[commands] Add max_concurrency decorator
This commit is contained in:
		@@ -26,13 +26,17 @@ DEALINGS IN THE SOFTWARE.
 | 
			
		||||
 | 
			
		||||
from discord.enums import Enum
 | 
			
		||||
import time
 | 
			
		||||
import asyncio
 | 
			
		||||
from collections import deque
 | 
			
		||||
 | 
			
		||||
from ...abc import PrivateChannel
 | 
			
		||||
from .errors import MaxConcurrencyReached
 | 
			
		||||
 | 
			
		||||
__all__ = (
 | 
			
		||||
    'BucketType',
 | 
			
		||||
    'Cooldown',
 | 
			
		||||
    'CooldownMapping',
 | 
			
		||||
    'MaxConcurrency',
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
class BucketType(Enum):
 | 
			
		||||
@@ -163,3 +167,129 @@ class CooldownMapping:
 | 
			
		||||
    def update_rate_limit(self, message, current=None):
 | 
			
		||||
        bucket = self.get_bucket(message, current)
 | 
			
		||||
        return bucket.update_rate_limit(current)
 | 
			
		||||
 | 
			
		||||
class _Semaphore:
 | 
			
		||||
    """This class is a version of a semaphore.
 | 
			
		||||
 | 
			
		||||
    If you're wondering why asyncio.Semaphore isn't being used,
 | 
			
		||||
    it's because it doesn't expose the internal value. This internal
 | 
			
		||||
    value is necessary because I need to support both `wait=True` and
 | 
			
		||||
    `wait=False`.
 | 
			
		||||
 | 
			
		||||
    An asyncio.Queue could have been used to do this as well -- but it
 | 
			
		||||
    not as inefficient since internally that uses two queues and is a bit
 | 
			
		||||
    overkill for what is basically a counter.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    __slots__ = ('value', 'loop', '_waiters')
 | 
			
		||||
 | 
			
		||||
    def __init__(self, number):
 | 
			
		||||
        self.value = number
 | 
			
		||||
        self.loop = asyncio.get_event_loop()
 | 
			
		||||
        self._waiters = deque()
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return '<_Semaphore value={0.value} waiters={1}>'.format(self, len(self._waiters))
 | 
			
		||||
 | 
			
		||||
    def locked(self):
 | 
			
		||||
        return self.value == 0
 | 
			
		||||
 | 
			
		||||
    def wake_up(self):
 | 
			
		||||
        while self._waiters:
 | 
			
		||||
            future = self._waiters.popleft()
 | 
			
		||||
            if not future.done():
 | 
			
		||||
                future.set_result(None)
 | 
			
		||||
                return
 | 
			
		||||
 | 
			
		||||
    async def acquire(self, *, wait=False):
 | 
			
		||||
        if not wait and self.value <= 0:
 | 
			
		||||
            # signal that we're not acquiring
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        while self.value <= 0:
 | 
			
		||||
            future = self.loop.create_future()
 | 
			
		||||
            self._waiters.append(future)
 | 
			
		||||
            try:
 | 
			
		||||
                await future
 | 
			
		||||
            except:
 | 
			
		||||
                future.cancel()
 | 
			
		||||
                if self.value > 0 and not future.cancelled():
 | 
			
		||||
                    self.wake_up()
 | 
			
		||||
                raise
 | 
			
		||||
 | 
			
		||||
        self.value -= 1
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def release(self):
 | 
			
		||||
        self.value += 1
 | 
			
		||||
        self.wake_up()
 | 
			
		||||
 | 
			
		||||
class MaxConcurrency:
 | 
			
		||||
    __slots__ = ('number', 'per', 'wait', '_mapping')
 | 
			
		||||
 | 
			
		||||
    def __init__(self, number, *, per, wait):
 | 
			
		||||
        self._mapping = {}
 | 
			
		||||
        self.per = per
 | 
			
		||||
        self.number = number
 | 
			
		||||
        self.wait = wait
 | 
			
		||||
 | 
			
		||||
        if number <= 0:
 | 
			
		||||
            raise ValueError('max_concurrency \'number\' cannot be less than 1')
 | 
			
		||||
 | 
			
		||||
        if not isinstance(per, BucketType):
 | 
			
		||||
            raise TypeError('max_concurrency \'per\' must be of type BucketType not %r' % type(per))
 | 
			
		||||
 | 
			
		||||
    def copy(self):
 | 
			
		||||
        return self.__class__(self.number, per=self.per, wait=self.wait)
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return '<MaxConcurrency per={0.per!r} number={0.number} wait={0.wait}>'.format(self)
 | 
			
		||||
 | 
			
		||||
    def get_bucket(self, message):
 | 
			
		||||
        bucket_type = self.per
 | 
			
		||||
        if bucket_type is BucketType.default:
 | 
			
		||||
            return 'global'
 | 
			
		||||
        elif bucket_type is BucketType.user:
 | 
			
		||||
            return message.author.id
 | 
			
		||||
        elif bucket_type is BucketType.guild:
 | 
			
		||||
            return (message.guild or message.author).id
 | 
			
		||||
        elif bucket_type is BucketType.channel:
 | 
			
		||||
            return message.channel.id
 | 
			
		||||
        elif bucket_type is BucketType.member:
 | 
			
		||||
            return ((message.guild and message.guild.id), message.author.id)
 | 
			
		||||
        elif bucket_type is BucketType.category:
 | 
			
		||||
            return (message.channel.category or message.channel).id
 | 
			
		||||
        elif bucket_type is BucketType.role:
 | 
			
		||||
            # we return the channel id of a private-channel as there are only roles in guilds
 | 
			
		||||
            # and that yields the same result as for a guild with only the @everyone role
 | 
			
		||||
            # NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
 | 
			
		||||
            # recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
 | 
			
		||||
            return (message.channel if isinstance(message.channel, PrivateChannel) else message.author.top_role).id
 | 
			
		||||
 | 
			
		||||
    async def acquire(self, message):
 | 
			
		||||
        key = self.get_bucket(message)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            sem = self._mapping[key]
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            self._mapping[key] = sem = _Semaphore(self.number)
 | 
			
		||||
 | 
			
		||||
        acquired = await sem.acquire(wait=self.wait)
 | 
			
		||||
        if not acquired:
 | 
			
		||||
            raise MaxConcurrencyReached(self.number, self.per)
 | 
			
		||||
 | 
			
		||||
    async def release(self, message):
 | 
			
		||||
        # Technically there's no reason for this function to be async
 | 
			
		||||
        # But it might be more useful in the future
 | 
			
		||||
        key = self.get_bucket(message)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            sem = self._mapping[key]
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            # ...? peculiar
 | 
			
		||||
            return
 | 
			
		||||
        else:
 | 
			
		||||
            sem.release()
 | 
			
		||||
 | 
			
		||||
        if sem.value >= self.number:
 | 
			
		||||
            del self._mapping[key]
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user