Compare commits
18 commits
fb773ba948
...
0898003749
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0898003749 | ||
|
|
7fd487ad66 | ||
|
|
9f27af971c | ||
|
|
d63ed0c440 | ||
|
|
e5ff2ae37f | ||
|
|
90ced9df65 | ||
|
|
2d33b16d2c | ||
|
|
bd80760667 | ||
|
|
2ef141776c | ||
|
|
2ddc4424cd | ||
|
|
071de5296b | ||
|
|
5b315ab2bf | ||
|
|
ef2e945477 | ||
|
|
514a2e551c | ||
|
|
73d52d2183 | ||
|
|
0cb678b848 | ||
|
|
3f649b3062 | ||
|
|
c581911f68 |
25 changed files with 2196 additions and 741 deletions
5
.gitignore
vendored
5
.gitignore
vendored
|
|
@ -21,6 +21,7 @@ wheels/
|
||||||
.installed.cfg
|
.installed.cfg
|
||||||
*.egg
|
*.egg
|
||||||
*.txt
|
*.txt
|
||||||
|
*.exe
|
||||||
|
|
||||||
# Virtual Environment
|
# Virtual Environment
|
||||||
venv/
|
venv/
|
||||||
|
|
@ -40,11 +41,11 @@ reaper_config.yaml
|
||||||
*.log
|
*.log
|
||||||
|
|
||||||
# Temporary Test Scripts
|
# Temporary Test Scripts
|
||||||
|
#test_*.py
|
||||||
tmp/
|
tmp/
|
||||||
test_*.py
|
|
||||||
test_release.zip
|
test_release.zip
|
||||||
test_release/
|
test_release/
|
||||||
DiscoReaper-*
|
DiscoReaper
|
||||||
*.zip
|
*.zip
|
||||||
|
|
||||||
# App data files
|
# App data files
|
||||||
|
|
|
||||||
14
README.md
14
README.md
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
**DiscoReaper** is a powerful tool designed to help you migrate your entire Discord server to Fluxer or Stoat. It clones channels, roles, emojis, permissions, and also your community's full message history.
|
**DiscoReaper** is a powerful tool designed to help you migrate your entire Discord server to Fluxer or Stoat. It clones channels, roles, emojis, permissions, and also your community's full message history.
|
||||||
|
|
||||||
### Get it here: [](https://github.com/rambros3d/disco-reaper/releases/latest/download/disco-reaper-linux.zip) [](https://github.com/rambros3d/disco-reaper/releases/latest/download/disco-reaper-windows.zip) [](https://github.com/rambros3d/disco-reaper/releases/latest/download/disco-reaper-macos.zip)
|
### Get it here: [](https://git.mithraic.cloud/ad3laid3/disco-reaper/releases/latest) [](https://git.mithraic.cloud/ad3laid3/disco-reaper/releases/latest)
|
||||||
|
|
||||||
|
|
||||||
>Join our [**Reaper Community**](https://fluxer.gg/9KxDP8WH) if you need help or have any questions.
|
>Join our [**Reaper Community**](https://fluxer.gg/9KxDP8WH) if you need help or have any questions.
|
||||||
|
|
@ -81,7 +81,7 @@ Migrate directly from Discord or your Local Backups to the target community.
|
||||||
#### Setup the bots as per this [guide](BOT-SETUP.md)
|
#### Setup the bots as per this [guide](BOT-SETUP.md)
|
||||||
|
|
||||||
### Option 1: Using Pre-built Binaries (Easiest)
|
### Option 1: Using Pre-built Binaries (Easiest)
|
||||||
1. **Download**: Grab the latest version from the [Releases](https://github.com/rambros3d/disco-reaper/releases) page.
|
1. **Download**: Grab the latest version from the [Releases](https://git.mithraic.cloud/ad3laid3/disco-reaper/releases) page.
|
||||||
2. **Run**:
|
2. **Run**:
|
||||||
- **Linux**: Run the `disco-reaper` binary (e.g., `./launch.sh` or double-click).
|
- **Linux**: Run the `disco-reaper` binary (e.g., `./launch.sh` or double-click).
|
||||||
- **Windows**: Run `disco-reaper.exe`.
|
- **Windows**: Run `disco-reaper.exe`.
|
||||||
|
|
@ -89,7 +89,7 @@ Migrate directly from Discord or your Local Backups to the target community.
|
||||||
### Option 2: Running from Source (To use latest unstable code)
|
### Option 2: Running from Source (To use latest unstable code)
|
||||||
1. **Clone**: Clone the repository to your local machine:
|
1. **Clone**: Clone the repository to your local machine:
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/rambros3d/disco-reaper.git
|
git clone https://git.mithraic.cloud/ad3laid3/disco-reaper.git
|
||||||
cd disco-reaper
|
cd disco-reaper
|
||||||
```
|
```
|
||||||
2. **Launch**: Run the appropriate launcher script for your OS. It will automatically create a virtual environment and install dependencies:
|
2. **Launch**: Run the appropriate launcher script for your OS. It will automatically create a virtual environment and install dependencies:
|
||||||
|
|
@ -153,10 +153,4 @@ But now their own website states that **Persona** will be used in some countries
|
||||||
|
|
||||||
## Contributors
|
## Contributors
|
||||||
|
|
||||||
<a href="https://github.com/rambros3d/disco-reaper/graphs/contributors">
|
MiTHRAL — fork maintainer, Stoat/Revolt integration & bug fixes.
|
||||||
<img src="https://contrib.rocks/image?repo=rambros3d/disco-reaper" />
|
|
||||||
</a>
|
|
||||||
|
|
||||||
Made with [contrib.rocks](https://contrib.rocks).
|
|
||||||
|
|
||||||
[](https://www.star-history.com/?repos=rambros3d%2Fdisco-reaper&type=date&legend=top-left)
|
|
||||||
|
|
|
||||||
77
build.bat
Normal file
77
build.bat
Normal file
|
|
@ -0,0 +1,77 @@
|
||||||
|
@echo off
|
||||||
|
setlocal enabledelayedexpansion
|
||||||
|
cd /d "%~dp0"
|
||||||
|
|
||||||
|
echo --- Disco-Reaper Windows Build Script ---
|
||||||
|
|
||||||
|
REM Check for venv
|
||||||
|
IF NOT EXIST "venv" (
|
||||||
|
echo Creating virtual environment...
|
||||||
|
python -m venv venv
|
||||||
|
IF ERRORLEVEL 1 (
|
||||||
|
echo Error: Failed to create virtual environment.
|
||||||
|
echo Ensure Python is installed and added to PATH.
|
||||||
|
IF NOT DEFINED AUTO_BUILD pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
echo Activating virtual environment...
|
||||||
|
call venv\Scripts\activate.bat
|
||||||
|
|
||||||
|
REM Self-healing pip check
|
||||||
|
python -m pip --version >nul 2>&1
|
||||||
|
IF ERRORLEVEL 1 (
|
||||||
|
echo Warning: pip is missing or broken in venv. Attempting repair...
|
||||||
|
python -m ensurepip --default-pip
|
||||||
|
IF ERRORLEVEL 1 (
|
||||||
|
echo Error: Failed to repair pip automatically.
|
||||||
|
echo Try recreating the venv: rmdir /s /q venv ^&^& python -m venv venv
|
||||||
|
IF NOT DEFINED AUTO_BUILD pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
echo Ensuring build dependencies are up to date...
|
||||||
|
REM python -m pip install --upgrade pip --quiet
|
||||||
|
python -m pip install pyinstaller --quiet
|
||||||
|
python -m pip install -r requirements.txt --quiet
|
||||||
|
|
||||||
|
echo Cleaning previous build artifacts...
|
||||||
|
IF EXIST "build" rmdir /s /q build
|
||||||
|
IF EXIST "dist" rmdir /s /q dist
|
||||||
|
|
||||||
|
echo Starting PyInstaller build...
|
||||||
|
REM Get git version tag
|
||||||
|
set "GIT_VERSION=Unknown"
|
||||||
|
for /f "tokens=*" %%i in ('git describe --tags --abbrev^=0 2^>nul') do set "GIT_VERSION=%%i"
|
||||||
|
echo Baking version: %GIT_VERSION%
|
||||||
|
echo __version__ = "%GIT_VERSION%"> src\core\_baked_version.py
|
||||||
|
|
||||||
|
python -m PyInstaller --clean disco-reaper.spec
|
||||||
|
IF ERRORLEVEL 1 (
|
||||||
|
echo Error: PyInstaller build failed.
|
||||||
|
del /f src\core\_baked_version.py 2>nul
|
||||||
|
IF NOT DEFINED AUTO_BUILD pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
echo Cleaning up baked version file...
|
||||||
|
del /f src\core\_baked_version.py 2>nul
|
||||||
|
|
||||||
|
echo Packaging release: disco-reaper-windows.zip...
|
||||||
|
cd dist
|
||||||
|
powershell -Command "Compress-Archive -Path 'DiscoReaper.exe' -DestinationPath 'disco-reaper-windows.zip' -Force" 2>nul
|
||||||
|
IF ERRORLEVEL 1 (
|
||||||
|
echo Warning: Failed to create zip. Files are available in dist\ directory.
|
||||||
|
) ELSE (
|
||||||
|
echo Package created: dist\disco-reaper-windows.zip
|
||||||
|
)
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
echo -----------------------------------
|
||||||
|
echo Build complete!
|
||||||
|
echo Standalone executable: dist\DiscoReaper.exe
|
||||||
|
echo Release Package: dist\disco-reaper-windows.zip
|
||||||
|
echo ---
|
||||||
|
IF NOT DEFINED AUTO_BUILD pause
|
||||||
|
|
@ -1,17 +1,18 @@
|
||||||
import sys
|
import sys
|
||||||
import logging
|
import logging
|
||||||
|
from logging.handlers import RotatingFileHandler
|
||||||
from src.ui.main_app import run_disco_reaper_tui
|
from src.ui.main_app import run_disco_reaper_tui
|
||||||
from src.core.configuration import load_config
|
from src.core.configuration import load_config
|
||||||
|
|
||||||
def setup_logging():
|
def setup_logging():
|
||||||
try:
|
try:
|
||||||
config = load_config(create_if_missing=False)
|
config = load_config(create_if_missing=False)
|
||||||
log_level_str = config.migration.log_level.upper()
|
log_level_str = config.log_level.upper()
|
||||||
level = getattr(logging, log_level_str, logging.INFO)
|
level = getattr(logging, log_level_str, logging.INFO)
|
||||||
except Exception:
|
except Exception:
|
||||||
level = logging.INFO
|
level = logging.INFO
|
||||||
|
|
||||||
handlers = [logging.FileHandler('.reaper.log', mode='a')]
|
handlers = [RotatingFileHandler('.reaper.log', mode='w', maxBytes=10*1024*1024, backupCount=3)]
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
|
format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
|
||||||
datefmt='%H:%M:%S',
|
datefmt='%H:%M:%S',
|
||||||
|
|
@ -92,18 +93,21 @@ def cleanup_old_update():
|
||||||
"""Removes the .old executable left behind by a Windows update."""
|
"""Removes the .old executable left behind by a Windows update."""
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
if sys.platform != "win32":
|
if sys.platform != "win32":
|
||||||
return
|
return
|
||||||
|
|
||||||
current_exe = sys.executable if getattr(sys, 'frozen', False) else sys.argv[0]
|
# In frozen (PyInstaller) builds, sys.executable points to the temp _MEIxxxxx dir.
|
||||||
old_exe = current_exe + ".old"
|
# sys.argv[0] always points to the real .exe on disk, so use that and resolve() it.
|
||||||
|
current_exe = Path(sys.argv[0]).resolve()
|
||||||
|
old_exe = current_exe.with_suffix(current_exe.suffix + ".old")
|
||||||
|
|
||||||
if os.path.exists(old_exe):
|
if old_exe.exists():
|
||||||
try:
|
try:
|
||||||
os.remove(old_exe)
|
old_exe.unlink()
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logging.getLogger(__name__).debug(f"Could not remove old update file {old_exe}: {e}")
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
import os
|
import os
|
||||||
|
|
|
||||||
|
|
@ -9,3 +9,6 @@ lottie # Lottie file manipulation and conversion
|
||||||
Pillow # Image processing (required for GIF rendering)
|
Pillow # Image processing (required for GIF rendering)
|
||||||
cairosvg # SVG rendering (required for Lottie conversion)
|
cairosvg # SVG rendering (required for Lottie conversion)
|
||||||
psutil # System information (CPU, RAM, etc.)
|
psutil # System information (CPU, RAM, etc.)
|
||||||
|
pytest # Testing framework
|
||||||
|
pytest-asyncio # Async testing for pytest
|
||||||
|
pytest-mock # Mocking for pytest
|
||||||
21
runtests.sh
Executable file
21
runtests.sh
Executable file
|
|
@ -0,0 +1,21 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Convenient script to run the Discord Reaper test suite.
|
||||||
|
# Automatically handles VENV activation and PYTHONPATH.
|
||||||
|
|
||||||
|
# Change to the project root directory (where the script is located)
|
||||||
|
cd "$(dirname "$0")"
|
||||||
|
|
||||||
|
# Activate the virtual environment if it exists
|
||||||
|
if [ -d "venv" ]; then
|
||||||
|
source venv/bin/activate
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Set PYTHONPATH to the current directory so src is discoverable
|
||||||
|
export PYTHONPATH=.
|
||||||
|
|
||||||
|
# Run pytest with any arguments passed to this script, defaulting to tests/
|
||||||
|
if [ $# -eq 0 ]; then
|
||||||
|
pytest -v -s -p no:warnings tests/
|
||||||
|
else
|
||||||
|
pytest -s "$@"
|
||||||
|
fi
|
||||||
|
|
@ -98,9 +98,9 @@ class BackupDatabase:
|
||||||
elif table == "permissions":
|
elif table == "permissions":
|
||||||
conn.execute("CREATE TABLE permissions (id INTEGER PRIMARY KEY AUTOINCREMENT, channel_id INTEGER, target_id INTEGER, target_type TEXT, allow INTEGER, deny INTEGER)")
|
conn.execute("CREATE TABLE permissions (id INTEGER PRIMARY KEY AUTOINCREMENT, channel_id INTEGER, target_id INTEGER, target_type TEXT, allow INTEGER, deny INTEGER)")
|
||||||
elif table == "users":
|
elif table == "users":
|
||||||
conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, username TEXT, display_name TEXT, avatar_file TEXT, avatar_url TEXT, roles TEXT)")
|
conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, username TEXT, display_name TEXT, avatar_file TEXT, avatar_url TEXT, roles TEXT, type INTEGER DEFAULT 0)")
|
||||||
elif table == "messages":
|
elif table == "messages":
|
||||||
conn.execute("CREATE TABLE messages (id INTEGER PRIMARY KEY, channel_id INTEGER, author_id INTEGER, content TEXT, timestamp TEXT, type INTEGER, message_reference INTEGER, is_pinned INTEGER, extra_data TEXT)")
|
conn.execute("CREATE TABLE messages (id INTEGER PRIMARY KEY, channel_id INTEGER, author_id INTEGER, content TEXT, timestamp TEXT, type INTEGER, message_reference INTEGER, is_pinned INTEGER, extra_data TEXT, custom_display_name TEXT, custom_avatar_url TEXT)")
|
||||||
elif table == "attachments":
|
elif table == "attachments":
|
||||||
conn.execute("CREATE TABLE attachments (id INTEGER PRIMARY KEY, message_id INTEGER, filename TEXT, size INTEGER, url TEXT, content_type TEXT, local_hash TEXT)")
|
conn.execute("CREATE TABLE attachments (id INTEGER PRIMARY KEY, message_id INTEGER, filename TEXT, size INTEGER, url TEXT, content_type TEXT, local_hash TEXT)")
|
||||||
elif table == "embeds":
|
elif table == "embeds":
|
||||||
|
|
@ -114,7 +114,7 @@ class BackupDatabase:
|
||||||
elif table == "forum_tags":
|
elif table == "forum_tags":
|
||||||
conn.execute("CREATE TABLE forum_tags (id INTEGER PRIMARY KEY, forum_id INTEGER, name TEXT, moderated INTEGER, emoji_id INTEGER, emoji_name TEXT)")
|
conn.execute("CREATE TABLE forum_tags (id INTEGER PRIMARY KEY, forum_id INTEGER, name TEXT, moderated INTEGER, emoji_id INTEGER, emoji_name TEXT)")
|
||||||
elif table == "server_assets":
|
elif table == "server_assets":
|
||||||
conn.execute("CREATE TABLE server_assets (id INTEGER PRIMARY KEY, name TEXT, type TEXT, filename TEXT, url TEXT, content_type INTEGER)")
|
conn.execute("CREATE TABLE server_assets (id INTEGER PRIMARY KEY, name TEXT, type TEXT, filename TEXT, url TEXT, content_type TEXT)")
|
||||||
|
|
||||||
old_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table}_old)").fetchall()]
|
old_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table}_old)").fetchall()]
|
||||||
new_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table})").fetchall()]
|
new_cols = [c[1] for c in conn.execute(f"PRAGMA table_info({table})").fetchall()]
|
||||||
|
|
@ -124,6 +124,25 @@ class BackupDatabase:
|
||||||
conn.execute(f"INSERT INTO {table} ({col_str}) SELECT {col_str} FROM {table}_old")
|
conn.execute(f"INSERT INTO {table} ({col_str}) SELECT {col_str} FROM {table}_old")
|
||||||
conn.execute(f"DROP TABLE {table}_old")
|
conn.execute(f"DROP TABLE {table}_old")
|
||||||
|
|
||||||
|
# 3. Custom Author Profile Migration
|
||||||
|
res = conn.execute("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='messages'").fetchone()
|
||||||
|
if res and res[0] > 0:
|
||||||
|
cols = conn.execute("PRAGMA table_info(messages)").fetchall()
|
||||||
|
col_names = [c["name"] for c in cols]
|
||||||
|
if "custom_display_name" not in col_names:
|
||||||
|
logger.info("Migrating messages: adding custom author profile columns")
|
||||||
|
conn.execute("ALTER TABLE messages ADD COLUMN custom_display_name TEXT")
|
||||||
|
conn.execute("ALTER TABLE messages ADD COLUMN custom_avatar_url TEXT")
|
||||||
|
|
||||||
|
# 4. User Type Categorization Migration
|
||||||
|
res = conn.execute("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='users'").fetchone()
|
||||||
|
if res and res[0] > 0:
|
||||||
|
cols = conn.execute("PRAGMA table_info(users)").fetchall()
|
||||||
|
col_names = [c["name"] for c in cols]
|
||||||
|
if "type" not in col_names:
|
||||||
|
logger.info("Migrating users: adding type column")
|
||||||
|
conn.execute("ALTER TABLE users ADD COLUMN type INTEGER DEFAULT 0")
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
def _init_db(self):
|
def _init_db(self):
|
||||||
|
|
@ -196,7 +215,8 @@ class BackupDatabase:
|
||||||
display_name TEXT,
|
display_name TEXT,
|
||||||
avatar_file TEXT,
|
avatar_file TEXT,
|
||||||
avatar_url TEXT,
|
avatar_url TEXT,
|
||||||
roles TEXT
|
roles TEXT,
|
||||||
|
type INTEGER DEFAULT 0
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
@ -211,7 +231,9 @@ class BackupDatabase:
|
||||||
type INTEGER,
|
type INTEGER,
|
||||||
message_reference INTEGER,
|
message_reference INTEGER,
|
||||||
is_pinned INTEGER,
|
is_pinned INTEGER,
|
||||||
extra_data TEXT
|
extra_data TEXT,
|
||||||
|
custom_display_name TEXT,
|
||||||
|
custom_avatar_url TEXT
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_messages_channel ON messages(channel_id)")
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_messages_channel ON messages(channel_id)")
|
||||||
|
|
@ -405,8 +427,8 @@ class BackupDatabase:
|
||||||
"""Saves users to the author cache."""
|
"""Saves users to the author cache."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._conn.executemany("""
|
self._conn.executemany("""
|
||||||
INSERT OR REPLACE INTO users (id, username, display_name, avatar_file, avatar_url, roles)
|
INSERT OR REPLACE INTO users (id, username, display_name, avatar_file, avatar_url, roles, type)
|
||||||
VALUES (:id, :username, :display_name, :avatar_file, :avatar_url, :roles)
|
VALUES (:id, :username, :display_name, :avatar_file, :avatar_url, :roles, :type)
|
||||||
""", users)
|
""", users)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
|
|
||||||
|
|
@ -454,8 +476,8 @@ class BackupDatabase:
|
||||||
conn = self._conn
|
conn = self._conn
|
||||||
# Insert messages
|
# Insert messages
|
||||||
conn.executemany("""
|
conn.executemany("""
|
||||||
INSERT OR REPLACE INTO messages (id, channel_id, author_id, content, timestamp, type, message_reference, is_pinned, extra_data)
|
INSERT OR REPLACE INTO messages (id, channel_id, author_id, content, timestamp, type, message_reference, is_pinned, extra_data, custom_display_name, custom_avatar_url)
|
||||||
VALUES (:id, :channel_id, :author_id, :content, :timestamp, :type, :message_reference, :is_pinned, :extra_data)
|
VALUES (:id, :channel_id, :author_id, :content, :timestamp, :type, :message_reference, :is_pinned, :extra_data, :custom_display_name, :custom_avatar_url)
|
||||||
""", messages)
|
""", messages)
|
||||||
|
|
||||||
# Extract attachments, reactions, and stickers
|
# Extract attachments, reactions, and stickers
|
||||||
|
|
@ -945,6 +967,60 @@ class BackupDatabase:
|
||||||
|
|
||||||
return purged_count
|
return purged_count
|
||||||
|
|
||||||
|
def get_backed_up_channel_ids(self) -> List[int]:
|
||||||
|
"""Returns a list of distinct channel IDs that have messages in the database."""
|
||||||
|
with self._lock:
|
||||||
|
rows = self._conn.execute("SELECT DISTINCT channel_id FROM messages").fetchall()
|
||||||
|
return [parse_snowflake(r[0]) for r in rows if parse_snowflake(r[0])]
|
||||||
|
|
||||||
|
def get_message_with_relations(self, message_id) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Fetches a single message with its attachments, embeds, reactions, and stickers."""
|
||||||
|
with self._lock:
|
||||||
|
mid = parse_snowflake(message_id)
|
||||||
|
row = self._conn.execute("SELECT * FROM messages WHERE id = ?", (mid,)).fetchone()
|
||||||
|
if not row:
|
||||||
|
return None
|
||||||
|
data = dict(row)
|
||||||
|
|
||||||
|
# Attachments
|
||||||
|
atts = self._conn.execute("SELECT * FROM attachments WHERE message_id = ?", (mid,)).fetchall()
|
||||||
|
data["attachments"] = [dict(a) for a in atts]
|
||||||
|
|
||||||
|
# Embeds
|
||||||
|
embs = self._conn.execute("SELECT * FROM embeds WHERE message_id = ?", (mid,)).fetchall()
|
||||||
|
data["embeds"] = []
|
||||||
|
for er in embs:
|
||||||
|
e_dict = {
|
||||||
|
"title": er["title"],
|
||||||
|
"description": er["description"],
|
||||||
|
"url": er["url"],
|
||||||
|
"color": er["color"],
|
||||||
|
"timestamp": er["timestamp"],
|
||||||
|
"thumbnail": {"url": er["thumbnail_url"]} if er["thumbnail_url"] else None,
|
||||||
|
"image": {"url": er["image_url"]} if er["image_url"] else None,
|
||||||
|
"author": {
|
||||||
|
"name": er["author_name"],
|
||||||
|
"url": er["author_url"],
|
||||||
|
"icon_url": er["author_icon_url"]
|
||||||
|
} if er["author_name"] else None,
|
||||||
|
"footer": {
|
||||||
|
"text": er["footer_text"],
|
||||||
|
"icon_url": er["footer_icon_url"]
|
||||||
|
} if er["footer_text"] else None,
|
||||||
|
"fields": json.loads(er["fields"]) if er["fields"] else []
|
||||||
|
}
|
||||||
|
data["embeds"].append(e_dict)
|
||||||
|
|
||||||
|
# Reactions
|
||||||
|
reas = self._conn.execute("SELECT * FROM reactions WHERE message_id = ?", (mid,)).fetchall()
|
||||||
|
data["reactions"] = [dict(r) for r in reas]
|
||||||
|
|
||||||
|
# Stickers
|
||||||
|
sts = self._conn.execute("SELECT * FROM message_stickers WHERE message_id = ?", (mid,)).fetchall()
|
||||||
|
data["stickers"] = [dict(s) for s in sts]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Commits any pending writes and closes the connection."""
|
"""Commits any pending writes and closes the connection."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
|
|
||||||
|
|
@ -393,6 +393,20 @@ class BackupMember:
|
||||||
# Fallback for unexpected data format
|
# Fallback for unexpected data format
|
||||||
self.id = 0
|
self.id = 0
|
||||||
self.name = "Unknown"
|
self.name = "Unknown"
|
||||||
|
self.display_name = "Unknown"
|
||||||
|
self.global_name = "Unknown"
|
||||||
|
self.bot = False
|
||||||
|
self.system = False
|
||||||
|
self.discriminator = "0000"
|
||||||
|
self.color = BackupColor(0)
|
||||||
|
self.roles = sorted(role_objects or [], key=lambda r: r.position, reverse=True)
|
||||||
|
self.guild_permissions = BackupPermissions(0)
|
||||||
|
self.created_at = datetime.now(timezone.utc)
|
||||||
|
self.joined_at = datetime.now(timezone.utc)
|
||||||
|
self.status = type("Status", (), {"value": "offline"})()
|
||||||
|
self.activity = None
|
||||||
|
self._avatar_url = None
|
||||||
|
self.avatar = BackupAsset(None)
|
||||||
return
|
return
|
||||||
self.id = parse_snowflake(data["id"])
|
self.id = parse_snowflake(data["id"])
|
||||||
self.name = data.get("username", "Unknown")
|
self.name = data.get("username", "Unknown")
|
||||||
|
|
@ -516,12 +530,13 @@ class BackupEmoji:
|
||||||
class BackupSticker:
|
class BackupSticker:
|
||||||
"""Minimal stand-in for discord.GuildSticker."""
|
"""Minimal stand-in for discord.GuildSticker."""
|
||||||
|
|
||||||
__slots__ = ("id", "name", "url", "format", "_backup_root", "_file_path")
|
__slots__ = ("id", "name", "url", "format", "_backup_root", "_file_path", "local_hash")
|
||||||
|
|
||||||
def __init__(self, data: dict, backup_root: Path | None = None, media_pool: dict | None = None):
|
def __init__(self, data: dict, backup_root: Path | None = None, media_pool: dict | None = None):
|
||||||
if not isinstance(data, dict):
|
if not isinstance(data, dict):
|
||||||
self.id = 0
|
self.id = 0
|
||||||
self.name = "Sticker"
|
self.name = "Sticker"
|
||||||
|
self.local_hash = None
|
||||||
return
|
return
|
||||||
self.id = parse_snowflake(data.get("id") or data.get("sticker_id", 0)) or 0
|
self.id = parse_snowflake(data.get("id") or data.get("sticker_id", 0)) or 0
|
||||||
self.name = data.get("name", "Sticker")
|
self.name = data.get("name", "Sticker")
|
||||||
|
|
@ -536,14 +551,14 @@ class BackupSticker:
|
||||||
self._backup_root = backup_root
|
self._backup_root = backup_root
|
||||||
|
|
||||||
# 1. Check if it's a CAS-based sticker (from message_stickers table)
|
# 1. Check if it's a CAS-based sticker (from message_stickers table)
|
||||||
local_hash = data.get("local_hash")
|
self.local_hash = data.get("local_hash")
|
||||||
if local_hash and backup_root:
|
if self.local_hash and backup_root:
|
||||||
ext = ".png"
|
ext = ".png"
|
||||||
if self.format == StickerFormatType.lottie: ext = ".json"
|
if self.format == StickerFormatType.lottie: ext = ".json"
|
||||||
elif self.format == StickerFormatType.apng: ext = ".png"
|
elif self.format == StickerFormatType.apng: ext = ".png"
|
||||||
elif self.format == StickerFormatType.gif: ext = ".gif"
|
elif self.format == StickerFormatType.gif: ext = ".gif"
|
||||||
|
|
||||||
self._file_path = backup_root / "attachments" / f"{local_hash}{ext}"
|
self._file_path = backup_root / "attachments" / f"{self.local_hash}{ext}"
|
||||||
# 2. Check if it's a server asset sticker (legacy or manual save)
|
# 2. Check if it's a server asset sticker (legacy or manual save)
|
||||||
elif data.get("filename") and backup_root:
|
elif data.get("filename") and backup_root:
|
||||||
self._file_path = backup_root / "server_assets" / data["filename"]
|
self._file_path = backup_root / "server_assets" / data["filename"]
|
||||||
|
|
@ -1266,11 +1281,7 @@ class BackupReader:
|
||||||
async def get_backed_up_channel_ids(self) -> List[int]:
|
async def get_backed_up_channel_ids(self) -> List[int]:
|
||||||
"""Returns a list of channel IDs that have messages in the database."""
|
"""Returns a list of channel IDs that have messages in the database."""
|
||||||
if not self.db: return []
|
if not self.db: return []
|
||||||
import sqlite3
|
return self.db.get_backed_up_channel_ids()
|
||||||
conn = sqlite3.connect(self.db.db_path)
|
|
||||||
rows = conn.execute("SELECT DISTINCT channel_id FROM messages").fetchall()
|
|
||||||
conn.close()
|
|
||||||
return [parse_snowflake(r[0]) for r in rows if parse_snowflake(r[0])]
|
|
||||||
|
|
||||||
async def get_channel(self, channel_id: int) -> BackupChannel | BackupThread | None:
|
async def get_channel(self, channel_id: int) -> BackupChannel | BackupThread | None:
|
||||||
for c in self.channels:
|
for c in self.channels:
|
||||||
|
|
@ -1328,6 +1339,18 @@ class BackupReader:
|
||||||
user_id = parse_snowflake(msg_data.get("author_id", 0)) or 0
|
user_id = parse_snowflake(msg_data.get("author_id", 0)) or 0
|
||||||
author = self._resolve_author(user_id)
|
author = self._resolve_author(user_id)
|
||||||
|
|
||||||
|
# Check for custom author profile (Webhooks / Masquerade)
|
||||||
|
over_name = msg_data.get("custom_display_name")
|
||||||
|
over_avatar = msg_data.get("custom_avatar_url")
|
||||||
|
if over_name:
|
||||||
|
# Create an ephemeral author object for this message
|
||||||
|
author = BackupMember({
|
||||||
|
"id": str(user_id),
|
||||||
|
"username": over_name,
|
||||||
|
"display_name": over_name,
|
||||||
|
"avatar_url": over_avatar
|
||||||
|
})
|
||||||
|
|
||||||
self._ensure_media_pool_loaded()
|
self._ensure_media_pool_loaded()
|
||||||
|
|
||||||
channel_id = parse_snowflake(msg_data["channel_id"])
|
channel_id = parse_snowflake(msg_data["channel_id"])
|
||||||
|
|
@ -1351,23 +1374,9 @@ class BackupReader:
|
||||||
async def get_message(self, channel_id: int, message_id: int) -> BackupMessage | None:
|
async def get_message(self, channel_id: int, message_id: int) -> BackupMessage | None:
|
||||||
"""Fetch a specific message from SQLite."""
|
"""Fetch a specific message from SQLite."""
|
||||||
if not self.db: return None
|
if not self.db: return None
|
||||||
import sqlite3
|
data = self.db.get_message_with_relations(message_id)
|
||||||
conn = sqlite3.connect(self.db.db_path)
|
if data:
|
||||||
conn.row_factory = sqlite3.Row
|
|
||||||
row = conn.execute("SELECT * FROM messages WHERE id = ?", (str(message_id),)).fetchone()
|
|
||||||
if row:
|
|
||||||
data = dict(row)
|
|
||||||
# Fetch attachments
|
|
||||||
atts = conn.execute("SELECT * FROM attachments WHERE message_id = ?", (str(message_id),)).fetchall()
|
|
||||||
data["attachments"] = [dict(a) for a in atts]
|
|
||||||
|
|
||||||
# Fetch stickers
|
|
||||||
sts = conn.execute("SELECT * FROM message_stickers WHERE message_id = ?", (str(message_id),)).fetchall()
|
|
||||||
data["stickers"] = [dict(s) for s in sts]
|
|
||||||
|
|
||||||
conn.close()
|
|
||||||
return self._hydrate_message(data)
|
return self._hydrate_message(data)
|
||||||
conn.close()
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_first_message(self, channel_id: int) -> BackupMessage | None:
|
async def get_first_message(self, channel_id: int) -> BackupMessage | None:
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,13 @@ class MigrationContext:
|
||||||
self.target_platform = target_platform or config.target_platform or "fluxer"
|
self.target_platform = target_platform or config.target_platform or "fluxer"
|
||||||
self.state = MigrationState()
|
self.state = MigrationState()
|
||||||
|
|
||||||
|
# Apply config-based log level dynamically to the root logger
|
||||||
|
if hasattr(self.config, "log_level") and self.config.log_level:
|
||||||
|
import logging
|
||||||
|
level = getattr(logging, self.config.log_level.upper(), logging.INFO)
|
||||||
|
logging.getLogger().setLevel(level)
|
||||||
|
logger.info(f"Log level updated to {self.config.log_level.upper()}")
|
||||||
|
|
||||||
# Select the appropriate source reader
|
# Select the appropriate source reader
|
||||||
if source_mode == "backup":
|
if source_mode == "backup":
|
||||||
from src.core.backup_reader import BackupReader
|
from src.core.backup_reader import BackupReader
|
||||||
|
|
@ -37,8 +44,8 @@ class MigrationContext:
|
||||||
|
|
||||||
# Build the writer for the active target platform only
|
# Build the writer for the active target platform only
|
||||||
if self.target_platform == "stoat":
|
if self.target_platform == "stoat":
|
||||||
token = config.stoat_bot_token or ""
|
token = config.stoat_bot_token
|
||||||
community_id = config.stoat_server_id or ""
|
community_id = config.stoat_server_id
|
||||||
api_url = config.stoat_api_url or "default"
|
api_url = config.stoat_api_url or "default"
|
||||||
self.writer = StoatWriter(token=token, community_id=community_id, api_url=api_url)
|
self.writer = StoatWriter(token=token, community_id=community_id, api_url=api_url)
|
||||||
self.stoat_writer = self.writer
|
self.stoat_writer = self.writer
|
||||||
|
|
@ -97,11 +104,17 @@ class MigrationContext:
|
||||||
}
|
}
|
||||||
|
|
||||||
# CONSISTENCY: Once target metadata is known, initialize the flat SQLite DB.
|
# CONSISTENCY: Once target metadata is known, initialize the flat SQLite DB.
|
||||||
if results["target_community"] and results["target_community_name"]:
|
if results["target_community"]:
|
||||||
tid = self.config.fluxer_server_id if self.target_platform == "fluxer" else self.config.stoat_server_id
|
tid = self.config.fluxer_server_id if self.target_platform == "fluxer" else self.config.stoat_server_id
|
||||||
|
|
||||||
|
# Prefer the original discord community name for the DB file if available (e.g. from live load or backup)
|
||||||
|
db_name = results.get("discord_server_name")
|
||||||
|
if not db_name or db_name == "Not Found" or db_name == "Unknown":
|
||||||
|
db_name = results.get("target_community_name") or "Unknown"
|
||||||
|
|
||||||
self.ensure_state_initialized(
|
self.ensure_state_initialized(
|
||||||
str(tid or ""),
|
str(tid or ""),
|
||||||
results["target_community_name"]
|
db_name
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
@ -120,6 +133,23 @@ class MigrationContext:
|
||||||
return
|
return
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Override the target name explicitly with the original Discord source name if available.
|
||||||
|
# This fixes naming collisions and UI confusion like "Fluxer-123456.db" instead of "MyServer-123456.db"
|
||||||
|
try:
|
||||||
|
if hasattr(self.discord_reader, "guild") and getattr(self.discord_reader, "guild", None):
|
||||||
|
community_name = getattr(self.discord_reader, "guild").name
|
||||||
|
elif getattr(self, "source_mode", "live") == "backup" and hasattr(self.discord_reader, "backup_dir"):
|
||||||
|
b_dir = getattr(self.discord_reader, "backup_dir")
|
||||||
|
if b_dir and b_dir.exists():
|
||||||
|
meta_file = b_dir / "metadata.json"
|
||||||
|
if meta_file.exists():
|
||||||
|
data = json.loads(meta_file.read_text())
|
||||||
|
community_name = data.get("name", community_name)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
clean_name = re.sub(r'[^\w\s-]', '', community_name).strip()
|
clean_name = re.sub(r'[^\w\s-]', '', community_name).strip()
|
||||||
clean_name = re.sub(r'[-\s]+', '_', clean_name)
|
clean_name = re.sub(r'[-\s]+', '_', clean_name)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import logging
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Dict, Any, Union
|
from typing import Optional, Dict, Any, List, Union
|
||||||
import threading
|
import threading
|
||||||
import sys
|
import sys
|
||||||
from src.core.utils import parse_snowflake
|
from src.core.utils import parse_snowflake
|
||||||
|
|
@ -560,6 +560,15 @@ class MigrationDatabase:
|
||||||
conn.execute("DELETE FROM thread_tracking WHERE channel_id = ?", (str(channel_id),))
|
conn.execute("DELETE FROM thread_tracking WHERE channel_id = ?", (str(channel_id),))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logger.info(f"Cleared all tracking and mapping data for channel: {channel_id}")
|
logger.info(f"Cleared all tracking and mapping data for channel: {channel_id}")
|
||||||
|
def clear_all_migration_data(self):
|
||||||
|
"""Purge all mappings and tracking data for ALL channels and threads."""
|
||||||
|
conn = self._get_conn()
|
||||||
|
conn.execute("DELETE FROM message_mappings")
|
||||||
|
conn.execute("DELETE FROM thread_mappings")
|
||||||
|
conn.execute("DELETE FROM channel_tracking")
|
||||||
|
conn.execute("DELETE FROM thread_tracking")
|
||||||
|
conn.commit()
|
||||||
|
logger.info("Cleared ALL tracking and message mapping data globally.")
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if hasattr(self._local, "conn"):
|
if hasattr(self._local, "conn"):
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ class DiscordExporter:
|
||||||
self.server_name = ""
|
self.server_name = ""
|
||||||
self.server_id = ""
|
self.server_id = ""
|
||||||
self.user_cache = {}
|
self.user_cache = {}
|
||||||
|
self.member_cache: Dict[int, Any] = {} # Pre-fetched member objects (id -> Member)
|
||||||
self.base_dir = Path(base_dir) if base_dir else Path(".")
|
self.base_dir = Path(base_dir) if base_dir else Path(".")
|
||||||
self.is_running = True
|
self.is_running = True
|
||||||
self.db: Optional[BackupDatabase] = None
|
self.db: Optional[BackupDatabase] = None
|
||||||
|
|
@ -62,6 +63,20 @@ class DiscordExporter:
|
||||||
hash_sha256.update(chunk)
|
hash_sha256.update(chunk)
|
||||||
return hash_sha256.hexdigest()
|
return hash_sha256.hexdigest()
|
||||||
|
|
||||||
|
async def prefetch_members(self):
|
||||||
|
"""Pre-fetches all guild members into a local cache for role resolution.
|
||||||
|
|
||||||
|
msg.author is a discord.User (no roles). This cache allows us to
|
||||||
|
resolve roles without an API call per message during message export.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
members = await self.reader.get_members()
|
||||||
|
self.member_cache = {m.id: m for m in members}
|
||||||
|
logger.info(f"Pre-fetched {len(self.member_cache)} members for role resolution.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not pre-fetch members (roles will be empty): {e}")
|
||||||
|
self.member_cache = {}
|
||||||
|
|
||||||
async def export_metadata(self):
|
async def export_metadata(self):
|
||||||
"""Saves server metadata to the SQLite database."""
|
"""Saves server metadata to the SQLite database."""
|
||||||
metadata = await self.reader.get_server_metadata()
|
metadata = await self.reader.get_server_metadata()
|
||||||
|
|
@ -439,37 +454,65 @@ class DiscordExporter:
|
||||||
|
|
||||||
return accumulated_count, accumulated_threads, accumulated_files
|
return accumulated_count, accumulated_threads, accumulated_files
|
||||||
|
|
||||||
async def _format_user(self, user):
|
async def _format_user(self, user, is_webhook=False):
|
||||||
"""Formats user data for the author or a mention.
|
"""Formats user data for the author or a mention.
|
||||||
|
|
||||||
Avatar downloads are intentionally deferred to keep this off the hot
|
For Webhooks, we use a generic name and the default Discord avatar system
|
||||||
message-formatting path. Call _flush_pending_avatars() after each batch.
|
for the base profile in the user cache.
|
||||||
"""
|
"""
|
||||||
user_id = str(user.id)
|
user_id_int = int(user.id)
|
||||||
|
user_id = str(user_id_int)
|
||||||
|
|
||||||
if user_id in self.user_cache:
|
if user_id in self.user_cache:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
username = user.name
|
||||||
|
display_name = getattr(user, "display_name", user.name)
|
||||||
|
avatar = user.avatar
|
||||||
|
avatar_url = str(user.display_avatar.url) if user.display_avatar else None
|
||||||
|
|
||||||
|
if is_webhook:
|
||||||
|
# For webhooks, we use the ID as the username for technical clarity,
|
||||||
|
# and the current name as the display name.
|
||||||
|
username = user_id
|
||||||
|
display_name = user.name
|
||||||
|
# Discord default avatar formula: (ID >> 22) % 5
|
||||||
|
default_index = (user_id_int >> 22) % 5
|
||||||
|
avatar_url = f"https://cdn.discordapp.com/embed/avatars/{default_index}.png"
|
||||||
|
avatar = None # Don't download character avatar as the "base" webhook avatar
|
||||||
|
|
||||||
# New user discovered — schedule avatar download but don't block here
|
# New user discovered — schedule avatar download but don't block here
|
||||||
avatar_file = None
|
avatar_file = None
|
||||||
if user.avatar:
|
if avatar:
|
||||||
av_name = f"{user_id}.png"
|
av_name = f"{user_id}.png"
|
||||||
av_target = self.users_path / av_name
|
av_target = self.users_path / av_name
|
||||||
avatar_file = f"users/{av_name}"
|
avatar_file = f"users/{av_name}"
|
||||||
if not av_target.exists():
|
if not av_target.exists():
|
||||||
# Queue for deferred download
|
# Queue for deferred download
|
||||||
self._pending_avatars.append((user_id, user.avatar, av_target))
|
self._pending_avatars.append((user_id, avatar, av_target))
|
||||||
|
|
||||||
roles = []
|
roles = []
|
||||||
if hasattr(user, "roles"):
|
if hasattr(user, "roles"):
|
||||||
roles = [str(r.id) for r in user.roles if not r.is_default()]
|
roles = [str(r.id) for r in user.roles if not r.is_default()]
|
||||||
|
|
||||||
|
# Determine user type
|
||||||
|
# 0: Regular User, 1: Bot, 2: Webhook, 3: System
|
||||||
|
u_type = 0
|
||||||
|
if is_webhook:
|
||||||
|
u_type = 2
|
||||||
|
elif getattr(user, "system", False):
|
||||||
|
u_type = 3
|
||||||
|
elif getattr(user, "bot", False):
|
||||||
|
u_type = 1
|
||||||
|
|
||||||
user_data = {
|
user_data = {
|
||||||
"id": user_id,
|
"id": user_id,
|
||||||
"username": user.name,
|
"username": username,
|
||||||
"display_name": getattr(user, "display_name", user.name),
|
"display_name": display_name,
|
||||||
"avatar_file": avatar_file,
|
"avatar_file": avatar_file,
|
||||||
"avatar_url": str(user.display_avatar.url) if user.avatar else None,
|
"avatar_url": avatar_url,
|
||||||
"roles": json.dumps(roles)
|
"roles": json.dumps(roles),
|
||||||
|
"type": u_type
|
||||||
}
|
}
|
||||||
self.user_cache[user_id] = user_data
|
self.user_cache[user_id] = user_data
|
||||||
return user_data
|
return user_data
|
||||||
|
|
@ -494,13 +537,21 @@ class DiscordExporter:
|
||||||
new_users = []
|
new_users = []
|
||||||
|
|
||||||
# 1. Author handling
|
# 1. Author handling
|
||||||
u_data = await self._format_user(msg.author)
|
is_webhook = bool(getattr(msg, "webhook_id", None))
|
||||||
|
author = msg.author
|
||||||
|
# msg.author is discord.User (no roles). Resolve to Member for role data.
|
||||||
|
if not is_webhook:
|
||||||
|
member = self.member_cache.get(msg.author.id)
|
||||||
|
if member:
|
||||||
|
author = member
|
||||||
|
u_data = await self._format_user(author, is_webhook=is_webhook)
|
||||||
if u_data: new_users.append(u_data)
|
if u_data: new_users.append(u_data)
|
||||||
|
|
||||||
# 1.5 Mentions handling (ensure all mentioned users are saved)
|
# 1.5 Mentions handling (ensure all mentioned users are saved)
|
||||||
if msg.mentions:
|
if msg.mentions:
|
||||||
for mention in msg.mentions:
|
for mention in msg.mentions:
|
||||||
u_ment = await self._format_user(mention)
|
# Mentions can be Member objects already, so roles work naturally
|
||||||
|
u_ment = await self._format_user(mention, is_webhook=False)
|
||||||
if u_ment: new_users.append(u_ment)
|
if u_ment: new_users.append(u_ment)
|
||||||
|
|
||||||
# 2. Attachments handling (Content-Addressable Storage)
|
# 2. Attachments handling (Content-Addressable Storage)
|
||||||
|
|
@ -603,6 +654,15 @@ class DiscordExporter:
|
||||||
for s_emb in snapshot.embeds:
|
for s_emb in snapshot.embeds:
|
||||||
embeds.append(s_emb.to_dict())
|
embeds.append(s_emb.to_dict())
|
||||||
|
|
||||||
|
# 5.6 Author Overrides (Webhooks / Masquerade)
|
||||||
|
custom_display_name = None
|
||||||
|
custom_avatar_url = None
|
||||||
|
|
||||||
|
# Webhooks or bots with masquerade often use per-message names/avatars
|
||||||
|
if getattr(msg, "webhook_id", None) or (msg.author and msg.author.bot):
|
||||||
|
custom_display_name = msg.author.name
|
||||||
|
custom_avatar_url = str(msg.author.display_avatar.url) if msg.author.display_avatar else None
|
||||||
|
|
||||||
m_data = {
|
m_data = {
|
||||||
"id": str(msg.id),
|
"id": str(msg.id),
|
||||||
"channel_id": str(msg.channel.id),
|
"channel_id": str(msg.channel.id),
|
||||||
|
|
@ -616,7 +676,9 @@ class DiscordExporter:
|
||||||
"stickers": stickers,
|
"stickers": stickers,
|
||||||
"embeds": embeds,
|
"embeds": embeds,
|
||||||
"reactions": reactions,
|
"reactions": reactions,
|
||||||
"extra_data": None
|
"extra_data": None,
|
||||||
|
"custom_display_name": custom_display_name,
|
||||||
|
"custom_avatar_url": custom_avatar_url
|
||||||
}
|
}
|
||||||
|
|
||||||
return m_data, new_users
|
return m_data, new_users
|
||||||
|
|
|
||||||
|
|
@ -232,21 +232,17 @@ class MigrationState:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_global_min_last_message_id(self, all_mapped_ids: List[str]) -> int | None:
|
def get_global_min_last_message_id(self, all_mapped_ids: list[str]) -> int | None:
|
||||||
"""Returns the absolute minimum last_msg_id among the given list of mapped target IDs (channels and threads)."""
|
"""Returns the absolute minimum last_msg_id among the given list of mapped target IDs (channels and threads)."""
|
||||||
if self._ensure_db():
|
if self._ensure_db():
|
||||||
return self.db.get_global_min_last_message_id(all_mapped_ids)
|
return self.db.get_global_min_last_message_id(all_mapped_ids)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def set_waterfall_last_id(self, last_id: str | int):
|
|
||||||
if self.db:
|
|
||||||
self.db.set_metadata("waterfall_last_id", str(last_id))
|
|
||||||
|
|
||||||
def get_waterfall_last_id(self) -> int | None:
|
def clear_all_migration_data(self):
|
||||||
if self.db:
|
"""Clears all message mapping and tracking state globally."""
|
||||||
val = self.db.get_metadata("waterfall_last_id")
|
if self._ensure_db():
|
||||||
return int(val) if val else None
|
self.db.clear_all_migration_data()
|
||||||
return None
|
|
||||||
|
|
||||||
def get_all_last_message_ids(self) -> Dict[str, str]:
|
def get_all_last_message_ids(self) -> Dict[str, str]:
|
||||||
"""Returns a combined map of channel_id/thread_id -> last_msg_id."""
|
"""Returns a combined map of channel_id/thread_id -> last_msg_id."""
|
||||||
|
|
|
||||||
|
|
@ -10,9 +10,7 @@ from src.core.utils import get_app_version
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
REPO_OWNER = "rambros3d"
|
API_URL = "https://git.mithraic.cloud/api/v1/repos/ad3laid3/disco-reaper/releases"
|
||||||
REPO_NAME = "disco-reaper"
|
|
||||||
API_URL = f"https://api.github.com/repos/{REPO_OWNER}/{REPO_NAME}/releases"
|
|
||||||
|
|
||||||
def get_current_version() -> str:
|
def get_current_version() -> str:
|
||||||
"""Returns the current version string, e.g., '1.0.0'. Strips 'Reaper-' and 'v'."""
|
"""Returns the current version string, e.g., '1.0.0'. Strips 'Reaper-' and 'v'."""
|
||||||
|
|
@ -52,7 +50,7 @@ async def check_for_updates() -> Optional[Dict[str, Any]]:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.get(API_URL, headers={"Accept": "application/vnd.github.v3+json"}) as resp:
|
async with session.get(API_URL) as resp:
|
||||||
if resp.status == 200:
|
if resp.status == 200:
|
||||||
releases = await resp.json()
|
releases = await resp.json()
|
||||||
if not isinstance(releases, list) or not releases:
|
if not isinstance(releases, list) or not releases:
|
||||||
|
|
|
||||||
|
|
@ -15,41 +15,52 @@ async def sync_channel_state(context: MigrationContext):
|
||||||
channels = await context.discord_reader.get_channels()
|
channels = await context.discord_reader.get_channels()
|
||||||
fluxer_channels = await context.fluxer_writer.get_channels()
|
fluxer_channels = await context.fluxer_writer.get_channels()
|
||||||
|
|
||||||
# Build name -> id map and ID set for Fluxer for fast lookup
|
# Build maps for Fluxer lookup
|
||||||
fluxer_name_map = {c.get("name"): str(c.get("id")) for c in fluxer_channels if c.get("name")}
|
# {name: id} for categories
|
||||||
fluxer_id_set = {str(c.get("id")) for c in fluxer_channels}
|
fluxer_cats = {c.get("name"): str(c.get("id")) for c in fluxer_channels if c.get("type") == 4}
|
||||||
|
# {parent_id: {name: id}} for channels
|
||||||
|
fluxer_structure = {}
|
||||||
|
for c in fluxer_channels:
|
||||||
|
if c.get("type") == 4: continue
|
||||||
|
p_id = str(c.get("parent_id")) if c.get("parent_id") else "root"
|
||||||
|
if p_id not in fluxer_structure: fluxer_structure[p_id] = {}
|
||||||
|
fluxer_structure[p_id][c.get("name")] = str(c.get("id"))
|
||||||
|
|
||||||
|
fluxer_id_set = {str(c.get("id")) for c in fluxer_channels}
|
||||||
updates = 0
|
updates = 0
|
||||||
removals = 0
|
removals = 0
|
||||||
|
|
||||||
# 1. Verify and Sync Categories
|
# 1. Sync Categories
|
||||||
for cat in categories:
|
for cat in categories:
|
||||||
discord_id = str(cat.id)
|
discord_id = str(cat.id)
|
||||||
fluxer_id = context.state.get_fluxer_category_id(discord_id)
|
fluxer_id = context.state.get_fluxer_category_id(discord_id)
|
||||||
|
|
||||||
if fluxer_id:
|
if fluxer_id:
|
||||||
if fluxer_id not in fluxer_id_set:
|
if fluxer_id not in fluxer_id_set:
|
||||||
context.state.remove_category_mapping(discord_id)
|
context.state.remove_category_mapping(discord_id)
|
||||||
removals += 1
|
removals += 1
|
||||||
elif cat.name in fluxer_name_map:
|
elif cat.name in fluxer_cats:
|
||||||
context.state.set_category_mapping(discord_id, fluxer_name_map[cat.name])
|
context.state.set_category_mapping(discord_id, fluxer_cats[cat.name])
|
||||||
updates += 1
|
updates += 1
|
||||||
|
|
||||||
# 2. Verify and Sync Channels
|
# 2. Sync Channels (parent-aware)
|
||||||
for ch in channels:
|
for ch in channels:
|
||||||
discord_id = str(ch.id)
|
discord_id = str(ch.id)
|
||||||
fluxer_id = context.state.get_fluxer_channel_id(discord_id)
|
fluxer_id = context.state.get_fluxer_channel_id(discord_id)
|
||||||
|
|
||||||
if fluxer_id:
|
if fluxer_id:
|
||||||
if fluxer_id not in fluxer_id_set:
|
if fluxer_id not in fluxer_id_set:
|
||||||
context.state.remove_channel_mapping(discord_id)
|
context.state.remove_channel_mapping(discord_id)
|
||||||
removals += 1
|
removals += 1
|
||||||
elif ch.name in fluxer_name_map:
|
else:
|
||||||
context.state.set_channel_mapping(discord_id, fluxer_name_map[ch.name])
|
# Try to match by name within the mapped parent category
|
||||||
|
p_discord_id = str(ch.category_id) if ch.category_id else "root"
|
||||||
|
p_fluxer_id = context.state.get_fluxer_category_id(p_discord_id) if p_discord_id != "root" else "root"
|
||||||
|
|
||||||
|
if p_fluxer_id in fluxer_structure and ch.name in fluxer_structure[p_fluxer_id]:
|
||||||
|
context.state.set_channel_mapping(discord_id, fluxer_structure[p_fluxer_id][ch.name])
|
||||||
updates += 1
|
updates += 1
|
||||||
|
|
||||||
if updates > 0 or removals > 0:
|
if updates > 0 or removals > 0:
|
||||||
logger.info(f"Channel sync: {updates} mapped, {removals} stale mappings removed")
|
logger.info(f"Fluxer Channel sync: {updates} mapped, {removals} stale mappings removed")
|
||||||
|
|
||||||
|
|
||||||
async def migrate_channels(context: MigrationContext, progress_callback: Callable[[str, str, int, int], Awaitable[None]] | None = None, force: bool = False) -> dict:
|
async def migrate_channels(context: MigrationContext, progress_callback: Callable[[str, str, int, int], Awaitable[None]] | None = None, force: bool = False) -> dict:
|
||||||
|
|
|
||||||
|
|
@ -158,7 +158,190 @@ async def get_channel_threads(reader: Any, channel_id: int) -> List[Any]:
|
||||||
return threads
|
return threads
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def _process_and_send_message(
|
||||||
|
context: MigrationContext,
|
||||||
|
msg: Any,
|
||||||
|
target_channel_id: str,
|
||||||
|
stats: Dict[str, Any],
|
||||||
|
thread_id: str | None = None,
|
||||||
|
parent_target_id: str | None = None,
|
||||||
|
thread_name: str | None = None,
|
||||||
|
processed_threads: set | None = None
|
||||||
|
) -> str | None:
|
||||||
|
"""
|
||||||
|
Internal helper to process a single Discord message (mentions, attachments, stickers)
|
||||||
|
and send it to the Fluxer platform.
|
||||||
|
"""
|
||||||
|
# 1. Formatting
|
||||||
|
content = msg.content or ""
|
||||||
|
|
||||||
|
# Check for forwarded flag
|
||||||
|
is_forwarded = False
|
||||||
|
if hasattr(msg.flags, 'forwarded'):
|
||||||
|
is_forwarded = msg.flags.forwarded
|
||||||
|
|
||||||
|
# Always ensure alias is created/retrieved to populate user_alias table
|
||||||
|
alias = context.state.get_user_alias(str(msg.author.id))
|
||||||
|
anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False
|
||||||
|
|
||||||
|
# Process Stickers
|
||||||
|
files = []
|
||||||
|
if hasattr(msg, 'stickers') and msg.stickers:
|
||||||
|
for s in msg.stickers:
|
||||||
|
try:
|
||||||
|
sticker_data = await context.discord_reader.download_sticker(s)
|
||||||
|
if not sticker_data: continue
|
||||||
|
|
||||||
|
format_val = getattr(s, 'format', 'png')
|
||||||
|
if hasattr(format_val, 'name'):
|
||||||
|
ext = format_val.name.lower()
|
||||||
|
elif isinstance(format_val, int):
|
||||||
|
format_map = {1: 'png', 2: 'apng', 3: 'lottie', 4: 'gif'}
|
||||||
|
ext = format_map.get(format_val, 'png')
|
||||||
|
else:
|
||||||
|
ext = str(format_val).lower()
|
||||||
|
|
||||||
|
# Conversion logic (Simplified for unification)
|
||||||
|
if ext == 'lottie' and HAS_LOTTIE:
|
||||||
|
try:
|
||||||
|
lottie_data = json.loads(sticker_data)
|
||||||
|
def _convert_lottie(data):
|
||||||
|
anim = Animation.load(data)
|
||||||
|
buf = io.BytesIO()
|
||||||
|
export_gif(anim, buf)
|
||||||
|
buf.seek(0)
|
||||||
|
return buf
|
||||||
|
gif_buf = await asyncio.to_thread(_convert_lottie, lottie_data)
|
||||||
|
from PIL import Image
|
||||||
|
def _convert_gif_to_webp(buf):
|
||||||
|
img = Image.open(buf)
|
||||||
|
w_buf = io.BytesIO()
|
||||||
|
if getattr(img, 'n_frames', 1) > 1:
|
||||||
|
img.save(w_buf, format='WEBP', save_all=True, loop=0, quality=80)
|
||||||
|
else:
|
||||||
|
img.save(w_buf, format='WEBP', quality=80)
|
||||||
|
return w_buf.getvalue()
|
||||||
|
sticker_data = await asyncio.to_thread(_convert_gif_to_webp, gif_buf)
|
||||||
|
ext = 'webp'
|
||||||
|
except Exception: ext = 'json'
|
||||||
|
elif ext in ('apng', 'gif'):
|
||||||
|
try:
|
||||||
|
from PIL import Image
|
||||||
|
def _process_animated_sticker(data):
|
||||||
|
img = Image.open(io.BytesIO(data))
|
||||||
|
webp_buf = io.BytesIO()
|
||||||
|
if getattr(img, 'n_frames', 1) > 1:
|
||||||
|
img.save(webp_buf, format='WEBP', save_all=True, loop=0, quality=80)
|
||||||
|
else:
|
||||||
|
img.save(webp_buf, format='WEBP', quality=80)
|
||||||
|
return webp_buf.getvalue()
|
||||||
|
sticker_data = await asyncio.to_thread(_process_animated_sticker, sticker_data)
|
||||||
|
ext = 'webp'
|
||||||
|
except Exception: pass
|
||||||
|
|
||||||
|
files.append({"filename": f"sticker_{s.name}_{s.id}.{ext}", "data": sticker_data})
|
||||||
|
stats["attachments"] += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to download sticker {getattr(s, 'name', 'unknown')}: {e}")
|
||||||
|
|
||||||
|
# Process Attachments
|
||||||
|
attachments_to_process = list(msg.attachments)
|
||||||
|
if is_forwarded and hasattr(msg, 'message_snapshots') and msg.message_snapshots:
|
||||||
|
snapshot = msg.message_snapshots[0]
|
||||||
|
if not content:
|
||||||
|
content = snapshot.content
|
||||||
|
attachments_to_process.extend(snapshot.attachments)
|
||||||
|
|
||||||
|
for att in attachments_to_process:
|
||||||
|
try:
|
||||||
|
att_data = await context.discord_reader.download_attachment(att)
|
||||||
|
files.append({"filename": att.filename, "data": att_data})
|
||||||
|
stats["attachments"] += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to download attachment {att.filename}: {e}")
|
||||||
|
|
||||||
|
# Clean Mentions
|
||||||
|
content = clean_mentions(
|
||||||
|
content=content,
|
||||||
|
guild=context.discord_reader.guild,
|
||||||
|
user_mentions=msg.mentions,
|
||||||
|
role_mentions=msg.role_mentions,
|
||||||
|
channel_mentions=msg.channel_mentions,
|
||||||
|
emoji_map=context.state.emoji_map,
|
||||||
|
channel_map=context.state.channel_map,
|
||||||
|
state=context.state,
|
||||||
|
target_server_id=context.fluxer_writer.community_id,
|
||||||
|
channel_names=context.channel_names if hasattr(context, 'channel_names') else None,
|
||||||
|
anonymize_users=anonymize_users
|
||||||
|
)
|
||||||
|
|
||||||
|
if not content and not files:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Reply Resolution
|
||||||
|
reply_to_fluxer_id = None
|
||||||
|
if msg.reference and msg.reference.message_id:
|
||||||
|
reply_to_fluxer_id = context.state.get_fluxer_message_id(target_channel_id, str(msg.reference.message_id))
|
||||||
|
|
||||||
|
# Fallback author tagging for replies if mapping not found
|
||||||
|
if not reply_to_fluxer_id:
|
||||||
|
try:
|
||||||
|
source_ref_msg = await context.discord_reader.get_message(msg.channel.id, msg.reference.message_id)
|
||||||
|
if source_ref_msg and source_ref_msg.author:
|
||||||
|
ref_name = context.state.get_user_alias(str(source_ref_msg.author.id)) if anonymize_users else source_ref_msg.author.display_name
|
||||||
|
content = f"`@{ref_name}`\n{content}"
|
||||||
|
else:
|
||||||
|
tgt_reply = context.state.get_target_message_id(target_channel_id, msg.reference.message_id)
|
||||||
|
if tgt_reply: content = f"[Reply to {tgt_reply}]\n{content}"
|
||||||
|
except Exception: pass
|
||||||
|
|
||||||
|
# Thread logic
|
||||||
|
if not reply_to_fluxer_id and parent_target_id and stats["messages"] == 0:
|
||||||
|
reply_to_fluxer_id = parent_target_id
|
||||||
|
if thread_name and stats["messages"] == 0:
|
||||||
|
content = f"> <<< THREAD: **{thread_name}** >>>\n{content}"
|
||||||
|
|
||||||
|
# Send Message
|
||||||
|
if anonymize_users:
|
||||||
|
author_name = alias or "Anonymized User"
|
||||||
|
author_avatar_url = None
|
||||||
|
else:
|
||||||
|
author_name = msg.author.display_name
|
||||||
|
author_avatar_url = msg.author.avatar.url if hasattr(msg.author, 'avatar') and msg.author.avatar else None
|
||||||
|
|
||||||
|
fluxer_msg_id = await context.fluxer_writer.send_message(
|
||||||
|
channel_id=target_channel_id,
|
||||||
|
author_name=author_name,
|
||||||
|
author_avatar_url=author_avatar_url,
|
||||||
|
content=content,
|
||||||
|
timestamp=int(msg.created_at.timestamp()),
|
||||||
|
files=files if files else None,
|
||||||
|
reply_to_message_id=reply_to_fluxer_id,
|
||||||
|
is_forwarded=is_forwarded,
|
||||||
|
embeds=msg.embeds
|
||||||
|
)
|
||||||
|
|
||||||
|
if fluxer_msg_id:
|
||||||
|
if thread_id:
|
||||||
|
context.state.set_thread_message_mapping(target_channel_id, thread_id, str(msg.id), fluxer_msg_id)
|
||||||
|
context.state.update_thread_last_message_timestamp(target_channel_id, thread_id, str(msg.created_at))
|
||||||
|
context.state.update_thread_last_message_id(target_channel_id, thread_id, str(msg.id))
|
||||||
|
context.state.increment_thread_stats(target_channel_id, thread_id, messages=1, files=len(files) if files else 0)
|
||||||
|
else:
|
||||||
|
context.state.set_message_mapping(target_channel_id, str(msg.id), fluxer_msg_id)
|
||||||
|
context.state.update_last_message_timestamp(target_channel_id, str(msg.created_at))
|
||||||
|
context.state.update_last_message_id(target_channel_id, str(msg.id))
|
||||||
|
context.state.increment_stats(target_channel_id, messages=1, files=len(files) if files else 0)
|
||||||
|
|
||||||
|
stats["messages"] += 1
|
||||||
|
stats["last_message_content"] = content
|
||||||
|
stats["last_message_author"] = msg.author.display_name
|
||||||
|
|
||||||
|
return fluxer_msg_id
|
||||||
|
|
||||||
async def analyze_migration(context: MigrationContext, source_channel_id: int, after_message_id: int | None = None, inclusive: bool = False, progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None, processed_threads: set | None = None) -> Dict[str, int]:
|
async def analyze_migration(context: MigrationContext, source_channel_id: int, after_message_id: int | None = None, inclusive: bool = False, progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None, processed_threads: set | None = None) -> Dict[str, int]:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Scans channel history to count messages, threads, and attachments.
|
Scans channel history to count messages, threads, and attachments.
|
||||||
"""
|
"""
|
||||||
|
|
@ -542,87 +725,30 @@ async def migrate_messages(
|
||||||
logger.debug(f"Added sticker {s.name} as attachment (extension: {ext}, size: {sticker_size} bytes)")
|
logger.debug(f"Added sticker {s.name} as attachment (extension: {ext}, size: {sticker_size} bytes)")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to download sticker {getattr(s, 'name', 'unknown')}: {e}")
|
logger.error(f"Failed to download sticker {getattr(s, 'name', 'unknown')}: {e}")
|
||||||
|
|
||||||
# Check for existing mapping to avoid duplicates when resuming
|
# Check for existing mapping to avoid duplicates when resuming
|
||||||
if context.state.get_target_message_id(target_channel_id, str(msg.id)):
|
if context.state.get_target_message_id(target_channel_id, str(msg.id)):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
reply_to_fluxer_id = None
|
fluxer_msg_id = await _process_and_send_message(
|
||||||
if msg.reference and msg.reference.message_id:
|
context=context,
|
||||||
reply_to_fluxer_id = context.state.get_fluxer_message_id(target_channel_id, str(msg.reference.message_id))
|
msg=msg,
|
||||||
if reply_to_fluxer_id:
|
target_channel_id=target_channel_id,
|
||||||
logger.debug(f"Detected reply to Discord ID {msg.reference.message_id} -> Fluxer ID {reply_to_fluxer_id}")
|
stats=stats,
|
||||||
else:
|
thread_id=thread_id,
|
||||||
logger.debug(f"Reply target Discord ID {msg.reference.message_id} not found in current session map.")
|
parent_target_id=parent_target_id,
|
||||||
|
thread_name=thread_name,
|
||||||
# If this is the FIRST thread message and we have a parent_target_id, force it as reply to the starter
|
processed_threads=processed_threads
|
||||||
if not reply_to_fluxer_id and parent_target_id and stats["messages"] == 0:
|
|
||||||
reply_to_fluxer_id = parent_target_id
|
|
||||||
|
|
||||||
# Prepend thread marker to the first message of the thread
|
|
||||||
if thread_name and stats["messages"] == 0:
|
|
||||||
content = f"> <<< THREAD: **{thread_name}** >>>\n{content}"
|
|
||||||
|
|
||||||
# Always ensure alias is created/retrieved to populate user_alias table
|
|
||||||
alias = context.state.get_user_alias(str(msg.author.id))
|
|
||||||
|
|
||||||
anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False
|
|
||||||
if anonymize_users:
|
|
||||||
author_name = alias or "Anonymized User"
|
|
||||||
author_avatar_url = None
|
|
||||||
else:
|
|
||||||
author_name = msg.author.display_name
|
|
||||||
author_avatar_url = msg.author.avatar.url if hasattr(msg.author, 'avatar') and msg.author.avatar else None
|
|
||||||
|
|
||||||
logger.debug(f"Fluxer: Calling send_message for Discord ID {msg.id}")
|
|
||||||
fluxer_msg_id = await context.fluxer_writer.send_message(
|
|
||||||
channel_id=target_channel_id,
|
|
||||||
author_name=author_name,
|
|
||||||
author_avatar_url=author_avatar_url,
|
|
||||||
content=content,
|
|
||||||
timestamp=int(msg.created_at.timestamp()),
|
|
||||||
files=files if files else None,
|
|
||||||
reply_to_message_id=reply_to_fluxer_id,
|
|
||||||
is_forwarded=is_forwarded,
|
|
||||||
embeds=msg.embeds
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if fluxer_msg_id:
|
# Check for associated thread (Individual mode recursion)
|
||||||
if thread_id:
|
|
||||||
context.state.set_thread_message_mapping(target_channel_id, thread_id, str(msg.id), fluxer_msg_id)
|
|
||||||
else:
|
|
||||||
context.state.set_message_mapping(target_channel_id, str(msg.id), fluxer_msg_id)
|
|
||||||
else:
|
|
||||||
logger.warning(f"Fluxer: send_message returned None for Discord ID {msg.id} (message might have been skipped or timed out)")
|
|
||||||
|
|
||||||
if thread_id:
|
|
||||||
context.state.update_thread_last_message_timestamp(target_channel_id, thread_id, str(msg.created_at))
|
|
||||||
context.state.update_thread_last_message_id(target_channel_id, thread_id, str(msg.id))
|
|
||||||
context.state.increment_thread_stats(target_channel_id, thread_id, messages=1, files=len(files) if files else 0)
|
|
||||||
else:
|
|
||||||
context.state.update_last_message_timestamp(target_channel_id, str(msg.created_at))
|
|
||||||
context.state.update_last_message_id(target_channel_id, str(msg.id))
|
|
||||||
context.state.increment_stats(target_channel_id, messages=1, files=len(files) if files else 0)
|
|
||||||
|
|
||||||
stats["messages"] += 1
|
|
||||||
stats["last_message_content"] = content
|
|
||||||
stats["last_message_author"] = msg.author.display_name
|
|
||||||
|
|
||||||
# Check for associated thread (Normal case: parent message is migrated)
|
|
||||||
if hasattr(msg, 'thread') and msg.thread:
|
if hasattr(msg, 'thread') and msg.thread:
|
||||||
thread = msg.thread
|
thread = msg.thread
|
||||||
if thread.id not in processed_threads:
|
if thread.id not in processed_threads:
|
||||||
processed_threads.add(thread.id)
|
processed_threads.add(thread.id)
|
||||||
# Track thread entry
|
|
||||||
stats["threads"] += 1
|
stats["threads"] += 1
|
||||||
|
|
||||||
# Fetch last migrated message ID for this thread
|
|
||||||
thread_after_id = context.state.get_thread_last_message_id(target_channel_id, str(thread.id))
|
thread_after_id = context.state.get_thread_last_message_id(target_channel_id, str(thread.id))
|
||||||
if thread_after_id:
|
|
||||||
logger.info(f"Resuming thread '{thread.name}' from after message ID: {thread_after_id}")
|
|
||||||
|
|
||||||
# Migrate thread messages recursively
|
|
||||||
thread_stats = await migrate_messages(
|
thread_stats = await migrate_messages(
|
||||||
context=context,
|
context=context,
|
||||||
source_channel_id=thread.id,
|
source_channel_id=thread.id,
|
||||||
|
|
@ -637,22 +763,19 @@ async def migrate_messages(
|
||||||
stats["attachments"] += thread_stats["attachments"]
|
stats["attachments"] += thread_stats["attachments"]
|
||||||
stats["threads"] += thread_stats["threads"]
|
stats["threads"] += thread_stats["threads"]
|
||||||
|
|
||||||
# Send End Marker
|
|
||||||
if context.is_running:
|
if context.is_running:
|
||||||
await context.fluxer_writer.send_marker(
|
await context.fluxer_writer.send_marker(
|
||||||
channel_id=target_channel_id,
|
channel_id=target_channel_id,
|
||||||
content=f"> <<< END OF THREAD >>>"
|
content=f"> <<< END OF THREAD >>>"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update Link Tracking (but prevent threaded messages from overwriting the parent channel pointers)
|
# Update Link Tracking (Parent pointer updates)
|
||||||
# The 'after_message_id' param usually means it's the main function call and not a thread recursive call
|
|
||||||
if not stats["first_message_url"]:
|
if not stats["first_message_url"]:
|
||||||
stats["first_message_url"] = msg.jump_url
|
stats["first_message_url"] = msg.jump_url
|
||||||
stats["last_message_url"] = msg.jump_url
|
stats["last_message_url"] = msg.jump_url
|
||||||
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
await progress_callback(stats)
|
await progress_callback(stats)
|
||||||
logger.debug(f"Fluxer: Finished processing message Discord ID {msg.id}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to process message {msg.id}: {e}")
|
logger.error(f"Failed to process message {msg.id}: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
@ -735,7 +858,7 @@ async def migrate_global_messages(
|
||||||
progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None
|
progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Migrates messages across all channels chronologically.
|
Migrates messages across all channels chronologically to Fluxer.
|
||||||
"""
|
"""
|
||||||
stats = {
|
stats = {
|
||||||
"messages": 0,
|
"messages": 0,
|
||||||
|
|
@ -750,14 +873,6 @@ async def migrate_global_messages(
|
||||||
processed_threads = set()
|
processed_threads = set()
|
||||||
logger.info("Starting Global Waterfall Migration for Fluxer...")
|
logger.info("Starting Global Waterfall Migration for Fluxer...")
|
||||||
|
|
||||||
# Keep track of active thread mapping natively to pass parent target IDs if needed
|
|
||||||
thread_to_target_channel = {}
|
|
||||||
|
|
||||||
# Emojis and mapped users cache setup
|
|
||||||
emoji_map = context.state.emoji_map
|
|
||||||
db_media = context.discord_reader.db.get_all_media() if context.discord_reader.db else {}
|
|
||||||
target_server_id = getattr(context.fluxer_writer, "server_id", None)
|
|
||||||
|
|
||||||
# Fetch global progress map to skip migrated messages efficiently
|
# Fetch global progress map to skip migrated messages efficiently
|
||||||
progress_map = context.state.get_all_last_message_ids()
|
progress_map = context.state.get_all_last_message_ids()
|
||||||
|
|
||||||
|
|
@ -794,124 +909,19 @@ async def migrate_global_messages(
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# If it's a thread message, we need to handle it based on if it's the thread starter or a reply
|
# If it's a thread message, we need to handle it based on if it's the thread starter or a reply
|
||||||
parent_target_id = None
|
|
||||||
if hasattr(msg, 'thread') and msg.thread and msg.id == msg.thread.id:
|
if hasattr(msg, 'thread') and msg.thread and msg.id == msg.thread.id:
|
||||||
processed_threads.add(msg.thread.id)
|
processed_threads.add(msg.thread.id)
|
||||||
stats["threads"] += 1
|
stats["threads"] += 1
|
||||||
elif msg.channel.type in [11, 12]: # Thread channels
|
|
||||||
# It's a message IN a thread.
|
|
||||||
# In Fluxer, threads might just be linear messages or threaded replies depending on schema
|
|
||||||
# For basic migration we just send it to the parent mapped target channel.
|
|
||||||
# The parent mapped target channel ID should already be calculated correctly by get_target_channel_id (which returns mapped thread or parent channel)
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Formatting
|
|
||||||
files = []
|
|
||||||
file_names = []
|
|
||||||
|
|
||||||
# Always ensure alias is created/retrieved to populate user_alias table
|
|
||||||
alias = context.state.get_user_alias(str(msg.author.id))
|
|
||||||
|
|
||||||
anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False
|
|
||||||
|
|
||||||
if anonymize_users:
|
|
||||||
author_name = alias or "Anonymized User"
|
|
||||||
author_avatar_url = None
|
|
||||||
else:
|
|
||||||
author_name = msg.author.display_name
|
|
||||||
author_avatar_url = msg.author.avatar.url if hasattr(msg.author, 'avatar') and msg.author.avatar else None
|
|
||||||
|
|
||||||
for att in msg.attachments:
|
|
||||||
media_info = db_media.get(att.local_hash) if db_media else None
|
|
||||||
local_path = None
|
|
||||||
if media_info:
|
|
||||||
local_path = Path(media_info["local_path"])
|
|
||||||
elif hasattr(att, 'read'):
|
|
||||||
# Fallback
|
|
||||||
pass
|
|
||||||
|
|
||||||
if local_path and local_path.exists():
|
|
||||||
files.append(local_path)
|
|
||||||
file_names.append(att.filename)
|
|
||||||
|
|
||||||
content = msg.content or ""
|
|
||||||
|
|
||||||
# Stickers
|
|
||||||
for sticker in msg.stickers:
|
|
||||||
sticker_name = sticker.name
|
|
||||||
sticker_url = sticker.url
|
|
||||||
|
|
||||||
# Check for uploaded media pool logic first
|
|
||||||
s_hash = sticker.local_hash
|
|
||||||
sticker_file = None
|
|
||||||
s_media = db_media.get(s_hash) if db_media and s_hash else None
|
|
||||||
if s_media:
|
|
||||||
s_path = Path(s_media["local_path"])
|
|
||||||
if s_path.exists():
|
|
||||||
sticker_file = s_path
|
|
||||||
|
|
||||||
content += f"\n[Sticker: {sticker_name}]"
|
|
||||||
if sticker_file:
|
|
||||||
files.append(sticker_file)
|
|
||||||
file_names.append(f"sticker_{sticker_name}.png")
|
|
||||||
|
|
||||||
content = clean_mentions(
|
|
||||||
content=content,
|
|
||||||
guild=context.discord_reader.guild,
|
|
||||||
user_mentions=msg.mentions,
|
|
||||||
role_mentions=msg.role_mentions,
|
|
||||||
channel_mentions=msg.channel_mentions,
|
|
||||||
emoji_map=emoji_map,
|
|
||||||
channel_map=context.state.channel_map,
|
|
||||||
state=context.state,
|
|
||||||
target_server_id=target_server_id,
|
|
||||||
channel_names=context.channel_names if hasattr(context, 'channel_names') else None,
|
|
||||||
anonymize_users=anonymize_users
|
|
||||||
)
|
|
||||||
|
|
||||||
if not content and not files:
|
|
||||||
logger.debug(f"Message {msg.id} empty after processing, skipping.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
timestamp_int = int(msg.created_at.timestamp())
|
|
||||||
|
|
||||||
if msg.reference and msg.reference.message_id:
|
|
||||||
# Resolve the author of the message being replied to
|
|
||||||
source_ref_msg = await context.discord_reader.get_message(msg.channel.id, msg.reference.message_id)
|
|
||||||
if source_ref_msg and source_ref_msg.author:
|
|
||||||
ref_author_id = str(source_ref_msg.author.id)
|
|
||||||
if anonymize_users:
|
|
||||||
ref_name = context.state.get_user_alias(ref_author_id) or "Anonymized User"
|
|
||||||
else:
|
|
||||||
ref_name = source_ref_msg.author.display_name
|
|
||||||
content = f"`@{ref_name}`\n{content}"
|
|
||||||
else:
|
|
||||||
# Fallback if author cannot be resolved (e.g. deleted/missing from backup)
|
|
||||||
tgt_reply = context.state.get_target_message_id(target_channel_id, msg.reference.message_id)
|
|
||||||
if tgt_reply:
|
|
||||||
content = f"[Reply to {tgt_reply}]\n{content}"
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
fluxer_msg_id = await context.fluxer_writer.send_message(
|
await _process_and_send_message(
|
||||||
channel_id=target_channel_id,
|
context=context,
|
||||||
author_name=author_name,
|
msg=msg,
|
||||||
author_avatar_url=author_avatar_url,
|
target_channel_id=target_channel_id,
|
||||||
content=content,
|
stats=stats,
|
||||||
files=files,
|
processed_threads=processed_threads
|
||||||
timestamp=timestamp_int,
|
|
||||||
embeds=msg.embeds
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if fluxer_msg_id:
|
|
||||||
context.state.set_target_message_mapping(target_channel_id, msg.id, fluxer_msg_id)
|
|
||||||
context.state.update_last_message_id(target_channel_id, msg.id)
|
|
||||||
context.state.set_waterfall_last_id(msg.id)
|
|
||||||
stats["attachments"] += len(files) if files else 0
|
|
||||||
|
|
||||||
stats["messages"] += 1
|
|
||||||
stats["last_message_content"] = content
|
|
||||||
stats["last_message_author"] = author_name
|
|
||||||
|
|
||||||
if not stats["first_message_url"]:
|
if not stats["first_message_url"]:
|
||||||
stats["first_message_url"] = msg.jump_url
|
stats["first_message_url"] = msg.jump_url
|
||||||
stats["last_message_url"] = msg.jump_url
|
stats["last_message_url"] = msg.jump_url
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ class FluxerWriter:
|
||||||
guilds_list.append((label, str(g.id)))
|
guilds_list.append((label, str(g.id)))
|
||||||
return guilds_list
|
return guilds_list
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"Failed to fetch Fluxer communities via HTTP: {e}")
|
||||||
logger.error(f"Failed to fetch Fluxer communities via HTTP: {e}")
|
logger.error(f"Failed to fetch Fluxer communities via HTTP: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
@ -61,6 +62,7 @@ class FluxerWriter:
|
||||||
return w
|
return w
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to manage webhook for channel {channel_id}: {e}")
|
print(f"Failed to manage webhook for channel {channel_id}: {e}")
|
||||||
|
logger.error(f"Failed to manage webhook for channel {channel_id}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
|
|
@ -322,6 +324,7 @@ class FluxerWriter:
|
||||||
logger.debug(f"Fluxer: Webhook send complete, msg_id={msg.id if msg else 'None'}")
|
logger.debug(f"Fluxer: Webhook send complete, msg_id={msg.id if msg else 'None'}")
|
||||||
return str(msg.id) if msg else None
|
return str(msg.id) if msg else None
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
print(f"Fluxer: Webhook send timed out after 45s for channel {channel_id}")
|
||||||
logger.error(f"Fluxer: Webhook send timed out after 45s for channel {channel_id}")
|
logger.error(f"Fluxer: Webhook send timed out after 45s for channel {channel_id}")
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
|
|
@ -340,19 +343,24 @@ class FluxerWriter:
|
||||||
|
|
||||||
logger.debug(f"Fluxer: Sending message via bot for user '{author_name}'")
|
logger.debug(f"Fluxer: Sending message via bot for user '{author_name}'")
|
||||||
try:
|
try:
|
||||||
|
kwargs = {
|
||||||
|
"channel_id": channel_id,
|
||||||
|
"content": final_bot_content,
|
||||||
|
"embeds": normalized_embeds
|
||||||
|
}
|
||||||
|
if fluxer_files:
|
||||||
|
kwargs["files"] = fluxer_files
|
||||||
|
if message_reference:
|
||||||
|
kwargs["message_reference"] = message_reference
|
||||||
|
|
||||||
msg_data = await asyncio.wait_for(
|
msg_data = await asyncio.wait_for(
|
||||||
self.client.send_message(
|
self.client.send_message(**kwargs),
|
||||||
channel_id=channel_id,
|
|
||||||
content=final_bot_content,
|
|
||||||
files=fluxer_files,
|
|
||||||
embeds=normalized_embeds,
|
|
||||||
message_reference=message_reference
|
|
||||||
),
|
|
||||||
timeout=45.0
|
timeout=45.0
|
||||||
)
|
)
|
||||||
logger.debug(f"Fluxer: Bot send complete, msg_id={msg_data.get('id') if msg_data else 'None'}")
|
logger.debug(f"Fluxer: Bot send complete, msg_id={msg_data.get('id') if msg_data else 'None'}")
|
||||||
return str(msg_data["id"]) if msg_data else None
|
return str(msg_data["id"]) if msg_data else None
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
print(f"Fluxer: Bot send timed out after 45s for channel {channel_id}")
|
||||||
logger.error(f"Fluxer: Bot send timed out after 45s for channel {channel_id}")
|
logger.error(f"Fluxer: Bot send timed out after 45s for channel {channel_id}")
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -370,22 +378,27 @@ class FluxerWriter:
|
||||||
|
|
||||||
fluxer_files = None
|
fluxer_files = None
|
||||||
if files:
|
if files:
|
||||||
fluxer_files = [File(io.BytesIO(f["data"]), filename=f["filename"]) for f in files]
|
fluxer_files = [f if hasattr(f, "filename") else File(io.BytesIO(f["data"]), filename=f["filename"]) for f in files]
|
||||||
|
|
||||||
message_reference = None
|
message_reference = None
|
||||||
if reply_to_message_id:
|
if reply_to_message_id:
|
||||||
message_reference = {"message_id": str(reply_to_message_id), "channel_id": str(channel_id)}
|
message_reference = {"message_id": str(reply_to_message_id), "channel_id": str(channel_id)}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
msg_data = await self.client.send_message(
|
kwargs = {
|
||||||
channel_id=channel_id,
|
"channel_id": channel_id,
|
||||||
content=content,
|
"content": content
|
||||||
files=fluxer_files,
|
}
|
||||||
message_reference=message_reference
|
if fluxer_files:
|
||||||
)
|
kwargs["files"] = fluxer_files
|
||||||
|
if message_reference:
|
||||||
|
kwargs["message_reference"] = message_reference
|
||||||
|
|
||||||
|
msg_data = await self.client.send_message(**kwargs)
|
||||||
return str(msg_data["id"]) if msg_data else None
|
return str(msg_data["id"]) if msg_data else None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to send marker: {e}")
|
print(f"Failed to send marker: {e}")
|
||||||
|
logger.error(f"Failed to send marker: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def create_role(self, name: str, color: int, hoist: bool, mentionable: bool, permissions: int, position: Optional[int] = None) -> str:
|
async def create_role(self, name: str, color: int, hoist: bool, mentionable: bool, permissions: int, position: Optional[int] = None) -> str:
|
||||||
|
|
@ -408,6 +421,7 @@ class FluxerWriter:
|
||||||
return str(role["id"])
|
return str(role["id"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to copy role {name}: {e}")
|
print(f"Failed to copy role {name}: {e}")
|
||||||
|
logger.error(f"Failed to copy role {name}: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
async def create_emoji(self, name: str, image_bytes: bytes) -> str:
|
async def create_emoji(self, name: str, image_bytes: bytes) -> str:
|
||||||
|
|
@ -473,6 +487,7 @@ class FluxerWriter:
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to update community metadata: {e}")
|
print(f"Failed to update community metadata: {e}")
|
||||||
|
logger.error(f"Failed to update community metadata: {e}")
|
||||||
|
|
||||||
async def remove_community_logo_and_banner(self) -> dict:
|
async def remove_community_logo_and_banner(self) -> dict:
|
||||||
"""
|
"""
|
||||||
|
|
@ -503,6 +518,7 @@ class FluxerWriter:
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to remove community icon: {e}")
|
print(f"Failed to remove community icon: {e}")
|
||||||
|
logger.error(f"Failed to remove community icon: {e}")
|
||||||
|
|
||||||
# 3. Remove banner if set
|
# 3. Remove banner if set
|
||||||
if has_banner:
|
if has_banner:
|
||||||
|
|
@ -513,6 +529,7 @@ class FluxerWriter:
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to remove community banner: {e}")
|
print(f"Failed to remove community banner: {e}")
|
||||||
|
logger.error(f"Failed to remove community banner: {e}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"icon": "REMOVED" if has_icon else "SKIP",
|
"icon": "REMOVED" if has_icon else "SKIP",
|
||||||
|
|
@ -544,6 +561,7 @@ class FluxerWriter:
|
||||||
await progress_callback(ch.get("name", "Unknown"), deleted, total)
|
await progress_callback(ch.get("name", "Unknown"), deleted, total)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to delete channel {ch.get('name')}: {e}")
|
print(f"Failed to delete channel {ch.get('name')}: {e}")
|
||||||
|
logger.error(f"Failed to delete channel {ch.get('name')}: {e}")
|
||||||
return deleted
|
return deleted
|
||||||
|
|
||||||
async def reset_channel_permissions(self, progress_callback=None) -> int:
|
async def reset_channel_permissions(self, progress_callback=None) -> int:
|
||||||
|
|
@ -576,12 +594,14 @@ class FluxerWriter:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"Failed to delete overwrite {ow['id']} for channel {ch['id']}: {e}")
|
||||||
logger.error(f"Failed to delete overwrite {ow['id']} for channel {ch['id']}: {e}")
|
logger.error(f"Failed to delete overwrite {ow['id']} for channel {ch['id']}: {e}")
|
||||||
processed += 1
|
processed += 1
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
await progress_callback(ch.get("name", "Unknown"), processed, total)
|
await progress_callback(ch.get("name", "Unknown"), processed, total)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to reset permissions for channel {ch.get('name')}: {e}")
|
print(f"Failed to reset permissions for channel {ch.get('name')}: {e}")
|
||||||
|
logger.error(f"Failed to reset permissions for channel {ch.get('name')}: {e}")
|
||||||
return processed
|
return processed
|
||||||
|
|
||||||
async def set_channel_permission(self, channel_id: str, overwrite_id: str, allow: int, deny: int, is_role: bool = True):
|
async def set_channel_permission(self, channel_id: str, overwrite_id: str, allow: int, deny: int, is_role: bool = True):
|
||||||
|
|
@ -603,6 +623,7 @@ class FluxerWriter:
|
||||||
type=0 if is_role else 1
|
type=0 if is_role else 1
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"Failed to set permission on channel {channel_id} for overwrite {overwrite_id}: {e}")
|
||||||
logger.error(f"Failed to set permission on channel {channel_id} for overwrite {overwrite_id}: {e}")
|
logger.error(f"Failed to set permission on channel {channel_id} for overwrite {overwrite_id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -644,6 +665,7 @@ class FluxerWriter:
|
||||||
await progress_callback(role.get("name", "Unknown"), deleted, total)
|
await progress_callback(role.get("name", "Unknown"), deleted, total)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to delete role {role.get('name')}: {e}")
|
print(f"Failed to delete role {role.get('name')}: {e}")
|
||||||
|
logger.error(f"Failed to delete role {role.get('name')}: {e}")
|
||||||
return deleted
|
return deleted
|
||||||
|
|
||||||
async def delete_all_emojis_and_stickers(self, progress_callback=None) -> dict:
|
async def delete_all_emojis_and_stickers(self, progress_callback=None) -> dict:
|
||||||
|
|
@ -667,8 +689,10 @@ class FluxerWriter:
|
||||||
await progress_callback(emoji.get("name", "Unknown"), "Emoji", emoji_deleted, emoji_total)
|
await progress_callback(emoji.get("name", "Unknown"), "Emoji", emoji_deleted, emoji_total)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to delete emoji {emoji.get('name')}: {e}")
|
print(f"Failed to delete emoji {emoji.get('name')}: {e}")
|
||||||
|
logger.error(f"Failed to delete emoji {emoji.get('name')}: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to fetch emojis: {e}")
|
print(f"Failed to fetch emojis: {e}")
|
||||||
|
logger.error(f"Failed to fetch emojis: {e}")
|
||||||
|
|
||||||
# Delete stickers
|
# Delete stickers
|
||||||
try:
|
try:
|
||||||
|
|
@ -682,8 +706,10 @@ class FluxerWriter:
|
||||||
await progress_callback(sticker.get("name", "Unknown"), "Sticker", sticker_deleted, sticker_total)
|
await progress_callback(sticker.get("name", "Unknown"), "Sticker", sticker_deleted, sticker_total)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to delete sticker {sticker.get('name')}: {e}")
|
print(f"Failed to delete sticker {sticker.get('name')}: {e}")
|
||||||
|
logger.error(f"Failed to delete sticker {sticker.get('name')}: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to fetch stickers: {e}")
|
print(f"Failed to fetch stickers: {e}")
|
||||||
|
logger.error(f"Failed to fetch stickers: {e}")
|
||||||
|
|
||||||
return {"emojis": emoji_deleted, "stickers": sticker_deleted}
|
return {"emojis": emoji_deleted, "stickers": sticker_deleted}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,41 +16,52 @@ async def sync_channel_state(context: MigrationContext):
|
||||||
channels = await context.discord_reader.get_channels()
|
channels = await context.discord_reader.get_channels()
|
||||||
target_channels = await context.writer.get_channels()
|
target_channels = await context.writer.get_channels()
|
||||||
|
|
||||||
# Build name -> id map and ID set for Stoat for fast lookup
|
# Build maps for Stoat lookup
|
||||||
target_name_map = {c.get("name"): str(c.get("id")) for c in target_channels if c.get("name")}
|
# {name: id} for categories (type 4)
|
||||||
target_id_set = {str(c.get("id")) for c in target_channels}
|
target_cats = {c.get("name"): str(c.get("id")) for c in target_channels if c.get("type") == 4}
|
||||||
|
# {parent_id: {name: id}} for channels
|
||||||
|
target_structure = {}
|
||||||
|
for c in target_channels:
|
||||||
|
if c.get("type") == 4: continue
|
||||||
|
p_id = str(c.get("parent_id")) if c.get("parent_id") else "root"
|
||||||
|
if p_id not in target_structure: target_structure[p_id] = {}
|
||||||
|
target_structure[p_id][c.get("name")] = str(c.get("id"))
|
||||||
|
|
||||||
|
target_id_set = {str(c.get("id")) for c in target_channels}
|
||||||
updates = 0
|
updates = 0
|
||||||
removals = 0
|
removals = 0
|
||||||
|
|
||||||
# 1. Verify and Sync Categories
|
# 1. Sync Categories
|
||||||
for cat in categories:
|
for cat in categories:
|
||||||
discord_id = str(cat.id)
|
discord_id = str(cat.id)
|
||||||
target_id = context.state.get_target_category_id(discord_id)
|
target_id = context.state.get_target_category_id(discord_id)
|
||||||
|
|
||||||
if target_id:
|
if target_id:
|
||||||
if target_id not in target_id_set:
|
if target_id not in target_id_set:
|
||||||
context.state.remove_category_mapping(discord_id)
|
context.state.remove_category_mapping(discord_id)
|
||||||
removals += 1
|
removals += 1
|
||||||
elif cat.name in target_name_map:
|
elif cat.name in target_cats:
|
||||||
context.state.set_target_category_mapping(discord_id, target_name_map[cat.name])
|
context.state.set_target_category_mapping(discord_id, target_cats[cat.name])
|
||||||
updates += 1
|
updates += 1
|
||||||
|
|
||||||
# 2. Verify and Sync Channels
|
# 2. Sync Channels (parent-aware)
|
||||||
for ch in channels:
|
for ch in channels:
|
||||||
discord_id = str(ch.id)
|
discord_id = str(ch.id)
|
||||||
target_id = context.state.get_target_channel_id(discord_id)
|
target_id = context.state.get_target_channel_id(discord_id)
|
||||||
|
|
||||||
if target_id:
|
if target_id:
|
||||||
if target_id not in target_id_set:
|
if target_id not in target_id_set:
|
||||||
context.state.remove_channel_mapping(discord_id)
|
context.state.remove_channel_mapping(discord_id)
|
||||||
removals += 1
|
removals += 1
|
||||||
elif ch.name in target_name_map:
|
else:
|
||||||
context.state.set_target_channel_mapping(discord_id, target_name_map[ch.name])
|
# Try to match by name within the mapped parent category
|
||||||
|
p_discord_id = str(ch.category_id) if ch.category_id else "root"
|
||||||
|
p_target_id = context.state.get_target_category_id(p_discord_id) if p_discord_id != "root" else "root"
|
||||||
|
|
||||||
|
if p_target_id in target_structure and ch.name in target_structure[p_target_id]:
|
||||||
|
context.state.set_target_channel_mapping(discord_id, target_structure[p_target_id][ch.name])
|
||||||
updates += 1
|
updates += 1
|
||||||
|
|
||||||
if updates > 0 or removals > 0:
|
if updates > 0 or removals > 0:
|
||||||
logger.info(f"Channel sync: {updates} mapped, {removals} stale mappings removed")
|
logger.info(f"Stoat Channel sync: {updates} mapped, {removals} stale mappings removed")
|
||||||
|
|
||||||
|
|
||||||
async def migrate_channels(context: MigrationContext, progress_callback: Callable[[str, str, int, int], Awaitable[None]] | None = None, force: bool = False) -> dict:
|
async def migrate_channels(context: MigrationContext, progress_callback: Callable[[str, str, int, int], Awaitable[None]] | None = None, force: bool = False) -> dict:
|
||||||
|
|
@ -107,7 +118,22 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
||||||
if total == 0:
|
if total == 0:
|
||||||
return cloned_info
|
return cloned_info
|
||||||
|
|
||||||
# 1. Create missing channels (unparented for now)
|
# 1. Create missing categories first
|
||||||
|
for cat in missing_categories:
|
||||||
|
if not context.is_running: break
|
||||||
|
|
||||||
|
state_key = str(cat.id)
|
||||||
|
target_id = await context.writer.create_channel(cat.name, type=4)
|
||||||
|
if target_id:
|
||||||
|
context.state.set_target_category_id(state_key, target_id)
|
||||||
|
cloned_info["categories_created"].append(cat.name)
|
||||||
|
if cat.name not in cloned_info["structure"]:
|
||||||
|
cloned_info["structure"][cat.name] = []
|
||||||
|
|
||||||
|
current_idx += 1
|
||||||
|
if progress_callback: await progress_callback(cat.name, "Copying", current_idx, total)
|
||||||
|
|
||||||
|
# 2. Create missing channels (now with parent_id available)
|
||||||
for channel in channels_to_create:
|
for channel in channels_to_create:
|
||||||
if not context.is_running: break
|
if not context.is_running: break
|
||||||
|
|
||||||
|
|
@ -119,7 +145,6 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
||||||
logger.debug(f"Creating channel {channel.name}: topic={topic}, nsfw={nsfw}, slowmode={slowmode}")
|
logger.debug(f"Creating channel {channel.name}: topic={topic}, nsfw={nsfw}, slowmode={slowmode}")
|
||||||
|
|
||||||
# Map Discord-specific types to target-supported types
|
# Map Discord-specific types to target-supported types
|
||||||
# 5 (News) -> 0 (Text), and fallback any unknown non-voice types to text
|
|
||||||
raw_type = channel.type.value if hasattr(channel.type, 'value') else 0
|
raw_type = channel.type.value if hasattr(channel.type, 'value') else 0
|
||||||
if raw_type == context.discord_reader.CHANNEL_TYPE_VOICE.value:
|
if raw_type == context.discord_reader.CHANNEL_TYPE_VOICE.value:
|
||||||
ch_type = 2
|
ch_type = 2
|
||||||
|
|
@ -128,20 +153,22 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
||||||
ch_type = 0
|
ch_type = 0
|
||||||
is_voice = False
|
is_voice = False
|
||||||
else:
|
else:
|
||||||
# Fallback for Stage channels (13) etc. to Text for safety
|
|
||||||
ch_type = 0
|
ch_type = 0
|
||||||
is_voice = False
|
is_voice = False
|
||||||
|
|
||||||
|
# Resolve parent category
|
||||||
|
parent_id = context.state.get_target_category_id(str(channel.category_id)) if channel.category_id else None
|
||||||
|
|
||||||
target_id = await context.writer.create_channel(
|
target_id = await context.writer.create_channel(
|
||||||
name=channel.name,
|
name=channel.name,
|
||||||
topic=topic if not is_voice else "",
|
topic=topic if not is_voice else "",
|
||||||
type=ch_type,
|
type=ch_type,
|
||||||
parent_id=None,
|
parent_id=parent_id,
|
||||||
nsfw=nsfw if not is_voice else False,
|
nsfw=nsfw if not is_voice else False,
|
||||||
slowmode_delay=slowmode if not is_voice else 0
|
slowmode_delay=slowmode if not is_voice else 0
|
||||||
)
|
)
|
||||||
if target_id:
|
if target_id:
|
||||||
context.state.set_target_channel_mapping(state_key, target_id)
|
context.state.set_target_channel_id(state_key, target_id)
|
||||||
cloned_info["channels_created"].append(channel.name)
|
cloned_info["channels_created"].append(channel.name)
|
||||||
|
|
||||||
parent_name = cat_name_map.get(str(channel.category_id), "No Category") if channel.category_id else "No Category"
|
parent_name = cat_name_map.get(str(channel.category_id), "No Category") if channel.category_id else "No Category"
|
||||||
|
|
@ -156,7 +183,8 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
||||||
name=channel.name,
|
name=channel.name,
|
||||||
topic=topic,
|
topic=topic,
|
||||||
nsfw=nsfw,
|
nsfw=nsfw,
|
||||||
slowmode_delay=slowmode
|
slowmode_delay=slowmode,
|
||||||
|
parent_id=parent_id
|
||||||
)
|
)
|
||||||
|
|
||||||
current_idx += 1
|
current_idx += 1
|
||||||
|
|
@ -172,32 +200,21 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
||||||
|
|
||||||
logger.debug(f"Syncing existing channel {channel.name} ({target_id}): topic={topic}, nsfw={nsfw}, slowmode={slowmode}")
|
logger.debug(f"Syncing existing channel {channel.name} ({target_id}): topic={topic}, nsfw={nsfw}, slowmode={slowmode}")
|
||||||
|
|
||||||
|
parent_id = context.state.get_target_category_id(str(channel.category_id)) if channel.category_id else None
|
||||||
|
|
||||||
await context.writer.modify_channel(
|
await context.writer.modify_channel(
|
||||||
channel_id=target_id,
|
channel_id=target_id,
|
||||||
name=channel.name,
|
name=channel.name,
|
||||||
topic=topic,
|
topic=topic,
|
||||||
nsfw=nsfw,
|
nsfw=nsfw,
|
||||||
slowmode_delay=slowmode
|
slowmode_delay=slowmode,
|
||||||
|
parent_id=parent_id
|
||||||
)
|
)
|
||||||
|
|
||||||
cloned_info["channels_synced"].append(channel.name)
|
cloned_info["channels_synced"].append(channel.name)
|
||||||
|
|
||||||
current_idx += 1
|
current_idx += 1
|
||||||
if progress_callback: await progress_callback(channel.name, "Syncing", current_idx, total)
|
if progress_callback: await progress_callback(channel.name, "Syncing", current_idx, total)
|
||||||
|
|
||||||
# 3. Create missing categories
|
|
||||||
for cat in missing_categories:
|
|
||||||
if not context.is_running: break
|
|
||||||
|
|
||||||
state_key = str(cat.id)
|
|
||||||
target_id = await context.writer.create_channel(cat.name, type=4)
|
|
||||||
if target_id:
|
|
||||||
context.state.set_target_category_mapping(state_key, target_id)
|
|
||||||
cloned_info["categories_created"].append(cat.name)
|
|
||||||
if cat.name not in cloned_info["structure"]:
|
|
||||||
cloned_info["structure"][cat.name] = []
|
|
||||||
|
|
||||||
current_idx += 1
|
|
||||||
if progress_callback: await progress_callback(f"Cat: {cat.name}", "Copying", current_idx, total)
|
if progress_callback: await progress_callback(f"Cat: {cat.name}", "Copying", current_idx, total)
|
||||||
|
|
||||||
# 4. Final step: Parent the channels into categories via mass server.edit()
|
# 4. Final step: Parent the channels into categories via mass server.edit()
|
||||||
|
|
@ -223,7 +240,7 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
||||||
|
|
||||||
# Resolve Stoat categories
|
# Resolve Stoat categories
|
||||||
# We iterate over the categories from the server to ensure we don't drop any
|
# We iterate over the categories from the server to ensure we don't drop any
|
||||||
for stoat_cat in server.categories:
|
for stoat_cat in (server.categories or []):
|
||||||
# Check if this Stoat category maps to any Discord category
|
# Check if this Stoat category maps to any Discord category
|
||||||
discord_cat_id = next((d_id for d_id, t_id in context.state.category_map.items() if t_id == str(stoat_cat.id)), None)
|
discord_cat_id = next((d_id for d_id, t_id in context.state.category_map.items() if t_id == str(stoat_cat.id)), None)
|
||||||
|
|
||||||
|
|
@ -261,8 +278,35 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
||||||
target_categories.sort(key=get_cat_position)
|
target_categories.sort(key=get_cat_position)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await server.edit(categories=target_categories)
|
import aiohttp
|
||||||
|
api_base = (context.writer.api_url or "https://api.stoat.chat/0.8").rstrip("/")
|
||||||
|
cats_payload = []
|
||||||
|
for c in target_categories:
|
||||||
|
entry: dict = {"id": str(c.id), "title": c.title, "channels": [str(x) for x in c.channels]}
|
||||||
|
if getattr(c, "default_permissions", None) is not None:
|
||||||
|
entry["default_permissions"] = c.default_permissions
|
||||||
|
if getattr(c, "role_permissions", None):
|
||||||
|
entry["role_permissions"] = c.role_permissions
|
||||||
|
cats_payload.append(entry)
|
||||||
|
for _attempt in range(5):
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.patch(
|
||||||
|
f"{api_base}/servers/{context.writer.community_id}",
|
||||||
|
json={"categories": cats_payload},
|
||||||
|
headers={"X-Bot-Token": context.writer.token.strip(), "Content-Type": "application/json"},
|
||||||
|
) as resp:
|
||||||
|
if resp.status == 429:
|
||||||
|
import re as _re
|
||||||
|
body = await resp.json(content_type=None)
|
||||||
|
wait_ms = body.get("retry_after", 10000)
|
||||||
|
logger.warning(f"Rate limited on mass parent, retrying in {wait_ms/1000:.1f}s…")
|
||||||
|
await asyncio.sleep(wait_ms / 1000 + 0.5)
|
||||||
|
continue
|
||||||
|
if resp.status not in (200, 204):
|
||||||
|
body = await resp.json(content_type=None)
|
||||||
|
raise Exception(f"HTTP {resp.status}: {body}")
|
||||||
logger.info("Successfully parented all channels.")
|
logger.info("Successfully parented all channels.")
|
||||||
|
break
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.error(f"Failed to mass parent channels: {ex}")
|
logger.error(f"Failed to mass parent channels: {ex}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -155,7 +155,192 @@ async def get_channel_threads(reader: Any, channel_id: int) -> List[Any]:
|
||||||
return threads
|
return threads
|
||||||
|
|
||||||
|
|
||||||
|
async def _process_and_send_message(
|
||||||
|
context: MigrationContext,
|
||||||
|
msg: Any,
|
||||||
|
target_channel_id: str,
|
||||||
|
stats: Dict[str, Any],
|
||||||
|
thread_id: str | None = None,
|
||||||
|
parent_target_id: str | None = None,
|
||||||
|
thread_name: str | None = None,
|
||||||
|
processed_threads: set | None = None
|
||||||
|
) -> str | None:
|
||||||
|
"""
|
||||||
|
Internal helper to process a single Discord message (mentions, attachments, stickers)
|
||||||
|
and send it to the Stoat platform.
|
||||||
|
"""
|
||||||
|
# 1. Processing Flags
|
||||||
|
is_forwarded = False
|
||||||
|
if hasattr(msg.flags, 'forwarded'):
|
||||||
|
is_forwarded = msg.flags.forwarded
|
||||||
|
|
||||||
|
# 2. Content & Formatting
|
||||||
|
content = msg.content or ""
|
||||||
|
anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False
|
||||||
|
alias = context.state.get_user_alias(str(msg.author.id))
|
||||||
|
|
||||||
|
# Process Stickers
|
||||||
|
files = []
|
||||||
|
if hasattr(msg, 'stickers') and msg.stickers:
|
||||||
|
for s in msg.stickers:
|
||||||
|
try:
|
||||||
|
sticker_data = await context.discord_reader.download_sticker(s)
|
||||||
|
if not sticker_data: continue
|
||||||
|
|
||||||
|
format_val = getattr(s, 'format', 'png')
|
||||||
|
if hasattr(format_val, 'name'):
|
||||||
|
ext = format_val.name.lower()
|
||||||
|
elif isinstance(format_val, int):
|
||||||
|
format_map = {1: 'png', 2: 'apng', 3: 'lottie', 4: 'gif'}
|
||||||
|
ext = format_map.get(format_val, 'png')
|
||||||
|
else:
|
||||||
|
ext = str(format_val).lower()
|
||||||
|
|
||||||
|
# Conversion logic for Stoat (WebP or GIF focus)
|
||||||
|
if ext == 'lottie' and HAS_LOTTIE:
|
||||||
|
try:
|
||||||
|
lottie_data = json.loads(sticker_data)
|
||||||
|
def _convert_lottie_to_gif(data):
|
||||||
|
animation = Animation.load(data)
|
||||||
|
output = io.BytesIO()
|
||||||
|
export_gif(animation, output)
|
||||||
|
return output.getvalue()
|
||||||
|
sticker_data = await asyncio.to_thread(_convert_lottie_to_gif, lottie_data)
|
||||||
|
ext = 'gif'
|
||||||
|
except Exception: ext = 'json'
|
||||||
|
elif ext == 'apng':
|
||||||
|
try:
|
||||||
|
from PIL import Image
|
||||||
|
def _convert_apng_to_gif(data):
|
||||||
|
img = Image.open(io.BytesIO(data))
|
||||||
|
gif_buf = io.BytesIO()
|
||||||
|
if getattr(img, 'n_frames', 1) > 1:
|
||||||
|
frames = []
|
||||||
|
durations = []
|
||||||
|
for i in range(img.n_frames):
|
||||||
|
img.seek(i)
|
||||||
|
frame = img.convert('RGBA')
|
||||||
|
current_frame = Image.new('RGBA', img.size, (0,0,0,0))
|
||||||
|
current_frame.paste(frame, (0, 0))
|
||||||
|
frames.append(current_frame)
|
||||||
|
durations.append(img.info.get('duration', 100))
|
||||||
|
frames[0].save(gif_buf, format='GIF', save_all=True, append_images=frames[1:], loop=0, duration=durations, disposal=2, transparency=0)
|
||||||
|
else: img.save(gif_buf, format='GIF')
|
||||||
|
return gif_buf.getvalue()
|
||||||
|
sticker_data = await asyncio.to_thread(_convert_apng_to_gif, sticker_data)
|
||||||
|
ext = 'gif'
|
||||||
|
except Exception: pass
|
||||||
|
|
||||||
|
files.append({
|
||||||
|
"filename": f"sticker_{s.name}_{s.id}.{ext}",
|
||||||
|
"data": sticker_data,
|
||||||
|
"content_type": f"image/{ext}" if ext != "json" else "application/json"
|
||||||
|
})
|
||||||
|
stats["attachments"] += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to download sticker {getattr(s, 'name', 'unknown')}: {e}")
|
||||||
|
|
||||||
|
# Process Attachments
|
||||||
|
attachments_to_process = list(msg.attachments)
|
||||||
|
if is_forwarded and hasattr(msg, 'message_snapshots') and msg.message_snapshots:
|
||||||
|
snapshot = msg.message_snapshots[0]
|
||||||
|
if not content: content = snapshot.content
|
||||||
|
attachments_to_process.extend(snapshot.attachments)
|
||||||
|
|
||||||
|
for att in attachments_to_process:
|
||||||
|
try:
|
||||||
|
att_data = await context.discord_reader.download_attachment(att)
|
||||||
|
files.append({
|
||||||
|
"filename": att.filename,
|
||||||
|
"data": att_data,
|
||||||
|
"content_type": getattr(att, "content_type", None)
|
||||||
|
})
|
||||||
|
stats["attachments"] += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to download attachment {att.filename}: {e}")
|
||||||
|
|
||||||
|
# Clean Mentions
|
||||||
|
content = clean_mentions(
|
||||||
|
content=content,
|
||||||
|
guild=context.discord_reader.guild,
|
||||||
|
user_mentions=msg.mentions,
|
||||||
|
role_mentions=msg.role_mentions,
|
||||||
|
channel_mentions=msg.channel_mentions,
|
||||||
|
emoji_map=context.state.emoji_map,
|
||||||
|
channel_map=context.state.channel_map,
|
||||||
|
state=context.state,
|
||||||
|
target_server_id=context.stoat_writer.community_id,
|
||||||
|
channel_names=context.channel_names if hasattr(context, 'channel_names') else None,
|
||||||
|
anonymize_users=anonymize_users
|
||||||
|
)
|
||||||
|
|
||||||
|
if not content and not files:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Reply Resolution
|
||||||
|
reply_to_stoat_id = None
|
||||||
|
if msg.reference and msg.reference.message_id:
|
||||||
|
reply_to_stoat_id = context.state.get_target_message_id(target_channel_id, str(msg.reference.message_id))
|
||||||
|
if not reply_to_stoat_id:
|
||||||
|
# Fallback author tagging
|
||||||
|
try:
|
||||||
|
source_ref_msg = await context.discord_reader.get_message(msg.channel.id, msg.reference.message_id)
|
||||||
|
if source_ref_msg and source_ref_msg.author:
|
||||||
|
ref_name = context.state.get_user_alias(str(source_ref_msg.author.id)) if anonymize_users else source_ref_msg.author.display_name
|
||||||
|
content = f"`@{ref_name}`\n{content}"
|
||||||
|
else:
|
||||||
|
tgt_reply = context.state.get_target_message_id(target_channel_id, msg.reference.message_id)
|
||||||
|
if tgt_reply: content = f"[Reply to {tgt_reply}]\n{content}"
|
||||||
|
except Exception: pass
|
||||||
|
|
||||||
|
# Thread logic
|
||||||
|
if not reply_to_stoat_id and parent_target_id and stats["messages"] == 0:
|
||||||
|
reply_to_stoat_id = parent_target_id
|
||||||
|
if thread_name and stats["messages"] == 0:
|
||||||
|
content = f"> <<< THREAD: **{thread_name}** >>>\n{content}"
|
||||||
|
|
||||||
|
# Author resolution
|
||||||
|
if anonymize_users:
|
||||||
|
author_name = alias or "Anonymized User"
|
||||||
|
author_avatar_url = None
|
||||||
|
else:
|
||||||
|
author_name = msg.author.display_name
|
||||||
|
author_avatar_url = str(msg.author.display_avatar.url) if msg.author.display_avatar.url else None
|
||||||
|
if author_avatar_url and not author_avatar_url.startswith("http"): author_avatar_url = None
|
||||||
|
|
||||||
|
# Send Message
|
||||||
|
stoat_msg_id = await context.stoat_writer.send_message(
|
||||||
|
channel_id=target_channel_id,
|
||||||
|
author_name=author_name,
|
||||||
|
author_avatar_url=author_avatar_url,
|
||||||
|
content=content,
|
||||||
|
timestamp=int(msg.created_at.timestamp()),
|
||||||
|
files=files if files else None,
|
||||||
|
reply_to_message_id=reply_to_stoat_id,
|
||||||
|
is_forwarded=is_forwarded,
|
||||||
|
embeds=msg.embeds
|
||||||
|
)
|
||||||
|
|
||||||
|
if stoat_msg_id:
|
||||||
|
if thread_id:
|
||||||
|
context.state.set_thread_message_mapping(target_channel_id, thread_id, str(msg.id), stoat_msg_id)
|
||||||
|
context.state.update_thread_last_message_timestamp(target_channel_id, thread_id, str(msg.created_at))
|
||||||
|
context.state.update_thread_last_message_id(target_channel_id, thread_id, str(msg.id))
|
||||||
|
context.state.increment_thread_stats(target_channel_id, thread_id, messages=1, files=len(files) if files else 0)
|
||||||
|
else:
|
||||||
|
context.state.set_message_mapping(target_channel_id, str(msg.id), stoat_msg_id)
|
||||||
|
context.state.update_last_message_timestamp(target_channel_id, str(msg.created_at))
|
||||||
|
context.state.update_last_message_id(target_channel_id, str(msg.id))
|
||||||
|
context.state.increment_stats(target_channel_id, messages=1, files=len(files) if files else 0)
|
||||||
|
|
||||||
|
stats["messages"] += 1
|
||||||
|
stats["last_message_content"] = content
|
||||||
|
stats["last_message_author"] = msg.author.display_name
|
||||||
|
|
||||||
|
return stoat_msg_id
|
||||||
|
|
||||||
async def analyze_migration(context: MigrationContext, source_channel_id: int, after_message_id: int | None = None, inclusive: bool = False, progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None, processed_threads: set | None = None) -> Dict[str, int]:
|
async def analyze_migration(context: MigrationContext, source_channel_id: int, after_message_id: int | None = None, inclusive: bool = False, progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None, processed_threads: set | None = None) -> Dict[str, int]:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Scans channel history to count messages, threads, and attachments.
|
Scans channel history to count messages, threads, and attachments.
|
||||||
"""
|
"""
|
||||||
|
|
@ -546,87 +731,30 @@ async def migrate_messages(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to download sticker {getattr(s, 'name', 'unknown')}: {e}")
|
logger.error(f"Failed to download sticker {getattr(s, 'name', 'unknown')}: {e}")
|
||||||
|
|
||||||
try:
|
|
||||||
# Check for existing mapping to avoid duplicates when resuming
|
# Check for existing mapping to avoid duplicates when resuming
|
||||||
if context.state.get_target_message_id(target_channel_id, str(msg.id)):
|
if context.state.get_target_message_id(target_channel_id, str(msg.id)):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check if this message is a reply
|
try:
|
||||||
reply_to_stoat_id = None
|
stoat_msg_id = await _process_and_send_message(
|
||||||
if msg.reference and msg.reference.message_id:
|
context=context,
|
||||||
reply_to_stoat_id = context.state.get_target_message_id(target_channel_id, str(msg.reference.message_id))
|
msg=msg,
|
||||||
if reply_to_stoat_id:
|
target_channel_id=target_channel_id,
|
||||||
logger.debug(f"Detected reply to Discord ID {msg.reference.message_id} -> Stoat ID {reply_to_stoat_id}")
|
stats=stats,
|
||||||
else:
|
thread_id=thread_id,
|
||||||
logger.debug(f"Reply target Discord ID {msg.reference.message_id} not found in current session map.")
|
parent_target_id=parent_target_id,
|
||||||
|
thread_name=thread_name,
|
||||||
# If this is the FIRST thread message and we have a parent_target_id, force it as reply to the starter
|
processed_threads=processed_threads
|
||||||
if not reply_to_stoat_id and parent_target_id and stats["messages"] == 0:
|
|
||||||
reply_to_stoat_id = parent_target_id
|
|
||||||
|
|
||||||
# Prepend thread marker to the first message of the thread
|
|
||||||
if thread_name and stats["messages"] == 0:
|
|
||||||
content = f"> <<< THREAD: **{thread_name}** >>>\n{content}"
|
|
||||||
|
|
||||||
# Always ensure alias is created/retrieved to populate user_alias table
|
|
||||||
alias = context.state.get_user_alias(str(msg.author.id))
|
|
||||||
|
|
||||||
anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False
|
|
||||||
if anonymize_users:
|
|
||||||
author_name = alias or "Anonymized User"
|
|
||||||
author_avatar_url = None
|
|
||||||
else:
|
|
||||||
author_name = msg.author.display_name
|
|
||||||
author_avatar_url = str(msg.author.display_avatar.url) if msg.author.display_avatar.url else None
|
|
||||||
if author_avatar_url and not author_avatar_url.startswith("http"):
|
|
||||||
author_avatar_url = None
|
|
||||||
|
|
||||||
logger.debug(f"Stoat: Calling send_message for Discord ID {msg.id}")
|
|
||||||
stoat_msg_id = await context.stoat_writer.send_message(
|
|
||||||
channel_id=target_channel_id,
|
|
||||||
author_name=author_name,
|
|
||||||
author_avatar_url=author_avatar_url,
|
|
||||||
content=content,
|
|
||||||
timestamp=int(msg.created_at.timestamp()),
|
|
||||||
files=files if files else None,
|
|
||||||
reply_to_message_id=reply_to_stoat_id,
|
|
||||||
is_forwarded=is_forwarded,
|
|
||||||
embeds=msg.embeds
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if stoat_msg_id:
|
# Check for associated thread (Individual mode recursion)
|
||||||
if thread_id:
|
|
||||||
context.state.set_thread_message_mapping(target_channel_id, thread_id, str(msg.id), stoat_msg_id)
|
|
||||||
else:
|
|
||||||
context.state.set_message_mapping(target_channel_id, str(msg.id), stoat_msg_id)
|
|
||||||
|
|
||||||
if thread_id:
|
|
||||||
context.state.update_thread_last_message_timestamp(target_channel_id, thread_id, str(msg.created_at))
|
|
||||||
context.state.update_thread_last_message_id(target_channel_id, thread_id, str(msg.id))
|
|
||||||
context.state.increment_thread_stats(target_channel_id, thread_id, messages=1, files=len(files) if files else 0)
|
|
||||||
else:
|
|
||||||
context.state.update_last_message_timestamp(target_channel_id, str(msg.created_at))
|
|
||||||
context.state.update_last_message_id(target_channel_id, str(msg.id))
|
|
||||||
context.state.increment_stats(target_channel_id, messages=1, files=len(files) if files else 0)
|
|
||||||
|
|
||||||
stats["messages"] += 1
|
|
||||||
stats["last_message_content"] = content
|
|
||||||
stats["last_message_author"] = msg.author.display_name
|
|
||||||
|
|
||||||
# Check for associated thread (Normal case: parent message is migrated)
|
|
||||||
if hasattr(msg, 'thread') and msg.thread:
|
if hasattr(msg, 'thread') and msg.thread:
|
||||||
thread = msg.thread
|
thread = msg.thread
|
||||||
if thread.id not in processed_threads:
|
if thread.id not in processed_threads:
|
||||||
processed_threads.add(thread.id)
|
processed_threads.add(thread.id)
|
||||||
# Track thread entry
|
|
||||||
stats["threads"] += 1
|
stats["threads"] += 1
|
||||||
|
|
||||||
# Fetch last migrated message ID for this thread
|
|
||||||
thread_after_id = context.state.get_thread_last_message_id(target_channel_id, str(thread.id))
|
thread_after_id = context.state.get_thread_last_message_id(target_channel_id, str(thread.id))
|
||||||
if thread_after_id:
|
|
||||||
logger.info(f"Resuming thread '{thread.name}' from after message ID: {thread_after_id}")
|
|
||||||
|
|
||||||
# Migrate thread messages recursively
|
|
||||||
thread_stats = await migrate_messages(
|
thread_stats = await migrate_messages(
|
||||||
context=context,
|
context=context,
|
||||||
source_channel_id=thread.id,
|
source_channel_id=thread.id,
|
||||||
|
|
@ -641,7 +769,6 @@ async def migrate_messages(
|
||||||
stats["attachments"] += thread_stats["attachments"]
|
stats["attachments"] += thread_stats["attachments"]
|
||||||
stats["threads"] += thread_stats["threads"]
|
stats["threads"] += thread_stats["threads"]
|
||||||
|
|
||||||
# Send End Marker
|
|
||||||
if context.is_running:
|
if context.is_running:
|
||||||
await context.stoat_writer.send_marker(
|
await context.stoat_writer.send_marker(
|
||||||
channel_id=target_channel_id,
|
channel_id=target_channel_id,
|
||||||
|
|
@ -656,13 +783,12 @@ async def migrate_messages(
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
await progress_callback(stats)
|
await progress_callback(stats)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# If it's a permission error, stop the entire migration
|
if "MissingPermission" in str(e): raise
|
||||||
if "MissingPermission" in str(e):
|
|
||||||
raise
|
|
||||||
logger.error(f"Failed to process message {msg.id}: {e}")
|
logger.error(f"Failed to process message {msg.id}: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
# Mark thread as completed if we finished the loop without being interrupted
|
# Mark thread as completed if we finished the loop without being interrupted
|
||||||
if thread_id and context.is_running:
|
if thread_id and context.is_running:
|
||||||
context.state.update_thread_completed(target_channel_id, thread_id, completed=True)
|
context.state.update_thread_completed(target_channel_id, thread_id, completed=True)
|
||||||
|
|
@ -795,107 +921,15 @@ async def migrate_global_messages(
|
||||||
elif msg.channel.type in [11, 12]:
|
elif msg.channel.type in [11, 12]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Formatting
|
|
||||||
files = []
|
|
||||||
|
|
||||||
# Always ensure alias is created/retrieved to populate user_alias table
|
|
||||||
alias = context.state.get_user_alias(str(msg.author.id))
|
|
||||||
|
|
||||||
anonymize_users = context.config.anonymize_users if hasattr(context, 'config') else False
|
|
||||||
|
|
||||||
if anonymize_users:
|
|
||||||
author_name = alias or "Anonymized User"
|
|
||||||
author_avatar_url = None
|
|
||||||
else:
|
|
||||||
author_name = msg.author.display_name
|
|
||||||
author_avatar_url = msg.author.avatar.url if hasattr(msg.author, 'avatar') and msg.author.avatar else None
|
|
||||||
|
|
||||||
for att in msg.attachments:
|
|
||||||
media_info = db_media.get(att.local_hash) if db_media else None
|
|
||||||
local_path = None
|
|
||||||
if media_info:
|
|
||||||
local_path = Path(media_info["local_path"])
|
|
||||||
|
|
||||||
if local_path and local_path.exists():
|
|
||||||
try:
|
try:
|
||||||
with open(local_path, "rb") as f:
|
await _process_and_send_message(
|
||||||
files.append({"filename": att.filename, "data": f.read()})
|
context=context,
|
||||||
except Exception as fe:
|
msg=msg,
|
||||||
logger.error(f"Failed to read file {local_path}: {fe}")
|
target_channel_id=target_channel_id,
|
||||||
|
stats=stats,
|
||||||
content = msg.content or ""
|
processed_threads=processed_threads
|
||||||
|
|
||||||
for sticker in msg.stickers:
|
|
||||||
sticker_name = sticker.name
|
|
||||||
s_hash = sticker.local_hash
|
|
||||||
sticker_file = None
|
|
||||||
s_media = db_media.get(s_hash) if db_media and s_hash else None
|
|
||||||
if s_media:
|
|
||||||
s_path = Path(s_media["local_path"])
|
|
||||||
if s_path.exists():
|
|
||||||
sticker_file = s_path
|
|
||||||
|
|
||||||
content += f"\n[Sticker: {sticker_name}]"
|
|
||||||
if sticker_file:
|
|
||||||
files.append(sticker_file)
|
|
||||||
file_names.append(f"sticker_{sticker_name}.png")
|
|
||||||
|
|
||||||
content = clean_mentions(
|
|
||||||
content=content,
|
|
||||||
guild=context.discord_reader.guild,
|
|
||||||
user_mentions=msg.mentions,
|
|
||||||
role_mentions=msg.role_mentions,
|
|
||||||
channel_mentions=msg.channel_mentions,
|
|
||||||
emoji_map=emoji_map,
|
|
||||||
channel_map=context.state.channel_map,
|
|
||||||
state=context.state,
|
|
||||||
target_server_id=target_server_id,
|
|
||||||
channel_names=context.channel_names if hasattr(context, 'channel_names') else None,
|
|
||||||
anonymize_users=anonymize_users
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not content and not files:
|
|
||||||
logger.debug(f"Message {msg.id} empty after processing, skipping.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
timestamp_int = int(msg.created_at.timestamp())
|
|
||||||
|
|
||||||
if msg.reference and msg.reference.message_id:
|
|
||||||
# Resolve the author of the message being replied to
|
|
||||||
source_ref_msg = await context.discord_reader.get_message(msg.channel_id, msg.reference.message_id)
|
|
||||||
if source_ref_msg and source_ref_msg.author:
|
|
||||||
ref_author_id = str(source_ref_msg.author.id)
|
|
||||||
if anonymize_users:
|
|
||||||
ref_name = context.state.get_user_alias(ref_author_id) or "Anonymized User"
|
|
||||||
else:
|
|
||||||
ref_name = source_ref_msg.author.display_name
|
|
||||||
content = f"`@{ref_name}`\n{content}"
|
|
||||||
else:
|
|
||||||
tgt_reply = context.state.get_target_message_id(target_channel_id, msg.reference.message_id)
|
|
||||||
if tgt_reply:
|
|
||||||
content = f"[Reply to {tgt_reply}]\n{content}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
stoat_msg_id = await context.stoat_writer.send_message(
|
|
||||||
channel_id=target_channel_id,
|
|
||||||
author_name=author_name,
|
|
||||||
author_avatar_url=author_avatar_url,
|
|
||||||
content=content,
|
|
||||||
files=files,
|
|
||||||
timestamp=timestamp_int,
|
|
||||||
embeds=msg.embeds
|
|
||||||
)
|
|
||||||
|
|
||||||
if stoat_msg_id:
|
|
||||||
context.state.set_target_message_mapping(target_channel_id, msg.id, stoat_msg_id)
|
|
||||||
context.state.update_last_message_id(target_channel_id, msg.id)
|
|
||||||
context.state.set_waterfall_last_id(msg.id)
|
|
||||||
stats["attachments"] += len(files) if files else 0
|
|
||||||
|
|
||||||
stats["messages"] += 1
|
|
||||||
stats["last_message_content"] = content
|
|
||||||
stats["last_message_author"] = author_name
|
|
||||||
|
|
||||||
if not stats["first_message_url"]:
|
if not stats["first_message_url"]:
|
||||||
stats["first_message_url"] = msg.jump_url
|
stats["first_message_url"] = msg.jump_url
|
||||||
stats["last_message_url"] = msg.jump_url
|
stats["last_message_url"] = msg.jump_url
|
||||||
|
|
@ -906,6 +940,7 @@ async def migrate_global_messages(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to process global message {msg.id}: {e}")
|
logger.error(f"Failed to process global message {msg.id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
except (KeyboardInterrupt, asyncio.CancelledError):
|
except (KeyboardInterrupt, asyncio.CancelledError):
|
||||||
context.is_running = False
|
context.is_running = False
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -5,24 +5,90 @@ from typing import Optional, List, Dict, Any
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def _discover_stoat_config(api_url: str) -> Dict[str, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Fetches the Stoat/Revolt instance configuration to discover the WS and CDN URLs.
|
||||||
|
Returns a dict with 'ws' and 'cdn' keys.
|
||||||
|
"""
|
||||||
|
results = {"ws": None, "cdn": None}
|
||||||
|
|
||||||
|
if not api_url or api_url == "default" or "stoat.chat" in api_url:
|
||||||
|
return results
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
try:
|
||||||
|
# Standard Revolt discovery endpoint is the root API URL
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(api_url.rstrip("/") + "/") as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
data = await resp.json()
|
||||||
|
results["ws"] = data.get("ws")
|
||||||
|
# Features might be in 'features.autumn.url'
|
||||||
|
features = data.get("features", {})
|
||||||
|
autumn = features.get("autumn", {})
|
||||||
|
results["cdn"] = autumn.get("url")
|
||||||
|
|
||||||
|
if results["ws"] or results["cdn"]:
|
||||||
|
logger.debug(f"Stoat Discovery: Found WS={results['ws']}, CDN={results['cdn']}")
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Stoat Discovery failed (fetching): {e}")
|
||||||
|
|
||||||
|
# Fallback to inference if fetch failed
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
try:
|
||||||
|
parsed = urlparse(api_url)
|
||||||
|
if parsed.netloc:
|
||||||
|
# Traditional defaults for self-hosted
|
||||||
|
results["ws"] = f"wss://{parsed.netloc}/ws"
|
||||||
|
results["cdn"] = f"https://{parsed.netloc}/autumn"
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
class StoatWriter:
|
class StoatWriter:
|
||||||
def __init__(self, token: str, community_id: str, api_url: str = "default"):
|
def __init__(self, token: str, community_id: str, api_url: str = "default", ws_url: str = None):
|
||||||
self.token = token
|
self.token = token
|
||||||
self.community_id = str(community_id)
|
self.community_id = str(community_id)
|
||||||
self.api_url = api_url
|
self.api_url = api_url
|
||||||
|
self.ws_url = ws_url
|
||||||
self.client: Optional[stoat.Client] = None
|
self.client: Optional[stoat.Client] = None
|
||||||
self._server = None
|
self._server = None
|
||||||
self._me = None
|
self._me = None
|
||||||
self._validation_cache = None
|
self._validation_cache = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def fetch_guilds(token: str, api_url: str = "default") -> list[tuple[str, str]]:
|
async def fetch_guilds(token: str, api_url: str = "default", ws_url: str = None, cdn_url: str = None) -> list[tuple[str, str]]:
|
||||||
"""Fetches the list of Stoat servers the bot is in. Returns list of (label, id)."""
|
"""Fetches the list of Stoat servers the bot is in. Returns list of (label, id)."""
|
||||||
client_kwargs = {"token": token, "bot": True}
|
token = token.strip()
|
||||||
if api_url and api_url != "default":
|
api_url = (api_url or "default").strip()
|
||||||
client_kwargs["http_base"] = api_url
|
ws_url = (ws_url or "").strip()
|
||||||
|
cdn_url = (cdn_url or "").strip()
|
||||||
|
|
||||||
|
# Auto-discover URLs if not provided for custom domains
|
||||||
|
if api_url != "default" and (not ws_url or not cdn_url):
|
||||||
|
discovery = await _discover_stoat_config(api_url)
|
||||||
|
if not ws_url: ws_url = discovery["ws"] or ""
|
||||||
|
if not cdn_url: cdn_url = discovery["cdn"] or ""
|
||||||
|
|
||||||
|
# Diagnostics to both stdout and logger
|
||||||
|
log_msg = f"Stoat: Fetching guilds using API URL: {api_url}"
|
||||||
|
if ws_url: log_msg += f" [WS: {ws_url}]"
|
||||||
|
if cdn_url: log_msg += f" [CDN: {cdn_url}]"
|
||||||
|
print(log_msg)
|
||||||
|
logger.debug(log_msg)
|
||||||
|
|
||||||
|
client_kwargs = {
|
||||||
|
"token": token,
|
||||||
|
"bot": True,
|
||||||
|
"http_base": api_url if api_url != "default" else None,
|
||||||
|
"websocket_base": ws_url or None,
|
||||||
|
"cdn_base": cdn_url or None
|
||||||
|
}
|
||||||
client = stoat.Client(**client_kwargs)
|
client = stoat.Client(**client_kwargs)
|
||||||
|
logger.debug(f"Stoat: Initialized client with native http_base={client_kwargs['http_base']} and websocket_base={client_kwargs['websocket_base']}")
|
||||||
|
|
||||||
ready_event = asyncio.Event()
|
ready_event = asyncio.Event()
|
||||||
servers_list = []
|
servers_list = []
|
||||||
|
|
||||||
|
|
@ -58,14 +124,21 @@ class StoatWriter:
|
||||||
else:
|
else:
|
||||||
raise asyncio.TimeoutError("Timed out waiting for Stoat to be ready")
|
raise asyncio.TimeoutError("Timed out waiting for Stoat to be ready")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"Failed to fetch Stoat servers: {e}")
|
||||||
logger.error(f"Failed to fetch Stoat servers: {e}")
|
logger.error(f"Failed to fetch Stoat servers: {e}")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
|
# Shutdown the specific client instance used for fetching
|
||||||
|
try:
|
||||||
await client.close()
|
await client.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
client_task.cancel()
|
client_task.cancel()
|
||||||
try:
|
try:
|
||||||
await client_task
|
# Wait for the task to actually finish terminating
|
||||||
except asyncio.CancelledError:
|
await asyncio.wait_for(client_task, timeout=2.0)
|
||||||
|
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return guilds_list
|
return guilds_list
|
||||||
|
|
@ -80,16 +153,45 @@ class StoatWriter:
|
||||||
pass
|
pass
|
||||||
self.client = None
|
self.client = None
|
||||||
|
|
||||||
client_kwargs = {"token": self.token, "bot": True}
|
api_url = (self.api_url or "default").strip()
|
||||||
if self.api_url and self.api_url != "default":
|
ws_url = (self.ws_url or "").strip()
|
||||||
client_kwargs["http_base"] = self.api_url
|
cdn_url = getattr(self, "cdn_url", "").strip()
|
||||||
|
|
||||||
|
# Auto-discover if not provided for custom domains
|
||||||
|
if api_url != "default" and (not ws_url or not cdn_url):
|
||||||
|
discovery = await _discover_stoat_config(api_url)
|
||||||
|
if not ws_url: ws_url = discovery["ws"] or ""
|
||||||
|
if not cdn_url: cdn_url = discovery["cdn"] or ""
|
||||||
|
|
||||||
|
token = self.token.strip()
|
||||||
|
|
||||||
|
# Diagnostics to both stdout and logger
|
||||||
|
log_msg = f"Stoat: Starting client using API URL: {api_url}"
|
||||||
|
if ws_url: log_msg += f" [WS: {ws_url}]"
|
||||||
|
if cdn_url: log_msg += f" [CDN: {cdn_url}]"
|
||||||
|
print(log_msg)
|
||||||
|
logger.debug(log_msg)
|
||||||
|
|
||||||
|
client_kwargs = {
|
||||||
|
"token": token,
|
||||||
|
"bot": True,
|
||||||
|
"http_base": api_url if api_url != "default" else None,
|
||||||
|
"websocket_base": ws_url or None,
|
||||||
|
"cdn_base": cdn_url or None
|
||||||
|
}
|
||||||
self.client = stoat.Client(**client_kwargs)
|
self.client = stoat.Client(**client_kwargs)
|
||||||
|
|
||||||
|
# Keep track of the start task so we can clean it up later
|
||||||
|
self._start_task = asyncio.create_task(self.client.start())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._me = await self.client.fetch_user("@me")
|
self._me = await self.client.fetch_user("@me")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"Failed to fetch bot user in StoatWriter: {e}")
|
||||||
logger.error(f"Failed to fetch bot user in StoatWriter: {e}")
|
logger.error(f"Failed to fetch bot user in StoatWriter: {e}")
|
||||||
self.client = None # Reset if we can't even fetch @me
|
# Ensure we clean up if start fails
|
||||||
|
await self.close()
|
||||||
|
self.client = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def my_id(self):
|
def my_id(self):
|
||||||
|
|
@ -235,68 +337,93 @@ class StoatWriter:
|
||||||
|
|
||||||
return results
|
return results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"Failed to fetch Stoat channels: {e}")
|
||||||
logger.error(f"Failed to fetch Stoat channels: {e}")
|
logger.error(f"Failed to fetch Stoat channels: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def create_channel(self, name: str, type: int = 0, topic: str = "", parent_id: Optional[str] = None, **kwargs) -> str:
|
async def create_channel(self, name: str, type: int = 0, topic: str = "", parent_id: Optional[str] = None, **kwargs) -> str:
|
||||||
server = await self._get_server(populate_channels=True)
|
server = await self._get_server(populate_channels=True)
|
||||||
try:
|
try:
|
||||||
if type == 4: # Category
|
if type == 4: # Category — use direct HTTP to avoid stoat.py auth issues
|
||||||
# The POST /categories endpoint throws 404 on some server versions, so we use server.edit(categories)
|
import aiohttp, random, time
|
||||||
import random
|
|
||||||
import time
|
|
||||||
chars = "0123456789ABCDEFGHJKMNPQRSTVWXYZ"
|
chars = "0123456789ABCDEFGHJKMNPQRSTVWXYZ"
|
||||||
# Mock a ULID
|
ts = int(time.time() * 1000)
|
||||||
new_id = "01" + "".join(random.choice(chars) for _ in range(24))
|
new_id = format(ts, '010X') + "".join(random.choice(chars) for _ in range(16))
|
||||||
|
api_base = (self.api_url or "https://api.stoat.chat/0.8").rstrip("/")
|
||||||
categories = list(server.categories) if hasattr(server, "categories") and server.categories else []
|
# Fetch current server to get existing categories
|
||||||
# Workaround for stoat.py bug: existing categories may fail to_dict() if slots are uninitialized
|
async with aiohttp.ClientSession() as session:
|
||||||
for c in categories:
|
async with session.get(
|
||||||
if not hasattr(c, "default_permissions"): c.default_permissions = None
|
f"{api_base}/servers/{self.community_id}",
|
||||||
if not hasattr(c, "role_permissions"): c.role_permissions = {}
|
headers={"X-Bot-Token": self.token.strip()},
|
||||||
|
) as resp:
|
||||||
new_cat = stoat.Category(id=new_id, title=name, channels=[])
|
sdata = await resp.json()
|
||||||
if not hasattr(new_cat, "default_permissions"): new_cat.default_permissions = None
|
existing = sdata.get("categories") or []
|
||||||
if not hasattr(new_cat, "role_permissions"): new_cat.role_permissions = {}
|
existing.append({"id": new_id, "title": name, "channels": []})
|
||||||
categories.append(new_cat)
|
async with session.patch(
|
||||||
|
f"{api_base}/servers/{self.community_id}",
|
||||||
await server.edit(categories=categories)
|
json={"categories": existing},
|
||||||
self._server = None # Clear cache after structural change
|
headers={"X-Bot-Token": self.token.strip(), "Content-Type": "application/json"},
|
||||||
|
) as resp:
|
||||||
|
data = await resp.json()
|
||||||
|
logger.debug(f"PATCH /servers category resp {resp.status}: {data}")
|
||||||
|
print(f"[writer] PATCH /servers category resp {resp.status}: {data}")
|
||||||
|
if resp.status not in (200, 201):
|
||||||
|
raise Exception(f"HTTP {resp.status}: {data}")
|
||||||
|
self._server = None
|
||||||
return new_id
|
return new_id
|
||||||
elif type == 2: # Voice Channel
|
else:
|
||||||
ch = await server.create_voice_channel(name=name)
|
# Use direct HTTP instead of stoat.py client to avoid aiohttp session issues
|
||||||
self._server = None # Clear cache
|
import aiohttp
|
||||||
return str(ch.id)
|
api_base = (self.api_url or "https://api.stoat.chat/0.8").rstrip("/")
|
||||||
else: # Text Channel
|
channel_type = "Voice" if type == 2 else "Text"
|
||||||
ch = await server.create_text_channel(name=name, description=topic)
|
payload: Dict[str, Any] = {"name": name, "type": channel_type}
|
||||||
self._server = None # Clear cache
|
if topic:
|
||||||
return str(ch.id)
|
payload["description"] = topic
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{api_base}/servers/{self.community_id}/channels",
|
||||||
|
json=payload,
|
||||||
|
headers={"X-Bot-Token": self.token.strip()},
|
||||||
|
) as resp:
|
||||||
|
try:
|
||||||
|
data = await resp.json(content_type=None)
|
||||||
|
except Exception:
|
||||||
|
data = await resp.text()
|
||||||
|
if resp.status not in (200, 201):
|
||||||
|
raise Exception(f"HTTP {resp.status}: {data}")
|
||||||
|
self._server = None
|
||||||
|
return str(data["_id"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"Failed to create Stoat channel {name}: {e}")
|
||||||
logger.error(f"Failed to create Stoat channel {name}: {e}")
|
logger.error(f"Failed to create Stoat channel {name}: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
async def modify_channel(self, channel_id: str, name: Optional[str] = None, topic: Optional[str] = None, nsfw: Optional[bool] = None, slowmode_delay: Optional[int] = None, **kwargs) -> bool:
|
async def modify_channel(self, channel_id: str, name: Optional[str] = None, topic: Optional[str] = None, nsfw: Optional[bool] = None, slowmode_delay: Optional[int] = None, **kwargs) -> bool:
|
||||||
server = await self._get_server(populate_channels=True)
|
import aiohttp
|
||||||
try:
|
api_base = (self.api_url or "https://api.stoat.chat/0.8").rstrip("/")
|
||||||
channel = next((c for c in server.channels if str(c.id) == channel_id), None)
|
payload: Dict[str, Any] = {}
|
||||||
if not channel:
|
|
||||||
return False
|
|
||||||
|
|
||||||
edit_kwargs = {}
|
|
||||||
if name is not None:
|
if name is not None:
|
||||||
edit_kwargs["name"] = name
|
payload["name"] = name
|
||||||
if topic is not None:
|
if topic is not None:
|
||||||
edit_kwargs["description"] = topic
|
payload["description"] = topic
|
||||||
if nsfw is not None:
|
if nsfw is not None:
|
||||||
edit_kwargs["nsfw"] = nsfw
|
payload["nsfw"] = nsfw
|
||||||
|
if not payload:
|
||||||
if edit_kwargs:
|
return True
|
||||||
await channel.edit(**edit_kwargs)
|
try:
|
||||||
self._server = None # Clear cache
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.patch(
|
||||||
# clone_server.py now handles all parenting bulk logic
|
f"{api_base}/channels/{channel_id}",
|
||||||
|
json=payload,
|
||||||
|
headers={"X-Bot-Token": self.token.strip()},
|
||||||
|
) as resp:
|
||||||
|
if resp.status not in (200, 204):
|
||||||
|
data = await resp.json()
|
||||||
|
raise Exception(f"HTTP {resp.status}: {data}")
|
||||||
|
self._server = None
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"Failed to modify Stoat channel {channel_id}: {e}")
|
||||||
logger.error(f"Failed to modify Stoat channel {channel_id}: {e}")
|
logger.error(f"Failed to modify Stoat channel {channel_id}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -397,6 +524,10 @@ class StoatWriter:
|
||||||
color=color
|
color=color
|
||||||
))
|
))
|
||||||
|
|
||||||
|
# Retry logic for 'NotFound' (common race condition on self-hosted instances)
|
||||||
|
max_tries = 2
|
||||||
|
for attempt in range(max_tries):
|
||||||
|
try:
|
||||||
msg = await channel.send(
|
msg = await channel.send(
|
||||||
content=final_content,
|
content=final_content,
|
||||||
masquerade=masquerade,
|
masquerade=masquerade,
|
||||||
|
|
@ -405,6 +536,14 @@ class StoatWriter:
|
||||||
embeds=stoat_embeds
|
embeds=stoat_embeds
|
||||||
)
|
)
|
||||||
return str(msg.id) if msg else None
|
return str(msg.id) if msg else None
|
||||||
|
except Exception as send_err:
|
||||||
|
# 'NotFound' often means the attachment is still being processed by the database
|
||||||
|
if "NotFound" in str(send_err) and attempt < max_tries - 1:
|
||||||
|
logger.warning(f"Stoat: Received NotFound during send (likely race condition). Retrying in 1.5s... (Attempt {attempt+1}/{max_tries})")
|
||||||
|
await asyncio.sleep(1.5)
|
||||||
|
continue
|
||||||
|
raise send_err
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# If file type not allowed, skip attachments and still send the message
|
# If file type not allowed, skip attachments and still send the message
|
||||||
if "FileTypeNotAllowed" in str(e) and attachments:
|
if "FileTypeNotAllowed" in str(e) and attachments:
|
||||||
|
|
@ -419,6 +558,7 @@ class StoatWriter:
|
||||||
return str(msg.id) if msg else None
|
return str(msg.id) if msg else None
|
||||||
raise # Re-raise MissingPermission and other errors
|
raise # Re-raise MissingPermission and other errors
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"Failed to send Stoat message to {channel_id}: {e}")
|
||||||
logger.error(f"Failed to send Stoat message to {channel_id}: {e}")
|
logger.error(f"Failed to send Stoat message to {channel_id}: {e}")
|
||||||
raise # Let caller handle (migration loop will stop for permission errors)
|
raise # Let caller handle (migration loop will stop for permission errors)
|
||||||
|
|
||||||
|
|
@ -442,6 +582,7 @@ class StoatWriter:
|
||||||
)
|
)
|
||||||
return str(msg.id)
|
return str(msg.id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"Failed to send Stoat marker to {channel_id}: {e}")
|
||||||
logger.error(f"Failed to send Stoat marker to {channel_id}: {e}")
|
logger.error(f"Failed to send Stoat marker to {channel_id}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -465,11 +606,39 @@ class StoatWriter:
|
||||||
|
|
||||||
# Set permissions
|
# Set permissions
|
||||||
if permissions != 0:
|
if permissions != 0:
|
||||||
s_perms = self._map_permissions(permissions)
|
requested_perms = self._map_permissions(permissions)
|
||||||
await server.set_role_permissions(role, allow=s_perms)
|
|
||||||
|
# Fetch bot's own permissions to mask the requested set.
|
||||||
|
# Stoat/Revolt prevents granting permissions that the bot itself lacks.
|
||||||
|
try:
|
||||||
|
me = await server.fetch_member(self.my_id)
|
||||||
|
bot_perms = server.permissions_for(me)
|
||||||
|
|
||||||
|
# Manual masking of boolean attributes
|
||||||
|
final_perms = stoat.Permissions.none()
|
||||||
|
for attr in dir(requested_perms):
|
||||||
|
if not attr.startswith("_") and isinstance(getattr(requested_perms, attr), bool):
|
||||||
|
if getattr(requested_perms, attr) and getattr(bot_perms, attr, False):
|
||||||
|
try:
|
||||||
|
setattr(final_perms, attr, True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await server.set_role_permissions(role, allow=final_perms)
|
||||||
|
except Exception as perm_err:
|
||||||
|
logger.warning(f"Stoat: Could not mask/verify role permissions for {name} (falling back to minimal): {perm_err}")
|
||||||
|
# Attempt to set at least one basic permission if possible, or just skip
|
||||||
|
try:
|
||||||
|
minimal = stoat.Permissions.none()
|
||||||
|
minimal.view_channel = True
|
||||||
|
minimal.send_messages = True
|
||||||
|
await server.set_role_permissions(role, allow=minimal)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
return str(role.id)
|
return str(role.id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"Failed to create Stoat role {name}: {e}")
|
||||||
logger.error(f"Failed to create Stoat role {name}: {e}")
|
logger.error(f"Failed to create Stoat role {name}: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
@ -526,6 +695,7 @@ class StoatWriter:
|
||||||
await server.set_default_permissions(s_perms)
|
await server.set_default_permissions(s_perms)
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"Failed to update Stoat default permissions: {e}")
|
||||||
logger.error(f"Failed to update Stoat default permissions: {e}")
|
logger.error(f"Failed to update Stoat default permissions: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -536,6 +706,7 @@ class StoatWriter:
|
||||||
emoji = await server.create_server_emoji(name=name, image=image_bytes)
|
emoji = await server.create_server_emoji(name=name, image=image_bytes)
|
||||||
return str(emoji.id)
|
return str(emoji.id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"Failed to create Stoat emoji {name}: {e}")
|
||||||
logger.error(f"Failed to create Stoat emoji {name}: {e}")
|
logger.error(f"Failed to create Stoat emoji {name}: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
@ -551,6 +722,7 @@ class StoatWriter:
|
||||||
banner=banner if banner is not None else stoat.UNDEFINED
|
banner=banner if banner is not None else stoat.UNDEFINED
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"Failed to update Stoat guild metadata: {e}")
|
||||||
logger.error(f"Failed to update Stoat guild metadata: {e}")
|
logger.error(f"Failed to update Stoat guild metadata: {e}")
|
||||||
|
|
||||||
async def remove_community_logo_and_banner(self) -> dict:
|
async def remove_community_logo_and_banner(self) -> dict:
|
||||||
|
|
@ -562,12 +734,14 @@ class StoatWriter:
|
||||||
try:
|
try:
|
||||||
await server.edit(icon=None)
|
await server.edit(icon=None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"Failed to remove Stoat community icon: {e}")
|
||||||
logger.error(f"Failed to remove Stoat community icon: {e}")
|
logger.error(f"Failed to remove Stoat community icon: {e}")
|
||||||
|
|
||||||
if has_banner:
|
if has_banner:
|
||||||
try:
|
try:
|
||||||
await server.edit(banner=None)
|
await server.edit(banner=None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"Failed to remove Stoat community banner: {e}")
|
||||||
logger.error(f"Failed to remove Stoat community banner: {e}")
|
logger.error(f"Failed to remove Stoat community banner: {e}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -594,6 +768,7 @@ class StoatWriter:
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
await progress_callback(name, i, total)
|
await progress_callback(name, i, total)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"Failed to delete Stoat channel {ch.id}: {e}")
|
||||||
logger.error(f"Failed to delete Stoat channel {ch.id}: {e}")
|
logger.error(f"Failed to delete Stoat channel {ch.id}: {e}")
|
||||||
|
|
||||||
# To delete categories, we can wipe the categories array via server.edit to avoid 404 endpoint
|
# To delete categories, we can wipe the categories array via server.edit to avoid 404 endpoint
|
||||||
|
|
@ -618,6 +793,7 @@ class StoatWriter:
|
||||||
await progress_callback(name, j, total)
|
await progress_callback(name, j, total)
|
||||||
j += 1
|
j += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"Failed to wipe Stoat categories via edit: {e}")
|
||||||
logger.error(f"Failed to wipe Stoat categories via edit: {e}")
|
logger.error(f"Failed to wipe Stoat categories via edit: {e}")
|
||||||
|
|
||||||
return count
|
return count
|
||||||
|
|
@ -634,17 +810,23 @@ class StoatWriter:
|
||||||
logger.info(f"Danger Zone: Skipping permission reset for audit channel {name}")
|
logger.info(f"Danger Zone: Skipping permission reset for audit channel {name}")
|
||||||
total -= 1
|
total -= 1
|
||||||
continue
|
continue
|
||||||
# In Stoat, clearing overrides might involve setting them to default or explicitly removing the role_permissions/default_permissions
|
|
||||||
# Since we don't know an explicit "clear_overrides" method, we'll wipe them by setting empty/none if possible.
|
# Fetch fresh channel to get current role_permissions
|
||||||
# Actually Stoat allows overwriting. Setting allow=0 deny=0 for role overrides isn't explicitly clear.
|
fresh_ch = await self.client.fetch_channel(ch.id)
|
||||||
# For safety, we will just pass. If the user expects it, we'd iterate over roles and set empty.
|
# Clear default permissions
|
||||||
# A quick way is to edit the channel permissions to empty state if possible.
|
if hasattr(fresh_ch, "default_permissions") and fresh_ch.default_permissions is not None:
|
||||||
# Let's count them anyway.
|
await fresh_ch.set_default_permissions(None)
|
||||||
# (Fluxer writer does a loop over existing overrides, we can just return 0 for now until we inspect Stoat `PermissionOverride` deletion)
|
|
||||||
|
# Clear all role overrides
|
||||||
|
if hasattr(fresh_ch, "role_permissions"):
|
||||||
|
for role_id in list(fresh_ch.role_permissions.keys()):
|
||||||
|
await fresh_ch.set_role_permissions(str(role_id), allow=stoat.Permissions.none(), deny=stoat.Permissions.none())
|
||||||
|
|
||||||
count += 1
|
count += 1
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
await progress_callback(name, i, total)
|
await progress_callback(name, i, total)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"Failed to reset Stoat channel permissions for {ch.id}: {e}")
|
||||||
logger.error(f"Failed to reset Stoat channel permissions for {ch.id}: {e}")
|
logger.error(f"Failed to reset Stoat channel permissions for {ch.id}: {e}")
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
|
@ -671,6 +853,7 @@ class StoatWriter:
|
||||||
if "MissingPermission" in err_msg and "ViewChannel" in err_msg:
|
if "MissingPermission" in err_msg and "ViewChannel" in err_msg:
|
||||||
logger.error(f"Stoat LOCKOUT: Bot lacks 'ViewChannel' to edit {channel_id}. "
|
logger.error(f"Stoat LOCKOUT: Bot lacks 'ViewChannel' to edit {channel_id}. "
|
||||||
"Ensure the bot has 'Manage Server' or a role with 'Allow View Channel' rank higher than @everyone.")
|
"Ensure the bot has 'Manage Server' or a role with 'Allow View Channel' rank higher than @everyone.")
|
||||||
|
print(f"Failed to set Stoat channel permission for {overwrite_id} on {channel_id}: {e}")
|
||||||
logger.error(f"Failed to set Stoat channel permission for {overwrite_id} on {channel_id}: {e}")
|
logger.error(f"Failed to set Stoat channel permission for {overwrite_id} on {channel_id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -712,18 +895,31 @@ class StoatWriter:
|
||||||
await emoji.delete()
|
await emoji.delete()
|
||||||
count += 1
|
count += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"Failed to delete Stoat emoji {emoji.name}: {e}")
|
||||||
logger.error(f"Failed to delete Stoat emoji {emoji.name}: {e}")
|
logger.error(f"Failed to delete Stoat emoji {emoji.name}: {e}")
|
||||||
|
|
||||||
return {"emojis": count, "stickers": 0}
|
return {"emojis": count, "stickers": 0}
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
client = self.client
|
client = self.client
|
||||||
|
task = getattr(self, "_start_task", None)
|
||||||
|
|
||||||
self.client = None # Atomic clear to prevent new usage
|
self.client = None # Atomic clear to prevent new usage
|
||||||
|
self._start_task = None
|
||||||
self._me = None
|
self._me = None
|
||||||
self._server = None
|
self._server = None
|
||||||
|
|
||||||
if client:
|
if client:
|
||||||
try:
|
try:
|
||||||
await client.close()
|
await client.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Error closing Stoat client: {e}")
|
logger.debug(f"Error closing Stoat client session: {e}")
|
||||||
|
|
||||||
|
if task and not task.done():
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(task, timeout=2.0)
|
||||||
|
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||||
|
pass
|
||||||
|
|
||||||
self._validation_cache = None
|
self._validation_cache = None
|
||||||
|
|
|
||||||
|
|
@ -164,6 +164,9 @@ class OperationPane(Container):
|
||||||
yield Rule(id="footer_rule")
|
yield Rule(id="footer_rule")
|
||||||
yield Button("Danger Zone ⚠", id="op_danger", variant="error", disabled=True, flat=True, tooltip="Dangerous operations:\ndelete channels, roles, emojis on target\n(use with caution)")
|
yield Button("Danger Zone ⚠", id="op_danger", variant="error", disabled=True, flat=True, tooltip="Dangerous operations:\ndelete channels, roles, emojis on target\n(use with caution)")
|
||||||
|
|
||||||
|
if self.cfg_name == "AutoTest":
|
||||||
|
yield Button("RUN AUTO TEST", id="op_autotest", variant="warning", flat="false", tooltip="Execute automated test sequence for the AutoTest profile")
|
||||||
|
|
||||||
def on_mount(self) -> None:
|
def on_mount(self) -> None:
|
||||||
self._rebuild_engine()
|
self._rebuild_engine()
|
||||||
self._update_info_labels()
|
self._update_info_labels()
|
||||||
|
|
@ -326,8 +329,13 @@ class OperationPane(Container):
|
||||||
for pne in self.query("#op_target_pane"): pne.display = False
|
for pne in self.query("#op_target_pane"): pne.display = False
|
||||||
|
|
||||||
enabled = (v.get("discord_token") and v.get("discord_server") and not d_missing)
|
enabled = (v.get("discord_token") and v.get("discord_server") and not d_missing)
|
||||||
for bid in ("#op_backup_msgs", "#op_backup_sync"):
|
for btn in self.query("#op_backup_msgs"):
|
||||||
for btn in self.query(bid): btn.disabled = not enabled
|
btn.disabled = not enabled
|
||||||
|
for btn in self.query("#op_backup_sync"):
|
||||||
|
btn.display = self.has_backup
|
||||||
|
btn.disabled = not (enabled and self.has_backup)
|
||||||
|
for btn in self.query("#op_autotest"):
|
||||||
|
btn.disabled = not enabled
|
||||||
|
|
||||||
for btn in self.query("#op_backup_stats"):
|
for btn in self.query("#op_backup_stats"):
|
||||||
btn.display = self.has_backup
|
btn.display = self.has_backup
|
||||||
|
|
@ -395,7 +403,7 @@ class OperationPane(Container):
|
||||||
lbl.update(f"{t_status}")
|
lbl.update(f"{t_status}")
|
||||||
|
|
||||||
# Buttons
|
# Buttons
|
||||||
for bid in ("#op_clone", "#op_sync", "#op_messages", "#op_waterfall", "#op_danger"):
|
for bid in ("#op_clone", "#op_sync", "#op_messages", "#op_waterfall", "#op_danger", "#op_autotest"):
|
||||||
for btn in self.query(bid): btn.disabled = not self.tokens_valid
|
for btn in self.query(bid): btn.disabled = not self.tokens_valid
|
||||||
|
|
||||||
# ── validation ────────────────────────────────────────────────────────
|
# ── validation ────────────────────────────────────────────────────────
|
||||||
|
|
@ -416,10 +424,10 @@ class OperationPane(Container):
|
||||||
|
|
||||||
# Disable all operation buttons while validation is in progress
|
# Disable all operation buttons while validation is in progress
|
||||||
if self.view_mode == "shuttle":
|
if self.view_mode == "shuttle":
|
||||||
for bid in ("#op_clone", "#op_sync", "#op_messages", "#op_waterfall", "#op_danger"):
|
for bid in ("#op_clone", "#op_sync", "#op_messages", "#op_waterfall", "#op_danger", "#op_autotest"):
|
||||||
for btn in self.query(bid): btn.disabled = True
|
for btn in self.query(bid): btn.disabled = True
|
||||||
elif self.view_mode == "backup":
|
elif self.view_mode == "backup":
|
||||||
for bid in ("#op_backup_msgs", "#op_backup_sync"):
|
for bid in ("#op_backup_msgs", "#op_backup_sync", "#op_autotest"):
|
||||||
for btn in self.query(bid): btn.disabled = True
|
for btn in self.query(bid): btn.disabled = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in run_validate setup: {e}")
|
logger.error(f"Error in run_validate setup: {e}")
|
||||||
|
|
@ -590,7 +598,124 @@ class OperationPane(Container):
|
||||||
from src.ui.backup_stats import BackupStatsScreen
|
from src.ui.backup_stats import BackupStatsScreen
|
||||||
target_dir = Path(self._base_dir()) / f"DISCORD_BACKUP-{self.config.discord_server_id}"
|
target_dir = Path(self._base_dir()) / f"DISCORD_BACKUP-{self.config.discord_server_id}"
|
||||||
self.app.push_screen(BackupStatsScreen(self.cfg_name, target_dir))
|
self.app.push_screen(BackupStatsScreen(self.cfg_name, target_dir))
|
||||||
|
elif bid == "op_autotest" or bid == "btn_autotest":
|
||||||
|
self.run_autotest_sequence()
|
||||||
|
|
||||||
|
@work(exclusive=True)
|
||||||
|
async def run_autotest_sequence(self) -> None:
|
||||||
|
"""Entry point for the AUTO TEST sequence."""
|
||||||
|
if not self.tokens_valid:
|
||||||
|
return
|
||||||
|
|
||||||
|
modal = ProgressScreen(log_level=self.config.log_level)
|
||||||
|
self.app.push_screen(modal)
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.view_mode == "shuttle":
|
||||||
|
await self._run_migration_autotest_logic(modal)
|
||||||
|
elif self.view_mode == "backup":
|
||||||
|
await self._run_backup_autotest_logic(modal)
|
||||||
|
|
||||||
|
modal.phase_report("AUTO TEST Complete", show_back=False)
|
||||||
|
modal.write("[bold green]Full automated test sequence finished successfully![/bold green]")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Auto-Test Error: {e}\n{traceback.format_exc()}")
|
||||||
|
modal.write(f"[bold red]Error: {e}[/bold red]")
|
||||||
|
modal.phase_report("Auto-Test", "error", show_back=False)
|
||||||
|
finally:
|
||||||
|
self.engine.is_running = False
|
||||||
|
await self.engine.close_connections()
|
||||||
|
|
||||||
|
async def _run_migration_autotest_logic(self, modal: ProgressScreen) -> None:
|
||||||
|
"""Executes the full migration test sequence."""
|
||||||
|
modal.set_status("AUTO TEST: Launching Migration Sequence...")
|
||||||
|
modal.write("[bold yellow]Starting Migration Auto-Test...[/bold yellow]")
|
||||||
|
|
||||||
|
migrate_mod = fluxer_migrate if self.target_platform == "fluxer" else stoat_migrate
|
||||||
|
|
||||||
|
# 1. Connect and initialize state
|
||||||
|
modal.set_status("Connecting and Initializing State...")
|
||||||
|
await self.engine.start_connections()
|
||||||
|
self.engine.is_running = True
|
||||||
|
|
||||||
|
# Initialize state database early to avoid Warnings
|
||||||
|
tid = self.config.fluxer_server_id if self.target_platform == "fluxer" else self.config.stoat_server_id
|
||||||
|
tgt_info = await self.engine.writer.validate()
|
||||||
|
self.engine.ensure_state_initialized(str(tid or ""), tgt_info.get("community_name", "Target"))
|
||||||
|
|
||||||
|
# 2. Danger Zone Clean
|
||||||
|
modal.write("\n[bold red]Phase 1: Danger Zone Cleanup[/bold red]")
|
||||||
|
await self._logic_dz_delete_channels(modal)
|
||||||
|
await self._logic_dz_reset_perms(modal)
|
||||||
|
await self._logic_dz_delete_roles(modal)
|
||||||
|
await self._logic_dz_delete_assets(modal)
|
||||||
|
|
||||||
|
# 3. Clone Roles & Permissions
|
||||||
|
modal.write("\n[bold cyan]Phase 2: Cloning Roles & Permissions[/bold cyan]")
|
||||||
|
await self._logic_clone_roles(modal, force=True)
|
||||||
|
|
||||||
|
# 4. Clone Assets (Logo, Banner, Emojis, Stickers)
|
||||||
|
modal.write("\n[bold cyan]Phase 3: Syncing Metadata & Assets[/bold cyan]")
|
||||||
|
await self._logic_copy_assets(modal, ["Emoji", "Sticker"], force=True)
|
||||||
|
await self._logic_sync_metadata(modal, ["name", "icon", "banner"])
|
||||||
|
|
||||||
|
# 5. Clone Template (Structure)
|
||||||
|
modal.write("\n[bold cyan]Phase 4: Cloning Server Structure (1/2)[/bold cyan]")
|
||||||
|
await self._logic_clone_channels(modal, force=True)
|
||||||
|
await self._logic_sync_permissions(modal)
|
||||||
|
|
||||||
|
# 6. Waterfall Migration (Only for backups)
|
||||||
|
if self.engine.source_mode == "backup" :
|
||||||
|
modal.write("\n[bold cyan]Phase 5: Waterfall Message Migration[/bold cyan]")
|
||||||
|
await self._logic_waterfall_migration(modal=modal, is_autotest=True)
|
||||||
|
else:
|
||||||
|
modal.write("\n[bold yellow]Phase 5: Skipping Waterfall (Live mode selected)[/bold yellow]")
|
||||||
|
|
||||||
|
# 7. Individual Channel Migration (Automated)
|
||||||
|
modal.write("\n[bold cyan]Phase 7: Individual Channel Migration[/bold cyan]")
|
||||||
|
await self._logic_autotest_migrate_all_channels(modal=modal)
|
||||||
|
|
||||||
|
async def _run_backup_autotest_logic(self, modal: ProgressScreen) -> None:
|
||||||
|
"""Executes the full backup test sequence."""
|
||||||
|
import shutil
|
||||||
|
modal.set_status("AUTO TEST: Launching Backup Sequence...")
|
||||||
|
modal.write("[bold yellow]Starting Backup Auto-Test...[/bold yellow]")
|
||||||
|
|
||||||
|
# 1. Clear old backup directory completely
|
||||||
|
server_dir = Path(self._base_dir()) / f"DISCORD_BACKUP-{self.config.discord_server_id}"
|
||||||
|
if server_dir.exists():
|
||||||
|
modal.write(f"[yellow]Deleting existing backup directory: {server_dir}[/yellow]")
|
||||||
|
try:
|
||||||
|
shutil.rmtree(server_dir)
|
||||||
|
except Exception as e:
|
||||||
|
modal.write(f"[red]Warning: Could not delete directory: {e}[/red]")
|
||||||
|
|
||||||
|
# 2. Setup & Metadata
|
||||||
|
modal.set_status("Initializing Discord connection...")
|
||||||
|
await self.engine.discord_reader.start()
|
||||||
|
await self.exporter.setup()
|
||||||
|
self.exporter.is_running = True
|
||||||
|
|
||||||
|
modal.write("\n[bold cyan]Phase 1: Full Server Backup[/bold cyan]")
|
||||||
|
|
||||||
|
# 3. Use unified backup logic in autotest mode
|
||||||
|
all_channels = await self.engine.discord_reader.get_channels()
|
||||||
|
eligible_channels = [
|
||||||
|
c for c in all_channels
|
||||||
|
if c.type in [
|
||||||
|
self.engine.discord_reader.CHANNEL_TYPE_TEXT,
|
||||||
|
self.engine.discord_reader.CHANNEL_TYPE_NEWS,
|
||||||
|
self.engine.discord_reader.CHANNEL_TYPE_FORUM
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
await self._logic_full_backup(
|
||||||
|
modal=modal,
|
||||||
|
selected_channels=eligible_channels,
|
||||||
|
force_overwrite=True,
|
||||||
|
is_autotest=True
|
||||||
|
)
|
||||||
# ── (1) clone server template (combined) ─────────────────────────────
|
# ── (1) clone server template (combined) ─────────────────────────────
|
||||||
|
|
||||||
def _open_clone_menu(self):
|
def _open_clone_menu(self):
|
||||||
|
|
@ -1012,13 +1137,64 @@ class OperationPane(Container):
|
||||||
# ── (5) message migration ─────────────────────────────────────────────
|
# ── (5) message migration ─────────────────────────────────────────────
|
||||||
|
|
||||||
@work(exclusive=True)
|
@work(exclusive=True)
|
||||||
async def run_migrate_messages(self) -> None:
|
async def run_migrate_messages(self, modal: ProgressScreen | None = None) -> None:
|
||||||
|
await self._logic_migrate_messages(modal)
|
||||||
|
|
||||||
|
async def _logic_autotest_migrate_all_channels(self, modal: ProgressScreen) -> None:
|
||||||
|
"""Automated name-based channel migration for Auto-Test."""
|
||||||
|
migrate_mod = fluxer_migrate if self.target_platform == "fluxer" else stoat_migrate
|
||||||
|
|
||||||
|
# 1. Matching
|
||||||
|
modal.set_status("Auto-matching channels by name...")
|
||||||
|
await self._perform_auto_matching()
|
||||||
|
|
||||||
|
# 2. Get channels
|
||||||
|
d_channels = await self.engine.discord_reader.get_channels()
|
||||||
|
text_channels = [c for c in d_channels if c.type in [
|
||||||
|
self.engine.discord_reader.CHANNEL_TYPE_TEXT,
|
||||||
|
self.engine.discord_reader.CHANNEL_TYPE_NEWS
|
||||||
|
]]
|
||||||
|
|
||||||
|
modal.write(f"[bold cyan]Auto-Test: Found {len(text_channels)} channels to migrate.[/bold cyan]")
|
||||||
|
|
||||||
|
# 3. Migrate loop
|
||||||
|
for i, ch in enumerate(text_channels):
|
||||||
|
if not self.engine.is_running: break
|
||||||
|
|
||||||
|
tgt_id = self.engine.state.get_target_channel_id(str(ch.id))
|
||||||
|
if not tgt_id:
|
||||||
|
modal.write(f"[yellow]Skipping #{ch.name} (no target mapping found)[/yellow]")
|
||||||
|
continue
|
||||||
|
|
||||||
|
modal.write(f"\n[bold]Migrating #{ch.name} ({i+1}/{len(text_channels)}) -> Target ID {tgt_id}[/bold]")
|
||||||
|
|
||||||
|
# Analyze (silent)
|
||||||
|
stats = await migrate_mod.analyze_migration(self.engine, source_channel_id=ch.id, after_message_id=None)
|
||||||
|
ch_total = stats["messages"]
|
||||||
|
|
||||||
|
async def update_indiv(curr):
|
||||||
|
c = curr["messages"]
|
||||||
|
modal.set_progress(c, ch_total or 100)
|
||||||
|
modal.set_item_status(f"#{ch.name}: {c}/{ch_total} messages")
|
||||||
|
|
||||||
|
await migrate_mod.migrate_messages(
|
||||||
|
self.engine,
|
||||||
|
source_channel_id=ch.id,
|
||||||
|
target_channel_id=tgt_id,
|
||||||
|
after_message_id=None,
|
||||||
|
progress_callback=update_indiv
|
||||||
|
)
|
||||||
|
|
||||||
|
modal.write("\n[bold green]Automated channel migration complete.[/bold green]")
|
||||||
|
|
||||||
|
async def _logic_migrate_messages(self, modal: ProgressScreen | None = None, is_autotest: bool = False) -> None:
|
||||||
if not self.tokens_valid:
|
if not self.tokens_valid:
|
||||||
return
|
return
|
||||||
|
|
||||||
migrate_mod = fluxer_migrate if self.target_platform == "fluxer" else stoat_migrate
|
migrate_mod = fluxer_migrate if self.target_platform == "fluxer" else stoat_migrate
|
||||||
platform_name = self.target_platform.capitalize()
|
platform_name = self.target_platform.capitalize()
|
||||||
|
|
||||||
|
if not modal:
|
||||||
modal = ProgressScreen(log_level=self.config.log_level)
|
modal = ProgressScreen(log_level=self.config.log_level)
|
||||||
self.app.push_screen(modal)
|
self.app.push_screen(modal)
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
@ -1418,13 +1594,17 @@ class OperationPane(Container):
|
||||||
await self.engine.close_connections()
|
await self.engine.close_connections()
|
||||||
|
|
||||||
@work(exclusive=True)
|
@work(exclusive=True)
|
||||||
async def run_waterfall_migration(self) -> None:
|
async def run_waterfall_migration(self, modal: ProgressScreen | None = None) -> None:
|
||||||
|
await self._logic_waterfall_migration(modal)
|
||||||
|
|
||||||
|
async def _logic_waterfall_migration(self, modal: ProgressScreen | None = None, is_autotest: bool = False) -> None:
|
||||||
if not self.tokens_valid:
|
if not self.tokens_valid:
|
||||||
return
|
return
|
||||||
|
|
||||||
migrate_mod = fluxer_migrate if self.target_platform == "fluxer" else stoat_migrate
|
migrate_mod = fluxer_migrate if self.target_platform == "fluxer" else stoat_migrate
|
||||||
platform_name = self.target_platform.capitalize()
|
platform_name = self.target_platform.capitalize()
|
||||||
|
|
||||||
|
if not modal:
|
||||||
modal = ProgressScreen(log_level=self.config.log_level)
|
modal = ProgressScreen(log_level=self.config.log_level)
|
||||||
self.app.push_screen(modal)
|
self.app.push_screen(modal)
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
@ -1434,6 +1614,9 @@ class OperationPane(Container):
|
||||||
modal.set_status("Connecting to Servers...")
|
modal.set_status("Connecting to Servers...")
|
||||||
await self.engine.start_connections()
|
await self.engine.start_connections()
|
||||||
|
|
||||||
|
# Ensure writer is validated before auto-matching (fixes NoneType role fetch error)
|
||||||
|
await self.engine.writer.validate()
|
||||||
|
|
||||||
modal.set_status("Synchronizing entity mappings...")
|
modal.set_status("Synchronizing entity mappings...")
|
||||||
await self._perform_auto_matching()
|
await self._perform_auto_matching()
|
||||||
|
|
||||||
|
|
@ -1460,6 +1643,10 @@ class OperationPane(Container):
|
||||||
prefix = "[bold cyan]📁[/bold cyan] " if mc.type == 4 else "[bold white]#[/bold white] "
|
prefix = "[bold cyan]📁[/bold cyan] " if mc.type == 4 else "[bold white]#[/bold white] "
|
||||||
modal.write(f" {prefix}{mc.name}")
|
modal.write(f" {prefix}{mc.name}")
|
||||||
|
|
||||||
|
if is_autotest:
|
||||||
|
choice = "btn_start_first"
|
||||||
|
modal.write("[bold cyan]Auto-Test: Automatically cloning missing channels.[/bold cyan]")
|
||||||
|
else:
|
||||||
choice = await modal.phase_wait_confirm(
|
choice = await modal.phase_wait_confirm(
|
||||||
show_continue=False,
|
show_continue=False,
|
||||||
show_id=True,
|
show_id=True,
|
||||||
|
|
@ -1551,9 +1738,7 @@ class OperationPane(Container):
|
||||||
if filtered_tgt_ids:
|
if filtered_tgt_ids:
|
||||||
all_mapped_tgt_ids = filtered_tgt_ids
|
all_mapped_tgt_ids = filtered_tgt_ids
|
||||||
|
|
||||||
# 2.6 Resume Point: Prioritize Global waterfall tracker, fallback to channel minimums
|
# 2.6 Resume Point: Calculate from global channel minimums
|
||||||
min_last_id = self.engine.state.get_waterfall_last_id()
|
|
||||||
if min_last_id is None:
|
|
||||||
min_last_id = self.engine.state.get_global_min_last_message_id(all_mapped_tgt_ids)
|
min_last_id = self.engine.state.get_global_min_last_message_id(all_mapped_tgt_ids)
|
||||||
|
|
||||||
modal.write(f"\n[bold cyan]Waterfall Migration Resume Point:[/bold cyan]")
|
modal.write(f"\n[bold cyan]Waterfall Migration Resume Point:[/bold cyan]")
|
||||||
|
|
@ -1562,11 +1747,15 @@ class OperationPane(Container):
|
||||||
else:
|
else:
|
||||||
modal.write("No previous migration state found. Starting from the beginning.")
|
modal.write("No previous migration state found. Starting from the beginning.")
|
||||||
|
|
||||||
|
if is_autotest:
|
||||||
|
choice = "btn_continue" if min_last_id is not None else "btn_start_first"
|
||||||
|
modal.write(f"[bold cyan]Auto-Test: Automatically choosing {choice.replace('btn_', '').replace('_', ' ')}.[/bold cyan]")
|
||||||
|
else:
|
||||||
choice = await modal.phase_wait_confirm(
|
choice = await modal.phase_wait_confirm(
|
||||||
show_continue=min_last_id is not None,
|
show_continue=min_last_id is not None,
|
||||||
show_id=False,
|
show_id=False,
|
||||||
btn_start_label="Start From Beginning",
|
btn_start_label="Start From Beginning",
|
||||||
btn_start_tooltip="Safe, skips duplicates automatically",
|
btn_start_tooltip="Wipes migration progress and restarts from the beginning; may create duplicates",
|
||||||
btn_start_variant="default" if min_last_id is not None else "primary",
|
btn_start_variant="default" if min_last_id is not None else "primary",
|
||||||
btn_continue_label=f"Continue from ID {min_last_id if min_last_id is not None else 0}" if min_last_id is not None else "Continue Migration",
|
btn_continue_label=f"Continue from ID {min_last_id if min_last_id is not None else 0}" if min_last_id is not None else "Continue Migration",
|
||||||
btn_continue_tooltip="Fastest"
|
btn_continue_tooltip="Fastest"
|
||||||
|
|
@ -1582,7 +1771,11 @@ class OperationPane(Container):
|
||||||
return
|
return
|
||||||
|
|
||||||
after_id = None
|
after_id = None
|
||||||
if choice == "btn_continue" and min_last_id is not None:
|
if choice == "btn_start_first":
|
||||||
|
logger.info("Proceeding with 'Start from Beginning' (global clean sink).")
|
||||||
|
self.engine.state.clear_all_migration_data()
|
||||||
|
after_id = None
|
||||||
|
elif choice == "btn_continue" and min_last_id is not None:
|
||||||
after_id = int(min_last_id)
|
after_id = int(min_last_id)
|
||||||
|
|
||||||
# Phase 3: Progress
|
# Phase 3: Progress
|
||||||
|
|
@ -1599,6 +1792,7 @@ class OperationPane(Container):
|
||||||
tid = self.config.fluxer_server_id
|
tid = self.config.fluxer_server_id
|
||||||
self.engine.ensure_state_initialized(str(tid or ""), platform_name)
|
self.engine.ensure_state_initialized(str(tid or ""), platform_name)
|
||||||
|
|
||||||
|
modal.show_stats()
|
||||||
modal.write("Scanning global footprint for totals ...")
|
modal.write("Scanning global footprint for totals ...")
|
||||||
stats_analysis = await migrate_mod.analyze_global_migration(self.engine, after_message_id=after_id)
|
stats_analysis = await migrate_mod.analyze_global_migration(self.engine, after_message_id=after_id)
|
||||||
total_messages = stats_analysis["messages"]
|
total_messages = stats_analysis["messages"]
|
||||||
|
|
@ -2087,6 +2281,7 @@ class OperationPane(Container):
|
||||||
|
|
||||||
@work(exclusive=True)
|
@work(exclusive=True)
|
||||||
async def run_backup_messages(self) -> None:
|
async def run_backup_messages(self) -> None:
|
||||||
|
"""UI entry point for full backup."""
|
||||||
modal_prog = ProgressScreen(log_level=self.config.log_level)
|
modal_prog = ProgressScreen(log_level=self.config.log_level)
|
||||||
self.app.push_screen(modal_prog)
|
self.app.push_screen(modal_prog)
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
@ -2096,38 +2291,6 @@ class OperationPane(Container):
|
||||||
await self.engine.discord_reader.start()
|
await self.engine.discord_reader.start()
|
||||||
await self.exporter.setup()
|
await self.exporter.setup()
|
||||||
|
|
||||||
# Check if profile is empty
|
|
||||||
profile_exists = False
|
|
||||||
if self.exporter.db:
|
|
||||||
try:
|
|
||||||
profile_exists = self.exporter.db.get_guild_profile() is not None
|
|
||||||
except Exception:
|
|
||||||
profile_exists = False
|
|
||||||
|
|
||||||
if not profile_exists:
|
|
||||||
modal_prog.set_status("First-Time Setup: Exporting Server Profile...")
|
|
||||||
modal_prog.write("[yellow]No existing profile found. Performing primary profile backup...[/yellow]")
|
|
||||||
|
|
||||||
modal_prog.write("[yellow]Exporting server metadata...[/yellow]")
|
|
||||||
await self.exporter.export_metadata()
|
|
||||||
|
|
||||||
modal_prog.write("[yellow]Syncing server assets (icon/banner)...[/yellow]")
|
|
||||||
await self.exporter.download_server_assets()
|
|
||||||
|
|
||||||
modal_prog.write("[yellow]Exporting server structure...[/yellow]")
|
|
||||||
await self.exporter.export_channels_structure()
|
|
||||||
|
|
||||||
modal_prog.write("[yellow]Exporting roles & permissions...[/yellow]")
|
|
||||||
await self.exporter.export_roles()
|
|
||||||
|
|
||||||
modal_prog.write("[yellow]Exporting custom emojis & stickers...[/yellow]")
|
|
||||||
await self.exporter.export_assets()
|
|
||||||
|
|
||||||
modal_prog.write("[bold green]Primary profile setup complete![/bold green]")
|
|
||||||
modal_prog.write("")
|
|
||||||
else:
|
|
||||||
modal_prog.write("[dim]Existing profile detected. Scanning structure...[/dim]")
|
|
||||||
await self.exporter.export_channels_structure()
|
|
||||||
all_channels = await self.engine.discord_reader.get_channels()
|
all_channels = await self.engine.discord_reader.get_channels()
|
||||||
all_categories = await self.engine.discord_reader.get_categories()
|
all_categories = await self.engine.discord_reader.get_categories()
|
||||||
cat_map = {c.id: c.name for c in all_categories}
|
cat_map = {c.id: c.name for c in all_categories}
|
||||||
|
|
@ -2146,8 +2309,9 @@ class OperationPane(Container):
|
||||||
modal_prog.allow_close()
|
modal_prog.allow_close()
|
||||||
return
|
return
|
||||||
|
|
||||||
any_found = False
|
# Analyze which are already backed up
|
||||||
backed_up_ids = set()
|
backed_up_ids = set()
|
||||||
|
any_found = False
|
||||||
if self.exporter.db:
|
if self.exporter.db:
|
||||||
channel_stats = self.exporter.db.get_stats_by_channel()
|
channel_stats = self.exporter.db.get_stats_by_channel()
|
||||||
for chan in eligible_channels:
|
for chan in eligible_channels:
|
||||||
|
|
@ -2155,15 +2319,11 @@ class OperationPane(Container):
|
||||||
any_found = True
|
any_found = True
|
||||||
backed_up_ids.add(chan.id)
|
backed_up_ids.add(chan.id)
|
||||||
|
|
||||||
self.app.pop_screen()
|
# Manual selection
|
||||||
|
|
||||||
while True:
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
future = loop.create_future()
|
future = loop.create_future()
|
||||||
|
|
||||||
def check_channels(reply: dict | None) -> None:
|
def check_channels(reply: dict | None) -> None:
|
||||||
if not future.done():
|
if not future.done(): future.set_result(reply)
|
||||||
future.set_result(reply)
|
|
||||||
|
|
||||||
self.app.push_screen(
|
self.app.push_screen(
|
||||||
ChannelSelectScreen(eligible_channels, cat_map, backed_up_ids, any_found),
|
ChannelSelectScreen(eligible_channels, cat_map, backed_up_ids, any_found),
|
||||||
|
|
@ -2171,71 +2331,93 @@ class OperationPane(Container):
|
||||||
)
|
)
|
||||||
|
|
||||||
reply = await future
|
reply = await future
|
||||||
if not reply:
|
if not reply: return
|
||||||
return
|
|
||||||
|
|
||||||
selected_ids = reply["channels"]
|
selected_ids = reply["channels"]
|
||||||
force_overwrite = reply["force"]
|
force_overwrite = reply["force"]
|
||||||
selected_channels = [c for c in eligible_channels if c.id in selected_ids]
|
selected_channels = [c for c in eligible_channels if c.id in selected_ids]
|
||||||
|
|
||||||
# Phase 2: Confirmation
|
# Confirmation phase
|
||||||
modal_prog = ProgressScreen(log_level=self.config.log_level) # Re-instantiate to avoid Textual re-push UI freeze
|
|
||||||
self.app.push_screen(modal_prog)
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
new_channels = [c for c in selected_channels if c.id not in backed_up_ids]
|
new_channels = [c for c in selected_channels if c.id not in backed_up_ids]
|
||||||
existing_channels = [c for c in selected_channels if c.id in backed_up_ids]
|
existing_channels = [c for c in selected_channels if c.id in backed_up_ids]
|
||||||
|
|
||||||
server = getattr(self.engine.discord_reader, 'guild', None)
|
modal_confirm = ProgressScreen(log_level=self.config.log_level)
|
||||||
if server:
|
self.app.push_screen(modal_confirm)
|
||||||
modal_prog.write(f"[bold cyan]Server Profile:[/bold cyan]")
|
await asyncio.sleep(0.1)
|
||||||
modal_prog.write(f" Name: [green]{server.name}[/green]")
|
|
||||||
modal_prog.write(f" Icon: [green]{'Present' if server.icon else 'None'}[/green]")
|
|
||||||
modal_prog.write("")
|
|
||||||
|
|
||||||
modal_prog.set_status(f"Confirm to proceed with Backup of [bold]{len(selected_channels)}[/bold] channels")
|
modal_confirm.set_status(f"Confirm to proceed with Backup of [bold]{len(selected_channels)}[/bold] channels")
|
||||||
modal_prog.show_info(f"[cyan]Backup Channels[/cyan]", f"{len(new_channels)} new, {len(existing_channels)} existing")
|
modal_confirm.show_info(f"[cyan]Backup Channels[/cyan]", f"{len(new_channels)} new, {len(existing_channels)} existing")
|
||||||
|
|
||||||
# Show categorized channel lists in the bottom log
|
choice = await modal_confirm.phase_wait_confirm(btn_start_label="Start Channel Backup", show_id=False)
|
||||||
if new_channels:
|
if choice != "btn_start_first":
|
||||||
modal_prog.write("[bold green]New Backups to be created:[/bold green]")
|
modal_confirm.dismiss()
|
||||||
for idx, c in enumerate(new_channels):
|
|
||||||
modal_prog.write(f" {idx+1}. #{c.name}")
|
|
||||||
|
|
||||||
if existing_channels:
|
|
||||||
action = "Overwritten" if force_overwrite else "Updated"
|
|
||||||
modal_prog.write(f"[bold yellow]\nExisting backups to be {action}:[/bold yellow]")
|
|
||||||
for idx, c in enumerate(existing_channels):
|
|
||||||
modal_prog.write(f" {idx+1}. #{c.name}")
|
|
||||||
|
|
||||||
choice = await modal_prog.phase_wait_confirm(btn_start_label="Start Channel Backup", show_id=False)
|
|
||||||
if choice == "btn_back":
|
|
||||||
modal_prog.dismiss()
|
|
||||||
continue
|
|
||||||
elif choice == "btn_start_id":
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
future = loop.create_future()
|
|
||||||
def id_callback(res: int | None) -> None:
|
|
||||||
if not future.done():
|
|
||||||
future.set_result(res)
|
|
||||||
|
|
||||||
id_modal = MessageIDInputModal(self.engine.discord_reader, selected_channels[0].id)
|
|
||||||
self.app.push_screen(id_modal, id_callback)
|
|
||||||
verified_id = await future
|
|
||||||
|
|
||||||
if verified_id is None:
|
|
||||||
# User cancelled the ID input
|
|
||||||
continue
|
|
||||||
|
|
||||||
after_id = verified_id
|
|
||||||
elif choice == "btn_main_menu":
|
|
||||||
modal_prog.dismiss()
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# If we are here, proceeding either via Start First or Start from ID (after_id)
|
modal_confirm.cancel_callback = lambda: setattr(self.exporter, "is_running", False)
|
||||||
if choice == "btn_start_first":
|
modal_confirm.phase_progress()
|
||||||
after_id = None
|
|
||||||
break
|
await self._logic_full_backup(
|
||||||
|
modal=modal_confirm,
|
||||||
|
selected_channels=selected_channels,
|
||||||
|
force_overwrite=force_overwrite,
|
||||||
|
is_autotest=False
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Backup Error: {traceback.format_exc()}")
|
||||||
|
modal_prog.write(f"[bold red]Backup failed: {e}[/bold red]")
|
||||||
|
modal_prog.phase_report("Backup", "error", show_back=False)
|
||||||
|
finally:
|
||||||
|
await self.engine.close_connections()
|
||||||
|
|
||||||
|
async def _logic_full_backup(self, modal: ProgressScreen, selected_channels: list, force_overwrite: bool, is_autotest: bool = False) -> None:
|
||||||
|
"""Non-interactive core backup logic."""
|
||||||
|
if not self.exporter.is_running:
|
||||||
|
self.exporter.is_running = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Metadata and Assets
|
||||||
|
modal.set_status("Exporting Server Profile...")
|
||||||
|
await self.exporter.export_metadata()
|
||||||
|
await self.exporter.download_server_assets()
|
||||||
|
await self.exporter.export_channels_structure()
|
||||||
|
await self.exporter.export_roles()
|
||||||
|
await self.exporter.export_assets()
|
||||||
|
|
||||||
|
# Pre-fetch all members once for role resolution during message export
|
||||||
|
modal.set_status("Pre-fetching server members...")
|
||||||
|
await self.exporter.prefetch_members()
|
||||||
|
|
||||||
|
# 2. Channel Messages
|
||||||
|
total_chans = len(selected_channels)
|
||||||
|
modal.write(f"\n[bold cyan]Backing up {total_chans} channels...[/bold cyan]")
|
||||||
|
modal.show_stats()
|
||||||
|
|
||||||
|
for i, chan in enumerate(selected_channels):
|
||||||
|
if not self.exporter.is_running: break
|
||||||
|
|
||||||
|
modal.set_item_status(f"[cyan]Processing ({i+1}/{total_chans}): #{chan.name}[/cyan]")
|
||||||
|
modal.set_progress(i, total_chans)
|
||||||
|
modal.write(f"[cyan]Backing up: #{chan.name}[/cyan]")
|
||||||
|
|
||||||
|
async def update_backup(name, count, author_name=None, message_preview=None, thread_count=0, file_count=0):
|
||||||
|
modal.update_stats(messages=str(count), threads=str(thread_count), files=str(file_count))
|
||||||
|
if author_name and message_preview and count % 20 == 0:
|
||||||
|
modal.write(f"[dim]{author_name}:[/dim] {message_preview}")
|
||||||
|
|
||||||
|
await self.exporter.export_channel_messages(
|
||||||
|
chan.id, progress_callback=update_backup, force=force_overwrite
|
||||||
|
)
|
||||||
|
modal.write(f"[green]Completed: #{chan.name}[/green]")
|
||||||
|
|
||||||
|
modal.set_progress(total_chans, total_chans)
|
||||||
|
modal.write("[bold green]Backup complete![/bold green]")
|
||||||
|
modal.phase_report("Full Backup", show_back=False)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
modal.write(f"[bold red]Core backup failed: {e}[/bold red]")
|
||||||
|
logger.error(f"Core Backup Error: {traceback.format_exc()}")
|
||||||
|
raise e
|
||||||
|
|
||||||
modal_prog.phase_progress()
|
modal_prog.phase_progress()
|
||||||
modal_prog.show_stats()
|
modal_prog.show_stats()
|
||||||
|
|
|
||||||
170
tests/conftest.py
Normal file
170
tests/conftest.py
Normal file
|
|
@ -0,0 +1,170 @@
|
||||||
|
import pytest
|
||||||
|
import shutil
|
||||||
|
import yaml
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from src.core.backup_database import BackupDatabase
|
||||||
|
from src.core.configuration import AppConfig
|
||||||
|
|
||||||
|
LOG_FILE = (Path(__file__).parent.parent / ".reaper_tests.log").resolve()
|
||||||
|
|
||||||
|
def _log(msg: str):
|
||||||
|
"""Write a message to the shared test log file."""
|
||||||
|
with open(LOG_FILE, "a") as f:
|
||||||
|
f.write(f"{msg}\n")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def log():
|
||||||
|
"""Fixture to provide the logging function to tests."""
|
||||||
|
return _log
|
||||||
|
|
||||||
|
def pytest_configure(config):
|
||||||
|
"""Register global warning filters and marks."""
|
||||||
|
# Register marks if needed
|
||||||
|
config.addinivalue_line("markers", "asyncio: mark test as asyncio")
|
||||||
|
|
||||||
|
# Silence benign async mock and ResourceWarnings globally
|
||||||
|
config.addinivalue_line("filterwarnings", "ignore::RuntimeWarning")
|
||||||
|
config.addinivalue_line("filterwarnings", "ignore::ResourceWarning")
|
||||||
|
|
||||||
|
def pytest_sessionstart(session):
|
||||||
|
"""Clear the log file at the beginning of the test session."""
|
||||||
|
with open(LOG_FILE, "w") as f:
|
||||||
|
f.write("--- Reaper Test Session Started ---\n")
|
||||||
|
|
||||||
|
def pytest_report_header(config):
|
||||||
|
"""Print data source status to the console header."""
|
||||||
|
test_data_dir = (Path(__file__).parent.parent / "ReaperFiles-AutoTest").resolve()
|
||||||
|
if test_data_dir.exists():
|
||||||
|
return f"[DATA_SOURCE] Automated Test Directory: FOUND ({test_data_dir.name})"
|
||||||
|
else:
|
||||||
|
return "[DATA_SOURCE] Automated Test Directory: NOT FOUND (Falling back to MOCKS)"
|
||||||
|
|
||||||
|
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
|
||||||
|
def pytest_runtest_makereport(item, call):
|
||||||
|
"""Hook to catch test results and log them to .reaper_tests.log."""
|
||||||
|
outcome = yield
|
||||||
|
rep = outcome.get_result()
|
||||||
|
|
||||||
|
# We only log on the 'call' phase (the actual test run)
|
||||||
|
if rep.when == "call":
|
||||||
|
status = rep.outcome.upper()
|
||||||
|
_log(f"RESULT: {item.nodeid} -> {status}")
|
||||||
|
if rep.failed:
|
||||||
|
_log(f"--- FAILURE DETAILS for {item.name} ---\n{rep.longreprtext}\n---------------------")
|
||||||
|
elif not rep.passed:
|
||||||
|
# Catch setup/teardown failures
|
||||||
|
_log(f"ERROR in {item.nodeid} [{rep.when}]: {rep.outcome.upper()}")
|
||||||
|
if rep.failed:
|
||||||
|
_log(f"--- ERROR DETAILS ---\n{rep.longreprtext}\n---------------------")
|
||||||
|
|
||||||
|
def pytest_warning_recorded(warning_message, when, nodeid, location):
|
||||||
|
"""Hook to catch and log warnings to .reaper_tests.log."""
|
||||||
|
msg = f"WARNING: {warning_message.message}"
|
||||||
|
if nodeid:
|
||||||
|
msg = f"WARNING in {nodeid} [{when}]: {warning_message.message}"
|
||||||
|
_log(msg)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_data_dir():
|
||||||
|
return (Path(__file__).parent.parent / "ReaperFiles-AutoTest").resolve()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def reaper_config(test_data_dir):
|
||||||
|
config_path = test_data_dir / "reaper_config.yaml"
|
||||||
|
if config_path.exists():
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
return data
|
||||||
|
# Fallback mock config
|
||||||
|
return {
|
||||||
|
"discord_bot_token": "mock_discord_token",
|
||||||
|
"discord_server_id": "123456789012345678",
|
||||||
|
"tool_mode": "backup_transfer",
|
||||||
|
"target_platform": "fluxer",
|
||||||
|
"fluxer_bot_token": "mock_fluxer_token",
|
||||||
|
"fluxer_server_id": "987654321098765432",
|
||||||
|
"stoat_bot_token": "mock_stoat_token",
|
||||||
|
"stoat_server_id": "MOCK_STOAT_COMMUNITY",
|
||||||
|
"anonymize_users": False,
|
||||||
|
"log_level": "DEBUG"
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_db(test_data_dir, tmp_path, reaper_config):
|
||||||
|
# If the real test data exists, use it
|
||||||
|
if test_data_dir.exists():
|
||||||
|
target_id = reaper_config.get("fluxer_server_id") if reaper_config.get("target_platform") == "fluxer" else reaper_config.get("stoat_server_id")
|
||||||
|
db_files = list(test_data_dir.glob(f"*-{target_id}.db"))
|
||||||
|
if not db_files:
|
||||||
|
db_files = list(test_data_dir.glob("*.db"))
|
||||||
|
|
||||||
|
if db_files:
|
||||||
|
original_db = db_files[0]
|
||||||
|
temp_db_path = tmp_path / "test_migration.db"
|
||||||
|
print(f"[DATA_SOURCE] USE_SAMPLE_DB: {original_db.name}")
|
||||||
|
_log(f"[DATA_SOURCE] USE_SAMPLE_DB: {original_db.name}")
|
||||||
|
shutil.copy(original_db, temp_db_path)
|
||||||
|
return temp_db_path
|
||||||
|
|
||||||
|
# Fallback: Create an empty mock database with the required schema
|
||||||
|
temp_db_path = tmp_path / "mock_migration.db"
|
||||||
|
print("[DATA_SOURCE] USE_MOCK_DB: Fallback to empty schema")
|
||||||
|
_log("[DATA_SOURCE] USE_MOCK_DB: Fallback to empty schema")
|
||||||
|
db = BackupDatabase(temp_db_path)
|
||||||
|
# The BackupDatabase constructor already initializes the schema
|
||||||
|
return temp_db_path
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def backup_db(temp_db):
|
||||||
|
db = BackupDatabase(temp_db)
|
||||||
|
yield db
|
||||||
|
# No explicit close needed for now as it's a persistent connection
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def backup_reader(test_data_dir, reaper_config, tmp_path):
|
||||||
|
sid = reaper_config.get("discord_server_id")
|
||||||
|
backup_path = test_data_dir / f"DISCORD_BACKUP-{sid}"
|
||||||
|
|
||||||
|
if not test_data_dir.exists() or not backup_path.exists():
|
||||||
|
# Fallback: create mock backup structure
|
||||||
|
mock_path = tmp_path / f"DISCORD_BACKUP-{sid}"
|
||||||
|
print(f"[DATA_SOURCE] USE_MOCK_BACKUP: {mock_path.name}")
|
||||||
|
_log(f"[DATA_SOURCE] USE_MOCK_BACKUP: {mock_path.name}")
|
||||||
|
mock_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
db_path = mock_path / "backup.db"
|
||||||
|
db = BackupDatabase(db_path)
|
||||||
|
# Populate with minimal mock data for BackupReader to work
|
||||||
|
db._conn.execute("INSERT OR IGNORE INTO guild_profile (id, name) VALUES (?, ?)", (int(sid), "Mock Guild"))
|
||||||
|
db._conn.execute("INSERT OR IGNORE INTO channels (id, name, type) VALUES (?, ?, ?)", (123, "mock-channel", 0))
|
||||||
|
db._conn.commit()
|
||||||
|
from src.core.backup_reader import BackupReader
|
||||||
|
return BackupReader(mock_path)
|
||||||
|
|
||||||
|
print(f"[DATA_SOURCE] USE_SAMPLE_BACKUP: {backup_path.name}")
|
||||||
|
_log(f"[DATA_SOURCE] USE_SAMPLE_BACKUP: {backup_path.name}")
|
||||||
|
from src.core.backup_reader import BackupReader
|
||||||
|
return BackupReader(backup_path)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_discord_reader():
|
||||||
|
reader = MagicMock()
|
||||||
|
reader.guild = MagicMock()
|
||||||
|
reader.fetch_message_history = AsyncMock()
|
||||||
|
reader.download_attachment = AsyncMock(return_value=b"fake_data")
|
||||||
|
reader.download_sticker = AsyncMock(return_value=b"fake_sticker_data")
|
||||||
|
return reader
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_fluxer_writer():
|
||||||
|
writer = MagicMock()
|
||||||
|
writer.send_message = AsyncMock(return_value="fluxer_msg_123")
|
||||||
|
writer.send_marker = AsyncMock()
|
||||||
|
return writer
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_stoat_writer():
|
||||||
|
writer = MagicMock()
|
||||||
|
writer.send_message = AsyncMock(return_value="stoat_msg_123")
|
||||||
|
writer.send_marker = AsyncMock()
|
||||||
|
return writer
|
||||||
139
tests/test_database.py
Normal file
139
tests/test_database.py
Normal file
|
|
@ -0,0 +1,139 @@
|
||||||
|
import pytest
|
||||||
|
from src.core.backup_database import BackupDatabase
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db():
|
||||||
|
from tests.conftest import _log
|
||||||
|
msg = "[DATA_SOURCE] UNIT_TEST: Using in-memory database"
|
||||||
|
print(msg)
|
||||||
|
_log(msg)
|
||||||
|
return BackupDatabase(":memory:")
|
||||||
|
|
||||||
|
|
||||||
|
def test_db_initialization(db):
|
||||||
|
res = db._conn.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='guild_profile'"
|
||||||
|
).fetchone()
|
||||||
|
assert res is not None
|
||||||
|
assert res["name"] == "guild_profile"
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_get_guild_profile(db):
|
||||||
|
profile = {
|
||||||
|
"id": "123456789",
|
||||||
|
"name": "Test Server",
|
||||||
|
"description": "Testing",
|
||||||
|
"owner_id": "987654321",
|
||||||
|
"ignore_channels": ["111", "222"],
|
||||||
|
}
|
||||||
|
db.set_guild_profile(profile)
|
||||||
|
got = db.get_guild_profile()
|
||||||
|
assert got["name"] == "Test Server"
|
||||||
|
assert got["id"] == 123456789
|
||||||
|
assert got["ignore_channels"] == ["111", "222"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_get_roles(db):
|
||||||
|
roles = [
|
||||||
|
{"id": "1", "name": "Admin", "color": 0xFF0000, "position": 1, "permissions": 8, "hoist": True, "mentionable": True},
|
||||||
|
{"id": "2", "name": "User", "color": 0x0000FF, "position": 2, "permissions": 0, "hoist": False, "mentionable": False},
|
||||||
|
]
|
||||||
|
db.save_roles(roles)
|
||||||
|
got = db.get_all_roles()
|
||||||
|
assert len(got) == 2
|
||||||
|
assert {r["name"] for r in got} == {"Admin", "User"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_get_channels(db):
|
||||||
|
channels = [
|
||||||
|
{"id": 10, "name": "general", "type": 0, "position": 0, "category_id": None,
|
||||||
|
"topic": "Talk", "nsfw": 0, "bitrate": None, "slowmode_delay": None},
|
||||||
|
{"id": 20, "name": "voice", "type": 2, "position": 1, "category_id": None,
|
||||||
|
"topic": None, "nsfw": 0, "bitrate": 64000, "slowmode_delay": None},
|
||||||
|
]
|
||||||
|
db.save_channels(channels)
|
||||||
|
got = db.get_all_channels()
|
||||||
|
assert len(got) == 2
|
||||||
|
assert {c["name"] for c in got} == {"general", "voice"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_messages_and_attachments(db):
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"id": 101, "channel_id": 10, "author_id": 999, "content": "Hello",
|
||||||
|
"timestamp": "2023-01-01T00:00:00Z", "type": 0,
|
||||||
|
"message_reference": None, "is_pinned": 0, "extra_data": None,
|
||||||
|
"attachments": [
|
||||||
|
{"id": 1, "filename": "file.png", "size": 100,
|
||||||
|
"url": "http://cdn.test/file.png", "content_type": "image/png", "local_hash": "abc"}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
db.save_messages_batch(messages)
|
||||||
|
msg = db._conn.execute("SELECT content FROM messages WHERE id=101").fetchone()
|
||||||
|
assert msg["content"] == "Hello"
|
||||||
|
att = db._conn.execute("SELECT filename FROM attachments WHERE message_id=101").fetchone()
|
||||||
|
assert att["filename"] == "file.png"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_last_message_id(db):
|
||||||
|
msgs = [
|
||||||
|
{"id": 200, "channel_id": 10, "author_id": 1, "content": "a",
|
||||||
|
"timestamp": "2023-01-01T00:00:00Z", "type": 0,
|
||||||
|
"message_reference": None, "is_pinned": 0, "extra_data": None},
|
||||||
|
{"id": 201, "channel_id": 10, "author_id": 1, "content": "b",
|
||||||
|
"timestamp": "2023-01-01T00:01:00Z", "type": 0,
|
||||||
|
"message_reference": None, "is_pinned": 0, "extra_data": None},
|
||||||
|
]
|
||||||
|
db.save_messages_batch(msgs)
|
||||||
|
assert db.get_last_message_id("10") == 201
|
||||||
|
|
||||||
|
|
||||||
|
def test_stats_by_channel(db):
|
||||||
|
msgs = [
|
||||||
|
{"id": 101, "channel_id": 10, "author_id": 1, "content": "Hi",
|
||||||
|
"timestamp": "2023-01-01T00:00:00Z", "type": 0,
|
||||||
|
"message_reference": None, "is_pinned": 0, "extra_data": None},
|
||||||
|
{"id": 102, "channel_id": 10, "author_id": 1, "content": "Bye",
|
||||||
|
"timestamp": "2023-01-01T00:01:00Z", "type": 0,
|
||||||
|
"message_reference": None, "is_pinned": 0, "extra_data": None},
|
||||||
|
]
|
||||||
|
db.save_messages_batch(msgs)
|
||||||
|
stats = db.get_stats_by_channel()
|
||||||
|
assert stats[10]["message_count"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_threads(db):
|
||||||
|
threads = [
|
||||||
|
{"id": 300, "name": "thread-1", "type": 11, "parent_id": 10,
|
||||||
|
"message_count": 5, "member_count": 2, "archived": 0,
|
||||||
|
"archive_timestamp": None, "auto_archive_duration": 1440,
|
||||||
|
"locked": 0, "applied_tags": None},
|
||||||
|
]
|
||||||
|
db.save_threads(threads)
|
||||||
|
got = db.get_threads_by_parent("10")
|
||||||
|
assert len(got) == 1
|
||||||
|
assert got[0]["name"] == "thread-1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_media_pool(db):
|
||||||
|
db.add_media_to_pool("hash123", "/path/file.png", 512, "image/png", "http://cdn.test/file.png")
|
||||||
|
db._conn.commit()
|
||||||
|
entry = db.get_media_by_hash("hash123")
|
||||||
|
assert entry is not None
|
||||||
|
assert entry["local_path"] == "/path/file.png"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_messages_paged_after_id(db):
|
||||||
|
msgs = [
|
||||||
|
{"id": i, "channel_id": 10, "author_id": 1, "content": f"msg{i}",
|
||||||
|
"timestamp": f"2023-01-01T00:0{i}:00Z", "type": 0,
|
||||||
|
"message_reference": None, "is_pinned": 0, "extra_data": None}
|
||||||
|
for i in range(5)
|
||||||
|
]
|
||||||
|
db.save_messages_batch(msgs)
|
||||||
|
page = db.get_messages_paged("10", limit=3, offset=0)
|
||||||
|
assert len(page) == 3
|
||||||
|
page_after = db.get_messages_paged("10", limit=10, after_id="2")
|
||||||
|
assert all(m["id"] > 2 for m in page_after)
|
||||||
138
tests/test_migration.py
Normal file
138
tests/test_migration.py
Normal file
|
|
@ -0,0 +1,138 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
from src.core.base import MigrationContext
|
||||||
|
from src.core.configuration import AppConfig
|
||||||
|
from src.core.backup_reader import ChannelType
|
||||||
|
from src.fluxer.migrate_message import migrate_messages as fluxer_migrate, _process_and_send_message as fluxer_send
|
||||||
|
from src.stoat.migrate_message import migrate_messages as stoat_migrate, _process_and_send_message as stoat_send
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# --- Platform Detection (Same as e2e_simulation) ---
|
||||||
|
def get_platforms():
|
||||||
|
"""Determine which platforms to test based on config, or use defaults."""
|
||||||
|
config_path = (Path(__file__).parent.parent / "ReaperFiles-AutoTest/reaper_config.yaml").resolve()
|
||||||
|
if not config_path.exists():
|
||||||
|
return ["fluxer", "stoat"]
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
platforms = []
|
||||||
|
if data.get("fluxer_bot_token"): platforms.append("fluxer")
|
||||||
|
if data.get("stoat_bot_token"): platforms.append("stoat")
|
||||||
|
return platforms if platforms else ["fluxer", "stoat"]
|
||||||
|
|
||||||
|
# --- Unit Tests (Transformation Logic) ---
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_context(mock_discord_reader, mock_fluxer_writer, mock_stoat_writer):
|
||||||
|
context = MagicMock(spec=MigrationContext)
|
||||||
|
context.discord_reader = mock_discord_reader
|
||||||
|
context.fluxer_writer = mock_fluxer_writer
|
||||||
|
context.stoat_writer = mock_stoat_writer
|
||||||
|
context.state = MagicMock()
|
||||||
|
context.state.get_user_alias.return_value = "TestAlias"
|
||||||
|
context.state.emoji_map = {}
|
||||||
|
context.state.channel_map = {}
|
||||||
|
context.is_running = True
|
||||||
|
return context
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_message():
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.id = 111
|
||||||
|
msg.author.id = 222
|
||||||
|
msg.author.display_name = "Author"
|
||||||
|
msg.content = "Test content"
|
||||||
|
msg.attachments = []
|
||||||
|
msg.embeds = []
|
||||||
|
msg.stickers = []
|
||||||
|
msg.created_at.timestamp.return_value = 1600000000.0
|
||||||
|
msg.flags.forwarded = False
|
||||||
|
msg.reference = None
|
||||||
|
return msg
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_migration_transform_fluxer(mock_context, mock_message):
|
||||||
|
stats = {"messages": 0, "attachments": 0}
|
||||||
|
result = await fluxer_send(context=mock_context, msg=mock_message, target_channel_id="c1", stats=stats)
|
||||||
|
assert result == "fluxer_msg_123"
|
||||||
|
assert stats["messages"] == 1
|
||||||
|
assert mock_context.fluxer_writer.send_message.called
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_migration_transform_stoat(mock_context, mock_message):
|
||||||
|
stats = {"messages": 0, "attachments": 0}
|
||||||
|
result = await stoat_send(context=mock_context, msg=mock_message, target_channel_id="c1", stats=stats)
|
||||||
|
assert result == "stoat_msg_123"
|
||||||
|
assert stats["messages"] == 1
|
||||||
|
assert mock_context.stoat_writer.send_message.called
|
||||||
|
|
||||||
|
# --- Integration Tests (Backup Reader) ---
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_backup_reader_interaction(backup_reader, reaper_config):
|
||||||
|
await backup_reader.start()
|
||||||
|
assert backup_reader.guild is not None
|
||||||
|
# Verify ID from config (or mock default)
|
||||||
|
assert str(backup_reader.guild.id) == reaper_config.get("discord_server_id")
|
||||||
|
|
||||||
|
channels = await backup_reader.fetch_channels()
|
||||||
|
assert len(channels) > 0
|
||||||
|
|
||||||
|
# --- E2E Simulation ---
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("platform", get_platforms())
|
||||||
|
async def test_migration_e2e_loop(reaper_config, test_data_dir, tmp_path, platform, request):
|
||||||
|
config = AppConfig(
|
||||||
|
discord_bot_token=reaper_config["discord_bot_token"],
|
||||||
|
discord_server_id=reaper_config["discord_server_id"],
|
||||||
|
target_platform=platform,
|
||||||
|
fluxer_bot_token=reaper_config.get("fluxer_bot_token"),
|
||||||
|
fluxer_server_id=reaper_config.get("fluxer_server_id"),
|
||||||
|
stoat_bot_token=reaper_config.get("stoat_bot_token"),
|
||||||
|
stoat_server_id=reaper_config.get("stoat_server_id"),
|
||||||
|
anonymize_users=reaper_config["anonymize_users"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if test_data_dir.exists():
|
||||||
|
base_dir = test_data_dir
|
||||||
|
msg = f"[DATA_SOURCE] E2E_SIMULATION: Using sample data from {base_dir.name}"
|
||||||
|
else:
|
||||||
|
base_dir = tmp_path
|
||||||
|
msg = f"[DATA_SOURCE] E2E_SIMULATION: Fallback to mock data in {base_dir}"
|
||||||
|
|
||||||
|
from tests.conftest import _log
|
||||||
|
print(msg)
|
||||||
|
_log(msg)
|
||||||
|
|
||||||
|
context = MigrationContext(config, source_mode="backup", base_dir=str(base_dir))
|
||||||
|
|
||||||
|
mock_writer = request.getfixturevalue(f"mock_{platform}_writer")
|
||||||
|
if platform == "fluxer":
|
||||||
|
context.fluxer_writer = mock_writer
|
||||||
|
migrate_func = fluxer_migrate
|
||||||
|
else:
|
||||||
|
context.stoat_writer = mock_writer
|
||||||
|
migrate_func = stoat_migrate
|
||||||
|
|
||||||
|
context.is_running = True
|
||||||
|
from src.core.database import MigrationDatabase
|
||||||
|
context.state.db = MigrationDatabase(tmp_path / f"e2e_{platform}.db", platform=platform)
|
||||||
|
|
||||||
|
await context.discord_reader.start()
|
||||||
|
channels = await context.discord_reader.fetch_channels()
|
||||||
|
text_channels = [c for c in channels if c.type == ChannelType.text]
|
||||||
|
|
||||||
|
if text_channels:
|
||||||
|
source_channel_id = text_channels[0].id
|
||||||
|
target_channel_id = "999" if platform == "fluxer" else "stoat123"
|
||||||
|
|
||||||
|
mock_writer.send_message.side_effect = lambda **kwargs: "ok"
|
||||||
|
|
||||||
|
# Test just the first available channel to keep it fast
|
||||||
|
stats = await migrate_func(context=context, source_channel_id=source_channel_id, target_channel_id=target_channel_id)
|
||||||
|
assert stats["messages"] >= 0
|
||||||
|
|
||||||
|
await context.discord_reader.close()
|
||||||
164
tests/test_ui.py
Normal file
164
tests/test_ui.py
Normal file
|
|
@ -0,0 +1,164 @@
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
from textual.app import App
|
||||||
|
from src.ui.main_app import ReaperApp, ConfigSelectionScreen, ConfigScreen
|
||||||
|
from src.ui.mode_screen import ModeScreen
|
||||||
|
from src.core.configuration import AppConfig
|
||||||
|
|
||||||
|
import os
|
||||||
|
from textual.widgets import ListItem, ListView, Input, Button, Label
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_configs(tmp_path, log):
|
||||||
|
reaper_dir = tmp_path / "ReaperFiles-TestConfig"
|
||||||
|
reaper_dir.mkdir()
|
||||||
|
(reaper_dir / "reaper_config.yaml").write_text("discord_bot_token: 'fake'\ndiscord_server_id: '123'\ntool_mode: 'backup_only'")
|
||||||
|
|
||||||
|
autotest_dir = tmp_path / "ReaperFiles-AutoTest"
|
||||||
|
autotest_dir.mkdir()
|
||||||
|
(autotest_dir / "reaper_config.yaml").write_text("discord_bot_token: 'fake'\ndiscord_server_id: '123'\ntool_mode: 'backup_transfer'")
|
||||||
|
|
||||||
|
old_cwd = os.getcwd()
|
||||||
|
os.chdir(tmp_path)
|
||||||
|
log(f"CWD changed to: {tmp_path}")
|
||||||
|
yield tmp_path
|
||||||
|
os.chdir(old_cwd)
|
||||||
|
|
||||||
|
async def wait_for_screen(app, screen_class, timeout=5.0):
|
||||||
|
import time
|
||||||
|
start = time.time()
|
||||||
|
while time.time() - start < timeout:
|
||||||
|
if isinstance(app.screen, screen_class):
|
||||||
|
return True
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
return False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ui_minimal_launch(mock_configs, log):
|
||||||
|
"""Verify app launch and screen transition to ModeScreen."""
|
||||||
|
log("Running test_ui_minimal_launch")
|
||||||
|
try:
|
||||||
|
# Reverting to AsyncMock to avoid WorkerError.
|
||||||
|
# RuntimeWarnings are now handled globally in conftest.py.
|
||||||
|
with patch("src.ui.main_app.ConfigSelectionScreen.check_updates", AsyncMock()):
|
||||||
|
app = ReaperApp()
|
||||||
|
async with app.run_test() as pilot:
|
||||||
|
await wait_for_screen(app, ConfigSelectionScreen)
|
||||||
|
await pilot.click(ListItem)
|
||||||
|
await wait_for_screen(app, ModeScreen)
|
||||||
|
assert isinstance(app.screen, ModeScreen)
|
||||||
|
log("test_ui_minimal_launch PASSED")
|
||||||
|
except Exception as e:
|
||||||
|
log(f"test_ui_minimal_launch FAILED: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ui_config_wizard_save(mock_configs, log):
|
||||||
|
"""Verify configuration editing and saving."""
|
||||||
|
log("Running test_ui_config_wizard_save")
|
||||||
|
try:
|
||||||
|
with patch("src.ui.main_app.ConfigSelectionScreen.check_updates", AsyncMock()):
|
||||||
|
app = ReaperApp()
|
||||||
|
async with app.run_test() as pilot:
|
||||||
|
await wait_for_screen(app, ConfigSelectionScreen)
|
||||||
|
await pilot.click(ListItem)
|
||||||
|
await wait_for_screen(app, ModeScreen)
|
||||||
|
await pilot.click("#btn_config")
|
||||||
|
await wait_for_screen(app, ConfigScreen)
|
||||||
|
|
||||||
|
screen = app.screen
|
||||||
|
inp = screen.query_one("#inp_discord_token", Input)
|
||||||
|
inp.value = "new_fake_token"
|
||||||
|
|
||||||
|
with patch("src.ui.main_app.save_config") as mock_save:
|
||||||
|
await pilot.click("#btn_save")
|
||||||
|
await pilot.pause(0.2)
|
||||||
|
assert mock_save.called
|
||||||
|
log("test_ui_config_wizard_save PASSED")
|
||||||
|
except Exception as e:
|
||||||
|
log(f"test_ui_config_wizard_save FAILED: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ui_operation_trigger(mock_configs, log):
|
||||||
|
"""Verify that an operation can be triggered."""
|
||||||
|
log("Running test_ui_operation_trigger")
|
||||||
|
from src.ui.shuttle_ops import OperationPane
|
||||||
|
from src.ui.modals import ChannelPickerScreen, ProgressScreen
|
||||||
|
try:
|
||||||
|
with patch("src.ui.main_app.ConfigSelectionScreen.check_updates", AsyncMock()):
|
||||||
|
# run_validate is decorated with @work in shuttle_ops.py, so it MUST be a coroutine (AsyncMock)
|
||||||
|
with patch.object(OperationPane, "run_validate", AsyncMock()):
|
||||||
|
app = ReaperApp()
|
||||||
|
async with app.run_test() as pilot:
|
||||||
|
await wait_for_screen(app, ConfigSelectionScreen)
|
||||||
|
await pilot.click(ListItem)
|
||||||
|
await wait_for_screen(app, ModeScreen)
|
||||||
|
|
||||||
|
pane = app.screen.query_one(OperationPane)
|
||||||
|
pane.tokens_valid = True
|
||||||
|
pane.src_channels = [{"id": 1, "name": "t"}]
|
||||||
|
pane.src_cat_map = {None: "D"}
|
||||||
|
pane.tgt_channels = [{"id": 2, "name": "t"}]
|
||||||
|
pane.tgt_cat_map = {None: "D"}
|
||||||
|
pane.all_tgt_channels = pane.tgt_channels
|
||||||
|
|
||||||
|
btn = pane.query_one("#op_backup_msgs", Button)
|
||||||
|
btn.disabled = False
|
||||||
|
await pilot.pause(0.2)
|
||||||
|
btn.focus()
|
||||||
|
await pilot.press("enter")
|
||||||
|
|
||||||
|
await wait_for_screen(app, (ChannelPickerScreen, ProgressScreen))
|
||||||
|
assert isinstance(app.screen, (ChannelPickerScreen, ProgressScreen))
|
||||||
|
log("test_ui_operation_trigger PASSED")
|
||||||
|
except Exception as e:
|
||||||
|
log(f"test_ui_operation_trigger FAILED: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ui_autotest_button(mock_configs, log):
|
||||||
|
"""Verify visibility and trigger of the AUTO TEST button."""
|
||||||
|
log("Running test_ui_autotest_button")
|
||||||
|
from src.ui.shuttle_ops import OperationPane
|
||||||
|
try:
|
||||||
|
with patch("src.ui.main_app.ConfigSelectionScreen.check_updates", AsyncMock()):
|
||||||
|
with patch.object(OperationPane, "run_validate", AsyncMock()):
|
||||||
|
app = ReaperApp()
|
||||||
|
async with app.run_test() as pilot:
|
||||||
|
await wait_for_screen(app, ConfigSelectionScreen)
|
||||||
|
# Find the AutoTest item in the list
|
||||||
|
lv = app.screen.query_one(ListView)
|
||||||
|
autotest_index = -1
|
||||||
|
for idx, item in enumerate(lv.children):
|
||||||
|
if item.name == "AutoTest":
|
||||||
|
autotest_index = idx
|
||||||
|
break
|
||||||
|
|
||||||
|
assert autotest_index != -1, "AutoTest profile not found in ListView"
|
||||||
|
target_item = lv.children[autotest_index]
|
||||||
|
await pilot.click(target_item)
|
||||||
|
assert await wait_for_screen(app, ModeScreen), "Timed out waiting for ModeScreen"
|
||||||
|
|
||||||
|
# Verify button is present
|
||||||
|
pane = app.screen.query_one(OperationPane)
|
||||||
|
btn = pane.query_one("#op_autotest", Button)
|
||||||
|
assert btn.display is True
|
||||||
|
assert "AUTO TEST" in str(btn.label)
|
||||||
|
|
||||||
|
# Mock the sequence and trigger it
|
||||||
|
with patch.object(OperationPane, "run_autotest_sequence", AsyncMock()) as mock_seq:
|
||||||
|
btn.disabled = False
|
||||||
|
await pilot.pause(0.1)
|
||||||
|
btn.focus()
|
||||||
|
await pilot.press("enter")
|
||||||
|
await pilot.pause(0.1)
|
||||||
|
assert mock_seq.called
|
||||||
|
log("test_ui_autotest_button PASSED")
|
||||||
|
except Exception as e:
|
||||||
|
log(f"test_ui_autotest_button FAILED: {e}")
|
||||||
|
raise
|
||||||
60
tests/test_utils.py
Normal file
60
tests/test_utils.py
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from src.core.utils import parse_snowflake, resolve_discord_links
|
||||||
|
|
||||||
|
def test_parse_snowflake_valid():
|
||||||
|
assert parse_snowflake("12345") == 12345
|
||||||
|
assert parse_snowflake(12345) == 12345
|
||||||
|
assert parse_snowflake(" 67890 ") == 67890
|
||||||
|
|
||||||
|
def test_parse_snowflake_invalid():
|
||||||
|
assert parse_snowflake(None) is None
|
||||||
|
assert parse_snowflake("") is None
|
||||||
|
assert parse_snowflake("none") is None
|
||||||
|
assert parse_snowflake("NULL") is None
|
||||||
|
assert parse_snowflake("not_a_number") is None
|
||||||
|
|
||||||
|
def test_resolve_discord_links_no_content():
|
||||||
|
assert resolve_discord_links("", None, "fluxer", "target_id") == ""
|
||||||
|
assert resolve_discord_links(None, None, "fluxer", "target_id") is None
|
||||||
|
|
||||||
|
def test_resolve_discord_links_no_mapping():
|
||||||
|
mock_state = MagicMock()
|
||||||
|
mock_state.get_target_channel_id.return_value = None
|
||||||
|
mock_state.get_target_category_id.return_value = None
|
||||||
|
mock_state.find_message_mapping.return_value = (None, None)
|
||||||
|
|
||||||
|
content = "Check this: https://discord.com/channels/1/2/3"
|
||||||
|
resolved = resolve_discord_links(content, mock_state, "fluxer", "target_server")
|
||||||
|
assert "[`discord-message`](<https://discord.com/channels/1/2/3>)" in resolved
|
||||||
|
|
||||||
|
def test_resolve_discord_links_channel_mapping():
|
||||||
|
mock_state = MagicMock()
|
||||||
|
mock_state.get_target_channel_id.return_value = "target_chan_456"
|
||||||
|
mock_state.find_message_mapping.return_value = (None, None)
|
||||||
|
|
||||||
|
content = "Go to https://discord.com/channels/123/456"
|
||||||
|
|
||||||
|
# Test Fluxer
|
||||||
|
resolved_fluxer = resolve_discord_links(content, mock_state, "fluxer", "target_server")
|
||||||
|
assert "https://fluxer.app/channels/target_server/target_chan_456" in resolved_fluxer
|
||||||
|
|
||||||
|
# Test Stoat
|
||||||
|
resolved_stoat = resolve_discord_links(content, mock_state, "stoat", "target_server")
|
||||||
|
assert "https://stoat.chat/server/target_server/channel/target_chan_456" in resolved_stoat
|
||||||
|
|
||||||
|
def test_resolve_discord_links_message_mapping():
|
||||||
|
mock_state = MagicMock()
|
||||||
|
mock_state.find_message_mapping.return_value = ("target_chan_456", "target_msg_789")
|
||||||
|
|
||||||
|
content = "Look at this: https://discord.com/channels/123/456/789"
|
||||||
|
|
||||||
|
# Test Fluxer
|
||||||
|
resolved_fluxer = resolve_discord_links(content, mock_state, "fluxer", "target_server")
|
||||||
|
assert "https://fluxer.app/channels/target_server/target_chan_456/target_msg_789" in resolved_fluxer
|
||||||
|
|
||||||
|
def test_resolve_discord_links_skips_wrapped():
|
||||||
|
mock_state = MagicMock()
|
||||||
|
content = "Already wrapped: [link](https://discord.com/channels/1/2/3)"
|
||||||
|
resolved = resolve_discord_links(content, mock_state, "fluxer", "target_server")
|
||||||
|
assert resolved == content
|
||||||
Loading…
Add table
Reference in a new issue