From 7fd487ad66e7a072ff5c241385eb81d5bc4c76c9 Mon Sep 17 00:00:00 2001 From: rambros Date: Fri, 3 Apr 2026 10:39:38 +0530 Subject: [PATCH] add support for stoat custom instance --- src/core/base.py | 11 ++- src/stoat/writer.py | 201 ++++++++++++++++++++++++++++++++++++++------ 2 files changed, 185 insertions(+), 27 deletions(-) diff --git a/src/core/base.py b/src/core/base.py index cbf6145..e9a1541 100644 --- a/src/core/base.py +++ b/src/core/base.py @@ -22,6 +22,13 @@ class MigrationContext: self.target_platform = target_platform or config.target_platform or "fluxer" self.state = MigrationState() + # Apply config-based log level dynamically to the root logger + if hasattr(self.config, "log_level") and self.config.log_level: + import logging + level = getattr(logging, self.config.log_level.upper(), logging.INFO) + logging.getLogger().setLevel(level) + logger.info(f"Log level updated to {self.config.log_level.upper()}") + # Select the appropriate source reader if source_mode == "backup": from src.core.backup_reader import BackupReader @@ -37,8 +44,8 @@ class MigrationContext: # Build the writer for the active target platform only if self.target_platform == "stoat": - token = config.stoat_bot_token or "" - community_id = config.stoat_server_id or "" + token = config.stoat_bot_token + community_id = config.stoat_server_id api_url = config.stoat_api_url or "default" self.writer = StoatWriter(token=token, community_id=community_id, api_url=api_url) self.stoat_writer = self.writer diff --git a/src/stoat/writer.py b/src/stoat/writer.py index d3a8c2c..efe7a8d 100644 --- a/src/stoat/writer.py +++ b/src/stoat/writer.py @@ -5,24 +5,90 @@ from typing import Optional, List, Dict, Any logger = logging.getLogger(__name__) +async def _discover_stoat_config(api_url: str) -> Dict[str, Optional[str]]: + """ + Fetches the Stoat/Revolt instance configuration to discover the WS and CDN URLs. + Returns a dict with 'ws' and 'cdn' keys. + """ + results = {"ws": None, "cdn": None} + + if not api_url or api_url == "default" or "stoat.chat" in api_url: + return results + + import aiohttp + try: + # Standard Revolt discovery endpoint is the root API URL + async with aiohttp.ClientSession() as session: + async with session.get(api_url.rstrip("/") + "/") as resp: + if resp.status == 200: + data = await resp.json() + results["ws"] = data.get("ws") + # Features might be in 'features.autumn.url' + features = data.get("features", {}) + autumn = features.get("autumn", {}) + results["cdn"] = autumn.get("url") + + if results["ws"] or results["cdn"]: + logger.debug(f"Stoat Discovery: Found WS={results['ws']}, CDN={results['cdn']}") + return results + except Exception as e: + logger.debug(f"Stoat Discovery failed (fetching): {e}") + + # Fallback to inference if fetch failed + from urllib.parse import urlparse + try: + parsed = urlparse(api_url) + if parsed.netloc: + # Traditional defaults for self-hosted + results["ws"] = f"wss://{parsed.netloc}/ws" + results["cdn"] = f"https://{parsed.netloc}/autumn" + except Exception: + pass + + return results + class StoatWriter: - def __init__(self, token: str, community_id: str, api_url: str = "default"): + def __init__(self, token: str, community_id: str, api_url: str = "default", ws_url: str = None): self.token = token self.community_id = str(community_id) self.api_url = api_url + self.ws_url = ws_url self.client: Optional[stoat.Client] = None self._server = None self._me = None self._validation_cache = None @staticmethod - async def fetch_guilds(token: str, api_url: str = "default") -> list[tuple[str, str]]: + async def fetch_guilds(token: str, api_url: str = "default", ws_url: str = None, cdn_url: str = None) -> list[tuple[str, str]]: """Fetches the list of Stoat servers the bot is in. Returns list of (label, id).""" - client_kwargs = {"token": token, "bot": True} - if api_url and api_url != "default": - client_kwargs["http_base"] = api_url - + token = token.strip() + api_url = (api_url or "default").strip() + ws_url = (ws_url or "").strip() + cdn_url = (cdn_url or "").strip() + + # Auto-discover URLs if not provided for custom domains + if api_url != "default" and (not ws_url or not cdn_url): + discovery = await _discover_stoat_config(api_url) + if not ws_url: ws_url = discovery["ws"] or "" + if not cdn_url: cdn_url = discovery["cdn"] or "" + + # Diagnostics to both stdout and logger + log_msg = f"Stoat: Fetching guilds using API URL: {api_url}" + if ws_url: log_msg += f" [WS: {ws_url}]" + if cdn_url: log_msg += f" [CDN: {cdn_url}]" + print(log_msg) + logger.debug(log_msg) + + client_kwargs = { + "token": token, + "bot": True, + "http_base": api_url if api_url != "default" else None, + "websocket_base": ws_url or None, + "cdn_base": cdn_url or None + } client = stoat.Client(**client_kwargs) + logger.debug(f"Stoat: Initialized client with native http_base={client_kwargs['http_base']} and websocket_base={client_kwargs['websocket_base']}") + ready_event = asyncio.Event() servers_list = [] @@ -62,11 +128,17 @@ class StoatWriter: logger.error(f"Failed to fetch Stoat servers: {e}") raise finally: - await client.close() + # Shutdown the specific client instance used for fetching + try: + await client.close() + except Exception: + pass + client_task.cancel() try: - await client_task - except asyncio.CancelledError: + # Wait for the task to actually finish terminating + await asyncio.wait_for(client_task, timeout=2.0) + except (asyncio.CancelledError, asyncio.TimeoutError): pass return guilds_list @@ -81,17 +153,45 @@ class StoatWriter: pass self.client = None - client_kwargs = {"token": self.token, "bot": True} - if self.api_url and self.api_url != "default": - client_kwargs["http_base"] = self.api_url - + api_url = (self.api_url or "default").strip() + ws_url = (self.ws_url or "").strip() + cdn_url = getattr(self, "cdn_url", "").strip() + + # Auto-discover if not provided for custom domains + if api_url != "default" and (not ws_url or not cdn_url): + discovery = await _discover_stoat_config(api_url) + if not ws_url: ws_url = discovery["ws"] or "" + if not cdn_url: cdn_url = discovery["cdn"] or "" + + token = self.token.strip() + + # Diagnostics to both stdout and logger + log_msg = f"Stoat: Starting client using API URL: {api_url}" + if ws_url: log_msg += f" [WS: {ws_url}]" + if cdn_url: log_msg += f" [CDN: {cdn_url}]" + print(log_msg) + logger.debug(log_msg) + + client_kwargs = { + "token": token, + "bot": True, + "http_base": api_url if api_url != "default" else None, + "websocket_base": ws_url or None, + "cdn_base": cdn_url or None + } self.client = stoat.Client(**client_kwargs) + + # Keep track of the start task so we can clean it up later + self._start_task = asyncio.create_task(self.client.start()) + try: self._me = await self.client.fetch_user("@me") except Exception as e: print(f"Failed to fetch bot user in StoatWriter: {e}") logger.error(f"Failed to fetch bot user in StoatWriter: {e}") - self.client = None # Reset if we can't even fetch @me + # Ensure we clean up if start fails + await self.close() + self.client = None @property def my_id(self): @@ -402,14 +502,26 @@ class StoatWriter: color=color )) - msg = await channel.send( - content=final_content, - masquerade=masquerade, - replies=replies, - attachments=attachments, - embeds=stoat_embeds - ) - return str(msg.id) if msg else None + # Retry logic for 'NotFound' (common race condition on self-hosted instances) + max_tries = 2 + for attempt in range(max_tries): + try: + msg = await channel.send( + content=final_content, + masquerade=masquerade, + replies=replies, + attachments=attachments, + embeds=stoat_embeds + ) + return str(msg.id) if msg else None + except Exception as send_err: + # 'NotFound' often means the attachment is still being processed by the database + if "NotFound" in str(send_err) and attempt < max_tries - 1: + logger.warning(f"Stoat: Received NotFound during send (likely race condition). Retrying in 1.5s... (Attempt {attempt+1}/{max_tries})") + await asyncio.sleep(1.5) + continue + raise send_err + except Exception as e: # If file type not allowed, skip attachments and still send the message if "FileTypeNotAllowed" in str(e) and attachments: @@ -472,8 +584,35 @@ class StoatWriter: # Set permissions if permissions != 0: - s_perms = self._map_permissions(permissions) - await server.set_role_permissions(role, allow=s_perms) + requested_perms = self._map_permissions(permissions) + + # Fetch bot's own permissions to mask the requested set. + # Stoat/Revolt prevents granting permissions that the bot itself lacks. + try: + me = await server.fetch_member(self.my_id) + bot_perms = server.permissions_for(me) + + # Manual masking of boolean attributes + final_perms = stoat.Permissions.none() + for attr in dir(requested_perms): + if not attr.startswith("_") and isinstance(getattr(requested_perms, attr), bool): + if getattr(requested_perms, attr) and getattr(bot_perms, attr, False): + try: + setattr(final_perms, attr, True) + except Exception: + pass + + await server.set_role_permissions(role, allow=final_perms) + except Exception as perm_err: + logger.warning(f"Stoat: Could not mask/verify role permissions for {name} (falling back to minimal): {perm_err}") + # Attempt to set at least one basic permission if possible, or just skip + try: + minimal = stoat.Permissions.none() + minimal.view_channel = True + minimal.send_messages = True + await server.set_role_permissions(role, allow=minimal) + except Exception: + pass return str(role.id) except Exception as e: @@ -741,12 +880,24 @@ class StoatWriter: async def close(self): client = self.client + task = getattr(self, "_start_task", None) + self.client = None # Atomic clear to prevent new usage + self._start_task = None self._me = None self._server = None + if client: try: await client.close() except Exception as e: - logger.debug(f"Error closing Stoat client: {e}") + logger.debug(f"Error closing Stoat client session: {e}") + + if task and not task.done(): + task.cancel() + try: + await asyncio.wait_for(task, timeout=2.0) + except (asyncio.CancelledError, asyncio.TimeoutError): + pass + self._validation_cache = None