[commands] Default converters now take in IDs to match against.
This commit is contained in:
		| @@ -68,12 +68,19 @@ class Converter: | |||||||
|     def convert(self): |     def convert(self): | ||||||
|         raise NotImplementedError('Derived classes need to implement this.') |         raise NotImplementedError('Derived classes need to implement this.') | ||||||
|  |  | ||||||
| class MemberConverter(Converter): | class IDConverter(Converter): | ||||||
|  |     def __init__(self, ctx, argument): | ||||||
|  |         super().__init__(ctx, argument) | ||||||
|  |         self._id_regex = re.compile(r'([0-9]{15,21})$') | ||||||
|  |  | ||||||
|  |     def _get_id_match(self): | ||||||
|  |         return self._id_regex.match(self.argument) | ||||||
|  |  | ||||||
|  | class MemberConverter(IDConverter): | ||||||
|     def convert(self): |     def convert(self): | ||||||
|         message = self.ctx.message |         message = self.ctx.message | ||||||
|         bot = self.ctx.bot |         bot = self.ctx.bot | ||||||
|  |         match = self._get_id_match() or re.match(r'<@!?([0-9]+)>$', self.argument) | ||||||
|         match = re.match(r'<@!?([0-9]+)>$', self.argument) |  | ||||||
|         server = message.server |         server = message.server | ||||||
|         result = None |         result = None | ||||||
|         if match is None: |         if match is None: | ||||||
| @@ -96,12 +103,12 @@ class MemberConverter(Converter): | |||||||
|  |  | ||||||
| UserConverter = MemberConverter | UserConverter = MemberConverter | ||||||
|  |  | ||||||
| class ChannelConverter(Converter): | class ChannelConverter(IDConverter): | ||||||
|     def convert(self): |     def convert(self): | ||||||
|         message = self.ctx.message |         message = self.ctx.message | ||||||
|         bot = self.ctx.bot |         bot = self.ctx.bot | ||||||
|  |  | ||||||
|         match = re.match(r'<#([0-9]+)>$', self.argument) |         match = self._get_id_match() or re.match(r'<#([0-9]+)>$', self.argument) | ||||||
|         result = None |         result = None | ||||||
|         server = message.server |         server = message.server | ||||||
|         if match is None: |         if match is None: | ||||||
| @@ -137,13 +144,13 @@ class ColourConverter(Converter): | |||||||
|                 raise BadArgument('Colour "{}" is invalid.'.format(arg)) |                 raise BadArgument('Colour "{}" is invalid.'.format(arg)) | ||||||
|             return method() |             return method() | ||||||
|  |  | ||||||
| class RoleConverter(Converter): | class RoleConverter(IDConverter): | ||||||
|     def convert(self): |     def convert(self): | ||||||
|         server = self.ctx.message.server |         server = self.ctx.message.server | ||||||
|         if not server: |         if not server: | ||||||
|             raise NoPrivateMessage() |             raise NoPrivateMessage() | ||||||
|  |  | ||||||
|         match = re.match(r'<@&([0-9]+)>$', self.argument) |         match = self._get_id_match() or re.match(r'<@&([0-9]+)>$', self.argument) | ||||||
|         params = dict(id=match.group(1)) if match else dict(name=self.argument) |         params = dict(id=match.group(1)) if match else dict(name=self.argument) | ||||||
|         result = discord.utils.get(server.roles, **params) |         result = discord.utils.get(server.roles, **params) | ||||||
|         if result is None: |         if result is None: | ||||||
| @@ -163,13 +170,13 @@ class InviteConverter(Converter): | |||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             raise BadArgument('Invite is invalid or expired') from e |             raise BadArgument('Invite is invalid or expired') from e | ||||||
|  |  | ||||||
| class EmojiConverter(Converter): | class EmojiConverter(IDConverter): | ||||||
|     @asyncio.coroutine |     @asyncio.coroutine | ||||||
|     def convert(self): |     def convert(self): | ||||||
|         message = self.ctx.message |         message = self.ctx.message | ||||||
|         bot = self.ctx.bot |         bot = self.ctx.bot | ||||||
|  |  | ||||||
|         match = re.match(r'<:([a-zA-Z0-9]+):([0-9]+)>$', self.argument) |         match = self._get_id_match() or re.match(r'<:[a-zA-Z0-9]+:([0-9]+)>$', self.argument) | ||||||
|         result = None |         result = None | ||||||
|         server = message.server |         server = message.server | ||||||
|         if match is None: |         if match is None: | ||||||
| @@ -180,7 +187,7 @@ class EmojiConverter(Converter): | |||||||
|             if result is None: |             if result is None: | ||||||
|                 result = discord.utils.get(bot.get_all_emojis(), name=self.argument) |                 result = discord.utils.get(bot.get_all_emojis(), name=self.argument) | ||||||
|         else: |         else: | ||||||
|             emoji_id = match.group(2) |             emoji_id = match.group(1) | ||||||
|  |  | ||||||
|             # Try to look up emoji by id. |             # Try to look up emoji by id. | ||||||
|             if server: |             if server: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user