implement thread mentions
This commit is contained in:
parent
4967183442
commit
939e063d46
5 changed files with 50 additions and 11 deletions
|
|
@ -945,11 +945,23 @@ class BackupGuild:
|
||||||
return next((c for c in self._reader._channels if c.id == parse_snowflake(channel_id)), None)
|
return next((c for c in self._reader._channels if c.id == parse_snowflake(channel_id)), None)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_thread(self, thread_id: int) -> "BackupChannel | None":
|
def get_thread(self, thread_id: int) -> "BackupThread | None":
|
||||||
|
"""Mock discord.Guild.get_thread."""
|
||||||
if self._reader:
|
if self._reader:
|
||||||
return next((c for c in self._reader._threads if c.id == parse_snowflake(thread_id)), None)
|
return next((t for t in self._reader._threads if t.id == parse_snowflake(thread_id)), None)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def fetch_channels(self) -> List[Union["BackupChannel", "BackupCategory"]]:
|
||||||
|
"""Async stub for discord.Guild.fetch_channels."""
|
||||||
|
if self._reader:
|
||||||
|
self._reader._ensure_structure_loaded()
|
||||||
|
return list(self._reader._channels) + list(self._reader._categories)
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def fetch_active_threads(self) -> List["BackupThread"]:
|
||||||
|
"""Async stub for discord.Guild.active_threads helper (mirrors the property)."""
|
||||||
|
return self.active_threads
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"BackupGuild(id={self.id}, name='{self.name}')"
|
return f"BackupGuild(id={self.id}, name='{self.name}')"
|
||||||
|
|
||||||
|
|
@ -1237,6 +1249,14 @@ class BackupReader:
|
||||||
channels = [c for c in channels if c.category_id == category_id]
|
channels = [c for c in channels if c.category_id == category_id]
|
||||||
return channels
|
return channels
|
||||||
|
|
||||||
|
async def fetch_channels(self) -> List[Union[BackupChannel, BackupCategory]]:
|
||||||
|
"""Uniform interface for all channels (including categories)."""
|
||||||
|
return list(self.channels) + list(self.categories)
|
||||||
|
|
||||||
|
async def get_active_threads(self) -> List[BackupThread]:
|
||||||
|
"""Uniform interface for active threads."""
|
||||||
|
return [t for t in self.threads if not t.archived]
|
||||||
|
|
||||||
async def get_backed_up_channel_ids(self) -> List[int]:
|
async def get_backed_up_channel_ids(self) -> List[int]:
|
||||||
"""Returns a list of channel IDs that have messages in the database."""
|
"""Returns a list of channel IDs that have messages in the database."""
|
||||||
if not self.db: return []
|
if not self.db: return []
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ class DiscordReader:
|
||||||
|
|
||||||
# Channel Types
|
# Channel Types
|
||||||
CHANNEL_TYPE_TEXT = discord.ChannelType.text
|
CHANNEL_TYPE_TEXT = discord.ChannelType.text
|
||||||
|
CHANNEL_TYPE_VOICE = discord.ChannelType.voice
|
||||||
CHANNEL_TYPE_NEWS = discord.ChannelType.news
|
CHANNEL_TYPE_NEWS = discord.ChannelType.news
|
||||||
CHANNEL_TYPE_FORUM = discord.ChannelType.forum
|
CHANNEL_TYPE_FORUM = discord.ChannelType.forum
|
||||||
|
|
||||||
|
|
@ -228,6 +229,18 @@ class DiscordReader:
|
||||||
all_channels = [c for c in all_channels if c.category_id == category_id]
|
all_channels = [c for c in all_channels if c.category_id == category_id]
|
||||||
return all_channels
|
return all_channels
|
||||||
|
|
||||||
|
async def get_active_threads(self) -> list[discord.Thread]:
|
||||||
|
"""Returns all active threads in the guild."""
|
||||||
|
if not self.guild:
|
||||||
|
return []
|
||||||
|
return self.guild.active_threads
|
||||||
|
|
||||||
|
async def fetch_channels(self) -> list[Union[discord.TextChannel, discord.VoiceChannel, discord.CategoryChannel, discord.ForumChannel]]:
|
||||||
|
"""Async stub for discord.Guild.fetch_channels."""
|
||||||
|
if not self.guild:
|
||||||
|
return []
|
||||||
|
return await self.guild.fetch_channels()
|
||||||
|
|
||||||
async def get_channel(self, channel_id: int):
|
async def get_channel(self, channel_id: int):
|
||||||
"""Returns a channel object."""
|
"""Returns a channel object."""
|
||||||
return await self.client.fetch_channel(channel_id)
|
return await self.client.fetch_channel(channel_id)
|
||||||
|
|
|
||||||
|
|
@ -224,9 +224,9 @@ class MigrationState:
|
||||||
row = conn.execute("SELECT channel_id, target_msg_id FROM message_mappings WHERE source_msg_id = ?", (str(discord_id),)).fetchone()
|
row = conn.execute("SELECT channel_id, target_msg_id FROM message_mappings WHERE source_msg_id = ?", (str(discord_id),)).fetchone()
|
||||||
if row:
|
if row:
|
||||||
return str(row["channel_id"]), str(row["target_msg_id"])
|
return str(row["channel_id"]), str(row["target_msg_id"])
|
||||||
row = conn.execute("SELECT thread_id, target_msg_id FROM thread_mappings WHERE source_msg_id = ?", (str(discord_id),)).fetchone()
|
row = conn.execute("SELECT channel_id, target_msg_id FROM thread_mappings WHERE source_msg_id = ?", (str(discord_id),)).fetchone()
|
||||||
if row:
|
if row:
|
||||||
return str(row["thread_id"]), str(row["target_msg_id"])
|
return str(row["channel_id"]), str(row["target_msg_id"])
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# --- Danger Zone Clearing ---
|
# --- Danger Zone Clearing ---
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,10 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None,
|
||||||
|
|
||||||
# 3. Try live lookup (fallback)
|
# 3. Try live lookup (fallback)
|
||||||
if not name:
|
if not name:
|
||||||
|
try:
|
||||||
channel = guild.get_channel(cid) or guild.get_thread(cid)
|
channel = guild.get_channel(cid) or guild.get_thread(cid)
|
||||||
|
except Exception:
|
||||||
|
channel = None
|
||||||
if not channel and channel_mentions:
|
if not channel and channel_mentions:
|
||||||
channel = next((c for c in channel_mentions if c.id == cid), None)
|
channel = next((c for c in channel_mentions if c.id == cid), None)
|
||||||
if channel:
|
if channel:
|
||||||
|
|
@ -261,12 +264,12 @@ async def migrate_messages(
|
||||||
try:
|
try:
|
||||||
logger.debug(f"Pre-fetching channel and thread names for guild {context.discord_reader.guild.id}...")
|
logger.debug(f"Pre-fetching channel and thread names for guild {context.discord_reader.guild.id}...")
|
||||||
# fetch_channels usually includes all non-thread channels
|
# fetch_channels usually includes all non-thread channels
|
||||||
all_channels = await context.discord_reader.guild.fetch_channels()
|
all_channels = await context.discord_reader.fetch_channels()
|
||||||
for c in all_channels:
|
for c in all_channels:
|
||||||
context.channel_names[str(c.id)] = c.name
|
context.channel_names[str(c.id)] = c.name
|
||||||
|
|
||||||
# active_threads helps find threads that might be mentioned
|
# get_active_threads helps find threads that might be mentioned
|
||||||
threads = await context.discord_reader.guild.active_threads()
|
threads = await context.discord_reader.get_active_threads()
|
||||||
for t in threads:
|
for t in threads:
|
||||||
context.channel_names[str(t.id)] = t.name
|
context.channel_names[str(t.id)] = t.name
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -79,7 +79,10 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None,
|
||||||
|
|
||||||
# 3. Try live lookup (fallback)
|
# 3. Try live lookup (fallback)
|
||||||
if not name:
|
if not name:
|
||||||
|
try:
|
||||||
channel = guild.get_channel(cid) or guild.get_thread(cid)
|
channel = guild.get_channel(cid) or guild.get_thread(cid)
|
||||||
|
except Exception:
|
||||||
|
channel = None
|
||||||
if not channel and channel_mentions:
|
if not channel and channel_mentions:
|
||||||
channel = next((c for c in channel_mentions if c.id == cid), None)
|
channel = next((c for c in channel_mentions if c.id == cid), None)
|
||||||
if channel:
|
if channel:
|
||||||
|
|
@ -260,11 +263,11 @@ async def migrate_messages(
|
||||||
context.channel_names = {}
|
context.channel_names = {}
|
||||||
try:
|
try:
|
||||||
logger.debug(f"Pre-fetching channel and thread names for guild {context.discord_reader.guild.id}...")
|
logger.debug(f"Pre-fetching channel and thread names for guild {context.discord_reader.guild.id}...")
|
||||||
all_channels = await context.discord_reader.guild.fetch_channels()
|
all_channels = await context.discord_reader.fetch_channels()
|
||||||
for c in all_channels:
|
for c in all_channels:
|
||||||
context.channel_names[str(c.id)] = c.name
|
context.channel_names[str(c.id)] = c.name
|
||||||
|
|
||||||
threads = await context.discord_reader.guild.active_threads()
|
threads = await context.discord_reader.get_active_threads()
|
||||||
for t in threads:
|
for t in threads:
|
||||||
context.channel_names[str(t.id)] = t.name
|
context.channel_names[str(t.id)] = t.name
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue