diff --git a/src/core/engine.py b/src/core/engine.py index 707108f..d540f31 100644 --- a/src/core/engine.py +++ b/src/core/engine.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import Callable, Awaitable +from typing import Callable, Awaitable, List from src.config import AppConfig from src.core.state import MigrationState from src.discord_bot.reader import DiscordReader @@ -194,28 +194,44 @@ class MigrationEngine: if progress_callback: await progress_callback(role.name, idx + 1, total) await asyncio.sleep(self.config.migration.rate_limit_delay_seconds) - async def migrate_emojis(self, progress_callback: Callable[[str, int, int], Awaitable[None]] | None = None): + async def migrate_emojis(self, progress_callback: Callable[[str, str, int, int], Awaitable[None]] | None = None, types_to_include: List[str] = ["Emoji", "Sticker"]): """Copies custom emojis and stickers.""" - emojis = await self.discord_reader.get_emojis() - total = len(emojis) + objs = [] + if "Emoji" in types_to_include: + emojis = await self.discord_reader.get_emojis() + objs.extend([(e, "Emoji") for e in emojis]) + if "Sticker" in types_to_include: + stickers = await self.discord_reader.get_stickers() + objs.extend([(s, "Sticker") for s in stickers]) + + total = len(objs) - for idx, emoji in enumerate(emojis): + for idx, (obj, obj_type) in enumerate(objs): if not self.is_running: break - fluxer_id = self.state.get_fluxer_channel_id(f"emoji_{emoji.id}") + state_key = f"{obj_type.lower()}_{obj.id}" + fluxer_id = self.state.get_fluxer_channel_id(state_key) if not fluxer_id: try: - img_data = await self.discord_reader.download_emoji(emoji) - fluxer_id = await self.fluxer_writer.create_emoji( - name=emoji.name, - image_bytes=img_data - ) + if obj_type == "Emoji": + img_data = await self.discord_reader.download_emoji(obj) + fluxer_id = await self.fluxer_writer.create_emoji( + name=obj.name, + image_bytes=img_data + ) + else: + img_data = await self.discord_reader.download_sticker(obj) + fluxer_id = await self.fluxer_writer.create_sticker( + name=obj.name, + image_bytes=img_data + ) + if fluxer_id: - self.state.set_channel_mapping(f"emoji_{emoji.id}", fluxer_id) + self.state.set_channel_mapping(state_key, fluxer_id) except Exception as e: - logger.error(f"Error downloading/uploading emoji {emoji.name}: {e}") + logger.error(f"Error downloading/uploading {obj_type.lower()} {obj.name}: {e}") - if progress_callback: await progress_callback(emoji.name, idx + 1, total) + if progress_callback: await progress_callback(obj.name, obj_type, idx + 1, total) await asyncio.sleep(self.config.migration.rate_limit_delay_seconds) async def run_full_migration(self): diff --git a/src/discord_bot/reader.py b/src/discord_bot/reader.py index 5888350..66a99b3 100644 --- a/src/discord_bot/reader.py +++ b/src/discord_bot/reader.py @@ -77,6 +77,12 @@ class DiscordReader: return [] return await self.guild.fetch_emojis() + async def get_stickers(self): + """Returns all custom stickers in the server.""" + if not self.guild: + return [] + return await self.guild.fetch_stickers() + async def get_channels(self, category_id: int | None = None): """Yields all non-category channels.""" if not self.guild: @@ -99,6 +105,10 @@ class DiscordReader: """Downloads a Discord emoji into memory.""" return await emoji.read() + async def download_sticker(self, sticker: discord.GuildSticker) -> bytes: + """Downloads a Discord sticker into memory.""" + return await sticker.read() + async def download_attachment(self, attachment: discord.Attachment) -> bytes: """Downloads a Discord attachment into memory.""" return await attachment.read() diff --git a/src/fluxer_bot/writer.py b/src/fluxer_bot/writer.py index dfa3486..141143d 100644 --- a/src/fluxer_bot/writer.py +++ b/src/fluxer_bot/writer.py @@ -115,6 +115,23 @@ class FluxerWriter: print(f"Failed to copy emoji {name}: {e}") return "" + async def create_sticker(self, name: str, image_bytes: bytes) -> str: + """ + Creates a custom sticker in the Fluxer community. + """ + assert self.client is not None + + try: + sticker = await self.client.create_guild_sticker( + guild_id=self.community_id, + name=name, + image=image_bytes + ) + return str(sticker["id"]) + except Exception as e: + print(f"Failed to copy sticker {name}: {e}") + return "" + async def update_guild_metadata(self, name: Optional[str] = None, icon: Optional[bytes] = None, banner: Optional[bytes] = None) -> None: """ Updates the Fluxer community name, icon, and banner. diff --git a/src/ui/app.py b/src/ui/app.py index 9289dae..b73c4d9 100644 --- a/src/ui/app.py +++ b/src/ui/app.py @@ -29,7 +29,7 @@ class MigrationCLI: self.tokens_valid = all(self.validation_results.values()) async def run(self): - console.print(Panel.fit("Discord Reaper", style="bold blue")) + console.print(Panel.fit("Fluxer Reaper", style="bold blue")) await self.validate_config() while True: @@ -230,11 +230,39 @@ class MigrationCLI: self.engine.is_running = False async def copy_emojis(self): - if not Confirm.ask("Are you sure you want to copy emojis and stickers?"): - return - - console.print("\n[bold green]Starting Emoji Migration...[/bold green]") + console.print("\n[yellow]Fetching emojis and stickers...[/yellow]") try: + await self.engine.start_connections() + emojis = await self.engine.discord_reader.get_emojis() + stickers = await self.engine.discord_reader.get_stickers() + + console.print(f"\n[bold]Custom emojis found: {len(emojis)}[/bold]") + for e in emojis: + console.print(f" - Emoji: {e.name}") + + console.print(f"[bold]Custom stickers found: {len(stickers)}[/bold]") + for s in stickers: + console.print(f" - Sticker: {s.name}") + + console.print("\n(1) Copy Emojis only") + console.print("(2) Copy Stickers only") + console.print("(3) Copy Emojis and Stickers") + console.print("(B) Back") + + choice = Prompt.ask("Select an option", choices=["1", "2", "3", "B", "b"], default="1").upper() + + if choice == "B": + return + + types_to_include = [] + if choice == "1": + types_to_include = ["Emoji"] + elif choice == "2": + types_to_include = ["Sticker"] + elif choice == "3": + types_to_include = ["Emoji", "Sticker"] + + console.print("\n[bold green]Starting Migration...[/bold green]") with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), @@ -243,16 +271,15 @@ class MigrationCLI: console=console ) as progress: - emoji_task = progress.add_task("[cyan]Copying Emojis...", total=100) + emoji_task = progress.add_task("[cyan]Copying Assets...", total=100) - async def update_progress(item_name: str, current: int, total: int): - progress.update(emoji_task, total=total, completed=current, description=f"[cyan]Copying Emoji: {item_name}") + async def update_progress(item_name: str, item_type: str, current: int, total: int): + progress.update(emoji_task, total=total, completed=current, description=f"[cyan]Copying {item_type}: {item_name}") - await self.engine.start_connections() self.engine.is_running = True - await self.engine.migrate_emojis(progress_callback=update_progress) + await self.engine.migrate_emojis(progress_callback=update_progress, types_to_include=types_to_include) - console.print("[bold green]Emoji migration complete![/bold green]") + console.print("[bold green]Migration complete![/bold green]") except Exception as e: console.print(f"[bold red]Error during emoji migration: {str(e)}[/bold red]")