From 98bde70ac1be4abbbfa403989150b6180de6cdb3 Mon Sep 17 00:00:00 2001 From: Gnome Date: Mon, 6 Sep 2021 11:27:25 +0100 Subject: [PATCH] Rework how checks add attributes to Commmand --- discord/ext/commands/core.py | 91 ++++++++++++++++++++++++------------ 1 file changed, 60 insertions(+), 31 deletions(-) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index f122e9ad..a5e61043 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -278,6 +278,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): .. versionadded:: 2.0 """ __original_kwargs__: Dict[str, Any] + _max_concurrency: Optional[MaxConcurrency] def __new__(cls: Type[CommandT], *args: Any, **kwargs: Any) -> CommandT: # if you're wondering why this is done, it's because we need to ensure @@ -332,33 +333,33 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.description: str = inspect.cleandoc(kwargs.get('description', '')) self.hidden: bool = kwargs.get('hidden', False) + if hasattr(func, '__command_attrs__'): + command_attrs: Dict[str, Any] = func.__command_attrs__ + else: + command_attrs = {} + try: - checks = func.__commands_checks__ + checks = command_attrs.pop('checks') checks.reverse() - except AttributeError: + except KeyError: checks = kwargs.get('checks', []) - self.checks: List[Check] = checks - try: - cooldown = func.__commands_cooldown__ - except AttributeError: + cooldown = command_attrs.pop('cooldown') + except KeyError: cooldown = kwargs.get('cooldown') - + + if cooldown is None: buckets = CooldownMapping(cooldown, BucketType.default) elif isinstance(cooldown, CooldownMapping): buckets = cooldown else: raise TypeError("Cooldown must be a an instance of CooldownMapping or None.") + + self.checks: List[Check] = checks self._buckets: CooldownMapping = buckets - - try: - max_concurrency = func.__commands_max_concurrency__ - except AttributeError: - max_concurrency = kwargs.get('max_concurrency') - - self._max_concurrency: Optional[MaxConcurrency] = max_concurrency + self._max_concurrency = kwargs.get('max_concurrency') self.require_var_positional: bool = kwargs.get('require_var_positional', False) self.ignore_extra: bool = kwargs.get('ignore_extra', True) @@ -371,20 +372,24 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self._before_invoke: Optional[Hook] = None try: - before_invoke = func.__before_invoke__ - except AttributeError: + before_invoke = command_attrs.pop('before_invoke') + except KeyError: pass else: self.before_invoke(before_invoke) self._after_invoke: Optional[Hook] = None try: - after_invoke = func.__after_invoke__ - except AttributeError: + after_invoke = command_attrs.pop('after_invoke') + except KeyError: pass else: self.after_invoke(after_invoke) + # Handle user provided command attrs + self._update_attrs(**command_attrs) + + @property def callback(self) -> Union[ Callable[Concatenate[CogT, Context, P], Coro[T]], @@ -408,6 +413,12 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.params = get_signature_parameters(function, globalns) + + def _update_attrs(self, **command_attrs: Any): + for key, value in command_attrs.items(): + setattr(self, key, value) + + def add_check(self, func: Check) -> None: """Adds a check to the command. @@ -1632,7 +1643,7 @@ def group( cls = Group # type: ignore return command(name=name, cls=cls, **attrs) # type: ignore -def check(predicate: Check) -> Callable[[T], T]: +def check(predicate: Check, **command_attrs: Any) -> Callable[[T], T]: r"""A decorator that adds a check to the :class:`.Command` or its subclasses. These checks could be accessed via :attr:`.Command.checks`. @@ -1701,16 +1712,22 @@ def check(predicate: Check) -> Callable[[T], T]: ----------- predicate: Callable[[:class:`Context`], :class:`bool`] The predicate to check if the command should be invoked. + **command_attrs: Dict[:class:`str`, Any] + key: value pairs to be added to the command's attributes. """ def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: if isinstance(func, Command): func.checks.append(predicate) + func._update_attrs(**command_attrs) else: if not hasattr(func, '__commands_checks__'): func.__commands_checks__ = [] + if not hasattr(func, "__command_attrs__"): + func.__command_attrs__ = {} func.__commands_checks__.append(predicate) + func.__command_attrs__.update(command_attrs) return func @@ -1875,7 +1892,7 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]: return True raise MissingAnyRole(list(items)) - return check(predicate) + return check(predicate, required_roles=items) def bot_has_role(item: int) -> Callable[[T], T]: """Similar to :func:`.has_role` except checks if the bot itself has the @@ -1903,7 +1920,7 @@ def bot_has_role(item: int) -> Callable[[T], T]: if role is None: raise BotMissingRole(item) return True - return check(predicate) + return check(predicate, bot_required_role=item) def bot_has_any_role(*items: int) -> Callable[[T], T]: """Similar to :func:`.has_any_role` except checks if the bot itself has @@ -1927,7 +1944,7 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]: if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items): return True raise BotMissingAnyRole(list(items)) - return check(predicate) + return check(predicate, bot_required_roles=items) def has_permissions(**perms: bool) -> Callable[[T], T]: """A :func:`.check` that is added that checks if the member has all of @@ -1974,7 +1991,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]: raise MissingPermissions(missing) - return check(predicate) + return check(predicate, required_permissions=perms) def bot_has_permissions(**perms: bool) -> Callable[[T], T]: """Similar to :func:`.has_permissions` except checks if the bot itself has @@ -2000,7 +2017,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]: raise BotMissingPermissions(missing) - return check(predicate) + return check(predicate, bot_required_permissions=perms) def has_guild_permissions(**perms: bool) -> Callable[[T], T]: """Similar to :func:`.has_permissions`, but operates on guild wide @@ -2028,7 +2045,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]: raise MissingPermissions(missing) - return check(predicate) + return check(predicate, required_guild_permissions=perms) def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]: """Similar to :func:`.has_guild_permissions`, but checks the bot @@ -2053,7 +2070,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]: raise BotMissingPermissions(missing) - return check(predicate) + return check(predicate, bot_required_guild_permissions=perms) def dm_only() -> Callable[[T], T]: """A :func:`.check` that indicates this command must only be used in a @@ -2155,7 +2172,10 @@ def cooldown(rate: int, per: float, type: Union[BucketType, Callable[[Message], if isinstance(func, Command): func._buckets = CooldownMapping(Cooldown(rate, per), type) else: - func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type) + if not hasattr(func, "__command_attrs__"): + func.__command_attrs__ = {} + + func.__command_attrs__['cooldown'] = CooldownMapping(Cooldown(rate, per), type) return func return decorator # type: ignore @@ -2195,7 +2215,10 @@ def dynamic_cooldown(cooldown: Union[BucketType, Callable[[Message], Any]], type if isinstance(func, Command): func._buckets = DynamicCooldownMapping(cooldown, type) else: - func.__commands_cooldown__ = DynamicCooldownMapping(cooldown, type) + if not hasattr(func, "__command_attrs__"): + func.__command_attrs__ = {} + + func.__command_attrs__['cooldown'] = DynamicCooldownMapping(cooldown, type) return func return decorator # type: ignore @@ -2228,7 +2251,10 @@ def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait: if isinstance(func, Command): func._max_concurrency = value else: - func.__commands_max_concurrency__ = value + if not hasattr(func, "__command_attrs__"): + func.__command_attrs__ = {} + + func.__command_attrs__['_max_concurrency'] = value return func return decorator # type: ignore @@ -2274,7 +2300,10 @@ def before_invoke(coro) -> Callable[[T], T]: if isinstance(func, Command): func.before_invoke(coro) else: - func.__before_invoke__ = coro + if not hasattr(func, "__command_attrs__"): + func.__command_attrs__ = {} + + func.__command_attrs__['before_invoke'] = coro return func return decorator # type: ignore @@ -2290,6 +2319,6 @@ def after_invoke(coro) -> Callable[[T], T]: if isinstance(func, Command): func.after_invoke(coro) else: - func.__after_invoke__ = coro + func.__command_attrs__['after_invoke'] = coro return func return decorator # type: ignore -- 2.47.2