[commands] Provide a dynamic cooldown system
This commit is contained in:
parent
ea32147d02
commit
f2d5ab6f80
@ -34,6 +34,7 @@ __all__ = (
|
|||||||
'BucketType',
|
'BucketType',
|
||||||
'Cooldown',
|
'Cooldown',
|
||||||
'CooldownMapping',
|
'CooldownMapping',
|
||||||
|
'DynamicCooldownMapping',
|
||||||
'MaxConcurrency',
|
'MaxConcurrency',
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -69,19 +70,15 @@ class BucketType(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class Cooldown:
|
class Cooldown:
|
||||||
__slots__ = ('rate', 'per', 'type', '_window', '_tokens', '_last')
|
__slots__ = ('rate', 'per', '_window', '_tokens', '_last')
|
||||||
|
|
||||||
def __init__(self, rate, per, type):
|
def __init__(self, rate, per):
|
||||||
self.rate = int(rate)
|
self.rate = int(rate)
|
||||||
self.per = float(per)
|
self.per = float(per)
|
||||||
self.type = type
|
|
||||||
self._window = 0.0
|
self._window = 0.0
|
||||||
self._tokens = self.rate
|
self._tokens = self.rate
|
||||||
self._last = 0.0
|
self._last = 0.0
|
||||||
|
|
||||||
if not callable(self.type):
|
|
||||||
raise TypeError('Cooldown type must be a BucketType or callable')
|
|
||||||
|
|
||||||
def get_tokens(self, current=None):
|
def get_tokens(self, current=None):
|
||||||
if not current:
|
if not current:
|
||||||
current = time.time()
|
current = time.time()
|
||||||
@ -128,15 +125,19 @@ class Cooldown:
|
|||||||
self._last = 0.0
|
self._last = 0.0
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
return Cooldown(self.rate, self.per, self.type)
|
return Cooldown(self.rate, self.per)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>'
|
return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>'
|
||||||
|
|
||||||
class CooldownMapping:
|
class CooldownMapping:
|
||||||
def __init__(self, original):
|
def __init__(self, original, type):
|
||||||
|
if not callable(type):
|
||||||
|
raise TypeError('Cooldown type must be a BucketType or callable')
|
||||||
|
|
||||||
self._cache = {}
|
self._cache = {}
|
||||||
self._cooldown = original
|
self._cooldown = original
|
||||||
|
self._type = type
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
ret = CooldownMapping(self._cooldown)
|
ret = CooldownMapping(self._cooldown)
|
||||||
@ -152,7 +153,7 @@ class CooldownMapping:
|
|||||||
return cls(Cooldown(rate, per, type))
|
return cls(Cooldown(rate, per, type))
|
||||||
|
|
||||||
def _bucket_key(self, msg):
|
def _bucket_key(self, msg):
|
||||||
return self._cooldown.type(msg)
|
return self._type(msg)
|
||||||
|
|
||||||
def _verify_cache_integrity(self, current=None):
|
def _verify_cache_integrity(self, current=None):
|
||||||
# we want to delete all cache objects that haven't been used
|
# we want to delete all cache objects that haven't been used
|
||||||
@ -163,14 +164,18 @@ class CooldownMapping:
|
|||||||
for k in dead_keys:
|
for k in dead_keys:
|
||||||
del self._cache[k]
|
del self._cache[k]
|
||||||
|
|
||||||
|
def create_bucket(self, message):
|
||||||
|
return self._cooldown.copy()
|
||||||
|
|
||||||
def get_bucket(self, message, current=None):
|
def get_bucket(self, message, current=None):
|
||||||
if self._cooldown.type is BucketType.default:
|
if self._type is BucketType.default:
|
||||||
return self._cooldown
|
return self._cooldown
|
||||||
|
|
||||||
self._verify_cache_integrity(current)
|
self._verify_cache_integrity(current)
|
||||||
key = self._bucket_key(message)
|
key = self._bucket_key(message)
|
||||||
if key not in self._cache:
|
if key not in self._cache:
|
||||||
bucket = self._cooldown.copy()
|
bucket = self.create_bucket(message)
|
||||||
|
if bucket is not None:
|
||||||
self._cache[key] = bucket
|
self._cache[key] = bucket
|
||||||
else:
|
else:
|
||||||
bucket = self._cache[key]
|
bucket = self._cache[key]
|
||||||
@ -181,6 +186,19 @@ class CooldownMapping:
|
|||||||
bucket = self.get_bucket(message, current)
|
bucket = self.get_bucket(message, current)
|
||||||
return bucket.update_rate_limit(current)
|
return bucket.update_rate_limit(current)
|
||||||
|
|
||||||
|
class DynamicCooldownMapping(CooldownMapping):
|
||||||
|
|
||||||
|
def __init__(self, factory, type):
|
||||||
|
super().__init__(None, type)
|
||||||
|
self._factory = factory
|
||||||
|
|
||||||
|
@property
|
||||||
|
def valid(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def create_bucket(self, message):
|
||||||
|
return self._factory(message)
|
||||||
|
|
||||||
class _Semaphore:
|
class _Semaphore:
|
||||||
"""This class is a version of a semaphore.
|
"""This class is a version of a semaphore.
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ import sys
|
|||||||
import discord
|
import discord
|
||||||
|
|
||||||
from .errors import *
|
from .errors import *
|
||||||
from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency
|
from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency, DynamicCooldownMapping
|
||||||
from . import converter as converters
|
from . import converter as converters
|
||||||
from ._types import _BaseCommand
|
from ._types import _BaseCommand
|
||||||
from .cog import Cog
|
from .cog import Cog
|
||||||
@ -54,6 +54,7 @@ __all__ = (
|
|||||||
'bot_has_permissions',
|
'bot_has_permissions',
|
||||||
'bot_has_any_role',
|
'bot_has_any_role',
|
||||||
'cooldown',
|
'cooldown',
|
||||||
|
'dynamic_cooldown',
|
||||||
'max_concurrency',
|
'max_concurrency',
|
||||||
'dm_only',
|
'dm_only',
|
||||||
'guild_only',
|
'guild_only',
|
||||||
@ -256,7 +257,10 @@ class Command(_BaseCommand):
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
cooldown = kwargs.get('cooldown')
|
cooldown = kwargs.get('cooldown')
|
||||||
finally:
|
finally:
|
||||||
self._buckets = CooldownMapping(cooldown)
|
if cooldown is None:
|
||||||
|
self._buckets = CooldownMapping(cooldown, BucketType.default)
|
||||||
|
elif isinstance(cooldown, CooldownMapping):
|
||||||
|
self._buckets = cooldown
|
||||||
|
|
||||||
try:
|
try:
|
||||||
max_concurrency = func.__commands_max_concurrency__
|
max_concurrency = func.__commands_max_concurrency__
|
||||||
@ -799,6 +803,7 @@ class Command(_BaseCommand):
|
|||||||
dt = ctx.message.edited_at or ctx.message.created_at
|
dt = ctx.message.edited_at or ctx.message.created_at
|
||||||
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
|
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
|
||||||
bucket = self._buckets.get_bucket(ctx.message, current)
|
bucket = self._buckets.get_bucket(ctx.message, current)
|
||||||
|
if bucket is not None:
|
||||||
retry_after = bucket.update_rate_limit(current)
|
retry_after = bucket.update_rate_limit(current)
|
||||||
if retry_after:
|
if retry_after:
|
||||||
raise CommandOnCooldown(bucket, retry_after)
|
raise CommandOnCooldown(bucket, retry_after)
|
||||||
@ -2014,9 +2019,48 @@ def cooldown(rate, per, type=BucketType.default):
|
|||||||
|
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
if isinstance(func, Command):
|
if isinstance(func, Command):
|
||||||
func._buckets = CooldownMapping(Cooldown(rate, per, type))
|
func._buckets = CooldownMapping(Cooldown(rate, per), type)
|
||||||
else:
|
else:
|
||||||
func.__commands_cooldown__ = Cooldown(rate, per, type)
|
func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type)
|
||||||
|
return func
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def dynamic_cooldown(cooldown, type=BucketType.default):
|
||||||
|
"""A decorator that adds a dynamic cooldown to a :class:`.Command`
|
||||||
|
|
||||||
|
This differs from :func:`.cooldown` in that it takes a function that
|
||||||
|
accepts a single parameter of type :class:`.discord.Message` and must
|
||||||
|
return a :class:`.Cooldown`
|
||||||
|
|
||||||
|
A cooldown allows a command to only be used a specific amount
|
||||||
|
of times in a specific time frame. These cooldowns can be based
|
||||||
|
either on a per-guild, per-channel, per-user, per-role or global basis.
|
||||||
|
Denoted by the third argument of ``type`` which must be of enum
|
||||||
|
type :class:`.BucketType`.
|
||||||
|
|
||||||
|
If a cooldown is triggered, then :exc:`.CommandOnCooldown` is triggered in
|
||||||
|
:func:`.on_command_error` and the local error handler.
|
||||||
|
|
||||||
|
A command can only have a single cooldown.
|
||||||
|
|
||||||
|
.. versionadded:: 2.0
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
------------
|
||||||
|
cooldown: Callable[[:class:`.discord.Message`], :class:`.Cooldown`]
|
||||||
|
A function that takes a message and returns a cooldown that will
|
||||||
|
apply to this invocation
|
||||||
|
type: :class:`.BucketType`
|
||||||
|
The type of cooldown to have.
|
||||||
|
"""
|
||||||
|
if not callable(cooldown):
|
||||||
|
raise TypeError("A callable must be provided")
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
if isinstance(func, Command):
|
||||||
|
func._buckets = DynamicCooldownMapping(cooldown, type)
|
||||||
|
else:
|
||||||
|
func.__commands_cooldown__ = DynamicCooldownMapping(cooldown, type)
|
||||||
return func
|
return func
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user