Rework how checks add attributes to Commmand

This commit is contained in:
Gnome
2021-09-06 11:27:25 +01:00
parent 65640ddfc7
commit 98bde70ac1

View File

@@ -278,6 +278,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
.. versionadded:: 2.0
"""
__original_kwargs__: Dict[str, Any]
_max_concurrency: Optional[MaxConcurrency]
def __new__(cls: Type[CommandT], *args: Any, **kwargs: Any) -> CommandT:
# if you're wondering why this is done, it's because we need to ensure
@@ -332,33 +333,33 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.description: str = inspect.cleandoc(kwargs.get('description', ''))
self.hidden: bool = kwargs.get('hidden', False)
if hasattr(func, '__command_attrs__'):
command_attrs: Dict[str, Any] = func.__command_attrs__
else:
command_attrs = {}
try:
checks = func.__commands_checks__
checks = command_attrs.pop('checks')
checks.reverse()
except AttributeError:
except KeyError:
checks = kwargs.get('checks', [])
self.checks: List[Check] = checks
try:
cooldown = func.__commands_cooldown__
except AttributeError:
cooldown = command_attrs.pop('cooldown')
except KeyError:
cooldown = kwargs.get('cooldown')
if cooldown is None:
buckets = CooldownMapping(cooldown, BucketType.default)
elif isinstance(cooldown, CooldownMapping):
buckets = cooldown
else:
raise TypeError("Cooldown must be a an instance of CooldownMapping or None.")
self.checks: List[Check] = checks
self._buckets: CooldownMapping = buckets
try:
max_concurrency = func.__commands_max_concurrency__
except AttributeError:
max_concurrency = kwargs.get('max_concurrency')
self._max_concurrency: Optional[MaxConcurrency] = max_concurrency
self._max_concurrency = kwargs.get('max_concurrency')
self.require_var_positional: bool = kwargs.get('require_var_positional', False)
self.ignore_extra: bool = kwargs.get('ignore_extra', True)
@@ -371,20 +372,24 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self._before_invoke: Optional[Hook] = None
try:
before_invoke = func.__before_invoke__
except AttributeError:
before_invoke = command_attrs.pop('before_invoke')
except KeyError:
pass
else:
self.before_invoke(before_invoke)
self._after_invoke: Optional[Hook] = None
try:
after_invoke = func.__after_invoke__
except AttributeError:
after_invoke = command_attrs.pop('after_invoke')
except KeyError:
pass
else:
self.after_invoke(after_invoke)
# Handle user provided command attrs
self._update_attrs(**command_attrs)
@property
def callback(self) -> Union[
Callable[Concatenate[CogT, Context, P], Coro[T]],
@@ -408,6 +413,12 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.params = get_signature_parameters(function, globalns)
def _update_attrs(self, **command_attrs: Any):
for key, value in command_attrs.items():
setattr(self, key, value)
def add_check(self, func: Check) -> None:
"""Adds a check to the command.
@@ -1632,7 +1643,7 @@ def group(
cls = Group # type: ignore
return command(name=name, cls=cls, **attrs) # type: ignore
def check(predicate: Check) -> Callable[[T], T]:
def check(predicate: Check, **command_attrs: Any) -> Callable[[T], T]:
r"""A decorator that adds a check to the :class:`.Command` or its
subclasses. These checks could be accessed via :attr:`.Command.checks`.
@@ -1701,16 +1712,22 @@ def check(predicate: Check) -> Callable[[T], T]:
-----------
predicate: Callable[[:class:`Context`], :class:`bool`]
The predicate to check if the command should be invoked.
**command_attrs: Dict[:class:`str`, Any]
key: value pairs to be added to the command's attributes.
"""
def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
if isinstance(func, Command):
func.checks.append(predicate)
func._update_attrs(**command_attrs)
else:
if not hasattr(func, '__commands_checks__'):
func.__commands_checks__ = []
if not hasattr(func, "__command_attrs__"):
func.__command_attrs__ = {}
func.__commands_checks__.append(predicate)
func.__command_attrs__.update(command_attrs)
return func
@@ -1875,7 +1892,7 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
return True
raise MissingAnyRole(list(items))
return check(predicate)
return check(predicate, required_roles=items)
def bot_has_role(item: int) -> Callable[[T], T]:
"""Similar to :func:`.has_role` except checks if the bot itself has the
@@ -1903,7 +1920,7 @@ def bot_has_role(item: int) -> Callable[[T], T]:
if role is None:
raise BotMissingRole(item)
return True
return check(predicate)
return check(predicate, bot_required_role=item)
def bot_has_any_role(*items: int) -> Callable[[T], T]:
"""Similar to :func:`.has_any_role` except checks if the bot itself has
@@ -1927,7 +1944,7 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]:
if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items):
return True
raise BotMissingAnyRole(list(items))
return check(predicate)
return check(predicate, bot_required_roles=items)
def has_permissions(**perms: bool) -> Callable[[T], T]:
"""A :func:`.check` that is added that checks if the member has all of
@@ -1974,7 +1991,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]:
raise MissingPermissions(missing)
return check(predicate)
return check(predicate, required_permissions=perms)
def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
"""Similar to :func:`.has_permissions` except checks if the bot itself has
@@ -2000,7 +2017,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
raise BotMissingPermissions(missing)
return check(predicate)
return check(predicate, bot_required_permissions=perms)
def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
"""Similar to :func:`.has_permissions`, but operates on guild wide
@@ -2028,7 +2045,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
raise MissingPermissions(missing)
return check(predicate)
return check(predicate, required_guild_permissions=perms)
def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
"""Similar to :func:`.has_guild_permissions`, but checks the bot
@@ -2053,7 +2070,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
raise BotMissingPermissions(missing)
return check(predicate)
return check(predicate, bot_required_guild_permissions=perms)
def dm_only() -> Callable[[T], T]:
"""A :func:`.check` that indicates this command must only be used in a
@@ -2155,7 +2172,10 @@ def cooldown(rate: int, per: float, type: Union[BucketType, Callable[[Message],
if isinstance(func, Command):
func._buckets = CooldownMapping(Cooldown(rate, per), type)
else:
func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type)
if not hasattr(func, "__command_attrs__"):
func.__command_attrs__ = {}
func.__command_attrs__['cooldown'] = CooldownMapping(Cooldown(rate, per), type)
return func
return decorator # type: ignore
@@ -2195,7 +2215,10 @@ def dynamic_cooldown(cooldown: Union[BucketType, Callable[[Message], Any]], type
if isinstance(func, Command):
func._buckets = DynamicCooldownMapping(cooldown, type)
else:
func.__commands_cooldown__ = DynamicCooldownMapping(cooldown, type)
if not hasattr(func, "__command_attrs__"):
func.__command_attrs__ = {}
func.__command_attrs__['cooldown'] = DynamicCooldownMapping(cooldown, type)
return func
return decorator # type: ignore
@@ -2228,7 +2251,10 @@ def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait:
if isinstance(func, Command):
func._max_concurrency = value
else:
func.__commands_max_concurrency__ = value
if not hasattr(func, "__command_attrs__"):
func.__command_attrs__ = {}
func.__command_attrs__['_max_concurrency'] = value
return func
return decorator # type: ignore
@@ -2274,7 +2300,10 @@ def before_invoke(coro) -> Callable[[T], T]:
if isinstance(func, Command):
func.before_invoke(coro)
else:
func.__before_invoke__ = coro
if not hasattr(func, "__command_attrs__"):
func.__command_attrs__ = {}
func.__command_attrs__['before_invoke'] = coro
return func
return decorator # type: ignore
@@ -2290,6 +2319,6 @@ def after_invoke(coro) -> Callable[[T], T]:
if isinstance(func, Command):
func.after_invoke(coro)
else:
func.__after_invoke__ = coro
func.__command_attrs__['after_invoke'] = coro
return func
return decorator # type: ignore