diff --git a/src/stoat/clone_server.py b/src/stoat/clone_server.py index 22667f4..04c36bb 100644 --- a/src/stoat/clone_server.py +++ b/src/stoat/clone_server.py @@ -63,6 +63,10 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl categories = await context.discord_reader.get_categories() channels = await context.discord_reader.get_channels() + # Sort categories and channels by position to preserve order + categories = sorted(categories, key=lambda c: getattr(c, 'position', 0)) + channels = sorted(channels, key=lambda c: getattr(c, 'position', 0)) + cloned_info = { "categories_created": [], "channels_created": [], @@ -73,13 +77,8 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl # 1. Identify categories to create missing_categories = [cat for cat in categories if force or not context.state.get_target_category_id(str(cat.id))] - missing_category_ids = {str(cat.id) for cat in missing_categories} # 2. Identify channels to create or move - # Fetch current Target state to check parent_ids - target_channels = await context.writer.get_channels() - target_parent_map = {str(c["id"]): (str(c.get("parent_id")) if c.get("parent_id") else None) for c in target_channels} - channels_to_create = [] channels_to_move = [] @@ -88,11 +87,8 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl target_id = context.state.get_target_channel_id(discord_id) if force or not target_id: - # We'll resolve the parent_id in the loop after categories are created channels_to_create.append(ch) else: - # Always add to move/sync list to ensure properties (topic, nsfw, slowmode) are synced - # even if the parent category is already correct. channels_to_move.append((ch, target_id)) total = len(missing_categories) + len(channels_to_create) + len(channels_to_move) @@ -102,7 +98,6 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl return cloned_info # 1. Migrate Categories - missing_category_ids = {str(cat.id) for cat in missing_categories} for cat in missing_categories: if not context.is_running: break @@ -148,7 +143,6 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl # Sync properties immediately await context.writer.modify_channel( channel_id=target_id, - parent_id=None, name=channel.name, topic=topic, nsfw=nsfw, @@ -170,7 +164,6 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl await context.writer.modify_channel( channel_id=target_id, - parent_id=None, name=channel.name, topic=topic, nsfw=nsfw, @@ -184,49 +177,66 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl # 4. Final step: Parent the channels into categories via mass server.edit() logger.info("Parenting all channels into their respective categories...") - # Force refetch to ensure we see all newly created categories from the loop above + # Force refresh server to get latest categories created during migration server = await context.writer._get_server(populate_channels=True, force=True) - cats = list(server.categories) if hasattr(server, "categories") and server.categories else [] - # Workaround: Ensure default properties are set for all categories - for c in cats: - if not hasattr(c, "default_permissions"): c.default_permissions = None - if not hasattr(c, "role_permissions"): c.role_permissions = {} + # Stoat categories are managed via server.edit(categories=...) + # We must preserve ALL categories that exist in Stoat, but update the ones we manage. + target_categories = [] - # We will build a map of target_cat_id -> list of target_ch_ids - cat_to_channels = {} - for cat in categories: - target_cat_id = context.state.get_target_category_id(str(cat.id)) - if not target_cat_id: continue + # Build map of Discord Cat ID -> [Target Channel IDs] + # Channels should already be sorted by position in the 'channels' list + discord_cat_to_channels = {} + for ch in channels: + cat_id = str(getattr(ch, 'category_id', '')) if ch.category_id else "unparented" + if cat_id not in discord_cat_to_channels: + discord_cat_to_channels[cat_id] = [] + + target_id = context.state.get_target_channel_id(str(ch.id)) + if target_id: + discord_cat_to_channels[cat_id].append(target_id) - target_channels_for_cat = [] - for ch in channels: - if str(getattr(ch, 'category_id', '')) == str(cat.id): - target_ch_id = context.state.get_target_channel_id(str(ch.id)) - if target_ch_id: - target_channels_for_cat.append(target_ch_id) + # Resolve Stoat categories + # We iterate over the categories from the server to ensure we don't drop any + for stoat_cat in server.categories: + # Check if this Stoat category maps to any Discord category + discord_cat_id = next((d_id for d_id, t_id in context.state.category_map.items() if t_id == str(stoat_cat.id)), None) - cat_to_channels[target_cat_id] = target_channels_for_cat - - # Now correctly assign them in the cats array, and remove them from other cats - all_assigned_channels = set() - for cat_id, ch_list in cat_to_channels.items(): - all_assigned_channels.update(ch_list) + if discord_cat_id: + # Managed category - update its channels + new_channels = discord_cat_to_channels.get(discord_cat_id, []) + new_cat = stoat.Category( + id=str(stoat_cat.id), + title=stoat_cat.title, + channels=new_channels + ) + else: + # Unmanaged category - preserve as is but ensure all channels are strings + new_cat = stoat.Category( + id=str(stoat_cat.id), + title=stoat_cat.title, + channels=[str(c_id) for c_id in stoat_cat.channels] + ) - for i, c in enumerate(cats): - # Remove any channels that are being assigned to a specific category - new_channels = [ch for ch in c.channels if ch not in all_assigned_channels] - if c.id in cat_to_channels: - # If this is one of our managed categories, set its channels to exactly what it should be - new_channels = cat_to_channels[c.id] - - new_cat = stoat.Category(id=c.id, title=c.title, channels=new_channels) - new_cat.default_permissions = c.default_permissions - new_cat.role_permissions = c.role_permissions - cats[i] = new_cat + # Workaround for stoat.py missing attributes + if not hasattr(new_cat, "default_permissions"): new_cat.default_permissions = getattr(stoat_cat, "default_permissions", None) + if not hasattr(new_cat, "role_permissions"): new_cat.role_permissions = getattr(stoat_cat, "role_permissions", {}) + target_categories.append(new_cat) + + # Sort target_categories based on the original Discord category positions + def get_cat_position(s_cat): + d_id = next((d_id for d_id, t_id in context.state.category_map.items() if t_id == str(s_cat.id)), None) + if d_id: + d_cat = next((c for c in categories if str(c.id) == d_id), None) + if d_cat: + return getattr(d_cat, 'position', 999) + return 999 + + target_categories.sort(key=get_cat_position) + try: - await server.edit(categories=cats) + await server.edit(categories=target_categories) logger.info("Successfully parented all channels.") except Exception as ex: logger.error(f"Failed to mass parent channels: {ex}") diff --git a/src/stoat/writer.py b/src/stoat/writer.py index 07a37af..f5e42b6 100644 --- a/src/stoat/writer.py +++ b/src/stoat/writer.py @@ -97,10 +97,12 @@ class StoatWriter: async def _get_server(self, populate_channels=False, force=False): # Always refetch if channels are requested AND we don't already have them - # Or if force is True (e.g. after category creation/mutation) - # Stoat Server objects use __slots__, so we can't easily add our own tracking attributes. - if force or (populate_channels and (not self._server or not hasattr(self._server, "channels") or not self._server.channels)): - self._server = await self.client.fetch_server(self.community_id, populate_channels=True) + # OR if force=True is passed. + if force: + self._server = await self.client.fetch_server(self.community_id, populate_channels=populate_channels) + elif populate_channels: + if not self._server or not hasattr(self._server, "channels") or not self._server.channels: + self._server = await self.client.fetch_server(self.community_id, populate_channels=True) elif not self._server: self._server = await self.client.fetch_server(self.community_id, populate_channels=False) return self._server @@ -235,9 +237,6 @@ class StoatWriter: try: if type == 4: # Category # The POST /categories endpoint throws 404 on some server versions, so we use server.edit(categories) - # Force refetch to ensure we have the absolute latest state before editing categories array - server = await self._get_server(populate_channels=True, force=True) - import random import time chars = "0123456789ABCDEFGHJKMNPQRSTVWXYZ" @@ -255,8 +254,8 @@ class StoatWriter: if not hasattr(new_cat, "role_permissions"): new_cat.role_permissions = {} categories.append(new_cat) - # server.edit returns a new Server object on some versions/implementations; maintain local reference - self._server = await server.edit(categories=categories) + await server.edit(categories=categories) + self._server = None # Clear cache after structural change return new_id else: # Text Channel ch = await server.create_text_channel(name=name, description=topic) @@ -283,6 +282,7 @@ class StoatWriter: if edit_kwargs: await channel.edit(**edit_kwargs) + self._server = None # Clear cache # clone_server.py now handles all parenting bulk logic return True diff --git a/src/ui/modals.py b/src/ui/modals.py index 062c259..7b7b1a5 100644 --- a/src/ui/modals.py +++ b/src/ui/modals.py @@ -110,7 +110,7 @@ class ProgressScreen(Screen[None]): yield Button("Back", id="btn_back", disabled=False) yield Button("Main Menu", id="btn_main_menu", disabled=False) with Horizontal(classes="action_row", id="prog_actions_cancel"): - yield Button("Cancel", id="btn_cancel", variant="error", tooltip="Stop current operation") + yield Button("Back", id="btn_cancel", variant="default", tooltip="Stop current operation") yield Footer() def __init__(self, log_level: str = "INFO", *args, **kwargs): @@ -157,18 +157,22 @@ class ProgressScreen(Screen[None]): self.confirm_future.set_result(btn_id) return - # If Cancel is pressed during operation, invoke callback and stay on screen + # If Cancel/Back is pressed during operation or loading if btn_id == "btn_cancel": if self.cancel_callback: self.cancel_callback() - - # Show cancelling message and disable button - self.set_status("[bold red]Cancelling... waiting for tasks to finish...[/bold red]") - try: - event.button.disabled = True - event.button.label = "Stopping..." - except Exception: - pass + # Show cancelling message and disable button + self.set_status("[bold red]Cancelling... waiting for tasks to finish...[/bold red]") + try: + event.button.disabled = True + event.button.label = "Stopping..." + except Exception: + pass + else: + # If no callback, just dismiss (likely in Loading phase) + if self.timer_event: + self.timer_event.stop() + self.dismiss("btn_back") return # If operation is done (report phase), just dismiss with the action @@ -312,7 +316,11 @@ class ProgressScreen(Screen[None]): except Exception: pass # Show Cancel button - try: self.query_one("#prog_actions_cancel", Horizontal).display = True + try: + cancel_btn = self.query_one("#btn_cancel", Button) + cancel_btn.label = "Cancel" + cancel_btn.variant = "error" + self.query_one("#prog_actions_cancel", Horizontal).display = True except Exception: pass try: self.query_one("#prog_loader", LoadingIndicator).display = True