diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index fe763fb62..9efb51041 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -48,6 +48,25 @@ class BucketType(Enum): category = 5 role = 6 + def get_key(self, msg): + if self is BucketType.user: + return msg.author.id + elif self is BucketType.guild: + return (msg.guild or msg.author).id + elif self is BucketType.channel: + return msg.channel.id + elif self is BucketType.member: + return ((msg.guild and msg.guild.id), msg.author.id) + elif self is BucketType.category: + return (msg.channel.category or msg.channel).id + elif self is BucketType.role: + # we return the channel id of a private-channel as there are only roles in guilds + # and that yields the same result as for a guild with only the @everyone role + # NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are + # recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do + return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id + + class Cooldown: __slots__ = ('rate', 'per', 'type', '_window', '_tokens', '_last') @@ -123,23 +142,7 @@ class CooldownMapping: return cls(Cooldown(rate, per, type)) def _bucket_key(self, msg): - bucket_type = self._cooldown.type - if bucket_type is BucketType.user: - return msg.author.id - elif bucket_type is BucketType.guild: - return (msg.guild or msg.author).id - elif bucket_type is BucketType.channel: - return msg.channel.id - elif bucket_type is BucketType.member: - return ((msg.guild and msg.guild.id), msg.author.id) - elif bucket_type is BucketType.category: - return (msg.channel.category or msg.channel).id - elif bucket_type is BucketType.role: - # we return the channel id of a private-channel as there are only roles in guilds - # and that yields the same result as for a guild with only the @everyone role - # NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are - # recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do - return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id + return self._cooldown.type.get_key(msg) def _verify_cache_integrity(self, current=None): # we want to delete all cache objects that haven't been used @@ -245,29 +248,11 @@ class MaxConcurrency: def __repr__(self): return ''.format(self) - def get_bucket(self, message): - bucket_type = self.per - if bucket_type is BucketType.default: - return 'global' - elif bucket_type is BucketType.user: - return message.author.id - elif bucket_type is BucketType.guild: - return (message.guild or message.author).id - elif bucket_type is BucketType.channel: - return message.channel.id - elif bucket_type is BucketType.member: - return ((message.guild and message.guild.id), message.author.id) - elif bucket_type is BucketType.category: - return (message.channel.category or message.channel).id - elif bucket_type is BucketType.role: - # we return the channel id of a private-channel as there are only roles in guilds - # and that yields the same result as for a guild with only the @everyone role - # NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are - # recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do - return (message.channel if isinstance(message.channel, PrivateChannel) else message.author.top_role).id + def get_key(self, message): + return self.per.get_key(message) async def acquire(self, message): - key = self.get_bucket(message) + key = self.get_key(message) try: sem = self._mapping[key] @@ -281,7 +266,7 @@ class MaxConcurrency: async def release(self, message): # Technically there's no reason for this function to be async # But it might be more useful in the future - key = self.get_bucket(message) + key = self.get_key(message) try: sem = self._mapping[key]