fix user mentions in backup migration
This commit is contained in:
parent
06091423a2
commit
ee61316451
3 changed files with 94 additions and 76 deletions
|
|
@ -75,7 +75,9 @@ class BackupDatabase:
|
||||||
position INTEGER,
|
position INTEGER,
|
||||||
category_id TEXT,
|
category_id TEXT,
|
||||||
topic TEXT,
|
topic TEXT,
|
||||||
nsfw INTEGER
|
nsfw INTEGER,
|
||||||
|
bitrate INTEGER,
|
||||||
|
slowmode_delay INTEGER
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
@ -299,8 +301,8 @@ class BackupDatabase:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
conn = self._get_conn()
|
conn = self._get_conn()
|
||||||
conn.executemany("""
|
conn.executemany("""
|
||||||
INSERT OR REPLACE INTO channels (id, name, type, position, category_id, topic, nsfw)
|
INSERT OR REPLACE INTO channels (id, name, type, position, category_id, topic, nsfw, bitrate, slowmode_delay)
|
||||||
VALUES (:id, :name, :type, :position, :category_id, :topic, :nsfw)
|
VALUES (:id, :name, :type, :position, :category_id, :topic, :nsfw, :bitrate, :slowmode_delay)
|
||||||
""", channels)
|
""", channels)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
|
||||||
|
|
@ -321,7 +321,8 @@ class BackupChannel:
|
||||||
"""Minimal stand-in for discord.TextChannel / ForumChannel / VoiceChannel."""
|
"""Minimal stand-in for discord.TextChannel / ForumChannel / VoiceChannel."""
|
||||||
|
|
||||||
__slots__ = ("id", "name", "type", "position", "topic", "nsfw",
|
__slots__ = ("id", "name", "type", "position", "topic", "nsfw",
|
||||||
"category_id", "available_tags", "parent_id", "guild", "overwrites")
|
"category_id", "available_tags", "parent_id", "guild", "overwrites",
|
||||||
|
"bitrate", "slowmode_delay")
|
||||||
|
|
||||||
_TYPE_MAP = {
|
_TYPE_MAP = {
|
||||||
"text": ChannelType.text,
|
"text": ChannelType.text,
|
||||||
|
|
@ -340,6 +341,8 @@ class BackupChannel:
|
||||||
self.position = data.get("position", 0)
|
self.position = data.get("position", 0)
|
||||||
self.topic = data.get("topic")
|
self.topic = data.get("topic")
|
||||||
self.nsfw = bool(data.get("nsfw", False))
|
self.nsfw = bool(data.get("nsfw", False))
|
||||||
|
self.bitrate = data.get("bitrate")
|
||||||
|
self.slowmode_delay = data.get("slowmode_delay")
|
||||||
cid = data.get("category_id")
|
cid = data.get("category_id")
|
||||||
self.category_id = parse_snowflake(cid) if cid else category_id
|
self.category_id = parse_snowflake(cid) if cid else category_id
|
||||||
self.parent_id = self.category_id
|
self.parent_id = self.category_id
|
||||||
|
|
@ -683,7 +686,7 @@ class BackupMessage:
|
||||||
cid = data.get("channel_id")
|
cid = data.get("channel_id")
|
||||||
self.channel_id = parse_snowflake(cid) if cid else (channel.id if channel else None)
|
self.channel_id = parse_snowflake(cid) if cid else (channel.id if channel else None)
|
||||||
|
|
||||||
# Mentions (simplified)
|
# Mentions (resolved on the fly by clean_mentions via guild.get_member)
|
||||||
self.mentions = []
|
self.mentions = []
|
||||||
self.role_mentions = []
|
self.role_mentions = []
|
||||||
self.channel_mentions = []
|
self.channel_mentions = []
|
||||||
|
|
@ -925,7 +928,11 @@ class BackupGuild:
|
||||||
|
|
||||||
def get_member(self, user_id: int) -> "BackupMember | None":
|
def get_member(self, user_id: int) -> "BackupMember | None":
|
||||||
if self._reader:
|
if self._reader:
|
||||||
return self._reader._member_map.get(parse_snowflake(user_id))
|
uid = parse_snowflake(user_id)
|
||||||
|
if uid in self._reader._member_map:
|
||||||
|
return self._reader._member_map[uid]
|
||||||
|
# Lazy load from DB
|
||||||
|
return self._reader._resolve_author(uid)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_role(self, role_id: int) -> "BackupRole | None":
|
def get_role(self, role_id: int) -> "BackupRole | None":
|
||||||
|
|
|
||||||
|
|
@ -250,7 +250,9 @@ class DiscordExporter:
|
||||||
"position": cat.position,
|
"position": cat.position,
|
||||||
"category_id": None,
|
"category_id": None,
|
||||||
"topic": None,
|
"topic": None,
|
||||||
"nsfw": 0
|
"nsfw": 0,
|
||||||
|
"bitrate": None,
|
||||||
|
"slowmode_delay": None
|
||||||
})
|
})
|
||||||
|
|
||||||
# Add child channels to list
|
# Add child channels to list
|
||||||
|
|
@ -323,7 +325,9 @@ class DiscordExporter:
|
||||||
"type": int(c.type.value) if hasattr(c.type, "value") else 0,
|
"type": int(c.type.value) if hasattr(c.type, "value") else 0,
|
||||||
"position": c.position,
|
"position": c.position,
|
||||||
"topic": getattr(c, "topic", None),
|
"topic": getattr(c, "topic", None),
|
||||||
"nsfw": 1 if getattr(c, "nsfw", False) else 0
|
"nsfw": 1 if getattr(c, "nsfw", False) else 0,
|
||||||
|
"bitrate": getattr(c, "bitrate", None),
|
||||||
|
"slowmode_delay": getattr(c, "slowmode_delay", None)
|
||||||
}
|
}
|
||||||
|
|
||||||
return data, ch_permissions, ch_forum_tags
|
return data, ch_permissions, ch_forum_tags
|
||||||
|
|
@ -363,9 +367,9 @@ class DiscordExporter:
|
||||||
|
|
||||||
if len(batch_raw) >= BATCH_SIZE:
|
if len(batch_raw) >= BATCH_SIZE:
|
||||||
results = await asyncio.gather(*(self._format_message(m) for m in batch_raw))
|
results = await asyncio.gather(*(self._format_message(m) for m in batch_raw))
|
||||||
for m_data, u_data in results:
|
for m_data, u_list in results:
|
||||||
batch_messages.append(m_data)
|
batch_messages.append(m_data)
|
||||||
if u_data: batch_users.append(u_data)
|
if u_list: batch_users.extend(u_list)
|
||||||
|
|
||||||
new_count += len(batch_messages)
|
new_count += len(batch_messages)
|
||||||
accumulated_count += len(batch_messages)
|
accumulated_count += len(batch_messages)
|
||||||
|
|
@ -389,33 +393,32 @@ class DiscordExporter:
|
||||||
batch_users.clear()
|
batch_users.clear()
|
||||||
batch_raw.clear()
|
batch_raw.clear()
|
||||||
|
|
||||||
# Final partial batch
|
if batch_raw and self.is_running:
|
||||||
if batch_raw and self.is_running:
|
results = await asyncio.gather(*(self._format_message(m) for m in batch_raw))
|
||||||
results = await asyncio.gather(*(self._format_message(m) for m in batch_raw))
|
for m_data, u_list in results:
|
||||||
for m_data, u_data in results:
|
batch_messages.append(m_data)
|
||||||
batch_messages.append(m_data)
|
if u_list: batch_users.extend(u_list)
|
||||||
if u_data: batch_users.append(u_data)
|
|
||||||
|
|
||||||
new_count += len(batch_messages)
|
new_count += len(batch_messages)
|
||||||
accumulated_count += len(batch_messages)
|
accumulated_count += len(batch_messages)
|
||||||
|
|
||||||
for m in batch_messages:
|
for m in batch_messages:
|
||||||
if "attachments" in m:
|
if "attachments" in m:
|
||||||
accumulated_files += len(m["attachments"])
|
accumulated_files += len(m["attachments"])
|
||||||
|
|
||||||
|
|
||||||
if self.db:
|
if self.db:
|
||||||
if batch_users: self.db.save_users(batch_users)
|
if batch_users: self.db.save_users(batch_users)
|
||||||
self.db.save_messages_batch(batch_messages)
|
self.db.save_messages_batch(batch_messages)
|
||||||
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
last_msg = batch_raw[-1]
|
last_msg = batch_raw[-1]
|
||||||
author_name = getattr(last_msg.author, "display_name", "Unknown")
|
author_name = getattr(last_msg.author, "display_name", "Unknown")
|
||||||
await progress_callback(channel_name, accumulated_count, author_name=author_name, thread_count=accumulated_threads, file_count=accumulated_files)
|
await progress_callback(channel_name, accumulated_count, author_name=author_name, thread_count=accumulated_threads, file_count=accumulated_files)
|
||||||
|
|
||||||
batch_messages.clear()
|
batch_messages.clear()
|
||||||
batch_users.clear()
|
batch_users.clear()
|
||||||
batch_raw.clear()
|
batch_raw.clear()
|
||||||
|
|
||||||
except discord.Forbidden:
|
except discord.Forbidden:
|
||||||
logger.error(f"403 Forbidden: Missing Access to read messages in {channel_name} ({channel_id})")
|
logger.error(f"403 Forbidden: Missing Access to read messages in {channel_name} ({channel_id})")
|
||||||
|
|
@ -427,41 +430,55 @@ class DiscordExporter:
|
||||||
|
|
||||||
return accumulated_count, accumulated_threads, accumulated_files
|
return accumulated_count, accumulated_threads, accumulated_files
|
||||||
|
|
||||||
|
async def _format_user(self, user):
|
||||||
|
"""Formats user data for the author or a mention."""
|
||||||
|
user_id = str(user.id)
|
||||||
|
if user_id in self.user_cache:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# New user discovered
|
||||||
|
avatar_file = None
|
||||||
|
if user.avatar:
|
||||||
|
try:
|
||||||
|
av_name = f"{user_id}.png"
|
||||||
|
av_target = self.users_path / av_name
|
||||||
|
if not av_target.exists():
|
||||||
|
await user.avatar.save(av_target)
|
||||||
|
avatar_file = f"users/{av_name}"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save avatar for {user.name}: {e}")
|
||||||
|
|
||||||
|
roles = []
|
||||||
|
if hasattr(user, "roles"):
|
||||||
|
roles = [str(r.id) for r in user.roles if not r.is_default()]
|
||||||
|
|
||||||
|
user_data = {
|
||||||
|
"id": user_id,
|
||||||
|
"username": user.name,
|
||||||
|
"display_name": getattr(user, "display_name", user.name),
|
||||||
|
"avatar_file": avatar_file,
|
||||||
|
"avatar_url": str(user.display_avatar.url) if user.avatar else None,
|
||||||
|
"roles": json.dumps(roles)
|
||||||
|
}
|
||||||
|
self.user_cache[user_id] = user_data
|
||||||
|
return user_data
|
||||||
|
|
||||||
async def _format_message(self, msg):
|
async def _format_message(self, msg):
|
||||||
"""Formats a single message and its author for DB storage."""
|
"""Formats a single message and its author for DB storage."""
|
||||||
|
new_users = []
|
||||||
|
|
||||||
# 1. Author handling
|
# 1. Author handling
|
||||||
author = msg.author
|
u_data = await self._format_user(msg.author)
|
||||||
user_id = str(author.id)
|
if u_data: new_users.append(u_data)
|
||||||
user_data = None
|
|
||||||
|
|
||||||
if user_id not in self.user_cache:
|
# 1.5 Mentions handling (ensure all mentioned users are saved)
|
||||||
# New user discovered
|
if msg.mentions:
|
||||||
avatar_file = None
|
for mention in msg.mentions:
|
||||||
if author.avatar:
|
u_ment = await self._format_user(mention)
|
||||||
try:
|
if u_ment: new_users.append(u_ment)
|
||||||
av_name = f"{user_id}.png"
|
|
||||||
av_target = self.users_path / av_name
|
|
||||||
if not av_target.exists():
|
|
||||||
await author.avatar.save(av_target)
|
|
||||||
avatar_file = f"users/{av_name}"
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to save avatar for {author.name}: {e}")
|
|
||||||
|
|
||||||
roles = []
|
|
||||||
if hasattr(author, "roles"):
|
|
||||||
roles = [str(r.id) for r in author.roles if not r.is_default()]
|
|
||||||
|
|
||||||
user_data = {
|
|
||||||
"id": user_id,
|
|
||||||
"username": author.name,
|
|
||||||
"display_name": getattr(author, "display_name", author.name),
|
|
||||||
"avatar_file": avatar_file,
|
|
||||||
"avatar_url": str(author.display_avatar.url) if author.avatar else None,
|
|
||||||
"roles": json.dumps(roles)
|
|
||||||
}
|
|
||||||
self.user_cache[user_id] = user_data
|
|
||||||
|
|
||||||
# 2. Attachments handling (Content-Addressable Storage)
|
# 2. Attachments handling (Content-Addressable Storage)
|
||||||
|
# ... (rest of the logic remains same, just updating m_data)
|
||||||
attachments = []
|
attachments = []
|
||||||
if msg.attachments:
|
if msg.attachments:
|
||||||
for att in msg.attachments:
|
for att in msg.attachments:
|
||||||
|
|
@ -480,7 +497,6 @@ class DiscordExporter:
|
||||||
stickers = []
|
stickers = []
|
||||||
if msg.stickers:
|
if msg.stickers:
|
||||||
for st in msg.stickers:
|
for st in msg.stickers:
|
||||||
# Deduplicate downloads for the same sticker in one session
|
|
||||||
if st.id in self.sticker_cache:
|
if st.id in self.sticker_cache:
|
||||||
st_bytes = self.sticker_cache[st.id]
|
st_bytes = self.sticker_cache[st.id]
|
||||||
else:
|
else:
|
||||||
|
|
@ -501,8 +517,6 @@ class DiscordExporter:
|
||||||
st_data["name"] = st.name
|
st_data["name"] = st.name
|
||||||
st_data["format_type"] = int(st.format.value) if hasattr(st, "format") and hasattr(st.format, "value") else 1
|
st_data["format_type"] = int(st.format.value) if hasattr(st, "format") and hasattr(st.format, "value") else 1
|
||||||
stickers.append(st_data)
|
stickers.append(st_data)
|
||||||
else:
|
|
||||||
logger.warning(f"Could not download message sticker {st.id} in message {msg.id}")
|
|
||||||
|
|
||||||
# 3. Embeds
|
# 3. Embeds
|
||||||
embeds = []
|
embeds = []
|
||||||
|
|
@ -521,7 +535,6 @@ class DiscordExporter:
|
||||||
})
|
})
|
||||||
|
|
||||||
# 5. Message data
|
# 5. Message data
|
||||||
# Check for reference (reply/origin of forward)
|
|
||||||
message_reference = None
|
message_reference = None
|
||||||
if msg.reference and msg.reference.message_id:
|
if msg.reference and msg.reference.message_id:
|
||||||
message_reference = str(msg.reference.message_id)
|
message_reference = str(msg.reference.message_id)
|
||||||
|
|
@ -530,7 +543,6 @@ class DiscordExporter:
|
||||||
content = msg.content or ""
|
content = msg.content or ""
|
||||||
msg_type = int(msg.type.value) if hasattr(msg.type, "value") else 0
|
msg_type = int(msg.type.value) if hasattr(msg.type, "value") else 0
|
||||||
|
|
||||||
# Detect if this message is forwarded (discord.py 2.5+)
|
|
||||||
is_forwarded = getattr(msg.flags, 'forwarded', False)
|
is_forwarded = getattr(msg.flags, 'forwarded', False)
|
||||||
if is_forwarded and hasattr(msg, 'message_snapshots') and msg.message_snapshots:
|
if is_forwarded and hasattr(msg, 'message_snapshots') and msg.message_snapshots:
|
||||||
msg_type = 100 # Custom Forward type
|
msg_type = 100 # Custom Forward type
|
||||||
|
|
@ -538,7 +550,6 @@ class DiscordExporter:
|
||||||
if snapshot.content:
|
if snapshot.content:
|
||||||
content = snapshot.content
|
content = snapshot.content
|
||||||
|
|
||||||
# Process snapshot attachments into main attachments list
|
|
||||||
for s_att in snapshot.attachments:
|
for s_att in snapshot.attachments:
|
||||||
att_res = await self._process_media(
|
att_res = await self._process_media(
|
||||||
media_id=s_att.id,
|
media_id=s_att.id,
|
||||||
|
|
@ -550,15 +561,13 @@ class DiscordExporter:
|
||||||
)
|
)
|
||||||
if att_res: attachments.append(att_res)
|
if att_res: attachments.append(att_res)
|
||||||
|
|
||||||
# Process snapshot embeds (simplified)
|
|
||||||
for s_emb in snapshot.embeds:
|
for s_emb in snapshot.embeds:
|
||||||
# We reuse the main embeds list
|
|
||||||
embeds.append(s_emb.to_dict())
|
embeds.append(s_emb.to_dict())
|
||||||
|
|
||||||
m_data = {
|
m_data = {
|
||||||
"id": str(msg.id),
|
"id": str(msg.id),
|
||||||
"channel_id": str(msg.channel.id),
|
"channel_id": str(msg.channel.id),
|
||||||
"author_id": user_id,
|
"author_id": str(msg.author.id),
|
||||||
"content": content,
|
"content": content,
|
||||||
"timestamp": msg.created_at.isoformat(),
|
"timestamp": msg.created_at.isoformat(),
|
||||||
"type": msg_type,
|
"type": msg_type,
|
||||||
|
|
@ -571,7 +580,7 @@ class DiscordExporter:
|
||||||
"extra_data": None
|
"extra_data": None
|
||||||
}
|
}
|
||||||
|
|
||||||
return m_data, user_data
|
return m_data, new_users
|
||||||
|
|
||||||
async def _process_media(self, media_id, url, filename, size=None, content_type=None, save_method=None, data=None):
|
async def _process_media(self, media_id, url, filename, size=None, content_type=None, save_method=None, data=None):
|
||||||
"""Downloads and deduplicates any media (attachment or sticker) using SHA-256 (CAS)."""
|
"""Downloads and deduplicates any media (attachment or sticker) using SHA-256 (CAS)."""
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue