Compare commits
No commits in common. "089800374957f3dc549e48dc72f68e8d4f97a308" and "fb773ba948dd2778dd2d4cb7ac205a2b6fa4dd75" have entirely different histories.
0898003749
...
fb773ba948
25 changed files with 742 additions and 2197 deletions
5
.gitignore
vendored
5
.gitignore
vendored
|
|
@ -21,7 +21,6 @@ wheels/
|
|||
.installed.cfg
|
||||
*.egg
|
||||
*.txt
|
||||
*.exe
|
||||
|
||||
# Virtual Environment
|
||||
venv/
|
||||
|
|
@ -41,11 +40,11 @@ reaper_config.yaml
|
|||
*.log
|
||||
|
||||
# Temporary Test Scripts
|
||||
#test_*.py
|
||||
tmp/
|
||||
test_*.py
|
||||
test_release.zip
|
||||
test_release/
|
||||
DiscoReaper
|
||||
DiscoReaper-*
|
||||
*.zip
|
||||
|
||||
# App data files
|
||||
|
|
|
|||
14
README.md
14
README.md
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
**DiscoReaper** is a powerful tool designed to help you migrate your entire Discord server to Fluxer or Stoat. It clones channels, roles, emojis, permissions, and also your community's full message history.
|
||||
|
||||
### Get it here: [](https://git.mithraic.cloud/ad3laid3/disco-reaper/releases/latest) [](https://git.mithraic.cloud/ad3laid3/disco-reaper/releases/latest)
|
||||
### Get it here: [](https://github.com/rambros3d/disco-reaper/releases/latest/download/disco-reaper-linux.zip) [](https://github.com/rambros3d/disco-reaper/releases/latest/download/disco-reaper-windows.zip) [](https://github.com/rambros3d/disco-reaper/releases/latest/download/disco-reaper-macos.zip)
|
||||
|
||||
|
||||
>Join our [**Reaper Community**](https://fluxer.gg/9KxDP8WH) if you need help or have any questions.
|
||||
|
|
@ -81,7 +81,7 @@ Migrate directly from Discord or your Local Backups to the target community.
|
|||
#### Setup the bots as per this [guide](BOT-SETUP.md)
|
||||
|
||||
### Option 1: Using Pre-built Binaries (Easiest)
|
||||
1. **Download**: Grab the latest version from the [Releases](https://git.mithraic.cloud/ad3laid3/disco-reaper/releases) page.
|
||||
1. **Download**: Grab the latest version from the [Releases](https://github.com/rambros3d/disco-reaper/releases) page.
|
||||
2. **Run**:
|
||||
- **Linux**: Run the `disco-reaper` binary (e.g., `./launch.sh` or double-click).
|
||||
- **Windows**: Run `disco-reaper.exe`.
|
||||
|
|
@ -89,7 +89,7 @@ Migrate directly from Discord or your Local Backups to the target community.
|
|||
### Option 2: Running from Source (To use latest unstable code)
|
||||
1. **Clone**: Clone the repository to your local machine:
|
||||
```bash
|
||||
git clone https://git.mithraic.cloud/ad3laid3/disco-reaper.git
|
||||
git clone https://github.com/rambros3d/disco-reaper.git
|
||||
cd disco-reaper
|
||||
```
|
||||
2. **Launch**: Run the appropriate launcher script for your OS. It will automatically create a virtual environment and install dependencies:
|
||||
|
|
@ -153,4 +153,10 @@ But now their own website states that **Persona** will be used in some countries
|
|||
|
||||
## Contributors
|
||||
|
||||
MiTHRAL — fork maintainer, Stoat/Revolt integration & bug fixes.
|
||||
<a href="https://github.com/rambros3d/disco-reaper/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=rambros3d/disco-reaper" />
|
||||
</a>
|
||||
|
||||
Made with [contrib.rocks](https://contrib.rocks).
|
||||
|
||||
[](https://www.star-history.com/?repos=rambros3d%2Fdisco-reaper&type=date&legend=top-left)
|
||||
|
|
|
|||
77
build.bat
77
build.bat
|
|
@ -1,77 +0,0 @@
|
|||
@echo off
|
||||
setlocal enabledelayedexpansion
|
||||
cd /d "%~dp0"
|
||||
|
||||
echo --- Disco-Reaper Windows Build Script ---
|
||||
|
||||
REM Check for venv
|
||||
IF NOT EXIST "venv" (
|
||||
echo Creating virtual environment...
|
||||
python -m venv venv
|
||||
IF ERRORLEVEL 1 (
|
||||
echo Error: Failed to create virtual environment.
|
||||
echo Ensure Python is installed and added to PATH.
|
||||
IF NOT DEFINED AUTO_BUILD pause
|
||||
exit /b 1
|
||||
)
|
||||
)
|
||||
|
||||
echo Activating virtual environment...
|
||||
call venv\Scripts\activate.bat
|
||||
|
||||
REM Self-healing pip check
|
||||
python -m pip --version >nul 2>&1
|
||||
IF ERRORLEVEL 1 (
|
||||
echo Warning: pip is missing or broken in venv. Attempting repair...
|
||||
python -m ensurepip --default-pip
|
||||
IF ERRORLEVEL 1 (
|
||||
echo Error: Failed to repair pip automatically.
|
||||
echo Try recreating the venv: rmdir /s /q venv ^&^& python -m venv venv
|
||||
IF NOT DEFINED AUTO_BUILD pause
|
||||
exit /b 1
|
||||
)
|
||||
)
|
||||
|
||||
echo Ensuring build dependencies are up to date...
|
||||
REM python -m pip install --upgrade pip --quiet
|
||||
python -m pip install pyinstaller --quiet
|
||||
python -m pip install -r requirements.txt --quiet
|
||||
|
||||
echo Cleaning previous build artifacts...
|
||||
IF EXIST "build" rmdir /s /q build
|
||||
IF EXIST "dist" rmdir /s /q dist
|
||||
|
||||
echo Starting PyInstaller build...
|
||||
REM Get git version tag
|
||||
set "GIT_VERSION=Unknown"
|
||||
for /f "tokens=*" %%i in ('git describe --tags --abbrev^=0 2^>nul') do set "GIT_VERSION=%%i"
|
||||
echo Baking version: %GIT_VERSION%
|
||||
echo __version__ = "%GIT_VERSION%"> src\core\_baked_version.py
|
||||
|
||||
python -m PyInstaller --clean disco-reaper.spec
|
||||
IF ERRORLEVEL 1 (
|
||||
echo Error: PyInstaller build failed.
|
||||
del /f src\core\_baked_version.py 2>nul
|
||||
IF NOT DEFINED AUTO_BUILD pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo Cleaning up baked version file...
|
||||
del /f src\core\_baked_version.py 2>nul
|
||||
|
||||
echo Packaging release: disco-reaper-windows.zip...
|
||||
cd dist
|
||||
powershell -Command "Compress-Archive -Path 'DiscoReaper.exe' -DestinationPath 'disco-reaper-windows.zip' -Force" 2>nul
|
||||
IF ERRORLEVEL 1 (
|
||||
echo Warning: Failed to create zip. Files are available in dist\ directory.
|
||||
) ELSE (
|
||||
echo Package created: dist\disco-reaper-windows.zip
|
||||
)
|
||||
cd ..
|
||||
|
||||
echo -----------------------------------
|
||||
echo Build complete!
|
||||
echo Standalone executable: dist\DiscoReaper.exe
|
||||
echo Release Package: dist\disco-reaper-windows.zip
|
||||
echo ---
|
||||
IF NOT DEFINED AUTO_BUILD pause
|
||||
|
|
@ -1,18 +1,17 @@
|
|||
import sys
|
||||
import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from src.ui.main_app import run_disco_reaper_tui
|
||||
from src.core.configuration import load_config
|
||||
|
||||
def setup_logging():
|
||||
try:
|
||||
config = load_config(create_if_missing=False)
|
||||
log_level_str = config.log_level.upper()
|
||||
log_level_str = config.migration.log_level.upper()
|
||||
level = getattr(logging, log_level_str, logging.INFO)
|
||||
except Exception:
|
||||
level = logging.INFO
|
||||
|
||||
handlers = [RotatingFileHandler('.reaper.log', mode='w', maxBytes=10*1024*1024, backupCount=3)]
|
||||
handlers = [logging.FileHandler('.reaper.log', mode='a')]
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
|
||||
datefmt='%H:%M:%S',
|
||||
|
|
@ -93,21 +92,18 @@ def cleanup_old_update():
|
|||
"""Removes the .old executable left behind by a Windows update."""
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
if sys.platform != "win32":
|
||||
return
|
||||
|
||||
# In frozen (PyInstaller) builds, sys.executable points to the temp _MEIxxxxx dir.
|
||||
# sys.argv[0] always points to the real .exe on disk, so use that and resolve() it.
|
||||
current_exe = Path(sys.argv[0]).resolve()
|
||||
old_exe = current_exe.with_suffix(current_exe.suffix + ".old")
|
||||
current_exe = sys.executable if getattr(sys, 'frozen', False) else sys.argv[0]
|
||||
old_exe = current_exe + ".old"
|
||||
|
||||
if old_exe.exists():
|
||||
if os.path.exists(old_exe):
|
||||
try:
|
||||
old_exe.unlink()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).debug(f"Could not remove old update file {old_exe}: {e}")
|
||||
os.remove(old_exe)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def main():
|
||||
import os
|
||||
|
|
|
|||
|
|
@ -9,6 +9,3 @@ lottie # Lottie file manipulation and conversion
|
|||
Pillow # Image processing (required for GIF rendering)
|
||||
cairosvg # SVG rendering (required for Lottie conversion)
|
||||
psutil # System information (CPU, RAM, etc.)
|
||||
pytest # Testing framework
|
||||
pytest-asyncio # Async testing for pytest
|
||||
pytest-mock # Mocking for pytest
|
||||
21
runtests.sh
21
runtests.sh
|
|
@ -1,21 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Convenient script to run the Discord Reaper test suite.
|
||||
# Automatically handles VENV activation and PYTHONPATH.
|
||||
|
||||
# Change to the project root directory (where the script is located)
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
# Activate the virtual environment if it exists
|
||||
if [ -d "venv" ]; then
|
||||
source venv/bin/activate
|
||||
fi
|
||||
|
||||
# Set PYTHONPATH to the current directory so src is discoverable
|
||||
export PYTHONPATH=.
|
||||
|
||||
# Run pytest with any arguments passed to this script, defaulting to tests/
|
||||
if [ $# -eq 0 ]; then
|
||||
pytest -v -s -p no:warnings tests/
|
||||
else
|
||||
pytest -s "$@"
|
||||
fi
|
||||
|
|
@ -98,9 +98,9 @@ class BackupDatabase:
|
|||
elif table == "permissions":
|
||||
conn.execute("CREATE TABLE permissions (id INTEGER PRIMARY KEY AUTOINCREMENT, channel_id INTEGER, target_id INTEGER, target_type TEXT, allow INTEGER, deny INTEGER)")
|
||||
elif table == "users":
|
||||
conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, username TEXT, display_name TEXT, avatar_file TEXT, avatar_url TEXT, roles TEXT, type INTEGER DEFAULT 0)")
|
||||
conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, username TEXT, display_name TEXT, avatar_file TEXT, avatar_url TEXT, roles TEXT)")
|
||||
elif table == "messages":
|
||||
conn.execute("CREATE TABLE messages (id INTEGER PRIMARY KEY, channel_id INTEGER, author_id INTEGER, content TEXT, timestamp TEXT, type INTEGER, message_reference INTEGER, is_pinned INTEGER, extra_data TEXT, custom_display_name TEXT, custom_avatar_url TEXT)")
|
||||
conn.execute("CREATE TABLE messages (id INTEGER PRIMARY KEY, channel_id INTEGER, author_id INTEGER, content TEXT, timestamp TEXT, type INTEGER, message_reference INTEGER, is_pinned INTEGER, extra_data TEXT)")
|
||||
elif table == "attachments":
|
||||
conn.execute("CREATE TABLE attachments (id INTEGER PRIMARY KEY, message_id INTEGER, filename TEXT, size INTEGER, url TEXT, content_type TEXT, local_hash TEXT)")
|
||||
elif table == "embeds":
|
||||
|
|
@ -114,7 +114,7 @@ class BackupDatabase:
|
|||
elif table == "forum_tags":
|
||||
conn.execute("CREATE TABLE forum_tags (id INTEGER PRIMARY KEY, forum_id INTEGER, name TEXT, moderated INTEGER, emoji_id INTEGER, emoji_name TEXT)")
|
||||
elif table == "server_assets":
|
||||
conn.execute("CREATE TABLE server_assets (id INTEGER PRIMARY KEY, name TEXT, type TEXT, filename TEXT, url TEXT, content_type TEXT)")
|
||||
conn.execute("CREATE TABLE server_assets (id INTEGER PRIMARY KEY, name TEXT, type TEXT, filename TEXT, url TEXT, content_type INTEGER)")
|
||||
|
||||
old_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table}_old)").fetchall()]
|
||||
new_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table})").fetchall()]
|
||||
|
|
@ -124,25 +124,6 @@ class BackupDatabase:
|
|||
conn.execute(f"INSERT INTO {table} ({col_str}) SELECT {col_str} FROM {table}_old")
|
||||
conn.execute(f"DROP TABLE {table}_old")
|
||||
|
||||
# 3. Custom Author Profile Migration
|
||||
res = conn.execute("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='messages'").fetchone()
|
||||
if res and res[0] > 0:
|
||||
cols = conn.execute("PRAGMA table_info(messages)").fetchall()
|
||||
col_names = [c["name"] for c in cols]
|
||||
if "custom_display_name" not in col_names:
|
||||
logger.info("Migrating messages: adding custom author profile columns")
|
||||
conn.execute("ALTER TABLE messages ADD COLUMN custom_display_name TEXT")
|
||||
conn.execute("ALTER TABLE messages ADD COLUMN custom_avatar_url TEXT")
|
||||
|
||||
# 4. User Type Categorization Migration
|
||||
res = conn.execute("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='users'").fetchone()
|
||||
if res and res[0] > 0:
|
||||
cols = conn.execute("PRAGMA table_info(users)").fetchall()
|
||||
col_names = [c["name"] for c in cols]
|
||||
if "type" not in col_names:
|
||||
logger.info("Migrating users: adding type column")
|
||||
conn.execute("ALTER TABLE users ADD COLUMN type INTEGER DEFAULT 0")
|
||||
|
||||
conn.commit()
|
||||
|
||||
def _init_db(self):
|
||||
|
|
@ -215,8 +196,7 @@ class BackupDatabase:
|
|||
display_name TEXT,
|
||||
avatar_file TEXT,
|
||||
avatar_url TEXT,
|
||||
roles TEXT,
|
||||
type INTEGER DEFAULT 0
|
||||
roles TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
|
|
@ -231,9 +211,7 @@ class BackupDatabase:
|
|||
type INTEGER,
|
||||
message_reference INTEGER,
|
||||
is_pinned INTEGER,
|
||||
extra_data TEXT,
|
||||
custom_display_name TEXT,
|
||||
custom_avatar_url TEXT
|
||||
extra_data TEXT
|
||||
)
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_messages_channel ON messages(channel_id)")
|
||||
|
|
@ -427,8 +405,8 @@ class BackupDatabase:
|
|||
"""Saves users to the author cache."""
|
||||
with self._lock:
|
||||
self._conn.executemany("""
|
||||
INSERT OR REPLACE INTO users (id, username, display_name, avatar_file, avatar_url, roles, type)
|
||||
VALUES (:id, :username, :display_name, :avatar_file, :avatar_url, :roles, :type)
|
||||
INSERT OR REPLACE INTO users (id, username, display_name, avatar_file, avatar_url, roles)
|
||||
VALUES (:id, :username, :display_name, :avatar_file, :avatar_url, :roles)
|
||||
""", users)
|
||||
self._conn.commit()
|
||||
|
||||
|
|
@ -476,8 +454,8 @@ class BackupDatabase:
|
|||
conn = self._conn
|
||||
# Insert messages
|
||||
conn.executemany("""
|
||||
INSERT OR REPLACE INTO messages (id, channel_id, author_id, content, timestamp, type, message_reference, is_pinned, extra_data, custom_display_name, custom_avatar_url)
|
||||
VALUES (:id, :channel_id, :author_id, :content, :timestamp, :type, :message_reference, :is_pinned, :extra_data, :custom_display_name, :custom_avatar_url)
|
||||
INSERT OR REPLACE INTO messages (id, channel_id, author_id, content, timestamp, type, message_reference, is_pinned, extra_data)
|
||||
VALUES (:id, :channel_id, :author_id, :content, :timestamp, :type, :message_reference, :is_pinned, :extra_data)
|
||||
""", messages)
|
||||
|
||||
# Extract attachments, reactions, and stickers
|
||||
|
|
@ -967,60 +945,6 @@ class BackupDatabase:
|
|||
|
||||
return purged_count
|
||||
|
||||
def get_backed_up_channel_ids(self) -> List[int]:
|
||||
"""Returns a list of distinct channel IDs that have messages in the database."""
|
||||
with self._lock:
|
||||
rows = self._conn.execute("SELECT DISTINCT channel_id FROM messages").fetchall()
|
||||
return [parse_snowflake(r[0]) for r in rows if parse_snowflake(r[0])]
|
||||
|
||||
def get_message_with_relations(self, message_id) -> Optional[Dict[str, Any]]:
|
||||
"""Fetches a single message with its attachments, embeds, reactions, and stickers."""
|
||||
with self._lock:
|
||||
mid = parse_snowflake(message_id)
|
||||
row = self._conn.execute("SELECT * FROM messages WHERE id = ?", (mid,)).fetchone()
|
||||
if not row:
|
||||
return None
|
||||
data = dict(row)
|
||||
|
||||
# Attachments
|
||||
atts = self._conn.execute("SELECT * FROM attachments WHERE message_id = ?", (mid,)).fetchall()
|
||||
data["attachments"] = [dict(a) for a in atts]
|
||||
|
||||
# Embeds
|
||||
embs = self._conn.execute("SELECT * FROM embeds WHERE message_id = ?", (mid,)).fetchall()
|
||||
data["embeds"] = []
|
||||
for er in embs:
|
||||
e_dict = {
|
||||
"title": er["title"],
|
||||
"description": er["description"],
|
||||
"url": er["url"],
|
||||
"color": er["color"],
|
||||
"timestamp": er["timestamp"],
|
||||
"thumbnail": {"url": er["thumbnail_url"]} if er["thumbnail_url"] else None,
|
||||
"image": {"url": er["image_url"]} if er["image_url"] else None,
|
||||
"author": {
|
||||
"name": er["author_name"],
|
||||
"url": er["author_url"],
|
||||
"icon_url": er["author_icon_url"]
|
||||
} if er["author_name"] else None,
|
||||
"footer": {
|
||||
"text": er["footer_text"],
|
||||
"icon_url": er["footer_icon_url"]
|
||||
} if er["footer_text"] else None,
|
||||
"fields": json.loads(er["fields"]) if er["fields"] else []
|
||||
}
|
||||
data["embeds"].append(e_dict)
|
||||
|
||||
# Reactions
|
||||
reas = self._conn.execute("SELECT * FROM reactions WHERE message_id = ?", (mid,)).fetchall()
|
||||
data["reactions"] = [dict(r) for r in reas]
|
||||
|
||||
# Stickers
|
||||
sts = self._conn.execute("SELECT * FROM message_stickers WHERE message_id = ?", (mid,)).fetchall()
|
||||
data["stickers"] = [dict(s) for s in sts]
|
||||
|
||||
return data
|
||||
|
||||
def close(self):
|
||||
"""Commits any pending writes and closes the connection."""
|
||||
with self._lock:
|
||||
|
|
|
|||
|
|
@ -393,20 +393,6 @@ class BackupMember:
|
|||
# Fallback for unexpected data format
|
||||
self.id = 0
|
||||
self.name = "Unknown"
|
||||
self.display_name = "Unknown"
|
||||
self.global_name = "Unknown"
|
||||
self.bot = False
|
||||
self.system = False
|
||||
self.discriminator = "0000"
|
||||
self.color = BackupColor(0)
|
||||
self.roles = sorted(role_objects or [], key=lambda r: r.position, reverse=True)
|
||||
self.guild_permissions = BackupPermissions(0)
|
||||
self.created_at = datetime.now(timezone.utc)
|
||||
self.joined_at = datetime.now(timezone.utc)
|
||||
self.status = type("Status", (), {"value": "offline"})()
|
||||
self.activity = None
|
||||
self._avatar_url = None
|
||||
self.avatar = BackupAsset(None)
|
||||
return
|
||||
self.id = parse_snowflake(data["id"])
|
||||
self.name = data.get("username", "Unknown")
|
||||
|
|
@ -530,13 +516,12 @@ class BackupEmoji:
|
|||
class BackupSticker:
|
||||
"""Minimal stand-in for discord.GuildSticker."""
|
||||
|
||||
__slots__ = ("id", "name", "url", "format", "_backup_root", "_file_path", "local_hash")
|
||||
__slots__ = ("id", "name", "url", "format", "_backup_root", "_file_path")
|
||||
|
||||
def __init__(self, data: dict, backup_root: Path | None = None, media_pool: dict | None = None):
|
||||
if not isinstance(data, dict):
|
||||
self.id = 0
|
||||
self.name = "Sticker"
|
||||
self.local_hash = None
|
||||
return
|
||||
self.id = parse_snowflake(data.get("id") or data.get("sticker_id", 0)) or 0
|
||||
self.name = data.get("name", "Sticker")
|
||||
|
|
@ -551,14 +536,14 @@ class BackupSticker:
|
|||
self._backup_root = backup_root
|
||||
|
||||
# 1. Check if it's a CAS-based sticker (from message_stickers table)
|
||||
self.local_hash = data.get("local_hash")
|
||||
if self.local_hash and backup_root:
|
||||
local_hash = data.get("local_hash")
|
||||
if local_hash and backup_root:
|
||||
ext = ".png"
|
||||
if self.format == StickerFormatType.lottie: ext = ".json"
|
||||
elif self.format == StickerFormatType.apng: ext = ".png"
|
||||
elif self.format == StickerFormatType.gif: ext = ".gif"
|
||||
|
||||
self._file_path = backup_root / "attachments" / f"{self.local_hash}{ext}"
|
||||
self._file_path = backup_root / "attachments" / f"{local_hash}{ext}"
|
||||
# 2. Check if it's a server asset sticker (legacy or manual save)
|
||||
elif data.get("filename") and backup_root:
|
||||
self._file_path = backup_root / "server_assets" / data["filename"]
|
||||
|
|
@ -1281,7 +1266,11 @@ class BackupReader:
|
|||
async def get_backed_up_channel_ids(self) -> List[int]:
|
||||
"""Returns a list of channel IDs that have messages in the database."""
|
||||
if not self.db: return []
|
||||
return self.db.get_backed_up_channel_ids()
|
||||
import sqlite3
|
||||
conn = sqlite3.connect(self.db.db_path)
|
||||
rows = conn.execute("SELECT DISTINCT channel_id FROM messages").fetchall()
|
||||
conn.close()
|
||||
return [parse_snowflake(r[0]) for r in rows if parse_snowflake(r[0])]
|
||||
|
||||
async def get_channel(self, channel_id: int) -> BackupChannel | BackupThread | None:
|
||||
for c in self.channels:
|
||||
|
|
@ -1339,18 +1328,6 @@ class BackupReader:
|
|||
user_id = parse_snowflake(msg_data.get("author_id", 0)) or 0
|
||||
author = self._resolve_author(user_id)
|
||||
|
||||
# Check for custom author profile (Webhooks / Masquerade)
|
||||
over_name = msg_data.get("custom_display_name")
|
||||
over_avatar = msg_data.get("custom_avatar_url")
|
||||
if over_name:
|
||||
# Create an ephemeral author object for this message
|
||||
author = BackupMember({
|
||||
"id": str(user_id),
|
||||
"username": over_name,
|
||||
"display_name": over_name,
|
||||
"avatar_url": over_avatar
|
||||
})
|
||||
|
||||
self._ensure_media_pool_loaded()
|
||||
|
||||
channel_id = parse_snowflake(msg_data["channel_id"])
|
||||
|
|
@ -1374,9 +1351,23 @@ class BackupReader:
|
|||
async def get_message(self, channel_id: int, message_id: int) -> BackupMessage | None:
|
||||
"""Fetch a specific message from SQLite."""
|
||||
if not self.db: return None
|
||||
data = self.db.get_message_with_relations(message_id)
|
||||
if data:
|
||||
import sqlite3
|
||||
conn = sqlite3.connect(self.db.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
row = conn.execute("SELECT * FROM messages WHERE id = ?", (str(message_id),)).fetchone()
|
||||
if row:
|
||||
data = dict(row)
|
||||
# Fetch attachments
|
||||
atts = conn.execute("SELECT * FROM attachments WHERE message_id = ?", (str(message_id),)).fetchall()
|
||||
data["attachments"] = [dict(a) for a in atts]
|
||||
|
||||
# Fetch stickers
|
||||
sts = conn.execute("SELECT * FROM message_stickers WHERE message_id = ?", (str(message_id),)).fetchall()
|
||||
data["stickers"] = [dict(s) for s in sts]
|
||||
|
||||
conn.close()
|
||||
return self._hydrate_message(data)
|
||||
conn.close()
|
||||
return None
|
||||
|
||||
async def get_first_message(self, channel_id: int) -> BackupMessage | None:
|
||||
|
|
|
|||
|
|
@ -22,13 +22,6 @@ class MigrationContext:
|
|||
self.target_platform = target_platform or config.target_platform or "fluxer"
|
||||
self.state = MigrationState()
|
||||
|
||||
# Apply config-based log level dynamically to the root logger
|
||||
if hasattr(self.config, "log_level") and self.config.log_level:
|
||||
import logging
|
||||
level = getattr(logging, self.config.log_level.upper(), logging.INFO)
|
||||
logging.getLogger().setLevel(level)
|
||||
logger.info(f"Log level updated to {self.config.log_level.upper()}")
|
||||
|
||||
# Select the appropriate source reader
|
||||
if source_mode == "backup":
|
||||
from src.core.backup_reader import BackupReader
|
||||
|
|
@ -44,8 +37,8 @@ class MigrationContext:
|
|||
|
||||
# Build the writer for the active target platform only
|
||||
if self.target_platform == "stoat":
|
||||
token = config.stoat_bot_token
|
||||
community_id = config.stoat_server_id
|
||||
token = config.stoat_bot_token or ""
|
||||
community_id = config.stoat_server_id or ""
|
||||
api_url = config.stoat_api_url or "default"
|
||||
self.writer = StoatWriter(token=token, community_id=community_id, api_url=api_url)
|
||||
self.stoat_writer = self.writer
|
||||
|
|
@ -104,17 +97,11 @@ class MigrationContext:
|
|||
}
|
||||
|
||||
# CONSISTENCY: Once target metadata is known, initialize the flat SQLite DB.
|
||||
if results["target_community"]:
|
||||
if results["target_community"] and results["target_community_name"]:
|
||||
tid = self.config.fluxer_server_id if self.target_platform == "fluxer" else self.config.stoat_server_id
|
||||
|
||||
# Prefer the original discord community name for the DB file if available (e.g. from live load or backup)
|
||||
db_name = results.get("discord_server_name")
|
||||
if not db_name or db_name == "Not Found" or db_name == "Unknown":
|
||||
db_name = results.get("target_community_name") or "Unknown"
|
||||
|
||||
self.ensure_state_initialized(
|
||||
str(tid or ""),
|
||||
db_name
|
||||
results["target_community_name"]
|
||||
)
|
||||
|
||||
return results
|
||||
|
|
@ -133,23 +120,6 @@ class MigrationContext:
|
|||
return
|
||||
|
||||
import re
|
||||
import json
|
||||
|
||||
# Override the target name explicitly with the original Discord source name if available.
|
||||
# This fixes naming collisions and UI confusion like "Fluxer-123456.db" instead of "MyServer-123456.db"
|
||||
try:
|
||||
if hasattr(self.discord_reader, "guild") and getattr(self.discord_reader, "guild", None):
|
||||
community_name = getattr(self.discord_reader, "guild").name
|
||||
elif getattr(self, "source_mode", "live") == "backup" and hasattr(self.discord_reader, "backup_dir"):
|
||||
b_dir = getattr(self.discord_reader, "backup_dir")
|
||||
if b_dir and b_dir.exists():
|
||||
meta_file = b_dir / "metadata.json"
|
||||
if meta_file.exists():
|
||||
data = json.loads(meta_file.read_text())
|
||||
community_name = data.get("name", community_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
clean_name = re.sub(r'[^\w\s-]', '', community_name).strip()
|
||||
clean_name = re.sub(r'[-\s]+', '_', clean_name)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import logging
|
|||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List, Union
|
||||
from typing import Optional, Dict, Any, Union
|
||||
import threading
|
||||
import sys
|
||||
from src.core.utils import parse_snowflake
|
||||
|
|
@ -560,15 +560,6 @@ class MigrationDatabase:
|
|||
conn.execute("DELETE FROM thread_tracking WHERE channel_id = ?", (str(channel_id),))
|
||||
conn.commit()
|
||||
logger.info(f"Cleared all tracking and mapping data for channel: {channel_id}")
|
||||
def clear_all_migration_data(self):
|
||||
"""Purge all mappings and tracking data for ALL channels and threads."""
|
||||
conn = self._get_conn()
|
||||
conn.execute("DELETE FROM message_mappings")
|
||||
conn.execute("DELETE FROM thread_mappings")
|
||||
conn.execute("DELETE FROM channel_tracking")
|
||||
conn.execute("DELETE FROM thread_tracking")
|
||||
conn.commit()
|
||||
logger.info("Cleared ALL tracking and message mapping data globally.")
|
||||
|
||||
def close(self):
|
||||
if hasattr(self._local, "conn"):
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ class DiscordExporter:
|
|||
self.server_name = ""
|
||||
self.server_id = ""
|
||||
self.user_cache = {}
|
||||
self.member_cache: Dict[int, Any] = {} # Pre-fetched member objects (id -> Member)
|
||||
self.base_dir = Path(base_dir) if base_dir else Path(".")
|
||||
self.is_running = True
|
||||
self.db: Optional[BackupDatabase] = None
|
||||
|
|
@ -63,20 +62,6 @@ class DiscordExporter:
|
|||
hash_sha256.update(chunk)
|
||||
return hash_sha256.hexdigest()
|
||||
|
||||
async def prefetch_members(self):
|
||||
"""Pre-fetches all guild members into a local cache for role resolution.
|
||||
|
||||
msg.author is a discord.User (no roles). This cache allows us to
|
||||
resolve roles without an API call per message during message export.
|
||||
"""
|
||||
try:
|
||||
members = await self.reader.get_members()
|
||||
self.member_cache = {m.id: m for m in members}
|
||||
logger.info(f"Pre-fetched {len(self.member_cache)} members for role resolution.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not pre-fetch members (roles will be empty): {e}")
|
||||
self.member_cache = {}
|
||||
|
||||
async def export_metadata(self):
|
||||
"""Saves server metadata to the SQLite database."""
|
||||
metadata = await self.reader.get_server_metadata()
|
||||
|
|
@ -454,65 +439,37 @@ class DiscordExporter:
|
|||
|
||||
return accumulated_count, accumulated_threads, accumulated_files
|
||||
|
||||
async def _format_user(self, user, is_webhook=False):
|
||||
async def _format_user(self, user):
|
||||
"""Formats user data for the author or a mention.
|
||||
|
||||
For Webhooks, we use a generic name and the default Discord avatar system
|
||||
for the base profile in the user cache.
|
||||
Avatar downloads are intentionally deferred to keep this off the hot
|
||||
message-formatting path. Call _flush_pending_avatars() after each batch.
|
||||
"""
|
||||
user_id_int = int(user.id)
|
||||
user_id = str(user_id_int)
|
||||
|
||||
user_id = str(user.id)
|
||||
if user_id in self.user_cache:
|
||||
return None
|
||||
|
||||
username = user.name
|
||||
display_name = getattr(user, "display_name", user.name)
|
||||
avatar = user.avatar
|
||||
avatar_url = str(user.display_avatar.url) if user.display_avatar else None
|
||||
|
||||
if is_webhook:
|
||||
# For webhooks, we use the ID as the username for technical clarity,
|
||||
# and the current name as the display name.
|
||||
username = user_id
|
||||
display_name = user.name
|
||||
# Discord default avatar formula: (ID >> 22) % 5
|
||||
default_index = (user_id_int >> 22) % 5
|
||||
avatar_url = f"https://cdn.discordapp.com/embed/avatars/{default_index}.png"
|
||||
avatar = None # Don't download character avatar as the "base" webhook avatar
|
||||
|
||||
# New user discovered — schedule avatar download but don't block here
|
||||
avatar_file = None
|
||||
if avatar:
|
||||
if user.avatar:
|
||||
av_name = f"{user_id}.png"
|
||||
av_target = self.users_path / av_name
|
||||
avatar_file = f"users/{av_name}"
|
||||
if not av_target.exists():
|
||||
# Queue for deferred download
|
||||
self._pending_avatars.append((user_id, avatar, av_target))
|
||||
self._pending_avatars.append((user_id, user.avatar, av_target))
|
||||
|
||||
roles = []
|
||||
if hasattr(user, "roles"):
|
||||
roles = [str(r.id) for r in user.roles if not r.is_default()]
|
||||
|
||||
# Determine user type
|
||||
# 0: Regular User, 1: Bot, 2: Webhook, 3: System
|
||||
u_type = 0
|
||||
if is_webhook:
|
||||
u_type = 2
|
||||
elif getattr(user, "system", False):
|
||||
u_type = 3
|
||||
elif getattr(user, "bot", False):
|
||||
u_type = 1
|
||||
|
||||
user_data = {
|
||||
"id": user_id,
|
||||
"username": username,
|
||||
"display_name": display_name,
|
||||
"username": user.name,
|
||||
"display_name": getattr(user, "display_name", user.name),
|
||||
"avatar_file": avatar_file,
|
||||
"avatar_url": avatar_url,
|
||||
"roles": json.dumps(roles),
|
||||
"type": u_type
|
||||
"avatar_url": str(user.display_avatar.url) if user.avatar else None,
|
||||
"roles": json.dumps(roles)
|
||||
}
|
||||
self.user_cache[user_id] = user_data
|
||||
return user_data
|
||||
|
|
@ -537,21 +494,13 @@ class DiscordExporter:
|
|||
new_users = []
|
||||
|
||||
# 1. Author handling
|
||||
is_webhook = bool(getattr(msg, "webhook_id", None))
|
||||
author = msg.author
|
||||
# msg.author is discord.User (no roles). Resolve to Member for role data.
|
||||
if not is_webhook:
|
||||
member = self.member_cache.get(msg.author.id)
|
||||
if member:
|
||||
author = member
|
||||
u_data = await self._format_user(author, is_webhook=is_webhook)
|
||||
u_data = await self._format_user(msg.author)
|
||||
if u_data: new_users.append(u_data)
|
||||
|
||||
# 1.5 Mentions handling (ensure all mentioned users are saved)
|
||||
if msg.mentions:
|
||||
for mention in msg.mentions:
|
||||
# Mentions can be Member objects already, so roles work naturally
|
||||
u_ment = await self._format_user(mention, is_webhook=False)
|
||||
u_ment = await self._format_user(mention)
|
||||
if u_ment: new_users.append(u_ment)
|
||||
|
||||
# 2. Attachments handling (Content-Addressable Storage)
|
||||
|
|
@ -654,15 +603,6 @@ class DiscordExporter:
|
|||
for s_emb in snapshot.embeds:
|
||||
embeds.append(s_emb.to_dict())
|
||||
|
||||
# 5.6 Author Overrides (Webhooks / Masquerade)
|
||||
custom_display_name = None
|
||||
custom_avatar_url = None
|
||||
|
||||
# Webhooks or bots with masquerade often use per-message names/avatars
|
||||
if getattr(msg, "webhook_id", None) or (msg.author and msg.author.bot):
|
||||
custom_display_name = msg.author.name
|
||||
custom_avatar_url = str(msg.author.display_avatar.url) if msg.author.display_avatar else None
|
||||
|
||||
m_data = {
|
||||
"id": str(msg.id),
|
||||
"channel_id": str(msg.channel.id),
|
||||
|
|
@ -676,9 +616,7 @@ class DiscordExporter:
|
|||
"stickers": stickers,
|
||||
"embeds": embeds,
|
||||
"reactions": reactions,
|
||||
"extra_data": None,
|
||||
"custom_display_name": custom_display_name,
|
||||
"custom_avatar_url": custom_avatar_url
|
||||
"extra_data": None
|
||||
}
|
||||
|
||||
return m_data, new_users
|
||||
|
|
|
|||
|
|
@ -232,17 +232,21 @@ class MigrationState:
|
|||
return None
|
||||
|
||||
|
||||
def get_global_min_last_message_id(self, all_mapped_ids: list[str]) -> int | None:
|
||||
def get_global_min_last_message_id(self, all_mapped_ids: List[str]) -> int | None:
|
||||
"""Returns the absolute minimum last_msg_id among the given list of mapped target IDs (channels and threads)."""
|
||||
if self._ensure_db():
|
||||
return self.db.get_global_min_last_message_id(all_mapped_ids)
|
||||
return None
|
||||
|
||||
def set_waterfall_last_id(self, last_id: str | int):
|
||||
if self.db:
|
||||
self.db.set_metadata("waterfall_last_id", str(last_id))
|
||||
|
||||
def clear_all_migration_data(self):
|
||||
"""Clears all message mapping and tracking state globally."""
|
||||
if self._ensure_db():
|
||||
self.db.clear_all_migration_data()
|
||||
def get_waterfall_last_id(self) -> int | None:
|
||||
if self.db:
|
||||
val = self.db.get_metadata("waterfall_last_id")
|
||||
return int(val) if val else None
|
||||
return None
|
||||
|
||||
def get_all_last_message_ids(self) -> Dict[str, str]:
|
||||
"""Returns a combined map of channel_id/thread_id -> last_msg_id."""
|
||||
|
|
|
|||
|
|
@ -10,7 +10,9 @@ from src.core.utils import get_app_version
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
API_URL = "https://git.mithraic.cloud/api/v1/repos/ad3laid3/disco-reaper/releases"
|
||||
REPO_OWNER = "rambros3d"
|
||||
REPO_NAME = "disco-reaper"
|
||||
API_URL = f"https://api.github.com/repos/{REPO_OWNER}/{REPO_NAME}/releases"
|
||||
|
||||
def get_current_version() -> str:
|
||||
"""Returns the current version string, e.g., '1.0.0'. Strips 'Reaper-' and 'v'."""
|
||||
|
|
@ -50,7 +52,7 @@ async def check_for_updates() -> Optional[Dict[str, Any]]:
|
|||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(API_URL) as resp:
|
||||
async with session.get(API_URL, headers={"Accept": "application/vnd.github.v3+json"}) as resp:
|
||||
if resp.status == 200:
|
||||
releases = await resp.json()
|
||||
if not isinstance(releases, list) or not releases:
|
||||
|
|
|
|||
|
|
@ -15,52 +15,41 @@ async def sync_channel_state(context: MigrationContext):
|
|||
channels = await context.discord_reader.get_channels()
|
||||
fluxer_channels = await context.fluxer_writer.get_channels()
|
||||
|
||||
# Build maps for Fluxer lookup
|
||||
# {name: id} for categories
|
||||
fluxer_cats = {c.get("name"): str(c.get("id")) for c in fluxer_channels if c.get("type") == 4}
|
||||
# {parent_id: {name: id}} for channels
|
||||
fluxer_structure = {}
|
||||
for c in fluxer_channels:
|
||||
if c.get("type") == 4: continue
|
||||
p_id = str(c.get("parent_id")) if c.get("parent_id") else "root"
|
||||
if p_id not in fluxer_structure: fluxer_structure[p_id] = {}
|
||||
fluxer_structure[p_id][c.get("name")] = str(c.get("id"))
|
||||
|
||||
# Build name -> id map and ID set for Fluxer for fast lookup
|
||||
fluxer_name_map = {c.get("name"): str(c.get("id")) for c in fluxer_channels if c.get("name")}
|
||||
fluxer_id_set = {str(c.get("id")) for c in fluxer_channels}
|
||||
|
||||
updates = 0
|
||||
removals = 0
|
||||
|
||||
# 1. Sync Categories
|
||||
# 1. Verify and Sync Categories
|
||||
for cat in categories:
|
||||
discord_id = str(cat.id)
|
||||
fluxer_id = context.state.get_fluxer_category_id(discord_id)
|
||||
|
||||
if fluxer_id:
|
||||
if fluxer_id not in fluxer_id_set:
|
||||
context.state.remove_category_mapping(discord_id)
|
||||
removals += 1
|
||||
elif cat.name in fluxer_cats:
|
||||
context.state.set_category_mapping(discord_id, fluxer_cats[cat.name])
|
||||
elif cat.name in fluxer_name_map:
|
||||
context.state.set_category_mapping(discord_id, fluxer_name_map[cat.name])
|
||||
updates += 1
|
||||
|
||||
# 2. Sync Channels (parent-aware)
|
||||
# 2. Verify and Sync Channels
|
||||
for ch in channels:
|
||||
discord_id = str(ch.id)
|
||||
fluxer_id = context.state.get_fluxer_channel_id(discord_id)
|
||||
|
||||
if fluxer_id:
|
||||
if fluxer_id not in fluxer_id_set:
|
||||
context.state.remove_channel_mapping(discord_id)
|
||||
removals += 1
|
||||
else:
|
||||
# Try to match by name within the mapped parent category
|
||||
p_discord_id = str(ch.category_id) if ch.category_id else "root"
|
||||
p_fluxer_id = context.state.get_fluxer_category_id(p_discord_id) if p_discord_id != "root" else "root"
|
||||
|
||||
if p_fluxer_id in fluxer_structure and ch.name in fluxer_structure[p_fluxer_id]:
|
||||
context.state.set_channel_mapping(discord_id, fluxer_structure[p_fluxer_id][ch.name])
|
||||
updates += 1
|
||||
elif ch.name in fluxer_name_map:
|
||||
context.state.set_channel_mapping(discord_id, fluxer_name_map[ch.name])
|
||||
updates += 1
|
||||
|
||||
if updates > 0 or removals > 0:
|
||||
logger.info(f"Fluxer Channel sync: {updates} mapped, {removals} stale mappings removed")
|
||||
logger.info(f"Channel sync: {updates} mapped, {removals} stale mappings removed")
|
||||
|
||||
|
||||
async def migrate_channels(context: MigrationContext, progress_callback: Callable[[str, str, int, int], Awaitable[None]] | None = None, force: bool = False) -> dict:
|
||||
|
|
|
|||
|
|
@ -158,190 +158,7 @@ async def get_channel_threads(reader: Any, channel_id: int) -> List[Any]:
|
|||
return threads
|
||||
|
||||
|
||||
|
||||
async def _process_and_send_message(
|
||||
context: MigrationContext,
|
||||
msg: Any,
|
||||
target_channel_id: str,
|
||||
stats: Dict[str, Any],
|
||||
thread_id: str | None = None,
|
||||
parent_target_id: str | None = None,
|
||||
thread_name: str | None = None,
|
||||
processed_threads: set | None = None
|
||||
) -> str | None:
|
||||
"""
|
||||
Internal helper to process a single Discord message (mentions, attachments, stickers)
|
||||
and send it to the Fluxer platform.
|
||||
"""
|
||||
# 1. Formatting
|
||||
content = msg.content or ""
|
||||
|
||||
# Check for forwarded flag
|
||||
is_forwarded = False
|
||||
if hasattr(msg.flags, 'forwarded'):
|
||||
is_forwarded = msg.flags.forwarded
|
||||
|
||||
# Always ensure alias is created/retrieved to populate user_alias table
|
||||
alias = context.state.get_user_alias(str(msg.author.id))
|
||||
anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False
|
||||
|
||||
# Process Stickers
|
||||
files = []
|
||||
if hasattr(msg, 'stickers') and msg.stickers:
|
||||
for s in msg.stickers:
|
||||
try:
|
||||
sticker_data = await context.discord_reader.download_sticker(s)
|
||||
if not sticker_data: continue
|
||||
|
||||
format_val = getattr(s, 'format', 'png')
|
||||
if hasattr(format_val, 'name'):
|
||||
ext = format_val.name.lower()
|
||||
elif isinstance(format_val, int):
|
||||
format_map = {1: 'png', 2: 'apng', 3: 'lottie', 4: 'gif'}
|
||||
ext = format_map.get(format_val, 'png')
|
||||
else:
|
||||
ext = str(format_val).lower()
|
||||
|
||||
# Conversion logic (Simplified for unification)
|
||||
if ext == 'lottie' and HAS_LOTTIE:
|
||||
try:
|
||||
lottie_data = json.loads(sticker_data)
|
||||
def _convert_lottie(data):
|
||||
anim = Animation.load(data)
|
||||
buf = io.BytesIO()
|
||||
export_gif(anim, buf)
|
||||
buf.seek(0)
|
||||
return buf
|
||||
gif_buf = await asyncio.to_thread(_convert_lottie, lottie_data)
|
||||
from PIL import Image
|
||||
def _convert_gif_to_webp(buf):
|
||||
img = Image.open(buf)
|
||||
w_buf = io.BytesIO()
|
||||
if getattr(img, 'n_frames', 1) > 1:
|
||||
img.save(w_buf, format='WEBP', save_all=True, loop=0, quality=80)
|
||||
else:
|
||||
img.save(w_buf, format='WEBP', quality=80)
|
||||
return w_buf.getvalue()
|
||||
sticker_data = await asyncio.to_thread(_convert_gif_to_webp, gif_buf)
|
||||
ext = 'webp'
|
||||
except Exception: ext = 'json'
|
||||
elif ext in ('apng', 'gif'):
|
||||
try:
|
||||
from PIL import Image
|
||||
def _process_animated_sticker(data):
|
||||
img = Image.open(io.BytesIO(data))
|
||||
webp_buf = io.BytesIO()
|
||||
if getattr(img, 'n_frames', 1) > 1:
|
||||
img.save(webp_buf, format='WEBP', save_all=True, loop=0, quality=80)
|
||||
else:
|
||||
img.save(webp_buf, format='WEBP', quality=80)
|
||||
return webp_buf.getvalue()
|
||||
sticker_data = await asyncio.to_thread(_process_animated_sticker, sticker_data)
|
||||
ext = 'webp'
|
||||
except Exception: pass
|
||||
|
||||
files.append({"filename": f"sticker_{s.name}_{s.id}.{ext}", "data": sticker_data})
|
||||
stats["attachments"] += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download sticker {getattr(s, 'name', 'unknown')}: {e}")
|
||||
|
||||
# Process Attachments
|
||||
attachments_to_process = list(msg.attachments)
|
||||
if is_forwarded and hasattr(msg, 'message_snapshots') and msg.message_snapshots:
|
||||
snapshot = msg.message_snapshots[0]
|
||||
if not content:
|
||||
content = snapshot.content
|
||||
attachments_to_process.extend(snapshot.attachments)
|
||||
|
||||
for att in attachments_to_process:
|
||||
try:
|
||||
att_data = await context.discord_reader.download_attachment(att)
|
||||
files.append({"filename": att.filename, "data": att_data})
|
||||
stats["attachments"] += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download attachment {att.filename}: {e}")
|
||||
|
||||
# Clean Mentions
|
||||
content = clean_mentions(
|
||||
content=content,
|
||||
guild=context.discord_reader.guild,
|
||||
user_mentions=msg.mentions,
|
||||
role_mentions=msg.role_mentions,
|
||||
channel_mentions=msg.channel_mentions,
|
||||
emoji_map=context.state.emoji_map,
|
||||
channel_map=context.state.channel_map,
|
||||
state=context.state,
|
||||
target_server_id=context.fluxer_writer.community_id,
|
||||
channel_names=context.channel_names if hasattr(context, 'channel_names') else None,
|
||||
anonymize_users=anonymize_users
|
||||
)
|
||||
|
||||
if not content and not files:
|
||||
return None
|
||||
|
||||
# Reply Resolution
|
||||
reply_to_fluxer_id = None
|
||||
if msg.reference and msg.reference.message_id:
|
||||
reply_to_fluxer_id = context.state.get_fluxer_message_id(target_channel_id, str(msg.reference.message_id))
|
||||
|
||||
# Fallback author tagging for replies if mapping not found
|
||||
if not reply_to_fluxer_id:
|
||||
try:
|
||||
source_ref_msg = await context.discord_reader.get_message(msg.channel.id, msg.reference.message_id)
|
||||
if source_ref_msg and source_ref_msg.author:
|
||||
ref_name = context.state.get_user_alias(str(source_ref_msg.author.id)) if anonymize_users else source_ref_msg.author.display_name
|
||||
content = f"`@{ref_name}`\n{content}"
|
||||
else:
|
||||
tgt_reply = context.state.get_target_message_id(target_channel_id, msg.reference.message_id)
|
||||
if tgt_reply: content = f"[Reply to {tgt_reply}]\n{content}"
|
||||
except Exception: pass
|
||||
|
||||
# Thread logic
|
||||
if not reply_to_fluxer_id and parent_target_id and stats["messages"] == 0:
|
||||
reply_to_fluxer_id = parent_target_id
|
||||
if thread_name and stats["messages"] == 0:
|
||||
content = f"> <<< THREAD: **{thread_name}** >>>\n{content}"
|
||||
|
||||
# Send Message
|
||||
if anonymize_users:
|
||||
author_name = alias or "Anonymized User"
|
||||
author_avatar_url = None
|
||||
else:
|
||||
author_name = msg.author.display_name
|
||||
author_avatar_url = msg.author.avatar.url if hasattr(msg.author, 'avatar') and msg.author.avatar else None
|
||||
|
||||
fluxer_msg_id = await context.fluxer_writer.send_message(
|
||||
channel_id=target_channel_id,
|
||||
author_name=author_name,
|
||||
author_avatar_url=author_avatar_url,
|
||||
content=content,
|
||||
timestamp=int(msg.created_at.timestamp()),
|
||||
files=files if files else None,
|
||||
reply_to_message_id=reply_to_fluxer_id,
|
||||
is_forwarded=is_forwarded,
|
||||
embeds=msg.embeds
|
||||
)
|
||||
|
||||
if fluxer_msg_id:
|
||||
if thread_id:
|
||||
context.state.set_thread_message_mapping(target_channel_id, thread_id, str(msg.id), fluxer_msg_id)
|
||||
context.state.update_thread_last_message_timestamp(target_channel_id, thread_id, str(msg.created_at))
|
||||
context.state.update_thread_last_message_id(target_channel_id, thread_id, str(msg.id))
|
||||
context.state.increment_thread_stats(target_channel_id, thread_id, messages=1, files=len(files) if files else 0)
|
||||
else:
|
||||
context.state.set_message_mapping(target_channel_id, str(msg.id), fluxer_msg_id)
|
||||
context.state.update_last_message_timestamp(target_channel_id, str(msg.created_at))
|
||||
context.state.update_last_message_id(target_channel_id, str(msg.id))
|
||||
context.state.increment_stats(target_channel_id, messages=1, files=len(files) if files else 0)
|
||||
|
||||
stats["messages"] += 1
|
||||
stats["last_message_content"] = content
|
||||
stats["last_message_author"] = msg.author.display_name
|
||||
|
||||
return fluxer_msg_id
|
||||
|
||||
async def analyze_migration(context: MigrationContext, source_channel_id: int, after_message_id: int | None = None, inclusive: bool = False, progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None, processed_threads: set | None = None) -> Dict[str, int]:
|
||||
|
||||
"""
|
||||
Scans channel history to count messages, threads, and attachments.
|
||||
"""
|
||||
|
|
@ -725,30 +542,87 @@ async def migrate_messages(
|
|||
logger.debug(f"Added sticker {s.name} as attachment (extension: {ext}, size: {sticker_size} bytes)")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download sticker {getattr(s, 'name', 'unknown')}: {e}")
|
||||
# Check for existing mapping to avoid duplicates when resuming
|
||||
|
||||
# Check for existing mapping to avoid duplicates when resuming
|
||||
if context.state.get_target_message_id(target_channel_id, str(msg.id)):
|
||||
continue
|
||||
|
||||
try:
|
||||
fluxer_msg_id = await _process_and_send_message(
|
||||
context=context,
|
||||
msg=msg,
|
||||
target_channel_id=target_channel_id,
|
||||
stats=stats,
|
||||
thread_id=thread_id,
|
||||
parent_target_id=parent_target_id,
|
||||
thread_name=thread_name,
|
||||
processed_threads=processed_threads
|
||||
reply_to_fluxer_id = None
|
||||
if msg.reference and msg.reference.message_id:
|
||||
reply_to_fluxer_id = context.state.get_fluxer_message_id(target_channel_id, str(msg.reference.message_id))
|
||||
if reply_to_fluxer_id:
|
||||
logger.debug(f"Detected reply to Discord ID {msg.reference.message_id} -> Fluxer ID {reply_to_fluxer_id}")
|
||||
else:
|
||||
logger.debug(f"Reply target Discord ID {msg.reference.message_id} not found in current session map.")
|
||||
|
||||
# If this is the FIRST thread message and we have a parent_target_id, force it as reply to the starter
|
||||
if not reply_to_fluxer_id and parent_target_id and stats["messages"] == 0:
|
||||
reply_to_fluxer_id = parent_target_id
|
||||
|
||||
# Prepend thread marker to the first message of the thread
|
||||
if thread_name and stats["messages"] == 0:
|
||||
content = f"> <<< THREAD: **{thread_name}** >>>\n{content}"
|
||||
|
||||
# Always ensure alias is created/retrieved to populate user_alias table
|
||||
alias = context.state.get_user_alias(str(msg.author.id))
|
||||
|
||||
anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False
|
||||
if anonymize_users:
|
||||
author_name = alias or "Anonymized User"
|
||||
author_avatar_url = None
|
||||
else:
|
||||
author_name = msg.author.display_name
|
||||
author_avatar_url = msg.author.avatar.url if hasattr(msg.author, 'avatar') and msg.author.avatar else None
|
||||
|
||||
logger.debug(f"Fluxer: Calling send_message for Discord ID {msg.id}")
|
||||
fluxer_msg_id = await context.fluxer_writer.send_message(
|
||||
channel_id=target_channel_id,
|
||||
author_name=author_name,
|
||||
author_avatar_url=author_avatar_url,
|
||||
content=content,
|
||||
timestamp=int(msg.created_at.timestamp()),
|
||||
files=files if files else None,
|
||||
reply_to_message_id=reply_to_fluxer_id,
|
||||
is_forwarded=is_forwarded,
|
||||
embeds=msg.embeds
|
||||
)
|
||||
|
||||
# Check for associated thread (Individual mode recursion)
|
||||
if fluxer_msg_id:
|
||||
if thread_id:
|
||||
context.state.set_thread_message_mapping(target_channel_id, thread_id, str(msg.id), fluxer_msg_id)
|
||||
else:
|
||||
context.state.set_message_mapping(target_channel_id, str(msg.id), fluxer_msg_id)
|
||||
else:
|
||||
logger.warning(f"Fluxer: send_message returned None for Discord ID {msg.id} (message might have been skipped or timed out)")
|
||||
|
||||
if thread_id:
|
||||
context.state.update_thread_last_message_timestamp(target_channel_id, thread_id, str(msg.created_at))
|
||||
context.state.update_thread_last_message_id(target_channel_id, thread_id, str(msg.id))
|
||||
context.state.increment_thread_stats(target_channel_id, thread_id, messages=1, files=len(files) if files else 0)
|
||||
else:
|
||||
context.state.update_last_message_timestamp(target_channel_id, str(msg.created_at))
|
||||
context.state.update_last_message_id(target_channel_id, str(msg.id))
|
||||
context.state.increment_stats(target_channel_id, messages=1, files=len(files) if files else 0)
|
||||
|
||||
stats["messages"] += 1
|
||||
stats["last_message_content"] = content
|
||||
stats["last_message_author"] = msg.author.display_name
|
||||
|
||||
# Check for associated thread (Normal case: parent message is migrated)
|
||||
if hasattr(msg, 'thread') and msg.thread:
|
||||
thread = msg.thread
|
||||
if thread.id not in processed_threads:
|
||||
processed_threads.add(thread.id)
|
||||
# Track thread entry
|
||||
stats["threads"] += 1
|
||||
|
||||
# Fetch last migrated message ID for this thread
|
||||
thread_after_id = context.state.get_thread_last_message_id(target_channel_id, str(thread.id))
|
||||
if thread_after_id:
|
||||
logger.info(f"Resuming thread '{thread.name}' from after message ID: {thread_after_id}")
|
||||
|
||||
# Migrate thread messages recursively
|
||||
thread_stats = await migrate_messages(
|
||||
context=context,
|
||||
source_channel_id=thread.id,
|
||||
|
|
@ -763,19 +637,22 @@ async def migrate_messages(
|
|||
stats["attachments"] += thread_stats["attachments"]
|
||||
stats["threads"] += thread_stats["threads"]
|
||||
|
||||
# Send End Marker
|
||||
if context.is_running:
|
||||
await context.fluxer_writer.send_marker(
|
||||
channel_id=target_channel_id,
|
||||
content=f"> <<< END OF THREAD >>>"
|
||||
)
|
||||
|
||||
# Update Link Tracking (Parent pointer updates)
|
||||
# Update Link Tracking (but prevent threaded messages from overwriting the parent channel pointers)
|
||||
# The 'after_message_id' param usually means it's the main function call and not a thread recursive call
|
||||
if not stats["first_message_url"]:
|
||||
stats["first_message_url"] = msg.jump_url
|
||||
stats["last_message_url"] = msg.jump_url
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback(stats)
|
||||
logger.debug(f"Fluxer: Finished processing message Discord ID {msg.id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process message {msg.id}: {e}")
|
||||
import traceback
|
||||
|
|
@ -858,7 +735,7 @@ async def migrate_global_messages(
|
|||
progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Migrates messages across all channels chronologically to Fluxer.
|
||||
Migrates messages across all channels chronologically.
|
||||
"""
|
||||
stats = {
|
||||
"messages": 0,
|
||||
|
|
@ -873,6 +750,14 @@ async def migrate_global_messages(
|
|||
processed_threads = set()
|
||||
logger.info("Starting Global Waterfall Migration for Fluxer...")
|
||||
|
||||
# Keep track of active thread mapping natively to pass parent target IDs if needed
|
||||
thread_to_target_channel = {}
|
||||
|
||||
# Emojis and mapped users cache setup
|
||||
emoji_map = context.state.emoji_map
|
||||
db_media = context.discord_reader.db.get_all_media() if context.discord_reader.db else {}
|
||||
target_server_id = getattr(context.fluxer_writer, "server_id", None)
|
||||
|
||||
# Fetch global progress map to skip migrated messages efficiently
|
||||
progress_map = context.state.get_all_last_message_ids()
|
||||
|
||||
|
|
@ -909,19 +794,124 @@ async def migrate_global_messages(
|
|||
continue
|
||||
|
||||
# If it's a thread message, we need to handle it based on if it's the thread starter or a reply
|
||||
parent_target_id = None
|
||||
if hasattr(msg, 'thread') and msg.thread and msg.id == msg.thread.id:
|
||||
processed_threads.add(msg.thread.id)
|
||||
stats["threads"] += 1
|
||||
elif msg.channel.type in [11, 12]: # Thread channels
|
||||
# It's a message IN a thread.
|
||||
# In Fluxer, threads might just be linear messages or threaded replies depending on schema
|
||||
# For basic migration we just send it to the parent mapped target channel.
|
||||
# The parent mapped target channel ID should already be calculated correctly by get_target_channel_id (which returns mapped thread or parent channel)
|
||||
pass
|
||||
|
||||
# Formatting
|
||||
files = []
|
||||
file_names = []
|
||||
|
||||
# Always ensure alias is created/retrieved to populate user_alias table
|
||||
alias = context.state.get_user_alias(str(msg.author.id))
|
||||
|
||||
anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False
|
||||
|
||||
if anonymize_users:
|
||||
author_name = alias or "Anonymized User"
|
||||
author_avatar_url = None
|
||||
else:
|
||||
author_name = msg.author.display_name
|
||||
author_avatar_url = msg.author.avatar.url if hasattr(msg.author, 'avatar') and msg.author.avatar else None
|
||||
|
||||
for att in msg.attachments:
|
||||
media_info = db_media.get(att.local_hash) if db_media else None
|
||||
local_path = None
|
||||
if media_info:
|
||||
local_path = Path(media_info["local_path"])
|
||||
elif hasattr(att, 'read'):
|
||||
# Fallback
|
||||
pass
|
||||
|
||||
if local_path and local_path.exists():
|
||||
files.append(local_path)
|
||||
file_names.append(att.filename)
|
||||
|
||||
content = msg.content or ""
|
||||
|
||||
# Stickers
|
||||
for sticker in msg.stickers:
|
||||
sticker_name = sticker.name
|
||||
sticker_url = sticker.url
|
||||
|
||||
# Check for uploaded media pool logic first
|
||||
s_hash = sticker.local_hash
|
||||
sticker_file = None
|
||||
s_media = db_media.get(s_hash) if db_media and s_hash else None
|
||||
if s_media:
|
||||
s_path = Path(s_media["local_path"])
|
||||
if s_path.exists():
|
||||
sticker_file = s_path
|
||||
|
||||
content += f"\n[Sticker: {sticker_name}]"
|
||||
if sticker_file:
|
||||
files.append(sticker_file)
|
||||
file_names.append(f"sticker_{sticker_name}.png")
|
||||
|
||||
content = clean_mentions(
|
||||
content=content,
|
||||
guild=context.discord_reader.guild,
|
||||
user_mentions=msg.mentions,
|
||||
role_mentions=msg.role_mentions,
|
||||
channel_mentions=msg.channel_mentions,
|
||||
emoji_map=emoji_map,
|
||||
channel_map=context.state.channel_map,
|
||||
state=context.state,
|
||||
target_server_id=target_server_id,
|
||||
channel_names=context.channel_names if hasattr(context, 'channel_names') else None,
|
||||
anonymize_users=anonymize_users
|
||||
)
|
||||
|
||||
if not content and not files:
|
||||
logger.debug(f"Message {msg.id} empty after processing, skipping.")
|
||||
continue
|
||||
|
||||
timestamp_int = int(msg.created_at.timestamp())
|
||||
|
||||
if msg.reference and msg.reference.message_id:
|
||||
# Resolve the author of the message being replied to
|
||||
source_ref_msg = await context.discord_reader.get_message(msg.channel.id, msg.reference.message_id)
|
||||
if source_ref_msg and source_ref_msg.author:
|
||||
ref_author_id = str(source_ref_msg.author.id)
|
||||
if anonymize_users:
|
||||
ref_name = context.state.get_user_alias(ref_author_id) or "Anonymized User"
|
||||
else:
|
||||
ref_name = source_ref_msg.author.display_name
|
||||
content = f"`@{ref_name}`\n{content}"
|
||||
else:
|
||||
# Fallback if author cannot be resolved (e.g. deleted/missing from backup)
|
||||
tgt_reply = context.state.get_target_message_id(target_channel_id, msg.reference.message_id)
|
||||
if tgt_reply:
|
||||
content = f"[Reply to {tgt_reply}]\n{content}"
|
||||
|
||||
try:
|
||||
await _process_and_send_message(
|
||||
context=context,
|
||||
msg=msg,
|
||||
target_channel_id=target_channel_id,
|
||||
stats=stats,
|
||||
processed_threads=processed_threads
|
||||
fluxer_msg_id = await context.fluxer_writer.send_message(
|
||||
channel_id=target_channel_id,
|
||||
author_name=author_name,
|
||||
author_avatar_url=author_avatar_url,
|
||||
content=content,
|
||||
files=files,
|
||||
timestamp=timestamp_int,
|
||||
embeds=msg.embeds
|
||||
)
|
||||
|
||||
if fluxer_msg_id:
|
||||
context.state.set_target_message_mapping(target_channel_id, msg.id, fluxer_msg_id)
|
||||
context.state.update_last_message_id(target_channel_id, msg.id)
|
||||
context.state.set_waterfall_last_id(msg.id)
|
||||
stats["attachments"] += len(files) if files else 0
|
||||
|
||||
stats["messages"] += 1
|
||||
stats["last_message_content"] = content
|
||||
stats["last_message_author"] = author_name
|
||||
|
||||
if not stats["first_message_url"]:
|
||||
stats["first_message_url"] = msg.jump_url
|
||||
stats["last_message_url"] = msg.jump_url
|
||||
|
|
|
|||
|
|
@ -36,7 +36,6 @@ class FluxerWriter:
|
|||
guilds_list.append((label, str(g.id)))
|
||||
return guilds_list
|
||||
except Exception as e:
|
||||
print(f"Failed to fetch Fluxer communities via HTTP: {e}")
|
||||
logger.error(f"Failed to fetch Fluxer communities via HTTP: {e}")
|
||||
raise
|
||||
|
||||
|
|
@ -62,7 +61,6 @@ class FluxerWriter:
|
|||
return w
|
||||
except Exception as e:
|
||||
print(f"Failed to manage webhook for channel {channel_id}: {e}")
|
||||
logger.error(f"Failed to manage webhook for channel {channel_id}: {e}")
|
||||
return None
|
||||
|
||||
async def start(self):
|
||||
|
|
@ -324,7 +322,6 @@ class FluxerWriter:
|
|||
logger.debug(f"Fluxer: Webhook send complete, msg_id={msg.id if msg else 'None'}")
|
||||
return str(msg.id) if msg else None
|
||||
except asyncio.TimeoutError:
|
||||
print(f"Fluxer: Webhook send timed out after 45s for channel {channel_id}")
|
||||
logger.error(f"Fluxer: Webhook send timed out after 45s for channel {channel_id}")
|
||||
return None
|
||||
else:
|
||||
|
|
@ -343,24 +340,19 @@ class FluxerWriter:
|
|||
|
||||
logger.debug(f"Fluxer: Sending message via bot for user '{author_name}'")
|
||||
try:
|
||||
kwargs = {
|
||||
"channel_id": channel_id,
|
||||
"content": final_bot_content,
|
||||
"embeds": normalized_embeds
|
||||
}
|
||||
if fluxer_files:
|
||||
kwargs["files"] = fluxer_files
|
||||
if message_reference:
|
||||
kwargs["message_reference"] = message_reference
|
||||
|
||||
msg_data = await asyncio.wait_for(
|
||||
self.client.send_message(**kwargs),
|
||||
self.client.send_message(
|
||||
channel_id=channel_id,
|
||||
content=final_bot_content,
|
||||
files=fluxer_files,
|
||||
embeds=normalized_embeds,
|
||||
message_reference=message_reference
|
||||
),
|
||||
timeout=45.0
|
||||
)
|
||||
logger.debug(f"Fluxer: Bot send complete, msg_id={msg_data.get('id') if msg_data else 'None'}")
|
||||
return str(msg_data["id"]) if msg_data else None
|
||||
except asyncio.TimeoutError:
|
||||
print(f"Fluxer: Bot send timed out after 45s for channel {channel_id}")
|
||||
logger.error(f"Fluxer: Bot send timed out after 45s for channel {channel_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
|
|
@ -378,27 +370,22 @@ class FluxerWriter:
|
|||
|
||||
fluxer_files = None
|
||||
if files:
|
||||
fluxer_files = [f if hasattr(f, "filename") else File(io.BytesIO(f["data"]), filename=f["filename"]) for f in files]
|
||||
fluxer_files = [File(io.BytesIO(f["data"]), filename=f["filename"]) for f in files]
|
||||
|
||||
message_reference = None
|
||||
if reply_to_message_id:
|
||||
message_reference = {"message_id": str(reply_to_message_id), "channel_id": str(channel_id)}
|
||||
|
||||
try:
|
||||
kwargs = {
|
||||
"channel_id": channel_id,
|
||||
"content": content
|
||||
}
|
||||
if fluxer_files:
|
||||
kwargs["files"] = fluxer_files
|
||||
if message_reference:
|
||||
kwargs["message_reference"] = message_reference
|
||||
|
||||
msg_data = await self.client.send_message(**kwargs)
|
||||
msg_data = await self.client.send_message(
|
||||
channel_id=channel_id,
|
||||
content=content,
|
||||
files=fluxer_files,
|
||||
message_reference=message_reference
|
||||
)
|
||||
return str(msg_data["id"]) if msg_data else None
|
||||
except Exception as e:
|
||||
print(f"Failed to send marker: {e}")
|
||||
logger.error(f"Failed to send marker: {e}")
|
||||
return None
|
||||
|
||||
async def create_role(self, name: str, color: int, hoist: bool, mentionable: bool, permissions: int, position: Optional[int] = None) -> str:
|
||||
|
|
@ -421,7 +408,6 @@ class FluxerWriter:
|
|||
return str(role["id"])
|
||||
except Exception as e:
|
||||
print(f"Failed to copy role {name}: {e}")
|
||||
logger.error(f"Failed to copy role {name}: {e}")
|
||||
return ""
|
||||
|
||||
async def create_emoji(self, name: str, image_bytes: bytes) -> str:
|
||||
|
|
@ -487,7 +473,6 @@ class FluxerWriter:
|
|||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to update community metadata: {e}")
|
||||
logger.error(f"Failed to update community metadata: {e}")
|
||||
|
||||
async def remove_community_logo_and_banner(self) -> dict:
|
||||
"""
|
||||
|
|
@ -518,7 +503,6 @@ class FluxerWriter:
|
|||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to remove community icon: {e}")
|
||||
logger.error(f"Failed to remove community icon: {e}")
|
||||
|
||||
# 3. Remove banner if set
|
||||
if has_banner:
|
||||
|
|
@ -529,7 +513,6 @@ class FluxerWriter:
|
|||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to remove community banner: {e}")
|
||||
logger.error(f"Failed to remove community banner: {e}")
|
||||
|
||||
return {
|
||||
"icon": "REMOVED" if has_icon else "SKIP",
|
||||
|
|
@ -561,7 +544,6 @@ class FluxerWriter:
|
|||
await progress_callback(ch.get("name", "Unknown"), deleted, total)
|
||||
except Exception as e:
|
||||
print(f"Failed to delete channel {ch.get('name')}: {e}")
|
||||
logger.error(f"Failed to delete channel {ch.get('name')}: {e}")
|
||||
return deleted
|
||||
|
||||
async def reset_channel_permissions(self, progress_callback=None) -> int:
|
||||
|
|
@ -594,14 +576,12 @@ class FluxerWriter:
|
|||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to delete overwrite {ow['id']} for channel {ch['id']}: {e}")
|
||||
logger.error(f"Failed to delete overwrite {ow['id']} for channel {ch['id']}: {e}")
|
||||
processed += 1
|
||||
if progress_callback:
|
||||
await progress_callback(ch.get("name", "Unknown"), processed, total)
|
||||
except Exception as e:
|
||||
print(f"Failed to reset permissions for channel {ch.get('name')}: {e}")
|
||||
logger.error(f"Failed to reset permissions for channel {ch.get('name')}: {e}")
|
||||
return processed
|
||||
|
||||
async def set_channel_permission(self, channel_id: str, overwrite_id: str, allow: int, deny: int, is_role: bool = True):
|
||||
|
|
@ -623,7 +603,6 @@ class FluxerWriter:
|
|||
type=0 if is_role else 1
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to set permission on channel {channel_id} for overwrite {overwrite_id}: {e}")
|
||||
logger.error(f"Failed to set permission on channel {channel_id} for overwrite {overwrite_id}: {e}")
|
||||
|
||||
|
||||
|
|
@ -665,7 +644,6 @@ class FluxerWriter:
|
|||
await progress_callback(role.get("name", "Unknown"), deleted, total)
|
||||
except Exception as e:
|
||||
print(f"Failed to delete role {role.get('name')}: {e}")
|
||||
logger.error(f"Failed to delete role {role.get('name')}: {e}")
|
||||
return deleted
|
||||
|
||||
async def delete_all_emojis_and_stickers(self, progress_callback=None) -> dict:
|
||||
|
|
@ -689,10 +667,8 @@ class FluxerWriter:
|
|||
await progress_callback(emoji.get("name", "Unknown"), "Emoji", emoji_deleted, emoji_total)
|
||||
except Exception as e:
|
||||
print(f"Failed to delete emoji {emoji.get('name')}: {e}")
|
||||
logger.error(f"Failed to delete emoji {emoji.get('name')}: {e}")
|
||||
except Exception as e:
|
||||
print(f"Failed to fetch emojis: {e}")
|
||||
logger.error(f"Failed to fetch emojis: {e}")
|
||||
|
||||
# Delete stickers
|
||||
try:
|
||||
|
|
@ -706,10 +682,8 @@ class FluxerWriter:
|
|||
await progress_callback(sticker.get("name", "Unknown"), "Sticker", sticker_deleted, sticker_total)
|
||||
except Exception as e:
|
||||
print(f"Failed to delete sticker {sticker.get('name')}: {e}")
|
||||
logger.error(f"Failed to delete sticker {sticker.get('name')}: {e}")
|
||||
except Exception as e:
|
||||
print(f"Failed to fetch stickers: {e}")
|
||||
logger.error(f"Failed to fetch stickers: {e}")
|
||||
|
||||
return {"emojis": emoji_deleted, "stickers": sticker_deleted}
|
||||
|
||||
|
|
|
|||
|
|
@ -16,52 +16,41 @@ async def sync_channel_state(context: MigrationContext):
|
|||
channels = await context.discord_reader.get_channels()
|
||||
target_channels = await context.writer.get_channels()
|
||||
|
||||
# Build maps for Stoat lookup
|
||||
# {name: id} for categories (type 4)
|
||||
target_cats = {c.get("name"): str(c.get("id")) for c in target_channels if c.get("type") == 4}
|
||||
# {parent_id: {name: id}} for channels
|
||||
target_structure = {}
|
||||
for c in target_channels:
|
||||
if c.get("type") == 4: continue
|
||||
p_id = str(c.get("parent_id")) if c.get("parent_id") else "root"
|
||||
if p_id not in target_structure: target_structure[p_id] = {}
|
||||
target_structure[p_id][c.get("name")] = str(c.get("id"))
|
||||
|
||||
# Build name -> id map and ID set for Stoat for fast lookup
|
||||
target_name_map = {c.get("name"): str(c.get("id")) for c in target_channels if c.get("name")}
|
||||
target_id_set = {str(c.get("id")) for c in target_channels}
|
||||
|
||||
updates = 0
|
||||
removals = 0
|
||||
|
||||
# 1. Sync Categories
|
||||
# 1. Verify and Sync Categories
|
||||
for cat in categories:
|
||||
discord_id = str(cat.id)
|
||||
target_id = context.state.get_target_category_id(discord_id)
|
||||
|
||||
if target_id:
|
||||
if target_id not in target_id_set:
|
||||
context.state.remove_category_mapping(discord_id)
|
||||
removals += 1
|
||||
elif cat.name in target_cats:
|
||||
context.state.set_target_category_mapping(discord_id, target_cats[cat.name])
|
||||
elif cat.name in target_name_map:
|
||||
context.state.set_target_category_mapping(discord_id, target_name_map[cat.name])
|
||||
updates += 1
|
||||
|
||||
# 2. Sync Channels (parent-aware)
|
||||
# 2. Verify and Sync Channels
|
||||
for ch in channels:
|
||||
discord_id = str(ch.id)
|
||||
target_id = context.state.get_target_channel_id(discord_id)
|
||||
|
||||
if target_id:
|
||||
if target_id not in target_id_set:
|
||||
context.state.remove_channel_mapping(discord_id)
|
||||
removals += 1
|
||||
else:
|
||||
# Try to match by name within the mapped parent category
|
||||
p_discord_id = str(ch.category_id) if ch.category_id else "root"
|
||||
p_target_id = context.state.get_target_category_id(p_discord_id) if p_discord_id != "root" else "root"
|
||||
|
||||
if p_target_id in target_structure and ch.name in target_structure[p_target_id]:
|
||||
context.state.set_target_channel_mapping(discord_id, target_structure[p_target_id][ch.name])
|
||||
updates += 1
|
||||
elif ch.name in target_name_map:
|
||||
context.state.set_target_channel_mapping(discord_id, target_name_map[ch.name])
|
||||
updates += 1
|
||||
|
||||
if updates > 0 or removals > 0:
|
||||
logger.info(f"Stoat Channel sync: {updates} mapped, {removals} stale mappings removed")
|
||||
logger.info(f"Channel sync: {updates} mapped, {removals} stale mappings removed")
|
||||
|
||||
|
||||
async def migrate_channels(context: MigrationContext, progress_callback: Callable[[str, str, int, int], Awaitable[None]] | None = None, force: bool = False) -> dict:
|
||||
|
|
@ -118,22 +107,7 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
|||
if total == 0:
|
||||
return cloned_info
|
||||
|
||||
# 1. Create missing categories first
|
||||
for cat in missing_categories:
|
||||
if not context.is_running: break
|
||||
|
||||
state_key = str(cat.id)
|
||||
target_id = await context.writer.create_channel(cat.name, type=4)
|
||||
if target_id:
|
||||
context.state.set_target_category_id(state_key, target_id)
|
||||
cloned_info["categories_created"].append(cat.name)
|
||||
if cat.name not in cloned_info["structure"]:
|
||||
cloned_info["structure"][cat.name] = []
|
||||
|
||||
current_idx += 1
|
||||
if progress_callback: await progress_callback(cat.name, "Copying", current_idx, total)
|
||||
|
||||
# 2. Create missing channels (now with parent_id available)
|
||||
# 1. Create missing channels (unparented for now)
|
||||
for channel in channels_to_create:
|
||||
if not context.is_running: break
|
||||
|
||||
|
|
@ -145,6 +119,7 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
|||
logger.debug(f"Creating channel {channel.name}: topic={topic}, nsfw={nsfw}, slowmode={slowmode}")
|
||||
|
||||
# Map Discord-specific types to target-supported types
|
||||
# 5 (News) -> 0 (Text), and fallback any unknown non-voice types to text
|
||||
raw_type = channel.type.value if hasattr(channel.type, 'value') else 0
|
||||
if raw_type == context.discord_reader.CHANNEL_TYPE_VOICE.value:
|
||||
ch_type = 2
|
||||
|
|
@ -153,22 +128,20 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
|||
ch_type = 0
|
||||
is_voice = False
|
||||
else:
|
||||
# Fallback for Stage channels (13) etc. to Text for safety
|
||||
ch_type = 0
|
||||
is_voice = False
|
||||
|
||||
# Resolve parent category
|
||||
parent_id = context.state.get_target_category_id(str(channel.category_id)) if channel.category_id else None
|
||||
|
||||
target_id = await context.writer.create_channel(
|
||||
name=channel.name,
|
||||
topic=topic if not is_voice else "",
|
||||
type=ch_type,
|
||||
parent_id=parent_id,
|
||||
parent_id=None,
|
||||
nsfw=nsfw if not is_voice else False,
|
||||
slowmode_delay=slowmode if not is_voice else 0
|
||||
)
|
||||
if target_id:
|
||||
context.state.set_target_channel_id(state_key, target_id)
|
||||
context.state.set_target_channel_mapping(state_key, target_id)
|
||||
cloned_info["channels_created"].append(channel.name)
|
||||
|
||||
parent_name = cat_name_map.get(str(channel.category_id), "No Category") if channel.category_id else "No Category"
|
||||
|
|
@ -183,8 +156,7 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
|||
name=channel.name,
|
||||
topic=topic,
|
||||
nsfw=nsfw,
|
||||
slowmode_delay=slowmode,
|
||||
parent_id=parent_id
|
||||
slowmode_delay=slowmode
|
||||
)
|
||||
|
||||
current_idx += 1
|
||||
|
|
@ -200,21 +172,32 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
|||
|
||||
logger.debug(f"Syncing existing channel {channel.name} ({target_id}): topic={topic}, nsfw={nsfw}, slowmode={slowmode}")
|
||||
|
||||
parent_id = context.state.get_target_category_id(str(channel.category_id)) if channel.category_id else None
|
||||
|
||||
await context.writer.modify_channel(
|
||||
channel_id=target_id,
|
||||
name=channel.name,
|
||||
topic=topic,
|
||||
nsfw=nsfw,
|
||||
slowmode_delay=slowmode,
|
||||
parent_id=parent_id
|
||||
slowmode_delay=slowmode
|
||||
)
|
||||
|
||||
cloned_info["channels_synced"].append(channel.name)
|
||||
|
||||
current_idx += 1
|
||||
if progress_callback: await progress_callback(channel.name, "Syncing", current_idx, total)
|
||||
|
||||
# 3. Create missing categories
|
||||
for cat in missing_categories:
|
||||
if not context.is_running: break
|
||||
|
||||
state_key = str(cat.id)
|
||||
target_id = await context.writer.create_channel(cat.name, type=4)
|
||||
if target_id:
|
||||
context.state.set_target_category_mapping(state_key, target_id)
|
||||
cloned_info["categories_created"].append(cat.name)
|
||||
if cat.name not in cloned_info["structure"]:
|
||||
cloned_info["structure"][cat.name] = []
|
||||
|
||||
current_idx += 1
|
||||
if progress_callback: await progress_callback(f"Cat: {cat.name}", "Copying", current_idx, total)
|
||||
|
||||
# 4. Final step: Parent the channels into categories via mass server.edit()
|
||||
|
|
@ -240,7 +223,7 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
|||
|
||||
# Resolve Stoat categories
|
||||
# We iterate over the categories from the server to ensure we don't drop any
|
||||
for stoat_cat in (server.categories or []):
|
||||
for stoat_cat in server.categories:
|
||||
# Check if this Stoat category maps to any Discord category
|
||||
discord_cat_id = next((d_id for d_id, t_id in context.state.category_map.items() if t_id == str(stoat_cat.id)), None)
|
||||
|
||||
|
|
@ -278,35 +261,8 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
|||
target_categories.sort(key=get_cat_position)
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
api_base = (context.writer.api_url or "https://api.stoat.chat/0.8").rstrip("/")
|
||||
cats_payload = []
|
||||
for c in target_categories:
|
||||
entry: dict = {"id": str(c.id), "title": c.title, "channels": [str(x) for x in c.channels]}
|
||||
if getattr(c, "default_permissions", None) is not None:
|
||||
entry["default_permissions"] = c.default_permissions
|
||||
if getattr(c, "role_permissions", None):
|
||||
entry["role_permissions"] = c.role_permissions
|
||||
cats_payload.append(entry)
|
||||
for _attempt in range(5):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.patch(
|
||||
f"{api_base}/servers/{context.writer.community_id}",
|
||||
json={"categories": cats_payload},
|
||||
headers={"X-Bot-Token": context.writer.token.strip(), "Content-Type": "application/json"},
|
||||
) as resp:
|
||||
if resp.status == 429:
|
||||
import re as _re
|
||||
body = await resp.json(content_type=None)
|
||||
wait_ms = body.get("retry_after", 10000)
|
||||
logger.warning(f"Rate limited on mass parent, retrying in {wait_ms/1000:.1f}s…")
|
||||
await asyncio.sleep(wait_ms / 1000 + 0.5)
|
||||
continue
|
||||
if resp.status not in (200, 204):
|
||||
body = await resp.json(content_type=None)
|
||||
raise Exception(f"HTTP {resp.status}: {body}")
|
||||
logger.info("Successfully parented all channels.")
|
||||
break
|
||||
await server.edit(categories=target_categories)
|
||||
logger.info("Successfully parented all channels.")
|
||||
except Exception as ex:
|
||||
logger.error(f"Failed to mass parent channels: {ex}")
|
||||
|
||||
|
|
|
|||
|
|
@ -155,192 +155,7 @@ async def get_channel_threads(reader: Any, channel_id: int) -> List[Any]:
|
|||
return threads
|
||||
|
||||
|
||||
async def _process_and_send_message(
|
||||
context: MigrationContext,
|
||||
msg: Any,
|
||||
target_channel_id: str,
|
||||
stats: Dict[str, Any],
|
||||
thread_id: str | None = None,
|
||||
parent_target_id: str | None = None,
|
||||
thread_name: str | None = None,
|
||||
processed_threads: set | None = None
|
||||
) -> str | None:
|
||||
"""
|
||||
Internal helper to process a single Discord message (mentions, attachments, stickers)
|
||||
and send it to the Stoat platform.
|
||||
"""
|
||||
# 1. Processing Flags
|
||||
is_forwarded = False
|
||||
if hasattr(msg.flags, 'forwarded'):
|
||||
is_forwarded = msg.flags.forwarded
|
||||
|
||||
# 2. Content & Formatting
|
||||
content = msg.content or ""
|
||||
anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False
|
||||
alias = context.state.get_user_alias(str(msg.author.id))
|
||||
|
||||
# Process Stickers
|
||||
files = []
|
||||
if hasattr(msg, 'stickers') and msg.stickers:
|
||||
for s in msg.stickers:
|
||||
try:
|
||||
sticker_data = await context.discord_reader.download_sticker(s)
|
||||
if not sticker_data: continue
|
||||
|
||||
format_val = getattr(s, 'format', 'png')
|
||||
if hasattr(format_val, 'name'):
|
||||
ext = format_val.name.lower()
|
||||
elif isinstance(format_val, int):
|
||||
format_map = {1: 'png', 2: 'apng', 3: 'lottie', 4: 'gif'}
|
||||
ext = format_map.get(format_val, 'png')
|
||||
else:
|
||||
ext = str(format_val).lower()
|
||||
|
||||
# Conversion logic for Stoat (WebP or GIF focus)
|
||||
if ext == 'lottie' and HAS_LOTTIE:
|
||||
try:
|
||||
lottie_data = json.loads(sticker_data)
|
||||
def _convert_lottie_to_gif(data):
|
||||
animation = Animation.load(data)
|
||||
output = io.BytesIO()
|
||||
export_gif(animation, output)
|
||||
return output.getvalue()
|
||||
sticker_data = await asyncio.to_thread(_convert_lottie_to_gif, lottie_data)
|
||||
ext = 'gif'
|
||||
except Exception: ext = 'json'
|
||||
elif ext == 'apng':
|
||||
try:
|
||||
from PIL import Image
|
||||
def _convert_apng_to_gif(data):
|
||||
img = Image.open(io.BytesIO(data))
|
||||
gif_buf = io.BytesIO()
|
||||
if getattr(img, 'n_frames', 1) > 1:
|
||||
frames = []
|
||||
durations = []
|
||||
for i in range(img.n_frames):
|
||||
img.seek(i)
|
||||
frame = img.convert('RGBA')
|
||||
current_frame = Image.new('RGBA', img.size, (0,0,0,0))
|
||||
current_frame.paste(frame, (0, 0))
|
||||
frames.append(current_frame)
|
||||
durations.append(img.info.get('duration', 100))
|
||||
frames[0].save(gif_buf, format='GIF', save_all=True, append_images=frames[1:], loop=0, duration=durations, disposal=2, transparency=0)
|
||||
else: img.save(gif_buf, format='GIF')
|
||||
return gif_buf.getvalue()
|
||||
sticker_data = await asyncio.to_thread(_convert_apng_to_gif, sticker_data)
|
||||
ext = 'gif'
|
||||
except Exception: pass
|
||||
|
||||
files.append({
|
||||
"filename": f"sticker_{s.name}_{s.id}.{ext}",
|
||||
"data": sticker_data,
|
||||
"content_type": f"image/{ext}" if ext != "json" else "application/json"
|
||||
})
|
||||
stats["attachments"] += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download sticker {getattr(s, 'name', 'unknown')}: {e}")
|
||||
|
||||
# Process Attachments
|
||||
attachments_to_process = list(msg.attachments)
|
||||
if is_forwarded and hasattr(msg, 'message_snapshots') and msg.message_snapshots:
|
||||
snapshot = msg.message_snapshots[0]
|
||||
if not content: content = snapshot.content
|
||||
attachments_to_process.extend(snapshot.attachments)
|
||||
|
||||
for att in attachments_to_process:
|
||||
try:
|
||||
att_data = await context.discord_reader.download_attachment(att)
|
||||
files.append({
|
||||
"filename": att.filename,
|
||||
"data": att_data,
|
||||
"content_type": getattr(att, "content_type", None)
|
||||
})
|
||||
stats["attachments"] += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download attachment {att.filename}: {e}")
|
||||
|
||||
# Clean Mentions
|
||||
content = clean_mentions(
|
||||
content=content,
|
||||
guild=context.discord_reader.guild,
|
||||
user_mentions=msg.mentions,
|
||||
role_mentions=msg.role_mentions,
|
||||
channel_mentions=msg.channel_mentions,
|
||||
emoji_map=context.state.emoji_map,
|
||||
channel_map=context.state.channel_map,
|
||||
state=context.state,
|
||||
target_server_id=context.stoat_writer.community_id,
|
||||
channel_names=context.channel_names if hasattr(context, 'channel_names') else None,
|
||||
anonymize_users=anonymize_users
|
||||
)
|
||||
|
||||
if not content and not files:
|
||||
return None
|
||||
|
||||
# Reply Resolution
|
||||
reply_to_stoat_id = None
|
||||
if msg.reference and msg.reference.message_id:
|
||||
reply_to_stoat_id = context.state.get_target_message_id(target_channel_id, str(msg.reference.message_id))
|
||||
if not reply_to_stoat_id:
|
||||
# Fallback author tagging
|
||||
try:
|
||||
source_ref_msg = await context.discord_reader.get_message(msg.channel.id, msg.reference.message_id)
|
||||
if source_ref_msg and source_ref_msg.author:
|
||||
ref_name = context.state.get_user_alias(str(source_ref_msg.author.id)) if anonymize_users else source_ref_msg.author.display_name
|
||||
content = f"`@{ref_name}`\n{content}"
|
||||
else:
|
||||
tgt_reply = context.state.get_target_message_id(target_channel_id, msg.reference.message_id)
|
||||
if tgt_reply: content = f"[Reply to {tgt_reply}]\n{content}"
|
||||
except Exception: pass
|
||||
|
||||
# Thread logic
|
||||
if not reply_to_stoat_id and parent_target_id and stats["messages"] == 0:
|
||||
reply_to_stoat_id = parent_target_id
|
||||
if thread_name and stats["messages"] == 0:
|
||||
content = f"> <<< THREAD: **{thread_name}** >>>\n{content}"
|
||||
|
||||
# Author resolution
|
||||
if anonymize_users:
|
||||
author_name = alias or "Anonymized User"
|
||||
author_avatar_url = None
|
||||
else:
|
||||
author_name = msg.author.display_name
|
||||
author_avatar_url = str(msg.author.display_avatar.url) if msg.author.display_avatar.url else None
|
||||
if author_avatar_url and not author_avatar_url.startswith("http"): author_avatar_url = None
|
||||
|
||||
# Send Message
|
||||
stoat_msg_id = await context.stoat_writer.send_message(
|
||||
channel_id=target_channel_id,
|
||||
author_name=author_name,
|
||||
author_avatar_url=author_avatar_url,
|
||||
content=content,
|
||||
timestamp=int(msg.created_at.timestamp()),
|
||||
files=files if files else None,
|
||||
reply_to_message_id=reply_to_stoat_id,
|
||||
is_forwarded=is_forwarded,
|
||||
embeds=msg.embeds
|
||||
)
|
||||
|
||||
if stoat_msg_id:
|
||||
if thread_id:
|
||||
context.state.set_thread_message_mapping(target_channel_id, thread_id, str(msg.id), stoat_msg_id)
|
||||
context.state.update_thread_last_message_timestamp(target_channel_id, thread_id, str(msg.created_at))
|
||||
context.state.update_thread_last_message_id(target_channel_id, thread_id, str(msg.id))
|
||||
context.state.increment_thread_stats(target_channel_id, thread_id, messages=1, files=len(files) if files else 0)
|
||||
else:
|
||||
context.state.set_message_mapping(target_channel_id, str(msg.id), stoat_msg_id)
|
||||
context.state.update_last_message_timestamp(target_channel_id, str(msg.created_at))
|
||||
context.state.update_last_message_id(target_channel_id, str(msg.id))
|
||||
context.state.increment_stats(target_channel_id, messages=1, files=len(files) if files else 0)
|
||||
|
||||
stats["messages"] += 1
|
||||
stats["last_message_content"] = content
|
||||
stats["last_message_author"] = msg.author.display_name
|
||||
|
||||
return stoat_msg_id
|
||||
|
||||
async def analyze_migration(context: MigrationContext, source_channel_id: int, after_message_id: int | None = None, inclusive: bool = False, progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None, processed_threads: set | None = None) -> Dict[str, int]:
|
||||
|
||||
"""
|
||||
Scans channel history to count messages, threads, and attachments.
|
||||
"""
|
||||
|
|
@ -731,30 +546,87 @@ async def migrate_messages(
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to download sticker {getattr(s, 'name', 'unknown')}: {e}")
|
||||
|
||||
# Check for existing mapping to avoid duplicates when resuming
|
||||
if context.state.get_target_message_id(target_channel_id, str(msg.id)):
|
||||
continue
|
||||
|
||||
try:
|
||||
stoat_msg_id = await _process_and_send_message(
|
||||
context=context,
|
||||
msg=msg,
|
||||
target_channel_id=target_channel_id,
|
||||
stats=stats,
|
||||
thread_id=thread_id,
|
||||
parent_target_id=parent_target_id,
|
||||
thread_name=thread_name,
|
||||
processed_threads=processed_threads
|
||||
# Check for existing mapping to avoid duplicates when resuming
|
||||
if context.state.get_target_message_id(target_channel_id, str(msg.id)):
|
||||
continue
|
||||
|
||||
# Check if this message is a reply
|
||||
reply_to_stoat_id = None
|
||||
if msg.reference and msg.reference.message_id:
|
||||
reply_to_stoat_id = context.state.get_target_message_id(target_channel_id, str(msg.reference.message_id))
|
||||
if reply_to_stoat_id:
|
||||
logger.debug(f"Detected reply to Discord ID {msg.reference.message_id} -> Stoat ID {reply_to_stoat_id}")
|
||||
else:
|
||||
logger.debug(f"Reply target Discord ID {msg.reference.message_id} not found in current session map.")
|
||||
|
||||
# If this is the FIRST thread message and we have a parent_target_id, force it as reply to the starter
|
||||
if not reply_to_stoat_id and parent_target_id and stats["messages"] == 0:
|
||||
reply_to_stoat_id = parent_target_id
|
||||
|
||||
# Prepend thread marker to the first message of the thread
|
||||
if thread_name and stats["messages"] == 0:
|
||||
content = f"> <<< THREAD: **{thread_name}** >>>\n{content}"
|
||||
|
||||
# Always ensure alias is created/retrieved to populate user_alias table
|
||||
alias = context.state.get_user_alias(str(msg.author.id))
|
||||
|
||||
anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False
|
||||
if anonymize_users:
|
||||
author_name = alias or "Anonymized User"
|
||||
author_avatar_url = None
|
||||
else:
|
||||
author_name = msg.author.display_name
|
||||
author_avatar_url = str(msg.author.display_avatar.url) if msg.author.display_avatar.url else None
|
||||
if author_avatar_url and not author_avatar_url.startswith("http"):
|
||||
author_avatar_url = None
|
||||
|
||||
logger.debug(f"Stoat: Calling send_message for Discord ID {msg.id}")
|
||||
stoat_msg_id = await context.stoat_writer.send_message(
|
||||
channel_id=target_channel_id,
|
||||
author_name=author_name,
|
||||
author_avatar_url=author_avatar_url,
|
||||
content=content,
|
||||
timestamp=int(msg.created_at.timestamp()),
|
||||
files=files if files else None,
|
||||
reply_to_message_id=reply_to_stoat_id,
|
||||
is_forwarded=is_forwarded,
|
||||
embeds=msg.embeds
|
||||
)
|
||||
|
||||
# Check for associated thread (Individual mode recursion)
|
||||
if stoat_msg_id:
|
||||
if thread_id:
|
||||
context.state.set_thread_message_mapping(target_channel_id, thread_id, str(msg.id), stoat_msg_id)
|
||||
else:
|
||||
context.state.set_message_mapping(target_channel_id, str(msg.id), stoat_msg_id)
|
||||
|
||||
if thread_id:
|
||||
context.state.update_thread_last_message_timestamp(target_channel_id, thread_id, str(msg.created_at))
|
||||
context.state.update_thread_last_message_id(target_channel_id, thread_id, str(msg.id))
|
||||
context.state.increment_thread_stats(target_channel_id, thread_id, messages=1, files=len(files) if files else 0)
|
||||
else:
|
||||
context.state.update_last_message_timestamp(target_channel_id, str(msg.created_at))
|
||||
context.state.update_last_message_id(target_channel_id, str(msg.id))
|
||||
context.state.increment_stats(target_channel_id, messages=1, files=len(files) if files else 0)
|
||||
|
||||
stats["messages"] += 1
|
||||
stats["last_message_content"] = content
|
||||
stats["last_message_author"] = msg.author.display_name
|
||||
|
||||
# Check for associated thread (Normal case: parent message is migrated)
|
||||
if hasattr(msg, 'thread') and msg.thread:
|
||||
thread = msg.thread
|
||||
if thread.id not in processed_threads:
|
||||
processed_threads.add(thread.id)
|
||||
# Track thread entry
|
||||
stats["threads"] += 1
|
||||
|
||||
# Fetch last migrated message ID for this thread
|
||||
thread_after_id = context.state.get_thread_last_message_id(target_channel_id, str(thread.id))
|
||||
if thread_after_id:
|
||||
logger.info(f"Resuming thread '{thread.name}' from after message ID: {thread_after_id}")
|
||||
|
||||
# Migrate thread messages recursively
|
||||
thread_stats = await migrate_messages(
|
||||
context=context,
|
||||
source_channel_id=thread.id,
|
||||
|
|
@ -769,6 +641,7 @@ async def migrate_messages(
|
|||
stats["attachments"] += thread_stats["attachments"]
|
||||
stats["threads"] += thread_stats["threads"]
|
||||
|
||||
# Send End Marker
|
||||
if context.is_running:
|
||||
await context.stoat_writer.send_marker(
|
||||
channel_id=target_channel_id,
|
||||
|
|
@ -783,12 +656,13 @@ async def migrate_messages(
|
|||
if progress_callback:
|
||||
await progress_callback(stats)
|
||||
except Exception as e:
|
||||
if "MissingPermission" in str(e): raise
|
||||
# If it's a permission error, stop the entire migration
|
||||
if "MissingPermission" in str(e):
|
||||
raise
|
||||
logger.error(f"Failed to process message {msg.id}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
# Mark thread as completed if we finished the loop without being interrupted
|
||||
if thread_id and context.is_running:
|
||||
context.state.update_thread_completed(target_channel_id, thread_id, completed=True)
|
||||
|
|
@ -921,15 +795,107 @@ async def migrate_global_messages(
|
|||
elif msg.channel.type in [11, 12]:
|
||||
pass
|
||||
|
||||
# Formatting
|
||||
files = []
|
||||
|
||||
# Always ensure alias is created/retrieved to populate user_alias table
|
||||
alias = context.state.get_user_alias(str(msg.author.id))
|
||||
|
||||
anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False
|
||||
|
||||
if anonymize_users:
|
||||
author_name = alias or "Anonymized User"
|
||||
author_avatar_url = None
|
||||
else:
|
||||
author_name = msg.author.display_name
|
||||
author_avatar_url = msg.author.avatar.url if hasattr(msg.author, 'avatar') and msg.author.avatar else None
|
||||
|
||||
for att in msg.attachments:
|
||||
media_info = db_media.get(att.local_hash) if db_media else None
|
||||
local_path = None
|
||||
if media_info:
|
||||
local_path = Path(media_info["local_path"])
|
||||
|
||||
if local_path and local_path.exists():
|
||||
try:
|
||||
with open(local_path, "rb") as f:
|
||||
files.append({"filename": att.filename, "data": f.read()})
|
||||
except Exception as fe:
|
||||
logger.error(f"Failed to read file {local_path}: {fe}")
|
||||
|
||||
content = msg.content or ""
|
||||
|
||||
for sticker in msg.stickers:
|
||||
sticker_name = sticker.name
|
||||
s_hash = sticker.local_hash
|
||||
sticker_file = None
|
||||
s_media = db_media.get(s_hash) if db_media and s_hash else None
|
||||
if s_media:
|
||||
s_path = Path(s_media["local_path"])
|
||||
if s_path.exists():
|
||||
sticker_file = s_path
|
||||
|
||||
content += f"\n[Sticker: {sticker_name}]"
|
||||
if sticker_file:
|
||||
files.append(sticker_file)
|
||||
file_names.append(f"sticker_{sticker_name}.png")
|
||||
|
||||
content = clean_mentions(
|
||||
content=content,
|
||||
guild=context.discord_reader.guild,
|
||||
user_mentions=msg.mentions,
|
||||
role_mentions=msg.role_mentions,
|
||||
channel_mentions=msg.channel_mentions,
|
||||
emoji_map=emoji_map,
|
||||
channel_map=context.state.channel_map,
|
||||
state=context.state,
|
||||
target_server_id=target_server_id,
|
||||
channel_names=context.channel_names if hasattr(context, 'channel_names') else None,
|
||||
anonymize_users=anonymize_users
|
||||
)
|
||||
|
||||
if not content and not files:
|
||||
logger.debug(f"Message {msg.id} empty after processing, skipping.")
|
||||
continue
|
||||
|
||||
timestamp_int = int(msg.created_at.timestamp())
|
||||
|
||||
if msg.reference and msg.reference.message_id:
|
||||
# Resolve the author of the message being replied to
|
||||
source_ref_msg = await context.discord_reader.get_message(msg.channel_id, msg.reference.message_id)
|
||||
if source_ref_msg and source_ref_msg.author:
|
||||
ref_author_id = str(source_ref_msg.author.id)
|
||||
if anonymize_users:
|
||||
ref_name = context.state.get_user_alias(ref_author_id) or "Anonymized User"
|
||||
else:
|
||||
ref_name = source_ref_msg.author.display_name
|
||||
content = f"`@{ref_name}`\n{content}"
|
||||
else:
|
||||
tgt_reply = context.state.get_target_message_id(target_channel_id, msg.reference.message_id)
|
||||
if tgt_reply:
|
||||
content = f"[Reply to {tgt_reply}]\n{content}"
|
||||
|
||||
try:
|
||||
await _process_and_send_message(
|
||||
context=context,
|
||||
msg=msg,
|
||||
target_channel_id=target_channel_id,
|
||||
stats=stats,
|
||||
processed_threads=processed_threads
|
||||
stoat_msg_id = await context.stoat_writer.send_message(
|
||||
channel_id=target_channel_id,
|
||||
author_name=author_name,
|
||||
author_avatar_url=author_avatar_url,
|
||||
content=content,
|
||||
files=files,
|
||||
timestamp=timestamp_int,
|
||||
embeds=msg.embeds
|
||||
)
|
||||
|
||||
if stoat_msg_id:
|
||||
context.state.set_target_message_mapping(target_channel_id, msg.id, stoat_msg_id)
|
||||
context.state.update_last_message_id(target_channel_id, msg.id)
|
||||
context.state.set_waterfall_last_id(msg.id)
|
||||
stats["attachments"] += len(files) if files else 0
|
||||
|
||||
stats["messages"] += 1
|
||||
stats["last_message_content"] = content
|
||||
stats["last_message_author"] = author_name
|
||||
|
||||
if not stats["first_message_url"]:
|
||||
stats["first_message_url"] = msg.jump_url
|
||||
stats["last_message_url"] = msg.jump_url
|
||||
|
|
@ -940,7 +906,6 @@ async def migrate_global_messages(
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to process global message {msg.id}: {e}")
|
||||
|
||||
|
||||
except (KeyboardInterrupt, asyncio.CancelledError):
|
||||
context.is_running = False
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -5,90 +5,24 @@ from typing import Optional, List, Dict, Any
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def _discover_stoat_config(api_url: str) -> Dict[str, Optional[str]]:
|
||||
"""
|
||||
Fetches the Stoat/Revolt instance configuration to discover the WS and CDN URLs.
|
||||
Returns a dict with 'ws' and 'cdn' keys.
|
||||
"""
|
||||
results = {"ws": None, "cdn": None}
|
||||
|
||||
if not api_url or api_url == "default" or "stoat.chat" in api_url:
|
||||
return results
|
||||
|
||||
import aiohttp
|
||||
try:
|
||||
# Standard Revolt discovery endpoint is the root API URL
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(api_url.rstrip("/") + "/") as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
results["ws"] = data.get("ws")
|
||||
# Features might be in 'features.autumn.url'
|
||||
features = data.get("features", {})
|
||||
autumn = features.get("autumn", {})
|
||||
results["cdn"] = autumn.get("url")
|
||||
|
||||
if results["ws"] or results["cdn"]:
|
||||
logger.debug(f"Stoat Discovery: Found WS={results['ws']}, CDN={results['cdn']}")
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.debug(f"Stoat Discovery failed (fetching): {e}")
|
||||
|
||||
# Fallback to inference if fetch failed
|
||||
from urllib.parse import urlparse
|
||||
try:
|
||||
parsed = urlparse(api_url)
|
||||
if parsed.netloc:
|
||||
# Traditional defaults for self-hosted
|
||||
results["ws"] = f"wss://{parsed.netloc}/ws"
|
||||
results["cdn"] = f"https://{parsed.netloc}/autumn"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return results
|
||||
|
||||
class StoatWriter:
|
||||
def __init__(self, token: str, community_id: str, api_url: str = "default", ws_url: str = None):
|
||||
def __init__(self, token: str, community_id: str, api_url: str = "default"):
|
||||
self.token = token
|
||||
self.community_id = str(community_id)
|
||||
self.api_url = api_url
|
||||
self.ws_url = ws_url
|
||||
self.client: Optional[stoat.Client] = None
|
||||
self._server = None
|
||||
self._me = None
|
||||
self._validation_cache = None
|
||||
|
||||
@staticmethod
|
||||
async def fetch_guilds(token: str, api_url: str = "default", ws_url: str = None, cdn_url: str = None) -> list[tuple[str, str]]:
|
||||
async def fetch_guilds(token: str, api_url: str = "default") -> list[tuple[str, str]]:
|
||||
"""Fetches the list of Stoat servers the bot is in. Returns list of (label, id)."""
|
||||
token = token.strip()
|
||||
api_url = (api_url or "default").strip()
|
||||
ws_url = (ws_url or "").strip()
|
||||
cdn_url = (cdn_url or "").strip()
|
||||
client_kwargs = {"token": token, "bot": True}
|
||||
if api_url and api_url != "default":
|
||||
client_kwargs["http_base"] = api_url
|
||||
|
||||
# Auto-discover URLs if not provided for custom domains
|
||||
if api_url != "default" and (not ws_url or not cdn_url):
|
||||
discovery = await _discover_stoat_config(api_url)
|
||||
if not ws_url: ws_url = discovery["ws"] or ""
|
||||
if not cdn_url: cdn_url = discovery["cdn"] or ""
|
||||
|
||||
# Diagnostics to both stdout and logger
|
||||
log_msg = f"Stoat: Fetching guilds using API URL: {api_url}"
|
||||
if ws_url: log_msg += f" [WS: {ws_url}]"
|
||||
if cdn_url: log_msg += f" [CDN: {cdn_url}]"
|
||||
print(log_msg)
|
||||
logger.debug(log_msg)
|
||||
|
||||
client_kwargs = {
|
||||
"token": token,
|
||||
"bot": True,
|
||||
"http_base": api_url if api_url != "default" else None,
|
||||
"websocket_base": ws_url or None,
|
||||
"cdn_base": cdn_url or None
|
||||
}
|
||||
client = stoat.Client(**client_kwargs)
|
||||
logger.debug(f"Stoat: Initialized client with native http_base={client_kwargs['http_base']} and websocket_base={client_kwargs['websocket_base']}")
|
||||
|
||||
ready_event = asyncio.Event()
|
||||
servers_list = []
|
||||
|
||||
|
|
@ -124,21 +58,14 @@ class StoatWriter:
|
|||
else:
|
||||
raise asyncio.TimeoutError("Timed out waiting for Stoat to be ready")
|
||||
except Exception as e:
|
||||
print(f"Failed to fetch Stoat servers: {e}")
|
||||
logger.error(f"Failed to fetch Stoat servers: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Shutdown the specific client instance used for fetching
|
||||
try:
|
||||
await client.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await client.close()
|
||||
client_task.cancel()
|
||||
try:
|
||||
# Wait for the task to actually finish terminating
|
||||
await asyncio.wait_for(client_task, timeout=2.0)
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
await client_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
return guilds_list
|
||||
|
|
@ -153,45 +80,16 @@ class StoatWriter:
|
|||
pass
|
||||
self.client = None
|
||||
|
||||
api_url = (self.api_url or "default").strip()
|
||||
ws_url = (self.ws_url or "").strip()
|
||||
cdn_url = getattr(self, "cdn_url", "").strip()
|
||||
client_kwargs = {"token": self.token, "bot": True}
|
||||
if self.api_url and self.api_url != "default":
|
||||
client_kwargs["http_base"] = self.api_url
|
||||
|
||||
# Auto-discover if not provided for custom domains
|
||||
if api_url != "default" and (not ws_url or not cdn_url):
|
||||
discovery = await _discover_stoat_config(api_url)
|
||||
if not ws_url: ws_url = discovery["ws"] or ""
|
||||
if not cdn_url: cdn_url = discovery["cdn"] or ""
|
||||
|
||||
token = self.token.strip()
|
||||
|
||||
# Diagnostics to both stdout and logger
|
||||
log_msg = f"Stoat: Starting client using API URL: {api_url}"
|
||||
if ws_url: log_msg += f" [WS: {ws_url}]"
|
||||
if cdn_url: log_msg += f" [CDN: {cdn_url}]"
|
||||
print(log_msg)
|
||||
logger.debug(log_msg)
|
||||
|
||||
client_kwargs = {
|
||||
"token": token,
|
||||
"bot": True,
|
||||
"http_base": api_url if api_url != "default" else None,
|
||||
"websocket_base": ws_url or None,
|
||||
"cdn_base": cdn_url or None
|
||||
}
|
||||
self.client = stoat.Client(**client_kwargs)
|
||||
|
||||
# Keep track of the start task so we can clean it up later
|
||||
self._start_task = asyncio.create_task(self.client.start())
|
||||
|
||||
try:
|
||||
self._me = await self.client.fetch_user("@me")
|
||||
except Exception as e:
|
||||
print(f"Failed to fetch bot user in StoatWriter: {e}")
|
||||
logger.error(f"Failed to fetch bot user in StoatWriter: {e}")
|
||||
# Ensure we clean up if start fails
|
||||
await self.close()
|
||||
self.client = None
|
||||
self.client = None # Reset if we can't even fetch @me
|
||||
|
||||
@property
|
||||
def my_id(self):
|
||||
|
|
@ -337,93 +235,68 @@ class StoatWriter:
|
|||
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"Failed to fetch Stoat channels: {e}")
|
||||
logger.error(f"Failed to fetch Stoat channels: {e}")
|
||||
return []
|
||||
|
||||
async def create_channel(self, name: str, type: int = 0, topic: str = "", parent_id: Optional[str] = None, **kwargs) -> str:
|
||||
server = await self._get_server(populate_channels=True)
|
||||
try:
|
||||
if type == 4: # Category — use direct HTTP to avoid stoat.py auth issues
|
||||
import aiohttp, random, time
|
||||
if type == 4: # Category
|
||||
# The POST /categories endpoint throws 404 on some server versions, so we use server.edit(categories)
|
||||
import random
|
||||
import time
|
||||
chars = "0123456789ABCDEFGHJKMNPQRSTVWXYZ"
|
||||
ts = int(time.time() * 1000)
|
||||
new_id = format(ts, '010X') + "".join(random.choice(chars) for _ in range(16))
|
||||
api_base = (self.api_url or "https://api.stoat.chat/0.8").rstrip("/")
|
||||
# Fetch current server to get existing categories
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{api_base}/servers/{self.community_id}",
|
||||
headers={"X-Bot-Token": self.token.strip()},
|
||||
) as resp:
|
||||
sdata = await resp.json()
|
||||
existing = sdata.get("categories") or []
|
||||
existing.append({"id": new_id, "title": name, "channels": []})
|
||||
async with session.patch(
|
||||
f"{api_base}/servers/{self.community_id}",
|
||||
json={"categories": existing},
|
||||
headers={"X-Bot-Token": self.token.strip(), "Content-Type": "application/json"},
|
||||
) as resp:
|
||||
data = await resp.json()
|
||||
logger.debug(f"PATCH /servers category resp {resp.status}: {data}")
|
||||
print(f"[writer] PATCH /servers category resp {resp.status}: {data}")
|
||||
if resp.status not in (200, 201):
|
||||
raise Exception(f"HTTP {resp.status}: {data}")
|
||||
self._server = None
|
||||
# Mock a ULID
|
||||
new_id = "01" + "".join(random.choice(chars) for _ in range(24))
|
||||
|
||||
categories = list(server.categories) if hasattr(server, "categories") and server.categories else []
|
||||
# Workaround for stoat.py bug: existing categories may fail to_dict() if slots are uninitialized
|
||||
for c in categories:
|
||||
if not hasattr(c, "default_permissions"): c.default_permissions = None
|
||||
if not hasattr(c, "role_permissions"): c.role_permissions = {}
|
||||
|
||||
new_cat = stoat.Category(id=new_id, title=name, channels=[])
|
||||
if not hasattr(new_cat, "default_permissions"): new_cat.default_permissions = None
|
||||
if not hasattr(new_cat, "role_permissions"): new_cat.role_permissions = {}
|
||||
categories.append(new_cat)
|
||||
|
||||
await server.edit(categories=categories)
|
||||
self._server = None # Clear cache after structural change
|
||||
return new_id
|
||||
else:
|
||||
# Use direct HTTP instead of stoat.py client to avoid aiohttp session issues
|
||||
import aiohttp
|
||||
api_base = (self.api_url or "https://api.stoat.chat/0.8").rstrip("/")
|
||||
channel_type = "Voice" if type == 2 else "Text"
|
||||
payload: Dict[str, Any] = {"name": name, "type": channel_type}
|
||||
if topic:
|
||||
payload["description"] = topic
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{api_base}/servers/{self.community_id}/channels",
|
||||
json=payload,
|
||||
headers={"X-Bot-Token": self.token.strip()},
|
||||
) as resp:
|
||||
try:
|
||||
data = await resp.json(content_type=None)
|
||||
except Exception:
|
||||
data = await resp.text()
|
||||
if resp.status not in (200, 201):
|
||||
raise Exception(f"HTTP {resp.status}: {data}")
|
||||
self._server = None
|
||||
return str(data["_id"])
|
||||
elif type == 2: # Voice Channel
|
||||
ch = await server.create_voice_channel(name=name)
|
||||
self._server = None # Clear cache
|
||||
return str(ch.id)
|
||||
else: # Text Channel
|
||||
ch = await server.create_text_channel(name=name, description=topic)
|
||||
self._server = None # Clear cache
|
||||
return str(ch.id)
|
||||
except Exception as e:
|
||||
print(f"Failed to create Stoat channel {name}: {e}")
|
||||
logger.error(f"Failed to create Stoat channel {name}: {e}")
|
||||
return ""
|
||||
|
||||
async def modify_channel(self, channel_id: str, name: Optional[str] = None, topic: Optional[str] = None, nsfw: Optional[bool] = None, slowmode_delay: Optional[int] = None, **kwargs) -> bool:
|
||||
import aiohttp
|
||||
api_base = (self.api_url or "https://api.stoat.chat/0.8").rstrip("/")
|
||||
payload: Dict[str, Any] = {}
|
||||
if name is not None:
|
||||
payload["name"] = name
|
||||
if topic is not None:
|
||||
payload["description"] = topic
|
||||
if nsfw is not None:
|
||||
payload["nsfw"] = nsfw
|
||||
if not payload:
|
||||
return True
|
||||
server = await self._get_server(populate_channels=True)
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.patch(
|
||||
f"{api_base}/channels/{channel_id}",
|
||||
json=payload,
|
||||
headers={"X-Bot-Token": self.token.strip()},
|
||||
) as resp:
|
||||
if resp.status not in (200, 204):
|
||||
data = await resp.json()
|
||||
raise Exception(f"HTTP {resp.status}: {data}")
|
||||
self._server = None
|
||||
channel = next((c for c in server.channels if str(c.id) == channel_id), None)
|
||||
if not channel:
|
||||
return False
|
||||
|
||||
edit_kwargs = {}
|
||||
if name is not None:
|
||||
edit_kwargs["name"] = name
|
||||
if topic is not None:
|
||||
edit_kwargs["description"] = topic
|
||||
if nsfw is not None:
|
||||
edit_kwargs["nsfw"] = nsfw
|
||||
|
||||
if edit_kwargs:
|
||||
await channel.edit(**edit_kwargs)
|
||||
self._server = None # Clear cache
|
||||
|
||||
# clone_server.py now handles all parenting bulk logic
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to modify Stoat channel {channel_id}: {e}")
|
||||
logger.error(f"Failed to modify Stoat channel {channel_id}: {e}")
|
||||
return False
|
||||
|
||||
|
|
@ -524,26 +397,14 @@ class StoatWriter:
|
|||
color=color
|
||||
))
|
||||
|
||||
# Retry logic for 'NotFound' (common race condition on self-hosted instances)
|
||||
max_tries = 2
|
||||
for attempt in range(max_tries):
|
||||
try:
|
||||
msg = await channel.send(
|
||||
content=final_content,
|
||||
masquerade=masquerade,
|
||||
replies=replies,
|
||||
attachments=attachments,
|
||||
embeds=stoat_embeds
|
||||
)
|
||||
return str(msg.id) if msg else None
|
||||
except Exception as send_err:
|
||||
# 'NotFound' often means the attachment is still being processed by the database
|
||||
if "NotFound" in str(send_err) and attempt < max_tries - 1:
|
||||
logger.warning(f"Stoat: Received NotFound during send (likely race condition). Retrying in 1.5s... (Attempt {attempt+1}/{max_tries})")
|
||||
await asyncio.sleep(1.5)
|
||||
continue
|
||||
raise send_err
|
||||
|
||||
msg = await channel.send(
|
||||
content=final_content,
|
||||
masquerade=masquerade,
|
||||
replies=replies,
|
||||
attachments=attachments,
|
||||
embeds=stoat_embeds
|
||||
)
|
||||
return str(msg.id) if msg else None
|
||||
except Exception as e:
|
||||
# If file type not allowed, skip attachments and still send the message
|
||||
if "FileTypeNotAllowed" in str(e) and attachments:
|
||||
|
|
@ -558,7 +419,6 @@ class StoatWriter:
|
|||
return str(msg.id) if msg else None
|
||||
raise # Re-raise MissingPermission and other errors
|
||||
except Exception as e:
|
||||
print(f"Failed to send Stoat message to {channel_id}: {e}")
|
||||
logger.error(f"Failed to send Stoat message to {channel_id}: {e}")
|
||||
raise # Let caller handle (migration loop will stop for permission errors)
|
||||
|
||||
|
|
@ -582,7 +442,6 @@ class StoatWriter:
|
|||
)
|
||||
return str(msg.id)
|
||||
except Exception as e:
|
||||
print(f"Failed to send Stoat marker to {channel_id}: {e}")
|
||||
logger.error(f"Failed to send Stoat marker to {channel_id}: {e}")
|
||||
return None
|
||||
|
||||
|
|
@ -606,39 +465,11 @@ class StoatWriter:
|
|||
|
||||
# Set permissions
|
||||
if permissions != 0:
|
||||
requested_perms = self._map_permissions(permissions)
|
||||
|
||||
# Fetch bot's own permissions to mask the requested set.
|
||||
# Stoat/Revolt prevents granting permissions that the bot itself lacks.
|
||||
try:
|
||||
me = await server.fetch_member(self.my_id)
|
||||
bot_perms = server.permissions_for(me)
|
||||
|
||||
# Manual masking of boolean attributes
|
||||
final_perms = stoat.Permissions.none()
|
||||
for attr in dir(requested_perms):
|
||||
if not attr.startswith("_") and isinstance(getattr(requested_perms, attr), bool):
|
||||
if getattr(requested_perms, attr) and getattr(bot_perms, attr, False):
|
||||
try:
|
||||
setattr(final_perms, attr, True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await server.set_role_permissions(role, allow=final_perms)
|
||||
except Exception as perm_err:
|
||||
logger.warning(f"Stoat: Could not mask/verify role permissions for {name} (falling back to minimal): {perm_err}")
|
||||
# Attempt to set at least one basic permission if possible, or just skip
|
||||
try:
|
||||
minimal = stoat.Permissions.none()
|
||||
minimal.view_channel = True
|
||||
minimal.send_messages = True
|
||||
await server.set_role_permissions(role, allow=minimal)
|
||||
except Exception:
|
||||
pass
|
||||
s_perms = self._map_permissions(permissions)
|
||||
await server.set_role_permissions(role, allow=s_perms)
|
||||
|
||||
return str(role.id)
|
||||
except Exception as e:
|
||||
print(f"Failed to create Stoat role {name}: {e}")
|
||||
logger.error(f"Failed to create Stoat role {name}: {e}")
|
||||
return ""
|
||||
|
||||
|
|
@ -695,7 +526,6 @@ class StoatWriter:
|
|||
await server.set_default_permissions(s_perms)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to update Stoat default permissions: {e}")
|
||||
logger.error(f"Failed to update Stoat default permissions: {e}")
|
||||
return False
|
||||
|
||||
|
|
@ -706,7 +536,6 @@ class StoatWriter:
|
|||
emoji = await server.create_server_emoji(name=name, image=image_bytes)
|
||||
return str(emoji.id)
|
||||
except Exception as e:
|
||||
print(f"Failed to create Stoat emoji {name}: {e}")
|
||||
logger.error(f"Failed to create Stoat emoji {name}: {e}")
|
||||
return ""
|
||||
|
||||
|
|
@ -722,7 +551,6 @@ class StoatWriter:
|
|||
banner=banner if banner is not None else stoat.UNDEFINED
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to update Stoat guild metadata: {e}")
|
||||
logger.error(f"Failed to update Stoat guild metadata: {e}")
|
||||
|
||||
async def remove_community_logo_and_banner(self) -> dict:
|
||||
|
|
@ -734,14 +562,12 @@ class StoatWriter:
|
|||
try:
|
||||
await server.edit(icon=None)
|
||||
except Exception as e:
|
||||
print(f"Failed to remove Stoat community icon: {e}")
|
||||
logger.error(f"Failed to remove Stoat community icon: {e}")
|
||||
|
||||
if has_banner:
|
||||
try:
|
||||
await server.edit(banner=None)
|
||||
except Exception as e:
|
||||
print(f"Failed to remove Stoat community banner: {e}")
|
||||
logger.error(f"Failed to remove Stoat community banner: {e}")
|
||||
|
||||
return {
|
||||
|
|
@ -768,7 +594,6 @@ class StoatWriter:
|
|||
if progress_callback:
|
||||
await progress_callback(name, i, total)
|
||||
except Exception as e:
|
||||
print(f"Failed to delete Stoat channel {ch.id}: {e}")
|
||||
logger.error(f"Failed to delete Stoat channel {ch.id}: {e}")
|
||||
|
||||
# To delete categories, we can wipe the categories array via server.edit to avoid 404 endpoint
|
||||
|
|
@ -793,7 +618,6 @@ class StoatWriter:
|
|||
await progress_callback(name, j, total)
|
||||
j += 1
|
||||
except Exception as e:
|
||||
print(f"Failed to wipe Stoat categories via edit: {e}")
|
||||
logger.error(f"Failed to wipe Stoat categories via edit: {e}")
|
||||
|
||||
return count
|
||||
|
|
@ -810,23 +634,17 @@ class StoatWriter:
|
|||
logger.info(f"Danger Zone: Skipping permission reset for audit channel {name}")
|
||||
total -= 1
|
||||
continue
|
||||
|
||||
# Fetch fresh channel to get current role_permissions
|
||||
fresh_ch = await self.client.fetch_channel(ch.id)
|
||||
# Clear default permissions
|
||||
if hasattr(fresh_ch, "default_permissions") and fresh_ch.default_permissions is not None:
|
||||
await fresh_ch.set_default_permissions(None)
|
||||
|
||||
# Clear all role overrides
|
||||
if hasattr(fresh_ch, "role_permissions"):
|
||||
for role_id in list(fresh_ch.role_permissions.keys()):
|
||||
await fresh_ch.set_role_permissions(str(role_id), allow=stoat.Permissions.none(), deny=stoat.Permissions.none())
|
||||
|
||||
# In Stoat, clearing overrides might involve setting them to default or explicitly removing the role_permissions/default_permissions
|
||||
# Since we don't know an explicit "clear_overrides" method, we'll wipe them by setting empty/none if possible.
|
||||
# Actually Stoat allows overwriting. Setting allow=0 deny=0 for role overrides isn't explicitly clear.
|
||||
# For safety, we will just pass. If the user expects it, we'd iterate over roles and set empty.
|
||||
# A quick way is to edit the channel permissions to empty state if possible.
|
||||
# Let's count them anyway.
|
||||
# (Fluxer writer does a loop over existing overrides, we can just return 0 for now until we inspect Stoat `PermissionOverride` deletion)
|
||||
count += 1
|
||||
if progress_callback:
|
||||
await progress_callback(name, i, total)
|
||||
except Exception as e:
|
||||
print(f"Failed to reset Stoat channel permissions for {ch.id}: {e}")
|
||||
logger.error(f"Failed to reset Stoat channel permissions for {ch.id}: {e}")
|
||||
return count
|
||||
|
||||
|
|
@ -853,7 +671,6 @@ class StoatWriter:
|
|||
if "MissingPermission" in err_msg and "ViewChannel" in err_msg:
|
||||
logger.error(f"Stoat LOCKOUT: Bot lacks 'ViewChannel' to edit {channel_id}. "
|
||||
"Ensure the bot has 'Manage Server' or a role with 'Allow View Channel' rank higher than @everyone.")
|
||||
print(f"Failed to set Stoat channel permission for {overwrite_id} on {channel_id}: {e}")
|
||||
logger.error(f"Failed to set Stoat channel permission for {overwrite_id} on {channel_id}: {e}")
|
||||
|
||||
|
||||
|
|
@ -895,31 +712,18 @@ class StoatWriter:
|
|||
await emoji.delete()
|
||||
count += 1
|
||||
except Exception as e:
|
||||
print(f"Failed to delete Stoat emoji {emoji.name}: {e}")
|
||||
logger.error(f"Failed to delete Stoat emoji {emoji.name}: {e}")
|
||||
|
||||
return {"emojis": count, "stickers": 0}
|
||||
|
||||
async def close(self):
|
||||
client = self.client
|
||||
task = getattr(self, "_start_task", None)
|
||||
|
||||
self.client = None # Atomic clear to prevent new usage
|
||||
self._start_task = None
|
||||
self._me = None
|
||||
self._server = None
|
||||
|
||||
if client:
|
||||
try:
|
||||
await client.close()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error closing Stoat client session: {e}")
|
||||
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(task, timeout=2.0)
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
pass
|
||||
|
||||
logger.debug(f"Error closing Stoat client: {e}")
|
||||
self._validation_cache = None
|
||||
|
|
|
|||
|
|
@ -164,9 +164,6 @@ class OperationPane(Container):
|
|||
yield Rule(id="footer_rule")
|
||||
yield Button("Danger Zone ⚠", id="op_danger", variant="error", disabled=True, flat=True, tooltip="Dangerous operations:\ndelete channels, roles, emojis on target\n(use with caution)")
|
||||
|
||||
if self.cfg_name == "AutoTest":
|
||||
yield Button("RUN AUTO TEST", id="op_autotest", variant="warning", flat="false", tooltip="Execute automated test sequence for the AutoTest profile")
|
||||
|
||||
def on_mount(self) -> None:
|
||||
self._rebuild_engine()
|
||||
self._update_info_labels()
|
||||
|
|
@ -329,13 +326,8 @@ class OperationPane(Container):
|
|||
for pne in self.query("#op_target_pane"): pne.display = False
|
||||
|
||||
enabled = (v.get("discord_token") and v.get("discord_server") and not d_missing)
|
||||
for btn in self.query("#op_backup_msgs"):
|
||||
btn.disabled = not enabled
|
||||
for btn in self.query("#op_backup_sync"):
|
||||
btn.display = self.has_backup
|
||||
btn.disabled = not (enabled and self.has_backup)
|
||||
for btn in self.query("#op_autotest"):
|
||||
btn.disabled = not enabled
|
||||
for bid in ("#op_backup_msgs", "#op_backup_sync"):
|
||||
for btn in self.query(bid): btn.disabled = not enabled
|
||||
|
||||
for btn in self.query("#op_backup_stats"):
|
||||
btn.display = self.has_backup
|
||||
|
|
@ -403,7 +395,7 @@ class OperationPane(Container):
|
|||
lbl.update(f"{t_status}")
|
||||
|
||||
# Buttons
|
||||
for bid in ("#op_clone", "#op_sync", "#op_messages", "#op_waterfall", "#op_danger", "#op_autotest"):
|
||||
for bid in ("#op_clone", "#op_sync", "#op_messages", "#op_waterfall", "#op_danger"):
|
||||
for btn in self.query(bid): btn.disabled = not self.tokens_valid
|
||||
|
||||
# ── validation ────────────────────────────────────────────────────────
|
||||
|
|
@ -424,10 +416,10 @@ class OperationPane(Container):
|
|||
|
||||
# Disable all operation buttons while validation is in progress
|
||||
if self.view_mode == "shuttle":
|
||||
for bid in ("#op_clone", "#op_sync", "#op_messages", "#op_waterfall", "#op_danger", "#op_autotest"):
|
||||
for bid in ("#op_clone", "#op_sync", "#op_messages", "#op_waterfall", "#op_danger"):
|
||||
for btn in self.query(bid): btn.disabled = True
|
||||
elif self.view_mode == "backup":
|
||||
for bid in ("#op_backup_msgs", "#op_backup_sync", "#op_autotest"):
|
||||
for bid in ("#op_backup_msgs", "#op_backup_sync"):
|
||||
for btn in self.query(bid): btn.disabled = True
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run_validate setup: {e}")
|
||||
|
|
@ -598,124 +590,7 @@ class OperationPane(Container):
|
|||
from src.ui.backup_stats import BackupStatsScreen
|
||||
target_dir = Path(self._base_dir()) / f"DISCORD_BACKUP-{self.config.discord_server_id}"
|
||||
self.app.push_screen(BackupStatsScreen(self.cfg_name, target_dir))
|
||||
elif bid == "op_autotest" or bid == "btn_autotest":
|
||||
self.run_autotest_sequence()
|
||||
|
||||
@work(exclusive=True)
|
||||
async def run_autotest_sequence(self) -> None:
|
||||
"""Entry point for the AUTO TEST sequence."""
|
||||
if not self.tokens_valid:
|
||||
return
|
||||
|
||||
modal = ProgressScreen(log_level=self.config.log_level)
|
||||
self.app.push_screen(modal)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
try:
|
||||
if self.view_mode == "shuttle":
|
||||
await self._run_migration_autotest_logic(modal)
|
||||
elif self.view_mode == "backup":
|
||||
await self._run_backup_autotest_logic(modal)
|
||||
|
||||
modal.phase_report("AUTO TEST Complete", show_back=False)
|
||||
modal.write("[bold green]Full automated test sequence finished successfully![/bold green]")
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-Test Error: {e}\n{traceback.format_exc()}")
|
||||
modal.write(f"[bold red]Error: {e}[/bold red]")
|
||||
modal.phase_report("Auto-Test", "error", show_back=False)
|
||||
finally:
|
||||
self.engine.is_running = False
|
||||
await self.engine.close_connections()
|
||||
|
||||
async def _run_migration_autotest_logic(self, modal: ProgressScreen) -> None:
|
||||
"""Executes the full migration test sequence."""
|
||||
modal.set_status("AUTO TEST: Launching Migration Sequence...")
|
||||
modal.write("[bold yellow]Starting Migration Auto-Test...[/bold yellow]")
|
||||
|
||||
migrate_mod = fluxer_migrate if self.target_platform == "fluxer" else stoat_migrate
|
||||
|
||||
# 1. Connect and initialize state
|
||||
modal.set_status("Connecting and Initializing State...")
|
||||
await self.engine.start_connections()
|
||||
self.engine.is_running = True
|
||||
|
||||
# Initialize state database early to avoid Warnings
|
||||
tid = self.config.fluxer_server_id if self.target_platform == "fluxer" else self.config.stoat_server_id
|
||||
tgt_info = await self.engine.writer.validate()
|
||||
self.engine.ensure_state_initialized(str(tid or ""), tgt_info.get("community_name", "Target"))
|
||||
|
||||
# 2. Danger Zone Clean
|
||||
modal.write("\n[bold red]Phase 1: Danger Zone Cleanup[/bold red]")
|
||||
await self._logic_dz_delete_channels(modal)
|
||||
await self._logic_dz_reset_perms(modal)
|
||||
await self._logic_dz_delete_roles(modal)
|
||||
await self._logic_dz_delete_assets(modal)
|
||||
|
||||
# 3. Clone Roles & Permissions
|
||||
modal.write("\n[bold cyan]Phase 2: Cloning Roles & Permissions[/bold cyan]")
|
||||
await self._logic_clone_roles(modal, force=True)
|
||||
|
||||
# 4. Clone Assets (Logo, Banner, Emojis, Stickers)
|
||||
modal.write("\n[bold cyan]Phase 3: Syncing Metadata & Assets[/bold cyan]")
|
||||
await self._logic_copy_assets(modal, ["Emoji", "Sticker"], force=True)
|
||||
await self._logic_sync_metadata(modal, ["name", "icon", "banner"])
|
||||
|
||||
# 5. Clone Template (Structure)
|
||||
modal.write("\n[bold cyan]Phase 4: Cloning Server Structure (1/2)[/bold cyan]")
|
||||
await self._logic_clone_channels(modal, force=True)
|
||||
await self._logic_sync_permissions(modal)
|
||||
|
||||
# 6. Waterfall Migration (Only for backups)
|
||||
if self.engine.source_mode == "backup" :
|
||||
modal.write("\n[bold cyan]Phase 5: Waterfall Message Migration[/bold cyan]")
|
||||
await self._logic_waterfall_migration(modal=modal, is_autotest=True)
|
||||
else:
|
||||
modal.write("\n[bold yellow]Phase 5: Skipping Waterfall (Live mode selected)[/bold yellow]")
|
||||
|
||||
# 7. Individual Channel Migration (Automated)
|
||||
modal.write("\n[bold cyan]Phase 7: Individual Channel Migration[/bold cyan]")
|
||||
await self._logic_autotest_migrate_all_channels(modal=modal)
|
||||
|
||||
async def _run_backup_autotest_logic(self, modal: ProgressScreen) -> None:
|
||||
"""Executes the full backup test sequence."""
|
||||
import shutil
|
||||
modal.set_status("AUTO TEST: Launching Backup Sequence...")
|
||||
modal.write("[bold yellow]Starting Backup Auto-Test...[/bold yellow]")
|
||||
|
||||
# 1. Clear old backup directory completely
|
||||
server_dir = Path(self._base_dir()) / f"DISCORD_BACKUP-{self.config.discord_server_id}"
|
||||
if server_dir.exists():
|
||||
modal.write(f"[yellow]Deleting existing backup directory: {server_dir}[/yellow]")
|
||||
try:
|
||||
shutil.rmtree(server_dir)
|
||||
except Exception as e:
|
||||
modal.write(f"[red]Warning: Could not delete directory: {e}[/red]")
|
||||
|
||||
# 2. Setup & Metadata
|
||||
modal.set_status("Initializing Discord connection...")
|
||||
await self.engine.discord_reader.start()
|
||||
await self.exporter.setup()
|
||||
self.exporter.is_running = True
|
||||
|
||||
modal.write("\n[bold cyan]Phase 1: Full Server Backup[/bold cyan]")
|
||||
|
||||
# 3. Use unified backup logic in autotest mode
|
||||
all_channels = await self.engine.discord_reader.get_channels()
|
||||
eligible_channels = [
|
||||
c for c in all_channels
|
||||
if c.type in [
|
||||
self.engine.discord_reader.CHANNEL_TYPE_TEXT,
|
||||
self.engine.discord_reader.CHANNEL_TYPE_NEWS,
|
||||
self.engine.discord_reader.CHANNEL_TYPE_FORUM
|
||||
]
|
||||
]
|
||||
|
||||
await self._logic_full_backup(
|
||||
modal=modal,
|
||||
selected_channels=eligible_channels,
|
||||
force_overwrite=True,
|
||||
is_autotest=True
|
||||
)
|
||||
# ── (1) clone server template (combined) ─────────────────────────────
|
||||
|
||||
def _open_clone_menu(self):
|
||||
|
|
@ -1137,67 +1012,16 @@ class OperationPane(Container):
|
|||
# ── (5) message migration ─────────────────────────────────────────────
|
||||
|
||||
@work(exclusive=True)
|
||||
async def run_migrate_messages(self, modal: ProgressScreen | None = None) -> None:
|
||||
await self._logic_migrate_messages(modal)
|
||||
|
||||
async def _logic_autotest_migrate_all_channels(self, modal: ProgressScreen) -> None:
|
||||
"""Automated name-based channel migration for Auto-Test."""
|
||||
migrate_mod = fluxer_migrate if self.target_platform == "fluxer" else stoat_migrate
|
||||
|
||||
# 1. Matching
|
||||
modal.set_status("Auto-matching channels by name...")
|
||||
await self._perform_auto_matching()
|
||||
|
||||
# 2. Get channels
|
||||
d_channels = await self.engine.discord_reader.get_channels()
|
||||
text_channels = [c for c in d_channels if c.type in [
|
||||
self.engine.discord_reader.CHANNEL_TYPE_TEXT,
|
||||
self.engine.discord_reader.CHANNEL_TYPE_NEWS
|
||||
]]
|
||||
|
||||
modal.write(f"[bold cyan]Auto-Test: Found {len(text_channels)} channels to migrate.[/bold cyan]")
|
||||
|
||||
# 3. Migrate loop
|
||||
for i, ch in enumerate(text_channels):
|
||||
if not self.engine.is_running: break
|
||||
|
||||
tgt_id = self.engine.state.get_target_channel_id(str(ch.id))
|
||||
if not tgt_id:
|
||||
modal.write(f"[yellow]Skipping #{ch.name} (no target mapping found)[/yellow]")
|
||||
continue
|
||||
|
||||
modal.write(f"\n[bold]Migrating #{ch.name} ({i+1}/{len(text_channels)}) -> Target ID {tgt_id}[/bold]")
|
||||
|
||||
# Analyze (silent)
|
||||
stats = await migrate_mod.analyze_migration(self.engine, source_channel_id=ch.id, after_message_id=None)
|
||||
ch_total = stats["messages"]
|
||||
|
||||
async def update_indiv(curr):
|
||||
c = curr["messages"]
|
||||
modal.set_progress(c, ch_total or 100)
|
||||
modal.set_item_status(f"#{ch.name}: {c}/{ch_total} messages")
|
||||
|
||||
await migrate_mod.migrate_messages(
|
||||
self.engine,
|
||||
source_channel_id=ch.id,
|
||||
target_channel_id=tgt_id,
|
||||
after_message_id=None,
|
||||
progress_callback=update_indiv
|
||||
)
|
||||
|
||||
modal.write("\n[bold green]Automated channel migration complete.[/bold green]")
|
||||
|
||||
async def _logic_migrate_messages(self, modal: ProgressScreen | None = None, is_autotest: bool = False) -> None:
|
||||
async def run_migrate_messages(self) -> None:
|
||||
if not self.tokens_valid:
|
||||
return
|
||||
|
||||
migrate_mod = fluxer_migrate if self.target_platform == "fluxer" else stoat_migrate
|
||||
platform_name = self.target_platform.capitalize()
|
||||
|
||||
if not modal:
|
||||
modal = ProgressScreen(log_level=self.config.log_level)
|
||||
self.app.push_screen(modal)
|
||||
await asyncio.sleep(0.1)
|
||||
modal = ProgressScreen(log_level=self.config.log_level)
|
||||
self.app.push_screen(modal)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
try:
|
||||
# Show info container
|
||||
|
|
@ -1594,29 +1418,22 @@ class OperationPane(Container):
|
|||
await self.engine.close_connections()
|
||||
|
||||
@work(exclusive=True)
|
||||
async def run_waterfall_migration(self, modal: ProgressScreen | None = None) -> None:
|
||||
await self._logic_waterfall_migration(modal)
|
||||
|
||||
async def _logic_waterfall_migration(self, modal: ProgressScreen | None = None, is_autotest: bool = False) -> None:
|
||||
async def run_waterfall_migration(self) -> None:
|
||||
if not self.tokens_valid:
|
||||
return
|
||||
|
||||
migrate_mod = fluxer_migrate if self.target_platform == "fluxer" else stoat_migrate
|
||||
platform_name = self.target_platform.capitalize()
|
||||
|
||||
if not modal:
|
||||
modal = ProgressScreen(log_level=self.config.log_level)
|
||||
self.app.push_screen(modal)
|
||||
await asyncio.sleep(0.1)
|
||||
modal = ProgressScreen(log_level=self.config.log_level)
|
||||
self.app.push_screen(modal)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
try:
|
||||
modal.show_info("[bold cyan]Waterfall Migration Ready[/bold cyan]", "Checking mapping and missing channels...")
|
||||
modal.set_status("Connecting to Servers...")
|
||||
await self.engine.start_connections()
|
||||
|
||||
# Ensure writer is validated before auto-matching (fixes NoneType role fetch error)
|
||||
await self.engine.writer.validate()
|
||||
|
||||
modal.set_status("Synchronizing entity mappings...")
|
||||
await self._perform_auto_matching()
|
||||
|
||||
|
|
@ -1643,19 +1460,15 @@ class OperationPane(Container):
|
|||
prefix = "[bold cyan]📁[/bold cyan] " if mc.type == 4 else "[bold white]#[/bold white] "
|
||||
modal.write(f" {prefix}{mc.name}")
|
||||
|
||||
if is_autotest:
|
||||
choice = "btn_start_first"
|
||||
modal.write("[bold cyan]Auto-Test: Automatically cloning missing channels.[/bold cyan]")
|
||||
else:
|
||||
choice = await modal.phase_wait_confirm(
|
||||
show_continue=False,
|
||||
show_id=True,
|
||||
btn_start_label="Clone missing channels",
|
||||
btn_id_label="Skip missing channels",
|
||||
btn_start_variant="primary",
|
||||
btn_start_tooltip=f"Automatically create {len(missing_channels)} entities on target",
|
||||
btn_id_tooltip="Start migration without these channels"
|
||||
)
|
||||
choice = await modal.phase_wait_confirm(
|
||||
show_continue=False,
|
||||
show_id=True,
|
||||
btn_start_label="Clone missing channels",
|
||||
btn_id_label="Skip missing channels",
|
||||
btn_start_variant="primary",
|
||||
btn_start_tooltip=f"Automatically create {len(missing_channels)} entities on target",
|
||||
btn_id_tooltip="Start migration without these channels"
|
||||
)
|
||||
|
||||
if choice == "btn_back":
|
||||
modal.dismiss()
|
||||
|
|
@ -1738,8 +1551,10 @@ class OperationPane(Container):
|
|||
if filtered_tgt_ids:
|
||||
all_mapped_tgt_ids = filtered_tgt_ids
|
||||
|
||||
# 2.6 Resume Point: Calculate from global channel minimums
|
||||
min_last_id = self.engine.state.get_global_min_last_message_id(all_mapped_tgt_ids)
|
||||
# 2.6 Resume Point: Prioritize Global waterfall tracker, fallback to channel minimums
|
||||
min_last_id = self.engine.state.get_waterfall_last_id()
|
||||
if min_last_id is None:
|
||||
min_last_id = self.engine.state.get_global_min_last_message_id(all_mapped_tgt_ids)
|
||||
|
||||
modal.write(f"\n[bold cyan]Waterfall Migration Resume Point:[/bold cyan]")
|
||||
if min_last_id is not None:
|
||||
|
|
@ -1747,19 +1562,15 @@ class OperationPane(Container):
|
|||
else:
|
||||
modal.write("No previous migration state found. Starting from the beginning.")
|
||||
|
||||
if is_autotest:
|
||||
choice = "btn_continue" if min_last_id is not None else "btn_start_first"
|
||||
modal.write(f"[bold cyan]Auto-Test: Automatically choosing {choice.replace('btn_', '').replace('_', ' ')}.[/bold cyan]")
|
||||
else:
|
||||
choice = await modal.phase_wait_confirm(
|
||||
show_continue=min_last_id is not None,
|
||||
show_id=False,
|
||||
btn_start_label="Start From Beginning",
|
||||
btn_start_tooltip="Wipes migration progress and restarts from the beginning; may create duplicates",
|
||||
btn_start_variant="default" if min_last_id is not None else "primary",
|
||||
btn_continue_label=f"Continue from ID {min_last_id if min_last_id is not None else 0}" if min_last_id is not None else "Continue Migration",
|
||||
btn_continue_tooltip="Fastest"
|
||||
)
|
||||
choice = await modal.phase_wait_confirm(
|
||||
show_continue=min_last_id is not None,
|
||||
show_id=False,
|
||||
btn_start_label="Start From Beginning",
|
||||
btn_start_tooltip="Safe, skips duplicates automatically",
|
||||
btn_start_variant="default" if min_last_id is not None else "primary",
|
||||
btn_continue_label=f"Continue from ID {min_last_id if min_last_id is not None else 0}" if min_last_id is not None else "Continue Migration",
|
||||
btn_continue_tooltip="Fastest"
|
||||
)
|
||||
|
||||
if choice == "btn_back":
|
||||
modal.dismiss()
|
||||
|
|
@ -1771,11 +1582,7 @@ class OperationPane(Container):
|
|||
return
|
||||
|
||||
after_id = None
|
||||
if choice == "btn_start_first":
|
||||
logger.info("Proceeding with 'Start from Beginning' (global clean sink).")
|
||||
self.engine.state.clear_all_migration_data()
|
||||
after_id = None
|
||||
elif choice == "btn_continue" and min_last_id is not None:
|
||||
if choice == "btn_continue" and min_last_id is not None:
|
||||
after_id = int(min_last_id)
|
||||
|
||||
# Phase 3: Progress
|
||||
|
|
@ -1792,7 +1599,6 @@ class OperationPane(Container):
|
|||
tid = self.config.fluxer_server_id
|
||||
self.engine.ensure_state_initialized(str(tid or ""), platform_name)
|
||||
|
||||
modal.show_stats()
|
||||
modal.write("Scanning global footprint for totals ...")
|
||||
stats_analysis = await migrate_mod.analyze_global_migration(self.engine, after_message_id=after_id)
|
||||
total_messages = stats_analysis["messages"]
|
||||
|
|
@ -2281,7 +2087,6 @@ class OperationPane(Container):
|
|||
|
||||
@work(exclusive=True)
|
||||
async def run_backup_messages(self) -> None:
|
||||
"""UI entry point for full backup."""
|
||||
modal_prog = ProgressScreen(log_level=self.config.log_level)
|
||||
self.app.push_screen(modal_prog)
|
||||
await asyncio.sleep(0.1)
|
||||
|
|
@ -2291,6 +2096,38 @@ class OperationPane(Container):
|
|||
await self.engine.discord_reader.start()
|
||||
await self.exporter.setup()
|
||||
|
||||
# Check if profile is empty
|
||||
profile_exists = False
|
||||
if self.exporter.db:
|
||||
try:
|
||||
profile_exists = self.exporter.db.get_guild_profile() is not None
|
||||
except Exception:
|
||||
profile_exists = False
|
||||
|
||||
if not profile_exists:
|
||||
modal_prog.set_status("First-Time Setup: Exporting Server Profile...")
|
||||
modal_prog.write("[yellow]No existing profile found. Performing primary profile backup...[/yellow]")
|
||||
|
||||
modal_prog.write("[yellow]Exporting server metadata...[/yellow]")
|
||||
await self.exporter.export_metadata()
|
||||
|
||||
modal_prog.write("[yellow]Syncing server assets (icon/banner)...[/yellow]")
|
||||
await self.exporter.download_server_assets()
|
||||
|
||||
modal_prog.write("[yellow]Exporting server structure...[/yellow]")
|
||||
await self.exporter.export_channels_structure()
|
||||
|
||||
modal_prog.write("[yellow]Exporting roles & permissions...[/yellow]")
|
||||
await self.exporter.export_roles()
|
||||
|
||||
modal_prog.write("[yellow]Exporting custom emojis & stickers...[/yellow]")
|
||||
await self.exporter.export_assets()
|
||||
|
||||
modal_prog.write("[bold green]Primary profile setup complete![/bold green]")
|
||||
modal_prog.write("")
|
||||
else:
|
||||
modal_prog.write("[dim]Existing profile detected. Scanning structure...[/dim]")
|
||||
await self.exporter.export_channels_structure()
|
||||
all_channels = await self.engine.discord_reader.get_channels()
|
||||
all_categories = await self.engine.discord_reader.get_categories()
|
||||
cat_map = {c.id: c.name for c in all_categories}
|
||||
|
|
@ -2309,9 +2146,8 @@ class OperationPane(Container):
|
|||
modal_prog.allow_close()
|
||||
return
|
||||
|
||||
# Analyze which are already backed up
|
||||
backed_up_ids = set()
|
||||
any_found = False
|
||||
backed_up_ids = set()
|
||||
if self.exporter.db:
|
||||
channel_stats = self.exporter.db.get_stats_by_channel()
|
||||
for chan in eligible_channels:
|
||||
|
|
@ -2319,105 +2155,87 @@ class OperationPane(Container):
|
|||
any_found = True
|
||||
backed_up_ids.add(chan.id)
|
||||
|
||||
# Manual selection
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.create_future()
|
||||
def check_channels(reply: dict | None) -> None:
|
||||
if not future.done(): future.set_result(reply)
|
||||
self.app.pop_screen()
|
||||
|
||||
self.app.push_screen(
|
||||
ChannelSelectScreen(eligible_channels, cat_map, backed_up_ids, any_found),
|
||||
check_channels,
|
||||
)
|
||||
while True:
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.create_future()
|
||||
|
||||
reply = await future
|
||||
if not reply: return
|
||||
def check_channels(reply: dict | None) -> None:
|
||||
if not future.done():
|
||||
future.set_result(reply)
|
||||
|
||||
selected_ids = reply["channels"]
|
||||
force_overwrite = reply["force"]
|
||||
selected_channels = [c for c in eligible_channels if c.id in selected_ids]
|
||||
|
||||
# Confirmation phase
|
||||
new_channels = [c for c in selected_channels if c.id not in backed_up_ids]
|
||||
existing_channels = [c for c in selected_channels if c.id in backed_up_ids]
|
||||
|
||||
modal_confirm = ProgressScreen(log_level=self.config.log_level)
|
||||
self.app.push_screen(modal_confirm)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
modal_confirm.set_status(f"Confirm to proceed with Backup of [bold]{len(selected_channels)}[/bold] channels")
|
||||
modal_confirm.show_info(f"[cyan]Backup Channels[/cyan]", f"{len(new_channels)} new, {len(existing_channels)} existing")
|
||||
|
||||
choice = await modal_confirm.phase_wait_confirm(btn_start_label="Start Channel Backup", show_id=False)
|
||||
if choice != "btn_start_first":
|
||||
modal_confirm.dismiss()
|
||||
return
|
||||
|
||||
modal_confirm.cancel_callback = lambda: setattr(self.exporter, "is_running", False)
|
||||
modal_confirm.phase_progress()
|
||||
|
||||
await self._logic_full_backup(
|
||||
modal=modal_confirm,
|
||||
selected_channels=selected_channels,
|
||||
force_overwrite=force_overwrite,
|
||||
is_autotest=False
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Backup Error: {traceback.format_exc()}")
|
||||
modal_prog.write(f"[bold red]Backup failed: {e}[/bold red]")
|
||||
modal_prog.phase_report("Backup", "error", show_back=False)
|
||||
finally:
|
||||
await self.engine.close_connections()
|
||||
|
||||
async def _logic_full_backup(self, modal: ProgressScreen, selected_channels: list, force_overwrite: bool, is_autotest: bool = False) -> None:
|
||||
"""Non-interactive core backup logic."""
|
||||
if not self.exporter.is_running:
|
||||
self.exporter.is_running = True
|
||||
|
||||
try:
|
||||
# 1. Metadata and Assets
|
||||
modal.set_status("Exporting Server Profile...")
|
||||
await self.exporter.export_metadata()
|
||||
await self.exporter.download_server_assets()
|
||||
await self.exporter.export_channels_structure()
|
||||
await self.exporter.export_roles()
|
||||
await self.exporter.export_assets()
|
||||
|
||||
# Pre-fetch all members once for role resolution during message export
|
||||
modal.set_status("Pre-fetching server members...")
|
||||
await self.exporter.prefetch_members()
|
||||
|
||||
# 2. Channel Messages
|
||||
total_chans = len(selected_channels)
|
||||
modal.write(f"\n[bold cyan]Backing up {total_chans} channels...[/bold cyan]")
|
||||
modal.show_stats()
|
||||
|
||||
for i, chan in enumerate(selected_channels):
|
||||
if not self.exporter.is_running: break
|
||||
|
||||
modal.set_item_status(f"[cyan]Processing ({i+1}/{total_chans}): #{chan.name}[/cyan]")
|
||||
modal.set_progress(i, total_chans)
|
||||
modal.write(f"[cyan]Backing up: #{chan.name}[/cyan]")
|
||||
|
||||
async def update_backup(name, count, author_name=None, message_preview=None, thread_count=0, file_count=0):
|
||||
modal.update_stats(messages=str(count), threads=str(thread_count), files=str(file_count))
|
||||
if author_name and message_preview and count % 20 == 0:
|
||||
modal.write(f"[dim]{author_name}:[/dim] {message_preview}")
|
||||
|
||||
await self.exporter.export_channel_messages(
|
||||
chan.id, progress_callback=update_backup, force=force_overwrite
|
||||
self.app.push_screen(
|
||||
ChannelSelectScreen(eligible_channels, cat_map, backed_up_ids, any_found),
|
||||
check_channels,
|
||||
)
|
||||
modal.write(f"[green]Completed: #{chan.name}[/green]")
|
||||
|
||||
modal.set_progress(total_chans, total_chans)
|
||||
modal.write("[bold green]Backup complete![/bold green]")
|
||||
modal.phase_report("Full Backup", show_back=False)
|
||||
reply = await future
|
||||
if not reply:
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
modal.write(f"[bold red]Core backup failed: {e}[/bold red]")
|
||||
logger.error(f"Core Backup Error: {traceback.format_exc()}")
|
||||
raise e
|
||||
selected_ids = reply["channels"]
|
||||
force_overwrite = reply["force"]
|
||||
selected_channels = [c for c in eligible_channels if c.id in selected_ids]
|
||||
|
||||
# Phase 2: Confirmation
|
||||
modal_prog = ProgressScreen(log_level=self.config.log_level) # Re-instantiate to avoid Textual re-push UI freeze
|
||||
self.app.push_screen(modal_prog)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
new_channels = [c for c in selected_channels if c.id not in backed_up_ids]
|
||||
existing_channels = [c for c in selected_channels if c.id in backed_up_ids]
|
||||
|
||||
server = getattr(self.engine.discord_reader, 'guild', None)
|
||||
if server:
|
||||
modal_prog.write(f"[bold cyan]Server Profile:[/bold cyan]")
|
||||
modal_prog.write(f" Name: [green]{server.name}[/green]")
|
||||
modal_prog.write(f" Icon: [green]{'Present' if server.icon else 'None'}[/green]")
|
||||
modal_prog.write("")
|
||||
|
||||
modal_prog.set_status(f"Confirm to proceed with Backup of [bold]{len(selected_channels)}[/bold] channels")
|
||||
modal_prog.show_info(f"[cyan]Backup Channels[/cyan]", f"{len(new_channels)} new, {len(existing_channels)} existing")
|
||||
|
||||
# Show categorized channel lists in the bottom log
|
||||
if new_channels:
|
||||
modal_prog.write("[bold green]New Backups to be created:[/bold green]")
|
||||
for idx, c in enumerate(new_channels):
|
||||
modal_prog.write(f" {idx+1}. #{c.name}")
|
||||
|
||||
if existing_channels:
|
||||
action = "Overwritten" if force_overwrite else "Updated"
|
||||
modal_prog.write(f"[bold yellow]\nExisting backups to be {action}:[/bold yellow]")
|
||||
for idx, c in enumerate(existing_channels):
|
||||
modal_prog.write(f" {idx+1}. #{c.name}")
|
||||
|
||||
choice = await modal_prog.phase_wait_confirm(btn_start_label="Start Channel Backup", show_id=False)
|
||||
if choice == "btn_back":
|
||||
modal_prog.dismiss()
|
||||
continue
|
||||
elif choice == "btn_start_id":
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.create_future()
|
||||
def id_callback(res: int | None) -> None:
|
||||
if not future.done():
|
||||
future.set_result(res)
|
||||
|
||||
id_modal = MessageIDInputModal(self.engine.discord_reader, selected_channels[0].id)
|
||||
self.app.push_screen(id_modal, id_callback)
|
||||
verified_id = await future
|
||||
|
||||
if verified_id is None:
|
||||
# User cancelled the ID input
|
||||
continue
|
||||
|
||||
after_id = verified_id
|
||||
elif choice == "btn_main_menu":
|
||||
modal_prog.dismiss()
|
||||
return
|
||||
|
||||
# If we are here, proceeding either via Start First or Start from ID (after_id)
|
||||
if choice == "btn_start_first":
|
||||
after_id = None
|
||||
break
|
||||
|
||||
modal_prog.phase_progress()
|
||||
modal_prog.show_stats()
|
||||
|
|
|
|||
|
|
@ -1,170 +0,0 @@
|
|||
import pytest
|
||||
import shutil
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from src.core.backup_database import BackupDatabase
|
||||
from src.core.configuration import AppConfig
|
||||
|
||||
LOG_FILE = (Path(__file__).parent.parent / ".reaper_tests.log").resolve()
|
||||
|
||||
def _log(msg: str):
|
||||
"""Write a message to the shared test log file."""
|
||||
with open(LOG_FILE, "a") as f:
|
||||
f.write(f"{msg}\n")
|
||||
|
||||
@pytest.fixture
|
||||
def log():
|
||||
"""Fixture to provide the logging function to tests."""
|
||||
return _log
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Register global warning filters and marks."""
|
||||
# Register marks if needed
|
||||
config.addinivalue_line("markers", "asyncio: mark test as asyncio")
|
||||
|
||||
# Silence benign async mock and ResourceWarnings globally
|
||||
config.addinivalue_line("filterwarnings", "ignore::RuntimeWarning")
|
||||
config.addinivalue_line("filterwarnings", "ignore::ResourceWarning")
|
||||
|
||||
def pytest_sessionstart(session):
|
||||
"""Clear the log file at the beginning of the test session."""
|
||||
with open(LOG_FILE, "w") as f:
|
||||
f.write("--- Reaper Test Session Started ---\n")
|
||||
|
||||
def pytest_report_header(config):
|
||||
"""Print data source status to the console header."""
|
||||
test_data_dir = (Path(__file__).parent.parent / "ReaperFiles-AutoTest").resolve()
|
||||
if test_data_dir.exists():
|
||||
return f"[DATA_SOURCE] Automated Test Directory: FOUND ({test_data_dir.name})"
|
||||
else:
|
||||
return "[DATA_SOURCE] Automated Test Directory: NOT FOUND (Falling back to MOCKS)"
|
||||
|
||||
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
|
||||
def pytest_runtest_makereport(item, call):
|
||||
"""Hook to catch test results and log them to .reaper_tests.log."""
|
||||
outcome = yield
|
||||
rep = outcome.get_result()
|
||||
|
||||
# We only log on the 'call' phase (the actual test run)
|
||||
if rep.when == "call":
|
||||
status = rep.outcome.upper()
|
||||
_log(f"RESULT: {item.nodeid} -> {status}")
|
||||
if rep.failed:
|
||||
_log(f"--- FAILURE DETAILS for {item.name} ---\n{rep.longreprtext}\n---------------------")
|
||||
elif not rep.passed:
|
||||
# Catch setup/teardown failures
|
||||
_log(f"ERROR in {item.nodeid} [{rep.when}]: {rep.outcome.upper()}")
|
||||
if rep.failed:
|
||||
_log(f"--- ERROR DETAILS ---\n{rep.longreprtext}\n---------------------")
|
||||
|
||||
def pytest_warning_recorded(warning_message, when, nodeid, location):
|
||||
"""Hook to catch and log warnings to .reaper_tests.log."""
|
||||
msg = f"WARNING: {warning_message.message}"
|
||||
if nodeid:
|
||||
msg = f"WARNING in {nodeid} [{when}]: {warning_message.message}"
|
||||
_log(msg)
|
||||
|
||||
@pytest.fixture
|
||||
def test_data_dir():
|
||||
return (Path(__file__).parent.parent / "ReaperFiles-AutoTest").resolve()
|
||||
|
||||
@pytest.fixture
|
||||
def reaper_config(test_data_dir):
|
||||
config_path = test_data_dir / "reaper_config.yaml"
|
||||
if config_path.exists():
|
||||
with open(config_path, "r") as f:
|
||||
data = yaml.safe_load(f)
|
||||
return data
|
||||
# Fallback mock config
|
||||
return {
|
||||
"discord_bot_token": "mock_discord_token",
|
||||
"discord_server_id": "123456789012345678",
|
||||
"tool_mode": "backup_transfer",
|
||||
"target_platform": "fluxer",
|
||||
"fluxer_bot_token": "mock_fluxer_token",
|
||||
"fluxer_server_id": "987654321098765432",
|
||||
"stoat_bot_token": "mock_stoat_token",
|
||||
"stoat_server_id": "MOCK_STOAT_COMMUNITY",
|
||||
"anonymize_users": False,
|
||||
"log_level": "DEBUG"
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db(test_data_dir, tmp_path, reaper_config):
|
||||
# If the real test data exists, use it
|
||||
if test_data_dir.exists():
|
||||
target_id = reaper_config.get("fluxer_server_id") if reaper_config.get("target_platform") == "fluxer" else reaper_config.get("stoat_server_id")
|
||||
db_files = list(test_data_dir.glob(f"*-{target_id}.db"))
|
||||
if not db_files:
|
||||
db_files = list(test_data_dir.glob("*.db"))
|
||||
|
||||
if db_files:
|
||||
original_db = db_files[0]
|
||||
temp_db_path = tmp_path / "test_migration.db"
|
||||
print(f"[DATA_SOURCE] USE_SAMPLE_DB: {original_db.name}")
|
||||
_log(f"[DATA_SOURCE] USE_SAMPLE_DB: {original_db.name}")
|
||||
shutil.copy(original_db, temp_db_path)
|
||||
return temp_db_path
|
||||
|
||||
# Fallback: Create an empty mock database with the required schema
|
||||
temp_db_path = tmp_path / "mock_migration.db"
|
||||
print("[DATA_SOURCE] USE_MOCK_DB: Fallback to empty schema")
|
||||
_log("[DATA_SOURCE] USE_MOCK_DB: Fallback to empty schema")
|
||||
db = BackupDatabase(temp_db_path)
|
||||
# The BackupDatabase constructor already initializes the schema
|
||||
return temp_db_path
|
||||
|
||||
@pytest.fixture
|
||||
def backup_db(temp_db):
|
||||
db = BackupDatabase(temp_db)
|
||||
yield db
|
||||
# No explicit close needed for now as it's a persistent connection
|
||||
|
||||
@pytest.fixture
|
||||
def backup_reader(test_data_dir, reaper_config, tmp_path):
|
||||
sid = reaper_config.get("discord_server_id")
|
||||
backup_path = test_data_dir / f"DISCORD_BACKUP-{sid}"
|
||||
|
||||
if not test_data_dir.exists() or not backup_path.exists():
|
||||
# Fallback: create mock backup structure
|
||||
mock_path = tmp_path / f"DISCORD_BACKUP-{sid}"
|
||||
print(f"[DATA_SOURCE] USE_MOCK_BACKUP: {mock_path.name}")
|
||||
_log(f"[DATA_SOURCE] USE_MOCK_BACKUP: {mock_path.name}")
|
||||
mock_path.mkdir(parents=True, exist_ok=True)
|
||||
db_path = mock_path / "backup.db"
|
||||
db = BackupDatabase(db_path)
|
||||
# Populate with minimal mock data for BackupReader to work
|
||||
db._conn.execute("INSERT OR IGNORE INTO guild_profile (id, name) VALUES (?, ?)", (int(sid), "Mock Guild"))
|
||||
db._conn.execute("INSERT OR IGNORE INTO channels (id, name, type) VALUES (?, ?, ?)", (123, "mock-channel", 0))
|
||||
db._conn.commit()
|
||||
from src.core.backup_reader import BackupReader
|
||||
return BackupReader(mock_path)
|
||||
|
||||
print(f"[DATA_SOURCE] USE_SAMPLE_BACKUP: {backup_path.name}")
|
||||
_log(f"[DATA_SOURCE] USE_SAMPLE_BACKUP: {backup_path.name}")
|
||||
from src.core.backup_reader import BackupReader
|
||||
return BackupReader(backup_path)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_reader():
|
||||
reader = MagicMock()
|
||||
reader.guild = MagicMock()
|
||||
reader.fetch_message_history = AsyncMock()
|
||||
reader.download_attachment = AsyncMock(return_value=b"fake_data")
|
||||
reader.download_sticker = AsyncMock(return_value=b"fake_sticker_data")
|
||||
return reader
|
||||
|
||||
@pytest.fixture
|
||||
def mock_fluxer_writer():
|
||||
writer = MagicMock()
|
||||
writer.send_message = AsyncMock(return_value="fluxer_msg_123")
|
||||
writer.send_marker = AsyncMock()
|
||||
return writer
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stoat_writer():
|
||||
writer = MagicMock()
|
||||
writer.send_message = AsyncMock(return_value="stoat_msg_123")
|
||||
writer.send_marker = AsyncMock()
|
||||
return writer
|
||||
|
|
@ -1,139 +0,0 @@
|
|||
import pytest
|
||||
from src.core.backup_database import BackupDatabase
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
from tests.conftest import _log
|
||||
msg = "[DATA_SOURCE] UNIT_TEST: Using in-memory database"
|
||||
print(msg)
|
||||
_log(msg)
|
||||
return BackupDatabase(":memory:")
|
||||
|
||||
|
||||
def test_db_initialization(db):
|
||||
res = db._conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='guild_profile'"
|
||||
).fetchone()
|
||||
assert res is not None
|
||||
assert res["name"] == "guild_profile"
|
||||
|
||||
|
||||
def test_save_get_guild_profile(db):
|
||||
profile = {
|
||||
"id": "123456789",
|
||||
"name": "Test Server",
|
||||
"description": "Testing",
|
||||
"owner_id": "987654321",
|
||||
"ignore_channels": ["111", "222"],
|
||||
}
|
||||
db.set_guild_profile(profile)
|
||||
got = db.get_guild_profile()
|
||||
assert got["name"] == "Test Server"
|
||||
assert got["id"] == 123456789
|
||||
assert got["ignore_channels"] == ["111", "222"]
|
||||
|
||||
|
||||
def test_save_get_roles(db):
|
||||
roles = [
|
||||
{"id": "1", "name": "Admin", "color": 0xFF0000, "position": 1, "permissions": 8, "hoist": True, "mentionable": True},
|
||||
{"id": "2", "name": "User", "color": 0x0000FF, "position": 2, "permissions": 0, "hoist": False, "mentionable": False},
|
||||
]
|
||||
db.save_roles(roles)
|
||||
got = db.get_all_roles()
|
||||
assert len(got) == 2
|
||||
assert {r["name"] for r in got} == {"Admin", "User"}
|
||||
|
||||
|
||||
def test_save_get_channels(db):
|
||||
channels = [
|
||||
{"id": 10, "name": "general", "type": 0, "position": 0, "category_id": None,
|
||||
"topic": "Talk", "nsfw": 0, "bitrate": None, "slowmode_delay": None},
|
||||
{"id": 20, "name": "voice", "type": 2, "position": 1, "category_id": None,
|
||||
"topic": None, "nsfw": 0, "bitrate": 64000, "slowmode_delay": None},
|
||||
]
|
||||
db.save_channels(channels)
|
||||
got = db.get_all_channels()
|
||||
assert len(got) == 2
|
||||
assert {c["name"] for c in got} == {"general", "voice"}
|
||||
|
||||
|
||||
def test_save_messages_and_attachments(db):
|
||||
messages = [
|
||||
{
|
||||
"id": 101, "channel_id": 10, "author_id": 999, "content": "Hello",
|
||||
"timestamp": "2023-01-01T00:00:00Z", "type": 0,
|
||||
"message_reference": None, "is_pinned": 0, "extra_data": None,
|
||||
"attachments": [
|
||||
{"id": 1, "filename": "file.png", "size": 100,
|
||||
"url": "http://cdn.test/file.png", "content_type": "image/png", "local_hash": "abc"}
|
||||
],
|
||||
}
|
||||
]
|
||||
db.save_messages_batch(messages)
|
||||
msg = db._conn.execute("SELECT content FROM messages WHERE id=101").fetchone()
|
||||
assert msg["content"] == "Hello"
|
||||
att = db._conn.execute("SELECT filename FROM attachments WHERE message_id=101").fetchone()
|
||||
assert att["filename"] == "file.png"
|
||||
|
||||
|
||||
def test_get_last_message_id(db):
|
||||
msgs = [
|
||||
{"id": 200, "channel_id": 10, "author_id": 1, "content": "a",
|
||||
"timestamp": "2023-01-01T00:00:00Z", "type": 0,
|
||||
"message_reference": None, "is_pinned": 0, "extra_data": None},
|
||||
{"id": 201, "channel_id": 10, "author_id": 1, "content": "b",
|
||||
"timestamp": "2023-01-01T00:01:00Z", "type": 0,
|
||||
"message_reference": None, "is_pinned": 0, "extra_data": None},
|
||||
]
|
||||
db.save_messages_batch(msgs)
|
||||
assert db.get_last_message_id("10") == 201
|
||||
|
||||
|
||||
def test_stats_by_channel(db):
|
||||
msgs = [
|
||||
{"id": 101, "channel_id": 10, "author_id": 1, "content": "Hi",
|
||||
"timestamp": "2023-01-01T00:00:00Z", "type": 0,
|
||||
"message_reference": None, "is_pinned": 0, "extra_data": None},
|
||||
{"id": 102, "channel_id": 10, "author_id": 1, "content": "Bye",
|
||||
"timestamp": "2023-01-01T00:01:00Z", "type": 0,
|
||||
"message_reference": None, "is_pinned": 0, "extra_data": None},
|
||||
]
|
||||
db.save_messages_batch(msgs)
|
||||
stats = db.get_stats_by_channel()
|
||||
assert stats[10]["message_count"] == 2
|
||||
|
||||
|
||||
def test_save_threads(db):
|
||||
threads = [
|
||||
{"id": 300, "name": "thread-1", "type": 11, "parent_id": 10,
|
||||
"message_count": 5, "member_count": 2, "archived": 0,
|
||||
"archive_timestamp": None, "auto_archive_duration": 1440,
|
||||
"locked": 0, "applied_tags": None},
|
||||
]
|
||||
db.save_threads(threads)
|
||||
got = db.get_threads_by_parent("10")
|
||||
assert len(got) == 1
|
||||
assert got[0]["name"] == "thread-1"
|
||||
|
||||
|
||||
def test_media_pool(db):
|
||||
db.add_media_to_pool("hash123", "/path/file.png", 512, "image/png", "http://cdn.test/file.png")
|
||||
db._conn.commit()
|
||||
entry = db.get_media_by_hash("hash123")
|
||||
assert entry is not None
|
||||
assert entry["local_path"] == "/path/file.png"
|
||||
|
||||
|
||||
def test_get_messages_paged_after_id(db):
|
||||
msgs = [
|
||||
{"id": i, "channel_id": 10, "author_id": 1, "content": f"msg{i}",
|
||||
"timestamp": f"2023-01-01T00:0{i}:00Z", "type": 0,
|
||||
"message_reference": None, "is_pinned": 0, "extra_data": None}
|
||||
for i in range(5)
|
||||
]
|
||||
db.save_messages_batch(msgs)
|
||||
page = db.get_messages_paged("10", limit=3, offset=0)
|
||||
assert len(page) == 3
|
||||
page_after = db.get_messages_paged("10", limit=10, after_id="2")
|
||||
assert all(m["id"] > 2 for m in page_after)
|
||||
|
|
@ -1,138 +0,0 @@
|
|||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
from src.core.base import MigrationContext
|
||||
from src.core.configuration import AppConfig
|
||||
from src.core.backup_reader import ChannelType
|
||||
from src.fluxer.migrate_message import migrate_messages as fluxer_migrate, _process_and_send_message as fluxer_send
|
||||
from src.stoat.migrate_message import migrate_messages as stoat_migrate, _process_and_send_message as stoat_send
|
||||
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
# --- Platform Detection (Same as e2e_simulation) ---
|
||||
def get_platforms():
|
||||
"""Determine which platforms to test based on config, or use defaults."""
|
||||
config_path = (Path(__file__).parent.parent / "ReaperFiles-AutoTest/reaper_config.yaml").resolve()
|
||||
if not config_path.exists():
|
||||
return ["fluxer", "stoat"]
|
||||
with open(config_path, "r") as f:
|
||||
data = yaml.safe_load(f)
|
||||
platforms = []
|
||||
if data.get("fluxer_bot_token"): platforms.append("fluxer")
|
||||
if data.get("stoat_bot_token"): platforms.append("stoat")
|
||||
return platforms if platforms else ["fluxer", "stoat"]
|
||||
|
||||
# --- Unit Tests (Transformation Logic) ---
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context(mock_discord_reader, mock_fluxer_writer, mock_stoat_writer):
|
||||
context = MagicMock(spec=MigrationContext)
|
||||
context.discord_reader = mock_discord_reader
|
||||
context.fluxer_writer = mock_fluxer_writer
|
||||
context.stoat_writer = mock_stoat_writer
|
||||
context.state = MagicMock()
|
||||
context.state.get_user_alias.return_value = "TestAlias"
|
||||
context.state.emoji_map = {}
|
||||
context.state.channel_map = {}
|
||||
context.is_running = True
|
||||
return context
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message():
|
||||
msg = MagicMock()
|
||||
msg.id = 111
|
||||
msg.author.id = 222
|
||||
msg.author.display_name = "Author"
|
||||
msg.content = "Test content"
|
||||
msg.attachments = []
|
||||
msg.embeds = []
|
||||
msg.stickers = []
|
||||
msg.created_at.timestamp.return_value = 1600000000.0
|
||||
msg.flags.forwarded = False
|
||||
msg.reference = None
|
||||
return msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migration_transform_fluxer(mock_context, mock_message):
|
||||
stats = {"messages": 0, "attachments": 0}
|
||||
result = await fluxer_send(context=mock_context, msg=mock_message, target_channel_id="c1", stats=stats)
|
||||
assert result == "fluxer_msg_123"
|
||||
assert stats["messages"] == 1
|
||||
assert mock_context.fluxer_writer.send_message.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migration_transform_stoat(mock_context, mock_message):
|
||||
stats = {"messages": 0, "attachments": 0}
|
||||
result = await stoat_send(context=mock_context, msg=mock_message, target_channel_id="c1", stats=stats)
|
||||
assert result == "stoat_msg_123"
|
||||
assert stats["messages"] == 1
|
||||
assert mock_context.stoat_writer.send_message.called
|
||||
|
||||
# --- Integration Tests (Backup Reader) ---
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backup_reader_interaction(backup_reader, reaper_config):
|
||||
await backup_reader.start()
|
||||
assert backup_reader.guild is not None
|
||||
# Verify ID from config (or mock default)
|
||||
assert str(backup_reader.guild.id) == reaper_config.get("discord_server_id")
|
||||
|
||||
channels = await backup_reader.fetch_channels()
|
||||
assert len(channels) > 0
|
||||
|
||||
# --- E2E Simulation ---
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("platform", get_platforms())
|
||||
async def test_migration_e2e_loop(reaper_config, test_data_dir, tmp_path, platform, request):
|
||||
config = AppConfig(
|
||||
discord_bot_token=reaper_config["discord_bot_token"],
|
||||
discord_server_id=reaper_config["discord_server_id"],
|
||||
target_platform=platform,
|
||||
fluxer_bot_token=reaper_config.get("fluxer_bot_token"),
|
||||
fluxer_server_id=reaper_config.get("fluxer_server_id"),
|
||||
stoat_bot_token=reaper_config.get("stoat_bot_token"),
|
||||
stoat_server_id=reaper_config.get("stoat_server_id"),
|
||||
anonymize_users=reaper_config["anonymize_users"]
|
||||
)
|
||||
|
||||
if test_data_dir.exists():
|
||||
base_dir = test_data_dir
|
||||
msg = f"[DATA_SOURCE] E2E_SIMULATION: Using sample data from {base_dir.name}"
|
||||
else:
|
||||
base_dir = tmp_path
|
||||
msg = f"[DATA_SOURCE] E2E_SIMULATION: Fallback to mock data in {base_dir}"
|
||||
|
||||
from tests.conftest import _log
|
||||
print(msg)
|
||||
_log(msg)
|
||||
|
||||
context = MigrationContext(config, source_mode="backup", base_dir=str(base_dir))
|
||||
|
||||
mock_writer = request.getfixturevalue(f"mock_{platform}_writer")
|
||||
if platform == "fluxer":
|
||||
context.fluxer_writer = mock_writer
|
||||
migrate_func = fluxer_migrate
|
||||
else:
|
||||
context.stoat_writer = mock_writer
|
||||
migrate_func = stoat_migrate
|
||||
|
||||
context.is_running = True
|
||||
from src.core.database import MigrationDatabase
|
||||
context.state.db = MigrationDatabase(tmp_path / f"e2e_{platform}.db", platform=platform)
|
||||
|
||||
await context.discord_reader.start()
|
||||
channels = await context.discord_reader.fetch_channels()
|
||||
text_channels = [c for c in channels if c.type == ChannelType.text]
|
||||
|
||||
if text_channels:
|
||||
source_channel_id = text_channels[0].id
|
||||
target_channel_id = "999" if platform == "fluxer" else "stoat123"
|
||||
|
||||
mock_writer.send_message.side_effect = lambda **kwargs: "ok"
|
||||
|
||||
# Test just the first available channel to keep it fast
|
||||
stats = await migrate_func(context=context, source_channel_id=source_channel_id, target_channel_id=target_channel_id)
|
||||
assert stats["messages"] >= 0
|
||||
|
||||
await context.discord_reader.close()
|
||||
164
tests/test_ui.py
164
tests/test_ui.py
|
|
@ -1,164 +0,0 @@
|
|||
import pytest
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
from textual.app import App
|
||||
from src.ui.main_app import ReaperApp, ConfigSelectionScreen, ConfigScreen
|
||||
from src.ui.mode_screen import ModeScreen
|
||||
from src.core.configuration import AppConfig
|
||||
|
||||
import os
|
||||
from textual.widgets import ListItem, ListView, Input, Button, Label
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_configs(tmp_path, log):
|
||||
reaper_dir = tmp_path / "ReaperFiles-TestConfig"
|
||||
reaper_dir.mkdir()
|
||||
(reaper_dir / "reaper_config.yaml").write_text("discord_bot_token: 'fake'\ndiscord_server_id: '123'\ntool_mode: 'backup_only'")
|
||||
|
||||
autotest_dir = tmp_path / "ReaperFiles-AutoTest"
|
||||
autotest_dir.mkdir()
|
||||
(autotest_dir / "reaper_config.yaml").write_text("discord_bot_token: 'fake'\ndiscord_server_id: '123'\ntool_mode: 'backup_transfer'")
|
||||
|
||||
old_cwd = os.getcwd()
|
||||
os.chdir(tmp_path)
|
||||
log(f"CWD changed to: {tmp_path}")
|
||||
yield tmp_path
|
||||
os.chdir(old_cwd)
|
||||
|
||||
async def wait_for_screen(app, screen_class, timeout=5.0):
|
||||
import time
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
if isinstance(app.screen, screen_class):
|
||||
return True
|
||||
await asyncio.sleep(0.1)
|
||||
return False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ui_minimal_launch(mock_configs, log):
|
||||
"""Verify app launch and screen transition to ModeScreen."""
|
||||
log("Running test_ui_minimal_launch")
|
||||
try:
|
||||
# Reverting to AsyncMock to avoid WorkerError.
|
||||
# RuntimeWarnings are now handled globally in conftest.py.
|
||||
with patch("src.ui.main_app.ConfigSelectionScreen.check_updates", AsyncMock()):
|
||||
app = ReaperApp()
|
||||
async with app.run_test() as pilot:
|
||||
await wait_for_screen(app, ConfigSelectionScreen)
|
||||
await pilot.click(ListItem)
|
||||
await wait_for_screen(app, ModeScreen)
|
||||
assert isinstance(app.screen, ModeScreen)
|
||||
log("test_ui_minimal_launch PASSED")
|
||||
except Exception as e:
|
||||
log(f"test_ui_minimal_launch FAILED: {e}")
|
||||
raise
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ui_config_wizard_save(mock_configs, log):
|
||||
"""Verify configuration editing and saving."""
|
||||
log("Running test_ui_config_wizard_save")
|
||||
try:
|
||||
with patch("src.ui.main_app.ConfigSelectionScreen.check_updates", AsyncMock()):
|
||||
app = ReaperApp()
|
||||
async with app.run_test() as pilot:
|
||||
await wait_for_screen(app, ConfigSelectionScreen)
|
||||
await pilot.click(ListItem)
|
||||
await wait_for_screen(app, ModeScreen)
|
||||
await pilot.click("#btn_config")
|
||||
await wait_for_screen(app, ConfigScreen)
|
||||
|
||||
screen = app.screen
|
||||
inp = screen.query_one("#inp_discord_token", Input)
|
||||
inp.value = "new_fake_token"
|
||||
|
||||
with patch("src.ui.main_app.save_config") as mock_save:
|
||||
await pilot.click("#btn_save")
|
||||
await pilot.pause(0.2)
|
||||
assert mock_save.called
|
||||
log("test_ui_config_wizard_save PASSED")
|
||||
except Exception as e:
|
||||
log(f"test_ui_config_wizard_save FAILED: {e}")
|
||||
raise
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ui_operation_trigger(mock_configs, log):
|
||||
"""Verify that an operation can be triggered."""
|
||||
log("Running test_ui_operation_trigger")
|
||||
from src.ui.shuttle_ops import OperationPane
|
||||
from src.ui.modals import ChannelPickerScreen, ProgressScreen
|
||||
try:
|
||||
with patch("src.ui.main_app.ConfigSelectionScreen.check_updates", AsyncMock()):
|
||||
# run_validate is decorated with @work in shuttle_ops.py, so it MUST be a coroutine (AsyncMock)
|
||||
with patch.object(OperationPane, "run_validate", AsyncMock()):
|
||||
app = ReaperApp()
|
||||
async with app.run_test() as pilot:
|
||||
await wait_for_screen(app, ConfigSelectionScreen)
|
||||
await pilot.click(ListItem)
|
||||
await wait_for_screen(app, ModeScreen)
|
||||
|
||||
pane = app.screen.query_one(OperationPane)
|
||||
pane.tokens_valid = True
|
||||
pane.src_channels = [{"id": 1, "name": "t"}]
|
||||
pane.src_cat_map = {None: "D"}
|
||||
pane.tgt_channels = [{"id": 2, "name": "t"}]
|
||||
pane.tgt_cat_map = {None: "D"}
|
||||
pane.all_tgt_channels = pane.tgt_channels
|
||||
|
||||
btn = pane.query_one("#op_backup_msgs", Button)
|
||||
btn.disabled = False
|
||||
await pilot.pause(0.2)
|
||||
btn.focus()
|
||||
await pilot.press("enter")
|
||||
|
||||
await wait_for_screen(app, (ChannelPickerScreen, ProgressScreen))
|
||||
assert isinstance(app.screen, (ChannelPickerScreen, ProgressScreen))
|
||||
log("test_ui_operation_trigger PASSED")
|
||||
except Exception as e:
|
||||
log(f"test_ui_operation_trigger FAILED: {e}")
|
||||
raise
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ui_autotest_button(mock_configs, log):
|
||||
"""Verify visibility and trigger of the AUTO TEST button."""
|
||||
log("Running test_ui_autotest_button")
|
||||
from src.ui.shuttle_ops import OperationPane
|
||||
try:
|
||||
with patch("src.ui.main_app.ConfigSelectionScreen.check_updates", AsyncMock()):
|
||||
with patch.object(OperationPane, "run_validate", AsyncMock()):
|
||||
app = ReaperApp()
|
||||
async with app.run_test() as pilot:
|
||||
await wait_for_screen(app, ConfigSelectionScreen)
|
||||
# Find the AutoTest item in the list
|
||||
lv = app.screen.query_one(ListView)
|
||||
autotest_index = -1
|
||||
for idx, item in enumerate(lv.children):
|
||||
if item.name == "AutoTest":
|
||||
autotest_index = idx
|
||||
break
|
||||
|
||||
assert autotest_index != -1, "AutoTest profile not found in ListView"
|
||||
target_item = lv.children[autotest_index]
|
||||
await pilot.click(target_item)
|
||||
assert await wait_for_screen(app, ModeScreen), "Timed out waiting for ModeScreen"
|
||||
|
||||
# Verify button is present
|
||||
pane = app.screen.query_one(OperationPane)
|
||||
btn = pane.query_one("#op_autotest", Button)
|
||||
assert btn.display is True
|
||||
assert "AUTO TEST" in str(btn.label)
|
||||
|
||||
# Mock the sequence and trigger it
|
||||
with patch.object(OperationPane, "run_autotest_sequence", AsyncMock()) as mock_seq:
|
||||
btn.disabled = False
|
||||
await pilot.pause(0.1)
|
||||
btn.focus()
|
||||
await pilot.press("enter")
|
||||
await pilot.pause(0.1)
|
||||
assert mock_seq.called
|
||||
log("test_ui_autotest_button PASSED")
|
||||
except Exception as e:
|
||||
log(f"test_ui_autotest_button FAILED: {e}")
|
||||
raise
|
||||
|
|
@ -1,60 +0,0 @@
|
|||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from src.core.utils import parse_snowflake, resolve_discord_links
|
||||
|
||||
def test_parse_snowflake_valid():
|
||||
assert parse_snowflake("12345") == 12345
|
||||
assert parse_snowflake(12345) == 12345
|
||||
assert parse_snowflake(" 67890 ") == 67890
|
||||
|
||||
def test_parse_snowflake_invalid():
|
||||
assert parse_snowflake(None) is None
|
||||
assert parse_snowflake("") is None
|
||||
assert parse_snowflake("none") is None
|
||||
assert parse_snowflake("NULL") is None
|
||||
assert parse_snowflake("not_a_number") is None
|
||||
|
||||
def test_resolve_discord_links_no_content():
|
||||
assert resolve_discord_links("", None, "fluxer", "target_id") == ""
|
||||
assert resolve_discord_links(None, None, "fluxer", "target_id") is None
|
||||
|
||||
def test_resolve_discord_links_no_mapping():
|
||||
mock_state = MagicMock()
|
||||
mock_state.get_target_channel_id.return_value = None
|
||||
mock_state.get_target_category_id.return_value = None
|
||||
mock_state.find_message_mapping.return_value = (None, None)
|
||||
|
||||
content = "Check this: https://discord.com/channels/1/2/3"
|
||||
resolved = resolve_discord_links(content, mock_state, "fluxer", "target_server")
|
||||
assert "[`discord-message`](<https://discord.com/channels/1/2/3>)" in resolved
|
||||
|
||||
def test_resolve_discord_links_channel_mapping():
|
||||
mock_state = MagicMock()
|
||||
mock_state.get_target_channel_id.return_value = "target_chan_456"
|
||||
mock_state.find_message_mapping.return_value = (None, None)
|
||||
|
||||
content = "Go to https://discord.com/channels/123/456"
|
||||
|
||||
# Test Fluxer
|
||||
resolved_fluxer = resolve_discord_links(content, mock_state, "fluxer", "target_server")
|
||||
assert "https://fluxer.app/channels/target_server/target_chan_456" in resolved_fluxer
|
||||
|
||||
# Test Stoat
|
||||
resolved_stoat = resolve_discord_links(content, mock_state, "stoat", "target_server")
|
||||
assert "https://stoat.chat/server/target_server/channel/target_chan_456" in resolved_stoat
|
||||
|
||||
def test_resolve_discord_links_message_mapping():
|
||||
mock_state = MagicMock()
|
||||
mock_state.find_message_mapping.return_value = ("target_chan_456", "target_msg_789")
|
||||
|
||||
content = "Look at this: https://discord.com/channels/123/456/789"
|
||||
|
||||
# Test Fluxer
|
||||
resolved_fluxer = resolve_discord_links(content, mock_state, "fluxer", "target_server")
|
||||
assert "https://fluxer.app/channels/target_server/target_chan_456/target_msg_789" in resolved_fluxer
|
||||
|
||||
def test_resolve_discord_links_skips_wrapped():
|
||||
mock_state = MagicMock()
|
||||
content = "Already wrapped: [link](https://discord.com/channels/1/2/3)"
|
||||
resolved = resolve_discord_links(content, mock_state, "fluxer", "target_server")
|
||||
assert resolved == content
|
||||
Loading…
Add table
Reference in a new issue