switch to sqlite for migration states
This commit is contained in:
parent
d0a8a1acb4
commit
de28ecdbb1
5 changed files with 513 additions and 342 deletions
|
|
@ -696,7 +696,15 @@ class BackupReader:
|
|||
self.backup_path = Path(backup_path)
|
||||
self.guild: BackupGuild | None = None
|
||||
|
||||
# Internal caches populated by start()
|
||||
self._thread_info: Dict[int, Dict[str, Any]] = {} # channel_id -> metadata (like name, parentID)
|
||||
|
||||
# Lazy loading flags
|
||||
self._roles_loaded = False
|
||||
self._structure_loaded = False
|
||||
self._assets_loaded = False
|
||||
self._members_loaded = False
|
||||
|
||||
# Internal storage
|
||||
self._categories: List[BackupCategory] = []
|
||||
self._channels: List[BackupChannel] = []
|
||||
self._roles: List[BackupRole] = []
|
||||
|
|
@ -704,12 +712,11 @@ class BackupReader:
|
|||
self._stickers: List[BackupSticker] = []
|
||||
self._members: List[BackupMember] = []
|
||||
self._member_map: Dict[int, BackupMember] = {}
|
||||
self._thread_info: Dict[int, Dict[str, Any]] = {} # channel_id -> metadata (like name, parentID)
|
||||
|
||||
# ── startup ──────────────────────────────────────────────────────────
|
||||
|
||||
async def start(self):
|
||||
"""Loads all JSON files from the backup directory into memory."""
|
||||
"""Initializes the backup path and loads the server profile."""
|
||||
bp = self.backup_path
|
||||
|
||||
# 1. Server profile -> BackupGuild
|
||||
|
|
@ -717,60 +724,95 @@ class BackupReader:
|
|||
if profile_file.exists():
|
||||
profile = json.loads(profile_file.read_text(encoding="utf-8"))
|
||||
self.guild = BackupGuild(profile, bp, reader=self)
|
||||
logger.info(f"[Backup] Loaded server profile: {self.guild.name} ({self.guild.id})")
|
||||
logger.info(f"[Backup] Initialized server: {self.guild.name} ({self.guild.id})")
|
||||
else:
|
||||
logger.warning(f"[Backup] server_profile/profile.json not found in {bp}")
|
||||
self.guild = None
|
||||
|
||||
# 2. Roles
|
||||
roles_file = bp / "server_profile" / "roles.json"
|
||||
@property
|
||||
def roles(self) -> List[BackupRole]:
|
||||
if not self._roles_loaded:
|
||||
roles_file = self.backup_path / "server_profile" / "roles.json"
|
||||
if roles_file.exists():
|
||||
logger.info(f"[Backup] Lazy-loading roles...")
|
||||
roles_data = json.loads(roles_file.read_text(encoding="utf-8"))
|
||||
self._roles = [BackupRole(r) for r in roles_data]
|
||||
logger.info(f"[Backup] Loaded {len(self._roles)} roles")
|
||||
self._roles_loaded = True
|
||||
return self._roles
|
||||
|
||||
# 3. Structure -> categories + channels
|
||||
struct_file = bp / "server_profile" / "structure.json"
|
||||
@property
|
||||
def categories(self) -> List[BackupCategory]:
|
||||
self._ensure_structure_loaded()
|
||||
return self._categories
|
||||
|
||||
@property
|
||||
def channels(self) -> List[BackupChannel]:
|
||||
self._ensure_structure_loaded()
|
||||
return self._channels
|
||||
|
||||
def _ensure_structure_loaded(self):
|
||||
if self._structure_loaded:
|
||||
return
|
||||
struct_file = self.backup_path / "server_profile" / "structure.json"
|
||||
if struct_file.exists():
|
||||
logger.info(f"[Backup] Lazy-loading server structure...")
|
||||
structure = json.loads(struct_file.read_text(encoding="utf-8"))
|
||||
for cat_data in structure:
|
||||
cat = BackupCategory(cat_data)
|
||||
if cat.id != 0: # skip 'uncategorized' as a real category
|
||||
if cat.id != 0:
|
||||
self._categories.append(cat)
|
||||
|
||||
for ch_data in cat_data.get("channels", []):
|
||||
ch_cat_id = cat.id if cat.id != 0 else None
|
||||
channel = BackupChannel(ch_data, category_id=ch_cat_id, guild=self.guild)
|
||||
self._channels.append(channel)
|
||||
self._structure_loaded = True
|
||||
|
||||
logger.info(f"[Backup] Loaded {len(self._categories)} categories, "
|
||||
f"{len(self._channels)} channels")
|
||||
@property
|
||||
def emojis(self) -> List[BackupEmoji]:
|
||||
self._ensure_assets_loaded()
|
||||
return self._emojis
|
||||
|
||||
# 4. Assets (emojis + stickers)
|
||||
assets_file = bp / "server_profile" / "assets.json"
|
||||
media_dir = bp / "server_profile" / "assets"
|
||||
@property
|
||||
def stickers(self) -> List[BackupSticker]:
|
||||
self._ensure_assets_loaded()
|
||||
return self._stickers
|
||||
|
||||
def _ensure_assets_loaded(self):
|
||||
if self._assets_loaded:
|
||||
return
|
||||
assets_file = self.backup_path / "server_profile" / "assets.json"
|
||||
media_dir = self.backup_path / "server_profile" / "assets"
|
||||
if assets_file.exists():
|
||||
logger.info(f"[Backup] Lazy-loading assets...")
|
||||
assets = json.loads(assets_file.read_text(encoding="utf-8"))
|
||||
self._emojis = [BackupEmoji(e, media_dir) for e in assets.get("emojis", [])]
|
||||
self._stickers = [BackupSticker(s, media_dir) for s in assets.get("stickers", [])]
|
||||
logger.info(f"[Backup] Loaded {len(self._emojis)} emojis, "
|
||||
f"{len(self._stickers)} stickers")
|
||||
self._assets_loaded = True
|
||||
|
||||
# 5. Users
|
||||
user_info_file = bp / "message_backup" / "users" / "user_info.json"
|
||||
@property
|
||||
def members(self) -> List[BackupMember]:
|
||||
self._ensure_members_loaded()
|
||||
return self._members
|
||||
|
||||
def _ensure_members_loaded(self):
|
||||
if self._members_loaded:
|
||||
return
|
||||
user_info_file = self.backup_path / "message_backup" / "users" / "user_info.json"
|
||||
if user_info_file.exists():
|
||||
logger.info(f"[Backup] Lazy-loading members...")
|
||||
try:
|
||||
users = json.loads(user_info_file.read_text(encoding="utf-8"))
|
||||
backup_root = bp / "message_backup"
|
||||
backup_root = self.backup_path / "message_backup"
|
||||
for u in users:
|
||||
user_role_ids = {int(r["id"]) for r in u.get("userRoles", [])}
|
||||
role_objs = [r for r in self._roles if r.id in user_role_ids]
|
||||
# Note: this triggers roles lazy load
|
||||
role_objs = [r for r in self.roles if r.id in user_role_ids]
|
||||
member = BackupMember(u, role_objects=role_objs, avatar_base=backup_root)
|
||||
self._members.append(member)
|
||||
self._member_map[member.id] = member
|
||||
logger.info(f"[Backup] Loaded {len(self._members)} users")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Backup] Failed to load user_info.json: {e}")
|
||||
logger.warning(f"[Backup] Failed to lazy-load user_info.json: {e}")
|
||||
self._members_loaded = True
|
||||
|
||||
# ── validation ───────────────────────────────────────────────────────
|
||||
|
||||
|
|
|
|||
|
|
@ -89,6 +89,16 @@ class MigrationContext:
|
|||
"target_community_name": t_valid.get("community_name"),
|
||||
"target_permissions": t_valid.get("permissions", {})
|
||||
}
|
||||
|
||||
# CONSISTENCY: Once target metadata is known, initialize the flat SQLite DB.
|
||||
if results["target_community"] and results["target_community_name"]:
|
||||
import re
|
||||
clean_name = re.sub(r'[^\w\s-]', '', results["target_community_name"]).strip()
|
||||
clean_name = re.sub(r'[-\s]+', '_', clean_name)
|
||||
db_community_id = str(self.config.target_server_id or "")
|
||||
self.state.set_folder(db_community_id, clean_name, base_dir=base_dir)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Validation failed with exception: {e}")
|
||||
return {
|
||||
|
|
|
|||
237
src/core/database.py
Normal file
237
src/core/database.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
import sqlite3
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
import threading
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MigrationDatabase:
|
||||
"""
|
||||
SQLite-based persistence for large-scale migration mappings and stats.
|
||||
Replaces the memory-bloated and O(N^2) JSON persistence for messages.
|
||||
"""
|
||||
|
||||
_local = threading.local()
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
self.db_path = db_path
|
||||
self._init_db()
|
||||
|
||||
def _get_conn(self) -> sqlite3.Connection:
|
||||
if not hasattr(self._local, "conn"):
|
||||
self._local.conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
self._local.conn.row_factory = sqlite3.Row
|
||||
return self._local.conn
|
||||
|
||||
def _init_db(self):
|
||||
"""Initialize tables if they don't exist."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Table for message mappings: SourceID -> TargetID
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS message_mappings (
|
||||
channel_id TEXT,
|
||||
source_msg_id TEXT,
|
||||
target_msg_id TEXT,
|
||||
timestamp TEXT,
|
||||
PRIMARY KEY (channel_id, source_msg_id)
|
||||
)
|
||||
""")
|
||||
|
||||
# Table for thread mappings
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS thread_mappings (
|
||||
channel_id TEXT,
|
||||
thread_id TEXT,
|
||||
source_msg_id TEXT,
|
||||
target_msg_id TEXT,
|
||||
timestamp TEXT,
|
||||
PRIMARY KEY (channel_id, thread_id, source_msg_id)
|
||||
)
|
||||
""")
|
||||
|
||||
# Table for per-channel stats and tracking
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS channel_tracking (
|
||||
channel_id TEXT PRIMARY KEY,
|
||||
last_msg_id TEXT,
|
||||
last_msg_ts TEXT,
|
||||
msg_count INTEGER DEFAULT 0,
|
||||
file_count INTEGER DEFAULT 0
|
||||
)
|
||||
""")
|
||||
|
||||
# Table for per-thread stats
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS thread_tracking (
|
||||
channel_id TEXT,
|
||||
thread_id TEXT,
|
||||
last_msg_id TEXT,
|
||||
last_msg_ts TEXT,
|
||||
msg_count INTEGER DEFAULT 0,
|
||||
file_count INTEGER DEFAULT 0,
|
||||
PRIMARY KEY (channel_id, thread_id)
|
||||
)
|
||||
""")
|
||||
|
||||
# Table for entity mappings (channels, roles, etc.)
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS entity_mappings (
|
||||
category TEXT,
|
||||
source_id TEXT,
|
||||
target_id TEXT,
|
||||
PRIMARY KEY (category, source_id)
|
||||
)
|
||||
""")
|
||||
|
||||
# Table for general metadata
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS metadata (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def set_message_mapping(self, channel_id: str, source_id: str, target_id: str, timestamp: str = None):
|
||||
conn = self._get_conn()
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO message_mappings (channel_id, source_msg_id, target_msg_id, timestamp) VALUES (?, ?, ?, ?)",
|
||||
(channel_id, source_id, target_id, timestamp)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_target_message_id(self, channel_id: str, source_id: str) -> Optional[str]:
|
||||
conn = self._get_conn()
|
||||
row = conn.execute(
|
||||
"SELECT target_msg_id FROM message_mappings WHERE channel_id = ? AND source_msg_id = ?",
|
||||
(channel_id, source_id)
|
||||
).fetchone()
|
||||
return row["target_msg_id"] if row else None
|
||||
|
||||
# --- New Entity Mapping Methods ---
|
||||
|
||||
def set_entity_mapping(self, category: str, source_id: str, target_id: str):
|
||||
conn = self._get_conn()
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO entity_mappings (category, source_id, target_id) VALUES (?, ?, ?)",
|
||||
(category, str(source_id), str(target_id))
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_entity_mapping(self, category: str, source_id: str) -> Optional[str]:
|
||||
conn = self._get_conn()
|
||||
row = conn.execute(
|
||||
"SELECT target_id FROM entity_mappings WHERE category = ? AND source_id = ?",
|
||||
(category, str(source_id))
|
||||
).fetchone()
|
||||
return row["target_id"] if row else None
|
||||
|
||||
def get_all_entity_mappings(self, category: str) -> Dict[str, str]:
|
||||
conn = self._get_conn()
|
||||
rows = conn.execute(
|
||||
"SELECT source_id, target_id FROM entity_mappings WHERE category = ?",
|
||||
(category,)
|
||||
).fetchall()
|
||||
return {row["source_id"]: row["target_id"] for row in rows}
|
||||
|
||||
def delete_entity_mapping(self, category: str, source_id: str):
|
||||
conn = self._get_conn()
|
||||
conn.execute(
|
||||
"DELETE FROM entity_mappings WHERE category = ? AND source_id = ?",
|
||||
(category, str(source_id))
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def clear_entities(self, category: str = None):
|
||||
conn = self._get_conn()
|
||||
if category:
|
||||
conn.execute("DELETE FROM entity_mappings WHERE category = ?", (category,))
|
||||
else:
|
||||
conn.execute("DELETE FROM entity_mappings")
|
||||
conn.commit()
|
||||
|
||||
# --- Metadata Methods ---
|
||||
|
||||
def set_metadata(self, key: str, value: str):
|
||||
conn = self._get_conn()
|
||||
conn.execute("INSERT OR REPLACE INTO metadata (key, value) VALUES (?, ?)", (key, str(value) if value is not None else None))
|
||||
conn.commit()
|
||||
|
||||
def get_metadata(self, key: str) -> Optional[str]:
|
||||
conn = self._get_conn()
|
||||
row = conn.execute("SELECT value FROM metadata WHERE key = ?", (key,)).fetchone()
|
||||
return row["value"] if row else None
|
||||
|
||||
def update_channel_tracking(self, channel_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0):
|
||||
conn = self._get_conn()
|
||||
# Initialize if missing
|
||||
conn.execute("INSERT OR IGNORE INTO channel_tracking (channel_id) VALUES (?)", (channel_id,))
|
||||
|
||||
if last_msg_id:
|
||||
conn.execute("UPDATE channel_tracking SET last_msg_id = ? WHERE channel_id = ?", (last_msg_id, channel_id))
|
||||
if last_msg_ts:
|
||||
conn.execute("UPDATE channel_tracking SET last_msg_ts = ? WHERE channel_id = ?", (last_msg_ts, channel_id))
|
||||
|
||||
if msg_inc != 0 or file_inc != 0:
|
||||
conn.execute(
|
||||
"UPDATE channel_tracking SET msg_count = msg_count + ?, file_count = file_count + ? WHERE channel_id = ?",
|
||||
(msg_inc, file_inc, channel_id)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_channel_tracking(self, channel_id: str) -> Dict[str, Any]:
|
||||
conn = self._get_conn()
|
||||
row = conn.execute("SELECT * FROM channel_tracking WHERE channel_id = ?", (channel_id,)).fetchone()
|
||||
if row:
|
||||
return dict(row)
|
||||
return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0}
|
||||
|
||||
# Thread methods similar to channel methods
|
||||
def set_thread_message_mapping(self, channel_id: str, thread_id: str, source_id: str, target_id: str, timestamp: str = None):
|
||||
conn = self._get_conn()
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO thread_mappings (channel_id, thread_id, source_msg_id, target_msg_id, timestamp) VALUES (?, ?, ?, ?, ?)",
|
||||
(channel_id, thread_id, source_id, target_id, timestamp)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_target_thread_message_id(self, channel_id: str, thread_id: str, source_id: str) -> Optional[str]:
|
||||
conn = self._get_conn()
|
||||
row = conn.execute(
|
||||
"SELECT target_msg_id FROM thread_mappings WHERE channel_id = ? AND thread_id = ? AND source_msg_id = ?",
|
||||
(channel_id, thread_id, source_id)
|
||||
).fetchone()
|
||||
return row["target_msg_id"] if row else None
|
||||
|
||||
def update_thread_tracking(self, channel_id: str, thread_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0):
|
||||
conn = self._get_conn()
|
||||
conn.execute("INSERT OR IGNORE INTO thread_tracking (channel_id, thread_id) VALUES (?, ?)", (channel_id, thread_id))
|
||||
|
||||
if last_msg_id:
|
||||
conn.execute("UPDATE thread_tracking SET last_msg_id = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_id, channel_id, thread_id))
|
||||
if last_msg_ts:
|
||||
conn.execute("UPDATE thread_tracking SET last_msg_ts = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_ts, channel_id, thread_id))
|
||||
|
||||
if msg_inc != 0 or file_inc != 0:
|
||||
conn.execute(
|
||||
"UPDATE thread_tracking SET msg_count = msg_count + ?, file_count = file_count + ? WHERE channel_id = ? AND thread_id = ?",
|
||||
(msg_inc, file_inc, channel_id, thread_id)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_thread_tracking(self, channel_id: str, thread_id: str) -> Dict[str, Any]:
|
||||
conn = self._get_conn()
|
||||
row = conn.execute("SELECT * FROM thread_tracking WHERE channel_id = ? AND thread_id = ?", (channel_id, thread_id)).fetchone()
|
||||
if row:
|
||||
return dict(row)
|
||||
return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0}
|
||||
|
||||
def close(self):
|
||||
if hasattr(self._local, "conn"):
|
||||
self._local.conn.close()
|
||||
del self._local.conn
|
||||
|
|
@ -1,366 +1,248 @@
|
|||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MigrationState:
|
||||
"""Manages persistence of the migration state to allow resumability."""
|
||||
"""Manages persistence of the migration state to allow resumability.
|
||||
Uses SQLite for ALL mappings and metadata.
|
||||
"""
|
||||
|
||||
def __init__(self, state_file: str | Path = "", messages_file: str | Path = ""):
|
||||
self.state_file: Path | None = Path(state_file) if state_file else None
|
||||
self.messages_file: Path | None = Path(messages_file) if messages_file else None
|
||||
def __init__(self):
|
||||
# database instance for all persistence
|
||||
self.db: Optional['MigrationDatabase'] = None
|
||||
|
||||
# mappings: discord_id -> fluxer_id
|
||||
self.channel_map: Dict[str, str] = {}
|
||||
self.category_map: Dict[str, str] = {}
|
||||
self.role_map: Dict[str, str] = {}
|
||||
self.emoji_map: Dict[str, str] = {}
|
||||
self.sticker_map: Dict[str, str] = {}
|
||||
def _ensure_db(self):
|
||||
if not self.db:
|
||||
logger.warning("MigrationState: Accessing database before initialization")
|
||||
return False
|
||||
return True
|
||||
|
||||
# audit log tracking
|
||||
self.audit_log_channel: str | None = None
|
||||
# --- Type Specific Getters/Setters (Database Backed) ---
|
||||
|
||||
# message tracking per target channel
|
||||
# Format: { target_channel_id: {"message_map": {}, "last_message_id": "", "last_message_timestamp": ""} }
|
||||
self.channel_messages: Dict[str, Dict[str, Any]] = {}
|
||||
def set_channel_mapping(self, discord_id: str, target_id: str):
|
||||
if self._ensure_db():
|
||||
self.db.set_entity_mapping("channel", str(discord_id), str(target_id))
|
||||
|
||||
self.load()
|
||||
def get_target_channel_id(self, discord_id: str) -> str | None:
|
||||
if self._ensure_db():
|
||||
return self.db.get_entity_mapping("channel", str(discord_id))
|
||||
return None
|
||||
|
||||
def load(self):
|
||||
migrated_state = False
|
||||
migrated_messages = False
|
||||
get_fluxer_channel_id = get_target_channel_id
|
||||
set_target_channel_mapping = set_channel_mapping
|
||||
|
||||
# 1. Load primary state file
|
||||
if self.state_file and self.state_file.exists():
|
||||
with open(self.state_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
self.channel_map = data.get("channels", {})
|
||||
self.category_map = data.get("categories", {})
|
||||
self.role_map = data.get("roles", {})
|
||||
self.emoji_map = data.get("emojis", {})
|
||||
self.sticker_map = data.get("stickers", {})
|
||||
self.audit_log_channel = data.get("audit_log_channel")
|
||||
|
||||
# 2. Load separate messages file
|
||||
if self.messages_file and self.messages_file.exists():
|
||||
logger.info(f"Loading messages from {self.messages_file.name}")
|
||||
try:
|
||||
with open(self.messages_file, "r", encoding="utf-8") as f:
|
||||
msg_data = json.load(f)
|
||||
|
||||
# Check for new schema (nested under 'channels')
|
||||
if "channels" in msg_data:
|
||||
self.channel_messages = msg_data.get("channels", {})
|
||||
logger.debug(f"Loaded {len(self.channel_messages)} tracked channels.")
|
||||
else:
|
||||
logger.warning("Legacy schema or empty tracker detected in messages file.")
|
||||
# Legacy schema detection & conversion to a default 'unknown_channel' just in case,
|
||||
# though new migrations shouldn't hit this based on previous removals.
|
||||
legacy_map = msg_data.get("messages", {})
|
||||
legacy_ids = msg_data.get("last_message_ids", {})
|
||||
legacy_times = msg_data.get("last_message_timestamps", {})
|
||||
|
||||
if legacy_map or legacy_ids or legacy_times:
|
||||
self.channel_messages = {
|
||||
"legacy_migrated_channel": {
|
||||
"message_map": legacy_map,
|
||||
"last_message_id": list(legacy_ids.values())[-1] if legacy_ids else "",
|
||||
"last_message_timestamp": list(legacy_times.values())[-1] if legacy_times else ""
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load messages file: {e}")
|
||||
|
||||
|
||||
|
||||
def save_state(self):
|
||||
"""Saves only the core server configuration (channels, roles, emojis)."""
|
||||
if not self.state_file:
|
||||
return
|
||||
|
||||
logger.debug(f"Saving state to {self.state_file.name}")
|
||||
data = {
|
||||
"channels": self.channel_map,
|
||||
"categories": self.category_map,
|
||||
"roles": self.role_map,
|
||||
"emojis": self.emoji_map,
|
||||
"stickers": self.sticker_map,
|
||||
"audit_log_channel": self.audit_log_channel
|
||||
}
|
||||
try:
|
||||
with open(self.state_file, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save state file: {e}")
|
||||
|
||||
def save_messages(self):
|
||||
"""Saves only the message tracking data."""
|
||||
if not self.messages_file:
|
||||
return
|
||||
|
||||
logger.debug(f"Saving messages to {self.messages_file.name}")
|
||||
data = {
|
||||
"channels": self.channel_messages
|
||||
}
|
||||
try:
|
||||
with open(self.messages_file, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save messages file: {e}")
|
||||
|
||||
# --- Type Specific Getters/Setters ---
|
||||
|
||||
def set_channel_mapping(self, discord_id: str, fluxer_id: str):
|
||||
self.channel_map[str(discord_id)] = str(fluxer_id)
|
||||
self.save_state()
|
||||
|
||||
def get_fluxer_channel_id(self, discord_id: str) -> str | None:
|
||||
return self.channel_map.get(str(discord_id))
|
||||
|
||||
def remove_channel_mapping(self, discord_id: str):
|
||||
self.channel_map.pop(str(discord_id), None)
|
||||
self.save_state()
|
||||
|
||||
def set_category_mapping(self, discord_id: str, fluxer_id: str):
|
||||
self.category_map[str(discord_id)] = str(fluxer_id)
|
||||
self.save_state()
|
||||
|
||||
def get_fluxer_category_id(self, discord_id: str) -> str | None:
|
||||
return self.category_map.get(str(discord_id))
|
||||
|
||||
def remove_category_mapping(self, discord_id: str):
|
||||
self.category_map.pop(str(discord_id), None)
|
||||
self.save_state()
|
||||
|
||||
def set_role_mapping(self, discord_id: str, fluxer_id: str):
|
||||
self.role_map[str(discord_id)] = str(fluxer_id)
|
||||
self.save_state()
|
||||
|
||||
def get_fluxer_role_id(self, discord_id: str) -> str | None:
|
||||
return self.role_map.get(str(discord_id))
|
||||
|
||||
def remove_role_mapping(self, discord_id: str):
|
||||
self.role_map.pop(str(discord_id), None)
|
||||
self.save_state()
|
||||
|
||||
def set_emoji_mapping(self, discord_id: str, fluxer_id: str):
|
||||
self.emoji_map[str(discord_id)] = str(fluxer_id)
|
||||
self.save_state()
|
||||
|
||||
def get_fluxer_emoji_id(self, discord_id: str) -> str | None:
|
||||
return self.emoji_map.get(str(discord_id))
|
||||
|
||||
def remove_emoji_mapping(self, discord_id: str):
|
||||
self.emoji_map.pop(str(discord_id), None)
|
||||
self.save_state()
|
||||
|
||||
def set_sticker_mapping(self, discord_id: str, fluxer_id: str):
|
||||
self.sticker_map[str(discord_id)] = str(fluxer_id)
|
||||
self.save_state()
|
||||
|
||||
def get_fluxer_sticker_id(self, discord_id: str) -> str | None:
|
||||
return self.sticker_map.get(str(discord_id))
|
||||
|
||||
def remove_sticker_mapping(self, discord_id: str):
|
||||
self.sticker_map.pop(str(discord_id), None)
|
||||
self.save_state()
|
||||
|
||||
# --- Generic Aliases for target platform migration ---
|
||||
|
||||
get_target_channel_id = get_fluxer_channel_id
|
||||
set_channel_mapping = set_channel_mapping # already generic enough in name if we ignore the 'fluxer' in implementation
|
||||
|
||||
def set_target_channel_mapping(self, discord_id: str, target_id: str):
|
||||
self.set_channel_mapping(discord_id, target_id)
|
||||
def set_category_mapping(self, discord_id: str, target_id: str):
|
||||
if self._ensure_db():
|
||||
self.db.set_entity_mapping("category", str(discord_id), str(target_id))
|
||||
|
||||
def get_target_category_id(self, discord_id: str) -> str | None:
|
||||
return self.get_fluxer_category_id(discord_id)
|
||||
if self._ensure_db():
|
||||
return self.db.get_entity_mapping("category", str(discord_id))
|
||||
return None
|
||||
|
||||
def set_target_category_mapping(self, discord_id: str, target_id: str):
|
||||
self.set_category_mapping(discord_id, target_id)
|
||||
get_fluxer_category_id = get_target_category_id
|
||||
set_target_category_mapping = set_category_mapping
|
||||
|
||||
def set_role_mapping(self, discord_id: str, target_id: str):
|
||||
if self._ensure_db():
|
||||
self.db.set_entity_mapping("role", str(discord_id), str(target_id))
|
||||
|
||||
def get_target_role_id(self, discord_id: str) -> str | None:
|
||||
return self.get_fluxer_role_id(discord_id)
|
||||
if self._ensure_db():
|
||||
return self.db.get_entity_mapping("role", str(discord_id))
|
||||
return None
|
||||
|
||||
def set_target_role_mapping(self, discord_id: str, target_id: str):
|
||||
self.set_role_mapping(discord_id, target_id)
|
||||
get_fluxer_role_id = get_target_role_id
|
||||
set_target_role_mapping = set_role_mapping
|
||||
|
||||
def set_emoji_mapping(self, discord_id: str, target_id: str):
|
||||
if self._ensure_db():
|
||||
self.db.set_entity_mapping("emoji", str(discord_id), str(target_id))
|
||||
|
||||
def get_target_emoji_id(self, discord_id: str) -> str | None:
|
||||
return self.get_fluxer_emoji_id(discord_id)
|
||||
if self._ensure_db():
|
||||
return self.db.get_entity_mapping("emoji", str(discord_id))
|
||||
return None
|
||||
|
||||
def set_target_emoji_mapping(self, discord_id: str, target_id: str):
|
||||
self.set_emoji_mapping(discord_id, target_id)
|
||||
get_fluxer_emoji_id = get_target_emoji_id
|
||||
set_target_emoji_mapping = set_emoji_mapping
|
||||
|
||||
def set_sticker_mapping(self, discord_id: str, target_id: str):
|
||||
if self._ensure_db():
|
||||
self.db.set_entity_mapping("sticker", str(discord_id), str(target_id))
|
||||
|
||||
def get_target_sticker_id(self, discord_id: str) -> str | None:
|
||||
return self.get_fluxer_sticker_id(discord_id)
|
||||
if self._ensure_db():
|
||||
return self.db.get_entity_mapping("sticker", str(discord_id))
|
||||
return None
|
||||
|
||||
def set_target_sticker_mapping(self, discord_id: str, target_id: str):
|
||||
self.set_sticker_mapping(discord_id, target_id)
|
||||
get_fluxer_sticker_id = get_target_sticker_id
|
||||
set_target_sticker_mapping = set_sticker_mapping
|
||||
|
||||
def get_target_message_id(self, target_channel_id: str, discord_id: str) -> str | None:
|
||||
return self.get_fluxer_message_id(target_channel_id, discord_id)
|
||||
# --- Properties for backward compatibility ---
|
||||
@property
|
||||
def channel_map(self) -> Dict[str, str]:
|
||||
return self.db.get_all_entity_mappings("channel") if self.db else {}
|
||||
|
||||
def set_target_message_mapping(self, target_channel_id: str, discord_id: str, target_id: str):
|
||||
self.set_message_mapping(target_channel_id, discord_id, target_id)
|
||||
@property
|
||||
def category_map(self) -> Dict[str, str]:
|
||||
return self.db.get_all_entity_mappings("category") if self.db else {}
|
||||
|
||||
@property
|
||||
def role_map(self) -> Dict[str, str]:
|
||||
return self.db.get_all_entity_mappings("role") if self.db else {}
|
||||
|
||||
@property
|
||||
def emoji_map(self) -> Dict[str, str]:
|
||||
return self.db.get_all_entity_mappings("emoji") if self.db else {}
|
||||
|
||||
@property
|
||||
def sticker_map(self) -> Dict[str, str]:
|
||||
return self.db.get_all_entity_mappings("sticker") if self.db else {}
|
||||
|
||||
@property
|
||||
def audit_log_channel(self) -> str | None:
|
||||
return self.db.get_metadata("audit_log_channel") if self.db else None
|
||||
|
||||
@audit_log_channel.setter
|
||||
def audit_log_channel(self, value: str | None):
|
||||
if self._ensure_db():
|
||||
self.db.set_metadata("audit_log_channel", value)
|
||||
|
||||
# --- Message Management ---
|
||||
|
||||
def _ensure_channel_tracking(self, target_channel_id: str):
|
||||
if str(target_channel_id) not in self.channel_messages:
|
||||
self.channel_messages[str(target_channel_id)] = {
|
||||
"message_map": {},
|
||||
"last_message_id": "",
|
||||
"last_message_timestamp": "",
|
||||
"total_messages": 0,
|
||||
"total_files": 0,
|
||||
"threads": {}
|
||||
}
|
||||
def set_target_message_mapping(self, target_channel_id: str, discord_id: str, target_id: str):
|
||||
if self._ensure_db():
|
||||
self.db.set_message_mapping(str(target_channel_id), str(discord_id), str(target_id))
|
||||
|
||||
def increment_stats(self, target_channel_id: str, messages: int = 1, files: int = 0):
|
||||
self._ensure_channel_tracking(target_channel_id)
|
||||
c = self.channel_messages[str(target_channel_id)]
|
||||
c["total_messages"] = c.get("total_messages", 0) + messages
|
||||
c["total_files"] = c.get("total_files", 0) + files
|
||||
self.save_messages()
|
||||
|
||||
# --- Thread Tracking ---
|
||||
|
||||
def _ensure_thread_tracking(self, target_channel_id: str, thread_id: str):
|
||||
self._ensure_channel_tracking(target_channel_id)
|
||||
threads = self.channel_messages[str(target_channel_id)].setdefault("threads", {})
|
||||
if str(thread_id) not in threads:
|
||||
threads[str(thread_id)] = {
|
||||
"thread_map": {},
|
||||
"last_message_id": "",
|
||||
"last_message_timestamp": "",
|
||||
"total_messages": 0,
|
||||
"total_files": 0
|
||||
}
|
||||
|
||||
def increment_thread_stats(self, target_channel_id: str, thread_id: str, messages: int = 1, files: int = 0):
|
||||
self._ensure_thread_tracking(target_channel_id, thread_id)
|
||||
t = self.channel_messages[str(target_channel_id)]["threads"][str(thread_id)]
|
||||
t["total_messages"] = t.get("total_messages", 0) + messages
|
||||
t["total_files"] = t.get("total_files", 0) + files
|
||||
self.save_messages()
|
||||
|
||||
def set_thread_message_mapping(self, target_channel_id: str, thread_id: str, discord_id: str, target_id: str):
|
||||
self._ensure_thread_tracking(target_channel_id, thread_id)
|
||||
self.channel_messages[str(target_channel_id)]["threads"][str(thread_id)]["thread_map"][str(discord_id)] = str(target_id)
|
||||
# Also add to main message_map for global message resolution (like replies)
|
||||
self.set_message_mapping(target_channel_id, discord_id, target_id)
|
||||
|
||||
def update_thread_last_message_timestamp(self, target_channel_id: str, thread_id: str, timestamp: str):
|
||||
self._ensure_thread_tracking(target_channel_id, thread_id)
|
||||
self.channel_messages[str(target_channel_id)]["threads"][str(thread_id)]["last_message_timestamp"] = str(timestamp)
|
||||
self.save_messages()
|
||||
|
||||
def update_thread_last_message_id(self, target_channel_id: str, thread_id: str, message_id: str):
|
||||
self._ensure_thread_tracking(target_channel_id, thread_id)
|
||||
self.channel_messages[str(target_channel_id)]["threads"][str(thread_id)]["last_message_id"] = str(message_id)
|
||||
self.save_messages()
|
||||
|
||||
def get_thread_message_id(self, target_channel_id: str, thread_id: str, discord_id: str) -> str | None:
|
||||
if str(target_channel_id) in self.channel_messages:
|
||||
threads = self.channel_messages[str(target_channel_id)].get("threads", {})
|
||||
if str(thread_id) in threads:
|
||||
return threads[str(thread_id)]["thread_map"].get(str(discord_id))
|
||||
def get_target_message_id(self, target_channel_id: str, discord_id: str) -> str | None:
|
||||
if self._ensure_db():
|
||||
return self.db.get_target_message_id(str(target_channel_id), str(discord_id))
|
||||
return None
|
||||
|
||||
def set_message_mapping(self, target_channel_id: str, discord_id: str, fluxer_id: str):
|
||||
self._ensure_channel_tracking(target_channel_id)
|
||||
self.channel_messages[str(target_channel_id)]["message_map"][str(discord_id)] = str(fluxer_id)
|
||||
self.save_messages()
|
||||
def set_message_mapping(self, target_channel_id: str, discord_id: str, target_id: str):
|
||||
self.set_target_message_mapping(target_channel_id, discord_id, target_id)
|
||||
|
||||
def get_fluxer_message_id(self, target_channel_id: str, discord_id: str) -> str | None:
|
||||
if str(target_channel_id) in self.channel_messages:
|
||||
return self.channel_messages[str(target_channel_id)]["message_map"].get(str(discord_id))
|
||||
return self.get_target_message_id(target_channel_id, discord_id)
|
||||
|
||||
def increment_stats(self, target_channel_id: str, messages: int = 1, files: int = 0):
|
||||
if self._ensure_db():
|
||||
self.db.update_channel_tracking(str(target_channel_id), msg_inc=messages, file_inc=files)
|
||||
|
||||
def increment_thread_stats(self, target_channel_id: str, thread_id: str, messages: int = 1, files: int = 0):
|
||||
if self._ensure_db():
|
||||
self.db.update_thread_tracking(str(target_channel_id), str(thread_id), msg_inc=messages, file_inc=files)
|
||||
|
||||
def set_thread_message_mapping(self, target_channel_id: str, thread_id: str, discord_id: str, target_id: str):
|
||||
if self._ensure_db():
|
||||
self.db.set_thread_message_mapping(str(target_channel_id), str(thread_id), str(discord_id), str(target_id))
|
||||
|
||||
def update_thread_last_message_timestamp(self, target_channel_id: str, thread_id: str, timestamp: str):
|
||||
if self._ensure_db():
|
||||
self.db.update_thread_tracking(str(target_channel_id), str(thread_id), last_msg_ts=str(timestamp))
|
||||
|
||||
def update_thread_last_message_id(self, target_channel_id: str, thread_id: str, message_id: str):
|
||||
if self._ensure_db():
|
||||
self.db.update_thread_tracking(str(target_channel_id), str(thread_id), last_msg_id=str(message_id))
|
||||
|
||||
def get_thread_message_id(self, target_channel_id: str, thread_id: str, discord_id: str) -> str | None:
|
||||
if self._ensure_db():
|
||||
return self.db.get_target_thread_message_id(str(target_channel_id), str(thread_id), str(discord_id))
|
||||
return None
|
||||
|
||||
def update_last_message_timestamp(self, target_channel_id: str, timestamp: str):
|
||||
if self._ensure_db():
|
||||
self.db.update_channel_tracking(str(target_channel_id), last_msg_ts=str(timestamp))
|
||||
|
||||
def update_last_message_id(self, target_channel_id: str, message_id: str):
|
||||
if self._ensure_db():
|
||||
self.db.update_channel_tracking(str(target_channel_id), last_msg_id=str(message_id))
|
||||
|
||||
def get_last_message_id(self, target_channel_id: str) -> str | None:
|
||||
if self._ensure_db():
|
||||
return self.db.get_channel_tracking(str(target_channel_id)).get("last_msg_id")
|
||||
return None
|
||||
|
||||
def find_message_mapping(self, discord_id: str) -> tuple[str, str] | tuple[None, None]:
|
||||
"""
|
||||
Searches for a message mapping across all tracked channels.
|
||||
Returns (target_channel_id, target_message_id) or (None, None).
|
||||
"""
|
||||
d_id = str(discord_id)
|
||||
for t_cid, data in self.channel_messages.items():
|
||||
# Check main message map
|
||||
if d_id in data.get("message_map", {}):
|
||||
return str(t_cid), str(data["message_map"][d_id])
|
||||
# Check threads
|
||||
for t_tid, t_data in data.get("threads", {}).items():
|
||||
if d_id in t_data.get("thread_map", {}):
|
||||
# For thread links, the target_channel_id is technically the thread ID in some contexts,
|
||||
# but usually for the URL it's the thread ID itself.
|
||||
return str(t_tid), str(t_data["thread_map"][d_id])
|
||||
if not self.db:
|
||||
return None, None
|
||||
conn = self.db._get_conn()
|
||||
row = conn.execute("SELECT channel_id, target_msg_id FROM message_mappings WHERE source_msg_id = ?", (str(discord_id),)).fetchone()
|
||||
if row:
|
||||
return str(row["channel_id"]), str(row["target_msg_id"])
|
||||
row = conn.execute("SELECT thread_id, target_msg_id FROM thread_mappings WHERE source_msg_id = ?", (str(discord_id),)).fetchone()
|
||||
if row:
|
||||
return str(row["thread_id"]), str(row["target_msg_id"])
|
||||
return None, None
|
||||
|
||||
def update_last_message_timestamp(self, target_channel_id: str, timestamp: str):
|
||||
self._ensure_channel_tracking(target_channel_id)
|
||||
self.channel_messages[str(target_channel_id)]["last_message_timestamp"] = str(timestamp)
|
||||
self.save_messages()
|
||||
|
||||
def update_last_message_id(self, target_channel_id: str, message_id: str):
|
||||
self._ensure_channel_tracking(target_channel_id)
|
||||
self.channel_messages[str(target_channel_id)]["last_message_id"] = str(message_id)
|
||||
self.save_messages()
|
||||
|
||||
def get_last_message_id(self, target_channel_id: str) -> str | None:
|
||||
if str(target_channel_id) in self.channel_messages:
|
||||
return self.channel_messages[str(target_channel_id)].get("last_message_id")
|
||||
return None
|
||||
|
||||
# --- Danger Zone Clearing ---
|
||||
|
||||
def clear_channel_mappings(self):
|
||||
"""Clears all channel and category mappings."""
|
||||
self.channel_map.clear()
|
||||
self.category_map.clear()
|
||||
self.save_state()
|
||||
if self._ensure_db():
|
||||
self.db.clear_entities("channel")
|
||||
self.db.clear_entities("category")
|
||||
|
||||
def clear_role_mappings(self):
|
||||
"""Clears all role mappings."""
|
||||
self.role_map.clear()
|
||||
self.save_state()
|
||||
if self._ensure_db():
|
||||
self.db.clear_entities("role")
|
||||
|
||||
def clear_asset_mappings(self):
|
||||
"""Clears all emoji and sticker mappings."""
|
||||
self.emoji_map.clear()
|
||||
self.sticker_map.clear()
|
||||
self.save_state()
|
||||
if self._ensure_db():
|
||||
self.db.clear_entities("emoji")
|
||||
self.db.clear_entities("sticker")
|
||||
|
||||
def clear_message_history(self):
|
||||
"""Clears all message mappings and timestamps."""
|
||||
self.channel_messages.clear()
|
||||
self.save_messages()
|
||||
if self.db:
|
||||
conn = self.db._get_conn()
|
||||
conn.execute("DELETE FROM message_mappings")
|
||||
conn.execute("DELETE FROM thread_mappings")
|
||||
conn.execute("DELETE FROM channel_tracking")
|
||||
conn.execute("DELETE FROM thread_tracking")
|
||||
conn.commit()
|
||||
|
||||
def set_folder(self, server_id: str, clean_name: str, base_dir: Path | str = ""):
|
||||
"""
|
||||
Initializes the SQLite database based on community name and ID.
|
||||
Filename: {name}-{id}.db (Flat structure)
|
||||
ID is priority: if a DB with the same ID exists but different name, rename it.
|
||||
"""
|
||||
base = Path(base_dir) if base_dir else Path(".")
|
||||
new_folder = base / f"{clean_name}-{server_id}"
|
||||
logger.info(f"Setting active migration folder: {new_folder}")
|
||||
desired_filename = f"{clean_name}-{server_id}.db"
|
||||
desired_path = base / desired_filename
|
||||
|
||||
# 1. Search base_dir to see if an older folder for this server_id exists
|
||||
existing_folder: Path | None = None
|
||||
if base.exists() and base.is_dir():
|
||||
for d in base.iterdir():
|
||||
if d.is_dir() and d.name.endswith(f"-{server_id}"):
|
||||
existing_folder = d
|
||||
# Priority 1: Match by ID
|
||||
existing_db: Path | None = None
|
||||
# Look for any file ending with -{server_id}.db
|
||||
for f in base.glob(f"*-{server_id}.db"):
|
||||
if f.is_file():
|
||||
existing_db = f
|
||||
break
|
||||
|
||||
# 2. Rename it if it doesn't match the new desired name
|
||||
if existing_folder and existing_folder != new_folder:
|
||||
logger.info(f"Renaming existing folder {existing_folder.name} to {new_folder.name}")
|
||||
db_path = desired_path
|
||||
if existing_db:
|
||||
if existing_db.name != desired_filename:
|
||||
logger.info(f"Server renamed: moving {existing_db.name} -> {desired_filename}")
|
||||
try:
|
||||
existing_folder.rename(new_folder)
|
||||
existing_db.rename(desired_path)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not rename {existing_folder} to {new_folder}: {e}")
|
||||
logger.error(f"Failed to rename database: {e}")
|
||||
# If rename fails, we'll use the existing one if it exists at the old path,
|
||||
# or the desired one if it exists there.
|
||||
if not desired_path.exists():
|
||||
db_path = existing_db
|
||||
|
||||
new_folder.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Setting active migration database: {db_path}")
|
||||
|
||||
self.state_file = new_folder / "state-migration.json"
|
||||
self.messages_file = new_folder / "message-tracker.json"
|
||||
from src.core.database import MigrationDatabase
|
||||
if self.db:
|
||||
self.db.close()
|
||||
self.db = MigrationDatabase(db_path)
|
||||
logger.info(f"Initialized SQLite database at {db_path}")
|
||||
|
||||
logger.debug("Re-loading data from new folder location.")
|
||||
self.load()
|
||||
# No-op methods kept for compatibility with callers that might try to load/save JSON
|
||||
def load(self): pass
|
||||
def save_state(self): pass
|
||||
|
|
|
|||
|
|
@ -1576,7 +1576,7 @@ class OperationPane(Container):
|
|||
|
||||
async def _fetch_clone_preview(self, selections: list[str]) -> dict[str, Any]:
|
||||
"""Fetches preview data from Discord (source server) for cloning confirmation,
|
||||
comparing with existing mappings in state-migration.json for presence highlighting."""
|
||||
comparing with existing mappings in SQLite for presence highlighting."""
|
||||
preview = {}
|
||||
reader = self.engine.discord_reader
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue