From a7191427553be5b29e0d01554a516f7c8a9b1b2b Mon Sep 17 00:00:00 2001 From: rambros Date: Mon, 16 Mar 2026 00:42:45 +0530 Subject: [PATCH] add thread message tracking in migration --- src/core/database.py | 21 +++++- src/core/state.py | 19 +++++ src/fluxer/migrate_message.py | 132 +++++++++++++++++++++++---------- src/stoat/migrate_message.py | 133 ++++++++++++++++++++++++---------- src/ui/shuttle_ops.py | 2 + 5 files changed, 229 insertions(+), 78 deletions(-) diff --git a/src/core/database.py b/src/core/database.py index 6cb2ba6..801da00 100644 --- a/src/core/database.py +++ b/src/core/database.py @@ -72,9 +72,16 @@ class MigrationDatabase: last_msg_ts TEXT, msg_count INTEGER DEFAULT 0, file_count INTEGER DEFAULT 0, + completed INTEGER DEFAULT 0, PRIMARY KEY (channel_id, thread_id) ) """) + + # Add completed column if it doesn't exist (backward compatibility for existing resumption DBs) + try: + cursor.execute("ALTER TABLE thread_tracking ADD COLUMN completed INTEGER DEFAULT 0") + except sqlite3.OperationalError: + pass # Already exists # Table for entity mappings (channels, roles, etc.) cursor.execute(""" @@ -208,7 +215,7 @@ class MigrationDatabase: ).fetchone() return row["target_msg_id"] if row else None - def update_thread_tracking(self, channel_id: str, thread_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0): + def update_thread_tracking(self, channel_id: str, thread_id: str, last_msg_id: str = None, last_msg_ts: str = None, msg_inc: int = 0, file_inc: int = 0, completed: int = None): conn = self._get_conn() conn.execute("INSERT OR IGNORE INTO thread_tracking (channel_id, thread_id) VALUES (?, ?)", (channel_id, thread_id)) @@ -216,6 +223,8 @@ class MigrationDatabase: conn.execute("UPDATE thread_tracking SET last_msg_id = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_id, channel_id, thread_id)) if last_msg_ts: conn.execute("UPDATE thread_tracking SET last_msg_ts = ? WHERE channel_id = ? AND thread_id = ?", (last_msg_ts, channel_id, thread_id)) + if completed is not None: + conn.execute("UPDATE thread_tracking SET completed = ? WHERE channel_id = ? AND thread_id = ?", (completed, channel_id, thread_id)) if msg_inc != 0 or file_inc != 0: conn.execute( @@ -231,6 +240,16 @@ class MigrationDatabase: return dict(row) return {"last_msg_id": None, "last_msg_ts": None, "msg_count": 0, "file_count": 0} + def clear_channel_data(self, channel_id: str): + """Purge all mappings and tracking data for a specific channel and its threads.""" + conn = self._get_conn() + conn.execute("DELETE FROM message_mappings WHERE channel_id = ?", (channel_id,)) + conn.execute("DELETE FROM thread_mappings WHERE channel_id = ?", (channel_id,)) + conn.execute("DELETE FROM channel_tracking WHERE channel_id = ?", (channel_id,)) + conn.execute("DELETE FROM thread_tracking WHERE channel_id = ?", (channel_id,)) + conn.commit() + logger.info(f"Cleared all tracking and mapping data for channel: {channel_id}") + def close(self): if hasattr(self._local, "conn"): self._local.conn.close() diff --git a/src/core/state.py b/src/core/state.py index 08666c1..f2ec869 100644 --- a/src/core/state.py +++ b/src/core/state.py @@ -149,6 +149,16 @@ class MigrationState: if self._ensure_db(): self.db.update_thread_tracking(str(target_channel_id), str(thread_id), last_msg_id=str(message_id)) + def update_thread_completed(self, target_channel_id: str, thread_id: str, completed: bool = True): + if self._ensure_db(): + self.db.update_thread_tracking(str(target_channel_id), str(thread_id), completed=1 if completed else 0) + + def is_thread_completed(self, target_channel_id: str, thread_id: str) -> bool: + if self._ensure_db(): + tracking = self.db.get_thread_tracking(str(target_channel_id), str(thread_id)) + return bool(tracking.get("completed", 0)) + return False + def get_thread_message_id(self, target_channel_id: str, thread_id: str, discord_id: str) -> str | None: if self._ensure_db(): return self.db.get_target_thread_message_id(str(target_channel_id), str(thread_id), str(discord_id)) @@ -167,6 +177,11 @@ class MigrationState: return self.db.get_channel_tracking(str(target_channel_id)).get("last_msg_id") return None + def get_thread_last_message_id(self, target_channel_id: str, thread_id: str) -> str | None: + if self._ensure_db(): + return self.db.get_thread_tracking(str(target_channel_id), str(thread_id)).get("last_msg_id") + return None + def find_message_mapping(self, discord_id: str) -> tuple[str, str] | tuple[None, None]: if not self.db: return None, None @@ -204,6 +219,10 @@ class MigrationState: conn.execute("DELETE FROM thread_tracking") conn.commit() + def clear_channel_data(self, target_channel_id: str): + if self._ensure_db(): + self.db.clear_channel_data(str(target_channel_id)) + def set_folder(self, server_id: str, clean_name: str, base_dir: Path | str = ""): """ Initializes the SQLite database based on community name and ID. diff --git a/src/fluxer/migrate_message.py b/src/fluxer/migrate_message.py index 24e3377..a7802d3 100644 --- a/src/fluxer/migrate_message.py +++ b/src/fluxer/migrate_message.py @@ -150,8 +150,15 @@ async def analyze_migration(context: MigrationContext, source_channel_id: int, a if thread.id not in processed_threads: processed_threads.add(thread.id) stats["threads"] += 1 + + # Fetch last migrated message ID for this thread + target_channel_id = context.state.get_target_channel_id(str(source_channel_id)) + thread_after_id = None + if target_channel_id: + thread_after_id = context.state.get_thread_last_message_id(target_channel_id, str(thread.id)) + # Recursively count thread content - thread_stats = await analyze_migration(context, thread.id, processed_threads=processed_threads) + thread_stats = await analyze_migration(context, thread.id, after_message_id=int(thread_after_id) if thread_after_id else None, processed_threads=processed_threads) stats["messages"] += thread_stats["messages"] stats["attachments"] += thread_stats["attachments"] stats["threads"] += thread_stats["threads"] # Nested threads (rare in Discord but possible in forum channels) @@ -186,7 +193,14 @@ async def analyze_migration(context: MigrationContext, source_channel_id: int, a if t.id not in processed_threads: processed_threads.add(t.id) stats["threads"] += 1 - thread_stats = await analyze_migration(context, t.id, processed_threads=processed_threads) + + # Fetch last migrated message ID for this thread + target_channel_id = context.state.get_target_channel_id(str(source_channel_id)) + thread_after_id = None + if target_channel_id: + thread_after_id = context.state.get_thread_last_message_id(target_channel_id, str(t.id)) + + thread_stats = await analyze_migration(context, t.id, after_message_id=int(thread_after_id) if thread_after_id else None, processed_threads=processed_threads) stats["messages"] += thread_stats["messages"] stats["attachments"] += thread_stats["attachments"] stats["threads"] += thread_stats["threads"] @@ -224,7 +238,57 @@ async def migrate_messages( if processed_threads is None: processed_threads = set() + async def _process_missed_threads(): + """Helper to scan for threads not yet processed in the current scan.""" + if not context.is_running: + return + logger.info(f"Checking for missed or pending threads in channel {source_channel_id}...") + all_threads = await get_channel_threads(context.discord_reader, source_channel_id) + for t in all_threads: + if not context.is_running: + break + if t.id not in processed_threads: + processed_threads.add(t.id) + + # Skip if thread was already fully migrated in a previous run + if context.state.is_thread_completed(target_channel_id, str(t.id)): + logger.debug(f"Skipping already completed thread '{t.name}' (ID: {t.id})") + continue + + logger.info(f"Checking missed thread '{t.name}' (ID: {t.id})") + + # Fetch last migrated message ID for this thread + thread_after_id = context.state.get_thread_last_message_id(target_channel_id, str(t.id)) + if thread_after_id: + logger.info(f"Resuming missed/pending thread '{t.name}' from after message ID: {thread_after_id}") + + stats["threads"] += 1 + thread_stats = await migrate_messages( + context=context, + source_channel_id=t.id, + target_channel_id=target_channel_id, + after_message_id=int(thread_after_id) if thread_after_id else None, + thread_id=str(t.id), + parent_target_id=None, + thread_name=t.name, + processed_threads=processed_threads + ) + stats["messages"] += thread_stats["messages"] + stats["attachments"] += thread_stats["attachments"] + stats["threads"] += thread_stats["threads"] + + if context.is_running: + await context.fluxer_writer.send_marker( + channel_id=target_channel_id, + content=f"> <<< END OF THREAD >>>" + ) + try: + # If resuming (after_message_id is set) and at top level, check for pending threads FIRST + # to preserve chronological order (finish old unfinished business first) + if not thread_id and after_message_id is not None: + await _process_missed_threads() + async for msg in context.discord_reader.fetch_message_history(source_channel_id, after_id=after_message_id, inclusive=inclusive): if not context.is_running: logger.warning("Migration interrupted by user (is_running=False)") @@ -252,11 +316,17 @@ async def migrate_messages( # Track thread entry stats["threads"] += 1 + # Fetch last migrated message ID for this thread + thread_after_id = context.state.get_thread_last_message_id(target_channel_id, str(thread.id)) + if thread_after_id: + logger.info(f"Resuming thread '{thread.name}' from after message ID: {thread_after_id}") + # Migrate thread messages recursively thread_stats = await migrate_messages( context=context, source_channel_id=thread.id, target_channel_id=target_channel_id, + after_message_id=int(thread_after_id) if thread_after_id else None, thread_id=str(thread.id), parent_target_id=None, thread_name=thread.name, @@ -267,10 +337,11 @@ async def migrate_messages( stats["threads"] += thread_stats["threads"] # Send End Marker - await context.fluxer_writer.send_marker( - channel_id=target_channel_id, - content=f"> <<< END OF THREAD >>>" - ) + if context.is_running: + await context.fluxer_writer.send_marker( + channel_id=target_channel_id, + content=f"> <<< END OF THREAD >>>" + ) if progress_callback: await progress_callback(stats) @@ -483,11 +554,17 @@ async def migrate_messages( # Track thread entry stats["threads"] += 1 + # Fetch last migrated message ID for this thread + thread_after_id = context.state.get_thread_last_message_id(target_channel_id, str(thread.id)) + if thread_after_id: + logger.info(f"Resuming thread '{thread.name}' from after message ID: {thread_after_id}") + # Migrate thread messages recursively thread_stats = await migrate_messages( context=context, source_channel_id=thread.id, target_channel_id=target_channel_id, + after_message_id=int(thread_after_id) if thread_after_id else None, thread_id=str(thread.id), parent_target_id=fluxer_msg_id, thread_name=thread.name, @@ -498,10 +575,11 @@ async def migrate_messages( stats["threads"] += thread_stats["threads"] # Send End Marker - await context.fluxer_writer.send_marker( - channel_id=target_channel_id, - content=f"> <<< END OF THREAD >>>" - ) + if context.is_running: + await context.fluxer_writer.send_marker( + channel_id=target_channel_id, + content=f"> <<< END OF THREAD >>>" + ) # Update Link Tracking (but prevent threaded messages from overwriting the parent channel pointers) # The 'after_message_id' param usually means it's the main function call and not a thread recursive call @@ -516,34 +594,12 @@ async def migrate_messages( logger.error(f"Failed to process message {msg.id}: {e}") import traceback logger.error(traceback.format_exc()) - - # After scanning messages, explicitly check for any missed threads (e.g. archived or skipped in scan) - # Only do this at the top level - if not thread_id and (after_message_id is not None or inclusive or stats["messages"] > 0): - all_threads = await get_channel_threads(context.discord_reader, source_channel_id) - for t in all_threads: - if t.id not in processed_threads: - processed_threads.add(t.id) - logger.info(f"Migrating missed thread '{t.name}' (ID: {t.id})") - - stats["threads"] += 1 - thread_stats = await migrate_messages( - context=context, - source_channel_id=t.id, - target_channel_id=target_channel_id, - thread_id=str(t.id), - parent_target_id=None, - thread_name=t.name, - processed_threads=processed_threads - ) - stats["messages"] += thread_stats["messages"] - stats["attachments"] += thread_stats["attachments"] - stats["threads"] += thread_stats["threads"] - - await context.fluxer_writer.send_marker( - channel_id=target_channel_id, - content=f"> <<< END OF THREAD >>>" - ) + + # Mark thread as completed if we finished the loop without being interrupted + if thread_id and context.is_running: + context.state.update_thread_completed(target_channel_id, thread_id, completed=True) + logger.info(f"Thread '{thread_name}' (ID: {thread_id}) marked as completed.") + except (KeyboardInterrupt, asyncio.CancelledError): context.is_running = False diff --git a/src/stoat/migrate_message.py b/src/stoat/migrate_message.py index bcdb294..b2f5ea3 100644 --- a/src/stoat/migrate_message.py +++ b/src/stoat/migrate_message.py @@ -156,7 +156,14 @@ async def analyze_migration(context: MigrationContext, source_channel_id: int, a if thread.id not in processed_threads: processed_threads.add(thread.id) stats["threads"] += 1 - thread_stats = await analyze_migration(context, thread.id, processed_threads=processed_threads) + + # Fetch last migrated message ID for this thread + target_channel_id = context.state.get_target_channel_id(str(source_channel_id)) + thread_after_id = None + if target_channel_id: + thread_after_id = context.state.get_thread_last_message_id(target_channel_id, str(thread.id)) + + thread_stats = await analyze_migration(context, thread.id, after_message_id=int(thread_after_id) if thread_after_id else None, processed_threads=processed_threads) stats["messages"] += thread_stats["messages"] stats["attachments"] += thread_stats["attachments"] stats["threads"] += thread_stats["threads"] @@ -189,7 +196,14 @@ async def analyze_migration(context: MigrationContext, source_channel_id: int, a if t.id not in processed_threads: processed_threads.add(t.id) stats["threads"] += 1 - thread_stats = await analyze_migration(context, t.id, processed_threads=processed_threads) + + # Fetch last migrated message ID for this thread + target_channel_id = context.state.get_target_channel_id(str(source_channel_id)) + thread_after_id = None + if target_channel_id: + thread_after_id = context.state.get_thread_last_message_id(target_channel_id, str(t.id)) + + thread_stats = await analyze_migration(context, t.id, after_message_id=int(thread_after_id) if thread_after_id else None, processed_threads=processed_threads) stats["messages"] += thread_stats["messages"] stats["attachments"] += thread_stats["attachments"] stats["threads"] += thread_stats["threads"] @@ -226,8 +240,57 @@ async def migrate_messages( if processed_threads is None: processed_threads = set() - + + async def _process_missed_threads(): + """Helper to scan for threads not yet processed in the current scan.""" + if not context.is_running: + return + logger.info(f"Checking for missed or pending threads in channel {source_channel_id}...") + all_threads = await get_channel_threads(context.discord_reader, source_channel_id) + for t in all_threads: + if not context.is_running: + break + if t.id not in processed_threads: + processed_threads.add(t.id) + + # Skip if thread was already fully migrated in a previous run + if context.state.is_thread_completed(target_channel_id, str(t.id)): + logger.debug(f"Skipping already completed thread '{t.name}' (ID: {t.id})") + continue + + logger.info(f"Checking missed thread '{t.name}' (ID: {t.id})") + + # Fetch last migrated message ID for this thread + thread_after_id = context.state.get_thread_last_message_id(target_channel_id, str(t.id)) + if thread_after_id: + logger.info(f"Resuming missed/pending thread '{t.name}' from after message ID: {thread_after_id}") + + stats["threads"] += 1 + thread_stats = await migrate_messages( + context=context, + source_channel_id=t.id, + target_channel_id=target_channel_id, + after_message_id=int(thread_after_id) if thread_after_id else None, + thread_id=str(t.id), + parent_target_id=None, + thread_name=t.name, + processed_threads=processed_threads + ) + stats["messages"] += thread_stats["messages"] + stats["attachments"] += thread_stats["attachments"] + stats["threads"] += thread_stats["threads"] + + if context.is_running: + await context.stoat_writer.send_marker( + channel_id=target_channel_id, + content=f"> <<< END OF THREAD >>>" + ) + try: + # If resuming (after_message_id is set) and at top level, check for pending threads FIRST + # to preserve chronological order (finish old unfinished business first) + if not thread_id and after_message_id is not None: + await _process_missed_threads() async for msg in context.discord_reader.fetch_message_history(source_channel_id, after_id=after_message_id, inclusive=inclusive): if not context.is_running: logger.warning("Migration interrupted by user (is_running=False)") @@ -256,11 +319,17 @@ async def migrate_messages( # Track thread entry stats["threads"] += 1 + # Fetch last migrated message ID for this thread + thread_after_id = context.state.get_thread_last_message_id(target_channel_id, str(thread.id)) + if thread_after_id: + logger.info(f"Resuming thread '{thread.name}' from after message ID: {thread_after_id}") + # Migrate thread messages recursively thread_stats = await migrate_messages( context=context, source_channel_id=thread.id, target_channel_id=target_channel_id, + after_message_id=int(thread_after_id) if thread_after_id else None, thread_id=str(thread.id), parent_target_id=None, thread_name=thread.name, @@ -271,10 +340,11 @@ async def migrate_messages( stats["threads"] += thread_stats["threads"] # Send End Marker - await context.stoat_writer.send_marker( - channel_id=target_channel_id, - content=f"> <<< END OF THREAD >>>" - ) + if context.is_running: + await context.stoat_writer.send_marker( + channel_id=target_channel_id, + content=f"> <<< END OF THREAD >>>" + ) if progress_callback: await progress_callback(stats) @@ -478,12 +548,18 @@ async def migrate_messages( processed_threads.add(thread.id) # Track thread entry stats["threads"] += 1 - + + # Fetch last migrated message ID for this thread + thread_after_id = context.state.get_thread_last_message_id(target_channel_id, str(thread.id)) + if thread_after_id: + logger.info(f"Resuming thread '{thread.name}' from after message ID: {thread_after_id}") + # Migrate thread messages recursively thread_stats = await migrate_messages( context=context, source_channel_id=thread.id, target_channel_id=target_channel_id, + after_message_id=int(thread_after_id) if thread_after_id else None, thread_id=str(thread.id), parent_target_id=stoat_msg_id, thread_name=thread.name, @@ -494,10 +570,11 @@ async def migrate_messages( stats["threads"] += thread_stats["threads"] # Send End Marker - await context.stoat_writer.send_marker( - channel_id=target_channel_id, - content=f"> <<< END OF THREAD >>>" - ) + if context.is_running: + await context.stoat_writer.send_marker( + channel_id=target_channel_id, + content=f"> <<< END OF THREAD >>>" + ) # Update Link Tracking if not stats["first_message_url"]: @@ -514,33 +591,11 @@ async def migrate_messages( import traceback logger.error(traceback.format_exc()) - # After scanning messages, explicitly check for any missed threads (e.g. archived or skipped in scan) - # Only do this at the top level - if not thread_id and (after_message_id is not None or inclusive or stats["messages"] > 0): - all_threads = await get_channel_threads(context.discord_reader, source_channel_id) - for t in all_threads: - if t.id not in processed_threads: - processed_threads.add(t.id) - logger.info(f"Migrating missed thread '{t.name}' (ID: {t.id})") - - stats["threads"] += 1 - thread_stats = await migrate_messages( - context=context, - source_channel_id=t.id, - target_channel_id=target_channel_id, - thread_id=str(t.id), - parent_target_id=None, - thread_name=t.name, - processed_threads=processed_threads - ) - stats["messages"] += thread_stats["messages"] - stats["attachments"] += thread_stats["attachments"] - stats["threads"] += thread_stats["threads"] - - await context.stoat_writer.send_marker( - channel_id=target_channel_id, - content=f"> <<< END OF THREAD >>>" - ) + # Mark thread as completed if we finished the loop without being interrupted + if thread_id and context.is_running: + context.state.update_thread_completed(target_channel_id, thread_id, completed=True) + logger.info(f"Thread '{thread_name}' (ID: {thread_id}) marked as completed.") + except (KeyboardInterrupt, asyncio.CancelledError): context.is_running = False diff --git a/src/ui/shuttle_ops.py b/src/ui/shuttle_ops.py index b3cef33..d21ec7f 100644 --- a/src/ui/shuttle_ops.py +++ b/src/ui/shuttle_ops.py @@ -1167,6 +1167,8 @@ class OperationPane(Container): else: logger.info("Proceeding with 'Start from First' (clean sink).") after_id = None + # Clear previous tracking data for this channel + self.engine.state.clear_channel_data(target_channel.get("id")) is_inclusive = (choice == "btn_start_id")