[commands] Implement before and after invoke command hooks.
Fixes #464.
This commit is contained in:
parent
8fa50a8f3e
commit
1c49374210
@ -136,6 +136,8 @@ class BotBase(GroupMixin):
|
|||||||
self.cogs = {}
|
self.cogs = {}
|
||||||
self.extensions = {}
|
self.extensions = {}
|
||||||
self._checks = []
|
self._checks = []
|
||||||
|
self._before_invoke = None
|
||||||
|
self._after_invoke = None
|
||||||
self.description = inspect.cleandoc(description) if description else ''
|
self.description = inspect.cleandoc(description) if description else ''
|
||||||
self.pm_help = pm_help
|
self.pm_help = pm_help
|
||||||
self.command_not_found = options.pop('command_not_found', 'No command called "{}" found.')
|
self.command_not_found = options.pop('command_not_found', 'No command called "{}" found.')
|
||||||
@ -269,6 +271,71 @@ class BotBase(GroupMixin):
|
|||||||
def can_run(self, ctx):
|
def can_run(self, ctx):
|
||||||
return all(f(ctx) for f in self._checks)
|
return all(f(ctx) for f in self._checks)
|
||||||
|
|
||||||
|
def before_invoke(self, coro):
|
||||||
|
"""A decorator that registers a coroutine as a pre-invoke hook.
|
||||||
|
|
||||||
|
A pre-invoke hook is called directly before the command is
|
||||||
|
called. This makes it a useful function to set up database
|
||||||
|
connections or any type of set up required.
|
||||||
|
|
||||||
|
This pre-invoke hook takes a sole parameter, a :class:`Context`.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
The :meth:`before_invoke` and :meth:`after_invoke` hooks are
|
||||||
|
only called if all checks and argument parsing procedures pass
|
||||||
|
without error. If any check or argument parsing procedures fail
|
||||||
|
then the hooks are not called.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
coro
|
||||||
|
The coroutine to register as the pre-invoke hook.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
-------
|
||||||
|
discord.ClientException
|
||||||
|
The coroutine is not actually a coroutine.
|
||||||
|
"""
|
||||||
|
if not asyncio.iscoroutinefunction(coro):
|
||||||
|
raise discord.ClientException('The error handler must be a coroutine.')
|
||||||
|
|
||||||
|
self._before_invoke = coro
|
||||||
|
return coro
|
||||||
|
|
||||||
|
def after_invoke(self, coro):
|
||||||
|
"""A decorator that registers a coroutine as a post-invoke hook.
|
||||||
|
|
||||||
|
A post-invoke hook is called directly after the command is
|
||||||
|
called. This makes it a useful function to clean-up database
|
||||||
|
connections or any type of clean up required.
|
||||||
|
|
||||||
|
This post-invoke hook takes a sole parameter, a :class:`Context`.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Similar to :meth:`before_invoke`\, this is not called unless
|
||||||
|
checks and argument parsing procedures succeed. This hook is,
|
||||||
|
however, **always** called regardless of the internal command
|
||||||
|
callback raising an error (i.e. :exc:`CommandInvokeError`\).
|
||||||
|
This makes it ideal for clean-up scenarios.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
coro
|
||||||
|
The coroutine to register as the post-invoke hook.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
-------
|
||||||
|
discord.ClientException
|
||||||
|
The coroutine is not actually a coroutine.
|
||||||
|
"""
|
||||||
|
if not asyncio.iscoroutinefunction(coro):
|
||||||
|
raise discord.ClientException('The error handler must be a coroutine.')
|
||||||
|
|
||||||
|
self._after_invoke = coro
|
||||||
|
return coro
|
||||||
|
|
||||||
# listener registration
|
# listener registration
|
||||||
|
|
||||||
def add_listener(self, func, name=None):
|
def add_listener(self, func, name=None):
|
||||||
|
@ -52,6 +52,21 @@ def wrap_callback(coro):
|
|||||||
return ret
|
return ret
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
def hooked_wrapped_callback(command, ctx, coro):
|
||||||
|
@functools.wraps(coro)
|
||||||
|
@asyncio.coroutine
|
||||||
|
def wrapped(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
ret = yield from coro(*args, **kwargs)
|
||||||
|
except CommandError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise CommandInvokeError(e) from e
|
||||||
|
finally:
|
||||||
|
yield from command.call_after_hooks(ctx)
|
||||||
|
return ret
|
||||||
|
return wrapped
|
||||||
|
|
||||||
def _convert_to_bool(argument):
|
def _convert_to_bool(argument):
|
||||||
lowered = argument.lower()
|
lowered = argument.lower()
|
||||||
if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'):
|
if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'):
|
||||||
@ -144,6 +159,8 @@ class Command:
|
|||||||
self.instance = None
|
self.instance = None
|
||||||
self.parent = None
|
self.parent = None
|
||||||
self._buckets = CooldownMapping(kwargs.get('cooldown'))
|
self._buckets = CooldownMapping(kwargs.get('cooldown'))
|
||||||
|
self._before_invoke = None
|
||||||
|
self._after_invoke = None
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def dispatch_error(self, error, ctx):
|
def dispatch_error(self, error, ctx):
|
||||||
@ -335,6 +352,50 @@ class Command:
|
|||||||
if not self.can_run(ctx):
|
if not self.can_run(ctx):
|
||||||
raise CheckFailure('The check functions for command {0.qualified_name} failed.'.format(self))
|
raise CheckFailure('The check functions for command {0.qualified_name} failed.'.format(self))
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def call_before_hooks(self, ctx):
|
||||||
|
# now that we're done preparing we can call the pre-command hooks
|
||||||
|
# first, call the command local hook:
|
||||||
|
cog = self.instance
|
||||||
|
if self._before_invoke is not None:
|
||||||
|
if cog is None:
|
||||||
|
yield from self._before_invoke(ctx)
|
||||||
|
else:
|
||||||
|
yield from self._before_invoke(cog, ctx)
|
||||||
|
|
||||||
|
# call the cog local hook if applicable:
|
||||||
|
try:
|
||||||
|
hook = getattr(cog, '_{0.__class__.__name__}__before_invoke'.format(cog))
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
yield from hook(ctx)
|
||||||
|
|
||||||
|
# call the bot global hook if necessary
|
||||||
|
hook = ctx.bot._before_invoke
|
||||||
|
if hook is not None:
|
||||||
|
yield from hook(ctx)
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def call_after_hooks(self, ctx):
|
||||||
|
cog = self.instance
|
||||||
|
if self._after_invoke is not None:
|
||||||
|
if cog is None:
|
||||||
|
yield from self._after_invoke(ctx)
|
||||||
|
else:
|
||||||
|
yield from self._after_invoke(cog, ctx)
|
||||||
|
|
||||||
|
try:
|
||||||
|
hook = getattr(cog, '_{0.__class__.__name__}__after_invoke'.format(cog))
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
yield from hook(ctx)
|
||||||
|
|
||||||
|
hook = ctx.bot._after_invoke
|
||||||
|
if hook is not None:
|
||||||
|
yield from hook(ctx)
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def prepare(self, ctx):
|
def prepare(self, ctx):
|
||||||
ctx.command = self
|
ctx.command = self
|
||||||
@ -347,6 +408,8 @@ class Command:
|
|||||||
if retry_after:
|
if retry_after:
|
||||||
raise CommandOnCooldown(bucket, retry_after)
|
raise CommandOnCooldown(bucket, retry_after)
|
||||||
|
|
||||||
|
yield from self.call_before_hooks(ctx)
|
||||||
|
|
||||||
def reset_cooldown(self, ctx):
|
def reset_cooldown(self, ctx):
|
||||||
"""Resets the cooldown on this command.
|
"""Resets the cooldown on this command.
|
||||||
|
|
||||||
@ -367,7 +430,7 @@ class Command:
|
|||||||
# since we're in a regular command (and not a group) then
|
# since we're in a regular command (and not a group) then
|
||||||
# the invoked subcommand is None.
|
# the invoked subcommand is None.
|
||||||
ctx.invoked_subcommand = None
|
ctx.invoked_subcommand = None
|
||||||
injected = wrap_callback(self.callback)
|
injected = hooked_wrapped_callback(self, ctx, self.callback)
|
||||||
yield from injected(*ctx.args, **ctx.kwargs)
|
yield from injected(*ctx.args, **ctx.kwargs)
|
||||||
|
|
||||||
def error(self, coro):
|
def error(self, coro):
|
||||||
@ -394,6 +457,60 @@ class Command:
|
|||||||
self.on_error = coro
|
self.on_error = coro
|
||||||
return coro
|
return coro
|
||||||
|
|
||||||
|
def before_invoke(self, coro):
|
||||||
|
"""A decorator that registers a coroutine as a pre-invoke hook.
|
||||||
|
|
||||||
|
A pre-invoke hook is called directly before :meth:`invoke` is
|
||||||
|
called. This makes it a useful function to set up database
|
||||||
|
connections or any type of set up required.
|
||||||
|
|
||||||
|
This pre-invoke hook takes a sole parameter, a :class:`Context`.
|
||||||
|
|
||||||
|
See :meth:`Bot.before_invoke` for more info.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
coro
|
||||||
|
The coroutine to register as the pre-invoke hook.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
-------
|
||||||
|
discord.ClientException
|
||||||
|
The coroutine is not actually a coroutine.
|
||||||
|
"""
|
||||||
|
if not asyncio.iscoroutinefunction(coro):
|
||||||
|
raise discord.ClientException('The error handler must be a coroutine.')
|
||||||
|
|
||||||
|
self._before_invoke = coro
|
||||||
|
return coro
|
||||||
|
|
||||||
|
def after_invoke(self, coro):
|
||||||
|
"""A decorator that registers a coroutine as a post-invoke hook.
|
||||||
|
|
||||||
|
A post-invoke hook is called directly after :meth:`invoke` is
|
||||||
|
called. This makes it a useful function to clean-up database
|
||||||
|
connections or any type of clean up required.
|
||||||
|
|
||||||
|
This post-invoke hook takes a sole parameter, a :class:`Context`.
|
||||||
|
|
||||||
|
See :meth:`Bot.after_invoke` for more info.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
coro
|
||||||
|
The coroutine to register as the post-invoke hook.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
-------
|
||||||
|
discord.ClientException
|
||||||
|
The coroutine is not actually a coroutine.
|
||||||
|
"""
|
||||||
|
if not asyncio.iscoroutinefunction(coro):
|
||||||
|
raise discord.ClientException('The error handler must be a coroutine.')
|
||||||
|
|
||||||
|
self._after_invoke = coro
|
||||||
|
return coro
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cog_name(self):
|
def cog_name(self):
|
||||||
"""The name of the cog this command belongs to. None otherwise."""
|
"""The name of the cog this command belongs to. None otherwise."""
|
||||||
@ -610,7 +727,7 @@ class Group(GroupMixin, Command):
|
|||||||
ctx.invoked_subcommand = self.commands.get(trigger, None)
|
ctx.invoked_subcommand = self.commands.get(trigger, None)
|
||||||
|
|
||||||
if early_invoke:
|
if early_invoke:
|
||||||
injected = wrap_callback(self.callback)
|
injected = hooked_wrapped_callback(self, ctx, self.callback)
|
||||||
yield from injected(*ctx.args, **ctx.kwargs)
|
yield from injected(*ctx.args, **ctx.kwargs)
|
||||||
|
|
||||||
if trigger and ctx.invoked_subcommand:
|
if trigger and ctx.invoked_subcommand:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user