improve channel selection & progress display

This commit is contained in:
rambros 2026-03-01 18:09:05 +05:30
parent 281399fc1e
commit bc500641bf
2 changed files with 116 additions and 17 deletions

View file

@ -55,7 +55,24 @@ class DiscordExporter:
else: else:
metadata["banner"] = None metadata["banner"] = None
# Add metadata fields
from datetime import datetime
metadata["last_backup"] = datetime.now().isoformat()
output_file = self.export_path / "server_profile.json" output_file = self.export_path / "server_profile.json"
# Preserve ignore_channels if the file already exists
ignore_channels = []
if output_file.exists():
try:
with open(output_file, "r", encoding="utf-8") as f:
old_data = json.load(f)
ignore_channels = old_data.get("ignore_channels", [])
except Exception as e:
logger.warning(f"Could not read existing server_profile.json to preserve ignore_channels: {e}")
metadata["ignore_channels"] = ignore_channels
with open(output_file, "w", encoding="utf-8") as f: with open(output_file, "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=4, ensure_ascii=False) json.dump(metadata, f, indent=4, ensure_ascii=False)
return metadata return metadata
@ -240,10 +257,12 @@ class DiscordExporter:
"nsfw": getattr(c, "nsfw", False) "nsfw": getattr(c, "nsfw", False)
} }
async def export_channel_messages(self, channel_id: int, progress_callback=None): async def export_channel_messages(self, channel_id: int, progress_callback=None, force=False):
"""Exports all messages from a channel, including attachments, pins, reactions.""" """Fetches and saves message history for a channel, handling incremental sync."""
channel = await self.reader.get_channel(channel_id) channel = await self.reader.get_channel(channel_id)
if not channel: return 0 if not channel:
logger.error(f"Channel not found: {channel_id}")
return 0
channel_name = channel.name channel_name = channel.name
safe_name = channel_name.replace(" ", "-").lower() safe_name = channel_name.replace(" ", "-").lower()
@ -278,13 +297,21 @@ class DiscordExporter:
base_filename = str(channel_id) base_filename = str(channel_id)
json_file = backup_dir / f"{base_filename}.json" json_file = backup_dir / f"{base_filename}.json"
asset_dir = backup_dir / base_filename asset_dir = backup_dir / base_filename
if force and asset_dir.exists():
import shutil
try:
shutil.rmtree(asset_dir)
except Exception as e:
logger.warning(f"Failed to clear asset directory {asset_dir}: {e}")
asset_dir.mkdir(exist_ok=True) asset_dir.mkdir(exist_ok=True)
messages = [] messages = []
last_id = None last_id = None
# Load existing messages for incremental sync # Load existing messages for incremental sync (skip if force)
if json_file.exists(): if not force and json_file.exists():
try: try:
with open(json_file, "r", encoding="utf-8") as f: with open(json_file, "r", encoding="utf-8") as f:
old_data = json.load(f) old_data = json.load(f)
@ -309,7 +336,7 @@ class DiscordExporter:
messages.append(msg_data) messages.append(msg_data)
new_count += 1 new_count += 1
if progress_callback: if progress_callback:
await progress_callback(channel_name, count + new_count) await progress_callback(channel_name, new_count)
except discord.Forbidden: except discord.Forbidden:
logger.error(f"403 Forbidden: Missing Access to read messages in {channel_name} ({channel_id})") logger.error(f"403 Forbidden: Missing Access to read messages in {channel_name} ({channel_id})")
if not messages: return 0 if not messages: return 0
@ -532,7 +559,7 @@ class DiscordExporter:
return data return data
async def export_threads(self, channel_id: int): async def export_threads(self, channel_id: int, progress_callback=None, force=False):
"""Exports active and archived threads for a channel.""" """Exports active and archived threads for a channel."""
channel = await self.reader.get_channel(channel_id) channel = await self.reader.get_channel(channel_id)
if not hasattr(channel, "threads") and not hasattr(channel, "public_archived_threads"): if not hasattr(channel, "threads") and not hasattr(channel, "public_archived_threads"):
@ -562,7 +589,7 @@ class DiscordExporter:
logger.info(f"Found {len(all_threads)} threads in {channel.name}. Starting backup...") logger.info(f"Found {len(all_threads)} threads in {channel.name}. Starting backup...")
for thread in all_threads: for thread in all_threads:
await self.export_channel_messages(thread.id) await self.export_channel_messages(thread.id, progress_callback=progress_callback, force=force)
thread_count += 1 thread_count += 1
return thread_count return thread_count

View file

