From 0aa32217b145f7401dfb9601f9cf34266656ee24 Mon Sep 17 00:00:00 2001 From: rambros Date: Mon, 9 Mar 2026 12:43:30 +0530 Subject: [PATCH] use local cache to avoid redundant api calls --- README.md | 6 +- src/core/discord_reader.py | 52 ++++++++++++--- src/fluxer/writer.py | 21 ++++-- src/stoat/clone_server.py | 3 +- src/stoat/writer.py | 71 ++++++++++++++------ src/ui/main_app.py | 6 +- src/ui/shuttle_ops.py | 128 ++++++++++++++++++++++++++++++++----- 7 files changed, 234 insertions(+), 53 deletions(-) diff --git a/README.md b/README.md index 9f93c0e..7081932 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ ![Disco Reaper](images/fluxer-reaper.jpg) ### Modern Terminal Interface -The tool now features a unified, intuitive TUI (Terminal User Interface) - no more commands +The tool now features a unified, intuitive TUI (Terminal User Interface) - no more text commands | Features | Fluxer | Stoat | | :--- | :---: | :---: | @@ -21,7 +21,7 @@ The tool now features a unified, intuitive TUI (Terminal User Interface) - no mo | - Roles Cloning | 🟩 | 🟩 | | - Roles Permissions | 🟩 | 🟩 | | - Category Permissions | 🟩 | ⚠️ | -| - Channel Permissions | 🟩 | ⏳to be done | +| - Channel Permissions | 🟩 | ⏳ | | **Emojis & Stickers** | | | | - Copy Emojis | 🟩 | 🟩 | | - Copy Stickers | 🟩 | ⚠️ | @@ -39,7 +39,7 @@ The tool now features a unified, intuitive TUI (Terminal User Interface) - no mo - ⚠️**Fluxer/Stoat**: Threads & Forums type channels are not yet natively available. As a workaround, threads are migrated in their parent channels as normal messages. - ⚠️**Stoat**: doesn't have features like Category Permissions or Slowmode settings for channels. -- ⏳**Stoat**: permission sync for channels is pending due to architectural differences. +- ⏳**Stoat**: permission sync for channels was not implemented due to architectural differences. --- diff --git a/src/core/discord_reader.py b/src/core/discord_reader.py index fcda75a..1dff7d0 100644 --- a/src/core/discord_reader.py +++ b/src/core/discord_reader.py @@ -63,6 +63,12 @@ class DiscordReader: self.client: discord.Client | None = None self.role_map: Dict[int, str] = {} self.channel_name_map: Dict[int, str] = {} + # Session-level caches to avoid redundant fetch calls + self._roles_cache: list[discord.Role] | None = None + self._channels_cache: list[discord.abc.GuildChannel] | None = None + self._categories_cache: list[discord.CategoryChannel] | None = None + self._emojis_cache: list[discord.Emoji] | None = None + self._stickers_cache: list[discord.GuildSticker] | None = None def _create_client(self): intents = discord.Intents.default() @@ -92,6 +98,7 @@ class DiscordReader: try: roles = await self.guild.fetch_roles() self.role_map = {r.id: r.name for r in roles} + self._roles_cache = [r for r in roles if not r.is_default()] except discord.Forbidden: logger.warning("403 Forbidden: Missing Access to fetch roles. Continuing without role mapping.") self.role_map = {} @@ -103,6 +110,8 @@ class DiscordReader: try: channels = await self.guild.fetch_channels() self.channel_name_map = {c.id: c.name for c in channels} + self._channels_cache = [c for c in channels if not isinstance(c, discord.CategoryChannel)] + self._categories_cache = [c for c in channels if isinstance(c, discord.CategoryChannel)] logger.debug(f"Pre-fetched {len(self.channel_name_map)} channels") except discord.Forbidden: logger.warning("403 Forbidden: Missing Access to fetch channels. Continuing without channel name mapping.") @@ -171,28 +180,40 @@ class DiscordReader: async def get_categories(self): if not self.guild: return [] + if self._categories_cache is not None: + return self._categories_cache categories = await self.guild.fetch_channels() - return [c for c in categories if isinstance(c, discord.CategoryChannel)] + self._categories_cache = [c for c in categories if isinstance(c, discord.CategoryChannel)] + return self._categories_cache async def get_roles(self): """Returns all roles in the server (excluding @everyone).""" if not self.guild: return [] + if self._roles_cache is not None: + return self._roles_cache roles = await self.guild.fetch_roles() # Filter out default @everyone role which cannot typically be created - return [r for r in roles if not r.is_default()] + self._roles_cache = [r for r in roles if not r.is_default()] + return self._roles_cache async def get_emojis(self): """Returns all custom emojis in the server.""" if not self.guild: return [] - return await self.guild.fetch_emojis() + if self._emojis_cache is not None: + return self._emojis_cache + self._emojis_cache = await self.guild.fetch_emojis() + return self._emojis_cache async def get_stickers(self): """Returns all custom stickers in the server.""" if not self.guild: return [] - return await self.guild.fetch_stickers() + if self._stickers_cache is not None: + return self._stickers_cache + self._stickers_cache = await self.guild.fetch_stickers() + return self._stickers_cache async def get_members(self): """Returns all members in the server.""" @@ -208,8 +229,12 @@ class DiscordReader: """Yields all non-category channels.""" if not self.guild: return [] - channels = await self.guild.fetch_channels() - all_channels = [c for c in channels if not isinstance(c, discord.CategoryChannel)] + + if self._channels_cache is None: + channels = await self.guild.fetch_channels() + self._channels_cache = [c for c in channels if not isinstance(c, discord.CategoryChannel)] + + all_channels = self._channels_cache if category_id: all_channels = [c for c in all_channels if c.category_id == category_id] return all_channels @@ -256,5 +281,16 @@ class DiscordReader: return await attachment.read() async def close(self): - if self.client: - await self.client.close() + client = self.client + self.client = None # Atomic clear + self.guild = None + self._roles_cache = None + self._channels_cache = None + self._categories_cache = None + self._emojis_cache = None + self._stickers_cache = None + if client: + try: + await client.close() + except Exception as e: + logger.debug(f"Error closing Discord client: {e}") diff --git a/src/fluxer/writer.py b/src/fluxer/writer.py index e1a9eb0..b1c5864 100644 --- a/src/fluxer/writer.py +++ b/src/fluxer/writer.py @@ -15,6 +15,7 @@ class FluxerWriter: self._bot_task: Optional[asyncio.Task] = None self._ready_event = asyncio.Event() self._webhooks: Dict[str, Webhook] = {} # channel_id -> Webhook + self._channels_cache: List[Dict[str, Any]] | None = None @staticmethod async def fetch_guilds(token: str, api_url: str = "default") -> list[tuple[str, str]]: @@ -243,8 +244,11 @@ class FluxerWriter: async def get_channels(self) -> List[Dict[str, Any]]: """Returns all channels in the community.""" + if self._channels_cache is not None: + return self._channels_cache assert self.client is not None - return await self.client.get_guild_channels(self.community_id) + self._channels_cache = await self.client.get_guild_channels(self.community_id) + return self._channels_cache async def send_message(self, channel_id: str, author_name: str, content: str, timestamp: int, author_avatar_url: Optional[str] = None, files: Optional[List[Dict[str, Any]]] = None, reply_to_message_id: Optional[str] = None, is_forwarded: bool = False, embeds: Optional[List[Dict[str, Any]]] = None) -> Optional[str]: """ @@ -660,16 +664,23 @@ class FluxerWriter: async def close(self): """Cleanly close connection and stop bot task.""" - if self.bot: + bot = self.bot + self.bot = None # Atomic clear + self._channels_cache = None + self._webhooks.clear() + + if bot: try: - await self.bot.close() + await bot.close() except Exception as e: logger.debug(f"Error closing Fluxer bot: {e}") if self._bot_task: - self._bot_task.cancel() + task = self._bot_task + self._bot_task = None + task.cancel() try: - await self._bot_task + await task except asyncio.CancelledError: pass self._ready_event.clear() diff --git a/src/stoat/clone_server.py b/src/stoat/clone_server.py index a7fb6cc..22667f4 100644 --- a/src/stoat/clone_server.py +++ b/src/stoat/clone_server.py @@ -184,7 +184,8 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl # 4. Final step: Parent the channels into categories via mass server.edit() logger.info("Parenting all channels into their respective categories...") - server = await context.writer._get_server(populate_channels=True) + # Force refetch to ensure we see all newly created categories from the loop above + server = await context.writer._get_server(populate_channels=True, force=True) cats = list(server.categories) if hasattr(server, "categories") and server.categories else [] # Workaround: Ensure default properties are set for all categories diff --git a/src/stoat/writer.py b/src/stoat/writer.py index d6db9bf..07a37af 100644 --- a/src/stoat/writer.py +++ b/src/stoat/writer.py @@ -10,6 +10,10 @@ class StoatWriter: self.token = token self.community_id = str(community_id) self.api_url = api_url + self.client: Optional[stoat.Client] = None + self._server = None + self._me = None + self._validation_cache = None @staticmethod async def fetch_guilds(token: str, api_url: str = "default") -> list[tuple[str, str]]: @@ -67,30 +71,44 @@ class StoatWriter: return guilds_list async def start(self): + if self.client: + # Check if client is actually usable (not half-closed) + try: + if self.client and not self.client.is_closed: + return + except Exception: + pass + self.client = None + client_kwargs = {"token": self.token, "bot": True} if self.api_url and self.api_url != "default": client_kwargs["http_base"] = self.api_url self.client = stoat.Client(**client_kwargs) - self._server = None - self._me = None try: self._me = await self.client.fetch_user("@me") except Exception as e: logger.error(f"Failed to fetch bot user in StoatWriter: {e}") + self.client = None # Reset if we can't even fetch @me @property def my_id(self): return str(self._me.id) if self._me else None - async def _get_server(self, populate_channels=False): - # Always refetch if channels are requested to ensure we have them + async def _get_server(self, populate_channels=False, force=False): + # Always refetch if channels are requested AND we don't already have them + # Or if force is True (e.g. after category creation/mutation) # Stoat Server objects use __slots__, so we can't easily add our own tracking attributes. - if not self._server or populate_channels: - self._server = await self.client.fetch_server(self.community_id, populate_channels=populate_channels) + if force or (populate_channels and (not self._server or not hasattr(self._server, "channels") or not self._server.channels)): + self._server = await self.client.fetch_server(self.community_id, populate_channels=True) + elif not self._server: + self._server = await self.client.fetch_server(self.community_id, populate_channels=False) return self._server async def validate(self) -> dict: + if self._validation_cache: + return self._validation_cache + results = { "token": False, "community": False, @@ -105,16 +123,21 @@ class StoatWriter: } } - # Use a temporary client for validation - client_kwargs = {"token": self.token, "bot": True} - if self.api_url and self.api_url != "default": - client_kwargs["http_base"] = self.api_url + # Ensure client is started + if not self.client: + await self.start() + + client = self.client + assert client is not None - client = stoat.Client(**client_kwargs) try: # Validate token by fetching current user try: - current_user = await client.fetch_user("@me") + # Reuse self._me if already fetched during start() + if not self._me: + self._me = await client.fetch_user("@me") + + current_user = self._me results["token"] = True results["bot_name"] = current_user.display_name or current_user.name except stoat.Unauthorized: @@ -128,12 +151,8 @@ class StoatWriter: results["community_name"] = server.name # Check permissions using effective server permissions for the bot - # Use current_user.id since @me might not be supported in all member endpoints try: me = await server.fetch_member(current_user.id) - # We use server.permissions_for(me) instead of me.server_permissions - # to avoid cache-related NoData exceptions. - # safe=False allows calculating even if some roles aren't in local cache. perms = server.permissions_for(me, safe=False) results["permissions"] = { @@ -166,9 +185,8 @@ class StoatWriter: except Exception as e: logger.error(f"Stoat validation failed: {str(e)}") - finally: - await client.close() + self._validation_cache = results return results async def get_channels(self) -> List[Dict[str, Any]]: @@ -217,6 +235,9 @@ class StoatWriter: try: if type == 4: # Category # The POST /categories endpoint throws 404 on some server versions, so we use server.edit(categories) + # Force refetch to ensure we have the absolute latest state before editing categories array + server = await self._get_server(populate_channels=True, force=True) + import random import time chars = "0123456789ABCDEFGHJKMNPQRSTVWXYZ" @@ -234,7 +255,8 @@ class StoatWriter: if not hasattr(new_cat, "role_permissions"): new_cat.role_permissions = {} categories.append(new_cat) - await server.edit(categories=categories) + # server.edit returns a new Server object on some versions/implementations; maintain local reference + self._server = await server.edit(categories=categories) return new_id else: # Text Channel ch = await server.create_text_channel(name=name, description=topic) @@ -676,4 +698,13 @@ class StoatWriter: return {"emojis": count, "stickers": 0} async def close(self): - pass + client = self.client + self.client = None # Atomic clear to prevent new usage + self._me = None + self._server = None + if client: + try: + await client.close() + except Exception as e: + logger.debug(f"Error closing Stoat client: {e}") + self._validation_cache = None diff --git a/src/ui/main_app.py b/src/ui/main_app.py index 94d5060..c6c377f 100644 --- a/src/ui/main_app.py +++ b/src/ui/main_app.py @@ -473,8 +473,12 @@ class ConfigScreen(Screen): # ────────────────────────────────────────────────────────────────────────────── class ReaperApp(App): + SCREENS = { + "config_selection": ConfigSelectionScreen, + } + def on_mount(self) -> None: - self.push_screen(ConfigSelectionScreen()) + self.push_screen("config_selection") self.theme = "dracula" def action_screenshot(self, filename: str | None = None, path: str | None = None) -> None: diff --git a/src/ui/shuttle_ops.py b/src/ui/shuttle_ops.py index f2171e0..d9d59dc 100644 --- a/src/ui/shuttle_ops.py +++ b/src/ui/shuttle_ops.py @@ -138,6 +138,7 @@ class ShuttlePane(Container): def on_mount(self) -> None: self._rebuild_engine() + # run_validate is handled by the writer's internal caching now to prevent log flooding self.run_validate() def reload_config(self) -> None: @@ -386,6 +387,9 @@ class ShuttlePane(Container): try: await self.engine.start_connections() connections_started = True + # Sync all entities before preview/confirmation + modal.set_status("Synchronizing entity mappings...") + await self._perform_auto_matching() except Exception as e: logger.warning(f"Could not pre-connect for Clone preview: {e}") @@ -394,7 +398,7 @@ class ShuttlePane(Container): modal.set_status(f"Awaiting Confirmation for {len(selections)} Operations...") - # Fetch and display live preview with presence highlighting + # Fetch and display live preview auto-matching already ran above preview = await self._fetch_clone_preview(selections) if connections_started else {} if connections_started: @@ -503,6 +507,7 @@ class ShuttlePane(Container): modal.phase_report("Batch Operation", "error", show_back=False) finally: self.engine.is_running = False + # Ensure we only close if we actually started them and no other task is inheriting await self.engine.close_connections() @work(exclusive=True) @@ -775,8 +780,12 @@ class ShuttlePane(Container): # Show info container modal.show_info("[bold cyan]Message Migration Ready[/bold cyan]", "Checking channel permissions...") - modal.set_status("Fetching channels...") + modal.set_status("Connecting to Servers...") await self.engine.start_connections() + + # Sync all entities before confirmation + modal.set_status("Synchronizing entity mappings...") + await self._perform_auto_matching() full_d = await self.engine.discord_reader.get_channels() @@ -1052,6 +1061,10 @@ class ShuttlePane(Container): try: await self.engine.start_target_only() target_started = True + + # Sync all entities before confirmation (even in danger zone) + modal.set_status("Synchronizing entity mappings...") + await self._perform_auto_matching() except Exception as e: logger.warning(f"Could not pre-connect for DZ preview: {e}") @@ -1207,48 +1220,133 @@ class ShuttlePane(Container): async def _fetch_clone_preview(self, selections: list[str]) -> dict[str, Any]: """Fetches preview data from Discord (source server) for cloning confirmation, comparing with existing entities on the target server for presence highlighting.""" - preview = {} + async def _perform_auto_matching(self): + """Matches Discord entities (roles, channels, emojis, stickers) with target platform items by name.""" + if not self.engine: + return + reader = self.engine.discord_reader writer = self.engine.writer is_fluxer = self.target_platform == "fluxer" - # Fetch target data for comparison - target_roles = [] - target_channels = [] + # 1. Fetch target data for comparison + target_roles_map = {} + target_chans_map = {} + target_cats_map = {} + target_emojis_map = {} + target_stickers_map = {} try: if is_fluxer: target_roles_raw = await writer.client.get_guild_roles(self.engine.config.target_server_id) - target_roles = [r.get("name", "").lower() for r in target_roles_raw] + target_roles_map = {r.get("name", "").lower(): str(r.get("id")) for r in target_roles_raw} + + target_emojis_raw = await writer.client.get_guild_emojis(self.engine.config.target_server_id) + target_emojis_map = {e.get("name", "").lower(): str(e.get("id")) for e in target_emojis_raw} + + try: + target_stickers_raw = await writer.client.get_guild_stickers(self.engine.config.target_server_id) + target_stickers_map = {s.get("name", "").lower(): str(s.get("id")) for s in target_stickers_raw} + except Exception: + pass else: server = await writer._get_server() - target_roles = [r.name.lower() for r in server.roles.values()] + target_roles_map = {r.name.lower(): str(r.id) for r in server.roles.values()} + + target_emojis_raw = await server.fetch_emojis() + target_emojis_map = {e.name.lower(): str(e.id) for e in target_emojis_raw} target_chans_raw = await writer.get_channels() - target_channels = [c.get("name", "").lower() for c in target_chans_raw] + target_chans_map = {c.get("name", "").lower(): str(c.get("id")) for c in target_chans_raw if c.get("type") != 4} + target_cats_map = {c.get("name", "").lower(): str(c.get("id")) for c in target_chans_raw if c.get("type") == 4} except Exception as e: - logger.warning(f"Clone Preview: failed to fetch target data for comparison: {e}") + logger.warning(f"Auto-matching: failed to fetch target data: {e}") + return # Cannot match without target data + + # 2. Match entities + try: + # Roles + src_roles = await reader.get_roles() + for r in src_roles: + name_l = r.name.lower() + if name_l in target_roles_map and not self.engine.state.get_target_role_id(r.id): + logger.info(f"Auto-matched Role: {r.name} -> {target_roles_map[name_l]}") + self.engine.state.set_target_role_mapping(r.id, target_roles_map[name_l]) + + # Categories + src_cats = await reader.get_categories() + for cat in src_cats: + name_l = cat.name.lower() + if name_l in target_cats_map and not self.engine.state.get_target_category_id(cat.id): + logger.info(f"Auto-matched Category: {cat.name} -> {target_cats_map[name_l]}") + self.engine.state.set_target_category_mapping(cat.id, target_cats_map[name_l]) + + # Channels + src_channels = await reader.get_channels() + for ch in src_channels: + name_l = ch.name.lower() + if name_l in target_chans_map and not self.engine.state.get_target_channel_id(ch.id): + logger.info(f"Auto-matched Channel: {ch.name} -> {target_chans_map[name_l]}") + self.engine.state.set_target_channel_mapping(ch.id, target_chans_map[name_l]) + + # Emojis + src_emojis = await reader.get_emojis() + for e in src_emojis: + name_l = e.name.lower() + if name_l in target_emojis_map and not self.engine.state.get_target_emoji_id(e.id): + logger.info(f"Auto-matched Emoji: {e.name} -> {target_emojis_map[name_l]}") + self.engine.state.set_target_emoji_mapping(e.id, target_emojis_map[name_l]) + + # Stickers + if is_fluxer: + src_stickers = await reader.get_stickers() + for s in src_stickers: + name_l = s.name.lower() + if name_l in target_stickers_map and not self.engine.state.get_target_sticker_id(s.id): + logger.info(f"Auto-matched Sticker: {s.name} -> {target_stickers_map[name_l]}") + self.engine.state.set_target_sticker_mapping(s.id, target_stickers_map[name_l]) + except Exception as e: + logger.warning(f"Auto-matching error: {e}") + + return { + "target_roles": target_roles_map, + "target_channels": target_chans_map, + "target_categories": target_cats_map, + "target_emojis": target_emojis_map, + "target_stickers": target_stickers_map + } + + async def _fetch_clone_preview(self, selections: list[str]) -> dict[str, Any]: + """Fetches preview data from Discord (source server) for cloning confirmation, + comparing with existing mappings in state-migration.json for presence highlighting.""" + preview = {} + reader = self.engine.discord_reader + + # We rely on the global auto-match that ran during connection + mapping_ch = self.engine.state.channel_map + mapping_cat = self.engine.state.category_map + mapping_role = self.engine.state.role_map try: if "sub_clone_roles" in selections: roles = await reader.get_roles() - preview["roles"] = [(r.name, r.name.lower() in target_roles) for r in roles] + # Highlight if existing in mapping + preview["roles"] = [(r.name, str(r.id) in mapping_role) for r in roles] except Exception as e: logger.warning(f"Clone Preview: failed to fetch roles: {e}") try: if "sub_clone_channels" in selections: - # Build hierarchy src_categories = await reader.get_categories() src_channels = await reader.get_channels() - # structure[cat_id] = (cat_name, cat_exists, [(ch_name, ch_exists), ...]) + # Build hierarchy for preview structure = {} for cat in src_categories: - cat_exists = cat.name.lower() in target_channels + cat_exists = str(cat.id) in mapping_cat structure[cat.id] = (cat.name, cat_exists, []) for ch in src_channels: - ch_exists = ch.name.lower() in target_channels + ch_exists = str(ch.id) in mapping_ch if ch.category_id in structure: structure[ch.category_id][2].append((ch.name, ch_exists)) else: