Fix code style issues with Black

This commit is contained in:
Lint Action
2021-09-05 21:34:20 +00:00
parent a23dae8604
commit 7513c2138f
108 changed files with 5369 additions and 4858 deletions

View File

@@ -70,52 +70,53 @@ if TYPE_CHECKING:
__all__ = (
'Command',
'Group',
'GroupMixin',
'command',
'group',
'has_role',
'has_permissions',
'has_any_role',
'check',
'check_any',
'before_invoke',
'after_invoke',
'bot_has_role',
'bot_has_permissions',
'bot_has_any_role',
'cooldown',
'dynamic_cooldown',
'max_concurrency',
'dm_only',
'guild_only',
'is_owner',
'is_nsfw',
'has_guild_permissions',
'bot_has_guild_permissions'
"Command",
"Group",
"GroupMixin",
"command",
"group",
"has_role",
"has_permissions",
"has_any_role",
"check",
"check_any",
"before_invoke",
"after_invoke",
"bot_has_role",
"bot_has_permissions",
"bot_has_any_role",
"cooldown",
"dynamic_cooldown",
"max_concurrency",
"dm_only",
"guild_only",
"is_owner",
"is_nsfw",
"has_guild_permissions",
"bot_has_guild_permissions",
)
MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
CogT = TypeVar('CogT', bound='Cog')
CommandT = TypeVar('CommandT', bound='Command')
ContextT = TypeVar('ContextT', bound='Context')
T = TypeVar("T")
CogT = TypeVar("CogT", bound="Cog")
CommandT = TypeVar("CommandT", bound="Command")
ContextT = TypeVar("ContextT", bound="Context")
# CHT = TypeVar('CHT', bound='Check')
GroupT = TypeVar('GroupT', bound='Group')
HookT = TypeVar('HookT', bound='Hook')
ErrorT = TypeVar('ErrorT', bound='Error')
GroupT = TypeVar("GroupT", bound="Group")
HookT = TypeVar("HookT", bound="Hook")
ErrorT = TypeVar("ErrorT", bound="Error")
if TYPE_CHECKING:
P = ParamSpec('P')
P = ParamSpec("P")
else:
P = TypeVar('P')
P = TypeVar("P")
def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]:
partial = functools.partial
while True:
if hasattr(function, '__wrapped__'):
if hasattr(function, "__wrapped__"):
function = function.__wrapped__
elif isinstance(function, partial):
function = function.func
@@ -139,7 +140,7 @@ def get_signature_parameters(function: Callable[..., Any], globalns: Dict[str, A
annotation = eval_annotation(annotation, globalns, globalns, cache)
if annotation is Greedy:
raise TypeError('Unparameterized Greedy[...] is disallowed in signature.')
raise TypeError("Unparameterized Greedy[...] is disallowed in signature.")
params[name] = parameter.replace(annotation=annotation)
@@ -158,8 +159,10 @@ def wrap_callback(coro):
except Exception as exc:
raise CommandInvokeError(exc) from exc
return ret
return wrapped
def hooked_wrapped_callback(command, ctx, coro):
@functools.wraps(coro)
async def wrapped(*args, **kwargs):
@@ -180,6 +183,7 @@ def hooked_wrapped_callback(command, ctx, coro):
await command.call_after_hooks(ctx)
return ret
return wrapped
@@ -202,6 +206,7 @@ class _CaseInsensitiveDict(dict):
def __setitem__(self, k, v):
super().__setitem__(k.casefold(), v)
class Command(_BaseCommand, Generic[CogT, P, T]):
r"""A class that implements the protocol for a bot text command.
@@ -269,8 +274,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
which calls converters. If ``False`` then cooldown processing is done
first and then the converters are called second. Defaults to ``False``.
extras: :class:`dict`
A dict of user provided extras to attach to the Command.
A dict of user provided extras to attach to the Command.
.. note::
This object may be copied by the library.
@@ -295,56 +300,60 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.__original_kwargs__ = kwargs.copy()
return self
def __init__(self, func: Union[
def __init__(
self,
func: Union[
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]],
], **kwargs: Any):
],
**kwargs: Any,
):
if not asyncio.iscoroutinefunction(func):
raise TypeError('Callback must be a coroutine.')
raise TypeError("Callback must be a coroutine.")
name = kwargs.get('name') or func.__name__
name = kwargs.get("name") or func.__name__
if not isinstance(name, str):
raise TypeError('Name of a command must be a string.')
raise TypeError("Name of a command must be a string.")
self.name: str = name
self.callback = func
self.enabled: bool = kwargs.get('enabled', True)
self.enabled: bool = kwargs.get("enabled", True)
help_doc = kwargs.get('help')
help_doc = kwargs.get("help")
if help_doc is not None:
help_doc = inspect.cleandoc(help_doc)
else:
help_doc = inspect.getdoc(func)
if isinstance(help_doc, bytes):
help_doc = help_doc.decode('utf-8')
help_doc = help_doc.decode("utf-8")
self.help: Optional[str] = help_doc
self.brief: Optional[str] = kwargs.get('brief')
self.usage: Optional[str] = kwargs.get('usage')
self.rest_is_raw: bool = kwargs.get('rest_is_raw', False)
self.aliases: Union[List[str], Tuple[str]] = kwargs.get('aliases', [])
self.extras: Dict[str, Any] = kwargs.get('extras', {})
self.brief: Optional[str] = kwargs.get("brief")
self.usage: Optional[str] = kwargs.get("usage")
self.rest_is_raw: bool = kwargs.get("rest_is_raw", False)
self.aliases: Union[List[str], Tuple[str]] = kwargs.get("aliases", [])
self.extras: Dict[str, Any] = kwargs.get("extras", {})
if not isinstance(self.aliases, (list, tuple)):
raise TypeError("Aliases of a command must be a list or a tuple of strings.")
self.description: str = inspect.cleandoc(kwargs.get('description', ''))
self.hidden: bool = kwargs.get('hidden', False)
self.description: str = inspect.cleandoc(kwargs.get("description", ""))
self.hidden: bool = kwargs.get("hidden", False)
try:
checks = func.__commands_checks__
checks.reverse()
except AttributeError:
checks = kwargs.get('checks', [])
checks = kwargs.get("checks", [])
self.checks: List[Check] = checks
try:
cooldown = func.__commands_cooldown__
except AttributeError:
cooldown = kwargs.get('cooldown')
cooldown = kwargs.get("cooldown")
if cooldown is None:
buckets = CooldownMapping(cooldown, BucketType.default)
elif isinstance(cooldown, CooldownMapping):
@@ -356,17 +365,17 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
try:
max_concurrency = func.__commands_max_concurrency__
except AttributeError:
max_concurrency = kwargs.get('max_concurrency')
max_concurrency = kwargs.get("max_concurrency")
self._max_concurrency: Optional[MaxConcurrency] = max_concurrency
self.require_var_positional: bool = kwargs.get('require_var_positional', False)
self.ignore_extra: bool = kwargs.get('ignore_extra', True)
self.cooldown_after_parsing: bool = kwargs.get('cooldown_after_parsing', False)
self.require_var_positional: bool = kwargs.get("require_var_positional", False)
self.ignore_extra: bool = kwargs.get("ignore_extra", True)
self.cooldown_after_parsing: bool = kwargs.get("cooldown_after_parsing", False)
self.cog: Optional[CogT] = None
# bandaid for the fact that sometimes parent can be the bot instance
parent = kwargs.get('parent')
parent = kwargs.get("parent")
self.parent: Optional[GroupMixin] = parent if isinstance(parent, _BaseCommand) else None # type: ignore
self._before_invoke: Optional[Hook] = None
@@ -386,17 +395,19 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.after_invoke(after_invoke)
@property
def callback(self) -> Union[
Callable[Concatenate[CogT, Context, P], Coro[T]],
Callable[Concatenate[Context, P], Coro[T]],
]:
def callback(
self,
) -> Union[Callable[Concatenate[CogT, Context, P], Coro[T]], Callable[Concatenate[Context, P], Coro[T]],]:
return self._callback
@callback.setter
def callback(self, function: Union[
def callback(
self,
function: Union[
Callable[Concatenate[CogT, Context, P], Coro[T]],
Callable[Concatenate[Context, P], Coro[T]],
]) -> None:
],
) -> None:
self._callback = function
unwrap = unwrap_function(function)
self.module = unwrap.__module__
@@ -527,7 +538,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
wrapped = wrap_callback(local)
await wrapped(ctx, error)
finally:
ctx.bot.dispatch('command_error', ctx, error)
ctx.bot.dispatch("command_error", ctx, error)
async def transform(self, ctx: Context, param: inspect.Parameter) -> Any:
required = param.default is param.empty
@@ -551,11 +562,11 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if view.eof:
if param.kind == param.VAR_POSITIONAL:
raise RuntimeError() # break the loop
raise RuntimeError() # break the loop
if required:
if self._is_typing_optional(param.annotation):
return None
if hasattr(converter, '__commands_is_flag__') and converter._can_be_constructible():
if hasattr(converter, "__commands_is_flag__") and converter._can_be_constructible():
return await converter._construct_default(ctx)
raise MissingRequiredArgument(param)
return param.default
@@ -577,7 +588,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
# type-checker fails to narrow argument
return await run_converters(ctx, converter, argument, param) # type: ignore
async def _transform_greedy_pos(self, ctx: Context, param: inspect.Parameter, required: bool, converter: Any) -> Any:
async def _transform_greedy_pos(
self, ctx: Context, param: inspect.Parameter, required: bool, converter: Any
) -> Any:
view = ctx.view
result = []
while not view.eof:
@@ -606,7 +619,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
value = await run_converters(ctx, converter, argument, param) # type: ignore
except (CommandError, ArgumentParsingError):
view.index = previous
raise RuntimeError() from None # break loop
raise RuntimeError() from None # break loop
else:
return value
@@ -643,11 +656,11 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
entries = []
command = self
# command.parent is type-hinted as GroupMixin some attributes are resolved via MRO
while command.parent is not None: # type: ignore
command = command.parent # type: ignore
entries.append(command.name) # type: ignore
while command.parent is not None: # type: ignore
command = command.parent # type: ignore
entries.append(command.name) # type: ignore
return ' '.join(reversed(entries))
return " ".join(reversed(entries))
@property
def parents(self) -> List[Group]:
@@ -661,8 +674,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
"""
entries = []
command = self
while command.parent is not None: # type: ignore
command = command.parent # type: ignore
while command.parent is not None: # type: ignore
command = command.parent # type: ignore
entries.append(command)
return entries
@@ -690,7 +703,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
parent = self.full_parent_name
if parent:
return parent + ' ' + self.name
return parent + " " + self.name
else:
return self.name
@@ -745,7 +758,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
break
if not self.ignore_extra and not view.eof:
raise TooManyArguments('Too many arguments passed to ' + self.qualified_name)
raise TooManyArguments("Too many arguments passed to " + self.qualified_name)
async def call_before_hooks(self, ctx: Context) -> None:
# now that we're done preparing we can call the pre-command hooks
@@ -753,7 +766,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
cog = self.cog
if self._before_invoke is not None:
# should be cog if @commands.before_invoke is used
instance = getattr(self._before_invoke, '__self__', cog)
instance = getattr(self._before_invoke, "__self__", cog)
# __self__ only exists for methods, not functions
# however, if @command.before_invoke is used, it will be a function
if instance:
@@ -775,7 +788,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
async def call_after_hooks(self, ctx: Context) -> None:
cog = self.cog
if self._after_invoke is not None:
instance = getattr(self._after_invoke, '__self__', cog)
instance = getattr(self._after_invoke, "__self__", cog)
if instance:
await self._after_invoke(instance, ctx) # type: ignore
else:
@@ -805,7 +818,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
ctx.command = self
if not await self.can_run(ctx):
raise CheckFailure(f'The check functions for command {self.qualified_name} failed.')
raise CheckFailure(f"The check functions for command {self.qualified_name} failed.")
if self._max_concurrency is not None:
# For this application, context can be duck-typed as a Message
@@ -929,7 +942,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The error handler must be a coroutine.')
raise TypeError("The error handler must be a coroutine.")
self.on_error: Error = coro
return coro
@@ -939,7 +952,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
.. versionadded:: 1.7
"""
return hasattr(self, 'on_error')
return hasattr(self, "on_error")
def before_invoke(self, coro: HookT) -> HookT:
"""A decorator that registers a coroutine as a pre-invoke hook.
@@ -963,7 +976,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
The coroutine passed is not actually a coroutine.
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The pre-invoke hook must be a coroutine.')
raise TypeError("The pre-invoke hook must be a coroutine.")
self._before_invoke = coro
return coro
@@ -990,7 +1003,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
The coroutine passed is not actually a coroutine.
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The post-invoke hook must be a coroutine.')
raise TypeError("The post-invoke hook must be a coroutine.")
self._after_invoke = coro
return coro
@@ -1011,11 +1024,11 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if self.brief is not None:
return self.brief
if self.help is not None:
return self.help.split('\n', 1)[0]
return ''
return self.help.split("\n", 1)[0]
return ""
def _is_typing_optional(self, annotation: Union[T, Optional[T]]) -> TypeGuard[Optional[T]]:
return getattr(annotation, '__origin__', None) is Union and type(None) in annotation.__args__ # type: ignore
return getattr(annotation, "__origin__", None) is Union and type(None) in annotation.__args__ # type: ignore
@property
def signature(self) -> str:
@@ -1025,7 +1038,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
params = self.clean_params
if not params:
return ''
return ""
result = []
for name, param in params.items():
@@ -1035,41 +1048,40 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
# parameter signature is a literal list of it's values
annotation = param.annotation.converter if greedy else param.annotation
origin = getattr(annotation, '__origin__', None)
origin = getattr(annotation, "__origin__", None)
if not greedy and origin is Union:
none_cls = type(None)
union_args = annotation.__args__
optional = union_args[-1] is none_cls
if len(union_args) == 2 and optional:
annotation = union_args[0]
origin = getattr(annotation, '__origin__', None)
origin = getattr(annotation, "__origin__", None)
if origin is Literal:
name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__)
name = "|".join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__)
if param.default is not param.empty:
# We don't want None or '' to trigger the [name=value] case and instead it should
# do [name] since [name=None] or [name=] are not exactly useful for the user.
should_print = param.default if isinstance(param.default, str) else param.default is not None
if should_print:
result.append(f'[{name}={param.default}]' if not greedy else
f'[{name}={param.default}]...')
result.append(f"[{name}={param.default}]" if not greedy else f"[{name}={param.default}]...")
continue
else:
result.append(f'[{name}]')
result.append(f"[{name}]")
elif param.kind == param.VAR_POSITIONAL:
if self.require_var_positional:
result.append(f'<{name}...>')
result.append(f"<{name}...>")
else:
result.append(f'[{name}...]')
result.append(f"[{name}...]")
elif greedy:
result.append(f'[{name}]...')
result.append(f"[{name}]...")
elif optional:
result.append(f'[{name}]')
result.append(f"[{name}]")
else:
result.append(f'<{name}>')
result.append(f"<{name}>")
return ' '.join(result)
return " ".join(result)
async def can_run(self, ctx: Context) -> bool:
"""|coro|
@@ -1099,14 +1111,14 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
"""
if not self.enabled:
raise DisabledCommand(f'{self.name} command is disabled')
raise DisabledCommand(f"{self.name} command is disabled")
original = ctx.command
ctx.command = self
try:
if not await ctx.bot.can_run(ctx):
raise CheckFailure(f'The global check functions for command {self.qualified_name} failed.')
raise CheckFailure(f"The global check functions for command {self.qualified_name} failed.")
cog = self.cog
if cog is not None:
@@ -1125,6 +1137,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
finally:
ctx.command = original
class GroupMixin(Generic[CogT]):
"""A mixin that implements common functionality for classes that behave
similar to :class:`.Group` and are allowed to register commands.
@@ -1137,8 +1150,9 @@ class GroupMixin(Generic[CogT]):
case_insensitive: :class:`bool`
Whether the commands should be case insensitive. Defaults to ``True``.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
case_insensitive = kwargs.get('case_insensitive', True)
case_insensitive = kwargs.get("case_insensitive", True)
self.all_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {}
self.case_insensitive: bool = case_insensitive
super().__init__(*args, **kwargs)
@@ -1177,7 +1191,7 @@ class GroupMixin(Generic[CogT]):
"""
if not isinstance(command, Command):
raise TypeError('The command passed must be a subclass of Command')
raise TypeError("The command passed must be a subclass of Command")
if isinstance(self, Command):
command.parent = self
@@ -1267,7 +1281,7 @@ class GroupMixin(Generic[CogT]):
"""
# fast path, no space in name.
if ' ' not in name:
if " " not in name:
return self.all_commands.get(name)
names = name.split()
@@ -1298,7 +1312,9 @@ class GroupMixin(Generic[CogT]):
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]],
]
], Command[CogT, P, T]]:
],
Command[CogT, P, T],
]:
...
@overload
@@ -1326,8 +1342,9 @@ class GroupMixin(Generic[CogT]):
Callable[..., :class:`Command`]
A decorator that converts the provided method into a Command, adds it to the bot, then returns it.
"""
def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> CommandT:
kwargs.setdefault('parent', self)
kwargs.setdefault("parent", self)
result = command(name=name, cls=cls, *args, **kwargs)(func)
self.add_command(result)
return result
@@ -1341,12 +1358,10 @@ class GroupMixin(Generic[CogT]):
cls: Type[Group[CogT, P, T]] = ...,
*args: Any,
**kwargs: Any,
) -> Callable[[
Union[
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]]
]
], Group[CogT, P, T]]:
) -> Callable[
[Union[Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]]]],
Group[CogT, P, T],
]:
...
@overload
@@ -1374,14 +1389,16 @@ class GroupMixin(Generic[CogT]):
Callable[..., :class:`Group`]
A decorator that converts the provided method into a Group, adds it to the bot, then returns it.
"""
def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> GroupT:
kwargs.setdefault('parent', self)
kwargs.setdefault("parent", self)
result = group(name=name, cls=cls, *args, **kwargs)(func)
self.add_command(result)
return result
return decorator
class Group(GroupMixin[CogT], Command[CogT, P, T]):
"""A class that implements a grouping protocol for commands to be
executed as subcommands.
@@ -1404,8 +1421,9 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
Indicates if the group's commands should be case insensitive.
Defaults to ``False``.
"""
def __init__(self, *args: Any, **attrs: Any) -> None:
self.invoke_without_command: bool = attrs.pop('invoke_without_command', False)
self.invoke_without_command: bool = attrs.pop("invoke_without_command", False)
super().__init__(*args, **attrs)
def copy(self: GroupT) -> GroupT:
@@ -1492,8 +1510,10 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
view.previous = previous
await super().reinvoke(ctx, call_hooks=call_hooks)
# Decorators
@overload
def command(
name: str = ...,
@@ -1505,10 +1525,12 @@ def command(
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]],
]
]
, Command[CogT, P, T]]:
],
Command[CogT, P, T],
]:
...
@overload
def command(
name: str = ...,
@@ -1520,22 +1542,23 @@ def command(
Callable[Concatenate[CogT, ContextT, P], Coro[Any]],
Callable[Concatenate[ContextT, P], Coro[Any]],
]
]
, CommandT]:
],
CommandT,
]:
...
def command(
name: str = MISSING,
cls: Type[CommandT] = MISSING,
**attrs: Any
name: str = MISSING, cls: Type[CommandT] = MISSING, **attrs: Any
) -> Callable[
[
Union[
Callable[Concatenate[ContextT, P], Coro[Any]],
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
]
]
, Union[Command[CogT, P, T], CommandT]]:
],
Union[Command[CogT, P, T], CommandT],
]:
"""A decorator that transforms a function into a :class:`.Command`
or if called with :func:`.group`, :class:`.Group`.
@@ -1568,16 +1591,19 @@ def command(
if cls is MISSING:
cls = Command # type: ignore
def decorator(func: Union[
def decorator(
func: Union[
Callable[Concatenate[ContextT, P], Coro[Any]],
Callable[Concatenate[CogT, ContextT, P], Coro[Any]],
]) -> CommandT:
]
) -> CommandT:
if isinstance(func, Command):
raise TypeError('Callback is already a command.')
raise TypeError("Callback is already a command.")
return cls(func, name=name, **attrs)
return decorator
@overload
def group(
name: str = ...,
@@ -1589,10 +1615,12 @@ def group(
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]],
]
]
, Group[CogT, P, T]]:
],
Group[CogT, P, T],
]:
...
@overload
def group(
name: str = ...,
@@ -1604,10 +1632,12 @@ def group(
Callable[Concatenate[CogT, ContextT, P], Coro[Any]],
Callable[Concatenate[ContextT, P], Coro[Any]],
]
]
, GroupT]:
],
GroupT,
]:
...
def group(
name: str = MISSING,
cls: Type[GroupT] = MISSING,
@@ -1618,8 +1648,9 @@ def group(
Callable[Concatenate[ContextT, P], Coro[Any]],
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
]
]
, Union[Group[CogT, P, T], GroupT]]:
],
Union[Group[CogT, P, T], GroupT],
]:
"""A decorator that transforms a function into a :class:`.Group`.
This is similar to the :func:`.command` decorator but the ``cls``
@@ -1632,6 +1663,7 @@ def group(
cls = Group # type: ignore
return command(name=name, cls=cls, **attrs) # type: ignore
def check(predicate: Check) -> Callable[[T], T]:
r"""A decorator that adds a check to the :class:`.Command` or its
subclasses. These checks could be accessed via :attr:`.Command.checks`.
@@ -1707,7 +1739,7 @@ def check(predicate: Check) -> Callable[[T], T]:
if isinstance(func, Command):
func.checks.append(predicate)
else:
if not hasattr(func, '__commands_checks__'):
if not hasattr(func, "__commands_checks__"):
func.__commands_checks__ = []
func.__commands_checks__.append(predicate)
@@ -1717,13 +1749,16 @@ def check(predicate: Check) -> Callable[[T], T]:
if inspect.iscoroutinefunction(predicate):
decorator.predicate = predicate
else:
@functools.wraps(predicate)
async def wrapper(ctx):
return predicate(ctx) # type: ignore
decorator.predicate = wrapper
return decorator # type: ignore
def check_any(*checks: Check) -> Callable[[T], T]:
r"""A :func:`check` that is added that checks if any of the checks passed
will pass, i.e. using logical OR.
@@ -1773,7 +1808,7 @@ def check_any(*checks: Check) -> Callable[[T], T]:
try:
pred = wrapped.predicate
except AttributeError:
raise TypeError(f'{wrapped!r} must be wrapped by commands.check decorator') from None
raise TypeError(f"{wrapped!r} must be wrapped by commands.check decorator") from None
else:
unwrapped.append(pred)
@@ -1792,6 +1827,7 @@ def check_any(*checks: Check) -> Callable[[T], T]:
return check(predicate)
def has_role(item: Union[int, str]) -> Callable[[T], T]:
"""A :func:`.check` that is added that checks if the member invoking the
command has the role specified via the name or ID specified.
@@ -1834,6 +1870,7 @@ def has_role(item: Union[int, str]) -> Callable[[T], T]:
return check(predicate)
def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
r"""A :func:`.check` that is added that checks if the member invoking the
command has **any** of the roles specified. This means that if they have
@@ -1865,18 +1902,22 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
async def cool(ctx):
await ctx.send('You are cool indeed')
"""
def predicate(ctx):
if ctx.guild is None:
raise NoPrivateMessage()
# ctx.guild is None doesn't narrow ctx.author to Member
getter = functools.partial(discord.utils.get, ctx.author.roles) # type: ignore
if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items):
if any(
getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items
):
return True
raise MissingAnyRole(list(items))
return check(predicate)
def bot_has_role(item: int) -> Callable[[T], T]:
"""Similar to :func:`.has_role` except checks if the bot itself has the
role.
@@ -1903,8 +1944,10 @@ def bot_has_role(item: int) -> Callable[[T], T]:
if role is None:
raise BotMissingRole(item)
return True
return check(predicate)
def bot_has_any_role(*items: int) -> Callable[[T], T]:
"""Similar to :func:`.has_any_role` except checks if the bot itself has
any of the roles listed.
@@ -1918,17 +1961,22 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]:
Raise :exc:`.BotMissingAnyRole` or :exc:`.NoPrivateMessage`
instead of generic checkfailure
"""
def predicate(ctx):
if ctx.guild is None:
raise NoPrivateMessage()
me = ctx.me
getter = functools.partial(discord.utils.get, me.roles)
if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items):
if any(
getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items
):
return True
raise BotMissingAnyRole(list(items))
return check(predicate)
def has_permissions(**perms: bool) -> Callable[[T], T]:
"""A :func:`.check` that is added that checks if the member has all of
the permissions necessary.
@@ -1976,6 +2024,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate)
def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
"""Similar to :func:`.has_permissions` except checks if the bot itself has
the permissions listed.
@@ -2002,6 +2051,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate)
def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
"""Similar to :func:`.has_permissions`, but operates on guild wide
permissions instead of the current channel permissions.
@@ -2030,6 +2080,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate)
def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
"""Similar to :func:`.has_guild_permissions`, but checks the bot
members guild permissions.
@@ -2055,6 +2106,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate)
def dm_only() -> Callable[[T], T]:
"""A :func:`.check` that indicates this command must only be used in a
DM context. Only private messages are allowed when
@@ -2073,6 +2125,7 @@ def dm_only() -> Callable[[T], T]:
return check(predicate)
def guild_only() -> Callable[[T], T]:
"""A :func:`.check` that indicates this command must only be used in a
guild context only. Basically, no private messages are allowed when
@@ -2089,6 +2142,7 @@ def guild_only() -> Callable[[T], T]:
return check(predicate)
def is_owner() -> Callable[[T], T]:
"""A :func:`.check` that checks if the person invoking this command is the
owner of the bot.
@@ -2101,11 +2155,12 @@ def is_owner() -> Callable[[T], T]:
async def predicate(ctx: Context) -> bool:
if not await ctx.bot.is_owner(ctx.author):
raise NotOwner('You do not own this bot.')
raise NotOwner("You do not own this bot.")
return True
return check(predicate)
def is_nsfw() -> Callable[[T], T]:
"""A :func:`.check` that checks if the channel is a NSFW channel.
@@ -2117,14 +2172,19 @@ def is_nsfw() -> Callable[[T], T]:
Raise :exc:`.NSFWChannelRequired` instead of generic :exc:`.CheckFailure`.
DM channels will also now pass this check.
"""
def pred(ctx: Context) -> bool:
ch = ctx.channel
if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()):
return True
raise NSFWChannelRequired(ch) # type: ignore
return check(pred)
def cooldown(rate: int, per: float, type: Union[BucketType, Callable[[Message], Any]] = BucketType.default) -> Callable[[T], T]:
def cooldown(
rate: int, per: float, type: Union[BucketType, Callable[[Message], Any]] = BucketType.default
) -> Callable[[T], T]:
"""A decorator that adds a cooldown to a :class:`.Command`
A cooldown allows a command to only be used a specific amount
@@ -2157,9 +2217,13 @@ def cooldown(rate: int, per: float, type: Union[BucketType, Callable[[Message],
else:
func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type)
return func
return decorator # type: ignore
def dynamic_cooldown(cooldown: Union[BucketType, Callable[[Message], Any]], type: BucketType = BucketType.default) -> Callable[[T], T]:
def dynamic_cooldown(
cooldown: Union[BucketType, Callable[[Message], Any]], type: BucketType = BucketType.default
) -> Callable[[T], T]:
"""A decorator that adds a dynamic cooldown to a :class:`.Command`
This differs from :func:`.cooldown` in that it takes a function that
@@ -2197,8 +2261,10 @@ def dynamic_cooldown(cooldown: Union[BucketType, Callable[[Message], Any]], type
else:
func.__commands_cooldown__ = DynamicCooldownMapping(cooldown, type)
return func
return decorator # type: ignore
def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait: bool = False) -> Callable[[T], T]:
"""A decorator that adds a maximum concurrency to a :class:`.Command` or its subclasses.
@@ -2230,8 +2296,10 @@ def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait:
else:
func.__commands_max_concurrency__ = value
return func
return decorator # type: ignore
def before_invoke(coro) -> Callable[[T], T]:
"""A decorator that registers a coroutine as a pre-invoke hook.
@@ -2270,14 +2338,17 @@ def before_invoke(coro) -> Callable[[T], T]:
bot.add_cog(What())
"""
def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
if isinstance(func, Command):
func.before_invoke(coro)
else:
func.__before_invoke__ = coro
return func
return decorator # type: ignore
def after_invoke(coro) -> Callable[[T], T]:
"""A decorator that registers a coroutine as a post-invoke hook.
@@ -2286,10 +2357,12 @@ def after_invoke(coro) -> Callable[[T], T]:
.. versionadded:: 1.4
"""
def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
if isinstance(func, Command):
func.after_invoke(coro)
else:
func.__after_invoke__ = coro
return func
return decorator # type: ignore