mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-09-05 01:16:21 +00:00
Fix typing issues and improve typing completeness across the library
Co-authored-by: Danny <Rapptz@users.noreply.github.com> Co-authored-by: Josh <josh.ja.butt@gmail.com>
This commit is contained in:
@ -41,7 +41,6 @@ from typing import (
|
||||
Tuple,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
overload,
|
||||
)
|
||||
|
||||
import discord
|
||||
@ -51,9 +50,8 @@ if TYPE_CHECKING:
|
||||
from .context import Context
|
||||
from discord.state import Channel
|
||||
from discord.threads import Thread
|
||||
from .bot import Bot, AutoShardedBot
|
||||
|
||||
_Bot = TypeVar('_Bot', bound=Union[Bot, AutoShardedBot])
|
||||
from ._types import BotT, _Bot
|
||||
|
||||
|
||||
__all__ = (
|
||||
@ -87,7 +85,7 @@ __all__ = (
|
||||
)
|
||||
|
||||
|
||||
def _get_from_guilds(bot, getter, argument):
|
||||
def _get_from_guilds(bot: _Bot, getter: str, argument: Any) -> Any:
|
||||
result = None
|
||||
for guild in bot.guilds:
|
||||
result = getattr(guild, getter)(argument)
|
||||
@ -115,7 +113,7 @@ class Converter(Protocol[T_co]):
|
||||
method to do its conversion logic. This method must be a :ref:`coroutine <coroutine>`.
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> T_co:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> T_co:
|
||||
"""|coro|
|
||||
|
||||
The method to override to do conversion logic.
|
||||
@ -163,7 +161,7 @@ class ObjectConverter(IDConverter[discord.Object]):
|
||||
2. Lookup by member, role, or channel mention.
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Object:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Object:
|
||||
match = self._get_id_match(argument) or re.match(r'<(?:@(?:!|&)?|#)([0-9]{15,20})>$', argument)
|
||||
|
||||
if match is None:
|
||||
@ -196,7 +194,7 @@ class MemberConverter(IDConverter[discord.Member]):
|
||||
optionally caching the result if :attr:`.MemberCacheFlags.joined` is enabled.
|
||||
"""
|
||||
|
||||
async def query_member_named(self, guild, argument):
|
||||
async def query_member_named(self, guild: discord.Guild, argument: str) -> Optional[discord.Member]:
|
||||
cache = guild._state.member_cache_flags.joined
|
||||
if len(argument) > 5 and argument[-5] == '#':
|
||||
username, _, discriminator = argument.rpartition('#')
|
||||
@ -206,7 +204,7 @@ class MemberConverter(IDConverter[discord.Member]):
|
||||
members = await guild.query_members(argument, limit=100, cache=cache)
|
||||
return discord.utils.find(lambda m: m.name == argument or m.nick == argument, members)
|
||||
|
||||
async def query_member_by_id(self, bot, guild, user_id):
|
||||
async def query_member_by_id(self, bot: _Bot, guild: discord.Guild, user_id: int) -> Optional[discord.Member]:
|
||||
ws = bot._get_websocket(shard_id=guild.shard_id)
|
||||
cache = guild._state.member_cache_flags.joined
|
||||
if ws.is_ratelimited():
|
||||
@ -227,7 +225,7 @@ class MemberConverter(IDConverter[discord.Member]):
|
||||
return None
|
||||
return members[0]
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Member:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Member:
|
||||
bot = ctx.bot
|
||||
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
|
||||
guild = ctx.guild
|
||||
@ -281,7 +279,7 @@ class UserConverter(IDConverter[discord.User]):
|
||||
and it's not available in cache.
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.User:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.User:
|
||||
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
|
||||
result = None
|
||||
state = ctx._state
|
||||
@ -359,7 +357,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
|
||||
|
||||
@staticmethod
|
||||
def _resolve_channel(
|
||||
ctx: Context[_Bot], guild_id: Optional[int], channel_id: Optional[int]
|
||||
ctx: Context[BotT], guild_id: Optional[int], channel_id: Optional[int]
|
||||
) -> Optional[Union[Channel, Thread]]:
|
||||
if channel_id is None:
|
||||
# we were passed just a message id so we can assume the channel is the current context channel
|
||||
@ -373,7 +371,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
|
||||
|
||||
return ctx.bot.get_channel(channel_id)
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialMessage:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.PartialMessage:
|
||||
guild_id, message_id, channel_id = self._get_id_matches(ctx, argument)
|
||||
channel = self._resolve_channel(ctx, guild_id, channel_id)
|
||||
if not channel or not isinstance(channel, discord.abc.Messageable):
|
||||
@ -396,7 +394,7 @@ class MessageConverter(IDConverter[discord.Message]):
|
||||
Raise :exc:`.ChannelNotFound`, :exc:`.MessageNotFound` or :exc:`.ChannelNotReadable` instead of generic :exc:`.BadArgument`
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Message:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Message:
|
||||
guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument)
|
||||
message = ctx.bot._connection._get_message(message_id)
|
||||
if message:
|
||||
@ -427,11 +425,11 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.abc.GuildChannel:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.abc.GuildChannel:
|
||||
return self._resolve_channel(ctx, argument, 'channels', discord.abc.GuildChannel)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT:
|
||||
def _resolve_channel(ctx: Context[BotT], argument: str, attribute: str, type: Type[CT]) -> CT:
|
||||
bot = ctx.bot
|
||||
|
||||
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
|
||||
@ -448,7 +446,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
|
||||
def check(c):
|
||||
return isinstance(c, type) and c.name == argument
|
||||
|
||||
result = discord.utils.find(check, bot.get_all_channels())
|
||||
result = discord.utils.find(check, bot.get_all_channels()) # type: ignore
|
||||
else:
|
||||
channel_id = int(match.group(1))
|
||||
if guild:
|
||||
@ -463,7 +461,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT:
|
||||
def _resolve_thread(ctx: Context[BotT], argument: str, attribute: str, type: Type[TT]) -> TT:
|
||||
bot = ctx.bot
|
||||
|
||||
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
|
||||
@ -502,7 +500,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
|
||||
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.TextChannel:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.TextChannel:
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, 'text_channels', discord.TextChannel)
|
||||
|
||||
|
||||
@ -522,7 +520,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
|
||||
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.VoiceChannel:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.VoiceChannel:
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, 'voice_channels', discord.VoiceChannel)
|
||||
|
||||
|
||||
@ -541,7 +539,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
|
||||
3. Lookup by name
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.StageChannel:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.StageChannel:
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, 'stage_channels', discord.StageChannel)
|
||||
|
||||
|
||||
@ -561,7 +559,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
|
||||
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.CategoryChannel:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.CategoryChannel:
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, 'categories', discord.CategoryChannel)
|
||||
|
||||
|
||||
@ -580,7 +578,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
|
||||
.. versionadded:: 1.7
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.StoreChannel:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.StoreChannel:
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel)
|
||||
|
||||
|
||||
@ -598,7 +596,7 @@ class ThreadConverter(IDConverter[discord.Thread]):
|
||||
.. versionadded: 2.0
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Thread:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Thread:
|
||||
return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread)
|
||||
|
||||
|
||||
@ -630,7 +628,7 @@ class ColourConverter(Converter[discord.Colour]):
|
||||
|
||||
RGB_REGEX = re.compile(r'rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)')
|
||||
|
||||
def parse_hex_number(self, argument):
|
||||
def parse_hex_number(self, argument: str) -> discord.Colour:
|
||||
arg = ''.join(i * 2 for i in argument) if len(argument) == 3 else argument
|
||||
try:
|
||||
value = int(arg, base=16)
|
||||
@ -641,7 +639,7 @@ class ColourConverter(Converter[discord.Colour]):
|
||||
else:
|
||||
return discord.Color(value=value)
|
||||
|
||||
def parse_rgb_number(self, argument, number):
|
||||
def parse_rgb_number(self, argument: str, number: str) -> int:
|
||||
if number[-1] == '%':
|
||||
value = int(number[:-1])
|
||||
if not (0 <= value <= 100):
|
||||
@ -653,7 +651,7 @@ class ColourConverter(Converter[discord.Colour]):
|
||||
raise BadColourArgument(argument)
|
||||
return value
|
||||
|
||||
def parse_rgb(self, argument, *, regex=RGB_REGEX):
|
||||
def parse_rgb(self, argument: str, *, regex: re.Pattern[str] = RGB_REGEX) -> discord.Colour:
|
||||
match = regex.match(argument)
|
||||
if match is None:
|
||||
raise BadColourArgument(argument)
|
||||
@ -663,7 +661,7 @@ class ColourConverter(Converter[discord.Colour]):
|
||||
blue = self.parse_rgb_number(argument, match.group('b'))
|
||||
return discord.Color.from_rgb(red, green, blue)
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Colour:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Colour:
|
||||
if argument[0] == '#':
|
||||
return self.parse_hex_number(argument[1:])
|
||||
|
||||
@ -704,7 +702,7 @@ class RoleConverter(IDConverter[discord.Role]):
|
||||
Raise :exc:`.RoleNotFound` instead of generic :exc:`.BadArgument`
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Role:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Role:
|
||||
guild = ctx.guild
|
||||
if not guild:
|
||||
raise NoPrivateMessage()
|
||||
@ -723,7 +721,7 @@ class RoleConverter(IDConverter[discord.Role]):
|
||||
class GameConverter(Converter[discord.Game]):
|
||||
"""Converts to :class:`~discord.Game`."""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Game:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Game:
|
||||
return discord.Game(name=argument)
|
||||
|
||||
|
||||
@ -736,7 +734,7 @@ class InviteConverter(Converter[discord.Invite]):
|
||||
Raise :exc:`.BadInviteArgument` instead of generic :exc:`.BadArgument`
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Invite:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Invite:
|
||||
try:
|
||||
invite = await ctx.bot.fetch_invite(argument)
|
||||
return invite
|
||||
@ -755,7 +753,7 @@ class GuildConverter(IDConverter[discord.Guild]):
|
||||
.. versionadded:: 1.7
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Guild:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Guild:
|
||||
match = self._get_id_match(argument)
|
||||
result = None
|
||||
|
||||
@ -787,7 +785,7 @@ class EmojiConverter(IDConverter[discord.Emoji]):
|
||||
Raise :exc:`.EmojiNotFound` instead of generic :exc:`.BadArgument`
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Emoji:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Emoji:
|
||||
match = self._get_id_match(argument) or re.match(r'<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$', argument)
|
||||
result = None
|
||||
bot = ctx.bot
|
||||
@ -821,7 +819,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
|
||||
Raise :exc:`.PartialEmojiConversionFailure` instead of generic :exc:`.BadArgument`
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialEmoji:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.PartialEmoji:
|
||||
match = re.match(r'<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$', argument)
|
||||
|
||||
if match:
|
||||
@ -850,7 +848,7 @@ class GuildStickerConverter(IDConverter[discord.GuildSticker]):
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.GuildSticker:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.GuildSticker:
|
||||
match = self._get_id_match(argument)
|
||||
result = None
|
||||
bot = ctx.bot
|
||||
@ -890,7 +888,7 @@ class ScheduledEventConverter(IDConverter[discord.ScheduledEvent]):
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.ScheduledEvent:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.ScheduledEvent:
|
||||
guild = ctx.guild
|
||||
match = self._get_id_match(argument)
|
||||
result = None
|
||||
@ -967,7 +965,7 @@ class clean_content(Converter[str]):
|
||||
self.escape_markdown = escape_markdown
|
||||
self.remove_markdown = remove_markdown
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> str:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> str:
|
||||
msg = ctx.message
|
||||
|
||||
if ctx.guild:
|
||||
@ -1047,10 +1045,10 @@ class Greedy(List[T]):
|
||||
|
||||
__slots__ = ('converter',)
|
||||
|
||||
def __init__(self, *, converter: T):
|
||||
self.converter = converter
|
||||
def __init__(self, *, converter: T) -> None:
|
||||
self.converter: T = converter
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
converter = getattr(self.converter, '__name__', repr(self.converter))
|
||||
return f'Greedy[{converter}]'
|
||||
|
||||
@ -1099,11 +1097,11 @@ def get_converter(param: inspect.Parameter) -> Any:
|
||||
_GenericAlias = type(List[T])
|
||||
|
||||
|
||||
def is_generic_type(tp: Any, *, _GenericAlias: Type = _GenericAlias) -> bool:
|
||||
return isinstance(tp, type) and issubclass(tp, Generic) or isinstance(tp, _GenericAlias) # type: ignore
|
||||
def is_generic_type(tp: Any, *, _GenericAlias: type = _GenericAlias) -> bool:
|
||||
return isinstance(tp, type) and issubclass(tp, Generic) or isinstance(tp, _GenericAlias)
|
||||
|
||||
|
||||
CONVERTER_MAPPING: Dict[Type[Any], Any] = {
|
||||
CONVERTER_MAPPING: Dict[type, Any] = {
|
||||
discord.Object: ObjectConverter,
|
||||
discord.Member: MemberConverter,
|
||||
discord.User: UserConverter,
|
||||
@ -1128,7 +1126,7 @@ CONVERTER_MAPPING: Dict[Type[Any], Any] = {
|
||||
}
|
||||
|
||||
|
||||
async def _actual_conversion(ctx: Context, converter, argument: str, param: inspect.Parameter):
|
||||
async def _actual_conversion(ctx: Context[BotT], converter, argument: str, param: inspect.Parameter):
|
||||
if converter is bool:
|
||||
return _convert_to_bool(argument)
|
||||
|
||||
@ -1166,7 +1164,7 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp
|
||||
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
|
||||
|
||||
|
||||
async def run_converters(ctx: Context, converter, argument: str, param: inspect.Parameter):
|
||||
async def run_converters(ctx: Context[BotT], converter: Any, argument: str, param: inspect.Parameter) -> Any:
|
||||
"""|coro|
|
||||
|
||||
Runs converters for a given converter, argument, and parameter.
|
||||
|
Reference in New Issue
Block a user