[commands] Refactor BucketType to not repeat in other places in code
This commit is contained in:
		@@ -48,6 +48,25 @@ class BucketType(Enum):
 | 
			
		||||
    category = 5
 | 
			
		||||
    role     = 6
 | 
			
		||||
 | 
			
		||||
    def get_key(self, msg):
 | 
			
		||||
        if self is BucketType.user:
 | 
			
		||||
            return msg.author.id
 | 
			
		||||
        elif self is BucketType.guild:
 | 
			
		||||
            return (msg.guild or msg.author).id
 | 
			
		||||
        elif self is BucketType.channel:
 | 
			
		||||
            return msg.channel.id
 | 
			
		||||
        elif self is BucketType.member:
 | 
			
		||||
            return ((msg.guild and msg.guild.id), msg.author.id)
 | 
			
		||||
        elif self is BucketType.category:
 | 
			
		||||
            return (msg.channel.category or msg.channel).id
 | 
			
		||||
        elif self is BucketType.role:
 | 
			
		||||
            # we return the channel id of a private-channel as there are only roles in guilds
 | 
			
		||||
            # and that yields the same result as for a guild with only the @everyone role
 | 
			
		||||
            # NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
 | 
			
		||||
            # recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
 | 
			
		||||
            return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Cooldown:
 | 
			
		||||
    __slots__ = ('rate', 'per', 'type', '_window', '_tokens', '_last')
 | 
			
		||||
 | 
			
		||||
@@ -123,23 +142,7 @@ class CooldownMapping:
 | 
			
		||||
        return cls(Cooldown(rate, per, type))
 | 
			
		||||
 | 
			
		||||
    def _bucket_key(self, msg):
 | 
			
		||||
        bucket_type = self._cooldown.type
 | 
			
		||||
        if bucket_type is BucketType.user:
 | 
			
		||||
            return msg.author.id
 | 
			
		||||
        elif bucket_type is BucketType.guild:
 | 
			
		||||
            return (msg.guild or msg.author).id
 | 
			
		||||
        elif bucket_type is BucketType.channel:
 | 
			
		||||
            return msg.channel.id
 | 
			
		||||
        elif bucket_type is BucketType.member:
 | 
			
		||||
            return ((msg.guild and msg.guild.id), msg.author.id)
 | 
			
		||||
        elif bucket_type is BucketType.category:
 | 
			
		||||
            return (msg.channel.category or msg.channel).id
 | 
			
		||||
        elif bucket_type is BucketType.role:
 | 
			
		||||
            # we return the channel id of a private-channel as there are only roles in guilds
 | 
			
		||||
            # and that yields the same result as for a guild with only the @everyone role
 | 
			
		||||
            # NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
 | 
			
		||||
            # recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
 | 
			
		||||
            return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id
 | 
			
		||||
        return self._cooldown.type.get_key(msg)
 | 
			
		||||
 | 
			
		||||
    def _verify_cache_integrity(self, current=None):
 | 
			
		||||
        # we want to delete all cache objects that haven't been used
 | 
			
		||||
@@ -245,29 +248,11 @@ class MaxConcurrency:
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return '<MaxConcurrency per={0.per!r} number={0.number} wait={0.wait}>'.format(self)
 | 
			
		||||
 | 
			
		||||
    def get_bucket(self, message):
 | 
			
		||||
        bucket_type = self.per
 | 
			
		||||
        if bucket_type is BucketType.default:
 | 
			
		||||
            return 'global'
 | 
			
		||||
        elif bucket_type is BucketType.user:
 | 
			
		||||
            return message.author.id
 | 
			
		||||
        elif bucket_type is BucketType.guild:
 | 
			
		||||
            return (message.guild or message.author).id
 | 
			
		||||
        elif bucket_type is BucketType.channel:
 | 
			
		||||
            return message.channel.id
 | 
			
		||||
        elif bucket_type is BucketType.member:
 | 
			
		||||
            return ((message.guild and message.guild.id), message.author.id)
 | 
			
		||||
        elif bucket_type is BucketType.category:
 | 
			
		||||
            return (message.channel.category or message.channel).id
 | 
			
		||||
        elif bucket_type is BucketType.role:
 | 
			
		||||
            # we return the channel id of a private-channel as there are only roles in guilds
 | 
			
		||||
            # and that yields the same result as for a guild with only the @everyone role
 | 
			
		||||
            # NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
 | 
			
		||||
            # recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
 | 
			
		||||
            return (message.channel if isinstance(message.channel, PrivateChannel) else message.author.top_role).id
 | 
			
		||||
    def get_key(self, message):
 | 
			
		||||
        return self.per.get_key(message)
 | 
			
		||||
 | 
			
		||||
    async def acquire(self, message):
 | 
			
		||||
        key = self.get_bucket(message)
 | 
			
		||||
        key = self.get_key(message)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            sem = self._mapping[key]
 | 
			
		||||
@@ -281,7 +266,7 @@ class MaxConcurrency:
 | 
			
		||||
    async def release(self, message):
 | 
			
		||||
        # Technically there's no reason for this function to be async
 | 
			
		||||
        # But it might be more useful in the future
 | 
			
		||||
        key = self.get_bucket(message)
 | 
			
		||||
        key = self.get_key(message)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            sem = self._mapping[key]
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user