from __future__ import annotations import json from dataclasses import dataclass from datetime import datetime, timezone from typing import Any from .core import BotRuntime, MediaItem from .db import connect_db, initialize_database from .media import ( catalog_item_for_source_id, load_media_library, media_item_card_payload, resolve_jellyfin_playback_source, ) from .worker_client import worker_control, worker_create_session, worker_status, worker_play WATCH_PARTY_STATUSES = { "draft", "queued", "connecting", "playing", "paused", "stopped", "error", } QUEUE_STATUSES = {"queued", "playing", "played", "skipped", "failed"} @dataclass(frozen=True) class WatchPartySession: id: int guild_id: str voice_channel_id: str text_channel_id: str owner_user_id: str title: str status: str worker_session_id: str current_queue_entry_id: int | None created_at: str updated_at: str def initialize_watchparty_schema() -> None: initialize_database() with connect_db() as connection: connection.execute( """ CREATE TABLE IF NOT EXISTS watch_party_sessions ( id INTEGER PRIMARY KEY AUTOINCREMENT, guild_id TEXT NOT NULL, voice_channel_id TEXT NOT NULL, text_channel_id TEXT NOT NULL, owner_user_id TEXT NOT NULL, title TEXT NOT NULL, status TEXT NOT NULL, worker_session_id TEXT NOT NULL DEFAULT '', current_queue_entry_id INTEGER, created_at TEXT NOT NULL, updated_at TEXT NOT NULL ) """ ) connection.execute( """ CREATE TABLE IF NOT EXISTS watch_party_queue ( id INTEGER PRIMARY KEY AUTOINCREMENT, session_id INTEGER NOT NULL, position INTEGER NOT NULL, media_type TEXT NOT NULL, jellyfin_source_id TEXT NOT NULL, title TEXT NOT NULL, year TEXT NOT NULL DEFAULT '', runtime TEXT NOT NULL DEFAULT '', summary TEXT NOT NULL DEFAULT '', poster_url TEXT NOT NULL DEFAULT '', status TEXT NOT NULL, created_at TEXT NOT NULL, FOREIGN KEY(session_id) REFERENCES watch_party_sessions(id) ON DELETE CASCADE ) """ ) connection.execute( """ CREATE TABLE IF NOT EXISTS watch_party_events ( id INTEGER PRIMARY KEY AUTOINCREMENT, session_id INTEGER NOT NULL, event_type TEXT NOT NULL, payload_json TEXT NOT NULL, created_at TEXT NOT NULL, FOREIGN KEY(session_id) REFERENCES watch_party_sessions(id) ON DELETE CASCADE ) """ ) connection.execute( """ CREATE TABLE IF NOT EXISTS watch_party_worker_state ( session_id INTEGER PRIMARY KEY, worker_status TEXT NOT NULL DEFAULT 'idle', playback_state TEXT NOT NULL DEFAULT 'idle', current_title TEXT NOT NULL DEFAULT '', position_seconds INTEGER NOT NULL DEFAULT 0, duration_seconds INTEGER NOT NULL DEFAULT 0, last_error TEXT NOT NULL DEFAULT '', updated_at TEXT NOT NULL, FOREIGN KEY(session_id) REFERENCES watch_party_sessions(id) ON DELETE CASCADE ) """ ) connection.commit() def utc_now() -> str: return datetime.now(timezone.utc).isoformat() def normalize_watch_party_status(status: str) -> str: cleaned = status.strip().lower() if cleaned not in WATCH_PARTY_STATUSES: raise ValueError(f"Invalid watch party status: {status}") return cleaned def normalize_queue_status(status: str) -> str: cleaned = status.strip().lower() if cleaned not in QUEUE_STATUSES: raise ValueError(f"Invalid queue status: {status}") return cleaned def watch_party_session_from_row(row: Any) -> WatchPartySession: return WatchPartySession( id=int(row["id"]), guild_id=str(row["guild_id"]), voice_channel_id=str(row["voice_channel_id"]), text_channel_id=str(row["text_channel_id"]), owner_user_id=str(row["owner_user_id"]), title=str(row["title"]), status=str(row["status"]), worker_session_id=str(row["worker_session_id"] or ""), current_queue_entry_id=int(row["current_queue_entry_id"]) if row["current_queue_entry_id"] is not None else None, created_at=str(row["created_at"]), updated_at=str(row["updated_at"]), ) def session_to_jsonable(session: WatchPartySession) -> dict[str, Any]: return { "id": session.id, "guildId": session.guild_id, "voiceChannelId": session.voice_channel_id, "textChannelId": session.text_channel_id, "ownerUserId": session.owner_user_id, "title": session.title, "status": session.status, "workerSessionId": session.worker_session_id, "currentQueueEntryId": session.current_queue_entry_id, "createdAt": session.created_at, "updatedAt": session.updated_at, } def queue_entry_to_jsonable(row: Any) -> dict[str, Any]: return { "id": int(row["id"]), "sessionId": int(row["session_id"]), "position": int(row["position"]), "mediaType": str(row["media_type"]), "jellyfinSourceId": str(row["jellyfin_source_id"]), "title": str(row["title"]), "year": str(row["year"]), "runtime": str(row["runtime"]), "summary": str(row["summary"]), "posterUrl": str(row["poster_url"]), "status": str(row["status"]), "createdAt": str(row["created_at"]), } def worker_state_to_jsonable(row: Any) -> dict[str, Any]: return { "sessionId": int(row["session_id"]), "workerStatus": str(row["worker_status"]), "playbackState": str(row["playback_state"]), "currentTitle": str(row["current_title"]), "positionSeconds": int(row["position_seconds"]), "durationSeconds": int(row["duration_seconds"]), "lastError": str(row["last_error"]) or "", "updatedAt": str(row["updated_at"]), } def log_watch_party_event(connection: Any, session_id: int, event_type: str, payload: dict[str, Any]) -> None: connection.execute( """ INSERT INTO watch_party_events(session_id, event_type, payload_json, created_at) VALUES (?, ?, ?, ?) """, (session_id, event_type, json.dumps(payload, sort_keys=True), utc_now()), ) def list_media_candidates(runtime: BotRuntime, *, media_type: str = "all", search: str = "", limit: int = 25) -> list[dict[str, Any]]: movies, shows = load_media_library(runtime) pools = { "movie": movies, "show": shows, "all": [*movies, *shows], } selected = pools.get(media_type, pools["all"]) query = search.strip().casefold() if query: selected = [ item for item in selected if query in " ".join([item.title, item.year or "", item.genres or "", item.summary or ""]).casefold() ] selected = selected[: max(1, min(limit, 100))] return [media_item_card_payload(item) for item in selected] def find_media_item_by_source_id(runtime: BotRuntime, source_id: str) -> MediaItem | None: movies, shows = load_media_library(runtime) for item in [*movies, *shows]: if item.source_id and item.source_id == source_id: return item return None def create_watch_party_session( runtime: BotRuntime, *, guild_id: str, voice_channel_id: str, text_channel_id: str, owner_user_id: str, title: str, ) -> dict[str, Any]: del runtime initialize_watchparty_schema() if not guild_id.strip(): raise ValueError("guildId is required") if not voice_channel_id.strip(): raise ValueError("voiceChannelId is required") if not text_channel_id.strip(): raise ValueError("textChannelId is required") if not owner_user_id.strip(): raise ValueError("ownerUserId is required") now = utc_now() with connect_db() as connection: cursor = connection.execute( """ INSERT INTO watch_party_sessions( guild_id, voice_channel_id, text_channel_id, owner_user_id, title, status, worker_session_id, current_queue_entry_id, created_at, updated_at ) VALUES (?, ?, ?, ?, ?, 'draft', '', NULL, ?, ?) """, (guild_id.strip(), voice_channel_id.strip(), text_channel_id.strip(), owner_user_id.strip(), title.strip(), now, now), ) session_id = int(cursor.lastrowid) connection.execute( """ INSERT INTO watch_party_worker_state( session_id, worker_status, playback_state, current_title, position_seconds, duration_seconds, last_error, updated_at ) VALUES (?, 'idle', 'idle', '', 0, 0, '', ?) """, (session_id, now), ) log_watch_party_event(connection, session_id, "session.created", {"title": title.strip(), "ownerUserId": owner_user_id.strip()}) connection.commit() return get_watch_party_session(session_id) def list_watch_party_sessions() -> list[dict[str, Any]]: initialize_watchparty_schema() with connect_db() as connection: rows = connection.execute( "SELECT * FROM watch_party_sessions ORDER BY updated_at DESC, id DESC" ).fetchall() return [session_to_jsonable(watch_party_session_from_row(row)) for row in rows] def find_watch_party_session( *, guild_id: str, voice_channel_id: str | None = None, allowed_statuses: set[str] | None = None, ) -> dict[str, Any] | None: initialize_watchparty_schema() clauses = ["guild_id = ?"] values: list[Any] = [guild_id.strip()] if voice_channel_id is not None: clauses.append("voice_channel_id = ?") values.append(voice_channel_id.strip()) if allowed_statuses: placeholders = ", ".join("?" for _ in allowed_statuses) clauses.append(f"status IN ({placeholders})") values.extend(sorted(allowed_statuses)) query = f""" SELECT * FROM watch_party_sessions WHERE {' AND '.join(clauses)} ORDER BY updated_at DESC, id DESC LIMIT 1 """ with connect_db() as connection: row = connection.execute(query, tuple(values)).fetchone() if row is None: return None return session_to_jsonable(watch_party_session_from_row(row)) def get_watch_party_session(session_id: int) -> dict[str, Any]: initialize_watchparty_schema() with connect_db() as connection: row = connection.execute( "SELECT * FROM watch_party_sessions WHERE id = ?", (session_id,), ).fetchone() if row is None: raise ValueError(f"Watch party session not found: {session_id}") queue_rows = connection.execute( "SELECT * FROM watch_party_queue WHERE session_id = ? ORDER BY position ASC, id ASC", (session_id,), ).fetchall() worker_row = connection.execute( "SELECT * FROM watch_party_worker_state WHERE session_id = ?", (session_id,), ).fetchone() event_rows = connection.execute( "SELECT id, event_type, payload_json, created_at FROM watch_party_events WHERE session_id = ? ORDER BY id DESC LIMIT 25", (session_id,), ).fetchall() return { "session": session_to_jsonable(watch_party_session_from_row(row)), "queue": [queue_entry_to_jsonable(entry) for entry in queue_rows], "worker": worker_state_to_jsonable(worker_row) if worker_row is not None else None, "events": [ { "id": int(event["id"]), "eventType": str(event["event_type"]), "payload": json.loads(str(event["payload_json"])), "createdAt": str(event["created_at"]), } for event in event_rows ], } def next_queue_position(connection: Any, session_id: int) -> int: row = connection.execute( "SELECT COALESCE(MAX(position), 0) AS position FROM watch_party_queue WHERE session_id = ?", (session_id,), ).fetchone() return int(row["position"]) + 1 def add_watch_party_queue_item(runtime: BotRuntime, session_id: int, jellyfin_source_id: str) -> dict[str, Any]: initialize_watchparty_schema() if not jellyfin_source_id.strip(): raise ValueError("jellyfinSourceId is required") item = find_media_item_by_source_id(runtime, jellyfin_source_id.strip()) if item is None: raise ValueError(f"Media item not found in saved library: {jellyfin_source_id}") with connect_db() as connection: session_row = connection.execute( "SELECT * FROM watch_party_sessions WHERE id = ?", (session_id,), ).fetchone() if session_row is None: raise ValueError(f"Watch party session not found: {session_id}") position = next_queue_position(connection, session_id) cursor = connection.execute( """ INSERT INTO watch_party_queue( session_id, position, media_type, jellyfin_source_id, title, year, runtime, summary, poster_url, status, created_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, 'queued', ?) """, ( session_id, position, item.media_type, item.source_id or "", item.title, item.year or "", item.runtime or "", item.summary or "", media_item_card_payload(item).get("posterUrl", ""), utc_now(), ), ) queue_id = int(cursor.lastrowid) connection.execute( "UPDATE watch_party_sessions SET status = ?, updated_at = ? WHERE id = ?", ("queued", utc_now(), session_id), ) log_watch_party_event( connection, session_id, "queue.added", {"queueEntryId": queue_id, "jellyfinSourceId": jellyfin_source_id.strip(), "title": item.title}, ) connection.commit() return get_watch_party_session(session_id) def update_watch_party_status(session_id: int, status: str, *, worker_session_id: str | None = None, current_queue_entry_id: int | None = None) -> dict[str, Any]: initialize_watchparty_schema() normalized_status = normalize_watch_party_status(status) updates = ["status = ?", "updated_at = ?"] values: list[Any] = [normalized_status, utc_now()] if worker_session_id is not None: updates.append("worker_session_id = ?") values.append(worker_session_id.strip()) if current_queue_entry_id is not None: updates.append("current_queue_entry_id = ?") values.append(current_queue_entry_id) values.append(session_id) with connect_db() as connection: if connection.execute("SELECT id FROM watch_party_sessions WHERE id = ?", (session_id,)).fetchone() is None: raise ValueError(f"Watch party session not found: {session_id}") connection.execute( f"UPDATE watch_party_sessions SET {', '.join(updates)} WHERE id = ?", tuple(values), ) log_watch_party_event( connection, session_id, "session.status", {"status": normalized_status, "workerSessionId": worker_session_id or "", "currentQueueEntryId": current_queue_entry_id}, ) connection.commit() return get_watch_party_session(session_id) def update_queue_entry_status(session_id: int, queue_entry_id: int, status: str) -> dict[str, Any]: initialize_watchparty_schema() normalized_status = normalize_queue_status(status) with connect_db() as connection: row = connection.execute( "SELECT id FROM watch_party_queue WHERE id = ? AND session_id = ?", (queue_entry_id, session_id), ).fetchone() if row is None: raise ValueError(f"Queue entry not found: {queue_entry_id}") connection.execute( "UPDATE watch_party_queue SET status = ? WHERE id = ?", (normalized_status, queue_entry_id), ) log_watch_party_event(connection, session_id, "queue.status", {"queueEntryId": queue_entry_id, "status": normalized_status}) connection.commit() return get_watch_party_session(session_id) def update_worker_state( session_id: int, *, worker_status: str, playback_state: str, current_title: str = "", position_seconds: int = 0, duration_seconds: int = 0, last_error: str = "", ) -> dict[str, Any]: initialize_watchparty_schema() with connect_db() as connection: if connection.execute("SELECT id FROM watch_party_sessions WHERE id = ?", (session_id,)).fetchone() is None: raise ValueError(f"Watch party session not found: {session_id}") connection.execute( """ INSERT INTO watch_party_worker_state( session_id, worker_status, playback_state, current_title, position_seconds, duration_seconds, last_error, updated_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(session_id) DO UPDATE SET worker_status = excluded.worker_status, playback_state = excluded.playback_state, current_title = excluded.current_title, position_seconds = excluded.position_seconds, duration_seconds = excluded.duration_seconds, last_error = excluded.last_error, updated_at = excluded.updated_at """, ( session_id, worker_status.strip(), playback_state.strip(), current_title.strip(), max(0, position_seconds), max(0, duration_seconds), last_error.strip(), utc_now(), ), ) log_watch_party_event( connection, session_id, "worker.state", { "workerStatus": worker_status.strip(), "playbackState": playback_state.strip(), "currentTitle": current_title.strip(), "positionSeconds": max(0, position_seconds), "durationSeconds": max(0, duration_seconds), "lastError": last_error.strip(), }, ) connection.commit() return get_watch_party_session(session_id) def worker_command_payload(session_id: int, action: str, data: dict[str, Any]) -> dict[str, Any]: payload = get_watch_party_session(session_id) payload["workerCommand"] = { "action": action, "sessionId": session_id, "data": data, } return payload def first_queued_entry(session_id: int) -> dict[str, Any] | None: initialize_watchparty_schema() with connect_db() as connection: row = connection.execute( """ SELECT * FROM watch_party_queue WHERE session_id = ? AND status = 'queued' ORDER BY position ASC, id ASC LIMIT 1 """, (session_id,), ).fetchone() return queue_entry_to_jsonable(row) if row is not None else None def start_watch_party_worker(session_id: int) -> dict[str, Any]: session_payload = get_watch_party_session(session_id) session = session_payload["session"] worker = worker_create_session( session_id=session_id, guild_id=str(session["guildId"]), voice_channel_id=str(session["voiceChannelId"]), text_channel_id=str(session["textChannelId"]), ) update_watch_party_status( session_id, "connecting", worker_session_id=str(worker.get("workerSessionId", session_id)), current_queue_entry_id=session["currentQueueEntryId"], ) result = update_worker_state( session_id, worker_status=str(worker.get("workerStatus", "idle")), playback_state=str(worker.get("playbackState", "idle")), current_title=str(worker.get("currentTitle", "")), position_seconds=int(worker.get("positionSeconds", 0) or 0), duration_seconds=int(worker.get("durationSeconds", 0) or 0), last_error=str(worker.get("lastError", "")), ) return {"worker": worker, **result} def play_next_watch_party_item(runtime: BotRuntime, session_id: int) -> dict[str, Any]: queue_entry = first_queued_entry(session_id) if queue_entry is None: raise ValueError("No queued items available") media_item = catalog_item_for_source_id(runtime, str(queue_entry["jellyfinSourceId"])) if media_item is None: raise ValueError(f"Queued media item no longer exists in the saved library: {queue_entry['jellyfinSourceId']}") playback = resolve_jellyfin_playback_source(runtime, media_item) worker = worker_play( session_id=session_id, queue_entry_id=int(queue_entry["id"]), jellyfin_source_id=str(queue_entry["jellyfinSourceId"]), title=str(queue_entry["title"]), media_type=str(queue_entry["mediaType"]), playback=playback, ) update_queue_entry_status(session_id, int(queue_entry["id"]), "playing") update_watch_party_status( session_id, "playing", current_queue_entry_id=int(queue_entry["id"]), ) result = update_worker_state( session_id, worker_status=str(worker.get("workerStatus", "idle")), playback_state=str(worker.get("playbackState", "idle")), current_title=str(worker.get("currentTitle", "")), position_seconds=int(worker.get("positionSeconds", 0) or 0), duration_seconds=int(worker.get("durationSeconds", 0) or 0), last_error=str(worker.get("lastError", "")), ) return {"worker": worker, "playback": playback, **result} def control_watch_party_worker(session_id: int, action: str, data: dict[str, Any] | None = None) -> dict[str, Any]: worker = worker_control(session_id=session_id, action=action, data=data or {}) if action == "pause": update_watch_party_status(session_id, "paused") elif action == "resume": update_watch_party_status(session_id, "playing") elif action == "stop": update_watch_party_status(session_id, "stopped") result = update_worker_state( session_id, worker_status=str(worker.get("workerStatus", "idle")), playback_state=str(worker.get("playbackState", "idle")), current_title=str(worker.get("currentTitle", "")), position_seconds=int(worker.get("positionSeconds", 0) or 0), duration_seconds=int(worker.get("durationSeconds", 0) or 0), last_error=str(worker.get("lastError", "")), ) return {"worker": worker, **result} def refresh_watch_party_worker(session_id: int) -> dict[str, Any]: worker = worker_status(session_id) result = update_worker_state( session_id, worker_status=str(worker.get("workerStatus", "idle")), playback_state=str(worker.get("playbackState", "idle")), current_title=str(worker.get("currentTitle", "")), position_seconds=int(worker.get("positionSeconds", 0) or 0), duration_seconds=int(worker.get("durationSeconds", 0) or 0), last_error=str(worker.get("lastError", "")), ) return {"worker": worker, **result}