Merge pull request #8 from rambros3d/sql-db

Switch to INTEGER for snowflake ids
This commit is contained in:
RamBros 2026-03-28 02:52:11 +05:30 committed by GitHub
commit c34d509677
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 354 additions and 168 deletions

View file

@ -4,20 +4,11 @@ import json
import threading import threading
from pathlib import Path from pathlib import Path
from typing import Dict, Any, List, Optional, Union from typing import Dict, Any, List, Optional, Union
from src.core.utils import parse_snowflake
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def parse_snowflake(value: Any) -> Optional[int]:
"""Safely parses a Discord ID (Snowflake) from any input, handling 'None' strings."""
if value is None:
return None
s = str(value).strip()
if not s or s.lower() == "none" or s == "NULL":
return None
try:
return int(s)
except ValueError:
return None
class BackupDatabase: class BackupDatabase:
"""Manages the SQLite database for local Discord backups.""" """Manages the SQLite database for local Discord backups."""
@ -39,24 +30,101 @@ class BackupDatabase:
def _migrate_db(self): def _migrate_db(self):
"""Handles backward compatibility by renaming columns in existing databases.""" """Handles backward compatibility by renaming columns in existing databases."""
with self._lock: with self._lock:
# Check 'media_pool' table conn = self._conn
res = self._conn.execute("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='media_pool'").fetchone() # 1. MIME Type to Content Type Migrations
if res[0] > 0: for table in ["media_pool", "server_assets"]:
cols = self._conn.execute("PRAGMA table_info(media_pool)").fetchall() res = conn.execute(f"SELECT count(*) FROM sqlite_master WHERE type='table' AND name='{table}'").fetchone()
col_names = [c["name"] for c in cols] if res and res[0] > 0:
if "mime_type" in col_names and "content_type" not in col_names: cols = conn.execute(f"PRAGMA table_info({table})").fetchall()
logger.info("Migrating media_pool: renaming 'mime_type' to 'content_type'") col_names = [c["name"] for c in cols]
self._conn.execute("ALTER TABLE media_pool RENAME COLUMN mime_type TO content_type") if "mime_type" in col_names and "content_type" not in col_names:
logger.info(f"Migrating {table}: renaming 'mime_type' to 'content_type'")
conn.execute(f"ALTER TABLE {table} RENAME COLUMN mime_type TO content_type")
# Check 'server_assets' table # 2. Universal ID Migration (TEXT -> INTEGER)
res = self._conn.execute("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='server_assets'").fetchone() # Mapping of table names to columns that must be INTEGER (Snowflakes)
if res[0] > 0: id_migrations = {
cols = self._conn.execute("PRAGMA table_info(server_assets)").fetchall() "guild_profile": ["id", "owner_id"],
col_names = [c["name"] for c in cols] "roles": ["id", "permissions"],
if "mime_type" in col_names and "content_type" not in col_names: "channels": ["id", "category_id"],
logger.info("Migrating server_assets: renaming 'mime_type' to 'content_type'") "permissions": ["channel_id", "target_id"],
self._conn.execute("ALTER TABLE server_assets RENAME COLUMN mime_type TO content_type") "users": ["id"],
self._conn.commit() "messages": ["id", "channel_id", "author_id", "message_reference"],
"attachments": ["id", "message_id"],
"embeds": ["message_id"],
"reactions": ["message_id", "emoji_id"],
"message_stickers": ["message_id", "sticker_id"],
"threads": ["id", "parent_id"],
"forum_tags": ["id", "forum_id", "emoji_id"],
"server_assets": ["id"]
}
for table, id_cols in id_migrations.items():
res = conn.execute(f"SELECT count(*) FROM sqlite_master WHERE type='table' AND name='{table}'").fetchone()
if not res or res[0] == 0:
continue
cols = conn.execute(f"PRAGMA table_info({table})").fetchall()
needs_migration = False
for col in cols:
if col[1] in id_cols and col[2] == "TEXT":
needs_migration = True
break
if needs_migration:
logger.info(f"Migrating {table}: converting ID columns to INTEGER")
# Special Case: messages already handled id, but now generic
# We use a temporary table to handle the schema change
conn.execute(f"ALTER TABLE {table} RENAME TO {table}_old")
# We can't easily generate the CREATE TABLE here without duplicating _init_db logic
# So we call _init_db to create the NEW table, then copy data
# But _init_db has 'IF NOT EXISTS', so we just call it once at the end?
# No, we need the table NOW for the INSERT.
# I'll just manually define the inserts or do it in _init_db.
# Actually, a better way is to do the CREATE TABLE here for this specific table.
# I'll have to duplicate the schema from _init_db for the migration.
# Alternatively, since we are already in _migrate_db, we can just do the
# specific CREATE TABLE for the table we are migrating.
if table == "guild_profile":
conn.execute("CREATE TABLE guild_profile (id INTEGER PRIMARY KEY, name TEXT, description TEXT, icon_file TEXT, icon_url TEXT, banner_file TEXT, banner_url TEXT, owner_id INTEGER, last_backup TEXT, ignore_channels TEXT)")
elif table == "roles":
conn.execute("CREATE TABLE roles (id INTEGER PRIMARY KEY, name TEXT, color INTEGER, position INTEGER, permissions INTEGER, hoist INTEGER, mentionable INTEGER)")
elif table == "channels":
conn.execute("CREATE TABLE channels (id INTEGER PRIMARY KEY, name TEXT, type INTEGER, position INTEGER, category_id INTEGER, topic TEXT, nsfw INTEGER, bitrate INTEGER, slowmode_delay INTEGER)")
elif table == "permissions":
conn.execute("CREATE TABLE permissions (id INTEGER PRIMARY KEY AUTOINCREMENT, channel_id INTEGER, target_id INTEGER, target_type TEXT, allow INTEGER, deny INTEGER)")
elif table == "users":
conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, username TEXT, display_name TEXT, avatar_file TEXT, avatar_url TEXT, roles TEXT)")
elif table == "messages":
conn.execute("CREATE TABLE messages (id INTEGER PRIMARY KEY, channel_id INTEGER, author_id INTEGER, content TEXT, timestamp TEXT, type INTEGER, message_reference INTEGER, is_pinned INTEGER, extra_data TEXT)")
elif table == "attachments":
conn.execute("CREATE TABLE attachments (id INTEGER PRIMARY KEY, message_id INTEGER, filename TEXT, size INTEGER, url TEXT, content_type TEXT, local_hash TEXT)")
elif table == "embeds":
conn.execute("CREATE TABLE embeds (id INTEGER PRIMARY KEY AUTOINCREMENT, message_id INTEGER, title TEXT, description TEXT, url TEXT, color INTEGER, timestamp TEXT, thumbnail_url TEXT, image_url TEXT, author_name TEXT, author_url TEXT, author_icon_url TEXT, footer_text TEXT, footer_icon_url TEXT, fields TEXT)")
elif table == "reactions":
conn.execute("CREATE TABLE reactions (id INTEGER PRIMARY KEY AUTOINCREMENT, message_id INTEGER, emoji_id INTEGER, emoji_name TEXT, count INTEGER)")
elif table == "message_stickers":
conn.execute("CREATE TABLE message_stickers (message_id INTEGER, sticker_id INTEGER, name TEXT, url TEXT, format_type INTEGER, local_hash TEXT, PRIMARY KEY (message_id, sticker_id))")
elif table == "threads":
conn.execute("CREATE TABLE threads (id INTEGER PRIMARY KEY, name TEXT, type INTEGER, parent_id INTEGER, message_count INTEGER, member_count INTEGER, archived INTEGER, archive_timestamp TEXT, auto_archive_duration INTEGER, locked INTEGER, applied_tags TEXT)")
elif table == "forum_tags":
conn.execute("CREATE TABLE forum_tags (id INTEGER PRIMARY KEY, forum_id INTEGER, name TEXT, moderated INTEGER, emoji_id INTEGER, emoji_name TEXT)")
elif table == "server_assets":
conn.execute("CREATE TABLE server_assets (id INTEGER PRIMARY KEY, name TEXT, type TEXT, filename TEXT, url TEXT, content_type INTEGER)")
old_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table}_old)").fetchall()]
new_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table})").fetchall()]
common_cols = [c for c in old_cols if c in new_cols]
col_str = ", ".join(common_cols)
conn.execute(f"INSERT INTO {table} ({col_str}) SELECT {col_str} FROM {table}_old")
conn.execute(f"DROP TABLE {table}_old")
conn.commit()
def _init_db(self): def _init_db(self):
"""Initializes the database schema.""" """Initializes the database schema."""
@ -66,14 +134,14 @@ class BackupDatabase:
# Guild Profile # Guild Profile
conn.execute(""" conn.execute("""
CREATE TABLE IF NOT EXISTS guild_profile ( CREATE TABLE IF NOT EXISTS guild_profile (
id TEXT PRIMARY KEY, id INTEGER PRIMARY KEY,
name TEXT, name TEXT,
description TEXT, description TEXT,
icon_file TEXT, icon_file TEXT,
icon_url TEXT, icon_url TEXT,
banner_file TEXT, banner_file TEXT,
banner_url TEXT, banner_url TEXT,
owner_id TEXT, owner_id INTEGER,
last_backup TEXT, last_backup TEXT,
ignore_channels TEXT ignore_channels TEXT
) )
@ -82,11 +150,11 @@ class BackupDatabase:
# Roles # Roles
conn.execute(""" conn.execute("""
CREATE TABLE IF NOT EXISTS roles ( CREATE TABLE IF NOT EXISTS roles (
id TEXT PRIMARY KEY, id INTEGER PRIMARY KEY,
name TEXT, name TEXT,
color INTEGER, color INTEGER,
position INTEGER, position INTEGER,
permissions TEXT, permissions INTEGER,
hoist INTEGER, hoist INTEGER,
mentionable INTEGER mentionable INTEGER
) )
@ -95,11 +163,11 @@ class BackupDatabase:
# Channels # Channels
conn.execute(""" conn.execute("""
CREATE TABLE IF NOT EXISTS channels ( CREATE TABLE IF NOT EXISTS channels (
id TEXT PRIMARY KEY, id INTEGER PRIMARY KEY,
name TEXT, name TEXT,
type INTEGER, type INTEGER,
position INTEGER, position INTEGER,
category_id TEXT, category_id INTEGER,
topic TEXT, topic TEXT,
nsfw INTEGER, nsfw INTEGER,
bitrate INTEGER, bitrate INTEGER,
@ -111,8 +179,8 @@ class BackupDatabase:
conn.execute(""" conn.execute("""
CREATE TABLE IF NOT EXISTS permissions ( CREATE TABLE IF NOT EXISTS permissions (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id TEXT, channel_id INTEGER,
target_id TEXT, target_id INTEGER,
target_type TEXT, target_type TEXT,
allow INTEGER, allow INTEGER,
deny INTEGER deny INTEGER
@ -123,7 +191,7 @@ class BackupDatabase:
# Users (Author cache) # Users (Author cache)
conn.execute(""" conn.execute("""
CREATE TABLE IF NOT EXISTS users ( CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY, id INTEGER PRIMARY KEY,
username TEXT, username TEXT,
display_name TEXT, display_name TEXT,
avatar_file TEXT, avatar_file TEXT,
@ -135,13 +203,13 @@ class BackupDatabase:
# Messages # Messages
conn.execute(""" conn.execute("""
CREATE TABLE IF NOT EXISTS messages ( CREATE TABLE IF NOT EXISTS messages (
id TEXT PRIMARY KEY, id INTEGER PRIMARY KEY,
channel_id TEXT, channel_id INTEGER,
author_id TEXT, author_id INTEGER,
content TEXT, content TEXT,
timestamp TEXT, timestamp TEXT,
type INTEGER, type INTEGER,
message_reference TEXT, message_reference INTEGER,
is_pinned INTEGER, is_pinned INTEGER,
extra_data TEXT extra_data TEXT
) )
@ -152,8 +220,8 @@ class BackupDatabase:
# Attachments # Attachments
conn.execute(""" conn.execute("""
CREATE TABLE IF NOT EXISTS attachments ( CREATE TABLE IF NOT EXISTS attachments (
id TEXT PRIMARY KEY, id INTEGER PRIMARY KEY,
message_id TEXT, message_id INTEGER,
filename TEXT, filename TEXT,
size INTEGER, size INTEGER,
url TEXT, url TEXT,
@ -167,7 +235,7 @@ class BackupDatabase:
conn.execute(""" conn.execute("""
CREATE TABLE IF NOT EXISTS embeds ( CREATE TABLE IF NOT EXISTS embeds (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
message_id TEXT, message_id INTEGER,
title TEXT, title TEXT,
description TEXT, description TEXT,
url TEXT, url TEXT,
@ -189,8 +257,8 @@ class BackupDatabase:
conn.execute(""" conn.execute("""
CREATE TABLE IF NOT EXISTS reactions ( CREATE TABLE IF NOT EXISTS reactions (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
message_id TEXT, message_id INTEGER,
emoji_id TEXT, emoji_id INTEGER,
emoji_name TEXT, emoji_name TEXT,
count INTEGER count INTEGER
) )
@ -200,8 +268,8 @@ class BackupDatabase:
# Message Stickers # Message Stickers
conn.execute(""" conn.execute("""
CREATE TABLE IF NOT EXISTS message_stickers ( CREATE TABLE IF NOT EXISTS message_stickers (
message_id TEXT, message_id INTEGER,
sticker_id TEXT, sticker_id INTEGER,
name TEXT, name TEXT,
url TEXT, url TEXT,
format_type INTEGER, format_type INTEGER,
@ -214,10 +282,10 @@ class BackupDatabase:
# Threads # Threads
conn.execute(""" conn.execute("""
CREATE TABLE IF NOT EXISTS threads ( CREATE TABLE IF NOT EXISTS threads (
id TEXT PRIMARY KEY, id INTEGER PRIMARY KEY,
name TEXT, name TEXT,
type INTEGER, type INTEGER,
parent_id TEXT, parent_id INTEGER,
message_count INTEGER, message_count INTEGER,
member_count INTEGER, member_count INTEGER,
archived INTEGER, archived INTEGER,
@ -232,11 +300,11 @@ class BackupDatabase:
# Forum Tags (Definitions for a forum channel) # Forum Tags (Definitions for a forum channel)
conn.execute(""" conn.execute("""
CREATE TABLE IF NOT EXISTS forum_tags ( CREATE TABLE IF NOT EXISTS forum_tags (
id TEXT PRIMARY KEY, id INTEGER PRIMARY KEY,
forum_id TEXT, forum_id INTEGER,
name TEXT, name TEXT,
moderated INTEGER, moderated INTEGER,
emoji_id TEXT, emoji_id INTEGER,
emoji_name TEXT emoji_name TEXT
) )
""") """)
@ -257,7 +325,7 @@ class BackupDatabase:
# Server Assets (Emojis, Stickers, etc.) # Server Assets (Emojis, Stickers, etc.)
conn.execute(""" conn.execute("""
CREATE TABLE IF NOT EXISTS server_assets ( CREATE TABLE IF NOT EXISTS server_assets (
id TEXT PRIMARY KEY, id INTEGER PRIMARY KEY,
name TEXT, name TEXT,
type TEXT, type TEXT,
filename TEXT, filename TEXT,
@ -276,10 +344,10 @@ class BackupDatabase:
INSERT OR REPLACE INTO guild_profile (id, name, description, icon_file, icon_url, banner_file, banner_url, owner_id, last_backup, ignore_channels) INSERT OR REPLACE INTO guild_profile (id, name, description, icon_file, icon_url, banner_file, banner_url, owner_id, last_backup, ignore_channels)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", ( """, (
str(data.get("id")), data.get("name"), data.get("description"), parse_snowflake(data.get("id")), data.get("name"), data.get("description"),
data.get("icon_file"), data.get("icon_url"), data.get("icon_file"), data.get("icon_url"),
data.get("banner_file"), data.get("banner_url"), data.get("banner_file"), data.get("banner_url"),
str(data.get("owner_id")), parse_snowflake(data.get("owner_id")),
data.get("last_backup"), json.dumps(data.get("ignore_channels", [])) data.get("last_backup"), json.dumps(data.get("ignore_channels", []))
)) ))
self._conn.commit() self._conn.commit()
@ -300,7 +368,7 @@ class BackupDatabase:
with self._lock: with self._lock:
formatted = [ formatted = [
{ {
"id": str(r["id"]), "id": parse_snowflake(r["id"]),
"name": r["name"], "name": r["name"],
"color": r["color"], "color": r["color"],
"position": r["position"], "position": r["position"],
@ -347,7 +415,7 @@ class BackupDatabase:
with self._lock: with self._lock:
formatted = [ formatted = [
{ {
"id": str(a["id"]), "id": parse_snowflake(a["id"]),
"name": a.get("name"), "name": a.get("name"),
"type": a.get("type"), "type": a.get("type"),
"filename": a.get("filename"), "filename": a.get("filename"),
@ -407,7 +475,7 @@ class BackupDatabase:
for rea in msg["reactions"]: for rea in msg["reactions"]:
all_reactions.append({ all_reactions.append({
"message_id": msg["id"], "message_id": msg["id"],
"emoji_id": str(rea["emoji_id"]) if rea.get("emoji_id") else None, "emoji_id": parse_snowflake(rea["emoji_id"]) if rea.get("emoji_id") else None,
"emoji_name": rea.get("emoji_name"), "emoji_name": rea.get("emoji_name"),
"count": rea.get("count", 0) "count": rea.get("count", 0)
}) })
@ -417,7 +485,7 @@ class BackupDatabase:
for st in msg["stickers"]: for st in msg["stickers"]:
all_stickers.append({ all_stickers.append({
"message_id": msg["id"], "message_id": msg["id"],
"sticker_id": str(st["id"]), "sticker_id": parse_snowflake(st["id"]),
"name": st.get("name"), "name": st.get("name"),
"url": st.get("url"), "url": st.get("url"),
"format_type": st.get("format_type"), "format_type": st.get("format_type"),
@ -469,7 +537,7 @@ class BackupDatabase:
def get_last_message_id(self, channel_id: str) -> Optional[str]: def get_last_message_id(self, channel_id: str) -> Optional[str]:
with self._lock: with self._lock:
row = self._conn.execute("SELECT id FROM messages WHERE channel_id = ? ORDER BY id DESC LIMIT 1", (str(channel_id),)).fetchone() row = self._conn.execute("SELECT id FROM messages WHERE channel_id = ? ORDER BY id DESC LIMIT 1", (parse_snowflake(channel_id),)).fetchone()
return row["id"] if row else None return row["id"] if row else None
def get_media_by_hash(self, file_hash: str) -> Optional[Dict[str, Any]]: def get_media_by_hash(self, file_hash: str) -> Optional[Dict[str, Any]]:
@ -600,7 +668,7 @@ class BackupDatabase:
"""Returns forum tag definitions.""" """Returns forum tag definitions."""
with self._lock: with self._lock:
if forum_id: if forum_id:
rows = self._conn.execute("SELECT * FROM forum_tags WHERE forum_id = ?", (str(forum_id),)).fetchall() rows = self._conn.execute("SELECT * FROM forum_tags WHERE forum_id = ?", (parse_snowflake(forum_id),)).fetchall()
else: else:
rows = self._conn.execute("SELECT * FROM forum_tags").fetchall() rows = self._conn.execute("SELECT * FROM forum_tags").fetchall()
return [dict(r) for r in rows] return [dict(r) for r in rows]
@ -608,13 +676,13 @@ class BackupDatabase:
def get_threads_by_parent(self, parent_id: str) -> List[Dict[str, Any]]: def get_threads_by_parent(self, parent_id: str) -> List[Dict[str, Any]]:
"""Returns all threads belonging to a parent channel.""" """Returns all threads belonging to a parent channel."""
with self._lock: with self._lock:
rows = self._conn.execute("SELECT * FROM threads WHERE parent_id = ?", (str(parent_id),)).fetchall() rows = self._conn.execute("SELECT * FROM threads WHERE parent_id = ?", (parse_snowflake(parent_id),)).fetchall()
return [dict(r) for r in rows] return [dict(r) for r in rows]
def get_thread(self, thread_id: str) -> Optional[Dict[str, Any]]: def get_thread(self, thread_id: str) -> Optional[Dict[str, Any]]:
"""Retrieves a single thread's metadata.""" """Retrieves a single thread's metadata."""
with self._lock: with self._lock:
row = self._conn.execute("SELECT * FROM threads WHERE id = ?", (str(thread_id),)).fetchone() row = self._conn.execute("SELECT * FROM threads WHERE id = ?", (parse_snowflake(thread_id),)).fetchone()
return dict(row) if row else None return dict(row) if row else None
def get_all_users(self) -> List[Dict[str, Any]]: def get_all_users(self) -> List[Dict[str, Any]]:
@ -624,7 +692,7 @@ class BackupDatabase:
def get_user(self, user_id: str) -> Optional[Dict[str, Any]]: def get_user(self, user_id: str) -> Optional[Dict[str, Any]]:
with self._lock: with self._lock:
row = self._conn.execute("SELECT * FROM users WHERE id = ?", (str(user_id),)).fetchone() row = self._conn.execute("SELECT * FROM users WHERE id = ?", (parse_snowflake(user_id),)).fetchone()
if row: if row:
data = dict(row) data = dict(row)
if data.get("roles"): if data.get("roles"):
@ -650,11 +718,11 @@ class BackupDatabase:
def get_messages_paged(self, channel_id: str, limit: int = 100, offset: int = 0, after_id: Optional[str] = None) -> List[Dict[str, Any]]: def get_messages_paged(self, channel_id: str, limit: int = 100, offset: int = 0, after_id: Optional[str] = None) -> List[Dict[str, Any]]:
with self._lock: with self._lock:
query = "SELECT * FROM messages WHERE channel_id = ?" query = "SELECT * FROM messages WHERE channel_id = ?"
params = [str(channel_id)] params = [parse_snowflake(channel_id)]
if after_id: if after_id:
query += " AND id > ?" query += " AND id > ?"
params.append(str(after_id)) params.append(parse_snowflake(after_id))
query += " ORDER BY id ASC LIMIT ? OFFSET ?" query += " ORDER BY id ASC LIMIT ? OFFSET ?"
params.extend([limit, offset]) params.extend([limit, offset])
@ -725,7 +793,7 @@ class BackupDatabase:
def delete_channel_messages(self, channel_id: Union[str, int]): def delete_channel_messages(self, channel_id: Union[str, int]):
"""Deletes all messages and related metadata for a specific channel and its threads.""" """Deletes all messages and related metadata for a specific channel and its threads."""
cid = str(channel_id) cid = parse_snowflake(channel_id)
with self._lock: with self._lock:
# 1. Identify all channel IDs involved (parent + all threads) # 1. Identify all channel IDs involved (parent + all threads)
target_ids = [cid] target_ids = [cid]
@ -807,4 +875,4 @@ class BackupDatabase:
self._conn.commit() self._conn.commit()
self._conn.close() self._conn.close()
except Exception: except Exception:
pass pass

View file

@ -129,7 +129,7 @@ class MigrationContext:
# or a logical subfolder. # or a logical subfolder.
base_dir = getattr(self, "base_dir", "") base_dir = getattr(self, "base_dir", "")
self.state.set_folder(community_id, clean_name, base_dir=base_dir) self.state.set_folder(community_id, clean_name, self.target_platform, base_dir=base_dir)
async def start_connections(self): async def start_connections(self):
await self.discord_reader.start() await self.discord_reader.start()

View file

@ -6,6 +6,7 @@ from pathlib import Path
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
import threading import threading
import sys import sys
from src.core.utils import parse_snowflake
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -15,8 +16,9 @@ class MigrationDatabase:
Replaces the memory-bloated and O(N^2) JSON persistence for messages. Replaces the memory-bloated and O(N^2) JSON persistence for messages.
""" """
def __init__(self, db_path: Path): def __init__(self, db_path: Path, platform: str = None):
self.db_path = db_path self.db_path = db_path
self.platform = platform.lower() if platform else None
self._local = threading.local() self._local = threading.local()
self._init_db() self._init_db()
@ -27,38 +29,118 @@ class MigrationDatabase:
return self._local.conn return self._local.conn
def _init_db(self): def _init_db(self):
"""Initialize tables if they don't exist.""" """Initialize tables if they don't exist and handle migrations/platform-specific schemas."""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
# 1. Determine active platform and column types
# Create metadata table first as we need it for platform tracking
cursor.execute("CREATE TABLE IF NOT EXISTS metadata (key TEXT PRIMARY KEY, value TEXT)")
# Load platform from DB if exists
cursor.execute("SELECT value FROM metadata WHERE key = ?", ("target_platform",))
row = cursor.fetchone()
stored_platform = row[0] if row else None
# If platform provided, update stored platform. If not provided, use stored.
active_platform = self.platform or stored_platform or "stoat" # Default to stoat if unknown
if self.platform and self.platform != stored_platform:
cursor.execute("INSERT OR REPLACE INTO metadata (key, value) VALUES (?, ?)", ("target_platform", active_platform))
conn.commit()
# Define types
# source_type is always INTEGER (Discord/Fluxer snowflakes)
# target_type is INTEGER for Fluxer, TEXT for Stoat
source_type = "INTEGER"
target_type = "INTEGER" if active_platform == "fluxer" else "TEXT"
# 2. Universal ID Migration (TEXT -> INTEGER vs Platform Switch)
# Mapping of table names to columns that must match their respective types
# key: table, value: (discord_cols, target_cols)
id_migrations = {
"message_mappings": (["source_msg_id"], ["channel_id", "target_msg_id"]),
"thread_mappings": (["source_msg_id"], ["channel_id", "thread_id", "target_msg_id"]),
"channel_tracking": ([], ["channel_id", "last_msg_id"]),
"thread_tracking": ([], ["channel_id", "thread_id", "last_msg_id"]),
"server_mappings": (["source_id"], ["target_id"]),
"asset_mappings": (["source_id"], ["target_id"]),
"user_alias": (["user_id"], [])
}
for table, (discord_cols, target_cols) in id_migrations.items():
cursor.execute(f"SELECT count(*) FROM sqlite_master WHERE type='table' AND name='{table}'")
res = cursor.fetchone()
if not res or res[0] == 0:
continue
cursor.execute(f"PRAGMA table_info({table})")
cols = cursor.fetchall()
needs_migration = False
for col in cols:
c_name, c_type = col[1], col[2]
if c_name in discord_cols and c_type != source_type:
needs_migration = True
break
if c_name in target_cols and c_type != target_type:
needs_migration = True
break
if needs_migration:
logger.info(f"MigrationDatabase: Migrating {table} schema (Platform: {active_platform})")
cursor.execute(f"ALTER TABLE {table} RENAME TO {table}_old")
if table == "message_mappings":
cursor.execute(f"CREATE TABLE message_mappings (channel_id {target_type}, source_msg_id {source_type}, target_msg_id {target_type}, timestamp TEXT, PRIMARY KEY (channel_id, source_msg_id))")
elif table == "thread_mappings":
cursor.execute(f"CREATE TABLE thread_mappings (channel_id {target_type}, thread_id {target_type}, source_msg_id {source_type}, target_msg_id {target_type}, timestamp TEXT, PRIMARY KEY (channel_id, thread_id, source_msg_id))")
elif table == "channel_tracking":
cursor.execute(f"CREATE TABLE channel_tracking (channel_id {target_type} PRIMARY KEY, last_msg_id {target_type}, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0)")
elif table == "thread_tracking":
cursor.execute(f"CREATE TABLE thread_tracking (channel_id {target_type}, thread_id {target_type}, last_msg_id {target_type}, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0, completed INTEGER DEFAULT 0, PRIMARY KEY (channel_id, thread_id))")
elif table == "server_mappings":
cursor.execute(f"CREATE TABLE server_mappings (category TEXT, source_id {source_type}, target_id {target_type}, PRIMARY KEY (category, source_id))")
elif table == "asset_mappings":
cursor.execute(f"CREATE TABLE asset_mappings (category TEXT, source_id {source_type}, target_id {target_type}, PRIMARY KEY (category, source_id))")
elif table == "user_alias":
cursor.execute(f"CREATE TABLE user_alias (user_id {source_type} PRIMARY KEY, alias TEXT UNIQUE)")
old_cols = [c[1] for c in cursor.execute(f"PRAGMA table_info({table}_old)").fetchall()]
new_cols = [c[1] for c in cursor.execute(f"PRAGMA table_info({table})").fetchall()]
common_cols = [c for c in old_cols if c in new_cols]
col_str = ", ".join(common_cols)
cursor.execute(f"INSERT OR IGNORE INTO {table} ({col_str}) SELECT {col_str} FROM {table}_old")
cursor.execute(f"DROP TABLE {table}_old")
# Initial Creation / Ensure Schema Correctness
# Table for message mappings: SourceID -> TargetID # Table for message mappings: SourceID -> TargetID
cursor.execute(""" cursor.execute(f"""
CREATE TABLE IF NOT EXISTS message_mappings ( CREATE TABLE IF NOT EXISTS message_mappings (
channel_id TEXT, channel_id {target_type},
source_msg_id TEXT, source_msg_id {source_type},
target_msg_id TEXT, target_msg_id {target_type},
timestamp TEXT, timestamp TEXT,
PRIMARY KEY (channel_id, source_msg_id) PRIMARY KEY (channel_id, source_msg_id)
) )
""") """)
# Table for thread mappings # Table for thread mappings
cursor.execute(""" cursor.execute(f"""
CREATE TABLE IF NOT EXISTS thread_mappings ( CREATE TABLE IF NOT EXISTS thread_mappings (
channel_id TEXT, channel_id {target_type},
thread_id TEXT, thread_id {target_type},
source_msg_id TEXT, source_msg_id {source_type},
target_msg_id TEXT, target_msg_id {target_type},
timestamp TEXT, timestamp TEXT,
PRIMARY KEY (channel_id, thread_id, source_msg_id) PRIMARY KEY (channel_id, thread_id, source_msg_id)
) )
""") """)
# Table for per-channel stats and tracking # Table for per-channel stats and tracking
cursor.execute(""" cursor.execute(f"""
CREATE TABLE IF NOT EXISTS channel_tracking ( CREATE TABLE IF NOT EXISTS channel_tracking (
channel_id TEXT PRIMARY KEY, channel_id {target_type} PRIMARY KEY,
last_msg_id TEXT, last_msg_id {target_type},
last_msg_ts TEXT, last_msg_ts TEXT,
msg_count INTEGER DEFAULT 0, msg_count INTEGER DEFAULT 0,
file_count INTEGER DEFAULT 0 file_count INTEGER DEFAULT 0
@ -66,11 +148,11 @@ class MigrationDatabase:
""") """)
# Table for per-thread stats # Table for per-thread stats
cursor.execute(""" cursor.execute(f"""
CREATE TABLE IF NOT EXISTS thread_tracking ( CREATE TABLE IF NOT EXISTS thread_tracking (
channel_id TEXT, channel_id {target_type},
thread_id TEXT, thread_id {target_type},
last_msg_id TEXT, last_msg_id {target_type},
last_msg_ts TEXT, last_msg_ts TEXT,
msg_count INTEGER DEFAULT 0, msg_count INTEGER DEFAULT 0,
file_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0,
@ -79,32 +161,32 @@ class MigrationDatabase:
) )
""") """)
# Add completed column if it doesn't exist (backward compatibility for existing resumption DBs) # Add completed column if it doesn't exist (backward compatibility)
try: try:
cursor.execute("ALTER TABLE thread_tracking ADD COLUMN completed INTEGER DEFAULT 0") cursor.execute("ALTER TABLE thread_tracking ADD COLUMN completed INTEGER DEFAULT 0")
except sqlite3.OperationalError: except sqlite3.OperationalError:
pass # Already exists pass # Already exists
# Table for server entity mappings (channels, roles, categories) # Table for server entity mappings (channels, roles, categories)
cursor.execute(""" cursor.execute(f"""
CREATE TABLE IF NOT EXISTS server_mappings ( CREATE TABLE IF NOT EXISTS server_mappings (
category TEXT, category TEXT,
source_id TEXT, source_id {source_type},
target_id TEXT, target_id {target_type},
PRIMARY KEY (category, source_id) PRIMARY KEY (category, source_id)
) )
""") """)
# Table for asset mappings (emojis, stickers) # Table for asset mappings (emojis, stickers)
cursor.execute(""" cursor.execute(f"""
CREATE TABLE IF NOT EXISTS asset_mappings ( CREATE TABLE IF NOT EXISTS asset_mappings (
category TEXT, category TEXT,
source_id TEXT, source_id {source_type},
target_id TEXT, target_id {target_type},
PRIMARY KEY (category, source_id) PRIMARY KEY (category, source_id)
) )
""") """)
# Migrate old entity_mappings if it exists # Migrate old entity_mappings if it exists
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='entity_mappings'") cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='entity_mappings'")
if cursor.fetchone(): if cursor.fetchone():
@ -125,18 +207,10 @@ class MigrationDatabase:
# Drop old table # Drop old table
cursor.execute("DROP TABLE entity_mappings") cursor.execute("DROP TABLE entity_mappings")
# Table for general metadata
cursor.execute("""
CREATE TABLE IF NOT EXISTS metadata (
key TEXT PRIMARY KEY,
value TEXT
)
""")
# Table for auto-generated user aliases (user_id -> alias) # Table for auto-generated user aliases (user_id -> alias)
cursor.execute(""" cursor.execute(f"""
CREATE TABLE IF NOT EXISTS user_alias ( CREATE TABLE IF NOT EXISTS user_alias (
user_id TEXT PRIMARY KEY, user_id {source_type} PRIMARY KEY,
alias TEXT UNIQUE alias TEXT UNIQUE
) )
""") """)
@ -152,17 +226,30 @@ class MigrationDatabase:
conn = self._get_conn() conn = self._get_conn()
conn.execute( conn.execute(
"INSERT OR REPLACE INTO message_mappings (channel_id, source_msg_id, target_msg_id, timestamp) VALUES (?, ?, ?, ?)", "INSERT OR REPLACE INTO message_mappings (channel_id, source_msg_id, target_msg_id, timestamp) VALUES (?, ?, ?, ?)",
(channel_id, source_id, target_id, timestamp) (str(channel_id), parse_snowflake(source_id), str(target_id), timestamp)
) )
conn.commit() conn.commit()
def get_target_message_id(self, channel_id: str, source_id: str) -> Optional[str]: def get_target_message_id(self, channel_id: str, source_id: str) -> Optional[Union[str, int]]:
conn = self._get_conn() conn = self._get_conn()
row = conn.execute( row = conn.execute(
"SELECT target_msg_id FROM message_mappings WHERE channel_id = ? AND source_msg_id = ?", "SELECT target_msg_id FROM message_mappings WHERE channel_id = ? AND source_msg_id = ?",
(channel_id, source_id) (str(channel_id), parse_snowflake(source_id))
).fetchone() ).fetchone()
return row["target_msg_id"] if row else None if row:
val = row["target_msg_id"]
return str(val) if self.platform == "stoat" else val
return None
def get_all_message_mappings(self, channel_id: str) -> Dict[Union[str, int], Union[str, int]]:
conn = self._get_conn()
rows = conn.execute(
"SELECT source_msg_id, target_msg_id FROM message_mappings WHERE channel_id = ?",
(str(channel_id),)
).fetchall()
if self.platform == "stoat":
return {str(row["source_msg_id"]): str(row["target_msg_id"]) for row in rows}
return {row["source_msg_id"]: row["target_msg_id"] for row in rows}
# --- User Alias Methods --- # --- User Alias Methods ---
@ -205,7 +292,7 @@ class MigrationDatabase:
conn = self._get_conn() conn = self._get_conn()
# Check for existing alias # Check for existing alias
row = conn.execute("SELECT alias FROM user_alias WHERE user_id = ?", (str(user_id),)).fetchone() row = conn.execute("SELECT alias FROM user_alias WHERE user_id = ?", (parse_snowflake(user_id),)).fetchone()
if row: if row:
return row["alias"] return row["alias"]
@ -215,7 +302,7 @@ class MigrationDatabase:
new_alias = self._generate_alias() new_alias = self._generate_alias()
conn.execute( conn.execute(
"INSERT INTO user_alias (user_id, alias) VALUES (?, ?)", "INSERT INTO user_alias (user_id, alias) VALUES (?, ?)",
(str(user_id), new_alias) (parse_snowflake(user_id), new_alias)
) )
conn.commit() conn.commit()
return new_alias return new_alias
@ -239,31 +326,36 @@ class MigrationDatabase:
conn = self._get_conn() conn = self._get_conn()
conn.execute( conn.execute(
"INSERT OR REPLACE INTO server_mappings (category, source_id, target_id) VALUES (?, ?, ?)", "INSERT OR REPLACE INTO server_mappings (category, source_id, target_id) VALUES (?, ?, ?)",
(category, str(source_id), str(target_id)) (category, parse_snowflake(source_id), str(target_id))
) )
conn.commit() conn.commit()
def get_server_mapping(self, category: str, source_id: str) -> Optional[str]: def get_server_mapping(self, category: str, source_id: str) -> Optional[Union[str, int]]:
conn = self._get_conn() conn = self._get_conn()
row = conn.execute( row = conn.execute(
"SELECT target_id FROM server_mappings WHERE category = ? AND source_id = ?", "SELECT target_id FROM server_mappings WHERE category = ? AND source_id = ?",
(category, str(source_id)) (category, parse_snowflake(source_id))
).fetchone() ).fetchone()
return row["target_id"] if row else None if row:
val = row["target_id"]
return str(val) if self.platform == "stoat" else val
return None
def get_all_server_mappings(self, category: str) -> Dict[str, str]: def get_all_server_mappings(self, category: str) -> Dict[Union[str, int], Union[str, int]]:
conn = self._get_conn() conn = self._get_conn()
rows = conn.execute( rows = conn.execute(
"SELECT source_id, target_id FROM server_mappings WHERE category = ?", "SELECT source_id, target_id FROM server_mappings WHERE category = ?",
(category,) (category,)
).fetchall() ).fetchall()
if self.platform == "stoat":
return {str(row["source_id"]): str(row["target_id"]) for row in rows}
return {row["source_id"]: row["target_id"] for row in rows} return {row["source_id"]: row["target_id"] for row in rows}
def delete_server_mapping(self, category: str, source_id: str): def delete_server_mapping(self, category: str, source_id: str):
conn = self._get_conn() conn = self._get_conn()
conn.execute( conn.execute(
"DELETE FROM server_mappings WHERE category = ? AND source_id = ?", "DELETE FROM server_mappings WHERE category = ? AND source_id = ?",
(category, str(source_id)) (category, parse_snowflake(source_id))
) )
conn.commit() conn.commit()
@ -281,31 +373,36 @@ class MigrationDatabase:
conn = self._get_conn() conn = self._get_conn()
conn.execute( conn.execute(
"INSERT OR REPLACE INTO asset_mappings (category, source_id, target_id) VALUES (?, ?, ?)", "INSERT OR REPLACE INTO asset_mappings (category, source_id, target_id) VALUES (?, ?, ?)",
(category, str(source_id), str(target_id)) (category, parse_snowflake(source_id), str(target_id))
) )
conn.commit() conn.commit()
def get_asset_mapping(self, category: str, source_id: str) -> Optional[str]: def get_asset_mapping(self, category: str, source_id: str) -> Optional[Union[str, int]]:
conn = self._get_conn() conn = self._get_conn()
row = conn.execute( row = conn.execute(
"SELECT target_id FROM asset_mappings WHERE category = ? AND source_id = ?", "SELECT target_id FROM asset_mappings WHERE category = ? AND source_id = ?",
(category, str(source_id)) (category, parse_snowflake(source_id))
).fetchone() ).fetchone()
return row["target_id"] if row else None if row:
val = row["target_id"]
return str(val) if self.platform == "stoat" else val
return None
def get_all_asset_mappings(self, category: str) -> Dict[str, str]: def get_all_asset_mappings(self, category: str) -> Dict[Union[str, int], Union[str, int]]:
conn = self._get_conn() conn = self._get_conn()
rows = conn.execute( rows = conn.execute(
"SELECT source_id, target_id FROM asset_mappings WHERE category = ?", "SELECT source_id, target_id FROM asset_mappings WHERE category = ?",
(category,) (category,)
).fetchall() ).fetchall()
if self.platform == "stoat":
return {str(row["source_id"]): str(row["target_id"]) for row in rows}
return {row["source_id"]: row["target_id"] for row in rows} return {row["source_id"]: row["target_id"] for row in rows}
def delete_asset_mapping(self, category: str, source_id: str): def delete_asset_mapping(self, category: str, source_id: str):
conn = self._get_conn() conn = self._get_conn()
conn.execute( conn.execute(
"DELETE FROM asset_mappings WHERE category = ? AND source_id = ?", "DELETE FROM asset_mappings WHERE category = ? AND source_id = ?",
(category, str(source_id)) (category, parse_snowflake(source_id))
) )
conn.commit() conn.commit()
@ -332,23 +429,23 @@ class MigrationDatabase:
def update_channel_tracking(self, channel_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0): def update_channel_tracking(self, channel_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0):
conn = self._get_conn() conn = self._get_conn()
# Initialize if missing # Initialize if missing
conn.execute("INSERT OR IGNORE INTO channel_tracking (channel_id) VALUES (?)", (channel_id,)) conn.execute("INSERT OR IGNORE INTO channel_tracking (channel_id) VALUES (?)", (str(channel_id),))
if last_msg_id: if last_msg_id:
conn.execute("UPDATE channel_tracking SET last_msg_id = ? WHERE channel_id = ?", (last_msg_id, channel_id)) conn.execute("UPDATE channel_tracking SET last_msg_id = ? WHERE channel_id = ?", (str(last_msg_id), str(channel_id)))
if last_msg_ts: if last_msg_ts:
conn.execute("UPDATE channel_tracking SET last_msg_ts = ? WHERE channel_id = ?", (last_msg_ts, channel_id)) conn.execute("UPDATE channel_tracking SET last_msg_ts = ? WHERE channel_id = ?", (last_msg_ts, str(channel_id)))
if msg_inc != 0 or file_inc != 0: if msg_inc != 0 or file_inc != 0:
conn.execute( conn.execute(
"UPDATE channel_tracking SET msg_count = msg_count + ?, file_count = file_count + ? WHERE channel_id = ?", "UPDATE channel_tracking SET msg_count = msg_count + ?, file_count = file_count + ? WHERE channel_id = ?",
(msg_inc, file_inc, channel_id) (msg_inc, file_inc, str(channel_id))
) )
conn.commit() conn.commit()
def get_channel_tracking(self, channel_id: str) -> Dict[str, Any]: def get_channel_tracking(self, channel_id: str) -> Dict[str, Any]:
conn = self._get_conn() conn = self._get_conn()
row = conn.execute("SELECT * FROM channel_tracking WHERE channel_id = ?", (channel_id,)).fetchone() row = conn.execute("SELECT * FROM channel_tracking WHERE channel_id = ?", (str(channel_id),)).fetchone()
if row: if row:
return dict(row) return dict(row)
return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0} return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0}
@ -358,39 +455,42 @@ class MigrationDatabase:
conn = self._get_conn() conn = self._get_conn()
conn.execute( conn.execute(
"INSERT OR REPLACE INTO thread_mappings (channel_id, thread_id, source_msg_id, target_msg_id, timestamp) VALUES (?, ?, ?, ?, ?)", "INSERT OR REPLACE INTO thread_mappings (channel_id, thread_id, source_msg_id, target_msg_id, timestamp) VALUES (?, ?, ?, ?, ?)",
(channel_id, thread_id, source_id, target_id, timestamp) (str(channel_id), str(thread_id), parse_snowflake(source_id), str(target_id), timestamp)
) )
conn.commit() conn.commit()
def get_target_thread_message_id(self, channel_id: str, thread_id: str, source_id: str) -> Optional[str]: def get_target_thread_message_id(self, channel_id: str, thread_id: str, source_id: str) -> Optional[Union[str, int]]:
conn = self._get_conn() conn = self._get_conn()
row = conn.execute( row = conn.execute(
"SELECT target_msg_id FROM thread_mappings WHERE channel_id = ? AND thread_id = ? AND source_msg_id = ?", "SELECT target_msg_id FROM thread_mappings WHERE channel_id = ? AND thread_id = ? AND source_msg_id = ?",
(channel_id, thread_id, source_id) (str(channel_id), str(thread_id), parse_snowflake(source_id))
).fetchone() ).fetchone()
return row["target_msg_id"] if row else None if row:
val = row["target_msg_id"]
return str(val) if self.platform == "stoat" else val
return None
def update_thread_tracking(self, channel_id: str, thread_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0, completed: int = None): def update_thread_tracking(self, channel_id: str, thread_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0, completed: int = None):
conn = self._get_conn() conn = self._get_conn()
conn.execute("INSERT OR IGNORE INTO thread_tracking (channel_id, thread_id) VALUES (?, ?)", (channel_id, thread_id)) conn.execute("INSERT OR IGNORE INTO thread_tracking (channel_id, thread_id) VALUES (?, ?)", (str(channel_id), str(thread_id)))
if last_msg_id: if last_msg_id:
conn.execute("UPDATE thread_tracking SET last_msg_id = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_id, channel_id, thread_id)) conn.execute("UPDATE thread_tracking SET last_msg_id = ? WHERE channel_id = ? AND thread_id = ?", (str(last_msg_id), str(channel_id), str(thread_id)))
if last_msg_ts: if last_msg_ts:
conn.execute("UPDATE thread_tracking SET last_msg_ts = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_ts, channel_id, thread_id)) conn.execute("UPDATE thread_tracking SET last_msg_ts = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_ts, str(channel_id), str(thread_id)))
if completed is not None: if completed is not None:
conn.execute("UPDATE thread_tracking SET completed = ? WHERE channel_id = ? AND thread_id = ?", (completed, channel_id, thread_id)) conn.execute("UPDATE thread_tracking SET completed = ? WHERE channel_id = ? AND thread_id = ?", (completed, str(channel_id), str(thread_id)))
if msg_inc != 0 or file_inc != 0: if msg_inc != 0 or file_inc != 0:
conn.execute( conn.execute(
"UPDATE thread_tracking SET msg_count = msg_count + ?, file_count = file_count + ? WHERE channel_id = ? AND thread_id = ?", "UPDATE thread_tracking SET msg_count = msg_count + ?, file_count = file_count + ? WHERE channel_id = ? AND thread_id = ?",
(msg_inc, file_inc, channel_id, thread_id) (msg_inc, file_inc, str(channel_id), str(thread_id))
) )
conn.commit() conn.commit()
def get_thread_tracking(self, channel_id: str, thread_id: str) -> Dict[str, Any]: def get_thread_tracking(self, channel_id: str, thread_id: str) -> Dict[str, Any]:
conn = self._get_conn() conn = self._get_conn()
row = conn.execute("SELECT * FROM thread_tracking WHERE channel_id = ? AND thread_id = ?", (channel_id, thread_id)).fetchone() row = conn.execute("SELECT * FROM thread_tracking WHERE channel_id = ? AND thread_id = ?", (str(channel_id), str(thread_id))).fetchone()
if row: if row:
return dict(row) return dict(row)
return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0} return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0}
@ -398,10 +498,10 @@ class MigrationDatabase:
def clear_channel_data(self, channel_id: str): def clear_channel_data(self, channel_id: str):
"""Purge all mappings and tracking data for a specific channel and its threads.""" """Purge all mappings and tracking data for a specific channel and its threads."""
conn = self._get_conn() conn = self._get_conn()
conn.execute("DELETE FROM message_mappings WHERE channel_id = ?", (channel_id,)) conn.execute("DELETE FROM message_mappings WHERE channel_id = ?", (str(channel_id),))
conn.execute("DELETE FROM thread_mappings WHERE channel_id = ?", (channel_id,)) conn.execute("DELETE FROM thread_mappings WHERE channel_id = ?", (str(channel_id),))
conn.execute("DELETE FROM channel_tracking WHERE channel_id = ?", (channel_id,)) conn.execute("DELETE FROM channel_tracking WHERE channel_id = ?", (str(channel_id),))
conn.execute("DELETE FROM thread_tracking WHERE channel_id = ?", (channel_id,)) conn.execute("DELETE FROM thread_tracking WHERE channel_id = ?", (str(channel_id),))
conn.commit() conn.commit()
logger.info(f"Cleared all tracking and mapping data for channel: {channel_id}") logger.info(f"Cleared all tracking and mapping data for channel: {channel_id}")

View file

@ -109,7 +109,7 @@ class DiscordExporter:
"name": r.name, "name": r.name,
"color": r.color.value, "color": r.color.value,
"position": r.position, "position": r.position,
"permissions": str(r.permissions.value), "permissions": r.permissions.value,
"hoist": 1 if r.hoist else 0, "hoist": 1 if r.hoist else 0,
"mentionable": 1 if r.mentionable else 0 "mentionable": 1 if r.mentionable else 0
}) })
@ -563,9 +563,10 @@ class DiscordExporter:
}) })
# 5. Message data # 5. Message data
from src.core.utils import parse_snowflake
message_reference = None message_reference = None
if msg.reference and msg.reference.message_id: if msg.reference and msg.reference.message_id:
message_reference = str(msg.reference.message_id) message_reference = parse_snowflake(msg.reference.message_id)
# 5.5 Forwarded snapshots # 5.5 Forwarded snapshots
content = msg.content or "" content = msg.content or ""

View file

@ -1,7 +1,10 @@
import json import json
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Dict, Any, Optional from typing import Dict, Optional, Any, Union, TYPE_CHECKING
if TYPE_CHECKING:
from src.core.database import MigrationDatabase
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -119,23 +122,23 @@ class MigrationState:
# --- Properties for backward compatibility --- # --- Properties for backward compatibility ---
@property @property
def channel_map(self) -> Dict[str, str]: def channel_map(self) -> Dict[Union[str, int], Union[str, int]]:
return self.db.get_all_server_mappings("channel") if self.db else {} return self.db.get_all_server_mappings("channel") if self.db else {}
@property @property
def category_map(self) -> Dict[str, str]: def category_map(self) -> Dict[Union[str, int], Union[str, int]]:
return self.db.get_all_server_mappings("category") if self.db else {} return self.db.get_all_server_mappings("category") if self.db else {}
@property @property
def role_map(self) -> Dict[str, str]: def role_map(self) -> Dict[Union[str, int], Union[str, int]]:
return self.db.get_all_server_mappings("role") if self.db else {} return self.db.get_all_server_mappings("role") if self.db else {}
@property @property
def emoji_map(self) -> Dict[str, str]: def emoji_map(self) -> Dict[Union[str, int], Union[str, int]]:
return self.db.get_all_asset_mappings("emoji") if self.db else {} return self.db.get_all_asset_mappings("emoji") if self.db else {}
@property @property
def sticker_map(self) -> Dict[str, str]: def sticker_map(self) -> Dict[Union[str, int], Union[str, int]]:
return self.db.get_all_asset_mappings("sticker") if self.db else {} return self.db.get_all_asset_mappings("sticker") if self.db else {}
@property @property
@ -153,7 +156,7 @@ class MigrationState:
if self._ensure_db(): if self._ensure_db():
self.db.set_message_mapping(str(target_channel_id), str(discord_id), str(target_id)) self.db.set_message_mapping(str(target_channel_id), str(discord_id), str(target_id))
def get_target_message_id(self, target_channel_id: str, discord_id: str) -> str | None: def get_target_message_id(self, target_channel_id: str, discord_id: str) -> str | int | None:
if self._ensure_db(): if self._ensure_db():
return self.db.get_target_message_id(str(target_channel_id), str(discord_id)) return self.db.get_target_message_id(str(target_channel_id), str(discord_id))
return None return None
@ -161,7 +164,7 @@ class MigrationState:
def set_message_mapping(self, target_channel_id: str, discord_id: str, target_id: str): def set_message_mapping(self, target_channel_id: str, discord_id: str, target_id: str):
self.set_target_message_mapping(target_channel_id, discord_id, target_id) self.set_target_message_mapping(target_channel_id, discord_id, target_id)
def get_fluxer_message_id(self, target_channel_id: str, discord_id: str) -> str | None: def get_fluxer_message_id(self, target_channel_id: str, discord_id: str) -> str | int | None:
return self.get_target_message_id(target_channel_id, discord_id) return self.get_target_message_id(target_channel_id, discord_id)
def increment_stats(self, target_channel_id: str, messages: int = 1, files: int = 0): def increment_stats(self, target_channel_id: str, messages: int = 1, files: int = 0):
@ -258,7 +261,7 @@ class MigrationState:
if self._ensure_db(): if self._ensure_db():
self.db.clear_channel_data(str(target_channel_id)) self.db.clear_channel_data(str(target_channel_id))
def set_folder(self, server_id: str, clean_name: str, base_dir: Path | str = ""): def set_folder(self, server_id: str, clean_name: str, platform: str = "stoat", base_dir: Path | str = ""):
""" """
Initializes the SQLite database based on community name and ID. Initializes the SQLite database based on community name and ID.
Filename: {name}-{id}.db (Flat structure) Filename: {name}-{id}.db (Flat structure)
@ -294,7 +297,7 @@ class MigrationState:
from src.core.database import MigrationDatabase from src.core.database import MigrationDatabase
if self.db: if self.db:
self.db.close() self.db.close()
self.db = MigrationDatabase(db_path) self.db = MigrationDatabase(db_path, platform=platform)
logger.info(f"Initialized SQLite database at {db_path}") logger.info(f"Initialized SQLite database at {db_path}")
def get_user_alias(self, user_id: str) -> str | None: def get_user_alias(self, user_id: str) -> str | None:

View file

@ -1,14 +1,29 @@
from typing import Any, Optional
import re import re
import logging import logging
from src.core.state import MigrationState
def parse_snowflake(value: Any) -> Optional[int]:
"""Safely parses a Discord ID (Snowflake) from any input, handling 'None' strings."""
if value is None:
return None
s = str(value).strip()
if not s or s.lower() == "none" or s == "NULL":
return None
try:
return int(s)
except ValueError:
return None
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def resolve_discord_links(content: str, state: MigrationState, platform: str, target_server_id: str) -> str: def resolve_discord_links(content: str, state, platform: str, target_server_id: str) -> str:
""" """
Finds Discord message/channel links and resolves them to the target platform Finds Discord message/channel links and resolves them to the target platform
if they have been migrated. if they have been migrated.
""" """
from src.core.state import MigrationState
if not isinstance(state, MigrationState):
logger.warning(f"resolve_discord_links: state is not MigrationState (type: {type(state)})")
if not content: if not content:
return content return content

View file

@ -18,10 +18,10 @@ async def sync_assets_state(context: MigrationContext):
fluxer_stickers = await context.fluxer_writer.client.get_guild_stickers(context.config.fluxer_server_id) fluxer_stickers = await context.fluxer_writer.client.get_guild_stickers(context.config.fluxer_server_id)
# Build name -> id maps and ID sets for Fluxer for fast lookup # Build name -> id maps and ID sets for Fluxer for fast lookup
fluxer_emoji_map = {e.get("name"): str(e.get("id")) for e in fluxer_emojis if e.get("name")} fluxer_emoji_map = {e.get("name"): e.get("id") for e in fluxer_emojis if e.get("name")}
fluxer_sticker_map = {s.get("name"): str(s.get("id")) for s in fluxer_stickers if s.get("name")} fluxer_sticker_map = {s.get("name"): s.get("id") for s in fluxer_stickers if s.get("name")}
fluxer_emoji_ids = {str(e.get("id")) for e in fluxer_emojis} fluxer_emoji_ids = {e.get("id") for e in fluxer_emojis}
fluxer_sticker_ids = {str(s.get("id")) for s in fluxer_stickers} fluxer_sticker_ids = {s.get("id") for s in fluxer_stickers}
updates = 0 updates = 0
removals = 0 removals = 0

View file

@ -73,8 +73,8 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None,
cid = int(match.group(1)) cid = int(match.group(1))
# 1. Check if channel is mapped in state # 1. Check if channel is mapped in state
if channel_map and str(cid) in channel_map: if channel_map and cid in channel_map:
return f"<#{channel_map[str(cid)]}>" return f"<#{channel_map[cid]}>"
# 2. Try to resolve channel name from pre-fetched names # 2. Try to resolve channel name from pre-fetched names
name = None name = None
@ -100,7 +100,7 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None,
def replace_emoji(match): def replace_emoji(match):
animated = match.group(1) == "a" animated = match.group(1) == "a"
name = match.group(2) name = match.group(2)
eid = match.group(3) eid = int(match.group(3))
if emoji_map and eid in emoji_map: if emoji_map and eid in emoji_map:
target_eid = emoji_map[eid] target_eid = emoji_map[eid]

View file

@ -15,8 +15,8 @@ async def sync_roles_state(context: MigrationContext):
fluxer_roles = await context.fluxer_writer.client.get_guild_roles(context.config.fluxer_server_id) fluxer_roles = await context.fluxer_writer.client.get_guild_roles(context.config.fluxer_server_id)
# Build name -> id maps and ID sets for Fluxer for fast lookup # Build name -> id maps and ID sets for Fluxer for fast lookup
fluxer_role_map = {r.get("name"): str(r.get("id")) for r in fluxer_roles if r.get("name")} fluxer_role_map = {r.get("name"): r.get("id") for r in fluxer_roles if r.get("name")}
fluxer_role_ids = {str(r.get("id")) for r in fluxer_roles} fluxer_role_ids = {r.get("id") for r in fluxer_roles}
updates = 0 updates = 0
removals = 0 removals = 0

View file

@ -37,8 +37,7 @@ async def sync_assets_state(context: MigrationContext):
if stoat_id: if stoat_id:
if stoat_id not in stoat_emoji_ids: if stoat_id not in stoat_emoji_ids:
context.state.emoji_map.pop(discord_id, None) context.state.remove_emoji_mapping(discord_id)
context.state.save_state()
removals += 1 removals += 1
elif emoji.name in stoat_emoji_map: elif emoji.name in stoat_emoji_map:
context.state.set_target_emoji_mapping(discord_id, stoat_emoji_map[emoji.name]) context.state.set_target_emoji_mapping(discord_id, stoat_emoji_map[emoji.name])