mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-04-18 23:15:48 +00:00
Fix typing issues and improve typing completeness across the library
Co-authored-by: Danny <Rapptz@users.noreply.github.com> Co-authored-by: Josh <josh.ja.butt@gmail.com>
This commit is contained in:
parent
603681940f
commit
5aa696ccfa
@ -23,7 +23,8 @@ DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Dict, Optional
|
||||
|
||||
from typing import Optional, Tuple, Dict
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
@ -35,7 +36,7 @@ import aiohttp
|
||||
import platform
|
||||
|
||||
|
||||
def show_version():
|
||||
def show_version() -> None:
|
||||
entries = []
|
||||
|
||||
entries.append('- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(sys.version_info))
|
||||
@ -52,7 +53,7 @@ def show_version():
|
||||
print('\n'.join(entries))
|
||||
|
||||
|
||||
def core(parser, args):
|
||||
def core(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None:
|
||||
if args.version:
|
||||
show_version()
|
||||
|
||||
@ -185,7 +186,7 @@ _base_table.update((chr(i), None) for i in range(32))
|
||||
_translation_table = str.maketrans(_base_table)
|
||||
|
||||
|
||||
def to_path(parser, name, *, replace_spaces=False):
|
||||
def to_path(parser: argparse.ArgumentParser, name: str, *, replace_spaces: bool = False) -> Path:
|
||||
if isinstance(name, Path):
|
||||
return name
|
||||
|
||||
@ -223,7 +224,7 @@ def to_path(parser, name, *, replace_spaces=False):
|
||||
return Path(name)
|
||||
|
||||
|
||||
def newbot(parser, args):
|
||||
def newbot(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None:
|
||||
new_directory = to_path(parser, args.directory) / to_path(parser, args.name)
|
||||
|
||||
# as a note exist_ok for Path is a 3.5+ only feature
|
||||
@ -265,7 +266,7 @@ def newbot(parser, args):
|
||||
print('successfully made bot at', new_directory)
|
||||
|
||||
|
||||
def newcog(parser, args):
|
||||
def newcog(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None:
|
||||
cog_dir = to_path(parser, args.directory)
|
||||
try:
|
||||
cog_dir.mkdir(exist_ok=True)
|
||||
@ -299,7 +300,7 @@ def newcog(parser, args):
|
||||
print('successfully made cog at', directory)
|
||||
|
||||
|
||||
def add_newbot_args(subparser):
|
||||
def add_newbot_args(subparser: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
|
||||
parser = subparser.add_parser('newbot', help='creates a command bot project quickly')
|
||||
parser.set_defaults(func=newbot)
|
||||
|
||||
@ -310,7 +311,7 @@ def add_newbot_args(subparser):
|
||||
parser.add_argument('--no-git', help='do not create a .gitignore file', action='store_true', dest='no_git')
|
||||
|
||||
|
||||
def add_newcog_args(subparser):
|
||||
def add_newcog_args(subparser: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
|
||||
parser = subparser.add_parser('newcog', help='creates a new cog template quickly')
|
||||
parser.set_defaults(func=newcog)
|
||||
|
||||
@ -322,7 +323,7 @@ def add_newcog_args(subparser):
|
||||
parser.add_argument('--full', help='add all special methods as well', action='store_true')
|
||||
|
||||
|
||||
def parse_args():
|
||||
def parse_args() -> Tuple[argparse.ArgumentParser, argparse.Namespace]:
|
||||
parser = argparse.ArgumentParser(prog='discord', description='Tools for helping with discord.py')
|
||||
parser.add_argument('-v', '--version', action='store_true', help='shows the library version')
|
||||
parser.set_defaults(func=core)
|
||||
@ -333,7 +334,7 @@ def parse_args():
|
||||
return parser, parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
parser, args = parse_args()
|
||||
args.func(parser, args)
|
||||
|
||||
|
@ -91,6 +91,9 @@ if TYPE_CHECKING:
|
||||
GuildChannel as GuildChannelPayload,
|
||||
OverwriteType,
|
||||
)
|
||||
from .types.snowflake import (
|
||||
SnowflakeList,
|
||||
)
|
||||
|
||||
PartialMessageableChannel = Union[TextChannel, Thread, DMChannel, PartialMessageable]
|
||||
MessageableChannel = Union[PartialMessageableChannel, GroupChannel]
|
||||
@ -708,7 +711,14 @@ class GuildChannel:
|
||||
) -> None:
|
||||
...
|
||||
|
||||
async def set_permissions(self, target, *, overwrite=_undefined, reason=None, **permissions):
|
||||
async def set_permissions(
|
||||
self,
|
||||
target: Union[Member, Role],
|
||||
*,
|
||||
overwrite: Any = _undefined,
|
||||
reason: Optional[str] = None,
|
||||
**permissions: bool,
|
||||
) -> None:
|
||||
r"""|coro|
|
||||
|
||||
Sets the channel specific permission overwrites for a target in the
|
||||
@ -917,7 +927,7 @@ class GuildChannel:
|
||||
) -> None:
|
||||
...
|
||||
|
||||
async def move(self, **kwargs) -> None:
|
||||
async def move(self, **kwargs: Any) -> None:
|
||||
"""|coro|
|
||||
|
||||
A rich interface to help move a channel relative to other channels.
|
||||
@ -1248,22 +1258,22 @@ class Messageable:
|
||||
|
||||
async def send(
|
||||
self,
|
||||
content=None,
|
||||
content: Optional[str] = None,
|
||||
*,
|
||||
tts=False,
|
||||
embed=None,
|
||||
embeds=None,
|
||||
file=None,
|
||||
files=None,
|
||||
stickers=None,
|
||||
delete_after=None,
|
||||
nonce=None,
|
||||
allowed_mentions=None,
|
||||
reference=None,
|
||||
mention_author=None,
|
||||
view=None,
|
||||
suppress_embeds=False,
|
||||
):
|
||||
tts: bool = False,
|
||||
embed: Optional[Embed] = None,
|
||||
embeds: Optional[List[Embed]] = None,
|
||||
file: Optional[File] = None,
|
||||
files: Optional[List[File]] = None,
|
||||
stickers: Optional[Sequence[Union[GuildSticker, StickerItem]]] = None,
|
||||
delete_after: Optional[float] = None,
|
||||
nonce: Optional[Union[str, int]] = None,
|
||||
allowed_mentions: Optional[AllowedMentions] = None,
|
||||
reference: Optional[Union[Message, MessageReference, PartialMessage]] = None,
|
||||
mention_author: Optional[bool] = None,
|
||||
view: Optional[View] = None,
|
||||
suppress_embeds: bool = False,
|
||||
) -> Message:
|
||||
"""|coro|
|
||||
|
||||
Sends a message to the destination with the content given.
|
||||
@ -1368,17 +1378,17 @@ class Messageable:
|
||||
previous_allowed_mention = state.allowed_mentions
|
||||
|
||||
if stickers is not None:
|
||||
stickers = [sticker.id for sticker in stickers]
|
||||
sticker_ids: SnowflakeList = [sticker.id for sticker in stickers]
|
||||
else:
|
||||
stickers = MISSING
|
||||
sticker_ids = MISSING
|
||||
|
||||
if reference is not None:
|
||||
try:
|
||||
reference = reference.to_message_reference_dict()
|
||||
reference_dict = reference.to_message_reference_dict()
|
||||
except AttributeError:
|
||||
raise TypeError('reference parameter must be Message, MessageReference, or PartialMessage') from None
|
||||
else:
|
||||
reference = MISSING
|
||||
reference_dict = MISSING
|
||||
|
||||
if view and not hasattr(view, '__discord_ui_view__'):
|
||||
raise TypeError(f'view parameter must be View not {view.__class__!r}')
|
||||
@ -1399,10 +1409,10 @@ class Messageable:
|
||||
embeds=embeds if embeds is not None else MISSING,
|
||||
nonce=nonce,
|
||||
allowed_mentions=allowed_mentions,
|
||||
message_reference=reference,
|
||||
message_reference=reference_dict,
|
||||
previous_allowed_mentions=previous_allowed_mention,
|
||||
mention_author=mention_author,
|
||||
stickers=stickers,
|
||||
stickers=sticker_ids,
|
||||
view=view,
|
||||
flags=flags,
|
||||
) as params:
|
||||
|
@ -123,7 +123,7 @@ class BaseActivity:
|
||||
|
||||
__slots__ = ('_created_at',)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
self._created_at: Optional[float] = kwargs.pop('created_at', None)
|
||||
|
||||
@property
|
||||
@ -218,7 +218,7 @@ class Activity(BaseActivity):
|
||||
'buttons',
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.state: Optional[str] = kwargs.pop('state', None)
|
||||
self.details: Optional[str] = kwargs.pop('details', None)
|
||||
@ -363,7 +363,7 @@ class Game(BaseActivity):
|
||||
|
||||
__slots__ = ('name', '_end', '_start')
|
||||
|
||||
def __init__(self, name: str, **extra):
|
||||
def __init__(self, name: str, **extra: Any) -> None:
|
||||
super().__init__(**extra)
|
||||
self.name: str = name
|
||||
|
||||
@ -420,10 +420,10 @@ class Game(BaseActivity):
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, Game) and other.name == self.name
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
@ -477,7 +477,7 @@ class Streaming(BaseActivity):
|
||||
|
||||
__slots__ = ('platform', 'name', 'game', 'url', 'details', 'assets')
|
||||
|
||||
def __init__(self, *, name: Optional[str], url: str, **extra: Any):
|
||||
def __init__(self, *, name: Optional[str], url: str, **extra: Any) -> None:
|
||||
super().__init__(**extra)
|
||||
self.platform: Optional[str] = name
|
||||
self.name: Optional[str] = extra.pop('details', name)
|
||||
@ -501,7 +501,7 @@ class Streaming(BaseActivity):
|
||||
return f'<Streaming name={self.name!r}>'
|
||||
|
||||
@property
|
||||
def twitch_name(self):
|
||||
def twitch_name(self) -> Optional[str]:
|
||||
"""Optional[:class:`str`]: If provided, the twitch name of the user streaming.
|
||||
|
||||
This corresponds to the ``large_image`` key of the :attr:`Streaming.assets`
|
||||
@ -528,10 +528,10 @@ class Streaming(BaseActivity):
|
||||
ret['details'] = self.details
|
||||
return ret
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, Streaming) and other.name == self.name and other.url == self.url
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
@ -563,14 +563,14 @@ class Spotify:
|
||||
|
||||
__slots__ = ('_state', '_details', '_timestamps', '_assets', '_party', '_sync_id', '_session_id', '_created_at')
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
self._state: str = data.pop('state', '')
|
||||
self._details: str = data.pop('details', '')
|
||||
self._timestamps: Dict[str, int] = data.pop('timestamps', {})
|
||||
self._timestamps: ActivityTimestamps = data.pop('timestamps', {})
|
||||
self._assets: ActivityAssets = data.pop('assets', {})
|
||||
self._party: ActivityParty = data.pop('party', {})
|
||||
self._sync_id: str = data.pop('sync_id')
|
||||
self._session_id: str = data.pop('session_id')
|
||||
self._sync_id: str = data.pop('sync_id', '')
|
||||
self._session_id: Optional[str] = data.pop('session_id')
|
||||
self._created_at: Optional[float] = data.pop('created_at', None)
|
||||
|
||||
@property
|
||||
@ -622,7 +622,7 @@ class Spotify:
|
||||
""":class:`str`: The activity's name. This will always return "Spotify"."""
|
||||
return 'Spotify'
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return (
|
||||
isinstance(other, Spotify)
|
||||
and other._session_id == self._session_id
|
||||
@ -630,7 +630,7 @@ class Spotify:
|
||||
and other.start == self.start
|
||||
)
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
@ -691,12 +691,14 @@ class Spotify:
|
||||
@property
|
||||
def start(self) -> datetime.datetime:
|
||||
""":class:`datetime.datetime`: When the user started playing this song in UTC."""
|
||||
return datetime.datetime.fromtimestamp(self._timestamps['start'] / 1000, tz=datetime.timezone.utc)
|
||||
# the start key will be present here
|
||||
return datetime.datetime.fromtimestamp(self._timestamps['start'] / 1000, tz=datetime.timezone.utc) # type: ignore
|
||||
|
||||
@property
|
||||
def end(self) -> datetime.datetime:
|
||||
""":class:`datetime.datetime`: When the user will stop playing this song in UTC."""
|
||||
return datetime.datetime.fromtimestamp(self._timestamps['end'] / 1000, tz=datetime.timezone.utc)
|
||||
# the end key will be present here
|
||||
return datetime.datetime.fromtimestamp(self._timestamps['end'] / 1000, tz=datetime.timezone.utc) # type: ignore
|
||||
|
||||
@property
|
||||
def duration(self) -> datetime.timedelta:
|
||||
@ -742,7 +744,7 @@ class CustomActivity(BaseActivity):
|
||||
|
||||
__slots__ = ('name', 'emoji', 'state')
|
||||
|
||||
def __init__(self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any):
|
||||
def __init__(self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any) -> None:
|
||||
super().__init__(**extra)
|
||||
self.name: Optional[str] = name
|
||||
self.state: Optional[str] = extra.pop('state', None)
|
||||
@ -786,10 +788,10 @@ class CustomActivity(BaseActivity):
|
||||
o['emoji'] = self.emoji.to_dict()
|
||||
return o
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, CustomActivity) and other.name == self.name and other.emoji == self.emoji
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
|
@ -166,7 +166,7 @@ def _validate_auto_complete_callback(
|
||||
return callback
|
||||
|
||||
|
||||
def _context_menu_annotation(annotation: Any, *, _none=NoneType) -> AppCommandType:
|
||||
def _context_menu_annotation(annotation: Any, *, _none: type = NoneType) -> AppCommandType:
|
||||
if annotation is Message:
|
||||
return AppCommandType.message
|
||||
|
||||
@ -686,7 +686,7 @@ class Group:
|
||||
The parent group. ``None`` if there isn't one.
|
||||
"""
|
||||
|
||||
__discord_app_commands_group_children__: ClassVar[List[Union[Command, Group]]] = []
|
||||
__discord_app_commands_group_children__: ClassVar[List[Union[Command[Any, ..., Any], Group]]] = []
|
||||
__discord_app_commands_skip_init_binding__: bool = False
|
||||
__discord_app_commands_group_name__: str = MISSING
|
||||
__discord_app_commands_group_description__: str = MISSING
|
||||
@ -694,10 +694,12 @@ class Group:
|
||||
|
||||
def __init_subclass__(cls, *, name: str = MISSING, description: str = MISSING) -> None:
|
||||
if not cls.__discord_app_commands_group_children__:
|
||||
cls.__discord_app_commands_group_children__ = children = [
|
||||
children: List[Union[Command[Any, ..., Any], Group]] = [
|
||||
member for member in cls.__dict__.values() if isinstance(member, (Group, Command)) and member.parent is None
|
||||
]
|
||||
|
||||
cls.__discord_app_commands_group_children__ = children
|
||||
|
||||
found = set()
|
||||
for child in children:
|
||||
if child.name in found:
|
||||
@ -796,15 +798,15 @@ class Group:
|
||||
"""Optional[:class:`Group`]: The parent of this group."""
|
||||
return self.parent
|
||||
|
||||
def _get_internal_command(self, name: str) -> Optional[Union[Command, Group]]:
|
||||
def _get_internal_command(self, name: str) -> Optional[Union[Command[Any, ..., Any], Group]]:
|
||||
return self._children.get(name)
|
||||
|
||||
@property
|
||||
def commands(self) -> List[Union[Command, Group]]:
|
||||
def commands(self) -> List[Union[Command[Any, ..., Any], Group]]:
|
||||
"""List[Union[:class:`Command`, :class:`Group`]]: The commands that this group contains."""
|
||||
return list(self._children.values())
|
||||
|
||||
async def on_error(self, interaction: Interaction, command: Command, error: AppCommandError) -> None:
|
||||
async def on_error(self, interaction: Interaction, command: Command[Any, ..., Any], error: AppCommandError) -> None:
|
||||
"""|coro|
|
||||
|
||||
A callback that is called when a child's command raises an :exc:`AppCommandError`.
|
||||
@ -823,7 +825,7 @@ class Group:
|
||||
|
||||
pass
|
||||
|
||||
def add_command(self, command: Union[Command, Group], /, *, override: bool = False):
|
||||
def add_command(self, command: Union[Command[Any, ..., Any], Group], /, *, override: bool = False) -> None:
|
||||
"""Adds a command or group to this group's internal list of commands.
|
||||
|
||||
Parameters
|
||||
@ -855,7 +857,7 @@ class Group:
|
||||
if len(self._children) > 25:
|
||||
raise ValueError('maximum number of child commands exceeded')
|
||||
|
||||
def remove_command(self, name: str, /) -> Optional[Union[Command, Group]]:
|
||||
def remove_command(self, name: str, /) -> Optional[Union[Command[Any, ..., Any], Group]]:
|
||||
"""Removes a command or group from the internal list of commands.
|
||||
|
||||
Parameters
|
||||
@ -872,7 +874,7 @@ class Group:
|
||||
|
||||
self._children.pop(name, None)
|
||||
|
||||
def get_command(self, name: str, /) -> Optional[Union[Command, Group]]:
|
||||
def get_command(self, name: str, /) -> Optional[Union[Command[Any, ..., Any], Group]]:
|
||||
"""Retrieves a command or group from its name.
|
||||
|
||||
Parameters
|
||||
@ -1046,7 +1048,7 @@ def describe(**parameters: str) -> Callable[[T], T]:
|
||||
return decorator
|
||||
|
||||
|
||||
def choices(**parameters: List[Choice]) -> Callable[[T], T]:
|
||||
def choices(**parameters: List[Choice[ChoiceT]]) -> Callable[[T], T]:
|
||||
r"""Instructs the given parameters by their name to use the given choices for their choices.
|
||||
|
||||
Example:
|
||||
|
@ -79,9 +79,9 @@ class CommandInvokeError(AppCommandError):
|
||||
The command that failed.
|
||||
"""
|
||||
|
||||
def __init__(self, command: Union[Command, ContextMenu], e: Exception) -> None:
|
||||
def __init__(self, command: Union[Command[Any, ..., Any], ContextMenu], e: Exception) -> None:
|
||||
self.original: Exception = e
|
||||
self.command: Union[Command, ContextMenu] = command
|
||||
self.command: Union[Command[Any, ..., Any], ContextMenu] = command
|
||||
super().__init__(f'Command {command.name!r} raised an exception: {e.__class__.__name__}: {e}')
|
||||
|
||||
|
||||
@ -191,8 +191,8 @@ class CommandSignatureMismatch(AppCommandError):
|
||||
The command that had the signature mismatch.
|
||||
"""
|
||||
|
||||
def __init__(self, command: Union[Command, ContextMenu, Group]):
|
||||
self.command: Union[Command, ContextMenu, Group] = command
|
||||
def __init__(self, command: Union[Command[Any, ..., Any], ContextMenu, Group]):
|
||||
self.command: Union[Command[Any, ..., Any], ContextMenu, Group] = command
|
||||
msg = (
|
||||
f'The signature for command {command.name!r} is different from the one provided by Discord. '
|
||||
'This can happen because either your code is out of date or you have not synced the '
|
||||
|
@ -58,7 +58,10 @@ if TYPE_CHECKING:
|
||||
PartialChannel,
|
||||
PartialThread,
|
||||
)
|
||||
from ..types.threads import ThreadMetadata
|
||||
from ..types.threads import (
|
||||
ThreadMetadata,
|
||||
ThreadArchiveDuration,
|
||||
)
|
||||
from ..state import ConnectionState
|
||||
from ..guild import GuildChannel, Guild
|
||||
from ..channel import TextChannel
|
||||
@ -117,17 +120,19 @@ class AppCommand(Hashable):
|
||||
'_state',
|
||||
)
|
||||
|
||||
def __init__(self, *, data: ApplicationCommandPayload, state=None):
|
||||
self._state = state
|
||||
def __init__(self, *, data: ApplicationCommandPayload, state: Optional[ConnectionState] = None) -> None:
|
||||
self._state: Optional[ConnectionState] = state
|
||||
self._from_data(data)
|
||||
|
||||
def _from_data(self, data: ApplicationCommandPayload):
|
||||
def _from_data(self, data: ApplicationCommandPayload) -> None:
|
||||
self.id: int = int(data['id'])
|
||||
self.application_id: int = int(data['application_id'])
|
||||
self.name: str = data['name']
|
||||
self.description: str = data['description']
|
||||
self.type: AppCommandType = try_enum(AppCommandType, data.get('type', 1))
|
||||
self.options = [app_command_option_factory(data=d, parent=self, state=self._state) for d in data.get('options', [])]
|
||||
self.options: List[Union[Argument, AppCommandGroup]] = [
|
||||
app_command_option_factory(data=d, parent=self, state=self._state) for d in data.get('options', [])
|
||||
]
|
||||
|
||||
def to_dict(self) -> ApplicationCommandPayload:
|
||||
return {
|
||||
@ -262,12 +267,12 @@ class AppCommandChannel(Hashable):
|
||||
data: PartialChannel,
|
||||
guild_id: int,
|
||||
):
|
||||
self._state = state
|
||||
self.guild_id = guild_id
|
||||
self.id = int(data['id'])
|
||||
self.type = try_enum(ChannelType, data['type'])
|
||||
self.name = data['name']
|
||||
self.permissions = Permissions(int(data['permissions']))
|
||||
self._state: ConnectionState = state
|
||||
self.guild_id: int = guild_id
|
||||
self.id: int = int(data['id'])
|
||||
self.type: ChannelType = try_enum(ChannelType, data['type'])
|
||||
self.name: str = data['name']
|
||||
self.permissions: Permissions = Permissions(int(data['permissions']))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
@ -405,13 +410,13 @@ class AppCommandThread(Hashable):
|
||||
data: PartialThread,
|
||||
guild_id: int,
|
||||
):
|
||||
self._state = state
|
||||
self.guild_id = guild_id
|
||||
self.id = int(data['id'])
|
||||
self.parent_id = int(data['parent_id'])
|
||||
self.type = try_enum(ChannelType, data['type'])
|
||||
self.name = data['name']
|
||||
self.permissions = Permissions(int(data['permissions']))
|
||||
self._state: ConnectionState = state
|
||||
self.guild_id: int = guild_id
|
||||
self.id: int = int(data['id'])
|
||||
self.parent_id: int = int(data['parent_id'])
|
||||
self.type: ChannelType = try_enum(ChannelType, data['type'])
|
||||
self.name: str = data['name']
|
||||
self.permissions: Permissions = Permissions(int(data['permissions']))
|
||||
self._unroll_metadata(data['thread_metadata'])
|
||||
|
||||
def __str__(self) -> str:
|
||||
@ -425,14 +430,14 @@ class AppCommandThread(Hashable):
|
||||
"""Optional[:class:`~discord.Guild`]: The channel's guild, from cache, if found."""
|
||||
return self._state._get_guild(self.guild_id)
|
||||
|
||||
def _unroll_metadata(self, data: ThreadMetadata):
|
||||
self.archived = data['archived']
|
||||
self.archiver_id = _get_as_snowflake(data, 'archiver_id')
|
||||
self.auto_archive_duration = data['auto_archive_duration']
|
||||
self.archive_timestamp = parse_time(data['archive_timestamp'])
|
||||
self.locked = data.get('locked', False)
|
||||
self.invitable = data.get('invitable', True)
|
||||
self._created_at = parse_time(data.get('create_timestamp'))
|
||||
def _unroll_metadata(self, data: ThreadMetadata) -> None:
|
||||
self.archived: bool = data['archived']
|
||||
self.archiver_id: Optional[int] = _get_as_snowflake(data, 'archiver_id')
|
||||
self.auto_archive_duration: ThreadArchiveDuration = data['auto_archive_duration']
|
||||
self.archive_timestamp: datetime = parse_time(data['archive_timestamp'])
|
||||
self.locked: bool = data.get('locked', False)
|
||||
self.invitable: bool = data.get('invitable', True)
|
||||
self._created_at: Optional[datetime] = parse_time(data.get('create_timestamp'))
|
||||
|
||||
@property
|
||||
def parent(self) -> Optional[TextChannel]:
|
||||
@ -522,20 +527,24 @@ class Argument:
|
||||
'_state',
|
||||
)
|
||||
|
||||
def __init__(self, *, parent: ApplicationCommandParent, data: ApplicationCommandOption, state=None):
|
||||
self._state = state
|
||||
self.parent = parent
|
||||
def __init__(
|
||||
self, *, parent: ApplicationCommandParent, data: ApplicationCommandOption, state: Optional[ConnectionState] = None
|
||||
) -> None:
|
||||
self._state: Optional[ConnectionState] = state
|
||||
self.parent: ApplicationCommandParent = parent
|
||||
self._from_data(data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<{self.__class__.__name__} name={self.name!r} type={self.type!r} required={self.required}>'
|
||||
|
||||
def _from_data(self, data: ApplicationCommandOption):
|
||||
def _from_data(self, data: ApplicationCommandOption) -> None:
|
||||
self.type: AppCommandOptionType = try_enum(AppCommandOptionType, data['type'])
|
||||
self.name: str = data['name']
|
||||
self.description: str = data['description']
|
||||
self.required: bool = data.get('required', False)
|
||||
self.choices: List[Choice] = [Choice(name=d['name'], value=d['value']) for d in data.get('choices', [])]
|
||||
self.choices: List[Choice[Union[int, float, str]]] = [
|
||||
Choice(name=d['name'], value=d['value']) for d in data.get('choices', [])
|
||||
]
|
||||
|
||||
def to_dict(self) -> ApplicationCommandOption:
|
||||
return {
|
||||
@ -582,20 +591,24 @@ class AppCommandGroup:
|
||||
'_state',
|
||||
)
|
||||
|
||||
def __init__(self, *, parent: ApplicationCommandParent, data: ApplicationCommandOption, state=None):
|
||||
self.parent = parent
|
||||
self._state = state
|
||||
def __init__(
|
||||
self, *, parent: ApplicationCommandParent, data: ApplicationCommandOption, state: Optional[ConnectionState] = None
|
||||
) -> None:
|
||||
self.parent: ApplicationCommandParent = parent
|
||||
self._state: Optional[ConnectionState] = state
|
||||
self._from_data(data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<{self.__class__.__name__} name={self.name!r} type={self.type!r} required={self.required}>'
|
||||
|
||||
def _from_data(self, data: ApplicationCommandOption):
|
||||
def _from_data(self, data: ApplicationCommandOption) -> None:
|
||||
self.type: AppCommandOptionType = try_enum(AppCommandOptionType, data['type'])
|
||||
self.name: str = data['name']
|
||||
self.description: str = data['description']
|
||||
self.required: bool = data.get('required', False)
|
||||
self.choices: List[Choice] = [Choice(name=d['name'], value=d['value']) for d in data.get('choices', [])]
|
||||
self.choices: List[Choice[Union[int, float, str]]] = [
|
||||
Choice(name=d['name'], value=d['value']) for d in data.get('choices', [])
|
||||
]
|
||||
self.arguments: List[Argument] = [
|
||||
Argument(parent=self, state=self._state, data=d)
|
||||
for d in data.get('options', [])
|
||||
@ -614,7 +627,7 @@ class AppCommandGroup:
|
||||
|
||||
|
||||
def app_command_option_factory(
|
||||
parent: ApplicationCommandParent, data: ApplicationCommandOption, *, state=None
|
||||
parent: ApplicationCommandParent, data: ApplicationCommandOption, *, state: Optional[ConnectionState] = None
|
||||
) -> Union[Argument, AppCommandGroup]:
|
||||
if is_app_command_argument_type(data['type']):
|
||||
return Argument(parent=parent, data=data, state=state)
|
||||
|
@ -95,7 +95,7 @@ class CommandParameter:
|
||||
description: str = MISSING
|
||||
required: bool = MISSING
|
||||
default: Any = MISSING
|
||||
choices: List[Choice] = MISSING
|
||||
choices: List[Choice[Union[str, int, float]]] = MISSING
|
||||
type: AppCommandOptionType = MISSING
|
||||
channel_types: List[ChannelType] = MISSING
|
||||
min_value: Optional[Union[int, float]] = None
|
||||
@ -549,7 +549,7 @@ ALLOWED_DEFAULTS: Dict[AppCommandOptionType, Tuple[Type[Any], ...]] = {
|
||||
def get_supported_annotation(
|
||||
annotation: Any,
|
||||
*,
|
||||
_none=NoneType,
|
||||
_none: type = NoneType,
|
||||
_mapping: Dict[Any, Type[Transformer]] = BUILT_IN_TRANSFORMERS,
|
||||
) -> Tuple[Any, Any]:
|
||||
"""Returns an appropriate, yet supported, annotation along with an optional default value.
|
||||
|
@ -26,7 +26,22 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, TYPE_CHECKING, Set, Tuple, TypeVar, Union, overload
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
TYPE_CHECKING,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
from collections import Counter
|
||||
|
||||
|
||||
@ -194,13 +209,13 @@ class CommandTree(Generic[ClientT]):
|
||||
|
||||
def add_command(
|
||||
self,
|
||||
command: Union[Command, ContextMenu, Group],
|
||||
command: Union[Command[Any, ..., Any], ContextMenu, Group],
|
||||
/,
|
||||
*,
|
||||
guild: Optional[Snowflake] = MISSING,
|
||||
guilds: List[Snowflake] = MISSING,
|
||||
override: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
"""Adds an application command to the tree.
|
||||
|
||||
This only adds the command locally -- in order to sync the commands
|
||||
@ -317,7 +332,7 @@ class CommandTree(Generic[ClientT]):
|
||||
*,
|
||||
guild: Optional[Snowflake] = ...,
|
||||
type: Literal[AppCommandType.chat_input] = ...,
|
||||
) -> Optional[Union[Command, Group]]:
|
||||
) -> Optional[Union[Command[Any, ..., Any], Group]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
@ -328,7 +343,7 @@ class CommandTree(Generic[ClientT]):
|
||||
*,
|
||||
guild: Optional[Snowflake] = ...,
|
||||
type: AppCommandType = ...,
|
||||
) -> Optional[Union[Command, ContextMenu, Group]]:
|
||||
) -> Optional[Union[Command[Any, ..., Any], ContextMenu, Group]]:
|
||||
...
|
||||
|
||||
def remove_command(
|
||||
@ -338,7 +353,7 @@ class CommandTree(Generic[ClientT]):
|
||||
*,
|
||||
guild: Optional[Snowflake] = None,
|
||||
type: AppCommandType = AppCommandType.chat_input,
|
||||
) -> Optional[Union[Command, ContextMenu, Group]]:
|
||||
) -> Optional[Union[Command[Any, ..., Any], ContextMenu, Group]]:
|
||||
"""Removes an application command from the tree.
|
||||
|
||||
This only removes the command locally -- in order to sync the commands
|
||||
@ -396,7 +411,7 @@ class CommandTree(Generic[ClientT]):
|
||||
*,
|
||||
guild: Optional[Snowflake] = ...,
|
||||
type: Literal[AppCommandType.chat_input] = ...,
|
||||
) -> Optional[Union[Command, Group]]:
|
||||
) -> Optional[Union[Command[Any, ..., Any], Group]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
@ -407,7 +422,7 @@ class CommandTree(Generic[ClientT]):
|
||||
*,
|
||||
guild: Optional[Snowflake] = ...,
|
||||
type: AppCommandType = ...,
|
||||
) -> Optional[Union[Command, ContextMenu, Group]]:
|
||||
) -> Optional[Union[Command[Any, ..., Any], ContextMenu, Group]]:
|
||||
...
|
||||
|
||||
def get_command(
|
||||
@ -417,7 +432,7 @@ class CommandTree(Generic[ClientT]):
|
||||
*,
|
||||
guild: Optional[Snowflake] = None,
|
||||
type: AppCommandType = AppCommandType.chat_input,
|
||||
) -> Optional[Union[Command, ContextMenu, Group]]:
|
||||
) -> Optional[Union[Command[Any, ..., Any], ContextMenu, Group]]:
|
||||
"""Gets a application command from the tree.
|
||||
|
||||
Parameters
|
||||
@ -468,7 +483,7 @@ class CommandTree(Generic[ClientT]):
|
||||
*,
|
||||
guild: Optional[Snowflake] = ...,
|
||||
type: Literal[AppCommandType.chat_input] = ...,
|
||||
) -> List[Union[Command, Group]]:
|
||||
) -> List[Union[Command[Any, ..., Any], Group]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
@ -477,7 +492,7 @@ class CommandTree(Generic[ClientT]):
|
||||
*,
|
||||
guild: Optional[Snowflake] = ...,
|
||||
type: AppCommandType = ...,
|
||||
) -> Union[List[Union[Command, Group]], List[ContextMenu]]:
|
||||
) -> Union[List[Union[Command[Any, ..., Any], Group]], List[ContextMenu]]:
|
||||
...
|
||||
|
||||
def get_commands(
|
||||
@ -485,7 +500,7 @@ class CommandTree(Generic[ClientT]):
|
||||
*,
|
||||
guild: Optional[Snowflake] = None,
|
||||
type: AppCommandType = AppCommandType.chat_input,
|
||||
) -> Union[List[Union[Command, Group]], List[ContextMenu]]:
|
||||
) -> Union[List[Union[Command[Any, ..., Any], Group]], List[ContextMenu]]:
|
||||
"""Gets all application commands from the tree.
|
||||
|
||||
Parameters
|
||||
@ -518,9 +533,11 @@ class CommandTree(Generic[ClientT]):
|
||||
value = type.value
|
||||
return [command for ((_, g, t), command) in self._context_menus.items() if g == guild_id and t == value]
|
||||
|
||||
def _get_all_commands(self, *, guild: Optional[Snowflake] = None) -> List[Union[Command, Group, ContextMenu]]:
|
||||
def _get_all_commands(
|
||||
self, *, guild: Optional[Snowflake] = None
|
||||
) -> List[Union[Command[Any, ..., Any], Group, ContextMenu]]:
|
||||
if guild is None:
|
||||
base: List[Union[Command, Group, ContextMenu]] = list(self._global_commands.values())
|
||||
base: List[Union[Command[Any, ..., Any], Group, ContextMenu]] = list(self._global_commands.values())
|
||||
base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g is None)
|
||||
return base
|
||||
else:
|
||||
@ -530,7 +547,7 @@ class CommandTree(Generic[ClientT]):
|
||||
guild_id = guild.id
|
||||
return [cmd for ((_, g, _), cmd) in self._context_menus.items() if g == guild_id]
|
||||
else:
|
||||
base: List[Union[Command, Group, ContextMenu]] = list(commands.values())
|
||||
base: List[Union[Command[Any, ..., Any], Group, ContextMenu]] = list(commands.values())
|
||||
guild_id = guild.id
|
||||
base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g == guild_id)
|
||||
return base
|
||||
@ -564,7 +581,7 @@ class CommandTree(Generic[ClientT]):
|
||||
async def on_error(
|
||||
self,
|
||||
interaction: Interaction,
|
||||
command: Optional[Union[ContextMenu, Command]],
|
||||
command: Optional[Union[ContextMenu, Command[Any, ..., Any]]],
|
||||
error: AppCommandError,
|
||||
) -> None:
|
||||
"""|coro|
|
||||
@ -742,7 +759,7 @@ class CommandTree(Generic[ClientT]):
|
||||
|
||||
self.client.loop.create_task(wrapper(), name='CommandTree-invoker')
|
||||
|
||||
async def _call_context_menu(self, interaction: Interaction, data: ApplicationCommandInteractionData, type: int):
|
||||
async def _call_context_menu(self, interaction: Interaction, data: ApplicationCommandInteractionData, type: int) -> None:
|
||||
name = data['name']
|
||||
guild_id = _get_as_snowflake(data, 'guild_id')
|
||||
ctx_menu = self._context_menus.get((name, guild_id, type))
|
||||
@ -770,7 +787,7 @@ class CommandTree(Generic[ClientT]):
|
||||
except AppCommandError as e:
|
||||
await self.on_error(interaction, ctx_menu, e)
|
||||
|
||||
async def call(self, interaction: Interaction):
|
||||
async def call(self, interaction: Interaction) -> None:
|
||||
"""|coro|
|
||||
|
||||
Given an :class:`~discord.Interaction`, calls the matching
|
||||
|
@ -39,6 +39,13 @@ __all__ = (
|
||||
# fmt: on
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
|
||||
from .state import ConnectionState
|
||||
from .webhook.async_ import _WebhookState
|
||||
|
||||
_State = Union[ConnectionState, _WebhookState]
|
||||
|
||||
ValidStaticFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png']
|
||||
ValidAssetFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png', 'gif']
|
||||
|
||||
@ -77,7 +84,7 @@ class AssetMixin:
|
||||
|
||||
return await self._state.http.get_from_cdn(self.url)
|
||||
|
||||
async def save(self, fp: Union[str, bytes, os.PathLike, io.BufferedIOBase], *, seek_begin: bool = True) -> int:
|
||||
async def save(self, fp: Union[str, bytes, os.PathLike[Any], io.BufferedIOBase], *, seek_begin: bool = True) -> int:
|
||||
"""|coro|
|
||||
|
||||
Saves this asset into a file-like object.
|
||||
@ -153,14 +160,14 @@ class Asset(AssetMixin):
|
||||
|
||||
BASE = 'https://cdn.discordapp.com'
|
||||
|
||||
def __init__(self, state, *, url: str, key: str, animated: bool = False):
|
||||
self._state = state
|
||||
self._url = url
|
||||
self._animated = animated
|
||||
self._key = key
|
||||
def __init__(self, state: _State, *, url: str, key: str, animated: bool = False) -> None:
|
||||
self._state: _State = state
|
||||
self._url: str = url
|
||||
self._animated: bool = animated
|
||||
self._key: str = key
|
||||
|
||||
@classmethod
|
||||
def _from_default_avatar(cls, state, index: int) -> Asset:
|
||||
def _from_default_avatar(cls, state: _State, index: int) -> Self:
|
||||
return cls(
|
||||
state,
|
||||
url=f'{cls.BASE}/embed/avatars/{index}.png',
|
||||
@ -169,7 +176,7 @@ class Asset(AssetMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset:
|
||||
def _from_avatar(cls, state: _State, user_id: int, avatar: str) -> Self:
|
||||
animated = avatar.startswith('a_')
|
||||
format = 'gif' if animated else 'png'
|
||||
return cls(
|
||||
@ -180,7 +187,7 @@ class Asset(AssetMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) -> Asset:
|
||||
def _from_guild_avatar(cls, state: _State, guild_id: int, member_id: int, avatar: str) -> Self:
|
||||
animated = avatar.startswith('a_')
|
||||
format = 'gif' if animated else 'png'
|
||||
return cls(
|
||||
@ -191,7 +198,7 @@ class Asset(AssetMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset:
|
||||
def _from_icon(cls, state: _State, object_id: int, icon_hash: str, path: str) -> Self:
|
||||
return cls(
|
||||
state,
|
||||
url=f'{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024',
|
||||
@ -200,7 +207,7 @@ class Asset(AssetMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asset:
|
||||
def _from_cover_image(cls, state: _State, object_id: int, cover_image_hash: str) -> Self:
|
||||
return cls(
|
||||
state,
|
||||
url=f'{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024',
|
||||
@ -209,7 +216,7 @@ class Asset(AssetMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_scheduled_event_cover_image(cls, state, scheduled_event_id: int, cover_image_hash: str) -> Asset:
|
||||
def _from_scheduled_event_cover_image(cls, state: _State, scheduled_event_id: int, cover_image_hash: str) -> Self:
|
||||
return cls(
|
||||
state,
|
||||
url=f'{cls.BASE}/guild-events/{scheduled_event_id}/{cover_image_hash}.png?size=1024',
|
||||
@ -218,7 +225,7 @@ class Asset(AssetMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset:
|
||||
def _from_guild_image(cls, state: _State, guild_id: int, image: str, path: str) -> Self:
|
||||
animated = image.startswith('a_')
|
||||
format = 'gif' if animated else 'png'
|
||||
return cls(
|
||||
@ -229,7 +236,7 @@ class Asset(AssetMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset:
|
||||
def _from_guild_icon(cls, state: _State, guild_id: int, icon_hash: str) -> Self:
|
||||
animated = icon_hash.startswith('a_')
|
||||
format = 'gif' if animated else 'png'
|
||||
return cls(
|
||||
@ -240,7 +247,7 @@ class Asset(AssetMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_sticker_banner(cls, state, banner: int) -> Asset:
|
||||
def _from_sticker_banner(cls, state: _State, banner: int) -> Self:
|
||||
return cls(
|
||||
state,
|
||||
url=f'{cls.BASE}/app-assets/710982414301790216/store/{banner}.png',
|
||||
@ -249,7 +256,7 @@ class Asset(AssetMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset:
|
||||
def _from_user_banner(cls, state: _State, user_id: int, banner_hash: str) -> Self:
|
||||
animated = banner_hash.startswith('a_')
|
||||
format = 'gif' if animated else 'png'
|
||||
return cls(
|
||||
@ -265,14 +272,14 @@ class Asset(AssetMixin):
|
||||
def __len__(self) -> int:
|
||||
return len(self._url)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
shorten = self._url.replace(self.BASE, '')
|
||||
return f'<Asset url={shorten!r}>'
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, Asset) and self._url == other._url
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash(self._url)
|
||||
|
||||
@property
|
||||
@ -295,7 +302,7 @@ class Asset(AssetMixin):
|
||||
size: int = MISSING,
|
||||
format: ValidAssetFormatTypes = MISSING,
|
||||
static_format: ValidStaticFormatTypes = MISSING,
|
||||
) -> Asset:
|
||||
) -> Self:
|
||||
"""Returns a new asset with the passed components replaced.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
@ -350,7 +357,7 @@ class Asset(AssetMixin):
|
||||
url = str(url)
|
||||
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
|
||||
|
||||
def with_size(self, size: int, /) -> Asset:
|
||||
def with_size(self, size: int, /) -> Self:
|
||||
"""Returns a new asset with the specified size.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
@ -378,7 +385,7 @@ class Asset(AssetMixin):
|
||||
url = str(yarl.URL(self._url).with_query(size=size))
|
||||
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
|
||||
|
||||
def with_format(self, format: ValidAssetFormatTypes, /) -> Asset:
|
||||
def with_format(self, format: ValidAssetFormatTypes, /) -> Self:
|
||||
"""Returns a new asset with the specified format.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
@ -413,7 +420,7 @@ class Asset(AssetMixin):
|
||||
url = str(url.with_path(f'{path}.{format}').with_query(url.raw_query_string))
|
||||
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
|
||||
|
||||
def with_static_format(self, format: ValidStaticFormatTypes, /) -> Asset:
|
||||
def with_static_format(self, format: ValidStaticFormatTypes, /) -> Self:
|
||||
"""Returns a new asset with the specified static format.
|
||||
|
||||
This only changes the format if the underlying asset is
|
||||
|
@ -50,12 +50,12 @@ if TYPE_CHECKING:
|
||||
from .member import Member
|
||||
from .role import Role
|
||||
from .scheduled_event import ScheduledEvent
|
||||
from .state import ConnectionState
|
||||
from .types.audit_log import (
|
||||
AuditLogChange as AuditLogChangePayload,
|
||||
AuditLogEntry as AuditLogEntryPayload,
|
||||
)
|
||||
from .types.channel import (
|
||||
PartialChannel as PartialChannelPayload,
|
||||
PermissionOverwrite as PermissionOverwritePayload,
|
||||
)
|
||||
from .types.invite import Invite as InvitePayload
|
||||
@ -242,8 +242,8 @@ class AuditLogChanges:
|
||||
# fmt: on
|
||||
|
||||
def __init__(self, entry: AuditLogEntry, data: List[AuditLogChangePayload]):
|
||||
self.before = AuditLogDiff()
|
||||
self.after = AuditLogDiff()
|
||||
self.before: AuditLogDiff = AuditLogDiff()
|
||||
self.after: AuditLogDiff = AuditLogDiff()
|
||||
|
||||
for elem in data:
|
||||
attr = elem['key']
|
||||
@ -390,17 +390,17 @@ class AuditLogEntry(Hashable):
|
||||
"""
|
||||
|
||||
def __init__(self, *, users: Dict[int, User], data: AuditLogEntryPayload, guild: Guild):
|
||||
self._state = guild._state
|
||||
self.guild = guild
|
||||
self._users = users
|
||||
self._state: ConnectionState = guild._state
|
||||
self.guild: Guild = guild
|
||||
self._users: Dict[int, User] = users
|
||||
self._from_data(data)
|
||||
|
||||
def _from_data(self, data: AuditLogEntryPayload) -> None:
|
||||
self.action = enums.try_enum(enums.AuditLogAction, data['action_type'])
|
||||
self.id = int(data['id'])
|
||||
self.action: enums.AuditLogAction = enums.try_enum(enums.AuditLogAction, data['action_type'])
|
||||
self.id: int = int(data['id'])
|
||||
|
||||
# this key is technically not usually present
|
||||
self.reason = data.get('reason')
|
||||
self.reason: Optional[str] = data.get('reason')
|
||||
extra = data.get('options')
|
||||
|
||||
# fmt: off
|
||||
@ -464,10 +464,13 @@ class AuditLogEntry(Hashable):
|
||||
self._changes = data.get('changes', [])
|
||||
|
||||
user_id = utils._get_as_snowflake(data, 'user_id')
|
||||
self.user = user_id and self._get_member(user_id)
|
||||
self.user: Optional[Union[User, Member]] = self._get_member(user_id)
|
||||
self._target_id = utils._get_as_snowflake(data, 'target_id')
|
||||
|
||||
def _get_member(self, user_id: int) -> Union[Member, User, None]:
|
||||
def _get_member(self, user_id: Optional[int]) -> Union[Member, User, None]:
|
||||
if user_id is None:
|
||||
return None
|
||||
|
||||
return self.guild.get_member(user_id) or self._users.get(user_id)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
@ -198,7 +198,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id')
|
||||
self._fill_overwrites(data)
|
||||
|
||||
async def _get_channel(self):
|
||||
async def _get_channel(self) -> Self:
|
||||
return self
|
||||
|
||||
@property
|
||||
@ -283,7 +283,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
async def edit(self) -> Optional[TextChannel]:
|
||||
...
|
||||
|
||||
async def edit(self, *, reason=None, **options):
|
||||
async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[TextChannel]:
|
||||
"""|coro|
|
||||
|
||||
Edits the channel.
|
||||
@ -908,7 +908,7 @@ class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hasha
|
||||
return self.guild.id, self.id
|
||||
|
||||
def _update(self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]) -> None:
|
||||
self.guild = guild
|
||||
self.guild: Guild = guild
|
||||
self.name: str = data['name']
|
||||
self.rtc_region: Optional[str] = data.get('rtc_region')
|
||||
self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get('video_quality_mode', 1))
|
||||
@ -1076,7 +1076,7 @@ class VoiceChannel(VocalGuildChannel):
|
||||
async def edit(self) -> Optional[VoiceChannel]:
|
||||
...
|
||||
|
||||
async def edit(self, *, reason=None, **options):
|
||||
async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[VoiceChannel]:
|
||||
"""|coro|
|
||||
|
||||
Edits the channel.
|
||||
@ -1220,7 +1220,7 @@ class StageChannel(VocalGuildChannel):
|
||||
|
||||
def _update(self, guild: Guild, data: StageChannelPayload) -> None:
|
||||
super()._update(guild, data)
|
||||
self.topic = data.get('topic')
|
||||
self.topic: Optional[str] = data.get('topic')
|
||||
|
||||
@property
|
||||
def requesting_to_speak(self) -> List[Member]:
|
||||
@ -1361,7 +1361,7 @@ class StageChannel(VocalGuildChannel):
|
||||
async def edit(self) -> Optional[StageChannel]:
|
||||
...
|
||||
|
||||
async def edit(self, *, reason=None, **options):
|
||||
async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[StageChannel]:
|
||||
"""|coro|
|
||||
|
||||
Edits the channel.
|
||||
@ -1522,7 +1522,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
|
||||
async def edit(self) -> Optional[CategoryChannel]:
|
||||
...
|
||||
|
||||
async def edit(self, *, reason=None, **options):
|
||||
async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[CategoryChannel]:
|
||||
"""|coro|
|
||||
|
||||
Edits the channel.
|
||||
@ -1578,7 +1578,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
|
||||
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
|
||||
|
||||
@utils.copy_doc(discord.abc.GuildChannel.move)
|
||||
async def move(self, **kwargs):
|
||||
async def move(self, **kwargs: Any) -> None:
|
||||
kwargs.pop('category', None)
|
||||
await super().move(**kwargs)
|
||||
|
||||
@ -1772,7 +1772,7 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
|
||||
async def edit(self) -> Optional[StoreChannel]:
|
||||
...
|
||||
|
||||
async def edit(self, *, reason=None, **options):
|
||||
async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[StoreChannel]:
|
||||
"""|coro|
|
||||
|
||||
Edits the channel.
|
||||
@ -1874,7 +1874,7 @@ class DMChannel(discord.abc.Messageable, Hashable):
|
||||
self.me: ClientUser = me
|
||||
self.id: int = int(data['id'])
|
||||
|
||||
async def _get_channel(self):
|
||||
async def _get_channel(self) -> Self:
|
||||
return self
|
||||
|
||||
def __str__(self) -> str:
|
||||
@ -2026,7 +2026,7 @@ class GroupChannel(discord.abc.Messageable, Hashable):
|
||||
else:
|
||||
self.owner = utils.find(lambda u: u.id == self.owner_id, self.recipients)
|
||||
|
||||
async def _get_channel(self):
|
||||
async def _get_channel(self) -> Self:
|
||||
return self
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
@ -196,11 +196,11 @@ class Client:
|
||||
unsync_clock: bool = options.pop('assume_unsync_clock', True)
|
||||
self.http: HTTPClient = HTTPClient(self.loop, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock)
|
||||
|
||||
self._handlers: Dict[str, Callable] = {
|
||||
self._handlers: Dict[str, Callable[..., None]] = {
|
||||
'ready': self._handle_ready,
|
||||
}
|
||||
|
||||
self._hooks: Dict[str, Callable] = {
|
||||
self._hooks: Dict[str, Callable[..., Coroutine[Any, Any, Any]]] = {
|
||||
'before_identify': self._call_before_identify_hook,
|
||||
}
|
||||
|
||||
@ -698,7 +698,7 @@ class Client:
|
||||
raise TypeError('activity must derive from BaseActivity.')
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
def status(self) -> Status:
|
||||
""":class:`.Status`:
|
||||
The status being used upon logging on to Discord.
|
||||
|
||||
@ -709,7 +709,7 @@ class Client:
|
||||
return Status.online
|
||||
|
||||
@status.setter
|
||||
def status(self, value):
|
||||
def status(self, value: Status) -> None:
|
||||
if value is Status.offline:
|
||||
self._connection._status = 'invisible'
|
||||
elif isinstance(value, Status):
|
||||
@ -1077,7 +1077,7 @@ class Client:
|
||||
*,
|
||||
activity: Optional[BaseActivity] = None,
|
||||
status: Optional[Status] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""|coro|
|
||||
|
||||
Changes the client's presence.
|
||||
|
@ -32,7 +32,6 @@ from typing import (
|
||||
Callable,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -90,10 +89,10 @@ class Colour:
|
||||
def _get_byte(self, byte: int) -> int:
|
||||
return (self.value >> (8 * byte)) & 0xFF
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, Colour) and self.value == other.value
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __str__(self) -> str:
|
||||
@ -265,28 +264,28 @@ class Colour:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x95a5a6``."""
|
||||
return cls(0x95A5A6)
|
||||
|
||||
lighter_gray: Callable[[Type[Self]], Self] = lighter_grey
|
||||
lighter_gray = lighter_grey
|
||||
|
||||
@classmethod
|
||||
def dark_grey(cls) -> Self:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x607d8b``."""
|
||||
return cls(0x607D8B)
|
||||
|
||||
dark_gray: Callable[[Type[Self]], Self] = dark_grey
|
||||
dark_gray = dark_grey
|
||||
|
||||
@classmethod
|
||||
def light_grey(cls) -> Self:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x979c9f``."""
|
||||
return cls(0x979C9F)
|
||||
|
||||
light_gray: Callable[[Type[Self]], Self] = light_grey
|
||||
light_gray = light_grey
|
||||
|
||||
@classmethod
|
||||
def darker_grey(cls) -> Self:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x546e7a``."""
|
||||
return cls(0x546E7A)
|
||||
|
||||
darker_gray: Callable[[Type[Self]], Self] = darker_grey
|
||||
darker_gray = darker_grey
|
||||
|
||||
@classmethod
|
||||
def og_blurple(cls) -> Self:
|
||||
|
@ -310,9 +310,9 @@ class SelectOption:
|
||||
emoji: Optional[Union[str, Emoji, PartialEmoji]] = None,
|
||||
default: bool = False,
|
||||
) -> None:
|
||||
self.label = label
|
||||
self.value = label if value is MISSING else value
|
||||
self.description = description
|
||||
self.label: str = label
|
||||
self.value: str = label if value is MISSING else value
|
||||
self.description: Optional[str] = description
|
||||
|
||||
if emoji is not None:
|
||||
if isinstance(emoji, str):
|
||||
@ -322,8 +322,8 @@ class SelectOption:
|
||||
else:
|
||||
raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}')
|
||||
|
||||
self.emoji = emoji
|
||||
self.default = default
|
||||
self.emoji: Optional[Union[str, Emoji, PartialEmoji]] = emoji
|
||||
self.default: bool = default
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
|
@ -25,13 +25,15 @@ DEALINGS IN THE SOFTWARE.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Optional, Type
|
||||
from typing import TYPE_CHECKING, Optional, Type, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .abc import Messageable
|
||||
|
||||
from types import TracebackType
|
||||
|
||||
BE = TypeVar('BE', bound=BaseException)
|
||||
|
||||
# fmt: off
|
||||
__all__ = (
|
||||
'Typing',
|
||||
@ -67,13 +69,13 @@ class Typing:
|
||||
async def __aenter__(self) -> None:
|
||||
self._channel = channel = await self.messageable._get_channel()
|
||||
await channel._state.http.send_typing(channel.id)
|
||||
self.task: asyncio.Task = self.loop.create_task(self.do_typing())
|
||||
self.task: asyncio.Task[None] = self.loop.create_task(self.do_typing())
|
||||
self.task.add_done_callback(_typing_done_callback)
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
exc_type: Optional[Type[BE]],
|
||||
exc: Optional[BE],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.task.cancel()
|
||||
|
@ -189,10 +189,10 @@ class Embed:
|
||||
):
|
||||
|
||||
self.colour = colour if colour is not EmptyEmbed else color
|
||||
self.title = title
|
||||
self.type = type
|
||||
self.url = url
|
||||
self.description = description
|
||||
self.title: MaybeEmpty[str] = title
|
||||
self.type: EmbedType = type
|
||||
self.url: MaybeEmpty[str] = url
|
||||
self.description: MaybeEmpty[str] = description
|
||||
|
||||
if self.title is not EmptyEmbed:
|
||||
self.title = str(self.title)
|
||||
@ -311,7 +311,7 @@ class Embed:
|
||||
return getattr(self, '_colour', EmptyEmbed)
|
||||
|
||||
@colour.setter
|
||||
def colour(self, value: Union[int, Colour, _EmptyEmbed]):
|
||||
def colour(self, value: Union[int, Colour, _EmptyEmbed]) -> None:
|
||||
if isinstance(value, (Colour, _EmptyEmbed)):
|
||||
self._colour = value
|
||||
elif isinstance(value, int):
|
||||
@ -326,7 +326,7 @@ class Embed:
|
||||
return getattr(self, '_timestamp', EmptyEmbed)
|
||||
|
||||
@timestamp.setter
|
||||
def timestamp(self, value: MaybeEmpty[datetime.datetime]):
|
||||
def timestamp(self, value: MaybeEmpty[datetime.datetime]) -> None:
|
||||
if isinstance(value, datetime.datetime):
|
||||
if value.tzinfo is None:
|
||||
value = value.astimezone()
|
||||
|
@ -142,10 +142,10 @@ class Emoji(_EmojiTag, AssetMixin):
|
||||
def __repr__(self) -> str:
|
||||
return f'<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>'
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, _EmojiTag) and self.id == other.id
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
|
@ -25,7 +25,7 @@ from __future__ import annotations
|
||||
|
||||
import types
|
||||
from collections import namedtuple
|
||||
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar
|
||||
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Iterator, Mapping
|
||||
|
||||
__all__ = (
|
||||
'Enum',
|
||||
@ -131,38 +131,38 @@ class EnumMeta(type):
|
||||
value_cls._actual_enum_cls_ = actual_cls # type: ignore - Runtime attribute isn't understood
|
||||
return actual_cls
|
||||
|
||||
def __iter__(cls):
|
||||
def __iter__(cls) -> Iterator[Any]:
|
||||
return (cls._enum_member_map_[name] for name in cls._enum_member_names_)
|
||||
|
||||
def __reversed__(cls):
|
||||
def __reversed__(cls) -> Iterator[Any]:
|
||||
return (cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_))
|
||||
|
||||
def __len__(cls):
|
||||
def __len__(cls) -> int:
|
||||
return len(cls._enum_member_names_)
|
||||
|
||||
def __repr__(cls):
|
||||
def __repr__(cls) -> str:
|
||||
return f'<enum {cls.__name__}>'
|
||||
|
||||
@property
|
||||
def __members__(cls):
|
||||
def __members__(cls) -> Mapping[str, Any]:
|
||||
return types.MappingProxyType(cls._enum_member_map_)
|
||||
|
||||
def __call__(cls, value):
|
||||
def __call__(cls, value: str) -> Any:
|
||||
try:
|
||||
return cls._enum_value_map_[value]
|
||||
except (KeyError, TypeError):
|
||||
raise ValueError(f"{value!r} is not a valid {cls.__name__}")
|
||||
|
||||
def __getitem__(cls, key):
|
||||
def __getitem__(cls, key: str) -> Any:
|
||||
return cls._enum_member_map_[key]
|
||||
|
||||
def __setattr__(cls, name, value):
|
||||
def __setattr__(cls, name: str, value: Any) -> None:
|
||||
raise TypeError('Enums are immutable.')
|
||||
|
||||
def __delattr__(cls, attr):
|
||||
def __delattr__(cls, attr: str) -> None:
|
||||
raise TypeError('Enums are immutable')
|
||||
|
||||
def __instancecheck__(self, instance):
|
||||
def __instancecheck__(self, instance: Any) -> bool:
|
||||
# isinstance(x, Y)
|
||||
# -> __instancecheck__(Y, x)
|
||||
try:
|
||||
@ -197,7 +197,7 @@ class ChannelType(Enum):
|
||||
private_thread = 12
|
||||
stage_voice = 13
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
@ -233,10 +233,10 @@ class SpeakingState(Enum):
|
||||
soundshare = 2
|
||||
priority = 4
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
def __int__(self):
|
||||
def __int__(self) -> int:
|
||||
return self.value
|
||||
|
||||
|
||||
@ -247,7 +247,7 @@ class VerificationLevel(Enum, comparable=True):
|
||||
high = 3
|
||||
highest = 4
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
@ -256,7 +256,7 @@ class ContentFilter(Enum, comparable=True):
|
||||
no_role = 1
|
||||
all_members = 2
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
@ -268,7 +268,7 @@ class Status(Enum):
|
||||
do_not_disturb = 'dnd'
|
||||
invisible = 'invisible'
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
@ -280,7 +280,7 @@ class DefaultAvatar(Enum):
|
||||
orange = 3
|
||||
red = 4
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
@ -467,7 +467,7 @@ class ActivityType(Enum):
|
||||
custom = 4
|
||||
competing = 5
|
||||
|
||||
def __int__(self):
|
||||
def __int__(self) -> int:
|
||||
return self.value
|
||||
|
||||
|
||||
@ -542,7 +542,7 @@ class VideoQualityMode(Enum):
|
||||
auto = 1
|
||||
full = 2
|
||||
|
||||
def __int__(self):
|
||||
def __int__(self) -> int:
|
||||
return self.value
|
||||
|
||||
|
||||
@ -552,7 +552,7 @@ class ComponentType(Enum):
|
||||
select = 3
|
||||
text_input = 4
|
||||
|
||||
def __int__(self):
|
||||
def __int__(self) -> int:
|
||||
return self.value
|
||||
|
||||
|
||||
@ -571,7 +571,7 @@ class ButtonStyle(Enum):
|
||||
red = 4
|
||||
url = 5
|
||||
|
||||
def __int__(self):
|
||||
def __int__(self) -> int:
|
||||
return self.value
|
||||
|
||||
|
||||
|
@ -23,21 +23,35 @@ DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
|
||||
from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union
|
||||
from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union, Tuple
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from .bot import Bot, AutoShardedBot
|
||||
from .context import Context
|
||||
from .cog import Cog
|
||||
from .errors import CommandError
|
||||
|
||||
T = TypeVar('T')
|
||||
P = ParamSpec('P')
|
||||
MaybeCoroFunc = Union[
|
||||
Callable[P, 'Coro[T]'],
|
||||
Callable[P, T],
|
||||
]
|
||||
else:
|
||||
P = TypeVar('P')
|
||||
MaybeCoroFunc = Tuple[P, T]
|
||||
|
||||
Coro = Coroutine[Any, Any, T]
|
||||
MaybeCoro = Union[T, Coro[T]]
|
||||
CoroFunc = Callable[..., Coro[Any]]
|
||||
|
||||
ContextT = TypeVar('ContextT', bound='Context')
|
||||
_Bot = Union['Bot', 'AutoShardedBot']
|
||||
BotT = TypeVar('BotT', bound=_Bot)
|
||||
|
||||
Check = Union[Callable[["Cog", "ContextT"], MaybeCoro[bool]], Callable[["ContextT"], MaybeCoro[bool]]]
|
||||
Hook = Union[Callable[["Cog", "ContextT"], Coro[Any]], Callable[["ContextT"], Coro[Any]]]
|
||||
|
@ -33,7 +33,21 @@ import importlib.util
|
||||
import sys
|
||||
import traceback
|
||||
import types
|
||||
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union, overload
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Mapping,
|
||||
List,
|
||||
Dict,
|
||||
TYPE_CHECKING,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Type,
|
||||
Union,
|
||||
Iterable,
|
||||
Collection,
|
||||
overload,
|
||||
)
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
@ -55,10 +69,18 @@ if TYPE_CHECKING:
|
||||
from discord.message import Message
|
||||
from discord.abc import User, Snowflake
|
||||
from ._types import (
|
||||
_Bot,
|
||||
BotT,
|
||||
Check,
|
||||
CoroFunc,
|
||||
ContextT,
|
||||
MaybeCoroFunc,
|
||||
)
|
||||
|
||||
_Prefix = Union[Iterable[str], str]
|
||||
_PrefixCallable = MaybeCoroFunc[[BotT, Message], _Prefix]
|
||||
PrefixType = Union[_Prefix, _PrefixCallable[BotT]]
|
||||
|
||||
__all__ = (
|
||||
'when_mentioned',
|
||||
'when_mentioned_or',
|
||||
@ -68,11 +90,9 @@ __all__ = (
|
||||
|
||||
T = TypeVar('T')
|
||||
CFT = TypeVar('CFT', bound='CoroFunc')
|
||||
CXT = TypeVar('CXT', bound='Context')
|
||||
BT = TypeVar('BT', bound='Union[Bot, AutoShardedBot]')
|
||||
|
||||
|
||||
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
|
||||
def when_mentioned(bot: _Bot, msg: Message) -> List[str]:
|
||||
"""A callable that implements a command prefix equivalent to being mentioned.
|
||||
|
||||
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
|
||||
@ -81,7 +101,7 @@ def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
|
||||
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore
|
||||
|
||||
|
||||
def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]:
|
||||
def when_mentioned_or(*prefixes: str) -> Callable[[_Bot, Message], List[str]]:
|
||||
"""A callable that implements when mentioned or other prefixes provided.
|
||||
|
||||
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
|
||||
@ -124,27 +144,33 @@ class _DefaultRepr:
|
||||
return '<default-help-command>'
|
||||
|
||||
|
||||
_default = _DefaultRepr()
|
||||
_default: Any = _DefaultRepr()
|
||||
|
||||
|
||||
class BotBase(GroupMixin):
|
||||
def __init__(self, command_prefix, help_command=_default, description=None, **options):
|
||||
class BotBase(GroupMixin[None]):
|
||||
def __init__(
|
||||
self,
|
||||
command_prefix: PrefixType[BotT],
|
||||
help_command: HelpCommand = _default,
|
||||
description: Optional[str] = None,
|
||||
**options: Any,
|
||||
) -> None:
|
||||
super().__init__(**options)
|
||||
self.command_prefix = command_prefix
|
||||
self.command_prefix: PrefixType[BotT] = command_prefix
|
||||
self.extra_events: Dict[str, List[CoroFunc]] = {}
|
||||
# Self doesn't have the ClientT bound, but since this is a mixin it technically does
|
||||
self.__tree: app_commands.CommandTree[Self] = app_commands.CommandTree(self) # type: ignore
|
||||
self.__cogs: Dict[str, Cog] = {}
|
||||
self.__extensions: Dict[str, types.ModuleType] = {}
|
||||
self._checks: List[Check] = []
|
||||
self._check_once = []
|
||||
self._before_invoke = None
|
||||
self._after_invoke = None
|
||||
self._help_command = None
|
||||
self.description = inspect.cleandoc(description) if description else ''
|
||||
self.owner_id = options.get('owner_id')
|
||||
self.owner_ids = options.get('owner_ids', set())
|
||||
self.strip_after_prefix = options.get('strip_after_prefix', False)
|
||||
self._check_once: List[Check] = []
|
||||
self._before_invoke: Optional[CoroFunc] = None
|
||||
self._after_invoke: Optional[CoroFunc] = None
|
||||
self._help_command: Optional[HelpCommand] = None
|
||||
self.description: str = inspect.cleandoc(description) if description else ''
|
||||
self.owner_id: Optional[int] = options.get('owner_id')
|
||||
self.owner_ids: Optional[Collection[int]] = options.get('owner_ids', set())
|
||||
self.strip_after_prefix: bool = options.get('strip_after_prefix', False)
|
||||
|
||||
if self.owner_id and self.owner_ids:
|
||||
raise TypeError('Both owner_id and owner_ids are set.')
|
||||
@ -182,7 +208,7 @@ class BotBase(GroupMixin):
|
||||
|
||||
await super().close() # type: ignore
|
||||
|
||||
async def on_command_error(self, context: Context, exception: errors.CommandError) -> None:
|
||||
async def on_command_error(self, context: Context[BotT], exception: errors.CommandError) -> None:
|
||||
"""|coro|
|
||||
|
||||
The default command error handler provided by the bot.
|
||||
@ -237,7 +263,7 @@ class BotBase(GroupMixin):
|
||||
self.add_check(func) # type: ignore
|
||||
return func
|
||||
|
||||
def add_check(self, func: Check, /, *, call_once: bool = False) -> None:
|
||||
def add_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None:
|
||||
"""Adds a global check to the bot.
|
||||
|
||||
This is the non-decorator interface to :meth:`.check`
|
||||
@ -261,7 +287,7 @@ class BotBase(GroupMixin):
|
||||
else:
|
||||
self._checks.append(func)
|
||||
|
||||
def remove_check(self, func: Check, /, *, call_once: bool = False) -> None:
|
||||
def remove_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None:
|
||||
"""Removes a global check from the bot.
|
||||
|
||||
This function is idempotent and will not raise an exception
|
||||
@ -324,7 +350,7 @@ class BotBase(GroupMixin):
|
||||
self.add_check(func, call_once=True)
|
||||
return func
|
||||
|
||||
async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool:
|
||||
async def can_run(self, ctx: Context[BotT], *, call_once: bool = False) -> bool:
|
||||
data = self._check_once if call_once else self._checks
|
||||
|
||||
if len(data) == 0:
|
||||
@ -947,7 +973,7 @@ class BotBase(GroupMixin):
|
||||
# if the load failed, the remnants should have been
|
||||
# cleaned from the load_extension function call
|
||||
# so let's load it from our old compiled library.
|
||||
await lib.setup(self) # type: ignore
|
||||
await lib.setup(self)
|
||||
self.__extensions[name] = lib
|
||||
|
||||
# revert sys.modules back to normal and raise back to caller
|
||||
@ -1015,11 +1041,12 @@ class BotBase(GroupMixin):
|
||||
"""
|
||||
prefix = ret = self.command_prefix
|
||||
if callable(prefix):
|
||||
ret = await discord.utils.maybe_coroutine(prefix, self, message)
|
||||
# self will be a Bot or AutoShardedBot
|
||||
ret = await discord.utils.maybe_coroutine(prefix, self, message) # type: ignore
|
||||
|
||||
if not isinstance(ret, str):
|
||||
try:
|
||||
ret = list(ret)
|
||||
ret = list(ret) # type: ignore
|
||||
except TypeError:
|
||||
# It's possible that a generator raised this exception. Don't
|
||||
# replace it with our own error if that's the case.
|
||||
@ -1048,15 +1075,15 @@ class BotBase(GroupMixin):
|
||||
self,
|
||||
message: Message,
|
||||
*,
|
||||
cls: Type[CXT] = ...,
|
||||
) -> CXT: # type: ignore
|
||||
cls: Type[ContextT] = ...,
|
||||
) -> ContextT:
|
||||
...
|
||||
|
||||
async def get_context(
|
||||
self,
|
||||
message: Message,
|
||||
*,
|
||||
cls: Type[CXT] = MISSING,
|
||||
cls: Type[ContextT] = MISSING,
|
||||
) -> Any:
|
||||
r"""|coro|
|
||||
|
||||
@ -1137,7 +1164,7 @@ class BotBase(GroupMixin):
|
||||
ctx.command = self.all_commands.get(invoker)
|
||||
return ctx
|
||||
|
||||
async def invoke(self, ctx: Context) -> None:
|
||||
async def invoke(self, ctx: Context[BotT]) -> None:
|
||||
"""|coro|
|
||||
|
||||
Invokes the command given under the invocation context and
|
||||
@ -1189,9 +1216,10 @@ class BotBase(GroupMixin):
|
||||
return
|
||||
|
||||
ctx = await self.get_context(message)
|
||||
await self.invoke(ctx)
|
||||
# the type of the invocation context's bot attribute will be correct
|
||||
await self.invoke(ctx) # type: ignore
|
||||
|
||||
async def on_message(self, message):
|
||||
async def on_message(self, message: Message) -> None:
|
||||
await self.process_commands(message)
|
||||
|
||||
|
||||
|
@ -30,7 +30,7 @@ from discord.utils import maybe_coroutine
|
||||
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union
|
||||
|
||||
from ._types import _BaseCommand
|
||||
from ._types import _BaseCommand, BotT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
@ -112,7 +112,7 @@ class CogMeta(type):
|
||||
|
||||
__cog_name__: str
|
||||
__cog_settings__: Dict[str, Any]
|
||||
__cog_commands__: List[Command]
|
||||
__cog_commands__: List[Command[Any, ..., Any]]
|
||||
__cog_is_app_commands_group__: bool
|
||||
__cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Any, ..., Any]]]
|
||||
__cog_listeners__: List[Tuple[str, str]]
|
||||
@ -406,7 +406,7 @@ class Cog(metaclass=CogMeta):
|
||||
pass
|
||||
|
||||
@_cog_special_method
|
||||
def bot_check_once(self, ctx: Context) -> bool:
|
||||
def bot_check_once(self, ctx: Context[BotT]) -> bool:
|
||||
"""A special method that registers as a :meth:`.Bot.check_once`
|
||||
check.
|
||||
|
||||
@ -416,7 +416,7 @@ class Cog(metaclass=CogMeta):
|
||||
return True
|
||||
|
||||
@_cog_special_method
|
||||
def bot_check(self, ctx: Context) -> bool:
|
||||
def bot_check(self, ctx: Context[BotT]) -> bool:
|
||||
"""A special method that registers as a :meth:`.Bot.check`
|
||||
check.
|
||||
|
||||
@ -426,7 +426,7 @@ class Cog(metaclass=CogMeta):
|
||||
return True
|
||||
|
||||
@_cog_special_method
|
||||
def cog_check(self, ctx: Context) -> bool:
|
||||
def cog_check(self, ctx: Context[BotT]) -> bool:
|
||||
"""A special method that registers as a :func:`~discord.ext.commands.check`
|
||||
for every command and subcommand in this cog.
|
||||
|
||||
@ -436,7 +436,7 @@ class Cog(metaclass=CogMeta):
|
||||
return True
|
||||
|
||||
@_cog_special_method
|
||||
async def cog_command_error(self, ctx: Context, error: Exception) -> None:
|
||||
async def cog_command_error(self, ctx: Context[BotT], error: Exception) -> None:
|
||||
"""A special method that is called whenever an error
|
||||
is dispatched inside this cog.
|
||||
|
||||
@ -455,7 +455,7 @@ class Cog(metaclass=CogMeta):
|
||||
pass
|
||||
|
||||
@_cog_special_method
|
||||
async def cog_before_invoke(self, ctx: Context) -> None:
|
||||
async def cog_before_invoke(self, ctx: Context[BotT]) -> None:
|
||||
"""A special method that acts as a cog local pre-invoke hook.
|
||||
|
||||
This is similar to :meth:`.Command.before_invoke`.
|
||||
@ -470,7 +470,7 @@ class Cog(metaclass=CogMeta):
|
||||
pass
|
||||
|
||||
@_cog_special_method
|
||||
async def cog_after_invoke(self, ctx: Context) -> None:
|
||||
async def cog_after_invoke(self, ctx: Context[BotT]) -> None:
|
||||
"""A special method that acts as a cog local post-invoke hook.
|
||||
|
||||
This is similar to :meth:`.Command.after_invoke`.
|
||||
|
@ -28,6 +28,8 @@ import re
|
||||
|
||||
from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
from ._types import BotT
|
||||
|
||||
import discord.abc
|
||||
import discord.utils
|
||||
|
||||
@ -59,7 +61,6 @@ MISSING: Any = discord.utils.MISSING
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
|
||||
CogT = TypeVar('CogT', bound="Cog")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -133,10 +134,10 @@ class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
args: List[Any] = MISSING,
|
||||
kwargs: Dict[str, Any] = MISSING,
|
||||
prefix: Optional[str] = None,
|
||||
command: Optional[Command] = None,
|
||||
command: Optional[Command[Any, ..., Any]] = None,
|
||||
invoked_with: Optional[str] = None,
|
||||
invoked_parents: List[str] = MISSING,
|
||||
invoked_subcommand: Optional[Command] = None,
|
||||
invoked_subcommand: Optional[Command[Any, ..., Any]] = None,
|
||||
subcommand_passed: Optional[str] = None,
|
||||
command_failed: bool = False,
|
||||
current_parameter: Optional[inspect.Parameter] = None,
|
||||
@ -146,11 +147,11 @@ class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
self.args: List[Any] = args or []
|
||||
self.kwargs: Dict[str, Any] = kwargs or {}
|
||||
self.prefix: Optional[str] = prefix
|
||||
self.command: Optional[Command] = command
|
||||
self.command: Optional[Command[Any, ..., Any]] = command
|
||||
self.view: StringView = view
|
||||
self.invoked_with: Optional[str] = invoked_with
|
||||
self.invoked_parents: List[str] = invoked_parents or []
|
||||
self.invoked_subcommand: Optional[Command] = invoked_subcommand
|
||||
self.invoked_subcommand: Optional[Command[Any, ..., Any]] = invoked_subcommand
|
||||
self.subcommand_passed: Optional[str] = subcommand_passed
|
||||
self.command_failed: bool = command_failed
|
||||
self.current_parameter: Optional[inspect.Parameter] = current_parameter
|
||||
@ -361,7 +362,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
return None
|
||||
|
||||
cmd = cmd.copy()
|
||||
cmd.context = self
|
||||
cmd.context = self # type: ignore
|
||||
if len(args) == 0:
|
||||
await cmd.prepare_help_command(self, None)
|
||||
mapping = cmd.get_bot_mapping()
|
||||
@ -390,7 +391,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
try:
|
||||
if hasattr(entity, '__cog_commands__'):
|
||||
injected = wrap_callback(cmd.send_cog_help)
|
||||
return await injected(entity)
|
||||
return await injected(entity) # type: ignore
|
||||
elif isinstance(entity, Group):
|
||||
injected = wrap_callback(cmd.send_group_help)
|
||||
return await injected(entity)
|
||||
|
@ -41,7 +41,6 @@ from typing import (
|
||||
Tuple,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
overload,
|
||||
)
|
||||
|
||||
import discord
|
||||
@ -51,9 +50,8 @@ if TYPE_CHECKING:
|
||||
from .context import Context
|
||||
from discord.state import Channel
|
||||
from discord.threads import Thread
|
||||
from .bot import Bot, AutoShardedBot
|
||||
|
||||
_Bot = TypeVar('_Bot', bound=Union[Bot, AutoShardedBot])
|
||||
from ._types import BotT, _Bot
|
||||
|
||||
|
||||
__all__ = (
|
||||
@ -87,7 +85,7 @@ __all__ = (
|
||||
)
|
||||
|
||||
|
||||
def _get_from_guilds(bot, getter, argument):
|
||||
def _get_from_guilds(bot: _Bot, getter: str, argument: Any) -> Any:
|
||||
result = None
|
||||
for guild in bot.guilds:
|
||||
result = getattr(guild, getter)(argument)
|
||||
@ -115,7 +113,7 @@ class Converter(Protocol[T_co]):
|
||||
method to do its conversion logic. This method must be a :ref:`coroutine <coroutine>`.
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> T_co:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> T_co:
|
||||
"""|coro|
|
||||
|
||||
The method to override to do conversion logic.
|
||||
@ -163,7 +161,7 @@ class ObjectConverter(IDConverter[discord.Object]):
|
||||
2. Lookup by member, role, or channel mention.
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Object:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Object:
|
||||
match = self._get_id_match(argument) or re.match(r'<(?:@(?:!|&)?|#)([0-9]{15,20})>$', argument)
|
||||
|
||||
if match is None:
|
||||
@ -196,7 +194,7 @@ class MemberConverter(IDConverter[discord.Member]):
|
||||
optionally caching the result if :attr:`.MemberCacheFlags.joined` is enabled.
|
||||
"""
|
||||
|
||||
async def query_member_named(self, guild, argument):
|
||||
async def query_member_named(self, guild: discord.Guild, argument: str) -> Optional[discord.Member]:
|
||||
cache = guild._state.member_cache_flags.joined
|
||||
if len(argument) > 5 and argument[-5] == '#':
|
||||
username, _, discriminator = argument.rpartition('#')
|
||||
@ -206,7 +204,7 @@ class MemberConverter(IDConverter[discord.Member]):
|
||||
members = await guild.query_members(argument, limit=100, cache=cache)
|
||||
return discord.utils.find(lambda m: m.name == argument or m.nick == argument, members)
|
||||
|
||||
async def query_member_by_id(self, bot, guild, user_id):
|
||||
async def query_member_by_id(self, bot: _Bot, guild: discord.Guild, user_id: int) -> Optional[discord.Member]:
|
||||
ws = bot._get_websocket(shard_id=guild.shard_id)
|
||||
cache = guild._state.member_cache_flags.joined
|
||||
if ws.is_ratelimited():
|
||||
@ -227,7 +225,7 @@ class MemberConverter(IDConverter[discord.Member]):
|
||||
return None
|
||||
return members[0]
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Member:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Member:
|
||||
bot = ctx.bot
|
||||
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
|
||||
guild = ctx.guild
|
||||
@ -281,7 +279,7 @@ class UserConverter(IDConverter[discord.User]):
|
||||
and it's not available in cache.
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.User:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.User:
|
||||
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
|
||||
result = None
|
||||
state = ctx._state
|
||||
@ -359,7 +357,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
|
||||
|
||||
@staticmethod
|
||||
def _resolve_channel(
|
||||
ctx: Context[_Bot], guild_id: Optional[int], channel_id: Optional[int]
|
||||
ctx: Context[BotT], guild_id: Optional[int], channel_id: Optional[int]
|
||||
) -> Optional[Union[Channel, Thread]]:
|
||||
if channel_id is None:
|
||||
# we were passed just a message id so we can assume the channel is the current context channel
|
||||
@ -373,7 +371,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
|
||||
|
||||
return ctx.bot.get_channel(channel_id)
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialMessage:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.PartialMessage:
|
||||
guild_id, message_id, channel_id = self._get_id_matches(ctx, argument)
|
||||
channel = self._resolve_channel(ctx, guild_id, channel_id)
|
||||
if not channel or not isinstance(channel, discord.abc.Messageable):
|
||||
@ -396,7 +394,7 @@ class MessageConverter(IDConverter[discord.Message]):
|
||||
Raise :exc:`.ChannelNotFound`, :exc:`.MessageNotFound` or :exc:`.ChannelNotReadable` instead of generic :exc:`.BadArgument`
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Message:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Message:
|
||||
guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument)
|
||||
message = ctx.bot._connection._get_message(message_id)
|
||||
if message:
|
||||
@ -427,11 +425,11 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.abc.GuildChannel:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.abc.GuildChannel:
|
||||
return self._resolve_channel(ctx, argument, 'channels', discord.abc.GuildChannel)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT:
|
||||
def _resolve_channel(ctx: Context[BotT], argument: str, attribute: str, type: Type[CT]) -> CT:
|
||||
bot = ctx.bot
|
||||
|
||||
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
|
||||
@ -448,7 +446,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
|
||||
def check(c):
|
||||
return isinstance(c, type) and c.name == argument
|
||||
|
||||
result = discord.utils.find(check, bot.get_all_channels())
|
||||
result = discord.utils.find(check, bot.get_all_channels()) # type: ignore
|
||||
else:
|
||||
channel_id = int(match.group(1))
|
||||
if guild:
|
||||
@ -463,7 +461,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT:
|
||||
def _resolve_thread(ctx: Context[BotT], argument: str, attribute: str, type: Type[TT]) -> TT:
|
||||
bot = ctx.bot
|
||||
|
||||
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
|
||||
@ -502,7 +500,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
|
||||
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.TextChannel:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.TextChannel:
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, 'text_channels', discord.TextChannel)
|
||||
|
||||
|
||||
@ -522,7 +520,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
|
||||
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.VoiceChannel:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.VoiceChannel:
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, 'voice_channels', discord.VoiceChannel)
|
||||
|
||||
|
||||
@ -541,7 +539,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
|
||||
3. Lookup by name
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.StageChannel:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.StageChannel:
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, 'stage_channels', discord.StageChannel)
|
||||
|
||||
|
||||
@ -561,7 +559,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
|
||||
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.CategoryChannel:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.CategoryChannel:
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, 'categories', discord.CategoryChannel)
|
||||
|
||||
|
||||
@ -580,7 +578,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
|
||||
.. versionadded:: 1.7
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.StoreChannel:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.StoreChannel:
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel)
|
||||
|
||||
|
||||
@ -598,7 +596,7 @@ class ThreadConverter(IDConverter[discord.Thread]):
|
||||
.. versionadded: 2.0
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Thread:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Thread:
|
||||
return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread)
|
||||
|
||||
|
||||
@ -630,7 +628,7 @@ class ColourConverter(Converter[discord.Colour]):
|
||||
|
||||
RGB_REGEX = re.compile(r'rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)')
|
||||
|
||||
def parse_hex_number(self, argument):
|
||||
def parse_hex_number(self, argument: str) -> discord.Colour:
|
||||
arg = ''.join(i * 2 for i in argument) if len(argument) == 3 else argument
|
||||
try:
|
||||
value = int(arg, base=16)
|
||||
@ -641,7 +639,7 @@ class ColourConverter(Converter[discord.Colour]):
|
||||
else:
|
||||
return discord.Color(value=value)
|
||||
|
||||
def parse_rgb_number(self, argument, number):
|
||||
def parse_rgb_number(self, argument: str, number: str) -> int:
|
||||
if number[-1] == '%':
|
||||
value = int(number[:-1])
|
||||
if not (0 <= value <= 100):
|
||||
@ -653,7 +651,7 @@ class ColourConverter(Converter[discord.Colour]):
|
||||
raise BadColourArgument(argument)
|
||||
return value
|
||||
|
||||
def parse_rgb(self, argument, *, regex=RGB_REGEX):
|
||||
def parse_rgb(self, argument: str, *, regex: re.Pattern[str] = RGB_REGEX) -> discord.Colour:
|
||||
match = regex.match(argument)
|
||||
if match is None:
|
||||
raise BadColourArgument(argument)
|
||||
@ -663,7 +661,7 @@ class ColourConverter(Converter[discord.Colour]):
|
||||
blue = self.parse_rgb_number(argument, match.group('b'))
|
||||
return discord.Color.from_rgb(red, green, blue)
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Colour:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Colour:
|
||||
if argument[0] == '#':
|
||||
return self.parse_hex_number(argument[1:])
|
||||
|
||||
@ -704,7 +702,7 @@ class RoleConverter(IDConverter[discord.Role]):
|
||||
Raise :exc:`.RoleNotFound` instead of generic :exc:`.BadArgument`
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Role:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Role:
|
||||
guild = ctx.guild
|
||||
if not guild:
|
||||
raise NoPrivateMessage()
|
||||
@ -723,7 +721,7 @@ class RoleConverter(IDConverter[discord.Role]):
|
||||
class GameConverter(Converter[discord.Game]):
|
||||
"""Converts to :class:`~discord.Game`."""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Game:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Game:
|
||||
return discord.Game(name=argument)
|
||||
|
||||
|
||||
@ -736,7 +734,7 @@ class InviteConverter(Converter[discord.Invite]):
|
||||
Raise :exc:`.BadInviteArgument` instead of generic :exc:`.BadArgument`
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Invite:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Invite:
|
||||
try:
|
||||
invite = await ctx.bot.fetch_invite(argument)
|
||||
return invite
|
||||
@ -755,7 +753,7 @@ class GuildConverter(IDConverter[discord.Guild]):
|
||||
.. versionadded:: 1.7
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Guild:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Guild:
|
||||
match = self._get_id_match(argument)
|
||||
result = None
|
||||
|
||||
@ -787,7 +785,7 @@ class EmojiConverter(IDConverter[discord.Emoji]):
|
||||
Raise :exc:`.EmojiNotFound` instead of generic :exc:`.BadArgument`
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Emoji:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Emoji:
|
||||
match = self._get_id_match(argument) or re.match(r'<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$', argument)
|
||||
result = None
|
||||
bot = ctx.bot
|
||||
@ -821,7 +819,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
|
||||
Raise :exc:`.PartialEmojiConversionFailure` instead of generic :exc:`.BadArgument`
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialEmoji:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.PartialEmoji:
|
||||
match = re.match(r'<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$', argument)
|
||||
|
||||
if match:
|
||||
@ -850,7 +848,7 @@ class GuildStickerConverter(IDConverter[discord.GuildSticker]):
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.GuildSticker:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.GuildSticker:
|
||||
match = self._get_id_match(argument)
|
||||
result = None
|
||||
bot = ctx.bot
|
||||
@ -890,7 +888,7 @@ class ScheduledEventConverter(IDConverter[discord.ScheduledEvent]):
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.ScheduledEvent:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> discord.ScheduledEvent:
|
||||
guild = ctx.guild
|
||||
match = self._get_id_match(argument)
|
||||
result = None
|
||||
@ -967,7 +965,7 @@ class clean_content(Converter[str]):
|
||||
self.escape_markdown = escape_markdown
|
||||
self.remove_markdown = remove_markdown
|
||||
|
||||
async def convert(self, ctx: Context[_Bot], argument: str) -> str:
|
||||
async def convert(self, ctx: Context[BotT], argument: str) -> str:
|
||||
msg = ctx.message
|
||||
|
||||
if ctx.guild:
|
||||
@ -1047,10 +1045,10 @@ class Greedy(List[T]):
|
||||
|
||||
__slots__ = ('converter',)
|
||||
|
||||
def __init__(self, *, converter: T):
|
||||
self.converter = converter
|
||||
def __init__(self, *, converter: T) -> None:
|
||||
self.converter: T = converter
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
converter = getattr(self.converter, '__name__', repr(self.converter))
|
||||
return f'Greedy[{converter}]'
|
||||
|
||||
@ -1099,11 +1097,11 @@ def get_converter(param: inspect.Parameter) -> Any:
|
||||
_GenericAlias = type(List[T])
|
||||
|
||||
|
||||
def is_generic_type(tp: Any, *, _GenericAlias: Type = _GenericAlias) -> bool:
|
||||
return isinstance(tp, type) and issubclass(tp, Generic) or isinstance(tp, _GenericAlias) # type: ignore
|
||||
def is_generic_type(tp: Any, *, _GenericAlias: type = _GenericAlias) -> bool:
|
||||
return isinstance(tp, type) and issubclass(tp, Generic) or isinstance(tp, _GenericAlias)
|
||||
|
||||
|
||||
CONVERTER_MAPPING: Dict[Type[Any], Any] = {
|
||||
CONVERTER_MAPPING: Dict[type, Any] = {
|
||||
discord.Object: ObjectConverter,
|
||||
discord.Member: MemberConverter,
|
||||
discord.User: UserConverter,
|
||||
@ -1128,7 +1126,7 @@ CONVERTER_MAPPING: Dict[Type[Any], Any] = {
|
||||
}
|
||||
|
||||
|
||||
async def _actual_conversion(ctx: Context, converter, argument: str, param: inspect.Parameter):
|
||||
async def _actual_conversion(ctx: Context[BotT], converter, argument: str, param: inspect.Parameter):
|
||||
if converter is bool:
|
||||
return _convert_to_bool(argument)
|
||||
|
||||
@ -1166,7 +1164,7 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp
|
||||
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
|
||||
|
||||
|
||||
async def run_converters(ctx: Context, converter, argument: str, param: inspect.Parameter):
|
||||
async def run_converters(ctx: Context[BotT], converter: Any, argument: str, param: inspect.Parameter) -> Any:
|
||||
"""|coro|
|
||||
|
||||
Runs converters for a given converter, argument, and parameter.
|
||||
|
@ -220,7 +220,7 @@ class CooldownMapping:
|
||||
return self._type
|
||||
|
||||
@classmethod
|
||||
def from_cooldown(cls, rate, per, type) -> Self:
|
||||
def from_cooldown(cls, rate: float, per: float, type: Callable[[Message], Any]) -> Self:
|
||||
return cls(Cooldown(rate, per), type)
|
||||
|
||||
def _bucket_key(self, msg: Message) -> Any:
|
||||
|
@ -61,6 +61,8 @@ if TYPE_CHECKING:
|
||||
from discord.message import Message
|
||||
|
||||
from ._types import (
|
||||
BotT,
|
||||
ContextT,
|
||||
Coro,
|
||||
CoroFunc,
|
||||
Check,
|
||||
@ -101,7 +103,6 @@ MISSING: Any = discord.utils.MISSING
|
||||
T = TypeVar('T')
|
||||
CogT = TypeVar('CogT', bound='Optional[Cog]')
|
||||
CommandT = TypeVar('CommandT', bound='Command')
|
||||
ContextT = TypeVar('ContextT', bound='Context')
|
||||
# CHT = TypeVar('CHT', bound='Check')
|
||||
GroupT = TypeVar('GroupT', bound='Group')
|
||||
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
|
||||
@ -159,9 +160,9 @@ def get_signature_parameters(
|
||||
return params
|
||||
|
||||
|
||||
def wrap_callback(coro):
|
||||
def wrap_callback(coro: Callable[P, Coro[T]]) -> Callable[P, Coro[Optional[T]]]:
|
||||
@functools.wraps(coro)
|
||||
async def wrapped(*args, **kwargs):
|
||||
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> Optional[T]:
|
||||
try:
|
||||
ret = await coro(*args, **kwargs)
|
||||
except CommandError:
|
||||
@ -175,9 +176,11 @@ def wrap_callback(coro):
|
||||
return wrapped
|
||||
|
||||
|
||||
def hooked_wrapped_callback(command, ctx, coro):
|
||||
def hooked_wrapped_callback(
|
||||
command: Command[Any, ..., Any], ctx: Context[BotT], coro: Callable[P, Coro[T]]
|
||||
) -> Callable[P, Coro[Optional[T]]]:
|
||||
@functools.wraps(coro)
|
||||
async def wrapped(*args, **kwargs):
|
||||
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> Optional[T]:
|
||||
try:
|
||||
ret = await coro(*args, **kwargs)
|
||||
except CommandError:
|
||||
@ -191,7 +194,7 @@ def hooked_wrapped_callback(command, ctx, coro):
|
||||
raise CommandInvokeError(exc) from exc
|
||||
finally:
|
||||
if command._max_concurrency is not None:
|
||||
await command._max_concurrency.release(ctx)
|
||||
await command._max_concurrency.release(ctx.message)
|
||||
|
||||
await command.call_after_hooks(ctx)
|
||||
return ret
|
||||
@ -359,7 +362,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
except AttributeError:
|
||||
checks = kwargs.get('checks', [])
|
||||
|
||||
self.checks: List[Check] = checks
|
||||
self.checks: List[Check[ContextT]] = checks
|
||||
|
||||
try:
|
||||
cooldown = func.__commands_cooldown__
|
||||
@ -387,8 +390,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
self.cog: CogT = None
|
||||
|
||||
# bandaid for the fact that sometimes parent can be the bot instance
|
||||
parent = kwargs.get('parent')
|
||||
self.parent: Optional[GroupMixin] = parent if isinstance(parent, _BaseCommand) else None # type: ignore
|
||||
parent: Optional[GroupMixin[Any]] = kwargs.get('parent')
|
||||
self.parent: Optional[GroupMixin[Any]] = parent if isinstance(parent, _BaseCommand) else None
|
||||
|
||||
self._before_invoke: Optional[Hook] = None
|
||||
try:
|
||||
@ -422,16 +425,16 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
) -> None:
|
||||
self._callback = function
|
||||
unwrap = unwrap_function(function)
|
||||
self.module = unwrap.__module__
|
||||
self.module: str = unwrap.__module__
|
||||
|
||||
try:
|
||||
globalns = unwrap.__globals__
|
||||
except AttributeError:
|
||||
globalns = {}
|
||||
|
||||
self.params = get_signature_parameters(function, globalns)
|
||||
self.params: Dict[str, inspect.Parameter] = get_signature_parameters(function, globalns)
|
||||
|
||||
def add_check(self, func: Check, /) -> None:
|
||||
def add_check(self, func: Check[ContextT], /) -> None:
|
||||
"""Adds a check to the command.
|
||||
|
||||
This is the non-decorator interface to :func:`.check`.
|
||||
@ -450,7 +453,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
|
||||
self.checks.append(func)
|
||||
|
||||
def remove_check(self, func: Check, /) -> None:
|
||||
def remove_check(self, func: Check[ContextT], /) -> None:
|
||||
"""Removes a check from the command.
|
||||
|
||||
This function is idempotent and will not raise an exception
|
||||
@ -484,7 +487,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
self.__init__(self.callback, **dict(self.__original_kwargs__, **kwargs))
|
||||
self.cog = cog
|
||||
|
||||
async def __call__(self, context: Context, *args: P.args, **kwargs: P.kwargs) -> T:
|
||||
async def __call__(self, context: Context[BotT], *args: P.args, **kwargs: P.kwargs) -> T:
|
||||
"""|coro|
|
||||
|
||||
Calls the internal callback that the command holds.
|
||||
@ -539,7 +542,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
else:
|
||||
return self.copy()
|
||||
|
||||
async def dispatch_error(self, ctx: Context, error: Exception) -> None:
|
||||
async def dispatch_error(self, ctx: Context[BotT], error: CommandError) -> None:
|
||||
ctx.command_failed = True
|
||||
cog = self.cog
|
||||
try:
|
||||
@ -549,7 +552,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
else:
|
||||
injected = wrap_callback(coro)
|
||||
if cog is not None:
|
||||
await injected(cog, ctx, error)
|
||||
await injected(cog, ctx, error) # type: ignore
|
||||
else:
|
||||
await injected(ctx, error)
|
||||
|
||||
@ -562,7 +565,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
finally:
|
||||
ctx.bot.dispatch('command_error', ctx, error)
|
||||
|
||||
async def transform(self, ctx: Context, param: inspect.Parameter) -> Any:
|
||||
async def transform(self, ctx: Context[BotT], param: inspect.Parameter) -> Any:
|
||||
required = param.default is param.empty
|
||||
converter = get_converter(param)
|
||||
consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw
|
||||
@ -610,7 +613,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
# type-checker fails to narrow argument
|
||||
return await run_converters(ctx, converter, argument, param) # type: ignore
|
||||
|
||||
async def _transform_greedy_pos(self, ctx: Context, param: inspect.Parameter, required: bool, converter: Any) -> Any:
|
||||
async def _transform_greedy_pos(
|
||||
self, ctx: Context[BotT], param: inspect.Parameter, required: bool, converter: Any
|
||||
) -> Any:
|
||||
view = ctx.view
|
||||
result = []
|
||||
while not view.eof:
|
||||
@ -631,7 +636,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
return param.default
|
||||
return result
|
||||
|
||||
async def _transform_greedy_var_pos(self, ctx: Context, param: inspect.Parameter, converter: Any) -> Any:
|
||||
async def _transform_greedy_var_pos(self, ctx: Context[BotT], param: inspect.Parameter, converter: Any) -> Any:
|
||||
view = ctx.view
|
||||
previous = view.index
|
||||
try:
|
||||
@ -669,7 +674,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
return ' '.join(reversed(entries))
|
||||
|
||||
@property
|
||||
def parents(self) -> List[Group]:
|
||||
def parents(self) -> List[Group[Any, ..., Any]]:
|
||||
"""List[:class:`Group`]: Retrieves the parents of this command.
|
||||
|
||||
If the command has no parents then it returns an empty :class:`list`.
|
||||
@ -687,7 +692,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
return entries
|
||||
|
||||
@property
|
||||
def root_parent(self) -> Optional[Group]:
|
||||
def root_parent(self) -> Optional[Group[Any, ..., Any]]:
|
||||
"""Optional[:class:`Group`]: Retrieves the root parent of this command.
|
||||
|
||||
If the command has no parents then it returns ``None``.
|
||||
@ -716,7 +721,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
def __str__(self) -> str:
|
||||
return self.qualified_name
|
||||
|
||||
async def _parse_arguments(self, ctx: Context) -> None:
|
||||
async def _parse_arguments(self, ctx: Context[BotT]) -> None:
|
||||
ctx.args = [ctx] if self.cog is None else [self.cog, ctx]
|
||||
ctx.kwargs = {}
|
||||
args = ctx.args
|
||||
@ -752,7 +757,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
if not self.ignore_extra and not view.eof:
|
||||
raise TooManyArguments('Too many arguments passed to ' + self.qualified_name)
|
||||
|
||||
async def call_before_hooks(self, ctx: Context) -> None:
|
||||
async def call_before_hooks(self, ctx: Context[BotT]) -> None:
|
||||
# now that we're done preparing we can call the pre-command hooks
|
||||
# first, call the command local hook:
|
||||
cog = self.cog
|
||||
@ -777,7 +782,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
if hook is not None:
|
||||
await hook(ctx)
|
||||
|
||||
async def call_after_hooks(self, ctx: Context) -> None:
|
||||
async def call_after_hooks(self, ctx: Context[BotT]) -> None:
|
||||
cog = self.cog
|
||||
if self._after_invoke is not None:
|
||||
instance = getattr(self._after_invoke, '__self__', cog)
|
||||
@ -796,7 +801,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
if hook is not None:
|
||||
await hook(ctx)
|
||||
|
||||
def _prepare_cooldowns(self, ctx: Context) -> None:
|
||||
def _prepare_cooldowns(self, ctx: Context[BotT]) -> None:
|
||||
if self._buckets.valid:
|
||||
dt = ctx.message.edited_at or ctx.message.created_at
|
||||
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
|
||||
@ -806,7 +811,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
if retry_after:
|
||||
raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore
|
||||
|
||||
async def prepare(self, ctx: Context) -> None:
|
||||
async def prepare(self, ctx: Context[BotT]) -> None:
|
||||
ctx.command = self
|
||||
|
||||
if not await self.can_run(ctx):
|
||||
@ -830,7 +835,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
await self._max_concurrency.release(ctx) # type: ignore
|
||||
raise
|
||||
|
||||
def is_on_cooldown(self, ctx: Context) -> bool:
|
||||
def is_on_cooldown(self, ctx: Context[BotT]) -> bool:
|
||||
"""Checks whether the command is currently on cooldown.
|
||||
|
||||
Parameters
|
||||
@ -851,7 +856,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
|
||||
return bucket.get_tokens(current) == 0
|
||||
|
||||
def reset_cooldown(self, ctx: Context) -> None:
|
||||
def reset_cooldown(self, ctx: Context[BotT]) -> None:
|
||||
"""Resets the cooldown on this command.
|
||||
|
||||
Parameters
|
||||
@ -863,7 +868,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
bucket = self._buckets.get_bucket(ctx.message)
|
||||
bucket.reset()
|
||||
|
||||
def get_cooldown_retry_after(self, ctx: Context) -> float:
|
||||
def get_cooldown_retry_after(self, ctx: Context[BotT]) -> float:
|
||||
"""Retrieves the amount of seconds before this command can be tried again.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
@ -887,7 +892,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
|
||||
return 0.0
|
||||
|
||||
async def invoke(self, ctx: Context) -> None:
|
||||
async def invoke(self, ctx: Context[BotT]) -> None:
|
||||
await self.prepare(ctx)
|
||||
|
||||
# terminate the invoked_subcommand chain.
|
||||
@ -896,9 +901,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
ctx.invoked_subcommand = None
|
||||
ctx.subcommand_passed = None
|
||||
injected = hooked_wrapped_callback(self, ctx, self.callback)
|
||||
await injected(*ctx.args, **ctx.kwargs)
|
||||
await injected(*ctx.args, **ctx.kwargs) # type: ignore
|
||||
|
||||
async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None:
|
||||
async def reinvoke(self, ctx: Context[BotT], *, call_hooks: bool = False) -> None:
|
||||
ctx.command = self
|
||||
await self._parse_arguments(ctx)
|
||||
|
||||
@ -936,7 +941,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
if not asyncio.iscoroutinefunction(coro):
|
||||
raise TypeError('The error handler must be a coroutine.')
|
||||
|
||||
self.on_error: Error = coro
|
||||
self.on_error: Error[Any] = coro
|
||||
return coro
|
||||
|
||||
def has_error_handler(self) -> bool:
|
||||
@ -1075,7 +1080,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
|
||||
return ' '.join(result)
|
||||
|
||||
async def can_run(self, ctx: Context) -> bool:
|
||||
async def can_run(self, ctx: Context[BotT]) -> bool:
|
||||
"""|coro|
|
||||
|
||||
Checks if the command can be executed by checking all the predicates
|
||||
@ -1341,7 +1346,7 @@ class GroupMixin(Generic[CogT]):
|
||||
def command(
|
||||
self,
|
||||
name: str = MISSING,
|
||||
cls: Type[Command] = MISSING,
|
||||
cls: Type[Command[Any, ..., Any]] = MISSING,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@ -1401,7 +1406,7 @@ class GroupMixin(Generic[CogT]):
|
||||
def group(
|
||||
self,
|
||||
name: str = MISSING,
|
||||
cls: Type[Group] = MISSING,
|
||||
cls: Type[Group[Any, ..., Any]] = MISSING,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@ -1461,9 +1466,9 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
|
||||
ret = super().copy()
|
||||
for cmd in self.commands:
|
||||
ret.add_command(cmd.copy())
|
||||
return ret # type: ignore
|
||||
return ret
|
||||
|
||||
async def invoke(self, ctx: Context) -> None:
|
||||
async def invoke(self, ctx: Context[BotT]) -> None:
|
||||
ctx.invoked_subcommand = None
|
||||
ctx.subcommand_passed = None
|
||||
early_invoke = not self.invoke_without_command
|
||||
@ -1481,7 +1486,7 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
|
||||
|
||||
if early_invoke:
|
||||
injected = hooked_wrapped_callback(self, ctx, self.callback)
|
||||
await injected(*ctx.args, **ctx.kwargs)
|
||||
await injected(*ctx.args, **ctx.kwargs) # type: ignore
|
||||
|
||||
ctx.invoked_parents.append(ctx.invoked_with) # type: ignore
|
||||
|
||||
@ -1494,7 +1499,7 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
|
||||
view.previous = previous
|
||||
await super().invoke(ctx)
|
||||
|
||||
async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None:
|
||||
async def reinvoke(self, ctx: Context[BotT], *, call_hooks: bool = False) -> None:
|
||||
ctx.invoked_subcommand = None
|
||||
early_invoke = not self.invoke_without_command
|
||||
if early_invoke:
|
||||
@ -1592,7 +1597,7 @@ def command(
|
||||
|
||||
def command(
|
||||
name: str = MISSING,
|
||||
cls: Type[Command] = MISSING,
|
||||
cls: Type[Command[Any, ..., Any]] = MISSING,
|
||||
**attrs: Any,
|
||||
) -> Any:
|
||||
"""A decorator that transforms a function into a :class:`.Command`
|
||||
@ -1662,7 +1667,7 @@ def group(
|
||||
|
||||
def group(
|
||||
name: str = MISSING,
|
||||
cls: Type[Group] = MISSING,
|
||||
cls: Type[Group[Any, ..., Any]] = MISSING,
|
||||
**attrs: Any,
|
||||
) -> Any:
|
||||
"""A decorator that transforms a function into a :class:`.Group`.
|
||||
@ -1679,7 +1684,7 @@ def group(
|
||||
return command(name=name, cls=cls, **attrs)
|
||||
|
||||
|
||||
def check(predicate: Check) -> Callable[[T], T]:
|
||||
def check(predicate: Check[ContextT]) -> Callable[[T], T]:
|
||||
r"""A decorator that adds a check to the :class:`.Command` or its
|
||||
subclasses. These checks could be accessed via :attr:`.Command.checks`.
|
||||
|
||||
@ -1774,7 +1779,7 @@ def check(predicate: Check) -> Callable[[T], T]:
|
||||
return decorator # type: ignore
|
||||
|
||||
|
||||
def check_any(*checks: Check) -> Callable[[T], T]:
|
||||
def check_any(*checks: Check[ContextT]) -> Callable[[T], T]:
|
||||
r"""A :func:`check` that is added that checks if any of the checks passed
|
||||
will pass, i.e. using logical OR.
|
||||
|
||||
@ -1827,7 +1832,7 @@ def check_any(*checks: Check) -> Callable[[T], T]:
|
||||
else:
|
||||
unwrapped.append(pred)
|
||||
|
||||
async def predicate(ctx: Context) -> bool:
|
||||
async def predicate(ctx: Context[BotT]) -> bool:
|
||||
errors = []
|
||||
for func in unwrapped:
|
||||
try:
|
||||
@ -1870,7 +1875,7 @@ def has_role(item: Union[int, str]) -> Callable[[T], T]:
|
||||
The name or ID of the role to check.
|
||||
"""
|
||||
|
||||
def predicate(ctx: Context) -> bool:
|
||||
def predicate(ctx: Context[BotT]) -> bool:
|
||||
if ctx.guild is None:
|
||||
raise NoPrivateMessage()
|
||||
|
||||
@ -1923,7 +1928,7 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
|
||||
raise NoPrivateMessage()
|
||||
|
||||
# ctx.guild is None doesn't narrow ctx.author to Member
|
||||
getter = functools.partial(discord.utils.get, ctx.author.roles) # type: ignore
|
||||
getter = functools.partial(discord.utils.get, ctx.author.roles)
|
||||
if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items):
|
||||
return True
|
||||
raise MissingAnyRole(list(items))
|
||||
@ -2022,7 +2027,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
if invalid:
|
||||
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
|
||||
|
||||
def predicate(ctx: Context) -> bool:
|
||||
def predicate(ctx: Context[BotT]) -> bool:
|
||||
ch = ctx.channel
|
||||
permissions = ch.permissions_for(ctx.author) # type: ignore
|
||||
|
||||
@ -2048,7 +2053,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
if invalid:
|
||||
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
|
||||
|
||||
def predicate(ctx: Context) -> bool:
|
||||
def predicate(ctx: Context[BotT]) -> bool:
|
||||
guild = ctx.guild
|
||||
me = guild.me if guild is not None else ctx.bot.user
|
||||
permissions = ctx.channel.permissions_for(me) # type: ignore
|
||||
@ -2077,7 +2082,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
if invalid:
|
||||
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
|
||||
|
||||
def predicate(ctx: Context) -> bool:
|
||||
def predicate(ctx: Context[BotT]) -> bool:
|
||||
if not ctx.guild:
|
||||
raise NoPrivateMessage
|
||||
|
||||
@ -2103,7 +2108,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
|
||||
if invalid:
|
||||
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
|
||||
|
||||
def predicate(ctx: Context) -> bool:
|
||||
def predicate(ctx: Context[BotT]) -> bool:
|
||||
if not ctx.guild:
|
||||
raise NoPrivateMessage
|
||||
|
||||
@ -2129,7 +2134,7 @@ def dm_only() -> Callable[[T], T]:
|
||||
.. versionadded:: 1.1
|
||||
"""
|
||||
|
||||
def predicate(ctx: Context) -> bool:
|
||||
def predicate(ctx: Context[BotT]) -> bool:
|
||||
if ctx.guild is not None:
|
||||
raise PrivateMessageOnly()
|
||||
return True
|
||||
@ -2146,7 +2151,7 @@ def guild_only() -> Callable[[T], T]:
|
||||
that is inherited from :exc:`.CheckFailure`.
|
||||
"""
|
||||
|
||||
def predicate(ctx: Context) -> bool:
|
||||
def predicate(ctx: Context[BotT]) -> bool:
|
||||
if ctx.guild is None:
|
||||
raise NoPrivateMessage()
|
||||
return True
|
||||
@ -2164,7 +2169,7 @@ def is_owner() -> Callable[[T], T]:
|
||||
from :exc:`.CheckFailure`.
|
||||
"""
|
||||
|
||||
async def predicate(ctx: Context) -> bool:
|
||||
async def predicate(ctx: Context[BotT]) -> bool:
|
||||
if not await ctx.bot.is_owner(ctx.author):
|
||||
raise NotOwner('You do not own this bot.')
|
||||
return True
|
||||
@ -2184,7 +2189,7 @@ def is_nsfw() -> Callable[[T], T]:
|
||||
DM channels will also now pass this check.
|
||||
"""
|
||||
|
||||
def pred(ctx: Context) -> bool:
|
||||
def pred(ctx: Context[BotT]) -> bool:
|
||||
ch = ctx.channel
|
||||
if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()):
|
||||
return True
|
||||
|
@ -39,6 +39,8 @@ if TYPE_CHECKING:
|
||||
from discord.threads import Thread
|
||||
from discord.types.snowflake import Snowflake, SnowflakeList
|
||||
|
||||
from ._types import BotT
|
||||
|
||||
|
||||
__all__ = (
|
||||
'CommandError',
|
||||
@ -135,8 +137,8 @@ class ConversionError(CommandError):
|
||||
the ``__cause__`` attribute.
|
||||
"""
|
||||
|
||||
def __init__(self, converter: Converter, original: Exception) -> None:
|
||||
self.converter: Converter = converter
|
||||
def __init__(self, converter: Converter[Any], original: Exception) -> None:
|
||||
self.converter: Converter[Any] = converter
|
||||
self.original: Exception = original
|
||||
|
||||
|
||||
@ -224,9 +226,9 @@ class CheckAnyFailure(CheckFailure):
|
||||
A list of check predicates that failed.
|
||||
"""
|
||||
|
||||
def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]) -> None:
|
||||
def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context[BotT]], bool]]) -> None:
|
||||
self.checks: List[CheckFailure] = checks
|
||||
self.errors: List[Callable[[Context], bool]] = errors
|
||||
self.errors: List[Callable[[Context[BotT]], bool]] = errors
|
||||
super().__init__('You do not have permission to run this command.')
|
||||
|
||||
|
||||
@ -807,9 +809,9 @@ class BadUnionArgument(UserInputError):
|
||||
A list of errors that were caught from failing the conversion.
|
||||
"""
|
||||
|
||||
def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None:
|
||||
def __init__(self, param: Parameter, converters: Tuple[type, ...], errors: List[CommandError]) -> None:
|
||||
self.param: Parameter = param
|
||||
self.converters: Tuple[Type, ...] = converters
|
||||
self.converters: Tuple[type, ...] = converters
|
||||
self.errors: List[CommandError] = errors
|
||||
|
||||
def _get_name(x):
|
||||
|
@ -49,8 +49,6 @@ from typing import (
|
||||
Tuple,
|
||||
List,
|
||||
Any,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -70,6 +68,8 @@ if TYPE_CHECKING:
|
||||
|
||||
from .context import Context
|
||||
|
||||
from ._types import BotT
|
||||
|
||||
|
||||
@dataclass
|
||||
class Flag:
|
||||
@ -148,7 +148,7 @@ def flag(
|
||||
return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override)
|
||||
|
||||
|
||||
def validate_flag_name(name: str, forbidden: Set[str]):
|
||||
def validate_flag_name(name: str, forbidden: Set[str]) -> None:
|
||||
if not name:
|
||||
raise ValueError('flag names should not be empty')
|
||||
|
||||
@ -348,7 +348,7 @@ class FlagsMeta(type):
|
||||
return type.__new__(cls, name, bases, attrs)
|
||||
|
||||
|
||||
async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]:
|
||||
async def tuple_convert_all(ctx: Context[BotT], argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]:
|
||||
view = StringView(argument)
|
||||
results = []
|
||||
param: inspect.Parameter = ctx.current_parameter # type: ignore
|
||||
@ -373,7 +373,7 @@ async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter:
|
||||
return tuple(results)
|
||||
|
||||
|
||||
async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]:
|
||||
async def tuple_convert_flag(ctx: Context[BotT], argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]:
|
||||
view = StringView(argument)
|
||||
results = []
|
||||
param: inspect.Parameter = ctx.current_parameter # type: ignore
|
||||
@ -401,7 +401,7 @@ async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters
|
||||
return tuple(results)
|
||||
|
||||
|
||||
async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -> Any:
|
||||
async def convert_flag(ctx: Context[BotT], argument: str, flag: Flag, annotation: Any = None) -> Any:
|
||||
param: inspect.Parameter = ctx.current_parameter # type: ignore
|
||||
annotation = annotation or flag.annotation
|
||||
try:
|
||||
@ -480,7 +480,7 @@ class FlagConverter(metaclass=FlagsMeta):
|
||||
yield (flag.name, getattr(self, flag.attribute))
|
||||
|
||||
@classmethod
|
||||
async def _construct_default(cls, ctx: Context) -> Self:
|
||||
async def _construct_default(cls, ctx: Context[BotT]) -> Self:
|
||||
self = cls.__new__(cls)
|
||||
flags = cls.__commands_flags__
|
||||
for flag in flags.values():
|
||||
@ -546,7 +546,7 @@ class FlagConverter(metaclass=FlagsMeta):
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def convert(cls, ctx: Context, argument: str) -> Self:
|
||||
async def convert(cls, ctx: Context[BotT], argument: str) -> Self:
|
||||
"""|coro|
|
||||
|
||||
The method that actually converters an argument to the flag mapping.
|
||||
@ -610,7 +610,7 @@ class FlagConverter(metaclass=FlagsMeta):
|
||||
values = [await convert_flag(ctx, value, flag) for value in values]
|
||||
|
||||
if flag.cast_to_dict:
|
||||
values = dict(values) # type: ignore
|
||||
values = dict(values)
|
||||
|
||||
setattr(self, flag.attribute, values)
|
||||
|
||||
|
@ -22,13 +22,27 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
import re
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Optional,
|
||||
Generator,
|
||||
List,
|
||||
TypeVar,
|
||||
Callable,
|
||||
Any,
|
||||
Dict,
|
||||
Tuple,
|
||||
Iterable,
|
||||
Sequence,
|
||||
Mapping,
|
||||
)
|
||||
|
||||
import discord.utils
|
||||
|
||||
@ -36,7 +50,21 @@ from .core import Group, Command, get_signature_parameters
|
||||
from .errors import CommandError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
import inspect
|
||||
|
||||
import discord.abc
|
||||
|
||||
from .bot import BotBase
|
||||
from .context import Context
|
||||
from .cog import Cog
|
||||
|
||||
from ._types import (
|
||||
Check,
|
||||
ContextT,
|
||||
BotT,
|
||||
_Bot,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
'Paginator',
|
||||
@ -45,7 +73,9 @@ __all__ = (
|
||||
'MinimalHelpCommand',
|
||||
)
|
||||
|
||||
MISSING = discord.utils.MISSING
|
||||
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
|
||||
|
||||
MISSING: Any = discord.utils.MISSING
|
||||
|
||||
# help -> shows info of bot on top/bottom and lists subcommands
|
||||
# help command -> shows detailed info of command
|
||||
@ -80,10 +110,10 @@ class Paginator:
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
prefix: :class:`str`
|
||||
The prefix inserted to every page. e.g. three backticks.
|
||||
suffix: :class:`str`
|
||||
The suffix appended at the end of every page. e.g. three backticks.
|
||||
prefix: Optional[:class:`str`]
|
||||
The prefix inserted to every page. e.g. three backticks, if any.
|
||||
suffix: Optional[:class:`str`]
|
||||
The suffix appended at the end of every page. e.g. three backticks, if any.
|
||||
max_size: :class:`int`
|
||||
The maximum amount of codepoints allowed in a page.
|
||||
linesep: :class:`str`
|
||||
@ -91,36 +121,38 @@ class Paginator:
|
||||
.. versionadded:: 1.7
|
||||
"""
|
||||
|
||||
def __init__(self, prefix='```', suffix='```', max_size=2000, linesep='\n'):
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
self.max_size = max_size
|
||||
self.linesep = linesep
|
||||
def __init__(
|
||||
self, prefix: Optional[str] = '```', suffix: Optional[str] = '```', max_size: int = 2000, linesep: str = '\n'
|
||||
) -> None:
|
||||
self.prefix: Optional[str] = prefix
|
||||
self.suffix: Optional[str] = suffix
|
||||
self.max_size: int = max_size
|
||||
self.linesep: str = linesep
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
def clear(self) -> None:
|
||||
"""Clears the paginator to have no pages."""
|
||||
if self.prefix is not None:
|
||||
self._current_page = [self.prefix]
|
||||
self._count = len(self.prefix) + self._linesep_len # prefix + newline
|
||||
self._current_page: List[str] = [self.prefix]
|
||||
self._count: int = len(self.prefix) + self._linesep_len # prefix + newline
|
||||
else:
|
||||
self._current_page = []
|
||||
self._count = 0
|
||||
self._pages = []
|
||||
self._pages: List[str] = []
|
||||
|
||||
@property
|
||||
def _prefix_len(self):
|
||||
def _prefix_len(self) -> int:
|
||||
return len(self.prefix) if self.prefix else 0
|
||||
|
||||
@property
|
||||
def _suffix_len(self):
|
||||
def _suffix_len(self) -> int:
|
||||
return len(self.suffix) if self.suffix else 0
|
||||
|
||||
@property
|
||||
def _linesep_len(self):
|
||||
def _linesep_len(self) -> int:
|
||||
return len(self.linesep)
|
||||
|
||||
def add_line(self, line='', *, empty=False):
|
||||
def add_line(self, line: str = '', *, empty: bool = False) -> None:
|
||||
"""Adds a line to the current page.
|
||||
|
||||
If the line exceeds the :attr:`max_size` then an exception
|
||||
@ -152,7 +184,7 @@ class Paginator:
|
||||
self._current_page.append('')
|
||||
self._count += self._linesep_len
|
||||
|
||||
def close_page(self):
|
||||
def close_page(self) -> None:
|
||||
"""Prematurely terminate a page."""
|
||||
if self.suffix is not None:
|
||||
self._current_page.append(self.suffix)
|
||||
@ -165,36 +197,38 @@ class Paginator:
|
||||
self._current_page = []
|
||||
self._count = 0
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
total = sum(len(p) for p in self._pages)
|
||||
return total + self._count
|
||||
|
||||
@property
|
||||
def pages(self):
|
||||
def pages(self) -> List[str]:
|
||||
"""List[:class:`str`]: Returns the rendered list of pages."""
|
||||
# we have more than just the prefix in our current page
|
||||
if len(self._current_page) > (0 if self.prefix is None else 1):
|
||||
self.close_page()
|
||||
return self._pages
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
fmt = '<Paginator prefix: {0.prefix!r} suffix: {0.suffix!r} linesep: {0.linesep!r} max_size: {0.max_size} count: {0._count}>'
|
||||
return fmt.format(self)
|
||||
|
||||
|
||||
def _not_overridden(f):
|
||||
def _not_overridden(f: FuncT) -> FuncT:
|
||||
f.__help_command_not_overridden__ = True
|
||||
return f
|
||||
|
||||
|
||||
class _HelpCommandImpl(Command):
|
||||
def __init__(self, inject, *args, **kwargs):
|
||||
def __init__(self, inject: HelpCommand, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(inject.command_callback, *args, **kwargs)
|
||||
self._original = inject
|
||||
self._injected = inject
|
||||
self.params = get_signature_parameters(inject.command_callback, globals(), skip_parameters=1)
|
||||
self._original: HelpCommand = inject
|
||||
self._injected: HelpCommand = inject
|
||||
self.params: Dict[str, inspect.Parameter] = get_signature_parameters(
|
||||
inject.command_callback, globals(), skip_parameters=1
|
||||
)
|
||||
|
||||
async def prepare(self, ctx):
|
||||
async def prepare(self, ctx: Context[Any]) -> None:
|
||||
self._injected = injected = self._original.copy()
|
||||
injected.context = ctx
|
||||
self.callback = injected.command_callback
|
||||
@ -209,7 +243,7 @@ class _HelpCommandImpl(Command):
|
||||
|
||||
await super().prepare(ctx)
|
||||
|
||||
async def _parse_arguments(self, ctx):
|
||||
async def _parse_arguments(self, ctx: Context[BotT]) -> None:
|
||||
# Make the parser think we don't have a cog so it doesn't
|
||||
# inject the parameter into `ctx.args`.
|
||||
original_cog = self.cog
|
||||
@ -219,22 +253,26 @@ class _HelpCommandImpl(Command):
|
||||
finally:
|
||||
self.cog = original_cog
|
||||
|
||||
async def _on_error_cog_implementation(self, dummy, ctx, error):
|
||||
async def _on_error_cog_implementation(self, _, ctx: Context[BotT], error: CommandError) -> None:
|
||||
await self._injected.on_help_command_error(ctx, error)
|
||||
|
||||
def _inject_into_cog(self, cog):
|
||||
def _inject_into_cog(self, cog: Cog) -> None:
|
||||
# Warning: hacky
|
||||
|
||||
# Make the cog think that get_commands returns this command
|
||||
# as well if we inject it without modifying __cog_commands__
|
||||
# since that's used for the injection and ejection of cogs.
|
||||
def wrapped_get_commands(*, _original=cog.get_commands):
|
||||
def wrapped_get_commands(
|
||||
*, _original: Callable[[], List[Command[Any, ..., Any]]] = cog.get_commands
|
||||
) -> List[Command[Any, ..., Any]]:
|
||||
ret = _original()
|
||||
ret.append(self)
|
||||
return ret
|
||||
|
||||
# Ditto here
|
||||
def wrapped_walk_commands(*, _original=cog.walk_commands):
|
||||
def wrapped_walk_commands(
|
||||
*, _original: Callable[[], Generator[Command[Any, ..., Any], None, None]] = cog.walk_commands
|
||||
):
|
||||
yield from _original()
|
||||
yield self
|
||||
|
||||
@ -244,7 +282,7 @@ class _HelpCommandImpl(Command):
|
||||
cog.walk_commands = wrapped_walk_commands
|
||||
self.cog = cog
|
||||
|
||||
def _eject_cog(self):
|
||||
def _eject_cog(self) -> None:
|
||||
if self.cog is None:
|
||||
return
|
||||
|
||||
@ -298,7 +336,11 @@ class HelpCommand:
|
||||
|
||||
MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys()))
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if TYPE_CHECKING:
|
||||
__original_kwargs__: Dict[str, Any]
|
||||
__original_args__: Tuple[Any, ...]
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
|
||||
# To prevent race conditions of a single instance while also allowing
|
||||
# for settings to be passed the original arguments passed must be assigned
|
||||
# to allow for easier copies (which will be made when the help command is actually called)
|
||||
@ -314,30 +356,31 @@ class HelpCommand:
|
||||
self.__original_args__ = deepcopy(args)
|
||||
return self
|
||||
|
||||
def __init__(self, **options):
|
||||
self.show_hidden = options.pop('show_hidden', False)
|
||||
self.verify_checks = options.pop('verify_checks', True)
|
||||
def __init__(self, **options: Any) -> None:
|
||||
self.show_hidden: bool = options.pop('show_hidden', False)
|
||||
self.verify_checks: bool = options.pop('verify_checks', True)
|
||||
self.command_attrs: Dict[str, Any]
|
||||
self.command_attrs = attrs = options.pop('command_attrs', {})
|
||||
attrs.setdefault('name', 'help')
|
||||
attrs.setdefault('help', 'Shows this message')
|
||||
self.context: Context = MISSING
|
||||
self.context: Context[_Bot] = MISSING
|
||||
self._command_impl = _HelpCommandImpl(self, **self.command_attrs)
|
||||
|
||||
def copy(self):
|
||||
def copy(self) -> Self:
|
||||
obj = self.__class__(*self.__original_args__, **self.__original_kwargs__)
|
||||
obj._command_impl = self._command_impl
|
||||
return obj
|
||||
|
||||
def _add_to_bot(self, bot):
|
||||
def _add_to_bot(self, bot: BotBase) -> None:
|
||||
command = _HelpCommandImpl(self, **self.command_attrs)
|
||||
bot.add_command(command)
|
||||
self._command_impl = command
|
||||
|
||||
def _remove_from_bot(self, bot):
|
||||
def _remove_from_bot(self, bot: BotBase) -> None:
|
||||
bot.remove_command(self._command_impl.name)
|
||||
self._command_impl._eject_cog()
|
||||
|
||||
def add_check(self, func, /):
|
||||
def add_check(self, func: Check[ContextT], /) -> None:
|
||||
"""
|
||||
Adds a check to the help command.
|
||||
|
||||
@ -355,7 +398,7 @@ class HelpCommand:
|
||||
|
||||
self._command_impl.add_check(func)
|
||||
|
||||
def remove_check(self, func, /):
|
||||
def remove_check(self, func: Check[ContextT], /) -> None:
|
||||
"""
|
||||
Removes a check from the help command.
|
||||
|
||||
@ -376,15 +419,15 @@ class HelpCommand:
|
||||
|
||||
self._command_impl.remove_check(func)
|
||||
|
||||
def get_bot_mapping(self):
|
||||
def get_bot_mapping(self) -> Dict[Optional[Cog], List[Command[Any, ..., Any]]]:
|
||||
"""Retrieves the bot mapping passed to :meth:`send_bot_help`."""
|
||||
bot = self.context.bot
|
||||
mapping = {cog: cog.get_commands() for cog in bot.cogs.values()}
|
||||
mapping: Dict[Optional[Cog], List[Command[Any, ..., Any]]] = {cog: cog.get_commands() for cog in bot.cogs.values()}
|
||||
mapping[None] = [c for c in bot.commands if c.cog is None]
|
||||
return mapping
|
||||
|
||||
@property
|
||||
def invoked_with(self):
|
||||
def invoked_with(self) -> Optional[str]:
|
||||
"""Similar to :attr:`Context.invoked_with` except properly handles
|
||||
the case where :meth:`Context.send_help` is used.
|
||||
|
||||
@ -395,7 +438,7 @@ class HelpCommand:
|
||||
|
||||
Returns
|
||||
---------
|
||||
:class:`str`
|
||||
Optional[:class:`str`]
|
||||
The command name that triggered this invocation.
|
||||
"""
|
||||
command_name = self._command_impl.name
|
||||
@ -404,7 +447,7 @@ class HelpCommand:
|
||||
return command_name
|
||||
return ctx.invoked_with
|
||||
|
||||
def get_command_signature(self, command):
|
||||
def get_command_signature(self, command: Command[Any, ..., Any]) -> str:
|
||||
"""Retrieves the signature portion of the help page.
|
||||
|
||||
Parameters
|
||||
@ -418,14 +461,14 @@ class HelpCommand:
|
||||
The signature for the command.
|
||||
"""
|
||||
|
||||
parent = command.parent
|
||||
parent: Optional[Group[Any, ..., Any]] = command.parent # type: ignore - the parent will be a Group
|
||||
entries = []
|
||||
while parent is not None:
|
||||
if not parent.signature or parent.invoke_without_command:
|
||||
entries.append(parent.name)
|
||||
else:
|
||||
entries.append(parent.name + ' ' + parent.signature)
|
||||
parent = parent.parent
|
||||
parent = parent.parent # type: ignore
|
||||
parent_sig = ' '.join(reversed(entries))
|
||||
|
||||
if len(command.aliases) > 0:
|
||||
@ -439,7 +482,7 @@ class HelpCommand:
|
||||
|
||||
return f'{self.context.clean_prefix}{alias} {command.signature}'
|
||||
|
||||
def remove_mentions(self, string):
|
||||
def remove_mentions(self, string: str) -> str:
|
||||
"""Removes mentions from the string to prevent abuse.
|
||||
|
||||
This includes ``@everyone``, ``@here``, member mentions and role mentions.
|
||||
@ -450,13 +493,13 @@ class HelpCommand:
|
||||
The string with mentions removed.
|
||||
"""
|
||||
|
||||
def replace(obj, *, transforms=self.MENTION_TRANSFORMS):
|
||||
def replace(obj: re.Match, *, transforms: Dict[str, str] = self.MENTION_TRANSFORMS) -> str:
|
||||
return transforms.get(obj.group(0), '@invalid')
|
||||
|
||||
return self.MENTION_PATTERN.sub(replace, string)
|
||||
|
||||
@property
|
||||
def cog(self):
|
||||
def cog(self) -> Optional[Cog]:
|
||||
"""A property for retrieving or setting the cog for the help command.
|
||||
|
||||
When a cog is set for the help command, it is as-if the help command
|
||||
@ -473,7 +516,7 @@ class HelpCommand:
|
||||
return self._command_impl.cog
|
||||
|
||||
@cog.setter
|
||||
def cog(self, cog):
|
||||
def cog(self, cog: Optional[Cog]) -> None:
|
||||
# Remove whatever cog is currently valid, if any
|
||||
self._command_impl._eject_cog()
|
||||
|
||||
@ -481,7 +524,7 @@ class HelpCommand:
|
||||
if cog is not None:
|
||||
self._command_impl._inject_into_cog(cog)
|
||||
|
||||
def command_not_found(self, string):
|
||||
def command_not_found(self, string: str) -> str:
|
||||
"""|maybecoro|
|
||||
|
||||
A method called when a command is not found in the help command.
|
||||
@ -502,7 +545,7 @@ class HelpCommand:
|
||||
"""
|
||||
return f'No command called "{string}" found.'
|
||||
|
||||
def subcommand_not_found(self, command, string):
|
||||
def subcommand_not_found(self, command: Command[Any, ..., Any], string: str) -> str:
|
||||
"""|maybecoro|
|
||||
|
||||
A method called when a command did not have a subcommand requested in the help command.
|
||||
@ -532,7 +575,13 @@ class HelpCommand:
|
||||
return f'Command "{command.qualified_name}" has no subcommand named {string}'
|
||||
return f'Command "{command.qualified_name}" has no subcommands.'
|
||||
|
||||
async def filter_commands(self, commands, *, sort=False, key=None):
|
||||
async def filter_commands(
|
||||
self,
|
||||
commands: Iterable[Command[Any, ..., Any]],
|
||||
*,
|
||||
sort: bool = False,
|
||||
key: Optional[Callable[[Command[Any, ..., Any]], Any]] = None,
|
||||
) -> List[Command[Any, ..., Any]]:
|
||||
"""|coro|
|
||||
|
||||
Returns a filtered list of commands and optionally sorts them.
|
||||
@ -546,7 +595,7 @@ class HelpCommand:
|
||||
An iterable of commands that are getting filtered.
|
||||
sort: :class:`bool`
|
||||
Whether to sort the result.
|
||||
key: Optional[Callable[:class:`Command`, Any]]
|
||||
key: Optional[Callable[[:class:`Command`], Any]]
|
||||
An optional key function to pass to :func:`py:sorted` that
|
||||
takes a :class:`Command` as its sole parameter. If ``sort`` is
|
||||
passed as ``True`` then this will default as the command name.
|
||||
@ -565,14 +614,14 @@ class HelpCommand:
|
||||
if self.verify_checks is False:
|
||||
# if we do not need to verify the checks then we can just
|
||||
# run it straight through normally without using await.
|
||||
return sorted(iterator, key=key) if sort else list(iterator)
|
||||
return sorted(iterator, key=key) if sort else list(iterator) # type: ignore - the key shouldn't be None
|
||||
|
||||
if self.verify_checks is None and not self.context.guild:
|
||||
# if verify_checks is None and we're in a DM, don't verify
|
||||
return sorted(iterator, key=key) if sort else list(iterator)
|
||||
return sorted(iterator, key=key) if sort else list(iterator) # type: ignore
|
||||
|
||||
# if we're here then we need to check every command if it can run
|
||||
async def predicate(cmd):
|
||||
async def predicate(cmd: Command[Any, ..., Any]) -> bool:
|
||||
try:
|
||||
return await cmd.can_run(self.context)
|
||||
except CommandError:
|
||||
@ -588,7 +637,7 @@ class HelpCommand:
|
||||
ret.sort(key=key)
|
||||
return ret
|
||||
|
||||
def get_max_size(self, commands):
|
||||
def get_max_size(self, commands: Sequence[Command[Any, ..., Any]]) -> int:
|
||||
"""Returns the largest name length of the specified command list.
|
||||
|
||||
Parameters
|
||||
@ -605,7 +654,7 @@ class HelpCommand:
|
||||
as_lengths = (discord.utils._string_width(c.name) for c in commands)
|
||||
return max(as_lengths, default=0)
|
||||
|
||||
def get_destination(self):
|
||||
def get_destination(self) -> discord.abc.MessageableChannel:
|
||||
"""Returns the :class:`~discord.abc.Messageable` where the help command will be output.
|
||||
|
||||
You can override this method to customise the behaviour.
|
||||
@ -619,7 +668,7 @@ class HelpCommand:
|
||||
"""
|
||||
return self.context.channel
|
||||
|
||||
async def send_error_message(self, error):
|
||||
async def send_error_message(self, error: str) -> None:
|
||||
"""|coro|
|
||||
|
||||
Handles the implementation when an error happens in the help command.
|
||||
@ -644,7 +693,7 @@ class HelpCommand:
|
||||
await destination.send(error)
|
||||
|
||||
@_not_overridden
|
||||
async def on_help_command_error(self, ctx, error):
|
||||
async def on_help_command_error(self, ctx: Context[BotT], error: CommandError) -> None:
|
||||
"""|coro|
|
||||
|
||||
The help command's error handler, as specified by :ref:`ext_commands_error_handler`.
|
||||
@ -664,7 +713,7 @@ class HelpCommand:
|
||||
"""
|
||||
pass
|
||||
|
||||
async def send_bot_help(self, mapping):
|
||||
async def send_bot_help(self, mapping: Mapping[Optional[Cog], List[Command[Any, ..., Any]]]) -> None:
|
||||
"""|coro|
|
||||
|
||||
Handles the implementation of the bot command page in the help command.
|
||||
@ -693,7 +742,7 @@ class HelpCommand:
|
||||
"""
|
||||
return None
|
||||
|
||||
async def send_cog_help(self, cog):
|
||||
async def send_cog_help(self, cog: Cog) -> None:
|
||||
"""|coro|
|
||||
|
||||
Handles the implementation of the cog page in the help command.
|
||||
@ -721,7 +770,7 @@ class HelpCommand:
|
||||
"""
|
||||
return None
|
||||
|
||||
async def send_group_help(self, group):
|
||||
async def send_group_help(self, group: Group[Any, ..., Any]) -> None:
|
||||
"""|coro|
|
||||
|
||||
Handles the implementation of the group page in the help command.
|
||||
@ -749,7 +798,7 @@ class HelpCommand:
|
||||
"""
|
||||
return None
|
||||
|
||||
async def send_command_help(self, command):
|
||||
async def send_command_help(self, command: Command[Any, ..., Any]) -> None:
|
||||
"""|coro|
|
||||
|
||||
Handles the implementation of the single command page in the help command.
|
||||
@ -787,7 +836,7 @@ class HelpCommand:
|
||||
"""
|
||||
return None
|
||||
|
||||
async def prepare_help_command(self, ctx, command=None):
|
||||
async def prepare_help_command(self, ctx: Context[BotT], command: Optional[str] = None) -> None:
|
||||
"""|coro|
|
||||
|
||||
A low level method that can be used to prepare the help command
|
||||
@ -811,7 +860,7 @@ class HelpCommand:
|
||||
"""
|
||||
pass
|
||||
|
||||
async def command_callback(self, ctx, *, command=None):
|
||||
async def command_callback(self, ctx: Context[BotT], *, command: Optional[str] = None) -> None:
|
||||
"""|coro|
|
||||
|
||||
The actual implementation of the help command.
|
||||
@ -856,7 +905,7 @@ class HelpCommand:
|
||||
|
||||
for key in keys[1:]:
|
||||
try:
|
||||
found = cmd.all_commands.get(key)
|
||||
found = cmd.all_commands.get(key) # type: ignore
|
||||
except AttributeError:
|
||||
string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key))
|
||||
return await self.send_error_message(string)
|
||||
@ -908,28 +957,28 @@ class DefaultHelpCommand(HelpCommand):
|
||||
The paginator used to paginate the help command output.
|
||||
"""
|
||||
|
||||
def __init__(self, **options):
|
||||
self.width = options.pop('width', 80)
|
||||
self.indent = options.pop('indent', 2)
|
||||
self.sort_commands = options.pop('sort_commands', True)
|
||||
self.dm_help = options.pop('dm_help', False)
|
||||
self.dm_help_threshold = options.pop('dm_help_threshold', 1000)
|
||||
self.commands_heading = options.pop('commands_heading', "Commands:")
|
||||
self.no_category = options.pop('no_category', 'No Category')
|
||||
self.paginator = options.pop('paginator', None)
|
||||
def __init__(self, **options: Any) -> None:
|
||||
self.width: int = options.pop('width', 80)
|
||||
self.indent: int = options.pop('indent', 2)
|
||||
self.sort_commands: bool = options.pop('sort_commands', True)
|
||||
self.dm_help: bool = options.pop('dm_help', False)
|
||||
self.dm_help_threshold: int = options.pop('dm_help_threshold', 1000)
|
||||
self.commands_heading: str = options.pop('commands_heading', "Commands:")
|
||||
self.no_category: str = options.pop('no_category', 'No Category')
|
||||
self.paginator: Paginator = options.pop('paginator', None)
|
||||
|
||||
if self.paginator is None:
|
||||
self.paginator = Paginator()
|
||||
self.paginator: Paginator = Paginator()
|
||||
|
||||
super().__init__(**options)
|
||||
|
||||
def shorten_text(self, text):
|
||||
def shorten_text(self, text: str) -> str:
|
||||
""":class:`str`: Shortens text to fit into the :attr:`width`."""
|
||||
if len(text) > self.width:
|
||||
return text[: self.width - 3].rstrip() + '...'
|
||||
return text
|
||||
|
||||
def get_ending_note(self):
|
||||
def get_ending_note(self) -> str:
|
||||
""":class:`str`: Returns help command's ending note. This is mainly useful to override for i18n purposes."""
|
||||
command_name = self.invoked_with
|
||||
return (
|
||||
@ -937,7 +986,9 @@ class DefaultHelpCommand(HelpCommand):
|
||||
f"You can also type {self.context.clean_prefix}{command_name} category for more info on a category."
|
||||
)
|
||||
|
||||
def add_indented_commands(self, commands, *, heading, max_size=None):
|
||||
def add_indented_commands(
|
||||
self, commands: Sequence[Command[Any, ..., Any]], *, heading: str, max_size: Optional[int] = None
|
||||
) -> None:
|
||||
"""Indents a list of commands after the specified heading.
|
||||
|
||||
The formatting is added to the :attr:`paginator`.
|
||||
@ -973,13 +1024,13 @@ class DefaultHelpCommand(HelpCommand):
|
||||
entry = f'{self.indent * " "}{name:<{width}} {command.short_doc}'
|
||||
self.paginator.add_line(self.shorten_text(entry))
|
||||
|
||||
async def send_pages(self):
|
||||
async def send_pages(self) -> None:
|
||||
"""A helper utility to send the page output from :attr:`paginator` to the destination."""
|
||||
destination = self.get_destination()
|
||||
for page in self.paginator.pages:
|
||||
await destination.send(page)
|
||||
|
||||
def add_command_formatting(self, command):
|
||||
def add_command_formatting(self, command: Command[Any, ..., Any]) -> None:
|
||||
"""A utility function to format the non-indented block of commands and groups.
|
||||
|
||||
Parameters
|
||||
@ -1002,7 +1053,7 @@ class DefaultHelpCommand(HelpCommand):
|
||||
self.paginator.add_line(line)
|
||||
self.paginator.add_line()
|
||||
|
||||
def get_destination(self):
|
||||
def get_destination(self) -> discord.abc.Messageable:
|
||||
ctx = self.context
|
||||
if self.dm_help is True:
|
||||
return ctx.author
|
||||
@ -1011,11 +1062,11 @@ class DefaultHelpCommand(HelpCommand):
|
||||
else:
|
||||
return ctx.channel
|
||||
|
||||
async def prepare_help_command(self, ctx, command):
|
||||
async def prepare_help_command(self, ctx: Context[BotT], command: str) -> None:
|
||||
self.paginator.clear()
|
||||
await super().prepare_help_command(ctx, command)
|
||||
|
||||
async def send_bot_help(self, mapping):
|
||||
async def send_bot_help(self, mapping: Mapping[Optional[Cog], List[Command[Any, ..., Any]]]) -> None:
|
||||
ctx = self.context
|
||||
bot = ctx.bot
|
||||
|
||||
@ -1045,12 +1096,12 @@ class DefaultHelpCommand(HelpCommand):
|
||||
|
||||
await self.send_pages()
|
||||
|
||||
async def send_command_help(self, command):
|
||||
async def send_command_help(self, command: Command[Any, ..., Any]) -> None:
|
||||
self.add_command_formatting(command)
|
||||
self.paginator.close_page()
|
||||
await self.send_pages()
|
||||
|
||||
async def send_group_help(self, group):
|
||||
async def send_group_help(self, group: Group[Any, ..., Any]) -> None:
|
||||
self.add_command_formatting(group)
|
||||
|
||||
filtered = await self.filter_commands(group.commands, sort=self.sort_commands)
|
||||
@ -1064,7 +1115,7 @@ class DefaultHelpCommand(HelpCommand):
|
||||
|
||||
await self.send_pages()
|
||||
|
||||
async def send_cog_help(self, cog):
|
||||
async def send_cog_help(self, cog: Cog) -> None:
|
||||
if cog.description:
|
||||
self.paginator.add_line(cog.description, empty=True)
|
||||
|
||||
@ -1111,27 +1162,27 @@ class MinimalHelpCommand(HelpCommand):
|
||||
The paginator used to paginate the help command output.
|
||||
"""
|
||||
|
||||
def __init__(self, **options):
|
||||
self.sort_commands = options.pop('sort_commands', True)
|
||||
self.commands_heading = options.pop('commands_heading', "Commands")
|
||||
self.dm_help = options.pop('dm_help', False)
|
||||
self.dm_help_threshold = options.pop('dm_help_threshold', 1000)
|
||||
self.aliases_heading = options.pop('aliases_heading', "Aliases:")
|
||||
self.no_category = options.pop('no_category', 'No Category')
|
||||
self.paginator = options.pop('paginator', None)
|
||||
def __init__(self, **options: Any) -> None:
|
||||
self.sort_commands: bool = options.pop('sort_commands', True)
|
||||
self.commands_heading: str = options.pop('commands_heading', "Commands")
|
||||
self.dm_help: bool = options.pop('dm_help', False)
|
||||
self.dm_help_threshold: int = options.pop('dm_help_threshold', 1000)
|
||||
self.aliases_heading: str = options.pop('aliases_heading', "Aliases:")
|
||||
self.no_category: str = options.pop('no_category', 'No Category')
|
||||
self.paginator: Paginator = options.pop('paginator', None)
|
||||
|
||||
if self.paginator is None:
|
||||
self.paginator = Paginator(suffix=None, prefix=None)
|
||||
self.paginator: Paginator = Paginator(suffix=None, prefix=None)
|
||||
|
||||
super().__init__(**options)
|
||||
|
||||
async def send_pages(self):
|
||||
async def send_pages(self) -> None:
|
||||
"""A helper utility to send the page output from :attr:`paginator` to the destination."""
|
||||
destination = self.get_destination()
|
||||
for page in self.paginator.pages:
|
||||
await destination.send(page)
|
||||
|
||||
def get_opening_note(self):
|
||||
def get_opening_note(self) -> str:
|
||||
"""Returns help command's opening note. This is mainly useful to override for i18n purposes.
|
||||
|
||||
The default implementation returns ::
|
||||
@ -1150,10 +1201,10 @@ class MinimalHelpCommand(HelpCommand):
|
||||
f"You can also use `{self.context.clean_prefix}{command_name} [category]` for more info on a category."
|
||||
)
|
||||
|
||||
def get_command_signature(self, command):
|
||||
def get_command_signature(self, command: Command[Any, ..., Any]) -> str:
|
||||
return f'{self.context.clean_prefix}{command.qualified_name} {command.signature}'
|
||||
|
||||
def get_ending_note(self):
|
||||
def get_ending_note(self) -> str:
|
||||
"""Return the help command's ending note. This is mainly useful to override for i18n purposes.
|
||||
|
||||
The default implementation does nothing.
|
||||
@ -1163,9 +1214,9 @@ class MinimalHelpCommand(HelpCommand):
|
||||
:class:`str`
|
||||
The help command ending note.
|
||||
"""
|
||||
return None
|
||||
return ''
|
||||
|
||||
def add_bot_commands_formatting(self, commands, heading):
|
||||
def add_bot_commands_formatting(self, commands: Sequence[Command[Any, ..., Any]], heading: str) -> None:
|
||||
"""Adds the minified bot heading with commands to the output.
|
||||
|
||||
The formatting should be added to the :attr:`paginator`.
|
||||
@ -1186,7 +1237,7 @@ class MinimalHelpCommand(HelpCommand):
|
||||
self.paginator.add_line(f'__**{heading}**__')
|
||||
self.paginator.add_line(joined)
|
||||
|
||||
def add_subcommand_formatting(self, command):
|
||||
def add_subcommand_formatting(self, command: Command[Any, ..., Any]) -> None:
|
||||
"""Adds formatting information on a subcommand.
|
||||
|
||||
The formatting should be added to the :attr:`paginator`.
|
||||
@ -1202,7 +1253,7 @@ class MinimalHelpCommand(HelpCommand):
|
||||
fmt = '{0}{1} \N{EN DASH} {2}' if command.short_doc else '{0}{1}'
|
||||
self.paginator.add_line(fmt.format(self.context.clean_prefix, command.qualified_name, command.short_doc))
|
||||
|
||||
def add_aliases_formatting(self, aliases):
|
||||
def add_aliases_formatting(self, aliases: Sequence[str]) -> None:
|
||||
"""Adds the formatting information on a command's aliases.
|
||||
|
||||
The formatting should be added to the :attr:`paginator`.
|
||||
@ -1219,7 +1270,7 @@ class MinimalHelpCommand(HelpCommand):
|
||||
"""
|
||||
self.paginator.add_line(f'**{self.aliases_heading}** {", ".join(aliases)}', empty=True)
|
||||
|
||||
def add_command_formatting(self, command):
|
||||
def add_command_formatting(self, command: Command[Any, ..., Any]) -> None:
|
||||
"""A utility function to format commands and groups.
|
||||
|
||||
Parameters
|
||||
@ -1246,7 +1297,7 @@ class MinimalHelpCommand(HelpCommand):
|
||||
self.paginator.add_line(line)
|
||||
self.paginator.add_line()
|
||||
|
||||
def get_destination(self):
|
||||
def get_destination(self) -> discord.abc.Messageable:
|
||||
ctx = self.context
|
||||
if self.dm_help is True:
|
||||
return ctx.author
|
||||
@ -1255,11 +1306,11 @@ class MinimalHelpCommand(HelpCommand):
|
||||
else:
|
||||
return ctx.channel
|
||||
|
||||
async def prepare_help_command(self, ctx, command):
|
||||
async def prepare_help_command(self, ctx: Context[BotT], command: str) -> None:
|
||||
self.paginator.clear()
|
||||
await super().prepare_help_command(ctx, command)
|
||||
|
||||
async def send_bot_help(self, mapping):
|
||||
async def send_bot_help(self, mapping: Mapping[Optional[Cog], List[Command[Any, ..., Any]]]) -> None:
|
||||
ctx = self.context
|
||||
bot = ctx.bot
|
||||
|
||||
@ -1272,7 +1323,7 @@ class MinimalHelpCommand(HelpCommand):
|
||||
|
||||
no_category = f'\u200b{self.no_category}'
|
||||
|
||||
def get_category(command, *, no_category=no_category):
|
||||
def get_category(command: Command[Any, ..., Any], *, no_category: str = no_category) -> str:
|
||||
cog = command.cog
|
||||
return cog.qualified_name if cog is not None else no_category
|
||||
|
||||
@ -1290,7 +1341,7 @@ class MinimalHelpCommand(HelpCommand):
|
||||
|
||||
await self.send_pages()
|
||||
|
||||
async def send_cog_help(self, cog):
|
||||
async def send_cog_help(self, cog: Cog) -> None:
|
||||
bot = self.context.bot
|
||||
if bot.description:
|
||||
self.paginator.add_line(bot.description, empty=True)
|
||||
@ -1315,7 +1366,7 @@ class MinimalHelpCommand(HelpCommand):
|
||||
|
||||
await self.send_pages()
|
||||
|
||||
async def send_group_help(self, group):
|
||||
async def send_group_help(self, group: Group[Any, ..., Any]) -> None:
|
||||
self.add_command_formatting(group)
|
||||
|
||||
filtered = await self.filter_commands(group.commands, sort=self.sort_commands)
|
||||
@ -1335,7 +1386,7 @@ class MinimalHelpCommand(HelpCommand):
|
||||
|
||||
await self.send_pages()
|
||||
|
||||
async def send_command_help(self, command):
|
||||
async def send_command_help(self, command: Command[Any, ..., Any]) -> None:
|
||||
self.add_command_formatting(command)
|
||||
self.paginator.close_page()
|
||||
await self.send_pages()
|
||||
|
@ -21,6 +21,11 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError
|
||||
|
||||
# map from opening quotes to closing quotes
|
||||
@ -47,24 +52,24 @@ _all_quotes = set(_quotes.keys()) | set(_quotes.values())
|
||||
|
||||
|
||||
class StringView:
|
||||
def __init__(self, buffer):
|
||||
self.index = 0
|
||||
self.buffer = buffer
|
||||
self.end = len(buffer)
|
||||
def __init__(self, buffer: str) -> None:
|
||||
self.index: int = 0
|
||||
self.buffer: str = buffer
|
||||
self.end: int = len(buffer)
|
||||
self.previous = 0
|
||||
|
||||
@property
|
||||
def current(self):
|
||||
def current(self) -> Optional[str]:
|
||||
return None if self.eof else self.buffer[self.index]
|
||||
|
||||
@property
|
||||
def eof(self):
|
||||
def eof(self) -> bool:
|
||||
return self.index >= self.end
|
||||
|
||||
def undo(self):
|
||||
def undo(self) -> None:
|
||||
self.index = self.previous
|
||||
|
||||
def skip_ws(self):
|
||||
def skip_ws(self) -> bool:
|
||||
pos = 0
|
||||
while not self.eof:
|
||||
try:
|
||||
@ -79,7 +84,7 @@ class StringView:
|
||||
self.index += pos
|
||||
return self.previous != self.index
|
||||
|
||||
def skip_string(self, string):
|
||||
def skip_string(self, string: str) -> bool:
|
||||
strlen = len(string)
|
||||
if self.buffer[self.index : self.index + strlen] == string:
|
||||
self.previous = self.index
|
||||
@ -87,19 +92,19 @@ class StringView:
|
||||
return True
|
||||
return False
|
||||
|
||||
def read_rest(self):
|
||||
def read_rest(self) -> str:
|
||||
result = self.buffer[self.index :]
|
||||
self.previous = self.index
|
||||
self.index = self.end
|
||||
return result
|
||||
|
||||
def read(self, n):
|
||||
def read(self, n: int) -> str:
|
||||
result = self.buffer[self.index : self.index + n]
|
||||
self.previous = self.index
|
||||
self.index += n
|
||||
return result
|
||||
|
||||
def get(self):
|
||||
def get(self) -> Optional[str]:
|
||||
try:
|
||||
result = self.buffer[self.index + 1]
|
||||
except IndexError:
|
||||
@ -109,7 +114,7 @@ class StringView:
|
||||
self.index += 1
|
||||
return result
|
||||
|
||||
def get_word(self):
|
||||
def get_word(self) -> str:
|
||||
pos = 0
|
||||
while not self.eof:
|
||||
try:
|
||||
@ -119,12 +124,12 @@ class StringView:
|
||||
pos += 1
|
||||
except IndexError:
|
||||
break
|
||||
self.previous = self.index
|
||||
self.previous: int = self.index
|
||||
result = self.buffer[self.index : self.index + pos]
|
||||
self.index += pos
|
||||
return result
|
||||
|
||||
def get_quoted_word(self):
|
||||
def get_quoted_word(self) -> Optional[str]:
|
||||
current = self.current
|
||||
if current is None:
|
||||
return None
|
||||
@ -187,5 +192,5 @@ class StringView:
|
||||
|
||||
result.append(current)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f'<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>'
|
||||
|
@ -110,15 +110,15 @@ class SleepHandle:
|
||||
__slots__ = ('future', 'loop', 'handle')
|
||||
|
||||
def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self.loop = loop
|
||||
self.future = future = loop.create_future()
|
||||
self.loop: asyncio.AbstractEventLoop = loop
|
||||
self.future: asyncio.Future[None] = loop.create_future()
|
||||
relative_delta = discord.utils.compute_timedelta(dt)
|
||||
self.handle = loop.call_later(relative_delta, future.set_result, True)
|
||||
self.handle = loop.call_later(relative_delta, self.future.set_result, True)
|
||||
|
||||
def recalculate(self, dt: datetime.datetime) -> None:
|
||||
self.handle.cancel()
|
||||
relative_delta = discord.utils.compute_timedelta(dt)
|
||||
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True)
|
||||
self.handle: asyncio.TimerHandle = self.loop.call_later(relative_delta, self.future.set_result, True)
|
||||
|
||||
def wait(self) -> asyncio.Future[Any]:
|
||||
return self.future
|
||||
|
@ -74,7 +74,7 @@ class File:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fp: Union[str, bytes, os.PathLike, io.BufferedIOBase],
|
||||
fp: Union[str, bytes, os.PathLike[Any], io.BufferedIOBase],
|
||||
filename: Optional[str] = None,
|
||||
*,
|
||||
spoiler: bool = False,
|
||||
|
@ -46,8 +46,8 @@ BF = TypeVar('BF', bound='BaseFlags')
|
||||
|
||||
class flag_value:
|
||||
def __init__(self, func: Callable[[Any], int]):
|
||||
self.flag = func(None)
|
||||
self.__doc__ = func.__doc__
|
||||
self.flag: int = func(None)
|
||||
self.__doc__: Optional[str] = func.__doc__
|
||||
|
||||
@overload
|
||||
def __get__(self, instance: None, owner: Type[BF]) -> Self:
|
||||
@ -65,7 +65,7 @@ class flag_value:
|
||||
def __set__(self, instance: BaseFlags, value: bool) -> None:
|
||||
instance._set_flag(self.flag, value)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f'<flag_value flag={self.flag!r}>'
|
||||
|
||||
|
||||
@ -73,8 +73,8 @@ class alias_flag_value(flag_value):
|
||||
pass
|
||||
|
||||
|
||||
def fill_with_flags(*, inverted: bool = False):
|
||||
def decorator(cls: Type[BF]):
|
||||
def fill_with_flags(*, inverted: bool = False) -> Callable[[Type[BF]], Type[BF]]:
|
||||
def decorator(cls: Type[BF]) -> Type[BF]:
|
||||
# fmt: off
|
||||
cls.VALID_FLAGS = {
|
||||
name: value.flag
|
||||
@ -116,10 +116,10 @@ class BaseFlags:
|
||||
self.value = value
|
||||
return self
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, self.__class__) and self.value == other.value
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
@ -504,8 +504,8 @@ class Intents(BaseFlags):
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, **kwargs: bool):
|
||||
self.value = self.DEFAULT_VALUE
|
||||
def __init__(self, **kwargs: bool) -> None:
|
||||
self.value: int = self.DEFAULT_VALUE
|
||||
for key, value in kwargs.items():
|
||||
if key not in self.VALID_FLAGS:
|
||||
raise TypeError(f'{key!r} is not a valid flag name.')
|
||||
@ -1005,7 +1005,7 @@ class MemberCacheFlags(BaseFlags):
|
||||
|
||||
def __init__(self, **kwargs: bool):
|
||||
bits = max(self.VALID_FLAGS.values()).bit_length()
|
||||
self.value = (1 << bits) - 1
|
||||
self.value: int = (1 << bits) - 1
|
||||
for key, value in kwargs.items():
|
||||
if key not in self.VALID_FLAGS:
|
||||
raise TypeError(f'{key!r} is not a valid flag name.')
|
||||
|
@ -54,6 +54,8 @@ __all__ = (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
|
||||
from .client import Client
|
||||
from .state import ConnectionState
|
||||
from .voice_client import VoiceClient
|
||||
@ -62,10 +64,10 @@ if TYPE_CHECKING:
|
||||
class ReconnectWebSocket(Exception):
|
||||
"""Signals to safely reconnect the websocket."""
|
||||
|
||||
def __init__(self, shard_id, *, resume=True):
|
||||
self.shard_id = shard_id
|
||||
self.resume = resume
|
||||
self.op = 'RESUME' if resume else 'IDENTIFY'
|
||||
def __init__(self, shard_id: Optional[int], *, resume: bool = True) -> None:
|
||||
self.shard_id: Optional[int] = shard_id
|
||||
self.resume: bool = resume
|
||||
self.op: str = 'RESUME' if resume else 'IDENTIFY'
|
||||
|
||||
|
||||
class WebSocketClosure(Exception):
|
||||
@ -225,7 +227,7 @@ class VoiceKeepAliveHandler(KeepAliveHandler):
|
||||
ack_time = time.perf_counter()
|
||||
self._last_ack = ack_time
|
||||
self._last_recv = ack_time
|
||||
self.latency = ack_time - self._last_send
|
||||
self.latency: float = ack_time - self._last_send
|
||||
self.recent_ack_latencies.append(self.latency)
|
||||
|
||||
|
||||
@ -339,7 +341,7 @@ class DiscordWebSocket:
|
||||
|
||||
@classmethod
|
||||
async def from_client(
|
||||
cls: Type[DWS],
|
||||
cls,
|
||||
client: Client,
|
||||
*,
|
||||
initial: bool = False,
|
||||
@ -348,7 +350,7 @@ class DiscordWebSocket:
|
||||
session: Optional[str] = None,
|
||||
sequence: Optional[int] = None,
|
||||
resume: bool = False,
|
||||
) -> DWS:
|
||||
) -> Self:
|
||||
"""Creates a main websocket for Discord from a :class:`Client`.
|
||||
|
||||
This is for internal use only.
|
||||
@ -821,11 +823,11 @@ class DiscordVoiceWebSocket:
|
||||
*,
|
||||
hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None,
|
||||
) -> None:
|
||||
self.ws = socket
|
||||
self.loop = loop
|
||||
self._keep_alive = None
|
||||
self._close_code = None
|
||||
self.secret_key = None
|
||||
self.ws: aiohttp.ClientWebSocketResponse = socket
|
||||
self.loop: asyncio.AbstractEventLoop = loop
|
||||
self._keep_alive: Optional[VoiceKeepAliveHandler] = None
|
||||
self._close_code: Optional[int] = None
|
||||
self.secret_key: Optional[str] = None
|
||||
if hook:
|
||||
self._hook = hook # type: ignore - type-checker doesn't like overriding methods
|
||||
|
||||
@ -864,7 +866,9 @@ class DiscordVoiceWebSocket:
|
||||
await self.send_as_json(payload)
|
||||
|
||||
@classmethod
|
||||
async def from_client(cls: Type[DVWS], client: VoiceClient, *, resume=False, hook=None) -> DVWS:
|
||||
async def from_client(
|
||||
cls, client: VoiceClient, *, resume: bool = False, hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None
|
||||
) -> Self:
|
||||
"""Creates a voice websocket for the :class:`VoiceClient`."""
|
||||
gateway = 'wss://' + client.endpoint + '/?v=4'
|
||||
http = client._state.http
|
||||
|
@ -123,6 +123,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .types.integration import IntegrationType
|
||||
from .types.snowflake import SnowflakeList
|
||||
from .types.widget import EditWidgetSettings
|
||||
|
||||
VocalGuildChannel = Union[VoiceChannel, StageChannel]
|
||||
GuildChannel = Union[VocalGuildChannel, TextChannel, CategoryChannel, StoreChannel]
|
||||
@ -3379,7 +3380,7 @@ class Guild(Hashable):
|
||||
HTTPException
|
||||
Editing the widget failed.
|
||||
"""
|
||||
payload = {}
|
||||
payload: EditWidgetSettings = {}
|
||||
if channel is not MISSING:
|
||||
payload['channel_id'] = None if channel is None else channel.id
|
||||
if enabled is not MISSING:
|
||||
@ -3492,7 +3493,7 @@ class Guild(Hashable):
|
||||
|
||||
async def change_voice_state(
|
||||
self, *, channel: Optional[abc.Snowflake], self_mute: bool = False, self_deaf: bool = False
|
||||
):
|
||||
) -> None:
|
||||
"""|coro|
|
||||
|
||||
Changes client's voice state in the guild.
|
||||
|
@ -76,12 +76,9 @@ if TYPE_CHECKING:
|
||||
audit_log,
|
||||
channel,
|
||||
command,
|
||||
components,
|
||||
emoji,
|
||||
embed,
|
||||
guild,
|
||||
integration,
|
||||
interactions,
|
||||
invite,
|
||||
member,
|
||||
message,
|
||||
@ -92,7 +89,6 @@ if TYPE_CHECKING:
|
||||
channel,
|
||||
widget,
|
||||
threads,
|
||||
voice,
|
||||
scheduled_event,
|
||||
sticker,
|
||||
)
|
||||
@ -122,7 +118,7 @@ class MultipartParameters(NamedTuple):
|
||||
multipart: Optional[List[Dict[str, Any]]]
|
||||
files: Optional[List[File]]
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
@ -577,7 +573,7 @@ class HTTPClient:
|
||||
|
||||
return self.request(Route('POST', '/users/{user_id}/channels', user_id=user_id), json=payload)
|
||||
|
||||
def leave_group(self, channel_id) -> Response[None]:
|
||||
def leave_group(self, channel_id: Snowflake) -> Response[None]:
|
||||
return self.request(Route('DELETE', '/channels/{channel_id}', channel_id=channel_id))
|
||||
|
||||
# Message management
|
||||
@ -1160,7 +1156,7 @@ class HTTPClient:
|
||||
def sync_template(self, guild_id: Snowflake, code: str) -> Response[template.Template]:
|
||||
return self.request(Route('PUT', '/guilds/{guild_id}/templates/{code}', guild_id=guild_id, code=code))
|
||||
|
||||
def edit_template(self, guild_id: Snowflake, code: str, payload) -> Response[template.Template]:
|
||||
def edit_template(self, guild_id: Snowflake, code: str, payload: Dict[str, Any]) -> Response[template.Template]:
|
||||
valid_keys = (
|
||||
'name',
|
||||
'description',
|
||||
@ -1420,7 +1416,9 @@ class HTTPClient:
|
||||
def get_widget(self, guild_id: Snowflake) -> Response[widget.Widget]:
|
||||
return self.request(Route('GET', '/guilds/{guild_id}/widget.json', guild_id=guild_id))
|
||||
|
||||
def edit_widget(self, guild_id: Snowflake, payload, reason: Optional[str] = None) -> Response[widget.WidgetSettings]:
|
||||
def edit_widget(
|
||||
self, guild_id: Snowflake, payload: widget.EditWidgetSettings, reason: Optional[str] = None
|
||||
) -> Response[widget.WidgetSettings]:
|
||||
return self.request(Route('PATCH', '/guilds/{guild_id}/widget', guild_id=guild_id), json=payload, reason=reason)
|
||||
|
||||
# Invite management
|
||||
@ -1812,7 +1810,9 @@ class HTTPClient:
|
||||
)
|
||||
return self.request(r)
|
||||
|
||||
def upsert_global_command(self, application_id: Snowflake, payload) -> Response[command.ApplicationCommand]:
|
||||
def upsert_global_command(
|
||||
self, application_id: Snowflake, payload: command.ApplicationCommand
|
||||
) -> Response[command.ApplicationCommand]:
|
||||
r = Route('POST', '/applications/{application_id}/commands', application_id=application_id)
|
||||
return self.request(r, json=payload)
|
||||
|
||||
@ -1845,7 +1845,9 @@ class HTTPClient:
|
||||
)
|
||||
return self.request(r)
|
||||
|
||||
def bulk_upsert_global_commands(self, application_id: Snowflake, payload) -> Response[List[command.ApplicationCommand]]:
|
||||
def bulk_upsert_global_commands(
|
||||
self, application_id: Snowflake, payload: List[Dict[str, Any]]
|
||||
) -> Response[List[command.ApplicationCommand]]:
|
||||
r = Route('PUT', '/applications/{application_id}/commands', application_id=application_id)
|
||||
return self.request(r, json=payload)
|
||||
|
||||
|
@ -39,6 +39,9 @@ __all__ = (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .guild import Guild
|
||||
from .role import Role
|
||||
from .state import ConnectionState
|
||||
from .types.integration import (
|
||||
IntegrationAccount as IntegrationAccountPayload,
|
||||
Integration as IntegrationPayload,
|
||||
@ -47,8 +50,6 @@ if TYPE_CHECKING:
|
||||
IntegrationType,
|
||||
IntegrationApplication as IntegrationApplicationPayload,
|
||||
)
|
||||
from .guild import Guild
|
||||
from .role import Role
|
||||
|
||||
|
||||
class IntegrationAccount:
|
||||
@ -109,11 +110,11 @@ class Integration:
|
||||
)
|
||||
|
||||
def __init__(self, *, data: IntegrationPayload, guild: Guild) -> None:
|
||||
self.guild = guild
|
||||
self._state = guild._state
|
||||
self.guild: Guild = guild
|
||||
self._state: ConnectionState = guild._state
|
||||
self._from_data(data)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} id={self.id} name={self.name!r}>"
|
||||
|
||||
def _from_data(self, data: IntegrationPayload) -> None:
|
||||
@ -123,7 +124,7 @@ class Integration:
|
||||
self.account: IntegrationAccount = IntegrationAccount(data['account'])
|
||||
|
||||
user = data.get('user')
|
||||
self.user = User(state=self._state, data=user) if user else None
|
||||
self.user: Optional[User] = User(state=self._state, data=user) if user else None
|
||||
self.enabled: bool = data['enabled']
|
||||
|
||||
async def delete(self, *, reason: Optional[str] = None) -> None:
|
||||
@ -319,7 +320,7 @@ class IntegrationApplication:
|
||||
'user',
|
||||
)
|
||||
|
||||
def __init__(self, *, data: IntegrationApplicationPayload, state):
|
||||
def __init__(self, *, data: IntegrationApplicationPayload, state: ConnectionState) -> None:
|
||||
self.id: int = int(data['id'])
|
||||
self.name: str = data['name']
|
||||
self.icon: Optional[str] = data['icon']
|
||||
@ -358,7 +359,7 @@ class BotIntegration(Integration):
|
||||
|
||||
def _from_data(self, data: BotIntegrationPayload) -> None:
|
||||
super()._from_data(data)
|
||||
self.application = IntegrationApplication(data=data['application'], state=self._state)
|
||||
self.application: IntegrationApplication = IntegrationApplication(data=data['application'], state=self._state)
|
||||
|
||||
|
||||
def _integration_factory(value: str) -> Tuple[Type[Integration], str]:
|
||||
|
@ -54,6 +54,9 @@ if TYPE_CHECKING:
|
||||
Interaction as InteractionPayload,
|
||||
InteractionData,
|
||||
)
|
||||
from .types.webhook import (
|
||||
Webhook as WebhookPayload,
|
||||
)
|
||||
from .client import Client
|
||||
from .guild import Guild
|
||||
from .state import ConnectionState
|
||||
@ -229,7 +232,7 @@ class Interaction:
|
||||
@utils.cached_slot_property('_cs_followup')
|
||||
def followup(self) -> Webhook:
|
||||
""":class:`Webhook`: Returns the follow up webhook for follow up interactions."""
|
||||
payload = {
|
||||
payload: WebhookPayload = {
|
||||
'id': self.application_id,
|
||||
'type': 3,
|
||||
'token': self.token,
|
||||
@ -703,7 +706,7 @@ class InteractionResponse:
|
||||
|
||||
self._responded = True
|
||||
|
||||
async def send_modal(self, modal: Modal, /):
|
||||
async def send_modal(self, modal: Modal, /) -> None:
|
||||
"""|coro|
|
||||
|
||||
Responds to this interaction by sending a modal.
|
||||
|
@ -456,7 +456,7 @@ class Invite(Hashable):
|
||||
guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id)
|
||||
channel_id = int(data['channel_id'])
|
||||
if guild is not None:
|
||||
channel = guild.get_channel(channel_id) or Object(id=channel_id) # type: ignore
|
||||
channel = guild.get_channel(channel_id) or Object(id=channel_id)
|
||||
else:
|
||||
guild = Object(id=guild_id) if guild_id is not None else None
|
||||
channel = Object(id=channel_id)
|
||||
@ -539,7 +539,7 @@ class Invite(Hashable):
|
||||
|
||||
return self
|
||||
|
||||
async def delete(self, *, reason: Optional[str] = None):
|
||||
async def delete(self, *, reason: Optional[str] = None) -> None:
|
||||
"""|coro|
|
||||
|
||||
Revokes the instant invite.
|
||||
|
@ -27,9 +27,8 @@ from __future__ import annotations
|
||||
import datetime
|
||||
import inspect
|
||||
import itertools
|
||||
import sys
|
||||
from operator import attrgetter
|
||||
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, Type
|
||||
|
||||
import discord.abc
|
||||
|
||||
@ -207,7 +206,7 @@ class _ClientStatus:
|
||||
return self
|
||||
|
||||
|
||||
def flatten_user(cls):
|
||||
def flatten_user(cls: Any) -> Type[Member]:
|
||||
for attr, value in itertools.chain(BaseUser.__dict__.items(), User.__dict__.items()):
|
||||
# ignore private/special methods
|
||||
if attr.startswith('_'):
|
||||
@ -333,7 +332,7 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
default_avatar: Asset
|
||||
avatar: Optional[Asset]
|
||||
dm_channel: Optional[DMChannel]
|
||||
create_dm = User.create_dm
|
||||
create_dm: Callable[[], Coroutine[Any, Any, DMChannel]]
|
||||
mutual_guilds: List[Guild]
|
||||
public_flags: PublicUserFlags
|
||||
banner: Optional[Asset]
|
||||
@ -369,10 +368,10 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
f' bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>'
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, _UserTag) and other.id == self.id
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
@ -425,7 +424,7 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
self._user = member._user
|
||||
return self
|
||||
|
||||
async def _get_channel(self):
|
||||
async def _get_channel(self) -> DMChannel:
|
||||
ch = await self.create_dm()
|
||||
return ch
|
||||
|
||||
|
@ -92,10 +92,10 @@ class AllowedMentions:
|
||||
roles: Union[bool, List[Snowflake]] = default,
|
||||
replied_user: bool = default,
|
||||
):
|
||||
self.everyone = everyone
|
||||
self.users = users
|
||||
self.roles = roles
|
||||
self.replied_user = replied_user
|
||||
self.everyone: bool = everyone
|
||||
self.users: Union[bool, List[Snowflake]] = users
|
||||
self.roles: Union[bool, List[Snowflake]] = roles
|
||||
self.replied_user: bool = replied_user
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> Self:
|
||||
|
@ -40,6 +40,7 @@ from typing import (
|
||||
Tuple,
|
||||
ClassVar,
|
||||
Optional,
|
||||
Type,
|
||||
overload,
|
||||
)
|
||||
|
||||
@ -71,7 +72,6 @@ if TYPE_CHECKING:
|
||||
MessageReference as MessageReferencePayload,
|
||||
MessageApplication as MessageApplicationPayload,
|
||||
MessageActivity as MessageActivityPayload,
|
||||
Reaction as ReactionPayload,
|
||||
)
|
||||
|
||||
from .types.components import Component as ComponentPayload
|
||||
@ -87,7 +87,7 @@ if TYPE_CHECKING:
|
||||
from .abc import GuildChannel, PartialMessageableChannel, MessageableChannel
|
||||
from .components import Component
|
||||
from .state import ConnectionState
|
||||
from .channel import TextChannel, GroupChannel, DMChannel
|
||||
from .channel import TextChannel
|
||||
from .mentions import AllowedMentions
|
||||
from .user import User
|
||||
from .role import Role
|
||||
@ -95,6 +95,7 @@ if TYPE_CHECKING:
|
||||
|
||||
EmojiInputType = Union[Emoji, PartialEmoji, str]
|
||||
|
||||
|
||||
__all__ = (
|
||||
'Attachment',
|
||||
'Message',
|
||||
@ -104,7 +105,7 @@ __all__ = (
|
||||
)
|
||||
|
||||
|
||||
def convert_emoji_reaction(emoji):
|
||||
def convert_emoji_reaction(emoji: Union[EmojiInputType, Reaction]) -> str:
|
||||
if isinstance(emoji, Reaction):
|
||||
emoji = emoji.emoji
|
||||
|
||||
@ -216,7 +217,7 @@ class Attachment(Hashable):
|
||||
|
||||
async def save(
|
||||
self,
|
||||
fp: Union[io.BufferedIOBase, PathLike],
|
||||
fp: Union[io.BufferedIOBase, PathLike[Any]],
|
||||
*,
|
||||
seek_begin: bool = True,
|
||||
use_cached: bool = False,
|
||||
@ -510,7 +511,7 @@ class MessageReference:
|
||||
to_message_reference_dict = to_dict
|
||||
|
||||
|
||||
def flatten_handlers(cls):
|
||||
def flatten_handlers(cls: Type[Message]) -> Type[Message]:
|
||||
prefix = len('_handle_')
|
||||
handlers = [
|
||||
(key[prefix:], value)
|
||||
@ -1036,7 +1037,7 @@ class Message(Hashable):
|
||||
)
|
||||
|
||||
@utils.cached_slot_property('_cs_system_content')
|
||||
def system_content(self):
|
||||
def system_content(self) -> Optional[str]:
|
||||
r""":class:`str`: A property that returns the content that is rendered
|
||||
regardless of the :attr:`Message.type`.
|
||||
|
||||
@ -1657,7 +1658,7 @@ class Message(Hashable):
|
||||
)
|
||||
return Thread(guild=self.guild, state=self._state, data=data)
|
||||
|
||||
async def reply(self, content: Optional[str] = None, **kwargs) -> Message:
|
||||
async def reply(self, content: Optional[str] = None, **kwargs: Any) -> Message:
|
||||
"""|coro|
|
||||
|
||||
A shortcut method to :meth:`.abc.Messageable.send` to reply to the
|
||||
@ -1798,7 +1799,7 @@ class PartialMessage(Hashable):
|
||||
|
||||
# Also needed for duck typing purposes
|
||||
# n.b. not exposed
|
||||
pinned = property(None, lambda x, y: None)
|
||||
pinned: Any = property(None, lambda x, y: None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<PartialMessage id={self.id} channel={self.channel!r}>'
|
||||
|
@ -363,7 +363,7 @@ class Encoder(_OpusStruct):
|
||||
_lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0)
|
||||
|
||||
def set_expected_packet_loss_percent(self, percentage: float) -> None:
|
||||
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore
|
||||
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100))))
|
||||
|
||||
def encode(self, pcm: bytes, frame_size: int) -> bytes:
|
||||
max_data_bytes = len(pcm)
|
||||
@ -373,8 +373,7 @@ class Encoder(_OpusStruct):
|
||||
|
||||
ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes)
|
||||
|
||||
# array can be initialized with bytes but mypy doesn't know
|
||||
return array.array('b', data[:ret]).tobytes() # type: ignore
|
||||
return array.array('b', data[:ret]).tobytes()
|
||||
|
||||
|
||||
class Decoder(_OpusStruct):
|
||||
|
@ -42,6 +42,7 @@ if TYPE_CHECKING:
|
||||
from .state import ConnectionState
|
||||
from datetime import datetime
|
||||
from .types.message import PartialEmoji as PartialEmojiPayload
|
||||
from .types.activity import ActivityEmoji
|
||||
|
||||
|
||||
class _EmojiTag:
|
||||
@ -99,13 +100,13 @@ class PartialEmoji(_EmojiTag, AssetMixin):
|
||||
id: Optional[int]
|
||||
|
||||
def __init__(self, *, name: str, animated: bool = False, id: Optional[int] = None):
|
||||
self.animated = animated
|
||||
self.name = name
|
||||
self.id = id
|
||||
self.animated: bool = animated
|
||||
self.name: str = name
|
||||
self.id: Optional[int] = id
|
||||
self._state: Optional[ConnectionState] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Union[PartialEmojiPayload, Dict[str, Any]]) -> Self:
|
||||
def from_dict(cls, data: Union[PartialEmojiPayload, ActivityEmoji, Dict[str, Any]]) -> Self:
|
||||
return cls(
|
||||
animated=data.get('animated', False),
|
||||
id=utils._get_as_snowflake(data, 'id'),
|
||||
@ -178,10 +179,10 @@ class PartialEmoji(_EmojiTag, AssetMixin):
|
||||
return f'<a:{self.name}:{self.id}>'
|
||||
return f'<:{self.name}:{self.id}>'
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f'<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>'
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if self.is_unicode_emoji():
|
||||
return isinstance(other, PartialEmoji) and self.name == other.name
|
||||
|
||||
@ -189,7 +190,7 @@ class PartialEmoji(_EmojiTag, AssetMixin):
|
||||
return self.id == other.id
|
||||
return False
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
|
@ -276,7 +276,7 @@ class Permissions(BaseFlags):
|
||||
# So 0000 OP2 0101 -> 0101
|
||||
# The OP is base & ~denied.
|
||||
# The OP2 is base | allowed.
|
||||
self.value = (self.value & ~deny) | allow
|
||||
self.value: int = (self.value & ~deny) | allow
|
||||
|
||||
@flag_value
|
||||
def create_instant_invite(self) -> int:
|
||||
@ -691,7 +691,7 @@ class PermissionOverwrite:
|
||||
|
||||
setattr(self, key, value)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, PermissionOverwrite) and self._values == other._values
|
||||
|
||||
def _set(self, key: str, value: Optional[bool]) -> None:
|
||||
|
@ -365,12 +365,11 @@ class FFmpegOpusAudio(FFmpegAudio):
|
||||
bitrate: Optional[int] = None,
|
||||
codec: Optional[str] = None,
|
||||
executable: str = 'ffmpeg',
|
||||
pipe=False,
|
||||
stderr=None,
|
||||
before_options=None,
|
||||
options=None,
|
||||
pipe: bool = False,
|
||||
stderr: Optional[IO[bytes]] = None,
|
||||
before_options: Optional[str] = None,
|
||||
options: Optional[str] = None,
|
||||
) -> None:
|
||||
|
||||
args = []
|
||||
subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr}
|
||||
|
||||
@ -635,7 +634,13 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]):
|
||||
class AudioPlayer(threading.Thread):
|
||||
DELAY: float = OpusEncoder.FRAME_LENGTH / 1000.0
|
||||
|
||||
def __init__(self, source: AudioSource, client: VoiceClient, *, after=None):
|
||||
def __init__(
|
||||
self,
|
||||
source: AudioSource,
|
||||
client: VoiceClient,
|
||||
*,
|
||||
after: Optional[Callable[[Optional[Exception]], Any]] = None,
|
||||
) -> None:
|
||||
threading.Thread.__init__(self)
|
||||
self.daemon: bool = True
|
||||
self.source: AudioSource = source
|
||||
@ -724,8 +729,8 @@ class AudioPlayer(threading.Thread):
|
||||
self._speak(SpeakingState.none)
|
||||
|
||||
def resume(self, *, update_speaking: bool = True) -> None:
|
||||
self.loops = 0
|
||||
self._start = time.perf_counter()
|
||||
self.loops: int = 0
|
||||
self._start: float = time.perf_counter()
|
||||
self._resumed.set()
|
||||
if update_speaking:
|
||||
self._speak(SpeakingState.voice)
|
||||
|
@ -94,10 +94,10 @@ class Reaction:
|
||||
""":class:`bool`: If this is a custom emoji."""
|
||||
return not isinstance(self.emoji, str)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, self.__class__) and other.emoji == self.emoji
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
def __ne__(self, other: object) -> bool:
|
||||
if isinstance(other, self.__class__):
|
||||
return other.emoji != self.emoji
|
||||
return True
|
||||
|
@ -211,7 +211,7 @@ class Role(Hashable):
|
||||
def __repr__(self) -> str:
|
||||
return f'<Role id={self.id} name={self.name!r}>'
|
||||
|
||||
def __lt__(self, other: Any) -> bool:
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if not isinstance(other, Role) or not isinstance(self, Role):
|
||||
return NotImplemented
|
||||
|
||||
@ -241,7 +241,7 @@ class Role(Hashable):
|
||||
def __gt__(self, other: Any) -> bool:
|
||||
return Role.__lt__(other, self)
|
||||
|
||||
def __ge__(self, other: Any) -> bool:
|
||||
def __ge__(self, other: object) -> bool:
|
||||
r = Role.__lt__(self, other)
|
||||
if r is NotImplemented:
|
||||
return NotImplemented
|
||||
|
@ -132,7 +132,7 @@ class ScheduledEvent(Hashable):
|
||||
self.guild_id: int = int(data['guild_id'])
|
||||
self.name: str = data['name']
|
||||
self.description: Optional[str] = data.get('description')
|
||||
self.entity_type = try_enum(EntityType, data['entity_type'])
|
||||
self.entity_type: EntityType = try_enum(EntityType, data['entity_type'])
|
||||
self.entity_id: Optional[int] = _get_as_snowflake(data, 'entity_id')
|
||||
self.start_time: datetime = parse_time(data['scheduled_start_time'])
|
||||
self.privacy_level: PrivacyLevel = try_enum(PrivacyLevel, data['status'])
|
||||
@ -153,7 +153,7 @@ class ScheduledEvent(Hashable):
|
||||
self.location: Optional[str] = data.get('location') if data else None
|
||||
|
||||
@classmethod
|
||||
def from_creation(cls, *, state: ConnectionState, data: GuildScheduledEventPayload):
|
||||
def from_creation(cls, *, state: ConnectionState, data: GuildScheduledEventPayload) -> None:
|
||||
creator_id = data.get('creator_id')
|
||||
self = cls(state=state, data=data)
|
||||
if creator_id:
|
||||
@ -180,7 +180,7 @@ class ScheduledEvent(Hashable):
|
||||
return self.guild.get_channel(self.channel_id) # type: ignore
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
def url(self) -> str:
|
||||
""":class:`str`: The url for the scheduled event."""
|
||||
return f'https://discord.com/events/{self.guild_id}/{self.id}'
|
||||
|
||||
|
@ -75,12 +75,12 @@ class EventItem:
|
||||
self.shard: Optional['Shard'] = shard
|
||||
self.error: Optional[Exception] = error
|
||||
|
||||
def __lt__(self, other: Any) -> bool:
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if not isinstance(other, EventItem):
|
||||
return NotImplemented
|
||||
return self.type < other.type
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, EventItem):
|
||||
return NotImplemented
|
||||
return self.type == other.type
|
||||
@ -409,6 +409,7 @@ class AutoShardedClient(Client):
|
||||
|
||||
async def launch_shards(self) -> None:
|
||||
if self.shard_count is None:
|
||||
self.shard_count: int
|
||||
self.shard_count, gateway = await self.http.get_bot_gateway()
|
||||
else:
|
||||
gateway = await self.http.get_gateway()
|
||||
|
@ -97,11 +97,11 @@ class StageInstance(Hashable):
|
||||
)
|
||||
|
||||
def __init__(self, *, state: ConnectionState, guild: Guild, data: StageInstancePayload) -> None:
|
||||
self._state = state
|
||||
self.guild = guild
|
||||
self._state: ConnectionState = state
|
||||
self.guild: Guild = guild
|
||||
self._update(data)
|
||||
|
||||
def _update(self, data: StageInstancePayload):
|
||||
def _update(self, data: StageInstancePayload) -> None:
|
||||
self.id: int = int(data['id'])
|
||||
self.channel_id: int = int(data['channel_id'])
|
||||
self.topic: str = data['topic']
|
||||
|
@ -43,6 +43,8 @@ from typing import (
|
||||
Sequence,
|
||||
Tuple,
|
||||
Deque,
|
||||
Literal,
|
||||
overload,
|
||||
)
|
||||
import weakref
|
||||
import inspect
|
||||
@ -88,7 +90,7 @@ if TYPE_CHECKING:
|
||||
from .types.activity import Activity as ActivityPayload
|
||||
from .types.channel import DMChannel as DMChannelPayload
|
||||
from .types.user import User as UserPayload, PartialUser as PartialUserPayload
|
||||
from .types.emoji import Emoji as EmojiPayload
|
||||
from .types.emoji import Emoji as EmojiPayload, PartialEmoji as PartialEmojiPayload
|
||||
from .types.sticker import GuildSticker as GuildStickerPayload
|
||||
from .types.guild import Guild as GuildPayload
|
||||
from .types.message import Message as MessagePayload, PartialMessage as PartialMessagePayload
|
||||
@ -165,9 +167,9 @@ class ConnectionState:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dispatch: Callable,
|
||||
handlers: Dict[str, Callable],
|
||||
hooks: Dict[str, Callable],
|
||||
dispatch: Callable[..., Any],
|
||||
handlers: Dict[str, Callable[..., Any]],
|
||||
hooks: Dict[str, Callable[..., Coroutine[Any, Any, Any]]],
|
||||
http: HTTPClient,
|
||||
**options: Any,
|
||||
) -> None:
|
||||
@ -178,9 +180,9 @@ class ConnectionState:
|
||||
if self.max_messages is not None and self.max_messages <= 0:
|
||||
self.max_messages = 1000
|
||||
|
||||
self.dispatch: Callable = dispatch
|
||||
self.handlers: Dict[str, Callable] = handlers
|
||||
self.hooks: Dict[str, Callable] = hooks
|
||||
self.dispatch: Callable[..., Any] = dispatch
|
||||
self.handlers: Dict[str, Callable[..., Any]] = handlers
|
||||
self.hooks: Dict[str, Callable[..., Coroutine[Any, Any, Any]]] = hooks
|
||||
self.shard_count: Optional[int] = None
|
||||
self._ready_task: Optional[asyncio.Task] = None
|
||||
self.application_id: Optional[int] = utils._get_as_snowflake(options, 'application_id')
|
||||
@ -245,6 +247,7 @@ class ConnectionState:
|
||||
if not intents.members or cache_flags._empty:
|
||||
self.store_user = self.store_user_no_intents # type: ignore - This reassignment is on purpose
|
||||
|
||||
self.parsers: Dict[str, Callable[[Any], None]]
|
||||
self.parsers = parsers = {}
|
||||
for attr, func in inspect.getmembers(self):
|
||||
if attr.startswith('parse_'):
|
||||
@ -343,13 +346,13 @@ class ConnectionState:
|
||||
self._users[user_id] = user
|
||||
return user
|
||||
|
||||
def store_user_no_intents(self, data):
|
||||
def store_user_no_intents(self, data: Union[UserPayload, PartialUserPayload]) -> User:
|
||||
return User(state=self, data=data)
|
||||
|
||||
def create_user(self, data: Union[UserPayload, PartialUserPayload]) -> User:
|
||||
return User(state=self, data=data)
|
||||
|
||||
def get_user(self, id):
|
||||
def get_user(self, id: int) -> Optional[User]:
|
||||
return self._users.get(id)
|
||||
|
||||
def store_emoji(self, guild: Guild, data: EmojiPayload) -> Emoji:
|
||||
@ -571,8 +574,7 @@ class ConnectionState:
|
||||
pass
|
||||
else:
|
||||
self.application_id = utils._get_as_snowflake(application, 'id')
|
||||
# flags will always be present here
|
||||
self.application_flags = ApplicationFlags._from_value(application['flags'])
|
||||
self.application_flags: ApplicationFlags = ApplicationFlags._from_value(application['flags'])
|
||||
|
||||
for guild_data in data['guilds']:
|
||||
self._add_guild_from_data(guild_data) # type: ignore
|
||||
@ -743,7 +745,7 @@ class ConnectionState:
|
||||
|
||||
self.dispatch('presence_update', old_member, member)
|
||||
|
||||
def parse_user_update(self, data: gw.UserUpdateEvent):
|
||||
def parse_user_update(self, data: gw.UserUpdateEvent) -> None:
|
||||
if self.user:
|
||||
self.user._update(data)
|
||||
|
||||
@ -1050,7 +1052,7 @@ class ConnectionState:
|
||||
guild.stickers = tuple(map(lambda d: self.store_sticker(guild, d), data['stickers']))
|
||||
self.dispatch('guild_stickers_update', guild, before_stickers, guild.stickers)
|
||||
|
||||
def _get_create_guild(self, data):
|
||||
def _get_create_guild(self, data: gw.GuildCreateEvent) -> Guild:
|
||||
if data.get('unavailable') is False:
|
||||
# GUILD_CREATE with unavailable in the response
|
||||
# usually means that the guild has become available
|
||||
@ -1063,10 +1065,22 @@ class ConnectionState:
|
||||
|
||||
return self._add_guild_from_data(data)
|
||||
|
||||
def is_guild_evicted(self, guild) -> bool:
|
||||
def is_guild_evicted(self, guild: Guild) -> bool:
|
||||
return guild.id not in self._guilds
|
||||
|
||||
async def chunk_guild(self, guild, *, wait=True, cache=None):
|
||||
@overload
|
||||
async def chunk_guild(self, guild: Guild, *, wait: Literal[True] = ..., cache: Optional[bool] = ...) -> List[Member]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def chunk_guild(
|
||||
self, guild: Guild, *, wait: Literal[False] = ..., cache: Optional[bool] = ...
|
||||
) -> asyncio.Future[List[Member]]:
|
||||
...
|
||||
|
||||
async def chunk_guild(
|
||||
self, guild: Guild, *, wait: bool = True, cache: Optional[bool] = None
|
||||
) -> Union[List[Member], asyncio.Future[List[Member]]]:
|
||||
cache = cache or self.member_cache_flags.joined
|
||||
request = self._chunk_requests.get(guild.id)
|
||||
if request is None:
|
||||
@ -1445,16 +1459,19 @@ class ConnectionState:
|
||||
return channel.guild.get_member(user_id)
|
||||
return self.get_user(user_id)
|
||||
|
||||
def get_reaction_emoji(self, data) -> Union[Emoji, PartialEmoji]:
|
||||
def get_reaction_emoji(self, data: PartialEmojiPayload) -> Union[Emoji, PartialEmoji, str]:
|
||||
emoji_id = utils._get_as_snowflake(data, 'id')
|
||||
|
||||
if not emoji_id:
|
||||
return data['name']
|
||||
# the name key will be a str
|
||||
return data['name'] # type: ignore
|
||||
|
||||
try:
|
||||
return self._emojis[emoji_id]
|
||||
except KeyError:
|
||||
return PartialEmoji.with_state(self, animated=data.get('animated', False), id=emoji_id, name=data['name'])
|
||||
return PartialEmoji.with_state(
|
||||
self, animated=data.get('animated', False), id=emoji_id, name=data['name'] # type: ignore
|
||||
)
|
||||
|
||||
def _upgrade_partial_emoji(self, emoji: PartialEmoji) -> Union[Emoji, PartialEmoji, str]:
|
||||
emoji_id = emoji.id
|
||||
@ -1589,6 +1606,7 @@ class AutoShardedConnectionState(ConnectionState):
|
||||
if not hasattr(self, '_ready_state'):
|
||||
self._ready_state = asyncio.Queue()
|
||||
|
||||
self.user: Optional[ClientUser]
|
||||
self.user = user = ClientUser(state=self, data=data['user'])
|
||||
# self._users is a list of Users, we're setting a ClientUser
|
||||
self._users[user.id] = user # type: ignore
|
||||
@ -1599,8 +1617,8 @@ class AutoShardedConnectionState(ConnectionState):
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
self.application_id = utils._get_as_snowflake(application, 'id')
|
||||
self.application_flags = ApplicationFlags._from_value(application['flags'])
|
||||
self.application_id: Optional[int] = utils._get_as_snowflake(application, 'id')
|
||||
self.application_flags: ApplicationFlags = ApplicationFlags._from_value(application['flags'])
|
||||
|
||||
for guild_data in data['guilds']:
|
||||
self._add_guild_from_data(guild_data) # type: ignore - _add_guild_from_data requires a complete Guild payload
|
||||
|
@ -228,7 +228,7 @@ class StickerItem(_StickerTag):
|
||||
The retrieved sticker.
|
||||
"""
|
||||
data: StickerPayload = await self._state.http.get_sticker(self.id)
|
||||
cls, _ = _sticker_factory(data['type']) # type: ignore
|
||||
cls, _ = _sticker_factory(data['type'])
|
||||
return cls(state=self._state, data=data)
|
||||
|
||||
|
||||
|
@ -41,6 +41,8 @@ __all__ = (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
|
||||
from .types.threads import (
|
||||
Thread as ThreadPayload,
|
||||
ThreadMember as ThreadMemberPayload,
|
||||
@ -147,13 +149,13 @@ class Thread(Messageable, Hashable):
|
||||
'_created_at',
|
||||
)
|
||||
|
||||
def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload):
|
||||
def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload) -> None:
|
||||
self._state: ConnectionState = state
|
||||
self.guild = guild
|
||||
self.guild: Guild = guild
|
||||
self._members: Dict[int, ThreadMember] = {}
|
||||
self._from_data(data)
|
||||
|
||||
async def _get_channel(self):
|
||||
async def _get_channel(self) -> Self:
|
||||
return self
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -166,17 +168,18 @@ class Thread(Messageable, Hashable):
|
||||
return self.name
|
||||
|
||||
def _from_data(self, data: ThreadPayload):
|
||||
self.id = int(data['id'])
|
||||
self.parent_id = int(data['parent_id'])
|
||||
self.owner_id = int(data['owner_id'])
|
||||
self.name = data['name']
|
||||
self._type = try_enum(ChannelType, data['type'])
|
||||
self.last_message_id = _get_as_snowflake(data, 'last_message_id')
|
||||
self.slowmode_delay = data.get('rate_limit_per_user', 0)
|
||||
self.message_count = data['message_count']
|
||||
self.member_count = data['member_count']
|
||||
self.id: int = int(data['id'])
|
||||
self.parent_id: int = int(data['parent_id'])
|
||||
self.owner_id: int = int(data['owner_id'])
|
||||
self.name: str = data['name']
|
||||
self._type: ChannelType = try_enum(ChannelType, data['type'])
|
||||
self.last_message_id: Optional[int] = _get_as_snowflake(data, 'last_message_id')
|
||||
self.slowmode_delay: int = data.get('rate_limit_per_user', 0)
|
||||
self.message_count: int = data['message_count']
|
||||
self.member_count: int = data['member_count']
|
||||
self._unroll_metadata(data['thread_metadata'])
|
||||
|
||||
self.me: Optional[ThreadMember]
|
||||
try:
|
||||
member = data['member']
|
||||
except KeyError:
|
||||
@ -185,15 +188,15 @@ class Thread(Messageable, Hashable):
|
||||
self.me = ThreadMember(self, member)
|
||||
|
||||
def _unroll_metadata(self, data: ThreadMetadata):
|
||||
self.archived = data['archived']
|
||||
self.archiver_id = _get_as_snowflake(data, 'archiver_id')
|
||||
self.auto_archive_duration = data['auto_archive_duration']
|
||||
self.archive_timestamp = parse_time(data['archive_timestamp'])
|
||||
self.locked = data.get('locked', False)
|
||||
self.invitable = data.get('invitable', True)
|
||||
self._created_at = parse_time(data.get('create_timestamp'))
|
||||
self.archived: bool = data['archived']
|
||||
self.archiver_id: Optional[int] = _get_as_snowflake(data, 'archiver_id')
|
||||
self.auto_archive_duration: int = data['auto_archive_duration']
|
||||
self.archive_timestamp: datetime = parse_time(data['archive_timestamp'])
|
||||
self.locked: bool = data.get('locked', False)
|
||||
self.invitable: bool = data.get('invitable', True)
|
||||
self._created_at: Optional[datetime] = parse_time(data.get('create_timestamp'))
|
||||
|
||||
def _update(self, data):
|
||||
def _update(self, data: ThreadPayload) -> None:
|
||||
try:
|
||||
self.name = data['name']
|
||||
except KeyError:
|
||||
@ -602,7 +605,7 @@ class Thread(Messageable, Hashable):
|
||||
# The data payload will always be a Thread payload
|
||||
return Thread(data=data, state=self._state, guild=self.guild) # type: ignore
|
||||
|
||||
async def join(self):
|
||||
async def join(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
Joins this thread.
|
||||
@ -619,7 +622,7 @@ class Thread(Messageable, Hashable):
|
||||
"""
|
||||
await self._state.http.join_thread(self.id)
|
||||
|
||||
async def leave(self):
|
||||
async def leave(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
Leaves this thread.
|
||||
@ -631,7 +634,7 @@ class Thread(Messageable, Hashable):
|
||||
"""
|
||||
await self._state.http.leave_thread(self.id)
|
||||
|
||||
async def add_user(self, user: Snowflake, /):
|
||||
async def add_user(self, user: Snowflake, /) -> None:
|
||||
"""|coro|
|
||||
|
||||
Adds a user to this thread.
|
||||
@ -654,7 +657,7 @@ class Thread(Messageable, Hashable):
|
||||
"""
|
||||
await self._state.http.add_user_to_thread(self.id, user.id)
|
||||
|
||||
async def remove_user(self, user: Snowflake, /):
|
||||
async def remove_user(self, user: Snowflake, /) -> None:
|
||||
"""|coro|
|
||||
|
||||
Removes a user from this thread.
|
||||
@ -718,7 +721,7 @@ class Thread(Messageable, Hashable):
|
||||
members = await self._state.http.get_thread_members(self.id)
|
||||
return [ThreadMember(parent=self, data=data) for data in members]
|
||||
|
||||
async def delete(self):
|
||||
async def delete(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
Deletes this thread.
|
||||
@ -806,28 +809,28 @@ class ThreadMember(Hashable):
|
||||
'parent',
|
||||
)
|
||||
|
||||
def __init__(self, parent: Thread, data: ThreadMemberPayload):
|
||||
self.parent = parent
|
||||
self._state = parent._state
|
||||
def __init__(self, parent: Thread, data: ThreadMemberPayload) -> None:
|
||||
self.parent: Thread = parent
|
||||
self._state: ConnectionState = parent._state
|
||||
self._from_data(data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<ThreadMember id={self.id} thread_id={self.thread_id} joined_at={self.joined_at!r}>'
|
||||
|
||||
def _from_data(self, data: ThreadMemberPayload):
|
||||
def _from_data(self, data: ThreadMemberPayload) -> None:
|
||||
try:
|
||||
self.id = int(data['user_id'])
|
||||
except KeyError:
|
||||
assert self._state.self_id is not None
|
||||
self.id = self._state.self_id
|
||||
self.id = self._state.self_id # type: ignore
|
||||
|
||||
self.thread_id: int
|
||||
try:
|
||||
self.thread_id = int(data['id'])
|
||||
except KeyError:
|
||||
self.thread_id = self.parent.id
|
||||
|
||||
self.joined_at = parse_time(data['join_timestamp'])
|
||||
self.flags = data['flags']
|
||||
self.joined_at: datetime = parse_time(data['join_timestamp'])
|
||||
self.flags: int = data['flags']
|
||||
|
||||
@property
|
||||
def thread(self) -> Thread:
|
||||
|
@ -112,3 +112,4 @@ class Activity(_BaseActivity, total=False):
|
||||
session_id: Optional[str]
|
||||
instance: bool
|
||||
buttons: List[ActivityButton]
|
||||
sync_id: str
|
||||
|
@ -58,3 +58,8 @@ class Widget(TypedDict):
|
||||
class WidgetSettings(TypedDict):
|
||||
enabled: bool
|
||||
channel_id: Optional[Snowflake]
|
||||
|
||||
|
||||
class EditWidgetSettings(TypedDict, total=False):
|
||||
enabled: bool
|
||||
channel_id: Optional[Snowflake]
|
||||
|
@ -44,6 +44,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from .view import View
|
||||
from ..emoji import Emoji
|
||||
from ..types.components import ButtonComponent as ButtonComponentPayload
|
||||
|
||||
V = TypeVar('V', bound='View', covariant=True)
|
||||
|
||||
@ -124,7 +125,7 @@ class Button(Item[V]):
|
||||
style=style,
|
||||
emoji=emoji,
|
||||
)
|
||||
self.row = row
|
||||
self.row: Optional[int] = row
|
||||
|
||||
@property
|
||||
def style(self) -> ButtonStyle:
|
||||
@ -132,7 +133,7 @@ class Button(Item[V]):
|
||||
return self._underlying.style
|
||||
|
||||
@style.setter
|
||||
def style(self, value: ButtonStyle):
|
||||
def style(self, value: ButtonStyle) -> None:
|
||||
self._underlying.style = value
|
||||
|
||||
@property
|
||||
@ -144,7 +145,7 @@ class Button(Item[V]):
|
||||
return self._underlying.custom_id
|
||||
|
||||
@custom_id.setter
|
||||
def custom_id(self, value: Optional[str]):
|
||||
def custom_id(self, value: Optional[str]) -> None:
|
||||
if value is not None and not isinstance(value, str):
|
||||
raise TypeError('custom_id must be None or str')
|
||||
|
||||
@ -156,7 +157,7 @@ class Button(Item[V]):
|
||||
return self._underlying.url
|
||||
|
||||
@url.setter
|
||||
def url(self, value: Optional[str]):
|
||||
def url(self, value: Optional[str]) -> None:
|
||||
if value is not None and not isinstance(value, str):
|
||||
raise TypeError('url must be None or str')
|
||||
self._underlying.url = value
|
||||
@ -167,7 +168,7 @@ class Button(Item[V]):
|
||||
return self._underlying.disabled
|
||||
|
||||
@disabled.setter
|
||||
def disabled(self, value: bool):
|
||||
def disabled(self, value: bool) -> None:
|
||||
self._underlying.disabled = bool(value)
|
||||
|
||||
@property
|
||||
@ -176,7 +177,7 @@ class Button(Item[V]):
|
||||
return self._underlying.label
|
||||
|
||||
@label.setter
|
||||
def label(self, value: Optional[str]):
|
||||
def label(self, value: Optional[str]) -> None:
|
||||
self._underlying.label = str(value) if value is not None else value
|
||||
|
||||
@property
|
||||
@ -185,7 +186,7 @@ class Button(Item[V]):
|
||||
return self._underlying.emoji
|
||||
|
||||
@emoji.setter
|
||||
def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]): # type: ignore
|
||||
def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]) -> None:
|
||||
if value is not None:
|
||||
if isinstance(value, str):
|
||||
self._underlying.emoji = PartialEmoji.from_str(value)
|
||||
@ -212,7 +213,7 @@ class Button(Item[V]):
|
||||
def type(self) -> ComponentType:
|
||||
return self._underlying.type
|
||||
|
||||
def to_component_dict(self):
|
||||
def to_component_dict(self) -> ButtonComponentPayload:
|
||||
return self._underlying.to_dict()
|
||||
|
||||
def is_dispatchable(self) -> bool:
|
||||
|
@ -101,7 +101,7 @@ class Item(Generic[V]):
|
||||
return self._row
|
||||
|
||||
@row.setter
|
||||
def row(self, value: Optional[int]):
|
||||
def row(self, value: Optional[int]) -> None:
|
||||
if value is None:
|
||||
self._row = None
|
||||
elif 5 > value >= 0:
|
||||
@ -118,7 +118,7 @@ class Item(Generic[V]):
|
||||
"""Optional[:class:`View`]: The underlying view for this item."""
|
||||
return self._view
|
||||
|
||||
async def callback(self, interaction: Interaction):
|
||||
async def callback(self, interaction: Interaction) -> Any:
|
||||
"""|coro|
|
||||
|
||||
The callback associated with this UI item.
|
||||
|
@ -38,6 +38,8 @@ from .item import Item
|
||||
from .view import View
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..interactions import Interaction
|
||||
from ..types.interactions import ModalSubmitComponentInteractionData as ModalSubmitComponentInteractionDataPayload
|
||||
|
||||
@ -101,7 +103,7 @@ class Modal(View):
|
||||
title: str
|
||||
|
||||
__discord_ui_modal__ = True
|
||||
__modal_children_items__: ClassVar[Dict[str, Item]] = {}
|
||||
__modal_children_items__: ClassVar[Dict[str, Item[Self]]] = {}
|
||||
|
||||
def __init_subclass__(cls, *, title: str = MISSING) -> None:
|
||||
if title is not MISSING:
|
||||
@ -139,7 +141,7 @@ class Modal(View):
|
||||
|
||||
super().__init__(timeout=timeout)
|
||||
|
||||
async def on_submit(self, interaction: Interaction):
|
||||
async def on_submit(self, interaction: Interaction) -> None:
|
||||
"""|coro|
|
||||
|
||||
Called when the modal is submitted.
|
||||
@ -169,7 +171,7 @@ class Modal(View):
|
||||
print(f'Ignoring exception in modal {self}:', file=sys.stderr)
|
||||
traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr)
|
||||
|
||||
def refresh(self, components: Sequence[ModalSubmitComponentInteractionDataPayload]):
|
||||
def refresh(self, components: Sequence[ModalSubmitComponentInteractionDataPayload]) -> None:
|
||||
for component in components:
|
||||
if component['type'] == 1:
|
||||
self.refresh(component['components'])
|
||||
|
@ -121,7 +121,7 @@ class Select(Item[V]):
|
||||
options=options,
|
||||
disabled=disabled,
|
||||
)
|
||||
self.row = row
|
||||
self.row: Optional[int] = row
|
||||
|
||||
@property
|
||||
def custom_id(self) -> str:
|
||||
@ -129,7 +129,7 @@ class Select(Item[V]):
|
||||
return self._underlying.custom_id
|
||||
|
||||
@custom_id.setter
|
||||
def custom_id(self, value: str):
|
||||
def custom_id(self, value: str) -> None:
|
||||
if not isinstance(value, str):
|
||||
raise TypeError('custom_id must be None or str')
|
||||
|
||||
@ -141,7 +141,7 @@ class Select(Item[V]):
|
||||
return self._underlying.placeholder
|
||||
|
||||
@placeholder.setter
|
||||
def placeholder(self, value: Optional[str]):
|
||||
def placeholder(self, value: Optional[str]) -> None:
|
||||
if value is not None and not isinstance(value, str):
|
||||
raise TypeError('placeholder must be None or str')
|
||||
|
||||
@ -153,7 +153,7 @@ class Select(Item[V]):
|
||||
return self._underlying.min_values
|
||||
|
||||
@min_values.setter
|
||||
def min_values(self, value: int):
|
||||
def min_values(self, value: int) -> None:
|
||||
self._underlying.min_values = int(value)
|
||||
|
||||
@property
|
||||
@ -162,7 +162,7 @@ class Select(Item[V]):
|
||||
return self._underlying.max_values
|
||||
|
||||
@max_values.setter
|
||||
def max_values(self, value: int):
|
||||
def max_values(self, value: int) -> None:
|
||||
self._underlying.max_values = int(value)
|
||||
|
||||
@property
|
||||
@ -171,7 +171,7 @@ class Select(Item[V]):
|
||||
return self._underlying.options
|
||||
|
||||
@options.setter
|
||||
def options(self, value: List[SelectOption]):
|
||||
def options(self, value: List[SelectOption]) -> None:
|
||||
if not isinstance(value, list):
|
||||
raise TypeError('options must be a list of SelectOption')
|
||||
if not all(isinstance(obj, SelectOption) for obj in value):
|
||||
@ -187,7 +187,7 @@ class Select(Item[V]):
|
||||
description: Optional[str] = None,
|
||||
emoji: Optional[Union[str, Emoji, PartialEmoji]] = None,
|
||||
default: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
"""Adds an option to the select menu.
|
||||
|
||||
To append a pre-existing :class:`discord.SelectOption` use the
|
||||
@ -226,7 +226,7 @@ class Select(Item[V]):
|
||||
|
||||
self.append_option(option)
|
||||
|
||||
def append_option(self, option: SelectOption):
|
||||
def append_option(self, option: SelectOption) -> None:
|
||||
"""Appends an option to the select menu.
|
||||
|
||||
Parameters
|
||||
@ -251,7 +251,7 @@ class Select(Item[V]):
|
||||
return self._underlying.disabled
|
||||
|
||||
@disabled.setter
|
||||
def disabled(self, value: bool):
|
||||
def disabled(self, value: bool) -> None:
|
||||
self._underlying.disabled = bool(value)
|
||||
|
||||
@property
|
||||
|
@ -50,6 +50,8 @@ __all__ = (
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..interactions import Interaction
|
||||
from ..message import Message
|
||||
from ..types.components import Component as ComponentPayload
|
||||
@ -163,7 +165,7 @@ class View:
|
||||
|
||||
cls.__view_children_items__ = children
|
||||
|
||||
def _init_children(self) -> List[Item]:
|
||||
def _init_children(self) -> List[Item[Self]]:
|
||||
children = []
|
||||
for func in self.__view_children_items__:
|
||||
item: Item = func.__discord_ui_model_type__(**func.__discord_ui_model_kwargs__)
|
||||
@ -175,7 +177,7 @@ class View:
|
||||
|
||||
def __init__(self, *, timeout: Optional[float] = 180.0):
|
||||
self.timeout = timeout
|
||||
self.children: List[Item] = self._init_children()
|
||||
self.children: List[Item[Self]] = self._init_children()
|
||||
self.__weights = _ViewWeights(self.children)
|
||||
self.id: str = os.urandom(16).hex()
|
||||
self.__cancel_callback: Optional[Callable[[View], None]] = None
|
||||
@ -250,7 +252,7 @@ class View:
|
||||
view.add_item(_component_to_item(component))
|
||||
return view
|
||||
|
||||
def add_item(self, item: Item) -> None:
|
||||
def add_item(self, item: Item[Any]) -> None:
|
||||
"""Adds an item to the view.
|
||||
|
||||
Parameters
|
||||
@ -278,7 +280,7 @@ class View:
|
||||
item._view = self
|
||||
self.children.append(item)
|
||||
|
||||
def remove_item(self, item: Item) -> None:
|
||||
def remove_item(self, item: Item[Any]) -> None:
|
||||
"""Removes an item from the view.
|
||||
|
||||
Parameters
|
||||
@ -334,7 +336,7 @@ class View:
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_error(self, error: Exception, item: Item, interaction: Interaction) -> None:
|
||||
async def on_error(self, error: Exception, item: Item[Any], interaction: Interaction) -> None:
|
||||
"""|coro|
|
||||
|
||||
A callback that is called when an item's callback or :meth:`interaction_check`
|
||||
@ -395,16 +397,16 @@ class View:
|
||||
|
||||
asyncio.create_task(self._scheduled_task(item, interaction), name=f'discord-ui-view-dispatch-{self.id}')
|
||||
|
||||
def refresh(self, components: List[Component]):
|
||||
def refresh(self, components: List[Component]) -> None:
|
||||
# This is pretty hacky at the moment
|
||||
# fmt: off
|
||||
old_state: Dict[Tuple[int, str], Item] = {
|
||||
old_state: Dict[Tuple[int, str], Item[Any]] = {
|
||||
(item.type.value, item.custom_id): item # type: ignore
|
||||
for item in self.children
|
||||
if item.is_dispatchable()
|
||||
}
|
||||
# fmt: on
|
||||
children: List[Item] = []
|
||||
children: List[Item[Any]] = []
|
||||
for component in _walk_all_components(components):
|
||||
try:
|
||||
older = old_state[(component.type.value, component.custom_id)] # type: ignore
|
||||
@ -494,7 +496,7 @@ class ViewStore:
|
||||
for k in to_remove:
|
||||
del self._views[k]
|
||||
|
||||
def add_view(self, view: View, message_id: Optional[int] = None):
|
||||
def add_view(self, view: View, message_id: Optional[int] = None) -> None:
|
||||
view._start_listening_from_store(self)
|
||||
if view.__discord_ui_modal__:
|
||||
self._modals[view.custom_id] = view # type: ignore
|
||||
@ -509,7 +511,7 @@ class ViewStore:
|
||||
if message_id is not None:
|
||||
self._synced_message_views[message_id] = view
|
||||
|
||||
def remove_view(self, view: View):
|
||||
def remove_view(self, view: View) -> None:
|
||||
if view.__discord_ui_modal__:
|
||||
self._modals.pop(view.custom_id, None) # type: ignore
|
||||
return
|
||||
@ -523,7 +525,7 @@ class ViewStore:
|
||||
del self._synced_message_views[key]
|
||||
break
|
||||
|
||||
def dispatch_view(self, component_type: int, custom_id: str, interaction: Interaction):
|
||||
def dispatch_view(self, component_type: int, custom_id: str, interaction: Interaction) -> None:
|
||||
self.__verify_integrity()
|
||||
message_id: Optional[int] = interaction.message and interaction.message.id
|
||||
key = (component_type, message_id, custom_id)
|
||||
@ -542,7 +544,7 @@ class ViewStore:
|
||||
custom_id: str,
|
||||
interaction: Interaction,
|
||||
components: List[ModalSubmitComponentInteractionDataPayload],
|
||||
):
|
||||
) -> None:
|
||||
modal = self._modals.get(custom_id)
|
||||
if modal is None:
|
||||
_log.debug("Modal interaction referencing unknown custom_id %s. Discarding", custom_id)
|
||||
@ -551,13 +553,13 @@ class ViewStore:
|
||||
modal.refresh(components)
|
||||
modal._dispatch_submit(interaction)
|
||||
|
||||
def is_message_tracked(self, message_id: int):
|
||||
def is_message_tracked(self, message_id: int) -> bool:
|
||||
return message_id in self._synced_message_views
|
||||
|
||||
def remove_message_tracking(self, message_id: int) -> Optional[View]:
|
||||
return self._synced_message_views.pop(message_id, None)
|
||||
|
||||
def update_from_message(self, message_id: int, components: List[ComponentPayload]):
|
||||
def update_from_message(self, message_id: int, components: List[ComponentPayload]) -> None:
|
||||
# pre-req: is_message_tracked == true
|
||||
view = self._synced_message_views[message_id]
|
||||
view.refresh([_component_factory(d) for d in components])
|
||||
|
@ -99,10 +99,10 @@ class BaseUser(_UserTag):
|
||||
def __str__(self) -> str:
|
||||
return f'{self.name}#{self.discriminator}'
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, _UserTag) and other.id == self.id
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
@ -444,7 +444,7 @@ class User(BaseUser, discord.abc.Messageable):
|
||||
def __repr__(self) -> str:
|
||||
return f'<User id={self.id} name={self.name!r} discriminator={self.discriminator!r} bot={self.bot}>'
|
||||
|
||||
async def _get_channel(self):
|
||||
async def _get_channel(self) -> DMChannel:
|
||||
ch = await self.create_dm()
|
||||
return ch
|
||||
|
||||
|
@ -29,6 +29,7 @@ from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
@ -42,6 +43,7 @@ from typing import (
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Protocol,
|
||||
Set,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
@ -66,7 +68,7 @@ import warnings
|
||||
import yarl
|
||||
|
||||
try:
|
||||
import orjson
|
||||
import orjson # type: ignore
|
||||
except ModuleNotFoundError:
|
||||
HAS_ORJSON = False
|
||||
else:
|
||||
@ -123,7 +125,7 @@ class _cached_property:
|
||||
if TYPE_CHECKING:
|
||||
from functools import cached_property as cached_property
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
from typing_extensions import ParamSpec, Self
|
||||
|
||||
from .permissions import Permissions
|
||||
from .abc import Snowflake
|
||||
@ -135,8 +137,16 @@ if TYPE_CHECKING:
|
||||
|
||||
P = ParamSpec('P')
|
||||
|
||||
MaybeCoroFunc = Union[
|
||||
Callable[P, Coroutine[Any, Any, 'T']],
|
||||
Callable[P, 'T'],
|
||||
]
|
||||
|
||||
_SnowflakeListBase = array.array[int]
|
||||
|
||||
else:
|
||||
cached_property = _cached_property
|
||||
_SnowflakeListBase = array.array
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
@ -178,7 +188,7 @@ class classproperty(Generic[T_co]):
|
||||
def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co:
|
||||
return self.fget(owner)
|
||||
|
||||
def __set__(self, instance, value) -> None:
|
||||
def __set__(self, instance: Optional[Any], value: Any) -> None:
|
||||
raise AttributeError('cannot set attribute')
|
||||
|
||||
|
||||
@ -210,7 +220,7 @@ class SequenceProxy(Sequence[T_co]):
|
||||
def __reversed__(self) -> Iterator[T_co]:
|
||||
return reversed(self.__proxied)
|
||||
|
||||
def index(self, value: Any, *args, **kwargs) -> int:
|
||||
def index(self, value: Any, *args: Any, **kwargs: Any) -> int:
|
||||
return self.__proxied.index(value, *args, **kwargs)
|
||||
|
||||
def count(self, value: Any) -> int:
|
||||
@ -578,7 +588,7 @@ def _is_submodule(parent: str, child: str) -> bool:
|
||||
|
||||
if HAS_ORJSON:
|
||||
|
||||
def _to_json(obj: Any) -> str: # type: ignore
|
||||
def _to_json(obj: Any) -> str:
|
||||
return orjson.dumps(obj).decode('utf-8')
|
||||
|
||||
_from_json = orjson.loads # type: ignore
|
||||
@ -602,15 +612,15 @@ def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
|
||||
return float(reset_after)
|
||||
|
||||
|
||||
async def maybe_coroutine(f, *args, **kwargs):
|
||||
async def maybe_coroutine(f: MaybeCoroFunc[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
||||
value = f(*args, **kwargs)
|
||||
if _isawaitable(value):
|
||||
return await value
|
||||
else:
|
||||
return value
|
||||
return value # type: ignore
|
||||
|
||||
|
||||
async def async_all(gen, *, check=_isawaitable):
|
||||
async def async_all(gen: Iterable[Awaitable[T]], *, check: Callable[[T], bool] = _isawaitable) -> bool:
|
||||
for elem in gen:
|
||||
if check(elem):
|
||||
elem = await elem
|
||||
@ -619,7 +629,7 @@ async def async_all(gen, *, check=_isawaitable):
|
||||
return True
|
||||
|
||||
|
||||
async def sane_wait_for(futures, *, timeout):
|
||||
async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: Optional[float]) -> Set[asyncio.Task[T]]:
|
||||
ensured = [asyncio.ensure_future(fut) for fut in futures]
|
||||
done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED)
|
||||
|
||||
@ -637,7 +647,7 @@ def get_slots(cls: Type[Any]) -> Iterator[str]:
|
||||
continue
|
||||
|
||||
|
||||
def compute_timedelta(dt: datetime.datetime):
|
||||
def compute_timedelta(dt: datetime.datetime) -> float:
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.astimezone()
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
@ -686,7 +696,7 @@ def valid_icon_size(size: int) -> bool:
|
||||
return not size & (size - 1) and 4096 >= size >= 16
|
||||
|
||||
|
||||
class SnowflakeList(array.array):
|
||||
class SnowflakeList(_SnowflakeListBase):
|
||||
"""Internal data storage class to efficiently store a list of snowflakes.
|
||||
|
||||
This should have the following characteristics:
|
||||
@ -705,7 +715,7 @@ class SnowflakeList(array.array):
|
||||
def __init__(self, data: Iterable[int], *, is_sorted: bool = False):
|
||||
...
|
||||
|
||||
def __new__(cls, data: Iterable[int], *, is_sorted: bool = False):
|
||||
def __new__(cls, data: Iterable[int], *, is_sorted: bool = False) -> Self:
|
||||
return array.array.__new__(cls, 'Q', data if is_sorted else sorted(data)) # type: ignore
|
||||
|
||||
def add(self, element: int) -> None:
|
||||
@ -1010,7 +1020,7 @@ def evaluate_annotation(
|
||||
cache: Dict[str, Any],
|
||||
*,
|
||||
implicit_str: bool = True,
|
||||
):
|
||||
) -> Any:
|
||||
if isinstance(tp, ForwardRef):
|
||||
tp = tp.__forward_arg__
|
||||
# ForwardRefs always evaluate their internals
|
||||
|
@ -262,7 +262,7 @@ class VoiceClient(VoiceProtocol):
|
||||
self._lite_nonce: int = 0
|
||||
self.ws: DiscordVoiceWebSocket = MISSING
|
||||
|
||||
warn_nacl = not has_nacl
|
||||
warn_nacl: bool = not has_nacl
|
||||
supported_modes: Tuple[SupportedModes, ...] = (
|
||||
'xsalsa20_poly1305_lite',
|
||||
'xsalsa20_poly1305_suffix',
|
||||
@ -279,7 +279,7 @@ class VoiceClient(VoiceProtocol):
|
||||
""":class:`ClientUser`: The user connected to voice (i.e. ourselves)."""
|
||||
return self._state.user # type: ignore - user can't be None after login
|
||||
|
||||
def checked_add(self, attr, value, limit):
|
||||
def checked_add(self, attr: str, value: int, limit: int) -> None:
|
||||
val = getattr(self, attr)
|
||||
if val + value > limit:
|
||||
setattr(self, attr, 0)
|
||||
@ -289,7 +289,7 @@ class VoiceClient(VoiceProtocol):
|
||||
# connection related
|
||||
|
||||
async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
|
||||
self.session_id = data['session_id']
|
||||
self.session_id: str = data['session_id']
|
||||
channel_id = data['channel_id']
|
||||
|
||||
if not self._handshaking or self._potentially_reconnecting:
|
||||
@ -323,12 +323,12 @@ class VoiceClient(VoiceProtocol):
|
||||
self.endpoint, _, _ = endpoint.rpartition(':')
|
||||
if self.endpoint.startswith('wss://'):
|
||||
# Just in case, strip it off since we're going to add it later
|
||||
self.endpoint = self.endpoint[6:]
|
||||
self.endpoint: str = self.endpoint[6:]
|
||||
|
||||
# This gets set later
|
||||
self.endpoint_ip = MISSING
|
||||
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.socket.setblocking(False)
|
||||
|
||||
if not self._handshaking:
|
||||
|
@ -30,7 +30,7 @@ import json
|
||||
import re
|
||||
|
||||
from urllib.parse import quote as urlquote
|
||||
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload
|
||||
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, TypeVar, Type, overload
|
||||
from contextvars import ContextVar
|
||||
import weakref
|
||||
|
||||
@ -43,7 +43,7 @@ from ..enums import try_enum, WebhookType
|
||||
from ..user import BaseUser, User
|
||||
from ..flags import MessageFlags
|
||||
from ..asset import Asset
|
||||
from ..http import Route, handle_message_parameters, MultipartParameters
|
||||
from ..http import Route, handle_message_parameters, MultipartParameters, HTTPClient
|
||||
from ..mixins import Hashable
|
||||
from ..channel import PartialMessageable
|
||||
from ..file import File
|
||||
@ -58,24 +58,38 @@ __all__ = (
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
from types import TracebackType
|
||||
|
||||
from ..embeds import Embed
|
||||
from ..mentions import AllowedMentions
|
||||
from ..message import Attachment
|
||||
from ..state import ConnectionState
|
||||
from ..http import Response
|
||||
from ..types.webhook import (
|
||||
Webhook as WebhookPayload,
|
||||
)
|
||||
from ..types.message import (
|
||||
Message as MessagePayload,
|
||||
)
|
||||
from ..guild import Guild
|
||||
from ..channel import TextChannel
|
||||
from ..abc import Snowflake
|
||||
from ..ui.view import View
|
||||
import datetime
|
||||
from ..types.webhook import (
|
||||
Webhook as WebhookPayload,
|
||||
SourceGuild as SourceGuildPayload,
|
||||
)
|
||||
from ..types.message import (
|
||||
Message as MessagePayload,
|
||||
)
|
||||
from ..types.user import (
|
||||
User as UserPayload,
|
||||
PartialUser as PartialUserPayload,
|
||||
)
|
||||
from ..types.channel import (
|
||||
PartialChannel as PartialChannelPayload,
|
||||
)
|
||||
|
||||
MISSING = utils.MISSING
|
||||
BE = TypeVar('BE', bound=BaseException)
|
||||
_State = Union[ConnectionState, '_WebhookState']
|
||||
|
||||
MISSING: Any = utils.MISSING
|
||||
|
||||
|
||||
class AsyncDeferredLock:
|
||||
@ -83,14 +97,19 @@ class AsyncDeferredLock:
|
||||
self.lock = lock
|
||||
self.delta: Optional[float] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
async def __aenter__(self) -> Self:
|
||||
await self.lock.acquire()
|
||||
return self
|
||||
|
||||
def delay_by(self, delta: float) -> None:
|
||||
self.delta = delta
|
||||
|
||||
async def __aexit__(self, type, value, traceback):
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BE]],
|
||||
exc: Optional[BE],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
if self.delta:
|
||||
await asyncio.sleep(self.delta)
|
||||
self.lock.release()
|
||||
@ -545,11 +564,11 @@ class PartialWebhookChannel(Hashable):
|
||||
|
||||
__slots__ = ('id', 'name')
|
||||
|
||||
def __init__(self, *, data):
|
||||
self.id = int(data['id'])
|
||||
self.name = data['name']
|
||||
def __init__(self, *, data: PartialChannelPayload) -> None:
|
||||
self.id: int = int(data['id'])
|
||||
self.name: str = data['name']
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f'<PartialWebhookChannel name={self.name!r} id={self.id}>'
|
||||
|
||||
|
||||
@ -570,13 +589,13 @@ class PartialWebhookGuild(Hashable):
|
||||
|
||||
__slots__ = ('id', 'name', '_icon', '_state')
|
||||
|
||||
def __init__(self, *, data, state):
|
||||
self._state = state
|
||||
self.id = int(data['id'])
|
||||
self.name = data['name']
|
||||
self._icon = data['icon']
|
||||
def __init__(self, *, data: SourceGuildPayload, state: _State) -> None:
|
||||
self._state: _State = state
|
||||
self.id: int = int(data['id'])
|
||||
self.name: str = data['name']
|
||||
self._icon: str = data['icon']
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f'<PartialWebhookGuild name={self.name!r} id={self.id}>'
|
||||
|
||||
@property
|
||||
@ -590,14 +609,14 @@ class PartialWebhookGuild(Hashable):
|
||||
class _FriendlyHttpAttributeErrorHelper:
|
||||
__slots__ = ()
|
||||
|
||||
def __getattr__(self, attr):
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
raise AttributeError('PartialWebhookState does not support http methods.')
|
||||
|
||||
|
||||
class _WebhookState:
|
||||
__slots__ = ('_parent', '_webhook')
|
||||
|
||||
def __init__(self, webhook: Any, parent: Optional[Union[ConnectionState, _WebhookState]]):
|
||||
def __init__(self, webhook: Any, parent: Optional[_State]):
|
||||
self._webhook: Any = webhook
|
||||
|
||||
self._parent: Optional[ConnectionState]
|
||||
@ -606,23 +625,23 @@ class _WebhookState:
|
||||
else:
|
||||
self._parent = parent
|
||||
|
||||
def _get_guild(self, guild_id):
|
||||
def _get_guild(self, guild_id: Optional[int]) -> Optional[Guild]:
|
||||
if self._parent is not None:
|
||||
return self._parent._get_guild(guild_id)
|
||||
return None
|
||||
|
||||
def store_user(self, data):
|
||||
def store_user(self, data: Union[UserPayload, PartialUserPayload]) -> BaseUser:
|
||||
if self._parent is not None:
|
||||
return self._parent.store_user(data)
|
||||
# state parameter is artificial
|
||||
return BaseUser(state=self, data=data) # type: ignore
|
||||
|
||||
def create_user(self, data):
|
||||
def create_user(self, data: Union[UserPayload, PartialUserPayload]) -> BaseUser:
|
||||
# state parameter is artificial
|
||||
return BaseUser(state=self, data=data) # type: ignore
|
||||
|
||||
@property
|
||||
def http(self):
|
||||
def http(self) -> Union[HTTPClient, _FriendlyHttpAttributeErrorHelper]:
|
||||
if self._parent is not None:
|
||||
return self._parent.http
|
||||
|
||||
@ -630,7 +649,7 @@ class _WebhookState:
|
||||
# however, using it should result in a late-binding error.
|
||||
return _FriendlyHttpAttributeErrorHelper()
|
||||
|
||||
def __getattr__(self, attr):
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
if self._parent is not None:
|
||||
return getattr(self._parent, attr)
|
||||
|
||||
@ -830,19 +849,24 @@ class BaseWebhook(Hashable):
|
||||
'_state',
|
||||
)
|
||||
|
||||
def __init__(self, data: WebhookPayload, token: Optional[str] = None, state: Optional[ConnectionState] = None):
|
||||
def __init__(
|
||||
self,
|
||||
data: WebhookPayload,
|
||||
token: Optional[str] = None,
|
||||
state: Optional[_State] = None,
|
||||
) -> None:
|
||||
self.auth_token: Optional[str] = token
|
||||
self._state: Union[ConnectionState, _WebhookState] = state or _WebhookState(self, parent=state)
|
||||
self._state: _State = state or _WebhookState(self, parent=state)
|
||||
self._update(data)
|
||||
|
||||
def _update(self, data: WebhookPayload):
|
||||
self.id = int(data['id'])
|
||||
self.type = try_enum(WebhookType, int(data['type']))
|
||||
self.channel_id = utils._get_as_snowflake(data, 'channel_id')
|
||||
self.guild_id = utils._get_as_snowflake(data, 'guild_id')
|
||||
self.name = data.get('name')
|
||||
self._avatar = data.get('avatar')
|
||||
self.token = data.get('token')
|
||||
def _update(self, data: WebhookPayload) -> None:
|
||||
self.id: int = int(data['id'])
|
||||
self.type: WebhookType = try_enum(WebhookType, int(data['type']))
|
||||
self.channel_id: Optional[int] = utils._get_as_snowflake(data, 'channel_id')
|
||||
self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id')
|
||||
self.name: Optional[str] = data.get('name')
|
||||
self._avatar: Optional[str] = data.get('avatar')
|
||||
self.token: Optional[str] = data.get('token')
|
||||
|
||||
user = data.get('user')
|
||||
self.user: Optional[Union[BaseUser, User]] = None
|
||||
@ -1010,11 +1034,17 @@ class Webhook(BaseWebhook):
|
||||
|
||||
__slots__: Tuple[str, ...] = ('session',)
|
||||
|
||||
def __init__(self, data: WebhookPayload, session: aiohttp.ClientSession, token: Optional[str] = None, state=None):
|
||||
def __init__(
|
||||
self,
|
||||
data: WebhookPayload,
|
||||
session: aiohttp.ClientSession,
|
||||
token: Optional[str] = None,
|
||||
state: Optional[_State] = None,
|
||||
) -> None:
|
||||
super().__init__(data, token, state)
|
||||
self.session = session
|
||||
self.session: aiohttp.ClientSession = session
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f'<Webhook id={self.id!r}>'
|
||||
|
||||
@property
|
||||
@ -1023,7 +1053,7 @@ class Webhook(BaseWebhook):
|
||||
return f'https://discord.com/api/webhooks/{self.id}/{self.token}'
|
||||
|
||||
@classmethod
|
||||
def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook:
|
||||
def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Self:
|
||||
"""Creates a partial :class:`Webhook`.
|
||||
|
||||
Parameters
|
||||
@ -1059,7 +1089,7 @@ class Webhook(BaseWebhook):
|
||||
return cls(data, session, token=bot_token)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook:
|
||||
def from_url(cls, url: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Self:
|
||||
"""Creates a partial :class:`Webhook` from a webhook URL.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
@ -1102,7 +1132,7 @@ class Webhook(BaseWebhook):
|
||||
return cls(data, session, token=bot_token) # type: ignore
|
||||
|
||||
@classmethod
|
||||
def _as_follower(cls, data, *, channel, user) -> Webhook:
|
||||
def _as_follower(cls, data, *, channel, user) -> Self:
|
||||
name = f"{channel.guild} #{channel}"
|
||||
feed: WebhookPayload = {
|
||||
'id': data['webhook_id'],
|
||||
@ -1118,8 +1148,8 @@ class Webhook(BaseWebhook):
|
||||
return cls(feed, session=session, state=state, token=state.http.token)
|
||||
|
||||
@classmethod
|
||||
def from_state(cls, data, state) -> Webhook:
|
||||
session = state.http._HTTPClient__session
|
||||
def from_state(cls, data: WebhookPayload, state: ConnectionState) -> Self:
|
||||
session = state.http._HTTPClient__session # type: ignore
|
||||
return cls(data, session=session, state=state, token=state.http.token)
|
||||
|
||||
async def fetch(self, *, prefer_auth: bool = True) -> Webhook:
|
||||
@ -1168,7 +1198,7 @@ class Webhook(BaseWebhook):
|
||||
|
||||
return Webhook(data, self.session, token=self.auth_token, state=self._state)
|
||||
|
||||
async def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True):
|
||||
async def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True) -> None:
|
||||
"""|coro|
|
||||
|
||||
Deletes this Webhook.
|
||||
|
@ -37,7 +37,7 @@ import time
|
||||
import re
|
||||
|
||||
from urllib.parse import quote as urlquote
|
||||
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload
|
||||
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, TypeVar, Type, overload
|
||||
import weakref
|
||||
|
||||
from .. import utils
|
||||
@ -56,36 +56,50 @@ __all__ = (
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
from types import TracebackType
|
||||
|
||||
from ..file import File
|
||||
from ..embeds import Embed
|
||||
from ..mentions import AllowedMentions
|
||||
from ..message import Attachment
|
||||
from ..abc import Snowflake
|
||||
from ..state import ConnectionState
|
||||
from ..types.webhook import (
|
||||
Webhook as WebhookPayload,
|
||||
)
|
||||
from ..abc import Snowflake
|
||||
from ..types.message import (
|
||||
Message as MessagePayload,
|
||||
)
|
||||
|
||||
BE = TypeVar('BE', bound=BaseException)
|
||||
|
||||
try:
|
||||
from requests import Session, Response
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
MISSING = utils.MISSING
|
||||
MISSING: Any = utils.MISSING
|
||||
|
||||
|
||||
class DeferredLock:
|
||||
def __init__(self, lock: threading.Lock):
|
||||
self.lock = lock
|
||||
def __init__(self, lock: threading.Lock) -> None:
|
||||
self.lock: threading.Lock = lock
|
||||
self.delta: Optional[float] = None
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> Self:
|
||||
self.lock.acquire()
|
||||
return self
|
||||
|
||||
def delay_by(self, delta: float) -> None:
|
||||
self.delta = delta
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BE]],
|
||||
exc: Optional[BE],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
if self.delta:
|
||||
time.sleep(self.delta)
|
||||
self.lock.release()
|
||||
@ -218,7 +232,7 @@ class WebhookAdapter:
|
||||
token: Optional[str] = None,
|
||||
session: Session,
|
||||
reason: Optional[str] = None,
|
||||
):
|
||||
) -> None:
|
||||
route = Route('DELETE', '/webhooks/{webhook_id}', webhook_id=webhook_id)
|
||||
return self.request(route, session, reason=reason, auth_token=token)
|
||||
|
||||
@ -229,7 +243,7 @@ class WebhookAdapter:
|
||||
*,
|
||||
session: Session,
|
||||
reason: Optional[str] = None,
|
||||
):
|
||||
) -> None:
|
||||
route = Route('DELETE', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
|
||||
return self.request(route, session, reason=reason)
|
||||
|
||||
@ -241,7 +255,7 @@ class WebhookAdapter:
|
||||
*,
|
||||
session: Session,
|
||||
reason: Optional[str] = None,
|
||||
):
|
||||
) -> WebhookPayload:
|
||||
route = Route('PATCH', '/webhooks/{webhook_id}', webhook_id=webhook_id)
|
||||
return self.request(route, session, reason=reason, payload=payload, auth_token=token)
|
||||
|
||||
@ -253,7 +267,7 @@ class WebhookAdapter:
|
||||
*,
|
||||
session: Session,
|
||||
reason: Optional[str] = None,
|
||||
):
|
||||
) -> WebhookPayload:
|
||||
route = Route('PATCH', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
|
||||
return self.request(route, session, reason=reason, payload=payload)
|
||||
|
||||
@ -268,7 +282,7 @@ class WebhookAdapter:
|
||||
files: Optional[List[File]] = None,
|
||||
thread_id: Optional[int] = None,
|
||||
wait: bool = False,
|
||||
):
|
||||
) -> MessagePayload:
|
||||
params = {'wait': int(wait)}
|
||||
if thread_id:
|
||||
params['thread_id'] = thread_id
|
||||
@ -282,7 +296,7 @@ class WebhookAdapter:
|
||||
message_id: int,
|
||||
*,
|
||||
session: Session,
|
||||
):
|
||||
) -> MessagePayload:
|
||||
route = Route(
|
||||
'GET',
|
||||
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
|
||||
@ -302,7 +316,7 @@ class WebhookAdapter:
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
multipart: Optional[List[Dict[str, Any]]] = None,
|
||||
files: Optional[List[File]] = None,
|
||||
):
|
||||
) -> MessagePayload:
|
||||
route = Route(
|
||||
'PATCH',
|
||||
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
|
||||
@ -319,7 +333,7 @@ class WebhookAdapter:
|
||||
message_id: int,
|
||||
*,
|
||||
session: Session,
|
||||
):
|
||||
) -> None:
|
||||
route = Route(
|
||||
'DELETE',
|
||||
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
|
||||
@ -335,7 +349,7 @@ class WebhookAdapter:
|
||||
token: str,
|
||||
*,
|
||||
session: Session,
|
||||
):
|
||||
) -> WebhookPayload:
|
||||
route = Route('GET', '/webhooks/{webhook_id}', webhook_id=webhook_id)
|
||||
return self.request(route, session=session, auth_token=token)
|
||||
|
||||
@ -345,7 +359,7 @@ class WebhookAdapter:
|
||||
token: str,
|
||||
*,
|
||||
session: Session,
|
||||
):
|
||||
) -> WebhookPayload:
|
||||
route = Route('GET', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
|
||||
return self.request(route, session=session)
|
||||
|
||||
@ -569,11 +583,17 @@ class SyncWebhook(BaseWebhook):
|
||||
|
||||
__slots__: Tuple[str, ...] = ('session',)
|
||||
|
||||
def __init__(self, data: WebhookPayload, session: Session, token: Optional[str] = None, state=None):
|
||||
def __init__(
|
||||
self,
|
||||
data: WebhookPayload,
|
||||
session: Session,
|
||||
token: Optional[str] = None,
|
||||
state: Optional[Union[ConnectionState, _WebhookState]] = None,
|
||||
) -> None:
|
||||
super().__init__(data, token, state)
|
||||
self.session = session
|
||||
self.session: Session = session
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f'<Webhook id={self.id!r}>'
|
||||
|
||||
@property
|
||||
@ -812,7 +832,7 @@ class SyncWebhook(BaseWebhook):
|
||||
|
||||
return SyncWebhook(data=data, session=self.session, token=self.auth_token, state=self._state)
|
||||
|
||||
def _create_message(self, data):
|
||||
def _create_message(self, data: MessagePayload) -> SyncWebhookMessage:
|
||||
state = _WebhookState(self, parent=self._state)
|
||||
# state may be artificial (unlikely at this point...)
|
||||
channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore
|
||||
|
@ -278,7 +278,7 @@ class Widget:
|
||||
def __str__(self) -> str:
|
||||
return self.json_url
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, Widget):
|
||||
return self.id == other.id
|
||||
return False
|
||||
|
Loading…
x
Reference in New Issue
Block a user