diff --git a/src/core/backup_reader.py b/src/core/backup_reader.py index b85c5a2..ff889b3 100644 --- a/src/core/backup_reader.py +++ b/src/core/backup_reader.py @@ -696,7 +696,15 @@ class BackupReader: self.backup_path = Path(backup_path) self.guild: BackupGuild | None = None - # Internal caches populated by start() + self._thread_info: Dict[int, Dict[str, Any]] = {} # channel_id -> metadata (like name, parentID) + + # Lazy loading flags + self._roles_loaded = False + self._structure_loaded = False + self._assets_loaded = False + self._members_loaded = False + + # Internal storage self._categories: List[BackupCategory] = [] self._channels: List[BackupChannel] = [] self._roles: List[BackupRole] = [] @@ -704,12 +712,11 @@ class BackupReader: self._stickers: List[BackupSticker] = [] self._members: List[BackupMember] = [] self._member_map: Dict[int, BackupMember] = {} - self._thread_info: Dict[int, Dict[str, Any]] = {} # channel_id -> metadata (like name, parentID) # ── startup ────────────────────────────────────────────────────────── async def start(self): - """Loads all JSON files from the backup directory into memory.""" + """Initializes the backup path and loads the server profile.""" bp = self.backup_path # 1. Server profile -> BackupGuild @@ -717,60 +724,95 @@ class BackupReader: if profile_file.exists(): profile = json.loads(profile_file.read_text(encoding="utf-8")) self.guild = BackupGuild(profile, bp, reader=self) - logger.info(f"[Backup] Loaded server profile: {self.guild.name} ({self.guild.id})") + logger.info(f"[Backup] Initialized server: {self.guild.name} ({self.guild.id})") else: logger.warning(f"[Backup] server_profile/profile.json not found in {bp}") self.guild = None - # 2. Roles - roles_file = bp / "server_profile" / "roles.json" - if roles_file.exists(): - roles_data = json.loads(roles_file.read_text(encoding="utf-8")) - self._roles = [BackupRole(r) for r in roles_data] - logger.info(f"[Backup] Loaded {len(self._roles)} roles") + @property + def roles(self) -> List[BackupRole]: + if not self._roles_loaded: + roles_file = self.backup_path / "server_profile" / "roles.json" + if roles_file.exists(): + logger.info(f"[Backup] Lazy-loading roles...") + roles_data = json.loads(roles_file.read_text(encoding="utf-8")) + self._roles = [BackupRole(r) for r in roles_data] + self._roles_loaded = True + return self._roles - # 3. Structure -> categories + channels - struct_file = bp / "server_profile" / "structure.json" + @property + def categories(self) -> List[BackupCategory]: + self._ensure_structure_loaded() + return self._categories + + @property + def channels(self) -> List[BackupChannel]: + self._ensure_structure_loaded() + return self._channels + + def _ensure_structure_loaded(self): + if self._structure_loaded: + return + struct_file = self.backup_path / "server_profile" / "structure.json" if struct_file.exists(): + logger.info(f"[Backup] Lazy-loading server structure...") structure = json.loads(struct_file.read_text(encoding="utf-8")) for cat_data in structure: cat = BackupCategory(cat_data) - if cat.id != 0: # skip 'uncategorized' as a real category + if cat.id != 0: self._categories.append(cat) - for ch_data in cat_data.get("channels", []): ch_cat_id = cat.id if cat.id != 0 else None channel = BackupChannel(ch_data, category_id=ch_cat_id, guild=self.guild) self._channels.append(channel) + self._structure_loaded = True - logger.info(f"[Backup] Loaded {len(self._categories)} categories, " - f"{len(self._channels)} channels") + @property + def emojis(self) -> List[BackupEmoji]: + self._ensure_assets_loaded() + return self._emojis - # 4. Assets (emojis + stickers) - assets_file = bp / "server_profile" / "assets.json" - media_dir = bp / "server_profile" / "assets" + @property + def stickers(self) -> List[BackupSticker]: + self._ensure_assets_loaded() + return self._stickers + + def _ensure_assets_loaded(self): + if self._assets_loaded: + return + assets_file = self.backup_path / "server_profile" / "assets.json" + media_dir = self.backup_path / "server_profile" / "assets" if assets_file.exists(): + logger.info(f"[Backup] Lazy-loading assets...") assets = json.loads(assets_file.read_text(encoding="utf-8")) self._emojis = [BackupEmoji(e, media_dir) for e in assets.get("emojis", [])] self._stickers = [BackupSticker(s, media_dir) for s in assets.get("stickers", [])] - logger.info(f"[Backup] Loaded {len(self._emojis)} emojis, " - f"{len(self._stickers)} stickers") + self._assets_loaded = True - # 5. Users - user_info_file = bp / "message_backup" / "users" / "user_info.json" + @property + def members(self) -> List[BackupMember]: + self._ensure_members_loaded() + return self._members + + def _ensure_members_loaded(self): + if self._members_loaded: + return + user_info_file = self.backup_path / "message_backup" / "users" / "user_info.json" if user_info_file.exists(): + logger.info(f"[Backup] Lazy-loading members...") try: users = json.loads(user_info_file.read_text(encoding="utf-8")) - backup_root = bp / "message_backup" + backup_root = self.backup_path / "message_backup" for u in users: user_role_ids = {int(r["id"]) for r in u.get("userRoles", [])} - role_objs = [r for r in self._roles if r.id in user_role_ids] + # Note: this triggers roles lazy load + role_objs = [r for r in self.roles if r.id in user_role_ids] member = BackupMember(u, role_objects=role_objs, avatar_base=backup_root) self._members.append(member) self._member_map[member.id] = member - logger.info(f"[Backup] Loaded {len(self._members)} users") except Exception as e: - logger.warning(f"[Backup] Failed to load user_info.json: {e}") + logger.warning(f"[Backup] Failed to lazy-load user_info.json: {e}") + self._members_loaded = True # ── validation ─────────────────────────────────────────────────────── diff --git a/src/core/base.py b/src/core/base.py index 1fcf0d0..815481a 100644 --- a/src/core/base.py +++ b/src/core/base.py @@ -89,6 +89,16 @@ class MigrationContext: "target_community_name": t_valid.get("community_name"), "target_permissions": t_valid.get("permissions", {}) } + + # CONSISTENCY: Once target metadata is known, initialize the flat SQLite DB. + if results["target_community"] and results["target_community_name"]: + import re + clean_name = re.sub(r'[^\w\s-]', '', results["target_community_name"]).strip() + clean_name = re.sub(r'[-\s]+', '_', clean_name) + db_community_id = str(self.config.target_server_id or "") + self.state.set_folder(db_community_id, clean_name, base_dir=base_dir) + + return results except Exception as e: logger.error(f"Validation failed with exception: {e}") return { diff --git a/src/core/database.py b/src/core/database.py new file mode 100644 index 0000000..6cb2ba6 --- /dev/null +++ b/src/core/database.py @@ -0,0 +1,237 @@ +import sqlite3 +import logging +from pathlib import Path +from typing import Optional, Dict, Any +import threading + +logger = logging.getLogger(__name__) + +class MigrationDatabase: + """ + SQLite-based persistence for large-scale migration mappings and stats. + Replaces the memory-bloated and O(N^2) JSON persistence for messages. + """ + + _local = threading.local() + + def __init__(self, db_path: Path): + self.db_path = db_path + self._init_db() + + def _get_conn(self) -> sqlite3.Connection: + if not hasattr(self._local, "conn"): + self._local.conn = sqlite3.connect(self.db_path, check_same_thread=False) + self._local.conn.row_factory = sqlite3.Row + return self._local.conn + + def _init_db(self): + """Initialize tables if they don't exist.""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + # Table for message mappings: SourceID -> TargetID + cursor.execute(""" + CREATE TABLE IF NOT EXISTS message_mappings ( + channel_id TEXT, + source_msg_id TEXT, + target_msg_id TEXT, + timestamp TEXT, + PRIMARY KEY (channel_id, source_msg_id) + ) + """) + + # Table for thread mappings + cursor.execute(""" + CREATE TABLE IF NOT EXISTS thread_mappings ( + channel_id TEXT, + thread_id TEXT, + source_msg_id TEXT, + target_msg_id TEXT, + timestamp TEXT, + PRIMARY KEY (channel_id, thread_id, source_msg_id) + ) + """) + + # Table for per-channel stats and tracking + cursor.execute(""" + CREATE TABLE IF NOT EXISTS channel_tracking ( + channel_id TEXT PRIMARY KEY, + last_msg_id TEXT, + last_msg_ts TEXT, + msg_count INTEGER DEFAULT 0, + file_count INTEGER DEFAULT 0 + ) + """) + + # Table for per-thread stats + cursor.execute(""" + CREATE TABLE IF NOT EXISTS thread_tracking ( + channel_id TEXT, + thread_id TEXT, + last_msg_id TEXT, + last_msg_ts TEXT, + msg_count INTEGER DEFAULT 0, + file_count INTEGER DEFAULT 0, + PRIMARY KEY (channel_id, thread_id) + ) + """) + + # Table for entity mappings (channels, roles, etc.) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS entity_mappings ( + category TEXT, + source_id TEXT, + target_id TEXT, + PRIMARY KEY (category, source_id) + ) + """) + + # Table for general metadata + cursor.execute(""" + CREATE TABLE IF NOT EXISTS metadata ( + key TEXT PRIMARY KEY, + value TEXT + ) + """) + + conn.commit() + conn.close() + + def set_message_mapping(self, channel_id: str, source_id: str, target_id: str, timestamp: str = None): + conn = self._get_conn() + conn.execute( + "INSERT OR REPLACE INTO message_mappings (channel_id, source_msg_id, target_msg_id, timestamp) VALUES (?, ?, ?, ?)", + (channel_id, source_id, target_id, timestamp) + ) + conn.commit() + + def get_target_message_id(self, channel_id: str, source_id: str) -> Optional[str]: + conn = self._get_conn() + row = conn.execute( + "SELECT target_msg_id FROM message_mappings WHERE channel_id = ? AND source_msg_id = ?", + (channel_id, source_id) + ).fetchone() + return row["target_msg_id"] if row else None + + # --- New Entity Mapping Methods --- + + def set_entity_mapping(self, category: str, source_id: str, target_id: str): + conn = self._get_conn() + conn.execute( + "INSERT OR REPLACE INTO entity_mappings (category, source_id, target_id) VALUES (?, ?, ?)", + (category, str(source_id), str(target_id)) + ) + conn.commit() + + def get_entity_mapping(self, category: str, source_id: str) -> Optional[str]: + conn = self._get_conn() + row = conn.execute( + "SELECT target_id FROM entity_mappings WHERE category = ? AND source_id = ?", + (category, str(source_id)) + ).fetchone() + return row["target_id"] if row else None + + def get_all_entity_mappings(self, category: str) -> Dict[str, str]: + conn = self._get_conn() + rows = conn.execute( + "SELECT source_id, target_id FROM entity_mappings WHERE category = ?", + (category,) + ).fetchall() + return {row["source_id"]: row["target_id"] for row in rows} + + def delete_entity_mapping(self, category: str, source_id: str): + conn = self._get_conn() + conn.execute( + "DELETE FROM entity_mappings WHERE category = ? AND source_id = ?", + (category, str(source_id)) + ) + conn.commit() + + def clear_entities(self, category: str = None): + conn = self._get_conn() + if category: + conn.execute("DELETE FROM entity_mappings WHERE category = ?", (category,)) + else: + conn.execute("DELETE FROM entity_mappings") + conn.commit() + + # --- Metadata Methods --- + + def set_metadata(self, key: str, value: str): + conn = self._get_conn() + conn.execute("INSERT OR REPLACE INTO metadata (key, value) VALUES (?, ?)", (key, str(value) if value is not None else None)) + conn.commit() + + def get_metadata(self, key: str) -> Optional[str]: + conn = self._get_conn() + row = conn.execute("SELECT value FROM metadata WHERE key = ?", (key,)).fetchone() + return row["value"] if row else None + + def update_channel_tracking(self, channel_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0): + conn = self._get_conn() + # Initialize if missing + conn.execute("INSERT OR IGNORE INTO channel_tracking (channel_id) VALUES (?)", (channel_id,)) + + if last_msg_id: + conn.execute("UPDATE channel_tracking SET last_msg_id = ? WHERE channel_id = ?", (last_msg_id, channel_id)) + if last_msg_ts: + conn.execute("UPDATE channel_tracking SET last_msg_ts = ? WHERE channel_id = ?", (last_msg_ts, channel_id)) + + if msg_inc != 0 or file_inc != 0: + conn.execute( + "UPDATE channel_tracking SET msg_count = msg_count + ?, file_count = file_count + ? WHERE channel_id = ?", + (msg_inc, file_inc, channel_id) + ) + conn.commit() + + def get_channel_tracking(self, channel_id: str) -> Dict[str, Any]: + conn = self._get_conn() + row = conn.execute("SELECT * FROM channel_tracking WHERE channel_id = ?", (channel_id,)).fetchone() + if row: + return dict(row) + return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0} + + # Thread methods similar to channel methods + def set_thread_message_mapping(self, channel_id: str, thread_id: str, source_id: str, target_id: str, timestamp: str = None): + conn = self._get_conn() + conn.execute( + "INSERT OR REPLACE INTO thread_mappings (channel_id, thread_id, source_msg_id, target_msg_id, timestamp) VALUES (?, ?, ?, ?, ?)", + (channel_id, thread_id, source_id, target_id, timestamp) + ) + conn.commit() + + def get_target_thread_message_id(self, channel_id: str, thread_id: str, source_id: str) -> Optional[str]: + conn = self._get_conn() + row = conn.execute( + "SELECT target_msg_id FROM thread_mappings WHERE channel_id = ? AND thread_id = ? AND source_msg_id = ?", + (channel_id, thread_id, source_id) + ).fetchone() + return row["target_msg_id"] if row else None + + def update_thread_tracking(self, channel_id: str, thread_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0): + conn = self._get_conn() + conn.execute("INSERT OR IGNORE INTO thread_tracking (channel_id, thread_id) VALUES (?, ?)", (channel_id, thread_id)) + + if last_msg_id: + conn.execute("UPDATE thread_tracking SET last_msg_id = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_id, channel_id, thread_id)) + if last_msg_ts: + conn.execute("UPDATE thread_tracking SET last_msg_ts = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_ts, channel_id, thread_id)) + + if msg_inc != 0 or file_inc != 0: + conn.execute( + "UPDATE thread_tracking SET msg_count = msg_count + ?, file_count = file_count + ? WHERE channel_id = ? AND thread_id = ?", + (msg_inc, file_inc, channel_id, thread_id) + ) + conn.commit() + + def get_thread_tracking(self, channel_id: str, thread_id: str) -> Dict[str, Any]: + conn = self._get_conn() + row = conn.execute("SELECT * FROM thread_tracking WHERE channel_id = ? AND thread_id = ?", (channel_id, thread_id)).fetchone() + if row: + return dict(row) + return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0} + + def close(self): + if hasattr(self._local, "conn"): + self._local.conn.close() + del self._local.conn diff --git a/src/core/state.py b/src/core/state.py index 0923180..08666c1 100644 --- a/src/core/state.py +++ b/src/core/state.py @@ -1,366 +1,248 @@ import json import logging from pathlib import Path -from typing import Dict, Any +from typing import Dict, Any, Optional logger = logging.getLogger(__name__) class MigrationState: - """Manages persistence of the migration state to allow resumability.""" + """Manages persistence of the migration state to allow resumability. + Uses SQLite for ALL mappings and metadata. + """ - def __init__(self, state_file: str | Path = "", messages_file: str | Path = ""): - self.state_file: Path | None = Path(state_file) if state_file else None - self.messages_file: Path | None = Path(messages_file) if messages_file else None + def __init__(self): + # database instance for all persistence + self.db: Optional['MigrationDatabase'] = None - # mappings: discord_id -> fluxer_id - self.channel_map: Dict[str, str] = {} - self.category_map: Dict[str, str] = {} - self.role_map: Dict[str, str] = {} - self.emoji_map: Dict[str, str] = {} - self.sticker_map: Dict[str, str] = {} - - # audit log tracking - self.audit_log_channel: str | None = None - - # message tracking per target channel - # Format: { target_channel_id: {"message_map": {}, "last_message_id": "", "last_message_timestamp": ""} } - self.channel_messages: Dict[str, Dict[str, Any]] = {} - - self.load() + def _ensure_db(self): + if not self.db: + logger.warning("MigrationState: Accessing database before initialization") + return False + return True - def load(self): - migrated_state = False - migrated_messages = False + # --- Type Specific Getters/Setters (Database Backed) --- - # 1. Load primary state file - if self.state_file and self.state_file.exists(): - with open(self.state_file, "r", encoding="utf-8") as f: - data = json.load(f) - self.channel_map = data.get("channels", {}) - self.category_map = data.get("categories", {}) - self.role_map = data.get("roles", {}) - self.emoji_map = data.get("emojis", {}) - self.sticker_map = data.get("stickers", {}) - self.audit_log_channel = data.get("audit_log_channel") + def set_channel_mapping(self, discord_id: str, target_id: str): + if self._ensure_db(): + self.db.set_entity_mapping("channel", str(discord_id), str(target_id)) - # 2. Load separate messages file - if self.messages_file and self.messages_file.exists(): - logger.info(f"Loading messages from {self.messages_file.name}") - try: - with open(self.messages_file, "r", encoding="utf-8") as f: - msg_data = json.load(f) - - # Check for new schema (nested under 'channels') - if "channels" in msg_data: - self.channel_messages = msg_data.get("channels", {}) - logger.debug(f"Loaded {len(self.channel_messages)} tracked channels.") - else: - logger.warning("Legacy schema or empty tracker detected in messages file.") - # Legacy schema detection & conversion to a default 'unknown_channel' just in case, - # though new migrations shouldn't hit this based on previous removals. - legacy_map = msg_data.get("messages", {}) - legacy_ids = msg_data.get("last_message_ids", {}) - legacy_times = msg_data.get("last_message_timestamps", {}) - - if legacy_map or legacy_ids or legacy_times: - self.channel_messages = { - "legacy_migrated_channel": { - "message_map": legacy_map, - "last_message_id": list(legacy_ids.values())[-1] if legacy_ids else "", - "last_message_timestamp": list(legacy_times.values())[-1] if legacy_times else "" - } - } - except Exception as e: - logger.error(f"Failed to load messages file: {e}") - - - - def save_state(self): - """Saves only the core server configuration (channels, roles, emojis).""" - if not self.state_file: - return - - logger.debug(f"Saving state to {self.state_file.name}") - data = { - "channels": self.channel_map, - "categories": self.category_map, - "roles": self.role_map, - "emojis": self.emoji_map, - "stickers": self.sticker_map, - "audit_log_channel": self.audit_log_channel - } - try: - with open(self.state_file, "w", encoding="utf-8") as f: - json.dump(data, f, indent=4) - except Exception as e: - logger.error(f"Failed to save state file: {e}") - - def save_messages(self): - """Saves only the message tracking data.""" - if not self.messages_file: - return - - logger.debug(f"Saving messages to {self.messages_file.name}") - data = { - "channels": self.channel_messages - } - try: - with open(self.messages_file, "w", encoding="utf-8") as f: - json.dump(data, f, indent=4) - except Exception as e: - logger.error(f"Failed to save messages file: {e}") - - # --- Type Specific Getters/Setters --- - - def set_channel_mapping(self, discord_id: str, fluxer_id: str): - self.channel_map[str(discord_id)] = str(fluxer_id) - self.save_state() - - def get_fluxer_channel_id(self, discord_id: str) -> str | None: - return self.channel_map.get(str(discord_id)) - - def remove_channel_mapping(self, discord_id: str): - self.channel_map.pop(str(discord_id), None) - self.save_state() - - def set_category_mapping(self, discord_id: str, fluxer_id: str): - self.category_map[str(discord_id)] = str(fluxer_id) - self.save_state() - - def get_fluxer_category_id(self, discord_id: str) -> str | None: - return self.category_map.get(str(discord_id)) - - def remove_category_mapping(self, discord_id: str): - self.category_map.pop(str(discord_id), None) - self.save_state() - - def set_role_mapping(self, discord_id: str, fluxer_id: str): - self.role_map[str(discord_id)] = str(fluxer_id) - self.save_state() - - def get_fluxer_role_id(self, discord_id: str) -> str | None: - return self.role_map.get(str(discord_id)) - - def remove_role_mapping(self, discord_id: str): - self.role_map.pop(str(discord_id), None) - self.save_state() - - def set_emoji_mapping(self, discord_id: str, fluxer_id: str): - self.emoji_map[str(discord_id)] = str(fluxer_id) - self.save_state() - - def get_fluxer_emoji_id(self, discord_id: str) -> str | None: - return self.emoji_map.get(str(discord_id)) - - def remove_emoji_mapping(self, discord_id: str): - self.emoji_map.pop(str(discord_id), None) - self.save_state() - - def set_sticker_mapping(self, discord_id: str, fluxer_id: str): - self.sticker_map[str(discord_id)] = str(fluxer_id) - self.save_state() - - def get_fluxer_sticker_id(self, discord_id: str) -> str | None: - return self.sticker_map.get(str(discord_id)) - - def remove_sticker_mapping(self, discord_id: str): - self.sticker_map.pop(str(discord_id), None) - self.save_state() - - # --- Generic Aliases for target platform migration --- + def get_target_channel_id(self, discord_id: str) -> str | None: + if self._ensure_db(): + return self.db.get_entity_mapping("channel", str(discord_id)) + return None - get_target_channel_id = get_fluxer_channel_id - set_channel_mapping = set_channel_mapping # already generic enough in name if we ignore the 'fluxer' in implementation - - def set_target_channel_mapping(self, discord_id: str, target_id: str): - self.set_channel_mapping(discord_id, target_id) + get_fluxer_channel_id = get_target_channel_id + set_target_channel_mapping = set_channel_mapping + + def set_category_mapping(self, discord_id: str, target_id: str): + if self._ensure_db(): + self.db.set_entity_mapping("category", str(discord_id), str(target_id)) def get_target_category_id(self, discord_id: str) -> str | None: - return self.get_fluxer_category_id(discord_id) + if self._ensure_db(): + return self.db.get_entity_mapping("category", str(discord_id)) + return None + + get_fluxer_category_id = get_target_category_id + set_target_category_mapping = set_category_mapping - def set_target_category_mapping(self, discord_id: str, target_id: str): - self.set_category_mapping(discord_id, target_id) + def set_role_mapping(self, discord_id: str, target_id: str): + if self._ensure_db(): + self.db.set_entity_mapping("role", str(discord_id), str(target_id)) def get_target_role_id(self, discord_id: str) -> str | None: - return self.get_fluxer_role_id(discord_id) + if self._ensure_db(): + return self.db.get_entity_mapping("role", str(discord_id)) + return None + + get_fluxer_role_id = get_target_role_id + set_target_role_mapping = set_role_mapping - def set_target_role_mapping(self, discord_id: str, target_id: str): - self.set_role_mapping(discord_id, target_id) + def set_emoji_mapping(self, discord_id: str, target_id: str): + if self._ensure_db(): + self.db.set_entity_mapping("emoji", str(discord_id), str(target_id)) def get_target_emoji_id(self, discord_id: str) -> str | None: - return self.get_fluxer_emoji_id(discord_id) + if self._ensure_db(): + return self.db.get_entity_mapping("emoji", str(discord_id)) + return None + + get_fluxer_emoji_id = get_target_emoji_id + set_target_emoji_mapping = set_emoji_mapping - def set_target_emoji_mapping(self, discord_id: str, target_id: str): - self.set_emoji_mapping(discord_id, target_id) + def set_sticker_mapping(self, discord_id: str, target_id: str): + if self._ensure_db(): + self.db.set_entity_mapping("sticker", str(discord_id), str(target_id)) def get_target_sticker_id(self, discord_id: str) -> str | None: - return self.get_fluxer_sticker_id(discord_id) + if self._ensure_db(): + return self.db.get_entity_mapping("sticker", str(discord_id)) + return None + + get_fluxer_sticker_id = get_target_sticker_id + set_target_sticker_mapping = set_sticker_mapping - def set_target_sticker_mapping(self, discord_id: str, target_id: str): - self.set_sticker_mapping(discord_id, target_id) - - def get_target_message_id(self, target_channel_id: str, discord_id: str) -> str | None: - return self.get_fluxer_message_id(target_channel_id, discord_id) + # --- Properties for backward compatibility --- + @property + def channel_map(self) -> Dict[str, str]: + return self.db.get_all_entity_mappings("channel") if self.db else {} - def set_target_message_mapping(self, target_channel_id: str, discord_id: str, target_id: str): - self.set_message_mapping(target_channel_id, discord_id, target_id) + @property + def category_map(self) -> Dict[str, str]: + return self.db.get_all_entity_mappings("category") if self.db else {} + + @property + def role_map(self) -> Dict[str, str]: + return self.db.get_all_entity_mappings("role") if self.db else {} + + @property + def emoji_map(self) -> Dict[str, str]: + return self.db.get_all_entity_mappings("emoji") if self.db else {} + + @property + def sticker_map(self) -> Dict[str, str]: + return self.db.get_all_entity_mappings("sticker") if self.db else {} + + @property + def audit_log_channel(self) -> str | None: + return self.db.get_metadata("audit_log_channel") if self.db else None + + @audit_log_channel.setter + def audit_log_channel(self, value: str | None): + if self._ensure_db(): + self.db.set_metadata("audit_log_channel", value) # --- Message Management --- - def _ensure_channel_tracking(self, target_channel_id: str): - if str(target_channel_id) not in self.channel_messages: - self.channel_messages[str(target_channel_id)] = { - "message_map": {}, - "last_message_id": "", - "last_message_timestamp": "", - "total_messages": 0, - "total_files": 0, - "threads": {} - } - - def increment_stats(self, target_channel_id: str, messages: int = 1, files: int = 0): - self._ensure_channel_tracking(target_channel_id) - c = self.channel_messages[str(target_channel_id)] - c["total_messages"] = c.get("total_messages", 0) + messages - c["total_files"] = c.get("total_files", 0) + files - self.save_messages() + def set_target_message_mapping(self, target_channel_id: str, discord_id: str, target_id: str): + if self._ensure_db(): + self.db.set_message_mapping(str(target_channel_id), str(discord_id), str(target_id)) - # --- Thread Tracking --- - - def _ensure_thread_tracking(self, target_channel_id: str, thread_id: str): - self._ensure_channel_tracking(target_channel_id) - threads = self.channel_messages[str(target_channel_id)].setdefault("threads", {}) - if str(thread_id) not in threads: - threads[str(thread_id)] = { - "thread_map": {}, - "last_message_id": "", - "last_message_timestamp": "", - "total_messages": 0, - "total_files": 0 - } - - def increment_thread_stats(self, target_channel_id: str, thread_id: str, messages: int = 1, files: int = 0): - self._ensure_thread_tracking(target_channel_id, thread_id) - t = self.channel_messages[str(target_channel_id)]["threads"][str(thread_id)] - t["total_messages"] = t.get("total_messages", 0) + messages - t["total_files"] = t.get("total_files", 0) + files - self.save_messages() - - def set_thread_message_mapping(self, target_channel_id: str, thread_id: str, discord_id: str, target_id: str): - self._ensure_thread_tracking(target_channel_id, thread_id) - self.channel_messages[str(target_channel_id)]["threads"][str(thread_id)]["thread_map"][str(discord_id)] = str(target_id) - # Also add to main message_map for global message resolution (like replies) - self.set_message_mapping(target_channel_id, discord_id, target_id) - - def update_thread_last_message_timestamp(self, target_channel_id: str, thread_id: str, timestamp: str): - self._ensure_thread_tracking(target_channel_id, thread_id) - self.channel_messages[str(target_channel_id)]["threads"][str(thread_id)]["last_message_timestamp"] = str(timestamp) - self.save_messages() - - def update_thread_last_message_id(self, target_channel_id: str, thread_id: str, message_id: str): - self._ensure_thread_tracking(target_channel_id, thread_id) - self.channel_messages[str(target_channel_id)]["threads"][str(thread_id)]["last_message_id"] = str(message_id) - self.save_messages() - - def get_thread_message_id(self, target_channel_id: str, thread_id: str, discord_id: str) -> str | None: - if str(target_channel_id) in self.channel_messages: - threads = self.channel_messages[str(target_channel_id)].get("threads", {}) - if str(thread_id) in threads: - return threads[str(thread_id)]["thread_map"].get(str(discord_id)) + def get_target_message_id(self, target_channel_id: str, discord_id: str) -> str | None: + if self._ensure_db(): + return self.db.get_target_message_id(str(target_channel_id), str(discord_id)) return None - - def set_message_mapping(self, target_channel_id: str, discord_id: str, fluxer_id: str): - self._ensure_channel_tracking(target_channel_id) - self.channel_messages[str(target_channel_id)]["message_map"][str(discord_id)] = str(fluxer_id) - self.save_messages() + + def set_message_mapping(self, target_channel_id: str, discord_id: str, target_id: str): + self.set_target_message_mapping(target_channel_id, discord_id, target_id) def get_fluxer_message_id(self, target_channel_id: str, discord_id: str) -> str | None: - if str(target_channel_id) in self.channel_messages: - return self.channel_messages[str(target_channel_id)]["message_map"].get(str(discord_id)) + return self.get_target_message_id(target_channel_id, discord_id) + + def increment_stats(self, target_channel_id: str, messages: int = 1, files: int = 0): + if self._ensure_db(): + self.db.update_channel_tracking(str(target_channel_id), msg_inc=messages, file_inc=files) + + def increment_thread_stats(self, target_channel_id: str, thread_id: str, messages: int = 1, files: int = 0): + if self._ensure_db(): + self.db.update_thread_tracking(str(target_channel_id), str(thread_id), msg_inc=messages, file_inc=files) + + def set_thread_message_mapping(self, target_channel_id: str, thread_id: str, discord_id: str, target_id: str): + if self._ensure_db(): + self.db.set_thread_message_mapping(str(target_channel_id), str(thread_id), str(discord_id), str(target_id)) + + def update_thread_last_message_timestamp(self, target_channel_id: str, thread_id: str, timestamp: str): + if self._ensure_db(): + self.db.update_thread_tracking(str(target_channel_id), str(thread_id), last_msg_ts=str(timestamp)) + + def update_thread_last_message_id(self, target_channel_id: str, thread_id: str, message_id: str): + if self._ensure_db(): + self.db.update_thread_tracking(str(target_channel_id), str(thread_id), last_msg_id=str(message_id)) + + def get_thread_message_id(self, target_channel_id: str, thread_id: str, discord_id: str) -> str | None: + if self._ensure_db(): + return self.db.get_target_thread_message_id(str(target_channel_id), str(thread_id), str(discord_id)) + return None + + def update_last_message_timestamp(self, target_channel_id: str, timestamp: str): + if self._ensure_db(): + self.db.update_channel_tracking(str(target_channel_id), last_msg_ts=str(timestamp)) + + def update_last_message_id(self, target_channel_id: str, message_id: str): + if self._ensure_db(): + self.db.update_channel_tracking(str(target_channel_id), last_msg_id=str(message_id)) + + def get_last_message_id(self, target_channel_id: str) -> str | None: + if self._ensure_db(): + return self.db.get_channel_tracking(str(target_channel_id)).get("last_msg_id") return None def find_message_mapping(self, discord_id: str) -> tuple[str, str] | tuple[None, None]: - """ - Searches for a message mapping across all tracked channels. - Returns (target_channel_id, target_message_id) or (None, None). - """ - d_id = str(discord_id) - for t_cid, data in self.channel_messages.items(): - # Check main message map - if d_id in data.get("message_map", {}): - return str(t_cid), str(data["message_map"][d_id]) - # Check threads - for t_tid, t_data in data.get("threads", {}).items(): - if d_id in t_data.get("thread_map", {}): - # For thread links, the target_channel_id is technically the thread ID in some contexts, - # but usually for the URL it's the thread ID itself. - return str(t_tid), str(t_data["thread_map"][d_id]) + if not self.db: + return None, None + conn = self.db._get_conn() + row = conn.execute("SELECT channel_id, target_msg_id FROM message_mappings WHERE source_msg_id = ?", (str(discord_id),)).fetchone() + if row: + return str(row["channel_id"]), str(row["target_msg_id"]) + row = conn.execute("SELECT thread_id, target_msg_id FROM thread_mappings WHERE source_msg_id = ?", (str(discord_id),)).fetchone() + if row: + return str(row["thread_id"]), str(row["target_msg_id"]) return None, None - - def update_last_message_timestamp(self, target_channel_id: str, timestamp: str): - self._ensure_channel_tracking(target_channel_id) - self.channel_messages[str(target_channel_id)]["last_message_timestamp"] = str(timestamp) - self.save_messages() - - def update_last_message_id(self, target_channel_id: str, message_id: str): - self._ensure_channel_tracking(target_channel_id) - self.channel_messages[str(target_channel_id)]["last_message_id"] = str(message_id) - self.save_messages() - - def get_last_message_id(self, target_channel_id: str) -> str | None: - if str(target_channel_id) in self.channel_messages: - return self.channel_messages[str(target_channel_id)].get("last_message_id") - return None # --- Danger Zone Clearing --- def clear_channel_mappings(self): - """Clears all channel and category mappings.""" - self.channel_map.clear() - self.category_map.clear() - self.save_state() + if self._ensure_db(): + self.db.clear_entities("channel") + self.db.clear_entities("category") def clear_role_mappings(self): - """Clears all role mappings.""" - self.role_map.clear() - self.save_state() + if self._ensure_db(): + self.db.clear_entities("role") def clear_asset_mappings(self): - """Clears all emoji and sticker mappings.""" - self.emoji_map.clear() - self.sticker_map.clear() - self.save_state() + if self._ensure_db(): + self.db.clear_entities("emoji") + self.db.clear_entities("sticker") def clear_message_history(self): - """Clears all message mappings and timestamps.""" - self.channel_messages.clear() - self.save_messages() + if self.db: + conn = self.db._get_conn() + conn.execute("DELETE FROM message_mappings") + conn.execute("DELETE FROM thread_mappings") + conn.execute("DELETE FROM channel_tracking") + conn.execute("DELETE FROM thread_tracking") + conn.commit() def set_folder(self, server_id: str, clean_name: str, base_dir: Path | str = ""): + """ + Initializes the SQLite database based on community name and ID. + Filename: {name}-{id}.db (Flat structure) + ID is priority: if a DB with the same ID exists but different name, rename it. + """ base = Path(base_dir) if base_dir else Path(".") - new_folder = base / f"{clean_name}-{server_id}" - logger.info(f"Setting active migration folder: {new_folder}") + desired_filename = f"{clean_name}-{server_id}.db" + desired_path = base / desired_filename - # 1. Search base_dir to see if an older folder for this server_id exists - existing_folder: Path | None = None - if base.exists() and base.is_dir(): - for d in base.iterdir(): - if d.is_dir() and d.name.endswith(f"-{server_id}"): - existing_folder = d - break - - # 2. Rename it if it doesn't match the new desired name - if existing_folder and existing_folder != new_folder: - logger.info(f"Renaming existing folder {existing_folder.name} to {new_folder.name}") - try: - existing_folder.rename(new_folder) - except Exception as e: - logger.debug(f"Could not rename {existing_folder} to {new_folder}: {e}") + # Priority 1: Match by ID + existing_db: Path | None = None + # Look for any file ending with -{server_id}.db + for f in base.glob(f"*-{server_id}.db"): + if f.is_file(): + existing_db = f + break + + db_path = desired_path + if existing_db: + if existing_db.name != desired_filename: + logger.info(f"Server renamed: moving {existing_db.name} -> {desired_filename}") + try: + existing_db.rename(desired_path) + except Exception as e: + logger.error(f"Failed to rename database: {e}") + # If rename fails, we'll use the existing one if it exists at the old path, + # or the desired one if it exists there. + if not desired_path.exists(): + db_path = existing_db + + logger.info(f"Setting active migration database: {db_path}") + + from src.core.database import MigrationDatabase + if self.db: + self.db.close() + self.db = MigrationDatabase(db_path) + logger.info(f"Initialized SQLite database at {db_path}") - new_folder.mkdir(parents=True, exist_ok=True) - - self.state_file = new_folder / "state-migration.json" - self.messages_file = new_folder / "message-tracker.json" - - logger.debug("Re-loading data from new folder location.") - self.load() + # No-op methods kept for compatibility with callers that might try to load/save JSON + def load(self): pass + def save_state(self): pass diff --git a/src/ui/shuttle_ops.py b/src/ui/shuttle_ops.py index febcdc8..62189de 100644 --- a/src/ui/shuttle_ops.py +++ b/src/ui/shuttle_ops.py @@ -1576,7 +1576,7 @@ class OperationPane(Container): async def _fetch_clone_preview(self, selections: list[str]) -> dict[str, Any]: """Fetches preview data from Discord (source server) for cloning confirmation, - comparing with existing mappings in state-migration.json for presence highlighting.""" + comparing with existing mappings in SQLite for presence highlighting.""" preview = {} reader = self.engine.discord_reader