optimize backup speed
This commit is contained in:
parent
4b24b29c03
commit
cefa477459
2 changed files with 287 additions and 302 deletions
|
|
@ -25,17 +25,18 @@ class BackupDatabase:
|
|||
def __init__(self, db_path: Path | str):
|
||||
self.db_path = Path(db_path)
|
||||
self._lock = threading.Lock()
|
||||
self._conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
# WAL mode allows concurrent readers and batches disk flushes significantly
|
||||
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||
self._conn.execute("PRAGMA synchronous=NORMAL")
|
||||
self._conn.execute("PRAGMA cache_size=-32000") # 32 MB page cache
|
||||
self._init_db()
|
||||
|
||||
def _get_conn(self):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
def _init_db(self):
|
||||
"""Initializes the database schema."""
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
conn = self._conn
|
||||
try:
|
||||
# Guild Profile
|
||||
conn.execute("""
|
||||
|
|
@ -242,12 +243,11 @@ class BackupDatabase:
|
|||
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
pass # persistent connection — do not close
|
||||
|
||||
def set_guild_profile(self, data: Dict[str, Any]):
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
conn.execute("""
|
||||
self._conn.execute("""
|
||||
INSERT OR REPLACE INTO guild_profile (id, name, description, icon_file, icon_url, banner_file, banner_url, owner_id, last_backup, ignore_channels)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
|
|
@ -257,14 +257,11 @@ class BackupDatabase:
|
|||
str(data.get("owner_id")),
|
||||
data.get("last_backup"), json.dumps(data.get("ignore_channels", []))
|
||||
))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
self._conn.commit()
|
||||
|
||||
def get_guild_profile(self) -> Optional[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
row = conn.execute("SELECT * FROM guild_profile LIMIT 1").fetchone()
|
||||
conn.close()
|
||||
row = self._conn.execute("SELECT * FROM guild_profile LIMIT 1").fetchone()
|
||||
if row:
|
||||
data = dict(row)
|
||||
if data.get("ignore_channels"):
|
||||
|
|
@ -276,11 +273,8 @@ class BackupDatabase:
|
|||
|
||||
def save_roles(self, roles: List[Dict[str, Any]]):
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
# Ensure complex fields are strings if they aren't already
|
||||
formatted = []
|
||||
for r in roles:
|
||||
formatted.append({
|
||||
formatted = [
|
||||
{
|
||||
"id": str(r["id"]),
|
||||
"name": r["name"],
|
||||
"color": r["color"],
|
||||
|
|
@ -288,102 +282,83 @@ class BackupDatabase:
|
|||
"permissions": str(r["permissions"]),
|
||||
"hoist": 1 if r["hoist"] else 0,
|
||||
"mentionable": 1 if r["mentionable"] else 0
|
||||
})
|
||||
|
||||
conn.executemany("""
|
||||
}
|
||||
for r in roles
|
||||
]
|
||||
self._conn.executemany("""
|
||||
INSERT OR REPLACE INTO roles (id, name, color, position, permissions, hoist, mentionable)
|
||||
VALUES (:id, :name, :color, :position, :permissions, :hoist, :mentionable)
|
||||
""", formatted)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
self._conn.commit()
|
||||
|
||||
def save_channels(self, channels: List[Dict[str, Any]]):
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
conn.executemany("""
|
||||
self._conn.executemany("""
|
||||
INSERT OR REPLACE INTO channels (id, name, type, position, category_id, topic, nsfw, bitrate, slowmode_delay)
|
||||
VALUES (:id, :name, :type, :position, :category_id, :topic, :nsfw, :bitrate, :slowmode_delay)
|
||||
""", channels)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
self._conn.commit()
|
||||
|
||||
def save_permissions(self, permissions: List[Dict[str, Any]]):
|
||||
"""Saves a batch of channel permission overwrites."""
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
try:
|
||||
conn.executemany("""
|
||||
self._conn.executemany("""
|
||||
INSERT INTO permissions (channel_id, target_id, target_type, allow, deny)
|
||||
VALUES (:channel_id, :target_id, :target_type, :allow, :deny)
|
||||
""", permissions)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
self._conn.commit()
|
||||
|
||||
def save_users(self, users: List[Dict[str, Any]]):
|
||||
"""Saves users to the author cache."""
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
conn.executemany("""
|
||||
self._conn.executemany("""
|
||||
INSERT OR REPLACE INTO users (id, username, display_name, avatar_file, avatar_url, roles)
|
||||
VALUES (:id, :username, :display_name, :avatar_file, :avatar_url, :roles)
|
||||
""", users)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
self._conn.commit()
|
||||
|
||||
def save_server_assets(self, assets: List[Dict[str, Any]]):
|
||||
"""Saves a batch of server assets (emojis, stickers) to the database."""
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
formatted = []
|
||||
for a in assets:
|
||||
formatted.append({
|
||||
formatted = [
|
||||
{
|
||||
"id": str(a["id"]),
|
||||
"name": a.get("name"),
|
||||
"type": a.get("type"),
|
||||
"filename": a.get("filename"),
|
||||
"url": a.get("url"),
|
||||
"mime_type": a.get("mime_type")
|
||||
})
|
||||
|
||||
conn.executemany("""
|
||||
}
|
||||
for a in assets
|
||||
]
|
||||
self._conn.executemany("""
|
||||
INSERT OR REPLACE INTO server_assets (id, name, type, filename, url, mime_type)
|
||||
VALUES (:id, :name, :type, :filename, :url, :mime_type)
|
||||
""", formatted)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
self._conn.commit()
|
||||
|
||||
def save_threads(self, threads: List[Dict[str, Any]]):
|
||||
"""Saves metadata for threads to the database."""
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
try:
|
||||
conn.executemany("""
|
||||
self._conn.executemany("""
|
||||
INSERT OR REPLACE INTO threads (id, name, type, parent_id, message_count, member_count, archived, archive_timestamp, auto_archive_duration, locked, applied_tags)
|
||||
VALUES (:id, :name, :type, :parent_id, :message_count, :member_count, :archived, :archive_timestamp, :auto_archive_duration, :locked, :applied_tags)
|
||||
""", threads)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
self._conn.commit()
|
||||
|
||||
def save_forum_tags(self, tags: List[Dict[str, Any]]):
|
||||
"""Saves definitions for forum tags."""
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
try:
|
||||
conn.executemany("""
|
||||
self._conn.executemany("""
|
||||
INSERT OR REPLACE INTO forum_tags (id, forum_id, name, moderated, emoji_id, emoji_name)
|
||||
VALUES (:id, :forum_id, :name, :moderated, :emoji_id, :emoji_name)
|
||||
""", tags)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
self._conn.commit()
|
||||
|
||||
def save_messages_batch(self, messages: List[Dict[str, Any]]):
|
||||
"""Batch inserts messages and their attachments."""
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
try:
|
||||
conn = self._conn
|
||||
# Insert messages
|
||||
conn.executemany("""
|
||||
INSERT OR REPLACE INTO messages (id, channel_id, author_id, content, timestamp, type, message_reference, is_pinned, extra_data)
|
||||
|
|
@ -434,7 +409,6 @@ class BackupDatabase:
|
|||
for msg in messages:
|
||||
if "embeds" in msg and msg["embeds"]:
|
||||
for emb in msg["embeds"]:
|
||||
# Insert top-level embed
|
||||
conn.execute("""
|
||||
INSERT INTO embeds (
|
||||
message_id, title, description, url, color, timestamp,
|
||||
|
|
@ -467,46 +441,34 @@ class BackupDatabase:
|
|||
""", all_stickers)
|
||||
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_last_message_id(self, channel_id: str) -> Optional[str]:
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
row = conn.execute("SELECT id FROM messages WHERE channel_id = ? ORDER BY id DESC LIMIT 1", (str(channel_id),)).fetchone()
|
||||
conn.close()
|
||||
row = self._conn.execute("SELECT id FROM messages WHERE channel_id = ? ORDER BY id DESC LIMIT 1", (str(channel_id),)).fetchone()
|
||||
return row["id"] if row else None
|
||||
|
||||
def get_media_by_hash(self, file_hash: str) -> Optional[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
row = conn.execute("SELECT * FROM media_pool WHERE hash = ?", (file_hash,)).fetchone()
|
||||
conn.close()
|
||||
row = self._conn.execute("SELECT * FROM media_pool WHERE hash = ?", (file_hash,)).fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def get_media_by_url(self, url: str) -> Optional[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
row = conn.execute("SELECT * FROM media_pool WHERE first_seen_url = ?", (url,)).fetchone()
|
||||
conn.close()
|
||||
row = self._conn.execute("SELECT * FROM media_pool WHERE first_seen_url = ?", (url,)).fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def add_media_to_pool(self, file_hash: str, local_path: str, size: int, mime_type: str, url: str):
|
||||
# NOTE: No commit here — caller (save_messages_batch) commits at end of batch
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
conn.execute("""
|
||||
self._conn.execute("""
|
||||
INSERT OR REPLACE INTO media_pool (hash, local_path, size, mime_type, first_seen_url)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (file_hash, str(local_path), size, mime_type, url))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def get_stats_by_channel(self) -> Dict[int, Dict[str, Any]]:
|
||||
"""Returns aggregate stats for all channels with backups."""
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
# Summary of messages per effective channel (including threads rolled up)
|
||||
msg_rows = conn.execute("""
|
||||
msg_rows = self._conn.execute("""
|
||||
SELECT
|
||||
COALESCE(t.parent_id, m.channel_id) as channel_id,
|
||||
COUNT(m.id) as msg_count
|
||||
|
|
@ -515,15 +477,13 @@ class BackupDatabase:
|
|||
GROUP BY channel_id
|
||||
""").fetchall()
|
||||
|
||||
# Thread counts per parent
|
||||
thread_rows = conn.execute("""
|
||||
thread_rows = self._conn.execute("""
|
||||
SELECT parent_id, COUNT(*) as thread_count
|
||||
FROM threads
|
||||
GROUP BY parent_id
|
||||
""").fetchall()
|
||||
|
||||
# Summary of attachments per effective channel (including threads rolled up)
|
||||
att_rows = conn.execute("""
|
||||
att_rows = self._conn.execute("""
|
||||
SELECT
|
||||
COALESCE(t.parent_id, m.channel_id) as channel_id,
|
||||
COUNT(a.id) as att_count,
|
||||
|
|
@ -534,8 +494,6 @@ class BackupDatabase:
|
|||
GROUP BY channel_id
|
||||
""").fetchall()
|
||||
|
||||
conn.close()
|
||||
|
||||
stats = {}
|
||||
for r in msg_rows:
|
||||
cid = parse_snowflake(r["channel_id"])
|
||||
|
|
@ -566,22 +524,18 @@ class BackupDatabase:
|
|||
|
||||
def get_all_roles(self) -> List[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
rows = conn.execute("SELECT * FROM roles ORDER BY position DESC").fetchall()
|
||||
conn.close()
|
||||
rows = self._conn.execute("SELECT * FROM roles ORDER BY position DESC").fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
def get_all_channels(self) -> List[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
rows = conn.execute("SELECT * FROM channels ORDER BY position ASC").fetchall()
|
||||
rows = self._conn.execute("SELECT * FROM channels ORDER BY position ASC").fetchall()
|
||||
|
||||
# Fetch all permissions
|
||||
chan_list = [dict(r) for r in rows]
|
||||
if chan_list:
|
||||
ids = [c["id"] for c in chan_list]
|
||||
placeholders = ",".join(["?"] * len(ids))
|
||||
perm_rows = conn.execute(f"SELECT * FROM permissions WHERE channel_id IN ({placeholders})", ids).fetchall()
|
||||
perm_rows = self._conn.execute(f"SELECT * FROM permissions WHERE channel_id IN ({placeholders})", ids).fetchall()
|
||||
|
||||
perms_by_chan = {}
|
||||
for pr in perm_rows:
|
||||
|
|
@ -597,8 +551,7 @@ class BackupDatabase:
|
|||
for c in chan_list:
|
||||
c["overwrites"] = perms_by_chan.get(c["id"], [])
|
||||
|
||||
# Fetch Forum Tags
|
||||
tag_rows = conn.execute(f"SELECT * FROM forum_tags WHERE forum_id IN ({placeholders})", ids).fetchall()
|
||||
tag_rows = self._conn.execute(f"SELECT * FROM forum_tags WHERE forum_id IN ({placeholders})", ids).fetchall()
|
||||
tags_by_forum = {}
|
||||
for tr in tag_rows:
|
||||
fid = tr["forum_id"]
|
||||
|
|
@ -608,56 +561,43 @@ class BackupDatabase:
|
|||
for c in chan_list:
|
||||
c["available_tags"] = tags_by_forum.get(c["id"], [])
|
||||
|
||||
conn.close()
|
||||
return chan_list
|
||||
|
||||
def get_all_threads(self) -> List[Dict[str, Any]]:
|
||||
"""Returns metadata for all threads in the backup."""
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
rows = conn.execute("SELECT * FROM threads").fetchall()
|
||||
conn.close()
|
||||
rows = self._conn.execute("SELECT * FROM threads").fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
def get_forum_tags(self, forum_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""Returns forum tag definitions."""
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
if forum_id:
|
||||
rows = conn.execute("SELECT * FROM forum_tags WHERE forum_id = ?", (str(forum_id),)).fetchall()
|
||||
rows = self._conn.execute("SELECT * FROM forum_tags WHERE forum_id = ?", (str(forum_id),)).fetchall()
|
||||
else:
|
||||
rows = conn.execute("SELECT * FROM forum_tags").fetchall()
|
||||
conn.close()
|
||||
rows = self._conn.execute("SELECT * FROM forum_tags").fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
def get_threads_by_parent(self, parent_id: str) -> List[Dict[str, Any]]:
|
||||
"""Returns all threads belonging to a parent channel."""
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
rows = conn.execute("SELECT * FROM threads WHERE parent_id = ?", (str(parent_id),)).fetchall()
|
||||
conn.close()
|
||||
rows = self._conn.execute("SELECT * FROM threads WHERE parent_id = ?", (str(parent_id),)).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
def get_thread(self, thread_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Retrieves a single thread's metadata."""
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
row = conn.execute("SELECT * FROM threads WHERE id = ?", (str(thread_id),)).fetchone()
|
||||
conn.close()
|
||||
row = self._conn.execute("SELECT * FROM threads WHERE id = ?", (str(thread_id),)).fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def get_all_users(self) -> List[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
rows = conn.execute("SELECT * FROM users").fetchall()
|
||||
conn.close()
|
||||
rows = self._conn.execute("SELECT * FROM users").fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
def get_user(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
row = conn.execute("SELECT * FROM users WHERE id = ?", (str(user_id),)).fetchone()
|
||||
conn.close()
|
||||
row = self._conn.execute("SELECT * FROM users WHERE id = ?", (str(user_id),)).fetchone()
|
||||
if row:
|
||||
data = dict(row)
|
||||
if data.get("roles"):
|
||||
|
|
@ -668,25 +608,20 @@ class BackupDatabase:
|
|||
def get_server_assets(self, asset_type: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""Returns all server assets, optionally filtered by type."""
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
if asset_type:
|
||||
rows = conn.execute("SELECT * FROM server_assets WHERE type = ?", (asset_type,)).fetchall()
|
||||
rows = self._conn.execute("SELECT * FROM server_assets WHERE type = ?", (asset_type,)).fetchall()
|
||||
else:
|
||||
rows = conn.execute("SELECT * FROM server_assets").fetchall()
|
||||
conn.close()
|
||||
rows = self._conn.execute("SELECT * FROM server_assets").fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
def get_all_media(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Returns the entire media pool as a dictionary indexed by hash."""
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
rows = conn.execute("SELECT * FROM media_pool").fetchall()
|
||||
conn.close()
|
||||
rows = self._conn.execute("SELECT * FROM media_pool").fetchall()
|
||||
return {r["hash"]: dict(r) for r in rows}
|
||||
|
||||
def get_messages_paged(self, channel_id: str, limit: int = 100, offset: int = 0, after_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
query = "SELECT * FROM messages WHERE channel_id = ?"
|
||||
params = [str(channel_id)]
|
||||
|
||||
|
|
@ -697,24 +632,21 @@ class BackupDatabase:
|
|||
query += " ORDER BY id ASC LIMIT ? OFFSET ?"
|
||||
params.extend([limit, offset])
|
||||
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
rows = self._conn.execute(query, params).fetchall()
|
||||
msg_list = [dict(r) for r in rows]
|
||||
|
||||
if msg_list:
|
||||
# Fetch attachments for these messages
|
||||
msg_ids = [m["id"] for m in msg_list]
|
||||
placeholders = ",".join(["?"] * len(msg_ids))
|
||||
|
||||
# Attachments
|
||||
att_rows = conn.execute(f"SELECT * FROM attachments WHERE message_id IN ({placeholders})", msg_ids).fetchall()
|
||||
att_rows = self._conn.execute(f"SELECT * FROM attachments WHERE message_id IN ({placeholders})", msg_ids).fetchall()
|
||||
atts_by_msg = {}
|
||||
for ar in att_rows:
|
||||
mid = ar["message_id"]
|
||||
if mid not in atts_by_msg: atts_by_msg[mid] = []
|
||||
atts_by_msg[mid].append(dict(ar))
|
||||
|
||||
# Embeds
|
||||
emb_rows = conn.execute(f"SELECT * FROM embeds WHERE message_id IN ({placeholders})", msg_ids).fetchall()
|
||||
emb_rows = self._conn.execute(f"SELECT * FROM embeds WHERE message_id IN ({placeholders})", msg_ids).fetchall()
|
||||
embs_by_msg = {}
|
||||
for er in emb_rows:
|
||||
mid = er["message_id"]
|
||||
|
|
@ -741,16 +673,14 @@ class BackupDatabase:
|
|||
}
|
||||
embs_by_msg[mid].append(e_dict)
|
||||
|
||||
# Reactions
|
||||
rea_rows = conn.execute(f"SELECT * FROM reactions WHERE message_id IN ({placeholders})", msg_ids).fetchall()
|
||||
rea_rows = self._conn.execute(f"SELECT * FROM reactions WHERE message_id IN ({placeholders})", msg_ids).fetchall()
|
||||
reas_by_msg = {}
|
||||
for rr in rea_rows:
|
||||
mid = rr["message_id"]
|
||||
if mid not in reas_by_msg: reas_by_msg[mid] = []
|
||||
reas_by_msg[mid].append(dict(rr))
|
||||
|
||||
# Stickers (Message-specific)
|
||||
st_rows = conn.execute(f"SELECT * FROM message_stickers WHERE message_id IN ({placeholders})", msg_ids).fetchall()
|
||||
st_rows = self._conn.execute(f"SELECT * FROM message_stickers WHERE message_id IN ({placeholders})", msg_ids).fetchall()
|
||||
sts_by_msg = {}
|
||||
for sr in st_rows:
|
||||
mid = sr["message_id"]
|
||||
|
|
@ -764,10 +694,13 @@ class BackupDatabase:
|
|||
m["reactions"] = reas_by_msg.get(m_id, [])
|
||||
m["stickers"] = sts_by_msg.get(m_id, [])
|
||||
|
||||
conn.close()
|
||||
return msg_list
|
||||
|
||||
def close(self):
|
||||
# We don't keep long-lived connections in this model to avoid locking issues,
|
||||
# but the method is here for parity with other DB classes.
|
||||
"""Commits any pending writes and closes the connection."""
|
||||
with self._lock:
|
||||
try:
|
||||
self._conn.commit()
|
||||
self._conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -22,6 +22,9 @@ class DiscordExporter:
|
|||
self.is_running = True
|
||||
self.db: Optional[BackupDatabase] = None
|
||||
self.sticker_cache: Dict[int, bytes] = {} # Deduplicate downloads in one session
|
||||
# Pending avatar downloads — flushed after each message batch to keep the
|
||||
# hot message-formatting path free of HTTP latency.
|
||||
self._pending_avatars: List[tuple] = [] # (user_id, save_coroutine, av_path)
|
||||
|
||||
async def setup(self):
|
||||
"""Prepares the output directory and fetches server metadata."""
|
||||
|
|
@ -379,6 +382,9 @@ class DiscordExporter:
|
|||
accumulated_files += len(m["attachments"])
|
||||
|
||||
# Persist to DB
|
||||
# Flush deferred avatar downloads before persisting this batch
|
||||
await self._flush_pending_avatars()
|
||||
|
||||
if self.db:
|
||||
if batch_users: self.db.save_users(batch_users)
|
||||
self.db.save_messages_batch(batch_messages)
|
||||
|
|
@ -393,6 +399,7 @@ class DiscordExporter:
|
|||
batch_users.clear()
|
||||
batch_raw.clear()
|
||||
|
||||
# Process any remaining messages that didn't fill a full batch
|
||||
if batch_raw and self.is_running:
|
||||
results = await asyncio.gather(*(self._format_message(m) for m in batch_raw))
|
||||
for m_data, u_list in results:
|
||||
|
|
@ -406,6 +413,8 @@ class DiscordExporter:
|
|||
if "attachments" in m:
|
||||
accumulated_files += len(m["attachments"])
|
||||
|
||||
# Flush deferred avatar downloads before persisting this batch
|
||||
await self._flush_pending_avatars()
|
||||
|
||||
if self.db:
|
||||
if batch_users: self.db.save_users(batch_users)
|
||||
|
|
@ -431,22 +440,24 @@ class DiscordExporter:
|
|||
return accumulated_count, accumulated_threads, accumulated_files
|
||||
|
||||
async def _format_user(self, user):
|
||||
"""Formats user data for the author or a mention."""
|
||||
"""Formats user data for the author or a mention.
|
||||
|
||||
Avatar downloads are intentionally deferred to keep this off the hot
|
||||
message-formatting path. Call _flush_pending_avatars() after each batch.
|
||||
"""
|
||||
user_id = str(user.id)
|
||||
if user_id in self.user_cache:
|
||||
return None
|
||||
|
||||
# New user discovered
|
||||
# New user discovered — schedule avatar download but don't block here
|
||||
avatar_file = None
|
||||
if user.avatar:
|
||||
try:
|
||||
av_name = f"{user_id}.png"
|
||||
av_target = self.users_path / av_name
|
||||
if not av_target.exists():
|
||||
await user.avatar.save(av_target)
|
||||
avatar_file = f"users/{av_name}"
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save avatar for {user.name}: {e}")
|
||||
if not av_target.exists():
|
||||
# Queue for deferred download
|
||||
self._pending_avatars.append((user_id, user.avatar, av_target))
|
||||
|
||||
roles = []
|
||||
if hasattr(user, "roles"):
|
||||
|
|
@ -463,6 +474,21 @@ class DiscordExporter:
|
|||
self.user_cache[user_id] = user_data
|
||||
return user_data
|
||||
|
||||
async def _flush_pending_avatars(self):
|
||||
"""Downloads all queued user avatars concurrently, then clears the queue."""
|
||||
if not self._pending_avatars:
|
||||
return
|
||||
|
||||
async def _save_avatar(user_id, avatar_asset, target_path):
|
||||
try:
|
||||
await avatar_asset.save(target_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save avatar for user {user_id}: {e}")
|
||||
|
||||
pending = self._pending_avatars[:]
|
||||
self._pending_avatars.clear()
|
||||
await asyncio.gather(*[_save_avatar(uid, av, path) for uid, av, path in pending])
|
||||
|
||||
async def _format_message(self, msg):
|
||||
"""Formats a single message and its author for DB storage."""
|
||||
new_users = []
|
||||
|
|
@ -478,11 +504,11 @@ class DiscordExporter:
|
|||
if u_ment: new_users.append(u_ment)
|
||||
|
||||
# 2. Attachments handling (Content-Addressable Storage)
|
||||
# ... (rest of the logic remains same, just updating m_data)
|
||||
# All attachments in a message are downloaded concurrently.
|
||||
attachments = []
|
||||
if msg.attachments:
|
||||
for att in msg.attachments:
|
||||
att_data = await self._process_media(
|
||||
att_tasks = [
|
||||
self._process_media(
|
||||
media_id=att.id,
|
||||
url=att.url,
|
||||
filename=att.filename,
|
||||
|
|
@ -490,8 +516,10 @@ class DiscordExporter:
|
|||
content_type=att.content_type,
|
||||
save_method=att.save
|
||||
)
|
||||
if att_data:
|
||||
attachments.append(att_data)
|
||||
for att in msg.attachments
|
||||
]
|
||||
att_results = await asyncio.gather(*att_tasks)
|
||||
attachments = [r for r in att_results if r]
|
||||
|
||||
# 2.5 Stickers handling
|
||||
stickers = []
|
||||
|
|
@ -556,8 +584,9 @@ class DiscordExporter:
|
|||
if snapshot.content:
|
||||
content = snapshot.content
|
||||
|
||||
for s_att in snapshot.attachments:
|
||||
att_res = await self._process_media(
|
||||
if snapshot.attachments:
|
||||
snap_tasks = [
|
||||
self._process_media(
|
||||
media_id=s_att.id,
|
||||
url=s_att.url,
|
||||
filename=s_att.filename,
|
||||
|
|
@ -565,7 +594,10 @@ class DiscordExporter:
|
|||
content_type=s_att.content_type,
|
||||
save_method=s_att.save
|
||||
)
|
||||
if att_res: attachments.append(att_res)
|
||||
for s_att in snapshot.attachments
|
||||
]
|
||||
snap_results = await asyncio.gather(*snap_tasks)
|
||||
attachments.extend(r for r in snap_results if r)
|
||||
|
||||
for s_emb in snapshot.embeds:
|
||||
embeds.append(s_emb.to_dict())
|
||||
|
|
@ -623,14 +655,16 @@ class DiscordExporter:
|
|||
try: tmp.close()
|
||||
except: pass
|
||||
|
||||
file_hash = self._calculate_sha256(tmp_path)
|
||||
actual_size = tmp_path.stat().st_size
|
||||
# Offload CPU-bound hashing and blocking file ops to the thread pool
|
||||
# so we don't stall concurrent downloads on the event loop.
|
||||
file_hash = await asyncio.to_thread(self._calculate_sha256, tmp_path)
|
||||
actual_size = (await asyncio.to_thread(tmp_path.stat)).st_size
|
||||
|
||||
# Check if hash already exists in pool
|
||||
if self.db:
|
||||
in_pool = self.db.get_media_by_hash(file_hash)
|
||||
if in_pool:
|
||||
tmp_path.unlink()
|
||||
await asyncio.to_thread(tmp_path.unlink)
|
||||
return {
|
||||
"id": str(media_id),
|
||||
"filename": filename,
|
||||
|
|
@ -645,7 +679,10 @@ class DiscordExporter:
|
|||
target_filename = f"{file_hash}{ext}"
|
||||
target_path = self.attachments_path / target_filename
|
||||
|
||||
shutil.move(str(tmp_path), str(target_path))
|
||||
await asyncio.to_thread(shutil.move, str(tmp_path), str(target_path))
|
||||
|
||||
# Mark as successfully moved so finally block doesn't delete it
|
||||
tmp_path = None
|
||||
|
||||
if self.db:
|
||||
self.db.add_media_to_pool(file_hash, f"attachments/{target_filename}", actual_size, content_type, str(url))
|
||||
|
|
@ -658,9 +695,14 @@ class DiscordExporter:
|
|||
"content_type": content_type,
|
||||
"local_hash": file_hash
|
||||
}
|
||||
except Exception as e:
|
||||
except BaseException as e:
|
||||
if not isinstance(e, asyncio.CancelledError):
|
||||
logger.error(f"Failed to process media {filename}: {e}")
|
||||
if tmp_path and tmp_path.exists(): tmp_path.unlink()
|
||||
raise
|
||||
finally:
|
||||
if tmp_path and tmp_path.exists():
|
||||
try: tmp_path.unlink()
|
||||
except: pass
|
||||
return None
|
||||
|
||||
async def export_threads(self, channel_id: int, progress_callback=None, force=False, accumulated_count=0, accumulated_threads=0, accumulated_files=0, after_id: int | None = None):
|
||||
|
|
@ -750,25 +792,35 @@ class DiscordExporter:
|
|||
})
|
||||
self.db.save_threads(thread_meta)
|
||||
|
||||
for thread in all_threads:
|
||||
if not self.is_running: break
|
||||
await asyncio.sleep(0)
|
||||
# Export threads concurrently — semaphore limits to 5 at a time to
|
||||
# avoid flooding Discord's rate limiter.
|
||||
sem = asyncio.Semaphore(5)
|
||||
|
||||
async def _export_one_thread(thread, t_idx):
|
||||
async with sem:
|
||||
if not self.is_running:
|
||||
return 0, 0, 0
|
||||
cnt, thr, fls = await self.export_channel_messages(
|
||||
thread.id,
|
||||
progress_callback=progress_callback,
|
||||
force=force,
|
||||
accumulated_count=0,
|
||||
accumulated_threads=0,
|
||||
accumulated_files=0,
|
||||
after_id=after_id
|
||||
)
|
||||
return cnt, thr, fls
|
||||
|
||||
if all_threads:
|
||||
thread_results = await asyncio.gather(
|
||||
*[_export_one_thread(t, i) for i, t in enumerate(all_threads)]
|
||||
)
|
||||
for cnt, thr, fls in thread_results:
|
||||
accumulated_count += cnt
|
||||
accumulated_threads += 1 + thr # +1 for the thread itself
|
||||
accumulated_files += fls
|
||||
|
||||
accumulated_threads += 1
|
||||
if progress_callback:
|
||||
await progress_callback(channel.name, accumulated_count, thread_count=accumulated_threads, file_count=accumulated_files)
|
||||
|
||||
# Backup thread messages
|
||||
accumulated_count, accumulated_threads, accumulated_files = await self.export_channel_messages(thread.id, progress_callback=progress_callback, force=force, accumulated_count=accumulated_count, accumulated_threads=accumulated_threads, accumulated_files=accumulated_files, after_id=after_id)
|
||||
|
||||
# For forums, ensure the starter message exists in the DB
|
||||
if is_forum:
|
||||
# starter_message is handled by export_channel_messages since it's just a message in that thread
|
||||
# However we may want to mark it or store forum-specific tags
|
||||
try:
|
||||
# Just yield for concurrency
|
||||
await asyncio.sleep(0)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing forum thread {thread.name}: {e}")
|
||||
|
||||
return accumulated_count, accumulated_threads, accumulated_files
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue