fix user mentions in backup migration
This commit is contained in:
parent
06091423a2
commit
ee61316451
3 changed files with 94 additions and 76 deletions
|
|
@ -75,7 +75,9 @@ class BackupDatabase:
|
|||
position INTEGER,
|
||||
category_id TEXT,
|
||||
topic TEXT,
|
||||
nsfw INTEGER
|
||||
nsfw INTEGER,
|
||||
bitrate INTEGER,
|
||||
slowmode_delay INTEGER
|
||||
)
|
||||
""")
|
||||
|
||||
|
|
@ -299,8 +301,8 @@ class BackupDatabase:
|
|||
with self._lock:
|
||||
conn = self._get_conn()
|
||||
conn.executemany("""
|
||||
INSERT OR REPLACE INTO channels (id, name, type, position, category_id, topic, nsfw)
|
||||
VALUES (:id, :name, :type, :position, :category_id, :topic, :nsfw)
|
||||
INSERT OR REPLACE INTO channels (id, name, type, position, category_id, topic, nsfw, bitrate, slowmode_delay)
|
||||
VALUES (:id, :name, :type, :position, :category_id, :topic, :nsfw, :bitrate, :slowmode_delay)
|
||||
""", channels)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
|
|
|||
|
|
@ -321,7 +321,8 @@ class BackupChannel:
|
|||
"""Minimal stand-in for discord.TextChannel / ForumChannel / VoiceChannel."""
|
||||
|
||||
__slots__ = ("id", "name", "type", "position", "topic", "nsfw",
|
||||
"category_id", "available_tags", "parent_id", "guild", "overwrites")
|
||||
"category_id", "available_tags", "parent_id", "guild", "overwrites",
|
||||
"bitrate", "slowmode_delay")
|
||||
|
||||
_TYPE_MAP = {
|
||||
"text": ChannelType.text,
|
||||
|
|
@ -340,6 +341,8 @@ class BackupChannel:
|
|||
self.position = data.get("position", 0)
|
||||
self.topic = data.get("topic")
|
||||
self.nsfw = bool(data.get("nsfw", False))
|
||||
self.bitrate = data.get("bitrate")
|
||||
self.slowmode_delay = data.get("slowmode_delay")
|
||||
cid = data.get("category_id")
|
||||
self.category_id = parse_snowflake(cid) if cid else category_id
|
||||
self.parent_id = self.category_id
|
||||
|
|
@ -683,7 +686,7 @@ class BackupMessage:
|
|||
cid = data.get("channel_id")
|
||||
self.channel_id = parse_snowflake(cid) if cid else (channel.id if channel else None)
|
||||
|
||||
# Mentions (simplified)
|
||||
# Mentions (resolved on the fly by clean_mentions via guild.get_member)
|
||||
self.mentions = []
|
||||
self.role_mentions = []
|
||||
self.channel_mentions = []
|
||||
|
|
@ -925,7 +928,11 @@ class BackupGuild:
|
|||
|
||||
def get_member(self, user_id: int) -> "BackupMember | None":
|
||||
if self._reader:
|
||||
return self._reader._member_map.get(parse_snowflake(user_id))
|
||||
uid = parse_snowflake(user_id)
|
||||
if uid in self._reader._member_map:
|
||||
return self._reader._member_map[uid]
|
||||
# Lazy load from DB
|
||||
return self._reader._resolve_author(uid)
|
||||
return None
|
||||
|
||||
def get_role(self, role_id: int) -> "BackupRole | None":
|
||||
|
|
|
|||
|
|
@ -250,7 +250,9 @@ class DiscordExporter:
|
|||
"position": cat.position,
|
||||
"category_id": None,
|
||||
"topic": None,
|
||||
"nsfw": 0
|
||||
"nsfw": 0,
|
||||
"bitrate": None,
|
||||
"slowmode_delay": None
|
||||
})
|
||||
|
||||
# Add child channels to list
|
||||
|
|
@ -323,7 +325,9 @@ class DiscordExporter:
|
|||
"type": int(c.type.value) if hasattr(c.type, "value") else 0,
|
||||
"position": c.position,
|
||||
"topic": getattr(c, "topic", None),
|
||||
"nsfw": 1 if getattr(c, "nsfw", False) else 0
|
||||
"nsfw": 1 if getattr(c, "nsfw", False) else 0,
|
||||
"bitrate": getattr(c, "bitrate", None),
|
||||
"slowmode_delay": getattr(c, "slowmode_delay", None)
|
||||
}
|
||||
|
||||
return data, ch_permissions, ch_forum_tags
|
||||
|
|
@ -363,9 +367,9 @@ class DiscordExporter:
|
|||
|
||||
if len(batch_raw) >= BATCH_SIZE:
|
||||
results = await asyncio.gather(*(self._format_message(m) for m in batch_raw))
|
||||
for m_data, u_data in results:
|
||||
for m_data, u_list in results:
|
||||
batch_messages.append(m_data)
|
||||
if u_data: batch_users.append(u_data)
|
||||
if u_list: batch_users.extend(u_list)
|
||||
|
||||
new_count += len(batch_messages)
|
||||
accumulated_count += len(batch_messages)
|
||||
|
|
@ -389,33 +393,32 @@ class DiscordExporter:
|
|||
batch_users.clear()
|
||||
batch_raw.clear()
|
||||
|
||||
# Final partial batch
|
||||
if batch_raw and self.is_running:
|
||||
results = await asyncio.gather(*(self._format_message(m) for m in batch_raw))
|
||||
for m_data, u_data in results:
|
||||
batch_messages.append(m_data)
|
||||
if u_data: batch_users.append(u_data)
|
||||
|
||||
new_count += len(batch_messages)
|
||||
accumulated_count += len(batch_messages)
|
||||
|
||||
for m in batch_messages:
|
||||
if "attachments" in m:
|
||||
accumulated_files += len(m["attachments"])
|
||||
|
||||
|
||||
if self.db:
|
||||
if batch_users: self.db.save_users(batch_users)
|
||||
self.db.save_messages_batch(batch_messages)
|
||||
|
||||
if progress_callback:
|
||||
last_msg = batch_raw[-1]
|
||||
author_name = getattr(last_msg.author, "display_name", "Unknown")
|
||||
await progress_callback(channel_name, accumulated_count, author_name=author_name, thread_count=accumulated_threads, file_count=accumulated_files)
|
||||
|
||||
batch_messages.clear()
|
||||
batch_users.clear()
|
||||
batch_raw.clear()
|
||||
if batch_raw and self.is_running:
|
||||
results = await asyncio.gather(*(self._format_message(m) for m in batch_raw))
|
||||
for m_data, u_list in results:
|
||||
batch_messages.append(m_data)
|
||||
if u_list: batch_users.extend(u_list)
|
||||
|
||||
new_count += len(batch_messages)
|
||||
accumulated_count += len(batch_messages)
|
||||
|
||||
for m in batch_messages:
|
||||
if "attachments" in m:
|
||||
accumulated_files += len(m["attachments"])
|
||||
|
||||
|
||||
if self.db:
|
||||
if batch_users: self.db.save_users(batch_users)
|
||||
self.db.save_messages_batch(batch_messages)
|
||||
|
||||
if progress_callback:
|
||||
last_msg = batch_raw[-1]
|
||||
author_name = getattr(last_msg.author, "display_name", "Unknown")
|
||||
await progress_callback(channel_name, accumulated_count, author_name=author_name, thread_count=accumulated_threads, file_count=accumulated_files)
|
||||
|
||||
batch_messages.clear()
|
||||
batch_users.clear()
|
||||
batch_raw.clear()
|
||||
|
||||
except discord.Forbidden:
|
||||
logger.error(f"403 Forbidden: Missing Access to read messages in {channel_name} ({channel_id})")
|
||||
|
|
@ -427,41 +430,55 @@ class DiscordExporter:
|
|||
|
||||
return accumulated_count, accumulated_threads, accumulated_files
|
||||
|
||||
async def _format_user(self, user):
|
||||
"""Formats user data for the author or a mention."""
|
||||
user_id = str(user.id)
|
||||
if user_id in self.user_cache:
|
||||
return None
|
||||
|
||||
# New user discovered
|
||||
avatar_file = None
|
||||
if user.avatar:
|
||||
try:
|
||||
av_name = f"{user_id}.png"
|
||||
av_target = self.users_path / av_name
|
||||
if not av_target.exists():
|
||||
await user.avatar.save(av_target)
|
||||
avatar_file = f"users/{av_name}"
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save avatar for {user.name}: {e}")
|
||||
|
||||
roles = []
|
||||
if hasattr(user, "roles"):
|
||||
roles = [str(r.id) for r in user.roles if not r.is_default()]
|
||||
|
||||
user_data = {
|
||||
"id": user_id,
|
||||
"username": user.name,
|
||||
"display_name": getattr(user, "display_name", user.name),
|
||||
"avatar_file": avatar_file,
|
||||
"avatar_url": str(user.display_avatar.url) if user.avatar else None,
|
||||
"roles": json.dumps(roles)
|
||||
}
|
||||
self.user_cache[user_id] = user_data
|
||||
return user_data
|
||||
|
||||
async def _format_message(self, msg):
|
||||
"""Formats a single message and its author for DB storage."""
|
||||
# 1. Author handling
|
||||
author = msg.author
|
||||
user_id = str(author.id)
|
||||
user_data = None
|
||||
new_users = []
|
||||
|
||||
if user_id not in self.user_cache:
|
||||
# New user discovered
|
||||
avatar_file = None
|
||||
if author.avatar:
|
||||
try:
|
||||
av_name = f"{user_id}.png"
|
||||
av_target = self.users_path / av_name
|
||||
if not av_target.exists():
|
||||
await author.avatar.save(av_target)
|
||||
avatar_file = f"users/{av_name}"
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save avatar for {author.name}: {e}")
|
||||
# 1. Author handling
|
||||
u_data = await self._format_user(msg.author)
|
||||
if u_data: new_users.append(u_data)
|
||||
|
||||
roles = []
|
||||
if hasattr(author, "roles"):
|
||||
roles = [str(r.id) for r in author.roles if not r.is_default()]
|
||||
|
||||
user_data = {
|
||||
"id": user_id,
|
||||
"username": author.name,
|
||||
"display_name": getattr(author, "display_name", author.name),
|
||||
"avatar_file": avatar_file,
|
||||
"avatar_url": str(author.display_avatar.url) if author.avatar else None,
|
||||
"roles": json.dumps(roles)
|
||||
}
|
||||
self.user_cache[user_id] = user_data
|
||||
# 1.5 Mentions handling (ensure all mentioned users are saved)
|
||||
if msg.mentions:
|
||||
for mention in msg.mentions:
|
||||
u_ment = await self._format_user(mention)
|
||||
if u_ment: new_users.append(u_ment)
|
||||
|
||||
# 2. Attachments handling (Content-Addressable Storage)
|
||||
# ... (rest of the logic remains same, just updating m_data)
|
||||
attachments = []
|
||||
if msg.attachments:
|
||||
for att in msg.attachments:
|
||||
|
|
@ -480,7 +497,6 @@ class DiscordExporter:
|
|||
stickers = []
|
||||
if msg.stickers:
|
||||
for st in msg.stickers:
|
||||
# Deduplicate downloads for the same sticker in one session
|
||||
if st.id in self.sticker_cache:
|
||||
st_bytes = self.sticker_cache[st.id]
|
||||
else:
|
||||
|
|
@ -501,8 +517,6 @@ class DiscordExporter:
|
|||
st_data["name"] = st.name
|
||||
st_data["format_type"] = int(st.format.value) if hasattr(st, "format") and hasattr(st.format, "value") else 1
|
||||
stickers.append(st_data)
|
||||
else:
|
||||
logger.warning(f"Could not download message sticker {st.id} in message {msg.id}")
|
||||
|
||||
# 3. Embeds
|
||||
embeds = []
|
||||
|
|
@ -521,7 +535,6 @@ class DiscordExporter:
|
|||
})
|
||||
|
||||
# 5. Message data
|
||||
# Check for reference (reply/origin of forward)
|
||||
message_reference = None
|
||||
if msg.reference and msg.reference.message_id:
|
||||
message_reference = str(msg.reference.message_id)
|
||||
|
|
@ -530,7 +543,6 @@ class DiscordExporter:
|
|||
content = msg.content or ""
|
||||
msg_type = int(msg.type.value) if hasattr(msg.type, "value") else 0
|
||||
|
||||
# Detect if this message is forwarded (discord.py 2.5+)
|
||||
is_forwarded = getattr(msg.flags, 'forwarded', False)
|
||||
if is_forwarded and hasattr(msg, 'message_snapshots') and msg.message_snapshots:
|
||||
msg_type = 100 # Custom Forward type
|
||||
|
|
@ -538,7 +550,6 @@ class DiscordExporter:
|
|||
if snapshot.content:
|
||||
content = snapshot.content
|
||||
|
||||
# Process snapshot attachments into main attachments list
|
||||
for s_att in snapshot.attachments:
|
||||
att_res = await self._process_media(
|
||||
media_id=s_att.id,
|
||||
|
|
@ -550,15 +561,13 @@ class DiscordExporter:
|
|||
)
|
||||
if att_res: attachments.append(att_res)
|
||||
|
||||
# Process snapshot embeds (simplified)
|
||||
for s_emb in snapshot.embeds:
|
||||
# We reuse the main embeds list
|
||||
embeds.append(s_emb.to_dict())
|
||||
|
||||
m_data = {
|
||||
"id": str(msg.id),
|
||||
"channel_id": str(msg.channel.id),
|
||||
"author_id": user_id,
|
||||
"author_id": str(msg.author.id),
|
||||
"content": content,
|
||||
"timestamp": msg.created_at.isoformat(),
|
||||
"type": msg_type,
|
||||
|
|
@ -571,7 +580,7 @@ class DiscordExporter:
|
|||
"extra_data": None
|
||||
}
|
||||
|
||||
return m_data, user_data
|
||||
return m_data, new_users
|
||||
|
||||
async def _process_media(self, media_id, url, filename, size=None, content_type=None, save_method=None, data=None):
|
||||
"""Downloads and deduplicates any media (attachment or sticker) using SHA-256 (CAS)."""
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue