fix forwarded messages
This commit is contained in:
parent
c2ece58209
commit
7bea5582e9
5 changed files with 367 additions and 123 deletions
|
|
@ -81,6 +81,7 @@ class MessageType(IntEnum):
|
|||
guild_incident_report_false_alarm = 39
|
||||
purchase_notification = 44
|
||||
poll_result = 46
|
||||
forward = 100
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -659,7 +660,7 @@ class BackupMessage:
|
|||
"Default": MessageType.default,
|
||||
"Reply": MessageType.reply,
|
||||
"ThreadStarter": MessageType.thread_starter_message,
|
||||
"Forward": MessageType.default,
|
||||
"Forward": 100,
|
||||
}
|
||||
|
||||
__slots__ = ("id", "type", "created_at", "pinned", "content", "author",
|
||||
|
|
@ -667,7 +668,7 @@ class BackupMessage:
|
|||
"reference", "thread", "channel_id", "flags", "guild", "channel",
|
||||
"mentions", "role_mentions", "channel_mentions", "mention_everyone",
|
||||
"tts", "nonce", "webhook_id", "application_id", "activity",
|
||||
"application", "interaction", "components", "jump_url")
|
||||
"application", "interaction", "components", "jump_url", "message_snapshots")
|
||||
def __init__(self, data: dict, *,
|
||||
author: BackupMember | None = None,
|
||||
guild: "BackupGuild | None" = None,
|
||||
|
|
@ -771,7 +772,10 @@ class BackupMessage:
|
|||
self.reference = type("Ref", (), {"message_id": parse_snowflake(data["message_reference"]), "channel_id": self.channel_id})()
|
||||
|
||||
self.thread = None
|
||||
self.flags = type("Flags", (), {"value": 0})()
|
||||
self.flags = type("Flags", (), {"value": 0, "forwarded": self.type == MessageType.forward})()
|
||||
|
||||
# snapshots not used in latest refined structure as content is in main message
|
||||
self.message_snapshots = []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"BackupMessage(id={self.id}, author={self.author})"
|
||||
|
|
@ -837,6 +841,13 @@ class BackupGuild:
|
|||
if self.banner:
|
||||
self.banner.url = data.get("banner_url")
|
||||
|
||||
@property
|
||||
def active_threads(self) -> List[BackupThread]:
|
||||
"""Returns all threads that are not archived."""
|
||||
if self._reader:
|
||||
return [t for t in self._reader._threads if not t.archived]
|
||||
return []
|
||||
|
||||
@property
|
||||
def roles(self) -> List[BackupRole]:
|
||||
return self._reader._roles if self._reader else []
|
||||
|
|
@ -904,6 +915,7 @@ class BackupReader:
|
|||
MESSAGE_TYPE_DEFAULT = MessageType.default
|
||||
MESSAGE_TYPE_REPLY = MessageType.reply
|
||||
MESSAGE_TYPE_THREAD_STARTER = MessageType.thread_starter_message
|
||||
MESSAGE_TYPE_FORWARD = MessageType.forward # Custom Reaper constant
|
||||
|
||||
Forbidden = BackupForbidden
|
||||
|
||||
|
|
@ -942,6 +954,7 @@ class BackupReader:
|
|||
self._categories: List[BackupCategory] = []
|
||||
self._channels: List[BackupChannel] = []
|
||||
self._threads: List[BackupThread] = []
|
||||
self._thread_map: Dict[int, BackupThread] = {} # Starter Message ID -> Thread
|
||||
self._roles: List[BackupRole] = []
|
||||
self._emojis: List[BackupEmoji] = []
|
||||
self._stickers: List[BackupSticker] = []
|
||||
|
|
@ -1007,6 +1020,8 @@ class BackupReader:
|
|||
for tdata in thread_rows:
|
||||
thread = BackupThread(tdata)
|
||||
self._threads.append(thread)
|
||||
if thread.id:
|
||||
self._thread_map[thread.id] = thread
|
||||
|
||||
# Resolve tag IDs to BackupTag objects using parent forum's available_tags
|
||||
if thread.applied_tags and thread.parent_id:
|
||||
|
|
@ -1208,7 +1223,7 @@ class BackupReader:
|
|||
channel_id = parse_snowflake(msg_data["channel_id"])
|
||||
channel = next((c for c in self.channels if c.id == channel_id), None)
|
||||
|
||||
return BackupMessage(
|
||||
bm = BackupMessage(
|
||||
msg_data,
|
||||
author=author,
|
||||
guild=self.guild,
|
||||
|
|
@ -1216,6 +1231,12 @@ class BackupReader:
|
|||
backup_root=self.backup_path,
|
||||
media_pool=self._media_pool
|
||||
)
|
||||
|
||||
# Link thread if this is a starter message
|
||||
if bm.id in self._thread_map:
|
||||
bm.thread = self._thread_map[bm.id]
|
||||
|
||||
return bm
|
||||
|
||||
async def get_message(self, channel_id: int, message_id: int) -> BackupMessage | None:
|
||||
"""Fetch a specific message from SQLite."""
|
||||
|
|
|
|||
|
|
@ -524,18 +524,47 @@ class DiscordExporter:
|
|||
})
|
||||
|
||||
# 5. Message data
|
||||
# Check for reference (reply)
|
||||
# Check for reference (reply/origin of forward)
|
||||
message_reference = None
|
||||
if msg.reference and msg.reference.message_id:
|
||||
message_reference = str(msg.reference.message_id)
|
||||
|
||||
# 5.5 Forwarded snapshots
|
||||
content = msg.content or ""
|
||||
msg_type = int(msg.type.value) if hasattr(msg.type, "value") else 0
|
||||
|
||||
# Detect if this message is forwarded (discord.py 2.5+)
|
||||
is_forwarded = getattr(msg.flags, 'forwarded', False)
|
||||
if is_forwarded and hasattr(msg, 'message_snapshots') and msg.message_snapshots:
|
||||
msg_type = 100 # Custom Forward type
|
||||
snapshot = msg.message_snapshots[0]
|
||||
if snapshot.content:
|
||||
content = snapshot.content
|
||||
|
||||
# Process snapshot attachments into main attachments list
|
||||
for s_att in snapshot.attachments:
|
||||
att_res = await self._process_media(
|
||||
media_id=s_att.id,
|
||||
url=s_att.url,
|
||||
filename=s_att.filename,
|
||||
size=s_att.size,
|
||||
content_type=s_att.content_type,
|
||||
save_method=s_att.save
|
||||
)
|
||||
if att_res: attachments.append(att_res)
|
||||
|
||||
# Process snapshot embeds (simplified)
|
||||
for s_emb in snapshot.embeds:
|
||||
# We reuse the main embeds list
|
||||
embeds.append(s_emb.to_dict())
|
||||
|
||||
m_data = {
|
||||
"id": str(msg.id),
|
||||
"channel_id": str(msg.channel.id),
|
||||
"author_id": user_id,
|
||||
"content": msg.content,
|
||||
"content": content,
|
||||
"timestamp": msg.created_at.isoformat(),
|
||||
"type": int(msg.type.value) if hasattr(msg.type, "value") else 0,
|
||||
"type": msg_type,
|
||||
"message_reference": message_reference,
|
||||
"is_pinned": 1 if msg.pinned else 0,
|
||||
"attachments": attachments,
|
||||
|
|
|
|||
|
|
@ -94,27 +94,73 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None,
|
|||
return content
|
||||
|
||||
|
||||
async def analyze_migration(context: MigrationContext, source_channel_id: int, after_message_id: int | None = None, inclusive: bool = False, progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None) -> Dict[str, int]:
|
||||
return content
|
||||
|
||||
|
||||
async def get_channel_threads(reader: Any, channel_id: int) -> List[Any]:
|
||||
"""Helper to fetch all threads (active and archived) for a channel from Live or Backup."""
|
||||
threads = []
|
||||
|
||||
# 1. From Backup (BackupReader has 'db' attribute)
|
||||
if hasattr(reader, 'db') and hasattr(reader, 'threads'):
|
||||
for t in reader.threads:
|
||||
if t.parent_id == channel_id:
|
||||
threads.append(t)
|
||||
return threads
|
||||
|
||||
# 2. From live Discord
|
||||
if hasattr(reader, 'guild') and reader.guild:
|
||||
try:
|
||||
# Guild-wide active threads
|
||||
if hasattr(reader.guild, 'active_threads'):
|
||||
for t in reader.guild.active_threads:
|
||||
if t.parent_id == channel_id:
|
||||
threads.append(t)
|
||||
|
||||
# Archived threads for this specific channel
|
||||
channel = await reader.get_channel(channel_id)
|
||||
if hasattr(channel, 'archived_threads'):
|
||||
# discord.py method
|
||||
async for t in channel.archived_threads(limit=None):
|
||||
threads.append(t)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not fetch live threads for {channel_id}: {e}")
|
||||
|
||||
return threads
|
||||
|
||||
|
||||
async def analyze_migration(context: MigrationContext, source_channel_id: int, after_message_id: int | None = None, inclusive: bool = False, progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None, processed_threads: set | None = None) -> Dict[str, int]:
|
||||
"""
|
||||
Scans channel history to count messages, threads, and attachments.
|
||||
"""
|
||||
stats = {"messages": 0, "threads": 0, "attachments": 0}
|
||||
|
||||
if processed_threads is None:
|
||||
processed_threads = set()
|
||||
|
||||
async for msg in context.discord_reader.fetch_message_history(source_channel_id, after_id=after_message_id, inclusive=inclusive):
|
||||
if not context.is_running:
|
||||
break
|
||||
|
||||
# Count thread messages and markers even if parent is skipped
|
||||
if hasattr(msg, 'thread') and msg.thread:
|
||||
stats["threads"] += 1
|
||||
# Recursively count thread content
|
||||
thread_stats = await analyze_migration(context, msg.thread.id)
|
||||
stats["messages"] += thread_stats["messages"]
|
||||
stats["attachments"] += thread_stats["attachments"]
|
||||
stats["threads"] += thread_stats["threads"] # Nested threads (rare in Discord but possible in forum channels)
|
||||
thread = msg.thread
|
||||
if thread.id not in processed_threads:
|
||||
processed_threads.add(thread.id)
|
||||
stats["threads"] += 1
|
||||
# Recursively count thread content
|
||||
thread_stats = await analyze_migration(context, thread.id, processed_threads=processed_threads)
|
||||
stats["messages"] += thread_stats["messages"]
|
||||
stats["attachments"] += thread_stats["attachments"]
|
||||
stats["threads"] += thread_stats["threads"] # Nested threads (rare in Discord but possible in forum channels)
|
||||
|
||||
# Consistent filtering with migrate_messages
|
||||
if msg.type not in [context.discord_reader.MESSAGE_TYPE_DEFAULT, context.discord_reader.MESSAGE_TYPE_REPLY, context.discord_reader.MESSAGE_TYPE_THREAD_STARTER]:
|
||||
if msg.type not in [
|
||||
context.discord_reader.MESSAGE_TYPE_DEFAULT,
|
||||
context.discord_reader.MESSAGE_TYPE_REPLY,
|
||||
context.discord_reader.MESSAGE_TYPE_THREAD_STARTER,
|
||||
context.discord_reader.MESSAGE_TYPE_FORWARD
|
||||
]:
|
||||
continue
|
||||
|
||||
stats["messages"] += 1
|
||||
|
|
@ -123,6 +169,20 @@ async def analyze_migration(context: MigrationContext, source_channel_id: int, a
|
|||
if progress_callback and stats["messages"] % 10 == 0:
|
||||
await progress_callback(stats)
|
||||
|
||||
# After scanning messages, explicitly check for any missed threads (e.g. archived or skipped in scan)
|
||||
# Only do this at the top level (not in recursive thread calls)
|
||||
if after_message_id is not None or inclusive: # Usually top level calls have some start point
|
||||
# Optimization: We check all threads for the channel
|
||||
all_threads = await get_channel_threads(context.discord_reader, source_channel_id)
|
||||
for t in all_threads:
|
||||
if t.id not in processed_threads:
|
||||
processed_threads.add(t.id)
|
||||
stats["threads"] += 1
|
||||
thread_stats = await analyze_migration(context, t.id, processed_threads=processed_threads)
|
||||
stats["messages"] += thread_stats["messages"]
|
||||
stats["attachments"] += thread_stats["attachments"]
|
||||
stats["threads"] += thread_stats["threads"]
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
|
|
@ -135,7 +195,8 @@ async def migrate_messages(
|
|||
progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None,
|
||||
thread_id: str | None = None,
|
||||
parent_target_id: str | None = None,
|
||||
thread_name: str | None = None
|
||||
thread_name: str | None = None,
|
||||
processed_threads: set | None = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Migrate messages for a specific channel and returns detailed statistics."""
|
||||
stats = {
|
||||
|
|
@ -152,6 +213,9 @@ async def migrate_messages(
|
|||
if after_message_id:
|
||||
logger.info(f"Resuming migration from after message ID: {after_message_id}")
|
||||
|
||||
if processed_threads is None:
|
||||
processed_threads = set()
|
||||
|
||||
try:
|
||||
async for msg in context.discord_reader.fetch_message_history(source_channel_id, after_id=after_message_id, inclusive=inclusive):
|
||||
if not context.is_running:
|
||||
|
|
@ -161,33 +225,39 @@ async def migrate_messages(
|
|||
|
||||
|
||||
# Skip system messages like "pinned a message", etc.
|
||||
if msg.type not in [context.discord_reader.MESSAGE_TYPE_DEFAULT, context.discord_reader.MESSAGE_TYPE_REPLY, context.discord_reader.MESSAGE_TYPE_THREAD_STARTER]:
|
||||
if msg.type not in [
|
||||
context.discord_reader.MESSAGE_TYPE_DEFAULT,
|
||||
context.discord_reader.MESSAGE_TYPE_REPLY,
|
||||
context.discord_reader.MESSAGE_TYPE_THREAD_STARTER,
|
||||
context.discord_reader.MESSAGE_TYPE_FORWARD
|
||||
]:
|
||||
# If we are skipping the parent, we STILL need to check for a thread!
|
||||
if hasattr(msg, 'thread') and msg.thread:
|
||||
thread = msg.thread
|
||||
logger.info(f"Detected thread '{thread.name}' on skipped message {msg.id}")
|
||||
|
||||
# Track thread entry
|
||||
stats["threads"] += 1
|
||||
|
||||
# Migrate thread messages recursively
|
||||
thread_stats = await migrate_messages(
|
||||
context=context,
|
||||
source_channel_id=thread.id,
|
||||
target_channel_id=target_channel_id,
|
||||
thread_id=str(thread.id),
|
||||
parent_target_id=None,
|
||||
thread_name=thread.name
|
||||
)
|
||||
stats["messages"] += thread_stats["messages"]
|
||||
stats["attachments"] += thread_stats["attachments"]
|
||||
stats["threads"] += thread_stats["threads"]
|
||||
|
||||
# Send End Marker
|
||||
await context.fluxer_writer.send_marker(
|
||||
channel_id=target_channel_id,
|
||||
content=f"> <<< END OF THREAD >>>"
|
||||
)
|
||||
if thread.id not in processed_threads:
|
||||
processed_threads.add(thread.id)
|
||||
# Track thread entry
|
||||
stats["threads"] += 1
|
||||
|
||||
# Migrate thread messages recursively
|
||||
thread_stats = await migrate_messages(
|
||||
context=context,
|
||||
source_channel_id=thread.id,
|
||||
target_channel_id=target_channel_id,
|
||||
thread_id=str(thread.id),
|
||||
parent_target_id=None,
|
||||
thread_name=thread.name,
|
||||
processed_threads=processed_threads
|
||||
)
|
||||
stats["messages"] += thread_stats["messages"]
|
||||
stats["attachments"] += thread_stats["attachments"]
|
||||
stats["threads"] += thread_stats["threads"]
|
||||
|
||||
# Send End Marker
|
||||
await context.fluxer_writer.send_marker(
|
||||
channel_id=target_channel_id,
|
||||
content=f"> <<< END OF THREAD >>>"
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback(stats)
|
||||
|
|
@ -378,29 +448,30 @@ async def migrate_messages(
|
|||
# Check for associated thread (Normal case: parent message is migrated)
|
||||
if hasattr(msg, 'thread') and msg.thread:
|
||||
thread = msg.thread
|
||||
logger.info(f"Detected thread '{thread.name}' on message {msg.id}")
|
||||
|
||||
# Track thread entry
|
||||
stats["threads"] += 1
|
||||
|
||||
# Migrate thread messages recursively
|
||||
thread_stats = await migrate_messages(
|
||||
context=context,
|
||||
source_channel_id=thread.id,
|
||||
target_channel_id=target_channel_id,
|
||||
thread_id=str(thread.id),
|
||||
parent_target_id=fluxer_msg_id,
|
||||
thread_name=thread.name
|
||||
)
|
||||
stats["messages"] += thread_stats["messages"]
|
||||
stats["attachments"] += thread_stats["attachments"]
|
||||
stats["threads"] += thread_stats["threads"]
|
||||
|
||||
# Send End Marker
|
||||
await context.fluxer_writer.send_marker(
|
||||
channel_id=target_channel_id,
|
||||
content=f"> <<< END OF THREAD >>>"
|
||||
)
|
||||
if thread.id not in processed_threads:
|
||||
processed_threads.add(thread.id)
|
||||
# Track thread entry
|
||||
stats["threads"] += 1
|
||||
|
||||
# Migrate thread messages recursively
|
||||
thread_stats = await migrate_messages(
|
||||
context=context,
|
||||
source_channel_id=thread.id,
|
||||
target_channel_id=target_channel_id,
|
||||
thread_id=str(thread.id),
|
||||
parent_target_id=fluxer_msg_id,
|
||||
thread_name=thread.name,
|
||||
processed_threads=processed_threads
|
||||
)
|
||||
stats["messages"] += thread_stats["messages"]
|
||||
stats["attachments"] += thread_stats["attachments"]
|
||||
stats["threads"] += thread_stats["threads"]
|
||||
|
||||
# Send End Marker
|
||||
await context.fluxer_writer.send_marker(
|
||||
channel_id=target_channel_id,
|
||||
content=f"> <<< END OF THREAD >>>"
|
||||
)
|
||||
|
||||
# Update Link Tracking (but prevent threaded messages from overwriting the parent channel pointers)
|
||||
# The 'after_message_id' param usually means it's the main function call and not a thread recursive call
|
||||
|
|
@ -415,7 +486,34 @@ async def migrate_messages(
|
|||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# Delay for rate limit safety
|
||||
# After scanning messages, explicitly check for any missed threads (e.g. archived or skipped in scan)
|
||||
# Only do this at the top level
|
||||
if not thread_id and (after_message_id is not None or inclusive or stats["messages"] > 0):
|
||||
all_threads = await get_channel_threads(context.discord_reader, source_channel_id)
|
||||
for t in all_threads:
|
||||
if t.id not in processed_threads:
|
||||
processed_threads.add(t.id)
|
||||
logger.info(f"Migrating missed thread '{t.name}' (ID: {t.id})")
|
||||
|
||||
stats["threads"] += 1
|
||||
thread_stats = await migrate_messages(
|
||||
context=context,
|
||||
source_channel_id=t.id,
|
||||
target_channel_id=target_channel_id,
|
||||
thread_id=str(t.id),
|
||||
parent_target_id=None,
|
||||
thread_name=t.name,
|
||||
processed_threads=processed_threads
|
||||
)
|
||||
stats["messages"] += thread_stats["messages"]
|
||||
stats["attachments"] += thread_stats["attachments"]
|
||||
stats["threads"] += thread_stats["threads"]
|
||||
|
||||
await context.fluxer_writer.send_marker(
|
||||
channel_id=target_channel_id,
|
||||
content=f"> <<< END OF THREAD >>>"
|
||||
)
|
||||
|
||||
except (KeyboardInterrupt, asyncio.CancelledError):
|
||||
context.is_running = False
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -94,7 +94,42 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None,
|
|||
return content
|
||||
|
||||
|
||||
async def analyze_migration(context: MigrationContext, source_channel_id: int, after_message_id: int | None = None, inclusive: bool = False, progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None) -> Dict[str, int]:
|
||||
return content
|
||||
|
||||
|
||||
async def get_channel_threads(reader: Any, channel_id: int) -> List[Any]:
|
||||
"""Helper to fetch all threads (active and archived) for a channel from Live or Backup."""
|
||||
threads = []
|
||||
|
||||
# 1. From Backup (BackupReader has 'db' attribute)
|
||||
if hasattr(reader, 'db') and hasattr(reader, 'threads'):
|
||||
for t in reader.threads:
|
||||
if t.parent_id == channel_id:
|
||||
threads.append(t)
|
||||
return threads
|
||||
|
||||
# 2. From live Discord
|
||||
if hasattr(reader, 'guild') and reader.guild:
|
||||
try:
|
||||
# Guild-wide active threads
|
||||
if hasattr(reader.guild, 'active_threads'):
|
||||
for t in reader.guild.active_threads:
|
||||
if t.parent_id == channel_id:
|
||||
threads.append(t)
|
||||
|
||||
# Archived threads for this specific channel
|
||||
channel = await reader.get_channel(channel_id)
|
||||
if hasattr(channel, 'archived_threads'):
|
||||
# discord.py method
|
||||
async for t in channel.archived_threads(limit=None):
|
||||
threads.append(t)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not fetch live threads for {channel_id}: {e}")
|
||||
|
||||
return threads
|
||||
|
||||
|
||||
async def analyze_migration(context: MigrationContext, source_channel_id: int, after_message_id: int | None = None, inclusive: bool = False, progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None, processed_threads: set | None = None) -> Dict[str, int]:
|
||||
"""
|
||||
Scans channel history to count messages, threads, and attachments.
|
||||
"""
|
||||
|
|
@ -106,20 +141,31 @@ async def analyze_migration(context: MigrationContext, source_channel_id: int, a
|
|||
"last_message_url": ""
|
||||
}
|
||||
|
||||
if processed_threads is None:
|
||||
processed_threads = set()
|
||||
|
||||
async for msg in context.discord_reader.fetch_message_history(source_channel_id, after_id=after_message_id, inclusive=inclusive):
|
||||
if not context.is_running:
|
||||
break
|
||||
|
||||
# Count thread messages and markers even if parent is skipped
|
||||
if hasattr(msg, 'thread') and msg.thread:
|
||||
stats["threads"] += 1
|
||||
thread_stats = await analyze_migration(context, msg.thread.id)
|
||||
stats["messages"] += thread_stats["messages"]
|
||||
stats["attachments"] += thread_stats["attachments"]
|
||||
stats["threads"] += thread_stats["threads"]
|
||||
thread = msg.thread
|
||||
if thread.id not in processed_threads:
|
||||
processed_threads.add(thread.id)
|
||||
stats["threads"] += 1
|
||||
thread_stats = await analyze_migration(context, thread.id, processed_threads=processed_threads)
|
||||
stats["messages"] += thread_stats["messages"]
|
||||
stats["attachments"] += thread_stats["attachments"]
|
||||
stats["threads"] += thread_stats["threads"]
|
||||
|
||||
# Consistent filtering with migrate_messages
|
||||
if msg.type not in [context.discord_reader.MESSAGE_TYPE_DEFAULT, context.discord_reader.MESSAGE_TYPE_REPLY, context.discord_reader.MESSAGE_TYPE_THREAD_STARTER]:
|
||||
if msg.type not in [
|
||||
context.discord_reader.MESSAGE_TYPE_DEFAULT,
|
||||
context.discord_reader.MESSAGE_TYPE_REPLY,
|
||||
context.discord_reader.MESSAGE_TYPE_THREAD_STARTER,
|
||||
context.discord_reader.MESSAGE_TYPE_FORWARD
|
||||
]:
|
||||
continue
|
||||
|
||||
stats["messages"] += 1
|
||||
|
|
@ -128,6 +174,18 @@ async def analyze_migration(context: MigrationContext, source_channel_id: int, a
|
|||
if progress_callback and stats["messages"] % 10 == 0:
|
||||
await progress_callback(stats)
|
||||
|
||||
# After scanning messages, explicitly check for any missed threads (e.g. archived or skipped in scan)
|
||||
# Only do this at the top level
|
||||
if after_message_id is not None or inclusive:
|
||||
all_threads = await get_channel_threads(context.discord_reader, source_channel_id)
|
||||
for t in all_threads:
|
||||
if t.id not in processed_threads:
|
||||
processed_threads.add(t.id)
|
||||
stats["threads"] += 1
|
||||
thread_stats = await analyze_migration(context, t.id, processed_threads=processed_threads)
|
||||
stats["messages"] += thread_stats["messages"]
|
||||
stats["attachments"] += thread_stats["attachments"]
|
||||
stats["threads"] += thread_stats["threads"]
|
||||
|
||||
return stats
|
||||
|
||||
|
|
@ -141,7 +199,8 @@ async def migrate_messages(
|
|||
progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None,
|
||||
thread_id: str | None = None,
|
||||
parent_target_id: str | None = None,
|
||||
thread_name: str | None = None
|
||||
thread_name: str | None = None,
|
||||
processed_threads: set | None = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Migrate messages for a specific channel using Stoat masquerade for author impersonation."""
|
||||
stats = {
|
||||
|
|
@ -158,6 +217,9 @@ async def migrate_messages(
|
|||
if after_message_id:
|
||||
logger.info(f"Resuming migration from after message ID: {after_message_id}")
|
||||
|
||||
if processed_threads is None:
|
||||
processed_threads = set()
|
||||
|
||||
try:
|
||||
async for msg in context.discord_reader.fetch_message_history(source_channel_id, after_id=after_message_id, inclusive=inclusive):
|
||||
if not context.is_running:
|
||||
|
|
@ -168,35 +230,39 @@ async def migrate_messages(
|
|||
|
||||
# Skip system messages like "pinned a message", etc.
|
||||
content = "" # Initialize content
|
||||
if msg.type not in [context.discord_reader.MESSAGE_TYPE_DEFAULT, context.discord_reader.MESSAGE_TYPE_REPLY, context.discord_reader.MESSAGE_TYPE_THREAD_STARTER]:
|
||||
if msg.type not in [
|
||||
context.discord_reader.MESSAGE_TYPE_DEFAULT,
|
||||
context.discord_reader.MESSAGE_TYPE_REPLY,
|
||||
context.discord_reader.MESSAGE_TYPE_THREAD_STARTER,
|
||||
context.discord_reader.MESSAGE_TYPE_FORWARD
|
||||
]:
|
||||
# If we are skipping the parent, we STILL need to check for a thread!
|
||||
if hasattr(msg, 'thread') and msg.thread:
|
||||
thread = msg.thread
|
||||
logger.info(f"Detected thread '{thread.name}' on skipped message {msg.id}")
|
||||
|
||||
# Track thread entry
|
||||
stats["threads"] += 1
|
||||
|
||||
pass
|
||||
|
||||
# Migrate thread messages recursively
|
||||
thread_stats = await migrate_messages(
|
||||
context=context,
|
||||
source_channel_id=thread.id,
|
||||
target_channel_id=target_channel_id,
|
||||
thread_id=str(thread.id),
|
||||
parent_target_id=None,
|
||||
thread_name=thread.name
|
||||
)
|
||||
stats["messages"] += thread_stats["messages"]
|
||||
stats["attachments"] += thread_stats["attachments"]
|
||||
stats["threads"] += thread_stats["threads"]
|
||||
|
||||
# Send End Marker
|
||||
await context.stoat_writer.send_marker(
|
||||
channel_id=target_channel_id,
|
||||
content=f"> <<< END OF THREAD >>>"
|
||||
)
|
||||
if thread.id not in processed_threads:
|
||||
processed_threads.add(thread.id)
|
||||
# Track thread entry
|
||||
stats["threads"] += 1
|
||||
|
||||
# Migrate thread messages recursively
|
||||
thread_stats = await migrate_messages(
|
||||
context=context,
|
||||
source_channel_id=thread.id,
|
||||
target_channel_id=target_channel_id,
|
||||
thread_id=str(thread.id),
|
||||
parent_target_id=None,
|
||||
thread_name=thread.name,
|
||||
processed_threads=processed_threads
|
||||
)
|
||||
stats["messages"] += thread_stats["messages"]
|
||||
stats["attachments"] += thread_stats["attachments"]
|
||||
stats["threads"] += thread_stats["threads"]
|
||||
|
||||
# Send End Marker
|
||||
await context.stoat_writer.send_marker(
|
||||
channel_id=target_channel_id,
|
||||
content=f"> <<< END OF THREAD >>>"
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback(stats)
|
||||
|
|
@ -390,31 +456,30 @@ async def migrate_messages(
|
|||
# Check for associated thread (Normal case: parent message is migrated)
|
||||
if hasattr(msg, 'thread') and msg.thread:
|
||||
thread = msg.thread
|
||||
logger.info(f"Detected thread '{thread.name}' on message {msg.id}")
|
||||
|
||||
# Track thread entry
|
||||
stats["threads"] += 1
|
||||
|
||||
pass
|
||||
|
||||
# Migrate thread messages recursively
|
||||
thread_stats = await migrate_messages(
|
||||
context=context,
|
||||
source_channel_id=thread.id,
|
||||
target_channel_id=target_channel_id,
|
||||
thread_id=str(thread.id),
|
||||
parent_target_id=stoat_msg_id,
|
||||
thread_name=thread.name
|
||||
)
|
||||
stats["messages"] += thread_stats["messages"]
|
||||
stats["attachments"] += thread_stats["attachments"]
|
||||
stats["threads"] += thread_stats["threads"]
|
||||
|
||||
# Send End Marker
|
||||
await context.stoat_writer.send_marker(
|
||||
channel_id=target_channel_id,
|
||||
content=f"> <<< END OF THREAD >>>"
|
||||
)
|
||||
if thread.id not in processed_threads:
|
||||
processed_threads.add(thread.id)
|
||||
# Track thread entry
|
||||
stats["threads"] += 1
|
||||
|
||||
# Migrate thread messages recursively
|
||||
thread_stats = await migrate_messages(
|
||||
context=context,
|
||||
source_channel_id=thread.id,
|
||||
target_channel_id=target_channel_id,
|
||||
thread_id=str(thread.id),
|
||||
parent_target_id=stoat_msg_id,
|
||||
thread_name=thread.name,
|
||||
processed_threads=processed_threads
|
||||
)
|
||||
stats["messages"] += thread_stats["messages"]
|
||||
stats["attachments"] += thread_stats["attachments"]
|
||||
stats["threads"] += thread_stats["threads"]
|
||||
|
||||
# Send End Marker
|
||||
await context.stoat_writer.send_marker(
|
||||
channel_id=target_channel_id,
|
||||
content=f"> <<< END OF THREAD >>>"
|
||||
)
|
||||
|
||||
# Update Link Tracking
|
||||
if not stats["first_message_url"]:
|
||||
|
|
@ -431,6 +496,34 @@ async def migrate_messages(
|
|||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# After scanning messages, explicitly check for any missed threads (e.g. archived or skipped in scan)
|
||||
# Only do this at the top level
|
||||
if not thread_id and (after_message_id is not None or inclusive or stats["messages"] > 0):
|
||||
all_threads = await get_channel_threads(context.discord_reader, source_channel_id)
|
||||
for t in all_threads:
|
||||
if t.id not in processed_threads:
|
||||
processed_threads.add(t.id)
|
||||
logger.info(f"Migrating missed thread '{t.name}' (ID: {t.id})")
|
||||
|
||||
stats["threads"] += 1
|
||||
thread_stats = await migrate_messages(
|
||||
context=context,
|
||||
source_channel_id=t.id,
|
||||
target_channel_id=target_channel_id,
|
||||
thread_id=str(t.id),
|
||||
parent_target_id=None,
|
||||
thread_name=t.name,
|
||||
processed_threads=processed_threads
|
||||
)
|
||||
stats["messages"] += thread_stats["messages"]
|
||||
stats["attachments"] += thread_stats["attachments"]
|
||||
stats["threads"] += thread_stats["threads"]
|
||||
|
||||
await context.stoat_writer.send_marker(
|
||||
channel_id=target_channel_id,
|
||||
content=f"> <<< END OF THREAD >>>"
|
||||
)
|
||||
|
||||
except (KeyboardInterrupt, asyncio.CancelledError):
|
||||
context.is_running = False
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -158,6 +158,9 @@ class OperationPane(Container):
|
|||
def on_show(self) -> None:
|
||||
"""Re-validate when the pane regains visibility."""
|
||||
if self.view_mode == "backup" or self.config.tool_mode == "backup_transfer":
|
||||
if self.view_mode == "shuttle":
|
||||
# Re-run path discovery in case a new backup was just made
|
||||
self._rebuild_engine()
|
||||
self.run_validate()
|
||||
|
||||
def reload_config(self) -> None:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue