Compare commits

..

No commits in common. "089800374957f3dc549e48dc72f68e8d4f97a308" and "fb773ba948dd2778dd2d4cb7ac205a2b6fa4dd75" have entirely different histories.

25 changed files with 742 additions and 2197 deletions

5
.gitignore vendored
View file

@ -21,7 +21,6 @@ wheels/
.installed.cfg
*.egg
*.txt
*.exe
# Virtual Environment
venv/
@ -41,11 +40,11 @@ reaper_config.yaml
*.log
# Temporary Test Scripts
#test_*.py
tmp/
test_*.py
test_release.zip
test_release/
DiscoReaper
DiscoReaper-*
*.zip
# App data files

View file

@ -2,7 +2,7 @@
**DiscoReaper** is a powerful tool designed to help you migrate your entire Discord server to Fluxer or Stoat. It clones channels, roles, emojis, permissions, and also your community's full message history.
### Get it here: [![Linux](https://img.shields.io/badge/Linux-FCC624?logo=linux&logoColor=black)](https://git.mithraic.cloud/ad3laid3/disco-reaper/releases/latest) [![Windows](https://custom-icon-badges.demolab.com/badge/Windows-0078D6?logo=windows11&logoColor=white)](https://git.mithraic.cloud/ad3laid3/disco-reaper/releases/latest)
### Get it here: [![Linux](https://img.shields.io/badge/Linux-FCC624?logo=linux&logoColor=black)](https://github.com/rambros3d/disco-reaper/releases/latest/download/disco-reaper-linux.zip) [![Windows](https://custom-icon-badges.demolab.com/badge/Windows-0078D6?logo=windows11&logoColor=white)](https://github.com/rambros3d/disco-reaper/releases/latest/download/disco-reaper-windows.zip) [![macOS](https://img.shields.io/badge/macOS-000000?logo=apple&logoColor=F0F0F0)](https://github.com/rambros3d/disco-reaper/releases/latest/download/disco-reaper-macos.zip)
>Join our [**Reaper Community**](https://fluxer.gg/9KxDP8WH) if you need help or have any questions.
@ -81,7 +81,7 @@ Migrate directly from Discord or your Local Backups to the target community.
#### Setup the bots as per this [guide](BOT-SETUP.md)
### Option 1: Using Pre-built Binaries (Easiest)
1. **Download**: Grab the latest version from the [Releases](https://git.mithraic.cloud/ad3laid3/disco-reaper/releases) page.
1. **Download**: Grab the latest version from the [Releases](https://github.com/rambros3d/disco-reaper/releases) page.
2. **Run**:
- **Linux**: Run the `disco-reaper` binary (e.g., `./launch.sh` or double-click).
- **Windows**: Run `disco-reaper.exe`.
@ -89,7 +89,7 @@ Migrate directly from Discord or your Local Backups to the target community.
### Option 2: Running from Source (To use latest unstable code)
1. **Clone**: Clone the repository to your local machine:
```bash
git clone https://git.mithraic.cloud/ad3laid3/disco-reaper.git
git clone https://github.com/rambros3d/disco-reaper.git
cd disco-reaper
```
2. **Launch**: Run the appropriate launcher script for your OS. It will automatically create a virtual environment and install dependencies:
@ -153,4 +153,10 @@ But now their own website states that **Persona** will be used in some countries
## Contributors
MiTHRAL — fork maintainer, Stoat/Revolt integration & bug fixes.
<a href="https://github.com/rambros3d/disco-reaper/graphs/contributors">
<img src="https://contrib.rocks/image?repo=rambros3d/disco-reaper" />
</a>
Made with [contrib.rocks](https://contrib.rocks).
[![Star History Chart](https://api.star-history.com/image?repos=rambros3d/disco-reaper&type=date&legend=top-left)](https://www.star-history.com/?repos=rambros3d%2Fdisco-reaper&type=date&legend=top-left)

View file

@ -1,77 +0,0 @@
@echo off
setlocal enabledelayedexpansion
cd /d "%~dp0"
echo --- Disco-Reaper Windows Build Script ---
REM Check for venv
IF NOT EXIST "venv" (
echo Creating virtual environment...
python -m venv venv
IF ERRORLEVEL 1 (
echo Error: Failed to create virtual environment.
echo Ensure Python is installed and added to PATH.
IF NOT DEFINED AUTO_BUILD pause
exit /b 1
)
)
echo Activating virtual environment...
call venv\Scripts\activate.bat
REM Self-healing pip check
python -m pip --version >nul 2>&1
IF ERRORLEVEL 1 (
echo Warning: pip is missing or broken in venv. Attempting repair...
python -m ensurepip --default-pip
IF ERRORLEVEL 1 (
echo Error: Failed to repair pip automatically.
echo Try recreating the venv: rmdir /s /q venv ^&^& python -m venv venv
IF NOT DEFINED AUTO_BUILD pause
exit /b 1
)
)
echo Ensuring build dependencies are up to date...
REM python -m pip install --upgrade pip --quiet
python -m pip install pyinstaller --quiet
python -m pip install -r requirements.txt --quiet
echo Cleaning previous build artifacts...
IF EXIST "build" rmdir /s /q build
IF EXIST "dist" rmdir /s /q dist
echo Starting PyInstaller build...
REM Get git version tag
set "GIT_VERSION=Unknown"
for /f "tokens=*" %%i in ('git describe --tags --abbrev^=0 2^>nul') do set "GIT_VERSION=%%i"
echo Baking version: %GIT_VERSION%
echo __version__ = "%GIT_VERSION%"> src\core\_baked_version.py
python -m PyInstaller --clean disco-reaper.spec
IF ERRORLEVEL 1 (
echo Error: PyInstaller build failed.
del /f src\core\_baked_version.py 2>nul
IF NOT DEFINED AUTO_BUILD pause
exit /b 1
)
echo Cleaning up baked version file...
del /f src\core\_baked_version.py 2>nul
echo Packaging release: disco-reaper-windows.zip...
cd dist
powershell -Command "Compress-Archive -Path 'DiscoReaper.exe' -DestinationPath 'disco-reaper-windows.zip' -Force" 2>nul
IF ERRORLEVEL 1 (
echo Warning: Failed to create zip. Files are available in dist\ directory.
) ELSE (
echo Package created: dist\disco-reaper-windows.zip
)
cd ..
echo -----------------------------------
echo Build complete!
echo Standalone executable: dist\DiscoReaper.exe
echo Release Package: dist\disco-reaper-windows.zip
echo ---
IF NOT DEFINED AUTO_BUILD pause

View file

@ -1,18 +1,17 @@
import sys
import logging
from logging.handlers import RotatingFileHandler
from src.ui.main_app import run_disco_reaper_tui
from src.core.configuration import load_config
def setup_logging():
try:
config = load_config(create_if_missing=False)
log_level_str = config.log_level.upper()
log_level_str = config.migration.log_level.upper()
level = getattr(logging, log_level_str, logging.INFO)
except Exception:
level = logging.INFO
handlers = [RotatingFileHandler('.reaper.log', mode='w', maxBytes=10*1024*1024, backupCount=3)]
handlers = [logging.FileHandler('.reaper.log', mode='a')]
logging.basicConfig(
format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
datefmt='%H:%M:%S',
@ -93,21 +92,18 @@ def cleanup_old_update():
"""Removes the .old executable left behind by a Windows update."""
import os
import sys
from pathlib import Path
if sys.platform != "win32":
return
current_exe = sys.executable if getattr(sys, 'frozen', False) else sys.argv[0]
old_exe = current_exe + ".old"
# In frozen (PyInstaller) builds, sys.executable points to the temp _MEIxxxxx dir.
# sys.argv[0] always points to the real .exe on disk, so use that and resolve() it.
current_exe = Path(sys.argv[0]).resolve()
old_exe = current_exe.with_suffix(current_exe.suffix + ".old")
if old_exe.exists():
if os.path.exists(old_exe):
try:
old_exe.unlink()
except Exception as e:
logging.getLogger(__name__).debug(f"Could not remove old update file {old_exe}: {e}")
os.remove(old_exe)
except Exception:
pass
def main():
import os

View file

@ -8,7 +8,4 @@ pydantic # Data validation using Python type hints
lottie # Lottie file manipulation and conversion
Pillow # Image processing (required for GIF rendering)
cairosvg # SVG rendering (required for Lottie conversion)
psutil # System information (CPU, RAM, etc.)
pytest # Testing framework
pytest-asyncio # Async testing for pytest
pytest-mock # Mocking for pytest
psutil # System information (CPU, RAM, etc.)

View file

@ -1,21 +0,0 @@
#!/bin/bash
# Convenient script to run the Discord Reaper test suite.
# Automatically handles VENV activation and PYTHONPATH.
# Change to the project root directory (where the script is located)
cd "$(dirname "$0")"
# Activate the virtual environment if it exists
if [ -d "venv" ]; then
source venv/bin/activate
fi
# Set PYTHONPATH to the current directory so src is discoverable
export PYTHONPATH=.
# Run pytest with any arguments passed to this script, defaulting to tests/
if [ $# -eq 0 ]; then
pytest -v -s -p no:warnings tests/
else
pytest -s "$@"
fi

View file

@ -98,9 +98,9 @@ class BackupDatabase:
elif table == "permissions":
conn.execute("CREATE TABLE permissions (id INTEGER PRIMARY KEY AUTOINCREMENT, channel_id INTEGER, target_id INTEGER, target_type TEXT, allow INTEGER, deny INTEGER)")
elif table == "users":
conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, username TEXT, display_name TEXT, avatar_file TEXT, avatar_url TEXT, roles TEXT, type INTEGER DEFAULT 0)")
conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, username TEXT, display_name TEXT, avatar_file TEXT, avatar_url TEXT, roles TEXT)")
elif table == "messages":
conn.execute("CREATE TABLE messages (id INTEGER PRIMARY KEY, channel_id INTEGER, author_id INTEGER, content TEXT, timestamp TEXT, type INTEGER, message_reference INTEGER, is_pinned INTEGER, extra_data TEXT, custom_display_name TEXT, custom_avatar_url TEXT)")
conn.execute("CREATE TABLE messages (id INTEGER PRIMARY KEY, channel_id INTEGER, author_id INTEGER, content TEXT, timestamp TEXT, type INTEGER, message_reference INTEGER, is_pinned INTEGER, extra_data TEXT)")
elif table == "attachments":
conn.execute("CREATE TABLE attachments (id INTEGER PRIMARY KEY, message_id INTEGER, filename TEXT, size INTEGER, url TEXT, content_type TEXT, local_hash TEXT)")
elif table == "embeds":
@ -114,7 +114,7 @@ class BackupDatabase:
elif table == "forum_tags":
conn.execute("CREATE TABLE forum_tags (id INTEGER PRIMARY KEY, forum_id INTEGER, name TEXT, moderated INTEGER, emoji_id INTEGER, emoji_name TEXT)")
elif table == "server_assets":
conn.execute("CREATE TABLE server_assets (id INTEGER PRIMARY KEY, name TEXT, type TEXT, filename TEXT, url TEXT, content_type TEXT)")
conn.execute("CREATE TABLE server_assets (id INTEGER PRIMARY KEY, name TEXT, type TEXT, filename TEXT, url TEXT, content_type INTEGER)")
old_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table}_old)").fetchall()]
new_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table})").fetchall()]
@ -123,25 +123,6 @@ class BackupDatabase:
conn.execute(f"INSERT INTO {table} ({col_str}) SELECT {col_str} FROM {table}_old")
conn.execute(f"DROP TABLE {table}_old")
# 3. Custom Author Profile Migration
res = conn.execute("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='messages'").fetchone()
if res and res[0] > 0:
cols = conn.execute("PRAGMA table_info(messages)").fetchall()
col_names = [c["name"] for c in cols]
if "custom_display_name" not in col_names:
logger.info("Migrating messages: adding custom author profile columns")
conn.execute("ALTER TABLE messages ADD COLUMN custom_display_name TEXT")
conn.execute("ALTER TABLE messages ADD COLUMN custom_avatar_url TEXT")
# 4. User Type Categorization Migration
res = conn.execute("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='users'").fetchone()
if res and res[0] > 0:
cols = conn.execute("PRAGMA table_info(users)").fetchall()
col_names = [c["name"] for c in cols]
if "type" not in col_names:
logger.info("Migrating users: adding type column")
conn.execute("ALTER TABLE users ADD COLUMN type INTEGER DEFAULT 0")
conn.commit()
@ -215,8 +196,7 @@ class BackupDatabase:
display_name TEXT,
avatar_file TEXT,
avatar_url TEXT,
roles TEXT,
type INTEGER DEFAULT 0
roles TEXT
)
""")
@ -231,9 +211,7 @@ class BackupDatabase:
type INTEGER,
message_reference INTEGER,
is_pinned INTEGER,
extra_data TEXT,
custom_display_name TEXT,
custom_avatar_url TEXT
extra_data TEXT
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_messages_channel ON messages(channel_id)")
@ -427,8 +405,8 @@ class BackupDatabase:
"""Saves users to the author cache."""
with self._lock:
self._conn.executemany("""
INSERT OR REPLACE INTO users (id, username, display_name, avatar_file, avatar_url, roles, type)
VALUES (:id, :username, :display_name, :avatar_file, :avatar_url, :roles, :type)
INSERT OR REPLACE INTO users (id, username, display_name, avatar_file, avatar_url, roles)
VALUES (:id, :username, :display_name, :avatar_file, :avatar_url, :roles)
""", users)
self._conn.commit()
@ -476,8 +454,8 @@ class BackupDatabase:
conn = self._conn
# Insert messages
conn.executemany("""
INSERT OR REPLACE INTO messages (id, channel_id, author_id, content, timestamp, type, message_reference, is_pinned, extra_data, custom_display_name, custom_avatar_url)
VALUES (:id, :channel_id, :author_id, :content, :timestamp, :type, :message_reference, :is_pinned, :extra_data, :custom_display_name, :custom_avatar_url)
INSERT OR REPLACE INTO messages (id, channel_id, author_id, content, timestamp, type, message_reference, is_pinned, extra_data)
VALUES (:id, :channel_id, :author_id, :content, :timestamp, :type, :message_reference, :is_pinned, :extra_data)
""", messages)
# Extract attachments, reactions, and stickers
@ -967,60 +945,6 @@ class BackupDatabase:
return purged_count
def get_backed_up_channel_ids(self) -> List[int]:
"""Returns a list of distinct channel IDs that have messages in the database."""
with self._lock:
rows = self._conn.execute("SELECT DISTINCT channel_id FROM messages").fetchall()
return [parse_snowflake(r[0]) for r in rows if parse_snowflake(r[0])]
def get_message_with_relations(self, message_id) -> Optional[Dict[str, Any]]:
"""Fetches a single message with its attachments, embeds, reactions, and stickers."""
with self._lock:
mid = parse_snowflake(message_id)
row = self._conn.execute("SELECT * FROM messages WHERE id = ?", (mid,)).fetchone()
if not row:
return None
data = dict(row)
# Attachments
atts = self._conn.execute("SELECT * FROM attachments WHERE message_id = ?", (mid,)).fetchall()
data["attachments"] = [dict(a) for a in atts]
# Embeds
embs = self._conn.execute("SELECT * FROM embeds WHERE message_id = ?", (mid,)).fetchall()
data["embeds"] = []
for er in embs:
e_dict = {
"title": er["title"],
"description": er["description"],
"url": er["url"],
"color": er["color"],
"timestamp": er["timestamp"],
"thumbnail": {"url": er["thumbnail_url"]} if er["thumbnail_url"] else None,
"image": {"url": er["image_url"]} if er["image_url"] else None,
"author": {
"name": er["author_name"],
"url": er["author_url"],
"icon_url": er["author_icon_url"]
} if er["author_name"] else None,
"footer": {
"text": er["footer_text"],
"icon_url": er["footer_icon_url"]
} if er["footer_text"] else None,
"fields": json.loads(er["fields"]) if er["fields"] else []
}
data["embeds"].append(e_dict)
# Reactions
reas = self._conn.execute("SELECT * FROM reactions WHERE message_id = ?", (mid,)).fetchall()
data["reactions"] = [dict(r) for r in reas]
# Stickers
sts = self._conn.execute("SELECT * FROM message_stickers WHERE message_id = ?", (mid,)).fetchall()
data["stickers"] = [dict(s) for s in sts]
return data
def close(self):
"""Commits any pending writes and closes the connection."""
with self._lock:

View file

@ -393,20 +393,6 @@ class BackupMember:
# Fallback for unexpected data format
self.id = 0
self.name = "Unknown"
self.display_name = "Unknown"
self.global_name = "Unknown"
self.bot = False
self.system = False
self.discriminator = "0000"
self.color = BackupColor(0)
self.roles = sorted(role_objects or [], key=lambda r: r.position, reverse=True)
self.guild_permissions = BackupPermissions(0)
self.created_at = datetime.now(timezone.utc)
self.joined_at = datetime.now(timezone.utc)
self.status = type("Status", (), {"value": "offline"})()
self.activity = None
self._avatar_url = None
self.avatar = BackupAsset(None)
return
self.id = parse_snowflake(data["id"])
self.name = data.get("username", "Unknown")
@ -530,13 +516,12 @@ class BackupEmoji:
class BackupSticker:
"""Minimal stand-in for discord.GuildSticker."""
__slots__ = ("id", "name", "url", "format", "_backup_root", "_file_path", "local_hash")
__slots__ = ("id", "name", "url", "format", "_backup_root", "_file_path")
def __init__(self, data: dict, backup_root: Path | None = None, media_pool: dict | None = None):
if not isinstance(data, dict):
self.id = 0
self.name = "Sticker"
self.local_hash = None
return
self.id = parse_snowflake(data.get("id") or data.get("sticker_id", 0)) or 0
self.name = data.get("name", "Sticker")
@ -551,14 +536,14 @@ class BackupSticker:
self._backup_root = backup_root
# 1. Check if it's a CAS-based sticker (from message_stickers table)
self.local_hash = data.get("local_hash")
if self.local_hash and backup_root:
local_hash = data.get("local_hash")
if local_hash and backup_root:
ext = ".png"
if self.format == StickerFormatType.lottie: ext = ".json"
elif self.format == StickerFormatType.apng: ext = ".png"
elif self.format == StickerFormatType.gif: ext = ".gif"
self._file_path = backup_root / "attachments" / f"{self.local_hash}{ext}"
self._file_path = backup_root / "attachments" / f"{local_hash}{ext}"
# 2. Check if it's a server asset sticker (legacy or manual save)
elif data.get("filename") and backup_root:
self._file_path = backup_root / "server_assets" / data["filename"]
@ -1281,7 +1266,11 @@ class BackupReader:
async def get_backed_up_channel_ids(self) -> List[int]:
"""Returns a list of channel IDs that have messages in the database."""
if not self.db: return []
return self.db.get_backed_up_channel_ids()
import sqlite3
conn = sqlite3.connect(self.db.db_path)
rows = conn.execute("SELECT DISTINCT channel_id FROM messages").fetchall()
conn.close()
return [parse_snowflake(r[0]) for r in rows if parse_snowflake(r[0])]
async def get_channel(self, channel_id: int) -> BackupChannel | BackupThread | None:
for c in self.channels:
@ -1339,18 +1328,6 @@ class BackupReader:
user_id = parse_snowflake(msg_data.get("author_id", 0)) or 0
author = self._resolve_author(user_id)
# Check for custom author profile (Webhooks / Masquerade)
over_name = msg_data.get("custom_display_name")
over_avatar = msg_data.get("custom_avatar_url")
if over_name:
# Create an ephemeral author object for this message
author = BackupMember({
"id": str(user_id),
"username": over_name,
"display_name": over_name,
"avatar_url": over_avatar
})
self._ensure_media_pool_loaded()
channel_id = parse_snowflake(msg_data["channel_id"])
@ -1374,9 +1351,23 @@ class BackupReader:
async def get_message(self, channel_id: int, message_id: int) -> BackupMessage | None:
"""Fetch a specific message from SQLite."""
if not self.db: return None
data = self.db.get_message_with_relations(message_id)
if data:
import sqlite3
conn = sqlite3.connect(self.db.db_path)
conn.row_factory = sqlite3.Row
row = conn.execute("SELECT * FROM messages WHERE id = ?", (str(message_id),)).fetchone()
if row:
data = dict(row)
# Fetch attachments
atts = conn.execute("SELECT * FROM attachments WHERE message_id = ?", (str(message_id),)).fetchall()
data["attachments"] = [dict(a) for a in atts]
# Fetch stickers
sts = conn.execute("SELECT * FROM message_stickers WHERE message_id = ?", (str(message_id),)).fetchall()
data["stickers"] = [dict(s) for s in sts]
conn.close()
return self._hydrate_message(data)
conn.close()
return None
async def get_first_message(self, channel_id: int) -> BackupMessage | None:

View file

@ -22,13 +22,6 @@ class MigrationContext:
self.target_platform = target_platform or config.target_platform or "fluxer"
self.state = MigrationState()
# Apply config-based log level dynamically to the root logger
if hasattr(self.config, "log_level") and self.config.log_level:
import logging
level = getattr(logging, self.config.log_level.upper(), logging.INFO)
logging.getLogger().setLevel(level)
logger.info(f"Log level updated to {self.config.log_level.upper()}")
# Select the appropriate source reader
if source_mode == "backup":
from src.core.backup_reader import BackupReader
@ -44,8 +37,8 @@ class MigrationContext:
# Build the writer for the active target platform only
if self.target_platform == "stoat":
token = config.stoat_bot_token
community_id = config.stoat_server_id
token = config.stoat_bot_token or ""
community_id = config.stoat_server_id or ""
api_url = config.stoat_api_url or "default"
self.writer = StoatWriter(token=token, community_id=community_id, api_url=api_url)
self.stoat_writer = self.writer
@ -104,17 +97,11 @@ class MigrationContext:
}
# CONSISTENCY: Once target metadata is known, initialize the flat SQLite DB.
if results["target_community"]:
if results["target_community"] and results["target_community_name"]:
tid = self.config.fluxer_server_id if self.target_platform == "fluxer" else self.config.stoat_server_id
# Prefer the original discord community name for the DB file if available (e.g. from live load or backup)
db_name = results.get("discord_server_name")
if not db_name or db_name == "Not Found" or db_name == "Unknown":
db_name = results.get("target_community_name") or "Unknown"
self.ensure_state_initialized(
str(tid or ""),
db_name
results["target_community_name"]
)
return results
@ -133,23 +120,6 @@ class MigrationContext:
return
import re
import json
# Override the target name explicitly with the original Discord source name if available.
# This fixes naming collisions and UI confusion like "Fluxer-123456.db" instead of "MyServer-123456.db"
try:
if hasattr(self.discord_reader, "guild") and getattr(self.discord_reader, "guild", None):
community_name = getattr(self.discord_reader, "guild").name
elif getattr(self, "source_mode", "live") == "backup" and hasattr(self.discord_reader, "backup_dir"):
b_dir = getattr(self.discord_reader, "backup_dir")
if b_dir and b_dir.exists():
meta_file = b_dir / "metadata.json"
if meta_file.exists():
data = json.loads(meta_file.read_text())
community_name = data.get("name", community_name)
except Exception:
pass
clean_name = re.sub(r'[^\w\s-]', '', community_name).strip()
clean_name = re.sub(r'[-\s]+', '_', clean_name)

View file

@ -4,7 +4,7 @@ import logging
import json
import random
from pathlib import Path
from typing import Optional, Dict, Any, List, Union
from typing import Optional, Dict, Any, Union
import threading
import sys
from src.core.utils import parse_snowflake
@ -560,15 +560,6 @@ class MigrationDatabase:
conn.execute("DELETE FROM thread_tracking WHERE channel_id = ?", (str(channel_id),))
conn.commit()
logger.info(f"Cleared all tracking and mapping data for channel: {channel_id}")
def clear_all_migration_data(self):
"""Purge all mappings and tracking data for ALL channels and threads."""
conn = self._get_conn()
conn.execute("DELETE FROM message_mappings")
conn.execute("DELETE FROM thread_mappings")
conn.execute("DELETE FROM channel_tracking")
conn.execute("DELETE FROM thread_tracking")
conn.commit()
logger.info("Cleared ALL tracking and message mapping data globally.")
def close(self):
if hasattr(self._local, "conn"):

View file

@ -18,7 +18,6 @@ class DiscordExporter:
self.server_name = ""
self.server_id = ""
self.user_cache = {}
self.member_cache: Dict[int, Any] = {} # Pre-fetched member objects (id -> Member)
self.base_dir = Path(base_dir) if base_dir else Path(".")
self.is_running = True
self.db: Optional[BackupDatabase] = None
@ -63,20 +62,6 @@ class DiscordExporter:
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
async def prefetch_members(self):
"""Pre-fetches all guild members into a local cache for role resolution.
msg.author is a discord.User (no roles). This cache allows us to
resolve roles without an API call per message during message export.
"""
try:
members = await self.reader.get_members()
self.member_cache = {m.id: m for m in members}
logger.info(f"Pre-fetched {len(self.member_cache)} members for role resolution.")
except Exception as e:
logger.warning(f"Could not pre-fetch members (roles will be empty): {e}")
self.member_cache = {}
async def export_metadata(self):
"""Saves server metadata to the SQLite database."""
metadata = await self.reader.get_server_metadata()
@ -454,65 +439,37 @@ class DiscordExporter:
return accumulated_count, accumulated_threads, accumulated_files
async def _format_user(self, user, is_webhook=False):
async def _format_user(self, user):
"""Formats user data for the author or a mention.
For Webhooks, we use a generic name and the default Discord avatar system
for the base profile in the user cache.
Avatar downloads are intentionally deferred to keep this off the hot
message-formatting path. Call _flush_pending_avatars() after each batch.
"""
user_id_int = int(user.id)
user_id = str(user_id_int)
user_id = str(user.id)
if user_id in self.user_cache:
return None
username = user.name
display_name = getattr(user, "display_name", user.name)
avatar = user.avatar
avatar_url = str(user.display_avatar.url) if user.display_avatar else None
if is_webhook:
# For webhooks, we use the ID as the username for technical clarity,
# and the current name as the display name.
username = user_id
display_name = user.name
# Discord default avatar formula: (ID >> 22) % 5
default_index = (user_id_int >> 22) % 5
avatar_url = f"https://cdn.discordapp.com/embed/avatars/{default_index}.png"
avatar = None # Don't download character avatar as the "base" webhook avatar
# New user discovered — schedule avatar download but don't block here
avatar_file = None
if avatar:
if user.avatar:
av_name = f"{user_id}.png"
av_target = self.users_path / av_name
avatar_file = f"users/{av_name}"
if not av_target.exists():
# Queue for deferred download
self._pending_avatars.append((user_id, avatar, av_target))
self._pending_avatars.append((user_id, user.avatar, av_target))
roles = []
if hasattr(user, "roles"):
roles = [str(r.id) for r in user.roles if not r.is_default()]
# Determine user type
# 0: Regular User, 1: Bot, 2: Webhook, 3: System
u_type = 0
if is_webhook:
u_type = 2
elif getattr(user, "system", False):
u_type = 3
elif getattr(user, "bot", False):
u_type = 1
user_data = {
"id": user_id,
"username": username,
"display_name": display_name,
"username": user.name,
"display_name": getattr(user, "display_name", user.name),
"avatar_file": avatar_file,
"avatar_url": avatar_url,
"roles": json.dumps(roles),
"type": u_type
"avatar_url": str(user.display_avatar.url) if user.avatar else None,
"roles": json.dumps(roles)
}
self.user_cache[user_id] = user_data
return user_data
@ -537,21 +494,13 @@ class DiscordExporter:
new_users = []
# 1. Author handling
is_webhook = bool(getattr(msg, "webhook_id", None))
author = msg.author
# msg.author is discord.User (no roles). Resolve to Member for role data.
if not is_webhook:
member = self.member_cache.get(msg.author.id)
if member:
author = member
u_data = await self._format_user(author, is_webhook=is_webhook)
u_data = await self._format_user(msg.author)
if u_data: new_users.append(u_data)
# 1.5 Mentions handling (ensure all mentioned users are saved)
if msg.mentions:
for mention in msg.mentions:
# Mentions can be Member objects already, so roles work naturally
u_ment = await self._format_user(mention, is_webhook=False)
u_ment = await self._format_user(mention)
if u_ment: new_users.append(u_ment)
# 2. Attachments handling (Content-Addressable Storage)
@ -654,15 +603,6 @@ class DiscordExporter:
for s_emb in snapshot.embeds:
embeds.append(s_emb.to_dict())
# 5.6 Author Overrides (Webhooks / Masquerade)
custom_display_name = None
custom_avatar_url = None
# Webhooks or bots with masquerade often use per-message names/avatars
if getattr(msg, "webhook_id", None) or (msg.author and msg.author.bot):
custom_display_name = msg.author.name
custom_avatar_url = str(msg.author.display_avatar.url) if msg.author.display_avatar else None
m_data = {
"id": str(msg.id),
"channel_id": str(msg.channel.id),
@ -676,9 +616,7 @@ class DiscordExporter:
"stickers": stickers,
"embeds": embeds,
"reactions": reactions,
"extra_data": None,
"custom_display_name": custom_display_name,
"custom_avatar_url": custom_avatar_url
"extra_data": None
}
return m_data, new_users

View file

@ -232,18 +232,22 @@ class MigrationState:
return None
def get_global_min_last_message_id(self, all_mapped_ids: list[str]) -> int | None:
def get_global_min_last_message_id(self, all_mapped_ids: List[str]) -> int | None:
"""Returns the absolute minimum last_msg_id among the given list of mapped target IDs (channels and threads)."""
if self._ensure_db():
return self.db.get_global_min_last_message_id(all_mapped_ids)
return None
def clear_all_migration_data(self):
"""Clears all message mapping and tracking state globally."""
if self._ensure_db():
self.db.clear_all_migration_data()
def set_waterfall_last_id(self, last_id: str | int):
if self.db:
self.db.set_metadata("waterfall_last_id", str(last_id))
def get_waterfall_last_id(self) -> int | None:
if self.db:
val = self.db.get_metadata("waterfall_last_id")
return int(val) if val else None
return None
def get_all_last_message_ids(self) -> Dict[str, str]:
"""Returns a combined map of channel_id/thread_id -> last_msg_id."""
if self._ensure_db():

View file

@ -10,7 +10,9 @@ from src.core.utils import get_app_version
logger = logging.getLogger(__name__)
API_URL = "https://git.mithraic.cloud/api/v1/repos/ad3laid3/disco-reaper/releases"
REPO_OWNER = "rambros3d"
REPO_NAME = "disco-reaper"
API_URL = f"https://api.github.com/repos/{REPO_OWNER}/{REPO_NAME}/releases"
def get_current_version() -> str:
"""Returns the current version string, e.g., '1.0.0'. Strips 'Reaper-' and 'v'."""
@ -50,7 +52,7 @@ async def check_for_updates() -> Optional[Dict[str, Any]]:
try:
async with aiohttp.ClientSession() as session:
async with session.get(API_URL) as resp:
async with session.get(API_URL, headers={"Accept": "application/vnd.github.v3+json"}) as resp:
if resp.status == 200:
releases = await resp.json()
if not isinstance(releases, list) or not releases:

View file

@ -15,52 +15,41 @@ async def sync_channel_state(context: MigrationContext):
channels = await context.discord_reader.get_channels()
fluxer_channels = await context.fluxer_writer.get_channels()
# Build maps for Fluxer lookup
# {name: id} for categories
fluxer_cats = {c.get("name"): str(c.get("id")) for c in fluxer_channels if c.get("type") == 4}
# {parent_id: {name: id}} for channels
fluxer_structure = {}
for c in fluxer_channels:
if c.get("type") == 4: continue
p_id = str(c.get("parent_id")) if c.get("parent_id") else "root"
if p_id not in fluxer_structure: fluxer_structure[p_id] = {}
fluxer_structure[p_id][c.get("name")] = str(c.get("id"))
# Build name -> id map and ID set for Fluxer for fast lookup
fluxer_name_map = {c.get("name"): str(c.get("id")) for c in fluxer_channels if c.get("name")}
fluxer_id_set = {str(c.get("id")) for c in fluxer_channels}
updates = 0
removals = 0
# 1. Sync Categories
# 1. Verify and Sync Categories
for cat in categories:
discord_id = str(cat.id)
fluxer_id = context.state.get_fluxer_category_id(discord_id)
if fluxer_id:
if fluxer_id not in fluxer_id_set:
context.state.remove_category_mapping(discord_id)
removals += 1
elif cat.name in fluxer_cats:
context.state.set_category_mapping(discord_id, fluxer_cats[cat.name])
elif cat.name in fluxer_name_map:
context.state.set_category_mapping(discord_id, fluxer_name_map[cat.name])
updates += 1
# 2. Sync Channels (parent-aware)
# 2. Verify and Sync Channels
for ch in channels:
discord_id = str(ch.id)
fluxer_id = context.state.get_fluxer_channel_id(discord_id)
if fluxer_id:
if fluxer_id not in fluxer_id_set:
context.state.remove_channel_mapping(discord_id)
removals += 1
else:
# Try to match by name within the mapped parent category
p_discord_id = str(ch.category_id) if ch.category_id else "root"
p_fluxer_id = context.state.get_fluxer_category_id(p_discord_id) if p_discord_id != "root" else "root"
if p_fluxer_id in fluxer_structure and ch.name in fluxer_structure[p_fluxer_id]:
context.state.set_channel_mapping(discord_id, fluxer_structure[p_fluxer_id][ch.name])
updates += 1
elif ch.name in fluxer_name_map:
context.state.set_channel_mapping(discord_id, fluxer_name_map[ch.name])
updates += 1
if updates > 0 or removals > 0:
logger.info(f"Fluxer Channel sync: {updates} mapped, {removals} stale mappings removed")
logger.info(f"Channel sync: {updates} mapped, {removals} stale mappings removed")
async def migrate_channels(context: MigrationContext, progress_callback: Callable[[str, str, int, int], Awaitable[None]] | None = None, force: bool = False) -> dict:

View file

@ -158,190 +158,7 @@ async def get_channel_threads(reader: Any, channel_id: int) -> List[Any]:
return threads
async def _process_and_send_message(
context: MigrationContext,
msg: Any,
target_channel_id: str,
stats: Dict[str, Any],
thread_id: str | None = None,
parent_target_id: str | None = None,
thread_name: str | None = None,
processed_threads: set | None = None
) -> str | None:
"""
Internal helper to process a single Discord message (mentions, attachments, stickers)
and send it to the Fluxer platform.
"""
# 1. Formatting
content = msg.content or ""
# Check for forwarded flag
is_forwarded = False
if hasattr(msg.flags, 'forwarded'):
is_forwarded = msg.flags.forwarded
# Always ensure alias is created/retrieved to populate user_alias table
alias = context.state.get_user_alias(str(msg.author.id))
anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False
# Process Stickers
files = []
if hasattr(msg, 'stickers') and msg.stickers:
for s in msg.stickers:
try:
sticker_data = await context.discord_reader.download_sticker(s)
if not sticker_data: continue
format_val = getattr(s, 'format', 'png')
if hasattr(format_val, 'name'):
ext = format_val.name.lower()
elif isinstance(format_val, int):
format_map = {1: 'png', 2: 'apng', 3: 'lottie', 4: 'gif'}
ext = format_map.get(format_val, 'png')
else:
ext = str(format_val).lower()
# Conversion logic (Simplified for unification)
if ext == 'lottie' and HAS_LOTTIE:
try:
lottie_data = json.loads(sticker_data)
def _convert_lottie(data):
anim = Animation.load(data)
buf = io.BytesIO()
export_gif(anim, buf)
buf.seek(0)
return buf
gif_buf = await asyncio.to_thread(_convert_lottie, lottie_data)
from PIL import Image
def _convert_gif_to_webp(buf):
img = Image.open(buf)
w_buf = io.BytesIO()
if getattr(img, 'n_frames', 1) > 1:
img.save(w_buf, format='WEBP', save_all=True, loop=0, quality=80)
else:
img.save(w_buf, format='WEBP', quality=80)
return w_buf.getvalue()
sticker_data = await asyncio.to_thread(_convert_gif_to_webp, gif_buf)
ext = 'webp'
except Exception: ext = 'json'
elif ext in ('apng', 'gif'):
try:
from PIL import Image
def _process_animated_sticker(data):
img = Image.open(io.BytesIO(data))
webp_buf = io.BytesIO()
if getattr(img, 'n_frames', 1) > 1:
img.save(webp_buf, format='WEBP', save_all=True, loop=0, quality=80)
else:
img.save(webp_buf, format='WEBP', quality=80)
return webp_buf.getvalue()
sticker_data = await asyncio.to_thread(_process_animated_sticker, sticker_data)
ext = 'webp'
except Exception: pass
files.append({"filename": f"sticker_{s.name}_{s.id}.{ext}", "data": sticker_data})
stats["attachments"] += 1
except Exception as e:
logger.error(f"Failed to download sticker {getattr(s, 'name', 'unknown')}: {e}")
# Process Attachments
attachments_to_process = list(msg.attachments)
if is_forwarded and hasattr(msg, 'message_snapshots') and msg.message_snapshots:
snapshot = msg.message_snapshots[0]
if not content:
content = snapshot.content
attachments_to_process.extend(snapshot.attachments)
for att in attachments_to_process:
try:
att_data = await context.discord_reader.download_attachment(att)
files.append({"filename": att.filename, "data": att_data})
stats["attachments"] += 1
except Exception as e:
logger.error(f"Failed to download attachment {att.filename}: {e}")
# Clean Mentions
content = clean_mentions(
content=content,
guild=context.discord_reader.guild,
user_mentions=msg.mentions,
role_mentions=msg.role_mentions,
channel_mentions=msg.channel_mentions,
emoji_map=context.state.emoji_map,
channel_map=context.state.channel_map,
state=context.state,
target_server_id=context.fluxer_writer.community_id,
channel_names=context.channel_names if hasattr(context, 'channel_names') else None,
anonymize_users=anonymize_users
)
if not content and not files:
return None
# Reply Resolution
reply_to_fluxer_id = None
if msg.reference and msg.reference.message_id:
reply_to_fluxer_id = context.state.get_fluxer_message_id(target_channel_id, str(msg.reference.message_id))
# Fallback author tagging for replies if mapping not found
if not reply_to_fluxer_id:
try:
source_ref_msg = await context.discord_reader.get_message(msg.channel.id, msg.reference.message_id)
if source_ref_msg and source_ref_msg.author:
ref_name = context.state.get_user_alias(str(source_ref_msg.author.id)) if anonymize_users else source_ref_msg.author.display_name
content = f"`@{ref_name}`\n{content}"
else:
tgt_reply = context.state.get_target_message_id(target_channel_id, msg.reference.message_id)
if tgt_reply: content = f"[Reply to {tgt_reply}]\n{content}"
except Exception: pass
# Thread logic
if not reply_to_fluxer_id and parent_target_id and stats["messages"] == 0:
reply_to_fluxer_id = parent_target_id
if thread_name and stats["messages"] == 0:
content = f"> <<< THREAD: **{thread_name}** >>>\n{content}"
# Send Message
if anonymize_users:
author_name = alias or "Anonymized User"
author_avatar_url = None
else:
author_name = msg.author.display_name
author_avatar_url = msg.author.avatar.url if hasattr(msg.author, 'avatar') and msg.author.avatar else None
fluxer_msg_id = await context.fluxer_writer.send_message(
channel_id=target_channel_id,
author_name=author_name,
author_avatar_url=author_avatar_url,
content=content,
timestamp=int(msg.created_at.timestamp()),
files=files if files else None,
reply_to_message_id=reply_to_fluxer_id,
is_forwarded=is_forwarded,
embeds=msg.embeds
)
if fluxer_msg_id:
if thread_id:
context.state.set_thread_message_mapping(target_channel_id, thread_id, str(msg.id), fluxer_msg_id)
context.state.update_thread_last_message_timestamp(target_channel_id, thread_id, str(msg.created_at))
context.state.update_thread_last_message_id(target_channel_id, thread_id, str(msg.id))
context.state.increment_thread_stats(target_channel_id, thread_id, messages=1, files=len(files) if files else 0)
else:
context.state.set_message_mapping(target_channel_id, str(msg.id), fluxer_msg_id)
context.state.update_last_message_timestamp(target_channel_id, str(msg.created_at))
context.state.update_last_message_id(target_channel_id, str(msg.id))
context.state.increment_stats(target_channel_id, messages=1, files=len(files) if files else 0)
stats["messages"] += 1
stats["last_message_content"] = content
stats["last_message_author"] = msg.author.display_name
return fluxer_msg_id
async def analyze_migration(context: MigrationContext, source_channel_id: int, after_message_id: int | None = None, inclusive: bool = False, progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None, processed_threads: set | None = None) -> Dict[str, int]:
"""
Scans channel history to count messages, threads, and attachments.
"""
@ -725,30 +542,87 @@ async def migrate_messages(
logger.debug(f"Added sticker {s.name} as attachment (extension: {ext}, size: {sticker_size} bytes)")
except Exception as e:
logger.error(f"Failed to download sticker {getattr(s, 'name', 'unknown')}: {e}")
# Check for existing mapping to avoid duplicates when resuming
# Check for existing mapping to avoid duplicates when resuming
if context.state.get_target_message_id(target_channel_id, str(msg.id)):
continue
try:
fluxer_msg_id = await _process_and_send_message(
context=context,
msg=msg,
target_channel_id=target_channel_id,
stats=stats,
thread_id=thread_id,
parent_target_id=parent_target_id,
thread_name=thread_name,
processed_threads=processed_threads
reply_to_fluxer_id = None
if msg.reference and msg.reference.message_id:
reply_to_fluxer_id = context.state.get_fluxer_message_id(target_channel_id, str(msg.reference.message_id))
if reply_to_fluxer_id:
logger.debug(f"Detected reply to Discord ID {msg.reference.message_id} -> Fluxer ID {reply_to_fluxer_id}")
else:
logger.debug(f"Reply target Discord ID {msg.reference.message_id} not found in current session map.")
# If this is the FIRST thread message and we have a parent_target_id, force it as reply to the starter
if not reply_to_fluxer_id and parent_target_id and stats["messages"] == 0:
reply_to_fluxer_id = parent_target_id
# Prepend thread marker to the first message of the thread
if thread_name and stats["messages"] == 0:
content = f"> <<< THREAD: **{thread_name}** >>>\n{content}"
# Always ensure alias is created/retrieved to populate user_alias table
alias = context.state.get_user_alias(str(msg.author.id))
anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False
if anonymize_users:
author_name = alias or "Anonymized User"
author_avatar_url = None
else:
author_name = msg.author.display_name
author_avatar_url = msg.author.avatar.url if hasattr(msg.author, 'avatar') and msg.author.avatar else None
logger.debug(f"Fluxer: Calling send_message for Discord ID {msg.id}")
fluxer_msg_id = await context.fluxer_writer.send_message(
channel_id=target_channel_id,
author_name=author_name,
author_avatar_url=author_avatar_url,
content=content,
timestamp=int(msg.created_at.timestamp()),
files=files if files else None,
reply_to_message_id=reply_to_fluxer_id,
is_forwarded=is_forwarded,
embeds=msg.embeds
)
# Check for associated thread (Individual mode recursion)
if fluxer_msg_id:
if thread_id:
context.state.set_thread_message_mapping(target_channel_id, thread_id, str(msg.id), fluxer_msg_id)
else:
context.state.set_message_mapping(target_channel_id, str(msg.id), fluxer_msg_id)
else:
logger.warning(f"Fluxer: send_message returned None for Discord ID {msg.id} (message might have been skipped or timed out)")
if thread_id:
context.state.update_thread_last_message_timestamp(target_channel_id, thread_id, str(msg.created_at))
context.state.update_thread_last_message_id(target_channel_id, thread_id, str(msg.id))
context.state.increment_thread_stats(target_channel_id, thread_id, messages=1, files=len(files) if files else 0)
else:
context.state.update_last_message_timestamp(target_channel_id, str(msg.created_at))
context.state.update_last_message_id(target_channel_id, str(msg.id))
context.state.increment_stats(target_channel_id, messages=1, files=len(files) if files else 0)
stats["messages"] += 1
stats["last_message_content"] = content
stats["last_message_author"] = msg.author.display_name
# Check for associated thread (Normal case: parent message is migrated)
if hasattr(msg, 'thread') and msg.thread:
thread = msg.thread
if thread.id not in processed_threads:
processed_threads.add(thread.id)
# Track thread entry
stats["threads"] += 1
# Fetch last migrated message ID for this thread
thread_after_id = context.state.get_thread_last_message_id(target_channel_id, str(thread.id))
if thread_after_id:
logger.info(f"Resuming thread '{thread.name}' from after message ID: {thread_after_id}")
# Migrate thread messages recursively
thread_stats = await migrate_messages(
context=context,
source_channel_id=thread.id,
@ -763,19 +637,22 @@ async def migrate_messages(
stats["attachments"] += thread_stats["attachments"]
stats["threads"] += thread_stats["threads"]
# Send End Marker
if context.is_running:
await context.fluxer_writer.send_marker(
channel_id=target_channel_id,
content=f"> <<< END OF THREAD >>>"
)
# Update Link Tracking (Parent pointer updates)
# Update Link Tracking (but prevent threaded messages from overwriting the parent channel pointers)
# The 'after_message_id' param usually means it's the main function call and not a thread recursive call
if not stats["first_message_url"]:
stats["first_message_url"] = msg.jump_url
stats["last_message_url"] = msg.jump_url
if progress_callback:
await progress_callback(stats)
logger.debug(f"Fluxer: Finished processing message Discord ID {msg.id}")
except Exception as e:
logger.error(f"Failed to process message {msg.id}: {e}")
import traceback
@ -858,7 +735,7 @@ async def migrate_global_messages(
progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None
) -> Dict[str, Any]:
"""
Migrates messages across all channels chronologically to Fluxer.
Migrates messages across all channels chronologically.
"""
stats = {
"messages": 0,
@ -873,6 +750,14 @@ async def migrate_global_messages(
processed_threads = set()
logger.info("Starting Global Waterfall Migration for Fluxer...")
# Keep track of active thread mapping natively to pass parent target IDs if needed
thread_to_target_channel = {}
# Emojis and mapped users cache setup
emoji_map = context.state.emoji_map
db_media = context.discord_reader.db.get_all_media() if context.discord_reader.db else {}
target_server_id = getattr(context.fluxer_writer, "server_id", None)
# Fetch global progress map to skip migrated messages efficiently
progress_map = context.state.get_all_last_message_ids()
@ -909,18 +794,123 @@ async def migrate_global_messages(
continue
# If it's a thread message, we need to handle it based on if it's the thread starter or a reply
parent_target_id = None
if hasattr(msg, 'thread') and msg.thread and msg.id == msg.thread.id:
processed_threads.add(msg.thread.id)
stats["threads"] += 1
elif msg.channel.type in [11, 12]: # Thread channels
# It's a message IN a thread.
# In Fluxer, threads might just be linear messages or threaded replies depending on schema
# For basic migration we just send it to the parent mapped target channel.
# The parent mapped target channel ID should already be calculated correctly by get_target_channel_id (which returns mapped thread or parent channel)
pass
# Formatting
files = []
file_names = []
# Always ensure alias is created/retrieved to populate user_alias table
alias = context.state.get_user_alias(str(msg.author.id))
anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False
if anonymize_users:
author_name = alias or "Anonymized User"
author_avatar_url = None
else:
author_name = msg.author.display_name
author_avatar_url = msg.author.avatar.url if hasattr(msg.author, 'avatar') and msg.author.avatar else None
for att in msg.attachments:
media_info = db_media.get(att.local_hash) if db_media else None
local_path = None
if media_info:
local_path = Path(media_info["local_path"])
elif hasattr(att, 'read'):
# Fallback
pass
if local_path and local_path.exists():
files.append(local_path)
file_names.append(att.filename)
content = msg.content or ""
# Stickers
for sticker in msg.stickers:
sticker_name = sticker.name
sticker_url = sticker.url
# Check for uploaded media pool logic first
s_hash = sticker.local_hash
sticker_file = None
s_media = db_media.get(s_hash) if db_media and s_hash else None
if s_media:
s_path = Path(s_media["local_path"])
if s_path.exists():
sticker_file = s_path
content += f"\n[Sticker: {sticker_name}]"
if sticker_file:
files.append(sticker_file)
file_names.append(f"sticker_{sticker_name}.png")
content = clean_mentions(
content=content,
guild=context.discord_reader.guild,
user_mentions=msg.mentions,
role_mentions=msg.role_mentions,
channel_mentions=msg.channel_mentions,
emoji_map=emoji_map,
channel_map=context.state.channel_map,
state=context.state,
target_server_id=target_server_id,
channel_names=context.channel_names if hasattr(context, 'channel_names') else None,
anonymize_users=anonymize_users
)
if not content and not files:
logger.debug(f"Message {msg.id} empty after processing, skipping.")
continue
timestamp_int = int(msg.created_at.timestamp())
if msg.reference and msg.reference.message_id:
# Resolve the author of the message being replied to
source_ref_msg = await context.discord_reader.get_message(msg.channel.id, msg.reference.message_id)
if source_ref_msg and source_ref_msg.author:
ref_author_id = str(source_ref_msg.author.id)
if anonymize_users:
ref_name = context.state.get_user_alias(ref_author_id) or "Anonymized User"
else:
ref_name = source_ref_msg.author.display_name
content = f"`@{ref_name}`\n{content}"
else:
# Fallback if author cannot be resolved (e.g. deleted/missing from backup)
tgt_reply = context.state.get_target_message_id(target_channel_id, msg.reference.message_id)
if tgt_reply:
content = f"[Reply to {tgt_reply}]\n{content}"
try:
await _process_and_send_message(
context=context,
msg=msg,
target_channel_id=target_channel_id,
stats=stats,
processed_threads=processed_threads
fluxer_msg_id = await context.fluxer_writer.send_message(
channel_id=target_channel_id,
author_name=author_name,
author_avatar_url=author_avatar_url,
content=content,
files=files,
timestamp=timestamp_int,
embeds=msg.embeds
)
if fluxer_msg_id:
context.state.set_target_message_mapping(target_channel_id, msg.id, fluxer_msg_id)
context.state.update_last_message_id(target_channel_id, msg.id)
context.state.set_waterfall_last_id(msg.id)
stats["attachments"] += len(files) if files else 0
stats["messages"] += 1
stats["last_message_content"] = content
stats["last_message_author"] = author_name
if not stats["first_message_url"]:
stats["first_message_url"] = msg.jump_url

View file

@ -36,7 +36,6 @@ class FluxerWriter:
guilds_list.append((label, str(g.id)))
return guilds_list
except Exception as e:
print(f"Failed to fetch Fluxer communities via HTTP: {e}")
logger.error(f"Failed to fetch Fluxer communities via HTTP: {e}")
raise
@ -62,7 +61,6 @@ class FluxerWriter:
return w
except Exception as e:
print(f"Failed to manage webhook for channel {channel_id}: {e}")
logger.error(f"Failed to manage webhook for channel {channel_id}: {e}")
return None
async def start(self):
@ -324,7 +322,6 @@ class FluxerWriter:
logger.debug(f"Fluxer: Webhook send complete, msg_id={msg.id if msg else 'None'}")
return str(msg.id) if msg else None
except asyncio.TimeoutError:
print(f"Fluxer: Webhook send timed out after 45s for channel {channel_id}")
logger.error(f"Fluxer: Webhook send timed out after 45s for channel {channel_id}")
return None
else:
@ -343,24 +340,19 @@ class FluxerWriter:
logger.debug(f"Fluxer: Sending message via bot for user '{author_name}'")
try:
kwargs = {
"channel_id": channel_id,
"content": final_bot_content,
"embeds": normalized_embeds
}
if fluxer_files:
kwargs["files"] = fluxer_files
if message_reference:
kwargs["message_reference"] = message_reference
msg_data = await asyncio.wait_for(
self.client.send_message(**kwargs),
self.client.send_message(
channel_id=channel_id,
content=final_bot_content,
files=fluxer_files,
embeds=normalized_embeds,
message_reference=message_reference
),
timeout=45.0
)
logger.debug(f"Fluxer: Bot send complete, msg_id={msg_data.get('id') if msg_data else 'None'}")
return str(msg_data["id"]) if msg_data else None
except asyncio.TimeoutError:
print(f"Fluxer: Bot send timed out after 45s for channel {channel_id}")
logger.error(f"Fluxer: Bot send timed out after 45s for channel {channel_id}")
return None
except Exception as e:
@ -378,27 +370,22 @@ class FluxerWriter:
fluxer_files = None
if files:
fluxer_files = [f if hasattr(f, "filename") else File(io.BytesIO(f["data"]), filename=f["filename"]) for f in files]
fluxer_files = [File(io.BytesIO(f["data"]), filename=f["filename"]) for f in files]
message_reference = None
if reply_to_message_id:
message_reference = {"message_id": str(reply_to_message_id), "channel_id": str(channel_id)}
try:
kwargs = {
"channel_id": channel_id,
"content": content
}
if fluxer_files:
kwargs["files"] = fluxer_files
if message_reference:
kwargs["message_reference"] = message_reference
msg_data = await self.client.send_message(**kwargs)
msg_data = await self.client.send_message(
channel_id=channel_id,
content=content,
files=fluxer_files,
message_reference=message_reference
)
return str(msg_data["id"]) if msg_data else None
except Exception as e:
print(f"Failed to send marker: {e}")
logger.error(f"Failed to send marker: {e}")
return None
async def create_role(self, name: str, color: int, hoist: bool, mentionable: bool, permissions: int, position: Optional[int] = None) -> str:
@ -421,7 +408,6 @@ class FluxerWriter:
return str(role["id"])
except Exception as e:
print(f"Failed to copy role {name}: {e}")
logger.error(f"Failed to copy role {name}: {e}")
return ""
async def create_emoji(self, name: str, image_bytes: bytes) -> str:
@ -487,7 +473,6 @@ class FluxerWriter:
)
except Exception as e:
print(f"Failed to update community metadata: {e}")
logger.error(f"Failed to update community metadata: {e}")
async def remove_community_logo_and_banner(self) -> dict:
"""
@ -518,7 +503,6 @@ class FluxerWriter:
)
except Exception as e:
print(f"Failed to remove community icon: {e}")
logger.error(f"Failed to remove community icon: {e}")
# 3. Remove banner if set
if has_banner:
@ -529,7 +513,6 @@ class FluxerWriter:
)
except Exception as e:
print(f"Failed to remove community banner: {e}")
logger.error(f"Failed to remove community banner: {e}")
return {
"icon": "REMOVED" if has_icon else "SKIP",
@ -561,7 +544,6 @@ class FluxerWriter:
await progress_callback(ch.get("name", "Unknown"), deleted, total)
except Exception as e:
print(f"Failed to delete channel {ch.get('name')}: {e}")
logger.error(f"Failed to delete channel {ch.get('name')}: {e}")
return deleted
async def reset_channel_permissions(self, progress_callback=None) -> int:
@ -594,14 +576,12 @@ class FluxerWriter:
)
)
except Exception as e:
print(f"Failed to delete overwrite {ow['id']} for channel {ch['id']}: {e}")
logger.error(f"Failed to delete overwrite {ow['id']} for channel {ch['id']}: {e}")
processed += 1
if progress_callback:
await progress_callback(ch.get("name", "Unknown"), processed, total)
except Exception as e:
print(f"Failed to reset permissions for channel {ch.get('name')}: {e}")
logger.error(f"Failed to reset permissions for channel {ch.get('name')}: {e}")
return processed
async def set_channel_permission(self, channel_id: str, overwrite_id: str, allow: int, deny: int, is_role: bool = True):
@ -623,7 +603,6 @@ class FluxerWriter:
type=0 if is_role else 1
)
except Exception as e:
print(f"Failed to set permission on channel {channel_id} for overwrite {overwrite_id}: {e}")
logger.error(f"Failed to set permission on channel {channel_id} for overwrite {overwrite_id}: {e}")
@ -665,7 +644,6 @@ class FluxerWriter:
await progress_callback(role.get("name", "Unknown"), deleted, total)
except Exception as e:
print(f"Failed to delete role {role.get('name')}: {e}")
logger.error(f"Failed to delete role {role.get('name')}: {e}")
return deleted
async def delete_all_emojis_and_stickers(self, progress_callback=None) -> dict:
@ -689,10 +667,8 @@ class FluxerWriter:
await progress_callback(emoji.get("name", "Unknown"), "Emoji", emoji_deleted, emoji_total)
except Exception as e:
print(f"Failed to delete emoji {emoji.get('name')}: {e}")
logger.error(f"Failed to delete emoji {emoji.get('name')}: {e}")
except Exception as e:
print(f"Failed to fetch emojis: {e}")
logger.error(f"Failed to fetch emojis: {e}")
# Delete stickers
try:
@ -706,10 +682,8 @@ class FluxerWriter:
await progress_callback(sticker.get("name", "Unknown"), "Sticker", sticker_deleted, sticker_total)
except Exception as e:
print(f"Failed to delete sticker {sticker.get('name')}: {e}")
logger.error(f"Failed to delete sticker {sticker.get('name')}: {e}")
except Exception as e:
print(f"Failed to fetch stickers: {e}")
logger.error(f"Failed to fetch stickers: {e}")
return {"emojis": emoji_deleted, "stickers": sticker_deleted}

View file

@ -16,52 +16,41 @@ async def sync_channel_state(context: MigrationContext):
channels = await context.discord_reader.get_channels()
target_channels = await context.writer.get_channels()
# Build maps for Stoat lookup
# {name: id} for categories (type 4)
target_cats = {c.get("name"): str(c.get("id")) for c in target_channels if c.get("type") == 4}
# {parent_id: {name: id}} for channels
target_structure = {}
for c in target_channels:
if c.get("type") == 4: continue
p_id = str(c.get("parent_id")) if c.get("parent_id") else "root"
if p_id not in target_structure: target_structure[p_id] = {}
target_structure[p_id][c.get("name")] = str(c.get("id"))
# Build name -> id map and ID set for Stoat for fast lookup
target_name_map = {c.get("name"): str(c.get("id")) for c in target_channels if c.get("name")}
target_id_set = {str(c.get("id")) for c in target_channels}
updates = 0
removals = 0
# 1. Sync Categories
# 1. Verify and Sync Categories
for cat in categories:
discord_id = str(cat.id)
target_id = context.state.get_target_category_id(discord_id)
if target_id:
if target_id not in target_id_set:
context.state.remove_category_mapping(discord_id)
removals += 1
elif cat.name in target_cats:
context.state.set_target_category_mapping(discord_id, target_cats[cat.name])
elif cat.name in target_name_map:
context.state.set_target_category_mapping(discord_id, target_name_map[cat.name])
updates += 1
# 2. Sync Channels (parent-aware)
# 2. Verify and Sync Channels
for ch in channels:
discord_id = str(ch.id)
target_id = context.state.get_target_channel_id(discord_id)
if target_id:
if target_id not in target_id_set:
context.state.remove_channel_mapping(discord_id)
removals += 1
else:
# Try to match by name within the mapped parent category
p_discord_id = str(ch.category_id) if ch.category_id else "root"
p_target_id = context.state.get_target_category_id(p_discord_id) if p_discord_id != "root" else "root"
if p_target_id in target_structure and ch.name in target_structure[p_target_id]:
context.state.set_target_channel_mapping(discord_id, target_structure[p_target_id][ch.name])
updates += 1
elif ch.name in target_name_map:
context.state.set_target_channel_mapping(discord_id, target_name_map[ch.name])
updates += 1
if updates > 0 or removals > 0:
logger.info(f"Stoat Channel sync: {updates} mapped, {removals} stale mappings removed")
logger.info(f"Channel sync: {updates} mapped, {removals} stale mappings removed")
async def migrate_channels(context: MigrationContext, progress_callback: Callable[[str, str, int, int], Awaitable[None]] | None = None, force: bool = False) -> dict:
@ -118,22 +107,7 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
if total == 0:
return cloned_info
# 1. Create missing categories first
for cat in missing_categories:
if not context.is_running: break
state_key = str(cat.id)
target_id = await context.writer.create_channel(cat.name, type=4)
if target_id:
context.state.set_target_category_id(state_key, target_id)
cloned_info["categories_created"].append(cat.name)
if cat.name not in cloned_info["structure"]:
cloned_info["structure"][cat.name] = []
current_idx += 1
if progress_callback: await progress_callback(cat.name, "Copying", current_idx, total)
# 2. Create missing channels (now with parent_id available)
# 1. Create missing channels (unparented for now)
for channel in channels_to_create:
if not context.is_running: break
@ -145,6 +119,7 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
logger.debug(f"Creating channel {channel.name}: topic={topic}, nsfw={nsfw}, slowmode={slowmode}")
# Map Discord-specific types to target-supported types
# 5 (News) -> 0 (Text), and fallback any unknown non-voice types to text
raw_type = channel.type.value if hasattr(channel.type, 'value') else 0
if raw_type == context.discord_reader.CHANNEL_TYPE_VOICE.value:
ch_type = 2
@ -153,22 +128,20 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
ch_type = 0
is_voice = False
else:
# Fallback for Stage channels (13) etc. to Text for safety
ch_type = 0
is_voice = False
# Resolve parent category
parent_id = context.state.get_target_category_id(str(channel.category_id)) if channel.category_id else None
target_id = await context.writer.create_channel(
name=channel.name,
topic=topic if not is_voice else "",
type=ch_type,
parent_id=parent_id,
parent_id=None,
nsfw=nsfw if not is_voice else False,
slowmode_delay=slowmode if not is_voice else 0
)
if target_id:
context.state.set_target_channel_id(state_key, target_id)
context.state.set_target_channel_mapping(state_key, target_id)
cloned_info["channels_created"].append(channel.name)
parent_name = cat_name_map.get(str(channel.category_id), "No Category") if channel.category_id else "No Category"
@ -183,8 +156,7 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
name=channel.name,
topic=topic,
nsfw=nsfw,
slowmode_delay=slowmode,
parent_id=parent_id
slowmode_delay=slowmode
)
current_idx += 1
@ -200,21 +172,32 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
logger.debug(f"Syncing existing channel {channel.name} ({target_id}): topic={topic}, nsfw={nsfw}, slowmode={slowmode}")
parent_id = context.state.get_target_category_id(str(channel.category_id)) if channel.category_id else None
await context.writer.modify_channel(
channel_id=target_id,
name=channel.name,
topic=topic,
nsfw=nsfw,
slowmode_delay=slowmode,
parent_id=parent_id
slowmode_delay=slowmode
)
cloned_info["channels_synced"].append(channel.name)
current_idx += 1
if progress_callback: await progress_callback(channel.name, "Syncing", current_idx, total)
# 3. Create missing categories
for cat in missing_categories:
if not context.is_running: break
state_key = str(cat.id)
target_id = await context.writer.create_channel(cat.name, type=4)
if target_id:
context.state.set_target_category_mapping(state_key, target_id)
cloned_info["categories_created"].append(cat.name)
if cat.name not in cloned_info["structure"]:
cloned_info["structure"][cat.name] = []
current_idx += 1
if progress_callback: await progress_callback(f"Cat: {cat.name}", "Copying", current_idx, total)
# 4. Final step: Parent the channels into categories via mass server.edit()
@ -240,7 +223,7 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
# Resolve Stoat categories
# We iterate over the categories from the server to ensure we don't drop any
for stoat_cat in (server.categories or []):
for stoat_cat in server.categories:
# Check if this Stoat category maps to any Discord category
discord_cat_id = next((d_id for d_id, t_id in context.state.category_map.items() if t_id == str(stoat_cat.id)), None)
@ -278,35 +261,8 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
target_categories.sort(key=get_cat_position)
try:
import aiohttp
api_base = (context.writer.api_url or "https://api.stoat.chat/0.8").rstrip("/")
cats_payload = []
for c in target_categories:
entry: dict = {"id": str(c.id), "title": c.title, "channels": [str(x) for x in c.channels]}
if getattr(c, "default_permissions", None) is not None:
entry["default_permissions"] = c.default_permissions
if getattr(c, "role_permissions", None):
entry["role_permissions"] = c.role_permissions
cats_payload.append(entry)
for _attempt in range(5):
async with aiohttp.ClientSession() as session:
async with session.patch(
f"{api_base}/servers/{context.writer.community_id}",
json={"categories": cats_payload},
headers={"X-Bot-Token": context.writer.token.strip(), "Content-Type": "application/json"},
) as resp:
if resp.status == 429:
import re as _re
body = await resp.json(content_type=None)
wait_ms = body.get("retry_after", 10000)
logger.warning(f"Rate limited on mass parent, retrying in {wait_ms/1000:.1f}s…")
await asyncio.sleep(wait_ms / 1000 + 0.5)
continue
if resp.status not in (200, 204):
body = await resp.json(content_type=None)
raise Exception(f"HTTP {resp.status}: {body}")
logger.info("Successfully parented all channels.")
break
await server.edit(categories=target_categories)
logger.info("Successfully parented all channels.")
except Exception as ex:
logger.error(f"Failed to mass parent channels: {ex}")

