Merge branch '2.0' into pr7268

# Conflicts:
#	discord/raw_models.py
This commit is contained in:
Arthur Jovart
2021-09-01 22:27:34 +02:00
115 changed files with 5173 additions and 37532 deletions

View File

@ -60,13 +60,15 @@ from .interactions import *
from .components import *
from .threads import *
class VersionInfo(NamedTuple):
major: int
minor: int
micro: int
releaselevel: Literal["alpha", "beta", "candidate", "final"]
serial: int
version_info = VersionInfo(major=2, minor=0, micro=0, releaselevel='alpha', serial=0)
class VersionInfo(NamedTuple):
major: int
minor: int
micro: int
releaselevel: Literal["alpha", "beta", "candidate", "final"]
serial: int
version_info: VersionInfo = VersionInfo(major=2, minor=0, micro=0, releaselevel='alpha', serial=0)
logging.getLogger(__name__).addHandler(logging.NullHandler())

View File

@ -51,7 +51,7 @@ def core(parser, args):
if args.version:
show_version()
bot_template = """#!/usr/bin/env python3
_bot_template = """#!/usr/bin/env python3
from discord.ext import commands
import discord
@ -77,7 +77,7 @@ bot = Bot()
bot.run(config.token)
"""
gitignore_template = """# Byte-compiled / optimized / DLL files
_gitignore_template = """# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
@ -107,7 +107,7 @@ var/
config.py
"""
cog_template = '''from discord.ext import commands
_cog_template = '''from discord.ext import commands
import discord
class {name}(commands.Cog{attrs}):
@ -120,7 +120,7 @@ def setup(bot):
bot.add_cog({name}(bot))
'''
cog_extras = '''
_cog_extras = '''
def cog_unload(self):
# clean up logic goes here
pass
@ -170,7 +170,7 @@ _base_table = {
# NUL (0) and 1-31 are disallowed
_base_table.update((chr(i), None) for i in range(32))
translation_table = str.maketrans(_base_table)
_translation_table = str.maketrans(_base_table)
def to_path(parser, name, *, replace_spaces=False):
if isinstance(name, Path):
@ -182,7 +182,7 @@ def to_path(parser, name, *, replace_spaces=False):
if len(name) <= 4 and name.upper() in forbidden:
parser.error('invalid directory name given, use a different one')
name = name.translate(translation_table)
name = name.translate(_translation_table)
if replace_spaces:
name = name.replace(' ', '-')
return Path(name)
@ -215,14 +215,14 @@ def newbot(parser, args):
try:
with open(str(new_directory / 'bot.py'), 'w', encoding='utf-8') as fp:
base = 'Bot' if not args.sharded else 'AutoShardedBot'
fp.write(bot_template.format(base=base, prefix=args.prefix))
fp.write(_bot_template.format(base=base, prefix=args.prefix))
except OSError as exc:
parser.error(f'could not create bot file ({exc})')
if not args.no_git:
try:
with open(str(new_directory / '.gitignore'), 'w', encoding='utf-8') as fp:
fp.write(gitignore_template)
fp.write(_gitignore_template)
except OSError as exc:
print(f'warning: could not create .gitignore file ({exc})')
@ -240,7 +240,7 @@ def newcog(parser, args):
try:
with open(str(directory), 'w', encoding='utf-8') as fp:
attrs = ''
extra = cog_extras if args.full else ''
extra = _cog_extras if args.full else ''
if args.class_name:
name = args.class_name
else:
@ -255,7 +255,7 @@ def newcog(parser, args):
attrs += f', name="{args.display_name}"'
if args.hide_commands:
attrs += ', command_attrs=dict(hidden=True)'
fp.write(cog_template.format(name=name, extra=extra, attrs=attrs))
fp.write(_cog_template.format(name=name, extra=extra, attrs=attrs))
except OSError as exc:
parser.error(f'could not create cog file ({exc})')
else:

View File

@ -28,14 +28,14 @@ import copy
import asyncio
from typing import (
Any,
Callable,
Dict,
List,
Mapping,
Optional,
TYPE_CHECKING,
Protocol,
Sequence,
Tuple,
Type,
TypeVar,
Union,
overload,
@ -52,6 +52,7 @@ from .role import Role
from .invite import Invite
from .file import File
from .voice_client import VoiceClient, VoiceProtocol
from .sticker import GuildSticker, StickerItem
from . import utils
__all__ = (
@ -68,6 +69,7 @@ T = TypeVar('T', bound=VoiceProtocol)
if TYPE_CHECKING:
from datetime import datetime
from .client import Client
from .user import ClientUser
from .asset import Asset
from .state import ConnectionState
@ -76,17 +78,18 @@ if TYPE_CHECKING:
from .channel import CategoryChannel
from .embeds import Embed
from .message import Message, MessageReference, PartialMessage
from .channel import TextChannel, DMChannel, GroupChannel
from .channel import TextChannel, DMChannel, GroupChannel, PartialMessageable
from .threads import Thread
from .enums import InviteTarget
from .ui.view import View
from .types.channel import (
PermissionOverwrite as PermissionOverwritePayload,
Channel as ChannelPayload,
GuildChannel as GuildChannelPayload,
OverwriteType,
)
PartialMessageableChannel = Union[TextChannel, Thread, DMChannel]
PartialMessageableChannel = Union[TextChannel, Thread, DMChannel, PartialMessageable]
MessageableChannel = Union[PartialMessageableChannel, GroupChannel]
SnowflakeTime = Union["Snowflake", datetime]
@ -120,11 +123,6 @@ class Snowflake(Protocol):
__slots__ = ()
id: int
@property
def created_at(self) -> datetime:
""":class:`datetime.datetime`: Returns the model's creation time as an aware datetime in UTC."""
raise NotImplementedError
@runtime_checkable
class User(Snowflake, Protocol):
@ -305,11 +303,8 @@ class GuildChannel:
payload.append(d)
await http.bulk_channel_update(self.guild.id, payload, reason=reason)
self.position = position
if parent_id is not _undefined:
self.category_id = int(parent_id) if parent_id else None
async def _edit(self, options: Dict[str, Any], reason: Optional[str]):
async def _edit(self, options: Dict[str, Any], reason: Optional[str]) -> Optional[ChannelPayload]:
try:
parent = options.pop('category')
except KeyError:
@ -388,8 +383,7 @@ class GuildChannel:
options['type'] = ch_type.value
if options:
data = await self._state.http.edit_channel(self.id, reason=reason, **options)
self._update(self.guild, data)
return await self._state.http.edit_channel(self.id, reason=reason, **options)
def _fill_overwrites(self, data: GuildChannelPayload) -> None:
self._overwrites = []
@ -473,7 +467,7 @@ class GuildChannel:
return PermissionOverwrite()
@property
def overwrites(self) -> Mapping[Union[Role, Member], PermissionOverwrite]:
def overwrites(self) -> Dict[Union[Role, Member], PermissionOverwrite]:
"""Returns all of the channel's overwrites.
This is returned as a dictionary where the key contains the target which
@ -482,7 +476,7 @@ class GuildChannel:
Returns
--------
Mapping[Union[:class:`~discord.Role`, :class:`~discord.Member`], :class:`~discord.PermissionOverwrite`]
Dict[Union[:class:`~discord.Role`, :class:`~discord.Member`], :class:`~discord.PermissionOverwrite`]
The channel's permission overwrites.
"""
ret = {}
@ -1146,6 +1140,7 @@ class Messageable:
- :class:`~discord.User`
- :class:`~discord.Member`
- :class:`~discord.ext.commands.Context`
- :class:`~discord.Thread`
"""
__slots__ = ()
@ -1162,6 +1157,7 @@ class Messageable:
tts: bool = ...,
embed: Embed = ...,
file: File = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
@ -1179,6 +1175,7 @@ class Messageable:
tts: bool = ...,
embed: Embed = ...,
files: List[File] = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
@ -1196,6 +1193,7 @@ class Messageable:
tts: bool = ...,
embeds: List[Embed] = ...,
file: File = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
@ -1213,6 +1211,7 @@ class Messageable:
tts: bool = ...,
embeds: List[Embed] = ...,
files: List[File] = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
@ -1231,6 +1230,7 @@ class Messageable:
embeds=None,
file=None,
files=None,
stickers=None,
delete_after=None,
nonce=None,
allowed_mentions=None,
@ -1302,6 +1302,10 @@ class Messageable:
embeds: List[:class:`~discord.Embed`]
A list of embeds to upload. Must be a maximum of 10.
.. versionadded:: 2.0
stickers: Sequence[Union[:class:`~discord.GuildSticker`, :class:`~discord.StickerItem`]]
A list of stickers to upload. Must be a maximum of 3.
.. versionadded:: 2.0
Raises
@ -1338,6 +1342,9 @@ class Messageable:
raise InvalidArgument('embeds parameter must be a list of up to 10 elements')
embeds = [embed.to_dict() for embed in embeds]
if stickers is not None:
stickers = [sticker.id for sticker in stickers]
if allowed_mentions is not None:
if state.allowed_mentions is not None:
allowed_mentions = state.allowed_mentions.merge(allowed_mentions).to_dict()
@ -1382,6 +1389,7 @@ class Messageable:
embeds=embeds,
nonce=nonce,
message_reference=reference,
stickers=stickers,
components=components,
)
finally:
@ -1404,6 +1412,7 @@ class Messageable:
nonce=nonce,
allowed_mentions=allowed_mentions,
message_reference=reference,
stickers=stickers,
components=components,
)
finally:
@ -1419,6 +1428,7 @@ class Messageable:
nonce=nonce,
allowed_mentions=allowed_mentions,
message_reference=reference,
stickers=stickers,
components=components,
)
@ -1452,6 +1462,7 @@ class Messageable:
This means that both ``with`` and ``async with`` work with this.
Example Usage: ::
async with channel.typing():
# simulate something heavy
await asyncio.sleep(10)
@ -1610,12 +1621,20 @@ class Connectable(Protocol):
def _get_voice_state_pair(self) -> Tuple[int, int]:
raise NotImplementedError
async def connect(self, *, timeout: float = 60.0, reconnect: bool = True, cls: Type[T] = VoiceClient) -> T:
async def connect(
self,
*,
timeout: float = 60.0,
reconnect: bool = True,
cls: Callable[[Client, Connectable], T] = VoiceClient,
) -> T:
"""|coro|
Connects to voice and creates a :class:`VoiceClient` to establish
your connection to the voice server.
This requires :attr:`Intents.voice_states`.
Parameters
-----------
timeout: :class:`float`

View File

@ -830,10 +830,12 @@ def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]:
except KeyError:
return Activity(**data)
else:
return CustomActivity(name=name, **data)
# we removed the name key from data already
return CustomActivity(name=name, **data) # type: ignore
elif game_type is ActivityType.streaming:
if 'url' in data:
return Streaming(**data)
# the url won't be None here
return Streaming(**data) # type: ignore
return Activity(**data)
elif game_type is ActivityType.listening and 'sync_id' in data and 'session_id' in data:
return Spotify(**data)

View File

@ -146,7 +146,7 @@ class AppInfo:
self.rpc_origins: List[str] = data['rpc_origins']
self.bot_public: bool = data['bot_public']
self.bot_require_code_grant: bool = data['bot_require_code_grant']
self.owner: User = state.store_user(data['owner'])
self.owner: User = state.create_user(data['owner'])
team: Optional[TeamPayload] = data.get('team')
self.team: Optional[Team] = Team(state, team) if team else None

View File

@ -177,6 +177,17 @@ class Asset(AssetMixin):
animated=animated,
)
@classmethod
def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) -> Asset:
animated = avatar.startswith('a_')
format = 'gif' if animated else 'png'
return cls(
state,
url=f"{cls.BASE}/guilds/{guild_id}/users/{member_id}/avatars/{avatar}.{format}?size=1024",
key=avatar,
animated=animated,
)
@classmethod
def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset:
return cls(
@ -216,14 +227,25 @@ class Asset(AssetMixin):
)
@classmethod
def _from_sticker(cls, state, sticker_id: int, sticker_hash: str) -> Asset:
def _from_sticker_banner(cls, state, banner: int) -> Asset:
return cls(
state,
url=f'{cls.BASE}/stickers/{sticker_id}/{sticker_hash}.png?size=1024',
key=sticker_hash,
url=f'{cls.BASE}/app-assets/710982414301790216/store/{banner}.png',
key=str(banner),
animated=False,
)
@classmethod
def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset:
animated = banner_hash.startswith('a_')
format = 'gif' if animated else 'png'
return cls(
state,
url=f'{cls.BASE}/banners/{user_id}/{banner_hash}.{format}?size=512',
key=banner_hash,
animated=animated
)
def __str__(self) -> str:
return self._url
@ -291,10 +313,11 @@ class Asset(AssetMixin):
if self._animated:
if format not in VALID_ASSET_FORMATS:
raise InvalidArgument(f'format must be one of {VALID_ASSET_FORMATS}')
else:
url = url.with_path(f'{path}.{format}')
elif static_format is MISSING:
if format not in VALID_STATIC_FORMATS:
raise InvalidArgument(f'format must be one of {VALID_STATIC_FORMATS}')
url = url.with_path(f'{path}.{format}')
url = url.with_path(f'{path}.{format}')
if static_format is not MISSING and not self._animated:
if static_format not in VALID_STATIC_FORMATS:

View File

@ -49,12 +49,17 @@ if TYPE_CHECKING:
from .guild import Guild
from .member import Member
from .role import Role
from .types.audit_log import AuditLogChange as AuditLogChangePayload
from .types.audit_log import AuditLogEntry as AuditLogEntryPayload
from .types.audit_log import (
AuditLogChange as AuditLogChangePayload,
AuditLogEntry as AuditLogEntryPayload,
)
from .types.channel import PermissionOverwrite as PermissionOverwritePayload
from .types.role import Role as RolePayload
from .types.snowflake import Snowflake
from .user import User
from .stage_instance import StageInstance
from .sticker import GuildSticker
from .threads import Thread
def _transform_permissions(entry: AuditLogEntry, data: str) -> Permissions:
@ -69,22 +74,21 @@ def _transform_snowflake(entry: AuditLogEntry, data: Snowflake) -> int:
return int(data)
def _transform_channel(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Object]:
def _transform_channel(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Union[abc.GuildChannel, Object]]:
if data is None:
return None
return entry.guild.get_channel(int(data)) or Object(id=data)
def _transform_owner_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Union[Member, User, None]:
def _transform_member_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Union[Member, User, None]:
if data is None:
return None
return entry._get_member(int(data))
def _transform_inviter_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Union[Member, User, None]:
def _transform_guild_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Guild]:
if data is None:
return None
return entry._get_member(int(data))
return entry._state._get_guild(data)
def _transform_overwrites(
@ -142,6 +146,11 @@ def _enum_transformer(enum: Type[T]) -> Callable[[AuditLogEntry, int], T]:
return _transform
def _transform_type(entry: AuditLogEntry, data: Union[int]) -> Union[enums.ChannelType, enums.StickerType]:
if entry.action.name.startswith('sticker_'):
return enums.try_enum(enums.StickerType, data)
else:
return enums.try_enum(enums.ChannelType, data)
class AuditLogDiff:
def __len__(self) -> int:
@ -176,8 +185,8 @@ class AuditLogChanges:
'permissions': (None, _transform_permissions),
'id': (None, _transform_snowflake),
'color': ('colour', _transform_color),
'owner_id': ('owner', _transform_owner_id),
'inviter_id': ('inviter', _transform_inviter_id),
'owner_id': ('owner', _transform_member_id),
'inviter_id': ('inviter', _transform_member_id),
'channel_id': ('channel', _transform_channel),
'afk_channel_id': ('afk_channel', _transform_channel),
'system_channel_id': ('system_channel', _transform_channel),
@ -191,12 +200,15 @@ class AuditLogChanges:
'icon_hash': ('icon', _transform_icon),
'avatar_hash': ('avatar', _transform_avatar),
'rate_limit_per_user': ('slowmode_delay', None),
'guild_id': ('guild', _transform_guild_id),
'tags': ('emoji', None),
'default_message_notifications': ('default_notifications', _enum_transformer(enums.NotificationLevel)),
'region': (None, _enum_transformer(enums.VoiceRegion)),
'rtc_region': (None, _enum_transformer(enums.VoiceRegion)),
'video_quality_mode': (None, _enum_transformer(enums.VideoQualityMode)),
'privacy_level': (None, _enum_transformer(enums.StagePrivacyLevel)),
'type': (None, _enum_transformer(enums.ChannelType)),
'format_type': (None, _enum_transformer(enums.StickerFormatType)),
'type': (None, _transform_type),
}
# fmt: on
@ -318,6 +330,10 @@ class AuditLogEntry(Hashable):
Returns the entry's hash.
.. describe:: int(x)
Returns the entry's ID.
.. versionchanged:: 1.7
Audit log entries are now comparable and hashable.
@ -434,7 +450,7 @@ class AuditLogEntry(Hashable):
return utils.snowflake_time(self.id)
@utils.cached_property
def target(self) -> Union[Guild, abc.GuildChannel, Member, User, Role, Invite, Emoji, Object, None]:
def target(self) -> Union[Guild, abc.GuildChannel, Member, User, Role, Invite, Emoji, StageInstance, GuildSticker, Thread, Object, None]:
try:
converter = getattr(self, '_convert_target_' + self.action.target_type)
except AttributeError:
@ -501,3 +517,12 @@ class AuditLogEntry(Hashable):
def _convert_target_message(self, target_id: int) -> Union[Member, User, None]:
return self._get_member(target_id)
def _convert_target_stage_instance(self, target_id: int) -> Union[StageInstance, Object]:
return self.guild.get_stage_instance(target_id) or Object(id=target_id)
def _convert_target_sticker(self, target_id: int) -> Union[GuildSticker, Object]:
return self._state.get_sticker(target_id) or Object(id=target_id)
def _convert_target_thread(self, target_id: int) -> Union[Thread, Object]:
return self.guild.get_thread(target_id) or Object(id=target_id)

View File

@ -26,13 +26,28 @@ from __future__ import annotations
import time
import asyncio
from typing import Any, Callable, Dict, Iterable, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union, overload
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
TYPE_CHECKING,
Tuple,
Type,
TypeVar,
Union,
overload,
)
import datetime
import discord.abc
from .permissions import PermissionOverwrite, Permissions
from .enums import ChannelType, StagePrivacyLevel, try_enum, VoiceRegion, VideoQualityMode
from .mixins import Hashable
from .object import Object
from . import utils
from .utils import MISSING
from .asset import Asset
@ -49,6 +64,7 @@ __all__ = (
'CategoryChannel',
'StoreChannel',
'GroupChannel',
'PartialMessageable',
)
if TYPE_CHECKING:
@ -70,6 +86,7 @@ if TYPE_CHECKING:
StoreChannel as StoreChannelPayload,
GroupDMChannel as GroupChannelPayload,
)
from .types.snowflake import SnowflakeList
async def _single_delete_strategy(messages: Iterable[Message]):
@ -98,6 +115,10 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
Returns the channel's name.
.. describe:: int(x)
Returns the channel's ID.
Attributes
-----------
name: :class:`str`
@ -127,6 +148,10 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
.. note::
To check if the channel or the guild of that channel are marked as NSFW, consider :meth:`is_nsfw` instead.
default_auto_archive_duration: :class:`int`
The default auto archive duration in minutes for threads created in this channel.
.. versionadded:: 2.0
"""
__slots__ = (
@ -142,6 +167,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
'_overwrites',
'_type',
'last_message_id',
'default_auto_archive_duration',
)
def __init__(self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload):
@ -171,6 +197,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
self.nsfw: bool = data.get('nsfw', False)
# Does this need coercion into `int`? No idea yet.
self.slowmode_delay: int = data.get('rate_limit_per_user', 0)
self.default_auto_archive_duration: ThreadArchiveDuration = data.get('default_auto_archive_duration', 1440)
self._type: int = data.get('type', self._type)
self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id')
self._fill_overwrites(data)
@ -207,7 +234,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
.. versionadded:: 2.0
"""
return [thread for thread in self.guild.threads if thread.parent_id == self.id]
return [thread for thread in self.guild._threads.values() if thread.parent_id == self.id]
def is_nsfw(self) -> bool:
""":class:`bool`: Checks if the channel is NSFW."""
@ -250,13 +277,14 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
sync_permissions: bool = ...,
category: Optional[CategoryChannel] = ...,
slowmode_delay: int = ...,
default_auto_archive_duration: ThreadArchiveDuration = ...,
type: ChannelType = ...,
overwrites: Dict[Union[Role, Member, Snowflake], PermissionOverwrite] = ...,
) -> None:
overwrites: Mapping[Union[Role, Member, Snowflake], PermissionOverwrite] = ...,
) -> Optional[TextChannel]:
...
@overload
async def edit(self) -> None:
async def edit(self) -> Optional[TextChannel]:
...
async def edit(self, *, reason=None, **options):
@ -273,6 +301,9 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
.. versionchanged:: 1.4
The ``type`` keyword-only parameter was added.
.. versionchanged:: 2.0
Edits are no longer in-place, the newly edited channel is returned instead.
Parameters
----------
name: :class:`str`
@ -298,9 +329,12 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
is only available to guilds that contain ``NEWS`` in :attr:`Guild.features`.
reason: Optional[:class:`str`]
The reason for editing this channel. Shows up on the audit log.
overwrites: :class:`dict`
A :class:`dict` of target (either a role or a member) to
overwrites: :class:`Mapping`
A :class:`Mapping` of target (either a role or a member) to
:class:`PermissionOverwrite` to apply to the channel.
default_auto_archive_duration: :class:`int`
The new default auto archive duration in minutes for threads created in this channel.
Must be one of ``60``, ``1440``, ``4320``, or ``10080``.
Raises
------
@ -311,8 +345,18 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
You do not have permissions to edit the channel.
HTTPException
Editing the channel failed.
Returns
--------
Optional[:class:`.TextChannel`]
The newly edited text channel. If the edit was only positional
then ``None`` is returned instead.
"""
await self._edit(options, reason=reason)
payload = await self._edit(options, reason=reason)
if payload is not None:
# the payload will always be the proper channel payload
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
@utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> TextChannel:
@ -366,7 +410,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
if len(messages) > 100:
raise ClientException('Can only bulk delete messages up to 100 messages')
message_ids: List[int] = [m.id for m in messages]
message_ids: SnowflakeList = [m.id for m in messages]
await self._state.http.delete_messages(self.id, message_ids)
async def purge(
@ -631,64 +675,73 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
"""
return self.guild.get_thread(thread_id)
async def start_thread(
async def create_thread(
self,
*,
name: str,
message: Optional[Snowflake] = None,
auto_archive_duration: ThreadArchiveDuration = 1440,
auto_archive_duration: ThreadArchiveDuration = MISSING,
type: Optional[ChannelType] = None,
reason: Optional[str] = None,
) -> Thread:
"""|coro|
Starts a thread in this text channel.
Creates a thread in this text channel.
If no starter message is passed with the ``message`` parameter then
you must have :attr:`~discord.Permissions.send_messages` and
:attr:`~discord.Permissions.use_private_threads` in order to start the thread.
To create a public thread, you must have :attr:`~discord.Permissions.create_public_threads`.
For a private thread, :attr:`~discord.Permissions.create_private_threads` is needed instead.
If a starter message is passed with the ``message`` parameter then
you must have :attr:`~discord.Permissions.send_messages` and
:attr:`~discord.Permissions.use_threads` in order to start the thread.
.. versionadded:: 2.0
Parameters
-----------
name: :class:`str`
The name of the thread.
message: Optional[:class:`abc.Snowflake`]
A snowflake representing the message to start the thread with.
If ``None`` is passed then a private thread is started.
A snowflake representing the message to create the thread with.
If ``None`` is passed then a private thread is created.
Defaults to ``None``.
auto_archive_duration: :class:`int`
The duration in minutes before a thread is automatically archived for inactivity.
Defaults to ``1440`` or 24 hours.
If not provided, the channel's default auto archive duration is used.
type: Optional[:class:`ChannelType`]
The type of thread to create. If a ``message`` is passed then this parameter
is ignored, as a thread created with a message is always a public thread.
By default this creates a private thread if this is ``None``.
reason: :class:`str`
The reason for creating a new thread. Shows up on the audit log.
Raises
-------
Forbidden
You do not have permissions to start a thread.
You do not have permissions to create a thread.
HTTPException
Starting the thread failed.
Returns
--------
:class:`Thread`
The started thread
The created thread
"""
if type is None:
type = ChannelType.private_thread
if message is None:
data = await self._state.http.start_private_thread(
data = await self._state.http.start_thread_without_message(
self.id,
name=name,
auto_archive_duration=auto_archive_duration,
type=ChannelType.private_thread.value,
auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration,
type=type.value,
reason=reason,
)
else:
data = await self._state.http.start_public_thread(
data = await self._state.http.start_thread_with_message(
self.id,
message.id,
name=name,
auto_archive_duration=auto_archive_duration,
type=ChannelType.public_thread.value,
auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration,
reason=reason,
)
return Thread(guild=self.guild, state=self._state, data=data)
@ -706,6 +759,8 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
You must have :attr:`~Permissions.read_message_history` to use this. If iterating over private threads
then :attr:`~Permissions.manage_threads` is also required.
.. versionadded:: 2.0
Parameters
-----------
limit: Optional[:class:`bool`]
@ -734,27 +789,6 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
"""
return ArchivedThreadIterator(self.id, self.guild, limit=limit, joined=joined, private=private, before=before)
async def active_threads(self) -> List[Thread]:
"""|coro|
Returns a list of active :class:`Thread` that the client can access.
This includes both private and public threads.
Raises
------
HTTPException
The request to get the active threads failed.
Returns
--------
List[:class:`Thread`]
The archived threads
"""
data = await self._state.http.get_active_threads(self.id)
# TODO: thread members?
return [Thread(guild=self.guild, state=self._state, data=d) for d in data.get('threads', [])]
class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable):
__slots__ = (
@ -930,15 +964,15 @@ class VoiceChannel(VocalGuildChannel):
position: int = ...,
sync_permissions: int = ...,
category: Optional[CategoryChannel] = ...,
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
rtc_region: Optional[VoiceRegion] = ...,
video_quality_mode: VideoQualityMode = ...,
reason: Optional[str] = ...,
) -> None:
) -> Optional[VoiceChannel]:
...
@overload
async def edit(self) -> None:
async def edit(self) -> Optional[VoiceChannel]:
...
async def edit(self, *, reason=None, **options):
@ -952,6 +986,9 @@ class VoiceChannel(VocalGuildChannel):
.. versionchanged:: 1.3
The ``overwrites`` keyword-only parameter was added.
.. versionchanged:: 2.0
Edits are no longer in-place, the newly edited channel is returned instead.
Parameters
----------
name: :class:`str`
@ -970,8 +1007,8 @@ class VoiceChannel(VocalGuildChannel):
category.
reason: Optional[:class:`str`]
The reason for editing this channel. Shows up on the audit log.
overwrites: :class:`dict`
A :class:`dict` of target (either a role or a member) to
overwrites: :class:`Mapping`
A :class:`Mapping` of target (either a role or a member) to
:class:`PermissionOverwrite` to apply to the channel.
rtc_region: Optional[:class:`VoiceRegion`]
The new region for the voice channel's voice communication.
@ -991,9 +1028,18 @@ class VoiceChannel(VocalGuildChannel):
You do not have permissions to edit the channel.
HTTPException
Editing the channel failed.
Returns
--------
Optional[:class:`.VoiceChannel`]
The newly edited voice channel. If the edit was only positional
then ``None`` is returned instead.
"""
await self._edit(options, reason=reason)
payload = await self._edit(options, reason=reason)
if payload is not None:
# the payload will always be the proper channel payload
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
class StageChannel(VocalGuildChannel):
@ -1119,7 +1165,9 @@ class StageChannel(VocalGuildChannel):
"""
return utils.get(self.guild.stage_instances, channel_id=self.id)
async def create_instance(self, *, topic: str, privacy_level: StagePrivacyLevel = MISSING) -> StageInstance:
async def create_instance(
self, *, topic: str, privacy_level: StagePrivacyLevel = MISSING, reason: Optional[str] = None
) -> StageInstance:
"""|coro|
Create a stage instance.
@ -1135,6 +1183,8 @@ class StageChannel(VocalGuildChannel):
The stage instance's topic.
privacy_level: :class:`StagePrivacyLevel`
The stage instance's privacy level. Defaults to :attr:`StagePrivacyLevel.guild_only`.
reason: :class:`str`
The reason the stage instance was created. Shows up on the audit log.
Raises
------
@ -1159,7 +1209,7 @@ class StageChannel(VocalGuildChannel):
payload['privacy_level'] = privacy_level.value
data = await self._state.http.create_stage_instance(**payload)
data = await self._state.http.create_stage_instance(**payload, reason=reason)
return StageInstance(guild=self.guild, state=self._state, data=data)
async def fetch_instance(self) -> StageInstance:
@ -1193,15 +1243,15 @@ class StageChannel(VocalGuildChannel):
position: int = ...,
sync_permissions: int = ...,
category: Optional[CategoryChannel] = ...,
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
rtc_region: Optional[VoiceRegion] = ...,
video_quality_mode: VideoQualityMode = ...,
reason: Optional[str] = ...,
) -> None:
) -> Optional[StageChannel]:
...
@overload
async def edit(self) -> None:
async def edit(self) -> Optional[StageChannel]:
...
async def edit(self, *, reason=None, **options):
@ -1215,6 +1265,9 @@ class StageChannel(VocalGuildChannel):
.. versionchanged:: 2.0
The ``topic`` parameter must now be set via :attr:`create_instance`.
.. versionchanged:: 2.0
Edits are no longer in-place, the newly edited channel is returned instead.
Parameters
----------
name: :class:`str`
@ -1229,8 +1282,8 @@ class StageChannel(VocalGuildChannel):
category.
reason: Optional[:class:`str`]
The reason for editing this channel. Shows up on the audit log.
overwrites: :class:`dict`
A :class:`dict` of target (either a role or a member) to
overwrites: :class:`Mapping`
A :class:`Mapping` of target (either a role or a member) to
:class:`PermissionOverwrite` to apply to the channel.
rtc_region: Optional[:class:`VoiceRegion`]
The new region for the stage channel's voice communication.
@ -1248,9 +1301,18 @@ class StageChannel(VocalGuildChannel):
You do not have permissions to edit the channel.
HTTPException
Editing the channel failed.
Returns
--------
Optional[:class:`.StageChannel`]
The newly edited stage channel. If the edit was only positional
then ``None`` is returned instead.
"""
await self._edit(options, reason=reason)
payload = await self._edit(options, reason=reason)
if payload is not None:
# the payload will always be the proper channel payload
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
class CategoryChannel(discord.abc.GuildChannel, Hashable):
@ -1276,6 +1338,10 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
Returns the category's name.
.. describe:: int(x)
Returns the category's ID.
Attributes
-----------
name: :class:`str`
@ -1337,13 +1403,13 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
name: str = ...,
position: int = ...,
nsfw: bool = ...,
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
reason: Optional[str] = ...,
) -> None:
) -> Optional[CategoryChannel]:
...
@overload
async def edit(self) -> None:
async def edit(self) -> Optional[CategoryChannel]:
...
async def edit(self, *, reason=None, **options):
@ -1357,6 +1423,9 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
.. versionchanged:: 1.3
The ``overwrites`` keyword-only parameter was added.
.. versionchanged:: 2.0
Edits are no longer in-place, the newly edited channel is returned instead.
Parameters
----------
name: :class:`str`
@ -1367,8 +1436,8 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
To mark the category as NSFW or not.
reason: Optional[:class:`str`]
The reason for editing this category. Shows up on the audit log.
overwrites: :class:`dict`
A :class:`dict` of target (either a role or a member) to
overwrites: :class:`Mapping`
A :class:`Mapping` of target (either a role or a member) to
:class:`PermissionOverwrite` to apply to the channel.
Raises
@ -1379,9 +1448,18 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
You do not have permissions to edit the category.
HTTPException
Editing the category failed.
Returns
--------
Optional[:class:`.CategoryChannel`]
The newly edited category channel. If the edit was only positional
then ``None`` is returned instead.
"""
await self._edit(options=options, reason=reason)
payload = await self._edit(options, reason=reason)
if payload is not None:
# the payload will always be the proper channel payload
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
@utils.copy_doc(discord.abc.GuildChannel.move)
async def move(self, **kwargs):
@ -1486,6 +1564,10 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
Returns the channel's name.
.. describe:: int(x)
Returns the channel's ID.
Attributes
-----------
name: :class:`str`
@ -1570,12 +1652,12 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
sync_permissions: bool = ...,
category: Optional[CategoryChannel],
reason: Optional[str],
overwrites: Dict[Union[Role, Member], PermissionOverwrite],
) -> None:
overwrites: Mapping[Union[Role, Member], PermissionOverwrite],
) -> Optional[StoreChannel]:
...
@overload
async def edit(self) -> None:
async def edit(self) -> Optional[StoreChannel]:
...
async def edit(self, *, reason=None, **options):
@ -1586,6 +1668,9 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
You must have the :attr:`~Permissions.manage_channels` permission to
use this.
.. versionchanged:: 2.0
Edits are no longer in-place, the newly edited channel is returned instead.
Parameters
----------
name: :class:`str`
@ -1602,8 +1687,8 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
category.
reason: Optional[:class:`str`]
The reason for editing this channel. Shows up on the audit log.
overwrites: :class:`dict`
A :class:`dict` of target (either a role or a member) to
overwrites: :class:`Mapping`
A :class:`Mapping` of target (either a role or a member) to
:class:`PermissionOverwrite` to apply to the channel.
.. versionadded:: 1.3
@ -1617,8 +1702,18 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
You do not have permissions to edit the channel.
HTTPException
Editing the channel failed.
Returns
--------
Optional[:class:`.StoreChannel`]
The newly edited store channel. If the edit was only positional
then ``None`` is returned instead.
"""
await self._edit(options, reason=reason)
payload = await self._edit(options, reason=reason)
if payload is not None:
# the payload will always be the proper channel payload
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
DMC = TypeVar('DMC', bound='DMChannel')
@ -1645,6 +1740,10 @@ class DMChannel(discord.abc.Messageable, Hashable):
Returns a string representation of the channel
.. describe:: int(x)
Returns the channel's ID.
Attributes
----------
recipient: Optional[:class:`User`]
@ -1682,6 +1781,7 @@ class DMChannel(discord.abc.Messageable, Hashable):
self._state = state
self.id = channel_id
self.recipient = None
# state.user won't be None here
self.me = state.user # type: ignore
return self
@ -1770,6 +1870,10 @@ class GroupChannel(discord.abc.Messageable, Hashable):
Returns a string representation of the channel
.. describe:: int(x)
Returns the channel's ID.
Attributes
----------
recipients: List[:class:`User`]
@ -1892,6 +1996,73 @@ class GroupChannel(discord.abc.Messageable, Hashable):
await self._state.http.leave_group(self.id)
class PartialMessageable(discord.abc.Messageable, Hashable):
"""Represents a partial messageable to aid with working messageable channels when
only a channel ID are present.
The only way to construct this class is through :meth:`Client.get_partial_messageable`.
Note that this class is trimmed down and has no rich attributes.
.. versionadded:: 2.0
.. container:: operations
.. describe:: x == y
Checks if two partial messageables are equal.
.. describe:: x != y
Checks if two partial messageables are not equal.
.. describe:: hash(x)
Returns the partial messageable's hash.
.. describe:: int(x)
Returns the messageable's ID.
Attributes
-----------
id: :class:`int`
The channel ID associated with this partial messageable.
type: Optional[:class:`ChannelType`]
The channel type associated with this partial messageable, if given.
"""
def __init__(self, state: ConnectionState, id: int, type: Optional[ChannelType] = None):
self._state: ConnectionState = state
self._channel: Object = Object(id=id)
self.id: int = id
self.type: Optional[ChannelType] = type
async def _get_channel(self) -> Object:
return self._channel
def get_partial_message(self, message_id: int, /) -> PartialMessage:
"""Creates a :class:`PartialMessage` from the message ID.
This is useful if you want to work with a message and only have its ID without
doing an unnecessary API call.
Parameters
------------
message_id: :class:`int`
The message ID to create a partial message for.
Returns
---------
:class:`PartialMessage`
The partial message.
"""
from .message import PartialMessage
return PartialMessage(channel=self, id=message_id)
def _guild_channel_factory(channel_type: int):
value = try_enum(ChannelType, channel_type)
if value is ChannelType.text:
@ -1919,8 +2090,16 @@ def _channel_factory(channel_type: int):
else:
return cls, value
def _threaded_channel_factory(channel_type: int):
cls, value = _channel_factory(channel_type)
if value in (ChannelType.private_thread, ChannelType.public_thread, ChannelType.news_thread):
return Thread, value
return cls, value
def _threaded_guild_channel_factory(channel_type: int):
cls, value = _guild_channel_factory(channel_type)
if value in (ChannelType.private_thread, ChannelType.public_thread, ChannelType.news_thread):
return Thread, value
return cls, value

View File

@ -29,17 +29,17 @@ import logging
import signal
import sys
import traceback
from typing import Any, Callable, Coroutine, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union
from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union
import aiohttp
from .user import User
from .user import User, ClientUser
from .invite import Invite
from .template import Template
from .widget import Widget
from .guild import Guild
from .emoji import Emoji
from .channel import _threaded_channel_factory
from .channel import _threaded_channel_factory, PartialMessageable
from .enums import ChannelType
from .mentions import AllowedMentions
from .errors import *
@ -60,11 +60,11 @@ from .appinfo import AppInfo
from .ui.view import View
from .stage_instance import StageInstance
from .threads import Thread
from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory
if TYPE_CHECKING:
from .abc import SnowflakeTime, PrivateChannel, GuildChannel, Snowflake
from .channel import DMChannel
from .user import ClientUser
from .message import Message
from .member import Member
from .voice_client import VoiceProtocol
@ -76,7 +76,7 @@ __all__ = (
Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])
log: logging.Logger = logging.getLogger(__name__)
_log = logging.getLogger(__name__)
def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()}
@ -84,12 +84,12 @@ def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
if not tasks:
return
log.info('Cleaning up after %d tasks.', len(tasks))
_log.info('Cleaning up after %d tasks.', len(tasks))
for task in tasks:
task.cancel()
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
log.info('All tasks finished cancelling.')
_log.info('All tasks finished cancelling.')
for task in tasks:
if task.cancelled():
@ -106,7 +106,7 @@ def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None:
_cancel_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
log.info('Closing the event loop.')
_log.info('Closing the event loop.')
loop.close()
class Client:
@ -142,7 +142,6 @@ class Client:
intents: :class:`Intents`
The intents that you want to enable for the session. This is a way of
disabling and enabling certain gateway events from triggering and being sent.
If not given, defaults to a regularly constructed :class:`Intents` class.
.. versionadded:: 1.5
member_cache_flags: :class:`MemberCacheFlags`
@ -184,6 +183,14 @@ class Client:
sync your system clock to Google's NTP server.
.. versionadded:: 1.3
enable_debug_events: :class:`bool`
Whether to enable events that are useful only for debugging gateway related information.
Right now this involves :func:`on_socket_raw_receive` and :func:`on_socket_raw_send`. If
this is ``False`` then those events will not be dispatched (due to performance considerations).
To enable these events, this must be set to ``True``. Defaults to ``False``.
.. versionadded:: 2.0
Attributes
-----------
@ -195,9 +202,13 @@ class Client:
def __init__(
self,
*,
intents: Intents,
loop: Optional[asyncio.AbstractEventLoop] = None,
**options: Any,
):
options["intents"] = intents
# self.ws is set in the connect method
self.ws: DiscordWebSocket = None # type: ignore
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {}
@ -218,6 +229,7 @@ class Client:
'before_identify': self._call_before_identify_hook
}
self._enable_debug_events: bool = options.pop('enable_debug_events', False)
self._connection: ConnectionState = self._get_state(**options)
self._connection.shard_count = self.shard_count
self._closed: bool = False
@ -227,7 +239,7 @@ class Client:
if VoiceClient.warn_nacl:
VoiceClient.warn_nacl = False
log.warning("PyNaCl is not installed, voice will NOT be supported")
_log.warning("PyNaCl is not installed, voice will NOT be supported")
# internals
@ -277,6 +289,14 @@ class Client:
"""List[:class:`.Emoji`]: The emojis that the connected client has."""
return self._connection.emojis
@property
def stickers(self) -> List[GuildSticker]:
"""List[:class:`.GuildSticker`]: The stickers that the connected client has.
.. versionadded:: 2.0
"""
return self._connection.stickers
@property
def cached_messages(self) -> Sequence[Message]:
"""Sequence[:class:`.Message`]: Read-only list of messages the connected client has cached.
@ -311,6 +331,8 @@ class Client:
If this is not passed via ``__init__`` then this is retrieved
through the gateway when an event contains the data. Usually
after :func:`~discord.on_connect` is called.
.. versionadded:: 2.0
"""
return self._connection.application_id
@ -318,7 +340,7 @@ class Client:
def application_flags(self) -> ApplicationFlags:
""":class:`~discord.ApplicationFlags`: The client's application flags.
.. versionadded: 2.0
.. versionadded:: 2.0
"""
return self._connection.application_flags # type: ignore
@ -343,7 +365,7 @@ class Client:
return asyncio.create_task(wrapped, name=f'discord.py: {event_name}')
def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None:
log.debug('Dispatching event %s', event)
_log.debug('Dispatching event %s', event)
method = 'on_' + event
listeners = self._listeners.get(event)
@ -448,8 +470,10 @@ class Client:
passing status code.
"""
log.info('logging in using static token')
await self.http.static_login(token.strip())
_log.info('logging in using static token')
data = await self.http.static_login(token.strip())
self._connection.user = ClientUser(state=self._connection, data=data)
async def connect(self, *, reconnect: bool = True) -> None:
"""|coro|
@ -489,7 +513,7 @@ class Client:
while True:
await self.ws.poll_event()
except ReconnectWebSocket as e:
log.info('Got a request to %s the websocket.', e.op)
_log.info('Got a request to %s the websocket.', e.op)
self.dispatch('disconnect')
ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id)
continue
@ -528,7 +552,7 @@ class Client:
raise
retry = backoff.delay()
log.exception("Attempting a reconnect in %.2fs", retry)
_log.exception("Attempting a reconnect in %.2fs", retry)
await asyncio.sleep(retry)
# Always try to RESUME the connection
# If the connection is not RESUME-able then the gateway will invalidate the session.
@ -630,10 +654,10 @@ class Client:
try:
loop.run_forever()
except KeyboardInterrupt:
log.info('Received signal to terminate bot and event loop.')
_log.info('Received signal to terminate bot and event loop.')
finally:
future.remove_done_callback(stop_loop_on_completion)
log.info('Cleaning up tasks.')
_log.info('Cleaning up tasks.')
_cleanup_loop(loop)
if not future.cancelled():
@ -661,9 +685,30 @@ class Client:
if value is None:
self._connection._activity = None
elif isinstance(value, BaseActivity):
self._connection._activity = value.to_dict()
# ConnectionState._activity is typehinted as ActivityPayload, we're passing Dict[str, Any]
self._connection._activity = value.to_dict() # type: ignore
else:
raise TypeError('activity must derive from BaseActivity.')
@property
def status(self):
""":class:`.Status`:
The status being used upon logging on to Discord.
.. versionadded: 2.0
"""
if self._connection._status in set(state.value for state in Status):
return Status(self._connection._status)
return Status.online
@status.setter
def status(self, value):
if value is Status.offline:
self._connection._status = 'invisible'
elif isinstance(value, Status):
self._connection._status = str(value)
else:
raise TypeError('status must derive from Status.')
@property
def allowed_mentions(self) -> Optional[AllowedMentions]:
@ -695,8 +740,8 @@ class Client:
"""List[:class:`~discord.User`]: Returns a list of all the users the bot can see."""
return list(self._connection._users.values())
def get_channel(self, id: int) -> Optional[Union[GuildChannel, PrivateChannel]]:
"""Returns a channel with the given ID.
def get_channel(self, id: int, /) -> Optional[Union[GuildChannel, Thread, PrivateChannel]]:
"""Returns a channel or thread with the given ID.
Parameters
-----------
@ -705,12 +750,34 @@ class Client:
Returns
--------
Optional[Union[:class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`]]
Optional[Union[:class:`.abc.GuildChannel`, :class:`.Thread`, :class:`.abc.PrivateChannel`]]
The returned channel or ``None`` if not found.
"""
return self._connection.get_channel(id)
def get_stage_instance(self, id) -> Optional[StageInstance]:
def get_partial_messageable(self, id: int, *, type: Optional[ChannelType] = None) -> PartialMessageable:
"""Returns a partial messageable with the given channel ID.
This is useful if you have a channel_id but don't want to do an API call
to send messages to it.
.. versionadded:: 2.0
Parameters
-----------
id: :class:`int`
The channel ID to create a partial messageable for.
type: Optional[:class:`.ChannelType`]
The underlying channel type for the partial messageable.
Returns
--------
:class:`.PartialMessageable`
The partial messageable
"""
return PartialMessageable(state=self._connection, id=id, type=type)
def get_stage_instance(self, id: int, /) -> Optional[StageInstance]:
"""Returns a stage instance with the given stage channel ID.
.. versionadded:: 2.0
@ -732,7 +799,7 @@ class Client:
if isinstance(channel, StageChannel):
return channel.instance
def get_guild(self, id) -> Optional[Guild]:
def get_guild(self, id: int, /) -> Optional[Guild]:
"""Returns a guild with the given ID.
Parameters
@ -747,7 +814,7 @@ class Client:
"""
return self._connection._get_guild(id)
def get_user(self, id) -> Optional[User]:
def get_user(self, id: int, /) -> Optional[User]:
"""Returns a user with the given ID.
Parameters
@ -762,7 +829,7 @@ class Client:
"""
return self._connection.get_user(id)
def get_emoji(self, id) -> Optional[Emoji]:
def get_emoji(self, id: int, /) -> Optional[Emoji]:
"""Returns an emoji with the given ID.
Parameters
@ -777,6 +844,23 @@ class Client:
"""
return self._connection.get_emoji(id)
def get_sticker(self, id: int, /) -> Optional[GuildSticker]:
"""Returns a guild sticker with the given ID.
.. versionadded:: 2.0
.. note::
To retrieve standard stickers, use :meth:`.fetch_sticker`.
or :meth:`.fetch_premium_sticker_packs`.
Returns
--------
Optional[:class:`.GuildSticker`]
The sticker or ``None`` if not found.
"""
return self._connection.get_sticker(id)
def get_all_channels(self) -> Generator[GuildChannel, None, None]:
"""A generator that retrieves every :class:`.abc.GuildChannel` the client can 'access'.
@ -959,7 +1043,7 @@ class Client:
raise TypeError('event registered must be a coroutine function')
setattr(self, coro.__name__, coro)
log.debug('%s has successfully been registered as an event', coro.__name__)
_log.debug('%s has successfully been registered as an event', coro.__name__)
return coro
async def change_presence(
@ -1109,7 +1193,7 @@ class Client:
data = await self.http.get_template(code)
return Template(data=data, state=self._connection) # type: ignore
async def fetch_guild(self, guild_id: int) -> Guild:
async def fetch_guild(self, guild_id: int, /) -> Guild:
"""|coro|
Retrieves a :class:`.Guild` from an ID.
@ -1198,7 +1282,7 @@ class Client:
data = await self.http.create_guild(name, region_value, icon_base64)
return Guild(data=data, state=self._connection)
async def fetch_stage_instance(self, channel_id: int) -> StageInstance:
async def fetch_stage_instance(self, channel_id: int, /) -> StageInstance:
"""|coro|
Gets a :class:`.StageInstance` for a stage channel id.
@ -1298,7 +1382,7 @@ class Client:
# Miscellaneous stuff
async def fetch_widget(self, guild_id: int) -> Widget:
async def fetch_widget(self, guild_id: int, /) -> Widget:
"""|coro|
Gets a :class:`.Widget` from a guild ID.
@ -1348,7 +1432,7 @@ class Client:
data['rpc_origins'] = None
return AppInfo(self._connection, data)
async def fetch_user(self, user_id: int) -> User:
async def fetch_user(self, user_id: int, /) -> User:
"""|coro|
Retrieves a :class:`~discord.User` based on their ID.
@ -1379,7 +1463,7 @@ class Client:
data = await self.http.get_user(user_id)
return User(state=self._connection, data=data)
async def fetch_channel(self, channel_id: int) -> Union[GuildChannel, PrivateChannel, Thread]:
async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, PrivateChannel, Thread]:
"""|coro|
Retrieves a :class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, or :class:`.Thread` with the specified ID.
@ -1413,15 +1497,18 @@ class Client:
raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data))
if ch_type in (ChannelType.group, ChannelType.private):
channel = factory(me=self.user, data=data, state=self._connection)
# the factory will be a DMChannel or GroupChannel here
channel = factory(me=self.user, data=data, state=self._connection) # type: ignore
else:
guild_id = int(data['guild_id'])
# the factory can't be a DMChannel or GroupChannel here
guild_id = int(data['guild_id']) # type: ignore
guild = self.get_guild(guild_id) or Object(id=guild_id)
channel = factory(guild=guild, state=self._connection, data=data)
# GuildChannels expect a Guild, we may be passing an Object
channel = factory(guild=guild, state=self._connection, data=data) # type: ignore
return channel
async def fetch_webhook(self, webhook_id: int) -> Webhook:
async def fetch_webhook(self, webhook_id: int, /) -> Webhook:
"""|coro|
Retrieves a :class:`.Webhook` with the specified ID.
@ -1443,6 +1530,49 @@ class Client:
data = await self.http.get_webhook(webhook_id)
return Webhook.from_state(data, state=self._connection)
async def fetch_sticker(self, sticker_id: int, /) -> Union[StandardSticker, GuildSticker]:
"""|coro|
Retrieves a :class:`.Sticker` with the specified ID.
.. versionadded:: 2.0
Raises
--------
:exc:`.HTTPException`
Retrieving the sticker failed.
:exc:`.NotFound`
Invalid sticker ID.
Returns
--------
Union[:class:`.StandardSticker`, :class:`.GuildSticker`]
The sticker you requested.
"""
data = await self.http.get_sticker(sticker_id)
cls, _ = _sticker_factory(data['type']) # type: ignore
return cls(state=self._connection, data=data) # type: ignore
async def fetch_premium_sticker_packs(self) -> List[StickerPack]:
"""|coro|
Retrieves all available premium sticker packs.
.. versionadded:: 2.0
Raises
-------
:exc:`.HTTPException`
Retrieving the sticker packs failed.
Returns
---------
List[:class:`.StickerPack`]
All available premium sticker packs.
"""
data = await self.http.list_premium_sticker_packs()
return [StickerPack(state=self._connection, data=pack) for pack in data['sticker_packs']]
async def create_dm(self, user: Snowflake) -> DMChannel:
"""|coro|
@ -1476,6 +1606,8 @@ class Client:
This method should be used for when a view is comprised of components
that last longer than the lifecycle of the program.
.. versionadded:: 2.0
Parameters
------------
@ -1505,5 +1637,8 @@ class Client:
@property
def persistent_views(self) -> Sequence[View]:
"""Sequence[:class:`.View`]: A sequence of persistent views added to the client."""
"""Sequence[:class:`.View`]: A sequence of persistent views added to the client.
.. versionadded:: 2.0
"""
return self._connection.persistent_views

View File

@ -78,7 +78,7 @@ class Colour:
__slots__ = ('value',)
def __init__(self, value):
def __init__(self, value: int):
if not isinstance(value, int):
raise TypeError(f'Expected int parameter, received {value.__class__.__name__} instead.')
@ -171,6 +171,14 @@ class Colour:
"""A factory method that returns a :class:`Colour` with a value of ``0x11806a``."""
return cls(0x11806a)
@classmethod
def brand_green(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x57F287``.
.. versionadded:: 2.0
"""
return cls(0x57F287)
@classmethod
def green(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x2ecc71``."""
@ -231,6 +239,14 @@ class Colour:
"""A factory method that returns a :class:`Colour` with a value of ``0xa84300``."""
return cls(0xa84300)
@classmethod
def brand_red(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xED4245``.
.. versionadded:: 2.0
"""
return cls(0xED4245)
@classmethod
def red(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xe74c3c``."""
@ -308,6 +324,15 @@ class Colour:
.. versionadded:: 2.0
"""
return cls(0xFEE75C)
@classmethod
def dark_blurple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x4E5D94``.
This is the original Dark Blurple branding.
.. versionadded:: 2.0
"""
return cls(0x4E5D94)
Color = Colour

View File

@ -226,6 +226,8 @@ class SelectMenu(Component):
Defaults to 1 and must be between 1 and 25.
options: List[:class:`SelectOption`]
A list of options that can be selected in this menu.
disabled: :class:`bool`
Whether the select is disabled or not.
"""
__slots__: Tuple[str, ...] = (
@ -234,6 +236,7 @@ class SelectMenu(Component):
'min_values',
'max_values',
'options',
'disabled',
)
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
@ -245,6 +248,7 @@ class SelectMenu(Component):
self.min_values: int = data.get('min_values', 1)
self.max_values: int = data.get('max_values', 1)
self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get('options', [])]
self.disabled: bool = data.get('disabled', False)
def to_dict(self) -> SelectMenuPayload:
payload: SelectMenuPayload = {
@ -253,6 +257,7 @@ class SelectMenu(Component):
'min_values': self.min_values,
'max_values': self.max_values,
'options': [op.to_dict() for op in self.options],
'disabled': self.disabled,
}
if self.placeholder:
@ -272,14 +277,14 @@ class SelectOption:
-----------
label: :class:`str`
The label of the option. This is displayed to users.
Can only be up to 25 characters.
Can only be up to 100 characters.
value: :class:`str`
The value of the option. This is not displayed to users.
If not provided when constructed then it defaults to the
label. Can only be up to 100 characters.
description: Optional[:class:`str`]
An additional description of the option, if any.
Can only be up to 50 characters.
Can only be up to 100 characters.
emoji: Optional[Union[:class:`str`, :class:`Emoji`, :class:`PartialEmoji`]]
The emoji of the option, if available.
default: :class:`bool`

View File

@ -22,13 +22,23 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING, TypeVar, Optional, Type
if TYPE_CHECKING:
from .abc import Messageable
from types import TracebackType
TypingT = TypeVar('TypingT', bound='Typing')
__all__ = (
'Typing',
)
def _typing_done_callback(fut):
def _typing_done_callback(fut: asyncio.Future) -> None:
# just retrieve any exception and call it a day
try:
fut.exception()
@ -36,11 +46,11 @@ def _typing_done_callback(fut):
pass
class Typing:
def __init__(self, messageable):
self.loop = messageable._state.loop
self.messageable = messageable
def __init__(self, messageable: Messageable) -> None:
self.loop: asyncio.AbstractEventLoop = messageable._state.loop
self.messageable: Messageable = messageable
async def do_typing(self):
async def do_typing(self) -> None:
try:
channel = self._channel
except AttributeError:
@ -52,18 +62,26 @@ class Typing:
await typing(channel.id)
await asyncio.sleep(5)
def __enter__(self):
self.task = asyncio.ensure_future(self.do_typing(), loop=self.loop)
def __enter__(self: TypingT) -> TypingT:
self.task: asyncio.Task = self.loop.create_task(self.do_typing())
self.task.add_done_callback(_typing_done_callback)
return self
def __exit__(self, exc_type, exc, tb):
def __exit__(self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self.task.cancel()
async def __aenter__(self):
async def __aenter__(self: TypingT) -> TypingT:
self._channel = channel = await self.messageable._get_channel()
await channel._state.http.send_typing(channel.id)
return self.__enter__()
async def __aexit__(self, exc_type, exc, tb):
async def __aexit__(self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self.task.cancel()

View File

@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import datetime
from typing import Any, Dict, Final, List, Protocol, TYPE_CHECKING, Type, TypeVar, Union
from typing import Any, Dict, Final, List, Mapping, Protocol, TYPE_CHECKING, Type, TypeVar, Union
from . import utils
from .colour import Colour
@ -72,30 +72,36 @@ if TYPE_CHECKING:
T = TypeVar('T')
MaybeEmpty = Union[T, _EmptyEmbed]
class _EmbedFooterProxy(Protocol):
text: MaybeEmpty[str]
icon_url: MaybeEmpty[str]
class _EmbedFieldProxy(Protocol):
name: MaybeEmpty[str]
value: MaybeEmpty[str]
inline: bool
class _EmbedMediaProxy(Protocol):
url: MaybeEmpty[str]
proxy_url: MaybeEmpty[str]
height: MaybeEmpty[int]
width: MaybeEmpty[int]
class _EmbedVideoProxy(Protocol):
url: MaybeEmpty[str]
height: MaybeEmpty[int]
width: MaybeEmpty[int]
class _EmbedProviderProxy(Protocol):
name: MaybeEmpty[str]
url: MaybeEmpty[str]
class _EmbedAuthorProxy(Protocol):
name: MaybeEmpty[str]
url: MaybeEmpty[str]
@ -175,15 +181,15 @@ class Embed:
Empty: Final = EmptyEmbed
def __init__(
self,
*,
colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
title: MaybeEmpty[Any] = EmptyEmbed,
type: EmbedType = 'rich',
url: MaybeEmpty[Any] = EmptyEmbed,
description: MaybeEmpty[Any] = EmptyEmbed,
timestamp: datetime.datetime = None,
self,
*,
colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
title: MaybeEmpty[Any] = EmptyEmbed,
type: EmbedType = 'rich',
url: MaybeEmpty[Any] = EmptyEmbed,
description: MaybeEmpty[Any] = EmptyEmbed,
timestamp: datetime.datetime = None,
):
self.colour = colour if colour is not EmptyEmbed else color
@ -205,7 +211,7 @@ class Embed:
self.timestamp = timestamp
@classmethod
def from_dict(cls: Type[E], data: EmbedData) -> E:
def from_dict(cls: Type[E], data: Mapping[str, Any]) -> E:
"""Converts a :class:`dict` to a :class:`Embed` provided it is in the
format that Discord expects it to be in.
@ -366,7 +372,7 @@ class Embed:
self._footer['icon_url'] = str(icon_url)
return self
def remove_footer(self: E) -> E:
"""Clears embed's footer information.
@ -381,7 +387,7 @@ class Embed:
pass
return self
@property
def image(self) -> _EmbedMediaProxy:
"""Returns an ``EmbedProxy`` denoting the image contents.
@ -397,6 +403,19 @@ class Embed:
"""
return EmbedProxy(getattr(self, '_image', {})) # type: ignore
@image.setter
def image(self: E, *, url: Any):
self._image = {
'url': str(url),
}
@image.deleter
def image(self: E):
try:
del self._image
except AttributeError:
pass
def set_image(self: E, *, url: MaybeEmpty[Any]) -> E:
"""Sets the image for the embed content.
@ -413,14 +432,9 @@ class Embed:
"""
if url is EmptyEmbed:
try:
del self._image
except AttributeError:
pass
del self.image
else:
self._image = {
'url': str(url),
}
self.image = url
return self
@ -439,7 +453,25 @@ class Embed:
"""
return EmbedProxy(getattr(self, '_thumbnail', {})) # type: ignore
def set_thumbnail(self: E, *, url: MaybeEmpty[Any]) -> E:
@thumbnail.setter
def thumbnail(self: E, *, url: Any):
"""Sets the thumbnail for the embed content.
"""
self._thumbnail = {
'url': str(url),
}
return
@thumbnail.deleter
def thumbnail(self):
try:
del self.thumbnail
except AttributeError:
pass
def set_thumbnail(self: E, *, url: MaybeEmpty[Any]):
"""Sets the thumbnail for the embed content.
This function returns the class instance to allow for fluent-style
@ -453,16 +485,10 @@ class Embed:
url: :class:`str`
The source URL for the thumbnail. Only HTTP(S) is supported.
"""
if url is EmptyEmbed:
try:
del self._thumbnail
except AttributeError:
pass
del self.thumbnail
else:
self._thumbnail = {
'url': str(url),
}
self.thumbnail = url
return self

View File

@ -72,6 +72,10 @@ class Emoji(_EmojiTag, AssetMixin):
Returns the emoji rendered for discord.
.. describe:: int(x)
Returns the emoji ID.
Attributes
-----------
name: :class:`str`
@ -137,6 +141,9 @@ class Emoji(_EmojiTag, AssetMixin):
return f'<a:{self.name}:{self.id}>'
return f'<:{self.name}:{self.id}>'
def __int__(self) -> int:
return self.id
def __repr__(self) -> str:
return f'<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>'
@ -212,7 +219,7 @@ class Emoji(_EmojiTag, AssetMixin):
await self._state.http.delete_custom_emoji(self.guild.id, self.id, reason=reason)
async def edit(self, *, name: str = MISSING, roles: List[Snowflake] = MISSING, reason: Optional[str] = None) -> None:
async def edit(self, *, name: str = MISSING, roles: List[Snowflake] = MISSING, reason: Optional[str] = None) -> Emoji:
r"""|coro|
Edits the custom emoji.
@ -220,6 +227,9 @@ class Emoji(_EmojiTag, AssetMixin):
You must have :attr:`~Permissions.manage_emojis` permission to
do this.
.. versionchanged:: 2.0
The newly updated emoji is returned.
Parameters
-----------
name: :class:`str`
@ -235,6 +245,11 @@ class Emoji(_EmojiTag, AssetMixin):
You are not allowed to edit emojis.
HTTPException
An error occurred editing the emoji.
Returns
--------
:class:`Emoji`
The newly updated emoji.
"""
payload = {}
@ -243,4 +258,5 @@ class Emoji(_EmojiTag, AssetMixin):
if roles is not MISSING:
payload['roles'] = [role.id for role in roles]
await self._state.http.edit_custom_emoji(self.guild.id, self.id, payload=payload, reason=reason)
data = await self._state.http.edit_custom_emoji(self.guild.id, self.id, payload=payload, reason=reason)
return Emoji(guild=self.guild, data=data, state=self._state)

View File

@ -46,6 +46,7 @@ __all__ = (
'ExpireBehaviour',
'ExpireBehavior',
'StickerType',
'StickerFormatType',
'InviteTarget',
'VideoQualityMode',
'ComponentType',
@ -57,13 +58,17 @@ __all__ = (
)
def _create_value_cls(name):
def _create_value_cls(name, comparable):
cls = namedtuple('_EnumValue_' + name, 'name value')
cls.__repr__ = lambda self: f'<{name}.{self.name}: {self.value!r}>'
cls.__str__ = lambda self: f'{name}.{self.name}'
if comparable:
cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value
cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value
cls.__lt__ = lambda self, other: isinstance(other, self.__class__) and self.value < other.value
cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value
return cls
def _is_descriptor(obj):
return hasattr(obj, '__get__') or hasattr(obj, '__set__') or hasattr(obj, '__delete__')
@ -75,12 +80,12 @@ class EnumMeta(type):
_enum_member_map_: ClassVar[Dict[str, Any]]
_enum_value_map_: ClassVar[Dict[Any, Any]]
def __new__(cls, name, bases, attrs):
def __new__(cls, name, bases, attrs, *, comparable: bool = False):
value_mapping = {}
member_mapping = {}
member_names = []
value_cls = _create_value_cls(name)
value_cls = _create_value_cls(name, comparable)
for key, value in list(attrs.items()):
is_descriptor = _is_descriptor(value)
if key[0] == '_' and not is_descriptor:
@ -251,7 +256,7 @@ class SpeakingState(Enum):
return self.value
class VerificationLevel(Enum):
class VerificationLevel(Enum, comparable=True):
none = 0
low = 1
medium = 2
@ -262,7 +267,7 @@ class VerificationLevel(Enum):
return self.name
class ContentFilter(Enum):
class ContentFilter(Enum, comparable=True):
disabled = 0
no_role = 1
all_members = 2
@ -295,7 +300,7 @@ class DefaultAvatar(Enum):
return self.name
class NotificationLevel(Enum):
class NotificationLevel(Enum, comparable=True):
all_messages = 0
only_mentions = 1
@ -346,6 +351,12 @@ class AuditLogAction(Enum):
stage_instance_create = 83
stage_instance_update = 84
stage_instance_delete = 85
sticker_create = 90
sticker_update = 91
sticker_delete = 92
thread_create = 110
thread_update = 111
thread_delete = 112
# fmt: on
@property
@ -390,6 +401,12 @@ class AuditLogAction(Enum):
AuditLogAction.stage_instance_create: AuditLogActionCategory.create,
AuditLogAction.stage_instance_update: AuditLogActionCategory.update,
AuditLogAction.stage_instance_delete: AuditLogActionCategory.delete,
AuditLogAction.sticker_create: AuditLogActionCategory.create,
AuditLogAction.sticker_update: AuditLogActionCategory.update,
AuditLogAction.sticker_delete: AuditLogActionCategory.delete,
AuditLogAction.thread_create: AuditLogActionCategory.create,
AuditLogAction.thread_update: AuditLogActionCategory.update,
AuditLogAction.thread_delete: AuditLogActionCategory.delete,
}
# fmt: on
return lookup[self]
@ -421,6 +438,10 @@ class AuditLogAction(Enum):
return 'integration'
elif v < 90:
return 'stage_instance'
elif v < 93:
return 'sticker'
elif v < 113:
return 'thread'
class UserFlags(Enum):
@ -476,10 +497,26 @@ ExpireBehavior = ExpireBehaviour
class StickerType(Enum):
standard = 1
guild = 2
class StickerFormatType(Enum):
png = 1
apng = 2
lottie = 3
@property
def file_extension(self) -> str:
# fmt: off
lookup: Dict[StickerFormatType, str] = {
StickerFormatType.png: 'png',
StickerFormatType.apng: 'png',
StickerFormatType.lottie: 'json',
}
# fmt: on
return lookup[self]
class InviteTarget(Enum):
unknown = 0
@ -545,7 +582,7 @@ class StagePrivacyLevel(Enum):
guild_only = 2
class NSFWLevel(Enum):
class NSFWLevel(Enum, comparable=True):
default = 0
explicit = 1
safe = 2

View File

@ -31,9 +31,9 @@ if TYPE_CHECKING:
try:
from requests import Response
ResponseType = Union[ClientResponse, Response]
_ResponseType = Union[ClientResponse, Response]
except ModuleNotFoundError:
ResponseType = ClientResponse
_ResponseType = ClientResponse
from .interactions import Interaction
@ -123,8 +123,8 @@ class HTTPException(DiscordException):
The Discord specific error code for the failure.
"""
def __init__(self, response: ResponseType, message: Optional[Union[str, Dict[str, Any]]]):
self.response: ResponseType = response
def __init__(self, response: _ResponseType, message: Optional[Union[str, Dict[str, Any]]]):
self.response: _ResponseType = response
self.status: int = response.status # type: ignore
self.code: int
self.text: str

View File

@ -22,6 +22,26 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union
if TYPE_CHECKING:
from .context import Context
from .cog import Cog
from .errors import CommandError
T = TypeVar('T')
Coro = Coroutine[Any, Any, T]
MaybeCoro = Union[T, Coro[T]]
CoroFunc = Callable[..., Coro[Any]]
Check = Union[Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]]
Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]]
Error = Union[Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]]
# This is merely a tag type to avoid circular import issues.
# Yes, this is a terrible solution but ultimately it is the only solution.
class _BaseCommand:

View File

@ -22,13 +22,18 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import asyncio
import collections
import collections.abc
import inspect
import importlib.util
import sys
import traceback
import types
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union
import discord
@ -39,6 +44,15 @@ from . import errors
from .help import HelpCommand, DefaultHelpCommand
from .cog import Cog
if TYPE_CHECKING:
import importlib.machinery
from discord.message import Message
from ._types import (
Check,
CoroFunc,
)
__all__ = (
'when_mentioned',
'when_mentioned_or',
@ -46,14 +60,21 @@ __all__ = (
'AutoShardedBot',
)
def when_mentioned(bot, msg):
MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
CFT = TypeVar('CFT', bound='CoroFunc')
CXT = TypeVar('CXT', bound='Context')
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
"""A callable that implements a command prefix equivalent to being mentioned.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
"""
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> ']
# bot.user will never be None when this is called
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore
def when_mentioned_or(*prefixes):
def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]:
"""A callable that implements when mentioned or other prefixes provided.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
@ -89,7 +110,7 @@ def when_mentioned_or(*prefixes):
return inner
def _is_submodule(parent, child):
def _is_submodule(parent: str, child: str) -> bool:
return parent == child or child.startswith(parent + ".")
class _DefaultRepr:
@ -99,13 +120,13 @@ class _DefaultRepr:
_default = _DefaultRepr()
class BotBase(GroupMixin):
def __init__(self, command_prefix, help_command=_default, description=None, **options):
super().__init__(**options)
def __init__(self, command_prefix, help_command=_default, description=None, *, intents: discord.Intents, **options):
super().__init__(**options, intents=intents)
self.command_prefix = command_prefix
self.extra_events = {}
self.__cogs = {}
self.__extensions = {}
self._checks = []
self.extra_events: Dict[str, List[CoroFunc]] = {}
self.__cogs: Dict[str, Cog] = {}
self.__extensions: Dict[str, types.ModuleType] = {}
self._checks: List[Check] = []
self._check_once = []
self._before_invoke = None
self._after_invoke = None
@ -128,13 +149,15 @@ class BotBase(GroupMixin):
# internal helpers
def dispatch(self, event_name, *args, **kwargs):
super().dispatch(event_name, *args, **kwargs)
def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None:
# super() will resolve to Client
super().dispatch(event_name, *args, **kwargs) # type: ignore
ev = 'on_' + event_name
for event in self.extra_events.get(ev, []):
self._schedule_event(event, ev, *args, **kwargs)
self._schedule_event(event, ev, *args, **kwargs) # type: ignore
async def close(self):
@discord.utils.copy_doc(discord.Client.close)
async def close(self) -> None:
for extension in tuple(self.__extensions):
try:
self.unload_extension(extension)
@ -147,9 +170,9 @@ class BotBase(GroupMixin):
except Exception:
pass
await super().close()
await super().close() # type: ignore
async def on_command_error(self, context, exception):
async def on_command_error(self, context: Context, exception: errors.CommandError) -> None:
"""|coro|
The default command error handler provided by the bot.
@ -175,7 +198,7 @@ class BotBase(GroupMixin):
# global check registration
def check(self, func):
def check(self, func: T) -> T:
r"""A decorator that adds a global check to the bot.
A global check is similar to a :func:`.check` that is applied
@ -200,10 +223,11 @@ class BotBase(GroupMixin):
return ctx.command.qualified_name in allowed_commands
"""
self.add_check(func)
# T was used instead of Check to ensure the type matches on return
self.add_check(func) # type: ignore
return func
def add_check(self, func, *, call_once=False):
def add_check(self, func: Check, *, call_once: bool = False) -> None:
"""Adds a global check to the bot.
This is the non-decorator interface to :meth:`.check`
@ -223,7 +247,7 @@ class BotBase(GroupMixin):
else:
self._checks.append(func)
def remove_check(self, func, *, call_once=False):
def remove_check(self, func: Check, *, call_once: bool = False) -> None:
"""Removes a global check from the bot.
This function is idempotent and will not raise an exception
@ -244,7 +268,7 @@ class BotBase(GroupMixin):
except ValueError:
pass
def check_once(self, func):
def check_once(self, func: CFT) -> CFT:
r"""A decorator that adds a "call once" global check to the bot.
Unlike regular global checks, this one is called only once
@ -282,15 +306,16 @@ class BotBase(GroupMixin):
self.add_check(func, call_once=True)
return func
async def can_run(self, ctx, *, call_once=False):
async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool:
data = self._check_once if call_once else self._checks
if len(data) == 0:
return True
return await discord.utils.async_all(f(ctx) for f in data)
# type-checker doesn't distinguish between functions and methods
return await discord.utils.async_all(f(ctx) for f in data) # type: ignore
async def is_owner(self, user):
async def is_owner(self, user: discord.User) -> bool:
"""|coro|
Checks if a :class:`~discord.User` or :class:`~discord.Member` is the owner of
@ -319,7 +344,8 @@ class BotBase(GroupMixin):
elif self.owner_ids:
return user.id in self.owner_ids
else:
app = await self.application_info()
app = await self.application_info() # type: ignore
if app.team:
self.owner_ids = ids = {m.id for m in app.team.members}
return user.id in ids
@ -327,7 +353,7 @@ class BotBase(GroupMixin):
self.owner_id = owner_id = app.owner.id
return user.id == owner_id
def before_invoke(self, coro):
def before_invoke(self, coro: CFT) -> CFT:
"""A decorator that registers a coroutine as a pre-invoke hook.
A pre-invoke hook is called directly before the command is
@ -359,7 +385,7 @@ class BotBase(GroupMixin):
self._before_invoke = coro
return coro
def after_invoke(self, coro):
def after_invoke(self, coro: CFT) -> CFT:
r"""A decorator that registers a coroutine as a post-invoke hook.
A post-invoke hook is called directly after the command is
@ -394,14 +420,14 @@ class BotBase(GroupMixin):
# listener registration
def add_listener(self, func, name=None):
def add_listener(self, func: CoroFunc, name: str = MISSING) -> None:
"""The non decorator alternative to :meth:`.listen`.
Parameters
-----------
func: :ref:`coroutine <coroutine>`
The function to call.
name: Optional[:class:`str`]
name: :class:`str`
The name of the event to listen for. Defaults to ``func.__name__``.
Example
@ -416,7 +442,7 @@ class BotBase(GroupMixin):
bot.add_listener(my_message, 'on_message')
"""
name = func.__name__ if name is None else name
name = func.__name__ if name is MISSING else name
if not asyncio.iscoroutinefunction(func):
raise TypeError('Listeners must be coroutines')
@ -426,7 +452,7 @@ class BotBase(GroupMixin):
else:
self.extra_events[name] = [func]
def remove_listener(self, func, name=None):
def remove_listener(self, func: CoroFunc, name: str = MISSING) -> None:
"""Removes a listener from the pool of listeners.
Parameters
@ -438,7 +464,7 @@ class BotBase(GroupMixin):
``func.__name__``.
"""
name = func.__name__ if name is None else name
name = func.__name__ if name is MISSING else name
if name in self.extra_events:
try:
@ -446,7 +472,7 @@ class BotBase(GroupMixin):
except ValueError:
pass
def listen(self, name=None):
def listen(self, name: str = MISSING) -> Callable[[CFT], CFT]:
"""A decorator that registers another function as an external
event listener. Basically this allows you to listen to multiple
events from different places e.g. such as :func:`.on_ready`
@ -476,7 +502,7 @@ class BotBase(GroupMixin):
The function being listened to is not a coroutine.
"""
def decorator(func):
def decorator(func: CFT) -> CFT:
self.add_listener(func, name)
return func
@ -528,7 +554,7 @@ class BotBase(GroupMixin):
cog = cog._inject(self)
self.__cogs[cog_name] = cog
def get_cog(self, name):
def get_cog(self, name: str) -> Optional[Cog]:
"""Gets the cog instance requested.
If the cog is not found, ``None`` is returned instead.
@ -547,8 +573,8 @@ class BotBase(GroupMixin):
"""
return self.__cogs.get(name)
def remove_cog(self, name):
"""Removes a cog from the bot.
def remove_cog(self, name: str) -> Optional[Cog]:
"""Removes a cog from the bot and returns it.
All registered commands and event listeners that the
cog has registered will be removed as well.
@ -559,6 +585,11 @@ class BotBase(GroupMixin):
-----------
name: :class:`str`
The name of the cog to remove.
Returns
-------
Optional[:class:`.Cog`]
The cog that was removed. ``None`` if not found.
"""
cog = self.__cogs.pop(name, None)
@ -570,14 +601,16 @@ class BotBase(GroupMixin):
help_command.cog = None
cog._eject(self)
return cog
@property
def cogs(self):
def cogs(self) -> Mapping[str, Cog]:
"""Mapping[:class:`str`, :class:`Cog`]: A read-only mapping of cog name to cog."""
return types.MappingProxyType(self.__cogs)
# extensions
def _remove_module_references(self, name):
def _remove_module_references(self, name: str) -> None:
# find all references to the module
# remove the cogs registered from the module
for cogname, cog in self.__cogs.copy().items():
@ -601,7 +634,7 @@ class BotBase(GroupMixin):
for index in reversed(remove):
del event_list[index]
def _call_module_finalizers(self, lib, key):
def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None:
try:
func = getattr(lib, 'teardown')
except AttributeError:
@ -619,12 +652,12 @@ class BotBase(GroupMixin):
if _is_submodule(name, module):
del sys.modules[module]
def _load_from_module_spec(self, spec, key):
def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None:
# precondition: key not in self.__extensions
lib = importlib.util.module_from_spec(spec)
sys.modules[key] = lib
try:
spec.loader.exec_module(lib)
spec.loader.exec_module(lib) # type: ignore
except Exception as e:
del sys.modules[key]
raise errors.ExtensionFailed(key, e) from e
@ -645,13 +678,13 @@ class BotBase(GroupMixin):
else:
self.__extensions[key] = lib
def _resolve_name(self, name, package):
def _resolve_name(self, name: str, package: Optional[str]) -> str:
try:
return importlib.util.resolve_name(name, package)
except ImportError:
raise errors.ExtensionNotFound(name)
def load_extension(self, name, *, package=None):
def load_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Loads an extension.
An extension is a python module that contains commands, cogs, or
@ -698,7 +731,7 @@ class BotBase(GroupMixin):
self._load_from_module_spec(spec, name)
def unload_extension(self, name, *, package=None):
def unload_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Unloads an extension.
When the extension is unloaded, all commands, listeners, and cogs are
@ -739,7 +772,7 @@ class BotBase(GroupMixin):
self._remove_module_references(lib.__name__)
self._call_module_finalizers(lib, name)
def reload_extension(self, name, *, package=None):
def reload_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Atomically reloads an extension.
This replaces the extension with the same extension, only refreshed. This is
@ -795,7 +828,7 @@ class BotBase(GroupMixin):
# if the load failed, the remnants should have been
# cleaned from the load_extension function call
# so let's load it from our old compiled library.
lib.setup(self)
lib.setup(self) # type: ignore
self.__extensions[name] = lib
# revert sys.modules back to normal and raise back to caller
@ -803,18 +836,18 @@ class BotBase(GroupMixin):
raise
@property
def extensions(self):
def extensions(self) -> Mapping[str, types.ModuleType]:
"""Mapping[:class:`str`, :class:`py:types.ModuleType`]: A read-only mapping of extension name to extension."""
return types.MappingProxyType(self.__extensions)
# help command stuff
@property
def help_command(self):
def help_command(self) -> Optional[HelpCommand]:
return self._help_command
@help_command.setter
def help_command(self, value):
def help_command(self, value: Optional[HelpCommand]) -> None:
if value is not None:
if not isinstance(value, HelpCommand):
raise TypeError('help_command must be a subclass of HelpCommand')
@ -830,7 +863,7 @@ class BotBase(GroupMixin):
# command processing
async def get_prefix(self, message):
async def get_prefix(self, message: Message) -> Union[List[str], str]:
"""|coro|
Retrieves the prefix the bot is listening to
@ -868,7 +901,7 @@ class BotBase(GroupMixin):
return ret
async def get_context(self, message, *, cls=Context):
async def get_context(self, message: Message, *, cls: Type[CXT] = Context) -> CXT:
r"""|coro|
Returns the invocation context from the message.
@ -901,7 +934,7 @@ class BotBase(GroupMixin):
view = StringView(message.content)
ctx = cls(prefix=None, view=view, bot=self, message=message)
if message.author.id == self.user.id:
if message.author.id == self.user.id: # type: ignore
return ctx
prefix = await self.get_prefix(message)
@ -938,11 +971,12 @@ class BotBase(GroupMixin):
invoker = view.get_word()
ctx.invoked_with = invoker
ctx.prefix = invoked_prefix
# type-checker fails to narrow invoked_prefix type.
ctx.prefix = invoked_prefix # type: ignore
ctx.command = self.all_commands.get(invoker)
return ctx
async def invoke(self, ctx):
async def invoke(self, ctx: Context) -> None:
"""|coro|
Invokes the command given under the invocation context and
@ -968,7 +1002,7 @@ class BotBase(GroupMixin):
exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found')
self.dispatch('command_error', ctx, exc)
async def process_commands(self, message):
async def process_commands(self, message: Message) -> None:
"""|coro|
This function processes the commands that have been registered

View File

@ -21,15 +21,30 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import inspect
import discord.utils
from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Type
from ._types import _BaseCommand
if TYPE_CHECKING:
from .bot import BotBase
from .context import Context
from .core import Command
__all__ = (
'CogMeta',
'Cog',
)
CogT = TypeVar('CogT', bound='Cog')
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
MISSING: Any = discord.utils.MISSING
class CogMeta(type):
"""A metaclass for defining a cog.
@ -89,8 +104,12 @@ class CogMeta(type):
async def bar(self, ctx):
pass # hidden -> False
"""
__cog_name__: str
__cog_settings__: Dict[str, Any]
__cog_commands__: List[Command]
__cog_listeners__: List[Tuple[str, str]]
def __new__(cls, *args, **kwargs):
def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta:
name, bases, attrs = args
attrs['__cog_name__'] = kwargs.pop('name', name)
attrs['__cog_settings__'] = kwargs.pop('command_attrs', {})
@ -143,14 +162,14 @@ class CogMeta(type):
new_cls.__cog_listeners__ = listeners_as_list
return new_cls
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args)
@classmethod
def qualified_name(cls):
def qualified_name(cls) -> str:
return cls.__cog_name__
def _cog_special_method(func):
def _cog_special_method(func: FuncT) -> FuncT:
func.__cog_special_method__ = None
return func
@ -164,8 +183,12 @@ class Cog(metaclass=CogMeta):
When inheriting from this class, the options shown in :class:`CogMeta`
are equally valid here.
"""
__cog_name__: ClassVar[str]
__cog_settings__: ClassVar[Dict[str, Any]]
__cog_commands__: ClassVar[List[Command]]
__cog_listeners__: ClassVar[List[Tuple[str, str]]]
def __new__(cls, *args, **kwargs):
def __new__(cls: Type[CogT], *args: Any, **kwargs: Any) -> CogT:
# For issue 426, we need to store a copy of the command objects
# since we modify them to inject `self` to them.
# To do this, we need to interfere with the Cog creation process.
@ -173,7 +196,8 @@ class Cog(metaclass=CogMeta):
cmd_attrs = cls.__cog_settings__
# Either update the command with the cog provided defaults or copy it.
self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__)
# r.e type ignore, type-checker complains about overriding a ClassVar
self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore
lookup = {
cmd.qualified_name: cmd
@ -186,15 +210,15 @@ class Cog(metaclass=CogMeta):
parent = command.parent
if parent is not None:
# Get the latest parent reference
parent = lookup[parent.qualified_name]
parent = lookup[parent.qualified_name] # type: ignore
# Update our parent's reference to our self
parent.remove_command(command.name)
parent.add_command(command)
parent.remove_command(command.name) # type: ignore
parent.add_command(command) # type: ignore
return self
def get_commands(self):
def get_commands(self) -> List[Command]:
r"""
Returns
--------
@ -209,20 +233,20 @@ class Cog(metaclass=CogMeta):
return [c for c in self.__cog_commands__ if c.parent is None]
@property
def qualified_name(self):
def qualified_name(self) -> str:
""":class:`str`: Returns the cog's specified name, not the class name."""
return self.__cog_name__
@property
def description(self):
def description(self) -> str:
""":class:`str`: Returns the cog's description, typically the cleaned docstring."""
return self.__cog_description__
@description.setter
def description(self, description):
def description(self, description: str) -> None:
self.__cog_description__ = description
def walk_commands(self):
def walk_commands(self) -> Generator[Command, None, None]:
"""An iterator that recursively walks through this cog's commands and subcommands.
Yields
@ -237,7 +261,7 @@ class Cog(metaclass=CogMeta):
if isinstance(command, GroupMixin):
yield from command.walk_commands()
def get_listeners(self):
def get_listeners(self) -> List[Tuple[str, Callable[..., Any]]]:
"""Returns a :class:`list` of (name, function) listener pairs that are defined in this cog.
Returns
@ -248,12 +272,12 @@ class Cog(metaclass=CogMeta):
return [(name, getattr(self, method_name)) for name, method_name in self.__cog_listeners__]
@classmethod
def _get_overridden_method(cls, method):
def _get_overridden_method(cls, method: FuncT) -> Optional[FuncT]:
"""Return None if the method is not overridden. Otherwise returns the overridden method."""
return getattr(method.__func__, '__cog_special_method__', method)
@classmethod
def listener(cls, name=None):
def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]:
"""A decorator that marks a function as a listener.
This is the cog equivalent of :meth:`.Bot.listen`.
@ -271,10 +295,10 @@ class Cog(metaclass=CogMeta):
the name.
"""
if name is not None and not isinstance(name, str):
if name is not MISSING and not isinstance(name, str):
raise TypeError(f'Cog.listener expected str but received {name.__class__.__name__!r} instead.')
def decorator(func):
def decorator(func: FuncT) -> FuncT:
actual = func
if isinstance(actual, staticmethod):
actual = actual.__func__
@ -293,7 +317,7 @@ class Cog(metaclass=CogMeta):
return func
return decorator
def has_error_handler(self):
def has_error_handler(self) -> bool:
""":class:`bool`: Checks whether the cog has an error handler.
.. versionadded:: 1.7
@ -301,7 +325,7 @@ class Cog(metaclass=CogMeta):
return not hasattr(self.cog_command_error.__func__, '__cog_special_method__')
@_cog_special_method
def cog_unload(self):
def cog_unload(self) -> None:
"""A special method that is called when the cog gets removed.
This function **cannot** be a coroutine. It must be a regular
@ -312,7 +336,7 @@ class Cog(metaclass=CogMeta):
pass
@_cog_special_method
def bot_check_once(self, ctx):
def bot_check_once(self, ctx: Context) -> bool:
"""A special method that registers as a :meth:`.Bot.check_once`
check.
@ -322,7 +346,7 @@ class Cog(metaclass=CogMeta):
return True
@_cog_special_method
def bot_check(self, ctx):
def bot_check(self, ctx: Context) -> bool:
"""A special method that registers as a :meth:`.Bot.check`
check.
@ -332,7 +356,7 @@ class Cog(metaclass=CogMeta):
return True
@_cog_special_method
def cog_check(self, ctx):
def cog_check(self, ctx: Context) -> bool:
"""A special method that registers as a :func:`~discord.ext.commands.check`
for every command and subcommand in this cog.
@ -342,7 +366,7 @@ class Cog(metaclass=CogMeta):
return True
@_cog_special_method
async def cog_command_error(self, ctx, error):
async def cog_command_error(self, ctx: Context, error: Exception) -> None:
"""A special method that is called whenever an error
is dispatched inside this cog.
@ -361,7 +385,7 @@ class Cog(metaclass=CogMeta):
pass
@_cog_special_method
async def cog_before_invoke(self, ctx):
async def cog_before_invoke(self, ctx: Context) -> None:
"""A special method that acts as a cog local pre-invoke hook.
This is similar to :meth:`.Command.before_invoke`.
@ -376,7 +400,7 @@ class Cog(metaclass=CogMeta):
pass
@_cog_special_method
async def cog_after_invoke(self, ctx):
async def cog_after_invoke(self, ctx: Context) -> None:
"""A special method that acts as a cog local post-invoke hook.
This is similar to :meth:`.Command.after_invoke`.
@ -390,7 +414,7 @@ class Cog(metaclass=CogMeta):
"""
pass
def _inject(self, bot):
def _inject(self: CogT, bot: BotBase) -> CogT:
cls = self.__class__
# realistically, the only thing that can cause loading errors
@ -425,7 +449,7 @@ class Cog(metaclass=CogMeta):
return self
def _eject(self, bot):
def _eject(self, bot: BotBase) -> None:
cls = self.__class__
try:

View File

@ -21,16 +21,52 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import inspect
import re
from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, TypeVar, Union
import discord.abc
import discord.utils
import re
from discord.message import Message
if TYPE_CHECKING:
from typing_extensions import ParamSpec
from discord.abc import MessageableChannel
from discord.guild import Guild
from discord.member import Member
from discord.state import ConnectionState
from discord.user import ClientUser, User
from discord.voice_client import VoiceProtocol
from .bot import Bot, AutoShardedBot
from .cog import Cog
from .core import Command
from .help import HelpCommand
from .view import StringView
__all__ = (
'Context',
)
class Context(discord.abc.Messageable):
MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
CogT = TypeVar('CogT', bound="Cog")
if TYPE_CHECKING:
P = ParamSpec('P')
else:
P = TypeVar('P')
class Context(discord.abc.Messageable, Generic[BotT]):
r"""Represents the context in which a command is being invoked under.
This class contains a lot of meta data to help you understand more about
@ -58,11 +94,11 @@ class Context(discord.abc.Messageable):
This is only of use for within converters.
.. versionadded:: 2.0
prefix: :class:`str`
prefix: Optional[:class:`str`]
The prefix that was used to invoke the command.
command: :class:`Command`
command: Optional[:class:`Command`]
The command that is being invoked currently.
invoked_with: :class:`str`
invoked_with: Optional[:class:`str`]
The command name that triggered this invocation. Useful for finding out
which alias called the command.
invoked_parents: List[:class:`str`]
@ -73,7 +109,7 @@ class Context(discord.abc.Messageable):
.. versionadded:: 1.7
invoked_subcommand: :class:`Command`
invoked_subcommand: Optional[:class:`Command`]
The subcommand that was invoked.
If no valid subcommand was invoked then this is equal to ``None``.
subcommand_passed: Optional[:class:`str`]
@ -86,23 +122,38 @@ class Context(discord.abc.Messageable):
or invoked.
"""
def __init__(self, **attrs):
self.message = attrs.pop('message', None)
self.bot = attrs.pop('bot', None)
self.args = attrs.pop('args', [])
self.kwargs = attrs.pop('kwargs', {})
self.prefix = attrs.pop('prefix')
self.command = attrs.pop('command', None)
self.view = attrs.pop('view', None)
self.invoked_with = attrs.pop('invoked_with', None)
self.invoked_parents = attrs.pop('invoked_parents', [])
self.invoked_subcommand = attrs.pop('invoked_subcommand', None)
self.subcommand_passed = attrs.pop('subcommand_passed', None)
self.command_failed = attrs.pop('command_failed', False)
self.current_parameter = attrs.pop('current_parameter', None)
self._state = self.message._state
def __init__(self,
*,
message: Message,
bot: BotT,
view: StringView,
args: List[Any] = MISSING,
kwargs: Dict[str, Any] = MISSING,
prefix: Optional[str] = None,
command: Optional[Command] = None,
invoked_with: Optional[str] = None,
invoked_parents: List[str] = MISSING,
invoked_subcommand: Optional[Command] = None,
subcommand_passed: Optional[str] = None,
command_failed: bool = False,
current_parameter: Optional[inspect.Parameter] = None,
):
self.message: Message = message
self.bot: BotT = bot
self.args: List[Any] = args or []
self.kwargs: Dict[str, Any] = kwargs or {}
self.prefix: Optional[str] = prefix
self.command: Optional[Command] = command
self.view: StringView = view
self.invoked_with: Optional[str] = invoked_with
self.invoked_parents: List[str] = invoked_parents or []
self.invoked_subcommand: Optional[Command] = invoked_subcommand
self.subcommand_passed: Optional[str] = subcommand_passed
self.command_failed: bool = command_failed
self.current_parameter: Optional[inspect.Parameter] = current_parameter
self._state: ConnectionState = self.message._state
async def invoke(self, command, /, *args, **kwargs):
async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
r"""|coro|
Calls a command with the arguments given.
@ -124,7 +175,7 @@ class Context(discord.abc.Messageable):
command: :class:`.Command`
The command that is going to be called.
\*args
The arguments to to use.
The arguments to use.
\*\*kwargs
The keyword arguments to use.
@ -133,17 +184,9 @@ class Context(discord.abc.Messageable):
TypeError
The command argument to invoke is missing.
"""
arguments = []
if command.cog is not None:
arguments.append(command.cog)
return await command(self, *args, **kwargs)
arguments.append(self)
arguments.extend(args)
ret = await command.callback(*arguments, **kwargs)
return ret
async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True):
async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True) -> None:
"""|coro|
Calls the command again.
@ -187,7 +230,7 @@ class Context(discord.abc.Messageable):
if restart:
to_call = cmd.root_parent or cmd
view.index = len(self.prefix)
view.index = len(self.prefix or '')
view.previous = 0
self.invoked_parents = []
self.invoked_with = view.get_word() # advance to get the root command
@ -206,20 +249,23 @@ class Context(discord.abc.Messageable):
self.subcommand_passed = subcommand_passed
@property
def valid(self):
def valid(self) -> bool:
""":class:`bool`: Checks if the invocation context is valid to be invoked with."""
return self.prefix is not None and self.command is not None
async def _get_channel(self):
async def _get_channel(self) -> discord.abc.Messageable:
return self.channel
@property
def clean_prefix(self):
def clean_prefix(self) -> str:
""":class:`str`: The cleaned up invoke prefix. i.e. mentions are ``@name`` instead of ``<@id>``.
.. versionadded:: 2.0
"""
user = self.guild.me if self.guild else self.bot.user
if self.prefix is None:
return ''
user = self.me
# this breaks if the prefix mention is not the bot itself but I
# consider this to be an *incredibly* strange use case. I'd rather go
# for this common use case rather than waste performance for the
@ -228,7 +274,7 @@ class Context(discord.abc.Messageable):
return pattern.sub("@%s" % user.display_name.replace('\\', r'\\'), self.prefix)
@property
def cog(self):
def cog(self) -> Optional[Cog]:
"""Optional[:class:`.Cog`]: Returns the cog associated with this context's command. None if it does not exist."""
if self.command is None:
@ -236,38 +282,39 @@ class Context(discord.abc.Messageable):
return self.command.cog
@discord.utils.cached_property
def guild(self):
def guild(self) -> Optional[Guild]:
"""Optional[:class:`.Guild`]: Returns the guild associated with this context's command. None if not available."""
return self.message.guild
@discord.utils.cached_property
def channel(self):
def channel(self) -> MessageableChannel:
"""Union[:class:`.abc.Messageable`]: Returns the channel associated with this context's command.
Shorthand for :attr:`.Message.channel`.
"""
return self.message.channel
@discord.utils.cached_property
def author(self):
def author(self) -> Union[User, Member]:
"""Union[:class:`~discord.User`, :class:`.Member`]:
Returns the author associated with this context's command. Shorthand for :attr:`.Message.author`
"""
return self.message.author
@discord.utils.cached_property
def me(self):
def me(self) -> Union[Member, ClientUser]:
"""Union[:class:`.Member`, :class:`.ClientUser`]:
Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message contexts.
"""
return self.guild.me if self.guild is not None else self.bot.user
# bot.user will never be None at this point.
return self.guild.me if self.guild is not None else self.bot.user # type: ignore
@property
def voice_client(self):
def voice_client(self) -> Optional[VoiceProtocol]:
r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable."""
g = self.guild
return g.voice_client if g else None
async def send_help(self, *args):
async def send_help(self, *args: Any) -> Any:
"""send_help(entity=<bot>)
|coro|
@ -319,12 +366,12 @@ class Context(discord.abc.Messageable):
return None
entity = args[0]
if entity is None:
return None
if isinstance(entity, str):
entity = bot.get_cog(entity) or bot.get_command(entity)
if entity is None:
return None
try:
entity.qualified_name
except AttributeError:
@ -348,6 +395,6 @@ class Context(discord.abc.Messageable):
except CommandError as e:
await cmd.on_help_command_error(self, e)
@discord.utils.copy_doc(discord.Message.reply)
async def reply(self, content=None, **kwargs):
@discord.utils.copy_doc(Message.reply)
async def reply(self, content: Optional[str] = None, **kwargs: Any) -> Message:
return await self.message.reply(content, **kwargs)

View File

@ -74,6 +74,7 @@ __all__ = (
'StoreChannelConverter',
'ThreadConverter',
'GuildChannelConverter',
'GuildStickerConverter',
'clean_content',
'Greedy',
'run_converters',
@ -823,6 +824,45 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
raise PartialEmojiConversionFailure(argument)
class GuildStickerConverter(IDConverter[discord.GuildSticker]):
"""Converts to a :class:`~discord.GuildSticker`.
All lookups are done for the local guild first, if available. If that lookup
fails, then it checks the client's global cache.
The lookup strategy is as follows (in order):
1. Lookup by ID.
3. Lookup by name
.. versionadded:: 2.0
"""
async def convert(self, ctx: Context, argument: str) -> discord.GuildSticker:
match = self._get_id_match(argument)
result = None
bot = ctx.bot
guild = ctx.guild
if match is None:
# Try to get the sticker by name. Try local guild first.
if guild:
result = discord.utils.get(guild.stickers, name=argument)
if result is None:
result = discord.utils.get(bot.stickers, name=argument)
else:
sticker_id = int(match.group(1))
# Try to look up sticker by id.
result = bot.get_sticker(sticker_id)
if result is None:
raise GuildStickerNotFound(argument)
return result
class clean_content(Converter[str]):
"""Converts the argument to mention scrubbed version of
said content.
@ -1012,6 +1052,7 @@ CONVERTER_MAPPING: Dict[Type[Any], Any] = {
discord.StoreChannel: StoreChannelConverter,
discord.Thread: ThreadConverter,
discord.abc.GuildChannel: GuildChannelConverter,
discord.GuildSticker: GuildStickerConverter,
}

View File

@ -22,6 +22,10 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Any, Callable, Deque, Dict, Optional, Type, TypeVar, TYPE_CHECKING
from discord.enums import Enum
import time
import asyncio
@ -30,6 +34,9 @@ from collections import deque
from ...abc import PrivateChannel
from .errors import MaxConcurrencyReached
if TYPE_CHECKING:
from ...message import Message
__all__ = (
'BucketType',
'Cooldown',
@ -38,6 +45,9 @@ __all__ = (
'MaxConcurrency',
)
C = TypeVar('C', bound='CooldownMapping')
MC = TypeVar('MC', bound='MaxConcurrency')
class BucketType(Enum):
default = 0
user = 1
@ -47,7 +57,7 @@ class BucketType(Enum):
category = 5
role = 6
def get_key(self, msg):
def get_key(self, msg: Message) -> Any:
if self is BucketType.user:
return msg.author.id
elif self is BucketType.guild:
@ -57,29 +67,52 @@ class BucketType(Enum):
elif self is BucketType.member:
return ((msg.guild and msg.guild.id), msg.author.id)
elif self is BucketType.category:
return (msg.channel.category or msg.channel).id
return (msg.channel.category or msg.channel).id # type: ignore
elif self is BucketType.role:
# we return the channel id of a private-channel as there are only roles in guilds
# and that yields the same result as for a guild with only the @everyone role
# NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
# recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore
def __call__(self, msg):
def __call__(self, msg: Message) -> Any:
return self.get_key(msg)
class Cooldown:
"""Represents a cooldown for a command.
Attributes
-----------
rate: :class:`int`
The total number of tokens available per :attr:`per` seconds.
per: :class:`float`
The length of the cooldown period in seconds.
"""
__slots__ = ('rate', 'per', '_window', '_tokens', '_last')
def __init__(self, rate, per):
self.rate = int(rate)
self.per = float(per)
self._window = 0.0
self._tokens = self.rate
self._last = 0.0
def __init__(self, rate: float, per: float) -> None:
self.rate: int = int(rate)
self.per: float = float(per)
self._window: float = 0.0
self._tokens: int = self.rate
self._last: float = 0.0
def get_tokens(self, current=None):
def get_tokens(self, current: Optional[float] = None) -> int:
"""Returns the number of available tokens before rate limiting is applied.
Parameters
------------
current: Optional[:class:`float`]
The time in seconds since Unix epoch to calculate tokens at.
If not supplied then :func:`time.time()` is used.
Returns
--------
:class:`int`
The number of tokens available before the cooldown is to be applied.
"""
if not current:
current = time.time()
@ -89,7 +122,20 @@ class Cooldown:
tokens = self.rate
return tokens
def get_retry_after(self, current=None):
def get_retry_after(self, current: Optional[float] = None) -> float:
"""Returns the time in seconds until the cooldown will be reset.
Parameters
-------------
current: Optional[:class:`float`]
The current time in seconds since Unix epoch.
If not supplied, then :func:`time.time()` is used.
Returns
-------
:class:`float`
The number of seconds to wait before this cooldown will be reset.
"""
current = current or time.time()
tokens = self.get_tokens(current)
@ -98,7 +144,20 @@ class Cooldown:
return 0.0
def update_rate_limit(self, current=None):
def update_rate_limit(self, current: Optional[float] = None) -> Optional[float]:
"""Updates the cooldown rate limit.
Parameters
-------------
current: Optional[:class:`float`]
The time in seconds since Unix epoch to update the rate limit at.
If not supplied, then :func:`time.time()` is used.
Returns
-------
Optional[:class:`float`]
The retry-after time in seconds if rate limited.
"""
current = current or time.time()
self._last = current
@ -115,46 +174,58 @@ class Cooldown:
# we're not so decrement our tokens
self._tokens -= 1
def reset(self):
def reset(self) -> None:
"""Reset the cooldown to its initial state."""
self._tokens = self.rate
self._last = 0.0
def copy(self):
def copy(self) -> Cooldown:
"""Creates a copy of this cooldown.
Returns
--------
:class:`Cooldown`
A new instance of this cooldown.
"""
return Cooldown(self.rate, self.per)
def __repr__(self):
def __repr__(self) -> str:
return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>'
class CooldownMapping:
def __init__(self, original, type):
def __init__(
self,
original: Optional[Cooldown],
type: Callable[[Message], Any],
) -> None:
if not callable(type):
raise TypeError('Cooldown type must be a BucketType or callable')
self._cache = {}
self._cooldown = original
self._type = type
self._cache: Dict[Any, Cooldown] = {}
self._cooldown: Optional[Cooldown] = original
self._type: Callable[[Message], Any] = type
def copy(self):
def copy(self) -> CooldownMapping:
ret = CooldownMapping(self._cooldown, self._type)
ret._cache = self._cache.copy()
return ret
@property
def valid(self):
def valid(self) -> bool:
return self._cooldown is not None
@property
def type(self):
def type(self) -> Callable[[Message], Any]:
return self._type
@classmethod
def from_cooldown(cls, rate, per, type):
def from_cooldown(cls: Type[C], rate, per, type) -> C:
return cls(Cooldown(rate, per), type)
def _bucket_key(self, msg):
def _bucket_key(self, msg: Message) -> Any:
return self._type(msg)
def _verify_cache_integrity(self, current=None):
def _verify_cache_integrity(self, current: Optional[float] = None) -> None:
# we want to delete all cache objects that haven't been used
# in a cooldown window. e.g. if we have a command that has a
# cooldown of 60s and it has not been used in 60s then that key should be deleted
@ -163,12 +234,12 @@ class CooldownMapping:
for k in dead_keys:
del self._cache[k]
def create_bucket(self, message):
return self._cooldown.copy()
def create_bucket(self, message: Message) -> Cooldown:
return self._cooldown.copy() # type: ignore
def get_bucket(self, message, current=None):
def get_bucket(self, message: Message, current: Optional[float] = None) -> Cooldown:
if self._type is BucketType.default:
return self._cooldown
return self._cooldown # type: ignore
self._verify_cache_integrity(current)
key = self._bucket_key(message)
@ -181,26 +252,30 @@ class CooldownMapping:
return bucket
def update_rate_limit(self, message, current=None):
def update_rate_limit(self, message: Message, current: Optional[float] = None) -> Optional[float]:
bucket = self.get_bucket(message, current)
return bucket.update_rate_limit(current)
class DynamicCooldownMapping(CooldownMapping):
def __init__(self, factory, type):
def __init__(
self,
factory: Callable[[Message], Cooldown],
type: Callable[[Message], Any]
) -> None:
super().__init__(None, type)
self._factory = factory
self._factory: Callable[[Message], Cooldown] = factory
def copy(self):
def copy(self) -> DynamicCooldownMapping:
ret = DynamicCooldownMapping(self._factory, self._type)
ret._cache = self._cache.copy()
return ret
@property
def valid(self):
def valid(self) -> bool:
return True
def create_bucket(self, message):
def create_bucket(self, message: Message) -> Cooldown:
return self._factory(message)
class _Semaphore:
@ -218,28 +293,28 @@ class _Semaphore:
__slots__ = ('value', 'loop', '_waiters')
def __init__(self, number):
self.value = number
self.loop = asyncio.get_event_loop()
self._waiters = deque()
def __init__(self, number: int) -> None:
self.value: int = number
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
self._waiters: Deque[asyncio.Future] = deque()
def __repr__(self):
def __repr__(self) -> str:
return f'<_Semaphore value={self.value} waiters={len(self._waiters)}>'
def locked(self):
def locked(self) -> bool:
return self.value == 0
def is_active(self):
def is_active(self) -> bool:
return len(self._waiters) > 0
def wake_up(self):
def wake_up(self) -> None:
while self._waiters:
future = self._waiters.popleft()
if not future.done():
future.set_result(None)
return
async def acquire(self, *, wait=False):
async def acquire(self, *, wait: bool = False) -> bool:
if not wait and self.value <= 0:
# signal that we're not acquiring
return False
@ -258,18 +333,18 @@ class _Semaphore:
self.value -= 1
return True
def release(self):
def release(self) -> None:
self.value += 1
self.wake_up()
class MaxConcurrency:
__slots__ = ('number', 'per', 'wait', '_mapping')
def __init__(self, number, *, per, wait):
self._mapping = {}
self.per = per
self.number = number
self.wait = wait
def __init__(self, number: int, *, per: BucketType, wait: bool) -> None:
self._mapping: Dict[Any, _Semaphore] = {}
self.per: BucketType = per
self.number: int = number
self.wait: bool = wait
if number <= 0:
raise ValueError('max_concurrency \'number\' cannot be less than 1')
@ -277,16 +352,16 @@ class MaxConcurrency:
if not isinstance(per, BucketType):
raise TypeError(f'max_concurrency \'per\' must be of type BucketType not {type(per)!r}')
def copy(self):
def copy(self: MC) -> MC:
return self.__class__(self.number, per=self.per, wait=self.wait)
def __repr__(self):
def __repr__(self) -> str:
return f'<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>'
def get_key(self, message):
def get_key(self, message: Message) -> Any:
return self.per.get_key(message)
async def acquire(self, message):
async def acquire(self, message: Message) -> None:
key = self.get_key(message)
try:
@ -298,7 +373,7 @@ class MaxConcurrency:
if not acquired:
raise MaxConcurrencyReached(self.number, self.per)
async def release(self, message):
async def release(self, message: Message) -> None:
# Technically there's no reason for this function to be async
# But it might be more useful in the future
key = self.get_key(message)

File diff suppressed because it is too large Load Diff

View File

@ -22,8 +22,23 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Optional, Any, TYPE_CHECKING, List, Callable, Type, Tuple, Union
from discord.errors import ClientException, DiscordException
if TYPE_CHECKING:
from inspect import Parameter
from .converter import Converter
from .context import Context
from .cooldowns import Cooldown, BucketType
from .flags import Flag
from discord.abc import GuildChannel
from discord.threads import Thread
from discord.types.snowflake import Snowflake, SnowflakeList
__all__ = (
'CommandError',
@ -54,6 +69,7 @@ __all__ = (
'RoleNotFound',
'BadInviteArgument',
'EmojiNotFound',
'GuildStickerNotFound',
'PartialEmojiConversionFailure',
'BadBoolArgument',
'MissingRole',
@ -93,7 +109,7 @@ class CommandError(DiscordException):
in a special way as they are caught and passed into a special event
from :class:`.Bot`\, :func:`.on_command_error`.
"""
def __init__(self, message=None, *args):
def __init__(self, message: Optional[str] = None, *args: Any) -> None:
if message is not None:
# clean-up @everyone and @here mentions
m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere')
@ -114,9 +130,9 @@ class ConversionError(CommandError):
The original exception that was raised. You can also get this via
the ``__cause__`` attribute.
"""
def __init__(self, converter, original):
self.converter = converter
self.original = original
def __init__(self, converter: Converter, original: Exception) -> None:
self.converter: Converter = converter
self.original: Exception = original
class UserInputError(CommandError):
"""The base exception type for errors that involve errors
@ -148,8 +164,8 @@ class MissingRequiredArgument(UserInputError):
param: :class:`inspect.Parameter`
The argument that is missing.
"""
def __init__(self, param):
self.param = param
def __init__(self, param: Parameter) -> None:
self.param: Parameter = param
super().__init__(f'{param.name} is a required argument that is missing.')
class TooManyArguments(UserInputError):
@ -190,9 +206,9 @@ class CheckAnyFailure(CheckFailure):
A list of check predicates that failed.
"""
def __init__(self, checks, errors):
self.checks = checks
self.errors = errors
def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]) -> None:
self.checks: List[CheckFailure] = checks
self.errors: List[Callable[[Context], bool]] = errors
super().__init__('You do not have permission to run this command.')
class PrivateMessageOnly(CheckFailure):
@ -201,7 +217,7 @@ class PrivateMessageOnly(CheckFailure):
This inherits from :exc:`CheckFailure`
"""
def __init__(self, message=None):
def __init__(self, message: Optional[str] = None) -> None:
super().__init__(message or 'This command can only be used in private messages.')
class NoPrivateMessage(CheckFailure):
@ -211,7 +227,7 @@ class NoPrivateMessage(CheckFailure):
This inherits from :exc:`CheckFailure`
"""
def __init__(self, message=None):
def __init__(self, message: Optional[str] = None) -> None:
super().__init__(message or 'This command cannot be used in private messages.')
class NotOwner(CheckFailure):
@ -234,8 +250,8 @@ class ObjectNotFound(BadArgument):
argument: :class:`str`
The argument supplied by the caller that was not matched
"""
def __init__(self, argument):
self.argument = argument
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'{argument!r} does not follow a valid ID or mention format.')
class MemberNotFound(BadArgument):
@ -251,8 +267,8 @@ class MemberNotFound(BadArgument):
argument: :class:`str`
The member supplied by the caller that was not found
"""
def __init__(self, argument):
self.argument = argument
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Member "{argument}" not found.')
class GuildNotFound(BadArgument):
@ -267,8 +283,8 @@ class GuildNotFound(BadArgument):
argument: :class:`str`
The guild supplied by the called that was not found
"""
def __init__(self, argument):
self.argument = argument
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Guild "{argument}" not found.')
class UserNotFound(BadArgument):
@ -284,8 +300,8 @@ class UserNotFound(BadArgument):
argument: :class:`str`
The user supplied by the caller that was not found
"""
def __init__(self, argument):
self.argument = argument
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'User "{argument}" not found.')
class MessageNotFound(BadArgument):
@ -300,8 +316,8 @@ class MessageNotFound(BadArgument):
argument: :class:`str`
The message supplied by the caller that was not found
"""
def __init__(self, argument):
self.argument = argument
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Message "{argument}" not found.')
class ChannelNotReadable(BadArgument):
@ -314,11 +330,11 @@ class ChannelNotReadable(BadArgument):
Attributes
-----------
argument: :class:`.abc.GuildChannel`
argument: Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
The channel supplied by the caller that was not readable
"""
def __init__(self, argument):
self.argument = argument
def __init__(self, argument: Union[GuildChannel, Thread]) -> None:
self.argument: Union[GuildChannel, Thread] = argument
super().__init__(f"Can't read messages in {argument.mention}.")
class ChannelNotFound(BadArgument):
@ -333,8 +349,8 @@ class ChannelNotFound(BadArgument):
argument: :class:`str`
The channel supplied by the caller that was not found
"""
def __init__(self, argument):
self.argument = argument
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Channel "{argument}" not found.')
class ThreadNotFound(BadArgument):
@ -349,8 +365,8 @@ class ThreadNotFound(BadArgument):
argument: :class:`str`
The thread supplied by the caller that was not found
"""
def __init__(self, argument):
self.argument = argument
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Thread "{argument}" not found.')
class BadColourArgument(BadArgument):
@ -365,8 +381,8 @@ class BadColourArgument(BadArgument):
argument: :class:`str`
The colour supplied by the caller that was not valid
"""
def __init__(self, argument):
self.argument = argument
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Colour "{argument}" is invalid.')
BadColorArgument = BadColourArgument
@ -383,8 +399,8 @@ class RoleNotFound(BadArgument):
argument: :class:`str`
The role supplied by the caller that was not found
"""
def __init__(self, argument):
self.argument = argument
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Role "{argument}" not found.')
class BadInviteArgument(BadArgument):
@ -394,8 +410,8 @@ class BadInviteArgument(BadArgument):
.. versionadded:: 1.5
"""
def __init__(self, argument):
self.argument = argument
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Invite "{argument}" is invalid or expired.')
class EmojiNotFound(BadArgument):
@ -410,8 +426,8 @@ class EmojiNotFound(BadArgument):
argument: :class:`str`
The emoji supplied by the caller that was not found
"""
def __init__(self, argument):
self.argument = argument
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Emoji "{argument}" not found.')
class PartialEmojiConversionFailure(BadArgument):
@ -427,10 +443,26 @@ class PartialEmojiConversionFailure(BadArgument):
argument: :class:`str`
The emoji supplied by the caller that did not match the regex
"""
def __init__(self, argument):
self.argument = argument
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Couldn\'t convert "{argument}" to PartialEmoji.')
class GuildStickerNotFound(BadArgument):
"""Exception raised when the bot can not find the sticker.
This inherits from :exc:`BadArgument`
.. versionadded:: 2.0
Attributes
-----------
argument: :class:`str`
The sticker supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Sticker "{argument}" not found.')
class BadBoolArgument(BadArgument):
"""Exception raised when a boolean argument was not convertable.
@ -443,8 +475,8 @@ class BadBoolArgument(BadArgument):
argument: :class:`str`
The boolean argument supplied by the caller that is not in the predefined list
"""
def __init__(self, argument):
self.argument = argument
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'{argument} is not a recognised boolean option')
class DisabledCommand(CommandError):
@ -465,8 +497,8 @@ class CommandInvokeError(CommandError):
The original exception that was raised. You can also get this via
the ``__cause__`` attribute.
"""
def __init__(self, e):
self.original = e
def __init__(self, e: Exception) -> None:
self.original: Exception = e
super().__init__(f'Command raised an exception: {e.__class__.__name__}: {e}')
class CommandOnCooldown(CommandError):
@ -476,7 +508,7 @@ class CommandOnCooldown(CommandError):
Attributes
-----------
cooldown: ``Cooldown``
cooldown: :class:`.Cooldown`
A class with attributes ``rate`` and ``per`` similar to the
:func:`.cooldown` decorator.
type: :class:`BucketType`
@ -484,10 +516,10 @@ class CommandOnCooldown(CommandError):
retry_after: :class:`float`
The amount of seconds to wait before you can retry again.
"""
def __init__(self, cooldown, retry_after, type):
self.cooldown = cooldown
self.retry_after = retry_after
self.type = type
def __init__(self, cooldown: Cooldown, retry_after: float, type: BucketType) -> None:
self.cooldown: Cooldown = cooldown
self.retry_after: float = retry_after
self.type: BucketType = type
super().__init__(f'You are on cooldown. Try again in {retry_after:.2f}s')
class MaxConcurrencyReached(CommandError):
@ -503,9 +535,9 @@ class MaxConcurrencyReached(CommandError):
The bucket type passed to the :func:`.max_concurrency` decorator.
"""
def __init__(self, number, per):
self.number = number
self.per = per
def __init__(self, number: int, per: BucketType) -> None:
self.number: int = number
self.per: BucketType = per
name = per.name
suffix = 'per %s' % name if per.name != 'default' else 'globally'
plural = '%s times %s' if number > 1 else '%s time %s'
@ -525,8 +557,8 @@ class MissingRole(CheckFailure):
The required role that is missing.
This is the parameter passed to :func:`~.commands.has_role`.
"""
def __init__(self, missing_role):
self.missing_role = missing_role
def __init__(self, missing_role: Snowflake) -> None:
self.missing_role: Snowflake = missing_role
message = f'Role {missing_role!r} is required to run this command.'
super().__init__(message)
@ -543,8 +575,8 @@ class BotMissingRole(CheckFailure):
The required role that is missing.
This is the parameter passed to :func:`~.commands.has_role`.
"""
def __init__(self, missing_role):
self.missing_role = missing_role
def __init__(self, missing_role: Snowflake) -> None:
self.missing_role: Snowflake = missing_role
message = f'Bot requires the role {missing_role!r} to run this command'
super().__init__(message)
@ -562,8 +594,8 @@ class MissingAnyRole(CheckFailure):
The roles that the invoker is missing.
These are the parameters passed to :func:`~.commands.has_any_role`.
"""
def __init__(self, missing_roles):
self.missing_roles = missing_roles
def __init__(self, missing_roles: SnowflakeList) -> None:
self.missing_roles: SnowflakeList = missing_roles
missing = [f"'{role}'" for role in missing_roles]
@ -591,8 +623,8 @@ class BotMissingAnyRole(CheckFailure):
These are the parameters passed to :func:`~.commands.has_any_role`.
"""
def __init__(self, missing_roles):
self.missing_roles = missing_roles
def __init__(self, missing_roles: SnowflakeList) -> None:
self.missing_roles: SnowflakeList = missing_roles
missing = [f"'{role}'" for role in missing_roles]
@ -613,11 +645,11 @@ class NSFWChannelRequired(CheckFailure):
Parameters
-----------
channel: :class:`discord.abc.GuildChannel`
channel: Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
The channel that does not have NSFW enabled.
"""
def __init__(self, channel):
self.channel = channel
def __init__(self, channel: Union[GuildChannel, Thread]) -> None:
self.channel: Union[GuildChannel, Thread] = channel
super().__init__(f"Channel '{channel}' needs to be NSFW for this command to work.")
class MissingPermissions(CheckFailure):
@ -628,11 +660,11 @@ class MissingPermissions(CheckFailure):
Attributes
-----------
missing_permissions: :class:`list`
missing_permissions: List[:class:`str`]
The required permissions that are missing.
"""
def __init__(self, missing_permissions, *args):
self.missing_permissions = missing_permissions
def __init__(self, missing_permissions: List[str], *args: Any) -> None:
self.missing_permissions: List[str] = missing_permissions
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions]
@ -651,11 +683,11 @@ class BotMissingPermissions(CheckFailure):
Attributes
-----------
missing_permissions: :class:`list`
missing_permissions: List[:class:`str`]
The required permissions that are missing.
"""
def __init__(self, missing_permissions, *args):
self.missing_permissions = missing_permissions
def __init__(self, missing_permissions: List[str], *args: Any) -> None:
self.missing_permissions: List[str] = missing_permissions
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions]
@ -681,10 +713,10 @@ class BadUnionArgument(UserInputError):
errors: List[:class:`CommandError`]
A list of errors that were caught from failing the conversion.
"""
def __init__(self, param, converters, errors):
self.param = param
self.converters = converters
self.errors = errors
def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None:
self.param: Parameter = param
self.converters: Tuple[Type, ...] = converters
self.errors: List[CommandError] = errors
def _get_name(x):
try:
@ -719,10 +751,10 @@ class BadLiteralArgument(UserInputError):
errors: List[:class:`CommandError`]
A list of errors that were caught from failing the conversion.
"""
def __init__(self, param, literals, errors):
self.param = param
self.literals = literals
self.errors = errors
def __init__(self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError]) -> None:
self.param: Parameter = param
self.literals: Tuple[Any, ...] = literals
self.errors: List[CommandError] = errors
to_string = [repr(l) for l in literals]
if len(to_string) > 2:
@ -752,8 +784,8 @@ class UnexpectedQuoteError(ArgumentParsingError):
quote: :class:`str`
The quote mark that was found inside the non-quoted string.
"""
def __init__(self, quote):
self.quote = quote
def __init__(self, quote: str) -> None:
self.quote: str = quote
super().__init__(f'Unexpected quote mark, {quote!r}, in non-quoted string')
class InvalidEndOfQuotedStringError(ArgumentParsingError):
@ -767,8 +799,8 @@ class InvalidEndOfQuotedStringError(ArgumentParsingError):
char: :class:`str`
The character found instead of the expected string.
"""
def __init__(self, char):
self.char = char
def __init__(self, char: str) -> None:
self.char: str = char
super().__init__(f'Expected space after closing quotation but received {char!r}')
class ExpectedClosingQuoteError(ArgumentParsingError):
@ -782,8 +814,8 @@ class ExpectedClosingQuoteError(ArgumentParsingError):
The quote character expected.
"""
def __init__(self, close_quote):
self.close_quote = close_quote
def __init__(self, close_quote: str) -> None:
self.close_quote: str = close_quote
super().__init__(f'Expected closing {close_quote}.')
class ExtensionError(DiscordException):
@ -796,8 +828,8 @@ class ExtensionError(DiscordException):
name: :class:`str`
The extension that had an error.
"""
def __init__(self, message=None, *args, name):
self.name = name
def __init__(self, message: Optional[str] = None, *args: Any, name: str) -> None:
self.name: str = name
message = message or f'Extension {name!r} had an error.'
# clean-up @everyone and @here mentions
m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere')
@ -808,7 +840,7 @@ class ExtensionAlreadyLoaded(ExtensionError):
This inherits from :exc:`ExtensionError`
"""
def __init__(self, name):
def __init__(self, name: str) -> None:
super().__init__(f'Extension {name!r} is already loaded.', name=name)
class ExtensionNotLoaded(ExtensionError):
@ -816,7 +848,7 @@ class ExtensionNotLoaded(ExtensionError):
This inherits from :exc:`ExtensionError`
"""
def __init__(self, name):
def __init__(self, name: str) -> None:
super().__init__(f'Extension {name!r} has not been loaded.', name=name)
class NoEntryPointError(ExtensionError):
@ -824,7 +856,7 @@ class NoEntryPointError(ExtensionError):
This inherits from :exc:`ExtensionError`
"""
def __init__(self, name):
def __init__(self, name: str) -> None:
super().__init__(f"Extension {name!r} has no 'setup' function.", name=name)
class ExtensionFailed(ExtensionError):
@ -840,8 +872,8 @@ class ExtensionFailed(ExtensionError):
The original exception that was raised. You can also get this via
the ``__cause__`` attribute.
"""
def __init__(self, name, original):
self.original = original
def __init__(self, name: str, original: Exception) -> None:
self.original: Exception = original
msg = f'Extension {name!r} raised an error: {original.__class__.__name__}: {original}'
super().__init__(msg, name=name)
@ -858,7 +890,7 @@ class ExtensionNotFound(ExtensionError):
name: :class:`str`
The extension that had the error.
"""
def __init__(self, name):
def __init__(self, name: str) -> None:
msg = f'Extension {name!r} could not be loaded.'
super().__init__(msg, name=name)
@ -877,9 +909,9 @@ class CommandRegistrationError(ClientException):
alias_conflict: :class:`bool`
Whether the name that conflicts is an alias of the command we try to add.
"""
def __init__(self, name, *, alias_conflict=False):
self.name = name
self.alias_conflict = alias_conflict
def __init__(self, name: str, *, alias_conflict: bool = False) -> None:
self.name: str = name
self.alias_conflict: bool = alias_conflict
type_ = 'alias' if alias_conflict else 'command'
super().__init__(f'The {type_} {name} is already an existing command or alias.')
@ -906,17 +938,25 @@ class TooManyFlags(FlagError):
values: List[:class:`str`]
The values that were passed.
"""
def __init__(self, flag, values):
self.flag = flag
self.values = values
def __init__(self, flag: Flag, values: List[str]) -> None:
self.flag: Flag = flag
self.values: List[str] = values
super().__init__(f'Too many flag values, expected {flag.max_args} but received {len(values)}.')
class BadFlagArgument(FlagError):
"""An exception raised when a flag failed to convert a value.
This inherits from :exc:`FlagError`
.. versionadded:: 2.0
Attributes
-----------
flag: :class:`~discord.ext.commands.Flag`
The flag that failed to convert.
"""
def __init__(self, flag):
self.flag = flag
def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag
try:
name = flag.annotation.__name__
except AttributeError:
@ -936,8 +976,8 @@ class MissingRequiredFlag(FlagError):
flag: :class:`~discord.ext.commands.Flag`
The required flag that was not found.
"""
def __init__(self, flag):
self.flag = flag
def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag
super().__init__(f'Flag {flag.name!r} is required and missing')
class MissingFlagArgument(FlagError):
@ -952,6 +992,6 @@ class MissingFlagArgument(FlagError):
flag: :class:`~discord.ext.commands.Flag`
The flag that did not get a value.
"""
def __init__(self, flag):
self.flag = flag
def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag
super().__init__(f'Flag {flag.name!r} does not have an argument')

View File

@ -27,11 +27,17 @@ import copy
import functools
import inspect
import re
from typing import Optional, TYPE_CHECKING
import discord.utils
from .core import Group, Command
from .errors import CommandError
if TYPE_CHECKING:
from .context import Context
__all__ = (
'Paginator',
'HelpCommand',
@ -320,7 +326,7 @@ class HelpCommand:
self.command_attrs = attrs = options.pop('command_attrs', {})
attrs.setdefault('name', 'help')
attrs.setdefault('help', 'Shows this message')
self.context = None
self.context: Context = discord.utils.MISSING
self._command_impl = _HelpCommandImpl(self, **self.command_attrs)
def copy(self):

View File

@ -27,22 +27,20 @@ from __future__ import annotations
import asyncio
import datetime
from typing import (
Any,
Awaitable,
Callable,
Any,
Awaitable,
Callable,
Generic,
List,
Optional,
Type,
List,
Optional,
Type,
TypeVar,
Union,
cast,
)
import aiohttp
import discord
import inspect
import logging
import sys
import traceback
@ -50,8 +48,6 @@ from collections.abc import Sequence
from discord.backoff import ExponentialBackoff
from discord.utils import MISSING
log = logging.getLogger(__name__)
__all__ = (
'loop',
)
@ -61,7 +57,6 @@ _func = Callable[..., Awaitable[Any]]
LF = TypeVar('LF', bound=_func)
FT = TypeVar('FT', bound=_func)
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
LT = TypeVar('LT', bound='Loop')
class SleepHandle:
@ -78,7 +73,7 @@ class SleepHandle:
relative_delta = discord.utils.compute_timedelta(dt)
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True)
def wait(self) -> asyncio.Future:
def wait(self) -> asyncio.Future[Any]:
return self.future
def done(self) -> bool:
@ -94,7 +89,9 @@ class Loop(Generic[LF]):
The main interface to create this is through :func:`loop`.
"""
def __init__(self,
def __init__(
self,
coro: LF,
seconds: float,
hours: float,
@ -102,15 +99,15 @@ class Loop(Generic[LF]):
time: Union[datetime.time, Sequence[datetime.time]],
count: Optional[int],
reconnect: bool,
loop: Optional[asyncio.AbstractEventLoop],
loop: asyncio.AbstractEventLoop,
) -> None:
self.coro: LF = coro
self.reconnect: bool = reconnect
self.loop: Optional[asyncio.AbstractEventLoop] = loop
self.loop: asyncio.AbstractEventLoop = loop
self.count: Optional[int] = count
self._current_loop = 0
self._handle = None
self._task = None
self._handle: SleepHandle = MISSING
self._task: asyncio.Task[None] = MISSING
self._injected = None
self._valid_exception = (
OSError,
@ -131,7 +128,7 @@ class Loop(Generic[LF]):
self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time)
self._last_iteration_failed = False
self._last_iteration = None
self._last_iteration: datetime.datetime = MISSING
self._next_iteration = None
if not inspect.iscoroutinefunction(self.coro):
@ -147,9 +144,8 @@ class Loop(Generic[LF]):
else:
await coro(*args, **kwargs)
def _try_sleep_until(self, dt: datetime.datetime):
self._handle = SleepHandle(dt=dt, loop=self.loop) # type: ignore
self._handle = SleepHandle(dt=dt, loop=self.loop)
return self._handle.wait()
async def _loop(self, *args: Any, **kwargs: Any) -> None:
@ -178,7 +174,7 @@ class Loop(Generic[LF]):
await asyncio.sleep(backoff.delay())
else:
await self._try_sleep_until(self._next_iteration)
if self._stop_next_iteration:
return
@ -211,14 +207,14 @@ class Loop(Generic[LF]):
if obj is None:
return self
copy = Loop(
self.coro,
seconds=self._seconds,
hours=self._hours,
copy: Loop[LF] = Loop(
self.coro,
seconds=self._seconds,
hours=self._hours,
minutes=self._minutes,
time=self._time,
time=self._time,
count=self.count,
reconnect=self.reconnect,
reconnect=self.reconnect,
loop=self.loop,
)
copy._injected = obj
@ -237,7 +233,7 @@ class Loop(Generic[LF]):
"""
if self._seconds is not MISSING:
return self._seconds
@property
def minutes(self) -> Optional[float]:
"""Optional[:class:`float`]: Read-only value for the number of minutes
@ -247,7 +243,7 @@ class Loop(Generic[LF]):
"""
if self._minutes is not MISSING:
return self._minutes
@property
def hours(self) -> Optional[float]:
"""Optional[:class:`float`]: Read-only value for the number of hours
@ -279,7 +275,7 @@ class Loop(Generic[LF]):
.. versionadded:: 1.3
"""
if self._task is None:
if self._task is MISSING:
return None
elif self._task and self._task.done() or self._stop_next_iteration:
return None
@ -305,7 +301,7 @@ class Loop(Generic[LF]):
return await self.coro(*args, **kwargs)
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task:
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
r"""Starts the internal task in the event loop.
Parameters
@ -326,13 +322,13 @@ class Loop(Generic[LF]):
The task that has been created.
"""
if self._task is not None and not self._task.done():
if self._task is not MISSING and not self._task.done():
raise RuntimeError('Task is already launched and is not completed.')
if self._injected is not None:
args = (self._injected, *args)
if self.loop is None:
if self.loop is MISSING:
self.loop = asyncio.get_event_loop()
self._task = self.loop.create_task(self._loop(*args, **kwargs))
@ -356,7 +352,7 @@ class Loop(Generic[LF]):
.. versionadded:: 1.2
"""
if self._task and not self._task.done():
if self._task is not MISSING and not self._task.done():
self._stop_next_iteration = True
def _can_be_cancelled(self) -> bool:
@ -383,7 +379,7 @@ class Loop(Generic[LF]):
The keyword arguments to use.
"""
def restart_when_over(fut, *, args=args, kwargs=kwargs):
def restart_when_over(fut: Any, *, args: Any = args, kwargs: Any = kwargs) -> None:
self._task.remove_done_callback(restart_when_over)
self.start(*args, **kwargs)
@ -446,9 +442,9 @@ class Loop(Generic[LF]):
self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions)
return len(self._valid_exception) == old_length - len(exceptions)
def get_task(self) -> Optional[asyncio.Task]:
def get_task(self) -> Optional[asyncio.Task[None]]:
"""Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running."""
return self._task
return self._task if self._task is not MISSING else None
def is_being_cancelled(self) -> bool:
"""Whether the task is being cancelled."""
@ -466,7 +462,7 @@ class Loop(Generic[LF]):
.. versionadded:: 1.4
"""
return not bool(self._task.done()) if self._task else False
return not bool(self._task.done()) if self._task is not MISSING else False
async def _error(self, *args: Any) -> None:
exception: Exception = args[-1]
@ -560,7 +556,9 @@ class Loop(Generic[LF]):
self._time_index = 0
if self._current_loop == 0:
# if we're at the last index on the first iteration, we need to sleep until tomorrow
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0])
return datetime.datetime.combine(
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0]
)
next_time = self._time[self._time_index]
@ -568,7 +566,7 @@ class Loop(Generic[LF]):
self._time_index += 1
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time)
next_date = cast(datetime.datetime, self._last_iteration)
next_date = self._last_iteration
if self._time_index == 0:
# we can assume that the earliest time should be scheduled for "tomorrow"
next_date += datetime.timedelta(days=1)
@ -576,12 +574,14 @@ class Loop(Generic[LF]):
self._time_index += 1
return datetime.datetime.combine(next_date, next_time)
def _prepare_time_index(self, now: Optional[datetime.datetime] = None) -> None:
def _prepare_time_index(self, now: datetime.datetime = MISSING) -> None:
# now kwarg should be a datetime.datetime representing the time "now"
# to calculate the next time index from
# pre-condition: self._time is set
time_now = (now or datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)).timetz()
time_now = (
now if now is not MISSING else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)
).timetz()
for idx, time in enumerate(self._time):
if time >= time_now:
self._time_index = idx
@ -597,20 +597,24 @@ class Loop(Generic[LF]):
utc: datetime.timezone = datetime.timezone.utc,
) -> List[datetime.time]:
if isinstance(time, dt):
ret = time if time.tzinfo is not None else time.replace(tzinfo=utc)
return [ret]
inner = time if time.tzinfo is not None else time.replace(tzinfo=utc)
return [inner]
if not isinstance(time, Sequence):
raise TypeError(f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.')
raise TypeError(
f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.'
)
if not time:
raise ValueError('time parameter must not be an empty sequence.')
ret = []
ret: List[datetime.time] = []
for index, t in enumerate(time):
if not isinstance(t, dt):
raise TypeError(f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.')
raise TypeError(
f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.'
)
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
ret = sorted(set(ret)) # de-dupe and sort times
ret = sorted(set(ret)) # de-dupe and sort times
return ret
def change_interval(
@ -691,7 +695,7 @@ def loop(
time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
count: Optional[int] = None,
reconnect: bool = True,
loop: Optional[asyncio.AbstractEventLoop] = None,
loop: asyncio.AbstractEventLoop = MISSING,
) -> Callable[[LF], Loop[LF]]:
"""A decorator that schedules a task in the background for you with
optional reconnect logic. The decorator returns a :class:`Loop`.
@ -707,7 +711,7 @@ def loop(
time: Union[:class:`datetime.time`, Sequence[:class:`datetime.time`]]
The exact times to run this loop at. Either a non-empty list or a single
value of :class:`datetime.time` should be passed. Timezones are supported.
If no timezone is given for the times, it is assumed to represent UTC time.
If no timezone is given for the times, it is assumed to represent UTC time.
This cannot be used in conjunction with the relative time parameters.
@ -724,7 +728,7 @@ def loop(
Whether to handle errors and restart the task
using an exponential back-off algorithm similar to the
one used in :meth:`discord.Client.connect`.
loop: Optional[:class:`asyncio.AbstractEventLoop`]
loop: :class:`asyncio.AbstractEventLoop`
The loop to use to register the task, if not given
defaults to :func:`asyncio.get_event_loop`.
@ -736,15 +740,17 @@ def loop(
The function was not a coroutine, an invalid value for the ``time`` parameter was passed,
or ``time`` parameter was passed in conjunction with relative time parameters.
"""
def decorator(func: LF) -> Loop[LF]:
kwargs = {
'seconds': seconds,
'minutes': minutes,
'hours': hours,
'count': count,
'time': time,
'reconnect': reconnect,
'loop': loop,
}
return Loop(func, **kwargs)
return Loop[LF](
func,
seconds=seconds,
minutes=minutes,
hours=hours,
count=count,
time=time,
reconnect=reconnect,
loop=loop,
)
return decorator

View File

@ -41,7 +41,7 @@ FV = TypeVar('FV', bound='flag_value')
BF = TypeVar('BF', bound='BaseFlags')
class flag_value(Generic[BF]):
class flag_value:
def __init__(self, func: Callable[[Any], int]):
self.flag = func(None)
self.__doc__ = func.__doc__
@ -205,7 +205,7 @@ class SystemChannelFlags(BaseFlags):
@flag_value
def premium_subscriptions(self):
""":class:`bool`: Returns ``True`` if the system channel is used for Nitro boosting notifications."""
""":class:`bool`: Returns ``True`` if the system channel is used for "Nitro boosting" notifications."""
return 2
@flag_value
@ -480,16 +480,6 @@ class Intents(BaseFlags):
self.value = self.DEFAULT_VALUE
return self
@classmethod
def default(cls: Type[Intents]) -> Intents:
"""A factory method that creates a :class:`Intents` with everything enabled
except :attr:`presences` and :attr:`members`.
"""
self = cls.all()
self.presences = False
self.members = False
return self
@flag_value
def guilds(self):
""":class:`bool`: Whether guild related events are enabled.
@ -566,18 +556,34 @@ class Intents(BaseFlags):
@flag_value
def emojis(self):
""":class:`bool`: Whether guild emoji related events are enabled.
""":class:`bool`: Alias of :attr:`.emojis_and_stickers`.
.. versionchanged:: 2.0
Changed to an alias.
"""
return 1 << 3
@alias_flag_value
def emojis_and_stickers(self):
""":class:`bool`: Whether guild emoji and sticker related events are enabled.
.. versionadded:: 2.0
This corresponds to the following events:
- :func:`on_guild_emojis_update`
- :func:`on_guild_stickers_update`
This also corresponds to the following attributes and classes in terms of cache:
- :class:`Emoji`
- :class:`GuildSticker`
- :meth:`Client.get_emoji`
- :meth:`Client.get_sticker`
- :meth:`Client.emojis`
- :meth:`Client.stickers`
- :attr:`Guild.emojis`
- :attr:`Guild.stickers`
"""
return 1 << 3
@ -634,6 +640,10 @@ class Intents(BaseFlags):
- :attr:`VoiceChannel.members`
- :attr:`VoiceChannel.voice_states`
- :attr:`Member.voice`
.. note::
This intent is required to connect to voice.
"""
return 1 << 7

View File

@ -40,7 +40,7 @@ from .activity import BaseActivity
from .enums import SpeakingState
from .errors import ConnectionClosed, InvalidArgument
log = logging.getLogger(__name__)
_log = logging.getLogger(__name__)
__all__ = (
'DiscordWebSocket',
@ -101,7 +101,7 @@ class GatewayRatelimiter:
async with self.lock:
delta = self.get_delay()
if delta:
log.warning('WebSocket in shard ID %s is ratelimited, waiting %.2f seconds', self.shard_id, delta)
_log.warning('WebSocket in shard ID %s is ratelimited, waiting %.2f seconds', self.shard_id, delta)
await asyncio.sleep(delta)
@ -129,20 +129,20 @@ class KeepAliveHandler(threading.Thread):
def run(self):
while not self._stop_ev.wait(self.interval):
if self._last_recv + self.heartbeat_timeout < time.perf_counter():
log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id)
_log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id)
coro = self.ws.close(4000)
f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop)
try:
f.result()
except Exception:
log.exception('An error occurred while stopping the gateway. Ignoring.')
_log.exception('An error occurred while stopping the gateway. Ignoring.')
finally:
self.stop()
return
data = self.get_payload()
log.debug(self.msg, self.shard_id, data['d'])
_log.debug(self.msg, self.shard_id, data['d'])
coro = self.ws.send_heartbeat(data)
f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop)
try:
@ -161,7 +161,7 @@ class KeepAliveHandler(threading.Thread):
else:
stack = ''.join(traceback.format_stack(frame))
msg = f'{self.block_msg}\nLoop thread traceback (most recent call last):\n{stack}'
log.warning(msg, self.shard_id, total)
_log.warning(msg, self.shard_id, total)
except Exception:
self.stop()
@ -185,7 +185,7 @@ class KeepAliveHandler(threading.Thread):
self._last_ack = ack_time
self.latency = ack_time - self._last_send
if self.latency > 10:
log.warning(self.behind_msg, self.shard_id, self.latency)
_log.warning(self.behind_msg, self.shard_id, self.latency)
class VoiceKeepAliveHandler(KeepAliveHandler):
def __init__(self, *args, **kwargs):
@ -293,6 +293,12 @@ class DiscordWebSocket:
def is_ratelimited(self):
return self._rate_limiter.is_ratelimited()
def debug_log_receive(self, data, /):
self._dispatch('socket_raw_receive', data)
def log_receive(self, _, /):
pass
@classmethod
async def from_client(cls, client, *, initial=False, gateway=None, shard_id=None, session=None, sequence=None, resume=False):
"""Creates a main websocket for Discord from a :class:`Client`.
@ -318,9 +324,13 @@ class DiscordWebSocket:
ws.sequence = sequence
ws._max_heartbeat_timeout = client._connection.heartbeat_timeout
if client._enable_debug_events:
ws.send = ws.debug_send
ws.log_receive = ws.debug_log_receive
client._connection._update_references(ws)
log.debug('Created websocket connected to %s', gateway)
_log.debug('Created websocket connected to %s', gateway)
# poll event for OP Hello
await ws.poll_event()
@ -393,7 +403,7 @@ class DiscordWebSocket:
await self.call_hooks('before_identify', self.shard_id, initial=self._initial_identify)
await self.send_as_json(payload)
log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id)
_log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id)
async def resume(self):
"""Sends the RESUME packet."""
@ -407,11 +417,9 @@ class DiscordWebSocket:
}
await self.send_as_json(payload)
log.info('Shard ID %s has sent the RESUME payload.', self.shard_id)
async def received_message(self, msg):
self._dispatch('socket_raw_receive', msg)
_log.info('Shard ID %s has sent the RESUME payload.', self.shard_id)
async def received_message(self, msg, /):
if type(msg) is bytes:
self._buffer.extend(msg)
@ -420,10 +428,14 @@ class DiscordWebSocket:
msg = self._zlib.decompress(self._buffer)
msg = msg.decode('utf-8')
self._buffer = bytearray()
msg = utils.from_json(msg)
log.debug('For Shard ID %s: WebSocket Event: %s', self.shard_id, msg)
self._dispatch('socket_response', msg)
self.log_receive(msg)
msg = utils._from_json(msg)
_log.debug('For Shard ID %s: WebSocket Event: %s', self.shard_id, msg)
event = msg.get('t')
if event:
self._dispatch('socket_event_type', event)
op = msg.get('op')
data = msg.get('d')
@ -439,7 +451,7 @@ class DiscordWebSocket:
# "reconnect" can only be handled by the Client
# so we terminate our connection and raise an
# internal exception signalling to reconnect.
log.debug('Received RECONNECT opcode.')
_log.debug('Received RECONNECT opcode.')
await self.close()
raise ReconnectWebSocket(self.shard_id)
@ -469,35 +481,33 @@ class DiscordWebSocket:
self.sequence = None
self.session_id = None
log.info('Shard ID %s session has been invalidated.', self.shard_id)
_log.info('Shard ID %s session has been invalidated.', self.shard_id)
await self.close(code=1000)
raise ReconnectWebSocket(self.shard_id, resume=False)
log.warning('Unknown OP code %s.', op)
_log.warning('Unknown OP code %s.', op)
return
event = msg.get('t')
if event == 'READY':
self._trace = trace = data.get('_trace', [])
self.sequence = msg['s']
self.session_id = data['session_id']
# pass back shard ID to ready handler
data['__shard_id__'] = self.shard_id
log.info('Shard ID %s has connected to Gateway: %s (Session ID: %s).',
_log.info('Shard ID %s has connected to Gateway: %s (Session ID: %s).',
self.shard_id, ', '.join(trace), self.session_id)
elif event == 'RESUMED':
self._trace = trace = data.get('_trace', [])
# pass back the shard ID to the resumed handler
data['__shard_id__'] = self.shard_id
log.info('Shard ID %s has successfully RESUMED session %s under trace %s.',
_log.info('Shard ID %s has successfully RESUMED session %s under trace %s.',
self.shard_id, self.session_id, ', '.join(trace))
try:
func = self._discord_parsers[event]
except KeyError:
log.debug('Unknown event %s.', event)
_log.debug('Unknown event %s.', event)
else:
func(data)
@ -551,10 +561,10 @@ class DiscordWebSocket:
elif msg.type is aiohttp.WSMsgType.BINARY:
await self.received_message(msg.data)
elif msg.type is aiohttp.WSMsgType.ERROR:
log.debug('Received %s', msg)
_log.debug('Received %s', msg)
raise msg.data
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSE):
log.debug('Received %s', msg)
_log.debug('Received %s', msg)
raise WebSocketClosure
except (asyncio.TimeoutError, WebSocketClosure) as e:
# Ensure the keep alive handler is closed
@ -563,25 +573,29 @@ class DiscordWebSocket:
self._keep_alive = None
if isinstance(e, asyncio.TimeoutError):
log.info('Timed out receiving packet. Attempting a reconnect.')
_log.info('Timed out receiving packet. Attempting a reconnect.')
raise ReconnectWebSocket(self.shard_id) from None
code = self._close_code or self.socket.close_code
if self._can_handle_close():
log.info('Websocket closed with %s, attempting a reconnect.', code)
_log.info('Websocket closed with %s, attempting a reconnect.', code)
raise ReconnectWebSocket(self.shard_id) from None
else:
log.info('Websocket closed with %s, cannot reconnect.', code)
_log.info('Websocket closed with %s, cannot reconnect.', code)
raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None
async def send(self, data):
async def debug_send(self, data, /):
await self._rate_limiter.block()
self._dispatch('socket_raw_send', data)
await self.socket.send_str(data)
async def send(self, data, /):
await self._rate_limiter.block()
await self.socket.send_str(data)
async def send_as_json(self, data):
try:
await self.send(utils.to_json(data))
await self.send(utils._to_json(data))
except RuntimeError as exc:
if not self._can_handle_close():
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
@ -589,7 +603,7 @@ class DiscordWebSocket:
async def send_heartbeat(self, data):
# This bypasses the rate limit handling code since it has a higher priority
try:
await self.socket.send_str(utils.to_json(data))
await self.socket.send_str(utils._to_json(data))
except RuntimeError as exc:
if not self._can_handle_close():
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
@ -615,8 +629,8 @@ class DiscordWebSocket:
}
}
sent = utils.to_json(payload)
log.debug('Sending "%s" to change status', sent)
sent = utils._to_json(payload)
_log.debug('Sending "%s" to change status', sent)
await self.send(sent)
async def request_chunks(self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None):
@ -652,7 +666,7 @@ class DiscordWebSocket:
}
}
log.debug('Updating our voice state to %s.', payload)
_log.debug('Updating our voice state to %s.', payload)
await self.send_as_json(payload)
async def close(self, code=4000):
@ -720,8 +734,8 @@ class DiscordVoiceWebSocket:
pass
async def send_as_json(self, data):
log.debug('Sending voice websocket frame: %s.', data)
await self.ws.send_str(utils.to_json(data))
_log.debug('Sending voice websocket frame: %s.', data)
await self.ws.send_str(utils._to_json(data))
send_heartbeat = send_as_json
@ -806,7 +820,7 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload)
async def received_message(self, msg):
log.debug('Voice websocket frame received: %s', msg)
_log.debug('Voice websocket frame received: %s', msg)
op = msg['op']
data = msg.get('d')
@ -815,7 +829,7 @@ class DiscordVoiceWebSocket:
elif op == self.HEARTBEAT_ACK:
self._keep_alive.ack()
elif op == self.RESUMED:
log.info('Voice RESUME succeeded.')
_log.info('Voice RESUME succeeded.')
elif op == self.SESSION_DESCRIPTION:
self._connection.mode = data['mode']
await self.load_secret_key(data)
@ -838,7 +852,7 @@ class DiscordVoiceWebSocket:
struct.pack_into('>I', packet, 4, state.ssrc)
state.socket.sendto(packet, (state.endpoint_ip, state.voice_port))
recv = await self.loop.sock_recv(state.socket, 70)
log.debug('received packet in initial_connection: %s', recv)
_log.debug('received packet in initial_connection: %s', recv)
# the ip is ascii starting at the 4th byte and ending at the first null
ip_start = 4
@ -846,15 +860,15 @@ class DiscordVoiceWebSocket:
state.ip = recv[ip_start:ip_end].decode('ascii')
state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
log.debug('detected ip: %s port: %s', state.ip, state.port)
_log.debug('detected ip: %s port: %s', state.ip, state.port)
# there *should* always be at least one supported mode (xsalsa20_poly1305)
modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes]
log.debug('received supported encryption modes: %s', ", ".join(modes))
_log.debug('received supported encryption modes: %s', ", ".join(modes))
mode = modes[0]
await self.select_protocol(state.ip, state.port, mode)
log.info('selected the voice protocol for use (%s)', mode)
_log.info('selected the voice protocol for use (%s)', mode)
@property
def latency(self):
@ -872,7 +886,7 @@ class DiscordVoiceWebSocket:
return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies)
async def load_secret_key(self, data):
log.info('received secret key for voice connection')
_log.info('received secret key for voice connection')
self.secret_key = self._connection.secret_key = data.get('secret_key')
await self.speak()
await self.speak(False)
@ -881,12 +895,12 @@ class DiscordVoiceWebSocket:
# This exception is handled up the chain
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
if msg.type is aiohttp.WSMsgType.TEXT:
await self.received_message(utils.from_json(msg.data))
await self.received_message(utils._from_json(msg.data))
elif msg.type is aiohttp.WSMsgType.ERROR:
log.debug('Received %s', msg)
_log.debug('Received %s', msg)
raise ConnectionClosed(self.ws, shard_id=None) from msg.data
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
log.debug('Received %s', msg)
_log.debug('Received %s', msg)
raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code)
async def close(self, code=1000):

View File

@ -25,6 +25,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import copy
import unicodedata
from typing import (
Any,
ClassVar,
@ -51,6 +52,7 @@ from .colour import Colour
from .errors import InvalidArgument, ClientException
from .channel import *
from .channel import _guild_channel_factory
from .channel import _threaded_guild_channel_factory
from .enums import (
AuditLogAction,
VideoQualityMode,
@ -71,7 +73,10 @@ from .asset import Asset
from .flags import SystemChannelFlags
from .integrations import Integration, _integration_factory
from .stage_instance import StageInstance
from .threads import Thread
from .threads import Thread, ThreadMember
from .sticker import GuildSticker
from .file import File
__all__ = (
'Guild',
@ -107,6 +112,7 @@ class BanEntry(NamedTuple):
class _GuildLimit(NamedTuple):
emoji: int
stickers: int
bitrate: float
filesize: int
@ -134,12 +140,20 @@ class Guild(Hashable):
Returns the guild's name.
.. describe:: int(x)
Returns the guild's ID.
Attributes
----------
name: :class:`str`
The guild name.
emojis: Tuple[:class:`Emoji`, ...]
All emojis that the guild owns.
stickers: Tuple[:class:`GuildSticker`, ...]
All stickers that the guild owns.
.. versionadded:: 2.0
region: :class:`VoiceRegion`
The region the guild belongs on. There is a chance that the region
will be a :class:`str` if the value is not recognised by the enumerator.
@ -234,6 +248,7 @@ class Guild(Hashable):
'owner_id',
'mfa_level',
'emojis',
'stickers',
'features',
'verification_level',
'explicit_content_filter',
@ -266,11 +281,11 @@ class Guild(Hashable):
)
_PREMIUM_GUILD_LIMITS: ClassVar[Dict[Optional[int], _GuildLimit]] = {
None: _GuildLimit(emoji=50, bitrate=96e3, filesize=8388608),
0: _GuildLimit(emoji=50, bitrate=96e3, filesize=8388608),
1: _GuildLimit(emoji=100, bitrate=128e3, filesize=8388608),
2: _GuildLimit(emoji=150, bitrate=256e3, filesize=52428800),
3: _GuildLimit(emoji=250, bitrate=384e3, filesize=104857600),
None: _GuildLimit(emoji=50, stickers=0, bitrate=96e3, filesize=8388608),
0: _GuildLimit(emoji=50, stickers=0, bitrate=96e3, filesize=8388608),
1: _GuildLimit(emoji=100, stickers=15, bitrate=128e3, filesize=8388608),
2: _GuildLimit(emoji=150, stickers=30, bitrate=256e3, filesize=52428800),
3: _GuildLimit(emoji=250, stickers=60, bitrate=384e3, filesize=104857600),
}
def __init__(self, *, data: GuildPayload, state: ConnectionState):
@ -412,6 +427,9 @@ class Guild(Hashable):
self.mfa_level: MFALevel = guild.get('mfa_level')
self.emojis: Tuple[Emoji, ...] = tuple(map(lambda d: state.store_emoji(self, d), guild.get('emojis', [])))
self.stickers: Tuple[GuildSticker, ...] = tuple(
map(lambda d: state.store_sticker(self, d), guild.get('stickers', []))
)
self.features: List[GuildFeature] = guild.get('features', [])
self._splash: Optional[str] = guild.get('splash')
self._system_channel_id: Optional[int] = utils._get_as_snowflake(guild, 'system_channel_id')
@ -599,6 +617,23 @@ class Guild(Hashable):
return self._channels.get(id) or self._threads.get(id)
def get_channel_or_thread(self, channel_id: int, /) -> Optional[Union[Thread, GuildChannel]]:
"""Returns a channel or thread with the given ID.
.. versionadded:: 2.0
Parameters
-----------
channel_id: :class:`int`
The ID to search for.
Returns
--------
Optional[Union[:class:`Thread`, :class:`.abc.GuildChannel`]]
The returned channel or thread or ``None`` if not found.
"""
return self._channels.get(channel_id) or self._threads.get(channel_id)
def get_channel(self, channel_id: int, /) -> Optional[GuildChannel]:
"""Returns a channel with the given ID.
@ -680,6 +715,15 @@ class Guild(Hashable):
more_emoji = 200 if 'MORE_EMOJI' in self.features else 50
return max(more_emoji, self._PREMIUM_GUILD_LIMITS[self.premium_tier].emoji)
@property
def sticker_limit(self) -> int:
""":class:`int`: The maximum number of sticker slots this guild has.
.. versionadded:: 2.0
"""
more_stickers = 60 if 'MORE_STICKERS' in self.features else 0
return max(more_stickers, self._PREMIUM_GUILD_LIMITS[self.premium_tier].stickers)
@property
def bitrate_limit(self) -> float:
""":class:`float`: The maximum bitrate for voice channels this guild can have."""
@ -696,7 +740,21 @@ class Guild(Hashable):
"""List[:class:`Member`]: A list of members that belong to this guild."""
return list(self._members.values())
def get_member(self, user_id: int) -> Optional[Member]:
@property
def humans(self) -> List[Member]:
"""List[:class:`Member`]: A list of human members that belong to this guild.
.. versionadded:: 2.0 """
return [member for member in self.members if not member.bot]
@property
def bots(self) -> List[Member]:
"""List[:class:`Member`]: A list of bots that belong to this guild.
.. versionadded:: 2.0 """
return [member for member in self.members if member.bot]
def get_member(self, user_id: int, /) -> Optional[Member]:
"""Returns a member with the given ID.
Parameters
@ -1316,7 +1374,7 @@ class Guild(Hashable):
preferred_locale: str = MISSING,
rules_channel: Optional[TextChannel] = MISSING,
public_updates_channel: Optional[TextChannel] = MISSING,
) -> None:
) -> Guild:
r"""|coro|
Edits the guild.
@ -1330,6 +1388,9 @@ class Guild(Hashable):
.. versionchanged:: 2.0
The `discovery_splash` and `community` keyword-only parameters were added.
.. versionchanged:: 2.0
The newly updated guild is returned.
Parameters
----------
name: :class:`str`
@ -1403,6 +1464,12 @@ class Guild(Hashable):
The image format passed in to ``icon`` is invalid. It must be
PNG or JPG. This is also raised if you are not the owner of the
guild and request an ownership transfer.
Returns
--------
:class:`Guild`
The newly updated guild. Note that this has the same limitations as
mentioned in :meth:`Client.fetch_guild` and may not have full data.
"""
http = self._state.http
@ -1515,7 +1582,8 @@ class Guild(Hashable):
fields['features'] = features
await http.edit_guild(self.id, reason=reason, **fields)
data = await http.edit_guild(self.id, reason=reason, **fields)
return Guild(data=data, state=self._state)
async def fetch_channels(self) -> Sequence[GuildChannel]:
"""|coro|
@ -1552,6 +1620,35 @@ class Guild(Hashable):
return [convert(d) for d in data]
async def active_threads(self) -> List[Thread]:
"""|coro|
Returns a list of active :class:`Thread` that the client can access.
This includes both private and public threads.
.. versionadded:: 2.0
Raises
------
HTTPException
The request to get the active threads failed.
Returns
--------
List[:class:`Thread`]
The active threads
"""
data = await self._state.http.get_active_threads(self.id)
threads = [Thread(guild=self, state=self._state, data=d) for d in data.get('threads', [])]
thread_lookup: Dict[int, Thread] = {thread.id: thread for thread in threads}
for member in data.get('members', []):
thread = thread_lookup.get(int(member['id']))
if thread is not None:
thread._add_member(ThreadMember(parent=thread, data=member))
return threads
# TODO: Remove Optional typing here when async iterators are refactored
def fetch_members(self, *, limit: int = 1000, after: Optional[SnowflakeTime] = None) -> MemberIterator:
"""Retrieves an :class:`.AsyncIterator` that enables receiving the guild's members. In order to use this,
@ -1665,14 +1762,14 @@ class Guild(Hashable):
data: BanPayload = await self._state.http.get_ban(user.id, self.id)
return BanEntry(user=User(state=self._state, data=data['user']), reason=data['reason'])
async def fetch_channel(self, channel_id: int, /) -> GuildChannel:
async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, Thread]:
"""|coro|
Retrieves a :class:`.abc.GuildChannel` with the specified ID.
Retrieves a :class:`.abc.GuildChannel` or :class:`.Thread` with the specified ID.
.. note::
This method is an API call. For general usage, consider :meth:`get_channel` instead.
This method is an API call. For general usage, consider :meth:`get_channel_or_thread` instead.
.. versionadded:: 2.0
@ -1691,12 +1788,12 @@ class Guild(Hashable):
Returns
--------
:class:`.abc.GuildChannel`
Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
The channel from the ID.
"""
data = await self._state.http.get_channel(channel_id)
factory, ch_type = _guild_channel_factory(data['type'])
factory, ch_type = _threaded_guild_channel_factory(data['type'])
if factory is None:
raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data))
@ -2009,6 +2106,150 @@ class Guild(Hashable):
return [convert(d) for d in data]
async def fetch_stickers(self) -> List[GuildSticker]:
r"""|coro|
Retrieves a list of all :class:`Sticker`\s for the guild.
.. versionadded:: 2.0
.. note::
This method is an API call. For general usage, consider :attr:`stickers` instead.
Raises
---------
HTTPException
An error occurred fetching the stickers.
Returns
--------
List[:class:`GuildSticker`]
The retrieved stickers.
"""
data = await self._state.http.get_all_guild_stickers(self.id)
return [GuildSticker(state=self._state, data=d) for d in data]
async def fetch_sticker(self, sticker_id: int, /) -> GuildSticker:
"""|coro|
Retrieves a custom :class:`Sticker` from the guild.
.. versionadded:: 2.0
.. note::
This method is an API call.
For general usage, consider iterating over :attr:`stickers` instead.
Parameters
-------------
sticker_id: :class:`int`
The sticker's ID.
Raises
---------
NotFound
The sticker requested could not be found.
HTTPException
An error occurred fetching the sticker.
Returns
--------
:class:`GuildSticker`
The retrieved sticker.
"""
data = await self._state.http.get_guild_sticker(self.id, sticker_id)
return GuildSticker(state=self._state, data=data)
async def create_sticker(
self,
*,
name: str,
description: Optional[str] = None,
emoji: str,
file: File,
reason: Optional[str] = None,
) -> GuildSticker:
"""|coro|
Creates a :class:`Sticker` for the guild.
You must have :attr:`~Permissions.manage_emojis_and_stickers` permission to
do this.
.. versionadded:: 2.0
Parameters
-----------
name: :class:`str`
The sticker name. Must be at least 2 characters.
description: Optional[:class:`str`]
The sticker's description. Can be ``None``.
emoji: :class:`str`
The name of a unicode emoji that represents the sticker's expression.
file: :class:`File`
The file of the sticker to upload.
reason: :class:`str`
The reason for creating this sticker. Shows up on the audit log.
Raises
-------
Forbidden
You are not allowed to create stickers.
HTTPException
An error occurred creating a sticker.
Returns
--------
:class:`GuildSticker`
The created sticker.
"""
payload = {
'name': name,
}
if description:
payload['description'] = description
try:
emoji = unicodedata.name(emoji)
except TypeError:
pass
else:
emoji = emoji.replace(' ', '_')
payload['tags'] = emoji
data = await self._state.http.create_guild_sticker(self.id, payload, file, reason)
return self._state.store_sticker(self, data)
async def delete_sticker(self, sticker: Snowflake, *, reason: Optional[str] = None) -> None:
"""|coro|
Deletes the custom :class:`Sticker` from the guild.
You must have :attr:`~Permissions.manage_emojis_and_stickers` permission to
do this.
.. versionadded:: 2.0
Parameters
-----------
sticker: :class:`abc.Snowflake`
The sticker you are deleting.
reason: Optional[:class:`str`]
The reason for deleting this sticker. Shows up on the audit log.
Raises
-------
Forbidden
You are not allowed to delete stickers.
HTTPException
An error occurred deleting the sticker.
"""
await self._state.http.delete_guild_sticker(self.id, sticker.id, reason)
async def fetch_emojis(self) -> List[Emoji]:
r"""|coro|

View File

@ -33,7 +33,6 @@ from typing import (
ClassVar,
Coroutine,
Dict,
Final,
Iterable,
List,
Optional,
@ -49,12 +48,12 @@ import weakref
import aiohttp
from .errors import HTTPException, Forbidden, NotFound, LoginFailure, DiscordServerError, GatewayNotFound
from .errors import HTTPException, Forbidden, NotFound, LoginFailure, DiscordServerError, GatewayNotFound, InvalidArgument
from .gateway import DiscordClientWebSocketResponse
from . import __version__, utils
from .utils import MISSING
log = logging.getLogger(__name__)
_log = logging.getLogger(__name__)
if TYPE_CHECKING:
from .file import File
@ -84,6 +83,7 @@ if TYPE_CHECKING:
widget,
threads,
voice,
sticker,
)
from .types.snowflake import Snowflake, SnowflakeList
@ -99,7 +99,7 @@ async def json_or_text(response: aiohttp.ClientResponse) -> Union[Dict[str, Any]
text = await response.text(encoding='utf-8')
try:
if response.headers['content-type'] == 'application/json':
return utils.from_json(text)
return utils._from_json(text)
except KeyError:
# Thanks Cloudflare
pass
@ -141,7 +141,8 @@ class MaybeUnlock:
def defer(self) -> None:
self._unlock = False
def __exit__(self,
def __exit__(
self,
exc_type: Optional[Type[BE]],
exc: Optional[BE],
traceback: Optional[TracebackType],
@ -152,15 +153,12 @@ class MaybeUnlock:
# For some reason, the Discord voice websocket expects this header to be
# completely lowercase while aiohttp respects spec and does it as case-insensitive
aiohttp.hdrs.WEBSOCKET = 'websocket' #type: ignore
aiohttp.hdrs.WEBSOCKET = 'websocket' # type: ignore
class HTTPClient:
"""Represents an HTTP client sending HTTP requests to the Discord API."""
SUCCESS_LOG: Final[ClassVar[str]] = '{method} {url} has received {text}'
REQUEST_LOG: Final[ClassVar[str]] = '{method} {url} with {json} has returned {status}'
def __init__(
self,
connector: Optional[aiohttp.BaseConnector] = None,
@ -168,7 +166,7 @@ class HTTPClient:
proxy: Optional[str] = None,
proxy_auth: Optional[aiohttp.BasicAuth] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
unsync_clock: bool = True
unsync_clock: bool = True,
) -> None:
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
self.connector = connector
@ -212,7 +210,7 @@ class HTTPClient:
*,
files: Optional[Sequence[File]] = None,
form: Optional[Iterable[Dict[str, Any]]] = None,
**kwargs: Any
**kwargs: Any,
) -> Any:
bucket = route.bucket
method = route.method
@ -234,7 +232,7 @@ class HTTPClient:
# some checking if it's a JSON request
if 'json' in kwargs:
headers['Content-Type'] = 'application/json'
kwargs['data'] = utils.to_json(kwargs.pop('json'))
kwargs['data'] = utils._to_json(kwargs.pop('json'))
try:
reason = kwargs.pop('reason')
@ -273,7 +271,7 @@ class HTTPClient:
try:
async with self.__session.request(method, url, **kwargs) as response:
log.debug('%s %s with %s has returned %s', method, url, kwargs.get('data'), response.status)
_log.debug('%s %s with %s has returned %s', method, url, kwargs.get('data'), response.status)
# even errors have text involved in them so this is safe to call
data = await json_or_text(response)
@ -283,13 +281,13 @@ class HTTPClient:
if remaining == '0' and response.status != 429:
# we've depleted our current bucket
delta = utils._parse_ratelimit_header(response, use_clock=self.use_clock)
log.debug('A rate limit bucket has been exhausted (bucket: %s, retry: %s).', bucket, delta)
_log.debug('A rate limit bucket has been exhausted (bucket: %s, retry: %s).', bucket, delta)
maybe_lock.defer()
self.loop.call_later(delta, lock.release)
# the request was successful so just return the text/json
if 300 > response.status >= 200:
log.debug('%s %s has received %s', method, url, data)
_log.debug('%s %s has received %s', method, url, data)
return data
# we are being rate limited
@ -302,22 +300,22 @@ class HTTPClient:
# sleep a bit
retry_after: float = data['retry_after']
log.warning(fmt, retry_after, bucket)
_log.warning(fmt, retry_after, bucket)
# check if it's a global rate limit
is_global = data.get('global', False)
if is_global:
log.warning('Global rate limit has been hit. Retrying in %.2f seconds.', retry_after)
_log.warning('Global rate limit has been hit. Retrying in %.2f seconds.', retry_after)
self._global_over.clear()
await asyncio.sleep(retry_after)
log.debug('Done sleeping for the rate limit. Retrying...')
_log.debug('Done sleeping for the rate limit. Retrying...')
# release the global lock now that the
# global rate limit has passed
if is_global:
self._global_over.set()
log.debug('Global rate limit is now over.')
_log.debug('Global rate limit is now over.')
continue
@ -415,14 +413,15 @@ class HTTPClient:
def send_message(
self,
channel_id: Snowflake,
content: str,
content: Optional[str],
*,
tts: bool = False,
embed: Optional[embed.Embed] = None,
embeds: Optional[List[embed.Embed]] = None,
nonce: Optional[str] = None,
nonce: Optional[str] = None,
allowed_mentions: Optional[message.AllowedMentions] = None,
message_reference: Optional[message.MessageReference] = None,
stickers: Optional[List[sticker.StickerItem]] = None,
components: Optional[List[components.Component]] = None,
) -> Response[message.Message]:
r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id)
@ -436,7 +435,7 @@ class HTTPClient:
if embed:
payload['embeds'] = [embed]
if embeds:
payload['embeds'] = embeds
@ -452,6 +451,9 @@ class HTTPClient:
if components:
payload['components'] = components
if stickers:
payload['sticker_ids'] = stickers
return self.request(r, json=payload)
def send_typing(self, channel_id: Snowflake) -> Response[None]:
@ -465,10 +467,11 @@ class HTTPClient:
content: Optional[str] = None,
tts: bool = False,
embed: Optional[embed.Embed] = None,
embeds: Iterable[Optional[embed.Embed]] = None,
embeds: Optional[Iterable[Optional[embed.Embed]]] = None,
nonce: Optional[str] = None,
allowed_mentions: Optional[message.AllowedMentions] = None,
message_reference: Optional[message.MessageReference] = None,
stickers: Optional[List[sticker.StickerItem]] = None,
components: Optional[List[components.Component]] = None,
) -> Response[message.Message]:
form = []
@ -488,8 +491,10 @@ class HTTPClient:
payload['message_reference'] = message_reference
if components:
payload['components'] = components
if stickers:
payload['sticker_ids'] = stickers
form.append({'name': 'payload_json', 'value': utils.to_json(payload)})
form.append({'name': 'payload_json', 'value': utils._to_json(payload)})
if len(files) == 1:
file = files[0]
form.append(
@ -525,6 +530,7 @@ class HTTPClient:
nonce: Optional[str] = None,
allowed_mentions: Optional[message.AllowedMentions] = None,
message_reference: Optional[message.MessageReference] = None,
stickers: Optional[List[sticker.StickerItem]] = None,
components: Optional[List[components.Component]] = None,
) -> Response[message.Message]:
r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id)
@ -538,14 +544,19 @@ class HTTPClient:
nonce=nonce,
allowed_mentions=allowed_mentions,
message_reference=message_reference,
stickers=stickers,
components=components,
)
def delete_message(self, channel_id: Snowflake, message_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]:
def delete_message(
self, channel_id: Snowflake, message_id: Snowflake, *, reason: Optional[str] = None
) -> Response[None]:
r = Route('DELETE', '/channels/{channel_id}/messages/{message_id}', channel_id=channel_id, message_id=message_id)
return self.request(r, reason=reason)
def delete_messages(self, channel_id: Snowflake, message_ids: SnowflakeList, *, reason: Optional[str] = None) -> Response[None]:
def delete_messages(
self, channel_id: Snowflake, message_ids: SnowflakeList, *, reason: Optional[str] = None
) -> Response[None]:
r = Route('POST', '/channels/{channel_id}/messages/bulk-delete', channel_id=channel_id)
payload = {
'messages': message_ids,
@ -567,7 +578,9 @@ class HTTPClient:
)
return self.request(r)
def remove_reaction(self, channel_id: Snowflake, message_id: Snowflake, emoji: str, member_id: Snowflake) -> Response[None]:
def remove_reaction(
self, channel_id: Snowflake, message_id: Snowflake, emoji: str, member_id: Snowflake
) -> Response[None]:
r = Route(
'DELETE',
'/channels/{channel_id}/messages/{message_id}/reactions/{emoji}/{member_id}',
@ -713,11 +726,7 @@ class HTTPClient:
'delete_message_days': delete_message_days,
}
if reason:
# thanks aiohttp
r.url = f'{r.url}?reason={_uriquote(reason)}'
return self.request(r, params=params)
return self.request(r, params=params, reason=reason)
def unban(self, user_id: Snowflake, guild_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]:
r = Route('DELETE', '/guilds/{guild_id}/bans/{user_id}', guild_id=guild_id, user_id=user_id)
@ -772,11 +781,11 @@ class HTTPClient:
}
return self.request(r, json=payload, reason=reason)
def edit_my_voice_state(self, guild_id: Snowflake, payload: voice.VoiceState) -> Response[None]:
def edit_my_voice_state(self, guild_id: Snowflake, payload: Dict[str, Any]) -> Response[None]:
r = Route('PATCH', '/guilds/{guild_id}/voice-states/@me', guild_id=guild_id)
return self.request(r, json=payload)
def edit_voice_state(self, guild_id: Snowflake, user_id: Snowflake, payload: voice.VoiceState) -> Response[None]:
def edit_voice_state(self, guild_id: Snowflake, user_id: Snowflake, payload: Dict[str, Any]) -> Response[None]:
r = Route('PATCH', '/guilds/{guild_id}/voice-states/{user_id}', guild_id=guild_id, user_id=user_id)
return self.request(r, json=payload)
@ -787,7 +796,7 @@ class HTTPClient:
*,
reason: Optional[str] = None,
**fields: Any,
) -> Response[member.Member]:
) -> Response[member.MemberWithUser]:
r = Route('PATCH', '/guilds/{guild_id}/members/{user_id}', guild_id=guild_id, user_id=user_id)
return self.request(r, json=fields, reason=reason)
@ -817,6 +826,8 @@ class HTTPClient:
'archived',
'auto_archive_duration',
'locked',
'invitable',
'default_auto_archive_duration',
)
payload = {k: v for k, v in options.items() if k in valid_keys}
return self.request(r, reason=reason, json=payload)
@ -871,42 +882,44 @@ class HTTPClient:
# Thread management
def start_public_thread(
def start_thread_with_message(
self,
channel_id: Snowflake,
message_id: Snowflake,
*,
name: str,
auto_archive_duration: threads.ThreadArchiveDuration,
type: threads.ThreadType,
reason: Optional[str] = None,
) -> Response[threads.Thread]:
payload = {
'name': name,
'auto_archive_duration': auto_archive_duration,
'type': type,
}
route = Route(
'POST', '/channels/{channel_id}/messages/{message_id}/threads', channel_id=channel_id, message_id=message_id
)
return self.request(route, json=payload)
return self.request(route, json=payload, reason=reason)
def start_private_thread(
def start_thread_without_message(
self,
channel_id: Snowflake,
*,
name: str,
auto_archive_duration: threads.ThreadArchiveDuration,
type: threads.ThreadType,
invitable: bool = True,
reason: Optional[str] = None,
) -> Response[threads.Thread]:
payload = {
'name': name,
'auto_archive_duration': auto_archive_duration,
'type': type,
'invitable': invitable,
}
route = Route('POST', '/channels/{channel_id}/threads', channel_id=channel_id)
return self.request(route, json=payload)
return self.request(route, json=payload, reason=reason)
def join_thread(self, channel_id: Snowflake) -> Response[None]:
return self.request(Route('POST', '/channels/{channel_id}/thread-members/@me', channel_id=channel_id))
@ -955,8 +968,8 @@ class HTTPClient:
params['limit'] = limit
return self.request(route, params=params)
def get_active_threads(self, channel_id: Snowflake) -> Response[threads.ThreadPaginationPayload]:
route = Route('GET', '/channels/{channel_id}/threads/active', channel_id=channel_id)
def get_active_threads(self, guild_id: Snowflake) -> Response[threads.ThreadPaginationPayload]:
route = Route('GET', '/guilds/{guild_id}/threads/active', guild_id=guild_id)
return self.request(route)
def get_thread_members(self, channel_id: Snowflake) -> Response[List[threads.ThreadMember]]:
@ -1119,7 +1132,9 @@ class HTTPClient:
def get_all_guild_channels(self, guild_id: Snowflake) -> Response[List[guild.GuildChannel]]:
return self.request(Route('GET', '/guilds/{guild_id}/channels', guild_id=guild_id))
def get_members(self, guild_id: Snowflake, limit: int, after: Optional[Snowflake]) -> Response[List[member.Member]]:
def get_members(
self, guild_id: Snowflake, limit: int, after: Optional[Snowflake]
) -> Response[List[member.MemberWithUser]]:
params: Dict[str, Any] = {
'limit': limit,
}
@ -1129,7 +1144,7 @@ class HTTPClient:
r = Route('GET', '/guilds/{guild_id}/members', guild_id=guild_id)
return self.request(r, params=params)
def get_member(self, guild_id: Snowflake, member_id: Snowflake) -> Response[member.Member]:
def get_member(self, guild_id: Snowflake, member_id: Snowflake) -> Response[member.MemberWithUser]:
return self.request(Route('GET', '/guilds/{guild_id}/members/{member_id}', guild_id=guild_id, member_id=member_id))
def prune_members(
@ -1164,6 +1179,71 @@ class HTTPClient:
return self.request(Route('GET', '/guilds/{guild_id}/prune', guild_id=guild_id), params=params)
def get_sticker(self, sticker_id: Snowflake) -> Response[sticker.Sticker]:
return self.request(Route('GET', '/stickers/{sticker_id}', sticker_id=sticker_id))
def list_premium_sticker_packs(self) -> Response[sticker.ListPremiumStickerPacks]:
return self.request(Route('GET', '/sticker-packs'))
def get_all_guild_stickers(self, guild_id: Snowflake) -> Response[List[sticker.GuildSticker]]:
return self.request(Route('GET', '/guilds/{guild_id}/stickers', guild_id=guild_id))
def get_guild_sticker(self, guild_id: Snowflake, sticker_id: Snowflake) -> Response[sticker.GuildSticker]:
return self.request(
Route('GET', '/guilds/{guild_id}/stickers/{sticker_id}', guild_id=guild_id, sticker_id=sticker_id)
)
def create_guild_sticker(
self, guild_id: Snowflake, payload: sticker.CreateGuildSticker, file: File, reason: str
) -> Response[sticker.GuildSticker]:
initial_bytes = file.fp.read(16)
try:
mime_type = utils._get_mime_type_for_image(initial_bytes)
except InvalidArgument:
if initial_bytes.startswith(b'{'):
mime_type = 'application/json'
else:
mime_type = 'application/octet-stream'
finally:
file.reset()
form: List[Dict[str, Any]] = [
{
'name': 'file',
'value': file.fp,
'filename': file.filename,
'content_type': mime_type,
}
]
for k, v in payload.items():
form.append(
{
'name': k,
'value': v,
}
)
return self.request(
Route('POST', '/guilds/{guild_id}/stickers', guild_id=guild_id), form=form, files=[file], reason=reason
)
def modify_guild_sticker(
self, guild_id: Snowflake, sticker_id: Snowflake, payload: sticker.EditGuildSticker, reason: Optional[str],
) -> Response[sticker.GuildSticker]:
return self.request(
Route('PATCH', '/guilds/{guild_id}/stickers/{sticker_id}', guild_id=guild_id, sticker_id=sticker_id),
json=payload,
reason=reason,
)
def delete_guild_sticker(self, guild_id: Snowflake, sticker_id: Snowflake, reason: Optional[str]) -> Response[None]:
return self.request(
Route('DELETE', '/guilds/{guild_id}/stickers/{sticker_id}', guild_id=guild_id, sticker_id=sticker_id),
reason=reason,
)
def get_all_custom_emojis(self, guild_id: Snowflake) -> Response[List[emoji.Emoji]]:
return self.request(Route('GET', '/guilds/{guild_id}/emojis', guild_id=guild_id))
@ -1237,12 +1317,14 @@ class HTTPClient:
return self.request(r)
def delete_integration(self, guild_id: Snowflake, integration_id: Snowflake) -> Response[None]:
def delete_integration(
self, guild_id: Snowflake, integration_id: Snowflake, *, reason: Optional[str] = None
) -> Response[None]:
r = Route(
'DELETE', '/guilds/{guild_id}/integrations/{integration_id}', guild_id=guild_id, integration_id=integration_id
)
return self.request(r)
return self.request(r, reason=reason)
def get_audit_logs(
self,
@ -1285,7 +1367,7 @@ class HTTPClient:
unique: bool = True,
target_type: Optional[invite.InviteTargetType] = None,
target_user_id: Optional[Snowflake] = None,
target_application_id: Optional[Snowflake] = None
target_application_id: Optional[Snowflake] = None,
) -> Response[invite.Invite]:
r = Route('POST', '/channels/{channel_id}/invites', channel_id=channel_id)
payload = {
@ -1306,7 +1388,9 @@ class HTTPClient:
return self.request(r, reason=reason, json=payload)
def get_invite(self, invite_id: str, *, with_counts: bool = True, with_expiration: bool = True) -> Response[invite.Invite]:
def get_invite(
self, invite_id: str, *, with_counts: bool = True, with_expiration: bool = True
) -> Response[invite.Invite]:
params = {
'with_counts': int(with_counts),
'with_expiration': int(with_expiration),
@ -1327,7 +1411,9 @@ class HTTPClient:
def get_roles(self, guild_id: Snowflake) -> Response[List[role.Role]]:
return self.request(Route('GET', '/guilds/{guild_id}/roles', guild_id=guild_id))
def edit_role(self, guild_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None, **fields: Any) -> Response[role.Role]:
def edit_role(
self, guild_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None, **fields: Any
) -> Response[role.Role]:
r = Route('PATCH', '/guilds/{guild_id}/roles/{role_id}', guild_id=guild_id, role_id=role_id)
valid_keys = ('name', 'permissions', 'color', 'hoist', 'mentionable')
payload = {k: v for k, v in fields.items() if k in valid_keys}
@ -1344,7 +1430,7 @@ class HTTPClient:
role_ids: List[int],
*,
reason: Optional[str] = None,
) -> Response[member.Member]:
) -> Response[member.MemberWithUser]:
return self.edit_member(guild_id=guild_id, user_id=user_id, roles=role_ids, reason=reason)
def create_role(self, guild_id: Snowflake, *, reason: Optional[str] = None, **fields: Any) -> Response[role.Role]:
@ -1361,7 +1447,9 @@ class HTTPClient:
r = Route('PATCH', '/guilds/{guild_id}/roles', guild_id=guild_id)
return self.request(r, json=positions, reason=reason)
def add_role(self, guild_id: Snowflake, user_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]:
def add_role(
self, guild_id: Snowflake, user_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None
) -> Response[None]:
r = Route(
'PUT',
'/guilds/{guild_id}/members/{user_id}/roles/{role_id}',
@ -1371,7 +1459,9 @@ class HTTPClient:
)
return self.request(r, reason=reason)
def remove_role(self, guild_id: Snowflake, user_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]:
def remove_role(
self, guild_id: Snowflake, user_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None
) -> Response[None]:
r = Route(
'DELETE',
'/guilds/{guild_id}/members/{user_id}/roles/{role_id}',
@ -1396,11 +1486,7 @@ class HTTPClient:
return self.request(r, json=payload, reason=reason)
def delete_channel_permissions(
self,
channel_id: Snowflake,
target: channel.OverwriteType,
*,
reason: Optional[str] = None
self, channel_id: Snowflake, target: channel.OverwriteType, *, reason: Optional[str] = None
) -> Response[None]:
r = Route('DELETE', '/channels/{channel_id}/permissions/{target}', channel_id=channel_id, target=target)
return self.request(r, reason=reason)
@ -1414,7 +1500,7 @@ class HTTPClient:
channel_id: Snowflake,
*,
reason: Optional[str] = None,
) -> Response[member.Member]:
) -> Response[member.MemberWithUser]:
return self.edit_member(guild_id=guild_id, user_id=user_id, channel_id=channel_id, reason=reason)
# Stage instance management
@ -1422,7 +1508,7 @@ class HTTPClient:
def get_stage_instance(self, channel_id: Snowflake) -> Response[channel.StageInstance]:
return self.request(Route('GET', '/stage-instances/{channel_id}', channel_id=channel_id))
def create_stage_instance(self, **payload) -> Response[channel.StageInstance]:
def create_stage_instance(self, *, reason: Optional[str], **payload: Any) -> Response[channel.StageInstance]:
valid_keys = (
'channel_id',
'topic',
@ -1430,26 +1516,30 @@ class HTTPClient:
)
payload = {k: v for k, v in payload.items() if k in valid_keys}
return self.request(Route('POST', '/stage-instances'), json=payload)
return self.request(Route('POST', '/stage-instances'), json=payload, reason=reason)
def edit_stage_instance(self, channel_id: Snowflake, **payload) -> Response[None]:
def edit_stage_instance(self, channel_id: Snowflake, *, reason: Optional[str] = None, **payload: Any) -> Response[None]:
valid_keys = (
'topic',
'privacy_level',
)
payload = {k: v for k, v in payload.items() if k in valid_keys}
return self.request(Route('PATCH', '/stage-instances/{channel_id}', channel_id=channel_id), json=payload)
return self.request(
Route('PATCH', '/stage-instances/{channel_id}', channel_id=channel_id), json=payload, reason=reason
)
def delete_stage_instance(self, channel_id: Snowflake) -> Response[None]:
return self.request(Route('DELETE', '/stage-instances/{channel_id}', channel_id=channel_id))
def delete_stage_instance(self, channel_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]:
return self.request(Route('DELETE', '/stage-instances/{channel_id}', channel_id=channel_id), reason=reason)
# Application commands (global)
def get_global_commands(self, application_id: Snowflake) -> Response[List[interactions.ApplicationCommand]]:
return self.request(Route('GET', '/applications/{application_id}/commands', application_id=application_id))
def get_global_command(self, application_id: Snowflake, command_id: Snowflake) -> Response[interactions.ApplicationCommand]:
def get_global_command(
self, application_id: Snowflake, command_id: Snowflake
) -> Response[interactions.ApplicationCommand]:
r = Route(
'GET',
'/applications/{application_id}/commands/{command_id}',
@ -1462,7 +1552,8 @@ class HTTPClient:
r = Route('POST', '/applications/{application_id}/commands', application_id=application_id)
return self.request(r, json=payload)
def edit_global_command(self,
def edit_global_command(
self,
application_id: Snowflake,
command_id: Snowflake,
payload: interactions.EditApplicationCommand,
@ -1490,13 +1581,17 @@ class HTTPClient:
)
return self.request(r)
def bulk_upsert_global_commands(self, application_id: Snowflake, payload) -> Response[List[interactions.ApplicationCommand]]:
def bulk_upsert_global_commands(
self, application_id: Snowflake, payload
) -> Response[List[interactions.ApplicationCommand]]:
r = Route('PUT', '/applications/{application_id}/commands', application_id=application_id)
return self.request(r, json=payload)
# Application commands (guild)
def get_guild_commands(self, application_id: Snowflake, guild_id: Snowflake) -> Response[List[interactions.ApplicationCommand]]:
def get_guild_commands(
self, application_id: Snowflake, guild_id: Snowflake
) -> Response[List[interactions.ApplicationCommand]]:
r = Route(
'GET',
'/applications/{application_id}/guilds/{guild_id}/commands',
@ -1534,7 +1629,8 @@ class HTTPClient:
)
return self.request(r, json=payload)
def edit_guild_command(self,
def edit_guild_command(
self,
application_id: Snowflake,
guild_id: Snowflake,
command_id: Snowflake,
@ -1571,9 +1667,9 @@ class HTTPClient:
return self.request(r)
def bulk_upsert_guild_commands(
self,
self,
application_id: Snowflake,
guild_id: Snowflake,
guild_id: Snowflake,
payload: List[interactions.EditApplicationCommand],
) -> Response[List[interactions.ApplicationCommand]]:
r = Route(
@ -1606,7 +1702,7 @@ class HTTPClient:
form: List[Dict[str, Any]] = [
{
'name': 'payload_json',
'value': utils.to_json(payload),
'value': utils._to_json(payload),
}
]
@ -1628,7 +1724,7 @@ class HTTPClient:
token: str,
*,
type: InteractionResponseType,
data: Optional[interactions.InteractionApplicationCommandCallbackData] = None
data: Optional[interactions.InteractionApplicationCommandCallbackData] = None,
) -> Response[None]:
r = Route(
'POST',
@ -1718,7 +1814,7 @@ class HTTPClient:
content: Optional[str] = None,
embeds: Optional[List[embed.Embed]] = None,
allowed_mentions: Optional[message.AllowedMentions] = None,
)-> Response[message.Message]:
) -> Response[message.Message]:
r = Route(
'PATCH',
'/webhooks/{application_id}/{interaction_token}/messages/{message_id}',

View File

@ -127,7 +127,7 @@ class Integration:
self.user = User(state=self._state, data=user) if user else None
self.enabled: bool = data['enabled']
async def delete(self) -> None:
async def delete(self, *, reason: Optional[str] = None) -> None:
"""|coro|
Deletes the integration.
@ -135,6 +135,13 @@ class Integration:
You must have the :attr:`~Permissions.manage_guild` permission to
do this.
Parameters
-----------
reason: :class:`str`
The reason the integration was deleted. Shows up on the audit log.
.. versionadded:: 2.0
Raises
-------
Forbidden
@ -142,7 +149,7 @@ class Integration:
HTTPException
Deleting the integration failed.
"""
await self._state.http.delete_integration(self.guild.id, self.id)
await self._state.http.delete_integration(self.guild.id, self.id, reason=reason)
class StreamIntegration(Integration):
@ -255,17 +262,10 @@ class StreamIntegration(Integration):
if enable_emoticons is not MISSING:
payload['enable_emoticons'] = enable_emoticons
# This endpoint is undocumented.
# Unsure if it returns the data or not as a result
await self._state.http.edit_integration(self.guild.id, self.id, **payload)
if expire_behaviour is not MISSING:
self.expire_behaviour = expire_behaviour
if enable_emoticons is not MISSING:
self.enable_emoticons = enable_emoticons
if expire_grace_period is not MISSING:
self.expire_grace_period = expire_grace_period
async def sync(self) -> None:
"""|coro|

View File

@ -31,6 +31,7 @@ import asyncio
from . import utils
from .enums import try_enum, InteractionType, InteractionResponseType
from .errors import InteractionResponded, HTTPException, ClientException
from .channel import PartialMessageable, ChannelType
from .user import User
from .member import Member
@ -57,10 +58,12 @@ if TYPE_CHECKING:
from aiohttp import ClientSession
from .embeds import Embed
from .ui.view import View
from .channel import VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel
from .channel import VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, PartialMessageable
from .threads import Thread
InteractionChannel = Union[VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, Thread]
InteractionChannel = Union[
VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, Thread, PartialMessageable
]
MISSING: Any = utils.MISSING
@ -92,6 +95,8 @@ class Interaction:
token: :class:`str`
The token to continue the interaction. These are valid
for 15 minutes.
data: :class:`dict`
The raw interaction data.
"""
__slots__: Tuple[str, ...] = (
@ -111,6 +116,7 @@ class Interaction:
'_original_message',
'_cs_response',
'_cs_followup',
'_cs_channel',
)
def __init__(self, *, data: InteractionPayload, state: ConnectionState):
@ -129,10 +135,9 @@ class Interaction:
self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id')
self.application_id: int = int(data['application_id'])
channel = self.channel or Object(id=self.channel_id) # type: ignore
self.message: Optional[Message]
try:
self.message = Message(state=self._state, channel=channel, data=data['message']) # type: ignore
self.message = Message(state=self._state, channel=self.channel, data=data['message']) # type: ignore
except KeyError:
self.message = None
@ -160,15 +165,21 @@ class Interaction:
"""Optional[:class:`Guild`]: The guild the interaction was sent from."""
return self._state and self._state._get_guild(self.guild_id)
@property
@utils.cached_slot_property('_cs_channel')
def channel(self) -> Optional[InteractionChannel]:
"""Optional[Union[:class:`abc.GuildChannel`, :class:`Thread`]]: The channel the interaction was sent from.
"""Optional[Union[:class:`abc.GuildChannel`, :class:`PartialMessageable`, :class:`Thread`]]: The channel the interaction was sent from.
Note that due to a Discord limitation, DM channels are not resolved since there is
no data to complete them.
no data to complete them. These are :class:`PartialMessageable` instead.
"""
guild = self.guild
return guild and guild._resolve_channel(self.channel_id)
channel = guild and guild._resolve_channel(self.channel_id)
if channel is None:
if self.channel_id is not None:
type = ChannelType.text if self.guild_id is not None else ChannelType.private
return PartialMessageable(state=self._state, id=self.channel_id, type=type)
return None
return channel
@property
def permissions(self) -> Permissions:
@ -250,7 +261,7 @@ class Interaction:
files: List[File] = MISSING,
view: Optional[View] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None,
):
) -> InteractionMessage:
"""|coro|
Edits the original interaction response message.
@ -291,7 +302,12 @@ class Interaction:
TypeError
You specified both ``embed`` and ``embeds`` or ``file`` and ``files``
ValueError
The length of ``embeds`` was invalid
The length of ``embeds`` was invalid.
Returns
--------
:class:`InteractionMessage`
The newly edited message.
"""
previous_mentions: Optional[AllowedMentions] = self._state.allowed_mentions
@ -315,8 +331,11 @@ class Interaction:
files=params.files,
)
# The message channel types should always match
message = InteractionMessage(state=self._state, channel=self.channel, data=data) # type: ignore
if view and not view.is_finished():
self._state.store_view(view, int(data['id']))
self._state.store_view(view, message.id)
return message
async def delete_original_message(self) -> None:
"""|coro|
@ -626,6 +645,9 @@ class _InteractionMessageState:
def store_user(self, data):
return self._parent.store_user(data)
def create_user(self, data):
return self._parent.create_user(data)
@property
def http(self):
return self._parent.http
@ -658,7 +680,7 @@ class InteractionMessage(Message):
files: List[File] = MISSING,
view: Optional[View] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None,
):
) -> InteractionMessage:
"""|coro|
Edits the message.
@ -693,9 +715,14 @@ class InteractionMessage(Message):
TypeError
You specified both ``embed`` and ``embeds`` or ``file`` and ``files``
ValueError
The length of ``embeds`` was invalid
The length of ``embeds`` was invalid.
Returns
---------
:class:`InteractionMessage`
The newly edited message.
"""
await self._state._interaction.edit_original_message(
return await self._state._interaction.edit_original_message(
content=content,
embeds=embeds,
embed=embed,

View File

@ -230,6 +230,7 @@ class Invite(Hashable):
Returns the invite URL.
The following table illustrates what methods will obtain the attributes:
+------------------------------------+------------------------------------------------------------+
@ -257,7 +258,7 @@ class Invite(Hashable):
Attributes
-----------
max_age: :class:`int`
How long the before the invite expires in seconds.
How long before the invite expires in seconds.
A value of ``0`` indicates that it doesn't expire.
code: :class:`str`
The URL fragment used for the invite.
@ -352,12 +353,12 @@ class Invite(Hashable):
self.expires_at: Optional[datetime.datetime] = parse_time(expires_at) if expires_at else None
inviter_data = data.get('inviter')
self.inviter: Optional[User] = None if inviter_data is None else self._state.store_user(inviter_data)
self.inviter: Optional[User] = None if inviter_data is None else self._state.create_user(inviter_data)
self.channel: Optional[InviteChannelType] = self._resolve_channel(data.get('channel'), channel)
target_user_data = data.get('target_user')
self.target_user: Optional[User] = None if target_user_data is None else self._state.store_user(target_user_data)
self.target_user: Optional[User] = None if target_user_data is None else self._state.create_user(target_user_data)
self.target_type: InviteTarget = try_enum(InviteTarget, data.get("target_type", 0))
@ -433,6 +434,9 @@ class Invite(Hashable):
def __str__(self) -> str:
return self.url
def __int__(self) -> int:
return 0 # To keep the object compatible with the hashable abc.
def __repr__(self) -> str:
return (
f'<Invite code={self.code!r} guild={self.guild!r} '

View File

@ -34,6 +34,7 @@ from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Typ
import discord.abc
from . import utils
from .asset import Asset
from .utils import MISSING
from .user import BaseUser, User, _UserTag
from .activity import create_activity, ActivityTypes
@ -48,11 +49,13 @@ __all__ = (
)
if TYPE_CHECKING:
from .channel import VoiceChannel, StageChannel
from .asset import Asset
from .channel import DMChannel, VoiceChannel, StageChannel
from .flags import PublicUserFlags
from .guild import Guild
from .types.activity import PartialPresenceUpdate
from .types.member import (
GatewayMember as GatewayMemberPayload,
MemberWithUser as MemberWithUserPayload,
Member as MemberPayload,
UserWithMember as UserWithMemberPayload,
)
@ -223,6 +226,10 @@ class Member(discord.abc.Messageable, _UserTag):
Returns the member's name with the discriminator.
.. describe:: int(x)
Returns the user's ID.
Attributes
----------
joined_at: Optional[:class:`datetime.datetime`]
@ -247,7 +254,7 @@ class Member(discord.abc.Messageable, _UserTag):
.. versionadded:: 1.6
premium_since: Optional[:class:`datetime.datetime`]
An aware datetime object that specifies the date and time in UTC when the member used their
Nitro boost on the guild, if available. This could be ``None``.
"Nitro boost" on the guild, if available. This could be ``None``.
"""
__slots__ = (
@ -261,6 +268,7 @@ class Member(discord.abc.Messageable, _UserTag):
'_client_status',
'_user',
'_state',
'_avatar',
)
if TYPE_CHECKING:
@ -270,14 +278,17 @@ class Member(discord.abc.Messageable, _UserTag):
bot: bool
system: bool
created_at: datetime.datetime
default_avatar = User.default_avatar
avatar = User.avatar
dm_channel = User.dm_channel
default_avatar: Asset
avatar: Optional[Asset]
dm_channel: Optional[DMChannel]
create_dm = User.create_dm
mutual_guilds = User.mutual_guilds
public_flags = User.public_flags
mutual_guilds: List[Guild]
public_flags: PublicUserFlags
banner: Optional[Asset]
accent_color: Optional[Colour]
accent_colour: Optional[Colour]
def __init__(self, *, data: GatewayMemberPayload, guild: Guild, state: ConnectionState):
def __init__(self, *, data: MemberWithUserPayload, guild: Guild, state: ConnectionState):
self._state: ConnectionState = state
self._user: User = state.store_user(data['user'])
self.guild: Guild = guild
@ -288,10 +299,14 @@ class Member(discord.abc.Messageable, _UserTag):
self.activities: Tuple[ActivityTypes, ...] = tuple()
self.nick: Optional[str] = data.get('nick', None)
self.pending: bool = data.get('pending', False)
self._avatar: Optional[str] = data.get('avatar')
def __str__(self) -> str:
return str(self._user)
def __int__(self) -> int:
return self.id
def __repr__(self) -> str:
return (
f'<Member id={self._user.id} name={self._user.name!r} discriminator={self._user.discriminator!r}'
@ -326,7 +341,7 @@ class Member(discord.abc.Messageable, _UserTag):
try:
member_data = data.pop('member')
except KeyError:
return state.store_user(data)
return state.create_user(data)
else:
member_data['user'] = data # type: ignore
return cls(data=member_data, guild=guild, state=state) # type: ignore
@ -344,6 +359,7 @@ class Member(discord.abc.Messageable, _UserTag):
self.pending = member.pending
self.activities = member.activities
self._state = member._state
self._avatar = member._avatar
# Reference will not be copied unless necessary by PRESENCE_UPDATE
# See below
@ -369,6 +385,7 @@ class Member(discord.abc.Messageable, _UserTag):
self.premium_since = utils.parse_time(data.get('premium_since'))
self._roles = utils.SnowflakeList(map(int, data['roles']))
self._avatar = data.get('avatar')
def _presence_update(self, data: PartialPresenceUpdate, user: UserPayload) -> Optional[Tuple[User, User]]:
self.activities = tuple(map(create_activity, data['activities']))
@ -493,6 +510,29 @@ class Member(discord.abc.Messageable, _UserTag):
"""
return self.nick or self.name
@property
def display_avatar(self) -> Asset:
""":class:`Asset`: Returns the member's display avatar.
For regular members this is just their avatar, but
if they have a guild specific avatar then that
is returned instead.
.. versionadded:: 2.0
"""
return self.guild_avatar or self._user.avatar or self._user.default_avatar
@property
def guild_avatar(self) -> Optional[Asset]:
"""Optional[:class:`Asset`]: Returns an :class:`Asset` for the guild avatar
the member has. If unavailable, ``None`` is returned.
.. versionadded:: 2.0
"""
if self._avatar is None:
return None
return Asset._from_guild_avatar(self._state, self.guild.id, self.id, self._avatar)
@property
def activity(self) -> Optional[ActivityTypes]:
"""Optional[Union[:class:`BaseActivity`, :class:`Spotify`]]: Returns the primary
@ -611,7 +651,7 @@ class Member(discord.abc.Messageable, _UserTag):
roles: List[discord.abc.Snowflake] = MISSING,
voice_channel: Optional[VocalGuildChannel] = MISSING,
reason: Optional[str] = None,
) -> None:
) -> Optional[Member]:
"""|coro|
Edits the member's data.
@ -637,6 +677,9 @@ class Member(discord.abc.Messageable, _UserTag):
.. versionchanged:: 1.1
Can now pass ``None`` to ``voice_channel`` to kick a member from voice.
.. versionchanged:: 2.0
The newly member is now optionally returned, if applicable.
Parameters
-----------
nick: Optional[:class:`str`]
@ -664,6 +707,12 @@ class Member(discord.abc.Messageable, _UserTag):
You do not have the proper permissions to the action requested.
HTTPException
The operation failed.
Returns
--------
Optional[:class:`.Member`]
The newly updated member, if applicable. This is only returned
when certain fields are updated.
"""
http = self._state.http
guild_id = self.guild.id
@ -706,7 +755,8 @@ class Member(discord.abc.Messageable, _UserTag):
payload['roles'] = tuple(r.id for r in roles)
if payload:
await http.edit_member(guild_id, self.id, reason=reason, **payload)
data = await http.edit_member(guild_id, self.id, reason=reason, **payload)
return Member(data=data, guild=self.guild, state=self._state)
async def request_to_speak(self) -> None:
"""|coro|
@ -847,7 +897,7 @@ class Member(discord.abc.Messageable, _UserTag):
for role in roles:
await req(guild_id, user_id, role.id, reason=reason)
def get_role(self, role_id: int) -> Optional[Role]:
def get_role(self, role_id: int, /) -> Optional[Role]:
"""Returns a role with the given ID from roles which the member has.
.. versionadded:: 2.0

View File

@ -29,7 +29,7 @@ import datetime
import re
import io
from os import PathLike
from typing import Dict, TYPE_CHECKING, Union, List, Optional, Any, Callable, Tuple, ClassVar, Optional, overload
from typing import Dict, TYPE_CHECKING, Union, List, Optional, Any, Callable, Tuple, ClassVar, Optional, overload, TypeVar, Type
from . import utils
from .reaction import Reaction
@ -45,7 +45,7 @@ from .file import File
from .utils import escape_mentions, MISSING
from .guild import Guild
from .mixins import Hashable
from .sticker import Sticker
from .sticker import StickerItem
from .threads import Thread
if TYPE_CHECKING:
@ -70,12 +70,13 @@ if TYPE_CHECKING:
from .abc import GuildChannel, PartialMessageableChannel, MessageableChannel
from .components import Component
from .state import ConnectionState
from .channel import TextChannel, GroupChannel, DMChannel
from .channel import TextChannel, GroupChannel, DMChannel, PartialMessageable
from .mentions import AllowedMentions
from .user import User
from .role import Role
from .ui.view import View
MR = TypeVar('MR', bound='MessageReference')
EmojiInputType = Union[Emoji, PartialEmoji, str]
__all__ = (
@ -124,6 +125,10 @@ class Attachment(Hashable):
Returns the hash of the attachment.
.. describe:: int(x)
Returns the attachment's ID.
.. versionchanged:: 1.7
Attachment can now be casted to :class:`str` and is hashable.
@ -341,7 +346,8 @@ class DeletedReferencedMessage:
@property
def id(self) -> int:
""":class:`int`: The message ID of the deleted referenced message."""
return self._parent.message_id
# the parent's message id won't be None here
return self._parent.message_id # type: ignore
@property
def channel_id(self) -> int:
@ -393,13 +399,13 @@ class MessageReference:
def __init__(self, *, message_id: int, channel_id: int, guild_id: Optional[int] = None, fail_if_not_exists: bool = True):
self._state: Optional[ConnectionState] = None
self.resolved: Optional[Union[Message, DeletedReferencedMessage]] = None
self.message_id: int = message_id
self.message_id: Optional[int] = message_id
self.channel_id: int = channel_id
self.guild_id: Optional[int] = guild_id
self.fail_if_not_exists: bool = fail_if_not_exists
@classmethod
def with_state(cls, state: ConnectionState, data: MessageReferencePayload) -> MessageReference:
def with_state(cls: Type[MR], state: ConnectionState, data: MessageReferencePayload) -> MR:
self = cls.__new__(cls)
self.message_id = utils._get_as_snowflake(data, 'message_id')
self.channel_id = int(data.pop('channel_id'))
@ -410,7 +416,7 @@ class MessageReference:
return self
@classmethod
def from_message(cls, message: Message, *, fail_if_not_exists: bool = True) -> MessageReference:
def from_message(cls: Type[MR], message: Message, *, fail_if_not_exists: bool = True) -> MR:
"""Creates a :class:`MessageReference` from an existing :class:`~discord.Message`.
.. versionadded:: 1.6
@ -457,13 +463,13 @@ class MessageReference:
return f'<MessageReference message_id={self.message_id!r} channel_id={self.channel_id!r} guild_id={self.guild_id!r}>'
def to_dict(self) -> MessageReferencePayload:
result = {'message_id': self.message_id} if self.message_id is not None else {}
result: MessageReferencePayload = {'message_id': self.message_id} if self.message_id is not None else {}
result['channel_id'] = self.channel_id
if self.guild_id is not None:
result['guild_id'] = self.guild_id
if self.fail_if_not_exists is not None:
result['fail_if_not_exists'] = self.fail_if_not_exists
return result # type: ignore
return result
to_message_reference_dict = to_dict
@ -501,6 +507,14 @@ class Message(Hashable):
Returns the message's hash.
.. describe:: str(x)
Returns the message's content.
.. describe:: int(x)
Returns the message's ID.
Attributes
-----------
tts: :class:`bool`
@ -520,7 +534,7 @@ class Message(Hashable):
This is not stored long term within Discord's servers and is only used ephemerally.
embeds: List[:class:`Embed`]
A list of embeds the message has.
channel: Union[:class:`TextChannel`, :class:`Thread`, :class:`DMChannel`, :class:`GroupChannel`]
channel: Union[:class:`TextChannel`, :class:`Thread`, :class:`DMChannel`, :class:`GroupChannel`, :class:`PartialMessageable`]
The :class:`TextChannel` or :class:`Thread` that the message was sent from.
Could be a :class:`DMChannel` or :class:`GroupChannel` if it's a private message.
reference: Optional[:class:`~discord.MessageReference`]
@ -588,8 +602,8 @@ class Message(Hashable):
- ``description``: A string representing the application's description.
- ``icon``: A string representing the icon ID of the application.
- ``cover_image``: A string representing the embed's image asset ID.
stickers: List[:class:`Sticker`]
A list of stickers given to the message.
stickers: List[:class:`StickerItem`]
A list of sticker items given to the message.
.. versionadded:: 1.6
components: List[:class:`Component`]
@ -637,7 +651,7 @@ class Message(Hashable):
_HANDLERS: ClassVar[List[Tuple[str, Callable[..., None]]]]
_CACHED_SLOTS: ClassVar[List[str]]
guild: Optional[Guild]
ref: Optional[MessageReference]
reference: Optional[MessageReference]
mentions: List[Union[User, Member]]
author: Union[User, Member]
role_mentions: List[Role]
@ -646,7 +660,7 @@ class Message(Hashable):
self,
*,
state: ConnectionState,
channel: Union[TextChannel, Thread, DMChannel, GroupChannel],
channel: MessageableChannel,
data: MessagePayload,
):
self._state: ConnectionState = state
@ -666,10 +680,11 @@ class Message(Hashable):
self.tts: bool = data['tts']
self.content: str = data['content']
self.nonce: Optional[Union[int, str]] = data.get('nonce')
self.stickers: List[Sticker] = [Sticker(data=d, state=state) for d in data.get('stickers', [])]
self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get('sticker_items', [])]
self.components: List[Component] = [_component_factory(d) for d in data.get('components', [])]
try:
# if the channel doesn't have a guild attribute, we handle that
self.guild = channel.guild # type: ignore
except AttributeError:
self.guild = state._get_guild(utils._get_as_snowflake(data, 'guild_id'))
@ -694,7 +709,8 @@ class Message(Hashable):
else:
chan, _ = state._get_guild_channel(resolved)
ref.resolved = self.__class__(channel=chan, data=resolved, state=state)
# the channel will be the correct type here
ref.resolved = self.__class__(channel=chan, data=resolved, state=state) # type: ignore
for handler in ('author', 'member', 'mentions', 'mention_roles'):
try:
@ -708,6 +724,10 @@ class Message(Hashable):
f'<{name} id={self.id} channel={self.channel!r} type={self.type!r} author={self.author!r} flags={self.flags!r}>'
)
def __str__(self) -> Optional[str]:
return self.content
def _try_patch(self, data, key, transform=None) -> None:
try:
value = data[key]
@ -977,9 +997,17 @@ class Message(Hashable):
def is_system(self) -> bool:
""":class:`bool`: Whether the message is a system message.
A system message is a message that is constructed entirely by the Discord API
in response to something.
.. versionadded:: 1.3
"""
return self.type is not MessageType.default
return self.type not in (
MessageType.default,
MessageType.reply,
MessageType.application_command,
MessageType.thread_starter_message,
)
@utils.cached_slot_property('_cs_system_content')
def system_content(self):
@ -994,21 +1022,27 @@ class Message(Hashable):
if self.type is MessageType.default:
return self.content
if self.type is MessageType.pins_add:
return f'{self.author.name} pinned a message to this channel.'
if self.type is MessageType.recipient_add:
return f'{self.author.name} added {self.mentions[0].name} to the group.'
if self.channel.type is ChannelType.group:
return f'{self.author.name} added {self.mentions[0].name} to the group.'
else:
return f'{self.author.name} added {self.mentions[0].name} to the thread.'
if self.type is MessageType.recipient_remove:
return f'{self.author.name} removed {self.mentions[0].name} from the group.'
if self.channel.type is ChannelType.group:
return f'{self.author.name} removed {self.mentions[0].name} from the group.'
else:
return f'{self.author.name} removed {self.mentions[0].name} from the thread.'
if self.type is MessageType.channel_name_change:
return f'{self.author.name} changed the channel name: {self.content}'
return f'{self.author.name} changed the channel name: **{self.content}**'
if self.type is MessageType.channel_icon_change:
return f'{self.author.name} changed the channel icon.'
if self.type is MessageType.pins_add:
return f'{self.author.name} pinned a message to this channel.'
if self.type is MessageType.new_member:
formats = [
"{0} joined the party.",
@ -1030,21 +1064,34 @@ class Message(Hashable):
return formats[created_at_ms % len(formats)].format(self.author.name)
if self.type is MessageType.premium_guild_subscription:
return f'{self.author.name} just boosted the server!'
if not self.content:
return f'{self.author.name} just boosted the server!'
else:
return f'{self.author.name} just boosted the server **{self.content}** times!'
if self.type is MessageType.premium_guild_tier_1:
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 1!**'
if not self.content:
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 1!**'
else:
return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 1!**'
if self.type is MessageType.premium_guild_tier_2:
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 2!**'
if not self.content:
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 2!**'
else:
return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 2!**'
if self.type is MessageType.premium_guild_tier_3:
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 3!**'
if not self.content:
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 3!**'
else:
return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 3!**'
if self.type is MessageType.channel_follow_add:
return f'{self.author.name} has added {self.content} to this channel'
if self.type is MessageType.guild_stream:
# the author will be a Member
return f'{self.author.name} is live! Now streaming {self.author.activity.name}' # type: ignore
if self.type is MessageType.guild_discovery_disqualified:
@ -1059,13 +1106,23 @@ class Message(Hashable):
if self.type is MessageType.guild_discovery_grace_period_final_warning:
return 'This server has failed Discovery activity requirements for 3 weeks in a row. If this server fails for 1 more week, it will be removed from Discovery.'
if self.type is MessageType.thread_created:
return f'{self.author.name} started a thread: **{self.content}**. See all **threads**.'
if self.type is MessageType.reply:
return self.content
if self.type is MessageType.thread_starter_message:
if self.reference is None or self.reference.resolved is None:
return 'Sorry, we couldn\'t load the first message in this thread'
# the resolved message for the reference will be a Message
return self.reference.resolved.content # type: ignore
if self.type is MessageType.guild_invite_reminder:
return 'Wondering who to invite?\nStart by inviting anyone who can help you build the server!'
async def delete(self, *, delay: Optional[float] = None) -> None:
async def delete(self, *, delay: Optional[float] = None, silent: bool = False) -> None:
"""|coro|
Deletes the message.
@ -1076,12 +1133,17 @@ class Message(Hashable):
.. versionchanged:: 1.1
Added the new ``delay`` keyword-only parameter.
.. versionchanged:: 2.0
Added the new ``silent`` keyword-only parameter.
Parameters
-----------
delay: Optional[:class:`float`]
If provided, the number of seconds to wait in the background
before deleting the message. If the deletion fails then it is silently ignored.
silent: :class:`bool`
If silent is set to ``True``, the error will not be raised, it will be ignored.
This defaults to ``False``
Raises
------
@ -1103,7 +1165,11 @@ class Message(Hashable):
asyncio.create_task(delete(delay))
else:
await self._state.http.delete_message(self.channel.id, self.id)
try:
await self._state.http.delete_message(self.channel.id, self.id)
except Exception:
if not silent:
raise
@overload
async def edit(
@ -1116,7 +1182,7 @@ class Message(Hashable):
delete_after: Optional[float] = ...,
allowed_mentions: Optional[AllowedMentions] = ...,
view: Optional[View] = ...,
) -> None:
) -> Message:
...
@overload
@ -1130,7 +1196,7 @@ class Message(Hashable):
delete_after: Optional[float] = ...,
allowed_mentions: Optional[AllowedMentions] = ...,
view: Optional[View] = ...,
) -> None:
) -> Message:
...
async def edit(
@ -1143,7 +1209,7 @@ class Message(Hashable):
delete_after: Optional[float] = None,
allowed_mentions: Optional[AllowedMentions] = MISSING,
view: Optional[View] = MISSING,
) -> None:
) -> Message:
"""|coro|
Edits the message.
@ -1245,9 +1311,8 @@ class Message(Hashable):
else:
payload['components'] = []
if payload:
data = await self._state.http.edit_message(self.channel.id, self.id, **payload)
self._update(data)
data = await self._state.http.edit_message(self.channel.id, self.id, **payload)
message = Message(state=self._state, channel=self.channel, data=data)
if view and not view.is_finished():
self._state.store_view(view, self.id)
@ -1255,6 +1320,8 @@ class Message(Hashable):
if delete_after is not None:
await self.delete(delay=delete_after)
return message
async def publish(self) -> None:
"""|coro|
@ -1449,49 +1516,51 @@ class Message(Hashable):
"""
await self._state.http.clear_reactions(self.channel.id, self.id)
async def start_thread(self, *, name: str, auto_archive_duration: ThreadArchiveDuration = 1440) -> Thread:
async def create_thread(self, *, name: str, auto_archive_duration: ThreadArchiveDuration = MISSING) -> Thread:
"""|coro|
Starts a public thread from this message.
Creates a public thread from this message.
You must have :attr:`~discord.Permissions.send_messages` and
:attr:`~discord.Permissions.use_threads` in order to start a thread.
You must have :attr:`~discord.Permissions.create_public_threads` in order to
create a public thread from a message.
The channel this message belongs in must be a :class:`TextChannel`.
.. versionadded:: 2.0
Parameters
-----------
name: :class:`str`
The name of the thread.
auto_archive_duration: :class:`int`
The duration in minutes before a thread is automatically archived for inactivity.
Defaults to ``1440`` or 24 hours.
If not provided, the channel's default auto archive duration is used.
Raises
-------
Forbidden
You do not have permissions to start a thread.
You do not have permissions to create a thread.
HTTPException
Starting the thread failed.
Creating the thread failed.
InvalidArgument
This message does not have guild info attached.
Returns
--------
:class:`.Thread`
The started thread.
The created thread.
"""
if self.guild is None:
raise InvalidArgument('This message does not have guild info attached.')
data = await self._state.http.start_public_thread(
default_auto_archive_duration: ThreadArchiveDuration = getattr(self.channel, 'default_auto_archive_duration', 1440)
data = await self._state.http.start_thread_with_message(
self.channel.id,
self.id,
name=name,
auto_archive_duration=auto_archive_duration,
type=ChannelType.public_thread.value,
auto_archive_duration=auto_archive_duration or default_auto_archive_duration,
)
return Thread(guild=self.guild, state=self._state, data=data) # type: ignore
return Thread(guild=self.guild, state=self._state, data=data)
async def reply(self, content: Optional[str] = None, **kwargs) -> Message:
"""|coro|
@ -1581,6 +1650,10 @@ class PartialMessage(Hashable):
Returns the partial message's hash.
.. describe:: int(x)
Returns the partial message's ID.
Attributes
-----------
channel: Union[:class:`TextChannel`, :class:`Thread`, :class:`DMChannel`]
@ -1605,8 +1678,15 @@ class PartialMessage(Hashable):
to_message_reference_dict = Message.to_message_reference_dict
def __init__(self, *, channel: PartialMessageableChannel, id: int):
if channel.type not in (ChannelType.text, ChannelType.news, ChannelType.private):
raise TypeError(f'Expected TextChannel or DMChannel not {type(channel)!r}')
if channel.type not in (
ChannelType.text,
ChannelType.news,
ChannelType.private,
ChannelType.news_thread,
ChannelType.public_thread,
ChannelType.private_thread,
):
raise TypeError(f'Expected TextChannel, DMChannel or Thread not {type(channel)!r}')
self.channel: PartialMessageableChannel = channel
self._state: ConnectionState = channel._state
@ -1730,7 +1810,7 @@ class PartialMessage(Hashable):
fields['embed'] = embed.to_dict()
try:
suppress = fields.pop('suppress')
suppress: bool = fields.pop('suppress')
except KeyError:
pass
else:
@ -1768,9 +1848,10 @@ class PartialMessage(Hashable):
data = await self._state.http.edit_message(self.channel.id, self.id, **fields)
if delete_after is not None:
await self.delete(delay=delete_after) # type: ignore
await self.delete(delay=delete_after)
if fields:
# data isn't unbound
msg = self._state.create_message(channel=self.channel, data=data) # type: ignore
if view and not view.is_finished():
self._state.store_view(view, self.id)

View File

@ -22,24 +22,20 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from typing import TypeVar
__all__ = (
'EqualityComparable',
'Hashable',
)
E = TypeVar('E', bound='EqualityComparable')
class EqualityComparable:
__slots__ = ()
id: int
def __eq__(self: E, other: E) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and other.id == self.id
def __ne__(self: E, other: E) -> bool:
def __ne__(self, other: object) -> bool:
if isinstance(other, self.__class__):
return other.id != self.id
return True
@ -47,5 +43,8 @@ class EqualityComparable:
class Hashable(EqualityComparable):
__slots__ = ()
def __int__(self) -> int:
return self.id
def __hash__(self) -> int:
return self.id >> 22

View File

@ -69,6 +69,10 @@ class Object(Hashable):
Returns the object's hash.
.. describe:: int(x)
Returns the object's ID.
Attributes
-----------
id: :class:`int`

View File

@ -22,8 +22,12 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import struct
from typing import TYPE_CHECKING, ClassVar, IO, Generator, Tuple, Optional
from .errors import DiscordException
__all__ = (
@ -40,22 +44,29 @@ class OggError(DiscordException):
# https://tools.ietf.org/html/rfc7845
class OggPage:
_header = struct.Struct('<xBQIIIB')
_header: ClassVar[struct.Struct] = struct.Struct('<xBQIIIB')
if TYPE_CHECKING:
flag: int
gran_pos: int
serial: int
pagenum: int
crc: int
segnum: int
def __init__(self, stream):
def __init__(self, stream: IO[bytes]) -> None:
try:
header = stream.read(struct.calcsize(self._header.format))
self.flag, self.gran_pos, self.serial, \
self.pagenum, self.crc, self.segnum = self._header.unpack(header)
self.segtable = stream.read(self.segnum)
self.segtable: bytes = stream.read(self.segnum)
bodylen = sum(struct.unpack('B'*self.segnum, self.segtable))
self.data = stream.read(bodylen)
self.data: bytes = stream.read(bodylen)
except Exception:
raise OggError('bad data stream') from None
def iter_packets(self):
def iter_packets(self) -> Generator[Tuple[bytes, bool], None, None]:
packetlen = offset = 0
partial = True
@ -74,10 +85,10 @@ class OggPage:
yield self.data[offset:], False
class OggStream:
def __init__(self, stream):
self.stream = stream
def __init__(self, stream: IO[bytes]) -> None:
self.stream: IO[bytes] = stream
def _next_page(self):
def _next_page(self) -> Optional[OggPage]:
head = self.stream.read(4)
if head == b'OggS':
return OggPage(self.stream)
@ -86,13 +97,13 @@ class OggStream:
else:
raise OggError('invalid header magic')
def _iter_pages(self):
def _iter_pages(self) -> Generator[OggPage, None, None]:
page = self._next_page()
while page:
yield page
page = self._next_page()
def iter_packets(self):
def iter_packets(self) -> Generator[bytes, None, None]:
partial = b''
for page in self._iter_pages():
for data, complete in page.iter_packets():

View File

@ -22,6 +22,10 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import List, Tuple, TypedDict, Any, TYPE_CHECKING, Callable, TypeVar, Literal, Optional, overload
import array
import ctypes
import ctypes.util
@ -31,7 +35,24 @@ import os.path
import struct
import sys
from .errors import DiscordException
from .errors import DiscordException, InvalidArgument
if TYPE_CHECKING:
T = TypeVar('T')
BAND_CTL = Literal['narrow', 'medium', 'wide', 'superwide', 'full']
SIGNAL_CTL = Literal['auto', 'voice', 'music']
class BandCtl(TypedDict):
narrow: int
medium: int
wide: int
superwide: int
full: int
class SignalCtl(TypedDict):
auto: int
voice: int
music: int
__all__ = (
'Encoder',
@ -39,7 +60,7 @@ __all__ = (
'OpusNotLoaded',
)
log = logging.getLogger(__name__)
_log = logging.getLogger(__name__)
c_int_ptr = ctypes.POINTER(ctypes.c_int)
c_int16_ptr = ctypes.POINTER(ctypes.c_int16)
@ -76,7 +97,7 @@ CTL_SET_SIGNAL = 4024
CTL_SET_GAIN = 4034
CTL_LAST_PACKET_DURATION = 4039
band_ctl = {
band_ctl: BandCtl = {
'narrow': 1101,
'medium': 1102,
'wide': 1103,
@ -84,22 +105,22 @@ band_ctl = {
'full': 1105,
}
signal_ctl = {
signal_ctl: SignalCtl = {
'auto': -1000,
'voice': 3001,
'music': 3002,
}
def _err_lt(result, func, args):
def _err_lt(result: int, func: Callable, args: List) -> int:
if result < OK:
log.info('error has happened in %s', func.__name__)
_log.info('error has happened in %s', func.__name__)
raise OpusError(result)
return result
def _err_ne(result, func, args):
def _err_ne(result: T, func: Callable, args: List) -> T:
ret = args[-1]._obj
if ret.value != OK:
log.info('error has happened in %s', func.__name__)
_log.info('error has happened in %s', func.__name__)
raise OpusError(ret.value)
return result
@ -108,7 +129,7 @@ def _err_ne(result, func, args):
# The second one are the types of arguments it takes.
# The third is the result type.
# The fourth is the error handler.
exported_functions = [
exported_functions: List[Tuple[Any, ...]] = [
# Generic
('opus_get_version_string',
None, ctypes.c_char_p, None),
@ -158,7 +179,7 @@ exported_functions = [
[ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
]
def libopus_loader(name):
def libopus_loader(name: str) -> Any:
# create the library...
lib = ctypes.cdll.LoadLibrary(name)
@ -178,11 +199,11 @@ def libopus_loader(name):
if item[3]:
func.errcheck = item[3]
except KeyError:
log.exception("Error assigning check function to %s", func)
_log.exception("Error assigning check function to %s", func)
return lib
def _load_default():
def _load_default() -> bool:
global _lib
try:
if sys.platform == 'win32':
@ -198,7 +219,7 @@ def _load_default():
return _lib is not None
def load_opus(name):
def load_opus(name: str) -> None:
"""Loads the libopus shared library for use with voice.
If this function is not called then the library uses the function
@ -236,7 +257,7 @@ def load_opus(name):
global _lib
_lib = libopus_loader(name)
def is_loaded():
def is_loaded() -> bool:
"""Function to check if opus lib is successfully loaded either
via the :func:`ctypes.util.find_library` call of :func:`load_opus`.
@ -259,10 +280,10 @@ class OpusError(DiscordException):
The error code returned.
"""
def __init__(self, code):
self.code = code
def __init__(self, code: int):
self.code: int = code
msg = _lib.opus_strerror(self.code).decode('utf-8')
log.info('"%s" has happened', msg)
_log.info('"%s" has happened', msg)
super().__init__(msg)
class OpusNotLoaded(DiscordException):
@ -286,92 +307,96 @@ class _OpusStruct:
return _lib.opus_get_version_string().decode('utf-8')
class Encoder(_OpusStruct):
def __init__(self, application=APPLICATION_AUDIO):
def __init__(self, application: int = APPLICATION_AUDIO):
_OpusStruct.get_opus_version()
self.application = application
self._state = self._create_state()
self.application: int = application
self._state: EncoderStruct = self._create_state()
self.set_bitrate(128)
self.set_fec(True)
self.set_expected_packet_loss_percent(0.15)
self.set_bandwidth('full')
self.set_signal_type('auto')
def __del__(self):
def __del__(self) -> None:
if hasattr(self, '_state'):
_lib.opus_encoder_destroy(self._state)
self._state = None
# This is a destructor, so it's okay to assign None
self._state = None # type: ignore
def _create_state(self):
def _create_state(self) -> EncoderStruct:
ret = ctypes.c_int()
return _lib.opus_encoder_create(self.SAMPLING_RATE, self.CHANNELS, self.application, ctypes.byref(ret))
def set_bitrate(self, kbps):
def set_bitrate(self, kbps: int) -> int:
kbps = min(512, max(16, int(kbps)))
_lib.opus_encoder_ctl(self._state, CTL_SET_BITRATE, kbps * 1024)
return kbps
def set_bandwidth(self, req):
def set_bandwidth(self, req: BAND_CTL) -> None:
if req not in band_ctl:
raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(band_ctl)}')
k = band_ctl[req]
_lib.opus_encoder_ctl(self._state, CTL_SET_BANDWIDTH, k)
def set_signal_type(self, req):
def set_signal_type(self, req: SIGNAL_CTL) -> None:
if req not in signal_ctl:
raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(signal_ctl)}')
k = signal_ctl[req]
_lib.opus_encoder_ctl(self._state, CTL_SET_SIGNAL, k)
def set_fec(self, enabled=True):
def set_fec(self, enabled: bool = True) -> None:
_lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0)
def set_expected_packet_loss_percent(self, percentage):
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100))))
def set_expected_packet_loss_percent(self, percentage: float) -> None:
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore
def encode(self, pcm, frame_size):
def encode(self, pcm: bytes, frame_size: int) -> bytes:
max_data_bytes = len(pcm)
pcm = ctypes.cast(pcm, c_int16_ptr)
# bytes can be used to reference pointer
pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore
data = (ctypes.c_char * max_data_bytes)()
ret = _lib.opus_encode(self._state, pcm, frame_size, data, max_data_bytes)
ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes)
return array.array('b', data[:ret]).tobytes()
# array can be initialized with bytes but mypy doesn't know
return array.array('b', data[:ret]).tobytes() # type: ignore
class Decoder(_OpusStruct):
def __init__(self):
_OpusStruct.get_opus_version()
self._state = self._create_state()
self._state: DecoderStruct = self._create_state()
def __del__(self):
def __del__(self) -> None:
if hasattr(self, '_state'):
_lib.opus_decoder_destroy(self._state)
self._state = None
# This is a destructor, so it's okay to assign None
self._state = None # type: ignore
def _create_state(self):
def _create_state(self) -> DecoderStruct:
ret = ctypes.c_int()
return _lib.opus_decoder_create(self.SAMPLING_RATE, self.CHANNELS, ctypes.byref(ret))
@staticmethod
def packet_get_nb_frames(data):
def packet_get_nb_frames(data: bytes) -> int:
"""Gets the number of frames in an Opus packet"""
return _lib.opus_packet_get_nb_frames(data, len(data))
@staticmethod
def packet_get_nb_channels(data):
def packet_get_nb_channels(data: bytes) -> int:
"""Gets the number of channels in an Opus packet"""
return _lib.opus_packet_get_nb_channels(data)
@classmethod
def packet_get_samples_per_frame(cls, data):
def packet_get_samples_per_frame(cls, data: bytes) -> int:
"""Gets the number of samples per frame from an Opus packet"""
return _lib.opus_packet_get_samples_per_frame(data, cls.SAMPLING_RATE)
def _set_gain(self, adjustment):
def _set_gain(self, adjustment: int) -> int:
"""Configures decoder gain adjustment.
Scales the decoded output by a factor specified in Q8 dB units.
@ -383,26 +408,34 @@ class Decoder(_OpusStruct):
"""
return _lib.opus_decoder_ctl(self._state, CTL_SET_GAIN, adjustment)
def set_gain(self, dB):
def set_gain(self, dB: float) -> int:
"""Sets the decoder gain in dB, from -128 to 128."""
dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8)
return self._set_gain(dB_Q8)
def set_volume(self, mult):
def set_volume(self, mult: float) -> int:
"""Sets the output volume as a float percent, i.e. 0.5 for 50%, 1.75 for 175%, etc."""
return self.set_gain(20 * math.log10(mult)) # amplitude ratio
def _get_last_packet_duration(self):
def _get_last_packet_duration(self) -> int:
"""Gets the duration (in samples) of the last packet successfully decoded or concealed."""
ret = ctypes.c_int32()
_lib.opus_decoder_ctl(self._state, CTL_LAST_PACKET_DURATION, ctypes.byref(ret))
return ret.value
def decode(self, data, *, fec=False):
@overload
def decode(self, data: bytes, *, fec: bool) -> bytes:
...
@overload
def decode(self, data: Literal[None], *, fec: Literal[False]) -> bytes:
...
def decode(self, data: Optional[bytes], *, fec: bool = False) -> bytes:
if data is None and fec:
raise OpusError("Invalid arguments: FEC cannot be used with null data")
raise InvalidArgument("Invalid arguments: FEC cannot be used with null data")
if data is None:
frame_size = self._get_last_packet_duration() or self.SAMPLES_PER_FRAME

View File

@ -147,7 +147,7 @@ class Permissions(BaseFlags):
"""A factory method that creates a :class:`Permissions` with all
permissions set to ``True``.
"""
return cls(0b111111111111111111111111111111111111)
return cls(0b111111111111111111111111111111111111111)
@classmethod
def all_channel(cls: Type[P]) -> P:
@ -167,8 +167,13 @@ class Permissions(BaseFlags):
.. versionchanged:: 1.7
Added :attr:`stream`, :attr:`priority_speaker` and :attr:`use_slash_commands` permissions.
.. versionchanged:: 2.0
Added :attr:`create_public_threads`, :attr:`create_private_threads`, :attr:`manage_threads`,
:attr:`use_external_stickers`, :attr:`send_messages_in_threads` and
:attr:`request_to_speak` permissions.
"""
return cls(0b10110011111101111111111101010001)
return cls(0b111110110110011111101111111111101010001)
@classmethod
def general(cls: Type[P]) -> P:
@ -200,8 +205,12 @@ class Permissions(BaseFlags):
.. versionchanged:: 1.7
Permission :attr:`read_messages` is no longer part of the text permissions.
Added :attr:`use_slash_commands` permission.
.. versionchanged:: 2.0
Added :attr:`create_public_threads`, :attr:`create_private_threads`, :attr:`manage_threads`,
:attr:`send_messages_in_threads` and :attr:`use_external_stickers` permissions.
"""
return cls(0b10000000000001111111100001000000)
return cls(0b111110010000000000001111111100001000000)
@classmethod
def voice(cls: Type[P]) -> P:
@ -462,6 +471,14 @@ class Permissions(BaseFlags):
""":class:`bool`: Returns ``True`` if a user can create, edit, or delete emojis."""
return 1 << 30
@make_permission_alias('manage_emojis')
def manage_emojis_and_stickers(self) -> int:
""":class:`bool`: An alias for :attr:`manage_emojis`.
.. versionadded:: 2.0
"""
return 1 << 30
@flag_value
def use_slash_commands(self) -> int:
""":class:`bool`: Returns ``True`` if a user can use slash commands.
@ -495,21 +512,45 @@ class Permissions(BaseFlags):
return 1 << 34
@flag_value
def use_threads(self) -> int:
""":class:`bool`: Returns ``True`` if a user can create and participate in public threads.
def create_public_threads(self) -> int:
""":class:`bool`: Returns ``True`` if a user can create public threads.
.. versionadded:: 2.0
"""
return 1 << 35
@flag_value
def use_private_threads(self) -> int:
""":class:`bool`: Returns ``True`` if a user can create and participate in private threads.
def create_private_threads(self) -> int:
""":class:`bool`: Returns ``True`` if a user can create private threads.
.. versionadded:: 2.0
"""
return 1 << 36
@flag_value
def external_stickers(self) -> int:
""":class:`bool`: Returns ``True`` if a user can use stickers from other guilds.
.. versionadded:: 2.0
"""
return 1 << 37
@make_permission_alias('external_stickers')
def use_external_stickers(self) -> int:
""":class:`bool`: An alias for :attr:`external_stickers`.
.. versionadded:: 2.0
"""
return 1 << 37
@flag_value
def send_messages_in_threads(self) -> int:
""":class:`bool`: Returns ``True`` if a user can send messages in threads.
.. versionadded:: 2.0
"""
return 1 << 38
PO = TypeVar('PO', bound='PermissionOverwrite')
def _augment_from_permissions(cls):
@ -613,12 +654,16 @@ class PermissionOverwrite:
manage_permissions: Optional[bool]
manage_webhooks: Optional[bool]
manage_emojis: Optional[bool]
manage_emojis_and_stickers: Optional[bool]
use_slash_commands: Optional[bool]
request_to_speak: Optional[bool]
manage_events: Optional[bool]
manage_threads: Optional[bool]
use_threads: Optional[bool]
use_private_threads: Optional[bool]
create_public_threads: Optional[bool]
create_private_threads: Optional[bool]
send_messages_in_threads: Optional[bool]
external_stickers: Optional[bool]
use_external_stickers: Optional[bool]
def __init__(self, **kwargs: Optional[bool]):
self._values: Dict[str, Optional[bool]] = {}
@ -641,7 +686,7 @@ class PermissionOverwrite:
else:
self._values[key] = value
def pair(self):
def pair(self) -> Tuple[Permissions, Permissions]:
"""Tuple[:class:`Permissions`, :class:`Permissions`]: Returns the (allow, deny) pair from this overwrite."""
allow = Permissions.none()

View File

@ -50,7 +50,7 @@ if TYPE_CHECKING:
AT = TypeVar('AT', bound='AudioSource')
FT = TypeVar('FT', bound='FFmpegOpusAudio')
log: logging.Logger = logging.getLogger(__name__)
_log = logging.getLogger(__name__)
__all__ = (
'AudioSource',
@ -140,13 +140,25 @@ class FFmpegAudio(AudioSource):
.. versionadded:: 1.3
"""
def __init__(self, source: str, *, executable: str = 'ffmpeg', args: Any, **subprocess_kwargs: Any):
def __init__(self, source: Union[str, io.BufferedIOBase], *, executable: str = 'ffmpeg', args: Any, **subprocess_kwargs: Any):
piping = subprocess_kwargs.get('stdin') == subprocess.PIPE
if piping and isinstance(source, str):
raise TypeError("parameter conflict: 'source' parameter cannot be a string when piping to stdin")
args = [executable, *args]
kwargs = {'stdout': subprocess.PIPE}
kwargs.update(subprocess_kwargs)
self._process: subprocess.Popen = self._spawn_process(args, **kwargs)
self._stdout: IO[bytes] = self._process.stdout # type: ignore
self._stdin: Optional[IO[Bytes]] = None
self._pipe_thread: Optional[threading.Thread] = None
if piping:
n = f'popen-stdin-writer:{id(self):#x}'
self._stdin = self._process.stdin
self._pipe_thread = threading.Thread(target=self._pipe_writer, args=(source,), daemon=True, name=n)
self._pipe_thread.start()
def _spawn_process(self, args: Any, **subprocess_kwargs: Any) -> subprocess.Popen:
process = None
@ -160,26 +172,44 @@ class FFmpegAudio(AudioSource):
else:
return process
def cleanup(self) -> None:
def _kill_process(self) -> None:
proc = self._process
if proc is MISSING:
return
log.info('Preparing to terminate ffmpeg process %s.', proc.pid)
_log.info('Preparing to terminate ffmpeg process %s.', proc.pid)
try:
proc.kill()
except Exception:
log.exception("Ignoring error attempting to kill ffmpeg process %s", proc.pid)
_log.exception('Ignoring error attempting to kill ffmpeg process %s', proc.pid)
if proc.poll() is None:
log.info('ffmpeg process %s has not terminated. Waiting to terminate...', proc.pid)
_log.info('ffmpeg process %s has not terminated. Waiting to terminate...', proc.pid)
proc.communicate()
log.info('ffmpeg process %s should have terminated with a return code of %s.', proc.pid, proc.returncode)
_log.info('ffmpeg process %s should have terminated with a return code of %s.', proc.pid, proc.returncode)
else:
log.info('ffmpeg process %s successfully terminated with return code of %s.', proc.pid, proc.returncode)
_log.info('ffmpeg process %s successfully terminated with return code of %s.', proc.pid, proc.returncode)
self._process = self._stdout = MISSING
def _pipe_writer(self, source: io.BufferedIOBase) -> None:
while self._process:
# arbitrarily large read size
data = source.read(8192)
if not data:
self._process.terminate()
return
try:
self._stdin.write(data)
except Exception:
_log.debug('Write error for %s, this is probably not a problem', self, exc_info=True)
# at this point the source data is either exhausted or the process is fubar
self._process.terminate()
return
def cleanup(self) -> None:
self._kill_process()
self._process = self._stdout = self._stdin = MISSING
class FFmpegPCMAudio(FFmpegAudio):
"""An audio source from FFmpeg (or AVConv).
@ -218,16 +248,16 @@ class FFmpegPCMAudio(FFmpegAudio):
def __init__(
self,
source: str,
source: Union[str, io.BufferedIOBase],
*,
executable: str = 'ffmpeg',
pipe: bool = False,
stderr: Optional[IO[str]] = None,
before_options: Optional[str] = None,
before_options: Optional[str] = None,
options: Optional[str] = None
) -> None:
args = []
subprocess_kwargs = {'stdin': source if pipe else subprocess.DEVNULL, 'stderr': stderr}
subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr}
if isinstance(before_options, str):
args.extend(shlex.split(before_options))
@ -315,7 +345,7 @@ class FFmpegOpusAudio(FFmpegAudio):
def __init__(
self,
source: str,
source: Union[str, io.BufferedIOBase],
*,
bitrate: int = 128,
codec: Optional[str] = None,
@ -327,7 +357,7 @@ class FFmpegOpusAudio(FFmpegAudio):
) -> None:
args = []
subprocess_kwargs = {'stdin': source if pipe else subprocess.DEVNULL, 'stderr': stderr}
subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr}
if isinstance(before_options, str):
args.extend(shlex.split(before_options))
@ -384,7 +414,6 @@ class FFmpegOpusAudio(FFmpegAudio):
def custom_probe(source, executable):
# some analysis code here
return codec, bitrate
source = await discord.FFmpegOpusAudio.from_probe("song.webm", method=custom_probe)
@ -480,18 +509,18 @@ class FFmpegOpusAudio(FFmpegAudio):
codec, bitrate = await loop.run_in_executor(None, lambda: probefunc(source, executable)) # type: ignore
except Exception:
if not fallback:
log.exception("Probe '%s' using '%s' failed", method, executable)
_log.exception("Probe '%s' using '%s' failed", method, executable)
return # type: ignore
log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable)
_log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable)
try:
codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable)) # type: ignore
except Exception:
log.exception("Fallback probe using '%s' failed", executable)
_log.exception("Fallback probe using '%s' failed", executable)
else:
log.info("Fallback probe found codec=%s, bitrate=%s", codec, bitrate)
_log.info("Fallback probe found codec=%s, bitrate=%s", codec, bitrate)
else:
log.info("Probe found codec=%s, bitrate=%s", codec, bitrate)
_log.info("Probe found codec=%s, bitrate=%s", codec, bitrate)
finally:
return codec, bitrate
@ -656,12 +685,12 @@ class AudioPlayer(threading.Thread):
try:
self.after(error)
except Exception as exc:
log.exception('Calling the after function failed.')
_log.exception('Calling the after function failed.')
exc.__context__ = error
traceback.print_exception(type(exc), exc, exc.__traceback__)
elif error:
msg = f'Exception in voice thread {self.name}'
log.exception(msg, exc_info=error)
_log.exception(msg, exc_info=error)
print(msg, file=sys.stderr)
traceback.print_exception(type(error), error, error.__traceback__)
@ -698,4 +727,4 @@ class AudioPlayer(threading.Thread):
try:
asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.client.loop)
except Exception as e:
log.info("Speaking call in player failed: %s", e)
_log.info("Speaking call in player failed: %s", e)

0
discord/py.typed Normal file
View File

View File

@ -22,6 +22,25 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Set, List
if TYPE_CHECKING:
from .types.raw_models import (
MessageDeleteEvent,
BulkMessageDeleteEvent,
ReactionActionEvent,
MessageUpdateEvent,
ReactionClearEvent,
ReactionClearEmojiEvent,
IntegrationDeleteEvent
)
from .message import Message
from .partial_emoji import PartialEmoji
from .member import Member
from .enums import ChannelType, try_enum
__all__ = (
@ -35,11 +54,13 @@ __all__ = (
'RawThreadDeleteEvent',
)
class _RawReprMixin:
def __repr__(self):
def __repr__(self) -> str:
value = ' '.join(f'{attr}={getattr(self, attr)!r}' for attr in self.__slots__)
return f'<{self.__class__.__name__} {value}>'
class RawMessageDeleteEvent(_RawReprMixin):
"""Represents the event payload for a :func:`on_raw_message_delete` event.
@ -57,14 +78,15 @@ class RawMessageDeleteEvent(_RawReprMixin):
__slots__ = ('message_id', 'channel_id', 'guild_id', 'cached_message')
def __init__(self, data):
self.message_id = int(data['id'])
self.channel_id = int(data['channel_id'])
self.cached_message = None
def __init__(self, data: MessageDeleteEvent) -> None:
self.message_id: int = int(data['id'])
self.channel_id: int = int(data['channel_id'])
self.cached_message: Optional[Message] = None
try:
self.guild_id = int(data['guild_id'])
self.guild_id: Optional[int] = int(data['guild_id'])
except KeyError:
self.guild_id = None
self.guild_id: Optional[int] = None
class RawBulkMessageDeleteEvent(_RawReprMixin):
"""Represents the event payload for a :func:`on_raw_bulk_message_delete` event.
@ -83,15 +105,16 @@ class RawBulkMessageDeleteEvent(_RawReprMixin):
__slots__ = ('message_ids', 'channel_id', 'guild_id', 'cached_messages')
def __init__(self, data):
self.message_ids = {int(x) for x in data.get('ids', [])}
self.channel_id = int(data['channel_id'])
self.cached_messages = []
def __init__(self, data: BulkMessageDeleteEvent) -> None:
self.message_ids: Set[int] = {int(x) for x in data.get('ids', [])}
self.channel_id: int = int(data['channel_id'])
self.cached_messages: List[Message] = []
try:
self.guild_id = int(data['guild_id'])
self.guild_id: Optional[int] = int(data['guild_id'])
except KeyError:
self.guild_id = None
self.guild_id: Optional[int] = None
class RawMessageUpdateEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_message_edit` event.
@ -118,16 +141,17 @@ class RawMessageUpdateEvent(_RawReprMixin):
__slots__ = ('message_id', 'channel_id', 'guild_id', 'data', 'cached_message')
def __init__(self, data):
self.message_id = int(data['id'])
self.channel_id = int(data['channel_id'])
self.data = data
self.cached_message = None
def __init__(self, data: MessageUpdateEvent) -> None:
self.message_id: int = int(data['id'])
self.channel_id: int = int(data['channel_id'])
self.data: MessageUpdateEvent = data
self.cached_message: Optional[Message] = None
try:
self.guild_id = int(data['guild_id'])
self.guild_id: Optional[int] = int(data['guild_id'])
except KeyError:
self.guild_id = None
self.guild_id: Optional[int] = None
class RawReactionActionEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_reaction_add` or
@ -161,18 +185,19 @@ class RawReactionActionEvent(_RawReprMixin):
__slots__ = ('message_id', 'user_id', 'channel_id', 'guild_id', 'emoji',
'event_type', 'member')
def __init__(self, data, emoji, event_type):
self.message_id = int(data['message_id'])
self.channel_id = int(data['channel_id'])
self.user_id = int(data['user_id'])
self.emoji = emoji
self.event_type = event_type
self.member = None
def __init__(self, data: ReactionActionEvent, emoji: PartialEmoji, event_type: str) -> None:
self.message_id: int = int(data['message_id'])
self.channel_id: int = int(data['channel_id'])
self.user_id: int = int(data['user_id'])
self.emoji: PartialEmoji = emoji
self.event_type: str = event_type
self.member: Optional[Member] = None
try:
self.guild_id = int(data['guild_id'])
self.guild_id: Optional[int] = int(data['guild_id'])
except KeyError:
self.guild_id = None
self.guild_id: Optional[int] = None
class RawReactionClearEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_reaction_clear` event.
@ -189,14 +214,15 @@ class RawReactionClearEvent(_RawReprMixin):
__slots__ = ('message_id', 'channel_id', 'guild_id')
def __init__(self, data):
self.message_id = int(data['message_id'])
self.channel_id = int(data['channel_id'])
def __init__(self, data: ReactionClearEvent) -> None:
self.message_id: int = int(data['message_id'])
self.channel_id: int = int(data['channel_id'])
try:
self.guild_id = int(data['guild_id'])
self.guild_id: Optional[int] = int(data['guild_id'])
except KeyError:
self.guild_id = None
self.guild_id: Optional[int] = None
class RawReactionClearEmojiEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_reaction_clear_emoji` event.
@ -217,15 +243,16 @@ class RawReactionClearEmojiEvent(_RawReprMixin):
__slots__ = ('message_id', 'channel_id', 'guild_id', 'emoji')
def __init__(self, data, emoji):
self.emoji = emoji
self.message_id = int(data['message_id'])
self.channel_id = int(data['channel_id'])
def __init__(self, data: ReactionClearEmojiEvent, emoji: PartialEmoji) -> None:
self.emoji: PartialEmoji = emoji
self.message_id: int = int(data['message_id'])
self.channel_id: int = int(data['channel_id'])
try:
self.guild_id = int(data['guild_id'])
self.guild_id: Optional[int] = int(data['guild_id'])
except KeyError:
self.guild_id = None
self.guild_id: Optional[int] = None
class RawIntegrationDeleteEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_integration_delete` event.
@ -244,14 +271,14 @@ class RawIntegrationDeleteEvent(_RawReprMixin):
__slots__ = ('integration_id', 'application_id', 'guild_id')
def __init__(self, data):
self.integration_id = int(data['id'])
self.guild_id = int(data['guild_id'])
def __init__(self, data: IntegrationDeleteEvent) -> None:
self.integration_id: int = int(data['id'])
self.guild_id: int = int(data['guild_id'])
try:
self.application_id = int(data['application_id'])
self.application_id: Optional[int] = int(data['application_id'])
except KeyError:
self.application_id = None
self.application_id: Optional[int] = None
class RawThreadDeleteEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_thread_delete` event.

View File

@ -42,6 +42,7 @@ if TYPE_CHECKING:
Role as RolePayload,
RoleTags as RoleTagPayload,
)
from .types.guild import RolePositionUpdate
from .guild import Guild
from .member import Member
from .state import ConnectionState
@ -140,6 +141,14 @@ class Role(Hashable):
Returns the role's name.
.. describe:: str(x)
Returns the role's ID.
.. describe:: int(x)
Returns the role's ID.
Attributes
----------
id: :class:`int`
@ -194,6 +203,9 @@ class Role(Hashable):
def __str__(self) -> str:
return self.name
def __int__(self) -> int:
return self.id
def __repr__(self) -> str:
return f'<Role id={self.id} name={self.name!r}>'
@ -336,7 +348,7 @@ class Role(Hashable):
else:
roles.append(self.id)
payload = [{"id": z[0], "position": z[1]} for z in zip(roles, change_range)]
payload: List[RolePositionUpdate] = [{"id": z[0], "position": z[1]} for z in zip(roles, change_range)]
await http.move_role_position(self.guild.id, payload, reason=reason)
async def edit(
@ -350,7 +362,7 @@ class Role(Hashable):
mentionable: bool = MISSING,
position: int = MISSING,
reason: Optional[str] = MISSING,
) -> None:
) -> Optional[Role]:
"""|coro|
Edits the role.
@ -363,6 +375,9 @@ class Role(Hashable):
.. versionchanged:: 1.4
Can now pass ``int`` to ``colour`` keyword-only parameter.
.. versionchanged:: 2.0
Edits are no longer in-place, the newly edited role is returned instead.
Parameters
-----------
name: :class:`str`
@ -390,11 +405,14 @@ class Role(Hashable):
InvalidArgument
An invalid position was given or the default
role was asked to be moved.
"""
Returns
--------
:class:`Role`
The newly edited role.
"""
if position is not MISSING:
await self._move(position, reason=reason)
self.position = position
payload: Dict[str, Any] = {}
if color is not MISSING:
@ -419,7 +437,7 @@ class Role(Hashable):
payload['mentionable'] = mentionable
data = await self._state.http.edit_role(self.guild.id, self.id, reason=reason, **payload)
self._update(data)
return Role(guild=self.guild, data=data, state=self._state)
async def delete(self, *, reason: Optional[str] = None) -> None:
"""|coro|

View File

@ -22,8 +22,9 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import asyncio
import itertools
import logging
import aiohttp
@ -34,22 +35,30 @@ from .backoff import ExponentialBackoff
from .gateway import *
from .errors import (
ClientException,
InvalidArgument,
HTTPException,
GatewayNotFound,
ConnectionClosed,
PrivilegedIntentsRequired,
)
from . import utils
from .enums import Status
from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict, TypeVar
if TYPE_CHECKING:
from .gateway import DiscordWebSocket
from .activity import BaseActivity
from .enums import Status
EI = TypeVar('EI', bound='EventItem')
__all__ = (
'AutoShardedClient',
'ShardInfo',
)
log = logging.getLogger(__name__)
_log = logging.getLogger(__name__)
class EventType:
close = 0
@ -59,39 +68,41 @@ class EventType:
terminate = 4
clean_close = 5
class EventItem:
__slots__ = ('type', 'shard', 'error')
def __init__(self, etype, shard, error):
self.type = etype
self.shard = shard
self.error = error
def __init__(self, etype: int, shard: Optional['Shard'], error: Optional[Exception]) -> None:
self.type: int = etype
self.shard: Optional['Shard'] = shard
self.error: Optional[Exception] = error
def __lt__(self, other):
def __lt__(self: EI, other: EI) -> bool:
if not isinstance(other, EventItem):
return NotImplemented
return self.type < other.type
def __eq__(self, other):
def __eq__(self: EI, other: EI) -> bool:
if not isinstance(other, EventItem):
return NotImplemented
return self.type == other.type
def __hash__(self):
def __hash__(self) -> int:
return hash(self.type)
class Shard:
def __init__(self, ws, client, queue_put):
self.ws = ws
self._client = client
self._dispatch = client.dispatch
self._queue_put = queue_put
self.loop = self._client.loop
self._disconnect = False
def __init__(self, ws: DiscordWebSocket, client: AutoShardedClient, queue_put: Callable[[EventItem], None]) -> None:
self.ws: DiscordWebSocket = ws
self._client: Client = client
self._dispatch: Callable[..., None] = client.dispatch
self._queue_put: Callable[[EventItem], None] = queue_put
self.loop: asyncio.AbstractEventLoop = self._client.loop
self._disconnect: bool = False
self._reconnect = client._reconnect
self._backoff = ExponentialBackoff()
self._task = None
self._handled_exceptions = (
self._backoff: ExponentialBackoff = ExponentialBackoff()
self._task: Optional[asyncio.Task] = None
self._handled_exceptions: Tuple[Type[Exception], ...] = (
OSError,
HTTPException,
GatewayNotFound,
@ -101,25 +112,26 @@ class Shard:
)
@property
def id(self):
return self.ws.shard_id
def id(self) -> int:
# DiscordWebSocket.shard_id is set in the from_client classmethod
return self.ws.shard_id # type: ignore
def launch(self):
def launch(self) -> None:
self._task = self.loop.create_task(self.worker())
def _cancel_task(self):
def _cancel_task(self) -> None:
if self._task is not None and not self._task.done():
self._task.cancel()
async def close(self):
async def close(self) -> None:
self._cancel_task()
await self.ws.close(code=1000)
async def disconnect(self):
async def disconnect(self) -> None:
await self.close()
self._dispatch('shard_disconnect', self.id)
async def _handle_disconnect(self, e):
async def _handle_disconnect(self, e: Exception) -> None:
self._dispatch('disconnect')
self._dispatch('shard_disconnect', self.id)
if not self._reconnect:
@ -144,11 +156,11 @@ class Shard:
return
retry = self._backoff.delay()
log.error('Attempting a reconnect for shard ID %s in %.2fs', self.id, retry, exc_info=e)
_log.error('Attempting a reconnect for shard ID %s in %.2fs', self.id, retry, exc_info=e)
await asyncio.sleep(retry)
self._queue_put(EventItem(EventType.reconnect, self, e))
async def worker(self):
async def worker(self) -> None:
while not self._client.is_closed():
try:
await self.ws.poll_event()
@ -165,14 +177,19 @@ class Shard:
self._queue_put(EventItem(EventType.terminate, self, e))
break
async def reidentify(self, exc):
async def reidentify(self, exc: ReconnectWebSocket) -> None:
self._cancel_task()
self._dispatch('disconnect')
self._dispatch('shard_disconnect', self.id)
log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
_log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
try:
coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id,
session=self.ws.session_id, sequence=self.ws.sequence)
coro = DiscordWebSocket.from_client(
self._client,
resume=exc.resume,
shard_id=self.id,
session=self.ws.session_id,
sequence=self.ws.sequence,
)
self.ws = await asyncio.wait_for(coro, timeout=60.0)
except self._handled_exceptions as e:
await self._handle_disconnect(e)
@ -183,7 +200,7 @@ class Shard:
else:
self.launch()
async def reconnect(self):
async def reconnect(self) -> None:
self._cancel_task()
try:
coro = DiscordWebSocket.from_client(self._client, shard_id=self.id)
@ -197,6 +214,7 @@ class Shard:
else:
self.launch()
class ShardInfo:
"""A class that gives information and control over a specific shard.
@ -215,16 +233,16 @@ class ShardInfo:
__slots__ = ('_parent', 'id', 'shard_count')
def __init__(self, parent, shard_count):
self._parent = parent
self.id = parent.id
self.shard_count = shard_count
def __init__(self, parent: Shard, shard_count: Optional[int]) -> None:
self._parent: Shard = parent
self.id: int = parent.id
self.shard_count: Optional[int] = shard_count
def is_closed(self):
def is_closed(self) -> bool:
""":class:`bool`: Whether the shard connection is currently closed."""
return not self._parent.ws.open
async def disconnect(self):
async def disconnect(self) -> None:
"""|coro|
Disconnects a shard. When this is called, the shard connection will no
@ -237,7 +255,7 @@ class ShardInfo:
await self._parent.disconnect()
async def reconnect(self):
async def reconnect(self) -> None:
"""|coro|
Disconnects and then connects the shard again.
@ -246,7 +264,7 @@ class ShardInfo:
await self._parent.disconnect()
await self._parent.reconnect()
async def connect(self):
async def connect(self) -> None:
"""|coro|
Connects a shard. If the shard is already connected this does nothing.
@ -257,11 +275,11 @@ class ShardInfo:
await self._parent.reconnect()
@property
def latency(self):
def latency(self) -> float:
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds for this shard."""
return self._parent.ws.latency
def is_ws_ratelimited(self):
def is_ws_ratelimited(self) -> bool:
""":class:`bool`: Whether the websocket is currently rate limited.
This can be useful to know when deciding whether you should query members
@ -271,6 +289,7 @@ class ShardInfo:
"""
return self._parent.ws.is_ratelimited()
class AutoShardedClient(Client):
"""A client similar to :class:`Client` except it handles the complications
of sharding for the user into a more manageable and transparent single
@ -297,9 +316,13 @@ class AutoShardedClient(Client):
shard_ids: Optional[List[:class:`int`]]
An optional list of shard_ids to launch the shards with.
"""
def __init__(self, *args, loop=None, **kwargs):
if TYPE_CHECKING:
_connection: AutoShardedConnectionState
def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None:
kwargs.pop('shard_id', None)
self.shard_ids = kwargs.pop('shard_ids', None)
self.shard_ids: Optional[List[int]] = kwargs.pop('shard_ids', None)
super().__init__(*args, loop=loop, **kwargs)
if self.shard_ids is not None:
@ -315,18 +338,24 @@ class AutoShardedClient(Client):
self._connection._get_client = lambda: self
self.__queue = asyncio.PriorityQueue()
def _get_websocket(self, guild_id=None, *, shard_id=None):
def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket:
if shard_id is None:
shard_id = (guild_id >> 22) % self.shard_count
# guild_id won't be None if shard_id is None and shard_count won't be None here
shard_id = (guild_id >> 22) % self.shard_count # type: ignore
return self.__shards[shard_id].ws
def _get_state(self, **options):
return AutoShardedConnectionState(dispatch=self.dispatch,
handlers=self._handlers,
hooks=self._hooks, http=self.http, loop=self.loop, **options)
def _get_state(self, **options: Any) -> AutoShardedConnectionState:
return AutoShardedConnectionState(
dispatch=self.dispatch,
handlers=self._handlers,
hooks=self._hooks,
http=self.http,
loop=self.loop,
**options,
)
@property
def latency(self):
def latency(self) -> float:
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.
This operates similarly to :meth:`Client.latency` except it uses the average
@ -338,14 +367,14 @@ class AutoShardedClient(Client):
return sum(latency for _, latency in self.latencies) / len(self.__shards)
@property
def latencies(self):
def latencies(self) -> List[Tuple[int, float]]:
"""List[Tuple[:class:`int`, :class:`float`]]: A list of latencies between a HEARTBEAT and a HEARTBEAT_ACK in seconds.
This returns a list of tuples with elements ``(shard_id, latency)``.
"""
return [(shard_id, shard.ws.latency) for shard_id, shard in self.__shards.items()]
def get_shard(self, shard_id):
def get_shard(self, shard_id: int) -> Optional[ShardInfo]:
"""Optional[:class:`ShardInfo`]: Gets the shard information at a given shard ID or ``None`` if not found."""
try:
parent = self.__shards[shard_id]
@ -355,16 +384,16 @@ class AutoShardedClient(Client):
return ShardInfo(parent, self.shard_count)
@property
def shards(self):
def shards(self) -> Dict[int, ShardInfo]:
"""Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object."""
return { shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items() }
return {shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items()}
async def launch_shard(self, gateway, shard_id, *, initial=False):
async def launch_shard(self, gateway: str, shard_id: int, *, initial: bool = False) -> None:
try:
coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id)
ws = await asyncio.wait_for(coro, timeout=180.0)
except Exception:
log.exception('Failed to connect for shard_id: %s. Retrying...', shard_id)
_log.exception('Failed to connect for shard_id: %s. Retrying...', shard_id)
await asyncio.sleep(5.0)
return await self.launch_shard(gateway, shard_id)
@ -372,7 +401,7 @@ class AutoShardedClient(Client):
self.__shards[shard_id] = ret = Shard(ws, self, self.__queue.put_nowait)
ret.launch()
async def launch_shards(self):
async def launch_shards(self) -> None:
if self.shard_count is None:
self.shard_count, gateway = await self.http.get_bot_gateway()
else:
@ -389,7 +418,7 @@ class AutoShardedClient(Client):
self._connection.shards_launched.set()
async def connect(self, *, reconnect=True):
async def connect(self, *, reconnect: bool = True) -> None:
self._reconnect = reconnect
await self.launch_shards()
@ -413,7 +442,7 @@ class AutoShardedClient(Client):
elif item.type == EventType.clean_close:
return
async def close(self):
async def close(self) -> None:
"""|coro|
Closes the connection to Discord.
@ -425,7 +454,7 @@ class AutoShardedClient(Client):
for vc in self.voice_clients:
try:
await vc.disconnect()
await vc.disconnect(force=True)
except Exception:
pass
@ -436,7 +465,13 @@ class AutoShardedClient(Client):
await self.http.close()
self.__queue.put_nowait(EventItem(EventType.clean_close, None, None))
async def change_presence(self, *, activity=None, status=None, shard_id=None):
async def change_presence(
self,
*,
activity: Optional[BaseActivity] = None,
status: Optional[Status] = None,
shard_id: int = None,
) -> None:
"""|coro|
Changes the client's presence.
@ -468,23 +503,23 @@ class AutoShardedClient(Client):
"""
if status is None:
status = 'online'
status_value = 'online'
status_enum = Status.online
elif status is Status.offline:
status = 'invisible'
status_value = 'invisible'
status_enum = Status.offline
else:
status_enum = status
status = str(status)
status_value = str(status)
if shard_id is None:
for shard in self.__shards.values():
await shard.ws.change_presence(activity=activity, status=status)
await shard.ws.change_presence(activity=activity, status=status_value)
guilds = self._connection.guilds
else:
shard = self.__shards[shard_id]
await shard.ws.change_presence(activity=activity, status=status)
await shard.ws.change_presence(activity=activity, status=status_value)
guilds = [g for g in self._connection.guilds if g.shard_id == shard_id]
activities = () if activity is None else (activity,)
@ -493,10 +528,11 @@ class AutoShardedClient(Client):
if me is None:
continue
me.activities = activities
# Member.activities is typehinted as Tuple[ActivityType, ...], we may be setting it as Tuple[BaseActivity, ...]
me.activities = activities # type: ignore
me.status = status_enum
def is_ws_ratelimited(self):
def is_ws_ratelimited(self) -> bool:
""":class:`bool`: Whether the websocket is currently rate limited.
This can be useful to know when deciding whether you should query members

View File

@ -61,6 +61,10 @@ class StageInstance(Hashable):
Returns the stage instance's hash.
.. describe:: int(x)
Returns the stage instance's ID.
Attributes
-----------
id: :class:`int`
@ -74,7 +78,7 @@ class StageInstance(Hashable):
privacy_level: :class:`StagePrivacyLevel`
The privacy level of the stage instance.
discoverable_disabled: :class:`bool`
Whether the stage instance is discoverable.
Whether discoverability for the stage instance is disabled.
"""
__slots__ = (
@ -97,21 +101,22 @@ class StageInstance(Hashable):
self.id: int = int(data['id'])
self.channel_id: int = int(data['channel_id'])
self.topic: str = data['topic']
self.privacy_level = try_enum(StagePrivacyLevel, data['privacy_level'])
self.discoverable_disabled = data['discoverable_disabled']
self.privacy_level: StagePrivacyLevel = try_enum(StagePrivacyLevel, data['privacy_level'])
self.discoverable_disabled: bool = data.get('discoverable_disabled', False)
def __repr__(self) -> str:
return f'<StageInstance id={self.id} guild={self.guild!r} channel_id={self.channel_id} topic={self.topic!r}>'
@cached_slot_property('_cs_channel')
def channel(self) -> Optional[StageChannel]:
"""Optional[:class:`StageChannel`: The guild that stage instance is running in."""
return self._state.get_channel(self.channel_id)
"""Optional[:class:`StageChannel`]: The channel that stage instance is running in."""
# the returned channel will always be a StageChannel or None
return self._state.get_channel(self.channel_id) # type: ignore
def is_public(self) -> bool:
return self.privacy_level is StagePrivacyLevel.public
async def edit(self, *, topic: str = MISSING, privacy_level: StagePrivacyLevel = MISSING) -> None:
async def edit(self, *, topic: str = MISSING, privacy_level: StagePrivacyLevel = MISSING, reason: Optional[str] = None) -> None:
"""|coro|
Edits the stage instance.
@ -125,6 +130,8 @@ class StageInstance(Hashable):
The stage instance's new topic.
privacy_level: :class:`StagePrivacyLevel`
The stage instance's new privacy level.
reason: :class:`str`
The reason the stage instance was edited. Shows up on the audit log.
Raises
------
@ -148,9 +155,9 @@ class StageInstance(Hashable):
payload['privacy_level'] = privacy_level.value
if payload:
await self._state.http.edit_stage_instance(self.channel_id, **payload)
await self._state.http.edit_stage_instance(self.channel_id, **payload, reason=reason)
async def delete(self) -> None:
async def delete(self, *, reason: Optional[str] = None) -> None:
"""|coro|
Deletes the stage instance.
@ -158,6 +165,11 @@ class StageInstance(Hashable):
You must have the :attr:`~Permissions.manage_channels` permission to
use this.
Parameters
-----------
reason: :class:`str`
The reason the stage instance was deleted. Shows up on the audit log.
Raises
------
Forbidden
@ -165,4 +177,4 @@ class StageInstance(Hashable):
HTTPException
Deleting the stage instance failed.
"""
await self._state.http.delete_stage_instance(self.channel_id)
await self._state.http.delete_stage_instance(self.channel_id, reason=reason)

File diff suppressed because it is too large Load Diff

View File

@ -23,24 +23,224 @@ DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional
from typing import Literal, TYPE_CHECKING, List, Optional, Tuple, Type, Union
import unicodedata
from .mixins import Hashable
from .asset import Asset
from .utils import snowflake_time
from .enums import StickerType, try_enum
from .asset import Asset, AssetMixin
from .utils import cached_slot_property, find, snowflake_time, get, MISSING
from .errors import InvalidData
from .enums import StickerType, StickerFormatType, try_enum
__all__ = (
'StickerPack',
'StickerItem',
'Sticker',
'StandardSticker',
'GuildSticker',
)
if TYPE_CHECKING:
import datetime
from .state import ConnectionState
from .types.message import Sticker as StickerPayload
from .user import User
from .guild import Guild
from .types.sticker import (
StickerPack as StickerPackPayload,
StickerItem as StickerItemPayload,
Sticker as StickerPayload,
StandardSticker as StandardStickerPayload,
GuildSticker as GuildStickerPayload,
ListPremiumStickerPacks as ListPremiumStickerPacksPayload,
EditGuildSticker,
)
class Sticker(Hashable):
class StickerPack(Hashable):
"""Represents a sticker pack.
.. versionadded:: 2.0
.. container:: operations
.. describe:: str(x)
Returns the name of the sticker pack.
.. describe:: hash(x)
Returns the hash of the sticker pack.
.. describe:: int(x)
Returns the ID of the sticker pack.
.. describe:: x == y
Checks if the sticker pack is equal to another sticker pack.
.. describe:: x != y
Checks if the sticker pack is not equal to another sticker pack.
Attributes
-----------
name: :class:`str`
The name of the sticker pack.
description: :class:`str`
The description of the sticker pack.
id: :class:`int`
The id of the sticker pack.
stickers: List[:class:`StandardSticker`]
The stickers of this sticker pack.
sku_id: :class:`int`
The SKU ID of the sticker pack.
cover_sticker_id: :class:`int`
The ID of the sticker used for the cover of the sticker pack.
cover_sticker: :class:`StandardSticker`
The sticker used for the cover of the sticker pack.
"""
__slots__ = (
'_state',
'id',
'stickers',
'name',
'sku_id',
'cover_sticker_id',
'cover_sticker',
'description',
'_banner',
)
def __init__(self, *, state: ConnectionState, data: StickerPackPayload) -> None:
self._state: ConnectionState = state
self._from_data(data)
def _from_data(self, data: StickerPackPayload) -> None:
self.id: int = int(data['id'])
stickers = data['stickers']
self.stickers: List[StandardSticker] = [StandardSticker(state=self._state, data=sticker) for sticker in stickers]
self.name: str = data['name']
self.sku_id: int = int(data['sku_id'])
self.cover_sticker_id: int = int(data['cover_sticker_id'])
self.cover_sticker: StandardSticker = get(self.stickers, id=self.cover_sticker_id) # type: ignore
self.description: str = data['description']
self._banner: int = int(data['banner_asset_id'])
@property
def banner(self) -> Asset:
""":class:`Asset`: The banner asset of the sticker pack."""
return Asset._from_sticker_banner(self._state, self._banner)
def __repr__(self) -> str:
return f'<StickerPack id={self.id} name={self.name!r} description={self.description!r}>'
def __str__(self) -> str:
return self.name
class _StickerTag(Hashable, AssetMixin):
__slots__ = ()
id: int
format: StickerFormatType
async def read(self) -> bytes:
"""|coro|
Retrieves the content of this sticker as a :class:`bytes` object.
.. note::
Stickers that use the :attr:`StickerFormatType.lottie` format cannot be read.
Raises
------
HTTPException
Downloading the asset failed.
NotFound
The asset was deleted.
TypeError
The sticker is a lottie type.
Returns
-------
:class:`bytes`
The content of the asset.
"""
if self.format is StickerFormatType.lottie:
raise TypeError('Cannot read stickers of format "lottie".')
return await super().read()
class StickerItem(_StickerTag):
"""Represents a sticker item.
.. versionadded:: 2.0
.. container:: operations
.. describe:: str(x)
Returns the name of the sticker item.
.. describe:: x == y
Checks if the sticker item is equal to another sticker item.
.. describe:: x != y
Checks if the sticker item is not equal to another sticker item.
Attributes
-----------
name: :class:`str`
The sticker's name.
id: :class:`int`
The id of the sticker.
format: :class:`StickerFormatType`
The format for the sticker's image.
url: :class:`str`
The URL for the sticker's image.
"""
__slots__ = ('_state', 'name', 'id', 'format', 'url')
def __init__(self, *, state: ConnectionState, data: StickerItemPayload):
self._state: ConnectionState = state
self.name: str = data['name']
self.id: int = int(data['id'])
self.format: StickerFormatType = try_enum(StickerFormatType, data['format_type'])
self.url: str = f'{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}'
def __repr__(self) -> str:
return f'<StickerItem id={self.id} name={self.name!r} format={self.format}>'
def __str__(self) -> str:
return self.name
async def fetch(self) -> Union[Sticker, StandardSticker, GuildSticker]:
"""|coro|
Attempts to retrieve the full sticker data of the sticker item.
Raises
--------
HTTPException
Retrieving the sticker failed.
Returns
--------
Union[:class:`StandardSticker`, :class:`GuildSticker`]
The retrieved sticker.
"""
data: StickerPayload = await self._state.http.get_sticker(self.id)
cls, _ = _sticker_factory(data['type']) # type: ignore
return cls(state=self._state, data=data)
class Sticker(_StickerTag):
"""Represents a sticker.
.. versionadded:: 1.6
@ -69,30 +269,27 @@ class Sticker(Hashable):
The description of the sticker.
pack_id: :class:`int`
The id of the sticker's pack.
format: :class:`StickerType`
format: :class:`StickerFormatType`
The format for the sticker's image.
tags: List[:class:`str`]
A list of tags for the sticker.
url: :class:`str`
The URL for the sticker's image.
"""
__slots__ = ('_state', 'id', 'name', 'description', 'pack_id', 'format', '_image', 'tags')
__slots__ = ('_state', 'id', 'name', 'description', 'format', 'url')
def __init__(self, *, state: ConnectionState, data: StickerPayload):
def __init__(self, *, state: ConnectionState, data: StickerPayload) -> None:
self._state: ConnectionState = state
self._from_data(data)
def _from_data(self, data: StickerPayload) -> None:
self.id: int = int(data['id'])
self.name: str = data['name']
self.description: str = data['description']
self.pack_id: int = int(data.get('pack_id', 0))
self.format: StickerType = try_enum(StickerType, data['format_type'])
self._image: str = data['asset']
try:
self.tags: List[str] = [tag.strip() for tag in data['tags'].split(',')]
except KeyError:
self.tags = []
self.format: StickerFormatType = try_enum(StickerFormatType, data['format_type'])
self.url: str = f'{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}'
def __repr__(self) -> str:
return f'<{self.__class__.__name__} id={self.id} name={self.name!r}>'
return f'<Sticker id={self.id} name={self.name!r}>'
def __str__(self) -> str:
return self.name
@ -102,19 +299,233 @@ class Sticker(Hashable):
""":class:`datetime.datetime`: Returns the sticker's creation time in UTC."""
return snowflake_time(self.id)
@property
def image(self) -> Optional[Asset]:
"""Returns an :class:`Asset` for the sticker's image.
.. note::
This will return ``None`` if the format is ``StickerType.lottie``.
class StandardSticker(Sticker):
"""Represents a sticker that is found in a standard sticker pack.
.. versionadded:: 2.0
.. container:: operations
.. describe:: str(x)
Returns the name of the sticker.
.. describe:: x == y
Checks if the sticker is equal to another sticker.
.. describe:: x != y
Checks if the sticker is not equal to another sticker.
Attributes
----------
name: :class:`str`
The sticker's name.
id: :class:`int`
The id of the sticker.
description: :class:`str`
The description of the sticker.
pack_id: :class:`int`
The id of the sticker's pack.
format: :class:`StickerFormatType`
The format for the sticker's image.
tags: List[:class:`str`]
A list of tags for the sticker.
sort_value: :class:`int`
The sticker's sort order within its pack.
"""
__slots__ = ('sort_value', 'pack_id', 'type', 'tags')
def _from_data(self, data: StandardStickerPayload) -> None:
super()._from_data(data)
self.sort_value: int = data['sort_value']
self.pack_id: int = int(data['pack_id'])
self.type: StickerType = StickerType.standard
try:
self.tags: List[str] = [tag.strip() for tag in data['tags'].split(',')]
except KeyError:
self.tags = []
def __repr__(self) -> str:
return f'<StandardSticker id={self.id} name={self.name!r} pack_id={self.pack_id}>'
async def pack(self) -> StickerPack:
"""|coro|
Retrieves the sticker pack that this sticker belongs to.
Raises
--------
InvalidData
The corresponding sticker pack was not found.
HTTPException
Retrieving the sticker pack failed.
Returns
-------
Optional[:class:`Asset`]
The resulting CDN asset.
--------
:class:`StickerPack`
The retrieved sticker pack.
"""
if self.format is StickerType.lottie:
return None
data: ListPremiumStickerPacksPayload = await self._state.http.list_premium_sticker_packs()
packs = data['sticker_packs']
pack = find(lambda d: int(d['id']) == self.pack_id, packs)
return Asset._from_sticker(self._state, self.id, self._image)
if pack:
return StickerPack(state=self._state, data=pack)
raise InvalidData(f'Could not find corresponding sticker pack for {self!r}')
class GuildSticker(Sticker):
"""Represents a sticker that belongs to a guild.
.. versionadded:: 2.0
.. container:: operations
.. describe:: str(x)
Returns the name of the sticker.
.. describe:: x == y
Checks if the sticker is equal to another sticker.
.. describe:: x != y
Checks if the sticker is not equal to another sticker.
Attributes
----------
name: :class:`str`
The sticker's name.
id: :class:`int`
The id of the sticker.
description: :class:`str`
The description of the sticker.
format: :class:`StickerFormatType`
The format for the sticker's image.
available: :class:`bool`
Whether this sticker is available for use.
guild_id: :class:`int`
The ID of the guild that this sticker is from.
user: Optional[:class:`User`]
The user that created this sticker. This can only be retrieved using :meth:`Guild.fetch_sticker` and
having the :attr:`~Permissions.manage_emojis_and_stickers` permission.
emoji: :class:`str`
The name of a unicode emoji that represents this sticker.
"""
__slots__ = ('available', 'guild_id', 'user', 'emoji', 'type', '_cs_guild')
def _from_data(self, data: GuildStickerPayload) -> None:
super()._from_data(data)
self.available: bool = data['available']
self.guild_id: int = int(data['guild_id'])
user = data.get('user')
self.user: Optional[User] = self._state.store_user(user) if user else None
self.emoji: str = data['tags']
self.type: StickerType = StickerType.guild
def __repr__(self) -> str:
return f'<GuildSticker name={self.name!r} id={self.id} guild_id={self.guild_id} user={self.user!r}>'
@cached_slot_property('_cs_guild')
def guild(self) -> Optional[Guild]:
"""Optional[:class:`Guild`]: The guild that this sticker is from.
Could be ``None`` if the bot is not in the guild.
.. versionadded:: 2.0
"""
return self._state._get_guild(self.guild_id)
async def edit(
self,
*,
name: str = MISSING,
description: str = MISSING,
emoji: str = MISSING,
reason: Optional[str] = None,
) -> GuildSticker:
"""|coro|
Edits a :class:`GuildSticker` for the guild.
Parameters
-----------
name: :class:`str`
The sticker's new name. Must be at least 2 characters.
description: Optional[:class:`str`]
The sticker's new description. Can be ``None``.
emoji: :class:`str`
The name of a unicode emoji that represents the sticker's expression.
reason: :class:`str`
The reason for editing this sticker. Shows up on the audit log.
Raises
-------
Forbidden
You are not allowed to edit stickers.
HTTPException
An error occurred editing the sticker.
Returns
--------
:class:`GuildSticker`
The newly modified sticker.
"""
payload: EditGuildSticker = {}
if name is not MISSING:
payload['name'] = name
if description is not MISSING:
payload['description'] = description
if emoji is not MISSING:
try:
emoji = unicodedata.name(emoji)
except TypeError:
pass
else:
emoji = emoji.replace(' ', '_')
payload['tags'] = emoji
data: GuildStickerPayload = await self._state.http.modify_guild_sticker(self.guild_id, self.id, payload, reason)
return GuildSticker(state=self._state, data=data)
async def delete(self, *, reason: Optional[str] = None) -> None:
"""|coro|
Deletes the custom :class:`Sticker` from the guild.
You must have :attr:`~Permissions.manage_emojis_and_stickers` permission to
do this.
Parameters
-----------
reason: Optional[:class:`str`]
The reason for deleting this sticker. Shows up on the audit log.
Raises
-------
Forbidden
You are not allowed to delete stickers.
HTTPException
An error occurred deleting the sticker.
"""
await self._state.http.delete_guild_sticker(self.guild_id, self.id, reason)
def _sticker_factory(sticker_type: Literal[1, 2]) -> Tuple[Type[Union[StandardSticker, GuildSticker, Sticker]], StickerType]:
value = try_enum(StickerType, sticker_type)
if value == StickerType.standard:
return StandardSticker, value
elif value == StickerType.guild:
return GuildSticker, value
else:
return Sticker, value

View File

@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Any, Optional, TYPE_CHECKING, overload
from typing import Any, Optional, TYPE_CHECKING
from .utils import parse_time, _get_as_snowflake, _bytes_to_base64_data, MISSING
from .enums import VoiceRegion
from .guild import Guild
@ -34,7 +34,10 @@ __all__ = (
)
if TYPE_CHECKING:
import datetime
from .types.template import Template as TemplatePayload
from .state import ConnectionState
from .user import User
class _FriendlyHttpAttributeErrorHelper:
@ -77,7 +80,7 @@ class _PartialTemplateState:
def _get_guild(self, id):
return self.__state._get_guild(id)
async def query_members(self, **kwargs):
async def query_members(self, **kwargs: Any):
return []
def __getattr__(self, attr):
@ -127,33 +130,35 @@ class Template:
'_state',
)
def __init__(self, *, state, data: TemplatePayload):
def __init__(self, *, state: ConnectionState, data: TemplatePayload) -> None:
self._state = state
self._store(data)
def _store(self, data: TemplatePayload):
self.code = data['code']
self.uses = data['usage_count']
self.name = data['name']
self.description = data['description']
def _store(self, data: TemplatePayload) -> None:
self.code: str = data['code']
self.uses: int = data['usage_count']
self.name: str = data['name']
self.description: Optional[str] = data['description']
creator_data = data.get('creator')
self.creator = None if creator_data is None else self._state.store_user(creator_data)
self.creator: Optional[User] = None if creator_data is None else self._state.create_user(creator_data)
self.created_at = parse_time(data.get('created_at'))
self.updated_at = parse_time(data.get('updated_at'))
self.created_at: Optional[datetime.datetime] = parse_time(data.get('created_at'))
self.updated_at: Optional[datetime.datetime] = parse_time(data.get('updated_at'))
id = _get_as_snowflake(data, 'source_guild_id')
guild_id = int(data['source_guild_id'])
guild: Optional[Guild] = self._state._get_guild(guild_id)
guild = self._state._get_guild(id)
if guild is None and id:
self.source_guild: Guild
if guild is None:
source_serialised = data['serialized_source_guild']
source_serialised['id'] = id
source_serialised['id'] = guild_id
state = _PartialTemplateState(state=self._state)
guild = Guild(data=source_serialised, state=state) # type: ignore
# Guild expects a ConnectionState, we're passing a _PartialTemplateState
self.source_guild = Guild(data=source_serialised, state=state) # type: ignore
else:
self.source_guild = guild
self.source_guild = guild
self.is_dirty = data.get('is_dirty', None)
self.is_dirty: Optional[bool] = data.get('is_dirty', None)
def __repr__(self) -> str:
return (
@ -161,7 +166,7 @@ class Template:
f' creator={self.creator!r} source_guild={self.source_guild!r} is_dirty={self.is_dirty}>'
)
async def create_guild(self, name: str, region: Optional[VoiceRegion] = None, icon: Any = None):
async def create_guild(self, name: str, region: Optional[VoiceRegion] = None, icon: Any = None) -> Guild:
"""|coro|
Creates a :class:`.Guild` using the template.
@ -201,7 +206,7 @@ class Template:
data = await self._state.http.create_from_template(self.code, name, region_value, icon)
return Guild(data=data, state=self._state)
async def sync(self) -> None:
async def sync(self) -> Template:
"""|coro|
Sync the template to the guild's current state.
@ -211,6 +216,9 @@ class Template:
.. versionadded:: 1.7
.. versionchanged:: 2.0
The template is no longer edited in-place, instead it is returned.
Raises
-------
HTTPException
@ -219,17 +227,22 @@ class Template:
You don't have permissions to edit the template.
NotFound
This template does not exist.
Returns
--------
:class:`Template`
The newly edited template.
"""
data = await self._state.http.sync_template(self.source_guild.id, self.code)
self._store(data)
return Template(state=self._state, data=data)
async def edit(
self,
*,
name: str = MISSING,
description: Optional[str] = MISSING,
) -> None:
) -> Template:
"""|coro|
Edit the template metadata.
@ -239,6 +252,9 @@ class Template:
.. versionadded:: 1.7
.. versionchanged:: 2.0
The template is no longer edited in-place, instead it is returned.
Parameters
------------
name: :class:`str`
@ -254,6 +270,11 @@ class Template:
You don't have permissions to edit the template.
NotFound
This template does not exist.
Returns
--------
:class:`Template`
The newly edited template.
"""
payload = {}
@ -263,7 +284,7 @@ class Template:
payload['description'] = description
data = await self._state.http.edit_template(self.source_guild.id, self.code, payload)
self._store(data)
return Template(state=self._state, data=data)
async def delete(self) -> None:
"""|coro|

View File

@ -46,8 +46,9 @@ if TYPE_CHECKING:
ThreadMetadata,
ThreadArchiveDuration,
)
from .types.snowflake import SnowflakeList
from .guild import Guild
from .channel import TextChannel
from .channel import TextChannel, CategoryChannel
from .member import Member
from .message import Message, PartialMessage
from .abc import Snowflake, SnowflakeTime
@ -73,6 +74,10 @@ class Thread(Messageable, Hashable):
Returns the thread's hash.
.. describe:: int(x)
Returns the thread's ID.
.. describe:: str(x)
Returns the thread's name.
@ -110,6 +115,9 @@ class Thread(Messageable, Hashable):
Whether the thread is archived.
locked: :class:`bool`
Whether the thread is locked.
invitable: :class:`bool`
Whether non-moderators can add other non-moderators to this thread.
This is always ``True`` for public threads.
archiver_id: Optional[:class:`int`]
The user's ID that archived this thread.
auto_archive_duration: :class:`int`
@ -135,6 +143,7 @@ class Thread(Messageable, Hashable):
'me',
'locked',
'archived',
'invitable',
'archiver_id',
'auto_archive_duration',
'archive_timestamp',
@ -166,6 +175,8 @@ class Thread(Messageable, Hashable):
self._type = try_enum(ChannelType, data['type'])
self.last_message_id = _get_as_snowflake(data, 'last_message_id')
self.slowmode_delay = data.get('rate_limit_per_user', 0)
self.message_count = data['message_count']
self.member_count = data['member_count']
self._unroll_metadata(data['thread_metadata'])
try:
@ -181,6 +192,7 @@ class Thread(Messageable, Hashable):
self.auto_archive_duration = data['auto_archive_duration']
self.archive_timestamp = parse_time(data['archive_timestamp'])
self.locked = data.get('locked', False)
self.invitable = data.get('invitable', True)
def _update(self, data):
try:
@ -215,6 +227,16 @@ class Thread(Messageable, Hashable):
""":class:`str`: The string that allows you to mention the thread."""
return f'<#{self.id}>'
@property
def members(self) -> List[ThreadMember]:
"""List[:class:`ThreadMember`]: A list of thread members in this thread.
This requires :attr:`Intents.members` to be properly filled. Most of the time however,
this data is not provided by the gateway and a call to :meth:`fetch_members` is
needed.
"""
return list(self._members.values())
@property
def last_message(self) -> Optional[Message]:
"""Fetches the last message from this channel in cache.
@ -236,6 +258,26 @@ class Thread(Messageable, Hashable):
"""
return self._state._get_message(self.last_message_id) if self.last_message_id else None
@property
def category(self) -> Optional[CategoryChannel]:
"""The category channel the parent channel belongs to, if applicable.
Raises
-------
ClientException
The parent channel was not cached and returned ``None``.
Returns
-------
Optional[:class:`CategoryChannel`]
The parent channel's category.
"""
parent = self.parent
if parent is None:
raise ClientException('Parent channel not found')
return parent.category
@property
def category_id(self) -> Optional[int]:
"""The category channel ID the parent channel belongs to, if applicable.
@ -362,7 +404,7 @@ class Thread(Messageable, Hashable):
if len(messages) > 100:
raise ClientException('Can only bulk delete messages up to 100 messages')
message_ids = [m.id for m in messages]
message_ids: SnowflakeList = [m.id for m in messages]
await self._state.http.delete_messages(self.id, message_ids)
async def purge(
@ -488,9 +530,10 @@ class Thread(Messageable, Hashable):
name: str = MISSING,
archived: bool = MISSING,
locked: bool = MISSING,
invitable: bool = MISSING,
slowmode_delay: int = MISSING,
auto_archive_duration: ThreadArchiveDuration = MISSING,
):
) -> Thread:
"""|coro|
Edits the thread.
@ -510,8 +553,11 @@ class Thread(Messageable, Hashable):
Whether to archive the thread or not.
locked: :class:`bool`
Whether to lock the thread or not.
invitable: :class:`bool`
Whether non-moderators can add other non-moderators to this thread.
Only available for private threads.
auto_archive_duration: :class:`int`
The new duration to auto archive threads for inactivity.
The new duration in minutes before a thread is automatically archived for inactivity.
Must be one of ``60``, ``1440``, ``4320``, or ``10080``.
slowmode_delay: :class:`int`
Specifies the slowmode rate limit for user in this thread, in seconds.
@ -523,6 +569,11 @@ class Thread(Messageable, Hashable):
You do not have permissions to edit the thread.
HTTPException
Editing the thread failed.
Returns
--------
:class:`Thread`
The newly edited thread.
"""
payload = {}
if name is not MISSING:
@ -533,20 +584,22 @@ class Thread(Messageable, Hashable):
payload['auto_archive_duration'] = auto_archive_duration
if locked is not MISSING:
payload['locked'] = locked
if invitable is not MISSING:
payload['invitable'] = invitable
if slowmode_delay is not MISSING:
payload['rate_limit_per_user'] = slowmode_delay
await self._state.http.edit_channel(self.id, **payload)
data = await self._state.http.edit_channel(self.id, **payload)
# The data payload will always be a Thread payload
return Thread(data=data, state=self._state, guild=self.guild) # type: ignore
async def join(self):
"""|coro|
Joins this thread.
You must have :attr:`~Permissions.send_messages` and :attr:`~Permissions.use_threads`
to join a public thread. If the thread is private then :attr:`~Permissions.send_messages`
and either :attr:`~Permissions.use_private_threads` or :attr:`~Permissions.manage_messages`
is required to join the thread.
You must have :attr:`~Permissions.send_messages_in_threads` to join a thread.
If the thread is private, :attr:`~Permissions.manage_threads` is also needed.
Raises
-------
@ -614,6 +667,28 @@ class Thread(Messageable, Hashable):
"""
await self._state.http.remove_user_from_thread(self.id, user.id)
async def fetch_members(self) -> List[ThreadMember]:
"""|coro|
Retrieves all :class:`ThreadMember` that are in this thread.
This requires :attr:`Intents.members` to get information about members
other than yourself.
Raises
-------
HTTPException
Retrieving the members failed.
Returns
--------
List[:class:`ThreadMember`]
All thread members in the thread.
"""
members = await self._state.http.get_thread_members(self.id)
return [ThreadMember(parent=self, data=data) for data in members]
async def delete(self):
"""|coro|
@ -677,6 +752,10 @@ class ThreadMember(Hashable):
Returns the thread member's hash.
.. describe:: int(x)
Returns the thread member's ID.
.. describe:: str(x)
Returns the thread member's name.

View File

@ -61,6 +61,7 @@ class _PartialAppInfoOptional(TypedDict, total=False):
terms_of_service_url: str
privacy_policy_url: str
max_participants: int
flags: int
class PartialAppInfo(_PartialAppInfoOptional, BaseAppInfo):
pass

View File

@ -32,6 +32,7 @@ from .user import User
from .snowflake import Snowflake
from .role import Role
from .channel import ChannelType, VideoQualityMode, PermissionOverwrite
from .threads import Thread
AuditLogEvent = Literal[
1,
@ -69,19 +70,28 @@ AuditLogEvent = Literal[
80,
81,
82,
83,
84,
85,
90,
91,
92,
110,
111,
112,
]
class _AuditLogChange_Str(TypedDict):
key: Literal[
'name', 'description', 'preferred_locale', 'vanity_url_code', 'topic', 'code', 'allow', 'deny', 'permissions'
'name', 'description', 'preferred_locale', 'vanity_url_code', 'topic', 'code', 'allow', 'deny', 'permissions', 'tags'
]
new_value: str
old_value: str
class _AuditLogChange_AssetHash(TypedDict):
key: Literal['icon_hash', 'splash_hash', 'discovery_splash_hash', 'banner_hash', 'avatar_hash']
key: Literal['icon_hash', 'splash_hash', 'discovery_splash_hash', 'banner_hash', 'avatar_hash', 'asset']
new_value: str
old_value: str
@ -98,6 +108,7 @@ class _AuditLogChange_Snowflake(TypedDict):
'application_id',
'channel_id',
'inviter_id',
'guild_id',
]
new_value: Snowflake
old_value: Snowflake
@ -116,6 +127,9 @@ class _AuditLogChange_Bool(TypedDict):
'enabled_emoticons',
'region',
'rtc_region',
'available',
'archived',
'locked',
]
new_value: bool
old_value: bool
@ -132,6 +146,8 @@ class _AuditLogChange_Int(TypedDict):
'max_uses',
'max_age',
'user_limit',
'auto_archive_duration',
'default_auto_archive_duration',
]
new_value: int
old_value: int
@ -238,3 +254,4 @@ class AuditLog(TypedDict):
users: List[User]
audit_log_entries: List[AuditLogEntry]
integrations: List[PartialIntegration]
threads: List[Thread]

View File

@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
from typing import List, Literal, Optional, TypedDict, Union
from .user import PartialUser
from .snowflake import Snowflake
from .threads import ThreadMetadata, ThreadMember
from .threads import ThreadMetadata, ThreadMember, ThreadArchiveDuration
OverwriteType = Literal[0, 1]
@ -63,6 +63,7 @@ class _TextChannelOptional(TypedDict, total=False):
last_message_id: Optional[Snowflake]
last_pin_timestamp: str
rate_limit_per_user: int
default_auto_archive_duration: ThreadArchiveDuration
class TextChannel(_BaseGuildChannel, _TextChannelOptional):
@ -115,7 +116,7 @@ class _ThreadChannelOptional(TypedDict, total=False):
class ThreadChannel(_BaseChannel, _ThreadChannelOptional):
type: Literal[11, 12]
type: Literal[10, 11, 12]
guild_id: Snowflake
parent_id: Snowflake
owner_id: Snowflake

View File

@ -53,6 +53,7 @@ class _SelectMenuOptional(TypedDict, total=False):
placeholder: str
min_values: int
max_values: int
disabled: bool
class _SelectOptionsOptional(TypedDict, total=False):

View File

@ -37,8 +37,11 @@ if TYPE_CHECKING:
from .message import AllowedMentions, Message
ApplicationCommandType = Literal[1, 2, 3]
class _ApplicationCommandOptional(TypedDict, total=False):
options: List[ApplicationCommandOption]
type: ApplicationCommandType
class ApplicationCommand(_ApplicationCommandOptional):
@ -53,7 +56,7 @@ class _ApplicationCommandOptionOptional(TypedDict, total=False):
options: List[ApplicationCommandOption]
ApplicationCommandOptionType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9]
ApplicationCommandOptionType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
class ApplicationCommandOption(_ApplicationCommandOptionOptional):
@ -122,12 +125,18 @@ class _ApplicationCommandInteractionDataOptionSnowflake(_ApplicationCommandInter
value: Snowflake
class _ApplicationCommandInteractionDataOptionNumber(_ApplicationCommandInteractionDataOption):
type: Literal[10]
value: float
ApplicationCommandInteractionDataOption = Union[
_ApplicationCommandInteractionDataOptionString,
_ApplicationCommandInteractionDataOptionInteger,
_ApplicationCommandInteractionDataOptionSubcommand,
_ApplicationCommandInteractionDataOptionBoolean,
_ApplicationCommandInteractionDataOptionSnowflake,
_ApplicationCommandInteractionDataOptionNumber,
]
@ -148,6 +157,8 @@ class ApplicationCommandInteractionDataResolved(TypedDict, total=False):
class _ApplicationCommandInteractionDataOptional(TypedDict, total=False):
options: List[ApplicationCommandInteractionDataOption]
resolved: ApplicationCommandInteractionDataResolved
target_id: Snowflake
type: ApplicationCommandType
class ApplicationCommandInteractionData(_ApplicationCommandInteractionDataOptional):
@ -211,8 +222,15 @@ class MessageInteraction(TypedDict):
user: User
class EditApplicationCommand(TypedDict):
name: str
class _EditApplicationCommandOptional(TypedDict, total=False):
description: str
options: Optional[List[ApplicationCommandOption]]
type: ApplicationCommandType
class EditApplicationCommand(_EditApplicationCommandOptional):
name: str
default_permission: bool

View File

@ -39,6 +39,7 @@ class PartialMember(TypedDict):
class Member(PartialMember, total=False):
avatar: str
user: User
nick: str
premium_since: str
@ -46,16 +47,17 @@ class Member(PartialMember, total=False):
permissions: str
class _OptionalGatewayMember(PartialMember, total=False):
class _OptionalMemberWithUser(PartialMember, total=False):
avatar: str
nick: str
premium_since: str
pending: bool
permissions: str
class GatewayMember(_OptionalGatewayMember):
class MemberWithUser(_OptionalMemberWithUser):
user: User
class UserWithMember(User, total=False):
member: _OptionalGatewayMember
member: _OptionalMemberWithUser

View File

@ -33,6 +33,7 @@ from .embed import Embed
from .channel import ChannelType
from .components import Component
from .interactions import MessageInteraction
from .sticker import StickerItem
class ChannelMention(TypedDict):
@ -89,22 +90,6 @@ class MessageReference(TypedDict, total=False):
fail_if_not_exists: bool
class _StickerOptional(TypedDict, total=False):
tags: str
StickerFormatType = Literal[1, 2, 3]
class Sticker(_StickerOptional):
id: Snowflake
pack_id: Snowflake
name: str
description: str
asset: str
format_type: StickerFormatType
class _MessageOptional(TypedDict, total=False):
guild_id: Snowflake
member: Member
@ -117,7 +102,7 @@ class _MessageOptional(TypedDict, total=False):
application_id: Snowflake
message_reference: MessageReference
flags: int
stickers: List[Sticker]
sticker_items: List[StickerItem]
referenced_message: Optional[Message]
interaction: MessageInteraction
components: List[Component]

View File

@ -0,0 +1,87 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from typing import TypedDict, List
from .snowflake import Snowflake
from .member import Member
from .emoji import PartialEmoji
class _MessageEventOptional(TypedDict, total=False):
guild_id: Snowflake
class MessageDeleteEvent(_MessageEventOptional):
id: Snowflake
channel_id: Snowflake
class BulkMessageDeleteEvent(_MessageEventOptional):
ids: List[Snowflake]
channel_id: Snowflake
class _ReactionActionEventOptional(TypedDict, total=False):
guild_id: Snowflake
member: Member
class MessageUpdateEvent(_MessageEventOptional):
id: Snowflake
channel_id: Snowflake
class ReactionActionEvent(_ReactionActionEventOptional):
user_id: Snowflake
channel_id: Snowflake
message_id: Snowflake
emoji: PartialEmoji
class _ReactionClearEventOptional(TypedDict, total=False):
guild_id: Snowflake
class ReactionClearEvent(_ReactionClearEventOptional):
channel_id: Snowflake
message_id: Snowflake
class _ReactionClearEmojiEventOptional(TypedDict, total=False):
guild_id: Snowflake
class ReactionClearEmojiEvent(_ReactionClearEmojiEventOptional):
channel_id: int
message_id: int
emoji: PartialEmoji
class _IntegrationDeleteEventOptional(TypedDict, total=False):
application_id: Snowflake
class IntegrationDeleteEvent(_IntegrationDeleteEventOptional):
id: Snowflake
guild_id: Snowflake

93
discord/types/sticker.py Normal file
View File

@ -0,0 +1,93 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import List, Literal, TypedDict, Union
from .snowflake import Snowflake
from .user import User
StickerFormatType = Literal[1, 2, 3]
class StickerItem(TypedDict):
id: Snowflake
name: str
format_type: StickerFormatType
class BaseSticker(TypedDict):
id: Snowflake
name: str
description: str
tags: str
format_type: StickerFormatType
class StandardSticker(BaseSticker):
type: Literal[1]
sort_value: int
pack_id: Snowflake
class _GuildStickerOptional(TypedDict, total=False):
user: User
class GuildSticker(BaseSticker, _GuildStickerOptional):
type: Literal[2]
available: bool
guild_id: Snowflake
Sticker = Union[BaseSticker, StandardSticker, GuildSticker]
class StickerPack(TypedDict):
id: Snowflake
stickers: List[StandardSticker]
name: str
sku_id: Snowflake
cover_sticker_id: Snowflake
description: str
banner_asset_id: Snowflake
class _CreateGuildStickerOptional(TypedDict, total=False):
description: str
class CreateGuildSticker(_CreateGuildStickerOptional):
name: str
tags: str
class EditGuildSticker(TypedDict, total=False):
name: str
tags: str
description: str
class ListPremiumStickerPacks(TypedDict):
sticker_packs: List[StickerPack]

View File

@ -41,6 +41,7 @@ class ThreadMember(TypedDict):
class _ThreadMetadataOptional(TypedDict, total=False):
archiver_id: Snowflake
locked: bool
invitable: bool
class ThreadMetadata(_ThreadMetadataOptional):

View File

@ -24,14 +24,14 @@ DEALINGS IN THE SOFTWARE.
from typing import Optional, TypedDict, List, Literal
from .snowflake import Snowflake
from .member import Member
from .member import MemberWithUser
SupportedModes = Literal['xsalsa20_poly1305_lite', 'xsalsa20_poly1305_suffix', 'xsalsa20_poly1305']
class _PartialVoiceStateOptional(TypedDict, total=False):
member: Member
member: MemberWithUser
self_stream: bool

View File

@ -78,6 +78,8 @@ class Select(Item[V]):
Defaults to 1 and must be between 1 and 25.
options: List[:class:`discord.SelectOption`]
A list of options that can be selected in this menu.
disabled: :class:`bool`
Whether the select is disabled or not.
row: Optional[:class:`int`]
The relative row this select menu belongs to. A Discord component can only have 5
rows. By default, items are arranged automatically into those 5 rows. If you'd
@ -91,6 +93,7 @@ class Select(Item[V]):
'min_values',
'max_values',
'options',
'disabled',
)
def __init__(
@ -101,8 +104,10 @@ class Select(Item[V]):
min_values: int = 1,
max_values: int = 1,
options: List[SelectOption] = MISSING,
disabled: bool = False,
row: Optional[int] = None,
) -> None:
super().__init__()
self._selected_values: List[str] = []
self._provided_custom_id = custom_id is not MISSING
custom_id = os.urandom(16).hex() if custom_id is MISSING else custom_id
@ -114,6 +119,7 @@ class Select(Item[V]):
min_values=min_values,
max_values=max_values,
options=options,
disabled=disabled,
)
self.row = row
@ -191,13 +197,13 @@ class Select(Item[V]):
-----------
label: :class:`str`
The label of the option. This is displayed to users.
Can only be up to 25 characters.
Can only be up to 100 characters.
value: :class:`str`
The value of the option. This is not displayed to users.
If not given, defaults to the label. Can only be up to 100 characters.
description: Optional[:class:`str`]
An additional description of the option, if any.
Can only be up to 50 characters.
Can only be up to 100 characters.
emoji: Optional[Union[:class:`str`, :class:`.Emoji`, :class:`.PartialEmoji`]]
The emoji of the option, if available. This can either be a string representing
the custom or unicode emoji or an instance of :class:`.PartialEmoji` or :class:`.Emoji`.
@ -240,6 +246,15 @@ class Select(Item[V]):
self._underlying.options.append(option)
@property
def disabled(self) -> bool:
""":class:`bool`: Whether the select is disabled or not."""
return self._underlying.disabled
@disabled.setter
def disabled(self, value: bool):
self._underlying.disabled = bool(value)
@property
def values(self) -> List[str]:
"""List[:class:`str`]: A list of values that have been selected by the user."""
@ -267,6 +282,7 @@ class Select(Item[V]):
min_values=component.min_values,
max_values=component.max_values,
options=component.options,
disabled=component.disabled,
row=None,
)
@ -285,6 +301,7 @@ def select(
min_values: int = 1,
max_values: int = 1,
options: List[SelectOption] = MISSING,
disabled: bool = False,
row: Optional[int] = None,
) -> Callable[[ItemCallbackType], ItemCallbackType]:
"""A decorator that attaches a select menu to a component.
@ -317,11 +334,13 @@ def select(
Defaults to 1 and must be between 1 and 25.
options: List[:class:`discord.SelectOption`]
A list of options that can be selected in this menu.
disabled: :class:`bool`
Whether the select is disabled or not. Defaults to ``False``.
"""
def decorator(func: ItemCallbackType) -> ItemCallbackType:
if not inspect.iscoroutinefunction(func):
raise TypeError('button function must be a coroutine function')
raise TypeError('select function must be a coroutine function')
func.__discord_ui_model_type__ = Select
func.__discord_ui_model_kwargs__ = {
@ -331,6 +350,7 @@ def select(
'min_values': min_values,
'max_values': max_values,
'options': options,
'disabled': disabled,
}
return func

View File

@ -456,8 +456,8 @@ class View:
class ViewStore:
def __init__(self, state: ConnectionState):
# (component_type, custom_id): (View, Item)
self._views: Dict[Tuple[int, str], Tuple[View, Item]] = {}
# (component_type, message_id, custom_id): (View, Item)
self._views: Dict[Tuple[int, Optional[int], str], Tuple[View, Item]] = {}
# message_id: View
self._synced_message_views: Dict[int, View] = {}
self._state: ConnectionState = state
@ -474,8 +474,7 @@ class ViewStore:
return list(views.values())
def __verify_integrity(self):
to_remove: List[Tuple[int, str]] = []
now = time.monotonic()
to_remove: List[Tuple[int, Optional[int], str]] = []
for (k, (view, _)) in self._views.items():
if view.is_finished():
to_remove.append(k)
@ -489,7 +488,7 @@ class ViewStore:
view._start_listening_from_store(self)
for item in view.children:
if item.is_dispatchable():
self._views[(item.type.value, item.custom_id)] = (view, item) # type: ignore
self._views[(item.type.value, message_id, item.custom_id)] = (view, item) # type: ignore
if message_id is not None:
self._synced_message_views[message_id] = view
@ -506,8 +505,11 @@ class ViewStore:
def dispatch(self, component_type: int, custom_id: str, interaction: Interaction):
self.__verify_integrity()
key = (component_type, custom_id)
value = self._views.get(key)
message_id: Optional[int] = interaction.message and interaction.message.id
key = (component_type, message_id, custom_id)
# Fallback to None message_id searches in case a persistent view
# was added without an associated message_id
value = self._views.get(key) or self._views.get((component_type, None, custom_id))
if value is None:
return

View File

@ -22,19 +22,35 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from typing import Any, Dict, Optional, TYPE_CHECKING
from __future__ import annotations
from typing import Any, Dict, List, Optional, Type, TypeVar, TYPE_CHECKING
import discord.abc
from .asset import Asset
from .colour import Colour
from .enums import DefaultAvatar
from .flags import PublicUserFlags
from .utils import snowflake_time, _bytes_to_base64_data, MISSING
from .enums import DefaultAvatar
from .colour import Colour
from .asset import Asset
if TYPE_CHECKING:
from datetime import datetime
from .channel import DMChannel
from .guild import Guild
from .message import Message
from .state import ConnectionState
from .types.channel import DMChannel as DMChannelPayload
from .types.user import User as UserPayload
__all__ = (
'User',
'ClientUser',
)
BU = TypeVar('BU', bound='BaseUser')
class _UserTag:
__slots__ = ()
@ -42,7 +58,18 @@ class _UserTag:
class BaseUser(_UserTag):
__slots__ = ('name', 'id', 'discriminator', '_avatar', 'bot', 'system', '_public_flags', '_state')
__slots__ = (
'name',
'id',
'discriminator',
'_avatar',
'_banner',
'_accent_colour',
'bot',
'system',
'_public_flags',
'_state',
)
if TYPE_CHECKING:
name: str
@ -50,53 +77,65 @@ class BaseUser(_UserTag):
discriminator: str
bot: bool
system: bool
_state: ConnectionState
_avatar: Optional[str]
_banner: Optional[str]
_accent_colour: Optional[str]
_public_flags: int
def __init__(self, *, state, data):
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
self._state = state
self._update(data)
def __repr__(self):
def __repr__(self) -> str:
return (
f"<BaseUser id={self.id} name={self.name!r} discriminator={self.discriminator!r}"
f" bot={self.bot} system={self.system}>"
)
def __str__(self):
def __str__(self) -> str:
return f'{self.name}#{self.discriminator}'
def __eq__(self, other):
def __int__(self) -> int:
return self.id
def __eq__(self, other: Any) -> bool:
return isinstance(other, _UserTag) and other.id == self.id
def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
def __hash__(self):
def __hash__(self) -> int:
return self.id >> 22
def _update(self, data):
def _update(self, data: UserPayload) -> None:
self.name = data['username']
self.id = int(data['id'])
self.discriminator = data['discriminator']
self._avatar = data['avatar']
self._banner = data.get('banner', None)
self._accent_colour = data.get('accent_color', None)
self._public_flags = data.get('public_flags', 0)
self.bot = data.get('bot', False)
self.system = data.get('system', False)
@classmethod
def _copy(cls, user):
def _copy(cls: Type[BU], user: BU) -> BU:
self = cls.__new__(cls) # bypass __init__
self.name = user.name
self.id = user.id
self.discriminator = user.discriminator
self._avatar = user._avatar
self._banner = user._banner
self._accent_colour = user._accent_colour
self.bot = user.bot
self._state = user._state
self._public_flags = user._public_flags
return self
def _to_minimal_user_json(self):
def _to_minimal_user_json(self) -> Dict[str, Any]:
return {
'username': self.name,
'id': self.id,
@ -106,29 +145,82 @@ class BaseUser(_UserTag):
}
@property
def public_flags(self):
def public_flags(self) -> PublicUserFlags:
""":class:`PublicUserFlags`: The publicly available flags the user has."""
return PublicUserFlags._from_value(self._public_flags)
@property
def avatar(self):
""":class:`Asset`: Returns an :class:`Asset` for the avatar the user has.
def avatar(self) -> Optional[Asset]:
"""Optional[:class:`Asset`]: Returns an :class:`Asset` for the avatar the user has.
If the user does not have a traditional avatar, an asset for
the default avatar is returned instead.
If the user does not have a traditional avatar, ``None`` is returned.
If you want the avatar that a user has displayed, consider :attr:`display_avatar`.
"""
if self._avatar is None:
return Asset._from_default_avatar(self._state, int(self.discriminator) % len(DefaultAvatar))
else:
if self._avatar is not None:
return Asset._from_avatar(self._state, self.id, self._avatar)
return None
@property
def default_avatar(self):
def default_avatar(self) -> Asset:
""":class:`Asset`: Returns the default avatar for a given user. This is calculated by the user's discriminator."""
return Asset._from_default_avatar(self._state, int(self.discriminator) % len(DefaultAvatar))
@property
def colour(self):
def display_avatar(self) -> Asset:
""":class:`Asset`: Returns the user's display avatar.
For regular users this is just their default avatar or uploaded avatar.
.. versionadded:: 2.0
"""
return self.avatar or self.default_avatar
@property
def banner(self) -> Optional[Asset]:
"""Optional[:class:`Asset`]: Returns the user's banner asset, if available.
.. versionadded:: 2.0
.. note::
This information is only available via :meth:`Client.fetch_user`.
"""
if self._banner is None:
return None
return Asset._from_user_banner(self._state, self.id, self._banner)
@property
def accent_colour(self) -> Optional[Colour]:
"""Optional[:class:`Colour`]: Returns the user's accent colour, if applicable.
There is an alias for this named :attr:`accent_color`.
.. versionadded:: 2.0
.. note::
This information is only available via :meth:`Client.fetch_user`.
"""
if self._accent_colour is None:
return None
return Colour(self._accent_colour)
@property
def accent_color(self) -> Optional[Colour]:
"""Optional[:class:`Colour`]: Returns the user's accent color, if applicable.
There is an alias for this named :attr:`accent_colour`.
.. versionadded:: 2.0
.. note::
This information is only available via :meth:`Client.fetch_user`.
"""
return self.accent_colour
@property
def colour(self) -> Colour:
""":class:`Colour`: A property that returns a colour denoting the rendered colour
for the user. This always returns :meth:`Colour.default`.
@ -137,7 +229,7 @@ class BaseUser(_UserTag):
return Colour.default()
@property
def color(self):
def color(self) -> Colour:
""":class:`Colour`: A property that returns a color denoting the rendered color
for the user. This always returns :meth:`Colour.default`.
@ -146,12 +238,12 @@ class BaseUser(_UserTag):
return self.colour
@property
def mention(self):
def mention(self) -> str:
""":class:`str`: Returns a string that allows you to mention the given user."""
return f'<@{self.id}>'
@property
def created_at(self):
def created_at(self) -> datetime:
""":class:`datetime.datetime`: Returns the user's creation time in UTC.
This is when the user's Discord account was created.
@ -159,7 +251,7 @@ class BaseUser(_UserTag):
return snowflake_time(self.id)
@property
def display_name(self):
def display_name(self) -> str:
""":class:`str`: Returns the user's display name.
For regular users this is just their username, but
@ -168,7 +260,7 @@ class BaseUser(_UserTag):
"""
return self.name
def mentioned_in(self, message):
def mentioned_in(self, message: Message) -> bool:
"""Checks if the user is mentioned in the specified message.
Parameters
@ -234,16 +326,22 @@ class ClientUser(BaseUser):
__slots__ = ('locale', '_flags', 'verified', 'mfa_enabled', '__weakref__')
def __init__(self, *, state, data):
if TYPE_CHECKING:
verified: bool
locale: Optional[str]
mfa_enabled: bool
_flags: int
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
super().__init__(state=state, data=data)
def __repr__(self):
def __repr__(self) -> str:
return (
f'<ClientUser id={self.id} name={self.name!r} discriminator={self.discriminator!r}'
f' bot={self.bot} verified={self.verified} mfa_enabled={self.mfa_enabled}>'
)
def _update(self, data):
def _update(self, data: UserPayload) -> None:
super()._update(data)
# There's actually an Optional[str] phone field as well but I won't use it
self.verified = data.get('verified', False)
@ -251,7 +349,7 @@ class ClientUser(BaseUser):
self._flags = data.get('flags', 0)
self.mfa_enabled = data.get('mfa_enabled', False)
async def edit(self, *, username: str = MISSING, avatar: bytes = MISSING) -> None:
async def edit(self, *, username: str = MISSING, avatar: bytes = MISSING) -> ClientUser:
"""|coro|
Edits the current profile of the client.
@ -265,6 +363,9 @@ class ClientUser(BaseUser):
The only image formats supported for uploading is JPEG and PNG.
.. versionchanged:: 2.0
The edit is no longer in-place, instead the newly edited client user is returned.
Parameters
-----------
username: :class:`str`
@ -279,6 +380,11 @@ class ClientUser(BaseUser):
Editing your profile failed.
InvalidArgument
Wrong image format passed for ``avatar``.
Returns
---------
:class:`ClientUser`
The newly edited client user.
"""
payload: Dict[str, Any] = {}
if username is not MISSING:
@ -287,8 +393,8 @@ class ClientUser(BaseUser):
if avatar is not MISSING:
payload['avatar'] = _bytes_to_base64_data(avatar)
data = await self._state.http.edit_profile(payload)
self._update(data)
data: UserPayload = await self._state.http.edit_profile(payload)
return ClientUser(state=self._state, data=data)
class User(BaseUser, discord.abc.Messageable):
@ -312,6 +418,10 @@ class User(BaseUser, discord.abc.Messageable):
Returns the user's name with discriminator.
.. describe:: int(x)
Returns the user's ID.
Attributes
-----------
name: :class:`str`
@ -328,11 +438,11 @@ class User(BaseUser, discord.abc.Messageable):
__slots__ = ('_stored',)
def __init__(self, *, state, data):
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
super().__init__(state=state, data=data)
self._stored = False
self._stored: bool = False
def __repr__(self):
def __repr__(self) -> str:
return f'<User id={self.id} name={self.name!r} discriminator={self.discriminator!r} bot={self.bot}>'
def __del__(self) -> None:
@ -343,17 +453,17 @@ class User(BaseUser, discord.abc.Messageable):
pass
@classmethod
def _copy(cls, user):
def _copy(cls, user: User):
self = super()._copy(user)
self._stored = getattr(user, '_stored', False)
self._stored = False
return self
async def _get_channel(self):
async def _get_channel(self) -> DMChannel:
ch = await self.create_dm()
return ch
@property
def dm_channel(self):
def dm_channel(self) -> Optional[DMChannel]:
"""Optional[:class:`DMChannel`]: Returns the channel associated with this user if it exists.
If this returns ``None``, you can create a DM channel by calling the
@ -362,7 +472,7 @@ class User(BaseUser, discord.abc.Messageable):
return self._state._get_private_channel_by_user(self.id)
@property
def mutual_guilds(self):
def mutual_guilds(self) -> List[Guild]:
"""List[:class:`Guild`]: The guilds that the user shares with the client.
.. note::
@ -373,7 +483,7 @@ class User(BaseUser, discord.abc.Messageable):
"""
return [guild for guild in self._state._guilds.values() if guild.get_member(self.id)]
async def create_dm(self):
async def create_dm(self) -> DMChannel:
"""|coro|
Creates a :class:`DMChannel` with this user.
@ -391,5 +501,5 @@ class User(BaseUser, discord.abc.Messageable):
return found
state = self._state
data = await state.http.start_private_message(self.id)
data: DMChannelPayload = await state.http.start_private_message(self.id)
return state.add_dm_channel(data)

View File

@ -120,6 +120,9 @@ class _cached_property:
if TYPE_CHECKING:
from functools import cached_property as cached_property
from typing_extensions import ParamSpec
from .permissions import Permissions
from .abc import Snowflake
from .invite import Invite
@ -129,6 +132,8 @@ if TYPE_CHECKING:
headers: Mapping[str, Any]
P = ParamSpec('P')
else:
cached_property = _cached_property
@ -231,8 +236,8 @@ def parse_time(timestamp: Optional[str]) -> Optional[datetime.datetime]:
return None
def copy_doc(original: Callable[..., Any]) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(overriden: Callable[..., Any]) -> Callable[..., Any]:
def copy_doc(original: Callable) -> Callable[[T], T]:
def decorator(overriden: T) -> T:
overriden.__doc__ = original.__doc__
overriden.__signature__ = _signature(original) # type: ignore
return overriden
@ -240,10 +245,10 @@ def copy_doc(original: Callable[..., Any]) -> Callable[[Callable[..., Any]], Cal
return decorator
def deprecated(instead: Optional[str] = None) -> Callable[[Callable[..., T]], Callable[..., T]]:
def actual_decorator(func: Callable[..., T]) -> Callable[..., T]:
def deprecated(instead: Optional[str] = None) -> Callable[[Callable[P, T]], Callable[P, T]]:
def actual_decorator(func: Callable[P, T]) -> Callable[P, T]:
@functools.wraps(func)
def decorated(*args, **kwargs) -> T:
def decorated(*args: P.args, **kwargs: P.kwargs) -> T:
warnings.simplefilter('always', DeprecationWarning) # turn off filter
if instead:
fmt = "{0.__name__} is deprecated, use {1} instead."
@ -267,7 +272,7 @@ def oauth_url(
redirect_uri: str = MISSING,
scopes: Iterable[str] = MISSING,
disable_guild_select: bool = False,
):
) -> str:
"""A helper function that returns the OAuth2 URL for inviting the bot
into guilds.
@ -479,17 +484,17 @@ def _bytes_to_base64_data(data: bytes) -> str:
if HAS_ORJSON:
def to_json(obj: Any) -> str: # type: ignore
def _to_json(obj: Any) -> str: # type: ignore
return orjson.dumps(obj).decode('utf-8')
from_json = orjson.loads # type: ignore
_from_json = orjson.loads # type: ignore
else:
def to_json(obj: Any) -> str:
def _to_json(obj: Any) -> str:
return json.dumps(obj, separators=(',', ':'), ensure_ascii=True)
from_json = json.loads
_from_json = json.loads
def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
@ -916,7 +921,7 @@ def evaluate_annotation(
is_literal = False
args = tp.__args__
if not hasattr(tp, '__origin__'):
if PY_310 and tp.__class__ is types.Union: # type: ignore
if PY_310 and tp.__class__ is types.UnionType: # type: ignore
converted = Union[args] # type: ignore
return evaluate_annotation(converted, globals, locals, cache)

View File

@ -84,7 +84,7 @@ __all__ = (
log: logging.Logger = logging.getLogger(__name__)
_log = logging.getLogger(__name__)
class VoiceProtocol:
"""A class that represents the Discord voice protocol.
@ -301,7 +301,7 @@ class VoiceClient(VoiceProtocol):
async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
if self._voice_server_complete.is_set():
log.info('Ignoring extraneous voice server update.')
_log.info('Ignoring extraneous voice server update.')
return
self.token = data.get('token')
@ -309,7 +309,7 @@ class VoiceClient(VoiceProtocol):
endpoint = data.get('endpoint')
if endpoint is None or self.token is None:
log.warning('Awaiting endpoint... This requires waiting. ' \
_log.warning('Awaiting endpoint... This requires waiting. ' \
'If timeout occurred considering raising the timeout and reconnecting.')
return
@ -335,18 +335,18 @@ class VoiceClient(VoiceProtocol):
await self.channel.guild.change_voice_state(channel=self.channel)
async def voice_disconnect(self) -> None:
log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id)
_log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id)
await self.channel.guild.change_voice_state(channel=None)
def prepare_handshake(self) -> None:
self._voice_state_complete.clear()
self._voice_server_complete.clear()
self._handshaking = True
log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1)
_log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1)
self._connections += 1
def finish_handshake(self) -> None:
log.info('Voice handshake complete. Endpoint found %s', self.endpoint)
_log.info('Voice handshake complete. Endpoint found %s', self.endpoint)
self._handshaking = False
self._voice_server_complete.clear()
self._voice_state_complete.clear()
@ -360,7 +360,7 @@ class VoiceClient(VoiceProtocol):
return ws
async def connect(self, *, reconnect: bool, timeout: float) ->None:
log.info('Connecting to voice...')
_log.info('Connecting to voice...')
self.timeout = timeout
for i in range(5):
@ -388,7 +388,7 @@ class VoiceClient(VoiceProtocol):
break
except (ConnectionClosed, asyncio.TimeoutError):
if reconnect:
log.exception('Failed to connect to voice... Retrying...')
_log.exception('Failed to connect to voice... Retrying...')
await asyncio.sleep(1 + i * 2.0)
await self.voice_disconnect()
continue
@ -453,14 +453,14 @@ class VoiceClient(VoiceProtocol):
# 4014 - voice channel has been deleted.
# 4015 - voice server has crashed
if exc.code in (1000, 4015):
log.info('Disconnecting from voice normally, close code %d.', exc.code)
_log.info('Disconnecting from voice normally, close code %d.', exc.code)
await self.disconnect()
break
if exc.code == 4014:
log.info('Disconnected from voice by force... potentially reconnecting.')
_log.info('Disconnected from voice by force... potentially reconnecting.')
successful = await self.potential_reconnect()
if not successful:
log.info('Reconnect was unsuccessful, disconnecting from voice normally...')
_log.info('Reconnect was unsuccessful, disconnecting from voice normally...')
await self.disconnect()
break
else:
@ -471,7 +471,7 @@ class VoiceClient(VoiceProtocol):
raise
retry = backoff.delay()
log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry)
_log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry)
self._connected.clear()
await asyncio.sleep(retry)
await self.voice_disconnect()
@ -479,7 +479,7 @@ class VoiceClient(VoiceProtocol):
await self.connect(reconnect=True, timeout=self.timeout)
except asyncio.TimeoutError:
# at this point we've retried 5 times... let's continue the loop.
log.warning('Could not connect to voice... Retrying...')
_log.warning('Could not connect to voice... Retrying...')
continue
async def disconnect(self, *, force: bool = False) -> None:
@ -671,6 +671,6 @@ class VoiceClient(VoiceProtocol):
try:
self.socket.sendto(packet, (self.endpoint_ip, self.voice_port))
except BlockingIOError:
log.warning('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp)
_log.warning('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp)
self.checked_add('timestamp', opus.Encoder.SAMPLES_PER_FRAME, 4294967295)

View File

@ -24,7 +24,6 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import contextvars
import logging
import asyncio
import json
@ -32,6 +31,7 @@ import re
from urllib.parse import quote as urlquote
from typing import Any, Dict, List, Literal, NamedTuple, Optional, TYPE_CHECKING, Tuple, Union, overload
from contextvars import ContextVar
import aiohttp
@ -43,7 +43,7 @@ from ..user import BaseUser, User
from ..asset import Asset
from ..http import Route
from ..mixins import Hashable
from ..object import Object
from ..channel import PartialMessageable
__all__ = (
'Webhook',
@ -52,15 +52,20 @@ __all__ = (
'PartialWebhookGuild',
)
log = logging.getLogger(__name__)
_log = logging.getLogger(__name__)
if TYPE_CHECKING:
from ..file import File
from ..embeds import Embed
from ..mentions import AllowedMentions
from ..state import ConnectionState
from ..http import Response
from ..types.webhook import (
Webhook as WebhookPayload,
)
from ..types.message import (
Message as MessagePayload,
)
from ..guild import Guild
from ..channel import TextChannel
from ..abc import Snowflake
@ -116,7 +121,7 @@ class AsyncWebhookAdapter:
if payload is not None:
headers['Content-Type'] = 'application/json'
to_send = utils.to_json(payload)
to_send = utils._to_json(payload)
if auth_token is not None:
headers['Authorization'] = f'Bot {auth_token}'
@ -143,7 +148,7 @@ class AsyncWebhookAdapter:
try:
async with session.request(method, url, data=to_send, headers=headers, params=params) as response:
log.debug(
_log.debug(
'Webhook ID %s with %s %s has returned status code %s',
webhook_id,
method,
@ -157,7 +162,7 @@ class AsyncWebhookAdapter:
remaining = response.headers.get('X-Ratelimit-Remaining')
if remaining == '0' and response.status != 429:
delta = utils._parse_ratelimit_header(response)
log.debug(
_log.debug(
'Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds', webhook_id, delta
)
lock.delay_by(delta)
@ -170,7 +175,7 @@ class AsyncWebhookAdapter:
raise HTTPException(response, data)
retry_after: float = data['retry_after'] # type: ignore
log.warning('Webhook ID %s is rate limited. Retrying in %.2f seconds', webhook_id, retry_after)
_log.warning('Webhook ID %s is rate limited. Retrying in %.2f seconds', webhook_id, retry_after)
await asyncio.sleep(retry_after)
continue
@ -205,7 +210,7 @@ class AsyncWebhookAdapter:
token: Optional[str] = None,
session: aiohttp.ClientSession,
reason: Optional[str] = None,
):
) -> Response[None]:
route = Route('DELETE', '/webhooks/{webhook_id}', webhook_id=webhook_id)
return self.request(route, session, reason=reason, auth_token=token)
@ -216,7 +221,7 @@ class AsyncWebhookAdapter:
*,
session: aiohttp.ClientSession,
reason: Optional[str] = None,
):
) -> Response[None]:
route = Route('DELETE', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, reason=reason)
@ -228,7 +233,7 @@ class AsyncWebhookAdapter:
*,
session: aiohttp.ClientSession,
reason: Optional[str] = None,
):
) -> Response[WebhookPayload]:
route = Route('PATCH', '/webhooks/{webhook_id}', webhook_id=webhook_id)
return self.request(route, session, reason=reason, payload=payload, auth_token=token)
@ -240,7 +245,7 @@ class AsyncWebhookAdapter:
*,
session: aiohttp.ClientSession,
reason: Optional[str] = None,
):
) -> Response[WebhookPayload]:
route = Route('PATCH', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, reason=reason, payload=payload)
@ -255,7 +260,7 @@ class AsyncWebhookAdapter:
files: Optional[List[File]] = None,
thread_id: Optional[int] = None,
wait: bool = False,
):
) -> Response[Optional[MessagePayload]]:
params = {'wait': int(wait)}
if thread_id:
params['thread_id'] = thread_id
@ -269,7 +274,7 @@ class AsyncWebhookAdapter:
message_id: int,
*,
session: aiohttp.ClientSession,
):
) -> Response[MessagePayload]:
route = Route(
'GET',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
@ -289,7 +294,7 @@ class AsyncWebhookAdapter:
payload: Optional[Dict[str, Any]] = None,
multipart: Optional[List[Dict[str, Any]]] = None,
files: Optional[List[File]] = None,
):
) -> Response[Message]:
route = Route(
'PATCH',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
@ -306,7 +311,7 @@ class AsyncWebhookAdapter:
message_id: int,
*,
session: aiohttp.ClientSession,
):
) -> Response[None]:
route = Route(
'DELETE',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
@ -322,7 +327,7 @@ class AsyncWebhookAdapter:
token: str,
*,
session: aiohttp.ClientSession,
):
) -> Response[WebhookPayload]:
route = Route('GET', '/webhooks/{webhook_id}', webhook_id=webhook_id)
return self.request(route, session=session, auth_token=token)
@ -332,7 +337,7 @@ class AsyncWebhookAdapter:
token: str,
*,
session: aiohttp.ClientSession,
):
) -> Response[WebhookPayload]:
route = Route('GET', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session=session)
@ -344,7 +349,7 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession,
type: int,
data: Optional[Dict[str, Any]] = None,
):
) -> Response[None]:
payload: Dict[str, Any] = {
'type': type,
}
@ -367,7 +372,7 @@ class AsyncWebhookAdapter:
token: str,
*,
session: aiohttp.ClientSession,
):
) -> Response[MessagePayload]:
r = Route(
'GET',
'/webhooks/{webhook_id}/{webhook_token}/messages/@original',
@ -385,7 +390,7 @@ class AsyncWebhookAdapter:
payload: Optional[Dict[str, Any]] = None,
multipart: Optional[List[Dict[str, Any]]] = None,
files: Optional[List[File]] = None,
):
) -> Response[MessagePayload]:
r = Route(
'PATCH',
'/webhooks/{webhook_id}/{webhook_token}/messages/@original',
@ -400,7 +405,7 @@ class AsyncWebhookAdapter:
token: str,
*,
session: aiohttp.ClientSession,
):
) -> Response[None]:
r = Route(
'DELETE',
'/webhooks/{webhook_id}/{wehook_token}/messages/@original',
@ -420,7 +425,7 @@ def handle_message_parameters(
content: Optional[str] = MISSING,
*,
username: str = MISSING,
avatar_url: str = MISSING,
avatar_url: Any = MISSING,
tts: bool = False,
ephemeral: bool = False,
file: File = MISSING,
@ -481,7 +486,7 @@ def handle_message_parameters(
files = [file]
if files:
multipart.append({'name': 'payload_json', 'value': utils.to_json(payload)})
multipart.append({'name': 'payload_json', 'value': utils._to_json(payload)})
payload = None
if len(files) == 1:
file = files[0]
@ -507,7 +512,7 @@ def handle_message_parameters(
return ExecuteWebhookParameters(payload=payload, multipart=multipart, files=files)
async_context = contextvars.ContextVar('async_webhook_context', default=AsyncWebhookAdapter())
async_context: ContextVar[AsyncWebhookAdapter] = ContextVar('async_webhook_context', default=AsyncWebhookAdapter())
class PartialWebhookChannel(Hashable):
@ -579,10 +584,11 @@ class _FriendlyHttpAttributeErrorHelper:
class _WebhookState:
__slots__ = ('_parent', '_webhook')
def __init__(self, webhook, parent):
self._webhook = webhook
def __init__(self, webhook: Any, parent: Optional[Union[ConnectionState, _WebhookState]]):
self._webhook: Any = webhook
if isinstance(parent, self.__class__):
self._parent: Optional[ConnectionState]
if isinstance(parent, _WebhookState):
self._parent = None
else:
self._parent = parent
@ -595,7 +601,12 @@ class _WebhookState:
def store_user(self, data):
if self._parent is not None:
return self._parent.store_user(data)
return BaseUser(state=self, data=data)
# state parameter is artificial
return BaseUser(state=self, data=data) # type: ignore
def create_user(self, data):
# state parameter is artificial
return BaseUser(state=self, data=data) # type: ignore
@property
def http(self):
@ -636,13 +647,16 @@ class WebhookMessage(Message):
files: List[File] = MISSING,
view: Optional[View] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None,
):
) -> WebhookMessage:
"""|coro|
Edits the message.
.. versionadded:: 1.6
.. versionchanged:: 2.0
The edit is no longer in-place, instead the newly edited message is returned.
Parameters
------------
content: Optional[:class:`str`]
@ -682,8 +696,13 @@ class WebhookMessage(Message):
The length of ``embeds`` was invalid
InvalidArgument
There was no token associated with this webhook.
Returns
--------
:class:`WebhookMessage`
The newly edited message.
"""
await self._state._webhook.edit_message(
return await self._state._webhook.edit_message(
self.id,
content=content,
embeds=embeds,
@ -745,9 +764,9 @@ class BaseWebhook(Hashable):
'_state',
)
def __init__(self, data: WebhookPayload, token: Optional[str] = None, state=None):
def __init__(self, data: WebhookPayload, token: Optional[str] = None, state: Optional[ConnectionState] = None):
self.auth_token: Optional[str] = token
self._state = state or _WebhookState(self, parent=state)
self._state: Union[ConnectionState, _WebhookState] = state or _WebhookState(self, parent=state)
self._update(data)
def _update(self, data: WebhookPayload):
@ -762,10 +781,8 @@ class BaseWebhook(Hashable):
user = data.get('user')
self.user: Optional[Union[BaseUser, User]] = None
if user is not None:
if self._state is None:
self.user = BaseUser(state=None, data=user)
else:
self.user = User(state=self._state, data=user)
# state parameter may be _WebhookState
self.user = User(state=self._state, data=user) # type: ignore
source_channel = data.get('source_channel')
if source_channel:
@ -869,6 +886,10 @@ class Webhook(BaseWebhook):
Returns the webhooks's hash.
.. describe:: int(x)
Returns the webhooks's ID.
.. versionchanged:: 1.4
Webhooks are now comparable and hashable.
@ -916,12 +937,12 @@ class Webhook(BaseWebhook):
return f'<Webhook id={self.id!r}>'
@property
def url(self):
def url(self) -> str:
""":class:`str` : Returns the webhook's url."""
return f'https://discord.com/api/webhooks/{self.id}/{self.token}'
@classmethod
def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None):
def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook:
"""Creates a partial :class:`Webhook`.
Parameters
@ -957,7 +978,7 @@ class Webhook(BaseWebhook):
return cls(data, session, token=bot_token)
@classmethod
def from_url(cls, url: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None):
def from_url(cls, url: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook:
"""Creates a partial :class:`Webhook` from a webhook URL.
Parameters
@ -996,7 +1017,7 @@ class Webhook(BaseWebhook):
return cls(data, session, token=bot_token) # type: ignore
@classmethod
def _as_follower(cls, data, *, channel, user):
def _as_follower(cls, data, *, channel, user) -> Webhook:
name = f"{channel.guild} #{channel}"
feed: WebhookPayload = {
'id': data['webhook_id'],
@ -1012,7 +1033,7 @@ class Webhook(BaseWebhook):
return cls(feed, session=session, state=state, token=state.http.token)
@classmethod
def from_state(cls, data, state):
def from_state(cls, data, state) -> Webhook:
session = state.http._HTTPClient__session
return cls(data, session=session, state=state, token=state.http.token)
@ -1108,7 +1129,7 @@ class Webhook(BaseWebhook):
avatar: Optional[bytes] = MISSING,
channel: Optional[Snowflake] = None,
prefer_auth: bool = True,
):
) -> Webhook:
"""|coro|
Edits this Webhook.
@ -1155,6 +1176,7 @@ class Webhook(BaseWebhook):
adapter = async_context.get()
data: Optional[WebhookPayload] = None
# If a channel is given, always use the authenticated endpoint
if channel is not None:
if self.auth_token is None:
@ -1162,21 +1184,24 @@ class Webhook(BaseWebhook):
payload['channel_id'] = channel.id
data = await adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
self._update(data)
return
if prefer_auth and self.auth_token:
data = await adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
self._update(data)
elif self.token:
data = await adapter.edit_webhook_with_token(
self.id, self.token, payload=payload, session=self.session, reason=reason
)
self._update(data)
if data is None:
raise RuntimeError('Unreachable code hit: data was not assigned')
return Webhook(data=data, session=self.session, token=self.auth_token, state=self._state)
def _create_message(self, data):
state = _WebhookState(self, parent=self._state)
channel = self.channel or Object(id=int(data['channel_id']))
# state may be artificial (unlikely at this point...)
channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore
# state is artificial
return WebhookMessage(data=data, state=state, channel=channel) # type: ignore
@overload
@ -1185,7 +1210,7 @@ class Webhook(BaseWebhook):
content: str = MISSING,
*,
username: str = MISSING,
avatar_url: str = MISSING,
avatar_url: Any = MISSING,
tts: bool = MISSING,
ephemeral: bool = MISSING,
file: File = MISSING,
@ -1205,7 +1230,7 @@ class Webhook(BaseWebhook):
content: str = MISSING,
*,
username: str = MISSING,
avatar_url: str = MISSING,
avatar_url: Any = MISSING,
tts: bool = MISSING,
ephemeral: bool = MISSING,
file: File = MISSING,
@ -1224,7 +1249,7 @@ class Webhook(BaseWebhook):
content: str = MISSING,
*,
username: str = MISSING,
avatar_url: str = MISSING,
avatar_url: Any = MISSING,
tts: bool = False,
ephemeral: bool = False,
file: File = MISSING,
@ -1261,9 +1286,10 @@ class Webhook(BaseWebhook):
username: :class:`str`
The username to send with this message. If no username is provided
then the default username for the webhook is used.
avatar_url: Union[:class:`str`, :class:`Asset`]
avatar_url: :class:`str`
The avatar URL to send with this message. If no avatar URL is provided
then the default avatar for the webhook is used.
then the default avatar for the webhook is used. If this is not a
string then it is explicitly cast using ``str``.
tts: :class:`bool`
Indicates if the message should be sent using text-to-speech.
ephemeral: :class:`bool`
@ -1435,7 +1461,7 @@ class Webhook(BaseWebhook):
files: List[File] = MISSING,
view: Optional[View] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None,
):
) -> WebhookMessage:
"""|coro|
Edits a message owned by this webhook.
@ -1445,6 +1471,9 @@ class Webhook(BaseWebhook):
.. versionadded:: 1.6
.. versionchanged:: 2.0
The edit is no longer in-place, instead the newly edited message is returned.
Parameters
------------
message_id: :class:`int`
@ -1488,6 +1517,11 @@ class Webhook(BaseWebhook):
InvalidArgument
There was no token associated with this webhook or the webhook had
no state.
Returns
--------
:class:`WebhookMessage`
The newly edited webhook message.
"""
if self.token is None:
@ -1511,7 +1545,7 @@ class Webhook(BaseWebhook):
previous_allowed_mentions=previous_mentions,
)
adapter = async_context.get()
await adapter.edit_webhook_message(
data = await adapter.edit_webhook_message(
self.id,
self.token,
message_id,
@ -1521,10 +1555,12 @@ class Webhook(BaseWebhook):
files=params.files,
)
message = self._create_message(data)
if view and not view.is_finished():
self._state.store_view(view, message_id)
return message
async def delete_message(self, message_id: int):
async def delete_message(self, message_id: int, /) -> None:
"""|coro|
Deletes a message owned by this webhook.

View File

@ -43,7 +43,7 @@ from .. import utils
from ..errors import InvalidArgument, HTTPException, Forbidden, NotFound, DiscordServerError
from ..message import Message
from ..http import Route
from ..object import Object
from ..channel import PartialMessageable
from .async_ import BaseWebhook, handle_message_parameters, _WebhookState
@ -52,7 +52,7 @@ __all__ = (
'SyncWebhookMessage',
)
log = logging.getLogger(__name__)
_log = logging.getLogger(__name__)
if TYPE_CHECKING:
from ..file import File
@ -117,7 +117,7 @@ class WebhookAdapter:
if payload is not None:
headers['Content-Type'] = 'application/json'
to_send = utils.to_json(payload)
to_send = utils._to_json(payload)
if auth_token is not None:
headers['Authorization'] = f'Bot {auth_token}'
@ -142,13 +142,15 @@ class WebhookAdapter:
for p in multipart:
name = p['name']
if name == 'payload_json':
to_send = { 'payload_json': p['value'] }
to_send = {'payload_json': p['value']}
else:
file_data[name] = (p['filename'], p['value'], p['content_type'])
try:
with session.request(method, url, data=to_send, files=file_data, headers=headers, params=params) as response:
log.debug(
with session.request(
method, url, data=to_send, files=file_data, headers=headers, params=params
) as response:
_log.debug(
'Webhook ID %s with %s %s has returned status code %s',
webhook_id,
method,
@ -166,7 +168,7 @@ class WebhookAdapter:
remaining = response.headers.get('X-Ratelimit-Remaining')
if remaining == '0' and response.status_code != 429:
delta = utils._parse_ratelimit_header(response)
log.debug(
_log.debug(
'Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds', webhook_id, delta
)
lock.delay_by(delta)
@ -179,7 +181,7 @@ class WebhookAdapter:
raise HTTPException(response, data)
retry_after: float = data['retry_after'] # type: ignore
log.warning('Webhook ID %s is rate limited. Retrying in %.2f seconds', webhook_id, retry_after)
_log.warning('Webhook ID %s is rate limited. Retrying in %.2f seconds', webhook_id, retry_after)
time.sleep(retry_after)
continue
@ -346,8 +348,17 @@ class WebhookAdapter:
return self.request(route, session=session)
_context = threading.local()
_context.adapter = WebhookAdapter()
class _WebhookContext(threading.local):
adapter: Optional[WebhookAdapter] = None
_context = _WebhookContext()
def _get_webhook_adapter() -> WebhookAdapter:
if _context.adapter is None:
_context.adapter = WebhookAdapter()
return _context.adapter
class SyncWebhookMessage(Message):
@ -362,6 +373,8 @@ class SyncWebhookMessage(Message):
.. versionadded:: 2.0
"""
_state: _WebhookState
def edit(
self,
content: Optional[str] = MISSING,
@ -370,7 +383,7 @@ class SyncWebhookMessage(Message):
file: File = MISSING,
files: List[File] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None,
):
) -> SyncWebhookMessage:
"""Edits the message.
Parameters
@ -403,8 +416,13 @@ class SyncWebhookMessage(Message):
The length of ``embeds`` was invalid
InvalidArgument
There was no token associated with this webhook.
Returns
--------
:class:`SyncWebhookMessage`
The newly edited message.
"""
self._state._webhook.edit_message(
return self._state._webhook.edit_message(
self.id,
content=content,
embeds=embeds,
@ -457,6 +475,10 @@ class SyncWebhook(BaseWebhook):
Returns the webhooks's hash.
.. describe:: int(x)
Returns the webhooks's ID.
.. versionchanged:: 1.4
Webhooks are now comparable and hashable.
@ -504,12 +526,12 @@ class SyncWebhook(BaseWebhook):
return f'<Webhook id={self.id!r}>'
@property
def url(self):
def url(self) -> str:
""":class:`str` : Returns the webhook's url."""
return f'https://discord.com/api/webhooks/{self.id}/{self.token}'
@classmethod
def partial(cls, id: int, token: str, *, session: Session = MISSING, bot_token: Optional[str] = None):
def partial(cls, id: int, token: str, *, session: Session = MISSING, bot_token: Optional[str] = None) -> SyncWebhook:
"""Creates a partial :class:`Webhook`.
Parameters
@ -548,7 +570,7 @@ class SyncWebhook(BaseWebhook):
return cls(data, session, token=bot_token)
@classmethod
def from_url(cls, url: str, *, session: Session = MISSING, bot_token: Optional[str] = None):
def from_url(cls, url: str, *, session: Session = MISSING, bot_token: Optional[str] = None) -> SyncWebhook:
"""Creates a partial :class:`Webhook` from a webhook URL.
Parameters
@ -621,7 +643,7 @@ class SyncWebhook(BaseWebhook):
:class:`SyncWebhook`
The fetched webhook.
"""
adapter: WebhookAdapter = _context.adapter
adapter: WebhookAdapter = _get_webhook_adapter()
if prefer_auth and self.auth_token:
data = adapter.fetch_webhook(self.id, self.auth_token, session=self.session)
@ -632,7 +654,7 @@ class SyncWebhook(BaseWebhook):
return SyncWebhook(data, self.session, token=self.auth_token, state=self._state)
def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True):
def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True) -> None:
"""Deletes this Webhook.
Parameters
@ -659,7 +681,7 @@ class SyncWebhook(BaseWebhook):
if self.token is None and self.auth_token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
adapter: WebhookAdapter = _context.adapter
adapter: WebhookAdapter = _get_webhook_adapter()
if prefer_auth and self.auth_token:
adapter.delete_webhook(self.id, token=self.auth_token, session=self.session, reason=reason)
@ -674,7 +696,7 @@ class SyncWebhook(BaseWebhook):
avatar: Optional[bytes] = MISSING,
channel: Optional[Snowflake] = None,
prefer_auth: bool = True,
):
) -> SyncWebhook:
"""Edits this Webhook.
Parameters
@ -702,6 +724,11 @@ class SyncWebhook(BaseWebhook):
InvalidArgument
This webhook does not have a token associated with it
or it tried editing a channel without authentication.
Returns
--------
:class:`SyncWebhook`
The newly edited webhook.
"""
if self.token is None and self.auth_token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
@ -713,8 +740,9 @@ class SyncWebhook(BaseWebhook):
if avatar is not MISSING:
payload['avatar'] = utils._bytes_to_base64_data(avatar) if avatar is not None else None
adapter: WebhookAdapter = _context.adapter
adapter: WebhookAdapter = _get_webhook_adapter()
data: Optional[WebhookPayload] = None
# If a channel is given, always use the authenticated endpoint
if channel is not None:
if self.auth_token is None:
@ -722,20 +750,23 @@ class SyncWebhook(BaseWebhook):
payload['channel_id'] = channel.id
data = adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
self._update(data)
return
if prefer_auth and self.auth_token:
data = adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
self._update(data)
elif self.token:
data = adapter.edit_webhook_with_token(self.id, self.token, payload=payload, session=self.session, reason=reason)
self._update(data)
if data is None:
raise RuntimeError('Unreachable code hit: data was not assigned')
return SyncWebhook(data=data, session=self.session, token=self.auth_token, state=self._state)
def _create_message(self, data):
state = _WebhookState(self, parent=self._state)
channel = self.channel or Object(id=int(data['channel_id']))
return SyncWebhookMessage(data=data, state=state, channel=channel)
# state may be artificial (unlikely at this point...)
channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore
# state is artificial
return SyncWebhookMessage(data=data, state=state, channel=channel) # type: ignore
@overload
def send(
@ -743,7 +774,7 @@ class SyncWebhook(BaseWebhook):
content: str = MISSING,
*,
username: str = MISSING,
avatar_url: str = MISSING,
avatar_url: Any = MISSING,
tts: bool = MISSING,
file: File = MISSING,
files: List[File] = MISSING,
@ -760,7 +791,7 @@ class SyncWebhook(BaseWebhook):
content: str = MISSING,
*,
username: str = MISSING,
avatar_url: str = MISSING,
avatar_url: Any = MISSING,
tts: bool = MISSING,
file: File = MISSING,
files: List[File] = MISSING,
@ -776,7 +807,7 @@ class SyncWebhook(BaseWebhook):
content: str = MISSING,
*,
username: str = MISSING,
avatar_url: str = MISSING,
avatar_url: Any = MISSING,
tts: bool = False,
file: File = MISSING,
files: List[File] = MISSING,
@ -808,9 +839,10 @@ class SyncWebhook(BaseWebhook):
username: :class:`str`
The username to send with this message. If no username is provided
then the default username for the webhook is used.
avatar_url: Union[:class:`str`, :class:`Asset`]
avatar_url: :class:`str`
The avatar URL to send with this message. If no avatar URL is provided
then the default avatar for the webhook is used.
then the default avatar for the webhook is used. If this is not a
string then it is explicitly cast using ``str``.
tts: :class:`bool`
Indicates if the message should be sent using text-to-speech.
file: :class:`File`
@ -873,7 +905,7 @@ class SyncWebhook(BaseWebhook):
allowed_mentions=allowed_mentions,
previous_allowed_mentions=previous_mentions,
)
adapter: WebhookAdapter = _context.adapter
adapter: WebhookAdapter = _get_webhook_adapter()
thread_id: Optional[int] = None
if thread is not MISSING:
thread_id = thread.id
@ -891,7 +923,7 @@ class SyncWebhook(BaseWebhook):
if wait:
return self._create_message(data)
def fetch_message(self, id: int) -> SyncWebhookMessage:
def fetch_message(self, id: int, /) -> SyncWebhookMessage:
"""Retrieves a single :class:`~discord.SyncWebhookMessage` owned by this webhook.
.. versionadded:: 2.0
@ -921,7 +953,7 @@ class SyncWebhook(BaseWebhook):
if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
adapter: WebhookAdapter = _context.adapter
adapter: WebhookAdapter = _get_webhook_adapter()
data = adapter.get_webhook_message(
self.id,
self.token,
@ -940,7 +972,7 @@ class SyncWebhook(BaseWebhook):
file: File = MISSING,
files: List[File] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None,
):
) -> SyncWebhookMessage:
"""Edits a message owned by this webhook.
This is a lower level interface to :meth:`WebhookMessage.edit` in case
@ -995,8 +1027,8 @@ class SyncWebhook(BaseWebhook):
allowed_mentions=allowed_mentions,
previous_allowed_mentions=previous_mentions,
)
adapter: WebhookAdapter = _context.adapter
adapter.edit_webhook_message(
adapter: WebhookAdapter = _get_webhook_adapter()
data = adapter.edit_webhook_message(
self.id,
self.token,
message_id,
@ -1005,8 +1037,9 @@ class SyncWebhook(BaseWebhook):
multipart=params.multipart,
files=params.files,
)
return self._create_message(data)
def delete_message(self, message_id: int):
def delete_message(self, message_id: int, /) -> None:
"""Deletes a message owned by this webhook.
This is a lower level interface to :meth:`WebhookMessage.delete` in case
@ -1029,7 +1062,7 @@ class SyncWebhook(BaseWebhook):
if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
adapter: WebhookAdapter = _context.adapter
adapter: WebhookAdapter = _get_webhook_adapter()
adapter.delete_webhook_message(
self.id,
self.token,