fix forwarded messages

This commit is contained in:
rambros 2026-03-14 21:42:56 +05:30
parent c2ece58209
commit 7bea5582e9
5 changed files with 367 additions and 123 deletions

View file

@ -81,6 +81,7 @@ class MessageType(IntEnum):
guild_incident_report_false_alarm = 39
purchase_notification = 44
poll_result = 46
forward = 100
# ---------------------------------------------------------------------------
@ -659,7 +660,7 @@ class BackupMessage:
"Default": MessageType.default,
"Reply": MessageType.reply,
"ThreadStarter": MessageType.thread_starter_message,
"Forward": MessageType.default,
"Forward": 100,
}
__slots__ = ("id", "type", "created_at", "pinned", "content", "author",
@ -667,7 +668,7 @@ class BackupMessage:
"reference", "thread", "channel_id", "flags", "guild", "channel",
"mentions", "role_mentions", "channel_mentions", "mention_everyone",
"tts", "nonce", "webhook_id", "application_id", "activity",
"application", "interaction", "components", "jump_url")
"application", "interaction", "components", "jump_url", "message_snapshots")
def __init__(self, data: dict, *,
author: BackupMember | None = None,
guild: "BackupGuild | None" = None,
@ -771,7 +772,10 @@ class BackupMessage:
self.reference = type("Ref", (), {"message_id": parse_snowflake(data["message_reference"]), "channel_id": self.channel_id})()
self.thread = None
self.flags = type("Flags", (), {"value": 0})()
self.flags = type("Flags", (), {"value": 0, "forwarded": self.type == MessageType.forward})()
# snapshots not used in latest refined structure as content is in main message
self.message_snapshots = []
def __repr__(self) -> str:
return f"BackupMessage(id={self.id}, author={self.author})"
@ -837,6 +841,13 @@ class BackupGuild:
if self.banner:
self.banner.url = data.get("banner_url")
@property
def active_threads(self) -> List[BackupThread]:
"""Returns all threads that are not archived."""
if self._reader:
return [t for t in self._reader._threads if not t.archived]
return []
@property
def roles(self) -> List[BackupRole]:
return self._reader._roles if self._reader else []
@ -904,6 +915,7 @@ class BackupReader:
MESSAGE_TYPE_DEFAULT = MessageType.default
MESSAGE_TYPE_REPLY = MessageType.reply
MESSAGE_TYPE_THREAD_STARTER = MessageType.thread_starter_message
MESSAGE_TYPE_FORWARD = MessageType.forward # Custom Reaper constant
Forbidden = BackupForbidden
@ -942,6 +954,7 @@ class BackupReader:
self._categories: List[BackupCategory] = []
self._channels: List[BackupChannel] = []
self._threads: List[BackupThread] = []
self._thread_map: Dict[int, BackupThread] = {} # Starter Message ID -> Thread
self._roles: List[BackupRole] = []
self._emojis: List[BackupEmoji] = []
self._stickers: List[BackupSticker] = []
@ -1007,6 +1020,8 @@ class BackupReader:
for tdata in thread_rows:
thread = BackupThread(tdata)
self._threads.append(thread)
if thread.id:
self._thread_map[thread.id] = thread
# Resolve tag IDs to BackupTag objects using parent forum's available_tags
if thread.applied_tags and thread.parent_id:
@ -1208,7 +1223,7 @@ class BackupReader:
channel_id = parse_snowflake(msg_data["channel_id"])
channel = next((c for c in self.channels if c.id == channel_id), None)
return BackupMessage(
bm = BackupMessage(
msg_data,
author=author,
guild=self.guild,
@ -1217,6 +1232,12 @@ class BackupReader:
media_pool=self._media_pool
)
# Link thread if this is a starter message
if bm.id in self._thread_map:
bm.thread = self._thread_map[bm.id]
return bm
async def get_message(self, channel_id: int, message_id: int) -> BackupMessage | None:
"""Fetch a specific message from SQLite."""
if not self.db: return None

View file

@ -524,18 +524,47 @@ class DiscordExporter:
})
# 5. Message data
# Check for reference (reply)
# Check for reference (reply/origin of forward)
message_reference = None
if msg.reference and msg.reference.message_id:
message_reference = str(msg.reference.message_id)
# 5.5 Forwarded snapshots
content = msg.content or ""
msg_type = int(msg.type.value) if hasattr(msg.type, "value") else 0
# Detect if this message is forwarded (discord.py 2.5+)
is_forwarded = getattr(msg.flags, 'forwarded', False)
if is_forwarded and hasattr(msg, 'message_snapshots') and msg.message_snapshots:
msg_type = 100 # Custom Forward type
snapshot = msg.message_snapshots[0]
if snapshot.content:
content = snapshot.content
# Process snapshot attachments into main attachments list
for s_att in snapshot.attachments:
att_res = await self._process_media(
media_id=s_att.id,
url=s_att.url,
filename=s_att.filename,
size=s_att.size,
content_type=s_att.content_type,
save_method=s_att.save
)
if att_res: attachments.append(att_res)
# Process snapshot embeds (simplified)
for s_emb in snapshot.embeds:
# We reuse the main embeds list
embeds.append(s_emb.to_dict())
m_data = {
"id": str(msg.id),
"channel_id": str(msg.channel.id),
"author_id": user_id,
"content": msg.content,
"content": content,
"timestamp": msg.created_at.isoformat(),
"type": int(msg.type.value) if hasattr(msg.type, "value") else 0,
"type": msg_type,
"message_reference": message_reference,
"is_pinned": 1 if msg.pinned else 0,
"attachments": attachments,

View file

@ -94,27 +94,73 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None,
return content
async def analyze_migration(context: MigrationContext, source_channel_id: int, after_message_id: int | None = None, inclusive: bool = False, progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None) -> Dict[str, int]:
return content
async def get_channel_threads(reader: Any, channel_id: int) -> List[Any]:
"""Helper to fetch all threads (active and archived) for a channel from Live or Backup."""
threads = []
# 1. From Backup (BackupReader has 'db' attribute)
if hasattr(reader, 'db') and hasattr(reader, 'threads'):
for t in reader.threads:
if t.parent_id == channel_id:
threads.append(t)
return threads
# 2. From live Discord
if hasattr(reader, 'guild') and reader.guild:
try:
# Guild-wide active threads
if hasattr(reader.guild, 'active_threads'):
for t in reader.guild.active_threads:
if t.parent_id == channel_id:
threads.append(t)
# Archived threads for this specific channel
channel = await reader.get_channel(channel_id)
if hasattr(channel, 'archived_threads'):
# discord.py method
async for t in channel.archived_threads(limit=None):
threads.append(t)
except Exception as e:
logger.debug(f"Could not fetch live threads for {channel_id}: {e}")
return threads
async def analyze_migration(context: MigrationContext, source_channel_id: int, after_message_id: int | None = None, inclusive: bool = False, progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None, processed_threads: set | None = None) -> Dict[str, int]:
"""
Scans channel history to count messages, threads, and attachments.
"""
stats = {"messages": 0, "threads": 0, "attachments": 0}
if processed_threads is None:
processed_threads = set()
async for msg in context.discord_reader.fetch_message_history(source_channel_id, after_id=after_message_id, inclusive=inclusive):
if not context.is_running:
break
# Count thread messages and markers even if parent is skipped
if hasattr(msg, 'thread') and msg.thread:
thread = msg.thread
if thread.id not in processed_threads:
processed_threads.add(thread.id)
stats["threads"] += 1
# Recursively count thread content
thread_stats = await analyze_migration(context, msg.thread.id)
thread_stats = await analyze_migration(context, thread.id, processed_threads=processed_threads)
stats["messages"] += thread_stats["messages"]
stats["attachments"] += thread_stats["attachments"]
stats["threads"] += thread_stats["threads"] # Nested threads (rare in Discord but possible in forum channels)
# Consistent filtering with migrate_messages
if msg.type not in [context.discord_reader.MESSAGE_TYPE_DEFAULT, context.discord_reader.MESSAGE_TYPE_REPLY, context.discord_reader.MESSAGE_TYPE_THREAD_STARTER]:
if msg.type not in [
context.discord_reader.MESSAGE_TYPE_DEFAULT,
context.discord_reader.MESSAGE_TYPE_REPLY,
context.discord_reader.MESSAGE_TYPE_THREAD_STARTER,
context.discord_reader.MESSAGE_TYPE_FORWARD
]:
continue
stats["messages"] += 1
@ -123,6 +169,20 @@ async def analyze_migration(context: MigrationContext, source_channel_id: int, a
if progress_callback and stats["messages"] % 10 == 0:
await progress_callback(stats)
# After scanning messages, explicitly check for any missed threads (e.g. archived or skipped in scan)
# Only do this at the top level (not in recursive thread calls)
if after_message_id is not None or inclusive: # Usually top level calls have some start point
# Optimization: We check all threads for the channel
all_threads = await get_channel_threads(context.discord_reader, source_channel_id)
for t in all_threads:
if t.id not in processed_threads:
processed_threads.add(t.id)
stats["threads"] += 1
thread_stats = await analyze_migration(context, t.id, processed_threads=processed_threads)
stats["messages"] += thread_stats["messages"]
stats["attachments"] += thread_stats["attachments"]
stats["threads"] += thread_stats["threads"]
return stats
@ -135,7 +195,8 @@ async def migrate_messages(
progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None,
thread_id: str | None = None,
parent_target_id: str | None = None,
thread_name: str | None = None
thread_name: str | None = None,
processed_threads: set | None = None
) -> Dict[str, Any]:
"""Migrate messages for a specific channel and returns detailed statistics."""
stats = {
@ -152,6 +213,9 @@ async def migrate_messages(
if after_message_id:
logger.info(f"Resuming migration from after message ID: {after_message_id}")
if processed_threads is None:
processed_threads = set()
try:
async for msg in context.discord_reader.fetch_message_history(source_channel_id, after_id=after_message_id, inclusive=inclusive):
if not context.is_running:
@ -161,12 +225,17 @@ async def migrate_messages(
# Skip system messages like "pinned a message", etc.
if msg.type not in [context.discord_reader.MESSAGE_TYPE_DEFAULT, context.discord_reader.MESSAGE_TYPE_REPLY, context.discord_reader.MESSAGE_TYPE_THREAD_STARTER]:
if msg.type not in [
context.discord_reader.MESSAGE_TYPE_DEFAULT,
context.discord_reader.MESSAGE_TYPE_REPLY,
context.discord_reader.MESSAGE_TYPE_THREAD_STARTER,
context.discord_reader.MESSAGE_TYPE_FORWARD
]:
# If we are skipping the parent, we STILL need to check for a thread!
if hasattr(msg, 'thread') and msg.thread:
thread = msg.thread
logger.info(f"Detected thread '{thread.name}' on skipped message {msg.id}")
if thread.id not in processed_threads:
processed_threads.add(thread.id)
# Track thread entry
stats["threads"] += 1
@ -177,7 +246,8 @@ async def migrate_messages(
target_channel_id=target_channel_id,
thread_id=str(thread.id),
parent_target_id=None,
thread_name=thread.name
thread_name=thread.name,
processed_threads=processed_threads
)
stats["messages"] += thread_stats["messages"]
stats["attachments"] += thread_stats["attachments"]
@ -378,8 +448,8 @@ async def migrate_messages(
# Check for associated thread (Normal case: parent message is migrated)
if hasattr(msg, 'thread') and msg.thread:
thread = msg.thread
logger.info(f"Detected thread '{thread.name}' on message {msg.id}")
if thread.id not in processed_threads:
processed_threads.add(thread.id)
# Track thread entry
stats["threads"] += 1
@ -390,7 +460,8 @@ async def migrate_messages(
target_channel_id=target_channel_id,
thread_id=str(thread.id),
parent_target_id=fluxer_msg_id,
thread_name=thread.name
thread_name=thread.name,
processed_threads=processed_threads
)
stats["messages"] += thread_stats["messages"]
stats["attachments"] += thread_stats["attachments"]
@ -415,7 +486,34 @@ async def migrate_messages(
import traceback
logger.error(traceback.format_exc())
# Delay for rate limit safety
# After scanning messages, explicitly check for any missed threads (e.g. archived or skipped in scan)
# Only do this at the top level
if not thread_id and (after_message_id is not None or inclusive or stats["messages"] > 0):
all_threads = await get_channel_threads(context.discord_reader, source_channel_id)
for t in all_threads:
if t.id not in processed_threads:
processed_threads.add(t.id)
logger.info(f"Migrating missed thread '{t.name}' (ID: {t.id})")
stats["threads"] += 1
thread_stats = await migrate_messages(
context=context,
source_channel_id=t.id,
target_channel_id=target_channel_id,
thread_id=str(t.id),
parent_target_id=None,
thread_name=t.name,
processed_threads=processed_threads
)
stats["messages"] += thread_stats["messages"]
stats["attachments"] += thread_stats["attachments"]
stats["threads"] += thread_stats["threads"]
await context.fluxer_writer.send_marker(
channel_id=target_channel_id,
content=f"> <<< END OF THREAD >>>"
)
except (KeyboardInterrupt, asyncio.CancelledError):
context.is_running = False
pass

View file

@ -94,7 +94,42 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None,
return content
async def analyze_migration(context: MigrationContext, source_channel_id: int, after_message_id: int | None = None, inclusive: bool = False, progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None) -> Dict[str, int]:
return content
async def get_channel_threads(reader: Any, channel_id: int) -> List[Any]:
"""Helper to fetch all threads (active and archived) for a channel from Live or Backup."""
threads = []
# 1. From Backup (BackupReader has 'db' attribute)
if hasattr(reader, 'db') and hasattr(reader, 'threads'):
for t in reader.threads:
if t.parent_id == channel_id:
threads.append(t)
return threads
# 2. From live Discord
if hasattr(reader, 'guild') and reader.guild:
try:
# Guild-wide active threads
if hasattr(reader.guild, 'active_threads'):
for t in reader.guild.active_threads:
if t.parent_id == channel_id:
threads.append(t)
# Archived threads for this specific channel
channel = await reader.get_channel(channel_id)
if hasattr(channel, 'archived_threads'):
# discord.py method
async for t in channel.archived_threads(limit=None):
threads.append(t)
except Exception as e:
logger.debug(f"Could not fetch live threads for {channel_id}: {e}")
return threads
async def analyze_migration(context: MigrationContext, source_channel_id: int, after_message_id: int | None = None, inclusive: bool = False, progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None, processed_threads: set | None = None) -> Dict[str, int]:
"""
Scans channel history to count messages, threads, and attachments.
"""
@ -106,20 +141,31 @@ async def analyze_migration(context: MigrationContext, source_channel_id: int, a
"last_message_url": ""
}
if processed_threads is None:
processed_threads = set()
async for msg in context.discord_reader.fetch_message_history(source_channel_id, after_id=after_message_id, inclusive=inclusive):
if not context.is_running:
break
# Count thread messages and markers even if parent is skipped
if hasattr(msg, 'thread') and msg.thread:
thread = msg.thread
if thread.id not in processed_threads:
processed_threads.add(thread.id)
stats["threads"] += 1
thread_stats = await analyze_migration(context, msg.thread.id)
thread_stats = await analyze_migration(context, thread.id, processed_threads=processed_threads)
stats["messages"] += thread_stats["messages"]
stats["attachments"] += thread_stats["attachments"]
stats["threads"] += thread_stats["threads"]
# Consistent filtering with migrate_messages
if msg.type not in [context.discord_reader.MESSAGE_TYPE_DEFAULT, context.discord_reader.MESSAGE_TYPE_REPLY, context.discord_reader.MESSAGE_TYPE_THREAD_STARTER]:
if msg.type not in [
context.discord_reader.MESSAGE_TYPE_DEFAULT,
context.discord_reader.MESSAGE_TYPE_REPLY,
context.discord_reader.MESSAGE_TYPE_THREAD_STARTER,
context.discord_reader.MESSAGE_TYPE_FORWARD
]:
continue
stats["messages"] += 1
@ -128,6 +174,18 @@ async def analyze_migration(context: MigrationContext, source_channel_id: int, a
if progress_callback and stats["messages"] % 10 == 0:
await progress_callback(stats)
# After scanning messages, explicitly check for any missed threads (e.g. archived or skipped in scan)
# Only do this at the top level
if after_message_id is not None or inclusive:
all_threads = await get_channel_threads(context.discord_reader, source_channel_id)
for t in all_threads:
if t.id not in processed_threads:
processed_threads.add(t.id)
stats["threads"] += 1
thread_stats = await analyze_migration(context, t.id, processed_threads=processed_threads)
stats["messages"] += thread_stats["messages"]
stats["attachments"] += thread_stats["attachments"]
stats["threads"] += thread_stats["threads"]
return stats
@ -141,7 +199,8 @@ async def migrate_messages(
progress_callback: Callable[[Dict[str, Any]], Awaitable[None]] | None = None,
thread_id: str | None = None,
parent_target_id: str | None = None,
thread_name: str | None = None
thread_name: str | None = None,
processed_threads: set | None = None
) -> Dict[str, Any]:
"""Migrate messages for a specific channel using Stoat masquerade for author impersonation."""
stats = {
@ -158,6 +217,9 @@ async def migrate_messages(
if after_message_id:
logger.info(f"Resuming migration from after message ID: {after_message_id}")
if processed_threads is None:
processed_threads = set()
try:
async for msg in context.discord_reader.fetch_message_history(source_channel_id, after_id=after_message_id, inclusive=inclusive):
if not context.is_running:
@ -168,17 +230,20 @@ async def migrate_messages(
# Skip system messages like "pinned a message", etc.
content = "" # Initialize content
if msg.type not in [context.discord_reader.MESSAGE_TYPE_DEFAULT, context.discord_reader.MESSAGE_TYPE_REPLY, context.discord_reader.MESSAGE_TYPE_THREAD_STARTER]:
if msg.type not in [
context.discord_reader.MESSAGE_TYPE_DEFAULT,
context.discord_reader.MESSAGE_TYPE_REPLY,
context.discord_reader.MESSAGE_TYPE_THREAD_STARTER,
context.discord_reader.MESSAGE_TYPE_FORWARD
]:
# If we are skipping the parent, we STILL need to check for a thread!
if hasattr(msg, 'thread') and msg.thread:
thread = msg.thread
logger.info(f"Detected thread '{thread.name}' on skipped message {msg.id}")
if thread.id not in processed_threads:
processed_threads.add(thread.id)
# Track thread entry
stats["threads"] += 1
pass
# Migrate thread messages recursively
thread_stats = await migrate_messages(
context=context,
@ -186,7 +251,8 @@ async def migrate_messages(
target_channel_id=target_channel_id,
thread_id=str(thread.id),
parent_target_id=None,
thread_name=thread.name
thread_name=thread.name,
processed_threads=processed_threads
)
stats["messages"] += thread_stats["messages"]
stats["attachments"] += thread_stats["attachments"]
@ -390,13 +456,11 @@ async def migrate_messages(
# Check for associated thread (Normal case: parent message is migrated)
if hasattr(msg, 'thread') and msg.thread:
thread = msg.thread
logger.info(f"Detected thread '{thread.name}' on message {msg.id}")
if thread.id not in processed_threads:
processed_threads.add(thread.id)
# Track thread entry
stats["threads"] += 1
pass
# Migrate thread messages recursively
thread_stats = await migrate_messages(
context=context,
@ -404,7 +468,8 @@ async def migrate_messages(
target_channel_id=target_channel_id,
thread_id=str(thread.id),
parent_target_id=stoat_msg_id,
thread_name=thread.name
thread_name=thread.name,
processed_threads=processed_threads
)
stats["messages"] += thread_stats["messages"]
stats["attachments"] += thread_stats["attachments"]
@ -431,6 +496,34 @@ async def migrate_messages(
import traceback
logger.error(traceback.format_exc())
# After scanning messages, explicitly check for any missed threads (e.g. archived or skipped in scan)
# Only do this at the top level
if not thread_id and (after_message_id is not None or inclusive or stats["messages"] > 0):
all_threads = await get_channel_threads(context.discord_reader, source_channel_id)
for t in all_threads:
if t.id not in processed_threads:
processed_threads.add(t.id)
logger.info(f"Migrating missed thread '{t.name}' (ID: {t.id})")
stats["threads"] += 1
thread_stats = await migrate_messages(
context=context,
source_channel_id=t.id,
target_channel_id=target_channel_id,
thread_id=str(t.id),
parent_target_id=None,
thread_name=t.name,
processed_threads=processed_threads
)
stats["messages"] += thread_stats["messages"]
stats["attachments"] += thread_stats["attachments"]
stats["threads"] += thread_stats["threads"]
await context.stoat_writer.send_marker(
channel_id=target_channel_id,
content=f"> <<< END OF THREAD >>>"
)
except (KeyboardInterrupt, asyncio.CancelledError):
context.is_running = False
pass

View file

@ -158,6 +158,9 @@ class OperationPane(Container):
def on_show(self) -> None:
"""Re-validate when the pane regains visibility."""
if self.view_mode == "backup" or self.config.tool_mode == "backup_transfer":
if self.view_mode == "shuttle":
# Re-run path discovery in case a new backup was just made
self._rebuild_engine()
self.run_validate()
def reload_config(self) -> None: