diff --git a/archive_bot/dashboard.py b/archive_bot/dashboard.py index 48bd943..8b7094a 100644 --- a/archive_bot/dashboard.py +++ b/archive_bot/dashboard.py @@ -38,24 +38,22 @@ from .storage import ( ) from .watchparty import ( add_watch_party_queue_item, + control_watch_party_worker, create_watch_party_session, + find_watch_party_session, first_queued_entry, get_watch_party_session, list_media_candidates, list_watch_party_sessions, + play_next_watch_party_item, + refresh_watch_party_worker, + start_watch_party_worker, update_queue_entry_status, update_watch_party_status, update_worker_state, worker_command_payload, ) -from .worker_client import ( - worker_control, - worker_create_session, - worker_enabled, - worker_health, - worker_play, - worker_status, -) +from .worker_client import worker_enabled, worker_health def make_dashboard_handler(runtime: BotRuntime, auth: DashboardAuth | None) -> type[BaseHTTPRequestHandler]: @@ -654,29 +652,8 @@ def make_dashboard_handler(runtime: BotRuntime, auth: DashboardAuth | None) -> t try: data = self.read_json() session_id = int(str(data.get("sessionId", "")).strip()) - session_payload = get_watch_party_session(session_id) - session = session_payload["session"] - worker = worker_create_session( - session_id=session_id, - guild_id=str(session["guildId"]), - voice_channel_id=str(session["voiceChannelId"]), - text_channel_id=str(session["textChannelId"]), - ) - result = update_watch_party_status( - session_id, - "connecting", - worker_session_id=str(worker.get("workerSessionId", session_id)), - current_queue_entry_id=session["currentQueueEntryId"], - ) - result = update_worker_state( - session_id, - worker_status=str(worker.get("workerStatus", "idle")), - playback_state=str(worker.get("playbackState", "idle")), - current_title=str(worker.get("currentTitle", "")), - position_seconds=int(worker.get("positionSeconds", 0) or 0), - duration_seconds=int(worker.get("durationSeconds", 0) or 0), - last_error=str(worker.get("lastError", "")), - ) + result = start_watch_party_worker(session_id) + worker = result["worker"] except Exception as exc: self.send_json(HTTPStatus.BAD_REQUEST, {"error": str(exc)}) return @@ -687,36 +664,9 @@ def make_dashboard_handler(runtime: BotRuntime, auth: DashboardAuth | None) -> t try: data = self.read_json() session_id = int(str(data.get("sessionId", "")).strip()) - queue_entry = first_queued_entry(session_id) - if queue_entry is None: - raise ValueError("No queued items available") - media_item = catalog_item_for_source_id(runtime, str(queue_entry["jellyfinSourceId"])) - if media_item is None: - raise ValueError(f"Queued media item no longer exists in the saved library: {queue_entry['jellyfinSourceId']}") - playback = resolve_jellyfin_playback_source(runtime, media_item) - worker = worker_play( - session_id=session_id, - queue_entry_id=int(queue_entry["id"]), - jellyfin_source_id=str(queue_entry["jellyfinSourceId"]), - title=str(queue_entry["title"]), - media_type=str(queue_entry["mediaType"]), - playback=playback, - ) - update_queue_entry_status(session_id, int(queue_entry["id"]), "playing") - update_watch_party_status( - session_id, - "playing", - current_queue_entry_id=int(queue_entry["id"]), - ) - result = update_worker_state( - session_id, - worker_status=str(worker.get("workerStatus", "idle")), - playback_state=str(worker.get("playbackState", "idle")), - current_title=str(worker.get("currentTitle", "")), - position_seconds=int(worker.get("positionSeconds", 0) or 0), - duration_seconds=int(worker.get("durationSeconds", 0) or 0), - last_error=str(worker.get("lastError", "")), - ) + result = play_next_watch_party_item(runtime, session_id) + worker = result["worker"] + playback = result["playback"] except Exception as exc: self.send_json(HTTPStatus.BAD_REQUEST, {"error": str(exc)}) return @@ -731,22 +681,8 @@ def make_dashboard_handler(runtime: BotRuntime, auth: DashboardAuth | None) -> t command_data = data.get("data", {}) if not isinstance(command_data, dict): raise ValueError("Worker control data must be an object") - worker = worker_control(session_id=session_id, action=action, data=command_data) - if action == "pause": - update_watch_party_status(session_id, "paused") - elif action == "resume": - update_watch_party_status(session_id, "playing") - elif action == "stop": - update_watch_party_status(session_id, "stopped") - result = update_worker_state( - session_id, - worker_status=str(worker.get("workerStatus", "idle")), - playback_state=str(worker.get("playbackState", "idle")), - current_title=str(worker.get("currentTitle", "")), - position_seconds=int(worker.get("positionSeconds", 0) or 0), - duration_seconds=int(worker.get("durationSeconds", 0) or 0), - last_error=str(worker.get("lastError", "")), - ) + result = control_watch_party_worker(session_id, action, command_data) + worker = result["worker"] except Exception as exc: self.send_json(HTTPStatus.BAD_REQUEST, {"error": str(exc)}) return @@ -757,16 +693,8 @@ def make_dashboard_handler(runtime: BotRuntime, auth: DashboardAuth | None) -> t try: data = self.read_json() session_id = int(str(data.get("sessionId", "")).strip()) - worker = worker_status(session_id) - result = update_worker_state( - session_id, - worker_status=str(worker.get("workerStatus", "idle")), - playback_state=str(worker.get("playbackState", "idle")), - current_title=str(worker.get("currentTitle", "")), - position_seconds=int(worker.get("positionSeconds", 0) or 0), - duration_seconds=int(worker.get("durationSeconds", 0) or 0), - last_error=str(worker.get("lastError", "")), - ) + result = refresh_watch_party_worker(session_id) + worker = result["worker"] except Exception as exc: self.send_json(HTTPStatus.BAD_REQUEST, {"error": str(exc)}) return diff --git a/archive_bot/discord_api.py b/archive_bot/discord_api.py index eef038c..63fdc22 100644 --- a/archive_bot/discord_api.py +++ b/archive_bot/discord_api.py @@ -9,17 +9,70 @@ import urllib.error import urllib.request from typing import Any -from .core import DISCORD_API +from .core import BotRuntime, DISCORD_API +from .watchparty import ( + add_watch_party_queue_item, + control_watch_party_worker, + create_watch_party_session, + find_watch_party_session, + get_watch_party_session, + list_media_candidates, + play_next_watch_party_item, + refresh_watch_party_worker, + start_watch_party_worker, +) +from .worker_client import worker_enabled + + +WATCH_PARTY_ACTIVE_STATUSES = {"draft", "queued", "connecting", "playing", "paused"} + + +def format_watch_party_summary(payload: dict[str, Any]) -> str: + session = payload.get("session", {}) + worker = payload.get("worker") or {} + queue = payload.get("queue") or [] + lines = [ + f"Session `{session.get('title', 'Watch Party')}`", + f"Status: `{session.get('status', 'unknown')}`", + f"Worker: `{worker.get('workerStatus', 'idle')}` · Playback: `{worker.get('playbackState', 'idle')}`", + ] + current_title = str(worker.get("currentTitle", "") or "").strip() + if current_title: + lines.append(f"Now playing: `{current_title}`") + if queue: + preview = [] + for entry in queue[:5]: + preview.append(f"{entry.get('position', '?')}. {entry.get('title', 'Untitled')} [{entry.get('status', 'queued')}]") + lines.append("Queue:\n" + "\n".join(preview)) + if len(queue) > 5: + lines.append(f"...and {len(queue) - 5} more") + else: + lines.append("Queue: empty") + return "\n".join(lines) + + +def format_media_search_results(items: list[dict[str, Any]]) -> str: + if not items: + return "No matches found." + lines: list[str] = [] + for item in items[:8]: + details = [str(item.get("mediaType", "")), str(item.get("year", "")), str(item.get("genres", ""))] + meta = " · ".join(part for part in details if part) + lines.append(f"- `{item.get('title', 'Untitled')}`{f' — {meta}' if meta else ''}") + return "\n".join(lines) class DiscordGatewayManager: - def __init__(self, token: str) -> None: + def __init__(self, token: str, runtime: BotRuntime) -> None: self.token = token + self.runtime = runtime self.thread: threading.Thread | None = None self.loop: asyncio.AbstractEventLoop | None = None self.client: Any = None self.ready = threading.Event() self._disconnecting = threading.Event() + self._commands_registered = False + self._commands_synced = False def start(self) -> None: if self.thread is not None: @@ -50,14 +103,20 @@ class DiscordGatewayManager: def __init__(self, manager: DiscordGatewayManager) -> None: intents = discord.Intents.default() intents.guilds = True + intents.voice_states = True super().__init__(intents=intents) self.manager = manager + self.tree = discord.app_commands.CommandTree(self) + + async def setup_hook(self) -> None: + self.manager.register_watchparty_commands(self.tree, discord) async def on_ready(self) -> None: await self.change_presence(status=discord.Status.online) user = self.user name = user.name if user is not None else "unknown" bot_id = user.id if user is not None else "unknown" + await self.manager.sync_commands(self.tree, discord) print(f"Discord gateway connected as {name} ({bot_id})", flush=True) self.manager.ready.set() @@ -84,6 +143,221 @@ class DiscordGatewayManager: finally: self.loop.close() + async def sync_commands(self, tree: Any, discord: Any) -> None: + if self._commands_synced: + return + try: + await tree.sync() + for guild in self.client.guilds if self.client is not None else []: + tree.copy_global_to(guild=guild) + await tree.sync(guild=guild) + except Exception as exc: + print(f"Discord command sync failed: {exc}", file=sys.stderr, flush=True) + else: + self._commands_synced = True + + def register_watchparty_commands(self, tree: Any, discord: Any) -> None: + if self._commands_registered: + return + + manager = self + runtime = self.runtime + + def current_voice_channel(interaction: Any) -> Any | None: + user = interaction.user + voice = getattr(user, "voice", None) + return getattr(voice, "channel", None) + + def ensure_guild_context(interaction: Any) -> None: + if interaction.guild is None: + raise ValueError("This command only works inside a server") + + def ensure_voice_context(interaction: Any) -> Any: + ensure_guild_context(interaction) + voice_channel = current_voice_channel(interaction) + if voice_channel is None: + raise ValueError("Join a voice channel first") + return voice_channel + + def session_for_voice(interaction: Any) -> dict[str, Any] | None: + voice_channel = current_voice_channel(interaction) + if interaction.guild is None or voice_channel is None: + return None + session = find_watch_party_session( + guild_id=str(interaction.guild.id), + voice_channel_id=str(voice_channel.id), + allowed_statuses=WATCH_PARTY_ACTIVE_STATUSES, + ) + if session is not None: + return session + return find_watch_party_session( + guild_id=str(interaction.guild.id), + voice_channel_id=str(voice_channel.id), + allowed_statuses=None, + ) + + def ensure_session(interaction: Any, *, create_if_missing: bool = False, title: str = "Watch Party") -> dict[str, Any]: + ensure_guild_context(interaction) + voice_channel = ensure_voice_context(interaction) + session = session_for_voice(interaction) + if session is not None: + return session + if not create_if_missing: + raise ValueError("No watch-party session exists for your voice channel. Run /watchparty create first.") + payload = create_watch_party_session( + runtime, + guild_id=str(interaction.guild.id), + voice_channel_id=str(voice_channel.id), + text_channel_id=str(interaction.channel_id), + owner_user_id=str(interaction.user.id), + title=title, + ) + return payload["session"] + + def pick_media_item(query: str, media_type: str) -> dict[str, Any]: + items = list_media_candidates(runtime, media_type=media_type, search=query, limit=8) + if not items: + raise ValueError("No library matches found for that query") + exact_matches = [item for item in items if str(item.get("title", "")).casefold() == query.strip().casefold()] + if len(items) > 1 and not exact_matches: + preview = format_media_search_results(items[:5]) + raise ValueError(f"Query matched multiple titles. Refine it:\n{preview}") + return exact_matches[0] if exact_matches else items[0] + + group = discord.app_commands.Group(name="watchparty", description="Control Jellyfin watch parties in Discord") + + @group.command(name="create", description="Create a watch-party session for your current voice channel") + @discord.app_commands.describe(title="Optional session title") + async def create_command(interaction: Any, title: str | None = None) -> None: + try: + voice_channel = ensure_voice_context(interaction) + session = session_for_voice(interaction) + if session is None: + payload = create_watch_party_session( + runtime, + guild_id=str(interaction.guild.id), + voice_channel_id=str(voice_channel.id), + text_channel_id=str(interaction.channel_id), + owner_user_id=str(interaction.user.id), + title=(title or f"{voice_channel.name} Watch Party").strip(), + ) + session = payload["session"] + message = f"Created `{session['title']}` for <#{session['voiceChannelId']}>." + else: + message = f"Using existing session `{session['title']}` for <#{session['voiceChannelId']}>." + await interaction.response.send_message(message, ephemeral=True) + except Exception as exc: + await interaction.response.send_message(str(exc), ephemeral=True) + + @group.command(name="search", description="Search the saved Jellyfin library for watch-party titles") + @discord.app_commands.describe(query="Movie or show title", media_type="Limit the search to movies or shows") + async def search_command(interaction: Any, query: str, media_type: str = "all") -> None: + try: + ensure_guild_context(interaction) + choice = media_type.strip().lower() if media_type.strip().lower() in {"all", "movie", "show"} else "all" + items = list_media_candidates(runtime, media_type=choice, search=query, limit=8) + await interaction.response.send_message(format_media_search_results(items), ephemeral=True) + except Exception as exc: + await interaction.response.send_message(str(exc), ephemeral=True) + + @group.command(name="add", description="Add a Jellyfin title to the current voice channel watch-party queue") + @discord.app_commands.describe(query="Exact or refined movie/show title", media_type="Limit the search to movies or shows") + async def add_command(interaction: Any, query: str, media_type: str = "all") -> None: + try: + session = ensure_session(interaction, create_if_missing=True, title="Watch Party") + choice = media_type.strip().lower() if media_type.strip().lower() in {"all", "movie", "show"} else "all" + item = pick_media_item(query, choice) + payload = add_watch_party_queue_item(runtime, int(session["id"]), str(item["sourceId"])) + queue = payload["queue"] + await interaction.response.send_message( + f"Added `{item['title']}` to `{payload['session']['title']}`. Queue size: {len(queue)}.", + ephemeral=True, + ) + except Exception as exc: + await interaction.response.send_message(str(exc), ephemeral=True) + + @group.command(name="start", description="Connect the worker and start playback in your current voice channel") + async def start_command(interaction: Any) -> None: + try: + session = ensure_session(interaction, create_if_missing=False) + if not worker_enabled(): + raise ValueError("Watch-party worker is not configured") + worker_payload = start_watch_party_worker(int(session["id"])) + try: + payload = play_next_watch_party_item(runtime, int(session["id"])) + await interaction.response.send_message( + f"Started `{payload['session']['title']}` in <#{payload['session']['voiceChannelId']}>.\n" + f"Now playing `{payload['worker'].get('currentTitle') or payload['session']['title']}`.", + ephemeral=True, + ) + except ValueError as exc: + if "No queued items available" not in str(exc): + raise + await interaction.response.send_message( + f"Connected the worker for `{worker_payload['session']['title']}`, but the queue is empty.", + ephemeral=True, + ) + except Exception as exc: + await interaction.response.send_message(str(exc), ephemeral=True) + + @group.command(name="pause", description="Pause the current watch-party stream") + async def pause_command(interaction: Any) -> None: + try: + session = ensure_session(interaction, create_if_missing=False) + payload = control_watch_party_worker(int(session["id"]), "pause", {}) + await interaction.response.send_message(format_watch_party_summary(payload), ephemeral=True) + except Exception as exc: + await interaction.response.send_message(str(exc), ephemeral=True) + + @group.command(name="resume", description="Resume the current watch-party stream") + async def resume_command(interaction: Any) -> None: + try: + session = ensure_session(interaction, create_if_missing=False) + payload = control_watch_party_worker(int(session["id"]), "resume", {}) + await interaction.response.send_message(format_watch_party_summary(payload), ephemeral=True) + except Exception as exc: + await interaction.response.send_message(str(exc), ephemeral=True) + + @group.command(name="stop", description="Stop the current watch-party stream") + async def stop_command(interaction: Any) -> None: + try: + session = ensure_session(interaction, create_if_missing=False) + payload = control_watch_party_worker(int(session["id"]), "stop", {}) + await interaction.response.send_message(format_watch_party_summary(payload), ephemeral=True) + except Exception as exc: + await interaction.response.send_message(str(exc), ephemeral=True) + + @group.command(name="seek", description="Seek the current stream to a time offset in seconds") + @discord.app_commands.describe(position_seconds="Target playback position in seconds") + async def seek_command(interaction: Any, position_seconds: int) -> None: + try: + session = ensure_session(interaction, create_if_missing=False) + payload = control_watch_party_worker(int(session["id"]), "seek", {"positionSeconds": max(0, position_seconds)}) + await interaction.response.send_message(format_watch_party_summary(payload), ephemeral=True) + except Exception as exc: + await interaction.response.send_message(str(exc), ephemeral=True) + + @group.command(name="queue", description="Show the queue for your current voice channel watch party") + async def queue_command(interaction: Any) -> None: + try: + session = ensure_session(interaction, create_if_missing=False) + payload = get_watch_party_session(int(session["id"])) + await interaction.response.send_message(format_watch_party_summary(payload), ephemeral=True) + except Exception as exc: + await interaction.response.send_message(str(exc), ephemeral=True) + + @group.command(name="status", description="Refresh and show the worker status for your current voice channel watch party") + async def status_command(interaction: Any) -> None: + try: + session = ensure_session(interaction, create_if_missing=False) + payload = refresh_watch_party_worker(int(session["id"])) if worker_enabled() else get_watch_party_session(int(session["id"])) + await interaction.response.send_message(format_watch_party_summary(payload), ephemeral=True) + except Exception as exc: + await interaction.response.send_message(str(exc), ephemeral=True) + + tree.add_command(group) + self._commands_registered = True + def discord_request( method: str, @@ -175,4 +449,3 @@ def discord_delete_message(token: str, channel_id: str, message_id: str) -> None def discord_bot_identity(token: str) -> dict[str, Any]: return discord_request("GET", token, "/users/@me") - diff --git a/archive_bot/main.py b/archive_bot/main.py index a13ae0f..771e45f 100644 --- a/archive_bot/main.py +++ b/archive_bot/main.py @@ -62,7 +62,7 @@ def main() -> int: ) gateway = None if bool_env("DISCORD_GATEWAY_ENABLED", True) and not runtime.dry_run: - gateway = DiscordGatewayManager(token) + gateway = DiscordGatewayManager(token, runtime) gateway.start() dashboard = maybe_start_dashboard(runtime) diff --git a/archive_bot/watchparty.py b/archive_bot/watchparty.py index 59f76bc..84ce8b4 100644 --- a/archive_bot/watchparty.py +++ b/archive_bot/watchparty.py @@ -7,7 +7,13 @@ from typing import Any from .core import BotRuntime, MediaItem from .db import connect_db, initialize_database -from .media import load_media_library, media_item_card_payload +from .media import ( + catalog_item_for_source_id, + load_media_library, + media_item_card_payload, + resolve_jellyfin_playback_source, +) +from .worker_client import worker_control, worker_create_session, worker_status, worker_play WATCH_PARTY_STATUSES = { "draft", @@ -279,6 +285,37 @@ def list_watch_party_sessions() -> list[dict[str, Any]]: return [session_to_jsonable(watch_party_session_from_row(row)) for row in rows] +def find_watch_party_session( + *, + guild_id: str, + voice_channel_id: str | None = None, + allowed_statuses: set[str] | None = None, +) -> dict[str, Any] | None: + initialize_watchparty_schema() + clauses = ["guild_id = ?"] + values: list[Any] = [guild_id.strip()] + if voice_channel_id is not None: + clauses.append("voice_channel_id = ?") + values.append(voice_channel_id.strip()) + if allowed_statuses: + placeholders = ", ".join("?" for _ in allowed_statuses) + clauses.append(f"status IN ({placeholders})") + values.extend(sorted(allowed_statuses)) + + query = f""" + SELECT * + FROM watch_party_sessions + WHERE {' AND '.join(clauses)} + ORDER BY updated_at DESC, id DESC + LIMIT 1 + """ + with connect_db() as connection: + row = connection.execute(query, tuple(values)).fetchone() + if row is None: + return None + return session_to_jsonable(watch_party_session_from_row(row)) + + def get_watch_party_session(session_id: int) -> dict[str, Any]: initialize_watchparty_schema() with connect_db() as connection: @@ -506,3 +543,98 @@ def first_queued_entry(session_id: int) -> dict[str, Any] | None: (session_id,), ).fetchone() return queue_entry_to_jsonable(row) if row is not None else None + + +def start_watch_party_worker(session_id: int) -> dict[str, Any]: + session_payload = get_watch_party_session(session_id) + session = session_payload["session"] + worker = worker_create_session( + session_id=session_id, + guild_id=str(session["guildId"]), + voice_channel_id=str(session["voiceChannelId"]), + text_channel_id=str(session["textChannelId"]), + ) + update_watch_party_status( + session_id, + "connecting", + worker_session_id=str(worker.get("workerSessionId", session_id)), + current_queue_entry_id=session["currentQueueEntryId"], + ) + result = update_worker_state( + session_id, + worker_status=str(worker.get("workerStatus", "idle")), + playback_state=str(worker.get("playbackState", "idle")), + current_title=str(worker.get("currentTitle", "")), + position_seconds=int(worker.get("positionSeconds", 0) or 0), + duration_seconds=int(worker.get("durationSeconds", 0) or 0), + last_error=str(worker.get("lastError", "")), + ) + return {"worker": worker, **result} + + +def play_next_watch_party_item(runtime: BotRuntime, session_id: int) -> dict[str, Any]: + queue_entry = first_queued_entry(session_id) + if queue_entry is None: + raise ValueError("No queued items available") + media_item = catalog_item_for_source_id(runtime, str(queue_entry["jellyfinSourceId"])) + if media_item is None: + raise ValueError(f"Queued media item no longer exists in the saved library: {queue_entry['jellyfinSourceId']}") + playback = resolve_jellyfin_playback_source(runtime, media_item) + worker = worker_play( + session_id=session_id, + queue_entry_id=int(queue_entry["id"]), + jellyfin_source_id=str(queue_entry["jellyfinSourceId"]), + title=str(queue_entry["title"]), + media_type=str(queue_entry["mediaType"]), + playback=playback, + ) + update_queue_entry_status(session_id, int(queue_entry["id"]), "playing") + update_watch_party_status( + session_id, + "playing", + current_queue_entry_id=int(queue_entry["id"]), + ) + result = update_worker_state( + session_id, + worker_status=str(worker.get("workerStatus", "idle")), + playback_state=str(worker.get("playbackState", "idle")), + current_title=str(worker.get("currentTitle", "")), + position_seconds=int(worker.get("positionSeconds", 0) or 0), + duration_seconds=int(worker.get("durationSeconds", 0) or 0), + last_error=str(worker.get("lastError", "")), + ) + return {"worker": worker, "playback": playback, **result} + + +def control_watch_party_worker(session_id: int, action: str, data: dict[str, Any] | None = None) -> dict[str, Any]: + worker = worker_control(session_id=session_id, action=action, data=data or {}) + if action == "pause": + update_watch_party_status(session_id, "paused") + elif action == "resume": + update_watch_party_status(session_id, "playing") + elif action == "stop": + update_watch_party_status(session_id, "stopped") + result = update_worker_state( + session_id, + worker_status=str(worker.get("workerStatus", "idle")), + playback_state=str(worker.get("playbackState", "idle")), + current_title=str(worker.get("currentTitle", "")), + position_seconds=int(worker.get("positionSeconds", 0) or 0), + duration_seconds=int(worker.get("durationSeconds", 0) or 0), + last_error=str(worker.get("lastError", "")), + ) + return {"worker": worker, **result} + + +def refresh_watch_party_worker(session_id: int) -> dict[str, Any]: + worker = worker_status(session_id) + result = update_worker_state( + session_id, + worker_status=str(worker.get("workerStatus", "idle")), + playback_state=str(worker.get("playbackState", "idle")), + current_title=str(worker.get("currentTitle", "")), + position_seconds=int(worker.get("positionSeconds", 0) or 0), + duration_seconds=int(worker.get("durationSeconds", 0) or 0), + last_error=str(worker.get("lastError", "")), + ) + return {"worker": worker, **result}