add support for stoat custom instance
This commit is contained in:
parent
9f27af971c
commit
7fd487ad66
2 changed files with 185 additions and 27 deletions
|
|
@ -22,6 +22,13 @@ class MigrationContext:
|
||||||
self.target_platform = target_platform or config.target_platform or "fluxer"
|
self.target_platform = target_platform or config.target_platform or "fluxer"
|
||||||
self.state = MigrationState()
|
self.state = MigrationState()
|
||||||
|
|
||||||
|
# Apply config-based log level dynamically to the root logger
|
||||||
|
if hasattr(self.config, "log_level") and self.config.log_level:
|
||||||
|
import logging
|
||||||
|
level = getattr(logging, self.config.log_level.upper(), logging.INFO)
|
||||||
|
logging.getLogger().setLevel(level)
|
||||||
|
logger.info(f"Log level updated to {self.config.log_level.upper()}")
|
||||||
|
|
||||||
# Select the appropriate source reader
|
# Select the appropriate source reader
|
||||||
if source_mode == "backup":
|
if source_mode == "backup":
|
||||||
from src.core.backup_reader import BackupReader
|
from src.core.backup_reader import BackupReader
|
||||||
|
|
@ -37,8 +44,8 @@ class MigrationContext:
|
||||||
|
|
||||||
# Build the writer for the active target platform only
|
# Build the writer for the active target platform only
|
||||||
if self.target_platform == "stoat":
|
if self.target_platform == "stoat":
|
||||||
token = config.stoat_bot_token or ""
|
token = config.stoat_bot_token
|
||||||
community_id = config.stoat_server_id or ""
|
community_id = config.stoat_server_id
|
||||||
api_url = config.stoat_api_url or "default"
|
api_url = config.stoat_api_url or "default"
|
||||||
self.writer = StoatWriter(token=token, community_id=community_id, api_url=api_url)
|
self.writer = StoatWriter(token=token, community_id=community_id, api_url=api_url)
|
||||||
self.stoat_writer = self.writer
|
self.stoat_writer = self.writer
|
||||||
|
|
|
||||||
|
|
@ -5,24 +5,90 @@ from typing import Optional, List, Dict, Any
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def _discover_stoat_config(api_url: str) -> Dict[str, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Fetches the Stoat/Revolt instance configuration to discover the WS and CDN URLs.
|
||||||
|
Returns a dict with 'ws' and 'cdn' keys.
|
||||||
|
"""
|
||||||
|
results = {"ws": None, "cdn": None}
|
||||||
|
|
||||||
|
if not api_url or api_url == "default" or "stoat.chat" in api_url:
|
||||||
|
return results
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
try:
|
||||||
|
# Standard Revolt discovery endpoint is the root API URL
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(api_url.rstrip("/") + "/") as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
data = await resp.json()
|
||||||
|
results["ws"] = data.get("ws")
|
||||||
|
# Features might be in 'features.autumn.url'
|
||||||
|
features = data.get("features", {})
|
||||||
|
autumn = features.get("autumn", {})
|
||||||
|
results["cdn"] = autumn.get("url")
|
||||||
|
|
||||||
|
if results["ws"] or results["cdn"]:
|
||||||
|
logger.debug(f"Stoat Discovery: Found WS={results['ws']}, CDN={results['cdn']}")
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Stoat Discovery failed (fetching): {e}")
|
||||||
|
|
||||||
|
# Fallback to inference if fetch failed
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
try:
|
||||||
|
parsed = urlparse(api_url)
|
||||||
|
if parsed.netloc:
|
||||||
|
# Traditional defaults for self-hosted
|
||||||
|
results["ws"] = f"wss://{parsed.netloc}/ws"
|
||||||
|
results["cdn"] = f"https://{parsed.netloc}/autumn"
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
class StoatWriter:
|
class StoatWriter:
|
||||||
def __init__(self, token: str, community_id: str, api_url: str = "default"):
|
def __init__(self, token: str, community_id: str, api_url: str = "default", ws_url: str = None):
|
||||||
self.token = token
|
self.token = token
|
||||||
self.community_id = str(community_id)
|
self.community_id = str(community_id)
|
||||||
self.api_url = api_url
|
self.api_url = api_url
|
||||||
|
self.ws_url = ws_url
|
||||||
self.client: Optional[stoat.Client] = None
|
self.client: Optional[stoat.Client] = None
|
||||||
self._server = None
|
self._server = None
|
||||||
self._me = None
|
self._me = None
|
||||||
self._validation_cache = None
|
self._validation_cache = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def fetch_guilds(token: str, api_url: str = "default") -> list[tuple[str, str]]:
|
async def fetch_guilds(token: str, api_url: str = "default", ws_url: str = None, cdn_url: str = None) -> list[tuple[str, str]]:
|
||||||
"""Fetches the list of Stoat servers the bot is in. Returns list of (label, id)."""
|
"""Fetches the list of Stoat servers the bot is in. Returns list of (label, id)."""
|
||||||
client_kwargs = {"token": token, "bot": True}
|
token = token.strip()
|
||||||
if api_url and api_url != "default":
|
api_url = (api_url or "default").strip()
|
||||||
client_kwargs["http_base"] = api_url
|
ws_url = (ws_url or "").strip()
|
||||||
|
cdn_url = (cdn_url or "").strip()
|
||||||
|
|
||||||
|
# Auto-discover URLs if not provided for custom domains
|
||||||
|
if api_url != "default" and (not ws_url or not cdn_url):
|
||||||
|
discovery = await _discover_stoat_config(api_url)
|
||||||
|
if not ws_url: ws_url = discovery["ws"] or ""
|
||||||
|
if not cdn_url: cdn_url = discovery["cdn"] or ""
|
||||||
|
|
||||||
|
# Diagnostics to both stdout and logger
|
||||||
|
log_msg = f"Stoat: Fetching guilds using API URL: {api_url}"
|
||||||
|
if ws_url: log_msg += f" [WS: {ws_url}]"
|
||||||
|
if cdn_url: log_msg += f" [CDN: {cdn_url}]"
|
||||||
|
print(log_msg)
|
||||||
|
logger.debug(log_msg)
|
||||||
|
|
||||||
|
client_kwargs = {
|
||||||
|
"token": token,
|
||||||
|
"bot": True,
|
||||||
|
"http_base": api_url if api_url != "default" else None,
|
||||||
|
"websocket_base": ws_url or None,
|
||||||
|
"cdn_base": cdn_url or None
|
||||||
|
}
|
||||||
client = stoat.Client(**client_kwargs)
|
client = stoat.Client(**client_kwargs)
|
||||||
|
logger.debug(f"Stoat: Initialized client with native http_base={client_kwargs['http_base']} and websocket_base={client_kwargs['websocket_base']}")
|
||||||
|
|
||||||
ready_event = asyncio.Event()
|
ready_event = asyncio.Event()
|
||||||
servers_list = []
|
servers_list = []
|
||||||
|
|
||||||
|
|
@ -62,11 +128,17 @@ class StoatWriter:
|
||||||
logger.error(f"Failed to fetch Stoat servers: {e}")
|
logger.error(f"Failed to fetch Stoat servers: {e}")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
|
# Shutdown the specific client instance used for fetching
|
||||||
|
try:
|
||||||
await client.close()
|
await client.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
client_task.cancel()
|
client_task.cancel()
|
||||||
try:
|
try:
|
||||||
await client_task
|
# Wait for the task to actually finish terminating
|
||||||
except asyncio.CancelledError:
|
await asyncio.wait_for(client_task, timeout=2.0)
|
||||||
|
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return guilds_list
|
return guilds_list
|
||||||
|
|
@ -81,17 +153,45 @@ class StoatWriter:
|
||||||
pass
|
pass
|
||||||
self.client = None
|
self.client = None
|
||||||
|
|
||||||
client_kwargs = {"token": self.token, "bot": True}
|
api_url = (self.api_url or "default").strip()
|
||||||
if self.api_url and self.api_url != "default":
|
ws_url = (self.ws_url or "").strip()
|
||||||
client_kwargs["http_base"] = self.api_url
|
cdn_url = getattr(self, "cdn_url", "").strip()
|
||||||
|
|
||||||
|
# Auto-discover if not provided for custom domains
|
||||||
|
if api_url != "default" and (not ws_url or not cdn_url):
|
||||||
|
discovery = await _discover_stoat_config(api_url)
|
||||||
|
if not ws_url: ws_url = discovery["ws"] or ""
|
||||||
|
if not cdn_url: cdn_url = discovery["cdn"] or ""
|
||||||
|
|
||||||
|
token = self.token.strip()
|
||||||
|
|
||||||
|
# Diagnostics to both stdout and logger
|
||||||
|
log_msg = f"Stoat: Starting client using API URL: {api_url}"
|
||||||
|
if ws_url: log_msg += f" [WS: {ws_url}]"
|
||||||
|
if cdn_url: log_msg += f" [CDN: {cdn_url}]"
|
||||||
|
print(log_msg)
|
||||||
|
logger.debug(log_msg)
|
||||||
|
|
||||||
|
client_kwargs = {
|
||||||
|
"token": token,
|
||||||
|
"bot": True,
|
||||||
|
"http_base": api_url if api_url != "default" else None,
|
||||||
|
"websocket_base": ws_url or None,
|
||||||
|
"cdn_base": cdn_url or None
|
||||||
|
}
|
||||||
self.client = stoat.Client(**client_kwargs)
|
self.client = stoat.Client(**client_kwargs)
|
||||||
|
|
||||||
|
# Keep track of the start task so we can clean it up later
|
||||||
|
self._start_task = asyncio.create_task(self.client.start())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._me = await self.client.fetch_user("@me")
|
self._me = await self.client.fetch_user("@me")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to fetch bot user in StoatWriter: {e}")
|
print(f"Failed to fetch bot user in StoatWriter: {e}")
|
||||||
logger.error(f"Failed to fetch bot user in StoatWriter: {e}")
|
logger.error(f"Failed to fetch bot user in StoatWriter: {e}")
|
||||||
self.client = None # Reset if we can't even fetch @me
|
# Ensure we clean up if start fails
|
||||||
|
await self.close()
|
||||||
|
self.client = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def my_id(self):
|
def my_id(self):
|
||||||
|
|
@ -402,6 +502,10 @@ class StoatWriter:
|
||||||
color=color
|
color=color
|
||||||
))
|
))
|
||||||
|
|
||||||
|
# Retry logic for 'NotFound' (common race condition on self-hosted instances)
|
||||||
|
max_tries = 2
|
||||||
|
for attempt in range(max_tries):
|
||||||
|
try:
|
||||||
msg = await channel.send(
|
msg = await channel.send(
|
||||||
content=final_content,
|
content=final_content,
|
||||||
masquerade=masquerade,
|
masquerade=masquerade,
|
||||||
|
|
@ -410,6 +514,14 @@ class StoatWriter:
|
||||||
embeds=stoat_embeds
|
embeds=stoat_embeds
|
||||||
)
|
)
|
||||||
return str(msg.id) if msg else None
|
return str(msg.id) if msg else None
|
||||||
|
except Exception as send_err:
|
||||||
|
# 'NotFound' often means the attachment is still being processed by the database
|
||||||
|
if "NotFound" in str(send_err) and attempt < max_tries - 1:
|
||||||
|
logger.warning(f"Stoat: Received NotFound during send (likely race condition). Retrying in 1.5s... (Attempt {attempt+1}/{max_tries})")
|
||||||
|
await asyncio.sleep(1.5)
|
||||||
|
continue
|
||||||
|
raise send_err
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# If file type not allowed, skip attachments and still send the message
|
# If file type not allowed, skip attachments and still send the message
|
||||||
if "FileTypeNotAllowed" in str(e) and attachments:
|
if "FileTypeNotAllowed" in str(e) and attachments:
|
||||||
|
|
@ -472,8 +584,35 @@ class StoatWriter:
|
||||||
|
|
||||||
# Set permissions
|
# Set permissions
|
||||||
if permissions != 0:
|
if permissions != 0:
|
||||||
s_perms = self._map_permissions(permissions)
|
requested_perms = self._map_permissions(permissions)
|
||||||
await server.set_role_permissions(role, allow=s_perms)
|
|
||||||
|
# Fetch bot's own permissions to mask the requested set.
|
||||||
|
# Stoat/Revolt prevents granting permissions that the bot itself lacks.
|
||||||
|
try:
|
||||||
|
me = await server.fetch_member(self.my_id)
|
||||||
|
bot_perms = server.permissions_for(me)
|
||||||
|
|
||||||
|
# Manual masking of boolean attributes
|
||||||
|
final_perms = stoat.Permissions.none()
|
||||||
|
for attr in dir(requested_perms):
|
||||||
|
if not attr.startswith("_") and isinstance(getattr(requested_perms, attr), bool):
|
||||||
|
if getattr(requested_perms, attr) and getattr(bot_perms, attr, False):
|
||||||
|
try:
|
||||||
|
setattr(final_perms, attr, True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await server.set_role_permissions(role, allow=final_perms)
|
||||||
|
except Exception as perm_err:
|
||||||
|
logger.warning(f"Stoat: Could not mask/verify role permissions for {name} (falling back to minimal): {perm_err}")
|
||||||
|
# Attempt to set at least one basic permission if possible, or just skip
|
||||||
|
try:
|
||||||
|
minimal = stoat.Permissions.none()
|
||||||
|
minimal.view_channel = True
|
||||||
|
minimal.send_messages = True
|
||||||
|
await server.set_role_permissions(role, allow=minimal)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
return str(role.id)
|
return str(role.id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -741,12 +880,24 @@ class StoatWriter:
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
client = self.client
|
client = self.client
|
||||||
|
task = getattr(self, "_start_task", None)
|
||||||
|
|
||||||
self.client = None # Atomic clear to prevent new usage
|
self.client = None # Atomic clear to prevent new usage
|
||||||
|
self._start_task = None
|
||||||
self._me = None
|
self._me = None
|
||||||
self._server = None
|
self._server = None
|
||||||
|
|
||||||
if client:
|
if client:
|
||||||
try:
|
try:
|
||||||
await client.close()
|
await client.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Error closing Stoat client: {e}")
|
logger.debug(f"Error closing Stoat client session: {e}")
|
||||||
|
|
||||||
|
if task and not task.done():
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(task, timeout=2.0)
|
||||||
|
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||||
|
pass
|
||||||
|
|
||||||
self._validation_cache = None
|
self._validation_cache = None
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue