mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-10-24 10:02:56 +00:00
Implement async checks. Fixes #380.
This commit is contained in:
@@ -85,7 +85,7 @@ def _default_help_command(ctx, *commands : str):
|
||||
|
||||
# help by itself just lists our own commands.
|
||||
if len(commands) == 0:
|
||||
pages = bot.formatter.format_help_for(ctx, bot)
|
||||
pages = yield from bot.formatter.format_help_for(ctx, bot)
|
||||
elif len(commands) == 1:
|
||||
# try to see if it is a cog name
|
||||
name = _mention_pattern.sub(repl, commands[0])
|
||||
@@ -98,7 +98,7 @@ def _default_help_command(ctx, *commands : str):
|
||||
yield from destination.send(bot.command_not_found.format(name))
|
||||
return
|
||||
|
||||
pages = bot.formatter.format_help_for(ctx, command)
|
||||
pages = yield from bot.formatter.format_help_for(ctx, command)
|
||||
else:
|
||||
name = _mention_pattern.sub(repl, commands[0])
|
||||
command = bot.commands.get(name)
|
||||
@@ -117,7 +117,7 @@ def _default_help_command(ctx, *commands : str):
|
||||
yield from destination.send(bot.command_has_no_subcommands.format(command, key))
|
||||
return
|
||||
|
||||
pages = bot.formatter.format_help_for(ctx, command)
|
||||
pages = yield from bot.formatter.format_help_for(ctx, command)
|
||||
|
||||
if bot.pm_help is None:
|
||||
characters = sum(map(lambda l: len(l), pages))
|
||||
@@ -218,9 +218,9 @@ class BotBase(GroupMixin):
|
||||
on a per command basis except it is run before any command checks
|
||||
have been verified and applies to every command the bot has.
|
||||
|
||||
.. warning::
|
||||
.. info::
|
||||
|
||||
This function must be a *regular* function and not a coroutine.
|
||||
This function can either be a regular function or a coroutine.
|
||||
|
||||
Similar to a command :func:`check`\, this takes a single parameter
|
||||
of type :class:`Context` and can only raise exceptions derived from
|
||||
@@ -268,8 +268,12 @@ class BotBase(GroupMixin):
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
@asyncio.coroutine
|
||||
def can_run(self, ctx):
|
||||
return all(f(ctx) for f in self._checks)
|
||||
if len(self._checks) == 0:
|
||||
return True
|
||||
|
||||
return (yield from discord.utils.async_all(f(ctx) for f in self._checks))
|
||||
|
||||
def before_invoke(self, coro):
|
||||
"""A decorator that registers a coroutine as a pre-invoke hook.
|
||||
|
@@ -342,6 +342,7 @@ class Command:
|
||||
raise TooManyArguments('Too many arguments passed to ' + self.qualified_name)
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def _verify_checks(self, ctx):
|
||||
if not self.enabled:
|
||||
raise DisabledCommand('{0.name} command is disabled'.format(self))
|
||||
@@ -349,10 +350,7 @@ class Command:
|
||||
if self.no_pm and not isinstance(ctx.channel, discord.abc.GuildChannel):
|
||||
raise NoPrivateMessage('This command cannot be used in private messages.')
|
||||
|
||||
if not ctx.bot.can_run(ctx):
|
||||
raise CheckFailure('The global check functions for command {0.qualified_name} failed.'.format(self))
|
||||
|
||||
if not self.can_run(ctx):
|
||||
if not (yield from self.can_run(ctx)):
|
||||
raise CheckFailure('The check functions for command {0.qualified_name} failed.'.format(self))
|
||||
|
||||
@asyncio.coroutine
|
||||
@@ -402,7 +400,7 @@ class Command:
|
||||
@asyncio.coroutine
|
||||
def prepare(self, ctx):
|
||||
ctx.command = self
|
||||
self._verify_checks(ctx)
|
||||
yield from self._verify_checks(ctx)
|
||||
yield from self._parse_arguments(ctx)
|
||||
|
||||
if self._buckets.valid:
|
||||
@@ -533,14 +531,17 @@ class Command:
|
||||
return self.help.split('\n', 1)[0]
|
||||
return ''
|
||||
|
||||
def can_run(self, context):
|
||||
"""Checks if the command can be executed by checking all the predicates
|
||||
@asyncio.coroutine
|
||||
def can_run(self, ctx):
|
||||
"""|coro|
|
||||
|
||||
Checks if the command can be executed by checking all the predicates
|
||||
inside the :attr:`checks` attribute.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
context : :class:`Context`
|
||||
The context of the command currently being invoked.
|
||||
ctx: :class:`Context`
|
||||
The ctx of the command currently being invoked.
|
||||
|
||||
Returns
|
||||
--------
|
||||
@@ -548,6 +549,9 @@ class Command:
|
||||
A boolean indicating if the command can be invoked.
|
||||
"""
|
||||
|
||||
if not (yield from ctx.bot.can_run(ctx)):
|
||||
raise CheckFailure('The global check functions for command {0.qualified_name} failed.'.format(self))
|
||||
|
||||
cog = self.instance
|
||||
if cog is not None:
|
||||
try:
|
||||
@@ -555,14 +559,16 @@ class Command:
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
if not local_check(context):
|
||||
ret = yield from discord.utils.maybe_coroutine(local_check, ctx)
|
||||
if not ret:
|
||||
return False
|
||||
|
||||
predicates = self.checks
|
||||
if not predicates:
|
||||
# since we have no checks, then we just return True.
|
||||
return True
|
||||
return all(predicate(context) for predicate in predicates)
|
||||
|
||||
return (yield from discord.utils.async_all(predicate(ctx) for predicate in predicates))
|
||||
|
||||
class GroupMixin:
|
||||
"""A mixin that implements common functionality for classes that behave
|
||||
@@ -855,6 +861,10 @@ def check(predicate):
|
||||
will be propagated while those subclassed will be sent to
|
||||
:func:`on_command_error`.
|
||||
|
||||
.. info::
|
||||
|
||||
These functions can either be regular functions or coroutines.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
predicate
|
||||
|
@@ -26,9 +26,11 @@ DEALINGS IN THE SOFTWARE.
|
||||
|
||||
import itertools
|
||||
import inspect
|
||||
import asyncio
|
||||
|
||||
from .core import GroupMixin, Command
|
||||
from .errors import CommandError
|
||||
# from discord.iterators import _FilteredAsyncIterator
|
||||
|
||||
# help -> shows info of bot on top/bottom and lists subcommands
|
||||
# help command -> shows detailed info of command
|
||||
@@ -227,6 +229,7 @@ class HelpFormatter:
|
||||
return "Type {0}{1} command for more info on a command.\n" \
|
||||
"You can also type {0}{1} category for more info on a category.".format(self.clean_prefix, command_name)
|
||||
|
||||
@asyncio.coroutine
|
||||
def filter_command_list(self):
|
||||
"""Returns a filtered list of commands based on the two attributes
|
||||
provided, :attr:`show_check_failure` and :attr:`show_hidden`. Also
|
||||
@@ -238,8 +241,9 @@ class HelpFormatter:
|
||||
An iterable with the filter being applied. The resulting value is
|
||||
a (key, value) tuple of the command name and the command itself.
|
||||
"""
|
||||
def predicate(tuple):
|
||||
cmd = tuple[1]
|
||||
|
||||
def sane_no_suspension_point_predicate(tup):
|
||||
cmd = tup[1]
|
||||
if self.is_cog():
|
||||
# filter commands that don't exist to this cog.
|
||||
if cmd.instance is not self.command:
|
||||
@@ -248,18 +252,31 @@ class HelpFormatter:
|
||||
if cmd.hidden and not self.show_hidden:
|
||||
return False
|
||||
|
||||
if self.show_check_failure:
|
||||
# we don't wanna bother doing the checks if the user does not
|
||||
# care about them, so just return true.
|
||||
return True
|
||||
return True
|
||||
|
||||
@asyncio.coroutine
|
||||
def predicate(tup):
|
||||
if sane_no_suspension_point_predicate(tup) is False:
|
||||
return False
|
||||
|
||||
cmd = tup[1]
|
||||
try:
|
||||
return cmd.can_run(self.context) and self.context.bot.can_run(self.context)
|
||||
return (yield from cmd.can_run(self.context))
|
||||
except CommandError:
|
||||
return False
|
||||
|
||||
iterator = self.command.commands.items() if not self.is_cog() else self.context.bot.commands.items()
|
||||
return filter(predicate, iterator)
|
||||
if not self.show_check_failure:
|
||||
return filter(sane_no_suspension_point_predicate, iterator)
|
||||
|
||||
# Gotta run every check and verify it
|
||||
ret = []
|
||||
for elem in iterator:
|
||||
valid = yield from predicate(elem)
|
||||
if valid:
|
||||
ret.append(elem)
|
||||
|
||||
return ret
|
||||
|
||||
def _add_subcommands_to_page(self, max_width, commands):
|
||||
for name, command in commands:
|
||||
@@ -271,6 +288,7 @@ class HelpFormatter:
|
||||
shortened = self.shorten(entry)
|
||||
self._paginator.add_line(shortened)
|
||||
|
||||
@asyncio.coroutine
|
||||
def format_help_for(self, context, command_or_bot):
|
||||
"""Formats the help page and handles the actual heavy lifting of how
|
||||
the help command looks like. To change the behaviour, override the
|
||||
@@ -290,8 +308,9 @@ class HelpFormatter:
|
||||
"""
|
||||
self.context = context
|
||||
self.command = command_or_bot
|
||||
return self.format()
|
||||
return (yield from self.format())
|
||||
|
||||
@asyncio.coroutine
|
||||
def format(self):
|
||||
"""Handles the actual behaviour involved with formatting.
|
||||
|
||||
@@ -334,18 +353,19 @@ class HelpFormatter:
|
||||
# last place sorting position.
|
||||
return cog + ':' if cog is not None else '\u200bNo Category:'
|
||||
|
||||
filtered = yield from self.filter_command_list()
|
||||
if self.is_bot():
|
||||
data = sorted(self.filter_command_list(), key=category)
|
||||
data = sorted(filtered, key=category)
|
||||
for category, commands in itertools.groupby(data, key=category):
|
||||
# there simply is no prettier way of doing this.
|
||||
commands = list(commands)
|
||||
commands = sorted(commands)
|
||||
if len(commands) > 0:
|
||||
self._paginator.add_line(category)
|
||||
|
||||
self._add_subcommands_to_page(max_width, commands)
|
||||
else:
|
||||
self._paginator.add_line('Commands:')
|
||||
self._add_subcommands_to_page(max_width, self.filter_command_list())
|
||||
self._add_subcommands_to_page(max_width, sorted(filtered))
|
||||
|
||||
# add the ending note
|
||||
self._paginator.add_line()
|
||||
|
@@ -30,18 +30,11 @@ import aiohttp
|
||||
import datetime
|
||||
|
||||
from .errors import NoMoreItems
|
||||
from .utils import time_snowflake
|
||||
from .utils import time_snowflake, maybe_coroutine
|
||||
from .object import Object
|
||||
|
||||
PY35 = sys.version_info >= (3, 5)
|
||||
|
||||
@asyncio.coroutine
|
||||
def _probably_coroutine(f, e):
|
||||
if asyncio.iscoroutinefunction(f):
|
||||
return (yield from f(e))
|
||||
else:
|
||||
return f(e)
|
||||
|
||||
class _AsyncIterator:
|
||||
__slots__ = ()
|
||||
|
||||
@@ -67,7 +60,7 @@ class _AsyncIterator:
|
||||
except NoMoreItems:
|
||||
return None
|
||||
|
||||
ret = yield from _probably_coroutine(predicate, elem)
|
||||
ret = yield from maybe_coroutine(predicate, elem)
|
||||
if ret:
|
||||
return elem
|
||||
|
||||
@@ -114,7 +107,7 @@ class _MappedAsyncIterator(_AsyncIterator):
|
||||
def get(self):
|
||||
# this raises NoMoreItems and will propagate appropriately
|
||||
item = yield from self.iterator.get()
|
||||
return (yield from _probably_coroutine(self.func, item))
|
||||
return (yield from maybe_coroutine(self.func, item))
|
||||
|
||||
class _FilteredAsyncIterator(_AsyncIterator):
|
||||
def __init__(self, iterator, predicate):
|
||||
@@ -132,7 +125,7 @@ class _FilteredAsyncIterator(_AsyncIterator):
|
||||
while True:
|
||||
# propagate NoMoreItems similar to _MappedAsyncIterator
|
||||
item = yield from getter()
|
||||
ret = yield from _probably_coroutine(pred, item)
|
||||
ret = yield from maybe_coroutine(pred, item)
|
||||
if ret:
|
||||
return item
|
||||
|
||||
|
@@ -260,3 +260,19 @@ def _bytes_to_base64_data(data):
|
||||
def to_json(obj):
|
||||
return json.dumps(obj, separators=(',', ':'), ensure_ascii=True)
|
||||
|
||||
@asyncio.coroutine
|
||||
def maybe_coroutine(f, e):
|
||||
if asyncio.iscoroutinefunction(f):
|
||||
return (yield from f(e))
|
||||
else:
|
||||
return f(e)
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_all(gen):
|
||||
check = asyncio.iscoroutine
|
||||
for elem in gen:
|
||||
if check(elem):
|
||||
elem = yield from elem
|
||||
if not elem:
|
||||
return False
|
||||
return True
|
||||
|
Reference in New Issue
Block a user