[commands] Minimise code duplication in channel converters
This commit is contained in:
parent
ec71eb2fcb
commit
353737239a
@ -26,7 +26,7 @@ from __future__ import annotations
|
||||
|
||||
import re
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, List, Protocol, TypeVar, Tuple, Union, runtime_checkable
|
||||
from typing import Iterable, Optional, TYPE_CHECKING, List, Protocol, Type, TypeVar, Tuple, Union, runtime_checkable
|
||||
|
||||
import discord
|
||||
from .errors import *
|
||||
@ -72,6 +72,7 @@ def _get_from_guilds(bot, getter, argument):
|
||||
_utils_get = discord.utils.get
|
||||
T = TypeVar('T')
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
CT = TypeVar('CT', bound=discord.abc.GuildChannel)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@ -112,13 +113,13 @@ class Converter(Protocol[T_co]):
|
||||
raise NotImplementedError('Derived classes need to implement this.')
|
||||
|
||||
|
||||
class IDConverter(Converter[T_co]):
|
||||
def __init__(self):
|
||||
self._id_regex = re.compile(r'([0-9]{15,20})$')
|
||||
super().__init__()
|
||||
_ID_REGEX = re.compile(r'([0-9]{15,20})$')
|
||||
|
||||
def _get_id_match(self, argument):
|
||||
return self._id_regex.match(argument)
|
||||
|
||||
class IDConverter(Converter[T_co]):
|
||||
@staticmethod
|
||||
def _get_id_match(argument):
|
||||
return _ID_REGEX.match(argument)
|
||||
|
||||
|
||||
class MemberConverter(IDConverter[discord.Member]):
|
||||
@ -351,20 +352,24 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.TextChannel:
|
||||
return self._resolve_channel(ctx, argument, ctx.guild.text_channels, discord.TextChannel)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_channel(ctx: Context, argument: str, iterable: Iterable[CT], type: Type[CT]) -> CT:
|
||||
bot = ctx.bot
|
||||
|
||||
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
|
||||
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
|
||||
result = None
|
||||
guild = ctx.guild
|
||||
|
||||
if match is None:
|
||||
# not a mention
|
||||
if guild:
|
||||
result = discord.utils.get(guild.text_channels, name=argument)
|
||||
result: Optional[CT] = discord.utils.get(iterable, name=argument)
|
||||
else:
|
||||
|
||||
def check(c):
|
||||
return isinstance(c, discord.TextChannel) and c.name == argument
|
||||
return isinstance(c, type) and c.name == argument
|
||||
|
||||
result = discord.utils.find(check, bot.get_all_channels())
|
||||
else:
|
||||
@ -374,7 +379,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
|
||||
else:
|
||||
result = _get_from_guilds(bot, 'get_channel', channel_id)
|
||||
|
||||
if not isinstance(result, discord.TextChannel):
|
||||
if not isinstance(result, type):
|
||||
raise ChannelNotFound(argument)
|
||||
|
||||
return result
|
||||
@ -397,32 +402,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel:
|
||||
bot = ctx.bot
|
||||
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
|
||||
result = None
|
||||
guild = ctx.guild
|
||||
|
||||
if match is None:
|
||||
# not a mention
|
||||
if guild:
|
||||
result = discord.utils.get(guild.voice_channels, name=argument)
|
||||
else:
|
||||
|
||||
def check(c):
|
||||
return isinstance(c, discord.VoiceChannel) and c.name == argument
|
||||
|
||||
result = discord.utils.find(check, bot.get_all_channels())
|
||||
else:
|
||||
channel_id = int(match.group(1))
|
||||
if guild:
|
||||
result = guild.get_channel(channel_id)
|
||||
else:
|
||||
result = _get_from_guilds(bot, 'get_channel', channel_id)
|
||||
|
||||
if not isinstance(result, discord.VoiceChannel):
|
||||
raise ChannelNotFound(argument)
|
||||
|
||||
return result
|
||||
return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.voice_channels, discord.VoiceChannel)
|
||||
|
||||
|
||||
class StageChannelConverter(IDConverter[discord.StageChannel]):
|
||||
@ -441,32 +421,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.StageChannel:
|
||||
bot = ctx.bot
|
||||
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
|
||||
result = None
|
||||
guild = ctx.guild
|
||||
|
||||
if match is None:
|
||||
# not a mention
|
||||
if guild:
|
||||
result = discord.utils.get(guild.stage_channels, name=argument)
|
||||
else:
|
||||
|
||||
def check(c):
|
||||
return isinstance(c, discord.StageChannel) and c.name == argument
|
||||
|
||||
result = discord.utils.find(check, bot.get_all_channels())
|
||||
else:
|
||||
channel_id = int(match.group(1))
|
||||
if guild:
|
||||
result = guild.get_channel(channel_id)
|
||||
else:
|
||||
result = _get_from_guilds(bot, 'get_channel', channel_id)
|
||||
|
||||
if not isinstance(result, discord.StageChannel):
|
||||
raise ChannelNotFound(argument)
|
||||
|
||||
return result
|
||||
return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.stage_channels, discord.StageChannel)
|
||||
|
||||
|
||||
class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
|
||||
@ -486,33 +441,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel:
|
||||
bot = ctx.bot
|
||||
|
||||
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
|
||||
result = None
|
||||
guild = ctx.guild
|
||||
|
||||
if match is None:
|
||||
# not a mention
|
||||
if guild:
|
||||
result = discord.utils.get(guild.categories, name=argument)
|
||||
else:
|
||||
|
||||
def check(c):
|
||||
return isinstance(c, discord.CategoryChannel) and c.name == argument
|
||||
|
||||
result = discord.utils.find(check, bot.get_all_channels())
|
||||
else:
|
||||
channel_id = int(match.group(1))
|
||||
if guild:
|
||||
result = guild.get_channel(channel_id)
|
||||
else:
|
||||
result = _get_from_guilds(bot, 'get_channel', channel_id)
|
||||
|
||||
if not isinstance(result, discord.CategoryChannel):
|
||||
raise ChannelNotFound(argument)
|
||||
|
||||
return result
|
||||
return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.categories, discord.CategoryChannel)
|
||||
|
||||
|
||||
class StoreChannelConverter(IDConverter[discord.StoreChannel]):
|
||||
@ -531,32 +460,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel:
|
||||
bot = ctx.bot
|
||||
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
|
||||
result = None
|
||||
guild = ctx.guild
|
||||
|
||||
if match is None:
|
||||
# not a mention
|
||||
if guild:
|
||||
result = discord.utils.get(guild.channels, name=argument)
|
||||
else:
|
||||
|
||||
def check(c):
|
||||
return isinstance(c, discord.StoreChannel) and c.name == argument
|
||||
|
||||
result = discord.utils.find(check, bot.get_all_channels())
|
||||
else:
|
||||
channel_id = int(match.group(1))
|
||||
if guild:
|
||||
result = guild.get_channel(channel_id)
|
||||
else:
|
||||
result = _get_from_guilds(bot, 'get_channel', channel_id)
|
||||
|
||||
if not isinstance(result, discord.StoreChannel):
|
||||
raise ChannelNotFound(argument)
|
||||
|
||||
return result
|
||||
return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.channels, discord.StoreChannel)
|
||||
|
||||
|
||||
class ColourConverter(Converter[discord.Colour]):
|
||||
@ -865,10 +769,12 @@ class clean_content(Converter[str]):
|
||||
r = _find(_id)
|
||||
return '@' + r.name if r else '@deleted-role'
|
||||
|
||||
# fmt: off
|
||||
transformations.update(
|
||||
(f'<@&{role_id}>', resolve_role(role_id))
|
||||
for role_id in message.raw_role_mentions
|
||||
) # fmt: off
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def repl(obj):
|
||||
return transformations.get(obj.group(0), '')
|
||||
|
Loading…
x
Reference in New Issue
Block a user