optimize backup speed

This commit is contained in:
rambros 2026-03-25 13:45:17 +05:30
parent 4b24b29c03
commit cefa477459
2 changed files with 287 additions and 302 deletions

View file

@ -25,17 +25,18 @@ class BackupDatabase:
def __init__(self, db_path: Path | str):
self.db_path = Path(db_path)
self._lock = threading.Lock()
self._conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
self._conn.row_factory = sqlite3.Row
# WAL mode allows concurrent readers and batches disk flushes significantly
self._conn.execute("PRAGMA journal_mode=WAL")
self._conn.execute("PRAGMA synchronous=NORMAL")
self._conn.execute("PRAGMA cache_size=-32000") # 32 MB page cache
self._init_db()
def _get_conn(self):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
conn.row_factory = sqlite3.Row
return conn
def _init_db(self):
"""Initializes the database schema."""
with self._lock:
conn = self._get_conn()
conn = self._conn
try:
# Guild Profile
conn.execute("""
@ -242,12 +243,11 @@ class BackupDatabase:
conn.commit()
finally:
conn.close()
pass # persistent connection — do not close
def set_guild_profile(self, data: Dict[str, Any]):
with self._lock:
conn = self._get_conn()
conn.execute("""
self._conn.execute("""
INSERT OR REPLACE INTO guild_profile (id, name, description, icon_file, icon_url, banner_file, banner_url, owner_id, last_backup, ignore_channels)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
@ -257,14 +257,11 @@ class BackupDatabase:
str(data.get("owner_id")),
data.get("last_backup"), json.dumps(data.get("ignore_channels", []))
))
conn.commit()
conn.close()
self._conn.commit()
def get_guild_profile(self) -> Optional[Dict[str, Any]]:
with self._lock:
conn = self._get_conn()
row = conn.execute("SELECT * FROM guild_profile LIMIT 1").fetchone()
conn.close()
row = self._conn.execute("SELECT * FROM guild_profile LIMIT 1").fetchone()
if row:
data = dict(row)
if data.get("ignore_channels"):
@ -276,11 +273,8 @@ class BackupDatabase:
def save_roles(self, roles: List[Dict[str, Any]]):
with self._lock:
conn = self._get_conn()
# Ensure complex fields are strings if they aren't already
formatted = []
for r in roles:
formatted.append({
formatted = [
{
"id": str(r["id"]),
"name": r["name"],
"color": r["color"],
@ -288,225 +282,193 @@ class BackupDatabase:
"permissions": str(r["permissions"]),
"hoist": 1 if r["hoist"] else 0,
"mentionable": 1 if r["mentionable"] else 0
})
conn.executemany("""
}
for r in roles
]
self._conn.executemany("""
INSERT OR REPLACE INTO roles (id, name, color, position, permissions, hoist, mentionable)
VALUES (:id, :name, :color, :position, :permissions, :hoist, :mentionable)
""", formatted)
conn.commit()
conn.close()
self._conn.commit()
def save_channels(self, channels: List[Dict[str, Any]]):
with self._lock:
conn = self._get_conn()
conn.executemany("""
self._conn.executemany("""
INSERT OR REPLACE INTO channels (id, name, type, position, category_id, topic, nsfw, bitrate, slowmode_delay)
VALUES (:id, :name, :type, :position, :category_id, :topic, :nsfw, :bitrate, :slowmode_delay)
""", channels)
conn.commit()
conn.close()
self._conn.commit()
def save_permissions(self, permissions: List[Dict[str, Any]]):
"""Saves a batch of channel permission overwrites."""
with self._lock:
conn = self._get_conn()
try:
conn.executemany("""
INSERT INTO permissions (channel_id, target_id, target_type, allow, deny)
VALUES (:channel_id, :target_id, :target_type, :allow, :deny)
""", permissions)
conn.commit()
finally:
conn.close()
self._conn.executemany("""
INSERT INTO permissions (channel_id, target_id, target_type, allow, deny)
VALUES (:channel_id, :target_id, :target_type, :allow, :deny)
""", permissions)
self._conn.commit()
def save_users(self, users: List[Dict[str, Any]]):
"""Saves users to the author cache."""
with self._lock:
conn = self._get_conn()
conn.executemany("""
self._conn.executemany("""
INSERT OR REPLACE INTO users (id, username, display_name, avatar_file, avatar_url, roles)
VALUES (:id, :username, :display_name, :avatar_file, :avatar_url, :roles)
""", users)
conn.commit()
conn.close()
self._conn.commit()
def save_server_assets(self, assets: List[Dict[str, Any]]):
"""Saves a batch of server assets (emojis, stickers) to the database."""
with self._lock:
conn = self._get_conn()
formatted = []
for a in assets:
formatted.append({
formatted = [
{
"id": str(a["id"]),
"name": a.get("name"),
"type": a.get("type"),
"filename": a.get("filename"),
"url": a.get("url"),
"mime_type": a.get("mime_type")
})
conn.executemany("""
}
for a in assets
]
self._conn.executemany("""
INSERT OR REPLACE INTO server_assets (id, name, type, filename, url, mime_type)
VALUES (:id, :name, :type, :filename, :url, :mime_type)
""", formatted)
conn.commit()
conn.close()
self._conn.commit()
def save_threads(self, threads: List[Dict[str, Any]]):
"""Saves metadata for threads to the database."""
with self._lock:
conn = self._get_conn()
try:
conn.executemany("""
INSERT OR REPLACE INTO threads (id, name, type, parent_id, message_count, member_count, archived, archive_timestamp, auto_archive_duration, locked, applied_tags)
VALUES (:id, :name, :type, :parent_id, :message_count, :member_count, :archived, :archive_timestamp, :auto_archive_duration, :locked, :applied_tags)
""", threads)
conn.commit()
finally:
conn.close()
self._conn.executemany("""
INSERT OR REPLACE INTO threads (id, name, type, parent_id, message_count, member_count, archived, archive_timestamp, auto_archive_duration, locked, applied_tags)
VALUES (:id, :name, :type, :parent_id, :message_count, :member_count, :archived, :archive_timestamp, :auto_archive_duration, :locked, :applied_tags)
""", threads)
self._conn.commit()
def save_forum_tags(self, tags: List[Dict[str, Any]]):
"""Saves definitions for forum tags."""
with self._lock:
conn = self._get_conn()
try:
conn.executemany("""
INSERT OR REPLACE INTO forum_tags (id, forum_id, name, moderated, emoji_id, emoji_name)
VALUES (:id, :forum_id, :name, :moderated, :emoji_id, :emoji_name)
""", tags)
conn.commit()
finally:
conn.close()
self._conn.executemany("""
INSERT OR REPLACE INTO forum_tags (id, forum_id, name, moderated, emoji_id, emoji_name)
VALUES (:id, :forum_id, :name, :moderated, :emoji_id, :emoji_name)
""", tags)
self._conn.commit()
def save_messages_batch(self, messages: List[Dict[str, Any]]):
"""Batch inserts messages and their attachments."""
with self._lock:
conn = self._get_conn()
try:
# Insert messages
conn = self._conn
# Insert messages
conn.executemany("""
INSERT OR REPLACE INTO messages (id, channel_id, author_id, content, timestamp, type, message_reference, is_pinned, extra_data)
VALUES (:id, :channel_id, :author_id, :content, :timestamp, :type, :message_reference, :is_pinned, :extra_data)
""", messages)
# Extract attachments, reactions, and stickers
all_attachments = []
all_reactions = []
all_stickers = []
for msg in messages:
# Attachments
if "attachments" in msg:
for att in msg["attachments"]:
att["message_id"] = msg["id"]
all_attachments.append(att)
# Reactions
if "reactions" in msg and msg["reactions"]:
for rea in msg["reactions"]:
all_reactions.append({
"message_id": msg["id"],
"emoji_id": str(rea["emoji_id"]) if rea.get("emoji_id") else None,
"emoji_name": rea.get("emoji_name"),
"count": rea.get("count", 0)
})
# Stickers
if "stickers" in msg and msg["stickers"]:
for st in msg["stickers"]:
all_stickers.append({
"message_id": msg["id"],
"sticker_id": str(st["id"]),
"name": st.get("name"),
"url": st.get("url"),
"format_type": st.get("format_type"),
"local_hash": st.get("local_hash")
})
if all_attachments:
conn.executemany("""
INSERT OR REPLACE INTO messages (id, channel_id, author_id, content, timestamp, type, message_reference, is_pinned, extra_data)
VALUES (:id, :channel_id, :author_id, :content, :timestamp, :type, :message_reference, :is_pinned, :extra_data)
""", messages)
INSERT OR REPLACE INTO attachments (id, message_id, filename, size, url, content_type, local_hash)
VALUES (:id, :message_id, :filename, :size, :url, :content_type, :local_hash)
""", all_attachments)
# Extract attachments, reactions, and stickers
all_attachments = []
all_reactions = []
all_stickers = []
# Save Embeds (Normalized with JSON Fields)
for msg in messages:
if "embeds" in msg and msg["embeds"]:
for emb in msg["embeds"]:
conn.execute("""
INSERT INTO embeds (
message_id, title, description, url, color, timestamp,
thumbnail_url, image_url, author_name, author_url,
author_icon_url, footer_text, footer_icon_url, fields
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
msg["id"], emb.get("title"), emb.get("description"), emb.get("url"),
emb.get("color"), emb.get("timestamp"),
emb.get("thumbnail", {}).get("url") if isinstance(emb.get("thumbnail"), dict) else None,
emb.get("image", {}).get("url") if isinstance(emb.get("image"), dict) else None,
emb.get("author", {}).get("name") if isinstance(emb.get("author"), dict) else None,
emb.get("author", {}).get("url") if isinstance(emb.get("author"), dict) else None,
emb.get("author", {}).get("icon_url") if isinstance(emb.get("author"), dict) else None,
emb.get("footer", {}).get("text") if isinstance(emb.get("footer"), dict) else None,
emb.get("footer", {}).get("icon_url") if isinstance(emb.get("footer"), dict) else None,
json.dumps(emb.get("fields", []))
))
for msg in messages:
# Attachments
if "attachments" in msg:
for att in msg["attachments"]:
att["message_id"] = msg["id"]
all_attachments.append(att)
if all_reactions:
conn.executemany("""
INSERT INTO reactions (message_id, emoji_id, emoji_name, count)
VALUES (:message_id, :emoji_id, :emoji_name, :count)
""", all_reactions)
# Reactions
if "reactions" in msg and msg["reactions"]:
for rea in msg["reactions"]:
all_reactions.append({
"message_id": msg["id"],
"emoji_id": str(rea["emoji_id"]) if rea.get("emoji_id") else None,
"emoji_name": rea.get("emoji_name"),
"count": rea.get("count", 0)
})
if all_stickers:
conn.executemany("""
INSERT OR REPLACE INTO message_stickers (message_id, sticker_id, name, url, format_type, local_hash)
VALUES (:message_id, :sticker_id, :name, :url, :format_type, :local_hash)
""", all_stickers)
# Stickers
if "stickers" in msg and msg["stickers"]:
for st in msg["stickers"]:
all_stickers.append({
"message_id": msg["id"],
"sticker_id": str(st["id"]),
"name": st.get("name"),
"url": st.get("url"),
"format_type": st.get("format_type"),
"local_hash": st.get("local_hash")
})
if all_attachments:
conn.executemany("""
INSERT OR REPLACE INTO attachments (id, message_id, filename, size, url, content_type, local_hash)
VALUES (:id, :message_id, :filename, :size, :url, :content_type, :local_hash)
""", all_attachments)
# Save Embeds (Normalized with JSON Fields)
for msg in messages:
if "embeds" in msg and msg["embeds"]:
for emb in msg["embeds"]:
# Insert top-level embed
conn.execute("""
INSERT INTO embeds (
message_id, title, description, url, color, timestamp,
thumbnail_url, image_url, author_name, author_url,
author_icon_url, footer_text, footer_icon_url, fields
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
msg["id"], emb.get("title"), emb.get("description"), emb.get("url"),
emb.get("color"), emb.get("timestamp"),
emb.get("thumbnail", {}).get("url") if isinstance(emb.get("thumbnail"), dict) else None,
emb.get("image", {}).get("url") if isinstance(emb.get("image"), dict) else None,
emb.get("author", {}).get("name") if isinstance(emb.get("author"), dict) else None,
emb.get("author", {}).get("url") if isinstance(emb.get("author"), dict) else None,
emb.get("author", {}).get("icon_url") if isinstance(emb.get("author"), dict) else None,
emb.get("footer", {}).get("text") if isinstance(emb.get("footer"), dict) else None,
emb.get("footer", {}).get("icon_url") if isinstance(emb.get("footer"), dict) else None,
json.dumps(emb.get("fields", []))
))
if all_reactions:
conn.executemany("""
INSERT INTO reactions (message_id, emoji_id, emoji_name, count)
VALUES (:message_id, :emoji_id, :emoji_name, :count)
""", all_reactions)
if all_stickers:
conn.executemany("""
INSERT OR REPLACE INTO message_stickers (message_id, sticker_id, name, url, format_type, local_hash)
VALUES (:message_id, :sticker_id, :name, :url, :format_type, :local_hash)
""", all_stickers)
conn.commit()
finally:
conn.close()
conn.commit()
def get_last_message_id(self, channel_id: str) -> Optional[str]:
with self._lock:
conn = self._get_conn()
row = conn.execute("SELECT id FROM messages WHERE channel_id = ? ORDER BY id DESC LIMIT 1", (str(channel_id),)).fetchone()
conn.close()
row = self._conn.execute("SELECT id FROM messages WHERE channel_id = ? ORDER BY id DESC LIMIT 1", (str(channel_id),)).fetchone()
return row["id"] if row else None
def get_media_by_hash(self, file_hash: str) -> Optional[Dict[str, Any]]:
with self._lock:
conn = self._get_conn()
row = conn.execute("SELECT * FROM media_pool WHERE hash = ?", (file_hash,)).fetchone()
conn.close()
row = self._conn.execute("SELECT * FROM media_pool WHERE hash = ?", (file_hash,)).fetchone()
return dict(row) if row else None
def get_media_by_url(self, url: str) -> Optional[Dict[str, Any]]:
with self._lock:
conn = self._get_conn()
row = conn.execute("SELECT * FROM media_pool WHERE first_seen_url = ?", (url,)).fetchone()
conn.close()
row = self._conn.execute("SELECT * FROM media_pool WHERE first_seen_url = ?", (url,)).fetchone()
return dict(row) if row else None
def add_media_to_pool(self, file_hash: str, local_path: str, size: int, mime_type: str, url: str):
# NOTE: No commit here — caller (save_messages_batch) commits at end of batch
with self._lock:
conn = self._get_conn()
conn.execute("""
self._conn.execute("""
INSERT OR REPLACE INTO media_pool (hash, local_path, size, mime_type, first_seen_url)
VALUES (?, ?, ?, ?, ?)
""", (file_hash, str(local_path), size, mime_type, url))
conn.commit()
conn.close()
def get_stats_by_channel(self) -> Dict[int, Dict[str, Any]]:
"""Returns aggregate stats for all channels with backups."""
with self._lock:
conn = self._get_conn()
# Summary of messages per effective channel (including threads rolled up)
msg_rows = conn.execute("""
msg_rows = self._conn.execute("""
SELECT
COALESCE(t.parent_id, m.channel_id) as channel_id,
COUNT(m.id) as msg_count
@ -515,15 +477,13 @@ class BackupDatabase:
GROUP BY channel_id
""").fetchall()
# Thread counts per parent
thread_rows = conn.execute("""
thread_rows = self._conn.execute("""
SELECT parent_id, COUNT(*) as thread_count
FROM threads
GROUP BY parent_id
""").fetchall()
# Summary of attachments per effective channel (including threads rolled up)
att_rows = conn.execute("""
att_rows = self._conn.execute("""
SELECT
COALESCE(t.parent_id, m.channel_id) as channel_id,
COUNT(a.id) as att_count,
@ -534,8 +494,6 @@ class BackupDatabase:
GROUP BY channel_id
""").fetchall()
conn.close()
stats = {}
for r in msg_rows:
cid = parse_snowflake(r["channel_id"])
@ -566,22 +524,18 @@ class BackupDatabase:
def get_all_roles(self) -> List[Dict[str, Any]]:
with self._lock:
conn = self._get_conn()
rows = conn.execute("SELECT * FROM roles ORDER BY position DESC").fetchall()
conn.close()
rows = self._conn.execute("SELECT * FROM roles ORDER BY position DESC").fetchall()
return [dict(r) for r in rows]
def get_all_channels(self) -> List[Dict[str, Any]]:
with self._lock:
conn = self._get_conn()
rows = conn.execute("SELECT * FROM channels ORDER BY position ASC").fetchall()
rows = self._conn.execute("SELECT * FROM channels ORDER BY position ASC").fetchall()
# Fetch all permissions
chan_list = [dict(r) for r in rows]
if chan_list:
ids = [c["id"] for c in chan_list]
placeholders = ",".join(["?"] * len(ids))
perm_rows = conn.execute(f"SELECT * FROM permissions WHERE channel_id IN ({placeholders})", ids).fetchall()
perm_rows = self._conn.execute(f"SELECT * FROM permissions WHERE channel_id IN ({placeholders})", ids).fetchall()
perms_by_chan = {}
for pr in perm_rows:
@ -597,8 +551,7 @@ class BackupDatabase:
for c in chan_list:
c["overwrites"] = perms_by_chan.get(c["id"], [])
# Fetch Forum Tags
tag_rows = conn.execute(f"SELECT * FROM forum_tags WHERE forum_id IN ({placeholders})", ids).fetchall()
tag_rows = self._conn.execute(f"SELECT * FROM forum_tags WHERE forum_id IN ({placeholders})", ids).fetchall()
tags_by_forum = {}
for tr in tag_rows:
fid = tr["forum_id"]
@ -608,56 +561,43 @@ class BackupDatabase:
for c in chan_list:
c["available_tags"] = tags_by_forum.get(c["id"], [])
conn.close()
return chan_list
def get_all_threads(self) -> List[Dict[str, Any]]:
"""Returns metadata for all threads in the backup."""
with self._lock:
conn = self._get_conn()
rows = conn.execute("SELECT * FROM threads").fetchall()
conn.close()
rows = self._conn.execute("SELECT * FROM threads").fetchall()
return [dict(r) for r in rows]
def get_forum_tags(self, forum_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""Returns forum tag definitions."""
with self._lock:
conn = self._get_conn()
if forum_id:
rows = conn.execute("SELECT * FROM forum_tags WHERE forum_id = ?", (str(forum_id),)).fetchall()
rows = self._conn.execute("SELECT * FROM forum_tags WHERE forum_id = ?", (str(forum_id),)).fetchall()
else:
rows = conn.execute("SELECT * FROM forum_tags").fetchall()
conn.close()
rows = self._conn.execute("SELECT * FROM forum_tags").fetchall()
return [dict(r) for r in rows]
def get_threads_by_parent(self, parent_id: str) -> List[Dict[str, Any]]:
"""Returns all threads belonging to a parent channel."""
with self._lock:
conn = self._get_conn()
rows = conn.execute("SELECT * FROM threads WHERE parent_id = ?", (str(parent_id),)).fetchall()
conn.close()
rows = self._conn.execute("SELECT * FROM threads WHERE parent_id = ?", (str(parent_id),)).fetchall()
return [dict(r) for r in rows]
def get_thread(self, thread_id: str) -> Optional[Dict[str, Any]]:
"""Retrieves a single thread's metadata."""
with self._lock:
conn = self._get_conn()
row = conn.execute("SELECT * FROM threads WHERE id = ?", (str(thread_id),)).fetchone()
conn.close()
row = self._conn.execute("SELECT * FROM threads WHERE id = ?", (str(thread_id),)).fetchone()
return dict(row) if row else None
def get_all_users(self) -> List[Dict[str, Any]]:
with self._lock:
conn = self._get_conn()
rows = conn.execute("SELECT * FROM users").fetchall()
conn.close()
rows = self._conn.execute("SELECT * FROM users").fetchall()
return [dict(r) for r in rows]
def get_user(self, user_id: str) -> Optional[Dict[str, Any]]:
with self._lock:
conn = self._get_conn()
row = conn.execute("SELECT * FROM users WHERE id = ?", (str(user_id),)).fetchone()
conn.close()
row = self._conn.execute("SELECT * FROM users WHERE id = ?", (str(user_id),)).fetchone()
if row:
data = dict(row)
if data.get("roles"):
@ -668,25 +608,20 @@ class BackupDatabase:
def get_server_assets(self, asset_type: Optional[str] = None) -> List[Dict[str, Any]]:
"""Returns all server assets, optionally filtered by type."""
with self._lock:
conn = self._get_conn()
if asset_type:
rows = conn.execute("SELECT * FROM server_assets WHERE type = ?", (asset_type,)).fetchall()
rows = self._conn.execute("SELECT * FROM server_assets WHERE type = ?", (asset_type,)).fetchall()
else:
rows = conn.execute("SELECT * FROM server_assets").fetchall()
conn.close()
rows = self._conn.execute("SELECT * FROM server_assets").fetchall()
return [dict(r) for r in rows]
def get_all_media(self) -> Dict[str, Dict[str, Any]]:
"""Returns the entire media pool as a dictionary indexed by hash."""
with self._lock:
conn = self._get_conn()
rows = conn.execute("SELECT * FROM media_pool").fetchall()
conn.close()
rows = self._conn.execute("SELECT * FROM media_pool").fetchall()
return {r["hash"]: dict(r) for r in rows}
def get_messages_paged(self, channel_id: str, limit: int = 100, offset: int = 0, after_id: Optional[str] = None) -> List[Dict[str, Any]]:
with self._lock:
conn = self._get_conn()
query = "SELECT * FROM messages WHERE channel_id = ?"
params = [str(channel_id)]
@ -697,24 +632,21 @@ class BackupDatabase:
query += " ORDER BY id ASC LIMIT ? OFFSET ?"
params.extend([limit, offset])
rows = conn.execute(query, params).fetchall()
rows = self._conn.execute(query, params).fetchall()
msg_list = [dict(r) for r in rows]
if msg_list:
# Fetch attachments for these messages
msg_ids = [m["id"] for m in msg_list]
placeholders = ",".join(["?"] * len(msg_ids))
# Attachments
att_rows = conn.execute(f"SELECT * FROM attachments WHERE message_id IN ({placeholders})", msg_ids).fetchall()
att_rows = self._conn.execute(f"SELECT * FROM attachments WHERE message_id IN ({placeholders})", msg_ids).fetchall()
atts_by_msg = {}
for ar in att_rows:
mid = ar["message_id"]
if mid not in atts_by_msg: atts_by_msg[mid] = []
atts_by_msg[mid].append(dict(ar))
# Embeds
emb_rows = conn.execute(f"SELECT * FROM embeds WHERE message_id IN ({placeholders})", msg_ids).fetchall()
emb_rows = self._conn.execute(f"SELECT * FROM embeds WHERE message_id IN ({placeholders})", msg_ids).fetchall()
embs_by_msg = {}
for er in emb_rows:
mid = er["message_id"]
@ -741,16 +673,14 @@ class BackupDatabase:
}
embs_by_msg[mid].append(e_dict)
# Reactions
rea_rows = conn.execute(f"SELECT * FROM reactions WHERE message_id IN ({placeholders})", msg_ids).fetchall()
rea_rows = self._conn.execute(f"SELECT * FROM reactions WHERE message_id IN ({placeholders})", msg_ids).fetchall()
reas_by_msg = {}
for rr in rea_rows:
mid = rr["message_id"]
if mid not in reas_by_msg: reas_by_msg[mid] = []
reas_by_msg[mid].append(dict(rr))
# Stickers (Message-specific)
st_rows = conn.execute(f"SELECT * FROM message_stickers WHERE message_id IN ({placeholders})", msg_ids).fetchall()
st_rows = self._conn.execute(f"SELECT * FROM message_stickers WHERE message_id IN ({placeholders})", msg_ids).fetchall()
sts_by_msg = {}
for sr in st_rows:
mid = sr["message_id"]
@ -764,10 +694,13 @@ class BackupDatabase:
m["reactions"] = reas_by_msg.get(m_id, [])
m["stickers"] = sts_by_msg.get(m_id, [])
conn.close()
return msg_list
def close(self):
# We don't keep long-lived connections in this model to avoid locking issues,
# but the method is here for parity with other DB classes.
pass
"""Commits any pending writes and closes the connection."""
with self._lock:
try:
self._conn.commit()
self._conn.close()
except Exception:
pass

View file

@ -21,7 +21,10 @@ class DiscordExporter:
self.base_dir = Path(base_dir) if base_dir else Path(".")
self.is_running = True
self.db: Optional[BackupDatabase] = None
self.sticker_cache: Dict[int, bytes] = {} # Deduplicate downloads in one session
self.sticker_cache: Dict[int, bytes] = {} # Deduplicate downloads in one session
# Pending avatar downloads — flushed after each message batch to keep the
# hot message-formatting path free of HTTP latency.
self._pending_avatars: List[tuple] = [] # (user_id, save_coroutine, av_path)
async def setup(self):
"""Prepares the output directory and fetches server metadata."""
@ -379,6 +382,9 @@ class DiscordExporter:
accumulated_files += len(m["attachments"])
# Persist to DB
# Flush deferred avatar downloads before persisting this batch
await self._flush_pending_avatars()
if self.db:
if batch_users: self.db.save_users(batch_users)
self.db.save_messages_batch(batch_messages)
@ -393,32 +399,35 @@ class DiscordExporter:
batch_users.clear()
batch_raw.clear()
if batch_raw and self.is_running:
results = await asyncio.gather(*(self._format_message(m) for m in batch_raw))
for m_data, u_list in results:
batch_messages.append(m_data)
if u_list: batch_users.extend(u_list)
# Process any remaining messages that didn't fill a full batch
if batch_raw and self.is_running:
results = await asyncio.gather(*(self._format_message(m) for m in batch_raw))
for m_data, u_list in results:
batch_messages.append(m_data)
if u_list: batch_users.extend(u_list)
new_count += len(batch_messages)
accumulated_count += len(batch_messages)
new_count += len(batch_messages)
accumulated_count += len(batch_messages)
for m in batch_messages:
if "attachments" in m:
accumulated_files += len(m["attachments"])
for m in batch_messages:
if "attachments" in m:
accumulated_files += len(m["attachments"])
# Flush deferred avatar downloads before persisting this batch
await self._flush_pending_avatars()
if self.db:
if batch_users: self.db.save_users(batch_users)
self.db.save_messages_batch(batch_messages)
if self.db:
if batch_users: self.db.save_users(batch_users)
self.db.save_messages_batch(batch_messages)
if progress_callback:
last_msg = batch_raw[-1]
author_name = getattr(last_msg.author, "display_name", "Unknown")
await progress_callback(channel_name, accumulated_count, author_name=author_name, thread_count=accumulated_threads, file_count=accumulated_files)
if progress_callback:
last_msg = batch_raw[-1]
author_name = getattr(last_msg.author, "display_name", "Unknown")
await progress_callback(channel_name, accumulated_count, author_name=author_name, thread_count=accumulated_threads, file_count=accumulated_files)
batch_messages.clear()
batch_users.clear()
batch_raw.clear()
batch_messages.clear()
batch_users.clear()
batch_raw.clear()
except discord.Forbidden:
logger.error(f"403 Forbidden: Missing Access to read messages in {channel_name} ({channel_id})")
@ -431,22 +440,24 @@ class DiscordExporter:
return accumulated_count, accumulated_threads, accumulated_files
async def _format_user(self, user):
"""Formats user data for the author or a mention."""
"""Formats user data for the author or a mention.
Avatar downloads are intentionally deferred to keep this off the hot
message-formatting path. Call _flush_pending_avatars() after each batch.
"""
user_id = str(user.id)
if user_id in self.user_cache:
return None
# New user discovered
# New user discovered — schedule avatar download but don't block here
avatar_file = None
if user.avatar:
try:
av_name = f"{user_id}.png"
av_target = self.users_path / av_name
if not av_target.exists():
await user.avatar.save(av_target)
avatar_file = f"users/{av_name}"
except Exception as e:
logger.error(f"Failed to save avatar for {user.name}: {e}")
av_name = f"{user_id}.png"
av_target = self.users_path / av_name
avatar_file = f"users/{av_name}"
if not av_target.exists():
# Queue for deferred download
self._pending_avatars.append((user_id, user.avatar, av_target))
roles = []
if hasattr(user, "roles"):
@ -463,6 +474,21 @@ class DiscordExporter:
self.user_cache[user_id] = user_data
return user_data
async def _flush_pending_avatars(self):
"""Downloads all queued user avatars concurrently, then clears the queue."""
if not self._pending_avatars:
return
async def _save_avatar(user_id, avatar_asset, target_path):
try:
await avatar_asset.save(target_path)
except Exception as e:
logger.error(f"Failed to save avatar for user {user_id}: {e}")
pending = self._pending_avatars[:]
self._pending_avatars.clear()
await asyncio.gather(*[_save_avatar(uid, av, path) for uid, av, path in pending])
async def _format_message(self, msg):
"""Formats a single message and its author for DB storage."""
new_users = []
@ -478,11 +504,11 @@ class DiscordExporter:
if u_ment: new_users.append(u_ment)
# 2. Attachments handling (Content-Addressable Storage)
# ... (rest of the logic remains same, just updating m_data)
# All attachments in a message are downloaded concurrently.
attachments = []
if msg.attachments:
for att in msg.attachments:
att_data = await self._process_media(
att_tasks = [
self._process_media(
media_id=att.id,
url=att.url,
filename=att.filename,
@ -490,8 +516,10 @@ class DiscordExporter:
content_type=att.content_type,
save_method=att.save
)
if att_data:
attachments.append(att_data)
for att in msg.attachments
]
att_results = await asyncio.gather(*att_tasks)
attachments = [r for r in att_results if r]
# 2.5 Stickers handling
stickers = []
@ -556,16 +584,20 @@ class DiscordExporter:
if snapshot.content:
content = snapshot.content
for s_att in snapshot.attachments:
att_res = await self._process_media(
media_id=s_att.id,
url=s_att.url,
filename=s_att.filename,
size=s_att.size,
content_type=s_att.content_type,
save_method=s_att.save
)
if att_res: attachments.append(att_res)
if snapshot.attachments:
snap_tasks = [
self._process_media(
media_id=s_att.id,
url=s_att.url,
filename=s_att.filename,
size=s_att.size,
content_type=s_att.content_type,
save_method=s_att.save
)
for s_att in snapshot.attachments
]
snap_results = await asyncio.gather(*snap_tasks)
attachments.extend(r for r in snap_results if r)
for s_emb in snapshot.embeds:
embeds.append(s_emb.to_dict())
@ -623,14 +655,16 @@ class DiscordExporter:
try: tmp.close()
except: pass
file_hash = self._calculate_sha256(tmp_path)
actual_size = tmp_path.stat().st_size
# Offload CPU-bound hashing and blocking file ops to the thread pool
# so we don't stall concurrent downloads on the event loop.
file_hash = await asyncio.to_thread(self._calculate_sha256, tmp_path)
actual_size = (await asyncio.to_thread(tmp_path.stat)).st_size
# Check if hash already exists in pool
if self.db:
in_pool = self.db.get_media_by_hash(file_hash)
if in_pool:
tmp_path.unlink()
await asyncio.to_thread(tmp_path.unlink)
return {
"id": str(media_id),
"filename": filename,
@ -645,7 +679,10 @@ class DiscordExporter:
target_filename = f"{file_hash}{ext}"
target_path = self.attachments_path / target_filename
shutil.move(str(tmp_path), str(target_path))
await asyncio.to_thread(shutil.move, str(tmp_path), str(target_path))
# Mark as successfully moved so finally block doesn't delete it
tmp_path = None
if self.db:
self.db.add_media_to_pool(file_hash, f"attachments/{target_filename}", actual_size, content_type, str(url))
@ -658,9 +695,14 @@ class DiscordExporter:
"content_type": content_type,
"local_hash": file_hash
}
except Exception as e:
logger.error(f"Failed to process media {filename}: {e}")
if tmp_path and tmp_path.exists(): tmp_path.unlink()
except BaseException as e:
if not isinstance(e, asyncio.CancelledError):
logger.error(f"Failed to process media {filename}: {e}")
raise
finally:
if tmp_path and tmp_path.exists():
try: tmp_path.unlink()
except: pass
return None
async def export_threads(self, channel_id: int, progress_callback=None, force=False, accumulated_count=0, accumulated_threads=0, accumulated_files=0, after_id: int | None = None):
@ -750,25 +792,35 @@ class DiscordExporter:
})
self.db.save_threads(thread_meta)
for thread in all_threads:
if not self.is_running: break
await asyncio.sleep(0)
# Export threads concurrently — semaphore limits to 5 at a time to
# avoid flooding Discord's rate limiter.
sem = asyncio.Semaphore(5)
async def _export_one_thread(thread, t_idx):
async with sem:
if not self.is_running:
return 0, 0, 0
cnt, thr, fls = await self.export_channel_messages(
thread.id,
progress_callback=progress_callback,
force=force,
accumulated_count=0,
accumulated_threads=0,
accumulated_files=0,
after_id=after_id
)
return cnt, thr, fls
if all_threads:
thread_results = await asyncio.gather(
*[_export_one_thread(t, i) for i, t in enumerate(all_threads)]
)
for cnt, thr, fls in thread_results:
accumulated_count += cnt
accumulated_threads += 1 + thr # +1 for the thread itself
accumulated_files += fls
accumulated_threads += 1
if progress_callback:
await progress_callback(channel.name, accumulated_count, thread_count=accumulated_threads, file_count=accumulated_files)
# Backup thread messages
accumulated_count, accumulated_threads, accumulated_files = await self.export_channel_messages(thread.id, progress_callback=progress_callback, force=force, accumulated_count=accumulated_count, accumulated_threads=accumulated_threads, accumulated_files=accumulated_files, after_id=after_id)
# For forums, ensure the starter message exists in the DB
if is_forum:
# starter_message is handled by export_channel_messages since it's just a message in that thread
# However we may want to mark it or store forum-specific tags
try:
# Just yield for concurrency
await asyncio.sleep(0)
except Exception as e:
logger.error(f"Error processing forum thread {thread.name}: {e}")
return accumulated_count, accumulated_threads, accumulated_files