From cefa477459af40c588363b3b37f6abc9fb20fc2e Mon Sep 17 00:00:00 2001 From: rambros Date: Wed, 25 Mar 2026 13:45:17 +0530 Subject: [PATCH] optimize backup speed --- src/core/backup_database.py | 383 +++++++++++++++--------------------- src/core/exporter.py | 206 +++++++++++-------- 2 files changed, 287 insertions(+), 302 deletions(-) diff --git a/src/core/backup_database.py b/src/core/backup_database.py index c96f9ad..a6f7adf 100644 --- a/src/core/backup_database.py +++ b/src/core/backup_database.py @@ -25,17 +25,18 @@ class BackupDatabase: def __init__(self, db_path: Path | str): self.db_path = Path(db_path) self._lock = threading.Lock() + self._conn = sqlite3.connect(str(self.db_path), check_same_thread=False) + self._conn.row_factory = sqlite3.Row + # WAL mode allows concurrent readers and batches disk flushes significantly + self._conn.execute("PRAGMA journal_mode=WAL") + self._conn.execute("PRAGMA synchronous=NORMAL") + self._conn.execute("PRAGMA cache_size=-32000") # 32 MB page cache self._init_db() - def _get_conn(self): - conn = sqlite3.connect(self.db_path, check_same_thread=False) - conn.row_factory = sqlite3.Row - return conn - def _init_db(self): """Initializes the database schema.""" with self._lock: - conn = self._get_conn() + conn = self._conn try: # Guild Profile conn.execute(""" @@ -242,12 +243,11 @@ class BackupDatabase: conn.commit() finally: - conn.close() + pass # persistent connection — do not close def set_guild_profile(self, data: Dict[str, Any]): with self._lock: - conn = self._get_conn() - conn.execute(""" + self._conn.execute(""" INSERT OR REPLACE INTO guild_profile (id, name, description, icon_file, icon_url, banner_file, banner_url, owner_id, last_backup, ignore_channels) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( @@ -257,14 +257,11 @@ class BackupDatabase: str(data.get("owner_id")), data.get("last_backup"), json.dumps(data.get("ignore_channels", [])) )) - conn.commit() - conn.close() + self._conn.commit() def get_guild_profile(self) -> Optional[Dict[str, Any]]: with self._lock: - conn = self._get_conn() - row = conn.execute("SELECT * FROM guild_profile LIMIT 1").fetchone() - conn.close() + row = self._conn.execute("SELECT * FROM guild_profile LIMIT 1").fetchone() if row: data = dict(row) if data.get("ignore_channels"): @@ -276,11 +273,8 @@ class BackupDatabase: def save_roles(self, roles: List[Dict[str, Any]]): with self._lock: - conn = self._get_conn() - # Ensure complex fields are strings if they aren't already - formatted = [] - for r in roles: - formatted.append({ + formatted = [ + { "id": str(r["id"]), "name": r["name"], "color": r["color"], @@ -288,225 +282,193 @@ class BackupDatabase: "permissions": str(r["permissions"]), "hoist": 1 if r["hoist"] else 0, "mentionable": 1 if r["mentionable"] else 0 - }) - - conn.executemany(""" + } + for r in roles + ] + self._conn.executemany(""" INSERT OR REPLACE INTO roles (id, name, color, position, permissions, hoist, mentionable) VALUES (:id, :name, :color, :position, :permissions, :hoist, :mentionable) """, formatted) - conn.commit() - conn.close() + self._conn.commit() def save_channels(self, channels: List[Dict[str, Any]]): with self._lock: - conn = self._get_conn() - conn.executemany(""" + self._conn.executemany(""" INSERT OR REPLACE INTO channels (id, name, type, position, category_id, topic, nsfw, bitrate, slowmode_delay) VALUES (:id, :name, :type, :position, :category_id, :topic, :nsfw, :bitrate, :slowmode_delay) """, channels) - conn.commit() - conn.close() + self._conn.commit() def save_permissions(self, permissions: List[Dict[str, Any]]): """Saves a batch of channel permission overwrites.""" with self._lock: - conn = self._get_conn() - try: - conn.executemany(""" - INSERT INTO permissions (channel_id, target_id, target_type, allow, deny) - VALUES (:channel_id, :target_id, :target_type, :allow, :deny) - """, permissions) - conn.commit() - finally: - conn.close() + self._conn.executemany(""" + INSERT INTO permissions (channel_id, target_id, target_type, allow, deny) + VALUES (:channel_id, :target_id, :target_type, :allow, :deny) + """, permissions) + self._conn.commit() def save_users(self, users: List[Dict[str, Any]]): """Saves users to the author cache.""" with self._lock: - conn = self._get_conn() - conn.executemany(""" + self._conn.executemany(""" INSERT OR REPLACE INTO users (id, username, display_name, avatar_file, avatar_url, roles) VALUES (:id, :username, :display_name, :avatar_file, :avatar_url, :roles) """, users) - conn.commit() - conn.close() + self._conn.commit() def save_server_assets(self, assets: List[Dict[str, Any]]): """Saves a batch of server assets (emojis, stickers) to the database.""" with self._lock: - conn = self._get_conn() - formatted = [] - for a in assets: - formatted.append({ + formatted = [ + { "id": str(a["id"]), "name": a.get("name"), "type": a.get("type"), "filename": a.get("filename"), "url": a.get("url"), "mime_type": a.get("mime_type") - }) - - conn.executemany(""" + } + for a in assets + ] + self._conn.executemany(""" INSERT OR REPLACE INTO server_assets (id, name, type, filename, url, mime_type) VALUES (:id, :name, :type, :filename, :url, :mime_type) """, formatted) - conn.commit() - conn.close() + self._conn.commit() def save_threads(self, threads: List[Dict[str, Any]]): """Saves metadata for threads to the database.""" with self._lock: - conn = self._get_conn() - try: - conn.executemany(""" - INSERT OR REPLACE INTO threads (id, name, type, parent_id, message_count, member_count, archived, archive_timestamp, auto_archive_duration, locked, applied_tags) - VALUES (:id, :name, :type, :parent_id, :message_count, :member_count, :archived, :archive_timestamp, :auto_archive_duration, :locked, :applied_tags) - """, threads) - conn.commit() - finally: - conn.close() + self._conn.executemany(""" + INSERT OR REPLACE INTO threads (id, name, type, parent_id, message_count, member_count, archived, archive_timestamp, auto_archive_duration, locked, applied_tags) + VALUES (:id, :name, :type, :parent_id, :message_count, :member_count, :archived, :archive_timestamp, :auto_archive_duration, :locked, :applied_tags) + """, threads) + self._conn.commit() def save_forum_tags(self, tags: List[Dict[str, Any]]): """Saves definitions for forum tags.""" with self._lock: - conn = self._get_conn() - try: - conn.executemany(""" - INSERT OR REPLACE INTO forum_tags (id, forum_id, name, moderated, emoji_id, emoji_name) - VALUES (:id, :forum_id, :name, :moderated, :emoji_id, :emoji_name) - """, tags) - conn.commit() - finally: - conn.close() + self._conn.executemany(""" + INSERT OR REPLACE INTO forum_tags (id, forum_id, name, moderated, emoji_id, emoji_name) + VALUES (:id, :forum_id, :name, :moderated, :emoji_id, :emoji_name) + """, tags) + self._conn.commit() def save_messages_batch(self, messages: List[Dict[str, Any]]): """Batch inserts messages and their attachments.""" with self._lock: - conn = self._get_conn() - try: - # Insert messages + conn = self._conn + # Insert messages + conn.executemany(""" + INSERT OR REPLACE INTO messages (id, channel_id, author_id, content, timestamp, type, message_reference, is_pinned, extra_data) + VALUES (:id, :channel_id, :author_id, :content, :timestamp, :type, :message_reference, :is_pinned, :extra_data) + """, messages) + + # Extract attachments, reactions, and stickers + all_attachments = [] + all_reactions = [] + all_stickers = [] + + for msg in messages: + # Attachments + if "attachments" in msg: + for att in msg["attachments"]: + att["message_id"] = msg["id"] + all_attachments.append(att) + + # Reactions + if "reactions" in msg and msg["reactions"]: + for rea in msg["reactions"]: + all_reactions.append({ + "message_id": msg["id"], + "emoji_id": str(rea["emoji_id"]) if rea.get("emoji_id") else None, + "emoji_name": rea.get("emoji_name"), + "count": rea.get("count", 0) + }) + + # Stickers + if "stickers" in msg and msg["stickers"]: + for st in msg["stickers"]: + all_stickers.append({ + "message_id": msg["id"], + "sticker_id": str(st["id"]), + "name": st.get("name"), + "url": st.get("url"), + "format_type": st.get("format_type"), + "local_hash": st.get("local_hash") + }) + + if all_attachments: conn.executemany(""" - INSERT OR REPLACE INTO messages (id, channel_id, author_id, content, timestamp, type, message_reference, is_pinned, extra_data) - VALUES (:id, :channel_id, :author_id, :content, :timestamp, :type, :message_reference, :is_pinned, :extra_data) - """, messages) - - # Extract attachments, reactions, and stickers - all_attachments = [] - all_reactions = [] - all_stickers = [] - - for msg in messages: - # Attachments - if "attachments" in msg: - for att in msg["attachments"]: - att["message_id"] = msg["id"] - all_attachments.append(att) - - # Reactions - if "reactions" in msg and msg["reactions"]: - for rea in msg["reactions"]: - all_reactions.append({ - "message_id": msg["id"], - "emoji_id": str(rea["emoji_id"]) if rea.get("emoji_id") else None, - "emoji_name": rea.get("emoji_name"), - "count": rea.get("count", 0) - }) - - # Stickers - if "stickers" in msg and msg["stickers"]: - for st in msg["stickers"]: - all_stickers.append({ - "message_id": msg["id"], - "sticker_id": str(st["id"]), - "name": st.get("name"), - "url": st.get("url"), - "format_type": st.get("format_type"), - "local_hash": st.get("local_hash") - }) - - if all_attachments: - conn.executemany(""" - INSERT OR REPLACE INTO attachments (id, message_id, filename, size, url, content_type, local_hash) - VALUES (:id, :message_id, :filename, :size, :url, :content_type, :local_hash) - """, all_attachments) + INSERT OR REPLACE INTO attachments (id, message_id, filename, size, url, content_type, local_hash) + VALUES (:id, :message_id, :filename, :size, :url, :content_type, :local_hash) + """, all_attachments) - # Save Embeds (Normalized with JSON Fields) - for msg in messages: - if "embeds" in msg and msg["embeds"]: - for emb in msg["embeds"]: - # Insert top-level embed - conn.execute(""" - INSERT INTO embeds ( - message_id, title, description, url, color, timestamp, - thumbnail_url, image_url, author_name, author_url, - author_icon_url, footer_text, footer_icon_url, fields - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - msg["id"], emb.get("title"), emb.get("description"), emb.get("url"), - emb.get("color"), emb.get("timestamp"), - emb.get("thumbnail", {}).get("url") if isinstance(emb.get("thumbnail"), dict) else None, - emb.get("image", {}).get("url") if isinstance(emb.get("image"), dict) else None, - emb.get("author", {}).get("name") if isinstance(emb.get("author"), dict) else None, - emb.get("author", {}).get("url") if isinstance(emb.get("author"), dict) else None, - emb.get("author", {}).get("icon_url") if isinstance(emb.get("author"), dict) else None, - emb.get("footer", {}).get("text") if isinstance(emb.get("footer"), dict) else None, - emb.get("footer", {}).get("icon_url") if isinstance(emb.get("footer"), dict) else None, - json.dumps(emb.get("fields", [])) - )) + # Save Embeds (Normalized with JSON Fields) + for msg in messages: + if "embeds" in msg and msg["embeds"]: + for emb in msg["embeds"]: + conn.execute(""" + INSERT INTO embeds ( + message_id, title, description, url, color, timestamp, + thumbnail_url, image_url, author_name, author_url, + author_icon_url, footer_text, footer_icon_url, fields + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + msg["id"], emb.get("title"), emb.get("description"), emb.get("url"), + emb.get("color"), emb.get("timestamp"), + emb.get("thumbnail", {}).get("url") if isinstance(emb.get("thumbnail"), dict) else None, + emb.get("image", {}).get("url") if isinstance(emb.get("image"), dict) else None, + emb.get("author", {}).get("name") if isinstance(emb.get("author"), dict) else None, + emb.get("author", {}).get("url") if isinstance(emb.get("author"), dict) else None, + emb.get("author", {}).get("icon_url") if isinstance(emb.get("author"), dict) else None, + emb.get("footer", {}).get("text") if isinstance(emb.get("footer"), dict) else None, + emb.get("footer", {}).get("icon_url") if isinstance(emb.get("footer"), dict) else None, + json.dumps(emb.get("fields", [])) + )) - if all_reactions: - conn.executemany(""" - INSERT INTO reactions (message_id, emoji_id, emoji_name, count) - VALUES (:message_id, :emoji_id, :emoji_name, :count) - """, all_reactions) - - if all_stickers: - conn.executemany(""" - INSERT OR REPLACE INTO message_stickers (message_id, sticker_id, name, url, format_type, local_hash) - VALUES (:message_id, :sticker_id, :name, :url, :format_type, :local_hash) - """, all_stickers) - - conn.commit() - finally: - conn.close() + if all_reactions: + conn.executemany(""" + INSERT INTO reactions (message_id, emoji_id, emoji_name, count) + VALUES (:message_id, :emoji_id, :emoji_name, :count) + """, all_reactions) + + if all_stickers: + conn.executemany(""" + INSERT OR REPLACE INTO message_stickers (message_id, sticker_id, name, url, format_type, local_hash) + VALUES (:message_id, :sticker_id, :name, :url, :format_type, :local_hash) + """, all_stickers) + + conn.commit() def get_last_message_id(self, channel_id: str) -> Optional[str]: with self._lock: - conn = self._get_conn() - row = conn.execute("SELECT id FROM messages WHERE channel_id = ? ORDER BY id DESC LIMIT 1", (str(channel_id),)).fetchone() - conn.close() + row = self._conn.execute("SELECT id FROM messages WHERE channel_id = ? ORDER BY id DESC LIMIT 1", (str(channel_id),)).fetchone() return row["id"] if row else None def get_media_by_hash(self, file_hash: str) -> Optional[Dict[str, Any]]: with self._lock: - conn = self._get_conn() - row = conn.execute("SELECT * FROM media_pool WHERE hash = ?", (file_hash,)).fetchone() - conn.close() + row = self._conn.execute("SELECT * FROM media_pool WHERE hash = ?", (file_hash,)).fetchone() return dict(row) if row else None def get_media_by_url(self, url: str) -> Optional[Dict[str, Any]]: with self._lock: - conn = self._get_conn() - row = conn.execute("SELECT * FROM media_pool WHERE first_seen_url = ?", (url,)).fetchone() - conn.close() + row = self._conn.execute("SELECT * FROM media_pool WHERE first_seen_url = ?", (url,)).fetchone() return dict(row) if row else None def add_media_to_pool(self, file_hash: str, local_path: str, size: int, mime_type: str, url: str): + # NOTE: No commit here — caller (save_messages_batch) commits at end of batch with self._lock: - conn = self._get_conn() - conn.execute(""" + self._conn.execute(""" INSERT OR REPLACE INTO media_pool (hash, local_path, size, mime_type, first_seen_url) VALUES (?, ?, ?, ?, ?) """, (file_hash, str(local_path), size, mime_type, url)) - conn.commit() - conn.close() def get_stats_by_channel(self) -> Dict[int, Dict[str, Any]]: """Returns aggregate stats for all channels with backups.""" with self._lock: - conn = self._get_conn() - # Summary of messages per effective channel (including threads rolled up) - msg_rows = conn.execute(""" + msg_rows = self._conn.execute(""" SELECT COALESCE(t.parent_id, m.channel_id) as channel_id, COUNT(m.id) as msg_count @@ -515,15 +477,13 @@ class BackupDatabase: GROUP BY channel_id """).fetchall() - # Thread counts per parent - thread_rows = conn.execute(""" + thread_rows = self._conn.execute(""" SELECT parent_id, COUNT(*) as thread_count FROM threads GROUP BY parent_id """).fetchall() - # Summary of attachments per effective channel (including threads rolled up) - att_rows = conn.execute(""" + att_rows = self._conn.execute(""" SELECT COALESCE(t.parent_id, m.channel_id) as channel_id, COUNT(a.id) as att_count, @@ -534,8 +494,6 @@ class BackupDatabase: GROUP BY channel_id """).fetchall() - conn.close() - stats = {} for r in msg_rows: cid = parse_snowflake(r["channel_id"]) @@ -566,22 +524,18 @@ class BackupDatabase: def get_all_roles(self) -> List[Dict[str, Any]]: with self._lock: - conn = self._get_conn() - rows = conn.execute("SELECT * FROM roles ORDER BY position DESC").fetchall() - conn.close() + rows = self._conn.execute("SELECT * FROM roles ORDER BY position DESC").fetchall() return [dict(r) for r in rows] def get_all_channels(self) -> List[Dict[str, Any]]: with self._lock: - conn = self._get_conn() - rows = conn.execute("SELECT * FROM channels ORDER BY position ASC").fetchall() + rows = self._conn.execute("SELECT * FROM channels ORDER BY position ASC").fetchall() - # Fetch all permissions chan_list = [dict(r) for r in rows] if chan_list: ids = [c["id"] for c in chan_list] placeholders = ",".join(["?"] * len(ids)) - perm_rows = conn.execute(f"SELECT * FROM permissions WHERE channel_id IN ({placeholders})", ids).fetchall() + perm_rows = self._conn.execute(f"SELECT * FROM permissions WHERE channel_id IN ({placeholders})", ids).fetchall() perms_by_chan = {} for pr in perm_rows: @@ -597,8 +551,7 @@ class BackupDatabase: for c in chan_list: c["overwrites"] = perms_by_chan.get(c["id"], []) - # Fetch Forum Tags - tag_rows = conn.execute(f"SELECT * FROM forum_tags WHERE forum_id IN ({placeholders})", ids).fetchall() + tag_rows = self._conn.execute(f"SELECT * FROM forum_tags WHERE forum_id IN ({placeholders})", ids).fetchall() tags_by_forum = {} for tr in tag_rows: fid = tr["forum_id"] @@ -608,56 +561,43 @@ class BackupDatabase: for c in chan_list: c["available_tags"] = tags_by_forum.get(c["id"], []) - conn.close() return chan_list def get_all_threads(self) -> List[Dict[str, Any]]: """Returns metadata for all threads in the backup.""" with self._lock: - conn = self._get_conn() - rows = conn.execute("SELECT * FROM threads").fetchall() - conn.close() + rows = self._conn.execute("SELECT * FROM threads").fetchall() return [dict(r) for r in rows] def get_forum_tags(self, forum_id: Optional[str] = None) -> List[Dict[str, Any]]: """Returns forum tag definitions.""" with self._lock: - conn = self._get_conn() if forum_id: - rows = conn.execute("SELECT * FROM forum_tags WHERE forum_id = ?", (str(forum_id),)).fetchall() + rows = self._conn.execute("SELECT * FROM forum_tags WHERE forum_id = ?", (str(forum_id),)).fetchall() else: - rows = conn.execute("SELECT * FROM forum_tags").fetchall() - conn.close() + rows = self._conn.execute("SELECT * FROM forum_tags").fetchall() return [dict(r) for r in rows] def get_threads_by_parent(self, parent_id: str) -> List[Dict[str, Any]]: """Returns all threads belonging to a parent channel.""" with self._lock: - conn = self._get_conn() - rows = conn.execute("SELECT * FROM threads WHERE parent_id = ?", (str(parent_id),)).fetchall() - conn.close() + rows = self._conn.execute("SELECT * FROM threads WHERE parent_id = ?", (str(parent_id),)).fetchall() return [dict(r) for r in rows] def get_thread(self, thread_id: str) -> Optional[Dict[str, Any]]: """Retrieves a single thread's metadata.""" with self._lock: - conn = self._get_conn() - row = conn.execute("SELECT * FROM threads WHERE id = ?", (str(thread_id),)).fetchone() - conn.close() + row = self._conn.execute("SELECT * FROM threads WHERE id = ?", (str(thread_id),)).fetchone() return dict(row) if row else None def get_all_users(self) -> List[Dict[str, Any]]: with self._lock: - conn = self._get_conn() - rows = conn.execute("SELECT * FROM users").fetchall() - conn.close() + rows = self._conn.execute("SELECT * FROM users").fetchall() return [dict(r) for r in rows] def get_user(self, user_id: str) -> Optional[Dict[str, Any]]: with self._lock: - conn = self._get_conn() - row = conn.execute("SELECT * FROM users WHERE id = ?", (str(user_id),)).fetchone() - conn.close() + row = self._conn.execute("SELECT * FROM users WHERE id = ?", (str(user_id),)).fetchone() if row: data = dict(row) if data.get("roles"): @@ -668,25 +608,20 @@ class BackupDatabase: def get_server_assets(self, asset_type: Optional[str] = None) -> List[Dict[str, Any]]: """Returns all server assets, optionally filtered by type.""" with self._lock: - conn = self._get_conn() if asset_type: - rows = conn.execute("SELECT * FROM server_assets WHERE type = ?", (asset_type,)).fetchall() + rows = self._conn.execute("SELECT * FROM server_assets WHERE type = ?", (asset_type,)).fetchall() else: - rows = conn.execute("SELECT * FROM server_assets").fetchall() - conn.close() + rows = self._conn.execute("SELECT * FROM server_assets").fetchall() return [dict(r) for r in rows] def get_all_media(self) -> Dict[str, Dict[str, Any]]: """Returns the entire media pool as a dictionary indexed by hash.""" with self._lock: - conn = self._get_conn() - rows = conn.execute("SELECT * FROM media_pool").fetchall() - conn.close() + rows = self._conn.execute("SELECT * FROM media_pool").fetchall() return {r["hash"]: dict(r) for r in rows} def get_messages_paged(self, channel_id: str, limit: int = 100, offset: int = 0, after_id: Optional[str] = None) -> List[Dict[str, Any]]: with self._lock: - conn = self._get_conn() query = "SELECT * FROM messages WHERE channel_id = ?" params = [str(channel_id)] @@ -697,24 +632,21 @@ class BackupDatabase: query += " ORDER BY id ASC LIMIT ? OFFSET ?" params.extend([limit, offset]) - rows = conn.execute(query, params).fetchall() + rows = self._conn.execute(query, params).fetchall() msg_list = [dict(r) for r in rows] if msg_list: - # Fetch attachments for these messages msg_ids = [m["id"] for m in msg_list] placeholders = ",".join(["?"] * len(msg_ids)) - # Attachments - att_rows = conn.execute(f"SELECT * FROM attachments WHERE message_id IN ({placeholders})", msg_ids).fetchall() + att_rows = self._conn.execute(f"SELECT * FROM attachments WHERE message_id IN ({placeholders})", msg_ids).fetchall() atts_by_msg = {} for ar in att_rows: mid = ar["message_id"] if mid not in atts_by_msg: atts_by_msg[mid] = [] atts_by_msg[mid].append(dict(ar)) - # Embeds - emb_rows = conn.execute(f"SELECT * FROM embeds WHERE message_id IN ({placeholders})", msg_ids).fetchall() + emb_rows = self._conn.execute(f"SELECT * FROM embeds WHERE message_id IN ({placeholders})", msg_ids).fetchall() embs_by_msg = {} for er in emb_rows: mid = er["message_id"] @@ -741,16 +673,14 @@ class BackupDatabase: } embs_by_msg[mid].append(e_dict) - # Reactions - rea_rows = conn.execute(f"SELECT * FROM reactions WHERE message_id IN ({placeholders})", msg_ids).fetchall() + rea_rows = self._conn.execute(f"SELECT * FROM reactions WHERE message_id IN ({placeholders})", msg_ids).fetchall() reas_by_msg = {} for rr in rea_rows: mid = rr["message_id"] if mid not in reas_by_msg: reas_by_msg[mid] = [] reas_by_msg[mid].append(dict(rr)) - # Stickers (Message-specific) - st_rows = conn.execute(f"SELECT * FROM message_stickers WHERE message_id IN ({placeholders})", msg_ids).fetchall() + st_rows = self._conn.execute(f"SELECT * FROM message_stickers WHERE message_id IN ({placeholders})", msg_ids).fetchall() sts_by_msg = {} for sr in st_rows: mid = sr["message_id"] @@ -764,10 +694,13 @@ class BackupDatabase: m["reactions"] = reas_by_msg.get(m_id, []) m["stickers"] = sts_by_msg.get(m_id, []) - conn.close() return msg_list def close(self): - # We don't keep long-lived connections in this model to avoid locking issues, - # but the method is here for parity with other DB classes. - pass + """Commits any pending writes and closes the connection.""" + with self._lock: + try: + self._conn.commit() + self._conn.close() + except Exception: + pass diff --git a/src/core/exporter.py b/src/core/exporter.py index f04541a..993d422 100644 --- a/src/core/exporter.py +++ b/src/core/exporter.py @@ -21,7 +21,10 @@ class DiscordExporter: self.base_dir = Path(base_dir) if base_dir else Path(".") self.is_running = True self.db: Optional[BackupDatabase] = None - self.sticker_cache: Dict[int, bytes] = {} # Deduplicate downloads in one session + self.sticker_cache: Dict[int, bytes] = {} # Deduplicate downloads in one session + # Pending avatar downloads — flushed after each message batch to keep the + # hot message-formatting path free of HTTP latency. + self._pending_avatars: List[tuple] = [] # (user_id, save_coroutine, av_path) async def setup(self): """Prepares the output directory and fetches server metadata.""" @@ -379,6 +382,9 @@ class DiscordExporter: accumulated_files += len(m["attachments"]) # Persist to DB + # Flush deferred avatar downloads before persisting this batch + await self._flush_pending_avatars() + if self.db: if batch_users: self.db.save_users(batch_users) self.db.save_messages_batch(batch_messages) @@ -393,32 +399,35 @@ class DiscordExporter: batch_users.clear() batch_raw.clear() - if batch_raw and self.is_running: - results = await asyncio.gather(*(self._format_message(m) for m in batch_raw)) - for m_data, u_list in results: - batch_messages.append(m_data) - if u_list: batch_users.extend(u_list) - - new_count += len(batch_messages) - accumulated_count += len(batch_messages) - - for m in batch_messages: - if "attachments" in m: - accumulated_files += len(m["attachments"]) - - - if self.db: - if batch_users: self.db.save_users(batch_users) - self.db.save_messages_batch(batch_messages) - - if progress_callback: - last_msg = batch_raw[-1] - author_name = getattr(last_msg.author, "display_name", "Unknown") - await progress_callback(channel_name, accumulated_count, author_name=author_name, thread_count=accumulated_threads, file_count=accumulated_files) - - batch_messages.clear() - batch_users.clear() - batch_raw.clear() + # Process any remaining messages that didn't fill a full batch + if batch_raw and self.is_running: + results = await asyncio.gather(*(self._format_message(m) for m in batch_raw)) + for m_data, u_list in results: + batch_messages.append(m_data) + if u_list: batch_users.extend(u_list) + + new_count += len(batch_messages) + accumulated_count += len(batch_messages) + + for m in batch_messages: + if "attachments" in m: + accumulated_files += len(m["attachments"]) + + # Flush deferred avatar downloads before persisting this batch + await self._flush_pending_avatars() + + if self.db: + if batch_users: self.db.save_users(batch_users) + self.db.save_messages_batch(batch_messages) + + if progress_callback: + last_msg = batch_raw[-1] + author_name = getattr(last_msg.author, "display_name", "Unknown") + await progress_callback(channel_name, accumulated_count, author_name=author_name, thread_count=accumulated_threads, file_count=accumulated_files) + + batch_messages.clear() + batch_users.clear() + batch_raw.clear() except discord.Forbidden: logger.error(f"403 Forbidden: Missing Access to read messages in {channel_name} ({channel_id})") @@ -431,22 +440,24 @@ class DiscordExporter: return accumulated_count, accumulated_threads, accumulated_files async def _format_user(self, user): - """Formats user data for the author or a mention.""" + """Formats user data for the author or a mention. + + Avatar downloads are intentionally deferred to keep this off the hot + message-formatting path. Call _flush_pending_avatars() after each batch. + """ user_id = str(user.id) if user_id in self.user_cache: return None - # New user discovered + # New user discovered — schedule avatar download but don't block here avatar_file = None if user.avatar: - try: - av_name = f"{user_id}.png" - av_target = self.users_path / av_name - if not av_target.exists(): - await user.avatar.save(av_target) - avatar_file = f"users/{av_name}" - except Exception as e: - logger.error(f"Failed to save avatar for {user.name}: {e}") + av_name = f"{user_id}.png" + av_target = self.users_path / av_name + avatar_file = f"users/{av_name}" + if not av_target.exists(): + # Queue for deferred download + self._pending_avatars.append((user_id, user.avatar, av_target)) roles = [] if hasattr(user, "roles"): @@ -463,6 +474,21 @@ class DiscordExporter: self.user_cache[user_id] = user_data return user_data + async def _flush_pending_avatars(self): + """Downloads all queued user avatars concurrently, then clears the queue.""" + if not self._pending_avatars: + return + + async def _save_avatar(user_id, avatar_asset, target_path): + try: + await avatar_asset.save(target_path) + except Exception as e: + logger.error(f"Failed to save avatar for user {user_id}: {e}") + + pending = self._pending_avatars[:] + self._pending_avatars.clear() + await asyncio.gather(*[_save_avatar(uid, av, path) for uid, av, path in pending]) + async def _format_message(self, msg): """Formats a single message and its author for DB storage.""" new_users = [] @@ -478,11 +504,11 @@ class DiscordExporter: if u_ment: new_users.append(u_ment) # 2. Attachments handling (Content-Addressable Storage) - # ... (rest of the logic remains same, just updating m_data) + # All attachments in a message are downloaded concurrently. attachments = [] if msg.attachments: - for att in msg.attachments: - att_data = await self._process_media( + att_tasks = [ + self._process_media( media_id=att.id, url=att.url, filename=att.filename, @@ -490,8 +516,10 @@ class DiscordExporter: content_type=att.content_type, save_method=att.save ) - if att_data: - attachments.append(att_data) + for att in msg.attachments + ] + att_results = await asyncio.gather(*att_tasks) + attachments = [r for r in att_results if r] # 2.5 Stickers handling stickers = [] @@ -556,16 +584,20 @@ class DiscordExporter: if snapshot.content: content = snapshot.content - for s_att in snapshot.attachments: - att_res = await self._process_media( - media_id=s_att.id, - url=s_att.url, - filename=s_att.filename, - size=s_att.size, - content_type=s_att.content_type, - save_method=s_att.save - ) - if att_res: attachments.append(att_res) + if snapshot.attachments: + snap_tasks = [ + self._process_media( + media_id=s_att.id, + url=s_att.url, + filename=s_att.filename, + size=s_att.size, + content_type=s_att.content_type, + save_method=s_att.save + ) + for s_att in snapshot.attachments + ] + snap_results = await asyncio.gather(*snap_tasks) + attachments.extend(r for r in snap_results if r) for s_emb in snapshot.embeds: embeds.append(s_emb.to_dict()) @@ -623,14 +655,16 @@ class DiscordExporter: try: tmp.close() except: pass - file_hash = self._calculate_sha256(tmp_path) - actual_size = tmp_path.stat().st_size + # Offload CPU-bound hashing and blocking file ops to the thread pool + # so we don't stall concurrent downloads on the event loop. + file_hash = await asyncio.to_thread(self._calculate_sha256, tmp_path) + actual_size = (await asyncio.to_thread(tmp_path.stat)).st_size # Check if hash already exists in pool if self.db: in_pool = self.db.get_media_by_hash(file_hash) if in_pool: - tmp_path.unlink() + await asyncio.to_thread(tmp_path.unlink) return { "id": str(media_id), "filename": filename, @@ -645,8 +679,11 @@ class DiscordExporter: target_filename = f"{file_hash}{ext}" target_path = self.attachments_path / target_filename - shutil.move(str(tmp_path), str(target_path)) + await asyncio.to_thread(shutil.move, str(tmp_path), str(target_path)) + # Mark as successfully moved so finally block doesn't delete it + tmp_path = None + if self.db: self.db.add_media_to_pool(file_hash, f"attachments/{target_filename}", actual_size, content_type, str(url)) @@ -658,9 +695,14 @@ class DiscordExporter: "content_type": content_type, "local_hash": file_hash } - except Exception as e: - logger.error(f"Failed to process media {filename}: {e}") - if tmp_path and tmp_path.exists(): tmp_path.unlink() + except BaseException as e: + if not isinstance(e, asyncio.CancelledError): + logger.error(f"Failed to process media {filename}: {e}") + raise + finally: + if tmp_path and tmp_path.exists(): + try: tmp_path.unlink() + except: pass return None async def export_threads(self, channel_id: int, progress_callback=None, force=False, accumulated_count=0, accumulated_threads=0, accumulated_files=0, after_id: int | None = None): @@ -750,25 +792,35 @@ class DiscordExporter: }) self.db.save_threads(thread_meta) - for thread in all_threads: - if not self.is_running: break - await asyncio.sleep(0) - - accumulated_threads += 1 + # Export threads concurrently — semaphore limits to 5 at a time to + # avoid flooding Discord's rate limiter. + sem = asyncio.Semaphore(5) + + async def _export_one_thread(thread, t_idx): + async with sem: + if not self.is_running: + return 0, 0, 0 + cnt, thr, fls = await self.export_channel_messages( + thread.id, + progress_callback=progress_callback, + force=force, + accumulated_count=0, + accumulated_threads=0, + accumulated_files=0, + after_id=after_id + ) + return cnt, thr, fls + + if all_threads: + thread_results = await asyncio.gather( + *[_export_one_thread(t, i) for i, t in enumerate(all_threads)] + ) + for cnt, thr, fls in thread_results: + accumulated_count += cnt + accumulated_threads += 1 + thr # +1 for the thread itself + accumulated_files += fls + if progress_callback: await progress_callback(channel.name, accumulated_count, thread_count=accumulated_threads, file_count=accumulated_files) - - # Backup thread messages - accumulated_count, accumulated_threads, accumulated_files = await self.export_channel_messages(thread.id, progress_callback=progress_callback, force=force, accumulated_count=accumulated_count, accumulated_threads=accumulated_threads, accumulated_files=accumulated_files, after_id=after_id) - # For forums, ensure the starter message exists in the DB - if is_forum: - # starter_message is handled by export_channel_messages since it's just a message in that thread - # However we may want to mark it or store forum-specific tags - try: - # Just yield for concurrency - await asyncio.sleep(0) - except Exception as e: - logger.error(f"Error processing forum thread {thread.name}: {e}") - return accumulated_count, accumulated_threads, accumulated_files