311 lines
13 KiB
Python
311 lines
13 KiB
Python
import json
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Dict, Optional, Any, Union, TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from src.core.database import MigrationDatabase
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class MigrationState:
|
|
"""Manages persistence of the migration state to allow resumability.
|
|
Uses SQLite for ALL mappings and metadata.
|
|
"""
|
|
|
|
def __init__(self):
|
|
# database instance for all persistence
|
|
self.db: Optional['MigrationDatabase'] = None
|
|
|
|
def _ensure_db(self):
|
|
if not self.db:
|
|
logger.warning("MigrationState: Accessing database before initialization")
|
|
return False
|
|
return True
|
|
|
|
# --- Type Specific Getters/Setters # --- Channel Mapping ---
|
|
def set_channel_mapping(self, discord_id: int | str, target_id: str):
|
|
"""Maps an original text/voice/forum channel ID to a minted server channel ID."""
|
|
if self.db:
|
|
self.db.set_server_mapping("channel", str(discord_id), str(target_id))
|
|
|
|
def get_target_channel_id(self, discord_id: int | str) -> str | None:
|
|
if self.db:
|
|
return self.db.get_server_mapping("channel", str(discord_id))
|
|
return None
|
|
|
|
def remove_channel_mapping(self, discord_id: int | str):
|
|
if self.db:
|
|
self.db.delete_server_mapping("channel", str(discord_id))
|
|
|
|
|
|
get_fluxer_channel_id = get_target_channel_id
|
|
set_target_channel_mapping = set_channel_mapping
|
|
|
|
# --- Category Mapping ---
|
|
def set_category_mapping(self, discord_id: int | str, target_id: str):
|
|
"""Maps an original discord category ID to a Stoat category Group ID."""
|
|
if self.db:
|
|
self.db.set_server_mapping("category", str(discord_id), str(target_id))
|
|
|
|
def get_category_mapping(self, discord_id: int | str) -> str | None:
|
|
"""Returns the Stoat Group ID for a previously migrated category."""
|
|
if self.db:
|
|
return self.db.get_server_mapping("category", str(discord_id))
|
|
return None
|
|
|
|
def remove_category_mapping(self, discord_id: int | str):
|
|
if self.db:
|
|
self.db.delete_server_mapping("category", str(discord_id))
|
|
|
|
get_fluxer_category_id = get_category_mapping
|
|
get_target_category_id = get_category_mapping
|
|
set_target_category_mapping = set_category_mapping
|
|
|
|
# --- Role Mapping ---
|
|
def set_role_mapping(self, discord_id: int | str, target_id: str):
|
|
"""Maps an original discord Role ID to a Stoat Role ID."""
|
|
if self.db:
|
|
self.db.set_server_mapping("role", str(discord_id), str(target_id))
|
|
|
|
def get_role_mapping(self, discord_id: int | str) -> str | None:
|
|
"""Returns the target Role ID for a previously migrated Role."""
|
|
if self.db:
|
|
return self.db.get_server_mapping("role", str(discord_id))
|
|
return None
|
|
|
|
def remove_role_mapping(self, discord_id: int | str):
|
|
if self.db:
|
|
self.db.delete_server_mapping("role", str(discord_id))
|
|
|
|
get_fluxer_role_id = get_role_mapping
|
|
get_target_role_id = get_role_mapping
|
|
set_target_role_mapping = set_role_mapping
|
|
|
|
# --- Emoji Mapping ---
|
|
def set_emoji_mapping(self, discord_id: int | str, target_id: str):
|
|
"""Maps an original discord Custom Emoji ID to a minted Emoji ID/URL."""
|
|
if self.db:
|
|
self.db.set_asset_mapping("emoji", str(discord_id), str(target_id))
|
|
|
|
def get_emoji_mapping(self, discord_id: int | str) -> str | None:
|
|
if self.db:
|
|
return self.db.get_asset_mapping("emoji", str(discord_id))
|
|
return None
|
|
|
|
def remove_emoji_mapping(self, discord_id: int | str):
|
|
if self.db:
|
|
self.db.delete_asset_mapping("emoji", str(discord_id))
|
|
|
|
get_fluxer_emoji_id = get_emoji_mapping
|
|
get_target_emoji_id = get_emoji_mapping
|
|
set_target_emoji_mapping = set_emoji_mapping
|
|
|
|
# --- Sticker Mapping ---
|
|
def set_sticker_mapping(self, discord_id: int | str, target_id: str):
|
|
"""Maps an original discord Custom Sticker ID to a target URL or ID."""
|
|
if self.db:
|
|
self.db.set_asset_mapping("sticker", str(discord_id), str(target_id))
|
|
|
|
def get_sticker_mapping(self, discord_id: int | str) -> str | None:
|
|
if self.db:
|
|
return self.db.get_asset_mapping("sticker", str(discord_id))
|
|
return None
|
|
|
|
def remove_sticker_mapping(self, discord_id: int | str):
|
|
if self.db:
|
|
self.db.delete_asset_mapping("sticker", str(discord_id))
|
|
|
|
get_fluxer_sticker_id = get_sticker_mapping
|
|
get_target_sticker_id = get_sticker_mapping
|
|
set_target_sticker_mapping = set_sticker_mapping
|
|
|
|
# --- Properties for backward compatibility ---
|
|
@property
|
|
def channel_map(self) -> Dict[Union[str, int], Union[str, int]]:
|
|
return self.db.get_all_server_mappings("channel") if self.db else {}
|
|
|
|
@property
|
|
def category_map(self) -> Dict[Union[str, int], Union[str, int]]:
|
|
return self.db.get_all_server_mappings("category") if self.db else {}
|
|
|
|
@property
|
|
def role_map(self) -> Dict[Union[str, int], Union[str, int]]:
|
|
return self.db.get_all_server_mappings("role") if self.db else {}
|
|
|
|
@property
|
|
def emoji_map(self) -> Dict[Union[str, int], Union[str, int]]:
|
|
return self.db.get_all_asset_mappings("emoji") if self.db else {}
|
|
|
|
@property
|
|
def sticker_map(self) -> Dict[Union[str, int], Union[str, int]]:
|
|
return self.db.get_all_asset_mappings("sticker") if self.db else {}
|
|
|
|
@property
|
|
def audit_log_channel(self) -> str | None:
|
|
return self.db.get_metadata("audit_log_channel") if self.db else None
|
|
|
|
@audit_log_channel.setter
|
|
def audit_log_channel(self, value: str | None):
|
|
if self._ensure_db():
|
|
self.db.set_metadata("audit_log_channel", value)
|
|
|
|
# --- Message Management ---
|
|
|
|
def set_target_message_mapping(self, target_channel_id: str, discord_id: str, target_id: str):
|
|
if self._ensure_db():
|
|
self.db.set_message_mapping(str(target_channel_id), str(discord_id), str(target_id))
|
|
|
|
def get_target_message_id(self, target_channel_id: str, discord_id: str) -> str | int | None:
|
|
if self._ensure_db():
|
|
return self.db.get_target_message_id(str(target_channel_id), str(discord_id))
|
|
return None
|
|
|
|
def set_message_mapping(self, target_channel_id: str, discord_id: str, target_id: str):
|
|
self.set_target_message_mapping(target_channel_id, discord_id, target_id)
|
|
|
|
def get_fluxer_message_id(self, target_channel_id: str, discord_id: str) -> str | int | None:
|
|
return self.get_target_message_id(target_channel_id, discord_id)
|
|
|
|
def increment_stats(self, target_channel_id: str, messages: int = 1, files: int = 0):
|
|
if self._ensure_db():
|
|
self.db.update_channel_tracking(str(target_channel_id), msg_inc=messages, file_inc=files)
|
|
|
|
def increment_thread_stats(self, target_channel_id: str, thread_id: str, messages: int = 1, files: int = 0):
|
|
if self._ensure_db():
|
|
self.db.update_thread_tracking(str(target_channel_id), str(thread_id), msg_inc=messages, file_inc=files)
|
|
|
|
def set_thread_message_mapping(self, target_channel_id: str, thread_id: str, discord_id: str, target_id: str):
|
|
if self._ensure_db():
|
|
self.db.set_thread_message_mapping(str(target_channel_id), str(thread_id), str(discord_id), str(target_id))
|
|
|
|
def update_thread_last_message_timestamp(self, target_channel_id: str, thread_id: str, timestamp: str):
|
|
if self._ensure_db():
|
|
self.db.update_thread_tracking(str(target_channel_id), str(thread_id), last_msg_ts=str(timestamp))
|
|
|
|
def update_thread_last_message_id(self, target_channel_id: str, thread_id: str, message_id: str):
|
|
if self._ensure_db():
|
|
self.db.update_thread_tracking(str(target_channel_id), str(thread_id), last_msg_id=str(message_id))
|
|
|
|
def update_thread_completed(self, target_channel_id: str, thread_id: str, completed: bool = True):
|
|
if self._ensure_db():
|
|
self.db.update_thread_tracking(str(target_channel_id), str(thread_id), completed=1 if completed else 0)
|
|
|
|
def is_thread_completed(self, target_channel_id: str, thread_id: str) -> bool:
|
|
if self._ensure_db():
|
|
tracking = self.db.get_thread_tracking(str(target_channel_id), str(thread_id))
|
|
return bool(tracking.get("completed", 0))
|
|
return False
|
|
|
|
def get_thread_message_id(self, target_channel_id: str, thread_id: str, discord_id: str) -> str | None:
|
|
if self._ensure_db():
|
|
return self.db.get_target_thread_message_id(str(target_channel_id), str(thread_id), str(discord_id))
|
|
return None
|
|
|
|
def update_last_message_timestamp(self, target_channel_id: str, timestamp: str):
|
|
if self._ensure_db():
|
|
self.db.update_channel_tracking(str(target_channel_id), last_msg_ts=str(timestamp))
|
|
|
|
def update_last_message_id(self, target_channel_id: str, message_id: str):
|
|
if self._ensure_db():
|
|
self.db.update_channel_tracking(str(target_channel_id), last_msg_id=str(message_id))
|
|
|
|
def get_last_message_id(self, target_channel_id: str) -> str | None:
|
|
if self._ensure_db():
|
|
return self.db.get_channel_tracking(str(target_channel_id)).get("last_msg_id")
|
|
return None
|
|
|
|
def get_thread_last_message_id(self, target_channel_id: str, thread_id: str) -> str | None:
|
|
if self._ensure_db():
|
|
return self.db.get_thread_tracking(str(target_channel_id), str(thread_id)).get("last_msg_id")
|
|
return None
|
|
|
|
def find_message_mapping(self, discord_id: str) -> tuple[str, str] | tuple[None, None]:
|
|
if not self.db:
|
|
return None, None
|
|
conn = self.db._get_conn()
|
|
row = conn.execute("SELECT channel_id, target_msg_id FROM message_mappings WHERE source_msg_id = ?", (str(discord_id),)).fetchone()
|
|
if row:
|
|
return str(row["channel_id"]), str(row["target_msg_id"])
|
|
row = conn.execute("SELECT channel_id, target_msg_id FROM thread_mappings WHERE source_msg_id = ?", (str(discord_id),)).fetchone()
|
|
if row:
|
|
return str(row["channel_id"]), str(row["target_msg_id"])
|
|
return None, None
|
|
|
|
# --- Danger Zone Clearing ---
|
|
|
|
def clear_channel_mappings(self):
|
|
if self.db:
|
|
self.db.clear_server_mappings("channel")
|
|
self.db.clear_server_mappings("category")
|
|
|
|
def clear_role_mappings(self):
|
|
if self.db:
|
|
self.db.clear_server_mappings("role")
|
|
|
|
def clear_asset_mappings(self):
|
|
if self.db:
|
|
self.db.clear_asset_mappings("emoji")
|
|
self.db.clear_asset_mappings("sticker")
|
|
|
|
def clear_message_history(self):
|
|
if self.db:
|
|
conn = self.db._get_conn()
|
|
conn.execute("DELETE FROM message_mappings")
|
|
conn.execute("DELETE FROM thread_mappings")
|
|
conn.execute("DELETE FROM channel_tracking")
|
|
conn.execute("DELETE FROM thread_tracking")
|
|
conn.commit()
|
|
|
|
def clear_channel_data(self, target_channel_id: str):
|
|
if self._ensure_db():
|
|
self.db.clear_channel_data(str(target_channel_id))
|
|
|
|
def set_folder(self, server_id: str, clean_name: str, platform: str = "stoat", base_dir: Path | str = ""):
|
|
"""
|
|
Initializes the SQLite database based on community name and ID.
|
|
Filename: {name}-{id}.db (Flat structure)
|
|
ID is priority: if a DB with the same ID exists but different name, rename it.
|
|
"""
|
|
base = Path(base_dir) if base_dir else Path(".")
|
|
desired_filename = f"{clean_name}-{server_id}.db"
|
|
desired_path = base / desired_filename
|
|
|
|
# Priority 1: Match by ID
|
|
existing_db: Path | None = None
|
|
# Look for any file ending with -{server_id}.db
|
|
for f in base.glob(f"*-{server_id}.db"):
|
|
if f.is_file():
|
|
existing_db = f
|
|
break
|
|
|
|
db_path = desired_path
|
|
if existing_db:
|
|
if existing_db.name != desired_filename:
|
|
logger.info(f"Server renamed: moving {existing_db.name} -> {desired_filename}")
|
|
try:
|
|
existing_db.rename(desired_path)
|
|
except Exception as e:
|
|
logger.error(f"Failed to rename database: {e}")
|
|
# If rename fails, we'll use the existing one if it exists at the old path,
|
|
# or the desired one if it exists there.
|
|
if not desired_path.exists():
|
|
db_path = existing_db
|
|
|
|
logger.info(f"Setting active migration database: {db_path}")
|
|
|
|
from src.core.database import MigrationDatabase
|
|
if self.db:
|
|
self.db.close()
|
|
self.db = MigrationDatabase(db_path, platform=platform)
|
|
logger.info(f"Initialized SQLite database at {db_path}")
|
|
|
|
def get_user_alias(self, user_id: str) -> str | None:
|
|
"""Gets or generates a unique alias for a given user ID via the Migration Database."""
|
|
if self.db:
|
|
return self.db.get_or_create_user_alias(user_id)
|
|
return None
|
|
|
|
# No-op methods kept for compatibility with callers that might try to load/save JSON
|
|
def load(self): pass
|
|
def save_state(self): pass
|