improve thread safety for backup db

This commit is contained in:
rambros3d 2026-03-30 01:40:05 +05:30
parent 0cb678b848
commit 73d52d2183
2 changed files with 72 additions and 22 deletions

View file

@ -114,7 +114,7 @@ class BackupDatabase:
elif table == "forum_tags": elif table == "forum_tags":
conn.execute("CREATE TABLE forum_tags (id INTEGER PRIMARY KEY, forum_id INTEGER, name TEXT, moderated INTEGER, emoji_id INTEGER, emoji_name TEXT)") conn.execute("CREATE TABLE forum_tags (id INTEGER PRIMARY KEY, forum_id INTEGER, name TEXT, moderated INTEGER, emoji_id INTEGER, emoji_name TEXT)")
elif table == "server_assets": elif table == "server_assets":
conn.execute("CREATE TABLE server_assets (id INTEGER PRIMARY KEY, name TEXT, type TEXT, filename TEXT, url TEXT, content_type INTEGER)") conn.execute("CREATE TABLE server_assets (id INTEGER PRIMARY KEY, name TEXT, type TEXT, filename TEXT, url TEXT, content_type TEXT)")
old_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table}_old)").fetchall()] old_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table}_old)").fetchall()]
new_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table})").fetchall()] new_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table})").fetchall()]
@ -945,6 +945,60 @@ class BackupDatabase:
return purged_count return purged_count
def get_backed_up_channel_ids(self) -> List[int]:
"""Returns a list of distinct channel IDs that have messages in the database."""
with self._lock:
rows = self._conn.execute("SELECT DISTINCT channel_id FROM messages").fetchall()
return [parse_snowflake(r[0]) for r in rows if parse_snowflake(r[0])]
def get_message_with_relations(self, message_id) -> Optional[Dict[str, Any]]:
"""Fetches a single message with its attachments, embeds, reactions, and stickers."""
with self._lock:
mid = parse_snowflake(message_id)
row = self._conn.execute("SELECT * FROM messages WHERE id = ?", (mid,)).fetchone()
if not row:
return None
data = dict(row)
# Attachments
atts = self._conn.execute("SELECT * FROM attachments WHERE message_id = ?", (mid,)).fetchall()
data["attachments"] = [dict(a) for a in atts]
# Embeds
embs = self._conn.execute("SELECT * FROM embeds WHERE message_id = ?", (mid,)).fetchall()
data["embeds"] = []
for er in embs:
e_dict = {
"title": er["title"],
"description": er["description"],
"url": er["url"],
"color": er["color"],
"timestamp": er["timestamp"],
"thumbnail": {"url": er["thumbnail_url"]} if er["thumbnail_url"] else None,
"image": {"url": er["image_url"]} if er["image_url"] else None,
"author": {
"name": er["author_name"],
"url": er["author_url"],
"icon_url": er["author_icon_url"]
} if er["author_name"] else None,
"footer": {
"text": er["footer_text"],
"icon_url": er["footer_icon_url"]
} if er["footer_text"] else None,
"fields": json.loads(er["fields"]) if er["fields"] else []
}
data["embeds"].append(e_dict)
# Reactions
reas = self._conn.execute("SELECT * FROM reactions WHERE message_id = ?", (mid,)).fetchall()
data["reactions"] = [dict(r) for r in reas]
# Stickers
sts = self._conn.execute("SELECT * FROM message_stickers WHERE message_id = ?", (mid,)).fetchall()
data["stickers"] = [dict(s) for s in sts]
return data
def close(self): def close(self):
"""Commits any pending writes and closes the connection.""" """Commits any pending writes and closes the connection."""
with self._lock: with self._lock:

View file

@ -393,6 +393,20 @@ class BackupMember:
# Fallback for unexpected data format # Fallback for unexpected data format
self.id = 0 self.id = 0
self.name = "Unknown" self.name = "Unknown"
self.display_name = "Unknown"
self.global_name = "Unknown"
self.bot = False
self.system = False
self.discriminator = "0000"
self.color = BackupColor(0)
self.roles = sorted(role_objects or [], key=lambda r: r.position, reverse=True)
self.guild_permissions = BackupPermissions(0)
self.created_at = datetime.now(timezone.utc)
self.joined_at = datetime.now(timezone.utc)
self.status = type("Status", (), {"value": "offline"})()
self.activity = None
self._avatar_url = None
self.avatar = BackupAsset(None)
return return
self.id = parse_snowflake(data["id"]) self.id = parse_snowflake(data["id"])
self.name = data.get("username", "Unknown") self.name = data.get("username", "Unknown")
@ -1266,11 +1280,7 @@ class BackupReader:
async def get_backed_up_channel_ids(self) -> List[int]: async def get_backed_up_channel_ids(self) -> List[int]:
"""Returns a list of channel IDs that have messages in the database.""" """Returns a list of channel IDs that have messages in the database."""
if not self.db: return [] if not self.db: return []
import sqlite3 return self.db.get_backed_up_channel_ids()
conn = sqlite3.connect(self.db.db_path)
rows = conn.execute("SELECT DISTINCT channel_id FROM messages").fetchall()
conn.close()
return [parse_snowflake(r[0]) for r in rows if parse_snowflake(r[0])]
async def get_channel(self, channel_id: int) -> BackupChannel | BackupThread | None: async def get_channel(self, channel_id: int) -> BackupChannel | BackupThread | None:
for c in self.channels: for c in self.channels:
@ -1351,23 +1361,9 @@ class BackupReader:
async def get_message(self, channel_id: int, message_id: int) -> BackupMessage | None: async def get_message(self, channel_id: int, message_id: int) -> BackupMessage | None:
"""Fetch a specific message from SQLite.""" """Fetch a specific message from SQLite."""
if not self.db: return None if not self.db: return None
import sqlite3 data = self.db.get_message_with_relations(message_id)
conn = sqlite3.connect(self.db.db_path) if data:
conn.row_factory = sqlite3.Row
row = conn.execute("SELECT * FROM messages WHERE id = ?", (str(message_id),)).fetchone()
if row:
data = dict(row)
# Fetch attachments
atts = conn.execute("SELECT * FROM attachments WHERE message_id = ?", (str(message_id),)).fetchall()
data["attachments"] = [dict(a) for a in atts]
# Fetch stickers
sts = conn.execute("SELECT * FROM message_stickers WHERE message_id = ?", (str(message_id),)).fetchall()
data["stickers"] = [dict(s) for s in sts]
conn.close()
return self._hydrate_message(data) return self._hydrate_message(data)
conn.close()
return None return None
async def get_first_message(self, channel_id: int) -> BackupMessage | None: async def get_first_message(self, channel_id: int) -> BackupMessage | None: