implement rate limit notification
This commit is contained in:
parent
1ba8b3e8ea
commit
ed0e79d135
1 changed files with 38 additions and 18 deletions
|
|
@ -2,11 +2,13 @@ import sys
|
|||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from rich.console import Console
|
||||
from rich.prompt import Prompt, Confirm
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, ProgressColumn
|
||||
from rich.text import Text
|
||||
from src.core.configuration import load_config, save_config
|
||||
from src.core.base import MigrationContext
|
||||
|
||||
|
|
@ -21,8 +23,19 @@ import src.stoat.migrate_message as stoat_migrate
|
|||
|
||||
from src.core.audit import log_audit_event
|
||||
|
||||
global_rate_limit_msg = ""
|
||||
global_rate_limit_expires = 0.0
|
||||
|
||||
class RateLimitColumn(ProgressColumn):
|
||||
"""Renders the current dynamic rate limit wait."""
|
||||
def render(self, task) -> Text:
|
||||
global global_rate_limit_msg, global_rate_limit_expires
|
||||
if time.time() < global_rate_limit_expires:
|
||||
return Text.from_markup(f"[dim]\\[wait: {global_rate_limit_msg}][/dim]")
|
||||
return Text("")
|
||||
|
||||
class RateLimitHandler(logging.Handler):
|
||||
"""Intersects library logs to print clean rate limit messages."""
|
||||
"""Intersects library logs to capture clean dynamic rate limit messages."""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._last_print = ""
|
||||
|
|
@ -30,27 +43,28 @@ class RateLimitHandler(logging.Handler):
|
|||
def emit(self, record):
|
||||
try:
|
||||
msg = record.getMessage()
|
||||
# Detect rate limit messages from discord.py or fluxer.py
|
||||
# Detect rate limit messages from discord.py, fluxer.py, stoat, etc.
|
||||
if "retry" in msg.lower() and ("rate limit" in msg.lower() or "429" in msg):
|
||||
# Extract seconds using regex: supports "retry in 5.50s" and "Retrying in 5.50 seconds"
|
||||
match = re.search(r"in ([\d.]+)\s*(?:seconds?|s)", msg, re.IGNORECASE)
|
||||
if match:
|
||||
seconds = match.group(1)
|
||||
platform = "discord" if "discord" in record.name.lower() else "fluxer"
|
||||
if "discord" in record.name.lower():
|
||||
platform = "Discord"
|
||||
elif "fluxer" in record.name.lower():
|
||||
platform = "Fluxer"
|
||||
elif "stoat" in record.name.lower():
|
||||
platform = "Stoat"
|
||||
else:
|
||||
platform = "API"
|
||||
|
||||
# Format the message
|
||||
new_msg = f"{platform} API rate limit: will retry after {seconds}"
|
||||
|
||||
# Avoid spamming the exact same message if nothing changed
|
||||
if new_msg == self._last_print:
|
||||
return
|
||||
|
||||
self._last_print = new_msg
|
||||
|
||||
# Use rich console to print on the same line.
|
||||
# end="\r" works with rich's internal live-update handling.
|
||||
# We add some padding to clear old text.
|
||||
console.print(f"{new_msg} ", end="\r")
|
||||
# Update the global dynamic rate limit info
|
||||
global global_rate_limit_msg, global_rate_limit_expires
|
||||
global_rate_limit_msg = f"{platform} rate limit {seconds}s"
|
||||
try:
|
||||
global_rate_limit_expires = time.time() + float(seconds)
|
||||
except ValueError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
|
@ -76,6 +90,7 @@ class MigrationCLI:
|
|||
rl_handler = RateLimitHandler()
|
||||
logging.getLogger("discord").addHandler(rl_handler)
|
||||
logging.getLogger("fluxer").addHandler(rl_handler)
|
||||
logging.getLogger("stoat").addHandler(rl_handler)
|
||||
|
||||
async def validate_config(self):
|
||||
self.validation_results = {
|
||||
|
|
@ -454,7 +469,7 @@ class MigrationCLI:
|
|||
channels = []
|
||||
try:
|
||||
await self.engine.start_connections()
|
||||
with console.status("[yellow]Syncing Fluxer channel state...[/yellow]"):
|
||||
with console.status(f"[yellow]Syncing {self.target_platform.capitalize()} channel state...[/yellow]"):
|
||||
await sync_channel_state(self.engine)
|
||||
categories = await self.engine.discord_reader.get_categories()
|
||||
channels = await self.engine.discord_reader.get_channels()
|
||||
|
|
@ -544,6 +559,7 @@ class MigrationCLI:
|
|||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
RateLimitColumn(),
|
||||
console=console
|
||||
) as progress:
|
||||
|
||||
|
|
@ -649,6 +665,7 @@ class MigrationCLI:
|
|||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
RateLimitColumn(),
|
||||
console=console
|
||||
) as progress:
|
||||
|
||||
|
|
@ -700,6 +717,7 @@ class MigrationCLI:
|
|||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
RateLimitColumn(),
|
||||
console=console
|
||||
) as progress:
|
||||
|
||||
|
|
@ -825,6 +843,7 @@ class MigrationCLI:
|
|||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
RateLimitColumn(),
|
||||
console=console
|
||||
) as progress:
|
||||
|
||||
|
|
@ -1233,6 +1252,7 @@ class MigrationCLI:
|
|||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
RateLimitColumn(),
|
||||
console=console
|
||||
) as progress:
|
||||
task = progress.add_task(f"[cyan]Migrating 0/{total_messages} messages...", total=total_messages)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue