impement backup_reader to parse local backups
This commit is contained in:
parent
71bcd3c9bb
commit
eb89a6c453
4 changed files with 811 additions and 32 deletions
|
|
@ -0,0 +1,734 @@
|
|||
"""
|
||||
BackupReader — discord.py-compatible local data provider.
|
||||
|
||||
Reads from local backup JSON files (produced by DiscordExporter) instead of the
|
||||
Discord API. Implements the same public interface as DiscordReader so that
|
||||
migration scripts and UI code can use either provider transparently.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator, Dict, Any, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lightweight enum clones (mirror discord.py values for compatibility)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ChannelType(IntEnum):
|
||||
text = 0
|
||||
voice = 2
|
||||
category = 4
|
||||
news = 5
|
||||
public_thread = 11
|
||||
forum = 15
|
||||
|
||||
|
||||
class MessageType(IntEnum):
|
||||
default = 0
|
||||
reply = 19
|
||||
thread_starter_message = 21
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock colour / permissions helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MockColor:
|
||||
"""Minimal stand-in for discord.Color."""
|
||||
|
||||
__slots__ = ("value",)
|
||||
|
||||
def __init__(self, value: int = 0):
|
||||
self.value = value
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"#{self.value:06x}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"MockColor(value={self.value})"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, (MockColor, int)):
|
||||
return self.value == (other.value if isinstance(other, MockColor) else other)
|
||||
return NotImplemented
|
||||
|
||||
@classmethod
|
||||
def from_hex(cls, hex_str: str) -> "MockColor":
|
||||
"""Parse '#rrggbb' or '0x...' to MockColor."""
|
||||
hex_str = hex_str.lstrip("#")
|
||||
try:
|
||||
return cls(int(hex_str, 16))
|
||||
except ValueError:
|
||||
return cls(0)
|
||||
|
||||
|
||||
class MockPermissions:
|
||||
"""Minimal stand-in for discord.Permissions."""
|
||||
|
||||
__slots__ = ("value",)
|
||||
|
||||
def __init__(self, value: int = 0):
|
||||
self.value = value
|
||||
|
||||
@property
|
||||
def view_channel(self) -> bool:
|
||||
return bool(self.value & 0x400)
|
||||
|
||||
@property
|
||||
def read_message_history(self) -> bool:
|
||||
return bool(self.value & 0x10000)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"MockPermissions(value={self.value})"
|
||||
|
||||
|
||||
class MockPermissionOverwrite:
|
||||
"""Minimal stand-in for discord.PermissionOverwrite."""
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock discord.py model objects
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MockAsset:
|
||||
"""Minimal stand-in for discord.Asset. Points to a local file."""
|
||||
|
||||
__slots__ = ("_path", "url")
|
||||
|
||||
def __init__(self, local_path: Path | str | None):
|
||||
self._path = Path(local_path) if local_path else None
|
||||
self.url = str(local_path) if local_path else None
|
||||
|
||||
async def read(self) -> bytes:
|
||||
if self._path and self._path.exists():
|
||||
return self._path.read_bytes()
|
||||
return b""
|
||||
|
||||
def is_animated(self) -> bool:
|
||||
if self._path:
|
||||
return self._path.suffix.lower() in (".gif",)
|
||||
return False
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return self._path is not None and self._path.exists()
|
||||
|
||||
|
||||
class MockRole:
|
||||
"""Minimal stand-in for discord.Role."""
|
||||
|
||||
__slots__ = ("id", "name", "color", "position", "permissions",
|
||||
"hoist", "mentionable", "managed")
|
||||
|
||||
def __init__(self, data: dict):
|
||||
self.id = int(data["id"])
|
||||
self.name = data["name"]
|
||||
self.color = MockColor.from_hex(data.get("color", "#000000"))
|
||||
self.position = data.get("position", 0)
|
||||
self.permissions = MockPermissions(int(data.get("permissions", 0)))
|
||||
self.hoist = data.get("hoist", False)
|
||||
self.mentionable = data.get("mentionable", False)
|
||||
self.managed = False
|
||||
|
||||
def is_default(self) -> bool:
|
||||
return self.name == "@everyone"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"MockRole(id={self.id}, name='{self.name}')"
|
||||
|
||||
|
||||
class MockCategory:
|
||||
"""Minimal stand-in for discord.CategoryChannel."""
|
||||
|
||||
__slots__ = ("id", "name", "position", "type")
|
||||
|
||||
def __init__(self, data: dict):
|
||||
try:
|
||||
self.id = int(data["id"])
|
||||
except (ValueError, TypeError):
|
||||
self.id = 0 # 'uncategorized' sentinel
|
||||
self.name = data["name"]
|
||||
self.position = data.get("position", 0)
|
||||
self.type = ChannelType.category
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"MockCategory(id={self.id}, name='{self.name}')"
|
||||
|
||||
|
||||
class MockChannel:
|
||||
"""Minimal stand-in for discord.TextChannel / ForumChannel / VoiceChannel."""
|
||||
|
||||
__slots__ = ("id", "name", "type", "position", "topic", "nsfw",
|
||||
"category_id", "available_tags", "parent_id")
|
||||
|
||||
_TYPE_MAP = {
|
||||
"text": ChannelType.text,
|
||||
"voice": ChannelType.voice,
|
||||
"news": ChannelType.news,
|
||||
"forum": ChannelType.forum,
|
||||
"thread": ChannelType.public_thread,
|
||||
}
|
||||
|
||||
def __init__(self, data: dict, category_id: int | None = None):
|
||||
self.id = int(data["id"])
|
||||
self.name = data["name"]
|
||||
self.type = self._TYPE_MAP.get(data.get("type", "text"), ChannelType.text)
|
||||
self.position = data.get("position", 0)
|
||||
self.topic = data.get("topic")
|
||||
self.nsfw = data.get("nsfw", False)
|
||||
self.category_id = category_id
|
||||
self.parent_id = category_id
|
||||
self.available_tags = data.get("available_tags", [])
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"MockChannel(id={self.id}, name='{self.name}', type={self.type})"
|
||||
|
||||
|
||||
class MockMember:
|
||||
"""Minimal stand-in for discord.Member / discord.User."""
|
||||
|
||||
__slots__ = ("id", "name", "display_name", "bot", "color",
|
||||
"roles", "avatar", "guild_permissions")
|
||||
|
||||
def __init__(self, data: dict, role_objects: list | None = None,
|
||||
avatar_base: Path | None = None):
|
||||
self.id = int(data["userID"])
|
||||
self.name = data.get("username", "Unknown")
|
||||
self.display_name = data.get("userNickname") or self.name
|
||||
self.bot = data.get("userIsBot", False)
|
||||
self.color = MockColor.from_hex(data.get("userColor") or "#000000")
|
||||
self.roles = role_objects or []
|
||||
self.guild_permissions = MockPermissions(0)
|
||||
|
||||
avatar_rel = data.get("userAvatar")
|
||||
if avatar_rel and avatar_base:
|
||||
self.avatar = MockAsset(avatar_base / avatar_rel)
|
||||
else:
|
||||
self.avatar = MockAsset(None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"MockMember(id={self.id}, name='{self.name}')"
|
||||
|
||||
|
||||
class MockAttachment:
|
||||
"""Minimal stand-in for discord.Attachment."""
|
||||
|
||||
__slots__ = ("id", "filename", "size", "url", "proxy_url", "_backup_root")
|
||||
|
||||
def __init__(self, data: dict, backup_root: Path | None = None):
|
||||
self.id = int(data["id"])
|
||||
self.filename = data.get("fileName", "unknown")
|
||||
self.size = data.get("fileSizeBytes", 0)
|
||||
self.url = data.get("url", "")
|
||||
self.proxy_url = self.url
|
||||
self._backup_root = backup_root
|
||||
|
||||
async def read(self) -> bytes:
|
||||
if self._backup_root:
|
||||
full = self._backup_root / self.url
|
||||
if full.exists():
|
||||
return full.read_bytes()
|
||||
return b""
|
||||
|
||||
async def save(self, path) -> None:
|
||||
data = await self.read()
|
||||
Path(path).write_bytes(data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"MockAttachment(id={self.id}, filename='{self.filename}')"
|
||||
|
||||
|
||||
class MockEmoji:
|
||||
"""Minimal stand-in for discord.Emoji."""
|
||||
|
||||
__slots__ = ("id", "name", "animated", "url", "_file_path")
|
||||
|
||||
def __init__(self, data: dict, media_dir: Path | None = None):
|
||||
self.id = int(data["id"])
|
||||
self.name = data["name"]
|
||||
self.animated = data.get("animated", False)
|
||||
filename = data.get("filename", "")
|
||||
self._file_path = media_dir / filename if media_dir and filename else None
|
||||
self.url = str(self._file_path) if self._file_path else None
|
||||
|
||||
async def read(self) -> bytes:
|
||||
if self._file_path and self._file_path.exists():
|
||||
return self._file_path.read_bytes()
|
||||
return b""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"MockEmoji(id={self.id}, name='{self.name}')"
|
||||
|
||||
|
||||
class MockSticker:
|
||||
"""Minimal stand-in for discord.GuildSticker."""
|
||||
|
||||
__slots__ = ("id", "name", "url", "format", "_file_path")
|
||||
|
||||
def __init__(self, data: dict, media_dir: Path | None = None):
|
||||
self.id = int(data["id"])
|
||||
self.name = data["name"]
|
||||
filename = data.get("filename", "")
|
||||
self._file_path = media_dir / filename if media_dir and filename else None
|
||||
self.url = str(self._file_path) if self._file_path else None
|
||||
self.format = data.get("format", "png")
|
||||
|
||||
async def read(self) -> bytes:
|
||||
if self._file_path and self._file_path.exists():
|
||||
return self._file_path.read_bytes()
|
||||
return b""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"MockSticker(id={self.id}, name='{self.name}')"
|
||||
|
||||
|
||||
class MockPartialEmoji:
|
||||
"""Minimal stand-in for discord.PartialEmoji."""
|
||||
|
||||
__slots__ = ("name", "id", "animated")
|
||||
|
||||
def __init__(self, name: str, id: int | None = None, animated: bool = False):
|
||||
self.name = name
|
||||
self.id = id
|
||||
self.animated = animated
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.id:
|
||||
return f"<{'a' if self.animated else ''}:{self.name}:{self.id}>"
|
||||
return self.name
|
||||
|
||||
|
||||
class MockReaction:
|
||||
"""Minimal stand-in for discord.Reaction."""
|
||||
|
||||
__slots__ = ("emoji", "count")
|
||||
|
||||
def __init__(self, data: dict):
|
||||
emoji_raw = data.get("emoji", "")
|
||||
self.count = data.get("count", 0)
|
||||
|
||||
if ":" in emoji_raw and not emoji_raw.startswith("<"):
|
||||
parts = emoji_raw.split(":", 1)
|
||||
try:
|
||||
self.emoji = MockPartialEmoji(name=parts[0], id=int(parts[1]))
|
||||
except (ValueError, IndexError):
|
||||
self.emoji = MockPartialEmoji(name=emoji_raw)
|
||||
else:
|
||||
self.emoji = MockPartialEmoji(name=emoji_raw)
|
||||
|
||||
def is_custom_emoji(self) -> bool:
|
||||
return self.emoji.id is not None
|
||||
|
||||
|
||||
class MockMessageReference:
|
||||
"""Minimal stand-in for discord.MessageReference."""
|
||||
|
||||
__slots__ = ("message_id", "channel_id")
|
||||
|
||||
def __init__(self, data: dict):
|
||||
self.message_id = int(data["messageId"])
|
||||
self.channel_id = int(data["channelId"])
|
||||
|
||||
|
||||
class MockThread:
|
||||
"""Minimal stand-in for discord.Thread metadata attached to a starter message."""
|
||||
|
||||
__slots__ = ("id", "name", "message_count", "archived",
|
||||
"auto_archive_duration", "locked", "parent_id")
|
||||
|
||||
def __init__(self, data: dict, parent_id: int | None = None):
|
||||
self.id = int(data["id"])
|
||||
self.name = data.get("name", "")
|
||||
self.message_count = data.get("messageCount", 0)
|
||||
self.archived = data.get("archived", False)
|
||||
self.auto_archive_duration = data.get("archiveDuration", 1440)
|
||||
self.locked = data.get("locked", False)
|
||||
self.parent_id = parent_id
|
||||
|
||||
|
||||
class MockMessage:
|
||||
"""Minimal stand-in for discord.Message."""
|
||||
|
||||
_TYPE_MAP = {
|
||||
"Default": MessageType.default,
|
||||
"Reply": MessageType.reply,
|
||||
"ThreadStarter": MessageType.thread_starter_message,
|
||||
"Thread_starter_message": MessageType.thread_starter_message,
|
||||
"Forward": MessageType.default,
|
||||
}
|
||||
|
||||
__slots__ = ("id", "type", "created_at", "pinned", "content", "author",
|
||||
"attachments", "embeds", "stickers", "reactions",
|
||||
"reference", "thread", "channel_id", "flags")
|
||||
|
||||
def __init__(self, data: dict, *,
|
||||
author: MockMember | None = None,
|
||||
channel_id: int | None = None,
|
||||
backup_root: Path | None = None):
|
||||
self.id = int(data["messageID"])
|
||||
self.type = self._TYPE_MAP.get(data.get("type", "Default"), MessageType.default)
|
||||
self.pinned = data.get("isPinned", False)
|
||||
self.content = data.get("content", "")
|
||||
self.author = author
|
||||
self.channel_id = channel_id
|
||||
|
||||
# Timestamp
|
||||
ts = data.get("timestamp")
|
||||
if ts:
|
||||
try:
|
||||
self.created_at = datetime.fromisoformat(ts)
|
||||
except (ValueError, TypeError):
|
||||
self.created_at = datetime.now(timezone.utc)
|
||||
else:
|
||||
self.created_at = datetime.now(timezone.utc)
|
||||
|
||||
# Attachments
|
||||
self.attachments = [
|
||||
MockAttachment(a, backup_root=backup_root)
|
||||
for a in data.get("attachments", [])
|
||||
]
|
||||
|
||||
# Embeds — store raw dicts (discord.py Embed.from_dict compatible)
|
||||
self.embeds = data.get("embeds", [])
|
||||
|
||||
# Stickers — store raw dicts
|
||||
self.stickers = data.get("stickers", [])
|
||||
|
||||
# Reactions
|
||||
self.reactions = [MockReaction(r) for r in data.get("reactions", [])]
|
||||
|
||||
# Reference (replies)
|
||||
ref = data.get("reference")
|
||||
self.reference = MockMessageReference(ref) if ref else None
|
||||
|
||||
# Thread info
|
||||
thread_data = data.get("thread")
|
||||
self.thread = MockThread(thread_data, parent_id=channel_id) if thread_data else None
|
||||
|
||||
# Flags placeholder
|
||||
self.flags = type("Flags", (), {"forwarded": data.get("type") == "Forward"})()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"MockMessage(id={self.id}, author={self.author})"
|
||||
|
||||
|
||||
class MockGuild:
|
||||
"""Minimal stand-in for discord.Guild."""
|
||||
|
||||
__slots__ = ("id", "name", "icon", "banner")
|
||||
|
||||
def __init__(self, data: dict, backup_path: Path):
|
||||
self.id = int(data["id"])
|
||||
self.name = data["name"]
|
||||
|
||||
icon_rel = data.get("icon")
|
||||
self.icon = MockAsset(backup_path / icon_rel) if icon_rel else MockAsset(None)
|
||||
|
||||
banner_rel = data.get("banner")
|
||||
self.banner = MockAsset(backup_path / banner_rel) if banner_rel else MockAsset(None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom Forbidden exception for backup context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class BackupForbidden(Exception):
|
||||
"""Raised when a requested resource doesn't exist in the backup."""
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BackupReader — main provider class
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class BackupReader:
|
||||
"""Reads from local backup files instead of the Discord API.
|
||||
|
||||
Implements the same public interface as DiscordReader so that
|
||||
migration scripts and UI code can use either provider transparently.
|
||||
"""
|
||||
|
||||
# -- Provider constants (same interface as DiscordReader) --
|
||||
MESSAGE_TYPE_DEFAULT = MessageType.default
|
||||
MESSAGE_TYPE_REPLY = MessageType.reply
|
||||
MESSAGE_TYPE_THREAD_STARTER = MessageType.thread_starter_message
|
||||
|
||||
Forbidden = BackupForbidden
|
||||
|
||||
CHANNEL_TYPE_TEXT = ChannelType.text
|
||||
CHANNEL_TYPE_NEWS = ChannelType.news
|
||||
CHANNEL_TYPE_FORUM = ChannelType.forum
|
||||
|
||||
@staticmethod
|
||||
def find_item(iterable, **attrs):
|
||||
"""Find first item in iterable matching all attrs."""
|
||||
for item in iterable:
|
||||
if all(getattr(item, k, None) == v for k, v in attrs.items()):
|
||||
return item
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def create_permission_overwrite():
|
||||
"""Factory for a PermissionOverwrite mock."""
|
||||
return MockPermissionOverwrite()
|
||||
|
||||
def __init__(self, backup_path: str | Path):
|
||||
self.backup_path = Path(backup_path)
|
||||
self.guild: MockGuild | None = None
|
||||
self.role_map: Dict[int, str] = {}
|
||||
|
||||
# Internal caches populated by start()
|
||||
self._categories: List[MockCategory] = []
|
||||
self._channels: List[MockChannel] = []
|
||||
self._roles: List[MockRole] = []
|
||||
self._emojis: List[MockEmoji] = []
|
||||
self._stickers: List[MockSticker] = []
|
||||
self._members: List[MockMember] = []
|
||||
self._member_map: Dict[int, MockMember] = {}
|
||||
|
||||
# ── startup ──────────────────────────────────────────────────────────
|
||||
|
||||
async def start(self):
|
||||
"""Loads all JSON files from the backup directory into memory."""
|
||||
bp = self.backup_path
|
||||
|
||||
# 1. Server profile -> MockGuild
|
||||
profile_file = bp / "server_profile.json"
|
||||
if profile_file.exists():
|
||||
profile = json.loads(profile_file.read_text(encoding="utf-8"))
|
||||
self.guild = MockGuild(profile, bp)
|
||||
logger.info(f"[Backup] Loaded server profile: {self.guild.name} ({self.guild.id})")
|
||||
else:
|
||||
logger.warning(f"[Backup] server_profile.json not found in {bp}")
|
||||
self.guild = None
|
||||
|
||||
# 2. Roles
|
||||
roles_file = bp / "server_roles.json"
|
||||
if roles_file.exists():
|
||||
roles_data = json.loads(roles_file.read_text(encoding="utf-8"))
|
||||
self._roles = [MockRole(r) for r in roles_data]
|
||||
self.role_map = {r.id: r.name for r in self._roles}
|
||||
logger.info(f"[Backup] Loaded {len(self._roles)} roles")
|
||||
|
||||
# 3. Structure -> categories + channels
|
||||
struct_file = bp / "server_structure.json"
|
||||
if struct_file.exists():
|
||||
structure = json.loads(struct_file.read_text(encoding="utf-8"))
|
||||
for cat_data in structure:
|
||||
cat = MockCategory(cat_data)
|
||||
if cat.id != 0: # skip 'uncategorized' as a real category
|
||||
self._categories.append(cat)
|
||||
|
||||
for ch_data in cat_data.get("channels", []):
|
||||
ch_cat_id = cat.id if cat.id != 0 else None
|
||||
channel = MockChannel(ch_data, category_id=ch_cat_id)
|
||||
self._channels.append(channel)
|
||||
|
||||
logger.info(f"[Backup] Loaded {len(self._categories)} categories, "
|
||||
f"{len(self._channels)} channels")
|
||||
|
||||
# 4. Assets (emojis + stickers)
|
||||
assets_file = bp / "server_assets.json"
|
||||
media_dir = bp / "server_media"
|
||||
if assets_file.exists():
|
||||
assets = json.loads(assets_file.read_text(encoding="utf-8"))
|
||||
self._emojis = [MockEmoji(e, media_dir) for e in assets.get("emojis", [])]
|
||||
self._stickers = [MockSticker(s, media_dir) for s in assets.get("stickers", [])]
|
||||
logger.info(f"[Backup] Loaded {len(self._emojis)} emojis, "
|
||||
f"{len(self._stickers)} stickers")
|
||||
|
||||
# 5. Users
|
||||
user_info_file = bp / "message_backup" / "user_info.json"
|
||||
if user_info_file.exists():
|
||||
try:
|
||||
users = json.loads(user_info_file.read_text(encoding="utf-8"))
|
||||
backup_root = bp / "message_backup"
|
||||
for u in users:
|
||||
user_role_ids = {int(r["id"]) for r in u.get("userRoles", [])}
|
||||
role_objs = [r for r in self._roles if r.id in user_role_ids]
|
||||
member = MockMember(u, role_objects=role_objs, avatar_base=backup_root)
|
||||
self._members.append(member)
|
||||
self._member_map[member.id] = member
|
||||
logger.info(f"[Backup] Loaded {len(self._members)} users")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Backup] Failed to load user_info.json: {e}")
|
||||
|
||||
# ── validation ───────────────────────────────────────────────────────
|
||||
|
||||
async def validate(self) -> Dict[str, Any]:
|
||||
"""Validates backup directory integrity."""
|
||||
results = {
|
||||
"token": False,
|
||||
"server": False,
|
||||
"bot_name": None,
|
||||
"server_name": None,
|
||||
"intents": {"message_content": True},
|
||||
"permissions": {"view_channel": True, "read_message_history": True},
|
||||
}
|
||||
|
||||
bp = self.backup_path
|
||||
if not bp.exists() or not bp.is_dir():
|
||||
return results
|
||||
|
||||
profile = bp / "server_profile.json"
|
||||
if profile.exists():
|
||||
try:
|
||||
data = json.loads(profile.read_text(encoding="utf-8"))
|
||||
results["token"] = True
|
||||
results["server"] = True
|
||||
results["bot_name"] = "BackupReader"
|
||||
results["server_name"] = data.get("name", "Unknown")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return results
|
||||
|
||||
# ── server metadata ──────────────────────────────────────────────────
|
||||
|
||||
async def get_server_metadata(self) -> Dict[str, Any]:
|
||||
if not self.guild:
|
||||
return {}
|
||||
return {
|
||||
"name": self.guild.name,
|
||||
"id": str(self.guild.id),
|
||||
"icon_url": self.guild.icon.url if self.guild.icon else None,
|
||||
"banner_url": self.guild.banner.url if self.guild.banner else None,
|
||||
}
|
||||
|
||||
async def download_asset(self, asset: MockAsset) -> bytes:
|
||||
return await asset.read()
|
||||
|
||||
# ── categories & channels ────────────────────────────────────────────
|
||||
|
||||
async def get_categories(self) -> List[MockCategory]:
|
||||
return list(self._categories)
|
||||
|
||||
async def get_channels(self, category_id: int | None = None) -> List[MockChannel]:
|
||||
channels = [c for c in self._channels if c.type != ChannelType.category]
|
||||
if category_id is not None:
|
||||
channels = [c for c in channels if c.category_id == category_id]
|
||||
return channels
|
||||
|
||||
async def get_channel(self, channel_id: int) -> MockChannel | None:
|
||||
for c in self._channels:
|
||||
if c.id == channel_id:
|
||||
return c
|
||||
return None
|
||||
|
||||
# ── roles, emojis, stickers, members ─────────────────────────────────
|
||||
|
||||
async def get_roles(self) -> List[MockRole]:
|
||||
return [r for r in self._roles if not r.is_default()]
|
||||
|
||||
async def get_emojis(self) -> List[MockEmoji]:
|
||||
return list(self._emojis)
|
||||
|
||||
async def get_stickers(self) -> List[MockSticker]:
|
||||
return list(self._stickers)
|
||||
|
||||
async def get_members(self) -> List[MockMember]:
|
||||
return list(self._members)
|
||||
|
||||
# ── messages ─────────────────────────────────────────────────────────
|
||||
|
||||
def _resolve_author(self, user_id_str: str) -> MockMember:
|
||||
"""Returns MockMember for a userID, creating a stub if missing."""
|
||||
uid = int(user_id_str)
|
||||
if uid in self._member_map:
|
||||
return self._member_map[uid]
|
||||
stub = MockMember({
|
||||
"userID": user_id_str,
|
||||
"username": f"User#{user_id_str[-4:]}",
|
||||
"userIsBot": False,
|
||||
})
|
||||
self._member_map[uid] = stub
|
||||
return stub
|
||||
|
||||
def _load_channel_messages(self, channel_id: int) -> list[dict]:
|
||||
"""Loads the messages array from a channel JSON file."""
|
||||
bp = self.backup_path / "message_backup"
|
||||
json_file = bp / f"{channel_id}.json"
|
||||
if not json_file.exists():
|
||||
for candidate in [
|
||||
bp / "threads" / f"{channel_id}.json",
|
||||
*bp.glob(f"*/{channel_id}.json"),
|
||||
]:
|
||||
if candidate.exists():
|
||||
json_file = candidate
|
||||
break
|
||||
else:
|
||||
return []
|
||||
|
||||
try:
|
||||
data = json.loads(json_file.read_text(encoding="utf-8"))
|
||||
return data.get("messages", [])
|
||||
except Exception as e:
|
||||
logger.error(f"[Backup] Failed to load messages for channel {channel_id}: {e}")
|
||||
return []
|
||||
|
||||
def _hydrate_message(self, msg_data: dict, channel_id: int) -> MockMessage:
|
||||
author = self._resolve_author(msg_data.get("userID", "0"))
|
||||
backup_root = self.backup_path / "message_backup"
|
||||
return MockMessage(
|
||||
msg_data,
|
||||
author=author,
|
||||
channel_id=channel_id,
|
||||
backup_root=backup_root,
|
||||
)
|
||||
|
||||
async def get_message(self, channel_id: int, message_id: int) -> MockMessage | None:
|
||||
messages = self._load_channel_messages(channel_id)
|
||||
for m in messages:
|
||||
if int(m["messageID"]) == message_id:
|
||||
return self._hydrate_message(m, channel_id)
|
||||
return None
|
||||
|
||||
async def get_first_message(self, channel_id: int) -> MockMessage | None:
|
||||
messages = self._load_channel_messages(channel_id)
|
||||
if messages:
|
||||
return self._hydrate_message(messages[0], channel_id)
|
||||
return None
|
||||
|
||||
async def fetch_message_history(
|
||||
self,
|
||||
channel_id: int,
|
||||
limit: int = None,
|
||||
after_id: int = None,
|
||||
) -> AsyncGenerator["MockMessage", None]:
|
||||
"""Yields MockMessages from the backup, respecting after_id and limit."""
|
||||
messages = self._load_channel_messages(channel_id)
|
||||
count = 0
|
||||
|
||||
for m in messages:
|
||||
msg_id = int(m["messageID"])
|
||||
if after_id and msg_id <= after_id:
|
||||
continue
|
||||
|
||||
yield self._hydrate_message(m, channel_id)
|
||||
count += 1
|
||||
|
||||
if limit and count >= limit:
|
||||
return
|
||||
|
||||
# ── download helpers ─────────────────────────────────────────────────
|
||||
|
||||
async def download_emoji(self, emoji: MockEmoji) -> bytes:
|
||||
return await emoji.read()
|
||||
|
||||
async def download_sticker(self, sticker: MockSticker) -> bytes:
|
||||
return await sticker.read()
|
||||
|
||||
async def download_attachment(self, attachment: MockAttachment) -> bytes:
|
||||
return await attachment.read()
|
||||
|
||||
# ── lifecycle ────────────────────────────────────────────────────────
|
||||
|
||||
async def close(self):
|
||||
"""No-op for backup reader (no connections to close)."""
|
||||
pass
|
||||
|
|
@ -17,13 +17,14 @@ class DiscordExporter:
|
|||
self.server_id = ""
|
||||
self.user_cache = {}
|
||||
self.base_dir = Path(base_dir) if base_dir else Path(".")
|
||||
self.is_running = True
|
||||
|
||||
async def setup(self):
|
||||
"""Prepares the output directory and fetches server metadata."""
|
||||
metadata = await self.reader.get_server_metadata()
|
||||
self.server_name = metadata.get("name", "Unknown Server")
|
||||
self.server_id = metadata.get("id", "0")
|
||||
|
||||
|
||||
# Create safe folder name
|
||||
import re
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9_\-\.]', '_', self.server_name)
|
||||
|
|
@ -38,6 +39,15 @@ class DiscordExporter:
|
|||
logger.info(f"Targeting server: {self.server_name} ({self.server_id})")
|
||||
return metadata
|
||||
|
||||
def _save_json_sync(self, file_path, data):
|
||||
"""Sync helper for saving JSON, meant to be run in a thread."""
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
async def _save_json(self, file_path, data):
|
||||
"""Async wrapper for saving JSON in a thread."""
|
||||
await asyncio.to_thread(self._save_json_sync, file_path, data)
|
||||
|
||||
async def export_metadata(self):
|
||||
"""Saves server metadata to a JSON file."""
|
||||
metadata = await self.reader.get_server_metadata()
|
||||
|
|
@ -74,8 +84,7 @@ class DiscordExporter:
|
|||
|
||||
metadata["ignore_channels"] = ignore_channels
|
||||
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(metadata, f, indent=4, ensure_ascii=False)
|
||||
await self._save_json(output_file, metadata)
|
||||
return metadata
|
||||
|
||||
async def export_roles(self):
|
||||
|
|
@ -94,8 +103,7 @@ class DiscordExporter:
|
|||
})
|
||||
|
||||
output_file = self.export_path / "server_roles.json"
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(role_data, f, indent=4, ensure_ascii=False)
|
||||
await self._save_json(output_file, role_data)
|
||||
return role_data
|
||||
|
||||
async def download_server_assets(self):
|
||||
|
|
@ -202,8 +210,7 @@ class DiscordExporter:
|
|||
customization["members"] = old_data.get("members", [])
|
||||
except Exception: pass
|
||||
|
||||
with open(custom_file, "w", encoding="utf-8") as f:
|
||||
json.dump(customization, f, indent=4, ensure_ascii=False)
|
||||
await self._save_json(custom_file, customization)
|
||||
|
||||
return len(emoji_data), len(sticker_data)
|
||||
|
||||
|
|
@ -244,8 +251,7 @@ class DiscordExporter:
|
|||
# but let's see if the user wants it. For now, cat_count is real Discord categories.
|
||||
|
||||
output_file = self.export_path / "server_structure.json"
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(structure, f, indent=4, ensure_ascii=False)
|
||||
await self._save_json(output_file, structure)
|
||||
return structure, cat_count, chan_count
|
||||
|
||||
async def _format_channel(self, c):
|
||||
|
|
@ -354,6 +360,7 @@ class DiscordExporter:
|
|||
# 1. Fetch new messages - Handle Forbidden gracefully
|
||||
try:
|
||||
async for msg in self.reader.fetch_message_history(channel_id, after_id=last_id):
|
||||
if not self.is_running: break
|
||||
await asyncio.sleep(0) # Yield control
|
||||
msg_data = await self._format_message(msg, asset_dir, base_filename, avatar_dir, avatar_rel_base)
|
||||
messages.append(msg_data)
|
||||
|
|
@ -440,12 +447,10 @@ class DiscordExporter:
|
|||
|
||||
# Save channel messages
|
||||
await asyncio.sleep(0) # Yield before writing large JSON
|
||||
with open(json_file, "w", encoding="utf-8") as f:
|
||||
json.dump(output_data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
# Save/Update user_info.json
|
||||
with open(user_info_file, "w", encoding="utf-8") as f:
|
||||
json.dump(list(self.user_cache.values()), f, indent=4, ensure_ascii=False)
|
||||
await self._save_json(json_file, output_data)
|
||||
|
||||
# Save/Update user_info.json (usually small, but consistent to thread it)
|
||||
await self._save_json(user_info_file, list(self.user_cache.values()))
|
||||
|
||||
# If it's a forum, also export its threads into the sub-directory
|
||||
if is_forum:
|
||||
|
|
@ -456,28 +461,33 @@ class DiscordExporter:
|
|||
async def _format_message(self, msg, asset_dir, asset_prefix, avatar_dir, avatar_rel_base):
|
||||
"""Formats a single message to match the reference format."""
|
||||
attachments = []
|
||||
for a in msg.attachments:
|
||||
async def process_attachment(a):
|
||||
# mimic reference asset naming (suffixing hash/id)
|
||||
safe_name = a.filename
|
||||
short_id = str(a.id)[-5:]
|
||||
stored_name = f"{Path(safe_name).stem}-{short_id}{Path(safe_name).suffix}"
|
||||
target = asset_dir / stored_name
|
||||
|
||||
try:
|
||||
# Check if exists, else download (basic cache)
|
||||
target = asset_dir / stored_name
|
||||
if not target.exists():
|
||||
data = await a.read()
|
||||
with open(target, "wb") as f:
|
||||
f.write(data)
|
||||
# Attachment.save() uses a thread internally to save to disk
|
||||
await a.save(target)
|
||||
|
||||
attachments.append({
|
||||
return {
|
||||
"id": str(a.id),
|
||||
"url": f"{asset_prefix}/{stored_name}",
|
||||
"fileName": a.filename,
|
||||
"fileSizeBytes": a.size
|
||||
})
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download attachment {a.filename}: {e}")
|
||||
return None
|
||||
|
||||
# Download all attachments for this message concurrently
|
||||
if msg.attachments:
|
||||
results = await asyncio.gather(*(process_attachment(a) for a in msg.attachments))
|
||||
attachments = [r for r in results if r]
|
||||
|
||||
# Author info extraction and deduplication
|
||||
author = msg.author
|
||||
|
|
@ -658,6 +668,9 @@ class DiscordExporter:
|
|||
logger.info(f"Found {len(all_threads)} threads in {channel.name}. Starting backup...")
|
||||
|
||||
for thread in all_threads:
|
||||
if not self.is_running:
|
||||
logger.info("Thread backup cancelled by user.")
|
||||
break
|
||||
await asyncio.sleep(0) # important yield between threads
|
||||
|
||||
# First backup the full thread — this creates {thread_id}.json with totalAttachmentSizeBytes
|
||||
|
|
@ -737,12 +750,10 @@ class DiscordExporter:
|
|||
forum_data["numberOfAttachments"] = sum(
|
||||
m.get("numberOfFiles", 0) for m in forum_data["messages"]
|
||||
)
|
||||
# Keep chronological order
|
||||
forum_data["messages"].sort(key=lambda x: x["timestamp"])
|
||||
|
||||
await asyncio.sleep(0) # Yield before writing
|
||||
with open(forum_json_file, "w", encoding="utf-8") as f:
|
||||
json.dump(forum_data, f, indent=4, ensure_ascii=False)
|
||||
await self._save_json(forum_json_file, forum_data)
|
||||
logger.info(f"Appended starter message for {thread.name} to {forum_json_file.name}")
|
||||
else:
|
||||
logger.warning(f"Forum JSON file does not exist: {forum_json_file}")
|
||||
|
|
|
|||
|
|
@ -227,14 +227,19 @@ class BackupPane(Container):
|
|||
self.app.push_screen(modal_prog)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
msg = "Sync existing backups" if not force_overwrite else "Overwriting existing backups"
|
||||
msg = "Backup Channels" if not force_overwrite else "Overwriting existing backups"
|
||||
target_preview = ", ".join([c.name for c in selected_channels[:3]])
|
||||
if len(selected_channels) > 3:
|
||||
target_preview += "..."
|
||||
|
||||
modal_prog.set_status(f"Awaiting Confirmation to backup [bold]{len(selected_channels)}[/bold] channels...")
|
||||
modal_prog.set_status(f"Confirm to proceed with Backup of [bold]{len(selected_channels)}[/bold] channels")
|
||||
modal_prog.show_info(f"[cyan]{msg}[/cyan]", f"Targets: {target_preview}")
|
||||
|
||||
# Show full target list in the bottom log
|
||||
modal_prog.write("[bold]Target Channels:[/bold]")
|
||||
for idx, c in enumerate(selected_channels):
|
||||
modal_prog.write(f" {idx+1}. #{c.name}")
|
||||
|
||||
choice = await modal_prog.phase_wait_confirm(btn_start_label="Start Channel Backup", show_id=False)
|
||||
if choice == "btn_back":
|
||||
modal_prog.dismiss()
|
||||
|
|
@ -249,6 +254,10 @@ class BackupPane(Container):
|
|||
|
||||
modal_prog.phase_progress()
|
||||
modal_prog.show_stats()
|
||||
|
||||
# Reset running flag and set cancel callback
|
||||
self.exporter.is_running = True
|
||||
modal_prog.cancel_callback = lambda: setattr(self.exporter, "is_running", False)
|
||||
|
||||
total_chans = len(selected_channels)
|
||||
modal_prog.set_status("Backing up messages...")
|
||||
|
|
@ -257,6 +266,9 @@ class BackupPane(Container):
|
|||
accumulated_msgs = 0
|
||||
|
||||
for i, chan in enumerate(selected_channels):
|
||||
if not self.exporter.is_running:
|
||||
modal_prog.write("[bold red]Backup cancelled by user.[/bold red]")
|
||||
break
|
||||
await asyncio.sleep(0.01) # Yield to UI thread to keep it responsive
|
||||
|
||||
backup_exists = (self.exporter.export_path / "message_backup" / f"{chan.id}.json").exists()
|
||||
|
|
@ -284,6 +296,11 @@ class BackupPane(Container):
|
|||
|
||||
modal_prog.write(f"[green]Completed: {chan.name}[/green]")
|
||||
|
||||
if not self.exporter.is_running:
|
||||
modal_prog.set_item_status("[bold red]Backup Cancelled.[/bold red]")
|
||||
modal_prog.phase_report("Message Backup", "stopped")
|
||||
return
|
||||
|
||||
modal_prog.set_progress(total_chans, total_chans)
|
||||
modal_prog.set_item_status("[bold green]Backup completed successfully![/bold green]")
|
||||
|
||||
|
|
@ -340,9 +357,16 @@ class BackupPane(Container):
|
|||
modal_prog.set_status("Syncing messages...")
|
||||
modal_prog.write(f"[yellow]Syncing {total_chans} channels...[/yellow]")
|
||||
|
||||
# Reset running flag and set cancel callback
|
||||
self.exporter.is_running = True
|
||||
modal_prog.cancel_callback = lambda: setattr(self.exporter, "is_running", False)
|
||||
|
||||
accumulated_msgs = 0
|
||||
|
||||
for i, chan in enumerate(selected_channels):
|
||||
if not self.exporter.is_running:
|
||||
modal_prog.write("[bold red]Sync cancelled by user.[/bold red]")
|
||||
break
|
||||
await asyncio.sleep(0.01) # Yield to UI thread
|
||||
|
||||
modal_prog.set_item_status(f"[cyan]Syncing ({i+1}/{total_chans}): #{chan.name}[/cyan]")
|
||||
|
|
@ -365,6 +389,11 @@ class BackupPane(Container):
|
|||
)
|
||||
modal_prog.write(f"[green]Synced: {chan.name}[/green]")
|
||||
|
||||
if not self.exporter.is_running:
|
||||
modal_prog.set_item_status("[bold red]Sync Cancelled.[/bold red]")
|
||||
modal_prog.phase_report("Backup Sync", "stopped")
|
||||
return
|
||||
|
||||
modal_prog.set_progress(total_chans, total_chans)
|
||||
modal_prog.set_item_status("[bold green]Sync operation complete![/bold green]")
|
||||
|
||||
|
|
|
|||
|
|
@ -157,13 +157,18 @@ class ProgressScreen(Screen[None]):
|
|||
self.confirm_future.set_result(btn_id)
|
||||
return
|
||||
|
||||
# If Cancel is pressed during operation, invoke callback and dismiss
|
||||
# If Cancel is pressed during operation, invoke callback and stay on screen
|
||||
if btn_id == "btn_cancel":
|
||||
if self.cancel_callback:
|
||||
self.cancel_callback()
|
||||
if self.timer_event:
|
||||
self.timer_event.stop()
|
||||
self.dismiss("btn_cancel")
|
||||
|
||||
# Show cancelling message and disable button
|
||||
self.set_status("[bold red]Cancelling... waiting for tasks to finish...[/bold red]")
|
||||
try:
|
||||
event.button.disabled = True
|
||||
event.button.label = "Stopping..."
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
# If operation is done (report phase), just dismiss with the action
|
||||
|
|
@ -660,7 +665,7 @@ class ChannelSelectScreen(Screen[dict]):
|
|||
if self.any_found:
|
||||
yield Label("Existing backups found:", classes="label_warning")
|
||||
yield Button("Sync", variant="success", id="btn_sync")
|
||||
yield Button("Force Overwrite", variant="error", id="btn_force")
|
||||
yield Button("Force Overwrite", variant="warning", id="btn_force")
|
||||
else:
|
||||
yield Button("Backup", variant="success", id="btn_backup")
|
||||
yield Button("Back", id="btn_cancel_chan")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue