[commands] Provide a dynamic cooldown system
This commit is contained in:
@@ -34,6 +34,7 @@ __all__ = (
|
||||
'BucketType',
|
||||
'Cooldown',
|
||||
'CooldownMapping',
|
||||
'DynamicCooldownMapping',
|
||||
'MaxConcurrency',
|
||||
)
|
||||
|
||||
@@ -69,19 +70,15 @@ class BucketType(Enum):
|
||||
|
||||
|
||||
class Cooldown:
|
||||
__slots__ = ('rate', 'per', 'type', '_window', '_tokens', '_last')
|
||||
__slots__ = ('rate', 'per', '_window', '_tokens', '_last')
|
||||
|
||||
def __init__(self, rate, per, type):
|
||||
def __init__(self, rate, per):
|
||||
self.rate = int(rate)
|
||||
self.per = float(per)
|
||||
self.type = type
|
||||
self._window = 0.0
|
||||
self._tokens = self.rate
|
||||
self._last = 0.0
|
||||
|
||||
if not callable(self.type):
|
||||
raise TypeError('Cooldown type must be a BucketType or callable')
|
||||
|
||||
def get_tokens(self, current=None):
|
||||
if not current:
|
||||
current = time.time()
|
||||
@@ -128,15 +125,19 @@ class Cooldown:
|
||||
self._last = 0.0
|
||||
|
||||
def copy(self):
|
||||
return Cooldown(self.rate, self.per, self.type)
|
||||
return Cooldown(self.rate, self.per)
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>'
|
||||
|
||||
class CooldownMapping:
|
||||
def __init__(self, original):
|
||||
def __init__(self, original, type):
|
||||
if not callable(type):
|
||||
raise TypeError('Cooldown type must be a BucketType or callable')
|
||||
|
||||
self._cache = {}
|
||||
self._cooldown = original
|
||||
self._type = type
|
||||
|
||||
def copy(self):
|
||||
ret = CooldownMapping(self._cooldown)
|
||||
@@ -152,7 +153,7 @@ class CooldownMapping:
|
||||
return cls(Cooldown(rate, per, type))
|
||||
|
||||
def _bucket_key(self, msg):
|
||||
return self._cooldown.type(msg)
|
||||
return self._type(msg)
|
||||
|
||||
def _verify_cache_integrity(self, current=None):
|
||||
# we want to delete all cache objects that haven't been used
|
||||
@@ -163,14 +164,18 @@ class CooldownMapping:
|
||||
for k in dead_keys:
|
||||
del self._cache[k]
|
||||
|
||||
def create_bucket(self, message):
|
||||
return self._cooldown.copy()
|
||||
|
||||
def get_bucket(self, message, current=None):
|
||||
if self._cooldown.type is BucketType.default:
|
||||
if self._type is BucketType.default:
|
||||
return self._cooldown
|
||||
|
||||
self._verify_cache_integrity(current)
|
||||
key = self._bucket_key(message)
|
||||
if key not in self._cache:
|
||||
bucket = self._cooldown.copy()
|
||||
bucket = self.create_bucket(message)
|
||||
if bucket is not None:
|
||||
self._cache[key] = bucket
|
||||
else:
|
||||
bucket = self._cache[key]
|
||||
@@ -181,6 +186,19 @@ class CooldownMapping:
|
||||
bucket = self.get_bucket(message, current)
|
||||
return bucket.update_rate_limit(current)
|
||||
|
||||
class DynamicCooldownMapping(CooldownMapping):
|
||||
|
||||
def __init__(self, factory, type):
|
||||
super().__init__(None, type)
|
||||
self._factory = factory
|
||||
|
||||
@property
|
||||
def valid(self):
|
||||
return True
|
||||
|
||||
def create_bucket(self, message):
|
||||
return self._factory(message)
|
||||
|
||||
class _Semaphore:
|
||||
"""This class is a version of a semaphore.
|
||||
|
||||
|
@@ -32,7 +32,7 @@ import sys
|
||||
import discord
|
||||
|
||||
from .errors import *
|
||||
from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency
|
||||
from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency, DynamicCooldownMapping
|
||||
from . import converter as converters
|
||||
from ._types import _BaseCommand
|
||||
from .cog import Cog
|
||||
@@ -54,6 +54,7 @@ __all__ = (
|
||||
'bot_has_permissions',
|
||||
'bot_has_any_role',
|
||||
'cooldown',
|
||||
'dynamic_cooldown',
|
||||
'max_concurrency',
|
||||
'dm_only',
|
||||
'guild_only',
|
||||
@@ -256,7 +257,10 @@ class Command(_BaseCommand):
|
||||
except AttributeError:
|
||||
cooldown = kwargs.get('cooldown')
|
||||
finally:
|
||||
self._buckets = CooldownMapping(cooldown)
|
||||
if cooldown is None:
|
||||
self._buckets = CooldownMapping(cooldown, BucketType.default)
|
||||
elif isinstance(cooldown, CooldownMapping):
|
||||
self._buckets = cooldown
|
||||
|
||||
try:
|
||||
max_concurrency = func.__commands_max_concurrency__
|
||||
@@ -799,6 +803,7 @@ class Command(_BaseCommand):
|
||||
dt = ctx.message.edited_at or ctx.message.created_at
|
||||
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
|
||||
bucket = self._buckets.get_bucket(ctx.message, current)
|
||||
if bucket is not None:
|
||||
retry_after = bucket.update_rate_limit(current)
|
||||
if retry_after:
|
||||
raise CommandOnCooldown(bucket, retry_after)
|
||||
@@ -2014,9 +2019,48 @@ def cooldown(rate, per, type=BucketType.default):
|
||||
|
||||
def decorator(func):
|
||||
if isinstance(func, Command):
|
||||
func._buckets = CooldownMapping(Cooldown(rate, per, type))
|
||||
func._buckets = CooldownMapping(Cooldown(rate, per), type)
|
||||
else:
|
||||
func.__commands_cooldown__ = Cooldown(rate, per, type)
|
||||
func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type)
|
||||
return func
|
||||
return decorator
|
||||
|
||||
def dynamic_cooldown(cooldown, type=BucketType.default):
|
||||
"""A decorator that adds a dynamic cooldown to a :class:`.Command`
|
||||
|
||||
This differs from :func:`.cooldown` in that it takes a function that
|
||||
accepts a single parameter of type :class:`.discord.Message` and must
|
||||
return a :class:`.Cooldown`
|
||||
|
||||
A cooldown allows a command to only be used a specific amount
|
||||
of times in a specific time frame. These cooldowns can be based
|
||||
either on a per-guild, per-channel, per-user, per-role or global basis.
|
||||
Denoted by the third argument of ``type`` which must be of enum
|
||||
type :class:`.BucketType`.
|
||||
|
||||
If a cooldown is triggered, then :exc:`.CommandOnCooldown` is triggered in
|
||||
:func:`.on_command_error` and the local error handler.
|
||||
|
||||
A command can only have a single cooldown.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Parameters
|
||||
------------
|
||||
cooldown: Callable[[:class:`.discord.Message`], :class:`.Cooldown`]
|
||||
A function that takes a message and returns a cooldown that will
|
||||
apply to this invocation
|
||||
type: :class:`.BucketType`
|
||||
The type of cooldown to have.
|
||||
"""
|
||||
if not callable(cooldown):
|
||||
raise TypeError("A callable must be provided")
|
||||
|
||||
def decorator(func):
|
||||
if isinstance(func, Command):
|
||||
func._buckets = DynamicCooldownMapping(cooldown, type)
|
||||
else:
|
||||
func.__commands_cooldown__ = DynamicCooldownMapping(cooldown, type)
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
Reference in New Issue
Block a user