[commands] Add max_concurrency decorator
This commit is contained in:
parent
3149f15165
commit
bf84c63396
@ -26,13 +26,17 @@ DEALINGS IN THE SOFTWARE.
|
|||||||
|
|
||||||
from discord.enums import Enum
|
from discord.enums import Enum
|
||||||
import time
|
import time
|
||||||
|
import asyncio
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
from ...abc import PrivateChannel
|
from ...abc import PrivateChannel
|
||||||
|
from .errors import MaxConcurrencyReached
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
'BucketType',
|
'BucketType',
|
||||||
'Cooldown',
|
'Cooldown',
|
||||||
'CooldownMapping',
|
'CooldownMapping',
|
||||||
|
'MaxConcurrency',
|
||||||
)
|
)
|
||||||
|
|
||||||
class BucketType(Enum):
|
class BucketType(Enum):
|
||||||
@ -163,3 +167,129 @@ class CooldownMapping:
|
|||||||
def update_rate_limit(self, message, current=None):
|
def update_rate_limit(self, message, current=None):
|
||||||
bucket = self.get_bucket(message, current)
|
bucket = self.get_bucket(message, current)
|
||||||
return bucket.update_rate_limit(current)
|
return bucket.update_rate_limit(current)
|
||||||
|
|
||||||
|
class _Semaphore:
|
||||||
|
"""This class is a version of a semaphore.
|
||||||
|
|
||||||
|
If you're wondering why asyncio.Semaphore isn't being used,
|
||||||
|
it's because it doesn't expose the internal value. This internal
|
||||||
|
value is necessary because I need to support both `wait=True` and
|
||||||
|
`wait=False`.
|
||||||
|
|
||||||
|
An asyncio.Queue could have been used to do this as well -- but it
|
||||||
|
not as inefficient since internally that uses two queues and is a bit
|
||||||
|
overkill for what is basically a counter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ('value', 'loop', '_waiters')
|
||||||
|
|
||||||
|
def __init__(self, number):
|
||||||
|
self.value = number
|
||||||
|
self.loop = asyncio.get_event_loop()
|
||||||
|
self._waiters = deque()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '<_Semaphore value={0.value} waiters={1}>'.format(self, len(self._waiters))
|
||||||
|
|
||||||
|
def locked(self):
|
||||||
|
return self.value == 0
|
||||||
|
|
||||||
|
def wake_up(self):
|
||||||
|
while self._waiters:
|
||||||
|
future = self._waiters.popleft()
|
||||||
|
if not future.done():
|
||||||
|
future.set_result(None)
|
||||||
|
return
|
||||||
|
|
||||||
|
async def acquire(self, *, wait=False):
|
||||||
|
if not wait and self.value <= 0:
|
||||||
|
# signal that we're not acquiring
|
||||||
|
return False
|
||||||
|
|
||||||
|
while self.value <= 0:
|
||||||
|
future = self.loop.create_future()
|
||||||
|
self._waiters.append(future)
|
||||||
|
try:
|
||||||
|
await future
|
||||||
|
except:
|
||||||
|
future.cancel()
|
||||||
|
if self.value > 0 and not future.cancelled():
|
||||||
|
self.wake_up()
|
||||||
|
raise
|
||||||
|
|
||||||
|
self.value -= 1
|
||||||
|
return True
|
||||||
|
|
||||||
|
def release(self):
|
||||||
|
self.value += 1
|
||||||
|
self.wake_up()
|
||||||
|
|
||||||
|
class MaxConcurrency:
|
||||||
|
__slots__ = ('number', 'per', 'wait', '_mapping')
|
||||||
|
|
||||||
|
def __init__(self, number, *, per, wait):
|
||||||
|
self._mapping = {}
|
||||||
|
self.per = per
|
||||||
|
self.number = number
|
||||||
|
self.wait = wait
|
||||||
|
|
||||||
|
if number <= 0:
|
||||||
|
raise ValueError('max_concurrency \'number\' cannot be less than 1')
|
||||||
|
|
||||||
|
if not isinstance(per, BucketType):
|
||||||
|
raise TypeError('max_concurrency \'per\' must be of type BucketType not %r' % type(per))
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
return self.__class__(self.number, per=self.per, wait=self.wait)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '<MaxConcurrency per={0.per!r} number={0.number} wait={0.wait}>'.format(self)
|
||||||
|
|
||||||
|
def get_bucket(self, message):
|
||||||
|
bucket_type = self.per
|
||||||
|
if bucket_type is BucketType.default:
|
||||||
|
return 'global'
|
||||||
|
elif bucket_type is BucketType.user:
|
||||||
|
return message.author.id
|
||||||
|
elif bucket_type is BucketType.guild:
|
||||||
|
return (message.guild or message.author).id
|
||||||
|
elif bucket_type is BucketType.channel:
|
||||||
|
return message.channel.id
|
||||||
|
elif bucket_type is BucketType.member:
|
||||||
|
return ((message.guild and message.guild.id), message.author.id)
|
||||||
|
elif bucket_type is BucketType.category:
|
||||||
|
return (message.channel.category or message.channel).id
|
||||||
|
elif bucket_type is BucketType.role:
|
||||||
|
# we return the channel id of a private-channel as there are only roles in guilds
|
||||||
|
# and that yields the same result as for a guild with only the @everyone role
|
||||||
|
# NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
|
||||||
|
# recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
|
||||||
|
return (message.channel if isinstance(message.channel, PrivateChannel) else message.author.top_role).id
|
||||||
|
|
||||||
|
async def acquire(self, message):
|
||||||
|
key = self.get_bucket(message)
|
||||||
|
|
||||||
|
try:
|
||||||
|
sem = self._mapping[key]
|
||||||
|
except KeyError:
|
||||||
|
self._mapping[key] = sem = _Semaphore(self.number)
|
||||||
|
|
||||||
|
acquired = await sem.acquire(wait=self.wait)
|
||||||
|
if not acquired:
|
||||||
|
raise MaxConcurrencyReached(self.number, self.per)
|
||||||
|
|
||||||
|
async def release(self, message):
|
||||||
|
# Technically there's no reason for this function to be async
|
||||||
|
# But it might be more useful in the future
|
||||||
|
key = self.get_bucket(message)
|
||||||
|
|
||||||
|
try:
|
||||||
|
sem = self._mapping[key]
|
||||||
|
except KeyError:
|
||||||
|
# ...? peculiar
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
sem.release()
|
||||||
|
|
||||||
|
if sem.value >= self.number:
|
||||||
|
del self._mapping[key]
|
||||||
|
@ -33,7 +33,7 @@ import datetime
|
|||||||
import discord
|
import discord
|
||||||
|
|
||||||
from .errors import *
|
from .errors import *
|
||||||
from .cooldowns import Cooldown, BucketType, CooldownMapping
|
from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency
|
||||||
from . import converter as converters
|
from . import converter as converters
|
||||||
from ._types import _BaseCommand
|
from ._types import _BaseCommand
|
||||||
from .cog import Cog
|
from .cog import Cog
|
||||||
@ -53,6 +53,7 @@ __all__ = (
|
|||||||
'bot_has_permissions',
|
'bot_has_permissions',
|
||||||
'bot_has_any_role',
|
'bot_has_any_role',
|
||||||
'cooldown',
|
'cooldown',
|
||||||
|
'max_concurrency',
|
||||||
'dm_only',
|
'dm_only',
|
||||||
'guild_only',
|
'guild_only',
|
||||||
'is_owner',
|
'is_owner',
|
||||||
@ -90,6 +91,9 @@ def hooked_wrapped_callback(command, ctx, coro):
|
|||||||
ctx.command_failed = True
|
ctx.command_failed = True
|
||||||
raise CommandInvokeError(exc) from exc
|
raise CommandInvokeError(exc) from exc
|
||||||
finally:
|
finally:
|
||||||
|
if command._max_concurrency is not None:
|
||||||
|
await command._max_concurrency.release(ctx)
|
||||||
|
|
||||||
await command.call_after_hooks(ctx)
|
await command.call_after_hooks(ctx)
|
||||||
return ret
|
return ret
|
||||||
return wrapped
|
return wrapped
|
||||||
@ -248,6 +252,13 @@ class Command(_BaseCommand):
|
|||||||
finally:
|
finally:
|
||||||
self._buckets = CooldownMapping(cooldown)
|
self._buckets = CooldownMapping(cooldown)
|
||||||
|
|
||||||
|
try:
|
||||||
|
max_concurrency = func.__commands_max_concurrency__
|
||||||
|
except AttributeError:
|
||||||
|
max_concurrency = kwargs.get('max_concurrency')
|
||||||
|
finally:
|
||||||
|
self._max_concurrency = max_concurrency
|
||||||
|
|
||||||
self.ignore_extra = kwargs.get('ignore_extra', True)
|
self.ignore_extra = kwargs.get('ignore_extra', True)
|
||||||
self.cooldown_after_parsing = kwargs.get('cooldown_after_parsing', False)
|
self.cooldown_after_parsing = kwargs.get('cooldown_after_parsing', False)
|
||||||
self.cog = None
|
self.cog = None
|
||||||
@ -331,6 +342,9 @@ class Command(_BaseCommand):
|
|||||||
other.checks = self.checks.copy()
|
other.checks = self.checks.copy()
|
||||||
if self._buckets.valid and not other._buckets.valid:
|
if self._buckets.valid and not other._buckets.valid:
|
||||||
other._buckets = self._buckets.copy()
|
other._buckets = self._buckets.copy()
|
||||||
|
if self._max_concurrency != other._max_concurrency:
|
||||||
|
other._max_concurrency = self._max_concurrency.copy()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
other.on_error = self.on_error
|
other.on_error = self.on_error
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
@ -718,6 +732,9 @@ class Command(_BaseCommand):
|
|||||||
self._prepare_cooldowns(ctx)
|
self._prepare_cooldowns(ctx)
|
||||||
await self._parse_arguments(ctx)
|
await self._parse_arguments(ctx)
|
||||||
|
|
||||||
|
if self._max_concurrency is not None:
|
||||||
|
await self._max_concurrency.acquire(ctx)
|
||||||
|
|
||||||
await self.call_before_hooks(ctx)
|
await self.call_before_hooks(ctx)
|
||||||
|
|
||||||
def is_on_cooldown(self, ctx):
|
def is_on_cooldown(self, ctx):
|
||||||
@ -1800,3 +1817,36 @@ def cooldown(rate, per, type=BucketType.default):
|
|||||||
func.__commands_cooldown__ = Cooldown(rate, per, type)
|
func.__commands_cooldown__ = Cooldown(rate, per, type)
|
||||||
return func
|
return func
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
def max_concurrency(number, per=BucketType.default, *, wait=False):
|
||||||
|
"""A decorator that adds a maximum concurrency to a :class:`.Command` or its subclasses.
|
||||||
|
|
||||||
|
This enables you to only allow a certain number of command invocations at the same time,
|
||||||
|
for example if a command takes too long or if only one user can use it at a time. This
|
||||||
|
differs from a cooldown in that there is no set waiting period or token bucket -- only
|
||||||
|
a set number of people can run the command.
|
||||||
|
|
||||||
|
.. versionadded:: 1.3.0
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-------------
|
||||||
|
number: :class:`int`
|
||||||
|
The maximum number of invocations of this command that can be running at the same time.
|
||||||
|
per: :class:`.BucketType`
|
||||||
|
The bucket that this concurrency is based on, e.g. ``BucketType.guild`` would allow
|
||||||
|
it to be used up to ``number`` times per guild.
|
||||||
|
wait: :class:`bool`
|
||||||
|
Whether the command should wait for the queue to be over. If this is set to ``False``
|
||||||
|
then instead of waiting until the command can run again, the command raises
|
||||||
|
:exc:`.MaxConcurrencyReached` to its error handler. If this is set to ``True``
|
||||||
|
then the command waits until it can be executed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
value = MaxConcurrency(number, per=per, wait=wait)
|
||||||
|
if isinstance(func, Command):
|
||||||
|
func._max_concurrency = value
|
||||||
|
else:
|
||||||
|
func.__commands_max_concurrency__ = value
|
||||||
|
return func
|
||||||
|
return decorator
|
||||||
|
@ -41,6 +41,7 @@ __all__ = (
|
|||||||
'TooManyArguments',
|
'TooManyArguments',
|
||||||
'UserInputError',
|
'UserInputError',
|
||||||
'CommandOnCooldown',
|
'CommandOnCooldown',
|
||||||
|
'MaxConcurrencyReached',
|
||||||
'NotOwner',
|
'NotOwner',
|
||||||
'MissingRole',
|
'MissingRole',
|
||||||
'BotMissingRole',
|
'BotMissingRole',
|
||||||
@ -240,6 +241,28 @@ class CommandOnCooldown(CommandError):
|
|||||||
self.retry_after = retry_after
|
self.retry_after = retry_after
|
||||||
super().__init__('You are on cooldown. Try again in {:.2f}s'.format(retry_after))
|
super().__init__('You are on cooldown. Try again in {:.2f}s'.format(retry_after))
|
||||||
|
|
||||||
|
class MaxConcurrencyReached(CommandError):
|
||||||
|
"""Exception raised when the command being invoked has reached its maximum concurrency.
|
||||||
|
|
||||||
|
This inherits from :exc:`CommandError`.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
------------
|
||||||
|
number: :class:`int`
|
||||||
|
The maximum number of concurrent invokers allowed.
|
||||||
|
per: :class:`BucketType`
|
||||||
|
The bucket type passed to the :func:`.max_concurrency` decorator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, number, per):
|
||||||
|
self.number = number
|
||||||
|
self.per = per
|
||||||
|
name = per.name
|
||||||
|
suffix = 'per %s' % name if per.name != 'default' else 'globally'
|
||||||
|
plural = '%s times %s' if number > 1 else '%s time %s'
|
||||||
|
fmt = plural % (number, suffix)
|
||||||
|
super().__init__('Too many people using this command. It can only be used {}.'.format(fmt))
|
||||||
|
|
||||||
class MissingRole(CheckFailure):
|
class MissingRole(CheckFailure):
|
||||||
"""Exception raised when the command invoker lacks a role to run a command.
|
"""Exception raised when the command invoker lacks a role to run a command.
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user