[commands] Converter.convert is always a coroutine.
Along with this change comes with the removal of Converter.prepare and adding two arguments to Converter.convert, the context and the argument. I suppose an added benefit is that you don't have to do attribute access since it's a local variable.
This commit is contained in:
parent
8ef984746a
commit
d7478425ca
@ -53,46 +53,52 @@ class Converter:
|
|||||||
special cased ``discord`` classes.
|
special cased ``discord`` classes.
|
||||||
|
|
||||||
Classes that derive from this should override the :meth:`convert` method
|
Classes that derive from this should override the :meth:`convert` method
|
||||||
to do its conversion logic. This method could be a coroutine or a regular
|
to do its conversion logic. This method must be a coroutine.
|
||||||
function.
|
"""
|
||||||
|
|
||||||
Before the convert method is called, :meth:`prepare` is called. This
|
@asyncio.coroutine
|
||||||
method must set the attributes below if overwritten.
|
def convert(self, ctx, argument):
|
||||||
|
"""|coro|
|
||||||
|
|
||||||
Attributes
|
The method to override to do conversion logic.
|
||||||
|
|
||||||
|
This can either be a coroutine or a regular function.
|
||||||
|
|
||||||
|
If an error is found while converting, it is recommended to
|
||||||
|
raise a :class:`CommandError` derived exception as it will
|
||||||
|
properly propagate to the error handlers.
|
||||||
|
|
||||||
|
Parameters
|
||||||
-----------
|
-----------
|
||||||
ctx: :class:`Context`
|
ctx: :class:`Context`
|
||||||
The invocation context that the argument is being used in.
|
The invocation context that the argument is being used in.
|
||||||
argument: str
|
argument: str
|
||||||
The argument that is being converted.
|
The argument that is being converted.
|
||||||
"""
|
"""
|
||||||
def prepare(self, ctx, argument):
|
|
||||||
self.ctx = ctx
|
|
||||||
self.argument = argument
|
|
||||||
|
|
||||||
def convert(self):
|
|
||||||
raise NotImplementedError('Derived classes need to implement this.')
|
raise NotImplementedError('Derived classes need to implement this.')
|
||||||
|
|
||||||
class IDConverter(Converter):
|
class IDConverter(Converter):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._id_regex = re.compile(r'([0-9]{15,21})$')
|
self._id_regex = re.compile(r'([0-9]{15,21})$')
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
def _get_id_match(self):
|
def _get_id_match(self, argument):
|
||||||
return self._id_regex.match(self.argument)
|
return self._id_regex.match(argument)
|
||||||
|
|
||||||
class MemberConverter(IDConverter):
|
class MemberConverter(IDConverter):
|
||||||
def convert(self):
|
@asyncio.coroutine
|
||||||
message = self.ctx.message
|
def convert(self, ctx, argument):
|
||||||
bot = self.ctx.bot
|
message = ctx.message
|
||||||
match = self._get_id_match() or re.match(r'<@!?([0-9]+)>$', self.argument)
|
bot = ctx.bot
|
||||||
|
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]+)>$', argument)
|
||||||
guild = message.guild
|
guild = message.guild
|
||||||
result = None
|
result = None
|
||||||
if match is None:
|
if match is None:
|
||||||
# not a mention...
|
# not a mention...
|
||||||
if guild:
|
if guild:
|
||||||
result = guild.get_member_named(self.argument)
|
result = guild.get_member_named(argument)
|
||||||
else:
|
else:
|
||||||
result = _get_from_guilds(bot, 'get_member_named', self.argument)
|
result = _get_from_guilds(bot, 'get_member_named', argument)
|
||||||
else:
|
else:
|
||||||
user_id = int(match.group(1))
|
user_id = int(match.group(1))
|
||||||
if guild:
|
if guild:
|
||||||
@ -101,21 +107,22 @@ class MemberConverter(IDConverter):
|
|||||||
result = _get_from_guilds(bot, 'get_member', user_id)
|
result = _get_from_guilds(bot, 'get_member', user_id)
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
raise BadArgument('Member "{}" not found'.format(self.argument))
|
raise BadArgument('Member "{}" not found'.format(argument))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
class UserConverter(IDConverter):
|
class UserConverter(IDConverter):
|
||||||
def convert(self):
|
@asyncio.coroutine
|
||||||
match = self._get_id_match() or re.match(r'<@!?([0-9]+)>$', self.argument)
|
def convert(self, ctx, argument):
|
||||||
|
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]+)>$', argument)
|
||||||
result = None
|
result = None
|
||||||
state = self.ctx._state
|
state = ctx._state
|
||||||
|
|
||||||
if match is not None:
|
if match is not None:
|
||||||
user_id = int(match.group(1))
|
user_id = int(match.group(1))
|
||||||
result = self.ctx.bot.get_user(user_id)
|
result = ctx.bot.get_user(user_id)
|
||||||
else:
|
else:
|
||||||
arg = self.argument
|
arg = argument
|
||||||
# check for discriminator if it exists
|
# check for discriminator if it exists
|
||||||
if len(arg) > 5 and arg[-5] == '#':
|
if len(arg) > 5 and arg[-5] == '#':
|
||||||
discrim = arg[-4:]
|
discrim = arg[-4:]
|
||||||
@ -129,25 +136,26 @@ class UserConverter(IDConverter):
|
|||||||
result = discord.utils.find(predicate, state._users.values())
|
result = discord.utils.find(predicate, state._users.values())
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
raise BadArgument('User "{}" not found'.format(self.argument))
|
raise BadArgument('User "{}" not found'.format(argument))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
class TextChannelConverter(IDConverter):
|
class TextChannelConverter(IDConverter):
|
||||||
def convert(self):
|
@asyncio.coroutine
|
||||||
bot = self.ctx.bot
|
def convert(self, ctx, argument):
|
||||||
|
bot = ctx.bot
|
||||||
|
|
||||||
match = self._get_id_match() or re.match(r'<#([0-9]+)>$', self.argument)
|
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
|
||||||
result = None
|
result = None
|
||||||
guild = self.ctx.guild
|
guild = ctx.guild
|
||||||
|
|
||||||
if match is None:
|
if match is None:
|
||||||
# not a mention
|
# not a mention
|
||||||
if guild:
|
if guild:
|
||||||
result = discord.utils.get(guild.text_channels, name=self.argument)
|
result = discord.utils.get(guild.text_channels, name=argument)
|
||||||
else:
|
else:
|
||||||
def check(c):
|
def check(c):
|
||||||
return isinstance(c, discord.TextChannel) and c.name == self.argument
|
return isinstance(c, discord.TextChannel) and c.name == argument
|
||||||
result = discord.utils.find(check, bot.get_all_channels())
|
result = discord.utils.find(check, bot.get_all_channels())
|
||||||
else:
|
else:
|
||||||
channel_id = int(match.group(1))
|
channel_id = int(match.group(1))
|
||||||
@ -157,25 +165,25 @@ class TextChannelConverter(IDConverter):
|
|||||||
result = _get_from_guilds(bot, 'get_channel', channel_id)
|
result = _get_from_guilds(bot, 'get_channel', channel_id)
|
||||||
|
|
||||||
if not isinstance(result, discord.TextChannel):
|
if not isinstance(result, discord.TextChannel):
|
||||||
raise BadArgument('Channel "{}" not found.'.format(self.argument))
|
raise BadArgument('Channel "{}" not found.'.format(argument))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
class VoiceChannelConverter(IDConverter):
|
class VoiceChannelConverter(IDConverter):
|
||||||
def convert(self):
|
@asyncio.coroutine
|
||||||
bot = self.ctx.bot
|
def convert(self, ctx, argument):
|
||||||
|
bot = ctx.bot
|
||||||
match = self._get_id_match() or re.match(r'<#([0-9]+)>$', self.argument)
|
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
|
||||||
result = None
|
result = None
|
||||||
guild = self.ctx.guild
|
guild = ctx.guild
|
||||||
|
|
||||||
if match is None:
|
if match is None:
|
||||||
# not a mention
|
# not a mention
|
||||||
if guild:
|
if guild:
|
||||||
result = discord.utils.get(guild.voice_channels, name=self.argument)
|
result = discord.utils.get(guild.voice_channels, name=argument)
|
||||||
else:
|
else:
|
||||||
def check(c):
|
def check(c):
|
||||||
return isinstance(c, discord.VoiceChannel) and c.name == self.argument
|
return isinstance(c, discord.VoiceChannel) and c.name == argument
|
||||||
result = discord.utils.find(check, bot.get_all_channels())
|
result = discord.utils.find(check, bot.get_all_channels())
|
||||||
else:
|
else:
|
||||||
channel_id = int(match.group(1))
|
channel_id = int(match.group(1))
|
||||||
@ -185,13 +193,14 @@ class VoiceChannelConverter(IDConverter):
|
|||||||
result = _get_from_guilds(bot, 'get_channel', channel_id)
|
result = _get_from_guilds(bot, 'get_channel', channel_id)
|
||||||
|
|
||||||
if not isinstance(result, discord.VoiceChannel):
|
if not isinstance(result, discord.VoiceChannel):
|
||||||
raise BadArgument('Channel "{}" not found.'.format(self.argument))
|
raise BadArgument('Channel "{}" not found.'.format(argument))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
class ColourConverter(Converter):
|
class ColourConverter(Converter):
|
||||||
def convert(self):
|
@asyncio.coroutine
|
||||||
arg = self.argument.replace('0x', '').lower()
|
def convert(self, ctx, argument):
|
||||||
|
arg = argument.replace('0x', '').lower()
|
||||||
|
|
||||||
if arg[0] == '#':
|
if arg[0] == '#':
|
||||||
arg = arg[1:]
|
arg = arg[1:]
|
||||||
@ -205,47 +214,48 @@ class ColourConverter(Converter):
|
|||||||
return method()
|
return method()
|
||||||
|
|
||||||
class RoleConverter(IDConverter):
|
class RoleConverter(IDConverter):
|
||||||
def convert(self):
|
@asyncio.coroutine
|
||||||
guild = self.ctx.message.guild
|
def convert(self, ctx, argument):
|
||||||
|
guild = ctx.message.guild
|
||||||
if not guild:
|
if not guild:
|
||||||
raise NoPrivateMessage()
|
raise NoPrivateMessage()
|
||||||
|
|
||||||
match = self._get_id_match() or re.match(r'<@&([0-9]+)>$', self.argument)
|
match = self._get_id_match(argument) or re.match(r'<@&([0-9]+)>$', argument)
|
||||||
params = dict(id=int(match.group(1))) if match else dict(name=self.argument)
|
params = dict(id=int(match.group(1))) if match else dict(name=argument)
|
||||||
result = discord.utils.get(guild.roles, **params)
|
result = discord.utils.get(guild.roles, **params)
|
||||||
if result is None:
|
if result is None:
|
||||||
raise BadArgument('Role "{}" not found.'.format(self.argument))
|
raise BadArgument('Role "{}" not found.'.format(argument))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
class GameConverter(Converter):
|
class GameConverter(Converter):
|
||||||
def convert(self):
|
@asyncio.coroutine
|
||||||
return discord.Game(name=self.argument)
|
def convert(self, ctx, argument):
|
||||||
|
return discord.Game(name=argument)
|
||||||
|
|
||||||
class InviteConverter(Converter):
|
class InviteConverter(Converter):
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def convert(self):
|
def convert(self, ctx, argument):
|
||||||
try:
|
try:
|
||||||
invite = yield from self.ctx.bot.get_invite(self.argument)
|
invite = yield from ctx.bot.get_invite(argument)
|
||||||
return invite
|
return invite
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise BadArgument('Invite is invalid or expired') from e
|
raise BadArgument('Invite is invalid or expired') from e
|
||||||
|
|
||||||
class EmojiConverter(IDConverter):
|
class EmojiConverter(IDConverter):
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def convert(self):
|
def convert(self, ctx, argument):
|
||||||
message = self.ctx.message
|
match = self._get_id_match(argument) or re.match(r'<:[a-zA-Z0-9]+:([0-9]+)>$', argument)
|
||||||
bot = self.ctx.bot
|
|
||||||
|
|
||||||
match = self._get_id_match() or re.match(r'<:[a-zA-Z0-9]+:([0-9]+)>$', self.argument)
|
|
||||||
result = None
|
result = None
|
||||||
guild = message.guild
|
bot = ctx.bot
|
||||||
|
guild = ctx.guild
|
||||||
|
|
||||||
if match is None:
|
if match is None:
|
||||||
# Try to get the emoji by name. Try local guild first.
|
# Try to get the emoji by name. Try local guild first.
|
||||||
if guild:
|
if guild:
|
||||||
result = discord.utils.get(guild.emojis, name=self.argument)
|
result = discord.utils.get(guild.emojis, name=argument)
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
result = discord.utils.get(bot.emojis, name=self.argument)
|
result = discord.utils.get(bot.emojis, name=argument)
|
||||||
else:
|
else:
|
||||||
emoji_id = int(match.group(1))
|
emoji_id = int(match.group(1))
|
||||||
|
|
||||||
@ -257,7 +267,7 @@ class EmojiConverter(IDConverter):
|
|||||||
result = discord.utils.get(bot.emojis, id=emoji_id)
|
result = discord.utils.get(bot.emojis, id=emoji_id)
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
raise BadArgument('Emoji "{}" not found.'.format(self.argument))
|
raise BadArgument('Emoji "{}" not found.'.format(argument))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -266,8 +276,9 @@ class clean_content(Converter):
|
|||||||
self.fix_channel_mentions = fix_channel_mentions
|
self.fix_channel_mentions = fix_channel_mentions
|
||||||
self.use_nicknames = use_nicknames
|
self.use_nicknames = use_nicknames
|
||||||
|
|
||||||
def convert(self):
|
@asyncio.coroutine
|
||||||
message = self.ctx.message
|
def convert(self, ctx, argument):
|
||||||
|
message = ctx.message
|
||||||
transformations = {}
|
transformations = {}
|
||||||
|
|
||||||
if self.fix_channel_mentions:
|
if self.fix_channel_mentions:
|
||||||
@ -306,7 +317,7 @@ class clean_content(Converter):
|
|||||||
return transformations.get(obj.group(0), '')
|
return transformations.get(obj.group(0), '')
|
||||||
|
|
||||||
pattern = re.compile('|'.join(transformations.keys()))
|
pattern = re.compile('|'.join(transformations.keys()))
|
||||||
result = pattern.sub(repl, self.argument)
|
result = pattern.sub(repl, argument)
|
||||||
|
|
||||||
transformations = {
|
transformations = {
|
||||||
'@everyone': '@\u200beveryone',
|
'@everyone': '@\u200beveryone',
|
||||||
|
@ -202,13 +202,10 @@ class Command:
|
|||||||
|
|
||||||
if inspect.isclass(converter) and issubclass(converter, converters.Converter):
|
if inspect.isclass(converter) and issubclass(converter, converters.Converter):
|
||||||
instance = converter()
|
instance = converter()
|
||||||
instance.prepare(ctx, argument)
|
ret = yield from instance.convert(ctx, argument)
|
||||||
ret = yield from discord.utils.maybe_coroutine(instance.convert)
|
|
||||||
return ret
|
return ret
|
||||||
|
elif isinstance(converter, converters.Converter):
|
||||||
if isinstance(converter, converters.Converter):
|
ret = yield from converter.convert(ctx, argument)
|
||||||
converter.prepare(ctx, argument)
|
|
||||||
ret = yield from discord.utils.maybe_coroutine(converter.convert)
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
return converter(argument)
|
return converter(argument)
|
||||||
@ -220,9 +217,6 @@ class Command:
|
|||||||
converter = str if param.default is None else type(param.default)
|
converter = str if param.default is None else type(param.default)
|
||||||
else:
|
else:
|
||||||
converter = str
|
converter = str
|
||||||
elif not inspect.isclass(type(converter)):
|
|
||||||
raise discord.ClientException('Function annotation must be a type')
|
|
||||||
|
|
||||||
return converter
|
return converter
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
|
Loading…
x
Reference in New Issue
Block a user