[commands] Change cooldowns to take context instead of message

This commit is contained in:
Mikey
2022-07-23 04:08:44 -07:00
committed by GitHub
parent 406495b465
commit 311891912e
2 changed files with 33 additions and 31 deletions

View File

@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Any, Callable, Deque, Dict, Optional, TYPE_CHECKING
from typing import Any, Callable, Deque, Dict, Optional, Union, Generic, TypeVar, TYPE_CHECKING
from discord.enums import Enum
import time
import asyncio
@ -33,12 +33,14 @@ from collections import deque
from ...abc import PrivateChannel
from .errors import MaxConcurrencyReached
from .context import Context
from discord.app_commands import Cooldown as Cooldown
if TYPE_CHECKING:
from typing_extensions import Self
from ...message import Message
from ._types import BotT
__all__ = (
'BucketType',
@ -48,6 +50,8 @@ __all__ = (
'MaxConcurrency',
)
T = TypeVar('T')
class BucketType(Enum):
default = 0
@ -58,7 +62,7 @@ class BucketType(Enum):
category = 5
role = 6
def get_key(self, msg: Message) -> Any:
def get_key(self, msg: Union[Message, Context[BotT]]) -> Any:
if self is BucketType.user:
return msg.author.id
elif self is BucketType.guild:
@ -76,22 +80,22 @@ class BucketType(Enum):
# receiving a DMChannel or GroupChannel which inherit from PrivateChannel and do
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore
def __call__(self, msg: Message) -> Any:
def __call__(self, msg: Union[Message, Context[BotT]]) -> Any:
return self.get_key(msg)
class CooldownMapping:
class CooldownMapping(Generic[T]):
def __init__(
self,
original: Optional[Cooldown],
type: Callable[[Message], Any],
type: Callable[[T], Any],
) -> None:
if not callable(type):
raise TypeError('Cooldown type must be a BucketType or callable')
self._cache: Dict[Any, Cooldown] = {}
self._cooldown: Optional[Cooldown] = original
self._type: Callable[[Message], Any] = type
self._type: Callable[[T], Any] = type
def copy(self) -> CooldownMapping:
ret = CooldownMapping(self._cooldown, self._type)
@ -103,14 +107,14 @@ class CooldownMapping:
return self._cooldown is not None
@property
def type(self) -> Callable[[Message], Any]:
def type(self) -> Callable[[T], Any]:
return self._type
@classmethod
def from_cooldown(cls, rate: float, per: float, type: Callable[[Message], Any]) -> Self:
def from_cooldown(cls, rate: float, per: float, type: Callable[[T], Any]) -> Self:
return cls(Cooldown(rate, per), type)
def _bucket_key(self, msg: Message) -> Any:
def _bucket_key(self, msg: T) -> Any:
return self._type(msg)
def _verify_cache_integrity(self, current: Optional[float] = None) -> None:
@ -122,10 +126,10 @@ class CooldownMapping:
for k in dead_keys:
del self._cache[k]
def create_bucket(self, message: Message) -> Cooldown:
def create_bucket(self, message: T) -> Cooldown:
return self._cooldown.copy() # type: ignore
def get_bucket(self, message: Message, current: Optional[float] = None) -> Optional[Cooldown]:
def get_bucket(self, message: T, current: Optional[float] = None) -> Optional[Cooldown]:
if self._type is BucketType.default:
return self._cooldown
@ -140,21 +144,21 @@ class CooldownMapping:
return bucket
def update_rate_limit(self, message: Message, current: Optional[float] = None, tokens: int = 1) -> Optional[float]:
def update_rate_limit(self, message: T, current: Optional[float] = None, tokens: int = 1) -> Optional[float]:
bucket = self.get_bucket(message, current)
if bucket is None:
return None
return bucket.update_rate_limit(current, tokens=tokens)
class DynamicCooldownMapping(CooldownMapping):
class DynamicCooldownMapping(CooldownMapping[T]):
def __init__(
self,
factory: Callable[[Message], Optional[Cooldown]],
type: Callable[[Message], Any],
factory: Callable[[T], Optional[Cooldown]],
type: Callable[[T], Any],
) -> None:
super().__init__(None, type)
self._factory: Callable[[Message], Optional[Cooldown]] = factory
self._factory: Callable[[T], Optional[Cooldown]] = factory
def copy(self) -> DynamicCooldownMapping:
ret = DynamicCooldownMapping(self._factory, self._type)
@ -165,7 +169,7 @@ class DynamicCooldownMapping(CooldownMapping):
def valid(self) -> bool:
return True
def create_bucket(self, message: Message) -> Optional[Cooldown]:
def create_bucket(self, message: T) -> Optional[Cooldown]:
return self._factory(message)