Merge pull request #60
* Rework how checks add attributes to Commmand * Merge remote-tracking branch 'upstream/2.0' into command-attrs-checks
This commit is contained in:
parent
2ecf755372
commit
e65415d3c8
@ -330,6 +330,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
__original_kwargs__: Dict[str, Any]
|
||||
_max_concurrency: Optional[MaxConcurrency]
|
||||
|
||||
def __new__(cls: Type[CommandT], *args: Any, **kwargs: Any) -> CommandT:
|
||||
# if you're wondering why this is done, it's because we need to ensure
|
||||
@ -392,17 +393,20 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
self.description: str = inspect.cleandoc(kwargs.get("description", ""))
|
||||
self.hidden: bool = kwargs.get("hidden", False)
|
||||
|
||||
if hasattr(func, "__command_attrs__"):
|
||||
command_attrs: Dict[str, Any] = func.__command_attrs__
|
||||
else:
|
||||
command_attrs = {}
|
||||
|
||||
try:
|
||||
checks = func.__commands_checks__
|
||||
checks = command_attrs.pop("checks")
|
||||
checks.reverse()
|
||||
except AttributeError:
|
||||
except KeyError:
|
||||
checks = kwargs.get("checks", [])
|
||||
|
||||
self.checks: List[Check] = checks
|
||||
|
||||
try:
|
||||
cooldown = func.__commands_cooldown__
|
||||
except AttributeError:
|
||||
cooldown = command_attrs.pop("cooldown")
|
||||
except KeyError:
|
||||
cooldown = kwargs.get("cooldown")
|
||||
|
||||
if cooldown is None:
|
||||
@ -411,14 +415,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
buckets = cooldown
|
||||
else:
|
||||
raise TypeError("Cooldown must be a an instance of CooldownMapping or None.")
|
||||
|
||||
self.checks: List[Check] = checks
|
||||
self._buckets: CooldownMapping = buckets
|
||||
|
||||
try:
|
||||
max_concurrency = func.__commands_max_concurrency__
|
||||
except AttributeError:
|
||||
max_concurrency = kwargs.get("max_concurrency")
|
||||
|
||||
self._max_concurrency: Optional[MaxConcurrency] = max_concurrency
|
||||
self._max_concurrency = kwargs.get("max_concurrency")
|
||||
|
||||
self.require_var_positional: bool = kwargs.get("require_var_positional", False)
|
||||
self.ignore_extra: bool = kwargs.get("ignore_extra", True)
|
||||
@ -435,20 +435,23 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
|
||||
self._before_invoke: Optional[Hook] = None
|
||||
try:
|
||||
before_invoke = func.__before_invoke__
|
||||
except AttributeError:
|
||||
before_invoke = command_attrs.pop("before_invoke")
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
self.before_invoke(before_invoke)
|
||||
|
||||
self._after_invoke: Optional[Hook] = None
|
||||
try:
|
||||
after_invoke = func.__after_invoke__
|
||||
except AttributeError:
|
||||
after_invoke = command_attrs.pop("after_invoke")
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
self.after_invoke(after_invoke)
|
||||
|
||||
# Handle user provided command attrs
|
||||
self._update_attrs(**command_attrs)
|
||||
|
||||
@property
|
||||
def callback(
|
||||
self,
|
||||
@ -474,6 +477,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
|
||||
self.params, self.option_descriptions = get_signature_parameters(function, globalns)
|
||||
|
||||
def _update_attrs(self, **command_attrs: Any):
|
||||
for key, value in command_attrs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def add_check(self, func: Check) -> None:
|
||||
"""Adds a check to the command.
|
||||
|
||||
@ -1829,7 +1836,7 @@ def group(
|
||||
return command(name=name, cls=cls, **attrs) # type: ignore
|
||||
|
||||
|
||||
def check(predicate: Check) -> Callable[[T], T]:
|
||||
def check(predicate: Check, **command_attrs: Any) -> Callable[[T], T]:
|
||||
r"""A decorator that adds a check to the :class:`.Command` or its
|
||||
subclasses. These checks could be accessed via :attr:`.Command.checks`.
|
||||
|
||||
@ -1898,16 +1905,22 @@ def check(predicate: Check) -> Callable[[T], T]:
|
||||
-----------
|
||||
predicate: Callable[[:class:`Context`], :class:`bool`]
|
||||
The predicate to check if the command should be invoked.
|
||||
**command_attrs: Dict[:class:`str`, Any]
|
||||
key: value pairs to be added to the command's attributes.
|
||||
"""
|
||||
|
||||
def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
|
||||
if isinstance(func, Command):
|
||||
func.checks.append(predicate)
|
||||
func._update_attrs(**command_attrs)
|
||||
else:
|
||||
if not hasattr(func, "__commands_checks__"):
|
||||
func.__commands_checks__ = []
|
||||
if not hasattr(func, "__command_attrs__"):
|
||||
func.__command_attrs__ = {}
|
||||
|
||||
func.__commands_checks__.append(predicate)
|
||||
func.__command_attrs__.update(command_attrs)
|
||||
|
||||
return func
|
||||
|
||||
@ -2080,7 +2093,7 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
|
||||
return True
|
||||
raise MissingAnyRole(list(items))
|
||||
|
||||
return check(predicate)
|
||||
return check(predicate, required_roles=items)
|
||||
|
||||
|
||||
def bot_has_role(item: int) -> Callable[[T], T]:
|
||||
@ -2110,7 +2123,7 @@ def bot_has_role(item: int) -> Callable[[T], T]:
|
||||
raise BotMissingRole(item)
|
||||
return True
|
||||
|
||||
return check(predicate)
|
||||
return check(predicate, bot_required_role=item)
|
||||
|
||||
|
||||
def bot_has_any_role(*items: int) -> Callable[[T], T]:
|
||||
@ -2139,7 +2152,7 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]:
|
||||
return True
|
||||
raise BotMissingAnyRole(list(items))
|
||||
|
||||
return check(predicate)
|
||||
return check(predicate, bot_required_roles=items)
|
||||
|
||||
|
||||
def has_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
@ -2187,7 +2200,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
|
||||
raise MissingPermissions(missing)
|
||||
|
||||
return check(predicate)
|
||||
return check(predicate, required_permissions=perms)
|
||||
|
||||
|
||||
def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
@ -2214,7 +2227,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
|
||||
raise BotMissingPermissions(missing)
|
||||
|
||||
return check(predicate)
|
||||
return check(predicate, bot_required_permissions=perms)
|
||||
|
||||
|
||||
def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
@ -2243,7 +2256,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
|
||||
raise MissingPermissions(missing)
|
||||
|
||||
return check(predicate)
|
||||
return check(predicate, required_guild_permissions=perms)
|
||||
|
||||
|
||||
def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
@ -2269,7 +2282,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
|
||||
raise BotMissingPermissions(missing)
|
||||
|
||||
return check(predicate)
|
||||
return check(predicate, bot_required_guild_permissions=perms)
|
||||
|
||||
|
||||
def dm_only() -> Callable[[T], T]:
|
||||
@ -2380,7 +2393,10 @@ def cooldown(
|
||||
if isinstance(func, Command):
|
||||
func._buckets = CooldownMapping(Cooldown(rate, per), type)
|
||||
else:
|
||||
func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type)
|
||||
if not hasattr(func, "__command_attrs__"):
|
||||
func.__command_attrs__ = {}
|
||||
|
||||
func.__command_attrs__["cooldown"] = CooldownMapping(Cooldown(rate, per), type)
|
||||
return func
|
||||
|
||||
return decorator # type: ignore
|
||||
@ -2424,7 +2440,10 @@ def dynamic_cooldown(
|
||||
if isinstance(func, Command):
|
||||
func._buckets = DynamicCooldownMapping(cooldown, type)
|
||||
else:
|
||||
func.__commands_cooldown__ = DynamicCooldownMapping(cooldown, type)
|
||||
if not hasattr(func, "__command_attrs__"):
|
||||
func.__command_attrs__ = {}
|
||||
|
||||
func.__command_attrs__["cooldown"] = DynamicCooldownMapping(cooldown, type)
|
||||
return func
|
||||
|
||||
return decorator # type: ignore
|
||||
@ -2459,7 +2478,10 @@ def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait:
|
||||
if isinstance(func, Command):
|
||||
func._max_concurrency = value
|
||||
else:
|
||||
func.__commands_max_concurrency__ = value
|
||||
if not hasattr(func, "__command_attrs__"):
|
||||
func.__command_attrs__ = {}
|
||||
|
||||
func.__command_attrs__["_max_concurrency"] = value
|
||||
return func
|
||||
|
||||
return decorator # type: ignore
|
||||
@ -2508,7 +2530,10 @@ def before_invoke(coro) -> Callable[[T], T]:
|
||||
if isinstance(func, Command):
|
||||
func.before_invoke(coro)
|
||||
else:
|
||||
func.__before_invoke__ = coro
|
||||
if not hasattr(func, "__command_attrs__"):
|
||||
func.__command_attrs__ = {}
|
||||
|
||||
func.__command_attrs__["before_invoke"] = coro
|
||||
return func
|
||||
|
||||
return decorator # type: ignore
|
||||
@ -2527,7 +2552,7 @@ def after_invoke(coro) -> Callable[[T], T]:
|
||||
if isinstance(func, Command):
|
||||
func.after_invoke(coro)
|
||||
else:
|
||||
func.__after_invoke__ = coro
|
||||
func.__command_attrs__["after_invoke"] = coro
|
||||
return func
|
||||
|
||||
return decorator # type: ignore
|
||||
|
Loading…
x
Reference in New Issue
Block a user