Implement async checks. Fixes #380.

This commit is contained in:
Rapptz
2017-02-12 12:13:23 -05:00
parent 2abdbc70c2
commit 47ef657fbd
5 changed files with 83 additions and 40 deletions

View File

@@ -85,7 +85,7 @@ def _default_help_command(ctx, *commands : str):
# help by itself just lists our own commands. # help by itself just lists our own commands.
if len(commands) == 0: if len(commands) == 0:
pages = bot.formatter.format_help_for(ctx, bot) pages = yield from bot.formatter.format_help_for(ctx, bot)
elif len(commands) == 1: elif len(commands) == 1:
# try to see if it is a cog name # try to see if it is a cog name
name = _mention_pattern.sub(repl, commands[0]) name = _mention_pattern.sub(repl, commands[0])
@@ -98,7 +98,7 @@ def _default_help_command(ctx, *commands : str):
yield from destination.send(bot.command_not_found.format(name)) yield from destination.send(bot.command_not_found.format(name))
return return
pages = bot.formatter.format_help_for(ctx, command) pages = yield from bot.formatter.format_help_for(ctx, command)
else: else:
name = _mention_pattern.sub(repl, commands[0]) name = _mention_pattern.sub(repl, commands[0])
command = bot.commands.get(name) command = bot.commands.get(name)
@@ -117,7 +117,7 @@ def _default_help_command(ctx, *commands : str):
yield from destination.send(bot.command_has_no_subcommands.format(command, key)) yield from destination.send(bot.command_has_no_subcommands.format(command, key))
return return
pages = bot.formatter.format_help_for(ctx, command) pages = yield from bot.formatter.format_help_for(ctx, command)
if bot.pm_help is None: if bot.pm_help is None:
characters = sum(map(lambda l: len(l), pages)) characters = sum(map(lambda l: len(l), pages))
@@ -218,9 +218,9 @@ class BotBase(GroupMixin):
on a per command basis except it is run before any command checks on a per command basis except it is run before any command checks
have been verified and applies to every command the bot has. have been verified and applies to every command the bot has.
.. warning:: .. info::
This function must be a *regular* function and not a coroutine. This function can either be a regular function or a coroutine.
Similar to a command :func:`check`\, this takes a single parameter Similar to a command :func:`check`\, this takes a single parameter
of type :class:`Context` and can only raise exceptions derived from of type :class:`Context` and can only raise exceptions derived from
@@ -268,8 +268,12 @@ class BotBase(GroupMixin):
except ValueError: except ValueError:
pass pass
@asyncio.coroutine
def can_run(self, ctx): def can_run(self, ctx):
return all(f(ctx) for f in self._checks) if len(self._checks) == 0:
return True
return (yield from discord.utils.async_all(f(ctx) for f in self._checks))
def before_invoke(self, coro): def before_invoke(self, coro):
"""A decorator that registers a coroutine as a pre-invoke hook. """A decorator that registers a coroutine as a pre-invoke hook.

View File

