mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-04-21 16:24:14 +00:00
[commands] Fix typing of check/check_any
This changes the type information of check decorators to return a protocol representing that the decorator leaves the underlying object unchanged while having a .predicate attribute. resolves #7949
This commit is contained in:
parent
4a73de946a
commit
d0667d08e3
@ -23,7 +23,7 @@ DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
|
||||
from typing import Any, Awaitable, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union, Tuple, Optional
|
||||
from typing import Any, Awaitable, Callable, Coroutine, TYPE_CHECKING, Protocol, TypeVar, Union, Tuple, Optional
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
@ -49,13 +49,23 @@ MaybeCoro = Union[T, Coro[T]]
|
||||
MaybeAwaitable = Union[T, Awaitable[T]]
|
||||
|
||||
CogT = TypeVar('CogT', bound='Optional[Cog]')
|
||||
Check = Callable[["ContextT"], MaybeCoro[bool]]
|
||||
UserCheck = Callable[["ContextT"], MaybeCoro[bool]]
|
||||
Hook = Union[Callable[["CogT", "ContextT"], Coro[Any]], Callable[["ContextT"], Coro[Any]]]
|
||||
Error = Union[Callable[["CogT", "ContextT", "CommandError"], Coro[Any]], Callable[["ContextT", "CommandError"], Coro[Any]]]
|
||||
|
||||
ContextT = TypeVar('ContextT', bound='Context[Any]')
|
||||
BotT = TypeVar('BotT', bound=_Bot, covariant=True)
|
||||
|
||||
ContextT_co = TypeVar('ContextT_co', bound='Context[Any]', covariant=True)
|
||||
|
||||
|
||||
class Check(Protocol[ContextT_co]):
|
||||
|
||||
predicate: Callable[[ContextT_co], Coroutine[Any, Any, bool]]
|
||||
|
||||
def __call__(self, coro_or_commands: T) -> T:
|
||||
...
|
||||
|
||||
|
||||
# This is merely a tag type to avoid circular import issues.
|
||||
# Yes, this is a terrible solution but ultimately it is the only solution.
|
||||
|
@ -73,7 +73,7 @@ if TYPE_CHECKING:
|
||||
from ._types import (
|
||||
_Bot,
|
||||
BotT,
|
||||
Check,
|
||||
UserCheck,
|
||||
CoroFunc,
|
||||
ContextT,
|
||||
MaybeAwaitableFunc,
|
||||
@ -173,8 +173,8 @@ class BotBase(GroupMixin[None]):
|
||||
self.__tree: app_commands.CommandTree[Self] = tree_cls(self) # type: ignore
|
||||
self.__cogs: Dict[str, Cog] = {}
|
||||
self.__extensions: Dict[str, types.ModuleType] = {}
|
||||
self._checks: List[Check] = []
|
||||
self._check_once: List[Check] = []
|
||||
self._checks: List[UserCheck] = []
|
||||
self._check_once: List[UserCheck] = []
|
||||
self._before_invoke: Optional[CoroFunc] = None
|
||||
self._after_invoke: Optional[CoroFunc] = None
|
||||
self._help_command: Optional[HelpCommand] = None
|
||||
@ -359,7 +359,7 @@ class BotBase(GroupMixin[None]):
|
||||
self.add_check(func) # type: ignore
|
||||
return func
|
||||
|
||||
def add_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None:
|
||||
def add_check(self, func: UserCheck[ContextT], /, *, call_once: bool = False) -> None:
|
||||
"""Adds a global check to the bot.
|
||||
|
||||
This is the non-decorator interface to :meth:`.check`
|
||||
@ -383,7 +383,7 @@ class BotBase(GroupMixin[None]):
|
||||
else:
|
||||
self._checks.append(func)
|
||||
|
||||
def remove_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None:
|
||||
def remove_check(self, func: UserCheck[ContextT], /, *, call_once: bool = False) -> None:
|
||||
"""Removes a global check from the bot.
|
||||
|
||||
This function is idempotent and will not raise an exception
|
||||
|
@ -60,7 +60,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from discord.message import Message
|
||||
|
||||
from ._types import BotT, Check, ContextT, Coro, CoroFunc, Error, Hook
|
||||
from ._types import BotT, Check, ContextT, Coro, CoroFunc, Error, Hook, UserCheck
|
||||
|
||||
|
||||
__all__ = (
|
||||
@ -378,7 +378,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
except AttributeError:
|
||||
checks = kwargs.get('checks', [])
|
||||
|
||||
self.checks: List[Check[ContextT]] = checks
|
||||
self.checks: List[UserCheck[ContextT]] = checks
|
||||
|
||||
try:
|
||||
cooldown = func.__commands_cooldown__
|
||||
@ -458,7 +458,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
|
||||
self.params: Dict[str, Parameter] = get_signature_parameters(function, globalns)
|
||||
|
||||
def add_check(self, func: Check[ContextT], /) -> None:
|
||||
def add_check(self, func: UserCheck[ContextT], /) -> None:
|
||||
"""Adds a check to the command.
|
||||
|
||||
This is the non-decorator interface to :func:`.check`.
|
||||
@ -477,7 +477,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
|
||||
self.checks.append(func)
|
||||
|
||||
def remove_check(self, func: Check[ContextT], /) -> None:
|
||||
def remove_check(self, func: UserCheck[ContextT], /) -> None:
|
||||
"""Removes a check from the command.
|
||||
|
||||
This function is idempotent and will not raise an exception
|
||||
@ -1745,7 +1745,7 @@ def group(
|
||||
return command(name=name, cls=cls, **attrs)
|
||||
|
||||
|
||||
def check(predicate: Check[ContextT], /) -> Callable[[T], T]:
|
||||
def check(predicate: UserCheck[ContextT], /) -> Check[ContextT]:
|
||||
r"""A decorator that adds a check to the :class:`.Command` or its
|
||||
subclasses. These checks could be accessed via :attr:`.Command.checks`.
|
||||
|
||||
@ -1844,7 +1844,7 @@ def check(predicate: Check[ContextT], /) -> Callable[[T], T]:
|
||||
return decorator # type: ignore
|
||||
|
||||
|
||||
def check_any(*checks: Check[ContextT]) -> Callable[[T], T]:
|
||||
def check_any(*checks: Check[ContextT]) -> Check[ContextT]:
|
||||
r"""A :func:`check` that is added that checks if any of the checks passed
|
||||
will pass, i.e. using logical OR.
|
||||
|
||||
@ -1910,10 +1910,10 @@ def check_any(*checks: Check[ContextT]) -> Callable[[T], T]:
|
||||
# if we're here, all checks failed
|
||||
raise CheckAnyFailure(unwrapped, errors)
|
||||
|
||||
return check(predicate)
|
||||
return check(predicate) # type: ignore
|
||||
|
||||
|
||||
def has_role(item: Union[int, str], /) -> Callable[[T], T]:
|
||||
def has_role(item: Union[int, str], /) -> Check[Any]:
|
||||
"""A :func:`.check` that is added that checks if the member invoking the
|
||||
command has the role specified via the name or ID specified.
|
||||
|
||||
@ -2066,7 +2066,7 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]:
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def has_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
def has_permissions(**perms: bool) -> Check[Any]:
|
||||
"""A :func:`.check` that is added that checks if the member has all of
|
||||
the permissions necessary.
|
||||
|
||||
@ -2114,7 +2114,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
def bot_has_permissions(**perms: bool) -> Check[Any]:
|
||||
"""Similar to :func:`.has_permissions` except checks if the bot itself has
|
||||
the permissions listed.
|
||||
|
||||
@ -2141,7 +2141,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
def has_guild_permissions(**perms: bool) -> Check[Any]:
|
||||
"""Similar to :func:`.has_permissions`, but operates on guild wide
|
||||
permissions instead of the current channel permissions.
|
||||
|
||||
@ -2170,7 +2170,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
def bot_has_guild_permissions(**perms: bool) -> Check[Any]:
|
||||
"""Similar to :func:`.has_guild_permissions`, but checks the bot
|
||||
members guild permissions.
|
||||
|
||||
@ -2196,7 +2196,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def dm_only() -> Callable[[T], T]:
|
||||
def dm_only() -> Check[Any]:
|
||||
"""A :func:`.check` that indicates this command must only be used in a
|
||||
DM context. Only private messages are allowed when
|
||||
using the command.
|
||||
@ -2215,7 +2215,7 @@ def dm_only() -> Callable[[T], T]:
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def guild_only() -> Callable[[T], T]:
|
||||
def guild_only() -> Check[Any]:
|
||||
"""A :func:`.check` that indicates this command must only be used in a
|
||||
guild context only. Basically, no private messages are allowed when
|
||||
using the command.
|
||||
@ -2232,7 +2232,7 @@ def guild_only() -> Callable[[T], T]:
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def is_owner() -> Callable[[T], T]:
|
||||
def is_owner() -> Check[Any]:
|
||||
"""A :func:`.check` that checks if the person invoking this command is the
|
||||
owner of the bot.
|
||||
|
||||
@ -2250,7 +2250,7 @@ def is_owner() -> Callable[[T], T]:
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def is_nsfw() -> Callable[[T], T]:
|
||||
def is_nsfw() -> Check[Any]:
|
||||
"""A :func:`.check` that checks if the channel is a NSFW channel.
|
||||
|
||||
This check raises a special exception, :exc:`.NSFWChannelRequired`
|
||||
|
@ -60,7 +60,7 @@ if TYPE_CHECKING:
|
||||
from .parameters import Parameter
|
||||
|
||||
from ._types import (
|
||||
Check,
|
||||
UserCheck,
|
||||
ContextT,
|
||||
BotT,
|
||||
_Bot,
|
||||
@ -378,7 +378,7 @@ class HelpCommand:
|
||||
bot.remove_command(self._command_impl.name)
|
||||
self._command_impl._eject_cog()
|
||||
|
||||
def add_check(self, func: Check[ContextT], /) -> None:
|
||||
def add_check(self, func: UserCheck[ContextT], /) -> None:
|
||||
"""
|
||||
Adds a check to the help command.
|
||||
|
||||
@ -396,7 +396,7 @@ class HelpCommand:
|
||||
|
||||
self._command_impl.add_check(func)
|
||||
|
||||
def remove_check(self, func: Check[ContextT], /) -> None:
|
||||
def remove_check(self, func: UserCheck[ContextT], /) -> None:
|
||||
"""
|
||||
Removes a check from the help command.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user