mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-09-04 17:06:21 +00:00
[commands] Add max_concurrency decorator
This commit is contained in:
@ -26,13 +26,17 @@ DEALINGS IN THE SOFTWARE.
|
||||
|
||||
from discord.enums import Enum
|
||||
import time
|
||||
import asyncio
|
||||
from collections import deque
|
||||
|
||||
from ...abc import PrivateChannel
|
||||
from .errors import MaxConcurrencyReached
|
||||
|
||||
__all__ = (
|
||||
'BucketType',
|
||||
'Cooldown',
|
||||
'CooldownMapping',
|
||||
'MaxConcurrency',
|
||||
)
|
||||
|
||||
class BucketType(Enum):
|
||||
@ -163,3 +167,129 @@ class CooldownMapping:
|
||||
def update_rate_limit(self, message, current=None):
|
||||
bucket = self.get_bucket(message, current)
|
||||
return bucket.update_rate_limit(current)
|
||||
|
||||
class _Semaphore:
|
||||
"""This class is a version of a semaphore.
|
||||
|
||||
If you're wondering why asyncio.Semaphore isn't being used,
|
||||
it's because it doesn't expose the internal value. This internal
|
||||
value is necessary because I need to support both `wait=True` and
|
||||
`wait=False`.
|
||||
|
||||
An asyncio.Queue could have been used to do this as well -- but it
|
||||
not as inefficient since internally that uses two queues and is a bit
|
||||
overkill for what is basically a counter.
|
||||
"""
|
||||
|
||||
__slots__ = ('value', 'loop', '_waiters')
|
||||
|
||||
def __init__(self, number):
|
||||
self.value = number
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self._waiters = deque()
|
||||
|
||||
def __repr__(self):
|
||||
return '<_Semaphore value={0.value} waiters={1}>'.format(self, len(self._waiters))
|
||||
|
||||
def locked(self):
|
||||
return self.value == 0
|
||||
|
||||
def wake_up(self):
|
||||
while self._waiters:
|
||||
future = self._waiters.popleft()
|
||||
if not future.done():
|
||||
future.set_result(None)
|
||||
return
|
||||
|
||||
async def acquire(self, *, wait=False):
|
||||
if not wait and self.value <= 0:
|
||||
# signal that we're not acquiring
|
||||
return False
|
||||
|
||||
while self.value <= 0:
|
||||
future = self.loop.create_future()
|
||||
self._waiters.append(future)
|
||||
try:
|
||||
await future
|
||||
except:
|
||||
future.cancel()
|
||||
if self.value > 0 and not future.cancelled():
|
||||
self.wake_up()
|
||||
raise
|
||||
|
||||
self.value -= 1
|
||||
return True
|
||||
|
||||
def release(self):
|
||||
self.value += 1
|
||||
self.wake_up()
|
||||
|
||||
class MaxConcurrency:
|
||||
__slots__ = ('number', 'per', 'wait', '_mapping')
|
||||
|
||||
def __init__(self, number, *, per, wait):
|
||||
self._mapping = {}
|
||||
self.per = per
|
||||
self.number = number
|
||||
self.wait = wait
|
||||
|
||||
if number <= 0:
|
||||
raise ValueError('max_concurrency \'number\' cannot be less than 1')
|
||||
|
||||
if not isinstance(per, BucketType):
|
||||
raise TypeError('max_concurrency \'per\' must be of type BucketType not %r' % type(per))
|
||||
|
||||
def copy(self):
|
||||
return self.__class__(self.number, per=self.per, wait=self.wait)
|
||||
|
||||
def __repr__(self):
|
||||
return '<MaxConcurrency per={0.per!r} number={0.number} wait={0.wait}>'.format(self)
|
||||
|
||||
def get_bucket(self, message):
|
||||
bucket_type = self.per
|
||||
if bucket_type is BucketType.default:
|
||||
return 'global'
|
||||
elif bucket_type is BucketType.user:
|
||||
return message.author.id
|
||||
elif bucket_type is BucketType.guild:
|
||||
return (message.guild or message.author).id
|
||||
elif bucket_type is BucketType.channel:
|
||||
return message.channel.id
|
||||
elif bucket_type is BucketType.member:
|
||||
return ((message.guild and message.guild.id), message.author.id)
|
||||
elif bucket_type is BucketType.category:
|
||||
return (message.channel.category or message.channel).id
|
||||
elif bucket_type is BucketType.role:
|
||||
# we return the channel id of a private-channel as there are only roles in guilds
|
||||
# and that yields the same result as for a guild with only the @everyone role
|
||||
# NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
|
||||
# recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
|
||||
return (message.channel if isinstance(message.channel, PrivateChannel) else message.author.top_role).id
|
||||
|
||||
async def acquire(self, message):
|
||||
key = self.get_bucket(message)
|
||||
|
||||
try:
|
||||
sem = self._mapping[key]
|
||||
except KeyError:
|
||||
self._mapping[key] = sem = _Semaphore(self.number)
|
||||
|
||||
acquired = await sem.acquire(wait=self.wait)
|
||||
if not acquired:
|
||||
raise MaxConcurrencyReached(self.number, self.per)
|
||||
|
||||
async def release(self, message):
|
||||
# Technically there's no reason for this function to be async
|
||||
# But it might be more useful in the future
|
||||
key = self.get_bucket(message)
|
||||
|
||||
try:
|
||||
sem = self._mapping[key]
|
||||
except KeyError:
|
||||
# ...? peculiar
|
||||
return
|
||||
else:
|
||||
sem.release()
|
||||
|
||||
if sem.value >= self.number:
|
||||
del self._mapping[key]
|
||||
|
Reference in New Issue
Block a user