Implement async checks. Fixes #380.
This commit is contained in:
parent
2abdbc70c2
commit
47ef657fbd
@ -85,7 +85,7 @@ def _default_help_command(ctx, *commands : str):
|
|||||||
|
|
||||||
# help by itself just lists our own commands.
|
# help by itself just lists our own commands.
|
||||||
if len(commands) == 0:
|
if len(commands) == 0:
|
||||||
pages = bot.formatter.format_help_for(ctx, bot)
|
pages = yield from bot.formatter.format_help_for(ctx, bot)
|
||||||
elif len(commands) == 1:
|
elif len(commands) == 1:
|
||||||
# try to see if it is a cog name
|
# try to see if it is a cog name
|
||||||
name = _mention_pattern.sub(repl, commands[0])
|
name = _mention_pattern.sub(repl, commands[0])
|
||||||
@ -98,7 +98,7 @@ def _default_help_command(ctx, *commands : str):
|
|||||||
yield from destination.send(bot.command_not_found.format(name))
|
yield from destination.send(bot.command_not_found.format(name))
|
||||||
return
|
return
|
||||||
|
|
||||||
pages = bot.formatter.format_help_for(ctx, command)
|
pages = yield from bot.formatter.format_help_for(ctx, command)
|
||||||
else:
|
else:
|
||||||
name = _mention_pattern.sub(repl, commands[0])
|
name = _mention_pattern.sub(repl, commands[0])
|
||||||
command = bot.commands.get(name)
|
command = bot.commands.get(name)
|
||||||
@ -117,7 +117,7 @@ def _default_help_command(ctx, *commands : str):
|
|||||||
yield from destination.send(bot.command_has_no_subcommands.format(command, key))
|
yield from destination.send(bot.command_has_no_subcommands.format(command, key))
|
||||||
return
|
return
|
||||||
|
|
||||||
pages = bot.formatter.format_help_for(ctx, command)
|
pages = yield from bot.formatter.format_help_for(ctx, command)
|
||||||
|
|
||||||
if bot.pm_help is None:
|
if bot.pm_help is None:
|
||||||
characters = sum(map(lambda l: len(l), pages))
|
characters = sum(map(lambda l: len(l), pages))
|
||||||
@ -218,9 +218,9 @@ class BotBase(GroupMixin):
|
|||||||
on a per command basis except it is run before any command checks
|
on a per command basis except it is run before any command checks
|
||||||
have been verified and applies to every command the bot has.
|
have been verified and applies to every command the bot has.
|
||||||
|
|
||||||
.. warning::
|
.. info::
|
||||||
|
|
||||||
This function must be a *regular* function and not a coroutine.
|
This function can either be a regular function or a coroutine.
|
||||||
|
|
||||||
Similar to a command :func:`check`\, this takes a single parameter
|
Similar to a command :func:`check`\, this takes a single parameter
|
||||||
of type :class:`Context` and can only raise exceptions derived from
|
of type :class:`Context` and can only raise exceptions derived from
|
||||||
@ -268,8 +268,12 @@ class BotBase(GroupMixin):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
def can_run(self, ctx):
|
def can_run(self, ctx):
|
||||||
return all(f(ctx) for f in self._checks)
|
if len(self._checks) == 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return (yield from discord.utils.async_all(f(ctx) for f in self._checks))
|
||||||
|
|
||||||
def before_invoke(self, coro):
|
def before_invoke(self, coro):
|
||||||
"""A decorator that registers a coroutine as a pre-invoke hook.
|
"""A decorator that registers a coroutine as a pre-invoke hook.
|
||||||
|
@ -342,6 +342,7 @@ class Command:
|
|||||||
raise TooManyArguments('Too many arguments passed to ' + self.qualified_name)
|
raise TooManyArguments('Too many arguments passed to ' + self.qualified_name)
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
def _verify_checks(self, ctx):
|
def _verify_checks(self, ctx):
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
raise DisabledCommand('{0.name} command is disabled'.format(self))
|
raise DisabledCommand('{0.name} command is disabled'.format(self))
|
||||||
@ -349,10 +350,7 @@ class Command:
|
|||||||
if self.no_pm and not isinstance(ctx.channel, discord.abc.GuildChannel):
|
if self.no_pm and not isinstance(ctx.channel, discord.abc.GuildChannel):
|
||||||
raise NoPrivateMessage('This command cannot be used in private messages.')
|
raise NoPrivateMessage('This command cannot be used in private messages.')
|
||||||
|
|
||||||
if not ctx.bot.can_run(ctx):
|
if not (yield from self.can_run(ctx)):
|
||||||
raise CheckFailure('The global check functions for command {0.qualified_name} failed.'.format(self))
|
|
||||||
|
|
||||||
if not self.can_run(ctx):
|
|
||||||
raise CheckFailure('The check functions for command {0.qualified_name} failed.'.format(self))
|
raise CheckFailure('The check functions for command {0.qualified_name} failed.'.format(self))
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
@ -402,7 +400,7 @@ class Command:
|
|||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def prepare(self, ctx):
|
def prepare(self, ctx):
|
||||||
ctx.command = self
|
ctx.command = self
|
||||||
self._verify_checks(ctx)
|
yield from self._verify_checks(ctx)
|
||||||
yield from self._parse_arguments(ctx)
|
yield from self._parse_arguments(ctx)
|
||||||
|
|
||||||
if self._buckets.valid:
|
if self._buckets.valid:
|
||||||
@ -533,14 +531,17 @@ class Command:
|
|||||||
return self.help.split('\n', 1)[0]
|
return self.help.split('\n', 1)[0]
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
def can_run(self, context):
|
@asyncio.coroutine
|
||||||
"""Checks if the command can be executed by checking all the predicates
|
def can_run(self, ctx):
|
||||||
|
"""|coro|
|
||||||
|
|
||||||
|
Checks if the command can be executed by checking all the predicates
|
||||||
inside the :attr:`checks` attribute.
|
inside the :attr:`checks` attribute.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
-----------
|
-----------
|
||||||
context : :class:`Context`
|
ctx: :class:`Context`
|
||||||
The context of the command currently being invoked.
|
The ctx of the command currently being invoked.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
--------
|
--------
|
||||||
@ -548,6 +549,9 @@ class Command:
|
|||||||
A boolean indicating if the command can be invoked.
|
A boolean indicating if the command can be invoked.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if not (yield from ctx.bot.can_run(ctx)):
|
||||||
|
raise CheckFailure('The global check functions for command {0.qualified_name} failed.'.format(self))
|
||||||
|
|
||||||
cog = self.instance
|
cog = self.instance
|
||||||
if cog is not None:
|
if cog is not None:
|
||||||
try:
|
try:
|
||||||
@ -555,14 +559,16 @@ class Command:
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if not local_check(context):
|
ret = yield from discord.utils.maybe_coroutine(local_check, ctx)
|
||||||
|
if not ret:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
predicates = self.checks
|
predicates = self.checks
|
||||||
if not predicates:
|
if not predicates:
|
||||||
# since we have no checks, then we just return True.
|
# since we have no checks, then we just return True.
|
||||||
return True
|
return True
|
||||||
return all(predicate(context) for predicate in predicates)
|
|
||||||
|
return (yield from discord.utils.async_all(predicate(ctx) for predicate in predicates))
|
||||||
|
|
||||||
class GroupMixin:
|
class GroupMixin:
|
||||||
"""A mixin that implements common functionality for classes that behave
|
"""A mixin that implements common functionality for classes that behave
|
||||||
@ -855,6 +861,10 @@ def check(predicate):
|
|||||||
will be propagated while those subclassed will be sent to
|
will be propagated while those subclassed will be sent to
|
||||||
:func:`on_command_error`.
|
:func:`on_command_error`.
|
||||||
|
|
||||||
|
.. info::
|
||||||
|
|
||||||
|
These functions can either be regular functions or coroutines.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
-----------
|
-----------
|
||||||
predicate
|
predicate
|
||||||
|
@ -26,9 +26,11 @@ DEALINGS IN THE SOFTWARE.
|
|||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import inspect
|
import inspect
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from .core import GroupMixin, Command
|
from .core import GroupMixin, Command
|
||||||
from .errors import CommandError
|
from .errors import CommandError
|
||||||
|
# from discord.iterators import _FilteredAsyncIterator
|
||||||
|
|
||||||
# help -> shows info of bot on top/bottom and lists subcommands
|
# help -> shows info of bot on top/bottom and lists subcommands
|
||||||
# help command -> shows detailed info of command
|
# help command -> shows detailed info of command
|
||||||
@ -227,6 +229,7 @@ class HelpFormatter:
|
|||||||
return "Type {0}{1} command for more info on a command.\n" \
|
return "Type {0}{1} command for more info on a command.\n" \
|
||||||
"You can also type {0}{1} category for more info on a category.".format(self.clean_prefix, command_name)
|
"You can also type {0}{1} category for more info on a category.".format(self.clean_prefix, command_name)
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
def filter_command_list(self):
|
def filter_command_list(self):
|
||||||
"""Returns a filtered list of commands based on the two attributes
|
"""Returns a filtered list of commands based on the two attributes
|
||||||
provided, :attr:`show_check_failure` and :attr:`show_hidden`. Also
|
provided, :attr:`show_check_failure` and :attr:`show_hidden`. Also
|
||||||
@ -238,8 +241,9 @@ class HelpFormatter:
|
|||||||
An iterable with the filter being applied. The resulting value is
|
An iterable with the filter being applied. The resulting value is
|
||||||
a (key, value) tuple of the command name and the command itself.
|
a (key, value) tuple of the command name and the command itself.
|
||||||
"""
|
"""
|
||||||
def predicate(tuple):
|
|
||||||
cmd = tuple[1]
|
def sane_no_suspension_point_predicate(tup):
|
||||||
|
cmd = tup[1]
|
||||||
if self.is_cog():
|
if self.is_cog():
|
||||||
# filter commands that don't exist to this cog.
|
# filter commands that don't exist to this cog.
|
||||||
if cmd.instance is not self.command:
|
if cmd.instance is not self.command:
|
||||||
@ -248,18 +252,31 @@ class HelpFormatter:
|
|||||||
if cmd.hidden and not self.show_hidden:
|
if cmd.hidden and not self.show_hidden:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if self.show_check_failure:
|
return True
|
||||||
# we don't wanna bother doing the checks if the user does not
|
|
||||||
# care about them, so just return true.
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def predicate(tup):
|
||||||
|
if sane_no_suspension_point_predicate(tup) is False:
|
||||||
|
return False
|
||||||
|
|
||||||
|
cmd = tup[1]
|
||||||
try:
|
try:
|
||||||
return cmd.can_run(self.context) and self.context.bot.can_run(self.context)
|
return (yield from cmd.can_run(self.context))
|
||||||
except CommandError:
|
except CommandError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
iterator = self.command.commands.items() if not self.is_cog() else self.context.bot.commands.items()
|
iterator = self.command.commands.items() if not self.is_cog() else self.context.bot.commands.items()
|
||||||
return filter(predicate, iterator)
|
if not self.show_check_failure:
|
||||||
|
return filter(sane_no_suspension_point_predicate, iterator)
|
||||||
|
|
||||||
|
# Gotta run every check and verify it
|
||||||
|
ret = []
|
||||||
|
for elem in iterator:
|
||||||
|
valid = yield from predicate(elem)
|
||||||
|
if valid:
|
||||||
|
ret.append(elem)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
def _add_subcommands_to_page(self, max_width, commands):
|
def _add_subcommands_to_page(self, max_width, commands):
|
||||||
for name, command in commands:
|
for name, command in commands:
|
||||||
@ -271,6 +288,7 @@ class HelpFormatter:
|
|||||||
shortened = self.shorten(entry)
|
shortened = self.shorten(entry)
|
||||||
self._paginator.add_line(shortened)
|
self._paginator.add_line(shortened)
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
def format_help_for(self, context, command_or_bot):
|
def format_help_for(self, context, command_or_bot):
|
||||||
"""Formats the help page and handles the actual heavy lifting of how
|
"""Formats the help page and handles the actual heavy lifting of how
|
||||||
the help command looks like. To change the behaviour, override the
|
the help command looks like. To change the behaviour, override the
|
||||||
@ -290,8 +308,9 @@ class HelpFormatter:
|
|||||||
"""
|
"""
|
||||||
self.context = context
|
self.context = context
|
||||||
self.command = command_or_bot
|
self.command = command_or_bot
|
||||||
return self.format()
|
return (yield from self.format())
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
def format(self):
|
def format(self):
|
||||||
"""Handles the actual behaviour involved with formatting.
|
"""Handles the actual behaviour involved with formatting.
|
||||||
|
|
||||||
@ -334,18 +353,19 @@ class HelpFormatter:
|
|||||||
# last place sorting position.
|
# last place sorting position.
|
||||||
return cog + ':' if cog is not None else '\u200bNo Category:'
|
return cog + ':' if cog is not None else '\u200bNo Category:'
|
||||||
|
|
||||||
|
filtered = yield from self.filter_command_list()
|
||||||
if self.is_bot():
|
if self.is_bot():
|
||||||
data = sorted(self.filter_command_list(), key=category)
|
data = sorted(filtered, key=category)
|
||||||
for category, commands in itertools.groupby(data, key=category):
|
for category, commands in itertools.groupby(data, key=category):
|
||||||
# there simply is no prettier way of doing this.
|
# there simply is no prettier way of doing this.
|
||||||
commands = list(commands)
|
commands = sorted(commands)
|
||||||
if len(commands) > 0:
|
if len(commands) > 0:
|
||||||
self._paginator.add_line(category)
|
self._paginator.add_line(category)
|
||||||
|
|
||||||
self._add_subcommands_to_page(max_width, commands)
|
self._add_subcommands_to_page(max_width, commands)
|
||||||
else:
|
else:
|
||||||
self._paginator.add_line('Commands:')
|
self._paginator.add_line('Commands:')
|
||||||
self._add_subcommands_to_page(max_width, self.filter_command_list())
|
self._add_subcommands_to_page(max_width, sorted(filtered))
|
||||||
|
|
||||||
# add the ending note
|
# add the ending note
|
||||||
self._paginator.add_line()
|
self._paginator.add_line()
|
||||||
|
@ -30,18 +30,11 @@ import aiohttp
|
|||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from .errors import NoMoreItems
|
from .errors import NoMoreItems
|
||||||
from .utils import time_snowflake
|
from .utils import time_snowflake, maybe_coroutine
|
||||||
from .object import Object
|
from .object import Object
|
||||||
|
|
||||||
PY35 = sys.version_info >= (3, 5)
|
PY35 = sys.version_info >= (3, 5)
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def _probably_coroutine(f, e):
|
|
||||||
if asyncio.iscoroutinefunction(f):
|
|
||||||
return (yield from f(e))
|
|
||||||
else:
|
|
||||||
return f(e)
|
|
||||||
|
|
||||||
class _AsyncIterator:
|
class _AsyncIterator:
|
||||||
__slots__ = ()
|
__slots__ = ()
|
||||||
|
|
||||||
@ -67,7 +60,7 @@ class _AsyncIterator:
|
|||||||
except NoMoreItems:
|
except NoMoreItems:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
ret = yield from _probably_coroutine(predicate, elem)
|
ret = yield from maybe_coroutine(predicate, elem)
|
||||||
if ret:
|
if ret:
|
||||||
return elem
|
return elem
|
||||||
|
|
||||||
@ -114,7 +107,7 @@ class _MappedAsyncIterator(_AsyncIterator):
|
|||||||
def get(self):
|
def get(self):
|
||||||
# this raises NoMoreItems and will propagate appropriately
|
# this raises NoMoreItems and will propagate appropriately
|
||||||
item = yield from self.iterator.get()
|
item = yield from self.iterator.get()
|
||||||
return (yield from _probably_coroutine(self.func, item))
|
return (yield from maybe_coroutine(self.func, item))
|
||||||
|
|
||||||
class _FilteredAsyncIterator(_AsyncIterator):
|
class _FilteredAsyncIterator(_AsyncIterator):
|
||||||
def __init__(self, iterator, predicate):
|
def __init__(self, iterator, predicate):
|
||||||
@ -132,7 +125,7 @@ class _FilteredAsyncIterator(_AsyncIterator):
|
|||||||
while True:
|
while True:
|
||||||
# propagate NoMoreItems similar to _MappedAsyncIterator
|
# propagate NoMoreItems similar to _MappedAsyncIterator
|
||||||
item = yield from getter()
|
item = yield from getter()
|
||||||
ret = yield from _probably_coroutine(pred, item)
|
ret = yield from maybe_coroutine(pred, item)
|
||||||
if ret:
|
if ret:
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
@ -260,3 +260,19 @@ def _bytes_to_base64_data(data):
|
|||||||
def to_json(obj):
|
def to_json(obj):
|
||||||
return json.dumps(obj, separators=(',', ':'), ensure_ascii=True)
|
return json.dumps(obj, separators=(',', ':'), ensure_ascii=True)
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def maybe_coroutine(f, e):
|
||||||
|
if asyncio.iscoroutinefunction(f):
|
||||||
|
return (yield from f(e))
|
||||||
|
else:
|
||||||
|
return f(e)
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def async_all(gen):
|
||||||
|
check = asyncio.iscoroutine
|
||||||
|
for elem in gen:
|
||||||
|
if check(elem):
|
||||||
|
elem = yield from elem
|
||||||
|
if not elem:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
Loading…
x
Reference in New Issue
Block a user