import os import json import logging import asyncio import hashlib import discord from pathlib import Path from typing import Dict, Any, List, Optional, AsyncGenerator from src.core.backup_database import BackupDatabase logger = logging.getLogger(__name__) class DiscordExporter: """Core logic for exporting Discord server data.""" def __init__(self, reader, base_dir: Path | str = ""): self.reader = reader self.server_name = "" self.server_id = "" self.user_cache = {} self.base_dir = Path(base_dir) if base_dir else Path(".") self.is_running = True self.db: Optional[BackupDatabase] = None self.sticker_cache: Dict[int, bytes] = {} # Deduplicate downloads in one session async def setup(self): """Prepares the output directory and fetches server metadata.""" metadata = await self.reader.get_server_metadata() self.server_name = metadata.get("name", "Unknown Server") self.server_id = metadata.get("id", "0") # Root export path: DISCORD_BACKUP-{id} self.export_path = self.base_dir / f"DISCORD_BACKUP-{self.server_id}" self.export_path.mkdir(parents=True, exist_ok=True) # New directory structure self.assets_path = self.export_path / "server_assets" self.assets_path.mkdir(exist_ok=True) self.users_path = self.export_path / "users" self.users_path.mkdir(exist_ok=True) self.attachments_path = self.export_path / "attachments" self.attachments_path.mkdir(exist_ok=True) # Initialize SQLite database db_path = self.export_path / "backup.db" self.db = BackupDatabase(db_path) logger.info(f"Export directory set to: {self.export_path}") logger.info(f"Initialized backup database at {db_path}") return metadata def _calculate_sha256(self, file_path: Path) -> str: """Calculates SHA-256 hash of a file.""" hash_sha256 = hashlib.sha256() with open(file_path, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): hash_sha256.update(chunk) return hash_sha256.hexdigest() async def export_metadata(self): """Saves server metadata to the SQLite database.""" metadata = await self.reader.get_server_metadata() # Relative paths to local assets for the UI/Reader if self.reader.guild: if self.reader.guild.icon: ext = "gif" if self.reader.guild.icon.is_animated() else "png" metadata["icon_file"] = f"server_assets/server_icon.{ext}" metadata["icon_url"] = str(self.reader.guild.icon.url) else: metadata["icon_file"] = None metadata["icon_url"] = None if self.reader.guild.banner: ext = "gif" if self.reader.guild.banner.is_animated() else "png" metadata["banner_file"] = f"server_assets/server_banner.{ext}" metadata["banner_url"] = str(self.reader.guild.banner.url) else: metadata["banner_file"] = None metadata["banner_url"] = None from datetime import datetime metadata["last_backup"] = datetime.now().isoformat() # Fetch existing ignore_channels from DB if available if self.db: existing_profile = self.db.get_guild_profile() if existing_profile and "ignore_channels" in existing_profile: metadata["ignore_channels"] = existing_profile["ignore_channels"] else: metadata["ignore_channels"] = [] # Initialize if not present self.db.set_guild_profile(metadata) return metadata async def export_roles(self): """Exports all roles to the SQLite database.""" roles = await self.reader.get_roles() role_data = [] for r in roles: role_data.append({ "id": str(r.id), "name": r.name, "color": r.color.value, "position": r.position, "permissions": str(r.permissions.value), "hoist": 1 if r.hoist else 0, "mentionable": 1 if r.mentionable else 0 }) if self.db: self.db.save_roles(role_data) return role_data async def download_server_assets(self): """Downloads server icon and banner to server_assets folder.""" metadata = await self.reader.get_server_metadata() # Download Server Icon if metadata.get("icon_url"): try: if self.reader.guild and self.reader.guild.icon: logger.info(f"Downloading server icon: {self.reader.guild.icon.url}") data = await self.reader.download_asset(self.reader.guild.icon) ext = "gif" if self.reader.guild.icon.is_animated() else "png" icon_path = self.assets_path / f"server_icon.{ext}" with open(icon_path, "wb") as f: f.write(data) logger.info(f"Saved server icon to {icon_path}") except Exception as e: logger.error(f"Failed to download server icon: {e}") # Download Server Banner if metadata.get("banner_url"): try: if self.reader.guild and self.reader.guild.banner: logger.info(f"Downloading server banner: {self.reader.guild.banner.url}") data = await self.reader.download_asset(self.reader.guild.banner) ext = "gif" if self.reader.guild.banner.is_animated() else "png" banner_path = self.assets_path / f"server_banner.{ext}" with open(banner_path, "wb") as f: f.write(data) logger.info(f"Saved server banner to {banner_path}") except Exception as e: logger.error(f"Failed to download server banner: {e}") async def export_assets(self): """Exports emojis and stickers to server_assets/ folder.""" await self.download_server_assets() emojis = await self.reader.get_emojis() stickers = await self.reader.get_stickers() emoji_data = [] logger.info(f"Exporting {len(emojis)} emojis...") for e in emojis: ext = "gif" if e.animated else "png" filename = f"emoji_{e.id}.{ext}" emoji_path = self.assets_path / filename try: if not emoji_path.exists(): data = await self.reader.download_emoji(e) with open(emoji_path, "wb") as f: f.write(data) emoji_data.append({ "id": str(e.id), "name": e.name, "type": "emoji", "filename": filename, "url": str(e.url), "mime_type": "image/gif" if e.animated else "image/png" }) except Exception as ex: logger.error(f"Failed to download emoji {e.name}: {ex}") sticker_data = [] logger.info(f"Exporting {len(stickers)} stickers...") for s in stickers: ext = self.reader.get_sticker_extension(s) filename = f"sticker_{s.id}.{ext}" sticker_path = self.assets_path / filename try: if not sticker_path.exists(): data = await self.reader.download_sticker(s) if data: with open(sticker_path, "wb") as f: f.write(data) mime_type = "image/png" if ext == "json": mime_type = "application/json" elif ext == "gif": mime_type = "image/gif" elif ext == "webp": mime_type = "image/webp" sticker_data.append({ "id": str(s.id), "name": getattr(s, "name", "unknown"), "type": "sticker", "filename": filename, "url": str(s.url), "mime_type": mime_type }) except Exception as ex: logger.error(f"Failed to download sticker {getattr(s, 'name', 'unknown')}: {ex}") # Save to database if self.db: all_assets = emoji_data + sticker_data if all_assets: self.db.save_server_assets(all_assets) return len(emoji_data), len(sticker_data) async def export_channels_structure(self): """Exports categories and channels hierarchy to SQLite.""" categories = await self.reader.get_categories() channels = await self.reader.get_channels() db_channels = [] db_permissions = [] db_forum_tags = [] chan_count = 0 cat_count = len(categories) for cat in categories: cat_channels = [c for c in channels if c.category_id == cat.id] formatted_channels, cat_chan_perms, cat_forum_tags = await self._process_channel_batch(cat_channels) chan_count += len(formatted_channels) db_permissions.extend(cat_chan_perms) db_forum_tags.extend(cat_forum_tags) # Category permissions for target, ow in cat.overwrites.items(): if isinstance(target, (discord.Role, discord.Member)): allow, deny = ow.pair() db_permissions.append({ "channel_id": str(cat.id), "target_id": str(target.id), "target_type": "role" if isinstance(target, discord.Role) else "member", "allow": allow.value, "deny": deny.value }) db_channels.append({ "id": str(cat.id), "name": cat.name, "type": int(cat.type.value) if hasattr(cat.type, "value") else 4, "position": cat.position, "category_id": None, "topic": None, "nsfw": 0 }) # Add child channels to list for fc in formatted_channels: fc["category_id"] = str(cat.id) db_channels.append(fc) # Uncategorized uncategorized = [c for c in channels if not c.category_id] if uncategorized: formatted_uncat, uncat_perms, uncat_forum_tags = await self._process_channel_batch(uncategorized) chan_count += len(formatted_uncat) db_permissions.extend(uncat_perms) db_forum_tags.extend(uncat_forum_tags) for fc in formatted_uncat: fc["category_id"] = None db_channels.append(fc) if self.db: self.db.save_channels(db_channels) if db_permissions: self.db.save_permissions(db_permissions) if db_forum_tags: self.db.save_forum_tags(db_forum_tags) return db_channels, cat_count, chan_count async def _process_channel_batch(self, channels): """Processes a batch of channels, extracting metadata, permissions, and forum tags.""" results = await asyncio.gather(*[self._format_channel(c) for c in channels]) formatted = [] permissions = [] forum_tags = [] for f_data, f_perms, f_tags in results: formatted.append(f_data) permissions.extend(f_perms) if f_tags: forum_tags.extend(f_tags) return formatted, permissions, forum_tags async def _format_channel(self, c): """Prepares channel data, its permissions, and forum tags for DB storage.""" ch_permissions = [] for target, ow in c.overwrites.items(): if isinstance(target, (discord.Role, discord.Member)): allow, deny = ow.pair() ch_permissions.append({ "channel_id": str(c.id), "target_id": str(target.id), "target_type": "role" if isinstance(target, discord.Role) else "member", "allow": allow.value, "deny": deny.value }) ch_forum_tags = [] if isinstance(c, discord.ForumChannel): for t in c.available_tags: ch_forum_tags.append({ "id": str(t.id), "forum_id": str(c.id), "name": t.name, "moderated": 1 if t.moderated else 0, "emoji_id": str(t.emoji.id) if t.emoji and hasattr(t.emoji, "id") else None, "emoji_name": t.emoji.name if t.emoji else (str(t.emoji) if t.emoji else None) }) data = { "id": str(c.id), "name": c.name, "type": int(c.type.value) if hasattr(c.type, "value") else 0, "position": c.position, "topic": getattr(c, "topic", None), "nsfw": 1 if getattr(c, "nsfw", False) else 0 } return data, ch_permissions, ch_forum_tags async def export_channel_messages(self, channel_id: int, progress_callback=None, force=False, accumulated_count=0, accumulated_threads=0, accumulated_files=0, after_id: int | None = None): """Fetches and saves message history for a channel to SQLite, handling incremental sync.""" channel = await self.reader.get_channel(channel_id) if not channel: logger.error(f"Channel not found: {channel_id}") return accumulated_count, accumulated_threads, accumulated_files channel_name = channel.name is_thread = isinstance(channel, discord.Thread) is_forum = isinstance(channel, discord.ForumChannel) # 1. Determine incremental sync point last_id = after_id if not force and last_id is None and self.db: stored_last_id = self.db.get_last_message_id(channel_id) if stored_last_id: last_id = int(stored_last_id) logger.info(f"Incremental sync for {channel_name}: starting after {last_id}") new_count = 0 BATCH_SIZE = 100 USER_SAVE_INTERVAL = 500 # Save user cache every N new messages # Batch accumulator for DB inserts batch_messages = [] batch_users = [] try: batch_raw = [] async for msg in self.reader.fetch_message_history(channel_id, after_id=last_id): if not self.is_running: break batch_raw.append(msg) if len(batch_raw) >= BATCH_SIZE: results = await asyncio.gather(*(self._format_message(m) for m in batch_raw)) for m_data, u_data in results: batch_messages.append(m_data) if u_data: batch_users.append(u_data) new_count += len(batch_messages) accumulated_count += len(batch_messages) for m in batch_messages: if "attachments" in m: accumulated_files += len(m["attachments"]) # Persist to DB if self.db: if batch_users: self.db.save_users(batch_users) self.db.save_messages_batch(batch_messages) if progress_callback: last_msg = batch_raw[-1] author_name = getattr(last_msg.author, "display_name", "Unknown") preview = (last_msg.content or "")[:150] await progress_callback(channel_name, accumulated_count, author_name=author_name, message_preview=preview, thread_count=accumulated_threads, file_count=accumulated_files) batch_messages.clear() batch_users.clear() batch_raw.clear() # Final partial batch if batch_raw and self.is_running: results = await asyncio.gather(*(self._format_message(m) for m in batch_raw)) for m_data, u_data in results: batch_messages.append(m_data) if u_data: batch_users.append(u_data) new_count += len(batch_messages) accumulated_count += len(batch_messages) for m in batch_messages: if "attachments" in m: accumulated_files += len(m["attachments"]) if self.db: if batch_users: self.db.save_users(batch_users) self.db.save_messages_batch(batch_messages) if progress_callback: last_msg = batch_raw[-1] author_name = getattr(last_msg.author, "display_name", "Unknown") await progress_callback(channel_name, accumulated_count, author_name=author_name, thread_count=accumulated_threads, file_count=accumulated_files) batch_messages.clear() batch_users.clear() batch_raw.clear() except discord.Forbidden: logger.error(f"403 Forbidden: Missing Access to read messages in {channel_name} ({channel_id})") except Exception as e: logger.error(f"Error fetching messages for {channel_name}: {e}") if not is_thread: accumulated_count, accumulated_threads, accumulated_files = await self.export_threads(channel_id, progress_callback=progress_callback, force=force, accumulated_count=accumulated_count, accumulated_threads=accumulated_threads, accumulated_files=accumulated_files) return accumulated_count, accumulated_threads, accumulated_files async def _format_message(self, msg): """Formats a single message and its author for DB storage.""" # 1. Author handling author = msg.author user_id = str(author.id) user_data = None if user_id not in self.user_cache: # New user discovered avatar_file = None if author.avatar: try: av_name = f"{user_id}.png" av_target = self.users_path / av_name if not av_target.exists(): await author.avatar.save(av_target) avatar_file = f"users/{av_name}" except Exception as e: logger.error(f"Failed to save avatar for {author.name}: {e}") roles = [] if hasattr(author, "roles"): roles = [str(r.id) for r in author.roles if not r.is_default()] user_data = { "id": user_id, "username": author.name, "display_name": getattr(author, "display_name", author.name), "avatar_file": avatar_file, "avatar_url": str(author.display_avatar.url) if author.avatar else None, "roles": json.dumps(roles) } self.user_cache[user_id] = user_data # 2. Attachments handling (Content-Addressable Storage) attachments = [] if msg.attachments: for att in msg.attachments: att_data = await self._process_media( media_id=att.id, url=att.url, filename=att.filename, size=att.size, content_type=att.content_type, save_method=att.save ) if att_data: attachments.append(att_data) # 2.5 Stickers handling stickers = [] if msg.stickers: for st in msg.stickers: # Deduplicate downloads for the same sticker in one session if st.id in self.sticker_cache: st_bytes = self.sticker_cache[st.id] else: st_bytes = await self.reader.download_sticker(st) if st_bytes: self.sticker_cache[st.id] = st_bytes if st_bytes: ext = self.reader.get_sticker_extension(st) st_data = await self._process_media( media_id=st.id, url=st.url, filename=f"{st.name}.{ext}", content_type=f"image/{ext}" if ext != "json" else "application/json", data=st_bytes ) if st_data: st_data["name"] = st.name st_data["format_type"] = int(st.format.value) if hasattr(st, "format") and hasattr(st.format, "value") else 1 stickers.append(st_data) else: logger.warning(f"Could not download message sticker {st.id} in message {msg.id}") # 3. Embeds embeds = [] if msg.embeds: embeds = [emb.to_dict() for emb in msg.embeds] # 4. Reactions reactions = [] if msg.reactions: for react in msg.reactions: emoji = react.emoji reactions.append({ "emoji_id": emoji.id if hasattr(emoji, "id") else None, "emoji_name": emoji.name if hasattr(emoji, "name") else str(emoji), "count": react.count }) # 5. Message data # Check for reference (reply/origin of forward) message_reference = None if msg.reference and msg.reference.message_id: message_reference = str(msg.reference.message_id) # 5.5 Forwarded snapshots content = msg.content or "" msg_type = int(msg.type.value) if hasattr(msg.type, "value") else 0 # Detect if this message is forwarded (discord.py 2.5+) is_forwarded = getattr(msg.flags, 'forwarded', False) if is_forwarded and hasattr(msg, 'message_snapshots') and msg.message_snapshots: msg_type = 100 # Custom Forward type snapshot = msg.message_snapshots[0] if snapshot.content: content = snapshot.content # Process snapshot attachments into main attachments list for s_att in snapshot.attachments: att_res = await self._process_media( media_id=s_att.id, url=s_att.url, filename=s_att.filename, size=s_att.size, content_type=s_att.content_type, save_method=s_att.save ) if att_res: attachments.append(att_res) # Process snapshot embeds (simplified) for s_emb in snapshot.embeds: # We reuse the main embeds list embeds.append(s_emb.to_dict()) m_data = { "id": str(msg.id), "channel_id": str(msg.channel.id), "author_id": user_id, "content": content, "timestamp": msg.created_at.isoformat(), "type": msg_type, "message_reference": message_reference, "is_pinned": 1 if msg.pinned else 0, "attachments": attachments, "stickers": stickers, "embeds": embeds, "reactions": reactions, "extra_data": None } return m_data, user_data async def _process_media(self, media_id, url, filename, size=None, content_type=None, save_method=None, data=None): """Downloads and deduplicates any media (attachment or sticker) using SHA-256 (CAS).""" # 1. First check by URL in DB if self.db: existing = self.db.get_media_by_url(str(url)) if existing: return { "id": str(media_id), "filename": filename, "size": existing["size"], "url": str(url), "content_type": existing["content_type"], "local_hash": existing["hash"] } # 2. Temporary download to calculate hash import tempfile import shutil tmp_path = None try: with tempfile.NamedTemporaryFile(delete=False) as tmp: tmp_path = Path(tmp.name) if data: tmp.write(data) elif save_method: # Closing handle before save_method just in case it needs to open it's own handle tmp.close() await save_method(tmp_path) else: return None # Ensure it's closed before hashing try: tmp.close() except: pass file_hash = self._calculate_sha256(tmp_path) actual_size = tmp_path.stat().st_size # Check if hash already exists in pool if self.db: in_pool = self.db.get_media_by_hash(file_hash) if in_pool: tmp_path.unlink() return { "id": str(media_id), "filename": filename, "size": actual_size, "url": str(url), "content_type": content_type or in_pool["content_type"], "local_hash": file_hash } # New content: move to pool ext = Path(filename).suffix target_filename = f"{file_hash}{ext}" target_path = self.attachments_path / target_filename shutil.move(str(tmp_path), str(target_path)) if self.db: self.db.add_media_to_pool(file_hash, f"attachments/{target_filename}", actual_size, content_type, str(url)) return { "id": str(media_id), "filename": filename, "size": actual_size, "url": str(url), "content_type": content_type, "local_hash": file_hash } except Exception as e: logger.error(f"Failed to process media {filename}: {e}") if tmp_path and tmp_path.exists(): tmp_path.unlink() return None async def export_threads(self, channel_id: int, progress_callback=None, force=False, accumulated_count=0, accumulated_threads=0, accumulated_files=0, after_id: int | None = None): """Exports active and archived threads for a channel to SQLite.""" channel = await self.reader.get_channel(channel_id) if not hasattr(channel, "threads") and not hasattr(channel, "archived_threads"): return accumulated_count, accumulated_threads, accumulated_files all_threads = [] try: # Active threads if self.reader.guild: threads = await self.reader.guild.active_threads() all_threads.extend([t for t in threads if t.parent_id == channel_id]) # Archived threads try: if hasattr(channel, "archived_threads"): async for thread in channel.archived_threads(limit=None): all_threads.append(thread) except discord.Forbidden: logger.warning(f"403 Forbidden: Cannot fetch archived threads in {channel.name}") except Exception as e: logger.warning(f"Error fetching archived threads: {e}") except Exception as e: logger.error(f"Failed to fetch threads for {channel.name}: {e}") is_forum = isinstance(channel, discord.ForumChannel) logger.debug(f"Exporting threads for channel '{channel.name}' ({channel.id}) [Type: {type(channel)}] [Is Forum: {is_forum}]") if all_threads and self.db: thread_meta = [] for t in all_threads: applied_tags = [] # If parent is missing, try to link it from our fetched channel object # This ensures discord.py can resolve the applied_tags correctly try: if t.parent is None: # Internal hack: link parent if missing to help resolve tags t._parent = channel except Exception: pass if hasattr(t, "applied_tags"): # Attempt 1: Standard attribute applied_tags = [str(tag.id) for tag in t.applied_tags] # Attempt 2: Internal list of IDs if available (sometimes populated when property is empty) if not applied_tags and hasattr(t, "_applied_tags"): raw_ids = getattr(t, "_applied_tags", []) if raw_ids: logger.info(f"Thread '{t.name}' ({t.id}) found raw tags in _applied_tags: {raw_ids}") applied_tags = [str(tid) for tid in raw_ids] # Attempt 3: If still empty and it's a forum thread, try to fetch it specifically if not applied_tags and is_forum: try: # We can try to fetch the thread specifically to get tags # (Discord sometimes doesn't include tags in bulk guild.active_threads) fetched_t = await self.reader.client.fetch_channel(t.id) # Check both property and internal list on fetched object if hasattr(fetched_t, "applied_tags"): applied_tags = [str(tag.id) for tag in fetched_t.applied_tags] if not applied_tags and hasattr(fetched_t, "_applied_tags"): raw_ids = getattr(fetched_t, "_applied_tags", []) applied_tags = [str(tid) for tid in raw_ids] except Exception as e: logger.debug(f"Failed to fetch thread {t.id} for tags: {e}") pass if not applied_tags and is_forum: logger.warning(f"Thread '{t.name}' ({t.id}) is in forum '{channel.name}' but NO tags found (tried all methods)") thread_meta.append({ "id": str(t.id), "name": t.name, "type": int(t.type.value) if hasattr(t.type, "value") else 11, # Default to public_thread "parent_id": str(t.parent_id) if t.parent_id else str(channel.id), "message_count": getattr(t, "message_count", 0), "member_count": getattr(t, "member_count", 0), "archived": 1 if t.archived else 0, "archive_timestamp": t.archive_timestamp.isoformat() if t.archive_timestamp else None, "auto_archive_duration": t.auto_archive_duration, "locked": 1 if getattr(t, "locked", False) else 0, "applied_tags": json.dumps(applied_tags) }) self.db.save_threads(thread_meta) for thread in all_threads: if not self.is_running: break await asyncio.sleep(0) accumulated_threads += 1 if progress_callback: await progress_callback(channel.name, accumulated_count, thread_count=accumulated_threads, file_count=accumulated_files) # Backup thread messages accumulated_count, accumulated_threads, accumulated_files = await self.export_channel_messages(thread.id, progress_callback=progress_callback, force=force, accumulated_count=accumulated_count, accumulated_threads=accumulated_threads, accumulated_files=accumulated_files, after_id=after_id) # For forums, ensure the starter message exists in the DB if is_forum: # starter_message is handled by export_channel_messages since it's just a message in that thread # However we may want to mark it or store forum-specific tags try: # Just yield for concurrency await asyncio.sleep(0) except Exception as e: logger.error(f"Error processing forum thread {thread.name}: {e}") return accumulated_count, accumulated_threads, accumulated_files