Format with black

This commit is contained in:
StockerMC
2021-10-15 13:58:15 -04:00
parent 23b390971f
commit 161affa246
2 changed files with 294 additions and 88 deletions

View File

@@ -161,7 +161,9 @@ class ObjectConverter(IDConverter[discord.Object]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.Object:
match = self._get_id_match(argument) or re.match(r"<(?:@(?:!|&)?|#)([0-9]{15,20})>$", argument)
match = self._get_id_match(argument) or re.match(
r"<(?:@(?:!|&)?|#)([0-9]{15,20})>$", argument
)
if match is None:
raise ObjectNotFound(argument)
@@ -198,10 +200,14 @@ class MemberConverter(IDConverter[discord.Member]):
if len(argument) > 5 and argument[-5] == "#":
username, _, discriminator = argument.rpartition("#")
members = await guild.query_members(username, limit=100, cache=cache)
return discord.utils.get(members, name=username, discriminator=discriminator)
return discord.utils.get(
members, name=username, discriminator=discriminator
)
else:
members = await guild.query_members(argument, limit=100, cache=cache)
return discord.utils.find(lambda m: m.name == argument or m.nick == argument, members)
return discord.utils.find(
lambda m: m.name == argument or m.nick == argument, members
)
async def query_member_by_id(self, bot, guild, user_id):
ws = bot._get_websocket(shard_id=guild.shard_id)
@@ -226,7 +232,9 @@ class MemberConverter(IDConverter[discord.Member]):
async def convert(self, ctx: Context, argument: str) -> discord.Member:
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument)
match = self._get_id_match(argument) or re.match(
r"<@!?([0-9]{15,20})>$", argument
)
guild = ctx.guild
result = None
user_id = None
@@ -239,7 +247,9 @@ class MemberConverter(IDConverter[discord.Member]):
else:
user_id = int(match.group(1))
if guild:
result = guild.get_member(user_id) or _utils_get(ctx.message.mentions, id=user_id)
result = guild.get_member(user_id) or _utils_get(
ctx.message.mentions, id=user_id
)
else:
result = _get_from_guilds(bot, "get_member", user_id)
@@ -279,13 +289,17 @@ class UserConverter(IDConverter[discord.User]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.User:
match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument)
match = self._get_id_match(argument) or re.match(
r"<@!?([0-9]{15,20})>$", argument
)
result = None
state = ctx._state
if match is not None:
user_id = int(match.group(1))
result = ctx.bot.get_user(user_id) or _utils_get(ctx.message.mentions, id=user_id)
result = ctx.bot.get_user(user_id) or _utils_get(
ctx.message.mentions, id=user_id
)
if result is None:
try:
result = await ctx.bot.fetch_user(user_id)
@@ -333,7 +347,9 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
@staticmethod
def _get_id_matches(ctx, argument):
id_regex = re.compile(r"(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$")
id_regex = re.compile(
r"(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$"
)
link_regex = re.compile(
r"https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/"
r"(?P<guild_id>[0-9]{15,20}|@me)"
@@ -355,7 +371,9 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
return guild_id, message_id, channel_id
@staticmethod
def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]:
def _resolve_channel(
ctx, guild_id, channel_id
) -> Optional[PartialMessageableChannel]:
if guild_id is not None:
guild = ctx.bot.get_guild(guild_id)
if guild is not None and channel_id is not None:
@@ -389,7 +407,9 @@ class MessageConverter(IDConverter[discord.Message]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.Message:
guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument)
guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(
ctx, argument
)
message = ctx.bot._connection._get_message(message_id)
if message:
return message
@@ -420,13 +440,19 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.abc.GuildChannel:
return self._resolve_channel(ctx, argument, "channels", discord.abc.GuildChannel)
return self._resolve_channel(
ctx, argument, "channels", discord.abc.GuildChannel
)
@staticmethod
def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT:
def _resolve_channel(
ctx: Context, argument: str, attribute: str, type: Type[CT]
) -> CT:
bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument)
match = IDConverter._get_id_match(argument) or re.match(
r"<#([0-9]{15,20})>$", argument
)
result = None
guild = ctx.guild
@@ -454,10 +480,14 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
return result
@staticmethod
def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT:
def _resolve_thread(
ctx: Context, argument: str, attribute: str, type: Type[TT]
) -> TT:
bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument)
match = IDConverter._get_id_match(argument) or re.match(
r"<#([0-9]{15,20})>$", argument
)
result = None
guild = ctx.guild
@@ -494,7 +524,9 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.TextChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, "text_channels", discord.TextChannel)
return GuildChannelConverter._resolve_channel(
ctx, argument, "text_channels", discord.TextChannel
)
class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
@@ -514,7 +546,9 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, "voice_channels", discord.VoiceChannel)
return GuildChannelConverter._resolve_channel(
ctx, argument, "voice_channels", discord.VoiceChannel
)
class StageChannelConverter(IDConverter[discord.StageChannel]):
@@ -533,7 +567,9 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.StageChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, "stage_channels", discord.StageChannel)
return GuildChannelConverter._resolve_channel(
ctx, argument, "stage_channels", discord.StageChannel
)
class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
@@ -553,7 +589,9 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, "categories", discord.CategoryChannel)
return GuildChannelConverter._resolve_channel(
ctx, argument, "categories", discord.CategoryChannel
)
class StoreChannelConverter(IDConverter[discord.StoreChannel]):
@@ -572,7 +610,9 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, "channels", discord.StoreChannel)
return GuildChannelConverter._resolve_channel(
ctx, argument, "channels", discord.StoreChannel
)
class ThreadConverter(IDConverter[discord.Thread]):
@@ -590,7 +630,9 @@ class ThreadConverter(IDConverter[discord.Thread]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.Thread:
return GuildChannelConverter._resolve_thread(ctx, argument, "threads", discord.Thread)
return GuildChannelConverter._resolve_thread(
ctx, argument, "threads", discord.Thread
)
class ColourConverter(Converter[discord.Colour]):
@@ -619,7 +661,9 @@ class ColourConverter(Converter[discord.Colour]):
Added support for ``rgb`` function and 3-digit hex shortcuts
"""
RGB_REGEX = re.compile(r"rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)")
RGB_REGEX = re.compile(
r"rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)"
)
def parse_hex_number(self, argument):
arg = "".join(i * 2 for i in argument) if len(argument) == 3 else argument
@@ -700,7 +744,9 @@ class RoleConverter(IDConverter[discord.Role]):
if not guild:
raise NoPrivateMessage()
match = self._get_id_match(argument) or re.match(r"<@&([0-9]{15,20})>$", argument)
match = self._get_id_match(argument) or re.match(
r"<@&([0-9]{15,20})>$", argument
)
if match:
result = guild.get_role(int(match.group(1)))
else:
@@ -779,7 +825,9 @@ class EmojiConverter(IDConverter[discord.Emoji]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.Emoji:
match = self._get_id_match(argument) or re.match(r"<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$", argument)
match = self._get_id_match(argument) or re.match(
r"<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$", argument
)
result = None
bot = ctx.bot
guild = ctx.guild
@@ -821,7 +869,10 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
emoji_id = int(match.group(3))
return discord.PartialEmoji.with_state(
ctx.bot._connection, animated=emoji_animated, name=emoji_name, id=emoji_id
ctx.bot._connection,
animated=emoji_animated,
name=emoji_name,
id=emoji_id,
)
raise PartialEmojiConversionFailure(argument)
@@ -906,7 +957,11 @@ class clean_content(Converter[str]):
def resolve_member(id: int) -> str:
m = _utils_get(msg.mentions, id=id) or ctx.guild.get_member(id)
return f"@{m.display_name if self.use_nicknames else m.name}" if m else "@deleted-user"
return (
f"@{m.display_name if self.use_nicknames else m.name}"
if m
else "@deleted-user"
)
def resolve_role(id: int) -> str:
r = _utils_get(msg.role_mentions, id=id) or ctx.guild.get_role(id)
@@ -996,7 +1051,11 @@ class Greedy(List[T]):
origin = getattr(converter, "__origin__", None)
args = getattr(converter, "__args__", ())
if not (callable(converter) or isinstance(converter, Converter) or origin is not None):
if not (
callable(converter)
or isinstance(converter, Converter)
or origin is not None
):
raise TypeError("Greedy[...] expects a type or a Converter instance.")
if converter in (str, type(None)) or origin is Greedy:
@@ -1044,7 +1103,13 @@ class Option(Generic[T, DT]): # type: ignore
"name",
)
def __init__(self, default: T = inspect.Parameter.empty, *, description: DT, name: str = discord.utils.MISSING) -> None:
def __init__(
self,
default: T = inspect.Parameter.empty,
*,
description: DT,
name: str = discord.utils.MISSING,
) -> None:
self.description = description
self.default = default
self.name: str = name
@@ -1052,7 +1117,12 @@ class Option(Generic[T, DT]): # type: ignore
if TYPE_CHECKING:
# Terrible workaround for type checking reasons
def Option(default: T = inspect.Parameter.empty, *, description: str, name: str = discord.utils.MISSING) -> T:
def Option(
default: T = inspect.Parameter.empty,
*,
description: str,
name: str = discord.utils.MISSING,
) -> T:
...
@@ -1107,7 +1177,9 @@ CONVERTER_MAPPING: Dict[Type[Any], Any] = {
}
async def _actual_conversion(ctx: Context, converter, argument: str, param: inspect.Parameter):
async def _actual_conversion(
ctx: Context, converter, argument: str, param: inspect.Parameter
):
if converter is bool:
return _convert_to_bool(argument)
@@ -1116,7 +1188,9 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp
except AttributeError:
pass
else:
if module is not None and (module.startswith("discord.") and not module.endswith("converter")):
if module is not None and (
module.startswith("discord.") and not module.endswith("converter")
):
converter = CONVERTER_MAPPING.get(converter, converter)
try:
@@ -1142,10 +1216,14 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp
except AttributeError:
name = converter.__class__.__name__
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
raise BadArgument(
f'Converting to "{name}" failed for parameter "{param.name}".'
) from exc
async def run_converters(ctx: Context, converter, argument: str, param: inspect.Parameter):
async def run_converters(
ctx: Context, converter, argument: str, param: inspect.Parameter
):
"""|coro|
Runs converters for a given converter, argument, and parameter.

View File

@@ -53,7 +53,13 @@ from operator import itemgetter
import discord
from .errors import *
from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency, DynamicCooldownMapping
from .cooldowns import (
Cooldown,
BucketType,
CooldownMapping,
MaxConcurrency,
DynamicCooldownMapping,
)
from .converter import (
CONVERTER_MAPPING,
Converter,
@@ -74,7 +80,10 @@ if TYPE_CHECKING:
from typing_extensions import Concatenate, ParamSpec, TypeGuard
from discord.message import Message
from discord.types.interactions import EditApplicationCommand, ApplicationCommandInteractionDataOption
from discord.types.interactions import (
EditApplicationCommand,
ApplicationCommandInteractionDataOption,
)
from ._types import (
Coro,
@@ -394,7 +403,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.slash_command: Optional[bool] = kwargs.get("slash_command", None)
self.message_command: Optional[bool] = kwargs.get("message_command", None)
self.slash_command_guilds: Optional[Iterable[int]] = kwargs.get("slash_command_guilds", None)
self.slash_command_guilds: Optional[Iterable[int]] = kwargs.get(
"slash_command_guilds", None
)
help_doc = kwargs.get("help")
if help_doc is not None:
@@ -413,7 +424,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.extras: Dict[str, Any] = kwargs.get("extras", {})
if not isinstance(self.aliases, (list, tuple)):
raise TypeError("Aliases of a command must be a list or a tuple of strings.")
raise TypeError(
"Aliases of a command must be a list or a tuple of strings."
)
self.description: str = inspect.cleandoc(kwargs.get("description", ""))
self.hidden: bool = kwargs.get("hidden", False)
@@ -439,7 +452,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
elif isinstance(cooldown, CooldownMapping):
buckets = cooldown
else:
raise TypeError("Cooldown must be a an instance of CooldownMapping or None.")
raise TypeError(
"Cooldown must be a an instance of CooldownMapping or None."
)
self.checks: List[Check] = checks
self._buckets: CooldownMapping = buckets
@@ -480,7 +495,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
@property
def callback(
self,
) -> Union[Callable[Concatenate[CogT, Context, P], Coro[T]], Callable[Concatenate[Context, P], Coro[T]],]:
) -> Union[
Callable[Concatenate[CogT, Context, P], Coro[T]],
Callable[Concatenate[Context, P], Coro[T]],
]:
return self._callback
@callback.setter
@@ -500,7 +518,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
except AttributeError:
globalns = {}
self.params, self.option_descriptions = get_signature_parameters(function, globalns)
self.params, self.option_descriptions = get_signature_parameters(
function, globalns
)
def _update_attrs(self, **command_attrs: Any):
for key, value in command_attrs.items():
@@ -634,7 +654,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
required = param.default is param.empty
converter = get_converter(param)
consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw
consume_rest_is_special = (
param.kind == param.KEYWORD_ONLY and not self.rest_is_raw
)
view = ctx.view
view.skip_ws()
@@ -642,9 +664,13 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
# it undos the view ready for the next parameter to use instead
if isinstance(converter, Greedy):
if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY):
return await self._transform_greedy_pos(ctx, param, required, converter.converter)
return await self._transform_greedy_pos(
ctx, param, required, converter.converter
)
elif param.kind == param.VAR_POSITIONAL:
return await self._transform_greedy_var_pos(ctx, param, converter.converter)
return await self._transform_greedy_var_pos(
ctx, param, converter.converter
)
else:
# if we're here, then it's a KEYWORD_ONLY param type
# since this is mostly useless, we'll helpfully transform Greedy[X]
@@ -657,7 +683,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if required:
if self._is_typing_optional(param.annotation):
return None
if hasattr(converter, "__commands_is_flag__") and converter._can_be_constructible():
if (
hasattr(converter, "__commands_is_flag__")
and converter._can_be_constructible()
):
return await converter._construct_default(ctx)
raise MissingRequiredArgument(param)
return param.default
@@ -702,7 +731,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return param.default
return result
async def _transform_greedy_var_pos(self, ctx: Context, param: inspect.Parameter, converter: Any) -> Any:
async def _transform_greedy_var_pos(
self, ctx: Context, param: inspect.Parameter, converter: Any
) -> Any:
view = ctx.view
previous = view.index
try:
@@ -816,13 +847,17 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
try:
next(iterator)
except StopIteration:
raise discord.ClientException(f'Callback for {self.name} command is missing "self" parameter.')
raise discord.ClientException(
f'Callback for {self.name} command is missing "self" parameter.'
)
# next we have the 'ctx' as the next parameter
try:
next(iterator)
except StopIteration:
raise discord.ClientException(f'Callback for {self.name} command is missing "ctx" parameter.')
raise discord.ClientException(
f'Callback for {self.name} command is missing "ctx" parameter.'
)
for name, param in iterator:
ctx.current_parameter = param
@@ -849,7 +884,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
break
if not self.ignore_extra and not view.eof:
raise TooManyArguments("Too many arguments passed to " + self.qualified_name)
raise TooManyArguments(
"Too many arguments passed to " + self.qualified_name
)
async def call_before_hooks(self, ctx: Context) -> None:
# now that we're done preparing we can call the pre-command hooks
@@ -909,7 +946,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
ctx.command = self
if not await self.can_run(ctx):
raise CheckFailure(f"The check functions for command {self.qualified_name} failed.")
raise CheckFailure(
f"The check functions for command {self.qualified_name} failed."
)
if self._max_concurrency is not None:
# For this application, context can be duck-typed as a Message
@@ -1118,7 +1157,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return self.help.split("\n", 1)[0]
return ""
def _is_typing_optional(self, annotation: Union[T, Optional[T]]) -> TypeGuard[Optional[T]]:
def _is_typing_optional(
self, annotation: Union[T, Optional[T]]
) -> TypeGuard[Optional[T]]:
return getattr(annotation, "__origin__", None) is Union and type(None) in annotation.__args__ # type: ignore
@property
@@ -1149,13 +1190,24 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
origin = getattr(annotation, "__origin__", None)
if origin is Literal:
name = "|".join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__)
name = "|".join(
f'"{v}"' if isinstance(v, str) else str(v)
for v in annotation.__args__
)
if param.default is not param.empty:
# We don't want None or '' to trigger the [name=value] case and instead it should
# do [name] since [name=None] or [name=] are not exactly useful for the user.
should_print = param.default if isinstance(param.default, str) else param.default is not None
should_print = (
param.default
if isinstance(param.default, str)
else param.default is not None
)
if should_print:
result.append(f"[{name}={param.default}]" if not greedy else f"[{name}={param.default}]...")
result.append(
f"[{name}={param.default}]"
if not greedy
else f"[{name}={param.default}]..."
)
continue
else:
result.append(f"[{name}]")
@@ -1204,21 +1256,29 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
raise DisabledCommand(f"{self.name} command is disabled")
if ctx.interaction is None and (
self.message_command is False or (self.message_command is None and not ctx.bot.message_commands)
self.message_command is False
or (self.message_command is None and not ctx.bot.message_commands)
):
raise DisabledCommand(f"{self.name} command cannot be run as a message command")
raise DisabledCommand(
f"{self.name} command cannot be run as a message command"
)
if ctx.interaction is not None and (
self.slash_command is False or (self.slash_command is None and not ctx.bot.slash_commands)
self.slash_command is False
or (self.slash_command is None and not ctx.bot.slash_commands)
):
raise DisabledCommand(f"{self.name} command cannot be run as a slash command")
raise DisabledCommand(
f"{self.name} command cannot be run as a slash command"
)
original = ctx.command
ctx.command = self
try:
if not await ctx.bot.can_run(ctx):
raise CheckFailure(f"The global check functions for command {self.qualified_name} failed.")
raise CheckFailure(
f"The global check functions for command {self.qualified_name} failed."
)
cog = self.cog
if cog is not None:
@@ -1246,7 +1306,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
param
for name, flag in annotation.get_flags().items()
for param in self._param_to_options(
name, flag.annotation, required=flag.required, varadic=flag.annotation is tuple
name,
flag.annotation,
required=flag.required,
varadic=flag.annotation is tuple,
)
]
@@ -1279,16 +1342,27 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
option["type"] = discord_type
# Set channel types
if discord_type == 7:
option["channel_types"] = application_option_channel_types[annotation]
option["channel_types"] = application_option_channel_types[
annotation
]
break
elif origin is Union:
if annotation in {Union[discord.Member, discord.Role], Union[MemberConverter, RoleConverter]}:
if annotation in {
Union[discord.Member, discord.Role],
Union[MemberConverter, RoleConverter],
}:
option["type"] = 9
elif all([arg in application_option_channel_types for arg in annotation.__args__]):
elif all(
[arg in application_option_channel_types for arg in annotation.__args__]
):
option["type"] = 7
option["channel_types"] = [discord_value for arg in annotation.__args__ for discord_value in application_option_channel_types[arg]]
option["channel_types"] = [
discord_value
for arg in annotation.__args__
for discord_value in application_option_channel_types[arg]
]
elif origin is Literal:
literal_values = annotation.__args__
@@ -1300,18 +1374,27 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
option["type"] = application_option_type_lookup[python_type]
option["choices"] = [
{"name": literal_value, "value": literal_value} for literal_value in annotation.__args__
{"name": literal_value, "value": literal_value}
for literal_value in annotation.__args__
]
return [option] # type: ignore
def to_application_command(self, nested: int = 0) -> Optional[EditApplicationCommand]:
def to_application_command(
self, nested: int = 0
) -> Optional[EditApplicationCommand]:
if self.slash_command is False:
return
elif nested == 3:
raise ApplicationCommandRegistrationError(self, f"{self.qualified_name} is too deeply nested!")
raise ApplicationCommandRegistrationError(
self, f"{self.qualified_name} is too deeply nested!"
)
payload = {"name": self.name, "description": self.short_doc or "no description", "options": []}
payload = {
"name": self.name,
"description": self.short_doc or "no description",
"options": [],
}
if nested != 0:
payload["type"] = 1
@@ -1319,15 +1402,23 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
options = self._param_to_options(
name,
param.annotation if param.annotation is not param.empty else str,
varadic=param.kind == param.KEYWORD_ONLY or isinstance(param.annotation, Greedy),
required=(param.default is param.empty and not self._is_typing_optional(param.annotation))
varadic=param.kind == param.KEYWORD_ONLY
or isinstance(param.annotation, Greedy),
required=(
param.default is param.empty
and not self._is_typing_optional(param.annotation)
)
or param.kind == param.VAR_POSITIONAL,
)
if options is not None:
payload["options"].extend(option for option in options if option is not None)
payload["options"].extend(
option for option in options if option is not None
)
# Now we have all options, make sure required is before optional.
payload["options"] = sorted(payload["options"], key=itemgetter("required"), reverse=True)
payload["options"] = sorted(
payload["options"], key=itemgetter("required"), reverse=True
)
return payload # type: ignore
@@ -1346,7 +1437,9 @@ class GroupMixin(Generic[CogT]):
def __init__(self, *args: Any, **kwargs: Any) -> None:
case_insensitive = kwargs.get("case_insensitive", True)
self.all_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {}
self.all_commands: Dict[str, Command[CogT, Any, Any]] = (
_CaseInsensitiveDict() if case_insensitive else {}
)
self.case_insensitive: bool = case_insensitive
super().__init__(*args, **kwargs)
@@ -1552,7 +1645,12 @@ class GroupMixin(Generic[CogT]):
*args: Any,
**kwargs: Any,
) -> Callable[
[Union[Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]]]],
[
Union[
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]],
]
],
Group[CogT, P, T],
]:
...
@@ -1703,18 +1801,23 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
view.previous = previous
await super().reinvoke(ctx, call_hooks=call_hooks)
def to_application_command(self, nested: int = 0) -> Optional[EditApplicationCommand]:
def to_application_command(
self, nested: int = 0
) -> Optional[EditApplicationCommand]:
if self.slash_command is False:
return
elif nested == 2:
raise ApplicationCommandRegistrationError(self, f"{self.qualified_name} is too deeply nested!")
raise ApplicationCommandRegistrationError(
self, f"{self.qualified_name} is too deeply nested!"
)
return { # type: ignore
"name": self.name,
"type": int(not (nested - 1)) + 1,
"description": self.short_doc or "no description",
"options": [
cmd.to_application_command(nested=nested + 1) for cmd in sorted(self.commands, key=lambda x: x.name)
cmd.to_application_command(nested=nested + 1)
for cmd in sorted(self.commands, key=lambda x: x.name)
],
}
@@ -2022,7 +2125,9 @@ def check_any(*checks: Check) -> Callable[[T], T]:
try:
pred = wrapped.predicate
except AttributeError:
raise TypeError(f"{wrapped!r} must be wrapped by commands.check decorator") from None
raise TypeError(
f"{wrapped!r} must be wrapped by commands.check decorator"
) from None
else:
unwrapped.append(pred)
@@ -2124,7 +2229,10 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
# ctx.guild is None doesn't narrow ctx.author to Member
getter = functools.partial(discord.utils.get, ctx.author.roles) # type: ignore
if any(
getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items
getter(id=item) is not None
if isinstance(item, int)
else getter(name=item) is not None
for item in items
):
return True
raise MissingAnyRole(list(items))
@@ -2183,7 +2291,10 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]:
me = ctx.me
getter = functools.partial(discord.utils.get, me.roles)
if any(
getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items
getter(id=item) is not None
if isinstance(item, int)
else getter(name=item) is not None
for item in items
):
return True
raise BotMissingAnyRole(list(items))
@@ -2229,7 +2340,9 @@ def has_permissions(**perms: bool) -> Callable[[T], T]:
ch = ctx.channel
permissions = ch.permissions_for(ctx.author) # type: ignore
missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value]
missing = [
perm for perm, value in perms.items() if getattr(permissions, perm) != value
]
if not missing:
return True
@@ -2256,7 +2369,9 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
me = guild.me if guild is not None else ctx.bot.user
permissions = ctx.channel.permissions_for(me) # type: ignore
missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value]
missing = [
perm for perm, value in perms.items() if getattr(permissions, perm) != value
]
if not missing:
return True
@@ -2285,7 +2400,9 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
raise NoPrivateMessage
permissions = ctx.author.guild_permissions # type: ignore
missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value]
missing = [
perm for perm, value in perms.items() if getattr(permissions, perm) != value
]
if not missing:
return True
@@ -2311,7 +2428,9 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
raise NoPrivateMessage
permissions = ctx.me.guild_permissions # type: ignore
missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value]
missing = [
perm for perm, value in perms.items() if getattr(permissions, perm) != value
]
if not missing:
return True
@@ -2389,7 +2508,9 @@ def is_nsfw() -> Callable[[T], T]:
def pred(ctx: Context) -> bool:
ch = ctx.channel
if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()):
if ctx.guild is None or (
isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()
):
return True
raise NSFWChannelRequired(ch) # type: ignore
@@ -2397,7 +2518,9 @@ def is_nsfw() -> Callable[[T], T]:
def cooldown(
rate: int, per: float, type: Union[BucketType, Callable[[Message], Any]] = BucketType.default
rate: int,
per: float,
type: Union[BucketType, Callable[[Message], Any]] = BucketType.default,
) -> Callable[[T], T]:
"""A decorator that adds a cooldown to a :class:`.Command`
@@ -2432,14 +2555,17 @@ def cooldown(
if not hasattr(func, "__command_attrs__"):
func.__command_attrs__ = {}
func.__command_attrs__["cooldown"] = CooldownMapping(Cooldown(rate, per), type)
func.__command_attrs__["cooldown"] = CooldownMapping(
Cooldown(rate, per), type
)
return func
return decorator # type: ignore
def dynamic_cooldown(
cooldown: Union[BucketType, Callable[[Message], Any]], type: BucketType = BucketType.default
cooldown: Union[BucketType, Callable[[Message], Any]],
type: BucketType = BucketType.default,
) -> Callable[[T], T]:
"""A decorator that adds a dynamic cooldown to a :class:`.Command`
@@ -2485,7 +2611,9 @@ def dynamic_cooldown(
return decorator # type: ignore
def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait: bool = False) -> Callable[[T], T]:
def max_concurrency(
number: int, per: BucketType = BucketType.default, *, wait: bool = False
) -> Callable[[T], T]:
"""A decorator that adds a maximum concurrency to a :class:`.Command` or its subclasses.
This enables you to only allow a certain number of command invocations at the same time,