Implement async checks. Fixes #380.

This commit is contained in:
Rapptz
2017-02-12 12:13:23 -05:00
parent 2abdbc70c2
commit 47ef657fbd
5 changed files with 83 additions and 40 deletions

View File

@@ -342,6 +342,7 @@ class Command:
raise TooManyArguments('Too many arguments passed to ' + self.qualified_name)
@asyncio.coroutine
def _verify_checks(self, ctx):
if not self.enabled:
raise DisabledCommand('{0.name} command is disabled'.format(self))
@@ -349,10 +350,7 @@ class Command:
if self.no_pm and not isinstance(ctx.channel, discord.abc.GuildChannel):
raise NoPrivateMessage('This command cannot be used in private messages.')
if not ctx.bot.can_run(ctx):
raise CheckFailure('The global check functions for command {0.qualified_name} failed.'.format(self))
if not self.can_run(ctx):
if not (yield from self.can_run(ctx)):
raise CheckFailure('The check functions for command {0.qualified_name} failed.'.format(self))
@asyncio.coroutine
@@ -402,7 +400,7 @@ class Command:
@asyncio.coroutine
def prepare(self, ctx):
ctx.command = self
self._verify_checks(ctx)
yield from self._verify_checks(ctx)
yield from self._parse_arguments(ctx)
if self._buckets.valid:
@@ -533,14 +531,17 @@ class Command:
return self.help.split('\n', 1)[0]
return ''
def can_run(self, context):
"""Checks if the command can be executed by checking all the predicates
@asyncio.coroutine
def can_run(self, ctx):
"""|coro|
Checks if the command can be executed by checking all the predicates
inside the :attr:`checks` attribute.
Parameters
-----------
context : :class:`Context`
The context of the command currently being invoked.
ctx: :class:`Context`
The ctx of the command currently being invoked.
Returns
--------
@@ -548,6 +549,9 @@ class Command:
A boolean indicating if the command can be invoked.
"""
if not (yield from ctx.bot.can_run(ctx)):
raise CheckFailure('The global check functions for command {0.qualified_name} failed.'.format(self))
cog = self.instance
if cog is not None:
try:
@@ -555,14 +559,16 @@ class Command:
except AttributeError:
pass
else:
if not local_check(context):
ret = yield from discord.utils.maybe_coroutine(local_check, ctx)
if not ret:
return False
predicates = self.checks
if not predicates:
# since we have no checks, then we just return True.
return True
return all(predicate(context) for predicate in predicates)
return (yield from discord.utils.async_all(predicate(ctx) for predicate in predicates))
class GroupMixin:
"""A mixin that implements common functionality for classes that behave
@@ -855,6 +861,10 @@ def check(predicate):
will be propagated while those subclassed will be sent to
:func:`on_command_error`.
.. info::
These functions can either be regular functions or coroutines.
Parameters
-----------
predicate