TheOrb/archive_bot/watchparty.py

508 lines
18 KiB
Python

from __future__ import annotations
import json
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any
from .core import BotRuntime, MediaItem
from .db import connect_db, initialize_database
from .media import load_media_library, media_item_card_payload
WATCH_PARTY_STATUSES = {
"draft",
"queued",
"connecting",
"playing",
"paused",
"stopped",
"error",
}
QUEUE_STATUSES = {"queued", "playing", "played", "skipped", "failed"}
@dataclass(frozen=True)
class WatchPartySession:
id: int
guild_id: str
voice_channel_id: str
text_channel_id: str
owner_user_id: str
title: str
status: str
worker_session_id: str
current_queue_entry_id: int | None
created_at: str
updated_at: str
def initialize_watchparty_schema() -> None:
initialize_database()
with connect_db() as connection:
connection.execute(
"""
CREATE TABLE IF NOT EXISTS watch_party_sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
guild_id TEXT NOT NULL,
voice_channel_id TEXT NOT NULL,
text_channel_id TEXT NOT NULL,
owner_user_id TEXT NOT NULL,
title TEXT NOT NULL,
status TEXT NOT NULL,
worker_session_id TEXT NOT NULL DEFAULT '',
current_queue_entry_id INTEGER,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
)
"""
)
connection.execute(
"""
CREATE TABLE IF NOT EXISTS watch_party_queue (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id INTEGER NOT NULL,
position INTEGER NOT NULL,
media_type TEXT NOT NULL,
jellyfin_source_id TEXT NOT NULL,
title TEXT NOT NULL,
year TEXT NOT NULL DEFAULT '',
runtime TEXT NOT NULL DEFAULT '',
summary TEXT NOT NULL DEFAULT '',
poster_url TEXT NOT NULL DEFAULT '',
status TEXT NOT NULL,
created_at TEXT NOT NULL,
FOREIGN KEY(session_id) REFERENCES watch_party_sessions(id) ON DELETE CASCADE
)
"""
)
connection.execute(
"""
CREATE TABLE IF NOT EXISTS watch_party_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id INTEGER NOT NULL,
event_type TEXT NOT NULL,
payload_json TEXT NOT NULL,
created_at TEXT NOT NULL,
FOREIGN KEY(session_id) REFERENCES watch_party_sessions(id) ON DELETE CASCADE
)
"""
)
connection.execute(
"""
CREATE TABLE IF NOT EXISTS watch_party_worker_state (
session_id INTEGER PRIMARY KEY,
worker_status TEXT NOT NULL DEFAULT 'idle',
playback_state TEXT NOT NULL DEFAULT 'idle',
current_title TEXT NOT NULL DEFAULT '',
position_seconds INTEGER NOT NULL DEFAULT 0,
duration_seconds INTEGER NOT NULL DEFAULT 0,
last_error TEXT NOT NULL DEFAULT '',
updated_at TEXT NOT NULL,
FOREIGN KEY(session_id) REFERENCES watch_party_sessions(id) ON DELETE CASCADE
)
"""
)
connection.commit()
def utc_now() -> str:
return datetime.now(timezone.utc).isoformat()
def normalize_watch_party_status(status: str) -> str:
cleaned = status.strip().lower()
if cleaned not in WATCH_PARTY_STATUSES:
raise ValueError(f"Invalid watch party status: {status}")
return cleaned
def normalize_queue_status(status: str) -> str:
cleaned = status.strip().lower()
if cleaned not in QUEUE_STATUSES:
raise ValueError(f"Invalid queue status: {status}")
return cleaned
def watch_party_session_from_row(row: Any) -> WatchPartySession:
return WatchPartySession(
id=int(row["id"]),
guild_id=str(row["guild_id"]),
voice_channel_id=str(row["voice_channel_id"]),
text_channel_id=str(row["text_channel_id"]),
owner_user_id=str(row["owner_user_id"]),
title=str(row["title"]),
status=str(row["status"]),
worker_session_id=str(row["worker_session_id"] or ""),
current_queue_entry_id=int(row["current_queue_entry_id"]) if row["current_queue_entry_id"] is not None else None,
created_at=str(row["created_at"]),
updated_at=str(row["updated_at"]),
)
def session_to_jsonable(session: WatchPartySession) -> dict[str, Any]:
return {
"id": session.id,
"guildId": session.guild_id,
"voiceChannelId": session.voice_channel_id,
"textChannelId": session.text_channel_id,
"ownerUserId": session.owner_user_id,
"title": session.title,
"status": session.status,
"workerSessionId": session.worker_session_id,
"currentQueueEntryId": session.current_queue_entry_id,
"createdAt": session.created_at,
"updatedAt": session.updated_at,
}
def queue_entry_to_jsonable(row: Any) -> dict[str, Any]:
return {
"id": int(row["id"]),
"sessionId": int(row["session_id"]),
"position": int(row["position"]),
"mediaType": str(row["media_type"]),
"jellyfinSourceId": str(row["jellyfin_source_id"]),
"title": str(row["title"]),
"year": str(row["year"]),
"runtime": str(row["runtime"]),
"summary": str(row["summary"]),
"posterUrl": str(row["poster_url"]),
"status": str(row["status"]),
"createdAt": str(row["created_at"]),
}
def worker_state_to_jsonable(row: Any) -> dict[str, Any]:
return {
"sessionId": int(row["session_id"]),
"workerStatus": str(row["worker_status"]),
"playbackState": str(row["playback_state"]),
"currentTitle": str(row["current_title"]),
"positionSeconds": int(row["position_seconds"]),
"durationSeconds": int(row["duration_seconds"]),
"lastError": str(row["last_error"]) or "",
"updatedAt": str(row["updated_at"]),
}
def log_watch_party_event(connection: Any, session_id: int, event_type: str, payload: dict[str, Any]) -> None:
connection.execute(
"""
INSERT INTO watch_party_events(session_id, event_type, payload_json, created_at)
VALUES (?, ?, ?, ?)
""",
(session_id, event_type, json.dumps(payload, sort_keys=True), utc_now()),
)
def list_media_candidates(runtime: BotRuntime, *, media_type: str = "all", search: str = "", limit: int = 25) -> list[dict[str, Any]]:
movies, shows = load_media_library(runtime)
pools = {
"movie": movies,
"show": shows,
"all": [*movies, *shows],
}
selected = pools.get(media_type, pools["all"])
query = search.strip().casefold()
if query:
selected = [
item
for item in selected
if query in " ".join([item.title, item.year or "", item.genres or "", item.summary or ""]).casefold()
]
selected = selected[: max(1, min(limit, 100))]
return [media_item_card_payload(item) for item in selected]
def find_media_item_by_source_id(runtime: BotRuntime, source_id: str) -> MediaItem | None:
movies, shows = load_media_library(runtime)
for item in [*movies, *shows]:
if item.source_id and item.source_id == source_id:
return item
return None
def create_watch_party_session(
runtime: BotRuntime,
*,
guild_id: str,
voice_channel_id: str,
text_channel_id: str,
owner_user_id: str,
title: str,
) -> dict[str, Any]:
del runtime
initialize_watchparty_schema()
if not guild_id.strip():
raise ValueError("guildId is required")
if not voice_channel_id.strip():
raise ValueError("voiceChannelId is required")
if not text_channel_id.strip():
raise ValueError("textChannelId is required")
if not owner_user_id.strip():
raise ValueError("ownerUserId is required")
now = utc_now()
with connect_db() as connection:
cursor = connection.execute(
"""
INSERT INTO watch_party_sessions(
guild_id, voice_channel_id, text_channel_id, owner_user_id,
title, status, worker_session_id, current_queue_entry_id, created_at, updated_at
)
VALUES (?, ?, ?, ?, ?, 'draft', '', NULL, ?, ?)
""",
(guild_id.strip(), voice_channel_id.strip(), text_channel_id.strip(), owner_user_id.strip(), title.strip(), now, now),
)
session_id = int(cursor.lastrowid)
connection.execute(
"""
INSERT INTO watch_party_worker_state(
session_id, worker_status, playback_state, current_title, position_seconds,
duration_seconds, last_error, updated_at
)
VALUES (?, 'idle', 'idle', '', 0, 0, '', ?)
""",
(session_id, now),
)
log_watch_party_event(connection, session_id, "session.created", {"title": title.strip(), "ownerUserId": owner_user_id.strip()})
connection.commit()
return get_watch_party_session(session_id)
def list_watch_party_sessions() -> list[dict[str, Any]]:
initialize_watchparty_schema()
with connect_db() as connection:
rows = connection.execute(
"SELECT * FROM watch_party_sessions ORDER BY updated_at DESC, id DESC"
).fetchall()
return [session_to_jsonable(watch_party_session_from_row(row)) for row in rows]
def get_watch_party_session(session_id: int) -> dict[str, Any]:
initialize_watchparty_schema()
with connect_db() as connection:
row = connection.execute(
"SELECT * FROM watch_party_sessions WHERE id = ?",
(session_id,),
).fetchone()
if row is None:
raise ValueError(f"Watch party session not found: {session_id}")
queue_rows = connection.execute(
"SELECT * FROM watch_party_queue WHERE session_id = ? ORDER BY position ASC, id ASC",
(session_id,),
).fetchall()
worker_row = connection.execute(
"SELECT * FROM watch_party_worker_state WHERE session_id = ?",
(session_id,),
).fetchone()
event_rows = connection.execute(
"SELECT id, event_type, payload_json, created_at FROM watch_party_events WHERE session_id = ? ORDER BY id DESC LIMIT 25",
(session_id,),
).fetchall()
return {
"session": session_to_jsonable(watch_party_session_from_row(row)),
"queue": [queue_entry_to_jsonable(entry) for entry in queue_rows],
"worker": worker_state_to_jsonable(worker_row) if worker_row is not None else None,
"events": [
{
"id": int(event["id"]),
"eventType": str(event["event_type"]),
"payload": json.loads(str(event["payload_json"])),
"createdAt": str(event["created_at"]),
}
for event in event_rows
],
}
def next_queue_position(connection: Any, session_id: int) -> int:
row = connection.execute(
"SELECT COALESCE(MAX(position), 0) AS position FROM watch_party_queue WHERE session_id = ?",
(session_id,),
).fetchone()
return int(row["position"]) + 1
def add_watch_party_queue_item(runtime: BotRuntime, session_id: int, jellyfin_source_id: str) -> dict[str, Any]:
initialize_watchparty_schema()
if not jellyfin_source_id.strip():
raise ValueError("jellyfinSourceId is required")
item = find_media_item_by_source_id(runtime, jellyfin_source_id.strip())
if item is None:
raise ValueError(f"Media item not found in saved library: {jellyfin_source_id}")
with connect_db() as connection:
session_row = connection.execute(
"SELECT * FROM watch_party_sessions WHERE id = ?",
(session_id,),
).fetchone()
if session_row is None:
raise ValueError(f"Watch party session not found: {session_id}")
position = next_queue_position(connection, session_id)
cursor = connection.execute(
"""
INSERT INTO watch_party_queue(
session_id, position, media_type, jellyfin_source_id, title,
year, runtime, summary, poster_url, status, created_at
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, 'queued', ?)
""",
(
session_id,
position,
item.media_type,
item.source_id or "",
item.title,
item.year or "",
item.runtime or "",
item.summary or "",
media_item_card_payload(item).get("posterUrl", ""),
utc_now(),
),
)
queue_id = int(cursor.lastrowid)
connection.execute(
"UPDATE watch_party_sessions SET status = ?, updated_at = ? WHERE id = ?",
("queued", utc_now(), session_id),
)
log_watch_party_event(
connection,
session_id,
"queue.added",
{"queueEntryId": queue_id, "jellyfinSourceId": jellyfin_source_id.strip(), "title": item.title},
)
connection.commit()
return get_watch_party_session(session_id)
def update_watch_party_status(session_id: int, status: str, *, worker_session_id: str | None = None, current_queue_entry_id: int | None = None) -> dict[str, Any]:
initialize_watchparty_schema()
normalized_status = normalize_watch_party_status(status)
updates = ["status = ?", "updated_at = ?"]
values: list[Any] = [normalized_status, utc_now()]
if worker_session_id is not None:
updates.append("worker_session_id = ?")
values.append(worker_session_id.strip())
if current_queue_entry_id is not None:
updates.append("current_queue_entry_id = ?")
values.append(current_queue_entry_id)
values.append(session_id)
with connect_db() as connection:
if connection.execute("SELECT id FROM watch_party_sessions WHERE id = ?", (session_id,)).fetchone() is None:
raise ValueError(f"Watch party session not found: {session_id}")
connection.execute(
f"UPDATE watch_party_sessions SET {', '.join(updates)} WHERE id = ?",
tuple(values),
)
log_watch_party_event(
connection,
session_id,
"session.status",
{"status": normalized_status, "workerSessionId": worker_session_id or "", "currentQueueEntryId": current_queue_entry_id},
)
connection.commit()
return get_watch_party_session(session_id)
def update_queue_entry_status(session_id: int, queue_entry_id: int, status: str) -> dict[str, Any]:
initialize_watchparty_schema()
normalized_status = normalize_queue_status(status)
with connect_db() as connection:
row = connection.execute(
"SELECT id FROM watch_party_queue WHERE id = ? AND session_id = ?",
(queue_entry_id, session_id),
).fetchone()
if row is None:
raise ValueError(f"Queue entry not found: {queue_entry_id}")
connection.execute(
"UPDATE watch_party_queue SET status = ? WHERE id = ?",
(normalized_status, queue_entry_id),
)
log_watch_party_event(connection, session_id, "queue.status", {"queueEntryId": queue_entry_id, "status": normalized_status})
connection.commit()
return get_watch_party_session(session_id)
def update_worker_state(
session_id: int,
*,
worker_status: str,
playback_state: str,
current_title: str = "",
position_seconds: int = 0,
duration_seconds: int = 0,
last_error: str = "",
) -> dict[str, Any]:
initialize_watchparty_schema()
with connect_db() as connection:
if connection.execute("SELECT id FROM watch_party_sessions WHERE id = ?", (session_id,)).fetchone() is None:
raise ValueError(f"Watch party session not found: {session_id}")
connection.execute(
"""
INSERT INTO watch_party_worker_state(
session_id, worker_status, playback_state, current_title, position_seconds,
duration_seconds, last_error, updated_at
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(session_id) DO UPDATE SET
worker_status = excluded.worker_status,
playback_state = excluded.playback_state,
current_title = excluded.current_title,
position_seconds = excluded.position_seconds,
duration_seconds = excluded.duration_seconds,
last_error = excluded.last_error,
updated_at = excluded.updated_at
""",
(
session_id,
worker_status.strip(),
playback_state.strip(),
current_title.strip(),
max(0, position_seconds),
max(0, duration_seconds),
last_error.strip(),
utc_now(),
),
)
log_watch_party_event(
connection,
session_id,
"worker.state",
{
"workerStatus": worker_status.strip(),
"playbackState": playback_state.strip(),
"currentTitle": current_title.strip(),
"positionSeconds": max(0, position_seconds),
"durationSeconds": max(0, duration_seconds),
"lastError": last_error.strip(),
},
)
connection.commit()
return get_watch_party_session(session_id)
def worker_command_payload(session_id: int, action: str, data: dict[str, Any]) -> dict[str, Any]:
payload = get_watch_party_session(session_id)
payload["workerCommand"] = {
"action": action,
"sessionId": session_id,
"data": data,
}
return payload
def first_queued_entry(session_id: int) -> dict[str, Any] | None:
initialize_watchparty_schema()
with connect_db() as connection:
row = connection.execute(
"""
SELECT * FROM watch_party_queue
WHERE session_id = ? AND status = 'queued'
ORDER BY position ASC, id ASC
LIMIT 1
""",
(session_id,),
).fetchone()
return queue_entry_to_jsonable(row) if row is not None else None