mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-07-02 00:00:02 +00:00
[commands] Change cooldowns to take context instead of message
This commit is contained in:
parent
406495b465
commit
311891912e
@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
from typing import Any, Callable, Deque, Dict, Optional, TYPE_CHECKING
|
from typing import Any, Callable, Deque, Dict, Optional, Union, Generic, TypeVar, TYPE_CHECKING
|
||||||
from discord.enums import Enum
|
from discord.enums import Enum
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
@ -33,12 +33,14 @@ from collections import deque
|
|||||||
|
|
||||||
from ...abc import PrivateChannel
|
from ...abc import PrivateChannel
|
||||||
from .errors import MaxConcurrencyReached
|
from .errors import MaxConcurrencyReached
|
||||||
|
from .context import Context
|
||||||
from discord.app_commands import Cooldown as Cooldown
|
from discord.app_commands import Cooldown as Cooldown
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from ...message import Message
|
from ...message import Message
|
||||||
|
from ._types import BotT
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
'BucketType',
|
'BucketType',
|
||||||
@ -48,6 +50,8 @@ __all__ = (
|
|||||||
'MaxConcurrency',
|
'MaxConcurrency',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
class BucketType(Enum):
|
class BucketType(Enum):
|
||||||
default = 0
|
default = 0
|
||||||
@ -58,7 +62,7 @@ class BucketType(Enum):
|
|||||||
category = 5
|
category = 5
|
||||||
role = 6
|
role = 6
|
||||||
|
|
||||||
def get_key(self, msg: Message) -> Any:
|
def get_key(self, msg: Union[Message, Context[BotT]]) -> Any:
|
||||||
if self is BucketType.user:
|
if self is BucketType.user:
|
||||||
return msg.author.id
|
return msg.author.id
|
||||||
elif self is BucketType.guild:
|
elif self is BucketType.guild:
|
||||||
@ -76,22 +80,22 @@ class BucketType(Enum):
|
|||||||
# receiving a DMChannel or GroupChannel which inherit from PrivateChannel and do
|
# receiving a DMChannel or GroupChannel which inherit from PrivateChannel and do
|
||||||
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore
|
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore
|
||||||
|
|
||||||
def __call__(self, msg: Message) -> Any:
|
def __call__(self, msg: Union[Message, Context[BotT]]) -> Any:
|
||||||
return self.get_key(msg)
|
return self.get_key(msg)
|
||||||
|
|
||||||
|
|
||||||
class CooldownMapping:
|
class CooldownMapping(Generic[T]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
original: Optional[Cooldown],
|
original: Optional[Cooldown],
|
||||||
type: Callable[[Message], Any],
|
type: Callable[[T], Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
if not callable(type):
|
if not callable(type):
|
||||||
raise TypeError('Cooldown type must be a BucketType or callable')
|
raise TypeError('Cooldown type must be a BucketType or callable')
|
||||||
|
|
||||||
self._cache: Dict[Any, Cooldown] = {}
|
self._cache: Dict[Any, Cooldown] = {}
|
||||||
self._cooldown: Optional[Cooldown] = original
|
self._cooldown: Optional[Cooldown] = original
|
||||||
self._type: Callable[[Message], Any] = type
|
self._type: Callable[[T], Any] = type
|
||||||
|
|
||||||
def copy(self) -> CooldownMapping:
|
def copy(self) -> CooldownMapping:
|
||||||
ret = CooldownMapping(self._cooldown, self._type)
|
ret = CooldownMapping(self._cooldown, self._type)
|
||||||
@ -103,14 +107,14 @@ class CooldownMapping:
|
|||||||
return self._cooldown is not None
|
return self._cooldown is not None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type(self) -> Callable[[Message], Any]:
|
def type(self) -> Callable[[T], Any]:
|
||||||
return self._type
|
return self._type
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cooldown(cls, rate: float, per: float, type: Callable[[Message], Any]) -> Self:
|
def from_cooldown(cls, rate: float, per: float, type: Callable[[T], Any]) -> Self:
|
||||||
return cls(Cooldown(rate, per), type)
|
return cls(Cooldown(rate, per), type)
|
||||||
|
|
||||||
def _bucket_key(self, msg: Message) -> Any:
|
def _bucket_key(self, msg: T) -> Any:
|
||||||
return self._type(msg)
|
return self._type(msg)
|
||||||
|
|
||||||
def _verify_cache_integrity(self, current: Optional[float] = None) -> None:
|
def _verify_cache_integrity(self, current: Optional[float] = None) -> None:
|
||||||
@ -122,10 +126,10 @@ class CooldownMapping:
|
|||||||
for k in dead_keys:
|
for k in dead_keys:
|
||||||
del self._cache[k]
|
del self._cache[k]
|
||||||
|
|
||||||
def create_bucket(self, message: Message) -> Cooldown:
|
def create_bucket(self, message: T) -> Cooldown:
|
||||||
return self._cooldown.copy() # type: ignore
|
return self._cooldown.copy() # type: ignore
|
||||||
|
|
||||||
def get_bucket(self, message: Message, current: Optional[float] = None) -> Optional[Cooldown]:
|
def get_bucket(self, message: T, current: Optional[float] = None) -> Optional[Cooldown]:
|
||||||
if self._type is BucketType.default:
|
if self._type is BucketType.default:
|
||||||
return self._cooldown
|
return self._cooldown
|
||||||
|
|
||||||
@ -140,21 +144,21 @@ class CooldownMapping:
|
|||||||
|
|
||||||
return bucket
|
return bucket
|
||||||
|
|
||||||
def update_rate_limit(self, message: Message, current: Optional[float] = None, tokens: int = 1) -> Optional[float]:
|
def update_rate_limit(self, message: T, current: Optional[float] = None, tokens: int = 1) -> Optional[float]:
|
||||||
bucket = self.get_bucket(message, current)
|
bucket = self.get_bucket(message, current)
|
||||||
if bucket is None:
|
if bucket is None:
|
||||||
return None
|
return None
|
||||||
return bucket.update_rate_limit(current, tokens=tokens)
|
return bucket.update_rate_limit(current, tokens=tokens)
|
||||||
|
|
||||||
|
|
||||||
class DynamicCooldownMapping(CooldownMapping):
|
class DynamicCooldownMapping(CooldownMapping[T]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
factory: Callable[[Message], Optional[Cooldown]],
|
factory: Callable[[T], Optional[Cooldown]],
|
||||||
type: Callable[[Message], Any],
|
type: Callable[[T], Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(None, type)
|
super().__init__(None, type)
|
||||||
self._factory: Callable[[Message], Optional[Cooldown]] = factory
|
self._factory: Callable[[T], Optional[Cooldown]] = factory
|
||||||
|
|
||||||
def copy(self) -> DynamicCooldownMapping:
|
def copy(self) -> DynamicCooldownMapping:
|
||||||
ret = DynamicCooldownMapping(self._factory, self._type)
|
ret = DynamicCooldownMapping(self._factory, self._type)
|
||||||
@ -165,7 +169,7 @@ class DynamicCooldownMapping(CooldownMapping):
|
|||||||
def valid(self) -> bool:
|
def valid(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def create_bucket(self, message: Message) -> Optional[Cooldown]:
|
def create_bucket(self, message: T) -> Optional[Cooldown]:
|
||||||
return self._factory(message)
|
return self._factory(message)
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,8 +58,6 @@ from .parameters import Parameter, Signature
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing_extensions import Concatenate, ParamSpec, Self
|
from typing_extensions import Concatenate, ParamSpec, Self
|
||||||
|
|
||||||
from discord.message import Message
|
|
||||||
|
|
||||||
from ._types import BotT, Check, ContextT, Coro, CoroFunc, Error, Hook, UserCheck
|
from ._types import BotT, Check, ContextT, Coro, CoroFunc, Error, Hook, UserCheck
|
||||||
|
|
||||||
|
|
||||||
@ -409,10 +407,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
|||||||
if cooldown is None:
|
if cooldown is None:
|
||||||
buckets = CooldownMapping(cooldown, BucketType.default)
|
buckets = CooldownMapping(cooldown, BucketType.default)
|
||||||
elif isinstance(cooldown, CooldownMapping):
|
elif isinstance(cooldown, CooldownMapping):
|
||||||
buckets = cooldown
|
buckets: CooldownMapping[Context] = cooldown
|
||||||
else:
|
else:
|
||||||
raise TypeError("Cooldown must be a an instance of CooldownMapping or None.")
|
raise TypeError("Cooldown must be a an instance of CooldownMapping or None.")
|
||||||
self._buckets: CooldownMapping = buckets
|
self._buckets: CooldownMapping[Context] = buckets
|
||||||
|
|
||||||
try:
|
try:
|
||||||
max_concurrency = func.__commands_max_concurrency__
|
max_concurrency = func.__commands_max_concurrency__
|
||||||
@ -879,7 +877,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
|||||||
if self._buckets.valid:
|
if self._buckets.valid:
|
||||||
dt = ctx.message.edited_at or ctx.message.created_at
|
dt = ctx.message.edited_at or ctx.message.created_at
|
||||||
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
|
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
|
||||||
bucket = self._buckets.get_bucket(ctx.message, current)
|
bucket = self._buckets.get_bucket(ctx, current)
|
||||||
if bucket is not None:
|
if bucket is not None:
|
||||||
retry_after = bucket.update_rate_limit(current)
|
retry_after = bucket.update_rate_limit(current)
|
||||||
if retry_after:
|
if retry_after:
|
||||||
@ -929,7 +927,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
|||||||
if not self._buckets.valid:
|
if not self._buckets.valid:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
bucket = self._buckets.get_bucket(ctx.message)
|
bucket = self._buckets.get_bucket(ctx)
|
||||||
if bucket is None:
|
if bucket is None:
|
||||||
return False
|
return False
|
||||||
dt = ctx.message.edited_at or ctx.message.created_at
|
dt = ctx.message.edited_at or ctx.message.created_at
|
||||||
@ -949,7 +947,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
|||||||
The invocation context to reset the cooldown under.
|
The invocation context to reset the cooldown under.
|
||||||
"""
|
"""
|
||||||
if self._buckets.valid:
|
if self._buckets.valid:
|
||||||
bucket = self._buckets.get_bucket(ctx.message)
|
bucket = self._buckets.get_bucket(ctx)
|
||||||
if bucket is not None:
|
if bucket is not None:
|
||||||
bucket.reset()
|
bucket.reset()
|
||||||
|
|
||||||
@ -974,7 +972,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
|||||||
If this is ``0.0`` then the command isn't on cooldown.
|
If this is ``0.0`` then the command isn't on cooldown.
|
||||||
"""
|
"""
|
||||||
if self._buckets.valid:
|
if self._buckets.valid:
|
||||||
bucket = self._buckets.get_bucket(ctx.message)
|
bucket = self._buckets.get_bucket(ctx)
|
||||||
if bucket is None:
|
if bucket is None:
|
||||||
return 0.0
|
return 0.0
|
||||||
dt = ctx.message.edited_at or ctx.message.created_at
|
dt = ctx.message.edited_at or ctx.message.created_at
|
||||||
@ -2399,7 +2397,7 @@ def is_nsfw() -> Check[Any]:
|
|||||||
def cooldown(
|
def cooldown(
|
||||||
rate: int,
|
rate: int,
|
||||||
per: float,
|
per: float,
|
||||||
type: Union[BucketType, Callable[[Message], Any]] = BucketType.default,
|
type: Union[BucketType, Callable[[Context], Any]] = BucketType.default,
|
||||||
) -> Callable[[T], T]:
|
) -> Callable[[T], T]:
|
||||||
"""A decorator that adds a cooldown to a :class:`.Command`
|
"""A decorator that adds a cooldown to a :class:`.Command`
|
||||||
|
|
||||||
@ -2420,7 +2418,7 @@ def cooldown(
|
|||||||
The number of times a command can be used before triggering a cooldown.
|
The number of times a command can be used before triggering a cooldown.
|
||||||
per: :class:`float`
|
per: :class:`float`
|
||||||
The amount of seconds to wait for a cooldown when it's been triggered.
|
The amount of seconds to wait for a cooldown when it's been triggered.
|
||||||
type: Union[:class:`.BucketType`, Callable[[:class:`.Message`], Any]]
|
type: Union[:class:`.BucketType`, Callable[[:class:`.Context`], Any]]
|
||||||
The type of cooldown to have. If callable, should return a key for the mapping.
|
The type of cooldown to have. If callable, should return a key for the mapping.
|
||||||
|
|
||||||
.. versionchanged:: 1.7
|
.. versionchanged:: 1.7
|
||||||
@ -2431,15 +2429,15 @@ def cooldown(
|
|||||||
if isinstance(func, Command):
|
if isinstance(func, Command):
|
||||||
func._buckets = CooldownMapping(Cooldown(rate, per), type)
|
func._buckets = CooldownMapping(Cooldown(rate, per), type)
|
||||||
else:
|
else:
|
||||||
func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type)
|
func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type) # type: ignore # typevar cannot be inferred without annotation
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator # type: ignore
|
return decorator # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def dynamic_cooldown(
|
def dynamic_cooldown(
|
||||||
cooldown: Union[BucketType, Callable[[Message], Any]],
|
cooldown: Callable[[Context], Cooldown | None],
|
||||||
type: BucketType,
|
type: BucketType | Callable[[Context], Any],
|
||||||
) -> Callable[[T], T]:
|
) -> Callable[[T], T]:
|
||||||
"""A decorator that adds a dynamic cooldown to a :class:`.Command`
|
"""A decorator that adds a dynamic cooldown to a :class:`.Command`
|
||||||
|
|
||||||
@ -2463,7 +2461,7 @@ def dynamic_cooldown(
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
------------
|
------------
|
||||||
cooldown: Callable[[:class:`.discord.Message`], Optional[:class:`~discord.app_commands.Cooldown`]]
|
cooldown: Callable[[:class:`.Context`], Optional[:class:`~discord.app_commands.Cooldown`]]
|
||||||
A function that takes a message and returns a cooldown that will
|
A function that takes a message and returns a cooldown that will
|
||||||
apply to this invocation or ``None`` if the cooldown should be bypassed.
|
apply to this invocation or ``None`` if the cooldown should be bypassed.
|
||||||
type: :class:`.BucketType`
|
type: :class:`.BucketType`
|
||||||
|
Loading…
x
Reference in New Issue
Block a user