diff --git a/README.md b/README.md index aac1b6e..486c5dc 100644 --- a/README.md +++ b/README.md @@ -9,8 +9,8 @@ ![Disco Reaper](images/fluxer-reaper.jpg) -### Modern Terminal Interface -The tool now features a unified, intuitive TUI (Terminal User Interface) - no more text commands +### Video Guide - [Youtube](https://www.youtube.com/watch?v=SwIPQDxLzqA) + | Features | Fluxer | Stoat | | :--- | :---: | :---: | diff --git a/src/core/backup_database.py b/src/core/backup_database.py index ecbd1b8..40e0011 100644 --- a/src/core/backup_database.py +++ b/src/core/backup_database.py @@ -4,20 +4,11 @@ import json import threading from pathlib import Path from typing import Dict, Any, List, Optional, Union +from src.core.utils import parse_snowflake logger = logging.getLogger(__name__) -def parse_snowflake(value: Any) -> Optional[int]: - """Safely parses a Discord ID (Snowflake) from any input, handling 'None' strings.""" - if value is None: - return None - s = str(value).strip() - if not s or s.lower() == "none" or s == "NULL": - return None - try: - return int(s) - except ValueError: - return None + class BackupDatabase: """Manages the SQLite database for local Discord backups.""" @@ -39,24 +30,101 @@ class BackupDatabase: def _migrate_db(self): """Handles backward compatibility by renaming columns in existing databases.""" with self._lock: - # Check 'media_pool' table - res = self._conn.execute("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='media_pool'").fetchone() - if res[0] > 0: - cols = self._conn.execute("PRAGMA table_info(media_pool)").fetchall() - col_names = [c["name"] for c in cols] - if "mime_type" in col_names and "content_type" not in col_names: - logger.info("Migrating media_pool: renaming 'mime_type' to 'content_type'") - self._conn.execute("ALTER TABLE media_pool RENAME COLUMN mime_type TO content_type") + conn = self._conn + # 1. MIME Type to Content Type Migrations + for table in ["media_pool", "server_assets"]: + res = conn.execute(f"SELECT count(*) FROM sqlite_master WHERE type='table' AND name='{table}'").fetchone() + if res and res[0] > 0: + cols = conn.execute(f"PRAGMA table_info({table})").fetchall() + col_names = [c["name"] for c in cols] + if "mime_type" in col_names and "content_type" not in col_names: + logger.info(f"Migrating {table}: renaming 'mime_type' to 'content_type'") + conn.execute(f"ALTER TABLE {table} RENAME COLUMN mime_type TO content_type") - # Check 'server_assets' table - res = self._conn.execute("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='server_assets'").fetchone() - if res[0] > 0: - cols = self._conn.execute("PRAGMA table_info(server_assets)").fetchall() - col_names = [c["name"] for c in cols] - if "mime_type" in col_names and "content_type" not in col_names: - logger.info("Migrating server_assets: renaming 'mime_type' to 'content_type'") - self._conn.execute("ALTER TABLE server_assets RENAME COLUMN mime_type TO content_type") - self._conn.commit() + # 2. Universal ID Migration (TEXT -> INTEGER) + # Mapping of table names to columns that must be INTEGER (Snowflakes) + id_migrations = { + "guild_profile": ["id", "owner_id"], + "roles": ["id", "permissions"], + "channels": ["id", "category_id"], + "permissions": ["channel_id", "target_id"], + "users": ["id"], + "messages": ["id", "channel_id", "author_id", "message_reference"], + "attachments": ["id", "message_id"], + "embeds": ["message_id"], + "reactions": ["message_id", "emoji_id"], + "message_stickers": ["message_id", "sticker_id"], + "threads": ["id", "parent_id"], + "forum_tags": ["id", "forum_id", "emoji_id"], + "server_assets": ["id"] + } + + for table, id_cols in id_migrations.items(): + res = conn.execute(f"SELECT count(*) FROM sqlite_master WHERE type='table' AND name='{table}'").fetchone() + if not res or res[0] == 0: + continue + + cols = conn.execute(f"PRAGMA table_info({table})").fetchall() + needs_migration = False + for col in cols: + if col[1] in id_cols and col[2] == "TEXT": + needs_migration = True + break + + if needs_migration: + logger.info(f"Migrating {table}: converting ID columns to INTEGER") + # Special Case: messages already handled id, but now generic + # We use a temporary table to handle the schema change + conn.execute(f"ALTER TABLE {table} RENAME TO {table}_old") + + # We can't easily generate the CREATE TABLE here without duplicating _init_db logic + # So we call _init_db to create the NEW table, then copy data + # But _init_db has 'IF NOT EXISTS', so we just call it once at the end? + # No, we need the table NOW for the INSERT. + # I'll just manually define the inserts or do it in _init_db. + + # Actually, a better way is to do the CREATE TABLE here for this specific table. + # I'll have to duplicate the schema from _init_db for the migration. + + # Alternatively, since we are already in _migrate_db, we can just do the + # specific CREATE TABLE for the table we are migrating. + + if table == "guild_profile": + conn.execute("CREATE TABLE guild_profile (id INTEGER PRIMARY KEY, name TEXT, description TEXT, icon_file TEXT, icon_url TEXT, banner_file TEXT, banner_url TEXT, owner_id INTEGER, last_backup TEXT, ignore_channels TEXT)") + elif table == "roles": + conn.execute("CREATE TABLE roles (id INTEGER PRIMARY KEY, name TEXT, color INTEGER, position INTEGER, permissions INTEGER, hoist INTEGER, mentionable INTEGER)") + elif table == "channels": + conn.execute("CREATE TABLE channels (id INTEGER PRIMARY KEY, name TEXT, type INTEGER, position INTEGER, category_id INTEGER, topic TEXT, nsfw INTEGER, bitrate INTEGER, slowmode_delay INTEGER)") + elif table == "permissions": + conn.execute("CREATE TABLE permissions (id INTEGER PRIMARY KEY AUTOINCREMENT, channel_id INTEGER, target_id INTEGER, target_type TEXT, allow INTEGER, deny INTEGER)") + elif table == "users": + conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, username TEXT, display_name TEXT, avatar_file TEXT, avatar_url TEXT, roles TEXT)") + elif table == "messages": + conn.execute("CREATE TABLE messages (id INTEGER PRIMARY KEY, channel_id INTEGER, author_id INTEGER, content TEXT, timestamp TEXT, type INTEGER, message_reference INTEGER, is_pinned INTEGER, extra_data TEXT)") + elif table == "attachments": + conn.execute("CREATE TABLE attachments (id INTEGER PRIMARY KEY, message_id INTEGER, filename TEXT, size INTEGER, url TEXT, content_type TEXT, local_hash TEXT)") + elif table == "embeds": + conn.execute("CREATE TABLE embeds (id INTEGER PRIMARY KEY AUTOINCREMENT, message_id INTEGER, title TEXT, description TEXT, url TEXT, color INTEGER, timestamp TEXT, thumbnail_url TEXT, image_url TEXT, author_name TEXT, author_url TEXT, author_icon_url TEXT, footer_text TEXT, footer_icon_url TEXT, fields TEXT)") + elif table == "reactions": + conn.execute("CREATE TABLE reactions (id INTEGER PRIMARY KEY AUTOINCREMENT, message_id INTEGER, emoji_id INTEGER, emoji_name TEXT, count INTEGER)") + elif table == "message_stickers": + conn.execute("CREATE TABLE message_stickers (message_id INTEGER, sticker_id INTEGER, name TEXT, url TEXT, format_type INTEGER, local_hash TEXT, PRIMARY KEY (message_id, sticker_id))") + elif table == "threads": + conn.execute("CREATE TABLE threads (id INTEGER PRIMARY KEY, name TEXT, type INTEGER, parent_id INTEGER, message_count INTEGER, member_count INTEGER, archived INTEGER, archive_timestamp TEXT, auto_archive_duration INTEGER, locked INTEGER, applied_tags TEXT)") + elif table == "forum_tags": + conn.execute("CREATE TABLE forum_tags (id INTEGER PRIMARY KEY, forum_id INTEGER, name TEXT, moderated INTEGER, emoji_id INTEGER, emoji_name TEXT)") + elif table == "server_assets": + conn.execute("CREATE TABLE server_assets (id INTEGER PRIMARY KEY, name TEXT, type TEXT, filename TEXT, url TEXT, content_type INTEGER)") + + old_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table}_old)").fetchall()] + new_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table})").fetchall()] + common_cols = [c for c in old_cols if c in new_cols] + col_str = ", ".join(common_cols) + + conn.execute(f"INSERT INTO {table} ({col_str}) SELECT {col_str} FROM {table}_old") + conn.execute(f"DROP TABLE {table}_old") + + conn.commit() def _init_db(self): """Initializes the database schema.""" @@ -66,14 +134,14 @@ class BackupDatabase: # Guild Profile conn.execute(""" CREATE TABLE IF NOT EXISTS guild_profile ( - id TEXT PRIMARY KEY, + id INTEGER PRIMARY KEY, name TEXT, description TEXT, icon_file TEXT, icon_url TEXT, banner_file TEXT, banner_url TEXT, - owner_id TEXT, + owner_id INTEGER, last_backup TEXT, ignore_channels TEXT ) @@ -82,11 +150,11 @@ class BackupDatabase: # Roles conn.execute(""" CREATE TABLE IF NOT EXISTS roles ( - id TEXT PRIMARY KEY, + id INTEGER PRIMARY KEY, name TEXT, color INTEGER, position INTEGER, - permissions TEXT, + permissions INTEGER, hoist INTEGER, mentionable INTEGER ) @@ -95,11 +163,11 @@ class BackupDatabase: # Channels conn.execute(""" CREATE TABLE IF NOT EXISTS channels ( - id TEXT PRIMARY KEY, + id INTEGER PRIMARY KEY, name TEXT, type INTEGER, position INTEGER, - category_id TEXT, + category_id INTEGER, topic TEXT, nsfw INTEGER, bitrate INTEGER, @@ -111,8 +179,8 @@ class BackupDatabase: conn.execute(""" CREATE TABLE IF NOT EXISTS permissions ( id INTEGER PRIMARY KEY AUTOINCREMENT, - channel_id TEXT, - target_id TEXT, + channel_id INTEGER, + target_id INTEGER, target_type TEXT, allow INTEGER, deny INTEGER @@ -123,7 +191,7 @@ class BackupDatabase: # Users (Author cache) conn.execute(""" CREATE TABLE IF NOT EXISTS users ( - id TEXT PRIMARY KEY, + id INTEGER PRIMARY KEY, username TEXT, display_name TEXT, avatar_file TEXT, @@ -135,13 +203,13 @@ class BackupDatabase: # Messages conn.execute(""" CREATE TABLE IF NOT EXISTS messages ( - id TEXT PRIMARY KEY, - channel_id TEXT, - author_id TEXT, + id INTEGER PRIMARY KEY, + channel_id INTEGER, + author_id INTEGER, content TEXT, timestamp TEXT, type INTEGER, - message_reference TEXT, + message_reference INTEGER, is_pinned INTEGER, extra_data TEXT ) @@ -152,8 +220,8 @@ class BackupDatabase: # Attachments conn.execute(""" CREATE TABLE IF NOT EXISTS attachments ( - id TEXT PRIMARY KEY, - message_id TEXT, + id INTEGER PRIMARY KEY, + message_id INTEGER, filename TEXT, size INTEGER, url TEXT, @@ -167,7 +235,7 @@ class BackupDatabase: conn.execute(""" CREATE TABLE IF NOT EXISTS embeds ( id INTEGER PRIMARY KEY AUTOINCREMENT, - message_id TEXT, + message_id INTEGER, title TEXT, description TEXT, url TEXT, @@ -189,8 +257,8 @@ class BackupDatabase: conn.execute(""" CREATE TABLE IF NOT EXISTS reactions ( id INTEGER PRIMARY KEY AUTOINCREMENT, - message_id TEXT, - emoji_id TEXT, + message_id INTEGER, + emoji_id INTEGER, emoji_name TEXT, count INTEGER ) @@ -200,8 +268,8 @@ class BackupDatabase: # Message Stickers conn.execute(""" CREATE TABLE IF NOT EXISTS message_stickers ( - message_id TEXT, - sticker_id TEXT, + message_id INTEGER, + sticker_id INTEGER, name TEXT, url TEXT, format_type INTEGER, @@ -214,10 +282,10 @@ class BackupDatabase: # Threads conn.execute(""" CREATE TABLE IF NOT EXISTS threads ( - id TEXT PRIMARY KEY, + id INTEGER PRIMARY KEY, name TEXT, type INTEGER, - parent_id TEXT, + parent_id INTEGER, message_count INTEGER, member_count INTEGER, archived INTEGER, @@ -232,11 +300,11 @@ class BackupDatabase: # Forum Tags (Definitions for a forum channel) conn.execute(""" CREATE TABLE IF NOT EXISTS forum_tags ( - id TEXT PRIMARY KEY, - forum_id TEXT, + id INTEGER PRIMARY KEY, + forum_id INTEGER, name TEXT, moderated INTEGER, - emoji_id TEXT, + emoji_id INTEGER, emoji_name TEXT ) """) @@ -257,7 +325,7 @@ class BackupDatabase: # Server Assets (Emojis, Stickers, etc.) conn.execute(""" CREATE TABLE IF NOT EXISTS server_assets ( - id TEXT PRIMARY KEY, + id INTEGER PRIMARY KEY, name TEXT, type TEXT, filename TEXT, @@ -276,10 +344,10 @@ class BackupDatabase: INSERT OR REPLACE INTO guild_profile (id, name, description, icon_file, icon_url, banner_file, banner_url, owner_id, last_backup, ignore_channels) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( - str(data.get("id")), data.get("name"), data.get("description"), + parse_snowflake(data.get("id")), data.get("name"), data.get("description"), data.get("icon_file"), data.get("icon_url"), data.get("banner_file"), data.get("banner_url"), - str(data.get("owner_id")), + parse_snowflake(data.get("owner_id")), data.get("last_backup"), json.dumps(data.get("ignore_channels", [])) )) self._conn.commit() @@ -300,7 +368,7 @@ class BackupDatabase: with self._lock: formatted = [ { - "id": str(r["id"]), + "id": parse_snowflake(r["id"]), "name": r["name"], "color": r["color"], "position": r["position"], @@ -347,7 +415,7 @@ class BackupDatabase: with self._lock: formatted = [ { - "id": str(a["id"]), + "id": parse_snowflake(a["id"]), "name": a.get("name"), "type": a.get("type"), "filename": a.get("filename"), @@ -407,7 +475,7 @@ class BackupDatabase: for rea in msg["reactions"]: all_reactions.append({ "message_id": msg["id"], - "emoji_id": str(rea["emoji_id"]) if rea.get("emoji_id") else None, + "emoji_id": parse_snowflake(rea["emoji_id"]) if rea.get("emoji_id") else None, "emoji_name": rea.get("emoji_name"), "count": rea.get("count", 0) }) @@ -417,7 +485,7 @@ class BackupDatabase: for st in msg["stickers"]: all_stickers.append({ "message_id": msg["id"], - "sticker_id": str(st["id"]), + "sticker_id": parse_snowflake(st["id"]), "name": st.get("name"), "url": st.get("url"), "format_type": st.get("format_type"), @@ -469,7 +537,7 @@ class BackupDatabase: def get_last_message_id(self, channel_id: str) -> Optional[str]: with self._lock: - row = self._conn.execute("SELECT id FROM messages WHERE channel_id = ? ORDER BY id DESC LIMIT 1", (str(channel_id),)).fetchone() + row = self._conn.execute("SELECT id FROM messages WHERE channel_id = ? ORDER BY id DESC LIMIT 1", (parse_snowflake(channel_id),)).fetchone() return row["id"] if row else None def get_media_by_hash(self, file_hash: str) -> Optional[Dict[str, Any]]: @@ -600,7 +668,7 @@ class BackupDatabase: """Returns forum tag definitions.""" with self._lock: if forum_id: - rows = self._conn.execute("SELECT * FROM forum_tags WHERE forum_id = ?", (str(forum_id),)).fetchall() + rows = self._conn.execute("SELECT * FROM forum_tags WHERE forum_id = ?", (parse_snowflake(forum_id),)).fetchall() else: rows = self._conn.execute("SELECT * FROM forum_tags").fetchall() return [dict(r) for r in rows] @@ -608,13 +676,13 @@ class BackupDatabase: def get_threads_by_parent(self, parent_id: str) -> List[Dict[str, Any]]: """Returns all threads belonging to a parent channel.""" with self._lock: - rows = self._conn.execute("SELECT * FROM threads WHERE parent_id = ?", (str(parent_id),)).fetchall() + rows = self._conn.execute("SELECT * FROM threads WHERE parent_id = ?", (parse_snowflake(parent_id),)).fetchall() return [dict(r) for r in rows] def get_thread(self, thread_id: str) -> Optional[Dict[str, Any]]: """Retrieves a single thread's metadata.""" with self._lock: - row = self._conn.execute("SELECT * FROM threads WHERE id = ?", (str(thread_id),)).fetchone() + row = self._conn.execute("SELECT * FROM threads WHERE id = ?", (parse_snowflake(thread_id),)).fetchone() return dict(row) if row else None def get_all_users(self) -> List[Dict[str, Any]]: @@ -624,7 +692,7 @@ class BackupDatabase: def get_user(self, user_id: str) -> Optional[Dict[str, Any]]: with self._lock: - row = self._conn.execute("SELECT * FROM users WHERE id = ?", (str(user_id),)).fetchone() + row = self._conn.execute("SELECT * FROM users WHERE id = ?", (parse_snowflake(user_id),)).fetchone() if row: data = dict(row) if data.get("roles"): @@ -650,11 +718,88 @@ class BackupDatabase: def get_messages_paged(self, channel_id: str, limit: int = 100, offset: int = 0, after_id: Optional[str] = None) -> List[Dict[str, Any]]: with self._lock: query = "SELECT * FROM messages WHERE channel_id = ?" - params = [str(channel_id)] + params = [parse_snowflake(channel_id)] if after_id: query += " AND id > ?" - params.append(str(after_id)) + params.append(parse_snowflake(after_id)) + + query += " ORDER BY id ASC LIMIT ? OFFSET ?" + params.extend([limit, offset]) + + rows = self._conn.execute(query, params).fetchall() + msg_list = [dict(r) for r in rows] + + if msg_list: + msg_ids = [m["id"] for m in msg_list] + placeholders = ",".join(["?"] * len(msg_ids)) + + att_rows = self._conn.execute(f"SELECT * FROM attachments WHERE message_id IN ({placeholders})", msg_ids).fetchall() + atts_by_msg = {} + for ar in att_rows: + mid = ar["message_id"] + if mid not in atts_by_msg: atts_by_msg[mid] = [] + atts_by_msg[mid].append(dict(ar)) + + emb_rows = self._conn.execute(f"SELECT * FROM embeds WHERE message_id IN ({placeholders})", msg_ids).fetchall() + embs_by_msg = {} + for er in emb_rows: + mid = er["message_id"] + if mid not in embs_by_msg: embs_by_msg[mid] = [] + + e_dict = { + "title": er["title"], + "description": er["description"], + "url": er["url"], + "color": er["color"], + "timestamp": er["timestamp"], + "thumbnail": {"url": er["thumbnail_url"]} if er["thumbnail_url"] else None, + "image": {"url": er["image_url"]} if er["image_url"] else None, + "author": { + "name": er["author_name"], + "url": er["author_url"], + "icon_url": er["author_icon_url"] + } if er["author_name"] else None, + "footer": { + "text": er["footer_text"], + "icon_url": er["footer_icon_url"] + } if er["footer_text"] else None, + "fields": json.loads(er["fields"]) if er["fields"] else [] + } + embs_by_msg[mid].append(e_dict) + + rea_rows = self._conn.execute(f"SELECT * FROM reactions WHERE message_id IN ({placeholders})", msg_ids).fetchall() + reas_by_msg = {} + for rr in rea_rows: + mid = rr["message_id"] + if mid not in reas_by_msg: reas_by_msg[mid] = [] + reas_by_msg[mid].append(dict(rr)) + + st_rows = self._conn.execute(f"SELECT * FROM message_stickers WHERE message_id IN ({placeholders})", msg_ids).fetchall() + sts_by_msg = {} + for sr in st_rows: + mid = sr["message_id"] + if mid not in sts_by_msg: sts_by_msg[mid] = [] + sts_by_msg[mid].append(dict(sr)) + + for m in msg_list: + m_id = m["id"] + m["attachments"] = atts_by_msg.get(m_id, []) + m["embeds"] = embs_by_msg.get(m_id, []) + m["reactions"] = reas_by_msg.get(m_id, []) + m["stickers"] = sts_by_msg.get(m_id, []) + + return msg_list + + def get_global_messages_paged(self, limit: int = 100, offset: int = 0, after_id: Optional[str] = None) -> List[Dict[str, Any]]: + """Fetches messages across ALL channels globally, ordered by timestamp/ID ascending.""" + with self._lock: + query = "SELECT * FROM messages" + params = [] + + if after_id: + query += " WHERE id > ?" + params.append(parse_snowflake(after_id)) query += " ORDER BY id ASC LIMIT ? OFFSET ?" params.extend([limit, offset]) @@ -725,7 +870,7 @@ class BackupDatabase: def delete_channel_messages(self, channel_id: Union[str, int]): """Deletes all messages and related metadata for a specific channel and its threads.""" - cid = str(channel_id) + cid = parse_snowflake(channel_id) with self._lock: # 1. Identify all channel IDs involved (parent + all threads) target_ids = [cid] @@ -807,4 +952,4 @@ class BackupDatabase: self._conn.commit() self._conn.close() except Exception: - pass + pass \ No newline at end of file diff --git a/src/core/backup_reader.py b/src/core/backup_reader.py index 408e16e..c9d250d 100644 --- a/src/core/backup_reader.py +++ b/src/core/backup_reader.py @@ -6,12 +6,13 @@ Discord API. Implements the same public interface as DiscordReader so that migration scripts and UI code can use either provider transparently. """ +from __future__ import annotations import json import logging from datetime import datetime, timezone from enum import IntEnum from pathlib import Path -from typing import AsyncGenerator, Dict, Any, List, Optional +from typing import AsyncGenerator, Dict, Any, List, Optional, Union from src.core.backup_database import BackupDatabase, parse_snowflake logger = logging.getLogger(__name__) @@ -1416,6 +1417,42 @@ class BackupReader: if len(msgs) < batch_size: break + async def fetch_global_message_history( + self, + limit: int = None, + after_id: int = None + ) -> AsyncGenerator["BackupMessage", None]: + """Yields BackupMessages globally from SQLite across all channels, natively ordered by timestamp/ID.""" + if not self.db: return + + offset = 0 + batch_size = 100 + count = 0 + + while True: + actual_limit = batch_size + if limit: + rem = limit - count + if rem <= 0: break + actual_limit = min(batch_size, rem) + + msgs = self.db.get_global_messages_paged( + limit=actual_limit, + offset=offset, + after_id=str(after_id) if after_id else None + ) + + if not msgs: + break + + for m in msgs: + yield self._hydrate_message(m) + count += 1 + + offset += len(msgs) + if len(msgs) < batch_size: + break + # ── download helpers ───────────────────────────────────────────────── async def download_emoji(self, emoji: BackupEmoji) -> bytes: diff --git a/src/core/base.py b/src/core/base.py index 87086b1..81943f8 100644 --- a/src/core/base.py +++ b/src/core/base.py @@ -129,7 +129,7 @@ class MigrationContext: # or a logical subfolder. base_dir = getattr(self, "base_dir", "") - self.state.set_folder(community_id, clean_name, base_dir=base_dir) + self.state.set_folder(community_id, clean_name, self.target_platform, base_dir=base_dir) async def start_connections(self): await self.discord_reader.start() diff --git a/src/core/database.py b/src/core/database.py index b2230b9..47578ea 100644 --- a/src/core/database.py +++ b/src/core/database.py @@ -1,11 +1,13 @@ +from __future__ import annotations import sqlite3 import logging import json import random from pathlib import Path -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, Union import threading import sys +from src.core.utils import parse_snowflake logger = logging.getLogger(__name__) @@ -15,8 +17,9 @@ class MigrationDatabase: Replaces the memory-bloated and O(N^2) JSON persistence for messages. """ - def __init__(self, db_path: Path): + def __init__(self, db_path: Path, platform: str = None): self.db_path = db_path + self.platform = platform.lower() if platform else None self._local = threading.local() self._init_db() @@ -27,38 +30,118 @@ class MigrationDatabase: return self._local.conn def _init_db(self): - """Initialize tables if they don't exist.""" + """Initialize tables if they don't exist and handle migrations/platform-specific schemas.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() + # 1. Determine active platform and column types + # Create metadata table first as we need it for platform tracking + cursor.execute("CREATE TABLE IF NOT EXISTS metadata (key TEXT PRIMARY KEY, value TEXT)") + + # Load platform from DB if exists + cursor.execute("SELECT value FROM metadata WHERE key = ?", ("target_platform",)) + row = cursor.fetchone() + stored_platform = row[0] if row else None + + # If platform provided, update stored platform. If not provided, use stored. + active_platform = self.platform or stored_platform or "stoat" # Default to stoat if unknown + if self.platform and self.platform != stored_platform: + cursor.execute("INSERT OR REPLACE INTO metadata (key, value) VALUES (?, ?)", ("target_platform", active_platform)) + conn.commit() + + # Define types + # source_type is always INTEGER (Discord/Fluxer snowflakes) + # target_type is INTEGER for Fluxer, TEXT for Stoat + source_type = "INTEGER" + target_type = "INTEGER" if active_platform == "fluxer" else "TEXT" + + # 2. Universal ID Migration (TEXT -> INTEGER vs Platform Switch) + # Mapping of table names to columns that must match their respective types + # key: table, value: (discord_cols, target_cols) + id_migrations = { + "message_mappings": (["source_msg_id"], ["channel_id", "target_msg_id"]), + "thread_mappings": (["source_msg_id"], ["channel_id", "thread_id", "target_msg_id"]), + "channel_tracking": ([], ["channel_id", "last_msg_id"]), + "thread_tracking": ([], ["channel_id", "thread_id", "last_msg_id"]), + "server_mappings": (["source_id"], ["target_id"]), + "asset_mappings": (["source_id"], ["target_id"]), + "user_alias": (["user_id"], []) + } + + for table, (discord_cols, target_cols) in id_migrations.items(): + cursor.execute(f"SELECT count(*) FROM sqlite_master WHERE type='table' AND name='{table}'") + res = cursor.fetchone() + if not res or res[0] == 0: + continue + + cursor.execute(f"PRAGMA table_info({table})") + cols = cursor.fetchall() + needs_migration = False + for col in cols: + c_name, c_type = col[1], col[2] + if c_name in discord_cols and c_type != source_type: + needs_migration = True + break + if c_name in target_cols and c_type != target_type: + needs_migration = True + break + + if needs_migration: + logger.info(f"MigrationDatabase: Migrating {table} schema (Platform: {active_platform})") + cursor.execute(f"ALTER TABLE {table} RENAME TO {table}_old") + + if table == "message_mappings": + cursor.execute(f"CREATE TABLE message_mappings (channel_id {target_type}, source_msg_id {source_type}, target_msg_id {target_type}, timestamp TEXT, PRIMARY KEY (channel_id, source_msg_id))") + elif table == "thread_mappings": + cursor.execute(f"CREATE TABLE thread_mappings (channel_id {target_type}, thread_id {target_type}, source_msg_id {source_type}, target_msg_id {target_type}, timestamp TEXT, PRIMARY KEY (channel_id, thread_id, source_msg_id))") + elif table == "channel_tracking": + cursor.execute(f"CREATE TABLE channel_tracking (channel_id {target_type} PRIMARY KEY, last_msg_id {target_type}, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0)") + elif table == "thread_tracking": + cursor.execute(f"CREATE TABLE thread_tracking (channel_id {target_type}, thread_id {target_type}, last_msg_id {target_type}, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0, completed INTEGER DEFAULT 0, PRIMARY KEY (channel_id, thread_id))") + elif table == "server_mappings": + cursor.execute(f"CREATE TABLE server_mappings (category TEXT, source_id {source_type}, target_id {target_type}, PRIMARY KEY (category, source_id))") + elif table == "asset_mappings": + cursor.execute(f"CREATE TABLE asset_mappings (category TEXT, source_id {source_type}, target_id {target_type}, PRIMARY KEY (category, source_id))") + elif table == "user_alias": + cursor.execute(f"CREATE TABLE user_alias (user_id {source_type} PRIMARY KEY, alias TEXT UNIQUE)") + + old_cols = [c[1] for c in cursor.execute(f"PRAGMA table_info({table}_old)").fetchall()] + new_cols = [c[1] for c in cursor.execute(f"PRAGMA table_info({table})").fetchall()] + common_cols = [c for c in old_cols if c in new_cols] + col_str = ", ".join(common_cols) + + cursor.execute(f"INSERT OR IGNORE INTO {table} ({col_str}) SELECT {col_str} FROM {table}_old") + cursor.execute(f"DROP TABLE {table}_old") + + # Initial Creation / Ensure Schema Correctness # Table for message mappings: SourceID -> TargetID - cursor.execute(""" + cursor.execute(f""" CREATE TABLE IF NOT EXISTS message_mappings ( - channel_id TEXT, - source_msg_id TEXT, - target_msg_id TEXT, + channel_id {target_type}, + source_msg_id {source_type}, + target_msg_id {target_type}, timestamp TEXT, PRIMARY KEY (channel_id, source_msg_id) ) """) # Table for thread mappings - cursor.execute(""" + cursor.execute(f""" CREATE TABLE IF NOT EXISTS thread_mappings ( - channel_id TEXT, - thread_id TEXT, - source_msg_id TEXT, - target_msg_id TEXT, + channel_id {target_type}, + thread_id {target_type}, + source_msg_id {source_type}, + target_msg_id {target_type}, timestamp TEXT, PRIMARY KEY (channel_id, thread_id, source_msg_id) ) """) # Table for per-channel stats and tracking - cursor.execute(""" + cursor.execute(f""" CREATE TABLE IF NOT EXISTS channel_tracking ( - channel_id TEXT PRIMARY KEY, - last_msg_id TEXT, + channel_id {target_type} PRIMARY KEY, + last_msg_id {target_type}, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0 @@ -66,11 +149,11 @@ class MigrationDatabase: """) # Table for per-thread stats - cursor.execute(""" + cursor.execute(f""" CREATE TABLE IF NOT EXISTS thread_tracking ( - channel_id TEXT, - thread_id TEXT, - last_msg_id TEXT, + channel_id {target_type}, + thread_id {target_type}, + last_msg_id {target_type}, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0, @@ -79,32 +162,32 @@ class MigrationDatabase: ) """) - # Add completed column if it doesn't exist (backward compatibility for existing resumption DBs) + # Add completed column if it doesn't exist (backward compatibility) try: cursor.execute("ALTER TABLE thread_tracking ADD COLUMN completed INTEGER DEFAULT 0") except sqlite3.OperationalError: pass # Already exists # Table for server entity mappings (channels, roles, categories) - cursor.execute(""" + cursor.execute(f""" CREATE TABLE IF NOT EXISTS server_mappings ( category TEXT, - source_id TEXT, - target_id TEXT, + source_id {source_type}, + target_id {target_type}, PRIMARY KEY (category, source_id) ) """) # Table for asset mappings (emojis, stickers) - cursor.execute(""" + cursor.execute(f""" CREATE TABLE IF NOT EXISTS asset_mappings ( category TEXT, - source_id TEXT, - target_id TEXT, + source_id {source_type}, + target_id {target_type}, PRIMARY KEY (category, source_id) ) """) - + # Migrate old entity_mappings if it exists cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='entity_mappings'") if cursor.fetchone(): @@ -125,18 +208,10 @@ class MigrationDatabase: # Drop old table cursor.execute("DROP TABLE entity_mappings") - # Table for general metadata - cursor.execute(""" - CREATE TABLE IF NOT EXISTS metadata ( - key TEXT PRIMARY KEY, - value TEXT - ) - """) - # Table for auto-generated user aliases (user_id -> alias) - cursor.execute(""" + cursor.execute(f""" CREATE TABLE IF NOT EXISTS user_alias ( - user_id TEXT PRIMARY KEY, + user_id {source_type} PRIMARY KEY, alias TEXT UNIQUE ) """) @@ -152,17 +227,30 @@ class MigrationDatabase: conn = self._get_conn() conn.execute( "INSERT OR REPLACE INTO message_mappings (channel_id, source_msg_id, target_msg_id, timestamp) VALUES (?, ?, ?, ?)", - (channel_id, source_id, target_id, timestamp) + (str(channel_id), parse_snowflake(source_id), str(target_id), timestamp) ) conn.commit() - def get_target_message_id(self, channel_id: str, source_id: str) -> Optional[str]: + def get_target_message_id(self, channel_id: str, source_id: str) -> Optional[Union[str, int]]: conn = self._get_conn() row = conn.execute( "SELECT target_msg_id FROM message_mappings WHERE channel_id = ? AND source_msg_id = ?", - (channel_id, source_id) + (str(channel_id), parse_snowflake(source_id)) ).fetchone() - return row["target_msg_id"] if row else None + if row: + val = row["target_msg_id"] + return str(val) if self.platform == "stoat" else val + return None + + def get_all_message_mappings(self, channel_id: str) -> Dict[Union[str, int], Union[str, int]]: + conn = self._get_conn() + rows = conn.execute( + "SELECT source_msg_id, target_msg_id FROM message_mappings WHERE channel_id = ?", + (str(channel_id),) + ).fetchall() + if self.platform == "stoat": + return {str(row["source_msg_id"]): str(row["target_msg_id"]) for row in rows} + return {row["source_msg_id"]: row["target_msg_id"] for row in rows} # --- User Alias Methods --- @@ -205,7 +293,7 @@ class MigrationDatabase: conn = self._get_conn() # Check for existing alias - row = conn.execute("SELECT alias FROM user_alias WHERE user_id = ?", (str(user_id),)).fetchone() + row = conn.execute("SELECT alias FROM user_alias WHERE user_id = ?", (parse_snowflake(user_id),)).fetchone() if row: return row["alias"] @@ -215,7 +303,7 @@ class MigrationDatabase: new_alias = self._generate_alias() conn.execute( "INSERT INTO user_alias (user_id, alias) VALUES (?, ?)", - (str(user_id), new_alias) + (parse_snowflake(user_id), new_alias) ) conn.commit() return new_alias @@ -239,31 +327,36 @@ class MigrationDatabase: conn = self._get_conn() conn.execute( "INSERT OR REPLACE INTO server_mappings (category, source_id, target_id) VALUES (?, ?, ?)", - (category, str(source_id), str(target_id)) + (category, parse_snowflake(source_id), str(target_id)) ) conn.commit() - def get_server_mapping(self, category: str, source_id: str) -> Optional[str]: + def get_server_mapping(self, category: str, source_id: str) -> Optional[Union[str, int]]: conn = self._get_conn() row = conn.execute( "SELECT target_id FROM server_mappings WHERE category = ? AND source_id = ?", - (category, str(source_id)) + (category, parse_snowflake(source_id)) ).fetchone() - return row["target_id"] if row else None + if row: + val = row["target_id"] + return str(val) if self.platform == "stoat" else val + return None - def get_all_server_mappings(self, category: str) -> Dict[str, str]: + def get_all_server_mappings(self, category: str) -> Dict[Union[str, int], Union[str, int]]: conn = self._get_conn() rows = conn.execute( "SELECT source_id, target_id FROM server_mappings WHERE category = ?", (category,) ).fetchall() + if self.platform == "stoat": + return {str(row["source_id"]): str(row["target_id"]) for row in rows} return {row["source_id"]: row["target_id"] for row in rows} def delete_server_mapping(self, category: str, source_id: str): conn = self._get_conn() conn.execute( "DELETE FROM server_mappings WHERE category = ? AND source_id = ?", - (category, str(source_id)) + (category, parse_snowflake(source_id)) ) conn.commit() @@ -281,31 +374,36 @@ class MigrationDatabase: conn = self._get_conn() conn.execute( "INSERT OR REPLACE INTO asset_mappings (category, source_id, target_id) VALUES (?, ?, ?)", - (category, str(source_id), str(target_id)) + (category, parse_snowflake(source_id), str(target_id)) ) conn.commit() - def get_asset_mapping(self, category: str, source_id: str) -> Optional[str]: + def get_asset_mapping(self, category: str, source_id: str) -> Optional[Union[str, int]]: conn = self._get_conn() row = conn.execute( "SELECT target_id FROM asset_mappings WHERE category = ? AND source_id = ?", - (category, str(source_id)) + (category, parse_snowflake(source_id)) ).fetchone() - return row["target_id"] if row else None + if row: + val = row["target_id"] + return str(val) if self.platform == "stoat" else val + return None - def get_all_asset_mappings(self, category: str) -> Dict[str, str]: + def get_all_asset_mappings(self, category: str) -> Dict[Union[str, int], Union[str, int]]: conn = self._get_conn() rows = conn.execute( "SELECT source_id, target_id FROM asset_mappings WHERE category = ?", (category,) ).fetchall() + if self.platform == "stoat": + return {str(row["source_id"]): str(row["target_id"]) for row in rows} return {row["source_id"]: row["target_id"] for row in rows} def delete_asset_mapping(self, category: str, source_id: str): conn = self._get_conn() conn.execute( "DELETE FROM asset_mappings WHERE category = ? AND source_id = ?", - (category, str(source_id)) + (category, parse_snowflake(source_id)) ) conn.commit() @@ -332,76 +430,134 @@ class MigrationDatabase: def update_channel_tracking(self, channel_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0): conn = self._get_conn() # Initialize if missing - conn.execute("INSERT OR IGNORE INTO channel_tracking (channel_id) VALUES (?)", (channel_id,)) + conn.execute("INSERT OR IGNORE INTO channel_tracking (channel_id) VALUES (?)", (str(channel_id),)) if last_msg_id: - conn.execute("UPDATE channel_tracking SET last_msg_id = ? WHERE channel_id = ?", (last_msg_id, channel_id)) + conn.execute("UPDATE channel_tracking SET last_msg_id = ? WHERE channel_id = ?", (str(last_msg_id), str(channel_id))) if last_msg_ts: - conn.execute("UPDATE channel_tracking SET last_msg_ts = ? WHERE channel_id = ?", (last_msg_ts, channel_id)) + conn.execute("UPDATE channel_tracking SET last_msg_ts = ? WHERE channel_id = ?", (last_msg_ts, str(channel_id))) if msg_inc != 0 or file_inc != 0: conn.execute( "UPDATE channel_tracking SET msg_count = msg_count + ?, file_count = file_count + ? WHERE channel_id = ?", - (msg_inc, file_inc, channel_id) + (msg_inc, file_inc, str(channel_id)) ) conn.commit() def get_channel_tracking(self, channel_id: str) -> Dict[str, Any]: conn = self._get_conn() - row = conn.execute("SELECT * FROM channel_tracking WHERE channel_id = ?", (channel_id,)).fetchone() + row = conn.execute("SELECT * FROM channel_tracking WHERE channel_id = ?", (str(channel_id),)).fetchone() if row: return dict(row) return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0} + + def get_global_min_last_message_id(self, all_mapped_ids: List[str]) -> Optional[int]: + """ + Returns the minimum last_msg_id successfully migrated across all mapped channels/threads. + If any mapped entity has no progress record, it is treated as ID 0. + Returns None only if NO progress has been made across ANY entity. + """ + if not all_mapped_ids: + return None + + conn = self._get_conn() + placeholders = ",".join(["?"] * len(all_mapped_ids)) + + # 1. Get last message IDs from channel tracking + c_rows = conn.execute(f"SELECT channel_id, last_msg_id FROM channel_tracking WHERE channel_id IN ({placeholders})", all_mapped_ids).fetchall() + c_map = {r["channel_id"]: r["last_msg_id"] for r in c_rows} + + # 2. Get last message IDs from thread tracking + t_rows = conn.execute(f"SELECT thread_id, last_msg_id FROM thread_tracking WHERE thread_id IN ({placeholders})", all_mapped_ids).fetchall() + t_map = {r["thread_id"]: r["last_msg_id"] for r in t_rows} + + # Combine maps + progress_map = {**c_map, **t_map} + + # 3. Aggregate IDs + ids = [] + has_any_progress = False + for mid in all_mapped_ids: + last_id = progress_map.get(mid) + if not last_id: + ids.append(0) # Unmigrated entity + else: + try: + ids.append(int(last_id)) + has_any_progress = True + except (ValueError, TypeError): + ids.append(0) + + if not has_any_progress: + return None + + return min(ids) + # Thread methods similar to channel methods def set_thread_message_mapping(self, channel_id: str, thread_id: str, source_id: str, target_id: str, timestamp: str = None): conn = self._get_conn() conn.execute( "INSERT OR REPLACE INTO thread_mappings (channel_id, thread_id, source_msg_id, target_msg_id, timestamp) VALUES (?, ?, ?, ?, ?)", - (channel_id, thread_id, source_id, target_id, timestamp) + (str(channel_id), str(thread_id), parse_snowflake(source_id), str(target_id), timestamp) ) conn.commit() - def get_target_thread_message_id(self, channel_id: str, thread_id: str, source_id: str) -> Optional[str]: + def get_target_thread_message_id(self, channel_id: str, thread_id: str, source_id: str) -> Optional[Union[str, int]]: conn = self._get_conn() row = conn.execute( "SELECT target_msg_id FROM thread_mappings WHERE channel_id = ? AND thread_id = ? AND source_msg_id = ?", - (channel_id, thread_id, source_id) + (str(channel_id), str(thread_id), parse_snowflake(source_id)) ).fetchone() - return row["target_msg_id"] if row else None + if row: + val = row["target_msg_id"] + return str(val) if self.platform == "stoat" else val + return None def update_thread_tracking(self, channel_id: str, thread_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0, completed: int = None): conn = self._get_conn() - conn.execute("INSERT OR IGNORE INTO thread_tracking (channel_id, thread_id) VALUES (?, ?)", (channel_id, thread_id)) + conn.execute("INSERT OR IGNORE INTO thread_tracking (channel_id, thread_id) VALUES (?, ?)", (str(channel_id), str(thread_id))) if last_msg_id: - conn.execute("UPDATE thread_tracking SET last_msg_id = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_id, channel_id, thread_id)) + conn.execute("UPDATE thread_tracking SET last_msg_id = ? WHERE channel_id = ? AND thread_id = ?", (str(last_msg_id), str(channel_id), str(thread_id))) if last_msg_ts: - conn.execute("UPDATE thread_tracking SET last_msg_ts = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_ts, channel_id, thread_id)) + conn.execute("UPDATE thread_tracking SET last_msg_ts = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_ts, str(channel_id), str(thread_id))) if completed is not None: - conn.execute("UPDATE thread_tracking SET completed = ? WHERE channel_id = ? AND thread_id = ?", (completed, channel_id, thread_id)) + conn.execute("UPDATE thread_tracking SET completed = ? WHERE channel_id = ? AND thread_id = ?", (completed, str(channel_id), str(thread_id))) if msg_inc != 0 or file_inc != 0: conn.execute( "UPDATE thread_tracking SET msg_count = msg_count + ?, file_count = file_count + ? WHERE channel_id = ? AND thread_id = ?", - (msg_inc, file_inc, channel_id, thread_id) + (msg_inc, file_inc, str(channel_id), str(thread_id)) ) conn.commit() def get_thread_tracking(self, channel_id: str, thread_id: str) -> Dict[str, Any]: conn = self._get_conn() - row = conn.execute("SELECT * FROM thread_tracking WHERE channel_id = ? AND thread_id = ?", (channel_id, thread_id)).fetchone() + row = conn.execute("SELECT * FROM thread_tracking WHERE channel_id = ? AND thread_id = ?", (str(channel_id), str(thread_id))).fetchone() if row: return dict(row) return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0} + def get_all_channel_tracking_ids(self) -> Dict[str, str]: + """Returns a map of channel_id -> last_msg_id for all tracked channels.""" + conn = self._get_conn() + rows = conn.execute("SELECT channel_id, last_msg_id FROM channel_tracking WHERE last_msg_id IS NOT NULL").fetchall() + return {str(row["channel_id"]): str(row["last_msg_id"]) for row in rows} + + def get_all_thread_tracking_ids(self) -> Dict[str, str]: + """Returns a map of thread_id -> last_msg_id for all tracked threads.""" + conn = self._get_conn() + rows = conn.execute("SELECT thread_id, last_msg_id FROM thread_tracking WHERE last_msg_id IS NOT NULL").fetchall() + return {str(row["thread_id"]): str(row["last_msg_id"]) for row in rows} + def clear_channel_data(self, channel_id: str): """Purge all mappings and tracking data for a specific channel and its threads.""" conn = self._get_conn() - conn.execute("DELETE FROM message_mappings WHERE channel_id = ?", (channel_id,)) - conn.execute("DELETE FROM thread_mappings WHERE channel_id = ?", (channel_id,)) - conn.execute("DELETE FROM channel_tracking WHERE channel_id = ?", (channel_id,)) - conn.execute("DELETE FROM thread_tracking WHERE channel_id = ?", (channel_id,)) + conn.execute("DELETE FROM message_mappings WHERE channel_id = ?", (str(channel_id),)) + conn.execute("DELETE FROM thread_mappings WHERE channel_id = ?", (str(channel_id),)) + conn.execute("DELETE FROM channel_tracking WHERE channel_id = ?", (str(channel_id),)) + conn.execute("DELETE FROM thread_tracking WHERE channel_id = ?", (str(channel_id),)) conn.commit() logger.info(f"Cleared all tracking and mapping data for channel: {channel_id}") diff --git a/src/core/exporter.py b/src/core/exporter.py index e60aaf7..7ad8926 100644 --- a/src/core/exporter.py +++ b/src/core/exporter.py @@ -109,7 +109,7 @@ class DiscordExporter: "name": r.name, "color": r.color.value, "position": r.position, - "permissions": str(r.permissions.value), + "permissions": r.permissions.value, "hoist": 1 if r.hoist else 0, "mentionable": 1 if r.mentionable else 0 }) @@ -563,9 +563,10 @@ class DiscordExporter: }) # 5. Message data + from src.core.utils import parse_snowflake message_reference = None if msg.reference and msg.reference.message_id: - message_reference = str(msg.reference.message_id) + message_reference = parse_snowflake(msg.reference.message_id) # 5.5 Forwarded snapshots content = msg.content or "" diff --git a/src/core/state.py b/src/core/state.py index 0e8d181..eb816c5 100644 --- a/src/core/state.py +++ b/src/core/state.py @@ -1,7 +1,10 @@ import json import logging from pathlib import Path -from typing import Dict, Any, Optional +from typing import Dict, Optional, Any, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from src.core.database import MigrationDatabase logger = logging.getLogger(__name__) @@ -35,7 +38,14 @@ class MigrationState: if self.db: self.db.delete_server_mapping("channel", str(discord_id)) + def remove_target_channel_mapping(self, discord_id: int | str): + if self.db: + self.db.delete_server_mapping("channel", str(discord_id)) + def set_target_channel_id(self, discord_id: int | str, target_id: str, *args): + """Alias for set_channel_mapping to handle legacy calls.""" + self.set_channel_mapping(discord_id, target_id) + get_fluxer_channel_id = get_target_channel_id set_target_channel_mapping = set_channel_mapping @@ -55,6 +65,10 @@ class MigrationState: if self.db: self.db.delete_server_mapping("category", str(discord_id)) + def set_target_category_id(self, discord_id: int | str, target_id: str, *args): + """Alias for set_category_mapping to handle legacy calls.""" + self.set_category_mapping(discord_id, target_id) + get_fluxer_category_id = get_category_mapping get_target_category_id = get_category_mapping set_target_category_mapping = set_category_mapping @@ -75,6 +89,10 @@ class MigrationState: if self.db: self.db.delete_server_mapping("role", str(discord_id)) + def set_target_role_id(self, discord_id: int | str, target_id: str, *args): + """Alias for set_role_mapping to handle legacy calls.""" + self.set_role_mapping(discord_id, target_id) + get_fluxer_role_id = get_role_mapping get_target_role_id = get_role_mapping set_target_role_mapping = set_role_mapping @@ -119,23 +137,23 @@ class MigrationState: # --- Properties for backward compatibility --- @property - def channel_map(self) -> Dict[str, str]: + def channel_map(self) -> Dict[Union[str, int], Union[str, int]]: return self.db.get_all_server_mappings("channel") if self.db else {} @property - def category_map(self) -> Dict[str, str]: + def category_map(self) -> Dict[Union[str, int], Union[str, int]]: return self.db.get_all_server_mappings("category") if self.db else {} @property - def role_map(self) -> Dict[str, str]: + def role_map(self) -> Dict[Union[str, int], Union[str, int]]: return self.db.get_all_server_mappings("role") if self.db else {} @property - def emoji_map(self) -> Dict[str, str]: + def emoji_map(self) -> Dict[Union[str, int], Union[str, int]]: return self.db.get_all_asset_mappings("emoji") if self.db else {} @property - def sticker_map(self) -> Dict[str, str]: + def sticker_map(self) -> Dict[Union[str, int], Union[str, int]]: return self.db.get_all_asset_mappings("sticker") if self.db else {} @property @@ -153,7 +171,7 @@ class MigrationState: if self._ensure_db(): self.db.set_message_mapping(str(target_channel_id), str(discord_id), str(target_id)) - def get_target_message_id(self, target_channel_id: str, discord_id: str) -> str | None: + def get_target_message_id(self, target_channel_id: str, discord_id: str) -> str | int | None: if self._ensure_db(): return self.db.get_target_message_id(str(target_channel_id), str(discord_id)) return None @@ -161,7 +179,7 @@ class MigrationState: def set_message_mapping(self, target_channel_id: str, discord_id: str, target_id: str): self.set_target_message_mapping(target_channel_id, discord_id, target_id) - def get_fluxer_message_id(self, target_channel_id: str, discord_id: str) -> str | None: + def get_fluxer_message_id(self, target_channel_id: str, discord_id: str) -> str | int | None: return self.get_target_message_id(target_channel_id, discord_id) def increment_stats(self, target_channel_id: str, messages: int = 1, files: int = 0): @@ -209,8 +227,34 @@ class MigrationState: def get_last_message_id(self, target_channel_id: str) -> str | None: if self._ensure_db(): - return self.db.get_channel_tracking(str(target_channel_id)).get("last_msg_id") + tracking = self.db.get_channel_tracking(str(target_channel_id)) + return tracking.get("last_msg_id") if tracking else None return None + + + def get_global_min_last_message_id(self, all_mapped_ids: List[str]) -> int | None: + """Returns the absolute minimum last_msg_id among the given list of mapped target IDs (channels and threads).""" + if self._ensure_db(): + return self.db.get_global_min_last_message_id(all_mapped_ids) + return None + + def set_waterfall_last_id(self, last_id: str | int): + if self.db: + self.db.set_metadata("waterfall_last_id", str(last_id)) + + def get_waterfall_last_id(self) -> int | None: + if self.db: + val = self.db.get_metadata("waterfall_last_id") + return int(val) if val else None + return None + + def get_all_last_message_ids(self) -> Dict[str, str]: + """Returns a combined map of channel_id/thread_id -> last_msg_id.""" + if self._ensure_db(): + c_map = self.db.get_all_channel_tracking_ids() + t_map = self.db.get_all_thread_tracking_ids() + return {**c_map, **t_map} + return {} def get_thread_last_message_id(self, target_channel_id: str, thread_id: str) -> str | None: if self._ensure_db(): @@ -258,7 +302,7 @@ class MigrationState: if self._ensure_db(): self.db.clear_channel_data(str(target_channel_id)) - def set_folder(self, server_id: str, clean_name: str, base_dir: Path | str = ""): + def set_folder(self, server_id: str, clean_name: str, platform: str = "stoat", base_dir: Path | str = ""): """ Initializes the SQLite database based on community name and ID. Filename: {name}-{id}.db (Flat structure) @@ -294,7 +338,7 @@ class MigrationState: from src.core.database import MigrationDatabase if self.db: self.db.close() - self.db = MigrationDatabase(db_path) + self.db = MigrationDatabase(db_path, platform=platform) logger.info(f"Initialized SQLite database at {db_path}") def get_user_alias(self, user_id: str) -> str | None: diff --git a/src/core/utils.py b/src/core/utils.py index 1375ef5..1167d47 100644 --- a/src/core/utils.py +++ b/src/core/utils.py @@ -1,14 +1,29 @@ +from typing import Any, Optional import re import logging -from src.core.state import MigrationState + +def parse_snowflake(value: Any) -> Optional[int]: + """Safely parses a Discord ID (Snowflake) from any input, handling 'None' strings.""" + if value is None: + return None + s = str(value).strip() + if not s or s.lower() == "none" or s == "NULL": + return None + try: + return int(s) + except ValueError: + return None logger = logging.getLogger(__name__) -def resolve_discord_links(content: str, state: MigrationState, platform: str, target_server_id: str) -> str: +def resolve_discord_links(content: str, state, platform: str, target_server_id: str) -> str: """ Finds Discord message/channel links and resolves them to the target platform if they have been migrated. """ + from src.core.state import MigrationState + if not isinstance(state, MigrationState): + logger.warning(f"resolve_discord_links: state is not MigrationState (type: {type(state)})") if not content: return content diff --git a/src/fluxer/emoji_stickers.py b/src/fluxer/emoji_stickers.py index f8de0ec..7a7e8fe 100644 --- a/src/fluxer/emoji_stickers.py +++ b/src/fluxer/emoji_stickers.py @@ -18,10 +18,10 @@ async def sync_assets_state(context: MigrationContext): fluxer_stickers = await context.fluxer_writer.client.get_guild_stickers(context.config.fluxer_server_id) # Build name -> id maps and ID sets for Fluxer for fast lookup - fluxer_emoji_map = {e.get("name"): str(e.get("id")) for e in fluxer_emojis if e.get("name")} - fluxer_sticker_map = {s.get("name"): str(s.get("id")) for s in fluxer_stickers if s.get("name")} - fluxer_emoji_ids = {str(e.get("id")) for e in fluxer_emojis} - fluxer_sticker_ids = {str(s.get("id")) for s in fluxer_stickers} + fluxer_emoji_map = {e.get("name"): e.get("id") for e in fluxer_emojis if e.get("name")} + fluxer_sticker_map = {s.get("name"): s.get("id") for s in fluxer_stickers if s.get("name")} + fluxer_emoji_ids = {e.get("id") for e in fluxer_emojis} + fluxer_sticker_ids = {s.get("id") for s in fluxer_stickers} updates = 0 removals = 0 diff --git a/src/fluxer/migrate_message.py b/src/fluxer/migrate_message.py index dc644cb..5a2c0ea 100644 --- a/src/fluxer/migrate_message.py +++ b/src/fluxer/migrate_message.py @@ -4,6 +4,7 @@ import re import json import io from typing import Callable, Awaitable, Dict, Any, List +from pathlib import Path try: from lottie.objects import Animation @@ -27,22 +28,25 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None, uid = int(match.group(1)) if anonymize_users and state: alias = state.get_user_alias(str(uid)) - return f"`@{alias}`" + return f"`@{alias}`" if alias else "`@Anonymized User`" + # 1. Try provided guild member = guild.get_member(uid) if member: return f"`@{member.display_name}`" - # 2. Try message's user_mentions + + # 2. Try provided user_mentions if user_mentions: - for u in user_mentions: - if u.id == uid: - return f"`@{getattr(u, 'display_name', u.name)}`" + m = next((u for u in user_mentions if u.id == uid), None) + if m: + return f"`@{m.display_name}`" + # 3. Try global cache via guild.client if hasattr(guild, 'client'): user = guild.client.get_user(uid) if user: return f"`@{user.name}`" - return f"`@Unknown User`" + return "`@Unknown User`" def replace_role(match): rid = int(match.group(1)) @@ -73,8 +77,8 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None, cid = int(match.group(1)) # 1. Check if channel is mapped in state - if channel_map and str(cid) in channel_map: - return f"<#{channel_map[str(cid)]}>" + if channel_map and cid in channel_map: + return f"<#{channel_map[cid]}>" # 2. Try to resolve channel name from pre-fetched names name = None @@ -100,7 +104,7 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None, def replace_emoji(match): animated = match.group(1) == "a" name = match.group(2) - eid = match.group(3) + eid = int(match.group(3)) if emoji_map and eid in emoji_map: target_eid = emoji_map[eid] @@ -539,8 +543,11 @@ async def migrate_messages( except Exception as e: logger.error(f"Failed to download sticker {getattr(s, 'name', 'unknown')}: {e}") + # Check for existing mapping to avoid duplicates when resuming + if context.state.get_target_message_id(target_channel_id, str(msg.id)): + continue + try: - # Check if this message is a reply reply_to_fluxer_id = None if msg.reference and msg.reference.message_id: reply_to_fluxer_id = context.state.get_fluxer_message_id(target_channel_id, str(msg.reference.message_id)) @@ -557,23 +564,22 @@ async def migrate_messages( if thread_name and stats["messages"] == 0: content = f"> <<< THREAD: **{thread_name}** >>>\n{content}" - # Get or generate alias + # Always ensure alias is created/retrieved to populate user_alias table alias = context.state.get_user_alias(str(msg.author.id)) - - if context.config.anonymize_users: - author_name = alias - avatar_url = f"https://api.dicebear.com/9.x/fun-emoji/jpg?seed={alias}" + + anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False + if anonymize_users: + author_name = alias or "Anonymized User" + author_avatar_url = None else: author_name = msg.author.display_name - avatar_url = str(msg.author.display_avatar.url) if msg.author.display_avatar.url else None - if avatar_url and not avatar_url.startswith("http"): - avatar_url = None + author_avatar_url = msg.author.avatar.url if hasattr(msg.author, 'avatar') and msg.author.avatar else None logger.debug(f"Fluxer: Calling send_message for Discord ID {msg.id}") fluxer_msg_id = await context.fluxer_writer.send_message( channel_id=target_channel_id, author_name=author_name, - author_avatar_url=avatar_url, + author_avatar_url=author_avatar_url, content=content, timestamp=int(msg.created_at.timestamp()), files=files if files else None, @@ -663,3 +669,261 @@ async def migrate_messages( pass return stats + + +async def analyze_global_migration(context: MigrationContext, after_message_id: int | None = None, inclusive: bool = False, progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None) -> Dict[str, int]: + """ + Scans the entire server history to count messages, threads, and attachments globally. + """ + stats = {"messages": 0, "threads": 0, "attachments": 0} + + # In global mode, thread messages are returned natively in timestamp order by global fetch if they're in the DB + # However we just count them if the fetcher yields them. + # Fetch global progress map to skip migrated messages efficiently + progress_map = context.state.get_all_last_message_ids() + + async for msg in context.discord_reader.fetch_global_message_history(after_id=after_message_id): + if not context.is_running: + break + + # Determine target channel to check for existing mapping + if not msg.channel: + continue + + target_channel_id = context.state.get_target_channel_id(str(msg.channel.id)) + if not target_channel_id: + continue + + # Efficient skip: if message ID is <= last migrated ID for this channel/thread + # This is the primary resume mechanism: wait until we pass the last migrated ID for this channel + last_id = progress_map.get(str(msg.channel.id)) + if last_id and msg.id <= int(last_id): + continue + + if msg.type not in [ + context.discord_reader.MESSAGE_TYPE_DEFAULT, + context.discord_reader.MESSAGE_TYPE_REPLY, + context.discord_reader.MESSAGE_TYPE_THREAD_STARTER, + context.discord_reader.MESSAGE_TYPE_FORWARD, + context.discord_reader.MESSAGE_TYPE_CHAT_INPUT_COMMAND, + context.discord_reader.MESSAGE_TYPE_CONTEXT_MENU_COMMAND, + context.discord_reader.MESSAGE_TYPE_POLL_RESULT, + context.discord_reader.MESSAGE_TYPE_AUTO_MODERATION_ACTION + ]: + continue + + stats["messages"] += 1 + stats["attachments"] += len(msg.attachments) + if hasattr(msg, 'thread') and msg.thread: + # We don't recursively traverse here, we just count the fact there is a thread + # The actual thread messages are also fetched by the global fetcher because they have their own timestamp/id + stats["threads"] += 1 + + if progress_callback and stats["messages"] % 100 == 0: + await progress_callback(stats) + + if progress_callback: + await progress_callback(stats) + + return stats + + +async def migrate_global_messages( + context: MigrationContext, + after_message_id: int | None = None, + inclusive: bool = False, + progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None +) -> Dict[str, Any]: + """ + Migrates messages across all channels chronologically. + """ + stats = { + "messages": 0, + "threads": 0, + "attachments": 0, + "last_message_content": "", + "last_message_author": "", + "first_message_url": None, + "last_message_url": None + } + + processed_threads = set() + logger.info("Starting Global Waterfall Migration for Fluxer...") + + # Keep track of active thread mapping natively to pass parent target IDs if needed + thread_to_target_channel = {} + + # Emojis and mapped users cache setup + emoji_map = context.state.emoji_map + db_media = context.discord_reader.db.get_all_media() if context.discord_reader.db else {} + target_server_id = getattr(context.fluxer_writer, "server_id", None) + + # Fetch global progress map to skip migrated messages efficiently + progress_map = context.state.get_all_last_message_ids() + + try: + async for msg in context.discord_reader.fetch_global_message_history(after_id=after_message_id): + if not context.is_running: + logger.warning("Global migration interrupted by user") + break + + if msg.type not in [ + context.discord_reader.MESSAGE_TYPE_DEFAULT, + context.discord_reader.MESSAGE_TYPE_REPLY, + context.discord_reader.MESSAGE_TYPE_THREAD_STARTER, + context.discord_reader.MESSAGE_TYPE_FORWARD, + context.discord_reader.MESSAGE_TYPE_CHAT_INPUT_COMMAND, + context.discord_reader.MESSAGE_TYPE_CONTEXT_MENU_COMMAND, + context.discord_reader.MESSAGE_TYPE_POLL_RESULT, + context.discord_reader.MESSAGE_TYPE_AUTO_MODERATION_ACTION + ]: + continue + + # Determine target channel + if not msg.channel: + continue + + target_channel_id = context.state.get_target_channel_id(str(msg.channel.id)) + if not target_channel_id: + continue + + # Efficient skip: if message ID is <= last migrated ID for this channel/thread + # This ensures we only resume a channel once we reach its last known progress point + last_id = progress_map.get(str(target_channel_id)) + if last_id and msg.id <= int(last_id): + continue + + # If it's a thread message, we need to handle it based on if it's the thread starter or a reply + parent_target_id = None + if hasattr(msg, 'thread') and msg.thread and msg.id == msg.thread.id: + processed_threads.add(msg.thread.id) + stats["threads"] += 1 + elif msg.channel.type in [11, 12]: # Thread channels + # It's a message IN a thread. + # In Fluxer, threads might just be linear messages or threaded replies depending on schema + # For basic migration we just send it to the parent mapped target channel. + # The parent mapped target channel ID should already be calculated correctly by get_target_channel_id (which returns mapped thread or parent channel) + pass + + # Formatting + files = [] + file_names = [] + + # Always ensure alias is created/retrieved to populate user_alias table + alias = context.state.get_user_alias(str(msg.author.id)) + + anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False + + if anonymize_users: + author_name = alias or "Anonymized User" + author_avatar_url = None + else: + author_name = msg.author.display_name + author_avatar_url = msg.author.avatar.url if hasattr(msg.author, 'avatar') and msg.author.avatar else None + + for att in msg.attachments: + media_info = db_media.get(att.local_hash) if db_media else None + local_path = None + if media_info: + local_path = Path(media_info["local_path"]) + elif hasattr(att, 'read'): + # Fallback + pass + + if local_path and local_path.exists(): + files.append(local_path) + file_names.append(att.filename) + + content = msg.content or "" + + # Stickers + for sticker in msg.stickers: + sticker_name = sticker.name + sticker_url = sticker.url + + # Check for uploaded media pool logic first + s_hash = sticker.local_hash + sticker_file = None + s_media = db_media.get(s_hash) if db_media and s_hash else None + if s_media: + s_path = Path(s_media["local_path"]) + if s_path.exists(): + sticker_file = s_path + + content += f"\n[Sticker: {sticker_name}]" + if sticker_file: + files.append(sticker_file) + file_names.append(f"sticker_{sticker_name}.png") + + content = clean_mentions( + content=content, + guild=context.discord_reader.guild, + user_mentions=msg.mentions, + role_mentions=msg.role_mentions, + channel_mentions=msg.channel_mentions, + emoji_map=emoji_map, + channel_map=context.state.channel_map, + state=context.state, + target_server_id=target_server_id, + channel_names=context.channel_names if hasattr(context, 'channel_names') else None, + anonymize_users=anonymize_users + ) + + if not content and not files: + logger.debug(f"Message {msg.id} empty after processing, skipping.") + continue + + timestamp_int = int(msg.created_at.timestamp()) + + if msg.reference and msg.reference.message_id: + # Resolve the author of the message being replied to + source_ref_msg = await context.discord_reader.get_message(msg.channel.id, msg.reference.message_id) + if source_ref_msg and source_ref_msg.author: + ref_author_id = str(source_ref_msg.author.id) + if anonymize_users: + ref_name = context.state.get_user_alias(ref_author_id) or "Anonymized User" + else: + ref_name = source_ref_msg.author.display_name + content = f"`@{ref_name}`\n{content}" + else: + # Fallback if author cannot be resolved (e.g. deleted/missing from backup) + tgt_reply = context.state.get_target_message_id(target_channel_id, msg.reference.message_id) + if tgt_reply: + content = f"[Reply to {tgt_reply}]\n{content}" + + try: + fluxer_msg_id = await context.fluxer_writer.send_message( + channel_id=target_channel_id, + author_name=author_name, + author_avatar_url=author_avatar_url, + content=content, + files=files, + timestamp=timestamp_int, + embeds=msg.embeds + ) + + if fluxer_msg_id: + context.state.set_target_message_mapping(target_channel_id, msg.id, fluxer_msg_id) + context.state.update_last_message_id(target_channel_id, msg.id) + context.state.set_waterfall_last_id(msg.id) + stats["attachments"] += len(files) if files else 0 + + stats["messages"] += 1 + stats["last_message_content"] = content + stats["last_message_author"] = author_name + + if not stats["first_message_url"]: + stats["first_message_url"] = msg.jump_url + stats["last_message_url"] = msg.jump_url + + if progress_callback: + await progress_callback(stats) + + except Exception as e: + logger.error(f"Failed to process global message {msg.id}: {e}") + + except (KeyboardInterrupt, asyncio.CancelledError): + context.is_running = False + pass + + return stats diff --git a/src/fluxer/roles_permissions.py b/src/fluxer/roles_permissions.py index 8b80a67..972222d 100644 --- a/src/fluxer/roles_permissions.py +++ b/src/fluxer/roles_permissions.py @@ -15,8 +15,8 @@ async def sync_roles_state(context: MigrationContext): fluxer_roles = await context.fluxer_writer.client.get_guild_roles(context.config.fluxer_server_id) # Build name -> id maps and ID sets for Fluxer for fast lookup - fluxer_role_map = {r.get("name"): str(r.get("id")) for r in fluxer_roles if r.get("name")} - fluxer_role_ids = {str(r.get("id")) for r in fluxer_roles} + fluxer_role_map = {r.get("name"): r.get("id") for r in fluxer_roles if r.get("name")} + fluxer_role_ids = {r.get("id") for r in fluxer_roles} updates = 0 removals = 0 diff --git a/src/stoat/emoji_stickers.py b/src/stoat/emoji_stickers.py index bfa1b60..5b7a0ed 100644 --- a/src/stoat/emoji_stickers.py +++ b/src/stoat/emoji_stickers.py @@ -37,8 +37,7 @@ async def sync_assets_state(context: MigrationContext): if stoat_id: if stoat_id not in stoat_emoji_ids: - context.state.emoji_map.pop(discord_id, None) - context.state.save_state() + context.state.remove_emoji_mapping(discord_id) removals += 1 elif emoji.name in stoat_emoji_map: context.state.set_target_emoji_mapping(discord_id, stoat_emoji_map[emoji.name]) diff --git a/src/stoat/migrate_message.py b/src/stoat/migrate_message.py index c0861bd..130579f 100644 --- a/src/stoat/migrate_message.py +++ b/src/stoat/migrate_message.py @@ -4,6 +4,7 @@ import re import json import io from typing import Callable, Awaitable, Dict, Any, List +from pathlib import Path try: from lottie.objects import Animation @@ -27,22 +28,25 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None, uid = int(match.group(1)) if anonymize_users and state: alias = state.get_user_alias(str(uid)) - return f"`@{alias}`" + return f"`@{alias}`" if alias else "`@Anonymized User`" + # 1. Try provided guild member = guild.get_member(uid) if member: return f"`@{member.display_name}`" - # 2. Try message's user_mentions + + # 2. Try provided user_mentions if user_mentions: - for u in user_mentions: - if u.id == uid: - return f"`@{getattr(u, 'display_name', u.name)}`" + m = next((u for u in user_mentions if u.id == uid), None) + if m: + return f"`@{m.display_name}`" + # 3. Try global cache via guild.client if hasattr(guild, 'client'): user = guild.client.get_user(uid) if user: return f"`@{user.name}`" - return f"`@Unknown User`" + return "`@Unknown User`" def replace_role(match): rid = int(match.group(1)) @@ -543,6 +547,10 @@ async def migrate_messages( logger.error(f"Failed to download sticker {getattr(s, 'name', 'unknown')}: {e}") try: + # Check for existing mapping to avoid duplicates when resuming + if context.state.get_target_message_id(target_channel_id, str(msg.id)): + continue + # Check if this message is a reply reply_to_stoat_id = None if msg.reference and msg.reference.message_id: @@ -560,22 +568,24 @@ async def migrate_messages( if thread_name and stats["messages"] == 0: content = f"> <<< THREAD: **{thread_name}** >>>\n{content}" - # Get or generate alias + # Always ensure alias is created/retrieved to populate user_alias table alias = context.state.get_user_alias(str(msg.author.id)) - - if context.config.anonymize_users: - author_name = alias - avatar_url = f"https://api.dicebear.com/9.x/fun-emoji/jpg?seed={alias}" + + anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False + if anonymize_users: + author_name = alias or "Anonymized User" + author_avatar_url = None else: author_name = msg.author.display_name - avatar_url = str(msg.author.display_avatar.url) if msg.author.display_avatar.url else None - if avatar_url and not avatar_url.startswith("http"): - avatar_url = None + author_avatar_url = str(msg.author.display_avatar.url) if msg.author.display_avatar.url else None + if author_avatar_url and not author_avatar_url.startswith("http"): + author_avatar_url = None + logger.debug(f"Stoat: Calling send_message for Discord ID {msg.id}") stoat_msg_id = await context.stoat_writer.send_message( channel_id=target_channel_id, author_name=author_name, - author_avatar_url=avatar_url, + author_avatar_url=author_avatar_url, content=content, timestamp=int(msg.created_at.timestamp()), files=files if files else None, @@ -664,3 +674,240 @@ async def migrate_messages( pass return stats + + +async def analyze_global_migration(context: MigrationContext, after_message_id: int | None = None, inclusive: bool = False, progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None) -> Dict[str, int]: + """ + Scans the entire server history to count messages, threads, and attachments globally. + """ + stats = {"messages": 0, "threads": 0, "attachments": 0} + + # Fetch global progress map to skip migrated messages efficiently + progress_map = context.state.get_all_last_message_ids() + + async for msg in context.discord_reader.fetch_global_message_history(after_id=after_message_id): + if not context.is_running: + break + + # Determine target channel to check for existing mapping + if not msg.channel: + continue + + target_channel_id = context.state.get_target_channel_id(str(msg.channel.id)) + if not target_channel_id: + continue + + # Efficient skip: if message ID is <= last migrated ID for this channel/thread + # This is the primary resume mechanism: wait until we pass the last migrated ID for this channel + last_id = progress_map.get(str(msg.channel.id)) + if last_id and msg.id <= int(last_id): + continue + + if msg.type not in [ + context.discord_reader.MESSAGE_TYPE_DEFAULT, + context.discord_reader.MESSAGE_TYPE_REPLY, + context.discord_reader.MESSAGE_TYPE_THREAD_STARTER, + context.discord_reader.MESSAGE_TYPE_FORWARD, + context.discord_reader.MESSAGE_TYPE_CHAT_INPUT_COMMAND, + context.discord_reader.MESSAGE_TYPE_CONTEXT_MENU_COMMAND, + context.discord_reader.MESSAGE_TYPE_POLL_RESULT, + context.discord_reader.MESSAGE_TYPE_AUTO_MODERATION_ACTION + ]: + continue + + stats["messages"] += 1 + stats["attachments"] += len(msg.attachments) + if hasattr(msg, 'thread') and msg.thread: + stats["threads"] += 1 + + if progress_callback and stats["messages"] % 100 == 0: + await progress_callback(stats) + + if progress_callback: + await progress_callback(stats) + + return stats + + +async def migrate_global_messages( + context: MigrationContext, + after_message_id: int | None = None, + inclusive: bool = False, + progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None +) -> Dict[str, Any]: + """ + Migrates messages across all channels chronologically to Stoat. + """ + stats = { + "messages": 0, + "threads": 0, + "attachments": 0, + "last_message_content": "", + "last_message_author": "", + "first_message_url": None, + "last_message_url": None + } + + processed_threads = set() + logger.info("Starting Global Waterfall Migration for Stoat...") + + emoji_map = context.state.emoji_map + db_media = context.discord_reader.db.get_all_media() if context.discord_reader.db else {} + # Fetch global progress map to skip migrated messages efficiently + progress_map = context.state.get_all_last_message_ids() + + try: + async for msg in context.discord_reader.fetch_global_message_history(after_id=after_message_id): + if not context.is_running: + logger.warning("Global migration interrupted by user") + break + + if msg.type not in [ + context.discord_reader.MESSAGE_TYPE_DEFAULT, + context.discord_reader.MESSAGE_TYPE_REPLY, + context.discord_reader.MESSAGE_TYPE_THREAD_STARTER, + context.discord_reader.MESSAGE_TYPE_FORWARD, + context.discord_reader.MESSAGE_TYPE_CHAT_INPUT_COMMAND, + context.discord_reader.MESSAGE_TYPE_CONTEXT_MENU_COMMAND, + context.discord_reader.MESSAGE_TYPE_POLL_RESULT, + context.discord_reader.MESSAGE_TYPE_AUTO_MODERATION_ACTION + ]: + continue + + # Determine target channel + if not msg.channel: + continue + + target_channel_id = context.state.get_target_channel_id(str(msg.channel.id)) + if not target_channel_id: + continue + + # Efficient skip: if message ID is <= last migrated ID for this channel/thread + # This ensures we only resume a channel once we reach its last known progress point + last_id = progress_map.get(str(target_channel_id)) + if last_id and msg.id <= int(last_id): + continue + + parent_target_id = None + if hasattr(msg, 'thread') and msg.thread and msg.id == msg.thread.id: + processed_threads.add(msg.thread.id) + stats["threads"] += 1 + elif msg.channel.type in [11, 12]: + pass + + # Formatting + files = [] + + # Always ensure alias is created/retrieved to populate user_alias table + alias = context.state.get_user_alias(str(msg.author.id)) + + anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False + + if anonymize_users: + author_name = alias or "Anonymized User" + author_avatar_url = None + else: + author_name = msg.author.display_name + author_avatar_url = msg.author.avatar.url if hasattr(msg.author, 'avatar') and msg.author.avatar else None + + for att in msg.attachments: + media_info = db_media.get(att.local_hash) if db_media else None + local_path = None + if media_info: + local_path = Path(media_info["local_path"]) + + if local_path and local_path.exists(): + try: + with open(local_path, "rb") as f: + files.append({"filename": att.filename, "data": f.read()}) + except Exception as fe: + logger.error(f"Failed to read file {local_path}: {fe}") + + content = msg.content or "" + + for sticker in msg.stickers: + sticker_name = sticker.name + s_hash = sticker.local_hash + sticker_file = None + s_media = db_media.get(s_hash) if db_media and s_hash else None + if s_media: + s_path = Path(s_media["local_path"]) + if s_path.exists(): + sticker_file = s_path + + content += f"\n[Sticker: {sticker_name}]" + if sticker_file: + files.append(sticker_file) + file_names.append(f"sticker_{sticker_name}.png") + + content = clean_mentions( + content=content, + guild=context.discord_reader.guild, + user_mentions=msg.mentions, + role_mentions=msg.role_mentions, + channel_mentions=msg.channel_mentions, + emoji_map=emoji_map, + channel_map=context.state.channel_map, + state=context.state, + target_server_id=target_server_id, + channel_names=context.channel_names if hasattr(context, 'channel_names') else None, + anonymize_users=anonymize_users + ) + + if not content and not files: + logger.debug(f"Message {msg.id} empty after processing, skipping.") + continue + + timestamp_int = int(msg.created_at.timestamp()) + + if msg.reference and msg.reference.message_id: + # Resolve the author of the message being replied to + source_ref_msg = await context.discord_reader.get_message(msg.channel_id, msg.reference.message_id) + if source_ref_msg and source_ref_msg.author: + ref_author_id = str(source_ref_msg.author.id) + if anonymize_users: + ref_name = context.state.get_user_alias(ref_author_id) or "Anonymized User" + else: + ref_name = source_ref_msg.author.display_name + content = f"`@{ref_name}`\n{content}" + else: + tgt_reply = context.state.get_target_message_id(target_channel_id, msg.reference.message_id) + if tgt_reply: + content = f"[Reply to {tgt_reply}]\n{content}" + + try: + stoat_msg_id = await context.stoat_writer.send_message( + channel_id=target_channel_id, + author_name=author_name, + author_avatar_url=author_avatar_url, + content=content, + files=files, + timestamp=timestamp_int, + embeds=msg.embeds + ) + + if stoat_msg_id: + context.state.set_target_message_mapping(target_channel_id, msg.id, stoat_msg_id) + context.state.update_last_message_id(target_channel_id, msg.id) + context.state.set_waterfall_last_id(msg.id) + stats["attachments"] += len(files) if files else 0 + + stats["messages"] += 1 + stats["last_message_content"] = content + stats["last_message_author"] = author_name + + if not stats["first_message_url"]: + stats["first_message_url"] = msg.jump_url + stats["last_message_url"] = msg.jump_url + + if progress_callback: + await progress_callback(stats) + + except Exception as e: + logger.error(f"Failed to process global message {msg.id}: {e}") + + except (KeyboardInterrupt, asyncio.CancelledError): + context.is_running = False + pass + + return stats diff --git a/src/ui/shuttle_ops.py b/src/ui/shuttle_ops.py index 362ecc7..fc17583 100644 --- a/src/ui/shuttle_ops.py +++ b/src/ui/shuttle_ops.py @@ -160,6 +160,7 @@ class OperationPane(Container): yield Button("Clone Server Template", id="op_clone", disabled=True, tooltip="Clone server roles, categories, and channels to the target community") yield Button("Sync Server Settings", id="op_sync", disabled=True, tooltip="Sync emojis, stickers, server name, and icon to the target community") yield Button("Migrate Message History", id="op_messages", disabled=True, variant="primary", tooltip="Migrate message history from Discord to the target platform") + yield Button("Waterfall Migration", id="op_waterfall", disabled=True, variant="primary", tooltip="Migrate all messages globally in chronological order to prevent broken links.\n(Available for Local Backups)") yield Rule(id="footer_rule") yield Button("Danger Zone ⚠", id="op_danger", variant="error", disabled=True, flat=True, tooltip="Dangerous operations:\ndelete channels, roles, emojis on target\n(use with caution)") @@ -394,7 +395,7 @@ class OperationPane(Container): lbl.update(f"{t_status}") # Buttons - for bid in ("#op_clone", "#op_sync", "#op_messages", "#op_danger"): + for bid in ("#op_clone", "#op_sync", "#op_messages", "#op_waterfall", "#op_danger"): for btn in self.query(bid): btn.disabled = not self.tokens_valid # ── validation ──────────────────────────────────────────────────────── @@ -415,7 +416,7 @@ class OperationPane(Container): # Disable all operation buttons while validation is in progress if self.view_mode == "shuttle": - for bid in ("#op_clone", "#op_sync", "#op_messages", "#op_danger"): + for bid in ("#op_clone", "#op_sync", "#op_messages", "#op_waterfall", "#op_danger"): for btn in self.query(bid): btn.disabled = True elif self.view_mode == "backup": for bid in ("#op_backup_msgs", "#op_backup_sync"): @@ -551,8 +552,13 @@ class OperationPane(Container): else: target_ok = v.get("target_token") and v.get("target_community") self.tokens_valid = bool(discord_ok and target_ok) - - + + # Post validation adjustments + if self.tokens_valid: + is_backup = (self.config.tool_mode == "backup_transfer") + for btn in self.query("#op_waterfall"): + btn.disabled = not is_backup + btn.display = is_backup self._update_info_labels() @@ -570,6 +576,8 @@ class OperationPane(Container): self._open_sync_menu() elif bid == "op_messages": self.run_migrate_messages() + elif bid == "op_waterfall": + self.run_waterfall_migration() elif bid == "op_danger": self._open_danger_menu() @@ -716,7 +724,6 @@ class OperationPane(Container): return elif choice == "btn_main_menu": modal.dismiss() - self.app.switch_screen("config_selection") return force_mode = (choice == "btn_start_id") @@ -798,7 +805,6 @@ class OperationPane(Container): return elif choice == "btn_main_menu": modal.dismiss() - self.app.switch_screen("config_selection") return force_mode = (choice == "btn_start_id") @@ -1099,6 +1105,11 @@ class OperationPane(Container): source_channel = next(c for c in d_channels if c.id == src_id) target_channel = next(c for c in f_channels if c.get("id") == tgt_id) + # 2. Analyze + modal = ProgressScreen(log_level=self.config.log_level) + self.app.push_screen(modal) + await asyncio.sleep(0.1) + # Determine after_id status (skip for pending channels) if pending_create_name: last_migrated = None @@ -1106,11 +1117,6 @@ class OperationPane(Container): else: last_migrated = self.engine.state.get_last_message_id(str(target_channel.get('id'))) has_previous = bool(last_migrated) - - # Analyze - modal = ProgressScreen(log_level=self.config.log_level) - self.app.push_screen(modal) - await asyncio.sleep(0.1) src_server = getattr(self.engine.discord_reader, 'guild', None) tgt_server_info = await self.engine.writer.validate() @@ -1243,7 +1249,6 @@ class OperationPane(Container): continue # Return to channel picker elif choice == "btn_main_menu": modal.dismiss() - self.app.switch_screen("config_selection") self.engine.is_running = False await self.engine.close_connections() return @@ -1294,6 +1299,14 @@ class OperationPane(Container): try: self.engine.is_running = True + + # Ensure state is initialized (database exists) + if self.target_platform == "stoat": + tid = self.config.stoat_server_id + else: + tid = self.config.fluxer_server_id + self.engine.ensure_state_initialized(str(tid or ""), platform_name) + stats_analysis = await migrate_mod.analyze_migration( self.engine, source_channel_id=source_channel.id, @@ -1399,6 +1412,239 @@ class OperationPane(Container): else: modal.write(f"[bold red]Error: {err}[/bold red]") modal.phase_report("Message Migration", "error", show_back=False) + logger.error(f"Migration Error: {traceback.format_exc()}") + finally: + self.engine.is_running = False + await self.engine.close_connections() + + @work(exclusive=True) + async def run_waterfall_migration(self) -> None: + if not self.tokens_valid: + return + + migrate_mod = fluxer_migrate if self.target_platform == "fluxer" else stoat_migrate + platform_name = self.target_platform.capitalize() + + modal = ProgressScreen(log_level=self.config.log_level) + self.app.push_screen(modal) + await asyncio.sleep(0.1) + + try: + modal.show_info("[bold cyan]Waterfall Migration Ready[/bold cyan]", "Checking mapping and missing channels...") + modal.set_status("Connecting to Servers...") + await self.engine.start_connections() + + modal.set_status("Synchronizing entity mappings...") + await self._perform_auto_matching() + + # 1. Missing channels check + if hasattr(self.engine.discord_reader, "get_all_channels"): + full_d = await self.engine.discord_reader.get_all_channels() + # Include TEXT (0), CATEGORY (4), and NEWS (5) + d_channels = [c for c in full_d if c.type in [0, 4, 5]] + else: + full_d = await self.engine.discord_reader.get_channels() + d_channels = [c for c in full_d if c.type in [0, 4, 5]] + missing_channels = [] + for d in d_channels: + if d.type == 4: + tgt_id = self.engine.state.get_target_category_id(str(d.id)) + else: + tgt_id = self.engine.state.get_target_channel_id(str(d.id)) + if not tgt_id: + missing_channels.append(d) + + if missing_channels: + modal.write(f"\n[bold yellow]Found {len(missing_channels)} backed-up channels/categories missing from target platform:[/bold yellow]") + for mc in missing_channels: + prefix = "[bold cyan]📁[/bold cyan] " if mc.type == 4 else "[bold white]#[/bold white] " + modal.write(f" {prefix}{mc.name}") + + choice = await modal.phase_wait_confirm( + show_continue=False, + show_id=True, + btn_start_label="Clone missing channels", + btn_id_label="Skip missing channels", + btn_start_variant="primary", + btn_start_tooltip=f"Automatically create {len(missing_channels)} entities on target", + btn_id_tooltip="Start migration without these channels" + ) + + if choice == "btn_back": + modal.dismiss() + await self.engine.close_connections() + return + elif choice == "btn_main_menu": + modal.dismiss() + await self.engine.close_connections() + return + + if choice == "btn_start_first": + modal.set_status("Cloning missing categories and channels...") + # Sort so categories (type 4) come first + missing_channels.sort(key=lambda x: 0 if x.type == 4 else 1) + + for mc in missing_channels: + try: + parent_target_id = None + if mc.type == 4: + modal.write(f"Creating category '[bold cyan]{mc.name}[/bold cyan]'...") + new_id = await self.engine.writer.create_channel(name=mc.name, type=4) + self.engine.state.set_target_category_mapping(str(mc.id), new_id) + modal.write(f"[green]Created Category {mc.name} ({new_id})[/green]") + else: + if hasattr(mc, 'category_id') and mc.category_id: + parent_target_id = self.engine.state.get_target_category_id(str(mc.category_id)) + + modal.write(f"Creating channel '#{mc.name}'...") + new_id = await self.engine.writer.create_channel(name=mc.name, parent_id=parent_target_id) + self.engine.state.set_target_channel_id(str(mc.id), new_id, self.engine.target_platform) + modal.write(f"[green]Created Channel {mc.name} ({new_id})[/green]") + except Exception as e: + logger.error(f"Failed to create {mc.name}: {e}\n{traceback.format_exc()}") + modal.write(f"[red]Failed to create {mc.name}: {e}[/red]") + elif choice == "btn_id": + # Skip missing channels: remove them from the active list + missing_ids = {str(c.id) for c in missing_channels} + d_channels = [c for c in d_channels if str(c.id) not in missing_ids] + + # 2. Resumption check + all_mapped_tgt_ids = [] + # Check regular text channels (exclude categories for resume check) + for c in d_channels: + if c.type == 4: continue + did = str(c.id) + tid = self.engine.state.get_target_channel_id(did) + if tid: all_mapped_tgt_ids.append(tid) + + # Also check threads (filtering to only include those belonging to active channels) + active_channel_ids = {str(c.id) for c in d_channels} + if hasattr(self.engine.discord_reader, "get_active_threads"): + threads = await self.engine.discord_reader.get_active_threads() + for t in threads: + pid = str(getattr(t, 'parent_id', getattr(t, 'channel_id', None))) + if pid not in active_channel_ids: continue + tid = self.engine.state.get_target_channel_id(str(t.id)) + if tid: all_mapped_tgt_ids.append(tid) + + # 2.5 Filter by actual content (Only for BackupReader) + # If a channel has NO messages in the backup, it will always be at 0 progress. + # We exclude those from the global MIN calculation to avoid pulling it to 0. + if hasattr(self.engine.discord_reader, "get_backed_up_channel_ids"): + backed_up_src_ids = await self.engine.discord_reader.get_backed_up_channel_ids() + backed_up_src_ids_str = {str(sid) for sid in backed_up_src_ids} + + filtered_tgt_ids = [] + # Find which target IDs belong to source channels that HAVE messages + for c in d_channels: # (d_channels is already filtered for skipped) + if str(c.id) in backed_up_src_ids_str: + tid = self.engine.state.get_target_channel_id(str(c.id)) + if tid: filtered_tgt_ids.append(tid) + + # Also check threads + if hasattr(self.engine.discord_reader, "threads"): + for t in self.engine.discord_reader.threads: + if str(t.id) in backed_up_src_ids_str: + tid = self.engine.state.get_target_channel_id(str(t.id)) + if tid: filtered_tgt_ids.append(tid) + + if filtered_tgt_ids: + all_mapped_tgt_ids = filtered_tgt_ids + + # 2.6 Resume Point: Prioritize Global waterfall tracker, fallback to channel minimums + min_last_id = self.engine.state.get_waterfall_last_id() + if min_last_id is None: + min_last_id = self.engine.state.get_global_min_last_message_id(all_mapped_tgt_ids) + + modal.write(f"\n[bold cyan]Waterfall Migration Resume Point:[/bold cyan]") + if min_last_id is not None: + modal.write(f"Minimum unmigrated message ID found: [green]{min_last_id}[/green]") + else: + modal.write("No previous migration state found. Starting from the beginning.") + + choice = await modal.phase_wait_confirm( + show_continue=min_last_id is not None, + show_id=False, + btn_start_label="Start From Beginning", + btn_start_tooltip="Safe, skips duplicates automatically", + btn_start_variant="default" if min_last_id is not None else "primary", + btn_continue_label=f"Continue from ID {min_last_id if min_last_id is not None else 0}" if min_last_id is not None else "Continue Migration", + btn_continue_tooltip="Fastest" + ) + + if choice == "btn_back": + modal.dismiss() + await self.engine.close_connections() + return + elif choice == "btn_main_menu": + modal.dismiss() + await self.engine.close_connections() + return + + after_id = None + if choice == "btn_continue" and min_last_id is not None: + after_id = int(min_last_id) + + # Phase 3: Progress + modal.cancel_callback = lambda: setattr(self.engine, "is_running", False) + modal.phase_progress() + modal.set_status("Migrating messages Globally...") + + self.engine.is_running = True + + # Ensure state is initialized (database exists) + if self.target_platform == "stoat": + tid = self.config.stoat_server_id + else: + tid = self.config.fluxer_server_id + self.engine.ensure_state_initialized(str(tid or ""), platform_name) + + modal.write("Scanning global footprint for totals ...") + stats_analysis = await migrate_mod.analyze_global_migration(self.engine, after_message_id=after_id) + total_messages = stats_analysis["messages"] + + modal.write(f"[bold cyan]Global Migration Started:[/bold cyan] {total_messages} total messages to process.") + modal.update_stats(messages=f"0/{total_messages}", threads=str(stats_analysis["threads"]), files=str(stats_analysis["attachments"])) + + async def update_msg(current_stats): + c_msgs = current_stats["messages"] + c_threads = current_stats["threads"] + c_files = current_stats["attachments"] + + msg_stat = f"{c_msgs}/{total_messages}" if total_messages > 0 else str(c_msgs) + modal.set_progress(c_msgs, total_messages or 100) + modal.update_stats(messages=msg_stat, threads=str(c_threads), files=str(c_files)) + + content = current_stats.get("last_message_content", "") + author = current_stats.get("last_message_author", "Unknown") + if content: + disp_content = (content[:100] + '...') if len(content) > 100 else content + modal.write(f"[bold]{author}:[/bold] {disp_content}") + + result = await migrate_mod.migrate_global_messages( + self.engine, + after_message_id=after_id, + inclusive=False, + progress_callback=update_msg, + ) + + if self.engine.is_running: + modal.write(f"[bold green]Success! {result['messages']} messages migrated globally.[/bold green]") + modal.phase_report("Waterfall Migration", show_back=False) + else: + modal.write(f"[bold yellow]Interrupted! {result['messages']} messages migrated.[/bold yellow]") + modal.phase_report("Waterfall Migration", "stopped", show_back=False) + + lines = [f"Migrated Server Globally → {platform_name}:"] + lines.append(f"{result['messages']} messages, {result['attachments']} attachments, {result['threads']} threads") + await log_audit_event(self.engine, "Waterfall Migration", "\n".join(lines)) + + except Exception as e: + err = str(e) + modal.write(f"[bold red]Error: {err}[/bold red]") + modal.phase_report("Waterfall Migration", "error", show_back=False) + + logger.error(traceback.format_exc()) finally: self.engine.is_running = False await self.engine.close_connections() @@ -1485,7 +1731,6 @@ class OperationPane(Container): return elif choice == "btn_main_menu": modal.dismiss() - self.app.switch_screen("config_selection") return modal.cancel_callback = lambda: setattr(self.engine, "is_running", False) @@ -1644,6 +1889,39 @@ class OperationPane(Container): logger.warning(f"Auto-matching: failed to fetch target data: {e}") return # Cannot match without target data + # 1.5 Cleanup deleted entities from mapping database + # This prevents "Ghost" mappings to channels/roles that were deleted on target + valid_chan_ids = {str(c.get("id")) for c in target_chans_raw} + valid_cat_ids = {str(c.get("id")) for c in target_chans_raw if c.get("type") == 4} + valid_role_ids = set(target_roles_map.values()) + valid_emoji_ids = set(target_emojis_map.values()) + + # Channels + for src_id, tgt_id in self.engine.state.channel_map.items(): + if str(tgt_id) not in valid_chan_ids: + logger.info(f"Auto-matching: clearing deleted channel mapping {src_id} -> {tgt_id}") + self.engine.state.remove_target_channel_mapping(src_id) + + # Categories + for src_id, tgt_id in self.engine.state.category_map.items(): + if str(tgt_id) not in valid_cat_ids: + logger.info(f"Auto-matching: clearing deleted category mapping {src_id} -> {tgt_id}") + self.engine.state.remove_category_mapping(src_id) + + # Roles + for src_id, tgt_id in self.engine.state.role_map.items(): + if str(tgt_id) not in valid_role_ids: + logger.info(f"Auto-matching: clearing deleted role mapping {src_id} -> {tgt_id}") + self.engine.state.remove_role_mapping(src_id) + + # Emojis + for src_id, tgt_id in self.engine.state.emoji_map.items(): + if str(tgt_id) not in valid_emoji_ids: + # Emojis might be URLs in some platforms, but we check if they are IDs first + if isinstance(tgt_id, str) and tgt_id.isdigit(): + logger.info(f"Auto-matching: clearing deleted emoji mapping {src_id} -> {tgt_id}") + self.engine.state.remove_emoji_mapping(src_id) + # 2. Match entities try: # Roles @@ -1687,6 +1965,7 @@ class OperationPane(Container): logger.info(f"Auto-matched Sticker: {s.name} -> {target_stickers_map[name_l]}") self.engine.state.set_target_sticker_mapping(s.id, target_stickers_map[name_l]) except Exception as e: + logger.error(f"Auto-matching error: {e}\n{traceback.format_exc()}") logger.warning(f"Auto-matching error: {e}") return { @@ -1951,7 +2230,6 @@ class OperationPane(Container): after_id = verified_id elif choice == "btn_main_menu": modal_prog.dismiss() - self.app.switch_screen("config_selection") return # If we are here, proceeding either via Start First or Start from ID (after_id) @@ -2062,7 +2340,7 @@ class OperationPane(Container): self.engine.is_running = False await self.engine.close_connections() if choice == "btn_main_menu": - self.app.switch_screen("config_selection") + pass return modal_prog.cancel_callback = lambda: setattr(self.engine, "is_running", False)