implement thread mentions

This commit is contained in:
rambros 2026-03-27 11:22:07 +05:30
parent 4967183442
commit 939e063d46
5 changed files with 50 additions and 11 deletions

View file

@ -945,11 +945,23 @@ class BackupGuild:
return next((c for c in self._reader._channels if c.id == parse_snowflake(channel_id)), None)
return None
def get_thread(self, thread_id: int) -> "BackupChannel | None":
def get_thread(self, thread_id: int) -> "BackupThread | None":
"""Mock discord.Guild.get_thread."""
if self._reader:
return next((c for c in self._reader._threads if c.id == parse_snowflake(thread_id)), None)
return next((t for t in self._reader._threads if t.id == parse_snowflake(thread_id)), None)
return None
async def fetch_channels(self) -> List[Union["BackupChannel", "BackupCategory"]]:
"""Async stub for discord.Guild.fetch_channels."""
if self._reader:
self._reader._ensure_structure_loaded()
return list(self._reader._channels) + list(self._reader._categories)
return []
async def fetch_active_threads(self) -> List["BackupThread"]:
"""Async stub for discord.Guild.active_threads helper (mirrors the property)."""
return self.active_threads
def __repr__(self) -> str:
return f"BackupGuild(id={self.id}, name='{self.name}')"
@ -1237,6 +1249,14 @@ class BackupReader:
channels = [c for c in channels if c.category_id == category_id]
return channels
async def fetch_channels(self) -> List[Union[BackupChannel, BackupCategory]]:
"""Uniform interface for all channels (including categories)."""
return list(self.channels) + list(self.categories)
async def get_active_threads(self) -> List[BackupThread]:
"""Uniform interface for active threads."""
return [t for t in self.threads if not t.archived]
async def get_backed_up_channel_ids(self) -> List[int]:
"""Returns a list of channel IDs that have messages in the database."""
if not self.db: return []

View file

@ -20,6 +20,7 @@ class DiscordReader:
# Channel Types
CHANNEL_TYPE_TEXT = discord.ChannelType.text
CHANNEL_TYPE_VOICE = discord.ChannelType.voice
CHANNEL_TYPE_NEWS = discord.ChannelType.news
CHANNEL_TYPE_FORUM = discord.ChannelType.forum
@ -228,6 +229,18 @@ class DiscordReader:
all_channels = [c for c in all_channels if c.category_id == category_id]
return all_channels
async def get_active_threads(self) -> list[discord.Thread]:
"""Returns all active threads in the guild."""
if not self.guild:
return []
return self.guild.active_threads
async def fetch_channels(self) -> list[Union[discord.TextChannel, discord.VoiceChannel, discord.CategoryChannel, discord.ForumChannel]]:
"""Async stub for discord.Guild.fetch_channels."""
if not self.guild:
return []
return await self.guild.fetch_channels()
async def get_channel(self, channel_id: int):
"""Returns a channel object."""
return await self.client.fetch_channel(channel_id)

View file

@ -224,9 +224,9 @@ class MigrationState:
row = conn.execute("SELECT channel_id, target_msg_id FROM message_mappings WHERE source_msg_id = ?", (str(discord_id),)).fetchone()
if row:
return str(row["channel_id"]), str(row["target_msg_id"])
row = conn.execute("SELECT thread_id, target_msg_id FROM thread_mappings WHERE source_msg_id = ?", (str(discord_id),)).fetchone()
row = conn.execute("SELECT channel_id, target_msg_id FROM thread_mappings WHERE source_msg_id = ?", (str(discord_id),)).fetchone()
if row:
return str(row["thread_id"]), str(row["target_msg_id"])
return str(row["channel_id"]), str(row["target_msg_id"])
return None, None
# --- Danger Zone Clearing ---

View file

@ -83,7 +83,10 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None,
# 3. Try live lookup (fallback)
if not name:
try:
channel = guild.get_channel(cid) or guild.get_thread(cid)
except Exception:
channel = None
if not channel and channel_mentions:
channel = next((c for c in channel_mentions if c.id == cid), None)
if channel:
@ -261,12 +264,12 @@ async def migrate_messages(
try:
logger.debug(f"Pre-fetching channel and thread names for guild {context.discord_reader.guild.id}...")
# fetch_channels usually includes all non-thread channels
all_channels = await context.discord_reader.guild.fetch_channels()
all_channels = await context.discord_reader.fetch_channels()
for c in all_channels:
context.channel_names[str(c.id)] = c.name
# active_threads helps find threads that might be mentioned
threads = await context.discord_reader.guild.active_threads()
# get_active_threads helps find threads that might be mentioned
threads = await context.discord_reader.get_active_threads()
for t in threads:
context.channel_names[str(t.id)] = t.name

View file

@ -79,7 +79,10 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None,
# 3. Try live lookup (fallback)
if not name:
try:
channel = guild.get_channel(cid) or guild.get_thread(cid)
except Exception:
channel = None
if not channel and channel_mentions:
channel = next((c for c in channel_mentions if c.id == cid), None)
if channel:
@ -260,11 +263,11 @@ async def migrate_messages(
context.channel_names = {}
try:
logger.debug(f"Pre-fetching channel and thread names for guild {context.discord_reader.guild.id}...")
all_channels = await context.discord_reader.guild.fetch_channels()
all_channels = await context.discord_reader.fetch_channels()
for c in all_channels:
context.channel_names[str(c.id)] = c.name
threads = await context.discord_reader.guild.active_threads()
threads = await context.discord_reader.get_active_threads()
for t in threads:
context.channel_names[str(t.id)] = t.name