diff --git a/src/core/backup_reader.py b/src/core/backup_reader.py index 9a23218..3218898 100644 --- a/src/core/backup_reader.py +++ b/src/core/backup_reader.py @@ -945,11 +945,23 @@ class BackupGuild: return next((c for c in self._reader._channels if c.id == parse_snowflake(channel_id)), None) return None - def get_thread(self, thread_id: int) -> "BackupChannel | None": + def get_thread(self, thread_id: int) -> "BackupThread | None": + """Mock discord.Guild.get_thread.""" if self._reader: - return next((c for c in self._reader._threads if c.id == parse_snowflake(thread_id)), None) + return next((t for t in self._reader._threads if t.id == parse_snowflake(thread_id)), None) return None + async def fetch_channels(self) -> List[Union["BackupChannel", "BackupCategory"]]: + """Async stub for discord.Guild.fetch_channels.""" + if self._reader: + self._reader._ensure_structure_loaded() + return list(self._reader._channels) + list(self._reader._categories) + return [] + + async def fetch_active_threads(self) -> List["BackupThread"]: + """Async stub for discord.Guild.active_threads helper (mirrors the property).""" + return self.active_threads + def __repr__(self) -> str: return f"BackupGuild(id={self.id}, name='{self.name}')" @@ -1237,6 +1249,14 @@ class BackupReader: channels = [c for c in channels if c.category_id == category_id] return channels + async def fetch_channels(self) -> List[Union[BackupChannel, BackupCategory]]: + """Uniform interface for all channels (including categories).""" + return list(self.channels) + list(self.categories) + + async def get_active_threads(self) -> List[BackupThread]: + """Uniform interface for active threads.""" + return [t for t in self.threads if not t.archived] + async def get_backed_up_channel_ids(self) -> List[int]: """Returns a list of channel IDs that have messages in the database.""" if not self.db: return [] diff --git a/src/core/discord_reader.py b/src/core/discord_reader.py index 245434f..e70c52b 100644 --- a/src/core/discord_reader.py +++ b/src/core/discord_reader.py @@ -20,6 +20,7 @@ class DiscordReader: # Channel Types CHANNEL_TYPE_TEXT = discord.ChannelType.text + CHANNEL_TYPE_VOICE = discord.ChannelType.voice CHANNEL_TYPE_NEWS = discord.ChannelType.news CHANNEL_TYPE_FORUM = discord.ChannelType.forum @@ -228,6 +229,18 @@ class DiscordReader: all_channels = [c for c in all_channels if c.category_id == category_id] return all_channels + async def get_active_threads(self) -> list[discord.Thread]: + """Returns all active threads in the guild.""" + if not self.guild: + return [] + return self.guild.active_threads + + async def fetch_channels(self) -> list[Union[discord.TextChannel, discord.VoiceChannel, discord.CategoryChannel, discord.ForumChannel]]: + """Async stub for discord.Guild.fetch_channels.""" + if not self.guild: + return [] + return await self.guild.fetch_channels() + async def get_channel(self, channel_id: int): """Returns a channel object.""" return await self.client.fetch_channel(channel_id) diff --git a/src/core/state.py b/src/core/state.py index d50869f..0e8d181 100644 --- a/src/core/state.py +++ b/src/core/state.py @@ -224,9 +224,9 @@ class MigrationState: row = conn.execute("SELECT channel_id, target_msg_id FROM message_mappings WHERE source_msg_id = ?", (str(discord_id),)).fetchone() if row: return str(row["channel_id"]), str(row["target_msg_id"]) - row = conn.execute("SELECT thread_id, target_msg_id FROM thread_mappings WHERE source_msg_id = ?", (str(discord_id),)).fetchone() + row = conn.execute("SELECT channel_id, target_msg_id FROM thread_mappings WHERE source_msg_id = ?", (str(discord_id),)).fetchone() if row: - return str(row["thread_id"]), str(row["target_msg_id"]) + return str(row["channel_id"]), str(row["target_msg_id"]) return None, None # --- Danger Zone Clearing --- diff --git a/src/fluxer/migrate_message.py b/src/fluxer/migrate_message.py index 92c5159..dc644cb 100644 --- a/src/fluxer/migrate_message.py +++ b/src/fluxer/migrate_message.py @@ -83,7 +83,10 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None, # 3. Try live lookup (fallback) if not name: - channel = guild.get_channel(cid) or guild.get_thread(cid) + try: + channel = guild.get_channel(cid) or guild.get_thread(cid) + except Exception: + channel = None if not channel and channel_mentions: channel = next((c for c in channel_mentions if c.id == cid), None) if channel: @@ -261,12 +264,12 @@ async def migrate_messages( try: logger.debug(f"Pre-fetching channel and thread names for guild {context.discord_reader.guild.id}...") # fetch_channels usually includes all non-thread channels - all_channels = await context.discord_reader.guild.fetch_channels() + all_channels = await context.discord_reader.fetch_channels() for c in all_channels: context.channel_names[str(c.id)] = c.name - # active_threads helps find threads that might be mentioned - threads = await context.discord_reader.guild.active_threads() + # get_active_threads helps find threads that might be mentioned + threads = await context.discord_reader.get_active_threads() for t in threads: context.channel_names[str(t.id)] = t.name diff --git a/src/stoat/migrate_message.py b/src/stoat/migrate_message.py index 229e568..c872fde 100644 --- a/src/stoat/migrate_message.py +++ b/src/stoat/migrate_message.py @@ -79,7 +79,10 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None, # 3. Try live lookup (fallback) if not name: - channel = guild.get_channel(cid) or guild.get_thread(cid) + try: + channel = guild.get_channel(cid) or guild.get_thread(cid) + except Exception: + channel = None if not channel and channel_mentions: channel = next((c for c in channel_mentions if c.id == cid), None) if channel: @@ -260,11 +263,11 @@ async def migrate_messages( context.channel_names = {} try: logger.debug(f"Pre-fetching channel and thread names for guild {context.discord_reader.guild.id}...") - all_channels = await context.discord_reader.guild.fetch_channels() + all_channels = await context.discord_reader.fetch_channels() for c in all_channels: context.channel_names[str(c.id)] = c.name - threads = await context.discord_reader.guild.active_threads() + threads = await context.discord_reader.get_active_threads() for t in threads: context.channel_names[str(t.id)] = t.name