fix channel positions in stoat
This commit is contained in:
parent
0aa32217b1
commit
8cdbcab5f6
3 changed files with 84 additions and 66 deletions
|
|
@ -63,6 +63,10 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
|||
categories = await context.discord_reader.get_categories()
|
||||
channels = await context.discord_reader.get_channels()
|
||||
|
||||
# Sort categories and channels by position to preserve order
|
||||
categories = sorted(categories, key=lambda c: getattr(c, 'position', 0))
|
||||
channels = sorted(channels, key=lambda c: getattr(c, 'position', 0))
|
||||
|
||||
cloned_info = {
|
||||
"categories_created": [],
|
||||
"channels_created": [],
|
||||
|
|
@ -73,13 +77,8 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
|||
|
||||
# 1. Identify categories to create
|
||||
missing_categories = [cat for cat in categories if force or not context.state.get_target_category_id(str(cat.id))]
|
||||
missing_category_ids = {str(cat.id) for cat in missing_categories}
|
||||
|
||||
# 2. Identify channels to create or move
|
||||
# Fetch current Target state to check parent_ids
|
||||
target_channels = await context.writer.get_channels()
|
||||
target_parent_map = {str(c["id"]): (str(c.get("parent_id")) if c.get("parent_id") else None) for c in target_channels}
|
||||
|
||||
channels_to_create = []
|
||||
channels_to_move = []
|
||||
|
||||
|
|
@ -88,11 +87,8 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
|||
target_id = context.state.get_target_channel_id(discord_id)
|
||||
|
||||
if force or not target_id:
|
||||
# We'll resolve the parent_id in the loop after categories are created
|
||||
channels_to_create.append(ch)
|
||||
else:
|
||||
# Always add to move/sync list to ensure properties (topic, nsfw, slowmode) are synced
|
||||
# even if the parent category is already correct.
|
||||
channels_to_move.append((ch, target_id))
|
||||
|
||||
total = len(missing_categories) + len(channels_to_create) + len(channels_to_move)
|
||||
|
|
@ -102,7 +98,6 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
|||
return cloned_info
|
||||
|
||||
# 1. Migrate Categories
|
||||
missing_category_ids = {str(cat.id) for cat in missing_categories}
|
||||
for cat in missing_categories:
|
||||
if not context.is_running: break
|
||||
|
||||
|
|
@ -148,7 +143,6 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
|||
# Sync properties immediately
|
||||
await context.writer.modify_channel(
|
||||
channel_id=target_id,
|
||||
parent_id=None,
|
||||
name=channel.name,
|
||||
topic=topic,
|
||||
nsfw=nsfw,
|
||||
|
|
@ -170,7 +164,6 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
|||
|
||||
await context.writer.modify_channel(
|
||||
channel_id=target_id,
|
||||
parent_id=None,
|
||||
name=channel.name,
|
||||
topic=topic,
|
||||
nsfw=nsfw,
|
||||
|
|
@ -184,49 +177,66 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
|||
|
||||
# 4. Final step: Parent the channels into categories via mass server.edit()
|
||||
logger.info("Parenting all channels into their respective categories...")
|
||||
# Force refetch to ensure we see all newly created categories from the loop above
|
||||
# Force refresh server to get latest categories created during migration
|
||||
server = await context.writer._get_server(populate_channels=True, force=True)
|
||||
cats = list(server.categories) if hasattr(server, "categories") and server.categories else []
|
||||
|
||||
# Workaround: Ensure default properties are set for all categories
|
||||
for c in cats:
|
||||
if not hasattr(c, "default_permissions"): c.default_permissions = None
|
||||
if not hasattr(c, "role_permissions"): c.role_permissions = {}
|
||||
# Stoat categories are managed via server.edit(categories=...)
|
||||
# We must preserve ALL categories that exist in Stoat, but update the ones we manage.
|
||||
target_categories = []
|
||||
|
||||
# We will build a map of target_cat_id -> list of target_ch_ids
|
||||
cat_to_channels = {}
|
||||
for cat in categories:
|
||||
target_cat_id = context.state.get_target_category_id(str(cat.id))
|
||||
if not target_cat_id: continue
|
||||
# Build map of Discord Cat ID -> [Target Channel IDs]
|
||||
# Channels should already be sorted by position in the 'channels' list
|
||||
discord_cat_to_channels = {}
|
||||
for ch in channels:
|
||||
cat_id = str(getattr(ch, 'category_id', '')) if ch.category_id else "unparented"
|
||||
if cat_id not in discord_cat_to_channels:
|
||||
discord_cat_to_channels[cat_id] = []
|
||||
|
||||
target_id = context.state.get_target_channel_id(str(ch.id))
|
||||
if target_id:
|
||||
discord_cat_to_channels[cat_id].append(target_id)
|
||||
|
||||
target_channels_for_cat = []
|
||||
for ch in channels:
|
||||
if str(getattr(ch, 'category_id', '')) == str(cat.id):
|
||||
target_ch_id = context.state.get_target_channel_id(str(ch.id))
|
||||
if target_ch_id:
|
||||
target_channels_for_cat.append(target_ch_id)
|
||||
# Resolve Stoat categories
|
||||
# We iterate over the categories from the server to ensure we don't drop any
|
||||
for stoat_cat in server.categories:
|
||||
# Check if this Stoat category maps to any Discord category
|
||||
discord_cat_id = next((d_id for d_id, t_id in context.state.category_map.items() if t_id == str(stoat_cat.id)), None)
|
||||
|
||||
cat_to_channels[target_cat_id] = target_channels_for_cat
|
||||
|
||||
# Now correctly assign them in the cats array, and remove them from other cats
|
||||
all_assigned_channels = set()
|
||||
for cat_id, ch_list in cat_to_channels.items():
|
||||
all_assigned_channels.update(ch_list)
|
||||
if discord_cat_id:
|
||||
# Managed category - update its channels
|
||||
new_channels = discord_cat_to_channels.get(discord_cat_id, [])
|
||||
new_cat = stoat.Category(
|
||||
id=str(stoat_cat.id),
|
||||
title=stoat_cat.title,
|
||||
channels=new_channels
|
||||
)
|
||||
else:
|
||||
# Unmanaged category - preserve as is but ensure all channels are strings
|
||||
new_cat = stoat.Category(
|
||||
id=str(stoat_cat.id),
|
||||
title=stoat_cat.title,
|
||||
channels=[str(c_id) for c_id in stoat_cat.channels]
|
||||
)
|
||||
|
||||
for i, c in enumerate(cats):
|
||||
# Remove any channels that are being assigned to a specific category
|
||||
new_channels = [ch for ch in c.channels if ch not in all_assigned_channels]
|
||||
if c.id in cat_to_channels:
|
||||
# If this is one of our managed categories, set its channels to exactly what it should be
|
||||
new_channels = cat_to_channels[c.id]
|
||||
|
||||
new_cat = stoat.Category(id=c.id, title=c.title, channels=new_channels)
|
||||
new_cat.default_permissions = c.default_permissions
|
||||
new_cat.role_permissions = c.role_permissions
|
||||
cats[i] = new_cat
|
||||
# Workaround for stoat.py missing attributes
|
||||
if not hasattr(new_cat, "default_permissions"): new_cat.default_permissions = getattr(stoat_cat, "default_permissions", None)
|
||||
if not hasattr(new_cat, "role_permissions"): new_cat.role_permissions = getattr(stoat_cat, "role_permissions", {})
|
||||
|
||||
target_categories.append(new_cat)
|
||||
|
||||
# Sort target_categories based on the original Discord category positions
|
||||
def get_cat_position(s_cat):
|
||||
d_id = next((d_id for d_id, t_id in context.state.category_map.items() if t_id == str(s_cat.id)), None)
|
||||
if d_id:
|
||||
d_cat = next((c for c in categories if str(c.id) == d_id), None)
|
||||
if d_cat:
|
||||
return getattr(d_cat, 'position', 999)
|
||||
return 999
|
||||
|
||||
target_categories.sort(key=get_cat_position)
|
||||
|
||||
try:
|
||||
await server.edit(categories=cats)
|
||||
await server.edit(categories=target_categories)
|
||||
logger.info("Successfully parented all channels.")
|
||||
except Exception as ex:
|
||||
logger.error(f"Failed to mass parent channels: {ex}")
|
||||
|
|
|
|||
|
|
@ -97,10 +97,12 @@ class StoatWriter:
|
|||
|
||||
async def _get_server(self, populate_channels=False, force=False):
|
||||
# Always refetch if channels are requested AND we don't already have them
|
||||
# Or if force is True (e.g. after category creation/mutation)
|
||||
# Stoat Server objects use __slots__, so we can't easily add our own tracking attributes.
|
||||
if force or (populate_channels and (not self._server or not hasattr(self._server, "channels") or not self._server.channels)):
|
||||
self._server = await self.client.fetch_server(self.community_id, populate_channels=True)
|
||||
# OR if force=True is passed.
|
||||
if force:
|
||||
self._server = await self.client.fetch_server(self.community_id, populate_channels=populate_channels)
|
||||
elif populate_channels:
|
||||
if not self._server or not hasattr(self._server, "channels") or not self._server.channels:
|
||||
self._server = await self.client.fetch_server(self.community_id, populate_channels=True)
|
||||
elif not self._server:
|
||||
self._server = await self.client.fetch_server(self.community_id, populate_channels=False)
|
||||
return self._server
|
||||
|
|
@ -235,9 +237,6 @@ class StoatWriter:
|
|||
try:
|
||||
if type == 4: # Category
|
||||
# The POST /categories endpoint throws 404 on some server versions, so we use server.edit(categories)
|
||||
# Force refetch to ensure we have the absolute latest state before editing categories array
|
||||
server = await self._get_server(populate_channels=True, force=True)
|
||||
|
||||
import random
|
||||
import time
|
||||
chars = "0123456789ABCDEFGHJKMNPQRSTVWXYZ"
|
||||
|
|
@ -255,8 +254,8 @@ class StoatWriter:
|
|||
if not hasattr(new_cat, "role_permissions"): new_cat.role_permissions = {}
|
||||
categories.append(new_cat)
|
||||
|
||||
# server.edit returns a new Server object on some versions/implementations; maintain local reference
|
||||
self._server = await server.edit(categories=categories)
|
||||
await server.edit(categories=categories)
|
||||
self._server = None # Clear cache after structural change
|
||||
return new_id
|
||||
else: # Text Channel
|
||||
ch = await server.create_text_channel(name=name, description=topic)
|
||||
|
|
@ -283,6 +282,7 @@ class StoatWriter:
|
|||
|
||||
if edit_kwargs:
|
||||
await channel.edit(**edit_kwargs)
|
||||
self._server = None # Clear cache
|
||||
|
||||
# clone_server.py now handles all parenting bulk logic
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ class ProgressScreen(Screen[None]):
|
|||
yield Button("Back", id="btn_back", disabled=False)
|
||||
yield Button("Main Menu", id="btn_main_menu", disabled=False)
|
||||
with Horizontal(classes="action_row", id="prog_actions_cancel"):
|
||||
yield Button("Cancel", id="btn_cancel", variant="error", tooltip="Stop current operation")
|
||||
yield Button("Back", id="btn_cancel", variant="default", tooltip="Stop current operation")
|
||||
yield Footer()
|
||||
|
||||
def __init__(self, log_level: str = "INFO", *args, **kwargs):
|
||||
|
|
@ -157,18 +157,22 @@ class ProgressScreen(Screen[None]):
|
|||
self.confirm_future.set_result(btn_id)
|
||||
return
|
||||
|
||||
# If Cancel is pressed during operation, invoke callback and stay on screen
|
||||
# If Cancel/Back is pressed during operation or loading
|
||||
if btn_id == "btn_cancel":
|
||||
if self.cancel_callback:
|
||||
self.cancel_callback()
|
||||
|
||||
# Show cancelling message and disable button
|
||||
self.set_status("[bold red]Cancelling... waiting for tasks to finish...[/bold red]")
|
||||
try:
|
||||
event.button.disabled = True
|
||||
event.button.label = "Stopping..."
|
||||
except Exception:
|
||||
pass
|
||||
# Show cancelling message and disable button
|
||||
self.set_status("[bold red]Cancelling... waiting for tasks to finish...[/bold red]")
|
||||
try:
|
||||
event.button.disabled = True
|
||||
event.button.label = "Stopping..."
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
# If no callback, just dismiss (likely in Loading phase)
|
||||
if self.timer_event:
|
||||
self.timer_event.stop()
|
||||
self.dismiss("btn_back")
|
||||
return
|
||||
|
||||
# If operation is done (report phase), just dismiss with the action
|
||||
|
|
@ -312,7 +316,11 @@ class ProgressScreen(Screen[None]):
|
|||
except Exception: pass
|
||||
|
||||
# Show Cancel button
|
||||
try: self.query_one("#prog_actions_cancel", Horizontal).display = True
|
||||
try:
|
||||
cancel_btn = self.query_one("#btn_cancel", Button)
|
||||
cancel_btn.label = "Cancel"
|
||||
cancel_btn.variant = "error"
|
||||
self.query_one("#prog_actions_cancel", Horizontal).display = True
|
||||
except Exception: pass
|
||||
|
||||
try: self.query_one("#prog_loader", LoadingIndicator).display = True
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue