add support for stoat custom instance
This commit is contained in:
parent
9f27af971c
commit
7fd487ad66
2 changed files with 185 additions and 27 deletions
|
|
@ -22,6 +22,13 @@ class MigrationContext:
|
|||
self.target_platform = target_platform or config.target_platform or "fluxer"
|
||||
self.state = MigrationState()
|
||||
|
||||
# Apply config-based log level dynamically to the root logger
|
||||
if hasattr(self.config, "log_level") and self.config.log_level:
|
||||
import logging
|
||||
level = getattr(logging, self.config.log_level.upper(), logging.INFO)
|
||||
logging.getLogger().setLevel(level)
|
||||
logger.info(f"Log level updated to {self.config.log_level.upper()}")
|
||||
|
||||
# Select the appropriate source reader
|
||||
if source_mode == "backup":
|
||||
from src.core.backup_reader import BackupReader
|
||||
|
|
@ -37,8 +44,8 @@ class MigrationContext:
|
|||
|
||||
# Build the writer for the active target platform only
|
||||
if self.target_platform == "stoat":
|
||||
token = config.stoat_bot_token or ""
|
||||
community_id = config.stoat_server_id or ""
|
||||
token = config.stoat_bot_token
|
||||
community_id = config.stoat_server_id
|
||||
api_url = config.stoat_api_url or "default"
|
||||
self.writer = StoatWriter(token=token, community_id=community_id, api_url=api_url)
|
||||
self.stoat_writer = self.writer
|
||||
|
|
|
|||
|
|
@ -5,24 +5,90 @@ from typing import Optional, List, Dict, Any
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def _discover_stoat_config(api_url: str) -> Dict[str, Optional[str]]:
|
||||
"""
|
||||
Fetches the Stoat/Revolt instance configuration to discover the WS and CDN URLs.
|
||||
Returns a dict with 'ws' and 'cdn' keys.
|
||||
"""
|
||||
results = {"ws": None, "cdn": None}
|
||||
|
||||
if not api_url or api_url == "default" or "stoat.chat" in api_url:
|
||||
return results
|
||||
|
||||
import aiohttp
|
||||
try:
|
||||
# Standard Revolt discovery endpoint is the root API URL
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(api_url.rstrip("/") + "/") as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
results["ws"] = data.get("ws")
|
||||
# Features might be in 'features.autumn.url'
|
||||
features = data.get("features", {})
|
||||
autumn = features.get("autumn", {})
|
||||
results["cdn"] = autumn.get("url")
|
||||
|
||||
if results["ws"] or results["cdn"]:
|
||||
logger.debug(f"Stoat Discovery: Found WS={results['ws']}, CDN={results['cdn']}")
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.debug(f"Stoat Discovery failed (fetching): {e}")
|
||||
|
||||
# Fallback to inference if fetch failed
|
||||
from urllib.parse import urlparse
|
||||
try:
|
||||
parsed = urlparse(api_url)
|
||||
if parsed.netloc:
|
||||
# Traditional defaults for self-hosted
|
||||
results["ws"] = f"wss://{parsed.netloc}/ws"
|
||||
results["cdn"] = f"https://{parsed.netloc}/autumn"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return results
|
||||
|
||||
class StoatWriter:
|
||||
def __init__(self, token: str, community_id: str, api_url: str = "default"):
|
||||
def __init__(self, token: str, community_id: str, api_url: str = "default", ws_url: str = None):
|
||||
self.token = token
|
||||
self.community_id = str(community_id)
|
||||
self.api_url = api_url
|
||||
self.ws_url = ws_url
|
||||
self.client: Optional[stoat.Client] = None
|
||||
self._server = None
|
||||
self._me = None
|
||||
self._validation_cache = None
|
||||
|
||||
@staticmethod
|
||||
async def fetch_guilds(token: str, api_url: str = "default") -> list[tuple[str, str]]:
|
||||
async def fetch_guilds(token: str, api_url: str = "default", ws_url: str = None, cdn_url: str = None) -> list[tuple[str, str]]:
|
||||
"""Fetches the list of Stoat servers the bot is in. Returns list of (label, id)."""
|
||||
client_kwargs = {"token": token, "bot": True}
|
||||
if api_url and api_url != "default":
|
||||
client_kwargs["http_base"] = api_url
|
||||
|
||||
token = token.strip()
|
||||
api_url = (api_url or "default").strip()
|
||||
ws_url = (ws_url or "").strip()
|
||||
cdn_url = (cdn_url or "").strip()
|
||||
|
||||
# Auto-discover URLs if not provided for custom domains
|
||||
if api_url != "default" and (not ws_url or not cdn_url):
|
||||
discovery = await _discover_stoat_config(api_url)
|
||||
if not ws_url: ws_url = discovery["ws"] or ""
|
||||
if not cdn_url: cdn_url = discovery["cdn"] or ""
|
||||
|
||||
# Diagnostics to both stdout and logger
|
||||
log_msg = f"Stoat: Fetching guilds using API URL: {api_url}"
|
||||
if ws_url: log_msg += f" [WS: {ws_url}]"
|
||||
if cdn_url: log_msg += f" [CDN: {cdn_url}]"
|
||||
print(log_msg)
|
||||
logger.debug(log_msg)
|
||||
|
||||
client_kwargs = {
|
||||
"token": token,
|
||||
"bot": True,
|
||||
"http_base": api_url if api_url != "default" else None,
|
||||
"websocket_base": ws_url or None,
|
||||
"cdn_base": cdn_url or None
|
||||
}
|
||||
client = stoat.Client(**client_kwargs)
|
||||
logger.debug(f"Stoat: Initialized client with native http_base={client_kwargs['http_base']} and websocket_base={client_kwargs['websocket_base']}")
|
||||
|
||||
ready_event = asyncio.Event()
|
||||
servers_list = []
|
||||
|
||||
|
|
@ -62,11 +128,17 @@ class StoatWriter:
|
|||
logger.error(f"Failed to fetch Stoat servers: {e}")
|
||||
raise
|
||||
finally:
|
||||
await client.close()
|
||||
# Shutdown the specific client instance used for fetching
|
||||
try:
|
||||
await client.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
client_task.cancel()
|
||||
try:
|
||||
await client_task
|
||||
except asyncio.CancelledError:
|
||||
# Wait for the task to actually finish terminating
|
||||
await asyncio.wait_for(client_task, timeout=2.0)
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
pass
|
||||
|
||||
return guilds_list
|
||||
|
|
@ -81,17 +153,45 @@ class StoatWriter:
|
|||
pass
|
||||
self.client = None
|
||||
|
||||
client_kwargs = {"token": self.token, "bot": True}
|
||||
if self.api_url and self.api_url != "default":
|
||||
client_kwargs["http_base"] = self.api_url
|
||||
|
||||
api_url = (self.api_url or "default").strip()
|
||||
ws_url = (self.ws_url or "").strip()
|
||||
cdn_url = getattr(self, "cdn_url", "").strip()
|
||||
|
||||
# Auto-discover if not provided for custom domains
|
||||
if api_url != "default" and (not ws_url or not cdn_url):
|
||||
discovery = await _discover_stoat_config(api_url)
|
||||
if not ws_url: ws_url = discovery["ws"] or ""
|
||||
if not cdn_url: cdn_url = discovery["cdn"] or ""
|
||||
|
||||
token = self.token.strip()
|
||||
|
||||
# Diagnostics to both stdout and logger
|
||||
log_msg = f"Stoat: Starting client using API URL: {api_url}"
|
||||
if ws_url: log_msg += f" [WS: {ws_url}]"
|
||||
if cdn_url: log_msg += f" [CDN: {cdn_url}]"
|
||||
print(log_msg)
|
||||
logger.debug(log_msg)
|
||||
|
||||
client_kwargs = {
|
||||
"token": token,
|
||||
"bot": True,
|
||||
"http_base": api_url if api_url != "default" else None,
|
||||
"websocket_base": ws_url or None,
|
||||
"cdn_base": cdn_url or None
|
||||
}
|
||||
self.client = stoat.Client(**client_kwargs)
|
||||
|
||||
# Keep track of the start task so we can clean it up later
|
||||
self._start_task = asyncio.create_task(self.client.start())
|
||||
|
||||
try:
|
||||
self._me = await self.client.fetch_user("@me")
|
||||
except Exception as e:
|
||||
print(f"Failed to fetch bot user in StoatWriter: {e}")
|
||||
logger.error(f"Failed to fetch bot user in StoatWriter: {e}")
|
||||
self.client = None # Reset if we can't even fetch @me
|
||||
# Ensure we clean up if start fails
|
||||
await self.close()
|
||||
self.client = None
|
||||
|
||||
@property
|
||||
def my_id(self):
|
||||
|
|
@ -402,14 +502,26 @@ class StoatWriter:
|
|||
color=color
|
||||
))
|
||||
|
||||
msg = await channel.send(
|
||||
content=final_content,
|
||||
masquerade=masquerade,
|
||||
replies=replies,
|
||||
attachments=attachments,
|
||||
embeds=stoat_embeds
|
||||
)
|
||||
return str(msg.id) if msg else None
|
||||
# Retry logic for 'NotFound' (common race condition on self-hosted instances)
|
||||
max_tries = 2
|
||||
for attempt in range(max_tries):
|
||||
try:
|
||||
msg = await channel.send(
|
||||
content=final_content,
|
||||
masquerade=masquerade,
|
||||
replies=replies,
|
||||
attachments=attachments,
|
||||
embeds=stoat_embeds
|
||||
)
|
||||
return str(msg.id) if msg else None
|
||||
except Exception as send_err:
|
||||
# 'NotFound' often means the attachment is still being processed by the database
|
||||
if "NotFound" in str(send_err) and attempt < max_tries - 1:
|
||||
logger.warning(f"Stoat: Received NotFound during send (likely race condition). Retrying in 1.5s... (Attempt {attempt+1}/{max_tries})")
|
||||
await asyncio.sleep(1.5)
|
||||
continue
|
||||
raise send_err
|
||||
|
||||
except Exception as e:
|
||||
# If file type not allowed, skip attachments and still send the message
|
||||
if "FileTypeNotAllowed" in str(e) and attachments:
|
||||
|
|
@ -472,8 +584,35 @@ class StoatWriter:
|
|||
|
||||
# Set permissions
|
||||
if permissions != 0:
|
||||
s_perms = self._map_permissions(permissions)
|
||||
await server.set_role_permissions(role, allow=s_perms)
|
||||
requested_perms = self._map_permissions(permissions)
|
||||
|
||||
# Fetch bot's own permissions to mask the requested set.
|
||||
# Stoat/Revolt prevents granting permissions that the bot itself lacks.
|
||||
try:
|
||||
me = await server.fetch_member(self.my_id)
|
||||
bot_perms = server.permissions_for(me)
|
||||
|
||||
# Manual masking of boolean attributes
|
||||
final_perms = stoat.Permissions.none()
|
||||
for attr in dir(requested_perms):
|
||||
if not attr.startswith("_") and isinstance(getattr(requested_perms, attr), bool):
|
||||
if getattr(requested_perms, attr) and getattr(bot_perms, attr, False):
|
||||
try:
|
||||
setattr(final_perms, attr, True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await server.set_role_permissions(role, allow=final_perms)
|
||||
except Exception as perm_err:
|
||||
logger.warning(f"Stoat: Could not mask/verify role permissions for {name} (falling back to minimal): {perm_err}")
|
||||
# Attempt to set at least one basic permission if possible, or just skip
|
||||
try:
|
||||
minimal = stoat.Permissions.none()
|
||||
minimal.view_channel = True
|
||||
minimal.send_messages = True
|
||||
await server.set_role_permissions(role, allow=minimal)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return str(role.id)
|
||||
except Exception as e:
|
||||
|
|
@ -741,12 +880,24 @@ class StoatWriter:
|
|||
|
||||
async def close(self):
|
||||
client = self.client
|
||||
task = getattr(self, "_start_task", None)
|
||||
|
||||
self.client = None # Atomic clear to prevent new usage
|
||||
self._start_task = None
|
||||
self._me = None
|
||||
self._server = None
|
||||
|
||||
if client:
|
||||
try:
|
||||
await client.close()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error closing Stoat client: {e}")
|
||||
logger.debug(f"Error closing Stoat client session: {e}")
|
||||
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(task, timeout=2.0)
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
pass
|
||||
|
||||
self._validation_cache = None
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue