optimize backup speed
This commit is contained in:
parent
4b24b29c03
commit
cefa477459
2 changed files with 287 additions and 302 deletions
|
|
@ -25,17 +25,18 @@ class BackupDatabase:
|
||||||
def __init__(self, db_path: Path | str):
|
def __init__(self, db_path: Path | str):
|
||||||
self.db_path = Path(db_path)
|
self.db_path = Path(db_path)
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
self._conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
||||||
|
self._conn.row_factory = sqlite3.Row
|
||||||
|
# WAL mode allows concurrent readers and batches disk flushes significantly
|
||||||
|
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
self._conn.execute("PRAGMA synchronous=NORMAL")
|
||||||
|
self._conn.execute("PRAGMA cache_size=-32000") # 32 MB page cache
|
||||||
self._init_db()
|
self._init_db()
|
||||||
|
|
||||||
def _get_conn(self):
|
|
||||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
|
||||||
conn.row_factory = sqlite3.Row
|
|
||||||
return conn
|
|
||||||
|
|
||||||
def _init_db(self):
|
def _init_db(self):
|
||||||
"""Initializes the database schema."""
|
"""Initializes the database schema."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
conn = self._conn
|
||||||
try:
|
try:
|
||||||
# Guild Profile
|
# Guild Profile
|
||||||
conn.execute("""
|
conn.execute("""
|
||||||
|
|
@ -242,12 +243,11 @@ class BackupDatabase:
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
pass # persistent connection — do not close
|
||||||
|
|
||||||
def set_guild_profile(self, data: Dict[str, Any]):
|
def set_guild_profile(self, data: Dict[str, Any]):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
self._conn.execute("""
|
||||||
conn.execute("""
|
|
||||||
INSERT OR REPLACE INTO guild_profile (id, name, description, icon_file, icon_url, banner_file, banner_url, owner_id, last_backup, ignore_channels)
|
INSERT OR REPLACE INTO guild_profile (id, name, description, icon_file, icon_url, banner_file, banner_url, owner_id, last_backup, ignore_channels)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
""", (
|
""", (
|
||||||
|
|
@ -257,14 +257,11 @@ class BackupDatabase:
|
||||||
str(data.get("owner_id")),
|
str(data.get("owner_id")),
|
||||||
data.get("last_backup"), json.dumps(data.get("ignore_channels", []))
|
data.get("last_backup"), json.dumps(data.get("ignore_channels", []))
|
||||||
))
|
))
|
||||||
conn.commit()
|
self._conn.commit()
|
||||||
conn.close()
|
|
||||||
|
|
||||||
def get_guild_profile(self) -> Optional[Dict[str, Any]]:
|
def get_guild_profile(self) -> Optional[Dict[str, Any]]:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
row = self._conn.execute("SELECT * FROM guild_profile LIMIT 1").fetchone()
|
||||||
row = conn.execute("SELECT * FROM guild_profile LIMIT 1").fetchone()
|
|
||||||
conn.close()
|
|
||||||
if row:
|
if row:
|
||||||
data = dict(row)
|
data = dict(row)
|
||||||
if data.get("ignore_channels"):
|
if data.get("ignore_channels"):
|
||||||
|
|
@ -276,11 +273,8 @@ class BackupDatabase:
|
||||||
|
|
||||||
def save_roles(self, roles: List[Dict[str, Any]]):
|
def save_roles(self, roles: List[Dict[str, Any]]):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
formatted = [
|
||||||
# Ensure complex fields are strings if they aren't already
|
{
|
||||||
formatted = []
|
|
||||||
for r in roles:
|
|
||||||
formatted.append({
|
|
||||||
"id": str(r["id"]),
|
"id": str(r["id"]),
|
||||||
"name": r["name"],
|
"name": r["name"],
|
||||||
"color": r["color"],
|
"color": r["color"],
|
||||||
|
|
@ -288,225 +282,193 @@ class BackupDatabase:
|
||||||
"permissions": str(r["permissions"]),
|
"permissions": str(r["permissions"]),
|
||||||
"hoist": 1 if r["hoist"] else 0,
|
"hoist": 1 if r["hoist"] else 0,
|
||||||
"mentionable": 1 if r["mentionable"] else 0
|
"mentionable": 1 if r["mentionable"] else 0
|
||||||
})
|
}
|
||||||
|
for r in roles
|
||||||
conn.executemany("""
|
]
|
||||||
|
self._conn.executemany("""
|
||||||
INSERT OR REPLACE INTO roles (id, name, color, position, permissions, hoist, mentionable)
|
INSERT OR REPLACE INTO roles (id, name, color, position, permissions, hoist, mentionable)
|
||||||
VALUES (:id, :name, :color, :position, :permissions, :hoist, :mentionable)
|
VALUES (:id, :name, :color, :position, :permissions, :hoist, :mentionable)
|
||||||
""", formatted)
|
""", formatted)
|
||||||
conn.commit()
|
self._conn.commit()
|
||||||
conn.close()
|
|
||||||
|
|
||||||
def save_channels(self, channels: List[Dict[str, Any]]):
|
def save_channels(self, channels: List[Dict[str, Any]]):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
self._conn.executemany("""
|
||||||
conn.executemany("""
|
|
||||||
INSERT OR REPLACE INTO channels (id, name, type, position, category_id, topic, nsfw, bitrate, slowmode_delay)
|
INSERT OR REPLACE INTO channels (id, name, type, position, category_id, topic, nsfw, bitrate, slowmode_delay)
|
||||||
VALUES (:id, :name, :type, :position, :category_id, :topic, :nsfw, :bitrate, :slowmode_delay)
|
VALUES (:id, :name, :type, :position, :category_id, :topic, :nsfw, :bitrate, :slowmode_delay)
|
||||||
""", channels)
|
""", channels)
|
||||||
conn.commit()
|
self._conn.commit()
|
||||||
conn.close()
|
|
||||||
|
|
||||||
def save_permissions(self, permissions: List[Dict[str, Any]]):
|
def save_permissions(self, permissions: List[Dict[str, Any]]):
|
||||||
"""Saves a batch of channel permission overwrites."""
|
"""Saves a batch of channel permission overwrites."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
self._conn.executemany("""
|
||||||
try:
|
INSERT INTO permissions (channel_id, target_id, target_type, allow, deny)
|
||||||
conn.executemany("""
|
VALUES (:channel_id, :target_id, :target_type, :allow, :deny)
|
||||||
INSERT INTO permissions (channel_id, target_id, target_type, allow, deny)
|
""", permissions)
|
||||||
VALUES (:channel_id, :target_id, :target_type, :allow, :deny)
|
self._conn.commit()
|
||||||
""", permissions)
|
|
||||||
conn.commit()
|
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
def save_users(self, users: List[Dict[str, Any]]):
|
def save_users(self, users: List[Dict[str, Any]]):
|
||||||
"""Saves users to the author cache."""
|
"""Saves users to the author cache."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
self._conn.executemany("""
|
||||||
conn.executemany("""
|
|
||||||
INSERT OR REPLACE INTO users (id, username, display_name, avatar_file, avatar_url, roles)
|
INSERT OR REPLACE INTO users (id, username, display_name, avatar_file, avatar_url, roles)
|
||||||
VALUES (:id, :username, :display_name, :avatar_file, :avatar_url, :roles)
|
VALUES (:id, :username, :display_name, :avatar_file, :avatar_url, :roles)
|
||||||
""", users)
|
""", users)
|
||||||
conn.commit()
|
self._conn.commit()
|
||||||
conn.close()
|
|
||||||
|
|
||||||
def save_server_assets(self, assets: List[Dict[str, Any]]):
|
def save_server_assets(self, assets: List[Dict[str, Any]]):
|
||||||
"""Saves a batch of server assets (emojis, stickers) to the database."""
|
"""Saves a batch of server assets (emojis, stickers) to the database."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
formatted = [
|
||||||
formatted = []
|
{
|
||||||
for a in assets:
|
|
||||||
formatted.append({
|
|
||||||
"id": str(a["id"]),
|
"id": str(a["id"]),
|
||||||
"name": a.get("name"),
|
"name": a.get("name"),
|
||||||
"type": a.get("type"),
|
"type": a.get("type"),
|
||||||
"filename": a.get("filename"),
|
"filename": a.get("filename"),
|
||||||
"url": a.get("url"),
|
"url": a.get("url"),
|
||||||
"mime_type": a.get("mime_type")
|
"mime_type": a.get("mime_type")
|
||||||
})
|
}
|
||||||
|
for a in assets
|
||||||
conn.executemany("""
|
]
|
||||||
|
self._conn.executemany("""
|
||||||
INSERT OR REPLACE INTO server_assets (id, name, type, filename, url, mime_type)
|
INSERT OR REPLACE INTO server_assets (id, name, type, filename, url, mime_type)
|
||||||
VALUES (:id, :name, :type, :filename, :url, :mime_type)
|
VALUES (:id, :name, :type, :filename, :url, :mime_type)
|
||||||
""", formatted)
|
""", formatted)
|
||||||
conn.commit()
|
self._conn.commit()
|
||||||
conn.close()
|
|
||||||
|
|
||||||
def save_threads(self, threads: List[Dict[str, Any]]):
|
def save_threads(self, threads: List[Dict[str, Any]]):
|
||||||
"""Saves metadata for threads to the database."""
|
"""Saves metadata for threads to the database."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
self._conn.executemany("""
|
||||||
try:
|
INSERT OR REPLACE INTO threads (id, name, type, parent_id, message_count, member_count, archived, archive_timestamp, auto_archive_duration, locked, applied_tags)
|
||||||
conn.executemany("""
|
VALUES (:id, :name, :type, :parent_id, :message_count, :member_count, :archived, :archive_timestamp, :auto_archive_duration, :locked, :applied_tags)
|
||||||
INSERT OR REPLACE INTO threads (id, name, type, parent_id, message_count, member_count, archived, archive_timestamp, auto_archive_duration, locked, applied_tags)
|
""", threads)
|
||||||
VALUES (:id, :name, :type, :parent_id, :message_count, :member_count, :archived, :archive_timestamp, :auto_archive_duration, :locked, :applied_tags)
|
self._conn.commit()
|
||||||
""", threads)
|
|
||||||
conn.commit()
|
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
def save_forum_tags(self, tags: List[Dict[str, Any]]):
|
def save_forum_tags(self, tags: List[Dict[str, Any]]):
|
||||||
"""Saves definitions for forum tags."""
|
"""Saves definitions for forum tags."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
self._conn.executemany("""
|
||||||
try:
|
INSERT OR REPLACE INTO forum_tags (id, forum_id, name, moderated, emoji_id, emoji_name)
|
||||||
conn.executemany("""
|
VALUES (:id, :forum_id, :name, :moderated, :emoji_id, :emoji_name)
|
||||||
INSERT OR REPLACE INTO forum_tags (id, forum_id, name, moderated, emoji_id, emoji_name)
|
""", tags)
|
||||||
VALUES (:id, :forum_id, :name, :moderated, :emoji_id, :emoji_name)
|
self._conn.commit()
|
||||||
""", tags)
|
|
||||||
conn.commit()
|
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
def save_messages_batch(self, messages: List[Dict[str, Any]]):
|
def save_messages_batch(self, messages: List[Dict[str, Any]]):
|
||||||
"""Batch inserts messages and their attachments."""
|
"""Batch inserts messages and their attachments."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
conn = self._conn
|
||||||
try:
|
# Insert messages
|
||||||
# Insert messages
|
conn.executemany("""
|
||||||
|
INSERT OR REPLACE INTO messages (id, channel_id, author_id, content, timestamp, type, message_reference, is_pinned, extra_data)
|
||||||
|
VALUES (:id, :channel_id, :author_id, :content, :timestamp, :type, :message_reference, :is_pinned, :extra_data)
|
||||||
|
""", messages)
|
||||||
|
|
||||||
|
# Extract attachments, reactions, and stickers
|
||||||
|
all_attachments = []
|
||||||
|
all_reactions = []
|
||||||
|
all_stickers = []
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
# Attachments
|
||||||
|
if "attachments" in msg:
|
||||||
|
for att in msg["attachments"]:
|
||||||
|
att["message_id"] = msg["id"]
|
||||||
|
all_attachments.append(att)
|
||||||
|
|
||||||
|
# Reactions
|
||||||
|
if "reactions" in msg and msg["reactions"]:
|
||||||
|
for rea in msg["reactions"]:
|
||||||
|
all_reactions.append({
|
||||||
|
"message_id": msg["id"],
|
||||||
|
"emoji_id": str(rea["emoji_id"]) if rea.get("emoji_id") else None,
|
||||||
|
"emoji_name": rea.get("emoji_name"),
|
||||||
|
"count": rea.get("count", 0)
|
||||||
|
})
|
||||||
|
|
||||||
|
# Stickers
|
||||||
|
if "stickers" in msg and msg["stickers"]:
|
||||||
|
for st in msg["stickers"]:
|
||||||
|
all_stickers.append({
|
||||||
|
"message_id": msg["id"],
|
||||||
|
"sticker_id": str(st["id"]),
|
||||||
|
"name": st.get("name"),
|
||||||
|
"url": st.get("url"),
|
||||||
|
"format_type": st.get("format_type"),
|
||||||
|
"local_hash": st.get("local_hash")
|
||||||
|
})
|
||||||
|
|
||||||
|
if all_attachments:
|
||||||
conn.executemany("""
|
conn.executemany("""
|
||||||
INSERT OR REPLACE INTO messages (id, channel_id, author_id, content, timestamp, type, message_reference, is_pinned, extra_data)
|
INSERT OR REPLACE INTO attachments (id, message_id, filename, size, url, content_type, local_hash)
|
||||||
VALUES (:id, :channel_id, :author_id, :content, :timestamp, :type, :message_reference, :is_pinned, :extra_data)
|
VALUES (:id, :message_id, :filename, :size, :url, :content_type, :local_hash)
|
||||||
""", messages)
|
""", all_attachments)
|
||||||
|
|
||||||
# Extract attachments, reactions, and stickers
|
|
||||||
all_attachments = []
|
|
||||||
all_reactions = []
|
|
||||||
all_stickers = []
|
|
||||||
|
|
||||||
for msg in messages:
|
|
||||||
# Attachments
|
|
||||||
if "attachments" in msg:
|
|
||||||
for att in msg["attachments"]:
|
|
||||||
att["message_id"] = msg["id"]
|
|
||||||
all_attachments.append(att)
|
|
||||||
|
|
||||||
# Reactions
|
|
||||||
if "reactions" in msg and msg["reactions"]:
|
|
||||||
for rea in msg["reactions"]:
|
|
||||||
all_reactions.append({
|
|
||||||
"message_id": msg["id"],
|
|
||||||
"emoji_id": str(rea["emoji_id"]) if rea.get("emoji_id") else None,
|
|
||||||
"emoji_name": rea.get("emoji_name"),
|
|
||||||
"count": rea.get("count", 0)
|
|
||||||
})
|
|
||||||
|
|
||||||
# Stickers
|
|
||||||
if "stickers" in msg and msg["stickers"]:
|
|
||||||
for st in msg["stickers"]:
|
|
||||||
all_stickers.append({
|
|
||||||
"message_id": msg["id"],
|
|
||||||
"sticker_id": str(st["id"]),
|
|
||||||
"name": st.get("name"),
|
|
||||||
"url": st.get("url"),
|
|
||||||
"format_type": st.get("format_type"),
|
|
||||||
"local_hash": st.get("local_hash")
|
|
||||||
})
|
|
||||||
|
|
||||||
if all_attachments:
|
|
||||||
conn.executemany("""
|
|
||||||
INSERT OR REPLACE INTO attachments (id, message_id, filename, size, url, content_type, local_hash)
|
|
||||||
VALUES (:id, :message_id, :filename, :size, :url, :content_type, :local_hash)
|
|
||||||
""", all_attachments)
|
|
||||||
|
|
||||||
# Save Embeds (Normalized with JSON Fields)
|
# Save Embeds (Normalized with JSON Fields)
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if "embeds" in msg and msg["embeds"]:
|
if "embeds" in msg and msg["embeds"]:
|
||||||
for emb in msg["embeds"]:
|
for emb in msg["embeds"]:
|
||||||
# Insert top-level embed
|
conn.execute("""
|
||||||
conn.execute("""
|
INSERT INTO embeds (
|
||||||
INSERT INTO embeds (
|
message_id, title, description, url, color, timestamp,
|
||||||
message_id, title, description, url, color, timestamp,
|
thumbnail_url, image_url, author_name, author_url,
|
||||||
thumbnail_url, image_url, author_name, author_url,
|
author_icon_url, footer_text, footer_icon_url, fields
|
||||||
author_icon_url, footer_text, footer_icon_url, fields
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
""", (
|
||||||
""", (
|
msg["id"], emb.get("title"), emb.get("description"), emb.get("url"),
|
||||||
msg["id"], emb.get("title"), emb.get("description"), emb.get("url"),
|
emb.get("color"), emb.get("timestamp"),
|
||||||
emb.get("color"), emb.get("timestamp"),
|
emb.get("thumbnail", {}).get("url") if isinstance(emb.get("thumbnail"), dict) else None,
|
||||||
emb.get("thumbnail", {}).get("url") if isinstance(emb.get("thumbnail"), dict) else None,
|
emb.get("image", {}).get("url") if isinstance(emb.get("image"), dict) else None,
|
||||||
emb.get("image", {}).get("url") if isinstance(emb.get("image"), dict) else None,
|
emb.get("author", {}).get("name") if isinstance(emb.get("author"), dict) else None,
|
||||||
emb.get("author", {}).get("name") if isinstance(emb.get("author"), dict) else None,
|
emb.get("author", {}).get("url") if isinstance(emb.get("author"), dict) else None,
|
||||||
emb.get("author", {}).get("url") if isinstance(emb.get("author"), dict) else None,
|
emb.get("author", {}).get("icon_url") if isinstance(emb.get("author"), dict) else None,
|
||||||
emb.get("author", {}).get("icon_url") if isinstance(emb.get("author"), dict) else None,
|
emb.get("footer", {}).get("text") if isinstance(emb.get("footer"), dict) else None,
|
||||||
emb.get("footer", {}).get("text") if isinstance(emb.get("footer"), dict) else None,
|
emb.get("footer", {}).get("icon_url") if isinstance(emb.get("footer"), dict) else None,
|
||||||
emb.get("footer", {}).get("icon_url") if isinstance(emb.get("footer"), dict) else None,
|
json.dumps(emb.get("fields", []))
|
||||||
json.dumps(emb.get("fields", []))
|
))
|
||||||
))
|
|
||||||
|
|
||||||
if all_reactions:
|
if all_reactions:
|
||||||
conn.executemany("""
|
conn.executemany("""
|
||||||
INSERT INTO reactions (message_id, emoji_id, emoji_name, count)
|
INSERT INTO reactions (message_id, emoji_id, emoji_name, count)
|
||||||
VALUES (:message_id, :emoji_id, :emoji_name, :count)
|
VALUES (:message_id, :emoji_id, :emoji_name, :count)
|
||||||
""", all_reactions)
|
""", all_reactions)
|
||||||
|
|
||||||
if all_stickers:
|
if all_stickers:
|
||||||
conn.executemany("""
|
conn.executemany("""
|
||||||
INSERT OR REPLACE INTO message_stickers (message_id, sticker_id, name, url, format_type, local_hash)
|
INSERT OR REPLACE INTO message_stickers (message_id, sticker_id, name, url, format_type, local_hash)
|
||||||
VALUES (:message_id, :sticker_id, :name, :url, :format_type, :local_hash)
|
VALUES (:message_id, :sticker_id, :name, :url, :format_type, :local_hash)
|
||||||
""", all_stickers)
|
""", all_stickers)
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
def get_last_message_id(self, channel_id: str) -> Optional[str]:
|
def get_last_message_id(self, channel_id: str) -> Optional[str]:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
row = self._conn.execute("SELECT id FROM messages WHERE channel_id = ? ORDER BY id DESC LIMIT 1", (str(channel_id),)).fetchone()
|
||||||
row = conn.execute("SELECT id FROM messages WHERE channel_id = ? ORDER BY id DESC LIMIT 1", (str(channel_id),)).fetchone()
|
|
||||||
conn.close()
|
|
||||||
return row["id"] if row else None
|
return row["id"] if row else None
|
||||||
|
|
||||||
def get_media_by_hash(self, file_hash: str) -> Optional[Dict[str, Any]]:
|
def get_media_by_hash(self, file_hash: str) -> Optional[Dict[str, Any]]:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
row = self._conn.execute("SELECT * FROM media_pool WHERE hash = ?", (file_hash,)).fetchone()
|
||||||
row = conn.execute("SELECT * FROM media_pool WHERE hash = ?", (file_hash,)).fetchone()
|
|
||||||
conn.close()
|
|
||||||
return dict(row) if row else None
|
return dict(row) if row else None
|
||||||
|
|
||||||
def get_media_by_url(self, url: str) -> Optional[Dict[str, Any]]:
|
def get_media_by_url(self, url: str) -> Optional[Dict[str, Any]]:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
row = self._conn.execute("SELECT * FROM media_pool WHERE first_seen_url = ?", (url,)).fetchone()
|
||||||
row = conn.execute("SELECT * FROM media_pool WHERE first_seen_url = ?", (url,)).fetchone()
|
|
||||||
conn.close()
|
|
||||||
return dict(row) if row else None
|
return dict(row) if row else None
|
||||||
|
|
||||||
def add_media_to_pool(self, file_hash: str, local_path: str, size: int, mime_type: str, url: str):
|
def add_media_to_pool(self, file_hash: str, local_path: str, size: int, mime_type: str, url: str):
|
||||||
|
# NOTE: No commit here — caller (save_messages_batch) commits at end of batch
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
self._conn.execute("""
|
||||||
conn.execute("""
|
|
||||||
INSERT OR REPLACE INTO media_pool (hash, local_path, size, mime_type, first_seen_url)
|
INSERT OR REPLACE INTO media_pool (hash, local_path, size, mime_type, first_seen_url)
|
||||||
VALUES (?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?)
|
||||||
""", (file_hash, str(local_path), size, mime_type, url))
|
""", (file_hash, str(local_path), size, mime_type, url))
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
def get_stats_by_channel(self) -> Dict[int, Dict[str, Any]]:
|
def get_stats_by_channel(self) -> Dict[int, Dict[str, Any]]:
|
||||||
"""Returns aggregate stats for all channels with backups."""
|
"""Returns aggregate stats for all channels with backups."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
msg_rows = self._conn.execute("""
|
||||||
# Summary of messages per effective channel (including threads rolled up)
|
|
||||||
msg_rows = conn.execute("""
|
|
||||||
SELECT
|
SELECT
|
||||||
COALESCE(t.parent_id, m.channel_id) as channel_id,
|
COALESCE(t.parent_id, m.channel_id) as channel_id,
|
||||||
COUNT(m.id) as msg_count
|
COUNT(m.id) as msg_count
|
||||||
|
|
@ -515,15 +477,13 @@ class BackupDatabase:
|
||||||
GROUP BY channel_id
|
GROUP BY channel_id
|
||||||
""").fetchall()
|
""").fetchall()
|
||||||
|
|
||||||
# Thread counts per parent
|
thread_rows = self._conn.execute("""
|
||||||
thread_rows = conn.execute("""
|
|
||||||
SELECT parent_id, COUNT(*) as thread_count
|
SELECT parent_id, COUNT(*) as thread_count
|
||||||
FROM threads
|
FROM threads
|
||||||
GROUP BY parent_id
|
GROUP BY parent_id
|
||||||
""").fetchall()
|
""").fetchall()
|
||||||
|
|
||||||
# Summary of attachments per effective channel (including threads rolled up)
|
att_rows = self._conn.execute("""
|
||||||
att_rows = conn.execute("""
|
|
||||||
SELECT
|
SELECT
|
||||||
COALESCE(t.parent_id, m.channel_id) as channel_id,
|
COALESCE(t.parent_id, m.channel_id) as channel_id,
|
||||||
COUNT(a.id) as att_count,
|
COUNT(a.id) as att_count,
|
||||||
|
|
@ -534,8 +494,6 @@ class BackupDatabase:
|
||||||
GROUP BY channel_id
|
GROUP BY channel_id
|
||||||
""").fetchall()
|
""").fetchall()
|
||||||
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
stats = {}
|
stats = {}
|
||||||
for r in msg_rows:
|
for r in msg_rows:
|
||||||
cid = parse_snowflake(r["channel_id"])
|
cid = parse_snowflake(r["channel_id"])
|
||||||
|
|
@ -566,22 +524,18 @@ class BackupDatabase:
|
||||||
|
|
||||||
def get_all_roles(self) -> List[Dict[str, Any]]:
|
def get_all_roles(self) -> List[Dict[str, Any]]:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
rows = self._conn.execute("SELECT * FROM roles ORDER BY position DESC").fetchall()
|
||||||
rows = conn.execute("SELECT * FROM roles ORDER BY position DESC").fetchall()
|
|
||||||
conn.close()
|
|
||||||
return [dict(r) for r in rows]
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
def get_all_channels(self) -> List[Dict[str, Any]]:
|
def get_all_channels(self) -> List[Dict[str, Any]]:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
rows = self._conn.execute("SELECT * FROM channels ORDER BY position ASC").fetchall()
|
||||||
rows = conn.execute("SELECT * FROM channels ORDER BY position ASC").fetchall()
|
|
||||||
|
|
||||||
# Fetch all permissions
|
|
||||||
chan_list = [dict(r) for r in rows]
|
chan_list = [dict(r) for r in rows]
|
||||||
if chan_list:
|
if chan_list:
|
||||||
ids = [c["id"] for c in chan_list]
|
ids = [c["id"] for c in chan_list]
|
||||||
placeholders = ",".join(["?"] * len(ids))
|
placeholders = ",".join(["?"] * len(ids))
|
||||||
perm_rows = conn.execute(f"SELECT * FROM permissions WHERE channel_id IN ({placeholders})", ids).fetchall()
|
perm_rows = self._conn.execute(f"SELECT * FROM permissions WHERE channel_id IN ({placeholders})", ids).fetchall()
|
||||||
|
|
||||||
perms_by_chan = {}
|
perms_by_chan = {}
|
||||||
for pr in perm_rows:
|
for pr in perm_rows:
|
||||||
|
|
@ -597,8 +551,7 @@ class BackupDatabase:
|
||||||
for c in chan_list:
|
for c in chan_list:
|
||||||
c["overwrites"] = perms_by_chan.get(c["id"], [])
|
c["overwrites"] = perms_by_chan.get(c["id"], [])
|
||||||
|
|
||||||
# Fetch Forum Tags
|
tag_rows = self._conn.execute(f"SELECT * FROM forum_tags WHERE forum_id IN ({placeholders})", ids).fetchall()
|
||||||
tag_rows = conn.execute(f"SELECT * FROM forum_tags WHERE forum_id IN ({placeholders})", ids).fetchall()
|
|
||||||
tags_by_forum = {}
|
tags_by_forum = {}
|
||||||
for tr in tag_rows:
|
for tr in tag_rows:
|
||||||
fid = tr["forum_id"]
|
fid = tr["forum_id"]
|
||||||
|
|
@ -608,56 +561,43 @@ class BackupDatabase:
|
||||||
for c in chan_list:
|
for c in chan_list:
|
||||||
c["available_tags"] = tags_by_forum.get(c["id"], [])
|
c["available_tags"] = tags_by_forum.get(c["id"], [])
|
||||||
|
|
||||||
conn.close()
|
|
||||||
return chan_list
|
return chan_list
|
||||||
|
|
||||||
def get_all_threads(self) -> List[Dict[str, Any]]:
|
def get_all_threads(self) -> List[Dict[str, Any]]:
|
||||||
"""Returns metadata for all threads in the backup."""
|
"""Returns metadata for all threads in the backup."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
rows = self._conn.execute("SELECT * FROM threads").fetchall()
|
||||||
rows = conn.execute("SELECT * FROM threads").fetchall()
|
|
||||||
conn.close()
|
|
||||||
return [dict(r) for r in rows]
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
def get_forum_tags(self, forum_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
def get_forum_tags(self, forum_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||||
"""Returns forum tag definitions."""
|
"""Returns forum tag definitions."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
|
||||||
if forum_id:
|
if forum_id:
|
||||||
rows = conn.execute("SELECT * FROM forum_tags WHERE forum_id = ?", (str(forum_id),)).fetchall()
|
rows = self._conn.execute("SELECT * FROM forum_tags WHERE forum_id = ?", (str(forum_id),)).fetchall()
|
||||||
else:
|
else:
|
||||||
rows = conn.execute("SELECT * FROM forum_tags").fetchall()
|
rows = self._conn.execute("SELECT * FROM forum_tags").fetchall()
|
||||||
conn.close()
|
|
||||||
return [dict(r) for r in rows]
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
def get_threads_by_parent(self, parent_id: str) -> List[Dict[str, Any]]:
|
def get_threads_by_parent(self, parent_id: str) -> List[Dict[str, Any]]:
|
||||||
"""Returns all threads belonging to a parent channel."""
|
"""Returns all threads belonging to a parent channel."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
rows = self._conn.execute("SELECT * FROM threads WHERE parent_id = ?", (str(parent_id),)).fetchall()
|
||||||
rows = conn.execute("SELECT * FROM threads WHERE parent_id = ?", (str(parent_id),)).fetchall()
|
|
||||||
conn.close()
|
|
||||||
return [dict(r) for r in rows]
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
def get_thread(self, thread_id: str) -> Optional[Dict[str, Any]]:
|
def get_thread(self, thread_id: str) -> Optional[Dict[str, Any]]:
|
||||||
"""Retrieves a single thread's metadata."""
|
"""Retrieves a single thread's metadata."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
row = self._conn.execute("SELECT * FROM threads WHERE id = ?", (str(thread_id),)).fetchone()
|
||||||
row = conn.execute("SELECT * FROM threads WHERE id = ?", (str(thread_id),)).fetchone()
|
|
||||||
conn.close()
|
|
||||||
return dict(row) if row else None
|
return dict(row) if row else None
|
||||||
|
|
||||||
def get_all_users(self) -> List[Dict[str, Any]]:
|
def get_all_users(self) -> List[Dict[str, Any]]:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
rows = self._conn.execute("SELECT * FROM users").fetchall()
|
||||||
rows = conn.execute("SELECT * FROM users").fetchall()
|
|
||||||
conn.close()
|
|
||||||
return [dict(r) for r in rows]
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
def get_user(self, user_id: str) -> Optional[Dict[str, Any]]:
|
def get_user(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
row = self._conn.execute("SELECT * FROM users WHERE id = ?", (str(user_id),)).fetchone()
|
||||||
row = conn.execute("SELECT * FROM users WHERE id = ?", (str(user_id),)).fetchone()
|
|
||||||
conn.close()
|
|
||||||
if row:
|
if row:
|
||||||
data = dict(row)
|
data = dict(row)
|
||||||
if data.get("roles"):
|
if data.get("roles"):
|
||||||
|
|
@ -668,25 +608,20 @@ class BackupDatabase:
|
||||||
def get_server_assets(self, asset_type: Optional[str] = None) -> List[Dict[str, Any]]:
|
def get_server_assets(self, asset_type: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||||
"""Returns all server assets, optionally filtered by type."""
|
"""Returns all server assets, optionally filtered by type."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
|
||||||
if asset_type:
|
if asset_type:
|
||||||
rows = conn.execute("SELECT * FROM server_assets WHERE type = ?", (asset_type,)).fetchall()
|
rows = self._conn.execute("SELECT * FROM server_assets WHERE type = ?", (asset_type,)).fetchall()
|
||||||
else:
|
else:
|
||||||
rows = conn.execute("SELECT * FROM server_assets").fetchall()
|
rows = self._conn.execute("SELECT * FROM server_assets").fetchall()
|
||||||
conn.close()
|
|
||||||
return [dict(r) for r in rows]
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
def get_all_media(self) -> Dict[str, Dict[str, Any]]:
|
def get_all_media(self) -> Dict[str, Dict[str, Any]]:
|
||||||
"""Returns the entire media pool as a dictionary indexed by hash."""
|
"""Returns the entire media pool as a dictionary indexed by hash."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
rows = self._conn.execute("SELECT * FROM media_pool").fetchall()
|
||||||
rows = conn.execute("SELECT * FROM media_pool").fetchall()
|
|
||||||
conn.close()
|
|
||||||
return {r["hash"]: dict(r) for r in rows}
|
return {r["hash"]: dict(r) for r in rows}
|
||||||
|
|
||||||
def get_messages_paged(self, channel_id: str, limit: int = 100, offset: int = 0, after_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
def get_messages_paged(self, channel_id: str, limit: int = 100, offset: int = 0, after_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
|
||||||
query = "SELECT * FROM messages WHERE channel_id = ?"
|
query = "SELECT * FROM messages WHERE channel_id = ?"
|
||||||
params = [str(channel_id)]
|
params = [str(channel_id)]
|
||||||
|
|
||||||
|
|
@ -697,24 +632,21 @@ class BackupDatabase:
|
||||||
query += " ORDER BY id ASC LIMIT ? OFFSET ?"
|
query += " ORDER BY id ASC LIMIT ? OFFSET ?"
|
||||||
params.extend([limit, offset])
|
params.extend([limit, offset])
|
||||||
|
|
||||||
rows = conn.execute(query, params).fetchall()
|
rows = self._conn.execute(query, params).fetchall()
|
||||||
msg_list = [dict(r) for r in rows]
|
msg_list = [dict(r) for r in rows]
|
||||||
|
|
||||||
if msg_list:
|
if msg_list:
|
||||||
# Fetch attachments for these messages
|
|
||||||
msg_ids = [m["id"] for m in msg_list]
|
msg_ids = [m["id"] for m in msg_list]
|
||||||
placeholders = ",".join(["?"] * len(msg_ids))
|
placeholders = ",".join(["?"] * len(msg_ids))
|
||||||
|
|
||||||
# Attachments
|
att_rows = self._conn.execute(f"SELECT * FROM attachments WHERE message_id IN ({placeholders})", msg_ids).fetchall()
|
||||||
att_rows = conn.execute(f"SELECT * FROM attachments WHERE message_id IN ({placeholders})", msg_ids).fetchall()
|
|
||||||
atts_by_msg = {}
|
atts_by_msg = {}
|
||||||
for ar in att_rows:
|
for ar in att_rows:
|
||||||
mid = ar["message_id"]
|
mid = ar["message_id"]
|
||||||
if mid not in atts_by_msg: atts_by_msg[mid] = []
|
if mid not in atts_by_msg: atts_by_msg[mid] = []
|
||||||
atts_by_msg[mid].append(dict(ar))
|
atts_by_msg[mid].append(dict(ar))
|
||||||
|
|
||||||
# Embeds
|
emb_rows = self._conn.execute(f"SELECT * FROM embeds WHERE message_id IN ({placeholders})", msg_ids).fetchall()
|
||||||
emb_rows = conn.execute(f"SELECT * FROM embeds WHERE message_id IN ({placeholders})", msg_ids).fetchall()
|
|
||||||
embs_by_msg = {}
|
embs_by_msg = {}
|
||||||
for er in emb_rows:
|
for er in emb_rows:
|
||||||
mid = er["message_id"]
|
mid = er["message_id"]
|
||||||
|
|
@ -741,16 +673,14 @@ class BackupDatabase:
|
||||||
}
|
}
|
||||||
embs_by_msg[mid].append(e_dict)
|
embs_by_msg[mid].append(e_dict)
|
||||||
|
|
||||||
# Reactions
|
rea_rows = self._conn.execute(f"SELECT * FROM reactions WHERE message_id IN ({placeholders})", msg_ids).fetchall()
|
||||||
rea_rows = conn.execute(f"SELECT * FROM reactions WHERE message_id IN ({placeholders})", msg_ids).fetchall()
|
|
||||||
reas_by_msg = {}
|
reas_by_msg = {}
|
||||||
for rr in rea_rows:
|
for rr in rea_rows:
|
||||||
mid = rr["message_id"]
|
mid = rr["message_id"]
|
||||||
if mid not in reas_by_msg: reas_by_msg[mid] = []
|
if mid not in reas_by_msg: reas_by_msg[mid] = []
|
||||||
reas_by_msg[mid].append(dict(rr))
|
reas_by_msg[mid].append(dict(rr))
|
||||||
|
|
||||||
# Stickers (Message-specific)
|
st_rows = self._conn.execute(f"SELECT * FROM message_stickers WHERE message_id IN ({placeholders})", msg_ids).fetchall()
|
||||||
st_rows = conn.execute(f"SELECT * FROM message_stickers WHERE message_id IN ({placeholders})", msg_ids).fetchall()
|
|
||||||
sts_by_msg = {}
|
sts_by_msg = {}
|
||||||
for sr in st_rows:
|
for sr in st_rows:
|
||||||
mid = sr["message_id"]
|
mid = sr["message_id"]
|
||||||
|
|
@ -764,10 +694,13 @@ class BackupDatabase:
|
||||||
m["reactions"] = reas_by_msg.get(m_id, [])
|
m["reactions"] = reas_by_msg.get(m_id, [])
|
||||||
m["stickers"] = sts_by_msg.get(m_id, [])
|
m["stickers"] = sts_by_msg.get(m_id, [])
|
||||||
|
|
||||||
conn.close()
|
|
||||||
return msg_list
|
return msg_list
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
# We don't keep long-lived connections in this model to avoid locking issues,
|
"""Commits any pending writes and closes the connection."""
|
||||||
# but the method is here for parity with other DB classes.
|
with self._lock:
|
||||||
pass
|
try:
|
||||||
|
self._conn.commit()
|
||||||
|
self._conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,10 @@ class DiscordExporter:
|
||||||
self.base_dir = Path(base_dir) if base_dir else Path(".")
|
self.base_dir = Path(base_dir) if base_dir else Path(".")
|
||||||
self.is_running = True
|
self.is_running = True
|
||||||
self.db: Optional[BackupDatabase] = None
|
self.db: Optional[BackupDatabase] = None
|
||||||
self.sticker_cache: Dict[int, bytes] = {} # Deduplicate downloads in one session
|
self.sticker_cache: Dict[int, bytes] = {} # Deduplicate downloads in one session
|
||||||
|
# Pending avatar downloads — flushed after each message batch to keep the
|
||||||
|
# hot message-formatting path free of HTTP latency.
|
||||||
|
self._pending_avatars: List[tuple] = [] # (user_id, save_coroutine, av_path)
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
"""Prepares the output directory and fetches server metadata."""
|
"""Prepares the output directory and fetches server metadata."""
|
||||||
|
|
@ -379,6 +382,9 @@ class DiscordExporter:
|
||||||
accumulated_files += len(m["attachments"])
|
accumulated_files += len(m["attachments"])
|
||||||
|
|
||||||
# Persist to DB
|
# Persist to DB
|
||||||
|
# Flush deferred avatar downloads before persisting this batch
|
||||||
|
await self._flush_pending_avatars()
|
||||||
|
|
||||||
if self.db:
|
if self.db:
|
||||||
if batch_users: self.db.save_users(batch_users)
|
if batch_users: self.db.save_users(batch_users)
|
||||||
self.db.save_messages_batch(batch_messages)
|
self.db.save_messages_batch(batch_messages)
|
||||||
|
|
@ -393,32 +399,35 @@ class DiscordExporter:
|
||||||
batch_users.clear()
|
batch_users.clear()
|
||||||
batch_raw.clear()
|
batch_raw.clear()
|
||||||
|
|
||||||
if batch_raw and self.is_running:
|
# Process any remaining messages that didn't fill a full batch
|
||||||
results = await asyncio.gather(*(self._format_message(m) for m in batch_raw))
|
if batch_raw and self.is_running:
|
||||||
for m_data, u_list in results:
|
results = await asyncio.gather(*(self._format_message(m) for m in batch_raw))
|
||||||
batch_messages.append(m_data)
|
for m_data, u_list in results:
|
||||||
if u_list: batch_users.extend(u_list)
|
batch_messages.append(m_data)
|
||||||
|
if u_list: batch_users.extend(u_list)
|
||||||
new_count += len(batch_messages)
|
|
||||||
accumulated_count += len(batch_messages)
|
new_count += len(batch_messages)
|
||||||
|
accumulated_count += len(batch_messages)
|
||||||
for m in batch_messages:
|
|
||||||
if "attachments" in m:
|
for m in batch_messages:
|
||||||
accumulated_files += len(m["attachments"])
|
if "attachments" in m:
|
||||||
|
accumulated_files += len(m["attachments"])
|
||||||
|
|
||||||
if self.db:
|
# Flush deferred avatar downloads before persisting this batch
|
||||||
if batch_users: self.db.save_users(batch_users)
|
await self._flush_pending_avatars()
|
||||||
self.db.save_messages_batch(batch_messages)
|
|
||||||
|
if self.db:
|
||||||
if progress_callback:
|
if batch_users: self.db.save_users(batch_users)
|
||||||
last_msg = batch_raw[-1]
|
self.db.save_messages_batch(batch_messages)
|
||||||
author_name = getattr(last_msg.author, "display_name", "Unknown")
|
|
||||||
await progress_callback(channel_name, accumulated_count, author_name=author_name, thread_count=accumulated_threads, file_count=accumulated_files)
|
if progress_callback:
|
||||||
|
last_msg = batch_raw[-1]
|
||||||
batch_messages.clear()
|
author_name = getattr(last_msg.author, "display_name", "Unknown")
|
||||||
batch_users.clear()
|
await progress_callback(channel_name, accumulated_count, author_name=author_name, thread_count=accumulated_threads, file_count=accumulated_files)
|
||||||
batch_raw.clear()
|
|
||||||
|
batch_messages.clear()
|
||||||
|
batch_users.clear()
|
||||||
|
batch_raw.clear()
|
||||||
|
|
||||||
except discord.Forbidden:
|
except discord.Forbidden:
|
||||||
logger.error(f"403 Forbidden: Missing Access to read messages in {channel_name} ({channel_id})")
|
logger.error(f"403 Forbidden: Missing Access to read messages in {channel_name} ({channel_id})")
|
||||||
|
|
@ -431,22 +440,24 @@ class DiscordExporter:
|
||||||
return accumulated_count, accumulated_threads, accumulated_files
|
return accumulated_count, accumulated_threads, accumulated_files
|
||||||
|
|
||||||
async def _format_user(self, user):
|
async def _format_user(self, user):
|
||||||
"""Formats user data for the author or a mention."""
|
"""Formats user data for the author or a mention.
|
||||||
|
|
||||||
|
Avatar downloads are intentionally deferred to keep this off the hot
|
||||||
|
message-formatting path. Call _flush_pending_avatars() after each batch.
|
||||||
|
"""
|
||||||
user_id = str(user.id)
|
user_id = str(user.id)
|
||||||
if user_id in self.user_cache:
|
if user_id in self.user_cache:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# New user discovered
|
# New user discovered — schedule avatar download but don't block here
|
||||||
avatar_file = None
|
avatar_file = None
|
||||||
if user.avatar:
|
if user.avatar:
|
||||||
try:
|
av_name = f"{user_id}.png"
|
||||||
av_name = f"{user_id}.png"
|
av_target = self.users_path / av_name
|
||||||
av_target = self.users_path / av_name
|
avatar_file = f"users/{av_name}"
|
||||||
if not av_target.exists():
|
if not av_target.exists():
|
||||||
await user.avatar.save(av_target)
|
# Queue for deferred download
|
||||||
avatar_file = f"users/{av_name}"
|
self._pending_avatars.append((user_id, user.avatar, av_target))
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to save avatar for {user.name}: {e}")
|
|
||||||
|
|
||||||
roles = []
|
roles = []
|
||||||
if hasattr(user, "roles"):
|
if hasattr(user, "roles"):
|
||||||
|
|
@ -463,6 +474,21 @@ class DiscordExporter:
|
||||||
self.user_cache[user_id] = user_data
|
self.user_cache[user_id] = user_data
|
||||||
return user_data
|
return user_data
|
||||||
|
|
||||||
|
async def _flush_pending_avatars(self):
|
||||||
|
"""Downloads all queued user avatars concurrently, then clears the queue."""
|
||||||
|
if not self._pending_avatars:
|
||||||
|
return
|
||||||
|
|
||||||
|
async def _save_avatar(user_id, avatar_asset, target_path):
|
||||||
|
try:
|
||||||
|
await avatar_asset.save(target_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save avatar for user {user_id}: {e}")
|
||||||
|
|
||||||
|
pending = self._pending_avatars[:]
|
||||||
|
self._pending_avatars.clear()
|
||||||
|
await asyncio.gather(*[_save_avatar(uid, av, path) for uid, av, path in pending])
|
||||||
|
|
||||||
async def _format_message(self, msg):
|
async def _format_message(self, msg):
|
||||||
"""Formats a single message and its author for DB storage."""
|
"""Formats a single message and its author for DB storage."""
|
||||||
new_users = []
|
new_users = []
|
||||||
|
|
@ -478,11 +504,11 @@ class DiscordExporter:
|
||||||
if u_ment: new_users.append(u_ment)
|
if u_ment: new_users.append(u_ment)
|
||||||
|
|
||||||
# 2. Attachments handling (Content-Addressable Storage)
|
# 2. Attachments handling (Content-Addressable Storage)
|
||||||
# ... (rest of the logic remains same, just updating m_data)
|
# All attachments in a message are downloaded concurrently.
|
||||||
attachments = []
|
attachments = []
|
||||||
if msg.attachments:
|
if msg.attachments:
|
||||||
for att in msg.attachments:
|
att_tasks = [
|
||||||
att_data = await self._process_media(
|
self._process_media(
|
||||||
media_id=att.id,
|
media_id=att.id,
|
||||||
url=att.url,
|
url=att.url,
|
||||||
filename=att.filename,
|
filename=att.filename,
|
||||||
|
|
@ -490,8 +516,10 @@ class DiscordExporter:
|
||||||
content_type=att.content_type,
|
content_type=att.content_type,
|
||||||
save_method=att.save
|
save_method=att.save
|
||||||
)
|
)
|
||||||
if att_data:
|
for att in msg.attachments
|
||||||
attachments.append(att_data)
|
]
|
||||||
|
att_results = await asyncio.gather(*att_tasks)
|
||||||
|
attachments = [r for r in att_results if r]
|
||||||
|
|
||||||
# 2.5 Stickers handling
|
# 2.5 Stickers handling
|
||||||
stickers = []
|
stickers = []
|
||||||
|
|
@ -556,16 +584,20 @@ class DiscordExporter:
|
||||||
if snapshot.content:
|
if snapshot.content:
|
||||||
content = snapshot.content
|
content = snapshot.content
|
||||||
|
|
||||||
for s_att in snapshot.attachments:
|
if snapshot.attachments:
|
||||||
att_res = await self._process_media(
|
snap_tasks = [
|
||||||
media_id=s_att.id,
|
self._process_media(
|
||||||
url=s_att.url,
|
media_id=s_att.id,
|
||||||
filename=s_att.filename,
|
url=s_att.url,
|
||||||
size=s_att.size,
|
filename=s_att.filename,
|
||||||
content_type=s_att.content_type,
|
size=s_att.size,
|
||||||
save_method=s_att.save
|
content_type=s_att.content_type,
|
||||||
)
|
save_method=s_att.save
|
||||||
if att_res: attachments.append(att_res)
|
)
|
||||||
|
for s_att in snapshot.attachments
|
||||||
|
]
|
||||||
|
snap_results = await asyncio.gather(*snap_tasks)
|
||||||
|
attachments.extend(r for r in snap_results if r)
|
||||||
|
|
||||||
for s_emb in snapshot.embeds:
|
for s_emb in snapshot.embeds:
|
||||||
embeds.append(s_emb.to_dict())
|
embeds.append(s_emb.to_dict())
|
||||||
|
|
@ -623,14 +655,16 @@ class DiscordExporter:
|
||||||
try: tmp.close()
|
try: tmp.close()
|
||||||
except: pass
|
except: pass
|
||||||
|
|
||||||
file_hash = self._calculate_sha256(tmp_path)
|
# Offload CPU-bound hashing and blocking file ops to the thread pool
|
||||||
actual_size = tmp_path.stat().st_size
|
# so we don't stall concurrent downloads on the event loop.
|
||||||
|
file_hash = await asyncio.to_thread(self._calculate_sha256, tmp_path)
|
||||||
|
actual_size = (await asyncio.to_thread(tmp_path.stat)).st_size
|
||||||
|
|
||||||
# Check if hash already exists in pool
|
# Check if hash already exists in pool
|
||||||
if self.db:
|
if self.db:
|
||||||
in_pool = self.db.get_media_by_hash(file_hash)
|
in_pool = self.db.get_media_by_hash(file_hash)
|
||||||
if in_pool:
|
if in_pool:
|
||||||
tmp_path.unlink()
|
await asyncio.to_thread(tmp_path.unlink)
|
||||||
return {
|
return {
|
||||||
"id": str(media_id),
|
"id": str(media_id),
|
||||||
"filename": filename,
|
"filename": filename,
|
||||||
|
|
@ -645,8 +679,11 @@ class DiscordExporter:
|
||||||
target_filename = f"{file_hash}{ext}"
|
target_filename = f"{file_hash}{ext}"
|
||||||
target_path = self.attachments_path / target_filename
|
target_path = self.attachments_path / target_filename
|
||||||
|
|
||||||
shutil.move(str(tmp_path), str(target_path))
|
await asyncio.to_thread(shutil.move, str(tmp_path), str(target_path))
|
||||||
|
|
||||||
|
# Mark as successfully moved so finally block doesn't delete it
|
||||||
|
tmp_path = None
|
||||||
|
|
||||||
if self.db:
|
if self.db:
|
||||||
self.db.add_media_to_pool(file_hash, f"attachments/{target_filename}", actual_size, content_type, str(url))
|
self.db.add_media_to_pool(file_hash, f"attachments/{target_filename}", actual_size, content_type, str(url))
|
||||||
|
|
||||||
|
|
@ -658,9 +695,14 @@ class DiscordExporter:
|
||||||
"content_type": content_type,
|
"content_type": content_type,
|
||||||
"local_hash": file_hash
|
"local_hash": file_hash
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except BaseException as e:
|
||||||
logger.error(f"Failed to process media {filename}: {e}")
|
if not isinstance(e, asyncio.CancelledError):
|
||||||
if tmp_path and tmp_path.exists(): tmp_path.unlink()
|
logger.error(f"Failed to process media {filename}: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if tmp_path and tmp_path.exists():
|
||||||
|
try: tmp_path.unlink()
|
||||||
|
except: pass
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def export_threads(self, channel_id: int, progress_callback=None, force=False, accumulated_count=0, accumulated_threads=0, accumulated_files=0, after_id: int | None = None):
|
async def export_threads(self, channel_id: int, progress_callback=None, force=False, accumulated_count=0, accumulated_threads=0, accumulated_files=0, after_id: int | None = None):
|
||||||
|
|
@ -750,25 +792,35 @@ class DiscordExporter:
|
||||||
})
|
})
|
||||||
self.db.save_threads(thread_meta)
|
self.db.save_threads(thread_meta)
|
||||||
|
|
||||||
for thread in all_threads:
|
# Export threads concurrently — semaphore limits to 5 at a time to
|
||||||
if not self.is_running: break
|
# avoid flooding Discord's rate limiter.
|
||||||
await asyncio.sleep(0)
|
sem = asyncio.Semaphore(5)
|
||||||
|
|
||||||
accumulated_threads += 1
|
async def _export_one_thread(thread, t_idx):
|
||||||
|
async with sem:
|
||||||
|
if not self.is_running:
|
||||||
|
return 0, 0, 0
|
||||||
|
cnt, thr, fls = await self.export_channel_messages(
|
||||||
|
thread.id,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
force=force,
|
||||||
|
accumulated_count=0,
|
||||||
|
accumulated_threads=0,
|
||||||
|
accumulated_files=0,
|
||||||
|
after_id=after_id
|
||||||
|
)
|
||||||
|
return cnt, thr, fls
|
||||||
|
|
||||||
|
if all_threads:
|
||||||
|
thread_results = await asyncio.gather(
|
||||||
|
*[_export_one_thread(t, i) for i, t in enumerate(all_threads)]
|
||||||
|
)
|
||||||
|
for cnt, thr, fls in thread_results:
|
||||||
|
accumulated_count += cnt
|
||||||
|
accumulated_threads += 1 + thr # +1 for the thread itself
|
||||||
|
accumulated_files += fls
|
||||||
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
await progress_callback(channel.name, accumulated_count, thread_count=accumulated_threads, file_count=accumulated_files)
|
await progress_callback(channel.name, accumulated_count, thread_count=accumulated_threads, file_count=accumulated_files)
|
||||||
|
|
||||||
# Backup thread messages
|
|
||||||
accumulated_count, accumulated_threads, accumulated_files = await self.export_channel_messages(thread.id, progress_callback=progress_callback, force=force, accumulated_count=accumulated_count, accumulated_threads=accumulated_threads, accumulated_files=accumulated_files, after_id=after_id)
|
|
||||||
|
|
||||||
# For forums, ensure the starter message exists in the DB
|
|
||||||
if is_forum:
|
|
||||||
# starter_message is handled by export_channel_messages since it's just a message in that thread
|
|
||||||
# However we may want to mark it or store forum-specific tags
|
|
||||||
try:
|
|
||||||
# Just yield for concurrency
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing forum thread {thread.name}: {e}")
|
|
||||||
|
|
||||||
return accumulated_count, accumulated_threads, accumulated_files
|
return accumulated_count, accumulated_threads, accumulated_files
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue