[commands] Add max_concurrency decorator

This commit is contained in:
Rapptz
2020-01-21 03:26:41 -05:00
parent 3149f15165
commit bf84c63396
3 changed files with 204 additions and 1 deletions

View File

@@ -33,7 +33,7 @@ import datetime
import discord
from .errors import *
from .cooldowns import Cooldown, BucketType, CooldownMapping
from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency
from . import converter as converters
from ._types import _BaseCommand
from .cog import Cog
@@ -53,6 +53,7 @@ __all__ = (
'bot_has_permissions',
'bot_has_any_role',
'cooldown',
'max_concurrency',
'dm_only',
'guild_only',
'is_owner',
@@ -90,6 +91,9 @@ def hooked_wrapped_callback(command, ctx, coro):
ctx.command_failed = True
raise CommandInvokeError(exc) from exc
finally:
if command._max_concurrency is not None:
await command._max_concurrency.release(ctx)
await command.call_after_hooks(ctx)
return ret
return wrapped
@@ -248,6 +252,13 @@ class Command(_BaseCommand):
finally:
self._buckets = CooldownMapping(cooldown)
try:
max_concurrency = func.__commands_max_concurrency__
except AttributeError:
max_concurrency = kwargs.get('max_concurrency')
finally:
self._max_concurrency = max_concurrency
self.ignore_extra = kwargs.get('ignore_extra', True)
self.cooldown_after_parsing = kwargs.get('cooldown_after_parsing', False)
self.cog = None
@@ -331,6 +342,9 @@ class Command(_BaseCommand):
other.checks = self.checks.copy()
if self._buckets.valid and not other._buckets.valid:
other._buckets = self._buckets.copy()
if self._max_concurrency != other._max_concurrency:
other._max_concurrency = self._max_concurrency.copy()
try:
other.on_error = self.on_error
except AttributeError:
@@ -718,6 +732,9 @@ class Command(_BaseCommand):
self._prepare_cooldowns(ctx)
await self._parse_arguments(ctx)
if self._max_concurrency is not None:
await self._max_concurrency.acquire(ctx)
await self.call_before_hooks(ctx)
def is_on_cooldown(self, ctx):
@@ -1800,3 +1817,36 @@ def cooldown(rate, per, type=BucketType.default):
func.__commands_cooldown__ = Cooldown(rate, per, type)
return func
return decorator
def max_concurrency(number, per=BucketType.default, *, wait=False):
"""A decorator that adds a maximum concurrency to a :class:`.Command` or its subclasses.
This enables you to only allow a certain number of command invocations at the same time,
for example if a command takes too long or if only one user can use it at a time. This
differs from a cooldown in that there is no set waiting period or token bucket -- only
a set number of people can run the command.
.. versionadded:: 1.3.0
Parameters
-------------
number: :class:`int`
The maximum number of invocations of this command that can be running at the same time.
per: :class:`.BucketType`
The bucket that this concurrency is based on, e.g. ``BucketType.guild`` would allow
it to be used up to ``number`` times per guild.
wait: :class:`bool`
Whether the command should wait for the queue to be over. If this is set to ``False``
then instead of waiting until the command can run again, the command raises
:exc:`.MaxConcurrencyReached` to its error handler. If this is set to ``True``
then the command waits until it can be executed.
"""
def decorator(func):
value = MaxConcurrency(number, per=per, wait=wait)
if isinstance(func, Command):
func._max_concurrency = value
else:
func.__commands_max_concurrency__ = value
return func
return decorator