mirror of
				https://github.com/Rapptz/discord.py.git
				synced 2025-10-25 02:23:04 +00:00 
			
		
		
		
	[commands] Minimise code duplication in channel converters
This commit is contained in:
		| @@ -26,7 +26,7 @@ from __future__ import annotations | |||||||
|  |  | ||||||
| import re | import re | ||||||
| import inspect | import inspect | ||||||
| from typing import TYPE_CHECKING, List, Protocol, TypeVar, Tuple, Union, runtime_checkable | from typing import Iterable, Optional, TYPE_CHECKING, List, Protocol, Type, TypeVar, Tuple, Union, runtime_checkable | ||||||
|  |  | ||||||
| import discord | import discord | ||||||
| from .errors import * | from .errors import * | ||||||
| @@ -72,6 +72,7 @@ def _get_from_guilds(bot, getter, argument): | |||||||
| _utils_get = discord.utils.get | _utils_get = discord.utils.get | ||||||
| T = TypeVar('T') | T = TypeVar('T') | ||||||
| T_co = TypeVar('T_co', covariant=True) | T_co = TypeVar('T_co', covariant=True) | ||||||
|  | CT = TypeVar('CT', bound=discord.abc.GuildChannel) | ||||||
|  |  | ||||||
|  |  | ||||||
| @runtime_checkable | @runtime_checkable | ||||||
| @@ -112,13 +113,13 @@ class Converter(Protocol[T_co]): | |||||||
|         raise NotImplementedError('Derived classes need to implement this.') |         raise NotImplementedError('Derived classes need to implement this.') | ||||||
|  |  | ||||||
|  |  | ||||||
| class IDConverter(Converter[T_co]): | _ID_REGEX = re.compile(r'([0-9]{15,20})$') | ||||||
|     def __init__(self): |  | ||||||
|         self._id_regex = re.compile(r'([0-9]{15,20})$') |  | ||||||
|         super().__init__() |  | ||||||
|  |  | ||||||
|     def _get_id_match(self, argument): |  | ||||||
|         return self._id_regex.match(argument) | class IDConverter(Converter[T_co]): | ||||||
|  |     @staticmethod | ||||||
|  |     def _get_id_match(argument): | ||||||
|  |         return _ID_REGEX.match(argument) | ||||||
|  |  | ||||||
|  |  | ||||||
| class MemberConverter(IDConverter[discord.Member]): | class MemberConverter(IDConverter[discord.Member]): | ||||||
| @@ -351,20 +352,24 @@ class TextChannelConverter(IDConverter[discord.TextChannel]): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     async def convert(self, ctx: Context, argument: str) -> discord.TextChannel: |     async def convert(self, ctx: Context, argument: str) -> discord.TextChannel: | ||||||
|  |         return self._resolve_channel(ctx, argument, ctx.guild.text_channels, discord.TextChannel) | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def _resolve_channel(ctx: Context, argument: str, iterable: Iterable[CT], type: Type[CT]) -> CT: | ||||||
|         bot = ctx.bot |         bot = ctx.bot | ||||||
|  |  | ||||||
|         match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) |         match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) | ||||||
|         result = None |         result = None | ||||||
|         guild = ctx.guild |         guild = ctx.guild | ||||||
|  |  | ||||||
|         if match is None: |         if match is None: | ||||||
|             # not a mention |             # not a mention | ||||||
|             if guild: |             if guild: | ||||||
|                 result = discord.utils.get(guild.text_channels, name=argument) |                 result: Optional[CT] = discord.utils.get(iterable, name=argument) | ||||||
|             else: |             else: | ||||||
|  |  | ||||||
|                 def check(c): |                 def check(c): | ||||||
|                     return isinstance(c, discord.TextChannel) and c.name == argument |                     return isinstance(c, type) and c.name == argument | ||||||
|  |  | ||||||
|                 result = discord.utils.find(check, bot.get_all_channels()) |                 result = discord.utils.find(check, bot.get_all_channels()) | ||||||
|         else: |         else: | ||||||
| @@ -374,7 +379,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]): | |||||||
|             else: |             else: | ||||||
|                 result = _get_from_guilds(bot, 'get_channel', channel_id) |                 result = _get_from_guilds(bot, 'get_channel', channel_id) | ||||||
|  |  | ||||||
|         if not isinstance(result, discord.TextChannel): |         if not isinstance(result, type): | ||||||
|             raise ChannelNotFound(argument) |             raise ChannelNotFound(argument) | ||||||
|  |  | ||||||
|         return result |         return result | ||||||
| @@ -397,32 +402,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel: |     async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel: | ||||||
|         bot = ctx.bot |         return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.voice_channels, discord.VoiceChannel) | ||||||
|         match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) |  | ||||||
|         result = None |  | ||||||
|         guild = ctx.guild |  | ||||||
|  |  | ||||||
|         if match is None: |  | ||||||
|             # not a mention |  | ||||||
|             if guild: |  | ||||||
|                 result = discord.utils.get(guild.voice_channels, name=argument) |  | ||||||
|             else: |  | ||||||
|  |  | ||||||
|                 def check(c): |  | ||||||
|                     return isinstance(c, discord.VoiceChannel) and c.name == argument |  | ||||||
|  |  | ||||||
|                 result = discord.utils.find(check, bot.get_all_channels()) |  | ||||||
|         else: |  | ||||||
|             channel_id = int(match.group(1)) |  | ||||||
|             if guild: |  | ||||||
|                 result = guild.get_channel(channel_id) |  | ||||||
|             else: |  | ||||||
|                 result = _get_from_guilds(bot, 'get_channel', channel_id) |  | ||||||
|  |  | ||||||
|         if not isinstance(result, discord.VoiceChannel): |  | ||||||
|             raise ChannelNotFound(argument) |  | ||||||
|  |  | ||||||
|         return result |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class StageChannelConverter(IDConverter[discord.StageChannel]): | class StageChannelConverter(IDConverter[discord.StageChannel]): | ||||||
| @@ -441,32 +421,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     async def convert(self, ctx: Context, argument: str) -> discord.StageChannel: |     async def convert(self, ctx: Context, argument: str) -> discord.StageChannel: | ||||||
|         bot = ctx.bot |         return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.stage_channels, discord.StageChannel) | ||||||
|         match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) |  | ||||||
|         result = None |  | ||||||
|         guild = ctx.guild |  | ||||||
|  |  | ||||||
|         if match is None: |  | ||||||
|             # not a mention |  | ||||||
|             if guild: |  | ||||||
|                 result = discord.utils.get(guild.stage_channels, name=argument) |  | ||||||
|             else: |  | ||||||
|  |  | ||||||
|                 def check(c): |  | ||||||
|                     return isinstance(c, discord.StageChannel) and c.name == argument |  | ||||||
|  |  | ||||||
|                 result = discord.utils.find(check, bot.get_all_channels()) |  | ||||||
|         else: |  | ||||||
|             channel_id = int(match.group(1)) |  | ||||||
|             if guild: |  | ||||||
|                 result = guild.get_channel(channel_id) |  | ||||||
|             else: |  | ||||||
|                 result = _get_from_guilds(bot, 'get_channel', channel_id) |  | ||||||
|  |  | ||||||
|         if not isinstance(result, discord.StageChannel): |  | ||||||
|             raise ChannelNotFound(argument) |  | ||||||
|  |  | ||||||
|         return result |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): | class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): | ||||||
| @@ -486,33 +441,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel: |     async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel: | ||||||
|         bot = ctx.bot |         return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.categories, discord.CategoryChannel) | ||||||
|  |  | ||||||
|         match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) |  | ||||||
|         result = None |  | ||||||
|         guild = ctx.guild |  | ||||||
|  |  | ||||||
|         if match is None: |  | ||||||
|             # not a mention |  | ||||||
|             if guild: |  | ||||||
|                 result = discord.utils.get(guild.categories, name=argument) |  | ||||||
|             else: |  | ||||||
|  |  | ||||||
|                 def check(c): |  | ||||||
|                     return isinstance(c, discord.CategoryChannel) and c.name == argument |  | ||||||
|  |  | ||||||
|                 result = discord.utils.find(check, bot.get_all_channels()) |  | ||||||
|         else: |  | ||||||
|             channel_id = int(match.group(1)) |  | ||||||
|             if guild: |  | ||||||
|                 result = guild.get_channel(channel_id) |  | ||||||
|             else: |  | ||||||
|                 result = _get_from_guilds(bot, 'get_channel', channel_id) |  | ||||||
|  |  | ||||||
|         if not isinstance(result, discord.CategoryChannel): |  | ||||||
|             raise ChannelNotFound(argument) |  | ||||||
|  |  | ||||||
|         return result |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class StoreChannelConverter(IDConverter[discord.StoreChannel]): | class StoreChannelConverter(IDConverter[discord.StoreChannel]): | ||||||
| @@ -531,32 +460,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel: |     async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel: | ||||||
|         bot = ctx.bot |         return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.channels, discord.StoreChannel) | ||||||
|         match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) |  | ||||||
|         result = None |  | ||||||
|         guild = ctx.guild |  | ||||||
|  |  | ||||||
|         if match is None: |  | ||||||
|             # not a mention |  | ||||||
|             if guild: |  | ||||||
|                 result = discord.utils.get(guild.channels, name=argument) |  | ||||||
|             else: |  | ||||||
|  |  | ||||||
|                 def check(c): |  | ||||||
|                     return isinstance(c, discord.StoreChannel) and c.name == argument |  | ||||||
|  |  | ||||||
|                 result = discord.utils.find(check, bot.get_all_channels()) |  | ||||||
|         else: |  | ||||||
|             channel_id = int(match.group(1)) |  | ||||||
|             if guild: |  | ||||||
|                 result = guild.get_channel(channel_id) |  | ||||||
|             else: |  | ||||||
|                 result = _get_from_guilds(bot, 'get_channel', channel_id) |  | ||||||
|  |  | ||||||
|         if not isinstance(result, discord.StoreChannel): |  | ||||||
|             raise ChannelNotFound(argument) |  | ||||||
|  |  | ||||||
|         return result |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ColourConverter(Converter[discord.Colour]): | class ColourConverter(Converter[discord.Colour]): | ||||||
| @@ -865,10 +769,12 @@ class clean_content(Converter[str]): | |||||||
|                 r = _find(_id) |                 r = _find(_id) | ||||||
|                 return '@' + r.name if r else '@deleted-role' |                 return '@' + r.name if r else '@deleted-role' | ||||||
|  |  | ||||||
|  |             # fmt: off | ||||||
|             transformations.update( |             transformations.update( | ||||||
|                 (f'<@&{role_id}>', resolve_role(role_id)) |                 (f'<@&{role_id}>', resolve_role(role_id)) | ||||||
|                 for role_id in message.raw_role_mentions |                 for role_id in message.raw_role_mentions | ||||||
|             )  # fmt: off |             ) | ||||||
|  |             # fmt: on | ||||||
|  |  | ||||||
|         def repl(obj): |         def repl(obj): | ||||||
|             return transformations.get(obj.group(0), '') |             return transformations.get(obj.group(0), '') | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user