diff --git a/src/core/backup_database.py b/src/core/backup_database.py new file mode 100644 index 0000000..e5f00dd --- /dev/null +++ b/src/core/backup_database.py @@ -0,0 +1,762 @@ +import sqlite3 +import logging +import json +import threading +from pathlib import Path +from typing import Dict, Any, List, Optional + +logger = logging.getLogger(__name__) + +class BackupDatabase: + """Manages the SQLite database for local Discord backups.""" + + def __init__(self, db_path: Path | str): + self.db_path = Path(db_path) + self._lock = threading.Lock() + self._init_db() + + def _get_conn(self): + conn = sqlite3.connect(self.db_path, check_same_thread=False) + conn.row_factory = sqlite3.Row + return conn + + def _init_db(self): + """Initializes the database schema.""" + with self._lock: + conn = self._get_conn() + try: + # Guild Profile + conn.execute(""" + CREATE TABLE IF NOT EXISTS guild_profile ( + id TEXT PRIMARY KEY, + name TEXT, + description TEXT, + icon_file TEXT, + icon_url TEXT, + banner_file TEXT, + banner_url TEXT, + owner_id TEXT, + last_backup TEXT, + ignore_channels TEXT + ) + """) + + # Roles + conn.execute(""" + CREATE TABLE IF NOT EXISTS roles ( + id TEXT PRIMARY KEY, + name TEXT, + color INTEGER, + position INTEGER, + permissions TEXT, + hoist INTEGER, + mentionable INTEGER + ) + """) + + # Channels + conn.execute(""" + CREATE TABLE IF NOT EXISTS channels ( + id TEXT PRIMARY KEY, + name TEXT, + type INTEGER, + position INTEGER, + category_id TEXT, + topic TEXT, + nsfw INTEGER + ) + """) + + # Channel Permissions + conn.execute(""" + CREATE TABLE IF NOT EXISTS permissions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + channel_id TEXT, + target_id TEXT, + target_type TEXT, + allow INTEGER, + deny INTEGER + ) + """) + conn.execute("CREATE INDEX IF NOT EXISTS idx_permissions_chan ON permissions(channel_id)") + + # Users (Author cache) + conn.execute(""" + CREATE TABLE IF NOT EXISTS users ( + id TEXT PRIMARY KEY, + username TEXT, + display_name TEXT, + avatar_file TEXT, + avatar_url TEXT, + roles TEXT + ) + """) + + # Messages + conn.execute(""" + CREATE TABLE IF NOT EXISTS messages ( + id TEXT PRIMARY KEY, + channel_id TEXT, + author_id TEXT, + content TEXT, + timestamp TEXT, + type INTEGER, + message_reference TEXT, + is_pinned INTEGER, + extra_data TEXT + ) + """) + conn.execute("CREATE INDEX IF NOT EXISTS idx_messages_channel ON messages(channel_id)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_messages_timestamp ON messages(timestamp)") + + # Attachments + conn.execute(""" + CREATE TABLE IF NOT EXISTS attachments ( + id TEXT PRIMARY KEY, + message_id TEXT, + filename TEXT, + size INTEGER, + url TEXT, + content_type TEXT, + local_hash TEXT + ) + """) + conn.execute("CREATE INDEX IF NOT EXISTS idx_attachments_msg ON attachments(message_id)") + + # Embeds + conn.execute(""" + CREATE TABLE IF NOT EXISTS embeds ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + message_id TEXT, + title TEXT, + description TEXT, + url TEXT, + color INTEGER, + timestamp TEXT, + thumbnail_url TEXT, + image_url TEXT, + author_name TEXT, + author_url TEXT, + author_icon_url TEXT, + footer_text TEXT, + footer_icon_url TEXT, + fields TEXT + ) + """) + conn.execute("CREATE INDEX IF NOT EXISTS idx_embeds_msg ON embeds(message_id)") + + # Reactions + conn.execute(""" + CREATE TABLE IF NOT EXISTS reactions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + message_id TEXT, + emoji_id TEXT, + emoji_name TEXT, + count INTEGER + ) + """) + conn.execute("CREATE INDEX IF NOT EXISTS idx_reactions_msg ON reactions(message_id)") + + # Message Stickers + conn.execute(""" + CREATE TABLE IF NOT EXISTS message_stickers ( + message_id TEXT, + sticker_id TEXT, + name TEXT, + url TEXT, + format_type INTEGER, + local_hash TEXT, + PRIMARY KEY (message_id, sticker_id) + ) + """) + conn.execute("CREATE INDEX IF NOT EXISTS idx_message_stickers_msg ON message_stickers(message_id)") + + # Threads + conn.execute(""" + CREATE TABLE IF NOT EXISTS threads ( + id TEXT PRIMARY KEY, + name TEXT, + type INTEGER, + parent_id TEXT, + message_count INTEGER, + member_count INTEGER, + archived INTEGER, + archive_timestamp TEXT, + auto_archive_duration INTEGER, + locked INTEGER, + applied_tags TEXT + ) + """) + conn.execute("CREATE INDEX IF NOT EXISTS idx_threads_parent ON threads(parent_id)") + + # Forum Tags (Definitions for a forum channel) + conn.execute(""" + CREATE TABLE IF NOT EXISTS forum_tags ( + id TEXT PRIMARY KEY, + forum_id TEXT, + name TEXT, + moderated INTEGER, + emoji_id TEXT, + emoji_name TEXT + ) + """) + conn.execute("CREATE INDEX IF NOT EXISTS idx_forum_tags_forum ON forum_tags(forum_id)") + + # Media Pool (CAS) + # Maps content hashes to local storage paths + conn.execute(""" + CREATE TABLE IF NOT EXISTS media_pool ( + hash TEXT PRIMARY KEY, + local_path TEXT, + size INTEGER, + mime_type TEXT, + first_seen_url TEXT + ) + """) + + # Server Assets (Emojis, Stickers, etc.) + conn.execute(""" + CREATE TABLE IF NOT EXISTS server_assets ( + id TEXT PRIMARY KEY, + name TEXT, + type TEXT, + filename TEXT, + url TEXT, + mime_type TEXT + ) + """) + + conn.commit() + finally: + conn.close() + + def set_guild_profile(self, data: Dict[str, Any]): + with self._lock: + conn = self._get_conn() + conn.execute(""" + INSERT OR REPLACE INTO guild_profile (id, name, description, icon_file, icon_url, banner_file, banner_url, owner_id, last_backup, ignore_channels) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + str(data.get("id")), data.get("name"), data.get("description"), + data.get("icon_file"), data.get("icon_url"), + data.get("banner_file"), data.get("banner_url"), + str(data.get("owner_id")), + data.get("last_backup"), json.dumps(data.get("ignore_channels", [])) + )) + conn.commit() + conn.close() + + def get_guild_profile(self) -> Optional[Dict[str, Any]]: + with self._lock: + conn = self._get_conn() + row = conn.execute("SELECT * FROM guild_profile LIMIT 1").fetchone() + conn.close() + if row: + data = dict(row) + if data.get("ignore_channels"): + data["ignore_channels"] = json.loads(data["ignore_channels"]) + else: + data["ignore_channels"] = [] + return data + return None + + def save_roles(self, roles: List[Dict[str, Any]]): + with self._lock: + conn = self._get_conn() + # Ensure complex fields are strings if they aren't already + formatted = [] + for r in roles: + formatted.append({ + "id": str(r["id"]), + "name": r["name"], + "color": r["color"], + "position": r["position"], + "permissions": str(r["permissions"]), + "hoist": 1 if r["hoist"] else 0, + "mentionable": 1 if r["mentionable"] else 0 + }) + + conn.executemany(""" + INSERT OR REPLACE INTO roles (id, name, color, position, permissions, hoist, mentionable) + VALUES (:id, :name, :color, :position, :permissions, :hoist, :mentionable) + """, formatted) + conn.commit() + conn.close() + + def save_channels(self, channels: List[Dict[str, Any]]): + with self._lock: + conn = self._get_conn() + conn.executemany(""" + INSERT OR REPLACE INTO channels (id, name, type, position, category_id, topic, nsfw) + VALUES (:id, :name, :type, :position, :category_id, :topic, :nsfw) + """, channels) + conn.commit() + conn.close() + + def save_permissions(self, permissions: List[Dict[str, Any]]): + """Saves a batch of channel permission overwrites.""" + with self._lock: + conn = self._get_conn() + try: + conn.executemany(""" + INSERT INTO permissions (channel_id, target_id, target_type, allow, deny) + VALUES (:channel_id, :target_id, :target_type, :allow, :deny) + """, permissions) + conn.commit() + finally: + conn.close() + + def save_users(self, users: List[Dict[str, Any]]): + """Saves users to the author cache.""" + with self._lock: + conn = self._get_conn() + conn.executemany(""" + INSERT OR REPLACE INTO users (id, username, display_name, avatar_file, avatar_url, roles) + VALUES (:id, :username, :display_name, :avatar_file, :avatar_url, :roles) + """, users) + conn.commit() + conn.close() + + def save_server_assets(self, assets: List[Dict[str, Any]]): + """Saves a batch of server assets (emojis, stickers) to the database.""" + with self._lock: + conn = self._get_conn() + formatted = [] + for a in assets: + formatted.append({ + "id": str(a["id"]), + "name": a.get("name"), + "type": a.get("type"), + "filename": a.get("filename"), + "url": a.get("url"), + "mime_type": a.get("mime_type") + }) + + conn.executemany(""" + INSERT OR REPLACE INTO server_assets (id, name, type, filename, url, mime_type) + VALUES (:id, :name, :type, :filename, :url, :mime_type) + """, formatted) + conn.commit() + conn.close() + + def save_threads(self, threads: List[Dict[str, Any]]): + """Saves metadata for threads to the database.""" + with self._lock: + conn = self._get_conn() + try: + conn.executemany(""" + INSERT OR REPLACE INTO threads (id, name, type, parent_id, message_count, member_count, archived, archive_timestamp, auto_archive_duration, locked, applied_tags) + VALUES (:id, :name, :type, :parent_id, :message_count, :member_count, :archived, :archive_timestamp, :auto_archive_duration, :locked, :applied_tags) + """, threads) + conn.commit() + finally: + conn.close() + + def save_forum_tags(self, tags: List[Dict[str, Any]]): + """Saves definitions for forum tags.""" + with self._lock: + conn = self._get_conn() + try: + conn.executemany(""" + INSERT OR REPLACE INTO forum_tags (id, forum_id, name, moderated, emoji_id, emoji_name) + VALUES (:id, :forum_id, :name, :moderated, :emoji_id, :emoji_name) + """, tags) + conn.commit() + finally: + conn.close() + + def save_messages_batch(self, messages: List[Dict[str, Any]]): + """Batch inserts messages and their attachments.""" + with self._lock: + conn = self._get_conn() + try: + # Insert messages + conn.executemany(""" + INSERT OR REPLACE INTO messages (id, channel_id, author_id, content, timestamp, type, message_reference, is_pinned, extra_data) + VALUES (:id, :channel_id, :author_id, :content, :timestamp, :type, :message_reference, :is_pinned, :extra_data) + """, messages) + + # Extract attachments, reactions, and stickers + all_attachments = [] + all_reactions = [] + all_stickers = [] + + for msg in messages: + # Attachments + if "attachments" in msg: + for att in msg["attachments"]: + att["message_id"] = msg["id"] + all_attachments.append(att) + + # Reactions + if "reactions" in msg and msg["reactions"]: + for rea in msg["reactions"]: + all_reactions.append({ + "message_id": msg["id"], + "emoji_id": str(rea["emoji_id"]) if rea.get("emoji_id") else None, + "emoji_name": rea.get("emoji_name"), + "count": rea.get("count", 0) + }) + + # Stickers + if "stickers" in msg and msg["stickers"]: + for st in msg["stickers"]: + all_stickers.append({ + "message_id": msg["id"], + "sticker_id": str(st["id"]), + "name": st.get("name"), + "url": st.get("url"), + "format_type": st.get("format_type"), + "local_hash": st.get("local_hash") + }) + + if all_attachments: + conn.executemany(""" + INSERT OR REPLACE INTO attachments (id, message_id, filename, size, url, content_type, local_hash) + VALUES (:id, :message_id, :filename, :size, :url, :content_type, :local_hash) + """, all_attachments) + + # Save Embeds (Normalized with JSON Fields) + for msg in messages: + if "embeds" in msg and msg["embeds"]: + for emb in msg["embeds"]: + # Insert top-level embed + conn.execute(""" + INSERT INTO embeds ( + message_id, title, description, url, color, timestamp, + thumbnail_url, image_url, author_name, author_url, + author_icon_url, footer_text, footer_icon_url, fields + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + msg["id"], emb.get("title"), emb.get("description"), emb.get("url"), + emb.get("color"), emb.get("timestamp"), + emb.get("thumbnail", {}).get("url") if isinstance(emb.get("thumbnail"), dict) else None, + emb.get("image", {}).get("url") if isinstance(emb.get("image"), dict) else None, + emb.get("author", {}).get("name") if isinstance(emb.get("author"), dict) else None, + emb.get("author", {}).get("url") if isinstance(emb.get("author"), dict) else None, + emb.get("author", {}).get("icon_url") if isinstance(emb.get("author"), dict) else None, + emb.get("footer", {}).get("text") if isinstance(emb.get("footer"), dict) else None, + emb.get("footer", {}).get("icon_url") if isinstance(emb.get("footer"), dict) else None, + json.dumps(emb.get("fields", [])) + )) + + if all_reactions: + conn.executemany(""" + INSERT INTO reactions (message_id, emoji_id, emoji_name, count) + VALUES (:message_id, :emoji_id, :emoji_name, :count) + """, all_reactions) + + if all_stickers: + conn.executemany(""" + INSERT OR REPLACE INTO message_stickers (message_id, sticker_id, name, url, format_type, local_hash) + VALUES (:message_id, :sticker_id, :name, :url, :format_type, :local_hash) + """, all_stickers) + + conn.commit() + finally: + conn.close() + + def get_last_message_id(self, channel_id: str) -> Optional[str]: + with self._lock: + conn = self._get_conn() + row = conn.execute("SELECT id FROM messages WHERE channel_id = ? ORDER BY id DESC LIMIT 1", (str(channel_id),)).fetchone() + conn.close() + return row["id"] if row else None + + def get_media_by_hash(self, file_hash: str) -> Optional[Dict[str, Any]]: + with self._lock: + conn = self._get_conn() + row = conn.execute("SELECT * FROM media_pool WHERE hash = ?", (file_hash,)).fetchone() + conn.close() + return dict(row) if row else None + + def get_media_by_url(self, url: str) -> Optional[Dict[str, Any]]: + with self._lock: + conn = self._get_conn() + row = conn.execute("SELECT * FROM media_pool WHERE first_seen_url = ?", (url,)).fetchone() + conn.close() + return dict(row) if row else None + + def add_media_to_pool(self, file_hash: str, local_path: str, size: int, mime_type: str, url: str): + with self._lock: + conn = self._get_conn() + conn.execute(""" + INSERT OR REPLACE INTO media_pool (hash, local_path, size, mime_type, first_seen_url) + VALUES (?, ?, ?, ?, ?) + """, (file_hash, str(local_path), size, mime_type, url)) + conn.commit() + conn.close() + + def get_stats_by_channel(self) -> Dict[int, Dict[str, Any]]: + """Returns aggregate stats for all channels with backups.""" + with self._lock: + conn = self._get_conn() + # Summary of messages per effective channel (including threads rolled up) + msg_rows = conn.execute(""" + SELECT + COALESCE(t.parent_id, m.channel_id) as channel_id, + COUNT(m.id) as msg_count + FROM messages m + LEFT JOIN threads t ON m.channel_id = t.id + GROUP BY channel_id + """).fetchall() + + # Thread counts per parent + thread_rows = conn.execute(""" + SELECT parent_id, COUNT(*) as thread_count + FROM threads + GROUP BY parent_id + """).fetchall() + + # Summary of attachments per effective channel (including threads rolled up) + att_rows = conn.execute(""" + SELECT + COALESCE(t.parent_id, m.channel_id) as channel_id, + COUNT(a.id) as att_count, + SUM(a.size) as total_size + FROM attachments a + JOIN messages m ON a.message_id = m.id + LEFT JOIN threads t ON m.channel_id = t.id + GROUP BY channel_id + """).fetchall() + + conn.close() + + stats = {} + for r in msg_rows: + cid_raw = r["channel_id"] + if cid_raw is None or cid_raw == "None": continue + cid = int(cid_raw) + stats[cid] = { + "message_count": r["msg_count"], + "thread_count": 0, + "attachment_count": 0, + "total_size": 0 + } + + for r in thread_rows: + cid_raw = r["parent_id"] + if cid_raw is None or cid_raw == "None": continue + cid = int(cid_raw) + if cid not in stats: + stats[cid] = {"message_count": 0, "thread_count": 0, "attachment_count": 0, "total_size": 0} + stats[cid]["thread_count"] = r["thread_count"] + + for r in att_rows: + cid_raw = r["channel_id"] + if cid_raw is None or cid_raw == "None": continue + cid = int(cid_raw) + if cid not in stats: + stats[cid] = {"message_count": 0, "thread_count": 0, "attachment_count": 0, "total_size": 0} + stats[cid]["attachment_count"] = r["att_count"] + stats[cid]["total_size"] = r["total_size"] or 0 + + return stats + + def get_all_roles(self) -> List[Dict[str, Any]]: + with self._lock: + conn = self._get_conn() + rows = conn.execute("SELECT * FROM roles ORDER BY position DESC").fetchall() + conn.close() + return [dict(r) for r in rows] + + def get_all_channels(self) -> List[Dict[str, Any]]: + with self._lock: + conn = self._get_conn() + rows = conn.execute("SELECT * FROM channels ORDER BY position ASC").fetchall() + + # Fetch all permissions + chan_list = [dict(r) for r in rows] + if chan_list: + ids = [c["id"] for c in chan_list] + placeholders = ",".join(["?"] * len(ids)) + perm_rows = conn.execute(f"SELECT * FROM permissions WHERE channel_id IN ({placeholders})", ids).fetchall() + + perms_by_chan = {} + for pr in perm_rows: + cid = pr["channel_id"] + if cid not in perms_by_chan: perms_by_chan[cid] = [] + perms_by_chan[cid].append({ + "id": pr["target_id"], + "type": pr["target_type"], + "allow": pr["allow"], + "deny": pr["deny"] + }) + + for c in chan_list: + c["overwrites"] = perms_by_chan.get(c["id"], []) + + # Fetch Forum Tags + tag_rows = conn.execute(f"SELECT * FROM forum_tags WHERE forum_id IN ({placeholders})", ids).fetchall() + tags_by_forum = {} + for tr in tag_rows: + fid = tr["forum_id"] + if fid not in tags_by_forum: tags_by_forum[fid] = [] + tags_by_forum[fid].append(dict(tr)) + + for c in chan_list: + c["available_tags"] = tags_by_forum.get(c["id"], []) + + conn.close() + return chan_list + + def get_all_threads(self) -> List[Dict[str, Any]]: + """Returns metadata for all threads in the backup.""" + with self._lock: + conn = self._get_conn() + rows = conn.execute("SELECT * FROM threads").fetchall() + conn.close() + return [dict(r) for r in rows] + + def get_forum_tags(self, forum_id: Optional[str] = None) -> List[Dict[str, Any]]: + """Returns forum tag definitions.""" + with self._lock: + conn = self._get_conn() + if forum_id: + rows = conn.execute("SELECT * FROM forum_tags WHERE forum_id = ?", (str(forum_id),)).fetchall() + else: + rows = conn.execute("SELECT * FROM forum_tags").fetchall() + conn.close() + return [dict(r) for r in rows] + + def get_threads_by_parent(self, parent_id: str) -> List[Dict[str, Any]]: + """Returns all threads belonging to a parent channel.""" + with self._lock: + conn = self._get_conn() + rows = conn.execute("SELECT * FROM threads WHERE parent_id = ?", (str(parent_id),)).fetchall() + conn.close() + return [dict(r) for r in rows] + + def get_thread(self, thread_id: str) -> Optional[Dict[str, Any]]: + """Retrieves a single thread's metadata.""" + with self._lock: + conn = self._get_conn() + row = conn.execute("SELECT * FROM threads WHERE id = ?", (str(thread_id),)).fetchone() + conn.close() + return dict(row) if row else None + + def get_all_users(self) -> List[Dict[str, Any]]: + with self._lock: + conn = self._get_conn() + rows = conn.execute("SELECT * FROM users").fetchall() + conn.close() + return [dict(r) for r in rows] + + def get_user(self, user_id: str) -> Optional[Dict[str, Any]]: + with self._lock: + conn = self._get_conn() + row = conn.execute("SELECT * FROM users WHERE id = ?", (str(user_id),)).fetchone() + conn.close() + if row: + data = dict(row) + if data.get("roles"): + data["roles"] = json.loads(data["roles"]) + return data + return None + + def get_server_assets(self, asset_type: Optional[str] = None) -> List[Dict[str, Any]]: + """Returns all server assets, optionally filtered by type.""" + with self._lock: + conn = self._get_conn() + if asset_type: + rows = conn.execute("SELECT * FROM server_assets WHERE type = ?", (asset_type,)).fetchall() + else: + rows = conn.execute("SELECT * FROM server_assets").fetchall() + conn.close() + return [dict(r) for r in rows] + + def get_all_media(self) -> Dict[str, Dict[str, Any]]: + """Returns the entire media pool as a dictionary indexed by hash.""" + with self._lock: + conn = self._get_conn() + rows = conn.execute("SELECT * FROM media_pool").fetchall() + conn.close() + return {r["hash"]: dict(r) for r in rows} + + def get_messages_paged(self, channel_id: str, limit: int = 100, offset: int = 0, after_id: Optional[str] = None) -> List[Dict[str, Any]]: + with self._lock: + conn = self._get_conn() + query = "SELECT * FROM messages WHERE channel_id = ?" + params = [str(channel_id)] + + if after_id: + query += " AND id > ?" + params.append(str(after_id)) + + query += " ORDER BY id ASC LIMIT ? OFFSET ?" + params.extend([limit, offset]) + + rows = conn.execute(query, params).fetchall() + msg_list = [dict(r) for r in rows] + + if msg_list: + # Fetch attachments for these messages + msg_ids = [m["id"] for m in msg_list] + placeholders = ",".join(["?"] * len(msg_ids)) + + # Attachments + att_rows = conn.execute(f"SELECT * FROM attachments WHERE message_id IN ({placeholders})", msg_ids).fetchall() + atts_by_msg = {} + for ar in att_rows: + mid = ar["message_id"] + if mid not in atts_by_msg: atts_by_msg[mid] = [] + atts_by_msg[mid].append(dict(ar)) + + # Embeds + emb_rows = conn.execute(f"SELECT * FROM embeds WHERE message_id IN ({placeholders})", msg_ids).fetchall() + embs_by_msg = {} + for er in emb_rows: + mid = er["message_id"] + if mid not in embs_by_msg: embs_by_msg[mid] = [] + + e_dict = { + "title": er["title"], + "description": er["description"], + "url": er["url"], + "color": er["color"], + "timestamp": er["timestamp"], + "thumbnail": {"url": er["thumbnail_url"]} if er["thumbnail_url"] else None, + "image": {"url": er["image_url"]} if er["image_url"] else None, + "author": { + "name": er["author_name"], + "url": er["author_url"], + "icon_url": er["author_icon_url"] + } if er["author_name"] else None, + "footer": { + "text": er["footer_text"], + "icon_url": er["footer_icon_url"] + } if er["footer_text"] else None, + "fields": json.loads(er["fields"]) if er["fields"] else [] + } + embs_by_msg[mid].append(e_dict) + + # Reactions + rea_rows = conn.execute(f"SELECT * FROM reactions WHERE message_id IN ({placeholders})", msg_ids).fetchall() + reas_by_msg = {} + for rr in rea_rows: + mid = rr["message_id"] + if mid not in reas_by_msg: reas_by_msg[mid] = [] + reas_by_msg[mid].append(dict(rr)) + + # Stickers (Message-specific) + st_rows = conn.execute(f"SELECT * FROM message_stickers WHERE message_id IN ({placeholders})", msg_ids).fetchall() + sts_by_msg = {} + for sr in st_rows: + mid = sr["message_id"] + if mid not in sts_by_msg: sts_by_msg[mid] = [] + sts_by_msg[mid].append(dict(sr)) + + for m in msg_list: + m_id = m["id"] + m["attachments"] = atts_by_msg.get(m_id, []) + m["embeds"] = embs_by_msg.get(m_id, []) + m["reactions"] = reas_by_msg.get(m_id, []) + m["stickers"] = sts_by_msg.get(m_id, []) + + conn.close() + return msg_list + + def close(self): + # We don't keep long-lived connections in this model to avoid locking issues, + # but the method is here for parity with other DB classes. + pass diff --git a/src/core/backup_reader.py b/src/core/backup_reader.py index ff889b3..e27976f 100644 --- a/src/core/backup_reader.py +++ b/src/core/backup_reader.py @@ -12,6 +12,7 @@ from datetime import datetime, timezone from enum import IntEnum from pathlib import Path from typing import AsyncGenerator, Dict, Any, List, Optional +from src.core.backup_database import BackupDatabase logger = logging.getLogger(__name__) @@ -22,17 +23,64 @@ logger = logging.getLogger(__name__) class ChannelType(IntEnum): text = 0 + private = 1 voice = 2 + group = 3 category = 4 news = 5 + news_thread = 10 public_thread = 11 + private_thread = 12 + stage_voice = 13 + directory = 14 forum = 15 + media = 16 +class StickerFormatType(IntEnum): + png = 1 + apng = 2 + lottie = 3 + gif = 4 + class MessageType(IntEnum): default = 0 + recipient_add = 1 + recipient_remove = 2 + call = 3 + channel_name_change = 4 + channel_icon_change = 5 + channel_pinned_message = 6 + user_join = 7 + guild_boost = 8 + guild_boost_tier_1 = 9 + guild_boost_tier_2 = 10 + guild_boost_tier_3 = 11 + channel_follow_add = 12 + guild_discovery_disqualified = 14 + guild_discovery_requalified = 15 + guild_discovery_grace_period_initial_warning = 16 + guild_discovery_grace_period_final_warning = 17 + thread_created = 18 reply = 19 + chat_input_command = 20 thread_starter_message = 21 + guild_invite_reminder = 22 + context_menu_command = 23 + auto_moderation_action = 24 + role_subscription_purchase = 25 + interaction_premium_upsell = 26 + stage_start = 27 + stage_end = 28 + stage_speaker = 29 + stage_topic = 31 + guild_application_premium_subscription = 32 + guild_incident_alert_mode_enabled = 36 + guild_incident_alert_mode_disabled = 37 + guild_incident_report_raid = 38 + guild_incident_report_false_alarm = 39 + purchase_notification = 44 + poll_result = 46 # --------------------------------------------------------------------------- @@ -200,14 +248,18 @@ class BackupRole: __slots__ = ("id", "name", "color", "position", "permissions", "hoist", "managed", "mentionable") def __init__(self, data: dict): + if not isinstance(data, dict): + self.id = 0 + self.name = "Unknown" + return self.id = int(data["id"]) self.name = data["name"] - self.color = BackupColor.from_hex(data.get("color", "#000000")) + self.color = BackupColor(int(data.get("color", 0))) self.position = data.get("position", 0) self.permissions = BackupPermissions(int(data.get("permissions", 0))) - self.hoist = data.get("hoist", False) - self.managed = data.get("managed", False) - self.mentionable = data.get("mentionable", True) + self.hoist = bool(data.get("hoist", False)) + self.managed = False + self.mentionable = bool(data.get("mentionable", True)) @property def mention(self) -> str: @@ -224,14 +276,21 @@ class BackupRole: return f"BackupRole(id={self.id}, name='{self.name}')" -def _parse_overwrites(raw_list: list) -> dict: +def _parse_overwrites(raw_list: list | Any) -> dict: """Parse overwrites JSON list into {BackupOverwriteTarget: BackupPermissionOverwrite} dict.""" result = {} + if not isinstance(raw_list, list): + return result for entry in raw_list: - target = BackupOverwriteTarget(int(entry["id"])) - ow = BackupPermissionOverwrite(allow=int(entry.get("allow", 0)), - deny=int(entry.get("deny", 0))) - result[target] = ow + if not isinstance(entry, dict): + continue + try: + target = BackupOverwriteTarget(int(entry["id"])) + ow = BackupPermissionOverwrite(allow=int(entry.get("allow", 0)), + deny=int(entry.get("deny", 0))) + result[target] = ow + except (KeyError, ValueError, TypeError): + continue return result @@ -248,7 +307,10 @@ class BackupCategory: self.name = data["name"] self.position = data.get("position", 0) self.type = ChannelType.category - self.overwrites = _parse_overwrites(data.get("overwrites", [])) + + # Overwrites are now passed as a list of dicts from get_all_channels + raw_ow = data.get("overwrites", []) + self.overwrites = _parse_overwrites(raw_ow) def __repr__(self) -> str: return f"BackupCategory(id={self.id}, name='{self.name}')" @@ -267,19 +329,30 @@ class BackupChannel: "forum": ChannelType.forum, "thread": ChannelType.public_thread, } - def __init__(self, data: dict, category_id: int | None = None, guild: "BackupGuild|None" = None): self.id = int(data["id"]) self.name = data["name"] - self.type = self._TYPE_MAP.get(data.get("type", "text"), ChannelType.text) + try: + self.type = ChannelType(int(data.get("type", 0))) + except ValueError: + self.type = ChannelType.text self.position = data.get("position", 0) self.topic = data.get("topic") - self.nsfw = data.get("nsfw", False) - self.category_id = category_id - self.parent_id = category_id - self.available_tags = data.get("available_tags", []) + self.nsfw = bool(data.get("nsfw", False)) + cid = data.get("category_id") + self.category_id = int(cid) if cid and cid != "None" else category_id + self.parent_id = self.category_id self.guild = guild - self.overwrites = _parse_overwrites(data.get("overwrites", [])) + + # Overwrites are now passed as a list of dicts from get_all_channels + raw_ow = data.get("overwrites", []) + self.overwrites = _parse_overwrites(raw_ow) + + # Forum tags are now structural + self.available_tags = [] + raw_tags = data.get("available_tags", []) + for t_data in raw_tags: + self.available_tags.append(BackupTag(t_data)) @property def mention(self) -> str: @@ -310,15 +383,20 @@ class BackupMember: "_avatar_url") def __init__(self, data: dict, role_objects: list | None = None, - avatar_base: Path | None = None): - self.id = int(data["userID"]) + backup_path: Path | None = None): + if not isinstance(data, dict): + # Fallback for unexpected data format + self.id = 0 + self.name = "Unknown" + return + self.id = int(data["id"]) self.name = data.get("username", "Unknown") - self.display_name = data.get("userNickname") or self.name + self.display_name = data.get("display_name") or self.name self.global_name = self.display_name - self.bot = data.get("userIsBot", False) + self.bot = False self.system = False self.discriminator = "0000" - self.color = BackupColor.from_hex(data.get("userColor") or "#000000") + self.color = BackupColor(0) self.roles = sorted(role_objects or [], key=lambda r: r.position, reverse=True) self.guild_permissions = BackupPermissions(0) @@ -327,15 +405,16 @@ class BackupMember: self.status = type("Status", (), {"value": "offline"})() self.activity = None - # CDN URL from Discord (saved during backup) - self._avatar_url = data.get("userAvatarUrl") + self._avatar_url = data.get("avatar_url") - # Local file asset (for reading bytes) - avatar_rel = data.get("userAvatar") - if avatar_rel and avatar_base: - self.avatar = BackupAsset(avatar_base / avatar_rel) + avatar_rel = data.get("avatar_file") + if avatar_rel and backup_path: + self.avatar = BackupAsset(backup_path / avatar_rel) else: self.avatar = BackupAsset(None) + + if self.avatar: + self.avatar.url = self._avatar_url @property def mention(self) -> str: @@ -361,15 +440,27 @@ class BackupMember: class BackupAttachment: """Minimal stand-in for discord.Attachment.""" - __slots__ = ("id", "filename", "size", "url", "proxy_url", "_backup_root") + __slots__ = ("id", "filename", "size", "url", "proxy_url", "_backup_root", "local_hash") - def __init__(self, data: dict, backup_root: Path | None = None): + def __init__(self, data: dict, backup_root: Path | None = None, media_pool: dict | None = None): + if not isinstance(data, dict): + self.id = 0 + self.filename = "unknown" + return self.id = int(data["id"]) - self.filename = data.get("fileName", "unknown") - self.size = data.get("fileSizeBytes", 0) + self.filename = data.get("filename", "unknown") + self.size = data.get("size", 0) self.url = data.get("url", "") self.proxy_url = self.url self._backup_root = backup_root + self.local_hash = data.get("local_hash") + + # Resolve local path via media pool if possible + if media_pool and self.local_hash in media_pool: + self.url = media_pool[self.local_hash]["local_path"] + elif self.local_hash: + # Fallback conjecture if pool didn't have it (e.g. ad-hoc load) + pass async def read(self) -> bytes: if self._backup_root: @@ -392,12 +483,17 @@ class BackupEmoji: __slots__ = ("id", "name", "animated", "url", "_file_path") def __init__(self, data: dict, media_dir: Path | None = None): + if not isinstance(data, dict): + self.id = 0 + self.name = "Unknown" + return self.id = int(data["id"]) self.name = data["name"] - self.animated = data.get("animated", False) + self.animated = data.get("mime_type") == "image/gif" filename = data.get("filename", "") self._file_path = media_dir / filename if media_dir and filename else None - self.url = str(self._file_path) if self._file_path else None + # Use the local path if available, else original URL + self.url = str(self._file_path) if self._file_path else data.get("url", "") async def read(self) -> bytes: if self._file_path and self._file_path.exists(): @@ -411,20 +507,45 @@ class BackupEmoji: class BackupSticker: """Minimal stand-in for discord.GuildSticker.""" - __slots__ = ("id", "name", "url", "format", "_backup_root") + __slots__ = ("id", "name", "url", "format", "_backup_root", "_file_path") - def __init__(self, data: dict, backup_root: Path | None = None): - self.id = int(data["messageID"]) if "messageID" in data else int(data.get("id", 0)) + def __init__(self, data: dict, backup_root: Path | None = None, media_pool: dict | None = None): + if not isinstance(data, dict): + self.id = 0 + self.name = "Sticker" + return + self.id = int(data.get("id") or data.get("sticker_id", 0)) self.name = data.get("name", "Sticker") - self.url = data.get("localPath", "") + + # Determine format + fmt = data.get("format_type", 1) + try: + self.format = StickerFormatType(fmt) + except ValueError: + self.format = StickerFormatType.png + self._backup_root = backup_root - self.format = data.get("format", "png") + + # 1. Check if it's a CAS-based sticker (from message_stickers table) + local_hash = data.get("local_hash") + if local_hash and backup_root: + ext = ".png" + if self.format == StickerFormatType.lottie: ext = ".json" + elif self.format == StickerFormatType.apng: ext = ".png" + elif self.format == StickerFormatType.gif: ext = ".gif" + + self._file_path = backup_root / "attachments" / f"{local_hash}{ext}" + # 2. Check if it's a server asset sticker (legacy or manual save) + elif data.get("filename") and backup_root: + self._file_path = backup_root / "server_assets" / data["filename"] + else: + self._file_path = None + + self.url = str(self._file_path) if self._file_path and self._file_path.exists() else data.get("url", "") async def read(self) -> bytes: - if self._backup_root and self.url: - full = self._backup_root / self.url - if full.exists(): - return full.read_bytes() + if self._file_path and self._file_path.exists(): + return self._file_path.read_bytes() return b"" def __repr__(self) -> str: @@ -469,6 +590,21 @@ class BackupReaction: return self.emoji.id is not None +class BackupTag: + """Minimal stand-in for discord.ForumTag.""" + __slots__ = ("id", "name", "moderated", "emoji") + def __init__(self, data: dict): + self.id = int(data["id"]) + self.name = data["name"] + self.moderated = bool(data.get("moderated", False)) + emoji_id = data.get("emoji_id") + emoji_name = data.get("emoji_name") + self.emoji = BackupPartialEmoji(name=emoji_name, id=int(emoji_id) if emoji_id else None) if emoji_name else None + + def __repr__(self) -> str: + return f"BackupTag(id={self.id}, name='{self.name}')" + + class BackupMessageReference: """Minimal stand-in for discord.MessageReference.""" @@ -483,16 +619,34 @@ class BackupThread: """Minimal stand-in for discord.Thread metadata attached to a starter message.""" __slots__ = ("id", "name", "message_count", "archived", - "auto_archive_duration", "locked", "parent_id") + "auto_archive_duration", "locked", "parent_id", "applied_tags", "type") def __init__(self, data: dict, parent_id: int | None = None): + if not isinstance(data, dict): + self.id = 0 + self.name = "" + return self.id = int(data["id"]) self.name = data.get("name", "") + try: + self.type = ChannelType(int(data.get("type", 11))) + except ValueError: + self.type = ChannelType.public_thread self.message_count = data.get("message_count", 0) self.archived = data.get("archived", False) self.auto_archive_duration = data.get("auto_archive_duration", 1440) self.locked = data.get("locked", False) - self.parent_id = parent_id + pid = data.get("parent_id") + self.parent_id = int(pid) if pid and pid != "None" else parent_id + + # Parse applied tags (JSON IDs) + self.applied_tags = [] + raw_tags = data.get("applied_tags") + if raw_tags: + try: + tag_ids = json.loads(raw_tags) if isinstance(raw_tags, str) else raw_tags + self.applied_tags = [int(tid) for tid in tag_ids] + except Exception: pass class BackupMessage: @@ -511,28 +665,31 @@ class BackupMessage: "mentions", "role_mentions", "channel_mentions", "mention_everyone", "tts", "nonce", "webhook_id", "application_id", "activity", "application", "interaction", "components", "jump_url") - def __init__(self, data: dict, *, author: BackupMember | None = None, guild: "BackupGuild | None" = None, - channel: BackupChannel | None = None, - backup_root: Path | None = None): - self.id = int(data["messageID"]) - self.type = self._TYPE_MAP.get(data.get("type", "Default"), MessageType.default) - self.pinned = data.get("isPinned", False) + channel: Optional[Any] = None, + backup_root: Path | None = None, + media_pool: dict | None = None): + self.id = int(data["id"]) + try: + self.type = MessageType(int(data.get("type", 0))) + except ValueError: + self.type = MessageType.default + self.pinned = bool(data.get("is_pinned", False)) self.content = data.get("content", "") self.author = author self.guild = guild self.channel = channel - self.channel_id = channel.id if channel else (int(data.get("channelID")) if data.get("channelID") else None) - - # Mentions (if not in backup, default to empty) + cid = data.get("channel_id") + self.channel_id = int(cid) if cid and cid != "None" else (channel.id if channel else None) + + # Mentions self.mentions = [] self.role_mentions = [] self.channel_mentions = [] - self.mention_everyone = data.get("mentionEveryone", False) - - # Standard discord.Message properties + self.mention_everyone = False + self.tts = False self.nonce = None self.webhook_id = None @@ -542,48 +699,102 @@ class BackupMessage: self.interaction = None self.components = [] self.jump_url = f"https://discord.com/channels/{guild.id if guild else 0}/{self.channel_id}/{self.id}" - + # Timestamp ts = data.get("timestamp") if ts: try: - self.created_at = datetime.fromisoformat(ts) + self.created_at = datetime.fromisoformat(ts).replace(tzinfo=timezone.utc) except (ValueError, TypeError): self.created_at = datetime.now(timezone.utc) else: self.created_at = datetime.now(timezone.utc) - - # Attachments - self.attachments = [ - BackupAttachment(a, backup_root=backup_root) - for a in data.get("attachments", []) - ] - - # Embeds — store raw dicts (discord.py Embed.from_dict compatible) - self.embeds = data.get("embeds", []) - + + # Attachments (parsed from DB or passed in) + self.attachments = [] + raw_atts = data.get("attachments", []) + if isinstance(raw_atts, str): + try: + raw_atts = json.loads(raw_atts) + except Exception: + raw_atts = [] + + for a in raw_atts: + if isinstance(a, dict): + self.attachments.append(BackupAttachment(a, backup_root=backup_root, media_pool=media_pool)) + + # Embeds + self.embeds = [] + for e in raw_embeds: + if isinstance(e, dict): + self.embeds.append(BackupEmbed(e)) + # Stickers - self.stickers = [ - BackupSticker(s, backup_root=backup_root) - for s in data.get("stickers", []) - ] + self.stickers = [] + raw_stickers = data.get("stickers", []) + for s in raw_stickers: + if isinstance(s, dict): + self.stickers.append(BackupSticker(s, backup_root=backup_root, media_pool=media_pool)) +class BackupEmbed: + """Minimal stand-in for discord.Embed.""" + __slots__ = ("title", "description", "url", "color", "timestamp", + "thumbnail", "image", "author", "footer", "fields") + + def __init__(self, data: dict): + self.title = data.get("title") + self.description = data.get("description") + self.url = data.get("url") + self.color = data.get("color") + self.timestamp = data.get("timestamp") + + self.thumbnail = type("Thumbnail", (), {"url": data["thumbnail"]["url"]})() if data.get("thumbnail") else None + self.image = type("Image", (), {"url": data["image"]["url"]})() if data.get("image") else None + + author = data.get("author") + self.author = type("Author", (), { + "name": author.get("name"), + "url": author.get("url"), + "icon_url": author.get("icon_url") + })() if author else None + + footer = data.get("footer") + self.footer = type("Footer", (), { + "text": footer.get("text"), + "icon_url": footer.get("icon_url") + })() if footer else None + + self.fields = [BackupEmbedField(f) for f in data.get("fields", [])] + +class BackupEmbedField: + """Minimal stand-in for embed fields.""" + __slots__ = ("name", "value", "inline") + def __init__(self, data: dict): + self.name = data.get("name") + self.value = data.get("value") + self.inline = bool(data.get("inline", False)) + + # Legacy extra_data support removed as requested + + self.stickers = [] + # Reactions - self.reactions = [BackupReaction(r) for r in data.get("reactions", [])] - - # Reference (replies) - ref = data.get("reference") - self.reference = BackupMessageReference(ref) if ref else None - - # Thread info - thread_data = data.get("thread") - self.thread = BackupThread(thread_data, parent_id=self.channel_id) if thread_data else None - - # Flags placeholder - self.flags = type("Flags", (), { - "forwarded": data.get("type") == "Forward", - "value": 0 # Bitmask value - })() + self.reactions = [] + raw_reactions = data.get("reactions", []) + if isinstance(raw_reactions, list): + self.reactions = raw_reactions + elif isinstance(raw_reactions, str): + try: + self.reactions = json.loads(raw_reactions) + except Exception: pass + + # Reference (replies/forwards) + self.reference = None + if data.get("message_reference"): + self.reference = type("Ref", (), {"message_id": int(data["message_reference"]), "channel_id": self.channel_id})() + + self.thread = None + self.flags = type("Flags", (), {"value": 0})() def __repr__(self) -> str: return f"BackupMessage(id={self.id}, author={self.author})" @@ -599,11 +810,15 @@ class BackupGuild: self.name = data["name"] self._reader = reader - icon_rel = data.get("icon") - self.icon = BackupAsset(backup_path / icon_rel) if icon_rel else BackupAsset(None) + icon_file = data.get("icon_file") + self.icon = BackupAsset(backup_path / icon_file) if icon_file else BackupAsset(None) + if self.icon: + self.icon.url = data.get("icon_url") - banner_rel = data.get("banner") - self.banner = BackupAsset(backup_path / banner_rel) if banner_rel else BackupAsset(None) + banner_file = data.get("banner_file") + self.banner = BackupAsset(backup_path / banner_file) if banner_file else BackupAsset(None) + if self.banner: + self.banner.url = data.get("banner_url") @property def roles(self) -> List[BackupRole]: @@ -695,8 +910,10 @@ class BackupReader: def __init__(self, backup_path: str | Path): self.backup_path = Path(backup_path) self.guild: BackupGuild | None = None + self.db: Optional[BackupDatabase] = None - self._thread_info: Dict[int, Dict[str, Any]] = {} # channel_id -> metadata (like name, parentID) + # Cache for performance + self._media_pool: Dict[str, Dict[str, Any]] = {} # Lazy loading flags self._roles_loaded = False @@ -707,6 +924,7 @@ class BackupReader: # Internal storage self._categories: List[BackupCategory] = [] self._channels: List[BackupChannel] = [] + self._threads: List[BackupThread] = [] self._roles: List[BackupRole] = [] self._emojis: List[BackupEmoji] = [] self._stickers: List[BackupSticker] = [] @@ -716,27 +934,26 @@ class BackupReader: # ── startup ────────────────────────────────────────────────────────── async def start(self): - """Initializes the backup path and loads the server profile.""" + """Initializes the backup path and the SQLite database.""" bp = self.backup_path - - # 1. Server profile -> BackupGuild - profile_file = bp / "server_profile" / "profile.json" - if profile_file.exists(): - profile = json.loads(profile_file.read_text(encoding="utf-8")) - self.guild = BackupGuild(profile, bp, reader=self) - logger.info(f"[Backup] Initialized server: {self.guild.name} ({self.guild.id})") + db_path = bp / "backup.db" + + if db_path.exists(): + self.db = BackupDatabase(db_path) + profile = self.db.get_guild_profile() + if profile: + self.guild = BackupGuild(profile, bp, reader=self) + logger.info(f"[Backup] Initialized database: {self.guild.name} ({self.guild.id})") else: - logger.warning(f"[Backup] server_profile/profile.json not found in {bp}") + logger.warning(f"[Backup] backup.db not found in {bp}") self.guild = None @property def roles(self) -> List[BackupRole]: - if not self._roles_loaded: - roles_file = self.backup_path / "server_profile" / "roles.json" - if roles_file.exists(): - logger.info(f"[Backup] Lazy-loading roles...") - roles_data = json.loads(roles_file.read_text(encoding="utf-8")) - self._roles = [BackupRole(r) for r in roles_data] + if not self._roles_loaded and self.db: + logger.info(f"[Backup] Loading roles from DB...") + rows = self.db.get_all_roles() + self._roles = [BackupRole(r) for r in rows] self._roles_loaded = True return self._roles @@ -745,26 +962,47 @@ class BackupReader: self._ensure_structure_loaded() return self._categories + @property + def threads(self) -> List[BackupThread]: + self._ensure_structure_loaded() + return self._threads + @property def channels(self) -> List[BackupChannel]: self._ensure_structure_loaded() return self._channels def _ensure_structure_loaded(self): - if self._structure_loaded: + if self._structure_loaded or not self.db: return - struct_file = self.backup_path / "server_profile" / "structure.json" - if struct_file.exists(): - logger.info(f"[Backup] Lazy-loading server structure...") - structure = json.loads(struct_file.read_text(encoding="utf-8")) - for cat_data in structure: - cat = BackupCategory(cat_data) - if cat.id != 0: - self._categories.append(cat) - for ch_data in cat_data.get("channels", []): - ch_cat_id = cat.id if cat.id != 0 else None - channel = BackupChannel(ch_data, category_id=ch_cat_id, guild=self.guild) - self._channels.append(channel) + + logger.info(f"[Backup] Loading server structure from DB...") + rows = self.db.get_all_channels() + + for data in rows: + if data["type"] == 4: # Category + self._categories.append(BackupCategory(data)) + else: + self._channels.append(BackupChannel(data, guild=self.guild)) + + # Load threads + thread_rows = self.db.get_all_threads() + for tdata in thread_rows: + thread = BackupThread(tdata) + self._threads.append(thread) + + # Resolve tag IDs to BackupTag objects using parent forum's available_tags + if thread.applied_tags and thread.parent_id: + parent = next((c for c in self._channels if c.id == thread.parent_id), None) + if parent and hasattr(parent, "available_tags"): + resolved = [] + for tid in thread.applied_tags: + # tid is int because of BackupThread.__init__ + tag = next((tag for tag in parent.available_tags if tag.id == tid), None) + if tag: + resolved.append(tag) + thread.applied_tags = resolved + self._structure_loaded = True @property @@ -778,15 +1016,18 @@ class BackupReader: return self._stickers def _ensure_assets_loaded(self): - if self._assets_loaded: + if self._assets_loaded or not self.db: return - assets_file = self.backup_path / "server_profile" / "assets.json" - media_dir = self.backup_path / "server_profile" / "assets" - if assets_file.exists(): - logger.info(f"[Backup] Lazy-loading assets...") - assets = json.loads(assets_file.read_text(encoding="utf-8")) - self._emojis = [BackupEmoji(e, media_dir) for e in assets.get("emojis", [])] - self._stickers = [BackupSticker(s, media_dir) for s in assets.get("stickers", [])] + + logger.info(f"[Backup] Loading assets from DB...") + media_dir = self.backup_path / "server_assets" + + emoji_rows = self.db.get_server_assets("emoji") + self._emojis = [BackupEmoji(e, media_dir) for e in emoji_rows if isinstance(e, dict)] + + sticker_rows = self.db.get_server_assets("sticker") + self._stickers = [BackupSticker(s, self.backup_path) for s in sticker_rows if isinstance(s, dict)] + self._assets_loaded = True @property @@ -795,29 +1036,36 @@ class BackupReader: return self._members def _ensure_members_loaded(self): - if self._members_loaded: + if self._members_loaded or not self.db: return - user_info_file = self.backup_path / "message_backup" / "users" / "user_info.json" - if user_info_file.exists(): - logger.info(f"[Backup] Lazy-loading members...") - try: - users = json.loads(user_info_file.read_text(encoding="utf-8")) - backup_root = self.backup_path / "message_backup" - for u in users: - user_role_ids = {int(r["id"]) for r in u.get("userRoles", [])} - # Note: this triggers roles lazy load - role_objs = [r for r in self.roles if r.id in user_role_ids] - member = BackupMember(u, role_objects=role_objs, avatar_base=backup_root) - self._members.append(member) - self._member_map[member.id] = member - except Exception as e: - logger.warning(f"[Backup] Failed to lazy-load user_info.json: {e}") + + logger.info(f"[Backup] Loading members from DB...") + rows = self.db.get_all_users() + + bp = self.backup_path + for u in rows: + if u.get("roles") and isinstance(u["roles"], str): + try: + u["roles"] = json.loads(u["roles"]) + except Exception: + u["roles"] = [] + + user_role_ids = set() + for rid in (u.get("roles") or []): + try: + user_role_ids.add(int(rid)) + except (ValueError, TypeError): + continue + role_objs = [r for r in self.roles if r.id in user_role_ids] + member = BackupMember(u, role_objects=role_objs, backup_path=bp) + self._members.append(member) + self._member_map[member.id] = member self._members_loaded = True # ── validation ─────────────────────────────────────────────────────── async def validate(self) -> Dict[str, Any]: - """Validates backup directory integrity.""" + """Validates backup database integrity.""" results = { "token": False, "server": False, @@ -828,18 +1076,18 @@ class BackupReader: } bp = self.backup_path - if not bp.exists() or not bp.is_dir(): - return results - - profile = bp / "server_profile" / "profile.json" - structure = bp / "server_profile" / "structure.json" - if profile.exists() and structure.exists(): + db_path = bp / "backup.db" + + if db_path.exists(): try: - data = json.loads(profile.read_text(encoding="utf-8")) - results["token"] = True - results["server"] = True - results["bot_name"] = "LOCAL BACKUP" - results["server_name"] = data.get("name", "Unknown") + # Use sub-connection to validate + db = BackupDatabase(db_path) + profile = db.get_guild_profile() + if profile: + results["token"] = True + results["server"] = True + results["bot_name"] = "LOCAL BACKUP" + results["server_name"] = profile.get("name", "Unknown") except Exception: pass @@ -853,8 +1101,8 @@ class BackupReader: return { "name": self.guild.name, "id": str(self.guild.id), - "icon_url": self.guild.icon.url if self.guild.icon else None, - "banner_url": self.guild.banner.url if self.guild.banner else None, + "icon_url": str(self.guild.icon.url) if self.guild.icon else None, + "banner_url": str(self.guild.banner.url) if self.guild.banner else None, } async def download_asset(self, asset: BackupAsset) -> bytes: @@ -863,135 +1111,121 @@ class BackupReader: # ── categories & channels ──────────────────────────────────────────── async def get_categories(self) -> List[BackupCategory]: - return list(self._categories) + return list(self.categories) async def get_channels(self, category_id: int | None = None) -> List[BackupChannel]: - channels = [c for c in self._channels if c.type != ChannelType.category] + channels = [c for c in self.channels if c.type != ChannelType.category] if category_id is not None: channels = [c for c in channels if c.category_id == category_id] return channels async def get_backed_up_channel_ids(self) -> List[int]: - """Returns a list of channel IDs that have corresponding backup directories.""" - backup_dir = self.backup_path / "message_backup" - if not backup_dir.exists(): - return [] - - ids = [] - for d in backup_dir.iterdir(): - if not d.is_dir(): - continue - if d.name == "users": - continue - # A backed-up channel has a messages.json inside its directory - if (d / "messages.json").exists(): - try: - ids.append(int(d.name)) - except ValueError: - pass - return ids + """Returns a list of channel IDs that have messages in the database.""" + if not self.db: return [] + import sqlite3 + conn = sqlite3.connect(self.db.db_path) + rows = conn.execute("SELECT DISTINCT channel_id FROM messages").fetchall() + conn.close() + return [int(r[0]) for r in rows] - async def get_channel(self, channel_id: int) -> BackupChannel | None: - for c in self._channels: + async def get_channel(self, channel_id: int) -> BackupChannel | BackupThread | None: + for c in self.channels: if c.id == channel_id: return c + for t in self.threads: + if t.id == channel_id: + return t return None # ── roles, emojis, stickers, members ───────────────────────────────── async def get_roles(self) -> List[BackupRole]: - return [r for r in self._roles if not r.is_default()] + return [r for r in self.roles if not r.is_default()] async def get_emojis(self) -> List[BackupEmoji]: - return list(self._emojis) + return list(self.emojis) async def get_stickers(self) -> List[BackupSticker]: - return list(self._stickers) + return list(self.stickers) async def get_members(self) -> List[BackupMember]: - return list(self._members) + return list(self.members) # ── messages ───────────────────────────────────────────────────────── - def _resolve_author(self, user_id_str: str) -> BackupMember: + def _ensure_media_pool_loaded(self): + if not self._media_pool and self.db: + self._media_pool = self.db.get_all_media() + + def _resolve_author(self, user_id: int) -> BackupMember: """Returns BackupMember for a userID, creating a stub if missing.""" - uid = int(user_id_str) - if uid in self._member_map: - return self._member_map[uid] + if user_id in self._member_map: + return self._member_map[user_id] + + # Try to fetch from DB + user_data = self.db.get_user(str(user_id)) if self.db else None + if user_data: + user_role_ids = {int(rid) for rid in (user_data.get("roles") or [])} + role_objs = [r for r in self.roles if r.id in user_role_ids] + member = BackupMember(user_data, role_objects=role_objs, backup_path=self.backup_path) + self._members.append(member) + self._member_map[user_id] = member + return member + + # Stub stub = BackupMember({ - "userID": user_id_str, - "username": f"User#{user_id_str[-4:]}", - "userIsBot": False, + "id": str(user_id), + "username": f"User#{str(user_id)[-4:]}", }) - self._member_map[uid] = stub + self._member_map[user_id] = stub return stub - def _load_channel_messages_data(self, channel_id: int) -> list[dict]: - """Loads the raw messages array from a channel JSON file.""" - bp = self.backup_path / "message_backup" + def _hydrate_message(self, msg_data: dict) -> BackupMessage: + user_id = int(msg_data.get("author_id", 0)) + author = self._resolve_author(user_id) - # Primary: message_backup/{channel_id}/messages.json - json_file = bp / str(channel_id) / "messages.json" - if not json_file.exists(): - # Fallback: search for thread_messages.json inside any parent channel - for candidate in bp.glob(f"*/{channel_id}/thread_messages.json"): - if candidate.exists(): - json_file = candidate - break - else: - return [] - - try: - data = json.loads(json_file.read_text(encoding="utf-8")) - - # Cache thread info if this is a thread - if "parentID" in data: - self._thread_info[channel_id] = { - "name": data.get("channelName", "Unknown Thread"), - "parent_id": int(data["parentID"]) - } - - return data.get("messages", []) - except Exception as e: - logger.error(f"[Backup] Failed to load messages for channel {channel_id}: {e}") - return [] - - def _hydrate_message(self, msg_data: dict, channel_id: int) -> BackupMessage: - author = self._resolve_author(msg_data.get("userID", "0")) - backup_root = self.backup_path / "message_backup" + self._ensure_media_pool_loaded() - # Resolve channel object - channel = next((c for c in self._channels if c.id == channel_id), None) - - # If not found in channels, check if it's a known thread - if not channel and channel_id in self._thread_info: - info = self._thread_info[channel_id] - # Create a stub BackupChannel for the thread - channel = BackupChannel({ - "id": str(channel_id), - "name": info["name"], - "type": "thread" - }, category_id=info["parent_id"], guild=self.guild) + channel_id = int(msg_data["channel_id"]) + channel = next((c for c in self.channels if c.id == channel_id), None) return BackupMessage( msg_data, author=author, guild=self.guild, channel=channel, - backup_root=backup_root, + backup_root=self.backup_path, + media_pool=self._media_pool ) async def get_message(self, channel_id: int, message_id: int) -> BackupMessage | None: - messages = self._load_channel_messages_data(channel_id) - for m in messages: - if int(m["messageID"]) == message_id: - return self._hydrate_message(m, channel_id) + """Fetch a specific message from SQLite.""" + if not self.db: return None + import sqlite3 + conn = sqlite3.connect(self.db.db_path) + conn.row_factory = sqlite3.Row + row = conn.execute("SELECT * FROM messages WHERE id = ?", (str(message_id),)).fetchone() + if row: + data = dict(row) + # Fetch attachments + atts = conn.execute("SELECT * FROM attachments WHERE message_id = ?", (str(message_id),)).fetchall() + data["attachments"] = [dict(a) for a in atts] + + # Fetch stickers + sts = conn.execute("SELECT * FROM message_stickers WHERE message_id = ?", (str(message_id),)).fetchall() + data["stickers"] = [dict(s) for s in sts] + + conn.close() + return self._hydrate_message(data) + conn.close() return None async def get_first_message(self, channel_id: int) -> BackupMessage | None: - messages = self._load_channel_messages_data(channel_id) - if messages: - return self._hydrate_message(messages[0], channel_id) + """Fetch the first message in a channel from SQLite.""" + if not self.db: return None + msgs = self.db.get_messages_paged(str(channel_id), limit=1) + if msgs: + return self._hydrate_message(msgs[0]) return None async def fetch_message_history( @@ -1001,23 +1235,37 @@ class BackupReader: after_id: int = None, inclusive: bool = False ) -> AsyncGenerator["BackupMessage", None]: - """Yields BackupMessages from the backup, respecting after_id and limit.""" - messages = self._load_channel_messages_data(channel_id) + """Yields BackupMessages from SQLite, respecting after_id and limit.""" + if not self.db: return + + offset = 0 + batch_size = 100 count = 0 - - for m in messages: - msg_id = int(m["messageID"]) - if after_id: - if inclusive and msg_id < after_id: - continue - if not inclusive and msg_id <= after_id: - continue - - yield self._hydrate_message(m, channel_id) - count += 1 - - if limit and count >= limit: - return + + while True: + actual_limit = batch_size + if limit: + rem = limit - count + if rem <= 0: break + actual_limit = min(batch_size, rem) + + msgs = self.db.get_messages_paged( + str(channel_id), + limit=actual_limit, + offset=offset, + after_id=str(after_id) if after_id else None + ) + + if not msgs: + break + + for m in msgs: + yield self._hydrate_message(m) + count += 1 + + offset += len(msgs) + if len(msgs) < batch_size: + break # ── download helpers ───────────────────────────────────────────────── diff --git a/src/core/discord_reader.py b/src/core/discord_reader.py index fbdd4ab..07789fa 100644 --- a/src/core/discord_reader.py +++ b/src/core/discord_reader.py @@ -279,9 +279,14 @@ class DiscordReader: try: # In a forum, the starter message ID is the thread ID starter = await thread.fetch_message(thread.id) - # Bind the thread so migrate_messages handles it properly - if not hasattr(starter, 'thread') or starter.thread is None: - starter.thread = thread + # Bind the thread if possible so downstream code has context + try: + if not hasattr(starter, 'thread') or starter.thread is None: + starter.thread = thread + except (AttributeError, TypeError): + # Some versions of discord.py don't allow setting the thread property + # or it might already be populated as a read-only property + pass yield starter except Exception as e: logger.debug(f"Could not fetch starter message for forum thread {thread.id}: {e}") diff --git a/src/core/exporter.py b/src/core/exporter.py index 2e5ac34..8d2446c 100644 --- a/src/core/exporter.py +++ b/src/core/exporter.py @@ -2,9 +2,11 @@ import os import json import logging import asyncio +import hashlib import discord from pathlib import Path from typing import Dict, Any, List, Optional, AsyncGenerator +from src.core.backup_database import BackupDatabase logger = logging.getLogger(__name__) @@ -18,6 +20,7 @@ class DiscordExporter: self.user_cache = {} self.base_dir = Path(base_dir) if base_dir else Path(".") self.is_running = True + self.db: Optional[BackupDatabase] = None async def setup(self): """Prepares the output directory and fetches server metadata.""" @@ -25,93 +28,94 @@ class DiscordExporter: self.server_name = metadata.get("name", "Unknown Server") self.server_id = metadata.get("id", "0") - # Create safe folder name - import re - safe_name = re.sub(r'[^a-zA-Z0-9_\-\.]', '_', self.server_name) + # Root export path: DISCORD_BACKUP-{id} self.export_path = self.base_dir / f"DISCORD_BACKUP-{self.server_id}" self.export_path.mkdir(parents=True, exist_ok=True) - # Server profile directory for metadata and assets - self.profile_path = self.export_path / "server_profile" - self.profile_path.mkdir(exist_ok=True) + # New directory structure + self.assets_path = self.export_path / "server_assets" + self.assets_path.mkdir(exist_ok=True) - # Consolidate media into server_profile/assets/ - self.media_path = self.profile_path / "assets" - self.media_path.mkdir(exist_ok=True) + self.users_path = self.export_path / "users" + self.users_path.mkdir(exist_ok=True) + + self.attachments_path = self.export_path / "attachments" + self.attachments_path.mkdir(exist_ok=True) + + # Initialize SQLite database + db_path = self.export_path / "backup.db" + self.db = BackupDatabase(db_path) logger.info(f"Export directory set to: {self.export_path}") - logger.info(f"Targeting server: {self.server_name} ({self.server_id})") + logger.info(f"Initialized backup database at {db_path}") return metadata - def _save_json_sync(self, file_path, data): - """Sync helper for saving JSON, meant to be run in a thread.""" - with open(file_path, "w", encoding="utf-8") as f: - json.dump(data, f, indent=4, ensure_ascii=False) - - async def _save_json(self, file_path, data): - """Async wrapper for saving JSON in a thread.""" - await asyncio.to_thread(self._save_json_sync, file_path, data) + def _calculate_sha256(self, file_path: Path) -> str: + """Calculates SHA-256 hash of a file.""" + hash_sha256 = hashlib.sha256() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_sha256.update(chunk) + return hash_sha256.hexdigest() async def export_metadata(self): - """Saves server metadata to a JSON file.""" + """Saves server metadata to the SQLite database.""" metadata = await self.reader.get_server_metadata() - # Add relative paths to local assets + # Relative paths to local assets for the UI/Reader if self.reader.guild: if self.reader.guild.icon: ext = "gif" if self.reader.guild.icon.is_animated() else "png" - metadata["icon"] = f"server_profile/assets/server_icon.{ext}" + metadata["icon_file"] = f"server_assets/server_icon.{ext}" + metadata["icon_url"] = str(self.reader.guild.icon.url) else: - metadata["icon"] = None + metadata["icon_file"] = None + metadata["icon_url"] = None if self.reader.guild.banner: ext = "gif" if self.reader.guild.banner.is_animated() else "png" - metadata["banner"] = f"server_profile/assets/server_banner.{ext}" + metadata["banner_file"] = f"server_assets/server_banner.{ext}" + metadata["banner_url"] = str(self.reader.guild.banner.url) else: - metadata["banner"] = None + metadata["banner_file"] = None + metadata["banner_url"] = None - # Add metadata fields from datetime import datetime metadata["last_backup"] = datetime.now().isoformat() - output_file = self.profile_path / "profile.json" - - # Preserve ignore_channels if the file already exists - ignore_channels = [] - if output_file.exists(): - try: - with open(output_file, "r", encoding="utf-8") as f: - old_data = json.load(f) - ignore_channels = old_data.get("ignore_channels", []) - except Exception as e: - logger.warning(f"Could not read existing profile.json to preserve ignore_channels: {e}") - - metadata["ignore_channels"] = ignore_channels - - await self._save_json(output_file, metadata) + # Fetch existing ignore_channels from DB if available + if self.db: + existing_profile = self.db.get_guild_profile() + if existing_profile and "ignore_channels" in existing_profile: + metadata["ignore_channels"] = existing_profile["ignore_channels"] + else: + metadata["ignore_channels"] = [] # Initialize if not present + + self.db.set_guild_profile(metadata) + return metadata async def export_roles(self): - """Exports all roles to server_roles.json.""" + """Exports all roles to the SQLite database.""" roles = await self.reader.get_roles() role_data = [] for r in roles: role_data.append({ "id": str(r.id), "name": r.name, - "color": str(r.color), + "color": r.color.value, "position": r.position, - "permissions": r.permissions.value, - "hoist": r.hoist, - "mentionable": r.mentionable + "permissions": str(r.permissions.value), + "hoist": 1 if r.hoist else 0, + "mentionable": 1 if r.mentionable else 0 }) - output_file = self.profile_path / "roles.json" - await self._save_json(output_file, role_data) + if self.db: + self.db.save_roles(role_data) return role_data async def download_server_assets(self): - """Downloads server icon and banner to media folder.""" + """Downloads server icon and banner to server_assets folder.""" metadata = await self.reader.get_server_metadata() # Download Server Icon if metadata.get("icon_url"): @@ -120,18 +124,12 @@ class DiscordExporter: logger.info(f"Downloading server icon: {self.reader.guild.icon.url}") data = await self.reader.download_asset(self.reader.guild.icon) ext = "gif" if self.reader.guild.icon.is_animated() else "png" - icon_path = self.media_path / f"server_icon.{ext}" + icon_path = self.assets_path / f"server_icon.{ext}" with open(icon_path, "wb") as f: f.write(data) logger.info(f"Saved server icon to {icon_path}") - else: - logger.warning("Icon URL found in metadata but guild icon asset is missing.") - except discord.Forbidden: - logger.error("403 Forbidden: Missing Access to download server icon.") except Exception as e: logger.error(f"Failed to download server icon: {e}") - else: - logger.info("No server icon found to download.") # Download Server Banner if metadata.get("banner_url"): @@ -140,19 +138,15 @@ class DiscordExporter: logger.info(f"Downloading server banner: {self.reader.guild.banner.url}") data = await self.reader.download_asset(self.reader.guild.banner) ext = "gif" if self.reader.guild.banner.is_animated() else "png" - banner_path = self.media_path / f"server_banner.{ext}" + banner_path = self.assets_path / f"server_banner.{ext}" with open(banner_path, "wb") as f: f.write(data) logger.info(f"Saved server banner to {banner_path}") - except discord.Forbidden: - logger.error("403 Forbidden: Missing Access to download server banner.") except Exception as e: logger.error(f"Failed to download server banner: {e}") - else: - logger.info("No server banner found to download.") async def export_assets(self): - """Exports emojis, stickers, and server media to assets.json and server_profile/assets/.""" + """Exports emojis and stickers to server_assets/ folder.""" await self.download_server_assets() emojis = await self.reader.get_emojis() @@ -162,20 +156,21 @@ class DiscordExporter: logger.info(f"Exporting {len(emojis)} emojis...") for e in emojis: ext = "gif" if e.animated else "png" - filename = f"emoji_{e.name}_{e.id}.{ext}" - emoji_path = self.media_path / filename + filename = f"emoji_{e.id}.{ext}" + emoji_path = self.assets_path / filename try: - data = await self.reader.download_emoji(e) - with open(emoji_path, "wb") as f: - f.write(data) + if not emoji_path.exists(): + data = await self.reader.download_emoji(e) + with open(emoji_path, "wb") as f: + f.write(data) emoji_data.append({ "id": str(e.id), "name": e.name, - "animated": e.animated, - "filename": filename + "type": "emoji", + "filename": filename, + "url": str(e.url), + "mime_type": "image/gif" if e.animated else "image/png" }) - except discord.Forbidden: - logger.error(f"403 Forbidden: Missing Access to download emoji {e.name}") except Exception as ex: logger.error(f"Failed to download emoji {e.name}: {ex}") @@ -188,537 +183,441 @@ class DiscordExporter: elif ".gif" in str(s.url): ext = "gif" elif ".webp" in str(s.url): ext = "webp" - filename = f"sticker_{s.name}_{s.id}.{ext}" - sticker_path = self.media_path / filename + filename = f"sticker_{s.id}.{ext}" + sticker_path = self.assets_path / filename try: - data = await self.reader.download_sticker(s) - with open(sticker_path, "wb") as f: - f.write(data) + if not sticker_path.exists(): + data = await self.reader.download_sticker(s) + with open(sticker_path, "wb") as f: + f.write(data) + mime_type = "image/png" + if ext == "json": mime_type = "application/json" + elif ext == "gif": mime_type = "image/gif" + elif ext == "webp": mime_type = "image/webp" + sticker_data.append({ "id": str(s.id), "name": s.name, - "filename": filename + "type": "sticker", + "filename": filename, + "url": str(s.url), + "mime_type": mime_type }) - except discord.Forbidden: - logger.error(f"403 Forbidden: Missing Access to download sticker {s.name}") except Exception as ex: logger.error(f"Failed to download sticker {s.name}: {ex}") - # Try to load existing customization to merge (if it exists) - custom_file = self.profile_path / "assets.json" - customization = {"emojis": emoji_data, "stickers": sticker_data, "members": []} - if custom_file.exists(): - try: - with open(custom_file, "r", encoding="utf-8") as f: - old_data = json.load(f) - customization["members"] = old_data.get("members", []) - except Exception: pass + # Save to database + if self.db: + all_assets = emoji_data + sticker_data + if all_assets: + self.db.save_server_assets(all_assets) - await self._save_json(custom_file, customization) - return len(emoji_data), len(sticker_data) async def export_channels_structure(self): - """Exports categories and channels hierarchy.""" + """Exports categories and channels hierarchy to SQLite.""" categories = await self.reader.get_categories() channels = await self.reader.get_channels() - structure = [] + db_channels = [] + db_permissions = [] + db_forum_tags = [] chan_count = 0 cat_count = len(categories) for cat in categories: cat_channels = [c for c in channels if c.category_id == cat.id] - formatted_channels = await asyncio.gather(*[self._format_channel(c) for c in cat_channels]) + formatted_channels, cat_chan_perms, cat_forum_tags = await self._process_channel_batch(cat_channels) chan_count += len(formatted_channels) - # Serialize role-only permission overwrites - cat_overwrites = [] + db_permissions.extend(cat_chan_perms) + db_forum_tags.extend(cat_forum_tags) + + # Category permissions for target, ow in cat.overwrites.items(): - if isinstance(target, discord.Role): + if isinstance(target, (discord.Role, discord.Member)): allow, deny = ow.pair() - cat_overwrites.append({ - "id": str(target.id), + db_permissions.append({ + "channel_id": str(cat.id), + "target_id": str(target.id), + "target_type": "role" if isinstance(target, discord.Role) else "member", "allow": allow.value, "deny": deny.value }) - structure.append({ - "type": "category", + db_channels.append({ "id": str(cat.id), "name": cat.name, + "type": int(cat.type.value) if hasattr(cat.type, "value") else 4, "position": cat.position, - "overwrites": cat_overwrites, - "channels": list(formatted_channels) + "category_id": None, + "topic": None, + "nsfw": 0 }) + + # Add child channels to list + for fc in formatted_channels: + fc["category_id"] = str(cat.id) + db_channels.append(fc) # Uncategorized uncategorized = [c for c in channels if not c.category_id] if uncategorized: - formatted_uncat = await asyncio.gather(*[self._format_channel(c) for c in uncategorized]) + formatted_uncat, uncat_perms, uncat_forum_tags = await self._process_channel_batch(uncategorized) chan_count += len(formatted_uncat) - structure.append({ - "type": "category", - "id": "uncategorized", - "name": "Uncategorized", - "channels": list(formatted_uncat) - }) - # No need to increment cat_count for 'Uncategorized' usually, - # but let's see if the user wants it. For now, cat_count is real Discord categories. + db_permissions.extend(uncat_perms) + db_forum_tags.extend(uncat_forum_tags) + for fc in formatted_uncat: + fc["category_id"] = None + db_channels.append(fc) - output_file = self.profile_path / "structure.json" - await self._save_json(output_file, structure) - return structure, cat_count, chan_count + if self.db: + self.db.save_channels(db_channels) + if db_permissions: + self.db.save_permissions(db_permissions) + if db_forum_tags: + self.db.save_forum_tags(db_forum_tags) + + return db_channels, cat_count, chan_count + + async def _process_channel_batch(self, channels): + """Processes a batch of channels, extracting metadata, permissions, and forum tags.""" + results = await asyncio.gather(*[self._format_channel(c) for c in channels]) + formatted = [] + permissions = [] + forum_tags = [] + for f_data, f_perms, f_tags in results: + formatted.append(f_data) + permissions.extend(f_perms) + if f_tags: + forum_tags.extend(f_tags) + return formatted, permissions, forum_tags async def _format_channel(self, c): - # Serialize role-only permission overwrites - ch_overwrites = [] + """Prepares channel data, its permissions, and forum tags for DB storage.""" + ch_permissions = [] for target, ow in c.overwrites.items(): - if isinstance(target, discord.Role): + if isinstance(target, (discord.Role, discord.Member)): allow, deny = ow.pair() - ch_overwrites.append({ - "id": str(target.id), + ch_permissions.append({ + "channel_id": str(c.id), + "target_id": str(target.id), + "target_type": "role" if isinstance(target, discord.Role) else "member", "allow": allow.value, "deny": deny.value }) + ch_forum_tags = [] + if isinstance(c, discord.ForumChannel): + for t in c.available_tags: + ch_forum_tags.append({ + "id": str(t.id), + "forum_id": str(c.id), + "name": t.name, + "moderated": 1 if t.moderated else 0, + "emoji_id": str(t.emoji.id) if t.emoji and hasattr(t.emoji, "id") else None, + "emoji_name": t.emoji.name if t.emoji else (str(t.emoji) if t.emoji else None) + }) + data = { "id": str(c.id), "name": c.name, - "type": str(c.type), + "type": int(c.type.value) if hasattr(c.type, "value") else 0, "position": c.position, "topic": getattr(c, "topic", None), - "nsfw": getattr(c, "nsfw", False), - "overwrites": ch_overwrites + "nsfw": 1 if getattr(c, "nsfw", False) else 0 } - if isinstance(c, discord.ForumChannel): - data["available_tags"] = [ - {"id": str(t.id), "name": t.name, "moderated": t.moderated, "emoji_id": str(t.emoji.id) if t.emoji and hasattr(t.emoji, "id") else None, "emoji_name": t.emoji.name if t.emoji else None} - for t in c.available_tags - ] - - return data + return data, ch_permissions, ch_forum_tags - async def export_channel_messages(self, channel_id: int, progress_callback=None, force=False, accumulated_count=0, after_id: int | None = None): - """Fetches and saves message history for a channel, handling incremental sync. Returns the total messages processed.""" + async def export_channel_messages(self, channel_id: int, progress_callback=None, force=False, accumulated_count=0, accumulated_threads=0, accumulated_files=0, after_id: int | None = None): + """Fetches and saves message history for a channel to SQLite, handling incremental sync.""" channel = await self.reader.get_channel(channel_id) if not channel: logger.error(f"Channel not found: {channel_id}") - return 0 + return accumulated_count, accumulated_threads, accumulated_files channel_name = channel.name - safe_name = channel_name.replace(" ", "-").lower() - - # Detection for thread grouping is_thread = isinstance(channel, discord.Thread) is_forum = isinstance(channel, discord.ForumChannel) - backup_root = self.export_path / "message_backup" - - if is_thread: - parent = await self.reader.get_channel(channel.parent_id) - # All threads nest inside their parent channel directory - backup_dir = backup_root / str(channel.parent_id) / str(channel_id) - backup_dir.mkdir(parents=True, exist_ok=True) - avatar_rel_base = "../../users/avatars" - elif is_forum: - # Forum metadata root: message_backup/{forum_id}/ - backup_dir = backup_root / str(channel_id) - backup_dir.mkdir(parents=True, exist_ok=True) - avatar_rel_base = "../users/avatars" - else: - # Regular channel: message_backup/{channel_id}/ - backup_dir = backup_root / str(channel_id) - backup_dir.mkdir(parents=True, exist_ok=True) - avatar_rel_base = "../users/avatars" - # Shared avatars directory: message_backup/users/avatars/ - users_dir = backup_root / "users" - users_dir.mkdir(exist_ok=True) - avatar_dir = users_dir / "avatars" - avatar_dir.mkdir(exist_ok=True) - - # Load existing user_info.json - user_info_file = users_dir / "user_info.json" - if not self.user_cache and user_info_file.exists(): - try: - with open(user_info_file, "r", encoding="utf-8") as f: - u_list = json.load(f) - self.user_cache = {u["id"]: u for u in u_list} - except Exception: - self.user_cache = {} + # 1. Determine incremental sync point + last_id = after_id + if not force and last_id is None and self.db: + stored_last_id = self.db.get_last_message_id(channel_id) + if stored_last_id: + last_id = int(stored_last_id) + logger.info(f"Incremental sync for {channel_name}: starting after {last_id}") - # Determine file names based on type - if is_thread: - json_file = backup_dir / "thread_messages.json" - asset_dir = backup_dir / "thread_attachments" - # asset_prefix is the relative path from message_backup/ for URL references - asset_prefix = f"{channel.parent_id}/{channel_id}/thread_attachments" - else: - json_file = backup_dir / "messages.json" - asset_dir = backup_dir / "attachments" - asset_prefix = f"{channel_id}/attachments" - - if force and asset_dir.exists(): - import shutil - try: - shutil.rmtree(asset_dir) - except Exception as e: - logger.warning(f"Failed to clear asset directory {asset_dir}: {e}") - - asset_dir.mkdir(exist_ok=True) - - messages = [] - last_id = None - - # Load existing messages for incremental sync (skip if force) - if after_id is not None: - last_id = after_id - elif not force and json_file.exists(): - try: - with open(json_file, "r", encoding="utf-8") as f: - old_data = json.load(f) - messages = old_data.get("messages", []) - if "lastMessageID" in old_data: - last_id = int(old_data["lastMessageID"]) - elif messages: - last_id = int(messages[-1]["messageID"]) - except Exception as e: - logger.warning(f"Could not load existing backup for sync in {channel_name}: {e}") - messages = [] - - count = len(messages) new_count = 0 - thread_count = 0 - thread_msg_count = 0 + BATCH_SIZE = 100 + USER_SAVE_INTERVAL = 500 # Save user cache every N new messages - BATCH_SIZE = 100 # Process messages in parallel batches - UI_LOG_INTERVAL = 10 # Only log message preview every N messages - USER_SAVE_INTERVAL = 100 # Save user_info.json every N new messages - - # 1. Fetch new messages - Handle Forbidden gracefully + # Batch accumulator for DB inserts + batch_messages = [] + batch_users = [] + try: batch_raw = [] async for msg in self.reader.fetch_message_history(channel_id, after_id=last_id): if not self.is_running: break batch_raw.append(msg) - # Process in batches for parallelism if len(batch_raw) >= BATCH_SIZE: - # Format all messages in the batch concurrently - batch_results = await asyncio.gather( - *(self._format_message(m, asset_dir, asset_prefix, avatar_dir, avatar_rel_base) for m in batch_raw) - ) - messages.extend(batch_results) - new_count += len(batch_results) - accumulated_count += len(batch_results) + results = await asyncio.gather(*(self._format_message(m) for m in batch_raw)) + for m_data, u_data in results: + batch_messages.append(m_data) + if u_data: batch_users.append(u_data) - # Throttled UI update: show preview only for the last message in the batch - if progress_callback and new_count % UI_LOG_INTERVAL < BATCH_SIZE: + new_count += len(batch_messages) + accumulated_count += len(batch_messages) + + for m in batch_messages: + if "attachments" in m: + accumulated_files += len(m["attachments"]) + + # Persist to DB + if self.db: + if batch_users: self.db.save_users(batch_users) + self.db.save_messages_batch(batch_messages) + + if progress_callback: last_msg = batch_raw[-1] - author = getattr(last_msg, "author", None) - author_name = getattr(author, "display_name", "Unknown") if author else "Unknown" - content = last_msg.content or "" - attachments_len = len(last_msg.attachments) if hasattr(last_msg, "attachments") else 0 - preview = content[:150] + ("..." if len(content) > 150 else "") - if attachments_len: - preview += f" [dim]({attachments_len} attachments)[/dim]" - if not preview: - preview = "[dim](no content)[/dim]" - await progress_callback(channel_name, accumulated_count, author_name=author_name, message_preview=preview) - elif progress_callback: - await progress_callback(channel_name, accumulated_count) - - # Periodic save of user_info.json (every ~100 messages) - if new_count % USER_SAVE_INTERVAL < BATCH_SIZE: - await self._save_json(user_info_file, list(self.user_cache.values())) + author_name = getattr(last_msg.author, "display_name", "Unknown") + preview = (last_msg.content or "")[:150] + await progress_callback(channel_name, accumulated_count, author_name=author_name, message_preview=preview, thread_count=accumulated_threads, file_count=accumulated_files) + batch_messages.clear() + batch_users.clear() batch_raw.clear() - # Process remaining messages in the last partial batch + # Final partial batch if batch_raw and self.is_running: - batch_results = await asyncio.gather( - *(self._format_message(m, asset_dir, asset_prefix, avatar_dir, avatar_rel_base) for m in batch_raw) - ) - messages.extend(batch_results) - new_count += len(batch_results) - accumulated_count += len(batch_results) + results = await asyncio.gather(*(self._format_message(m) for m in batch_raw)) + for m_data, u_data in results: + batch_messages.append(m_data) + if u_data: batch_users.append(u_data) + + new_count += len(batch_messages) + accumulated_count += len(batch_messages) + + for m in batch_messages: + if "attachments" in m: + accumulated_files += len(m["attachments"]) + + + if self.db: + if batch_users: self.db.save_users(batch_users) + self.db.save_messages_batch(batch_messages) if progress_callback: last_msg = batch_raw[-1] - author = getattr(last_msg, "author", None) - author_name = getattr(author, "display_name", "Unknown") if author else "Unknown" - content = last_msg.content or "" - preview = content[:150] + ("..." if len(content) > 150 else "") - if not preview: - preview = "[dim](no content)[/dim]" - await progress_callback(channel_name, accumulated_count, author_name=author_name, message_preview=preview) + author_name = getattr(last_msg.author, "display_name", "Unknown") + await progress_callback(channel_name, accumulated_count, author_name=author_name, thread_count=accumulated_threads, file_count=accumulated_files) - # Final user save after last batch - await self._save_json(user_info_file, list(self.user_cache.values())) + batch_messages.clear() + batch_users.clear() batch_raw.clear() except discord.Forbidden: logger.error(f"403 Forbidden: Missing Access to read messages in {channel_name} ({channel_id})") - if not messages: return accumulated_count except Exception as e: logger.error(f"Error fetching messages for {channel_name}: {e}") - if not messages: return accumulated_count - # If it's a forum or a channel with no new messages, we still want the UI to register that we've started it. - if new_count == 0 and progress_callback: - await progress_callback(channel_name, accumulated_count) + if not is_thread: + accumulated_count, accumulated_threads, accumulated_files = await self.export_threads(channel_id, progress_callback=progress_callback, force=force, accumulated_count=accumulated_count, accumulated_threads=accumulated_threads, accumulated_files=accumulated_files) - # 2. Handle Threads and collect counts accurately - all_threads = [] - try: - # Active threads: Use active_threads() coroutine for 2.6.4 - if self.reader.guild: - threads = await self.reader.guild.active_threads() - all_threads.extend([t for t in threads if t.parent_id == channel_id]) - - # Archived threads: Use the consolidated archived_threads() iterator - try: - if hasattr(channel, "archived_threads"): - async for thread in channel.archived_threads(limit=None): - all_threads.append(thread) - except discord.Forbidden: - logger.warning(f"403 Forbidden: Cannot fetch archived threads in {channel_name}") - except Exception as e: - logger.warning(f"Error fetching archived threads: {e}") - except Exception as e: - logger.warning(f"Failed to fetch threads for count in {channel_name}: {e}") + return accumulated_count, accumulated_threads, accumulated_files - thread_count = len(all_threads) - for t in all_threads: - await asyncio.sleep(0) # Yield for safety - thread_msg_count += (t.message_count or 0) - - msg_type = "Text" - if is_thread: - msg_type = "Thread" - elif channel.type == discord.ChannelType.news: - msg_type = "News" - elif is_forum: - msg_type = "Forum" - - output_data = { - "channelName": channel_name, - "channelID": str(channel_id), - "channelType": msg_type, - "messageCount": len(messages), - "threadCount": thread_count, - "lastMessageID": str(messages[-1]["messageID"]) if messages else None, - "threadMessagesCount": thread_msg_count, - "totalAttachmentSizeBytes": sum(m.get("totalFileSizeBytes", 0) for m in messages), - "numberOfAttachments": sum(m.get("numberOfFiles", 0) for m in messages), - "lastBackup": discord.utils.utcnow().isoformat(), - "messages": messages - } - - if is_thread: - output_data["parentID"] = str(channel.parent_id) - - # Merge additional metadata for forums (like tags) - if is_forum: - fmt_data = await self._format_channel(channel) - for k, v in fmt_data.items(): - if k not in output_data and k not in ["id", "name", "type", "position", "nsfw", "topic"]: - output_data[k] = v - - # Save channel messages - await asyncio.sleep(0) # Yield before writing large JSON - await self._save_json(json_file, output_data) - - # If it's a forum, also export its threads into the sub-directory - if is_forum: - accumulated_count = await self.export_threads(channel_id, progress_callback=progress_callback, force=force, accumulated_count=accumulated_count) - - return accumulated_count - - async def _format_message(self, msg, asset_dir, asset_prefix, avatar_dir, avatar_rel_base): - """Formats a single message to match the reference format.""" - attachments = [] - async def process_attachment(a): - # mimic reference asset naming (suffixing hash/id) - safe_name = a.filename - short_id = str(a.id)[-5:] - stored_name = f"{Path(safe_name).stem}-{short_id}{Path(safe_name).suffix}" - target = asset_dir / stored_name - - try: - # Check if exists, else download (basic cache) - if not target.exists(): - # Attachment.save() uses a thread internally to save to disk - await a.save(target) - - return { - "id": str(a.id), - "url": f"{asset_prefix}/{stored_name}", - "fileName": a.filename, - "fileSizeBytes": a.size - } - except Exception as e: - logger.error(f"Failed to download attachment {a.filename}: {e}") - return None - - # Download all attachments for this message concurrently - if msg.attachments: - results = await asyncio.gather(*(process_attachment(a) for a in msg.attachments)) - attachments = [r for r in results if r] - - # Author info extraction and deduplication + async def _format_message(self, msg): + """Formats a single message and its author for DB storage.""" + # 1. Author handling author = msg.author user_id = str(author.id) + user_data = None if user_id not in self.user_cache: - avatar_url = None + # New user discovered + avatar_file = None if author.avatar: try: av_name = f"{user_id}.png" - av_target = avatar_dir / av_name + av_target = self.users_path / av_name if not av_target.exists(): await author.avatar.save(av_target) - avatar_url = f"{avatar_rel_base}/{av_name}" + avatar_file = f"users/{av_name}" except Exception as e: logger.error(f"Failed to save avatar for {author.name}: {e}") roles = [] if hasattr(author, "roles"): - for r in author.roles: - if r.is_default(): continue - roles.append({ - "id": str(r.id), - "name": r.name, - "color": str(r.color), - "position": r.position - }) + roles = [str(r.id) for r in author.roles if not r.is_default()] - self.user_cache[user_id] = { - "userID": user_id, + user_data = { + "id": user_id, "username": author.name, - "userNickname": getattr(author, "display_name", author.name), - "userColor": str(author.color) if hasattr(author, "color") else None, - "userIsBot": author.bot, - "userRoles": roles, - "userAvatar": f"users/avatars/{user_id}.png" if author.avatar else None, - "userAvatarUrl": str(author.display_avatar.url) if author.avatar else None + "display_name": getattr(author, "display_name", author.name), + "avatar_file": avatar_file, + "avatar_url": str(author.display_avatar.url) if author.avatar else None, + "roles": json.dumps(roles) } + self.user_cache[user_id] = user_data - reactions = [] - for r in msg.reactions: - emoji_str = str(r.emoji) if not r.is_custom_emoji() else f"{r.emoji.name}:{r.emoji.id}" - reactions.append({ - "emoji": emoji_str, - "count": r.count - }) + # 2. Attachments handling (Content-Addressable Storage) + attachments = [] + if msg.attachments: + for att in msg.attachments: + att_data = await self._process_media( + media_id=att.id, + url=att.url, + filename=att.filename, + size=att.size, + content_type=att.content_type, + save_method=att.save + ) + if att_data: + attachments.append(att_data) - # Process Stickers (Download and Metadata) + # 2.5 Stickers handling stickers = [] - for s in msg.stickers: - sticker_filename = f"sticker_{s.id}" - # Extension mapping based on format - ext = "png" - if str(s.format).endswith("apng"): ext = "apng" - elif str(s.format).endswith("lottie"): ext = "json" - elif str(s.format).endswith("gif"): ext = "gif" - - sticker_filename += f".{ext}" - sticker_path = asset_dir / sticker_filename - - try: - if not sticker_path.exists(): - # Handle Lottie stickers manually since discord.py Refuses to save them - if str(s.format).endswith("lottie"): - # Use the name-mangled internal session from the client - session = self.reader.client.http._HTTPClient__session - async with session.get(s.url) as resp: - if resp.status == 200: - with open(sticker_path, "wb") as f: - f.write(await resp.read()) - else: - raise Exception(f"HTTP {resp.status}") - else: - await s.save(sticker_path) + if msg.stickers: + for st in msg.stickers: + # Determine extension based on format + ext = ".png" + if hasattr(st, "format"): + try: + from discord import StickerFormatType + if st.format == StickerFormatType.lottie: + ext = ".json" + elif st.format == StickerFormatType.apng: + ext = ".png" + elif st.format == StickerFormatType.gif: + ext = ".gif" + except ImportError: + pass - stickers.append({ - "id": str(s.id), - "name": s.name, - "format": str(s.format).split(".")[-1], - "localPath": f"{asset_prefix}/{sticker_filename}" - }) - except Exception as e: - logger.error(f"Failed to download sticker {s.name} ({s.id}): {e}") - # Fallback to minimal metadata if download fails - stickers.append({ - "id": str(s.id), - "name": s.name, - "format": str(s.format).split(".")[-1] + st_data = await self._process_media( + media_id=st.id, + url=st.url, + filename=f"{st.name}{ext}", + content_type=f"image/{ext[1:]}" if ext != ".json" else "application/json", + save_method=st.save + ) + if st_data: + st_data["format_type"] = int(st.format.value) if hasattr(st, "format") and hasattr(st.format, "value") else 1 + stickers.append(st_data) + + # 3. Embeds + embeds = [] + if msg.embeds: + embeds = [emb.to_dict() for emb in msg.embeds] + + # 4. Reactions + reactions = [] + if msg.reactions: + for react in msg.reactions: + emoji = react.emoji + reactions.append({ + "emoji_id": emoji.id if hasattr(emoji, "id") else None, + "emoji_name": emoji.name if hasattr(emoji, "name") else str(emoji), + "count": react.count }) - # Determine message type (Override if it's a thread starter or forward) - raw_repr = str(msg.type).lower() - - if "thread_starter" in raw_repr or msg.thread: - msg_type = "ThreadStarter" - else: - msg_type = raw_repr.split(".")[-1].capitalize() - - # Check for forwarded flags (newer discord.py feature) - try: - if hasattr(msg.flags, "forwarded") and msg.flags.forwarded: - msg_type = "Forward" - except Exception: - pass + # 5. Message data + # Check for reference (reply) + message_reference = None + if msg.reference and msg.reference.message_id: + message_reference = str(msg.reference.message_id) - msg_content = msg.content - if msg_type == "Forward" and not msg_content: - try: - if hasattr(msg, "message_snapshots") and msg.message_snapshots: - msg_content = msg.message_snapshots[0].content - except Exception: - pass - - data = { - "messageID": str(msg.id), - "type": msg_type, + m_data = { + "id": str(msg.id), + "channel_id": str(msg.channel.id), + "author_id": user_id, + "content": msg.content, "timestamp": msg.created_at.isoformat(), - "isPinned": msg.pinned, - "content": msg_content, - "userID": user_id, + "type": int(msg.type.value) if hasattr(msg.type, "value") else 0, + "message_reference": message_reference, + "is_pinned": 1 if msg.pinned else 0, "attachments": attachments, - "numberOfFiles": len(attachments), - "totalFileSizeBytes": sum(a["fileSizeBytes"] for a in attachments), - "embeds": [e.to_dict() for e in msg.embeds], "stickers": stickers, - "reactions": reactions + "embeds": embeds, + "reactions": reactions, + "extra_data": None } - # Thread info for creation/starter messages - if msg.thread: - data["thread"] = { - "id": str(msg.thread.id), - "name": msg.thread.name, - "messageCount": getattr(msg.thread, "message_count", 0), - "archived": msg.thread.archived, - "archiveDuration": msg.thread.auto_archive_duration, - "locked": msg.thread.locked - } + return m_data, user_data - # Add reply reference if exists - if msg.reference and msg.reference.message_id: - data["reference"] = { - "messageId": str(msg.reference.message_id), - "channelId": str(msg.reference.channel_id) - } + async def _process_media(self, media_id, url, filename, size=None, content_type=None, save_method=None): + """Downloads and deduplicates any media (attachment or sticker) using SHA-256 (CAS).""" + # 1. First check by URL in DB + if self.db: + existing = self.db.get_media_by_url(str(url)) + if existing: + return { + "id": str(media_id), + "filename": filename, + "size": existing["size"], + "url": str(url), + "content_type": existing["content_type"], + "local_hash": existing["hash"] + } - return data + # 2. Temporary download to calculate hash + import tempfile + import shutil + with tempfile.NamedTemporaryFile(delete=False) as tmp: + tmp_path = Path(tmp.name) + try: + if save_method: + await save_method(tmp_path) + else: + return None + + file_hash = self._calculate_sha256(tmp_path) + actual_size = tmp_path.stat().st_size + + # Check if hash already exists in pool + if self.db: + in_pool = self.db.get_media_by_hash(file_hash) + if in_pool: + tmp_path.unlink() + return { + "id": str(media_id), + "filename": filename, + "size": actual_size, + "url": str(url), + "content_type": content_type or in_pool["content_type"], + "local_hash": file_hash + } + + # New content: move to pool + ext = Path(filename).suffix + target_filename = f"{file_hash}{ext}" + target_path = self.attachments_path / target_filename + + shutil.move(str(tmp_path), str(target_path)) + + if self.db: + self.db.add_media_to_pool(file_hash, f"attachments/{target_filename}", actual_size, content_type, str(url)) + + return { + "id": str(media_id), + "filename": filename, + "size": actual_size, + "url": str(url), + "content_type": content_type, + "local_hash": file_hash + } + except Exception as e: + logger.error(f"Failed to process media {filename}: {e}") + if tmp_path.exists(): tmp_path.unlink() + return None - async def export_threads(self, channel_id: int, progress_callback=None, force=False, accumulated_count=0, after_id: int | None = None): - """Exports active and archived threads for a channel. Returns accumulated message count.""" + async def export_threads(self, channel_id: int, progress_callback=None, force=False, accumulated_count=0, accumulated_threads=0, accumulated_files=0, after_id: int | None = None): + """Exports active and archived threads for a channel to SQLite.""" channel = await self.reader.get_channel(channel_id) - if not hasattr(channel, "threads") and not hasattr(channel, "public_archived_threads"): - return 0 + if not hasattr(channel, "threads") and not hasattr(channel, "archived_threads"): + return accumulated_count, accumulated_threads, accumulated_files all_threads = [] try: @@ -740,108 +639,71 @@ class DiscordExporter: logger.error(f"Failed to fetch threads for {channel.name}: {e}") is_forum = isinstance(channel, discord.ForumChannel) - backup_root = self.export_path / "message_backup" - forum_json_file = backup_root / str(channel_id) / "messages.json" - forum_asset_dir = backup_root / str(channel_id) - avatar_dir = backup_root / "users" / "avatars" - thread_count = 0 - if all_threads: - logger.info(f"Found {len(all_threads)} threads in {channel.name}. Starting backup...") - - for thread in all_threads: - if not self.is_running: - logger.info("Thread backup cancelled by user.") - break - await asyncio.sleep(0) # important yield between threads - - # First backup the full thread — this creates {thread_id}.json with totalAttachmentSizeBytes - accumulated_count = await self.export_channel_messages(thread.id, progress_callback=progress_callback, force=force, accumulated_count=accumulated_count, after_id=after_id) - thread_count += 1 - - # Then populate the forum root JSON with the starter message - if is_forum: - logger.info(f"Adding starter message for thread: {thread.name} ({thread.id})") + if all_threads and self.db: + thread_meta = [] + for t in all_threads: + applied_tags = [] + + # If parent is missing, try to link it from our fetched channel object + # This ensures discord.py can resolve the applied_tags correctly try: - msg_found = False - # In discord.py 2.x, we get the oldest message by using 'after' with a limit - async for msg in thread.history(limit=1, after=discord.Object(id=thread.id - 1)): - msg_found = True - logger.debug(f"Found starter message {msg.id} for {thread.name}") - - # Save assets in the thread's own directory inside the forum directory - thread_asset_dir = forum_asset_dir / str(thread.id) / "thread_attachments" - thread_asset_dir.mkdir(parents=True, exist_ok=True) - - msg_data = await self._format_message( - msg, - thread_asset_dir, - f"{channel_id}/{thread.id}/thread_attachments", # Full relative path from message_backup/ - avatar_dir, - "../../users/avatars" # Two levels up from {forum_id}/{thread_id}/ - ) - # Override type and add title for forum starter messages - msg_data["type"] = "ThreadStarter" - msg_data["title"] = thread.name - - # Store applied tag IDs (as strings) — names are resolvable via the forum's available_tags - msg_data["tags"] = [str(tid) for tid in getattr(thread, "_applied_tags", [])] - - # Enrich totalFileSizeBytes with the child thread's totalAttachmentSizeBytes - # (the thread JSON has already been written above) - thread_json = backup_root / str(channel_id) / str(thread.id) / "thread_messages.json" - if thread_json.exists(): - try: - with open(thread_json, "r", encoding="utf-8") as f: - thread_data = json.load(f) - child_size = thread_data.get("totalAttachmentSizeBytes", 0) - msg_data["totalFileSizeBytes"] = msg_data.get("totalFileSizeBytes", 0) + child_size - - child_count = thread_data.get("numberOfAttachments", 0) - msg_data["numberOfFiles"] = msg_data.get("numberOfFiles", 0) + child_count - - logger.debug(f"Enriched files for {thread.name}: +{child_size} bytes, +{child_count} files from child thread") - except Exception as e: - logger.error(f"Failed to read thread JSON for size enrichment: {e}") - - if forum_json_file.exists(): - with open(forum_json_file, "r", encoding="utf-8") as f: - try: - forum_data = json.load(f) - except Exception as e: - logger.error(f"Failed to load forum JSON: {e}") - forum_data = {} - - if "messages" not in forum_data: - forum_data["messages"] = [] - - # Avoid duplicates — update if already exists (e.g. sync run) - existing = next((m for m in forum_data["messages"] if m["messageID"] == msg_data["messageID"]), None) - if existing: - existing.update(msg_data) - logger.debug(f"Updated starter message for {thread.name} in forum JSON") - else: - forum_data["messages"].append(msg_data) - - forum_data["messageCount"] = len(forum_data["messages"]) - # Recalculate forum totalAttachmentSizeBytes from enriched starter messages - forum_data["totalAttachmentSizeBytes"] = sum( - m.get("totalFileSizeBytes", 0) for m in forum_data["messages"] - ) - # Recalculate forum numberOfAttachments from enriched starter messages - forum_data["numberOfAttachments"] = sum( - m.get("numberOfFiles", 0) for m in forum_data["messages"] - ) - forum_data["messages"].sort(key=lambda x: x["timestamp"]) - - await asyncio.sleep(0) # Yield before writing - await self._save_json(forum_json_file, forum_data) - logger.info(f"Appended starter message for {thread.name} to {forum_json_file.name}") - else: - logger.warning(f"Forum JSON file does not exist: {forum_json_file}") + if t.parent is None: + # Internal hack: link parent if missing to help resolve tags + t._parent = channel + except Exception: + pass + + if hasattr(t, "applied_tags"): + # Attempt 1: Standard attribute + applied_tags = [str(tag.id) for tag in t.applied_tags] - if not msg_found: - logger.warning(f"No starter message found for thread: {thread.name}") + # Attempt 2: If still empty and it's a forum thread, it might not be loaded + if not applied_tags and is_forum: + try: + # We can try to fetch the thread specifically to get tags + # But we only do this if we really have to + # (Discord sometimes doesn't include tags in bulk guild.active_threads) + fetched_t = await self.reader.client.fetch_channel(t.id) + if hasattr(fetched_t, "applied_tags"): + applied_tags = [str(tag.id) for tag in fetched_t.applied_tags] + except Exception: + pass + + thread_meta.append({ + "id": str(t.id), + "name": t.name, + "type": int(t.type.value) if hasattr(t.type, "value") else 11, # Default to public_thread + "parent_id": str(t.parent_id) if t.parent_id else str(channel.id), + "message_count": getattr(t, "message_count", 0), + "member_count": getattr(t, "member_count", 0), + "archived": 1 if t.archived else 0, + "archive_timestamp": t.archive_timestamp.isoformat() if t.archive_timestamp else None, + "auto_archive_duration": t.auto_archive_duration, + "locked": 1 if getattr(t, "locked", False) else 0, + "applied_tags": json.dumps(applied_tags) + }) + self.db.save_threads(thread_meta) + + for thread in all_threads: + if not self.is_running: break + await asyncio.sleep(0) + + accumulated_threads += 1 + if progress_callback: + await progress_callback(channel.name, accumulated_count, thread_count=accumulated_threads, file_count=accumulated_files) + + # Backup thread messages + accumulated_count, accumulated_threads, accumulated_files = await self.export_channel_messages(thread.id, progress_callback=progress_callback, force=force, accumulated_count=accumulated_count, accumulated_threads=accumulated_threads, accumulated_files=accumulated_files, after_id=after_id) + + # For forums, ensure the starter message exists in the DB + if is_forum: + # starter_message is handled by export_channel_messages since it's just a message in that thread + # However we may want to mark it or store forum-specific tags + try: + # Just yield for concurrency + await asyncio.sleep(0) except Exception as e: - logger.error(f"Error adding starter message for {thread.name}: {e}") - return accumulated_count + logger.error(f"Error processing forum thread {thread.name}: {e}") + + return accumulated_count, accumulated_threads, accumulated_files diff --git a/src/ui/backup_stats.py b/src/ui/backup_stats.py index 69624b4..6ef9026 100644 --- a/src/ui/backup_stats.py +++ b/src/ui/backup_stats.py @@ -1,4 +1,5 @@ import json +import logging from pathlib import Path from datetime import datetime import asyncio @@ -7,11 +8,15 @@ from textual.app import ComposeResult from textual.screen import Screen from textual.containers import Container, Vertical, VerticalScroll, Horizontal from textual.widgets import Button, Label, Rule, Tree -from src.ui.widgets import RamDisplay from textual import work from rich.text import Text from rich.style import Style +from src.ui.widgets import RamDisplay +from src.core.backup_reader import BackupReader, ChannelType + +logger = logging.getLogger(__name__) + class BackupStatsScreen(Screen[None]): """Full-screen view for displaying detailed backup statistics.""" @@ -254,178 +259,133 @@ class BackupStatsScreen(Screen[None]): @work(exclusive=True) async def load_data(self) -> None: try: - # Locate the correct backup directory passed from backup_ops - if not self.target_dir.exists(): - raise FileNotFoundError(f"Config directory {self.target_dir} not found.") - - target_dir = self.target_dir - if not (target_dir / "server_profile" / "profile.json").exists(): + # Initialize BackupReader + reader = BackupReader(self.target_dir) + await reader.start() + + if not reader.guild: raise FileNotFoundError(f"No valid backup found in {self.target_dir}") - - self.profile_path = target_dir / "server_profile" - self.backup_path = target_dir / "message_backup" - + # 1. Profile / Server Info - server_name = "Unknown Server" - server_id = "0" + guild = reader.guild + server_name = guild.name + server_id = str(guild.id) + + # Get last backup info from guild profile + profile = reader.db.get_guild_profile() last_backup_str = "Never" - profile_file = self.profile_path / "profile.json" - if profile_file.exists(): - with open(profile_file, "r", encoding="utf-8") as f: - data = json.load(f) - server_name = data.get("name", server_name) - server_id = data.get("id", server_id) - ts = data.get("last_backup") - if ts: - try: - dt = datetime.fromisoformat(ts) - last_backup_str = dt.strftime("%d %b, %Y - %H:%M") - except Exception: - last_backup_str = ts + ts = profile.get("last_backup") + if ts: + try: + dt = datetime.fromisoformat(ts) + last_backup_str = dt.strftime("%d %b, %Y - %H:%M") + except Exception: + last_backup_str = ts self.query_one("#bs_name", Label).update(f"{server_name}") self.query_one("#bs_id", Label).update(f"{server_id}") self.query_one("#bs_last_backup", Label).update(f"{last_backup_str}") - # 2. Assets / Entities - member_count = 0 - emoji_count = 0 - sticker_count = 0 - - # Fetch members from users/user_info.json - if self.backup_path: - user_info_file = self.backup_path / "users" / "user_info.json" - if user_info_file.exists(): - try: - with open(user_info_file, "r", encoding="utf-8") as f: - data = json.load(f) - if isinstance(data, list): - member_count = len(data) - elif isinstance(data, dict): - member_count = len(data) - except Exception: - pass - - assets_file = self.profile_path / "assets.json" - if assets_file.exists(): - with open(assets_file, "r", encoding="utf-8") as f: - data = json.load(f) - emoji_count = len(data.get("emojis", [])) - sticker_count = len(data.get("stickers", [])) - - role_count = 0 - roles_file = self.profile_path / "roles.json" - if roles_file.exists(): - with open(roles_file, "r", encoding="utf-8") as f: - role_count = len(json.load(f)) + # 2. Assets / Entities (using properties) + member_count = len(reader.members) + emoji_count = len(reader.emojis) + sticker_count = len(reader.stickers) + role_count = len(reader.roles) self.query_one("#bs_val_members", Label).update(f"{member_count}") self.query_one("#bs_val_roles", Label).update(f"{role_count}") self.query_one("#bs_val_emojis", Label).update(f"{emoji_count}") self.query_one("#bs_val_stickers", Label).update(f"{sticker_count}") - # 3. Structure & Per-Channel Stats - structure_file = self.profile_path / "structure.json" - total_channels = 0 - backed_up_channels = 0 + # 3. Aggregate Stats from DB + channel_stats = reader.db.get_stats_by_channel() - total_msgs = 0 - total_threads = 0 - total_files = 0 - total_size = 0 + total_msgs = sum(s["message_count"] for s in channel_stats.values()) + total_threads = sum(s["thread_count"] for s in channel_stats.values()) + total_files = sum(s["attachment_count"] for s in channel_stats.values()) + total_size = sum(s["total_size"] for s in channel_stats.values()) - cat_nodes = [] # Collect category data + backed_up_channel_ids = set(channel_stats.keys()) + + # 4. Structure & Per-Channel Stats + categories = reader.categories + all_channels = reader.channels + + total_channels = len([c for c in all_channels if c.type != ChannelType.category]) + backed_up_channels = len([cid for cid in backed_up_channel_ids if any(c.id == cid for c in all_channels)]) + + # Helper to map channels to categories + cat_map = {cat.id: {"cat": cat, "chans": []} for cat in categories} + cat_map[None] = {"cat": None, "chans": []} # Uncategorized + + for chan in all_channels: + if chan.type == ChannelType.category: continue + cid = chan.category_id + if cid not in cat_map: cid = None + cat_map[cid]["chans"].append(chan) - if structure_file.exists(): - with open(structure_file, "r", encoding="utf-8") as f: - structure = json.load(f) - - for cat in structure: - cat_name = cat.get("name", "Unknown Category").upper() - c_msgs = 0 - c_thds = 0 - c_files = 0 - c_size = 0 - - chan_list = [] - - for chan in cat.get("channels", []): - total_channels += 1 - ch_id = chan.get("id") - ch_name = f"# {chan.get('name', 'unknown')}" - - m_count = 0 - t_count = 0 - f_count = 0 - s_bytes = 0 - is_backed_up = False - - msg_file = self.backup_path / str(ch_id) / "messages.json" - if msg_file.exists(): - is_backed_up = True - backed_up_channels += 1 - try: - with open(msg_file, "r", encoding="utf-8") as mf: - mdata = json.load(mf) - m_count = mdata.get("messageCount", 0) - t_count = mdata.get("threadCount", 0) - f_count = mdata.get("numberOfAttachments", 0) - s_bytes = mdata.get("totalAttachmentSizeBytes", 0) - except Exception: - pass - - chan_list.append({ - "name": ch_name, - "msgs": m_count if is_backed_up else "NA", - "threads": t_count if is_backed_up else "NA", - "files": f_count if is_backed_up else "NA", - "size": s_bytes, - "is_backed_up": is_backed_up - }) - c_msgs += m_count - c_thds += t_count - c_files += f_count - c_size += s_bytes - - total_msgs += c_msgs - total_threads += c_thds - total_files += c_files - total_size += c_size - - cat_nodes.append({ - "name": cat_name, - "channels": chan_list, - "msgs": c_msgs, - "threads": c_thds, - "files": c_files, - "size": c_size - }) - - # 4. Global Stats + # Global update self.query_one("#bs_val_msgs", Label).update(f"{total_msgs}") self.query_one("#bs_val_threads", Label).update(f"{total_threads}") self.query_one("#bs_val_files", Label).update(f"{total_files}") self.query_one("#bs_val_size", Label).update(f"{self._format_size(total_size)}") - self.query_one("#bs_val_coverage", Label).update(f"{backed_up_channels} / {total_channels}") # 5. Build Tree - for cat in cat_nodes: - cat_lbl = self._format_tree_row(f"{cat['name']}", cat['msgs'], cat['threads'], cat['files'], self._format_size(cat['size'])) + for cat_id, info in cat_map.items(): + cat = info["cat"] + chans = info["chans"] + if not chans: continue + + cat_name = cat.name.upper() if cat else "UNCATEGORIZED" + + # Aggregate for category + c_msgs = 0 + c_thds = 0 + c_files = 0 + c_size = 0 + + chan_nodes_data = [] + for ch in chans: + stats = channel_stats.get(ch.id, {"message_count": 0, "thread_count": 0, "attachment_count": 0, "total_size": 0}) + is_bu = ch.id in backed_up_channel_ids + + chan_nodes_data.append({ + "name": f"# {ch.name}", + "msgs": stats["message_count"] if is_bu else "NA", + "threads": stats["thread_count"] if is_bu else "NA", + "files": stats["attachment_count"] if is_bu else "NA", + "size": stats["total_size"] if is_bu else 0, + "is_backed_up": is_bu + }) + + c_msgs += stats["message_count"] + c_thds += stats["thread_count"] + c_files += stats["attachment_count"] + c_size += stats["total_size"] + + cat_lbl = self._format_tree_row(cat_name, c_msgs, c_thds, c_files, self._format_size(c_size)) cat_lbl.stylize("bold yellow") node = self.stats_tree.root.add(cat_lbl, expand=True) - for ch in cat["channels"]: - size_str = self._format_size(ch['size']) if ch['is_backed_up'] else "NA" - ch_lbl = self._format_tree_row(f" {ch['name']}", ch['msgs'], ch['threads'], ch['files'], size_str) + for ch_data in chan_nodes_data: + size_str = self._format_size(ch_data['size']) if ch_data['is_backed_up'] else "NA" + ch_lbl = self._format_tree_row(f" {ch_data['name']}", ch_data['msgs'], ch_data['threads'], ch_data['files'], size_str) - if ch['is_backed_up']: + if ch_data['is_backed_up']: ch_lbl.stylize("bold white") else: - ch_lbl.stylize("dim white") # Textual 'dim' looks like a dull grey + ch_lbl.stylize("dim white") node.add_leaf(ch_lbl) except Exception as e: + import traceback + logger.exception("Failed to load backup stats") + trace_str = traceback.format_exc() + logger.error(f"Full Traceback: {trace_str}") self.query_one("#bs_name", Label).update(f"[red]Error loading data[/red]") self.query_one("#bs_id", Label).update(f"[red]{e}[/red]") + finally: + if 'reader' in locals(): + await reader.close() diff --git a/src/ui/shuttle_ops.py b/src/ui/shuttle_ops.py index 62189de..db29e4a 100644 --- a/src/ui/shuttle_ops.py +++ b/src/ui/shuttle_ops.py @@ -183,7 +183,6 @@ class OperationPane(Container): self.exporter = DiscordExporter(self.engine.discord_reader, base_dir=self._base_dir()) def _get_backup_info(self) -> str | None: - import json if not self.config or not self.config.discord_server_id: return None @@ -191,14 +190,16 @@ class OperationPane(Container): if not target_dir.exists(): return None - profile_file = target_dir / "server_profile" / "profile.json" - if not profile_file.exists(): + db_file = target_dir / "backup.db" + if not db_file.exists(): return None try: from datetime import datetime - with open(profile_file, "r", encoding="utf-8") as f: - data = json.load(f) - ts_str = data.get("last_backup") + from src.core.backup_database import BackupDatabase + db = BackupDatabase(db_file) + profile = db.get_guild_profile() + if profile: + ts_str = profile.get("last_backup") if ts_str: dt = datetime.fromisoformat(ts_str) return dt.strftime("%d-%b-%Y %H:%M") @@ -1767,10 +1768,12 @@ class OperationPane(Container): any_found = False backed_up_ids = set() - for chan in eligible_channels: - if (self.exporter.export_path / "message_backup" / str(chan.id) / "messages.json").exists(): - any_found = True - backed_up_ids.add(chan.id) + if self.exporter.db: + channel_stats = self.exporter.db.get_stats_by_channel() + for chan in eligible_channels: + if chan.id in channel_stats: + any_found = True + backed_up_ids.add(chan.id) self.app.pop_screen() @@ -1867,6 +1870,8 @@ class OperationPane(Container): modal_prog.write(f"[yellow]Starting backup for {total_chans} channels...[/yellow]") accumulated_msgs = 0 + accumulated_threads = 0 + accumulated_files = 0 for i, chan in enumerate(selected_channels): if not self.exporter.is_running: @@ -1874,7 +1879,7 @@ class OperationPane(Container): break await asyncio.sleep(0.01) # Yield to UI thread to keep it responsive - backup_exists = (self.exporter.export_path / "message_backup" / str(chan.id) / "messages.json").exists() + backup_exists = chan.id in backed_up_ids is_sync = backup_exists and not force_overwrite label = "Syncing Backup" if is_sync else "Backing up" @@ -1884,20 +1889,17 @@ class OperationPane(Container): logger.info(f"{label} for channel: #{chan.name} ({chan.id})") _msg_log_counter = 0 - async def update_msg_count(name, count, author_name=None, message_preview=None): + async def update_msg_count(name, count, author_name=None, message_preview=None, thread_count=0, file_count=0): nonlocal _msg_log_counter - modal_prog.update_stats(messages=str(count)) + modal_prog.update_stats(messages=str(count), threads=str(thread_count), files=str(file_count)) _msg_log_counter += 1 if author_name and message_preview and _msg_log_counter % 10 == 0: modal_prog.write(f"[bold]{author_name}:[/bold] {message_preview}") - accumulated_msgs = await self.exporter.export_channel_messages( + accumulated_msgs, accumulated_threads, accumulated_files = await self.exporter.export_channel_messages( chan.id, progress_callback=update_msg_count, force=force_overwrite, - accumulated_count=accumulated_msgs, after_id=after_id - ) - accumulated_msgs = await self.exporter.export_threads( - chan.id, progress_callback=update_msg_count, force=force_overwrite, - accumulated_count=accumulated_msgs + accumulated_count=accumulated_msgs, accumulated_threads=accumulated_threads, accumulated_files=accumulated_files, + after_id=after_id ) modal_prog.write(f"[green]Completed: {chan.name}[/green]") @@ -1981,9 +1983,14 @@ class OperationPane(Container): ] ] + # Get channels that have messages in the database + backed_up_channel_ids = set() + if self.exporter.db: + backed_up_channel_ids = set(self.exporter.db.get_stats_by_channel().keys()) + selected_channels = [ c for c in eligible_channels - if (self.exporter.export_path / "message_backup" / str(c.id) / "messages.json").exists() + if c.id in backed_up_channel_ids ] if not selected_channels: @@ -1999,6 +2006,8 @@ class OperationPane(Container): modal_prog.cancel_callback = lambda: setattr(self.exporter, "is_running", False) accumulated_msgs = 0 + accumulated_threads = 0 + accumulated_files = 0 for i, chan in enumerate(selected_channels): if not self.exporter.is_running: @@ -2012,20 +2021,16 @@ class OperationPane(Container): logger.info(f"Syncing backup for channel: #{chan.name} ({chan.id})") _msg_log_counter = 0 - async def update_msg_count(name, count, author_name=None, message_preview=None): + async def update_msg_count(name, count, author_name=None, message_preview=None, thread_count=0, file_count=0): nonlocal _msg_log_counter - modal_prog.update_stats(messages=str(count)) + modal_prog.update_stats(messages=str(count), threads=str(thread_count), files=str(file_count)) _msg_log_counter += 1 if author_name and message_preview and _msg_log_counter % 10 == 0: modal_prog.write(f"[bold]{author_name}:[/bold] {message_preview}") - accumulated_msgs = await self.exporter.export_channel_messages( + accumulated_msgs, accumulated_threads, accumulated_files = await self.exporter.export_channel_messages( chan.id, progress_callback=update_msg_count, force=False, - accumulated_count=accumulated_msgs - ) - accumulated_msgs = await self.exporter.export_threads( - chan.id, progress_callback=update_msg_count, force=False, - accumulated_count=accumulated_msgs + accumulated_count=accumulated_msgs, accumulated_threads=accumulated_threads, accumulated_files=accumulated_files ) modal_prog.write(f"[green]Synced: {chan.name}[/green]")