[commands] Document / type-hint cooldown
This commit is contained in:
parent
ec32b71ff9
commit
1c63816cc0
@ -289,7 +289,7 @@ class Client:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def stickers(self) -> List[GuildSticker]:
|
def stickers(self) -> List[GuildSticker]:
|
||||||
"""List[:class:`GuildSticker`]: The stickers that the connected client has.
|
"""List[:class:`.GuildSticker`]: The stickers that the connected client has.
|
||||||
|
|
||||||
.. versionadded:: 2.0
|
.. versionadded:: 2.0
|
||||||
"""
|
"""
|
||||||
|
@ -22,6 +22,10 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
|||||||
DEALINGS IN THE SOFTWARE.
|
DEALINGS IN THE SOFTWARE.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Any, Callable, Deque, Dict, Optional, Type, TypeVar, TYPE_CHECKING
|
||||||
from discord.enums import Enum
|
from discord.enums import Enum
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
@ -30,6 +34,9 @@ from collections import deque
|
|||||||
from ...abc import PrivateChannel
|
from ...abc import PrivateChannel
|
||||||
from .errors import MaxConcurrencyReached
|
from .errors import MaxConcurrencyReached
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...message import Message
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
'BucketType',
|
'BucketType',
|
||||||
'Cooldown',
|
'Cooldown',
|
||||||
@ -38,6 +45,9 @@ __all__ = (
|
|||||||
'MaxConcurrency',
|
'MaxConcurrency',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
C = TypeVar('C', bound='CooldownMapping')
|
||||||
|
MC = TypeVar('MC', bound='MaxConcurrency')
|
||||||
|
|
||||||
class BucketType(Enum):
|
class BucketType(Enum):
|
||||||
default = 0
|
default = 0
|
||||||
user = 1
|
user = 1
|
||||||
@ -47,7 +57,7 @@ class BucketType(Enum):
|
|||||||
category = 5
|
category = 5
|
||||||
role = 6
|
role = 6
|
||||||
|
|
||||||
def get_key(self, msg):
|
def get_key(self, msg: Message) -> Any:
|
||||||
if self is BucketType.user:
|
if self is BucketType.user:
|
||||||
return msg.author.id
|
return msg.author.id
|
||||||
elif self is BucketType.guild:
|
elif self is BucketType.guild:
|
||||||
@ -57,29 +67,52 @@ class BucketType(Enum):
|
|||||||
elif self is BucketType.member:
|
elif self is BucketType.member:
|
||||||
return ((msg.guild and msg.guild.id), msg.author.id)
|
return ((msg.guild and msg.guild.id), msg.author.id)
|
||||||
elif self is BucketType.category:
|
elif self is BucketType.category:
|
||||||
return (msg.channel.category or msg.channel).id
|
return (msg.channel.category or msg.channel).id # type: ignore
|
||||||
elif self is BucketType.role:
|
elif self is BucketType.role:
|
||||||
# we return the channel id of a private-channel as there are only roles in guilds
|
# we return the channel id of a private-channel as there are only roles in guilds
|
||||||
# and that yields the same result as for a guild with only the @everyone role
|
# and that yields the same result as for a guild with only the @everyone role
|
||||||
# NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
|
# NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
|
||||||
# recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
|
# recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
|
||||||
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id
|
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore
|
||||||
|
|
||||||
def __call__(self, msg):
|
def __call__(self, msg: Message) -> Any:
|
||||||
return self.get_key(msg)
|
return self.get_key(msg)
|
||||||
|
|
||||||
|
|
||||||
class Cooldown:
|
class Cooldown:
|
||||||
|
"""Represents a cooldown for a command.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
-----------
|
||||||
|
rate: :class:`int`
|
||||||
|
The total number of tokens available per :attr:`per` seconds.
|
||||||
|
per: :class:`float`
|
||||||
|
The length of the cooldown period in seconds.
|
||||||
|
"""
|
||||||
|
|
||||||
__slots__ = ('rate', 'per', '_window', '_tokens', '_last')
|
__slots__ = ('rate', 'per', '_window', '_tokens', '_last')
|
||||||
|
|
||||||
def __init__(self, rate, per):
|
def __init__(self, rate: float, per: float) -> None:
|
||||||
self.rate = int(rate)
|
self.rate: int = int(rate)
|
||||||
self.per = float(per)
|
self.per: float = float(per)
|
||||||
self._window = 0.0
|
self._window: float = 0.0
|
||||||
self._tokens = self.rate
|
self._tokens: int = self.rate
|
||||||
self._last = 0.0
|
self._last: float = 0.0
|
||||||
|
|
||||||
def get_tokens(self, current=None):
|
def get_tokens(self, current: Optional[float] = None) -> int:
|
||||||
|
"""Returns the number of available tokens before rate limiting is applied.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
------------
|
||||||
|
current: Optional[:class:`float`]
|
||||||
|
The time in seconds since Unix epoch to calculate tokens at.
|
||||||
|
If not supplied then :func:`time.time()` is used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
--------
|
||||||
|
:class:`int`
|
||||||
|
The number of tokens available before the cooldown is to be applied.
|
||||||
|
"""
|
||||||
if not current:
|
if not current:
|
||||||
current = time.time()
|
current = time.time()
|
||||||
|
|
||||||
@ -89,7 +122,20 @@ class Cooldown:
|
|||||||
tokens = self.rate
|
tokens = self.rate
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
def get_retry_after(self, current=None):
|
def get_retry_after(self, current: Optional[float] = None) -> float:
|
||||||
|
"""Returns the time in seconds until the cooldown will be reset.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-------------
|
||||||
|
current: Optional[:class:`float`]
|
||||||
|
The current time in seconds since Unix epoch.
|
||||||
|
If not supplied, then :func:`time.time()` is used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
:class:`float`
|
||||||
|
The number of seconds to wait before this cooldown will be reset.
|
||||||
|
"""
|
||||||
current = current or time.time()
|
current = current or time.time()
|
||||||
tokens = self.get_tokens(current)
|
tokens = self.get_tokens(current)
|
||||||
|
|
||||||
@ -98,7 +144,20 @@ class Cooldown:
|
|||||||
|
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
def update_rate_limit(self, current=None):
|
def update_rate_limit(self, current: Optional[float] = None) -> Optional[float]:
|
||||||
|
"""Updates the cooldown rate limit.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-------------
|
||||||
|
current: Optional[:class:`float`]
|
||||||
|
The time in seconds since Unix epoch to update the rate limit at.
|
||||||
|
If not supplied, then :func:`time.time()` is used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Optional[:class:`float`]
|
||||||
|
The retry-after time in seconds if rate limited.
|
||||||
|
"""
|
||||||
current = current or time.time()
|
current = current or time.time()
|
||||||
self._last = current
|
self._last = current
|
||||||
|
|
||||||
@ -115,46 +174,58 @@ class Cooldown:
|
|||||||
# we're not so decrement our tokens
|
# we're not so decrement our tokens
|
||||||
self._tokens -= 1
|
self._tokens -= 1
|
||||||
|
|
||||||
def reset(self):
|
def reset(self) -> None:
|
||||||
|
"""Reset the cooldown to its initial state."""
|
||||||
self._tokens = self.rate
|
self._tokens = self.rate
|
||||||
self._last = 0.0
|
self._last = 0.0
|
||||||
|
|
||||||
def copy(self):
|
def copy(self) -> Cooldown:
|
||||||
|
"""Creates a copy of this cooldown.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
--------
|
||||||
|
:class:`Cooldown`
|
||||||
|
A new instance of this cooldown.
|
||||||
|
"""
|
||||||
return Cooldown(self.rate, self.per)
|
return Cooldown(self.rate, self.per)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>'
|
return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>'
|
||||||
|
|
||||||
class CooldownMapping:
|
class CooldownMapping:
|
||||||
def __init__(self, original, type):
|
def __init__(
|
||||||
|
self,
|
||||||
|
original: Optional[Cooldown],
|
||||||
|
type: Callable[[Message], Any],
|
||||||
|
) -> None:
|
||||||
if not callable(type):
|
if not callable(type):
|
||||||
raise TypeError('Cooldown type must be a BucketType or callable')
|
raise TypeError('Cooldown type must be a BucketType or callable')
|
||||||
|
|
||||||
self._cache = {}
|
self._cache: Dict[Any, Cooldown] = {}
|
||||||
self._cooldown = original
|
self._cooldown: Optional[Cooldown] = original
|
||||||
self._type = type
|
self._type: Callable[[Message], Any] = type
|
||||||
|
|
||||||
def copy(self):
|
def copy(self) -> CooldownMapping:
|
||||||
ret = CooldownMapping(self._cooldown, self._type)
|
ret = CooldownMapping(self._cooldown, self._type)
|
||||||
ret._cache = self._cache.copy()
|
ret._cache = self._cache.copy()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def valid(self):
|
def valid(self) -> bool:
|
||||||
return self._cooldown is not None
|
return self._cooldown is not None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type(self):
|
def type(self) -> Callable[[Message], Any]:
|
||||||
return self._type
|
return self._type
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cooldown(cls, rate, per, type):
|
def from_cooldown(cls: Type[C], rate, per, type) -> C:
|
||||||
return cls(Cooldown(rate, per), type)
|
return cls(Cooldown(rate, per), type)
|
||||||
|
|
||||||
def _bucket_key(self, msg):
|
def _bucket_key(self, msg: Message) -> Any:
|
||||||
return self._type(msg)
|
return self._type(msg)
|
||||||
|
|
||||||
def _verify_cache_integrity(self, current=None):
|
def _verify_cache_integrity(self, current: Optional[float] = None) -> None:
|
||||||
# we want to delete all cache objects that haven't been used
|
# we want to delete all cache objects that haven't been used
|
||||||
# in a cooldown window. e.g. if we have a command that has a
|
# in a cooldown window. e.g. if we have a command that has a
|
||||||
# cooldown of 60s and it has not been used in 60s then that key should be deleted
|
# cooldown of 60s and it has not been used in 60s then that key should be deleted
|
||||||
@ -163,12 +234,12 @@ class CooldownMapping:
|
|||||||
for k in dead_keys:
|
for k in dead_keys:
|
||||||
del self._cache[k]
|
del self._cache[k]
|
||||||
|
|
||||||
def create_bucket(self, message):
|
def create_bucket(self, message: Message) -> Cooldown:
|
||||||
return self._cooldown.copy()
|
return self._cooldown.copy() # type: ignore
|
||||||
|
|
||||||
def get_bucket(self, message, current=None):
|
def get_bucket(self, message: Message, current: Optional[float] = None) -> Cooldown:
|
||||||
if self._type is BucketType.default:
|
if self._type is BucketType.default:
|
||||||
return self._cooldown
|
return self._cooldown # type: ignore
|
||||||
|
|
||||||
self._verify_cache_integrity(current)
|
self._verify_cache_integrity(current)
|
||||||
key = self._bucket_key(message)
|
key = self._bucket_key(message)
|
||||||
@ -181,26 +252,30 @@ class CooldownMapping:
|
|||||||
|
|
||||||
return bucket
|
return bucket
|
||||||
|
|
||||||
def update_rate_limit(self, message, current=None):
|
def update_rate_limit(self, message: Message, current: Optional[float] = None) -> Optional[float]:
|
||||||
bucket = self.get_bucket(message, current)
|
bucket = self.get_bucket(message, current)
|
||||||
return bucket.update_rate_limit(current)
|
return bucket.update_rate_limit(current)
|
||||||
|
|
||||||
class DynamicCooldownMapping(CooldownMapping):
|
class DynamicCooldownMapping(CooldownMapping):
|
||||||
|
|
||||||
def __init__(self, factory, type):
|
def __init__(
|
||||||
|
self,
|
||||||
|
factory: Callable[[Message], Cooldown],
|
||||||
|
type: Callable[[Message], Any]
|
||||||
|
) -> None:
|
||||||
super().__init__(None, type)
|
super().__init__(None, type)
|
||||||
self._factory = factory
|
self._factory: Callable[[Message], Cooldown] = factory
|
||||||
|
|
||||||
def copy(self):
|
def copy(self) -> DynamicCooldownMapping:
|
||||||
ret = DynamicCooldownMapping(self._factory, self._type)
|
ret = DynamicCooldownMapping(self._factory, self._type)
|
||||||
ret._cache = self._cache.copy()
|
ret._cache = self._cache.copy()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def valid(self):
|
def valid(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def create_bucket(self, message):
|
def create_bucket(self, message: Message) -> Cooldown:
|
||||||
return self._factory(message)
|
return self._factory(message)
|
||||||
|
|
||||||
class _Semaphore:
|
class _Semaphore:
|
||||||
@ -218,28 +293,28 @@ class _Semaphore:
|
|||||||
|
|
||||||
__slots__ = ('value', 'loop', '_waiters')
|
__slots__ = ('value', 'loop', '_waiters')
|
||||||
|
|
||||||
def __init__(self, number):
|
def __init__(self, number: int) -> None:
|
||||||
self.value = number
|
self.value: int = number
|
||||||
self.loop = asyncio.get_event_loop()
|
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
||||||
self._waiters = deque()
|
self._waiters: Deque[asyncio.Future] = deque()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
return f'<_Semaphore value={self.value} waiters={len(self._waiters)}>'
|
return f'<_Semaphore value={self.value} waiters={len(self._waiters)}>'
|
||||||
|
|
||||||
def locked(self):
|
def locked(self) -> bool:
|
||||||
return self.value == 0
|
return self.value == 0
|
||||||
|
|
||||||
def is_active(self):
|
def is_active(self) -> bool:
|
||||||
return len(self._waiters) > 0
|
return len(self._waiters) > 0
|
||||||
|
|
||||||
def wake_up(self):
|
def wake_up(self) -> None:
|
||||||
while self._waiters:
|
while self._waiters:
|
||||||
future = self._waiters.popleft()
|
future = self._waiters.popleft()
|
||||||
if not future.done():
|
if not future.done():
|
||||||
future.set_result(None)
|
future.set_result(None)
|
||||||
return
|
return
|
||||||
|
|
||||||
async def acquire(self, *, wait=False):
|
async def acquire(self, *, wait: bool = False) -> bool:
|
||||||
if not wait and self.value <= 0:
|
if not wait and self.value <= 0:
|
||||||
# signal that we're not acquiring
|
# signal that we're not acquiring
|
||||||
return False
|
return False
|
||||||
@ -258,18 +333,18 @@ class _Semaphore:
|
|||||||
self.value -= 1
|
self.value -= 1
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def release(self):
|
def release(self) -> None:
|
||||||
self.value += 1
|
self.value += 1
|
||||||
self.wake_up()
|
self.wake_up()
|
||||||
|
|
||||||
class MaxConcurrency:
|
class MaxConcurrency:
|
||||||
__slots__ = ('number', 'per', 'wait', '_mapping')
|
__slots__ = ('number', 'per', 'wait', '_mapping')
|
||||||
|
|
||||||
def __init__(self, number, *, per, wait):
|
def __init__(self, number: int, *, per: BucketType, wait: bool) -> None:
|
||||||
self._mapping = {}
|
self._mapping: Dict[Any, _Semaphore] = {}
|
||||||
self.per = per
|
self.per: BucketType = per
|
||||||
self.number = number
|
self.number: int = number
|
||||||
self.wait = wait
|
self.wait: bool = wait
|
||||||
|
|
||||||
if number <= 0:
|
if number <= 0:
|
||||||
raise ValueError('max_concurrency \'number\' cannot be less than 1')
|
raise ValueError('max_concurrency \'number\' cannot be less than 1')
|
||||||
@ -277,16 +352,16 @@ class MaxConcurrency:
|
|||||||
if not isinstance(per, BucketType):
|
if not isinstance(per, BucketType):
|
||||||
raise TypeError(f'max_concurrency \'per\' must be of type BucketType not {type(per)!r}')
|
raise TypeError(f'max_concurrency \'per\' must be of type BucketType not {type(per)!r}')
|
||||||
|
|
||||||
def copy(self):
|
def copy(self: MC) -> MC:
|
||||||
return self.__class__(self.number, per=self.per, wait=self.wait)
|
return self.__class__(self.number, per=self.per, wait=self.wait)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
return f'<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>'
|
return f'<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>'
|
||||||
|
|
||||||
def get_key(self, message):
|
def get_key(self, message: Message) -> Any:
|
||||||
return self.per.get_key(message)
|
return self.per.get_key(message)
|
||||||
|
|
||||||
async def acquire(self, message):
|
async def acquire(self, message: Message) -> None:
|
||||||
key = self.get_key(message)
|
key = self.get_key(message)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -298,7 +373,7 @@ class MaxConcurrency:
|
|||||||
if not acquired:
|
if not acquired:
|
||||||
raise MaxConcurrencyReached(self.number, self.per)
|
raise MaxConcurrencyReached(self.number, self.per)
|
||||||
|
|
||||||
async def release(self, message):
|
async def release(self, message: Message) -> None:
|
||||||
# Technically there's no reason for this function to be async
|
# Technically there's no reason for this function to be async
|
||||||
# But it might be more useful in the future
|
# But it might be more useful in the future
|
||||||
key = self.get_key(message)
|
key = self.get_key(message)
|
||||||
|
@ -493,7 +493,7 @@ class CommandOnCooldown(CommandError):
|
|||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
-----------
|
-----------
|
||||||
cooldown: ``Cooldown``
|
cooldown: :class:`.Cooldown`
|
||||||
A class with attributes ``rate`` and ``per`` similar to the
|
A class with attributes ``rate`` and ``per`` similar to the
|
||||||
:func:`.cooldown` decorator.
|
:func:`.cooldown` decorator.
|
||||||
type: :class:`BucketType`
|
type: :class:`BucketType`
|
||||||
|
@ -330,6 +330,14 @@ Checks
|
|||||||
|
|
||||||
.. _ext_commands_api_context:
|
.. _ext_commands_api_context:
|
||||||
|
|
||||||
|
Cooldown
|
||||||
|
---------
|
||||||
|
|
||||||
|
.. attributetable:: discord.ext.commands.Cooldown
|
||||||
|
|
||||||
|
.. autoclass:: discord.ext.commands.Cooldown
|
||||||
|
:members:
|
||||||
|
|
||||||
Context
|
Context
|
||||||
--------
|
--------
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user