View file

@ -155,192 +155,7 @@ async def get_channel_threads(reader: Any, channel_id: int) -> List[Any]:
return threads
async def _process_and_send_message(
context: MigrationContext,
msg: Any,
target_channel_id: str,
stats: Dict[str, Any],
thread_id: str | None = None,
parent_target_id: str | None = None,
thread_name: str | None = None,
processed_threads: set | None = None
) -> str | None:
"""
Internal helper to process a single Discord message (mentions, attachments, stickers)
and send it to the Stoat platform.
"""
# 1. Processing Flags
is_forwarded = False
if hasattr(msg.flags, 'forwarded'):
is_forwarded = msg.flags.forwarded
# 2. Content & Formatting
content = msg.content or ""
anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False
alias = context.state.get_user_alias(str(msg.author.id))
# Process Stickers
files = []
if hasattr(msg, 'stickers') and msg.stickers:
for s in msg.stickers:
try:
sticker_data = await context.discord_reader.download_sticker(s)
if not sticker_data: continue
format_val = getattr(s, 'format', 'png')
if hasattr(format_val, 'name'):
ext = format_val.name.lower()
elif isinstance(format_val, int):
format_map = {1: 'png', 2: 'apng', 3: 'lottie', 4: 'gif'}
ext = format_map.get(format_val, 'png')
else:
ext = str(format_val).lower()
# Conversion logic for Stoat (WebP or GIF focus)
if ext == 'lottie' and HAS_LOTTIE:
try:
lottie_data = json.loads(sticker_data)
def _convert_lottie_to_gif(data):
animation = Animation.load(data)
output = io.BytesIO()
export_gif(animation, output)
return output.getvalue()
sticker_data = await asyncio.to_thread(_convert_lottie_to_gif, lottie_data)
ext = 'gif'
except Exception: ext = 'json'
elif ext == 'apng':
try:
from PIL import Image
def _convert_apng_to_gif(data):
img = Image.open(io.BytesIO(data))
gif_buf = io.BytesIO()
if getattr(img, 'n_frames', 1) > 1:
frames = []
durations = []
for i in range(img.n_frames):
img.seek(i)
frame = img.convert('RGBA')
current_frame = Image.new('RGBA', img.size, (0,0,0,0))
current_frame.paste(frame, (0, 0))
frames.append(current_frame)
durations.append(img.info.get('duration', 100))
frames[0].save(gif_buf, format='GIF', save_all=True, append_images=frames[1:], loop=0, duration=durations, disposal=2, transparency=0)
else: img.save(gif_buf, format='GIF')
return gif_buf.getvalue()
sticker_data = await asyncio.to_thread(_convert_apng_to_gif, sticker_data)
ext = 'gif'
except Exception: pass
files.append({
"filename": f"sticker_{s.name}_{s.id}.{ext}",
"data": sticker_data,
"content_type": f"image/{ext}" if ext != "json" else "application/json"
})
stats["attachments"] += 1
except Exception as e:
logger.error(f"Failed to download sticker {getattr(s, 'name', 'unknown')}: {e}")
# Process Attachments
attachments_to_process = list(msg.attachments)
if is_forwarded and hasattr(msg, 'message_snapshots') and msg.message_snapshots:
snapshot = msg.message_snapshots[0]
if not content: content = snapshot.content
attachments_to_process.extend(snapshot.attachments)
for att in attachments_to_process:
try:
att_data = await context.discord_reader.download_attachment(att)
files.append({
"filename": att.filename,
"data": att_data,
"content_type": getattr(att, "content_type", None)
})
stats["attachments"] += 1
except Exception as e:
logger.error(f"Failed to download attachment {att.filename}: {e}")
# Clean Mentions
content = clean_mentions(
content=content,
guild=context.discord_reader.guild,
user_mentions=msg.mentions,
role_mentions=msg.role_mentions,
channel_mentions=msg.channel_mentions,
emoji_map=context.state.emoji_map,
channel_map=context.state.channel_map,
state=context.state,
target_server_id=context.stoat_writer.community_id,
channel_names=context.channel_names if hasattr(context, 'channel_names') else None,
anonymize_users=anonymize_users
)
if not content and not files:
return None
# Reply Resolution
reply_to_stoat_id = None
if msg.reference and msg.reference.message_id:
reply_to_stoat_id = context.state.get_target_message_id(target_channel_id, str(msg.reference.message_id))
if not reply_to_stoat_id:
# Fallback author tagging
try:
source_ref_msg = await context.discord_reader.get_message(msg.channel.id, msg.reference.message_id)
if source_ref_msg and source_ref_msg.author:
ref_name = context.state.get_user_alias(str(source_ref_msg.author.id)) if anonymize_users else source_ref_msg.author.display_name
content = f"`@{ref_name}`\n{content}"
else:
tgt_reply = context.state.get_target_message_id(target_channel_id, msg.reference.message_id)
if tgt_reply: content = f"[Reply to {tgt_reply}]\n{content}"
except Exception: pass
# Thread logic
if not reply_to_stoat_id and parent_target_id and stats["messages"] == 0:
reply_to_stoat_id = parent_target_id
if thread_name and stats["messages"] == 0:
content = f"> <<< THREAD: **{thread_name}** >>>\n{content}"
# Author resolution
if anonymize_users:
author_name = alias or "Anonymized User"
author_avatar_url = None
else:
author_name = msg.author.display_name
author_avatar_url = str(msg.author.display_avatar.url) if msg.author.display_avatar.url else None
if author_avatar_url and not author_avatar_url.startswith("http"): author_avatar_url = None
# Send Message
stoat_msg_id = await context.stoat_writer.send_message(
channel_id=target_channel_id,
author_name=author_name,
author_avatar_url=author_avatar_url,
content=content,
timestamp=int(msg.created_at.timestamp()),
files=files if files else None,
reply_to_message_id=reply_to_stoat_id,
is_forwarded=is_forwarded,
embeds=msg.embeds
)
if stoat_msg_id:
if thread_id:
context.state.set_thread_message_mapping(target_channel_id, thread_id, str(msg.id), stoat_msg_id)
context.state.update_thread_last_message_timestamp(target_channel_id, thread_id, str(msg.created_at))
context.state.update_thread_last_message_id(target_channel_id, thread_id, str(msg.id))
context.state.increment_thread_stats(target_channel_id, thread_id, messages=1, files=len(files) if files else 0)
else:
context.state.set_message_mapping(target_channel_id, str(msg.id), stoat_msg_id)
context.state.update_last_message_timestamp(target_channel_id, str(msg.created_at))
context.state.update_last_message_id(target_channel_id, str(msg.id))
context.state.increment_stats(target_channel_id, messages=1, files=len(files) if files else 0)
stats["messages"] += 1
stats["last_message_content"] = content
stats["last_message_author"] = msg.author.display_name
return stoat_msg_id
async def analyze_migration(context: MigrationContext, source_channel_id: int, after_message_id: int | None = None, inclusive: bool = False, progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None, processed_threads: set | None = None) -> Dict[str, int]:
"""
Scans channel history to count messages, threads, and attachments.
"""
@ -731,30 +546,87 @@ async def migrate_messages(
except Exception as e:
logger.error(f"Failed to download sticker {getattr(s, 'name', 'unknown')}: {e}")
# Check for existing mapping to avoid duplicates when resuming
if context.state.get_target_message_id(target_channel_id, str(msg.id)):
continue
try:
stoat_msg_id = await _process_and_send_message(
context=context,
msg=msg,
target_channel_id=target_channel_id,
stats=stats,
thread_id=thread_id,
parent_target_id=parent_target_id,
thread_name=thread_name,
processed_threads=processed_threads
)
# Check for existing mapping to avoid duplicates when resuming
if context.state.get_target_message_id(target_channel_id, str(msg.id)):
continue
# Check for associated thread (Individual mode recursion)
# Check if this message is a reply
reply_to_stoat_id = None
if msg.reference and msg.reference.message_id:
reply_to_stoat_id = context.state.get_target_message_id(target_channel_id, str(msg.reference.message_id))
if reply_to_stoat_id:
logger.debug(f"Detected reply to Discord ID {msg.reference.message_id} -> Stoat ID {reply_to_stoat_id}")
else:
logger.debug(f"Reply target Discord ID {msg.reference.message_id} not found in current session map.")
# If this is the FIRST thread message and we have a parent_target_id, force it as reply to the starter
if not reply_to_stoat_id and parent_target_id and stats["messages"] == 0:
reply_to_stoat_id = parent_target_id
# Prepend thread marker to the first message of the thread
if thread_name and stats["messages"] == 0:
content = f"> <<< THREAD: **{thread_name}** >>>\n{content}"
# Always ensure alias is created/retrieved to populate user_alias table
alias = context.state.get_user_alias(str(msg.author.id))
anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False
if anonymize_users:
author_name = alias or "Anonymized User"
author_avatar_url = None
else:
author_name = msg.author.display_name
author_avatar_url = str(msg.author.display_avatar.url) if msg.author.display_avatar.url else None
if author_avatar_url and not author_avatar_url.startswith("http"):
author_avatar_url = None
logger.debug(f"Stoat: Calling send_message for Discord ID {msg.id}")
stoat_msg_id = await context.stoat_writer.send_message(
channel_id=target_channel_id,
author_name=author_name,
author_avatar_url=author_avatar_url,
content=content,
timestamp=int(msg.created_at.timestamp()),
files=files if files else None,
reply_to_message_id=reply_to_stoat_id,
is_forwarded=is_forwarded,
embeds=msg.embeds
)
if stoat_msg_id:
if thread_id:
context.state.set_thread_message_mapping(target_channel_id, thread_id, str(msg.id), stoat_msg_id)
else:
context.state.set_message_mapping(target_channel_id, str(msg.id), stoat_msg_id)
if thread_id:
context.state.update_thread_last_message_timestamp(target_channel_id, thread_id, str(msg.created_at))
context.state.update_thread_last_message_id(target_channel_id, thread_id, str(msg.id))
context.state.increment_thread_stats(target_channel_id, thread_id, messages=1, files=len(files) if files else 0)
else:
context.state.update_last_message_timestamp(target_channel_id, str(msg.created_at))
context.state.update_last_message_id(target_channel_id, str(msg.id))
context.state.increment_stats(target_channel_id, messages=1, files=len(files) if files else 0)
stats["messages"] += 1
stats["last_message_content"] = content
stats["last_message_author"] = msg.author.display_name
# Check for associated thread (Normal case: parent message is migrated)
if hasattr(msg, 'thread') and msg.thread:
thread = msg.thread
if thread.id not in processed_threads:
processed_threads.add(thread.id)
# Track thread entry
stats["threads"] += 1
# Fetch last migrated message ID for this thread
thread_after_id = context.state.get_thread_last_message_id(target_channel_id, str(thread.id))
if thread_after_id:
logger.info(f"Resuming thread '{thread.name}' from after message ID: {thread_after_id}")
# Migrate thread messages recursively
thread_stats = await migrate_messages(
context=context,
source_channel_id=thread.id,
@ -769,6 +641,7 @@ async def migrate_messages(
stats["attachments"] += thread_stats["attachments"]
stats["threads"] += thread_stats["threads"]
# Send End Marker
if context.is_running:
await context.stoat_writer.send_marker(
channel_id=target_channel_id,
@ -783,11 +656,12 @@ async def migrate_messages(
if progress_callback:
await progress_callback(stats)
except Exception as e:
if "MissingPermission" in str(e): raise
# If it's a permission error, stop the entire migration
if "MissingPermission" in str(e):
raise
logger.error(f"Failed to process message {msg.id}: {e}")
import traceback
logger.error(traceback.format_exc())
# Mark thread as completed if we finished the loop without being interrupted
if thread_id and context.is_running:
@ -921,14 +795,106 @@ async def migrate_global_messages(
elif msg.channel.type in [11, 12]:
pass
# Formatting
files = []
# Always ensure alias is created/retrieved to populate user_alias table
alias = context.state.get_user_alias(str(msg.author.id))
anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False
if anonymize_users:
author_name = alias or "Anonymized User"
author_avatar_url = None
else:
author_name = msg.author.display_name
author_avatar_url = msg.author.avatar.url if hasattr(msg.author, 'avatar') and msg.author.avatar else None
for att in msg.attachments:
media_info = db_media.get(att.local_hash) if db_media else None
local_path = None
if media_info:
local_path = Path(media_info["local_path"])
if local_path and local_path.exists():
try:
with open(local_path, "rb") as f:
files.append({"filename": att.filename, "data": f.read()})
except Exception as fe:
logger.error(f"Failed to read file {local_path}: {fe}")
content = msg.content or ""
for sticker in msg.stickers:
sticker_name = sticker.name
s_hash = sticker.local_hash
sticker_file = None
s_media = db_media.get(s_hash) if db_media and s_hash else None
if s_media:
s_path = Path(s_media["local_path"])
if s_path.exists():
sticker_file = s_path
content += f"\n[Sticker: {sticker_name}]"
if sticker_file:
files.append(sticker_file)
file_names.append(f"sticker_{sticker_name}.png")
content = clean_mentions(
content=content,
guild=context.discord_reader.guild,
user_mentions=msg.mentions,
role_mentions=msg.role_mentions,
channel_mentions=msg.channel_mentions,
emoji_map=emoji_map,
channel_map=context.state.channel_map,
state=context.state,
target_server_id=target_server_id,
channel_names=context.channel_names if hasattr(context, 'channel_names') else None,
anonymize_users=anonymize_users
)
if not content and not files:
logger.debug(f"Message {msg.id} empty after processing, skipping.")
continue
timestamp_int = int(msg.created_at.timestamp())
if msg.reference and msg.reference.message_id:
# Resolve the author of the message being replied to
source_ref_msg = await context.discord_reader.get_message(msg.channel_id, msg.reference.message_id)
if source_ref_msg and source_ref_msg.author:
ref_author_id = str(source_ref_msg.author.id)
if anonymize_users:
ref_name = context.state.get_user_alias(ref_author_id) or "Anonymized User"
else:
ref_name = source_ref_msg.author.display_name
content = f"`@{ref_name}`\n{content}"
else:
tgt_reply = context.state.get_target_message_id(target_channel_id, msg.reference.message_id)
if tgt_reply:
content = f"[Reply to {tgt_reply}]\n{content}"
try:
await _process_and_send_message(
context=context,
msg=msg,
target_channel_id=target_channel_id,
stats=stats,
processed_threads=processed_threads
stoat_msg_id = await context.stoat_writer.send_message(
channel_id=target_channel_id,
author_name=author_name,
author_avatar_url=author_avatar_url,
content=content,
files=files,
timestamp=timestamp_int,
embeds=msg.embeds
)
if stoat_msg_id:
context.state.set_target_message_mapping(target_channel_id, msg.id, stoat_msg_id)
context.state.update_last_message_id(target_channel_id, msg.id)
context.state.set_waterfall_last_id(msg.id)
stats["attachments"] += len(files) if files else 0
stats["messages"] += 1
stats["last_message_content"] = content
stats["last_message_author"] = author_name
if not stats["first_message_url"]:
stats["first_message_url"] = msg.jump_url
@ -939,7 +905,6 @@ async def migrate_global_messages(
except Exception as e:
logger.error(f"Failed to process global message {msg.id}: {e}")
except (KeyboardInterrupt, asyncio.CancelledError):
context.is_running = False

View file

@ -5,90 +5,24 @@ from typing import Optional, List, Dict, Any
logger = logging.getLogger(__name__)
async def _discover_stoat_config(api_url: str) -> Dict[str, Optional[str]]:
"""
Fetches the Stoat/Revolt instance configuration to discover the WS and CDN URLs.
Returns a dict with 'ws' and 'cdn' keys.
"""
results = {"ws": None, "cdn": None}
if not api_url or api_url == "default" or "stoat.chat" in api_url:
return results
import aiohttp
try:
# Standard Revolt discovery endpoint is the root API URL
async with aiohttp.ClientSession() as session:
async with session.get(api_url.rstrip("/") + "/") as resp:
if resp.status == 200:
data = await resp.json()
results["ws"] = data.get("ws")
# Features might be in 'features.autumn.url'
features = data.get("features", {})
autumn = features.get("autumn", {})
results["cdn"] = autumn.get("url")
if results["ws"] or results["cdn"]:
logger.debug(f"Stoat Discovery: Found WS={results['ws']}, CDN={results['cdn']}")
return results
except Exception as e:
logger.debug(f"Stoat Discovery failed (fetching): {e}")
# Fallback to inference if fetch failed
from urllib.parse import urlparse
try:
parsed = urlparse(api_url)
if parsed.netloc:
# Traditional defaults for self-hosted
results["ws"] = f"wss://{parsed.netloc}/ws"
results["cdn"] = f"https://{parsed.netloc}/autumn"
except Exception:
pass
return results
class StoatWriter:
def __init__(self, token: str, community_id: str, api_url: str = "default", ws_url: str = None):
def __init__(self, token: str, community_id: str, api_url: str = "default"):
self.token = token
self.community_id = str(community_id)
self.api_url = api_url
self.ws_url = ws_url
self.client: Optional[stoat.Client] = None
self._server = None
self._me = None
self._validation_cache = None
@staticmethod
async def fetch_guilds(token: str, api_url: str = "default", ws_url: str = None, cdn_url: str = None) -> list[tuple[str, str]]:
async def fetch_guilds(token: str, api_url: str = "default") -> list[tuple[str, str]]:
"""Fetches the list of Stoat servers the bot is in. Returns list of (label, id)."""
token = token.strip()
api_url = (api_url or "default").strip()
ws_url = (ws_url or "").strip()
cdn_url = (cdn_url or "").strip()
# Auto-discover URLs if not provided for custom domains
if api_url != "default" and (not ws_url or not cdn_url):
discovery = await _discover_stoat_config(api_url)
if not ws_url: ws_url = discovery["ws"] or ""
if not cdn_url: cdn_url = discovery["cdn"] or ""
# Diagnostics to both stdout and logger
log_msg = f"Stoat: Fetching guilds using API URL: {api_url}"
if ws_url: log_msg += f" [WS: {ws_url}]"
if cdn_url: log_msg += f" [CDN: {cdn_url}]"
print(log_msg)
logger.debug(log_msg)
client_kwargs = {
"token": token,
"bot": True,
"http_base": api_url if api_url != "default" else None,
"websocket_base": ws_url or None,
"cdn_base": cdn_url or None
}
client_kwargs = {"token": token, "bot": True}
if api_url and api_url != "default":
client_kwargs["http_base"] = api_url
client = stoat.Client(**client_kwargs)
logger.debug(f"Stoat: Initialized client with native http_base={client_kwargs['http_base']} and websocket_base={client_kwargs['websocket_base']}")
ready_event = asyncio.Event()
servers_list = []
@ -124,21 +58,14 @@ class StoatWriter:
else:
raise asyncio.TimeoutError("Timed out waiting for Stoat to be ready")
except Exception as e:
print(f"Failed to fetch Stoat servers: {e}")
logger.error(f"Failed to fetch Stoat servers: {e}")
raise
finally:
# Shutdown the specific client instance used for fetching
try:
await client.close()
except Exception:
pass
await client.close()
client_task.cancel()
try:
# Wait for the task to actually finish terminating
await asyncio.wait_for(client_task, timeout=2.0)
except (asyncio.CancelledError, asyncio.TimeoutError):
await client_task
except asyncio.CancelledError:
pass
return guilds_list
@ -153,45 +80,16 @@ class StoatWriter:
pass
self.client = None
api_url = (self.api_url or "default").strip()
ws_url = (self.ws_url or "").strip()
cdn_url = getattr(self, "cdn_url", "").strip()
# Auto-discover if not provided for custom domains
if api_url != "default" and (not ws_url or not cdn_url):
discovery = await _discover_stoat_config(api_url)
if not ws_url: ws_url = discovery["ws"] or ""
if not cdn_url: cdn_url = discovery["cdn"] or ""
token = self.token.strip()
# Diagnostics to both stdout and logger
log_msg = f"Stoat: Starting client using API URL: {api_url}"
if ws_url: log_msg += f" [WS: {ws_url}]"
if cdn_url: log_msg += f" [CDN: {cdn_url}]"
print(log_msg)
logger.debug(log_msg)
client_kwargs = {
"token": token,
"bot": True,
"http_base": api_url if api_url != "default" else None,
"websocket_base": ws_url or None,
"cdn_base": cdn_url or None
}
client_kwargs = {"token": self.token, "bot": True}
if self.api_url and self.api_url != "default":
client_kwargs["http_base"] = self.api_url
self.client = stoat.Client(**client_kwargs)
# Keep track of the start task so we can clean it up later
self._start_task = asyncio.create_task(self.client.start())
try:
self._me = await self.client.fetch_user("@me")
except Exception as e:
print(f"Failed to fetch bot user in StoatWriter: {e}")
logger.error(f"Failed to fetch bot user in StoatWriter: {e}")
# Ensure we clean up if start fails
await self.close()
self.client = None
self.client = None # Reset if we can't even fetch @me
@property
def my_id(self):
@ -337,93 +235,68 @@ class StoatWriter:
return results
except Exception as e:
print(f"Failed to fetch Stoat channels: {e}")
logger.error(f"Failed to fetch Stoat channels: {e}")
return []
async def create_channel(self, name: str, type: int = 0, topic: str = "", parent_id: Optional[str] = None, **kwargs) -> str:
server = await self._get_server(populate_channels=True)
try:
if type == 4: # Category — use direct HTTP to avoid stoat.py auth issues
import aiohttp, random, time
if type == 4: # Category
# The POST /categories endpoint throws 404 on some server versions, so we use server.edit(categories)
import random
import time
chars = "0123456789ABCDEFGHJKMNPQRSTVWXYZ"
ts = int(time.time() * 1000)
new_id = format(ts, '010X') + "".join(random.choice(chars) for _ in range(16))
api_base = (self.api_url or "https://api.stoat.chat/0.8").rstrip("/")
# Fetch current server to get existing categories
async with aiohttp.ClientSession() as session:
async with session.get(
f"{api_base}/servers/{self.community_id}",
headers={"X-Bot-Token": self.token.strip()},
) as resp:
sdata = await resp.json()
existing = sdata.get("categories") or []
existing.append({"id": new_id, "title": name, "channels": []})
async with session.patch(
f"{api_base}/servers/{self.community_id}",
json={"categories": existing},
headers={"X-Bot-Token": self.token.strip(), "Content-Type": "application/json"},
) as resp:
data = await resp.json()
logger.debug(f"PATCH /servers category resp {resp.status}: {data}")
print(f"[writer] PATCH /servers category resp {resp.status}: {data}")
if resp.status not in (200, 201):
raise Exception(f"HTTP {resp.status}: {data}")
self._server = None
# Mock a ULID
new_id = "01" + "".join(random.choice(chars) for _ in range(24))
categories = list(server.categories) if hasattr(server, "categories") and server.categories else []
# Workaround for stoat.py bug: existing categories may fail to_dict() if slots are uninitialized
for c in categories:
if not hasattr(c, "default_permissions"): c.default_permissions = None
if not hasattr(c, "role_permissions"): c.role_permissions = {}
new_cat = stoat.Category(id=new_id, title=name, channels=[])
if not hasattr(new_cat, "default_permissions"): new_cat.default_permissions = None
if not hasattr(new_cat, "role_permissions"): new_cat.role_permissions = {}
categories.append(new_cat)
await server.edit(categories=categories)
self._server = None # Clear cache after structural change
return new_id
else:
# Use direct HTTP instead of stoat.py client to avoid aiohttp session issues
import aiohttp
api_base = (self.api_url or "https://api.stoat.chat/0.8").rstrip("/")
channel_type = "Voice" if type == 2 else "Text"
payload: Dict[str, Any] = {"name": name, "type": channel_type}
if topic:
payload["description"] = topic
async with aiohttp.ClientSession() as session:
async with session.post(
f"{api_base}/servers/{self.community_id}/channels",
json=payload,
headers={"X-Bot-Token": self.token.strip()},
) as resp:
try:
data = await resp.json(content_type=None)
except Exception:
data = await resp.text()
if resp.status not in (200, 201):
raise Exception(f"HTTP {resp.status}: {data}")
self._server = None
return str(data["_id"])
elif type == 2: # Voice Channel
ch = await server.create_voice_channel(name=name)
self._server = None # Clear cache
return str(ch.id)
else: # Text Channel
ch = await server.create_text_channel(name=name, description=topic)
self._server = None # Clear cache
return str(ch.id)
except Exception as e:
print(f"Failed to create Stoat channel {name}: {e}")
logger.error(f"Failed to create Stoat channel {name}: {e}")
return ""
async def modify_channel(self, channel_id: str, name: Optional[str] = None, topic: Optional[str] = None, nsfw: Optional[bool] = None, slowmode_delay: Optional[int] = None, **kwargs) -> bool:
import aiohttp
api_base = (self.api_url or "https://api.stoat.chat/0.8").rstrip("/")
payload: Dict[str, Any] = {}
if name is not None:
payload["name"] = name
if topic is not None:
payload["description"] = topic
if nsfw is not None:
payload["nsfw"] = nsfw
if not payload:
return True
server = await self._get_server(populate_channels=True)
try:
async with aiohttp.ClientSession() as session:
async with session.patch(
f"{api_base}/channels/{channel_id}",
json=payload,
headers={"X-Bot-Token": self.token.strip()},
) as resp:
if resp.status not in (200, 204):
data = await resp.json()
raise Exception(f"HTTP {resp.status}: {data}")
self._server = None
channel = next((c for c in server.channels if str(c.id) == channel_id), None)
if not channel:
return False
edit_kwargs = {}
if name is not None:
edit_kwargs["name"] = name
if topic is not None:
edit_kwargs["description"] = topic
if nsfw is not None:
edit_kwargs["nsfw"] = nsfw
if edit_kwargs:
await channel.edit(**edit_kwargs)
self._server = None # Clear cache
# clone_server.py now handles all parenting bulk logic
return True
except Exception as e:
print(f"Failed to modify Stoat channel {channel_id}: {e}")
logger.error(f"Failed to modify Stoat channel {channel_id}: {e}")
return False
@ -524,26 +397,14 @@ class StoatWriter:
color=color
))
# Retry logic for 'NotFound' (common race condition on self-hosted instances)
max_tries = 2
for attempt in range(max_tries):
try:
msg = await channel.send(
content=final_content,
masquerade=masquerade,
replies=replies,
attachments=attachments,
embeds=stoat_embeds
)
return str(msg.id) if msg else None
except Exception as send_err:
# 'NotFound' often means the attachment is still being processed by the database
if "NotFound" in str(send_err) and attempt < max_tries - 1:
logger.warning(f"Stoat: Received NotFound during send (likely race condition). Retrying in 1.5s... (Attempt {attempt+1}/{max_tries})")
await asyncio.sleep(1.5)
continue
raise send_err
msg = await channel.send(
content=final_content,
masquerade=masquerade,
replies=replies,
attachments=attachments,
embeds=stoat_embeds
)
return str(msg.id) if msg else None
except Exception as e:
# If file type not allowed, skip attachments and still send the message
if "FileTypeNotAllowed" in str(e) and attachments:
@ -558,7 +419,6 @@ class StoatWriter:
return str(msg.id) if msg else None
raise # Re-raise MissingPermission and other errors
except Exception as e:
print(f"Failed to send Stoat message to {channel_id}: {e}")
logger.error(f"Failed to send Stoat message to {channel_id}: {e}")
raise # Let caller handle (migration loop will stop for permission errors)
@ -582,7 +442,6 @@ class StoatWriter:
)
return str(msg.id)
except Exception as e:
print(f"Failed to send Stoat marker to {channel_id}: {e}")
logger.error(f"Failed to send Stoat marker to {channel_id}: {e}")
return None
@ -606,39 +465,11 @@ class StoatWriter:
# Set permissions
if permissions != 0:
requested_perms = self._map_permissions(permissions)
# Fetch bot's own permissions to mask the requested set.
# Stoat/Revolt prevents granting permissions that the bot itself lacks.
try:
me = await server.fetch_member(self.my_id)
bot_perms = server.permissions_for(me)
# Manual masking of boolean attributes
final_perms = stoat.Permissions.none()
for attr in dir(requested_perms):
if not attr.startswith("_") and isinstance(getattr(requested_perms, attr), bool):
if getattr(requested_perms, attr) and getattr(bot_perms, attr, False):
try:
setattr(final_perms, attr, True)
except Exception:
pass
await server.set_role_permissions(role, allow=final_perms)
except Exception as perm_err:
logger.warning(f"Stoat: Could not mask/verify role permissions for {name} (falling back to minimal): {perm_err}")
# Attempt to set at least one basic permission if possible, or just skip
try:
minimal = stoat.Permissions.none()
minimal.view_channel = True
minimal.send_messages = True
await server.set_role_permissions(role, allow=minimal)
except Exception:
pass
s_perms = self._map_permissions(permissions)
await server.set_role_permissions(role, allow=s_perms)
return str(role.id)
except Exception as e:
print(f"Failed to create Stoat role {name}: {e}")
logger.error(f"Failed to create Stoat role {name}: {e}")
return ""
@ -695,7 +526,6 @@ class StoatWriter:
await server.set_default_permissions(s_perms)
return True
except Exception as e:
print(f"Failed to update Stoat default permissions: {e}")
logger.error(f"Failed to update Stoat default permissions: {e}")
return False
@ -706,7 +536,6 @@ class StoatWriter:
emoji = await server.create_server_emoji(name=name, image=image_bytes)
return str(emoji.id)
except Exception as e:
print(f"Failed to create Stoat emoji {name}: {e}")
logger.error(f"Failed to create Stoat emoji {name}: {e}")
return ""
@ -722,7 +551,6 @@ class StoatWriter:
banner=banner if banner is not None else stoat.UNDEFINED
)
except Exception as e:
print(f"Failed to update Stoat guild metadata: {e}")
logger.error(f"Failed to update Stoat guild metadata: {e}")
async def remove_community_logo_and_banner(self) -> dict:
@ -734,14 +562,12 @@ class StoatWriter:
try:
await server.edit(icon=None)
except Exception as e:
print(f"Failed to remove Stoat community icon: {e}")
logger.error(f"Failed to remove Stoat community icon: {e}")
if has_banner:
try:
await server.edit(banner=None)
except Exception as e:
print(f"Failed to remove Stoat community banner: {e}")
logger.error(f"Failed to remove Stoat community banner: {e}")
return {
@ -768,7 +594,6 @@ class StoatWriter:
if progress_callback:
await progress_callback(name, i, total)
except Exception as e:
print(f"Failed to delete Stoat channel {ch.id}: {e}")
logger.error(f"Failed to delete Stoat channel {ch.id}: {e}")
# To delete categories, we can wipe the categories array via server.edit to avoid 404 endpoint
@ -793,7 +618,6 @@ class StoatWriter:
await progress_callback(name, j, total)
j += 1
except Exception as e:
print(f"Failed to wipe Stoat categories via edit: {e}")
logger.error(f"Failed to wipe Stoat categories via edit: {e}")
return count
@ -810,23 +634,17 @@ class StoatWriter:
logger.info(f"Danger Zone: Skipping permission reset for audit channel {name}")
total -= 1
continue
# Fetch fresh channel to get current role_permissions
fresh_ch = await self.client.fetch_channel(ch.id)
# Clear default permissions
if hasattr(fresh_ch, "default_permissions") and fresh_ch.default_permissions is not None:
await fresh_ch.set_default_permissions(None)
# Clear all role overrides
if hasattr(fresh_ch, "role_permissions"):
for role_id in list(fresh_ch.role_permissions.keys()):
await fresh_ch.set_role_permissions(str(role_id), allow=stoat.Permissions.none(), deny=stoat.Permissions.none())
# In Stoat, clearing overrides might involve setting them to default or explicitly removing the role_permissions/default_permissions
# Since we don't know an explicit "clear_overrides" method, we'll wipe them by setting empty/none if possible.
# Actually Stoat allows overwriting. Setting allow=0 deny=0 for role overrides isn't explicitly clear.
# For safety, we will just pass. If the user expects it, we'd iterate over roles and set empty.
# A quick way is to edit the channel permissions to empty state if possible.
# Let's count them anyway.
# (Fluxer writer does a loop over existing overrides, we can just return 0 for now until we inspect Stoat `PermissionOverride` deletion)
count += 1
if progress_callback:
await progress_callback(name, i, total)
except Exception as e:
print(f"Failed to reset Stoat channel permissions for {ch.id}: {e}")
logger.error(f"Failed to reset Stoat channel permissions for {ch.id}: {e}")
return count
@ -853,7 +671,6 @@ class StoatWriter:
if "MissingPermission" in err_msg and "ViewChannel" in err_msg:
logger.error(f"Stoat LOCKOUT: Bot lacks 'ViewChannel' to edit {channel_id}. "
"Ensure the bot has 'Manage Server' or a role with 'Allow View Channel' rank higher than @everyone.")
print(f"Failed to set Stoat channel permission for {overwrite_id} on {channel_id}: {e}")
logger.error(f"Failed to set Stoat channel permission for {overwrite_id} on {channel_id}: {e}")
@ -895,31 +712,18 @@ class StoatWriter:
await emoji.delete()
count += 1
except Exception as e:
print(f"Failed to delete Stoat emoji {emoji.name}: {e}")
logger.error(f"Failed to delete Stoat emoji {emoji.name}: {e}")
return {"emojis": count, "stickers": 0}
async def close(self):
client = self.client
task = getattr(self, "_start_task", None)
self.client = None # Atomic clear to prevent new usage
self._start_task = None
self._me = None
self._server = None
if client:
try:
await client.close()
except Exception as e:
logger.debug(f"Error closing Stoat client session: {e}")
if task and not task.done():
task.cancel()
try:
await asyncio.wait_for(task, timeout=2.0)
except (asyncio.CancelledError, asyncio.TimeoutError):
pass
logger.debug(f"Error closing Stoat client: {e}")
self._validation_cache = None

View file

@ -163,9 +163,6 @@ class OperationPane(Container):
yield Button("Waterfall Migration", id="op_waterfall", disabled=True, variant="primary", tooltip="Migrate all messages globally in chronological order to prevent broken links.\n(Available for Local Backups)")
yield Rule(id="footer_rule")
yield Button("Danger Zone ⚠", id="op_danger", variant="error", disabled=True, flat=True, tooltip="Dangerous operations:\ndelete channels, roles, emojis on target\n(use with caution)")
if self.cfg_name == "AutoTest":
yield Button("RUN AUTO TEST", id="op_autotest", variant="warning", flat="false", tooltip="Execute automated test sequence for the AutoTest profile")
def on_mount(self) -> None:
self._rebuild_engine()
@ -329,13 +326,8 @@ class OperationPane(Container):
for pne in self.query("#op_target_pane"): pne.display = False
enabled = (v.get("discord_token") and v.get("discord_server") and not d_missing)
for btn in self.query("#op_backup_msgs"):
btn.disabled = not enabled
for btn in self.query("#op_backup_sync"):
btn.display = self.has_backup
btn.disabled = not (enabled and self.has_backup)
for btn in self.query("#op_autotest"):
btn.disabled = not enabled
for bid in ("#op_backup_msgs", "#op_backup_sync"):
for btn in self.query(bid): btn.disabled = not enabled
for btn in self.query("#op_backup_stats"):
btn.display = self.has_backup
@ -403,7 +395,7 @@ class OperationPane(Container):
lbl.update(f"{t_status}")
# Buttons
for bid in ("#op_clone", "#op_sync", "#op_messages", "#op_waterfall", "#op_danger", "#op_autotest"):
for bid in ("#op_clone", "#op_sync", "#op_messages", "#op_waterfall", "#op_danger"):
for btn in self.query(bid): btn.disabled = not self.tokens_valid
# ── validation ────────────────────────────────────────────────────────
@ -424,10 +416,10 @@ class OperationPane(Container):
# Disable all operation buttons while validation is in progress
if self.view_mode == "shuttle":
for bid in ("#op_clone", "#op_sync", "#op_messages", "#op_waterfall", "#op_danger", "#op_autotest"):
for bid in ("#op_clone", "#op_sync", "#op_messages", "#op_waterfall", "#op_danger"):
for btn in self.query(bid): btn.disabled = True
elif self.view_mode == "backup":
for bid in ("#op_backup_msgs", "#op_backup_sync", "#op_autotest"):
for bid in ("#op_backup_msgs", "#op_backup_sync"):
for btn in self.query(bid): btn.disabled = True
except Exception as e:
logger.error(f"Error in run_validate setup: {e}")
@ -598,124 +590,7 @@ class OperationPane(Container):
from src.ui.backup_stats import BackupStatsScreen
target_dir = Path(self._base_dir()) / f"DISCORD_BACKUP-{self.config.discord_server_id}"
self.app.push_screen(BackupStatsScreen(self.cfg_name, target_dir))
elif bid == "op_autotest" or bid == "btn_autotest":
self.run_autotest_sequence()
@work(exclusive=True)
async def run_autotest_sequence(self) -> None:
"""Entry point for the AUTO TEST sequence."""
if not self.tokens_valid:
return
modal = ProgressScreen(log_level=self.config.log_level)
self.app.push_screen(modal)
await asyncio.sleep(0.1)
try:
if self.view_mode == "shuttle":
await self._run_migration_autotest_logic(modal)
elif self.view_mode == "backup":
await self._run_backup_autotest_logic(modal)
modal.phase_report("AUTO TEST Complete", show_back=False)
modal.write("[bold green]Full automated test sequence finished successfully![/bold green]")
except Exception as e:
logger.error(f"Auto-Test Error: {e}\n{traceback.format_exc()}")
modal.write(f"[bold red]Error: {e}[/bold red]")
modal.phase_report("Auto-Test", "error", show_back=False)
finally:
self.engine.is_running = False
await self.engine.close_connections()
async def _run_migration_autotest_logic(self, modal: ProgressScreen) -> None:
"""Executes the full migration test sequence."""
modal.set_status("AUTO TEST: Launching Migration Sequence...")
modal.write("[bold yellow]Starting Migration Auto-Test...[/bold yellow]")
migrate_mod = fluxer_migrate if self.target_platform == "fluxer" else stoat_migrate
# 1. Connect and initialize state
modal.set_status("Connecting and Initializing State...")
await self.engine.start_connections()
self.engine.is_running = True
# Initialize state database early to avoid Warnings
tid = self.config.fluxer_server_id if self.target_platform == "fluxer" else self.config.stoat_server_id
tgt_info = await self.engine.writer.validate()
self.engine.ensure_state_initialized(str(tid or ""), tgt_info.get("community_name", "Target"))
# 2. Danger Zone Clean
modal.write("\n[bold red]Phase 1: Danger Zone Cleanup[/bold red]")
await self._logic_dz_delete_channels(modal)
await self._logic_dz_reset_perms(modal)
await self._logic_dz_delete_roles(modal)
await self._logic_dz_delete_assets(modal)
# 3. Clone Roles & Permissions
modal.write("\n[bold cyan]Phase 2: Cloning Roles & Permissions[/bold cyan]")
await self._logic_clone_roles(modal, force=True)
# 4. Clone Assets (Logo, Banner, Emojis, Stickers)
modal.write("\n[bold cyan]Phase 3: Syncing Metadata & Assets[/bold cyan]")
await self._logic_copy_assets(modal, ["Emoji", "Sticker"], force=True)
await self._logic_sync_metadata(modal, ["name", "icon", "banner"])
# 5. Clone Template (Structure)
modal.write("\n[bold cyan]Phase 4: Cloning Server Structure (1/2)[/bold cyan]")
await self._logic_clone_channels(modal, force=True)
await self._logic_sync_permissions(modal)
# 6. Waterfall Migration (Only for backups)
if self.engine.source_mode == "backup" :
modal.write("\n[bold cyan]Phase 5: Waterfall Message Migration[/bold cyan]")
await self._logic_waterfall_migration(modal=modal, is_autotest=True)
else:
modal.write("\n[bold yellow]Phase 5: Skipping Waterfall (Live mode selected)[/bold yellow]")
# 7. Individual Channel Migration (Automated)
modal.write("\n[bold cyan]Phase 7: Individual Channel Migration[/bold cyan]")
await self._logic_autotest_migrate_all_channels(modal=modal)
async def _run_backup_autotest_logic(self, modal: ProgressScreen) -> None:
"""Executes the full backup test sequence."""
import shutil
modal.set_status("AUTO TEST: Launching Backup Sequence...")
modal.write("[bold yellow]Starting Backup Auto-Test...[/bold yellow]")
# 1. Clear old backup directory completely
server_dir = Path(self._base_dir()) / f"DISCORD_BACKUP-{self.config.discord_server_id}"
if server_dir.exists():
modal.write(f"[yellow]Deleting existing backup directory: {server_dir}[/yellow]")
try:
shutil.rmtree(server_dir)
except Exception as e:
modal.write(f"[red]Warning: Could not delete directory: {e}[/red]")
# 2. Setup & Metadata
modal.set_status("Initializing Discord connection...")
await self.engine.discord_reader.start()
await self.exporter.setup()
self.exporter.is_running = True
modal.write("\n[bold cyan]Phase 1: Full Server Backup[/bold cyan]")
# 3. Use unified backup logic in autotest mode
all_channels = await self.engine.discord_reader.get_channels()
eligible_channels = [
c for c in all_channels
if c.type in [
self.engine.discord_reader.CHANNEL_TYPE_TEXT,
self.engine.discord_reader.CHANNEL_TYPE_NEWS,
self.engine.discord_reader.CHANNEL_TYPE_FORUM
]
]
await self._logic_full_backup(
modal=modal,
selected_channels=eligible_channels,
force_overwrite=True,
is_autotest=True
)
# ── (1) clone server template (combined) ─────────────────────────────
def _open_clone_menu(self):
@ -1137,67 +1012,16 @@ class OperationPane(Container):
# ── (5) message migration ─────────────────────────────────────────────
@work(exclusive=True)
async def run_migrate_messages(self, modal: ProgressScreen | None = None) -> None:
await self._logic_migrate_messages(modal)
async def _logic_autotest_migrate_all_channels(self, modal: ProgressScreen) -> None:
"""Automated name-based channel migration for Auto-Test."""
migrate_mod = fluxer_migrate if self.target_platform == "fluxer" else stoat_migrate
# 1. Matching
modal.set_status("Auto-matching channels by name...")
await self._perform_auto_matching()
# 2. Get channels
d_channels = await self.engine.discord_reader.get_channels()
text_channels = [c for c in d_channels if c.type in [
self.engine.discord_reader.CHANNEL_TYPE_TEXT,
self.engine.discord_reader.CHANNEL_TYPE_NEWS
]]
modal.write(f"[bold cyan]Auto-Test: Found {len(text_channels)} channels to migrate.[/bold cyan]")
# 3. Migrate loop
for i, ch in enumerate(text_channels):
if not self.engine.is_running: break
tgt_id = self.engine.state.get_target_channel_id(str(ch.id))
if not tgt_id:
modal.write(f"[yellow]Skipping #{ch.name} (no target mapping found)[/yellow]")
continue
modal.write(f"\n[bold]Migrating #{ch.name} ({i+1}/{len(text_channels)}) -> Target ID {tgt_id}[/bold]")
# Analyze (silent)
stats = await migrate_mod.analyze_migration(self.engine, source_channel_id=ch.id, after_message_id=None)
ch_total = stats["messages"]
async def update_indiv(curr):
c = curr["messages"]
modal.set_progress(c, ch_total or 100)
modal.set_item_status(f"#{ch.name}: {c}/{ch_total} messages")
await migrate_mod.migrate_messages(
self.engine,
source_channel_id=ch.id,
target_channel_id=tgt_id,
after_message_id=None,
progress_callback=update_indiv
)
modal.write("\n[bold green]Automated channel migration complete.[/bold green]")
async def _logic_migrate_messages(self, modal: ProgressScreen | None = None, is_autotest: bool = False) -> None:
async def run_migrate_messages(self) -> None:
if not self.tokens_valid:
return
migrate_mod = fluxer_migrate if self.target_platform == "fluxer" else stoat_migrate
platform_name = self.target_platform.capitalize()
if not modal:
modal = ProgressScreen(log_level=self.config.log_level)
self.app.push_screen(modal)
await asyncio.sleep(0.1)
modal = ProgressScreen(log_level=self.config.log_level)
self.app.push_screen(modal)
await asyncio.sleep(0.1)
try:
# Show info container
@ -1594,29 +1418,22 @@ class OperationPane(Container):
await self.engine.close_connections()
@work(exclusive=True)
async def run_waterfall_migration(self, modal: ProgressScreen | None = None) -> None:
await self._logic_waterfall_migration(modal)
async def _logic_waterfall_migration(self, modal: ProgressScreen | None = None, is_autotest: bool = False) -> None:
async def run_waterfall_migration(self) -> None:
if not self.tokens_valid:
return
migrate_mod = fluxer_migrate if self.target_platform == "fluxer" else stoat_migrate
platform_name = self.target_platform.capitalize()
if not modal:
modal = ProgressScreen(log_level=self.config.log_level)
self.app.push_screen(modal)
await asyncio.sleep(0.1)
modal = ProgressScreen(log_level=self.config.log_level)
self.app.push_screen(modal)
await asyncio.sleep(0.1)
try:
modal.show_info("[bold cyan]Waterfall Migration Ready[/bold cyan]", "Checking mapping and missing channels...")
modal.set_status("Connecting to Servers...")
await self.engine.start_connections()
# Ensure writer is validated before auto-matching (fixes NoneType role fetch error)
await self.engine.writer.validate()
modal.set_status("Synchronizing entity mappings...")
await self._perform_auto_matching()
@ -1643,19 +1460,15 @@ class OperationPane(Container):
prefix = "[bold cyan]📁[/bold cyan] " if mc.type == 4 else "[bold white]#[/bold white] "
modal.write(f" {prefix}{mc.name}")
if is_autotest:
choice = "btn_start_first"
modal.write("[bold cyan]Auto-Test: Automatically cloning missing channels.[/bold cyan]")
else:
choice = await modal.phase_wait_confirm(
show_continue=False,
show_id=True,
btn_start_label="Clone missing channels",
btn_id_label="Skip missing channels",
btn_start_variant="primary",
btn_start_tooltip=f"Automatically create {len(missing_channels)} entities on target",
btn_id_tooltip="Start migration without these channels"
)
choice = await modal.phase_wait_confirm(
show_continue=False,
show_id=True,
btn_start_label="Clone missing channels",
btn_id_label="Skip missing channels",
btn_start_variant="primary",
btn_start_tooltip=f"Automatically create {len(missing_channels)} entities on target",
btn_id_tooltip="Start migration without these channels"
)
if choice == "btn_back":
modal.dismiss()
@ -1737,9 +1550,11 @@ class OperationPane(Container):
if filtered_tgt_ids:
all_mapped_tgt_ids = filtered_tgt_ids
# 2.6 Resume Point: Calculate from global channel minimums
min_last_id = self.engine.state.get_global_min_last_message_id(all_mapped_tgt_ids)
# 2.6 Resume Point: Prioritize Global waterfall tracker, fallback to channel minimums
min_last_id = self.engine.state.get_waterfall_last_id()
if min_last_id is None:
min_last_id = self.engine.state.get_global_min_last_message_id(all_mapped_tgt_ids)
modal.write(f"\n[bold cyan]Waterfall Migration Resume Point:[/bold cyan]")
if min_last_id is not None:
@ -1747,19 +1562,15 @@ class OperationPane(Container):
else:
modal.write("No previous migration state found. Starting from the beginning.")
if is_autotest:
choice = "btn_continue" if min_last_id is not None else "btn_start_first"
modal.write(f"[bold cyan]Auto-Test: Automatically choosing {choice.replace('btn_', '').replace('_', ' ')}.[/bold cyan]")
else:
choice = await modal.phase_wait_confirm(
show_continue=min_last_id is not None,
show_id=False,
btn_start_label="Start From Beginning",
btn_start_tooltip="Wipes migration progress and restarts from the beginning; may create duplicates",
btn_start_variant="default" if min_last_id is not None else "primary",
btn_continue_label=f"Continue from ID {min_last_id if min_last_id is not None else 0}" if min_last_id is not None else "Continue Migration",
btn_continue_tooltip="Fastest"
)
choice = await modal.phase_wait_confirm(
show_continue=min_last_id is not None,
show_id=False,
btn_start_label="Start From Beginning",
btn_start_tooltip="Safe, skips duplicates automatically",
btn_start_variant="default" if min_last_id is not None else "primary",
btn_continue_label=f"Continue from ID {min_last_id if min_last_id is not None else 0}" if min_last_id is not None else "Continue Migration",
btn_continue_tooltip="Fastest"
)
if choice == "btn_back":
modal.dismiss()
@ -1771,11 +1582,7 @@ class OperationPane(Container):
return
after_id = None
if choice == "btn_start_first":
logger.info("Proceeding with 'Start from Beginning' (global clean sink).")
self.engine.state.clear_all_migration_data()
after_id = None
elif choice == "btn_continue" and min_last_id is not None:
if choice == "btn_continue" and min_last_id is not None:
after_id = int(min_last_id)
# Phase 3: Progress
@ -1792,7 +1599,6 @@ class OperationPane(Container):
tid = self.config.fluxer_server_id
self.engine.ensure_state_initialized(str(tid or ""), platform_name)
modal.show_stats()
modal.write("Scanning global footprint for totals ...")
stats_analysis = await migrate_mod.analyze_global_migration(self.engine, after_message_id=after_id)
total_messages = stats_analysis["messages"]
@ -2281,7 +2087,6 @@ class OperationPane(Container):
@work(exclusive=True)
async def run_backup_messages(self) -> None:
"""UI entry point for full backup."""
modal_prog = ProgressScreen(log_level=self.config.log_level)
self.app.push_screen(modal_prog)
await asyncio.sleep(0.1)
@ -2291,6 +2096,38 @@ class OperationPane(Container):
await self.engine.discord_reader.start()
await self.exporter.setup()
# Check if profile is empty
profile_exists = False
if self.exporter.db:
try:
profile_exists = self.exporter.db.get_guild_profile() is not None
except Exception:
profile_exists = False
if not profile_exists:
modal_prog.set_status("First-Time Setup: Exporting Server Profile...")
modal_prog.write("[yellow]No existing profile found. Performing primary profile backup...[/yellow]")
modal_prog.write("[yellow]Exporting server metadata...[/yellow]")
await self.exporter.export_metadata()
modal_prog.write("[yellow]Syncing server assets (icon/banner)...[/yellow]")
await self.exporter.download_server_assets()
modal_prog.write("[yellow]Exporting server structure...[/yellow]")
await self.exporter.export_channels_structure()
modal_prog.write("[yellow]Exporting roles & permissions...[/yellow]")
await self.exporter.export_roles()
modal_prog.write("[yellow]Exporting custom emojis & stickers...[/yellow]")
await self.exporter.export_assets()
modal_prog.write("[bold green]Primary profile setup complete![/bold green]")
modal_prog.write("")
else:
modal_prog.write("[dim]Existing profile detected. Scanning structure...[/dim]")
await self.exporter.export_channels_structure()
all_channels = await self.engine.discord_reader.get_channels()
all_categories = await self.engine.discord_reader.get_categories()
cat_map = {c.id: c.name for c in all_categories}
@ -2309,9 +2146,8 @@ class OperationPane(Container):
modal_prog.allow_close()
return
# Analyze which are already backed up
backed_up_ids = set()
any_found = False
backed_up_ids = set()
if self.exporter.db:
channel_stats = self.exporter.db.get_stats_by_channel()
for chan in eligible_channels:
@ -2319,105 +2155,87 @@ class OperationPane(Container):
any_found = True
backed_up_ids.add(chan.id)
# Manual selection
loop = asyncio.get_running_loop()
future = loop.create_future()
def check_channels(reply: dict | None) -> None:
if not future.done(): future.set_result(reply)
self.app.pop_screen()
self.app.push_screen(
ChannelSelectScreen(eligible_channels, cat_map, backed_up_ids, any_found),
check_channels,
)
while True:
loop = asyncio.get_running_loop()
future = loop.create_future()
reply = await future
if not reply: return
def check_channels(reply: dict | None) -> None:
if not future.done():
future.set_result(reply)
selected_ids = reply["channels"]
force_overwrite = reply["force"]
selected_channels = [c for c in eligible_channels if c.id in selected_ids]
# Confirmation phase
new_channels = [c for c in selected_channels if c.id not in backed_up_ids]
existing_channels = [c for c in selected_channels if c.id in backed_up_ids]
modal_confirm = ProgressScreen(log_level=self.config.log_level)
self.app.push_screen(modal_confirm)
await asyncio.sleep(0.1)
modal_confirm.set_status(f"Confirm to proceed with Backup of [bold]{len(selected_channels)}[/bold] channels")
modal_confirm.show_info(f"[cyan]Backup Channels[/cyan]", f"{len(new_channels)} new, {len(existing_channels)} existing")
choice = await modal_confirm.phase_wait_confirm(btn_start_label="Start Channel Backup", show_id=False)
if choice != "btn_start_first":
modal_confirm.dismiss()
return
modal_confirm.cancel_callback = lambda: setattr(self.exporter, "is_running", False)
modal_confirm.phase_progress()
await self._logic_full_backup(
modal=modal_confirm,
selected_channels=selected_channels,
force_overwrite=force_overwrite,
is_autotest=False
)
except Exception as e:
logger.error(f"Backup Error: {traceback.format_exc()}")
modal_prog.write(f"[bold red]Backup failed: {e}[/bold red]")
modal_prog.phase_report("Backup", "error", show_back=False)
finally:
await self.engine.close_connections()
async def _logic_full_backup(self, modal: ProgressScreen, selected_channels: list, force_overwrite: bool, is_autotest: bool = False) -> None:
"""Non-interactive core backup logic."""
if not self.exporter.is_running:
self.exporter.is_running = True
try:
# 1. Metadata and Assets
modal.set_status("Exporting Server Profile...")
await self.exporter.export_metadata()
await self.exporter.download_server_assets()
await self.exporter.export_channels_structure()
await self.exporter.export_roles()
await self.exporter.export_assets()
# Pre-fetch all members once for role resolution during message export
modal.set_status("Pre-fetching server members...")
await self.exporter.prefetch_members()
# 2. Channel Messages
total_chans = len(selected_channels)
modal.write(f"\n[bold cyan]Backing up {total_chans} channels...[/bold cyan]")
modal.show_stats()
for i, chan in enumerate(selected_channels):
if not self.exporter.is_running: break
modal.set_item_status(f"[cyan]Processing ({i+1}/{total_chans}): #{chan.name}[/cyan]")
modal.set_progress(i, total_chans)
modal.write(f"[cyan]Backing up: #{chan.name}[/cyan]")
async def update_backup(name, count, author_name=None, message_preview=None, thread_count=0, file_count=0):
modal.update_stats(messages=str(count), threads=str(thread_count), files=str(file_count))
if author_name and message_preview and count % 20 == 0:
modal.write(f"[dim]{author_name}:[/dim] {message_preview}")
await self.exporter.export_channel_messages(
chan.id, progress_callback=update_backup, force=force_overwrite
self.app.push_screen(
ChannelSelectScreen(eligible_channels, cat_map, backed_up_ids, any_found),
check_channels,
)
modal.write(f"[green]Completed: #{chan.name}[/green]")
modal.set_progress(total_chans, total_chans)
modal.write("[bold green]Backup complete![/bold green]")
modal.phase_report("Full Backup", show_back=False)
reply = await future
if not reply:
return
except Exception as e:
modal.write(f"[bold red]Core backup failed: {e}[/bold red]")
logger.error(f"Core Backup Error: {traceback.format_exc()}")
raise e
selected_ids = reply["channels"]
force_overwrite = reply["force"]
selected_channels = [c for c in eligible_channels if c.id in selected_ids]
# Phase 2: Confirmation
modal_prog = ProgressScreen(log_level=self.config.log_level) # Re-instantiate to avoid Textual re-push UI freeze
self.app.push_screen(modal_prog)
await asyncio.sleep(0.1)
new_channels = [c for c in selected_channels if c.id not in backed_up_ids]
existing_channels = [c for c in selected_channels if c.id in backed_up_ids]
server = getattr(self.engine.discord_reader, 'guild', None)
if server:
modal_prog.write(f"[bold cyan]Server Profile:[/bold cyan]")
modal_prog.write(f" Name: [green]{server.name}[/green]")
modal_prog.write(f" Icon: [green]{'Present' if server.icon else 'None'}[/green]")
modal_prog.write("")
modal_prog.set_status(f"Confirm to proceed with Backup of [bold]{len(selected_channels)}[/bold] channels")
modal_prog.show_info(f"[cyan]Backup Channels[/cyan]", f"{len(new_channels)} new, {len(existing_channels)} existing")
# Show categorized channel lists in the bottom log
if new_channels:
modal_prog.write("[bold green]New Backups to be created:[/bold green]")
for idx, c in enumerate(new_channels):
modal_prog.write(f" {idx+1}. #{c.name}")
if existing_channels:
action = "Overwritten" if force_overwrite else "Updated"
modal_prog.write(f"[bold yellow]\nExisting backups to be {action}:[/bold yellow]")
for idx, c in enumerate(existing_channels):
modal_prog.write(f" {idx+1}. #{c.name}")
choice = await modal_prog.phase_wait_confirm(btn_start_label="Start Channel Backup", show_id=False)
if choice == "btn_back":
modal_prog.dismiss()
continue
elif choice == "btn_start_id":
loop = asyncio.get_running_loop()
future = loop.create_future()
def id_callback(res: int | None) -> None:
if not future.done():
future.set_result(res)
id_modal = MessageIDInputModal(self.engine.discord_reader, selected_channels[0].id)
self.app.push_screen(id_modal, id_callback)
verified_id = await future
if verified_id is None:
# User cancelled the ID input
continue
after_id = verified_id
elif choice == "btn_main_menu":
modal_prog.dismiss()
return
# If we are here, proceeding either via Start First or Start from ID (after_id)
if choice == "btn_start_first":
after_id = None
break
modal_prog.phase_progress()
modal_prog.show_stats()

View file

@ -1,170 +0,0 @@
import pytest
import shutil
import yaml
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
from src.core.backup_database import BackupDatabase
from src.core.configuration import AppConfig
LOG_FILE = (Path(__file__).parent.parent / ".reaper_tests.log").resolve()
def _log(msg: str):
"""Write a message to the shared test log file."""
with open(LOG_FILE, "a") as f:
f.write(f"{msg}\n")
@pytest.fixture
def log():
"""Fixture to provide the logging function to tests."""
return _log
def pytest_configure(config):
"""Register global warning filters and marks."""
# Register marks if needed
config.addinivalue_line("markers", "asyncio: mark test as asyncio")
# Silence benign async mock and ResourceWarnings globally
config.addinivalue_line("filterwarnings", "ignore::RuntimeWarning")
config.addinivalue_line("filterwarnings", "ignore::ResourceWarning")
def pytest_sessionstart(session):
"""Clear the log file at the beginning of the test session."""
with open(LOG_FILE, "w") as f:
f.write("--- Reaper Test Session Started ---\n")
def pytest_report_header(config):
"""Print data source status to the console header."""
test_data_dir = (Path(__file__).parent.parent / "ReaperFiles-AutoTest").resolve()
if test_data_dir.exists():
return f"[DATA_SOURCE] Automated Test Directory: FOUND ({test_data_dir.name})"
else:
return "[DATA_SOURCE] Automated Test Directory: NOT FOUND (Falling back to MOCKS)"
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
def pytest_runtest_makereport(item, call):
"""Hook to catch test results and log them to .reaper_tests.log."""
outcome = yield
rep = outcome.get_result()
# We only log on the 'call' phase (the actual test run)
if rep.when == "call":
status = rep.outcome.upper()
_log(f"RESULT: {item.nodeid} -> {status}")
if rep.failed:
_log(f"--- FAILURE DETAILS for {item.name} ---\n{rep.longreprtext}\n---------------------")
elif not rep.passed:
# Catch setup/teardown failures
_log(f"ERROR in {item.nodeid} [{rep.when}]: {rep.outcome.upper()}")
if rep.failed:
_log(f"--- ERROR DETAILS ---\n{rep.longreprtext}\n---------------------")
def pytest_warning_recorded(warning_message, when, nodeid, location):
"""Hook to catch and log warnings to .reaper_tests.log."""
msg = f"WARNING: {warning_message.message}"
if nodeid:
msg = f"WARNING in {nodeid} [{when}]: {warning_message.message}"
_log(msg)
@pytest.fixture
def test_data_dir():
return (Path(__file__).parent.parent / "ReaperFiles-AutoTest").resolve()
@pytest.fixture
def reaper_config(test_data_dir):
config_path = test_data_dir / "reaper_config.yaml"
if config_path.exists():
with open(config_path, "r") as f:
data = yaml.safe_load(f)
return data
# Fallback mock config
return {
"discord_bot_token": "mock_discord_token",
"discord_server_id": "123456789012345678",
"tool_mode": "backup_transfer",
"target_platform": "fluxer",
"fluxer_bot_token": "mock_fluxer_token",
"fluxer_server_id": "987654321098765432",
"stoat_bot_token": "mock_stoat_token",
"stoat_server_id": "MOCK_STOAT_COMMUNITY",
"anonymize_users": False,
"log_level": "DEBUG"
}
@pytest.fixture
def temp_db(test_data_dir, tmp_path, reaper_config):
# If the real test data exists, use it
if test_data_dir.exists():
target_id = reaper_config.get("fluxer_server_id") if reaper_config.get("target_platform") == "fluxer" else reaper_config.get("stoat_server_id")
db_files = list(test_data_dir.glob(f"*-{target_id}.db"))
if not db_files:
db_files = list(test_data_dir.glob("*.db"))
if db_files:
original_db = db_files[0]
temp_db_path = tmp_path / "test_migration.db"
print(f"[DATA_SOURCE] USE_SAMPLE_DB: {original_db.name}")
_log(f"[DATA_SOURCE] USE_SAMPLE_DB: {original_db.name}")
shutil.copy(original_db, temp_db_path)
return temp_db_path
# Fallback: Create an empty mock database with the required schema
temp_db_path = tmp_path / "mock_migration.db"
print("[DATA_SOURCE] USE_MOCK_DB: Fallback to empty schema")
_log("[DATA_SOURCE] USE_MOCK_DB: Fallback to empty schema")
db = BackupDatabase(temp_db_path)
# The BackupDatabase constructor already initializes the schema
return temp_db_path
@pytest.fixture
def backup_db(temp_db):
db = BackupDatabase(temp_db)
yield db
# No explicit close needed for now as it's a persistent connection
@pytest.fixture
def backup_reader(test_data_dir, reaper_config, tmp_path):
sid = reaper_config.get("discord_server_id")
backup_path = test_data_dir / f"DISCORD_BACKUP-{sid}"
if not test_data_dir.exists() or not backup_path.exists():
# Fallback: create mock backup structure
mock_path = tmp_path / f"DISCORD_BACKUP-{sid}"
print(f"[DATA_SOURCE] USE_MOCK_BACKUP: {mock_path.name}")
_log(f"[DATA_SOURCE] USE_MOCK_BACKUP: {mock_path.name}")
mock_path.mkdir(parents=True, exist_ok=True)
db_path = mock_path / "backup.db"
db = BackupDatabase(db_path)
# Populate with minimal mock data for BackupReader to work
db._conn.execute("INSERT OR IGNORE INTO guild_profile (id, name) VALUES (?, ?)", (int(sid), "Mock Guild"))
db._conn.execute("INSERT OR IGNORE INTO channels (id, name, type) VALUES (?, ?, ?)", (123, "mock-channel", 0))
db._conn.commit()
from src.core.backup_reader import BackupReader
return BackupReader(mock_path)
print(f"[DATA_SOURCE] USE_SAMPLE_BACKUP: {backup_path.name}")
_log(f"[DATA_SOURCE] USE_SAMPLE_BACKUP: {backup_path.name}")
from src.core.backup_reader import BackupReader
return BackupReader(backup_path)
@pytest.fixture
def mock_discord_reader():
reader = MagicMock()
reader.guild = MagicMock()
reader.fetch_message_history = AsyncMock()
reader.download_attachment = AsyncMock(return_value=b"fake_data")
reader.download_sticker = AsyncMock(return_value=b"fake_sticker_data")
return reader
@pytest.fixture
def mock_fluxer_writer():
writer = MagicMock()
writer.send_message = AsyncMock(return_value="fluxer_msg_123")
writer.send_marker = AsyncMock()
return writer
@pytest.fixture
def mock_stoat_writer():
writer = MagicMock()
writer.send_message = AsyncMock(return_value="stoat_msg_123")
writer.send_marker = AsyncMock()
return writer

View file

@ -1,139 +0,0 @@
import pytest
from src.core.backup_database import BackupDatabase
@pytest.fixture
def db():
from tests.conftest import _log
msg = "[DATA_SOURCE] UNIT_TEST: Using in-memory database"
print(msg)
_log(msg)
return BackupDatabase(":memory:")
def test_db_initialization(db):
res = db._conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='guild_profile'"
).fetchone()
assert res is not None
assert res["name"] == "guild_profile"
def test_save_get_guild_profile(db):
profile = {
"id": "123456789",
"name": "Test Server",
"description": "Testing",
"owner_id": "987654321",
"ignore_channels": ["111", "222"],
}
db.set_guild_profile(profile)
got = db.get_guild_profile()
assert got["name"] == "Test Server"
assert got["id"] == 123456789
assert got["ignore_channels"] == ["111", "222"]
def test_save_get_roles(db):
roles = [
{"id": "1", "name": "Admin", "color": 0xFF0000, "position": 1, "permissions": 8, "hoist": True, "mentionable": True},
{"id": "2", "name": "User", "color": 0x0000FF, "position": 2, "permissions": 0, "hoist": False, "mentionable": False},
]
db.save_roles(roles)
got = db.get_all_roles()
assert len(got) == 2
assert {r["name"] for r in got} == {"Admin", "User"}
def test_save_get_channels(db):
channels = [
{"id": 10, "name": "general", "type": 0, "position": 0, "category_id": None,
"topic": "Talk", "nsfw": 0, "bitrate": None, "slowmode_delay": None},
{"id": 20, "name": "voice", "type": 2, "position": 1, "category_id": None,
"topic": None, "nsfw": 0, "bitrate": 64000, "slowmode_delay": None},
]
db.save_channels(channels)
got = db.get_all_channels()
assert len(got) == 2
assert {c["name"] for c in got} == {"general", "voice"}
def test_save_messages_and_attachments(db):
messages = [
{
"id": 101, "channel_id": 10, "author_id": 999, "content": "Hello",
"timestamp": "2023-01-01T00:00:00Z", "type": 0,
"message_reference": None, "is_pinned": 0, "extra_data": None,
"attachments": [
{"id": 1, "filename": "file.png", "size": 100,
"url": "http://cdn.test/file.png", "content_type": "image/png", "local_hash": "abc"}
],
}
]
db.save_messages_batch(messages)
msg = db._conn.execute("SELECT content FROM messages WHERE id=101").fetchone()
assert msg["content"] == "Hello"
att = db._conn.execute("SELECT filename FROM attachments WHERE message_id=101").fetchone()
assert att["filename"] == "file.png"
def test_get_last_message_id(db):
msgs = [
{"id": 200, "channel_id": 10, "author_id": 1, "content": "a",
"timestamp": "2023-01-01T00:00:00Z", "type": 0,
"message_reference": None, "is_pinned": 0, "extra_data": None},
{"id": 201, "channel_id": 10, "author_id": 1, "content": "b",
"timestamp": "2023-01-01T00:01:00Z", "type": 0,
"message_reference": None, "is_pinned": 0, "extra_data": None},
]
db.save_messages_batch(msgs)
assert db.get_last_message_id("10") == 201
def test_stats_by_channel(db):
msgs = [
{"id": 101, "channel_id": 10, "author_id": 1, "content": "Hi",
"timestamp": "2023-01-01T00:00:00Z", "type": 0,
"message_reference": None, "is_pinned": 0, "extra_data": None},
{"id": 102, "channel_id": 10, "author_id": 1, "content": "Bye",
"timestamp": "2023-01-01T00:01:00Z", "type": 0,
"message_reference": None, "is_pinned": 0, "extra_data": None},
]
db.save_messages_batch(msgs)
stats = db.get_stats_by_channel()
assert stats[10]["message_count"] == 2
def test_save_threads(db):
threads = [
{"id": 300, "name": "thread-1", "type": 11, "parent_id": 10,
"message_count": 5, "member_count": 2, "archived": 0,
"archive_timestamp": None, "auto_archive_duration": 1440,
"locked": 0, "applied_tags": None},
]
db.save_threads(threads)
got = db.get_threads_by_parent("10")
assert len(got) == 1
assert got[0]["name"] == "thread-1"
def test_media_pool(db):
db.add_media_to_pool("hash123", "/path/file.png", 512, "image/png", "http://cdn.test/file.png")
db._conn.commit()
entry = db.get_media_by_hash("hash123")
assert entry is not None
assert entry["local_path"] == "/path/file.png"
def test_get_messages_paged_after_id(db):
msgs = [
{"id": i, "channel_id": 10, "author_id": 1, "content": f"msg{i}",
"timestamp": f"2023-01-01T00:0{i}:00Z", "type": 0,
"message_reference": None, "is_pinned": 0, "extra_data": None}
for i in range(5)
]
db.save_messages_batch(msgs)
page = db.get_messages_paged("10", limit=3, offset=0)
assert len(page) == 3
page_after = db.get_messages_paged("10", limit=10, after_id="2")
assert all(m["id"] > 2 for m in page_after)

View file

@ -1,138 +0,0 @@
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from src.core.base import MigrationContext
from src.core.configuration import AppConfig
from src.core.backup_reader import ChannelType
from src.fluxer.migrate_message import migrate_messages as fluxer_migrate, _process_and_send_message as fluxer_send
from src.stoat.migrate_message import migrate_messages as stoat_migrate, _process_and_send_message as stoat_send
import yaml
from pathlib import Path
# --- Platform Detection (Same as e2e_simulation) ---
def get_platforms():
"""Determine which platforms to test based on config, or use defaults."""
config_path = (Path(__file__).parent.parent / "ReaperFiles-AutoTest/reaper_config.yaml").resolve()
if not config_path.exists():
return ["fluxer", "stoat"]
with open(config_path, "r") as f:
data = yaml.safe_load(f)
platforms = []
if data.get("fluxer_bot_token"): platforms.append("fluxer")
if data.get("stoat_bot_token"): platforms.append("stoat")
return platforms if platforms else ["fluxer", "stoat"]
# --- Unit Tests (Transformation Logic) ---
@pytest.fixture
def mock_context(mock_discord_reader, mock_fluxer_writer, mock_stoat_writer):
context = MagicMock(spec=MigrationContext)
context.discord_reader = mock_discord_reader
context.fluxer_writer = mock_fluxer_writer
context.stoat_writer = mock_stoat_writer
context.state = MagicMock()
context.state.get_user_alias.return_value = "TestAlias"
context.state.emoji_map = {}
context.state.channel_map = {}
context.is_running = True
return context
@pytest.fixture
def mock_message():
msg = MagicMock()
msg.id = 111
msg.author.id = 222
msg.author.display_name = "Author"
msg.content = "Test content"
msg.attachments = []
msg.embeds = []
msg.stickers = []
msg.created_at.timestamp.return_value = 1600000000.0
msg.flags.forwarded = False
msg.reference = None
return msg
@pytest.mark.asyncio
async def test_migration_transform_fluxer(mock_context, mock_message):
stats = {"messages": 0, "attachments": 0}
result = await fluxer_send(context=mock_context, msg=mock_message, target_channel_id="c1", stats=stats)
assert result == "fluxer_msg_123"
assert stats["messages"] == 1
assert mock_context.fluxer_writer.send_message.called
@pytest.mark.asyncio
async def test_migration_transform_stoat(mock_context, mock_message):
stats = {"messages": 0, "attachments": 0}
result = await stoat_send(context=mock_context, msg=mock_message, target_channel_id="c1", stats=stats)
assert result == "stoat_msg_123"
assert stats["messages"] == 1
assert mock_context.stoat_writer.send_message.called
# --- Integration Tests (Backup Reader) ---
@pytest.mark.asyncio
async def test_backup_reader_interaction(backup_reader, reaper_config):
await backup_reader.start()
assert backup_reader.guild is not None
# Verify ID from config (or mock default)
assert str(backup_reader.guild.id) == reaper_config.get("discord_server_id")
channels = await backup_reader.fetch_channels()
assert len(channels) > 0
# --- E2E Simulation ---
@pytest.mark.asyncio
@pytest.mark.parametrize("platform", get_platforms())
async def test_migration_e2e_loop(reaper_config, test_data_dir, tmp_path, platform, request):
config = AppConfig(
discord_bot_token=reaper_config["discord_bot_token"],
discord_server_id=reaper_config["discord_server_id"],
target_platform=platform,
fluxer_bot_token=reaper_config.get("fluxer_bot_token"),
fluxer_server_id=reaper_config.get("fluxer_server_id"),
stoat_bot_token=reaper_config.get("stoat_bot_token"),
stoat_server_id=reaper_config.get("stoat_server_id"),
anonymize_users=reaper_config["anonymize_users"]
)
if test_data_dir.exists():
base_dir = test_data_dir
msg = f"[DATA_SOURCE] E2E_SIMULATION: Using sample data from {base_dir.name}"
else:
base_dir = tmp_path
msg = f"[DATA_SOURCE] E2E_SIMULATION: Fallback to mock data in {base_dir}"
from tests.conftest import _log
print(msg)
_log(msg)
context = MigrationContext(config, source_mode="backup", base_dir=str(base_dir))
mock_writer = request.getfixturevalue(f"mock_{platform}_writer")
if platform == "fluxer":
context.fluxer_writer = mock_writer
migrate_func = fluxer_migrate
else:
context.stoat_writer = mock_writer
migrate_func = stoat_migrate
context.is_running = True
from src.core.database import MigrationDatabase
context.state.db = MigrationDatabase(tmp_path / f"e2e_{platform}.db", platform=platform)
await context.discord_reader.start()
channels = await context.discord_reader.fetch_channels()
text_channels = [c for c in channels if c.type == ChannelType.text]
if text_channels:
source_channel_id = text_channels[0].id
target_channel_id = "999" if platform == "fluxer" else "stoat123"
mock_writer.send_message.side_effect = lambda **kwargs: "ok"
# Test just the first available channel to keep it fast
stats = await migrate_func(context=context, source_channel_id=source_channel_id, target_channel_id=target_channel_id)
assert stats["messages"] >= 0
await context.discord_reader.close()

View file

@ -1,164 +0,0 @@
import pytest
import asyncio
from pathlib import Path
from unittest.mock import MagicMock, AsyncMock, patch
from textual.app import App
from src.ui.main_app import ReaperApp, ConfigSelectionScreen, ConfigScreen
from src.ui.mode_screen import ModeScreen
from src.core.configuration import AppConfig
import os
from textual.widgets import ListItem, ListView, Input, Button, Label
@pytest.fixture
def mock_configs(tmp_path, log):
reaper_dir = tmp_path / "ReaperFiles-TestConfig"
reaper_dir.mkdir()
(reaper_dir / "reaper_config.yaml").write_text("discord_bot_token: 'fake'\ndiscord_server_id: '123'\ntool_mode: 'backup_only'")
autotest_dir = tmp_path / "ReaperFiles-AutoTest"
autotest_dir.mkdir()
(autotest_dir / "reaper_config.yaml").write_text("discord_bot_token: 'fake'\ndiscord_server_id: '123'\ntool_mode: 'backup_transfer'")
old_cwd = os.getcwd()
os.chdir(tmp_path)
log(f"CWD changed to: {tmp_path}")
yield tmp_path
os.chdir(old_cwd)
async def wait_for_screen(app, screen_class, timeout=5.0):
import time
start = time.time()
while time.time() - start < timeout:
if isinstance(app.screen, screen_class):
return True
await asyncio.sleep(0.1)
return False
@pytest.mark.asyncio
async def test_ui_minimal_launch(mock_configs, log):
"""Verify app launch and screen transition to ModeScreen."""
log("Running test_ui_minimal_launch")
try:
# Reverting to AsyncMock to avoid WorkerError.
# RuntimeWarnings are now handled globally in conftest.py.
with patch("src.ui.main_app.ConfigSelectionScreen.check_updates", AsyncMock()):
app = ReaperApp()
async with app.run_test() as pilot:
await wait_for_screen(app, ConfigSelectionScreen)
await pilot.click(ListItem)
await wait_for_screen(app, ModeScreen)
assert isinstance(app.screen, ModeScreen)
log("test_ui_minimal_launch PASSED")
except Exception as e:
log(f"test_ui_minimal_launch FAILED: {e}")
raise
@pytest.mark.asyncio
async def test_ui_config_wizard_save(mock_configs, log):
"""Verify configuration editing and saving."""
log("Running test_ui_config_wizard_save")
try:
with patch("src.ui.main_app.ConfigSelectionScreen.check_updates", AsyncMock()):
app = ReaperApp()
async with app.run_test() as pilot:
await wait_for_screen(app, ConfigSelectionScreen)
await pilot.click(ListItem)
await wait_for_screen(app, ModeScreen)
await pilot.click("#btn_config")
await wait_for_screen(app, ConfigScreen)
screen = app.screen
inp = screen.query_one("#inp_discord_token", Input)
inp.value = "new_fake_token"
with patch("src.ui.main_app.save_config") as mock_save:
await pilot.click("#btn_save")
await pilot.pause(0.2)
assert mock_save.called
log("test_ui_config_wizard_save PASSED")
except Exception as e:
log(f"test_ui_config_wizard_save FAILED: {e}")
raise
@pytest.mark.asyncio
async def test_ui_operation_trigger(mock_configs, log):
"""Verify that an operation can be triggered."""
log("Running test_ui_operation_trigger")
from src.ui.shuttle_ops import OperationPane
from src.ui.modals import ChannelPickerScreen, ProgressScreen
try:
with patch("src.ui.main_app.ConfigSelectionScreen.check_updates", AsyncMock()):
# run_validate is decorated with @work in shuttle_ops.py, so it MUST be a coroutine (AsyncMock)
with patch.object(OperationPane, "run_validate", AsyncMock()):
app = ReaperApp()
async with app.run_test() as pilot:
await wait_for_screen(app, ConfigSelectionScreen)
await pilot.click(ListItem)
await wait_for_screen(app, ModeScreen)
pane = app.screen.query_one(OperationPane)
pane.tokens_valid = True
pane.src_channels = [{"id": 1, "name": "t"}]
pane.src_cat_map = {None: "D"}
pane.tgt_channels = [{"id": 2, "name": "t"}]
pane.tgt_cat_map = {None: "D"}
pane.all_tgt_channels = pane.tgt_channels
btn = pane.query_one("#op_backup_msgs", Button)
btn.disabled = False
await pilot.pause(0.2)
btn.focus()
await pilot.press("enter")
await wait_for_screen(app, (ChannelPickerScreen, ProgressScreen))
assert isinstance(app.screen, (ChannelPickerScreen, ProgressScreen))
log("test_ui_operation_trigger PASSED")
except Exception as e:
log(f"test_ui_operation_trigger FAILED: {e}")
raise
@pytest.mark.asyncio
async def test_ui_autotest_button(mock_configs, log):
"""Verify visibility and trigger of the AUTO TEST button."""
log("Running test_ui_autotest_button")
from src.ui.shuttle_ops import OperationPane
try:
with patch("src.ui.main_app.ConfigSelectionScreen.check_updates", AsyncMock()):
with patch.object(OperationPane, "run_validate", AsyncMock()):
app = ReaperApp()
async with app.run_test() as pilot:
await wait_for_screen(app, ConfigSelectionScreen)
# Find the AutoTest item in the list
lv = app.screen.query_one(ListView)
autotest_index = -1
for idx, item in enumerate(lv.children):
if item.name == "AutoTest":
autotest_index = idx
break
assert autotest_index != -1, "AutoTest profile not found in ListView"
target_item = lv.children[autotest_index]
await pilot.click(target_item)
assert await wait_for_screen(app, ModeScreen), "Timed out waiting for ModeScreen"
# Verify button is present
pane = app.screen.query_one(OperationPane)
btn = pane.query_one("#op_autotest", Button)
assert btn.display is True
assert "AUTO TEST" in str(btn.label)
# Mock the sequence and trigger it
with patch.object(OperationPane, "run_autotest_sequence", AsyncMock()) as mock_seq:
btn.disabled = False
await pilot.pause(0.1)
btn.focus()
await pilot.press("enter")
await pilot.pause(0.1)
assert mock_seq.called
log("test_ui_autotest_button PASSED")
except Exception as e:
log(f"test_ui_autotest_button FAILED: {e}")
raise

View file

@ -1,60 +0,0 @@
import pytest
from unittest.mock import MagicMock
from src.core.utils import parse_snowflake, resolve_discord_links
def test_parse_snowflake_valid():
assert parse_snowflake("12345") == 12345
assert parse_snowflake(12345) == 12345
assert parse_snowflake(" 67890 ") == 67890
def test_parse_snowflake_invalid():
assert parse_snowflake(None) is None
assert parse_snowflake("") is None
assert parse_snowflake("none") is None
assert parse_snowflake("NULL") is None
assert parse_snowflake("not_a_number") is None
def test_resolve_discord_links_no_content():
assert resolve_discord_links("", None, "fluxer", "target_id") == ""
assert resolve_discord_links(None, None, "fluxer", "target_id") is None
def test_resolve_discord_links_no_mapping():
mock_state = MagicMock()
mock_state.get_target_channel_id.return_value = None
mock_state.get_target_category_id.return_value = None
mock_state.find_message_mapping.return_value = (None, None)
content = "Check this: https://discord.com/channels/1/2/3"
resolved = resolve_discord_links(content, mock_state, "fluxer", "target_server")
assert "[`discord-message`](<https://discord.com/channels/1/2/3>)" in resolved
def test_resolve_discord_links_channel_mapping():
mock_state = MagicMock()
mock_state.get_target_channel_id.return_value = "target_chan_456"
mock_state.find_message_mapping.return_value = (None, None)
content = "Go to https://discord.com/channels/123/456"
# Test Fluxer
resolved_fluxer = resolve_discord_links(content, mock_state, "fluxer", "target_server")
assert "https://fluxer.app/channels/target_server/target_chan_456" in resolved_fluxer
# Test Stoat
resolved_stoat = resolve_discord_links(content, mock_state, "stoat", "target_server")
assert "https://stoat.chat/server/target_server/channel/target_chan_456" in resolved_stoat
def test_resolve_discord_links_message_mapping():
mock_state = MagicMock()
mock_state.find_message_mapping.return_value = ("target_chan_456", "target_msg_789")
content = "Look at this: https://discord.com/channels/123/456/789"
# Test Fluxer
resolved_fluxer = resolve_discord_links(content, mock_state, "fluxer", "target_server")
assert "https://fluxer.app/channels/target_server/target_chan_456/target_msg_789" in resolved_fluxer
def test_resolve_discord_links_skips_wrapped():
mock_state = MagicMock()
content = "Already wrapped: [link](https://discord.com/channels/1/2/3)"
resolved = resolve_discord_links(content, mock_state, "fluxer", "target_server")
assert resolved == content