wrap discord.py into discord_reader
This commit is contained in:
parent
650aea98eb
commit
600209e419
4 changed files with 29 additions and 14 deletions
|
|
@ -5,6 +5,24 @@ from typing import AsyncGenerator, Dict, Any
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class DiscordReader:
|
class DiscordReader:
|
||||||
|
# -- Provider constants (used by migration scripts instead of importing discord) --
|
||||||
|
MESSAGE_TYPE_DEFAULT = discord.MessageType.default
|
||||||
|
MESSAGE_TYPE_REPLY = discord.MessageType.reply
|
||||||
|
MESSAGE_TYPE_THREAD_STARTER = discord.MessageType.thread_starter_message
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_item(iterable, **attrs):
|
||||||
|
"""Find first item in iterable matching all attrs. Drop-in for discord.utils.get()."""
|
||||||
|
for item in iterable:
|
||||||
|
if all(getattr(item, k, None) == v for k, v in attrs.items()):
|
||||||
|
return item
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_permission_overwrite():
|
||||||
|
"""Factory for discord.PermissionOverwrite, keeps the import centralized."""
|
||||||
|
return discord.PermissionOverwrite()
|
||||||
|
|
||||||
def __init__(self, token: str, server_id: str):
|
def __init__(self, token: str, server_id: str):
|
||||||
self.token = token
|
self.token = token
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import discord
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Callable, Awaitable, Dict, Any
|
from typing import Callable, Awaitable, Dict, Any
|
||||||
|
|
@ -33,10 +32,10 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None,
|
||||||
def replace_role(match):
|
def replace_role(match):
|
||||||
rid = int(match.group(1))
|
rid = int(match.group(1))
|
||||||
# 1. Try provided guild cache/list
|
# 1. Try provided guild cache/list
|
||||||
role = guild.get_role(rid) or discord.utils.get(guild.roles, id=rid)
|
role = guild.get_role(rid) or next((r for r in guild.roles if r.id == rid), None)
|
||||||
# 2. Try message's role_mentions
|
# 2. Try message's role_mentions
|
||||||
if not role and role_mentions:
|
if not role and role_mentions:
|
||||||
role = discord.utils.get(role_mentions, id=rid)
|
role = next((r for r in role_mentions if r.id == rid), None)
|
||||||
|
|
||||||
# 3. Try all guilds the client is aware of (fallback for cache issues)
|
# 3. Try all guilds the client is aware of (fallback for cache issues)
|
||||||
if not role and hasattr(guild, 'client'):
|
if not role and hasattr(guild, 'client'):
|
||||||
|
|
@ -85,7 +84,7 @@ async def analyze_migration(context: MigrationContext, source_channel_id: int, a
|
||||||
stats["threads"] += thread_stats["threads"] # Nested threads (rare in Discord but possible in forum channels)
|
stats["threads"] += thread_stats["threads"] # Nested threads (rare in Discord but possible in forum channels)
|
||||||
|
|
||||||
# Consistent filtering with migrate_messages
|
# Consistent filtering with migrate_messages
|
||||||
if msg.type not in [discord.MessageType.default, discord.MessageType.reply, discord.MessageType.thread_starter_message]:
|
if msg.type not in [context.discord_reader.MESSAGE_TYPE_DEFAULT, context.discord_reader.MESSAGE_TYPE_REPLY, context.discord_reader.MESSAGE_TYPE_THREAD_STARTER]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
stats["messages"] += 1
|
stats["messages"] += 1
|
||||||
|
|
@ -115,9 +114,9 @@ async def migrate_messages(context: MigrationContext, source_channel_id: int, ta
|
||||||
|
|
||||||
# Skip system messages like "pinned a message", etc.
|
# Skip system messages like "pinned a message", etc.
|
||||||
# We treat thread_starter_message (type 21) as our thread marker.
|
# We treat thread_starter_message (type 21) as our thread marker.
|
||||||
if msg.type == discord.MessageType.thread_starter_message:
|
if msg.type == context.discord_reader.MESSAGE_TYPE_THREAD_STARTER:
|
||||||
content = f"> <<< THREAD: **{msg.channel.name}** >>>"
|
content = f"> <<< THREAD: **{msg.channel.name}** >>>"
|
||||||
elif msg.type not in [discord.MessageType.default, discord.MessageType.reply]:
|
elif msg.type not in [context.discord_reader.MESSAGE_TYPE_DEFAULT, context.discord_reader.MESSAGE_TYPE_REPLY]:
|
||||||
# If we are skipping the parent, we STILL need to check for a thread!
|
# If we are skipping the parent, we STILL need to check for a thread!
|
||||||
if hasattr(msg, 'thread') and msg.thread:
|
if hasattr(msg, 'thread') and msg.thread:
|
||||||
thread = msg.thread
|
thread = msg.thread
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import discord
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Callable, Awaitable, Dict, Any
|
from typing import Callable, Awaitable, Dict, Any
|
||||||
|
|
@ -33,10 +32,10 @@ def clean_mentions(content: str, guild, user_mentions=None, role_mentions=None,
|
||||||
def replace_role(match):
|
def replace_role(match):
|
||||||
rid = int(match.group(1))
|
rid = int(match.group(1))
|
||||||
# 1. Try provided guild cache/list
|
# 1. Try provided guild cache/list
|
||||||
role = guild.get_role(rid) or discord.utils.get(guild.roles, id=rid)
|
role = guild.get_role(rid) or next((r for r in guild.roles if r.id == rid), None)
|
||||||
# 2. Try message's role_mentions
|
# 2. Try message's role_mentions
|
||||||
if not role and role_mentions:
|
if not role and role_mentions:
|
||||||
role = discord.utils.get(role_mentions, id=rid)
|
role = next((r for r in role_mentions if r.id == rid), None)
|
||||||
|
|
||||||
# 3. Try all guilds the client is aware of (fallback for cache issues)
|
# 3. Try all guilds the client is aware of (fallback for cache issues)
|
||||||
if not role and hasattr(guild, 'client'):
|
if not role and hasattr(guild, 'client'):
|
||||||
|
|
@ -84,7 +83,7 @@ async def analyze_migration(context: MigrationContext, source_channel_id: int, a
|
||||||
stats["threads"] += thread_stats["threads"]
|
stats["threads"] += thread_stats["threads"]
|
||||||
|
|
||||||
# Consistent filtering with migrate_messages
|
# Consistent filtering with migrate_messages
|
||||||
if msg.type not in [discord.MessageType.default, discord.MessageType.reply, discord.MessageType.thread_starter_message]:
|
if msg.type not in [context.discord_reader.MESSAGE_TYPE_DEFAULT, context.discord_reader.MESSAGE_TYPE_REPLY, context.discord_reader.MESSAGE_TYPE_THREAD_STARTER]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
stats["messages"] += 1
|
stats["messages"] += 1
|
||||||
|
|
@ -115,11 +114,11 @@ async def migrate_messages(context: MigrationContext, source_channel_id: int, ta
|
||||||
# Skip system messages like "pinned a message", etc.
|
# Skip system messages like "pinned a message", etc.
|
||||||
# We treat thread_starter_message (type 21) as our thread marker.
|
# We treat thread_starter_message (type 21) as our thread marker.
|
||||||
content = "" # Initialize content
|
content = "" # Initialize content
|
||||||
if msg.type == discord.MessageType.thread_starter_message:
|
if msg.type == context.discord_reader.MESSAGE_TYPE_THREAD_STARTER:
|
||||||
content = f"> <<< THREAD: **{msg.channel.name}** >>>"
|
content = f"> <<< THREAD: **{msg.channel.name}** >>>"
|
||||||
# If it's a thread starter and we already processed the thread at the top,
|
# If it's a thread starter and we already processed the thread at the top,
|
||||||
# we might be double-posting. But we want it as a marker.
|
# we might be double-posting. But we want it as a marker.
|
||||||
elif msg.type not in [discord.MessageType.default, discord.MessageType.reply]:
|
elif msg.type not in [context.discord_reader.MESSAGE_TYPE_DEFAULT, context.discord_reader.MESSAGE_TYPE_REPLY]:
|
||||||
# If we are skipping the parent, we STILL need to check for a thread!
|
# If we are skipping the parent, we STILL need to check for a thread!
|
||||||
if hasattr(msg, 'thread') and msg.thread:
|
if hasattr(msg, 'thread') and msg.thread:
|
||||||
thread = msg.thread
|
thread = msg.thread
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import discord
|
|
||||||
from typing import Callable, Awaitable
|
from typing import Callable, Awaitable
|
||||||
|
|
||||||
from src.core.base import MigrationContext
|
from src.core.base import MigrationContext
|
||||||
|
|
@ -106,7 +105,7 @@ async def sync_permissions(context: MigrationContext, progress_callback: Callabl
|
||||||
# Merge logic: channel overwrites specific settings, but
|
# Merge logic: channel overwrites specific settings, but
|
||||||
# keeps inherited settings for any permission marked as 'None' in channel
|
# keeps inherited settings for any permission marked as 'None' in channel
|
||||||
cat_ow = final_overwrites[target.id]
|
cat_ow = final_overwrites[target.id]
|
||||||
merged = discord.PermissionOverwrite()
|
merged = context.discord_reader.create_permission_overwrite()
|
||||||
# Apply category settings
|
# Apply category settings
|
||||||
for name, value in cat_ow:
|
for name, value in cat_ow:
|
||||||
setattr(merged, name, value)
|
setattr(merged, name, value)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue