mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-07-21 10:26:47 +00:00
Add CommandTree.error decorator to set on_error dynamically
This commit is contained in:
parent
2bf612cd67
commit
698d1e12a1
@ -31,6 +31,7 @@ from typing import (
|
|||||||
Any,
|
Any,
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Callable,
|
Callable,
|
||||||
|
Coroutine,
|
||||||
Dict,
|
Dict,
|
||||||
Generator,
|
Generator,
|
||||||
Generic,
|
Generic,
|
||||||
@ -66,6 +67,15 @@ if TYPE_CHECKING:
|
|||||||
from ..abc import Snowflake
|
from ..abc import Snowflake
|
||||||
from .commands import ContextMenuCallback, CommandCallback, P, T
|
from .commands import ContextMenuCallback, CommandCallback, P, T
|
||||||
|
|
||||||
|
ErrorFunc = Callable[
|
||||||
|
[
|
||||||
|
Interaction,
|
||||||
|
Optional[Union[ContextMenu, Command[Any, ..., Any]]],
|
||||||
|
AppCommandError,
|
||||||
|
],
|
||||||
|
Coroutine[Any, Any, Any],
|
||||||
|
]
|
||||||
|
|
||||||
__all__ = ('CommandTree',)
|
__all__ = ('CommandTree',)
|
||||||
|
|
||||||
ClientT = TypeVar('ClientT', bound='Client')
|
ClientT = TypeVar('ClientT', bound='Client')
|
||||||
@ -681,6 +691,36 @@ class CommandTree(Generic[ClientT]):
|
|||||||
|
|
||||||
traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr)
|
traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr)
|
||||||
|
|
||||||
|
def error(self, coro: ErrorFunc) -> ErrorFunc:
|
||||||
|
"""A decorator that registers a coroutine as a local error handler.
|
||||||
|
|
||||||
|
This must match the signature of the :meth:`on_error` callback.
|
||||||
|
|
||||||
|
The error passed will be derived from :exc:`AppCommandError`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
coro: :ref:`coroutine <coroutine>`
|
||||||
|
The coroutine to register as the local error handler.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
-------
|
||||||
|
TypeError
|
||||||
|
The coroutine passed is not actually a coroutine or does
|
||||||
|
not match the signature.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not inspect.iscoroutinefunction(coro):
|
||||||
|
raise TypeError('The error handler must be a coroutine.')
|
||||||
|
|
||||||
|
params = inspect.signature(coro).parameters
|
||||||
|
if len(params) != 3:
|
||||||
|
raise TypeError('error handler must have 3 parameters')
|
||||||
|
|
||||||
|
# Type checker doesn't like overriding methods like this
|
||||||
|
self.on_error = coro # type: ignore
|
||||||
|
return coro
|
||||||
|
|
||||||
def command(
|
def command(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user