[Commands] Allow custom checks to add attributes to the Command instance #60

Merged
Gnome-py merged 2 commits from command-attrs-checks into 2.0 2021-09-21 18:47:28 +00:00

View File

@ -283,6 +283,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
.. versionadded:: 2.0
"""
__original_kwargs__: Dict[str, Any]
_max_concurrency: Optional[MaxConcurrency]
def __new__(cls: Type[CommandT], *args: Any, **kwargs: Any) -> CommandT:
# if you're wondering why this is done, it's because we need to ensure
@ -341,17 +342,20 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.description: str = inspect.cleandoc(kwargs.get("description", ""))
self.hidden: bool = kwargs.get("hidden", False)
if hasattr(func, "__command_attrs__"):
command_attrs: Dict[str, Any] = func.__command_attrs__
else:
command_attrs = {}
try:
checks = func.__commands_checks__
checks = command_attrs.pop("checks")
checks.reverse()
except AttributeError:
except KeyError:
checks = kwargs.get("checks", [])
self.checks: List[Check] = checks
try:
cooldown = func.__commands_cooldown__
except AttributeError:
cooldown = command_attrs.pop("cooldown")
except KeyError:
cooldown = kwargs.get("cooldown")
if cooldown is None:
@ -360,14 +364,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
buckets = cooldown
else:
raise TypeError("Cooldown must be a an instance of CooldownMapping or None.")
self.checks: List[Check] = checks
self._buckets: CooldownMapping = buckets
try:
max_concurrency = func.__commands_max_concurrency__
except AttributeError:
max_concurrency = kwargs.get("max_concurrency")
self._max_concurrency: Optional[MaxConcurrency] = max_concurrency
self._max_concurrency = kwargs.get("max_concurrency")
self.require_var_positional: bool = kwargs.get("require_var_positional", False)
self.ignore_extra: bool = kwargs.get("ignore_extra", True)
@ -380,20 +380,23 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self._before_invoke: Optional[Hook] = None
try:
before_invoke = func.__before_invoke__
except AttributeError:
before_invoke = command_attrs.pop("before_invoke")
except KeyError:
pass
else:
self.before_invoke(before_invoke)
self._after_invoke: Optional[Hook] = None
try:
after_invoke = func.__after_invoke__
except AttributeError:
after_invoke = command_attrs.pop("after_invoke")
except KeyError:
pass
else:
self.after_invoke(after_invoke)
# Handle user provided command attrs
self._update_attrs(**command_attrs)
@property
def callback(
self,
@ -419,6 +422,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.params = get_signature_parameters(function, globalns)
def _update_attrs(self, **command_attrs: Any):
for key, value in command_attrs.items():
setattr(self, key, value)
def add_check(self, func: Check) -> None:
"""Adds a check to the command.
@ -1664,7 +1671,7 @@ def group(
return command(name=name, cls=cls, **attrs) # type: ignore
def check(predicate: Check) -> Callable[[T], T]:
def check(predicate: Check, **command_attrs: Any) -> Callable[[T], T]:
r"""A decorator that adds a check to the :class:`.Command` or its
subclasses. These checks could be accessed via :attr:`.Command.checks`.
@ -1733,16 +1740,22 @@ def check(predicate: Check) -> Callable[[T], T]:
-----------
predicate: Callable[[:class:`Context`], :class:`bool`]
The predicate to check if the command should be invoked.
**command_attrs: Dict[:class:`str`, Any]
key: value pairs to be added to the command's attributes.
"""
def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
if isinstance(func, Command):
func.checks.append(predicate)
func._update_attrs(**command_attrs)
else:
if not hasattr(func, "__commands_checks__"):
func.__commands_checks__ = []
if not hasattr(func, "__command_attrs__"):
func.__command_attrs__ = {}
func.__commands_checks__.append(predicate)
func.__command_attrs__.update(command_attrs)
return func
@ -1915,7 +1928,7 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
return True
raise MissingAnyRole(list(items))
return check(predicate)
return check(predicate, required_roles=items)
def bot_has_role(item: int) -> Callable[[T], T]:
@ -1945,7 +1958,7 @@ def bot_has_role(item: int) -> Callable[[T], T]:
raise BotMissingRole(item)
return True
return check(predicate)
return check(predicate, bot_required_role=item)
def bot_has_any_role(*items: int) -> Callable[[T], T]:
@ -1974,7 +1987,7 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]:
return True
raise BotMissingAnyRole(list(items))
return check(predicate)
return check(predicate, bot_required_roles=items)
def has_permissions(**perms: bool) -> Callable[[T], T]:
@ -2022,7 +2035,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]:
raise MissingPermissions(missing)
return check(predicate)
return check(predicate, required_permissions=perms)
def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
@ -2049,7 +2062,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
raise BotMissingPermissions(missing)
return check(predicate)
return check(predicate, bot_required_permissions=perms)
def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
@ -2078,7 +2091,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
raise MissingPermissions(missing)
return check(predicate)
return check(predicate, required_guild_permissions=perms)
def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
@ -2104,7 +2117,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
raise BotMissingPermissions(missing)
return check(predicate)
return check(predicate, bot_required_guild_permissions=perms)
def dm_only() -> Callable[[T], T]:
@ -2215,7 +2228,10 @@ def cooldown(
if isinstance(func, Command):
func._buckets = CooldownMapping(Cooldown(rate, per), type)
else:
func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type)
if not hasattr(func, "__command_attrs__"):
func.__command_attrs__ = {}
func.__command_attrs__["cooldown"] = CooldownMapping(Cooldown(rate, per), type)
return func
return decorator # type: ignore
@ -2259,7 +2275,10 @@ def dynamic_cooldown(
if isinstance(func, Command):
func._buckets = DynamicCooldownMapping(cooldown, type)
else:
func.__commands_cooldown__ = DynamicCooldownMapping(cooldown, type)
if not hasattr(func, "__command_attrs__"):
func.__command_attrs__ = {}
func.__command_attrs__["cooldown"] = DynamicCooldownMapping(cooldown, type)
return func
return decorator # type: ignore
@ -2294,7 +2313,10 @@ def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait:
if isinstance(func, Command):
func._max_concurrency = value
else:
func.__commands_max_concurrency__ = value
if not hasattr(func, "__command_attrs__"):
func.__command_attrs__ = {}
func.__command_attrs__["_max_concurrency"] = value
return func
return decorator # type: ignore
@ -2343,7 +2365,10 @@ def before_invoke(coro) -> Callable[[T], T]:
if isinstance(func, Command):
func.before_invoke(coro)
else:
func.__before_invoke__ = coro
if not hasattr(func, "__command_attrs__"):
func.__command_attrs__ = {}
func.__command_attrs__["before_invoke"] = coro
return func
return decorator # type: ignore
@ -2362,7 +2387,7 @@ def after_invoke(coro) -> Callable[[T], T]:
if isinstance(func, Command):
func.after_invoke(coro)
else:
func.__after_invoke__ = coro
func.__command_attrs__["after_invoke"] = coro
return func
return decorator # type: ignore