fix user mentions in backup migration

This commit is contained in:
rambros 2026-03-15 23:59:10 +05:30
parent 06091423a2
commit ee61316451
3 changed files with 94 additions and 76 deletions

View file

@ -75,7 +75,9 @@ class BackupDatabase:
position INTEGER,
category_id TEXT,
topic TEXT,
nsfw INTEGER
nsfw INTEGER,
bitrate INTEGER,
slowmode_delay INTEGER
)
""")
@ -299,8 +301,8 @@ class BackupDatabase:
with self._lock:
conn = self._get_conn()
conn.executemany("""
INSERT OR REPLACE INTO channels (id, name, type, position, category_id, topic, nsfw)
VALUES (:id, :name, :type, :position, :category_id, :topic, :nsfw)
INSERT OR REPLACE INTO channels (id, name, type, position, category_id, topic, nsfw, bitrate, slowmode_delay)
VALUES (:id, :name, :type, :position, :category_id, :topic, :nsfw, :bitrate, :slowmode_delay)
""", channels)
conn.commit()
conn.close()

View file

@ -321,7 +321,8 @@ class BackupChannel:
"""Minimal stand-in for discord.TextChannel / ForumChannel / VoiceChannel."""
__slots__ = ("id", "name", "type", "position", "topic", "nsfw",
"category_id", "available_tags", "parent_id", "guild", "overwrites")
"category_id", "available_tags", "parent_id", "guild", "overwrites",
"bitrate", "slowmode_delay")
_TYPE_MAP = {
"text": ChannelType.text,
@ -340,6 +341,8 @@ class BackupChannel:
self.position = data.get("position", 0)
self.topic = data.get("topic")
self.nsfw = bool(data.get("nsfw", False))
self.bitrate = data.get("bitrate")
self.slowmode_delay = data.get("slowmode_delay")
cid = data.get("category_id")
self.category_id = parse_snowflake(cid) if cid else category_id
self.parent_id = self.category_id
@ -683,7 +686,7 @@ class BackupMessage:
cid = data.get("channel_id")
self.channel_id = parse_snowflake(cid) if cid else (channel.id if channel else None)
# Mentions (simplified)
# Mentions (resolved on the fly by clean_mentions via guild.get_member)
self.mentions = []
self.role_mentions = []
self.channel_mentions = []
@ -925,7 +928,11 @@ class BackupGuild:
def get_member(self, user_id: int) -> "BackupMember | None":
if self._reader:
return self._reader._member_map.get(parse_snowflake(user_id))
uid = parse_snowflake(user_id)
if uid in self._reader._member_map:
return self._reader._member_map[uid]
# Lazy load from DB
return self._reader._resolve_author(uid)
return None
def get_role(self, role_id: int) -> "BackupRole | None":

View file

@ -250,7 +250,9 @@ class DiscordExporter:
"position": cat.position,
"category_id": None,
"topic": None,
"nsfw": 0
"nsfw": 0,
"bitrate": None,
"slowmode_delay": None
})
# Add child channels to list
@ -323,7 +325,9 @@ class DiscordExporter:
"type": int(c.type.value) if hasattr(c.type, "value") else 0,
"position": c.position,
"topic": getattr(c, "topic", None),
"nsfw": 1 if getattr(c, "nsfw", False) else 0
"nsfw": 1 if getattr(c, "nsfw", False) else 0,
"bitrate": getattr(c, "bitrate", None),
"slowmode_delay": getattr(c, "slowmode_delay", None)
}
return data, ch_permissions, ch_forum_tags
@ -363,9 +367,9 @@ class DiscordExporter:
if len(batch_raw) >= BATCH_SIZE:
results = await asyncio.gather(*(self._format_message(m) for m in batch_raw))
for m_data, u_data in results:
for m_data, u_list in results:
batch_messages.append(m_data)
if u_data: batch_users.append(u_data)
if u_list: batch_users.extend(u_list)
new_count += len(batch_messages)
accumulated_count += len(batch_messages)
@ -389,33 +393,32 @@ class DiscordExporter:
batch_users.clear()
batch_raw.clear()
# Final partial batch
if batch_raw and self.is_running:
results = await asyncio.gather(*(self._format_message(m) for m in batch_raw))
for m_data, u_data in results:
batch_messages.append(m_data)
if u_data: batch_users.append(u_data)
new_count += len(batch_messages)
accumulated_count += len(batch_messages)
for m in batch_messages:
if "attachments" in m:
accumulated_files += len(m["attachments"])
if self.db:
if batch_users: self.db.save_users(batch_users)
self.db.save_messages_batch(batch_messages)
if progress_callback:
last_msg = batch_raw[-1]
author_name = getattr(last_msg.author, "display_name", "Unknown")
await progress_callback(channel_name, accumulated_count, author_name=author_name, thread_count=accumulated_threads, file_count=accumulated_files)
batch_messages.clear()
batch_users.clear()
batch_raw.clear()
if batch_raw and self.is_running:
results = await asyncio.gather(*(self._format_message(m) for m in batch_raw))
for m_data, u_list in results:
batch_messages.append(m_data)
if u_list: batch_users.extend(u_list)
new_count += len(batch_messages)
accumulated_count += len(batch_messages)
for m in batch_messages:
if "attachments" in m:
accumulated_files += len(m["attachments"])
if self.db:
if batch_users: self.db.save_users(batch_users)
self.db.save_messages_batch(batch_messages)
if progress_callback:
last_msg = batch_raw[-1]
author_name = getattr(last_msg.author, "display_name", "Unknown")
await progress_callback(channel_name, accumulated_count, author_name=author_name, thread_count=accumulated_threads, file_count=accumulated_files)
batch_messages.clear()
batch_users.clear()
batch_raw.clear()
except discord.Forbidden:
logger.error(f"403 Forbidden: Missing Access to read messages in {channel_name} ({channel_id})")
@ -427,41 +430,55 @@ class DiscordExporter:
return accumulated_count, accumulated_threads, accumulated_files
async def _format_user(self, user):
"""Formats user data for the author or a mention."""
user_id = str(user.id)
if user_id in self.user_cache:
return None
# New user discovered
avatar_file = None
if user.avatar:
try:
av_name = f"{user_id}.png"
av_target = self.users_path / av_name
if not av_target.exists():
await user.avatar.save(av_target)
avatar_file = f"users/{av_name}"
except Exception as e:
logger.error(f"Failed to save avatar for {user.name}: {e}")
roles = []
if hasattr(user, "roles"):
roles = [str(r.id) for r in user.roles if not r.is_default()]
user_data = {
"id": user_id,
"username": user.name,
"display_name": getattr(user, "display_name", user.name),
"avatar_file": avatar_file,
"avatar_url": str(user.display_avatar.url) if user.avatar else None,
"roles": json.dumps(roles)
}
self.user_cache[user_id] = user_data
return user_data
async def _format_message(self, msg):
"""Formats a single message and its author for DB storage."""
# 1. Author handling
author = msg.author
user_id = str(author.id)
user_data = None
new_users = []
if user_id not in self.user_cache:
# New user discovered
avatar_file = None
if author.avatar:
try:
av_name = f"{user_id}.png"
av_target = self.users_path / av_name
if not av_target.exists():
await author.avatar.save(av_target)
avatar_file = f"users/{av_name}"
except Exception as e:
logger.error(f"Failed to save avatar for {author.name}: {e}")
# 1. Author handling
u_data = await self._format_user(msg.author)
if u_data: new_users.append(u_data)
roles = []
if hasattr(author, "roles"):
roles = [str(r.id) for r in author.roles if not r.is_default()]
user_data = {
"id": user_id,
"username": author.name,
"display_name": getattr(author, "display_name", author.name),
"avatar_file": avatar_file,
"avatar_url": str(author.display_avatar.url) if author.avatar else None,
"roles": json.dumps(roles)
}
self.user_cache[user_id] = user_data
# 1.5 Mentions handling (ensure all mentioned users are saved)
if msg.mentions:
for mention in msg.mentions:
u_ment = await self._format_user(mention)
if u_ment: new_users.append(u_ment)
# 2. Attachments handling (Content-Addressable Storage)
# ... (rest of the logic remains same, just updating m_data)
attachments = []
if msg.attachments:
for att in msg.attachments:
@ -480,7 +497,6 @@ class DiscordExporter:
stickers = []
if msg.stickers:
for st in msg.stickers:
# Deduplicate downloads for the same sticker in one session
if st.id in self.sticker_cache:
st_bytes = self.sticker_cache[st.id]
else:
@ -501,8 +517,6 @@ class DiscordExporter:
st_data["name"] = st.name
st_data["format_type"] = int(st.format.value) if hasattr(st, "format") and hasattr(st.format, "value") else 1
stickers.append(st_data)
else:
logger.warning(f"Could not download message sticker {st.id} in message {msg.id}")
# 3. Embeds
embeds = []
@ -521,7 +535,6 @@ class DiscordExporter:
})
# 5. Message data
# Check for reference (reply/origin of forward)
message_reference = None
if msg.reference and msg.reference.message_id:
message_reference = str(msg.reference.message_id)
@ -530,7 +543,6 @@ class DiscordExporter:
content = msg.content or ""
msg_type = int(msg.type.value) if hasattr(msg.type, "value") else 0
# Detect if this message is forwarded (discord.py 2.5+)
is_forwarded = getattr(msg.flags, 'forwarded', False)
if is_forwarded and hasattr(msg, 'message_snapshots') and msg.message_snapshots:
msg_type = 100 # Custom Forward type
@ -538,7 +550,6 @@ class DiscordExporter:
if snapshot.content:
content = snapshot.content
# Process snapshot attachments into main attachments list
for s_att in snapshot.attachments:
att_res = await self._process_media(
media_id=s_att.id,
@ -550,15 +561,13 @@ class DiscordExporter:
)
if att_res: attachments.append(att_res)
# Process snapshot embeds (simplified)
for s_emb in snapshot.embeds:
# We reuse the main embeds list
embeds.append(s_emb.to_dict())
m_data = {
"id": str(msg.id),
"channel_id": str(msg.channel.id),
"author_id": user_id,
"author_id": str(msg.author.id),
"content": content,
"timestamp": msg.created_at.isoformat(),
"type": msg_type,
@ -571,7 +580,7 @@ class DiscordExporter:
"extra_data": None
}
return m_data, user_data
return m_data, new_users
async def _process_media(self, media_id, url, filename, size=None, content_type=None, save_method=None, data=None):
"""Downloads and deduplicates any media (attachment or sticker) using SHA-256 (CAS)."""