From 5541f82dba8ceb90e0382db9b7105f132aadea07 Mon Sep 17 00:00:00 2001 From: HuntingFighter Date: Fri, 27 Mar 2026 16:36:11 +0100 Subject: [PATCH 1/8] Fixed a bug where some messages where not imported in the right order. --- src/core/backup_database.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/src/core/backup_database.py b/src/core/backup_database.py index ecbd1b8..61b0aaa 100644 --- a/src/core/backup_database.py +++ b/src/core/backup_database.py @@ -40,6 +40,7 @@ class BackupDatabase: """Handles backward compatibility by renaming columns in existing databases.""" with self._lock: # Check 'media_pool' table + self._conn.execute("") res = self._conn.execute("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='media_pool'").fetchone() if res[0] > 0: cols = self._conn.execute("PRAGMA table_info(media_pool)").fetchall() @@ -56,6 +57,27 @@ class BackupDatabase: if "mime_type" in col_names and "content_type" not in col_names: logger.info("Migrating server_assets: renaming 'mime_type' to 'content_type'") self._conn.execute("ALTER TABLE server_assets RENAME COLUMN mime_type TO content_type") + res = self._conn.execute("SELECT count(*) FROM messages LIMIT 1").fetchone() + if res[0] > 0: + cols = self._conn.execute("PRAGMA table_info(messages)").fetchall() + id_type = next(col for col in cols if col[1] == "id")[2] + if id_type == "TEXT": + logger.info("Migrating messages: Changing id column type to integer") + self._conn.execute("ALTER TABLE messages RENAME TO messages_old") + self._conn.execute("""CREATE TABLE IF NOT EXISTS messages + ( + id INTEGER PRIMARY KEY, + channel_id TEXT, + author_id TEXT, + content TEXT, + timestamp TEXT, + type INTEGER, + message_reference TEXT, + is_pinned INTEGER, + extra_data TEXT + ) + """) + self._conn.execute("INSERT INTO messages SELECT * FROM messages_old") self._conn.commit() def _init_db(self): @@ -135,7 +157,7 @@ class BackupDatabase: # Messages conn.execute(""" CREATE TABLE IF NOT EXISTS messages ( - id TEXT PRIMARY KEY, + id INTEGER PRIMARY KEY, channel_id TEXT, author_id TEXT, content TEXT, From 0111bc0eae726d39e3238890fe675819582f11cd Mon Sep 17 00:00:00 2001 From: HuntingFighter Date: Fri, 27 Mar 2026 16:36:58 +0100 Subject: [PATCH 2/8] Drop old messages table --- src/core/backup_database.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/core/backup_database.py b/src/core/backup_database.py index 61b0aaa..733612c 100644 --- a/src/core/backup_database.py +++ b/src/core/backup_database.py @@ -78,6 +78,7 @@ class BackupDatabase: ) """) self._conn.execute("INSERT INTO messages SELECT * FROM messages_old") + self._conn.execute("DROP TABLE messages_old") self._conn.commit() def _init_db(self): From 5750133b3a49c32ecd23106f6100853fcd2f9db6 Mon Sep 17 00:00:00 2001 From: rambros Date: Sat, 28 Mar 2026 02:46:19 +0530 Subject: [PATCH 3/8] use int for discord snowflake ids --- src/core/backup_database.py | 227 +++++++++++++++++++++--------------- src/core/database.py | 147 ++++++++++++++++------- src/core/exporter.py | 5 +- src/core/utils.py | 14 +++ 4 files changed, 255 insertions(+), 138 deletions(-) diff --git a/src/core/backup_database.py b/src/core/backup_database.py index 733612c..88d0df1 100644 --- a/src/core/backup_database.py +++ b/src/core/backup_database.py @@ -4,20 +4,11 @@ import json import threading from pathlib import Path from typing import Dict, Any, List, Optional, Union +from src.core.utils import parse_snowflake logger = logging.getLogger(__name__) -def parse_snowflake(value: Any) -> Optional[int]: - """Safely parses a Discord ID (Snowflake) from any input, handling 'None' strings.""" - if value is None: - return None - s = str(value).strip() - if not s or s.lower() == "none" or s == "NULL": - return None - try: - return int(s) - except ValueError: - return None + class BackupDatabase: """Manages the SQLite database for local Discord backups.""" @@ -39,47 +30,101 @@ class BackupDatabase: def _migrate_db(self): """Handles backward compatibility by renaming columns in existing databases.""" with self._lock: - # Check 'media_pool' table - self._conn.execute("") - res = self._conn.execute("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='media_pool'").fetchone() - if res[0] > 0: - cols = self._conn.execute("PRAGMA table_info(media_pool)").fetchall() - col_names = [c["name"] for c in cols] - if "mime_type" in col_names and "content_type" not in col_names: - logger.info("Migrating media_pool: renaming 'mime_type' to 'content_type'") - self._conn.execute("ALTER TABLE media_pool RENAME COLUMN mime_type TO content_type") + conn = self._conn + # 1. MIME Type to Content Type Migrations + for table in ["media_pool", "server_assets"]: + res = conn.execute(f"SELECT count(*) FROM sqlite_master WHERE type='table' AND name='{table}'").fetchone() + if res and res[0] > 0: + cols = conn.execute(f"PRAGMA table_info({table})").fetchall() + col_names = [c["name"] for c in cols] + if "mime_type" in col_names and "content_type" not in col_names: + logger.info(f"Migrating {table}: renaming 'mime_type' to 'content_type'") + conn.execute(f"ALTER TABLE {table} RENAME COLUMN mime_type TO content_type") - # Check 'server_assets' table - res = self._conn.execute("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='server_assets'").fetchone() - if res[0] > 0: - cols = self._conn.execute("PRAGMA table_info(server_assets)").fetchall() - col_names = [c["name"] for c in cols] - if "mime_type" in col_names and "content_type" not in col_names: - logger.info("Migrating server_assets: renaming 'mime_type' to 'content_type'") - self._conn.execute("ALTER TABLE server_assets RENAME COLUMN mime_type TO content_type") - res = self._conn.execute("SELECT count(*) FROM messages LIMIT 1").fetchone() - if res[0] > 0: - cols = self._conn.execute("PRAGMA table_info(messages)").fetchall() - id_type = next(col for col in cols if col[1] == "id")[2] - if id_type == "TEXT": - logger.info("Migrating messages: Changing id column type to integer") - self._conn.execute("ALTER TABLE messages RENAME TO messages_old") - self._conn.execute("""CREATE TABLE IF NOT EXISTS messages - ( - id INTEGER PRIMARY KEY, - channel_id TEXT, - author_id TEXT, - content TEXT, - timestamp TEXT, - type INTEGER, - message_reference TEXT, - is_pinned INTEGER, - extra_data TEXT - ) - """) - self._conn.execute("INSERT INTO messages SELECT * FROM messages_old") - self._conn.execute("DROP TABLE messages_old") - self._conn.commit() + # 2. Universal ID Migration (TEXT -> INTEGER) + # Mapping of table names to columns that must be INTEGER (Snowflakes) + id_migrations = { + "guild_profile": ["id", "owner_id"], + "roles": ["id", "permissions"], + "channels": ["id", "category_id"], + "permissions": ["channel_id", "target_id"], + "users": ["id"], + "messages": ["id", "channel_id", "author_id", "message_reference"], + "attachments": ["id", "message_id"], + "embeds": ["message_id"], + "reactions": ["message_id", "emoji_id"], + "message_stickers": ["message_id", "sticker_id"], + "threads": ["id", "parent_id"], + "forum_tags": ["id", "forum_id", "emoji_id"], + "server_assets": ["id"] + } + + for table, id_cols in id_migrations.items(): + res = conn.execute(f"SELECT count(*) FROM sqlite_master WHERE type='table' AND name='{table}'").fetchone() + if not res or res[0] == 0: + continue + + cols = conn.execute(f"PRAGMA table_info({table})").fetchall() + needs_migration = False + for col in cols: + if col[1] in id_cols and col[2] == "TEXT": + needs_migration = True + break + + if needs_migration: + logger.info(f"Migrating {table}: converting ID columns to INTEGER") + # Special Case: messages already handled id, but now generic + # We use a temporary table to handle the schema change + conn.execute(f"ALTER TABLE {table} RENAME TO {table}_old") + + # We can't easily generate the CREATE TABLE here without duplicating _init_db logic + # So we call _init_db to create the NEW table, then copy data + # But _init_db has 'IF NOT EXISTS', so we just call it once at the end? + # No, we need the table NOW for the INSERT. + # I'll just manually define the inserts or do it in _init_db. + + # Actually, a better way is to do the CREATE TABLE here for this specific table. + # I'll have to duplicate the schema from _init_db for the migration. + + # Alternatively, since we are already in _migrate_db, we can just do the + # specific CREATE TABLE for the table we are migrating. + + if table == "guild_profile": + conn.execute("CREATE TABLE guild_profile (id INTEGER PRIMARY KEY, name TEXT, description TEXT, icon_file TEXT, icon_url TEXT, banner_file TEXT, banner_url TEXT, owner_id INTEGER, last_backup TEXT, ignore_channels TEXT)") + elif table == "roles": + conn.execute("CREATE TABLE roles (id INTEGER PRIMARY KEY, name TEXT, color INTEGER, position INTEGER, permissions INTEGER, hoist INTEGER, mentionable INTEGER)") + elif table == "channels": + conn.execute("CREATE TABLE channels (id INTEGER PRIMARY KEY, name TEXT, type INTEGER, position INTEGER, category_id INTEGER, topic TEXT, nsfw INTEGER, bitrate INTEGER, slowmode_delay INTEGER)") + elif table == "permissions": + conn.execute("CREATE TABLE permissions (id INTEGER PRIMARY KEY AUTOINCREMENT, channel_id INTEGER, target_id INTEGER, target_type TEXT, allow INTEGER, deny INTEGER)") + elif table == "users": + conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, username TEXT, display_name TEXT, avatar_file TEXT, avatar_url TEXT, roles TEXT)") + elif table == "messages": + conn.execute("CREATE TABLE messages (id INTEGER PRIMARY KEY, channel_id INTEGER, author_id INTEGER, content TEXT, timestamp TEXT, type INTEGER, message_reference INTEGER, is_pinned INTEGER, extra_data TEXT)") + elif table == "attachments": + conn.execute("CREATE TABLE attachments (id INTEGER PRIMARY KEY, message_id INTEGER, filename TEXT, size INTEGER, url TEXT, content_type TEXT, local_hash TEXT)") + elif table == "embeds": + conn.execute("CREATE TABLE embeds (id INTEGER PRIMARY KEY AUTOINCREMENT, message_id INTEGER, title TEXT, description TEXT, url TEXT, color INTEGER, timestamp TEXT, thumbnail_url TEXT, image_url TEXT, author_name TEXT, author_url TEXT, author_icon_url TEXT, footer_text TEXT, footer_icon_url TEXT, fields TEXT)") + elif table == "reactions": + conn.execute("CREATE TABLE reactions (id INTEGER PRIMARY KEY AUTOINCREMENT, message_id INTEGER, emoji_id INTEGER, emoji_name TEXT, count INTEGER)") + elif table == "message_stickers": + conn.execute("CREATE TABLE message_stickers (message_id INTEGER, sticker_id INTEGER, name TEXT, url TEXT, format_type INTEGER, local_hash TEXT, PRIMARY KEY (message_id, sticker_id))") + elif table == "threads": + conn.execute("CREATE TABLE threads (id INTEGER PRIMARY KEY, name TEXT, type INTEGER, parent_id INTEGER, message_count INTEGER, member_count INTEGER, archived INTEGER, archive_timestamp TEXT, auto_archive_duration INTEGER, locked INTEGER, applied_tags TEXT)") + elif table == "forum_tags": + conn.execute("CREATE TABLE forum_tags (id INTEGER PRIMARY KEY, forum_id INTEGER, name TEXT, moderated INTEGER, emoji_id INTEGER, emoji_name TEXT)") + elif table == "server_assets": + conn.execute("CREATE TABLE server_assets (id INTEGER PRIMARY KEY, name TEXT, type TEXT, filename TEXT, url TEXT, content_type INTEGER)") + + old_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table}_old)").fetchall()] + new_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table})").fetchall()] + common_cols = [c for c in old_cols if c in new_cols] + col_str = ", ".join(common_cols) + + conn.execute(f"INSERT INTO {table} ({col_str}) SELECT {col_str} FROM {table}_old") + conn.execute(f"DROP TABLE {table}_old") + + conn.commit() def _init_db(self): """Initializes the database schema.""" @@ -89,14 +134,14 @@ class BackupDatabase: # Guild Profile conn.execute(""" CREATE TABLE IF NOT EXISTS guild_profile ( - id TEXT PRIMARY KEY, + id INTEGER PRIMARY KEY, name TEXT, description TEXT, icon_file TEXT, icon_url TEXT, banner_file TEXT, banner_url TEXT, - owner_id TEXT, + owner_id INTEGER, last_backup TEXT, ignore_channels TEXT ) @@ -105,11 +150,11 @@ class BackupDatabase: # Roles conn.execute(""" CREATE TABLE IF NOT EXISTS roles ( - id TEXT PRIMARY KEY, + id INTEGER PRIMARY KEY, name TEXT, color INTEGER, position INTEGER, - permissions TEXT, + permissions INTEGER, hoist INTEGER, mentionable INTEGER ) @@ -118,11 +163,11 @@ class BackupDatabase: # Channels conn.execute(""" CREATE TABLE IF NOT EXISTS channels ( - id TEXT PRIMARY KEY, + id INTEGER PRIMARY KEY, name TEXT, type INTEGER, position INTEGER, - category_id TEXT, + category_id INTEGER, topic TEXT, nsfw INTEGER, bitrate INTEGER, @@ -134,8 +179,8 @@ class BackupDatabase: conn.execute(""" CREATE TABLE IF NOT EXISTS permissions ( id INTEGER PRIMARY KEY AUTOINCREMENT, - channel_id TEXT, - target_id TEXT, + channel_id INTEGER, + target_id INTEGER, target_type TEXT, allow INTEGER, deny INTEGER @@ -146,7 +191,7 @@ class BackupDatabase: # Users (Author cache) conn.execute(""" CREATE TABLE IF NOT EXISTS users ( - id TEXT PRIMARY KEY, + id INTEGER PRIMARY KEY, username TEXT, display_name TEXT, avatar_file TEXT, @@ -159,12 +204,12 @@ class BackupDatabase: conn.execute(""" CREATE TABLE IF NOT EXISTS messages ( id INTEGER PRIMARY KEY, - channel_id TEXT, - author_id TEXT, + channel_id INTEGER, + author_id INTEGER, content TEXT, timestamp TEXT, type INTEGER, - message_reference TEXT, + message_reference INTEGER, is_pinned INTEGER, extra_data TEXT ) @@ -175,8 +220,8 @@ class BackupDatabase: # Attachments conn.execute(""" CREATE TABLE IF NOT EXISTS attachments ( - id TEXT PRIMARY KEY, - message_id TEXT, + id INTEGER PRIMARY KEY, + message_id INTEGER, filename TEXT, size INTEGER, url TEXT, @@ -190,7 +235,7 @@ class BackupDatabase: conn.execute(""" CREATE TABLE IF NOT EXISTS embeds ( id INTEGER PRIMARY KEY AUTOINCREMENT, - message_id TEXT, + message_id INTEGER, title TEXT, description TEXT, url TEXT, @@ -212,8 +257,8 @@ class BackupDatabase: conn.execute(""" CREATE TABLE IF NOT EXISTS reactions ( id INTEGER PRIMARY KEY AUTOINCREMENT, - message_id TEXT, - emoji_id TEXT, + message_id INTEGER, + emoji_id INTEGER, emoji_name TEXT, count INTEGER ) @@ -223,8 +268,8 @@ class BackupDatabase: # Message Stickers conn.execute(""" CREATE TABLE IF NOT EXISTS message_stickers ( - message_id TEXT, - sticker_id TEXT, + message_id INTEGER, + sticker_id INTEGER, name TEXT, url TEXT, format_type INTEGER, @@ -237,10 +282,10 @@ class BackupDatabase: # Threads conn.execute(""" CREATE TABLE IF NOT EXISTS threads ( - id TEXT PRIMARY KEY, + id INTEGER PRIMARY KEY, name TEXT, type INTEGER, - parent_id TEXT, + parent_id INTEGER, message_count INTEGER, member_count INTEGER, archived INTEGER, @@ -255,11 +300,11 @@ class BackupDatabase: # Forum Tags (Definitions for a forum channel) conn.execute(""" CREATE TABLE IF NOT EXISTS forum_tags ( - id TEXT PRIMARY KEY, - forum_id TEXT, + id INTEGER PRIMARY KEY, + forum_id INTEGER, name TEXT, moderated INTEGER, - emoji_id TEXT, + emoji_id INTEGER, emoji_name TEXT ) """) @@ -280,7 +325,7 @@ class BackupDatabase: # Server Assets (Emojis, Stickers, etc.) conn.execute(""" CREATE TABLE IF NOT EXISTS server_assets ( - id TEXT PRIMARY KEY, + id INTEGER PRIMARY KEY, name TEXT, type TEXT, filename TEXT, @@ -299,10 +344,10 @@ class BackupDatabase: INSERT OR REPLACE INTO guild_profile (id, name, description, icon_file, icon_url, banner_file, banner_url, owner_id, last_backup, ignore_channels) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( - str(data.get("id")), data.get("name"), data.get("description"), + parse_snowflake(data.get("id")), data.get("name"), data.get("description"), data.get("icon_file"), data.get("icon_url"), data.get("banner_file"), data.get("banner_url"), - str(data.get("owner_id")), + parse_snowflake(data.get("owner_id")), data.get("last_backup"), json.dumps(data.get("ignore_channels", [])) )) self._conn.commit() @@ -323,7 +368,7 @@ class BackupDatabase: with self._lock: formatted = [ { - "id": str(r["id"]), + "id": parse_snowflake(r["id"]), "name": r["name"], "color": r["color"], "position": r["position"], @@ -370,7 +415,7 @@ class BackupDatabase: with self._lock: formatted = [ { - "id": str(a["id"]), + "id": parse_snowflake(a["id"]), "name": a.get("name"), "type": a.get("type"), "filename": a.get("filename"), @@ -430,7 +475,7 @@ class BackupDatabase: for rea in msg["reactions"]: all_reactions.append({ "message_id": msg["id"], - "emoji_id": str(rea["emoji_id"]) if rea.get("emoji_id") else None, + "emoji_id": parse_snowflake(rea["emoji_id"]) if rea.get("emoji_id") else None, "emoji_name": rea.get("emoji_name"), "count": rea.get("count", 0) }) @@ -440,7 +485,7 @@ class BackupDatabase: for st in msg["stickers"]: all_stickers.append({ "message_id": msg["id"], - "sticker_id": str(st["id"]), + "sticker_id": parse_snowflake(st["id"]), "name": st.get("name"), "url": st.get("url"), "format_type": st.get("format_type"), @@ -492,7 +537,7 @@ class BackupDatabase: def get_last_message_id(self, channel_id: str) -> Optional[str]: with self._lock: - row = self._conn.execute("SELECT id FROM messages WHERE channel_id = ? ORDER BY id DESC LIMIT 1", (str(channel_id),)).fetchone() + row = self._conn.execute("SELECT id FROM messages WHERE channel_id = ? ORDER BY id DESC LIMIT 1", (parse_snowflake(channel_id),)).fetchone() return row["id"] if row else None def get_media_by_hash(self, file_hash: str) -> Optional[Dict[str, Any]]: @@ -623,7 +668,7 @@ class BackupDatabase: """Returns forum tag definitions.""" with self._lock: if forum_id: - rows = self._conn.execute("SELECT * FROM forum_tags WHERE forum_id = ?", (str(forum_id),)).fetchall() + rows = self._conn.execute("SELECT * FROM forum_tags WHERE forum_id = ?", (parse_snowflake(forum_id),)).fetchall() else: rows = self._conn.execute("SELECT * FROM forum_tags").fetchall() return [dict(r) for r in rows] @@ -631,13 +676,13 @@ class BackupDatabase: def get_threads_by_parent(self, parent_id: str) -> List[Dict[str, Any]]: """Returns all threads belonging to a parent channel.""" with self._lock: - rows = self._conn.execute("SELECT * FROM threads WHERE parent_id = ?", (str(parent_id),)).fetchall() + rows = self._conn.execute("SELECT * FROM threads WHERE parent_id = ?", (parse_snowflake(parent_id),)).fetchall() return [dict(r) for r in rows] def get_thread(self, thread_id: str) -> Optional[Dict[str, Any]]: """Retrieves a single thread's metadata.""" with self._lock: - row = self._conn.execute("SELECT * FROM threads WHERE id = ?", (str(thread_id),)).fetchone() + row = self._conn.execute("SELECT * FROM threads WHERE id = ?", (parse_snowflake(thread_id),)).fetchone() return dict(row) if row else None def get_all_users(self) -> List[Dict[str, Any]]: @@ -647,7 +692,7 @@ class BackupDatabase: def get_user(self, user_id: str) -> Optional[Dict[str, Any]]: with self._lock: - row = self._conn.execute("SELECT * FROM users WHERE id = ?", (str(user_id),)).fetchone() + row = self._conn.execute("SELECT * FROM users WHERE id = ?", (parse_snowflake(user_id),)).fetchone() if row: data = dict(row) if data.get("roles"): @@ -673,11 +718,11 @@ class BackupDatabase: def get_messages_paged(self, channel_id: str, limit: int = 100, offset: int = 0, after_id: Optional[str] = None) -> List[Dict[str, Any]]: with self._lock: query = "SELECT * FROM messages WHERE channel_id = ?" - params = [str(channel_id)] + params = [parse_snowflake(channel_id)] if after_id: query += " AND id > ?" - params.append(str(after_id)) + params.append(parse_snowflake(after_id)) query += " ORDER BY id ASC LIMIT ? OFFSET ?" params.extend([limit, offset]) @@ -748,7 +793,7 @@ class BackupDatabase: def delete_channel_messages(self, channel_id: Union[str, int]): """Deletes all messages and related metadata for a specific channel and its threads.""" - cid = str(channel_id) + cid = parse_snowflake(channel_id) with self._lock: # 1. Identify all channel IDs involved (parent + all threads) target_ids = [cid] @@ -830,4 +875,4 @@ class BackupDatabase: self._conn.commit() self._conn.close() except Exception: - pass + pass \ No newline at end of file diff --git a/src/core/database.py b/src/core/database.py index b2230b9..aeffa49 100644 --- a/src/core/database.py +++ b/src/core/database.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Optional, Dict, Any import threading import sys +from src.core.utils import parse_snowflake logger = logging.getLogger(__name__) @@ -27,16 +28,72 @@ class MigrationDatabase: return self._local.conn def _init_db(self): - """Initialize tables if they don't exist.""" + """Initialize tables if they don't exist and handle migrations.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() + # 1. MIME Type to Content Type Migrations (if applicable - not in this class usually) + + # 2. Universal ID Migration (TEXT -> INTEGER) + # Mapping of table names to columns that must be INTEGER (Snowflakes) + id_migrations = { + "message_mappings": ["channel_id", "source_msg_id", "target_msg_id"], + "thread_mappings": ["channel_id", "thread_id", "source_msg_id", "target_msg_id"], + "channel_tracking": ["channel_id", "last_msg_id"], + "thread_tracking": ["channel_id", "thread_id", "last_msg_id"], + "server_mappings": ["source_id", "target_id"], + "asset_mappings": ["source_id", "target_id"], + "user_alias": ["user_id"] + } + + for table, id_cols in id_migrations.items(): + cursor.execute(f"SELECT count(*) FROM sqlite_master WHERE type='table' AND name='{table}'") + res = cursor.fetchone() + if not res or res[0] == 0: + continue + + cursor.execute(f"PRAGMA table_info({table})") + cols = cursor.fetchall() + needs_migration = False + for col in cols: + if col[1] in id_cols and col[2] == "TEXT": + needs_migration = True + break + + if needs_migration: + logger.info(f"MigrationDatabase: Migrating {table}: converting ID columns to INTEGER") + cursor.execute(f"ALTER TABLE {table} RENAME TO {table}_old") + + if table == "message_mappings": + cursor.execute("CREATE TABLE message_mappings (channel_id INTEGER, source_msg_id INTEGER, target_msg_id INTEGER, timestamp TEXT, PRIMARY KEY (channel_id, source_msg_id))") + elif table == "thread_mappings": + cursor.execute("CREATE TABLE thread_mappings (channel_id INTEGER, thread_id INTEGER, source_msg_id INTEGER, target_msg_id INTEGER, timestamp TEXT, PRIMARY KEY (channel_id, thread_id, source_msg_id))") + elif table == "channel_tracking": + cursor.execute("CREATE TABLE channel_tracking (channel_id INTEGER PRIMARY KEY, last_msg_id INTEGER, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0)") + elif table == "thread_tracking": + cursor.execute("CREATE TABLE thread_tracking (channel_id INTEGER, thread_id INTEGER, last_msg_id INTEGER, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0, completed INTEGER DEFAULT 0, PRIMARY KEY (channel_id, thread_id))") + elif table == "server_mappings": + cursor.execute("CREATE TABLE server_mappings (category TEXT, source_id INTEGER, target_id INTEGER, PRIMARY KEY (category, source_id))") + elif table == "asset_mappings": + cursor.execute("CREATE TABLE asset_mappings (category TEXT, source_id INTEGER, target_id INTEGER, PRIMARY KEY (category, source_id))") + elif table == "user_alias": + cursor.execute("CREATE TABLE user_alias (user_id INTEGER PRIMARY KEY, alias TEXT UNIQUE)") + + old_cols = [c[1] for c in cursor.execute(f"PRAGMA table_info({table}_old)").fetchall()] + new_cols = [c[1] for c in cursor.execute(f"PRAGMA table_info({table})").fetchall()] + common_cols = [c for c in old_cols if c in new_cols] + col_str = ", ".join(common_cols) + + cursor.execute(f"INSERT OR IGNORE INTO {table} ({col_str}) SELECT {col_str} FROM {table}_old") + cursor.execute(f"DROP TABLE {table}_old") + + # Initial Creation / Ensure Schema Correctness # Table for message mappings: SourceID -> TargetID cursor.execute(""" CREATE TABLE IF NOT EXISTS message_mappings ( - channel_id TEXT, - source_msg_id TEXT, - target_msg_id TEXT, + channel_id INTEGER, + source_msg_id INTEGER, + target_msg_id INTEGER, timestamp TEXT, PRIMARY KEY (channel_id, source_msg_id) ) @@ -45,10 +102,10 @@ class MigrationDatabase: # Table for thread mappings cursor.execute(""" CREATE TABLE IF NOT EXISTS thread_mappings ( - channel_id TEXT, - thread_id TEXT, - source_msg_id TEXT, - target_msg_id TEXT, + channel_id INTEGER, + thread_id INTEGER, + source_msg_id INTEGER, + target_msg_id INTEGER, timestamp TEXT, PRIMARY KEY (channel_id, thread_id, source_msg_id) ) @@ -57,8 +114,8 @@ class MigrationDatabase: # Table for per-channel stats and tracking cursor.execute(""" CREATE TABLE IF NOT EXISTS channel_tracking ( - channel_id TEXT PRIMARY KEY, - last_msg_id TEXT, + channel_id INTEGER PRIMARY KEY, + last_msg_id INTEGER, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0 @@ -68,9 +125,9 @@ class MigrationDatabase: # Table for per-thread stats cursor.execute(""" CREATE TABLE IF NOT EXISTS thread_tracking ( - channel_id TEXT, - thread_id TEXT, - last_msg_id TEXT, + channel_id INTEGER, + thread_id INTEGER, + last_msg_id INTEGER, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0, @@ -89,8 +146,8 @@ class MigrationDatabase: cursor.execute(""" CREATE TABLE IF NOT EXISTS server_mappings ( category TEXT, - source_id TEXT, - target_id TEXT, + source_id INTEGER, + target_id INTEGER, PRIMARY KEY (category, source_id) ) """) @@ -99,8 +156,8 @@ class MigrationDatabase: cursor.execute(""" CREATE TABLE IF NOT EXISTS asset_mappings ( category TEXT, - source_id TEXT, - target_id TEXT, + source_id INTEGER, + target_id INTEGER, PRIMARY KEY (category, source_id) ) """) @@ -136,7 +193,7 @@ class MigrationDatabase: # Table for auto-generated user aliases (user_id -> alias) cursor.execute(""" CREATE TABLE IF NOT EXISTS user_alias ( - user_id TEXT PRIMARY KEY, + user_id INTEGER PRIMARY KEY, alias TEXT UNIQUE ) """) @@ -152,7 +209,7 @@ class MigrationDatabase: conn = self._get_conn() conn.execute( "INSERT OR REPLACE INTO message_mappings (channel_id, source_msg_id, target_msg_id, timestamp) VALUES (?, ?, ?, ?)", - (channel_id, source_id, target_id, timestamp) + (parse_snowflake(channel_id), parse_snowflake(source_id), parse_snowflake(target_id), timestamp) ) conn.commit() @@ -160,7 +217,7 @@ class MigrationDatabase: conn = self._get_conn() row = conn.execute( "SELECT target_msg_id FROM message_mappings WHERE channel_id = ? AND source_msg_id = ?", - (channel_id, source_id) + (parse_snowflake(channel_id), parse_snowflake(source_id)) ).fetchone() return row["target_msg_id"] if row else None @@ -205,7 +262,7 @@ class MigrationDatabase: conn = self._get_conn() # Check for existing alias - row = conn.execute("SELECT alias FROM user_alias WHERE user_id = ?", (str(user_id),)).fetchone() + row = conn.execute("SELECT alias FROM user_alias WHERE user_id = ?", (parse_snowflake(user_id),)).fetchone() if row: return row["alias"] @@ -215,7 +272,7 @@ class MigrationDatabase: new_alias = self._generate_alias() conn.execute( "INSERT INTO user_alias (user_id, alias) VALUES (?, ?)", - (str(user_id), new_alias) + (parse_snowflake(user_id), new_alias) ) conn.commit() return new_alias @@ -239,7 +296,7 @@ class MigrationDatabase: conn = self._get_conn() conn.execute( "INSERT OR REPLACE INTO server_mappings (category, source_id, target_id) VALUES (?, ?, ?)", - (category, str(source_id), str(target_id)) + (category, parse_snowflake(source_id), parse_snowflake(target_id)) ) conn.commit() @@ -247,7 +304,7 @@ class MigrationDatabase: conn = self._get_conn() row = conn.execute( "SELECT target_id FROM server_mappings WHERE category = ? AND source_id = ?", - (category, str(source_id)) + (category, parse_snowflake(source_id)) ).fetchone() return row["target_id"] if row else None @@ -263,7 +320,7 @@ class MigrationDatabase: conn = self._get_conn() conn.execute( "DELETE FROM server_mappings WHERE category = ? AND source_id = ?", - (category, str(source_id)) + (category, parse_snowflake(source_id)) ) conn.commit() @@ -281,7 +338,7 @@ class MigrationDatabase: conn = self._get_conn() conn.execute( "INSERT OR REPLACE INTO asset_mappings (category, source_id, target_id) VALUES (?, ?, ?)", - (category, str(source_id), str(target_id)) + (category, parse_snowflake(source_id), parse_snowflake(target_id)) ) conn.commit() @@ -289,7 +346,7 @@ class MigrationDatabase: conn = self._get_conn() row = conn.execute( "SELECT target_id FROM asset_mappings WHERE category = ? AND source_id = ?", - (category, str(source_id)) + (category, parse_snowflake(source_id)) ).fetchone() return row["target_id"] if row else None @@ -305,7 +362,7 @@ class MigrationDatabase: conn = self._get_conn() conn.execute( "DELETE FROM asset_mappings WHERE category = ? AND source_id = ?", - (category, str(source_id)) + (category, parse_snowflake(source_id)) ) conn.commit() @@ -332,23 +389,23 @@ class MigrationDatabase: def update_channel_tracking(self, channel_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0): conn = self._get_conn() # Initialize if missing - conn.execute("INSERT OR IGNORE INTO channel_tracking (channel_id) VALUES (?)", (channel_id,)) + conn.execute("INSERT OR IGNORE INTO channel_tracking (channel_id) VALUES (?)", (parse_snowflake(channel_id),)) if last_msg_id: - conn.execute("UPDATE channel_tracking SET last_msg_id = ? WHERE channel_id = ?", (last_msg_id, channel_id)) + conn.execute("UPDATE channel_tracking SET last_msg_id = ? WHERE channel_id = ?", (parse_snowflake(last_msg_id), parse_snowflake(channel_id))) if last_msg_ts: - conn.execute("UPDATE channel_tracking SET last_msg_ts = ? WHERE channel_id = ?", (last_msg_ts, channel_id)) + conn.execute("UPDATE channel_tracking SET last_msg_ts = ? WHERE channel_id = ?", (last_msg_ts, parse_snowflake(channel_id))) if msg_inc != 0 or file_inc != 0: conn.execute( "UPDATE channel_tracking SET msg_count = msg_count + ?, file_count = file_count + ? WHERE channel_id = ?", - (msg_inc, file_inc, channel_id) + (msg_inc, file_inc, parse_snowflake(channel_id)) ) conn.commit() def get_channel_tracking(self, channel_id: str) -> Dict[str, Any]: conn = self._get_conn() - row = conn.execute("SELECT * FROM channel_tracking WHERE channel_id = ?", (channel_id,)).fetchone() + row = conn.execute("SELECT * FROM channel_tracking WHERE channel_id = ?", (parse_snowflake(channel_id),)).fetchone() if row: return dict(row) return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0} @@ -358,7 +415,7 @@ class MigrationDatabase: conn = self._get_conn() conn.execute( "INSERT OR REPLACE INTO thread_mappings (channel_id, thread_id, source_msg_id, target_msg_id, timestamp) VALUES (?, ?, ?, ?, ?)", - (channel_id, thread_id, source_id, target_id, timestamp) + (parse_snowflake(channel_id), parse_snowflake(thread_id), parse_snowflake(source_id), parse_snowflake(target_id), timestamp) ) conn.commit() @@ -366,31 +423,31 @@ class MigrationDatabase: conn = self._get_conn() row = conn.execute( "SELECT target_msg_id FROM thread_mappings WHERE channel_id = ? AND thread_id = ? AND source_msg_id = ?", - (channel_id, thread_id, source_id) + (parse_snowflake(channel_id), parse_snowflake(thread_id), parse_snowflake(source_id)) ).fetchone() return row["target_msg_id"] if row else None def update_thread_tracking(self, channel_id: str, thread_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0, completed: int = None): conn = self._get_conn() - conn.execute("INSERT OR IGNORE INTO thread_tracking (channel_id, thread_id) VALUES (?, ?)", (channel_id, thread_id)) + conn.execute("INSERT OR IGNORE INTO thread_tracking (channel_id, thread_id) VALUES (?, ?)", (parse_snowflake(channel_id), parse_snowflake(thread_id))) if last_msg_id: - conn.execute("UPDATE thread_tracking SET last_msg_id = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_id, channel_id, thread_id)) + conn.execute("UPDATE thread_tracking SET last_msg_id = ? WHERE channel_id = ? AND thread_id = ?", (parse_snowflake(last_msg_id), parse_snowflake(channel_id), parse_snowflake(thread_id))) if last_msg_ts: - conn.execute("UPDATE thread_tracking SET last_msg_ts = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_ts, channel_id, thread_id)) + conn.execute("UPDATE thread_tracking SET last_msg_ts = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_ts, parse_snowflake(channel_id), parse_snowflake(thread_id))) if completed is not None: - conn.execute("UPDATE thread_tracking SET completed = ? WHERE channel_id = ? AND thread_id = ?", (completed, channel_id, thread_id)) + conn.execute("UPDATE thread_tracking SET completed = ? WHERE channel_id = ? AND thread_id = ?", (completed, parse_snowflake(channel_id), parse_snowflake(thread_id))) if msg_inc != 0 or file_inc != 0: conn.execute( "UPDATE thread_tracking SET msg_count = msg_count + ?, file_count = file_count + ? WHERE channel_id = ? AND thread_id = ?", - (msg_inc, file_inc, channel_id, thread_id) + (msg_inc, file_inc, parse_snowflake(channel_id), parse_snowflake(thread_id)) ) conn.commit() def get_thread_tracking(self, channel_id: str, thread_id: str) -> Dict[str, Any]: conn = self._get_conn() - row = conn.execute("SELECT * FROM thread_tracking WHERE channel_id = ? AND thread_id = ?", (channel_id, thread_id)).fetchone() + row = conn.execute("SELECT * FROM thread_tracking WHERE channel_id = ? AND thread_id = ?", (parse_snowflake(channel_id), parse_snowflake(thread_id))).fetchone() if row: return dict(row) return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0} @@ -398,10 +455,10 @@ class MigrationDatabase: def clear_channel_data(self, channel_id: str): """Purge all mappings and tracking data for a specific channel and its threads.""" conn = self._get_conn() - conn.execute("DELETE FROM message_mappings WHERE channel_id = ?", (channel_id,)) - conn.execute("DELETE FROM thread_mappings WHERE channel_id = ?", (channel_id,)) - conn.execute("DELETE FROM channel_tracking WHERE channel_id = ?", (channel_id,)) - conn.execute("DELETE FROM thread_tracking WHERE channel_id = ?", (channel_id,)) + conn.execute("DELETE FROM message_mappings WHERE channel_id = ?", (parse_snowflake(channel_id),)) + conn.execute("DELETE FROM thread_mappings WHERE channel_id = ?", (parse_snowflake(channel_id),)) + conn.execute("DELETE FROM channel_tracking WHERE channel_id = ?", (parse_snowflake(channel_id),)) + conn.execute("DELETE FROM thread_tracking WHERE channel_id = ?", (parse_snowflake(channel_id),)) conn.commit() logger.info(f"Cleared all tracking and mapping data for channel: {channel_id}") diff --git a/src/core/exporter.py b/src/core/exporter.py index e60aaf7..7ad8926 100644 --- a/src/core/exporter.py +++ b/src/core/exporter.py @@ -109,7 +109,7 @@ class DiscordExporter: "name": r.name, "color": r.color.value, "position": r.position, - "permissions": str(r.permissions.value), + "permissions": r.permissions.value, "hoist": 1 if r.hoist else 0, "mentionable": 1 if r.mentionable else 0 }) @@ -563,9 +563,10 @@ class DiscordExporter: }) # 5. Message data + from src.core.utils import parse_snowflake message_reference = None if msg.reference and msg.reference.message_id: - message_reference = str(msg.reference.message_id) + message_reference = parse_snowflake(msg.reference.message_id) # 5.5 Forwarded snapshots content = msg.content or "" diff --git a/src/core/utils.py b/src/core/utils.py index 1375ef5..8402395 100644 --- a/src/core/utils.py +++ b/src/core/utils.py @@ -1,5 +1,19 @@ +from typing import Any, Optional import re import logging + +def parse_snowflake(value: Any) -> Optional[int]: + """Safely parses a Discord ID (Snowflake) from any input, handling 'None' strings.""" + if value is None: + return None + s = str(value).strip() + if not s or s.lower() == "none" or s == "NULL": + return None + try: + return int(s) + except ValueError: + return None + from src.core.state import MigrationState logger = logging.getLogger(__name__) From 5ae515dd7ab918990adf13a54a5fdc8d477230aa Mon Sep 17 00:00:00 2001 From: rambros Date: Sat, 28 Mar 2026 02:47:14 +0530 Subject: [PATCH 4/8] int for fluxer, str for stoat --- src/core/base.py | 2 +- src/core/database.py | 217 +++++++++++++++++++------------- src/core/state.py | 23 ++-- src/core/utils.py | 7 +- src/fluxer/emoji_stickers.py | 8 +- src/fluxer/migrate_message.py | 6 +- src/fluxer/roles_permissions.py | 4 +- src/stoat/emoji_stickers.py | 3 +- 8 files changed, 158 insertions(+), 112 deletions(-) diff --git a/src/core/base.py b/src/core/base.py index 87086b1..81943f8 100644 --- a/src/core/base.py +++ b/src/core/base.py @@ -129,7 +129,7 @@ class MigrationContext: # or a logical subfolder. base_dir = getattr(self, "base_dir", "") - self.state.set_folder(community_id, clean_name, base_dir=base_dir) + self.state.set_folder(community_id, clean_name, self.target_platform, base_dir=base_dir) async def start_connections(self): await self.discord_reader.start() diff --git a/src/core/database.py b/src/core/database.py index aeffa49..996e083 100644 --- a/src/core/database.py +++ b/src/core/database.py @@ -16,8 +16,9 @@ class MigrationDatabase: Replaces the memory-bloated and O(N^2) JSON persistence for messages. """ - def __init__(self, db_path: Path): + def __init__(self, db_path: Path, platform: str = None): self.db_path = db_path + self.platform = platform.lower() if platform else None self._local = threading.local() self._init_db() @@ -28,25 +29,45 @@ class MigrationDatabase: return self._local.conn def _init_db(self): - """Initialize tables if they don't exist and handle migrations.""" + """Initialize tables if they don't exist and handle migrations/platform-specific schemas.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() - # 1. MIME Type to Content Type Migrations (if applicable - not in this class usually) + # 1. Determine active platform and column types + # Create metadata table first as we need it for platform tracking + cursor.execute("CREATE TABLE IF NOT EXISTS metadata (key TEXT PRIMARY KEY, value TEXT)") - # 2. Universal ID Migration (TEXT -> INTEGER) - # Mapping of table names to columns that must be INTEGER (Snowflakes) + # Load platform from DB if exists + cursor.execute("SELECT value FROM metadata WHERE key = ?", ("target_platform",)) + row = cursor.fetchone() + stored_platform = row[0] if row else None + + # If platform provided, update stored platform. If not provided, use stored. + active_platform = self.platform or stored_platform or "stoat" # Default to stoat if unknown + if self.platform and self.platform != stored_platform: + cursor.execute("INSERT OR REPLACE INTO metadata (key, value) VALUES (?, ?)", ("target_platform", active_platform)) + conn.commit() + + # Define types + # source_type is always INTEGER (Discord/Fluxer snowflakes) + # target_type is INTEGER for Fluxer, TEXT for Stoat + source_type = "INTEGER" + target_type = "INTEGER" if active_platform == "fluxer" else "TEXT" + + # 2. Universal ID Migration (TEXT -> INTEGER vs Platform Switch) + # Mapping of table names to columns that must match their respective types + # key: table, value: (discord_cols, target_cols) id_migrations = { - "message_mappings": ["channel_id", "source_msg_id", "target_msg_id"], - "thread_mappings": ["channel_id", "thread_id", "source_msg_id", "target_msg_id"], - "channel_tracking": ["channel_id", "last_msg_id"], - "thread_tracking": ["channel_id", "thread_id", "last_msg_id"], - "server_mappings": ["source_id", "target_id"], - "asset_mappings": ["source_id", "target_id"], - "user_alias": ["user_id"] + "message_mappings": (["source_msg_id"], ["channel_id", "target_msg_id"]), + "thread_mappings": (["source_msg_id"], ["channel_id", "thread_id", "target_msg_id"]), + "channel_tracking": ([], ["channel_id", "last_msg_id"]), + "thread_tracking": ([], ["channel_id", "thread_id", "last_msg_id"]), + "server_mappings": (["source_id"], ["target_id"]), + "asset_mappings": (["source_id"], ["target_id"]), + "user_alias": (["user_id"], []) } - for table, id_cols in id_migrations.items(): + for table, (discord_cols, target_cols) in id_migrations.items(): cursor.execute(f"SELECT count(*) FROM sqlite_master WHERE type='table' AND name='{table}'") res = cursor.fetchone() if not res or res[0] == 0: @@ -56,28 +77,32 @@ class MigrationDatabase: cols = cursor.fetchall() needs_migration = False for col in cols: - if col[1] in id_cols and col[2] == "TEXT": + c_name, c_type = col[1], col[2] + if c_name in discord_cols and c_type != source_type: + needs_migration = True + break + if c_name in target_cols and c_type != target_type: needs_migration = True break if needs_migration: - logger.info(f"MigrationDatabase: Migrating {table}: converting ID columns to INTEGER") + logger.info(f"MigrationDatabase: Migrating {table} schema (Platform: {active_platform})") cursor.execute(f"ALTER TABLE {table} RENAME TO {table}_old") if table == "message_mappings": - cursor.execute("CREATE TABLE message_mappings (channel_id INTEGER, source_msg_id INTEGER, target_msg_id INTEGER, timestamp TEXT, PRIMARY KEY (channel_id, source_msg_id))") + cursor.execute(f"CREATE TABLE message_mappings (channel_id {target_type}, source_msg_id {source_type}, target_msg_id {target_type}, timestamp TEXT, PRIMARY KEY (channel_id, source_msg_id))") elif table == "thread_mappings": - cursor.execute("CREATE TABLE thread_mappings (channel_id INTEGER, thread_id INTEGER, source_msg_id INTEGER, target_msg_id INTEGER, timestamp TEXT, PRIMARY KEY (channel_id, thread_id, source_msg_id))") + cursor.execute(f"CREATE TABLE thread_mappings (channel_id {target_type}, thread_id {target_type}, source_msg_id {source_type}, target_msg_id {target_type}, timestamp TEXT, PRIMARY KEY (channel_id, thread_id, source_msg_id))") elif table == "channel_tracking": - cursor.execute("CREATE TABLE channel_tracking (channel_id INTEGER PRIMARY KEY, last_msg_id INTEGER, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0)") + cursor.execute(f"CREATE TABLE channel_tracking (channel_id {target_type} PRIMARY KEY, last_msg_id {target_type}, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0)") elif table == "thread_tracking": - cursor.execute("CREATE TABLE thread_tracking (channel_id INTEGER, thread_id INTEGER, last_msg_id INTEGER, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0, completed INTEGER DEFAULT 0, PRIMARY KEY (channel_id, thread_id))") + cursor.execute(f"CREATE TABLE thread_tracking (channel_id {target_type}, thread_id {target_type}, last_msg_id {target_type}, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0, completed INTEGER DEFAULT 0, PRIMARY KEY (channel_id, thread_id))") elif table == "server_mappings": - cursor.execute("CREATE TABLE server_mappings (category TEXT, source_id INTEGER, target_id INTEGER, PRIMARY KEY (category, source_id))") + cursor.execute(f"CREATE TABLE server_mappings (category TEXT, source_id {source_type}, target_id {target_type}, PRIMARY KEY (category, source_id))") elif table == "asset_mappings": - cursor.execute("CREATE TABLE asset_mappings (category TEXT, source_id INTEGER, target_id INTEGER, PRIMARY KEY (category, source_id))") + cursor.execute(f"CREATE TABLE asset_mappings (category TEXT, source_id {source_type}, target_id {target_type}, PRIMARY KEY (category, source_id))") elif table == "user_alias": - cursor.execute("CREATE TABLE user_alias (user_id INTEGER PRIMARY KEY, alias TEXT UNIQUE)") + cursor.execute(f"CREATE TABLE user_alias (user_id {source_type} PRIMARY KEY, alias TEXT UNIQUE)") old_cols = [c[1] for c in cursor.execute(f"PRAGMA table_info({table}_old)").fetchall()] new_cols = [c[1] for c in cursor.execute(f"PRAGMA table_info({table})").fetchall()] @@ -89,33 +114,33 @@ class MigrationDatabase: # Initial Creation / Ensure Schema Correctness # Table for message mappings: SourceID -> TargetID - cursor.execute(""" + cursor.execute(f""" CREATE TABLE IF NOT EXISTS message_mappings ( - channel_id INTEGER, - source_msg_id INTEGER, - target_msg_id INTEGER, + channel_id {target_type}, + source_msg_id {source_type}, + target_msg_id {target_type}, timestamp TEXT, PRIMARY KEY (channel_id, source_msg_id) ) """) # Table for thread mappings - cursor.execute(""" + cursor.execute(f""" CREATE TABLE IF NOT EXISTS thread_mappings ( - channel_id INTEGER, - thread_id INTEGER, - source_msg_id INTEGER, - target_msg_id INTEGER, + channel_id {target_type}, + thread_id {target_type}, + source_msg_id {source_type}, + target_msg_id {target_type}, timestamp TEXT, PRIMARY KEY (channel_id, thread_id, source_msg_id) ) """) # Table for per-channel stats and tracking - cursor.execute(""" + cursor.execute(f""" CREATE TABLE IF NOT EXISTS channel_tracking ( - channel_id INTEGER PRIMARY KEY, - last_msg_id INTEGER, + channel_id {target_type} PRIMARY KEY, + last_msg_id {target_type}, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0 @@ -123,11 +148,11 @@ class MigrationDatabase: """) # Table for per-thread stats - cursor.execute(""" + cursor.execute(f""" CREATE TABLE IF NOT EXISTS thread_tracking ( - channel_id INTEGER, - thread_id INTEGER, - last_msg_id INTEGER, + channel_id {target_type}, + thread_id {target_type}, + last_msg_id {target_type}, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0, @@ -136,32 +161,32 @@ class MigrationDatabase: ) """) - # Add completed column if it doesn't exist (backward compatibility for existing resumption DBs) + # Add completed column if it doesn't exist (backward compatibility) try: cursor.execute("ALTER TABLE thread_tracking ADD COLUMN completed INTEGER DEFAULT 0") except sqlite3.OperationalError: pass # Already exists # Table for server entity mappings (channels, roles, categories) - cursor.execute(""" + cursor.execute(f""" CREATE TABLE IF NOT EXISTS server_mappings ( category TEXT, - source_id INTEGER, - target_id INTEGER, + source_id {source_type}, + target_id {target_type}, PRIMARY KEY (category, source_id) ) """) # Table for asset mappings (emojis, stickers) - cursor.execute(""" + cursor.execute(f""" CREATE TABLE IF NOT EXISTS asset_mappings ( category TEXT, - source_id INTEGER, - target_id INTEGER, + source_id {source_type}, + target_id {target_type}, PRIMARY KEY (category, source_id) ) """) - + # Migrate old entity_mappings if it exists cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='entity_mappings'") if cursor.fetchone(): @@ -182,18 +207,10 @@ class MigrationDatabase: # Drop old table cursor.execute("DROP TABLE entity_mappings") - # Table for general metadata - cursor.execute(""" - CREATE TABLE IF NOT EXISTS metadata ( - key TEXT PRIMARY KEY, - value TEXT - ) - """) - # Table for auto-generated user aliases (user_id -> alias) - cursor.execute(""" + cursor.execute(f""" CREATE TABLE IF NOT EXISTS user_alias ( - user_id INTEGER PRIMARY KEY, + user_id {source_type} PRIMARY KEY, alias TEXT UNIQUE ) """) @@ -209,17 +226,30 @@ class MigrationDatabase: conn = self._get_conn() conn.execute( "INSERT OR REPLACE INTO message_mappings (channel_id, source_msg_id, target_msg_id, timestamp) VALUES (?, ?, ?, ?)", - (parse_snowflake(channel_id), parse_snowflake(source_id), parse_snowflake(target_id), timestamp) + (str(channel_id), parse_snowflake(source_id), str(target_id), timestamp) ) conn.commit() - def get_target_message_id(self, channel_id: str, source_id: str) -> Optional[str]: + def get_target_message_id(self, channel_id: str, source_id: str) -> Optional[Union[str, int]]: conn = self._get_conn() row = conn.execute( "SELECT target_msg_id FROM message_mappings WHERE channel_id = ? AND source_msg_id = ?", - (parse_snowflake(channel_id), parse_snowflake(source_id)) + (str(channel_id), parse_snowflake(source_id)) ).fetchone() - return row["target_msg_id"] if row else None + if row: + val = row["target_msg_id"] + return str(val) if self.platform == "stoat" else val + return None + + def get_all_message_mappings(self, channel_id: str) -> Dict[Union[str, int], Union[str, int]]: + conn = self._get_conn() + rows = conn.execute( + "SELECT source_msg_id, target_msg_id FROM message_mappings WHERE channel_id = ?", + (str(channel_id),) + ).fetchall() + if self.platform == "stoat": + return {str(row["source_msg_id"]): str(row["target_msg_id"]) for row in rows} + return {row["source_msg_id"]: row["target_msg_id"] for row in rows} # --- User Alias Methods --- @@ -296,24 +326,29 @@ class MigrationDatabase: conn = self._get_conn() conn.execute( "INSERT OR REPLACE INTO server_mappings (category, source_id, target_id) VALUES (?, ?, ?)", - (category, parse_snowflake(source_id), parse_snowflake(target_id)) + (category, parse_snowflake(source_id), str(target_id)) ) conn.commit() - def get_server_mapping(self, category: str, source_id: str) -> Optional[str]: + def get_server_mapping(self, category: str, source_id: str) -> Optional[Union[str, int]]: conn = self._get_conn() row = conn.execute( "SELECT target_id FROM server_mappings WHERE category = ? AND source_id = ?", (category, parse_snowflake(source_id)) ).fetchone() - return row["target_id"] if row else None + if row: + val = row["target_id"] + return str(val) if self.platform == "stoat" else val + return None - def get_all_server_mappings(self, category: str) -> Dict[str, str]: + def get_all_server_mappings(self, category: str) -> Dict[Union[str, int], Union[str, int]]: conn = self._get_conn() rows = conn.execute( "SELECT source_id, target_id FROM server_mappings WHERE category = ?", (category,) ).fetchall() + if self.platform == "stoat": + return {str(row["source_id"]): str(row["target_id"]) for row in rows} return {row["source_id"]: row["target_id"] for row in rows} def delete_server_mapping(self, category: str, source_id: str): @@ -338,24 +373,29 @@ class MigrationDatabase: conn = self._get_conn() conn.execute( "INSERT OR REPLACE INTO asset_mappings (category, source_id, target_id) VALUES (?, ?, ?)", - (category, parse_snowflake(source_id), parse_snowflake(target_id)) + (category, parse_snowflake(source_id), str(target_id)) ) conn.commit() - def get_asset_mapping(self, category: str, source_id: str) -> Optional[str]: + def get_asset_mapping(self, category: str, source_id: str) -> Optional[Union[str, int]]: conn = self._get_conn() row = conn.execute( "SELECT target_id FROM asset_mappings WHERE category = ? AND source_id = ?", (category, parse_snowflake(source_id)) ).fetchone() - return row["target_id"] if row else None + if row: + val = row["target_id"] + return str(val) if self.platform == "stoat" else val + return None - def get_all_asset_mappings(self, category: str) -> Dict[str, str]: + def get_all_asset_mappings(self, category: str) -> Dict[Union[str, int], Union[str, int]]: conn = self._get_conn() rows = conn.execute( "SELECT source_id, target_id FROM asset_mappings WHERE category = ?", (category,) ).fetchall() + if self.platform == "stoat": + return {str(row["source_id"]): str(row["target_id"]) for row in rows} return {row["source_id"]: row["target_id"] for row in rows} def delete_asset_mapping(self, category: str, source_id: str): @@ -389,23 +429,23 @@ class MigrationDatabase: def update_channel_tracking(self, channel_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0): conn = self._get_conn() # Initialize if missing - conn.execute("INSERT OR IGNORE INTO channel_tracking (channel_id) VALUES (?)", (parse_snowflake(channel_id),)) + conn.execute("INSERT OR IGNORE INTO channel_tracking (channel_id) VALUES (?)", (str(channel_id),)) if last_msg_id: - conn.execute("UPDATE channel_tracking SET last_msg_id = ? WHERE channel_id = ?", (parse_snowflake(last_msg_id), parse_snowflake(channel_id))) + conn.execute("UPDATE channel_tracking SET last_msg_id = ? WHERE channel_id = ?", (str(last_msg_id), str(channel_id))) if last_msg_ts: - conn.execute("UPDATE channel_tracking SET last_msg_ts = ? WHERE channel_id = ?", (last_msg_ts, parse_snowflake(channel_id))) + conn.execute("UPDATE channel_tracking SET last_msg_ts = ? WHERE channel_id = ?", (last_msg_ts, str(channel_id))) if msg_inc != 0 or file_inc != 0: conn.execute( "UPDATE channel_tracking SET msg_count = msg_count + ?, file_count = file_count + ? WHERE channel_id = ?", - (msg_inc, file_inc, parse_snowflake(channel_id)) + (msg_inc, file_inc, str(channel_id)) ) conn.commit() def get_channel_tracking(self, channel_id: str) -> Dict[str, Any]: conn = self._get_conn() - row = conn.execute("SELECT * FROM channel_tracking WHERE channel_id = ?", (parse_snowflake(channel_id),)).fetchone() + row = conn.execute("SELECT * FROM channel_tracking WHERE channel_id = ?", (str(channel_id),)).fetchone() if row: return dict(row) return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0} @@ -415,39 +455,42 @@ class MigrationDatabase: conn = self._get_conn() conn.execute( "INSERT OR REPLACE INTO thread_mappings (channel_id, thread_id, source_msg_id, target_msg_id, timestamp) VALUES (?, ?, ?, ?, ?)", - (parse_snowflake(channel_id), parse_snowflake(thread_id), parse_snowflake(source_id), parse_snowflake(target_id), timestamp) + (str(channel_id), str(thread_id), parse_snowflake(source_id), str(target_id), timestamp) ) conn.commit() - def get_target_thread_message_id(self, channel_id: str, thread_id: str, source_id: str) -> Optional[str]: + def get_target_thread_message_id(self, channel_id: str, thread_id: str, source_id: str) -> Optional[Union[str, int]]: conn = self._get_conn() row = conn.execute( "SELECT target_msg_id FROM thread_mappings WHERE channel_id = ? AND thread_id = ? AND source_msg_id = ?", - (parse_snowflake(channel_id), parse_snowflake(thread_id), parse_snowflake(source_id)) + (str(channel_id), str(thread_id), parse_snowflake(source_id)) ).fetchone() - return row["target_msg_id"] if row else None + if row: + val = row["target_msg_id"] + return str(val) if self.platform == "stoat" else val + return None def update_thread_tracking(self, channel_id: str, thread_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0, completed: int = None): conn = self._get_conn() - conn.execute("INSERT OR IGNORE INTO thread_tracking (channel_id, thread_id) VALUES (?, ?)", (parse_snowflake(channel_id), parse_snowflake(thread_id))) + conn.execute("INSERT OR IGNORE INTO thread_tracking (channel_id, thread_id) VALUES (?, ?)", (str(channel_id), str(thread_id))) if last_msg_id: - conn.execute("UPDATE thread_tracking SET last_msg_id = ? WHERE channel_id = ? AND thread_id = ?", (parse_snowflake(last_msg_id), parse_snowflake(channel_id), parse_snowflake(thread_id))) + conn.execute("UPDATE thread_tracking SET last_msg_id = ? WHERE channel_id = ? AND thread_id = ?", (str(last_msg_id), str(channel_id), str(thread_id))) if last_msg_ts: - conn.execute("UPDATE thread_tracking SET last_msg_ts = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_ts, parse_snowflake(channel_id), parse_snowflake(thread_id))) + conn.execute("UPDATE thread_tracking SET last_msg_ts = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_ts, str(channel_id), str(thread_id))) if completed is not None: - conn.execute("UPDATE thread_tracking SET completed = ? WHERE channel_id = ? AND thread_id = ?", (completed, parse_snowflake(channel_id), parse_snowflake(thread_id))) + conn.execute("UPDATE thread_tracking SET completed = ? WHERE channel_id = ? AND thread_id = ?", (completed, str(channel_id), str(thread_id))) if msg_inc != 0 or file_inc != 0: conn.execute( "UPDATE thread_tracking SET msg_count = msg_count + ?, file_count = file_count + ? WHERE channel_id = ? AND thread_id = ?", - (msg_inc, file_inc, parse_snowflake(channel_id), parse_snowflake(thread_id)) + (msg_inc, file_inc, str(channel_id), str(thread_id)) ) conn.commit() def get_thread_tracking(self, channel_id: str, thread_id: str) -> Dict[str, Any]: conn = self._get_conn() - row = conn.execute("SELECT * FROM thread_tracking WHERE channel_id = ? AND thread_id = ?", (parse_snowflake(channel_id), parse_snowflake(thread_id))).fetchone() + row = conn.execute("SELECT * FROM thread_tracking WHERE channel_id = ? AND thread_id = ?", (str(channel_id), str(thread_id))).fetchone() if row: return dict(row) return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0} @@ -455,10 +498,10 @@ class MigrationDatabase: def clear_channel_data(self, channel_id: str): """Purge all mappings and tracking data for a specific channel and its threads.""" conn = self._get_conn() - conn.execute("DELETE FROM message_mappings WHERE channel_id = ?", (parse_snowflake(channel_id),)) - conn.execute("DELETE FROM thread_mappings WHERE channel_id = ?", (parse_snowflake(channel_id),)) - conn.execute("DELETE FROM channel_tracking WHERE channel_id = ?", (parse_snowflake(channel_id),)) - conn.execute("DELETE FROM thread_tracking WHERE channel_id = ?", (parse_snowflake(channel_id),)) + conn.execute("DELETE FROM message_mappings WHERE channel_id = ?", (str(channel_id),)) + conn.execute("DELETE FROM thread_mappings WHERE channel_id = ?", (str(channel_id),)) + conn.execute("DELETE FROM channel_tracking WHERE channel_id = ?", (str(channel_id),)) + conn.execute("DELETE FROM thread_tracking WHERE channel_id = ?", (str(channel_id),)) conn.commit() logger.info(f"Cleared all tracking and mapping data for channel: {channel_id}") diff --git a/src/core/state.py b/src/core/state.py index 0e8d181..dcfd446 100644 --- a/src/core/state.py +++ b/src/core/state.py @@ -1,7 +1,10 @@ import json import logging from pathlib import Path -from typing import Dict, Any, Optional +from typing import Dict, Optional, Any, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from src.core.database import MigrationDatabase logger = logging.getLogger(__name__) @@ -119,23 +122,23 @@ class MigrationState: # --- Properties for backward compatibility --- @property - def channel_map(self) -> Dict[str, str]: + def channel_map(self) -> Dict[Union[str, int], Union[str, int]]: return self.db.get_all_server_mappings("channel") if self.db else {} @property - def category_map(self) -> Dict[str, str]: + def category_map(self) -> Dict[Union[str, int], Union[str, int]]: return self.db.get_all_server_mappings("category") if self.db else {} @property - def role_map(self) -> Dict[str, str]: + def role_map(self) -> Dict[Union[str, int], Union[str, int]]: return self.db.get_all_server_mappings("role") if self.db else {} @property - def emoji_map(self) -> Dict[str, str]: + def emoji_map(self) -> Dict[Union[str, int], Union[str, int]]: return self.db.get_all_asset_mappings("emoji") if self.db else {} @property - def sticker_map(self) -> Dict[str, str]: + def sticker_map(self) -> Dict[Union[str, int], Union[str, int]]: return self.db.get_all_asset_mappings("sticker") if self.db else {} @property @@ -153,7 +156,7 @@ class MigrationState: if self._ensure_db(): self.db.set_message_mapping(str(target_channel_id), str(discord_id), str(target_id)) - def get_target_message_id(self, target_channel_id: str, discord_id: str) -> str | None: + def get_target_message_id(self, target_channel_id: str, discord_id: str) -> str | int | None: if self._ensure_db(): return self.db.get_target_message_id(str(target_channel_id), str(discord_id)) return None @@ -161,7 +164,7 @@ class MigrationState: def set_message_mapping(self, target_channel_id: str, discord_id: str, target_id: str): self.set_target_message_mapping(target_channel_id, discord_id, target_id) - def get_fluxer_message_id(self, target_channel_id: str, discord_id: str) -> str | None: + def get_fluxer_message_id(self, target_channel_id: str, discord_id: str) -> str | int | None: return self.get_target_message_id(target_channel_id, discord_id) def increment_stats(self, target_channel_id: str, messages: int = 1, files: int = 0): @@ -258,7 +261,7 @@ class MigrationState: if self._ensure_db(): self.db.clear_channel_data(str(target_channel_id)) - def set_folder(self, server_id: str, clean_name: str, base_dir: Path | str = ""): + def set_folder(self, server_id: str, clean_name: str, platform: str = "stoat", base_dir: Path | str = ""): """ Initializes the SQLite database based on community name and ID. Filename: {name}-{id}.db (Flat structure) @@ -294,7 +297,7 @@ class MigrationState: from src.core.database import MigrationDatabase if self.db: self.db.close() - self.db = MigrationDatabase(db_path) + self.db = MigrationDatabase(db_path, platform=platform) logger.info(f"Initialized SQLite database at {db_path}") def get_user_alias(self, user_id: str) -> str | None: diff --git a/src/core/utils.py b/src/core/utils.py index 8402395..1167d47 100644 --- a/src/core/utils.py +++ b/src/core/utils.py @@ -14,15 +14,16 @@ def parse_snowflake(value: Any) -> Optional[int]: except ValueError: return None -from src.core.state import MigrationState - logger = logging.getLogger(__name__) -def resolve_discord_links(content: str, state: MigrationState, platform: str, target_server_id: str) -> str: +def resolve_discord_links(content: str, state, platform: str, target_server_id: str) -> str: """ Finds Discord message/channel links and resolves them to the target platform if they have been migrated. """ + from src.core.state import MigrationState + if not isinstance(state, MigrationState): + logger.warning(f"resolve_discord_links: state is not MigrationState (type: {type(state)})") if not content: return content diff --git a/src/fluxer/emoji_stickers.py b/src/fluxer/emoji_stickers.py index f8de0ec..7a7e8fe 100644 --- a/src/fluxer/emoji_stickers.py +++ b/src/fluxer/emoji_stickers.py @@ -18,10 +18,10 @@ async def sync_assets_state(context: MigrationContext): fluxer_stickers = await context.fluxer_writer.client.get_guild_stickers(context.config.fluxer_server_id) # Build name -> id maps and ID sets for Fluxer for fast lookup - fluxer_emoji_map = {e.get("name"): str(e.get("id")) for e in fluxer_emojis if e.get("name")} - fluxer_sticker_map = {s.get("name"): str(s.get("id")) for s in fluxer_stickers if s.get("name")} - fluxer_emoji_ids = {str(e.get("id")) for e in fluxer_emojis} - fluxer_sticker_ids = {str(s.get("id")) for s in fluxer_stickers} + fluxer_emoji_map = {e.get("name"): e.get("id") for e in fluxer_emojis if e.get("name")} + fluxer_sticker_map = {s.get("name"): s.get("id") for s in fluxer_stickers if s.get("name")} + fluxer_emoji_ids = {e.get("id") for e in fluxer_emojis} + fluxer_sticker_ids = {s.get("id") for s in fluxer_stickers} updates = 0 removals = 0 diff --git a/src/fluxer/migrate_message.py b/src/fluxer/migrate_message.py index dc644cb..12e656c 100644 --- a/src/fluxer/migrate_message.py +++ b/src/fluxer/migrate_message.py @@ -73,8 +73,8 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None, cid = int(match.group(1)) # 1. Check if channel is mapped in state - if channel_map and str(cid) in channel_map: - return f"<#{channel_map[str(cid)]}>" + if channel_map and cid in channel_map: + return f"<#{channel_map[cid]}>" # 2. Try to resolve channel name from pre-fetched names name = None @@ -100,7 +100,7 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None, def replace_emoji(match): animated = match.group(1) == "a" name = match.group(2) - eid = match.group(3) + eid = int(match.group(3)) if emoji_map and eid in emoji_map: target_eid = emoji_map[eid] diff --git a/src/fluxer/roles_permissions.py b/src/fluxer/roles_permissions.py index 83ff25b..5973c27 100644 --- a/src/fluxer/roles_permissions.py +++ b/src/fluxer/roles_permissions.py @@ -15,8 +15,8 @@ async def sync_roles_state(context: MigrationContext): fluxer_roles = await context.fluxer_writer.client.get_guild_roles(context.config.fluxer_server_id) # Build name -> id maps and ID sets for Fluxer for fast lookup - fluxer_role_map = {r.get("name"): str(r.get("id")) for r in fluxer_roles if r.get("name")} - fluxer_role_ids = {str(r.get("id")) for r in fluxer_roles} + fluxer_role_map = {r.get("name"): r.get("id") for r in fluxer_roles if r.get("name")} + fluxer_role_ids = {r.get("id") for r in fluxer_roles} updates = 0 removals = 0 diff --git a/src/stoat/emoji_stickers.py b/src/stoat/emoji_stickers.py index bfa1b60..5b7a0ed 100644 --- a/src/stoat/emoji_stickers.py +++ b/src/stoat/emoji_stickers.py @@ -37,8 +37,7 @@ async def sync_assets_state(context: MigrationContext): if stoat_id: if stoat_id not in stoat_emoji_ids: - context.state.emoji_map.pop(discord_id, None) - context.state.save_state() + context.state.remove_emoji_mapping(discord_id) removals += 1 elif emoji.name in stoat_emoji_map: context.state.set_target_emoji_mapping(discord_id, stoat_emoji_map[emoji.name]) From 011c0ca4e09e134f57a8a9424ed9f98895ba994b Mon Sep 17 00:00:00 2001 From: rambros Date: Sat, 28 Mar 2026 11:24:18 +0530 Subject: [PATCH 5/8] fix windows PyiFrozenLoader error --- src/core/backup_reader.py | 3 ++- src/core/database.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/core/backup_reader.py b/src/core/backup_reader.py index 408e16e..3ec7c92 100644 --- a/src/core/backup_reader.py +++ b/src/core/backup_reader.py @@ -6,12 +6,13 @@ Discord API. Implements the same public interface as DiscordReader so that migration scripts and UI code can use either provider transparently. """ +from __future__ import annotations import json import logging from datetime import datetime, timezone from enum import IntEnum from pathlib import Path -from typing import AsyncGenerator, Dict, Any, List, Optional +from typing import AsyncGenerator, Dict, Any, List, Optional, Union from src.core.backup_database import BackupDatabase, parse_snowflake logger = logging.getLogger(__name__) diff --git a/src/core/database.py b/src/core/database.py index 996e083..bfaf490 100644 --- a/src/core/database.py +++ b/src/core/database.py @@ -1,9 +1,10 @@ +from __future__ import annotations import sqlite3 import logging import json import random from pathlib import Path -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, Union import threading import sys from src.core.utils import parse_snowflake From 220f97aad45a9912f544acd79f3af19d6ed8ac37 Mon Sep 17 00:00:00 2001 From: rambros Date: Sat, 28 Mar 2026 16:58:38 +0530 Subject: [PATCH 6/8] add waterfall mode --- README.md | 4 +- src/core/backup_database.py | 77 ++++++++++ src/core/backup_reader.py | 36 +++++ src/core/database.py | 29 ++++ src/core/state.py | 6 + src/fluxer/migrate_message.py | 262 ++++++++++++++++++++++++++++++++-- src/stoat/migrate_message.py | 245 +++++++++++++++++++++++++++++-- src/ui/shuttle_ops.py | 205 +++++++++++++++++++++++++- 8 files changed, 828 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index aac1b6e..486c5dc 100644 --- a/README.md +++ b/README.md @@ -9,8 +9,8 @@ ![Disco Reaper](images/fluxer-reaper.jpg) -### Modern Terminal Interface -The tool now features a unified, intuitive TUI (Terminal User Interface) - no more text commands +### Video Guide - [Youtube](https://www.youtube.com/watch?v=SwIPQDxLzqA) + | Features | Fluxer | Stoat | | :--- | :---: | :---: | diff --git a/src/core/backup_database.py b/src/core/backup_database.py index 88d0df1..40e0011 100644 --- a/src/core/backup_database.py +++ b/src/core/backup_database.py @@ -791,6 +791,83 @@ class BackupDatabase: return msg_list + def get_global_messages_paged(self, limit: int = 100, offset: int = 0, after_id: Optional[str] = None) -> List[Dict[str, Any]]: + """Fetches messages across ALL channels globally, ordered by timestamp/ID ascending.""" + with self._lock: + query = "SELECT * FROM messages" + params = [] + + if after_id: + query += " WHERE id > ?" + params.append(parse_snowflake(after_id)) + + query += " ORDER BY id ASC LIMIT ? OFFSET ?" + params.extend([limit, offset]) + + rows = self._conn.execute(query, params).fetchall() + msg_list = [dict(r) for r in rows] + + if msg_list: + msg_ids = [m["id"] for m in msg_list] + placeholders = ",".join(["?"] * len(msg_ids)) + + att_rows = self._conn.execute(f"SELECT * FROM attachments WHERE message_id IN ({placeholders})", msg_ids).fetchall() + atts_by_msg = {} + for ar in att_rows: + mid = ar["message_id"] + if mid not in atts_by_msg: atts_by_msg[mid] = [] + atts_by_msg[mid].append(dict(ar)) + + emb_rows = self._conn.execute(f"SELECT * FROM embeds WHERE message_id IN ({placeholders})", msg_ids).fetchall() + embs_by_msg = {} + for er in emb_rows: + mid = er["message_id"] + if mid not in embs_by_msg: embs_by_msg[mid] = [] + + e_dict = { + "title": er["title"], + "description": er["description"], + "url": er["url"], + "color": er["color"], + "timestamp": er["timestamp"], + "thumbnail": {"url": er["thumbnail_url"]} if er["thumbnail_url"] else None, + "image": {"url": er["image_url"]} if er["image_url"] else None, + "author": { + "name": er["author_name"], + "url": er["author_url"], + "icon_url": er["author_icon_url"] + } if er["author_name"] else None, + "footer": { + "text": er["footer_text"], + "icon_url": er["footer_icon_url"] + } if er["footer_text"] else None, + "fields": json.loads(er["fields"]) if er["fields"] else [] + } + embs_by_msg[mid].append(e_dict) + + rea_rows = self._conn.execute(f"SELECT * FROM reactions WHERE message_id IN ({placeholders})", msg_ids).fetchall() + reas_by_msg = {} + for rr in rea_rows: + mid = rr["message_id"] + if mid not in reas_by_msg: reas_by_msg[mid] = [] + reas_by_msg[mid].append(dict(rr)) + + st_rows = self._conn.execute(f"SELECT * FROM message_stickers WHERE message_id IN ({placeholders})", msg_ids).fetchall() + sts_by_msg = {} + for sr in st_rows: + mid = sr["message_id"] + if mid not in sts_by_msg: sts_by_msg[mid] = [] + sts_by_msg[mid].append(dict(sr)) + + for m in msg_list: + m_id = m["id"] + m["attachments"] = atts_by_msg.get(m_id, []) + m["embeds"] = embs_by_msg.get(m_id, []) + m["reactions"] = reas_by_msg.get(m_id, []) + m["stickers"] = sts_by_msg.get(m_id, []) + + return msg_list + def delete_channel_messages(self, channel_id: Union[str, int]): """Deletes all messages and related metadata for a specific channel and its threads.""" cid = parse_snowflake(channel_id) diff --git a/src/core/backup_reader.py b/src/core/backup_reader.py index 3ec7c92..c9d250d 100644 --- a/src/core/backup_reader.py +++ b/src/core/backup_reader.py @@ -1417,6 +1417,42 @@ class BackupReader: if len(msgs) < batch_size: break + async def fetch_global_message_history( + self, + limit: int = None, + after_id: int = None + ) -> AsyncGenerator["BackupMessage", None]: + """Yields BackupMessages globally from SQLite across all channels, natively ordered by timestamp/ID.""" + if not self.db: return + + offset = 0 + batch_size = 100 + count = 0 + + while True: + actual_limit = batch_size + if limit: + rem = limit - count + if rem <= 0: break + actual_limit = min(batch_size, rem) + + msgs = self.db.get_global_messages_paged( + limit=actual_limit, + offset=offset, + after_id=str(after_id) if after_id else None + ) + + if not msgs: + break + + for m in msgs: + yield self._hydrate_message(m) + count += 1 + + offset += len(msgs) + if len(msgs) < batch_size: + break + # ── download helpers ───────────────────────────────────────────────── async def download_emoji(self, emoji: BackupEmoji) -> bytes: diff --git a/src/core/database.py b/src/core/database.py index bfaf490..419d9aa 100644 --- a/src/core/database.py +++ b/src/core/database.py @@ -451,6 +451,35 @@ class MigrationDatabase: return dict(row) return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0} + def get_global_min_last_message_id(self, mapped_channel_ids: List[str]) -> Optional[str]: + """Returns the minimum last_msg_id across all mapped channels. If any mapped channel has NO last_msg_id, returns None.""" + if not mapped_channel_ids: + return None + + conn = self._get_conn() + placeholders = ",".join(["?"] * len(mapped_channel_ids)) + rows = conn.execute(f"SELECT last_msg_id FROM channel_tracking WHERE channel_id IN ({placeholders})", mapped_channel_ids).fetchall() + + # If the number of tracked channels is less than mapped, it means some mapped channels haven't started. + if len(rows) < len(mapped_channel_ids): + return None + + # Parse all ids + ids = [] + for r in rows: + val = r["last_msg_id"] + if not val: + return None # One channel has no messages yet + try: + ids.append(int(val)) + except ValueError: + pass + + if not ids: + return None + + return str(min(ids)) + # Thread methods similar to channel methods def set_thread_message_mapping(self, channel_id: str, thread_id: str, source_id: str, target_id: str, timestamp: str = None): conn = self._get_conn() diff --git a/src/core/state.py b/src/core/state.py index dcfd446..ed9c1c6 100644 --- a/src/core/state.py +++ b/src/core/state.py @@ -215,6 +215,12 @@ class MigrationState: return self.db.get_channel_tracking(str(target_channel_id)).get("last_msg_id") return None + def get_global_min_last_message_id(self, mapped_channel_ids: List[str]) -> str | None: + """Returns the absolute minimum last_msg_id among the given list of mapped target channel IDs.""" + if self._ensure_db(): + return self.db.get_global_min_last_message_id(mapped_channel_ids) + return None + def get_thread_last_message_id(self, target_channel_id: str, thread_id: str) -> str | None: if self._ensure_db(): return self.db.get_thread_tracking(str(target_channel_id), str(thread_id)).get("last_msg_id") diff --git a/src/fluxer/migrate_message.py b/src/fluxer/migrate_message.py index 12e656c..c63693e 100644 --- a/src/fluxer/migrate_message.py +++ b/src/fluxer/migrate_message.py @@ -4,6 +4,7 @@ import re import json import io from typing import Callable, Awaitable, Dict, Any, List +from pathlib import Path try: from lottie.objects import Animation @@ -27,22 +28,25 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None, uid = int(match.group(1)) if anonymize_users and state: alias = state.get_user_alias(str(uid)) - return f"`@{alias}`" + return f"`@{alias}`" if alias else "`@Anonymized User`" + # 1. Try provided guild member = guild.get_member(uid) if member: return f"`@{member.display_name}`" - # 2. Try message's user_mentions + + # 2. Try provided user_mentions if user_mentions: - for u in user_mentions: - if u.id == uid: - return f"`@{getattr(u, 'display_name', u.name)}`" + m = next((u for u in user_mentions if u.id == uid), None) + if m: + return f"`@{m.display_name}`" + # 3. Try global cache via guild.client if hasattr(guild, 'client'): user = guild.client.get_user(uid) if user: return f"`@{user.name}`" - return f"`@Unknown User`" + return "`@Unknown User`" def replace_role(match): rid = int(match.group(1)) @@ -557,23 +561,22 @@ async def migrate_messages( if thread_name and stats["messages"] == 0: content = f"> <<< THREAD: **{thread_name}** >>>\n{content}" - # Get or generate alias + # Always ensure alias is created/retrieved to populate user_alias table alias = context.state.get_user_alias(str(msg.author.id)) - - if context.config.anonymize_users: - author_name = alias - avatar_url = f"https://api.dicebear.com/9.x/fun-emoji/jpg?seed={alias}" + + anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False + if anonymize_users: + author_name = alias or "Anonymized User" + author_avatar_url = None else: author_name = msg.author.display_name - avatar_url = str(msg.author.display_avatar.url) if msg.author.display_avatar.url else None - if avatar_url and not avatar_url.startswith("http"): - avatar_url = None + author_avatar_url = msg.author.avatar.url if hasattr(msg.author, 'avatar') and msg.author.avatar else None logger.debug(f"Fluxer: Calling send_message for Discord ID {msg.id}") fluxer_msg_id = await context.fluxer_writer.send_message( channel_id=target_channel_id, author_name=author_name, - author_avatar_url=avatar_url, + author_avatar_url=author_avatar_url, content=content, timestamp=int(msg.created_at.timestamp()), files=files if files else None, @@ -663,3 +666,232 @@ async def migrate_messages( pass return stats + + +async def analyze_global_migration(context: MigrationContext, after_message_id: int | None = None, inclusive: bool = False, progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None) -> Dict[str, int]: + """ + Scans the entire server history to count messages, threads, and attachments globally. + """ + stats = {"messages": 0, "threads": 0, "attachments": 0} + + # In global mode, thread messages are returned natively in timestamp order by global fetch if they're in the DB + # However we just count them if the fetcher yields them. + async for msg in context.discord_reader.fetch_global_message_history(after_id=after_message_id): + if not context.is_running: + break + + if msg.type not in [ + context.discord_reader.MESSAGE_TYPE_DEFAULT, + context.discord_reader.MESSAGE_TYPE_REPLY, + context.discord_reader.MESSAGE_TYPE_THREAD_STARTER, + context.discord_reader.MESSAGE_TYPE_FORWARD, + context.discord_reader.MESSAGE_TYPE_CHAT_INPUT_COMMAND, + context.discord_reader.MESSAGE_TYPE_CONTEXT_MENU_COMMAND, + context.discord_reader.MESSAGE_TYPE_POLL_RESULT, + context.discord_reader.MESSAGE_TYPE_AUTO_MODERATION_ACTION + ]: + continue + + stats["messages"] += 1 + stats["attachments"] += len(msg.attachments) + if hasattr(msg, 'thread') and msg.thread: + # We don't recursively traverse here, we just count the fact there is a thread + # The actual thread messages are also fetched by the global fetcher because they have their own timestamp/id + stats["threads"] += 1 + + if progress_callback and stats["messages"] % 100 == 0: + await progress_callback(stats) + + if progress_callback: + await progress_callback(stats) + + return stats + + +async def migrate_global_messages( + context: MigrationContext, + after_message_id: int | None = None, + inclusive: bool = False, + progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None +) -> Dict[str, Any]: + """ + Migrates messages across all channels chronologically. + """ + stats = { + "messages": 0, + "threads": 0, + "attachments": 0, + "last_message_content": "", + "last_message_author": "", + "first_message_url": None, + "last_message_url": None + } + + processed_threads = set() + logger.info("Starting Global Waterfall Migration for Fluxer...") + + # Keep track of active thread mapping natively to pass parent target IDs if needed + thread_to_target_channel = {} + + # Emojis and mapped users cache setup + emoji_map = context.state.emoji_map + db_media = context.discord_reader.db.get_all_media() if context.discord_reader.db else {} + target_server_id = getattr(context.fluxer_writer, "server_id", None) + + try: + async for msg in context.discord_reader.fetch_global_message_history(after_id=after_message_id): + if not context.is_running: + logger.warning("Global migration interrupted by user") + break + + if msg.type not in [ + context.discord_reader.MESSAGE_TYPE_DEFAULT, + context.discord_reader.MESSAGE_TYPE_REPLY, + context.discord_reader.MESSAGE_TYPE_THREAD_STARTER, + context.discord_reader.MESSAGE_TYPE_FORWARD, + context.discord_reader.MESSAGE_TYPE_CHAT_INPUT_COMMAND, + context.discord_reader.MESSAGE_TYPE_CONTEXT_MENU_COMMAND, + context.discord_reader.MESSAGE_TYPE_POLL_RESULT, + context.discord_reader.MESSAGE_TYPE_AUTO_MODERATION_ACTION + ]: + continue + + # Determine target channel + target_channel_id = context.state.get_target_channel_id(str(msg.channel.id)) + if not target_channel_id: + logger.debug(f"Skipping msg {msg.id}: channel {msg.channel.id} not mapped.") + continue + + # If it's a thread message, we need to handle it based on if it's the thread starter or a reply + parent_target_id = None + if hasattr(msg, 'thread') and msg.thread and msg.id == msg.thread.id: + processed_threads.add(msg.thread.id) + stats["threads"] += 1 + elif msg.channel.type in [11, 12]: # Thread channels + # It's a message IN a thread. + # In Fluxer, threads might just be linear messages or threaded replies depending on schema + # For basic migration we just send it to the parent mapped target channel. + # The parent mapped target channel ID should already be calculated correctly by get_target_channel_id (which returns mapped thread or parent channel) + pass + + # Formatting + files = [] + file_names = [] + + # Always ensure alias is created/retrieved to populate user_alias table + alias = context.state.get_user_alias(str(msg.author.id)) + + anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False + + if anonymize_users: + author_name = alias or "Anonymized User" + author_avatar_url = None + else: + author_name = msg.author.display_name + author_avatar_url = msg.author.avatar.url if hasattr(msg.author, 'avatar') and msg.author.avatar else None + + for att in msg.attachments: + media_info = db_media.get(att.local_hash) if db_media else None + local_path = None + if media_info: + local_path = Path(media_info["local_path"]) + elif hasattr(att, 'read'): + # Fallback + pass + + if local_path and local_path.exists(): + files.append(local_path) + file_names.append(att.filename) + + content = msg.content or "" + + # Stickers + for sticker in msg.stickers: + sticker_name = sticker.name + sticker_url = sticker.url + + # Check for uploaded media pool logic first + s_hash = sticker.local_hash + sticker_file = None + s_media = db_media.get(s_hash) if db_media and s_hash else None + if s_media: + s_path = Path(s_media["local_path"]) + if s_path.exists(): + sticker_file = s_path + + content += f"\n[Sticker: {sticker_name}]" + if sticker_file: + files.append(sticker_file) + file_names.append(f"sticker_{sticker_name}.png") + + content = clean_mentions( + content=content, + guild=context.discord_reader.guild, + user_mentions=msg.mentions, + role_mentions=msg.role_mentions, + channel_mentions=msg.channel_mentions, + emoji_map=emoji_map, + channel_map=context.state.channel_map, + state=context.state, + target_server_id=target_server_id, + channel_names=context.channel_names if hasattr(context, 'channel_names') else None, + anonymize_users=anonymize_users + ) + + if not content and not files: + logger.debug(f"Message {msg.id} empty after processing, skipping.") + continue + + timestamp_int = int(msg.created_at.timestamp()) + + if msg.reference and msg.reference.message_id: + # Resolve the author of the message being replied to + source_ref_msg = await context.discord_reader.get_message(msg.channel.id, msg.reference.message_id) + if source_ref_msg and source_ref_msg.author: + ref_author_id = str(source_ref_msg.author.id) + if anonymize_users: + ref_name = context.state.get_user_alias(ref_author_id) or "Anonymized User" + else: + ref_name = source_ref_msg.author.display_name + content = f"`@{ref_name}`\n{content}" + else: + # Fallback if author cannot be resolved (e.g. deleted/missing from backup) + tgt_reply = context.state.get_target_message_id(target_channel_id, msg.reference.message_id) + if tgt_reply: + content = f"[Reply to {tgt_reply}]\n{content}" + + try: + fluxer_msg_id = await context.fluxer_writer.send_message( + channel_id=target_channel_id, + author_name=author_name, + author_avatar_url=author_avatar_url, + content=content, + files=files, + timestamp=timestamp_int, + embeds=msg.embeds + ) + + if fluxer_msg_id: + context.state.set_target_message_mapping(target_channel_id, msg.id, fluxer_msg_id) + context.state.update_last_message_id(target_channel_id, msg.id) + stats["attachments"] += len(files) if files else 0 + + stats["messages"] += 1 + stats["last_message_content"] = content + stats["last_message_author"] = author_name + + if not stats["first_message_url"]: + stats["first_message_url"] = msg.jump_url + stats["last_message_url"] = msg.jump_url + + if progress_callback: + await progress_callback(stats) + + except Exception as e: + logger.error(f"Failed to process global message {msg.id}: {e}") + + except (KeyboardInterrupt, asyncio.CancelledError): + context.is_running = False + pass + + return stats diff --git a/src/stoat/migrate_message.py b/src/stoat/migrate_message.py index c0861bd..45d0fee 100644 --- a/src/stoat/migrate_message.py +++ b/src/stoat/migrate_message.py @@ -4,6 +4,7 @@ import re import json import io from typing import Callable, Awaitable, Dict, Any, List +from pathlib import Path try: from lottie.objects import Animation @@ -27,22 +28,25 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None, uid = int(match.group(1)) if anonymize_users and state: alias = state.get_user_alias(str(uid)) - return f"`@{alias}`" + return f"`@{alias}`" if alias else "`@Anonymized User`" + # 1. Try provided guild member = guild.get_member(uid) if member: return f"`@{member.display_name}`" - # 2. Try message's user_mentions + + # 2. Try provided user_mentions if user_mentions: - for u in user_mentions: - if u.id == uid: - return f"`@{getattr(u, 'display_name', u.name)}`" + m = next((u for u in user_mentions if u.id == uid), None) + if m: + return f"`@{m.display_name}`" + # 3. Try global cache via guild.client if hasattr(guild, 'client'): user = guild.client.get_user(uid) if user: return f"`@{user.name}`" - return f"`@Unknown User`" + return "`@Unknown User`" def replace_role(match): rid = int(match.group(1)) @@ -560,22 +564,24 @@ async def migrate_messages( if thread_name and stats["messages"] == 0: content = f"> <<< THREAD: **{thread_name}** >>>\n{content}" - # Get or generate alias + # Always ensure alias is created/retrieved to populate user_alias table alias = context.state.get_user_alias(str(msg.author.id)) - - if context.config.anonymize_users: - author_name = alias - avatar_url = f"https://api.dicebear.com/9.x/fun-emoji/jpg?seed={alias}" + + anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False + if anonymize_users: + author_name = alias or "Anonymized User" + author_avatar_url = None else: author_name = msg.author.display_name - avatar_url = str(msg.author.display_avatar.url) if msg.author.display_avatar.url else None - if avatar_url and not avatar_url.startswith("http"): - avatar_url = None + author_avatar_url = str(msg.author.display_avatar.url) if msg.author.display_avatar.url else None + if author_avatar_url and not author_avatar_url.startswith("http"): + author_avatar_url = None + logger.debug(f"Stoat: Calling send_message for Discord ID {msg.id}") stoat_msg_id = await context.stoat_writer.send_message( channel_id=target_channel_id, author_name=author_name, - author_avatar_url=avatar_url, + author_avatar_url=author_avatar_url, content=content, timestamp=int(msg.created_at.timestamp()), files=files if files else None, @@ -664,3 +670,212 @@ async def migrate_messages( pass return stats + + +async def analyze_global_migration(context: MigrationContext, after_message_id: int | None = None, inclusive: bool = False, progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None) -> Dict[str, int]: + """ + Scans the entire server history to count messages, threads, and attachments globally. + """ + stats = {"messages": 0, "threads": 0, "attachments": 0} + + async for msg in context.discord_reader.fetch_global_message_history(after_id=after_message_id): + if not context.is_running: + break + + if msg.type not in [ + context.discord_reader.MESSAGE_TYPE_DEFAULT, + context.discord_reader.MESSAGE_TYPE_REPLY, + context.discord_reader.MESSAGE_TYPE_THREAD_STARTER, + context.discord_reader.MESSAGE_TYPE_FORWARD, + context.discord_reader.MESSAGE_TYPE_CHAT_INPUT_COMMAND, + context.discord_reader.MESSAGE_TYPE_CONTEXT_MENU_COMMAND, + context.discord_reader.MESSAGE_TYPE_POLL_RESULT, + context.discord_reader.MESSAGE_TYPE_AUTO_MODERATION_ACTION + ]: + continue + + stats["messages"] += 1 + stats["attachments"] += len(msg.attachments) + if hasattr(msg, 'thread') and msg.thread: + stats["threads"] += 1 + + if progress_callback and stats["messages"] % 100 == 0: + await progress_callback(stats) + + if progress_callback: + await progress_callback(stats) + + return stats + + +async def migrate_global_messages( + context: MigrationContext, + after_message_id: int | None = None, + inclusive: bool = False, + progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None +) -> Dict[str, Any]: + """ + Migrates messages across all channels chronologically to Stoat. + """ + stats = { + "messages": 0, + "threads": 0, + "attachments": 0, + "last_message_content": "", + "last_message_author": "", + "first_message_url": None, + "last_message_url": None + } + + processed_threads = set() + logger.info("Starting Global Waterfall Migration for Stoat...") + + emoji_map = context.state.emoji_map + db_media = context.discord_reader.db.get_all_media() if context.discord_reader.db else {} + target_server_id = getattr(context.stoat_writer, "community_id", None) + + try: + async for msg in context.discord_reader.fetch_global_message_history(after_id=after_message_id): + if not context.is_running: + logger.warning("Global migration interrupted by user") + break + + if msg.type not in [ + context.discord_reader.MESSAGE_TYPE_DEFAULT, + context.discord_reader.MESSAGE_TYPE_REPLY, + context.discord_reader.MESSAGE_TYPE_THREAD_STARTER, + context.discord_reader.MESSAGE_TYPE_FORWARD, + context.discord_reader.MESSAGE_TYPE_CHAT_INPUT_COMMAND, + context.discord_reader.MESSAGE_TYPE_CONTEXT_MENU_COMMAND, + context.discord_reader.MESSAGE_TYPE_POLL_RESULT, + context.discord_reader.MESSAGE_TYPE_AUTO_MODERATION_ACTION + ]: + continue + + target_channel_id = context.state.get_target_channel_id(str(msg.channel.id)) + if not target_channel_id: + logger.debug(f"Skipping msg {msg.id}: channel {msg.channel.id} not mapped.") + continue + + parent_target_id = None + if hasattr(msg, 'thread') and msg.thread and msg.id == msg.thread.id: + processed_threads.add(msg.thread.id) + stats["threads"] += 1 + elif msg.channel.type in [11, 12]: + pass + + # Formatting + files = [] + + # Always ensure alias is created/retrieved to populate user_alias table + alias = context.state.get_user_alias(str(msg.author.id)) + + anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False + + if anonymize_users: + author_name = alias or "Anonymized User" + author_avatar_url = None + else: + author_name = msg.author.display_name + author_avatar_url = msg.author.avatar.url if hasattr(msg.author, 'avatar') and msg.author.avatar else None + + for att in msg.attachments: + media_info = db_media.get(att.local_hash) if db_media else None + local_path = None + if media_info: + local_path = Path(media_info["local_path"]) + + if local_path and local_path.exists(): + try: + with open(local_path, "rb") as f: + files.append({"filename": att.filename, "data": f.read()}) + except Exception as fe: + logger.error(f"Failed to read file {local_path}: {fe}") + + content = msg.content or "" + + for sticker in msg.stickers: + sticker_name = sticker.name + s_hash = sticker.local_hash + sticker_file = None + s_media = db_media.get(s_hash) if db_media and s_hash else None + if s_media: + s_path = Path(s_media["local_path"]) + if s_path.exists(): + sticker_file = s_path + + content += f"\n[Sticker: {sticker_name}]" + if sticker_file: + files.append(sticker_file) + file_names.append(f"sticker_{sticker_name}.png") + + content = clean_mentions( + content=content, + guild=context.discord_reader.guild, + user_mentions=msg.mentions, + role_mentions=msg.role_mentions, + channel_mentions=msg.channel_mentions, + emoji_map=emoji_map, + channel_map=context.state.channel_map, + state=context.state, + target_server_id=target_server_id, + channel_names=context.channel_names if hasattr(context, 'channel_names') else None, + anonymize_users=anonymize_users + ) + + if not content and not files: + logger.debug(f"Message {msg.id} empty after processing, skipping.") + continue + + timestamp_int = int(msg.created_at.timestamp()) + + if msg.reference and msg.reference.message_id: + # Resolve the author of the message being replied to + source_ref_msg = await context.discord_reader.get_message(msg.channel_id, msg.reference.message_id) + if source_ref_msg and source_ref_msg.author: + ref_author_id = str(source_ref_msg.author.id) + if anonymize_users: + ref_name = context.state.get_user_alias(ref_author_id) or "Anonymized User" + else: + ref_name = source_ref_msg.author.display_name + content = f"`@{ref_name}`\n{content}" + else: + tgt_reply = context.state.get_target_message_id(target_channel_id, msg.reference.message_id) + if tgt_reply: + content = f"[Reply to {tgt_reply}]\n{content}" + + try: + stoat_msg_id = await context.stoat_writer.send_message( + channel_id=target_channel_id, + author_name=author_name, + author_avatar_url=author_avatar_url, + content=content, + files=files, + timestamp=timestamp_int, + embeds=msg.embeds + ) + + if stoat_msg_id: + context.state.set_target_message_mapping(target_channel_id, msg.id, stoat_msg_id) + context.state.update_last_message_id(target_channel_id, msg.id) + stats["attachments"] += len(files) if files else 0 + + stats["messages"] += 1 + stats["last_message_content"] = content + stats["last_message_author"] = author_name + + if not stats["first_message_url"]: + stats["first_message_url"] = msg.jump_url + stats["last_message_url"] = msg.jump_url + + if progress_callback: + await progress_callback(stats) + + except Exception as e: + logger.error(f"Failed to process global message {msg.id}: {e}") + + except (KeyboardInterrupt, asyncio.CancelledError): + context.is_running = False + pass + + return stats diff --git a/src/ui/shuttle_ops.py b/src/ui/shuttle_ops.py index 362ecc7..09243a8 100644 --- a/src/ui/shuttle_ops.py +++ b/src/ui/shuttle_ops.py @@ -160,6 +160,7 @@ class OperationPane(Container): yield Button("Clone Server Template", id="op_clone", disabled=True, tooltip="Clone server roles, categories, and channels to the target community") yield Button("Sync Server Settings", id="op_sync", disabled=True, tooltip="Sync emojis, stickers, server name, and icon to the target community") yield Button("Migrate Message History", id="op_messages", disabled=True, variant="primary", tooltip="Migrate message history from Discord to the target platform") + yield Button("Waterfall Migration", id="op_waterfall", disabled=True, variant="primary", tooltip="Migrate all messages globally in chronological order to prevent broken links.\n(Available for Local Backups)") yield Rule(id="footer_rule") yield Button("Danger Zone ⚠", id="op_danger", variant="error", disabled=True, flat=True, tooltip="Dangerous operations:\ndelete channels, roles, emojis on target\n(use with caution)") @@ -394,7 +395,7 @@ class OperationPane(Container): lbl.update(f"{t_status}") # Buttons - for bid in ("#op_clone", "#op_sync", "#op_messages", "#op_danger"): + for bid in ("#op_clone", "#op_sync", "#op_messages", "#op_waterfall", "#op_danger"): for btn in self.query(bid): btn.disabled = not self.tokens_valid # ── validation ──────────────────────────────────────────────────────── @@ -415,7 +416,7 @@ class OperationPane(Container): # Disable all operation buttons while validation is in progress if self.view_mode == "shuttle": - for bid in ("#op_clone", "#op_sync", "#op_messages", "#op_danger"): + for bid in ("#op_clone", "#op_sync", "#op_messages", "#op_waterfall", "#op_danger"): for btn in self.query(bid): btn.disabled = True elif self.view_mode == "backup": for bid in ("#op_backup_msgs", "#op_backup_sync"): @@ -551,8 +552,13 @@ class OperationPane(Container): else: target_ok = v.get("target_token") and v.get("target_community") self.tokens_valid = bool(discord_ok and target_ok) - - + + # Post validation adjustments + if self.tokens_valid: + is_backup = (self.config.tool_mode == "backup_transfer") + for btn in self.query("#op_waterfall"): + btn.disabled = not is_backup + btn.display = is_backup self._update_info_labels() @@ -570,6 +576,8 @@ class OperationPane(Container): self._open_sync_menu() elif bid == "op_messages": self.run_migrate_messages() + elif bid == "op_waterfall": + self.run_waterfall_migration() elif bid == "op_danger": self._open_danger_menu() @@ -1294,6 +1302,14 @@ class OperationPane(Container): try: self.engine.is_running = True + + # Ensure state is initialized (database exists) + if self.target_platform == "stoat": + tid = self.config.stoat_server_id + else: + tid = self.config.fluxer_server_id + self.engine.ensure_state_initialized(str(tid or ""), platform_name) + stats_analysis = await migrate_mod.analyze_migration( self.engine, source_channel_id=source_channel.id, @@ -1399,6 +1415,187 @@ class OperationPane(Container): else: modal.write(f"[bold red]Error: {err}[/bold red]") modal.phase_report("Message Migration", "error", show_back=False) + import traceback + logger.error(f"Migration Error: {traceback.format_exc()}") + finally: + self.engine.is_running = False + await self.engine.close_connections() + + @work(exclusive=True) + async def run_waterfall_migration(self) -> None: + if not self.tokens_valid: + return + + migrate_mod = fluxer_migrate if self.target_platform == "fluxer" else stoat_migrate + platform_name = self.target_platform.capitalize() + + modal = ProgressScreen(log_level=self.config.log_level) + self.app.push_screen(modal) + await asyncio.sleep(0.1) + + try: + modal.show_info("[bold cyan]Waterfall Migration Ready[/bold cyan]", "Checking mapping and missing channels...") + modal.set_status("Connecting to Servers...") + await self.engine.start_connections() + + modal.set_status("Synchronizing entity mappings...") + await self._perform_auto_matching() + + # 1. Missing channels check + full_d = await self.engine.discord_reader.get_channels() + if hasattr(self.engine.discord_reader, "get_backed_up_channel_ids"): + valid_ids = await self.engine.discord_reader.get_backed_up_channel_ids() + d_channels = [c for c in full_d if c.id in valid_ids and c.type in [0, 5]] + else: + d_channels = [c for c in full_d if c.type in [0, 5]] + + missing_channels = [] + for d in d_channels: + tgt_id = self.engine.state.get_target_channel_id(str(d.id)) + if not tgt_id: + missing_channels.append(d) + + if missing_channels: + modal.write(f"\n[bold yellow]Found {len(missing_channels)} channels with backups but no target mapping.[/bold yellow]") + modal.write("[dim]Do you want to automatically create these missing channels now?[/dim]") + + choice = await modal.phase_wait_confirm( + show_continue=False, + show_id=True, + btn_start_label=f"Yes, Create {len(missing_channels)} Missing Channels", + btn_id_label="No, Skip Them", + btn_start_variant="primary", + btn_start_tooltip="Create channels and map them", + btn_id_tooltip="Skip them (Warning: may cause broken mentions)" + ) + + if choice == "btn_back": + modal.dismiss() + return + elif choice == "btn_main_menu": + modal.dismiss() + self.app.switch_screen("config_selection") + return + + if choice == "btn_start_first": + modal.set_status("Creating missing channels...") + for mc in missing_channels: + try: + modal.write(f"Creating channel '#{mc.name}'...") + new_id = await self.engine.writer.create_channel(name=mc.name) + # Link them + self.engine.state.set_target_channel_id(str(mc.id), new_id, self.engine.platform) + modal.write(f"[green]Created {mc.name} ({new_id})[/green]") + except Exception as e: + modal.write(f"[red]Failed to create {mc.name}: {e}[/red]") + + # 2. Resumption check + all_mapped_tgt_ids = [] + # Check regular channels + for did in [str(c.id) for c in d_channels]: + tid = self.engine.state.get_target_channel_id(did) + if tid: all_mapped_tgt_ids.append(tid) + + # Also check threads + if hasattr(self.engine.discord_reader, "get_active_threads"): + threads = await self.engine.discord_reader.get_active_threads() + for t in threads: + tid = self.engine.state.get_target_channel_id(str(t.id)) + if tid: all_mapped_tgt_ids.append(tid) + + min_last_id = self.engine.state.get_global_min_last_message_id(all_mapped_tgt_ids) + + modal.write(f"\n[bold cyan]Waterfall Migration Resume Point:[/bold cyan]") + if min_last_id: + modal.write(f"Minimum unmigrated message ID found: [green]{min_last_id}[/green]") + else: + modal.write("No previous migration state found. Starting from the beginning.") + + choice = await modal.phase_wait_confirm( + show_continue=bool(min_last_id), + show_id=False, + btn_start_label="Start From Beginning", + btn_start_tooltip="Safe, skips duplicates automatically", + btn_start_variant="default" if min_last_id else "primary", + btn_continue_label=f"Continue from ID {min_last_id}" if min_last_id else "Continue Migration", + btn_continue_tooltip="Fastest" + ) + + if choice == "btn_back": + modal.dismiss() + await self.engine.close_connections() + return + elif choice == "btn_main_menu": + modal.dismiss() + await self.engine.close_connections() + self.app.switch_screen("config_selection") + return + + after_id = None + if choice == "btn_continue" and min_last_id: + after_id = int(min_last_id) + + # Phase 3: Progress + modal.cancel_callback = lambda: setattr(self.engine, "is_running", False) + modal.phase_progress() + modal.set_status("Migrating messages Globally...") + + self.engine.is_running = True + + # Ensure state is initialized (database exists) + if self.target_platform == "stoat": + tid = self.config.stoat_server_id + else: + tid = self.config.fluxer_server_id + self.engine.ensure_state_initialized(str(tid or ""), platform_name) + + modal.write("Scanning global footprint for totals ...") + stats_analysis = await migrate_mod.analyze_global_migration(self.engine, after_message_id=after_id) + total_messages = stats_analysis["messages"] + + modal.write(f"[bold cyan]Global Migration Started:[/bold cyan] {total_messages} total messages to process.") + modal.update_stats(messages=f"0/{total_messages}", threads=str(stats_analysis["threads"]), files=str(stats_analysis["attachments"])) + + async def update_msg(current_stats): + c_msgs = current_stats["messages"] + c_threads = current_stats["threads"] + c_files = current_stats["attachments"] + + msg_stat = f"{c_msgs}/{total_messages}" if total_messages > 0 else str(c_msgs) + modal.set_progress(c_msgs, total_messages or 100) + modal.update_stats(messages=msg_stat, threads=str(c_threads), files=str(c_files)) + + content = current_stats.get("last_message_content", "") + author = current_stats.get("last_message_author", "Unknown") + if content: + disp_content = (content[:100] + '...') if len(content) > 100 else content + modal.write(f"[bold]{author}:[/bold] {disp_content}") + + result = await migrate_mod.migrate_global_messages( + self.engine, + after_message_id=after_id, + inclusive=False, + progress_callback=update_msg, + ) + + if self.engine.is_running: + modal.write(f"[bold green]Success! {result['messages']} messages migrated globally.[/bold green]") + modal.phase_report("Waterfall Migration", show_back=False) + else: + modal.write(f"[bold yellow]Interrupted! {result['messages']} messages migrated.[/bold yellow]") + modal.phase_report("Waterfall Migration", "stopped", show_back=False) + + lines = [f"Migrated Server Globally → {platform_name}:"] + lines.append(f"{result['messages']} messages, {result['attachments']} attachments, {result['threads']} threads") + await log_audit_event(self.engine, "Waterfall Migration", "\n".join(lines)) + + except Exception as e: + err = str(e) + modal.write(f"[bold red]Error: {err}[/bold red]") + modal.phase_report("Waterfall Migration", "error", show_back=False) + + import traceback + logger.error(traceback.format_exc()) finally: self.engine.is_running = False await self.engine.close_connections() From 9074582a270fcab7821ad80bdd06e0781e7d8593 Mon Sep 17 00:00:00 2001 From: rambros Date: Sat, 28 Mar 2026 19:21:28 +0530 Subject: [PATCH 7/8] add resumability for waterfall mode --- src/core/database.py | 68 ++++++++++----- src/core/state.py | 41 +++++++-- src/fluxer/migrate_message.py | 36 +++++++- src/stoat/migrate_message.py | 36 +++++++- src/ui/shuttle_ops.py | 157 ++++++++++++++++++++++++++-------- 5 files changed, 269 insertions(+), 69 deletions(-) diff --git a/src/core/database.py b/src/core/database.py index 419d9aa..47578ea 100644 --- a/src/core/database.py +++ b/src/core/database.py @@ -451,34 +451,48 @@ class MigrationDatabase: return dict(row) return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0} - def get_global_min_last_message_id(self, mapped_channel_ids: List[str]) -> Optional[str]: - """Returns the minimum last_msg_id across all mapped channels. If any mapped channel has NO last_msg_id, returns None.""" - if not mapped_channel_ids: + + def get_global_min_last_message_id(self, all_mapped_ids: List[str]) -> Optional[int]: + """ + Returns the minimum last_msg_id successfully migrated across all mapped channels/threads. + If any mapped entity has no progress record, it is treated as ID 0. + Returns None only if NO progress has been made across ANY entity. + """ + if not all_mapped_ids: return None conn = self._get_conn() - placeholders = ",".join(["?"] * len(mapped_channel_ids)) - rows = conn.execute(f"SELECT last_msg_id FROM channel_tracking WHERE channel_id IN ({placeholders})", mapped_channel_ids).fetchall() + placeholders = ",".join(["?"] * len(all_mapped_ids)) - # If the number of tracked channels is less than mapped, it means some mapped channels haven't started. - if len(rows) < len(mapped_channel_ids): - return None - - # Parse all ids + # 1. Get last message IDs from channel tracking + c_rows = conn.execute(f"SELECT channel_id, last_msg_id FROM channel_tracking WHERE channel_id IN ({placeholders})", all_mapped_ids).fetchall() + c_map = {r["channel_id"]: r["last_msg_id"] for r in c_rows} + + # 2. Get last message IDs from thread tracking + t_rows = conn.execute(f"SELECT thread_id, last_msg_id FROM thread_tracking WHERE thread_id IN ({placeholders})", all_mapped_ids).fetchall() + t_map = {r["thread_id"]: r["last_msg_id"] for r in t_rows} + + # Combine maps + progress_map = {**c_map, **t_map} + + # 3. Aggregate IDs ids = [] - for r in rows: - val = r["last_msg_id"] - if not val: - return None # One channel has no messages yet - try: - ids.append(int(val)) - except ValueError: - pass - - if not ids: + has_any_progress = False + for mid in all_mapped_ids: + last_id = progress_map.get(mid) + if not last_id: + ids.append(0) # Unmigrated entity + else: + try: + ids.append(int(last_id)) + has_any_progress = True + except (ValueError, TypeError): + ids.append(0) + + if not has_any_progress: return None - return str(min(ids)) + return min(ids) # Thread methods similar to channel methods def set_thread_message_mapping(self, channel_id: str, thread_id: str, source_id: str, target_id: str, timestamp: str = None): @@ -525,6 +539,18 @@ class MigrationDatabase: return dict(row) return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0} + def get_all_channel_tracking_ids(self) -> Dict[str, str]: + """Returns a map of channel_id -> last_msg_id for all tracked channels.""" + conn = self._get_conn() + rows = conn.execute("SELECT channel_id, last_msg_id FROM channel_tracking WHERE last_msg_id IS NOT NULL").fetchall() + return {str(row["channel_id"]): str(row["last_msg_id"]) for row in rows} + + def get_all_thread_tracking_ids(self) -> Dict[str, str]: + """Returns a map of thread_id -> last_msg_id for all tracked threads.""" + conn = self._get_conn() + rows = conn.execute("SELECT thread_id, last_msg_id FROM thread_tracking WHERE last_msg_id IS NOT NULL").fetchall() + return {str(row["thread_id"]): str(row["last_msg_id"]) for row in rows} + def clear_channel_data(self, channel_id: str): """Purge all mappings and tracking data for a specific channel and its threads.""" conn = self._get_conn() diff --git a/src/core/state.py b/src/core/state.py index ed9c1c6..adb46ec 100644 --- a/src/core/state.py +++ b/src/core/state.py @@ -38,7 +38,14 @@ class MigrationState: if self.db: self.db.delete_server_mapping("channel", str(discord_id)) + def remove_target_channel_mapping(self, discord_id: int | str): + if self.db: + self.db.delete_server_mapping("channel", str(discord_id)) + def set_target_channel_id(self, discord_id: int | str, target_id: str, *args): + """Alias for set_channel_mapping to handle legacy calls.""" + self.set_channel_mapping(discord_id, target_id) + get_fluxer_channel_id = get_target_channel_id set_target_channel_mapping = set_channel_mapping @@ -58,6 +65,10 @@ class MigrationState: if self.db: self.db.delete_server_mapping("category", str(discord_id)) + def set_target_category_id(self, discord_id: int | str, target_id: str, *args): + """Alias for set_category_mapping to handle legacy calls.""" + self.set_category_mapping(discord_id, target_id) + get_fluxer_category_id = get_category_mapping get_target_category_id = get_category_mapping set_target_category_mapping = set_category_mapping @@ -78,6 +89,10 @@ class MigrationState: if self.db: self.db.delete_server_mapping("role", str(discord_id)) + def set_target_role_id(self, discord_id: int | str, target_id: str, *args): + """Alias for set_role_mapping to handle legacy calls.""" + self.set_role_mapping(discord_id, target_id) + get_fluxer_role_id = get_role_mapping get_target_role_id = get_role_mapping set_target_role_mapping = set_role_mapping @@ -210,16 +225,30 @@ class MigrationState: if self._ensure_db(): self.db.update_channel_tracking(str(target_channel_id), last_msg_id=str(message_id)) - def get_last_message_id(self, target_channel_id: str) -> str | None: + + def get_global_min_last_message_id(self, all_mapped_ids: List[str]) -> int | None: + """Returns the absolute minimum last_msg_id among the given list of mapped target IDs (channels and threads).""" if self._ensure_db(): - return self.db.get_channel_tracking(str(target_channel_id)).get("last_msg_id") + return self.db.get_global_min_last_message_id(all_mapped_ids) return None - def get_global_min_last_message_id(self, mapped_channel_ids: List[str]) -> str | None: - """Returns the absolute minimum last_msg_id among the given list of mapped target channel IDs.""" - if self._ensure_db(): - return self.db.get_global_min_last_message_id(mapped_channel_ids) + def set_waterfall_last_id(self, last_id: str | int): + if self.db: + self.db.set_metadata("waterfall_last_id", str(last_id)) + + def get_waterfall_last_id(self) -> int | None: + if self.db: + val = self.db.get_metadata("waterfall_last_id") + return int(val) if val else None return None + + def get_all_last_message_ids(self) -> Dict[str, str]: + """Returns a combined map of channel_id/thread_id -> last_msg_id.""" + if self._ensure_db(): + c_map = self.db.get_all_channel_tracking_ids() + t_map = self.db.get_all_thread_tracking_ids() + return {**c_map, **t_map} + return {} def get_thread_last_message_id(self, target_channel_id: str, thread_id: str) -> str | None: if self._ensure_db(): diff --git a/src/fluxer/migrate_message.py b/src/fluxer/migrate_message.py index c63693e..5a2c0ea 100644 --- a/src/fluxer/migrate_message.py +++ b/src/fluxer/migrate_message.py @@ -543,8 +543,11 @@ async def migrate_messages( except Exception as e: logger.error(f"Failed to download sticker {getattr(s, 'name', 'unknown')}: {e}") + # Check for existing mapping to avoid duplicates when resuming + if context.state.get_target_message_id(target_channel_id, str(msg.id)): + continue + try: - # Check if this message is a reply reply_to_fluxer_id = None if msg.reference and msg.reference.message_id: reply_to_fluxer_id = context.state.get_fluxer_message_id(target_channel_id, str(msg.reference.message_id)) @@ -676,10 +679,27 @@ async def analyze_global_migration(context: MigrationContext, after_message_id: # In global mode, thread messages are returned natively in timestamp order by global fetch if they're in the DB # However we just count them if the fetcher yields them. + # Fetch global progress map to skip migrated messages efficiently + progress_map = context.state.get_all_last_message_ids() + async for msg in context.discord_reader.fetch_global_message_history(after_id=after_message_id): if not context.is_running: break + # Determine target channel to check for existing mapping + if not msg.channel: + continue + + target_channel_id = context.state.get_target_channel_id(str(msg.channel.id)) + if not target_channel_id: + continue + + # Efficient skip: if message ID is <= last migrated ID for this channel/thread + # This is the primary resume mechanism: wait until we pass the last migrated ID for this channel + last_id = progress_map.get(str(msg.channel.id)) + if last_id and msg.id <= int(last_id): + continue + if msg.type not in [ context.discord_reader.MESSAGE_TYPE_DEFAULT, context.discord_reader.MESSAGE_TYPE_REPLY, @@ -738,6 +758,9 @@ async def migrate_global_messages( db_media = context.discord_reader.db.get_all_media() if context.discord_reader.db else {} target_server_id = getattr(context.fluxer_writer, "server_id", None) + # Fetch global progress map to skip migrated messages efficiently + progress_map = context.state.get_all_last_message_ids() + try: async for msg in context.discord_reader.fetch_global_message_history(after_id=after_message_id): if not context.is_running: @@ -757,9 +780,17 @@ async def migrate_global_messages( continue # Determine target channel + if not msg.channel: + continue + target_channel_id = context.state.get_target_channel_id(str(msg.channel.id)) if not target_channel_id: - logger.debug(f"Skipping msg {msg.id}: channel {msg.channel.id} not mapped.") + continue + + # Efficient skip: if message ID is <= last migrated ID for this channel/thread + # This ensures we only resume a channel once we reach its last known progress point + last_id = progress_map.get(str(target_channel_id)) + if last_id and msg.id <= int(last_id): continue # If it's a thread message, we need to handle it based on if it's the thread starter or a reply @@ -874,6 +905,7 @@ async def migrate_global_messages( if fluxer_msg_id: context.state.set_target_message_mapping(target_channel_id, msg.id, fluxer_msg_id) context.state.update_last_message_id(target_channel_id, msg.id) + context.state.set_waterfall_last_id(msg.id) stats["attachments"] += len(files) if files else 0 stats["messages"] += 1 diff --git a/src/stoat/migrate_message.py b/src/stoat/migrate_message.py index 45d0fee..130579f 100644 --- a/src/stoat/migrate_message.py +++ b/src/stoat/migrate_message.py @@ -547,6 +547,10 @@ async def migrate_messages( logger.error(f"Failed to download sticker {getattr(s, 'name', 'unknown')}: {e}") try: + # Check for existing mapping to avoid duplicates when resuming + if context.state.get_target_message_id(target_channel_id, str(msg.id)): + continue + # Check if this message is a reply reply_to_stoat_id = None if msg.reference and msg.reference.message_id: @@ -678,10 +682,27 @@ async def analyze_global_migration(context: MigrationContext, after_message_id: """ stats = {"messages": 0, "threads": 0, "attachments": 0} + # Fetch global progress map to skip migrated messages efficiently + progress_map = context.state.get_all_last_message_ids() + async for msg in context.discord_reader.fetch_global_message_history(after_id=after_message_id): if not context.is_running: break + # Determine target channel to check for existing mapping + if not msg.channel: + continue + + target_channel_id = context.state.get_target_channel_id(str(msg.channel.id)) + if not target_channel_id: + continue + + # Efficient skip: if message ID is <= last migrated ID for this channel/thread + # This is the primary resume mechanism: wait until we pass the last migrated ID for this channel + last_id = progress_map.get(str(msg.channel.id)) + if last_id and msg.id <= int(last_id): + continue + if msg.type not in [ context.discord_reader.MESSAGE_TYPE_DEFAULT, context.discord_reader.MESSAGE_TYPE_REPLY, @@ -732,7 +753,8 @@ async def migrate_global_messages( emoji_map = context.state.emoji_map db_media = context.discord_reader.db.get_all_media() if context.discord_reader.db else {} - target_server_id = getattr(context.stoat_writer, "community_id", None) + # Fetch global progress map to skip migrated messages efficiently + progress_map = context.state.get_all_last_message_ids() try: async for msg in context.discord_reader.fetch_global_message_history(after_id=after_message_id): @@ -752,9 +774,18 @@ async def migrate_global_messages( ]: continue + # Determine target channel + if not msg.channel: + continue + target_channel_id = context.state.get_target_channel_id(str(msg.channel.id)) if not target_channel_id: - logger.debug(f"Skipping msg {msg.id}: channel {msg.channel.id} not mapped.") + continue + + # Efficient skip: if message ID is <= last migrated ID for this channel/thread + # This ensures we only resume a channel once we reach its last known progress point + last_id = progress_map.get(str(target_channel_id)) + if last_id and msg.id <= int(last_id): continue parent_target_id = None @@ -858,6 +889,7 @@ async def migrate_global_messages( if stoat_msg_id: context.state.set_target_message_mapping(target_channel_id, msg.id, stoat_msg_id) context.state.update_last_message_id(target_channel_id, msg.id) + context.state.set_waterfall_last_id(msg.id) stats["attachments"] += len(files) if files else 0 stats["messages"] += 1 diff --git a/src/ui/shuttle_ops.py b/src/ui/shuttle_ops.py index 09243a8..1e810d6 100644 --- a/src/ui/shuttle_ops.py +++ b/src/ui/shuttle_ops.py @@ -724,7 +724,6 @@ class OperationPane(Container): return elif choice == "btn_main_menu": modal.dismiss() - self.app.switch_screen("config_selection") return force_mode = (choice == "btn_start_id") @@ -806,7 +805,6 @@ class OperationPane(Container): return elif choice == "btn_main_menu": modal.dismiss() - self.app.switch_screen("config_selection") return force_mode = (choice == "btn_start_id") @@ -1251,7 +1249,6 @@ class OperationPane(Container): continue # Return to channel picker elif choice == "btn_main_menu": modal.dismiss() - self.app.switch_screen("config_selection") self.engine.is_running = False await self.engine.close_connections() return @@ -1415,7 +1412,6 @@ class OperationPane(Container): else: modal.write(f"[bold red]Error: {err}[/bold red]") modal.phase_report("Message Migration", "error", show_back=False) - import traceback logger.error(f"Migration Error: {traceback.format_exc()}") finally: self.engine.is_running = False @@ -1442,82 +1438,137 @@ class OperationPane(Container): await self._perform_auto_matching() # 1. Missing channels check - full_d = await self.engine.discord_reader.get_channels() - if hasattr(self.engine.discord_reader, "get_backed_up_channel_ids"): - valid_ids = await self.engine.discord_reader.get_backed_up_channel_ids() - d_channels = [c for c in full_d if c.id in valid_ids and c.type in [0, 5]] + if hasattr(self.engine.discord_reader, "get_all_channels"): + full_d = await self.engine.discord_reader.get_all_channels() + # Include TEXT (0), CATEGORY (4), and NEWS (5) + d_channels = [c for c in full_d if c.type in [0, 4, 5]] else: - d_channels = [c for c in full_d if c.type in [0, 5]] - + full_d = await self.engine.discord_reader.get_channels() + d_channels = [c for c in full_d if c.type in [0, 4, 5]] missing_channels = [] for d in d_channels: - tgt_id = self.engine.state.get_target_channel_id(str(d.id)) + if d.type == 4: + tgt_id = self.engine.state.get_target_category_id(str(d.id)) + else: + tgt_id = self.engine.state.get_target_channel_id(str(d.id)) if not tgt_id: missing_channels.append(d) if missing_channels: - modal.write(f"\n[bold yellow]Found {len(missing_channels)} channels with backups but no target mapping.[/bold yellow]") - modal.write("[dim]Do you want to automatically create these missing channels now?[/dim]") + modal.write(f"\n[bold yellow]Found {len(missing_channels)} backed-up channels/categories missing from target platform:[/bold yellow]") + for mc in missing_channels: + prefix = "[bold cyan]📁[/bold cyan] " if mc.type == 4 else "[bold white]#[/bold white] " + modal.write(f" {prefix}{mc.name}") choice = await modal.phase_wait_confirm( show_continue=False, show_id=True, - btn_start_label=f"Yes, Create {len(missing_channels)} Missing Channels", - btn_id_label="No, Skip Them", + btn_start_label="Clone missing channels", + btn_id_label="Skip missing channels", btn_start_variant="primary", - btn_start_tooltip="Create channels and map them", - btn_id_tooltip="Skip them (Warning: may cause broken mentions)" + btn_start_tooltip=f"Automatically create {len(missing_channels)} entities on target", + btn_id_tooltip="Start migration without these channels" ) if choice == "btn_back": modal.dismiss() + await self.engine.close_connections() return elif choice == "btn_main_menu": modal.dismiss() - self.app.switch_screen("config_selection") + await self.engine.close_connections() return if choice == "btn_start_first": - modal.set_status("Creating missing channels...") + modal.set_status("Cloning missing categories and channels...") + # Sort so categories (type 4) come first + missing_channels.sort(key=lambda x: 0 if x.type == 4 else 1) + for mc in missing_channels: try: - modal.write(f"Creating channel '#{mc.name}'...") - new_id = await self.engine.writer.create_channel(name=mc.name) - # Link them - self.engine.state.set_target_channel_id(str(mc.id), new_id, self.engine.platform) - modal.write(f"[green]Created {mc.name} ({new_id})[/green]") + parent_target_id = None + if mc.type == 4: + modal.write(f"Creating category '[bold cyan]{mc.name}[/bold cyan]'...") + new_id = await self.engine.writer.create_channel(name=mc.name, type=4) + self.engine.state.set_target_category_mapping(str(mc.id), new_id) + modal.write(f"[green]Created Category {mc.name} ({new_id})[/green]") + else: + if hasattr(mc, 'category_id') and mc.category_id: + parent_target_id = self.engine.state.get_target_category_id(str(mc.category_id)) + + modal.write(f"Creating channel '#{mc.name}'...") + new_id = await self.engine.writer.create_channel(name=mc.name, parent_id=parent_target_id) + self.engine.state.set_target_channel_id(str(mc.id), new_id, self.engine.target_platform) + modal.write(f"[green]Created Channel {mc.name} ({new_id})[/green]") except Exception as e: + logger.error(f"Failed to create {mc.name}: {e}\n{traceback.format_exc()}") modal.write(f"[red]Failed to create {mc.name}: {e}[/red]") + elif choice == "btn_id": + # Skip missing channels: remove them from the active list + missing_ids = {str(c.id) for c in missing_channels} + d_channels = [c for c in d_channels if str(c.id) not in missing_ids] # 2. Resumption check all_mapped_tgt_ids = [] - # Check regular channels - for did in [str(c.id) for c in d_channels]: + # Check regular text channels (exclude categories for resume check) + for c in d_channels: + if c.type == 4: continue + did = str(c.id) tid = self.engine.state.get_target_channel_id(did) if tid: all_mapped_tgt_ids.append(tid) - # Also check threads + # Also check threads (filtering to only include those belonging to active channels) + active_channel_ids = {str(c.id) for c in d_channels} if hasattr(self.engine.discord_reader, "get_active_threads"): threads = await self.engine.discord_reader.get_active_threads() for t in threads: + pid = str(getattr(t, 'parent_id', getattr(t, 'channel_id', None))) + if pid not in active_channel_ids: continue tid = self.engine.state.get_target_channel_id(str(t.id)) if tid: all_mapped_tgt_ids.append(tid) + + # 2.5 Filter by actual content (Only for BackupReader) + # If a channel has NO messages in the backup, it will always be at 0 progress. + # We exclude those from the global MIN calculation to avoid pulling it to 0. + if hasattr(self.engine.discord_reader, "get_backed_up_channel_ids"): + backed_up_src_ids = await self.engine.discord_reader.get_backed_up_channel_ids() + backed_up_src_ids_str = {str(sid) for sid in backed_up_src_ids} - min_last_id = self.engine.state.get_global_min_last_message_id(all_mapped_tgt_ids) + filtered_tgt_ids = [] + # Find which target IDs belong to source channels that HAVE messages + for c in d_channels: # (d_channels is already filtered for skipped) + if str(c.id) in backed_up_src_ids_str: + tid = self.engine.state.get_target_channel_id(str(c.id)) + if tid: filtered_tgt_ids.append(tid) + + # Also check threads + if hasattr(self.engine.discord_reader, "threads"): + for t in self.engine.discord_reader.threads: + if str(t.id) in backed_up_src_ids_str: + tid = self.engine.state.get_target_channel_id(str(t.id)) + if tid: filtered_tgt_ids.append(tid) + + if filtered_tgt_ids: + all_mapped_tgt_ids = filtered_tgt_ids + + # 2.6 Resume Point: Prioritize Global waterfall tracker, fallback to channel minimums + min_last_id = self.engine.state.get_waterfall_last_id() + if min_last_id is None: + min_last_id = self.engine.state.get_global_min_last_message_id(all_mapped_tgt_ids) modal.write(f"\n[bold cyan]Waterfall Migration Resume Point:[/bold cyan]") - if min_last_id: + if min_last_id is not None: modal.write(f"Minimum unmigrated message ID found: [green]{min_last_id}[/green]") else: modal.write("No previous migration state found. Starting from the beginning.") choice = await modal.phase_wait_confirm( - show_continue=bool(min_last_id), + show_continue=min_last_id is not None, show_id=False, btn_start_label="Start From Beginning", btn_start_tooltip="Safe, skips duplicates automatically", - btn_start_variant="default" if min_last_id else "primary", - btn_continue_label=f"Continue from ID {min_last_id}" if min_last_id else "Continue Migration", + btn_start_variant="default" if min_last_id is not None else "primary", + btn_continue_label=f"Continue from ID {min_last_id if min_last_id is not None else 0}" if min_last_id is not None else "Continue Migration", btn_continue_tooltip="Fastest" ) @@ -1528,11 +1579,10 @@ class OperationPane(Container): elif choice == "btn_main_menu": modal.dismiss() await self.engine.close_connections() - self.app.switch_screen("config_selection") return after_id = None - if choice == "btn_continue" and min_last_id: + if choice == "btn_continue" and min_last_id is not None: after_id = int(min_last_id) # Phase 3: Progress @@ -1594,7 +1644,6 @@ class OperationPane(Container): modal.write(f"[bold red]Error: {err}[/bold red]") modal.phase_report("Waterfall Migration", "error", show_back=False) - import traceback logger.error(traceback.format_exc()) finally: self.engine.is_running = False @@ -1682,7 +1731,6 @@ class OperationPane(Container): return elif choice == "btn_main_menu": modal.dismiss() - self.app.switch_screen("config_selection") return modal.cancel_callback = lambda: setattr(self.engine, "is_running", False) @@ -1841,6 +1889,39 @@ class OperationPane(Container): logger.warning(f"Auto-matching: failed to fetch target data: {e}") return # Cannot match without target data + # 1.5 Cleanup deleted entities from mapping database + # This prevents "Ghost" mappings to channels/roles that were deleted on target + valid_chan_ids = {str(c.get("id")) for c in target_chans_raw} + valid_cat_ids = {str(c.get("id")) for c in target_chans_raw if c.get("type") == 4} + valid_role_ids = set(target_roles_map.values()) + valid_emoji_ids = set(target_emojis_map.values()) + + # Channels + for src_id, tgt_id in self.engine.state.channel_map.items(): + if str(tgt_id) not in valid_chan_ids: + logger.info(f"Auto-matching: clearing deleted channel mapping {src_id} -> {tgt_id}") + self.engine.state.remove_target_channel_mapping(src_id) + + # Categories + for src_id, tgt_id in self.engine.state.category_map.items(): + if str(tgt_id) not in valid_cat_ids: + logger.info(f"Auto-matching: clearing deleted category mapping {src_id} -> {tgt_id}") + self.engine.state.remove_category_mapping(src_id) + + # Roles + for src_id, tgt_id in self.engine.state.role_map.items(): + if str(tgt_id) not in valid_role_ids: + logger.info(f"Auto-matching: clearing deleted role mapping {src_id} -> {tgt_id}") + self.engine.state.remove_role_mapping(src_id) + + # Emojis + for src_id, tgt_id in self.engine.state.emoji_map.items(): + if str(tgt_id) not in valid_emoji_ids: + # Emojis might be URLs in some platforms, but we check if they are IDs first + if isinstance(tgt_id, str) and tgt_id.isdigit(): + logger.info(f"Auto-matching: clearing deleted emoji mapping {src_id} -> {tgt_id}") + self.engine.state.remove_emoji_mapping(src_id) + # 2. Match entities try: # Roles @@ -1884,6 +1965,7 @@ class OperationPane(Container): logger.info(f"Auto-matched Sticker: {s.name} -> {target_stickers_map[name_l]}") self.engine.state.set_target_sticker_mapping(s.id, target_stickers_map[name_l]) except Exception as e: + logger.error(f"Auto-matching error: {e}\n{traceback.format_exc()}") logger.warning(f"Auto-matching error: {e}") return { @@ -2148,7 +2230,6 @@ class OperationPane(Container): after_id = verified_id elif choice == "btn_main_menu": modal_prog.dismiss() - self.app.switch_screen("config_selection") return # If we are here, proceeding either via Start First or Start from ID (after_id) @@ -2259,7 +2340,7 @@ class OperationPane(Container): self.engine.is_running = False await self.engine.close_connections() if choice == "btn_main_menu": - self.app.switch_screen("config_selection") + pass return modal_prog.cancel_callback = lambda: setattr(self.engine, "is_running", False) From 654d63c407701c070fcb959812e905176fabba60 Mon Sep 17 00:00:00 2001 From: rambros Date: Sat, 28 Mar 2026 20:15:20 +0530 Subject: [PATCH 8/8] message migration fix --- src/core/state.py | 6 ++++++ src/ui/shuttle_ops.py | 10 +++++----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/core/state.py b/src/core/state.py index adb46ec..eb816c5 100644 --- a/src/core/state.py +++ b/src/core/state.py @@ -225,6 +225,12 @@ class MigrationState: if self._ensure_db(): self.db.update_channel_tracking(str(target_channel_id), last_msg_id=str(message_id)) + def get_last_message_id(self, target_channel_id: str) -> str | None: + if self._ensure_db(): + tracking = self.db.get_channel_tracking(str(target_channel_id)) + return tracking.get("last_msg_id") if tracking else None + return None + def get_global_min_last_message_id(self, all_mapped_ids: List[str]) -> int | None: """Returns the absolute minimum last_msg_id among the given list of mapped target IDs (channels and threads).""" diff --git a/src/ui/shuttle_ops.py b/src/ui/shuttle_ops.py index 1e810d6..fc17583 100644 --- a/src/ui/shuttle_ops.py +++ b/src/ui/shuttle_ops.py @@ -1105,6 +1105,11 @@ class OperationPane(Container): source_channel = next(c for c in d_channels if c.id == src_id) target_channel = next(c for c in f_channels if c.get("id") == tgt_id) + # 2. Analyze + modal = ProgressScreen(log_level=self.config.log_level) + self.app.push_screen(modal) + await asyncio.sleep(0.1) + # Determine after_id status (skip for pending channels) if pending_create_name: last_migrated = None @@ -1112,11 +1117,6 @@ class OperationPane(Container): else: last_migrated = self.engine.state.get_last_message_id(str(target_channel.get('id'))) has_previous = bool(last_migrated) - - # Analyze - modal = ProgressScreen(log_level=self.config.log_level) - self.app.push_screen(modal) - await asyncio.sleep(0.1) src_server = getattr(self.engine.discord_reader, 'guild', None) tgt_server_info = await self.engine.writer.validate()