@@ -342,6 +342,7 @@ class Command:
raise TooManyArguments('Too many arguments passed to ' + self.qualified_name) raise TooManyArguments('Too many arguments passed to ' + self.qualified_name)
@asyncio.coroutine
def _verify_checks(self, ctx): def _verify_checks(self, ctx):
if not self.enabled: if not self.enabled:
raise DisabledCommand('{0.name} command is disabled'.format(self)) raise DisabledCommand('{0.name} command is disabled'.format(self))
@@ -349,10 +350,7 @@ class Command:
if self.no_pm and not isinstance(ctx.channel, discord.abc.GuildChannel): if self.no_pm and not isinstance(ctx.channel, discord.abc.GuildChannel):
raise NoPrivateMessage('This command cannot be used in private messages.') raise NoPrivateMessage('This command cannot be used in private messages.')
if not ctx.bot.can_run(ctx): if not (yield from self.can_run(ctx)):
raise CheckFailure('The global check functions for command {0.qualified_name} failed.'.format(self))
if not self.can_run(ctx):
raise CheckFailure('The check functions for command {0.qualified_name} failed.'.format(self)) raise CheckFailure('The check functions for command {0.qualified_name} failed.'.format(self))
@asyncio.coroutine @asyncio.coroutine
@@ -402,7 +400,7 @@ class Command:
@asyncio.coroutine @asyncio.coroutine
def prepare(self, ctx): def prepare(self, ctx):
ctx.command = self ctx.command = self
self._verify_checks(ctx) yield from self._verify_checks(ctx)
yield from self._parse_arguments(ctx) yield from self._parse_arguments(ctx)
if self._buckets.valid: if self._buckets.valid:
@@ -533,14 +531,17 @@ class Command:
return self.help.split('\n', 1)[0] return self.help.split('\n', 1)[0]
return '' return ''
def can_run(self, context): @asyncio.coroutine
"""Checks if the command can be executed by checking all the predicates def can_run(self, ctx):
"""|coro|
Checks if the command can be executed by checking all the predicates
inside the :attr:`checks` attribute. inside the :attr:`checks` attribute.
Parameters Parameters
----------- -----------
context : :class:`Context` ctx: :class:`Context`
The context of the command currently being invoked. The ctx of the command currently being invoked.
Returns Returns
-------- --------
@@ -548,6 +549,9 @@ class Command:
A boolean indicating if the command can be invoked. A boolean indicating if the command can be invoked.
""" """
if not (yield from ctx.bot.can_run(ctx)):
raise CheckFailure('The global check functions for command {0.qualified_name} failed.'.format(self))
cog = self.instance cog = self.instance
if cog is not None: if cog is not None:
try: try:
@@ -555,14 +559,16 @@ class Command:
except AttributeError: except AttributeError:
pass pass
else: else:
if not local_check(context): ret = yield from discord.utils.maybe_coroutine(local_check, ctx)
if not ret:
return False return False
predicates = self.checks predicates = self.checks
if not predicates: if not predicates:
# since we have no checks, then we just return True. # since we have no checks, then we just return True.
return True return True
return all(predicate(context) for predicate in predicates)
return (yield from discord.utils.async_all(predicate(ctx) for predicate in predicates))
class GroupMixin: class GroupMixin:
"""A mixin that implements common functionality for classes that behave """A mixin that implements common functionality for classes that behave
@@ -855,6 +861,10 @@ def check(predicate):
will be propagated while those subclassed will be sent to will be propagated while those subclassed will be sent to
:func:`on_command_error`. :func:`on_command_error`.
.. info::
These functions can either be regular functions or coroutines.
Parameters Parameters
----------- -----------
predicate predicate

View File

@@ -26,9 +26,11 @@ DEALINGS IN THE SOFTWARE.
import itertools import itertools
import inspect import inspect
import asyncio
from .core import GroupMixin, Command from .core import GroupMixin, Command
from .errors import CommandError from .errors import CommandError
# from discord.iterators import _FilteredAsyncIterator
# help -> shows info of bot on top/bottom and lists subcommands # help -> shows info of bot on top/bottom and lists subcommands
# help command -> shows detailed info of command # help command -> shows detailed info of command
@@ -227,6 +229,7 @@ class HelpFormatter:
return "Type {0}{1} command for more info on a command.\n" \ return "Type {0}{1} command for more info on a command.\n" \
"You can also type {0}{1} category for more info on a category.".format(self.clean_prefix, command_name) "You can also type {0}{1} category for more info on a category.".format(self.clean_prefix, command_name)
@asyncio.coroutine
def filter_command_list(self): def filter_command_list(self):
"""Returns a filtered list of commands based on the two attributes """Returns a filtered list of commands based on the two attributes
provided, :attr:`show_check_failure` and :attr:`show_hidden`. Also provided, :attr:`show_check_failure` and :attr:`show_hidden`. Also
@@ -238,8 +241,9 @@ class HelpFormatter:
An iterable with the filter being applied. The resulting value is An iterable with the filter being applied. The resulting value is
a (key, value) tuple of the command name and the command itself. a (key, value) tuple of the command name and the command itself.
""" """
def predicate(tuple):
cmd = tuple[1] def sane_no_suspension_point_predicate(tup):
cmd = tup[1]
if self.is_cog(): if self.is_cog():
# filter commands that don't exist to this cog. # filter commands that don't exist to this cog.
if cmd.instance is not self.command: if cmd.instance is not self.command:
@@ -248,18 +252,31 @@ class HelpFormatter:
if cmd.hidden and not self.show_hidden: if cmd.hidden and not self.show_hidden:
return False return False
if self.show_check_failure: return True
# we don't wanna bother doing the checks if the user does not
# care about them, so just return true.
return True
@asyncio.coroutine
def predicate(tup):
if sane_no_suspension_point_predicate(tup) is False:
return False
cmd = tup[1]
try: try:
return cmd.can_run(self.context) and self.context.bot.can_run(self.context) return (yield from cmd.can_run(self.context))
except CommandError: except CommandError:
return False return False
iterator = self.command.commands.items() if not self.is_cog() else self.context.bot.commands.items() iterator = self.command.commands.items() if not self.is_cog() else self.context.bot.commands.items()
return filter(predicate, iterator) if not self.show_check_failure:
return filter(sane_no_suspension_point_predicate, iterator)
# Gotta run every check and verify it
ret = []
for elem in iterator:
valid = yield from predicate(elem)
if valid:
ret.append(elem)
return ret
def _add_subcommands_to_page(self, max_width, commands): def _add_subcommands_to_page(self, max_width, commands):
for name, command in commands: for name, command in commands:
@@ -271,6 +288,7 @@ class HelpFormatter:
shortened = self.shorten(entry) shortened = self.shorten(entry)
self._paginator.add_line(shortened) self._paginator.add_line(shortened)
@asyncio.coroutine
def format_help_for(self, context, command_or_bot): def format_help_for(self, context, command_or_bot):
"""Formats the help page and handles the actual heavy lifting of how """Formats the help page and handles the actual heavy lifting of how
the help command looks like. To change the behaviour, override the the help command looks like. To change the behaviour, override the
@@ -290,8 +308,9 @@ class HelpFormatter:
""" """
self.context = context self.context = context
self.command = command_or_bot self.command = command_or_bot
return self.format() return (yield from self.format())
@asyncio.coroutine
def format(self): def format(self):
"""Handles the actual behaviour involved with formatting. """Handles the actual behaviour involved with formatting.
@@ -334,18 +353,19 @@ class HelpFormatter:
# last place sorting position. # last place sorting position.
return cog + ':' if cog is not None else '\u200bNo Category:' return cog + ':' if cog is not None else '\u200bNo Category:'
filtered = yield from self.filter_command_list()
if self.is_bot(): if self.is_bot():
data = sorted(self.filter_command_list(), key=category) data = sorted(filtered, key=category)
for category, commands in itertools.groupby(data, key=category): for category, commands in itertools.groupby(data, key=category):
# there simply is no prettier way of doing this. # there simply is no prettier way of doing this.
commands = list(commands) commands = sorted(commands)
if len(commands) > 0: if len(commands) > 0:
self._paginator.add_line(category) self._paginator.add_line(category)
self._add_subcommands_to_page(max_width, commands) self._add_subcommands_to_page(max_width, commands)
else: else:
self._paginator.add_line('Commands:') self._paginator.add_line('Commands:')
self._add_subcommands_to_page(max_width, self.filter_command_list()) self._add_subcommands_to_page(max_width, sorted(filtered))
# add the ending note # add the ending note
self._paginator.add_line() self._paginator.add_line()

