TheOrb/archive_bot/auth.py

159 lines
5.1 KiB
Python

from __future__ import annotations
import base64
import hashlib
import hmac
import os
import secrets
import threading
import time
from dataclasses import dataclass
from .core import PBKDF2_ITERATIONS, bool_env
@dataclass(frozen=True)
class DashboardAuthConfig:
username: str
password_hash: str
session_ttl_seconds: int
cookie_secure: bool
@dataclass
class DashboardSession:
username: str
csrf_token: str
expires_at: float
class DashboardAuth:
def __init__(self, config: DashboardAuthConfig) -> None:
self.config = config
self.lock = threading.Lock()
self.sessions: dict[str, DashboardSession] = {}
self.failed_logins: dict[str, list[float]] = {}
def login_allowed(self, key: str) -> bool:
now = time.time()
window_start = now - 900
with self.lock:
attempts = [attempt for attempt in self.failed_logins.get(key, []) if attempt >= window_start]
self.failed_logins[key] = attempts
return len(attempts) < 10
def record_failed_login(self, key: str) -> None:
now = time.time()
with self.lock:
self.failed_logins.setdefault(key, []).append(now)
def clear_failed_login(self, key: str) -> None:
with self.lock:
self.failed_logins.pop(key, None)
def login(self, username: str, password: str) -> tuple[str, DashboardSession] | None:
if not hmac.compare_digest(username, self.config.username):
return None
if not verify_password_hash(self.config.password_hash, password):
return None
session_id = secrets.token_urlsafe(32)
session = DashboardSession(
username=username,
csrf_token=secrets.token_urlsafe(32),
expires_at=time.time() + self.config.session_ttl_seconds,
)
with self.lock:
self.sessions[session_id] = session
return session_id, session
def session_from_cookie(self, cookie_header: str | None) -> tuple[str, DashboardSession] | None:
from http.cookies import SimpleCookie
from .core import SESSION_COOKIE
if not cookie_header:
return None
cookie = SimpleCookie()
cookie.load(cookie_header)
morsel = cookie.get(SESSION_COOKIE)
if morsel is None:
return None
session_id = morsel.value
now = time.time()
with self.lock:
session = self.sessions.get(session_id)
if session is None:
return None
if session.expires_at <= now:
self.sessions.pop(session_id, None)
return None
session.expires_at = now + self.config.session_ttl_seconds
return session_id, session
def logout(self, session_id: str) -> None:
with self.lock:
self.sessions.pop(session_id, None)
def password_hash(password: str) -> str:
salt = secrets.token_bytes(16)
digest = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, PBKDF2_ITERATIONS)
salt_text = base64.urlsafe_b64encode(salt).decode("ascii").rstrip("=")
digest_text = base64.urlsafe_b64encode(digest).decode("ascii").rstrip("=")
return f"pbkdf2_sha256${PBKDF2_ITERATIONS}${salt_text}${digest_text}"
def decode_urlsafe_base64(value: str) -> bytes:
padding = "=" * (-len(value) % 4)
return base64.urlsafe_b64decode(value + padding)
def verify_password_hash(encoded: str, password: str) -> bool:
try:
algorithm, iterations, salt_text, digest_text = encoded.split("$", 3)
if algorithm != "pbkdf2_sha256":
return False
salt = decode_urlsafe_base64(salt_text)
expected = decode_urlsafe_base64(digest_text)
actual = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, int(iterations))
except (ValueError, TypeError):
return False
return hmac.compare_digest(actual, expected)
def dashboard_auth_from_env() -> DashboardAuth | None:
if bool_env("DASHBOARD_AUTH_DISABLED", False):
return None
username = os.getenv("DASHBOARD_USERNAME", "").strip()
encoded_hash = os.getenv("DASHBOARD_PASSWORD_HASH", "").strip()
if not username or not encoded_hash:
raise SystemExit(
"Dashboard auth is enabled but DASHBOARD_USERNAME or DASHBOARD_PASSWORD_HASH is missing. "
"Set credentials or explicitly set DASHBOARD_AUTH_DISABLED=true."
)
ttl = int(os.getenv("DASHBOARD_SESSION_TTL_SECONDS", "28800"))
secure = bool_env("DASHBOARD_COOKIE_SECURE", False)
return DashboardAuth(
DashboardAuthConfig(
username=username,
password_hash=encoded_hash,
session_ttl_seconds=ttl,
cookie_secure=secure,
)
)
def print_password_hash() -> None:
import getpass
first = getpass.getpass("Dashboard password: ")
second = getpass.getpass("Confirm password: ")
if first != second:
raise SystemExit("Passwords did not match")
if len(first) < 12:
raise SystemExit("Use at least 12 characters")
print(password_hash(first))