From 5ae515dd7ab918990adf13a54a5fdc8d477230aa Mon Sep 17 00:00:00 2001 From: rambros Date: Sat, 28 Mar 2026 02:47:14 +0530 Subject: [PATCH] int for fluxer, str for stoat --- src/core/base.py | 2 +- src/core/database.py | 217 +++++++++++++++++++------------- src/core/state.py | 23 ++-- src/core/utils.py | 7 +- src/fluxer/emoji_stickers.py | 8 +- src/fluxer/migrate_message.py | 6 +- src/fluxer/roles_permissions.py | 4 +- src/stoat/emoji_stickers.py | 3 +- 8 files changed, 158 insertions(+), 112 deletions(-) diff --git a/src/core/base.py b/src/core/base.py index 87086b1..81943f8 100644 --- a/src/core/base.py +++ b/src/core/base.py @@ -129,7 +129,7 @@ class MigrationContext: # or a logical subfolder. base_dir = getattr(self, "base_dir", "") - self.state.set_folder(community_id, clean_name, base_dir=base_dir) + self.state.set_folder(community_id, clean_name, self.target_platform, base_dir=base_dir) async def start_connections(self): await self.discord_reader.start() diff --git a/src/core/database.py b/src/core/database.py index aeffa49..996e083 100644 --- a/src/core/database.py +++ b/src/core/database.py @@ -16,8 +16,9 @@ class MigrationDatabase: Replaces the memory-bloated and O(N^2) JSON persistence for messages. """ - def __init__(self, db_path: Path): + def __init__(self, db_path: Path, platform: str = None): self.db_path = db_path + self.platform = platform.lower() if platform else None self._local = threading.local() self._init_db() @@ -28,25 +29,45 @@ class MigrationDatabase: return self._local.conn def _init_db(self): - """Initialize tables if they don't exist and handle migrations.""" + """Initialize tables if they don't exist and handle migrations/platform-specific schemas.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() - # 1. MIME Type to Content Type Migrations (if applicable - not in this class usually) + # 1. Determine active platform and column types + # Create metadata table first as we need it for platform tracking + cursor.execute("CREATE TABLE IF NOT EXISTS metadata (key TEXT PRIMARY KEY, value TEXT)") - # 2. Universal ID Migration (TEXT -> INTEGER) - # Mapping of table names to columns that must be INTEGER (Snowflakes) + # Load platform from DB if exists + cursor.execute("SELECT value FROM metadata WHERE key = ?", ("target_platform",)) + row = cursor.fetchone() + stored_platform = row[0] if row else None + + # If platform provided, update stored platform. If not provided, use stored. + active_platform = self.platform or stored_platform or "stoat" # Default to stoat if unknown + if self.platform and self.platform != stored_platform: + cursor.execute("INSERT OR REPLACE INTO metadata (key, value) VALUES (?, ?)", ("target_platform", active_platform)) + conn.commit() + + # Define types + # source_type is always INTEGER (Discord/Fluxer snowflakes) + # target_type is INTEGER for Fluxer, TEXT for Stoat + source_type = "INTEGER" + target_type = "INTEGER" if active_platform == "fluxer" else "TEXT" + + # 2. Universal ID Migration (TEXT -> INTEGER vs Platform Switch) + # Mapping of table names to columns that must match their respective types + # key: table, value: (discord_cols, target_cols) id_migrations = { - "message_mappings": ["channel_id", "source_msg_id", "target_msg_id"], - "thread_mappings": ["channel_id", "thread_id", "source_msg_id", "target_msg_id"], - "channel_tracking": ["channel_id", "last_msg_id"], - "thread_tracking": ["channel_id", "thread_id", "last_msg_id"], - "server_mappings": ["source_id", "target_id"], - "asset_mappings": ["source_id", "target_id"], - "user_alias": ["user_id"] + "message_mappings": (["source_msg_id"], ["channel_id", "target_msg_id"]), + "thread_mappings": (["source_msg_id"], ["channel_id", "thread_id", "target_msg_id"]), + "channel_tracking": ([], ["channel_id", "last_msg_id"]), + "thread_tracking": ([], ["channel_id", "thread_id", "last_msg_id"]), + "server_mappings": (["source_id"], ["target_id"]), + "asset_mappings": (["source_id"], ["target_id"]), + "user_alias": (["user_id"], []) } - for table, id_cols in id_migrations.items(): + for table, (discord_cols, target_cols) in id_migrations.items(): cursor.execute(f"SELECT count(*) FROM sqlite_master WHERE type='table' AND name='{table}'") res = cursor.fetchone() if not res or res[0] == 0: @@ -56,28 +77,32 @@ class MigrationDatabase: cols = cursor.fetchall() needs_migration = False for col in cols: - if col[1] in id_cols and col[2] == "TEXT": + c_name, c_type = col[1], col[2] + if c_name in discord_cols and c_type != source_type: + needs_migration = True + break + if c_name in target_cols and c_type != target_type: needs_migration = True break if needs_migration: - logger.info(f"MigrationDatabase: Migrating {table}: converting ID columns to INTEGER") + logger.info(f"MigrationDatabase: Migrating {table} schema (Platform: {active_platform})") cursor.execute(f"ALTER TABLE {table} RENAME TO {table}_old") if table == "message_mappings": - cursor.execute("CREATE TABLE message_mappings (channel_id INTEGER, source_msg_id INTEGER, target_msg_id INTEGER, timestamp TEXT, PRIMARY KEY (channel_id, source_msg_id))") + cursor.execute(f"CREATE TABLE message_mappings (channel_id {target_type}, source_msg_id {source_type}, target_msg_id {target_type}, timestamp TEXT, PRIMARY KEY (channel_id, source_msg_id))") elif table == "thread_mappings": - cursor.execute("CREATE TABLE thread_mappings (channel_id INTEGER, thread_id INTEGER, source_msg_id INTEGER, target_msg_id INTEGER, timestamp TEXT, PRIMARY KEY (channel_id, thread_id, source_msg_id))") + cursor.execute(f"CREATE TABLE thread_mappings (channel_id {target_type}, thread_id {target_type}, source_msg_id {source_type}, target_msg_id {target_type}, timestamp TEXT, PRIMARY KEY (channel_id, thread_id, source_msg_id))") elif table == "channel_tracking": - cursor.execute("CREATE TABLE channel_tracking (channel_id INTEGER PRIMARY KEY, last_msg_id INTEGER, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0)") + cursor.execute(f"CREATE TABLE channel_tracking (channel_id {target_type} PRIMARY KEY, last_msg_id {target_type}, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0)") elif table == "thread_tracking": - cursor.execute("CREATE TABLE thread_tracking (channel_id INTEGER, thread_id INTEGER, last_msg_id INTEGER, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0, completed INTEGER DEFAULT 0, PRIMARY KEY (channel_id, thread_id))") + cursor.execute(f"CREATE TABLE thread_tracking (channel_id {target_type}, thread_id {target_type}, last_msg_id {target_type}, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0, completed INTEGER DEFAULT 0, PRIMARY KEY (channel_id, thread_id))") elif table == "server_mappings": - cursor.execute("CREATE TABLE server_mappings (category TEXT, source_id INTEGER, target_id INTEGER, PRIMARY KEY (category, source_id))") + cursor.execute(f"CREATE TABLE server_mappings (category TEXT, source_id {source_type}, target_id {target_type}, PRIMARY KEY (category, source_id))") elif table == "asset_mappings": - cursor.execute("CREATE TABLE asset_mappings (category TEXT, source_id INTEGER, target_id INTEGER, PRIMARY KEY (category, source_id))") + cursor.execute(f"CREATE TABLE asset_mappings (category TEXT, source_id {source_type}, target_id {target_type}, PRIMARY KEY (category, source_id))") elif table == "user_alias": - cursor.execute("CREATE TABLE user_alias (user_id INTEGER PRIMARY KEY, alias TEXT UNIQUE)") + cursor.execute(f"CREATE TABLE user_alias (user_id {source_type} PRIMARY KEY, alias TEXT UNIQUE)") old_cols = [c[1] for c in cursor.execute(f"PRAGMA table_info({table}_old)").fetchall()] new_cols = [c[1] for c in cursor.execute(f"PRAGMA table_info({table})").fetchall()] @@ -89,33 +114,33 @@ class MigrationDatabase: # Initial Creation / Ensure Schema Correctness # Table for message mappings: SourceID -> TargetID - cursor.execute(""" + cursor.execute(f""" CREATE TABLE IF NOT EXISTS message_mappings ( - channel_id INTEGER, - source_msg_id INTEGER, - target_msg_id INTEGER, + channel_id {target_type}, + source_msg_id {source_type}, + target_msg_id {target_type}, timestamp TEXT, PRIMARY KEY (channel_id, source_msg_id) ) """) # Table for thread mappings - cursor.execute(""" + cursor.execute(f""" CREATE TABLE IF NOT EXISTS thread_mappings ( - channel_id INTEGER, - thread_id INTEGER, - source_msg_id INTEGER, - target_msg_id INTEGER, + channel_id {target_type}, + thread_id {target_type}, + source_msg_id {source_type}, + target_msg_id {target_type}, timestamp TEXT, PRIMARY KEY (channel_id, thread_id, source_msg_id) ) """) # Table for per-channel stats and tracking - cursor.execute(""" + cursor.execute(f""" CREATE TABLE IF NOT EXISTS channel_tracking ( - channel_id INTEGER PRIMARY KEY, - last_msg_id INTEGER, + channel_id {target_type} PRIMARY KEY, + last_msg_id {target_type}, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0 @@ -123,11 +148,11 @@ class MigrationDatabase: """) # Table for per-thread stats - cursor.execute(""" + cursor.execute(f""" CREATE TABLE IF NOT EXISTS thread_tracking ( - channel_id INTEGER, - thread_id INTEGER, - last_msg_id INTEGER, + channel_id {target_type}, + thread_id {target_type}, + last_msg_id {target_type}, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0, @@ -136,32 +161,32 @@ class MigrationDatabase: ) """) - # Add completed column if it doesn't exist (backward compatibility for existing resumption DBs) + # Add completed column if it doesn't exist (backward compatibility) try: cursor.execute("ALTER TABLE thread_tracking ADD COLUMN completed INTEGER DEFAULT 0") except sqlite3.OperationalError: pass # Already exists # Table for server entity mappings (channels, roles, categories) - cursor.execute(""" + cursor.execute(f""" CREATE TABLE IF NOT EXISTS server_mappings ( category TEXT, - source_id INTEGER, - target_id INTEGER, + source_id {source_type}, + target_id {target_type}, PRIMARY KEY (category, source_id) ) """) # Table for asset mappings (emojis, stickers) - cursor.execute(""" + cursor.execute(f""" CREATE TABLE IF NOT EXISTS asset_mappings ( category TEXT, - source_id INTEGER, - target_id INTEGER, + source_id {source_type}, + target_id {target_type}, PRIMARY KEY (category, source_id) ) """) - + # Migrate old entity_mappings if it exists cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='entity_mappings'") if cursor.fetchone(): @@ -182,18 +207,10 @@ class MigrationDatabase: # Drop old table cursor.execute("DROP TABLE entity_mappings") - # Table for general metadata - cursor.execute(""" - CREATE TABLE IF NOT EXISTS metadata ( - key TEXT PRIMARY KEY, - value TEXT - ) - """) - # Table for auto-generated user aliases (user_id -> alias) - cursor.execute(""" + cursor.execute(f""" CREATE TABLE IF NOT EXISTS user_alias ( - user_id INTEGER PRIMARY KEY, + user_id {source_type} PRIMARY KEY, alias TEXT UNIQUE ) """) @@ -209,17 +226,30 @@ class MigrationDatabase: conn = self._get_conn() conn.execute( "INSERT OR REPLACE INTO message_mappings (channel_id, source_msg_id, target_msg_id, timestamp) VALUES (?, ?, ?, ?)", - (parse_snowflake(channel_id), parse_snowflake(source_id), parse_snowflake(target_id), timestamp) + (str(channel_id), parse_snowflake(source_id), str(target_id), timestamp) ) conn.commit() - def get_target_message_id(self, channel_id: str, source_id: str) -> Optional[str]: + def get_target_message_id(self, channel_id: str, source_id: str) -> Optional[Union[str, int]]: conn = self._get_conn() row = conn.execute( "SELECT target_msg_id FROM message_mappings WHERE channel_id = ? AND source_msg_id = ?", - (parse_snowflake(channel_id), parse_snowflake(source_id)) + (str(channel_id), parse_snowflake(source_id)) ).fetchone() - return row["target_msg_id"] if row else None + if row: + val = row["target_msg_id"] + return str(val) if self.platform == "stoat" else val + return None + + def get_all_message_mappings(self, channel_id: str) -> Dict[Union[str, int], Union[str, int]]: + conn = self._get_conn() + rows = conn.execute( + "SELECT source_msg_id, target_msg_id FROM message_mappings WHERE channel_id = ?", + (str(channel_id),) + ).fetchall() + if self.platform == "stoat": + return {str(row["source_msg_id"]): str(row["target_msg_id"]) for row in rows} + return {row["source_msg_id"]: row["target_msg_id"] for row in rows} # --- User Alias Methods --- @@ -296,24 +326,29 @@ class MigrationDatabase: conn = self._get_conn() conn.execute( "INSERT OR REPLACE INTO server_mappings (category, source_id, target_id) VALUES (?, ?, ?)", - (category, parse_snowflake(source_id), parse_snowflake(target_id)) + (category, parse_snowflake(source_id), str(target_id)) ) conn.commit() - def get_server_mapping(self, category: str, source_id: str) -> Optional[str]: + def get_server_mapping(self, category: str, source_id: str) -> Optional[Union[str, int]]: conn = self._get_conn() row = conn.execute( "SELECT target_id FROM server_mappings WHERE category = ? AND source_id = ?", (category, parse_snowflake(source_id)) ).fetchone() - return row["target_id"] if row else None + if row: + val = row["target_id"] + return str(val) if self.platform == "stoat" else val + return None - def get_all_server_mappings(self, category: str) -> Dict[str, str]: + def get_all_server_mappings(self, category: str) -> Dict[Union[str, int], Union[str, int]]: conn = self._get_conn() rows = conn.execute( "SELECT source_id, target_id FROM server_mappings WHERE category = ?", (category,) ).fetchall() + if self.platform == "stoat": + return {str(row["source_id"]): str(row["target_id"]) for row in rows} return {row["source_id"]: row["target_id"] for row in rows} def delete_server_mapping(self, category: str, source_id: str): @@ -338,24 +373,29 @@ class MigrationDatabase: conn = self._get_conn() conn.execute( "INSERT OR REPLACE INTO asset_mappings (category, source_id, target_id) VALUES (?, ?, ?)", - (category, parse_snowflake(source_id), parse_snowflake(target_id)) + (category, parse_snowflake(source_id), str(target_id)) ) conn.commit() - def get_asset_mapping(self, category: str, source_id: str) -> Optional[str]: + def get_asset_mapping(self, category: str, source_id: str) -> Optional[Union[str, int]]: conn = self._get_conn() row = conn.execute( "SELECT target_id FROM asset_mappings WHERE category = ? AND source_id = ?", (category, parse_snowflake(source_id)) ).fetchone() - return row["target_id"] if row else None + if row: + val = row["target_id"] + return str(val) if self.platform == "stoat" else val + return None - def get_all_asset_mappings(self, category: str) -> Dict[str, str]: + def get_all_asset_mappings(self, category: str) -> Dict[Union[str, int], Union[str, int]]: conn = self._get_conn() rows = conn.execute( "SELECT source_id, target_id FROM asset_mappings WHERE category = ?", (category,) ).fetchall() + if self.platform == "stoat": + return {str(row["source_id"]): str(row["target_id"]) for row in rows} return {row["source_id"]: row["target_id"] for row in rows} def delete_asset_mapping(self, category: str, source_id: str): @@ -389,23 +429,23 @@ class MigrationDatabase: def update_channel_tracking(self, channel_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0): conn = self._get_conn() # Initialize if missing - conn.execute("INSERT OR IGNORE INTO channel_tracking (channel_id) VALUES (?)", (parse_snowflake(channel_id),)) + conn.execute("INSERT OR IGNORE INTO channel_tracking (channel_id) VALUES (?)", (str(channel_id),)) if last_msg_id: - conn.execute("UPDATE channel_tracking SET last_msg_id = ? WHERE channel_id = ?", (parse_snowflake(last_msg_id), parse_snowflake(channel_id))) + conn.execute("UPDATE channel_tracking SET last_msg_id = ? WHERE channel_id = ?", (str(last_msg_id), str(channel_id))) if last_msg_ts: - conn.execute("UPDATE channel_tracking SET last_msg_ts = ? WHERE channel_id = ?", (last_msg_ts, parse_snowflake(channel_id))) + conn.execute("UPDATE channel_tracking SET last_msg_ts = ? WHERE channel_id = ?", (last_msg_ts, str(channel_id))) if msg_inc != 0 or file_inc != 0: conn.execute( "UPDATE channel_tracking SET msg_count = msg_count + ?, file_count = file_count + ? WHERE channel_id = ?", - (msg_inc, file_inc, parse_snowflake(channel_id)) + (msg_inc, file_inc, str(channel_id)) ) conn.commit() def get_channel_tracking(self, channel_id: str) -> Dict[str, Any]: conn = self._get_conn() - row = conn.execute("SELECT * FROM channel_tracking WHERE channel_id = ?", (parse_snowflake(channel_id),)).fetchone() + row = conn.execute("SELECT * FROM channel_tracking WHERE channel_id = ?", (str(channel_id),)).fetchone() if row: return dict(row) return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0} @@ -415,39 +455,42 @@ class MigrationDatabase: conn = self._get_conn() conn.execute( "INSERT OR REPLACE INTO thread_mappings (channel_id, thread_id, source_msg_id, target_msg_id, timestamp) VALUES (?, ?, ?, ?, ?)", - (parse_snowflake(channel_id), parse_snowflake(thread_id), parse_snowflake(source_id), parse_snowflake(target_id), timestamp) + (str(channel_id), str(thread_id), parse_snowflake(source_id), str(target_id), timestamp) ) conn.commit() - def get_target_thread_message_id(self, channel_id: str, thread_id: str, source_id: str) -> Optional[str]: + def get_target_thread_message_id(self, channel_id: str, thread_id: str, source_id: str) -> Optional[Union[str, int]]: conn = self._get_conn() row = conn.execute( "SELECT target_msg_id FROM thread_mappings WHERE channel_id = ? AND thread_id = ? AND source_msg_id = ?", - (parse_snowflake(channel_id), parse_snowflake(thread_id), parse_snowflake(source_id)) + (str(channel_id), str(thread_id), parse_snowflake(source_id)) ).fetchone() - return row["target_msg_id"] if row else None + if row: + val = row["target_msg_id"] + return str(val) if self.platform == "stoat" else val + return None def update_thread_tracking(self, channel_id: str, thread_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0, completed: int = None): conn = self._get_conn() - conn.execute("INSERT OR IGNORE INTO thread_tracking (channel_id, thread_id) VALUES (?, ?)", (parse_snowflake(channel_id), parse_snowflake(thread_id))) + conn.execute("INSERT OR IGNORE INTO thread_tracking (channel_id, thread_id) VALUES (?, ?)", (str(channel_id), str(thread_id))) if last_msg_id: - conn.execute("UPDATE thread_tracking SET last_msg_id = ? WHERE channel_id = ? AND thread_id = ?", (parse_snowflake(last_msg_id), parse_snowflake(channel_id), parse_snowflake(thread_id))) + conn.execute("UPDATE thread_tracking SET last_msg_id = ? WHERE channel_id = ? AND thread_id = ?", (str(last_msg_id), str(channel_id), str(thread_id))) if last_msg_ts: - conn.execute("UPDATE thread_tracking SET last_msg_ts = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_ts, parse_snowflake(channel_id), parse_snowflake(thread_id))) + conn.execute("UPDATE thread_tracking SET last_msg_ts = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_ts, str(channel_id), str(thread_id))) if completed is not None: - conn.execute("UPDATE thread_tracking SET completed = ? WHERE channel_id = ? AND thread_id = ?", (completed, parse_snowflake(channel_id), parse_snowflake(thread_id))) + conn.execute("UPDATE thread_tracking SET completed = ? WHERE channel_id = ? AND thread_id = ?", (completed, str(channel_id), str(thread_id))) if msg_inc != 0 or file_inc != 0: conn.execute( "UPDATE thread_tracking SET msg_count = msg_count + ?, file_count = file_count + ? WHERE channel_id = ? AND thread_id = ?", - (msg_inc, file_inc, parse_snowflake(channel_id), parse_snowflake(thread_id)) + (msg_inc, file_inc, str(channel_id), str(thread_id)) ) conn.commit() def get_thread_tracking(self, channel_id: str, thread_id: str) -> Dict[str, Any]: conn = self._get_conn() - row = conn.execute("SELECT * FROM thread_tracking WHERE channel_id = ? AND thread_id = ?", (parse_snowflake(channel_id), parse_snowflake(thread_id))).fetchone() + row = conn.execute("SELECT * FROM thread_tracking WHERE channel_id = ? AND thread_id = ?", (str(channel_id), str(thread_id))).fetchone() if row: return dict(row) return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0} @@ -455,10 +498,10 @@ class MigrationDatabase: def clear_channel_data(self, channel_id: str): """Purge all mappings and tracking data for a specific channel and its threads.""" conn = self._get_conn() - conn.execute("DELETE FROM message_mappings WHERE channel_id = ?", (parse_snowflake(channel_id),)) - conn.execute("DELETE FROM thread_mappings WHERE channel_id = ?", (parse_snowflake(channel_id),)) - conn.execute("DELETE FROM channel_tracking WHERE channel_id = ?", (parse_snowflake(channel_id),)) - conn.execute("DELETE FROM thread_tracking WHERE channel_id = ?", (parse_snowflake(channel_id),)) + conn.execute("DELETE FROM message_mappings WHERE channel_id = ?", (str(channel_id),)) + conn.execute("DELETE FROM thread_mappings WHERE channel_id = ?", (str(channel_id),)) + conn.execute("DELETE FROM channel_tracking WHERE channel_id = ?", (str(channel_id),)) + conn.execute("DELETE FROM thread_tracking WHERE channel_id = ?", (str(channel_id),)) conn.commit() logger.info(f"Cleared all tracking and mapping data for channel: {channel_id}") diff --git a/src/core/state.py b/src/core/state.py index 0e8d181..dcfd446 100644 --- a/src/core/state.py +++ b/src/core/state.py @@ -1,7 +1,10 @@ import json import logging from pathlib import Path -from typing import Dict, Any, Optional +from typing import Dict, Optional, Any, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from src.core.database import MigrationDatabase logger = logging.getLogger(__name__) @@ -119,23 +122,23 @@ class MigrationState: # --- Properties for backward compatibility --- @property - def channel_map(self) -> Dict[str, str]: + def channel_map(self) -> Dict[Union[str, int], Union[str, int]]: return self.db.get_all_server_mappings("channel") if self.db else {} @property - def category_map(self) -> Dict[str, str]: + def category_map(self) -> Dict[Union[str, int], Union[str, int]]: return self.db.get_all_server_mappings("category") if self.db else {} @property - def role_map(self) -> Dict[str, str]: + def role_map(self) -> Dict[Union[str, int], Union[str, int]]: return self.db.get_all_server_mappings("role") if self.db else {} @property - def emoji_map(self) -> Dict[str, str]: + def emoji_map(self) -> Dict[Union[str, int], Union[str, int]]: return self.db.get_all_asset_mappings("emoji") if self.db else {} @property - def sticker_map(self) -> Dict[str, str]: + def sticker_map(self) -> Dict[Union[str, int], Union[str, int]]: return self.db.get_all_asset_mappings("sticker") if self.db else {} @property @@ -153,7 +156,7 @@ class MigrationState: if self._ensure_db(): self.db.set_message_mapping(str(target_channel_id), str(discord_id), str(target_id)) - def get_target_message_id(self, target_channel_id: str, discord_id: str) -> str | None: + def get_target_message_id(self, target_channel_id: str, discord_id: str) -> str | int | None: if self._ensure_db(): return self.db.get_target_message_id(str(target_channel_id), str(discord_id)) return None @@ -161,7 +164,7 @@ class MigrationState: def set_message_mapping(self, target_channel_id: str, discord_id: str, target_id: str): self.set_target_message_mapping(target_channel_id, discord_id, target_id) - def get_fluxer_message_id(self, target_channel_id: str, discord_id: str) -> str | None: + def get_fluxer_message_id(self, target_channel_id: str, discord_id: str) -> str | int | None: return self.get_target_message_id(target_channel_id, discord_id) def increment_stats(self, target_channel_id: str, messages: int = 1, files: int = 0): @@ -258,7 +261,7 @@ class MigrationState: if self._ensure_db(): self.db.clear_channel_data(str(target_channel_id)) - def set_folder(self, server_id: str, clean_name: str, base_dir: Path | str = ""): + def set_folder(self, server_id: str, clean_name: str, platform: str = "stoat", base_dir: Path | str = ""): """ Initializes the SQLite database based on community name and ID. Filename: {name}-{id}.db (Flat structure) @@ -294,7 +297,7 @@ class MigrationState: from src.core.database import MigrationDatabase if self.db: self.db.close() - self.db = MigrationDatabase(db_path) + self.db = MigrationDatabase(db_path, platform=platform) logger.info(f"Initialized SQLite database at {db_path}") def get_user_alias(self, user_id: str) -> str | None: diff --git a/src/core/utils.py b/src/core/utils.py index 8402395..1167d47 100644 --- a/src/core/utils.py +++ b/src/core/utils.py @@ -14,15 +14,16 @@ def parse_snowflake(value: Any) -> Optional[int]: except ValueError: return None -from src.core.state import MigrationState - logger = logging.getLogger(__name__) -def resolve_discord_links(content: str, state: MigrationState, platform: str, target_server_id: str) -> str: +def resolve_discord_links(content: str, state, platform: str, target_server_id: str) -> str: """ Finds Discord message/channel links and resolves them to the target platform if they have been migrated. """ + from src.core.state import MigrationState + if not isinstance(state, MigrationState): + logger.warning(f"resolve_discord_links: state is not MigrationState (type: {type(state)})") if not content: return content diff --git a/src/fluxer/emoji_stickers.py b/src/fluxer/emoji_stickers.py index f8de0ec..7a7e8fe 100644 --- a/src/fluxer/emoji_stickers.py +++ b/src/fluxer/emoji_stickers.py @@ -18,10 +18,10 @@ async def sync_assets_state(context: MigrationContext): fluxer_stickers = await context.fluxer_writer.client.get_guild_stickers(context.config.fluxer_server_id) # Build name -> id maps and ID sets for Fluxer for fast lookup - fluxer_emoji_map = {e.get("name"): str(e.get("id")) for e in fluxer_emojis if e.get("name")} - fluxer_sticker_map = {s.get("name"): str(s.get("id")) for s in fluxer_stickers if s.get("name")} - fluxer_emoji_ids = {str(e.get("id")) for e in fluxer_emojis} - fluxer_sticker_ids = {str(s.get("id")) for s in fluxer_stickers} + fluxer_emoji_map = {e.get("name"): e.get("id") for e in fluxer_emojis if e.get("name")} + fluxer_sticker_map = {s.get("name"): s.get("id") for s in fluxer_stickers if s.get("name")} + fluxer_emoji_ids = {e.get("id") for e in fluxer_emojis} + fluxer_sticker_ids = {s.get("id") for s in fluxer_stickers} updates = 0 removals = 0 diff --git a/src/fluxer/migrate_message.py b/src/fluxer/migrate_message.py index dc644cb..12e656c 100644 --- a/src/fluxer/migrate_message.py +++ b/src/fluxer/migrate_message.py @@ -73,8 +73,8 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None, cid = int(match.group(1)) # 1. Check if channel is mapped in state - if channel_map and str(cid) in channel_map: - return f"<#{channel_map[str(cid)]}>" + if channel_map and cid in channel_map: + return f"<#{channel_map[cid]}>" # 2. Try to resolve channel name from pre-fetched names name = None @@ -100,7 +100,7 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None, def replace_emoji(match): animated = match.group(1) == "a" name = match.group(2) - eid = match.group(3) + eid = int(match.group(3)) if emoji_map and eid in emoji_map: target_eid = emoji_map[eid] diff --git a/src/fluxer/roles_permissions.py b/src/fluxer/roles_permissions.py index 83ff25b..5973c27 100644 --- a/src/fluxer/roles_permissions.py +++ b/src/fluxer/roles_permissions.py @@ -15,8 +15,8 @@ async def sync_roles_state(context: MigrationContext): fluxer_roles = await context.fluxer_writer.client.get_guild_roles(context.config.fluxer_server_id) # Build name -> id maps and ID sets for Fluxer for fast lookup - fluxer_role_map = {r.get("name"): str(r.get("id")) for r in fluxer_roles if r.get("name")} - fluxer_role_ids = {str(r.get("id")) for r in fluxer_roles} + fluxer_role_map = {r.get("name"): r.get("id") for r in fluxer_roles if r.get("name")} + fluxer_role_ids = {r.get("id") for r in fluxer_roles} updates = 0 removals = 0 diff --git a/src/stoat/emoji_stickers.py b/src/stoat/emoji_stickers.py index bfa1b60..5b7a0ed 100644 --- a/src/stoat/emoji_stickers.py +++ b/src/stoat/emoji_stickers.py @@ -37,8 +37,7 @@ async def sync_assets_state(context: MigrationContext): if stoat_id: if stoat_id not in stoat_emoji_ids: - context.state.emoji_map.pop(discord_id, None) - context.state.save_state() + context.state.remove_emoji_mapping(discord_id) removals += 1 elif emoji.name in stoat_emoji_map: context.state.set_target_emoji_mapping(discord_id, stoat_emoji_map[emoji.name])