mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-10-24 10:02:56 +00:00
[commands] Refactor BucketType to not repeat in other places in code
This commit is contained in:
@@ -48,6 +48,25 @@ class BucketType(Enum):
|
||||
category = 5
|
||||
role = 6
|
||||
|
||||
def get_key(self, msg):
|
||||
if self is BucketType.user:
|
||||
return msg.author.id
|
||||
elif self is BucketType.guild:
|
||||
return (msg.guild or msg.author).id
|
||||
elif self is BucketType.channel:
|
||||
return msg.channel.id
|
||||
elif self is BucketType.member:
|
||||
return ((msg.guild and msg.guild.id), msg.author.id)
|
||||
elif self is BucketType.category:
|
||||
return (msg.channel.category or msg.channel).id
|
||||
elif self is BucketType.role:
|
||||
# we return the channel id of a private-channel as there are only roles in guilds
|
||||
# and that yields the same result as for a guild with only the @everyone role
|
||||
# NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
|
||||
# recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
|
||||
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id
|
||||
|
||||
|
||||
class Cooldown:
|
||||
__slots__ = ('rate', 'per', 'type', '_window', '_tokens', '_last')
|
||||
|
||||
@@ -123,23 +142,7 @@ class CooldownMapping:
|
||||
return cls(Cooldown(rate, per, type))
|
||||
|
||||
def _bucket_key(self, msg):
|
||||
bucket_type = self._cooldown.type
|
||||
if bucket_type is BucketType.user:
|
||||
return msg.author.id
|
||||
elif bucket_type is BucketType.guild:
|
||||
return (msg.guild or msg.author).id
|
||||
elif bucket_type is BucketType.channel:
|
||||
return msg.channel.id
|
||||
elif bucket_type is BucketType.member:
|
||||
return ((msg.guild and msg.guild.id), msg.author.id)
|
||||
elif bucket_type is BucketType.category:
|
||||
return (msg.channel.category or msg.channel).id
|
||||
elif bucket_type is BucketType.role:
|
||||
# we return the channel id of a private-channel as there are only roles in guilds
|
||||
# and that yields the same result as for a guild with only the @everyone role
|
||||
# NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
|
||||
# recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
|
||||
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id
|
||||
return self._cooldown.type.get_key(msg)
|
||||
|
||||
def _verify_cache_integrity(self, current=None):
|
||||
# we want to delete all cache objects that haven't been used
|
||||
@@ -245,29 +248,11 @@ class MaxConcurrency:
|
||||
def __repr__(self):
|
||||
return '<MaxConcurrency per={0.per!r} number={0.number} wait={0.wait}>'.format(self)
|
||||
|
||||
def get_bucket(self, message):
|
||||
bucket_type = self.per
|
||||
if bucket_type is BucketType.default:
|
||||
return 'global'
|
||||
elif bucket_type is BucketType.user:
|
||||
return message.author.id
|
||||
elif bucket_type is BucketType.guild:
|
||||
return (message.guild or message.author).id
|
||||
elif bucket_type is BucketType.channel:
|
||||
return message.channel.id
|
||||
elif bucket_type is BucketType.member:
|
||||
return ((message.guild and message.guild.id), message.author.id)
|
||||
elif bucket_type is BucketType.category:
|
||||
return (message.channel.category or message.channel).id
|
||||
elif bucket_type is BucketType.role:
|
||||
# we return the channel id of a private-channel as there are only roles in guilds
|
||||
# and that yields the same result as for a guild with only the @everyone role
|
||||
# NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
|
||||
# recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
|
||||
return (message.channel if isinstance(message.channel, PrivateChannel) else message.author.top_role).id
|
||||
def get_key(self, message):
|
||||
return self.per.get_key(message)
|
||||
|
||||
async def acquire(self, message):
|
||||
key = self.get_bucket(message)
|
||||
key = self.get_key(message)
|
||||
|
||||
try:
|
||||
sem = self._mapping[key]
|
||||
@@ -281,7 +266,7 @@ class MaxConcurrency:
|
||||
async def release(self, message):
|
||||
# Technically there's no reason for this function to be async
|
||||
# But it might be more useful in the future
|
||||
key = self.get_bucket(message)
|
||||
key = self.get_key(message)
|
||||
|
||||
try:
|
||||
sem = self._mapping[key]
|
||||
|
Reference in New Issue
Block a user