Fix code style issues with Black

This commit is contained in:
Lint Action
2021-09-05 21:34:20 +00:00
parent a23dae8604
commit 7513c2138f
108 changed files with 5369 additions and 4858 deletions

View File

@ -31,7 +31,7 @@ if TYPE_CHECKING:
from .cog import Cog
from .errors import CommandError
T = TypeVar('T')
T = TypeVar("T")
Coro = Coroutine[Any, Any, T]
MaybeCoro = Union[T, Coro[T]]
@ -39,7 +39,9 @@ CoroFunc = Callable[..., Coro[Any]]
Check = Union[Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]]
Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]]
Error = Union[Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]]
Error = Union[
Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]
]
# This is merely a tag type to avoid circular import issues.

View File

@ -54,17 +54,18 @@ if TYPE_CHECKING:
)
__all__ = (
'when_mentioned',
'when_mentioned_or',
'Bot',
'AutoShardedBot',
"when_mentioned",
"when_mentioned_or",
"Bot",
"AutoShardedBot",
)
MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
CFT = TypeVar('CFT', bound='CoroFunc')
CXT = TypeVar('CXT', bound='Context')
T = TypeVar("T")
CFT = TypeVar("CFT", bound="CoroFunc")
CXT = TypeVar("CXT", bound="Context")
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
"""A callable that implements a command prefix equivalent to being mentioned.
@ -72,7 +73,8 @@ def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
"""
# bot.user will never be None when this is called
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore
return [f"<@{bot.user.id}> ", f"<@!{bot.user.id}> "] # type: ignore
def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]:
"""A callable that implements when mentioned or other prefixes provided.
@ -103,6 +105,7 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M
----------
:func:`.when_mentioned`
"""
def inner(bot, msg):
r = list(prefixes)
r = when_mentioned(bot, msg) + r
@ -110,15 +113,19 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M
return inner
def _is_submodule(parent: str, child: str) -> bool:
return parent == child or child.startswith(parent + ".")
class _DefaultRepr:
def __repr__(self):
return '<default-help-command>'
return "<default-help-command>"
_default = _DefaultRepr()
class BotBase(GroupMixin):
def __init__(self, command_prefix, help_command=_default, description=None, *, intents: discord.Intents, **options):
super().__init__(**options, intents=intents)
@ -131,16 +138,16 @@ class BotBase(GroupMixin):
self._before_invoke = None
self._after_invoke = None
self._help_command = None
self.description = inspect.cleandoc(description) if description else ''
self.owner_id = options.get('owner_id')
self.owner_ids = options.get('owner_ids', set())
self.strip_after_prefix = options.get('strip_after_prefix', False)
self.description = inspect.cleandoc(description) if description else ""
self.owner_id = options.get("owner_id")
self.owner_ids = options.get("owner_ids", set())
self.strip_after_prefix = options.get("strip_after_prefix", False)
if self.owner_id and self.owner_ids:
raise TypeError('Both owner_id and owner_ids are set.')
raise TypeError("Both owner_id and owner_ids are set.")
if self.owner_ids and not isinstance(self.owner_ids, collections.abc.Collection):
raise TypeError(f'owner_ids must be a collection not {self.owner_ids.__class__!r}')
raise TypeError(f"owner_ids must be a collection not {self.owner_ids.__class__!r}")
if help_command is _default:
self.help_command = DefaultHelpCommand()
@ -152,7 +159,7 @@ class BotBase(GroupMixin):
def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None:
# super() will resolve to Client
super().dispatch(event_name, *args, **kwargs) # type: ignore
ev = 'on_' + event_name
ev = "on_" + event_name
for event in self.extra_events.get(ev, []):
self._schedule_event(event, ev, *args, **kwargs) # type: ignore
@ -182,7 +189,7 @@ class BotBase(GroupMixin):
This only fires if you do not specify any listeners for command error.
"""
if self.extra_events.get('on_command_error', None):
if self.extra_events.get("on_command_error", None):
return
command = context.command
@ -193,7 +200,7 @@ class BotBase(GroupMixin):
if cog and cog.has_error_handler():
return
print(f'Ignoring exception in command {context.command}:', file=sys.stderr)
print(f"Ignoring exception in command {context.command}:", file=sys.stderr)
traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr)
# global check registration
@ -425,7 +432,7 @@ class BotBase(GroupMixin):
The coroutine passed is not actually a coroutine.
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The pre-invoke hook must be a coroutine.')
raise TypeError("The pre-invoke hook must be a coroutine.")
self._before_invoke = coro
return coro
@ -458,7 +465,7 @@ class BotBase(GroupMixin):
The coroutine passed is not actually a coroutine.
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The post-invoke hook must be a coroutine.')
raise TypeError("The post-invoke hook must be a coroutine.")
self._after_invoke = coro
return coro
@ -490,7 +497,7 @@ class BotBase(GroupMixin):
name = func.__name__ if name is MISSING else name
if not asyncio.iscoroutinefunction(func):
raise TypeError('Listeners must be coroutines')
raise TypeError("Listeners must be coroutines")
if name in self.extra_events:
self.extra_events[name].append(func)
@ -586,14 +593,14 @@ class BotBase(GroupMixin):
"""
if not isinstance(cog, Cog):
raise TypeError('cogs must derive from Cog')
raise TypeError("cogs must derive from Cog")
cog_name = cog.__cog_name__
existing = self.__cogs.get(cog_name)
if existing is not None:
if not override:
raise discord.ClientException(f'Cog named {cog_name!r} already loaded')
raise discord.ClientException(f"Cog named {cog_name!r} already loaded")
self.remove_cog(cog_name)
cog = cog._inject(self)
@ -681,7 +688,7 @@ class BotBase(GroupMixin):
def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None:
try:
func = getattr(lib, 'teardown')
func = getattr(lib, "teardown")
except AttributeError:
pass
else:
@ -708,7 +715,7 @@ class BotBase(GroupMixin):
raise errors.ExtensionFailed(key, e) from e
try:
setup = getattr(lib, 'setup')
setup = getattr(lib, "setup")
except AttributeError:
del sys.modules[key]
raise errors.NoEntryPointError(key)
@ -858,11 +865,7 @@ class BotBase(GroupMixin):
raise errors.ExtensionNotLoaded(name)
# get the previous module states from sys modules
modules = {
name: module
for name, module in sys.modules.items()
if _is_submodule(lib.__name__, name)
}
modules = {name: module for name, module in sys.modules.items() if _is_submodule(lib.__name__, name)}
try:
# Unload and then load the module...
@ -895,7 +898,7 @@ class BotBase(GroupMixin):
def help_command(self, value: Optional[HelpCommand]) -> None:
if value is not None:
if not isinstance(value, HelpCommand):
raise TypeError('help_command must be a subclass of HelpCommand')
raise TypeError("help_command must be a subclass of HelpCommand")
if self._help_command is not None:
self._help_command._remove_from_bot(self)
self._help_command = value
@ -938,8 +941,10 @@ class BotBase(GroupMixin):
if isinstance(ret, collections.abc.Iterable):
raise
raise TypeError("command_prefix must be plain string, iterable of strings, or callable "
f"returning either of these, not {ret.__class__.__name__}")
raise TypeError(
"command_prefix must be plain string, iterable of strings, or callable "
f"returning either of these, not {ret.__class__.__name__}"
)
if not ret:
raise ValueError("Iterable command_prefix must contain at least one prefix")
@ -999,14 +1004,18 @@ class BotBase(GroupMixin):
except TypeError:
if not isinstance(prefix, list):
raise TypeError("get_prefix must return either a string or a list of string, "
f"not {prefix.__class__.__name__}")
raise TypeError(
"get_prefix must return either a string or a list of string, "
f"not {prefix.__class__.__name__}"
)
# It's possible a bad command_prefix got us here.
for value in prefix:
if not isinstance(value, str):
raise TypeError("Iterable command_prefix or list returned from get_prefix must "
f"contain only strings, not {value.__class__.__name__}")
raise TypeError(
"Iterable command_prefix or list returned from get_prefix must "
f"contain only strings, not {value.__class__.__name__}"
)
# Getting here shouldn't happen
raise
@ -1033,19 +1042,19 @@ class BotBase(GroupMixin):
The invocation context to invoke.
"""
if ctx.command is not None:
self.dispatch('command', ctx)
self.dispatch("command", ctx)
try:
if await self.can_run(ctx, call_once=True):
await ctx.command.invoke(ctx)
else:
raise errors.CheckFailure('The global check once functions failed.')
raise errors.CheckFailure("The global check once functions failed.")
except errors.CommandError as exc:
await ctx.command.dispatch_error(ctx, exc)
else:
self.dispatch('command_completion', ctx)
self.dispatch("command_completion", ctx)
elif ctx.invoked_with:
exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found')
self.dispatch('command_error', ctx, exc)
self.dispatch("command_error", ctx, exc)
async def process_commands(self, message: Message) -> None:
"""|coro|
@ -1078,6 +1087,7 @@ class BotBase(GroupMixin):
async def on_message(self, message):
await self.process_commands(message)
class Bot(BotBase, discord.Client):
"""Represents a discord bot.
@ -1148,10 +1158,13 @@ class Bot(BotBase, discord.Client):
.. versionadded:: 1.7
"""
pass
class AutoShardedBot(BotBase, discord.AutoShardedClient):
"""This is similar to :class:`.Bot` except that it is inherited from
:class:`discord.AutoShardedClient` instead.
"""
pass

View File

@ -36,15 +36,16 @@ if TYPE_CHECKING:
from .core import Command
__all__ = (
'CogMeta',
'Cog',
"CogMeta",
"Cog",
)
CogT = TypeVar('CogT', bound='Cog')
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
CogT = TypeVar("CogT", bound="Cog")
FuncT = TypeVar("FuncT", bound=Callable[..., Any])
MISSING: Any = discord.utils.MISSING
class CogMeta(type):
"""A metaclass for defining a cog.
@ -104,6 +105,7 @@ class CogMeta(type):
async def bar(self, ctx):
pass # hidden -> False
"""
__cog_name__: str
__cog_settings__: Dict[str, Any]
__cog_commands__: List[Command]
@ -111,17 +113,17 @@ class CogMeta(type):
def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta:
name, bases, attrs = args
attrs['__cog_name__'] = kwargs.pop('name', name)
attrs['__cog_settings__'] = kwargs.pop('command_attrs', {})
attrs["__cog_name__"] = kwargs.pop("name", name)
attrs["__cog_settings__"] = kwargs.pop("command_attrs", {})
description = kwargs.pop('description', None)
description = kwargs.pop("description", None)
if description is None:
description = inspect.cleandoc(attrs.get('__doc__', ''))
attrs['__cog_description__'] = description
description = inspect.cleandoc(attrs.get("__doc__", ""))
attrs["__cog_description__"] = description
commands = {}
listeners = {}
no_bot_cog = 'Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})'
no_bot_cog = "Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})"
new_cls = super().__new__(cls, name, bases, attrs, **kwargs)
for base in reversed(new_cls.__mro__):
@ -136,21 +138,21 @@ class CogMeta(type):
value = value.__func__
if isinstance(value, _BaseCommand):
if is_static_method:
raise TypeError(f'Command in method {base}.{elem!r} must not be staticmethod.')
if elem.startswith(('cog_', 'bot_')):
raise TypeError(f"Command in method {base}.{elem!r} must not be staticmethod.")
if elem.startswith(("cog_", "bot_")):
raise TypeError(no_bot_cog.format(base, elem))
commands[elem] = value
elif inspect.iscoroutinefunction(value):
try:
getattr(value, '__cog_listener__')
getattr(value, "__cog_listener__")
except AttributeError:
continue
else:
if elem.startswith(('cog_', 'bot_')):
if elem.startswith(("cog_", "bot_")):
raise TypeError(no_bot_cog.format(base, elem))
listeners[elem] = value
new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__
new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__
listeners_as_list = []
for listener in listeners.values():
@ -169,10 +171,12 @@ class CogMeta(type):
def qualified_name(cls) -> str:
return cls.__cog_name__
def _cog_special_method(func: FuncT) -> FuncT:
func.__cog_special_method__ = None
return func
class Cog(metaclass=CogMeta):
"""The base class that all cogs must inherit from.
@ -183,6 +187,7 @@ class Cog(metaclass=CogMeta):
When inheriting from this class, the options shown in :class:`CogMeta`
are equally valid here.
"""
__cog_name__: ClassVar[str]
__cog_settings__: ClassVar[Dict[str, Any]]
__cog_commands__: ClassVar[List[Command]]
@ -199,10 +204,7 @@ class Cog(metaclass=CogMeta):
# r.e type ignore, type-checker complains about overriding a ClassVar
self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore
lookup = {
cmd.qualified_name: cmd
for cmd in self.__cog_commands__
}
lookup = {cmd.qualified_name: cmd for cmd in self.__cog_commands__}
# Update the Command instances dynamically as well
for command in self.__cog_commands__:
@ -255,6 +257,7 @@ class Cog(metaclass=CogMeta):
A command or group from the cog.
"""
from .core import GroupMixin
for command in self.__cog_commands__:
if command.parent is None:
yield command
@ -274,7 +277,7 @@ class Cog(metaclass=CogMeta):
@classmethod
def _get_overridden_method(cls, method: FuncT) -> Optional[FuncT]:
"""Return None if the method is not overridden. Otherwise returns the overridden method."""
return getattr(method.__func__, '__cog_special_method__', method)
return getattr(method.__func__, "__cog_special_method__", method)
@classmethod
def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]:
@ -296,14 +299,14 @@ class Cog(metaclass=CogMeta):
"""
if name is not MISSING and not isinstance(name, str):
raise TypeError(f'Cog.listener expected str but received {name.__class__.__name__!r} instead.')
raise TypeError(f"Cog.listener expected str but received {name.__class__.__name__!r} instead.")
def decorator(func: FuncT) -> FuncT:
actual = func
if isinstance(actual, staticmethod):
actual = actual.__func__
if not inspect.iscoroutinefunction(actual):
raise TypeError('Listener function must be a coroutine function.')
raise TypeError("Listener function must be a coroutine function.")
actual.__cog_listener__ = True
to_assign = name or actual.__name__
try:
@ -315,6 +318,7 @@ class Cog(metaclass=CogMeta):
# to pick it up but the metaclass unfurls the function and
# thus the assignments need to be on the actual function
return func
return decorator
def has_error_handler(self) -> bool:
@ -322,7 +326,7 @@ class Cog(metaclass=CogMeta):
.. versionadded:: 1.7
"""
return not hasattr(self.cog_command_error.__func__, '__cog_special_method__')
return not hasattr(self.cog_command_error.__func__, "__cog_special_method__")
@_cog_special_method
def cog_unload(self) -> None:

View File

@ -50,21 +50,19 @@ if TYPE_CHECKING:
from .help import HelpCommand
from .view import StringView
__all__ = (
'Context',
)
__all__ = ("Context",)
MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
CogT = TypeVar('CogT', bound="Cog")
T = TypeVar("T")
BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]")
CogT = TypeVar("CogT", bound="Cog")
if TYPE_CHECKING:
P = ParamSpec('P')
P = ParamSpec("P")
else:
P = TypeVar('P')
P = TypeVar("P")
class Context(discord.abc.Messageable, Generic[BotT]):
@ -123,7 +121,8 @@ class Context(discord.abc.Messageable, Generic[BotT]):
or invoked.
"""
def __init__(self,
def __init__(
self,
*,
message: Message,
bot: BotT,
@ -220,7 +219,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
cmd = self.command
view = self.view
if cmd is None:
raise ValueError('This context is not valid.')
raise ValueError("This context is not valid.")
# some state to revert to when we're done
index, previous = view.index, view.previous
@ -231,10 +230,10 @@ class Context(discord.abc.Messageable, Generic[BotT]):
if restart:
to_call = cmd.root_parent or cmd
view.index = len(self.prefix or '')
view.index = len(self.prefix or "")
view.previous = 0
self.invoked_parents = []
self.invoked_with = view.get_word() # advance to get the root command
self.invoked_with = view.get_word() # advance to get the root command
else:
to_call = cmd
@ -264,7 +263,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
.. versionadded:: 2.0
"""
if self.prefix is None:
return ''
return ""
user = self.me
# this breaks if the prefix mention is not the bot itself but I
@ -272,7 +271,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
# for this common use case rather than waste performance for the
# odd one.
pattern = re.compile(r"<@!?%s>" % user.id)
return pattern.sub("@%s" % user.display_name.replace('\\', r'\\'), self.prefix)
return pattern.sub("@%s" % user.display_name.replace("\\", r"\\"), self.prefix)
@property
def cog(self) -> Optional[Cog]:
@ -389,7 +388,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
await cmd.prepare_help_command(self, entity.qualified_name)
try:
if hasattr(entity, '__cog_commands__'):
if hasattr(entity, "__cog_commands__"):
injected = wrap_callback(cmd.send_cog_help)
return await injected(entity)
elif isinstance(entity, Group):

View File

@ -52,32 +52,32 @@ if TYPE_CHECKING:
__all__ = (
'Converter',
'ObjectConverter',
'MemberConverter',
'UserConverter',
'MessageConverter',
'PartialMessageConverter',
'TextChannelConverter',
'InviteConverter',
'GuildConverter',
'RoleConverter',
'GameConverter',
'ColourConverter',
'ColorConverter',
'VoiceChannelConverter',
'StageChannelConverter',
'EmojiConverter',
'PartialEmojiConverter',
'CategoryChannelConverter',
'IDConverter',
'StoreChannelConverter',
'ThreadConverter',
'GuildChannelConverter',
'GuildStickerConverter',
'clean_content',
'Greedy',
'run_converters',
"Converter",
"ObjectConverter",
"MemberConverter",
"UserConverter",
"MessageConverter",
"PartialMessageConverter",
"TextChannelConverter",
"InviteConverter",
"GuildConverter",
"RoleConverter",
"GameConverter",
"ColourConverter",
"ColorConverter",
"VoiceChannelConverter",
"StageChannelConverter",
"EmojiConverter",
"PartialEmojiConverter",
"CategoryChannelConverter",
"IDConverter",
"StoreChannelConverter",
"ThreadConverter",
"GuildChannelConverter",
"GuildStickerConverter",
"clean_content",
"Greedy",
"run_converters",
)
@ -91,10 +91,10 @@ def _get_from_guilds(bot, getter, argument):
_utils_get = discord.utils.get
T = TypeVar('T')
T_co = TypeVar('T_co', covariant=True)
CT = TypeVar('CT', bound=discord.abc.GuildChannel)
TT = TypeVar('TT', bound=discord.Thread)
T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
CT = TypeVar("CT", bound=discord.abc.GuildChannel)
TT = TypeVar("TT", bound=discord.Thread)
@runtime_checkable
@ -132,10 +132,10 @@ class Converter(Protocol[T_co]):
:exc:`.BadArgument`
The converter failed to convert the argument.
"""
raise NotImplementedError('Derived classes need to implement this.')
raise NotImplementedError("Derived classes need to implement this.")
_ID_REGEX = re.compile(r'([0-9]{15,20})$')
_ID_REGEX = re.compile(r"([0-9]{15,20})$")
class IDConverter(Converter[T_co]):
@ -158,7 +158,7 @@ class ObjectConverter(IDConverter[discord.Object]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.Object:
match = self._get_id_match(argument) or re.match(r'<(?:@(?:!|&)?|#)([0-9]{15,20})>$', argument)
match = self._get_id_match(argument) or re.match(r"<(?:@(?:!|&)?|#)([0-9]{15,20})>$", argument)
if match is None:
raise ObjectNotFound(argument)
@ -192,8 +192,8 @@ class MemberConverter(IDConverter[discord.Member]):
async def query_member_named(self, guild, argument):
cache = guild._state.member_cache_flags.joined
if len(argument) > 5 and argument[-5] == '#':
username, _, discriminator = argument.rpartition('#')
if len(argument) > 5 and argument[-5] == "#":
username, _, discriminator = argument.rpartition("#")
members = await guild.query_members(username, limit=100, cache=cache)
return discord.utils.get(members, name=username, discriminator=discriminator)
else:
@ -223,7 +223,7 @@ class MemberConverter(IDConverter[discord.Member]):
async def convert(self, ctx: Context, argument: str) -> discord.Member:
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument)
guild = ctx.guild
result = None
user_id = None
@ -232,13 +232,13 @@ class MemberConverter(IDConverter[discord.Member]):
if guild:
result = guild.get_member_named(argument)
else:
result = _get_from_guilds(bot, 'get_member_named', argument)
result = _get_from_guilds(bot, "get_member_named", argument)
else:
user_id = int(match.group(1))
if guild:
result = guild.get_member(user_id) or _utils_get(ctx.message.mentions, id=user_id)
else:
result = _get_from_guilds(bot, 'get_member', user_id)
result = _get_from_guilds(bot, "get_member", user_id)
if result is None:
if guild is None:
@ -276,7 +276,7 @@ class UserConverter(IDConverter[discord.User]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.User:
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument)
result = None
state = ctx._state
@ -294,12 +294,12 @@ class UserConverter(IDConverter[discord.User]):
arg = argument
# Remove the '@' character if this is the first character from the argument
if arg[0] == '@':
if arg[0] == "@":
# Remove first character
arg = arg[1:]
# check for discriminator if it exists,
if len(arg) > 5 and arg[-5] == '#':
if len(arg) > 5 and arg[-5] == "#":
discrim = arg[-4:]
name = arg[:-5]
predicate = lambda u: u.name == name and u.discriminator == discrim
@ -330,22 +330,22 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
@staticmethod
def _get_id_matches(ctx, argument):
id_regex = re.compile(r'(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$')
id_regex = re.compile(r"(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$")
link_regex = re.compile(
r'https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/'
r'(?P<guild_id>[0-9]{15,20}|@me)'
r'/(?P<channel_id>[0-9]{15,20})/(?P<message_id>[0-9]{15,20})/?$'
r"https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/"
r"(?P<guild_id>[0-9]{15,20}|@me)"
r"/(?P<channel_id>[0-9]{15,20})/(?P<message_id>[0-9]{15,20})/?$"
)
match = id_regex.match(argument) or link_regex.match(argument)
if not match:
raise MessageNotFound(argument)
data = match.groupdict()
channel_id = discord.utils._get_as_snowflake(data, 'channel_id')
message_id = int(data['message_id'])
guild_id = data.get('guild_id')
channel_id = discord.utils._get_as_snowflake(data, "channel_id")
message_id = int(data["message_id"])
guild_id = data.get("guild_id")
if guild_id is None:
guild_id = ctx.guild and ctx.guild.id
elif guild_id == '@me':
elif guild_id == "@me":
guild_id = None
else:
guild_id = int(guild_id)
@ -417,13 +417,13 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.abc.GuildChannel:
return self._resolve_channel(ctx, argument, 'channels', discord.abc.GuildChannel)
return self._resolve_channel(ctx, argument, "channels", discord.abc.GuildChannel)
@staticmethod
def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT:
bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument)
result = None
guild = ctx.guild
@ -443,7 +443,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
if guild:
result = guild.get_channel(channel_id)
else:
result = _get_from_guilds(bot, 'get_channel', channel_id)
result = _get_from_guilds(bot, "get_channel", channel_id)
if not isinstance(result, type):
raise ChannelNotFound(argument)
@ -454,7 +454,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT:
bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument)
result = None
guild = ctx.guild
@ -491,7 +491,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.TextChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'text_channels', discord.TextChannel)
return GuildChannelConverter._resolve_channel(ctx, argument, "text_channels", discord.TextChannel)
class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
@ -511,7 +511,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'voice_channels', discord.VoiceChannel)
return GuildChannelConverter._resolve_channel(ctx, argument, "voice_channels", discord.VoiceChannel)
class StageChannelConverter(IDConverter[discord.StageChannel]):
@ -530,7 +530,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.StageChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'stage_channels', discord.StageChannel)
return GuildChannelConverter._resolve_channel(ctx, argument, "stage_channels", discord.StageChannel)
class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
@ -550,7 +550,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'categories', discord.CategoryChannel)
return GuildChannelConverter._resolve_channel(ctx, argument, "categories", discord.CategoryChannel)
class StoreChannelConverter(IDConverter[discord.StoreChannel]):
@ -569,7 +569,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel)
return GuildChannelConverter._resolve_channel(ctx, argument, "channels", discord.StoreChannel)
class ThreadConverter(IDConverter[discord.Thread]):
@ -587,7 +587,7 @@ class ThreadConverter(IDConverter[discord.Thread]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.Thread:
return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread)
return GuildChannelConverter._resolve_thread(ctx, argument, "threads", discord.Thread)
class ColourConverter(Converter[discord.Colour]):
@ -616,10 +616,10 @@ class ColourConverter(Converter[discord.Colour]):
Added support for ``rgb`` function and 3-digit hex shortcuts
"""
RGB_REGEX = re.compile(r'rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)')
RGB_REGEX = re.compile(r"rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)")
def parse_hex_number(self, argument):
arg = ''.join(i * 2 for i in argument) if len(argument) == 3 else argument
arg = "".join(i * 2 for i in argument) if len(argument) == 3 else argument
try:
value = int(arg, base=16)
if not (0 <= value <= 0xFFFFFF):
@ -630,7 +630,7 @@ class ColourConverter(Converter[discord.Colour]):
return discord.Color(value=value)
def parse_rgb_number(self, argument, number):
if number[-1] == '%':
if number[-1] == "%":
value = int(number[:-1])
if not (0 <= value <= 100):
raise BadColourArgument(argument)
@ -646,29 +646,29 @@ class ColourConverter(Converter[discord.Colour]):
if match is None:
raise BadColourArgument(argument)
red = self.parse_rgb_number(argument, match.group('r'))
green = self.parse_rgb_number(argument, match.group('g'))
blue = self.parse_rgb_number(argument, match.group('b'))
red = self.parse_rgb_number(argument, match.group("r"))
green = self.parse_rgb_number(argument, match.group("g"))
blue = self.parse_rgb_number(argument, match.group("b"))
return discord.Color.from_rgb(red, green, blue)
async def convert(self, ctx: Context, argument: str) -> discord.Colour:
if argument[0] == '#':
if argument[0] == "#":
return self.parse_hex_number(argument[1:])
if argument[0:2] == '0x':
if argument[0:2] == "0x":
rest = argument[2:]
# Legacy backwards compatible syntax
if rest.startswith('#'):
if rest.startswith("#"):
return self.parse_hex_number(rest[1:])
return self.parse_hex_number(rest)
arg = argument.lower()
if arg[0:3] == 'rgb':
if arg[0:3] == "rgb":
return self.parse_rgb(arg)
arg = arg.replace(' ', '_')
arg = arg.replace(" ", "_")
method = getattr(discord.Colour, arg, None)
if arg.startswith('from_') or method is None or not inspect.ismethod(method):
if arg.startswith("from_") or method is None or not inspect.ismethod(method):
raise BadColourArgument(arg)
return method()
@ -697,7 +697,7 @@ class RoleConverter(IDConverter[discord.Role]):
if not guild:
raise NoPrivateMessage()
match = self._get_id_match(argument) or re.match(r'<@&([0-9]{15,20})>$', argument)
match = self._get_id_match(argument) or re.match(r"<@&([0-9]{15,20})>$", argument)
if match:
result = guild.get_role(int(match.group(1)))
else:
@ -776,7 +776,7 @@ class EmojiConverter(IDConverter[discord.Emoji]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.Emoji:
match = self._get_id_match(argument) or re.match(r'<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$', argument)
match = self._get_id_match(argument) or re.match(r"<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$", argument)
result = None
bot = ctx.bot
guild = ctx.guild
@ -810,7 +810,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.PartialEmoji:
match = re.match(r'<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$', argument)
match = re.match(r"<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$", argument)
if match:
emoji_animated = bool(match.group(1))
@ -903,37 +903,37 @@ class clean_content(Converter[str]):
def resolve_member(id: int) -> str:
m = _utils_get(msg.mentions, id=id) or ctx.guild.get_member(id)
return f'@{m.display_name if self.use_nicknames else m.name}' if m else '@deleted-user'
return f"@{m.display_name if self.use_nicknames else m.name}" if m else "@deleted-user"
def resolve_role(id: int) -> str:
r = _utils_get(msg.role_mentions, id=id) or ctx.guild.get_role(id)
return f'@{r.name}' if r else '@deleted-role'
return f"@{r.name}" if r else "@deleted-role"
else:
def resolve_member(id: int) -> str:
m = _utils_get(msg.mentions, id=id) or ctx.bot.get_user(id)
return f'@{m.name}' if m else '@deleted-user'
return f"@{m.name}" if m else "@deleted-user"
def resolve_role(id: int) -> str:
return '@deleted-role'
return "@deleted-role"
if self.fix_channel_mentions and ctx.guild:
def resolve_channel(id: int) -> str:
c = ctx.guild.get_channel(id)
return f'#{c.name}' if c else '#deleted-channel'
return f"#{c.name}" if c else "#deleted-channel"
else:
def resolve_channel(id: int) -> str:
return f'<#{id}>'
return f"<#{id}>"
transforms = {
'@': resolve_member,
'@!': resolve_member,
'#': resolve_channel,
'@&': resolve_role,
"@": resolve_member,
"@!": resolve_member,
"#": resolve_channel,
"@&": resolve_role,
}
def repl(match: re.Match) -> str:
@ -942,7 +942,7 @@ class clean_content(Converter[str]):
transformed = transforms[type](id)
return transformed
result = re.sub(r'<(@[!&]?|#)([0-9]{15,20})>', repl, argument)
result = re.sub(r"<(@[!&]?|#)([0-9]{15,20})>", repl, argument)
if self.escape_markdown:
result = discord.utils.escape_markdown(result)
elif self.remove_markdown:
@ -974,42 +974,42 @@ class Greedy(List[T]):
For more information, check :ref:`ext_commands_special_converters`.
"""
__slots__ = ('converter',)
__slots__ = ("converter",)
def __init__(self, *, converter: T):
self.converter = converter
def __repr__(self):
converter = getattr(self.converter, '__name__', repr(self.converter))
return f'Greedy[{converter}]'
converter = getattr(self.converter, "__name__", repr(self.converter))
return f"Greedy[{converter}]"
def __class_getitem__(cls, params: Union[Tuple[T], T]) -> Greedy[T]:
if not isinstance(params, tuple):
params = (params,)
if len(params) != 1:
raise TypeError('Greedy[...] only takes a single argument')
raise TypeError("Greedy[...] only takes a single argument")
converter = params[0]
origin = getattr(converter, '__origin__', None)
args = getattr(converter, '__args__', ())
origin = getattr(converter, "__origin__", None)
args = getattr(converter, "__args__", ())
if not (callable(converter) or isinstance(converter, Converter) or origin is not None):
raise TypeError('Greedy[...] expects a type or a Converter instance.')
raise TypeError("Greedy[...] expects a type or a Converter instance.")
if converter in (str, type(None)) or origin is Greedy:
raise TypeError(f'Greedy[{converter.__name__}] is invalid.')
raise TypeError(f"Greedy[{converter.__name__}] is invalid.")
if origin is Union and type(None) in args:
raise TypeError(f'Greedy[{converter!r}] is invalid.')
raise TypeError(f"Greedy[{converter!r}] is invalid.")
return cls(converter=converter)
def _convert_to_bool(argument: str) -> bool:
lowered = argument.lower()
if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'):
if lowered in ("yes", "y", "true", "t", "1", "enable", "on"):
return True
elif lowered in ('no', 'n', 'false', 'f', '0', 'disable', 'off'):
elif lowered in ("no", "n", "false", "f", "0", "disable", "off"):
return False
else:
raise BadBoolArgument(lowered)
@ -1065,7 +1065,7 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp
except AttributeError:
pass
else:
if module is not None and (module.startswith('discord.') and not module.endswith('converter')):
if module is not None and (module.startswith("discord.") and not module.endswith("converter")):
converter = CONVERTER_MAPPING.get(converter, converter)
try:
@ -1124,7 +1124,7 @@ async def run_converters(ctx: Context, converter, argument: str, param: inspect.
Any
The resulting conversion.
"""
origin = getattr(converter, '__origin__', None)
origin = getattr(converter, "__origin__", None)
if origin is Union:
errors = []

View File

@ -38,24 +38,25 @@ if TYPE_CHECKING:
from ...message import Message
__all__ = (
'BucketType',
'Cooldown',
'CooldownMapping',
'DynamicCooldownMapping',
'MaxConcurrency',
"BucketType",
"Cooldown",
"CooldownMapping",
"DynamicCooldownMapping",
"MaxConcurrency",
)
C = TypeVar('C', bound='CooldownMapping')
MC = TypeVar('MC', bound='MaxConcurrency')
C = TypeVar("C", bound="CooldownMapping")
MC = TypeVar("MC", bound="MaxConcurrency")
class BucketType(Enum):
default = 0
user = 1
guild = 2
channel = 3
member = 4
default = 0
user = 1
guild = 2
channel = 3
member = 4
category = 5
role = 6
role = 6
def get_key(self, msg: Message) -> Any:
if self is BucketType.user:
@ -90,7 +91,7 @@ class Cooldown:
The length of the cooldown period in seconds.
"""
__slots__ = ('rate', 'per', '_window', '_tokens', '_last')
__slots__ = ("rate", "per", "_window", "_tokens", "_last")
def __init__(self, rate: float, per: float) -> None:
self.rate: int = int(rate)
@ -190,7 +191,8 @@ class Cooldown:
return Cooldown(self.rate, self.per)
def __repr__(self) -> str:
return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>'
return f"<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>"
class CooldownMapping:
def __init__(
@ -199,7 +201,7 @@ class CooldownMapping:
type: Callable[[Message], Any],
) -> None:
if not callable(type):
raise TypeError('Cooldown type must be a BucketType or callable')
raise TypeError("Cooldown type must be a BucketType or callable")
self._cache: Dict[Any, Cooldown] = {}
self._cooldown: Optional[Cooldown] = original
@ -256,13 +258,9 @@ class CooldownMapping:
bucket = self.get_bucket(message, current)
return bucket.update_rate_limit(current)
class DynamicCooldownMapping(CooldownMapping):
def __init__(
self,
factory: Callable[[Message], Cooldown],
type: Callable[[Message], Any]
) -> None:
class DynamicCooldownMapping(CooldownMapping):
def __init__(self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any]) -> None:
super().__init__(None, type)
self._factory: Callable[[Message], Cooldown] = factory
@ -278,6 +276,7 @@ class DynamicCooldownMapping(CooldownMapping):
def create_bucket(self, message: Message) -> Cooldown:
return self._factory(message)
class _Semaphore:
"""This class is a version of a semaphore.
@ -291,7 +290,7 @@ class _Semaphore:
overkill for what is basically a counter.
"""
__slots__ = ('value', 'loop', '_waiters')
__slots__ = ("value", "loop", "_waiters")
def __init__(self, number: int) -> None:
self.value: int = number
@ -299,7 +298,7 @@ class _Semaphore:
self._waiters: Deque[asyncio.Future] = deque()
def __repr__(self) -> str:
return f'<_Semaphore value={self.value} waiters={len(self._waiters)}>'
return f"<_Semaphore value={self.value} waiters={len(self._waiters)}>"
def locked(self) -> bool:
return self.value == 0
@ -337,8 +336,9 @@ class _Semaphore:
self.value += 1
self.wake_up()
class MaxConcurrency:
__slots__ = ('number', 'per', 'wait', '_mapping')
__slots__ = ("number", "per", "wait", "_mapping")
def __init__(self, number: int, *, per: BucketType, wait: bool) -> None:
self._mapping: Dict[Any, _Semaphore] = {}
@ -347,16 +347,16 @@ class MaxConcurrency:
self.wait: bool = wait
if number <= 0:
raise ValueError('max_concurrency \'number\' cannot be less than 1')
raise ValueError("max_concurrency 'number' cannot be less than 1")
if not isinstance(per, BucketType):
raise TypeError(f'max_concurrency \'per\' must be of type BucketType not {type(per)!r}')
raise TypeError(f"max_concurrency 'per' must be of type BucketType not {type(per)!r}")
def copy(self: MC) -> MC:
return self.__class__(self.number, per=self.per, wait=self.wait)
def __repr__(self) -> str:
return f'<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>'
return f"<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>"
def get_key(self, message: Message) -> Any:
return self.per.get_key(message)

View File

@ -70,52 +70,53 @@ if TYPE_CHECKING:
__all__ = (
'Command',
'Group',
'GroupMixin',
'command',
'group',
'has_role',
'has_permissions',
'has_any_role',
'check',
'check_any',
'before_invoke',
'after_invoke',
'bot_has_role',
'bot_has_permissions',
'bot_has_any_role',
'cooldown',
'dynamic_cooldown',
'max_concurrency',
'dm_only',
'guild_only',
'is_owner',
'is_nsfw',
'has_guild_permissions',
'bot_has_guild_permissions'
"Command",
"Group",
"GroupMixin",
"command",
"group",
"has_role",
"has_permissions",
"has_any_role",
"check",
"check_any",
"before_invoke",
"after_invoke",
"bot_has_role",
"bot_has_permissions",
"bot_has_any_role",
"cooldown",
"dynamic_cooldown",
"max_concurrency",
"dm_only",
"guild_only",
"is_owner",
"is_nsfw",
"has_guild_permissions",
"bot_has_guild_permissions",
)
MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
CogT = TypeVar('CogT', bound='Cog')
CommandT = TypeVar('CommandT', bound='Command')
ContextT = TypeVar('ContextT', bound='Context')
T = TypeVar("T")
CogT = TypeVar("CogT", bound="Cog")
CommandT = TypeVar("CommandT", bound="Command")
ContextT = TypeVar("ContextT", bound="Context")
# CHT = TypeVar('CHT', bound='Check')
GroupT = TypeVar('GroupT', bound='Group')
HookT = TypeVar('HookT', bound='Hook')
ErrorT = TypeVar('ErrorT', bound='Error')
GroupT = TypeVar("GroupT", bound="Group")
HookT = TypeVar("HookT", bound="Hook")
ErrorT = TypeVar("ErrorT", bound="Error")
if TYPE_CHECKING:
P = ParamSpec('P')
P = ParamSpec("P")
else:
P = TypeVar('P')
P = TypeVar("P")
def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]:
partial = functools.partial
while True:
if hasattr(function, '__wrapped__'):
if hasattr(function, "__wrapped__"):
function = function.__wrapped__
elif isinstance(function, partial):
function = function.func
@ -139,7 +140,7 @@ def get_signature_parameters(function: Callable[..., Any], globalns: Dict[str, A
annotation = eval_annotation(annotation, globalns, globalns, cache)
if annotation is Greedy:
raise TypeError('Unparameterized Greedy[...] is disallowed in signature.')
raise TypeError("Unparameterized Greedy[...] is disallowed in signature.")
params[name] = parameter.replace(annotation=annotation)
@ -158,8 +159,10 @@ def wrap_callback(coro):
except Exception as exc:
raise CommandInvokeError(exc) from exc
return ret
return wrapped
def hooked_wrapped_callback(command, ctx, coro):
@functools.wraps(coro)
async def wrapped(*args, **kwargs):
@ -180,6 +183,7 @@ def hooked_wrapped_callback(command, ctx, coro):
await command.call_after_hooks(ctx)
return ret
return wrapped
@ -202,6 +206,7 @@ class _CaseInsensitiveDict(dict):
def __setitem__(self, k, v):
super().__setitem__(k.casefold(), v)
class Command(_BaseCommand, Generic[CogT, P, T]):
r"""A class that implements the protocol for a bot text command.
@ -269,8 +274,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
which calls converters. If ``False`` then cooldown processing is done
first and then the converters are called second. Defaults to ``False``.
extras: :class:`dict`
A dict of user provided extras to attach to the Command.
A dict of user provided extras to attach to the Command.
.. note::
This object may be copied by the library.
@ -295,56 +300,60 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.__original_kwargs__ = kwargs.copy()
return self
def __init__(self, func: Union[
def __init__(
self,
func: Union[
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]],
], **kwargs: Any):
],
**kwargs: Any,
):
if not asyncio.iscoroutinefunction(func):
raise TypeError('Callback must be a coroutine.')
raise TypeError("Callback must be a coroutine.")
name = kwargs.get('name') or func.__name__
name = kwargs.get("name") or func.__name__
if not isinstance(name, str):
raise TypeError('Name of a command must be a string.')
raise TypeError("Name of a command must be a string.")
self.name: str = name
self.callback = func
self.enabled: bool = kwargs.get('enabled', True)
self.enabled: bool = kwargs.get("enabled", True)
help_doc = kwargs.get('help')
help_doc = kwargs.get("help")
if help_doc is not None:
help_doc = inspect.cleandoc(help_doc)
else:
help_doc = inspect.getdoc(func)
if isinstance(help_doc, bytes):
help_doc = help_doc.decode('utf-8')
help_doc = help_doc.decode("utf-8")
self.help: Optional[str] = help_doc
self.brief: Optional[str] = kwargs.get('brief')
self.usage: Optional[str] = kwargs.get('usage')
self.rest_is_raw: bool = kwargs.get('rest_is_raw', False)
self.aliases: Union[List[str], Tuple[str]] = kwargs.get('aliases', [])
self.extras: Dict[str, Any] = kwargs.get('extras', {})
self.brief: Optional[str] = kwargs.get("brief")
self.usage: Optional[str] = kwargs.get("usage")
self.rest_is_raw: bool = kwargs.get("rest_is_raw", False)
self.aliases: Union[List[str], Tuple[str]] = kwargs.get("aliases", [])
self.extras: Dict[str, Any] = kwargs.get("extras", {})
if not isinstance(self.aliases, (list, tuple)):
raise TypeError("Aliases of a command must be a list or a tuple of strings.")
self.description: str = inspect.cleandoc(kwargs.get('description', ''))
self.hidden: bool = kwargs.get('hidden', False)
self.description: str = inspect.cleandoc(kwargs.get("description", ""))
self.hidden: bool = kwargs.get("hidden", False)
try:
checks = func.__commands_checks__
checks.reverse()
except AttributeError:
checks = kwargs.get('checks', [])
checks = kwargs.get("checks", [])
self.checks: List[Check] = checks
try:
cooldown = func.__commands_cooldown__
except AttributeError:
cooldown = kwargs.get('cooldown')
cooldown = kwargs.get("cooldown")
if cooldown is None:
buckets = CooldownMapping(cooldown, BucketType.default)
elif isinstance(cooldown, CooldownMapping):
@ -356,17 +365,17 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
try:
max_concurrency = func.__commands_max_concurrency__
except AttributeError:
max_concurrency = kwargs.get('max_concurrency')
max_concurrency = kwargs.get("max_concurrency")
self._max_concurrency: Optional[MaxConcurrency] = max_concurrency
self.require_var_positional: bool = kwargs.get('require_var_positional', False)
self.ignore_extra: bool = kwargs.get('ignore_extra', True)
self.cooldown_after_parsing: bool = kwargs.get('cooldown_after_parsing', False)
self.require_var_positional: bool = kwargs.get("require_var_positional", False)
self.ignore_extra: bool = kwargs.get("ignore_extra", True)
self.cooldown_after_parsing: bool = kwargs.get("cooldown_after_parsing", False)
self.cog: Optional[CogT] = None
# bandaid for the fact that sometimes parent can be the bot instance
parent = kwargs.get('parent')
parent = kwargs.get("parent")
self.parent: Optional[GroupMixin] = parent if isinstance(parent, _BaseCommand) else None # type: ignore
self._before_invoke: Optional[Hook] = None
@ -386,17 +395,19 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.after_invoke(after_invoke)
@property
def callback(self) -> Union[
Callable[Concatenate[CogT, Context, P], Coro[T]],
Callable[Concatenate[Context, P], Coro[T]],
]:
def callback(
self,
) -> Union[Callable[Concatenate[CogT, Context, P], Coro[T]], Callable[Concatenate[Context, P], Coro[T]],]:
return self._callback
@callback.setter
def callback(self, function: Union[
def callback(
self,
function: Union[
Callable[Concatenate[CogT, Context, P], Coro[T]],
Callable[Concatenate[Context, P], Coro[T]],
]) -> None:
],
) -> None:
self._callback = function
unwrap = unwrap_function(function)
self.module = unwrap.__module__
@ -527,7 +538,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
wrapped = wrap_callback(local)
await wrapped(ctx, error)
finally:
ctx.bot.dispatch('command_error', ctx, error)
ctx.bot.dispatch("command_error", ctx, error)
async def transform(self, ctx: Context, param: inspect.Parameter) -> Any:
required = param.default is param.empty
@ -551,11 +562,11 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if view.eof:
if param.kind == param.VAR_POSITIONAL:
raise RuntimeError() # break the loop
raise RuntimeError() # break the loop
if required:
if self._is_typing_optional(param.annotation):
return None
if hasattr(converter, '__commands_is_flag__') and converter._can_be_constructible():
if hasattr(converter, "__commands_is_flag__") and converter._can_be_constructible():
return await converter._construct_default(ctx)
raise MissingRequiredArgument(param)
return param.default
@ -577,7 +588,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
# type-checker fails to narrow argument
return await run_converters(ctx, converter, argument, param) # type: ignore
async def _transform_greedy_pos(self, ctx: Context, param: inspect.Parameter, required: bool, converter: Any) -> Any:
async def _transform_greedy_pos(
self, ctx: Context, param: inspect.Parameter, required: bool, converter: Any
) -> Any:
view = ctx.view
result = []
while not view.eof:
@ -606,7 +619,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
value = await run_converters(ctx, converter, argument, param) # type: ignore
except (CommandError, ArgumentParsingError):
view.index = previous
raise RuntimeError() from None # break loop
raise RuntimeError() from None # break loop
else:
return value
@ -643,11 +656,11 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
entries = []
command = self
# command.parent is type-hinted as GroupMixin some attributes are resolved via MRO
while command.parent is not None: # type: ignore
command = command.parent # type: ignore
entries.append(command.name) # type: ignore
while command.parent is not None: # type: ignore
command = command.parent # type: ignore
entries.append(command.name) # type: ignore
return ' '.join(reversed(entries))
return " ".join(reversed(entries))
@property
def parents(self) -> List[Group]:
@ -661,8 +674,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
"""
entries = []
command = self
while command.parent is not None: # type: ignore
command = command.parent # type: ignore
while command.parent is not None: # type: ignore
command = command.parent # type: ignore
entries.append(command)
return entries
@ -690,7 +703,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
parent = self.full_parent_name
if parent:
return parent + ' ' + self.name
return parent + " " + self.name
else:
return self.name
@ -745,7 +758,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
break
if not self.ignore_extra and not view.eof:
raise TooManyArguments('Too many arguments passed to ' + self.qualified_name)
raise TooManyArguments("Too many arguments passed to " + self.qualified_name)
async def call_before_hooks(self, ctx: Context) -> None:
# now that we're done preparing we can call the pre-command hooks
@ -753,7 +766,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
cog = self.cog
if self._before_invoke is not None:
# should be cog if @commands.before_invoke is used
instance = getattr(self._before_invoke, '__self__', cog)
instance = getattr(self._before_invoke, "__self__", cog)
# __self__ only exists for methods, not functions
# however, if @command.before_invoke is used, it will be a function
if instance:
@ -775,7 +788,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
async def call_after_hooks(self, ctx: Context) -> None:
cog = self.cog
if self._after_invoke is not None:
instance = getattr(self._after_invoke, '__self__', cog)
instance = getattr(self._after_invoke, "__self__", cog)
if instance:
await self._after_invoke(instance, ctx) # type: ignore
else:
@ -805,7 +818,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
ctx.command = self
if not await self.can_run(ctx):
raise CheckFailure(f'The check functions for command {self.qualified_name} failed.')
raise CheckFailure(f"The check functions for command {self.qualified_name} failed.")
if self._max_concurrency is not None:
# For this application, context can be duck-typed as a Message
@ -929,7 +942,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The error handler must be a coroutine.')
raise TypeError("The error handler must be a coroutine.")
self.on_error: Error = coro
return coro
@ -939,7 +952,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
.. versionadded:: 1.7
"""
return hasattr(self, 'on_error')
return hasattr(self, "on_error")
def before_invoke(self, coro: HookT) -> HookT:
"""A decorator that registers a coroutine as a pre-invoke hook.
@ -963,7 +976,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
The coroutine passed is not actually a coroutine.
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The pre-invoke hook must be a coroutine.')
raise TypeError("The pre-invoke hook must be a coroutine.")
self._before_invoke = coro
return coro
@ -990,7 +1003,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
The coroutine passed is not actually a coroutine.
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The post-invoke hook must be a coroutine.')
raise TypeError("The post-invoke hook must be a coroutine.")
self._after_invoke = coro
return coro
@ -1011,11 +1024,11 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if self.brief is not None:
return self.brief
if self.help is not None:
return self.help.split('\n', 1)[0]
return ''
return self.help.split("\n", 1)[0]
return ""
def _is_typing_optional(self, annotation: Union[T, Optional[T]]) -> TypeGuard[Optional[T]]:
return getattr(annotation, '__origin__', None) is Union and type(None) in annotation.__args__ # type: ignore
return getattr(annotation, "__origin__", None) is Union and type(None) in annotation.__args__ # type: ignore
@property
def signature(self) -> str:
@ -1025,7 +1038,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
params = self.clean_params
if not params:
return ''
return ""
result = []
for name, param in params.items():
@ -1035,41 +1048,40 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
# parameter signature is a literal list of it's values
annotation = param.annotation.converter if greedy else param.annotation
origin = getattr(annotation, '__origin__', None)
origin = getattr(annotation, "__origin__", None)
if not greedy and origin is Union:
none_cls = type(None)
union_args = annotation.__args__
optional = union_args[-1] is none_cls
if len(union_args) == 2 and optional:
annotation = union_args[0]
origin = getattr(annotation, '__origin__', None)
origin = getattr(annotation, "__origin__", None)
if origin is Literal:
name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__)
name = "|".join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__)
if param.default is not param.empty:
# We don't want None or '' to trigger the [name=value] case and instead it should
# do [name] since [name=None] or [name=] are not exactly useful for the user.
should_print = param.default if isinstance(param.default, str) else param.default is not None
if should_print:
result.append(f'[{name}={param.default}]' if not greedy else
f'[{name}={param.default}]...')
result.append(f"[{name}={param.default}]" if not greedy else f"[{name}={param.default}]...")
continue
else:
result.append(f'[{name}]')
result.append(f"[{name}]")
elif param.kind == param.VAR_POSITIONAL:
if self.require_var_positional:
result.append(f'<{name}...>')
result.append(f"<{name}...>")
else:
result.append(f'[{name}...]')
result.append(f"[{name}...]")
elif greedy:
result.append(f'[{name}]...')
result.append(f"[{name}]...")
elif optional:
result.append(f'[{name}]')
result.append(f"[{name}]")
else:
result.append(f'<{name}>')
result.append(f"<{name}>")
return ' '.join(result)
return " ".join(result)
async def can_run(self, ctx: Context) -> bool:
"""|coro|
@ -1099,14 +1111,14 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
"""
if not self.enabled:
raise DisabledCommand(f'{self.name} command is disabled')
raise DisabledCommand(f"{self.name} command is disabled")
original = ctx.command
ctx.command = self
try:
if not await ctx.bot.can_run(ctx):
raise CheckFailure(f'The global check functions for command {self.qualified_name} failed.')
raise CheckFailure(f"The global check functions for command {self.qualified_name} failed.")
cog = self.cog
if cog is not None:
@ -1125,6 +1137,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
finally:
ctx.command = original
class GroupMixin(Generic[CogT]):
"""A mixin that implements common functionality for classes that behave
similar to :class:`.Group` and are allowed to register commands.
@ -1137,8 +1150,9 @@ class GroupMixin(Generic[CogT]):
case_insensitive: :class:`bool`
Whether the commands should be case insensitive. Defaults to ``True``.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
case_insensitive = kwargs.get('case_insensitive', True)
case_insensitive = kwargs.get("case_insensitive", True)
self.all_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {}
self.case_insensitive: bool = case_insensitive
super().__init__(*args, **kwargs)
@ -1177,7 +1191,7 @@ class GroupMixin(Generic[CogT]):
"""
if not isinstance(command, Command):
raise TypeError('The command passed must be a subclass of Command')
raise TypeError("The command passed must be a subclass of Command")
if isinstance(self, Command):
command.parent = self
@ -1267,7 +1281,7 @@ class GroupMixin(Generic[CogT]):
"""
# fast path, no space in name.
if ' ' not in name:
if " " not in name:
return self.all_commands.get(name)
names = name.split()
@ -1298,7 +1312,9 @@ class GroupMixin(Generic[CogT]):
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]],
]
], Command[CogT, P, T]]:
],
Command[CogT, P, T],
]:
...
@overload
@ -1326,8 +1342,9 @@ class GroupMixin(Generic[CogT]):
Callable[..., :class:`Command`]
A decorator that converts the provided method into a Command, adds it to the bot, then returns it.
"""
def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> CommandT:
kwargs.setdefault('parent', self)
kwargs.setdefault("parent", self)
result = command(name=name, cls=cls, *args, **kwargs)(func)
self.add_command(result)
return result
@ -1341,12 +1358,10 @@ class GroupMixin(Generic[CogT]):
cls: Type[Group[CogT, P, T]] = ...,
*args: Any,
**kwargs: Any,
) -> Callable[[
Union[
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]]
]
], Group[CogT, P, T]]:
) -> Callable[
[Union[Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]]]],
Group[CogT, P, T],
]:
...
@overload
@ -1374,14 +1389,16 @@ class GroupMixin(Generic[CogT]):
Callable[..., :class:`Group`]
A decorator that converts the provided method into a Group, adds it to the bot, then returns it.
"""
def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> GroupT:
kwargs.setdefault('parent', self)
kwargs.setdefault("parent", self)
result = group(name=name, cls=cls, *args, **kwargs)(func)
self.add_command(result)
return result
return decorator
class Group(GroupMixin[CogT], Command[CogT, P, T]):
"""A class that implements a grouping protocol for commands to be
executed as subcommands.
@ -1404,8 +1421,9 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
Indicates if the group's commands should be case insensitive.
Defaults to ``False``.
"""
def __init__(self, *args: Any, **attrs: Any) -> None:
self.invoke_without_command: bool = attrs.pop('invoke_without_command', False)
self.invoke_without_command: bool = attrs.pop("invoke_without_command", False)
super().__init__(*args, **attrs)
def copy(self: GroupT) -> GroupT:
@ -1492,8 +1510,10 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
view.previous = previous
await super().reinvoke(ctx, call_hooks=call_hooks)
# Decorators
@overload
def command(
name: str = ...,
@ -1505,10 +1525,12 @@ def command(
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]],
]
]
, Command[CogT, P, T]]:
],
Command[CogT, P, T],
]:
...
@overload
def command(
name: str = ...,
@ -1520,22 +1542,23 @@ def command(
Callable[Concatenate[CogT, ContextT, P], Coro[Any]],
Callable[Concatenate[ContextT, P], Coro[Any]],
]
]
, CommandT]:
],
CommandT,
]:
...
def command(
name: str = MISSING,
cls: Type[CommandT] = MISSING,
**attrs: Any
name: str = MISSING, cls: Type[CommandT] = MISSING, **attrs: Any
) -> Callable[
[
Union[
Callable[Concatenate[ContextT, P], Coro[Any]],
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
]
]
, Union[Command[CogT, P, T], CommandT]]:
],
Union[Command[CogT, P, T], CommandT],
]:
"""A decorator that transforms a function into a :class:`.Command`
or if called with :func:`.group`, :class:`.Group`.
@ -1568,16 +1591,19 @@ def command(
if cls is MISSING:
cls = Command # type: ignore
def decorator(func: Union[
def decorator(
func: Union[
Callable[Concatenate[ContextT, P], Coro[Any]],
Callable[Concatenate[CogT, ContextT, P], Coro[Any]],
]) -> CommandT:
]
) -> CommandT:
if isinstance(func, Command):
raise TypeError('Callback is already a command.')
raise TypeError("Callback is already a command.")
return cls(func, name=name, **attrs)
return decorator
@overload
def group(
name: str = ...,
@ -1589,10 +1615,12 @@ def group(
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]],
]
]
, Group[CogT, P, T]]:
],
Group[CogT, P, T],
]:
...
@overload
def group(
name: str = ...,
@ -1604,10 +1632,12 @@ def group(
Callable[Concatenate[CogT, ContextT, P], Coro[Any]],
Callable[Concatenate[ContextT, P], Coro[Any]],
]
]
, GroupT]:
],
GroupT,
]:
...
def group(
name: str = MISSING,
cls: Type[GroupT] = MISSING,
@ -1618,8 +1648,9 @@ def group(
Callable[Concatenate[ContextT, P], Coro[Any]],
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
]
]
, Union[Group[CogT, P, T], GroupT]]:
],
Union[Group[CogT, P, T], GroupT],
]:
"""A decorator that transforms a function into a :class:`.Group`.
This is similar to the :func:`.command` decorator but the ``cls``
@ -1632,6 +1663,7 @@ def group(
cls = Group # type: ignore
return command(name=name, cls=cls, **attrs) # type: ignore
def check(predicate: Check) -> Callable[[T], T]:
r"""A decorator that adds a check to the :class:`.Command` or its
subclasses. These checks could be accessed via :attr:`.Command.checks`.
@ -1707,7 +1739,7 @@ def check(predicate: Check) -> Callable[[T], T]:
if isinstance(func, Command):
func.checks.append(predicate)
else:
if not hasattr(func, '__commands_checks__'):
if not hasattr(func, "__commands_checks__"):
func.__commands_checks__ = []
func.__commands_checks__.append(predicate)
@ -1717,13 +1749,16 @@ def check(predicate: Check) -> Callable[[T], T]:
if inspect.iscoroutinefunction(predicate):
decorator.predicate = predicate
else:
@functools.wraps(predicate)
async def wrapper(ctx):
return predicate(ctx) # type: ignore
decorator.predicate = wrapper
return decorator # type: ignore
def check_any(*checks: Check) -> Callable[[T], T]:
r"""A :func:`check` that is added that checks if any of the checks passed
will pass, i.e. using logical OR.
@ -1773,7 +1808,7 @@ def check_any(*checks: Check) -> Callable[[T], T]:
try:
pred = wrapped.predicate
except AttributeError:
raise TypeError(f'{wrapped!r} must be wrapped by commands.check decorator') from None
raise TypeError(f"{wrapped!r} must be wrapped by commands.check decorator") from None
else:
unwrapped.append(pred)
@ -1792,6 +1827,7 @@ def check_any(*checks: Check) -> Callable[[T], T]:
return check(predicate)
def has_role(item: Union[int, str]) -> Callable[[T], T]:
"""A :func:`.check` that is added that checks if the member invoking the
command has the role specified via the name or ID specified.
@ -1834,6 +1870,7 @@ def has_role(item: Union[int, str]) -> Callable[[T], T]:
return check(predicate)
def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
r"""A :func:`.check` that is added that checks if the member invoking the
command has **any** of the roles specified. This means that if they have
@ -1865,18 +1902,22 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
async def cool(ctx):
await ctx.send('You are cool indeed')
"""
def predicate(ctx):
if ctx.guild is None:
raise NoPrivateMessage()
# ctx.guild is None doesn't narrow ctx.author to Member
getter = functools.partial(discord.utils.get, ctx.author.roles) # type: ignore
if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items):
if any(
getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items
):
return True
raise MissingAnyRole(list(items))
return check(predicate)
def bot_has_role(item: int) -> Callable[[T], T]:
"""Similar to :func:`.has_role` except checks if the bot itself has the
role.
@ -1903,8 +1944,10 @@ def bot_has_role(item: int) -> Callable[[T], T]:
if role is None:
raise BotMissingRole(item)
return True
return check(predicate)
def bot_has_any_role(*items: int) -> Callable[[T], T]:
"""Similar to :func:`.has_any_role` except checks if the bot itself has
any of the roles listed.
@ -1918,17 +1961,22 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]:
Raise :exc:`.BotMissingAnyRole` or :exc:`.NoPrivateMessage`
instead of generic checkfailure
"""
def predicate(ctx):
if ctx.guild is None:
raise NoPrivateMessage()
me = ctx.me
getter = functools.partial(discord.utils.get, me.roles)
if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items):
if any(
getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items
):
return True
raise BotMissingAnyRole(list(items))
return check(predicate)
def has_permissions(**perms: bool) -> Callable[[T], T]:
"""A :func:`.check` that is added that checks if the member has all of
the permissions necessary.
@ -1976,6 +2024,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate)
def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
"""Similar to :func:`.has_permissions` except checks if the bot itself has
the permissions listed.
@ -2002,6 +2051,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate)
def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
"""Similar to :func:`.has_permissions`, but operates on guild wide
permissions instead of the current channel permissions.
@ -2030,6 +2080,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate)
def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
"""Similar to :func:`.has_guild_permissions`, but checks the bot
members guild permissions.
@ -2055,6 +2106,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate)
def dm_only() -> Callable[[T], T]:
"""A :func:`.check` that indicates this command must only be used in a
DM context. Only private messages are allowed when
@ -2073,6 +2125,7 @@ def dm_only() -> Callable[[T], T]:
return check(predicate)
def guild_only() -> Callable[[T], T]:
"""A :func:`.check` that indicates this command must only be used in a
guild context only. Basically, no private messages are allowed when
@ -2089,6 +2142,7 @@ def guild_only() -> Callable[[T], T]:
return check(predicate)
def is_owner() -> Callable[[T], T]:
"""A :func:`.check` that checks if the person invoking this command is the
owner of the bot.
@ -2101,11 +2155,12 @@ def is_owner() -> Callable[[T], T]:
async def predicate(ctx: Context) -> bool:
if not await ctx.bot.is_owner(ctx.author):
raise NotOwner('You do not own this bot.')
raise NotOwner("You do not own this bot.")
return True
return check(predicate)
def is_nsfw() -> Callable[[T], T]:
"""A :func:`.check` that checks if the channel is a NSFW channel.
@ -2117,14 +2172,19 @@ def is_nsfw() -> Callable[[T], T]:
Raise :exc:`.NSFWChannelRequired` instead of generic :exc:`.CheckFailure`.
DM channels will also now pass this check.
"""
def pred(ctx: Context) -> bool:
ch = ctx.channel
if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()):
return True
raise NSFWChannelRequired(ch) # type: ignore
return check(pred)
def cooldown(rate: int, per: float, type: Union[BucketType, Callable[[Message], Any]] = BucketType.default) -> Callable[[T], T]:
def cooldown(
rate: int, per: float, type: Union[BucketType, Callable[[Message], Any]] = BucketType.default
) -> Callable[[T], T]:
"""A decorator that adds a cooldown to a :class:`.Command`
A cooldown allows a command to only be used a specific amount
@ -2157,9 +2217,13 @@ def cooldown(rate: int, per: float, type: Union[BucketType, Callable[[Message],
else:
func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type)
return func
return decorator # type: ignore
def dynamic_cooldown(cooldown: Union[BucketType, Callable[[Message], Any]], type: BucketType = BucketType.default) -> Callable[[T], T]:
def dynamic_cooldown(
cooldown: Union[BucketType, Callable[[Message], Any]], type: BucketType = BucketType.default
) -> Callable[[T], T]:
"""A decorator that adds a dynamic cooldown to a :class:`.Command`
This differs from :func:`.cooldown` in that it takes a function that
@ -2197,8 +2261,10 @@ def dynamic_cooldown(cooldown: Union[BucketType, Callable[[Message], Any]], type
else:
func.__commands_cooldown__ = DynamicCooldownMapping(cooldown, type)
return func
return decorator # type: ignore
def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait: bool = False) -> Callable[[T], T]:
"""A decorator that adds a maximum concurrency to a :class:`.Command` or its subclasses.
@ -2230,8 +2296,10 @@ def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait:
else:
func.__commands_max_concurrency__ = value
return func
return decorator # type: ignore
def before_invoke(coro) -> Callable[[T], T]:
"""A decorator that registers a coroutine as a pre-invoke hook.
@ -2270,14 +2338,17 @@ def before_invoke(coro) -> Callable[[T], T]:
bot.add_cog(What())
"""
def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
if isinstance(func, Command):
func.before_invoke(coro)
else:
func.__before_invoke__ = coro
return func
return decorator # type: ignore
def after_invoke(coro) -> Callable[[T], T]:
"""A decorator that registers a coroutine as a post-invoke hook.
@ -2286,10 +2357,12 @@ def after_invoke(coro) -> Callable[[T], T]:
.. versionadded:: 1.4
"""
def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
if isinstance(func, Command):
func.after_invoke(coro)
else:
func.__after_invoke__ = coro
return func
return decorator # type: ignore

View File

@ -41,65 +41,66 @@ if TYPE_CHECKING:
__all__ = (
'CommandError',
'MissingRequiredArgument',
'BadArgument',
'PrivateMessageOnly',
'NoPrivateMessage',
'CheckFailure',
'CheckAnyFailure',
'CommandNotFound',
'DisabledCommand',
'CommandInvokeError',
'TooManyArguments',
'UserInputError',
'CommandOnCooldown',
'MaxConcurrencyReached',
'NotOwner',
'MessageNotFound',
'ObjectNotFound',
'MemberNotFound',
'GuildNotFound',
'UserNotFound',
'ChannelNotFound',
'ThreadNotFound',
'ChannelNotReadable',
'BadColourArgument',
'BadColorArgument',
'RoleNotFound',
'BadInviteArgument',
'EmojiNotFound',
'GuildStickerNotFound',
'PartialEmojiConversionFailure',
'BadBoolArgument',
'MissingRole',
'BotMissingRole',
'MissingAnyRole',
'BotMissingAnyRole',
'MissingPermissions',
'BotMissingPermissions',
'NSFWChannelRequired',
'ConversionError',
'BadUnionArgument',
'BadLiteralArgument',
'ArgumentParsingError',
'UnexpectedQuoteError',
'InvalidEndOfQuotedStringError',
'ExpectedClosingQuoteError',
'ExtensionError',
'ExtensionAlreadyLoaded',
'ExtensionNotLoaded',
'NoEntryPointError',
'ExtensionFailed',
'ExtensionNotFound',
'CommandRegistrationError',
'FlagError',
'BadFlagArgument',
'MissingFlagArgument',
'TooManyFlags',
'MissingRequiredFlag',
"CommandError",
"MissingRequiredArgument",
"BadArgument",
"PrivateMessageOnly",
"NoPrivateMessage",
"CheckFailure",
"CheckAnyFailure",
"CommandNotFound",
"DisabledCommand",
"CommandInvokeError",
"TooManyArguments",
"UserInputError",
"CommandOnCooldown",
"MaxConcurrencyReached",
"NotOwner",
"MessageNotFound",
"ObjectNotFound",
"MemberNotFound",
"GuildNotFound",
"UserNotFound",
"ChannelNotFound",
"ThreadNotFound",
"ChannelNotReadable",
"BadColourArgument",
"BadColorArgument",
"RoleNotFound",
"BadInviteArgument",
"EmojiNotFound",
"GuildStickerNotFound",
"PartialEmojiConversionFailure",
"BadBoolArgument",
"MissingRole",
"BotMissingRole",
"MissingAnyRole",
"BotMissingAnyRole",
"MissingPermissions",
"BotMissingPermissions",
"NSFWChannelRequired",
"ConversionError",
"BadUnionArgument",
"BadLiteralArgument",
"ArgumentParsingError",
"UnexpectedQuoteError",
"InvalidEndOfQuotedStringError",
"ExpectedClosingQuoteError",
"ExtensionError",
"ExtensionAlreadyLoaded",
"ExtensionNotLoaded",
"NoEntryPointError",
"ExtensionFailed",
"ExtensionNotFound",
"CommandRegistrationError",
"FlagError",
"BadFlagArgument",
"MissingFlagArgument",
"TooManyFlags",
"MissingRequiredFlag",
)
class CommandError(DiscordException):
r"""The base exception type for all command related errors.
@ -109,14 +110,16 @@ class CommandError(DiscordException):
in a special way as they are caught and passed into a special event
from :class:`.Bot`\, :func:`.on_command_error`.
"""
def __init__(self, message: Optional[str] = None, *args: Any) -> None:
if message is not None:
# clean-up @everyone and @here mentions
m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere')
m = message.replace("@everyone", "@\u200beveryone").replace("@here", "@\u200bhere")
super().__init__(m, *args)
else:
super().__init__(*args)
class ConversionError(CommandError):
"""Exception raised when a Converter class raises non-CommandError.
@ -130,18 +133,22 @@ class ConversionError(CommandError):
The original exception that was raised. You can also get this via
the ``__cause__`` attribute.
"""
def __init__(self, converter: Converter, original: Exception) -> None:
self.converter: Converter = converter
self.original: Exception = original
class UserInputError(CommandError):
"""The base exception type for errors that involve errors
regarding user input.
This inherits from :exc:`CommandError`.
"""
pass
class CommandNotFound(CommandError):
"""Exception raised when a command is attempted to be invoked
but no command under that name is found.
@ -151,8 +158,10 @@ class CommandNotFound(CommandError):
This inherits from :exc:`CommandError`.
"""
pass
class MissingRequiredArgument(UserInputError):
"""Exception raised when parsing a command and a parameter
that is required is not encountered.
@ -164,9 +173,11 @@ class MissingRequiredArgument(UserInputError):
param: :class:`inspect.Parameter`
The argument that is missing.
"""
def __init__(self, param: Parameter) -> None:
self.param: Parameter = param
super().__init__(f'{param.name} is a required argument that is missing.')
super().__init__(f"{param.name} is a required argument that is missing.")
class TooManyArguments(UserInputError):
"""Exception raised when the command was passed too many arguments and its
@ -174,23 +185,29 @@ class TooManyArguments(UserInputError):
This inherits from :exc:`UserInputError`
"""
pass
class BadArgument(UserInputError):
"""Exception raised when a parsing or conversion failure is encountered
on an argument to pass into a command.
This inherits from :exc:`UserInputError`
"""
pass
class CheckFailure(CommandError):
"""Exception raised when the predicates in :attr:`.Command.checks` have failed.
This inherits from :exc:`CommandError`
"""
pass
class CheckAnyFailure(CheckFailure):
"""Exception raised when all predicates in :func:`check_any` fail.
@ -209,7 +226,8 @@ class CheckAnyFailure(CheckFailure):
def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]) -> None:
self.checks: List[CheckFailure] = checks
self.errors: List[Callable[[Context], bool]] = errors
super().__init__('You do not have permission to run this command.')
super().__init__("You do not have permission to run this command.")
class PrivateMessageOnly(CheckFailure):
"""Exception raised when an operation does not work outside of private
@ -217,8 +235,10 @@ class PrivateMessageOnly(CheckFailure):
This inherits from :exc:`CheckFailure`
"""
def __init__(self, message: Optional[str] = None) -> None:
super().__init__(message or 'This command can only be used in private messages.')
super().__init__(message or "This command can only be used in private messages.")
class NoPrivateMessage(CheckFailure):
"""Exception raised when an operation does not work in private message
@ -228,15 +248,18 @@ class NoPrivateMessage(CheckFailure):
"""
def __init__(self, message: Optional[str] = None) -> None:
super().__init__(message or 'This command cannot be used in private messages.')
super().__init__(message or "This command cannot be used in private messages.")
class NotOwner(CheckFailure):
"""Exception raised when the message author is not the owner of the bot.
This inherits from :exc:`CheckFailure`
"""
pass
class ObjectNotFound(BadArgument):
"""Exception raised when the argument provided did not match the format
of an ID or a mention.
@ -250,9 +273,11 @@ class ObjectNotFound(BadArgument):
argument: :class:`str`
The argument supplied by the caller that was not matched
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'{argument!r} does not follow a valid ID or mention format.')
super().__init__(f"{argument!r} does not follow a valid ID or mention format.")
class MemberNotFound(BadArgument):
"""Exception raised when the member provided was not found in the bot's
@ -267,10 +292,12 @@ class MemberNotFound(BadArgument):
argument: :class:`str`
The member supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Member "{argument}" not found.')
class GuildNotFound(BadArgument):
"""Exception raised when the guild provided was not found in the bot's cache.
@ -283,10 +310,12 @@ class GuildNotFound(BadArgument):
argument: :class:`str`
The guild supplied by the called that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Guild "{argument}" not found.')
class UserNotFound(BadArgument):
"""Exception raised when the user provided was not found in the bot's
cache.
@ -300,10 +329,12 @@ class UserNotFound(BadArgument):
argument: :class:`str`
The user supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'User "{argument}" not found.')
class MessageNotFound(BadArgument):
"""Exception raised when the message provided was not found in the channel.
@ -316,10 +347,12 @@ class MessageNotFound(BadArgument):
argument: :class:`str`
The message supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Message "{argument}" not found.')
class ChannelNotReadable(BadArgument):
"""Exception raised when the bot does not have permission to read messages
in the channel.
@ -333,10 +366,12 @@ class ChannelNotReadable(BadArgument):
argument: Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
The channel supplied by the caller that was not readable
"""
def __init__(self, argument: Union[GuildChannel, Thread]) -> None:
self.argument: Union[GuildChannel, Thread] = argument
super().__init__(f"Can't read messages in {argument.mention}.")
class ChannelNotFound(BadArgument):
"""Exception raised when the bot can not find the channel.
@ -349,10 +384,12 @@ class ChannelNotFound(BadArgument):
argument: :class:`str`
The channel supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Channel "{argument}" not found.')
class ThreadNotFound(BadArgument):
"""Exception raised when the bot can not find the thread.
@ -365,10 +402,12 @@ class ThreadNotFound(BadArgument):
argument: :class:`str`
The thread supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Thread "{argument}" not found.')
class BadColourArgument(BadArgument):
"""Exception raised when the colour is not valid.
@ -381,12 +420,15 @@ class BadColourArgument(BadArgument):
argument: :class:`str`
The colour supplied by the caller that was not valid
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Colour "{argument}" is invalid.')
BadColorArgument = BadColourArgument
class RoleNotFound(BadArgument):
"""Exception raised when the bot can not find the role.
@ -399,10 +441,12 @@ class RoleNotFound(BadArgument):
argument: :class:`str`
The role supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Role "{argument}" not found.')
class BadInviteArgument(BadArgument):
"""Exception raised when the invite is invalid or expired.
@ -410,10 +454,12 @@ class BadInviteArgument(BadArgument):
.. versionadded:: 1.5
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Invite "{argument}" is invalid or expired.')
class EmojiNotFound(BadArgument):
"""Exception raised when the bot can not find the emoji.
@ -426,10 +472,12 @@ class EmojiNotFound(BadArgument):
argument: :class:`str`
The emoji supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Emoji "{argument}" not found.')
class PartialEmojiConversionFailure(BadArgument):
"""Exception raised when the emoji provided does not match the correct
format.
@ -443,10 +491,12 @@ class PartialEmojiConversionFailure(BadArgument):
argument: :class:`str`
The emoji supplied by the caller that did not match the regex
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Couldn\'t convert "{argument}" to PartialEmoji.')
class GuildStickerNotFound(BadArgument):
"""Exception raised when the bot can not find the sticker.
@ -459,10 +509,12 @@ class GuildStickerNotFound(BadArgument):
argument: :class:`str`
The sticker supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Sticker "{argument}" not found.')
class BadBoolArgument(BadArgument):
"""Exception raised when a boolean argument was not convertable.
@ -475,17 +527,21 @@ class BadBoolArgument(BadArgument):
argument: :class:`str`
The boolean argument supplied by the caller that is not in the predefined list
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'{argument} is not a recognised boolean option')
super().__init__(f"{argument} is not a recognised boolean option")
class DisabledCommand(CommandError):
"""Exception raised when the command being invoked is disabled.
This inherits from :exc:`CommandError`
"""
pass
class CommandInvokeError(CommandError):
"""Exception raised when the command being invoked raised an exception.
@ -497,9 +553,11 @@ class CommandInvokeError(CommandError):
The original exception that was raised. You can also get this via
the ``__cause__`` attribute.
"""
def __init__(self, e: Exception) -> None:
self.original: Exception = e
super().__init__(f'Command raised an exception: {e.__class__.__name__}: {e}')
super().__init__(f"Command raised an exception: {e.__class__.__name__}: {e}")
class CommandOnCooldown(CommandError):
"""Exception raised when the command being invoked is on cooldown.
@ -516,11 +574,13 @@ class CommandOnCooldown(CommandError):
retry_after: :class:`float`
The amount of seconds to wait before you can retry again.
"""
def __init__(self, cooldown: Cooldown, retry_after: float, type: BucketType) -> None:
self.cooldown: Cooldown = cooldown
self.retry_after: float = retry_after
self.type: BucketType = type
super().__init__(f'You are on cooldown. Try again in {retry_after:.2f}s')
super().__init__(f"You are on cooldown. Try again in {retry_after:.2f}s")
class MaxConcurrencyReached(CommandError):
"""Exception raised when the command being invoked has reached its maximum concurrency.
@ -539,10 +599,11 @@ class MaxConcurrencyReached(CommandError):
self.number: int = number
self.per: BucketType = per
name = per.name
suffix = 'per %s' % name if per.name != 'default' else 'globally'
plural = '%s times %s' if number > 1 else '%s time %s'
suffix = "per %s" % name if per.name != "default" else "globally"
plural = "%s times %s" if number > 1 else "%s time %s"
fmt = plural % (number, suffix)
super().__init__(f'Too many people are using this command. It can only be used {fmt} concurrently.')
super().__init__(f"Too many people are using this command. It can only be used {fmt} concurrently.")
class MissingRole(CheckFailure):
"""Exception raised when the command invoker lacks a role to run a command.
@ -557,11 +618,13 @@ class MissingRole(CheckFailure):
The required role that is missing.
This is the parameter passed to :func:`~.commands.has_role`.
"""
def __init__(self, missing_role: Snowflake) -> None:
self.missing_role: Snowflake = missing_role
message = f'Role {missing_role!r} is required to run this command.'
message = f"Role {missing_role!r} is required to run this command."
super().__init__(message)
class BotMissingRole(CheckFailure):
"""Exception raised when the bot's member lacks a role to run a command.
@ -575,11 +638,13 @@ class BotMissingRole(CheckFailure):
The required role that is missing.
This is the parameter passed to :func:`~.commands.has_role`.
"""
def __init__(self, missing_role: Snowflake) -> None:
self.missing_role: Snowflake = missing_role
message = f'Bot requires the role {missing_role!r} to run this command'
message = f"Bot requires the role {missing_role!r} to run this command"
super().__init__(message)
class MissingAnyRole(CheckFailure):
"""Exception raised when the command invoker lacks any of
the roles specified to run a command.
@ -594,15 +659,16 @@ class MissingAnyRole(CheckFailure):
The roles that the invoker is missing.
These are the parameters passed to :func:`~.commands.has_any_role`.
"""
def __init__(self, missing_roles: SnowflakeList) -> None:
self.missing_roles: SnowflakeList = missing_roles
missing = [f"'{role}'" for role in missing_roles]
if len(missing) > 2:
fmt = '{}, or {}'.format(", ".join(missing[:-1]), missing[-1])
fmt = "{}, or {}".format(", ".join(missing[:-1]), missing[-1])
else:
fmt = ' or '.join(missing)
fmt = " or ".join(missing)
message = f"You are missing at least one of the required roles: {fmt}"
super().__init__(message)
@ -623,19 +689,21 @@ class BotMissingAnyRole(CheckFailure):
These are the parameters passed to :func:`~.commands.has_any_role`.
"""
def __init__(self, missing_roles: SnowflakeList) -> None:
self.missing_roles: SnowflakeList = missing_roles
missing = [f"'{role}'" for role in missing_roles]
if len(missing) > 2:
fmt = '{}, or {}'.format(", ".join(missing[:-1]), missing[-1])
fmt = "{}, or {}".format(", ".join(missing[:-1]), missing[-1])
else:
fmt = ' or '.join(missing)
fmt = " or ".join(missing)
message = f"Bot is missing at least one of the required roles: {fmt}"
super().__init__(message)
class NSFWChannelRequired(CheckFailure):
"""Exception raised when a channel does not have the required NSFW setting.
@ -648,10 +716,12 @@ class NSFWChannelRequired(CheckFailure):
channel: Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
The channel that does not have NSFW enabled.
"""
def __init__(self, channel: Union[GuildChannel, Thread]) -> None:
self.channel: Union[GuildChannel, Thread] = channel
super().__init__(f"Channel '{channel}' needs to be NSFW for this command to work.")
class MissingPermissions(CheckFailure):
"""Exception raised when the command invoker lacks permissions to run a
command.
@ -663,18 +733,20 @@ class MissingPermissions(CheckFailure):
missing_permissions: List[:class:`str`]
The required permissions that are missing.
"""
def __init__(self, missing_permissions: List[str], *args: Any) -> None:
self.missing_permissions: List[str] = missing_permissions
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions]
missing = [perm.replace("_", " ").replace("guild", "server").title() for perm in missing_permissions]
if len(missing) > 2:
fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1])
fmt = "{}, and {}".format(", ".join(missing[:-1]), missing[-1])
else:
fmt = ' and '.join(missing)
message = f'You are missing {fmt} permission(s) to run this command.'
fmt = " and ".join(missing)
message = f"You are missing {fmt} permission(s) to run this command."
super().__init__(message, *args)
class BotMissingPermissions(CheckFailure):
"""Exception raised when the bot's member lacks permissions to run a
command.
@ -686,18 +758,20 @@ class BotMissingPermissions(CheckFailure):
missing_permissions: List[:class:`str`]
The required permissions that are missing.
"""
def __init__(self, missing_permissions: List[str], *args: Any) -> None:
self.missing_permissions: List[str] = missing_permissions
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions]
missing = [perm.replace("_", " ").replace("guild", "server").title() for perm in missing_permissions]
if len(missing) > 2:
fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1])
fmt = "{}, and {}".format(", ".join(missing[:-1]), missing[-1])
else:
fmt = ' and '.join(missing)
message = f'Bot requires {fmt} permission(s) to run this command.'
fmt = " and ".join(missing)
message = f"Bot requires {fmt} permission(s) to run this command."
super().__init__(message, *args)
class BadUnionArgument(UserInputError):
"""Exception raised when a :data:`typing.Union` converter fails for all
its associated types.
@ -713,6 +787,7 @@ class BadUnionArgument(UserInputError):
errors: List[:class:`CommandError`]
A list of errors that were caught from failing the conversion.
"""
def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None:
self.param: Parameter = param
self.converters: Tuple[Type, ...] = converters
@ -722,18 +797,19 @@ class BadUnionArgument(UserInputError):
try:
return x.__name__
except AttributeError:
if hasattr(x, '__origin__'):
if hasattr(x, "__origin__"):
return repr(x)
return x.__class__.__name__
to_string = [_get_name(x) for x in converters]
if len(to_string) > 2:
fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1])
fmt = "{}, or {}".format(", ".join(to_string[:-1]), to_string[-1])
else:
fmt = ' or '.join(to_string)
fmt = " or ".join(to_string)
super().__init__(f'Could not convert "{param.name}" into {fmt}.')
class BadLiteralArgument(UserInputError):
"""Exception raised when a :data:`typing.Literal` converter fails for all
its associated values.
@ -751,6 +827,7 @@ class BadLiteralArgument(UserInputError):
errors: List[:class:`CommandError`]
A list of errors that were caught from failing the conversion.
"""
def __init__(self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError]) -> None:
self.param: Parameter = param
self.literals: Tuple[Any, ...] = literals
@ -758,12 +835,13 @@ class BadLiteralArgument(UserInputError):
to_string = [repr(l) for l in literals]
if len(to_string) > 2:
fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1])
fmt = "{}, or {}".format(", ".join(to_string[:-1]), to_string[-1])
else:
fmt = ' or '.join(to_string)
fmt = " or ".join(to_string)
super().__init__(f'Could not convert "{param.name}" into the literal {fmt}.')
class ArgumentParsingError(UserInputError):
"""An exception raised when the parser fails to parse a user's input.
@ -772,8 +850,10 @@ class ArgumentParsingError(UserInputError):
There are child classes that implement more granular parsing errors for
i18n purposes.
"""
pass
class UnexpectedQuoteError(ArgumentParsingError):
"""An exception raised when the parser encounters a quote mark inside a non-quoted string.
@ -784,9 +864,11 @@ class UnexpectedQuoteError(ArgumentParsingError):
quote: :class:`str`
The quote mark that was found inside the non-quoted string.
"""
def __init__(self, quote: str) -> None:
self.quote: str = quote
super().__init__(f'Unexpected quote mark, {quote!r}, in non-quoted string')
super().__init__(f"Unexpected quote mark, {quote!r}, in non-quoted string")
class InvalidEndOfQuotedStringError(ArgumentParsingError):
"""An exception raised when a space is expected after the closing quote in a string
@ -799,9 +881,11 @@ class InvalidEndOfQuotedStringError(ArgumentParsingError):
char: :class:`str`
The character found instead of the expected string.
"""
def __init__(self, char: str) -> None:
self.char: str = char
super().__init__(f'Expected space after closing quotation but received {char!r}')
super().__init__(f"Expected space after closing quotation but received {char!r}")
class ExpectedClosingQuoteError(ArgumentParsingError):
"""An exception raised when a quote character is expected but not found.
@ -816,7 +900,8 @@ class ExpectedClosingQuoteError(ArgumentParsingError):
def __init__(self, close_quote: str) -> None:
self.close_quote: str = close_quote
super().__init__(f'Expected closing {close_quote}.')
super().__init__(f"Expected closing {close_quote}.")
class ExtensionError(DiscordException):
"""Base exception for extension related errors.
@ -828,37 +913,45 @@ class ExtensionError(DiscordException):
name: :class:`str`
The extension that had an error.
"""
def __init__(self, message: Optional[str] = None, *args: Any, name: str) -> None:
self.name: str = name
message = message or f'Extension {name!r} had an error.'
message = message or f"Extension {name!r} had an error."
# clean-up @everyone and @here mentions
m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere')
m = message.replace("@everyone", "@\u200beveryone").replace("@here", "@\u200bhere")
super().__init__(m, *args)
class ExtensionAlreadyLoaded(ExtensionError):
"""An exception raised when an extension has already been loaded.
This inherits from :exc:`ExtensionError`
"""
def __init__(self, name: str) -> None:
super().__init__(f'Extension {name!r} is already loaded.', name=name)
super().__init__(f"Extension {name!r} is already loaded.", name=name)
class ExtensionNotLoaded(ExtensionError):
"""An exception raised when an extension was not loaded.
This inherits from :exc:`ExtensionError`
"""
def __init__(self, name: str) -> None:
super().__init__(f'Extension {name!r} has not been loaded.', name=name)
super().__init__(f"Extension {name!r} has not been loaded.", name=name)
class NoEntryPointError(ExtensionError):
"""An exception raised when an extension does not have a ``setup`` entry point function.
This inherits from :exc:`ExtensionError`
"""
def __init__(self, name: str) -> None:
super().__init__(f"Extension {name!r} has no 'setup' function.", name=name)
class ExtensionFailed(ExtensionError):
"""An exception raised when an extension failed to load during execution of the module or ``setup`` entry point.
@ -872,11 +965,13 @@ class ExtensionFailed(ExtensionError):
The original exception that was raised. You can also get this via
the ``__cause__`` attribute.
"""
def __init__(self, name: str, original: Exception) -> None:
self.original: Exception = original
msg = f'Extension {name!r} raised an error: {original.__class__.__name__}: {original}'
msg = f"Extension {name!r} raised an error: {original.__class__.__name__}: {original}"
super().__init__(msg, name=name)
class ExtensionNotFound(ExtensionError):
"""An exception raised when an extension is not found.
@ -890,10 +985,12 @@ class ExtensionNotFound(ExtensionError):
name: :class:`str`
The extension that had the error.
"""
def __init__(self, name: str) -> None:
msg = f'Extension {name!r} could not be loaded.'
msg = f"Extension {name!r} could not be loaded."
super().__init__(msg, name=name)
class CommandRegistrationError(ClientException):
"""An exception raised when the command can't be added
because the name is already taken by a different command.
@ -909,11 +1006,13 @@ class CommandRegistrationError(ClientException):
alias_conflict: :class:`bool`
Whether the name that conflicts is an alias of the command we try to add.
"""
def __init__(self, name: str, *, alias_conflict: bool = False) -> None:
self.name: str = name
self.alias_conflict: bool = alias_conflict
type_ = 'alias' if alias_conflict else 'command'
super().__init__(f'The {type_} {name} is already an existing command or alias.')
type_ = "alias" if alias_conflict else "command"
super().__init__(f"The {type_} {name} is already an existing command or alias.")
class FlagError(BadArgument):
"""The base exception type for all flag parsing related errors.
@ -922,8 +1021,10 @@ class FlagError(BadArgument):
.. versionadded:: 2.0
"""
pass
class TooManyFlags(FlagError):
"""An exception raised when a flag has received too many values.
@ -938,10 +1039,12 @@ class TooManyFlags(FlagError):
values: List[:class:`str`]
The values that were passed.
"""
def __init__(self, flag: Flag, values: List[str]) -> None:
self.flag: Flag = flag
self.values: List[str] = values
super().__init__(f'Too many flag values, expected {flag.max_args} but received {len(values)}.')
super().__init__(f"Too many flag values, expected {flag.max_args} but received {len(values)}.")
class BadFlagArgument(FlagError):
"""An exception raised when a flag failed to convert a value.
@ -955,6 +1058,7 @@ class BadFlagArgument(FlagError):
flag: :class:`~discord.ext.commands.Flag`
The flag that failed to convert.
"""
def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag
try:
@ -962,7 +1066,8 @@ class BadFlagArgument(FlagError):
except AttributeError:
name = flag.annotation.__class__.__name__
super().__init__(f'Could not convert to {name!r} for flag {flag.name!r}')
super().__init__(f"Could not convert to {name!r} for flag {flag.name!r}")
class MissingRequiredFlag(FlagError):
"""An exception raised when a required flag was not given.
@ -976,9 +1081,11 @@ class MissingRequiredFlag(FlagError):
flag: :class:`~discord.ext.commands.Flag`
The required flag that was not found.
"""
def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag
super().__init__(f'Flag {flag.name!r} is required and missing')
super().__init__(f"Flag {flag.name!r} is required and missing")
class MissingFlagArgument(FlagError):
"""An exception raised when a flag did not get a value.
@ -992,6 +1099,7 @@ class MissingFlagArgument(FlagError):
flag: :class:`~discord.ext.commands.Flag`
The flag that did not get a value.
"""
def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag
super().__init__(f'Flag {flag.name!r} does not have an argument')
super().__init__(f"Flag {flag.name!r} does not have an argument")

View File

@ -59,9 +59,9 @@ import sys
import re
__all__ = (
'Flag',
'flag',
'FlagConverter',
"Flag",
"flag",
"FlagConverter",
)
@ -148,20 +148,20 @@ def flag(
def validate_flag_name(name: str, forbidden: Set[str]):
if not name:
raise ValueError('flag names should not be empty')
raise ValueError("flag names should not be empty")
for ch in name:
if ch.isspace():
raise ValueError(f'flag name {name!r} cannot have spaces')
if ch == '\\':
raise ValueError(f'flag name {name!r} cannot have backslashes')
raise ValueError(f"flag name {name!r} cannot have spaces")
if ch == "\\":
raise ValueError(f"flag name {name!r} cannot have backslashes")
if ch in forbidden:
raise ValueError(f'flag name {name!r} cannot have any of {forbidden!r} within them')
raise ValueError(f"flag name {name!r} cannot have any of {forbidden!r} within them")
def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]:
annotations = namespace.get('__annotations__', {})
case_insensitive = namespace['__commands_flag_case_insensitive__']
annotations = namespace.get("__annotations__", {})
case_insensitive = namespace["__commands_flag_case_insensitive__"]
flags: Dict[str, Flag] = {}
cache: Dict[str, Any] = {}
names: Set[str] = set()
@ -178,7 +178,11 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache)
if flag.default is MISSING and hasattr(annotation, '__commands_is_flag__') and annotation._can_be_constructible():
if (
flag.default is MISSING
and hasattr(annotation, "__commands_is_flag__")
and annotation._can_be_constructible()
):
flag.default = annotation._construct_default
if flag.aliases is MISSING:
@ -229,7 +233,7 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
if flag.max_args is MISSING:
flag.max_args = 1
else:
raise TypeError(f'Unsupported typing annotation {annotation!r} for {flag.name!r} flag')
raise TypeError(f"Unsupported typing annotation {annotation!r} for {flag.name!r} flag")
if flag.override is MISSING:
flag.override = False
@ -237,7 +241,7 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
# Validate flag names are unique
name = flag.name.casefold() if case_insensitive else flag.name
if name in names:
raise TypeError(f'{flag.name!r} flag conflicts with previous flag or alias.')
raise TypeError(f"{flag.name!r} flag conflicts with previous flag or alias.")
else:
names.add(name)
@ -245,7 +249,7 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
# Validate alias is unique
alias = alias.casefold() if case_insensitive else alias
if alias in names:
raise TypeError(f'{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.')
raise TypeError(f"{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.")
else:
names.add(alias)
@ -274,10 +278,10 @@ class FlagsMeta(type):
delimiter: str = MISSING,
prefix: str = MISSING,
):
attrs['__commands_is_flag__'] = True
attrs["__commands_is_flag__"] = True
try:
global_ns = sys.modules[attrs['__module__']].__dict__
global_ns = sys.modules[attrs["__module__"]].__dict__
except KeyError:
global_ns = {}
@ -296,26 +300,26 @@ class FlagsMeta(type):
flags: Dict[str, Flag] = {}
aliases: Dict[str, str] = {}
for base in reversed(bases):
if base.__dict__.get('__commands_is_flag__', False):
flags.update(base.__dict__['__commands_flags__'])
aliases.update(base.__dict__['__commands_flag_aliases__'])
if base.__dict__.get("__commands_is_flag__", False):
flags.update(base.__dict__["__commands_flags__"])
aliases.update(base.__dict__["__commands_flag_aliases__"])
if case_insensitive is MISSING:
attrs['__commands_flag_case_insensitive__'] = base.__dict__['__commands_flag_case_insensitive__']
attrs["__commands_flag_case_insensitive__"] = base.__dict__["__commands_flag_case_insensitive__"]
if delimiter is MISSING:
attrs['__commands_flag_delimiter__'] = base.__dict__['__commands_flag_delimiter__']
attrs["__commands_flag_delimiter__"] = base.__dict__["__commands_flag_delimiter__"]
if prefix is MISSING:
attrs['__commands_flag_prefix__'] = base.__dict__['__commands_flag_prefix__']
attrs["__commands_flag_prefix__"] = base.__dict__["__commands_flag_prefix__"]
if case_insensitive is not MISSING:
attrs['__commands_flag_case_insensitive__'] = case_insensitive
attrs["__commands_flag_case_insensitive__"] = case_insensitive
if delimiter is not MISSING:
attrs['__commands_flag_delimiter__'] = delimiter
attrs["__commands_flag_delimiter__"] = delimiter
if prefix is not MISSING:
attrs['__commands_flag_prefix__'] = prefix
attrs["__commands_flag_prefix__"] = prefix
case_insensitive = attrs.setdefault('__commands_flag_case_insensitive__', False)
delimiter = attrs.setdefault('__commands_flag_delimiter__', ':')
prefix = attrs.setdefault('__commands_flag_prefix__', '')
case_insensitive = attrs.setdefault("__commands_flag_case_insensitive__", False)
delimiter = attrs.setdefault("__commands_flag_delimiter__", ":")
prefix = attrs.setdefault("__commands_flag_prefix__", "")
for flag_name, flag in get_flags(attrs, global_ns, local_ns).items():
flags[flag_name] = flag
@ -337,11 +341,11 @@ class FlagsMeta(type):
keys.extend(re.escape(a) for a in aliases)
keys = sorted(keys, key=lambda t: len(t), reverse=True)
joined = '|'.join(keys)
pattern = re.compile(f'(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})', regex_flags)
attrs['__commands_flag_regex__'] = pattern
attrs['__commands_flags__'] = flags
attrs['__commands_flag_aliases__'] = aliases
joined = "|".join(keys)
pattern = re.compile(f"(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})", regex_flags)
attrs["__commands_flag_regex__"] = pattern
attrs["__commands_flags__"] = flags
attrs["__commands_flag_aliases__"] = aliases
return type.__new__(cls, name, bases, attrs)
@ -432,7 +436,7 @@ async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -
raise BadFlagArgument(flag) from e
F = TypeVar('F', bound='FlagConverter')
F = TypeVar("F", bound="FlagConverter")
class FlagConverter(metaclass=FlagsMeta):
@ -493,8 +497,8 @@ class FlagConverter(metaclass=FlagsMeta):
return self
def __repr__(self) -> str:
pairs = ' '.join([f'{flag.attribute}={getattr(self, flag.attribute)!r}' for flag in self.get_flags().values()])
return f'<{self.__class__.__name__} {pairs}>'
pairs = " ".join([f"{flag.attribute}={getattr(self, flag.attribute)!r}" for flag in self.get_flags().values()])
return f"<{self.__class__.__name__} {pairs}>"
@classmethod
def parse_flags(cls, argument: str) -> Dict[str, List[str]]:
@ -507,7 +511,7 @@ class FlagConverter(metaclass=FlagsMeta):
case_insensitive = cls.__commands_flag_case_insensitive__
for match in cls.__commands_flag_regex__.finditer(argument):
begin, end = match.span(0)
key = match.group('flag')
key = match.group("flag")
if case_insensitive:
key = key.casefold()

View File

@ -39,10 +39,10 @@ if TYPE_CHECKING:
from .context import Context
__all__ = (
'Paginator',
'HelpCommand',
'DefaultHelpCommand',
'MinimalHelpCommand',
"Paginator",
"HelpCommand",
"DefaultHelpCommand",
"MinimalHelpCommand",
)
# help -> shows info of bot on top/bottom and lists subcommands
@ -89,7 +89,7 @@ class Paginator:
.. versionadded:: 1.7
"""
def __init__(self, prefix='```', suffix='```', max_size=2000, linesep='\n'):
def __init__(self, prefix="```", suffix="```", max_size=2000, linesep="\n"):
self.prefix = prefix
self.suffix = suffix
self.max_size = max_size
@ -118,7 +118,7 @@ class Paginator:
def _linesep_len(self):
return len(self.linesep)
def add_line(self, line='', *, empty=False):
def add_line(self, line="", *, empty=False):
"""Adds a line to the current page.
If the line exceeds the :attr:`max_size` then an exception
@ -138,7 +138,7 @@ class Paginator:
"""
max_page_size = self.max_size - self._prefix_len - self._suffix_len - 2 * self._linesep_len
if len(line) > max_page_size:
raise RuntimeError(f'Line exceeds maximum page size {max_page_size}')
raise RuntimeError(f"Line exceeds maximum page size {max_page_size}")
if self._count + len(line) + self._linesep_len > self.max_size - self._suffix_len:
self.close_page()
@ -147,7 +147,7 @@ class Paginator:
self._current_page.append(line)
if empty:
self._current_page.append('')
self._current_page.append("")
self._count += self._linesep_len
def close_page(self):
@ -176,7 +176,7 @@ class Paginator:
return self._pages
def __repr__(self):
fmt = '<Paginator prefix: {0.prefix!r} suffix: {0.suffix!r} linesep: {0.linesep!r} max_size: {0.max_size} count: {0._count}>'
fmt = "<Paginator prefix: {0.prefix!r} suffix: {0.suffix!r} linesep: {0.linesep!r} max_size: {0.max_size} count: {0._count}>"
return fmt.format(self)
@ -197,7 +197,7 @@ class _HelpCommandImpl(Command):
self.callback = injected.command_callback
on_error = injected.on_help_command_error
if not hasattr(on_error, '__help_command_not_overriden__'):
if not hasattr(on_error, "__help_command_not_overriden__"):
if self.cog is not None:
self.on_error = self._on_error_cog_implementation
else:
@ -224,7 +224,7 @@ class _HelpCommandImpl(Command):
try:
del result[next(iter(result))]
except StopIteration:
raise ValueError('Missing context parameter') from None
raise ValueError("Missing context parameter") from None
else:
return result
@ -296,13 +296,13 @@ class HelpCommand:
"""
MENTION_TRANSFORMS = {
'@everyone': '@\u200beveryone',
'@here': '@\u200bhere',
r'<@!?[0-9]{17,22}>': '@deleted-user',
r'<@&[0-9]{17,22}>': '@deleted-role',
"@everyone": "@\u200beveryone",
"@here": "@\u200bhere",
r"<@!?[0-9]{17,22}>": "@deleted-user",
r"<@&[0-9]{17,22}>": "@deleted-role",
}
MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys()))
MENTION_PATTERN = re.compile("|".join(MENTION_TRANSFORMS.keys()))
def __new__(cls, *args, **kwargs):
# To prevent race conditions of a single instance while also allowing
@ -321,11 +321,11 @@ class HelpCommand:
return self
def __init__(self, **options):
self.show_hidden = options.pop('show_hidden', False)
self.verify_checks = options.pop('verify_checks', True)
self.command_attrs = attrs = options.pop('command_attrs', {})
attrs.setdefault('name', 'help')
attrs.setdefault('help', 'Shows this message')
self.show_hidden = options.pop("show_hidden", False)
self.verify_checks = options.pop("verify_checks", True)
self.command_attrs = attrs = options.pop("command_attrs", {})
attrs.setdefault("name", "help")
attrs.setdefault("help", "Shows this message")
self.context: Context = discord.utils.MISSING
self._command_impl = _HelpCommandImpl(self, **self.command_attrs)
@ -422,20 +422,20 @@ class HelpCommand:
if not parent.signature or parent.invoke_without_command:
entries.append(parent.name)
else:
entries.append(parent.name + ' ' + parent.signature)
entries.append(parent.name + " " + parent.signature)
parent = parent.parent
parent_sig = ' '.join(reversed(entries))
parent_sig = " ".join(reversed(entries))
if len(command.aliases) > 0:
aliases = '|'.join(command.aliases)
fmt = f'[{command.name}|{aliases}]'
aliases = "|".join(command.aliases)
fmt = f"[{command.name}|{aliases}]"
if parent_sig:
fmt = parent_sig + ' ' + fmt
fmt = parent_sig + " " + fmt
alias = fmt
else:
alias = command.name if not parent_sig else parent_sig + ' ' + command.name
alias = command.name if not parent_sig else parent_sig + " " + command.name
return f'{self.context.clean_prefix}{alias} {command.signature}'
return f"{self.context.clean_prefix}{alias} {command.signature}"
def remove_mentions(self, string):
"""Removes mentions from the string to prevent abuse.
@ -449,7 +449,7 @@ class HelpCommand:
"""
def replace(obj, *, transforms=self.MENTION_TRANSFORMS):
return transforms.get(obj.group(0), '@invalid')
return transforms.get(obj.group(0), "@invalid")
return self.MENTION_PATTERN.sub(replace, string)
@ -846,7 +846,7 @@ class HelpCommand:
# Since we want to have detailed errors when someone
# passes an invalid subcommand, we need to walk through
# the command group chain ourselves.
keys = command.split(' ')
keys = command.split(" ")
cmd = bot.all_commands.get(keys[0])
if cmd is None:
string = await maybe_coro(self.command_not_found, self.remove_mentions(keys[0]))
@ -907,14 +907,14 @@ class DefaultHelpCommand(HelpCommand):
"""
def __init__(self, **options):
self.width = options.pop('width', 80)
self.indent = options.pop('indent', 2)
self.sort_commands = options.pop('sort_commands', True)
self.dm_help = options.pop('dm_help', False)
self.dm_help_threshold = options.pop('dm_help_threshold', 1000)
self.commands_heading = options.pop('commands_heading', "Commands:")
self.no_category = options.pop('no_category', 'No Category')
self.paginator = options.pop('paginator', None)
self.width = options.pop("width", 80)
self.indent = options.pop("indent", 2)
self.sort_commands = options.pop("sort_commands", True)
self.dm_help = options.pop("dm_help", False)
self.dm_help_threshold = options.pop("dm_help_threshold", 1000)
self.commands_heading = options.pop("commands_heading", "Commands:")
self.no_category = options.pop("no_category", "No Category")
self.paginator = options.pop("paginator", None)
if self.paginator is None:
self.paginator = Paginator()
@ -924,7 +924,7 @@ class DefaultHelpCommand(HelpCommand):
def shorten_text(self, text):
""":class:`str`: Shortens text to fit into the :attr:`width`."""
if len(text) > self.width:
return text[:self.width - 3].rstrip() + '...'
return text[: self.width - 3].rstrip() + "..."
return text
def get_ending_note(self):
@ -1021,11 +1021,11 @@ class DefaultHelpCommand(HelpCommand):
# <description> portion
self.paginator.add_line(bot.description, empty=True)
no_category = f'\u200b{self.no_category}:'
no_category = f"\u200b{self.no_category}:"
def get_category(command, *, no_category=no_category):
cog = command.cog
return cog.qualified_name + ':' if cog is not None else no_category
return cog.qualified_name + ":" if cog is not None else no_category
filtered = await self.filter_commands(bot.commands, sort=True, key=get_category)
max_size = self.get_max_size(filtered)
@ -1110,13 +1110,13 @@ class MinimalHelpCommand(HelpCommand):
"""
def __init__(self, **options):
self.sort_commands = options.pop('sort_commands', True)
self.commands_heading = options.pop('commands_heading', "Commands")
self.dm_help = options.pop('dm_help', False)
self.dm_help_threshold = options.pop('dm_help_threshold', 1000)
self.aliases_heading = options.pop('aliases_heading', "Aliases:")
self.no_category = options.pop('no_category', 'No Category')
self.paginator = options.pop('paginator', None)
self.sort_commands = options.pop("sort_commands", True)
self.commands_heading = options.pop("commands_heading", "Commands")
self.dm_help = options.pop("dm_help", False)
self.dm_help_threshold = options.pop("dm_help_threshold", 1000)
self.aliases_heading = options.pop("aliases_heading", "Aliases:")
self.no_category = options.pop("no_category", "No Category")
self.paginator = options.pop("paginator", None)
if self.paginator is None:
self.paginator = Paginator(suffix=None, prefix=None)
@ -1149,7 +1149,7 @@ class MinimalHelpCommand(HelpCommand):
)
def get_command_signature(self, command):
return f'{self.context.clean_prefix}{command.qualified_name} {command.signature}'
return f"{self.context.clean_prefix}{command.qualified_name} {command.signature}"
def get_ending_note(self):
"""Return the help command's ending note. This is mainly useful to override for i18n purposes.
@ -1180,8 +1180,8 @@ class MinimalHelpCommand(HelpCommand):
"""
if commands:
# U+2002 Middle Dot
joined = '\u2002'.join(c.name for c in commands)
self.paginator.add_line(f'__**{heading}**__')
joined = "\u2002".join(c.name for c in commands)
self.paginator.add_line(f"__**{heading}**__")
self.paginator.add_line(joined)
def add_subcommand_formatting(self, command):
@ -1197,7 +1197,7 @@ class MinimalHelpCommand(HelpCommand):
command: :class:`Command`
The command to show information of.
"""
fmt = '{0}{1} \N{EN DASH} {2}' if command.short_doc else '{0}{1}'
fmt = "{0}{1} \N{EN DASH} {2}" if command.short_doc else "{0}{1}"
self.paginator.add_line(fmt.format(self.context.clean_prefix, command.qualified_name, command.short_doc))
def add_aliases_formatting(self, aliases):
@ -1268,7 +1268,7 @@ class MinimalHelpCommand(HelpCommand):
if note:
self.paginator.add_line(note, empty=True)
no_category = f'\u200b{self.no_category}'
no_category = f"\u200b{self.no_category}"
def get_category(command, *, no_category=no_category):
cog = command.cog
@ -1302,7 +1302,7 @@ class MinimalHelpCommand(HelpCommand):
filtered = await self.filter_commands(cog.get_commands(), sort=self.sort_commands)
if filtered:
self.paginator.add_line(f'**{cog.qualified_name} {self.commands_heading}**')
self.paginator.add_line(f"**{cog.qualified_name} {self.commands_heading}**")
for command in filtered:
self.add_subcommand_formatting(command)
@ -1322,7 +1322,7 @@ class MinimalHelpCommand(HelpCommand):
if note:
self.paginator.add_line(note, empty=True)
self.paginator.add_line(f'**{self.commands_heading}**')
self.paginator.add_line(f"**{self.commands_heading}**")
for command in filtered:
self.add_subcommand_formatting(command)

View File

@ -46,6 +46,7 @@ _quotes = {
}
_all_quotes = set(_quotes.keys()) | set(_quotes.values())
class StringView:
def __init__(self, buffer):
self.index = 0
@ -81,20 +82,20 @@ class StringView:
def skip_string(self, string):
strlen = len(string)
if self.buffer[self.index:self.index + strlen] == string:
if self.buffer[self.index : self.index + strlen] == string:
self.previous = self.index
self.index += strlen
return True
return False
def read_rest(self):
result = self.buffer[self.index:]
result = self.buffer[self.index :]
self.previous = self.index
self.index = self.end
return result
def read(self, n):
result = self.buffer[self.index:self.index + n]
result = self.buffer[self.index : self.index + n]
self.previous = self.index
self.index += n
return result
@ -120,7 +121,7 @@ class StringView:
except IndexError:
break
self.previous = self.index
result = self.buffer[self.index:self.index + pos]
result = self.buffer[self.index : self.index + pos]
self.index += pos
return result
@ -144,11 +145,11 @@ class StringView:
if is_quoted:
# unexpected EOF
raise ExpectedClosingQuoteError(close_quote)
return ''.join(result)
return "".join(result)
# currently we accept strings in the format of "hello world"
# to embed a quote inside the string you must escape it: "a \"world\""
if current == '\\':
if current == "\\":
next_char = self.get()
if not next_char:
# string ends with \ and no character after it
@ -156,7 +157,7 @@ class StringView:
# if we're quoted then we're expecting a closing quote
raise ExpectedClosingQuoteError(close_quote)
# if we aren't then we just let it through
return ''.join(result)
return "".join(result)
if next_char in _escaped_quotes:
# escaped quote
@ -179,14 +180,13 @@ class StringView:
raise InvalidEndOfQuotedStringError(next_char)
# we're quoted so it's okay
return ''.join(result)
return "".join(result)
if current.isspace() and not is_quoted:
# end of word found
return ''.join(result)
return "".join(result)
result.append(current)
def __repr__(self):
return f'<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>'
return f"<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>"

View File

@ -48,19 +48,17 @@ from collections.abc import Sequence
from discord.backoff import ExponentialBackoff
from discord.utils import MISSING
__all__ = (
'loop',
)
__all__ = ("loop",)
T = TypeVar('T')
T = TypeVar("T")
_func = Callable[..., Awaitable[Any]]
LF = TypeVar('LF', bound=_func)
FT = TypeVar('FT', bound=_func)
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
LF = TypeVar("LF", bound=_func)
FT = TypeVar("FT", bound=_func)
ET = TypeVar("ET", bound=Callable[[Any, BaseException], Awaitable[Any]])
class SleepHandle:
__slots__ = ('future', 'loop', 'handle')
__slots__ = ("future", "loop", "handle")
def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None:
self.loop = loop
@ -124,7 +122,7 @@ class Loop(Generic[LF]):
self._stop_next_iteration = False
if self.count is not None and self.count <= 0:
raise ValueError('count must be greater than 0 or None.')
raise ValueError("count must be greater than 0 or None.")
self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time)
self._last_iteration_failed = False
@ -132,10 +130,10 @@ class Loop(Generic[LF]):
self._next_iteration = None
if not inspect.iscoroutinefunction(self.coro):
raise TypeError(f'Expected coroutine function, not {type(self.coro).__name__!r}.')
raise TypeError(f"Expected coroutine function, not {type(self.coro).__name__!r}.")
async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None:
coro = getattr(self, '_' + name)
coro = getattr(self, "_" + name)
if coro is None:
return
@ -150,7 +148,7 @@ class Loop(Generic[LF]):
async def _loop(self, *args: Any, **kwargs: Any) -> None:
backoff = ExponentialBackoff()
await self._call_loop_function('before_loop')
await self._call_loop_function("before_loop")
self._last_iteration_failed = False
if self._time is not MISSING:
# the time index should be prepared every time the internal loop is started
@ -193,10 +191,10 @@ class Loop(Generic[LF]):
raise
except Exception as exc:
self._has_failed = True
await self._call_loop_function('error', exc)
await self._call_loop_function("error", exc)
raise exc
finally:
await self._call_loop_function('after_loop')
await self._call_loop_function("after_loop")
self._handle.cancel()
self._is_being_cancelled = False
self._current_loop = 0
@ -323,7 +321,7 @@ class Loop(Generic[LF]):
"""
if self._task is not MISSING and not self._task.done():
raise RuntimeError('Task is already launched and is not completed.')
raise RuntimeError("Task is already launched and is not completed.")
if self._injected is not None:
args = (self._injected, *args)
@ -410,9 +408,9 @@ class Loop(Generic[LF]):
for exc in exceptions:
if not inspect.isclass(exc):
raise TypeError(f'{exc!r} must be a class.')
raise TypeError(f"{exc!r} must be a class.")
if not issubclass(exc, BaseException):
raise TypeError(f'{exc!r} must inherit from BaseException.')
raise TypeError(f"{exc!r} must inherit from BaseException.")
self._valid_exception = (*self._valid_exception, *exceptions)
@ -466,7 +464,7 @@ class Loop(Generic[LF]):
async def _error(self, *args: Any) -> None:
exception: Exception = args[-1]
print(f'Unhandled exception in internal background task {self.coro.__name__!r}.', file=sys.stderr)
print(f"Unhandled exception in internal background task {self.coro.__name__!r}.", file=sys.stderr)
traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr)
def before_loop(self, coro: FT) -> FT:
@ -489,7 +487,7 @@ class Loop(Generic[LF]):
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.")
self._before_loop = coro
return coro
@ -517,7 +515,7 @@ class Loop(Generic[LF]):
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.")
self._after_loop = coro
return coro
@ -543,7 +541,7 @@ class Loop(Generic[LF]):
The function was not a coroutine.
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.")
self._error = coro # type: ignore
return coro
@ -601,16 +599,16 @@ class Loop(Generic[LF]):
return [inner]
if not isinstance(time, Sequence):
raise TypeError(
f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.'
f"Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead."
)
if not time:
raise ValueError('time parameter must not be an empty sequence.')
raise ValueError("time parameter must not be an empty sequence.")
ret: List[datetime.time] = []
for index, t in enumerate(time):
if not isinstance(t, dt):
raise TypeError(
f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.'
f"Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead."
)
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
@ -663,7 +661,7 @@ class Loop(Generic[LF]):
hours = hours or 0
sleep = seconds + (minutes * 60.0) + (hours * 3600.0)
if sleep < 0:
raise ValueError('Total number of seconds cannot be less than zero.')
raise ValueError("Total number of seconds cannot be less than zero.")
self._sleep = sleep
self._seconds = float(seconds)
@ -672,7 +670,7 @@ class Loop(Generic[LF]):
self._time: List[datetime.time] = MISSING
else:
if any((seconds, minutes, hours)):
raise TypeError('Cannot mix explicit time with relative time')
raise TypeError("Cannot mix explicit time with relative time")
self._time = self._get_time_parameter(time)
self._sleep = self._seconds = self._minutes = self._hours = MISSING