View File

@@ -30,18 +30,11 @@ import aiohttp
import datetime import datetime
from .errors import NoMoreItems from .errors import NoMoreItems
from .utils import time_snowflake from .utils import time_snowflake, maybe_coroutine
from .object import Object from .object import Object
PY35 = sys.version_info >= (3, 5) PY35 = sys.version_info >= (3, 5)
@asyncio.coroutine
def _probably_coroutine(f, e):
if asyncio.iscoroutinefunction(f):
return (yield from f(e))
else:
return f(e)
class _AsyncIterator: class _AsyncIterator:
__slots__ = () __slots__ = ()
@@ -67,7 +60,7 @@ class _AsyncIterator:
except NoMoreItems: except NoMoreItems:
return None return None
ret = yield from _probably_coroutine(predicate, elem) ret = yield from maybe_coroutine(predicate, elem)
if ret: if ret:
return elem return elem
@@ -114,7 +107,7 @@ class _MappedAsyncIterator(_AsyncIterator):
def get(self): def get(self):
# this raises NoMoreItems and will propagate appropriately # this raises NoMoreItems and will propagate appropriately
item = yield from self.iterator.get() item = yield from self.iterator.get()
return (yield from _probably_coroutine(self.func, item)) return (yield from maybe_coroutine(self.func, item))
class _FilteredAsyncIterator(_AsyncIterator): class _FilteredAsyncIterator(_AsyncIterator):
def __init__(self, iterator, predicate): def __init__(self, iterator, predicate):
@@ -132,7 +125,7 @@ class _FilteredAsyncIterator(_AsyncIterator):
while True: while True:
# propagate NoMoreItems similar to _MappedAsyncIterator # propagate NoMoreItems similar to _MappedAsyncIterator
item = yield from getter() item = yield from getter()
ret = yield from _probably_coroutine(pred, item) ret = yield from maybe_coroutine(pred, item)
if ret: if ret:
return item return item

View File

@@ -260,3 +260,19 @@ def _bytes_to_base64_data(data):
def to_json(obj): def to_json(obj):
return json.dumps(obj, separators=(',', ':'), ensure_ascii=True) return json.dumps(obj, separators=(',', ':'), ensure_ascii=True)
@asyncio.coroutine
def maybe_coroutine(f, e):
if asyncio.iscoroutinefunction(f):
return (yield from f(e))
else:
return f(e)
@asyncio.coroutine
def async_all(gen):
check = asyncio.iscoroutine
for elem in gen:
if check(elem):
elem = yield from elem
if not elem:
return False
return True