wrap discord.py into discord_reader
This commit is contained in:
parent
650aea98eb
commit
600209e419
4 changed files with 29 additions and 14 deletions
|
|
@ -5,6 +5,24 @@ from typing import AsyncGenerator, Dict, Any
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DiscordReader:
|
||||
# -- Provider constants (used by migration scripts instead of importing discord) --
|
||||
MESSAGE_TYPE_DEFAULT = discord.MessageType.default
|
||||
MESSAGE_TYPE_REPLY = discord.MessageType.reply
|
||||
MESSAGE_TYPE_THREAD_STARTER = discord.MessageType.thread_starter_message
|
||||
|
||||
@staticmethod
|
||||
def find_item(iterable, **attrs):
|
||||
"""Find first item in iterable matching all attrs. Drop-in for discord.utils.get()."""
|
||||
for item in iterable:
|
||||
if all(getattr(item, k, None) == v for k, v in attrs.items()):
|
||||
return item
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def create_permission_overwrite():
|
||||
"""Factory for discord.PermissionOverwrite, keeps the import centralized."""
|
||||
return discord.PermissionOverwrite()
|
||||
|
||||
def __init__(self, token: str, server_id: str):
|
||||
self.token = token
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import asyncio
|
||||
import discord
|
||||
import logging
|
||||
import re
|
||||
from typing import Callable, Awaitable, Dict, Any
|
||||
|
|
@ -33,10 +32,10 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None,
|
|||
def replace_role(match):
|
||||
rid = int(match.group(1))
|
||||
# 1. Try provided guild cache/list
|
||||
role = guild.get_role(rid) or discord.utils.get(guild.roles, id=rid)
|
||||
role = guild.get_role(rid) or next((r for r in guild.roles if r.id == rid), None)
|
||||
# 2. Try message's role_mentions
|
||||
if not role and role_mentions:
|
||||
role = discord.utils.get(role_mentions, id=rid)
|
||||
role = next((r for r in role_mentions if r.id == rid), None)
|
||||
|
||||
# 3. Try all guilds the client is aware of (fallback for cache issues)
|
||||
if not role and hasattr(guild, 'client'):
|
||||
|
|
@ -85,7 +84,7 @@ async def analyze_migration(context: MigrationContext, source_channel_id: int, a
|
|||
stats["threads"] += thread_stats["threads"] # Nested threads (rare in Discord but possible in forum channels)
|
||||
|
||||
# Consistent filtering with migrate_messages
|
||||
if msg.type not in [discord.MessageType.default, discord.MessageType.reply, discord.MessageType.thread_starter_message]:
|
||||
if msg.type not in [context.discord_reader.MESSAGE_TYPE_DEFAULT, context.discord_reader.MESSAGE_TYPE_REPLY, context.discord_reader.MESSAGE_TYPE_THREAD_STARTER]:
|
||||
continue
|
||||
|
||||
stats["messages"] += 1
|
||||
|
|
@ -115,9 +114,9 @@ async def migrate_messages(context: MigrationContext, source_channel_id: int, ta
|
|||
|
||||
# Skip system messages like "pinned a message", etc.
|
||||
# We treat thread_starter_message (type 21) as our thread marker.
|
||||
if msg.type == discord.MessageType.thread_starter_message:
|
||||
if msg.type == context.discord_reader.MESSAGE_TYPE_THREAD_STARTER:
|
||||
content = f"> <<< THREAD: **{msg.channel.name}** >>>"
|
||||
elif msg.type not in [discord.MessageType.default, discord.MessageType.reply]:
|
||||
elif msg.type not in [context.discord_reader.MESSAGE_TYPE_DEFAULT, context.discord_reader.MESSAGE_TYPE_REPLY]:
|
||||
# If we are skipping the parent, we STILL need to check for a thread!
|
||||
if hasattr(msg, 'thread') and msg.thread:
|
||||
thread = msg.thread
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import asyncio
|
||||
import discord
|
||||
import logging
|
||||
import re
|
||||
from typing import Callable, Awaitable, Dict, Any
|
||||
|
|
@ -33,10 +32,10 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None,
|
|||
def replace_role(match):
|
||||
rid = int(match.group(1))
|
||||
# 1. Try provided guild cache/list
|
||||
role = guild.get_role(rid) or discord.utils.get(guild.roles, id=rid)
|
||||
role = guild.get_role(rid) or next((r for r in guild.roles if r.id == rid), None)
|
||||
# 2. Try message's role_mentions
|
||||
if not role and role_mentions:
|
||||
role = discord.utils.get(role_mentions, id=rid)
|
||||
role = next((r for r in role_mentions if r.id == rid), None)
|
||||
|
||||
# 3. Try all guilds the client is aware of (fallback for cache issues)
|
||||
if not role and hasattr(guild, 'client'):
|
||||
|
|
@ -84,7 +83,7 @@ async def analyze_migration(context: MigrationContext, source_channel_id: int, a
|
|||
stats["threads"] += thread_stats["threads"]
|
||||
|
||||
# Consistent filtering with migrate_messages
|
||||
if msg.type not in [discord.MessageType.default, discord.MessageType.reply, discord.MessageType.thread_starter_message]:
|
||||
if msg.type not in [context.discord_reader.MESSAGE_TYPE_DEFAULT, context.discord_reader.MESSAGE_TYPE_REPLY, context.discord_reader.MESSAGE_TYPE_THREAD_STARTER]:
|
||||
continue
|
||||
|
||||
stats["messages"] += 1
|
||||
|
|
@ -115,11 +114,11 @@ async def migrate_messages(context: MigrationContext, source_channel_id: int, ta
|
|||
# Skip system messages like "pinned a message", etc.
|
||||
# We treat thread_starter_message (type 21) as our thread marker.
|
||||
content = "" # Initialize content
|
||||
if msg.type == discord.MessageType.thread_starter_message:
|
||||
if msg.type == context.discord_reader.MESSAGE_TYPE_THREAD_STARTER:
|
||||
content = f"> <<< THREAD: **{msg.channel.name}** >>>"
|
||||
# If it's a thread starter and we already processed the thread at the top,
|
||||
# we might be double-posting. But we want it as a marker.
|
||||
elif msg.type not in [discord.MessageType.default, discord.MessageType.reply]:
|
||||
elif msg.type not in [context.discord_reader.MESSAGE_TYPE_DEFAULT, context.discord_reader.MESSAGE_TYPE_REPLY]:
|
||||
# If we are skipping the parent, we STILL need to check for a thread!
|
||||
if hasattr(msg, 'thread') and msg.thread:
|
||||
thread = msg.thread
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import discord
|
||||
from typing import Callable, Awaitable
|
||||
|
||||
from src.core.base import MigrationContext
|
||||
|
|
@ -106,7 +105,7 @@ async def sync_permissions(context: MigrationContext, progress_callback: Callabl
|
|||
# Merge logic: channel overwrites specific settings, but
|
||||
# keeps inherited settings for any permission marked as 'None' in channel
|
||||
cat_ow = final_overwrites[target.id]
|
||||
merged = discord.PermissionOverwrite()
|
||||
merged = context.discord_reader.create_permission_overwrite()
|
||||
# Apply category settings
|
||||
for name, value in cat_ow:
|
||||
setattr(merged, name, value)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue