mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-08-30 15:01:42 +00:00
[commands] Fix typing of check/check_any
This changes the type information of check decorators to return a protocol representing that the decorator leaves the underlying object unchanged while having a .predicate attribute. resolves #7949
This commit is contained in:
parent
4a73de946a
commit
d0667d08e3
@ -23,7 +23,7 @@ DEALINGS IN THE SOFTWARE.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
from typing import Any, Awaitable, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union, Tuple, Optional
|
from typing import Any, Awaitable, Callable, Coroutine, TYPE_CHECKING, Protocol, TypeVar, Union, Tuple, Optional
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
@ -49,13 +49,23 @@ MaybeCoro = Union[T, Coro[T]]
|
|||||||
MaybeAwaitable = Union[T, Awaitable[T]]
|
MaybeAwaitable = Union[T, Awaitable[T]]
|
||||||
|
|
||||||
CogT = TypeVar('CogT', bound='Optional[Cog]')
|
CogT = TypeVar('CogT', bound='Optional[Cog]')
|
||||||
Check = Callable[["ContextT"], MaybeCoro[bool]]
|
UserCheck = Callable[["ContextT"], MaybeCoro[bool]]
|
||||||
Hook = Union[Callable[["CogT", "ContextT"], Coro[Any]], Callable[["ContextT"], Coro[Any]]]
|
Hook = Union[Callable[["CogT", "ContextT"], Coro[Any]], Callable[["ContextT"], Coro[Any]]]
|
||||||
Error = Union[Callable[["CogT", "ContextT", "CommandError"], Coro[Any]], Callable[["ContextT", "CommandError"], Coro[Any]]]
|
Error = Union[Callable[["CogT", "ContextT", "CommandError"], Coro[Any]], Callable[["ContextT", "CommandError"], Coro[Any]]]
|
||||||
|
|
||||||
ContextT = TypeVar('ContextT', bound='Context[Any]')
|
ContextT = TypeVar('ContextT', bound='Context[Any]')
|
||||||
BotT = TypeVar('BotT', bound=_Bot, covariant=True)
|
BotT = TypeVar('BotT', bound=_Bot, covariant=True)
|
||||||
|
|
||||||
|
ContextT_co = TypeVar('ContextT_co', bound='Context[Any]', covariant=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Check(Protocol[ContextT_co]):
|
||||||
|
|
||||||
|
predicate: Callable[[ContextT_co], Coroutine[Any, Any, bool]]
|
||||||
|
|
||||||
|
def __call__(self, coro_or_commands: T) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# This is merely a tag type to avoid circular import issues.
|
# This is merely a tag type to avoid circular import issues.
|
||||||
# Yes, this is a terrible solution but ultimately it is the only solution.
|
# Yes, this is a terrible solution but ultimately it is the only solution.
|
||||||
|
@ -73,7 +73,7 @@ if TYPE_CHECKING:
|
|||||||
from ._types import (
|
from ._types import (
|
||||||
_Bot,
|
_Bot,
|
||||||
BotT,
|
BotT,
|
||||||
Check,
|
UserCheck,
|
||||||
CoroFunc,
|
CoroFunc,
|
||||||
ContextT,
|
ContextT,
|
||||||
MaybeAwaitableFunc,
|
MaybeAwaitableFunc,
|
||||||
@ -173,8 +173,8 @@ class BotBase(GroupMixin[None]):
|
|||||||
self.__tree: app_commands.CommandTree[Self] = tree_cls(self) # type: ignore
|
self.__tree: app_commands.CommandTree[Self] = tree_cls(self) # type: ignore
|
||||||
self.__cogs: Dict[str, Cog] = {}
|
self.__cogs: Dict[str, Cog] = {}
|
||||||
self.__extensions: Dict[str, types.ModuleType] = {}
|
self.__extensions: Dict[str, types.ModuleType] = {}
|
||||||
self._checks: List[Check] = []
|
self._checks: List[UserCheck] = []
|
||||||
self._check_once: List[Check] = []
|
self._check_once: List[UserCheck] = []
|
||||||
self._before_invoke: Optional[CoroFunc] = None
|
self._before_invoke: Optional[CoroFunc] = None
|
||||||
self._after_invoke: Optional[CoroFunc] = None
|
self._after_invoke: Optional[CoroFunc] = None
|
||||||
self._help_command: Optional[HelpCommand] = None
|
self._help_command: Optional[HelpCommand] = None
|
||||||
@ -359,7 +359,7 @@ class BotBase(GroupMixin[None]):
|
|||||||
self.add_check(func) # type: ignore
|
self.add_check(func) # type: ignore
|
||||||
return func
|
return func
|
||||||
|
|
||||||
def add_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None:
|
def add_check(self, func: UserCheck[ContextT], /, *, call_once: bool = False) -> None:
|
||||||
"""Adds a global check to the bot.
|
"""Adds a global check to the bot.
|
||||||
|
|
||||||
This is the non-decorator interface to :meth:`.check`
|
This is the non-decorator interface to :meth:`.check`
|
||||||
@ -383,7 +383,7 @@ class BotBase(GroupMixin[None]):
|
|||||||
else:
|
else:
|
||||||
self._checks.append(func)
|
self._checks.append(func)
|
||||||
|
|
||||||
def remove_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None:
|
def remove_check(self, func: UserCheck[ContextT], /, *, call_once: bool = False) -> None:
|
||||||
"""Removes a global check from the bot.
|
"""Removes a global check from the bot.
|
||||||
|
|
||||||
This function is idempotent and will not raise an exception
|
This function is idempotent and will not raise an exception
|
||||||
|
@ -60,7 +60,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
from discord.message import Message
|
from discord.message import Message
|
||||||
|
|
||||||
from ._types import BotT, Check, ContextT, Coro, CoroFunc, Error, Hook
|
from ._types import BotT, Check, ContextT, Coro, CoroFunc, Error, Hook, UserCheck
|
||||||
|
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
@ -378,7 +378,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
checks = kwargs.get('checks', [])
|
checks = kwargs.get('checks', [])
|
||||||
|
|
||||||
self.checks: List[Check[ContextT]] = checks
|
self.checks: List[UserCheck[ContextT]] = checks
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cooldown = func.__commands_cooldown__
|
cooldown = func.__commands_cooldown__
|
||||||
@ -458,7 +458,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
|||||||
|
|
||||||
self.params: Dict[str, Parameter] = get_signature_parameters(function, globalns)
|
self.params: Dict[str, Parameter] = get_signature_parameters(function, globalns)
|
||||||
|
|
||||||
def add_check(self, func: Check[ContextT], /) -> None:
|
def add_check(self, func: UserCheck[ContextT], /) -> None:
|
||||||
"""Adds a check to the command.
|
"""Adds a check to the command.
|
||||||
|
|
||||||
This is the non-decorator interface to :func:`.check`.
|
This is the non-decorator interface to :func:`.check`.
|
||||||
@ -477,7 +477,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
|||||||
|
|
||||||
self.checks.append(func)
|
self.checks.append(func)
|
||||||
|
|
||||||
def remove_check(self, func: Check[ContextT], /) -> None:
|
def remove_check(self, func: UserCheck[ContextT], /) -> None:
|
||||||
"""Removes a check from the command.
|
"""Removes a check from the command.
|
||||||
|
|
||||||
This function is idempotent and will not raise an exception
|
This function is idempotent and will not raise an exception
|
||||||
@ -1745,7 +1745,7 @@ def group(
|
|||||||
return command(name=name, cls=cls, **attrs)
|
return command(name=name, cls=cls, **attrs)
|
||||||
|
|
||||||
|
|
||||||
def check(predicate: Check[ContextT], /) -> Callable[[T], T]:
|
def check(predicate: UserCheck[ContextT], /) -> Check[ContextT]:
|
||||||
r"""A decorator that adds a check to the :class:`.Command` or its
|
r"""A decorator that adds a check to the :class:`.Command` or its
|
||||||
subclasses. These checks could be accessed via :attr:`.Command.checks`.
|
subclasses. These checks could be accessed via :attr:`.Command.checks`.
|
||||||
|
|
||||||
@ -1844,7 +1844,7 @@ def check(predicate: Check[ContextT], /) -> Callable[[T], T]:
|
|||||||
return decorator # type: ignore
|
return decorator # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def check_any(*checks: Check[ContextT]) -> Callable[[T], T]:
|
def check_any(*checks: Check[ContextT]) -> Check[ContextT]:
|
||||||
r"""A :func:`check` that is added that checks if any of the checks passed
|
r"""A :func:`check` that is added that checks if any of the checks passed
|
||||||
will pass, i.e. using logical OR.
|
will pass, i.e. using logical OR.
|
||||||
|
|
||||||
@ -1910,10 +1910,10 @@ def check_any(*checks: Check[ContextT]) -> Callable[[T], T]:
|
|||||||
# if we're here, all checks failed
|
# if we're here, all checks failed
|
||||||
raise CheckAnyFailure(unwrapped, errors)
|
raise CheckAnyFailure(unwrapped, errors)
|
||||||
|
|
||||||
return check(predicate)
|
return check(predicate) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def has_role(item: Union[int, str], /) -> Callable[[T], T]:
|
def has_role(item: Union[int, str], /) -> Check[Any]:
|
||||||
"""A :func:`.check` that is added that checks if the member invoking the
|
"""A :func:`.check` that is added that checks if the member invoking the
|
||||||
command has the role specified via the name or ID specified.
|
command has the role specified via the name or ID specified.
|
||||||
|
|
||||||
@ -2066,7 +2066,7 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]:
|
|||||||
return check(predicate)
|
return check(predicate)
|
||||||
|
|
||||||
|
|
||||||
def has_permissions(**perms: bool) -> Callable[[T], T]:
|
def has_permissions(**perms: bool) -> Check[Any]:
|
||||||
"""A :func:`.check` that is added that checks if the member has all of
|
"""A :func:`.check` that is added that checks if the member has all of
|
||||||
the permissions necessary.
|
the permissions necessary.
|
||||||
|
|
||||||
@ -2114,7 +2114,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]:
|
|||||||
return check(predicate)
|
return check(predicate)
|
||||||
|
|
||||||
|
|
||||||
def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
|
def bot_has_permissions(**perms: bool) -> Check[Any]:
|
||||||
"""Similar to :func:`.has_permissions` except checks if the bot itself has
|
"""Similar to :func:`.has_permissions` except checks if the bot itself has
|
||||||
the permissions listed.
|
the permissions listed.
|
||||||
|
|
||||||
@ -2141,7 +2141,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
|
|||||||
return check(predicate)
|
return check(predicate)
|
||||||
|
|
||||||
|
|
||||||
def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
|
def has_guild_permissions(**perms: bool) -> Check[Any]:
|
||||||
"""Similar to :func:`.has_permissions`, but operates on guild wide
|
"""Similar to :func:`.has_permissions`, but operates on guild wide
|
||||||
permissions instead of the current channel permissions.
|
permissions instead of the current channel permissions.
|
||||||
|
|
||||||
@ -2170,7 +2170,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
|
|||||||
return check(predicate)
|
return check(predicate)
|
||||||
|
|
||||||
|
|
||||||
def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
|
def bot_has_guild_permissions(**perms: bool) -> Check[Any]:
|
||||||
"""Similar to :func:`.has_guild_permissions`, but checks the bot
|
"""Similar to :func:`.has_guild_permissions`, but checks the bot
|
||||||
members guild permissions.
|
members guild permissions.
|
||||||
|
|
||||||
@ -2196,7 +2196,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
|
|||||||
return check(predicate)
|
return check(predicate)
|
||||||
|
|
||||||
|
|
||||||
def dm_only() -> Callable[[T], T]:
|
def dm_only() -> Check[Any]:
|
||||||
"""A :func:`.check` that indicates this command must only be used in a
|
"""A :func:`.check` that indicates this command must only be used in a
|
||||||
DM context. Only private messages are allowed when
|
DM context. Only private messages are allowed when
|
||||||
using the command.
|
using the command.
|
||||||
@ -2215,7 +2215,7 @@ def dm_only() -> Callable[[T], T]:
|
|||||||
return check(predicate)
|
return check(predicate)
|
||||||
|
|
||||||
|
|
||||||
def guild_only() -> Callable[[T], T]:
|
def guild_only() -> Check[Any]:
|
||||||
"""A :func:`.check` that indicates this command must only be used in a
|
"""A :func:`.check` that indicates this command must only be used in a
|
||||||
guild context only. Basically, no private messages are allowed when
|
guild context only. Basically, no private messages are allowed when
|
||||||
using the command.
|
using the command.
|
||||||
@ -2232,7 +2232,7 @@ def guild_only() -> Callable[[T], T]:
|
|||||||
return check(predicate)
|
return check(predicate)
|
||||||
|
|
||||||
|
|
||||||
def is_owner() -> Callable[[T], T]:
|
def is_owner() -> Check[Any]:
|
||||||
"""A :func:`.check` that checks if the person invoking this command is the
|
"""A :func:`.check` that checks if the person invoking this command is the
|
||||||
owner of the bot.
|
owner of the bot.
|
||||||
|
|
||||||
@ -2250,7 +2250,7 @@ def is_owner() -> Callable[[T], T]:
|
|||||||
return check(predicate)
|
return check(predicate)
|
||||||
|
|
||||||
|
|
||||||
def is_nsfw() -> Callable[[T], T]:
|
def is_nsfw() -> Check[Any]:
|
||||||
"""A :func:`.check` that checks if the channel is a NSFW channel.
|
"""A :func:`.check` that checks if the channel is a NSFW channel.
|
||||||
|
|
||||||
This check raises a special exception, :exc:`.NSFWChannelRequired`
|
This check raises a special exception, :exc:`.NSFWChannelRequired`
|
||||||
|
@ -60,7 +60,7 @@ if TYPE_CHECKING:
|
|||||||
from .parameters import Parameter
|
from .parameters import Parameter
|
||||||
|
|
||||||
from ._types import (
|
from ._types import (
|
||||||
Check,
|
UserCheck,
|
||||||
ContextT,
|
ContextT,
|
||||||
BotT,
|
BotT,
|
||||||
_Bot,
|
_Bot,
|
||||||
@ -378,7 +378,7 @@ class HelpCommand:
|
|||||||
bot.remove_command(self._command_impl.name)
|
bot.remove_command(self._command_impl.name)
|
||||||
self._command_impl._eject_cog()
|
self._command_impl._eject_cog()
|
||||||
|
|
||||||
def add_check(self, func: Check[ContextT], /) -> None:
|
def add_check(self, func: UserCheck[ContextT], /) -> None:
|
||||||
"""
|
"""
|
||||||
Adds a check to the help command.
|
Adds a check to the help command.
|
||||||
|
|
||||||
@ -396,7 +396,7 @@ class HelpCommand:
|
|||||||
|
|
||||||
self._command_impl.add_check(func)
|
self._command_impl.add_check(func)
|
||||||
|
|
||||||
def remove_check(self, func: Check[ContextT], /) -> None:
|
def remove_check(self, func: UserCheck[ContextT], /) -> None:
|
||||||
"""
|
"""
|
||||||
Removes a check from the help command.
|
Removes a check from the help command.
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user