[commands] Minimise code duplication in channel converters
This commit is contained in:
parent
ec71eb2fcb
commit
353737239a
@ -26,7 +26,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
import inspect
|
import inspect
|
||||||
from typing import TYPE_CHECKING, List, Protocol, TypeVar, Tuple, Union, runtime_checkable
|
from typing import Iterable, Optional, TYPE_CHECKING, List, Protocol, Type, TypeVar, Tuple, Union, runtime_checkable
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
from .errors import *
|
from .errors import *
|
||||||
@ -72,6 +72,7 @@ def _get_from_guilds(bot, getter, argument):
|
|||||||
_utils_get = discord.utils.get
|
_utils_get = discord.utils.get
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
T_co = TypeVar('T_co', covariant=True)
|
T_co = TypeVar('T_co', covariant=True)
|
||||||
|
CT = TypeVar('CT', bound=discord.abc.GuildChannel)
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@ -112,13 +113,13 @@ class Converter(Protocol[T_co]):
|
|||||||
raise NotImplementedError('Derived classes need to implement this.')
|
raise NotImplementedError('Derived classes need to implement this.')
|
||||||
|
|
||||||
|
|
||||||
class IDConverter(Converter[T_co]):
|
_ID_REGEX = re.compile(r'([0-9]{15,20})$')
|
||||||
def __init__(self):
|
|
||||||
self._id_regex = re.compile(r'([0-9]{15,20})$')
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def _get_id_match(self, argument):
|
|
||||||
return self._id_regex.match(argument)
|
class IDConverter(Converter[T_co]):
|
||||||
|
@staticmethod
|
||||||
|
def _get_id_match(argument):
|
||||||
|
return _ID_REGEX.match(argument)
|
||||||
|
|
||||||
|
|
||||||
class MemberConverter(IDConverter[discord.Member]):
|
class MemberConverter(IDConverter[discord.Member]):
|
||||||
@ -351,20 +352,24 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
async def convert(self, ctx: Context, argument: str) -> discord.TextChannel:
|
async def convert(self, ctx: Context, argument: str) -> discord.TextChannel:
|
||||||
|
return self._resolve_channel(ctx, argument, ctx.guild.text_channels, discord.TextChannel)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_channel(ctx: Context, argument: str, iterable: Iterable[CT], type: Type[CT]) -> CT:
|
||||||
bot = ctx.bot
|
bot = ctx.bot
|
||||||
|
|
||||||
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
|
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
|
||||||
result = None
|
result = None
|
||||||
guild = ctx.guild
|
guild = ctx.guild
|
||||||
|
|
||||||
if match is None:
|
if match is None:
|
||||||
# not a mention
|
# not a mention
|
||||||
if guild:
|
if guild:
|
||||||
result = discord.utils.get(guild.text_channels, name=argument)
|
result: Optional[CT] = discord.utils.get(iterable, name=argument)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def check(c):
|
def check(c):
|
||||||
return isinstance(c, discord.TextChannel) and c.name == argument
|
return isinstance(c, type) and c.name == argument
|
||||||
|
|
||||||
result = discord.utils.find(check, bot.get_all_channels())
|
result = discord.utils.find(check, bot.get_all_channels())
|
||||||
else:
|
else:
|
||||||
@ -374,7 +379,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
|
|||||||
else:
|
else:
|
||||||
result = _get_from_guilds(bot, 'get_channel', channel_id)
|
result = _get_from_guilds(bot, 'get_channel', channel_id)
|
||||||
|
|
||||||
if not isinstance(result, discord.TextChannel):
|
if not isinstance(result, type):
|
||||||
raise ChannelNotFound(argument)
|
raise ChannelNotFound(argument)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@ -397,32 +402,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel:
|
async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel:
|
||||||
bot = ctx.bot
|
return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.voice_channels, discord.VoiceChannel)
|
||||||
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
|
|
||||||
result = None
|
|
||||||
guild = ctx.guild
|
|
||||||
|
|
||||||
if match is None:
|
|
||||||
# not a mention
|
|
||||||
if guild:
|
|
||||||
result = discord.utils.get(guild.voice_channels, name=argument)
|
|
||||||
else:
|
|
||||||
|
|
||||||
def check(c):
|
|
||||||
return isinstance(c, discord.VoiceChannel) and c.name == argument
|
|
||||||
|
|
||||||
result = discord.utils.find(check, bot.get_all_channels())
|
|
||||||
else:
|
|
||||||
channel_id = int(match.group(1))
|
|
||||||
if guild:
|
|
||||||
result = guild.get_channel(channel_id)
|
|
||||||
else:
|
|
||||||
result = _get_from_guilds(bot, 'get_channel', channel_id)
|
|
||||||
|
|
||||||
if not isinstance(result, discord.VoiceChannel):
|
|
||||||
raise ChannelNotFound(argument)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class StageChannelConverter(IDConverter[discord.StageChannel]):
|
class StageChannelConverter(IDConverter[discord.StageChannel]):
|
||||||
@ -441,32 +421,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
async def convert(self, ctx: Context, argument: str) -> discord.StageChannel:
|
async def convert(self, ctx: Context, argument: str) -> discord.StageChannel:
|
||||||
bot = ctx.bot
|
return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.stage_channels, discord.StageChannel)
|
||||||
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
|
|
||||||
result = None
|
|
||||||
guild = ctx.guild
|
|
||||||
|
|
||||||
if match is None:
|
|
||||||
# not a mention
|
|
||||||
if guild:
|
|
||||||
result = discord.utils.get(guild.stage_channels, name=argument)
|
|
||||||
else:
|
|
||||||
|
|
||||||
def check(c):
|
|
||||||
return isinstance(c, discord.StageChannel) and c.name == argument
|
|
||||||
|
|
||||||
result = discord.utils.find(check, bot.get_all_channels())
|
|
||||||
else:
|
|
||||||
channel_id = int(match.group(1))
|
|
||||||
if guild:
|
|
||||||
result = guild.get_channel(channel_id)
|
|
||||||
else:
|
|
||||||
result = _get_from_guilds(bot, 'get_channel', channel_id)
|
|
||||||
|
|
||||||
if not isinstance(result, discord.StageChannel):
|
|
||||||
raise ChannelNotFound(argument)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
|
class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
|
||||||
@ -486,33 +441,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel:
|
async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel:
|
||||||
bot = ctx.bot
|
return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.categories, discord.CategoryChannel)
|
||||||
|
|
||||||
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
|
|
||||||
result = None
|
|
||||||
guild = ctx.guild
|
|
||||||
|
|
||||||
if match is None:
|
|
||||||
# not a mention
|
|
||||||
if guild:
|
|
||||||
result = discord.utils.get(guild.categories, name=argument)
|
|
||||||
else:
|
|
||||||
|
|
||||||
def check(c):
|
|
||||||
return isinstance(c, discord.CategoryChannel) and c.name == argument
|
|
||||||
|
|
||||||
result = discord.utils.find(check, bot.get_all_channels())
|
|
||||||
else:
|
|
||||||
channel_id = int(match.group(1))
|
|
||||||
if guild:
|
|
||||||
result = guild.get_channel(channel_id)
|
|
||||||
else:
|
|
||||||
result = _get_from_guilds(bot, 'get_channel', channel_id)
|
|
||||||
|
|
||||||
if not isinstance(result, discord.CategoryChannel):
|
|
||||||
raise ChannelNotFound(argument)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class StoreChannelConverter(IDConverter[discord.StoreChannel]):
|
class StoreChannelConverter(IDConverter[discord.StoreChannel]):
|
||||||
@ -531,32 +460,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel:
|
async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel:
|
||||||
bot = ctx.bot
|
return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.channels, discord.StoreChannel)
|
||||||
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
|
|
||||||
result = None
|
|
||||||
guild = ctx.guild
|
|
||||||
|
|
||||||
if match is None:
|
|
||||||
# not a mention
|
|
||||||
if guild:
|
|
||||||
result = discord.utils.get(guild.channels, name=argument)
|
|
||||||
else:
|
|
||||||
|
|
||||||
def check(c):
|
|
||||||
return isinstance(c, discord.StoreChannel) and c.name == argument
|
|
||||||
|
|
||||||
result = discord.utils.find(check, bot.get_all_channels())
|
|
||||||
else:
|
|
||||||
channel_id = int(match.group(1))
|
|
||||||
if guild:
|
|
||||||
result = guild.get_channel(channel_id)
|
|
||||||
else:
|
|
||||||
result = _get_from_guilds(bot, 'get_channel', channel_id)
|
|
||||||
|
|
||||||
if not isinstance(result, discord.StoreChannel):
|
|
||||||
raise ChannelNotFound(argument)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class ColourConverter(Converter[discord.Colour]):
|
class ColourConverter(Converter[discord.Colour]):
|
||||||
@ -865,10 +769,12 @@ class clean_content(Converter[str]):
|
|||||||
r = _find(_id)
|
r = _find(_id)
|
||||||
return '@' + r.name if r else '@deleted-role'
|
return '@' + r.name if r else '@deleted-role'
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
transformations.update(
|
transformations.update(
|
||||||
(f'<@&{role_id}>', resolve_role(role_id))
|
(f'<@&{role_id}>', resolve_role(role_id))
|
||||||
for role_id in message.raw_role_mentions
|
for role_id in message.raw_role_mentions
|
||||||
) # fmt: off
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
def repl(obj):
|
def repl(obj):
|
||||||
return transformations.get(obj.group(0), '')
|
return transformations.get(obj.group(0), '')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user