From 73d52d2183485d263fad4cd72c6cb91b4df67196 Mon Sep 17 00:00:00 2001 From: rambros3d Date: Mon, 30 Mar 2026 01:40:05 +0530 Subject: [PATCH] improve thread safety for backup db --- src/core/backup_database.py | 56 ++++++++++++++++++++++++++++++++++++- src/core/backup_reader.py | 38 +++++++++++-------------- 2 files changed, 72 insertions(+), 22 deletions(-) diff --git a/src/core/backup_database.py b/src/core/backup_database.py index 40e0011..381cdb7 100644 --- a/src/core/backup_database.py +++ b/src/core/backup_database.py @@ -114,7 +114,7 @@ class BackupDatabase: elif table == "forum_tags": conn.execute("CREATE TABLE forum_tags (id INTEGER PRIMARY KEY, forum_id INTEGER, name TEXT, moderated INTEGER, emoji_id INTEGER, emoji_name TEXT)") elif table == "server_assets": - conn.execute("CREATE TABLE server_assets (id INTEGER PRIMARY KEY, name TEXT, type TEXT, filename TEXT, url TEXT, content_type INTEGER)") + conn.execute("CREATE TABLE server_assets (id INTEGER PRIMARY KEY, name TEXT, type TEXT, filename TEXT, url TEXT, content_type TEXT)") old_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table}_old)").fetchall()] new_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table})").fetchall()] @@ -945,6 +945,60 @@ class BackupDatabase: return purged_count + def get_backed_up_channel_ids(self) -> List[int]: + """Returns a list of distinct channel IDs that have messages in the database.""" + with self._lock: + rows = self._conn.execute("SELECT DISTINCT channel_id FROM messages").fetchall() + return [parse_snowflake(r[0]) for r in rows if parse_snowflake(r[0])] + + def get_message_with_relations(self, message_id) -> Optional[Dict[str, Any]]: + """Fetches a single message with its attachments, embeds, reactions, and stickers.""" + with self._lock: + mid = parse_snowflake(message_id) + row = self._conn.execute("SELECT * FROM messages WHERE id = ?", (mid,)).fetchone() + if not row: + return None + data = dict(row) + + # Attachments + atts = self._conn.execute("SELECT * FROM attachments WHERE message_id = ?", (mid,)).fetchall() + data["attachments"] = [dict(a) for a in atts] + + # Embeds + embs = self._conn.execute("SELECT * FROM embeds WHERE message_id = ?", (mid,)).fetchall() + data["embeds"] = [] + for er in embs: + e_dict = { + "title": er["title"], + "description": er["description"], + "url": er["url"], + "color": er["color"], + "timestamp": er["timestamp"], + "thumbnail": {"url": er["thumbnail_url"]} if er["thumbnail_url"] else None, + "image": {"url": er["image_url"]} if er["image_url"] else None, + "author": { + "name": er["author_name"], + "url": er["author_url"], + "icon_url": er["author_icon_url"] + } if er["author_name"] else None, + "footer": { + "text": er["footer_text"], + "icon_url": er["footer_icon_url"] + } if er["footer_text"] else None, + "fields": json.loads(er["fields"]) if er["fields"] else [] + } + data["embeds"].append(e_dict) + + # Reactions + reas = self._conn.execute("SELECT * FROM reactions WHERE message_id = ?", (mid,)).fetchall() + data["reactions"] = [dict(r) for r in reas] + + # Stickers + sts = self._conn.execute("SELECT * FROM message_stickers WHERE message_id = ?", (mid,)).fetchall() + data["stickers"] = [dict(s) for s in sts] + + return data + def close(self): """Commits any pending writes and closes the connection.""" with self._lock: diff --git a/src/core/backup_reader.py b/src/core/backup_reader.py index c9d250d..36369c8 100644 --- a/src/core/backup_reader.py +++ b/src/core/backup_reader.py @@ -393,6 +393,20 @@ class BackupMember: # Fallback for unexpected data format self.id = 0 self.name = "Unknown" + self.display_name = "Unknown" + self.global_name = "Unknown" + self.bot = False + self.system = False + self.discriminator = "0000" + self.color = BackupColor(0) + self.roles = sorted(role_objects or [], key=lambda r: r.position, reverse=True) + self.guild_permissions = BackupPermissions(0) + self.created_at = datetime.now(timezone.utc) + self.joined_at = datetime.now(timezone.utc) + self.status = type("Status", (), {"value": "offline"})() + self.activity = None + self._avatar_url = None + self.avatar = BackupAsset(None) return self.id = parse_snowflake(data["id"]) self.name = data.get("username", "Unknown") @@ -1266,11 +1280,7 @@ class BackupReader: async def get_backed_up_channel_ids(self) -> List[int]: """Returns a list of channel IDs that have messages in the database.""" if not self.db: return [] - import sqlite3 - conn = sqlite3.connect(self.db.db_path) - rows = conn.execute("SELECT DISTINCT channel_id FROM messages").fetchall() - conn.close() - return [parse_snowflake(r[0]) for r in rows if parse_snowflake(r[0])] + return self.db.get_backed_up_channel_ids() async def get_channel(self, channel_id: int) -> BackupChannel | BackupThread | None: for c in self.channels: @@ -1351,23 +1361,9 @@ class BackupReader: async def get_message(self, channel_id: int, message_id: int) -> BackupMessage | None: """Fetch a specific message from SQLite.""" if not self.db: return None - import sqlite3 - conn = sqlite3.connect(self.db.db_path) - conn.row_factory = sqlite3.Row - row = conn.execute("SELECT * FROM messages WHERE id = ?", (str(message_id),)).fetchone() - if row: - data = dict(row) - # Fetch attachments - atts = conn.execute("SELECT * FROM attachments WHERE message_id = ?", (str(message_id),)).fetchall() - data["attachments"] = [dict(a) for a in atts] - - # Fetch stickers - sts = conn.execute("SELECT * FROM message_stickers WHERE message_id = ?", (str(message_id),)).fetchall() - data["stickers"] = [dict(s) for s in sts] - - conn.close() + data = self.db.get_message_with_relations(message_id) + if data: return self._hydrate_message(data) - conn.close() return None async def get_first_message(self, channel_id: int) -> BackupMessage | None: