Fix typing issues and improve typing completeness across the library

Co-authored-by: Danny <Rapptz@users.noreply.github.com>
Co-authored-by: Josh <josh.ja.butt@gmail.com>
This commit is contained in:
Stocker
2022-03-13 23:52:10 -04:00
committed by GitHub
parent 603681940f
commit 5aa696ccfa
66 changed files with 1071 additions and 802 deletions

View File

@@ -23,21 +23,35 @@ DEALINGS IN THE SOFTWARE.
"""
from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union
from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union, Tuple
T = TypeVar('T')
if TYPE_CHECKING:
from typing_extensions import ParamSpec
from .bot import Bot, AutoShardedBot
from .context import Context
from .cog import Cog
from .errors import CommandError
T = TypeVar('T')
P = ParamSpec('P')
MaybeCoroFunc = Union[
Callable[P, 'Coro[T]'],
Callable[P, T],
]
else:
P = TypeVar('P')
MaybeCoroFunc = Tuple[P, T]
Coro = Coroutine[Any, Any, T]
MaybeCoro = Union[T, Coro[T]]
CoroFunc = Callable[..., Coro[Any]]
ContextT = TypeVar('ContextT', bound='Context')
_Bot = Union['Bot', 'AutoShardedBot']
BotT = TypeVar('BotT', bound=_Bot)
Check = Union[Callable[["Cog", "ContextT"], MaybeCoro[bool]], Callable[["ContextT"], MaybeCoro[bool]]]
Hook = Union[Callable[["Cog", "ContextT"], Coro[Any]], Callable[["ContextT"], Coro[Any]]]

View File

@@ -33,7 +33,21 @@ import importlib.util
import sys
import traceback
import types
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union, overload
from typing import (
Any,
Callable,
Mapping,
List,
Dict,
TYPE_CHECKING,
Optional,
TypeVar,
Type,
Union,
Iterable,
Collection,
overload,
)
import discord
from discord import app_commands
@@ -55,10 +69,18 @@ if TYPE_CHECKING:
from discord.message import Message
from discord.abc import User, Snowflake
from ._types import (
_Bot,
BotT,
Check,
CoroFunc,
ContextT,
MaybeCoroFunc,
)
_Prefix = Union[Iterable[str], str]
_PrefixCallable = MaybeCoroFunc[[BotT, Message], _Prefix]
PrefixType = Union[_Prefix, _PrefixCallable[BotT]]
__all__ = (
'when_mentioned',
'when_mentioned_or',
@@ -68,11 +90,9 @@ __all__ = (
T = TypeVar('T')
CFT = TypeVar('CFT', bound='CoroFunc')
CXT = TypeVar('CXT', bound='Context')
BT = TypeVar('BT', bound='Union[Bot, AutoShardedBot]')
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
def when_mentioned(bot: _Bot, msg: Message) -> List[str]:
"""A callable that implements a command prefix equivalent to being mentioned.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
@@ -81,7 +101,7 @@ def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore
def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]:
def when_mentioned_or(*prefixes: str) -> Callable[[_Bot, Message], List[str]]:
"""A callable that implements when mentioned or other prefixes provided.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
@@ -124,27 +144,33 @@ class _DefaultRepr:
return '<default-help-command>'
_default = _DefaultRepr()
_default: Any = _DefaultRepr()
class BotBase(GroupMixin):
def __init__(self, command_prefix, help_command=_default, description=None, **options):
class BotBase(GroupMixin[None]):
def __init__(
self,
command_prefix: PrefixType[BotT],
help_command: HelpCommand = _default,
description: Optional[str] = None,
**options: Any,
) -> None:
super().__init__(**options)
self.command_prefix = command_prefix
self.command_prefix: PrefixType[BotT] = command_prefix
self.extra_events: Dict[str, List[CoroFunc]] = {}
# Self doesn't have the ClientT bound, but since this is a mixin it technically does
self.__tree: app_commands.CommandTree[Self] = app_commands.CommandTree(self) # type: ignore
self.__cogs: Dict[str, Cog] = {}
self.__extensions: Dict[str, types.ModuleType] = {}
self._checks: List[Check] = []
self._check_once = []
self._before_invoke = None
self._after_invoke = None
self._help_command = None
self.description = inspect.cleandoc(description) if description else ''
self.owner_id = options.get('owner_id')
self.owner_ids = options.get('owner_ids', set())
self.strip_after_prefix = options.get('strip_after_prefix', False)
self._check_once: List[Check] = []
self._before_invoke: Optional[CoroFunc] = None
self._after_invoke: Optional[CoroFunc] = None
self._help_command: Optional[HelpCommand] = None
self.description: str = inspect.cleandoc(description) if description else ''
self.owner_id: Optional[int] = options.get('owner_id')
self.owner_ids: Optional[Collection[int]] = options.get('owner_ids', set())
self.strip_after_prefix: bool = options.get('strip_after_prefix', False)
if self.owner_id and self.owner_ids:
raise TypeError('Both owner_id and owner_ids are set.')
@@ -182,7 +208,7 @@ class BotBase(GroupMixin):
await super().close() # type: ignore
async def on_command_error(self, context: Context, exception: errors.CommandError) -> None:
async def on_command_error(self, context: Context[BotT], exception: errors.CommandError) -> None:
"""|coro|
The default command error handler provided by the bot.
@@ -237,7 +263,7 @@ class BotBase(GroupMixin):
self.add_check(func) # type: ignore
return func
def add_check(self, func: Check, /, *, call_once: bool = False) -> None:
def add_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None:
"""Adds a global check to the bot.
This is the non-decorator interface to :meth:`.check`
@@ -261,7 +287,7 @@ class BotBase(GroupMixin):
else:
self._checks.append(func)
def remove_check(self, func: Check, /, *, call_once: bool = False) -> None:
def remove_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None:
"""Removes a global check from the bot.
This function is idempotent and will not raise an exception
@@ -324,7 +350,7 @@ class BotBase(GroupMixin):
self.add_check(func, call_once=True)
return func
async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool:
async def can_run(self, ctx: Context[BotT], *, call_once: bool = False) -> bool:
data = self._check_once if call_once else self._checks
if len(data) == 0:
@@ -947,7 +973,7 @@ class BotBase(GroupMixin):
# if the load failed, the remnants should have been
# cleaned from the load_extension function call
# so let's load it from our old compiled library.
await lib.setup(self) # type: ignore
await lib.setup(self)
self.__extensions[name] = lib
# revert sys.modules back to normal and raise back to caller
@@ -1015,11 +1041,12 @@ class BotBase(GroupMixin):
"""
prefix = ret = self.command_prefix
if callable(prefix):
ret = await discord.utils.maybe_coroutine(prefix, self, message)
# self will be a Bot or AutoShardedBot
ret = await discord.utils.maybe_coroutine(prefix, self, message) # type: ignore
if not isinstance(ret, str):
try:
ret = list(ret)
ret = list(ret) # type: ignore
except TypeError:
# It's possible that a generator raised this exception. Don't
# replace it with our own error if that's the case.
@@ -1048,15 +1075,15 @@ class BotBase(GroupMixin):
self,
message: Message,
*,
cls: Type[CXT] = ...,
) -> CXT: # type: ignore
cls: Type[ContextT] = ...,
) -> ContextT:
...
async def get_context(
self,
message: Message,
*,
cls: Type[CXT] = MISSING,
cls: Type[ContextT] = MISSING,
) -> Any:
r"""|coro|
@@ -1137,7 +1164,7 @@ class BotBase(GroupMixin):
ctx.command = self.all_commands.get(invoker)
return ctx
async def invoke(self, ctx: Context) -> None:
async def invoke(self, ctx: Context[BotT]) -> None:
"""|coro|
Invokes the command given under the invocation context and
@@ -1189,9 +1216,10 @@ class BotBase(GroupMixin):
return
ctx = await self.get_context(message)
await self.invoke(ctx)
# the type of the invocation context's bot attribute will be correct
await self.invoke(ctx) # type: ignore
async def on_message(self, message):
async def on_message(self, message: Message) -> None:
await self.process_commands(message)

View File

@@ -30,7 +30,7 @@ from discord.utils import maybe_coroutine
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union
from ._types import _BaseCommand
from ._types import _BaseCommand, BotT
if TYPE_CHECKING:
from typing_extensions import Self
@@ -112,7 +112,7 @@ class CogMeta(type):
__cog_name__: str
__cog_settings__: Dict[str, Any]
__cog_commands__: List[Command]
__cog_commands__: List[Command[Any, ..., Any]]
__cog_is_app_commands_group__: bool
__cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Any, ..., Any]]]
__cog_listeners__: List[Tuple[str, str]]
@@ -406,7 +406,7 @@ class Cog(metaclass=CogMeta):
pass
@_cog_special_method
def bot_check_once(self, ctx: Context) -> bool:
def bot_check_once(self, ctx: Context[BotT]) -> bool:
"""A special method that registers as a :meth:`.Bot.check_once`
check.
@@ -416,7 +416,7 @@ class Cog(metaclass=CogMeta):
return True
@_cog_special_method
def bot_check(self, ctx: Context) -> bool:
def bot_check(self, ctx: Context[BotT]) -> bool:
"""A special method that registers as a :meth:`.Bot.check`
check.
@@ -426,7 +426,7 @@ class Cog(metaclass=CogMeta):
return True
@_cog_special_method
def cog_check(self, ctx: Context) -> bool:
def cog_check(self, ctx: Context[BotT]) -> bool:
"""A special method that registers as a :func:`~discord.ext.commands.check`
for every command and subcommand in this cog.
@@ -436,7 +436,7 @@ class Cog(metaclass=CogMeta):
return True
@_cog_special_method
async def cog_command_error(self, ctx: Context, error: Exception) -> None:
async def cog_command_error(self, ctx: Context[BotT], error: Exception) -> None:
"""A special method that is called whenever an error
is dispatched inside this cog.
@@ -455,7 +455,7 @@ class Cog(metaclass=CogMeta):
pass
@_cog_special_method
async def cog_before_invoke(self, ctx: Context) -> None:
async def cog_before_invoke(self, ctx: Context[BotT]) -> None:
"""A special method that acts as a cog local pre-invoke hook.
This is similar to :meth:`.Command.before_invoke`.
@@ -470,7 +470,7 @@ class Cog(metaclass=CogMeta):
pass
@_cog_special_method
async def cog_after_invoke(self, ctx: Context) -> None:
async def cog_after_invoke(self, ctx: Context[BotT]) -> None:
"""A special method that acts as a cog local post-invoke hook.
This is similar to :meth:`.Command.after_invoke`.

View File

@@ -28,6 +28,8 @@ import re
from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, TypeVar, Union
from ._types import BotT
import discord.abc
import discord.utils
@@ -59,7 +61,6 @@ MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
CogT = TypeVar('CogT', bound="Cog")
if TYPE_CHECKING:
@@ -133,10 +134,10 @@ class Context(discord.abc.Messageable, Generic[BotT]):
args: List[Any] = MISSING,
kwargs: Dict[str, Any] = MISSING,
prefix: Optional[str] = None,
command: Optional[Command] = None,
command: Optional[Command[Any, ..., Any]] = None,
invoked_with: Optional[str] = None,
invoked_parents: List[str] = MISSING,
invoked_subcommand: Optional[Command] = None,
invoked_subcommand: Optional[Command[Any, ..., Any]] = None,
subcommand_passed: Optional[str] = None,
command_failed: bool = False,
current_parameter: Optional[inspect.Parameter] = None,
@@ -146,11 +147,11 @@ class Context(discord.abc.Messageable, Generic[BotT]):
self.args: List[Any] = args or []
self.kwargs: Dict[str, Any] = kwargs or {}
self.prefix: Optional[str] = prefix
self.command: Optional[Command] = command
self.command: Optional[Command[Any, ..., Any]] = command
self.view: StringView = view
self.invoked_with: Optional[str] = invoked_with
self.invoked_parents: List[str] = invoked_parents or []
self.invoked_subcommand: Optional[Command] = invoked_subcommand
self.invoked_subcommand: Optional[Command[Any, ..., Any]] = invoked_subcommand
self.subcommand_passed: Optional[str] = subcommand_passed
self.command_failed: bool = command_failed
self.current_parameter: Optional[inspect.Parameter] = current_parameter
@@ -361,7 +362,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
return None
cmd = cmd.copy()
cmd.context = self
cmd.context = self # type: ignore
if len(args) == 0:
await cmd.prepare_help_command(self, None)
mapping = cmd.get_bot_mapping()
@@ -390,7 +391,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
try:
if hasattr(entity, '__cog_commands__'):
injected = wrap_callback(cmd.send_cog_help)
return await injected(entity)
return await injected(entity) # type: ignore
elif isinstance(entity, Group):
injected = wrap_callback(cmd.send_group_help)
return await injected(entity)

View File

@@ -41,7 +41,6 @@ from typing import (
Tuple,
Union,
runtime_checkable,
overload,
)
import discord
@@ -51,9 +50,8 @@ if TYPE_CHECKING:
from .context import Context
from discord.state import Channel
from discord.threads import Thread
from .bot import Bot, AutoShardedBot
_Bot = TypeVar('_Bot', bound=Union[Bot, AutoShardedBot])
from ._types import BotT, _Bot
__all__ = (
@@ -87,7 +85,7 @@ __all__ = (
)
def _get_from_guilds(bot, getter, argument):
def _get_from_guilds(bot: _Bot, getter: str, argument: Any) -> Any:
result = None
for guild in bot.guilds:
result = getattr(guild, getter)(argument)
@@ -115,7 +113,7 @@ class Converter(Protocol[T_co]):
method to do its conversion logic. This method must be a :ref:`coroutine <coroutine>`.
"""
async def convert(self, ctx: Context, argument: str) -> T_co:
async def convert(self, ctx: Context[BotT], argument: str) -> T_co:
"""|coro|
The method to override to do conversion logic.
@@ -163,7 +161,7 @@ class ObjectConverter(IDConverter[discord.Object]):
2. Lookup by member, role, or channel mention.
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Object:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Object:
match = self._get_id_match(argument) or re.match(r'<(?:@(?:!|&)?|#)([0-9]{15,20})>$', argument)
if match is None:
@@ -196,7 +194,7 @@ class MemberConverter(IDConverter[discord.Member]):
optionally caching the result if :attr:`.MemberCacheFlags.joined` is enabled.
"""
async def query_member_named(self, guild, argument):
async def query_member_named(self, guild: discord.Guild, argument: str) -> Optional[discord.Member]:
cache = guild._state.member_cache_flags.joined
if len(argument) > 5 and argument[-5] == '#':
username, _, discriminator = argument.rpartition('#')
@@ -206,7 +204,7 @@ class MemberConverter(IDConverter[discord.Member]):
members = await guild.query_members(argument, limit=100, cache=cache)
return discord.utils.find(lambda m: m.name == argument or m.nick == argument, members)
async def query_member_by_id(self, bot, guild, user_id):
async def query_member_by_id(self, bot: _Bot, guild: discord.Guild, user_id: int) -> Optional[discord.Member]:
ws = bot._get_websocket(shard_id=guild.shard_id)
cache = guild._state.member_cache_flags.joined
if ws.is_ratelimited():
@@ -227,7 +225,7 @@ class MemberConverter(IDConverter[discord.Member]):
return None
return members[0]
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Member:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Member:
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
guild = ctx.guild
@@ -281,7 +279,7 @@ class UserConverter(IDConverter[discord.User]):
and it's not available in cache.
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.User:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.User:
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
result = None
state = ctx._state
@@ -359,7 +357,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
@staticmethod
def _resolve_channel(
ctx: Context[_Bot], guild_id: Optional[int], channel_id: Optional[int]
ctx: Context[BotT], guild_id: Optional[int], channel_id: Optional[int]
) -> Optional[Union[Channel, Thread]]:
if channel_id is None:
# we were passed just a message id so we can assume the channel is the current context channel
@@ -373,7 +371,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
return ctx.bot.get_channel(channel_id)
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialMessage:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.PartialMessage:
guild_id, message_id, channel_id = self._get_id_matches(ctx, argument)
channel = self._resolve_channel(ctx, guild_id, channel_id)
if not channel or not isinstance(channel, discord.abc.Messageable):
@@ -396,7 +394,7 @@ class MessageConverter(IDConverter[discord.Message]):
Raise :exc:`.ChannelNotFound`, :exc:`.MessageNotFound` or :exc:`.ChannelNotReadable` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Message:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Message:
guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument)
message = ctx.bot._connection._get_message(message_id)
if message:
@@ -427,11 +425,11 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
.. versionadded:: 2.0
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.abc.GuildChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.abc.GuildChannel:
return self._resolve_channel(ctx, argument, 'channels', discord.abc.GuildChannel)
@staticmethod
def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT:
def _resolve_channel(ctx: Context[BotT], argument: str, attribute: str, type: Type[CT]) -> CT:
bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
@@ -448,7 +446,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
def check(c):
return isinstance(c, type) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels())
result = discord.utils.find(check, bot.get_all_channels()) # type: ignore
else:
channel_id = int(match.group(1))
if guild:
@@ -463,7 +461,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
return result
@staticmethod
def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT:
def _resolve_thread(ctx: Context[BotT], argument: str, attribute: str, type: Type[TT]) -> TT:
bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
@@ -502,7 +500,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.TextChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.TextChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'text_channels', discord.TextChannel)
@@ -522,7 +520,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.VoiceChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.VoiceChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'voice_channels', discord.VoiceChannel)
@@ -541,7 +539,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
3. Lookup by name
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.StageChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.StageChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'stage_channels', discord.StageChannel)
@@ -561,7 +559,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.CategoryChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.CategoryChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'categories', discord.CategoryChannel)
@@ -580,7 +578,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
.. versionadded:: 1.7
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.StoreChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.StoreChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel)
@@ -598,7 +596,7 @@ class ThreadConverter(IDConverter[discord.Thread]):
.. versionadded: 2.0
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Thread:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Thread:
return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread)
@@ -630,7 +628,7 @@ class ColourConverter(Converter[discord.Colour]):
RGB_REGEX = re.compile(r'rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)')
def parse_hex_number(self, argument):
def parse_hex_number(self, argument: str) -> discord.Colour:
arg = ''.join(i * 2 for i in argument) if len(argument) == 3 else argument
try:
value = int(arg, base=16)
@@ -641,7 +639,7 @@ class ColourConverter(Converter[discord.Colour]):
else:
return discord.Color(value=value)
def parse_rgb_number(self, argument, number):
def parse_rgb_number(self, argument: str, number: str) -> int:
if number[-1] == '%':
value = int(number[:-1])
if not (0 <= value <= 100):
@@ -653,7 +651,7 @@ class ColourConverter(Converter[discord.Colour]):
raise BadColourArgument(argument)
return value
def parse_rgb(self, argument, *, regex=RGB_REGEX):
def parse_rgb(self, argument: str, *, regex: re.Pattern[str] = RGB_REGEX) -> discord.Colour:
match = regex.match(argument)
if match is None:
raise BadColourArgument(argument)
@@ -663,7 +661,7 @@ class ColourConverter(Converter[discord.Colour]):
blue = self.parse_rgb_number(argument, match.group('b'))
return discord.Color.from_rgb(red, green, blue)
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Colour:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Colour:
if argument[0] == '#':
return self.parse_hex_number(argument[1:])
@@ -704,7 +702,7 @@ class RoleConverter(IDConverter[discord.Role]):
Raise :exc:`.RoleNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Role:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Role:
guild = ctx.guild
if not guild:
raise NoPrivateMessage()
@@ -723,7 +721,7 @@ class RoleConverter(IDConverter[discord.Role]):
class GameConverter(Converter[discord.Game]):
"""Converts to :class:`~discord.Game`."""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Game:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Game:
return discord.Game(name=argument)
@@ -736,7 +734,7 @@ class InviteConverter(Converter[discord.Invite]):
Raise :exc:`.BadInviteArgument` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Invite:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Invite:
try:
invite = await ctx.bot.fetch_invite(argument)
return invite
@@ -755,7 +753,7 @@ class GuildConverter(IDConverter[discord.Guild]):
.. versionadded:: 1.7
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Guild:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Guild:
match = self._get_id_match(argument)
result = None
@@ -787,7 +785,7 @@ class EmojiConverter(IDConverter[discord.Emoji]):
Raise :exc:`.EmojiNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Emoji:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Emoji:
match = self._get_id_match(argument) or re.match(r'<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$', argument)
result = None
bot = ctx.bot
@@ -821,7 +819,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
Raise :exc:`.PartialEmojiConversionFailure` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialEmoji:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.PartialEmoji:
match = re.match(r'<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$', argument)
if match:
@@ -850,7 +848,7 @@ class GuildStickerConverter(IDConverter[discord.GuildSticker]):
.. versionadded:: 2.0
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.GuildSticker:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.GuildSticker:
match = self._get_id_match(argument)
result = None
bot = ctx.bot
@@ -890,7 +888,7 @@ class ScheduledEventConverter(IDConverter[discord.ScheduledEvent]):
.. versionadded:: 2.0
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.ScheduledEvent:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.ScheduledEvent:
guild = ctx.guild
match = self._get_id_match(argument)
result = None
@@ -967,7 +965,7 @@ class clean_content(Converter[str]):
self.escape_markdown = escape_markdown
self.remove_markdown = remove_markdown
async def convert(self, ctx: Context[_Bot], argument: str) -> str:
async def convert(self, ctx: Context[BotT], argument: str) -> str:
msg = ctx.message
if ctx.guild:
@@ -1047,10 +1045,10 @@ class Greedy(List[T]):
__slots__ = ('converter',)
def __init__(self, *, converter: T):
self.converter = converter
def __init__(self, *, converter: T) -> None:
self.converter: T = converter
def __repr__(self):
def __repr__(self) -> str:
converter = getattr(self.converter, '__name__', repr(self.converter))
return f'Greedy[{converter}]'
@@ -1099,11 +1097,11 @@ def get_converter(param: inspect.Parameter) -> Any:
_GenericAlias = type(List[T])
def is_generic_type(tp: Any, *, _GenericAlias: Type = _GenericAlias) -> bool:
return isinstance(tp, type) and issubclass(tp, Generic) or isinstance(tp, _GenericAlias) # type: ignore
def is_generic_type(tp: Any, *, _GenericAlias: type = _GenericAlias) -> bool:
return isinstance(tp, type) and issubclass(tp, Generic) or isinstance(tp, _GenericAlias)
CONVERTER_MAPPING: Dict[Type[Any], Any] = {
CONVERTER_MAPPING: Dict[type, Any] = {
discord.Object: ObjectConverter,
discord.Member: MemberConverter,
discord.User: UserConverter,
@@ -1128,7 +1126,7 @@ CONVERTER_MAPPING: Dict[Type[Any], Any] = {
}
async def _actual_conversion(ctx: Context, converter, argument: str, param: inspect.Parameter):
async def _actual_conversion(ctx: Context[BotT], converter, argument: str, param: inspect.Parameter):
if converter is bool:
return _convert_to_bool(argument)
@@ -1166,7 +1164,7 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
async def run_converters(ctx: Context, converter, argument: str, param: inspect.Parameter):
async def run_converters(ctx: Context[BotT], converter: Any, argument: str, param: inspect.Parameter) -> Any:
"""|coro|
Runs converters for a given converter, argument, and parameter.

View File

@@ -220,7 +220,7 @@ class CooldownMapping:
return self._type
@classmethod
def from_cooldown(cls, rate, per, type) -> Self:
def from_cooldown(cls, rate: float, per: float, type: Callable[[Message], Any]) -> Self:
return cls(Cooldown(rate, per), type)
def _bucket_key(self, msg: Message) -> Any:

View File

@@ -61,6 +61,8 @@ if TYPE_CHECKING:
from discord.message import Message
from ._types import (
BotT,
ContextT,
Coro,
CoroFunc,
Check,
@@ -101,7 +103,6 @@ MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
CogT = TypeVar('CogT', bound='Optional[Cog]')
CommandT = TypeVar('CommandT', bound='Command')
ContextT = TypeVar('ContextT', bound='Context')
# CHT = TypeVar('CHT', bound='Check')
GroupT = TypeVar('GroupT', bound='Group')
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
@@ -159,9 +160,9 @@ def get_signature_parameters(
return params
def wrap_callback(coro):
def wrap_callback(coro: Callable[P, Coro[T]]) -> Callable[P, Coro[Optional[T]]]:
@functools.wraps(coro)
async def wrapped(*args, **kwargs):
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> Optional[T]:
try:
ret = await coro(*args, **kwargs)
except CommandError:
@@ -175,9 +176,11 @@ def wrap_callback(coro):
return wrapped
def hooked_wrapped_callback(command, ctx, coro):
def hooked_wrapped_callback(
command: Command[Any, ..., Any], ctx: Context[BotT], coro: Callable[P, Coro[T]]
) -> Callable[P, Coro[Optional[T]]]:
@functools.wraps(coro)
async def wrapped(*args, **kwargs):
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> Optional[T]:
try:
ret = await coro(*args, **kwargs)
except CommandError:
@@ -191,7 +194,7 @@ def hooked_wrapped_callback(command, ctx, coro):
raise CommandInvokeError(exc) from exc
finally:
if command._max_concurrency is not None:
await command._max_concurrency.release(ctx)
await command._max_concurrency.release(ctx.message)
await command.call_after_hooks(ctx)
return ret
@@ -359,7 +362,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
except AttributeError:
checks = kwargs.get('checks', [])
self.checks: List[Check] = checks
self.checks: List[Check[ContextT]] = checks
try:
cooldown = func.__commands_cooldown__
@@ -387,8 +390,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.cog: CogT = None
# bandaid for the fact that sometimes parent can be the bot instance
parent = kwargs.get('parent')
self.parent: Optional[GroupMixin] = parent if isinstance(parent, _BaseCommand) else None # type: ignore
parent: Optional[GroupMixin[Any]] = kwargs.get('parent')
self.parent: Optional[GroupMixin[Any]] = parent if isinstance(parent, _BaseCommand) else None
self._before_invoke: Optional[Hook] = None
try:
@@ -422,16 +425,16 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
) -> None:
self._callback = function
unwrap = unwrap_function(function)
self.module = unwrap.__module__
self.module: str = unwrap.__module__
try:
globalns = unwrap.__globals__
except AttributeError:
globalns = {}
self.params = get_signature_parameters(function, globalns)
self.params: Dict[str, inspect.Parameter] = get_signature_parameters(function, globalns)
def add_check(self, func: Check, /) -> None:
def add_check(self, func: Check[ContextT], /) -> None:
"""Adds a check to the command.
This is the non-decorator interface to :func:`.check`.
@@ -450,7 +453,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.checks.append(func)
def remove_check(self, func: Check, /) -> None:
def remove_check(self, func: Check[ContextT], /) -> None:
"""Removes a check from the command.
This function is idempotent and will not raise an exception
@@ -484,7 +487,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.__init__(self.callback, **dict(self.__original_kwargs__, **kwargs))
self.cog = cog
async def __call__(self, context: Context, *args: P.args, **kwargs: P.kwargs) -> T:
async def __call__(self, context: Context[BotT], *args: P.args, **kwargs: P.kwargs) -> T:
"""|coro|
Calls the internal callback that the command holds.
@@ -539,7 +542,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
else:
return self.copy()
async def dispatch_error(self, ctx: Context, error: Exception) -> None:
async def dispatch_error(self, ctx: Context[BotT], error: CommandError) -> None:
ctx.command_failed = True
cog = self.cog
try:
@@ -549,7 +552,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
else:
injected = wrap_callback(coro)
if cog is not None:
await injected(cog, ctx, error)
await injected(cog, ctx, error) # type: ignore
else:
await injected(ctx, error)
@@ -562,7 +565,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
finally:
ctx.bot.dispatch('command_error', ctx, error)
async def transform(self, ctx: Context, param: inspect.Parameter) -> Any:
async def transform(self, ctx: Context[BotT], param: inspect.Parameter) -> Any:
required = param.default is param.empty
converter = get_converter(param)
consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw
@@ -610,7 +613,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
# type-checker fails to narrow argument
return await run_converters(ctx, converter, argument, param) # type: ignore
async def _transform_greedy_pos(self, ctx: Context, param: inspect.Parameter, required: bool, converter: Any) -> Any:
async def _transform_greedy_pos(
self, ctx: Context[BotT], param: inspect.Parameter, required: bool, converter: Any
) -> Any:
view = ctx.view
result = []
while not view.eof:
@@ -631,7 +636,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return param.default
return result
async def _transform_greedy_var_pos(self, ctx: Context, param: inspect.Parameter, converter: Any) -> Any:
async def _transform_greedy_var_pos(self, ctx: Context[BotT], param: inspect.Parameter, converter: Any) -> Any:
view = ctx.view
previous = view.index
try:
@@ -669,7 +674,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return ' '.join(reversed(entries))
@property
def parents(self) -> List[Group]:
def parents(self) -> List[Group[Any, ..., Any]]:
"""List[:class:`Group`]: Retrieves the parents of this command.
If the command has no parents then it returns an empty :class:`list`.
@@ -687,7 +692,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return entries
@property
def root_parent(self) -> Optional[Group]:
def root_parent(self) -> Optional[Group[Any, ..., Any]]:
"""Optional[:class:`Group`]: Retrieves the root parent of this command.
If the command has no parents then it returns ``None``.
@@ -716,7 +721,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
def __str__(self) -> str:
return self.qualified_name
async def _parse_arguments(self, ctx: Context) -> None:
async def _parse_arguments(self, ctx: Context[BotT]) -> None:
ctx.args = [ctx] if self.cog is None else [self.cog, ctx]
ctx.kwargs = {}
args = ctx.args
@@ -752,7 +757,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if not self.ignore_extra and not view.eof:
raise TooManyArguments('Too many arguments passed to ' + self.qualified_name)
async def call_before_hooks(self, ctx: Context) -> None:
async def call_before_hooks(self, ctx: Context[BotT]) -> None:
# now that we're done preparing we can call the pre-command hooks
# first, call the command local hook:
cog = self.cog
@@ -777,7 +782,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if hook is not None:
await hook(ctx)
async def call_after_hooks(self, ctx: Context) -> None:
async def call_after_hooks(self, ctx: Context[BotT]) -> None:
cog = self.cog
if self._after_invoke is not None:
instance = getattr(self._after_invoke, '__self__', cog)
@@ -796,7 +801,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if hook is not None:
await hook(ctx)
def _prepare_cooldowns(self, ctx: Context) -> None:
def _prepare_cooldowns(self, ctx: Context[BotT]) -> None:
if self._buckets.valid:
dt = ctx.message.edited_at or ctx.message.created_at
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
@@ -806,7 +811,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if retry_after:
raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore
async def prepare(self, ctx: Context) -> None:
async def prepare(self, ctx: Context[BotT]) -> None:
ctx.command = self
if not await self.can_run(ctx):
@@ -830,7 +835,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
await self._max_concurrency.release(ctx) # type: ignore
raise
def is_on_cooldown(self, ctx: Context) -> bool:
def is_on_cooldown(self, ctx: Context[BotT]) -> bool:
"""Checks whether the command is currently on cooldown.
Parameters
@@ -851,7 +856,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
return bucket.get_tokens(current) == 0
def reset_cooldown(self, ctx: Context) -> None:
def reset_cooldown(self, ctx: Context[BotT]) -> None:
"""Resets the cooldown on this command.
Parameters
@@ -863,7 +868,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
bucket = self._buckets.get_bucket(ctx.message)
bucket.reset()
def get_cooldown_retry_after(self, ctx: Context) -> float:
def get_cooldown_retry_after(self, ctx: Context[BotT]) -> float:
"""Retrieves the amount of seconds before this command can be tried again.
.. versionadded:: 1.4
@@ -887,7 +892,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return 0.0
async def invoke(self, ctx: Context) -> None:
async def invoke(self, ctx: Context[BotT]) -> None:
await self.prepare(ctx)
# terminate the invoked_subcommand chain.
@@ -896,9 +901,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
ctx.invoked_subcommand = None
ctx.subcommand_passed = None
injected = hooked_wrapped_callback(self, ctx, self.callback)
await injected(*ctx.args, **ctx.kwargs)
await injected(*ctx.args, **ctx.kwargs) # type: ignore
async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None:
async def reinvoke(self, ctx: Context[BotT], *, call_hooks: bool = False) -> None:
ctx.command = self
await self._parse_arguments(ctx)
@@ -936,7 +941,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The error handler must be a coroutine.')
self.on_error: Error = coro
self.on_error: Error[Any] = coro
return coro
def has_error_handler(self) -> bool:
@@ -1075,7 +1080,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return ' '.join(result)
async def can_run(self, ctx: Context) -> bool:
async def can_run(self, ctx: Context[BotT]) -> bool:
"""|coro|
Checks if the command can be executed by checking all the predicates
@@ -1341,7 +1346,7 @@ class GroupMixin(Generic[CogT]):
def command(
self,
name: str = MISSING,
cls: Type[Command] = MISSING,
cls: Type[Command[Any, ..., Any]] = MISSING,
*args: Any,
**kwargs: Any,
) -> Any:
@@ -1401,7 +1406,7 @@ class GroupMixin(Generic[CogT]):
def group(
self,
name: str = MISSING,
cls: Type[Group] = MISSING,
cls: Type[Group[Any, ..., Any]] = MISSING,
*args: Any,
**kwargs: Any,
) -> Any:
@@ -1461,9 +1466,9 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
ret = super().copy()
for cmd in self.commands:
ret.add_command(cmd.copy())
return ret # type: ignore
return ret
async def invoke(self, ctx: Context) -> None:
async def invoke(self, ctx: Context[BotT]) -> None:
ctx.invoked_subcommand = None
ctx.subcommand_passed = None
early_invoke = not self.invoke_without_command
@@ -1481,7 +1486,7 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
if early_invoke:
injected = hooked_wrapped_callback(self, ctx, self.callback)
await injected(*ctx.args, **ctx.kwargs)
await injected(*ctx.args, **ctx.kwargs) # type: ignore
ctx.invoked_parents.append(ctx.invoked_with) # type: ignore
@@ -1494,7 +1499,7 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
view.previous = previous
await super().invoke(ctx)
async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None:
async def reinvoke(self, ctx: Context[BotT], *, call_hooks: bool = False) -> None:
ctx.invoked_subcommand = None
early_invoke = not self.invoke_without_command
if early_invoke:
@@ -1592,7 +1597,7 @@ def command(
def command(
name: str = MISSING,
cls: Type[Command] = MISSING,
cls: Type[Command[Any, ..., Any]] = MISSING,
**attrs: Any,
) -> Any:
"""A decorator that transforms a function into a :class:`.Command`
@@ -1662,7 +1667,7 @@ def group(
def group(
name: str = MISSING,
cls: Type[Group] = MISSING,
cls: Type[Group[Any, ..., Any]] = MISSING,
**attrs: Any,
) -> Any:
"""A decorator that transforms a function into a :class:`.Group`.
@@ -1679,7 +1684,7 @@ def group(
return command(name=name, cls=cls, **attrs)
def check(predicate: Check) -> Callable[[T], T]:
def check(predicate: Check[ContextT]) -> Callable[[T], T]:
r"""A decorator that adds a check to the :class:`.Command` or its
subclasses. These checks could be accessed via :attr:`.Command.checks`.
@@ -1774,7 +1779,7 @@ def check(predicate: Check) -> Callable[[T], T]:
return decorator # type: ignore
def check_any(*checks: Check) -> Callable[[T], T]:
def check_any(*checks: Check[ContextT]) -> Callable[[T], T]:
r"""A :func:`check` that is added that checks if any of the checks passed
will pass, i.e. using logical OR.
@@ -1827,7 +1832,7 @@ def check_any(*checks: Check) -> Callable[[T], T]:
else:
unwrapped.append(pred)
async def predicate(ctx: Context) -> bool:
async def predicate(ctx: Context[BotT]) -> bool:
errors = []
for func in unwrapped:
try:
@@ -1870,7 +1875,7 @@ def has_role(item: Union[int, str]) -> Callable[[T], T]:
The name or ID of the role to check.
"""
def predicate(ctx: Context) -> bool:
def predicate(ctx: Context[BotT]) -> bool:
if ctx.guild is None:
raise NoPrivateMessage()
@@ -1923,7 +1928,7 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
raise NoPrivateMessage()
# ctx.guild is None doesn't narrow ctx.author to Member
getter = functools.partial(discord.utils.get, ctx.author.roles) # type: ignore
getter = functools.partial(discord.utils.get, ctx.author.roles)
if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items):
return True
raise MissingAnyRole(list(items))
@@ -2022,7 +2027,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]:
if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
def predicate(ctx: Context) -> bool:
def predicate(ctx: Context[BotT]) -> bool:
ch = ctx.channel
permissions = ch.permissions_for(ctx.author) # type: ignore
@@ -2048,7 +2053,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
def predicate(ctx: Context) -> bool:
def predicate(ctx: Context[BotT]) -> bool:
guild = ctx.guild
me = guild.me if guild is not None else ctx.bot.user
permissions = ctx.channel.permissions_for(me) # type: ignore
@@ -2077,7 +2082,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
def predicate(ctx: Context) -> bool:
def predicate(ctx: Context[BotT]) -> bool:
if not ctx.guild:
raise NoPrivateMessage
@@ -2103,7 +2108,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
def predicate(ctx: Context) -> bool:
def predicate(ctx: Context[BotT]) -> bool:
if not ctx.guild:
raise NoPrivateMessage
@@ -2129,7 +2134,7 @@ def dm_only() -> Callable[[T], T]:
.. versionadded:: 1.1
"""
def predicate(ctx: Context) -> bool:
def predicate(ctx: Context[BotT]) -> bool:
if ctx.guild is not None:
raise PrivateMessageOnly()
return True
@@ -2146,7 +2151,7 @@ def guild_only() -> Callable[[T], T]:
that is inherited from :exc:`.CheckFailure`.
"""
def predicate(ctx: Context) -> bool:
def predicate(ctx: Context[BotT]) -> bool:
if ctx.guild is None:
raise NoPrivateMessage()
return True
@@ -2164,7 +2169,7 @@ def is_owner() -> Callable[[T], T]:
from :exc:`.CheckFailure`.
"""
async def predicate(ctx: Context) -> bool:
async def predicate(ctx: Context[BotT]) -> bool:
if not await ctx.bot.is_owner(ctx.author):
raise NotOwner('You do not own this bot.')
return True
@@ -2184,7 +2189,7 @@ def is_nsfw() -> Callable[[T], T]:
DM channels will also now pass this check.
"""
def pred(ctx: Context) -> bool:
def pred(ctx: Context[BotT]) -> bool:
ch = ctx.channel
if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()):
return True

View File

@@ -39,6 +39,8 @@ if TYPE_CHECKING:
from discord.threads import Thread
from discord.types.snowflake import Snowflake, SnowflakeList
from ._types import BotT
__all__ = (
'CommandError',
@@ -135,8 +137,8 @@ class ConversionError(CommandError):
the ``__cause__`` attribute.
"""
def __init__(self, converter: Converter, original: Exception) -> None:
self.converter: Converter = converter
def __init__(self, converter: Converter[Any], original: Exception) -> None:
self.converter: Converter[Any] = converter
self.original: Exception = original
@@ -224,9 +226,9 @@ class CheckAnyFailure(CheckFailure):
A list of check predicates that failed.
"""
def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]) -> None:
def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context[BotT]], bool]]) -> None:
self.checks: List[CheckFailure] = checks
self.errors: List[Callable[[Context], bool]] = errors
self.errors: List[Callable[[Context[BotT]], bool]] = errors
super().__init__('You do not have permission to run this command.')
@@ -807,9 +809,9 @@ class BadUnionArgument(UserInputError):
A list of errors that were caught from failing the conversion.
"""
def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None:
def __init__(self, param: Parameter, converters: Tuple[type, ...], errors: List[CommandError]) -> None:
self.param: Parameter = param
self.converters: Tuple[Type, ...] = converters
self.converters: Tuple[type, ...] = converters
self.errors: List[CommandError] = errors
def _get_name(x):

View File

@@ -49,8 +49,6 @@ from typing import (
Tuple,
List,
Any,
Type,
TypeVar,
Union,
)
@@ -70,6 +68,8 @@ if TYPE_CHECKING:
from .context import Context
from ._types import BotT
@dataclass
class Flag:
@@ -148,7 +148,7 @@ def flag(
return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override)
def validate_flag_name(name: str, forbidden: Set[str]):
def validate_flag_name(name: str, forbidden: Set[str]) -> None:
if not name:
raise ValueError('flag names should not be empty')
@@ -348,7 +348,7 @@ class FlagsMeta(type):
return type.__new__(cls, name, bases, attrs)
async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]:
async def tuple_convert_all(ctx: Context[BotT], argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]:
view = StringView(argument)
results = []
param: inspect.Parameter = ctx.current_parameter # type: ignore
@@ -373,7 +373,7 @@ async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter:
return tuple(results)
async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]:
async def tuple_convert_flag(ctx: Context[BotT], argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]:
view = StringView(argument)
results = []
param: inspect.Parameter = ctx.current_parameter # type: ignore
@@ -401,7 +401,7 @@ async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters
return tuple(results)
async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -> Any:
async def convert_flag(ctx: Context[BotT], argument: str, flag: Flag, annotation: Any = None) -> Any:
param: inspect.Parameter = ctx.current_parameter # type: ignore
annotation = annotation or flag.annotation
try:
@@ -480,7 +480,7 @@ class FlagConverter(metaclass=FlagsMeta):
yield (flag.name, getattr(self, flag.attribute))
@classmethod
async def _construct_default(cls, ctx: Context) -> Self:
async def _construct_default(cls, ctx: Context[BotT]) -> Self:
self = cls.__new__(cls)
flags = cls.__commands_flags__
for flag in flags.values():
@@ -546,7 +546,7 @@ class FlagConverter(metaclass=FlagsMeta):
return result
@classmethod
async def convert(cls, ctx: Context, argument: str) -> Self:
async def convert(cls, ctx: Context[BotT], argument: str) -> Self:
"""|coro|
The method that actually converters an argument to the flag mapping.
@@ -610,7 +610,7 @@ class FlagConverter(metaclass=FlagsMeta):
values = [await convert_flag(ctx, value, flag) for value in values]
if flag.cast_to_dict:
values = dict(values) # type: ignore
values = dict(values)
setattr(self, flag.attribute, values)

View File

@@ -22,13 +22,27 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import itertools
import copy
import functools
import inspect
import re
from typing import Optional, TYPE_CHECKING
from typing import (
TYPE_CHECKING,
Optional,
Generator,
List,
TypeVar,
Callable,
Any,
Dict,
Tuple,
Iterable,
Sequence,
Mapping,
)
import discord.utils
@@ -36,7 +50,21 @@ from .core import Group, Command, get_signature_parameters
from .errors import CommandError
if TYPE_CHECKING:
from typing_extensions import Self
import inspect
import discord.abc
from .bot import BotBase
from .context import Context
from .cog import Cog
from ._types import (
Check,
ContextT,
BotT,
_Bot,
)
__all__ = (
'Paginator',
@@ -45,7 +73,9 @@ __all__ = (
'MinimalHelpCommand',
)
MISSING = discord.utils.MISSING
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
MISSING: Any = discord.utils.MISSING
# help -> shows info of bot on top/bottom and lists subcommands
# help command -> shows detailed info of command
@@ -80,10 +110,10 @@ class Paginator:
Attributes
-----------
prefix: :class:`str`
The prefix inserted to every page. e.g. three backticks.
suffix: :class:`str`
The suffix appended at the end of every page. e.g. three backticks.
prefix: Optional[:class:`str`]
The prefix inserted to every page. e.g. three backticks, if any.
suffix: Optional[:class:`str`]
The suffix appended at the end of every page. e.g. three backticks, if any.
max_size: :class:`int`
The maximum amount of codepoints allowed in a page.
linesep: :class:`str`
@@ -91,36 +121,38 @@ class Paginator:
.. versionadded:: 1.7
"""
def __init__(self, prefix='```', suffix='```', max_size=2000, linesep='\n'):
self.prefix = prefix
self.suffix = suffix
self.max_size = max_size
self.linesep = linesep
def __init__(
self, prefix: Optional[str] = '```', suffix: Optional[str] = '```', max_size: int = 2000, linesep: str = '\n'
) -> None:
self.prefix: Optional[str] = prefix
self.suffix: Optional[str] = suffix
self.max_size: int = max_size
self.linesep: str = linesep
self.clear()
def clear(self):
def clear(self) -> None:
"""Clears the paginator to have no pages."""
if self.prefix is not None:
self._current_page = [self.prefix]
self._count = len(self.prefix) + self._linesep_len # prefix + newline
self._current_page: List[str] = [self.prefix]
self._count: int = len(self.prefix) + self._linesep_len # prefix + newline
else:
self._current_page = []
self._count = 0
self._pages = []
self._pages: List[str] = []
@property
def _prefix_len(self):
def _prefix_len(self) -> int:
return len(self.prefix) if self.prefix else 0
@property
def _suffix_len(self):
def _suffix_len(self) -> int:
return len(self.suffix) if self.suffix else 0
@property
def _linesep_len(self):
def _linesep_len(self) -> int:
return len(self.linesep)
def add_line(self, line='', *, empty=False):
def add_line(self, line: str = '', *, empty: bool = False) -> None:
"""Adds a line to the current page.
If the line exceeds the :attr:`max_size` then an exception
@@ -152,7 +184,7 @@ class Paginator:
self._current_page.append('')
self._count += self._linesep_len
def close_page(self):
def close_page(self) -> None:
"""Prematurely terminate a page."""
if self.suffix is not None:
self._current_page.append(self.suffix)
@@ -165,36 +197,38 @@ class Paginator:
self._current_page = []
self._count = 0
def __len__(self):
def __len__(self) -> int:
total = sum(len(p) for p in self._pages)
return total + self._count
@property
def pages(self):
def pages(self) -> List[str]:
"""List[:class:`str`]: Returns the rendered list of pages."""
# we have more than just the prefix in our current page
if len(self._current_page) > (0 if self.prefix is None else 1):
self.close_page()
return self._pages
def __repr__(self):
def __repr__(self) -> str:
fmt = '<Paginator prefix: {0.prefix!r} suffix: {0.suffix!r} linesep: {0.linesep!r} max_size: {0.max_size} count: {0._count}>'
return fmt.format(self)
def _not_overridden(f):
def _not_overridden(f: FuncT) -> FuncT:
f.__help_command_not_overridden__ = True
return f
class _HelpCommandImpl(Command):
def __init__(self, inject, *args, **kwargs):
def __init__(self, inject: HelpCommand, *args: Any, **kwargs: Any) -> None:
super().__init__(inject.command_callback, *args, **kwargs)
self._original = inject
self._injected = inject
self.params = get_signature_parameters(inject.command_callback, globals(), skip_parameters=1)
self._original: HelpCommand = inject
self._injected: HelpCommand = inject
self.params: Dict[str, inspect.Parameter] = get_signature_parameters(
inject.command_callback, globals(), skip_parameters=1
)
async def prepare(self, ctx):
async def prepare(self, ctx: Context[Any]) -> None:
self._injected = injected = self._original.copy()
injected.context = ctx
self.callback = injected.command_callback
@@ -209,7 +243,7 @@ class _HelpCommandImpl(Command):
await super().prepare(ctx)
async def _parse_arguments(self, ctx):
async def _parse_arguments(self, ctx: Context[BotT]) -> None:
# Make the parser think we don't have a cog so it doesn't
# inject the parameter into `ctx.args`.
original_cog = self.cog
@@ -219,22 +253,26 @@ class _HelpCommandImpl(Command):
finally:
self.cog = original_cog
async def _on_error_cog_implementation(self, dummy, ctx, error):
async def _on_error_cog_implementation(self, _, ctx: Context[BotT], error: CommandError) -> None:
await self._injected.on_help_command_error(ctx, error)
def _inject_into_cog(self, cog):
def _inject_into_cog(self, cog: Cog) -> None:
# Warning: hacky
# Make the cog think that get_commands returns this command
# as well if we inject it without modifying __cog_commands__
# since that's used for the injection and ejection of cogs.
def wrapped_get_commands(*, _original=cog.get_commands):
def wrapped_get_commands(
*, _original: Callable[[], List[Command[Any, ..., Any]]] = cog.get_commands
) -> List[Command[Any, ..., Any]]:
ret = _original()
ret.append(self)
return ret
# Ditto here
def wrapped_walk_commands(*, _original=cog.walk_commands):
def wrapped_walk_commands(
*, _original: Callable[[], Generator[Command[Any, ..., Any], None, None]] = cog.walk_commands
):
yield from _original()
yield self
@@ -244,7 +282,7 @@ class _HelpCommandImpl(Command):
cog.walk_commands = wrapped_walk_commands
self.cog = cog
def _eject_cog(self):
def _eject_cog(self) -> None:
if self.cog is None:
return
@@ -298,7 +336,11 @@ class HelpCommand:
MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys()))
def __new__(cls, *args, **kwargs):
if TYPE_CHECKING:
__original_kwargs__: Dict[str, Any]
__original_args__: Tuple[Any, ...]
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
# To prevent race conditions of a single instance while also allowing
# for settings to be passed the original arguments passed must be assigned
# to allow for easier copies (which will be made when the help command is actually called)
@@ -314,30 +356,31 @@ class HelpCommand:
self.__original_args__ = deepcopy(args)
return self
def __init__(self, **options):
self.show_hidden = options.pop('show_hidden', False)
self.verify_checks = options.pop('verify_checks', True)
def __init__(self, **options: Any) -> None:
self.show_hidden: bool = options.pop('show_hidden', False)
self.verify_checks: bool = options.pop('verify_checks', True)
self.command_attrs: Dict[str, Any]
self.command_attrs = attrs = options.pop('command_attrs', {})
attrs.setdefault('name', 'help')
attrs.setdefault('help', 'Shows this message')
self.context: Context = MISSING
self.context: Context[_Bot] = MISSING
self._command_impl = _HelpCommandImpl(self, **self.command_attrs)
def copy(self):
def copy(self) -> Self:
obj = self.__class__(*self.__original_args__, **self.__original_kwargs__)
obj._command_impl = self._command_impl
return obj
def _add_to_bot(self, bot):
def _add_to_bot(self, bot: BotBase) -> None:
command = _HelpCommandImpl(self, **self.command_attrs)
bot.add_command(command)
self._command_impl = command
def _remove_from_bot(self, bot):
def _remove_from_bot(self, bot: BotBase) -> None:
bot.remove_command(self._command_impl.name)
self._command_impl._eject_cog()
def add_check(self, func, /):
def add_check(self, func: Check[ContextT], /) -> None:
"""
Adds a check to the help command.
@@ -355,7 +398,7 @@ class HelpCommand:
self._command_impl.add_check(func)
def remove_check(self, func, /):
def remove_check(self, func: Check[ContextT], /) -> None:
"""
Removes a check from the help command.
@@ -376,15 +419,15 @@ class HelpCommand:
self._command_impl.remove_check(func)
def get_bot_mapping(self):
def get_bot_mapping(self) -> Dict[Optional[Cog], List[Command[Any, ..., Any]]]:
"""Retrieves the bot mapping passed to :meth:`send_bot_help`."""
bot = self.context.bot
mapping = {cog: cog.get_commands() for cog in bot.cogs.values()}
mapping: Dict[Optional[Cog], List[Command[Any, ..., Any]]] = {cog: cog.get_commands() for cog in bot.cogs.values()}
mapping[None] = [c for c in bot.commands if c.cog is None]
return mapping
@property
def invoked_with(self):
def invoked_with(self) -> Optional[str]:
"""Similar to :attr:`Context.invoked_with` except properly handles
the case where :meth:`Context.send_help` is used.
@@ -395,7 +438,7 @@ class HelpCommand:
Returns
---------
:class:`str`
Optional[:class:`str`]
The command name that triggered this invocation.
"""
command_name = self._command_impl.name
@@ -404,7 +447,7 @@ class HelpCommand:
return command_name
return ctx.invoked_with
def get_command_signature(self, command):
def get_command_signature(self, command: Command[Any, ..., Any]) -> str:
"""Retrieves the signature portion of the help page.
Parameters
@@ -418,14 +461,14 @@ class HelpCommand:
The signature for the command.
"""
parent = command.parent
parent: Optional[Group[Any, ..., Any]] = command.parent # type: ignore - the parent will be a Group
entries = []
while parent is not None:
if not parent.signature or parent.invoke_without_command:
entries.append(parent.name)
else:
entries.append(parent.name + ' ' + parent.signature)
parent = parent.parent
parent = parent.parent # type: ignore
parent_sig = ' '.join(reversed(entries))
if len(command.aliases) > 0:
@@ -439,7 +482,7 @@ class HelpCommand:
return f'{self.context.clean_prefix}{alias} {command.signature}'
def remove_mentions(self, string):
def remove_mentions(self, string: str) -> str:
"""Removes mentions from the string to prevent abuse.
This includes ``@everyone``, ``@here``, member mentions and role mentions.
@@ -450,13 +493,13 @@ class HelpCommand:
The string with mentions removed.
"""
def replace(obj, *, transforms=self.MENTION_TRANSFORMS):
def replace(obj: re.Match, *, transforms: Dict[str, str] = self.MENTION_TRANSFORMS) -> str:
return transforms.get(obj.group(0), '@invalid')
return self.MENTION_PATTERN.sub(replace, string)
@property
def cog(self):
def cog(self) -> Optional[Cog]:
"""A property for retrieving or setting the cog for the help command.
When a cog is set for the help command, it is as-if the help command
@@ -473,7 +516,7 @@ class HelpCommand:
return self._command_impl.cog
@cog.setter
def cog(self, cog):
def cog(self, cog: Optional[Cog]) -> None:
# Remove whatever cog is currently valid, if any
self._command_impl._eject_cog()
@@ -481,7 +524,7 @@ class HelpCommand:
if cog is not None:
self._command_impl._inject_into_cog(cog)
def command_not_found(self, string):
def command_not_found(self, string: str) -> str:
"""|maybecoro|
A method called when a command is not found in the help command.
@@ -502,7 +545,7 @@ class HelpCommand:
"""
return f'No command called "{string}" found.'
def subcommand_not_found(self, command, string):
def subcommand_not_found(self, command: Command[Any, ..., Any], string: str) -> str:
"""|maybecoro|
A method called when a command did not have a subcommand requested in the help command.
@@ -532,7 +575,13 @@ class HelpCommand:
return f'Command "{command.qualified_name}" has no subcommand named {string}'
return f'Command "{command.qualified_name}" has no subcommands.'
async def filter_commands(self, commands, *, sort=False, key=None):
async def filter_commands(
self,
commands: Iterable[Command[Any, ..., Any]],
*,
sort: bool = False,
key: Optional[Callable[[Command[Any, ..., Any]], Any]] = None,
) -> List[Command[Any, ..., Any]]:
"""|coro|
Returns a filtered list of commands and optionally sorts them.
@@ -546,7 +595,7 @@ class HelpCommand:
An iterable of commands that are getting filtered.
sort: :class:`bool`
Whether to sort the result.
key: Optional[Callable[:class:`Command`, Any]]
key: Optional[Callable[[:class:`Command`], Any]]
An optional key function to pass to :func:`py:sorted` that
takes a :class:`Command` as its sole parameter. If ``sort`` is
passed as ``True`` then this will default as the command name.
@@ -565,14 +614,14 @@ class HelpCommand:
if self.verify_checks is False:
# if we do not need to verify the checks then we can just
# run it straight through normally without using await.
return sorted(iterator, key=key) if sort else list(iterator)
return sorted(iterator, key=key) if sort else list(iterator) # type: ignore - the key shouldn't be None
if self.verify_checks is None and not self.context.guild:
# if verify_checks is None and we're in a DM, don't verify
return sorted(iterator, key=key) if sort else list(iterator)
return sorted(iterator, key=key) if sort else list(iterator) # type: ignore
# if we're here then we need to check every command if it can run
async def predicate(cmd):
async def predicate(cmd: Command[Any, ..., Any]) -> bool:
try:
return await cmd.can_run(self.context)
except CommandError:
@@ -588,7 +637,7 @@ class HelpCommand:
ret.sort(key=key)
return ret
def get_max_size(self, commands):
def get_max_size(self, commands: Sequence[Command[Any, ..., Any]]) -> int:
"""Returns the largest name length of the specified command list.
Parameters
@@ -605,7 +654,7 @@ class HelpCommand:
as_lengths = (discord.utils._string_width(c.name) for c in commands)
return max(as_lengths, default=0)
def get_destination(self):
def get_destination(self) -> discord.abc.MessageableChannel:
"""Returns the :class:`~discord.abc.Messageable` where the help command will be output.
You can override this method to customise the behaviour.
@@ -619,7 +668,7 @@ class HelpCommand:
"""
return self.context.channel
async def send_error_message(self, error):
async def send_error_message(self, error: str) -> None:
"""|coro|
Handles the implementation when an error happens in the help command.
@@ -644,7 +693,7 @@ class HelpCommand:
await destination.send(error)
@_not_overridden
async def on_help_command_error(self, ctx, error):
async def on_help_command_error(self, ctx: Context[BotT], error: CommandError) -> None:
"""|coro|
The help command's error handler, as specified by :ref:`ext_commands_error_handler`.
@@ -664,7 +713,7 @@ class HelpCommand:
"""
pass
async def send_bot_help(self, mapping):
async def send_bot_help(self, mapping: Mapping[Optional[Cog], List[Command[Any, ..., Any]]]) -> None:
"""|coro|
Handles the implementation of the bot command page in the help command.
@@ -693,7 +742,7 @@ class HelpCommand:
"""
return None
async def send_cog_help(self, cog):
async def send_cog_help(self, cog: Cog) -> None:
"""|coro|
Handles the implementation of the cog page in the help command.
@@ -721,7 +770,7 @@ class HelpCommand:
"""
return None
async def send_group_help(self, group):
async def send_group_help(self, group: Group[Any, ..., Any]) -> None:
"""|coro|
Handles the implementation of the group page in the help command.
@@ -749,7 +798,7 @@ class HelpCommand:
"""
return None
async def send_command_help(self, command):
async def send_command_help(self, command: Command[Any, ..., Any]) -> None:
"""|coro|
Handles the implementation of the single command page in the help command.
@@ -787,7 +836,7 @@ class HelpCommand:
"""
return None
async def prepare_help_command(self, ctx, command=None):
async def prepare_help_command(self, ctx: Context[BotT], command: Optional[str] = None) -> None:
"""|coro|
A low level method that can be used to prepare the help command
@@ -811,7 +860,7 @@ class HelpCommand:
"""
pass
async def command_callback(self, ctx, *, command=None):
async def command_callback(self, ctx: Context[BotT], *, command: Optional[str] = None) -> None:
"""|coro|
The actual implementation of the help command.
@@ -856,7 +905,7 @@ class HelpCommand:
for key in keys[1:]:
try:
found = cmd.all_commands.get(key)
found = cmd.all_commands.get(key) # type: ignore
except AttributeError:
string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key))
return await self.send_error_message(string)
@@ -908,28 +957,28 @@ class DefaultHelpCommand(HelpCommand):
The paginator used to paginate the help command output.
"""
def __init__(self, **options):
self.width = options.pop('width', 80)
self.indent = options.pop('indent', 2)
self.sort_commands = options.pop('sort_commands', True)
self.dm_help = options.pop('dm_help', False)
self.dm_help_threshold = options.pop('dm_help_threshold', 1000)
self.commands_heading = options.pop('commands_heading', "Commands:")
self.no_category = options.pop('no_category', 'No Category')
self.paginator = options.pop('paginator', None)
def __init__(self, **options: Any) -> None:
self.width: int = options.pop('width', 80)
self.indent: int = options.pop('indent', 2)
self.sort_commands: bool = options.pop('sort_commands', True)
self.dm_help: bool = options.pop('dm_help', False)
self.dm_help_threshold: int = options.pop('dm_help_threshold', 1000)
self.commands_heading: str = options.pop('commands_heading', "Commands:")
self.no_category: str = options.pop('no_category', 'No Category')
self.paginator: Paginator = options.pop('paginator', None)
if self.paginator is None:
self.paginator = Paginator()
self.paginator: Paginator = Paginator()
super().__init__(**options)
def shorten_text(self, text):
def shorten_text(self, text: str) -> str:
""":class:`str`: Shortens text to fit into the :attr:`width`."""
if len(text) > self.width:
return text[: self.width - 3].rstrip() + '...'
return text
def get_ending_note(self):
def get_ending_note(self) -> str:
""":class:`str`: Returns help command's ending note. This is mainly useful to override for i18n purposes."""
command_name = self.invoked_with
return (
@@ -937,7 +986,9 @@ class DefaultHelpCommand(HelpCommand):
f"You can also type {self.context.clean_prefix}{command_name} category for more info on a category."
)
def add_indented_commands(self, commands, *, heading, max_size=None):
def add_indented_commands(
self, commands: Sequence[Command[Any, ..., Any]], *, heading: str, max_size: Optional[int] = None
) -> None:
"""Indents a list of commands after the specified heading.
The formatting is added to the :attr:`paginator`.
@@ -973,13 +1024,13 @@ class DefaultHelpCommand(HelpCommand):
entry = f'{self.indent * " "}{name:<{width}} {command.short_doc}'
self.paginator.add_line(self.shorten_text(entry))
async def send_pages(self):
async def send_pages(self) -> None:
"""A helper utility to send the page output from :attr:`paginator` to the destination."""
destination = self.get_destination()
for page in self.paginator.pages:
await destination.send(page)
def add_command_formatting(self, command):
def add_command_formatting(self, command: Command[Any, ..., Any]) -> None:
"""A utility function to format the non-indented block of commands and groups.
Parameters
@@ -1002,7 +1053,7 @@ class DefaultHelpCommand(HelpCommand):
self.paginator.add_line(line)
self.paginator.add_line()
def get_destination(self):
def get_destination(self) -> discord.abc.Messageable:
ctx = self.context
if self.dm_help is True:
return ctx.author
@@ -1011,11 +1062,11 @@ class DefaultHelpCommand(HelpCommand):
else:
return ctx.channel
async def prepare_help_command(self, ctx, command):
async def prepare_help_command(self, ctx: Context[BotT], command: str) -> None:
self.paginator.clear()
await super().prepare_help_command(ctx, command)
async def send_bot_help(self, mapping):
async def send_bot_help(self, mapping: Mapping[Optional[Cog], List[Command[Any, ..., Any]]]) -> None:
ctx = self.context
bot = ctx.bot
@@ -1045,12 +1096,12 @@ class DefaultHelpCommand(HelpCommand):
await self.send_pages()
async def send_command_help(self, command):
async def send_command_help(self, command: Command[Any, ..., Any]) -> None:
self.add_command_formatting(command)
self.paginator.close_page()
await self.send_pages()
async def send_group_help(self, group):
async def send_group_help(self, group: Group[Any, ..., Any]) -> None:
self.add_command_formatting(group)
filtered = await self.filter_commands(group.commands, sort=self.sort_commands)
@@ -1064,7 +1115,7 @@ class DefaultHelpCommand(HelpCommand):
await self.send_pages()
async def send_cog_help(self, cog):
async def send_cog_help(self, cog: Cog) -> None:
if cog.description:
self.paginator.add_line(cog.description, empty=True)
@@ -1111,27 +1162,27 @@ class MinimalHelpCommand(HelpCommand):
The paginator used to paginate the help command output.
"""
def __init__(self, **options):
self.sort_commands = options.pop('sort_commands', True)
self.commands_heading = options.pop('commands_heading', "Commands")
self.dm_help = options.pop('dm_help', False)
self.dm_help_threshold = options.pop('dm_help_threshold', 1000)
self.aliases_heading = options.pop('aliases_heading', "Aliases:")
self.no_category = options.pop('no_category', 'No Category')
self.paginator = options.pop('paginator', None)
def __init__(self, **options: Any) -> None:
self.sort_commands: bool = options.pop('sort_commands', True)
self.commands_heading: str = options.pop('commands_heading', "Commands")
self.dm_help: bool = options.pop('dm_help', False)
self.dm_help_threshold: int = options.pop('dm_help_threshold', 1000)
self.aliases_heading: str = options.pop('aliases_heading', "Aliases:")
self.no_category: str = options.pop('no_category', 'No Category')
self.paginator: Paginator = options.pop('paginator', None)
if self.paginator is None:
self.paginator = Paginator(suffix=None, prefix=None)
self.paginator: Paginator = Paginator(suffix=None, prefix=None)
super().__init__(**options)
async def send_pages(self):
async def send_pages(self) -> None:
"""A helper utility to send the page output from :attr:`paginator` to the destination."""
destination = self.get_destination()
for page in self.paginator.pages:
await destination.send(page)
def get_opening_note(self):
def get_opening_note(self) -> str:
"""Returns help command's opening note. This is mainly useful to override for i18n purposes.
The default implementation returns ::
@@ -1150,10 +1201,10 @@ class MinimalHelpCommand(HelpCommand):
f"You can also use `{self.context.clean_prefix}{command_name} [category]` for more info on a category."
)
def get_command_signature(self, command):
def get_command_signature(self, command: Command[Any, ..., Any]) -> str:
return f'{self.context.clean_prefix}{command.qualified_name} {command.signature}'
def get_ending_note(self):
def get_ending_note(self) -> str:
"""Return the help command's ending note. This is mainly useful to override for i18n purposes.
The default implementation does nothing.
@@ -1163,9 +1214,9 @@ class MinimalHelpCommand(HelpCommand):
:class:`str`
The help command ending note.
"""
return None
return ''
def add_bot_commands_formatting(self, commands, heading):
def add_bot_commands_formatting(self, commands: Sequence[Command[Any, ..., Any]], heading: str) -> None:
"""Adds the minified bot heading with commands to the output.
The formatting should be added to the :attr:`paginator`.
@@ -1186,7 +1237,7 @@ class MinimalHelpCommand(HelpCommand):
self.paginator.add_line(f'__**{heading}**__')
self.paginator.add_line(joined)
def add_subcommand_formatting(self, command):
def add_subcommand_formatting(self, command: Command[Any, ..., Any]) -> None:
"""Adds formatting information on a subcommand.
The formatting should be added to the :attr:`paginator`.
@@ -1202,7 +1253,7 @@ class MinimalHelpCommand(HelpCommand):
fmt = '{0}{1} \N{EN DASH} {2}' if command.short_doc else '{0}{1}'
self.paginator.add_line(fmt.format(self.context.clean_prefix, command.qualified_name, command.short_doc))
def add_aliases_formatting(self, aliases):
def add_aliases_formatting(self, aliases: Sequence[str]) -> None:
"""Adds the formatting information on a command's aliases.
The formatting should be added to the :attr:`paginator`.
@@ -1219,7 +1270,7 @@ class MinimalHelpCommand(HelpCommand):
"""
self.paginator.add_line(f'**{self.aliases_heading}** {", ".join(aliases)}', empty=True)
def add_command_formatting(self, command):
def add_command_formatting(self, command: Command[Any, ..., Any]) -> None:
"""A utility function to format commands and groups.
Parameters
@@ -1246,7 +1297,7 @@ class MinimalHelpCommand(HelpCommand):
self.paginator.add_line(line)
self.paginator.add_line()
def get_destination(self):
def get_destination(self) -> discord.abc.Messageable:
ctx = self.context
if self.dm_help is True:
return ctx.author
@@ -1255,11 +1306,11 @@ class MinimalHelpCommand(HelpCommand):
else:
return ctx.channel
async def prepare_help_command(self, ctx, command):
async def prepare_help_command(self, ctx: Context[BotT], command: str) -> None:
self.paginator.clear()
await super().prepare_help_command(ctx, command)
async def send_bot_help(self, mapping):
async def send_bot_help(self, mapping: Mapping[Optional[Cog], List[Command[Any, ..., Any]]]) -> None:
ctx = self.context
bot = ctx.bot
@@ -1272,7 +1323,7 @@ class MinimalHelpCommand(HelpCommand):
no_category = f'\u200b{self.no_category}'
def get_category(command, *, no_category=no_category):
def get_category(command: Command[Any, ..., Any], *, no_category: str = no_category) -> str:
cog = command.cog
return cog.qualified_name if cog is not None else no_category
@@ -1290,7 +1341,7 @@ class MinimalHelpCommand(HelpCommand):
await self.send_pages()
async def send_cog_help(self, cog):
async def send_cog_help(self, cog: Cog) -> None:
bot = self.context.bot
if bot.description:
self.paginator.add_line(bot.description, empty=True)
@@ -1315,7 +1366,7 @@ class MinimalHelpCommand(HelpCommand):
await self.send_pages()
async def send_group_help(self, group):
async def send_group_help(self, group: Group[Any, ..., Any]) -> None:
self.add_command_formatting(group)
filtered = await self.filter_commands(group.commands, sort=self.sort_commands)
@@ -1335,7 +1386,7 @@ class MinimalHelpCommand(HelpCommand):
await self.send_pages()
async def send_command_help(self, command):
async def send_command_help(self, command: Command[Any, ..., Any]) -> None:
self.add_command_formatting(command)
self.paginator.close_page()
await self.send_pages()

View File

@@ -21,6 +21,11 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Optional
from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError
# map from opening quotes to closing quotes
@@ -47,24 +52,24 @@ _all_quotes = set(_quotes.keys()) | set(_quotes.values())
class StringView:
def __init__(self, buffer):
self.index = 0
self.buffer = buffer
self.end = len(buffer)
def __init__(self, buffer: str) -> None:
self.index: int = 0
self.buffer: str = buffer
self.end: int = len(buffer)
self.previous = 0
@property
def current(self):
def current(self) -> Optional[str]:
return None if self.eof else self.buffer[self.index]
@property
def eof(self):
def eof(self) -> bool:
return self.index >= self.end
def undo(self):
def undo(self) -> None:
self.index = self.previous
def skip_ws(self):
def skip_ws(self) -> bool:
pos = 0
while not self.eof:
try:
@@ -79,7 +84,7 @@ class StringView:
self.index += pos
return self.previous != self.index
def skip_string(self, string):
def skip_string(self, string: str) -> bool:
strlen = len(string)
if self.buffer[self.index : self.index + strlen] == string:
self.previous = self.index
@@ -87,19 +92,19 @@ class StringView:
return True
return False
def read_rest(self):
def read_rest(self) -> str:
result = self.buffer[self.index :]
self.previous = self.index
self.index = self.end
return result
def read(self, n):
def read(self, n: int) -> str:
result = self.buffer[self.index : self.index + n]
self.previous = self.index
self.index += n
return result
def get(self):
def get(self) -> Optional[str]:
try:
result = self.buffer[self.index + 1]
except IndexError:
@@ -109,7 +114,7 @@ class StringView:
self.index += 1
return result
def get_word(self):
def get_word(self) -> str:
pos = 0
while not self.eof:
try:
@@ -119,12 +124,12 @@ class StringView:
pos += 1
except IndexError:
break
self.previous = self.index
self.previous: int = self.index
result = self.buffer[self.index : self.index + pos]
self.index += pos
return result
def get_quoted_word(self):
def get_quoted_word(self) -> Optional[str]:
current = self.current
if current is None:
return None
@@ -187,5 +192,5 @@ class StringView:
result.append(current)
def __repr__(self):
def __repr__(self) -> str:
return f'<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>'

View File

@@ -110,15 +110,15 @@ class SleepHandle:
__slots__ = ('future', 'loop', 'handle')
def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None:
self.loop = loop
self.future = future = loop.create_future()
self.loop: asyncio.AbstractEventLoop = loop
self.future: asyncio.Future[None] = loop.create_future()
relative_delta = discord.utils.compute_timedelta(dt)
self.handle = loop.call_later(relative_delta, future.set_result, True)
self.handle = loop.call_later(relative_delta, self.future.set_result, True)
def recalculate(self, dt: datetime.datetime) -> None:
self.handle.cancel()
relative_delta = discord.utils.compute_timedelta(dt)
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True)
self.handle: asyncio.TimerHandle = self.loop.call_later(relative_delta, self.future.set_result, True)
def wait(self) -> asyncio.Future[Any]:
return self.future