diff --git a/src/core/base.py b/src/core/base.py index 815481a..398d5af 100644 --- a/src/core/base.py +++ b/src/core/base.py @@ -16,6 +16,7 @@ class MigrationContext: def __init__(self, config: AppConfig, target_platform: str | None = None, source_mode: str = "live", base_dir: str = ""): self.config = config self.source_mode = source_mode + self.base_dir = base_dir # If caller didn't specify, fall back to config value self.target_platform = target_platform or config.target_platform or "fluxer" self.state = MigrationState() @@ -76,7 +77,7 @@ class MigrationContext: try: d_valid = await self.discord_reader.validate() t_valid = await self.writer.validate() - return { + results = { "discord_token": d_valid.get("token", False), "discord_bot_name": d_valid.get("bot_name"), "discord_server": d_valid.get("server", False), @@ -92,11 +93,10 @@ class MigrationContext: # CONSISTENCY: Once target metadata is known, initialize the flat SQLite DB. if results["target_community"] and results["target_community_name"]: - import re - clean_name = re.sub(r'[^\w\s-]', '', results["target_community_name"]).strip() - clean_name = re.sub(r'[-\s]+', '_', clean_name) - db_community_id = str(self.config.target_server_id or "") - self.state.set_folder(db_community_id, clean_name, base_dir=base_dir) + self.ensure_state_initialized( + str(self.config.target_server_id or ""), + results["target_community_name"] + ) return results except Exception as e: @@ -108,6 +108,23 @@ class MigrationContext: "target_community": False } + def ensure_state_initialized(self, community_id: str, community_name: str): + """Ensures the MigrationState database is initialized with the correct folder naming.""" + if not community_id or not community_name: + return + + import re + clean_name = re.sub(r'[^\w\s-]', '', community_name).strip() + clean_name = re.sub(r'[-\s]+', '_', clean_name) + + # Determine base directory (same logic as used in _find_backup_path) + # We assume the caller might have provided a base_dir in __init__ + # but for state we usually want it in the same place as backups + # or a logical subfolder. + base_dir = getattr(self, "base_dir", "") + + self.state.set_folder(community_id, clean_name, base_dir=base_dir) + async def start_connections(self): await self.discord_reader.start() await self.writer.start() diff --git a/src/ui/shuttle_ops.py b/src/ui/shuttle_ops.py index 23bfd51..a57bf06 100644 --- a/src/ui/shuttle_ops.py +++ b/src/ui/shuttle_ops.py @@ -1088,6 +1088,9 @@ class OperationPane(Container): tgt_server_info = await self.engine.writer.validate() tgt_server_name = tgt_server_info.get("community_name", "target community") + # ENSURE INITIALIZED for mapping lookup in analyze/migrate + self.engine.ensure_state_initialized(str(self.engine.config.target_server_id), tgt_server_name) + if src_server: modal.write(f"[bold cyan]Source Server Profile:[/bold cyan]") modal.write(f" Name: [green]{src_server.name}[/green]") @@ -1579,6 +1582,15 @@ class OperationPane(Container): target_emojis_raw = await writer.client.get_guild_emojis(self.engine.config.target_server_id) target_emojis_map = {e.get("name", "").lower(): str(e.get("id")) for e in target_emojis_raw} + # RE-INITIALIZE STATE if we found community info + # This ensures mapping persistence even if validate_all was skipped + try: + community_info = await writer.client.get_guild(self.engine.config.target_server_id) + if community_info: + self.engine.ensure_state_initialized(str(self.engine.config.target_server_id), community_info.get("name", "Target")) + except Exception: + pass + try: target_stickers_raw = await writer.client.get_guild_stickers(self.engine.config.target_server_id) target_stickers_map = {s.get("name", "").lower(): str(s.get("id")) for s in target_stickers_raw} @@ -1590,6 +1602,9 @@ class OperationPane(Container): target_emojis_raw = await server.fetch_emojis() target_emojis_map = {e.name.lower(): str(e.id) for e in target_emojis_raw} + + # RE-INITIALIZE STATE + self.engine.ensure_state_initialized(str(self.engine.config.target_server_id), server.name) target_chans_raw = await writer.get_channels() target_chans_map = {c.get("name", "").lower(): str(c.get("id")) for c in target_chans_raw if c.get("type") != 4}