Fix typing issues and improve typing completeness across the library

Co-authored-by: Danny <Rapptz@users.noreply.github.com>
Co-authored-by: Josh <josh.ja.butt@gmail.com>
This commit is contained in:
Stocker
2022-03-13 23:52:10 -04:00
committed by GitHub
parent 603681940f
commit 5aa696ccfa
66 changed files with 1071 additions and 802 deletions

View File

@ -41,7 +41,6 @@ from typing import (
Tuple,
Union,
runtime_checkable,
overload,
)
import discord
@ -51,9 +50,8 @@ if TYPE_CHECKING:
from .context import Context
from discord.state import Channel
from discord.threads import Thread
from .bot import Bot, AutoShardedBot
_Bot = TypeVar('_Bot', bound=Union[Bot, AutoShardedBot])
from ._types import BotT, _Bot
__all__ = (
@ -87,7 +85,7 @@ __all__ = (
)
def _get_from_guilds(bot, getter, argument):
def _get_from_guilds(bot: _Bot, getter: str, argument: Any) -> Any:
result = None
for guild in bot.guilds:
result = getattr(guild, getter)(argument)
@ -115,7 +113,7 @@ class Converter(Protocol[T_co]):
method to do its conversion logic. This method must be a :ref:`coroutine <coroutine>`.
"""
async def convert(self, ctx: Context, argument: str) -> T_co:
async def convert(self, ctx: Context[BotT], argument: str) -> T_co:
"""|coro|
The method to override to do conversion logic.
@ -163,7 +161,7 @@ class ObjectConverter(IDConverter[discord.Object]):
2. Lookup by member, role, or channel mention.
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Object:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Object:
match = self._get_id_match(argument) or re.match(r'<(?:@(?:!|&)?|#)([0-9]{15,20})>$', argument)
if match is None:
@ -196,7 +194,7 @@ class MemberConverter(IDConverter[discord.Member]):
optionally caching the result if :attr:`.MemberCacheFlags.joined` is enabled.
"""
async def query_member_named(self, guild, argument):
async def query_member_named(self, guild: discord.Guild, argument: str) -> Optional[discord.Member]:
cache = guild._state.member_cache_flags.joined
if len(argument) > 5 and argument[-5] == '#':
username, _, discriminator = argument.rpartition('#')
@ -206,7 +204,7 @@ class MemberConverter(IDConverter[discord.Member]):
members = await guild.query_members(argument, limit=100, cache=cache)
return discord.utils.find(lambda m: m.name == argument or m.nick == argument, members)
async def query_member_by_id(self, bot, guild, user_id):
async def query_member_by_id(self, bot: _Bot, guild: discord.Guild, user_id: int) -> Optional[discord.Member]:
ws = bot._get_websocket(shard_id=guild.shard_id)
cache = guild._state.member_cache_flags.joined
if ws.is_ratelimited():
@ -227,7 +225,7 @@ class MemberConverter(IDConverter[discord.Member]):
return None
return members[0]
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Member:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Member:
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
guild = ctx.guild
@ -281,7 +279,7 @@ class UserConverter(IDConverter[discord.User]):
and it's not available in cache.
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.User:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.User:
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
result = None
state = ctx._state
@ -359,7 +357,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
@staticmethod
def _resolve_channel(
ctx: Context[_Bot], guild_id: Optional[int], channel_id: Optional[int]
ctx: Context[BotT], guild_id: Optional[int], channel_id: Optional[int]
) -> Optional[Union[Channel, Thread]]:
if channel_id is None:
# we were passed just a message id so we can assume the channel is the current context channel
@ -373,7 +371,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
return ctx.bot.get_channel(channel_id)
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialMessage:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.PartialMessage:
guild_id, message_id, channel_id = self._get_id_matches(ctx, argument)
channel = self._resolve_channel(ctx, guild_id, channel_id)
if not channel or not isinstance(channel, discord.abc.Messageable):
@ -396,7 +394,7 @@ class MessageConverter(IDConverter[discord.Message]):
Raise :exc:`.ChannelNotFound`, :exc:`.MessageNotFound` or :exc:`.ChannelNotReadable` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Message:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Message:
guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument)
message = ctx.bot._connection._get_message(message_id)
if message:
@ -427,11 +425,11 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
.. versionadded:: 2.0
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.abc.GuildChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.abc.GuildChannel:
return self._resolve_channel(ctx, argument, 'channels', discord.abc.GuildChannel)
@staticmethod
def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT:
def _resolve_channel(ctx: Context[BotT], argument: str, attribute: str, type: Type[CT]) -> CT:
bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
@ -448,7 +446,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
def check(c):
return isinstance(c, type) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels())
result = discord.utils.find(check, bot.get_all_channels()) # type: ignore
else:
channel_id = int(match.group(1))
if guild:
@ -463,7 +461,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
return result
@staticmethod
def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT:
def _resolve_thread(ctx: Context[BotT], argument: str, attribute: str, type: Type[TT]) -> TT:
bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
@ -502,7 +500,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.TextChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.TextChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'text_channels', discord.TextChannel)
@ -522,7 +520,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.VoiceChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.VoiceChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'voice_channels', discord.VoiceChannel)
@ -541,7 +539,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
3. Lookup by name
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.StageChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.StageChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'stage_channels', discord.StageChannel)
@ -561,7 +559,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.CategoryChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.CategoryChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'categories', discord.CategoryChannel)
@ -580,7 +578,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
.. versionadded:: 1.7
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.StoreChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.StoreChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel)
@ -598,7 +596,7 @@ class ThreadConverter(IDConverter[discord.Thread]):
.. versionadded: 2.0
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Thread:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Thread:
return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread)
@ -630,7 +628,7 @@ class ColourConverter(Converter[discord.Colour]):
RGB_REGEX = re.compile(r'rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)')
def parse_hex_number(self, argument):
def parse_hex_number(self, argument: str) -> discord.Colour:
arg = ''.join(i * 2 for i in argument) if len(argument) == 3 else argument
try:
value = int(arg, base=16)
@ -641,7 +639,7 @@ class ColourConverter(Converter[discord.Colour]):
else:
return discord.Color(value=value)
def parse_rgb_number(self, argument, number):
def parse_rgb_number(self, argument: str, number: str) -> int:
if number[-1] == '%':
value = int(number[:-1])
if not (0 <= value <= 100):
@ -653,7 +651,7 @@ class ColourConverter(Converter[discord.Colour]):
raise BadColourArgument(argument)
return value
def parse_rgb(self, argument, *, regex=RGB_REGEX):
def parse_rgb(self, argument: str, *, regex: re.Pattern[str] = RGB_REGEX) -> discord.Colour:
match = regex.match(argument)
if match is None:
raise BadColourArgument(argument)
@ -663,7 +661,7 @@ class ColourConverter(Converter[discord.Colour]):
blue = self.parse_rgb_number(argument, match.group('b'))
return discord.Color.from_rgb(red, green, blue)
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Colour:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Colour:
if argument[0] == '#':
return self.parse_hex_number(argument[1:])
@ -704,7 +702,7 @@ class RoleConverter(IDConverter[discord.Role]):
Raise :exc:`.RoleNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Role:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Role:
guild = ctx.guild
if not guild:
raise NoPrivateMessage()
@ -723,7 +721,7 @@ class RoleConverter(IDConverter[discord.Role]):
class GameConverter(Converter[discord.Game]):
"""Converts to :class:`~discord.Game`."""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Game:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Game:
return discord.Game(name=argument)
@ -736,7 +734,7 @@ class InviteConverter(Converter[discord.Invite]):
Raise :exc:`.BadInviteArgument` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Invite:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Invite:
try:
invite = await ctx.bot.fetch_invite(argument)
return invite
@ -755,7 +753,7 @@ class GuildConverter(IDConverter[discord.Guild]):
.. versionadded:: 1.7
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Guild:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Guild:
match = self._get_id_match(argument)
result = None
@ -787,7 +785,7 @@ class EmojiConverter(IDConverter[discord.Emoji]):
Raise :exc:`.EmojiNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Emoji:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Emoji:
match = self._get_id_match(argument) or re.match(r'<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$', argument)
result = None
bot = ctx.bot
@ -821,7 +819,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
Raise :exc:`.PartialEmojiConversionFailure` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialEmoji:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.PartialEmoji:
match = re.match(r'<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$', argument)
if match:
@ -850,7 +848,7 @@ class GuildStickerConverter(IDConverter[discord.GuildSticker]):
.. versionadded:: 2.0
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.GuildSticker:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.GuildSticker:
match = self._get_id_match(argument)
result = None
bot = ctx.bot
@ -890,7 +888,7 @@ class ScheduledEventConverter(IDConverter[discord.ScheduledEvent]):
.. versionadded:: 2.0
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.ScheduledEvent:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.ScheduledEvent:
guild = ctx.guild
match = self._get_id_match(argument)
result = None
@ -967,7 +965,7 @@ class clean_content(Converter[str]):
self.escape_markdown = escape_markdown
self.remove_markdown = remove_markdown
async def convert(self, ctx: Context[_Bot], argument: str) -> str:
async def convert(self, ctx: Context[BotT], argument: str) -> str:
msg = ctx.message
if ctx.guild:
@ -1047,10 +1045,10 @@ class Greedy(List[T]):
__slots__ = ('converter',)
def __init__(self, *, converter: T):
self.converter = converter
def __init__(self, *, converter: T) -> None:
self.converter: T = converter
def __repr__(self):
def __repr__(self) -> str:
converter = getattr(self.converter, '__name__', repr(self.converter))
return f'Greedy[{converter}]'
@ -1099,11 +1097,11 @@ def get_converter(param: inspect.Parameter) -> Any:
_GenericAlias = type(List[T])
def is_generic_type(tp: Any, *, _GenericAlias: Type = _GenericAlias) -> bool:
return isinstance(tp, type) and issubclass(tp, Generic) or isinstance(tp, _GenericAlias) # type: ignore
def is_generic_type(tp: Any, *, _GenericAlias: type = _GenericAlias) -> bool:
return isinstance(tp, type) and issubclass(tp, Generic) or isinstance(tp, _GenericAlias)
CONVERTER_MAPPING: Dict[Type[Any], Any] = {
CONVERTER_MAPPING: Dict[type, Any] = {
discord.Object: ObjectConverter,
discord.Member: MemberConverter,
discord.User: UserConverter,
@ -1128,7 +1126,7 @@ CONVERTER_MAPPING: Dict[Type[Any], Any] = {
}
async def _actual_conversion(ctx: Context, converter, argument: str, param: inspect.Parameter):
async def _actual_conversion(ctx: Context[BotT], converter, argument: str, param: inspect.Parameter):
if converter is bool:
return _convert_to_bool(argument)
@ -1166,7 +1164,7 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
async def run_converters(ctx: Context, converter, argument: str, param: inspect.Parameter):
async def run_converters(ctx: Context[BotT], converter: Any, argument: str, param: inspect.Parameter) -> Any:
"""|coro|
Runs converters for a given converter, argument, and parameter.