mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-09-04 17:06:21 +00:00
Fix various generics throughout the public interface
Fix CooldownMapping generic typing and ensure other public methods have proper generics
This commit is contained in:
@ -40,7 +40,6 @@ if TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
|
||||
from ...message import Message
|
||||
from ._types import BotT
|
||||
|
||||
__all__ = (
|
||||
'BucketType',
|
||||
@ -50,7 +49,7 @@ __all__ = (
|
||||
'MaxConcurrency',
|
||||
)
|
||||
|
||||
T = TypeVar('T')
|
||||
T_contra = TypeVar('T_contra', contravariant=True)
|
||||
|
||||
|
||||
class BucketType(Enum):
|
||||
@ -62,7 +61,7 @@ class BucketType(Enum):
|
||||
category = 5
|
||||
role = 6
|
||||
|
||||
def get_key(self, msg: Union[Message, Context[BotT]]) -> Any:
|
||||
def get_key(self, msg: Union[Message, Context[Any]]) -> Any:
|
||||
if self is BucketType.user:
|
||||
return msg.author.id
|
||||
elif self is BucketType.guild:
|
||||
@ -80,24 +79,24 @@ class BucketType(Enum):
|
||||
# receiving a DMChannel or GroupChannel which inherit from PrivateChannel and do
|
||||
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore
|
||||
|
||||
def __call__(self, msg: Union[Message, Context[BotT]]) -> Any:
|
||||
def __call__(self, msg: Union[Message, Context[Any]]) -> Any:
|
||||
return self.get_key(msg)
|
||||
|
||||
|
||||
class CooldownMapping(Generic[T]):
|
||||
class CooldownMapping(Generic[T_contra]):
|
||||
def __init__(
|
||||
self,
|
||||
original: Optional[Cooldown],
|
||||
type: Callable[[T], Any],
|
||||
type: Callable[[T_contra], Any],
|
||||
) -> None:
|
||||
if not callable(type):
|
||||
raise TypeError('Cooldown type must be a BucketType or callable')
|
||||
|
||||
self._cache: Dict[Any, Cooldown] = {}
|
||||
self._cooldown: Optional[Cooldown] = original
|
||||
self._type: Callable[[T], Any] = type
|
||||
self._type: Callable[[T_contra], Any] = type
|
||||
|
||||
def copy(self) -> CooldownMapping:
|
||||
def copy(self) -> CooldownMapping[T_contra]:
|
||||
ret = CooldownMapping(self._cooldown, self._type)
|
||||
ret._cache = self._cache.copy()
|
||||
return ret
|
||||
@ -107,14 +106,14 @@ class CooldownMapping(Generic[T]):
|
||||
return self._cooldown is not None
|
||||
|
||||
@property
|
||||
def type(self) -> Callable[[T], Any]:
|
||||
def type(self) -> Callable[[T_contra], Any]:
|
||||
return self._type
|
||||
|
||||
@classmethod
|
||||
def from_cooldown(cls, rate: float, per: float, type: Callable[[T], Any]) -> Self:
|
||||
def from_cooldown(cls, rate: float, per: float, type: Callable[[T_contra], Any]) -> Self:
|
||||
return cls(Cooldown(rate, per), type)
|
||||
|
||||
def _bucket_key(self, msg: T) -> Any:
|
||||
def _bucket_key(self, msg: T_contra) -> Any:
|
||||
return self._type(msg)
|
||||
|
||||
def _verify_cache_integrity(self, current: Optional[float] = None) -> None:
|
||||
@ -126,10 +125,10 @@ class CooldownMapping(Generic[T]):
|
||||
for k in dead_keys:
|
||||
del self._cache[k]
|
||||
|
||||
def create_bucket(self, message: T) -> Cooldown:
|
||||
def create_bucket(self, message: T_contra) -> Cooldown:
|
||||
return self._cooldown.copy() # type: ignore
|
||||
|
||||
def get_bucket(self, message: T, current: Optional[float] = None) -> Optional[Cooldown]:
|
||||
def get_bucket(self, message: T_contra, current: Optional[float] = None) -> Optional[Cooldown]:
|
||||
if self._type is BucketType.default:
|
||||
return self._cooldown
|
||||
|
||||
@ -144,23 +143,23 @@ class CooldownMapping(Generic[T]):
|
||||
|
||||
return bucket
|
||||
|
||||
def update_rate_limit(self, message: T, current: Optional[float] = None, tokens: int = 1) -> Optional[float]:
|
||||
def update_rate_limit(self, message: T_contra, current: Optional[float] = None, tokens: int = 1) -> Optional[float]:
|
||||
bucket = self.get_bucket(message, current)
|
||||
if bucket is None:
|
||||
return None
|
||||
return bucket.update_rate_limit(current, tokens=tokens)
|
||||
|
||||
|
||||
class DynamicCooldownMapping(CooldownMapping[T]):
|
||||
class DynamicCooldownMapping(CooldownMapping[T_contra]):
|
||||
def __init__(
|
||||
self,
|
||||
factory: Callable[[T], Optional[Cooldown]],
|
||||
type: Callable[[T], Any],
|
||||
factory: Callable[[T_contra], Optional[Cooldown]],
|
||||
type: Callable[[T_contra], Any],
|
||||
) -> None:
|
||||
super().__init__(None, type)
|
||||
self._factory: Callable[[T], Optional[Cooldown]] = factory
|
||||
self._factory: Callable[[T_contra], Optional[Cooldown]] = factory
|
||||
|
||||
def copy(self) -> DynamicCooldownMapping:
|
||||
def copy(self) -> DynamicCooldownMapping[T_contra]:
|
||||
ret = DynamicCooldownMapping(self._factory, self._type)
|
||||
ret._cache = self._cache.copy()
|
||||
return ret
|
||||
@ -169,7 +168,7 @@ class DynamicCooldownMapping(CooldownMapping[T]):
|
||||
def valid(self) -> bool:
|
||||
return True
|
||||
|
||||
def create_bucket(self, message: T) -> Optional[Cooldown]:
|
||||
def create_bucket(self, message: T_contra) -> Optional[Cooldown]:
|
||||
return self._factory(message)
|
||||
|
||||
|
||||
@ -254,10 +253,10 @@ class MaxConcurrency:
|
||||
def __repr__(self) -> str:
|
||||
return f'<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>'
|
||||
|
||||
def get_key(self, message: Message) -> Any:
|
||||
def get_key(self, message: Union[Message, Context[Any]]) -> Any:
|
||||
return self.per.get_key(message)
|
||||
|
||||
async def acquire(self, message: Message) -> None:
|
||||
async def acquire(self, message: Union[Message, Context[Any]]) -> None:
|
||||
key = self.get_key(message)
|
||||
|
||||
try:
|
||||
@ -269,7 +268,7 @@ class MaxConcurrency:
|
||||
if not acquired:
|
||||
raise MaxConcurrencyReached(self.number, self.per)
|
||||
|
||||
async def release(self, message: Message) -> None:
|
||||
async def release(self, message: Union[Message, Context[Any]]) -> None:
|
||||
# Technically there's no reason for this function to be async
|
||||
# But it might be more useful in the future
|
||||
key = self.get_key(message)
|
||||
|
Reference in New Issue
Block a user