improve thread safety for backup db
This commit is contained in:
parent
0cb678b848
commit
73d52d2183
2 changed files with 72 additions and 22 deletions
|
|
@ -114,7 +114,7 @@ class BackupDatabase:
|
||||||
elif table == "forum_tags":
|
elif table == "forum_tags":
|
||||||
conn.execute("CREATE TABLE forum_tags (id INTEGER PRIMARY KEY, forum_id INTEGER, name TEXT, moderated INTEGER, emoji_id INTEGER, emoji_name TEXT)")
|
conn.execute("CREATE TABLE forum_tags (id INTEGER PRIMARY KEY, forum_id INTEGER, name TEXT, moderated INTEGER, emoji_id INTEGER, emoji_name TEXT)")
|
||||||
elif table == "server_assets":
|
elif table == "server_assets":
|
||||||
conn.execute("CREATE TABLE server_assets (id INTEGER PRIMARY KEY, name TEXT, type TEXT, filename TEXT, url TEXT, content_type INTEGER)")
|
conn.execute("CREATE TABLE server_assets (id INTEGER PRIMARY KEY, name TEXT, type TEXT, filename TEXT, url TEXT, content_type TEXT)")
|
||||||
|
|
||||||
old_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table}_old)").fetchall()]
|
old_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table}_old)").fetchall()]
|
||||||
new_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table})").fetchall()]
|
new_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table})").fetchall()]
|
||||||
|
|
@ -945,6 +945,60 @@ class BackupDatabase:
|
||||||
|
|
||||||
return purged_count
|
return purged_count
|
||||||
|
|
||||||
|
def get_backed_up_channel_ids(self) -> List[int]:
|
||||||
|
"""Returns a list of distinct channel IDs that have messages in the database."""
|
||||||
|
with self._lock:
|
||||||
|
rows = self._conn.execute("SELECT DISTINCT channel_id FROM messages").fetchall()
|
||||||
|
return [parse_snowflake(r[0]) for r in rows if parse_snowflake(r[0])]
|
||||||
|
|
||||||
|
def get_message_with_relations(self, message_id) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Fetches a single message with its attachments, embeds, reactions, and stickers."""
|
||||||
|
with self._lock:
|
||||||
|
mid = parse_snowflake(message_id)
|
||||||
|
row = self._conn.execute("SELECT * FROM messages WHERE id = ?", (mid,)).fetchone()
|
||||||
|
if not row:
|
||||||
|
return None
|
||||||
|
data = dict(row)
|
||||||
|
|
||||||
|
# Attachments
|
||||||
|
atts = self._conn.execute("SELECT * FROM attachments WHERE message_id = ?", (mid,)).fetchall()
|
||||||
|
data["attachments"] = [dict(a) for a in atts]
|
||||||
|
|
||||||
|
# Embeds
|
||||||
|
embs = self._conn.execute("SELECT * FROM embeds WHERE message_id = ?", (mid,)).fetchall()
|
||||||
|
data["embeds"] = []
|
||||||
|
for er in embs:
|
||||||
|
e_dict = {
|
||||||
|
"title": er["title"],
|
||||||
|
"description": er["description"],
|
||||||
|
"url": er["url"],
|
||||||
|
"color": er["color"],
|
||||||
|
"timestamp": er["timestamp"],
|
||||||
|
"thumbnail": {"url": er["thumbnail_url"]} if er["thumbnail_url"] else None,
|
||||||
|
"image": {"url": er["image_url"]} if er["image_url"] else None,
|
||||||
|
"author": {
|
||||||
|
"name": er["author_name"],
|
||||||
|
"url": er["author_url"],
|
||||||
|
"icon_url": er["author_icon_url"]
|
||||||
|
} if er["author_name"] else None,
|
||||||
|
"footer": {
|
||||||
|
"text": er["footer_text"],
|
||||||
|
"icon_url": er["footer_icon_url"]
|
||||||
|
} if er["footer_text"] else None,
|
||||||
|
"fields": json.loads(er["fields"]) if er["fields"] else []
|
||||||
|
}
|
||||||
|
data["embeds"].append(e_dict)
|
||||||
|
|
||||||
|
# Reactions
|
||||||
|
reas = self._conn.execute("SELECT * FROM reactions WHERE message_id = ?", (mid,)).fetchall()
|
||||||
|
data["reactions"] = [dict(r) for r in reas]
|
||||||
|
|
||||||
|
# Stickers
|
||||||
|
sts = self._conn.execute("SELECT * FROM message_stickers WHERE message_id = ?", (mid,)).fetchall()
|
||||||
|
data["stickers"] = [dict(s) for s in sts]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Commits any pending writes and closes the connection."""
|
"""Commits any pending writes and closes the connection."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
|
|
||||||
|
|
@ -393,6 +393,20 @@ class BackupMember:
|
||||||
# Fallback for unexpected data format
|
# Fallback for unexpected data format
|
||||||
self.id = 0
|
self.id = 0
|
||||||
self.name = "Unknown"
|
self.name = "Unknown"
|
||||||
|
self.display_name = "Unknown"
|
||||||
|
self.global_name = "Unknown"
|
||||||
|
self.bot = False
|
||||||
|
self.system = False
|
||||||
|
self.discriminator = "0000"
|
||||||
|
self.color = BackupColor(0)
|
||||||
|
self.roles = sorted(role_objects or [], key=lambda r: r.position, reverse=True)
|
||||||
|
self.guild_permissions = BackupPermissions(0)
|
||||||
|
self.created_at = datetime.now(timezone.utc)
|
||||||
|
self.joined_at = datetime.now(timezone.utc)
|
||||||
|
self.status = type("Status", (), {"value": "offline"})()
|
||||||
|
self.activity = None
|
||||||
|
self._avatar_url = None
|
||||||
|
self.avatar = BackupAsset(None)
|
||||||
return
|
return
|
||||||
self.id = parse_snowflake(data["id"])
|
self.id = parse_snowflake(data["id"])
|
||||||
self.name = data.get("username", "Unknown")
|
self.name = data.get("username", "Unknown")
|
||||||
|
|
@ -1266,11 +1280,7 @@ class BackupReader:
|
||||||
async def get_backed_up_channel_ids(self) -> List[int]:
|
async def get_backed_up_channel_ids(self) -> List[int]:
|
||||||
"""Returns a list of channel IDs that have messages in the database."""
|
"""Returns a list of channel IDs that have messages in the database."""
|
||||||
if not self.db: return []
|
if not self.db: return []
|
||||||
import sqlite3
|
return self.db.get_backed_up_channel_ids()
|
||||||
conn = sqlite3.connect(self.db.db_path)
|
|
||||||
rows = conn.execute("SELECT DISTINCT channel_id FROM messages").fetchall()
|
|
||||||
conn.close()
|
|
||||||
return [parse_snowflake(r[0]) for r in rows if parse_snowflake(r[0])]
|
|
||||||
|
|
||||||
async def get_channel(self, channel_id: int) -> BackupChannel | BackupThread | None:
|
async def get_channel(self, channel_id: int) -> BackupChannel | BackupThread | None:
|
||||||
for c in self.channels:
|
for c in self.channels:
|
||||||
|
|
@ -1351,23 +1361,9 @@ class BackupReader:
|
||||||
async def get_message(self, channel_id: int, message_id: int) -> BackupMessage | None:
|
async def get_message(self, channel_id: int, message_id: int) -> BackupMessage | None:
|
||||||
"""Fetch a specific message from SQLite."""
|
"""Fetch a specific message from SQLite."""
|
||||||
if not self.db: return None
|
if not self.db: return None
|
||||||
import sqlite3
|
data = self.db.get_message_with_relations(message_id)
|
||||||
conn = sqlite3.connect(self.db.db_path)
|
if data:
|
||||||
conn.row_factory = sqlite3.Row
|
|
||||||
row = conn.execute("SELECT * FROM messages WHERE id = ?", (str(message_id),)).fetchone()
|
|
||||||
if row:
|
|
||||||
data = dict(row)
|
|
||||||
# Fetch attachments
|
|
||||||
atts = conn.execute("SELECT * FROM attachments WHERE message_id = ?", (str(message_id),)).fetchall()
|
|
||||||
data["attachments"] = [dict(a) for a in atts]
|
|
||||||
|
|
||||||
# Fetch stickers
|
|
||||||
sts = conn.execute("SELECT * FROM message_stickers WHERE message_id = ?", (str(message_id),)).fetchall()
|
|
||||||
data["stickers"] = [dict(s) for s in sts]
|
|
||||||
|
|
||||||
conn.close()
|
|
||||||
return self._hydrate_message(data)
|
return self._hydrate_message(data)
|
||||||
conn.close()
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_first_message(self, channel_id: int) -> BackupMessage | None:
|
async def get_first_message(self, channel_id: int) -> BackupMessage | None:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue