256 lines
10 KiB
Python
256 lines
10 KiB
Python
import sqlite3
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Optional, Dict, Any
|
|
import threading
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class MigrationDatabase:
|
|
"""
|
|
SQLite-based persistence for large-scale migration mappings and stats.
|
|
Replaces the memory-bloated and O(N^2) JSON persistence for messages.
|
|
"""
|
|
|
|
_local = threading.local()
|
|
|
|
def __init__(self, db_path: Path):
|
|
self.db_path = db_path
|
|
self._init_db()
|
|
|
|
def _get_conn(self) -> sqlite3.Connection:
|
|
if not hasattr(self._local, "conn"):
|
|
self._local.conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
|
self._local.conn.row_factory = sqlite3.Row
|
|
return self._local.conn
|
|
|
|
def _init_db(self):
|
|
"""Initialize tables if they don't exist."""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# Table for message mappings: SourceID -> TargetID
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS message_mappings (
|
|
channel_id TEXT,
|
|
source_msg_id TEXT,
|
|
target_msg_id TEXT,
|
|
timestamp TEXT,
|
|
PRIMARY KEY (channel_id, source_msg_id)
|
|
)
|
|
""")
|
|
|
|
# Table for thread mappings
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS thread_mappings (
|
|
channel_id TEXT,
|
|
thread_id TEXT,
|
|
source_msg_id TEXT,
|
|
target_msg_id TEXT,
|
|
timestamp TEXT,
|
|
PRIMARY KEY (channel_id, thread_id, source_msg_id)
|
|
)
|
|
""")
|
|
|
|
# Table for per-channel stats and tracking
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS channel_tracking (
|
|
channel_id TEXT PRIMARY KEY,
|
|
last_msg_id TEXT,
|
|
last_msg_ts TEXT,
|
|
msg_count INTEGER DEFAULT 0,
|
|
file_count INTEGER DEFAULT 0
|
|
)
|
|
""")
|
|
|
|
# Table for per-thread stats
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS thread_tracking (
|
|
channel_id TEXT,
|
|
thread_id TEXT,
|
|
last_msg_id TEXT,
|
|
last_msg_ts TEXT,
|
|
msg_count INTEGER DEFAULT 0,
|
|
file_count INTEGER DEFAULT 0,
|
|
completed INTEGER DEFAULT 0,
|
|
PRIMARY KEY (channel_id, thread_id)
|
|
)
|
|
""")
|
|
|
|
# Add completed column if it doesn't exist (backward compatibility for existing resumption DBs)
|
|
try:
|
|
cursor.execute("ALTER TABLE thread_tracking ADD COLUMN completed INTEGER DEFAULT 0")
|
|
except sqlite3.OperationalError:
|
|
pass # Already exists
|
|
|
|
# Table for entity mappings (channels, roles, etc.)
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS entity_mappings (
|
|
category TEXT,
|
|
source_id TEXT,
|
|
target_id TEXT,
|
|
PRIMARY KEY (category, source_id)
|
|
)
|
|
""")
|
|
|
|
# Table for general metadata
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS metadata (
|
|
key TEXT PRIMARY KEY,
|
|
value TEXT
|
|
)
|
|
""")
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
def set_message_mapping(self, channel_id: str, source_id: str, target_id: str, timestamp: str = None):
|
|
conn = self._get_conn()
|
|
conn.execute(
|
|
"INSERT OR REPLACE INTO message_mappings (channel_id, source_msg_id, target_msg_id, timestamp) VALUES (?, ?, ?, ?)",
|
|
(channel_id, source_id, target_id, timestamp)
|
|
)
|
|
conn.commit()
|
|
|
|
def get_target_message_id(self, channel_id: str, source_id: str) -> Optional[str]:
|
|
conn = self._get_conn()
|
|
row = conn.execute(
|
|
"SELECT target_msg_id FROM message_mappings WHERE channel_id = ? AND source_msg_id = ?",
|
|
(channel_id, source_id)
|
|
).fetchone()
|
|
return row["target_msg_id"] if row else None
|
|
|
|
# --- New Entity Mapping Methods ---
|
|
|
|
def set_entity_mapping(self, category: str, source_id: str, target_id: str):
|
|
conn = self._get_conn()
|
|
conn.execute(
|
|
"INSERT OR REPLACE INTO entity_mappings (category, source_id, target_id) VALUES (?, ?, ?)",
|
|
(category, str(source_id), str(target_id))
|
|
)
|
|
conn.commit()
|
|
|
|
def get_entity_mapping(self, category: str, source_id: str) -> Optional[str]:
|
|
conn = self._get_conn()
|
|
row = conn.execute(
|
|
"SELECT target_id FROM entity_mappings WHERE category = ? AND source_id = ?",
|
|
(category, str(source_id))
|
|
).fetchone()
|
|
return row["target_id"] if row else None
|
|
|
|
def get_all_entity_mappings(self, category: str) -> Dict[str, str]:
|
|
conn = self._get_conn()
|
|
rows = conn.execute(
|
|
"SELECT source_id, target_id FROM entity_mappings WHERE category = ?",
|
|
(category,)
|
|
).fetchall()
|
|
return {row["source_id"]: row["target_id"] for row in rows}
|
|
|
|
def delete_entity_mapping(self, category: str, source_id: str):
|
|
conn = self._get_conn()
|
|
conn.execute(
|
|
"DELETE FROM entity_mappings WHERE category = ? AND source_id = ?",
|
|
(category, str(source_id))
|
|
)
|
|
conn.commit()
|
|
|
|
def clear_entities(self, category: str = None):
|
|
conn = self._get_conn()
|
|
if category:
|
|
conn.execute("DELETE FROM entity_mappings WHERE category = ?", (category,))
|
|
else:
|
|
conn.execute("DELETE FROM entity_mappings")
|
|
conn.commit()
|
|
|
|
# --- Metadata Methods ---
|
|
|
|
def set_metadata(self, key: str, value: str):
|
|
conn = self._get_conn()
|
|
conn.execute("INSERT OR REPLACE INTO metadata (key, value) VALUES (?, ?)", (key, str(value) if value is not None else None))
|
|
conn.commit()
|
|
|
|
def get_metadata(self, key: str) -> Optional[str]:
|
|
conn = self._get_conn()
|
|
row = conn.execute("SELECT value FROM metadata WHERE key = ?", (key,)).fetchone()
|
|
return row["value"] if row else None
|
|
|
|
def update_channel_tracking(self, channel_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0):
|
|
conn = self._get_conn()
|
|
# Initialize if missing
|
|
conn.execute("INSERT OR IGNORE INTO channel_tracking (channel_id) VALUES (?)", (channel_id,))
|
|
|
|
if last_msg_id:
|
|
conn.execute("UPDATE channel_tracking SET last_msg_id = ? WHERE channel_id = ?", (last_msg_id, channel_id))
|
|
if last_msg_ts:
|
|
conn.execute("UPDATE channel_tracking SET last_msg_ts = ? WHERE channel_id = ?", (last_msg_ts, channel_id))
|
|
|
|
if msg_inc != 0 or file_inc != 0:
|
|
conn.execute(
|
|
"UPDATE channel_tracking SET msg_count = msg_count + ?, file_count = file_count + ? WHERE channel_id = ?",
|
|
(msg_inc, file_inc, channel_id)
|
|
)
|
|
conn.commit()
|
|
|
|
def get_channel_tracking(self, channel_id: str) -> Dict[str, Any]:
|
|
conn = self._get_conn()
|
|
row = conn.execute("SELECT * FROM channel_tracking WHERE channel_id = ?", (channel_id,)).fetchone()
|
|
if row:
|
|
return dict(row)
|
|
return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0}
|
|
|
|
# Thread methods similar to channel methods
|
|
def set_thread_message_mapping(self, channel_id: str, thread_id: str, source_id: str, target_id: str, timestamp: str = None):
|
|
conn = self._get_conn()
|
|
conn.execute(
|
|
"INSERT OR REPLACE INTO thread_mappings (channel_id, thread_id, source_msg_id, target_msg_id, timestamp) VALUES (?, ?, ?, ?, ?)",
|
|
(channel_id, thread_id, source_id, target_id, timestamp)
|
|
)
|
|
conn.commit()
|
|
|
|
def get_target_thread_message_id(self, channel_id: str, thread_id: str, source_id: str) -> Optional[str]:
|
|
conn = self._get_conn()
|
|
row = conn.execute(
|
|
"SELECT target_msg_id FROM thread_mappings WHERE channel_id = ? AND thread_id = ? AND source_msg_id = ?",
|
|
(channel_id, thread_id, source_id)
|
|
).fetchone()
|
|
return row["target_msg_id"] if row else None
|
|
|
|
def update_thread_tracking(self, channel_id: str, thread_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0, completed: int = None):
|
|
conn = self._get_conn()
|
|
conn.execute("INSERT OR IGNORE INTO thread_tracking (channel_id, thread_id) VALUES (?, ?)", (channel_id, thread_id))
|
|
|
|
if last_msg_id:
|
|
conn.execute("UPDATE thread_tracking SET last_msg_id = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_id, channel_id, thread_id))
|
|
if last_msg_ts:
|
|
conn.execute("UPDATE thread_tracking SET last_msg_ts = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_ts, channel_id, thread_id))
|
|
if completed is not None:
|
|
conn.execute("UPDATE thread_tracking SET completed = ? WHERE channel_id = ? AND thread_id = ?", (completed, channel_id, thread_id))
|
|
|
|
if msg_inc != 0 or file_inc != 0:
|
|
conn.execute(
|
|
"UPDATE thread_tracking SET msg_count = msg_count + ?, file_count = file_count + ? WHERE channel_id = ? AND thread_id = ?",
|
|
(msg_inc, file_inc, channel_id, thread_id)
|
|
)
|
|
conn.commit()
|
|
|
|
def get_thread_tracking(self, channel_id: str, thread_id: str) -> Dict[str, Any]:
|
|
conn = self._get_conn()
|
|
row = conn.execute("SELECT * FROM thread_tracking WHERE channel_id = ? AND thread_id = ?", (channel_id, thread_id)).fetchone()
|
|
if row:
|
|
return dict(row)
|
|
return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0}
|
|
|
|
def clear_channel_data(self, channel_id: str):
|
|
"""Purge all mappings and tracking data for a specific channel and its threads."""
|
|
conn = self._get_conn()
|
|
conn.execute("DELETE FROM message_mappings WHERE channel_id = ?", (channel_id,))
|
|
conn.execute("DELETE FROM thread_mappings WHERE channel_id = ?", (channel_id,))
|
|
conn.execute("DELETE FROM channel_tracking WHERE channel_id = ?", (channel_id,))
|
|
conn.execute("DELETE FROM thread_tracking WHERE channel_id = ?", (channel_id,))
|
|
conn.commit()
|
|
logger.info(f"Cleared all tracking and mapping data for channel: {channel_id}")
|
|
|
|
def close(self):
|
|
if hasattr(self._local, "conn"):
|
|
self._local.conn.close()
|
|
del self._local.conn
|