TheOrb/archive_bot/watchparty.py

730 lines
27 KiB
Python

from __future__ import annotations
import json
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any
from .core import BotRuntime, MediaItem
from .db import connect_db, initialize_database
WATCH_PARTY_STATUSES = {
"draft",
"queued",
"connecting",
"playing",
"paused",
"stopped",
"error",
}
QUEUE_STATUSES = {"queued", "playing", "played", "skipped", "failed"}
@dataclass(frozen=True)
class WatchPartySession:
id: int
guild_id: str
voice_channel_id: str
text_channel_id: str
owner_user_id: str
title: str
status: str
worker_session_id: str
current_queue_entry_id: int | None
created_at: str
updated_at: str
def initialize_watchparty_schema() -> None:
initialize_database()
with connect_db() as connection:
connection.execute(
"""
CREATE TABLE IF NOT EXISTS watch_party_sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
guild_id TEXT NOT NULL,
voice_channel_id TEXT NOT NULL,
text_channel_id TEXT NOT NULL,
owner_user_id TEXT NOT NULL,
title TEXT NOT NULL,
status TEXT NOT NULL,
worker_session_id TEXT NOT NULL DEFAULT '',
current_queue_entry_id INTEGER,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
)
"""
)
connection.execute(
"""
CREATE TABLE IF NOT EXISTS watch_party_queue (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id INTEGER NOT NULL,
position INTEGER NOT NULL,
media_type TEXT NOT NULL,
jellyfin_source_id TEXT NOT NULL,
title TEXT NOT NULL,
year TEXT NOT NULL DEFAULT '',
runtime TEXT NOT NULL DEFAULT '',
summary TEXT NOT NULL DEFAULT '',
poster_url TEXT NOT NULL DEFAULT '',
status TEXT NOT NULL,
created_at TEXT NOT NULL,
FOREIGN KEY(session_id) REFERENCES watch_party_sessions(id) ON DELETE CASCADE
)
"""
)
connection.execute(
"""
CREATE TABLE IF NOT EXISTS watch_party_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id INTEGER NOT NULL,
event_type TEXT NOT NULL,
payload_json TEXT NOT NULL,
created_at TEXT NOT NULL,
FOREIGN KEY(session_id) REFERENCES watch_party_sessions(id) ON DELETE CASCADE
)
"""
)
connection.execute(
"""
CREATE TABLE IF NOT EXISTS watch_party_worker_state (
session_id INTEGER PRIMARY KEY,
worker_status TEXT NOT NULL DEFAULT 'idle',
playback_state TEXT NOT NULL DEFAULT 'idle',
current_title TEXT NOT NULL DEFAULT '',
position_seconds INTEGER NOT NULL DEFAULT 0,
duration_seconds INTEGER NOT NULL DEFAULT 0,
last_error TEXT NOT NULL DEFAULT '',
updated_at TEXT NOT NULL,
FOREIGN KEY(session_id) REFERENCES watch_party_sessions(id) ON DELETE CASCADE
)
"""
)
connection.commit()
def utc_now() -> str:
return datetime.now(timezone.utc).isoformat()
def normalize_watch_party_status(status: str) -> str:
cleaned = status.strip().lower()
if cleaned not in WATCH_PARTY_STATUSES:
raise ValueError(f"Invalid watch party status: {status}")
return cleaned
def normalize_queue_status(status: str) -> str:
cleaned = status.strip().lower()
if cleaned not in QUEUE_STATUSES:
raise ValueError(f"Invalid queue status: {status}")
return cleaned
def watch_party_session_from_row(row: Any) -> WatchPartySession:
return WatchPartySession(
id=int(row["id"]),
guild_id=str(row["guild_id"]),
voice_channel_id=str(row["voice_channel_id"]),
text_channel_id=str(row["text_channel_id"]),
owner_user_id=str(row["owner_user_id"]),
title=str(row["title"]),
status=str(row["status"]),
worker_session_id=str(row["worker_session_id"] or ""),
current_queue_entry_id=int(row["current_queue_entry_id"]) if row["current_queue_entry_id"] is not None else None,
created_at=str(row["created_at"]),
updated_at=str(row["updated_at"]),
)
def session_to_jsonable(session: WatchPartySession) -> dict[str, Any]:
return {
"id": session.id,
"guildId": session.guild_id,
"voiceChannelId": session.voice_channel_id,
"textChannelId": session.text_channel_id,
"ownerUserId": session.owner_user_id,
"title": session.title,
"status": session.status,
"workerSessionId": session.worker_session_id,
"currentQueueEntryId": session.current_queue_entry_id,
"createdAt": session.created_at,
"updatedAt": session.updated_at,
}
def queue_entry_to_jsonable(row: Any) -> dict[str, Any]:
return {
"id": int(row["id"]),
"sessionId": int(row["session_id"]),
"position": int(row["position"]),
"mediaType": str(row["media_type"]),
"jellyfinSourceId": str(row["jellyfin_source_id"]),
"title": str(row["title"]),
"year": str(row["year"]),
"runtime": str(row["runtime"]),
"summary": str(row["summary"]),
"posterUrl": str(row["poster_url"]),
"status": str(row["status"]),
"createdAt": str(row["created_at"]),
}
def worker_state_to_jsonable(row: Any) -> dict[str, Any]:
return {
"sessionId": int(row["session_id"]),
"workerStatus": str(row["worker_status"]),
"playbackState": str(row["playback_state"]),
"currentTitle": str(row["current_title"]),
"positionSeconds": int(row["position_seconds"]),
"durationSeconds": int(row["duration_seconds"]),
"lastError": str(row["last_error"]) or "",
"updatedAt": str(row["updated_at"]),
}
def log_watch_party_event(connection: Any, session_id: int, event_type: str, payload: dict[str, Any]) -> None:
connection.execute(
"""
INSERT INTO watch_party_events(session_id, event_type, payload_json, created_at)
VALUES (?, ?, ?, ?)
""",
(session_id, event_type, json.dumps(payload, sort_keys=True), utc_now()),
)
def list_media_candidates(runtime: BotRuntime, *, media_type: str = "all", search: str = "", limit: int = 25) -> list[dict[str, Any]]:
from .media import load_media_library, media_item_card_payload
movies, shows = load_media_library(runtime)
pools = {
"movie": movies,
"show": shows,
"all": [*movies, *shows],
}
selected = pools.get(media_type, pools["all"])
query = search.strip().casefold()
if query:
selected = [
item
for item in selected
if query in " ".join([item.title, item.year or "", item.genres or "", item.summary or ""]).casefold()
]
selected = selected[: max(1, min(limit, 100))]
return [media_item_card_payload(item) for item in selected]
def find_media_item_by_source_id(runtime: BotRuntime, source_id: str) -> MediaItem | None:
from .media import load_media_library
movies, shows = load_media_library(runtime)
for item in [*movies, *shows]:
if item.source_id and item.source_id == source_id:
return item
return None
def create_watch_party_session(
runtime: BotRuntime,
*,
guild_id: str,
voice_channel_id: str,
text_channel_id: str,
owner_user_id: str,
title: str,
) -> dict[str, Any]:
del runtime
initialize_watchparty_schema()
if not guild_id.strip():
raise ValueError("guildId is required")
if not voice_channel_id.strip():
raise ValueError("voiceChannelId is required")
if not text_channel_id.strip():
raise ValueError("textChannelId is required")
if not owner_user_id.strip():
raise ValueError("ownerUserId is required")
now = utc_now()
with connect_db() as connection:
cursor = connection.execute(
"""
INSERT INTO watch_party_sessions(
guild_id, voice_channel_id, text_channel_id, owner_user_id,
title, status, worker_session_id, current_queue_entry_id, created_at, updated_at
)
VALUES (?, ?, ?, ?, ?, 'draft', '', NULL, ?, ?)
""",
(guild_id.strip(), voice_channel_id.strip(), text_channel_id.strip(), owner_user_id.strip(), title.strip(), now, now),
)
session_id = int(cursor.lastrowid)
connection.execute(
"""
INSERT INTO watch_party_worker_state(
session_id, worker_status, playback_state, current_title, position_seconds,
duration_seconds, last_error, updated_at
)
VALUES (?, 'idle', 'idle', '', 0, 0, '', ?)
""",
(session_id, now),
)
log_watch_party_event(connection, session_id, "session.created", {"title": title.strip(), "ownerUserId": owner_user_id.strip()})
connection.commit()
return get_watch_party_session(session_id)
def list_watch_party_sessions() -> list[dict[str, Any]]:
initialize_watchparty_schema()
with connect_db() as connection:
rows = connection.execute(
"SELECT * FROM watch_party_sessions ORDER BY updated_at DESC, id DESC"
).fetchall()
return [session_to_jsonable(watch_party_session_from_row(row)) for row in rows]
def find_watch_party_session(
*,
guild_id: str,
voice_channel_id: str | None = None,
allowed_statuses: set[str] | None = None,
) -> dict[str, Any] | None:
initialize_watchparty_schema()
clauses = ["guild_id = ?"]
values: list[Any] = [guild_id.strip()]
if voice_channel_id is not None:
clauses.append("voice_channel_id = ?")
values.append(voice_channel_id.strip())
if allowed_statuses:
placeholders = ", ".join("?" for _ in allowed_statuses)
clauses.append(f"status IN ({placeholders})")
values.extend(sorted(allowed_statuses))
query = f"""
SELECT *
FROM watch_party_sessions
WHERE {' AND '.join(clauses)}
ORDER BY updated_at DESC, id DESC
LIMIT 1
"""
with connect_db() as connection:
row = connection.execute(query, tuple(values)).fetchone()
if row is None:
return None
return session_to_jsonable(watch_party_session_from_row(row))
def get_watch_party_session(session_id: int) -> dict[str, Any]:
initialize_watchparty_schema()
with connect_db() as connection:
row = connection.execute(
"SELECT * FROM watch_party_sessions WHERE id = ?",
(session_id,),
).fetchone()
if row is None:
raise ValueError(f"Watch party session not found: {session_id}")
queue_rows = connection.execute(
"SELECT * FROM watch_party_queue WHERE session_id = ? ORDER BY position ASC, id ASC",
(session_id,),
).fetchall()
worker_row = connection.execute(
"SELECT * FROM watch_party_worker_state WHERE session_id = ?",
(session_id,),
).fetchone()
event_rows = connection.execute(
"SELECT id, event_type, payload_json, created_at FROM watch_party_events WHERE session_id = ? ORDER BY id DESC LIMIT 25",
(session_id,),
).fetchall()
return {
"session": session_to_jsonable(watch_party_session_from_row(row)),
"queue": [queue_entry_to_jsonable(entry) for entry in queue_rows],
"worker": worker_state_to_jsonable(worker_row) if worker_row is not None else None,
"events": [
{
"id": int(event["id"]),
"eventType": str(event["event_type"]),
"payload": json.loads(str(event["payload_json"])),
"createdAt": str(event["created_at"]),
}
for event in event_rows
],
}
def next_queue_position(connection: Any, session_id: int) -> int:
row = connection.execute(
"SELECT COALESCE(MAX(position), 0) AS position FROM watch_party_queue WHERE session_id = ?",
(session_id,),
).fetchone()
return int(row["position"]) + 1
def add_watch_party_queue_item(runtime: BotRuntime, session_id: int, jellyfin_source_id: str) -> dict[str, Any]:
from .media import media_item_card_payload
initialize_watchparty_schema()
if not jellyfin_source_id.strip():
raise ValueError("jellyfinSourceId is required")
item = find_media_item_by_source_id(runtime, jellyfin_source_id.strip())
if item is None:
raise ValueError(f"Media item not found in saved library: {jellyfin_source_id}")
with connect_db() as connection:
session_row = connection.execute(
"SELECT * FROM watch_party_sessions WHERE id = ?",
(session_id,),
).fetchone()
if session_row is None:
raise ValueError(f"Watch party session not found: {session_id}")
position = next_queue_position(connection, session_id)
cursor = connection.execute(
"""
INSERT INTO watch_party_queue(
session_id, position, media_type, jellyfin_source_id, title,
year, runtime, summary, poster_url, status, created_at
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, 'queued', ?)
""",
(
session_id,
position,
item.media_type,
item.source_id or "",
item.title,
item.year or "",
item.runtime or "",
item.summary or "",
media_item_card_payload(item).get("posterUrl", ""),
utc_now(),
),
)
queue_id = int(cursor.lastrowid)
connection.execute(
"UPDATE watch_party_sessions SET status = ?, updated_at = ? WHERE id = ?",
("queued", utc_now(), session_id),
)
log_watch_party_event(
connection,
session_id,
"queue.added",
{"queueEntryId": queue_id, "jellyfinSourceId": jellyfin_source_id.strip(), "title": item.title},
)
connection.commit()
return get_watch_party_session(session_id)
def update_watch_party_status(session_id: int, status: str, *, worker_session_id: str | None = None, current_queue_entry_id: int | None = None) -> dict[str, Any]:
initialize_watchparty_schema()
normalized_status = normalize_watch_party_status(status)
updates = ["status = ?", "updated_at = ?"]
values: list[Any] = [normalized_status, utc_now()]
if worker_session_id is not None:
updates.append("worker_session_id = ?")
values.append(worker_session_id.strip())
if current_queue_entry_id is not None:
updates.append("current_queue_entry_id = ?")
values.append(current_queue_entry_id)
values.append(session_id)
with connect_db() as connection:
if connection.execute("SELECT id FROM watch_party_sessions WHERE id = ?", (session_id,)).fetchone() is None:
raise ValueError(f"Watch party session not found: {session_id}")
connection.execute(
f"UPDATE watch_party_sessions SET {', '.join(updates)} WHERE id = ?",
tuple(values),
)
log_watch_party_event(
connection,
session_id,
"session.status",
{"status": normalized_status, "workerSessionId": worker_session_id or "", "currentQueueEntryId": current_queue_entry_id},
)
connection.commit()
return get_watch_party_session(session_id)
def update_queue_entry_status(session_id: int, queue_entry_id: int, status: str) -> dict[str, Any]:
initialize_watchparty_schema()
normalized_status = normalize_queue_status(status)
with connect_db() as connection:
row = connection.execute(
"SELECT id FROM watch_party_queue WHERE id = ? AND session_id = ?",
(queue_entry_id, session_id),
).fetchone()
if row is None:
raise ValueError(f"Queue entry not found: {queue_entry_id}")
connection.execute(
"UPDATE watch_party_queue SET status = ? WHERE id = ?",
(normalized_status, queue_entry_id),
)
log_watch_party_event(connection, session_id, "queue.status", {"queueEntryId": queue_entry_id, "status": normalized_status})
connection.commit()
return get_watch_party_session(session_id)
def update_worker_state(
session_id: int,
*,
worker_status: str,
playback_state: str,
current_title: str = "",
position_seconds: int = 0,
duration_seconds: int = 0,
last_error: str = "",
) -> dict[str, Any]:
initialize_watchparty_schema()
with connect_db() as connection:
if connection.execute("SELECT id FROM watch_party_sessions WHERE id = ?", (session_id,)).fetchone() is None:
raise ValueError(f"Watch party session not found: {session_id}")
connection.execute(
"""
INSERT INTO watch_party_worker_state(
session_id, worker_status, playback_state, current_title, position_seconds,
duration_seconds, last_error, updated_at
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(session_id) DO UPDATE SET
worker_status = excluded.worker_status,
playback_state = excluded.playback_state,
current_title = excluded.current_title,
position_seconds = excluded.position_seconds,
duration_seconds = excluded.duration_seconds,
last_error = excluded.last_error,
updated_at = excluded.updated_at
""",
(
session_id,
worker_status.strip(),
playback_state.strip(),
current_title.strip(),
max(0, position_seconds),
max(0, duration_seconds),
last_error.strip(),
utc_now(),
),
)
log_watch_party_event(
connection,
session_id,
"worker.state",
{
"workerStatus": worker_status.strip(),
"playbackState": playback_state.strip(),
"currentTitle": current_title.strip(),
"positionSeconds": max(0, position_seconds),
"durationSeconds": max(0, duration_seconds),
"lastError": last_error.strip(),
},
)
connection.commit()
return get_watch_party_session(session_id)
def worker_command_payload(session_id: int, action: str, data: dict[str, Any]) -> dict[str, Any]:
payload = get_watch_party_session(session_id)
payload["workerCommand"] = {
"action": action,
"sessionId": session_id,
"data": data,
}
return payload
def first_queued_entry(session_id: int) -> dict[str, Any] | None:
initialize_watchparty_schema()
with connect_db() as connection:
row = connection.execute(
"""
SELECT * FROM watch_party_queue
WHERE session_id = ? AND status = 'queued'
ORDER BY position ASC, id ASC
LIMIT 1
""",
(session_id,),
).fetchone()
return queue_entry_to_jsonable(row) if row is not None else None
def start_watch_party_worker(session_id: int) -> dict[str, Any]:
from .worker_client import worker_create_session
session_payload = get_watch_party_session(session_id)
session = session_payload["session"]
worker = worker_create_session(
session_id=session_id,
guild_id=str(session["guildId"]),
voice_channel_id=str(session["voiceChannelId"]),
text_channel_id=str(session["textChannelId"]),
)
update_watch_party_status(
session_id,
"connecting",
worker_session_id=str(worker.get("workerSessionId", session_id)),
current_queue_entry_id=session["currentQueueEntryId"],
)
result = update_worker_state(
session_id,
worker_status=str(worker.get("workerStatus", "idle")),
playback_state=str(worker.get("playbackState", "idle")),
current_title=str(worker.get("currentTitle", "")),
position_seconds=int(worker.get("positionSeconds", 0) or 0),
duration_seconds=int(worker.get("durationSeconds", 0) or 0),
last_error=str(worker.get("lastError", "")),
)
return {"worker": worker, **result}
def play_next_watch_party_item(runtime: BotRuntime, session_id: int) -> dict[str, Any]:
from .media import catalog_item_for_source_id, resolve_jellyfin_playback_source
from .worker_client import worker_play
queue_entry = first_queued_entry(session_id)
if queue_entry is None:
raise ValueError("No queued items available")
media_item = catalog_item_for_source_id(runtime, str(queue_entry["jellyfinSourceId"]))
if media_item is None:
raise ValueError(f"Queued media item no longer exists in the saved library: {queue_entry['jellyfinSourceId']}")
playback = resolve_jellyfin_playback_source(runtime, media_item)
worker = worker_play(
session_id=session_id,
queue_entry_id=int(queue_entry["id"]),
jellyfin_source_id=str(queue_entry["jellyfinSourceId"]),
title=str(queue_entry["title"]),
media_type=str(queue_entry["mediaType"]),
playback=playback,
)
update_queue_entry_status(session_id, int(queue_entry["id"]), "playing")
update_watch_party_status(
session_id,
"playing",
current_queue_entry_id=int(queue_entry["id"]),
)
result = update_worker_state(
session_id,
worker_status=str(worker.get("workerStatus", "idle")),
playback_state=str(worker.get("playbackState", "idle")),
current_title=str(worker.get("currentTitle", "")),
position_seconds=int(worker.get("positionSeconds", 0) or 0),
duration_seconds=int(worker.get("durationSeconds", 0) or 0),
last_error=str(worker.get("lastError", "")),
)
return {"worker": worker, "playback": playback, **result}
def control_watch_party_worker(session_id: int, action: str, data: dict[str, Any] | None = None) -> dict[str, Any]:
from .worker_client import worker_control
worker = worker_control(session_id=session_id, action=action, data=data or {})
if action == "pause":
update_watch_party_status(session_id, "paused")
elif action == "resume":
update_watch_party_status(session_id, "playing")
elif action == "stop":
update_watch_party_status(session_id, "stopped")
result = update_worker_state(
session_id,
worker_status=str(worker.get("workerStatus", "idle")),
playback_state=str(worker.get("playbackState", "idle")),
current_title=str(worker.get("currentTitle", "")),
position_seconds=int(worker.get("positionSeconds", 0) or 0),
duration_seconds=int(worker.get("durationSeconds", 0) or 0),
last_error=str(worker.get("lastError", "")),
)
return {"worker": worker, **result}
def refresh_watch_party_worker(session_id: int) -> dict[str, Any]:
session_payload = get_watch_party_session(session_id)
session = session_payload["session"]
if not str(session.get("workerSessionId", "")).strip():
return session_payload
from .worker_client import worker_status
try:
worker = worker_status(session_id)
except RuntimeError as exc:
if "HTTP 404" in str(exc):
return session_payload
raise
result = update_worker_state(
session_id,
worker_status=str(worker.get("workerStatus", "idle")),
playback_state=str(worker.get("playbackState", "idle")),
current_title=str(worker.get("currentTitle", "")),
position_seconds=int(worker.get("positionSeconds", 0) or 0),
duration_seconds=int(worker.get("durationSeconds", 0) or 0),
last_error=str(worker.get("lastError", "")),
)
return {"worker": worker, **result}
def clear_watch_party_queue(session_id: int) -> dict[str, Any]:
initialize_watchparty_schema()
with connect_db() as connection:
if connection.execute("SELECT id FROM watch_party_sessions WHERE id = ?", (session_id,)).fetchone() is None:
raise ValueError(f"Watch party session not found: {session_id}")
connection.execute(
"DELETE FROM watch_party_queue WHERE session_id = ? AND status = 'queued'",
(session_id,),
)
session_row = connection.execute(
"SELECT status FROM watch_party_sessions WHERE id = ?", (session_id,)
).fetchone()
current_status = session_row["status"] if session_row else ""
if current_status == "queued":
connection.execute(
"UPDATE watch_party_sessions SET status = 'draft', updated_at = ? WHERE id = ?",
(utc_now(), session_id),
)
log_watch_party_event(connection, session_id, "queue.cleared", {})
connection.commit()
return get_watch_party_session(session_id)
def end_watch_party_session(session_id: int) -> dict[str, Any]:
from .worker_client import worker_control, worker_enabled
initialize_watchparty_schema()
if worker_enabled():
try:
worker_control(session_id=session_id, action="stop")
except Exception as exc:
print(f"Failed to notify worker of stop during session end: {exc}", flush=True)
update_watch_party_status(session_id, "stopped")
result = update_worker_state(
session_id,
worker_status="idle",
playback_state="idle",
current_title="",
position_seconds=0,
duration_seconds=0,
last_error="",
)
return result
def end_all_watch_party_sessions() -> dict[str, Any]:
from .worker_client import worker_control, worker_enabled
initialize_watchparty_schema()
with connect_db() as connection:
rows = connection.execute(
"SELECT id FROM watch_party_sessions WHERE status != 'stopped'"
).fetchall()
session_ids = [int(row["id"]) for row in rows]
for session_id in session_ids:
if worker_enabled():
try:
worker_control(session_id=session_id, action="stop")
except Exception as exc:
print(f"Failed to notify worker of stop during session end_all: {exc}", flush=True)
with connect_db() as connection:
cursor = connection.execute("DELETE FROM watch_party_sessions")
deleted_count = cursor.rowcount
connection.commit()
return {"ok": True, "endedCount": deleted_count}