fix channel positions in stoat
This commit is contained in:
parent
0aa32217b1
commit
8cdbcab5f6
3 changed files with 84 additions and 66 deletions
|
|
@ -63,6 +63,10 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
||||||
categories = await context.discord_reader.get_categories()
|
categories = await context.discord_reader.get_categories()
|
||||||
channels = await context.discord_reader.get_channels()
|
channels = await context.discord_reader.get_channels()
|
||||||
|
|
||||||
|
# Sort categories and channels by position to preserve order
|
||||||
|
categories = sorted(categories, key=lambda c: getattr(c, 'position', 0))
|
||||||
|
channels = sorted(channels, key=lambda c: getattr(c, 'position', 0))
|
||||||
|
|
||||||
cloned_info = {
|
cloned_info = {
|
||||||
"categories_created": [],
|
"categories_created": [],
|
||||||
"channels_created": [],
|
"channels_created": [],
|
||||||
|
|
@ -73,13 +77,8 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
||||||
|
|
||||||
# 1. Identify categories to create
|
# 1. Identify categories to create
|
||||||
missing_categories = [cat for cat in categories if force or not context.state.get_target_category_id(str(cat.id))]
|
missing_categories = [cat for cat in categories if force or not context.state.get_target_category_id(str(cat.id))]
|
||||||
missing_category_ids = {str(cat.id) for cat in missing_categories}
|
|
||||||
|
|
||||||
# 2. Identify channels to create or move
|
# 2. Identify channels to create or move
|
||||||
# Fetch current Target state to check parent_ids
|
|
||||||
target_channels = await context.writer.get_channels()
|
|
||||||
target_parent_map = {str(c["id"]): (str(c.get("parent_id")) if c.get("parent_id") else None) for c in target_channels}
|
|
||||||
|
|
||||||
channels_to_create = []
|
channels_to_create = []
|
||||||
channels_to_move = []
|
channels_to_move = []
|
||||||
|
|
||||||
|
|
@ -88,11 +87,8 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
||||||
target_id = context.state.get_target_channel_id(discord_id)
|
target_id = context.state.get_target_channel_id(discord_id)
|
||||||
|
|
||||||
if force or not target_id:
|
if force or not target_id:
|
||||||
# We'll resolve the parent_id in the loop after categories are created
|
|
||||||
channels_to_create.append(ch)
|
channels_to_create.append(ch)
|
||||||
else:
|
else:
|
||||||
# Always add to move/sync list to ensure properties (topic, nsfw, slowmode) are synced
|
|
||||||
# even if the parent category is already correct.
|
|
||||||
channels_to_move.append((ch, target_id))
|
channels_to_move.append((ch, target_id))
|
||||||
|
|
||||||
total = len(missing_categories) + len(channels_to_create) + len(channels_to_move)
|
total = len(missing_categories) + len(channels_to_create) + len(channels_to_move)
|
||||||
|
|
@ -102,7 +98,6 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
||||||
return cloned_info
|
return cloned_info
|
||||||
|
|
||||||
# 1. Migrate Categories
|
# 1. Migrate Categories
|
||||||
missing_category_ids = {str(cat.id) for cat in missing_categories}
|
|
||||||
for cat in missing_categories:
|
for cat in missing_categories:
|
||||||
if not context.is_running: break
|
if not context.is_running: break
|
||||||
|
|
||||||
|
|
@ -148,7 +143,6 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
||||||
# Sync properties immediately
|
# Sync properties immediately
|
||||||
await context.writer.modify_channel(
|
await context.writer.modify_channel(
|
||||||
channel_id=target_id,
|
channel_id=target_id,
|
||||||
parent_id=None,
|
|
||||||
name=channel.name,
|
name=channel.name,
|
||||||
topic=topic,
|
topic=topic,
|
||||||
nsfw=nsfw,
|
nsfw=nsfw,
|
||||||
|
|
@ -170,7 +164,6 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
||||||
|
|
||||||
await context.writer.modify_channel(
|
await context.writer.modify_channel(
|
||||||
channel_id=target_id,
|
channel_id=target_id,
|
||||||
parent_id=None,
|
|
||||||
name=channel.name,
|
name=channel.name,
|
||||||
topic=topic,
|
topic=topic,
|
||||||
nsfw=nsfw,
|
nsfw=nsfw,
|
||||||
|
|
@ -184,49 +177,66 @@ async def migrate_channels(context: MigrationContext, progress_callback: Callabl
|
||||||
|
|
||||||
# 4. Final step: Parent the channels into categories via mass server.edit()
|
# 4. Final step: Parent the channels into categories via mass server.edit()
|
||||||
logger.info("Parenting all channels into their respective categories...")
|
logger.info("Parenting all channels into their respective categories...")
|
||||||
# Force refetch to ensure we see all newly created categories from the loop above
|
# Force refresh server to get latest categories created during migration
|
||||||
server = await context.writer._get_server(populate_channels=True, force=True)
|
server = await context.writer._get_server(populate_channels=True, force=True)
|
||||||
cats = list(server.categories) if hasattr(server, "categories") and server.categories else []
|
|
||||||
|
|
||||||
# Workaround: Ensure default properties are set for all categories
|
# Stoat categories are managed via server.edit(categories=...)
|
||||||
for c in cats:
|
# We must preserve ALL categories that exist in Stoat, but update the ones we manage.
|
||||||
if not hasattr(c, "default_permissions"): c.default_permissions = None
|
target_categories = []
|
||||||
if not hasattr(c, "role_permissions"): c.role_permissions = {}
|
|
||||||
|
|
||||||
# We will build a map of target_cat_id -> list of target_ch_ids
|
# Build map of Discord Cat ID -> [Target Channel IDs]
|
||||||
cat_to_channels = {}
|
# Channels should already be sorted by position in the 'channels' list
|
||||||
for cat in categories:
|
discord_cat_to_channels = {}
|
||||||
target_cat_id = context.state.get_target_category_id(str(cat.id))
|
for ch in channels:
|
||||||
if not target_cat_id: continue
|
cat_id = str(getattr(ch, 'category_id', '')) if ch.category_id else "unparented"
|
||||||
|
if cat_id not in discord_cat_to_channels:
|
||||||
|
discord_cat_to_channels[cat_id] = []
|
||||||
|
|
||||||
|
target_id = context.state.get_target_channel_id(str(ch.id))
|
||||||
|
if target_id:
|
||||||
|
discord_cat_to_channels[cat_id].append(target_id)
|
||||||
|
|
||||||
target_channels_for_cat = []
|
# Resolve Stoat categories
|
||||||
for ch in channels:
|
# We iterate over the categories from the server to ensure we don't drop any
|
||||||
if str(getattr(ch, 'category_id', '')) == str(cat.id):
|
for stoat_cat in server.categories:
|
||||||
target_ch_id = context.state.get_target_channel_id(str(ch.id))
|
# Check if this Stoat category maps to any Discord category
|
||||||
if target_ch_id:
|
discord_cat_id = next((d_id for d_id, t_id in context.state.category_map.items() if t_id == str(stoat_cat.id)), None)
|
||||||
target_channels_for_cat.append(target_ch_id)
|
|
||||||
|
|
||||||
cat_to_channels[target_cat_id] = target_channels_for_cat
|
if discord_cat_id:
|
||||||
|
# Managed category - update its channels
|
||||||
# Now correctly assign them in the cats array, and remove them from other cats
|
new_channels = discord_cat_to_channels.get(discord_cat_id, [])
|
||||||
all_assigned_channels = set()
|
new_cat = stoat.Category(
|
||||||
for cat_id, ch_list in cat_to_channels.items():
|
id=str(stoat_cat.id),
|
||||||
all_assigned_channels.update(ch_list)
|
title=stoat_cat.title,
|
||||||
|
channels=new_channels
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Unmanaged category - preserve as is but ensure all channels are strings
|
||||||
|
new_cat = stoat.Category(
|
||||||
|
id=str(stoat_cat.id),
|
||||||
|
title=stoat_cat.title,
|
||||||
|
channels=[str(c_id) for c_id in stoat_cat.channels]
|
||||||
|
)
|
||||||
|
|
||||||
for i, c in enumerate(cats):
|
# Workaround for stoat.py missing attributes
|
||||||
# Remove any channels that are being assigned to a specific category
|
if not hasattr(new_cat, "default_permissions"): new_cat.default_permissions = getattr(stoat_cat, "default_permissions", None)
|
||||||
new_channels = [ch for ch in c.channels if ch not in all_assigned_channels]
|
if not hasattr(new_cat, "role_permissions"): new_cat.role_permissions = getattr(stoat_cat, "role_permissions", {})
|
||||||
if c.id in cat_to_channels:
|
|
||||||
# If this is one of our managed categories, set its channels to exactly what it should be
|
|
||||||
new_channels = cat_to_channels[c.id]
|
|
||||||
|
|
||||||
new_cat = stoat.Category(id=c.id, title=c.title, channels=new_channels)
|
|
||||||
new_cat.default_permissions = c.default_permissions
|
|
||||||
new_cat.role_permissions = c.role_permissions
|
|
||||||
cats[i] = new_cat
|
|
||||||
|
|
||||||
|
target_categories.append(new_cat)
|
||||||
|
|
||||||
|
# Sort target_categories based on the original Discord category positions
|
||||||
|
def get_cat_position(s_cat):
|
||||||
|
d_id = next((d_id for d_id, t_id in context.state.category_map.items() if t_id == str(s_cat.id)), None)
|
||||||
|
if d_id:
|
||||||
|
d_cat = next((c for c in categories if str(c.id) == d_id), None)
|
||||||
|
if d_cat:
|
||||||
|
return getattr(d_cat, 'position', 999)
|
||||||
|
return 999
|
||||||
|
|
||||||
|
target_categories.sort(key=get_cat_position)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await server.edit(categories=cats)
|
await server.edit(categories=target_categories)
|
||||||
logger.info("Successfully parented all channels.")
|
logger.info("Successfully parented all channels.")
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.error(f"Failed to mass parent channels: {ex}")
|
logger.error(f"Failed to mass parent channels: {ex}")
|
||||||
|
|
|
||||||
|
|
@ -97,10 +97,12 @@ class StoatWriter:
|
||||||
|
|
||||||
async def _get_server(self, populate_channels=False, force=False):
|
async def _get_server(self, populate_channels=False, force=False):
|
||||||
# Always refetch if channels are requested AND we don't already have them
|
# Always refetch if channels are requested AND we don't already have them
|
||||||
# Or if force is True (e.g. after category creation/mutation)
|
# OR if force=True is passed.
|
||||||
# Stoat Server objects use __slots__, so we can't easily add our own tracking attributes.
|
if force:
|
||||||
if force or (populate_channels and (not self._server or not hasattr(self._server, "channels") or not self._server.channels)):
|
self._server = await self.client.fetch_server(self.community_id, populate_channels=populate_channels)
|
||||||
self._server = await self.client.fetch_server(self.community_id, populate_channels=True)
|
elif populate_channels:
|
||||||
|
if not self._server or not hasattr(self._server, "channels") or not self._server.channels:
|
||||||
|
self._server = await self.client.fetch_server(self.community_id, populate_channels=True)
|
||||||
elif not self._server:
|
elif not self._server:
|
||||||
self._server = await self.client.fetch_server(self.community_id, populate_channels=False)
|
self._server = await self.client.fetch_server(self.community_id, populate_channels=False)
|
||||||
return self._server
|
return self._server
|
||||||
|
|
@ -235,9 +237,6 @@ class StoatWriter:
|
||||||
try:
|
try:
|
||||||
if type == 4: # Category
|
if type == 4: # Category
|
||||||
# The POST /categories endpoint throws 404 on some server versions, so we use server.edit(categories)
|
# The POST /categories endpoint throws 404 on some server versions, so we use server.edit(categories)
|
||||||
# Force refetch to ensure we have the absolute latest state before editing categories array
|
|
||||||
server = await self._get_server(populate_channels=True, force=True)
|
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
chars = "0123456789ABCDEFGHJKMNPQRSTVWXYZ"
|
chars = "0123456789ABCDEFGHJKMNPQRSTVWXYZ"
|
||||||
|
|
@ -255,8 +254,8 @@ class StoatWriter:
|
||||||
if not hasattr(new_cat, "role_permissions"): new_cat.role_permissions = {}
|
if not hasattr(new_cat, "role_permissions"): new_cat.role_permissions = {}
|
||||||
categories.append(new_cat)
|
categories.append(new_cat)
|
||||||
|
|
||||||
# server.edit returns a new Server object on some versions/implementations; maintain local reference
|
await server.edit(categories=categories)
|
||||||
self._server = await server.edit(categories=categories)
|
self._server = None # Clear cache after structural change
|
||||||
return new_id
|
return new_id
|
||||||
else: # Text Channel
|
else: # Text Channel
|
||||||
ch = await server.create_text_channel(name=name, description=topic)
|
ch = await server.create_text_channel(name=name, description=topic)
|
||||||
|
|
@ -283,6 +282,7 @@ class StoatWriter:
|
||||||
|
|
||||||
if edit_kwargs:
|
if edit_kwargs:
|
||||||
await channel.edit(**edit_kwargs)
|
await channel.edit(**edit_kwargs)
|
||||||
|
self._server = None # Clear cache
|
||||||
|
|
||||||
# clone_server.py now handles all parenting bulk logic
|
# clone_server.py now handles all parenting bulk logic
|
||||||
return True
|
return True
|
||||||
|
|
|
||||||
|
|
@ -110,7 +110,7 @@ class ProgressScreen(Screen[None]):
|
||||||
yield Button("Back", id="btn_back", disabled=False)
|
yield Button("Back", id="btn_back", disabled=False)
|
||||||
yield Button("Main Menu", id="btn_main_menu", disabled=False)
|
yield Button("Main Menu", id="btn_main_menu", disabled=False)
|
||||||
with Horizontal(classes="action_row", id="prog_actions_cancel"):
|
with Horizontal(classes="action_row", id="prog_actions_cancel"):
|
||||||
yield Button("Cancel", id="btn_cancel", variant="error", tooltip="Stop current operation")
|
yield Button("Back", id="btn_cancel", variant="default", tooltip="Stop current operation")
|
||||||
yield Footer()
|
yield Footer()
|
||||||
|
|
||||||
def __init__(self, log_level: str = "INFO", *args, **kwargs):
|
def __init__(self, log_level: str = "INFO", *args, **kwargs):
|
||||||
|
|
@ -157,18 +157,22 @@ class ProgressScreen(Screen[None]):
|
||||||
self.confirm_future.set_result(btn_id)
|
self.confirm_future.set_result(btn_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
# If Cancel is pressed during operation, invoke callback and stay on screen
|
# If Cancel/Back is pressed during operation or loading
|
||||||
if btn_id == "btn_cancel":
|
if btn_id == "btn_cancel":
|
||||||
if self.cancel_callback:
|
if self.cancel_callback:
|
||||||
self.cancel_callback()
|
self.cancel_callback()
|
||||||
|
# Show cancelling message and disable button
|
||||||
# Show cancelling message and disable button
|
self.set_status("[bold red]Cancelling... waiting for tasks to finish...[/bold red]")
|
||||||
self.set_status("[bold red]Cancelling... waiting for tasks to finish...[/bold red]")
|
try:
|
||||||
try:
|
event.button.disabled = True
|
||||||
event.button.disabled = True
|
event.button.label = "Stopping..."
|
||||||
event.button.label = "Stopping..."
|
except Exception:
|
||||||
except Exception:
|
pass
|
||||||
pass
|
else:
|
||||||
|
# If no callback, just dismiss (likely in Loading phase)
|
||||||
|
if self.timer_event:
|
||||||
|
self.timer_event.stop()
|
||||||
|
self.dismiss("btn_back")
|
||||||
return
|
return
|
||||||
|
|
||||||
# If operation is done (report phase), just dismiss with the action
|
# If operation is done (report phase), just dismiss with the action
|
||||||
|
|
@ -312,7 +316,11 @@ class ProgressScreen(Screen[None]):
|
||||||
except Exception: pass
|
except Exception: pass
|
||||||
|
|
||||||
# Show Cancel button
|
# Show Cancel button
|
||||||
try: self.query_one("#prog_actions_cancel", Horizontal).display = True
|
try:
|
||||||
|
cancel_btn = self.query_one("#btn_cancel", Button)
|
||||||
|
cancel_btn.label = "Cancel"
|
||||||
|
cancel_btn.variant = "error"
|
||||||
|
self.query_one("#prog_actions_cancel", Horizontal).display = True
|
||||||
except Exception: pass
|
except Exception: pass
|
||||||
|
|
||||||
try: self.query_one("#prog_loader", LoadingIndicator).display = True
|
try: self.query_one("#prog_loader", LoadingIndicator).display = True
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue