Merge pull request #8 from rambros3d/sql-db
Switch to INTEGER for snowflake ids
This commit is contained in:
commit
c34d509677
10 changed files with 354 additions and 168 deletions
|
|
@ -4,20 +4,11 @@ import json
|
|||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from src.core.utils import parse_snowflake
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def parse_snowflake(value: Any) -> Optional[int]:
|
||||
"""Safely parses a Discord ID (Snowflake) from any input, handling 'None' strings."""
|
||||
if value is None:
|
||||
return None
|
||||
s = str(value).strip()
|
||||
if not s or s.lower() == "none" or s == "NULL":
|
||||
return None
|
||||
try:
|
||||
return int(s)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
class BackupDatabase:
|
||||
"""Manages the SQLite database for local Discord backups."""
|
||||
|
|
@ -39,24 +30,101 @@ class BackupDatabase:
|
|||
def _migrate_db(self):
|
||||
"""Handles backward compatibility by renaming columns in existing databases."""
|
||||
with self._lock:
|
||||
# Check 'media_pool' table
|
||||
res = self._conn.execute("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='media_pool'").fetchone()
|
||||
if res[0] > 0:
|
||||
cols = self._conn.execute("PRAGMA table_info(media_pool)").fetchall()
|
||||
col_names = [c["name"] for c in cols]
|
||||
if "mime_type" in col_names and "content_type" not in col_names:
|
||||
logger.info("Migrating media_pool: renaming 'mime_type' to 'content_type'")
|
||||
self._conn.execute("ALTER TABLE media_pool RENAME COLUMN mime_type TO content_type")
|
||||
conn = self._conn
|
||||
# 1. MIME Type to Content Type Migrations
|
||||
for table in ["media_pool", "server_assets"]:
|
||||
res = conn.execute(f"SELECT count(*) FROM sqlite_master WHERE type='table' AND name='{table}'").fetchone()
|
||||
if res and res[0] > 0:
|
||||
cols = conn.execute(f"PRAGMA table_info({table})").fetchall()
|
||||
col_names = [c["name"] for c in cols]
|
||||
if "mime_type" in col_names and "content_type" not in col_names:
|
||||
logger.info(f"Migrating {table}: renaming 'mime_type' to 'content_type'")
|
||||
conn.execute(f"ALTER TABLE {table} RENAME COLUMN mime_type TO content_type")
|
||||
|
||||
# Check 'server_assets' table
|
||||
res = self._conn.execute("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='server_assets'").fetchone()
|
||||
if res[0] > 0:
|
||||
cols = self._conn.execute("PRAGMA table_info(server_assets)").fetchall()
|
||||
col_names = [c["name"] for c in cols]
|
||||
if "mime_type" in col_names and "content_type" not in col_names:
|
||||
logger.info("Migrating server_assets: renaming 'mime_type' to 'content_type'")
|
||||
self._conn.execute("ALTER TABLE server_assets RENAME COLUMN mime_type TO content_type")
|
||||
self._conn.commit()
|
||||
# 2. Universal ID Migration (TEXT -> INTEGER)
|
||||
# Mapping of table names to columns that must be INTEGER (Snowflakes)
|
||||
id_migrations = {
|
||||
"guild_profile": ["id", "owner_id"],
|
||||
"roles": ["id", "permissions"],
|
||||
"channels": ["id", "category_id"],
|
||||
"permissions": ["channel_id", "target_id"],
|
||||
"users": ["id"],
|
||||
"messages": ["id", "channel_id", "author_id", "message_reference"],
|
||||
"attachments": ["id", "message_id"],
|
||||
"embeds": ["message_id"],
|
||||
"reactions": ["message_id", "emoji_id"],
|
||||
"message_stickers": ["message_id", "sticker_id"],
|
||||
"threads": ["id", "parent_id"],
|
||||
"forum_tags": ["id", "forum_id", "emoji_id"],
|
||||
"server_assets": ["id"]
|
||||
}
|
||||
|
||||
for table, id_cols in id_migrations.items():
|
||||
res = conn.execute(f"SELECT count(*) FROM sqlite_master WHERE type='table' AND name='{table}'").fetchone()
|
||||
if not res or res[0] == 0:
|
||||
continue
|
||||
|
||||
cols = conn.execute(f"PRAGMA table_info({table})").fetchall()
|
||||
needs_migration = False
|
||||
for col in cols:
|
||||
if col[1] in id_cols and col[2] == "TEXT":
|
||||
needs_migration = True
|
||||
break
|
||||
|
||||
if needs_migration:
|
||||
logger.info(f"Migrating {table}: converting ID columns to INTEGER")
|
||||
# Special Case: messages already handled id, but now generic
|
||||
# We use a temporary table to handle the schema change
|
||||
conn.execute(f"ALTER TABLE {table} RENAME TO {table}_old")
|
||||
|
||||
# We can't easily generate the CREATE TABLE here without duplicating _init_db logic
|
||||
# So we call _init_db to create the NEW table, then copy data
|
||||
# But _init_db has 'IF NOT EXISTS', so we just call it once at the end?
|
||||
# No, we need the table NOW for the INSERT.
|
||||
# I'll just manually define the inserts or do it in _init_db.
|
||||
|
||||
# Actually, a better way is to do the CREATE TABLE here for this specific table.
|
||||
# I'll have to duplicate the schema from _init_db for the migration.
|
||||
|
||||
# Alternatively, since we are already in _migrate_db, we can just do the
|
||||
# specific CREATE TABLE for the table we are migrating.
|
||||
|
||||
if table == "guild_profile":
|
||||
conn.execute("CREATE TABLE guild_profile (id INTEGER PRIMARY KEY, name TEXT, description TEXT, icon_file TEXT, icon_url TEXT, banner_file TEXT, banner_url TEXT, owner_id INTEGER, last_backup TEXT, ignore_channels TEXT)")
|
||||
elif table == "roles":
|
||||
conn.execute("CREATE TABLE roles (id INTEGER PRIMARY KEY, name TEXT, color INTEGER, position INTEGER, permissions INTEGER, hoist INTEGER, mentionable INTEGER)")
|
||||
elif table == "channels":
|
||||
conn.execute("CREATE TABLE channels (id INTEGER PRIMARY KEY, name TEXT, type INTEGER, position INTEGER, category_id INTEGER, topic TEXT, nsfw INTEGER, bitrate INTEGER, slowmode_delay INTEGER)")
|
||||
elif table == "permissions":
|
||||
conn.execute("CREATE TABLE permissions (id INTEGER PRIMARY KEY AUTOINCREMENT, channel_id INTEGER, target_id INTEGER, target_type TEXT, allow INTEGER, deny INTEGER)")
|
||||
elif table == "users":
|
||||
conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, username TEXT, display_name TEXT, avatar_file TEXT, avatar_url TEXT, roles TEXT)")
|
||||
elif table == "messages":
|
||||
conn.execute("CREATE TABLE messages (id INTEGER PRIMARY KEY, channel_id INTEGER, author_id INTEGER, content TEXT, timestamp TEXT, type INTEGER, message_reference INTEGER, is_pinned INTEGER, extra_data TEXT)")
|
||||
elif table == "attachments":
|
||||
conn.execute("CREATE TABLE attachments (id INTEGER PRIMARY KEY, message_id INTEGER, filename TEXT, size INTEGER, url TEXT, content_type TEXT, local_hash TEXT)")
|
||||
elif table == "embeds":
|
||||
conn.execute("CREATE TABLE embeds (id INTEGER PRIMARY KEY AUTOINCREMENT, message_id INTEGER, title TEXT, description TEXT, url TEXT, color INTEGER, timestamp TEXT, thumbnail_url TEXT, image_url TEXT, author_name TEXT, author_url TEXT, author_icon_url TEXT, footer_text TEXT, footer_icon_url TEXT, fields TEXT)")
|
||||
elif table == "reactions":
|
||||
conn.execute("CREATE TABLE reactions (id INTEGER PRIMARY KEY AUTOINCREMENT, message_id INTEGER, emoji_id INTEGER, emoji_name TEXT, count INTEGER)")
|
||||
elif table == "message_stickers":
|
||||
conn.execute("CREATE TABLE message_stickers (message_id INTEGER, sticker_id INTEGER, name TEXT, url TEXT, format_type INTEGER, local_hash TEXT, PRIMARY KEY (message_id, sticker_id))")
|
||||
elif table == "threads":
|
||||
conn.execute("CREATE TABLE threads (id INTEGER PRIMARY KEY, name TEXT, type INTEGER, parent_id INTEGER, message_count INTEGER, member_count INTEGER, archived INTEGER, archive_timestamp TEXT, auto_archive_duration INTEGER, locked INTEGER, applied_tags TEXT)")
|
||||
elif table == "forum_tags":
|
||||
conn.execute("CREATE TABLE forum_tags (id INTEGER PRIMARY KEY, forum_id INTEGER, name TEXT, moderated INTEGER, emoji_id INTEGER, emoji_name TEXT)")
|
||||
elif table == "server_assets":
|
||||
conn.execute("CREATE TABLE server_assets (id INTEGER PRIMARY KEY, name TEXT, type TEXT, filename TEXT, url TEXT, content_type INTEGER)")
|
||||
|
||||
old_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table}_old)").fetchall()]
|
||||
new_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table})").fetchall()]
|
||||
common_cols = [c for c in old_cols if c in new_cols]
|
||||
col_str = ", ".join(common_cols)
|
||||
|
||||
conn.execute(f"INSERT INTO {table} ({col_str}) SELECT {col_str} FROM {table}_old")
|
||||
conn.execute(f"DROP TABLE {table}_old")
|
||||
|
||||
conn.commit()
|
||||
|
||||
def _init_db(self):
|
||||
"""Initializes the database schema."""
|
||||
|
|
@ -66,14 +134,14 @@ class BackupDatabase:
|
|||
# Guild Profile
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS guild_profile (
|
||||
id TEXT PRIMARY KEY,
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT,
|
||||
description TEXT,
|
||||
icon_file TEXT,
|
||||
icon_url TEXT,
|
||||
banner_file TEXT,
|
||||
banner_url TEXT,
|
||||
owner_id TEXT,
|
||||
owner_id INTEGER,
|
||||
last_backup TEXT,
|
||||
ignore_channels TEXT
|
||||
)
|
||||
|
|
@ -82,11 +150,11 @@ class BackupDatabase:
|
|||
# Roles
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS roles (
|
||||
id TEXT PRIMARY KEY,
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT,
|
||||
color INTEGER,
|
||||
position INTEGER,
|
||||
permissions TEXT,
|
||||
permissions INTEGER,
|
||||
hoist INTEGER,
|
||||
mentionable INTEGER
|
||||
)
|
||||
|
|
@ -95,11 +163,11 @@ class BackupDatabase:
|
|||
# Channels
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS channels (
|
||||
id TEXT PRIMARY KEY,
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT,
|
||||
type INTEGER,
|
||||
position INTEGER,
|
||||
category_id TEXT,
|
||||
category_id INTEGER,
|
||||
topic TEXT,
|
||||
nsfw INTEGER,
|
||||
bitrate INTEGER,
|
||||
|
|
@ -111,8 +179,8 @@ class BackupDatabase:
|
|||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS permissions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
channel_id TEXT,
|
||||
target_id TEXT,
|
||||
channel_id INTEGER,
|
||||
target_id INTEGER,
|
||||
target_type TEXT,
|
||||
allow INTEGER,
|
||||
deny INTEGER
|
||||
|
|
@ -123,7 +191,7 @@ class BackupDatabase:
|
|||
# Users (Author cache)
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id TEXT PRIMARY KEY,
|
||||
id INTEGER PRIMARY KEY,
|
||||
username TEXT,
|
||||
display_name TEXT,
|
||||
avatar_file TEXT,
|
||||
|
|
@ -135,13 +203,13 @@ class BackupDatabase:
|
|||
# Messages
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id TEXT PRIMARY KEY,
|
||||
channel_id TEXT,
|
||||
author_id TEXT,
|
||||
id INTEGER PRIMARY KEY,
|
||||
channel_id INTEGER,
|
||||
author_id INTEGER,
|
||||
content TEXT,
|
||||
timestamp TEXT,
|
||||
type INTEGER,
|
||||
message_reference TEXT,
|
||||
message_reference INTEGER,
|
||||
is_pinned INTEGER,
|
||||
extra_data TEXT
|
||||
)
|
||||
|
|
@ -152,8 +220,8 @@ class BackupDatabase:
|
|||
# Attachments
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS attachments (
|
||||
id TEXT PRIMARY KEY,
|
||||
message_id TEXT,
|
||||
id INTEGER PRIMARY KEY,
|
||||
message_id INTEGER,
|
||||
filename TEXT,
|
||||
size INTEGER,
|
||||
url TEXT,
|
||||
|
|
@ -167,7 +235,7 @@ class BackupDatabase:
|
|||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS embeds (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
message_id TEXT,
|
||||
message_id INTEGER,
|
||||
title TEXT,
|
||||
description TEXT,
|
||||
url TEXT,
|
||||
|
|
@ -189,8 +257,8 @@ class BackupDatabase:
|
|||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS reactions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
message_id TEXT,
|
||||
emoji_id TEXT,
|
||||
message_id INTEGER,
|
||||
emoji_id INTEGER,
|
||||
emoji_name TEXT,
|
||||
count INTEGER
|
||||
)
|
||||
|
|
@ -200,8 +268,8 @@ class BackupDatabase:
|
|||
# Message Stickers
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS message_stickers (
|
||||
message_id TEXT,
|
||||
sticker_id TEXT,
|
||||
message_id INTEGER,
|
||||
sticker_id INTEGER,
|
||||
name TEXT,
|
||||
url TEXT,
|
||||
format_type INTEGER,
|
||||
|
|
@ -214,10 +282,10 @@ class BackupDatabase:
|
|||
# Threads
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS threads (
|
||||
id TEXT PRIMARY KEY,
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT,
|
||||
type INTEGER,
|
||||
parent_id TEXT,
|
||||
parent_id INTEGER,
|
||||
message_count INTEGER,
|
||||
member_count INTEGER,
|
||||
archived INTEGER,
|
||||
|
|
@ -232,11 +300,11 @@ class BackupDatabase:
|
|||
# Forum Tags (Definitions for a forum channel)
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS forum_tags (
|
||||
id TEXT PRIMARY KEY,
|
||||
forum_id TEXT,
|
||||
id INTEGER PRIMARY KEY,
|
||||
forum_id INTEGER,
|
||||
name TEXT,
|
||||
moderated INTEGER,
|
||||
emoji_id TEXT,
|
||||
emoji_id INTEGER,
|
||||
emoji_name TEXT
|
||||
)
|
||||
""")
|
||||
|
|
@ -257,7 +325,7 @@ class BackupDatabase:
|
|||
# Server Assets (Emojis, Stickers, etc.)
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS server_assets (
|
||||
id TEXT PRIMARY KEY,
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT,
|
||||
type TEXT,
|
||||
filename TEXT,
|
||||
|
|
@ -276,10 +344,10 @@ class BackupDatabase:
|
|||
INSERT OR REPLACE INTO guild_profile (id, name, description, icon_file, icon_url, banner_file, banner_url, owner_id, last_backup, ignore_channels)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
str(data.get("id")), data.get("name"), data.get("description"),
|
||||
parse_snowflake(data.get("id")), data.get("name"), data.get("description"),
|
||||
data.get("icon_file"), data.get("icon_url"),
|
||||
data.get("banner_file"), data.get("banner_url"),
|
||||
str(data.get("owner_id")),
|
||||
parse_snowflake(data.get("owner_id")),
|
||||
data.get("last_backup"), json.dumps(data.get("ignore_channels", []))
|
||||
))
|
||||
self._conn.commit()
|
||||
|
|
@ -300,7 +368,7 @@ class BackupDatabase:
|
|||
with self._lock:
|
||||
formatted = [
|
||||
{
|
||||
"id": str(r["id"]),
|
||||
"id": parse_snowflake(r["id"]),
|
||||
"name": r["name"],
|
||||
"color": r["color"],
|
||||
"position": r["position"],
|
||||
|
|
@ -347,7 +415,7 @@ class BackupDatabase:
|
|||
with self._lock:
|
||||
formatted = [
|
||||
{
|
||||
"id": str(a["id"]),
|
||||
"id": parse_snowflake(a["id"]),
|
||||
"name": a.get("name"),
|
||||
"type": a.get("type"),
|
||||
"filename": a.get("filename"),
|
||||
|
|
@ -407,7 +475,7 @@ class BackupDatabase:
|
|||
for rea in msg["reactions"]:
|
||||
all_reactions.append({
|
||||
"message_id": msg["id"],
|
||||
"emoji_id": str(rea["emoji_id"]) if rea.get("emoji_id") else None,
|
||||
"emoji_id": parse_snowflake(rea["emoji_id"]) if rea.get("emoji_id") else None,
|
||||
"emoji_name": rea.get("emoji_name"),
|
||||
"count": rea.get("count", 0)
|
||||
})
|
||||
|
|
@ -417,7 +485,7 @@ class BackupDatabase:
|
|||
for st in msg["stickers"]:
|
||||
all_stickers.append({
|
||||
"message_id": msg["id"],
|
||||
"sticker_id": str(st["id"]),
|
||||
"sticker_id": parse_snowflake(st["id"]),
|
||||
"name": st.get("name"),
|
||||
"url": st.get("url"),
|
||||
"format_type": st.get("format_type"),
|
||||
|
|
@ -469,7 +537,7 @@ class BackupDatabase:
|
|||
|
||||
def get_last_message_id(self, channel_id: str) -> Optional[str]:
|
||||
with self._lock:
|
||||
row = self._conn.execute("SELECT id FROM messages WHERE channel_id = ? ORDER BY id DESC LIMIT 1", (str(channel_id),)).fetchone()
|
||||
row = self._conn.execute("SELECT id FROM messages WHERE channel_id = ? ORDER BY id DESC LIMIT 1", (parse_snowflake(channel_id),)).fetchone()
|
||||
return row["id"] if row else None
|
||||
|
||||
def get_media_by_hash(self, file_hash: str) -> Optional[Dict[str, Any]]:
|
||||
|
|
@ -600,7 +668,7 @@ class BackupDatabase:
|
|||
"""Returns forum tag definitions."""
|
||||
with self._lock:
|
||||
if forum_id:
|
||||
rows = self._conn.execute("SELECT * FROM forum_tags WHERE forum_id = ?", (str(forum_id),)).fetchall()
|
||||
rows = self._conn.execute("SELECT * FROM forum_tags WHERE forum_id = ?", (parse_snowflake(forum_id),)).fetchall()
|
||||
else:
|
||||
rows = self._conn.execute("SELECT * FROM forum_tags").fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
|
@ -608,13 +676,13 @@ class BackupDatabase:
|
|||
def get_threads_by_parent(self, parent_id: str) -> List[Dict[str, Any]]:
|
||||
"""Returns all threads belonging to a parent channel."""
|
||||
with self._lock:
|
||||
rows = self._conn.execute("SELECT * FROM threads WHERE parent_id = ?", (str(parent_id),)).fetchall()
|
||||
rows = self._conn.execute("SELECT * FROM threads WHERE parent_id = ?", (parse_snowflake(parent_id),)).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
def get_thread(self, thread_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Retrieves a single thread's metadata."""
|
||||
with self._lock:
|
||||
row = self._conn.execute("SELECT * FROM threads WHERE id = ?", (str(thread_id),)).fetchone()
|
||||
row = self._conn.execute("SELECT * FROM threads WHERE id = ?", (parse_snowflake(thread_id),)).fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def get_all_users(self) -> List[Dict[str, Any]]:
|
||||
|
|
@ -624,7 +692,7 @@ class BackupDatabase:
|
|||
|
||||
def get_user(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
row = self._conn.execute("SELECT * FROM users WHERE id = ?", (str(user_id),)).fetchone()
|
||||
row = self._conn.execute("SELECT * FROM users WHERE id = ?", (parse_snowflake(user_id),)).fetchone()
|
||||
if row:
|
||||
data = dict(row)
|
||||
if data.get("roles"):
|
||||
|
|
@ -650,11 +718,11 @@ class BackupDatabase:
|
|||
def get_messages_paged(self, channel_id: str, limit: int = 100, offset: int = 0, after_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
query = "SELECT * FROM messages WHERE channel_id = ?"
|
||||
params = [str(channel_id)]
|
||||
params = [parse_snowflake(channel_id)]
|
||||
|
||||
if after_id:
|
||||
query += " AND id > ?"
|
||||
params.append(str(after_id))
|
||||
params.append(parse_snowflake(after_id))
|
||||
|
||||
query += " ORDER BY id ASC LIMIT ? OFFSET ?"
|
||||
params.extend([limit, offset])
|
||||
|
|
@ -725,7 +793,7 @@ class BackupDatabase:
|
|||
|
||||
def delete_channel_messages(self, channel_id: Union[str, int]):
|
||||
"""Deletes all messages and related metadata for a specific channel and its threads."""
|
||||
cid = str(channel_id)
|
||||
cid = parse_snowflake(channel_id)
|
||||
with self._lock:
|
||||
# 1. Identify all channel IDs involved (parent + all threads)
|
||||
target_ids = [cid]
|
||||
|
|
|
|||
|
|
@ -129,7 +129,7 @@ class MigrationContext:
|
|||
# or a logical subfolder.
|
||||
base_dir = getattr(self, "base_dir", "")
|
||||
|
||||
self.state.set_folder(community_id, clean_name, base_dir=base_dir)
|
||||
self.state.set_folder(community_id, clean_name, self.target_platform, base_dir=base_dir)
|
||||
|
||||
async def start_connections(self):
|
||||
await self.discord_reader.start()
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from pathlib import Path
|
|||
from typing import Optional, Dict, Any
|
||||
import threading
|
||||
import sys
|
||||
from src.core.utils import parse_snowflake
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -15,8 +16,9 @@ class MigrationDatabase:
|
|||
Replaces the memory-bloated and O(N^2) JSON persistence for messages.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
def __init__(self, db_path: Path, platform: str = None):
|
||||
self.db_path = db_path
|
||||
self.platform = platform.lower() if platform else None
|
||||
self._local = threading.local()
|
||||
self._init_db()
|
||||
|
||||
|
|
@ -27,38 +29,118 @@ class MigrationDatabase:
|
|||
return self._local.conn
|
||||
|
||||
def _init_db(self):
|
||||
"""Initialize tables if they don't exist."""
|
||||
"""Initialize tables if they don't exist and handle migrations/platform-specific schemas."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 1. Determine active platform and column types
|
||||
# Create metadata table first as we need it for platform tracking
|
||||
cursor.execute("CREATE TABLE IF NOT EXISTS metadata (key TEXT PRIMARY KEY, value TEXT)")
|
||||
|
||||
# Load platform from DB if exists
|
||||
cursor.execute("SELECT value FROM metadata WHERE key = ?", ("target_platform",))
|
||||
row = cursor.fetchone()
|
||||
stored_platform = row[0] if row else None
|
||||
|
||||
# If platform provided, update stored platform. If not provided, use stored.
|
||||
active_platform = self.platform or stored_platform or "stoat" # Default to stoat if unknown
|
||||
if self.platform and self.platform != stored_platform:
|
||||
cursor.execute("INSERT OR REPLACE INTO metadata (key, value) VALUES (?, ?)", ("target_platform", active_platform))
|
||||
conn.commit()
|
||||
|
||||
# Define types
|
||||
# source_type is always INTEGER (Discord/Fluxer snowflakes)
|
||||
# target_type is INTEGER for Fluxer, TEXT for Stoat
|
||||
source_type = "INTEGER"
|
||||
target_type = "INTEGER" if active_platform == "fluxer" else "TEXT"
|
||||
|
||||
# 2. Universal ID Migration (TEXT -> INTEGER vs Platform Switch)
|
||||
# Mapping of table names to columns that must match their respective types
|
||||
# key: table, value: (discord_cols, target_cols)
|
||||
id_migrations = {
|
||||
"message_mappings": (["source_msg_id"], ["channel_id", "target_msg_id"]),
|
||||
"thread_mappings": (["source_msg_id"], ["channel_id", "thread_id", "target_msg_id"]),
|
||||
"channel_tracking": ([], ["channel_id", "last_msg_id"]),
|
||||
"thread_tracking": ([], ["channel_id", "thread_id", "last_msg_id"]),
|
||||
"server_mappings": (["source_id"], ["target_id"]),
|
||||
"asset_mappings": (["source_id"], ["target_id"]),
|
||||
"user_alias": (["user_id"], [])
|
||||
}
|
||||
|
||||
for table, (discord_cols, target_cols) in id_migrations.items():
|
||||
cursor.execute(f"SELECT count(*) FROM sqlite_master WHERE type='table' AND name='{table}'")
|
||||
res = cursor.fetchone()
|
||||
if not res or res[0] == 0:
|
||||
continue
|
||||
|
||||
cursor.execute(f"PRAGMA table_info({table})")
|
||||
cols = cursor.fetchall()
|
||||
needs_migration = False
|
||||
for col in cols:
|
||||
c_name, c_type = col[1], col[2]
|
||||
if c_name in discord_cols and c_type != source_type:
|
||||
needs_migration = True
|
||||
break
|
||||
if c_name in target_cols and c_type != target_type:
|
||||
needs_migration = True
|
||||
break
|
||||
|
||||
if needs_migration:
|
||||
logger.info(f"MigrationDatabase: Migrating {table} schema (Platform: {active_platform})")
|
||||
cursor.execute(f"ALTER TABLE {table} RENAME TO {table}_old")
|
||||
|
||||
if table == "message_mappings":
|
||||
cursor.execute(f"CREATE TABLE message_mappings (channel_id {target_type}, source_msg_id {source_type}, target_msg_id {target_type}, timestamp TEXT, PRIMARY KEY (channel_id, source_msg_id))")
|
||||
elif table == "thread_mappings":
|
||||
cursor.execute(f"CREATE TABLE thread_mappings (channel_id {target_type}, thread_id {target_type}, source_msg_id {source_type}, target_msg_id {target_type}, timestamp TEXT, PRIMARY KEY (channel_id, thread_id, source_msg_id))")
|
||||
elif table == "channel_tracking":
|
||||
cursor.execute(f"CREATE TABLE channel_tracking (channel_id {target_type} PRIMARY KEY, last_msg_id {target_type}, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0)")
|
||||
elif table == "thread_tracking":
|
||||
cursor.execute(f"CREATE TABLE thread_tracking (channel_id {target_type}, thread_id {target_type}, last_msg_id {target_type}, last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0, completed INTEGER DEFAULT 0, PRIMARY KEY (channel_id, thread_id))")
|
||||
elif table == "server_mappings":
|
||||
cursor.execute(f"CREATE TABLE server_mappings (category TEXT, source_id {source_type}, target_id {target_type}, PRIMARY KEY (category, source_id))")
|
||||
elif table == "asset_mappings":
|
||||
cursor.execute(f"CREATE TABLE asset_mappings (category TEXT, source_id {source_type}, target_id {target_type}, PRIMARY KEY (category, source_id))")
|
||||
elif table == "user_alias":
|
||||
cursor.execute(f"CREATE TABLE user_alias (user_id {source_type} PRIMARY KEY, alias TEXT UNIQUE)")
|
||||
|
||||
old_cols = [c[1] for c in cursor.execute(f"PRAGMA table_info({table}_old)").fetchall()]
|
||||
new_cols = [c[1] for c in cursor.execute(f"PRAGMA table_info({table})").fetchall()]
|
||||
common_cols = [c for c in old_cols if c in new_cols]
|
||||
col_str = ", ".join(common_cols)
|
||||
|
||||
cursor.execute(f"INSERT OR IGNORE INTO {table} ({col_str}) SELECT {col_str} FROM {table}_old")
|
||||
cursor.execute(f"DROP TABLE {table}_old")
|
||||
|
||||
# Initial Creation / Ensure Schema Correctness
|
||||
# Table for message mappings: SourceID -> TargetID
|
||||
cursor.execute("""
|
||||
cursor.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS message_mappings (
|
||||
channel_id TEXT,
|
||||
source_msg_id TEXT,
|
||||
target_msg_id TEXT,
|
||||
channel_id {target_type},
|
||||
source_msg_id {source_type},
|
||||
target_msg_id {target_type},
|
||||
timestamp TEXT,
|
||||
PRIMARY KEY (channel_id, source_msg_id)
|
||||
)
|
||||
""")
|
||||
|
||||
# Table for thread mappings
|
||||
cursor.execute("""
|
||||
cursor.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS thread_mappings (
|
||||
channel_id TEXT,
|
||||
thread_id TEXT,
|
||||
source_msg_id TEXT,
|
||||
target_msg_id TEXT,
|
||||
channel_id {target_type},
|
||||
thread_id {target_type},
|
||||
source_msg_id {source_type},
|
||||
target_msg_id {target_type},
|
||||
timestamp TEXT,
|
||||
PRIMARY KEY (channel_id, thread_id, source_msg_id)
|
||||
)
|
||||
""")
|
||||
|
||||
# Table for per-channel stats and tracking
|
||||
cursor.execute("""
|
||||
cursor.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS channel_tracking (
|
||||
channel_id TEXT PRIMARY KEY,
|
||||
last_msg_id TEXT,
|
||||
channel_id {target_type} PRIMARY KEY,
|
||||
last_msg_id {target_type},
|
||||
last_msg_ts TEXT,
|
||||
msg_count INTEGER DEFAULT 0,
|
||||
file_count INTEGER DEFAULT 0
|
||||
|
|
@ -66,11 +148,11 @@ class MigrationDatabase:
|
|||
""")
|
||||
|
||||
# Table for per-thread stats
|
||||
cursor.execute("""
|
||||
cursor.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS thread_tracking (
|
||||
channel_id TEXT,
|
||||
thread_id TEXT,
|
||||
last_msg_id TEXT,
|
||||
channel_id {target_type},
|
||||
thread_id {target_type},
|
||||
last_msg_id {target_type},
|
||||
last_msg_ts TEXT,
|
||||
msg_count INTEGER DEFAULT 0,
|
||||
file_count INTEGER DEFAULT 0,
|
||||
|
|
@ -79,28 +161,28 @@ class MigrationDatabase:
|
|||
)
|
||||
""")
|
||||
|
||||
# Add completed column if it doesn't exist (backward compatibility for existing resumption DBs)
|
||||
# Add completed column if it doesn't exist (backward compatibility)
|
||||
try:
|
||||
cursor.execute("ALTER TABLE thread_tracking ADD COLUMN completed INTEGER DEFAULT 0")
|
||||
except sqlite3.OperationalError:
|
||||
pass # Already exists
|
||||
|
||||
# Table for server entity mappings (channels, roles, categories)
|
||||
cursor.execute("""
|
||||
cursor.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS server_mappings (
|
||||
category TEXT,
|
||||
source_id TEXT,
|
||||
target_id TEXT,
|
||||
source_id {source_type},
|
||||
target_id {target_type},
|
||||
PRIMARY KEY (category, source_id)
|
||||
)
|
||||
""")
|
||||
|
||||
# Table for asset mappings (emojis, stickers)
|
||||
cursor.execute("""
|
||||
cursor.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS asset_mappings (
|
||||
category TEXT,
|
||||
source_id TEXT,
|
||||
target_id TEXT,
|
||||
source_id {source_type},
|
||||
target_id {target_type},
|
||||
PRIMARY KEY (category, source_id)
|
||||
)
|
||||
""")
|
||||
|
|
@ -125,18 +207,10 @@ class MigrationDatabase:
|
|||
# Drop old table
|
||||
cursor.execute("DROP TABLE entity_mappings")
|
||||
|
||||
# Table for general metadata
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS metadata (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
# Table for auto-generated user aliases (user_id -> alias)
|
||||
cursor.execute("""
|
||||
cursor.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS user_alias (
|
||||
user_id TEXT PRIMARY KEY,
|
||||
user_id {source_type} PRIMARY KEY,
|
||||
alias TEXT UNIQUE
|
||||
)
|
||||
""")
|
||||
|
|
@ -152,17 +226,30 @@ class MigrationDatabase:
|
|||
conn = self._get_conn()
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO message_mappings (channel_id, source_msg_id, target_msg_id, timestamp) VALUES (?, ?, ?, ?)",
|
||||
(channel_id, source_id, target_id, timestamp)
|
||||
(str(channel_id), parse_snowflake(source_id), str(target_id), timestamp)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_target_message_id(self, channel_id: str, source_id: str) -> Optional[str]:
|
||||
def get_target_message_id(self, channel_id: str, source_id: str) -> Optional[Union[str, int]]:
|
||||
conn = self._get_conn()
|
||||
row = conn.execute(
|
||||
"SELECT target_msg_id FROM message_mappings WHERE channel_id = ? AND source_msg_id = ?",
|
||||
(channel_id, source_id)
|
||||
(str(channel_id), parse_snowflake(source_id))
|
||||
).fetchone()
|
||||
return row["target_msg_id"] if row else None
|
||||
if row:
|
||||
val = row["target_msg_id"]
|
||||
return str(val) if self.platform == "stoat" else val
|
||||
return None
|
||||
|
||||
def get_all_message_mappings(self, channel_id: str) -> Dict[Union[str, int], Union[str, int]]:
|
||||
conn = self._get_conn()
|
||||
rows = conn.execute(
|
||||
"SELECT source_msg_id, target_msg_id FROM message_mappings WHERE channel_id = ?",
|
||||
(str(channel_id),)
|
||||
).fetchall()
|
||||
if self.platform == "stoat":
|
||||
return {str(row["source_msg_id"]): str(row["target_msg_id"]) for row in rows}
|
||||
return {row["source_msg_id"]: row["target_msg_id"] for row in rows}
|
||||
|
||||
# --- User Alias Methods ---
|
||||
|
||||
|
|
@ -205,7 +292,7 @@ class MigrationDatabase:
|
|||
conn = self._get_conn()
|
||||
|
||||
# Check for existing alias
|
||||
row = conn.execute("SELECT alias FROM user_alias WHERE user_id = ?", (str(user_id),)).fetchone()
|
||||
row = conn.execute("SELECT alias FROM user_alias WHERE user_id = ?", (parse_snowflake(user_id),)).fetchone()
|
||||
if row:
|
||||
return row["alias"]
|
||||
|
||||
|
|
@ -215,7 +302,7 @@ class MigrationDatabase:
|
|||
new_alias = self._generate_alias()
|
||||
conn.execute(
|
||||
"INSERT INTO user_alias (user_id, alias) VALUES (?, ?)",
|
||||
(str(user_id), new_alias)
|
||||
(parse_snowflake(user_id), new_alias)
|
||||
)
|
||||
conn.commit()
|
||||
return new_alias
|
||||
|
|
@ -239,31 +326,36 @@ class MigrationDatabase:
|
|||
conn = self._get_conn()
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO server_mappings (category, source_id, target_id) VALUES (?, ?, ?)",
|
||||
(category, str(source_id), str(target_id))
|
||||
(category, parse_snowflake(source_id), str(target_id))
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_server_mapping(self, category: str, source_id: str) -> Optional[str]:
|
||||
def get_server_mapping(self, category: str, source_id: str) -> Optional[Union[str, int]]:
|
||||
conn = self._get_conn()
|
||||
row = conn.execute(
|
||||
"SELECT target_id FROM server_mappings WHERE category = ? AND source_id = ?",
|
||||
(category, str(source_id))
|
||||
(category, parse_snowflake(source_id))
|
||||
).fetchone()
|
||||
return row["target_id"] if row else None
|
||||
if row:
|
||||
val = row["target_id"]
|
||||
return str(val) if self.platform == "stoat" else val
|
||||
return None
|
||||
|
||||
def get_all_server_mappings(self, category: str) -> Dict[str, str]:
|
||||
def get_all_server_mappings(self, category: str) -> Dict[Union[str, int], Union[str, int]]:
|
||||
conn = self._get_conn()
|
||||
rows = conn.execute(
|
||||
"SELECT source_id, target_id FROM server_mappings WHERE category = ?",
|
||||
(category,)
|
||||
).fetchall()
|
||||
if self.platform == "stoat":
|
||||
return {str(row["source_id"]): str(row["target_id"]) for row in rows}
|
||||
return {row["source_id"]: row["target_id"] for row in rows}
|
||||
|
||||
def delete_server_mapping(self, category: str, source_id: str):
|
||||
conn = self._get_conn()
|
||||
conn.execute(
|
||||
"DELETE FROM server_mappings WHERE category = ? AND source_id = ?",
|
||||
(category, str(source_id))
|
||||
(category, parse_snowflake(source_id))
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
|
@ -281,31 +373,36 @@ class MigrationDatabase:
|
|||
conn = self._get_conn()
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO asset_mappings (category, source_id, target_id) VALUES (?, ?, ?)",
|
||||
(category, str(source_id), str(target_id))
|
||||
(category, parse_snowflake(source_id), str(target_id))
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_asset_mapping(self, category: str, source_id: str) -> Optional[str]:
|
||||
def get_asset_mapping(self, category: str, source_id: str) -> Optional[Union[str, int]]:
|
||||
conn = self._get_conn()
|
||||
row = conn.execute(
|
||||
"SELECT target_id FROM asset_mappings WHERE category = ? AND source_id = ?",
|
||||
(category, str(source_id))
|
||||
(category, parse_snowflake(source_id))
|
||||
).fetchone()
|
||||
return row["target_id"] if row else None
|
||||
if row:
|
||||
val = row["target_id"]
|
||||
return str(val) if self.platform == "stoat" else val
|
||||
return None
|
||||
|
||||
def get_all_asset_mappings(self, category: str) -> Dict[str, str]:
|
||||
def get_all_asset_mappings(self, category: str) -> Dict[Union[str, int], Union[str, int]]:
|
||||
conn = self._get_conn()
|
||||
rows = conn.execute(
|
||||
"SELECT source_id, target_id FROM asset_mappings WHERE category = ?",
|
||||
(category,)
|
||||
).fetchall()
|
||||
if self.platform == "stoat":
|
||||
return {str(row["source_id"]): str(row["target_id"]) for row in rows}
|
||||
return {row["source_id"]: row["target_id"] for row in rows}
|
||||
|
||||
def delete_asset_mapping(self, category: str, source_id: str):
|
||||
conn = self._get_conn()
|
||||
conn.execute(
|
||||
"DELETE FROM asset_mappings WHERE category = ? AND source_id = ?",
|
||||
(category, str(source_id))
|
||||
(category, parse_snowflake(source_id))
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
|
@ -332,23 +429,23 @@ class MigrationDatabase:
|
|||
def update_channel_tracking(self, channel_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0):
|
||||
conn = self._get_conn()
|
||||
# Initialize if missing
|
||||
conn.execute("INSERT OR IGNORE INTO channel_tracking (channel_id) VALUES (?)", (channel_id,))
|
||||
conn.execute("INSERT OR IGNORE INTO channel_tracking (channel_id) VALUES (?)", (str(channel_id),))
|
||||
|
||||
if last_msg_id:
|
||||
conn.execute("UPDATE channel_tracking SET last_msg_id = ? WHERE channel_id = ?", (last_msg_id, channel_id))
|
||||
conn.execute("UPDATE channel_tracking SET last_msg_id = ? WHERE channel_id = ?", (str(last_msg_id), str(channel_id)))
|
||||
if last_msg_ts:
|
||||
conn.execute("UPDATE channel_tracking SET last_msg_ts = ? WHERE channel_id = ?", (last_msg_ts, channel_id))
|
||||
conn.execute("UPDATE channel_tracking SET last_msg_ts = ? WHERE channel_id = ?", (last_msg_ts, str(channel_id)))
|
||||
|
||||
if msg_inc != 0 or file_inc != 0:
|
||||
conn.execute(
|
||||
"UPDATE channel_tracking SET msg_count = msg_count + ?, file_count = file_count + ? WHERE channel_id = ?",
|
||||
(msg_inc, file_inc, channel_id)
|
||||
(msg_inc, file_inc, str(channel_id))
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_channel_tracking(self, channel_id: str) -> Dict[str, Any]:
|
||||
conn = self._get_conn()
|
||||
row = conn.execute("SELECT * FROM channel_tracking WHERE channel_id = ?", (channel_id,)).fetchone()
|
||||
row = conn.execute("SELECT * FROM channel_tracking WHERE channel_id = ?", (str(channel_id),)).fetchone()
|
||||
if row:
|
||||
return dict(row)
|
||||
return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0}
|
||||
|
|
@ -358,39 +455,42 @@ class MigrationDatabase:
|
|||
conn = self._get_conn()
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO thread_mappings (channel_id, thread_id, source_msg_id, target_msg_id, timestamp) VALUES (?, ?, ?, ?, ?)",
|
||||
(channel_id, thread_id, source_id, target_id, timestamp)
|
||||
(str(channel_id), str(thread_id), parse_snowflake(source_id), str(target_id), timestamp)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_target_thread_message_id(self, channel_id: str, thread_id: str, source_id: str) -> Optional[str]:
|
||||
def get_target_thread_message_id(self, channel_id: str, thread_id: str, source_id: str) -> Optional[Union[str, int]]:
|
||||
conn = self._get_conn()
|
||||
row = conn.execute(
|
||||
"SELECT target_msg_id FROM thread_mappings WHERE channel_id = ? AND thread_id = ? AND source_msg_id = ?",
|
||||
(channel_id, thread_id, source_id)
|
||||
(str(channel_id), str(thread_id), parse_snowflake(source_id))
|
||||
).fetchone()
|
||||
return row["target_msg_id"] if row else None
|
||||
if row:
|
||||
val = row["target_msg_id"]
|
||||
return str(val) if self.platform == "stoat" else val
|
||||
return None
|
||||
|
||||
def update_thread_tracking(self, channel_id: str, thread_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0, completed: int = None):
|
||||
conn = self._get_conn()
|
||||
conn.execute("INSERT OR IGNORE INTO thread_tracking (channel_id, thread_id) VALUES (?, ?)", (channel_id, thread_id))
|
||||
conn.execute("INSERT OR IGNORE INTO thread_tracking (channel_id, thread_id) VALUES (?, ?)", (str(channel_id), str(thread_id)))
|
||||
|
||||
if last_msg_id:
|
||||
conn.execute("UPDATE thread_tracking SET last_msg_id = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_id, channel_id, thread_id))
|
||||
conn.execute("UPDATE thread_tracking SET last_msg_id = ? WHERE channel_id = ? AND thread_id = ?", (str(last_msg_id), str(channel_id), str(thread_id)))
|
||||
if last_msg_ts:
|
||||
conn.execute("UPDATE thread_tracking SET last_msg_ts = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_ts, channel_id, thread_id))
|
||||
conn.execute("UPDATE thread_tracking SET last_msg_ts = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_ts, str(channel_id), str(thread_id)))
|
||||
if completed is not None:
|
||||
conn.execute("UPDATE thread_tracking SET completed = ? WHERE channel_id = ? AND thread_id = ?", (completed, channel_id, thread_id))
|
||||
conn.execute("UPDATE thread_tracking SET completed = ? WHERE channel_id = ? AND thread_id = ?", (completed, str(channel_id), str(thread_id)))
|
||||
|
||||
if msg_inc != 0 or file_inc != 0:
|
||||
conn.execute(
|
||||
"UPDATE thread_tracking SET msg_count = msg_count + ?, file_count = file_count + ? WHERE channel_id = ? AND thread_id = ?",
|
||||
(msg_inc, file_inc, channel_id, thread_id)
|
||||
(msg_inc, file_inc, str(channel_id), str(thread_id))
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_thread_tracking(self, channel_id: str, thread_id: str) -> Dict[str, Any]:
|
||||
conn = self._get_conn()
|
||||
row = conn.execute("SELECT * FROM thread_tracking WHERE channel_id = ? AND thread_id = ?", (channel_id, thread_id)).fetchone()
|
||||
row = conn.execute("SELECT * FROM thread_tracking WHERE channel_id = ? AND thread_id = ?", (str(channel_id), str(thread_id))).fetchone()
|
||||
if row:
|
||||
return dict(row)
|
||||
return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0}
|
||||
|
|
@ -398,10 +498,10 @@ class MigrationDatabase:
|
|||
def clear_channel_data(self, channel_id: str):
|
||||
"""Purge all mappings and tracking data for a specific channel and its threads."""
|
||||
conn = self._get_conn()
|
||||
conn.execute("DELETE FROM message_mappings WHERE channel_id = ?", (channel_id,))
|
||||
conn.execute("DELETE FROM thread_mappings WHERE channel_id = ?", (channel_id,))
|
||||
conn.execute("DELETE FROM channel_tracking WHERE channel_id = ?", (channel_id,))
|
||||
conn.execute("DELETE FROM thread_tracking WHERE channel_id = ?", (channel_id,))
|
||||
conn.execute("DELETE FROM message_mappings WHERE channel_id = ?", (str(channel_id),))
|
||||
conn.execute("DELETE FROM thread_mappings WHERE channel_id = ?", (str(channel_id),))
|
||||
conn.execute("DELETE FROM channel_tracking WHERE channel_id = ?", (str(channel_id),))
|
||||
conn.execute("DELETE FROM thread_tracking WHERE channel_id = ?", (str(channel_id),))
|
||||
conn.commit()
|
||||
logger.info(f"Cleared all tracking and mapping data for channel: {channel_id}")
|
||||
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ class DiscordExporter:
|
|||
"name": r.name,
|
||||
"color": r.color.value,
|
||||
"position": r.position,
|
||||
"permissions": str(r.permissions.value),
|
||||
"permissions": r.permissions.value,
|
||||
"hoist": 1 if r.hoist else 0,
|
||||
"mentionable": 1 if r.mentionable else 0
|
||||
})
|
||||
|
|
@ -563,9 +563,10 @@ class DiscordExporter:
|
|||
})
|
||||
|
||||
# 5. Message data
|
||||
from src.core.utils import parse_snowflake
|
||||
message_reference = None
|
||||
if msg.reference and msg.reference.message_id:
|
||||
message_reference = str(msg.reference.message_id)
|
||||
message_reference = parse_snowflake(msg.reference.message_id)
|
||||
|
||||
# 5.5 Forwarded snapshots
|
||||
content = msg.content or ""
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Dict, Optional, Any, Union, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.core.database import MigrationDatabase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -119,23 +122,23 @@ class MigrationState:
|
|||
|
||||
# --- Properties for backward compatibility ---
|
||||
@property
|
||||
def channel_map(self) -> Dict[str, str]:
|
||||
def channel_map(self) -> Dict[Union[str, int], Union[str, int]]:
|
||||
return self.db.get_all_server_mappings("channel") if self.db else {}
|
||||
|
||||
@property
|
||||
def category_map(self) -> Dict[str, str]:
|
||||
def category_map(self) -> Dict[Union[str, int], Union[str, int]]:
|
||||
return self.db.get_all_server_mappings("category") if self.db else {}
|
||||
|
||||
@property
|
||||
def role_map(self) -> Dict[str, str]:
|
||||
def role_map(self) -> Dict[Union[str, int], Union[str, int]]:
|
||||
return self.db.get_all_server_mappings("role") if self.db else {}
|
||||
|
||||
@property
|
||||
def emoji_map(self) -> Dict[str, str]:
|
||||
def emoji_map(self) -> Dict[Union[str, int], Union[str, int]]:
|
||||
return self.db.get_all_asset_mappings("emoji") if self.db else {}
|
||||
|
||||
@property
|
||||
def sticker_map(self) -> Dict[str, str]:
|
||||
def sticker_map(self) -> Dict[Union[str, int], Union[str, int]]:
|
||||
return self.db.get_all_asset_mappings("sticker") if self.db else {}
|
||||
|
||||
@property
|
||||
|
|
@ -153,7 +156,7 @@ class MigrationState:
|
|||
if self._ensure_db():
|
||||
self.db.set_message_mapping(str(target_channel_id), str(discord_id), str(target_id))
|
||||
|
||||
def get_target_message_id(self, target_channel_id: str, discord_id: str) -> str | None:
|
||||
def get_target_message_id(self, target_channel_id: str, discord_id: str) -> str | int | None:
|
||||
if self._ensure_db():
|
||||
return self.db.get_target_message_id(str(target_channel_id), str(discord_id))
|
||||
return None
|
||||
|
|
@ -161,7 +164,7 @@ class MigrationState:
|
|||
def set_message_mapping(self, target_channel_id: str, discord_id: str, target_id: str):
|
||||
self.set_target_message_mapping(target_channel_id, discord_id, target_id)
|
||||
|
||||
def get_fluxer_message_id(self, target_channel_id: str, discord_id: str) -> str | None:
|
||||
def get_fluxer_message_id(self, target_channel_id: str, discord_id: str) -> str | int | None:
|
||||
return self.get_target_message_id(target_channel_id, discord_id)
|
||||
|
||||
def increment_stats(self, target_channel_id: str, messages: int = 1, files: int = 0):
|
||||
|
|
@ -258,7 +261,7 @@ class MigrationState:
|
|||
if self._ensure_db():
|
||||
self.db.clear_channel_data(str(target_channel_id))
|
||||
|
||||
def set_folder(self, server_id: str, clean_name: str, base_dir: Path | str = ""):
|
||||
def set_folder(self, server_id: str, clean_name: str, platform: str = "stoat", base_dir: Path | str = ""):
|
||||
"""
|
||||
Initializes the SQLite database based on community name and ID.
|
||||
Filename: {name}-{id}.db (Flat structure)
|
||||
|
|
@ -294,7 +297,7 @@ class MigrationState:
|
|||
from src.core.database import MigrationDatabase
|
||||
if self.db:
|
||||
self.db.close()
|
||||
self.db = MigrationDatabase(db_path)
|
||||
self.db = MigrationDatabase(db_path, platform=platform)
|
||||
logger.info(f"Initialized SQLite database at {db_path}")
|
||||
|
||||
def get_user_alias(self, user_id: str) -> str | None:
|
||||
|
|
|
|||
|
|
@ -1,14 +1,29 @@
|
|||
from typing import Any, Optional
|
||||
import re
|
||||
import logging
|
||||
from src.core.state import MigrationState
|
||||
|
||||
def parse_snowflake(value: Any) -> Optional[int]:
|
||||
"""Safely parses a Discord ID (Snowflake) from any input, handling 'None' strings."""
|
||||
if value is None:
|
||||
return None
|
||||
s = str(value).strip()
|
||||
if not s or s.lower() == "none" or s == "NULL":
|
||||
return None
|
||||
try:
|
||||
return int(s)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def resolve_discord_links(content: str, state: MigrationState, platform: str, target_server_id: str) -> str:
|
||||
def resolve_discord_links(content: str, state, platform: str, target_server_id: str) -> str:
|
||||
"""
|
||||
Finds Discord message/channel links and resolves them to the target platform
|
||||
if they have been migrated.
|
||||
"""
|
||||
from src.core.state import MigrationState
|
||||
if not isinstance(state, MigrationState):
|
||||
logger.warning(f"resolve_discord_links: state is not MigrationState (type: {type(state)})")
|
||||
if not content:
|
||||
return content
|
||||
|
||||
|
|
|
|||
|
|
@ -18,10 +18,10 @@ async def sync_assets_state(context: MigrationContext):
|
|||
fluxer_stickers = await context.fluxer_writer.client.get_guild_stickers(context.config.fluxer_server_id)
|
||||
|
||||
# Build name -> id maps and ID sets for Fluxer for fast lookup
|
||||
fluxer_emoji_map = {e.get("name"): str(e.get("id")) for e in fluxer_emojis if e.get("name")}
|
||||
fluxer_sticker_map = {s.get("name"): str(s.get("id")) for s in fluxer_stickers if s.get("name")}
|
||||
fluxer_emoji_ids = {str(e.get("id")) for e in fluxer_emojis}
|
||||
fluxer_sticker_ids = {str(s.get("id")) for s in fluxer_stickers}
|
||||
fluxer_emoji_map = {e.get("name"): e.get("id") for e in fluxer_emojis if e.get("name")}
|
||||
fluxer_sticker_map = {s.get("name"): s.get("id") for s in fluxer_stickers if s.get("name")}
|
||||
fluxer_emoji_ids = {e.get("id") for e in fluxer_emojis}
|
||||
fluxer_sticker_ids = {s.get("id") for s in fluxer_stickers}
|
||||
|
||||
updates = 0
|
||||
removals = 0
|
||||
|
|
|
|||
|
|
@ -73,8 +73,8 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None,
|
|||
cid = int(match.group(1))
|
||||
|
||||
# 1. Check if channel is mapped in state
|
||||
if channel_map and str(cid) in channel_map:
|
||||
return f"<#{channel_map[str(cid)]}>"
|
||||
if channel_map and cid in channel_map:
|
||||
return f"<#{channel_map[cid]}>"
|
||||
|
||||
# 2. Try to resolve channel name from pre-fetched names
|
||||
name = None
|
||||
|
|
@ -100,7 +100,7 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None,
|
|||
def replace_emoji(match):
|
||||
animated = match.group(1) == "a"
|
||||
name = match.group(2)
|
||||
eid = match.group(3)
|
||||
eid = int(match.group(3))
|
||||
|
||||
if emoji_map and eid in emoji_map:
|
||||
target_eid = emoji_map[eid]
|
||||
|
|
|
|||
|
|
@ -15,8 +15,8 @@ async def sync_roles_state(context: MigrationContext):
|
|||
fluxer_roles = await context.fluxer_writer.client.get_guild_roles(context.config.fluxer_server_id)
|
||||
|
||||
# Build name -> id maps and ID sets for Fluxer for fast lookup
|
||||
fluxer_role_map = {r.get("name"): str(r.get("id")) for r in fluxer_roles if r.get("name")}
|
||||
fluxer_role_ids = {str(r.get("id")) for r in fluxer_roles}
|
||||
fluxer_role_map = {r.get("name"): r.get("id") for r in fluxer_roles if r.get("name")}
|
||||
fluxer_role_ids = {r.get("id") for r in fluxer_roles}
|
||||
|
||||
updates = 0
|
||||
removals = 0
|
||||
|
|
|
|||
|
|
@ -37,8 +37,7 @@ async def sync_assets_state(context: MigrationContext):
|
|||
|
||||
if stoat_id:
|
||||
if stoat_id not in stoat_emoji_ids:
|
||||
context.state.emoji_map.pop(discord_id, None)
|
||||
context.state.save_state()
|
||||
context.state.remove_emoji_mapping(discord_id)
|
||||
removals += 1
|
||||
elif emoji.name in stoat_emoji_map:
|
||||
context.state.set_target_emoji_mapping(discord_id, stoat_emoji_map[emoji.name])
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue