From 07ad6951fbb2151eb6b166484cf5bed5ee00f5b7 Mon Sep 17 00:00:00 2001 From: Bryan Forbes Date: Mon, 1 Aug 2022 05:24:55 -0500 Subject: [PATCH] Fix various generics throughout the public interface Fix CooldownMapping generic typing and ensure other public methods have proper generics --- discord/app_commands/commands.py | 2 +- discord/ext/commands/cooldowns.py | 45 +++++++++++++++---------------- discord/ext/commands/core.py | 20 +++++++------- 3 files changed, 33 insertions(+), 34 deletions(-) diff --git a/discord/app_commands/commands.py b/discord/app_commands/commands.py index 396afb869..054969964 100644 --- a/discord/app_commands/commands.py +++ b/discord/app_commands/commands.py @@ -134,7 +134,7 @@ else: AutocompleteCallback = Callable[..., Coro[T]] -CheckInputParameter = Union['Command[Any, ..., Any]', 'ContextMenu', CommandCallback, ContextMenuCallback] +CheckInputParameter = Union['Command[Any, ..., Any]', 'ContextMenu', 'CommandCallback[Any, ..., Any]', ContextMenuCallback] # The re module doesn't support \p{} so we have to list characters from Thai and Devanagari manually. THAI_COMBINING = r'\u0e31-\u0e3a\u0e47-\u0e4e' diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index 1a332370c..2af7cb017 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -40,7 +40,6 @@ if TYPE_CHECKING: from typing_extensions import Self from ...message import Message - from ._types import BotT __all__ = ( 'BucketType', @@ -50,7 +49,7 @@ __all__ = ( 'MaxConcurrency', ) -T = TypeVar('T') +T_contra = TypeVar('T_contra', contravariant=True) class BucketType(Enum): @@ -62,7 +61,7 @@ class BucketType(Enum): category = 5 role = 6 - def get_key(self, msg: Union[Message, Context[BotT]]) -> Any: + def get_key(self, msg: Union[Message, Context[Any]]) -> Any: if self is BucketType.user: return msg.author.id elif self is BucketType.guild: @@ -80,24 +79,24 @@ class BucketType(Enum): # receiving a DMChannel or GroupChannel which inherit from PrivateChannel and do return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore - def __call__(self, msg: Union[Message, Context[BotT]]) -> Any: + def __call__(self, msg: Union[Message, Context[Any]]) -> Any: return self.get_key(msg) -class CooldownMapping(Generic[T]): +class CooldownMapping(Generic[T_contra]): def __init__( self, original: Optional[Cooldown], - type: Callable[[T], Any], + type: Callable[[T_contra], Any], ) -> None: if not callable(type): raise TypeError('Cooldown type must be a BucketType or callable') self._cache: Dict[Any, Cooldown] = {} self._cooldown: Optional[Cooldown] = original - self._type: Callable[[T], Any] = type + self._type: Callable[[T_contra], Any] = type - def copy(self) -> CooldownMapping: + def copy(self) -> CooldownMapping[T_contra]: ret = CooldownMapping(self._cooldown, self._type) ret._cache = self._cache.copy() return ret @@ -107,14 +106,14 @@ class CooldownMapping(Generic[T]): return self._cooldown is not None @property - def type(self) -> Callable[[T], Any]: + def type(self) -> Callable[[T_contra], Any]: return self._type @classmethod - def from_cooldown(cls, rate: float, per: float, type: Callable[[T], Any]) -> Self: + def from_cooldown(cls, rate: float, per: float, type: Callable[[T_contra], Any]) -> Self: return cls(Cooldown(rate, per), type) - def _bucket_key(self, msg: T) -> Any: + def _bucket_key(self, msg: T_contra) -> Any: return self._type(msg) def _verify_cache_integrity(self, current: Optional[float] = None) -> None: @@ -126,10 +125,10 @@ class CooldownMapping(Generic[T]): for k in dead_keys: del self._cache[k] - def create_bucket(self, message: T) -> Cooldown: + def create_bucket(self, message: T_contra) -> Cooldown: return self._cooldown.copy() # type: ignore - def get_bucket(self, message: T, current: Optional[float] = None) -> Optional[Cooldown]: + def get_bucket(self, message: T_contra, current: Optional[float] = None) -> Optional[Cooldown]: if self._type is BucketType.default: return self._cooldown @@ -144,23 +143,23 @@ class CooldownMapping(Generic[T]): return bucket - def update_rate_limit(self, message: T, current: Optional[float] = None, tokens: int = 1) -> Optional[float]: + def update_rate_limit(self, message: T_contra, current: Optional[float] = None, tokens: int = 1) -> Optional[float]: bucket = self.get_bucket(message, current) if bucket is None: return None return bucket.update_rate_limit(current, tokens=tokens) -class DynamicCooldownMapping(CooldownMapping[T]): +class DynamicCooldownMapping(CooldownMapping[T_contra]): def __init__( self, - factory: Callable[[T], Optional[Cooldown]], - type: Callable[[T], Any], + factory: Callable[[T_contra], Optional[Cooldown]], + type: Callable[[T_contra], Any], ) -> None: super().__init__(None, type) - self._factory: Callable[[T], Optional[Cooldown]] = factory + self._factory: Callable[[T_contra], Optional[Cooldown]] = factory - def copy(self) -> DynamicCooldownMapping: + def copy(self) -> DynamicCooldownMapping[T_contra]: ret = DynamicCooldownMapping(self._factory, self._type) ret._cache = self._cache.copy() return ret @@ -169,7 +168,7 @@ class DynamicCooldownMapping(CooldownMapping[T]): def valid(self) -> bool: return True - def create_bucket(self, message: T) -> Optional[Cooldown]: + def create_bucket(self, message: T_contra) -> Optional[Cooldown]: return self._factory(message) @@ -254,10 +253,10 @@ class MaxConcurrency: def __repr__(self) -> str: return f'' - def get_key(self, message: Message) -> Any: + def get_key(self, message: Union[Message, Context[Any]]) -> Any: return self.per.get_key(message) - async def acquire(self, message: Message) -> None: + async def acquire(self, message: Union[Message, Context[Any]]) -> None: key = self.get_key(message) try: @@ -269,7 +268,7 @@ class MaxConcurrency: if not acquired: raise MaxConcurrencyReached(self.number, self.per) - async def release(self, message: Message) -> None: + async def release(self, message: Union[Message, Context[Any]]) -> None: # Technically there's no reason for this function to be async # But it might be more useful in the future key = self.get_key(message) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 14f177fba..68b0f0950 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -91,9 +91,9 @@ __all__ = ( MISSING: Any = discord.utils.MISSING T = TypeVar('T') -CommandT = TypeVar('CommandT', bound='Command') +CommandT = TypeVar('CommandT', bound='Command[Any, ..., Any]') # CHT = TypeVar('CHT', bound='Check') -GroupT = TypeVar('GroupT', bound='Group') +GroupT = TypeVar('GroupT', bound='Group[Any, ..., Any]') if TYPE_CHECKING: P = ParamSpec('P') @@ -404,10 +404,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]): if cooldown is None: buckets = CooldownMapping(cooldown, BucketType.default) elif isinstance(cooldown, CooldownMapping): - buckets: CooldownMapping[Context] = cooldown + buckets: CooldownMapping[Context[Any]] = cooldown else: raise TypeError("Cooldown must be a an instance of CooldownMapping or None.") - self._buckets: CooldownMapping[Context] = buckets + self._buckets: CooldownMapping[Context[Any]] = buckets try: max_concurrency = func.__commands_max_concurrency__ @@ -452,15 +452,15 @@ class Command(_BaseCommand, Generic[CogT, P, T]): @property def callback( self, - ) -> Union[Callable[Concatenate[CogT, Context, P], Coro[T]], Callable[Concatenate[Context, P], Coro[T]],]: + ) -> Union[Callable[Concatenate[CogT, Context[Any], P], Coro[T]], Callable[Concatenate[Context[Any], P], Coro[T]],]: return self._callback @callback.setter def callback( self, function: Union[ - Callable[Concatenate[CogT, Context, P], Coro[T]], - Callable[Concatenate[Context, P], Coro[T]], + Callable[Concatenate[CogT, Context[Any], P], Coro[T]], + Callable[Concatenate[Context[Any], P], Coro[T]], ], ) -> None: self._callback = function @@ -2394,7 +2394,7 @@ def is_nsfw() -> Check[Any]: def cooldown( rate: int, per: float, - type: Union[BucketType, Callable[[Context], Any]] = BucketType.default, + type: Union[BucketType, Callable[[Context[Any]], Any]] = BucketType.default, ) -> Callable[[T], T]: """A decorator that adds a cooldown to a :class:`.Command` @@ -2433,8 +2433,8 @@ def cooldown( def dynamic_cooldown( - cooldown: Callable[[Context], Optional[Cooldown]], - type: Union[BucketType, Callable[[Context], Any]], + cooldown: Callable[[Context[Any]], Optional[Cooldown]], + type: Union[BucketType, Callable[[Context[Any]], Any]], ) -> Callable[[T], T]: """A decorator that adds a dynamic cooldown to a :class:`.Command`