@ -3,6 +3,10 @@ import asyncio
import discord import discord
import logging import logging
import time import time
import re
import json
from datetime import datetime
from pathlib import Path
from rich.console import Console from rich.console import Console
from rich.prompt import Prompt, Confirm from rich.prompt import Prompt, Confirm
from rich.table import Table from rich.table import Table
@ -66,6 +70,29 @@ class DiscoReaperCLI:
return False return False
def get_backup_info(self):
"""Checks for existing backup and returns formatted timestamp."""
d_name = self.validation_results.get("discord_server_name")
d_id = self.config.discord_server_id
if not d_name or not d_id or d_id == "DISCORD_SERVER_ID":
return None
safe_name = re.sub(r'[^a-zA-Z0-9_\-\.]', '_', d_name)
export_path = Path(".") / f"EXPORT-{safe_name}-{d_id}"
profile_file = export_path / "server_profile.json"
if profile_file.exists():
try:
with open(profile_file, "r", encoding="utf-8") as f:
data = json.load(f)
ts_str = data.get("last_backup")
if ts_str:
dt = datetime.fromisoformat(ts_str)
return dt.strftime("%d-%b-%Y %H:%M")
except Exception:
pass
return None
async def run(self): async def run(self):
await self.validate_config() await self.validate_config()
@ -81,9 +108,13 @@ class DiscoReaperCLI:
b_display = f"[bold green]\"{b_name}\"[/bold green]" if b_name else "[bold red]UNKNOWN[/bold red]" b_display = f"[bold green]\"{b_name}\"[/bold green]" if b_name else "[bold red]UNKNOWN[/bold red]"
console.print(f"[bold cyan]Bot name:[/bold cyan] {b_display}") console.print(f"[bold cyan]Bot name:[/bold cyan] {b_display}")
backup_ts = self.get_backup_info()
if backup_ts:
console.print(f"[bold cyan]Backup Found:[/bold cyan] [bold yellow]{backup_ts}[/bold yellow]")
console.print("\n[bold]Main Menu[/bold]") console.print("\n[bold]Main Menu[/bold]")
console.print("(1) Backup Server Profile") console.print("(1) Backup Server Profile")
console.print("(2) Backup Messages") console.print("(2) Backup Channel Messages")
console.print("(3) Update & Sync Backup") console.print("(3) Update & Sync Backup")
console.print("(4) Configuration") console.print("(4) Configuration")
console.print("(Q) Exit") console.print("(Q) Exit")
@ -182,7 +213,14 @@ class DiscoReaperCLI:
cat_name = cat_map.get(chan.category_id) cat_name = cat_map.get(chan.category_id)
if cat_name: if cat_name:
display_name = f"{chan.name} [{cat_name}]" display_name = f"{chan.name} [{cat_name}]"
console.print(f"({i+1}) {display_name}")
# Check for existing backup
found_prefix = ""
backup_file = self.exporter.export_path / "message_backup" / f"{chan.id}.json"
if backup_file.exists():
found_prefix = "[green][FOUND][/green] "
console.print(f"({i+1}) {found_prefix}{display_name}")
console.print("(A) [bold green]All Channels[/bold green]") console.print("(A) [bold green]All Channels[/bold green]")
console.print("(B) Back") console.print("(B) Back")
@ -209,23 +247,57 @@ class DiscoReaperCLI:
console.print("[yellow]No valid channels selected.[/yellow]") console.print("[yellow]No valid channels selected.[/yellow]")
return return
# Check if any have [FOUND]
any_found = False
for chan in selected_channels:
if (self.exporter.export_path / "message_backup" / f"{chan.id}.json").exists():
any_found = True
break
force_overwrite = False
if any_found:
console.print("\n[bold yellow]Existing backup(s) found for some selected channels.[/bold yellow]")
console.print("(Y) Update & Sync Backup")
console.print("(F) [bold red]Force Overwrite Backup[/bold red]")
console.print("(B) Back")
sync_choice = Prompt.ask("\nSelect option", choices=["Y", "F", "B"], default="Y").upper()
if sync_choice == "B":
return
force_overwrite = (sync_choice == "F")
console.print(f"\n[yellow]Starting backup for {len(selected_channels)} channels...[/yellow]") console.print(f"\n[yellow]Starting backup for {len(selected_channels)} channels...[/yellow]")
with Progress( with Progress(
SpinnerColumn(), SpinnerColumn(),
TextColumn("[progress.description]{task.description}"), TextColumn("[progress.description]{task.description}"),
BarColumn(), TextColumn("- [bold yellow]{task.fields[msg_count]} {task.fields[suffix]}"),
TaskProgressColumn(),
console=console console=console
) as progress: ) as progress:
overall_task = progress.add_task("[cyan]Exporting Channels...", total=len(selected_channels))
for chan in selected_channels: for chan in selected_channels:
progress.update(overall_task, description=f"[cyan]Backing up: {chan.name}") # Determine if it's a sync or fresh backup
await self.exporter.export_channel_messages(chan.id) backup_exists = (self.exporter.export_path / "message_backup" / f"{chan.id}.json").exists()
await self.exporter.export_threads(chan.id) is_sync = backup_exists and not force_overwrite
progress.advance(overall_task)
label = "Syncing Backup" if is_sync else "Backing up"
suffix = "new messages" if is_sync else "messages"
# Create a specific task for each channel to keep it on its own line
task_id = progress.add_task(f"[cyan]{label}: {chan.name}", total=None, msg_count=0, suffix=suffix)
async def update_msg_count(name, count, tid=task_id):
progress.update(tid, msg_count=count)
await self.exporter.export_channel_messages(chan.id, progress_callback=update_msg_count, force=force_overwrite)
await self.exporter.export_threads(chan.id, progress_callback=update_msg_count, force=force_overwrite)
# Mark as finished (stop spinner)
progress.stop_task(task_id)
progress.update(task_id, description=f"[green]Completed: {chan.name}")
console.print("[bold green]Message backup complete![/bold green]") console.print("[bold green]Message backup complete![/bold green]")
# Update last_backup in server_profile.json
await self.exporter.export_metadata()
except Exception as e: except Exception as e:
console.print(f"[bold red]Message backup failed: {e}[/bold red]") console.print(f"[bold red]Message backup failed: {e}[/bold red]")
finally: finally: