Merge pull request #60

* Rework how checks add attributes to Commmand

* Merge remote-tracking branch 'upstream/2.0' into command-attrs-checks
This commit is contained in:
Gnome! 2021-09-21 19:47:28 +01:00 committed by GitHub
parent 2ecf755372
commit e65415d3c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -330,6 +330,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
.. versionadded:: 2.0
"""
__original_kwargs__: Dict[str, Any]
_max_concurrency: Optional[MaxConcurrency]
def __new__(cls: Type[CommandT], *args: Any, **kwargs: Any) -> CommandT:
# if you're wondering why this is done, it's because we need to ensure
@ -392,17 +393,20 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.description: str = inspect.cleandoc(kwargs.get("description", ""))
self.hidden: bool = kwargs.get("hidden", False)
if hasattr(func, "__command_attrs__"):
command_attrs: Dict[str, Any] = func.__command_attrs__
else:
command_attrs = {}
try:
checks = func.__commands_checks__
checks = command_attrs.pop("checks")
checks.reverse()
except AttributeError:
except KeyError:
checks = kwargs.get("checks", [])
self.checks: List[Check] = checks
try:
cooldown = func.__commands_cooldown__
except AttributeError:
cooldown = command_attrs.pop("cooldown")
except KeyError:
cooldown = kwargs.get("cooldown")
if cooldown is None:
@ -411,14 +415,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
buckets = cooldown
else:
raise TypeError("Cooldown must be a an instance of CooldownMapping or None.")
self.checks: List[Check] = checks
self._buckets: CooldownMapping = buckets
try:
max_concurrency = func.__commands_max_concurrency__
except AttributeError:
max_concurrency = kwargs.get("max_concurrency")
self._max_concurrency: Optional[MaxConcurrency] = max_concurrency
self._max_concurrency = kwargs.get("max_concurrency")
self.require_var_positional: bool = kwargs.get("require_var_positional", False)
self.ignore_extra: bool = kwargs.get("ignore_extra", True)
@ -435,20 +435,23 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self._before_invoke: Optional[Hook] = None
try:
before_invoke = func.__before_invoke__
except AttributeError:
before_invoke = command_attrs.pop("before_invoke")
except KeyError:
pass
else:
self.before_invoke(before_invoke)
self._after_invoke: Optional[Hook] = None
try:
after_invoke = func.__after_invoke__
except AttributeError:
after_invoke = command_attrs.pop("after_invoke")
except KeyError:
pass
else:
self.after_invoke(after_invoke)
# Handle user provided command attrs
self._update_attrs(**command_attrs)
@property
def callback(
self,
@ -474,6 +477,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.params, self.option_descriptions = get_signature_parameters(function, globalns)
def _update_attrs(self, **command_attrs: Any):
for key, value in command_attrs.items():
setattr(self, key, value)
def add_check(self, func: Check) -> None:
"""Adds a check to the command.
@ -1829,7 +1836,7 @@ def group(
return command(name=name, cls=cls, **attrs) # type: ignore
def check(predicate: Check) -> Callable[[T], T]:
def check(predicate: Check, **command_attrs: Any) -> Callable[[T], T]:
r"""A decorator that adds a check to the :class:`.Command` or its
subclasses. These checks could be accessed via :attr:`.Command.checks`.
@ -1898,16 +1905,22 @@ def check(predicate: Check) -> Callable[[T], T]:
-----------
predicate: Callable[[:class:`Context`], :class:`bool`]
The predicate to check if the command should be invoked.
**command_attrs: Dict[:class:`str`, Any]
key: value pairs to be added to the command's attributes.
"""
def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
if isinstance(func, Command):
func.checks.append(predicate)
func._update_attrs(**command_attrs)
else:
if not hasattr(func, "__commands_checks__"):
func.__commands_checks__ = []
if not hasattr(func, "__command_attrs__"):
func.__command_attrs__ = {}
func.__commands_checks__.append(predicate)
func.__command_attrs__.update(command_attrs)
return func
@ -2080,7 +2093,7 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
return True
raise MissingAnyRole(list(items))
return check(predicate)
return check(predicate, required_roles=items)
def bot_has_role(item: int) -> Callable[[T], T]:
@ -2110,7 +2123,7 @@ def bot_has_role(item: int) -> Callable[[T], T]:
raise BotMissingRole(item)
return True
return check(predicate)
return check(predicate, bot_required_role=item)
def bot_has_any_role(*items: int) -> Callable[[T], T]:
@ -2139,7 +2152,7 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]:
return True
raise BotMissingAnyRole(list(items))
return check(predicate)
return check(predicate, bot_required_roles=items)
def has_permissions(**perms: bool) -> Callable[[T], T]:
@ -2187,7 +2200,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]:
raise MissingPermissions(missing)
return check(predicate)
return check(predicate, required_permissions=perms)
def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
@ -2214,7 +2227,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
raise BotMissingPermissions(missing)
return check(predicate)
return check(predicate, bot_required_permissions=perms)
def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
@ -2243,7 +2256,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
raise MissingPermissions(missing)
return check(predicate)
return check(predicate, required_guild_permissions=perms)
def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
@ -2269,7 +2282,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
raise BotMissingPermissions(missing)
return check(predicate)
return check(predicate, bot_required_guild_permissions=perms)
def dm_only() -> Callable[[T], T]:
@ -2380,7 +2393,10 @@ def cooldown(
if isinstance(func, Command):
func._buckets = CooldownMapping(Cooldown(rate, per), type)
else:
func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type)
if not hasattr(func, "__command_attrs__"):
func.__command_attrs__ = {}
func.__command_attrs__["cooldown"] = CooldownMapping(Cooldown(rate, per), type)
return func
return decorator # type: ignore
@ -2424,7 +2440,10 @@ def dynamic_cooldown(
if isinstance(func, Command):
func._buckets = DynamicCooldownMapping(cooldown, type)
else:
func.__commands_cooldown__ = DynamicCooldownMapping(cooldown, type)
if not hasattr(func, "__command_attrs__"):
func.__command_attrs__ = {}
func.__command_attrs__["cooldown"] = DynamicCooldownMapping(cooldown, type)
return func
return decorator # type: ignore
@ -2459,7 +2478,10 @@ def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait:
if isinstance(func, Command):
func._max_concurrency = value
else:
func.__commands_max_concurrency__ = value
if not hasattr(func, "__command_attrs__"):
func.__command_attrs__ = {}
func.__command_attrs__["_max_concurrency"] = value
return func
return decorator # type: ignore
@ -2508,7 +2530,10 @@ def before_invoke(coro) -> Callable[[T], T]:
if isinstance(func, Command):
func.before_invoke(coro)
else:
func.__before_invoke__ = coro
if not hasattr(func, "__command_attrs__"):
func.__command_attrs__ = {}
func.__command_attrs__["before_invoke"] = coro
return func
return decorator # type: ignore
@ -2527,7 +2552,7 @@ def after_invoke(coro) -> Callable[[T], T]:
if isinstance(func, Command):
func.after_invoke(coro)
else:
func.__after_invoke__ = coro
func.__command_attrs__["after_invoke"] = coro
return func
return decorator # type: ignore