[commands] Implement before and after invoke command hooks.
Fixes #464.
This commit is contained in:
@@ -136,6 +136,8 @@ class BotBase(GroupMixin):
|
||||
self.cogs = {}
|
||||
self.extensions = {}
|
||||
self._checks = []
|
||||
self._before_invoke = None
|
||||
self._after_invoke = None
|
||||
self.description = inspect.cleandoc(description) if description else ''
|
||||
self.pm_help = pm_help
|
||||
self.command_not_found = options.pop('command_not_found', 'No command called "{}" found.')
|
||||
@@ -269,6 +271,71 @@ class BotBase(GroupMixin):
|
||||
def can_run(self, ctx):
|
||||
return all(f(ctx) for f in self._checks)
|
||||
|
||||
def before_invoke(self, coro):
|
||||
"""A decorator that registers a coroutine as a pre-invoke hook.
|
||||
|
||||
A pre-invoke hook is called directly before the command is
|
||||
called. This makes it a useful function to set up database
|
||||
connections or any type of set up required.
|
||||
|
||||
This pre-invoke hook takes a sole parameter, a :class:`Context`.
|
||||
|
||||
.. note::
|
||||
|
||||
The :meth:`before_invoke` and :meth:`after_invoke` hooks are
|
||||
only called if all checks and argument parsing procedures pass
|
||||
without error. If any check or argument parsing procedures fail
|
||||
then the hooks are not called.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
coro
|
||||
The coroutine to register as the pre-invoke hook.
|
||||
|
||||
Raises
|
||||
-------
|
||||
discord.ClientException
|
||||
The coroutine is not actually a coroutine.
|
||||
"""
|
||||
if not asyncio.iscoroutinefunction(coro):
|
||||
raise discord.ClientException('The error handler must be a coroutine.')
|
||||
|
||||
self._before_invoke = coro
|
||||
return coro
|
||||
|
||||
def after_invoke(self, coro):
|
||||
"""A decorator that registers a coroutine as a post-invoke hook.
|
||||
|
||||
A post-invoke hook is called directly after the command is
|
||||
called. This makes it a useful function to clean-up database
|
||||
connections or any type of clean up required.
|
||||
|
||||
This post-invoke hook takes a sole parameter, a :class:`Context`.
|
||||
|
||||
.. note::
|
||||
|
||||
Similar to :meth:`before_invoke`\, this is not called unless
|
||||
checks and argument parsing procedures succeed. This hook is,
|
||||
however, **always** called regardless of the internal command
|
||||
callback raising an error (i.e. :exc:`CommandInvokeError`\).
|
||||
This makes it ideal for clean-up scenarios.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
coro
|
||||
The coroutine to register as the post-invoke hook.
|
||||
|
||||
Raises
|
||||
-------
|
||||
discord.ClientException
|
||||
The coroutine is not actually a coroutine.
|
||||
"""
|
||||
if not asyncio.iscoroutinefunction(coro):
|
||||
raise discord.ClientException('The error handler must be a coroutine.')
|
||||
|
||||
self._after_invoke = coro
|
||||
return coro
|
||||
|
||||
# listener registration
|
||||
|
||||
def add_listener(self, func, name=None):
|
||||
|
@@ -52,6 +52,21 @@ def wrap_callback(coro):
|
||||
return ret
|
||||
return wrapped
|
||||
|
||||
def hooked_wrapped_callback(command, ctx, coro):
|
||||
@functools.wraps(coro)
|
||||
@asyncio.coroutine
|
||||
def wrapped(*args, **kwargs):
|
||||
try:
|
||||
ret = yield from coro(*args, **kwargs)
|
||||
except CommandError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise CommandInvokeError(e) from e
|
||||
finally:
|
||||
yield from command.call_after_hooks(ctx)
|
||||
return ret
|
||||
return wrapped
|
||||
|
||||
def _convert_to_bool(argument):
|
||||
lowered = argument.lower()
|
||||
if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'):
|
||||
@@ -144,6 +159,8 @@ class Command:
|
||||
self.instance = None
|
||||
self.parent = None
|
||||
self._buckets = CooldownMapping(kwargs.get('cooldown'))
|
||||
self._before_invoke = None
|
||||
self._after_invoke = None
|
||||
|
||||
@asyncio.coroutine
|
||||
def dispatch_error(self, error, ctx):
|
||||
@@ -335,6 +352,50 @@ class Command:
|
||||
if not self.can_run(ctx):
|
||||
raise CheckFailure('The check functions for command {0.qualified_name} failed.'.format(self))
|
||||
|
||||
@asyncio.coroutine
|
||||
def call_before_hooks(self, ctx):
|
||||
# now that we're done preparing we can call the pre-command hooks
|
||||
# first, call the command local hook:
|
||||
cog = self.instance
|
||||
if self._before_invoke is not None:
|
||||
if cog is None:
|
||||
yield from self._before_invoke(ctx)
|
||||
else:
|
||||
yield from self._before_invoke(cog, ctx)
|
||||
|
||||
# call the cog local hook if applicable:
|
||||
try:
|
||||
hook = getattr(cog, '_{0.__class__.__name__}__before_invoke'.format(cog))
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
yield from hook(ctx)
|
||||
|
||||
# call the bot global hook if necessary
|
||||
hook = ctx.bot._before_invoke
|
||||
if hook is not None:
|
||||
yield from hook(ctx)
|
||||
|
||||
@asyncio.coroutine
|
||||
def call_after_hooks(self, ctx):
|
||||
cog = self.instance
|
||||
if self._after_invoke is not None:
|
||||
if cog is None:
|
||||
yield from self._after_invoke(ctx)
|
||||
else:
|
||||
yield from self._after_invoke(cog, ctx)
|
||||
|
||||
try:
|
||||
hook = getattr(cog, '_{0.__class__.__name__}__after_invoke'.format(cog))
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
yield from hook(ctx)
|
||||
|
||||
hook = ctx.bot._after_invoke
|
||||
if hook is not None:
|
||||
yield from hook(ctx)
|
||||
|
||||
@asyncio.coroutine
|
||||
def prepare(self, ctx):
|
||||
ctx.command = self
|
||||
@@ -347,6 +408,8 @@ class Command:
|
||||
if retry_after:
|
||||
raise CommandOnCooldown(bucket, retry_after)
|
||||
|
||||
yield from self.call_before_hooks(ctx)
|
||||
|
||||
def reset_cooldown(self, ctx):
|
||||
"""Resets the cooldown on this command.
|
||||
|
||||
@@ -367,7 +430,7 @@ class Command:
|
||||
# since we're in a regular command (and not a group) then
|
||||
# the invoked subcommand is None.
|
||||
ctx.invoked_subcommand = None
|
||||
injected = wrap_callback(self.callback)
|
||||
injected = hooked_wrapped_callback(self, ctx, self.callback)
|
||||
yield from injected(*ctx.args, **ctx.kwargs)
|
||||
|
||||
def error(self, coro):
|
||||
@@ -394,6 +457,60 @@ class Command:
|
||||
self.on_error = coro
|
||||
return coro
|
||||
|
||||
def before_invoke(self, coro):
|
||||
"""A decorator that registers a coroutine as a pre-invoke hook.
|
||||
|
||||
A pre-invoke hook is called directly before :meth:`invoke` is
|
||||
called. This makes it a useful function to set up database
|
||||
connections or any type of set up required.
|
||||
|
||||
This pre-invoke hook takes a sole parameter, a :class:`Context`.
|
||||
|
||||
See :meth:`Bot.before_invoke` for more info.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
coro
|
||||
The coroutine to register as the pre-invoke hook.
|
||||
|
||||
Raises
|
||||
-------
|
||||
discord.ClientException
|
||||
The coroutine is not actually a coroutine.
|
||||
"""
|
||||
if not asyncio.iscoroutinefunction(coro):
|
||||
raise discord.ClientException('The error handler must be a coroutine.')
|
||||
|
||||
self._before_invoke = coro
|
||||
return coro
|
||||
|
||||
def after_invoke(self, coro):
|
||||
"""A decorator that registers a coroutine as a post-invoke hook.
|
||||
|
||||
A post-invoke hook is called directly after :meth:`invoke` is
|
||||
called. This makes it a useful function to clean-up database
|
||||
connections or any type of clean up required.
|
||||
|
||||
This post-invoke hook takes a sole parameter, a :class:`Context`.
|
||||
|
||||
See :meth:`Bot.after_invoke` for more info.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
coro
|
||||
The coroutine to register as the post-invoke hook.
|
||||
|
||||
Raises
|
||||
-------
|
||||
discord.ClientException
|
||||
The coroutine is not actually a coroutine.
|
||||
"""
|
||||
if not asyncio.iscoroutinefunction(coro):
|
||||
raise discord.ClientException('The error handler must be a coroutine.')
|
||||
|
||||
self._after_invoke = coro
|
||||
return coro
|
||||
|
||||
@property
|
||||
def cog_name(self):
|
||||
"""The name of the cog this command belongs to. None otherwise."""
|
||||
@@ -610,7 +727,7 @@ class Group(GroupMixin, Command):
|
||||
ctx.invoked_subcommand = self.commands.get(trigger, None)
|
||||
|
||||
if early_invoke:
|
||||
injected = wrap_callback(self.callback)
|
||||
injected = hooked_wrapped_callback(self, ctx, self.callback)
|
||||
yield from injected(*ctx.args, **ctx.kwargs)
|
||||
|
||||
if trigger and ctx.invoked_subcommand:
|
||||
|
Reference in New Issue
Block a user