fix channel positions in stoat

This commit is contained in:
rambros 2026-03-09 12:49:47 +05:30
parent 0aa32217b1
commit 8cdbcab5f6
3 changed files with 84 additions and 66 deletions

View file

@ -63,6 +63,10 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
categories = await context.discord_reader.get_categories() categories = await context.discord_reader.get_categories()
channels = await context.discord_reader.get_channels() channels = await context.discord_reader.get_channels()
# Sort categories and channels by position to preserve order
categories = sorted(categories, key=lambda c: getattr(c, 'position', 0))
channels = sorted(channels, key=lambda c: getattr(c, 'position', 0))
cloned_info = { cloned_info = {
"categories_created": [], "categories_created": [],
"channels_created": [], "channels_created": [],
@ -73,13 +77,8 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
# 1. Identify categories to create # 1. Identify categories to create
missing_categories = [cat for cat in categories if force or not context.state.get_target_category_id(str(cat.id))] missing_categories = [cat for cat in categories if force or not context.state.get_target_category_id(str(cat.id))]
missing_category_ids = {str(cat.id) for cat in missing_categories}
# 2. Identify channels to create or move # 2. Identify channels to create or move
# Fetch current Target state to check parent_ids
target_channels = await context.writer.get_channels()
target_parent_map = {str(c["id"]): (str(c.get("parent_id")) if c.get("parent_id") else None) for c in target_channels}
channels_to_create = [] channels_to_create = []
channels_to_move = [] channels_to_move = []
@ -88,11 +87,8 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
target_id = context.state.get_target_channel_id(discord_id) target_id = context.state.get_target_channel_id(discord_id)
if force or not target_id: if force or not target_id:
# We'll resolve the parent_id in the loop after categories are created
channels_to_create.append(ch) channels_to_create.append(ch)
else: else:
# Always add to move/sync list to ensure properties (topic, nsfw, slowmode) are synced
# even if the parent category is already correct.
channels_to_move.append((ch, target_id)) channels_to_move.append((ch, target_id))
total = len(missing_categories) + len(channels_to_create) + len(channels_to_move) total = len(missing_categories) + len(channels_to_create) + len(channels_to_move)
@ -102,7 +98,6 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
return cloned_info return cloned_info
# 1. Migrate Categories # 1. Migrate Categories
missing_category_ids = {str(cat.id) for cat in missing_categories}
for cat in missing_categories: for cat in missing_categories:
if not context.is_running: break if not context.is_running: break
@ -148,7 +143,6 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
# Sync properties immediately # Sync properties immediately
await context.writer.modify_channel( await context.writer.modify_channel(
channel_id=target_id, channel_id=target_id,
parent_id=None,
name=channel.name, name=channel.name,
topic=topic, topic=topic,
nsfw=nsfw, nsfw=nsfw,
@ -170,7 +164,6 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
await context.writer.modify_channel( await context.writer.modify_channel(
channel_id=target_id, channel_id=target_id,
parent_id=None,
name=channel.name, name=channel.name,
topic=topic, topic=topic,
nsfw=nsfw, nsfw=nsfw,
@ -184,49 +177,66 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
# 4. Final step: Parent the channels into categories via mass server.edit() # 4. Final step: Parent the channels into categories via mass server.edit()
logger.info("Parenting all channels into their respective categories...") logger.info("Parenting all channels into their respective categories...")
# Force refetch to ensure we see all newly created categories from the loop above # Force refresh server to get latest categories created during migration
server = await context.writer._get_server(populate_channels=True, force=True) server = await context.writer._get_server(populate_channels=True, force=True)
cats = list(server.categories) if hasattr(server, "categories") and server.categories else []
# Workaround: Ensure default properties are set for all categories # Stoat categories are managed via server.edit(categories=...)
for c in cats: # We must preserve ALL categories that exist in Stoat, but update the ones we manage.
if not hasattr(c, "default_permissions"): c.default_permissions = None target_categories = []
if not hasattr(c, "role_permissions"): c.role_permissions = {}
# We will build a map of target_cat_id -> list of target_ch_ids # Build map of Discord Cat ID -> [Target Channel IDs]
cat_to_channels = {} # Channels should already be sorted by position in the 'channels' list
for cat in categories: discord_cat_to_channels = {}
target_cat_id = context.state.get_target_category_id(str(cat.id)) for ch in channels:
if not target_cat_id: continue cat_id = str(getattr(ch, 'category_id', '')) if ch.category_id else "unparented"
if cat_id not in discord_cat_to_channels:
discord_cat_to_channels[cat_id] = []
target_channels_for_cat = [] target_id = context.state.get_target_channel_id(str(ch.id))
for ch in channels: if target_id:
if str(getattr(ch, 'category_id', '')) == str(cat.id): discord_cat_to_channels[cat_id].append(target_id)
target_ch_id = context.state.get_target_channel_id(str(ch.id))
if target_ch_id:
target_channels_for_cat.append(target_ch_id)
cat_to_channels[target_cat_id] = target_channels_for_cat # Resolve Stoat categories
# We iterate over the categories from the server to ensure we don't drop any
for stoat_cat in server.categories:
# Check if this Stoat category maps to any Discord category
discord_cat_id = next((d_id for d_id, t_id in context.state.category_map.items() if t_id == str(stoat_cat.id)), None)
# Now correctly assign them in the cats array, and remove them from other cats if discord_cat_id:
all_assigned_channels = set() # Managed category - update its channels
for cat_id, ch_list in cat_to_channels.items(): new_channels = discord_cat_to_channels.get(discord_cat_id, [])
all_assigned_channels.update(ch_list) new_cat = stoat.Category(
id=str(stoat_cat.id),
title=stoat_cat.title,
channels=new_channels
)
else:
# Unmanaged category - preserve as is but ensure all channels are strings
new_cat = stoat.Category(
id=str(stoat_cat.id),
title=stoat_cat.title,
channels=[str(c_id) for c_id in stoat_cat.channels]
)
for i, c in enumerate(cats): # Workaround for stoat.py missing attributes
# Remove any channels that are being assigned to a specific category if not hasattr(new_cat, "default_permissions"): new_cat.default_permissions = getattr(stoat_cat, "default_permissions", None)
new_channels = [ch for ch in c.channels if ch not in all_assigned_channels] if not hasattr(new_cat, "role_permissions"): new_cat.role_permissions = getattr(stoat_cat, "role_permissions", {})
if c.id in cat_to_channels:
# If this is one of our managed categories, set its channels to exactly what it should be
new_channels = cat_to_channels[c.id]
new_cat = stoat.Category(id=c.id, title=c.title, channels=new_channels) target_categories.append(new_cat)
new_cat.default_permissions = c.default_permissions
new_cat.role_permissions = c.role_permissions # Sort target_categories based on the original Discord category positions
cats[i] = new_cat def get_cat_position(s_cat):
d_id = next((d_id for d_id, t_id in context.state.category_map.items() if t_id == str(s_cat.id)), None)
if d_id:
d_cat = next((c for c in categories if str(c.id) == d_id), None)
if d_cat:
return getattr(d_cat, 'position', 999)
return 999
target_categories.sort(key=get_cat_position)
try: try:
await server.edit(categories=cats) await server.edit(categories=target_categories)
logger.info("Successfully parented all channels.") logger.info("Successfully parented all channels.")
except Exception as ex: except Exception as ex:
logger.error(f"Failed to mass parent channels: {ex}") logger.error(f"Failed to mass parent channels: {ex}")

View file

@ -97,10 +97,12 @@ class StoatWriter:
async def _get_server(self, populate_channels=False, force=False): async def _get_server(self, populate_channels=False, force=False):
# Always refetch if channels are requested AND we don't already have them # Always refetch if channels are requested AND we don't already have them
# Or if force is True (e.g. after category creation/mutation) # OR if force=True is passed.
# Stoat Server objects use __slots__, so we can't easily add our own tracking attributes. if force:
if force or (populate_channels and (not self._server or not hasattr(self._server, "channels") or not self._server.channels)): self._server = await self.client.fetch_server(self.community_id, populate_channels=populate_channels)
self._server = await self.client.fetch_server(self.community_id, populate_channels=True) elif populate_channels:
if not self._server or not hasattr(self._server, "channels") or not self._server.channels:
self._server = await self.client.fetch_server(self.community_id, populate_channels=True)
elif not self._server: elif not self._server:
self._server = await self.client.fetch_server(self.community_id, populate_channels=False) self._server = await self.client.fetch_server(self.community_id, populate_channels=False)
return self._server return self._server
@ -235,9 +237,6 @@ class StoatWriter:
try: try:
if type == 4: # Category if type == 4: # Category
# The POST /categories endpoint throws 404 on some server versions, so we use server.edit(categories) # The POST /categories endpoint throws 404 on some server versions, so we use server.edit(categories)
# Force refetch to ensure we have the absolute latest state before editing categories array
server = await self._get_server(populate_channels=True, force=True)
import random import random
import time import time
chars = "0123456789ABCDEFGHJKMNPQRSTVWXYZ" chars = "0123456789ABCDEFGHJKMNPQRSTVWXYZ"
@ -255,8 +254,8 @@ class StoatWriter:
if not hasattr(new_cat, "role_permissions"): new_cat.role_permissions = {} if not hasattr(new_cat, "role_permissions"): new_cat.role_permissions = {}
categories.append(new_cat) categories.append(new_cat)
# server.edit returns a new Server object on some versions/implementations; maintain local reference await server.edit(categories=categories)
self._server = await server.edit(categories=categories) self._server = None # Clear cache after structural change
return new_id return new_id
else: # Text Channel else: # Text Channel
ch = await server.create_text_channel(name=name, description=topic) ch = await server.create_text_channel(name=name, description=topic)
@ -283,6 +282,7 @@ class StoatWriter:
if edit_kwargs: if edit_kwargs:
await channel.edit(**edit_kwargs) await channel.edit(**edit_kwargs)
self._server = None # Clear cache
# clone_server.py now handles all parenting bulk logic # clone_server.py now handles all parenting bulk logic
return True return True

View file

@ -110,7 +110,7 @@ class ProgressScreen(Screen[None]):
yield Button("Back", id="btn_back", disabled=False) yield Button("Back", id="btn_back", disabled=False)
yield Button("Main Menu", id="btn_main_menu", disabled=False) yield Button("Main Menu", id="btn_main_menu", disabled=False)
with Horizontal(classes="action_row", id="prog_actions_cancel"): with Horizontal(classes="action_row", id="prog_actions_cancel"):
yield Button("Cancel", id="btn_cancel", variant="error", tooltip="Stop current operation") yield Button("Back", id="btn_cancel", variant="default", tooltip="Stop current operation")
yield Footer() yield Footer()
def __init__(self, log_level: str = "INFO", *args, **kwargs): def __init__(self, log_level: str = "INFO", *args, **kwargs):
@ -157,18 +157,22 @@ class ProgressScreen(Screen[None]):
self.confirm_future.set_result(btn_id) self.confirm_future.set_result(btn_id)
return return
# If Cancel is pressed during operation, invoke callback and stay on screen # If Cancel/Back is pressed during operation or loading
if btn_id == "btn_cancel": if btn_id == "btn_cancel":
if self.cancel_callback: if self.cancel_callback:
self.cancel_callback() self.cancel_callback()
# Show cancelling message and disable button
# Show cancelling message and disable button self.set_status("[bold red]Cancelling... waiting for tasks to finish...[/bold red]")
self.set_status("[bold red]Cancelling... waiting for tasks to finish...[/bold red]") try:
try: event.button.disabled = True
event.button.disabled = True event.button.label = "Stopping..."
event.button.label = "Stopping..." except Exception:
except Exception: pass
pass else:
# If no callback, just dismiss (likely in Loading phase)
if self.timer_event:
self.timer_event.stop()
self.dismiss("btn_back")
return return
# If operation is done (report phase), just dismiss with the action # If operation is done (report phase), just dismiss with the action
@ -312,7 +316,11 @@ class ProgressScreen(Screen[None]):
except Exception: pass except Exception: pass
# Show Cancel button # Show Cancel button
try: self.query_one("#prog_actions_cancel", Horizontal).display = True try:
cancel_btn = self.query_one("#btn_cancel", Button)
cancel_btn.label = "Cancel"
cancel_btn.variant = "error"
self.query_one("#prog_actions_cancel", Horizontal).display = True
except Exception: pass except Exception: pass
try: self.query_one("#prog_loader", LoadingIndicator).display = True try: self.query_one("#prog_loader", LoadingIndicator).display = True