switch to sqlite for migration states
This commit is contained in:
parent
d0a8a1acb4
commit
de28ecdbb1
5 changed files with 513 additions and 342 deletions
|
|
@ -696,7 +696,15 @@ class BackupReader:
|
||||||
self.backup_path = Path(backup_path)
|
self.backup_path = Path(backup_path)
|
||||||
self.guild: BackupGuild | None = None
|
self.guild: BackupGuild | None = None
|
||||||
|
|
||||||
# Internal caches populated by start()
|
self._thread_info: Dict[int, Dict[str, Any]] = {} # channel_id -> metadata (like name, parentID)
|
||||||
|
|
||||||
|
# Lazy loading flags
|
||||||
|
self._roles_loaded = False
|
||||||
|
self._structure_loaded = False
|
||||||
|
self._assets_loaded = False
|
||||||
|
self._members_loaded = False
|
||||||
|
|
||||||
|
# Internal storage
|
||||||
self._categories: List[BackupCategory] = []
|
self._categories: List[BackupCategory] = []
|
||||||
self._channels: List[BackupChannel] = []
|
self._channels: List[BackupChannel] = []
|
||||||
self._roles: List[BackupRole] = []
|
self._roles: List[BackupRole] = []
|
||||||
|
|
@ -704,12 +712,11 @@ class BackupReader:
|
||||||
self._stickers: List[BackupSticker] = []
|
self._stickers: List[BackupSticker] = []
|
||||||
self._members: List[BackupMember] = []
|
self._members: List[BackupMember] = []
|
||||||
self._member_map: Dict[int, BackupMember] = {}
|
self._member_map: Dict[int, BackupMember] = {}
|
||||||
self._thread_info: Dict[int, Dict[str, Any]] = {} # channel_id -> metadata (like name, parentID)
|
|
||||||
|
|
||||||
# ── startup ──────────────────────────────────────────────────────────
|
# ── startup ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
"""Loads all JSON files from the backup directory into memory."""
|
"""Initializes the backup path and loads the server profile."""
|
||||||
bp = self.backup_path
|
bp = self.backup_path
|
||||||
|
|
||||||
# 1. Server profile -> BackupGuild
|
# 1. Server profile -> BackupGuild
|
||||||
|
|
@ -717,60 +724,95 @@ class BackupReader:
|
||||||
if profile_file.exists():
|
if profile_file.exists():
|
||||||
profile = json.loads(profile_file.read_text(encoding="utf-8"))
|
profile = json.loads(profile_file.read_text(encoding="utf-8"))
|
||||||
self.guild = BackupGuild(profile, bp, reader=self)
|
self.guild = BackupGuild(profile, bp, reader=self)
|
||||||
logger.info(f"[Backup] Loaded server profile: {self.guild.name} ({self.guild.id})")
|
logger.info(f"[Backup] Initialized server: {self.guild.name} ({self.guild.id})")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"[Backup] server_profile/profile.json not found in {bp}")
|
logger.warning(f"[Backup] server_profile/profile.json not found in {bp}")
|
||||||
self.guild = None
|
self.guild = None
|
||||||
|
|
||||||
# 2. Roles
|
@property
|
||||||
roles_file = bp / "server_profile" / "roles.json"
|
def roles(self) -> List[BackupRole]:
|
||||||
if roles_file.exists():
|
if not self._roles_loaded:
|
||||||
roles_data = json.loads(roles_file.read_text(encoding="utf-8"))
|
roles_file = self.backup_path / "server_profile" / "roles.json"
|
||||||
self._roles = [BackupRole(r) for r in roles_data]
|
if roles_file.exists():
|
||||||
logger.info(f"[Backup] Loaded {len(self._roles)} roles")
|
logger.info(f"[Backup] Lazy-loading roles...")
|
||||||
|
roles_data = json.loads(roles_file.read_text(encoding="utf-8"))
|
||||||
|
self._roles = [BackupRole(r) for r in roles_data]
|
||||||
|
self._roles_loaded = True
|
||||||
|
return self._roles
|
||||||
|
|
||||||
# 3. Structure -> categories + channels
|
@property
|
||||||
struct_file = bp / "server_profile" / "structure.json"
|
def categories(self) -> List[BackupCategory]:
|
||||||
|
self._ensure_structure_loaded()
|
||||||
|
return self._categories
|
||||||
|
|
||||||
|
@property
|
||||||
|
def channels(self) -> List[BackupChannel]:
|
||||||
|
self._ensure_structure_loaded()
|
||||||
|
return self._channels
|
||||||
|
|
||||||
|
def _ensure_structure_loaded(self):
|
||||||
|
if self._structure_loaded:
|
||||||
|
return
|
||||||
|
struct_file = self.backup_path / "server_profile" / "structure.json"
|
||||||
if struct_file.exists():
|
if struct_file.exists():
|
||||||
|
logger.info(f"[Backup] Lazy-loading server structure...")
|
||||||
structure = json.loads(struct_file.read_text(encoding="utf-8"))
|
structure = json.loads(struct_file.read_text(encoding="utf-8"))
|
||||||
for cat_data in structure:
|
for cat_data in structure:
|
||||||
cat = BackupCategory(cat_data)
|
cat = BackupCategory(cat_data)
|
||||||
if cat.id != 0: # skip 'uncategorized' as a real category
|
if cat.id != 0:
|
||||||
self._categories.append(cat)
|
self._categories.append(cat)
|
||||||
|
|
||||||
for ch_data in cat_data.get("channels", []):
|
for ch_data in cat_data.get("channels", []):
|
||||||
ch_cat_id = cat.id if cat.id != 0 else None
|
ch_cat_id = cat.id if cat.id != 0 else None
|
||||||
channel = BackupChannel(ch_data, category_id=ch_cat_id, guild=self.guild)
|
channel = BackupChannel(ch_data, category_id=ch_cat_id, guild=self.guild)
|
||||||
self._channels.append(channel)
|
self._channels.append(channel)
|
||||||
|
self._structure_loaded = True
|
||||||
|
|
||||||
logger.info(f"[Backup] Loaded {len(self._categories)} categories, "
|
@property
|
||||||
f"{len(self._channels)} channels")
|
def emojis(self) -> List[BackupEmoji]:
|
||||||
|
self._ensure_assets_loaded()
|
||||||
|
return self._emojis
|
||||||
|
|
||||||
# 4. Assets (emojis + stickers)
|
@property
|
||||||
assets_file = bp / "server_profile" / "assets.json"
|
def stickers(self) -> List[BackupSticker]:
|
||||||
media_dir = bp / "server_profile" / "assets"
|
self._ensure_assets_loaded()
|
||||||
|
return self._stickers
|
||||||
|
|
||||||
|
def _ensure_assets_loaded(self):
|
||||||
|
if self._assets_loaded:
|
||||||
|
return
|
||||||
|
assets_file = self.backup_path / "server_profile" / "assets.json"
|
||||||
|
media_dir = self.backup_path / "server_profile" / "assets"
|
||||||
if assets_file.exists():
|
if assets_file.exists():
|
||||||
|
logger.info(f"[Backup] Lazy-loading assets...")
|
||||||
assets = json.loads(assets_file.read_text(encoding="utf-8"))
|
assets = json.loads(assets_file.read_text(encoding="utf-8"))
|
||||||
self._emojis = [BackupEmoji(e, media_dir) for e in assets.get("emojis", [])]
|
self._emojis = [BackupEmoji(e, media_dir) for e in assets.get("emojis", [])]
|
||||||
self._stickers = [BackupSticker(s, media_dir) for s in assets.get("stickers", [])]
|
self._stickers = [BackupSticker(s, media_dir) for s in assets.get("stickers", [])]
|
||||||
logger.info(f"[Backup] Loaded {len(self._emojis)} emojis, "
|
self._assets_loaded = True
|
||||||
f"{len(self._stickers)} stickers")
|
|
||||||
|
|
||||||
# 5. Users
|
@property
|
||||||
user_info_file = bp / "message_backup" / "users" / "user_info.json"
|
def members(self) -> List[BackupMember]:
|
||||||
|
self._ensure_members_loaded()
|
||||||
|
return self._members
|
||||||
|
|
||||||
|
def _ensure_members_loaded(self):
|
||||||
|
if self._members_loaded:
|
||||||
|
return
|
||||||
|
user_info_file = self.backup_path / "message_backup" / "users" / "user_info.json"
|
||||||
if user_info_file.exists():
|
if user_info_file.exists():
|
||||||
|
logger.info(f"[Backup] Lazy-loading members...")
|
||||||
try:
|
try:
|
||||||
users = json.loads(user_info_file.read_text(encoding="utf-8"))
|
users = json.loads(user_info_file.read_text(encoding="utf-8"))
|
||||||
backup_root = bp / "message_backup"
|
backup_root = self.backup_path / "message_backup"
|
||||||
for u in users:
|
for u in users:
|
||||||
user_role_ids = {int(r["id"]) for r in u.get("userRoles", [])}
|
user_role_ids = {int(r["id"]) for r in u.get("userRoles", [])}
|
||||||
role_objs = [r for r in self._roles if r.id in user_role_ids]
|
# Note: this triggers roles lazy load
|
||||||
|
role_objs = [r for r in self.roles if r.id in user_role_ids]
|
||||||
member = BackupMember(u, role_objects=role_objs, avatar_base=backup_root)
|
member = BackupMember(u, role_objects=role_objs, avatar_base=backup_root)
|
||||||
self._members.append(member)
|
self._members.append(member)
|
||||||
self._member_map[member.id] = member
|
self._member_map[member.id] = member
|
||||||
logger.info(f"[Backup] Loaded {len(self._members)} users")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"[Backup] Failed to load user_info.json: {e}")
|
logger.warning(f"[Backup] Failed to lazy-load user_info.json: {e}")
|
||||||
|
self._members_loaded = True
|
||||||
|
|
||||||
# ── validation ───────────────────────────────────────────────────────
|
# ── validation ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -89,6 +89,16 @@ class MigrationContext:
|
||||||
"target_community_name": t_valid.get("community_name"),
|
"target_community_name": t_valid.get("community_name"),
|
||||||
"target_permissions": t_valid.get("permissions", {})
|
"target_permissions": t_valid.get("permissions", {})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# CONSISTENCY: Once target metadata is known, initialize the flat SQLite DB.
|
||||||
|
if results["target_community"] and results["target_community_name"]:
|
||||||
|
import re
|
||||||
|
clean_name = re.sub(r'[^\w\s-]', '', results["target_community_name"]).strip()
|
||||||
|
clean_name = re.sub(r'[-\s]+', '_', clean_name)
|
||||||
|
db_community_id = str(self.config.target_server_id or "")
|
||||||
|
self.state.set_folder(db_community_id, clean_name, base_dir=base_dir)
|
||||||
|
|
||||||
|
return results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Validation failed with exception: {e}")
|
logger.error(f"Validation failed with exception: {e}")
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
237
src/core/database.py
Normal file
237
src/core/database.py
Normal file
|
|
@ -0,0 +1,237 @@
|
||||||
|
import sqlite3
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
import threading
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class MigrationDatabase:
|
||||||
|
"""
|
||||||
|
SQLite-based persistence for large-scale migration mappings and stats.
|
||||||
|
Replaces the memory-bloated and O(N^2) JSON persistence for messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_local = threading.local()
|
||||||
|
|
||||||
|
def __init__(self, db_path: Path):
|
||||||
|
self.db_path = db_path
|
||||||
|
self._init_db()
|
||||||
|
|
||||||
|
def _get_conn(self) -> sqlite3.Connection:
|
||||||
|
if not hasattr(self._local, "conn"):
|
||||||
|
self._local.conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||||
|
self._local.conn.row_factory = sqlite3.Row
|
||||||
|
return self._local.conn
|
||||||
|
|
||||||
|
def _init_db(self):
|
||||||
|
"""Initialize tables if they don't exist."""
|
||||||
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Table for message mappings: SourceID -> TargetID
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS message_mappings (
|
||||||
|
channel_id TEXT,
|
||||||
|
source_msg_id TEXT,
|
||||||
|
target_msg_id TEXT,
|
||||||
|
timestamp TEXT,
|
||||||
|
PRIMARY KEY (channel_id, source_msg_id)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Table for thread mappings
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS thread_mappings (
|
||||||
|
channel_id TEXT,
|
||||||
|
thread_id TEXT,
|
||||||
|
source_msg_id TEXT,
|
||||||
|
target_msg_id TEXT,
|
||||||
|
timestamp TEXT,
|
||||||
|
PRIMARY KEY (channel_id, thread_id, source_msg_id)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Table for per-channel stats and tracking
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS channel_tracking (
|
||||||
|
channel_id TEXT PRIMARY KEY,
|
||||||
|
last_msg_id TEXT,
|
||||||
|
last_msg_ts TEXT,
|
||||||
|
msg_count INTEGER DEFAULT 0,
|
||||||
|
file_count INTEGER DEFAULT 0
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Table for per-thread stats
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS thread_tracking (
|
||||||
|
channel_id TEXT,
|
||||||
|
thread_id TEXT,
|
||||||
|
last_msg_id TEXT,
|
||||||
|
last_msg_ts TEXT,
|
||||||
|
msg_count INTEGER DEFAULT 0,
|
||||||
|
file_count INTEGER DEFAULT 0,
|
||||||
|
PRIMARY KEY (channel_id, thread_id)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Table for entity mappings (channels, roles, etc.)
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS entity_mappings (
|
||||||
|
category TEXT,
|
||||||
|
source_id TEXT,
|
||||||
|
target_id TEXT,
|
||||||
|
PRIMARY KEY (category, source_id)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Table for general metadata
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS metadata (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
value TEXT
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def set_message_mapping(self, channel_id: str, source_id: str, target_id: str, timestamp: str = None):
|
||||||
|
conn = self._get_conn()
|
||||||
|
conn.execute(
|
||||||
|
"INSERT OR REPLACE INTO message_mappings (channel_id, source_msg_id, target_msg_id, timestamp) VALUES (?, ?, ?, ?)",
|
||||||
|
(channel_id, source_id, target_id, timestamp)
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def get_target_message_id(self, channel_id: str, source_id: str) -> Optional[str]:
|
||||||
|
conn = self._get_conn()
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT target_msg_id FROM message_mappings WHERE channel_id = ? AND source_msg_id = ?",
|
||||||
|
(channel_id, source_id)
|
||||||
|
).fetchone()
|
||||||
|
return row["target_msg_id"] if row else None
|
||||||
|
|
||||||
|
# --- New Entity Mapping Methods ---
|
||||||
|
|
||||||
|
def set_entity_mapping(self, category: str, source_id: str, target_id: str):
|
||||||
|
conn = self._get_conn()
|
||||||
|
conn.execute(
|
||||||
|
"INSERT OR REPLACE INTO entity_mappings (category, source_id, target_id) VALUES (?, ?, ?)",
|
||||||
|
(category, str(source_id), str(target_id))
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def get_entity_mapping(self, category: str, source_id: str) -> Optional[str]:
|
||||||
|
conn = self._get_conn()
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT target_id FROM entity_mappings WHERE category = ? AND source_id = ?",
|
||||||
|
(category, str(source_id))
|
||||||
|
).fetchone()
|
||||||
|
return row["target_id"] if row else None
|
||||||
|
|
||||||
|
def get_all_entity_mappings(self, category: str) -> Dict[str, str]:
|
||||||
|
conn = self._get_conn()
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT source_id, target_id FROM entity_mappings WHERE category = ?",
|
||||||
|
(category,)
|
||||||
|
).fetchall()
|
||||||
|
return {row["source_id"]: row["target_id"] for row in rows}
|
||||||
|
|
||||||
|
def delete_entity_mapping(self, category: str, source_id: str):
|
||||||
|
conn = self._get_conn()
|
||||||
|
conn.execute(
|
||||||
|
"DELETE FROM entity_mappings WHERE category = ? AND source_id = ?",
|
||||||
|
(category, str(source_id))
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def clear_entities(self, category: str = None):
|
||||||
|
conn = self._get_conn()
|
||||||
|
if category:
|
||||||
|
conn.execute("DELETE FROM entity_mappings WHERE category = ?", (category,))
|
||||||
|
else:
|
||||||
|
conn.execute("DELETE FROM entity_mappings")
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
# --- Metadata Methods ---
|
||||||
|
|
||||||
|
def set_metadata(self, key: str, value: str):
|
||||||
|
conn = self._get_conn()
|
||||||
|
conn.execute("INSERT OR REPLACE INTO metadata (key, value) VALUES (?, ?)", (key, str(value) if value is not None else None))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def get_metadata(self, key: str) -> Optional[str]:
|
||||||
|
conn = self._get_conn()
|
||||||
|
row = conn.execute("SELECT value FROM metadata WHERE key = ?", (key,)).fetchone()
|
||||||
|
return row["value"] if row else None
|
||||||
|
|
||||||
|
def update_channel_tracking(self, channel_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0):
|
||||||
|
conn = self._get_conn()
|
||||||
|
# Initialize if missing
|
||||||
|
conn.execute("INSERT OR IGNORE INTO channel_tracking (channel_id) VALUES (?)", (channel_id,))
|
||||||
|
|
||||||
|
if last_msg_id:
|
||||||
|
conn.execute("UPDATE channel_tracking SET last_msg_id = ? WHERE channel_id = ?", (last_msg_id, channel_id))
|
||||||
|
if last_msg_ts:
|
||||||
|
conn.execute("UPDATE channel_tracking SET last_msg_ts = ? WHERE channel_id = ?", (last_msg_ts, channel_id))
|
||||||
|
|
||||||
|
if msg_inc != 0 or file_inc != 0:
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE channel_tracking SET msg_count = msg_count + ?, file_count = file_count + ? WHERE channel_id = ?",
|
||||||
|
(msg_inc, file_inc, channel_id)
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def get_channel_tracking(self, channel_id: str) -> Dict[str, Any]:
|
||||||
|
conn = self._get_conn()
|
||||||
|
row = conn.execute("SELECT * FROM channel_tracking WHERE channel_id = ?", (channel_id,)).fetchone()
|
||||||
|
if row:
|
||||||
|
return dict(row)
|
||||||
|
return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0}
|
||||||
|
|
||||||
|
# Thread methods similar to channel methods
|
||||||
|
def set_thread_message_mapping(self, channel_id: str, thread_id: str, source_id: str, target_id: str, timestamp: str = None):
|
||||||
|
conn = self._get_conn()
|
||||||
|
conn.execute(
|
||||||
|
"INSERT OR REPLACE INTO thread_mappings (channel_id, thread_id, source_msg_id, target_msg_id, timestamp) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
(channel_id, thread_id, source_id, target_id, timestamp)
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def get_target_thread_message_id(self, channel_id: str, thread_id: str, source_id: str) -> Optional[str]:
|
||||||
|
conn = self._get_conn()
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT target_msg_id FROM thread_mappings WHERE channel_id = ? AND thread_id = ? AND source_msg_id = ?",
|
||||||
|
(channel_id, thread_id, source_id)
|
||||||
|
).fetchone()
|
||||||
|
return row["target_msg_id"] if row else None
|
||||||
|
|
||||||
|
def update_thread_tracking(self, channel_id: str, thread_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0):
|
||||||
|
conn = self._get_conn()
|
||||||
|
conn.execute("INSERT OR IGNORE INTO thread_tracking (channel_id, thread_id) VALUES (?, ?)", (channel_id, thread_id))
|
||||||
|
|
||||||
|
if last_msg_id:
|
||||||
|
conn.execute("UPDATE thread_tracking SET last_msg_id = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_id, channel_id, thread_id))
|
||||||
|
if last_msg_ts:
|
||||||
|
conn.execute("UPDATE thread_tracking SET last_msg_ts = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_ts, channel_id, thread_id))
|
||||||
|
|
||||||
|
if msg_inc != 0 or file_inc != 0:
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE thread_tracking SET msg_count = msg_count + ?, file_count = file_count + ? WHERE channel_id = ? AND thread_id = ?",
|
||||||
|
(msg_inc, file_inc, channel_id, thread_id)
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def get_thread_tracking(self, channel_id: str, thread_id: str) -> Dict[str, Any]:
|
||||||
|
conn = self._get_conn()
|
||||||
|
row = conn.execute("SELECT * FROM thread_tracking WHERE channel_id = ? AND thread_id = ?", (channel_id, thread_id)).fetchone()
|
||||||
|
if row:
|
||||||
|
return dict(row)
|
||||||
|
return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0}
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if hasattr(self._local, "conn"):
|
||||||
|
self._local.conn.close()
|
||||||
|
del self._local.conn
|
||||||
|
|
@ -1,366 +1,248 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any, Optional
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class MigrationState:
|
class MigrationState:
|
||||||
"""Manages persistence of the migration state to allow resumability."""
|
"""Manages persistence of the migration state to allow resumability.
|
||||||
|
Uses SQLite for ALL mappings and metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, state_file: str | Path = "", messages_file: str | Path = ""):
|
def __init__(self):
|
||||||
self.state_file: Path | None = Path(state_file) if state_file else None
|
# database instance for all persistence
|
||||||
self.messages_file: Path | None = Path(messages_file) if messages_file else None
|
self.db: Optional['MigrationDatabase'] = None
|
||||||
|
|
||||||
# mappings: discord_id -> fluxer_id
|
def _ensure_db(self):
|
||||||
self.channel_map: Dict[str, str] = {}
|
if not self.db:
|
||||||
self.category_map: Dict[str, str] = {}
|
logger.warning("MigrationState: Accessing database before initialization")
|
||||||
self.role_map: Dict[str, str] = {}
|
return False
|
||||||
self.emoji_map: Dict[str, str] = {}
|
return True
|
||||||
self.sticker_map: Dict[str, str] = {}
|
|
||||||
|
|
||||||
# audit log tracking
|
|
||||||
self.audit_log_channel: str | None = None
|
|
||||||
|
|
||||||
# message tracking per target channel
|
|
||||||
# Format: { target_channel_id: {"message_map": {}, "last_message_id": "", "last_message_timestamp": ""} }
|
|
||||||
self.channel_messages: Dict[str, Dict[str, Any]] = {}
|
|
||||||
|
|
||||||
self.load()
|
|
||||||
|
|
||||||
def load(self):
|
# --- Type Specific Getters/Setters (Database Backed) ---
|
||||||
migrated_state = False
|
|
||||||
migrated_messages = False
|
|
||||||
|
|
||||||
# 1. Load primary state file
|
def set_channel_mapping(self, discord_id: str, target_id: str):
|
||||||
if self.state_file and self.state_file.exists():
|
if self._ensure_db():
|
||||||
with open(self.state_file, "r", encoding="utf-8") as f:
|
self.db.set_entity_mapping("channel", str(discord_id), str(target_id))
|
||||||
data = json.load(f)
|
|
||||||
self.channel_map = data.get("channels", {})
|
|
||||||
self.category_map = data.get("categories", {})
|
|
||||||
self.role_map = data.get("roles", {})
|
|
||||||
self.emoji_map = data.get("emojis", {})
|
|
||||||
self.sticker_map = data.get("stickers", {})
|
|
||||||
self.audit_log_channel = data.get("audit_log_channel")
|
|
||||||
|
|
||||||
# 2. Load separate messages file
|
def get_target_channel_id(self, discord_id: str) -> str | None:
|
||||||
if self.messages_file and self.messages_file.exists():
|
if self._ensure_db():
|
||||||
logger.info(f"Loading messages from {self.messages_file.name}")
|
return self.db.get_entity_mapping("channel", str(discord_id))
|
||||||
try:
|
return None
|
||||||
with open(self.messages_file, "r", encoding="utf-8") as f:
|
|
||||||
msg_data = json.load(f)
|
|
||||||
|
|
||||||
# Check for new schema (nested under 'channels')
|
|
||||||
if "channels" in msg_data:
|
|
||||||
self.channel_messages = msg_data.get("channels", {})
|
|
||||||
logger.debug(f"Loaded {len(self.channel_messages)} tracked channels.")
|
|
||||||
else:
|
|
||||||
logger.warning("Legacy schema or empty tracker detected in messages file.")
|
|
||||||
# Legacy schema detection & conversion to a default 'unknown_channel' just in case,
|
|
||||||
# though new migrations shouldn't hit this based on previous removals.
|
|
||||||
legacy_map = msg_data.get("messages", {})
|
|
||||||
legacy_ids = msg_data.get("last_message_ids", {})
|
|
||||||
legacy_times = msg_data.get("last_message_timestamps", {})
|
|
||||||
|
|
||||||
if legacy_map or legacy_ids or legacy_times:
|
|
||||||
self.channel_messages = {
|
|
||||||
"legacy_migrated_channel": {
|
|
||||||
"message_map": legacy_map,
|
|
||||||
"last_message_id": list(legacy_ids.values())[-1] if legacy_ids else "",
|
|
||||||
"last_message_timestamp": list(legacy_times.values())[-1] if legacy_times else ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to load messages file: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def save_state(self):
|
|
||||||
"""Saves only the core server configuration (channels, roles, emojis)."""
|
|
||||||
if not self.state_file:
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.debug(f"Saving state to {self.state_file.name}")
|
|
||||||
data = {
|
|
||||||
"channels": self.channel_map,
|
|
||||||
"categories": self.category_map,
|
|
||||||
"roles": self.role_map,
|
|
||||||
"emojis": self.emoji_map,
|
|
||||||
"stickers": self.sticker_map,
|
|
||||||
"audit_log_channel": self.audit_log_channel
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
with open(self.state_file, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(data, f, indent=4)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to save state file: {e}")
|
|
||||||
|
|
||||||
def save_messages(self):
|
|
||||||
"""Saves only the message tracking data."""
|
|
||||||
if not self.messages_file:
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.debug(f"Saving messages to {self.messages_file.name}")
|
|
||||||
data = {
|
|
||||||
"channels": self.channel_messages
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
with open(self.messages_file, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(data, f, indent=4)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to save messages file: {e}")
|
|
||||||
|
|
||||||
# --- Type Specific Getters/Setters ---
|
|
||||||
|
|
||||||
def set_channel_mapping(self, discord_id: str, fluxer_id: str):
|
|
||||||
self.channel_map[str(discord_id)] = str(fluxer_id)
|
|
||||||
self.save_state()
|
|
||||||
|
|
||||||
def get_fluxer_channel_id(self, discord_id: str) -> str | None:
|
|
||||||
return self.channel_map.get(str(discord_id))
|
|
||||||
|
|
||||||
def remove_channel_mapping(self, discord_id: str):
|
|
||||||
self.channel_map.pop(str(discord_id), None)
|
|
||||||
self.save_state()
|
|
||||||
|
|
||||||
def set_category_mapping(self, discord_id: str, fluxer_id: str):
|
|
||||||
self.category_map[str(discord_id)] = str(fluxer_id)
|
|
||||||
self.save_state()
|
|
||||||
|
|
||||||
def get_fluxer_category_id(self, discord_id: str) -> str | None:
|
|
||||||
return self.category_map.get(str(discord_id))
|
|
||||||
|
|
||||||
def remove_category_mapping(self, discord_id: str):
|
|
||||||
self.category_map.pop(str(discord_id), None)
|
|
||||||
self.save_state()
|
|
||||||
|
|
||||||
def set_role_mapping(self, discord_id: str, fluxer_id: str):
|
|
||||||
self.role_map[str(discord_id)] = str(fluxer_id)
|
|
||||||
self.save_state()
|
|
||||||
|
|
||||||
def get_fluxer_role_id(self, discord_id: str) -> str | None:
|
|
||||||
return self.role_map.get(str(discord_id))
|
|
||||||
|
|
||||||
def remove_role_mapping(self, discord_id: str):
|
|
||||||
self.role_map.pop(str(discord_id), None)
|
|
||||||
self.save_state()
|
|
||||||
|
|
||||||
def set_emoji_mapping(self, discord_id: str, fluxer_id: str):
|
|
||||||
self.emoji_map[str(discord_id)] = str(fluxer_id)
|
|
||||||
self.save_state()
|
|
||||||
|
|
||||||
def get_fluxer_emoji_id(self, discord_id: str) -> str | None:
|
|
||||||
return self.emoji_map.get(str(discord_id))
|
|
||||||
|
|
||||||
def remove_emoji_mapping(self, discord_id: str):
|
|
||||||
self.emoji_map.pop(str(discord_id), None)
|
|
||||||
self.save_state()
|
|
||||||
|
|
||||||
def set_sticker_mapping(self, discord_id: str, fluxer_id: str):
|
|
||||||
self.sticker_map[str(discord_id)] = str(fluxer_id)
|
|
||||||
self.save_state()
|
|
||||||
|
|
||||||
def get_fluxer_sticker_id(self, discord_id: str) -> str | None:
|
|
||||||
return self.sticker_map.get(str(discord_id))
|
|
||||||
|
|
||||||
def remove_sticker_mapping(self, discord_id: str):
|
|
||||||
self.sticker_map.pop(str(discord_id), None)
|
|
||||||
self.save_state()
|
|
||||||
|
|
||||||
# --- Generic Aliases for target platform migration ---
|
|
||||||
|
|
||||||
get_target_channel_id = get_fluxer_channel_id
|
get_fluxer_channel_id = get_target_channel_id
|
||||||
set_channel_mapping = set_channel_mapping # already generic enough in name if we ignore the 'fluxer' in implementation
|
set_target_channel_mapping = set_channel_mapping
|
||||||
|
|
||||||
def set_target_channel_mapping(self, discord_id: str, target_id: str):
|
def set_category_mapping(self, discord_id: str, target_id: str):
|
||||||
self.set_channel_mapping(discord_id, target_id)
|
if self._ensure_db():
|
||||||
|
self.db.set_entity_mapping("category", str(discord_id), str(target_id))
|
||||||
|
|
||||||
def get_target_category_id(self, discord_id: str) -> str | None:
|
def get_target_category_id(self, discord_id: str) -> str | None:
|
||||||
return self.get_fluxer_category_id(discord_id)
|
if self._ensure_db():
|
||||||
|
return self.db.get_entity_mapping("category", str(discord_id))
|
||||||
|
return None
|
||||||
|
|
||||||
|
get_fluxer_category_id = get_target_category_id
|
||||||
|
set_target_category_mapping = set_category_mapping
|
||||||
|
|
||||||
def set_target_category_mapping(self, discord_id: str, target_id: str):
|
def set_role_mapping(self, discord_id: str, target_id: str):
|
||||||
self.set_category_mapping(discord_id, target_id)
|
if self._ensure_db():
|
||||||
|
self.db.set_entity_mapping("role", str(discord_id), str(target_id))
|
||||||
|
|
||||||
def get_target_role_id(self, discord_id: str) -> str | None:
|
def get_target_role_id(self, discord_id: str) -> str | None:
|
||||||
return self.get_fluxer_role_id(discord_id)
|
if self._ensure_db():
|
||||||
|
return self.db.get_entity_mapping("role", str(discord_id))
|
||||||
|
return None
|
||||||
|
|
||||||
|
get_fluxer_role_id = get_target_role_id
|
||||||
|
set_target_role_mapping = set_role_mapping
|
||||||
|
|
||||||
def set_target_role_mapping(self, discord_id: str, target_id: str):
|
def set_emoji_mapping(self, discord_id: str, target_id: str):
|
||||||
self.set_role_mapping(discord_id, target_id)
|
if self._ensure_db():
|
||||||
|
self.db.set_entity_mapping("emoji", str(discord_id), str(target_id))
|
||||||
|
|
||||||
def get_target_emoji_id(self, discord_id: str) -> str | None:
|
def get_target_emoji_id(self, discord_id: str) -> str | None:
|
||||||
return self.get_fluxer_emoji_id(discord_id)
|
if self._ensure_db():
|
||||||
|
return self.db.get_entity_mapping("emoji", str(discord_id))
|
||||||
|
return None
|
||||||
|
|
||||||
|
get_fluxer_emoji_id = get_target_emoji_id
|
||||||
|
set_target_emoji_mapping = set_emoji_mapping
|
||||||
|
|
||||||
def set_target_emoji_mapping(self, discord_id: str, target_id: str):
|
def set_sticker_mapping(self, discord_id: str, target_id: str):
|
||||||
self.set_emoji_mapping(discord_id, target_id)
|
if self._ensure_db():
|
||||||
|
self.db.set_entity_mapping("sticker", str(discord_id), str(target_id))
|
||||||
|
|
||||||
def get_target_sticker_id(self, discord_id: str) -> str | None:
|
def get_target_sticker_id(self, discord_id: str) -> str | None:
|
||||||
return self.get_fluxer_sticker_id(discord_id)
|
if self._ensure_db():
|
||||||
|
return self.db.get_entity_mapping("sticker", str(discord_id))
|
||||||
|
return None
|
||||||
|
|
||||||
|
get_fluxer_sticker_id = get_target_sticker_id
|
||||||
|
set_target_sticker_mapping = set_sticker_mapping
|
||||||
|
|
||||||
def set_target_sticker_mapping(self, discord_id: str, target_id: str):
|
# --- Properties for backward compatibility ---
|
||||||
self.set_sticker_mapping(discord_id, target_id)
|
@property
|
||||||
|
def channel_map(self) -> Dict[str, str]:
|
||||||
def get_target_message_id(self, target_channel_id: str, discord_id: str) -> str | None:
|
return self.db.get_all_entity_mappings("channel") if self.db else {}
|
||||||
return self.get_fluxer_message_id(target_channel_id, discord_id)
|
|
||||||
|
|
||||||
def set_target_message_mapping(self, target_channel_id: str, discord_id: str, target_id: str):
|
@property
|
||||||
self.set_message_mapping(target_channel_id, discord_id, target_id)
|
def category_map(self) -> Dict[str, str]:
|
||||||
|
return self.db.get_all_entity_mappings("category") if self.db else {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def role_map(self) -> Dict[str, str]:
|
||||||
|
return self.db.get_all_entity_mappings("role") if self.db else {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def emoji_map(self) -> Dict[str, str]:
|
||||||
|
return self.db.get_all_entity_mappings("emoji") if self.db else {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sticker_map(self) -> Dict[str, str]:
|
||||||
|
return self.db.get_all_entity_mappings("sticker") if self.db else {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audit_log_channel(self) -> str | None:
|
||||||
|
return self.db.get_metadata("audit_log_channel") if self.db else None
|
||||||
|
|
||||||
|
@audit_log_channel.setter
|
||||||
|
def audit_log_channel(self, value: str | None):
|
||||||
|
if self._ensure_db():
|
||||||
|
self.db.set_metadata("audit_log_channel", value)
|
||||||
|
|
||||||
# --- Message Management ---
|
# --- Message Management ---
|
||||||
|
|
||||||
def _ensure_channel_tracking(self, target_channel_id: str):
|
def set_target_message_mapping(self, target_channel_id: str, discord_id: str, target_id: str):
|
||||||
if str(target_channel_id) not in self.channel_messages:
|
if self._ensure_db():
|
||||||
self.channel_messages[str(target_channel_id)] = {
|
self.db.set_message_mapping(str(target_channel_id), str(discord_id), str(target_id))
|
||||||
"message_map": {},
|
|
||||||
"last_message_id": "",
|
|
||||||
"last_message_timestamp": "",
|
|
||||||
"total_messages": 0,
|
|
||||||
"total_files": 0,
|
|
||||||
"threads": {}
|
|
||||||
}
|
|
||||||
|
|
||||||
def increment_stats(self, target_channel_id: str, messages: int = 1, files: int = 0):
|
|
||||||
self._ensure_channel_tracking(target_channel_id)
|
|
||||||
c = self.channel_messages[str(target_channel_id)]
|
|
||||||
c["total_messages"] = c.get("total_messages", 0) + messages
|
|
||||||
c["total_files"] = c.get("total_files", 0) + files
|
|
||||||
self.save_messages()
|
|
||||||
|
|
||||||
# --- Thread Tracking ---
|
def get_target_message_id(self, target_channel_id: str, discord_id: str) -> str | None:
|
||||||
|
if self._ensure_db():
|
||||||
def _ensure_thread_tracking(self, target_channel_id: str, thread_id: str):
|
return self.db.get_target_message_id(str(target_channel_id), str(discord_id))
|
||||||
self._ensure_channel_tracking(target_channel_id)
|
|
||||||
threads = self.channel_messages[str(target_channel_id)].setdefault("threads", {})
|
|
||||||
if str(thread_id) not in threads:
|
|
||||||
threads[str(thread_id)] = {
|
|
||||||
"thread_map": {},
|
|
||||||
"last_message_id": "",
|
|
||||||
"last_message_timestamp": "",
|
|
||||||
"total_messages": 0,
|
|
||||||
"total_files": 0
|
|
||||||
}
|
|
||||||
|
|
||||||
def increment_thread_stats(self, target_channel_id: str, thread_id: str, messages: int = 1, files: int = 0):
|
|
||||||
self._ensure_thread_tracking(target_channel_id, thread_id)
|
|
||||||
t = self.channel_messages[str(target_channel_id)]["threads"][str(thread_id)]
|
|
||||||
t["total_messages"] = t.get("total_messages", 0) + messages
|
|
||||||
t["total_files"] = t.get("total_files", 0) + files
|
|
||||||
self.save_messages()
|
|
||||||
|
|
||||||
def set_thread_message_mapping(self, target_channel_id: str, thread_id: str, discord_id: str, target_id: str):
|
|
||||||
self._ensure_thread_tracking(target_channel_id, thread_id)
|
|
||||||
self.channel_messages[str(target_channel_id)]["threads"][str(thread_id)]["thread_map"][str(discord_id)] = str(target_id)
|
|
||||||
# Also add to main message_map for global message resolution (like replies)
|
|
||||||
self.set_message_mapping(target_channel_id, discord_id, target_id)
|
|
||||||
|
|
||||||
def update_thread_last_message_timestamp(self, target_channel_id: str, thread_id: str, timestamp: str):
|
|
||||||
self._ensure_thread_tracking(target_channel_id, thread_id)
|
|
||||||
self.channel_messages[str(target_channel_id)]["threads"][str(thread_id)]["last_message_timestamp"] = str(timestamp)
|
|
||||||
self.save_messages()
|
|
||||||
|
|
||||||
def update_thread_last_message_id(self, target_channel_id: str, thread_id: str, message_id: str):
|
|
||||||
self._ensure_thread_tracking(target_channel_id, thread_id)
|
|
||||||
self.channel_messages[str(target_channel_id)]["threads"][str(thread_id)]["last_message_id"] = str(message_id)
|
|
||||||
self.save_messages()
|
|
||||||
|
|
||||||
def get_thread_message_id(self, target_channel_id: str, thread_id: str, discord_id: str) -> str | None:
|
|
||||||
if str(target_channel_id) in self.channel_messages:
|
|
||||||
threads = self.channel_messages[str(target_channel_id)].get("threads", {})
|
|
||||||
if str(thread_id) in threads:
|
|
||||||
return threads[str(thread_id)]["thread_map"].get(str(discord_id))
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def set_message_mapping(self, target_channel_id: str, discord_id: str, fluxer_id: str):
|
def set_message_mapping(self, target_channel_id: str, discord_id: str, target_id: str):
|
||||||
self._ensure_channel_tracking(target_channel_id)
|
self.set_target_message_mapping(target_channel_id, discord_id, target_id)
|
||||||
self.channel_messages[str(target_channel_id)]["message_map"][str(discord_id)] = str(fluxer_id)
|
|
||||||
self.save_messages()
|
|
||||||
|
|
||||||
def get_fluxer_message_id(self, target_channel_id: str, discord_id: str) -> str | None:
|
def get_fluxer_message_id(self, target_channel_id: str, discord_id: str) -> str | None:
|
||||||
if str(target_channel_id) in self.channel_messages:
|
return self.get_target_message_id(target_channel_id, discord_id)
|
||||||
return self.channel_messages[str(target_channel_id)]["message_map"].get(str(discord_id))
|
|
||||||
|
def increment_stats(self, target_channel_id: str, messages: int = 1, files: int = 0):
|
||||||
|
if self._ensure_db():
|
||||||
|
self.db.update_channel_tracking(str(target_channel_id), msg_inc=messages, file_inc=files)
|
||||||
|
|
||||||
|
def increment_thread_stats(self, target_channel_id: str, thread_id: str, messages: int = 1, files: int = 0):
|
||||||
|
if self._ensure_db():
|
||||||
|
self.db.update_thread_tracking(str(target_channel_id), str(thread_id), msg_inc=messages, file_inc=files)
|
||||||
|
|
||||||
|
def set_thread_message_mapping(self, target_channel_id: str, thread_id: str, discord_id: str, target_id: str):
|
||||||
|
if self._ensure_db():
|
||||||
|
self.db.set_thread_message_mapping(str(target_channel_id), str(thread_id), str(discord_id), str(target_id))
|
||||||
|
|
||||||
|
def update_thread_last_message_timestamp(self, target_channel_id: str, thread_id: str, timestamp: str):
|
||||||
|
if self._ensure_db():
|
||||||
|
self.db.update_thread_tracking(str(target_channel_id), str(thread_id), last_msg_ts=str(timestamp))
|
||||||
|
|
||||||
|
def update_thread_last_message_id(self, target_channel_id: str, thread_id: str, message_id: str):
|
||||||
|
if self._ensure_db():
|
||||||
|
self.db.update_thread_tracking(str(target_channel_id), str(thread_id), last_msg_id=str(message_id))
|
||||||
|
|
||||||
|
def get_thread_message_id(self, target_channel_id: str, thread_id: str, discord_id: str) -> str | None:
|
||||||
|
if self._ensure_db():
|
||||||
|
return self.db.get_target_thread_message_id(str(target_channel_id), str(thread_id), str(discord_id))
|
||||||
|
return None
|
||||||
|
|
||||||
|
def update_last_message_timestamp(self, target_channel_id: str, timestamp: str):
|
||||||
|
if self._ensure_db():
|
||||||
|
self.db.update_channel_tracking(str(target_channel_id), last_msg_ts=str(timestamp))
|
||||||
|
|
||||||
|
def update_last_message_id(self, target_channel_id: str, message_id: str):
|
||||||
|
if self._ensure_db():
|
||||||
|
self.db.update_channel_tracking(str(target_channel_id), last_msg_id=str(message_id))
|
||||||
|
|
||||||
|
def get_last_message_id(self, target_channel_id: str) -> str | None:
|
||||||
|
if self._ensure_db():
|
||||||
|
return self.db.get_channel_tracking(str(target_channel_id)).get("last_msg_id")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def find_message_mapping(self, discord_id: str) -> tuple[str, str] | tuple[None, None]:
|
def find_message_mapping(self, discord_id: str) -> tuple[str, str] | tuple[None, None]:
|
||||||
"""
|
if not self.db:
|
||||||
Searches for a message mapping across all tracked channels.
|
return None, None
|
||||||
Returns (target_channel_id, target_message_id) or (None, None).
|
conn = self.db._get_conn()
|
||||||
"""
|
row = conn.execute("SELECT channel_id, target_msg_id FROM message_mappings WHERE source_msg_id = ?", (str(discord_id),)).fetchone()
|
||||||
d_id = str(discord_id)
|
if row:
|
||||||
for t_cid, data in self.channel_messages.items():
|
return str(row["channel_id"]), str(row["target_msg_id"])
|
||||||
# Check main message map
|
row = conn.execute("SELECT thread_id, target_msg_id FROM thread_mappings WHERE source_msg_id = ?", (str(discord_id),)).fetchone()
|
||||||
if d_id in data.get("message_map", {}):
|
if row:
|
||||||
return str(t_cid), str(data["message_map"][d_id])
|
return str(row["thread_id"]), str(row["target_msg_id"])
|
||||||
# Check threads
|
|
||||||
for t_tid, t_data in data.get("threads", {}).items():
|
|
||||||
if d_id in t_data.get("thread_map", {}):
|
|
||||||
# For thread links, the target_channel_id is technically the thread ID in some contexts,
|
|
||||||
# but usually for the URL it's the thread ID itself.
|
|
||||||
return str(t_tid), str(t_data["thread_map"][d_id])
|
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
def update_last_message_timestamp(self, target_channel_id: str, timestamp: str):
|
|
||||||
self._ensure_channel_tracking(target_channel_id)
|
|
||||||
self.channel_messages[str(target_channel_id)]["last_message_timestamp"] = str(timestamp)
|
|
||||||
self.save_messages()
|
|
||||||
|
|
||||||
def update_last_message_id(self, target_channel_id: str, message_id: str):
|
|
||||||
self._ensure_channel_tracking(target_channel_id)
|
|
||||||
self.channel_messages[str(target_channel_id)]["last_message_id"] = str(message_id)
|
|
||||||
self.save_messages()
|
|
||||||
|
|
||||||
def get_last_message_id(self, target_channel_id: str) -> str | None:
|
|
||||||
if str(target_channel_id) in self.channel_messages:
|
|
||||||
return self.channel_messages[str(target_channel_id)].get("last_message_id")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# --- Danger Zone Clearing ---
|
# --- Danger Zone Clearing ---
|
||||||
|
|
||||||
def clear_channel_mappings(self):
|
def clear_channel_mappings(self):
|
||||||
"""Clears all channel and category mappings."""
|
if self._ensure_db():
|
||||||
self.channel_map.clear()
|
self.db.clear_entities("channel")
|
||||||
self.category_map.clear()
|
self.db.clear_entities("category")
|
||||||
self.save_state()
|
|
||||||
|
|
||||||
def clear_role_mappings(self):
|
def clear_role_mappings(self):
|
||||||
"""Clears all role mappings."""
|
if self._ensure_db():
|
||||||
self.role_map.clear()
|
self.db.clear_entities("role")
|
||||||
self.save_state()
|
|
||||||
|
|
||||||
def clear_asset_mappings(self):
|
def clear_asset_mappings(self):
|
||||||
"""Clears all emoji and sticker mappings."""
|
if self._ensure_db():
|
||||||
self.emoji_map.clear()
|
self.db.clear_entities("emoji")
|
||||||
self.sticker_map.clear()
|
self.db.clear_entities("sticker")
|
||||||
self.save_state()
|
|
||||||
|
|
||||||
def clear_message_history(self):
|
def clear_message_history(self):
|
||||||
"""Clears all message mappings and timestamps."""
|
if self.db:
|
||||||
self.channel_messages.clear()
|
conn = self.db._get_conn()
|
||||||
self.save_messages()
|
conn.execute("DELETE FROM message_mappings")
|
||||||
|
conn.execute("DELETE FROM thread_mappings")
|
||||||
|
conn.execute("DELETE FROM channel_tracking")
|
||||||
|
conn.execute("DELETE FROM thread_tracking")
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
def set_folder(self, server_id: str, clean_name: str, base_dir: Path | str = ""):
|
def set_folder(self, server_id: str, clean_name: str, base_dir: Path | str = ""):
|
||||||
|
"""
|
||||||
|
Initializes the SQLite database based on community name and ID.
|
||||||
|
Filename: {name}-{id}.db (Flat structure)
|
||||||
|
ID is priority: if a DB with the same ID exists but different name, rename it.
|
||||||
|
"""
|
||||||
base = Path(base_dir) if base_dir else Path(".")
|
base = Path(base_dir) if base_dir else Path(".")
|
||||||
new_folder = base / f"{clean_name}-{server_id}"
|
desired_filename = f"{clean_name}-{server_id}.db"
|
||||||
logger.info(f"Setting active migration folder: {new_folder}")
|
desired_path = base / desired_filename
|
||||||
|
|
||||||
# 1. Search base_dir to see if an older folder for this server_id exists
|
# Priority 1: Match by ID
|
||||||
existing_folder: Path | None = None
|
existing_db: Path | None = None
|
||||||
if base.exists() and base.is_dir():
|
# Look for any file ending with -{server_id}.db
|
||||||
for d in base.iterdir():
|
for f in base.glob(f"*-{server_id}.db"):
|
||||||
if d.is_dir() and d.name.endswith(f"-{server_id}"):
|
if f.is_file():
|
||||||
existing_folder = d
|
existing_db = f
|
||||||
break
|
break
|
||||||
|
|
||||||
# 2. Rename it if it doesn't match the new desired name
|
db_path = desired_path
|
||||||
if existing_folder and existing_folder != new_folder:
|
if existing_db:
|
||||||
logger.info(f"Renaming existing folder {existing_folder.name} to {new_folder.name}")
|
if existing_db.name != desired_filename:
|
||||||
try:
|
logger.info(f"Server renamed: moving {existing_db.name} -> {desired_filename}")
|
||||||
existing_folder.rename(new_folder)
|
try:
|
||||||
except Exception as e:
|
existing_db.rename(desired_path)
|
||||||
logger.debug(f"Could not rename {existing_folder} to {new_folder}: {e}")
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to rename database: {e}")
|
||||||
|
# If rename fails, we'll use the existing one if it exists at the old path,
|
||||||
|
# or the desired one if it exists there.
|
||||||
|
if not desired_path.exists():
|
||||||
|
db_path = existing_db
|
||||||
|
|
||||||
|
logger.info(f"Setting active migration database: {db_path}")
|
||||||
|
|
||||||
|
from src.core.database import MigrationDatabase
|
||||||
|
if self.db:
|
||||||
|
self.db.close()
|
||||||
|
self.db = MigrationDatabase(db_path)
|
||||||
|
logger.info(f"Initialized SQLite database at {db_path}")
|
||||||
|
|
||||||
new_folder.mkdir(parents=True, exist_ok=True)
|
# No-op methods kept for compatibility with callers that might try to load/save JSON
|
||||||
|
def load(self): pass
|
||||||
self.state_file = new_folder / "state-migration.json"
|
def save_state(self): pass
|
||||||
self.messages_file = new_folder / "message-tracker.json"
|
|
||||||
|
|
||||||
logger.debug("Re-loading data from new folder location.")
|
|
||||||
self.load()
|
|
||||||
|
|
|
||||||
|
|
@ -1576,7 +1576,7 @@ class OperationPane(Container):
|
||||||
|
|
||||||
async def _fetch_clone_preview(self, selections: list[str]) -> dict[str, Any]:
|
async def _fetch_clone_preview(self, selections: list[str]) -> dict[str, Any]:
|
||||||
"""Fetches preview data from Discord (source server) for cloning confirmation,
|
"""Fetches preview data from Discord (source server) for cloning confirmation,
|
||||||
comparing with existing mappings in state-migration.json for presence highlighting."""
|
comparing with existing mappings in SQLite for presence highlighting."""
|
||||||
preview = {}
|
preview = {}
|
||||||
reader = self.engine.discord_reader
|
reader = self.engine.discord_reader
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue