Fix code style issues with Black
This commit is contained in:
@ -31,7 +31,7 @@ if TYPE_CHECKING:
|
||||
from .cog import Cog
|
||||
from .errors import CommandError
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
|
||||
Coro = Coroutine[Any, Any, T]
|
||||
MaybeCoro = Union[T, Coro[T]]
|
||||
@ -39,7 +39,9 @@ CoroFunc = Callable[..., Coro[Any]]
|
||||
|
||||
Check = Union[Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]]
|
||||
Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]]
|
||||
Error = Union[Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]]
|
||||
Error = Union[
|
||||
Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]
|
||||
]
|
||||
|
||||
|
||||
# This is merely a tag type to avoid circular import issues.
|
||||
|
@ -54,17 +54,18 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
'when_mentioned',
|
||||
'when_mentioned_or',
|
||||
'Bot',
|
||||
'AutoShardedBot',
|
||||
"when_mentioned",
|
||||
"when_mentioned_or",
|
||||
"Bot",
|
||||
"AutoShardedBot",
|
||||
)
|
||||
|
||||
MISSING: Any = discord.utils.MISSING
|
||||
|
||||
T = TypeVar('T')
|
||||
CFT = TypeVar('CFT', bound='CoroFunc')
|
||||
CXT = TypeVar('CXT', bound='Context')
|
||||
T = TypeVar("T")
|
||||
CFT = TypeVar("CFT", bound="CoroFunc")
|
||||
CXT = TypeVar("CXT", bound="Context")
|
||||
|
||||
|
||||
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
|
||||
"""A callable that implements a command prefix equivalent to being mentioned.
|
||||
@ -72,7 +73,8 @@ def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
|
||||
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
|
||||
"""
|
||||
# bot.user will never be None when this is called
|
||||
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore
|
||||
return [f"<@{bot.user.id}> ", f"<@!{bot.user.id}> "] # type: ignore
|
||||
|
||||
|
||||
def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]:
|
||||
"""A callable that implements when mentioned or other prefixes provided.
|
||||
@ -103,6 +105,7 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M
|
||||
----------
|
||||
:func:`.when_mentioned`
|
||||
"""
|
||||
|
||||
def inner(bot, msg):
|
||||
r = list(prefixes)
|
||||
r = when_mentioned(bot, msg) + r
|
||||
@ -110,15 +113,19 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def _is_submodule(parent: str, child: str) -> bool:
|
||||
return parent == child or child.startswith(parent + ".")
|
||||
|
||||
|
||||
class _DefaultRepr:
|
||||
def __repr__(self):
|
||||
return '<default-help-command>'
|
||||
return "<default-help-command>"
|
||||
|
||||
|
||||
_default = _DefaultRepr()
|
||||
|
||||
|
||||
class BotBase(GroupMixin):
|
||||
def __init__(self, command_prefix, help_command=_default, description=None, *, intents: discord.Intents, **options):
|
||||
super().__init__(**options, intents=intents)
|
||||
@ -131,16 +138,16 @@ class BotBase(GroupMixin):
|
||||
self._before_invoke = None
|
||||
self._after_invoke = None
|
||||
self._help_command = None
|
||||
self.description = inspect.cleandoc(description) if description else ''
|
||||
self.owner_id = options.get('owner_id')
|
||||
self.owner_ids = options.get('owner_ids', set())
|
||||
self.strip_after_prefix = options.get('strip_after_prefix', False)
|
||||
self.description = inspect.cleandoc(description) if description else ""
|
||||
self.owner_id = options.get("owner_id")
|
||||
self.owner_ids = options.get("owner_ids", set())
|
||||
self.strip_after_prefix = options.get("strip_after_prefix", False)
|
||||
|
||||
if self.owner_id and self.owner_ids:
|
||||
raise TypeError('Both owner_id and owner_ids are set.')
|
||||
raise TypeError("Both owner_id and owner_ids are set.")
|
||||
|
||||
if self.owner_ids and not isinstance(self.owner_ids, collections.abc.Collection):
|
||||
raise TypeError(f'owner_ids must be a collection not {self.owner_ids.__class__!r}')
|
||||
raise TypeError(f"owner_ids must be a collection not {self.owner_ids.__class__!r}")
|
||||
|
||||
if help_command is _default:
|
||||
self.help_command = DefaultHelpCommand()
|
||||
@ -152,7 +159,7 @@ class BotBase(GroupMixin):
|
||||
def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None:
|
||||
# super() will resolve to Client
|
||||
super().dispatch(event_name, *args, **kwargs) # type: ignore
|
||||
ev = 'on_' + event_name
|
||||
ev = "on_" + event_name
|
||||
for event in self.extra_events.get(ev, []):
|
||||
self._schedule_event(event, ev, *args, **kwargs) # type: ignore
|
||||
|
||||
@ -182,7 +189,7 @@ class BotBase(GroupMixin):
|
||||
|
||||
This only fires if you do not specify any listeners for command error.
|
||||
"""
|
||||
if self.extra_events.get('on_command_error', None):
|
||||
if self.extra_events.get("on_command_error", None):
|
||||
return
|
||||
|
||||
command = context.command
|
||||
@ -193,7 +200,7 @@ class BotBase(GroupMixin):
|
||||
if cog and cog.has_error_handler():
|
||||
return
|
||||
|
||||
print(f'Ignoring exception in command {context.command}:', file=sys.stderr)
|
||||
print(f"Ignoring exception in command {context.command}:", file=sys.stderr)
|
||||
traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr)
|
||||
|
||||
# global check registration
|
||||
@ -425,7 +432,7 @@ class BotBase(GroupMixin):
|
||||
The coroutine passed is not actually a coroutine.
|
||||
"""
|
||||
if not asyncio.iscoroutinefunction(coro):
|
||||
raise TypeError('The pre-invoke hook must be a coroutine.')
|
||||
raise TypeError("The pre-invoke hook must be a coroutine.")
|
||||
|
||||
self._before_invoke = coro
|
||||
return coro
|
||||
@ -458,7 +465,7 @@ class BotBase(GroupMixin):
|
||||
The coroutine passed is not actually a coroutine.
|
||||
"""
|
||||
if not asyncio.iscoroutinefunction(coro):
|
||||
raise TypeError('The post-invoke hook must be a coroutine.')
|
||||
raise TypeError("The post-invoke hook must be a coroutine.")
|
||||
|
||||
self._after_invoke = coro
|
||||
return coro
|
||||
@ -490,7 +497,7 @@ class BotBase(GroupMixin):
|
||||
name = func.__name__ if name is MISSING else name
|
||||
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
raise TypeError('Listeners must be coroutines')
|
||||
raise TypeError("Listeners must be coroutines")
|
||||
|
||||
if name in self.extra_events:
|
||||
self.extra_events[name].append(func)
|
||||
@ -586,14 +593,14 @@ class BotBase(GroupMixin):
|
||||
"""
|
||||
|
||||
if not isinstance(cog, Cog):
|
||||
raise TypeError('cogs must derive from Cog')
|
||||
raise TypeError("cogs must derive from Cog")
|
||||
|
||||
cog_name = cog.__cog_name__
|
||||
existing = self.__cogs.get(cog_name)
|
||||
|
||||
if existing is not None:
|
||||
if not override:
|
||||
raise discord.ClientException(f'Cog named {cog_name!r} already loaded')
|
||||
raise discord.ClientException(f"Cog named {cog_name!r} already loaded")
|
||||
self.remove_cog(cog_name)
|
||||
|
||||
cog = cog._inject(self)
|
||||
@ -681,7 +688,7 @@ class BotBase(GroupMixin):
|
||||
|
||||
def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None:
|
||||
try:
|
||||
func = getattr(lib, 'teardown')
|
||||
func = getattr(lib, "teardown")
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
@ -708,7 +715,7 @@ class BotBase(GroupMixin):
|
||||
raise errors.ExtensionFailed(key, e) from e
|
||||
|
||||
try:
|
||||
setup = getattr(lib, 'setup')
|
||||
setup = getattr(lib, "setup")
|
||||
except AttributeError:
|
||||
del sys.modules[key]
|
||||
raise errors.NoEntryPointError(key)
|
||||
@ -858,11 +865,7 @@ class BotBase(GroupMixin):
|
||||
raise errors.ExtensionNotLoaded(name)
|
||||
|
||||
# get the previous module states from sys modules
|
||||
modules = {
|
||||
name: module
|
||||
for name, module in sys.modules.items()
|
||||
if _is_submodule(lib.__name__, name)
|
||||
}
|
||||
modules = {name: module for name, module in sys.modules.items() if _is_submodule(lib.__name__, name)}
|
||||
|
||||
try:
|
||||
# Unload and then load the module...
|
||||
@ -895,7 +898,7 @@ class BotBase(GroupMixin):
|
||||
def help_command(self, value: Optional[HelpCommand]) -> None:
|
||||
if value is not None:
|
||||
if not isinstance(value, HelpCommand):
|
||||
raise TypeError('help_command must be a subclass of HelpCommand')
|
||||
raise TypeError("help_command must be a subclass of HelpCommand")
|
||||
if self._help_command is not None:
|
||||
self._help_command._remove_from_bot(self)
|
||||
self._help_command = value
|
||||
@ -938,8 +941,10 @@ class BotBase(GroupMixin):
|
||||
if isinstance(ret, collections.abc.Iterable):
|
||||
raise
|
||||
|
||||
raise TypeError("command_prefix must be plain string, iterable of strings, or callable "
|
||||
f"returning either of these, not {ret.__class__.__name__}")
|
||||
raise TypeError(
|
||||
"command_prefix must be plain string, iterable of strings, or callable "
|
||||
f"returning either of these, not {ret.__class__.__name__}"
|
||||
)
|
||||
|
||||
if not ret:
|
||||
raise ValueError("Iterable command_prefix must contain at least one prefix")
|
||||
@ -999,14 +1004,18 @@ class BotBase(GroupMixin):
|
||||
|
||||
except TypeError:
|
||||
if not isinstance(prefix, list):
|
||||
raise TypeError("get_prefix must return either a string or a list of string, "
|
||||
f"not {prefix.__class__.__name__}")
|
||||
raise TypeError(
|
||||
"get_prefix must return either a string or a list of string, "
|
||||
f"not {prefix.__class__.__name__}"
|
||||
)
|
||||
|
||||
# It's possible a bad command_prefix got us here.
|
||||
for value in prefix:
|
||||
if not isinstance(value, str):
|
||||
raise TypeError("Iterable command_prefix or list returned from get_prefix must "
|
||||
f"contain only strings, not {value.__class__.__name__}")
|
||||
raise TypeError(
|
||||
"Iterable command_prefix or list returned from get_prefix must "
|
||||
f"contain only strings, not {value.__class__.__name__}"
|
||||
)
|
||||
|
||||
# Getting here shouldn't happen
|
||||
raise
|
||||
@ -1033,19 +1042,19 @@ class BotBase(GroupMixin):
|
||||
The invocation context to invoke.
|
||||
"""
|
||||
if ctx.command is not None:
|
||||
self.dispatch('command', ctx)
|
||||
self.dispatch("command", ctx)
|
||||
try:
|
||||
if await self.can_run(ctx, call_once=True):
|
||||
await ctx.command.invoke(ctx)
|
||||
else:
|
||||
raise errors.CheckFailure('The global check once functions failed.')
|
||||
raise errors.CheckFailure("The global check once functions failed.")
|
||||
except errors.CommandError as exc:
|
||||
await ctx.command.dispatch_error(ctx, exc)
|
||||
else:
|
||||
self.dispatch('command_completion', ctx)
|
||||
self.dispatch("command_completion", ctx)
|
||||
elif ctx.invoked_with:
|
||||
exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found')
|
||||
self.dispatch('command_error', ctx, exc)
|
||||
self.dispatch("command_error", ctx, exc)
|
||||
|
||||
async def process_commands(self, message: Message) -> None:
|
||||
"""|coro|
|
||||
@ -1078,6 +1087,7 @@ class BotBase(GroupMixin):
|
||||
async def on_message(self, message):
|
||||
await self.process_commands(message)
|
||||
|
||||
|
||||
class Bot(BotBase, discord.Client):
|
||||
"""Represents a discord bot.
|
||||
|
||||
@ -1148,10 +1158,13 @@ class Bot(BotBase, discord.Client):
|
||||
|
||||
.. versionadded:: 1.7
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AutoShardedBot(BotBase, discord.AutoShardedClient):
|
||||
"""This is similar to :class:`.Bot` except that it is inherited from
|
||||
:class:`discord.AutoShardedClient` instead.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
@ -36,15 +36,16 @@ if TYPE_CHECKING:
|
||||
from .core import Command
|
||||
|
||||
__all__ = (
|
||||
'CogMeta',
|
||||
'Cog',
|
||||
"CogMeta",
|
||||
"Cog",
|
||||
)
|
||||
|
||||
CogT = TypeVar('CogT', bound='Cog')
|
||||
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
|
||||
CogT = TypeVar("CogT", bound="Cog")
|
||||
FuncT = TypeVar("FuncT", bound=Callable[..., Any])
|
||||
|
||||
MISSING: Any = discord.utils.MISSING
|
||||
|
||||
|
||||
class CogMeta(type):
|
||||
"""A metaclass for defining a cog.
|
||||
|
||||
@ -104,6 +105,7 @@ class CogMeta(type):
|
||||
async def bar(self, ctx):
|
||||
pass # hidden -> False
|
||||
"""
|
||||
|
||||
__cog_name__: str
|
||||
__cog_settings__: Dict[str, Any]
|
||||
__cog_commands__: List[Command]
|
||||
@ -111,17 +113,17 @@ class CogMeta(type):
|
||||
|
||||
def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta:
|
||||
name, bases, attrs = args
|
||||
attrs['__cog_name__'] = kwargs.pop('name', name)
|
||||
attrs['__cog_settings__'] = kwargs.pop('command_attrs', {})
|
||||
attrs["__cog_name__"] = kwargs.pop("name", name)
|
||||
attrs["__cog_settings__"] = kwargs.pop("command_attrs", {})
|
||||
|
||||
description = kwargs.pop('description', None)
|
||||
description = kwargs.pop("description", None)
|
||||
if description is None:
|
||||
description = inspect.cleandoc(attrs.get('__doc__', ''))
|
||||
attrs['__cog_description__'] = description
|
||||
description = inspect.cleandoc(attrs.get("__doc__", ""))
|
||||
attrs["__cog_description__"] = description
|
||||
|
||||
commands = {}
|
||||
listeners = {}
|
||||
no_bot_cog = 'Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})'
|
||||
no_bot_cog = "Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})"
|
||||
|
||||
new_cls = super().__new__(cls, name, bases, attrs, **kwargs)
|
||||
for base in reversed(new_cls.__mro__):
|
||||
@ -136,21 +138,21 @@ class CogMeta(type):
|
||||
value = value.__func__
|
||||
if isinstance(value, _BaseCommand):
|
||||
if is_static_method:
|
||||
raise TypeError(f'Command in method {base}.{elem!r} must not be staticmethod.')
|
||||
if elem.startswith(('cog_', 'bot_')):
|
||||
raise TypeError(f"Command in method {base}.{elem!r} must not be staticmethod.")
|
||||
if elem.startswith(("cog_", "bot_")):
|
||||
raise TypeError(no_bot_cog.format(base, elem))
|
||||
commands[elem] = value
|
||||
elif inspect.iscoroutinefunction(value):
|
||||
try:
|
||||
getattr(value, '__cog_listener__')
|
||||
getattr(value, "__cog_listener__")
|
||||
except AttributeError:
|
||||
continue
|
||||
else:
|
||||
if elem.startswith(('cog_', 'bot_')):
|
||||
if elem.startswith(("cog_", "bot_")):
|
||||
raise TypeError(no_bot_cog.format(base, elem))
|
||||
listeners[elem] = value
|
||||
|
||||
new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__
|
||||
new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__
|
||||
|
||||
listeners_as_list = []
|
||||
for listener in listeners.values():
|
||||
@ -169,10 +171,12 @@ class CogMeta(type):
|
||||
def qualified_name(cls) -> str:
|
||||
return cls.__cog_name__
|
||||
|
||||
|
||||
def _cog_special_method(func: FuncT) -> FuncT:
|
||||
func.__cog_special_method__ = None
|
||||
return func
|
||||
|
||||
|
||||
class Cog(metaclass=CogMeta):
|
||||
"""The base class that all cogs must inherit from.
|
||||
|
||||
@ -183,6 +187,7 @@ class Cog(metaclass=CogMeta):
|
||||
When inheriting from this class, the options shown in :class:`CogMeta`
|
||||
are equally valid here.
|
||||
"""
|
||||
|
||||
__cog_name__: ClassVar[str]
|
||||
__cog_settings__: ClassVar[Dict[str, Any]]
|
||||
__cog_commands__: ClassVar[List[Command]]
|
||||
@ -199,10 +204,7 @@ class Cog(metaclass=CogMeta):
|
||||
# r.e type ignore, type-checker complains about overriding a ClassVar
|
||||
self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore
|
||||
|
||||
lookup = {
|
||||
cmd.qualified_name: cmd
|
||||
for cmd in self.__cog_commands__
|
||||
}
|
||||
lookup = {cmd.qualified_name: cmd for cmd in self.__cog_commands__}
|
||||
|
||||
# Update the Command instances dynamically as well
|
||||
for command in self.__cog_commands__:
|
||||
@ -255,6 +257,7 @@ class Cog(metaclass=CogMeta):
|
||||
A command or group from the cog.
|
||||
"""
|
||||
from .core import GroupMixin
|
||||
|
||||
for command in self.__cog_commands__:
|
||||
if command.parent is None:
|
||||
yield command
|
||||
@ -274,7 +277,7 @@ class Cog(metaclass=CogMeta):
|
||||
@classmethod
|
||||
def _get_overridden_method(cls, method: FuncT) -> Optional[FuncT]:
|
||||
"""Return None if the method is not overridden. Otherwise returns the overridden method."""
|
||||
return getattr(method.__func__, '__cog_special_method__', method)
|
||||
return getattr(method.__func__, "__cog_special_method__", method)
|
||||
|
||||
@classmethod
|
||||
def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]:
|
||||
@ -296,14 +299,14 @@ class Cog(metaclass=CogMeta):
|
||||
"""
|
||||
|
||||
if name is not MISSING and not isinstance(name, str):
|
||||
raise TypeError(f'Cog.listener expected str but received {name.__class__.__name__!r} instead.')
|
||||
raise TypeError(f"Cog.listener expected str but received {name.__class__.__name__!r} instead.")
|
||||
|
||||
def decorator(func: FuncT) -> FuncT:
|
||||
actual = func
|
||||
if isinstance(actual, staticmethod):
|
||||
actual = actual.__func__
|
||||
if not inspect.iscoroutinefunction(actual):
|
||||
raise TypeError('Listener function must be a coroutine function.')
|
||||
raise TypeError("Listener function must be a coroutine function.")
|
||||
actual.__cog_listener__ = True
|
||||
to_assign = name or actual.__name__
|
||||
try:
|
||||
@ -315,6 +318,7 @@ class Cog(metaclass=CogMeta):
|
||||
# to pick it up but the metaclass unfurls the function and
|
||||
# thus the assignments need to be on the actual function
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def has_error_handler(self) -> bool:
|
||||
@ -322,7 +326,7 @@ class Cog(metaclass=CogMeta):
|
||||
|
||||
.. versionadded:: 1.7
|
||||
"""
|
||||
return not hasattr(self.cog_command_error.__func__, '__cog_special_method__')
|
||||
return not hasattr(self.cog_command_error.__func__, "__cog_special_method__")
|
||||
|
||||
@_cog_special_method
|
||||
def cog_unload(self) -> None:
|
||||
|
@ -50,21 +50,19 @@ if TYPE_CHECKING:
|
||||
from .help import HelpCommand
|
||||
from .view import StringView
|
||||
|
||||
__all__ = (
|
||||
'Context',
|
||||
)
|
||||
__all__ = ("Context",)
|
||||
|
||||
MISSING: Any = discord.utils.MISSING
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
|
||||
CogT = TypeVar('CogT', bound="Cog")
|
||||
T = TypeVar("T")
|
||||
BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]")
|
||||
CogT = TypeVar("CogT", bound="Cog")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
P = ParamSpec('P')
|
||||
P = ParamSpec("P")
|
||||
else:
|
||||
P = TypeVar('P')
|
||||
P = TypeVar("P")
|
||||
|
||||
|
||||
class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
@ -123,7 +121,8 @@ class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
or invoked.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
message: Message,
|
||||
bot: BotT,
|
||||
@ -220,7 +219,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
cmd = self.command
|
||||
view = self.view
|
||||
if cmd is None:
|
||||
raise ValueError('This context is not valid.')
|
||||
raise ValueError("This context is not valid.")
|
||||
|
||||
# some state to revert to when we're done
|
||||
index, previous = view.index, view.previous
|
||||
@ -231,10 +230,10 @@ class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
|
||||
if restart:
|
||||
to_call = cmd.root_parent or cmd
|
||||
view.index = len(self.prefix or '')
|
||||
view.index = len(self.prefix or "")
|
||||
view.previous = 0
|
||||
self.invoked_parents = []
|
||||
self.invoked_with = view.get_word() # advance to get the root command
|
||||
self.invoked_with = view.get_word() # advance to get the root command
|
||||
else:
|
||||
to_call = cmd
|
||||
|
||||
@ -264,7 +263,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
if self.prefix is None:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
user = self.me
|
||||
# this breaks if the prefix mention is not the bot itself but I
|
||||
@ -272,7 +271,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
# for this common use case rather than waste performance for the
|
||||
# odd one.
|
||||
pattern = re.compile(r"<@!?%s>" % user.id)
|
||||
return pattern.sub("@%s" % user.display_name.replace('\\', r'\\'), self.prefix)
|
||||
return pattern.sub("@%s" % user.display_name.replace("\\", r"\\"), self.prefix)
|
||||
|
||||
@property
|
||||
def cog(self) -> Optional[Cog]:
|
||||
@ -389,7 +388,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
await cmd.prepare_help_command(self, entity.qualified_name)
|
||||
|
||||
try:
|
||||
if hasattr(entity, '__cog_commands__'):
|
||||
if hasattr(entity, "__cog_commands__"):
|
||||
injected = wrap_callback(cmd.send_cog_help)
|
||||
return await injected(entity)
|
||||
elif isinstance(entity, Group):
|
||||
|
@ -52,32 +52,32 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
__all__ = (
|
||||
'Converter',
|
||||
'ObjectConverter',
|
||||
'MemberConverter',
|
||||
'UserConverter',
|
||||
'MessageConverter',
|
||||
'PartialMessageConverter',
|
||||
'TextChannelConverter',
|
||||
'InviteConverter',
|
||||
'GuildConverter',
|
||||
'RoleConverter',
|
||||
'GameConverter',
|
||||
'ColourConverter',
|
||||
'ColorConverter',
|
||||
'VoiceChannelConverter',
|
||||
'StageChannelConverter',
|
||||
'EmojiConverter',
|
||||
'PartialEmojiConverter',
|
||||
'CategoryChannelConverter',
|
||||
'IDConverter',
|
||||
'StoreChannelConverter',
|
||||
'ThreadConverter',
|
||||
'GuildChannelConverter',
|
||||
'GuildStickerConverter',
|
||||
'clean_content',
|
||||
'Greedy',
|
||||
'run_converters',
|
||||
"Converter",
|
||||
"ObjectConverter",
|
||||
"MemberConverter",
|
||||
"UserConverter",
|
||||
"MessageConverter",
|
||||
"PartialMessageConverter",
|
||||
"TextChannelConverter",
|
||||
"InviteConverter",
|
||||
"GuildConverter",
|
||||
"RoleConverter",
|
||||
"GameConverter",
|
||||
"ColourConverter",
|
||||
"ColorConverter",
|
||||
"VoiceChannelConverter",
|
||||
"StageChannelConverter",
|
||||
"EmojiConverter",
|
||||
"PartialEmojiConverter",
|
||||
"CategoryChannelConverter",
|
||||
"IDConverter",
|
||||
"StoreChannelConverter",
|
||||
"ThreadConverter",
|
||||
"GuildChannelConverter",
|
||||
"GuildStickerConverter",
|
||||
"clean_content",
|
||||
"Greedy",
|
||||
"run_converters",
|
||||
)
|
||||
|
||||
|
||||
@ -91,10 +91,10 @@ def _get_from_guilds(bot, getter, argument):
|
||||
|
||||
|
||||
_utils_get = discord.utils.get
|
||||
T = TypeVar('T')
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
CT = TypeVar('CT', bound=discord.abc.GuildChannel)
|
||||
TT = TypeVar('TT', bound=discord.Thread)
|
||||
T = TypeVar("T")
|
||||
T_co = TypeVar("T_co", covariant=True)
|
||||
CT = TypeVar("CT", bound=discord.abc.GuildChannel)
|
||||
TT = TypeVar("TT", bound=discord.Thread)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@ -132,10 +132,10 @@ class Converter(Protocol[T_co]):
|
||||
:exc:`.BadArgument`
|
||||
The converter failed to convert the argument.
|
||||
"""
|
||||
raise NotImplementedError('Derived classes need to implement this.')
|
||||
raise NotImplementedError("Derived classes need to implement this.")
|
||||
|
||||
|
||||
_ID_REGEX = re.compile(r'([0-9]{15,20})$')
|
||||
_ID_REGEX = re.compile(r"([0-9]{15,20})$")
|
||||
|
||||
|
||||
class IDConverter(Converter[T_co]):
|
||||
@ -158,7 +158,7 @@ class ObjectConverter(IDConverter[discord.Object]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.Object:
|
||||
match = self._get_id_match(argument) or re.match(r'<(?:@(?:!|&)?|#)([0-9]{15,20})>$', argument)
|
||||
match = self._get_id_match(argument) or re.match(r"<(?:@(?:!|&)?|#)([0-9]{15,20})>$", argument)
|
||||
|
||||
if match is None:
|
||||
raise ObjectNotFound(argument)
|
||||
@ -192,8 +192,8 @@ class MemberConverter(IDConverter[discord.Member]):
|
||||
|
||||
async def query_member_named(self, guild, argument):
|
||||
cache = guild._state.member_cache_flags.joined
|
||||
if len(argument) > 5 and argument[-5] == '#':
|
||||
username, _, discriminator = argument.rpartition('#')
|
||||
if len(argument) > 5 and argument[-5] == "#":
|
||||
username, _, discriminator = argument.rpartition("#")
|
||||
members = await guild.query_members(username, limit=100, cache=cache)
|
||||
return discord.utils.get(members, name=username, discriminator=discriminator)
|
||||
else:
|
||||
@ -223,7 +223,7 @@ class MemberConverter(IDConverter[discord.Member]):
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.Member:
|
||||
bot = ctx.bot
|
||||
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
|
||||
match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument)
|
||||
guild = ctx.guild
|
||||
result = None
|
||||
user_id = None
|
||||
@ -232,13 +232,13 @@ class MemberConverter(IDConverter[discord.Member]):
|
||||
if guild:
|
||||
result = guild.get_member_named(argument)
|
||||
else:
|
||||
result = _get_from_guilds(bot, 'get_member_named', argument)
|
||||
result = _get_from_guilds(bot, "get_member_named", argument)
|
||||
else:
|
||||
user_id = int(match.group(1))
|
||||
if guild:
|
||||
result = guild.get_member(user_id) or _utils_get(ctx.message.mentions, id=user_id)
|
||||
else:
|
||||
result = _get_from_guilds(bot, 'get_member', user_id)
|
||||
result = _get_from_guilds(bot, "get_member", user_id)
|
||||
|
||||
if result is None:
|
||||
if guild is None:
|
||||
@ -276,7 +276,7 @@ class UserConverter(IDConverter[discord.User]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.User:
|
||||
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
|
||||
match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument)
|
||||
result = None
|
||||
state = ctx._state
|
||||
|
||||
@ -294,12 +294,12 @@ class UserConverter(IDConverter[discord.User]):
|
||||
arg = argument
|
||||
|
||||
# Remove the '@' character if this is the first character from the argument
|
||||
if arg[0] == '@':
|
||||
if arg[0] == "@":
|
||||
# Remove first character
|
||||
arg = arg[1:]
|
||||
|
||||
# check for discriminator if it exists,
|
||||
if len(arg) > 5 and arg[-5] == '#':
|
||||
if len(arg) > 5 and arg[-5] == "#":
|
||||
discrim = arg[-4:]
|
||||
name = arg[:-5]
|
||||
predicate = lambda u: u.name == name and u.discriminator == discrim
|
||||
@ -330,22 +330,22 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
|
||||
|
||||
@staticmethod
|
||||
def _get_id_matches(ctx, argument):
|
||||
id_regex = re.compile(r'(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$')
|
||||
id_regex = re.compile(r"(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$")
|
||||
link_regex = re.compile(
|
||||
r'https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/'
|
||||
r'(?P<guild_id>[0-9]{15,20}|@me)'
|
||||
r'/(?P<channel_id>[0-9]{15,20})/(?P<message_id>[0-9]{15,20})/?$'
|
||||
r"https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/"
|
||||
r"(?P<guild_id>[0-9]{15,20}|@me)"
|
||||
r"/(?P<channel_id>[0-9]{15,20})/(?P<message_id>[0-9]{15,20})/?$"
|
||||
)
|
||||
match = id_regex.match(argument) or link_regex.match(argument)
|
||||
if not match:
|
||||
raise MessageNotFound(argument)
|
||||
data = match.groupdict()
|
||||
channel_id = discord.utils._get_as_snowflake(data, 'channel_id')
|
||||
message_id = int(data['message_id'])
|
||||
guild_id = data.get('guild_id')
|
||||
channel_id = discord.utils._get_as_snowflake(data, "channel_id")
|
||||
message_id = int(data["message_id"])
|
||||
guild_id = data.get("guild_id")
|
||||
if guild_id is None:
|
||||
guild_id = ctx.guild and ctx.guild.id
|
||||
elif guild_id == '@me':
|
||||
elif guild_id == "@me":
|
||||
guild_id = None
|
||||
else:
|
||||
guild_id = int(guild_id)
|
||||
@ -417,13 +417,13 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.abc.GuildChannel:
|
||||
return self._resolve_channel(ctx, argument, 'channels', discord.abc.GuildChannel)
|
||||
return self._resolve_channel(ctx, argument, "channels", discord.abc.GuildChannel)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT:
|
||||
bot = ctx.bot
|
||||
|
||||
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
|
||||
match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument)
|
||||
result = None
|
||||
guild = ctx.guild
|
||||
|
||||
@ -443,7 +443,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
|
||||
if guild:
|
||||
result = guild.get_channel(channel_id)
|
||||
else:
|
||||
result = _get_from_guilds(bot, 'get_channel', channel_id)
|
||||
result = _get_from_guilds(bot, "get_channel", channel_id)
|
||||
|
||||
if not isinstance(result, type):
|
||||
raise ChannelNotFound(argument)
|
||||
@ -454,7 +454,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
|
||||
def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT:
|
||||
bot = ctx.bot
|
||||
|
||||
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
|
||||
match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument)
|
||||
result = None
|
||||
guild = ctx.guild
|
||||
|
||||
@ -491,7 +491,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.TextChannel:
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, 'text_channels', discord.TextChannel)
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, "text_channels", discord.TextChannel)
|
||||
|
||||
|
||||
class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
|
||||
@ -511,7 +511,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel:
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, 'voice_channels', discord.VoiceChannel)
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, "voice_channels", discord.VoiceChannel)
|
||||
|
||||
|
||||
class StageChannelConverter(IDConverter[discord.StageChannel]):
|
||||
@ -530,7 +530,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.StageChannel:
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, 'stage_channels', discord.StageChannel)
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, "stage_channels", discord.StageChannel)
|
||||
|
||||
|
||||
class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
|
||||
@ -550,7 +550,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel:
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, 'categories', discord.CategoryChannel)
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, "categories", discord.CategoryChannel)
|
||||
|
||||
|
||||
class StoreChannelConverter(IDConverter[discord.StoreChannel]):
|
||||
@ -569,7 +569,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel:
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel)
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, "channels", discord.StoreChannel)
|
||||
|
||||
|
||||
class ThreadConverter(IDConverter[discord.Thread]):
|
||||
@ -587,7 +587,7 @@ class ThreadConverter(IDConverter[discord.Thread]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.Thread:
|
||||
return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread)
|
||||
return GuildChannelConverter._resolve_thread(ctx, argument, "threads", discord.Thread)
|
||||
|
||||
|
||||
class ColourConverter(Converter[discord.Colour]):
|
||||
@ -616,10 +616,10 @@ class ColourConverter(Converter[discord.Colour]):
|
||||
Added support for ``rgb`` function and 3-digit hex shortcuts
|
||||
"""
|
||||
|
||||
RGB_REGEX = re.compile(r'rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)')
|
||||
RGB_REGEX = re.compile(r"rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)")
|
||||
|
||||
def parse_hex_number(self, argument):
|
||||
arg = ''.join(i * 2 for i in argument) if len(argument) == 3 else argument
|
||||
arg = "".join(i * 2 for i in argument) if len(argument) == 3 else argument
|
||||
try:
|
||||
value = int(arg, base=16)
|
||||
if not (0 <= value <= 0xFFFFFF):
|
||||
@ -630,7 +630,7 @@ class ColourConverter(Converter[discord.Colour]):
|
||||
return discord.Color(value=value)
|
||||
|
||||
def parse_rgb_number(self, argument, number):
|
||||
if number[-1] == '%':
|
||||
if number[-1] == "%":
|
||||
value = int(number[:-1])
|
||||
if not (0 <= value <= 100):
|
||||
raise BadColourArgument(argument)
|
||||
@ -646,29 +646,29 @@ class ColourConverter(Converter[discord.Colour]):
|
||||
if match is None:
|
||||
raise BadColourArgument(argument)
|
||||
|
||||
red = self.parse_rgb_number(argument, match.group('r'))
|
||||
green = self.parse_rgb_number(argument, match.group('g'))
|
||||
blue = self.parse_rgb_number(argument, match.group('b'))
|
||||
red = self.parse_rgb_number(argument, match.group("r"))
|
||||
green = self.parse_rgb_number(argument, match.group("g"))
|
||||
blue = self.parse_rgb_number(argument, match.group("b"))
|
||||
return discord.Color.from_rgb(red, green, blue)
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.Colour:
|
||||
if argument[0] == '#':
|
||||
if argument[0] == "#":
|
||||
return self.parse_hex_number(argument[1:])
|
||||
|
||||
if argument[0:2] == '0x':
|
||||
if argument[0:2] == "0x":
|
||||
rest = argument[2:]
|
||||
# Legacy backwards compatible syntax
|
||||
if rest.startswith('#'):
|
||||
if rest.startswith("#"):
|
||||
return self.parse_hex_number(rest[1:])
|
||||
return self.parse_hex_number(rest)
|
||||
|
||||
arg = argument.lower()
|
||||
if arg[0:3] == 'rgb':
|
||||
if arg[0:3] == "rgb":
|
||||
return self.parse_rgb(arg)
|
||||
|
||||
arg = arg.replace(' ', '_')
|
||||
arg = arg.replace(" ", "_")
|
||||
method = getattr(discord.Colour, arg, None)
|
||||
if arg.startswith('from_') or method is None or not inspect.ismethod(method):
|
||||
if arg.startswith("from_") or method is None or not inspect.ismethod(method):
|
||||
raise BadColourArgument(arg)
|
||||
return method()
|
||||
|
||||
@ -697,7 +697,7 @@ class RoleConverter(IDConverter[discord.Role]):
|
||||
if not guild:
|
||||
raise NoPrivateMessage()
|
||||
|
||||
match = self._get_id_match(argument) or re.match(r'<@&([0-9]{15,20})>$', argument)
|
||||
match = self._get_id_match(argument) or re.match(r"<@&([0-9]{15,20})>$", argument)
|
||||
if match:
|
||||
result = guild.get_role(int(match.group(1)))
|
||||
else:
|
||||
@ -776,7 +776,7 @@ class EmojiConverter(IDConverter[discord.Emoji]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.Emoji:
|
||||
match = self._get_id_match(argument) or re.match(r'<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$', argument)
|
||||
match = self._get_id_match(argument) or re.match(r"<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$", argument)
|
||||
result = None
|
||||
bot = ctx.bot
|
||||
guild = ctx.guild
|
||||
@ -810,7 +810,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.PartialEmoji:
|
||||
match = re.match(r'<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$', argument)
|
||||
match = re.match(r"<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$", argument)
|
||||
|
||||
if match:
|
||||
emoji_animated = bool(match.group(1))
|
||||
@ -903,37 +903,37 @@ class clean_content(Converter[str]):
|
||||
|
||||
def resolve_member(id: int) -> str:
|
||||
m = _utils_get(msg.mentions, id=id) or ctx.guild.get_member(id)
|
||||
return f'@{m.display_name if self.use_nicknames else m.name}' if m else '@deleted-user'
|
||||
return f"@{m.display_name if self.use_nicknames else m.name}" if m else "@deleted-user"
|
||||
|
||||
def resolve_role(id: int) -> str:
|
||||
r = _utils_get(msg.role_mentions, id=id) or ctx.guild.get_role(id)
|
||||
return f'@{r.name}' if r else '@deleted-role'
|
||||
return f"@{r.name}" if r else "@deleted-role"
|
||||
|
||||
else:
|
||||
|
||||
def resolve_member(id: int) -> str:
|
||||
m = _utils_get(msg.mentions, id=id) or ctx.bot.get_user(id)
|
||||
return f'@{m.name}' if m else '@deleted-user'
|
||||
return f"@{m.name}" if m else "@deleted-user"
|
||||
|
||||
def resolve_role(id: int) -> str:
|
||||
return '@deleted-role'
|
||||
return "@deleted-role"
|
||||
|
||||
if self.fix_channel_mentions and ctx.guild:
|
||||
|
||||
def resolve_channel(id: int) -> str:
|
||||
c = ctx.guild.get_channel(id)
|
||||
return f'#{c.name}' if c else '#deleted-channel'
|
||||
return f"#{c.name}" if c else "#deleted-channel"
|
||||
|
||||
else:
|
||||
|
||||
def resolve_channel(id: int) -> str:
|
||||
return f'<#{id}>'
|
||||
return f"<#{id}>"
|
||||
|
||||
transforms = {
|
||||
'@': resolve_member,
|
||||
'@!': resolve_member,
|
||||
'#': resolve_channel,
|
||||
'@&': resolve_role,
|
||||
"@": resolve_member,
|
||||
"@!": resolve_member,
|
||||
"#": resolve_channel,
|
||||
"@&": resolve_role,
|
||||
}
|
||||
|
||||
def repl(match: re.Match) -> str:
|
||||
@ -942,7 +942,7 @@ class clean_content(Converter[str]):
|
||||
transformed = transforms[type](id)
|
||||
return transformed
|
||||
|
||||
result = re.sub(r'<(@[!&]?|#)([0-9]{15,20})>', repl, argument)
|
||||
result = re.sub(r"<(@[!&]?|#)([0-9]{15,20})>", repl, argument)
|
||||
if self.escape_markdown:
|
||||
result = discord.utils.escape_markdown(result)
|
||||
elif self.remove_markdown:
|
||||
@ -974,42 +974,42 @@ class Greedy(List[T]):
|
||||
For more information, check :ref:`ext_commands_special_converters`.
|
||||
"""
|
||||
|
||||
__slots__ = ('converter',)
|
||||
__slots__ = ("converter",)
|
||||
|
||||
def __init__(self, *, converter: T):
|
||||
self.converter = converter
|
||||
|
||||
def __repr__(self):
|
||||
converter = getattr(self.converter, '__name__', repr(self.converter))
|
||||
return f'Greedy[{converter}]'
|
||||
converter = getattr(self.converter, "__name__", repr(self.converter))
|
||||
return f"Greedy[{converter}]"
|
||||
|
||||
def __class_getitem__(cls, params: Union[Tuple[T], T]) -> Greedy[T]:
|
||||
if not isinstance(params, tuple):
|
||||
params = (params,)
|
||||
if len(params) != 1:
|
||||
raise TypeError('Greedy[...] only takes a single argument')
|
||||
raise TypeError("Greedy[...] only takes a single argument")
|
||||
converter = params[0]
|
||||
|
||||
origin = getattr(converter, '__origin__', None)
|
||||
args = getattr(converter, '__args__', ())
|
||||
origin = getattr(converter, "__origin__", None)
|
||||
args = getattr(converter, "__args__", ())
|
||||
|
||||
if not (callable(converter) or isinstance(converter, Converter) or origin is not None):
|
||||
raise TypeError('Greedy[...] expects a type or a Converter instance.')
|
||||
raise TypeError("Greedy[...] expects a type or a Converter instance.")
|
||||
|
||||
if converter in (str, type(None)) or origin is Greedy:
|
||||
raise TypeError(f'Greedy[{converter.__name__}] is invalid.')
|
||||
raise TypeError(f"Greedy[{converter.__name__}] is invalid.")
|
||||
|
||||
if origin is Union and type(None) in args:
|
||||
raise TypeError(f'Greedy[{converter!r}] is invalid.')
|
||||
raise TypeError(f"Greedy[{converter!r}] is invalid.")
|
||||
|
||||
return cls(converter=converter)
|
||||
|
||||
|
||||
def _convert_to_bool(argument: str) -> bool:
|
||||
lowered = argument.lower()
|
||||
if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'):
|
||||
if lowered in ("yes", "y", "true", "t", "1", "enable", "on"):
|
||||
return True
|
||||
elif lowered in ('no', 'n', 'false', 'f', '0', 'disable', 'off'):
|
||||
elif lowered in ("no", "n", "false", "f", "0", "disable", "off"):
|
||||
return False
|
||||
else:
|
||||
raise BadBoolArgument(lowered)
|
||||
@ -1065,7 +1065,7 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
if module is not None and (module.startswith('discord.') and not module.endswith('converter')):
|
||||
if module is not None and (module.startswith("discord.") and not module.endswith("converter")):
|
||||
converter = CONVERTER_MAPPING.get(converter, converter)
|
||||
|
||||
try:
|
||||
@ -1124,7 +1124,7 @@ async def run_converters(ctx: Context, converter, argument: str, param: inspect.
|
||||
Any
|
||||
The resulting conversion.
|
||||
"""
|
||||
origin = getattr(converter, '__origin__', None)
|
||||
origin = getattr(converter, "__origin__", None)
|
||||
|
||||
if origin is Union:
|
||||
errors = []
|
||||
|
@ -38,24 +38,25 @@ if TYPE_CHECKING:
|
||||
from ...message import Message
|
||||
|
||||
__all__ = (
|
||||
'BucketType',
|
||||
'Cooldown',
|
||||
'CooldownMapping',
|
||||
'DynamicCooldownMapping',
|
||||
'MaxConcurrency',
|
||||
"BucketType",
|
||||
"Cooldown",
|
||||
"CooldownMapping",
|
||||
"DynamicCooldownMapping",
|
||||
"MaxConcurrency",
|
||||
)
|
||||
|
||||
C = TypeVar('C', bound='CooldownMapping')
|
||||
MC = TypeVar('MC', bound='MaxConcurrency')
|
||||
C = TypeVar("C", bound="CooldownMapping")
|
||||
MC = TypeVar("MC", bound="MaxConcurrency")
|
||||
|
||||
|
||||
class BucketType(Enum):
|
||||
default = 0
|
||||
user = 1
|
||||
guild = 2
|
||||
channel = 3
|
||||
member = 4
|
||||
default = 0
|
||||
user = 1
|
||||
guild = 2
|
||||
channel = 3
|
||||
member = 4
|
||||
category = 5
|
||||
role = 6
|
||||
role = 6
|
||||
|
||||
def get_key(self, msg: Message) -> Any:
|
||||
if self is BucketType.user:
|
||||
@ -90,7 +91,7 @@ class Cooldown:
|
||||
The length of the cooldown period in seconds.
|
||||
"""
|
||||
|
||||
__slots__ = ('rate', 'per', '_window', '_tokens', '_last')
|
||||
__slots__ = ("rate", "per", "_window", "_tokens", "_last")
|
||||
|
||||
def __init__(self, rate: float, per: float) -> None:
|
||||
self.rate: int = int(rate)
|
||||
@ -190,7 +191,8 @@ class Cooldown:
|
||||
return Cooldown(self.rate, self.per)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>'
|
||||
return f"<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>"
|
||||
|
||||
|
||||
class CooldownMapping:
|
||||
def __init__(
|
||||
@ -199,7 +201,7 @@ class CooldownMapping:
|
||||
type: Callable[[Message], Any],
|
||||
) -> None:
|
||||
if not callable(type):
|
||||
raise TypeError('Cooldown type must be a BucketType or callable')
|
||||
raise TypeError("Cooldown type must be a BucketType or callable")
|
||||
|
||||
self._cache: Dict[Any, Cooldown] = {}
|
||||
self._cooldown: Optional[Cooldown] = original
|
||||
@ -256,13 +258,9 @@ class CooldownMapping:
|
||||
bucket = self.get_bucket(message, current)
|
||||
return bucket.update_rate_limit(current)
|
||||
|
||||
class DynamicCooldownMapping(CooldownMapping):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
factory: Callable[[Message], Cooldown],
|
||||
type: Callable[[Message], Any]
|
||||
) -> None:
|
||||
class DynamicCooldownMapping(CooldownMapping):
|
||||
def __init__(self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any]) -> None:
|
||||
super().__init__(None, type)
|
||||
self._factory: Callable[[Message], Cooldown] = factory
|
||||
|
||||
@ -278,6 +276,7 @@ class DynamicCooldownMapping(CooldownMapping):
|
||||
def create_bucket(self, message: Message) -> Cooldown:
|
||||
return self._factory(message)
|
||||
|
||||
|
||||
class _Semaphore:
|
||||
"""This class is a version of a semaphore.
|
||||
|
||||
@ -291,7 +290,7 @@ class _Semaphore:
|
||||
overkill for what is basically a counter.
|
||||
"""
|
||||
|
||||
__slots__ = ('value', 'loop', '_waiters')
|
||||
__slots__ = ("value", "loop", "_waiters")
|
||||
|
||||
def __init__(self, number: int) -> None:
|
||||
self.value: int = number
|
||||
@ -299,7 +298,7 @@ class _Semaphore:
|
||||
self._waiters: Deque[asyncio.Future] = deque()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<_Semaphore value={self.value} waiters={len(self._waiters)}>'
|
||||
return f"<_Semaphore value={self.value} waiters={len(self._waiters)}>"
|
||||
|
||||
def locked(self) -> bool:
|
||||
return self.value == 0
|
||||
@ -337,8 +336,9 @@ class _Semaphore:
|
||||
self.value += 1
|
||||
self.wake_up()
|
||||
|
||||
|
||||
class MaxConcurrency:
|
||||
__slots__ = ('number', 'per', 'wait', '_mapping')
|
||||
__slots__ = ("number", "per", "wait", "_mapping")
|
||||
|
||||
def __init__(self, number: int, *, per: BucketType, wait: bool) -> None:
|
||||
self._mapping: Dict[Any, _Semaphore] = {}
|
||||
@ -347,16 +347,16 @@ class MaxConcurrency:
|
||||
self.wait: bool = wait
|
||||
|
||||
if number <= 0:
|
||||
raise ValueError('max_concurrency \'number\' cannot be less than 1')
|
||||
raise ValueError("max_concurrency 'number' cannot be less than 1")
|
||||
|
||||
if not isinstance(per, BucketType):
|
||||
raise TypeError(f'max_concurrency \'per\' must be of type BucketType not {type(per)!r}')
|
||||
raise TypeError(f"max_concurrency 'per' must be of type BucketType not {type(per)!r}")
|
||||
|
||||
def copy(self: MC) -> MC:
|
||||
return self.__class__(self.number, per=self.per, wait=self.wait)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>'
|
||||
return f"<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>"
|
||||
|
||||
def get_key(self, message: Message) -> Any:
|
||||
return self.per.get_key(message)
|
||||
|
@ -70,52 +70,53 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
__all__ = (
|
||||
'Command',
|
||||
'Group',
|
||||
'GroupMixin',
|
||||
'command',
|
||||
'group',
|
||||
'has_role',
|
||||
'has_permissions',
|
||||
'has_any_role',
|
||||
'check',
|
||||
'check_any',
|
||||
'before_invoke',
|
||||
'after_invoke',
|
||||
'bot_has_role',
|
||||
'bot_has_permissions',
|
||||
'bot_has_any_role',
|
||||
'cooldown',
|
||||
'dynamic_cooldown',
|
||||
'max_concurrency',
|
||||
'dm_only',
|
||||
'guild_only',
|
||||
'is_owner',
|
||||
'is_nsfw',
|
||||
'has_guild_permissions',
|
||||
'bot_has_guild_permissions'
|
||||
"Command",
|
||||
"Group",
|
||||
"GroupMixin",
|
||||
"command",
|
||||
"group",
|
||||
"has_role",
|
||||
"has_permissions",
|
||||
"has_any_role",
|
||||
"check",
|
||||
"check_any",
|
||||
"before_invoke",
|
||||
"after_invoke",
|
||||
"bot_has_role",
|
||||
"bot_has_permissions",
|
||||
"bot_has_any_role",
|
||||
"cooldown",
|
||||
"dynamic_cooldown",
|
||||
"max_concurrency",
|
||||
"dm_only",
|
||||
"guild_only",
|
||||
"is_owner",
|
||||
"is_nsfw",
|
||||
"has_guild_permissions",
|
||||
"bot_has_guild_permissions",
|
||||
)
|
||||
|
||||
MISSING: Any = discord.utils.MISSING
|
||||
|
||||
T = TypeVar('T')
|
||||
CogT = TypeVar('CogT', bound='Cog')
|
||||
CommandT = TypeVar('CommandT', bound='Command')
|
||||
ContextT = TypeVar('ContextT', bound='Context')
|
||||
T = TypeVar("T")
|
||||
CogT = TypeVar("CogT", bound="Cog")
|
||||
CommandT = TypeVar("CommandT", bound="Command")
|
||||
ContextT = TypeVar("ContextT", bound="Context")
|
||||
# CHT = TypeVar('CHT', bound='Check')
|
||||
GroupT = TypeVar('GroupT', bound='Group')
|
||||
HookT = TypeVar('HookT', bound='Hook')
|
||||
ErrorT = TypeVar('ErrorT', bound='Error')
|
||||
GroupT = TypeVar("GroupT", bound="Group")
|
||||
HookT = TypeVar("HookT", bound="Hook")
|
||||
ErrorT = TypeVar("ErrorT", bound="Error")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
P = ParamSpec('P')
|
||||
P = ParamSpec("P")
|
||||
else:
|
||||
P = TypeVar('P')
|
||||
P = TypeVar("P")
|
||||
|
||||
|
||||
def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]:
|
||||
partial = functools.partial
|
||||
while True:
|
||||
if hasattr(function, '__wrapped__'):
|
||||
if hasattr(function, "__wrapped__"):
|
||||
function = function.__wrapped__
|
||||
elif isinstance(function, partial):
|
||||
function = function.func
|
||||
@ -139,7 +140,7 @@ def get_signature_parameters(function: Callable[..., Any], globalns: Dict[str, A
|
||||
|
||||
annotation = eval_annotation(annotation, globalns, globalns, cache)
|
||||
if annotation is Greedy:
|
||||
raise TypeError('Unparameterized Greedy[...] is disallowed in signature.')
|
||||
raise TypeError("Unparameterized Greedy[...] is disallowed in signature.")
|
||||
|
||||
params[name] = parameter.replace(annotation=annotation)
|
||||
|
||||
@ -158,8 +159,10 @@ def wrap_callback(coro):
|
||||
except Exception as exc:
|
||||
raise CommandInvokeError(exc) from exc
|
||||
return ret
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def hooked_wrapped_callback(command, ctx, coro):
|
||||
@functools.wraps(coro)
|
||||
async def wrapped(*args, **kwargs):
|
||||
@ -180,6 +183,7 @@ def hooked_wrapped_callback(command, ctx, coro):
|
||||
|
||||
await command.call_after_hooks(ctx)
|
||||
return ret
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
@ -202,6 +206,7 @@ class _CaseInsensitiveDict(dict):
|
||||
def __setitem__(self, k, v):
|
||||
super().__setitem__(k.casefold(), v)
|
||||
|
||||
|
||||
class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
r"""A class that implements the protocol for a bot text command.
|
||||
|
||||
@ -269,8 +274,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
which calls converters. If ``False`` then cooldown processing is done
|
||||
first and then the converters are called second. Defaults to ``False``.
|
||||
extras: :class:`dict`
|
||||
A dict of user provided extras to attach to the Command.
|
||||
|
||||
A dict of user provided extras to attach to the Command.
|
||||
|
||||
.. note::
|
||||
This object may be copied by the library.
|
||||
|
||||
@ -295,56 +300,60 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
self.__original_kwargs__ = kwargs.copy()
|
||||
return self
|
||||
|
||||
def __init__(self, func: Union[
|
||||
def __init__(
|
||||
self,
|
||||
func: Union[
|
||||
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
|
||||
Callable[Concatenate[ContextT, P], Coro[T]],
|
||||
], **kwargs: Any):
|
||||
],
|
||||
**kwargs: Any,
|
||||
):
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
raise TypeError('Callback must be a coroutine.')
|
||||
raise TypeError("Callback must be a coroutine.")
|
||||
|
||||
name = kwargs.get('name') or func.__name__
|
||||
name = kwargs.get("name") or func.__name__
|
||||
if not isinstance(name, str):
|
||||
raise TypeError('Name of a command must be a string.')
|
||||
raise TypeError("Name of a command must be a string.")
|
||||
self.name: str = name
|
||||
|
||||
self.callback = func
|
||||
self.enabled: bool = kwargs.get('enabled', True)
|
||||
self.enabled: bool = kwargs.get("enabled", True)
|
||||
|
||||
help_doc = kwargs.get('help')
|
||||
help_doc = kwargs.get("help")
|
||||
if help_doc is not None:
|
||||
help_doc = inspect.cleandoc(help_doc)
|
||||
else:
|
||||
help_doc = inspect.getdoc(func)
|
||||
if isinstance(help_doc, bytes):
|
||||
help_doc = help_doc.decode('utf-8')
|
||||
help_doc = help_doc.decode("utf-8")
|
||||
|
||||
self.help: Optional[str] = help_doc
|
||||
|
||||
self.brief: Optional[str] = kwargs.get('brief')
|
||||
self.usage: Optional[str] = kwargs.get('usage')
|
||||
self.rest_is_raw: bool = kwargs.get('rest_is_raw', False)
|
||||
self.aliases: Union[List[str], Tuple[str]] = kwargs.get('aliases', [])
|
||||
self.extras: Dict[str, Any] = kwargs.get('extras', {})
|
||||
self.brief: Optional[str] = kwargs.get("brief")
|
||||
self.usage: Optional[str] = kwargs.get("usage")
|
||||
self.rest_is_raw: bool = kwargs.get("rest_is_raw", False)
|
||||
self.aliases: Union[List[str], Tuple[str]] = kwargs.get("aliases", [])
|
||||
self.extras: Dict[str, Any] = kwargs.get("extras", {})
|
||||
|
||||
if not isinstance(self.aliases, (list, tuple)):
|
||||
raise TypeError("Aliases of a command must be a list or a tuple of strings.")
|
||||
|
||||
self.description: str = inspect.cleandoc(kwargs.get('description', ''))
|
||||
self.hidden: bool = kwargs.get('hidden', False)
|
||||
self.description: str = inspect.cleandoc(kwargs.get("description", ""))
|
||||
self.hidden: bool = kwargs.get("hidden", False)
|
||||
|
||||
try:
|
||||
checks = func.__commands_checks__
|
||||
checks.reverse()
|
||||
except AttributeError:
|
||||
checks = kwargs.get('checks', [])
|
||||
checks = kwargs.get("checks", [])
|
||||
|
||||
self.checks: List[Check] = checks
|
||||
|
||||
try:
|
||||
cooldown = func.__commands_cooldown__
|
||||
except AttributeError:
|
||||
cooldown = kwargs.get('cooldown')
|
||||
|
||||
cooldown = kwargs.get("cooldown")
|
||||
|
||||
if cooldown is None:
|
||||
buckets = CooldownMapping(cooldown, BucketType.default)
|
||||
elif isinstance(cooldown, CooldownMapping):
|
||||
@ -356,17 +365,17 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
try:
|
||||
max_concurrency = func.__commands_max_concurrency__
|
||||
except AttributeError:
|
||||
max_concurrency = kwargs.get('max_concurrency')
|
||||
max_concurrency = kwargs.get("max_concurrency")
|
||||
|
||||
self._max_concurrency: Optional[MaxConcurrency] = max_concurrency
|
||||
|
||||
self.require_var_positional: bool = kwargs.get('require_var_positional', False)
|
||||
self.ignore_extra: bool = kwargs.get('ignore_extra', True)
|
||||
self.cooldown_after_parsing: bool = kwargs.get('cooldown_after_parsing', False)
|
||||
self.require_var_positional: bool = kwargs.get("require_var_positional", False)
|
||||
self.ignore_extra: bool = kwargs.get("ignore_extra", True)
|
||||
self.cooldown_after_parsing: bool = kwargs.get("cooldown_after_parsing", False)
|
||||
self.cog: Optional[CogT] = None
|
||||
|
||||
# bandaid for the fact that sometimes parent can be the bot instance
|
||||
parent = kwargs.get('parent')
|
||||
parent = kwargs.get("parent")
|
||||
self.parent: Optional[GroupMixin] = parent if isinstance(parent, _BaseCommand) else None # type: ignore
|
||||
|
||||
self._before_invoke: Optional[Hook] = None
|
||||
@ -386,17 +395,19 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
self.after_invoke(after_invoke)
|
||||
|
||||
@property
|
||||
def callback(self) -> Union[
|
||||
Callable[Concatenate[CogT, Context, P], Coro[T]],
|
||||
Callable[Concatenate[Context, P], Coro[T]],
|
||||
]:
|
||||
def callback(
|
||||
self,
|
||||
) -> Union[Callable[Concatenate[CogT, Context, P], Coro[T]], Callable[Concatenate[Context, P], Coro[T]],]:
|
||||
return self._callback
|
||||
|
||||
@callback.setter
|
||||
def callback(self, function: Union[
|
||||
def callback(
|
||||
self,
|
||||
function: Union[
|
||||
Callable[Concatenate[CogT, Context, P], Coro[T]],
|
||||
Callable[Concatenate[Context, P], Coro[T]],
|
||||
]) -> None:
|
||||
],
|
||||
) -> None:
|
||||
self._callback = function
|
||||
unwrap = unwrap_function(function)
|
||||
self.module = unwrap.__module__
|
||||
@ -527,7 +538,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
wrapped = wrap_callback(local)
|
||||
await wrapped(ctx, error)
|
||||
finally:
|
||||
ctx.bot.dispatch('command_error', ctx, error)
|
||||
ctx.bot.dispatch("command_error", ctx, error)
|
||||
|
||||
async def transform(self, ctx: Context, param: inspect.Parameter) -> Any:
|
||||
required = param.default is param.empty
|
||||
@ -551,11 +562,11 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
|
||||
if view.eof:
|
||||
if param.kind == param.VAR_POSITIONAL:
|
||||
raise RuntimeError() # break the loop
|
||||
raise RuntimeError() # break the loop
|
||||
if required:
|
||||
if self._is_typing_optional(param.annotation):
|
||||
return None
|
||||
if hasattr(converter, '__commands_is_flag__') and converter._can_be_constructible():
|
||||
if hasattr(converter, "__commands_is_flag__") and converter._can_be_constructible():
|
||||
return await converter._construct_default(ctx)
|
||||
raise MissingRequiredArgument(param)
|
||||
return param.default
|
||||
@ -577,7 +588,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
# type-checker fails to narrow argument
|
||||
return await run_converters(ctx, converter, argument, param) # type: ignore
|
||||
|
||||
async def _transform_greedy_pos(self, ctx: Context, param: inspect.Parameter, required: bool, converter: Any) -> Any:
|
||||
async def _transform_greedy_pos(
|
||||
self, ctx: Context, param: inspect.Parameter, required: bool, converter: Any
|
||||
) -> Any:
|
||||
view = ctx.view
|
||||
result = []
|
||||
while not view.eof:
|
||||
@ -606,7 +619,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
value = await run_converters(ctx, converter, argument, param) # type: ignore
|
||||
except (CommandError, ArgumentParsingError):
|
||||
view.index = previous
|
||||
raise RuntimeError() from None # break loop
|
||||
raise RuntimeError() from None # break loop
|
||||
else:
|
||||
return value
|
||||
|
||||
@ -643,11 +656,11 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
entries = []
|
||||
command = self
|
||||
# command.parent is type-hinted as GroupMixin some attributes are resolved via MRO
|
||||
while command.parent is not None: # type: ignore
|
||||
command = command.parent # type: ignore
|
||||
entries.append(command.name) # type: ignore
|
||||
while command.parent is not None: # type: ignore
|
||||
command = command.parent # type: ignore
|
||||
entries.append(command.name) # type: ignore
|
||||
|
||||
return ' '.join(reversed(entries))
|
||||
return " ".join(reversed(entries))
|
||||
|
||||
@property
|
||||
def parents(self) -> List[Group]:
|
||||
@ -661,8 +674,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
"""
|
||||
entries = []
|
||||
command = self
|
||||
while command.parent is not None: # type: ignore
|
||||
command = command.parent # type: ignore
|
||||
while command.parent is not None: # type: ignore
|
||||
command = command.parent # type: ignore
|
||||
entries.append(command)
|
||||
|
||||
return entries
|
||||
@ -690,7 +703,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
|
||||
parent = self.full_parent_name
|
||||
if parent:
|
||||
return parent + ' ' + self.name
|
||||
return parent + " " + self.name
|
||||
else:
|
||||
return self.name
|
||||
|
||||
@ -745,7 +758,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
break
|
||||
|
||||
if not self.ignore_extra and not view.eof:
|
||||
raise TooManyArguments('Too many arguments passed to ' + self.qualified_name)
|
||||
raise TooManyArguments("Too many arguments passed to " + self.qualified_name)
|
||||
|
||||
async def call_before_hooks(self, ctx: Context) -> None:
|
||||
# now that we're done preparing we can call the pre-command hooks
|
||||
@ -753,7 +766,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
cog = self.cog
|
||||
if self._before_invoke is not None:
|
||||
# should be cog if @commands.before_invoke is used
|
||||
instance = getattr(self._before_invoke, '__self__', cog)
|
||||
instance = getattr(self._before_invoke, "__self__", cog)
|
||||
# __self__ only exists for methods, not functions
|
||||
# however, if @command.before_invoke is used, it will be a function
|
||||
if instance:
|
||||
@ -775,7 +788,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
async def call_after_hooks(self, ctx: Context) -> None:
|
||||
cog = self.cog
|
||||
if self._after_invoke is not None:
|
||||
instance = getattr(self._after_invoke, '__self__', cog)
|
||||
instance = getattr(self._after_invoke, "__self__", cog)
|
||||
if instance:
|
||||
await self._after_invoke(instance, ctx) # type: ignore
|
||||
else:
|
||||
@ -805,7 +818,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
ctx.command = self
|
||||
|
||||
if not await self.can_run(ctx):
|
||||
raise CheckFailure(f'The check functions for command {self.qualified_name} failed.')
|
||||
raise CheckFailure(f"The check functions for command {self.qualified_name} failed.")
|
||||
|
||||
if self._max_concurrency is not None:
|
||||
# For this application, context can be duck-typed as a Message
|
||||
@ -929,7 +942,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
"""
|
||||
|
||||
if not asyncio.iscoroutinefunction(coro):
|
||||
raise TypeError('The error handler must be a coroutine.')
|
||||
raise TypeError("The error handler must be a coroutine.")
|
||||
|
||||
self.on_error: Error = coro
|
||||
return coro
|
||||
@ -939,7 +952,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
|
||||
.. versionadded:: 1.7
|
||||
"""
|
||||
return hasattr(self, 'on_error')
|
||||
return hasattr(self, "on_error")
|
||||
|
||||
def before_invoke(self, coro: HookT) -> HookT:
|
||||
"""A decorator that registers a coroutine as a pre-invoke hook.
|
||||
@ -963,7 +976,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
The coroutine passed is not actually a coroutine.
|
||||
"""
|
||||
if not asyncio.iscoroutinefunction(coro):
|
||||
raise TypeError('The pre-invoke hook must be a coroutine.')
|
||||
raise TypeError("The pre-invoke hook must be a coroutine.")
|
||||
|
||||
self._before_invoke = coro
|
||||
return coro
|
||||
@ -990,7 +1003,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
The coroutine passed is not actually a coroutine.
|
||||
"""
|
||||
if not asyncio.iscoroutinefunction(coro):
|
||||
raise TypeError('The post-invoke hook must be a coroutine.')
|
||||
raise TypeError("The post-invoke hook must be a coroutine.")
|
||||
|
||||
self._after_invoke = coro
|
||||
return coro
|
||||
@ -1011,11 +1024,11 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
if self.brief is not None:
|
||||
return self.brief
|
||||
if self.help is not None:
|
||||
return self.help.split('\n', 1)[0]
|
||||
return ''
|
||||
return self.help.split("\n", 1)[0]
|
||||
return ""
|
||||
|
||||
def _is_typing_optional(self, annotation: Union[T, Optional[T]]) -> TypeGuard[Optional[T]]:
|
||||
return getattr(annotation, '__origin__', None) is Union and type(None) in annotation.__args__ # type: ignore
|
||||
return getattr(annotation, "__origin__", None) is Union and type(None) in annotation.__args__ # type: ignore
|
||||
|
||||
@property
|
||||
def signature(self) -> str:
|
||||
@ -1025,7 +1038,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
|
||||
params = self.clean_params
|
||||
if not params:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
result = []
|
||||
for name, param in params.items():
|
||||
@ -1035,41 +1048,40 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
|
||||
# parameter signature is a literal list of it's values
|
||||
annotation = param.annotation.converter if greedy else param.annotation
|
||||
origin = getattr(annotation, '__origin__', None)
|
||||
origin = getattr(annotation, "__origin__", None)
|
||||
if not greedy and origin is Union:
|
||||
none_cls = type(None)
|
||||
union_args = annotation.__args__
|
||||
optional = union_args[-1] is none_cls
|
||||
if len(union_args) == 2 and optional:
|
||||
annotation = union_args[0]
|
||||
origin = getattr(annotation, '__origin__', None)
|
||||
origin = getattr(annotation, "__origin__", None)
|
||||
|
||||
if origin is Literal:
|
||||
name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__)
|
||||
name = "|".join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__)
|
||||
if param.default is not param.empty:
|
||||
# We don't want None or '' to trigger the [name=value] case and instead it should
|
||||
# do [name] since [name=None] or [name=] are not exactly useful for the user.
|
||||
should_print = param.default if isinstance(param.default, str) else param.default is not None
|
||||
if should_print:
|
||||
result.append(f'[{name}={param.default}]' if not greedy else
|
||||
f'[{name}={param.default}]...')
|
||||
result.append(f"[{name}={param.default}]" if not greedy else f"[{name}={param.default}]...")
|
||||
continue
|
||||
else:
|
||||
result.append(f'[{name}]')
|
||||
result.append(f"[{name}]")
|
||||
|
||||
elif param.kind == param.VAR_POSITIONAL:
|
||||
if self.require_var_positional:
|
||||
result.append(f'<{name}...>')
|
||||
result.append(f"<{name}...>")
|
||||
else:
|
||||
result.append(f'[{name}...]')
|
||||
result.append(f"[{name}...]")
|
||||
elif greedy:
|
||||
result.append(f'[{name}]...')
|
||||
result.append(f"[{name}]...")
|
||||
elif optional:
|
||||
result.append(f'[{name}]')
|
||||
result.append(f"[{name}]")
|
||||
else:
|
||||
result.append(f'<{name}>')
|
||||
result.append(f"<{name}>")
|
||||
|
||||
return ' '.join(result)
|
||||
return " ".join(result)
|
||||
|
||||
async def can_run(self, ctx: Context) -> bool:
|
||||
"""|coro|
|
||||
@ -1099,14 +1111,14 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
"""
|
||||
|
||||
if not self.enabled:
|
||||
raise DisabledCommand(f'{self.name} command is disabled')
|
||||
raise DisabledCommand(f"{self.name} command is disabled")
|
||||
|
||||
original = ctx.command
|
||||
ctx.command = self
|
||||
|
||||
try:
|
||||
if not await ctx.bot.can_run(ctx):
|
||||
raise CheckFailure(f'The global check functions for command {self.qualified_name} failed.')
|
||||
raise CheckFailure(f"The global check functions for command {self.qualified_name} failed.")
|
||||
|
||||
cog = self.cog
|
||||
if cog is not None:
|
||||
@ -1125,6 +1137,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
finally:
|
||||
ctx.command = original
|
||||
|
||||
|
||||
class GroupMixin(Generic[CogT]):
|
||||
"""A mixin that implements common functionality for classes that behave
|
||||
similar to :class:`.Group` and are allowed to register commands.
|
||||
@ -1137,8 +1150,9 @@ class GroupMixin(Generic[CogT]):
|
||||
case_insensitive: :class:`bool`
|
||||
Whether the commands should be case insensitive. Defaults to ``True``.
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
case_insensitive = kwargs.get('case_insensitive', True)
|
||||
case_insensitive = kwargs.get("case_insensitive", True)
|
||||
self.all_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {}
|
||||
self.case_insensitive: bool = case_insensitive
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -1177,7 +1191,7 @@ class GroupMixin(Generic[CogT]):
|
||||
"""
|
||||
|
||||
if not isinstance(command, Command):
|
||||
raise TypeError('The command passed must be a subclass of Command')
|
||||
raise TypeError("The command passed must be a subclass of Command")
|
||||
|
||||
if isinstance(self, Command):
|
||||
command.parent = self
|
||||
@ -1267,7 +1281,7 @@ class GroupMixin(Generic[CogT]):
|
||||
"""
|
||||
|
||||
# fast path, no space in name.
|
||||
if ' ' not in name:
|
||||
if " " not in name:
|
||||
return self.all_commands.get(name)
|
||||
|
||||
names = name.split()
|
||||
@ -1298,7 +1312,9 @@ class GroupMixin(Generic[CogT]):
|
||||
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
|
||||
Callable[Concatenate[ContextT, P], Coro[T]],
|
||||
]
|
||||
], Command[CogT, P, T]]:
|
||||
],
|
||||
Command[CogT, P, T],
|
||||
]:
|
||||
...
|
||||
|
||||
@overload
|
||||
@ -1326,8 +1342,9 @@ class GroupMixin(Generic[CogT]):
|
||||
Callable[..., :class:`Command`]
|
||||
A decorator that converts the provided method into a Command, adds it to the bot, then returns it.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> CommandT:
|
||||
kwargs.setdefault('parent', self)
|
||||
kwargs.setdefault("parent", self)
|
||||
result = command(name=name, cls=cls, *args, **kwargs)(func)
|
||||
self.add_command(result)
|
||||
return result
|
||||
@ -1341,12 +1358,10 @@ class GroupMixin(Generic[CogT]):
|
||||
cls: Type[Group[CogT, P, T]] = ...,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Callable[[
|
||||
Union[
|
||||
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
|
||||
Callable[Concatenate[ContextT, P], Coro[T]]
|
||||
]
|
||||
], Group[CogT, P, T]]:
|
||||
) -> Callable[
|
||||
[Union[Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]]]],
|
||||
Group[CogT, P, T],
|
||||
]:
|
||||
...
|
||||
|
||||
@overload
|
||||
@ -1374,14 +1389,16 @@ class GroupMixin(Generic[CogT]):
|
||||
Callable[..., :class:`Group`]
|
||||
A decorator that converts the provided method into a Group, adds it to the bot, then returns it.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> GroupT:
|
||||
kwargs.setdefault('parent', self)
|
||||
kwargs.setdefault("parent", self)
|
||||
result = group(name=name, cls=cls, *args, **kwargs)(func)
|
||||
self.add_command(result)
|
||||
return result
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class Group(GroupMixin[CogT], Command[CogT, P, T]):
|
||||
"""A class that implements a grouping protocol for commands to be
|
||||
executed as subcommands.
|
||||
@ -1404,8 +1421,9 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
|
||||
Indicates if the group's commands should be case insensitive.
|
||||
Defaults to ``False``.
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, **attrs: Any) -> None:
|
||||
self.invoke_without_command: bool = attrs.pop('invoke_without_command', False)
|
||||
self.invoke_without_command: bool = attrs.pop("invoke_without_command", False)
|
||||
super().__init__(*args, **attrs)
|
||||
|
||||
def copy(self: GroupT) -> GroupT:
|
||||
@ -1492,8 +1510,10 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
|
||||
view.previous = previous
|
||||
await super().reinvoke(ctx, call_hooks=call_hooks)
|
||||
|
||||
|
||||
# Decorators
|
||||
|
||||
|
||||
@overload
|
||||
def command(
|
||||
name: str = ...,
|
||||
@ -1505,10 +1525,12 @@ def command(
|
||||
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
|
||||
Callable[Concatenate[ContextT, P], Coro[T]],
|
||||
]
|
||||
]
|
||||
, Command[CogT, P, T]]:
|
||||
],
|
||||
Command[CogT, P, T],
|
||||
]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def command(
|
||||
name: str = ...,
|
||||
@ -1520,22 +1542,23 @@ def command(
|
||||
Callable[Concatenate[CogT, ContextT, P], Coro[Any]],
|
||||
Callable[Concatenate[ContextT, P], Coro[Any]],
|
||||
]
|
||||
]
|
||||
, CommandT]:
|
||||
],
|
||||
CommandT,
|
||||
]:
|
||||
...
|
||||
|
||||
|
||||
def command(
|
||||
name: str = MISSING,
|
||||
cls: Type[CommandT] = MISSING,
|
||||
**attrs: Any
|
||||
name: str = MISSING, cls: Type[CommandT] = MISSING, **attrs: Any
|
||||
) -> Callable[
|
||||
[
|
||||
Union[
|
||||
Callable[Concatenate[ContextT, P], Coro[Any]],
|
||||
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
|
||||
]
|
||||
]
|
||||
, Union[Command[CogT, P, T], CommandT]]:
|
||||
],
|
||||
Union[Command[CogT, P, T], CommandT],
|
||||
]:
|
||||
"""A decorator that transforms a function into a :class:`.Command`
|
||||
or if called with :func:`.group`, :class:`.Group`.
|
||||
|
||||
@ -1568,16 +1591,19 @@ def command(
|
||||
if cls is MISSING:
|
||||
cls = Command # type: ignore
|
||||
|
||||
def decorator(func: Union[
|
||||
def decorator(
|
||||
func: Union[
|
||||
Callable[Concatenate[ContextT, P], Coro[Any]],
|
||||
Callable[Concatenate[CogT, ContextT, P], Coro[Any]],
|
||||
]) -> CommandT:
|
||||
]
|
||||
) -> CommandT:
|
||||
if isinstance(func, Command):
|
||||
raise TypeError('Callback is already a command.')
|
||||
raise TypeError("Callback is already a command.")
|
||||
return cls(func, name=name, **attrs)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@overload
|
||||
def group(
|
||||
name: str = ...,
|
||||
@ -1589,10 +1615,12 @@ def group(
|
||||
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
|
||||
Callable[Concatenate[ContextT, P], Coro[T]],
|
||||
]
|
||||
]
|
||||
, Group[CogT, P, T]]:
|
||||
],
|
||||
Group[CogT, P, T],
|
||||
]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def group(
|
||||
name: str = ...,
|
||||
@ -1604,10 +1632,12 @@ def group(
|
||||
Callable[Concatenate[CogT, ContextT, P], Coro[Any]],
|
||||
Callable[Concatenate[ContextT, P], Coro[Any]],
|
||||
]
|
||||
]
|
||||
, GroupT]:
|
||||
],
|
||||
GroupT,
|
||||
]:
|
||||
...
|
||||
|
||||
|
||||
def group(
|
||||
name: str = MISSING,
|
||||
cls: Type[GroupT] = MISSING,
|
||||
@ -1618,8 +1648,9 @@ def group(
|
||||
Callable[Concatenate[ContextT, P], Coro[Any]],
|
||||
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
|
||||
]
|
||||
]
|
||||
, Union[Group[CogT, P, T], GroupT]]:
|
||||
],
|
||||
Union[Group[CogT, P, T], GroupT],
|
||||
]:
|
||||
"""A decorator that transforms a function into a :class:`.Group`.
|
||||
|
||||
This is similar to the :func:`.command` decorator but the ``cls``
|
||||
@ -1632,6 +1663,7 @@ def group(
|
||||
cls = Group # type: ignore
|
||||
return command(name=name, cls=cls, **attrs) # type: ignore
|
||||
|
||||
|
||||
def check(predicate: Check) -> Callable[[T], T]:
|
||||
r"""A decorator that adds a check to the :class:`.Command` or its
|
||||
subclasses. These checks could be accessed via :attr:`.Command.checks`.
|
||||
@ -1707,7 +1739,7 @@ def check(predicate: Check) -> Callable[[T], T]:
|
||||
if isinstance(func, Command):
|
||||
func.checks.append(predicate)
|
||||
else:
|
||||
if not hasattr(func, '__commands_checks__'):
|
||||
if not hasattr(func, "__commands_checks__"):
|
||||
func.__commands_checks__ = []
|
||||
|
||||
func.__commands_checks__.append(predicate)
|
||||
@ -1717,13 +1749,16 @@ def check(predicate: Check) -> Callable[[T], T]:
|
||||
if inspect.iscoroutinefunction(predicate):
|
||||
decorator.predicate = predicate
|
||||
else:
|
||||
|
||||
@functools.wraps(predicate)
|
||||
async def wrapper(ctx):
|
||||
return predicate(ctx) # type: ignore
|
||||
|
||||
decorator.predicate = wrapper
|
||||
|
||||
return decorator # type: ignore
|
||||
|
||||
|
||||
def check_any(*checks: Check) -> Callable[[T], T]:
|
||||
r"""A :func:`check` that is added that checks if any of the checks passed
|
||||
will pass, i.e. using logical OR.
|
||||
@ -1773,7 +1808,7 @@ def check_any(*checks: Check) -> Callable[[T], T]:
|
||||
try:
|
||||
pred = wrapped.predicate
|
||||
except AttributeError:
|
||||
raise TypeError(f'{wrapped!r} must be wrapped by commands.check decorator') from None
|
||||
raise TypeError(f"{wrapped!r} must be wrapped by commands.check decorator") from None
|
||||
else:
|
||||
unwrapped.append(pred)
|
||||
|
||||
@ -1792,6 +1827,7 @@ def check_any(*checks: Check) -> Callable[[T], T]:
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def has_role(item: Union[int, str]) -> Callable[[T], T]:
|
||||
"""A :func:`.check` that is added that checks if the member invoking the
|
||||
command has the role specified via the name or ID specified.
|
||||
@ -1834,6 +1870,7 @@ def has_role(item: Union[int, str]) -> Callable[[T], T]:
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
|
||||
r"""A :func:`.check` that is added that checks if the member invoking the
|
||||
command has **any** of the roles specified. This means that if they have
|
||||
@ -1865,18 +1902,22 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
|
||||
async def cool(ctx):
|
||||
await ctx.send('You are cool indeed')
|
||||
"""
|
||||
|
||||
def predicate(ctx):
|
||||
if ctx.guild is None:
|
||||
raise NoPrivateMessage()
|
||||
|
||||
# ctx.guild is None doesn't narrow ctx.author to Member
|
||||
getter = functools.partial(discord.utils.get, ctx.author.roles) # type: ignore
|
||||
if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items):
|
||||
if any(
|
||||
getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items
|
||||
):
|
||||
return True
|
||||
raise MissingAnyRole(list(items))
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def bot_has_role(item: int) -> Callable[[T], T]:
|
||||
"""Similar to :func:`.has_role` except checks if the bot itself has the
|
||||
role.
|
||||
@ -1903,8 +1944,10 @@ def bot_has_role(item: int) -> Callable[[T], T]:
|
||||
if role is None:
|
||||
raise BotMissingRole(item)
|
||||
return True
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def bot_has_any_role(*items: int) -> Callable[[T], T]:
|
||||
"""Similar to :func:`.has_any_role` except checks if the bot itself has
|
||||
any of the roles listed.
|
||||
@ -1918,17 +1961,22 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]:
|
||||
Raise :exc:`.BotMissingAnyRole` or :exc:`.NoPrivateMessage`
|
||||
instead of generic checkfailure
|
||||
"""
|
||||
|
||||
def predicate(ctx):
|
||||
if ctx.guild is None:
|
||||
raise NoPrivateMessage()
|
||||
|
||||
me = ctx.me
|
||||
getter = functools.partial(discord.utils.get, me.roles)
|
||||
if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items):
|
||||
if any(
|
||||
getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items
|
||||
):
|
||||
return True
|
||||
raise BotMissingAnyRole(list(items))
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def has_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
"""A :func:`.check` that is added that checks if the member has all of
|
||||
the permissions necessary.
|
||||
@ -1976,6 +2024,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
"""Similar to :func:`.has_permissions` except checks if the bot itself has
|
||||
the permissions listed.
|
||||
@ -2002,6 +2051,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
"""Similar to :func:`.has_permissions`, but operates on guild wide
|
||||
permissions instead of the current channel permissions.
|
||||
@ -2030,6 +2080,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
"""Similar to :func:`.has_guild_permissions`, but checks the bot
|
||||
members guild permissions.
|
||||
@ -2055,6 +2106,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def dm_only() -> Callable[[T], T]:
|
||||
"""A :func:`.check` that indicates this command must only be used in a
|
||||
DM context. Only private messages are allowed when
|
||||
@ -2073,6 +2125,7 @@ def dm_only() -> Callable[[T], T]:
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def guild_only() -> Callable[[T], T]:
|
||||
"""A :func:`.check` that indicates this command must only be used in a
|
||||
guild context only. Basically, no private messages are allowed when
|
||||
@ -2089,6 +2142,7 @@ def guild_only() -> Callable[[T], T]:
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def is_owner() -> Callable[[T], T]:
|
||||
"""A :func:`.check` that checks if the person invoking this command is the
|
||||
owner of the bot.
|
||||
@ -2101,11 +2155,12 @@ def is_owner() -> Callable[[T], T]:
|
||||
|
||||
async def predicate(ctx: Context) -> bool:
|
||||
if not await ctx.bot.is_owner(ctx.author):
|
||||
raise NotOwner('You do not own this bot.')
|
||||
raise NotOwner("You do not own this bot.")
|
||||
return True
|
||||
|
||||
return check(predicate)
|
||||
|
||||
|
||||
def is_nsfw() -> Callable[[T], T]:
|
||||
"""A :func:`.check` that checks if the channel is a NSFW channel.
|
||||
|
||||
@ -2117,14 +2172,19 @@ def is_nsfw() -> Callable[[T], T]:
|
||||
Raise :exc:`.NSFWChannelRequired` instead of generic :exc:`.CheckFailure`.
|
||||
DM channels will also now pass this check.
|
||||
"""
|
||||
|
||||
def pred(ctx: Context) -> bool:
|
||||
ch = ctx.channel
|
||||
if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()):
|
||||
return True
|
||||
raise NSFWChannelRequired(ch) # type: ignore
|
||||
|
||||
return check(pred)
|
||||
|
||||
def cooldown(rate: int, per: float, type: Union[BucketType, Callable[[Message], Any]] = BucketType.default) -> Callable[[T], T]:
|
||||
|
||||
def cooldown(
|
||||
rate: int, per: float, type: Union[BucketType, Callable[[Message], Any]] = BucketType.default
|
||||
) -> Callable[[T], T]:
|
||||
"""A decorator that adds a cooldown to a :class:`.Command`
|
||||
|
||||
A cooldown allows a command to only be used a specific amount
|
||||
@ -2157,9 +2217,13 @@ def cooldown(rate: int, per: float, type: Union[BucketType, Callable[[Message],
|
||||
else:
|
||||
func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type)
|
||||
return func
|
||||
|
||||
return decorator # type: ignore
|
||||
|
||||
def dynamic_cooldown(cooldown: Union[BucketType, Callable[[Message], Any]], type: BucketType = BucketType.default) -> Callable[[T], T]:
|
||||
|
||||
def dynamic_cooldown(
|
||||
cooldown: Union[BucketType, Callable[[Message], Any]], type: BucketType = BucketType.default
|
||||
) -> Callable[[T], T]:
|
||||
"""A decorator that adds a dynamic cooldown to a :class:`.Command`
|
||||
|
||||
This differs from :func:`.cooldown` in that it takes a function that
|
||||
@ -2197,8 +2261,10 @@ def dynamic_cooldown(cooldown: Union[BucketType, Callable[[Message], Any]], type
|
||||
else:
|
||||
func.__commands_cooldown__ = DynamicCooldownMapping(cooldown, type)
|
||||
return func
|
||||
|
||||
return decorator # type: ignore
|
||||
|
||||
|
||||
def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait: bool = False) -> Callable[[T], T]:
|
||||
"""A decorator that adds a maximum concurrency to a :class:`.Command` or its subclasses.
|
||||
|
||||
@ -2230,8 +2296,10 @@ def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait:
|
||||
else:
|
||||
func.__commands_max_concurrency__ = value
|
||||
return func
|
||||
|
||||
return decorator # type: ignore
|
||||
|
||||
|
||||
def before_invoke(coro) -> Callable[[T], T]:
|
||||
"""A decorator that registers a coroutine as a pre-invoke hook.
|
||||
|
||||
@ -2270,14 +2338,17 @@ def before_invoke(coro) -> Callable[[T], T]:
|
||||
|
||||
bot.add_cog(What())
|
||||
"""
|
||||
|
||||
def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
|
||||
if isinstance(func, Command):
|
||||
func.before_invoke(coro)
|
||||
else:
|
||||
func.__before_invoke__ = coro
|
||||
return func
|
||||
|
||||
return decorator # type: ignore
|
||||
|
||||
|
||||
def after_invoke(coro) -> Callable[[T], T]:
|
||||
"""A decorator that registers a coroutine as a post-invoke hook.
|
||||
|
||||
@ -2286,10 +2357,12 @@ def after_invoke(coro) -> Callable[[T], T]:
|
||||
|
||||
.. versionadded:: 1.4
|
||||
"""
|
||||
|
||||
def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
|
||||
if isinstance(func, Command):
|
||||
func.after_invoke(coro)
|
||||
else:
|
||||
func.__after_invoke__ = coro
|
||||
return func
|
||||
|
||||
return decorator # type: ignore
|
||||
|
@ -41,65 +41,66 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
__all__ = (
|
||||
'CommandError',
|
||||
'MissingRequiredArgument',
|
||||
'BadArgument',
|
||||
'PrivateMessageOnly',
|
||||
'NoPrivateMessage',
|
||||
'CheckFailure',
|
||||
'CheckAnyFailure',
|
||||
'CommandNotFound',
|
||||
'DisabledCommand',
|
||||
'CommandInvokeError',
|
||||
'TooManyArguments',
|
||||
'UserInputError',
|
||||
'CommandOnCooldown',
|
||||
'MaxConcurrencyReached',
|
||||
'NotOwner',
|
||||
'MessageNotFound',
|
||||
'ObjectNotFound',
|
||||
'MemberNotFound',
|
||||
'GuildNotFound',
|
||||
'UserNotFound',
|
||||
'ChannelNotFound',
|
||||
'ThreadNotFound',
|
||||
'ChannelNotReadable',
|
||||
'BadColourArgument',
|
||||
'BadColorArgument',
|
||||
'RoleNotFound',
|
||||
'BadInviteArgument',
|
||||
'EmojiNotFound',
|
||||
'GuildStickerNotFound',
|
||||
'PartialEmojiConversionFailure',
|
||||
'BadBoolArgument',
|
||||
'MissingRole',
|
||||
'BotMissingRole',
|
||||
'MissingAnyRole',
|
||||
'BotMissingAnyRole',
|
||||
'MissingPermissions',
|
||||
'BotMissingPermissions',
|
||||
'NSFWChannelRequired',
|
||||
'ConversionError',
|
||||
'BadUnionArgument',
|
||||
'BadLiteralArgument',
|
||||
'ArgumentParsingError',
|
||||
'UnexpectedQuoteError',
|
||||
'InvalidEndOfQuotedStringError',
|
||||
'ExpectedClosingQuoteError',
|
||||
'ExtensionError',
|
||||
'ExtensionAlreadyLoaded',
|
||||
'ExtensionNotLoaded',
|
||||
'NoEntryPointError',
|
||||
'ExtensionFailed',
|
||||
'ExtensionNotFound',
|
||||
'CommandRegistrationError',
|
||||
'FlagError',
|
||||
'BadFlagArgument',
|
||||
'MissingFlagArgument',
|
||||
'TooManyFlags',
|
||||
'MissingRequiredFlag',
|
||||
"CommandError",
|
||||
"MissingRequiredArgument",
|
||||
"BadArgument",
|
||||
"PrivateMessageOnly",
|
||||
"NoPrivateMessage",
|
||||
"CheckFailure",
|
||||
"CheckAnyFailure",
|
||||
"CommandNotFound",
|
||||
"DisabledCommand",
|
||||
"CommandInvokeError",
|
||||
"TooManyArguments",
|
||||
"UserInputError",
|
||||
"CommandOnCooldown",
|
||||
"MaxConcurrencyReached",
|
||||
"NotOwner",
|
||||
"MessageNotFound",
|
||||
"ObjectNotFound",
|
||||
"MemberNotFound",
|
||||
"GuildNotFound",
|
||||
"UserNotFound",
|
||||
"ChannelNotFound",
|
||||
"ThreadNotFound",
|
||||
"ChannelNotReadable",
|
||||
"BadColourArgument",
|
||||
"BadColorArgument",
|
||||
"RoleNotFound",
|
||||
"BadInviteArgument",
|
||||
"EmojiNotFound",
|
||||
"GuildStickerNotFound",
|
||||
"PartialEmojiConversionFailure",
|
||||
"BadBoolArgument",
|
||||
"MissingRole",
|
||||
"BotMissingRole",
|
||||
"MissingAnyRole",
|
||||
"BotMissingAnyRole",
|
||||
"MissingPermissions",
|
||||
"BotMissingPermissions",
|
||||
"NSFWChannelRequired",
|
||||
"ConversionError",
|
||||
"BadUnionArgument",
|
||||
"BadLiteralArgument",
|
||||
"ArgumentParsingError",
|
||||
"UnexpectedQuoteError",
|
||||
"InvalidEndOfQuotedStringError",
|
||||
"ExpectedClosingQuoteError",
|
||||
"ExtensionError",
|
||||
"ExtensionAlreadyLoaded",
|
||||
"ExtensionNotLoaded",
|
||||
"NoEntryPointError",
|
||||
"ExtensionFailed",
|
||||
"ExtensionNotFound",
|
||||
"CommandRegistrationError",
|
||||
"FlagError",
|
||||
"BadFlagArgument",
|
||||
"MissingFlagArgument",
|
||||
"TooManyFlags",
|
||||
"MissingRequiredFlag",
|
||||
)
|
||||
|
||||
|
||||
class CommandError(DiscordException):
|
||||
r"""The base exception type for all command related errors.
|
||||
|
||||
@ -109,14 +110,16 @@ class CommandError(DiscordException):
|
||||
in a special way as they are caught and passed into a special event
|
||||
from :class:`.Bot`\, :func:`.on_command_error`.
|
||||
"""
|
||||
|
||||
def __init__(self, message: Optional[str] = None, *args: Any) -> None:
|
||||
if message is not None:
|
||||
# clean-up @everyone and @here mentions
|
||||
m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere')
|
||||
m = message.replace("@everyone", "@\u200beveryone").replace("@here", "@\u200bhere")
|
||||
super().__init__(m, *args)
|
||||
else:
|
||||
super().__init__(*args)
|
||||
|
||||
|
||||
class ConversionError(CommandError):
|
||||
"""Exception raised when a Converter class raises non-CommandError.
|
||||
|
||||
@ -130,18 +133,22 @@ class ConversionError(CommandError):
|
||||
The original exception that was raised. You can also get this via
|
||||
the ``__cause__`` attribute.
|
||||
"""
|
||||
|
||||
def __init__(self, converter: Converter, original: Exception) -> None:
|
||||
self.converter: Converter = converter
|
||||
self.original: Exception = original
|
||||
|
||||
|
||||
class UserInputError(CommandError):
|
||||
"""The base exception type for errors that involve errors
|
||||
regarding user input.
|
||||
|
||||
This inherits from :exc:`CommandError`.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CommandNotFound(CommandError):
|
||||
"""Exception raised when a command is attempted to be invoked
|
||||
but no command under that name is found.
|
||||
@ -151,8 +158,10 @@ class CommandNotFound(CommandError):
|
||||
|
||||
This inherits from :exc:`CommandError`.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MissingRequiredArgument(UserInputError):
|
||||
"""Exception raised when parsing a command and a parameter
|
||||
that is required is not encountered.
|
||||
@ -164,9 +173,11 @@ class MissingRequiredArgument(UserInputError):
|
||||
param: :class:`inspect.Parameter`
|
||||
The argument that is missing.
|
||||
"""
|
||||
|
||||
def __init__(self, param: Parameter) -> None:
|
||||
self.param: Parameter = param
|
||||
super().__init__(f'{param.name} is a required argument that is missing.')
|
||||
super().__init__(f"{param.name} is a required argument that is missing.")
|
||||
|
||||
|
||||
class TooManyArguments(UserInputError):
|
||||
"""Exception raised when the command was passed too many arguments and its
|
||||
@ -174,23 +185,29 @@ class TooManyArguments(UserInputError):
|
||||
|
||||
This inherits from :exc:`UserInputError`
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BadArgument(UserInputError):
|
||||
"""Exception raised when a parsing or conversion failure is encountered
|
||||
on an argument to pass into a command.
|
||||
|
||||
This inherits from :exc:`UserInputError`
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CheckFailure(CommandError):
|
||||
"""Exception raised when the predicates in :attr:`.Command.checks` have failed.
|
||||
|
||||
This inherits from :exc:`CommandError`
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CheckAnyFailure(CheckFailure):
|
||||
"""Exception raised when all predicates in :func:`check_any` fail.
|
||||
|
||||
@ -209,7 +226,8 @@ class CheckAnyFailure(CheckFailure):
|
||||
def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]) -> None:
|
||||
self.checks: List[CheckFailure] = checks
|
||||
self.errors: List[Callable[[Context], bool]] = errors
|
||||
super().__init__('You do not have permission to run this command.')
|
||||
super().__init__("You do not have permission to run this command.")
|
||||
|
||||
|
||||
class PrivateMessageOnly(CheckFailure):
|
||||
"""Exception raised when an operation does not work outside of private
|
||||
@ -217,8 +235,10 @@ class PrivateMessageOnly(CheckFailure):
|
||||
|
||||
This inherits from :exc:`CheckFailure`
|
||||
"""
|
||||
|
||||
def __init__(self, message: Optional[str] = None) -> None:
|
||||
super().__init__(message or 'This command can only be used in private messages.')
|
||||
super().__init__(message or "This command can only be used in private messages.")
|
||||
|
||||
|
||||
class NoPrivateMessage(CheckFailure):
|
||||
"""Exception raised when an operation does not work in private message
|
||||
@ -228,15 +248,18 @@ class NoPrivateMessage(CheckFailure):
|
||||
"""
|
||||
|
||||
def __init__(self, message: Optional[str] = None) -> None:
|
||||
super().__init__(message or 'This command cannot be used in private messages.')
|
||||
super().__init__(message or "This command cannot be used in private messages.")
|
||||
|
||||
|
||||
class NotOwner(CheckFailure):
|
||||
"""Exception raised when the message author is not the owner of the bot.
|
||||
|
||||
This inherits from :exc:`CheckFailure`
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ObjectNotFound(BadArgument):
|
||||
"""Exception raised when the argument provided did not match the format
|
||||
of an ID or a mention.
|
||||
@ -250,9 +273,11 @@ class ObjectNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The argument supplied by the caller that was not matched
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'{argument!r} does not follow a valid ID or mention format.')
|
||||
super().__init__(f"{argument!r} does not follow a valid ID or mention format.")
|
||||
|
||||
|
||||
class MemberNotFound(BadArgument):
|
||||
"""Exception raised when the member provided was not found in the bot's
|
||||
@ -267,10 +292,12 @@ class MemberNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The member supplied by the caller that was not found
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Member "{argument}" not found.')
|
||||
|
||||
|
||||
class GuildNotFound(BadArgument):
|
||||
"""Exception raised when the guild provided was not found in the bot's cache.
|
||||
|
||||
@ -283,10 +310,12 @@ class GuildNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The guild supplied by the called that was not found
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Guild "{argument}" not found.')
|
||||
|
||||
|
||||
class UserNotFound(BadArgument):
|
||||
"""Exception raised when the user provided was not found in the bot's
|
||||
cache.
|
||||
@ -300,10 +329,12 @@ class UserNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The user supplied by the caller that was not found
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'User "{argument}" not found.')
|
||||
|
||||
|
||||
class MessageNotFound(BadArgument):
|
||||
"""Exception raised when the message provided was not found in the channel.
|
||||
|
||||
@ -316,10 +347,12 @@ class MessageNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The message supplied by the caller that was not found
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Message "{argument}" not found.')
|
||||
|
||||
|
||||
class ChannelNotReadable(BadArgument):
|
||||
"""Exception raised when the bot does not have permission to read messages
|
||||
in the channel.
|
||||
@ -333,10 +366,12 @@ class ChannelNotReadable(BadArgument):
|
||||
argument: Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
|
||||
The channel supplied by the caller that was not readable
|
||||
"""
|
||||
|
||||
def __init__(self, argument: Union[GuildChannel, Thread]) -> None:
|
||||
self.argument: Union[GuildChannel, Thread] = argument
|
||||
super().__init__(f"Can't read messages in {argument.mention}.")
|
||||
|
||||
|
||||
class ChannelNotFound(BadArgument):
|
||||
"""Exception raised when the bot can not find the channel.
|
||||
|
||||
@ -349,10 +384,12 @@ class ChannelNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The channel supplied by the caller that was not found
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Channel "{argument}" not found.')
|
||||
|
||||
|
||||
class ThreadNotFound(BadArgument):
|
||||
"""Exception raised when the bot can not find the thread.
|
||||
|
||||
@ -365,10 +402,12 @@ class ThreadNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The thread supplied by the caller that was not found
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Thread "{argument}" not found.')
|
||||
|
||||
|
||||
class BadColourArgument(BadArgument):
|
||||
"""Exception raised when the colour is not valid.
|
||||
|
||||
@ -381,12 +420,15 @@ class BadColourArgument(BadArgument):
|
||||
argument: :class:`str`
|
||||
The colour supplied by the caller that was not valid
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Colour "{argument}" is invalid.')
|
||||
|
||||
|
||||
BadColorArgument = BadColourArgument
|
||||
|
||||
|
||||
class RoleNotFound(BadArgument):
|
||||
"""Exception raised when the bot can not find the role.
|
||||
|
||||
@ -399,10 +441,12 @@ class RoleNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The role supplied by the caller that was not found
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Role "{argument}" not found.')
|
||||
|
||||
|
||||
class BadInviteArgument(BadArgument):
|
||||
"""Exception raised when the invite is invalid or expired.
|
||||
|
||||
@ -410,10 +454,12 @@ class BadInviteArgument(BadArgument):
|
||||
|
||||
.. versionadded:: 1.5
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Invite "{argument}" is invalid or expired.')
|
||||
|
||||
|
||||
class EmojiNotFound(BadArgument):
|
||||
"""Exception raised when the bot can not find the emoji.
|
||||
|
||||
@ -426,10 +472,12 @@ class EmojiNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The emoji supplied by the caller that was not found
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Emoji "{argument}" not found.')
|
||||
|
||||
|
||||
class PartialEmojiConversionFailure(BadArgument):
|
||||
"""Exception raised when the emoji provided does not match the correct
|
||||
format.
|
||||
@ -443,10 +491,12 @@ class PartialEmojiConversionFailure(BadArgument):
|
||||
argument: :class:`str`
|
||||
The emoji supplied by the caller that did not match the regex
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Couldn\'t convert "{argument}" to PartialEmoji.')
|
||||
|
||||
|
||||
class GuildStickerNotFound(BadArgument):
|
||||
"""Exception raised when the bot can not find the sticker.
|
||||
|
||||
@ -459,10 +509,12 @@ class GuildStickerNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The sticker supplied by the caller that was not found
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Sticker "{argument}" not found.')
|
||||
|
||||
|
||||
class BadBoolArgument(BadArgument):
|
||||
"""Exception raised when a boolean argument was not convertable.
|
||||
|
||||
@ -475,17 +527,21 @@ class BadBoolArgument(BadArgument):
|
||||
argument: :class:`str`
|
||||
The boolean argument supplied by the caller that is not in the predefined list
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'{argument} is not a recognised boolean option')
|
||||
super().__init__(f"{argument} is not a recognised boolean option")
|
||||
|
||||
|
||||
class DisabledCommand(CommandError):
|
||||
"""Exception raised when the command being invoked is disabled.
|
||||
|
||||
This inherits from :exc:`CommandError`
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CommandInvokeError(CommandError):
|
||||
"""Exception raised when the command being invoked raised an exception.
|
||||
|
||||
@ -497,9 +553,11 @@ class CommandInvokeError(CommandError):
|
||||
The original exception that was raised. You can also get this via
|
||||
the ``__cause__`` attribute.
|
||||
"""
|
||||
|
||||
def __init__(self, e: Exception) -> None:
|
||||
self.original: Exception = e
|
||||
super().__init__(f'Command raised an exception: {e.__class__.__name__}: {e}')
|
||||
super().__init__(f"Command raised an exception: {e.__class__.__name__}: {e}")
|
||||
|
||||
|
||||
class CommandOnCooldown(CommandError):
|
||||
"""Exception raised when the command being invoked is on cooldown.
|
||||
@ -516,11 +574,13 @@ class CommandOnCooldown(CommandError):
|
||||
retry_after: :class:`float`
|
||||
The amount of seconds to wait before you can retry again.
|
||||
"""
|
||||
|
||||
def __init__(self, cooldown: Cooldown, retry_after: float, type: BucketType) -> None:
|
||||
self.cooldown: Cooldown = cooldown
|
||||
self.retry_after: float = retry_after
|
||||
self.type: BucketType = type
|
||||
super().__init__(f'You are on cooldown. Try again in {retry_after:.2f}s')
|
||||
super().__init__(f"You are on cooldown. Try again in {retry_after:.2f}s")
|
||||
|
||||
|
||||
class MaxConcurrencyReached(CommandError):
|
||||
"""Exception raised when the command being invoked has reached its maximum concurrency.
|
||||
@ -539,10 +599,11 @@ class MaxConcurrencyReached(CommandError):
|
||||
self.number: int = number
|
||||
self.per: BucketType = per
|
||||
name = per.name
|
||||
suffix = 'per %s' % name if per.name != 'default' else 'globally'
|
||||
plural = '%s times %s' if number > 1 else '%s time %s'
|
||||
suffix = "per %s" % name if per.name != "default" else "globally"
|
||||
plural = "%s times %s" if number > 1 else "%s time %s"
|
||||
fmt = plural % (number, suffix)
|
||||
super().__init__(f'Too many people are using this command. It can only be used {fmt} concurrently.')
|
||||
super().__init__(f"Too many people are using this command. It can only be used {fmt} concurrently.")
|
||||
|
||||
|
||||
class MissingRole(CheckFailure):
|
||||
"""Exception raised when the command invoker lacks a role to run a command.
|
||||
@ -557,11 +618,13 @@ class MissingRole(CheckFailure):
|
||||
The required role that is missing.
|
||||
This is the parameter passed to :func:`~.commands.has_role`.
|
||||
"""
|
||||
|
||||
def __init__(self, missing_role: Snowflake) -> None:
|
||||
self.missing_role: Snowflake = missing_role
|
||||
message = f'Role {missing_role!r} is required to run this command.'
|
||||
message = f"Role {missing_role!r} is required to run this command."
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BotMissingRole(CheckFailure):
|
||||
"""Exception raised when the bot's member lacks a role to run a command.
|
||||
|
||||
@ -575,11 +638,13 @@ class BotMissingRole(CheckFailure):
|
||||
The required role that is missing.
|
||||
This is the parameter passed to :func:`~.commands.has_role`.
|
||||
"""
|
||||
|
||||
def __init__(self, missing_role: Snowflake) -> None:
|
||||
self.missing_role: Snowflake = missing_role
|
||||
message = f'Bot requires the role {missing_role!r} to run this command'
|
||||
message = f"Bot requires the role {missing_role!r} to run this command"
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class MissingAnyRole(CheckFailure):
|
||||
"""Exception raised when the command invoker lacks any of
|
||||
the roles specified to run a command.
|
||||
@ -594,15 +659,16 @@ class MissingAnyRole(CheckFailure):
|
||||
The roles that the invoker is missing.
|
||||
These are the parameters passed to :func:`~.commands.has_any_role`.
|
||||
"""
|
||||
|
||||
def __init__(self, missing_roles: SnowflakeList) -> None:
|
||||
self.missing_roles: SnowflakeList = missing_roles
|
||||
|
||||
missing = [f"'{role}'" for role in missing_roles]
|
||||
|
||||
if len(missing) > 2:
|
||||
fmt = '{}, or {}'.format(", ".join(missing[:-1]), missing[-1])
|
||||
fmt = "{}, or {}".format(", ".join(missing[:-1]), missing[-1])
|
||||
else:
|
||||
fmt = ' or '.join(missing)
|
||||
fmt = " or ".join(missing)
|
||||
|
||||
message = f"You are missing at least one of the required roles: {fmt}"
|
||||
super().__init__(message)
|
||||
@ -623,19 +689,21 @@ class BotMissingAnyRole(CheckFailure):
|
||||
These are the parameters passed to :func:`~.commands.has_any_role`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, missing_roles: SnowflakeList) -> None:
|
||||
self.missing_roles: SnowflakeList = missing_roles
|
||||
|
||||
missing = [f"'{role}'" for role in missing_roles]
|
||||
|
||||
if len(missing) > 2:
|
||||
fmt = '{}, or {}'.format(", ".join(missing[:-1]), missing[-1])
|
||||
fmt = "{}, or {}".format(", ".join(missing[:-1]), missing[-1])
|
||||
else:
|
||||
fmt = ' or '.join(missing)
|
||||
fmt = " or ".join(missing)
|
||||
|
||||
message = f"Bot is missing at least one of the required roles: {fmt}"
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class NSFWChannelRequired(CheckFailure):
|
||||
"""Exception raised when a channel does not have the required NSFW setting.
|
||||
|
||||
@ -648,10 +716,12 @@ class NSFWChannelRequired(CheckFailure):
|
||||
channel: Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
|
||||
The channel that does not have NSFW enabled.
|
||||
"""
|
||||
|
||||
def __init__(self, channel: Union[GuildChannel, Thread]) -> None:
|
||||
self.channel: Union[GuildChannel, Thread] = channel
|
||||
super().__init__(f"Channel '{channel}' needs to be NSFW for this command to work.")
|
||||
|
||||
|
||||
class MissingPermissions(CheckFailure):
|
||||
"""Exception raised when the command invoker lacks permissions to run a
|
||||
command.
|
||||
@ -663,18 +733,20 @@ class MissingPermissions(CheckFailure):
|
||||
missing_permissions: List[:class:`str`]
|
||||
The required permissions that are missing.
|
||||
"""
|
||||
|
||||
def __init__(self, missing_permissions: List[str], *args: Any) -> None:
|
||||
self.missing_permissions: List[str] = missing_permissions
|
||||
|
||||
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions]
|
||||
missing = [perm.replace("_", " ").replace("guild", "server").title() for perm in missing_permissions]
|
||||
|
||||
if len(missing) > 2:
|
||||
fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1])
|
||||
fmt = "{}, and {}".format(", ".join(missing[:-1]), missing[-1])
|
||||
else:
|
||||
fmt = ' and '.join(missing)
|
||||
message = f'You are missing {fmt} permission(s) to run this command.'
|
||||
fmt = " and ".join(missing)
|
||||
message = f"You are missing {fmt} permission(s) to run this command."
|
||||
super().__init__(message, *args)
|
||||
|
||||
|
||||
class BotMissingPermissions(CheckFailure):
|
||||
"""Exception raised when the bot's member lacks permissions to run a
|
||||
command.
|
||||
@ -686,18 +758,20 @@ class BotMissingPermissions(CheckFailure):
|
||||
missing_permissions: List[:class:`str`]
|
||||
The required permissions that are missing.
|
||||
"""
|
||||
|
||||
def __init__(self, missing_permissions: List[str], *args: Any) -> None:
|
||||
self.missing_permissions: List[str] = missing_permissions
|
||||
|
||||
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions]
|
||||
missing = [perm.replace("_", " ").replace("guild", "server").title() for perm in missing_permissions]
|
||||
|
||||
if len(missing) > 2:
|
||||
fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1])
|
||||
fmt = "{}, and {}".format(", ".join(missing[:-1]), missing[-1])
|
||||
else:
|
||||
fmt = ' and '.join(missing)
|
||||
message = f'Bot requires {fmt} permission(s) to run this command.'
|
||||
fmt = " and ".join(missing)
|
||||
message = f"Bot requires {fmt} permission(s) to run this command."
|
||||
super().__init__(message, *args)
|
||||
|
||||
|
||||
class BadUnionArgument(UserInputError):
|
||||
"""Exception raised when a :data:`typing.Union` converter fails for all
|
||||
its associated types.
|
||||
@ -713,6 +787,7 @@ class BadUnionArgument(UserInputError):
|
||||
errors: List[:class:`CommandError`]
|
||||
A list of errors that were caught from failing the conversion.
|
||||
"""
|
||||
|
||||
def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None:
|
||||
self.param: Parameter = param
|
||||
self.converters: Tuple[Type, ...] = converters
|
||||
@ -722,18 +797,19 @@ class BadUnionArgument(UserInputError):
|
||||
try:
|
||||
return x.__name__
|
||||
except AttributeError:
|
||||
if hasattr(x, '__origin__'):
|
||||
if hasattr(x, "__origin__"):
|
||||
return repr(x)
|
||||
return x.__class__.__name__
|
||||
|
||||
to_string = [_get_name(x) for x in converters]
|
||||
if len(to_string) > 2:
|
||||
fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1])
|
||||
fmt = "{}, or {}".format(", ".join(to_string[:-1]), to_string[-1])
|
||||
else:
|
||||
fmt = ' or '.join(to_string)
|
||||
fmt = " or ".join(to_string)
|
||||
|
||||
super().__init__(f'Could not convert "{param.name}" into {fmt}.')
|
||||
|
||||
|
||||
class BadLiteralArgument(UserInputError):
|
||||
"""Exception raised when a :data:`typing.Literal` converter fails for all
|
||||
its associated values.
|
||||
@ -751,6 +827,7 @@ class BadLiteralArgument(UserInputError):
|
||||
errors: List[:class:`CommandError`]
|
||||
A list of errors that were caught from failing the conversion.
|
||||
"""
|
||||
|
||||
def __init__(self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError]) -> None:
|
||||
self.param: Parameter = param
|
||||
self.literals: Tuple[Any, ...] = literals
|
||||
@ -758,12 +835,13 @@ class BadLiteralArgument(UserInputError):
|
||||
|
||||
to_string = [repr(l) for l in literals]
|
||||
if len(to_string) > 2:
|
||||
fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1])
|
||||
fmt = "{}, or {}".format(", ".join(to_string[:-1]), to_string[-1])
|
||||
else:
|
||||
fmt = ' or '.join(to_string)
|
||||
fmt = " or ".join(to_string)
|
||||
|
||||
super().__init__(f'Could not convert "{param.name}" into the literal {fmt}.')
|
||||
|
||||
|
||||
class ArgumentParsingError(UserInputError):
|
||||
"""An exception raised when the parser fails to parse a user's input.
|
||||
|
||||
@ -772,8 +850,10 @@ class ArgumentParsingError(UserInputError):
|
||||
There are child classes that implement more granular parsing errors for
|
||||
i18n purposes.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UnexpectedQuoteError(ArgumentParsingError):
|
||||
"""An exception raised when the parser encounters a quote mark inside a non-quoted string.
|
||||
|
||||
@ -784,9 +864,11 @@ class UnexpectedQuoteError(ArgumentParsingError):
|
||||
quote: :class:`str`
|
||||
The quote mark that was found inside the non-quoted string.
|
||||
"""
|
||||
|
||||
def __init__(self, quote: str) -> None:
|
||||
self.quote: str = quote
|
||||
super().__init__(f'Unexpected quote mark, {quote!r}, in non-quoted string')
|
||||
super().__init__(f"Unexpected quote mark, {quote!r}, in non-quoted string")
|
||||
|
||||
|
||||
class InvalidEndOfQuotedStringError(ArgumentParsingError):
|
||||
"""An exception raised when a space is expected after the closing quote in a string
|
||||
@ -799,9 +881,11 @@ class InvalidEndOfQuotedStringError(ArgumentParsingError):
|
||||
char: :class:`str`
|
||||
The character found instead of the expected string.
|
||||
"""
|
||||
|
||||
def __init__(self, char: str) -> None:
|
||||
self.char: str = char
|
||||
super().__init__(f'Expected space after closing quotation but received {char!r}')
|
||||
super().__init__(f"Expected space after closing quotation but received {char!r}")
|
||||
|
||||
|
||||
class ExpectedClosingQuoteError(ArgumentParsingError):
|
||||
"""An exception raised when a quote character is expected but not found.
|
||||
@ -816,7 +900,8 @@ class ExpectedClosingQuoteError(ArgumentParsingError):
|
||||
|
||||
def __init__(self, close_quote: str) -> None:
|
||||
self.close_quote: str = close_quote
|
||||
super().__init__(f'Expected closing {close_quote}.')
|
||||
super().__init__(f"Expected closing {close_quote}.")
|
||||
|
||||
|
||||
class ExtensionError(DiscordException):
|
||||
"""Base exception for extension related errors.
|
||||
@ -828,37 +913,45 @@ class ExtensionError(DiscordException):
|
||||
name: :class:`str`
|
||||
The extension that had an error.
|
||||
"""
|
||||
|
||||
def __init__(self, message: Optional[str] = None, *args: Any, name: str) -> None:
|
||||
self.name: str = name
|
||||
message = message or f'Extension {name!r} had an error.'
|
||||
message = message or f"Extension {name!r} had an error."
|
||||
# clean-up @everyone and @here mentions
|
||||
m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere')
|
||||
m = message.replace("@everyone", "@\u200beveryone").replace("@here", "@\u200bhere")
|
||||
super().__init__(m, *args)
|
||||
|
||||
|
||||
class ExtensionAlreadyLoaded(ExtensionError):
|
||||
"""An exception raised when an extension has already been loaded.
|
||||
|
||||
This inherits from :exc:`ExtensionError`
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
super().__init__(f'Extension {name!r} is already loaded.', name=name)
|
||||
super().__init__(f"Extension {name!r} is already loaded.", name=name)
|
||||
|
||||
|
||||
class ExtensionNotLoaded(ExtensionError):
|
||||
"""An exception raised when an extension was not loaded.
|
||||
|
||||
This inherits from :exc:`ExtensionError`
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
super().__init__(f'Extension {name!r} has not been loaded.', name=name)
|
||||
super().__init__(f"Extension {name!r} has not been loaded.", name=name)
|
||||
|
||||
|
||||
class NoEntryPointError(ExtensionError):
|
||||
"""An exception raised when an extension does not have a ``setup`` entry point function.
|
||||
|
||||
This inherits from :exc:`ExtensionError`
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
super().__init__(f"Extension {name!r} has no 'setup' function.", name=name)
|
||||
|
||||
|
||||
class ExtensionFailed(ExtensionError):
|
||||
"""An exception raised when an extension failed to load during execution of the module or ``setup`` entry point.
|
||||
|
||||
@ -872,11 +965,13 @@ class ExtensionFailed(ExtensionError):
|
||||
The original exception that was raised. You can also get this via
|
||||
the ``__cause__`` attribute.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, original: Exception) -> None:
|
||||
self.original: Exception = original
|
||||
msg = f'Extension {name!r} raised an error: {original.__class__.__name__}: {original}'
|
||||
msg = f"Extension {name!r} raised an error: {original.__class__.__name__}: {original}"
|
||||
super().__init__(msg, name=name)
|
||||
|
||||
|
||||
class ExtensionNotFound(ExtensionError):
|
||||
"""An exception raised when an extension is not found.
|
||||
|
||||
@ -890,10 +985,12 @@ class ExtensionNotFound(ExtensionError):
|
||||
name: :class:`str`
|
||||
The extension that had the error.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
msg = f'Extension {name!r} could not be loaded.'
|
||||
msg = f"Extension {name!r} could not be loaded."
|
||||
super().__init__(msg, name=name)
|
||||
|
||||
|
||||
class CommandRegistrationError(ClientException):
|
||||
"""An exception raised when the command can't be added
|
||||
because the name is already taken by a different command.
|
||||
@ -909,11 +1006,13 @@ class CommandRegistrationError(ClientException):
|
||||
alias_conflict: :class:`bool`
|
||||
Whether the name that conflicts is an alias of the command we try to add.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, *, alias_conflict: bool = False) -> None:
|
||||
self.name: str = name
|
||||
self.alias_conflict: bool = alias_conflict
|
||||
type_ = 'alias' if alias_conflict else 'command'
|
||||
super().__init__(f'The {type_} {name} is already an existing command or alias.')
|
||||
type_ = "alias" if alias_conflict else "command"
|
||||
super().__init__(f"The {type_} {name} is already an existing command or alias.")
|
||||
|
||||
|
||||
class FlagError(BadArgument):
|
||||
"""The base exception type for all flag parsing related errors.
|
||||
@ -922,8 +1021,10 @@ class FlagError(BadArgument):
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TooManyFlags(FlagError):
|
||||
"""An exception raised when a flag has received too many values.
|
||||
|
||||
@ -938,10 +1039,12 @@ class TooManyFlags(FlagError):
|
||||
values: List[:class:`str`]
|
||||
The values that were passed.
|
||||
"""
|
||||
|
||||
def __init__(self, flag: Flag, values: List[str]) -> None:
|
||||
self.flag: Flag = flag
|
||||
self.values: List[str] = values
|
||||
super().__init__(f'Too many flag values, expected {flag.max_args} but received {len(values)}.')
|
||||
super().__init__(f"Too many flag values, expected {flag.max_args} but received {len(values)}.")
|
||||
|
||||
|
||||
class BadFlagArgument(FlagError):
|
||||
"""An exception raised when a flag failed to convert a value.
|
||||
@ -955,6 +1058,7 @@ class BadFlagArgument(FlagError):
|
||||
flag: :class:`~discord.ext.commands.Flag`
|
||||
The flag that failed to convert.
|
||||
"""
|
||||
|
||||
def __init__(self, flag: Flag) -> None:
|
||||
self.flag: Flag = flag
|
||||
try:
|
||||
@ -962,7 +1066,8 @@ class BadFlagArgument(FlagError):
|
||||
except AttributeError:
|
||||
name = flag.annotation.__class__.__name__
|
||||
|
||||
super().__init__(f'Could not convert to {name!r} for flag {flag.name!r}')
|
||||
super().__init__(f"Could not convert to {name!r} for flag {flag.name!r}")
|
||||
|
||||
|
||||
class MissingRequiredFlag(FlagError):
|
||||
"""An exception raised when a required flag was not given.
|
||||
@ -976,9 +1081,11 @@ class MissingRequiredFlag(FlagError):
|
||||
flag: :class:`~discord.ext.commands.Flag`
|
||||
The required flag that was not found.
|
||||
"""
|
||||
|
||||
def __init__(self, flag: Flag) -> None:
|
||||
self.flag: Flag = flag
|
||||
super().__init__(f'Flag {flag.name!r} is required and missing')
|
||||
super().__init__(f"Flag {flag.name!r} is required and missing")
|
||||
|
||||
|
||||
class MissingFlagArgument(FlagError):
|
||||
"""An exception raised when a flag did not get a value.
|
||||
@ -992,6 +1099,7 @@ class MissingFlagArgument(FlagError):
|
||||
flag: :class:`~discord.ext.commands.Flag`
|
||||
The flag that did not get a value.
|
||||
"""
|
||||
|
||||
def __init__(self, flag: Flag) -> None:
|
||||
self.flag: Flag = flag
|
||||
super().__init__(f'Flag {flag.name!r} does not have an argument')
|
||||
super().__init__(f"Flag {flag.name!r} does not have an argument")
|
||||
|
@ -59,9 +59,9 @@ import sys
|
||||
import re
|
||||
|
||||
__all__ = (
|
||||
'Flag',
|
||||
'flag',
|
||||
'FlagConverter',
|
||||
"Flag",
|
||||
"flag",
|
||||
"FlagConverter",
|
||||
)
|
||||
|
||||
|
||||
@ -148,20 +148,20 @@ def flag(
|
||||
|
||||
def validate_flag_name(name: str, forbidden: Set[str]):
|
||||
if not name:
|
||||
raise ValueError('flag names should not be empty')
|
||||
raise ValueError("flag names should not be empty")
|
||||
|
||||
for ch in name:
|
||||
if ch.isspace():
|
||||
raise ValueError(f'flag name {name!r} cannot have spaces')
|
||||
if ch == '\\':
|
||||
raise ValueError(f'flag name {name!r} cannot have backslashes')
|
||||
raise ValueError(f"flag name {name!r} cannot have spaces")
|
||||
if ch == "\\":
|
||||
raise ValueError(f"flag name {name!r} cannot have backslashes")
|
||||
if ch in forbidden:
|
||||
raise ValueError(f'flag name {name!r} cannot have any of {forbidden!r} within them')
|
||||
raise ValueError(f"flag name {name!r} cannot have any of {forbidden!r} within them")
|
||||
|
||||
|
||||
def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]:
|
||||
annotations = namespace.get('__annotations__', {})
|
||||
case_insensitive = namespace['__commands_flag_case_insensitive__']
|
||||
annotations = namespace.get("__annotations__", {})
|
||||
case_insensitive = namespace["__commands_flag_case_insensitive__"]
|
||||
flags: Dict[str, Flag] = {}
|
||||
cache: Dict[str, Any] = {}
|
||||
names: Set[str] = set()
|
||||
@ -178,7 +178,11 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
|
||||
|
||||
annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache)
|
||||
|
||||
if flag.default is MISSING and hasattr(annotation, '__commands_is_flag__') and annotation._can_be_constructible():
|
||||
if (
|
||||
flag.default is MISSING
|
||||
and hasattr(annotation, "__commands_is_flag__")
|
||||
and annotation._can_be_constructible()
|
||||
):
|
||||
flag.default = annotation._construct_default
|
||||
|
||||
if flag.aliases is MISSING:
|
||||
@ -229,7 +233,7 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
|
||||
if flag.max_args is MISSING:
|
||||
flag.max_args = 1
|
||||
else:
|
||||
raise TypeError(f'Unsupported typing annotation {annotation!r} for {flag.name!r} flag')
|
||||
raise TypeError(f"Unsupported typing annotation {annotation!r} for {flag.name!r} flag")
|
||||
|
||||
if flag.override is MISSING:
|
||||
flag.override = False
|
||||
@ -237,7 +241,7 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
|
||||
# Validate flag names are unique
|
||||
name = flag.name.casefold() if case_insensitive else flag.name
|
||||
if name in names:
|
||||
raise TypeError(f'{flag.name!r} flag conflicts with previous flag or alias.')
|
||||
raise TypeError(f"{flag.name!r} flag conflicts with previous flag or alias.")
|
||||
else:
|
||||
names.add(name)
|
||||
|
||||
@ -245,7 +249,7 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
|
||||
# Validate alias is unique
|
||||
alias = alias.casefold() if case_insensitive else alias
|
||||
if alias in names:
|
||||
raise TypeError(f'{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.')
|
||||
raise TypeError(f"{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.")
|
||||
else:
|
||||
names.add(alias)
|
||||
|
||||
@ -274,10 +278,10 @@ class FlagsMeta(type):
|
||||
delimiter: str = MISSING,
|
||||
prefix: str = MISSING,
|
||||
):
|
||||
attrs['__commands_is_flag__'] = True
|
||||
attrs["__commands_is_flag__"] = True
|
||||
|
||||
try:
|
||||
global_ns = sys.modules[attrs['__module__']].__dict__
|
||||
global_ns = sys.modules[attrs["__module__"]].__dict__
|
||||
except KeyError:
|
||||
global_ns = {}
|
||||
|
||||
@ -296,26 +300,26 @@ class FlagsMeta(type):
|
||||
flags: Dict[str, Flag] = {}
|
||||
aliases: Dict[str, str] = {}
|
||||
for base in reversed(bases):
|
||||
if base.__dict__.get('__commands_is_flag__', False):
|
||||
flags.update(base.__dict__['__commands_flags__'])
|
||||
aliases.update(base.__dict__['__commands_flag_aliases__'])
|
||||
if base.__dict__.get("__commands_is_flag__", False):
|
||||
flags.update(base.__dict__["__commands_flags__"])
|
||||
aliases.update(base.__dict__["__commands_flag_aliases__"])
|
||||
if case_insensitive is MISSING:
|
||||
attrs['__commands_flag_case_insensitive__'] = base.__dict__['__commands_flag_case_insensitive__']
|
||||
attrs["__commands_flag_case_insensitive__"] = base.__dict__["__commands_flag_case_insensitive__"]
|
||||
if delimiter is MISSING:
|
||||
attrs['__commands_flag_delimiter__'] = base.__dict__['__commands_flag_delimiter__']
|
||||
attrs["__commands_flag_delimiter__"] = base.__dict__["__commands_flag_delimiter__"]
|
||||
if prefix is MISSING:
|
||||
attrs['__commands_flag_prefix__'] = base.__dict__['__commands_flag_prefix__']
|
||||
attrs["__commands_flag_prefix__"] = base.__dict__["__commands_flag_prefix__"]
|
||||
|
||||
if case_insensitive is not MISSING:
|
||||
attrs['__commands_flag_case_insensitive__'] = case_insensitive
|
||||
attrs["__commands_flag_case_insensitive__"] = case_insensitive
|
||||
if delimiter is not MISSING:
|
||||
attrs['__commands_flag_delimiter__'] = delimiter
|
||||
attrs["__commands_flag_delimiter__"] = delimiter
|
||||
if prefix is not MISSING:
|
||||
attrs['__commands_flag_prefix__'] = prefix
|
||||
attrs["__commands_flag_prefix__"] = prefix
|
||||
|
||||
case_insensitive = attrs.setdefault('__commands_flag_case_insensitive__', False)
|
||||
delimiter = attrs.setdefault('__commands_flag_delimiter__', ':')
|
||||
prefix = attrs.setdefault('__commands_flag_prefix__', '')
|
||||
case_insensitive = attrs.setdefault("__commands_flag_case_insensitive__", False)
|
||||
delimiter = attrs.setdefault("__commands_flag_delimiter__", ":")
|
||||
prefix = attrs.setdefault("__commands_flag_prefix__", "")
|
||||
|
||||
for flag_name, flag in get_flags(attrs, global_ns, local_ns).items():
|
||||
flags[flag_name] = flag
|
||||
@ -337,11 +341,11 @@ class FlagsMeta(type):
|
||||
keys.extend(re.escape(a) for a in aliases)
|
||||
keys = sorted(keys, key=lambda t: len(t), reverse=True)
|
||||
|
||||
joined = '|'.join(keys)
|
||||
pattern = re.compile(f'(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})', regex_flags)
|
||||
attrs['__commands_flag_regex__'] = pattern
|
||||
attrs['__commands_flags__'] = flags
|
||||
attrs['__commands_flag_aliases__'] = aliases
|
||||
joined = "|".join(keys)
|
||||
pattern = re.compile(f"(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})", regex_flags)
|
||||
attrs["__commands_flag_regex__"] = pattern
|
||||
attrs["__commands_flags__"] = flags
|
||||
attrs["__commands_flag_aliases__"] = aliases
|
||||
|
||||
return type.__new__(cls, name, bases, attrs)
|
||||
|
||||
@ -432,7 +436,7 @@ async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -
|
||||
raise BadFlagArgument(flag) from e
|
||||
|
||||
|
||||
F = TypeVar('F', bound='FlagConverter')
|
||||
F = TypeVar("F", bound="FlagConverter")
|
||||
|
||||
|
||||
class FlagConverter(metaclass=FlagsMeta):
|
||||
@ -493,8 +497,8 @@ class FlagConverter(metaclass=FlagsMeta):
|
||||
return self
|
||||
|
||||
def __repr__(self) -> str:
|
||||
pairs = ' '.join([f'{flag.attribute}={getattr(self, flag.attribute)!r}' for flag in self.get_flags().values()])
|
||||
return f'<{self.__class__.__name__} {pairs}>'
|
||||
pairs = " ".join([f"{flag.attribute}={getattr(self, flag.attribute)!r}" for flag in self.get_flags().values()])
|
||||
return f"<{self.__class__.__name__} {pairs}>"
|
||||
|
||||
@classmethod
|
||||
def parse_flags(cls, argument: str) -> Dict[str, List[str]]:
|
||||
@ -507,7 +511,7 @@ class FlagConverter(metaclass=FlagsMeta):
|
||||
case_insensitive = cls.__commands_flag_case_insensitive__
|
||||
for match in cls.__commands_flag_regex__.finditer(argument):
|
||||
begin, end = match.span(0)
|
||||
key = match.group('flag')
|
||||
key = match.group("flag")
|
||||
if case_insensitive:
|
||||
key = key.casefold()
|
||||
|
||||
|
@ -39,10 +39,10 @@ if TYPE_CHECKING:
|
||||
from .context import Context
|
||||
|
||||
__all__ = (
|
||||
'Paginator',
|
||||
'HelpCommand',
|
||||
'DefaultHelpCommand',
|
||||
'MinimalHelpCommand',
|
||||
"Paginator",
|
||||
"HelpCommand",
|
||||
"DefaultHelpCommand",
|
||||
"MinimalHelpCommand",
|
||||
)
|
||||
|
||||
# help -> shows info of bot on top/bottom and lists subcommands
|
||||
@ -89,7 +89,7 @@ class Paginator:
|
||||
.. versionadded:: 1.7
|
||||
"""
|
||||
|
||||
def __init__(self, prefix='```', suffix='```', max_size=2000, linesep='\n'):
|
||||
def __init__(self, prefix="```", suffix="```", max_size=2000, linesep="\n"):
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
self.max_size = max_size
|
||||
@ -118,7 +118,7 @@ class Paginator:
|
||||
def _linesep_len(self):
|
||||
return len(self.linesep)
|
||||
|
||||
def add_line(self, line='', *, empty=False):
|
||||
def add_line(self, line="", *, empty=False):
|
||||
"""Adds a line to the current page.
|
||||
|
||||
If the line exceeds the :attr:`max_size` then an exception
|
||||
@ -138,7 +138,7 @@ class Paginator:
|
||||
"""
|
||||
max_page_size = self.max_size - self._prefix_len - self._suffix_len - 2 * self._linesep_len
|
||||
if len(line) > max_page_size:
|
||||
raise RuntimeError(f'Line exceeds maximum page size {max_page_size}')
|
||||
raise RuntimeError(f"Line exceeds maximum page size {max_page_size}")
|
||||
|
||||
if self._count + len(line) + self._linesep_len > self.max_size - self._suffix_len:
|
||||
self.close_page()
|
||||
@ -147,7 +147,7 @@ class Paginator:
|
||||
self._current_page.append(line)
|
||||
|
||||
if empty:
|
||||
self._current_page.append('')
|
||||
self._current_page.append("")
|
||||
self._count += self._linesep_len
|
||||
|
||||
def close_page(self):
|
||||
@ -176,7 +176,7 @@ class Paginator:
|
||||
return self._pages
|
||||
|
||||
def __repr__(self):
|
||||
fmt = '<Paginator prefix: {0.prefix!r} suffix: {0.suffix!r} linesep: {0.linesep!r} max_size: {0.max_size} count: {0._count}>'
|
||||
fmt = "<Paginator prefix: {0.prefix!r} suffix: {0.suffix!r} linesep: {0.linesep!r} max_size: {0.max_size} count: {0._count}>"
|
||||
return fmt.format(self)
|
||||
|
||||
|
||||
@ -197,7 +197,7 @@ class _HelpCommandImpl(Command):
|
||||
self.callback = injected.command_callback
|
||||
|
||||
on_error = injected.on_help_command_error
|
||||
if not hasattr(on_error, '__help_command_not_overriden__'):
|
||||
if not hasattr(on_error, "__help_command_not_overriden__"):
|
||||
if self.cog is not None:
|
||||
self.on_error = self._on_error_cog_implementation
|
||||
else:
|
||||
@ -224,7 +224,7 @@ class _HelpCommandImpl(Command):
|
||||
try:
|
||||
del result[next(iter(result))]
|
||||
except StopIteration:
|
||||
raise ValueError('Missing context parameter') from None
|
||||
raise ValueError("Missing context parameter") from None
|
||||
else:
|
||||
return result
|
||||
|
||||
@ -296,13 +296,13 @@ class HelpCommand:
|
||||
"""
|
||||
|
||||
MENTION_TRANSFORMS = {
|
||||
'@everyone': '@\u200beveryone',
|
||||
'@here': '@\u200bhere',
|
||||
r'<@!?[0-9]{17,22}>': '@deleted-user',
|
||||
r'<@&[0-9]{17,22}>': '@deleted-role',
|
||||
"@everyone": "@\u200beveryone",
|
||||
"@here": "@\u200bhere",
|
||||
r"<@!?[0-9]{17,22}>": "@deleted-user",
|
||||
r"<@&[0-9]{17,22}>": "@deleted-role",
|
||||
}
|
||||
|
||||
MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys()))
|
||||
MENTION_PATTERN = re.compile("|".join(MENTION_TRANSFORMS.keys()))
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# To prevent race conditions of a single instance while also allowing
|
||||
@ -321,11 +321,11 @@ class HelpCommand:
|
||||
return self
|
||||
|
||||
def __init__(self, **options):
|
||||
self.show_hidden = options.pop('show_hidden', False)
|
||||
self.verify_checks = options.pop('verify_checks', True)
|
||||
self.command_attrs = attrs = options.pop('command_attrs', {})
|
||||
attrs.setdefault('name', 'help')
|
||||
attrs.setdefault('help', 'Shows this message')
|
||||
self.show_hidden = options.pop("show_hidden", False)
|
||||
self.verify_checks = options.pop("verify_checks", True)
|
||||
self.command_attrs = attrs = options.pop("command_attrs", {})
|
||||
attrs.setdefault("name", "help")
|
||||
attrs.setdefault("help", "Shows this message")
|
||||
self.context: Context = discord.utils.MISSING
|
||||
self._command_impl = _HelpCommandImpl(self, **self.command_attrs)
|
||||
|
||||
@ -422,20 +422,20 @@ class HelpCommand:
|
||||
if not parent.signature or parent.invoke_without_command:
|
||||
entries.append(parent.name)
|
||||
else:
|
||||
entries.append(parent.name + ' ' + parent.signature)
|
||||
entries.append(parent.name + " " + parent.signature)
|
||||
parent = parent.parent
|
||||
parent_sig = ' '.join(reversed(entries))
|
||||
parent_sig = " ".join(reversed(entries))
|
||||
|
||||
if len(command.aliases) > 0:
|
||||
aliases = '|'.join(command.aliases)
|
||||
fmt = f'[{command.name}|{aliases}]'
|
||||
aliases = "|".join(command.aliases)
|
||||
fmt = f"[{command.name}|{aliases}]"
|
||||
if parent_sig:
|
||||
fmt = parent_sig + ' ' + fmt
|
||||
fmt = parent_sig + " " + fmt
|
||||
alias = fmt
|
||||
else:
|
||||
alias = command.name if not parent_sig else parent_sig + ' ' + command.name
|
||||
alias = command.name if not parent_sig else parent_sig + " " + command.name
|
||||
|
||||
return f'{self.context.clean_prefix}{alias} {command.signature}'
|
||||
return f"{self.context.clean_prefix}{alias} {command.signature}"
|
||||
|
||||
def remove_mentions(self, string):
|
||||
"""Removes mentions from the string to prevent abuse.
|
||||
@ -449,7 +449,7 @@ class HelpCommand:
|
||||
"""
|
||||
|
||||
def replace(obj, *, transforms=self.MENTION_TRANSFORMS):
|
||||
return transforms.get(obj.group(0), '@invalid')
|
||||
return transforms.get(obj.group(0), "@invalid")
|
||||
|
||||
return self.MENTION_PATTERN.sub(replace, string)
|
||||
|
||||
@ -846,7 +846,7 @@ class HelpCommand:
|
||||
# Since we want to have detailed errors when someone
|
||||
# passes an invalid subcommand, we need to walk through
|
||||
# the command group chain ourselves.
|
||||
keys = command.split(' ')
|
||||
keys = command.split(" ")
|
||||
cmd = bot.all_commands.get(keys[0])
|
||||
if cmd is None:
|
||||
string = await maybe_coro(self.command_not_found, self.remove_mentions(keys[0]))
|
||||
@ -907,14 +907,14 @@ class DefaultHelpCommand(HelpCommand):
|
||||
"""
|
||||
|
||||
def __init__(self, **options):
|
||||
self.width = options.pop('width', 80)
|
||||
self.indent = options.pop('indent', 2)
|
||||
self.sort_commands = options.pop('sort_commands', True)
|
||||
self.dm_help = options.pop('dm_help', False)
|
||||
self.dm_help_threshold = options.pop('dm_help_threshold', 1000)
|
||||
self.commands_heading = options.pop('commands_heading', "Commands:")
|
||||
self.no_category = options.pop('no_category', 'No Category')
|
||||
self.paginator = options.pop('paginator', None)
|
||||
self.width = options.pop("width", 80)
|
||||
self.indent = options.pop("indent", 2)
|
||||
self.sort_commands = options.pop("sort_commands", True)
|
||||
self.dm_help = options.pop("dm_help", False)
|
||||
self.dm_help_threshold = options.pop("dm_help_threshold", 1000)
|
||||
self.commands_heading = options.pop("commands_heading", "Commands:")
|
||||
self.no_category = options.pop("no_category", "No Category")
|
||||
self.paginator = options.pop("paginator", None)
|
||||
|
||||
if self.paginator is None:
|
||||
self.paginator = Paginator()
|
||||
@ -924,7 +924,7 @@ class DefaultHelpCommand(HelpCommand):
|
||||
def shorten_text(self, text):
|
||||
""":class:`str`: Shortens text to fit into the :attr:`width`."""
|
||||
if len(text) > self.width:
|
||||
return text[:self.width - 3].rstrip() + '...'
|
||||
return text[: self.width - 3].rstrip() + "..."
|
||||
return text
|
||||
|
||||
def get_ending_note(self):
|
||||
@ -1021,11 +1021,11 @@ class DefaultHelpCommand(HelpCommand):
|
||||
# <description> portion
|
||||
self.paginator.add_line(bot.description, empty=True)
|
||||
|
||||
no_category = f'\u200b{self.no_category}:'
|
||||
no_category = f"\u200b{self.no_category}:"
|
||||
|
||||
def get_category(command, *, no_category=no_category):
|
||||
cog = command.cog
|
||||
return cog.qualified_name + ':' if cog is not None else no_category
|
||||
return cog.qualified_name + ":" if cog is not None else no_category
|
||||
|
||||
filtered = await self.filter_commands(bot.commands, sort=True, key=get_category)
|
||||
max_size = self.get_max_size(filtered)
|
||||
@ -1110,13 +1110,13 @@ class MinimalHelpCommand(HelpCommand):
|
||||
"""
|
||||
|
||||
def __init__(self, **options):
|
||||
self.sort_commands = options.pop('sort_commands', True)
|
||||
self.commands_heading = options.pop('commands_heading', "Commands")
|
||||
self.dm_help = options.pop('dm_help', False)
|
||||
self.dm_help_threshold = options.pop('dm_help_threshold', 1000)
|
||||
self.aliases_heading = options.pop('aliases_heading', "Aliases:")
|
||||
self.no_category = options.pop('no_category', 'No Category')
|
||||
self.paginator = options.pop('paginator', None)
|
||||
self.sort_commands = options.pop("sort_commands", True)
|
||||
self.commands_heading = options.pop("commands_heading", "Commands")
|
||||
self.dm_help = options.pop("dm_help", False)
|
||||
self.dm_help_threshold = options.pop("dm_help_threshold", 1000)
|
||||
self.aliases_heading = options.pop("aliases_heading", "Aliases:")
|
||||
self.no_category = options.pop("no_category", "No Category")
|
||||
self.paginator = options.pop("paginator", None)
|
||||
|
||||
if self.paginator is None:
|
||||
self.paginator = Paginator(suffix=None, prefix=None)
|
||||
@ -1149,7 +1149,7 @@ class MinimalHelpCommand(HelpCommand):
|
||||
)
|
||||
|
||||
def get_command_signature(self, command):
|
||||
return f'{self.context.clean_prefix}{command.qualified_name} {command.signature}'
|
||||
return f"{self.context.clean_prefix}{command.qualified_name} {command.signature}"
|
||||
|
||||
def get_ending_note(self):
|
||||
"""Return the help command's ending note. This is mainly useful to override for i18n purposes.
|
||||
@ -1180,8 +1180,8 @@ class MinimalHelpCommand(HelpCommand):
|
||||
"""
|
||||
if commands:
|
||||
# U+2002 Middle Dot
|
||||
joined = '\u2002'.join(c.name for c in commands)
|
||||
self.paginator.add_line(f'__**{heading}**__')
|
||||
joined = "\u2002".join(c.name for c in commands)
|
||||
self.paginator.add_line(f"__**{heading}**__")
|
||||
self.paginator.add_line(joined)
|
||||
|
||||
def add_subcommand_formatting(self, command):
|
||||
@ -1197,7 +1197,7 @@ class MinimalHelpCommand(HelpCommand):
|
||||
command: :class:`Command`
|
||||
The command to show information of.
|
||||
"""
|
||||
fmt = '{0}{1} \N{EN DASH} {2}' if command.short_doc else '{0}{1}'
|
||||
fmt = "{0}{1} \N{EN DASH} {2}" if command.short_doc else "{0}{1}"
|
||||
self.paginator.add_line(fmt.format(self.context.clean_prefix, command.qualified_name, command.short_doc))
|
||||
|
||||
def add_aliases_formatting(self, aliases):
|
||||
@ -1268,7 +1268,7 @@ class MinimalHelpCommand(HelpCommand):
|
||||
if note:
|
||||
self.paginator.add_line(note, empty=True)
|
||||
|
||||
no_category = f'\u200b{self.no_category}'
|
||||
no_category = f"\u200b{self.no_category}"
|
||||
|
||||
def get_category(command, *, no_category=no_category):
|
||||
cog = command.cog
|
||||
@ -1302,7 +1302,7 @@ class MinimalHelpCommand(HelpCommand):
|
||||
|
||||
filtered = await self.filter_commands(cog.get_commands(), sort=self.sort_commands)
|
||||
if filtered:
|
||||
self.paginator.add_line(f'**{cog.qualified_name} {self.commands_heading}**')
|
||||
self.paginator.add_line(f"**{cog.qualified_name} {self.commands_heading}**")
|
||||
for command in filtered:
|
||||
self.add_subcommand_formatting(command)
|
||||
|
||||
@ -1322,7 +1322,7 @@ class MinimalHelpCommand(HelpCommand):
|
||||
if note:
|
||||
self.paginator.add_line(note, empty=True)
|
||||
|
||||
self.paginator.add_line(f'**{self.commands_heading}**')
|
||||
self.paginator.add_line(f"**{self.commands_heading}**")
|
||||
for command in filtered:
|
||||
self.add_subcommand_formatting(command)
|
||||
|
||||
|
@ -46,6 +46,7 @@ _quotes = {
|
||||
}
|
||||
_all_quotes = set(_quotes.keys()) | set(_quotes.values())
|
||||
|
||||
|
||||
class StringView:
|
||||
def __init__(self, buffer):
|
||||
self.index = 0
|
||||
@ -81,20 +82,20 @@ class StringView:
|
||||
|
||||
def skip_string(self, string):
|
||||
strlen = len(string)
|
||||
if self.buffer[self.index:self.index + strlen] == string:
|
||||
if self.buffer[self.index : self.index + strlen] == string:
|
||||
self.previous = self.index
|
||||
self.index += strlen
|
||||
return True
|
||||
return False
|
||||
|
||||
def read_rest(self):
|
||||
result = self.buffer[self.index:]
|
||||
result = self.buffer[self.index :]
|
||||
self.previous = self.index
|
||||
self.index = self.end
|
||||
return result
|
||||
|
||||
def read(self, n):
|
||||
result = self.buffer[self.index:self.index + n]
|
||||
result = self.buffer[self.index : self.index + n]
|
||||
self.previous = self.index
|
||||
self.index += n
|
||||
return result
|
||||
@ -120,7 +121,7 @@ class StringView:
|
||||
except IndexError:
|
||||
break
|
||||
self.previous = self.index
|
||||
result = self.buffer[self.index:self.index + pos]
|
||||
result = self.buffer[self.index : self.index + pos]
|
||||
self.index += pos
|
||||
return result
|
||||
|
||||
@ -144,11 +145,11 @@ class StringView:
|
||||
if is_quoted:
|
||||
# unexpected EOF
|
||||
raise ExpectedClosingQuoteError(close_quote)
|
||||
return ''.join(result)
|
||||
return "".join(result)
|
||||
|
||||
# currently we accept strings in the format of "hello world"
|
||||
# to embed a quote inside the string you must escape it: "a \"world\""
|
||||
if current == '\\':
|
||||
if current == "\\":
|
||||
next_char = self.get()
|
||||
if not next_char:
|
||||
# string ends with \ and no character after it
|
||||
@ -156,7 +157,7 @@ class StringView:
|
||||
# if we're quoted then we're expecting a closing quote
|
||||
raise ExpectedClosingQuoteError(close_quote)
|
||||
# if we aren't then we just let it through
|
||||
return ''.join(result)
|
||||
return "".join(result)
|
||||
|
||||
if next_char in _escaped_quotes:
|
||||
# escaped quote
|
||||
@ -179,14 +180,13 @@ class StringView:
|
||||
raise InvalidEndOfQuotedStringError(next_char)
|
||||
|
||||
# we're quoted so it's okay
|
||||
return ''.join(result)
|
||||
return "".join(result)
|
||||
|
||||
if current.isspace() and not is_quoted:
|
||||
# end of word found
|
||||
return ''.join(result)
|
||||
return "".join(result)
|
||||
|
||||
result.append(current)
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f'<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>'
|
||||
return f"<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>"
|
||||
|
@ -48,19 +48,17 @@ from collections.abc import Sequence
|
||||
from discord.backoff import ExponentialBackoff
|
||||
from discord.utils import MISSING
|
||||
|
||||
__all__ = (
|
||||
'loop',
|
||||
)
|
||||
__all__ = ("loop",)
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
_func = Callable[..., Awaitable[Any]]
|
||||
LF = TypeVar('LF', bound=_func)
|
||||
FT = TypeVar('FT', bound=_func)
|
||||
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
|
||||
LF = TypeVar("LF", bound=_func)
|
||||
FT = TypeVar("FT", bound=_func)
|
||||
ET = TypeVar("ET", bound=Callable[[Any, BaseException], Awaitable[Any]])
|
||||
|
||||
|
||||
class SleepHandle:
|
||||
__slots__ = ('future', 'loop', 'handle')
|
||||
__slots__ = ("future", "loop", "handle")
|
||||
|
||||
def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self.loop = loop
|
||||
@ -124,7 +122,7 @@ class Loop(Generic[LF]):
|
||||
self._stop_next_iteration = False
|
||||
|
||||
if self.count is not None and self.count <= 0:
|
||||
raise ValueError('count must be greater than 0 or None.')
|
||||
raise ValueError("count must be greater than 0 or None.")
|
||||
|
||||
self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time)
|
||||
self._last_iteration_failed = False
|
||||
@ -132,10 +130,10 @@ class Loop(Generic[LF]):
|
||||
self._next_iteration = None
|
||||
|
||||
if not inspect.iscoroutinefunction(self.coro):
|
||||
raise TypeError(f'Expected coroutine function, not {type(self.coro).__name__!r}.')
|
||||
raise TypeError(f"Expected coroutine function, not {type(self.coro).__name__!r}.")
|
||||
|
||||
async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None:
|
||||
coro = getattr(self, '_' + name)
|
||||
coro = getattr(self, "_" + name)
|
||||
if coro is None:
|
||||
return
|
||||
|
||||
@ -150,7 +148,7 @@ class Loop(Generic[LF]):
|
||||
|
||||
async def _loop(self, *args: Any, **kwargs: Any) -> None:
|
||||
backoff = ExponentialBackoff()
|
||||
await self._call_loop_function('before_loop')
|
||||
await self._call_loop_function("before_loop")
|
||||
self._last_iteration_failed = False
|
||||
if self._time is not MISSING:
|
||||
# the time index should be prepared every time the internal loop is started
|
||||
@ -193,10 +191,10 @@ class Loop(Generic[LF]):
|
||||
raise
|
||||
except Exception as exc:
|
||||
self._has_failed = True
|
||||
await self._call_loop_function('error', exc)
|
||||
await self._call_loop_function("error", exc)
|
||||
raise exc
|
||||
finally:
|
||||
await self._call_loop_function('after_loop')
|
||||
await self._call_loop_function("after_loop")
|
||||
self._handle.cancel()
|
||||
self._is_being_cancelled = False
|
||||
self._current_loop = 0
|
||||
@ -323,7 +321,7 @@ class Loop(Generic[LF]):
|
||||
"""
|
||||
|
||||
if self._task is not MISSING and not self._task.done():
|
||||
raise RuntimeError('Task is already launched and is not completed.')
|
||||
raise RuntimeError("Task is already launched and is not completed.")
|
||||
|
||||
if self._injected is not None:
|
||||
args = (self._injected, *args)
|
||||
@ -410,9 +408,9 @@ class Loop(Generic[LF]):
|
||||
|
||||
for exc in exceptions:
|
||||
if not inspect.isclass(exc):
|
||||
raise TypeError(f'{exc!r} must be a class.')
|
||||
raise TypeError(f"{exc!r} must be a class.")
|
||||
if not issubclass(exc, BaseException):
|
||||
raise TypeError(f'{exc!r} must inherit from BaseException.')
|
||||
raise TypeError(f"{exc!r} must inherit from BaseException.")
|
||||
|
||||
self._valid_exception = (*self._valid_exception, *exceptions)
|
||||
|
||||
@ -466,7 +464,7 @@ class Loop(Generic[LF]):
|
||||
|
||||
async def _error(self, *args: Any) -> None:
|
||||
exception: Exception = args[-1]
|
||||
print(f'Unhandled exception in internal background task {self.coro.__name__!r}.', file=sys.stderr)
|
||||
print(f"Unhandled exception in internal background task {self.coro.__name__!r}.", file=sys.stderr)
|
||||
traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr)
|
||||
|
||||
def before_loop(self, coro: FT) -> FT:
|
||||
@ -489,7 +487,7 @@ class Loop(Generic[LF]):
|
||||
"""
|
||||
|
||||
if not inspect.iscoroutinefunction(coro):
|
||||
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
|
||||
raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.")
|
||||
|
||||
self._before_loop = coro
|
||||
return coro
|
||||
@ -517,7 +515,7 @@ class Loop(Generic[LF]):
|
||||
"""
|
||||
|
||||
if not inspect.iscoroutinefunction(coro):
|
||||
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
|
||||
raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.")
|
||||
|
||||
self._after_loop = coro
|
||||
return coro
|
||||
@ -543,7 +541,7 @@ class Loop(Generic[LF]):
|
||||
The function was not a coroutine.
|
||||
"""
|
||||
if not inspect.iscoroutinefunction(coro):
|
||||
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
|
||||
raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.")
|
||||
|
||||
self._error = coro # type: ignore
|
||||
return coro
|
||||
@ -601,16 +599,16 @@ class Loop(Generic[LF]):
|
||||
return [inner]
|
||||
if not isinstance(time, Sequence):
|
||||
raise TypeError(
|
||||
f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.'
|
||||
f"Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead."
|
||||
)
|
||||
if not time:
|
||||
raise ValueError('time parameter must not be an empty sequence.')
|
||||
raise ValueError("time parameter must not be an empty sequence.")
|
||||
|
||||
ret: List[datetime.time] = []
|
||||
for index, t in enumerate(time):
|
||||
if not isinstance(t, dt):
|
||||
raise TypeError(
|
||||
f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.'
|
||||
f"Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead."
|
||||
)
|
||||
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
|
||||
|
||||
@ -663,7 +661,7 @@ class Loop(Generic[LF]):
|
||||
hours = hours or 0
|
||||
sleep = seconds + (minutes * 60.0) + (hours * 3600.0)
|
||||
if sleep < 0:
|
||||
raise ValueError('Total number of seconds cannot be less than zero.')
|
||||
raise ValueError("Total number of seconds cannot be less than zero.")
|
||||
|
||||
self._sleep = sleep
|
||||
self._seconds = float(seconds)
|
||||
@ -672,7 +670,7 @@ class Loop(Generic[LF]):
|
||||
self._time: List[datetime.time] = MISSING
|
||||
else:
|
||||
if any((seconds, minutes, hours)):
|
||||
raise TypeError('Cannot mix explicit time with relative time')
|
||||
raise TypeError("Cannot mix explicit time with relative time")
|
||||
self._time = self._get_time_parameter(time)
|
||||
self._sleep = self._seconds = self._minutes = self._hours = MISSING
|
||||
|
||||
|
Reference in New Issue
Block a user