improve thread safety for backup db
This commit is contained in:
parent
0cb678b848
commit
73d52d2183
2 changed files with 72 additions and 22 deletions
|
|
@ -114,7 +114,7 @@ class BackupDatabase:
|
|||
elif table == "forum_tags":
|
||||
conn.execute("CREATE TABLE forum_tags (id INTEGER PRIMARY KEY, forum_id INTEGER, name TEXT, moderated INTEGER, emoji_id INTEGER, emoji_name TEXT)")
|
||||
elif table == "server_assets":
|
||||
conn.execute("CREATE TABLE server_assets (id INTEGER PRIMARY KEY, name TEXT, type TEXT, filename TEXT, url TEXT, content_type INTEGER)")
|
||||
conn.execute("CREATE TABLE server_assets (id INTEGER PRIMARY KEY, name TEXT, type TEXT, filename TEXT, url TEXT, content_type TEXT)")
|
||||
|
||||
old_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table}_old)").fetchall()]
|
||||
new_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table})").fetchall()]
|
||||
|
|
@ -945,6 +945,60 @@ class BackupDatabase:
|
|||
|
||||
return purged_count
|
||||
|
||||
def get_backed_up_channel_ids(self) -> List[int]:
|
||||
"""Returns a list of distinct channel IDs that have messages in the database."""
|
||||
with self._lock:
|
||||
rows = self._conn.execute("SELECT DISTINCT channel_id FROM messages").fetchall()
|
||||
return [parse_snowflake(r[0]) for r in rows if parse_snowflake(r[0])]
|
||||
|
||||
def get_message_with_relations(self, message_id) -> Optional[Dict[str, Any]]:
|
||||
"""Fetches a single message with its attachments, embeds, reactions, and stickers."""
|
||||
with self._lock:
|
||||
mid = parse_snowflake(message_id)
|
||||
row = self._conn.execute("SELECT * FROM messages WHERE id = ?", (mid,)).fetchone()
|
||||
if not row:
|
||||
return None
|
||||
data = dict(row)
|
||||
|
||||
# Attachments
|
||||
atts = self._conn.execute("SELECT * FROM attachments WHERE message_id = ?", (mid,)).fetchall()
|
||||
data["attachments"] = [dict(a) for a in atts]
|
||||
|
||||
# Embeds
|
||||
embs = self._conn.execute("SELECT * FROM embeds WHERE message_id = ?", (mid,)).fetchall()
|
||||
data["embeds"] = []
|
||||
for er in embs:
|
||||
e_dict = {
|
||||
"title": er["title"],
|
||||
"description": er["description"],
|
||||
"url": er["url"],
|
||||
"color": er["color"],
|
||||
"timestamp": er["timestamp"],
|
||||
"thumbnail": {"url": er["thumbnail_url"]} if er["thumbnail_url"] else None,
|
||||
"image": {"url": er["image_url"]} if er["image_url"] else None,
|
||||
"author": {
|
||||
"name": er["author_name"],
|
||||
"url": er["author_url"],
|
||||
"icon_url": er["author_icon_url"]
|
||||
} if er["author_name"] else None,
|
||||
"footer": {
|
||||
"text": er["footer_text"],
|
||||
"icon_url": er["footer_icon_url"]
|
||||
} if er["footer_text"] else None,
|
||||
"fields": json.loads(er["fields"]) if er["fields"] else []
|
||||
}
|
||||
data["embeds"].append(e_dict)
|
||||
|
||||
# Reactions
|
||||
reas = self._conn.execute("SELECT * FROM reactions WHERE message_id = ?", (mid,)).fetchall()
|
||||
data["reactions"] = [dict(r) for r in reas]
|
||||
|
||||
# Stickers
|
||||
sts = self._conn.execute("SELECT * FROM message_stickers WHERE message_id = ?", (mid,)).fetchall()
|
||||
data["stickers"] = [dict(s) for s in sts]
|
||||
|
||||
return data
|
||||
|
||||
def close(self):
|
||||
"""Commits any pending writes and closes the connection."""
|
||||
with self._lock:
|
||||
|
|
|
|||
|
|
@ -393,6 +393,20 @@ class BackupMember:
|
|||
# Fallback for unexpected data format
|
||||
self.id = 0
|
||||
self.name = "Unknown"
|
||||
self.display_name = "Unknown"
|
||||
self.global_name = "Unknown"
|
||||
self.bot = False
|
||||
self.system = False
|
||||
self.discriminator = "0000"
|
||||
self.color = BackupColor(0)
|
||||
self.roles = sorted(role_objects or [], key=lambda r: r.position, reverse=True)
|
||||
self.guild_permissions = BackupPermissions(0)
|
||||
self.created_at = datetime.now(timezone.utc)
|
||||
self.joined_at = datetime.now(timezone.utc)
|
||||
self.status = type("Status", (), {"value": "offline"})()
|
||||
self.activity = None
|
||||
self._avatar_url = None
|
||||
self.avatar = BackupAsset(None)
|
||||
return
|
||||
self.id = parse_snowflake(data["id"])
|
||||
self.name = data.get("username", "Unknown")
|
||||
|
|
@ -1266,11 +1280,7 @@ class BackupReader:
|
|||
async def get_backed_up_channel_ids(self) -> List[int]:
|
||||
"""Returns a list of channel IDs that have messages in the database."""
|
||||
if not self.db: return []
|
||||
import sqlite3
|
||||
conn = sqlite3.connect(self.db.db_path)
|
||||
rows = conn.execute("SELECT DISTINCT channel_id FROM messages").fetchall()
|
||||
conn.close()
|
||||
return [parse_snowflake(r[0]) for r in rows if parse_snowflake(r[0])]
|
||||
return self.db.get_backed_up_channel_ids()
|
||||
|
||||
async def get_channel(self, channel_id: int) -> BackupChannel | BackupThread | None:
|
||||
for c in self.channels:
|
||||
|
|
@ -1351,23 +1361,9 @@ class BackupReader:
|
|||
async def get_message(self, channel_id: int, message_id: int) -> BackupMessage | None:
|
||||
"""Fetch a specific message from SQLite."""
|
||||
if not self.db: return None
|
||||
import sqlite3
|
||||
conn = sqlite3.connect(self.db.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
row = conn.execute("SELECT * FROM messages WHERE id = ?", (str(message_id),)).fetchone()
|
||||
if row:
|
||||
data = dict(row)
|
||||
# Fetch attachments
|
||||
atts = conn.execute("SELECT * FROM attachments WHERE message_id = ?", (str(message_id),)).fetchall()
|
||||
data["attachments"] = [dict(a) for a in atts]
|
||||
|
||||
# Fetch stickers
|
||||
sts = conn.execute("SELECT * FROM message_stickers WHERE message_id = ?", (str(message_id),)).fetchall()
|
||||
data["stickers"] = [dict(s) for s in sts]
|
||||
|
||||
conn.close()
|
||||
data = self.db.get_message_with_relations(message_id)
|
||||
if data:
|
||||
return self._hydrate_message(data)
|
||||
conn.close()
|
||||
return None
|
||||
|
||||
async def get_first_message(self, channel_id: int) -> BackupMessage | None:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue