Run black on the repository, with the default configuration.

This commit is contained in:
Arthur Jovart
2021-09-01 21:30:56 +02:00
parent 6f5614373a
commit 4d9a1989a0
107 changed files with 8671 additions and 5258 deletions

View File

@@ -31,15 +31,23 @@ if TYPE_CHECKING:
from .cog import Cog
from .errors import CommandError
T = TypeVar('T')
T = TypeVar("T")
Coro = Coroutine[Any, Any, T]
MaybeCoro = Union[T, Coro[T]]
CoroFunc = Callable[..., Coro[Any]]
Check = Union[Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]]
Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]]
Error = Union[Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]]
Check = Union[
Callable[["Cog", "Context[Any]"], MaybeCoro[bool]],
Callable[["Context[Any]"], MaybeCoro[bool]],
]
Hook = Union[
Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]
]
Error = Union[
Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]],
Callable[["Context[Any]", "CommandError"], Coro[Any]],
]
# This is merely a tag type to avoid circular import issues.

View File

@@ -33,7 +33,18 @@ import importlib.util
import sys
import traceback
import types
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union
from typing import (
Any,
Callable,
Mapping,
List,
Dict,
TYPE_CHECKING,
Optional,
TypeVar,
Type,
Union,
)
import discord
@@ -54,17 +65,18 @@ if TYPE_CHECKING:
)
__all__ = (
'when_mentioned',
'when_mentioned_or',
'Bot',
'AutoShardedBot',
"when_mentioned",
"when_mentioned_or",
"Bot",
"AutoShardedBot",
)
MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
CFT = TypeVar('CFT', bound='CoroFunc')
CXT = TypeVar('CXT', bound='Context')
T = TypeVar("T")
CFT = TypeVar("CFT", bound="CoroFunc")
CXT = TypeVar("CXT", bound="Context")
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
"""A callable that implements a command prefix equivalent to being mentioned.
@@ -72,9 +84,12 @@ def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
"""
# bot.user will never be None when this is called
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore
return [f"<@{bot.user.id}> ", f"<@!{bot.user.id}> "] # type: ignore
def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]:
def when_mentioned_or(
*prefixes: str,
) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]:
"""A callable that implements when mentioned or other prefixes provided.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
@@ -103,6 +118,7 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M
----------
:func:`.when_mentioned`
"""
def inner(bot, msg):
r = list(prefixes)
r = when_mentioned(bot, msg) + r
@@ -110,17 +126,29 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M
return inner
def _is_submodule(parent: str, child: str) -> bool:
return parent == child or child.startswith(parent + ".")
class _DefaultRepr:
def __repr__(self):
return '<default-help-command>'
return "<default-help-command>"
_default = _DefaultRepr()
class BotBase(GroupMixin):
def __init__(self, command_prefix, help_command=_default, description=None, *, intents: discord.Intents, **options):
def __init__(
self,
command_prefix,
help_command=_default,
description=None,
*,
intents: discord.Intents,
**options,
):
super().__init__(**options, intents=intents)
self.command_prefix = command_prefix
self.extra_events: Dict[str, List[CoroFunc]] = {}
@@ -131,16 +159,20 @@ class BotBase(GroupMixin):
self._before_invoke = None
self._after_invoke = None
self._help_command = None
self.description = inspect.cleandoc(description) if description else ''
self.owner_id = options.get('owner_id')
self.owner_ids = options.get('owner_ids', set())
self.strip_after_prefix = options.get('strip_after_prefix', False)
self.description = inspect.cleandoc(description) if description else ""
self.owner_id = options.get("owner_id")
self.owner_ids = options.get("owner_ids", set())
self.strip_after_prefix = options.get("strip_after_prefix", False)
if self.owner_id and self.owner_ids:
raise TypeError('Both owner_id and owner_ids are set.')
raise TypeError("Both owner_id and owner_ids are set.")
if self.owner_ids and not isinstance(self.owner_ids, collections.abc.Collection):
raise TypeError(f'owner_ids must be a collection not {self.owner_ids.__class__!r}')
if self.owner_ids and not isinstance(
self.owner_ids, collections.abc.Collection
):
raise TypeError(
f"owner_ids must be a collection not {self.owner_ids.__class__!r}"
)
if help_command is _default:
self.help_command = DefaultHelpCommand()
@@ -152,7 +184,7 @@ class BotBase(GroupMixin):
def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None:
# super() will resolve to Client
super().dispatch(event_name, *args, **kwargs) # type: ignore
ev = 'on_' + event_name
ev = "on_" + event_name
for event in self.extra_events.get(ev, []):
self._schedule_event(event, ev, *args, **kwargs) # type: ignore
@@ -172,7 +204,9 @@ class BotBase(GroupMixin):
await super().close() # type: ignore
async def on_command_error(self, context: Context, exception: errors.CommandError) -> None:
async def on_command_error(
self, context: Context, exception: errors.CommandError
) -> None:
"""|coro|
The default command error handler provided by the bot.
@@ -182,7 +216,7 @@ class BotBase(GroupMixin):
This only fires if you do not specify any listeners for command error.
"""
if self.extra_events.get('on_command_error', None):
if self.extra_events.get("on_command_error", None):
return
command = context.command
@@ -193,8 +227,10 @@ class BotBase(GroupMixin):
if cog and cog.has_error_handler():
return
print(f'Ignoring exception in command {context.command}:', file=sys.stderr)
traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr)
print(f"Ignoring exception in command {context.command}:", file=sys.stderr)
traceback.print_exception(
type(exception), exception, exception.__traceback__, file=sys.stderr
)
# global check registration
@@ -380,7 +416,7 @@ class BotBase(GroupMixin):
The coroutine passed is not actually a coroutine.
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The pre-invoke hook must be a coroutine.')
raise TypeError("The pre-invoke hook must be a coroutine.")
self._before_invoke = coro
return coro
@@ -413,7 +449,7 @@ class BotBase(GroupMixin):
The coroutine passed is not actually a coroutine.
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The post-invoke hook must be a coroutine.')
raise TypeError("The post-invoke hook must be a coroutine.")
self._after_invoke = coro
return coro
@@ -445,7 +481,7 @@ class BotBase(GroupMixin):
name = func.__name__ if name is MISSING else name
if not asyncio.iscoroutinefunction(func):
raise TypeError('Listeners must be coroutines')
raise TypeError("Listeners must be coroutines")
if name in self.extra_events:
self.extra_events[name].append(func)
@@ -541,14 +577,14 @@ class BotBase(GroupMixin):
"""
if not isinstance(cog, Cog):
raise TypeError('cogs must derive from Cog')
raise TypeError("cogs must derive from Cog")
cog_name = cog.__cog_name__
existing = self.__cogs.get(cog_name)
if existing is not None:
if not override:
raise discord.ClientException(f'Cog named {cog_name!r} already loaded')
raise discord.ClientException(f"Cog named {cog_name!r} already loaded")
self.remove_cog(cog_name)
cog = cog._inject(self)
@@ -628,7 +664,9 @@ class BotBase(GroupMixin):
for event_list in self.extra_events.copy().values():
remove = []
for index, event in enumerate(event_list):
if event.__module__ is not None and _is_submodule(name, event.__module__):
if event.__module__ is not None and _is_submodule(
name, event.__module__
):
remove.append(index)
for index in reversed(remove):
@@ -636,7 +674,7 @@ class BotBase(GroupMixin):
def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None:
try:
func = getattr(lib, 'teardown')
func = getattr(lib, "teardown")
except AttributeError:
pass
else:
@@ -652,7 +690,9 @@ class BotBase(GroupMixin):
if _is_submodule(name, module):
del sys.modules[module]
def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None:
def _load_from_module_spec(
self, spec: importlib.machinery.ModuleSpec, key: str
) -> None:
# precondition: key not in self.__extensions
lib = importlib.util.module_from_spec(spec)
sys.modules[key] = lib
@@ -663,7 +703,7 @@ class BotBase(GroupMixin):
raise errors.ExtensionFailed(key, e) from e
try:
setup = getattr(lib, 'setup')
setup = getattr(lib, "setup")
except AttributeError:
del sys.modules[key]
raise errors.NoEntryPointError(key)
@@ -850,7 +890,7 @@ class BotBase(GroupMixin):
def help_command(self, value: Optional[HelpCommand]) -> None:
if value is not None:
if not isinstance(value, HelpCommand):
raise TypeError('help_command must be a subclass of HelpCommand')
raise TypeError("help_command must be a subclass of HelpCommand")
if self._help_command is not None:
self._help_command._remove_from_bot(self)
self._help_command = value
@@ -893,11 +933,15 @@ class BotBase(GroupMixin):
if isinstance(ret, collections.abc.Iterable):
raise
raise TypeError("command_prefix must be plain string, iterable of strings, or callable "
f"returning either of these, not {ret.__class__.__name__}")
raise TypeError(
"command_prefix must be plain string, iterable of strings, or callable "
f"returning either of these, not {ret.__class__.__name__}"
)
if not ret:
raise ValueError("Iterable command_prefix must contain at least one prefix")
raise ValueError(
"Iterable command_prefix must contain at least one prefix"
)
return ret
@@ -954,14 +998,18 @@ class BotBase(GroupMixin):
except TypeError:
if not isinstance(prefix, list):
raise TypeError("get_prefix must return either a string or a list of string, "
f"not {prefix.__class__.__name__}")
raise TypeError(
"get_prefix must return either a string or a list of string, "
f"not {prefix.__class__.__name__}"
)
# It's possible a bad command_prefix got us here.
for value in prefix:
if not isinstance(value, str):
raise TypeError("Iterable command_prefix or list returned from get_prefix must "
f"contain only strings, not {value.__class__.__name__}")
raise TypeError(
"Iterable command_prefix or list returned from get_prefix must "
f"contain only strings, not {value.__class__.__name__}"
)
# Getting here shouldn't happen
raise
@@ -988,19 +1036,19 @@ class BotBase(GroupMixin):
The invocation context to invoke.
"""
if ctx.command is not None:
self.dispatch('command', ctx)
self.dispatch("command", ctx)
try:
if await self.can_run(ctx, call_once=True):
await ctx.command.invoke(ctx)
else:
raise errors.CheckFailure('The global check once functions failed.')
raise errors.CheckFailure("The global check once functions failed.")
except errors.CommandError as exc:
await ctx.command.dispatch_error(ctx, exc)
else:
self.dispatch('command_completion', ctx)
self.dispatch("command_completion", ctx)
elif ctx.invoked_with:
exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found')
self.dispatch('command_error', ctx, exc)
self.dispatch("command_error", ctx, exc)
async def process_commands(self, message: Message) -> None:
"""|coro|
@@ -1033,6 +1081,7 @@ class BotBase(GroupMixin):
async def on_message(self, message):
await self.process_commands(message)
class Bot(BotBase, discord.Client):
"""Represents a discord bot.
@@ -1103,10 +1152,13 @@ class Bot(BotBase, discord.Client):
.. versionadded:: 1.7
"""
pass
class AutoShardedBot(BotBase, discord.AutoShardedClient):
"""This is similar to :class:`.Bot` except that it is inherited from
:class:`discord.AutoShardedClient` instead.
"""
pass

View File

@@ -26,7 +26,19 @@ from __future__ import annotations
import inspect
import discord.utils
from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Type
from typing import (
Any,
Callable,
ClassVar,
Dict,
Generator,
List,
Optional,
TYPE_CHECKING,
Tuple,
TypeVar,
Type,
)
from ._types import _BaseCommand
@@ -36,15 +48,16 @@ if TYPE_CHECKING:
from .core import Command
__all__ = (
'CogMeta',
'Cog',
"CogMeta",
"Cog",
)
CogT = TypeVar('CogT', bound='Cog')
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
CogT = TypeVar("CogT", bound="Cog")
FuncT = TypeVar("FuncT", bound=Callable[..., Any])
MISSING: Any = discord.utils.MISSING
class CogMeta(type):
"""A metaclass for defining a cog.
@@ -104,6 +117,7 @@ class CogMeta(type):
async def bar(self, ctx):
pass # hidden -> False
"""
__cog_name__: str
__cog_settings__: Dict[str, Any]
__cog_commands__: List[Command]
@@ -111,17 +125,17 @@ class CogMeta(type):
def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta:
name, bases, attrs = args
attrs['__cog_name__'] = kwargs.pop('name', name)
attrs['__cog_settings__'] = kwargs.pop('command_attrs', {})
attrs["__cog_name__"] = kwargs.pop("name", name)
attrs["__cog_settings__"] = kwargs.pop("command_attrs", {})
description = kwargs.pop('description', None)
description = kwargs.pop("description", None)
if description is None:
description = inspect.cleandoc(attrs.get('__doc__', ''))
attrs['__cog_description__'] = description
description = inspect.cleandoc(attrs.get("__doc__", ""))
attrs["__cog_description__"] = description
commands = {}
listeners = {}
no_bot_cog = 'Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})'
no_bot_cog = "Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})"
new_cls = super().__new__(cls, name, bases, attrs, **kwargs)
for base in reversed(new_cls.__mro__):
@@ -136,21 +150,25 @@ class CogMeta(type):
value = value.__func__
if isinstance(value, _BaseCommand):
if is_static_method:
raise TypeError(f'Command in method {base}.{elem!r} must not be staticmethod.')
if elem.startswith(('cog_', 'bot_')):
raise TypeError(
f"Command in method {base}.{elem!r} must not be staticmethod."
)
if elem.startswith(("cog_", "bot_")):
raise TypeError(no_bot_cog.format(base, elem))
commands[elem] = value
elif inspect.iscoroutinefunction(value):
try:
getattr(value, '__cog_listener__')
getattr(value, "__cog_listener__")
except AttributeError:
continue
else:
if elem.startswith(('cog_', 'bot_')):
if elem.startswith(("cog_", "bot_")):
raise TypeError(no_bot_cog.format(base, elem))
listeners[elem] = value
new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__
new_cls.__cog_commands__ = list(
commands.values()
) # this will be copied in Cog.__new__
listeners_as_list = []
for listener in listeners.values():
@@ -169,10 +187,12 @@ class CogMeta(type):
def qualified_name(cls) -> str:
return cls.__cog_name__
def _cog_special_method(func: FuncT) -> FuncT:
func.__cog_special_method__ = None
return func
class Cog(metaclass=CogMeta):
"""The base class that all cogs must inherit from.
@@ -183,6 +203,7 @@ class Cog(metaclass=CogMeta):
When inheriting from this class, the options shown in :class:`CogMeta`
are equally valid here.
"""
__cog_name__: ClassVar[str]
__cog_settings__: ClassVar[Dict[str, Any]]
__cog_commands__: ClassVar[List[Command]]
@@ -199,10 +220,7 @@ class Cog(metaclass=CogMeta):
# r.e type ignore, type-checker complains about overriding a ClassVar
self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore
lookup = {
cmd.qualified_name: cmd
for cmd in self.__cog_commands__
}
lookup = {cmd.qualified_name: cmd for cmd in self.__cog_commands__}
# Update the Command instances dynamically as well
for command in self.__cog_commands__:
@@ -255,6 +273,7 @@ class Cog(metaclass=CogMeta):
A command or group from the cog.
"""
from .core import GroupMixin
for command in self.__cog_commands__:
if command.parent is None:
yield command
@@ -269,12 +288,15 @@ class Cog(metaclass=CogMeta):
List[Tuple[:class:`str`, :ref:`coroutine <coroutine>`]]
The listeners defined in this cog.
"""
return [(name, getattr(self, method_name)) for name, method_name in self.__cog_listeners__]
return [
(name, getattr(self, method_name))
for name, method_name in self.__cog_listeners__
]
@classmethod
def _get_overridden_method(cls, method: FuncT) -> Optional[FuncT]:
"""Return None if the method is not overridden. Otherwise returns the overridden method."""
return getattr(method.__func__, '__cog_special_method__', method)
return getattr(method.__func__, "__cog_special_method__", method)
@classmethod
def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]:
@@ -296,14 +318,16 @@ class Cog(metaclass=CogMeta):
"""
if name is not MISSING and not isinstance(name, str):
raise TypeError(f'Cog.listener expected str but received {name.__class__.__name__!r} instead.')
raise TypeError(
f"Cog.listener expected str but received {name.__class__.__name__!r} instead."
)
def decorator(func: FuncT) -> FuncT:
actual = func
if isinstance(actual, staticmethod):
actual = actual.__func__
if not inspect.iscoroutinefunction(actual):
raise TypeError('Listener function must be a coroutine function.')
raise TypeError("Listener function must be a coroutine function.")
actual.__cog_listener__ = True
to_assign = name or actual.__name__
try:
@@ -315,6 +339,7 @@ class Cog(metaclass=CogMeta):
# to pick it up but the metaclass unfurls the function and
# thus the assignments need to be on the actual function
return func
return decorator
def has_error_handler(self) -> bool:
@@ -322,7 +347,7 @@ class Cog(metaclass=CogMeta):
.. versionadded:: 1.7
"""
return not hasattr(self.cog_command_error.__func__, '__cog_special_method__')
return not hasattr(self.cog_command_error.__func__, "__cog_special_method__")
@_cog_special_method
def cog_unload(self) -> None:

View File

@@ -49,21 +49,19 @@ if TYPE_CHECKING:
from .help import HelpCommand
from .view import StringView
__all__ = (
'Context',
)
__all__ = ("Context",)
MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
CogT = TypeVar('CogT', bound="Cog")
T = TypeVar("T")
BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]")
CogT = TypeVar("CogT", bound="Cog")
if TYPE_CHECKING:
P = ParamSpec('P')
P = ParamSpec("P")
else:
P = TypeVar('P')
P = TypeVar("P")
class Context(discord.abc.Messageable, Generic[BotT]):
@@ -122,7 +120,8 @@ class Context(discord.abc.Messageable, Generic[BotT]):
or invoked.
"""
def __init__(self,
def __init__(
self,
*,
message: Message,
bot: BotT,
@@ -153,7 +152,9 @@ class Context(discord.abc.Messageable, Generic[BotT]):
self.current_parameter: Optional[inspect.Parameter] = current_parameter
self._state: ConnectionState = self.message._state
async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
async def invoke(
self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs
) -> T:
r"""|coro|
Calls a command with the arguments given.
@@ -219,7 +220,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
cmd = self.command
view = self.view
if cmd is None:
raise ValueError('This context is not valid.')
raise ValueError("This context is not valid.")
# some state to revert to when we're done
index, previous = view.index, view.previous
@@ -230,10 +231,10 @@ class Context(discord.abc.Messageable, Generic[BotT]):
if restart:
to_call = cmd.root_parent or cmd
view.index = len(self.prefix or '')
view.index = len(self.prefix or "")
view.previous = 0
self.invoked_parents = []
self.invoked_with = view.get_word() # advance to get the root command
self.invoked_with = view.get_word() # advance to get the root command
else:
to_call = cmd
@@ -263,7 +264,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
.. versionadded:: 2.0
"""
if self.prefix is None:
return ''
return ""
user = self.me
# this breaks if the prefix mention is not the bot itself but I
@@ -271,7 +272,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
# for this common use case rather than waste performance for the
# odd one.
pattern = re.compile(r"<@!?%s>" % user.id)
return pattern.sub("@%s" % user.display_name.replace('\\', r'\\'), self.prefix)
return pattern.sub("@%s" % user.display_name.replace("\\", r"\\"), self.prefix)
@property
def cog(self) -> Optional[Cog]:
@@ -381,7 +382,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
await cmd.prepare_help_command(self, entity.qualified_name)
try:
if hasattr(entity, '__cog_commands__'):
if hasattr(entity, "__cog_commands__"):
injected = wrap_callback(cmd.send_cog_help)
return await injected(entity)
elif isinstance(entity, Group):

View File

@@ -52,32 +52,32 @@ if TYPE_CHECKING:
__all__ = (
'Converter',
'ObjectConverter',
'MemberConverter',
'UserConverter',
'MessageConverter',
'PartialMessageConverter',
'TextChannelConverter',
'InviteConverter',
'GuildConverter',
'RoleConverter',
'GameConverter',
'ColourConverter',
'ColorConverter',
'VoiceChannelConverter',
'StageChannelConverter',
'EmojiConverter',
'PartialEmojiConverter',
'CategoryChannelConverter',
'IDConverter',
'StoreChannelConverter',
'ThreadConverter',
'GuildChannelConverter',
'GuildStickerConverter',
'clean_content',
'Greedy',
'run_converters',
"Converter",
"ObjectConverter",
"MemberConverter",
"UserConverter",
"MessageConverter",
"PartialMessageConverter",
"TextChannelConverter",
"InviteConverter",
"GuildConverter",
"RoleConverter",
"GameConverter",
"ColourConverter",
"ColorConverter",
"VoiceChannelConverter",
"StageChannelConverter",
"EmojiConverter",
"PartialEmojiConverter",
"CategoryChannelConverter",
"IDConverter",
"StoreChannelConverter",
"ThreadConverter",
"GuildChannelConverter",
"GuildStickerConverter",
"clean_content",
"Greedy",
"run_converters",
)
@@ -91,10 +91,10 @@ def _get_from_guilds(bot, getter, argument):
_utils_get = discord.utils.get
T = TypeVar('T')
T_co = TypeVar('T_co', covariant=True)
CT = TypeVar('CT', bound=discord.abc.GuildChannel)
TT = TypeVar('TT', bound=discord.Thread)
T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
CT = TypeVar("CT", bound=discord.abc.GuildChannel)
TT = TypeVar("TT", bound=discord.Thread)
@runtime_checkable
@@ -132,10 +132,10 @@ class Converter(Protocol[T_co]):
:exc:`.BadArgument`
The converter failed to convert the argument.
"""
raise NotImplementedError('Derived classes need to implement this.')
raise NotImplementedError("Derived classes need to implement this.")
_ID_REGEX = re.compile(r'([0-9]{15,20})$')
_ID_REGEX = re.compile(r"([0-9]{15,20})$")
class IDConverter(Converter[T_co]):
@@ -158,7 +158,9 @@ class ObjectConverter(IDConverter[discord.Object]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.Object:
match = self._get_id_match(argument) or re.match(r'<(?:@(?:!|&)?|#)([0-9]{15,20})>$', argument)
match = self._get_id_match(argument) or re.match(
r"<(?:@(?:!|&)?|#)([0-9]{15,20})>$", argument
)
if match is None:
raise ObjectNotFound(argument)
@@ -192,13 +194,17 @@ class MemberConverter(IDConverter[discord.Member]):
async def query_member_named(self, guild, argument):
cache = guild._state.member_cache_flags.joined
if len(argument) > 5 and argument[-5] == '#':
username, _, discriminator = argument.rpartition('#')
if len(argument) > 5 and argument[-5] == "#":
username, _, discriminator = argument.rpartition("#")
members = await guild.query_members(username, limit=100, cache=cache)
return discord.utils.get(members, name=username, discriminator=discriminator)
return discord.utils.get(
members, name=username, discriminator=discriminator
)
else:
members = await guild.query_members(argument, limit=100, cache=cache)
return discord.utils.find(lambda m: m.name == argument or m.nick == argument, members)
return discord.utils.find(
lambda m: m.name == argument or m.nick == argument, members
)
async def query_member_by_id(self, bot, guild, user_id):
ws = bot._get_websocket(shard_id=guild.shard_id)
@@ -223,7 +229,9 @@ class MemberConverter(IDConverter[discord.Member]):
async def convert(self, ctx: Context, argument: str) -> discord.Member:
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
match = self._get_id_match(argument) or re.match(
r"<@!?([0-9]{15,20})>$", argument
)
guild = ctx.guild
result = None
user_id = None
@@ -232,13 +240,15 @@ class MemberConverter(IDConverter[discord.Member]):
if guild:
result = guild.get_member_named(argument)
else:
result = _get_from_guilds(bot, 'get_member_named', argument)
result = _get_from_guilds(bot, "get_member_named", argument)
else:
user_id = int(match.group(1))
if guild:
result = guild.get_member(user_id) or _utils_get(ctx.message.mentions, id=user_id)
result = guild.get_member(user_id) or _utils_get(
ctx.message.mentions, id=user_id
)
else:
result = _get_from_guilds(bot, 'get_member', user_id)
result = _get_from_guilds(bot, "get_member", user_id)
if result is None:
if guild is None:
@@ -276,13 +286,17 @@ class UserConverter(IDConverter[discord.User]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.User:
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
match = self._get_id_match(argument) or re.match(
r"<@!?([0-9]{15,20})>$", argument
)
result = None
state = ctx._state
if match is not None:
user_id = int(match.group(1))
result = ctx.bot.get_user(user_id) or _utils_get(ctx.message.mentions, id=user_id)
result = ctx.bot.get_user(user_id) or _utils_get(
ctx.message.mentions, id=user_id
)
if result is None:
try:
result = await ctx.bot.fetch_user(user_id)
@@ -294,12 +308,12 @@ class UserConverter(IDConverter[discord.User]):
arg = argument
# Remove the '@' character if this is the first character from the argument
if arg[0] == '@':
if arg[0] == "@":
# Remove first character
arg = arg[1:]
# check for discriminator if it exists,
if len(arg) > 5 and arg[-5] == '#':
if len(arg) > 5 and arg[-5] == "#":
discrim = arg[-4:]
name = arg[:-5]
predicate = lambda u: u.name == name and u.discriminator == discrim
@@ -330,29 +344,33 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
@staticmethod
def _get_id_matches(ctx, argument):
id_regex = re.compile(r'(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$')
id_regex = re.compile(
r"(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$"
)
link_regex = re.compile(
r'https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/'
r'(?P<guild_id>[0-9]{15,20}|@me)'
r'/(?P<channel_id>[0-9]{15,20})/(?P<message_id>[0-9]{15,20})/?$'
r"https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/"
r"(?P<guild_id>[0-9]{15,20}|@me)"
r"/(?P<channel_id>[0-9]{15,20})/(?P<message_id>[0-9]{15,20})/?$"
)
match = id_regex.match(argument) or link_regex.match(argument)
if not match:
raise MessageNotFound(argument)
data = match.groupdict()
channel_id = discord.utils._get_as_snowflake(data, 'channel_id')
message_id = int(data['message_id'])
guild_id = data.get('guild_id')
channel_id = discord.utils._get_as_snowflake(data, "channel_id")
message_id = int(data["message_id"])
guild_id = data.get("guild_id")
if guild_id is None:
guild_id = ctx.guild and ctx.guild.id
elif guild_id == '@me':
elif guild_id == "@me":
guild_id = None
else:
guild_id = int(guild_id)
return guild_id, message_id, channel_id
@staticmethod
def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]:
def _resolve_channel(
ctx, guild_id, channel_id
) -> Optional[PartialMessageableChannel]:
if guild_id is not None:
guild = ctx.bot.get_guild(guild_id)
if guild is not None and channel_id is not None:
@@ -386,7 +404,9 @@ class MessageConverter(IDConverter[discord.Message]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.Message:
guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument)
guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(
ctx, argument
)
message = ctx.bot._connection._get_message(message_id)
if message:
return message
@@ -417,13 +437,19 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.abc.GuildChannel:
return self._resolve_channel(ctx, argument, 'channels', discord.abc.GuildChannel)
return self._resolve_channel(
ctx, argument, "channels", discord.abc.GuildChannel
)
@staticmethod
def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT:
def _resolve_channel(
ctx: Context, argument: str, attribute: str, type: Type[CT]
) -> CT:
bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
match = IDConverter._get_id_match(argument) or re.match(
r"<#([0-9]{15,20})>$", argument
)
result = None
guild = ctx.guild
@@ -443,7 +469,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
if guild:
result = guild.get_channel(channel_id)
else:
result = _get_from_guilds(bot, 'get_channel', channel_id)
result = _get_from_guilds(bot, "get_channel", channel_id)
if not isinstance(result, type):
raise ChannelNotFound(argument)
@@ -451,10 +477,14 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
return result
@staticmethod
def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT:
def _resolve_thread(
ctx: Context, argument: str, attribute: str, type: Type[TT]
) -> TT:
bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
match = IDConverter._get_id_match(argument) or re.match(
r"<#([0-9]{15,20})>$", argument
)
result = None
guild = ctx.guild
@@ -491,7 +521,9 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.TextChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'text_channels', discord.TextChannel)
return GuildChannelConverter._resolve_channel(
ctx, argument, "text_channels", discord.TextChannel
)
class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
@@ -511,7 +543,9 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'voice_channels', discord.VoiceChannel)
return GuildChannelConverter._resolve_channel(
ctx, argument, "voice_channels", discord.VoiceChannel
)
class StageChannelConverter(IDConverter[discord.StageChannel]):
@@ -530,7 +564,9 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.StageChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'stage_channels', discord.StageChannel)
return GuildChannelConverter._resolve_channel(
ctx, argument, "stage_channels", discord.StageChannel
)
class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
@@ -550,7 +586,9 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'categories', discord.CategoryChannel)
return GuildChannelConverter._resolve_channel(
ctx, argument, "categories", discord.CategoryChannel
)
class StoreChannelConverter(IDConverter[discord.StoreChannel]):
@@ -569,7 +607,9 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel)
return GuildChannelConverter._resolve_channel(
ctx, argument, "channels", discord.StoreChannel
)
class ThreadConverter(IDConverter[discord.Thread]):
@@ -587,7 +627,9 @@ class ThreadConverter(IDConverter[discord.Thread]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.Thread:
return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread)
return GuildChannelConverter._resolve_thread(
ctx, argument, "threads", discord.Thread
)
class ColourConverter(Converter[discord.Colour]):
@@ -616,10 +658,12 @@ class ColourConverter(Converter[discord.Colour]):
Added support for ``rgb`` function and 3-digit hex shortcuts
"""
RGB_REGEX = re.compile(r'rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)')
RGB_REGEX = re.compile(
r"rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)"
)
def parse_hex_number(self, argument):
arg = ''.join(i * 2 for i in argument) if len(argument) == 3 else argument
arg = "".join(i * 2 for i in argument) if len(argument) == 3 else argument
try:
value = int(arg, base=16)
if not (0 <= value <= 0xFFFFFF):
@@ -630,7 +674,7 @@ class ColourConverter(Converter[discord.Colour]):
return discord.Color(value=value)
def parse_rgb_number(self, argument, number):
if number[-1] == '%':
if number[-1] == "%":
value = int(number[:-1])
if not (0 <= value <= 100):
raise BadColourArgument(argument)
@@ -646,29 +690,29 @@ class ColourConverter(Converter[discord.Colour]):
if match is None:
raise BadColourArgument(argument)
red = self.parse_rgb_number(argument, match.group('r'))
green = self.parse_rgb_number(argument, match.group('g'))
blue = self.parse_rgb_number(argument, match.group('b'))
red = self.parse_rgb_number(argument, match.group("r"))
green = self.parse_rgb_number(argument, match.group("g"))
blue = self.parse_rgb_number(argument, match.group("b"))
return discord.Color.from_rgb(red, green, blue)
async def convert(self, ctx: Context, argument: str) -> discord.Colour:
if argument[0] == '#':
if argument[0] == "#":
return self.parse_hex_number(argument[1:])
if argument[0:2] == '0x':
if argument[0:2] == "0x":
rest = argument[2:]
# Legacy backwards compatible syntax
if rest.startswith('#'):
if rest.startswith("#"):
return self.parse_hex_number(rest[1:])
return self.parse_hex_number(rest)
arg = argument.lower()
if arg[0:3] == 'rgb':
if arg[0:3] == "rgb":
return self.parse_rgb(arg)
arg = arg.replace(' ', '_')
arg = arg.replace(" ", "_")
method = getattr(discord.Colour, arg, None)
if arg.startswith('from_') or method is None or not inspect.ismethod(method):
if arg.startswith("from_") or method is None or not inspect.ismethod(method):
raise BadColourArgument(arg)
return method()
@@ -697,7 +741,9 @@ class RoleConverter(IDConverter[discord.Role]):
if not guild:
raise NoPrivateMessage()
match = self._get_id_match(argument) or re.match(r'<@&([0-9]{15,20})>$', argument)
match = self._get_id_match(argument) or re.match(
r"<@&([0-9]{15,20})>$", argument
)
if match:
result = guild.get_role(int(match.group(1)))
else:
@@ -776,7 +822,9 @@ class EmojiConverter(IDConverter[discord.Emoji]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.Emoji:
match = self._get_id_match(argument) or re.match(r'<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$', argument)
match = self._get_id_match(argument) or re.match(
r"<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$", argument
)
result = None
bot = ctx.bot
guild = ctx.guild
@@ -810,7 +858,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.PartialEmoji:
match = re.match(r'<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$', argument)
match = re.match(r"<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$", argument)
if match:
emoji_animated = bool(match.group(1))
@@ -818,7 +866,10 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
emoji_id = int(match.group(3))
return discord.PartialEmoji.with_state(
ctx.bot._connection, animated=emoji_animated, name=emoji_name, id=emoji_id
ctx.bot._connection,
animated=emoji_animated,
name=emoji_name,
id=emoji_id,
)
raise PartialEmojiConversionFailure(argument)
@@ -903,37 +954,41 @@ class clean_content(Converter[str]):
def resolve_member(id: int) -> str:
m = _utils_get(msg.mentions, id=id) or ctx.guild.get_member(id)
return f'@{m.display_name if self.use_nicknames else m.name}' if m else '@deleted-user'
return (
f"@{m.display_name if self.use_nicknames else m.name}"
if m
else "@deleted-user"
)
def resolve_role(id: int) -> str:
r = _utils_get(msg.role_mentions, id=id) or ctx.guild.get_role(id)
return f'@{r.name}' if r else '@deleted-role'
return f"@{r.name}" if r else "@deleted-role"
else:
def resolve_member(id: int) -> str:
m = _utils_get(msg.mentions, id=id) or ctx.bot.get_user(id)
return f'@{m.name}' if m else '@deleted-user'
return f"@{m.name}" if m else "@deleted-user"
def resolve_role(id: int) -> str:
return '@deleted-role'
return "@deleted-role"
if self.fix_channel_mentions and ctx.guild:
def resolve_channel(id: int) -> str:
c = ctx.guild.get_channel(id)
return f'#{c.name}' if c else '#deleted-channel'
return f"#{c.name}" if c else "#deleted-channel"
else:
def resolve_channel(id: int) -> str:
return f'<#{id}>'
return f"<#{id}>"
transforms = {
'@': resolve_member,
'@!': resolve_member,
'#': resolve_channel,
'@&': resolve_role,
"@": resolve_member,
"@!": resolve_member,
"#": resolve_channel,
"@&": resolve_role,
}
def repl(match: re.Match) -> str:
@@ -942,7 +997,7 @@ class clean_content(Converter[str]):
transformed = transforms[type](id)
return transformed
result = re.sub(r'<(@[!&]?|#)([0-9]{15,20})>', repl, argument)
result = re.sub(r"<(@[!&]?|#)([0-9]{15,20})>", repl, argument)
if self.escape_markdown:
result = discord.utils.escape_markdown(result)
elif self.remove_markdown:
@@ -974,42 +1029,46 @@ class Greedy(List[T]):
For more information, check :ref:`ext_commands_special_converters`.
"""
__slots__ = ('converter',)
__slots__ = ("converter",)
def __init__(self, *, converter: T):
self.converter = converter
def __repr__(self):
converter = getattr(self.converter, '__name__', repr(self.converter))
return f'Greedy[{converter}]'
converter = getattr(self.converter, "__name__", repr(self.converter))
return f"Greedy[{converter}]"
def __class_getitem__(cls, params: Union[Tuple[T], T]) -> Greedy[T]:
if not isinstance(params, tuple):
params = (params,)
if len(params) != 1:
raise TypeError('Greedy[...] only takes a single argument')
raise TypeError("Greedy[...] only takes a single argument")
converter = params[0]
origin = getattr(converter, '__origin__', None)
args = getattr(converter, '__args__', ())
origin = getattr(converter, "__origin__", None)
args = getattr(converter, "__args__", ())
if not (callable(converter) or isinstance(converter, Converter) or origin is not None):
raise TypeError('Greedy[...] expects a type or a Converter instance.')
if not (
callable(converter)
or isinstance(converter, Converter)
or origin is not None
):
raise TypeError("Greedy[...] expects a type or a Converter instance.")
if converter in (str, type(None)) or origin is Greedy:
raise TypeError(f'Greedy[{converter.__name__}] is invalid.')
raise TypeError(f"Greedy[{converter.__name__}] is invalid.")
if origin is Union and type(None) in args:
raise TypeError(f'Greedy[{converter!r}] is invalid.')
raise TypeError(f"Greedy[{converter!r}] is invalid.")
return cls(converter=converter)
def _convert_to_bool(argument: str) -> bool:
lowered = argument.lower()
if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'):
if lowered in ("yes", "y", "true", "t", "1", "enable", "on"):
return True
elif lowered in ('no', 'n', 'false', 'f', '0', 'disable', 'off'):
elif lowered in ("no", "n", "false", "f", "0", "disable", "off"):
return False
else:
raise BadBoolArgument(lowered)
@@ -1056,7 +1115,9 @@ CONVERTER_MAPPING: Dict[Type[Any], Any] = {
}
async def _actual_conversion(ctx: Context, converter, argument: str, param: inspect.Parameter):
async def _actual_conversion(
ctx: Context, converter, argument: str, param: inspect.Parameter
):
if converter is bool:
return _convert_to_bool(argument)
@@ -1065,7 +1126,9 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp
except AttributeError:
pass
else:
if module is not None and (module.startswith('discord.') and not module.endswith('converter')):
if module is not None and (
module.startswith("discord.") and not module.endswith("converter")
):
converter = CONVERTER_MAPPING.get(converter, converter)
try:
@@ -1091,10 +1154,14 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp
except AttributeError:
name = converter.__class__.__name__
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
raise BadArgument(
f'Converting to "{name}" failed for parameter "{param.name}".'
) from exc
async def run_converters(ctx: Context, converter, argument: str, param: inspect.Parameter):
async def run_converters(
ctx: Context, converter, argument: str, param: inspect.Parameter
):
"""|coro|
Runs converters for a given converter, argument, and parameter.
@@ -1124,7 +1191,7 @@ async def run_converters(ctx: Context, converter, argument: str, param: inspect.
Any
The resulting conversion.
"""
origin = getattr(converter, '__origin__', None)
origin = getattr(converter, "__origin__", None)
if origin is Union:
errors = []

View File

@@ -38,24 +38,25 @@ if TYPE_CHECKING:
from ...message import Message
__all__ = (
'BucketType',
'Cooldown',
'CooldownMapping',
'DynamicCooldownMapping',
'MaxConcurrency',
"BucketType",
"Cooldown",
"CooldownMapping",
"DynamicCooldownMapping",
"MaxConcurrency",
)
C = TypeVar('C', bound='CooldownMapping')
MC = TypeVar('MC', bound='MaxConcurrency')
C = TypeVar("C", bound="CooldownMapping")
MC = TypeVar("MC", bound="MaxConcurrency")
class BucketType(Enum):
default = 0
user = 1
guild = 2
channel = 3
member = 4
default = 0
user = 1
guild = 2
channel = 3
member = 4
category = 5
role = 6
role = 6
def get_key(self, msg: Message) -> Any:
if self is BucketType.user:
@@ -90,7 +91,7 @@ class Cooldown:
The length of the cooldown period in seconds.
"""
__slots__ = ('rate', 'per', '_window', '_tokens', '_last')
__slots__ = ("rate", "per", "_window", "_tokens", "_last")
def __init__(self, rate: float, per: float) -> None:
self.rate: int = int(rate)
@@ -190,7 +191,8 @@ class Cooldown:
return Cooldown(self.rate, self.per)
def __repr__(self) -> str:
return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>'
return f"<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>"
class CooldownMapping:
def __init__(
@@ -199,7 +201,7 @@ class CooldownMapping:
type: Callable[[Message], Any],
) -> None:
if not callable(type):
raise TypeError('Cooldown type must be a BucketType or callable')
raise TypeError("Cooldown type must be a BucketType or callable")
self._cache: Dict[Any, Cooldown] = {}
self._cooldown: Optional[Cooldown] = original
@@ -252,16 +254,16 @@ class CooldownMapping:
return bucket
def update_rate_limit(self, message: Message, current: Optional[float] = None) -> Optional[float]:
def update_rate_limit(
self, message: Message, current: Optional[float] = None
) -> Optional[float]:
bucket = self.get_bucket(message, current)
return bucket.update_rate_limit(current)
class DynamicCooldownMapping(CooldownMapping):
class DynamicCooldownMapping(CooldownMapping):
def __init__(
self,
factory: Callable[[Message], Cooldown],
type: Callable[[Message], Any]
self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any]
) -> None:
super().__init__(None, type)
self._factory: Callable[[Message], Cooldown] = factory
@@ -278,6 +280,7 @@ class DynamicCooldownMapping(CooldownMapping):
def create_bucket(self, message: Message) -> Cooldown:
return self._factory(message)
class _Semaphore:
"""This class is a version of a semaphore.
@@ -291,7 +294,7 @@ class _Semaphore:
overkill for what is basically a counter.
"""
__slots__ = ('value', 'loop', '_waiters')
__slots__ = ("value", "loop", "_waiters")
def __init__(self, number: int) -> None:
self.value: int = number
@@ -299,7 +302,7 @@ class _Semaphore:
self._waiters: Deque[asyncio.Future] = deque()
def __repr__(self) -> str:
return f'<_Semaphore value={self.value} waiters={len(self._waiters)}>'
return f"<_Semaphore value={self.value} waiters={len(self._waiters)}>"
def locked(self) -> bool:
return self.value == 0
@@ -337,8 +340,9 @@ class _Semaphore:
self.value += 1
self.wake_up()
class MaxConcurrency:
__slots__ = ('number', 'per', 'wait', '_mapping')
__slots__ = ("number", "per", "wait", "_mapping")
def __init__(self, number: int, *, per: BucketType, wait: bool) -> None:
self._mapping: Dict[Any, _Semaphore] = {}
@@ -347,16 +351,20 @@ class MaxConcurrency:
self.wait: bool = wait
if number <= 0:
raise ValueError('max_concurrency \'number\' cannot be less than 1')
raise ValueError("max_concurrency 'number' cannot be less than 1")
if not isinstance(per, BucketType):
raise TypeError(f'max_concurrency \'per\' must be of type BucketType not {type(per)!r}')
raise TypeError(
f"max_concurrency 'per' must be of type BucketType not {type(per)!r}"
)
def copy(self: MC) -> MC:
return self.__class__(self.number, per=self.per, wait=self.wait)
def __repr__(self) -> str:
return f'<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>'
return (
f"<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>"
)
def get_key(self, message: Message) -> Any:
return self.per.get_key(message)

File diff suppressed because it is too large Load Diff

View File

@@ -41,65 +41,66 @@ if TYPE_CHECKING:
__all__ = (
'CommandError',
'MissingRequiredArgument',
'BadArgument',
'PrivateMessageOnly',
'NoPrivateMessage',
'CheckFailure',
'CheckAnyFailure',
'CommandNotFound',
'DisabledCommand',
'CommandInvokeError',
'TooManyArguments',
'UserInputError',
'CommandOnCooldown',
'MaxConcurrencyReached',
'NotOwner',
'MessageNotFound',
'ObjectNotFound',
'MemberNotFound',
'GuildNotFound',
'UserNotFound',
'ChannelNotFound',
'ThreadNotFound',
'ChannelNotReadable',
'BadColourArgument',
'BadColorArgument',
'RoleNotFound',
'BadInviteArgument',
'EmojiNotFound',
'GuildStickerNotFound',
'PartialEmojiConversionFailure',
'BadBoolArgument',
'MissingRole',
'BotMissingRole',
'MissingAnyRole',
'BotMissingAnyRole',
'MissingPermissions',
'BotMissingPermissions',
'NSFWChannelRequired',
'ConversionError',
'BadUnionArgument',
'BadLiteralArgument',
'ArgumentParsingError',
'UnexpectedQuoteError',
'InvalidEndOfQuotedStringError',
'ExpectedClosingQuoteError',
'ExtensionError',
'ExtensionAlreadyLoaded',
'ExtensionNotLoaded',
'NoEntryPointError',
'ExtensionFailed',
'ExtensionNotFound',
'CommandRegistrationError',
'FlagError',
'BadFlagArgument',
'MissingFlagArgument',
'TooManyFlags',
'MissingRequiredFlag',
"CommandError",
"MissingRequiredArgument",
"BadArgument",
"PrivateMessageOnly",
"NoPrivateMessage",
"CheckFailure",
"CheckAnyFailure",
"CommandNotFound",
"DisabledCommand",
"CommandInvokeError",
"TooManyArguments",
"UserInputError",
"CommandOnCooldown",
"MaxConcurrencyReached",
"NotOwner",
"MessageNotFound",
"ObjectNotFound",
"MemberNotFound",
"GuildNotFound",
"UserNotFound",
"ChannelNotFound",
"ThreadNotFound",
"ChannelNotReadable",
"BadColourArgument",
"BadColorArgument",
"RoleNotFound",
"BadInviteArgument",
"EmojiNotFound",
"GuildStickerNotFound",
"PartialEmojiConversionFailure",
"BadBoolArgument",
"MissingRole",
"BotMissingRole",
"MissingAnyRole",
"BotMissingAnyRole",
"MissingPermissions",
"BotMissingPermissions",
"NSFWChannelRequired",
"ConversionError",
"BadUnionArgument",
"BadLiteralArgument",
"ArgumentParsingError",
"UnexpectedQuoteError",
"InvalidEndOfQuotedStringError",
"ExpectedClosingQuoteError",
"ExtensionError",
"ExtensionAlreadyLoaded",
"ExtensionNotLoaded",
"NoEntryPointError",
"ExtensionFailed",
"ExtensionNotFound",
"CommandRegistrationError",
"FlagError",
"BadFlagArgument",
"MissingFlagArgument",
"TooManyFlags",
"MissingRequiredFlag",
)
class CommandError(DiscordException):
r"""The base exception type for all command related errors.
@@ -109,14 +110,18 @@ class CommandError(DiscordException):
in a special way as they are caught and passed into a special event
from :class:`.Bot`\, :func:`.on_command_error`.
"""
def __init__(self, message: Optional[str] = None, *args: Any) -> None:
if message is not None:
# clean-up @everyone and @here mentions
m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere')
m = message.replace("@everyone", "@\u200beveryone").replace(
"@here", "@\u200bhere"
)
super().__init__(m, *args)
else:
super().__init__(*args)
class ConversionError(CommandError):
"""Exception raised when a Converter class raises non-CommandError.
@@ -130,18 +135,22 @@ class ConversionError(CommandError):
The original exception that was raised. You can also get this via
the ``__cause__`` attribute.
"""
def __init__(self, converter: Converter, original: Exception) -> None:
self.converter: Converter = converter
self.original: Exception = original
class UserInputError(CommandError):
"""The base exception type for errors that involve errors
regarding user input.
This inherits from :exc:`CommandError`.
"""
pass
class CommandNotFound(CommandError):
"""Exception raised when a command is attempted to be invoked
but no command under that name is found.
@@ -151,8 +160,10 @@ class CommandNotFound(CommandError):
This inherits from :exc:`CommandError`.
"""
pass
class MissingRequiredArgument(UserInputError):
"""Exception raised when parsing a command and a parameter
that is required is not encountered.
@@ -164,9 +175,11 @@ class MissingRequiredArgument(UserInputError):
param: :class:`inspect.Parameter`
The argument that is missing.
"""
def __init__(self, param: Parameter) -> None:
self.param: Parameter = param
super().__init__(f'{param.name} is a required argument that is missing.')
super().__init__(f"{param.name} is a required argument that is missing.")
class TooManyArguments(UserInputError):
"""Exception raised when the command was passed too many arguments and its
@@ -174,23 +187,29 @@ class TooManyArguments(UserInputError):
This inherits from :exc:`UserInputError`
"""
pass
class BadArgument(UserInputError):
"""Exception raised when a parsing or conversion failure is encountered
on an argument to pass into a command.
This inherits from :exc:`UserInputError`
"""
pass
class CheckFailure(CommandError):
"""Exception raised when the predicates in :attr:`.Command.checks` have failed.
This inherits from :exc:`CommandError`
"""
pass
class CheckAnyFailure(CheckFailure):
"""Exception raised when all predicates in :func:`check_any` fail.
@@ -206,10 +225,13 @@ class CheckAnyFailure(CheckFailure):
A list of check predicates that failed.
"""
def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]) -> None:
def __init__(
self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]
) -> None:
self.checks: List[CheckFailure] = checks
self.errors: List[Callable[[Context], bool]] = errors
super().__init__('You do not have permission to run this command.')
super().__init__("You do not have permission to run this command.")
class PrivateMessageOnly(CheckFailure):
"""Exception raised when an operation does not work outside of private
@@ -217,8 +239,12 @@ class PrivateMessageOnly(CheckFailure):
This inherits from :exc:`CheckFailure`
"""
def __init__(self, message: Optional[str] = None) -> None:
super().__init__(message or 'This command can only be used in private messages.')
super().__init__(
message or "This command can only be used in private messages."
)
class NoPrivateMessage(CheckFailure):
"""Exception raised when an operation does not work in private message
@@ -228,15 +254,18 @@ class NoPrivateMessage(CheckFailure):
"""
def __init__(self, message: Optional[str] = None) -> None:
super().__init__(message or 'This command cannot be used in private messages.')
super().__init__(message or "This command cannot be used in private messages.")
class NotOwner(CheckFailure):
"""Exception raised when the message author is not the owner of the bot.
This inherits from :exc:`CheckFailure`
"""
pass
class ObjectNotFound(BadArgument):
"""Exception raised when the argument provided did not match the format
of an ID or a mention.
@@ -250,9 +279,11 @@ class ObjectNotFound(BadArgument):
argument: :class:`str`
The argument supplied by the caller that was not matched
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'{argument!r} does not follow a valid ID or mention format.')
super().__init__(f"{argument!r} does not follow a valid ID or mention format.")
class MemberNotFound(BadArgument):
"""Exception raised when the member provided was not found in the bot's
@@ -267,10 +298,12 @@ class MemberNotFound(BadArgument):
argument: :class:`str`
The member supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Member "{argument}" not found.')
class GuildNotFound(BadArgument):
"""Exception raised when the guild provided was not found in the bot's cache.
@@ -283,10 +316,12 @@ class GuildNotFound(BadArgument):
argument: :class:`str`
The guild supplied by the called that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Guild "{argument}" not found.')
class UserNotFound(BadArgument):
"""Exception raised when the user provided was not found in the bot's
cache.
@@ -300,10 +335,12 @@ class UserNotFound(BadArgument):
argument: :class:`str`
The user supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'User "{argument}" not found.')
class MessageNotFound(BadArgument):
"""Exception raised when the message provided was not found in the channel.
@@ -316,10 +353,12 @@ class MessageNotFound(BadArgument):
argument: :class:`str`
The message supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Message "{argument}" not found.')
class ChannelNotReadable(BadArgument):
"""Exception raised when the bot does not have permission to read messages
in the channel.
@@ -333,10 +372,12 @@ class ChannelNotReadable(BadArgument):
argument: Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
The channel supplied by the caller that was not readable
"""
def __init__(self, argument: Union[GuildChannel, Thread]) -> None:
self.argument: Union[GuildChannel, Thread] = argument
super().__init__(f"Can't read messages in {argument.mention}.")
class ChannelNotFound(BadArgument):
"""Exception raised when the bot can not find the channel.
@@ -349,10 +390,12 @@ class ChannelNotFound(BadArgument):
argument: :class:`str`
The channel supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Channel "{argument}" not found.')
class ThreadNotFound(BadArgument):
"""Exception raised when the bot can not find the thread.
@@ -365,10 +408,12 @@ class ThreadNotFound(BadArgument):
argument: :class:`str`
The thread supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Thread "{argument}" not found.')
class BadColourArgument(BadArgument):
"""Exception raised when the colour is not valid.
@@ -381,12 +426,15 @@ class BadColourArgument(BadArgument):
argument: :class:`str`
The colour supplied by the caller that was not valid
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Colour "{argument}" is invalid.')
BadColorArgument = BadColourArgument
class RoleNotFound(BadArgument):
"""Exception raised when the bot can not find the role.
@@ -399,10 +447,12 @@ class RoleNotFound(BadArgument):
argument: :class:`str`
The role supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Role "{argument}" not found.')
class BadInviteArgument(BadArgument):
"""Exception raised when the invite is invalid or expired.
@@ -410,10 +460,12 @@ class BadInviteArgument(BadArgument):
.. versionadded:: 1.5
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Invite "{argument}" is invalid or expired.')
class EmojiNotFound(BadArgument):
"""Exception raised when the bot can not find the emoji.
@@ -426,10 +478,12 @@ class EmojiNotFound(BadArgument):
argument: :class:`str`
The emoji supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Emoji "{argument}" not found.')
class PartialEmojiConversionFailure(BadArgument):
"""Exception raised when the emoji provided does not match the correct
format.
@@ -443,10 +497,12 @@ class PartialEmojiConversionFailure(BadArgument):
argument: :class:`str`
The emoji supplied by the caller that did not match the regex
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Couldn\'t convert "{argument}" to PartialEmoji.')
class GuildStickerNotFound(BadArgument):
"""Exception raised when the bot can not find the sticker.
@@ -459,10 +515,12 @@ class GuildStickerNotFound(BadArgument):
argument: :class:`str`
The sticker supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Sticker "{argument}" not found.')
class BadBoolArgument(BadArgument):
"""Exception raised when a boolean argument was not convertable.
@@ -475,17 +533,21 @@ class BadBoolArgument(BadArgument):
argument: :class:`str`
The boolean argument supplied by the caller that is not in the predefined list
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'{argument} is not a recognised boolean option')
super().__init__(f"{argument} is not a recognised boolean option")
class DisabledCommand(CommandError):
"""Exception raised when the command being invoked is disabled.
This inherits from :exc:`CommandError`
"""
pass
class CommandInvokeError(CommandError):
"""Exception raised when the command being invoked raised an exception.
@@ -497,9 +559,11 @@ class CommandInvokeError(CommandError):
The original exception that was raised. You can also get this via
the ``__cause__`` attribute.
"""
def __init__(self, e: Exception) -> None:
self.original: Exception = e
super().__init__(f'Command raised an exception: {e.__class__.__name__}: {e}')
super().__init__(f"Command raised an exception: {e.__class__.__name__}: {e}")
class CommandOnCooldown(CommandError):
"""Exception raised when the command being invoked is on cooldown.
@@ -516,11 +580,15 @@ class CommandOnCooldown(CommandError):
retry_after: :class:`float`
The amount of seconds to wait before you can retry again.
"""
def __init__(self, cooldown: Cooldown, retry_after: float, type: BucketType) -> None:
def __init__(
self, cooldown: Cooldown, retry_after: float, type: BucketType
) -> None:
self.cooldown: Cooldown = cooldown
self.retry_after: float = retry_after
self.type: BucketType = type
super().__init__(f'You are on cooldown. Try again in {retry_after:.2f}s')
super().__init__(f"You are on cooldown. Try again in {retry_after:.2f}s")
class MaxConcurrencyReached(CommandError):
"""Exception raised when the command being invoked has reached its maximum concurrency.
@@ -539,10 +607,13 @@ class MaxConcurrencyReached(CommandError):
self.number: int = number
self.per: BucketType = per
name = per.name
suffix = 'per %s' % name if per.name != 'default' else 'globally'
plural = '%s times %s' if number > 1 else '%s time %s'
suffix = "per %s" % name if per.name != "default" else "globally"
plural = "%s times %s" if number > 1 else "%s time %s"
fmt = plural % (number, suffix)
super().__init__(f'Too many people are using this command. It can only be used {fmt} concurrently.')
super().__init__(
f"Too many people are using this command. It can only be used {fmt} concurrently."
)
class MissingRole(CheckFailure):
"""Exception raised when the command invoker lacks a role to run a command.
@@ -557,11 +628,13 @@ class MissingRole(CheckFailure):
The required role that is missing.
This is the parameter passed to :func:`~.commands.has_role`.
"""
def __init__(self, missing_role: Snowflake) -> None:
self.missing_role: Snowflake = missing_role
message = f'Role {missing_role!r} is required to run this command.'
message = f"Role {missing_role!r} is required to run this command."
super().__init__(message)
class BotMissingRole(CheckFailure):
"""Exception raised when the bot's member lacks a role to run a command.
@@ -575,11 +648,13 @@ class BotMissingRole(CheckFailure):
The required role that is missing.
This is the parameter passed to :func:`~.commands.has_role`.
"""
def __init__(self, missing_role: Snowflake) -> None:
self.missing_role: Snowflake = missing_role
message = f'Bot requires the role {missing_role!r} to run this command'
message = f"Bot requires the role {missing_role!r} to run this command"
super().__init__(message)
class MissingAnyRole(CheckFailure):
"""Exception raised when the command invoker lacks any of
the roles specified to run a command.
@@ -594,15 +669,16 @@ class MissingAnyRole(CheckFailure):
The roles that the invoker is missing.
These are the parameters passed to :func:`~.commands.has_any_role`.
"""
def __init__(self, missing_roles: SnowflakeList) -> None:
self.missing_roles: SnowflakeList = missing_roles
missing = [f"'{role}'" for role in missing_roles]
if len(missing) > 2:
fmt = '{}, or {}'.format(", ".join(missing[:-1]), missing[-1])
fmt = "{}, or {}".format(", ".join(missing[:-1]), missing[-1])
else:
fmt = ' or '.join(missing)
fmt = " or ".join(missing)
message = f"You are missing at least one of the required roles: {fmt}"
super().__init__(message)
@@ -623,19 +699,21 @@ class BotMissingAnyRole(CheckFailure):
These are the parameters passed to :func:`~.commands.has_any_role`.
"""
def __init__(self, missing_roles: SnowflakeList) -> None:
self.missing_roles: SnowflakeList = missing_roles
missing = [f"'{role}'" for role in missing_roles]
if len(missing) > 2:
fmt = '{}, or {}'.format(", ".join(missing[:-1]), missing[-1])
fmt = "{}, or {}".format(", ".join(missing[:-1]), missing[-1])
else:
fmt = ' or '.join(missing)
fmt = " or ".join(missing)
message = f"Bot is missing at least one of the required roles: {fmt}"
super().__init__(message)
class NSFWChannelRequired(CheckFailure):
"""Exception raised when a channel does not have the required NSFW setting.
@@ -648,9 +726,13 @@ class NSFWChannelRequired(CheckFailure):
channel: Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
The channel that does not have NSFW enabled.
"""
def __init__(self, channel: Union[GuildChannel, Thread]) -> None:
self.channel: Union[GuildChannel, Thread] = channel
super().__init__(f"Channel '{channel}' needs to be NSFW for this command to work.")
super().__init__(
f"Channel '{channel}' needs to be NSFW for this command to work."
)
class MissingPermissions(CheckFailure):
"""Exception raised when the command invoker lacks permissions to run a
@@ -663,18 +745,23 @@ class MissingPermissions(CheckFailure):
missing_permissions: List[:class:`str`]
The required permissions that are missing.
"""
def __init__(self, missing_permissions: List[str], *args: Any) -> None:
self.missing_permissions: List[str] = missing_permissions
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions]
missing = [
perm.replace("_", " ").replace("guild", "server").title()
for perm in missing_permissions
]
if len(missing) > 2:
fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1])
fmt = "{}, and {}".format(", ".join(missing[:-1]), missing[-1])
else:
fmt = ' and '.join(missing)
message = f'You are missing {fmt} permission(s) to run this command.'
fmt = " and ".join(missing)
message = f"You are missing {fmt} permission(s) to run this command."
super().__init__(message, *args)
class BotMissingPermissions(CheckFailure):
"""Exception raised when the bot's member lacks permissions to run a
command.
@@ -686,18 +773,23 @@ class BotMissingPermissions(CheckFailure):
missing_permissions: List[:class:`str`]
The required permissions that are missing.
"""
def __init__(self, missing_permissions: List[str], *args: Any) -> None:
self.missing_permissions: List[str] = missing_permissions
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions]
missing = [
perm.replace("_", " ").replace("guild", "server").title()
for perm in missing_permissions
]
if len(missing) > 2:
fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1])
fmt = "{}, and {}".format(", ".join(missing[:-1]), missing[-1])
else:
fmt = ' and '.join(missing)
message = f'Bot requires {fmt} permission(s) to run this command.'
fmt = " and ".join(missing)
message = f"Bot requires {fmt} permission(s) to run this command."
super().__init__(message, *args)
class BadUnionArgument(UserInputError):
"""Exception raised when a :data:`typing.Union` converter fails for all
its associated types.
@@ -713,7 +805,10 @@ class BadUnionArgument(UserInputError):
errors: List[:class:`CommandError`]
A list of errors that were caught from failing the conversion.
"""
def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None:
def __init__(
self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]
) -> None:
self.param: Parameter = param
self.converters: Tuple[Type, ...] = converters
self.errors: List[CommandError] = errors
@@ -722,18 +817,19 @@ class BadUnionArgument(UserInputError):
try:
return x.__name__
except AttributeError:
if hasattr(x, '__origin__'):
if hasattr(x, "__origin__"):
return repr(x)
return x.__class__.__name__
to_string = [_get_name(x) for x in converters]
if len(to_string) > 2:
fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1])
fmt = "{}, or {}".format(", ".join(to_string[:-1]), to_string[-1])
else:
fmt = ' or '.join(to_string)
fmt = " or ".join(to_string)
super().__init__(f'Could not convert "{param.name}" into {fmt}.')
class BadLiteralArgument(UserInputError):
"""Exception raised when a :data:`typing.Literal` converter fails for all
its associated values.
@@ -751,19 +847,23 @@ class BadLiteralArgument(UserInputError):
errors: List[:class:`CommandError`]
A list of errors that were caught from failing the conversion.
"""
def __init__(self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError]) -> None:
def __init__(
self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError]
) -> None:
self.param: Parameter = param
self.literals: Tuple[Any, ...] = literals
self.errors: List[CommandError] = errors
to_string = [repr(l) for l in literals]
if len(to_string) > 2:
fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1])
fmt = "{}, or {}".format(", ".join(to_string[:-1]), to_string[-1])
else:
fmt = ' or '.join(to_string)
fmt = " or ".join(to_string)
super().__init__(f'Could not convert "{param.name}" into the literal {fmt}.')
class ArgumentParsingError(UserInputError):
"""An exception raised when the parser fails to parse a user's input.
@@ -772,8 +872,10 @@ class ArgumentParsingError(UserInputError):
There are child classes that implement more granular parsing errors for
i18n purposes.
"""
pass
class UnexpectedQuoteError(ArgumentParsingError):
"""An exception raised when the parser encounters a quote mark inside a non-quoted string.
@@ -784,9 +886,11 @@ class UnexpectedQuoteError(ArgumentParsingError):
quote: :class:`str`
The quote mark that was found inside the non-quoted string.
"""
def __init__(self, quote: str) -> None:
self.quote: str = quote
super().__init__(f'Unexpected quote mark, {quote!r}, in non-quoted string')
super().__init__(f"Unexpected quote mark, {quote!r}, in non-quoted string")
class InvalidEndOfQuotedStringError(ArgumentParsingError):
"""An exception raised when a space is expected after the closing quote in a string
@@ -799,9 +903,13 @@ class InvalidEndOfQuotedStringError(ArgumentParsingError):
char: :class:`str`
The character found instead of the expected string.
"""
def __init__(self, char: str) -> None:
self.char: str = char
super().__init__(f'Expected space after closing quotation but received {char!r}')
super().__init__(
f"Expected space after closing quotation but received {char!r}"
)
class ExpectedClosingQuoteError(ArgumentParsingError):
"""An exception raised when a quote character is expected but not found.
@@ -816,7 +924,8 @@ class ExpectedClosingQuoteError(ArgumentParsingError):
def __init__(self, close_quote: str) -> None:
self.close_quote: str = close_quote
super().__init__(f'Expected closing {close_quote}.')
super().__init__(f"Expected closing {close_quote}.")
class ExtensionError(DiscordException):
"""Base exception for extension related errors.
@@ -828,37 +937,47 @@ class ExtensionError(DiscordException):
name: :class:`str`
The extension that had an error.
"""
def __init__(self, message: Optional[str] = None, *args: Any, name: str) -> None:
self.name: str = name
message = message or f'Extension {name!r} had an error.'
message = message or f"Extension {name!r} had an error."
# clean-up @everyone and @here mentions
m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere')
m = message.replace("@everyone", "@\u200beveryone").replace(
"@here", "@\u200bhere"
)
super().__init__(m, *args)
class ExtensionAlreadyLoaded(ExtensionError):
"""An exception raised when an extension has already been loaded.
This inherits from :exc:`ExtensionError`
"""
def __init__(self, name: str) -> None:
super().__init__(f'Extension {name!r} is already loaded.', name=name)
super().__init__(f"Extension {name!r} is already loaded.", name=name)
class ExtensionNotLoaded(ExtensionError):
"""An exception raised when an extension was not loaded.
This inherits from :exc:`ExtensionError`
"""
def __init__(self, name: str) -> None:
super().__init__(f'Extension {name!r} has not been loaded.', name=name)
super().__init__(f"Extension {name!r} has not been loaded.", name=name)
class NoEntryPointError(ExtensionError):
"""An exception raised when an extension does not have a ``setup`` entry point function.
This inherits from :exc:`ExtensionError`
"""
def __init__(self, name: str) -> None:
super().__init__(f"Extension {name!r} has no 'setup' function.", name=name)
class ExtensionFailed(ExtensionError):
"""An exception raised when an extension failed to load during execution of the module or ``setup`` entry point.
@@ -872,11 +991,13 @@ class ExtensionFailed(ExtensionError):
The original exception that was raised. You can also get this via
the ``__cause__`` attribute.
"""
def __init__(self, name: str, original: Exception) -> None:
self.original: Exception = original
msg = f'Extension {name!r} raised an error: {original.__class__.__name__}: {original}'
msg = f"Extension {name!r} raised an error: {original.__class__.__name__}: {original}"
super().__init__(msg, name=name)
class ExtensionNotFound(ExtensionError):
"""An exception raised when an extension is not found.
@@ -890,10 +1011,12 @@ class ExtensionNotFound(ExtensionError):
name: :class:`str`
The extension that had the error.
"""
def __init__(self, name: str) -> None:
msg = f'Extension {name!r} could not be loaded.'
msg = f"Extension {name!r} could not be loaded."
super().__init__(msg, name=name)
class CommandRegistrationError(ClientException):
"""An exception raised when the command can't be added
because the name is already taken by a different command.
@@ -909,11 +1032,13 @@ class CommandRegistrationError(ClientException):
alias_conflict: :class:`bool`
Whether the name that conflicts is an alias of the command we try to add.
"""
def __init__(self, name: str, *, alias_conflict: bool = False) -> None:
self.name: str = name
self.alias_conflict: bool = alias_conflict
type_ = 'alias' if alias_conflict else 'command'
super().__init__(f'The {type_} {name} is already an existing command or alias.')
type_ = "alias" if alias_conflict else "command"
super().__init__(f"The {type_} {name} is already an existing command or alias.")
class FlagError(BadArgument):
"""The base exception type for all flag parsing related errors.
@@ -922,8 +1047,10 @@ class FlagError(BadArgument):
.. versionadded:: 2.0
"""
pass
class TooManyFlags(FlagError):
"""An exception raised when a flag has received too many values.
@@ -938,10 +1065,14 @@ class TooManyFlags(FlagError):
values: List[:class:`str`]
The values that were passed.
"""
def __init__(self, flag: Flag, values: List[str]) -> None:
self.flag: Flag = flag
self.values: List[str] = values
super().__init__(f'Too many flag values, expected {flag.max_args} but received {len(values)}.')
super().__init__(
f"Too many flag values, expected {flag.max_args} but received {len(values)}."
)
class BadFlagArgument(FlagError):
"""An exception raised when a flag failed to convert a value.
@@ -955,6 +1086,7 @@ class BadFlagArgument(FlagError):
flag: :class:`~discord.ext.commands.Flag`
The flag that failed to convert.
"""
def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag
try:
@@ -962,7 +1094,8 @@ class BadFlagArgument(FlagError):
except AttributeError:
name = flag.annotation.__class__.__name__
super().__init__(f'Could not convert to {name!r} for flag {flag.name!r}')
super().__init__(f"Could not convert to {name!r} for flag {flag.name!r}")
class MissingRequiredFlag(FlagError):
"""An exception raised when a required flag was not given.
@@ -976,9 +1109,11 @@ class MissingRequiredFlag(FlagError):
flag: :class:`~discord.ext.commands.Flag`
The required flag that was not found.
"""
def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag
super().__init__(f'Flag {flag.name!r} is required and missing')
super().__init__(f"Flag {flag.name!r} is required and missing")
class MissingFlagArgument(FlagError):
"""An exception raised when a flag did not get a value.
@@ -992,6 +1127,7 @@ class MissingFlagArgument(FlagError):
flag: :class:`~discord.ext.commands.Flag`
The flag that did not get a value.
"""
def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag
super().__init__(f'Flag {flag.name!r} does not have an argument')
super().__init__(f"Flag {flag.name!r} does not have an argument")

View File

@@ -59,9 +59,9 @@ import sys
import re
__all__ = (
'Flag',
'flag',
'FlagConverter',
"Flag",
"flag",
"FlagConverter",
)
@@ -143,25 +143,35 @@ def flag(
Whether multiple given values overrides the previous value. The default
value depends on the annotation given.
"""
return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override)
return Flag(
name=name,
aliases=aliases,
default=default,
max_args=max_args,
override=override,
)
def validate_flag_name(name: str, forbidden: Set[str]):
if not name:
raise ValueError('flag names should not be empty')
raise ValueError("flag names should not be empty")
for ch in name:
if ch.isspace():
raise ValueError(f'flag name {name!r} cannot have spaces')
if ch == '\\':
raise ValueError(f'flag name {name!r} cannot have backslashes')
raise ValueError(f"flag name {name!r} cannot have spaces")
if ch == "\\":
raise ValueError(f"flag name {name!r} cannot have backslashes")
if ch in forbidden:
raise ValueError(f'flag name {name!r} cannot have any of {forbidden!r} within them')
raise ValueError(
f"flag name {name!r} cannot have any of {forbidden!r} within them"
)
def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]:
annotations = namespace.get('__annotations__', {})
case_insensitive = namespace['__commands_flag_case_insensitive__']
def get_flags(
namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]
) -> Dict[str, Flag]:
annotations = namespace.get("__annotations__", {})
case_insensitive = namespace["__commands_flag_case_insensitive__"]
flags: Dict[str, Flag] = {}
cache: Dict[str, Any] = {}
names: Set[str] = set()
@@ -176,9 +186,15 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
if flag.name is MISSING:
flag.name = name
annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache)
annotation = flag.annotation = resolve_annotation(
flag.annotation, globals, locals, cache
)
if flag.default is MISSING and hasattr(annotation, '__commands_is_flag__') and annotation._can_be_constructible():
if (
flag.default is MISSING
and hasattr(annotation, "__commands_is_flag__")
and annotation._can_be_constructible()
):
flag.default = annotation._construct_default
if flag.aliases is MISSING:
@@ -229,7 +245,9 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
if flag.max_args is MISSING:
flag.max_args = 1
else:
raise TypeError(f'Unsupported typing annotation {annotation!r} for {flag.name!r} flag')
raise TypeError(
f"Unsupported typing annotation {annotation!r} for {flag.name!r} flag"
)
if flag.override is MISSING:
flag.override = False
@@ -237,7 +255,9 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
# Validate flag names are unique
name = flag.name.casefold() if case_insensitive else flag.name
if name in names:
raise TypeError(f'{flag.name!r} flag conflicts with previous flag or alias.')
raise TypeError(
f"{flag.name!r} flag conflicts with previous flag or alias."
)
else:
names.add(name)
@@ -245,7 +265,9 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
# Validate alias is unique
alias = alias.casefold() if case_insensitive else alias
if alias in names:
raise TypeError(f'{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.')
raise TypeError(
f"{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias."
)
else:
names.add(alias)
@@ -274,10 +296,10 @@ class FlagsMeta(type):
delimiter: str = MISSING,
prefix: str = MISSING,
):
attrs['__commands_is_flag__'] = True
attrs["__commands_is_flag__"] = True
try:
global_ns = sys.modules[attrs['__module__']].__dict__
global_ns = sys.modules[attrs["__module__"]].__dict__
except KeyError:
global_ns = {}
@@ -296,26 +318,32 @@ class FlagsMeta(type):
flags: Dict[str, Flag] = {}
aliases: Dict[str, str] = {}
for base in reversed(bases):
if base.__dict__.get('__commands_is_flag__', False):
flags.update(base.__dict__['__commands_flags__'])
aliases.update(base.__dict__['__commands_flag_aliases__'])
if base.__dict__.get("__commands_is_flag__", False):
flags.update(base.__dict__["__commands_flags__"])
aliases.update(base.__dict__["__commands_flag_aliases__"])
if case_insensitive is MISSING:
attrs['__commands_flag_case_insensitive__'] = base.__dict__['__commands_flag_case_insensitive__']
attrs["__commands_flag_case_insensitive__"] = base.__dict__[
"__commands_flag_case_insensitive__"
]
if delimiter is MISSING:
attrs['__commands_flag_delimiter__'] = base.__dict__['__commands_flag_delimiter__']
attrs["__commands_flag_delimiter__"] = base.__dict__[
"__commands_flag_delimiter__"
]
if prefix is MISSING:
attrs['__commands_flag_prefix__'] = base.__dict__['__commands_flag_prefix__']
attrs["__commands_flag_prefix__"] = base.__dict__[
"__commands_flag_prefix__"
]
if case_insensitive is not MISSING:
attrs['__commands_flag_case_insensitive__'] = case_insensitive
attrs["__commands_flag_case_insensitive__"] = case_insensitive
if delimiter is not MISSING:
attrs['__commands_flag_delimiter__'] = delimiter
attrs["__commands_flag_delimiter__"] = delimiter
if prefix is not MISSING:
attrs['__commands_flag_prefix__'] = prefix
attrs["__commands_flag_prefix__"] = prefix
case_insensitive = attrs.setdefault('__commands_flag_case_insensitive__', False)
delimiter = attrs.setdefault('__commands_flag_delimiter__', ':')
prefix = attrs.setdefault('__commands_flag_prefix__', '')
case_insensitive = attrs.setdefault("__commands_flag_case_insensitive__", False)
delimiter = attrs.setdefault("__commands_flag_delimiter__", ":")
prefix = attrs.setdefault("__commands_flag_prefix__", "")
for flag_name, flag in get_flags(attrs, global_ns, local_ns).items():
flags[flag_name] = flag
@@ -330,23 +358,30 @@ class FlagsMeta(type):
regex_flags = 0
if case_insensitive:
flags = {key.casefold(): value for key, value in flags.items()}
aliases = {key.casefold(): value.casefold() for key, value in aliases.items()}
aliases = {
key.casefold(): value.casefold() for key, value in aliases.items()
}
regex_flags = re.IGNORECASE
keys = list(re.escape(k) for k in flags)
keys.extend(re.escape(a) for a in aliases)
keys = sorted(keys, key=lambda t: len(t), reverse=True)
joined = '|'.join(keys)
pattern = re.compile(f'(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})', regex_flags)
attrs['__commands_flag_regex__'] = pattern
attrs['__commands_flags__'] = flags
attrs['__commands_flag_aliases__'] = aliases
joined = "|".join(keys)
pattern = re.compile(
f"(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})",
regex_flags,
)
attrs["__commands_flag_regex__"] = pattern
attrs["__commands_flags__"] = flags
attrs["__commands_flag_aliases__"] = aliases
return type.__new__(cls, name, bases, attrs)
async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]:
async def tuple_convert_all(
ctx: Context, argument: str, flag: Flag, converter: Any
) -> Tuple[Any, ...]:
view = StringView(argument)
results = []
param: inspect.Parameter = ctx.current_parameter # type: ignore
@@ -371,7 +406,9 @@ async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter:
return tuple(results)
async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]:
async def tuple_convert_flag(
ctx: Context, argument: str, flag: Flag, converters: Any
) -> Tuple[Any, ...]:
view = StringView(argument)
results = []
param: inspect.Parameter = ctx.current_parameter # type: ignore
@@ -409,9 +446,13 @@ async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -
else:
if origin is tuple:
if annotation.__args__[-1] is Ellipsis:
return await tuple_convert_all(ctx, argument, flag, annotation.__args__[0])
return await tuple_convert_all(
ctx, argument, flag, annotation.__args__[0]
)
else:
return await tuple_convert_flag(ctx, argument, flag, annotation.__args__)
return await tuple_convert_flag(
ctx, argument, flag, annotation.__args__
)
elif origin is list:
# typing.List[x]
annotation = annotation.__args__[0]
@@ -432,7 +473,7 @@ async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -
raise BadFlagArgument(flag) from e
F = TypeVar('F', bound='FlagConverter')
F = TypeVar("F", bound="FlagConverter")
class FlagConverter(metaclass=FlagsMeta):
@@ -493,8 +534,13 @@ class FlagConverter(metaclass=FlagsMeta):
return self
def __repr__(self) -> str:
pairs = ' '.join([f'{flag.attribute}={getattr(self, flag.attribute)!r}' for flag in self.get_flags().values()])
return f'<{self.__class__.__name__} {pairs}>'
pairs = " ".join(
[
f"{flag.attribute}={getattr(self, flag.attribute)!r}"
for flag in self.get_flags().values()
]
)
return f"<{self.__class__.__name__} {pairs}>"
@classmethod
def parse_flags(cls, argument: str) -> Dict[str, List[str]]:
@@ -507,7 +553,7 @@ class FlagConverter(metaclass=FlagsMeta):
case_insensitive = cls.__commands_flag_case_insensitive__
for match in cls.__commands_flag_regex__.finditer(argument):
begin, end = match.span(0)
key = match.group('flag')
key = match.group("flag")
if case_insensitive:
key = key.casefold()

View File

@@ -39,10 +39,10 @@ if TYPE_CHECKING:
from .context import Context
__all__ = (
'Paginator',
'HelpCommand',
'DefaultHelpCommand',
'MinimalHelpCommand',
"Paginator",
"HelpCommand",
"DefaultHelpCommand",
"MinimalHelpCommand",
)
# help -> shows info of bot on top/bottom and lists subcommands
@@ -89,7 +89,7 @@ class Paginator:
.. versionadded:: 1.7
"""
def __init__(self, prefix='```', suffix='```', max_size=2000, linesep='\n'):
def __init__(self, prefix="```", suffix="```", max_size=2000, linesep="\n"):
self.prefix = prefix
self.suffix = suffix
self.max_size = max_size
@@ -118,7 +118,7 @@ class Paginator:
def _linesep_len(self):
return len(self.linesep)
def add_line(self, line='', *, empty=False):
def add_line(self, line="", *, empty=False):
"""Adds a line to the current page.
If the line exceeds the :attr:`max_size` then an exception
@@ -136,18 +136,23 @@ class Paginator:
RuntimeError
The line was too big for the current :attr:`max_size`.
"""
max_page_size = self.max_size - self._prefix_len - self._suffix_len - 2 * self._linesep_len
max_page_size = (
self.max_size - self._prefix_len - self._suffix_len - 2 * self._linesep_len
)
if len(line) > max_page_size:
raise RuntimeError(f'Line exceeds maximum page size {max_page_size}')
raise RuntimeError(f"Line exceeds maximum page size {max_page_size}")
if self._count + len(line) + self._linesep_len > self.max_size - self._suffix_len:
if (
self._count + len(line) + self._linesep_len
> self.max_size - self._suffix_len
):
self.close_page()
self._count += len(line) + self._linesep_len
self._current_page.append(line)
if empty:
self._current_page.append('')
self._current_page.append("")
self._count += self._linesep_len
def close_page(self):
@@ -176,7 +181,7 @@ class Paginator:
return self._pages
def __repr__(self):
fmt = '<Paginator prefix: {0.prefix!r} suffix: {0.suffix!r} linesep: {0.linesep!r} max_size: {0.max_size} count: {0._count}>'
fmt = "<Paginator prefix: {0.prefix!r} suffix: {0.suffix!r} linesep: {0.linesep!r} max_size: {0.max_size} count: {0._count}>"
return fmt.format(self)
@@ -197,7 +202,7 @@ class _HelpCommandImpl(Command):
self.callback = injected.command_callback
on_error = injected.on_help_command_error
if not hasattr(on_error, '__help_command_not_overriden__'):
if not hasattr(on_error, "__help_command_not_overriden__"):
if self.cog is not None:
self.on_error = self._on_error_cog_implementation
else:
@@ -224,7 +229,7 @@ class _HelpCommandImpl(Command):
try:
del result[next(iter(result))]
except StopIteration:
raise ValueError('Missing context parameter') from None
raise ValueError("Missing context parameter") from None
else:
return result
@@ -296,13 +301,13 @@ class HelpCommand:
"""
MENTION_TRANSFORMS = {
'@everyone': '@\u200beveryone',
'@here': '@\u200bhere',
r'<@!?[0-9]{17,22}>': '@deleted-user',
r'<@&[0-9]{17,22}>': '@deleted-role',
"@everyone": "@\u200beveryone",
"@here": "@\u200bhere",
r"<@!?[0-9]{17,22}>": "@deleted-user",
r"<@&[0-9]{17,22}>": "@deleted-role",
}
MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys()))
MENTION_PATTERN = re.compile("|".join(MENTION_TRANSFORMS.keys()))
def __new__(cls, *args, **kwargs):
# To prevent race conditions of a single instance while also allowing
@@ -321,11 +326,11 @@ class HelpCommand:
return self
def __init__(self, **options):
self.show_hidden = options.pop('show_hidden', False)
self.verify_checks = options.pop('verify_checks', True)
self.command_attrs = attrs = options.pop('command_attrs', {})
attrs.setdefault('name', 'help')
attrs.setdefault('help', 'Shows this message')
self.show_hidden = options.pop("show_hidden", False)
self.verify_checks = options.pop("verify_checks", True)
self.command_attrs = attrs = options.pop("command_attrs", {})
attrs.setdefault("name", "help")
attrs.setdefault("help", "Shows this message")
self.context: Context = discord.utils.MISSING
self._command_impl = _HelpCommandImpl(self, **self.command_attrs)
@@ -398,7 +403,11 @@ class HelpCommand:
"""
command_name = self._command_impl.name
ctx = self.context
if ctx is None or ctx.command is None or ctx.command.qualified_name != command_name:
if (
ctx is None
or ctx.command is None
or ctx.command.qualified_name != command_name
):
return command_name
return ctx.invoked_with
@@ -422,20 +431,20 @@ class HelpCommand:
if not parent.signature or parent.invoke_without_command:
entries.append(parent.name)
else:
entries.append(parent.name + ' ' + parent.signature)
entries.append(parent.name + " " + parent.signature)
parent = parent.parent
parent_sig = ' '.join(reversed(entries))
parent_sig = " ".join(reversed(entries))
if len(command.aliases) > 0:
aliases = '|'.join(command.aliases)
fmt = f'[{command.name}|{aliases}]'
aliases = "|".join(command.aliases)
fmt = f"[{command.name}|{aliases}]"
if parent_sig:
fmt = parent_sig + ' ' + fmt
fmt = parent_sig + " " + fmt
alias = fmt
else:
alias = command.name if not parent_sig else parent_sig + ' ' + command.name
alias = command.name if not parent_sig else parent_sig + " " + command.name
return f'{self.context.clean_prefix}{alias} {command.signature}'
return f"{self.context.clean_prefix}{alias} {command.signature}"
def remove_mentions(self, string):
"""Removes mentions from the string to prevent abuse.
@@ -449,7 +458,7 @@ class HelpCommand:
"""
def replace(obj, *, transforms=self.MENTION_TRANSFORMS):
return transforms.get(obj.group(0), '@invalid')
return transforms.get(obj.group(0), "@invalid")
return self.MENTION_PATTERN.sub(replace, string)
@@ -527,7 +536,9 @@ class HelpCommand:
The string to use when the command did not have the subcommand requested.
"""
if isinstance(command, Group) and len(command.all_commands) > 0:
return f'Command "{command.qualified_name}" has no subcommand named {string}'
return (
f'Command "{command.qualified_name}" has no subcommand named {string}'
)
return f'Command "{command.qualified_name}" has no subcommands.'
async def filter_commands(self, commands, *, sort=False, key=None):
@@ -558,7 +569,9 @@ class HelpCommand:
if sort and key is None:
key = lambda c: c.name
iterator = commands if self.show_hidden else filter(lambda c: not c.hidden, commands)
iterator = (
commands if self.show_hidden else filter(lambda c: not c.hidden, commands)
)
if self.verify_checks is False:
# if we do not need to verify the checks then we can just
@@ -846,21 +859,27 @@ class HelpCommand:
# Since we want to have detailed errors when someone
# passes an invalid subcommand, we need to walk through
# the command group chain ourselves.
keys = command.split(' ')
keys = command.split(" ")
cmd = bot.all_commands.get(keys[0])
if cmd is None:
string = await maybe_coro(self.command_not_found, self.remove_mentions(keys[0]))
string = await maybe_coro(
self.command_not_found, self.remove_mentions(keys[0])
)
return await self.send_error_message(string)
for key in keys[1:]:
try:
found = cmd.all_commands.get(key)
except AttributeError:
string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key))
string = await maybe_coro(
self.subcommand_not_found, cmd, self.remove_mentions(key)
)
return await self.send_error_message(string)
else:
if found is None:
string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key))
string = await maybe_coro(
self.subcommand_not_found, cmd, self.remove_mentions(key)
)
return await self.send_error_message(string)
cmd = found
@@ -907,14 +926,14 @@ class DefaultHelpCommand(HelpCommand):
"""
def __init__(self, **options):
self.width = options.pop('width', 80)
self.indent = options.pop('indent', 2)
self.sort_commands = options.pop('sort_commands', True)
self.dm_help = options.pop('dm_help', False)
self.dm_help_threshold = options.pop('dm_help_threshold', 1000)
self.commands_heading = options.pop('commands_heading', "Commands:")
self.no_category = options.pop('no_category', 'No Category')
self.paginator = options.pop('paginator', None)
self.width = options.pop("width", 80)
self.indent = options.pop("indent", 2)
self.sort_commands = options.pop("sort_commands", True)
self.dm_help = options.pop("dm_help", False)
self.dm_help_threshold = options.pop("dm_help_threshold", 1000)
self.commands_heading = options.pop("commands_heading", "Commands:")
self.no_category = options.pop("no_category", "No Category")
self.paginator = options.pop("paginator", None)
if self.paginator is None:
self.paginator = Paginator()
@@ -924,7 +943,7 @@ class DefaultHelpCommand(HelpCommand):
def shorten_text(self, text):
""":class:`str`: Shortens text to fit into the :attr:`width`."""
if len(text) > self.width:
return text[:self.width - 3].rstrip() + '...'
return text[: self.width - 3].rstrip() + "..."
return text
def get_ending_note(self):
@@ -1021,11 +1040,11 @@ class DefaultHelpCommand(HelpCommand):
# <description> portion
self.paginator.add_line(bot.description, empty=True)
no_category = f'\u200b{self.no_category}:'
no_category = f"\u200b{self.no_category}:"
def get_category(command, *, no_category=no_category):
cog = command.cog
return cog.qualified_name + ':' if cog is not None else no_category
return cog.qualified_name + ":" if cog is not None else no_category
filtered = await self.filter_commands(bot.commands, sort=True, key=get_category)
max_size = self.get_max_size(filtered)
@@ -1033,7 +1052,11 @@ class DefaultHelpCommand(HelpCommand):
# Now we can add the commands to the page.
for category, commands in to_iterate:
commands = sorted(commands, key=lambda c: c.name) if self.sort_commands else list(commands)
commands = (
sorted(commands, key=lambda c: c.name)
if self.sort_commands
else list(commands)
)
self.add_indented_commands(commands, heading=category, max_size=max_size)
note = self.get_ending_note()
@@ -1066,7 +1089,9 @@ class DefaultHelpCommand(HelpCommand):
if cog.description:
self.paginator.add_line(cog.description, empty=True)
filtered = await self.filter_commands(cog.get_commands(), sort=self.sort_commands)
filtered = await self.filter_commands(
cog.get_commands(), sort=self.sort_commands
)
self.add_indented_commands(filtered, heading=self.commands_heading)
note = self.get_ending_note()
@@ -1110,13 +1135,13 @@ class MinimalHelpCommand(HelpCommand):
"""
def __init__(self, **options):
self.sort_commands = options.pop('sort_commands', True)
self.commands_heading = options.pop('commands_heading', "Commands")
self.dm_help = options.pop('dm_help', False)
self.dm_help_threshold = options.pop('dm_help_threshold', 1000)
self.aliases_heading = options.pop('aliases_heading', "Aliases:")
self.no_category = options.pop('no_category', 'No Category')
self.paginator = options.pop('paginator', None)
self.sort_commands = options.pop("sort_commands", True)
self.commands_heading = options.pop("commands_heading", "Commands")
self.dm_help = options.pop("dm_help", False)
self.dm_help_threshold = options.pop("dm_help_threshold", 1000)
self.aliases_heading = options.pop("aliases_heading", "Aliases:")
self.no_category = options.pop("no_category", "No Category")
self.paginator = options.pop("paginator", None)
if self.paginator is None:
self.paginator = Paginator(suffix=None, prefix=None)
@@ -1149,7 +1174,9 @@ class MinimalHelpCommand(HelpCommand):
)
def get_command_signature(self, command):
return f'{self.context.clean_prefix}{command.qualified_name} {command.signature}'
return (
f"{self.context.clean_prefix}{command.qualified_name} {command.signature}"
)
def get_ending_note(self):
"""Return the help command's ending note. This is mainly useful to override for i18n purposes.
@@ -1180,8 +1207,8 @@ class MinimalHelpCommand(HelpCommand):
"""
if commands:
# U+2002 Middle Dot
joined = '\u2002'.join(c.name for c in commands)
self.paginator.add_line(f'__**{heading}**__')
joined = "\u2002".join(c.name for c in commands)
self.paginator.add_line(f"__**{heading}**__")
self.paginator.add_line(joined)
def add_subcommand_formatting(self, command):
@@ -1197,8 +1224,12 @@ class MinimalHelpCommand(HelpCommand):
command: :class:`Command`
The command to show information of.
"""
fmt = '{0}{1} \N{EN DASH} {2}' if command.short_doc else '{0}{1}'
self.paginator.add_line(fmt.format(self.context.clean_prefix, command.qualified_name, command.short_doc))
fmt = "{0}{1} \N{EN DASH} {2}" if command.short_doc else "{0}{1}"
self.paginator.add_line(
fmt.format(
self.context.clean_prefix, command.qualified_name, command.short_doc
)
)
def add_aliases_formatting(self, aliases):
"""Adds the formatting information on a command's aliases.
@@ -1215,7 +1246,9 @@ class MinimalHelpCommand(HelpCommand):
aliases: Sequence[:class:`str`]
A list of aliases to format.
"""
self.paginator.add_line(f'**{self.aliases_heading}** {", ".join(aliases)}', empty=True)
self.paginator.add_line(
f'**{self.aliases_heading}** {", ".join(aliases)}', empty=True
)
def add_command_formatting(self, command):
"""A utility function to format commands and groups.
@@ -1268,7 +1301,7 @@ class MinimalHelpCommand(HelpCommand):
if note:
self.paginator.add_line(note, empty=True)
no_category = f'\u200b{self.no_category}'
no_category = f"\u200b{self.no_category}"
def get_category(command, *, no_category=no_category):
cog = command.cog
@@ -1278,7 +1311,11 @@ class MinimalHelpCommand(HelpCommand):
to_iterate = itertools.groupby(filtered, key=get_category)
for category, commands in to_iterate:
commands = sorted(commands, key=lambda c: c.name) if self.sort_commands else list(commands)
commands = (
sorted(commands, key=lambda c: c.name)
if self.sort_commands
else list(commands)
)
self.add_bot_commands_formatting(commands, category)
note = self.get_ending_note()
@@ -1300,9 +1337,11 @@ class MinimalHelpCommand(HelpCommand):
if cog.description:
self.paginator.add_line(cog.description, empty=True)
filtered = await self.filter_commands(cog.get_commands(), sort=self.sort_commands)
filtered = await self.filter_commands(
cog.get_commands(), sort=self.sort_commands
)
if filtered:
self.paginator.add_line(f'**{cog.qualified_name} {self.commands_heading}**')
self.paginator.add_line(f"**{cog.qualified_name} {self.commands_heading}**")
for command in filtered:
self.add_subcommand_formatting(command)
@@ -1322,7 +1361,7 @@ class MinimalHelpCommand(HelpCommand):
if note:
self.paginator.add_line(note, empty=True)
self.paginator.add_line(f'**{self.commands_heading}**')
self.paginator.add_line(f"**{self.commands_heading}**")
for command in filtered:
self.add_subcommand_formatting(command)

View File

@@ -22,7 +22,11 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError
from .errors import (
UnexpectedQuoteError,
InvalidEndOfQuotedStringError,
ExpectedClosingQuoteError,
)
# map from opening quotes to closing quotes
_quotes = {
@@ -46,6 +50,7 @@ _quotes = {
}
_all_quotes = set(_quotes.keys()) | set(_quotes.values())
class StringView:
def __init__(self, buffer):
self.index = 0
@@ -81,20 +86,20 @@ class StringView:
def skip_string(self, string):
strlen = len(string)
if self.buffer[self.index:self.index + strlen] == string:
if self.buffer[self.index : self.index + strlen] == string:
self.previous = self.index
self.index += strlen
return True
return False
def read_rest(self):
result = self.buffer[self.index:]
result = self.buffer[self.index :]
self.previous = self.index
self.index = self.end
return result
def read(self, n):
result = self.buffer[self.index:self.index + n]
result = self.buffer[self.index : self.index + n]
self.previous = self.index
self.index += n
return result
@@ -120,7 +125,7 @@ class StringView:
except IndexError:
break
self.previous = self.index
result = self.buffer[self.index:self.index + pos]
result = self.buffer[self.index : self.index + pos]
self.index += pos
return result
@@ -144,11 +149,11 @@ class StringView:
if is_quoted:
# unexpected EOF
raise ExpectedClosingQuoteError(close_quote)
return ''.join(result)
return "".join(result)
# currently we accept strings in the format of "hello world"
# to embed a quote inside the string you must escape it: "a \"world\""
if current == '\\':
if current == "\\":
next_char = self.get()
if not next_char:
# string ends with \ and no character after it
@@ -156,7 +161,7 @@ class StringView:
# if we're quoted then we're expecting a closing quote
raise ExpectedClosingQuoteError(close_quote)
# if we aren't then we just let it through
return ''.join(result)
return "".join(result)
if next_char in _escaped_quotes:
# escaped quote
@@ -179,14 +184,13 @@ class StringView:
raise InvalidEndOfQuotedStringError(next_char)
# we're quoted so it's okay
return ''.join(result)
return "".join(result)
if current.isspace() and not is_quoted:
# end of word found
return ''.join(result)
return "".join(result)
result.append(current)
def __repr__(self):
return f'<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>'
return f"<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>"

View File

@@ -48,21 +48,21 @@ from collections.abc import Sequence
from discord.backoff import ExponentialBackoff
from discord.utils import MISSING
__all__ = (
'loop',
)
__all__ = ("loop",)
T = TypeVar('T')
T = TypeVar("T")
_func = Callable[..., Awaitable[Any]]
LF = TypeVar('LF', bound=_func)
FT = TypeVar('FT', bound=_func)
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
LF = TypeVar("LF", bound=_func)
FT = TypeVar("FT", bound=_func)
ET = TypeVar("ET", bound=Callable[[Any, BaseException], Awaitable[Any]])
class SleepHandle:
__slots__ = ('future', 'loop', 'handle')
__slots__ = ("future", "loop", "handle")
def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None:
def __init__(
self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop
) -> None:
self.loop = loop
self.future = future = loop.create_future()
relative_delta = discord.utils.compute_timedelta(dt)
@@ -124,7 +124,7 @@ class Loop(Generic[LF]):
self._stop_next_iteration = False
if self.count is not None and self.count <= 0:
raise ValueError('count must be greater than 0 or None.')
raise ValueError("count must be greater than 0 or None.")
self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time)
self._last_iteration_failed = False
@@ -132,10 +132,12 @@ class Loop(Generic[LF]):
self._next_iteration = None
if not inspect.iscoroutinefunction(self.coro):
raise TypeError(f'Expected coroutine function, not {type(self.coro).__name__!r}.')
raise TypeError(
f"Expected coroutine function, not {type(self.coro).__name__!r}."
)
async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None:
coro = getattr(self, '_' + name)
coro = getattr(self, "_" + name)
if coro is None:
return
@@ -150,7 +152,7 @@ class Loop(Generic[LF]):
async def _loop(self, *args: Any, **kwargs: Any) -> None:
backoff = ExponentialBackoff()
await self._call_loop_function('before_loop')
await self._call_loop_function("before_loop")
self._last_iteration_failed = False
if self._time is not MISSING:
# the time index should be prepared every time the internal loop is started
@@ -193,10 +195,10 @@ class Loop(Generic[LF]):
raise
except Exception as exc:
self._has_failed = True
await self._call_loop_function('error', exc)
await self._call_loop_function("error", exc)
raise exc
finally:
await self._call_loop_function('after_loop')
await self._call_loop_function("after_loop")
self._handle.cancel()
self._is_being_cancelled = False
self._current_loop = 0
@@ -323,7 +325,7 @@ class Loop(Generic[LF]):
"""
if self._task is not MISSING and not self._task.done():
raise RuntimeError('Task is already launched and is not completed.')
raise RuntimeError("Task is already launched and is not completed.")
if self._injected is not None:
args = (self._injected, *args)
@@ -356,7 +358,9 @@ class Loop(Generic[LF]):
self._stop_next_iteration = True
def _can_be_cancelled(self) -> bool:
return bool(not self._is_being_cancelled and self._task and not self._task.done())
return bool(
not self._is_being_cancelled and self._task and not self._task.done()
)
def cancel(self) -> None:
"""Cancels the internal task, if it is running."""
@@ -379,7 +383,9 @@ class Loop(Generic[LF]):
The keyword arguments to use.
"""
def restart_when_over(fut: Any, *, args: Any = args, kwargs: Any = kwargs) -> None:
def restart_when_over(
fut: Any, *, args: Any = args, kwargs: Any = kwargs
) -> None:
self._task.remove_done_callback(restart_when_over)
self.start(*args, **kwargs)
@@ -410,9 +416,9 @@ class Loop(Generic[LF]):
for exc in exceptions:
if not inspect.isclass(exc):
raise TypeError(f'{exc!r} must be a class.')
raise TypeError(f"{exc!r} must be a class.")
if not issubclass(exc, BaseException):
raise TypeError(f'{exc!r} must inherit from BaseException.')
raise TypeError(f"{exc!r} must inherit from BaseException.")
self._valid_exception = (*self._valid_exception, *exceptions)
@@ -439,7 +445,9 @@ class Loop(Generic[LF]):
Whether all exceptions were successfully removed.
"""
old_length = len(self._valid_exception)
self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions)
self._valid_exception = tuple(
x for x in self._valid_exception if x not in exceptions
)
return len(self._valid_exception) == old_length - len(exceptions)
def get_task(self) -> Optional[asyncio.Task[None]]:
@@ -466,8 +474,13 @@ class Loop(Generic[LF]):
async def _error(self, *args: Any) -> None:
exception: Exception = args[-1]
print(f'Unhandled exception in internal background task {self.coro.__name__!r}.', file=sys.stderr)
traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr)
print(
f"Unhandled exception in internal background task {self.coro.__name__!r}.",
file=sys.stderr,
)
traceback.print_exception(
type(exception), exception, exception.__traceback__, file=sys.stderr
)
def before_loop(self, coro: FT) -> FT:
"""A decorator that registers a coroutine to be called before the loop starts running.
@@ -489,7 +502,9 @@ class Loop(Generic[LF]):
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
raise TypeError(
f"Expected coroutine function, received {coro.__class__.__name__!r}."
)
self._before_loop = coro
return coro
@@ -517,7 +532,9 @@ class Loop(Generic[LF]):
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
raise TypeError(
f"Expected coroutine function, received {coro.__class__.__name__!r}."
)
self._after_loop = coro
return coro
@@ -543,7 +560,9 @@ class Loop(Generic[LF]):
The function was not a coroutine.
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
raise TypeError(
f"Expected coroutine function, received {coro.__class__.__name__!r}."
)
self._error = coro # type: ignore
return coro
@@ -557,14 +576,18 @@ class Loop(Generic[LF]):
if self._current_loop == 0:
# if we're at the last index on the first iteration, we need to sleep until tomorrow
return datetime.datetime.combine(
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0]
datetime.datetime.now(datetime.timezone.utc)
+ datetime.timedelta(days=1),
self._time[0],
)
next_time = self._time[self._time_index]
if self._current_loop == 0:
self._time_index += 1
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time)
return datetime.datetime.combine(
datetime.datetime.now(datetime.timezone.utc), next_time
)
next_date = self._last_iteration
if self._time_index == 0:
@@ -580,7 +603,9 @@ class Loop(Generic[LF]):
# pre-condition: self._time is set
time_now = (
now if now is not MISSING else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)
now
if now is not MISSING
else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)
).timetz()
for idx, time in enumerate(self._time):
if time >= time_now:
@@ -601,16 +626,16 @@ class Loop(Generic[LF]):
return [inner]
if not isinstance(time, Sequence):
raise TypeError(
f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.'
f"Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead."
)
if not time:
raise ValueError('time parameter must not be an empty sequence.')
raise ValueError("time parameter must not be an empty sequence.")
ret: List[datetime.time] = []
for index, t in enumerate(time):
if not isinstance(t, dt):
raise TypeError(
f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.'
f"Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead."
)
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
@@ -663,7 +688,7 @@ class Loop(Generic[LF]):
hours = hours or 0
sleep = seconds + (minutes * 60.0) + (hours * 3600.0)
if sleep < 0:
raise ValueError('Total number of seconds cannot be less than zero.')
raise ValueError("Total number of seconds cannot be less than zero.")
self._sleep = sleep
self._seconds = float(seconds)
@@ -672,7 +697,7 @@ class Loop(Generic[LF]):
self._time: List[datetime.time] = MISSING
else:
if any((seconds, minutes, hours)):
raise TypeError('Cannot mix explicit time with relative time')
raise TypeError("Cannot mix explicit time with relative time")
self._time = self._get_time_parameter(time)
self._sleep = self._seconds = self._minutes = self._hours = MISSING