Implement async checks. Fixes #380.
This commit is contained in:
		| @@ -85,7 +85,7 @@ def _default_help_command(ctx, *commands : str): | ||||
|  | ||||
|     # help by itself just lists our own commands. | ||||
|     if len(commands) == 0: | ||||
|         pages = bot.formatter.format_help_for(ctx, bot) | ||||
|         pages = yield from bot.formatter.format_help_for(ctx, bot) | ||||
|     elif len(commands) == 1: | ||||
|         # try to see if it is a cog name | ||||
|         name = _mention_pattern.sub(repl, commands[0]) | ||||
| @@ -98,7 +98,7 @@ def _default_help_command(ctx, *commands : str): | ||||
|                 yield from destination.send(bot.command_not_found.format(name)) | ||||
|                 return | ||||
|  | ||||
|         pages = bot.formatter.format_help_for(ctx, command) | ||||
|         pages = yield from bot.formatter.format_help_for(ctx, command) | ||||
|     else: | ||||
|         name = _mention_pattern.sub(repl, commands[0]) | ||||
|         command = bot.commands.get(name) | ||||
| @@ -117,7 +117,7 @@ def _default_help_command(ctx, *commands : str): | ||||
|                 yield from destination.send(bot.command_has_no_subcommands.format(command, key)) | ||||
|                 return | ||||
|  | ||||
|         pages = bot.formatter.format_help_for(ctx, command) | ||||
|         pages = yield from bot.formatter.format_help_for(ctx, command) | ||||
|  | ||||
|     if bot.pm_help is None: | ||||
|         characters = sum(map(lambda l: len(l), pages)) | ||||
| @@ -218,9 +218,9 @@ class BotBase(GroupMixin): | ||||
|         on a per command basis except it is run before any command checks | ||||
|         have been verified and applies to every command the bot has. | ||||
|  | ||||
|         .. warning:: | ||||
|         .. info:: | ||||
|  | ||||
|             This function must be a *regular* function and not a coroutine. | ||||
|             This function can either be a regular function or a coroutine. | ||||
|  | ||||
|         Similar to a command :func:`check`\, this takes a single parameter | ||||
|         of type :class:`Context` and can only raise exceptions derived from | ||||
| @@ -268,8 +268,12 @@ class BotBase(GroupMixin): | ||||
|         except ValueError: | ||||
|             pass | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def can_run(self, ctx): | ||||
|         return all(f(ctx) for f in self._checks) | ||||
|         if len(self._checks) == 0: | ||||
|             return True | ||||
|  | ||||
|         return (yield from discord.utils.async_all(f(ctx) for f in self._checks)) | ||||
|  | ||||
|     def before_invoke(self, coro): | ||||
|         """A decorator that registers a coroutine as a pre-invoke hook. | ||||
|   | ||||
| @@ -342,6 +342,7 @@ class Command: | ||||
|                 raise TooManyArguments('Too many arguments passed to ' + self.qualified_name) | ||||
|  | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def _verify_checks(self, ctx): | ||||
|         if not self.enabled: | ||||
|             raise DisabledCommand('{0.name} command is disabled'.format(self)) | ||||
| @@ -349,10 +350,7 @@ class Command: | ||||
|         if self.no_pm and not isinstance(ctx.channel, discord.abc.GuildChannel): | ||||
|             raise NoPrivateMessage('This command cannot be used in private messages.') | ||||
|  | ||||
|         if not ctx.bot.can_run(ctx): | ||||
|             raise CheckFailure('The global check functions for command {0.qualified_name} failed.'.format(self)) | ||||
|  | ||||
|         if not self.can_run(ctx): | ||||
|         if not (yield from self.can_run(ctx)): | ||||
|             raise CheckFailure('The check functions for command {0.qualified_name} failed.'.format(self)) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
| @@ -402,7 +400,7 @@ class Command: | ||||
|     @asyncio.coroutine | ||||
|     def prepare(self, ctx): | ||||
|         ctx.command = self | ||||
|         self._verify_checks(ctx) | ||||
|         yield from self._verify_checks(ctx) | ||||
|         yield from self._parse_arguments(ctx) | ||||
|  | ||||
|         if self._buckets.valid: | ||||
| @@ -533,14 +531,17 @@ class Command: | ||||
|             return self.help.split('\n', 1)[0] | ||||
|         return '' | ||||
|  | ||||
|     def can_run(self, context): | ||||
|         """Checks if the command can be executed by checking all the predicates | ||||
|     @asyncio.coroutine | ||||
|     def can_run(self, ctx): | ||||
|         """|coro| | ||||
|  | ||||
|         Checks if the command can be executed by checking all the predicates | ||||
|         inside the :attr:`checks` attribute. | ||||
|  | ||||
|         Parameters | ||||
|         ----------- | ||||
|         context : :class:`Context` | ||||
|             The context of the command currently being invoked. | ||||
|         ctx: :class:`Context` | ||||
|             The ctx of the command currently being invoked. | ||||
|  | ||||
|         Returns | ||||
|         -------- | ||||
| @@ -548,6 +549,9 @@ class Command: | ||||
|             A boolean indicating if the command can be invoked. | ||||
|         """ | ||||
|  | ||||
|         if not (yield from ctx.bot.can_run(ctx)): | ||||
|             raise CheckFailure('The global check functions for command {0.qualified_name} failed.'.format(self)) | ||||
|  | ||||
|         cog = self.instance | ||||
|         if cog is not None: | ||||
|             try: | ||||
| @@ -555,14 +559,16 @@ class Command: | ||||
|             except AttributeError: | ||||
|                 pass | ||||
|             else: | ||||
|                 if not local_check(context): | ||||
|                 ret = yield from discord.utils.maybe_coroutine(local_check, ctx) | ||||
|                 if not ret: | ||||
|                     return False | ||||
|  | ||||
|         predicates = self.checks | ||||
|         if not predicates: | ||||
|             # since we have no checks, then we just return True. | ||||
|             return True | ||||
|         return all(predicate(context) for predicate in predicates) | ||||
|  | ||||
|         return (yield from discord.utils.async_all(predicate(ctx) for predicate in predicates)) | ||||
|  | ||||
| class GroupMixin: | ||||
|     """A mixin that implements common functionality for classes that behave | ||||
| @@ -855,6 +861,10 @@ def check(predicate): | ||||
|     will be propagated while those subclassed will be sent to | ||||
|     :func:`on_command_error`. | ||||
|  | ||||
|     .. info:: | ||||
|  | ||||
|         These functions can either be regular functions or coroutines. | ||||
|  | ||||
|     Parameters | ||||
|     ----------- | ||||
|     predicate | ||||
|   | ||||
| @@ -26,9 +26,11 @@ DEALINGS IN THE SOFTWARE. | ||||
|  | ||||
| import itertools | ||||
| import inspect | ||||
| import asyncio | ||||
|  | ||||
| from .core import GroupMixin, Command | ||||
| from .errors import CommandError | ||||
| # from discord.iterators import _FilteredAsyncIterator | ||||
|  | ||||
| # help -> shows info of bot on top/bottom and lists subcommands | ||||
| # help command -> shows detailed info of command | ||||
| @@ -227,6 +229,7 @@ class HelpFormatter: | ||||
|         return "Type {0}{1} command for more info on a command.\n" \ | ||||
|                "You can also type {0}{1} category for more info on a category.".format(self.clean_prefix, command_name) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def filter_command_list(self): | ||||
|         """Returns a filtered list of commands based on the two attributes | ||||
|         provided, :attr:`show_check_failure` and :attr:`show_hidden`. Also | ||||
| @@ -238,8 +241,9 @@ class HelpFormatter: | ||||
|             An iterable with the filter being applied. The resulting value is | ||||
|             a (key, value) tuple of the command name and the command itself. | ||||
|         """ | ||||
|         def predicate(tuple): | ||||
|             cmd = tuple[1] | ||||
|  | ||||
|         def sane_no_suspension_point_predicate(tup): | ||||
|             cmd = tup[1] | ||||
|             if self.is_cog(): | ||||
|                 # filter commands that don't exist to this cog. | ||||
|                 if cmd.instance is not self.command: | ||||
| @@ -248,18 +252,31 @@ class HelpFormatter: | ||||
|             if cmd.hidden and not self.show_hidden: | ||||
|                 return False | ||||
|  | ||||
|             if self.show_check_failure: | ||||
|                 # we don't wanna bother doing the checks if the user does not | ||||
|                 # care about them, so just return true. | ||||
|                 return True | ||||
|             return True | ||||
|  | ||||
|         @asyncio.coroutine | ||||
|         def predicate(tup): | ||||
|             if sane_no_suspension_point_predicate(tup) is False: | ||||
|                 return False | ||||
|  | ||||
|             cmd = tup[1] | ||||
|             try: | ||||
|                 return cmd.can_run(self.context) and self.context.bot.can_run(self.context) | ||||
|                 return (yield from cmd.can_run(self.context)) | ||||
|             except CommandError: | ||||
|                 return False | ||||
|  | ||||
|         iterator = self.command.commands.items() if not self.is_cog() else self.context.bot.commands.items() | ||||
|         return filter(predicate, iterator) | ||||
|         if not self.show_check_failure: | ||||
|             return filter(sane_no_suspension_point_predicate, iterator) | ||||
|  | ||||
|         # Gotta run every check and verify it | ||||
|         ret = [] | ||||
|         for elem in iterator: | ||||
|             valid = yield from predicate(elem) | ||||
|             if valid: | ||||
|                 ret.append(elem) | ||||
|  | ||||
|         return ret | ||||
|  | ||||
|     def _add_subcommands_to_page(self, max_width, commands): | ||||
|         for name, command in commands: | ||||
| @@ -271,6 +288,7 @@ class HelpFormatter: | ||||
|             shortened = self.shorten(entry) | ||||
|             self._paginator.add_line(shortened) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def format_help_for(self, context, command_or_bot): | ||||
|         """Formats the help page and handles the actual heavy lifting of how | ||||
|         the help command looks like. To change the behaviour, override the | ||||
| @@ -290,8 +308,9 @@ class HelpFormatter: | ||||
|         """ | ||||
|         self.context = context | ||||
|         self.command = command_or_bot | ||||
|         return self.format() | ||||
|         return (yield from self.format()) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def format(self): | ||||
|         """Handles the actual behaviour involved with formatting. | ||||
|  | ||||
| @@ -334,18 +353,19 @@ class HelpFormatter: | ||||
|             # last place sorting position. | ||||
|             return cog + ':' if cog is not None else '\u200bNo Category:' | ||||
|  | ||||
|         filtered = yield from self.filter_command_list() | ||||
|         if self.is_bot(): | ||||
|             data = sorted(self.filter_command_list(), key=category) | ||||
|             data = sorted(filtered, key=category) | ||||
|             for category, commands in itertools.groupby(data, key=category): | ||||
|                 # there simply is no prettier way of doing this. | ||||
|                 commands = list(commands) | ||||
|                 commands = sorted(commands) | ||||
|                 if len(commands) > 0: | ||||
|                     self._paginator.add_line(category) | ||||
|  | ||||
|                 self._add_subcommands_to_page(max_width, commands) | ||||
|         else: | ||||
|             self._paginator.add_line('Commands:') | ||||
|             self._add_subcommands_to_page(max_width, self.filter_command_list()) | ||||
|             self._add_subcommands_to_page(max_width, sorted(filtered)) | ||||
|  | ||||
|         # add the ending note | ||||
|         self._paginator.add_line() | ||||
|   | ||||
| @@ -30,18 +30,11 @@ import aiohttp | ||||
| import datetime | ||||
|  | ||||
| from .errors import NoMoreItems | ||||
| from .utils import time_snowflake | ||||
| from .utils import time_snowflake, maybe_coroutine | ||||
| from .object import Object | ||||
|  | ||||
| PY35 = sys.version_info >= (3, 5) | ||||
|  | ||||
| @asyncio.coroutine | ||||
| def _probably_coroutine(f, e): | ||||
|     if asyncio.iscoroutinefunction(f): | ||||
|         return (yield from f(e)) | ||||
|     else: | ||||
|         return f(e) | ||||
|  | ||||
| class _AsyncIterator: | ||||
|     __slots__ = () | ||||
|  | ||||
| @@ -67,7 +60,7 @@ class _AsyncIterator: | ||||
|             except NoMoreItems: | ||||
|                 return None | ||||
|  | ||||
|             ret = yield from _probably_coroutine(predicate, elem) | ||||
|             ret = yield from maybe_coroutine(predicate, elem) | ||||
|             if ret: | ||||
|                 return elem | ||||
|  | ||||
| @@ -114,7 +107,7 @@ class _MappedAsyncIterator(_AsyncIterator): | ||||
|     def get(self): | ||||
|         # this raises NoMoreItems and will propagate appropriately | ||||
|         item = yield from self.iterator.get() | ||||
|         return (yield from _probably_coroutine(self.func, item)) | ||||
|         return (yield from maybe_coroutine(self.func, item)) | ||||
|  | ||||
| class _FilteredAsyncIterator(_AsyncIterator): | ||||
|     def __init__(self, iterator, predicate): | ||||
| @@ -132,7 +125,7 @@ class _FilteredAsyncIterator(_AsyncIterator): | ||||
|         while True: | ||||
|             # propagate NoMoreItems similar to _MappedAsyncIterator | ||||
|             item = yield from getter() | ||||
|             ret = yield from _probably_coroutine(pred, item) | ||||
|             ret = yield from maybe_coroutine(pred, item) | ||||
|             if ret: | ||||
|                 return item | ||||
|  | ||||
|   | ||||
| @@ -260,3 +260,19 @@ def _bytes_to_base64_data(data): | ||||
| def to_json(obj): | ||||
|     return json.dumps(obj, separators=(',', ':'), ensure_ascii=True) | ||||
|  | ||||
| @asyncio.coroutine | ||||
| def maybe_coroutine(f, e): | ||||
|     if asyncio.iscoroutinefunction(f): | ||||
|         return (yield from f(e)) | ||||
|     else: | ||||
|         return f(e) | ||||
|  | ||||
| @asyncio.coroutine | ||||
| def async_all(gen): | ||||
|     check = asyncio.iscoroutine | ||||
|     for elem in gen: | ||||
|         if check(elem): | ||||
|             elem = yield from elem | ||||
|         if not elem: | ||||
|             return False | ||||
|     return True | ||||
|   | ||||
		Reference in New Issue
	
	Block a user