Merge branch '2.0' into pr7268
# Conflicts: # discord/raw_models.py
This commit is contained in:
@@ -60,13 +60,15 @@ from .interactions import *
|
||||
from .components import *
|
||||
from .threads import *
|
||||
|
||||
class VersionInfo(NamedTuple):
|
||||
major: int
|
||||
minor: int
|
||||
micro: int
|
||||
releaselevel: Literal["alpha", "beta", "candidate", "final"]
|
||||
serial: int
|
||||
|
||||
version_info = VersionInfo(major=2, minor=0, micro=0, releaselevel='alpha', serial=0)
|
||||
class VersionInfo(NamedTuple):
|
||||
major: int
|
||||
minor: int
|
||||
micro: int
|
||||
releaselevel: Literal["alpha", "beta", "candidate", "final"]
|
||||
serial: int
|
||||
|
||||
|
||||
version_info: VersionInfo = VersionInfo(major=2, minor=0, micro=0, releaselevel='alpha', serial=0)
|
||||
|
||||
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
||||
|
||||
@@ -51,7 +51,7 @@ def core(parser, args):
|
||||
if args.version:
|
||||
show_version()
|
||||
|
||||
bot_template = """#!/usr/bin/env python3
|
||||
_bot_template = """#!/usr/bin/env python3
|
||||
|
||||
from discord.ext import commands
|
||||
import discord
|
||||
@@ -77,7 +77,7 @@ bot = Bot()
|
||||
bot.run(config.token)
|
||||
"""
|
||||
|
||||
gitignore_template = """# Byte-compiled / optimized / DLL files
|
||||
_gitignore_template = """# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
@@ -107,7 +107,7 @@ var/
|
||||
config.py
|
||||
"""
|
||||
|
||||
cog_template = '''from discord.ext import commands
|
||||
_cog_template = '''from discord.ext import commands
|
||||
import discord
|
||||
|
||||
class {name}(commands.Cog{attrs}):
|
||||
@@ -120,7 +120,7 @@ def setup(bot):
|
||||
bot.add_cog({name}(bot))
|
||||
'''
|
||||
|
||||
cog_extras = '''
|
||||
_cog_extras = '''
|
||||
def cog_unload(self):
|
||||
# clean up logic goes here
|
||||
pass
|
||||
@@ -170,7 +170,7 @@ _base_table = {
|
||||
# NUL (0) and 1-31 are disallowed
|
||||
_base_table.update((chr(i), None) for i in range(32))
|
||||
|
||||
translation_table = str.maketrans(_base_table)
|
||||
_translation_table = str.maketrans(_base_table)
|
||||
|
||||
def to_path(parser, name, *, replace_spaces=False):
|
||||
if isinstance(name, Path):
|
||||
@@ -182,7 +182,7 @@ def to_path(parser, name, *, replace_spaces=False):
|
||||
if len(name) <= 4 and name.upper() in forbidden:
|
||||
parser.error('invalid directory name given, use a different one')
|
||||
|
||||
name = name.translate(translation_table)
|
||||
name = name.translate(_translation_table)
|
||||
if replace_spaces:
|
||||
name = name.replace(' ', '-')
|
||||
return Path(name)
|
||||
@@ -215,14 +215,14 @@ def newbot(parser, args):
|
||||
try:
|
||||
with open(str(new_directory / 'bot.py'), 'w', encoding='utf-8') as fp:
|
||||
base = 'Bot' if not args.sharded else 'AutoShardedBot'
|
||||
fp.write(bot_template.format(base=base, prefix=args.prefix))
|
||||
fp.write(_bot_template.format(base=base, prefix=args.prefix))
|
||||
except OSError as exc:
|
||||
parser.error(f'could not create bot file ({exc})')
|
||||
|
||||
if not args.no_git:
|
||||
try:
|
||||
with open(str(new_directory / '.gitignore'), 'w', encoding='utf-8') as fp:
|
||||
fp.write(gitignore_template)
|
||||
fp.write(_gitignore_template)
|
||||
except OSError as exc:
|
||||
print(f'warning: could not create .gitignore file ({exc})')
|
||||
|
||||
@@ -240,7 +240,7 @@ def newcog(parser, args):
|
||||
try:
|
||||
with open(str(directory), 'w', encoding='utf-8') as fp:
|
||||
attrs = ''
|
||||
extra = cog_extras if args.full else ''
|
||||
extra = _cog_extras if args.full else ''
|
||||
if args.class_name:
|
||||
name = args.class_name
|
||||
else:
|
||||
@@ -255,7 +255,7 @@ def newcog(parser, args):
|
||||
attrs += f', name="{args.display_name}"'
|
||||
if args.hide_commands:
|
||||
attrs += ', command_attrs=dict(hidden=True)'
|
||||
fp.write(cog_template.format(name=name, extra=extra, attrs=attrs))
|
||||
fp.write(_cog_template.format(name=name, extra=extra, attrs=attrs))
|
||||
except OSError as exc:
|
||||
parser.error(f'could not create cog file ({exc})')
|
||||
else:
|
||||
|
||||
@@ -28,14 +28,14 @@ import copy
|
||||
import asyncio
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
@@ -52,6 +52,7 @@ from .role import Role
|
||||
from .invite import Invite
|
||||
from .file import File
|
||||
from .voice_client import VoiceClient, VoiceProtocol
|
||||
from .sticker import GuildSticker, StickerItem
|
||||
from . import utils
|
||||
|
||||
__all__ = (
|
||||
@@ -68,6 +69,7 @@ T = TypeVar('T', bound=VoiceProtocol)
|
||||
if TYPE_CHECKING:
|
||||
from datetime import datetime
|
||||
|
||||
from .client import Client
|
||||
from .user import ClientUser
|
||||
from .asset import Asset
|
||||
from .state import ConnectionState
|
||||
@@ -76,17 +78,18 @@ if TYPE_CHECKING:
|
||||
from .channel import CategoryChannel
|
||||
from .embeds import Embed
|
||||
from .message import Message, MessageReference, PartialMessage
|
||||
from .channel import TextChannel, DMChannel, GroupChannel
|
||||
from .channel import TextChannel, DMChannel, GroupChannel, PartialMessageable
|
||||
from .threads import Thread
|
||||
from .enums import InviteTarget
|
||||
from .ui.view import View
|
||||
from .types.channel import (
|
||||
PermissionOverwrite as PermissionOverwritePayload,
|
||||
Channel as ChannelPayload,
|
||||
GuildChannel as GuildChannelPayload,
|
||||
OverwriteType,
|
||||
)
|
||||
|
||||
PartialMessageableChannel = Union[TextChannel, Thread, DMChannel]
|
||||
PartialMessageableChannel = Union[TextChannel, Thread, DMChannel, PartialMessageable]
|
||||
MessageableChannel = Union[PartialMessageableChannel, GroupChannel]
|
||||
SnowflakeTime = Union["Snowflake", datetime]
|
||||
|
||||
@@ -120,11 +123,6 @@ class Snowflake(Protocol):
|
||||
__slots__ = ()
|
||||
id: int
|
||||
|
||||
@property
|
||||
def created_at(self) -> datetime:
|
||||
""":class:`datetime.datetime`: Returns the model's creation time as an aware datetime in UTC."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class User(Snowflake, Protocol):
|
||||
@@ -305,11 +303,8 @@ class GuildChannel:
|
||||
payload.append(d)
|
||||
|
||||
await http.bulk_channel_update(self.guild.id, payload, reason=reason)
|
||||
self.position = position
|
||||
if parent_id is not _undefined:
|
||||
self.category_id = int(parent_id) if parent_id else None
|
||||
|
||||
async def _edit(self, options: Dict[str, Any], reason: Optional[str]):
|
||||
async def _edit(self, options: Dict[str, Any], reason: Optional[str]) -> Optional[ChannelPayload]:
|
||||
try:
|
||||
parent = options.pop('category')
|
||||
except KeyError:
|
||||
@@ -388,8 +383,7 @@ class GuildChannel:
|
||||
options['type'] = ch_type.value
|
||||
|
||||
if options:
|
||||
data = await self._state.http.edit_channel(self.id, reason=reason, **options)
|
||||
self._update(self.guild, data)
|
||||
return await self._state.http.edit_channel(self.id, reason=reason, **options)
|
||||
|
||||
def _fill_overwrites(self, data: GuildChannelPayload) -> None:
|
||||
self._overwrites = []
|
||||
@@ -473,7 +467,7 @@ class GuildChannel:
|
||||
return PermissionOverwrite()
|
||||
|
||||
@property
|
||||
def overwrites(self) -> Mapping[Union[Role, Member], PermissionOverwrite]:
|
||||
def overwrites(self) -> Dict[Union[Role, Member], PermissionOverwrite]:
|
||||
"""Returns all of the channel's overwrites.
|
||||
|
||||
This is returned as a dictionary where the key contains the target which
|
||||
@@ -482,7 +476,7 @@ class GuildChannel:
|
||||
|
||||
Returns
|
||||
--------
|
||||
Mapping[Union[:class:`~discord.Role`, :class:`~discord.Member`], :class:`~discord.PermissionOverwrite`]
|
||||
Dict[Union[:class:`~discord.Role`, :class:`~discord.Member`], :class:`~discord.PermissionOverwrite`]
|
||||
The channel's permission overwrites.
|
||||
"""
|
||||
ret = {}
|
||||
@@ -1146,6 +1140,7 @@ class Messageable:
|
||||
- :class:`~discord.User`
|
||||
- :class:`~discord.Member`
|
||||
- :class:`~discord.ext.commands.Context`
|
||||
- :class:`~discord.Thread`
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
@@ -1162,6 +1157,7 @@ class Messageable:
|
||||
tts: bool = ...,
|
||||
embed: Embed = ...,
|
||||
file: File = ...,
|
||||
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
|
||||
delete_after: float = ...,
|
||||
nonce: Union[str, int] = ...,
|
||||
allowed_mentions: AllowedMentions = ...,
|
||||
@@ -1179,6 +1175,7 @@ class Messageable:
|
||||
tts: bool = ...,
|
||||
embed: Embed = ...,
|
||||
files: List[File] = ...,
|
||||
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
|
||||
delete_after: float = ...,
|
||||
nonce: Union[str, int] = ...,
|
||||
allowed_mentions: AllowedMentions = ...,
|
||||
@@ -1196,6 +1193,7 @@ class Messageable:
|
||||
tts: bool = ...,
|
||||
embeds: List[Embed] = ...,
|
||||
file: File = ...,
|
||||
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
|
||||
delete_after: float = ...,
|
||||
nonce: Union[str, int] = ...,
|
||||
allowed_mentions: AllowedMentions = ...,
|
||||
@@ -1213,6 +1211,7 @@ class Messageable:
|
||||
tts: bool = ...,
|
||||
embeds: List[Embed] = ...,
|
||||
files: List[File] = ...,
|
||||
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
|
||||
delete_after: float = ...,
|
||||
nonce: Union[str, int] = ...,
|
||||
allowed_mentions: AllowedMentions = ...,
|
||||
@@ -1231,6 +1230,7 @@ class Messageable:
|
||||
embeds=None,
|
||||
file=None,
|
||||
files=None,
|
||||
stickers=None,
|
||||
delete_after=None,
|
||||
nonce=None,
|
||||
allowed_mentions=None,
|
||||
@@ -1302,6 +1302,10 @@ class Messageable:
|
||||
embeds: List[:class:`~discord.Embed`]
|
||||
A list of embeds to upload. Must be a maximum of 10.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
stickers: Sequence[Union[:class:`~discord.GuildSticker`, :class:`~discord.StickerItem`]]
|
||||
A list of stickers to upload. Must be a maximum of 3.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Raises
|
||||
@@ -1338,6 +1342,9 @@ class Messageable:
|
||||
raise InvalidArgument('embeds parameter must be a list of up to 10 elements')
|
||||
embeds = [embed.to_dict() for embed in embeds]
|
||||
|
||||
if stickers is not None:
|
||||
stickers = [sticker.id for sticker in stickers]
|
||||
|
||||
if allowed_mentions is not None:
|
||||
if state.allowed_mentions is not None:
|
||||
allowed_mentions = state.allowed_mentions.merge(allowed_mentions).to_dict()
|
||||
@@ -1382,6 +1389,7 @@ class Messageable:
|
||||
embeds=embeds,
|
||||
nonce=nonce,
|
||||
message_reference=reference,
|
||||
stickers=stickers,
|
||||
components=components,
|
||||
)
|
||||
finally:
|
||||
@@ -1404,6 +1412,7 @@ class Messageable:
|
||||
nonce=nonce,
|
||||
allowed_mentions=allowed_mentions,
|
||||
message_reference=reference,
|
||||
stickers=stickers,
|
||||
components=components,
|
||||
)
|
||||
finally:
|
||||
@@ -1419,6 +1428,7 @@ class Messageable:
|
||||
nonce=nonce,
|
||||
allowed_mentions=allowed_mentions,
|
||||
message_reference=reference,
|
||||
stickers=stickers,
|
||||
components=components,
|
||||
)
|
||||
|
||||
@@ -1452,6 +1462,7 @@ class Messageable:
|
||||
This means that both ``with`` and ``async with`` work with this.
|
||||
|
||||
Example Usage: ::
|
||||
|
||||
async with channel.typing():
|
||||
# simulate something heavy
|
||||
await asyncio.sleep(10)
|
||||
@@ -1610,12 +1621,20 @@ class Connectable(Protocol):
|
||||
def _get_voice_state_pair(self) -> Tuple[int, int]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def connect(self, *, timeout: float = 60.0, reconnect: bool = True, cls: Type[T] = VoiceClient) -> T:
|
||||
async def connect(
|
||||
self,
|
||||
*,
|
||||
timeout: float = 60.0,
|
||||
reconnect: bool = True,
|
||||
cls: Callable[[Client, Connectable], T] = VoiceClient,
|
||||
) -> T:
|
||||
"""|coro|
|
||||
|
||||
Connects to voice and creates a :class:`VoiceClient` to establish
|
||||
your connection to the voice server.
|
||||
|
||||
This requires :attr:`Intents.voice_states`.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
timeout: :class:`float`
|
||||
|
||||
@@ -830,10 +830,12 @@ def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]:
|
||||
except KeyError:
|
||||
return Activity(**data)
|
||||
else:
|
||||
return CustomActivity(name=name, **data)
|
||||
# we removed the name key from data already
|
||||
return CustomActivity(name=name, **data) # type: ignore
|
||||
elif game_type is ActivityType.streaming:
|
||||
if 'url' in data:
|
||||
return Streaming(**data)
|
||||
# the url won't be None here
|
||||
return Streaming(**data) # type: ignore
|
||||
return Activity(**data)
|
||||
elif game_type is ActivityType.listening and 'sync_id' in data and 'session_id' in data:
|
||||
return Spotify(**data)
|
||||
|
||||
@@ -146,7 +146,7 @@ class AppInfo:
|
||||
self.rpc_origins: List[str] = data['rpc_origins']
|
||||
self.bot_public: bool = data['bot_public']
|
||||
self.bot_require_code_grant: bool = data['bot_require_code_grant']
|
||||
self.owner: User = state.store_user(data['owner'])
|
||||
self.owner: User = state.create_user(data['owner'])
|
||||
|
||||
team: Optional[TeamPayload] = data.get('team')
|
||||
self.team: Optional[Team] = Team(state, team) if team else None
|
||||
|
||||
@@ -177,6 +177,17 @@ class Asset(AssetMixin):
|
||||
animated=animated,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) -> Asset:
|
||||
animated = avatar.startswith('a_')
|
||||
format = 'gif' if animated else 'png'
|
||||
return cls(
|
||||
state,
|
||||
url=f"{cls.BASE}/guilds/{guild_id}/users/{member_id}/avatars/{avatar}.{format}?size=1024",
|
||||
key=avatar,
|
||||
animated=animated,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset:
|
||||
return cls(
|
||||
@@ -216,14 +227,25 @@ class Asset(AssetMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_sticker(cls, state, sticker_id: int, sticker_hash: str) -> Asset:
|
||||
def _from_sticker_banner(cls, state, banner: int) -> Asset:
|
||||
return cls(
|
||||
state,
|
||||
url=f'{cls.BASE}/stickers/{sticker_id}/{sticker_hash}.png?size=1024',
|
||||
key=sticker_hash,
|
||||
url=f'{cls.BASE}/app-assets/710982414301790216/store/{banner}.png',
|
||||
key=str(banner),
|
||||
animated=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset:
|
||||
animated = banner_hash.startswith('a_')
|
||||
format = 'gif' if animated else 'png'
|
||||
return cls(
|
||||
state,
|
||||
url=f'{cls.BASE}/banners/{user_id}/{banner_hash}.{format}?size=512',
|
||||
key=banner_hash,
|
||||
animated=animated
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self._url
|
||||
|
||||
@@ -291,10 +313,11 @@ class Asset(AssetMixin):
|
||||
if self._animated:
|
||||
if format not in VALID_ASSET_FORMATS:
|
||||
raise InvalidArgument(f'format must be one of {VALID_ASSET_FORMATS}')
|
||||
else:
|
||||
url = url.with_path(f'{path}.{format}')
|
||||
elif static_format is MISSING:
|
||||
if format not in VALID_STATIC_FORMATS:
|
||||
raise InvalidArgument(f'format must be one of {VALID_STATIC_FORMATS}')
|
||||
url = url.with_path(f'{path}.{format}')
|
||||
url = url.with_path(f'{path}.{format}')
|
||||
|
||||
if static_format is not MISSING and not self._animated:
|
||||
if static_format not in VALID_STATIC_FORMATS:
|
||||
|
||||
@@ -49,12 +49,17 @@ if TYPE_CHECKING:
|
||||
from .guild import Guild
|
||||
from .member import Member
|
||||
from .role import Role
|
||||
from .types.audit_log import AuditLogChange as AuditLogChangePayload
|
||||
from .types.audit_log import AuditLogEntry as AuditLogEntryPayload
|
||||
from .types.audit_log import (
|
||||
AuditLogChange as AuditLogChangePayload,
|
||||
AuditLogEntry as AuditLogEntryPayload,
|
||||
)
|
||||
from .types.channel import PermissionOverwrite as PermissionOverwritePayload
|
||||
from .types.role import Role as RolePayload
|
||||
from .types.snowflake import Snowflake
|
||||
from .user import User
|
||||
from .stage_instance import StageInstance
|
||||
from .sticker import GuildSticker
|
||||
from .threads import Thread
|
||||
|
||||
|
||||
def _transform_permissions(entry: AuditLogEntry, data: str) -> Permissions:
|
||||
@@ -69,22 +74,21 @@ def _transform_snowflake(entry: AuditLogEntry, data: Snowflake) -> int:
|
||||
return int(data)
|
||||
|
||||
|
||||
def _transform_channel(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Object]:
|
||||
def _transform_channel(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Union[abc.GuildChannel, Object]]:
|
||||
if data is None:
|
||||
return None
|
||||
return entry.guild.get_channel(int(data)) or Object(id=data)
|
||||
|
||||
|
||||
def _transform_owner_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Union[Member, User, None]:
|
||||
def _transform_member_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Union[Member, User, None]:
|
||||
if data is None:
|
||||
return None
|
||||
return entry._get_member(int(data))
|
||||
|
||||
|
||||
def _transform_inviter_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Union[Member, User, None]:
|
||||
def _transform_guild_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Guild]:
|
||||
if data is None:
|
||||
return None
|
||||
return entry._get_member(int(data))
|
||||
return entry._state._get_guild(data)
|
||||
|
||||
|
||||
def _transform_overwrites(
|
||||
@@ -142,6 +146,11 @@ def _enum_transformer(enum: Type[T]) -> Callable[[AuditLogEntry, int], T]:
|
||||
|
||||
return _transform
|
||||
|
||||
def _transform_type(entry: AuditLogEntry, data: Union[int]) -> Union[enums.ChannelType, enums.StickerType]:
|
||||
if entry.action.name.startswith('sticker_'):
|
||||
return enums.try_enum(enums.StickerType, data)
|
||||
else:
|
||||
return enums.try_enum(enums.ChannelType, data)
|
||||
|
||||
class AuditLogDiff:
|
||||
def __len__(self) -> int:
|
||||
@@ -176,8 +185,8 @@ class AuditLogChanges:
|
||||
'permissions': (None, _transform_permissions),
|
||||
'id': (None, _transform_snowflake),
|
||||
'color': ('colour', _transform_color),
|
||||
'owner_id': ('owner', _transform_owner_id),
|
||||
'inviter_id': ('inviter', _transform_inviter_id),
|
||||
'owner_id': ('owner', _transform_member_id),
|
||||
'inviter_id': ('inviter', _transform_member_id),
|
||||
'channel_id': ('channel', _transform_channel),
|
||||
'afk_channel_id': ('afk_channel', _transform_channel),
|
||||
'system_channel_id': ('system_channel', _transform_channel),
|
||||
@@ -191,12 +200,15 @@ class AuditLogChanges:
|
||||
'icon_hash': ('icon', _transform_icon),
|
||||
'avatar_hash': ('avatar', _transform_avatar),
|
||||
'rate_limit_per_user': ('slowmode_delay', None),
|
||||
'guild_id': ('guild', _transform_guild_id),
|
||||
'tags': ('emoji', None),
|
||||
'default_message_notifications': ('default_notifications', _enum_transformer(enums.NotificationLevel)),
|
||||
'region': (None, _enum_transformer(enums.VoiceRegion)),
|
||||
'rtc_region': (None, _enum_transformer(enums.VoiceRegion)),
|
||||
'video_quality_mode': (None, _enum_transformer(enums.VideoQualityMode)),
|
||||
'privacy_level': (None, _enum_transformer(enums.StagePrivacyLevel)),
|
||||
'type': (None, _enum_transformer(enums.ChannelType)),
|
||||
'format_type': (None, _enum_transformer(enums.StickerFormatType)),
|
||||
'type': (None, _transform_type),
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
@@ -318,6 +330,10 @@ class AuditLogEntry(Hashable):
|
||||
|
||||
Returns the entry's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the entry's ID.
|
||||
|
||||
.. versionchanged:: 1.7
|
||||
Audit log entries are now comparable and hashable.
|
||||
|
||||
@@ -434,7 +450,7 @@ class AuditLogEntry(Hashable):
|
||||
return utils.snowflake_time(self.id)
|
||||
|
||||
@utils.cached_property
|
||||
def target(self) -> Union[Guild, abc.GuildChannel, Member, User, Role, Invite, Emoji, Object, None]:
|
||||
def target(self) -> Union[Guild, abc.GuildChannel, Member, User, Role, Invite, Emoji, StageInstance, GuildSticker, Thread, Object, None]:
|
||||
try:
|
||||
converter = getattr(self, '_convert_target_' + self.action.target_type)
|
||||
except AttributeError:
|
||||
@@ -501,3 +517,12 @@ class AuditLogEntry(Hashable):
|
||||
|
||||
def _convert_target_message(self, target_id: int) -> Union[Member, User, None]:
|
||||
return self._get_member(target_id)
|
||||
|
||||
def _convert_target_stage_instance(self, target_id: int) -> Union[StageInstance, Object]:
|
||||
return self.guild.get_stage_instance(target_id) or Object(id=target_id)
|
||||
|
||||
def _convert_target_sticker(self, target_id: int) -> Union[GuildSticker, Object]:
|
||||
return self._state.get_sticker(target_id) or Object(id=target_id)
|
||||
|
||||
def _convert_target_thread(self, target_id: int) -> Union[Thread, Object]:
|
||||
return self.guild.get_thread(target_id) or Object(id=target_id)
|
||||
|
||||
@@ -26,13 +26,28 @@ from __future__ import annotations
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union, overload
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
import datetime
|
||||
|
||||
import discord.abc
|
||||
from .permissions import PermissionOverwrite, Permissions
|
||||
from .enums import ChannelType, StagePrivacyLevel, try_enum, VoiceRegion, VideoQualityMode
|
||||
from .mixins import Hashable
|
||||
from .object import Object
|
||||
from . import utils
|
||||
from .utils import MISSING
|
||||
from .asset import Asset
|
||||
@@ -49,6 +64,7 @@ __all__ = (
|
||||
'CategoryChannel',
|
||||
'StoreChannel',
|
||||
'GroupChannel',
|
||||
'PartialMessageable',
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -70,6 +86,7 @@ if TYPE_CHECKING:
|
||||
StoreChannel as StoreChannelPayload,
|
||||
GroupDMChannel as GroupChannelPayload,
|
||||
)
|
||||
from .types.snowflake import SnowflakeList
|
||||
|
||||
|
||||
async def _single_delete_strategy(messages: Iterable[Message]):
|
||||
@@ -98,6 +115,10 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
|
||||
Returns the channel's name.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the channel's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
name: :class:`str`
|
||||
@@ -127,6 +148,10 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
.. note::
|
||||
|
||||
To check if the channel or the guild of that channel are marked as NSFW, consider :meth:`is_nsfw` instead.
|
||||
default_auto_archive_duration: :class:`int`
|
||||
The default auto archive duration in minutes for threads created in this channel.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
@@ -142,6 +167,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
'_overwrites',
|
||||
'_type',
|
||||
'last_message_id',
|
||||
'default_auto_archive_duration',
|
||||
)
|
||||
|
||||
def __init__(self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload):
|
||||
@@ -171,6 +197,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
self.nsfw: bool = data.get('nsfw', False)
|
||||
# Does this need coercion into `int`? No idea yet.
|
||||
self.slowmode_delay: int = data.get('rate_limit_per_user', 0)
|
||||
self.default_auto_archive_duration: ThreadArchiveDuration = data.get('default_auto_archive_duration', 1440)
|
||||
self._type: int = data.get('type', self._type)
|
||||
self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id')
|
||||
self._fill_overwrites(data)
|
||||
@@ -207,7 +234,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return [thread for thread in self.guild.threads if thread.parent_id == self.id]
|
||||
return [thread for thread in self.guild._threads.values() if thread.parent_id == self.id]
|
||||
|
||||
def is_nsfw(self) -> bool:
|
||||
""":class:`bool`: Checks if the channel is NSFW."""
|
||||
@@ -250,13 +277,14 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
sync_permissions: bool = ...,
|
||||
category: Optional[CategoryChannel] = ...,
|
||||
slowmode_delay: int = ...,
|
||||
default_auto_archive_duration: ThreadArchiveDuration = ...,
|
||||
type: ChannelType = ...,
|
||||
overwrites: Dict[Union[Role, Member, Snowflake], PermissionOverwrite] = ...,
|
||||
) -> None:
|
||||
overwrites: Mapping[Union[Role, Member, Snowflake], PermissionOverwrite] = ...,
|
||||
) -> Optional[TextChannel]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def edit(self) -> None:
|
||||
async def edit(self) -> Optional[TextChannel]:
|
||||
...
|
||||
|
||||
async def edit(self, *, reason=None, **options):
|
||||
@@ -273,6 +301,9 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
.. versionchanged:: 1.4
|
||||
The ``type`` keyword-only parameter was added.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
Edits are no longer in-place, the newly edited channel is returned instead.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: :class:`str`
|
||||
@@ -298,9 +329,12 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
is only available to guilds that contain ``NEWS`` in :attr:`Guild.features`.
|
||||
reason: Optional[:class:`str`]
|
||||
The reason for editing this channel. Shows up on the audit log.
|
||||
overwrites: :class:`dict`
|
||||
A :class:`dict` of target (either a role or a member) to
|
||||
overwrites: :class:`Mapping`
|
||||
A :class:`Mapping` of target (either a role or a member) to
|
||||
:class:`PermissionOverwrite` to apply to the channel.
|
||||
default_auto_archive_duration: :class:`int`
|
||||
The new default auto archive duration in minutes for threads created in this channel.
|
||||
Must be one of ``60``, ``1440``, ``4320``, or ``10080``.
|
||||
|
||||
Raises
|
||||
------
|
||||
@@ -311,8 +345,18 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
You do not have permissions to edit the channel.
|
||||
HTTPException
|
||||
Editing the channel failed.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Optional[:class:`.TextChannel`]
|
||||
The newly edited text channel. If the edit was only positional
|
||||
then ``None`` is returned instead.
|
||||
"""
|
||||
await self._edit(options, reason=reason)
|
||||
|
||||
payload = await self._edit(options, reason=reason)
|
||||
if payload is not None:
|
||||
# the payload will always be the proper channel payload
|
||||
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
|
||||
|
||||
@utils.copy_doc(discord.abc.GuildChannel.clone)
|
||||
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> TextChannel:
|
||||
@@ -366,7 +410,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
if len(messages) > 100:
|
||||
raise ClientException('Can only bulk delete messages up to 100 messages')
|
||||
|
||||
message_ids: List[int] = [m.id for m in messages]
|
||||
message_ids: SnowflakeList = [m.id for m in messages]
|
||||
await self._state.http.delete_messages(self.id, message_ids)
|
||||
|
||||
async def purge(
|
||||
@@ -631,64 +675,73 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
"""
|
||||
return self.guild.get_thread(thread_id)
|
||||
|
||||
async def start_thread(
|
||||
async def create_thread(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
message: Optional[Snowflake] = None,
|
||||
auto_archive_duration: ThreadArchiveDuration = 1440,
|
||||
auto_archive_duration: ThreadArchiveDuration = MISSING,
|
||||
type: Optional[ChannelType] = None,
|
||||
reason: Optional[str] = None,
|
||||
) -> Thread:
|
||||
"""|coro|
|
||||
|
||||
Starts a thread in this text channel.
|
||||
Creates a thread in this text channel.
|
||||
|
||||
If no starter message is passed with the ``message`` parameter then
|
||||
you must have :attr:`~discord.Permissions.send_messages` and
|
||||
:attr:`~discord.Permissions.use_private_threads` in order to start the thread.
|
||||
To create a public thread, you must have :attr:`~discord.Permissions.create_public_threads`.
|
||||
For a private thread, :attr:`~discord.Permissions.create_private_threads` is needed instead.
|
||||
|
||||
If a starter message is passed with the ``message`` parameter then
|
||||
you must have :attr:`~discord.Permissions.send_messages` and
|
||||
:attr:`~discord.Permissions.use_threads` in order to start the thread.
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
name: :class:`str`
|
||||
The name of the thread.
|
||||
message: Optional[:class:`abc.Snowflake`]
|
||||
A snowflake representing the message to start the thread with.
|
||||
If ``None`` is passed then a private thread is started.
|
||||
A snowflake representing the message to create the thread with.
|
||||
If ``None`` is passed then a private thread is created.
|
||||
Defaults to ``None``.
|
||||
auto_archive_duration: :class:`int`
|
||||
The duration in minutes before a thread is automatically archived for inactivity.
|
||||
Defaults to ``1440`` or 24 hours.
|
||||
If not provided, the channel's default auto archive duration is used.
|
||||
type: Optional[:class:`ChannelType`]
|
||||
The type of thread to create. If a ``message`` is passed then this parameter
|
||||
is ignored, as a thread created with a message is always a public thread.
|
||||
By default this creates a private thread if this is ``None``.
|
||||
reason: :class:`str`
|
||||
The reason for creating a new thread. Shows up on the audit log.
|
||||
|
||||
Raises
|
||||
-------
|
||||
Forbidden
|
||||
You do not have permissions to start a thread.
|
||||
You do not have permissions to create a thread.
|
||||
HTTPException
|
||||
Starting the thread failed.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`Thread`
|
||||
The started thread
|
||||
The created thread
|
||||
"""
|
||||
|
||||
if type is None:
|
||||
type = ChannelType.private_thread
|
||||
|
||||
if message is None:
|
||||
data = await self._state.http.start_private_thread(
|
||||
data = await self._state.http.start_thread_without_message(
|
||||
self.id,
|
||||
name=name,
|
||||
auto_archive_duration=auto_archive_duration,
|
||||
type=ChannelType.private_thread.value,
|
||||
auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration,
|
||||
type=type.value,
|
||||
reason=reason,
|
||||
)
|
||||
else:
|
||||
data = await self._state.http.start_public_thread(
|
||||
data = await self._state.http.start_thread_with_message(
|
||||
self.id,
|
||||
message.id,
|
||||
name=name,
|
||||
auto_archive_duration=auto_archive_duration,
|
||||
type=ChannelType.public_thread.value,
|
||||
auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
return Thread(guild=self.guild, state=self._state, data=data)
|
||||
@@ -706,6 +759,8 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
You must have :attr:`~Permissions.read_message_history` to use this. If iterating over private threads
|
||||
then :attr:`~Permissions.manage_threads` is also required.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
limit: Optional[:class:`bool`]
|
||||
@@ -734,27 +789,6 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
"""
|
||||
return ArchivedThreadIterator(self.id, self.guild, limit=limit, joined=joined, private=private, before=before)
|
||||
|
||||
async def active_threads(self) -> List[Thread]:
|
||||
"""|coro|
|
||||
|
||||
Returns a list of active :class:`Thread` that the client can access.
|
||||
|
||||
This includes both private and public threads.
|
||||
|
||||
Raises
|
||||
------
|
||||
HTTPException
|
||||
The request to get the active threads failed.
|
||||
|
||||
Returns
|
||||
--------
|
||||
List[:class:`Thread`]
|
||||
The archived threads
|
||||
"""
|
||||
data = await self._state.http.get_active_threads(self.id)
|
||||
# TODO: thread members?
|
||||
return [Thread(guild=self.guild, state=self._state, data=d) for d in data.get('threads', [])]
|
||||
|
||||
|
||||
class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable):
|
||||
__slots__ = (
|
||||
@@ -930,15 +964,15 @@ class VoiceChannel(VocalGuildChannel):
|
||||
position: int = ...,
|
||||
sync_permissions: int = ...,
|
||||
category: Optional[CategoryChannel] = ...,
|
||||
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
|
||||
overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
|
||||
rtc_region: Optional[VoiceRegion] = ...,
|
||||
video_quality_mode: VideoQualityMode = ...,
|
||||
reason: Optional[str] = ...,
|
||||
) -> None:
|
||||
) -> Optional[VoiceChannel]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def edit(self) -> None:
|
||||
async def edit(self) -> Optional[VoiceChannel]:
|
||||
...
|
||||
|
||||
async def edit(self, *, reason=None, **options):
|
||||
@@ -952,6 +986,9 @@ class VoiceChannel(VocalGuildChannel):
|
||||
.. versionchanged:: 1.3
|
||||
The ``overwrites`` keyword-only parameter was added.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
Edits are no longer in-place, the newly edited channel is returned instead.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: :class:`str`
|
||||
@@ -970,8 +1007,8 @@ class VoiceChannel(VocalGuildChannel):
|
||||
category.
|
||||
reason: Optional[:class:`str`]
|
||||
The reason for editing this channel. Shows up on the audit log.
|
||||
overwrites: :class:`dict`
|
||||
A :class:`dict` of target (either a role or a member) to
|
||||
overwrites: :class:`Mapping`
|
||||
A :class:`Mapping` of target (either a role or a member) to
|
||||
:class:`PermissionOverwrite` to apply to the channel.
|
||||
rtc_region: Optional[:class:`VoiceRegion`]
|
||||
The new region for the voice channel's voice communication.
|
||||
@@ -991,9 +1028,18 @@ class VoiceChannel(VocalGuildChannel):
|
||||
You do not have permissions to edit the channel.
|
||||
HTTPException
|
||||
Editing the channel failed.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Optional[:class:`.VoiceChannel`]
|
||||
The newly edited voice channel. If the edit was only positional
|
||||
then ``None`` is returned instead.
|
||||
"""
|
||||
|
||||
await self._edit(options, reason=reason)
|
||||
payload = await self._edit(options, reason=reason)
|
||||
if payload is not None:
|
||||
# the payload will always be the proper channel payload
|
||||
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
|
||||
|
||||
|
||||
class StageChannel(VocalGuildChannel):
|
||||
@@ -1119,7 +1165,9 @@ class StageChannel(VocalGuildChannel):
|
||||
"""
|
||||
return utils.get(self.guild.stage_instances, channel_id=self.id)
|
||||
|
||||
async def create_instance(self, *, topic: str, privacy_level: StagePrivacyLevel = MISSING) -> StageInstance:
|
||||
async def create_instance(
|
||||
self, *, topic: str, privacy_level: StagePrivacyLevel = MISSING, reason: Optional[str] = None
|
||||
) -> StageInstance:
|
||||
"""|coro|
|
||||
|
||||
Create a stage instance.
|
||||
@@ -1135,6 +1183,8 @@ class StageChannel(VocalGuildChannel):
|
||||
The stage instance's topic.
|
||||
privacy_level: :class:`StagePrivacyLevel`
|
||||
The stage instance's privacy level. Defaults to :attr:`StagePrivacyLevel.guild_only`.
|
||||
reason: :class:`str`
|
||||
The reason the stage instance was created. Shows up on the audit log.
|
||||
|
||||
Raises
|
||||
------
|
||||
@@ -1159,7 +1209,7 @@ class StageChannel(VocalGuildChannel):
|
||||
|
||||
payload['privacy_level'] = privacy_level.value
|
||||
|
||||
data = await self._state.http.create_stage_instance(**payload)
|
||||
data = await self._state.http.create_stage_instance(**payload, reason=reason)
|
||||
return StageInstance(guild=self.guild, state=self._state, data=data)
|
||||
|
||||
async def fetch_instance(self) -> StageInstance:
|
||||
@@ -1193,15 +1243,15 @@ class StageChannel(VocalGuildChannel):
|
||||
position: int = ...,
|
||||
sync_permissions: int = ...,
|
||||
category: Optional[CategoryChannel] = ...,
|
||||
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
|
||||
overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
|
||||
rtc_region: Optional[VoiceRegion] = ...,
|
||||
video_quality_mode: VideoQualityMode = ...,
|
||||
reason: Optional[str] = ...,
|
||||
) -> None:
|
||||
) -> Optional[StageChannel]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def edit(self) -> None:
|
||||
async def edit(self) -> Optional[StageChannel]:
|
||||
...
|
||||
|
||||
async def edit(self, *, reason=None, **options):
|
||||
@@ -1215,6 +1265,9 @@ class StageChannel(VocalGuildChannel):
|
||||
.. versionchanged:: 2.0
|
||||
The ``topic`` parameter must now be set via :attr:`create_instance`.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
Edits are no longer in-place, the newly edited channel is returned instead.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: :class:`str`
|
||||
@@ -1229,8 +1282,8 @@ class StageChannel(VocalGuildChannel):
|
||||
category.
|
||||
reason: Optional[:class:`str`]
|
||||
The reason for editing this channel. Shows up on the audit log.
|
||||
overwrites: :class:`dict`
|
||||
A :class:`dict` of target (either a role or a member) to
|
||||
overwrites: :class:`Mapping`
|
||||
A :class:`Mapping` of target (either a role or a member) to
|
||||
:class:`PermissionOverwrite` to apply to the channel.
|
||||
rtc_region: Optional[:class:`VoiceRegion`]
|
||||
The new region for the stage channel's voice communication.
|
||||
@@ -1248,9 +1301,18 @@ class StageChannel(VocalGuildChannel):
|
||||
You do not have permissions to edit the channel.
|
||||
HTTPException
|
||||
Editing the channel failed.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Optional[:class:`.StageChannel`]
|
||||
The newly edited stage channel. If the edit was only positional
|
||||
then ``None`` is returned instead.
|
||||
"""
|
||||
|
||||
await self._edit(options, reason=reason)
|
||||
payload = await self._edit(options, reason=reason)
|
||||
if payload is not None:
|
||||
# the payload will always be the proper channel payload
|
||||
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
|
||||
|
||||
|
||||
class CategoryChannel(discord.abc.GuildChannel, Hashable):
|
||||
@@ -1276,6 +1338,10 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
|
||||
|
||||
Returns the category's name.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the category's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
name: :class:`str`
|
||||
@@ -1337,13 +1403,13 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
|
||||
name: str = ...,
|
||||
position: int = ...,
|
||||
nsfw: bool = ...,
|
||||
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
|
||||
overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
|
||||
reason: Optional[str] = ...,
|
||||
) -> None:
|
||||
) -> Optional[CategoryChannel]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def edit(self) -> None:
|
||||
async def edit(self) -> Optional[CategoryChannel]:
|
||||
...
|
||||
|
||||
async def edit(self, *, reason=None, **options):
|
||||
@@ -1357,6 +1423,9 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
|
||||
.. versionchanged:: 1.3
|
||||
The ``overwrites`` keyword-only parameter was added.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
Edits are no longer in-place, the newly edited channel is returned instead.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: :class:`str`
|
||||
@@ -1367,8 +1436,8 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
|
||||
To mark the category as NSFW or not.
|
||||
reason: Optional[:class:`str`]
|
||||
The reason for editing this category. Shows up on the audit log.
|
||||
overwrites: :class:`dict`
|
||||
A :class:`dict` of target (either a role or a member) to
|
||||
overwrites: :class:`Mapping`
|
||||
A :class:`Mapping` of target (either a role or a member) to
|
||||
:class:`PermissionOverwrite` to apply to the channel.
|
||||
|
||||
Raises
|
||||
@@ -1379,9 +1448,18 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
|
||||
You do not have permissions to edit the category.
|
||||
HTTPException
|
||||
Editing the category failed.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Optional[:class:`.CategoryChannel`]
|
||||
The newly edited category channel. If the edit was only positional
|
||||
then ``None`` is returned instead.
|
||||
"""
|
||||
|
||||
await self._edit(options=options, reason=reason)
|
||||
payload = await self._edit(options, reason=reason)
|
||||
if payload is not None:
|
||||
# the payload will always be the proper channel payload
|
||||
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
|
||||
|
||||
@utils.copy_doc(discord.abc.GuildChannel.move)
|
||||
async def move(self, **kwargs):
|
||||
@@ -1486,6 +1564,10 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
|
||||
|
||||
Returns the channel's name.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the channel's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
name: :class:`str`
|
||||
@@ -1570,12 +1652,12 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
|
||||
sync_permissions: bool = ...,
|
||||
category: Optional[CategoryChannel],
|
||||
reason: Optional[str],
|
||||
overwrites: Dict[Union[Role, Member], PermissionOverwrite],
|
||||
) -> None:
|
||||
overwrites: Mapping[Union[Role, Member], PermissionOverwrite],
|
||||
) -> Optional[StoreChannel]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def edit(self) -> None:
|
||||
async def edit(self) -> Optional[StoreChannel]:
|
||||
...
|
||||
|
||||
async def edit(self, *, reason=None, **options):
|
||||
@@ -1586,6 +1668,9 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
|
||||
You must have the :attr:`~Permissions.manage_channels` permission to
|
||||
use this.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
Edits are no longer in-place, the newly edited channel is returned instead.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: :class:`str`
|
||||
@@ -1602,8 +1687,8 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
|
||||
category.
|
||||
reason: Optional[:class:`str`]
|
||||
The reason for editing this channel. Shows up on the audit log.
|
||||
overwrites: :class:`dict`
|
||||
A :class:`dict` of target (either a role or a member) to
|
||||
overwrites: :class:`Mapping`
|
||||
A :class:`Mapping` of target (either a role or a member) to
|
||||
:class:`PermissionOverwrite` to apply to the channel.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
@@ -1617,8 +1702,18 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
|
||||
You do not have permissions to edit the channel.
|
||||
HTTPException
|
||||
Editing the channel failed.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Optional[:class:`.StoreChannel`]
|
||||
The newly edited store channel. If the edit was only positional
|
||||
then ``None`` is returned instead.
|
||||
"""
|
||||
await self._edit(options, reason=reason)
|
||||
|
||||
payload = await self._edit(options, reason=reason)
|
||||
if payload is not None:
|
||||
# the payload will always be the proper channel payload
|
||||
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
|
||||
|
||||
|
||||
DMC = TypeVar('DMC', bound='DMChannel')
|
||||
@@ -1645,6 +1740,10 @@ class DMChannel(discord.abc.Messageable, Hashable):
|
||||
|
||||
Returns a string representation of the channel
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the channel's ID.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
recipient: Optional[:class:`User`]
|
||||
@@ -1682,6 +1781,7 @@ class DMChannel(discord.abc.Messageable, Hashable):
|
||||
self._state = state
|
||||
self.id = channel_id
|
||||
self.recipient = None
|
||||
# state.user won't be None here
|
||||
self.me = state.user # type: ignore
|
||||
return self
|
||||
|
||||
@@ -1770,6 +1870,10 @@ class GroupChannel(discord.abc.Messageable, Hashable):
|
||||
|
||||
Returns a string representation of the channel
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the channel's ID.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
recipients: List[:class:`User`]
|
||||
@@ -1892,6 +1996,73 @@ class GroupChannel(discord.abc.Messageable, Hashable):
|
||||
await self._state.http.leave_group(self.id)
|
||||
|
||||
|
||||
class PartialMessageable(discord.abc.Messageable, Hashable):
|
||||
"""Represents a partial messageable to aid with working messageable channels when
|
||||
only a channel ID are present.
|
||||
|
||||
The only way to construct this class is through :meth:`Client.get_partial_messageable`.
|
||||
|
||||
Note that this class is trimmed down and has no rich attributes.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
.. container:: operations
|
||||
|
||||
.. describe:: x == y
|
||||
|
||||
Checks if two partial messageables are equal.
|
||||
|
||||
.. describe:: x != y
|
||||
|
||||
Checks if two partial messageables are not equal.
|
||||
|
||||
.. describe:: hash(x)
|
||||
|
||||
Returns the partial messageable's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the messageable's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
id: :class:`int`
|
||||
The channel ID associated with this partial messageable.
|
||||
type: Optional[:class:`ChannelType`]
|
||||
The channel type associated with this partial messageable, if given.
|
||||
"""
|
||||
|
||||
def __init__(self, state: ConnectionState, id: int, type: Optional[ChannelType] = None):
|
||||
self._state: ConnectionState = state
|
||||
self._channel: Object = Object(id=id)
|
||||
self.id: int = id
|
||||
self.type: Optional[ChannelType] = type
|
||||
|
||||
async def _get_channel(self) -> Object:
|
||||
return self._channel
|
||||
|
||||
def get_partial_message(self, message_id: int, /) -> PartialMessage:
|
||||
"""Creates a :class:`PartialMessage` from the message ID.
|
||||
|
||||
This is useful if you want to work with a message and only have its ID without
|
||||
doing an unnecessary API call.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
message_id: :class:`int`
|
||||
The message ID to create a partial message for.
|
||||
|
||||
Returns
|
||||
---------
|
||||
:class:`PartialMessage`
|
||||
The partial message.
|
||||
"""
|
||||
|
||||
from .message import PartialMessage
|
||||
|
||||
return PartialMessage(channel=self, id=message_id)
|
||||
|
||||
|
||||
def _guild_channel_factory(channel_type: int):
|
||||
value = try_enum(ChannelType, channel_type)
|
||||
if value is ChannelType.text:
|
||||
@@ -1919,8 +2090,16 @@ def _channel_factory(channel_type: int):
|
||||
else:
|
||||
return cls, value
|
||||
|
||||
|
||||
def _threaded_channel_factory(channel_type: int):
|
||||
cls, value = _channel_factory(channel_type)
|
||||
if value in (ChannelType.private_thread, ChannelType.public_thread, ChannelType.news_thread):
|
||||
return Thread, value
|
||||
return cls, value
|
||||
|
||||
|
||||
def _threaded_guild_channel_factory(channel_type: int):
|
||||
cls, value = _guild_channel_factory(channel_type)
|
||||
if value in (ChannelType.private_thread, ChannelType.public_thread, ChannelType.news_thread):
|
||||
return Thread, value
|
||||
return cls, value
|
||||
|
||||
@@ -29,17 +29,17 @@ import logging
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any, Callable, Coroutine, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union
|
||||
from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .user import User
|
||||
from .user import User, ClientUser
|
||||
from .invite import Invite
|
||||
from .template import Template
|
||||
from .widget import Widget
|
||||
from .guild import Guild
|
||||
from .emoji import Emoji
|
||||
from .channel import _threaded_channel_factory
|
||||
from .channel import _threaded_channel_factory, PartialMessageable
|
||||
from .enums import ChannelType
|
||||
from .mentions import AllowedMentions
|
||||
from .errors import *
|
||||
@@ -60,11 +60,11 @@ from .appinfo import AppInfo
|
||||
from .ui.view import View
|
||||
from .stage_instance import StageInstance
|
||||
from .threads import Thread
|
||||
from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .abc import SnowflakeTime, PrivateChannel, GuildChannel, Snowflake
|
||||
from .channel import DMChannel
|
||||
from .user import ClientUser
|
||||
from .message import Message
|
||||
from .member import Member
|
||||
from .voice_client import VoiceProtocol
|
||||
@@ -76,7 +76,7 @@ __all__ = (
|
||||
Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])
|
||||
|
||||
|
||||
log: logging.Logger = logging.getLogger(__name__)
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
|
||||
tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()}
|
||||
@@ -84,12 +84,12 @@ def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
|
||||
if not tasks:
|
||||
return
|
||||
|
||||
log.info('Cleaning up after %d tasks.', len(tasks))
|
||||
_log.info('Cleaning up after %d tasks.', len(tasks))
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
|
||||
log.info('All tasks finished cancelling.')
|
||||
_log.info('All tasks finished cancelling.')
|
||||
|
||||
for task in tasks:
|
||||
if task.cancelled():
|
||||
@@ -106,7 +106,7 @@ def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||
_cancel_tasks(loop)
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
finally:
|
||||
log.info('Closing the event loop.')
|
||||
_log.info('Closing the event loop.')
|
||||
loop.close()
|
||||
|
||||
class Client:
|
||||
@@ -142,7 +142,6 @@ class Client:
|
||||
intents: :class:`Intents`
|
||||
The intents that you want to enable for the session. This is a way of
|
||||
disabling and enabling certain gateway events from triggering and being sent.
|
||||
If not given, defaults to a regularly constructed :class:`Intents` class.
|
||||
|
||||
.. versionadded:: 1.5
|
||||
member_cache_flags: :class:`MemberCacheFlags`
|
||||
@@ -184,6 +183,14 @@ class Client:
|
||||
sync your system clock to Google's NTP server.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
enable_debug_events: :class:`bool`
|
||||
Whether to enable events that are useful only for debugging gateway related information.
|
||||
|
||||
Right now this involves :func:`on_socket_raw_receive` and :func:`on_socket_raw_send`. If
|
||||
this is ``False`` then those events will not be dispatched (due to performance considerations).
|
||||
To enable these events, this must be set to ``True``. Defaults to ``False``.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
@@ -195,9 +202,13 @@ class Client:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
intents: Intents,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
**options: Any,
|
||||
):
|
||||
options["intents"] = intents
|
||||
|
||||
# self.ws is set in the connect method
|
||||
self.ws: DiscordWebSocket = None # type: ignore
|
||||
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
|
||||
self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {}
|
||||
@@ -218,6 +229,7 @@ class Client:
|
||||
'before_identify': self._call_before_identify_hook
|
||||
}
|
||||
|
||||
self._enable_debug_events: bool = options.pop('enable_debug_events', False)
|
||||
self._connection: ConnectionState = self._get_state(**options)
|
||||
self._connection.shard_count = self.shard_count
|
||||
self._closed: bool = False
|
||||
@@ -227,7 +239,7 @@ class Client:
|
||||
|
||||
if VoiceClient.warn_nacl:
|
||||
VoiceClient.warn_nacl = False
|
||||
log.warning("PyNaCl is not installed, voice will NOT be supported")
|
||||
_log.warning("PyNaCl is not installed, voice will NOT be supported")
|
||||
|
||||
# internals
|
||||
|
||||
@@ -277,6 +289,14 @@ class Client:
|
||||
"""List[:class:`.Emoji`]: The emojis that the connected client has."""
|
||||
return self._connection.emojis
|
||||
|
||||
@property
|
||||
def stickers(self) -> List[GuildSticker]:
|
||||
"""List[:class:`.GuildSticker`]: The stickers that the connected client has.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return self._connection.stickers
|
||||
|
||||
@property
|
||||
def cached_messages(self) -> Sequence[Message]:
|
||||
"""Sequence[:class:`.Message`]: Read-only list of messages the connected client has cached.
|
||||
@@ -311,6 +331,8 @@ class Client:
|
||||
If this is not passed via ``__init__`` then this is retrieved
|
||||
through the gateway when an event contains the data. Usually
|
||||
after :func:`~discord.on_connect` is called.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return self._connection.application_id
|
||||
|
||||
@@ -318,7 +340,7 @@ class Client:
|
||||
def application_flags(self) -> ApplicationFlags:
|
||||
""":class:`~discord.ApplicationFlags`: The client's application flags.
|
||||
|
||||
.. versionadded: 2.0
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return self._connection.application_flags # type: ignore
|
||||
|
||||
@@ -343,7 +365,7 @@ class Client:
|
||||
return asyncio.create_task(wrapped, name=f'discord.py: {event_name}')
|
||||
|
||||
def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None:
|
||||
log.debug('Dispatching event %s', event)
|
||||
_log.debug('Dispatching event %s', event)
|
||||
method = 'on_' + event
|
||||
|
||||
listeners = self._listeners.get(event)
|
||||
@@ -448,8 +470,10 @@ class Client:
|
||||
passing status code.
|
||||
"""
|
||||
|
||||
log.info('logging in using static token')
|
||||
await self.http.static_login(token.strip())
|
||||
_log.info('logging in using static token')
|
||||
|
||||
data = await self.http.static_login(token.strip())
|
||||
self._connection.user = ClientUser(state=self._connection, data=data)
|
||||
|
||||
async def connect(self, *, reconnect: bool = True) -> None:
|
||||
"""|coro|
|
||||
@@ -489,7 +513,7 @@ class Client:
|
||||
while True:
|
||||
await self.ws.poll_event()
|
||||
except ReconnectWebSocket as e:
|
||||
log.info('Got a request to %s the websocket.', e.op)
|
||||
_log.info('Got a request to %s the websocket.', e.op)
|
||||
self.dispatch('disconnect')
|
||||
ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id)
|
||||
continue
|
||||
@@ -528,7 +552,7 @@ class Client:
|
||||
raise
|
||||
|
||||
retry = backoff.delay()
|
||||
log.exception("Attempting a reconnect in %.2fs", retry)
|
||||
_log.exception("Attempting a reconnect in %.2fs", retry)
|
||||
await asyncio.sleep(retry)
|
||||
# Always try to RESUME the connection
|
||||
# If the connection is not RESUME-able then the gateway will invalidate the session.
|
||||
@@ -630,10 +654,10 @@ class Client:
|
||||
try:
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt:
|
||||
log.info('Received signal to terminate bot and event loop.')
|
||||
_log.info('Received signal to terminate bot and event loop.')
|
||||
finally:
|
||||
future.remove_done_callback(stop_loop_on_completion)
|
||||
log.info('Cleaning up tasks.')
|
||||
_log.info('Cleaning up tasks.')
|
||||
_cleanup_loop(loop)
|
||||
|
||||
if not future.cancelled():
|
||||
@@ -661,9 +685,30 @@ class Client:
|
||||
if value is None:
|
||||
self._connection._activity = None
|
||||
elif isinstance(value, BaseActivity):
|
||||
self._connection._activity = value.to_dict()
|
||||
# ConnectionState._activity is typehinted as ActivityPayload, we're passing Dict[str, Any]
|
||||
self._connection._activity = value.to_dict() # type: ignore
|
||||
else:
|
||||
raise TypeError('activity must derive from BaseActivity.')
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
""":class:`.Status`:
|
||||
The status being used upon logging on to Discord.
|
||||
|
||||
.. versionadded: 2.0
|
||||
"""
|
||||
if self._connection._status in set(state.value for state in Status):
|
||||
return Status(self._connection._status)
|
||||
return Status.online
|
||||
|
||||
@status.setter
|
||||
def status(self, value):
|
||||
if value is Status.offline:
|
||||
self._connection._status = 'invisible'
|
||||
elif isinstance(value, Status):
|
||||
self._connection._status = str(value)
|
||||
else:
|
||||
raise TypeError('status must derive from Status.')
|
||||
|
||||
@property
|
||||
def allowed_mentions(self) -> Optional[AllowedMentions]:
|
||||
@@ -695,8 +740,8 @@ class Client:
|
||||
"""List[:class:`~discord.User`]: Returns a list of all the users the bot can see."""
|
||||
return list(self._connection._users.values())
|
||||
|
||||
def get_channel(self, id: int) -> Optional[Union[GuildChannel, PrivateChannel]]:
|
||||
"""Returns a channel with the given ID.
|
||||
def get_channel(self, id: int, /) -> Optional[Union[GuildChannel, Thread, PrivateChannel]]:
|
||||
"""Returns a channel or thread with the given ID.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
@@ -705,12 +750,34 @@ class Client:
|
||||
|
||||
Returns
|
||||
--------
|
||||
Optional[Union[:class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`]]
|
||||
Optional[Union[:class:`.abc.GuildChannel`, :class:`.Thread`, :class:`.abc.PrivateChannel`]]
|
||||
The returned channel or ``None`` if not found.
|
||||
"""
|
||||
return self._connection.get_channel(id)
|
||||
|
||||
def get_stage_instance(self, id) -> Optional[StageInstance]:
|
||||
def get_partial_messageable(self, id: int, *, type: Optional[ChannelType] = None) -> PartialMessageable:
|
||||
"""Returns a partial messageable with the given channel ID.
|
||||
|
||||
This is useful if you have a channel_id but don't want to do an API call
|
||||
to send messages to it.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
id: :class:`int`
|
||||
The channel ID to create a partial messageable for.
|
||||
type: Optional[:class:`.ChannelType`]
|
||||
The underlying channel type for the partial messageable.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`.PartialMessageable`
|
||||
The partial messageable
|
||||
"""
|
||||
return PartialMessageable(state=self._connection, id=id, type=type)
|
||||
|
||||
def get_stage_instance(self, id: int, /) -> Optional[StageInstance]:
|
||||
"""Returns a stage instance with the given stage channel ID.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
@@ -732,7 +799,7 @@ class Client:
|
||||
if isinstance(channel, StageChannel):
|
||||
return channel.instance
|
||||
|
||||
def get_guild(self, id) -> Optional[Guild]:
|
||||
def get_guild(self, id: int, /) -> Optional[Guild]:
|
||||
"""Returns a guild with the given ID.
|
||||
|
||||
Parameters
|
||||
@@ -747,7 +814,7 @@ class Client:
|
||||
"""
|
||||
return self._connection._get_guild(id)
|
||||
|
||||
def get_user(self, id) -> Optional[User]:
|
||||
def get_user(self, id: int, /) -> Optional[User]:
|
||||
"""Returns a user with the given ID.
|
||||
|
||||
Parameters
|
||||
@@ -762,7 +829,7 @@ class Client:
|
||||
"""
|
||||
return self._connection.get_user(id)
|
||||
|
||||
def get_emoji(self, id) -> Optional[Emoji]:
|
||||
def get_emoji(self, id: int, /) -> Optional[Emoji]:
|
||||
"""Returns an emoji with the given ID.
|
||||
|
||||
Parameters
|
||||
@@ -777,6 +844,23 @@ class Client:
|
||||
"""
|
||||
return self._connection.get_emoji(id)
|
||||
|
||||
def get_sticker(self, id: int, /) -> Optional[GuildSticker]:
|
||||
"""Returns a guild sticker with the given ID.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
.. note::
|
||||
|
||||
To retrieve standard stickers, use :meth:`.fetch_sticker`.
|
||||
or :meth:`.fetch_premium_sticker_packs`.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Optional[:class:`.GuildSticker`]
|
||||
The sticker or ``None`` if not found.
|
||||
"""
|
||||
return self._connection.get_sticker(id)
|
||||
|
||||
def get_all_channels(self) -> Generator[GuildChannel, None, None]:
|
||||
"""A generator that retrieves every :class:`.abc.GuildChannel` the client can 'access'.
|
||||
|
||||
@@ -959,7 +1043,7 @@ class Client:
|
||||
raise TypeError('event registered must be a coroutine function')
|
||||
|
||||
setattr(self, coro.__name__, coro)
|
||||
log.debug('%s has successfully been registered as an event', coro.__name__)
|
||||
_log.debug('%s has successfully been registered as an event', coro.__name__)
|
||||
return coro
|
||||
|
||||
async def change_presence(
|
||||
@@ -1109,7 +1193,7 @@ class Client:
|
||||
data = await self.http.get_template(code)
|
||||
return Template(data=data, state=self._connection) # type: ignore
|
||||
|
||||
async def fetch_guild(self, guild_id: int) -> Guild:
|
||||
async def fetch_guild(self, guild_id: int, /) -> Guild:
|
||||
"""|coro|
|
||||
|
||||
Retrieves a :class:`.Guild` from an ID.
|
||||
@@ -1198,7 +1282,7 @@ class Client:
|
||||
data = await self.http.create_guild(name, region_value, icon_base64)
|
||||
return Guild(data=data, state=self._connection)
|
||||
|
||||
async def fetch_stage_instance(self, channel_id: int) -> StageInstance:
|
||||
async def fetch_stage_instance(self, channel_id: int, /) -> StageInstance:
|
||||
"""|coro|
|
||||
|
||||
Gets a :class:`.StageInstance` for a stage channel id.
|
||||
@@ -1298,7 +1382,7 @@ class Client:
|
||||
|
||||
# Miscellaneous stuff
|
||||
|
||||
async def fetch_widget(self, guild_id: int) -> Widget:
|
||||
async def fetch_widget(self, guild_id: int, /) -> Widget:
|
||||
"""|coro|
|
||||
|
||||
Gets a :class:`.Widget` from a guild ID.
|
||||
@@ -1348,7 +1432,7 @@ class Client:
|
||||
data['rpc_origins'] = None
|
||||
return AppInfo(self._connection, data)
|
||||
|
||||
async def fetch_user(self, user_id: int) -> User:
|
||||
async def fetch_user(self, user_id: int, /) -> User:
|
||||
"""|coro|
|
||||
|
||||
Retrieves a :class:`~discord.User` based on their ID.
|
||||
@@ -1379,7 +1463,7 @@ class Client:
|
||||
data = await self.http.get_user(user_id)
|
||||
return User(state=self._connection, data=data)
|
||||
|
||||
async def fetch_channel(self, channel_id: int) -> Union[GuildChannel, PrivateChannel, Thread]:
|
||||
async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, PrivateChannel, Thread]:
|
||||
"""|coro|
|
||||
|
||||
Retrieves a :class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, or :class:`.Thread` with the specified ID.
|
||||
@@ -1413,15 +1497,18 @@ class Client:
|
||||
raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data))
|
||||
|
||||
if ch_type in (ChannelType.group, ChannelType.private):
|
||||
channel = factory(me=self.user, data=data, state=self._connection)
|
||||
# the factory will be a DMChannel or GroupChannel here
|
||||
channel = factory(me=self.user, data=data, state=self._connection) # type: ignore
|
||||
else:
|
||||
guild_id = int(data['guild_id'])
|
||||
# the factory can't be a DMChannel or GroupChannel here
|
||||
guild_id = int(data['guild_id']) # type: ignore
|
||||
guild = self.get_guild(guild_id) or Object(id=guild_id)
|
||||
channel = factory(guild=guild, state=self._connection, data=data)
|
||||
# GuildChannels expect a Guild, we may be passing an Object
|
||||
channel = factory(guild=guild, state=self._connection, data=data) # type: ignore
|
||||
|
||||
return channel
|
||||
|
||||
async def fetch_webhook(self, webhook_id: int) -> Webhook:
|
||||
async def fetch_webhook(self, webhook_id: int, /) -> Webhook:
|
||||
"""|coro|
|
||||
|
||||
Retrieves a :class:`.Webhook` with the specified ID.
|
||||
@@ -1443,6 +1530,49 @@ class Client:
|
||||
data = await self.http.get_webhook(webhook_id)
|
||||
return Webhook.from_state(data, state=self._connection)
|
||||
|
||||
async def fetch_sticker(self, sticker_id: int, /) -> Union[StandardSticker, GuildSticker]:
|
||||
"""|coro|
|
||||
|
||||
Retrieves a :class:`.Sticker` with the specified ID.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Raises
|
||||
--------
|
||||
:exc:`.HTTPException`
|
||||
Retrieving the sticker failed.
|
||||
:exc:`.NotFound`
|
||||
Invalid sticker ID.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Union[:class:`.StandardSticker`, :class:`.GuildSticker`]
|
||||
The sticker you requested.
|
||||
"""
|
||||
data = await self.http.get_sticker(sticker_id)
|
||||
cls, _ = _sticker_factory(data['type']) # type: ignore
|
||||
return cls(state=self._connection, data=data) # type: ignore
|
||||
|
||||
async def fetch_premium_sticker_packs(self) -> List[StickerPack]:
|
||||
"""|coro|
|
||||
|
||||
Retrieves all available premium sticker packs.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Raises
|
||||
-------
|
||||
:exc:`.HTTPException`
|
||||
Retrieving the sticker packs failed.
|
||||
|
||||
Returns
|
||||
---------
|
||||
List[:class:`.StickerPack`]
|
||||
All available premium sticker packs.
|
||||
"""
|
||||
data = await self.http.list_premium_sticker_packs()
|
||||
return [StickerPack(state=self._connection, data=pack) for pack in data['sticker_packs']]
|
||||
|
||||
async def create_dm(self, user: Snowflake) -> DMChannel:
|
||||
"""|coro|
|
||||
|
||||
@@ -1476,6 +1606,8 @@ class Client:
|
||||
|
||||
This method should be used for when a view is comprised of components
|
||||
that last longer than the lifecycle of the program.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Parameters
|
||||
------------
|
||||
@@ -1505,5 +1637,8 @@ class Client:
|
||||
|
||||
@property
|
||||
def persistent_views(self) -> Sequence[View]:
|
||||
"""Sequence[:class:`.View`]: A sequence of persistent views added to the client."""
|
||||
"""Sequence[:class:`.View`]: A sequence of persistent views added to the client.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return self._connection.persistent_views
|
||||
|
||||
@@ -78,7 +78,7 @@ class Colour:
|
||||
|
||||
__slots__ = ('value',)
|
||||
|
||||
def __init__(self, value):
|
||||
def __init__(self, value: int):
|
||||
if not isinstance(value, int):
|
||||
raise TypeError(f'Expected int parameter, received {value.__class__.__name__} instead.')
|
||||
|
||||
@@ -171,6 +171,14 @@ class Colour:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x11806a``."""
|
||||
return cls(0x11806a)
|
||||
|
||||
@classmethod
|
||||
def brand_green(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x57F287``.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return cls(0x57F287)
|
||||
|
||||
@classmethod
|
||||
def green(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x2ecc71``."""
|
||||
@@ -231,6 +239,14 @@ class Colour:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0xa84300``."""
|
||||
return cls(0xa84300)
|
||||
|
||||
@classmethod
|
||||
def brand_red(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0xED4245``.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return cls(0xED4245)
|
||||
|
||||
@classmethod
|
||||
def red(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0xe74c3c``."""
|
||||
@@ -308,6 +324,15 @@ class Colour:
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return cls(0xFEE75C)
|
||||
|
||||
@classmethod
|
||||
def dark_blurple(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x4E5D94``.
|
||||
This is the original Dark Blurple branding.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return cls(0x4E5D94)
|
||||
|
||||
|
||||
Color = Colour
|
||||
|
||||
@@ -226,6 +226,8 @@ class SelectMenu(Component):
|
||||
Defaults to 1 and must be between 1 and 25.
|
||||
options: List[:class:`SelectOption`]
|
||||
A list of options that can be selected in this menu.
|
||||
disabled: :class:`bool`
|
||||
Whether the select is disabled or not.
|
||||
"""
|
||||
|
||||
__slots__: Tuple[str, ...] = (
|
||||
@@ -234,6 +236,7 @@ class SelectMenu(Component):
|
||||
'min_values',
|
||||
'max_values',
|
||||
'options',
|
||||
'disabled',
|
||||
)
|
||||
|
||||
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
|
||||
@@ -245,6 +248,7 @@ class SelectMenu(Component):
|
||||
self.min_values: int = data.get('min_values', 1)
|
||||
self.max_values: int = data.get('max_values', 1)
|
||||
self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get('options', [])]
|
||||
self.disabled: bool = data.get('disabled', False)
|
||||
|
||||
def to_dict(self) -> SelectMenuPayload:
|
||||
payload: SelectMenuPayload = {
|
||||
@@ -253,6 +257,7 @@ class SelectMenu(Component):
|
||||
'min_values': self.min_values,
|
||||
'max_values': self.max_values,
|
||||
'options': [op.to_dict() for op in self.options],
|
||||
'disabled': self.disabled,
|
||||
}
|
||||
|
||||
if self.placeholder:
|
||||
@@ -272,14 +277,14 @@ class SelectOption:
|
||||
-----------
|
||||
label: :class:`str`
|
||||
The label of the option. This is displayed to users.
|
||||
Can only be up to 25 characters.
|
||||
Can only be up to 100 characters.
|
||||
value: :class:`str`
|
||||
The value of the option. This is not displayed to users.
|
||||
If not provided when constructed then it defaults to the
|
||||
label. Can only be up to 100 characters.
|
||||
description: Optional[:class:`str`]
|
||||
An additional description of the option, if any.
|
||||
Can only be up to 50 characters.
|
||||
Can only be up to 100 characters.
|
||||
emoji: Optional[Union[:class:`str`, :class:`Emoji`, :class:`PartialEmoji`]]
|
||||
The emoji of the option, if available.
|
||||
default: :class:`bool`
|
||||
|
||||
@@ -22,13 +22,23 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, TypeVar, Optional, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .abc import Messageable
|
||||
|
||||
from types import TracebackType
|
||||
|
||||
TypingT = TypeVar('TypingT', bound='Typing')
|
||||
|
||||
__all__ = (
|
||||
'Typing',
|
||||
)
|
||||
|
||||
def _typing_done_callback(fut):
|
||||
def _typing_done_callback(fut: asyncio.Future) -> None:
|
||||
# just retrieve any exception and call it a day
|
||||
try:
|
||||
fut.exception()
|
||||
@@ -36,11 +46,11 @@ def _typing_done_callback(fut):
|
||||
pass
|
||||
|
||||
class Typing:
|
||||
def __init__(self, messageable):
|
||||
self.loop = messageable._state.loop
|
||||
self.messageable = messageable
|
||||
def __init__(self, messageable: Messageable) -> None:
|
||||
self.loop: asyncio.AbstractEventLoop = messageable._state.loop
|
||||
self.messageable: Messageable = messageable
|
||||
|
||||
async def do_typing(self):
|
||||
async def do_typing(self) -> None:
|
||||
try:
|
||||
channel = self._channel
|
||||
except AttributeError:
|
||||
@@ -52,18 +62,26 @@ class Typing:
|
||||
await typing(channel.id)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
def __enter__(self):
|
||||
self.task = asyncio.ensure_future(self.do_typing(), loop=self.loop)
|
||||
def __enter__(self: TypingT) -> TypingT:
|
||||
self.task: asyncio.Task = self.loop.create_task(self.do_typing())
|
||||
self.task.add_done_callback(_typing_done_callback)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
def __exit__(self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.task.cancel()
|
||||
|
||||
async def __aenter__(self):
|
||||
async def __aenter__(self: TypingT) -> TypingT:
|
||||
self._channel = channel = await self.messageable._get_channel()
|
||||
await channel._state.http.send_typing(channel.id)
|
||||
return self.__enter__()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
async def __aexit__(self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.task.cancel()
|
||||
|
||||
@@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from typing import Any, Dict, Final, List, Protocol, TYPE_CHECKING, Type, TypeVar, Union
|
||||
from typing import Any, Dict, Final, List, Mapping, Protocol, TYPE_CHECKING, Type, TypeVar, Union
|
||||
|
||||
from . import utils
|
||||
from .colour import Colour
|
||||
@@ -72,30 +72,36 @@ if TYPE_CHECKING:
|
||||
T = TypeVar('T')
|
||||
MaybeEmpty = Union[T, _EmptyEmbed]
|
||||
|
||||
|
||||
class _EmbedFooterProxy(Protocol):
|
||||
text: MaybeEmpty[str]
|
||||
icon_url: MaybeEmpty[str]
|
||||
|
||||
|
||||
class _EmbedFieldProxy(Protocol):
|
||||
name: MaybeEmpty[str]
|
||||
value: MaybeEmpty[str]
|
||||
inline: bool
|
||||
|
||||
|
||||
class _EmbedMediaProxy(Protocol):
|
||||
url: MaybeEmpty[str]
|
||||
proxy_url: MaybeEmpty[str]
|
||||
height: MaybeEmpty[int]
|
||||
width: MaybeEmpty[int]
|
||||
|
||||
|
||||
class _EmbedVideoProxy(Protocol):
|
||||
url: MaybeEmpty[str]
|
||||
height: MaybeEmpty[int]
|
||||
width: MaybeEmpty[int]
|
||||
|
||||
|
||||
class _EmbedProviderProxy(Protocol):
|
||||
name: MaybeEmpty[str]
|
||||
url: MaybeEmpty[str]
|
||||
|
||||
|
||||
class _EmbedAuthorProxy(Protocol):
|
||||
name: MaybeEmpty[str]
|
||||
url: MaybeEmpty[str]
|
||||
@@ -175,15 +181,15 @@ class Embed:
|
||||
Empty: Final = EmptyEmbed
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
|
||||
color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
|
||||
title: MaybeEmpty[Any] = EmptyEmbed,
|
||||
type: EmbedType = 'rich',
|
||||
url: MaybeEmpty[Any] = EmptyEmbed,
|
||||
description: MaybeEmpty[Any] = EmptyEmbed,
|
||||
timestamp: datetime.datetime = None,
|
||||
self,
|
||||
*,
|
||||
colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
|
||||
color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
|
||||
title: MaybeEmpty[Any] = EmptyEmbed,
|
||||
type: EmbedType = 'rich',
|
||||
url: MaybeEmpty[Any] = EmptyEmbed,
|
||||
description: MaybeEmpty[Any] = EmptyEmbed,
|
||||
timestamp: datetime.datetime = None,
|
||||
):
|
||||
|
||||
self.colour = colour if colour is not EmptyEmbed else color
|
||||
@@ -205,7 +211,7 @@ class Embed:
|
||||
self.timestamp = timestamp
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[E], data: EmbedData) -> E:
|
||||
def from_dict(cls: Type[E], data: Mapping[str, Any]) -> E:
|
||||
"""Converts a :class:`dict` to a :class:`Embed` provided it is in the
|
||||
format that Discord expects it to be in.
|
||||
|
||||
@@ -366,7 +372,7 @@ class Embed:
|
||||
self._footer['icon_url'] = str(icon_url)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def remove_footer(self: E) -> E:
|
||||
"""Clears embed's footer information.
|
||||
|
||||
@@ -381,7 +387,7 @@ class Embed:
|
||||
pass
|
||||
|
||||
return self
|
||||
|
||||
|
||||
@property
|
||||
def image(self) -> _EmbedMediaProxy:
|
||||
"""Returns an ``EmbedProxy`` denoting the image contents.
|
||||
@@ -397,6 +403,19 @@ class Embed:
|
||||
"""
|
||||
return EmbedProxy(getattr(self, '_image', {})) # type: ignore
|
||||
|
||||
@image.setter
|
||||
def image(self: E, *, url: Any):
|
||||
self._image = {
|
||||
'url': str(url),
|
||||
}
|
||||
|
||||
@image.deleter
|
||||
def image(self: E):
|
||||
try:
|
||||
del self._image
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def set_image(self: E, *, url: MaybeEmpty[Any]) -> E:
|
||||
"""Sets the image for the embed content.
|
||||
|
||||
@@ -413,14 +432,9 @@ class Embed:
|
||||
"""
|
||||
|
||||
if url is EmptyEmbed:
|
||||
try:
|
||||
del self._image
|
||||
except AttributeError:
|
||||
pass
|
||||
del self.image
|
||||
else:
|
||||
self._image = {
|
||||
'url': str(url),
|
||||
}
|
||||
self.image = url
|
||||
|
||||
return self
|
||||
|
||||
@@ -439,7 +453,25 @@ class Embed:
|
||||
"""
|
||||
return EmbedProxy(getattr(self, '_thumbnail', {})) # type: ignore
|
||||
|
||||
def set_thumbnail(self: E, *, url: MaybeEmpty[Any]) -> E:
|
||||
@thumbnail.setter
|
||||
def thumbnail(self: E, *, url: Any):
|
||||
"""Sets the thumbnail for the embed content.
|
||||
"""
|
||||
|
||||
self._thumbnail = {
|
||||
'url': str(url),
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
@thumbnail.deleter
|
||||
def thumbnail(self):
|
||||
try:
|
||||
del self.thumbnail
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def set_thumbnail(self: E, *, url: MaybeEmpty[Any]):
|
||||
"""Sets the thumbnail for the embed content.
|
||||
|
||||
This function returns the class instance to allow for fluent-style
|
||||
@@ -453,16 +485,10 @@ class Embed:
|
||||
url: :class:`str`
|
||||
The source URL for the thumbnail. Only HTTP(S) is supported.
|
||||
"""
|
||||
|
||||
if url is EmptyEmbed:
|
||||
try:
|
||||
del self._thumbnail
|
||||
except AttributeError:
|
||||
pass
|
||||
del self.thumbnail
|
||||
else:
|
||||
self._thumbnail = {
|
||||
'url': str(url),
|
||||
}
|
||||
self.thumbnail = url
|
||||
|
||||
return self
|
||||
|
||||
|
||||
@@ -72,6 +72,10 @@ class Emoji(_EmojiTag, AssetMixin):
|
||||
|
||||
Returns the emoji rendered for discord.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the emoji ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
name: :class:`str`
|
||||
@@ -137,6 +141,9 @@ class Emoji(_EmojiTag, AssetMixin):
|
||||
return f'<a:{self.name}:{self.id}>'
|
||||
return f'<:{self.name}:{self.id}>'
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>'
|
||||
|
||||
@@ -212,7 +219,7 @@ class Emoji(_EmojiTag, AssetMixin):
|
||||
|
||||
await self._state.http.delete_custom_emoji(self.guild.id, self.id, reason=reason)
|
||||
|
||||
async def edit(self, *, name: str = MISSING, roles: List[Snowflake] = MISSING, reason: Optional[str] = None) -> None:
|
||||
async def edit(self, *, name: str = MISSING, roles: List[Snowflake] = MISSING, reason: Optional[str] = None) -> Emoji:
|
||||
r"""|coro|
|
||||
|
||||
Edits the custom emoji.
|
||||
@@ -220,6 +227,9 @@ class Emoji(_EmojiTag, AssetMixin):
|
||||
You must have :attr:`~Permissions.manage_emojis` permission to
|
||||
do this.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
The newly updated emoji is returned.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
name: :class:`str`
|
||||
@@ -235,6 +245,11 @@ class Emoji(_EmojiTag, AssetMixin):
|
||||
You are not allowed to edit emojis.
|
||||
HTTPException
|
||||
An error occurred editing the emoji.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`Emoji`
|
||||
The newly updated emoji.
|
||||
"""
|
||||
|
||||
payload = {}
|
||||
@@ -243,4 +258,5 @@ class Emoji(_EmojiTag, AssetMixin):
|
||||
if roles is not MISSING:
|
||||
payload['roles'] = [role.id for role in roles]
|
||||
|
||||
await self._state.http.edit_custom_emoji(self.guild.id, self.id, payload=payload, reason=reason)
|
||||
data = await self._state.http.edit_custom_emoji(self.guild.id, self.id, payload=payload, reason=reason)
|
||||
return Emoji(guild=self.guild, data=data, state=self._state)
|
||||
|
||||
@@ -46,6 +46,7 @@ __all__ = (
|
||||
'ExpireBehaviour',
|
||||
'ExpireBehavior',
|
||||
'StickerType',
|
||||
'StickerFormatType',
|
||||
'InviteTarget',
|
||||
'VideoQualityMode',
|
||||
'ComponentType',
|
||||
@@ -57,13 +58,17 @@ __all__ = (
|
||||
)
|
||||
|
||||
|
||||
def _create_value_cls(name):
|
||||
def _create_value_cls(name, comparable):
|
||||
cls = namedtuple('_EnumValue_' + name, 'name value')
|
||||
cls.__repr__ = lambda self: f'<{name}.{self.name}: {self.value!r}>'
|
||||
cls.__str__ = lambda self: f'{name}.{self.name}'
|
||||
if comparable:
|
||||
cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value
|
||||
cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value
|
||||
cls.__lt__ = lambda self, other: isinstance(other, self.__class__) and self.value < other.value
|
||||
cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value
|
||||
return cls
|
||||
|
||||
|
||||
def _is_descriptor(obj):
|
||||
return hasattr(obj, '__get__') or hasattr(obj, '__set__') or hasattr(obj, '__delete__')
|
||||
|
||||
@@ -75,12 +80,12 @@ class EnumMeta(type):
|
||||
_enum_member_map_: ClassVar[Dict[str, Any]]
|
||||
_enum_value_map_: ClassVar[Dict[Any, Any]]
|
||||
|
||||
def __new__(cls, name, bases, attrs):
|
||||
def __new__(cls, name, bases, attrs, *, comparable: bool = False):
|
||||
value_mapping = {}
|
||||
member_mapping = {}
|
||||
member_names = []
|
||||
|
||||
value_cls = _create_value_cls(name)
|
||||
value_cls = _create_value_cls(name, comparable)
|
||||
for key, value in list(attrs.items()):
|
||||
is_descriptor = _is_descriptor(value)
|
||||
if key[0] == '_' and not is_descriptor:
|
||||
@@ -251,7 +256,7 @@ class SpeakingState(Enum):
|
||||
return self.value
|
||||
|
||||
|
||||
class VerificationLevel(Enum):
|
||||
class VerificationLevel(Enum, comparable=True):
|
||||
none = 0
|
||||
low = 1
|
||||
medium = 2
|
||||
@@ -262,7 +267,7 @@ class VerificationLevel(Enum):
|
||||
return self.name
|
||||
|
||||
|
||||
class ContentFilter(Enum):
|
||||
class ContentFilter(Enum, comparable=True):
|
||||
disabled = 0
|
||||
no_role = 1
|
||||
all_members = 2
|
||||
@@ -295,7 +300,7 @@ class DefaultAvatar(Enum):
|
||||
return self.name
|
||||
|
||||
|
||||
class NotificationLevel(Enum):
|
||||
class NotificationLevel(Enum, comparable=True):
|
||||
all_messages = 0
|
||||
only_mentions = 1
|
||||
|
||||
@@ -346,6 +351,12 @@ class AuditLogAction(Enum):
|
||||
stage_instance_create = 83
|
||||
stage_instance_update = 84
|
||||
stage_instance_delete = 85
|
||||
sticker_create = 90
|
||||
sticker_update = 91
|
||||
sticker_delete = 92
|
||||
thread_create = 110
|
||||
thread_update = 111
|
||||
thread_delete = 112
|
||||
# fmt: on
|
||||
|
||||
@property
|
||||
@@ -390,6 +401,12 @@ class AuditLogAction(Enum):
|
||||
AuditLogAction.stage_instance_create: AuditLogActionCategory.create,
|
||||
AuditLogAction.stage_instance_update: AuditLogActionCategory.update,
|
||||
AuditLogAction.stage_instance_delete: AuditLogActionCategory.delete,
|
||||
AuditLogAction.sticker_create: AuditLogActionCategory.create,
|
||||
AuditLogAction.sticker_update: AuditLogActionCategory.update,
|
||||
AuditLogAction.sticker_delete: AuditLogActionCategory.delete,
|
||||
AuditLogAction.thread_create: AuditLogActionCategory.create,
|
||||
AuditLogAction.thread_update: AuditLogActionCategory.update,
|
||||
AuditLogAction.thread_delete: AuditLogActionCategory.delete,
|
||||
}
|
||||
# fmt: on
|
||||
return lookup[self]
|
||||
@@ -421,6 +438,10 @@ class AuditLogAction(Enum):
|
||||
return 'integration'
|
||||
elif v < 90:
|
||||
return 'stage_instance'
|
||||
elif v < 93:
|
||||
return 'sticker'
|
||||
elif v < 113:
|
||||
return 'thread'
|
||||
|
||||
|
||||
class UserFlags(Enum):
|
||||
@@ -476,10 +497,26 @@ ExpireBehavior = ExpireBehaviour
|
||||
|
||||
|
||||
class StickerType(Enum):
|
||||
standard = 1
|
||||
guild = 2
|
||||
|
||||
|
||||
class StickerFormatType(Enum):
|
||||
png = 1
|
||||
apng = 2
|
||||
lottie = 3
|
||||
|
||||
@property
|
||||
def file_extension(self) -> str:
|
||||
# fmt: off
|
||||
lookup: Dict[StickerFormatType, str] = {
|
||||
StickerFormatType.png: 'png',
|
||||
StickerFormatType.apng: 'png',
|
||||
StickerFormatType.lottie: 'json',
|
||||
}
|
||||
# fmt: on
|
||||
return lookup[self]
|
||||
|
||||
|
||||
class InviteTarget(Enum):
|
||||
unknown = 0
|
||||
@@ -545,7 +582,7 @@ class StagePrivacyLevel(Enum):
|
||||
guild_only = 2
|
||||
|
||||
|
||||
class NSFWLevel(Enum):
|
||||
class NSFWLevel(Enum, comparable=True):
|
||||
default = 0
|
||||
explicit = 1
|
||||
safe = 2
|
||||
|
||||
@@ -31,9 +31,9 @@ if TYPE_CHECKING:
|
||||
try:
|
||||
from requests import Response
|
||||
|
||||
ResponseType = Union[ClientResponse, Response]
|
||||
_ResponseType = Union[ClientResponse, Response]
|
||||
except ModuleNotFoundError:
|
||||
ResponseType = ClientResponse
|
||||
_ResponseType = ClientResponse
|
||||
|
||||
from .interactions import Interaction
|
||||
|
||||
@@ -123,8 +123,8 @@ class HTTPException(DiscordException):
|
||||
The Discord specific error code for the failure.
|
||||
"""
|
||||
|
||||
def __init__(self, response: ResponseType, message: Optional[Union[str, Dict[str, Any]]]):
|
||||
self.response: ResponseType = response
|
||||
def __init__(self, response: _ResponseType, message: Optional[Union[str, Dict[str, Any]]]):
|
||||
self.response: _ResponseType = response
|
||||
self.status: int = response.status # type: ignore
|
||||
self.code: int
|
||||
self.text: str
|
||||
|
||||
@@ -22,6 +22,26 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
|
||||
from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .context import Context
|
||||
from .cog import Cog
|
||||
from .errors import CommandError
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
Coro = Coroutine[Any, Any, T]
|
||||
MaybeCoro = Union[T, Coro[T]]
|
||||
CoroFunc = Callable[..., Coro[Any]]
|
||||
|
||||
Check = Union[Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]]
|
||||
Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]]
|
||||
Error = Union[Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]]
|
||||
|
||||
|
||||
# This is merely a tag type to avoid circular import issues.
|
||||
# Yes, this is a terrible solution but ultimately it is the only solution.
|
||||
class _BaseCommand:
|
||||
|
||||
@@ -22,13 +22,18 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import collections.abc
|
||||
import inspect
|
||||
import importlib.util
|
||||
import sys
|
||||
import traceback
|
||||
import types
|
||||
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union
|
||||
|
||||
import discord
|
||||
|
||||
@@ -39,6 +44,15 @@ from . import errors
|
||||
from .help import HelpCommand, DefaultHelpCommand
|
||||
from .cog import Cog
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import importlib.machinery
|
||||
|
||||
from discord.message import Message
|
||||
from ._types import (
|
||||
Check,
|
||||
CoroFunc,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
'when_mentioned',
|
||||
'when_mentioned_or',
|
||||
@@ -46,14 +60,21 @@ __all__ = (
|
||||
'AutoShardedBot',
|
||||
)
|
||||
|
||||
def when_mentioned(bot, msg):
|
||||
MISSING: Any = discord.utils.MISSING
|
||||
|
||||
T = TypeVar('T')
|
||||
CFT = TypeVar('CFT', bound='CoroFunc')
|
||||
CXT = TypeVar('CXT', bound='Context')
|
||||
|
||||
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
|
||||
"""A callable that implements a command prefix equivalent to being mentioned.
|
||||
|
||||
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
|
||||
"""
|
||||
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> ']
|
||||
# bot.user will never be None when this is called
|
||||
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore
|
||||
|
||||
def when_mentioned_or(*prefixes):
|
||||
def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]:
|
||||
"""A callable that implements when mentioned or other prefixes provided.
|
||||
|
||||
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
|
||||
@@ -89,7 +110,7 @@ def when_mentioned_or(*prefixes):
|
||||
|
||||
return inner
|
||||
|
||||
def _is_submodule(parent, child):
|
||||
def _is_submodule(parent: str, child: str) -> bool:
|
||||
return parent == child or child.startswith(parent + ".")
|
||||
|
||||
class _DefaultRepr:
|
||||
@@ -99,13 +120,13 @@ class _DefaultRepr:
|
||||
_default = _DefaultRepr()
|
||||
|
||||
class BotBase(GroupMixin):
|
||||
def __init__(self, command_prefix, help_command=_default, description=None, **options):
|
||||
super().__init__(**options)
|
||||
def __init__(self, command_prefix, help_command=_default, description=None, *, intents: discord.Intents, **options):
|
||||
super().__init__(**options, intents=intents)
|
||||
self.command_prefix = command_prefix
|
||||
self.extra_events = {}
|
||||
self.__cogs = {}
|
||||
self.__extensions = {}
|
||||
self._checks = []
|
||||
self.extra_events: Dict[str, List[CoroFunc]] = {}
|
||||
self.__cogs: Dict[str, Cog] = {}
|
||||
self.__extensions: Dict[str, types.ModuleType] = {}
|
||||
self._checks: List[Check] = []
|
||||
self._check_once = []
|
||||
self._before_invoke = None
|
||||
self._after_invoke = None
|
||||
@@ -128,13 +149,15 @@ class BotBase(GroupMixin):
|
||||
|
||||
# internal helpers
|
||||
|
||||
def dispatch(self, event_name, *args, **kwargs):
|
||||
super().dispatch(event_name, *args, **kwargs)
|
||||
def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None:
|
||||
# super() will resolve to Client
|
||||
super().dispatch(event_name, *args, **kwargs) # type: ignore
|
||||
ev = 'on_' + event_name
|
||||
for event in self.extra_events.get(ev, []):
|
||||
self._schedule_event(event, ev, *args, **kwargs)
|
||||
self._schedule_event(event, ev, *args, **kwargs) # type: ignore
|
||||
|
||||
async def close(self):
|
||||
@discord.utils.copy_doc(discord.Client.close)
|
||||
async def close(self) -> None:
|
||||
for extension in tuple(self.__extensions):
|
||||
try:
|
||||
self.unload_extension(extension)
|
||||
@@ -147,9 +170,9 @@ class BotBase(GroupMixin):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await super().close()
|
||||
await super().close() # type: ignore
|
||||
|
||||
async def on_command_error(self, context, exception):
|
||||
async def on_command_error(self, context: Context, exception: errors.CommandError) -> None:
|
||||
"""|coro|
|
||||
|
||||
The default command error handler provided by the bot.
|
||||
@@ -175,7 +198,7 @@ class BotBase(GroupMixin):
|
||||
|
||||
# global check registration
|
||||
|
||||
def check(self, func):
|
||||
def check(self, func: T) -> T:
|
||||
r"""A decorator that adds a global check to the bot.
|
||||
|
||||
A global check is similar to a :func:`.check` that is applied
|
||||
@@ -200,10 +223,11 @@ class BotBase(GroupMixin):
|
||||
return ctx.command.qualified_name in allowed_commands
|
||||
|
||||
"""
|
||||
self.add_check(func)
|
||||
# T was used instead of Check to ensure the type matches on return
|
||||
self.add_check(func) # type: ignore
|
||||
return func
|
||||
|
||||
def add_check(self, func, *, call_once=False):
|
||||
def add_check(self, func: Check, *, call_once: bool = False) -> None:
|
||||
"""Adds a global check to the bot.
|
||||
|
||||
This is the non-decorator interface to :meth:`.check`
|
||||
@@ -223,7 +247,7 @@ class BotBase(GroupMixin):
|
||||
else:
|
||||
self._checks.append(func)
|
||||
|
||||
def remove_check(self, func, *, call_once=False):
|
||||
def remove_check(self, func: Check, *, call_once: bool = False) -> None:
|
||||
"""Removes a global check from the bot.
|
||||
|
||||
This function is idempotent and will not raise an exception
|
||||
@@ -244,7 +268,7 @@ class BotBase(GroupMixin):
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def check_once(self, func):
|
||||
def check_once(self, func: CFT) -> CFT:
|
||||
r"""A decorator that adds a "call once" global check to the bot.
|
||||
|
||||
Unlike regular global checks, this one is called only once
|
||||
@@ -282,15 +306,16 @@ class BotBase(GroupMixin):
|
||||
self.add_check(func, call_once=True)
|
||||
return func
|
||||
|
||||
async def can_run(self, ctx, *, call_once=False):
|
||||
async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool:
|
||||
data = self._check_once if call_once else self._checks
|
||||
|
||||
if len(data) == 0:
|
||||
return True
|
||||
|
||||
return await discord.utils.async_all(f(ctx) for f in data)
|
||||
# type-checker doesn't distinguish between functions and methods
|
||||
return await discord.utils.async_all(f(ctx) for f in data) # type: ignore
|
||||
|
||||
async def is_owner(self, user):
|
||||
async def is_owner(self, user: discord.User) -> bool:
|
||||
"""|coro|
|
||||
|
||||
Checks if a :class:`~discord.User` or :class:`~discord.Member` is the owner of
|
||||
@@ -319,7 +344,8 @@ class BotBase(GroupMixin):
|
||||
elif self.owner_ids:
|
||||
return user.id in self.owner_ids
|
||||
else:
|
||||
app = await self.application_info()
|
||||
|
||||
app = await self.application_info() # type: ignore
|
||||
if app.team:
|
||||
self.owner_ids = ids = {m.id for m in app.team.members}
|
||||
return user.id in ids
|
||||
@@ -327,7 +353,7 @@ class BotBase(GroupMixin):
|
||||
self.owner_id = owner_id = app.owner.id
|
||||
return user.id == owner_id
|
||||
|
||||
def before_invoke(self, coro):
|
||||
def before_invoke(self, coro: CFT) -> CFT:
|
||||
"""A decorator that registers a coroutine as a pre-invoke hook.
|
||||
|
||||
A pre-invoke hook is called directly before the command is
|
||||
@@ -359,7 +385,7 @@ class BotBase(GroupMixin):
|
||||
self._before_invoke = coro
|
||||
return coro
|
||||
|
||||
def after_invoke(self, coro):
|
||||
def after_invoke(self, coro: CFT) -> CFT:
|
||||
r"""A decorator that registers a coroutine as a post-invoke hook.
|
||||
|
||||
A post-invoke hook is called directly after the command is
|
||||
@@ -394,14 +420,14 @@ class BotBase(GroupMixin):
|
||||
|
||||
# listener registration
|
||||
|
||||
def add_listener(self, func, name=None):
|
||||
def add_listener(self, func: CoroFunc, name: str = MISSING) -> None:
|
||||
"""The non decorator alternative to :meth:`.listen`.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
func: :ref:`coroutine <coroutine>`
|
||||
The function to call.
|
||||
name: Optional[:class:`str`]
|
||||
name: :class:`str`
|
||||
The name of the event to listen for. Defaults to ``func.__name__``.
|
||||
|
||||
Example
|
||||
@@ -416,7 +442,7 @@ class BotBase(GroupMixin):
|
||||
bot.add_listener(my_message, 'on_message')
|
||||
|
||||
"""
|
||||
name = func.__name__ if name is None else name
|
||||
name = func.__name__ if name is MISSING else name
|
||||
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
raise TypeError('Listeners must be coroutines')
|
||||
@@ -426,7 +452,7 @@ class BotBase(GroupMixin):
|
||||
else:
|
||||
self.extra_events[name] = [func]
|
||||
|
||||
def remove_listener(self, func, name=None):
|
||||
def remove_listener(self, func: CoroFunc, name: str = MISSING) -> None:
|
||||
"""Removes a listener from the pool of listeners.
|
||||
|
||||
Parameters
|
||||
@@ -438,7 +464,7 @@ class BotBase(GroupMixin):
|
||||
``func.__name__``.
|
||||
"""
|
||||
|
||||
name = func.__name__ if name is None else name
|
||||
name = func.__name__ if name is MISSING else name
|
||||
|
||||
if name in self.extra_events:
|
||||
try:
|
||||
@@ -446,7 +472,7 @@ class BotBase(GroupMixin):
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def listen(self, name=None):
|
||||
def listen(self, name: str = MISSING) -> Callable[[CFT], CFT]:
|
||||
"""A decorator that registers another function as an external
|
||||
event listener. Basically this allows you to listen to multiple
|
||||
events from different places e.g. such as :func:`.on_ready`
|
||||
@@ -476,7 +502,7 @@ class BotBase(GroupMixin):
|
||||
The function being listened to is not a coroutine.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
def decorator(func: CFT) -> CFT:
|
||||
self.add_listener(func, name)
|
||||
return func
|
||||
|
||||
@@ -528,7 +554,7 @@ class BotBase(GroupMixin):
|
||||
cog = cog._inject(self)
|
||||
self.__cogs[cog_name] = cog
|
||||
|
||||
def get_cog(self, name):
|
||||
def get_cog(self, name: str) -> Optional[Cog]:
|
||||
"""Gets the cog instance requested.
|
||||
|
||||
If the cog is not found, ``None`` is returned instead.
|
||||
@@ -547,8 +573,8 @@ class BotBase(GroupMixin):
|
||||
"""
|
||||
return self.__cogs.get(name)
|
||||
|
||||
def remove_cog(self, name):
|
||||
"""Removes a cog from the bot.
|
||||
def remove_cog(self, name: str) -> Optional[Cog]:
|
||||
"""Removes a cog from the bot and returns it.
|
||||
|
||||
All registered commands and event listeners that the
|
||||
cog has registered will be removed as well.
|
||||
@@ -559,6 +585,11 @@ class BotBase(GroupMixin):
|
||||
-----------
|
||||
name: :class:`str`
|
||||
The name of the cog to remove.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[:class:`.Cog`]
|
||||
The cog that was removed. ``None`` if not found.
|
||||
"""
|
||||
|
||||
cog = self.__cogs.pop(name, None)
|
||||
@@ -570,14 +601,16 @@ class BotBase(GroupMixin):
|
||||
help_command.cog = None
|
||||
cog._eject(self)
|
||||
|
||||
return cog
|
||||
|
||||
@property
|
||||
def cogs(self):
|
||||
def cogs(self) -> Mapping[str, Cog]:
|
||||
"""Mapping[:class:`str`, :class:`Cog`]: A read-only mapping of cog name to cog."""
|
||||
return types.MappingProxyType(self.__cogs)
|
||||
|
||||
# extensions
|
||||
|
||||
def _remove_module_references(self, name):
|
||||
def _remove_module_references(self, name: str) -> None:
|
||||
# find all references to the module
|
||||
# remove the cogs registered from the module
|
||||
for cogname, cog in self.__cogs.copy().items():
|
||||
@@ -601,7 +634,7 @@ class BotBase(GroupMixin):
|
||||
for index in reversed(remove):
|
||||
del event_list[index]
|
||||
|
||||
def _call_module_finalizers(self, lib, key):
|
||||
def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None:
|
||||
try:
|
||||
func = getattr(lib, 'teardown')
|
||||
except AttributeError:
|
||||
@@ -619,12 +652,12 @@ class BotBase(GroupMixin):
|
||||
if _is_submodule(name, module):
|
||||
del sys.modules[module]
|
||||
|
||||
def _load_from_module_spec(self, spec, key):
|
||||
def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None:
|
||||
# precondition: key not in self.__extensions
|
||||
lib = importlib.util.module_from_spec(spec)
|
||||
sys.modules[key] = lib
|
||||
try:
|
||||
spec.loader.exec_module(lib)
|
||||
spec.loader.exec_module(lib) # type: ignore
|
||||
except Exception as e:
|
||||
del sys.modules[key]
|
||||
raise errors.ExtensionFailed(key, e) from e
|
||||
@@ -645,13 +678,13 @@ class BotBase(GroupMixin):
|
||||
else:
|
||||
self.__extensions[key] = lib
|
||||
|
||||
def _resolve_name(self, name, package):
|
||||
def _resolve_name(self, name: str, package: Optional[str]) -> str:
|
||||
try:
|
||||
return importlib.util.resolve_name(name, package)
|
||||
except ImportError:
|
||||
raise errors.ExtensionNotFound(name)
|
||||
|
||||
def load_extension(self, name, *, package=None):
|
||||
def load_extension(self, name: str, *, package: Optional[str] = None) -> None:
|
||||
"""Loads an extension.
|
||||
|
||||
An extension is a python module that contains commands, cogs, or
|
||||
@@ -698,7 +731,7 @@ class BotBase(GroupMixin):
|
||||
|
||||
self._load_from_module_spec(spec, name)
|
||||
|
||||
def unload_extension(self, name, *, package=None):
|
||||
def unload_extension(self, name: str, *, package: Optional[str] = None) -> None:
|
||||
"""Unloads an extension.
|
||||
|
||||
When the extension is unloaded, all commands, listeners, and cogs are
|
||||
@@ -739,7 +772,7 @@ class BotBase(GroupMixin):
|
||||
self._remove_module_references(lib.__name__)
|
||||
self._call_module_finalizers(lib, name)
|
||||
|
||||
def reload_extension(self, name, *, package=None):
|
||||
def reload_extension(self, name: str, *, package: Optional[str] = None) -> None:
|
||||
"""Atomically reloads an extension.
|
||||
|
||||
This replaces the extension with the same extension, only refreshed. This is
|
||||
@@ -795,7 +828,7 @@ class BotBase(GroupMixin):
|
||||
# if the load failed, the remnants should have been
|
||||
# cleaned from the load_extension function call
|
||||
# so let's load it from our old compiled library.
|
||||
lib.setup(self)
|
||||
lib.setup(self) # type: ignore
|
||||
self.__extensions[name] = lib
|
||||
|
||||
# revert sys.modules back to normal and raise back to caller
|
||||
@@ -803,18 +836,18 @@ class BotBase(GroupMixin):
|
||||
raise
|
||||
|
||||
@property
|
||||
def extensions(self):
|
||||
def extensions(self) -> Mapping[str, types.ModuleType]:
|
||||
"""Mapping[:class:`str`, :class:`py:types.ModuleType`]: A read-only mapping of extension name to extension."""
|
||||
return types.MappingProxyType(self.__extensions)
|
||||
|
||||
# help command stuff
|
||||
|
||||
@property
|
||||
def help_command(self):
|
||||
def help_command(self) -> Optional[HelpCommand]:
|
||||
return self._help_command
|
||||
|
||||
@help_command.setter
|
||||
def help_command(self, value):
|
||||
def help_command(self, value: Optional[HelpCommand]) -> None:
|
||||
if value is not None:
|
||||
if not isinstance(value, HelpCommand):
|
||||
raise TypeError('help_command must be a subclass of HelpCommand')
|
||||
@@ -830,7 +863,7 @@ class BotBase(GroupMixin):
|
||||
|
||||
# command processing
|
||||
|
||||
async def get_prefix(self, message):
|
||||
async def get_prefix(self, message: Message) -> Union[List[str], str]:
|
||||
"""|coro|
|
||||
|
||||
Retrieves the prefix the bot is listening to
|
||||
@@ -868,7 +901,7 @@ class BotBase(GroupMixin):
|
||||
|
||||
return ret
|
||||
|
||||
async def get_context(self, message, *, cls=Context):
|
||||
async def get_context(self, message: Message, *, cls: Type[CXT] = Context) -> CXT:
|
||||
r"""|coro|
|
||||
|
||||
Returns the invocation context from the message.
|
||||
@@ -901,7 +934,7 @@ class BotBase(GroupMixin):
|
||||
view = StringView(message.content)
|
||||
ctx = cls(prefix=None, view=view, bot=self, message=message)
|
||||
|
||||
if message.author.id == self.user.id:
|
||||
if message.author.id == self.user.id: # type: ignore
|
||||
return ctx
|
||||
|
||||
prefix = await self.get_prefix(message)
|
||||
@@ -938,11 +971,12 @@ class BotBase(GroupMixin):
|
||||
|
||||
invoker = view.get_word()
|
||||
ctx.invoked_with = invoker
|
||||
ctx.prefix = invoked_prefix
|
||||
# type-checker fails to narrow invoked_prefix type.
|
||||
ctx.prefix = invoked_prefix # type: ignore
|
||||
ctx.command = self.all_commands.get(invoker)
|
||||
return ctx
|
||||
|
||||
async def invoke(self, ctx):
|
||||
async def invoke(self, ctx: Context) -> None:
|
||||
"""|coro|
|
||||
|
||||
Invokes the command given under the invocation context and
|
||||
@@ -968,7 +1002,7 @@ class BotBase(GroupMixin):
|
||||
exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found')
|
||||
self.dispatch('command_error', ctx, exc)
|
||||
|
||||
async def process_commands(self, message):
|
||||
async def process_commands(self, message: Message) -> None:
|
||||
"""|coro|
|
||||
|
||||
This function processes the commands that have been registered
|
||||
|
||||
@@ -21,15 +21,30 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import discord.utils
|
||||
|
||||
from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Type
|
||||
|
||||
from ._types import _BaseCommand
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .bot import BotBase
|
||||
from .context import Context
|
||||
from .core import Command
|
||||
|
||||
__all__ = (
|
||||
'CogMeta',
|
||||
'Cog',
|
||||
)
|
||||
|
||||
CogT = TypeVar('CogT', bound='Cog')
|
||||
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
|
||||
|
||||
MISSING: Any = discord.utils.MISSING
|
||||
|
||||
class CogMeta(type):
|
||||
"""A metaclass for defining a cog.
|
||||
|
||||
@@ -89,8 +104,12 @@ class CogMeta(type):
|
||||
async def bar(self, ctx):
|
||||
pass # hidden -> False
|
||||
"""
|
||||
__cog_name__: str
|
||||
__cog_settings__: Dict[str, Any]
|
||||
__cog_commands__: List[Command]
|
||||
__cog_listeners__: List[Tuple[str, str]]
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta:
|
||||
name, bases, attrs = args
|
||||
attrs['__cog_name__'] = kwargs.pop('name', name)
|
||||
attrs['__cog_settings__'] = kwargs.pop('command_attrs', {})
|
||||
@@ -143,14 +162,14 @@ class CogMeta(type):
|
||||
new_cls.__cog_listeners__ = listeners_as_list
|
||||
return new_cls
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args)
|
||||
|
||||
@classmethod
|
||||
def qualified_name(cls):
|
||||
def qualified_name(cls) -> str:
|
||||
return cls.__cog_name__
|
||||
|
||||
def _cog_special_method(func):
|
||||
def _cog_special_method(func: FuncT) -> FuncT:
|
||||
func.__cog_special_method__ = None
|
||||
return func
|
||||
|
||||
@@ -164,8 +183,12 @@ class Cog(metaclass=CogMeta):
|
||||
When inheriting from this class, the options shown in :class:`CogMeta`
|
||||
are equally valid here.
|
||||
"""
|
||||
__cog_name__: ClassVar[str]
|
||||
__cog_settings__: ClassVar[Dict[str, Any]]
|
||||
__cog_commands__: ClassVar[List[Command]]
|
||||
__cog_listeners__: ClassVar[List[Tuple[str, str]]]
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
def __new__(cls: Type[CogT], *args: Any, **kwargs: Any) -> CogT:
|
||||
# For issue 426, we need to store a copy of the command objects
|
||||
# since we modify them to inject `self` to them.
|
||||
# To do this, we need to interfere with the Cog creation process.
|
||||
@@ -173,7 +196,8 @@ class Cog(metaclass=CogMeta):
|
||||
cmd_attrs = cls.__cog_settings__
|
||||
|
||||
# Either update the command with the cog provided defaults or copy it.
|
||||
self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__)
|
||||
# r.e type ignore, type-checker complains about overriding a ClassVar
|
||||
self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore
|
||||
|
||||
lookup = {
|
||||
cmd.qualified_name: cmd
|
||||
@@ -186,15 +210,15 @@ class Cog(metaclass=CogMeta):
|
||||
parent = command.parent
|
||||
if parent is not None:
|
||||
# Get the latest parent reference
|
||||
parent = lookup[parent.qualified_name]
|
||||
parent = lookup[parent.qualified_name] # type: ignore
|
||||
|
||||
# Update our parent's reference to our self
|
||||
parent.remove_command(command.name)
|
||||
parent.add_command(command)
|
||||
parent.remove_command(command.name) # type: ignore
|
||||
parent.add_command(command) # type: ignore
|
||||
|
||||
return self
|
||||
|
||||
def get_commands(self):
|
||||
def get_commands(self) -> List[Command]:
|
||||
r"""
|
||||
Returns
|
||||
--------
|
||||
@@ -209,20 +233,20 @@ class Cog(metaclass=CogMeta):
|
||||
return [c for c in self.__cog_commands__ if c.parent is None]
|
||||
|
||||
@property
|
||||
def qualified_name(self):
|
||||
def qualified_name(self) -> str:
|
||||
""":class:`str`: Returns the cog's specified name, not the class name."""
|
||||
return self.__cog_name__
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
def description(self) -> str:
|
||||
""":class:`str`: Returns the cog's description, typically the cleaned docstring."""
|
||||
return self.__cog_description__
|
||||
|
||||
@description.setter
|
||||
def description(self, description):
|
||||
def description(self, description: str) -> None:
|
||||
self.__cog_description__ = description
|
||||
|
||||
def walk_commands(self):
|
||||
def walk_commands(self) -> Generator[Command, None, None]:
|
||||
"""An iterator that recursively walks through this cog's commands and subcommands.
|
||||
|
||||
Yields
|
||||
@@ -237,7 +261,7 @@ class Cog(metaclass=CogMeta):
|
||||
if isinstance(command, GroupMixin):
|
||||
yield from command.walk_commands()
|
||||
|
||||
def get_listeners(self):
|
||||
def get_listeners(self) -> List[Tuple[str, Callable[..., Any]]]:
|
||||
"""Returns a :class:`list` of (name, function) listener pairs that are defined in this cog.
|
||||
|
||||
Returns
|
||||
@@ -248,12 +272,12 @@ class Cog(metaclass=CogMeta):
|
||||
return [(name, getattr(self, method_name)) for name, method_name in self.__cog_listeners__]
|
||||
|
||||
@classmethod
|
||||
def _get_overridden_method(cls, method):
|
||||
def _get_overridden_method(cls, method: FuncT) -> Optional[FuncT]:
|
||||
"""Return None if the method is not overridden. Otherwise returns the overridden method."""
|
||||
return getattr(method.__func__, '__cog_special_method__', method)
|
||||
|
||||
@classmethod
|
||||
def listener(cls, name=None):
|
||||
def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]:
|
||||
"""A decorator that marks a function as a listener.
|
||||
|
||||
This is the cog equivalent of :meth:`.Bot.listen`.
|
||||
@@ -271,10 +295,10 @@ class Cog(metaclass=CogMeta):
|
||||
the name.
|
||||
"""
|
||||
|
||||
if name is not None and not isinstance(name, str):
|
||||
if name is not MISSING and not isinstance(name, str):
|
||||
raise TypeError(f'Cog.listener expected str but received {name.__class__.__name__!r} instead.')
|
||||
|
||||
def decorator(func):
|
||||
def decorator(func: FuncT) -> FuncT:
|
||||
actual = func
|
||||
if isinstance(actual, staticmethod):
|
||||
actual = actual.__func__
|
||||
@@ -293,7 +317,7 @@ class Cog(metaclass=CogMeta):
|
||||
return func
|
||||
return decorator
|
||||
|
||||
def has_error_handler(self):
|
||||
def has_error_handler(self) -> bool:
|
||||
""":class:`bool`: Checks whether the cog has an error handler.
|
||||
|
||||
.. versionadded:: 1.7
|
||||
@@ -301,7 +325,7 @@ class Cog(metaclass=CogMeta):
|
||||
return not hasattr(self.cog_command_error.__func__, '__cog_special_method__')
|
||||
|
||||
@_cog_special_method
|
||||
def cog_unload(self):
|
||||
def cog_unload(self) -> None:
|
||||
"""A special method that is called when the cog gets removed.
|
||||
|
||||
This function **cannot** be a coroutine. It must be a regular
|
||||
@@ -312,7 +336,7 @@ class Cog(metaclass=CogMeta):
|
||||
pass
|
||||
|
||||
@_cog_special_method
|
||||
def bot_check_once(self, ctx):
|
||||
def bot_check_once(self, ctx: Context) -> bool:
|
||||
"""A special method that registers as a :meth:`.Bot.check_once`
|
||||
check.
|
||||
|
||||
@@ -322,7 +346,7 @@ class Cog(metaclass=CogMeta):
|
||||
return True
|
||||
|
||||
@_cog_special_method
|
||||
def bot_check(self, ctx):
|
||||
def bot_check(self, ctx: Context) -> bool:
|
||||
"""A special method that registers as a :meth:`.Bot.check`
|
||||
check.
|
||||
|
||||
@@ -332,7 +356,7 @@ class Cog(metaclass=CogMeta):
|
||||
return True
|
||||
|
||||
@_cog_special_method
|
||||
def cog_check(self, ctx):
|
||||
def cog_check(self, ctx: Context) -> bool:
|
||||
"""A special method that registers as a :func:`~discord.ext.commands.check`
|
||||
for every command and subcommand in this cog.
|
||||
|
||||
@@ -342,7 +366,7 @@ class Cog(metaclass=CogMeta):
|
||||
return True
|
||||
|
||||
@_cog_special_method
|
||||
async def cog_command_error(self, ctx, error):
|
||||
async def cog_command_error(self, ctx: Context, error: Exception) -> None:
|
||||
"""A special method that is called whenever an error
|
||||
is dispatched inside this cog.
|
||||
|
||||
@@ -361,7 +385,7 @@ class Cog(metaclass=CogMeta):
|
||||
pass
|
||||
|
||||
@_cog_special_method
|
||||
async def cog_before_invoke(self, ctx):
|
||||
async def cog_before_invoke(self, ctx: Context) -> None:
|
||||
"""A special method that acts as a cog local pre-invoke hook.
|
||||
|
||||
This is similar to :meth:`.Command.before_invoke`.
|
||||
@@ -376,7 +400,7 @@ class Cog(metaclass=CogMeta):
|
||||
pass
|
||||
|
||||
@_cog_special_method
|
||||
async def cog_after_invoke(self, ctx):
|
||||
async def cog_after_invoke(self, ctx: Context) -> None:
|
||||
"""A special method that acts as a cog local post-invoke hook.
|
||||
|
||||
This is similar to :meth:`.Command.after_invoke`.
|
||||
@@ -390,7 +414,7 @@ class Cog(metaclass=CogMeta):
|
||||
"""
|
||||
pass
|
||||
|
||||
def _inject(self, bot):
|
||||
def _inject(self: CogT, bot: BotBase) -> CogT:
|
||||
cls = self.__class__
|
||||
|
||||
# realistically, the only thing that can cause loading errors
|
||||
@@ -425,7 +449,7 @@ class Cog(metaclass=CogMeta):
|
||||
|
||||
return self
|
||||
|
||||
def _eject(self, bot):
|
||||
def _eject(self, bot: BotBase) -> None:
|
||||
cls = self.__class__
|
||||
|
||||
try:
|
||||
|
||||
@@ -21,16 +21,52 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import re
|
||||
|
||||
from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
import discord.abc
|
||||
import discord.utils
|
||||
import re
|
||||
|
||||
from discord.message import Message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from discord.abc import MessageableChannel
|
||||
from discord.guild import Guild
|
||||
from discord.member import Member
|
||||
from discord.state import ConnectionState
|
||||
from discord.user import ClientUser, User
|
||||
from discord.voice_client import VoiceProtocol
|
||||
|
||||
from .bot import Bot, AutoShardedBot
|
||||
from .cog import Cog
|
||||
from .core import Command
|
||||
from .help import HelpCommand
|
||||
from .view import StringView
|
||||
|
||||
__all__ = (
|
||||
'Context',
|
||||
)
|
||||
|
||||
class Context(discord.abc.Messageable):
|
||||
MISSING: Any = discord.utils.MISSING
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
|
||||
CogT = TypeVar('CogT', bound="Cog")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
P = ParamSpec('P')
|
||||
else:
|
||||
P = TypeVar('P')
|
||||
|
||||
|
||||
class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
r"""Represents the context in which a command is being invoked under.
|
||||
|
||||
This class contains a lot of meta data to help you understand more about
|
||||
@@ -58,11 +94,11 @@ class Context(discord.abc.Messageable):
|
||||
This is only of use for within converters.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
prefix: :class:`str`
|
||||
prefix: Optional[:class:`str`]
|
||||
The prefix that was used to invoke the command.
|
||||
command: :class:`Command`
|
||||
command: Optional[:class:`Command`]
|
||||
The command that is being invoked currently.
|
||||
invoked_with: :class:`str`
|
||||
invoked_with: Optional[:class:`str`]
|
||||
The command name that triggered this invocation. Useful for finding out
|
||||
which alias called the command.
|
||||
invoked_parents: List[:class:`str`]
|
||||
@@ -73,7 +109,7 @@ class Context(discord.abc.Messageable):
|
||||
|
||||
.. versionadded:: 1.7
|
||||
|
||||
invoked_subcommand: :class:`Command`
|
||||
invoked_subcommand: Optional[:class:`Command`]
|
||||
The subcommand that was invoked.
|
||||
If no valid subcommand was invoked then this is equal to ``None``.
|
||||
subcommand_passed: Optional[:class:`str`]
|
||||
@@ -86,23 +122,38 @@ class Context(discord.abc.Messageable):
|
||||
or invoked.
|
||||
"""
|
||||
|
||||
def __init__(self, **attrs):
|
||||
self.message = attrs.pop('message', None)
|
||||
self.bot = attrs.pop('bot', None)
|
||||
self.args = attrs.pop('args', [])
|
||||
self.kwargs = attrs.pop('kwargs', {})
|
||||
self.prefix = attrs.pop('prefix')
|
||||
self.command = attrs.pop('command', None)
|
||||
self.view = attrs.pop('view', None)
|
||||
self.invoked_with = attrs.pop('invoked_with', None)
|
||||
self.invoked_parents = attrs.pop('invoked_parents', [])
|
||||
self.invoked_subcommand = attrs.pop('invoked_subcommand', None)
|
||||
self.subcommand_passed = attrs.pop('subcommand_passed', None)
|
||||
self.command_failed = attrs.pop('command_failed', False)
|
||||
self.current_parameter = attrs.pop('current_parameter', None)
|
||||
self._state = self.message._state
|
||||
def __init__(self,
|
||||
*,
|
||||
message: Message,
|
||||
bot: BotT,
|
||||
view: StringView,
|
||||
args: List[Any] = MISSING,
|
||||
kwargs: Dict[str, Any] = MISSING,
|
||||
prefix: Optional[str] = None,
|
||||
command: Optional[Command] = None,
|
||||
invoked_with: Optional[str] = None,
|
||||
invoked_parents: List[str] = MISSING,
|
||||
invoked_subcommand: Optional[Command] = None,
|
||||
subcommand_passed: Optional[str] = None,
|
||||
command_failed: bool = False,
|
||||
current_parameter: Optional[inspect.Parameter] = None,
|
||||
):
|
||||
self.message: Message = message
|
||||
self.bot: BotT = bot
|
||||
self.args: List[Any] = args or []
|
||||
self.kwargs: Dict[str, Any] = kwargs or {}
|
||||
self.prefix: Optional[str] = prefix
|
||||
self.command: Optional[Command] = command
|
||||
self.view: StringView = view
|
||||
self.invoked_with: Optional[str] = invoked_with
|
||||
self.invoked_parents: List[str] = invoked_parents or []
|
||||
self.invoked_subcommand: Optional[Command] = invoked_subcommand
|
||||
self.subcommand_passed: Optional[str] = subcommand_passed
|
||||
self.command_failed: bool = command_failed
|
||||
self.current_parameter: Optional[inspect.Parameter] = current_parameter
|
||||
self._state: ConnectionState = self.message._state
|
||||
|
||||
async def invoke(self, command, /, *args, **kwargs):
|
||||
async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
|
||||
r"""|coro|
|
||||
|
||||
Calls a command with the arguments given.
|
||||
@@ -124,7 +175,7 @@ class Context(discord.abc.Messageable):
|
||||
command: :class:`.Command`
|
||||
The command that is going to be called.
|
||||
\*args
|
||||
The arguments to to use.
|
||||
The arguments to use.
|
||||
\*\*kwargs
|
||||
The keyword arguments to use.
|
||||
|
||||
@@ -133,17 +184,9 @@ class Context(discord.abc.Messageable):
|
||||
TypeError
|
||||
The command argument to invoke is missing.
|
||||
"""
|
||||
arguments = []
|
||||
if command.cog is not None:
|
||||
arguments.append(command.cog)
|
||||
return await command(self, *args, **kwargs)
|
||||
|
||||
arguments.append(self)
|
||||
arguments.extend(args)
|
||||
|
||||
ret = await command.callback(*arguments, **kwargs)
|
||||
return ret
|
||||
|
||||
async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True):
|
||||
async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True) -> None:
|
||||
"""|coro|
|
||||
|
||||
Calls the command again.
|
||||
@@ -187,7 +230,7 @@ class Context(discord.abc.Messageable):
|
||||
|
||||
if restart:
|
||||
to_call = cmd.root_parent or cmd
|
||||
view.index = len(self.prefix)
|
||||
view.index = len(self.prefix or '')
|
||||
view.previous = 0
|
||||
self.invoked_parents = []
|
||||
self.invoked_with = view.get_word() # advance to get the root command
|
||||
@@ -206,20 +249,23 @@ class Context(discord.abc.Messageable):
|
||||
self.subcommand_passed = subcommand_passed
|
||||
|
||||
@property
|
||||
def valid(self):
|
||||
def valid(self) -> bool:
|
||||
""":class:`bool`: Checks if the invocation context is valid to be invoked with."""
|
||||
return self.prefix is not None and self.command is not None
|
||||
|
||||
async def _get_channel(self):
|
||||
async def _get_channel(self) -> discord.abc.Messageable:
|
||||
return self.channel
|
||||
|
||||
@property
|
||||
def clean_prefix(self):
|
||||
def clean_prefix(self) -> str:
|
||||
""":class:`str`: The cleaned up invoke prefix. i.e. mentions are ``@name`` instead of ``<@id>``.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
user = self.guild.me if self.guild else self.bot.user
|
||||
if self.prefix is None:
|
||||
return ''
|
||||
|
||||
user = self.me
|
||||
# this breaks if the prefix mention is not the bot itself but I
|
||||
# consider this to be an *incredibly* strange use case. I'd rather go
|
||||
# for this common use case rather than waste performance for the
|
||||
@@ -228,7 +274,7 @@ class Context(discord.abc.Messageable):
|
||||
return pattern.sub("@%s" % user.display_name.replace('\\', r'\\'), self.prefix)
|
||||
|
||||
@property
|
||||
def cog(self):
|
||||
def cog(self) -> Optional[Cog]:
|
||||
"""Optional[:class:`.Cog`]: Returns the cog associated with this context's command. None if it does not exist."""
|
||||
|
||||
if self.command is None:
|
||||
@@ -236,38 +282,39 @@ class Context(discord.abc.Messageable):
|
||||
return self.command.cog
|
||||
|
||||
@discord.utils.cached_property
|
||||
def guild(self):
|
||||
def guild(self) -> Optional[Guild]:
|
||||
"""Optional[:class:`.Guild`]: Returns the guild associated with this context's command. None if not available."""
|
||||
return self.message.guild
|
||||
|
||||
@discord.utils.cached_property
|
||||
def channel(self):
|
||||
def channel(self) -> MessageableChannel:
|
||||
"""Union[:class:`.abc.Messageable`]: Returns the channel associated with this context's command.
|
||||
Shorthand for :attr:`.Message.channel`.
|
||||
"""
|
||||
return self.message.channel
|
||||
|
||||
@discord.utils.cached_property
|
||||
def author(self):
|
||||
def author(self) -> Union[User, Member]:
|
||||
"""Union[:class:`~discord.User`, :class:`.Member`]:
|
||||
Returns the author associated with this context's command. Shorthand for :attr:`.Message.author`
|
||||
"""
|
||||
return self.message.author
|
||||
|
||||
@discord.utils.cached_property
|
||||
def me(self):
|
||||
def me(self) -> Union[Member, ClientUser]:
|
||||
"""Union[:class:`.Member`, :class:`.ClientUser`]:
|
||||
Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message contexts.
|
||||
"""
|
||||
return self.guild.me if self.guild is not None else self.bot.user
|
||||
# bot.user will never be None at this point.
|
||||
return self.guild.me if self.guild is not None else self.bot.user # type: ignore
|
||||
|
||||
@property
|
||||
def voice_client(self):
|
||||
def voice_client(self) -> Optional[VoiceProtocol]:
|
||||
r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable."""
|
||||
g = self.guild
|
||||
return g.voice_client if g else None
|
||||
|
||||
async def send_help(self, *args):
|
||||
async def send_help(self, *args: Any) -> Any:
|
||||
"""send_help(entity=<bot>)
|
||||
|
||||
|coro|
|
||||
@@ -319,12 +366,12 @@ class Context(discord.abc.Messageable):
|
||||
return None
|
||||
|
||||
entity = args[0]
|
||||
if entity is None:
|
||||
return None
|
||||
|
||||
if isinstance(entity, str):
|
||||
entity = bot.get_cog(entity) or bot.get_command(entity)
|
||||
|
||||
if entity is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
entity.qualified_name
|
||||
except AttributeError:
|
||||
@@ -348,6 +395,6 @@ class Context(discord.abc.Messageable):
|
||||
except CommandError as e:
|
||||
await cmd.on_help_command_error(self, e)
|
||||
|
||||
@discord.utils.copy_doc(discord.Message.reply)
|
||||
async def reply(self, content=None, **kwargs):
|
||||
@discord.utils.copy_doc(Message.reply)
|
||||
async def reply(self, content: Optional[str] = None, **kwargs: Any) -> Message:
|
||||
return await self.message.reply(content, **kwargs)
|
||||
|
||||
@@ -74,6 +74,7 @@ __all__ = (
|
||||
'StoreChannelConverter',
|
||||
'ThreadConverter',
|
||||
'GuildChannelConverter',
|
||||
'GuildStickerConverter',
|
||||
'clean_content',
|
||||
'Greedy',
|
||||
'run_converters',
|
||||
@@ -823,6 +824,45 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
|
||||
raise PartialEmojiConversionFailure(argument)
|
||||
|
||||
|
||||
class GuildStickerConverter(IDConverter[discord.GuildSticker]):
|
||||
"""Converts to a :class:`~discord.GuildSticker`.
|
||||
|
||||
All lookups are done for the local guild first, if available. If that lookup
|
||||
fails, then it checks the client's global cache.
|
||||
|
||||
The lookup strategy is as follows (in order):
|
||||
|
||||
1. Lookup by ID.
|
||||
3. Lookup by name
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.GuildSticker:
|
||||
match = self._get_id_match(argument)
|
||||
result = None
|
||||
bot = ctx.bot
|
||||
guild = ctx.guild
|
||||
|
||||
if match is None:
|
||||
# Try to get the sticker by name. Try local guild first.
|
||||
if guild:
|
||||
result = discord.utils.get(guild.stickers, name=argument)
|
||||
|
||||
if result is None:
|
||||
result = discord.utils.get(bot.stickers, name=argument)
|
||||
else:
|
||||
sticker_id = int(match.group(1))
|
||||
|
||||
# Try to look up sticker by id.
|
||||
result = bot.get_sticker(sticker_id)
|
||||
|
||||
if result is None:
|
||||
raise GuildStickerNotFound(argument)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class clean_content(Converter[str]):
|
||||
"""Converts the argument to mention scrubbed version of
|
||||
said content.
|
||||
@@ -1012,6 +1052,7 @@ CONVERTER_MAPPING: Dict[Type[Any], Any] = {
|
||||
discord.StoreChannel: StoreChannelConverter,
|
||||
discord.Thread: ThreadConverter,
|
||||
discord.abc.GuildChannel: GuildChannelConverter,
|
||||
discord.GuildSticker: GuildStickerConverter,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -22,6 +22,10 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
from typing import Any, Callable, Deque, Dict, Optional, Type, TypeVar, TYPE_CHECKING
|
||||
from discord.enums import Enum
|
||||
import time
|
||||
import asyncio
|
||||
@@ -30,6 +34,9 @@ from collections import deque
|
||||
from ...abc import PrivateChannel
|
||||
from .errors import MaxConcurrencyReached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...message import Message
|
||||
|
||||
__all__ = (
|
||||
'BucketType',
|
||||
'Cooldown',
|
||||
@@ -38,6 +45,9 @@ __all__ = (
|
||||
'MaxConcurrency',
|
||||
)
|
||||
|
||||
C = TypeVar('C', bound='CooldownMapping')
|
||||
MC = TypeVar('MC', bound='MaxConcurrency')
|
||||
|
||||
class BucketType(Enum):
|
||||
default = 0
|
||||
user = 1
|
||||
@@ -47,7 +57,7 @@ class BucketType(Enum):
|
||||
category = 5
|
||||
role = 6
|
||||
|
||||
def get_key(self, msg):
|
||||
def get_key(self, msg: Message) -> Any:
|
||||
if self is BucketType.user:
|
||||
return msg.author.id
|
||||
elif self is BucketType.guild:
|
||||
@@ -57,29 +67,52 @@ class BucketType(Enum):
|
||||
elif self is BucketType.member:
|
||||
return ((msg.guild and msg.guild.id), msg.author.id)
|
||||
elif self is BucketType.category:
|
||||
return (msg.channel.category or msg.channel).id
|
||||
return (msg.channel.category or msg.channel).id # type: ignore
|
||||
elif self is BucketType.role:
|
||||
# we return the channel id of a private-channel as there are only roles in guilds
|
||||
# and that yields the same result as for a guild with only the @everyone role
|
||||
# NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
|
||||
# recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
|
||||
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id
|
||||
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore
|
||||
|
||||
def __call__(self, msg):
|
||||
def __call__(self, msg: Message) -> Any:
|
||||
return self.get_key(msg)
|
||||
|
||||
|
||||
class Cooldown:
|
||||
"""Represents a cooldown for a command.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
rate: :class:`int`
|
||||
The total number of tokens available per :attr:`per` seconds.
|
||||
per: :class:`float`
|
||||
The length of the cooldown period in seconds.
|
||||
"""
|
||||
|
||||
__slots__ = ('rate', 'per', '_window', '_tokens', '_last')
|
||||
|
||||
def __init__(self, rate, per):
|
||||
self.rate = int(rate)
|
||||
self.per = float(per)
|
||||
self._window = 0.0
|
||||
self._tokens = self.rate
|
||||
self._last = 0.0
|
||||
def __init__(self, rate: float, per: float) -> None:
|
||||
self.rate: int = int(rate)
|
||||
self.per: float = float(per)
|
||||
self._window: float = 0.0
|
||||
self._tokens: int = self.rate
|
||||
self._last: float = 0.0
|
||||
|
||||
def get_tokens(self, current=None):
|
||||
def get_tokens(self, current: Optional[float] = None) -> int:
|
||||
"""Returns the number of available tokens before rate limiting is applied.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
current: Optional[:class:`float`]
|
||||
The time in seconds since Unix epoch to calculate tokens at.
|
||||
If not supplied then :func:`time.time()` is used.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`int`
|
||||
The number of tokens available before the cooldown is to be applied.
|
||||
"""
|
||||
if not current:
|
||||
current = time.time()
|
||||
|
||||
@@ -89,7 +122,20 @@ class Cooldown:
|
||||
tokens = self.rate
|
||||
return tokens
|
||||
|
||||
def get_retry_after(self, current=None):
|
||||
def get_retry_after(self, current: Optional[float] = None) -> float:
|
||||
"""Returns the time in seconds until the cooldown will be reset.
|
||||
|
||||
Parameters
|
||||
-------------
|
||||
current: Optional[:class:`float`]
|
||||
The current time in seconds since Unix epoch.
|
||||
If not supplied, then :func:`time.time()` is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class:`float`
|
||||
The number of seconds to wait before this cooldown will be reset.
|
||||
"""
|
||||
current = current or time.time()
|
||||
tokens = self.get_tokens(current)
|
||||
|
||||
@@ -98,7 +144,20 @@ class Cooldown:
|
||||
|
||||
return 0.0
|
||||
|
||||
def update_rate_limit(self, current=None):
|
||||
def update_rate_limit(self, current: Optional[float] = None) -> Optional[float]:
|
||||
"""Updates the cooldown rate limit.
|
||||
|
||||
Parameters
|
||||
-------------
|
||||
current: Optional[:class:`float`]
|
||||
The time in seconds since Unix epoch to update the rate limit at.
|
||||
If not supplied, then :func:`time.time()` is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[:class:`float`]
|
||||
The retry-after time in seconds if rate limited.
|
||||
"""
|
||||
current = current or time.time()
|
||||
self._last = current
|
||||
|
||||
@@ -115,46 +174,58 @@ class Cooldown:
|
||||
# we're not so decrement our tokens
|
||||
self._tokens -= 1
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
"""Reset the cooldown to its initial state."""
|
||||
self._tokens = self.rate
|
||||
self._last = 0.0
|
||||
|
||||
def copy(self):
|
||||
def copy(self) -> Cooldown:
|
||||
"""Creates a copy of this cooldown.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`Cooldown`
|
||||
A new instance of this cooldown.
|
||||
"""
|
||||
return Cooldown(self.rate, self.per)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>'
|
||||
|
||||
class CooldownMapping:
|
||||
def __init__(self, original, type):
|
||||
def __init__(
|
||||
self,
|
||||
original: Optional[Cooldown],
|
||||
type: Callable[[Message], Any],
|
||||
) -> None:
|
||||
if not callable(type):
|
||||
raise TypeError('Cooldown type must be a BucketType or callable')
|
||||
|
||||
self._cache = {}
|
||||
self._cooldown = original
|
||||
self._type = type
|
||||
self._cache: Dict[Any, Cooldown] = {}
|
||||
self._cooldown: Optional[Cooldown] = original
|
||||
self._type: Callable[[Message], Any] = type
|
||||
|
||||
def copy(self):
|
||||
def copy(self) -> CooldownMapping:
|
||||
ret = CooldownMapping(self._cooldown, self._type)
|
||||
ret._cache = self._cache.copy()
|
||||
return ret
|
||||
|
||||
@property
|
||||
def valid(self):
|
||||
def valid(self) -> bool:
|
||||
return self._cooldown is not None
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
def type(self) -> Callable[[Message], Any]:
|
||||
return self._type
|
||||
|
||||
@classmethod
|
||||
def from_cooldown(cls, rate, per, type):
|
||||
def from_cooldown(cls: Type[C], rate, per, type) -> C:
|
||||
return cls(Cooldown(rate, per), type)
|
||||
|
||||
def _bucket_key(self, msg):
|
||||
def _bucket_key(self, msg: Message) -> Any:
|
||||
return self._type(msg)
|
||||
|
||||
def _verify_cache_integrity(self, current=None):
|
||||
def _verify_cache_integrity(self, current: Optional[float] = None) -> None:
|
||||
# we want to delete all cache objects that haven't been used
|
||||
# in a cooldown window. e.g. if we have a command that has a
|
||||
# cooldown of 60s and it has not been used in 60s then that key should be deleted
|
||||
@@ -163,12 +234,12 @@ class CooldownMapping:
|
||||
for k in dead_keys:
|
||||
del self._cache[k]
|
||||
|
||||
def create_bucket(self, message):
|
||||
return self._cooldown.copy()
|
||||
def create_bucket(self, message: Message) -> Cooldown:
|
||||
return self._cooldown.copy() # type: ignore
|
||||
|
||||
def get_bucket(self, message, current=None):
|
||||
def get_bucket(self, message: Message, current: Optional[float] = None) -> Cooldown:
|
||||
if self._type is BucketType.default:
|
||||
return self._cooldown
|
||||
return self._cooldown # type: ignore
|
||||
|
||||
self._verify_cache_integrity(current)
|
||||
key = self._bucket_key(message)
|
||||
@@ -181,26 +252,30 @@ class CooldownMapping:
|
||||
|
||||
return bucket
|
||||
|
||||
def update_rate_limit(self, message, current=None):
|
||||
def update_rate_limit(self, message: Message, current: Optional[float] = None) -> Optional[float]:
|
||||
bucket = self.get_bucket(message, current)
|
||||
return bucket.update_rate_limit(current)
|
||||
|
||||
class DynamicCooldownMapping(CooldownMapping):
|
||||
|
||||
def __init__(self, factory, type):
|
||||
def __init__(
|
||||
self,
|
||||
factory: Callable[[Message], Cooldown],
|
||||
type: Callable[[Message], Any]
|
||||
) -> None:
|
||||
super().__init__(None, type)
|
||||
self._factory = factory
|
||||
self._factory: Callable[[Message], Cooldown] = factory
|
||||
|
||||
def copy(self):
|
||||
def copy(self) -> DynamicCooldownMapping:
|
||||
ret = DynamicCooldownMapping(self._factory, self._type)
|
||||
ret._cache = self._cache.copy()
|
||||
return ret
|
||||
|
||||
@property
|
||||
def valid(self):
|
||||
def valid(self) -> bool:
|
||||
return True
|
||||
|
||||
def create_bucket(self, message):
|
||||
def create_bucket(self, message: Message) -> Cooldown:
|
||||
return self._factory(message)
|
||||
|
||||
class _Semaphore:
|
||||
@@ -218,28 +293,28 @@ class _Semaphore:
|
||||
|
||||
__slots__ = ('value', 'loop', '_waiters')
|
||||
|
||||
def __init__(self, number):
|
||||
self.value = number
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self._waiters = deque()
|
||||
def __init__(self, number: int) -> None:
|
||||
self.value: int = number
|
||||
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
||||
self._waiters: Deque[asyncio.Future] = deque()
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f'<_Semaphore value={self.value} waiters={len(self._waiters)}>'
|
||||
|
||||
def locked(self):
|
||||
def locked(self) -> bool:
|
||||
return self.value == 0
|
||||
|
||||
def is_active(self):
|
||||
def is_active(self) -> bool:
|
||||
return len(self._waiters) > 0
|
||||
|
||||
def wake_up(self):
|
||||
def wake_up(self) -> None:
|
||||
while self._waiters:
|
||||
future = self._waiters.popleft()
|
||||
if not future.done():
|
||||
future.set_result(None)
|
||||
return
|
||||
|
||||
async def acquire(self, *, wait=False):
|
||||
async def acquire(self, *, wait: bool = False) -> bool:
|
||||
if not wait and self.value <= 0:
|
||||
# signal that we're not acquiring
|
||||
return False
|
||||
@@ -258,18 +333,18 @@ class _Semaphore:
|
||||
self.value -= 1
|
||||
return True
|
||||
|
||||
def release(self):
|
||||
def release(self) -> None:
|
||||
self.value += 1
|
||||
self.wake_up()
|
||||
|
||||
class MaxConcurrency:
|
||||
__slots__ = ('number', 'per', 'wait', '_mapping')
|
||||
|
||||
def __init__(self, number, *, per, wait):
|
||||
self._mapping = {}
|
||||
self.per = per
|
||||
self.number = number
|
||||
self.wait = wait
|
||||
def __init__(self, number: int, *, per: BucketType, wait: bool) -> None:
|
||||
self._mapping: Dict[Any, _Semaphore] = {}
|
||||
self.per: BucketType = per
|
||||
self.number: int = number
|
||||
self.wait: bool = wait
|
||||
|
||||
if number <= 0:
|
||||
raise ValueError('max_concurrency \'number\' cannot be less than 1')
|
||||
@@ -277,16 +352,16 @@ class MaxConcurrency:
|
||||
if not isinstance(per, BucketType):
|
||||
raise TypeError(f'max_concurrency \'per\' must be of type BucketType not {type(per)!r}')
|
||||
|
||||
def copy(self):
|
||||
def copy(self: MC) -> MC:
|
||||
return self.__class__(self.number, per=self.per, wait=self.wait)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f'<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>'
|
||||
|
||||
def get_key(self, message):
|
||||
def get_key(self, message: Message) -> Any:
|
||||
return self.per.get_key(message)
|
||||
|
||||
async def acquire(self, message):
|
||||
async def acquire(self, message: Message) -> None:
|
||||
key = self.get_key(message)
|
||||
|
||||
try:
|
||||
@@ -298,7 +373,7 @@ class MaxConcurrency:
|
||||
if not acquired:
|
||||
raise MaxConcurrencyReached(self.number, self.per)
|
||||
|
||||
async def release(self, message):
|
||||
async def release(self, message: Message) -> None:
|
||||
# Technically there's no reason for this function to be async
|
||||
# But it might be more useful in the future
|
||||
key = self.get_key(message)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -22,8 +22,23 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Any, TYPE_CHECKING, List, Callable, Type, Tuple, Union
|
||||
|
||||
from discord.errors import ClientException, DiscordException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from inspect import Parameter
|
||||
|
||||
from .converter import Converter
|
||||
from .context import Context
|
||||
from .cooldowns import Cooldown, BucketType
|
||||
from .flags import Flag
|
||||
from discord.abc import GuildChannel
|
||||
from discord.threads import Thread
|
||||
from discord.types.snowflake import Snowflake, SnowflakeList
|
||||
|
||||
|
||||
__all__ = (
|
||||
'CommandError',
|
||||
@@ -54,6 +69,7 @@ __all__ = (
|
||||
'RoleNotFound',
|
||||
'BadInviteArgument',
|
||||
'EmojiNotFound',
|
||||
'GuildStickerNotFound',
|
||||
'PartialEmojiConversionFailure',
|
||||
'BadBoolArgument',
|
||||
'MissingRole',
|
||||
@@ -93,7 +109,7 @@ class CommandError(DiscordException):
|
||||
in a special way as they are caught and passed into a special event
|
||||
from :class:`.Bot`\, :func:`.on_command_error`.
|
||||
"""
|
||||
def __init__(self, message=None, *args):
|
||||
def __init__(self, message: Optional[str] = None, *args: Any) -> None:
|
||||
if message is not None:
|
||||
# clean-up @everyone and @here mentions
|
||||
m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere')
|
||||
@@ -114,9 +130,9 @@ class ConversionError(CommandError):
|
||||
The original exception that was raised. You can also get this via
|
||||
the ``__cause__`` attribute.
|
||||
"""
|
||||
def __init__(self, converter, original):
|
||||
self.converter = converter
|
||||
self.original = original
|
||||
def __init__(self, converter: Converter, original: Exception) -> None:
|
||||
self.converter: Converter = converter
|
||||
self.original: Exception = original
|
||||
|
||||
class UserInputError(CommandError):
|
||||
"""The base exception type for errors that involve errors
|
||||
@@ -148,8 +164,8 @@ class MissingRequiredArgument(UserInputError):
|
||||
param: :class:`inspect.Parameter`
|
||||
The argument that is missing.
|
||||
"""
|
||||
def __init__(self, param):
|
||||
self.param = param
|
||||
def __init__(self, param: Parameter) -> None:
|
||||
self.param: Parameter = param
|
||||
super().__init__(f'{param.name} is a required argument that is missing.')
|
||||
|
||||
class TooManyArguments(UserInputError):
|
||||
@@ -190,9 +206,9 @@ class CheckAnyFailure(CheckFailure):
|
||||
A list of check predicates that failed.
|
||||
"""
|
||||
|
||||
def __init__(self, checks, errors):
|
||||
self.checks = checks
|
||||
self.errors = errors
|
||||
def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]) -> None:
|
||||
self.checks: List[CheckFailure] = checks
|
||||
self.errors: List[Callable[[Context], bool]] = errors
|
||||
super().__init__('You do not have permission to run this command.')
|
||||
|
||||
class PrivateMessageOnly(CheckFailure):
|
||||
@@ -201,7 +217,7 @@ class PrivateMessageOnly(CheckFailure):
|
||||
|
||||
This inherits from :exc:`CheckFailure`
|
||||
"""
|
||||
def __init__(self, message=None):
|
||||
def __init__(self, message: Optional[str] = None) -> None:
|
||||
super().__init__(message or 'This command can only be used in private messages.')
|
||||
|
||||
class NoPrivateMessage(CheckFailure):
|
||||
@@ -211,7 +227,7 @@ class NoPrivateMessage(CheckFailure):
|
||||
This inherits from :exc:`CheckFailure`
|
||||
"""
|
||||
|
||||
def __init__(self, message=None):
|
||||
def __init__(self, message: Optional[str] = None) -> None:
|
||||
super().__init__(message or 'This command cannot be used in private messages.')
|
||||
|
||||
class NotOwner(CheckFailure):
|
||||
@@ -234,8 +250,8 @@ class ObjectNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The argument supplied by the caller that was not matched
|
||||
"""
|
||||
def __init__(self, argument):
|
||||
self.argument = argument
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'{argument!r} does not follow a valid ID or mention format.')
|
||||
|
||||
class MemberNotFound(BadArgument):
|
||||
@@ -251,8 +267,8 @@ class MemberNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The member supplied by the caller that was not found
|
||||
"""
|
||||
def __init__(self, argument):
|
||||
self.argument = argument
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Member "{argument}" not found.')
|
||||
|
||||
class GuildNotFound(BadArgument):
|
||||
@@ -267,8 +283,8 @@ class GuildNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The guild supplied by the called that was not found
|
||||
"""
|
||||
def __init__(self, argument):
|
||||
self.argument = argument
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Guild "{argument}" not found.')
|
||||
|
||||
class UserNotFound(BadArgument):
|
||||
@@ -284,8 +300,8 @@ class UserNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The user supplied by the caller that was not found
|
||||
"""
|
||||
def __init__(self, argument):
|
||||
self.argument = argument
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'User "{argument}" not found.')
|
||||
|
||||
class MessageNotFound(BadArgument):
|
||||
@@ -300,8 +316,8 @@ class MessageNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The message supplied by the caller that was not found
|
||||
"""
|
||||
def __init__(self, argument):
|
||||
self.argument = argument
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Message "{argument}" not found.')
|
||||
|
||||
class ChannelNotReadable(BadArgument):
|
||||
@@ -314,11 +330,11 @@ class ChannelNotReadable(BadArgument):
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
argument: :class:`.abc.GuildChannel`
|
||||
argument: Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
|
||||
The channel supplied by the caller that was not readable
|
||||
"""
|
||||
def __init__(self, argument):
|
||||
self.argument = argument
|
||||
def __init__(self, argument: Union[GuildChannel, Thread]) -> None:
|
||||
self.argument: Union[GuildChannel, Thread] = argument
|
||||
super().__init__(f"Can't read messages in {argument.mention}.")
|
||||
|
||||
class ChannelNotFound(BadArgument):
|
||||
@@ -333,8 +349,8 @@ class ChannelNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The channel supplied by the caller that was not found
|
||||
"""
|
||||
def __init__(self, argument):
|
||||
self.argument = argument
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Channel "{argument}" not found.')
|
||||
|
||||
class ThreadNotFound(BadArgument):
|
||||
@@ -349,8 +365,8 @@ class ThreadNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The thread supplied by the caller that was not found
|
||||
"""
|
||||
def __init__(self, argument):
|
||||
self.argument = argument
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Thread "{argument}" not found.')
|
||||
|
||||
class BadColourArgument(BadArgument):
|
||||
@@ -365,8 +381,8 @@ class BadColourArgument(BadArgument):
|
||||
argument: :class:`str`
|
||||
The colour supplied by the caller that was not valid
|
||||
"""
|
||||
def __init__(self, argument):
|
||||
self.argument = argument
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Colour "{argument}" is invalid.')
|
||||
|
||||
BadColorArgument = BadColourArgument
|
||||
@@ -383,8 +399,8 @@ class RoleNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The role supplied by the caller that was not found
|
||||
"""
|
||||
def __init__(self, argument):
|
||||
self.argument = argument
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Role "{argument}" not found.')
|
||||
|
||||
class BadInviteArgument(BadArgument):
|
||||
@@ -394,8 +410,8 @@ class BadInviteArgument(BadArgument):
|
||||
|
||||
.. versionadded:: 1.5
|
||||
"""
|
||||
def __init__(self, argument):
|
||||
self.argument = argument
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Invite "{argument}" is invalid or expired.')
|
||||
|
||||
class EmojiNotFound(BadArgument):
|
||||
@@ -410,8 +426,8 @@ class EmojiNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The emoji supplied by the caller that was not found
|
||||
"""
|
||||
def __init__(self, argument):
|
||||
self.argument = argument
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Emoji "{argument}" not found.')
|
||||
|
||||
class PartialEmojiConversionFailure(BadArgument):
|
||||
@@ -427,10 +443,26 @@ class PartialEmojiConversionFailure(BadArgument):
|
||||
argument: :class:`str`
|
||||
The emoji supplied by the caller that did not match the regex
|
||||
"""
|
||||
def __init__(self, argument):
|
||||
self.argument = argument
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Couldn\'t convert "{argument}" to PartialEmoji.')
|
||||
|
||||
class GuildStickerNotFound(BadArgument):
|
||||
"""Exception raised when the bot can not find the sticker.
|
||||
|
||||
This inherits from :exc:`BadArgument`
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
argument: :class:`str`
|
||||
The sticker supplied by the caller that was not found
|
||||
"""
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Sticker "{argument}" not found.')
|
||||
|
||||
class BadBoolArgument(BadArgument):
|
||||
"""Exception raised when a boolean argument was not convertable.
|
||||
|
||||
@@ -443,8 +475,8 @@ class BadBoolArgument(BadArgument):
|
||||
argument: :class:`str`
|
||||
The boolean argument supplied by the caller that is not in the predefined list
|
||||
"""
|
||||
def __init__(self, argument):
|
||||
self.argument = argument
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'{argument} is not a recognised boolean option')
|
||||
|
||||
class DisabledCommand(CommandError):
|
||||
@@ -465,8 +497,8 @@ class CommandInvokeError(CommandError):
|
||||
The original exception that was raised. You can also get this via
|
||||
the ``__cause__`` attribute.
|
||||
"""
|
||||
def __init__(self, e):
|
||||
self.original = e
|
||||
def __init__(self, e: Exception) -> None:
|
||||
self.original: Exception = e
|
||||
super().__init__(f'Command raised an exception: {e.__class__.__name__}: {e}')
|
||||
|
||||
class CommandOnCooldown(CommandError):
|
||||
@@ -476,7 +508,7 @@ class CommandOnCooldown(CommandError):
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
cooldown: ``Cooldown``
|
||||
cooldown: :class:`.Cooldown`
|
||||
A class with attributes ``rate`` and ``per`` similar to the
|
||||
:func:`.cooldown` decorator.
|
||||
type: :class:`BucketType`
|
||||
@@ -484,10 +516,10 @@ class CommandOnCooldown(CommandError):
|
||||
retry_after: :class:`float`
|
||||
The amount of seconds to wait before you can retry again.
|
||||
"""
|
||||
def __init__(self, cooldown, retry_after, type):
|
||||
self.cooldown = cooldown
|
||||
self.retry_after = retry_after
|
||||
self.type = type
|
||||
def __init__(self, cooldown: Cooldown, retry_after: float, type: BucketType) -> None:
|
||||
self.cooldown: Cooldown = cooldown
|
||||
self.retry_after: float = retry_after
|
||||
self.type: BucketType = type
|
||||
super().__init__(f'You are on cooldown. Try again in {retry_after:.2f}s')
|
||||
|
||||
class MaxConcurrencyReached(CommandError):
|
||||
@@ -503,9 +535,9 @@ class MaxConcurrencyReached(CommandError):
|
||||
The bucket type passed to the :func:`.max_concurrency` decorator.
|
||||
"""
|
||||
|
||||
def __init__(self, number, per):
|
||||
self.number = number
|
||||
self.per = per
|
||||
def __init__(self, number: int, per: BucketType) -> None:
|
||||
self.number: int = number
|
||||
self.per: BucketType = per
|
||||
name = per.name
|
||||
suffix = 'per %s' % name if per.name != 'default' else 'globally'
|
||||
plural = '%s times %s' if number > 1 else '%s time %s'
|
||||
@@ -525,8 +557,8 @@ class MissingRole(CheckFailure):
|
||||
The required role that is missing.
|
||||
This is the parameter passed to :func:`~.commands.has_role`.
|
||||
"""
|
||||
def __init__(self, missing_role):
|
||||
self.missing_role = missing_role
|
||||
def __init__(self, missing_role: Snowflake) -> None:
|
||||
self.missing_role: Snowflake = missing_role
|
||||
message = f'Role {missing_role!r} is required to run this command.'
|
||||
super().__init__(message)
|
||||
|
||||
@@ -543,8 +575,8 @@ class BotMissingRole(CheckFailure):
|
||||
The required role that is missing.
|
||||
This is the parameter passed to :func:`~.commands.has_role`.
|
||||
"""
|
||||
def __init__(self, missing_role):
|
||||
self.missing_role = missing_role
|
||||
def __init__(self, missing_role: Snowflake) -> None:
|
||||
self.missing_role: Snowflake = missing_role
|
||||
message = f'Bot requires the role {missing_role!r} to run this command'
|
||||
super().__init__(message)
|
||||
|
||||
@@ -562,8 +594,8 @@ class MissingAnyRole(CheckFailure):
|
||||
The roles that the invoker is missing.
|
||||
These are the parameters passed to :func:`~.commands.has_any_role`.
|
||||
"""
|
||||
def __init__(self, missing_roles):
|
||||
self.missing_roles = missing_roles
|
||||
def __init__(self, missing_roles: SnowflakeList) -> None:
|
||||
self.missing_roles: SnowflakeList = missing_roles
|
||||
|
||||
missing = [f"'{role}'" for role in missing_roles]
|
||||
|
||||
@@ -591,8 +623,8 @@ class BotMissingAnyRole(CheckFailure):
|
||||
These are the parameters passed to :func:`~.commands.has_any_role`.
|
||||
|
||||
"""
|
||||
def __init__(self, missing_roles):
|
||||
self.missing_roles = missing_roles
|
||||
def __init__(self, missing_roles: SnowflakeList) -> None:
|
||||
self.missing_roles: SnowflakeList = missing_roles
|
||||
|
||||
missing = [f"'{role}'" for role in missing_roles]
|
||||
|
||||
@@ -613,11 +645,11 @@ class NSFWChannelRequired(CheckFailure):
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
channel: :class:`discord.abc.GuildChannel`
|
||||
channel: Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
|
||||
The channel that does not have NSFW enabled.
|
||||
"""
|
||||
def __init__(self, channel):
|
||||
self.channel = channel
|
||||
def __init__(self, channel: Union[GuildChannel, Thread]) -> None:
|
||||
self.channel: Union[GuildChannel, Thread] = channel
|
||||
super().__init__(f"Channel '{channel}' needs to be NSFW for this command to work.")
|
||||
|
||||
class MissingPermissions(CheckFailure):
|
||||
@@ -628,11 +660,11 @@ class MissingPermissions(CheckFailure):
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
missing_permissions: :class:`list`
|
||||
missing_permissions: List[:class:`str`]
|
||||
The required permissions that are missing.
|
||||
"""
|
||||
def __init__(self, missing_permissions, *args):
|
||||
self.missing_permissions = missing_permissions
|
||||
def __init__(self, missing_permissions: List[str], *args: Any) -> None:
|
||||
self.missing_permissions: List[str] = missing_permissions
|
||||
|
||||
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions]
|
||||
|
||||
@@ -651,11 +683,11 @@ class BotMissingPermissions(CheckFailure):
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
missing_permissions: :class:`list`
|
||||
missing_permissions: List[:class:`str`]
|
||||
The required permissions that are missing.
|
||||
"""
|
||||
def __init__(self, missing_permissions, *args):
|
||||
self.missing_permissions = missing_permissions
|
||||
def __init__(self, missing_permissions: List[str], *args: Any) -> None:
|
||||
self.missing_permissions: List[str] = missing_permissions
|
||||
|
||||
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions]
|
||||
|
||||
@@ -681,10 +713,10 @@ class BadUnionArgument(UserInputError):
|
||||
errors: List[:class:`CommandError`]
|
||||
A list of errors that were caught from failing the conversion.
|
||||
"""
|
||||
def __init__(self, param, converters, errors):
|
||||
self.param = param
|
||||
self.converters = converters
|
||||
self.errors = errors
|
||||
def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None:
|
||||
self.param: Parameter = param
|
||||
self.converters: Tuple[Type, ...] = converters
|
||||
self.errors: List[CommandError] = errors
|
||||
|
||||
def _get_name(x):
|
||||
try:
|
||||
@@ -719,10 +751,10 @@ class BadLiteralArgument(UserInputError):
|
||||
errors: List[:class:`CommandError`]
|
||||
A list of errors that were caught from failing the conversion.
|
||||
"""
|
||||
def __init__(self, param, literals, errors):
|
||||
self.param = param
|
||||
self.literals = literals
|
||||
self.errors = errors
|
||||
def __init__(self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError]) -> None:
|
||||
self.param: Parameter = param
|
||||
self.literals: Tuple[Any, ...] = literals
|
||||
self.errors: List[CommandError] = errors
|
||||
|
||||
to_string = [repr(l) for l in literals]
|
||||
if len(to_string) > 2:
|
||||
@@ -752,8 +784,8 @@ class UnexpectedQuoteError(ArgumentParsingError):
|
||||
quote: :class:`str`
|
||||
The quote mark that was found inside the non-quoted string.
|
||||
"""
|
||||
def __init__(self, quote):
|
||||
self.quote = quote
|
||||
def __init__(self, quote: str) -> None:
|
||||
self.quote: str = quote
|
||||
super().__init__(f'Unexpected quote mark, {quote!r}, in non-quoted string')
|
||||
|
||||
class InvalidEndOfQuotedStringError(ArgumentParsingError):
|
||||
@@ -767,8 +799,8 @@ class InvalidEndOfQuotedStringError(ArgumentParsingError):
|
||||
char: :class:`str`
|
||||
The character found instead of the expected string.
|
||||
"""
|
||||
def __init__(self, char):
|
||||
self.char = char
|
||||
def __init__(self, char: str) -> None:
|
||||
self.char: str = char
|
||||
super().__init__(f'Expected space after closing quotation but received {char!r}')
|
||||
|
||||
class ExpectedClosingQuoteError(ArgumentParsingError):
|
||||
@@ -782,8 +814,8 @@ class ExpectedClosingQuoteError(ArgumentParsingError):
|
||||
The quote character expected.
|
||||
"""
|
||||
|
||||
def __init__(self, close_quote):
|
||||
self.close_quote = close_quote
|
||||
def __init__(self, close_quote: str) -> None:
|
||||
self.close_quote: str = close_quote
|
||||
super().__init__(f'Expected closing {close_quote}.')
|
||||
|
||||
class ExtensionError(DiscordException):
|
||||
@@ -796,8 +828,8 @@ class ExtensionError(DiscordException):
|
||||
name: :class:`str`
|
||||
The extension that had an error.
|
||||
"""
|
||||
def __init__(self, message=None, *args, name):
|
||||
self.name = name
|
||||
def __init__(self, message: Optional[str] = None, *args: Any, name: str) -> None:
|
||||
self.name: str = name
|
||||
message = message or f'Extension {name!r} had an error.'
|
||||
# clean-up @everyone and @here mentions
|
||||
m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere')
|
||||
@@ -808,7 +840,7 @@ class ExtensionAlreadyLoaded(ExtensionError):
|
||||
|
||||
This inherits from :exc:`ExtensionError`
|
||||
"""
|
||||
def __init__(self, name):
|
||||
def __init__(self, name: str) -> None:
|
||||
super().__init__(f'Extension {name!r} is already loaded.', name=name)
|
||||
|
||||
class ExtensionNotLoaded(ExtensionError):
|
||||
@@ -816,7 +848,7 @@ class ExtensionNotLoaded(ExtensionError):
|
||||
|
||||
This inherits from :exc:`ExtensionError`
|
||||
"""
|
||||
def __init__(self, name):
|
||||
def __init__(self, name: str) -> None:
|
||||
super().__init__(f'Extension {name!r} has not been loaded.', name=name)
|
||||
|
||||
class NoEntryPointError(ExtensionError):
|
||||
@@ -824,7 +856,7 @@ class NoEntryPointError(ExtensionError):
|
||||
|
||||
This inherits from :exc:`ExtensionError`
|
||||
"""
|
||||
def __init__(self, name):
|
||||
def __init__(self, name: str) -> None:
|
||||
super().__init__(f"Extension {name!r} has no 'setup' function.", name=name)
|
||||
|
||||
class ExtensionFailed(ExtensionError):
|
||||
@@ -840,8 +872,8 @@ class ExtensionFailed(ExtensionError):
|
||||
The original exception that was raised. You can also get this via
|
||||
the ``__cause__`` attribute.
|
||||
"""
|
||||
def __init__(self, name, original):
|
||||
self.original = original
|
||||
def __init__(self, name: str, original: Exception) -> None:
|
||||
self.original: Exception = original
|
||||
msg = f'Extension {name!r} raised an error: {original.__class__.__name__}: {original}'
|
||||
super().__init__(msg, name=name)
|
||||
|
||||
@@ -858,7 +890,7 @@ class ExtensionNotFound(ExtensionError):
|
||||
name: :class:`str`
|
||||
The extension that had the error.
|
||||
"""
|
||||
def __init__(self, name):
|
||||
def __init__(self, name: str) -> None:
|
||||
msg = f'Extension {name!r} could not be loaded.'
|
||||
super().__init__(msg, name=name)
|
||||
|
||||
@@ -877,9 +909,9 @@ class CommandRegistrationError(ClientException):
|
||||
alias_conflict: :class:`bool`
|
||||
Whether the name that conflicts is an alias of the command we try to add.
|
||||
"""
|
||||
def __init__(self, name, *, alias_conflict=False):
|
||||
self.name = name
|
||||
self.alias_conflict = alias_conflict
|
||||
def __init__(self, name: str, *, alias_conflict: bool = False) -> None:
|
||||
self.name: str = name
|
||||
self.alias_conflict: bool = alias_conflict
|
||||
type_ = 'alias' if alias_conflict else 'command'
|
||||
super().__init__(f'The {type_} {name} is already an existing command or alias.')
|
||||
|
||||
@@ -906,17 +938,25 @@ class TooManyFlags(FlagError):
|
||||
values: List[:class:`str`]
|
||||
The values that were passed.
|
||||
"""
|
||||
def __init__(self, flag, values):
|
||||
self.flag = flag
|
||||
self.values = values
|
||||
def __init__(self, flag: Flag, values: List[str]) -> None:
|
||||
self.flag: Flag = flag
|
||||
self.values: List[str] = values
|
||||
super().__init__(f'Too many flag values, expected {flag.max_args} but received {len(values)}.')
|
||||
|
||||
class BadFlagArgument(FlagError):
|
||||
"""An exception raised when a flag failed to convert a value.
|
||||
|
||||
This inherits from :exc:`FlagError`
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
flag: :class:`~discord.ext.commands.Flag`
|
||||
The flag that failed to convert.
|
||||
"""
|
||||
def __init__(self, flag):
|
||||
self.flag = flag
|
||||
def __init__(self, flag: Flag) -> None:
|
||||
self.flag: Flag = flag
|
||||
try:
|
||||
name = flag.annotation.__name__
|
||||
except AttributeError:
|
||||
@@ -936,8 +976,8 @@ class MissingRequiredFlag(FlagError):
|
||||
flag: :class:`~discord.ext.commands.Flag`
|
||||
The required flag that was not found.
|
||||
"""
|
||||
def __init__(self, flag):
|
||||
self.flag = flag
|
||||
def __init__(self, flag: Flag) -> None:
|
||||
self.flag: Flag = flag
|
||||
super().__init__(f'Flag {flag.name!r} is required and missing')
|
||||
|
||||
class MissingFlagArgument(FlagError):
|
||||
@@ -952,6 +992,6 @@ class MissingFlagArgument(FlagError):
|
||||
flag: :class:`~discord.ext.commands.Flag`
|
||||
The flag that did not get a value.
|
||||
"""
|
||||
def __init__(self, flag):
|
||||
self.flag = flag
|
||||
def __init__(self, flag: Flag) -> None:
|
||||
self.flag: Flag = flag
|
||||
super().__init__(f'Flag {flag.name!r} does not have an argument')
|
||||
|
||||
@@ -27,11 +27,17 @@ import copy
|
||||
import functools
|
||||
import inspect
|
||||
import re
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
import discord.utils
|
||||
|
||||
from .core import Group, Command
|
||||
from .errors import CommandError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .context import Context
|
||||
|
||||
__all__ = (
|
||||
'Paginator',
|
||||
'HelpCommand',
|
||||
@@ -320,7 +326,7 @@ class HelpCommand:
|
||||
self.command_attrs = attrs = options.pop('command_attrs', {})
|
||||
attrs.setdefault('name', 'help')
|
||||
attrs.setdefault('help', 'Shows this message')
|
||||
self.context = None
|
||||
self.context: Context = discord.utils.MISSING
|
||||
self._command_impl = _HelpCommandImpl(self, **self.command_attrs)
|
||||
|
||||
def copy(self):
|
||||
|
||||
@@ -27,22 +27,20 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
@@ -50,8 +48,6 @@ from collections.abc import Sequence
|
||||
from discord.backoff import ExponentialBackoff
|
||||
from discord.utils import MISSING
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
__all__ = (
|
||||
'loop',
|
||||
)
|
||||
@@ -61,7 +57,6 @@ _func = Callable[..., Awaitable[Any]]
|
||||
LF = TypeVar('LF', bound=_func)
|
||||
FT = TypeVar('FT', bound=_func)
|
||||
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
|
||||
LT = TypeVar('LT', bound='Loop')
|
||||
|
||||
|
||||
class SleepHandle:
|
||||
@@ -78,7 +73,7 @@ class SleepHandle:
|
||||
relative_delta = discord.utils.compute_timedelta(dt)
|
||||
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True)
|
||||
|
||||
def wait(self) -> asyncio.Future:
|
||||
def wait(self) -> asyncio.Future[Any]:
|
||||
return self.future
|
||||
|
||||
def done(self) -> bool:
|
||||
@@ -94,7 +89,9 @@ class Loop(Generic[LF]):
|
||||
|
||||
The main interface to create this is through :func:`loop`.
|
||||
"""
|
||||
def __init__(self,
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
coro: LF,
|
||||
seconds: float,
|
||||
hours: float,
|
||||
@@ -102,15 +99,15 @@ class Loop(Generic[LF]):
|
||||
time: Union[datetime.time, Sequence[datetime.time]],
|
||||
count: Optional[int],
|
||||
reconnect: bool,
|
||||
loop: Optional[asyncio.AbstractEventLoop],
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
) -> None:
|
||||
self.coro: LF = coro
|
||||
self.reconnect: bool = reconnect
|
||||
self.loop: Optional[asyncio.AbstractEventLoop] = loop
|
||||
self.loop: asyncio.AbstractEventLoop = loop
|
||||
self.count: Optional[int] = count
|
||||
self._current_loop = 0
|
||||
self._handle = None
|
||||
self._task = None
|
||||
self._handle: SleepHandle = MISSING
|
||||
self._task: asyncio.Task[None] = MISSING
|
||||
self._injected = None
|
||||
self._valid_exception = (
|
||||
OSError,
|
||||
@@ -131,7 +128,7 @@ class Loop(Generic[LF]):
|
||||
|
||||
self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time)
|
||||
self._last_iteration_failed = False
|
||||
self._last_iteration = None
|
||||
self._last_iteration: datetime.datetime = MISSING
|
||||
self._next_iteration = None
|
||||
|
||||
if not inspect.iscoroutinefunction(self.coro):
|
||||
@@ -147,9 +144,8 @@ class Loop(Generic[LF]):
|
||||
else:
|
||||
await coro(*args, **kwargs)
|
||||
|
||||
|
||||
def _try_sleep_until(self, dt: datetime.datetime):
|
||||
self._handle = SleepHandle(dt=dt, loop=self.loop) # type: ignore
|
||||
self._handle = SleepHandle(dt=dt, loop=self.loop)
|
||||
return self._handle.wait()
|
||||
|
||||
async def _loop(self, *args: Any, **kwargs: Any) -> None:
|
||||
@@ -178,7 +174,7 @@ class Loop(Generic[LF]):
|
||||
await asyncio.sleep(backoff.delay())
|
||||
else:
|
||||
await self._try_sleep_until(self._next_iteration)
|
||||
|
||||
|
||||
if self._stop_next_iteration:
|
||||
return
|
||||
|
||||
@@ -211,14 +207,14 @@ class Loop(Generic[LF]):
|
||||
if obj is None:
|
||||
return self
|
||||
|
||||
copy = Loop(
|
||||
self.coro,
|
||||
seconds=self._seconds,
|
||||
hours=self._hours,
|
||||
copy: Loop[LF] = Loop(
|
||||
self.coro,
|
||||
seconds=self._seconds,
|
||||
hours=self._hours,
|
||||
minutes=self._minutes,
|
||||
time=self._time,
|
||||
time=self._time,
|
||||
count=self.count,
|
||||
reconnect=self.reconnect,
|
||||
reconnect=self.reconnect,
|
||||
loop=self.loop,
|
||||
)
|
||||
copy._injected = obj
|
||||
@@ -237,7 +233,7 @@ class Loop(Generic[LF]):
|
||||
"""
|
||||
if self._seconds is not MISSING:
|
||||
return self._seconds
|
||||
|
||||
|
||||
@property
|
||||
def minutes(self) -> Optional[float]:
|
||||
"""Optional[:class:`float`]: Read-only value for the number of minutes
|
||||
@@ -247,7 +243,7 @@ class Loop(Generic[LF]):
|
||||
"""
|
||||
if self._minutes is not MISSING:
|
||||
return self._minutes
|
||||
|
||||
|
||||
@property
|
||||
def hours(self) -> Optional[float]:
|
||||
"""Optional[:class:`float`]: Read-only value for the number of hours
|
||||
@@ -279,7 +275,7 @@ class Loop(Generic[LF]):
|
||||
|
||||
.. versionadded:: 1.3
|
||||
"""
|
||||
if self._task is None:
|
||||
if self._task is MISSING:
|
||||
return None
|
||||
elif self._task and self._task.done() or self._stop_next_iteration:
|
||||
return None
|
||||
@@ -305,7 +301,7 @@ class Loop(Generic[LF]):
|
||||
|
||||
return await self.coro(*args, **kwargs)
|
||||
|
||||
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task:
|
||||
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
|
||||
r"""Starts the internal task in the event loop.
|
||||
|
||||
Parameters
|
||||
@@ -326,13 +322,13 @@ class Loop(Generic[LF]):
|
||||
The task that has been created.
|
||||
"""
|
||||
|
||||
if self._task is not None and not self._task.done():
|
||||
if self._task is not MISSING and not self._task.done():
|
||||
raise RuntimeError('Task is already launched and is not completed.')
|
||||
|
||||
if self._injected is not None:
|
||||
args = (self._injected, *args)
|
||||
|
||||
if self.loop is None:
|
||||
if self.loop is MISSING:
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
self._task = self.loop.create_task(self._loop(*args, **kwargs))
|
||||
@@ -356,7 +352,7 @@ class Loop(Generic[LF]):
|
||||
|
||||
.. versionadded:: 1.2
|
||||
"""
|
||||
if self._task and not self._task.done():
|
||||
if self._task is not MISSING and not self._task.done():
|
||||
self._stop_next_iteration = True
|
||||
|
||||
def _can_be_cancelled(self) -> bool:
|
||||
@@ -383,7 +379,7 @@ class Loop(Generic[LF]):
|
||||
The keyword arguments to use.
|
||||
"""
|
||||
|
||||
def restart_when_over(fut, *, args=args, kwargs=kwargs):
|
||||
def restart_when_over(fut: Any, *, args: Any = args, kwargs: Any = kwargs) -> None:
|
||||
self._task.remove_done_callback(restart_when_over)
|
||||
self.start(*args, **kwargs)
|
||||
|
||||
@@ -446,9 +442,9 @@ class Loop(Generic[LF]):
|
||||
self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions)
|
||||
return len(self._valid_exception) == old_length - len(exceptions)
|
||||
|
||||
def get_task(self) -> Optional[asyncio.Task]:
|
||||
def get_task(self) -> Optional[asyncio.Task[None]]:
|
||||
"""Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running."""
|
||||
return self._task
|
||||
return self._task if self._task is not MISSING else None
|
||||
|
||||
def is_being_cancelled(self) -> bool:
|
||||
"""Whether the task is being cancelled."""
|
||||
@@ -466,7 +462,7 @@ class Loop(Generic[LF]):
|
||||
|
||||
.. versionadded:: 1.4
|
||||
"""
|
||||
return not bool(self._task.done()) if self._task else False
|
||||
return not bool(self._task.done()) if self._task is not MISSING else False
|
||||
|
||||
async def _error(self, *args: Any) -> None:
|
||||
exception: Exception = args[-1]
|
||||
@@ -560,7 +556,9 @@ class Loop(Generic[LF]):
|
||||
self._time_index = 0
|
||||
if self._current_loop == 0:
|
||||
# if we're at the last index on the first iteration, we need to sleep until tomorrow
|
||||
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0])
|
||||
return datetime.datetime.combine(
|
||||
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0]
|
||||
)
|
||||
|
||||
next_time = self._time[self._time_index]
|
||||
|
||||
@@ -568,7 +566,7 @@ class Loop(Generic[LF]):
|
||||
self._time_index += 1
|
||||
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time)
|
||||
|
||||
next_date = cast(datetime.datetime, self._last_iteration)
|
||||
next_date = self._last_iteration
|
||||
if self._time_index == 0:
|
||||
# we can assume that the earliest time should be scheduled for "tomorrow"
|
||||
next_date += datetime.timedelta(days=1)
|
||||
@@ -576,12 +574,14 @@ class Loop(Generic[LF]):
|
||||
self._time_index += 1
|
||||
return datetime.datetime.combine(next_date, next_time)
|
||||
|
||||
def _prepare_time_index(self, now: Optional[datetime.datetime] = None) -> None:
|
||||
def _prepare_time_index(self, now: datetime.datetime = MISSING) -> None:
|
||||
# now kwarg should be a datetime.datetime representing the time "now"
|
||||
# to calculate the next time index from
|
||||
|
||||
# pre-condition: self._time is set
|
||||
time_now = (now or datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)).timetz()
|
||||
time_now = (
|
||||
now if now is not MISSING else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)
|
||||
).timetz()
|
||||
for idx, time in enumerate(self._time):
|
||||
if time >= time_now:
|
||||
self._time_index = idx
|
||||
@@ -597,20 +597,24 @@ class Loop(Generic[LF]):
|
||||
utc: datetime.timezone = datetime.timezone.utc,
|
||||
) -> List[datetime.time]:
|
||||
if isinstance(time, dt):
|
||||
ret = time if time.tzinfo is not None else time.replace(tzinfo=utc)
|
||||
return [ret]
|
||||
inner = time if time.tzinfo is not None else time.replace(tzinfo=utc)
|
||||
return [inner]
|
||||
if not isinstance(time, Sequence):
|
||||
raise TypeError(f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.')
|
||||
raise TypeError(
|
||||
f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.'
|
||||
)
|
||||
if not time:
|
||||
raise ValueError('time parameter must not be an empty sequence.')
|
||||
|
||||
ret = []
|
||||
ret: List[datetime.time] = []
|
||||
for index, t in enumerate(time):
|
||||
if not isinstance(t, dt):
|
||||
raise TypeError(f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.')
|
||||
raise TypeError(
|
||||
f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.'
|
||||
)
|
||||
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
|
||||
|
||||
ret = sorted(set(ret)) # de-dupe and sort times
|
||||
ret = sorted(set(ret)) # de-dupe and sort times
|
||||
return ret
|
||||
|
||||
def change_interval(
|
||||
@@ -691,7 +695,7 @@ def loop(
|
||||
time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
|
||||
count: Optional[int] = None,
|
||||
reconnect: bool = True,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
loop: asyncio.AbstractEventLoop = MISSING,
|
||||
) -> Callable[[LF], Loop[LF]]:
|
||||
"""A decorator that schedules a task in the background for you with
|
||||
optional reconnect logic. The decorator returns a :class:`Loop`.
|
||||
@@ -707,7 +711,7 @@ def loop(
|
||||
time: Union[:class:`datetime.time`, Sequence[:class:`datetime.time`]]
|
||||
The exact times to run this loop at. Either a non-empty list or a single
|
||||
value of :class:`datetime.time` should be passed. Timezones are supported.
|
||||
If no timezone is given for the times, it is assumed to represent UTC time.
|
||||
If no timezone is given for the times, it is assumed to represent UTC time.
|
||||
|
||||
This cannot be used in conjunction with the relative time parameters.
|
||||
|
||||
@@ -724,7 +728,7 @@ def loop(
|
||||
Whether to handle errors and restart the task
|
||||
using an exponential back-off algorithm similar to the
|
||||
one used in :meth:`discord.Client.connect`.
|
||||
loop: Optional[:class:`asyncio.AbstractEventLoop`]
|
||||
loop: :class:`asyncio.AbstractEventLoop`
|
||||
The loop to use to register the task, if not given
|
||||
defaults to :func:`asyncio.get_event_loop`.
|
||||
|
||||
@@ -736,15 +740,17 @@ def loop(
|
||||
The function was not a coroutine, an invalid value for the ``time`` parameter was passed,
|
||||
or ``time`` parameter was passed in conjunction with relative time parameters.
|
||||
"""
|
||||
|
||||
def decorator(func: LF) -> Loop[LF]:
|
||||
kwargs = {
|
||||
'seconds': seconds,
|
||||
'minutes': minutes,
|
||||
'hours': hours,
|
||||
'count': count,
|
||||
'time': time,
|
||||
'reconnect': reconnect,
|
||||
'loop': loop,
|
||||
}
|
||||
return Loop(func, **kwargs)
|
||||
return Loop[LF](
|
||||
func,
|
||||
seconds=seconds,
|
||||
minutes=minutes,
|
||||
hours=hours,
|
||||
count=count,
|
||||
time=time,
|
||||
reconnect=reconnect,
|
||||
loop=loop,
|
||||
)
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -41,7 +41,7 @@ FV = TypeVar('FV', bound='flag_value')
|
||||
BF = TypeVar('BF', bound='BaseFlags')
|
||||
|
||||
|
||||
class flag_value(Generic[BF]):
|
||||
class flag_value:
|
||||
def __init__(self, func: Callable[[Any], int]):
|
||||
self.flag = func(None)
|
||||
self.__doc__ = func.__doc__
|
||||
@@ -205,7 +205,7 @@ class SystemChannelFlags(BaseFlags):
|
||||
|
||||
@flag_value
|
||||
def premium_subscriptions(self):
|
||||
""":class:`bool`: Returns ``True`` if the system channel is used for Nitro boosting notifications."""
|
||||
""":class:`bool`: Returns ``True`` if the system channel is used for "Nitro boosting" notifications."""
|
||||
return 2
|
||||
|
||||
@flag_value
|
||||
@@ -480,16 +480,6 @@ class Intents(BaseFlags):
|
||||
self.value = self.DEFAULT_VALUE
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def default(cls: Type[Intents]) -> Intents:
|
||||
"""A factory method that creates a :class:`Intents` with everything enabled
|
||||
except :attr:`presences` and :attr:`members`.
|
||||
"""
|
||||
self = cls.all()
|
||||
self.presences = False
|
||||
self.members = False
|
||||
return self
|
||||
|
||||
@flag_value
|
||||
def guilds(self):
|
||||
""":class:`bool`: Whether guild related events are enabled.
|
||||
@@ -566,18 +556,34 @@ class Intents(BaseFlags):
|
||||
|
||||
@flag_value
|
||||
def emojis(self):
|
||||
""":class:`bool`: Whether guild emoji related events are enabled.
|
||||
""":class:`bool`: Alias of :attr:`.emojis_and_stickers`.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
Changed to an alias.
|
||||
"""
|
||||
return 1 << 3
|
||||
|
||||
@alias_flag_value
|
||||
def emojis_and_stickers(self):
|
||||
""":class:`bool`: Whether guild emoji and sticker related events are enabled.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
This corresponds to the following events:
|
||||
|
||||
- :func:`on_guild_emojis_update`
|
||||
- :func:`on_guild_stickers_update`
|
||||
|
||||
This also corresponds to the following attributes and classes in terms of cache:
|
||||
|
||||
- :class:`Emoji`
|
||||
- :class:`GuildSticker`
|
||||
- :meth:`Client.get_emoji`
|
||||
- :meth:`Client.get_sticker`
|
||||
- :meth:`Client.emojis`
|
||||
- :meth:`Client.stickers`
|
||||
- :attr:`Guild.emojis`
|
||||
- :attr:`Guild.stickers`
|
||||
"""
|
||||
return 1 << 3
|
||||
|
||||
@@ -634,6 +640,10 @@ class Intents(BaseFlags):
|
||||
- :attr:`VoiceChannel.members`
|
||||
- :attr:`VoiceChannel.voice_states`
|
||||
- :attr:`Member.voice`
|
||||
|
||||
.. note::
|
||||
|
||||
This intent is required to connect to voice.
|
||||
"""
|
||||
return 1 << 7
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ from .activity import BaseActivity
|
||||
from .enums import SpeakingState
|
||||
from .errors import ConnectionClosed, InvalidArgument
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
__all__ = (
|
||||
'DiscordWebSocket',
|
||||
@@ -101,7 +101,7 @@ class GatewayRatelimiter:
|
||||
async with self.lock:
|
||||
delta = self.get_delay()
|
||||
if delta:
|
||||
log.warning('WebSocket in shard ID %s is ratelimited, waiting %.2f seconds', self.shard_id, delta)
|
||||
_log.warning('WebSocket in shard ID %s is ratelimited, waiting %.2f seconds', self.shard_id, delta)
|
||||
await asyncio.sleep(delta)
|
||||
|
||||
|
||||
@@ -129,20 +129,20 @@ class KeepAliveHandler(threading.Thread):
|
||||
def run(self):
|
||||
while not self._stop_ev.wait(self.interval):
|
||||
if self._last_recv + self.heartbeat_timeout < time.perf_counter():
|
||||
log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id)
|
||||
_log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id)
|
||||
coro = self.ws.close(4000)
|
||||
f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop)
|
||||
|
||||
try:
|
||||
f.result()
|
||||
except Exception:
|
||||
log.exception('An error occurred while stopping the gateway. Ignoring.')
|
||||
_log.exception('An error occurred while stopping the gateway. Ignoring.')
|
||||
finally:
|
||||
self.stop()
|
||||
return
|
||||
|
||||
data = self.get_payload()
|
||||
log.debug(self.msg, self.shard_id, data['d'])
|
||||
_log.debug(self.msg, self.shard_id, data['d'])
|
||||
coro = self.ws.send_heartbeat(data)
|
||||
f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop)
|
||||
try:
|
||||
@@ -161,7 +161,7 @@ class KeepAliveHandler(threading.Thread):
|
||||
else:
|
||||
stack = ''.join(traceback.format_stack(frame))
|
||||
msg = f'{self.block_msg}\nLoop thread traceback (most recent call last):\n{stack}'
|
||||
log.warning(msg, self.shard_id, total)
|
||||
_log.warning(msg, self.shard_id, total)
|
||||
|
||||
except Exception:
|
||||
self.stop()
|
||||
@@ -185,7 +185,7 @@ class KeepAliveHandler(threading.Thread):
|
||||
self._last_ack = ack_time
|
||||
self.latency = ack_time - self._last_send
|
||||
if self.latency > 10:
|
||||
log.warning(self.behind_msg, self.shard_id, self.latency)
|
||||
_log.warning(self.behind_msg, self.shard_id, self.latency)
|
||||
|
||||
class VoiceKeepAliveHandler(KeepAliveHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -293,6 +293,12 @@ class DiscordWebSocket:
|
||||
def is_ratelimited(self):
|
||||
return self._rate_limiter.is_ratelimited()
|
||||
|
||||
def debug_log_receive(self, data, /):
|
||||
self._dispatch('socket_raw_receive', data)
|
||||
|
||||
def log_receive(self, _, /):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def from_client(cls, client, *, initial=False, gateway=None, shard_id=None, session=None, sequence=None, resume=False):
|
||||
"""Creates a main websocket for Discord from a :class:`Client`.
|
||||
@@ -318,9 +324,13 @@ class DiscordWebSocket:
|
||||
ws.sequence = sequence
|
||||
ws._max_heartbeat_timeout = client._connection.heartbeat_timeout
|
||||
|
||||
if client._enable_debug_events:
|
||||
ws.send = ws.debug_send
|
||||
ws.log_receive = ws.debug_log_receive
|
||||
|
||||
client._connection._update_references(ws)
|
||||
|
||||
log.debug('Created websocket connected to %s', gateway)
|
||||
_log.debug('Created websocket connected to %s', gateway)
|
||||
|
||||
# poll event for OP Hello
|
||||
await ws.poll_event()
|
||||
@@ -393,7 +403,7 @@ class DiscordWebSocket:
|
||||
|
||||
await self.call_hooks('before_identify', self.shard_id, initial=self._initial_identify)
|
||||
await self.send_as_json(payload)
|
||||
log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id)
|
||||
_log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id)
|
||||
|
||||
async def resume(self):
|
||||
"""Sends the RESUME packet."""
|
||||
@@ -407,11 +417,9 @@ class DiscordWebSocket:
|
||||
}
|
||||
|
||||
await self.send_as_json(payload)
|
||||
log.info('Shard ID %s has sent the RESUME payload.', self.shard_id)
|
||||
|
||||
async def received_message(self, msg):
|
||||
self._dispatch('socket_raw_receive', msg)
|
||||
_log.info('Shard ID %s has sent the RESUME payload.', self.shard_id)
|
||||
|
||||
async def received_message(self, msg, /):
|
||||
if type(msg) is bytes:
|
||||
self._buffer.extend(msg)
|
||||
|
||||
@@ -420,10 +428,14 @@ class DiscordWebSocket:
|
||||
msg = self._zlib.decompress(self._buffer)
|
||||
msg = msg.decode('utf-8')
|
||||
self._buffer = bytearray()
|
||||
msg = utils.from_json(msg)
|
||||
|
||||
log.debug('For Shard ID %s: WebSocket Event: %s', self.shard_id, msg)
|
||||
self._dispatch('socket_response', msg)
|
||||
self.log_receive(msg)
|
||||
msg = utils._from_json(msg)
|
||||
|
||||
_log.debug('For Shard ID %s: WebSocket Event: %s', self.shard_id, msg)
|
||||
event = msg.get('t')
|
||||
if event:
|
||||
self._dispatch('socket_event_type', event)
|
||||
|
||||
op = msg.get('op')
|
||||
data = msg.get('d')
|
||||
@@ -439,7 +451,7 @@ class DiscordWebSocket:
|
||||
# "reconnect" can only be handled by the Client
|
||||
# so we terminate our connection and raise an
|
||||
# internal exception signalling to reconnect.
|
||||
log.debug('Received RECONNECT opcode.')
|
||||
_log.debug('Received RECONNECT opcode.')
|
||||
await self.close()
|
||||
raise ReconnectWebSocket(self.shard_id)
|
||||
|
||||
@@ -469,35 +481,33 @@ class DiscordWebSocket:
|
||||
|
||||
self.sequence = None
|
||||
self.session_id = None
|
||||
log.info('Shard ID %s session has been invalidated.', self.shard_id)
|
||||
_log.info('Shard ID %s session has been invalidated.', self.shard_id)
|
||||
await self.close(code=1000)
|
||||
raise ReconnectWebSocket(self.shard_id, resume=False)
|
||||
|
||||
log.warning('Unknown OP code %s.', op)
|
||||
_log.warning('Unknown OP code %s.', op)
|
||||
return
|
||||
|
||||
event = msg.get('t')
|
||||
|
||||
if event == 'READY':
|
||||
self._trace = trace = data.get('_trace', [])
|
||||
self.sequence = msg['s']
|
||||
self.session_id = data['session_id']
|
||||
# pass back shard ID to ready handler
|
||||
data['__shard_id__'] = self.shard_id
|
||||
log.info('Shard ID %s has connected to Gateway: %s (Session ID: %s).',
|
||||
_log.info('Shard ID %s has connected to Gateway: %s (Session ID: %s).',
|
||||
self.shard_id, ', '.join(trace), self.session_id)
|
||||
|
||||
elif event == 'RESUMED':
|
||||
self._trace = trace = data.get('_trace', [])
|
||||
# pass back the shard ID to the resumed handler
|
||||
data['__shard_id__'] = self.shard_id
|
||||
log.info('Shard ID %s has successfully RESUMED session %s under trace %s.',
|
||||
_log.info('Shard ID %s has successfully RESUMED session %s under trace %s.',
|
||||
self.shard_id, self.session_id, ', '.join(trace))
|
||||
|
||||
try:
|
||||
func = self._discord_parsers[event]
|
||||
except KeyError:
|
||||
log.debug('Unknown event %s.', event)
|
||||
_log.debug('Unknown event %s.', event)
|
||||
else:
|
||||
func(data)
|
||||
|
||||
@@ -551,10 +561,10 @@ class DiscordWebSocket:
|
||||
elif msg.type is aiohttp.WSMsgType.BINARY:
|
||||
await self.received_message(msg.data)
|
||||
elif msg.type is aiohttp.WSMsgType.ERROR:
|
||||
log.debug('Received %s', msg)
|
||||
_log.debug('Received %s', msg)
|
||||
raise msg.data
|
||||
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSE):
|
||||
log.debug('Received %s', msg)
|
||||
_log.debug('Received %s', msg)
|
||||
raise WebSocketClosure
|
||||
except (asyncio.TimeoutError, WebSocketClosure) as e:
|
||||
# Ensure the keep alive handler is closed
|
||||
@@ -563,25 +573,29 @@ class DiscordWebSocket:
|
||||
self._keep_alive = None
|
||||
|
||||
if isinstance(e, asyncio.TimeoutError):
|
||||
log.info('Timed out receiving packet. Attempting a reconnect.')
|
||||
_log.info('Timed out receiving packet. Attempting a reconnect.')
|
||||
raise ReconnectWebSocket(self.shard_id) from None
|
||||
|
||||
code = self._close_code or self.socket.close_code
|
||||
if self._can_handle_close():
|
||||
log.info('Websocket closed with %s, attempting a reconnect.', code)
|
||||
_log.info('Websocket closed with %s, attempting a reconnect.', code)
|
||||
raise ReconnectWebSocket(self.shard_id) from None
|
||||
else:
|
||||
log.info('Websocket closed with %s, cannot reconnect.', code)
|
||||
_log.info('Websocket closed with %s, cannot reconnect.', code)
|
||||
raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None
|
||||
|
||||
async def send(self, data):
|
||||
async def debug_send(self, data, /):
|
||||
await self._rate_limiter.block()
|
||||
self._dispatch('socket_raw_send', data)
|
||||
await self.socket.send_str(data)
|
||||
|
||||
async def send(self, data, /):
|
||||
await self._rate_limiter.block()
|
||||
await self.socket.send_str(data)
|
||||
|
||||
async def send_as_json(self, data):
|
||||
try:
|
||||
await self.send(utils.to_json(data))
|
||||
await self.send(utils._to_json(data))
|
||||
except RuntimeError as exc:
|
||||
if not self._can_handle_close():
|
||||
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
|
||||
@@ -589,7 +603,7 @@ class DiscordWebSocket:
|
||||
async def send_heartbeat(self, data):
|
||||
# This bypasses the rate limit handling code since it has a higher priority
|
||||
try:
|
||||
await self.socket.send_str(utils.to_json(data))
|
||||
await self.socket.send_str(utils._to_json(data))
|
||||
except RuntimeError as exc:
|
||||
if not self._can_handle_close():
|
||||
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
|
||||
@@ -615,8 +629,8 @@ class DiscordWebSocket:
|
||||
}
|
||||
}
|
||||
|
||||
sent = utils.to_json(payload)
|
||||
log.debug('Sending "%s" to change status', sent)
|
||||
sent = utils._to_json(payload)
|
||||
_log.debug('Sending "%s" to change status', sent)
|
||||
await self.send(sent)
|
||||
|
||||
async def request_chunks(self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None):
|
||||
@@ -652,7 +666,7 @@ class DiscordWebSocket:
|
||||
}
|
||||
}
|
||||
|
||||
log.debug('Updating our voice state to %s.', payload)
|
||||
_log.debug('Updating our voice state to %s.', payload)
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def close(self, code=4000):
|
||||
@@ -720,8 +734,8 @@ class DiscordVoiceWebSocket:
|
||||
pass
|
||||
|
||||
async def send_as_json(self, data):
|
||||
log.debug('Sending voice websocket frame: %s.', data)
|
||||
await self.ws.send_str(utils.to_json(data))
|
||||
_log.debug('Sending voice websocket frame: %s.', data)
|
||||
await self.ws.send_str(utils._to_json(data))
|
||||
|
||||
send_heartbeat = send_as_json
|
||||
|
||||
@@ -806,7 +820,7 @@ class DiscordVoiceWebSocket:
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def received_message(self, msg):
|
||||
log.debug('Voice websocket frame received: %s', msg)
|
||||
_log.debug('Voice websocket frame received: %s', msg)
|
||||
op = msg['op']
|
||||
data = msg.get('d')
|
||||
|
||||
@@ -815,7 +829,7 @@ class DiscordVoiceWebSocket:
|
||||
elif op == self.HEARTBEAT_ACK:
|
||||
self._keep_alive.ack()
|
||||
elif op == self.RESUMED:
|
||||
log.info('Voice RESUME succeeded.')
|
||||
_log.info('Voice RESUME succeeded.')
|
||||
elif op == self.SESSION_DESCRIPTION:
|
||||
self._connection.mode = data['mode']
|
||||
await self.load_secret_key(data)
|
||||
@@ -838,7 +852,7 @@ class DiscordVoiceWebSocket:
|
||||
struct.pack_into('>I', packet, 4, state.ssrc)
|
||||
state.socket.sendto(packet, (state.endpoint_ip, state.voice_port))
|
||||
recv = await self.loop.sock_recv(state.socket, 70)
|
||||
log.debug('received packet in initial_connection: %s', recv)
|
||||
_log.debug('received packet in initial_connection: %s', recv)
|
||||
|
||||
# the ip is ascii starting at the 4th byte and ending at the first null
|
||||
ip_start = 4
|
||||
@@ -846,15 +860,15 @@ class DiscordVoiceWebSocket:
|
||||
state.ip = recv[ip_start:ip_end].decode('ascii')
|
||||
|
||||
state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
|
||||
log.debug('detected ip: %s port: %s', state.ip, state.port)
|
||||
_log.debug('detected ip: %s port: %s', state.ip, state.port)
|
||||
|
||||
# there *should* always be at least one supported mode (xsalsa20_poly1305)
|
||||
modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes]
|
||||
log.debug('received supported encryption modes: %s', ", ".join(modes))
|
||||
_log.debug('received supported encryption modes: %s', ", ".join(modes))
|
||||
|
||||
mode = modes[0]
|
||||
await self.select_protocol(state.ip, state.port, mode)
|
||||
log.info('selected the voice protocol for use (%s)', mode)
|
||||
_log.info('selected the voice protocol for use (%s)', mode)
|
||||
|
||||
@property
|
||||
def latency(self):
|
||||
@@ -872,7 +886,7 @@ class DiscordVoiceWebSocket:
|
||||
return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies)
|
||||
|
||||
async def load_secret_key(self, data):
|
||||
log.info('received secret key for voice connection')
|
||||
_log.info('received secret key for voice connection')
|
||||
self.secret_key = self._connection.secret_key = data.get('secret_key')
|
||||
await self.speak()
|
||||
await self.speak(False)
|
||||
@@ -881,12 +895,12 @@ class DiscordVoiceWebSocket:
|
||||
# This exception is handled up the chain
|
||||
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
|
||||
if msg.type is aiohttp.WSMsgType.TEXT:
|
||||
await self.received_message(utils.from_json(msg.data))
|
||||
await self.received_message(utils._from_json(msg.data))
|
||||
elif msg.type is aiohttp.WSMsgType.ERROR:
|
||||
log.debug('Received %s', msg)
|
||||
_log.debug('Received %s', msg)
|
||||
raise ConnectionClosed(self.ws, shard_id=None) from msg.data
|
||||
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
|
||||
log.debug('Received %s', msg)
|
||||
_log.debug('Received %s', msg)
|
||||
raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code)
|
||||
|
||||
async def close(self, code=1000):
|
||||
|
||||
269
discord/guild.py
269
discord/guild.py
@@ -25,6 +25,7 @@ DEALINGS IN THE SOFTWARE.
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import unicodedata
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
@@ -51,6 +52,7 @@ from .colour import Colour
|
||||
from .errors import InvalidArgument, ClientException
|
||||
from .channel import *
|
||||
from .channel import _guild_channel_factory
|
||||
from .channel import _threaded_guild_channel_factory
|
||||
from .enums import (
|
||||
AuditLogAction,
|
||||
VideoQualityMode,
|
||||
@@ -71,7 +73,10 @@ from .asset import Asset
|
||||
from .flags import SystemChannelFlags
|
||||
from .integrations import Integration, _integration_factory
|
||||
from .stage_instance import StageInstance
|
||||
from .threads import Thread
|
||||
from .threads import Thread, ThreadMember
|
||||
from .sticker import GuildSticker
|
||||
from .file import File
|
||||
|
||||
|
||||
__all__ = (
|
||||
'Guild',
|
||||
@@ -107,6 +112,7 @@ class BanEntry(NamedTuple):
|
||||
|
||||
class _GuildLimit(NamedTuple):
|
||||
emoji: int
|
||||
stickers: int
|
||||
bitrate: float
|
||||
filesize: int
|
||||
|
||||
@@ -134,12 +140,20 @@ class Guild(Hashable):
|
||||
|
||||
Returns the guild's name.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the guild's ID.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name: :class:`str`
|
||||
The guild name.
|
||||
emojis: Tuple[:class:`Emoji`, ...]
|
||||
All emojis that the guild owns.
|
||||
stickers: Tuple[:class:`GuildSticker`, ...]
|
||||
All stickers that the guild owns.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
region: :class:`VoiceRegion`
|
||||
The region the guild belongs on. There is a chance that the region
|
||||
will be a :class:`str` if the value is not recognised by the enumerator.
|
||||
@@ -234,6 +248,7 @@ class Guild(Hashable):
|
||||
'owner_id',
|
||||
'mfa_level',
|
||||
'emojis',
|
||||
'stickers',
|
||||
'features',
|
||||
'verification_level',
|
||||
'explicit_content_filter',
|
||||
@@ -266,11 +281,11 @@ class Guild(Hashable):
|
||||
)
|
||||
|
||||
_PREMIUM_GUILD_LIMITS: ClassVar[Dict[Optional[int], _GuildLimit]] = {
|
||||
None: _GuildLimit(emoji=50, bitrate=96e3, filesize=8388608),
|
||||
0: _GuildLimit(emoji=50, bitrate=96e3, filesize=8388608),
|
||||
1: _GuildLimit(emoji=100, bitrate=128e3, filesize=8388608),
|
||||
2: _GuildLimit(emoji=150, bitrate=256e3, filesize=52428800),
|
||||
3: _GuildLimit(emoji=250, bitrate=384e3, filesize=104857600),
|
||||
None: _GuildLimit(emoji=50, stickers=0, bitrate=96e3, filesize=8388608),
|
||||
0: _GuildLimit(emoji=50, stickers=0, bitrate=96e3, filesize=8388608),
|
||||
1: _GuildLimit(emoji=100, stickers=15, bitrate=128e3, filesize=8388608),
|
||||
2: _GuildLimit(emoji=150, stickers=30, bitrate=256e3, filesize=52428800),
|
||||
3: _GuildLimit(emoji=250, stickers=60, bitrate=384e3, filesize=104857600),
|
||||
}
|
||||
|
||||
def __init__(self, *, data: GuildPayload, state: ConnectionState):
|
||||
@@ -412,6 +427,9 @@ class Guild(Hashable):
|
||||
|
||||
self.mfa_level: MFALevel = guild.get('mfa_level')
|
||||
self.emojis: Tuple[Emoji, ...] = tuple(map(lambda d: state.store_emoji(self, d), guild.get('emojis', [])))
|
||||
self.stickers: Tuple[GuildSticker, ...] = tuple(
|
||||
map(lambda d: state.store_sticker(self, d), guild.get('stickers', []))
|
||||
)
|
||||
self.features: List[GuildFeature] = guild.get('features', [])
|
||||
self._splash: Optional[str] = guild.get('splash')
|
||||
self._system_channel_id: Optional[int] = utils._get_as_snowflake(guild, 'system_channel_id')
|
||||
@@ -599,6 +617,23 @@ class Guild(Hashable):
|
||||
|
||||
return self._channels.get(id) or self._threads.get(id)
|
||||
|
||||
def get_channel_or_thread(self, channel_id: int, /) -> Optional[Union[Thread, GuildChannel]]:
|
||||
"""Returns a channel or thread with the given ID.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
channel_id: :class:`int`
|
||||
The ID to search for.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Optional[Union[:class:`Thread`, :class:`.abc.GuildChannel`]]
|
||||
The returned channel or thread or ``None`` if not found.
|
||||
"""
|
||||
return self._channels.get(channel_id) or self._threads.get(channel_id)
|
||||
|
||||
def get_channel(self, channel_id: int, /) -> Optional[GuildChannel]:
|
||||
"""Returns a channel with the given ID.
|
||||
|
||||
@@ -680,6 +715,15 @@ class Guild(Hashable):
|
||||
more_emoji = 200 if 'MORE_EMOJI' in self.features else 50
|
||||
return max(more_emoji, self._PREMIUM_GUILD_LIMITS[self.premium_tier].emoji)
|
||||
|
||||
@property
|
||||
def sticker_limit(self) -> int:
|
||||
""":class:`int`: The maximum number of sticker slots this guild has.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
more_stickers = 60 if 'MORE_STICKERS' in self.features else 0
|
||||
return max(more_stickers, self._PREMIUM_GUILD_LIMITS[self.premium_tier].stickers)
|
||||
|
||||
@property
|
||||
def bitrate_limit(self) -> float:
|
||||
""":class:`float`: The maximum bitrate for voice channels this guild can have."""
|
||||
@@ -696,7 +740,21 @@ class Guild(Hashable):
|
||||
"""List[:class:`Member`]: A list of members that belong to this guild."""
|
||||
return list(self._members.values())
|
||||
|
||||
def get_member(self, user_id: int) -> Optional[Member]:
|
||||
@property
|
||||
def humans(self) -> List[Member]:
|
||||
"""List[:class:`Member`]: A list of human members that belong to this guild.
|
||||
|
||||
.. versionadded:: 2.0 """
|
||||
return [member for member in self.members if not member.bot]
|
||||
|
||||
@property
|
||||
def bots(self) -> List[Member]:
|
||||
"""List[:class:`Member`]: A list of bots that belong to this guild.
|
||||
|
||||
.. versionadded:: 2.0 """
|
||||
return [member for member in self.members if member.bot]
|
||||
|
||||
def get_member(self, user_id: int, /) -> Optional[Member]:
|
||||
"""Returns a member with the given ID.
|
||||
|
||||
Parameters
|
||||
@@ -1316,7 +1374,7 @@ class Guild(Hashable):
|
||||
preferred_locale: str = MISSING,
|
||||
rules_channel: Optional[TextChannel] = MISSING,
|
||||
public_updates_channel: Optional[TextChannel] = MISSING,
|
||||
) -> None:
|
||||
) -> Guild:
|
||||
r"""|coro|
|
||||
|
||||
Edits the guild.
|
||||
@@ -1330,6 +1388,9 @@ class Guild(Hashable):
|
||||
.. versionchanged:: 2.0
|
||||
The `discovery_splash` and `community` keyword-only parameters were added.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
The newly updated guild is returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: :class:`str`
|
||||
@@ -1403,6 +1464,12 @@ class Guild(Hashable):
|
||||
The image format passed in to ``icon`` is invalid. It must be
|
||||
PNG or JPG. This is also raised if you are not the owner of the
|
||||
guild and request an ownership transfer.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`Guild`
|
||||
The newly updated guild. Note that this has the same limitations as
|
||||
mentioned in :meth:`Client.fetch_guild` and may not have full data.
|
||||
"""
|
||||
|
||||
http = self._state.http
|
||||
@@ -1515,7 +1582,8 @@ class Guild(Hashable):
|
||||
|
||||
fields['features'] = features
|
||||
|
||||
await http.edit_guild(self.id, reason=reason, **fields)
|
||||
data = await http.edit_guild(self.id, reason=reason, **fields)
|
||||
return Guild(data=data, state=self._state)
|
||||
|
||||
async def fetch_channels(self) -> Sequence[GuildChannel]:
|
||||
"""|coro|
|
||||
@@ -1552,6 +1620,35 @@ class Guild(Hashable):
|
||||
|
||||
return [convert(d) for d in data]
|
||||
|
||||
async def active_threads(self) -> List[Thread]:
|
||||
"""|coro|
|
||||
|
||||
Returns a list of active :class:`Thread` that the client can access.
|
||||
|
||||
This includes both private and public threads.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Raises
|
||||
------
|
||||
HTTPException
|
||||
The request to get the active threads failed.
|
||||
|
||||
Returns
|
||||
--------
|
||||
List[:class:`Thread`]
|
||||
The active threads
|
||||
"""
|
||||
data = await self._state.http.get_active_threads(self.id)
|
||||
threads = [Thread(guild=self, state=self._state, data=d) for d in data.get('threads', [])]
|
||||
thread_lookup: Dict[int, Thread] = {thread.id: thread for thread in threads}
|
||||
for member in data.get('members', []):
|
||||
thread = thread_lookup.get(int(member['id']))
|
||||
if thread is not None:
|
||||
thread._add_member(ThreadMember(parent=thread, data=member))
|
||||
|
||||
return threads
|
||||
|
||||
# TODO: Remove Optional typing here when async iterators are refactored
|
||||
def fetch_members(self, *, limit: int = 1000, after: Optional[SnowflakeTime] = None) -> MemberIterator:
|
||||
"""Retrieves an :class:`.AsyncIterator` that enables receiving the guild's members. In order to use this,
|
||||
@@ -1665,14 +1762,14 @@ class Guild(Hashable):
|
||||
data: BanPayload = await self._state.http.get_ban(user.id, self.id)
|
||||
return BanEntry(user=User(state=self._state, data=data['user']), reason=data['reason'])
|
||||
|
||||
async def fetch_channel(self, channel_id: int, /) -> GuildChannel:
|
||||
async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, Thread]:
|
||||
"""|coro|
|
||||
|
||||
Retrieves a :class:`.abc.GuildChannel` with the specified ID.
|
||||
Retrieves a :class:`.abc.GuildChannel` or :class:`.Thread` with the specified ID.
|
||||
|
||||
.. note::
|
||||
|
||||
This method is an API call. For general usage, consider :meth:`get_channel` instead.
|
||||
This method is an API call. For general usage, consider :meth:`get_channel_or_thread` instead.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
@@ -1691,12 +1788,12 @@ class Guild(Hashable):
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`.abc.GuildChannel`
|
||||
Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
|
||||
The channel from the ID.
|
||||
"""
|
||||
data = await self._state.http.get_channel(channel_id)
|
||||
|
||||
factory, ch_type = _guild_channel_factory(data['type'])
|
||||
factory, ch_type = _threaded_guild_channel_factory(data['type'])
|
||||
if factory is None:
|
||||
raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data))
|
||||
|
||||
@@ -2009,6 +2106,150 @@ class Guild(Hashable):
|
||||
|
||||
return [convert(d) for d in data]
|
||||
|
||||
async def fetch_stickers(self) -> List[GuildSticker]:
|
||||
r"""|coro|
|
||||
|
||||
Retrieves a list of all :class:`Sticker`\s for the guild.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
.. note::
|
||||
|
||||
This method is an API call. For general usage, consider :attr:`stickers` instead.
|
||||
|
||||
Raises
|
||||
---------
|
||||
HTTPException
|
||||
An error occurred fetching the stickers.
|
||||
|
||||
Returns
|
||||
--------
|
||||
List[:class:`GuildSticker`]
|
||||
The retrieved stickers.
|
||||
"""
|
||||
data = await self._state.http.get_all_guild_stickers(self.id)
|
||||
return [GuildSticker(state=self._state, data=d) for d in data]
|
||||
|
||||
async def fetch_sticker(self, sticker_id: int, /) -> GuildSticker:
|
||||
"""|coro|
|
||||
|
||||
Retrieves a custom :class:`Sticker` from the guild.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
.. note::
|
||||
|
||||
This method is an API call.
|
||||
For general usage, consider iterating over :attr:`stickers` instead.
|
||||
|
||||
Parameters
|
||||
-------------
|
||||
sticker_id: :class:`int`
|
||||
The sticker's ID.
|
||||
|
||||
Raises
|
||||
---------
|
||||
NotFound
|
||||
The sticker requested could not be found.
|
||||
HTTPException
|
||||
An error occurred fetching the sticker.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`GuildSticker`
|
||||
The retrieved sticker.
|
||||
"""
|
||||
data = await self._state.http.get_guild_sticker(self.id, sticker_id)
|
||||
return GuildSticker(state=self._state, data=data)
|
||||
|
||||
async def create_sticker(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
description: Optional[str] = None,
|
||||
emoji: str,
|
||||
file: File,
|
||||
reason: Optional[str] = None,
|
||||
) -> GuildSticker:
|
||||
"""|coro|
|
||||
|
||||
Creates a :class:`Sticker` for the guild.
|
||||
|
||||
You must have :attr:`~Permissions.manage_emojis_and_stickers` permission to
|
||||
do this.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
name: :class:`str`
|
||||
The sticker name. Must be at least 2 characters.
|
||||
description: Optional[:class:`str`]
|
||||
The sticker's description. Can be ``None``.
|
||||
emoji: :class:`str`
|
||||
The name of a unicode emoji that represents the sticker's expression.
|
||||
file: :class:`File`
|
||||
The file of the sticker to upload.
|
||||
reason: :class:`str`
|
||||
The reason for creating this sticker. Shows up on the audit log.
|
||||
|
||||
Raises
|
||||
-------
|
||||
Forbidden
|
||||
You are not allowed to create stickers.
|
||||
HTTPException
|
||||
An error occurred creating a sticker.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`GuildSticker`
|
||||
The created sticker.
|
||||
"""
|
||||
payload = {
|
||||
'name': name,
|
||||
}
|
||||
|
||||
if description:
|
||||
payload['description'] = description
|
||||
|
||||
try:
|
||||
emoji = unicodedata.name(emoji)
|
||||
except TypeError:
|
||||
pass
|
||||
else:
|
||||
emoji = emoji.replace(' ', '_')
|
||||
|
||||
payload['tags'] = emoji
|
||||
|
||||
data = await self._state.http.create_guild_sticker(self.id, payload, file, reason)
|
||||
return self._state.store_sticker(self, data)
|
||||
|
||||
async def delete_sticker(self, sticker: Snowflake, *, reason: Optional[str] = None) -> None:
|
||||
"""|coro|
|
||||
|
||||
Deletes the custom :class:`Sticker` from the guild.
|
||||
|
||||
You must have :attr:`~Permissions.manage_emojis_and_stickers` permission to
|
||||
do this.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
sticker: :class:`abc.Snowflake`
|
||||
The sticker you are deleting.
|
||||
reason: Optional[:class:`str`]
|
||||
The reason for deleting this sticker. Shows up on the audit log.
|
||||
|
||||
Raises
|
||||
-------
|
||||
Forbidden
|
||||
You are not allowed to delete stickers.
|
||||
HTTPException
|
||||
An error occurred deleting the sticker.
|
||||
"""
|
||||
await self._state.http.delete_guild_sticker(self.id, sticker.id, reason)
|
||||
|
||||
async def fetch_emojis(self) -> List[Emoji]:
|
||||
r"""|coro|
|
||||
|
||||
|
||||
246
discord/http.py
246
discord/http.py
@@ -33,7 +33,6 @@ from typing import (
|
||||
ClassVar,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Final,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
@@ -49,12 +48,12 @@ import weakref
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .errors import HTTPException, Forbidden, NotFound, LoginFailure, DiscordServerError, GatewayNotFound
|
||||
from .errors import HTTPException, Forbidden, NotFound, LoginFailure, DiscordServerError, GatewayNotFound, InvalidArgument
|
||||
from .gateway import DiscordClientWebSocketResponse
|
||||
from . import __version__, utils
|
||||
from .utils import MISSING
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .file import File
|
||||
@@ -84,6 +83,7 @@ if TYPE_CHECKING:
|
||||
widget,
|
||||
threads,
|
||||
voice,
|
||||
sticker,
|
||||
)
|
||||
from .types.snowflake import Snowflake, SnowflakeList
|
||||
|
||||
@@ -99,7 +99,7 @@ async def json_or_text(response: aiohttp.ClientResponse) -> Union[Dict[str, Any]
|
||||
text = await response.text(encoding='utf-8')
|
||||
try:
|
||||
if response.headers['content-type'] == 'application/json':
|
||||
return utils.from_json(text)
|
||||
return utils._from_json(text)
|
||||
except KeyError:
|
||||
# Thanks Cloudflare
|
||||
pass
|
||||
@@ -141,7 +141,8 @@ class MaybeUnlock:
|
||||
def defer(self) -> None:
|
||||
self._unlock = False
|
||||
|
||||
def __exit__(self,
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BE]],
|
||||
exc: Optional[BE],
|
||||
traceback: Optional[TracebackType],
|
||||
@@ -152,15 +153,12 @@ class MaybeUnlock:
|
||||
|
||||
# For some reason, the Discord voice websocket expects this header to be
|
||||
# completely lowercase while aiohttp respects spec and does it as case-insensitive
|
||||
aiohttp.hdrs.WEBSOCKET = 'websocket' #type: ignore
|
||||
aiohttp.hdrs.WEBSOCKET = 'websocket' # type: ignore
|
||||
|
||||
|
||||
class HTTPClient:
|
||||
"""Represents an HTTP client sending HTTP requests to the Discord API."""
|
||||
|
||||
SUCCESS_LOG: Final[ClassVar[str]] = '{method} {url} has received {text}'
|
||||
REQUEST_LOG: Final[ClassVar[str]] = '{method} {url} with {json} has returned {status}'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: Optional[aiohttp.BaseConnector] = None,
|
||||
@@ -168,7 +166,7 @@ class HTTPClient:
|
||||
proxy: Optional[str] = None,
|
||||
proxy_auth: Optional[aiohttp.BasicAuth] = None,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
unsync_clock: bool = True
|
||||
unsync_clock: bool = True,
|
||||
) -> None:
|
||||
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
|
||||
self.connector = connector
|
||||
@@ -212,7 +210,7 @@ class HTTPClient:
|
||||
*,
|
||||
files: Optional[Sequence[File]] = None,
|
||||
form: Optional[Iterable[Dict[str, Any]]] = None,
|
||||
**kwargs: Any
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
bucket = route.bucket
|
||||
method = route.method
|
||||
@@ -234,7 +232,7 @@ class HTTPClient:
|
||||
# some checking if it's a JSON request
|
||||
if 'json' in kwargs:
|
||||
headers['Content-Type'] = 'application/json'
|
||||
kwargs['data'] = utils.to_json(kwargs.pop('json'))
|
||||
kwargs['data'] = utils._to_json(kwargs.pop('json'))
|
||||
|
||||
try:
|
||||
reason = kwargs.pop('reason')
|
||||
@@ -273,7 +271,7 @@ class HTTPClient:
|
||||
|
||||
try:
|
||||
async with self.__session.request(method, url, **kwargs) as response:
|
||||
log.debug('%s %s with %s has returned %s', method, url, kwargs.get('data'), response.status)
|
||||
_log.debug('%s %s with %s has returned %s', method, url, kwargs.get('data'), response.status)
|
||||
|
||||
# even errors have text involved in them so this is safe to call
|
||||
data = await json_or_text(response)
|
||||
@@ -283,13 +281,13 @@ class HTTPClient:
|
||||
if remaining == '0' and response.status != 429:
|
||||
# we've depleted our current bucket
|
||||
delta = utils._parse_ratelimit_header(response, use_clock=self.use_clock)
|
||||
log.debug('A rate limit bucket has been exhausted (bucket: %s, retry: %s).', bucket, delta)
|
||||
_log.debug('A rate limit bucket has been exhausted (bucket: %s, retry: %s).', bucket, delta)
|
||||
maybe_lock.defer()
|
||||
self.loop.call_later(delta, lock.release)
|
||||
|
||||
# the request was successful so just return the text/json
|
||||
if 300 > response.status >= 200:
|
||||
log.debug('%s %s has received %s', method, url, data)
|
||||
_log.debug('%s %s has received %s', method, url, data)
|
||||
return data
|
||||
|
||||
# we are being rate limited
|
||||
@@ -302,22 +300,22 @@ class HTTPClient:
|
||||
|
||||
# sleep a bit
|
||||
retry_after: float = data['retry_after']
|
||||
log.warning(fmt, retry_after, bucket)
|
||||
_log.warning(fmt, retry_after, bucket)
|
||||
|
||||
# check if it's a global rate limit
|
||||
is_global = data.get('global', False)
|
||||
if is_global:
|
||||
log.warning('Global rate limit has been hit. Retrying in %.2f seconds.', retry_after)
|
||||
_log.warning('Global rate limit has been hit. Retrying in %.2f seconds.', retry_after)
|
||||
self._global_over.clear()
|
||||
|
||||
await asyncio.sleep(retry_after)
|
||||
log.debug('Done sleeping for the rate limit. Retrying...')
|
||||
_log.debug('Done sleeping for the rate limit. Retrying...')
|
||||
|
||||
# release the global lock now that the
|
||||
# global rate limit has passed
|
||||
if is_global:
|
||||
self._global_over.set()
|
||||
log.debug('Global rate limit is now over.')
|
||||
_log.debug('Global rate limit is now over.')
|
||||
|
||||
continue
|
||||
|
||||
@@ -415,14 +413,15 @@ class HTTPClient:
|
||||
def send_message(
|
||||
self,
|
||||
channel_id: Snowflake,
|
||||
content: str,
|
||||
content: Optional[str],
|
||||
*,
|
||||
tts: bool = False,
|
||||
embed: Optional[embed.Embed] = None,
|
||||
embeds: Optional[List[embed.Embed]] = None,
|
||||
nonce: Optional[str] = None,
|
||||
nonce: Optional[str] = None,
|
||||
allowed_mentions: Optional[message.AllowedMentions] = None,
|
||||
message_reference: Optional[message.MessageReference] = None,
|
||||
stickers: Optional[List[sticker.StickerItem]] = None,
|
||||
components: Optional[List[components.Component]] = None,
|
||||
) -> Response[message.Message]:
|
||||
r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id)
|
||||
@@ -436,7 +435,7 @@ class HTTPClient:
|
||||
|
||||
if embed:
|
||||
payload['embeds'] = [embed]
|
||||
|
||||
|
||||
if embeds:
|
||||
payload['embeds'] = embeds
|
||||
|
||||
@@ -452,6 +451,9 @@ class HTTPClient:
|
||||
if components:
|
||||
payload['components'] = components
|
||||
|
||||
if stickers:
|
||||
payload['sticker_ids'] = stickers
|
||||
|
||||
return self.request(r, json=payload)
|
||||
|
||||
def send_typing(self, channel_id: Snowflake) -> Response[None]:
|
||||
@@ -465,10 +467,11 @@ class HTTPClient:
|
||||
content: Optional[str] = None,
|
||||
tts: bool = False,
|
||||
embed: Optional[embed.Embed] = None,
|
||||
embeds: Iterable[Optional[embed.Embed]] = None,
|
||||
embeds: Optional[Iterable[Optional[embed.Embed]]] = None,
|
||||
nonce: Optional[str] = None,
|
||||
allowed_mentions: Optional[message.AllowedMentions] = None,
|
||||
message_reference: Optional[message.MessageReference] = None,
|
||||
stickers: Optional[List[sticker.StickerItem]] = None,
|
||||
components: Optional[List[components.Component]] = None,
|
||||
) -> Response[message.Message]:
|
||||
form = []
|
||||
@@ -488,8 +491,10 @@ class HTTPClient:
|
||||
payload['message_reference'] = message_reference
|
||||
if components:
|
||||
payload['components'] = components
|
||||
if stickers:
|
||||
payload['sticker_ids'] = stickers
|
||||
|
||||
form.append({'name': 'payload_json', 'value': utils.to_json(payload)})
|
||||
form.append({'name': 'payload_json', 'value': utils._to_json(payload)})
|
||||
if len(files) == 1:
|
||||
file = files[0]
|
||||
form.append(
|
||||
@@ -525,6 +530,7 @@ class HTTPClient:
|
||||
nonce: Optional[str] = None,
|
||||
allowed_mentions: Optional[message.AllowedMentions] = None,
|
||||
message_reference: Optional[message.MessageReference] = None,
|
||||
stickers: Optional[List[sticker.StickerItem]] = None,
|
||||
components: Optional[List[components.Component]] = None,
|
||||
) -> Response[message.Message]:
|
||||
r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id)
|
||||
@@ -538,14 +544,19 @@ class HTTPClient:
|
||||
nonce=nonce,
|
||||
allowed_mentions=allowed_mentions,
|
||||
message_reference=message_reference,
|
||||
stickers=stickers,
|
||||
components=components,
|
||||
)
|
||||
|
||||
def delete_message(self, channel_id: Snowflake, message_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]:
|
||||
def delete_message(
|
||||
self, channel_id: Snowflake, message_id: Snowflake, *, reason: Optional[str] = None
|
||||
) -> Response[None]:
|
||||
r = Route('DELETE', '/channels/{channel_id}/messages/{message_id}', channel_id=channel_id, message_id=message_id)
|
||||
return self.request(r, reason=reason)
|
||||
|
||||
def delete_messages(self, channel_id: Snowflake, message_ids: SnowflakeList, *, reason: Optional[str] = None) -> Response[None]:
|
||||
def delete_messages(
|
||||
self, channel_id: Snowflake, message_ids: SnowflakeList, *, reason: Optional[str] = None
|
||||
) -> Response[None]:
|
||||
r = Route('POST', '/channels/{channel_id}/messages/bulk-delete', channel_id=channel_id)
|
||||
payload = {
|
||||
'messages': message_ids,
|
||||
@@ -567,7 +578,9 @@ class HTTPClient:
|
||||
)
|
||||
return self.request(r)
|
||||
|
||||
def remove_reaction(self, channel_id: Snowflake, message_id: Snowflake, emoji: str, member_id: Snowflake) -> Response[None]:
|
||||
def remove_reaction(
|
||||
self, channel_id: Snowflake, message_id: Snowflake, emoji: str, member_id: Snowflake
|
||||
) -> Response[None]:
|
||||
r = Route(
|
||||
'DELETE',
|
||||
'/channels/{channel_id}/messages/{message_id}/reactions/{emoji}/{member_id}',
|
||||
@@ -713,11 +726,7 @@ class HTTPClient:
|
||||
'delete_message_days': delete_message_days,
|
||||
}
|
||||
|
||||
if reason:
|
||||
# thanks aiohttp
|
||||
r.url = f'{r.url}?reason={_uriquote(reason)}'
|
||||
|
||||
return self.request(r, params=params)
|
||||
return self.request(r, params=params, reason=reason)
|
||||
|
||||
def unban(self, user_id: Snowflake, guild_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]:
|
||||
r = Route('DELETE', '/guilds/{guild_id}/bans/{user_id}', guild_id=guild_id, user_id=user_id)
|
||||
@@ -772,11 +781,11 @@ class HTTPClient:
|
||||
}
|
||||
return self.request(r, json=payload, reason=reason)
|
||||
|
||||
def edit_my_voice_state(self, guild_id: Snowflake, payload: voice.VoiceState) -> Response[None]:
|
||||
def edit_my_voice_state(self, guild_id: Snowflake, payload: Dict[str, Any]) -> Response[None]:
|
||||
r = Route('PATCH', '/guilds/{guild_id}/voice-states/@me', guild_id=guild_id)
|
||||
return self.request(r, json=payload)
|
||||
|
||||
def edit_voice_state(self, guild_id: Snowflake, user_id: Snowflake, payload: voice.VoiceState) -> Response[None]:
|
||||
def edit_voice_state(self, guild_id: Snowflake, user_id: Snowflake, payload: Dict[str, Any]) -> Response[None]:
|
||||
r = Route('PATCH', '/guilds/{guild_id}/voice-states/{user_id}', guild_id=guild_id, user_id=user_id)
|
||||
return self.request(r, json=payload)
|
||||
|
||||
@@ -787,7 +796,7 @@ class HTTPClient:
|
||||
*,
|
||||
reason: Optional[str] = None,
|
||||
**fields: Any,
|
||||
) -> Response[member.Member]:
|
||||
) -> Response[member.MemberWithUser]:
|
||||
r = Route('PATCH', '/guilds/{guild_id}/members/{user_id}', guild_id=guild_id, user_id=user_id)
|
||||
return self.request(r, json=fields, reason=reason)
|
||||
|
||||
@@ -817,6 +826,8 @@ class HTTPClient:
|
||||
'archived',
|
||||
'auto_archive_duration',
|
||||
'locked',
|
||||
'invitable',
|
||||
'default_auto_archive_duration',
|
||||
)
|
||||
payload = {k: v for k, v in options.items() if k in valid_keys}
|
||||
return self.request(r, reason=reason, json=payload)
|
||||
@@ -871,42 +882,44 @@ class HTTPClient:
|
||||
|
||||
# Thread management
|
||||
|
||||
def start_public_thread(
|
||||
def start_thread_with_message(
|
||||
self,
|
||||
channel_id: Snowflake,
|
||||
message_id: Snowflake,
|
||||
*,
|
||||
name: str,
|
||||
auto_archive_duration: threads.ThreadArchiveDuration,
|
||||
type: threads.ThreadType,
|
||||
reason: Optional[str] = None,
|
||||
) -> Response[threads.Thread]:
|
||||
payload = {
|
||||
'name': name,
|
||||
'auto_archive_duration': auto_archive_duration,
|
||||
'type': type,
|
||||
}
|
||||
|
||||
route = Route(
|
||||
'POST', '/channels/{channel_id}/messages/{message_id}/threads', channel_id=channel_id, message_id=message_id
|
||||
)
|
||||
return self.request(route, json=payload)
|
||||
return self.request(route, json=payload, reason=reason)
|
||||
|
||||
def start_private_thread(
|
||||
def start_thread_without_message(
|
||||
self,
|
||||
channel_id: Snowflake,
|
||||
*,
|
||||
name: str,
|
||||
auto_archive_duration: threads.ThreadArchiveDuration,
|
||||
type: threads.ThreadType,
|
||||
invitable: bool = True,
|
||||
reason: Optional[str] = None,
|
||||
) -> Response[threads.Thread]:
|
||||
payload = {
|
||||
'name': name,
|
||||
'auto_archive_duration': auto_archive_duration,
|
||||
'type': type,
|
||||
'invitable': invitable,
|
||||
}
|
||||
|
||||
route = Route('POST', '/channels/{channel_id}/threads', channel_id=channel_id)
|
||||
return self.request(route, json=payload)
|
||||
return self.request(route, json=payload, reason=reason)
|
||||
|
||||
def join_thread(self, channel_id: Snowflake) -> Response[None]:
|
||||
return self.request(Route('POST', '/channels/{channel_id}/thread-members/@me', channel_id=channel_id))
|
||||
@@ -955,8 +968,8 @@ class HTTPClient:
|
||||
params['limit'] = limit
|
||||
return self.request(route, params=params)
|
||||
|
||||
def get_active_threads(self, channel_id: Snowflake) -> Response[threads.ThreadPaginationPayload]:
|
||||
route = Route('GET', '/channels/{channel_id}/threads/active', channel_id=channel_id)
|
||||
def get_active_threads(self, guild_id: Snowflake) -> Response[threads.ThreadPaginationPayload]:
|
||||
route = Route('GET', '/guilds/{guild_id}/threads/active', guild_id=guild_id)
|
||||
return self.request(route)
|
||||
|
||||
def get_thread_members(self, channel_id: Snowflake) -> Response[List[threads.ThreadMember]]:
|
||||
@@ -1119,7 +1132,9 @@ class HTTPClient:
|
||||
def get_all_guild_channels(self, guild_id: Snowflake) -> Response[List[guild.GuildChannel]]:
|
||||
return self.request(Route('GET', '/guilds/{guild_id}/channels', guild_id=guild_id))
|
||||
|
||||
def get_members(self, guild_id: Snowflake, limit: int, after: Optional[Snowflake]) -> Response[List[member.Member]]:
|
||||
def get_members(
|
||||
self, guild_id: Snowflake, limit: int, after: Optional[Snowflake]
|
||||
) -> Response[List[member.MemberWithUser]]:
|
||||
params: Dict[str, Any] = {
|
||||
'limit': limit,
|
||||
}
|
||||
@@ -1129,7 +1144,7 @@ class HTTPClient:
|
||||
r = Route('GET', '/guilds/{guild_id}/members', guild_id=guild_id)
|
||||
return self.request(r, params=params)
|
||||
|
||||
def get_member(self, guild_id: Snowflake, member_id: Snowflake) -> Response[member.Member]:
|
||||
def get_member(self, guild_id: Snowflake, member_id: Snowflake) -> Response[member.MemberWithUser]:
|
||||
return self.request(Route('GET', '/guilds/{guild_id}/members/{member_id}', guild_id=guild_id, member_id=member_id))
|
||||
|
||||
def prune_members(
|
||||
@@ -1164,6 +1179,71 @@ class HTTPClient:
|
||||
|
||||
return self.request(Route('GET', '/guilds/{guild_id}/prune', guild_id=guild_id), params=params)
|
||||
|
||||
def get_sticker(self, sticker_id: Snowflake) -> Response[sticker.Sticker]:
|
||||
return self.request(Route('GET', '/stickers/{sticker_id}', sticker_id=sticker_id))
|
||||
|
||||
def list_premium_sticker_packs(self) -> Response[sticker.ListPremiumStickerPacks]:
|
||||
return self.request(Route('GET', '/sticker-packs'))
|
||||
|
||||
def get_all_guild_stickers(self, guild_id: Snowflake) -> Response[List[sticker.GuildSticker]]:
|
||||
return self.request(Route('GET', '/guilds/{guild_id}/stickers', guild_id=guild_id))
|
||||
|
||||
def get_guild_sticker(self, guild_id: Snowflake, sticker_id: Snowflake) -> Response[sticker.GuildSticker]:
|
||||
return self.request(
|
||||
Route('GET', '/guilds/{guild_id}/stickers/{sticker_id}', guild_id=guild_id, sticker_id=sticker_id)
|
||||
)
|
||||
|
||||
def create_guild_sticker(
|
||||
self, guild_id: Snowflake, payload: sticker.CreateGuildSticker, file: File, reason: str
|
||||
) -> Response[sticker.GuildSticker]:
|
||||
initial_bytes = file.fp.read(16)
|
||||
|
||||
try:
|
||||
mime_type = utils._get_mime_type_for_image(initial_bytes)
|
||||
except InvalidArgument:
|
||||
if initial_bytes.startswith(b'{'):
|
||||
mime_type = 'application/json'
|
||||
else:
|
||||
mime_type = 'application/octet-stream'
|
||||
finally:
|
||||
file.reset()
|
||||
|
||||
form: List[Dict[str, Any]] = [
|
||||
{
|
||||
'name': 'file',
|
||||
'value': file.fp,
|
||||
'filename': file.filename,
|
||||
'content_type': mime_type,
|
||||
}
|
||||
]
|
||||
|
||||
for k, v in payload.items():
|
||||
form.append(
|
||||
{
|
||||
'name': k,
|
||||
'value': v,
|
||||
}
|
||||
)
|
||||
|
||||
return self.request(
|
||||
Route('POST', '/guilds/{guild_id}/stickers', guild_id=guild_id), form=form, files=[file], reason=reason
|
||||
)
|
||||
|
||||
def modify_guild_sticker(
|
||||
self, guild_id: Snowflake, sticker_id: Snowflake, payload: sticker.EditGuildSticker, reason: Optional[str],
|
||||
) -> Response[sticker.GuildSticker]:
|
||||
return self.request(
|
||||
Route('PATCH', '/guilds/{guild_id}/stickers/{sticker_id}', guild_id=guild_id, sticker_id=sticker_id),
|
||||
json=payload,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
def delete_guild_sticker(self, guild_id: Snowflake, sticker_id: Snowflake, reason: Optional[str]) -> Response[None]:
|
||||
return self.request(
|
||||
Route('DELETE', '/guilds/{guild_id}/stickers/{sticker_id}', guild_id=guild_id, sticker_id=sticker_id),
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
def get_all_custom_emojis(self, guild_id: Snowflake) -> Response[List[emoji.Emoji]]:
|
||||
return self.request(Route('GET', '/guilds/{guild_id}/emojis', guild_id=guild_id))
|
||||
|
||||
@@ -1237,12 +1317,14 @@ class HTTPClient:
|
||||
|
||||
return self.request(r)
|
||||
|
||||
def delete_integration(self, guild_id: Snowflake, integration_id: Snowflake) -> Response[None]:
|
||||
def delete_integration(
|
||||
self, guild_id: Snowflake, integration_id: Snowflake, *, reason: Optional[str] = None
|
||||
) -> Response[None]:
|
||||
r = Route(
|
||||
'DELETE', '/guilds/{guild_id}/integrations/{integration_id}', guild_id=guild_id, integration_id=integration_id
|
||||
)
|
||||
|
||||
return self.request(r)
|
||||
return self.request(r, reason=reason)
|
||||
|
||||
def get_audit_logs(
|
||||
self,
|
||||
@@ -1285,7 +1367,7 @@ class HTTPClient:
|
||||
unique: bool = True,
|
||||
target_type: Optional[invite.InviteTargetType] = None,
|
||||
target_user_id: Optional[Snowflake] = None,
|
||||
target_application_id: Optional[Snowflake] = None
|
||||
target_application_id: Optional[Snowflake] = None,
|
||||
) -> Response[invite.Invite]:
|
||||
r = Route('POST', '/channels/{channel_id}/invites', channel_id=channel_id)
|
||||
payload = {
|
||||
@@ -1306,7 +1388,9 @@ class HTTPClient:
|
||||
|
||||
return self.request(r, reason=reason, json=payload)
|
||||
|
||||
def get_invite(self, invite_id: str, *, with_counts: bool = True, with_expiration: bool = True) -> Response[invite.Invite]:
|
||||
def get_invite(
|
||||
self, invite_id: str, *, with_counts: bool = True, with_expiration: bool = True
|
||||
) -> Response[invite.Invite]:
|
||||
params = {
|
||||
'with_counts': int(with_counts),
|
||||
'with_expiration': int(with_expiration),
|
||||
@@ -1327,7 +1411,9 @@ class HTTPClient:
|
||||
def get_roles(self, guild_id: Snowflake) -> Response[List[role.Role]]:
|
||||
return self.request(Route('GET', '/guilds/{guild_id}/roles', guild_id=guild_id))
|
||||
|
||||
def edit_role(self, guild_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None, **fields: Any) -> Response[role.Role]:
|
||||
def edit_role(
|
||||
self, guild_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None, **fields: Any
|
||||
) -> Response[role.Role]:
|
||||
r = Route('PATCH', '/guilds/{guild_id}/roles/{role_id}', guild_id=guild_id, role_id=role_id)
|
||||
valid_keys = ('name', 'permissions', 'color', 'hoist', 'mentionable')
|
||||
payload = {k: v for k, v in fields.items() if k in valid_keys}
|
||||
@@ -1344,7 +1430,7 @@ class HTTPClient:
|
||||
role_ids: List[int],
|
||||
*,
|
||||
reason: Optional[str] = None,
|
||||
) -> Response[member.Member]:
|
||||
) -> Response[member.MemberWithUser]:
|
||||
return self.edit_member(guild_id=guild_id, user_id=user_id, roles=role_ids, reason=reason)
|
||||
|
||||
def create_role(self, guild_id: Snowflake, *, reason: Optional[str] = None, **fields: Any) -> Response[role.Role]:
|
||||
@@ -1361,7 +1447,9 @@ class HTTPClient:
|
||||
r = Route('PATCH', '/guilds/{guild_id}/roles', guild_id=guild_id)
|
||||
return self.request(r, json=positions, reason=reason)
|
||||
|
||||
def add_role(self, guild_id: Snowflake, user_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]:
|
||||
def add_role(
|
||||
self, guild_id: Snowflake, user_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None
|
||||
) -> Response[None]:
|
||||
r = Route(
|
||||
'PUT',
|
||||
'/guilds/{guild_id}/members/{user_id}/roles/{role_id}',
|
||||
@@ -1371,7 +1459,9 @@ class HTTPClient:
|
||||
)
|
||||
return self.request(r, reason=reason)
|
||||
|
||||
def remove_role(self, guild_id: Snowflake, user_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]:
|
||||
def remove_role(
|
||||
self, guild_id: Snowflake, user_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None
|
||||
) -> Response[None]:
|
||||
r = Route(
|
||||
'DELETE',
|
||||
'/guilds/{guild_id}/members/{user_id}/roles/{role_id}',
|
||||
@@ -1396,11 +1486,7 @@ class HTTPClient:
|
||||
return self.request(r, json=payload, reason=reason)
|
||||
|
||||
def delete_channel_permissions(
|
||||
self,
|
||||
channel_id: Snowflake,
|
||||
target: channel.OverwriteType,
|
||||
*,
|
||||
reason: Optional[str] = None
|
||||
self, channel_id: Snowflake, target: channel.OverwriteType, *, reason: Optional[str] = None
|
||||
) -> Response[None]:
|
||||
r = Route('DELETE', '/channels/{channel_id}/permissions/{target}', channel_id=channel_id, target=target)
|
||||
return self.request(r, reason=reason)
|
||||
@@ -1414,7 +1500,7 @@ class HTTPClient:
|
||||
channel_id: Snowflake,
|
||||
*,
|
||||
reason: Optional[str] = None,
|
||||
) -> Response[member.Member]:
|
||||
) -> Response[member.MemberWithUser]:
|
||||
return self.edit_member(guild_id=guild_id, user_id=user_id, channel_id=channel_id, reason=reason)
|
||||
|
||||
# Stage instance management
|
||||
@@ -1422,7 +1508,7 @@ class HTTPClient:
|
||||
def get_stage_instance(self, channel_id: Snowflake) -> Response[channel.StageInstance]:
|
||||
return self.request(Route('GET', '/stage-instances/{channel_id}', channel_id=channel_id))
|
||||
|
||||
def create_stage_instance(self, **payload) -> Response[channel.StageInstance]:
|
||||
def create_stage_instance(self, *, reason: Optional[str], **payload: Any) -> Response[channel.StageInstance]:
|
||||
valid_keys = (
|
||||
'channel_id',
|
||||
'topic',
|
||||
@@ -1430,26 +1516,30 @@ class HTTPClient:
|
||||
)
|
||||
payload = {k: v for k, v in payload.items() if k in valid_keys}
|
||||
|
||||
return self.request(Route('POST', '/stage-instances'), json=payload)
|
||||
return self.request(Route('POST', '/stage-instances'), json=payload, reason=reason)
|
||||
|
||||
def edit_stage_instance(self, channel_id: Snowflake, **payload) -> Response[None]:
|
||||
def edit_stage_instance(self, channel_id: Snowflake, *, reason: Optional[str] = None, **payload: Any) -> Response[None]:
|
||||
valid_keys = (
|
||||
'topic',
|
||||
'privacy_level',
|
||||
)
|
||||
payload = {k: v for k, v in payload.items() if k in valid_keys}
|
||||
|
||||
return self.request(Route('PATCH', '/stage-instances/{channel_id}', channel_id=channel_id), json=payload)
|
||||
return self.request(
|
||||
Route('PATCH', '/stage-instances/{channel_id}', channel_id=channel_id), json=payload, reason=reason
|
||||
)
|
||||
|
||||
def delete_stage_instance(self, channel_id: Snowflake) -> Response[None]:
|
||||
return self.request(Route('DELETE', '/stage-instances/{channel_id}', channel_id=channel_id))
|
||||
def delete_stage_instance(self, channel_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]:
|
||||
return self.request(Route('DELETE', '/stage-instances/{channel_id}', channel_id=channel_id), reason=reason)
|
||||
|
||||
# Application commands (global)
|
||||
|
||||
def get_global_commands(self, application_id: Snowflake) -> Response[List[interactions.ApplicationCommand]]:
|
||||
return self.request(Route('GET', '/applications/{application_id}/commands', application_id=application_id))
|
||||
|
||||
def get_global_command(self, application_id: Snowflake, command_id: Snowflake) -> Response[interactions.ApplicationCommand]:
|
||||
def get_global_command(
|
||||
self, application_id: Snowflake, command_id: Snowflake
|
||||
) -> Response[interactions.ApplicationCommand]:
|
||||
r = Route(
|
||||
'GET',
|
||||
'/applications/{application_id}/commands/{command_id}',
|
||||
@@ -1462,7 +1552,8 @@ class HTTPClient:
|
||||
r = Route('POST', '/applications/{application_id}/commands', application_id=application_id)
|
||||
return self.request(r, json=payload)
|
||||
|
||||
def edit_global_command(self,
|
||||
def edit_global_command(
|
||||
self,
|
||||
application_id: Snowflake,
|
||||
command_id: Snowflake,
|
||||
payload: interactions.EditApplicationCommand,
|
||||
@@ -1490,13 +1581,17 @@ class HTTPClient:
|
||||
)
|
||||
return self.request(r)
|
||||
|
||||
def bulk_upsert_global_commands(self, application_id: Snowflake, payload) -> Response[List[interactions.ApplicationCommand]]:
|
||||
def bulk_upsert_global_commands(
|
||||
self, application_id: Snowflake, payload
|
||||
) -> Response[List[interactions.ApplicationCommand]]:
|
||||
r = Route('PUT', '/applications/{application_id}/commands', application_id=application_id)
|
||||
return self.request(r, json=payload)
|
||||
|
||||
# Application commands (guild)
|
||||
|
||||
def get_guild_commands(self, application_id: Snowflake, guild_id: Snowflake) -> Response[List[interactions.ApplicationCommand]]:
|
||||
def get_guild_commands(
|
||||
self, application_id: Snowflake, guild_id: Snowflake
|
||||
) -> Response[List[interactions.ApplicationCommand]]:
|
||||
r = Route(
|
||||
'GET',
|
||||
'/applications/{application_id}/guilds/{guild_id}/commands',
|
||||
@@ -1534,7 +1629,8 @@ class HTTPClient:
|
||||
)
|
||||
return self.request(r, json=payload)
|
||||
|
||||
def edit_guild_command(self,
|
||||
def edit_guild_command(
|
||||
self,
|
||||
application_id: Snowflake,
|
||||
guild_id: Snowflake,
|
||||
command_id: Snowflake,
|
||||
@@ -1571,9 +1667,9 @@ class HTTPClient:
|
||||
return self.request(r)
|
||||
|
||||
def bulk_upsert_guild_commands(
|
||||
self,
|
||||
self,
|
||||
application_id: Snowflake,
|
||||
guild_id: Snowflake,
|
||||
guild_id: Snowflake,
|
||||
payload: List[interactions.EditApplicationCommand],
|
||||
) -> Response[List[interactions.ApplicationCommand]]:
|
||||
r = Route(
|
||||
@@ -1606,7 +1702,7 @@ class HTTPClient:
|
||||
form: List[Dict[str, Any]] = [
|
||||
{
|
||||
'name': 'payload_json',
|
||||
'value': utils.to_json(payload),
|
||||
'value': utils._to_json(payload),
|
||||
}
|
||||
]
|
||||
|
||||
@@ -1628,7 +1724,7 @@ class HTTPClient:
|
||||
token: str,
|
||||
*,
|
||||
type: InteractionResponseType,
|
||||
data: Optional[interactions.InteractionApplicationCommandCallbackData] = None
|
||||
data: Optional[interactions.InteractionApplicationCommandCallbackData] = None,
|
||||
) -> Response[None]:
|
||||
r = Route(
|
||||
'POST',
|
||||
@@ -1718,7 +1814,7 @@ class HTTPClient:
|
||||
content: Optional[str] = None,
|
||||
embeds: Optional[List[embed.Embed]] = None,
|
||||
allowed_mentions: Optional[message.AllowedMentions] = None,
|
||||
)-> Response[message.Message]:
|
||||
) -> Response[message.Message]:
|
||||
r = Route(
|
||||
'PATCH',
|
||||
'/webhooks/{application_id}/{interaction_token}/messages/{message_id}',
|
||||
|
||||
@@ -127,7 +127,7 @@ class Integration:
|
||||
self.user = User(state=self._state, data=user) if user else None
|
||||
self.enabled: bool = data['enabled']
|
||||
|
||||
async def delete(self) -> None:
|
||||
async def delete(self, *, reason: Optional[str] = None) -> None:
|
||||
"""|coro|
|
||||
|
||||
Deletes the integration.
|
||||
@@ -135,6 +135,13 @@ class Integration:
|
||||
You must have the :attr:`~Permissions.manage_guild` permission to
|
||||
do this.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
reason: :class:`str`
|
||||
The reason the integration was deleted. Shows up on the audit log.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Raises
|
||||
-------
|
||||
Forbidden
|
||||
@@ -142,7 +149,7 @@ class Integration:
|
||||
HTTPException
|
||||
Deleting the integration failed.
|
||||
"""
|
||||
await self._state.http.delete_integration(self.guild.id, self.id)
|
||||
await self._state.http.delete_integration(self.guild.id, self.id, reason=reason)
|
||||
|
||||
|
||||
class StreamIntegration(Integration):
|
||||
@@ -255,17 +262,10 @@ class StreamIntegration(Integration):
|
||||
if enable_emoticons is not MISSING:
|
||||
payload['enable_emoticons'] = enable_emoticons
|
||||
|
||||
# This endpoint is undocumented.
|
||||
# Unsure if it returns the data or not as a result
|
||||
await self._state.http.edit_integration(self.guild.id, self.id, **payload)
|
||||
|
||||
if expire_behaviour is not MISSING:
|
||||
self.expire_behaviour = expire_behaviour
|
||||
|
||||
if enable_emoticons is not MISSING:
|
||||
self.enable_emoticons = enable_emoticons
|
||||
|
||||
if expire_grace_period is not MISSING:
|
||||
self.expire_grace_period = expire_grace_period
|
||||
|
||||
async def sync(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ import asyncio
|
||||
from . import utils
|
||||
from .enums import try_enum, InteractionType, InteractionResponseType
|
||||
from .errors import InteractionResponded, HTTPException, ClientException
|
||||
from .channel import PartialMessageable, ChannelType
|
||||
|
||||
from .user import User
|
||||
from .member import Member
|
||||
@@ -57,10 +58,12 @@ if TYPE_CHECKING:
|
||||
from aiohttp import ClientSession
|
||||
from .embeds import Embed
|
||||
from .ui.view import View
|
||||
from .channel import VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel
|
||||
from .channel import VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, PartialMessageable
|
||||
from .threads import Thread
|
||||
|
||||
InteractionChannel = Union[VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, Thread]
|
||||
InteractionChannel = Union[
|
||||
VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, Thread, PartialMessageable
|
||||
]
|
||||
|
||||
MISSING: Any = utils.MISSING
|
||||
|
||||
@@ -92,6 +95,8 @@ class Interaction:
|
||||
token: :class:`str`
|
||||
The token to continue the interaction. These are valid
|
||||
for 15 minutes.
|
||||
data: :class:`dict`
|
||||
The raw interaction data.
|
||||
"""
|
||||
|
||||
__slots__: Tuple[str, ...] = (
|
||||
@@ -111,6 +116,7 @@ class Interaction:
|
||||
'_original_message',
|
||||
'_cs_response',
|
||||
'_cs_followup',
|
||||
'_cs_channel',
|
||||
)
|
||||
|
||||
def __init__(self, *, data: InteractionPayload, state: ConnectionState):
|
||||
@@ -129,10 +135,9 @@ class Interaction:
|
||||
self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id')
|
||||
self.application_id: int = int(data['application_id'])
|
||||
|
||||
channel = self.channel or Object(id=self.channel_id) # type: ignore
|
||||
self.message: Optional[Message]
|
||||
try:
|
||||
self.message = Message(state=self._state, channel=channel, data=data['message']) # type: ignore
|
||||
self.message = Message(state=self._state, channel=self.channel, data=data['message']) # type: ignore
|
||||
except KeyError:
|
||||
self.message = None
|
||||
|
||||
@@ -160,15 +165,21 @@ class Interaction:
|
||||
"""Optional[:class:`Guild`]: The guild the interaction was sent from."""
|
||||
return self._state and self._state._get_guild(self.guild_id)
|
||||
|
||||
@property
|
||||
@utils.cached_slot_property('_cs_channel')
|
||||
def channel(self) -> Optional[InteractionChannel]:
|
||||
"""Optional[Union[:class:`abc.GuildChannel`, :class:`Thread`]]: The channel the interaction was sent from.
|
||||
"""Optional[Union[:class:`abc.GuildChannel`, :class:`PartialMessageable`, :class:`Thread`]]: The channel the interaction was sent from.
|
||||
|
||||
Note that due to a Discord limitation, DM channels are not resolved since there is
|
||||
no data to complete them.
|
||||
no data to complete them. These are :class:`PartialMessageable` instead.
|
||||
"""
|
||||
guild = self.guild
|
||||
return guild and guild._resolve_channel(self.channel_id)
|
||||
channel = guild and guild._resolve_channel(self.channel_id)
|
||||
if channel is None:
|
||||
if self.channel_id is not None:
|
||||
type = ChannelType.text if self.guild_id is not None else ChannelType.private
|
||||
return PartialMessageable(state=self._state, id=self.channel_id, type=type)
|
||||
return None
|
||||
return channel
|
||||
|
||||
@property
|
||||
def permissions(self) -> Permissions:
|
||||
@@ -250,7 +261,7 @@ class Interaction:
|
||||
files: List[File] = MISSING,
|
||||
view: Optional[View] = MISSING,
|
||||
allowed_mentions: Optional[AllowedMentions] = None,
|
||||
):
|
||||
) -> InteractionMessage:
|
||||
"""|coro|
|
||||
|
||||
Edits the original interaction response message.
|
||||
@@ -291,7 +302,12 @@ class Interaction:
|
||||
TypeError
|
||||
You specified both ``embed`` and ``embeds`` or ``file`` and ``files``
|
||||
ValueError
|
||||
The length of ``embeds`` was invalid
|
||||
The length of ``embeds`` was invalid.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`InteractionMessage`
|
||||
The newly edited message.
|
||||
"""
|
||||
|
||||
previous_mentions: Optional[AllowedMentions] = self._state.allowed_mentions
|
||||
@@ -315,8 +331,11 @@ class Interaction:
|
||||
files=params.files,
|
||||
)
|
||||
|
||||
# The message channel types should always match
|
||||
message = InteractionMessage(state=self._state, channel=self.channel, data=data) # type: ignore
|
||||
if view and not view.is_finished():
|
||||
self._state.store_view(view, int(data['id']))
|
||||
self._state.store_view(view, message.id)
|
||||
return message
|
||||
|
||||
async def delete_original_message(self) -> None:
|
||||
"""|coro|
|
||||
@@ -626,6 +645,9 @@ class _InteractionMessageState:
|
||||
def store_user(self, data):
|
||||
return self._parent.store_user(data)
|
||||
|
||||
def create_user(self, data):
|
||||
return self._parent.create_user(data)
|
||||
|
||||
@property
|
||||
def http(self):
|
||||
return self._parent.http
|
||||
@@ -658,7 +680,7 @@ class InteractionMessage(Message):
|
||||
files: List[File] = MISSING,
|
||||
view: Optional[View] = MISSING,
|
||||
allowed_mentions: Optional[AllowedMentions] = None,
|
||||
):
|
||||
) -> InteractionMessage:
|
||||
"""|coro|
|
||||
|
||||
Edits the message.
|
||||
@@ -693,9 +715,14 @@ class InteractionMessage(Message):
|
||||
TypeError
|
||||
You specified both ``embed`` and ``embeds`` or ``file`` and ``files``
|
||||
ValueError
|
||||
The length of ``embeds`` was invalid
|
||||
The length of ``embeds`` was invalid.
|
||||
|
||||
Returns
|
||||
---------
|
||||
:class:`InteractionMessage`
|
||||
The newly edited message.
|
||||
"""
|
||||
await self._state._interaction.edit_original_message(
|
||||
return await self._state._interaction.edit_original_message(
|
||||
content=content,
|
||||
embeds=embeds,
|
||||
embed=embed,
|
||||
|
||||
@@ -230,6 +230,7 @@ class Invite(Hashable):
|
||||
|
||||
Returns the invite URL.
|
||||
|
||||
|
||||
The following table illustrates what methods will obtain the attributes:
|
||||
|
||||
+------------------------------------+------------------------------------------------------------+
|
||||
@@ -257,7 +258,7 @@ class Invite(Hashable):
|
||||
Attributes
|
||||
-----------
|
||||
max_age: :class:`int`
|
||||
How long the before the invite expires in seconds.
|
||||
How long before the invite expires in seconds.
|
||||
A value of ``0`` indicates that it doesn't expire.
|
||||
code: :class:`str`
|
||||
The URL fragment used for the invite.
|
||||
@@ -352,12 +353,12 @@ class Invite(Hashable):
|
||||
self.expires_at: Optional[datetime.datetime] = parse_time(expires_at) if expires_at else None
|
||||
|
||||
inviter_data = data.get('inviter')
|
||||
self.inviter: Optional[User] = None if inviter_data is None else self._state.store_user(inviter_data)
|
||||
self.inviter: Optional[User] = None if inviter_data is None else self._state.create_user(inviter_data)
|
||||
|
||||
self.channel: Optional[InviteChannelType] = self._resolve_channel(data.get('channel'), channel)
|
||||
|
||||
target_user_data = data.get('target_user')
|
||||
self.target_user: Optional[User] = None if target_user_data is None else self._state.store_user(target_user_data)
|
||||
self.target_user: Optional[User] = None if target_user_data is None else self._state.create_user(target_user_data)
|
||||
|
||||
self.target_type: InviteTarget = try_enum(InviteTarget, data.get("target_type", 0))
|
||||
|
||||
@@ -433,6 +434,9 @@ class Invite(Hashable):
|
||||
def __str__(self) -> str:
|
||||
return self.url
|
||||
|
||||
def __int__(self) -> int:
|
||||
return 0 # To keep the object compatible with the hashable abc.
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'<Invite code={self.code!r} guild={self.guild!r} '
|
||||
|
||||
@@ -34,6 +34,7 @@ from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Typ
|
||||
import discord.abc
|
||||
|
||||
from . import utils
|
||||
from .asset import Asset
|
||||
from .utils import MISSING
|
||||
from .user import BaseUser, User, _UserTag
|
||||
from .activity import create_activity, ActivityTypes
|
||||
@@ -48,11 +49,13 @@ __all__ = (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .channel import VoiceChannel, StageChannel
|
||||
from .asset import Asset
|
||||
from .channel import DMChannel, VoiceChannel, StageChannel
|
||||
from .flags import PublicUserFlags
|
||||
from .guild import Guild
|
||||
from .types.activity import PartialPresenceUpdate
|
||||
from .types.member import (
|
||||
GatewayMember as GatewayMemberPayload,
|
||||
MemberWithUser as MemberWithUserPayload,
|
||||
Member as MemberPayload,
|
||||
UserWithMember as UserWithMemberPayload,
|
||||
)
|
||||
@@ -223,6 +226,10 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
|
||||
Returns the member's name with the discriminator.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the user's ID.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
joined_at: Optional[:class:`datetime.datetime`]
|
||||
@@ -247,7 +254,7 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
.. versionadded:: 1.6
|
||||
premium_since: Optional[:class:`datetime.datetime`]
|
||||
An aware datetime object that specifies the date and time in UTC when the member used their
|
||||
Nitro boost on the guild, if available. This could be ``None``.
|
||||
"Nitro boost" on the guild, if available. This could be ``None``.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
@@ -261,6 +268,7 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
'_client_status',
|
||||
'_user',
|
||||
'_state',
|
||||
'_avatar',
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -270,14 +278,17 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
bot: bool
|
||||
system: bool
|
||||
created_at: datetime.datetime
|
||||
default_avatar = User.default_avatar
|
||||
avatar = User.avatar
|
||||
dm_channel = User.dm_channel
|
||||
default_avatar: Asset
|
||||
avatar: Optional[Asset]
|
||||
dm_channel: Optional[DMChannel]
|
||||
create_dm = User.create_dm
|
||||
mutual_guilds = User.mutual_guilds
|
||||
public_flags = User.public_flags
|
||||
mutual_guilds: List[Guild]
|
||||
public_flags: PublicUserFlags
|
||||
banner: Optional[Asset]
|
||||
accent_color: Optional[Colour]
|
||||
accent_colour: Optional[Colour]
|
||||
|
||||
def __init__(self, *, data: GatewayMemberPayload, guild: Guild, state: ConnectionState):
|
||||
def __init__(self, *, data: MemberWithUserPayload, guild: Guild, state: ConnectionState):
|
||||
self._state: ConnectionState = state
|
||||
self._user: User = state.store_user(data['user'])
|
||||
self.guild: Guild = guild
|
||||
@@ -288,10 +299,14 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
self.activities: Tuple[ActivityTypes, ...] = tuple()
|
||||
self.nick: Optional[str] = data.get('nick', None)
|
||||
self.pending: bool = data.get('pending', False)
|
||||
self._avatar: Optional[str] = data.get('avatar')
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self._user)
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'<Member id={self._user.id} name={self._user.name!r} discriminator={self._user.discriminator!r}'
|
||||
@@ -326,7 +341,7 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
try:
|
||||
member_data = data.pop('member')
|
||||
except KeyError:
|
||||
return state.store_user(data)
|
||||
return state.create_user(data)
|
||||
else:
|
||||
member_data['user'] = data # type: ignore
|
||||
return cls(data=member_data, guild=guild, state=state) # type: ignore
|
||||
@@ -344,6 +359,7 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
self.pending = member.pending
|
||||
self.activities = member.activities
|
||||
self._state = member._state
|
||||
self._avatar = member._avatar
|
||||
|
||||
# Reference will not be copied unless necessary by PRESENCE_UPDATE
|
||||
# See below
|
||||
@@ -369,6 +385,7 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
|
||||
self.premium_since = utils.parse_time(data.get('premium_since'))
|
||||
self._roles = utils.SnowflakeList(map(int, data['roles']))
|
||||
self._avatar = data.get('avatar')
|
||||
|
||||
def _presence_update(self, data: PartialPresenceUpdate, user: UserPayload) -> Optional[Tuple[User, User]]:
|
||||
self.activities = tuple(map(create_activity, data['activities']))
|
||||
@@ -493,6 +510,29 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
"""
|
||||
return self.nick or self.name
|
||||
|
||||
@property
|
||||
def display_avatar(self) -> Asset:
|
||||
""":class:`Asset`: Returns the member's display avatar.
|
||||
|
||||
For regular members this is just their avatar, but
|
||||
if they have a guild specific avatar then that
|
||||
is returned instead.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return self.guild_avatar or self._user.avatar or self._user.default_avatar
|
||||
|
||||
@property
|
||||
def guild_avatar(self) -> Optional[Asset]:
|
||||
"""Optional[:class:`Asset`]: Returns an :class:`Asset` for the guild avatar
|
||||
the member has. If unavailable, ``None`` is returned.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
if self._avatar is None:
|
||||
return None
|
||||
return Asset._from_guild_avatar(self._state, self.guild.id, self.id, self._avatar)
|
||||
|
||||
@property
|
||||
def activity(self) -> Optional[ActivityTypes]:
|
||||
"""Optional[Union[:class:`BaseActivity`, :class:`Spotify`]]: Returns the primary
|
||||
@@ -611,7 +651,7 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
roles: List[discord.abc.Snowflake] = MISSING,
|
||||
voice_channel: Optional[VocalGuildChannel] = MISSING,
|
||||
reason: Optional[str] = None,
|
||||
) -> None:
|
||||
) -> Optional[Member]:
|
||||
"""|coro|
|
||||
|
||||
Edits the member's data.
|
||||
@@ -637,6 +677,9 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
.. versionchanged:: 1.1
|
||||
Can now pass ``None`` to ``voice_channel`` to kick a member from voice.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
The newly member is now optionally returned, if applicable.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
nick: Optional[:class:`str`]
|
||||
@@ -664,6 +707,12 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
You do not have the proper permissions to the action requested.
|
||||
HTTPException
|
||||
The operation failed.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Optional[:class:`.Member`]
|
||||
The newly updated member, if applicable. This is only returned
|
||||
when certain fields are updated.
|
||||
"""
|
||||
http = self._state.http
|
||||
guild_id = self.guild.id
|
||||
@@ -706,7 +755,8 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
payload['roles'] = tuple(r.id for r in roles)
|
||||
|
||||
if payload:
|
||||
await http.edit_member(guild_id, self.id, reason=reason, **payload)
|
||||
data = await http.edit_member(guild_id, self.id, reason=reason, **payload)
|
||||
return Member(data=data, guild=self.guild, state=self._state)
|
||||
|
||||
async def request_to_speak(self) -> None:
|
||||
"""|coro|
|
||||
@@ -847,7 +897,7 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
for role in roles:
|
||||
await req(guild_id, user_id, role.id, reason=reason)
|
||||
|
||||
def get_role(self, role_id: int) -> Optional[Role]:
|
||||
def get_role(self, role_id: int, /) -> Optional[Role]:
|
||||
"""Returns a role with the given ID from roles which the member has.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
@@ -29,7 +29,7 @@ import datetime
|
||||
import re
|
||||
import io
|
||||
from os import PathLike
|
||||
from typing import Dict, TYPE_CHECKING, Union, List, Optional, Any, Callable, Tuple, ClassVar, Optional, overload
|
||||
from typing import Dict, TYPE_CHECKING, Union, List, Optional, Any, Callable, Tuple, ClassVar, Optional, overload, TypeVar, Type
|
||||
|
||||
from . import utils
|
||||
from .reaction import Reaction
|
||||
@@ -45,7 +45,7 @@ from .file import File
|
||||
from .utils import escape_mentions, MISSING
|
||||
from .guild import Guild
|
||||
from .mixins import Hashable
|
||||
from .sticker import Sticker
|
||||
from .sticker import StickerItem
|
||||
from .threads import Thread
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -70,12 +70,13 @@ if TYPE_CHECKING:
|
||||
from .abc import GuildChannel, PartialMessageableChannel, MessageableChannel
|
||||
from .components import Component
|
||||
from .state import ConnectionState
|
||||
from .channel import TextChannel, GroupChannel, DMChannel
|
||||
from .channel import TextChannel, GroupChannel, DMChannel, PartialMessageable
|
||||
from .mentions import AllowedMentions
|
||||
from .user import User
|
||||
from .role import Role
|
||||
from .ui.view import View
|
||||
|
||||
MR = TypeVar('MR', bound='MessageReference')
|
||||
EmojiInputType = Union[Emoji, PartialEmoji, str]
|
||||
|
||||
__all__ = (
|
||||
@@ -124,6 +125,10 @@ class Attachment(Hashable):
|
||||
|
||||
Returns the hash of the attachment.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the attachment's ID.
|
||||
|
||||
.. versionchanged:: 1.7
|
||||
Attachment can now be casted to :class:`str` and is hashable.
|
||||
|
||||
@@ -341,7 +346,8 @@ class DeletedReferencedMessage:
|
||||
@property
|
||||
def id(self) -> int:
|
||||
""":class:`int`: The message ID of the deleted referenced message."""
|
||||
return self._parent.message_id
|
||||
# the parent's message id won't be None here
|
||||
return self._parent.message_id # type: ignore
|
||||
|
||||
@property
|
||||
def channel_id(self) -> int:
|
||||
@@ -393,13 +399,13 @@ class MessageReference:
|
||||
def __init__(self, *, message_id: int, channel_id: int, guild_id: Optional[int] = None, fail_if_not_exists: bool = True):
|
||||
self._state: Optional[ConnectionState] = None
|
||||
self.resolved: Optional[Union[Message, DeletedReferencedMessage]] = None
|
||||
self.message_id: int = message_id
|
||||
self.message_id: Optional[int] = message_id
|
||||
self.channel_id: int = channel_id
|
||||
self.guild_id: Optional[int] = guild_id
|
||||
self.fail_if_not_exists: bool = fail_if_not_exists
|
||||
|
||||
@classmethod
|
||||
def with_state(cls, state: ConnectionState, data: MessageReferencePayload) -> MessageReference:
|
||||
def with_state(cls: Type[MR], state: ConnectionState, data: MessageReferencePayload) -> MR:
|
||||
self = cls.__new__(cls)
|
||||
self.message_id = utils._get_as_snowflake(data, 'message_id')
|
||||
self.channel_id = int(data.pop('channel_id'))
|
||||
@@ -410,7 +416,7 @@ class MessageReference:
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_message(cls, message: Message, *, fail_if_not_exists: bool = True) -> MessageReference:
|
||||
def from_message(cls: Type[MR], message: Message, *, fail_if_not_exists: bool = True) -> MR:
|
||||
"""Creates a :class:`MessageReference` from an existing :class:`~discord.Message`.
|
||||
|
||||
.. versionadded:: 1.6
|
||||
@@ -457,13 +463,13 @@ class MessageReference:
|
||||
return f'<MessageReference message_id={self.message_id!r} channel_id={self.channel_id!r} guild_id={self.guild_id!r}>'
|
||||
|
||||
def to_dict(self) -> MessageReferencePayload:
|
||||
result = {'message_id': self.message_id} if self.message_id is not None else {}
|
||||
result: MessageReferencePayload = {'message_id': self.message_id} if self.message_id is not None else {}
|
||||
result['channel_id'] = self.channel_id
|
||||
if self.guild_id is not None:
|
||||
result['guild_id'] = self.guild_id
|
||||
if self.fail_if_not_exists is not None:
|
||||
result['fail_if_not_exists'] = self.fail_if_not_exists
|
||||
return result # type: ignore
|
||||
return result
|
||||
|
||||
to_message_reference_dict = to_dict
|
||||
|
||||
@@ -501,6 +507,14 @@ class Message(Hashable):
|
||||
|
||||
Returns the message's hash.
|
||||
|
||||
.. describe:: str(x)
|
||||
|
||||
Returns the message's content.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the message's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
tts: :class:`bool`
|
||||
@@ -520,7 +534,7 @@ class Message(Hashable):
|
||||
This is not stored long term within Discord's servers and is only used ephemerally.
|
||||
embeds: List[:class:`Embed`]
|
||||
A list of embeds the message has.
|
||||
channel: Union[:class:`TextChannel`, :class:`Thread`, :class:`DMChannel`, :class:`GroupChannel`]
|
||||
channel: Union[:class:`TextChannel`, :class:`Thread`, :class:`DMChannel`, :class:`GroupChannel`, :class:`PartialMessageable`]
|
||||
The :class:`TextChannel` or :class:`Thread` that the message was sent from.
|
||||
Could be a :class:`DMChannel` or :class:`GroupChannel` if it's a private message.
|
||||
reference: Optional[:class:`~discord.MessageReference`]
|
||||
@@ -588,8 +602,8 @@ class Message(Hashable):
|
||||
- ``description``: A string representing the application's description.
|
||||
- ``icon``: A string representing the icon ID of the application.
|
||||
- ``cover_image``: A string representing the embed's image asset ID.
|
||||
stickers: List[:class:`Sticker`]
|
||||
A list of stickers given to the message.
|
||||
stickers: List[:class:`StickerItem`]
|
||||
A list of sticker items given to the message.
|
||||
|
||||
.. versionadded:: 1.6
|
||||
components: List[:class:`Component`]
|
||||
@@ -637,7 +651,7 @@ class Message(Hashable):
|
||||
_HANDLERS: ClassVar[List[Tuple[str, Callable[..., None]]]]
|
||||
_CACHED_SLOTS: ClassVar[List[str]]
|
||||
guild: Optional[Guild]
|
||||
ref: Optional[MessageReference]
|
||||
reference: Optional[MessageReference]
|
||||
mentions: List[Union[User, Member]]
|
||||
author: Union[User, Member]
|
||||
role_mentions: List[Role]
|
||||
@@ -646,7 +660,7 @@ class Message(Hashable):
|
||||
self,
|
||||
*,
|
||||
state: ConnectionState,
|
||||
channel: Union[TextChannel, Thread, DMChannel, GroupChannel],
|
||||
channel: MessageableChannel,
|
||||
data: MessagePayload,
|
||||
):
|
||||
self._state: ConnectionState = state
|
||||
@@ -666,10 +680,11 @@ class Message(Hashable):
|
||||
self.tts: bool = data['tts']
|
||||
self.content: str = data['content']
|
||||
self.nonce: Optional[Union[int, str]] = data.get('nonce')
|
||||
self.stickers: List[Sticker] = [Sticker(data=d, state=state) for d in data.get('stickers', [])]
|
||||
self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get('sticker_items', [])]
|
||||
self.components: List[Component] = [_component_factory(d) for d in data.get('components', [])]
|
||||
|
||||
try:
|
||||
# if the channel doesn't have a guild attribute, we handle that
|
||||
self.guild = channel.guild # type: ignore
|
||||
except AttributeError:
|
||||
self.guild = state._get_guild(utils._get_as_snowflake(data, 'guild_id'))
|
||||
@@ -694,7 +709,8 @@ class Message(Hashable):
|
||||
else:
|
||||
chan, _ = state._get_guild_channel(resolved)
|
||||
|
||||
ref.resolved = self.__class__(channel=chan, data=resolved, state=state)
|
||||
# the channel will be the correct type here
|
||||
ref.resolved = self.__class__(channel=chan, data=resolved, state=state) # type: ignore
|
||||
|
||||
for handler in ('author', 'member', 'mentions', 'mention_roles'):
|
||||
try:
|
||||
@@ -708,6 +724,10 @@ class Message(Hashable):
|
||||
f'<{name} id={self.id} channel={self.channel!r} type={self.type!r} author={self.author!r} flags={self.flags!r}>'
|
||||
)
|
||||
|
||||
|
||||
def __str__(self) -> Optional[str]:
|
||||
return self.content
|
||||
|
||||
def _try_patch(self, data, key, transform=None) -> None:
|
||||
try:
|
||||
value = data[key]
|
||||
@@ -977,9 +997,17 @@ class Message(Hashable):
|
||||
def is_system(self) -> bool:
|
||||
""":class:`bool`: Whether the message is a system message.
|
||||
|
||||
A system message is a message that is constructed entirely by the Discord API
|
||||
in response to something.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
"""
|
||||
return self.type is not MessageType.default
|
||||
return self.type not in (
|
||||
MessageType.default,
|
||||
MessageType.reply,
|
||||
MessageType.application_command,
|
||||
MessageType.thread_starter_message,
|
||||
)
|
||||
|
||||
@utils.cached_slot_property('_cs_system_content')
|
||||
def system_content(self):
|
||||
@@ -994,21 +1022,27 @@ class Message(Hashable):
|
||||
if self.type is MessageType.default:
|
||||
return self.content
|
||||
|
||||
if self.type is MessageType.pins_add:
|
||||
return f'{self.author.name} pinned a message to this channel.'
|
||||
|
||||
if self.type is MessageType.recipient_add:
|
||||
return f'{self.author.name} added {self.mentions[0].name} to the group.'
|
||||
if self.channel.type is ChannelType.group:
|
||||
return f'{self.author.name} added {self.mentions[0].name} to the group.'
|
||||
else:
|
||||
return f'{self.author.name} added {self.mentions[0].name} to the thread.'
|
||||
|
||||
if self.type is MessageType.recipient_remove:
|
||||
return f'{self.author.name} removed {self.mentions[0].name} from the group.'
|
||||
if self.channel.type is ChannelType.group:
|
||||
return f'{self.author.name} removed {self.mentions[0].name} from the group.'
|
||||
else:
|
||||
return f'{self.author.name} removed {self.mentions[0].name} from the thread.'
|
||||
|
||||
if self.type is MessageType.channel_name_change:
|
||||
return f'{self.author.name} changed the channel name: {self.content}'
|
||||
return f'{self.author.name} changed the channel name: **{self.content}**'
|
||||
|
||||
if self.type is MessageType.channel_icon_change:
|
||||
return f'{self.author.name} changed the channel icon.'
|
||||
|
||||
if self.type is MessageType.pins_add:
|
||||
return f'{self.author.name} pinned a message to this channel.'
|
||||
|
||||
if self.type is MessageType.new_member:
|
||||
formats = [
|
||||
"{0} joined the party.",
|
||||
@@ -1030,21 +1064,34 @@ class Message(Hashable):
|
||||
return formats[created_at_ms % len(formats)].format(self.author.name)
|
||||
|
||||
if self.type is MessageType.premium_guild_subscription:
|
||||
return f'{self.author.name} just boosted the server!'
|
||||
if not self.content:
|
||||
return f'{self.author.name} just boosted the server!'
|
||||
else:
|
||||
return f'{self.author.name} just boosted the server **{self.content}** times!'
|
||||
|
||||
if self.type is MessageType.premium_guild_tier_1:
|
||||
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 1!**'
|
||||
if not self.content:
|
||||
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 1!**'
|
||||
else:
|
||||
return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 1!**'
|
||||
|
||||
if self.type is MessageType.premium_guild_tier_2:
|
||||
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 2!**'
|
||||
if not self.content:
|
||||
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 2!**'
|
||||
else:
|
||||
return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 2!**'
|
||||
|
||||
if self.type is MessageType.premium_guild_tier_3:
|
||||
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 3!**'
|
||||
if not self.content:
|
||||
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 3!**'
|
||||
else:
|
||||
return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 3!**'
|
||||
|
||||
if self.type is MessageType.channel_follow_add:
|
||||
return f'{self.author.name} has added {self.content} to this channel'
|
||||
|
||||
if self.type is MessageType.guild_stream:
|
||||
# the author will be a Member
|
||||
return f'{self.author.name} is live! Now streaming {self.author.activity.name}' # type: ignore
|
||||
|
||||
if self.type is MessageType.guild_discovery_disqualified:
|
||||
@@ -1059,13 +1106,23 @@ class Message(Hashable):
|
||||
if self.type is MessageType.guild_discovery_grace_period_final_warning:
|
||||
return 'This server has failed Discovery activity requirements for 3 weeks in a row. If this server fails for 1 more week, it will be removed from Discovery.'
|
||||
|
||||
if self.type is MessageType.thread_created:
|
||||
return f'{self.author.name} started a thread: **{self.content}**. See all **threads**.'
|
||||
|
||||
if self.type is MessageType.reply:
|
||||
return self.content
|
||||
|
||||
if self.type is MessageType.thread_starter_message:
|
||||
if self.reference is None or self.reference.resolved is None:
|
||||
return 'Sorry, we couldn\'t load the first message in this thread'
|
||||
|
||||
# the resolved message for the reference will be a Message
|
||||
return self.reference.resolved.content # type: ignore
|
||||
|
||||
if self.type is MessageType.guild_invite_reminder:
|
||||
return 'Wondering who to invite?\nStart by inviting anyone who can help you build the server!'
|
||||
|
||||
async def delete(self, *, delay: Optional[float] = None) -> None:
|
||||
async def delete(self, *, delay: Optional[float] = None, silent: bool = False) -> None:
|
||||
"""|coro|
|
||||
|
||||
Deletes the message.
|
||||
@@ -1076,12 +1133,17 @@ class Message(Hashable):
|
||||
|
||||
.. versionchanged:: 1.1
|
||||
Added the new ``delay`` keyword-only parameter.
|
||||
.. versionchanged:: 2.0
|
||||
Added the new ``silent`` keyword-only parameter.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
delay: Optional[:class:`float`]
|
||||
If provided, the number of seconds to wait in the background
|
||||
before deleting the message. If the deletion fails then it is silently ignored.
|
||||
silent: :class:`bool`
|
||||
If silent is set to ``True``, the error will not be raised, it will be ignored.
|
||||
This defaults to ``False``
|
||||
|
||||
Raises
|
||||
------
|
||||
@@ -1103,7 +1165,11 @@ class Message(Hashable):
|
||||
|
||||
asyncio.create_task(delete(delay))
|
||||
else:
|
||||
await self._state.http.delete_message(self.channel.id, self.id)
|
||||
try:
|
||||
await self._state.http.delete_message(self.channel.id, self.id)
|
||||
except Exception:
|
||||
if not silent:
|
||||
raise
|
||||
|
||||
@overload
|
||||
async def edit(
|
||||
@@ -1116,7 +1182,7 @@ class Message(Hashable):
|
||||
delete_after: Optional[float] = ...,
|
||||
allowed_mentions: Optional[AllowedMentions] = ...,
|
||||
view: Optional[View] = ...,
|
||||
) -> None:
|
||||
) -> Message:
|
||||
...
|
||||
|
||||
@overload
|
||||
@@ -1130,7 +1196,7 @@ class Message(Hashable):
|
||||
delete_after: Optional[float] = ...,
|
||||
allowed_mentions: Optional[AllowedMentions] = ...,
|
||||
view: Optional[View] = ...,
|
||||
) -> None:
|
||||
) -> Message:
|
||||
...
|
||||
|
||||
async def edit(
|
||||
@@ -1143,7 +1209,7 @@ class Message(Hashable):
|
||||
delete_after: Optional[float] = None,
|
||||
allowed_mentions: Optional[AllowedMentions] = MISSING,
|
||||
view: Optional[View] = MISSING,
|
||||
) -> None:
|
||||
) -> Message:
|
||||
"""|coro|
|
||||
|
||||
Edits the message.
|
||||
@@ -1245,9 +1311,8 @@ class Message(Hashable):
|
||||
else:
|
||||
payload['components'] = []
|
||||
|
||||
if payload:
|
||||
data = await self._state.http.edit_message(self.channel.id, self.id, **payload)
|
||||
self._update(data)
|
||||
data = await self._state.http.edit_message(self.channel.id, self.id, **payload)
|
||||
message = Message(state=self._state, channel=self.channel, data=data)
|
||||
|
||||
if view and not view.is_finished():
|
||||
self._state.store_view(view, self.id)
|
||||
@@ -1255,6 +1320,8 @@ class Message(Hashable):
|
||||
if delete_after is not None:
|
||||
await self.delete(delay=delete_after)
|
||||
|
||||
return message
|
||||
|
||||
async def publish(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
@@ -1449,49 +1516,51 @@ class Message(Hashable):
|
||||
"""
|
||||
await self._state.http.clear_reactions(self.channel.id, self.id)
|
||||
|
||||
async def start_thread(self, *, name: str, auto_archive_duration: ThreadArchiveDuration = 1440) -> Thread:
|
||||
async def create_thread(self, *, name: str, auto_archive_duration: ThreadArchiveDuration = MISSING) -> Thread:
|
||||
"""|coro|
|
||||
|
||||
Starts a public thread from this message.
|
||||
Creates a public thread from this message.
|
||||
|
||||
You must have :attr:`~discord.Permissions.send_messages` and
|
||||
:attr:`~discord.Permissions.use_threads` in order to start a thread.
|
||||
You must have :attr:`~discord.Permissions.create_public_threads` in order to
|
||||
create a public thread from a message.
|
||||
|
||||
The channel this message belongs in must be a :class:`TextChannel`.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
name: :class:`str`
|
||||
The name of the thread.
|
||||
auto_archive_duration: :class:`int`
|
||||
The duration in minutes before a thread is automatically archived for inactivity.
|
||||
Defaults to ``1440`` or 24 hours.
|
||||
If not provided, the channel's default auto archive duration is used.
|
||||
|
||||
Raises
|
||||
-------
|
||||
Forbidden
|
||||
You do not have permissions to start a thread.
|
||||
You do not have permissions to create a thread.
|
||||
HTTPException
|
||||
Starting the thread failed.
|
||||
Creating the thread failed.
|
||||
InvalidArgument
|
||||
This message does not have guild info attached.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`.Thread`
|
||||
The started thread.
|
||||
The created thread.
|
||||
"""
|
||||
if self.guild is None:
|
||||
raise InvalidArgument('This message does not have guild info attached.')
|
||||
|
||||
data = await self._state.http.start_public_thread(
|
||||
default_auto_archive_duration: ThreadArchiveDuration = getattr(self.channel, 'default_auto_archive_duration', 1440)
|
||||
data = await self._state.http.start_thread_with_message(
|
||||
self.channel.id,
|
||||
self.id,
|
||||
name=name,
|
||||
auto_archive_duration=auto_archive_duration,
|
||||
type=ChannelType.public_thread.value,
|
||||
auto_archive_duration=auto_archive_duration or default_auto_archive_duration,
|
||||
)
|
||||
return Thread(guild=self.guild, state=self._state, data=data) # type: ignore
|
||||
return Thread(guild=self.guild, state=self._state, data=data)
|
||||
|
||||
async def reply(self, content: Optional[str] = None, **kwargs) -> Message:
|
||||
"""|coro|
|
||||
@@ -1581,6 +1650,10 @@ class PartialMessage(Hashable):
|
||||
|
||||
Returns the partial message's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the partial message's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
channel: Union[:class:`TextChannel`, :class:`Thread`, :class:`DMChannel`]
|
||||
@@ -1605,8 +1678,15 @@ class PartialMessage(Hashable):
|
||||
to_message_reference_dict = Message.to_message_reference_dict
|
||||
|
||||
def __init__(self, *, channel: PartialMessageableChannel, id: int):
|
||||
if channel.type not in (ChannelType.text, ChannelType.news, ChannelType.private):
|
||||
raise TypeError(f'Expected TextChannel or DMChannel not {type(channel)!r}')
|
||||
if channel.type not in (
|
||||
ChannelType.text,
|
||||
ChannelType.news,
|
||||
ChannelType.private,
|
||||
ChannelType.news_thread,
|
||||
ChannelType.public_thread,
|
||||
ChannelType.private_thread,
|
||||
):
|
||||
raise TypeError(f'Expected TextChannel, DMChannel or Thread not {type(channel)!r}')
|
||||
|
||||
self.channel: PartialMessageableChannel = channel
|
||||
self._state: ConnectionState = channel._state
|
||||
@@ -1730,7 +1810,7 @@ class PartialMessage(Hashable):
|
||||
fields['embed'] = embed.to_dict()
|
||||
|
||||
try:
|
||||
suppress = fields.pop('suppress')
|
||||
suppress: bool = fields.pop('suppress')
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
@@ -1768,9 +1848,10 @@ class PartialMessage(Hashable):
|
||||
data = await self._state.http.edit_message(self.channel.id, self.id, **fields)
|
||||
|
||||
if delete_after is not None:
|
||||
await self.delete(delay=delete_after) # type: ignore
|
||||
await self.delete(delay=delete_after)
|
||||
|
||||
if fields:
|
||||
# data isn't unbound
|
||||
msg = self._state.create_message(channel=self.channel, data=data) # type: ignore
|
||||
if view and not view.is_finished():
|
||||
self._state.store_view(view, self.id)
|
||||
|
||||
@@ -22,24 +22,20 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from typing import TypeVar
|
||||
|
||||
__all__ = (
|
||||
'EqualityComparable',
|
||||
'Hashable',
|
||||
)
|
||||
|
||||
E = TypeVar('E', bound='EqualityComparable')
|
||||
|
||||
class EqualityComparable:
|
||||
__slots__ = ()
|
||||
|
||||
id: int
|
||||
|
||||
def __eq__(self: E, other: E) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, self.__class__) and other.id == self.id
|
||||
|
||||
def __ne__(self: E, other: E) -> bool:
|
||||
def __ne__(self, other: object) -> bool:
|
||||
if isinstance(other, self.__class__):
|
||||
return other.id != self.id
|
||||
return True
|
||||
@@ -47,5 +43,8 @@ class EqualityComparable:
|
||||
class Hashable(EqualityComparable):
|
||||
__slots__ = ()
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.id
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.id >> 22
|
||||
|
||||
@@ -69,6 +69,10 @@ class Object(Hashable):
|
||||
|
||||
Returns the object's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the object's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
id: :class:`int`
|
||||
|
||||
@@ -22,8 +22,12 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import struct
|
||||
|
||||
from typing import TYPE_CHECKING, ClassVar, IO, Generator, Tuple, Optional
|
||||
|
||||
from .errors import DiscordException
|
||||
|
||||
__all__ = (
|
||||
@@ -40,22 +44,29 @@ class OggError(DiscordException):
|
||||
# https://tools.ietf.org/html/rfc7845
|
||||
|
||||
class OggPage:
|
||||
_header = struct.Struct('<xBQIIIB')
|
||||
_header: ClassVar[struct.Struct] = struct.Struct('<xBQIIIB')
|
||||
if TYPE_CHECKING:
|
||||
flag: int
|
||||
gran_pos: int
|
||||
serial: int
|
||||
pagenum: int
|
||||
crc: int
|
||||
segnum: int
|
||||
|
||||
def __init__(self, stream):
|
||||
def __init__(self, stream: IO[bytes]) -> None:
|
||||
try:
|
||||
header = stream.read(struct.calcsize(self._header.format))
|
||||
|
||||
self.flag, self.gran_pos, self.serial, \
|
||||
self.pagenum, self.crc, self.segnum = self._header.unpack(header)
|
||||
|
||||
self.segtable = stream.read(self.segnum)
|
||||
self.segtable: bytes = stream.read(self.segnum)
|
||||
bodylen = sum(struct.unpack('B'*self.segnum, self.segtable))
|
||||
self.data = stream.read(bodylen)
|
||||
self.data: bytes = stream.read(bodylen)
|
||||
except Exception:
|
||||
raise OggError('bad data stream') from None
|
||||
|
||||
def iter_packets(self):
|
||||
def iter_packets(self) -> Generator[Tuple[bytes, bool], None, None]:
|
||||
packetlen = offset = 0
|
||||
partial = True
|
||||
|
||||
@@ -74,10 +85,10 @@ class OggPage:
|
||||
yield self.data[offset:], False
|
||||
|
||||
class OggStream:
|
||||
def __init__(self, stream):
|
||||
self.stream = stream
|
||||
def __init__(self, stream: IO[bytes]) -> None:
|
||||
self.stream: IO[bytes] = stream
|
||||
|
||||
def _next_page(self):
|
||||
def _next_page(self) -> Optional[OggPage]:
|
||||
head = self.stream.read(4)
|
||||
if head == b'OggS':
|
||||
return OggPage(self.stream)
|
||||
@@ -86,13 +97,13 @@ class OggStream:
|
||||
else:
|
||||
raise OggError('invalid header magic')
|
||||
|
||||
def _iter_pages(self):
|
||||
def _iter_pages(self) -> Generator[OggPage, None, None]:
|
||||
page = self._next_page()
|
||||
while page:
|
||||
yield page
|
||||
page = self._next_page()
|
||||
|
||||
def iter_packets(self):
|
||||
def iter_packets(self) -> Generator[bytes, None, None]:
|
||||
partial = b''
|
||||
for page in self._iter_pages():
|
||||
for data, complete in page.iter_packets():
|
||||
|
||||
125
discord/opus.py
125
discord/opus.py
@@ -22,6 +22,10 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Tuple, TypedDict, Any, TYPE_CHECKING, Callable, TypeVar, Literal, Optional, overload
|
||||
|
||||
import array
|
||||
import ctypes
|
||||
import ctypes.util
|
||||
@@ -31,7 +35,24 @@ import os.path
|
||||
import struct
|
||||
import sys
|
||||
|
||||
from .errors import DiscordException
|
||||
from .errors import DiscordException, InvalidArgument
|
||||
|
||||
if TYPE_CHECKING:
|
||||
T = TypeVar('T')
|
||||
BAND_CTL = Literal['narrow', 'medium', 'wide', 'superwide', 'full']
|
||||
SIGNAL_CTL = Literal['auto', 'voice', 'music']
|
||||
|
||||
class BandCtl(TypedDict):
|
||||
narrow: int
|
||||
medium: int
|
||||
wide: int
|
||||
superwide: int
|
||||
full: int
|
||||
|
||||
class SignalCtl(TypedDict):
|
||||
auto: int
|
||||
voice: int
|
||||
music: int
|
||||
|
||||
__all__ = (
|
||||
'Encoder',
|
||||
@@ -39,7 +60,7 @@ __all__ = (
|
||||
'OpusNotLoaded',
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
c_int_ptr = ctypes.POINTER(ctypes.c_int)
|
||||
c_int16_ptr = ctypes.POINTER(ctypes.c_int16)
|
||||
@@ -76,7 +97,7 @@ CTL_SET_SIGNAL = 4024
|
||||
CTL_SET_GAIN = 4034
|
||||
CTL_LAST_PACKET_DURATION = 4039
|
||||
|
||||
band_ctl = {
|
||||
band_ctl: BandCtl = {
|
||||
'narrow': 1101,
|
||||
'medium': 1102,
|
||||
'wide': 1103,
|
||||
@@ -84,22 +105,22 @@ band_ctl = {
|
||||
'full': 1105,
|
||||
}
|
||||
|
||||
signal_ctl = {
|
||||
signal_ctl: SignalCtl = {
|
||||
'auto': -1000,
|
||||
'voice': 3001,
|
||||
'music': 3002,
|
||||
}
|
||||
|
||||
def _err_lt(result, func, args):
|
||||
def _err_lt(result: int, func: Callable, args: List) -> int:
|
||||
if result < OK:
|
||||
log.info('error has happened in %s', func.__name__)
|
||||
_log.info('error has happened in %s', func.__name__)
|
||||
raise OpusError(result)
|
||||
return result
|
||||
|
||||
def _err_ne(result, func, args):
|
||||
def _err_ne(result: T, func: Callable, args: List) -> T:
|
||||
ret = args[-1]._obj
|
||||
if ret.value != OK:
|
||||
log.info('error has happened in %s', func.__name__)
|
||||
_log.info('error has happened in %s', func.__name__)
|
||||
raise OpusError(ret.value)
|
||||
return result
|
||||
|
||||
@@ -108,7 +129,7 @@ def _err_ne(result, func, args):
|
||||
# The second one are the types of arguments it takes.
|
||||
# The third is the result type.
|
||||
# The fourth is the error handler.
|
||||
exported_functions = [
|
||||
exported_functions: List[Tuple[Any, ...]] = [
|
||||
# Generic
|
||||
('opus_get_version_string',
|
||||
None, ctypes.c_char_p, None),
|
||||
@@ -158,7 +179,7 @@ exported_functions = [
|
||||
[ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
|
||||
]
|
||||
|
||||
def libopus_loader(name):
|
||||
def libopus_loader(name: str) -> Any:
|
||||
# create the library...
|
||||
lib = ctypes.cdll.LoadLibrary(name)
|
||||
|
||||
@@ -178,11 +199,11 @@ def libopus_loader(name):
|
||||
if item[3]:
|
||||
func.errcheck = item[3]
|
||||
except KeyError:
|
||||
log.exception("Error assigning check function to %s", func)
|
||||
_log.exception("Error assigning check function to %s", func)
|
||||
|
||||
return lib
|
||||
|
||||
def _load_default():
|
||||
def _load_default() -> bool:
|
||||
global _lib
|
||||
try:
|
||||
if sys.platform == 'win32':
|
||||
@@ -198,7 +219,7 @@ def _load_default():
|
||||
|
||||
return _lib is not None
|
||||
|
||||
def load_opus(name):
|
||||
def load_opus(name: str) -> None:
|
||||
"""Loads the libopus shared library for use with voice.
|
||||
|
||||
If this function is not called then the library uses the function
|
||||
@@ -236,7 +257,7 @@ def load_opus(name):
|
||||
global _lib
|
||||
_lib = libopus_loader(name)
|
||||
|
||||
def is_loaded():
|
||||
def is_loaded() -> bool:
|
||||
"""Function to check if opus lib is successfully loaded either
|
||||
via the :func:`ctypes.util.find_library` call of :func:`load_opus`.
|
||||
|
||||
@@ -259,10 +280,10 @@ class OpusError(DiscordException):
|
||||
The error code returned.
|
||||
"""
|
||||
|
||||
def __init__(self, code):
|
||||
self.code = code
|
||||
def __init__(self, code: int):
|
||||
self.code: int = code
|
||||
msg = _lib.opus_strerror(self.code).decode('utf-8')
|
||||
log.info('"%s" has happened', msg)
|
||||
_log.info('"%s" has happened', msg)
|
||||
super().__init__(msg)
|
||||
|
||||
class OpusNotLoaded(DiscordException):
|
||||
@@ -286,92 +307,96 @@ class _OpusStruct:
|
||||
return _lib.opus_get_version_string().decode('utf-8')
|
||||
|
||||
class Encoder(_OpusStruct):
|
||||
def __init__(self, application=APPLICATION_AUDIO):
|
||||
def __init__(self, application: int = APPLICATION_AUDIO):
|
||||
_OpusStruct.get_opus_version()
|
||||
|
||||
self.application = application
|
||||
self._state = self._create_state()
|
||||
self.application: int = application
|
||||
self._state: EncoderStruct = self._create_state()
|
||||
self.set_bitrate(128)
|
||||
self.set_fec(True)
|
||||
self.set_expected_packet_loss_percent(0.15)
|
||||
self.set_bandwidth('full')
|
||||
self.set_signal_type('auto')
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
if hasattr(self, '_state'):
|
||||
_lib.opus_encoder_destroy(self._state)
|
||||
self._state = None
|
||||
# This is a destructor, so it's okay to assign None
|
||||
self._state = None # type: ignore
|
||||
|
||||
def _create_state(self):
|
||||
def _create_state(self) -> EncoderStruct:
|
||||
ret = ctypes.c_int()
|
||||
return _lib.opus_encoder_create(self.SAMPLING_RATE, self.CHANNELS, self.application, ctypes.byref(ret))
|
||||
|
||||
def set_bitrate(self, kbps):
|
||||
def set_bitrate(self, kbps: int) -> int:
|
||||
kbps = min(512, max(16, int(kbps)))
|
||||
|
||||
_lib.opus_encoder_ctl(self._state, CTL_SET_BITRATE, kbps * 1024)
|
||||
return kbps
|
||||
|
||||
def set_bandwidth(self, req):
|
||||
def set_bandwidth(self, req: BAND_CTL) -> None:
|
||||
if req not in band_ctl:
|
||||
raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(band_ctl)}')
|
||||
|
||||
k = band_ctl[req]
|
||||
_lib.opus_encoder_ctl(self._state, CTL_SET_BANDWIDTH, k)
|
||||
|
||||
def set_signal_type(self, req):
|
||||
def set_signal_type(self, req: SIGNAL_CTL) -> None:
|
||||
if req not in signal_ctl:
|
||||
raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(signal_ctl)}')
|
||||
|
||||
k = signal_ctl[req]
|
||||
_lib.opus_encoder_ctl(self._state, CTL_SET_SIGNAL, k)
|
||||
|
||||
def set_fec(self, enabled=True):
|
||||
def set_fec(self, enabled: bool = True) -> None:
|
||||
_lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0)
|
||||
|
||||
def set_expected_packet_loss_percent(self, percentage):
|
||||
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100))))
|
||||
def set_expected_packet_loss_percent(self, percentage: float) -> None:
|
||||
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore
|
||||
|
||||
def encode(self, pcm, frame_size):
|
||||
def encode(self, pcm: bytes, frame_size: int) -> bytes:
|
||||
max_data_bytes = len(pcm)
|
||||
pcm = ctypes.cast(pcm, c_int16_ptr)
|
||||
# bytes can be used to reference pointer
|
||||
pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore
|
||||
data = (ctypes.c_char * max_data_bytes)()
|
||||
|
||||
ret = _lib.opus_encode(self._state, pcm, frame_size, data, max_data_bytes)
|
||||
ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes)
|
||||
|
||||
return array.array('b', data[:ret]).tobytes()
|
||||
# array can be initialized with bytes but mypy doesn't know
|
||||
return array.array('b', data[:ret]).tobytes() # type: ignore
|
||||
|
||||
class Decoder(_OpusStruct):
|
||||
def __init__(self):
|
||||
_OpusStruct.get_opus_version()
|
||||
|
||||
self._state = self._create_state()
|
||||
self._state: DecoderStruct = self._create_state()
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
if hasattr(self, '_state'):
|
||||
_lib.opus_decoder_destroy(self._state)
|
||||
self._state = None
|
||||
# This is a destructor, so it's okay to assign None
|
||||
self._state = None # type: ignore
|
||||
|
||||
def _create_state(self):
|
||||
def _create_state(self) -> DecoderStruct:
|
||||
ret = ctypes.c_int()
|
||||
return _lib.opus_decoder_create(self.SAMPLING_RATE, self.CHANNELS, ctypes.byref(ret))
|
||||
|
||||
@staticmethod
|
||||
def packet_get_nb_frames(data):
|
||||
def packet_get_nb_frames(data: bytes) -> int:
|
||||
"""Gets the number of frames in an Opus packet"""
|
||||
return _lib.opus_packet_get_nb_frames(data, len(data))
|
||||
|
||||
@staticmethod
|
||||
def packet_get_nb_channels(data):
|
||||
def packet_get_nb_channels(data: bytes) -> int:
|
||||
"""Gets the number of channels in an Opus packet"""
|
||||
return _lib.opus_packet_get_nb_channels(data)
|
||||
|
||||
@classmethod
|
||||
def packet_get_samples_per_frame(cls, data):
|
||||
def packet_get_samples_per_frame(cls, data: bytes) -> int:
|
||||
"""Gets the number of samples per frame from an Opus packet"""
|
||||
return _lib.opus_packet_get_samples_per_frame(data, cls.SAMPLING_RATE)
|
||||
|
||||
def _set_gain(self, adjustment):
|
||||
def _set_gain(self, adjustment: int) -> int:
|
||||
"""Configures decoder gain adjustment.
|
||||
|
||||
Scales the decoded output by a factor specified in Q8 dB units.
|
||||
@@ -383,26 +408,34 @@ class Decoder(_OpusStruct):
|
||||
"""
|
||||
return _lib.opus_decoder_ctl(self._state, CTL_SET_GAIN, adjustment)
|
||||
|
||||
def set_gain(self, dB):
|
||||
def set_gain(self, dB: float) -> int:
|
||||
"""Sets the decoder gain in dB, from -128 to 128."""
|
||||
|
||||
dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8)
|
||||
return self._set_gain(dB_Q8)
|
||||
|
||||
def set_volume(self, mult):
|
||||
def set_volume(self, mult: float) -> int:
|
||||
"""Sets the output volume as a float percent, i.e. 0.5 for 50%, 1.75 for 175%, etc."""
|
||||
return self.set_gain(20 * math.log10(mult)) # amplitude ratio
|
||||
|
||||
def _get_last_packet_duration(self):
|
||||
def _get_last_packet_duration(self) -> int:
|
||||
"""Gets the duration (in samples) of the last packet successfully decoded or concealed."""
|
||||
|
||||
ret = ctypes.c_int32()
|
||||
_lib.opus_decoder_ctl(self._state, CTL_LAST_PACKET_DURATION, ctypes.byref(ret))
|
||||
return ret.value
|
||||
|
||||
def decode(self, data, *, fec=False):
|
||||
@overload
|
||||
def decode(self, data: bytes, *, fec: bool) -> bytes:
|
||||
...
|
||||
|
||||
@overload
|
||||
def decode(self, data: Literal[None], *, fec: Literal[False]) -> bytes:
|
||||
...
|
||||
|
||||
def decode(self, data: Optional[bytes], *, fec: bool = False) -> bytes:
|
||||
if data is None and fec:
|
||||
raise OpusError("Invalid arguments: FEC cannot be used with null data")
|
||||
raise InvalidArgument("Invalid arguments: FEC cannot be used with null data")
|
||||
|
||||
if data is None:
|
||||
frame_size = self._get_last_packet_duration() or self.SAMPLES_PER_FRAME
|
||||
|
||||
@@ -147,7 +147,7 @@ class Permissions(BaseFlags):
|
||||
"""A factory method that creates a :class:`Permissions` with all
|
||||
permissions set to ``True``.
|
||||
"""
|
||||
return cls(0b111111111111111111111111111111111111)
|
||||
return cls(0b111111111111111111111111111111111111111)
|
||||
|
||||
@classmethod
|
||||
def all_channel(cls: Type[P]) -> P:
|
||||
@@ -167,8 +167,13 @@ class Permissions(BaseFlags):
|
||||
|
||||
.. versionchanged:: 1.7
|
||||
Added :attr:`stream`, :attr:`priority_speaker` and :attr:`use_slash_commands` permissions.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
Added :attr:`create_public_threads`, :attr:`create_private_threads`, :attr:`manage_threads`,
|
||||
:attr:`use_external_stickers`, :attr:`send_messages_in_threads` and
|
||||
:attr:`request_to_speak` permissions.
|
||||
"""
|
||||
return cls(0b10110011111101111111111101010001)
|
||||
return cls(0b111110110110011111101111111111101010001)
|
||||
|
||||
@classmethod
|
||||
def general(cls: Type[P]) -> P:
|
||||
@@ -200,8 +205,12 @@ class Permissions(BaseFlags):
|
||||
.. versionchanged:: 1.7
|
||||
Permission :attr:`read_messages` is no longer part of the text permissions.
|
||||
Added :attr:`use_slash_commands` permission.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
Added :attr:`create_public_threads`, :attr:`create_private_threads`, :attr:`manage_threads`,
|
||||
:attr:`send_messages_in_threads` and :attr:`use_external_stickers` permissions.
|
||||
"""
|
||||
return cls(0b10000000000001111111100001000000)
|
||||
return cls(0b111110010000000000001111111100001000000)
|
||||
|
||||
@classmethod
|
||||
def voice(cls: Type[P]) -> P:
|
||||
@@ -462,6 +471,14 @@ class Permissions(BaseFlags):
|
||||
""":class:`bool`: Returns ``True`` if a user can create, edit, or delete emojis."""
|
||||
return 1 << 30
|
||||
|
||||
@make_permission_alias('manage_emojis')
|
||||
def manage_emojis_and_stickers(self) -> int:
|
||||
""":class:`bool`: An alias for :attr:`manage_emojis`.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return 1 << 30
|
||||
|
||||
@flag_value
|
||||
def use_slash_commands(self) -> int:
|
||||
""":class:`bool`: Returns ``True`` if a user can use slash commands.
|
||||
@@ -495,21 +512,45 @@ class Permissions(BaseFlags):
|
||||
return 1 << 34
|
||||
|
||||
@flag_value
|
||||
def use_threads(self) -> int:
|
||||
""":class:`bool`: Returns ``True`` if a user can create and participate in public threads.
|
||||
def create_public_threads(self) -> int:
|
||||
""":class:`bool`: Returns ``True`` if a user can create public threads.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return 1 << 35
|
||||
|
||||
@flag_value
|
||||
def use_private_threads(self) -> int:
|
||||
""":class:`bool`: Returns ``True`` if a user can create and participate in private threads.
|
||||
def create_private_threads(self) -> int:
|
||||
""":class:`bool`: Returns ``True`` if a user can create private threads.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return 1 << 36
|
||||
|
||||
@flag_value
|
||||
def external_stickers(self) -> int:
|
||||
""":class:`bool`: Returns ``True`` if a user can use stickers from other guilds.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return 1 << 37
|
||||
|
||||
@make_permission_alias('external_stickers')
|
||||
def use_external_stickers(self) -> int:
|
||||
""":class:`bool`: An alias for :attr:`external_stickers`.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return 1 << 37
|
||||
|
||||
@flag_value
|
||||
def send_messages_in_threads(self) -> int:
|
||||
""":class:`bool`: Returns ``True`` if a user can send messages in threads.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return 1 << 38
|
||||
|
||||
PO = TypeVar('PO', bound='PermissionOverwrite')
|
||||
|
||||
def _augment_from_permissions(cls):
|
||||
@@ -613,12 +654,16 @@ class PermissionOverwrite:
|
||||
manage_permissions: Optional[bool]
|
||||
manage_webhooks: Optional[bool]
|
||||
manage_emojis: Optional[bool]
|
||||
manage_emojis_and_stickers: Optional[bool]
|
||||
use_slash_commands: Optional[bool]
|
||||
request_to_speak: Optional[bool]
|
||||
manage_events: Optional[bool]
|
||||
manage_threads: Optional[bool]
|
||||
use_threads: Optional[bool]
|
||||
use_private_threads: Optional[bool]
|
||||
create_public_threads: Optional[bool]
|
||||
create_private_threads: Optional[bool]
|
||||
send_messages_in_threads: Optional[bool]
|
||||
external_stickers: Optional[bool]
|
||||
use_external_stickers: Optional[bool]
|
||||
|
||||
def __init__(self, **kwargs: Optional[bool]):
|
||||
self._values: Dict[str, Optional[bool]] = {}
|
||||
@@ -641,7 +686,7 @@ class PermissionOverwrite:
|
||||
else:
|
||||
self._values[key] = value
|
||||
|
||||
def pair(self):
|
||||
def pair(self) -> Tuple[Permissions, Permissions]:
|
||||
"""Tuple[:class:`Permissions`, :class:`Permissions`]: Returns the (allow, deny) pair from this overwrite."""
|
||||
|
||||
allow = Permissions.none()
|
||||
|
||||
@@ -50,7 +50,7 @@ if TYPE_CHECKING:
|
||||
AT = TypeVar('AT', bound='AudioSource')
|
||||
FT = TypeVar('FT', bound='FFmpegOpusAudio')
|
||||
|
||||
log: logging.Logger = logging.getLogger(__name__)
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
__all__ = (
|
||||
'AudioSource',
|
||||
@@ -140,13 +140,25 @@ class FFmpegAudio(AudioSource):
|
||||
.. versionadded:: 1.3
|
||||
"""
|
||||
|
||||
def __init__(self, source: str, *, executable: str = 'ffmpeg', args: Any, **subprocess_kwargs: Any):
|
||||
def __init__(self, source: Union[str, io.BufferedIOBase], *, executable: str = 'ffmpeg', args: Any, **subprocess_kwargs: Any):
|
||||
piping = subprocess_kwargs.get('stdin') == subprocess.PIPE
|
||||
if piping and isinstance(source, str):
|
||||
raise TypeError("parameter conflict: 'source' parameter cannot be a string when piping to stdin")
|
||||
|
||||
args = [executable, *args]
|
||||
kwargs = {'stdout': subprocess.PIPE}
|
||||
kwargs.update(subprocess_kwargs)
|
||||
|
||||
self._process: subprocess.Popen = self._spawn_process(args, **kwargs)
|
||||
self._stdout: IO[bytes] = self._process.stdout # type: ignore
|
||||
self._stdin: Optional[IO[Bytes]] = None
|
||||
self._pipe_thread: Optional[threading.Thread] = None
|
||||
|
||||
if piping:
|
||||
n = f'popen-stdin-writer:{id(self):#x}'
|
||||
self._stdin = self._process.stdin
|
||||
self._pipe_thread = threading.Thread(target=self._pipe_writer, args=(source,), daemon=True, name=n)
|
||||
self._pipe_thread.start()
|
||||
|
||||
def _spawn_process(self, args: Any, **subprocess_kwargs: Any) -> subprocess.Popen:
|
||||
process = None
|
||||
@@ -160,26 +172,44 @@ class FFmpegAudio(AudioSource):
|
||||
else:
|
||||
return process
|
||||
|
||||
def cleanup(self) -> None:
|
||||
def _kill_process(self) -> None:
|
||||
proc = self._process
|
||||
if proc is MISSING:
|
||||
return
|
||||
|
||||
log.info('Preparing to terminate ffmpeg process %s.', proc.pid)
|
||||
_log.info('Preparing to terminate ffmpeg process %s.', proc.pid)
|
||||
|
||||
try:
|
||||
proc.kill()
|
||||
except Exception:
|
||||
log.exception("Ignoring error attempting to kill ffmpeg process %s", proc.pid)
|
||||
_log.exception('Ignoring error attempting to kill ffmpeg process %s', proc.pid)
|
||||
|
||||
if proc.poll() is None:
|
||||
log.info('ffmpeg process %s has not terminated. Waiting to terminate...', proc.pid)
|
||||
_log.info('ffmpeg process %s has not terminated. Waiting to terminate...', proc.pid)
|
||||
proc.communicate()
|
||||
log.info('ffmpeg process %s should have terminated with a return code of %s.', proc.pid, proc.returncode)
|
||||
_log.info('ffmpeg process %s should have terminated with a return code of %s.', proc.pid, proc.returncode)
|
||||
else:
|
||||
log.info('ffmpeg process %s successfully terminated with return code of %s.', proc.pid, proc.returncode)
|
||||
_log.info('ffmpeg process %s successfully terminated with return code of %s.', proc.pid, proc.returncode)
|
||||
|
||||
self._process = self._stdout = MISSING
|
||||
|
||||
def _pipe_writer(self, source: io.BufferedIOBase) -> None:
|
||||
while self._process:
|
||||
# arbitrarily large read size
|
||||
data = source.read(8192)
|
||||
if not data:
|
||||
self._process.terminate()
|
||||
return
|
||||
try:
|
||||
self._stdin.write(data)
|
||||
except Exception:
|
||||
_log.debug('Write error for %s, this is probably not a problem', self, exc_info=True)
|
||||
# at this point the source data is either exhausted or the process is fubar
|
||||
self._process.terminate()
|
||||
return
|
||||
|
||||
def cleanup(self) -> None:
|
||||
self._kill_process()
|
||||
self._process = self._stdout = self._stdin = MISSING
|
||||
|
||||
class FFmpegPCMAudio(FFmpegAudio):
|
||||
"""An audio source from FFmpeg (or AVConv).
|
||||
@@ -218,16 +248,16 @@ class FFmpegPCMAudio(FFmpegAudio):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source: str,
|
||||
source: Union[str, io.BufferedIOBase],
|
||||
*,
|
||||
executable: str = 'ffmpeg',
|
||||
pipe: bool = False,
|
||||
stderr: Optional[IO[str]] = None,
|
||||
before_options: Optional[str] = None,
|
||||
before_options: Optional[str] = None,
|
||||
options: Optional[str] = None
|
||||
) -> None:
|
||||
args = []
|
||||
subprocess_kwargs = {'stdin': source if pipe else subprocess.DEVNULL, 'stderr': stderr}
|
||||
subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr}
|
||||
|
||||
if isinstance(before_options, str):
|
||||
args.extend(shlex.split(before_options))
|
||||
@@ -315,7 +345,7 @@ class FFmpegOpusAudio(FFmpegAudio):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source: str,
|
||||
source: Union[str, io.BufferedIOBase],
|
||||
*,
|
||||
bitrate: int = 128,
|
||||
codec: Optional[str] = None,
|
||||
@@ -327,7 +357,7 @@ class FFmpegOpusAudio(FFmpegAudio):
|
||||
) -> None:
|
||||
|
||||
args = []
|
||||
subprocess_kwargs = {'stdin': source if pipe else subprocess.DEVNULL, 'stderr': stderr}
|
||||
subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr}
|
||||
|
||||
if isinstance(before_options, str):
|
||||
args.extend(shlex.split(before_options))
|
||||
@@ -384,7 +414,6 @@ class FFmpegOpusAudio(FFmpegAudio):
|
||||
|
||||
def custom_probe(source, executable):
|
||||
# some analysis code here
|
||||
|
||||
return codec, bitrate
|
||||
|
||||
source = await discord.FFmpegOpusAudio.from_probe("song.webm", method=custom_probe)
|
||||
@@ -480,18 +509,18 @@ class FFmpegOpusAudio(FFmpegAudio):
|
||||
codec, bitrate = await loop.run_in_executor(None, lambda: probefunc(source, executable)) # type: ignore
|
||||
except Exception:
|
||||
if not fallback:
|
||||
log.exception("Probe '%s' using '%s' failed", method, executable)
|
||||
_log.exception("Probe '%s' using '%s' failed", method, executable)
|
||||
return # type: ignore
|
||||
|
||||
log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable)
|
||||
_log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable)
|
||||
try:
|
||||
codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable)) # type: ignore
|
||||
except Exception:
|
||||
log.exception("Fallback probe using '%s' failed", executable)
|
||||
_log.exception("Fallback probe using '%s' failed", executable)
|
||||
else:
|
||||
log.info("Fallback probe found codec=%s, bitrate=%s", codec, bitrate)
|
||||
_log.info("Fallback probe found codec=%s, bitrate=%s", codec, bitrate)
|
||||
else:
|
||||
log.info("Probe found codec=%s, bitrate=%s", codec, bitrate)
|
||||
_log.info("Probe found codec=%s, bitrate=%s", codec, bitrate)
|
||||
finally:
|
||||
return codec, bitrate
|
||||
|
||||
@@ -656,12 +685,12 @@ class AudioPlayer(threading.Thread):
|
||||
try:
|
||||
self.after(error)
|
||||
except Exception as exc:
|
||||
log.exception('Calling the after function failed.')
|
||||
_log.exception('Calling the after function failed.')
|
||||
exc.__context__ = error
|
||||
traceback.print_exception(type(exc), exc, exc.__traceback__)
|
||||
elif error:
|
||||
msg = f'Exception in voice thread {self.name}'
|
||||
log.exception(msg, exc_info=error)
|
||||
_log.exception(msg, exc_info=error)
|
||||
print(msg, file=sys.stderr)
|
||||
traceback.print_exception(type(error), error, error.__traceback__)
|
||||
|
||||
@@ -698,4 +727,4 @@ class AudioPlayer(threading.Thread):
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.client.loop)
|
||||
except Exception as e:
|
||||
log.info("Speaking call in player failed: %s", e)
|
||||
_log.info("Speaking call in player failed: %s", e)
|
||||
|
||||
0
discord/py.typed
Normal file
0
discord/py.typed
Normal file
@@ -22,6 +22,25 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Set, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .types.raw_models import (
|
||||
MessageDeleteEvent,
|
||||
BulkMessageDeleteEvent,
|
||||
ReactionActionEvent,
|
||||
MessageUpdateEvent,
|
||||
ReactionClearEvent,
|
||||
ReactionClearEmojiEvent,
|
||||
IntegrationDeleteEvent
|
||||
)
|
||||
from .message import Message
|
||||
from .partial_emoji import PartialEmoji
|
||||
from .member import Member
|
||||
|
||||
|
||||
from .enums import ChannelType, try_enum
|
||||
|
||||
__all__ = (
|
||||
@@ -35,11 +54,13 @@ __all__ = (
|
||||
'RawThreadDeleteEvent',
|
||||
)
|
||||
|
||||
|
||||
class _RawReprMixin:
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
value = ' '.join(f'{attr}={getattr(self, attr)!r}' for attr in self.__slots__)
|
||||
return f'<{self.__class__.__name__} {value}>'
|
||||
|
||||
|
||||
class RawMessageDeleteEvent(_RawReprMixin):
|
||||
"""Represents the event payload for a :func:`on_raw_message_delete` event.
|
||||
|
||||
@@ -57,14 +78,15 @@ class RawMessageDeleteEvent(_RawReprMixin):
|
||||
|
||||
__slots__ = ('message_id', 'channel_id', 'guild_id', 'cached_message')
|
||||
|
||||
def __init__(self, data):
|
||||
self.message_id = int(data['id'])
|
||||
self.channel_id = int(data['channel_id'])
|
||||
self.cached_message = None
|
||||
def __init__(self, data: MessageDeleteEvent) -> None:
|
||||
self.message_id: int = int(data['id'])
|
||||
self.channel_id: int = int(data['channel_id'])
|
||||
self.cached_message: Optional[Message] = None
|
||||
try:
|
||||
self.guild_id = int(data['guild_id'])
|
||||
self.guild_id: Optional[int] = int(data['guild_id'])
|
||||
except KeyError:
|
||||
self.guild_id = None
|
||||
self.guild_id: Optional[int] = None
|
||||
|
||||
|
||||
class RawBulkMessageDeleteEvent(_RawReprMixin):
|
||||
"""Represents the event payload for a :func:`on_raw_bulk_message_delete` event.
|
||||
@@ -83,15 +105,16 @@ class RawBulkMessageDeleteEvent(_RawReprMixin):
|
||||
|
||||
__slots__ = ('message_ids', 'channel_id', 'guild_id', 'cached_messages')
|
||||
|
||||
def __init__(self, data):
|
||||
self.message_ids = {int(x) for x in data.get('ids', [])}
|
||||
self.channel_id = int(data['channel_id'])
|
||||
self.cached_messages = []
|
||||
def __init__(self, data: BulkMessageDeleteEvent) -> None:
|
||||
self.message_ids: Set[int] = {int(x) for x in data.get('ids', [])}
|
||||
self.channel_id: int = int(data['channel_id'])
|
||||
self.cached_messages: List[Message] = []
|
||||
|
||||
try:
|
||||
self.guild_id = int(data['guild_id'])
|
||||
self.guild_id: Optional[int] = int(data['guild_id'])
|
||||
except KeyError:
|
||||
self.guild_id = None
|
||||
self.guild_id: Optional[int] = None
|
||||
|
||||
|
||||
class RawMessageUpdateEvent(_RawReprMixin):
|
||||
"""Represents the payload for a :func:`on_raw_message_edit` event.
|
||||
@@ -118,16 +141,17 @@ class RawMessageUpdateEvent(_RawReprMixin):
|
||||
|
||||
__slots__ = ('message_id', 'channel_id', 'guild_id', 'data', 'cached_message')
|
||||
|
||||
def __init__(self, data):
|
||||
self.message_id = int(data['id'])
|
||||
self.channel_id = int(data['channel_id'])
|
||||
self.data = data
|
||||
self.cached_message = None
|
||||
def __init__(self, data: MessageUpdateEvent) -> None:
|
||||
self.message_id: int = int(data['id'])
|
||||
self.channel_id: int = int(data['channel_id'])
|
||||
self.data: MessageUpdateEvent = data
|
||||
self.cached_message: Optional[Message] = None
|
||||
|
||||
try:
|
||||
self.guild_id = int(data['guild_id'])
|
||||
self.guild_id: Optional[int] = int(data['guild_id'])
|
||||
except KeyError:
|
||||
self.guild_id = None
|
||||
self.guild_id: Optional[int] = None
|
||||
|
||||
|
||||
class RawReactionActionEvent(_RawReprMixin):
|
||||
"""Represents the payload for a :func:`on_raw_reaction_add` or
|
||||
@@ -161,18 +185,19 @@ class RawReactionActionEvent(_RawReprMixin):
|
||||
__slots__ = ('message_id', 'user_id', 'channel_id', 'guild_id', 'emoji',
|
||||
'event_type', 'member')
|
||||
|
||||
def __init__(self, data, emoji, event_type):
|
||||
self.message_id = int(data['message_id'])
|
||||
self.channel_id = int(data['channel_id'])
|
||||
self.user_id = int(data['user_id'])
|
||||
self.emoji = emoji
|
||||
self.event_type = event_type
|
||||
self.member = None
|
||||
def __init__(self, data: ReactionActionEvent, emoji: PartialEmoji, event_type: str) -> None:
|
||||
self.message_id: int = int(data['message_id'])
|
||||
self.channel_id: int = int(data['channel_id'])
|
||||
self.user_id: int = int(data['user_id'])
|
||||
self.emoji: PartialEmoji = emoji
|
||||
self.event_type: str = event_type
|
||||
self.member: Optional[Member] = None
|
||||
|
||||
try:
|
||||
self.guild_id = int(data['guild_id'])
|
||||
self.guild_id: Optional[int] = int(data['guild_id'])
|
||||
except KeyError:
|
||||
self.guild_id = None
|
||||
self.guild_id: Optional[int] = None
|
||||
|
||||
|
||||
class RawReactionClearEvent(_RawReprMixin):
|
||||
"""Represents the payload for a :func:`on_raw_reaction_clear` event.
|
||||
@@ -189,14 +214,15 @@ class RawReactionClearEvent(_RawReprMixin):
|
||||
|
||||
__slots__ = ('message_id', 'channel_id', 'guild_id')
|
||||
|
||||
def __init__(self, data):
|
||||
self.message_id = int(data['message_id'])
|
||||
self.channel_id = int(data['channel_id'])
|
||||
def __init__(self, data: ReactionClearEvent) -> None:
|
||||
self.message_id: int = int(data['message_id'])
|
||||
self.channel_id: int = int(data['channel_id'])
|
||||
|
||||
try:
|
||||
self.guild_id = int(data['guild_id'])
|
||||
self.guild_id: Optional[int] = int(data['guild_id'])
|
||||
except KeyError:
|
||||
self.guild_id = None
|
||||
self.guild_id: Optional[int] = None
|
||||
|
||||
|
||||
class RawReactionClearEmojiEvent(_RawReprMixin):
|
||||
"""Represents the payload for a :func:`on_raw_reaction_clear_emoji` event.
|
||||
@@ -217,15 +243,16 @@ class RawReactionClearEmojiEvent(_RawReprMixin):
|
||||
|
||||
__slots__ = ('message_id', 'channel_id', 'guild_id', 'emoji')
|
||||
|
||||
def __init__(self, data, emoji):
|
||||
self.emoji = emoji
|
||||
self.message_id = int(data['message_id'])
|
||||
self.channel_id = int(data['channel_id'])
|
||||
def __init__(self, data: ReactionClearEmojiEvent, emoji: PartialEmoji) -> None:
|
||||
self.emoji: PartialEmoji = emoji
|
||||
self.message_id: int = int(data['message_id'])
|
||||
self.channel_id: int = int(data['channel_id'])
|
||||
|
||||
try:
|
||||
self.guild_id = int(data['guild_id'])
|
||||
self.guild_id: Optional[int] = int(data['guild_id'])
|
||||
except KeyError:
|
||||
self.guild_id = None
|
||||
self.guild_id: Optional[int] = None
|
||||
|
||||
|
||||
class RawIntegrationDeleteEvent(_RawReprMixin):
|
||||
"""Represents the payload for a :func:`on_raw_integration_delete` event.
|
||||
@@ -244,14 +271,14 @@ class RawIntegrationDeleteEvent(_RawReprMixin):
|
||||
|
||||
__slots__ = ('integration_id', 'application_id', 'guild_id')
|
||||
|
||||
def __init__(self, data):
|
||||
self.integration_id = int(data['id'])
|
||||
self.guild_id = int(data['guild_id'])
|
||||
def __init__(self, data: IntegrationDeleteEvent) -> None:
|
||||
self.integration_id: int = int(data['id'])
|
||||
self.guild_id: int = int(data['guild_id'])
|
||||
|
||||
try:
|
||||
self.application_id = int(data['application_id'])
|
||||
self.application_id: Optional[int] = int(data['application_id'])
|
||||
except KeyError:
|
||||
self.application_id = None
|
||||
self.application_id: Optional[int] = None
|
||||
|
||||
class RawThreadDeleteEvent(_RawReprMixin):
|
||||
"""Represents the payload for a :func:`on_raw_thread_delete` event.
|
||||
|
||||
@@ -42,6 +42,7 @@ if TYPE_CHECKING:
|
||||
Role as RolePayload,
|
||||
RoleTags as RoleTagPayload,
|
||||
)
|
||||
from .types.guild import RolePositionUpdate
|
||||
from .guild import Guild
|
||||
from .member import Member
|
||||
from .state import ConnectionState
|
||||
@@ -140,6 +141,14 @@ class Role(Hashable):
|
||||
|
||||
Returns the role's name.
|
||||
|
||||
.. describe:: str(x)
|
||||
|
||||
Returns the role's ID.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the role's ID.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
id: :class:`int`
|
||||
@@ -194,6 +203,9 @@ class Role(Hashable):
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<Role id={self.id} name={self.name!r}>'
|
||||
|
||||
@@ -336,7 +348,7 @@ class Role(Hashable):
|
||||
else:
|
||||
roles.append(self.id)
|
||||
|
||||
payload = [{"id": z[0], "position": z[1]} for z in zip(roles, change_range)]
|
||||
payload: List[RolePositionUpdate] = [{"id": z[0], "position": z[1]} for z in zip(roles, change_range)]
|
||||
await http.move_role_position(self.guild.id, payload, reason=reason)
|
||||
|
||||
async def edit(
|
||||
@@ -350,7 +362,7 @@ class Role(Hashable):
|
||||
mentionable: bool = MISSING,
|
||||
position: int = MISSING,
|
||||
reason: Optional[str] = MISSING,
|
||||
) -> None:
|
||||
) -> Optional[Role]:
|
||||
"""|coro|
|
||||
|
||||
Edits the role.
|
||||
@@ -363,6 +375,9 @@ class Role(Hashable):
|
||||
.. versionchanged:: 1.4
|
||||
Can now pass ``int`` to ``colour`` keyword-only parameter.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
Edits are no longer in-place, the newly edited role is returned instead.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
name: :class:`str`
|
||||
@@ -390,11 +405,14 @@ class Role(Hashable):
|
||||
InvalidArgument
|
||||
An invalid position was given or the default
|
||||
role was asked to be moved.
|
||||
"""
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`Role`
|
||||
The newly edited role.
|
||||
"""
|
||||
if position is not MISSING:
|
||||
await self._move(position, reason=reason)
|
||||
self.position = position
|
||||
|
||||
payload: Dict[str, Any] = {}
|
||||
if color is not MISSING:
|
||||
@@ -419,7 +437,7 @@ class Role(Hashable):
|
||||
payload['mentionable'] = mentionable
|
||||
|
||||
data = await self._state.http.edit_role(self.guild.id, self.id, reason=reason, **payload)
|
||||
self._update(data)
|
||||
return Role(guild=self.guild, data=data, state=self._state)
|
||||
|
||||
async def delete(self, *, reason: Optional[str] = None) -> None:
|
||||
"""|coro|
|
||||
|
||||
180
discord/shard.py
180
discord/shard.py
@@ -22,8 +22,9 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
import logging
|
||||
|
||||
import aiohttp
|
||||
@@ -34,22 +35,30 @@ from .backoff import ExponentialBackoff
|
||||
from .gateway import *
|
||||
from .errors import (
|
||||
ClientException,
|
||||
InvalidArgument,
|
||||
HTTPException,
|
||||
GatewayNotFound,
|
||||
ConnectionClosed,
|
||||
PrivilegedIntentsRequired,
|
||||
)
|
||||
|
||||
from . import utils
|
||||
from .enums import Status
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .gateway import DiscordWebSocket
|
||||
from .activity import BaseActivity
|
||||
from .enums import Status
|
||||
|
||||
EI = TypeVar('EI', bound='EventItem')
|
||||
|
||||
__all__ = (
|
||||
'AutoShardedClient',
|
||||
'ShardInfo',
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventType:
|
||||
close = 0
|
||||
@@ -59,39 +68,41 @@ class EventType:
|
||||
terminate = 4
|
||||
clean_close = 5
|
||||
|
||||
|
||||
class EventItem:
|
||||
__slots__ = ('type', 'shard', 'error')
|
||||
|
||||
def __init__(self, etype, shard, error):
|
||||
self.type = etype
|
||||
self.shard = shard
|
||||
self.error = error
|
||||
def __init__(self, etype: int, shard: Optional['Shard'], error: Optional[Exception]) -> None:
|
||||
self.type: int = etype
|
||||
self.shard: Optional['Shard'] = shard
|
||||
self.error: Optional[Exception] = error
|
||||
|
||||
def __lt__(self, other):
|
||||
def __lt__(self: EI, other: EI) -> bool:
|
||||
if not isinstance(other, EventItem):
|
||||
return NotImplemented
|
||||
return self.type < other.type
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self: EI, other: EI) -> bool:
|
||||
if not isinstance(other, EventItem):
|
||||
return NotImplemented
|
||||
return self.type == other.type
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.type)
|
||||
|
||||
|
||||
class Shard:
|
||||
def __init__(self, ws, client, queue_put):
|
||||
self.ws = ws
|
||||
self._client = client
|
||||
self._dispatch = client.dispatch
|
||||
self._queue_put = queue_put
|
||||
self.loop = self._client.loop
|
||||
self._disconnect = False
|
||||
def __init__(self, ws: DiscordWebSocket, client: AutoShardedClient, queue_put: Callable[[EventItem], None]) -> None:
|
||||
self.ws: DiscordWebSocket = ws
|
||||
self._client: Client = client
|
||||
self._dispatch: Callable[..., None] = client.dispatch
|
||||
self._queue_put: Callable[[EventItem], None] = queue_put
|
||||
self.loop: asyncio.AbstractEventLoop = self._client.loop
|
||||
self._disconnect: bool = False
|
||||
self._reconnect = client._reconnect
|
||||
self._backoff = ExponentialBackoff()
|
||||
self._task = None
|
||||
self._handled_exceptions = (
|
||||
self._backoff: ExponentialBackoff = ExponentialBackoff()
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._handled_exceptions: Tuple[Type[Exception], ...] = (
|
||||
OSError,
|
||||
HTTPException,
|
||||
GatewayNotFound,
|
||||
@@ -101,25 +112,26 @@ class Shard:
|
||||
)
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self.ws.shard_id
|
||||
def id(self) -> int:
|
||||
# DiscordWebSocket.shard_id is set in the from_client classmethod
|
||||
return self.ws.shard_id # type: ignore
|
||||
|
||||
def launch(self):
|
||||
def launch(self) -> None:
|
||||
self._task = self.loop.create_task(self.worker())
|
||||
|
||||
def _cancel_task(self):
|
||||
def _cancel_task(self) -> None:
|
||||
if self._task is not None and not self._task.done():
|
||||
self._task.cancel()
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
self._cancel_task()
|
||||
await self.ws.close(code=1000)
|
||||
|
||||
async def disconnect(self):
|
||||
async def disconnect(self) -> None:
|
||||
await self.close()
|
||||
self._dispatch('shard_disconnect', self.id)
|
||||
|
||||
async def _handle_disconnect(self, e):
|
||||
async def _handle_disconnect(self, e: Exception) -> None:
|
||||
self._dispatch('disconnect')
|
||||
self._dispatch('shard_disconnect', self.id)
|
||||
if not self._reconnect:
|
||||
@@ -144,11 +156,11 @@ class Shard:
|
||||
return
|
||||
|
||||
retry = self._backoff.delay()
|
||||
log.error('Attempting a reconnect for shard ID %s in %.2fs', self.id, retry, exc_info=e)
|
||||
_log.error('Attempting a reconnect for shard ID %s in %.2fs', self.id, retry, exc_info=e)
|
||||
await asyncio.sleep(retry)
|
||||
self._queue_put(EventItem(EventType.reconnect, self, e))
|
||||
|
||||
async def worker(self):
|
||||
async def worker(self) -> None:
|
||||
while not self._client.is_closed():
|
||||
try:
|
||||
await self.ws.poll_event()
|
||||
@@ -165,14 +177,19 @@ class Shard:
|
||||
self._queue_put(EventItem(EventType.terminate, self, e))
|
||||
break
|
||||
|
||||
async def reidentify(self, exc):
|
||||
async def reidentify(self, exc: ReconnectWebSocket) -> None:
|
||||
self._cancel_task()
|
||||
self._dispatch('disconnect')
|
||||
self._dispatch('shard_disconnect', self.id)
|
||||
log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
|
||||
_log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
|
||||
try:
|
||||
coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id,
|
||||
session=self.ws.session_id, sequence=self.ws.sequence)
|
||||
coro = DiscordWebSocket.from_client(
|
||||
self._client,
|
||||
resume=exc.resume,
|
||||
shard_id=self.id,
|
||||
session=self.ws.session_id,
|
||||
sequence=self.ws.sequence,
|
||||
)
|
||||
self.ws = await asyncio.wait_for(coro, timeout=60.0)
|
||||
except self._handled_exceptions as e:
|
||||
await self._handle_disconnect(e)
|
||||
@@ -183,7 +200,7 @@ class Shard:
|
||||
else:
|
||||
self.launch()
|
||||
|
||||
async def reconnect(self):
|
||||
async def reconnect(self) -> None:
|
||||
self._cancel_task()
|
||||
try:
|
||||
coro = DiscordWebSocket.from_client(self._client, shard_id=self.id)
|
||||
@@ -197,6 +214,7 @@ class Shard:
|
||||
else:
|
||||
self.launch()
|
||||
|
||||
|
||||
class ShardInfo:
|
||||
"""A class that gives information and control over a specific shard.
|
||||
|
||||
@@ -215,16 +233,16 @@ class ShardInfo:
|
||||
|
||||
__slots__ = ('_parent', 'id', 'shard_count')
|
||||
|
||||
def __init__(self, parent, shard_count):
|
||||
self._parent = parent
|
||||
self.id = parent.id
|
||||
self.shard_count = shard_count
|
||||
def __init__(self, parent: Shard, shard_count: Optional[int]) -> None:
|
||||
self._parent: Shard = parent
|
||||
self.id: int = parent.id
|
||||
self.shard_count: Optional[int] = shard_count
|
||||
|
||||
def is_closed(self):
|
||||
def is_closed(self) -> bool:
|
||||
""":class:`bool`: Whether the shard connection is currently closed."""
|
||||
return not self._parent.ws.open
|
||||
|
||||
async def disconnect(self):
|
||||
async def disconnect(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
Disconnects a shard. When this is called, the shard connection will no
|
||||
@@ -237,7 +255,7 @@ class ShardInfo:
|
||||
|
||||
await self._parent.disconnect()
|
||||
|
||||
async def reconnect(self):
|
||||
async def reconnect(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
Disconnects and then connects the shard again.
|
||||
@@ -246,7 +264,7 @@ class ShardInfo:
|
||||
await self._parent.disconnect()
|
||||
await self._parent.reconnect()
|
||||
|
||||
async def connect(self):
|
||||
async def connect(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
Connects a shard. If the shard is already connected this does nothing.
|
||||
@@ -257,11 +275,11 @@ class ShardInfo:
|
||||
await self._parent.reconnect()
|
||||
|
||||
@property
|
||||
def latency(self):
|
||||
def latency(self) -> float:
|
||||
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds for this shard."""
|
||||
return self._parent.ws.latency
|
||||
|
||||
def is_ws_ratelimited(self):
|
||||
def is_ws_ratelimited(self) -> bool:
|
||||
""":class:`bool`: Whether the websocket is currently rate limited.
|
||||
|
||||
This can be useful to know when deciding whether you should query members
|
||||
@@ -271,6 +289,7 @@ class ShardInfo:
|
||||
"""
|
||||
return self._parent.ws.is_ratelimited()
|
||||
|
||||
|
||||
class AutoShardedClient(Client):
|
||||
"""A client similar to :class:`Client` except it handles the complications
|
||||
of sharding for the user into a more manageable and transparent single
|
||||
@@ -297,9 +316,13 @@ class AutoShardedClient(Client):
|
||||
shard_ids: Optional[List[:class:`int`]]
|
||||
An optional list of shard_ids to launch the shards with.
|
||||
"""
|
||||
def __init__(self, *args, loop=None, **kwargs):
|
||||
|
||||
if TYPE_CHECKING:
|
||||
_connection: AutoShardedConnectionState
|
||||
|
||||
def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None:
|
||||
kwargs.pop('shard_id', None)
|
||||
self.shard_ids = kwargs.pop('shard_ids', None)
|
||||
self.shard_ids: Optional[List[int]] = kwargs.pop('shard_ids', None)
|
||||
super().__init__(*args, loop=loop, **kwargs)
|
||||
|
||||
if self.shard_ids is not None:
|
||||
@@ -315,18 +338,24 @@ class AutoShardedClient(Client):
|
||||
self._connection._get_client = lambda: self
|
||||
self.__queue = asyncio.PriorityQueue()
|
||||
|
||||
def _get_websocket(self, guild_id=None, *, shard_id=None):
|
||||
def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket:
|
||||
if shard_id is None:
|
||||
shard_id = (guild_id >> 22) % self.shard_count
|
||||
# guild_id won't be None if shard_id is None and shard_count won't be None here
|
||||
shard_id = (guild_id >> 22) % self.shard_count # type: ignore
|
||||
return self.__shards[shard_id].ws
|
||||
|
||||
def _get_state(self, **options):
|
||||
return AutoShardedConnectionState(dispatch=self.dispatch,
|
||||
handlers=self._handlers,
|
||||
hooks=self._hooks, http=self.http, loop=self.loop, **options)
|
||||
def _get_state(self, **options: Any) -> AutoShardedConnectionState:
|
||||
return AutoShardedConnectionState(
|
||||
dispatch=self.dispatch,
|
||||
handlers=self._handlers,
|
||||
hooks=self._hooks,
|
||||
http=self.http,
|
||||
loop=self.loop,
|
||||
**options,
|
||||
)
|
||||
|
||||
@property
|
||||
def latency(self):
|
||||
def latency(self) -> float:
|
||||
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.
|
||||
|
||||
This operates similarly to :meth:`Client.latency` except it uses the average
|
||||
@@ -338,14 +367,14 @@ class AutoShardedClient(Client):
|
||||
return sum(latency for _, latency in self.latencies) / len(self.__shards)
|
||||
|
||||
@property
|
||||
def latencies(self):
|
||||
def latencies(self) -> List[Tuple[int, float]]:
|
||||
"""List[Tuple[:class:`int`, :class:`float`]]: A list of latencies between a HEARTBEAT and a HEARTBEAT_ACK in seconds.
|
||||
|
||||
This returns a list of tuples with elements ``(shard_id, latency)``.
|
||||
"""
|
||||
return [(shard_id, shard.ws.latency) for shard_id, shard in self.__shards.items()]
|
||||
|
||||
def get_shard(self, shard_id):
|
||||
def get_shard(self, shard_id: int) -> Optional[ShardInfo]:
|
||||
"""Optional[:class:`ShardInfo`]: Gets the shard information at a given shard ID or ``None`` if not found."""
|
||||
try:
|
||||
parent = self.__shards[shard_id]
|
||||
@@ -355,16 +384,16 @@ class AutoShardedClient(Client):
|
||||
return ShardInfo(parent, self.shard_count)
|
||||
|
||||
@property
|
||||
def shards(self):
|
||||
def shards(self) -> Dict[int, ShardInfo]:
|
||||
"""Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object."""
|
||||
return { shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items() }
|
||||
return {shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items()}
|
||||
|
||||
async def launch_shard(self, gateway, shard_id, *, initial=False):
|
||||
async def launch_shard(self, gateway: str, shard_id: int, *, initial: bool = False) -> None:
|
||||
try:
|
||||
coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id)
|
||||
ws = await asyncio.wait_for(coro, timeout=180.0)
|
||||
except Exception:
|
||||
log.exception('Failed to connect for shard_id: %s. Retrying...', shard_id)
|
||||
_log.exception('Failed to connect for shard_id: %s. Retrying...', shard_id)
|
||||
await asyncio.sleep(5.0)
|
||||
return await self.launch_shard(gateway, shard_id)
|
||||
|
||||
@@ -372,7 +401,7 @@ class AutoShardedClient(Client):
|
||||
self.__shards[shard_id] = ret = Shard(ws, self, self.__queue.put_nowait)
|
||||
ret.launch()
|
||||
|
||||
async def launch_shards(self):
|
||||
async def launch_shards(self) -> None:
|
||||
if self.shard_count is None:
|
||||
self.shard_count, gateway = await self.http.get_bot_gateway()
|
||||
else:
|
||||
@@ -389,7 +418,7 @@ class AutoShardedClient(Client):
|
||||
|
||||
self._connection.shards_launched.set()
|
||||
|
||||
async def connect(self, *, reconnect=True):
|
||||
async def connect(self, *, reconnect: bool = True) -> None:
|
||||
self._reconnect = reconnect
|
||||
await self.launch_shards()
|
||||
|
||||
@@ -413,7 +442,7 @@ class AutoShardedClient(Client):
|
||||
elif item.type == EventType.clean_close:
|
||||
return
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
Closes the connection to Discord.
|
||||
@@ -425,7 +454,7 @@ class AutoShardedClient(Client):
|
||||
|
||||
for vc in self.voice_clients:
|
||||
try:
|
||||
await vc.disconnect()
|
||||
await vc.disconnect(force=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -436,7 +465,13 @@ class AutoShardedClient(Client):
|
||||
await self.http.close()
|
||||
self.__queue.put_nowait(EventItem(EventType.clean_close, None, None))
|
||||
|
||||
async def change_presence(self, *, activity=None, status=None, shard_id=None):
|
||||
async def change_presence(
|
||||
self,
|
||||
*,
|
||||
activity: Optional[BaseActivity] = None,
|
||||
status: Optional[Status] = None,
|
||||
shard_id: int = None,
|
||||
) -> None:
|
||||
"""|coro|
|
||||
|
||||
Changes the client's presence.
|
||||
@@ -468,23 +503,23 @@ class AutoShardedClient(Client):
|
||||
"""
|
||||
|
||||
if status is None:
|
||||
status = 'online'
|
||||
status_value = 'online'
|
||||
status_enum = Status.online
|
||||
elif status is Status.offline:
|
||||
status = 'invisible'
|
||||
status_value = 'invisible'
|
||||
status_enum = Status.offline
|
||||
else:
|
||||
status_enum = status
|
||||
status = str(status)
|
||||
status_value = str(status)
|
||||
|
||||
if shard_id is None:
|
||||
for shard in self.__shards.values():
|
||||
await shard.ws.change_presence(activity=activity, status=status)
|
||||
await shard.ws.change_presence(activity=activity, status=status_value)
|
||||
|
||||
guilds = self._connection.guilds
|
||||
else:
|
||||
shard = self.__shards[shard_id]
|
||||
await shard.ws.change_presence(activity=activity, status=status)
|
||||
await shard.ws.change_presence(activity=activity, status=status_value)
|
||||
guilds = [g for g in self._connection.guilds if g.shard_id == shard_id]
|
||||
|
||||
activities = () if activity is None else (activity,)
|
||||
@@ -493,10 +528,11 @@ class AutoShardedClient(Client):
|
||||
if me is None:
|
||||
continue
|
||||
|
||||
me.activities = activities
|
||||
# Member.activities is typehinted as Tuple[ActivityType, ...], we may be setting it as Tuple[BaseActivity, ...]
|
||||
me.activities = activities # type: ignore
|
||||
me.status = status_enum
|
||||
|
||||
def is_ws_ratelimited(self):
|
||||
def is_ws_ratelimited(self) -> bool:
|
||||
""":class:`bool`: Whether the websocket is currently rate limited.
|
||||
|
||||
This can be useful to know when deciding whether you should query members
|
||||
|
||||
@@ -61,6 +61,10 @@ class StageInstance(Hashable):
|
||||
|
||||
Returns the stage instance's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the stage instance's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
id: :class:`int`
|
||||
@@ -74,7 +78,7 @@ class StageInstance(Hashable):
|
||||
privacy_level: :class:`StagePrivacyLevel`
|
||||
The privacy level of the stage instance.
|
||||
discoverable_disabled: :class:`bool`
|
||||
Whether the stage instance is discoverable.
|
||||
Whether discoverability for the stage instance is disabled.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
@@ -97,21 +101,22 @@ class StageInstance(Hashable):
|
||||
self.id: int = int(data['id'])
|
||||
self.channel_id: int = int(data['channel_id'])
|
||||
self.topic: str = data['topic']
|
||||
self.privacy_level = try_enum(StagePrivacyLevel, data['privacy_level'])
|
||||
self.discoverable_disabled = data['discoverable_disabled']
|
||||
self.privacy_level: StagePrivacyLevel = try_enum(StagePrivacyLevel, data['privacy_level'])
|
||||
self.discoverable_disabled: bool = data.get('discoverable_disabled', False)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<StageInstance id={self.id} guild={self.guild!r} channel_id={self.channel_id} topic={self.topic!r}>'
|
||||
|
||||
@cached_slot_property('_cs_channel')
|
||||
def channel(self) -> Optional[StageChannel]:
|
||||
"""Optional[:class:`StageChannel`: The guild that stage instance is running in."""
|
||||
return self._state.get_channel(self.channel_id)
|
||||
"""Optional[:class:`StageChannel`]: The channel that stage instance is running in."""
|
||||
# the returned channel will always be a StageChannel or None
|
||||
return self._state.get_channel(self.channel_id) # type: ignore
|
||||
|
||||
def is_public(self) -> bool:
|
||||
return self.privacy_level is StagePrivacyLevel.public
|
||||
|
||||
async def edit(self, *, topic: str = MISSING, privacy_level: StagePrivacyLevel = MISSING) -> None:
|
||||
async def edit(self, *, topic: str = MISSING, privacy_level: StagePrivacyLevel = MISSING, reason: Optional[str] = None) -> None:
|
||||
"""|coro|
|
||||
|
||||
Edits the stage instance.
|
||||
@@ -125,6 +130,8 @@ class StageInstance(Hashable):
|
||||
The stage instance's new topic.
|
||||
privacy_level: :class:`StagePrivacyLevel`
|
||||
The stage instance's new privacy level.
|
||||
reason: :class:`str`
|
||||
The reason the stage instance was edited. Shows up on the audit log.
|
||||
|
||||
Raises
|
||||
------
|
||||
@@ -148,9 +155,9 @@ class StageInstance(Hashable):
|
||||
payload['privacy_level'] = privacy_level.value
|
||||
|
||||
if payload:
|
||||
await self._state.http.edit_stage_instance(self.channel_id, **payload)
|
||||
await self._state.http.edit_stage_instance(self.channel_id, **payload, reason=reason)
|
||||
|
||||
async def delete(self) -> None:
|
||||
async def delete(self, *, reason: Optional[str] = None) -> None:
|
||||
"""|coro|
|
||||
|
||||
Deletes the stage instance.
|
||||
@@ -158,6 +165,11 @@ class StageInstance(Hashable):
|
||||
You must have the :attr:`~Permissions.manage_channels` permission to
|
||||
use this.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
reason: :class:`str`
|
||||
The reason the stage instance was deleted. Shows up on the audit log.
|
||||
|
||||
Raises
|
||||
------
|
||||
Forbidden
|
||||
@@ -165,4 +177,4 @@ class StageInstance(Hashable):
|
||||
HTTPException
|
||||
Deleting the stage instance failed.
|
||||
"""
|
||||
await self._state.http.delete_stage_instance(self.channel_id)
|
||||
await self._state.http.delete_stage_instance(self.channel_id, reason=reason)
|
||||
|
||||
639
discord/state.py
639
discord/state.py
File diff suppressed because it is too large
Load Diff
@@ -23,24 +23,224 @@ DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import Literal, TYPE_CHECKING, List, Optional, Tuple, Type, Union
|
||||
import unicodedata
|
||||
|
||||
from .mixins import Hashable
|
||||
from .asset import Asset
|
||||
from .utils import snowflake_time
|
||||
from .enums import StickerType, try_enum
|
||||
from .asset import Asset, AssetMixin
|
||||
from .utils import cached_slot_property, find, snowflake_time, get, MISSING
|
||||
from .errors import InvalidData
|
||||
from .enums import StickerType, StickerFormatType, try_enum
|
||||
|
||||
__all__ = (
|
||||
'StickerPack',
|
||||
'StickerItem',
|
||||
'Sticker',
|
||||
'StandardSticker',
|
||||
'GuildSticker',
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import datetime
|
||||
from .state import ConnectionState
|
||||
from .types.message import Sticker as StickerPayload
|
||||
from .user import User
|
||||
from .guild import Guild
|
||||
from .types.sticker import (
|
||||
StickerPack as StickerPackPayload,
|
||||
StickerItem as StickerItemPayload,
|
||||
Sticker as StickerPayload,
|
||||
StandardSticker as StandardStickerPayload,
|
||||
GuildSticker as GuildStickerPayload,
|
||||
ListPremiumStickerPacks as ListPremiumStickerPacksPayload,
|
||||
EditGuildSticker,
|
||||
)
|
||||
|
||||
|
||||
class Sticker(Hashable):
|
||||
class StickerPack(Hashable):
|
||||
"""Represents a sticker pack.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
.. container:: operations
|
||||
|
||||
.. describe:: str(x)
|
||||
|
||||
Returns the name of the sticker pack.
|
||||
|
||||
.. describe:: hash(x)
|
||||
|
||||
Returns the hash of the sticker pack.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the ID of the sticker pack.
|
||||
|
||||
.. describe:: x == y
|
||||
|
||||
Checks if the sticker pack is equal to another sticker pack.
|
||||
|
||||
.. describe:: x != y
|
||||
|
||||
Checks if the sticker pack is not equal to another sticker pack.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
name: :class:`str`
|
||||
The name of the sticker pack.
|
||||
description: :class:`str`
|
||||
The description of the sticker pack.
|
||||
id: :class:`int`
|
||||
The id of the sticker pack.
|
||||
stickers: List[:class:`StandardSticker`]
|
||||
The stickers of this sticker pack.
|
||||
sku_id: :class:`int`
|
||||
The SKU ID of the sticker pack.
|
||||
cover_sticker_id: :class:`int`
|
||||
The ID of the sticker used for the cover of the sticker pack.
|
||||
cover_sticker: :class:`StandardSticker`
|
||||
The sticker used for the cover of the sticker pack.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'_state',
|
||||
'id',
|
||||
'stickers',
|
||||
'name',
|
||||
'sku_id',
|
||||
'cover_sticker_id',
|
||||
'cover_sticker',
|
||||
'description',
|
||||
'_banner',
|
||||
)
|
||||
|
||||
def __init__(self, *, state: ConnectionState, data: StickerPackPayload) -> None:
|
||||
self._state: ConnectionState = state
|
||||
self._from_data(data)
|
||||
|
||||
def _from_data(self, data: StickerPackPayload) -> None:
|
||||
self.id: int = int(data['id'])
|
||||
stickers = data['stickers']
|
||||
self.stickers: List[StandardSticker] = [StandardSticker(state=self._state, data=sticker) for sticker in stickers]
|
||||
self.name: str = data['name']
|
||||
self.sku_id: int = int(data['sku_id'])
|
||||
self.cover_sticker_id: int = int(data['cover_sticker_id'])
|
||||
self.cover_sticker: StandardSticker = get(self.stickers, id=self.cover_sticker_id) # type: ignore
|
||||
self.description: str = data['description']
|
||||
self._banner: int = int(data['banner_asset_id'])
|
||||
|
||||
@property
|
||||
def banner(self) -> Asset:
|
||||
""":class:`Asset`: The banner asset of the sticker pack."""
|
||||
return Asset._from_sticker_banner(self._state, self._banner)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<StickerPack id={self.id} name={self.name!r} description={self.description!r}>'
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
class _StickerTag(Hashable, AssetMixin):
|
||||
__slots__ = ()
|
||||
|
||||
id: int
|
||||
format: StickerFormatType
|
||||
|
||||
async def read(self) -> bytes:
|
||||
"""|coro|
|
||||
|
||||
Retrieves the content of this sticker as a :class:`bytes` object.
|
||||
|
||||
.. note::
|
||||
|
||||
Stickers that use the :attr:`StickerFormatType.lottie` format cannot be read.
|
||||
|
||||
Raises
|
||||
------
|
||||
HTTPException
|
||||
Downloading the asset failed.
|
||||
NotFound
|
||||
The asset was deleted.
|
||||
TypeError
|
||||
The sticker is a lottie type.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class:`bytes`
|
||||
The content of the asset.
|
||||
"""
|
||||
if self.format is StickerFormatType.lottie:
|
||||
raise TypeError('Cannot read stickers of format "lottie".')
|
||||
return await super().read()
|
||||
|
||||
|
||||
class StickerItem(_StickerTag):
|
||||
"""Represents a sticker item.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
.. container:: operations
|
||||
|
||||
.. describe:: str(x)
|
||||
|
||||
Returns the name of the sticker item.
|
||||
|
||||
.. describe:: x == y
|
||||
|
||||
Checks if the sticker item is equal to another sticker item.
|
||||
|
||||
.. describe:: x != y
|
||||
|
||||
Checks if the sticker item is not equal to another sticker item.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
name: :class:`str`
|
||||
The sticker's name.
|
||||
id: :class:`int`
|
||||
The id of the sticker.
|
||||
format: :class:`StickerFormatType`
|
||||
The format for the sticker's image.
|
||||
url: :class:`str`
|
||||
The URL for the sticker's image.
|
||||
"""
|
||||
|
||||
__slots__ = ('_state', 'name', 'id', 'format', 'url')
|
||||
|
||||
def __init__(self, *, state: ConnectionState, data: StickerItemPayload):
|
||||
self._state: ConnectionState = state
|
||||
self.name: str = data['name']
|
||||
self.id: int = int(data['id'])
|
||||
self.format: StickerFormatType = try_enum(StickerFormatType, data['format_type'])
|
||||
self.url: str = f'{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}'
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<StickerItem id={self.id} name={self.name!r} format={self.format}>'
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
async def fetch(self) -> Union[Sticker, StandardSticker, GuildSticker]:
|
||||
"""|coro|
|
||||
|
||||
Attempts to retrieve the full sticker data of the sticker item.
|
||||
|
||||
Raises
|
||||
--------
|
||||
HTTPException
|
||||
Retrieving the sticker failed.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Union[:class:`StandardSticker`, :class:`GuildSticker`]
|
||||
The retrieved sticker.
|
||||
"""
|
||||
data: StickerPayload = await self._state.http.get_sticker(self.id)
|
||||
cls, _ = _sticker_factory(data['type']) # type: ignore
|
||||
return cls(state=self._state, data=data)
|
||||
|
||||
|
||||
class Sticker(_StickerTag):
|
||||
"""Represents a sticker.
|
||||
|
||||
.. versionadded:: 1.6
|
||||
@@ -69,30 +269,27 @@ class Sticker(Hashable):
|
||||
The description of the sticker.
|
||||
pack_id: :class:`int`
|
||||
The id of the sticker's pack.
|
||||
format: :class:`StickerType`
|
||||
format: :class:`StickerFormatType`
|
||||
The format for the sticker's image.
|
||||
tags: List[:class:`str`]
|
||||
A list of tags for the sticker.
|
||||
url: :class:`str`
|
||||
The URL for the sticker's image.
|
||||
"""
|
||||
|
||||
__slots__ = ('_state', 'id', 'name', 'description', 'pack_id', 'format', '_image', 'tags')
|
||||
__slots__ = ('_state', 'id', 'name', 'description', 'format', 'url')
|
||||
|
||||
def __init__(self, *, state: ConnectionState, data: StickerPayload):
|
||||
def __init__(self, *, state: ConnectionState, data: StickerPayload) -> None:
|
||||
self._state: ConnectionState = state
|
||||
self._from_data(data)
|
||||
|
||||
def _from_data(self, data: StickerPayload) -> None:
|
||||
self.id: int = int(data['id'])
|
||||
self.name: str = data['name']
|
||||
self.description: str = data['description']
|
||||
self.pack_id: int = int(data.get('pack_id', 0))
|
||||
self.format: StickerType = try_enum(StickerType, data['format_type'])
|
||||
self._image: str = data['asset']
|
||||
|
||||
try:
|
||||
self.tags: List[str] = [tag.strip() for tag in data['tags'].split(',')]
|
||||
except KeyError:
|
||||
self.tags = []
|
||||
self.format: StickerFormatType = try_enum(StickerFormatType, data['format_type'])
|
||||
self.url: str = f'{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}'
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<{self.__class__.__name__} id={self.id} name={self.name!r}>'
|
||||
return f'<Sticker id={self.id} name={self.name!r}>'
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
@@ -102,19 +299,233 @@ class Sticker(Hashable):
|
||||
""":class:`datetime.datetime`: Returns the sticker's creation time in UTC."""
|
||||
return snowflake_time(self.id)
|
||||
|
||||
@property
|
||||
def image(self) -> Optional[Asset]:
|
||||
"""Returns an :class:`Asset` for the sticker's image.
|
||||
|
||||
.. note::
|
||||
This will return ``None`` if the format is ``StickerType.lottie``.
|
||||
class StandardSticker(Sticker):
|
||||
"""Represents a sticker that is found in a standard sticker pack.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
.. container:: operations
|
||||
|
||||
.. describe:: str(x)
|
||||
|
||||
Returns the name of the sticker.
|
||||
|
||||
.. describe:: x == y
|
||||
|
||||
Checks if the sticker is equal to another sticker.
|
||||
|
||||
.. describe:: x != y
|
||||
|
||||
Checks if the sticker is not equal to another sticker.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name: :class:`str`
|
||||
The sticker's name.
|
||||
id: :class:`int`
|
||||
The id of the sticker.
|
||||
description: :class:`str`
|
||||
The description of the sticker.
|
||||
pack_id: :class:`int`
|
||||
The id of the sticker's pack.
|
||||
format: :class:`StickerFormatType`
|
||||
The format for the sticker's image.
|
||||
tags: List[:class:`str`]
|
||||
A list of tags for the sticker.
|
||||
sort_value: :class:`int`
|
||||
The sticker's sort order within its pack.
|
||||
"""
|
||||
|
||||
__slots__ = ('sort_value', 'pack_id', 'type', 'tags')
|
||||
|
||||
def _from_data(self, data: StandardStickerPayload) -> None:
|
||||
super()._from_data(data)
|
||||
self.sort_value: int = data['sort_value']
|
||||
self.pack_id: int = int(data['pack_id'])
|
||||
self.type: StickerType = StickerType.standard
|
||||
|
||||
try:
|
||||
self.tags: List[str] = [tag.strip() for tag in data['tags'].split(',')]
|
||||
except KeyError:
|
||||
self.tags = []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<StandardSticker id={self.id} name={self.name!r} pack_id={self.pack_id}>'
|
||||
|
||||
async def pack(self) -> StickerPack:
|
||||
"""|coro|
|
||||
|
||||
Retrieves the sticker pack that this sticker belongs to.
|
||||
|
||||
Raises
|
||||
--------
|
||||
InvalidData
|
||||
The corresponding sticker pack was not found.
|
||||
HTTPException
|
||||
Retrieving the sticker pack failed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[:class:`Asset`]
|
||||
The resulting CDN asset.
|
||||
--------
|
||||
:class:`StickerPack`
|
||||
The retrieved sticker pack.
|
||||
"""
|
||||
if self.format is StickerType.lottie:
|
||||
return None
|
||||
data: ListPremiumStickerPacksPayload = await self._state.http.list_premium_sticker_packs()
|
||||
packs = data['sticker_packs']
|
||||
pack = find(lambda d: int(d['id']) == self.pack_id, packs)
|
||||
|
||||
return Asset._from_sticker(self._state, self.id, self._image)
|
||||
if pack:
|
||||
return StickerPack(state=self._state, data=pack)
|
||||
raise InvalidData(f'Could not find corresponding sticker pack for {self!r}')
|
||||
|
||||
|
||||
class GuildSticker(Sticker):
|
||||
"""Represents a sticker that belongs to a guild.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
.. container:: operations
|
||||
|
||||
.. describe:: str(x)
|
||||
|
||||
Returns the name of the sticker.
|
||||
|
||||
.. describe:: x == y
|
||||
|
||||
Checks if the sticker is equal to another sticker.
|
||||
|
||||
.. describe:: x != y
|
||||
|
||||
Checks if the sticker is not equal to another sticker.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name: :class:`str`
|
||||
The sticker's name.
|
||||
id: :class:`int`
|
||||
The id of the sticker.
|
||||
description: :class:`str`
|
||||
The description of the sticker.
|
||||
format: :class:`StickerFormatType`
|
||||
The format for the sticker's image.
|
||||
available: :class:`bool`
|
||||
Whether this sticker is available for use.
|
||||
guild_id: :class:`int`
|
||||
The ID of the guild that this sticker is from.
|
||||
user: Optional[:class:`User`]
|
||||
The user that created this sticker. This can only be retrieved using :meth:`Guild.fetch_sticker` and
|
||||
having the :attr:`~Permissions.manage_emojis_and_stickers` permission.
|
||||
emoji: :class:`str`
|
||||
The name of a unicode emoji that represents this sticker.
|
||||
"""
|
||||
|
||||
__slots__ = ('available', 'guild_id', 'user', 'emoji', 'type', '_cs_guild')
|
||||
|
||||
def _from_data(self, data: GuildStickerPayload) -> None:
|
||||
super()._from_data(data)
|
||||
self.available: bool = data['available']
|
||||
self.guild_id: int = int(data['guild_id'])
|
||||
user = data.get('user')
|
||||
self.user: Optional[User] = self._state.store_user(user) if user else None
|
||||
self.emoji: str = data['tags']
|
||||
self.type: StickerType = StickerType.guild
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<GuildSticker name={self.name!r} id={self.id} guild_id={self.guild_id} user={self.user!r}>'
|
||||
|
||||
@cached_slot_property('_cs_guild')
|
||||
def guild(self) -> Optional[Guild]:
|
||||
"""Optional[:class:`Guild`]: The guild that this sticker is from.
|
||||
Could be ``None`` if the bot is not in the guild.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return self._state._get_guild(self.guild_id)
|
||||
|
||||
async def edit(
|
||||
self,
|
||||
*,
|
||||
name: str = MISSING,
|
||||
description: str = MISSING,
|
||||
emoji: str = MISSING,
|
||||
reason: Optional[str] = None,
|
||||
) -> GuildSticker:
|
||||
"""|coro|
|
||||
|
||||
Edits a :class:`GuildSticker` for the guild.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
name: :class:`str`
|
||||
The sticker's new name. Must be at least 2 characters.
|
||||
description: Optional[:class:`str`]
|
||||
The sticker's new description. Can be ``None``.
|
||||
emoji: :class:`str`
|
||||
The name of a unicode emoji that represents the sticker's expression.
|
||||
reason: :class:`str`
|
||||
The reason for editing this sticker. Shows up on the audit log.
|
||||
|
||||
Raises
|
||||
-------
|
||||
Forbidden
|
||||
You are not allowed to edit stickers.
|
||||
HTTPException
|
||||
An error occurred editing the sticker.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`GuildSticker`
|
||||
The newly modified sticker.
|
||||
"""
|
||||
payload: EditGuildSticker = {}
|
||||
|
||||
if name is not MISSING:
|
||||
payload['name'] = name
|
||||
|
||||
if description is not MISSING:
|
||||
payload['description'] = description
|
||||
|
||||
if emoji is not MISSING:
|
||||
try:
|
||||
emoji = unicodedata.name(emoji)
|
||||
except TypeError:
|
||||
pass
|
||||
else:
|
||||
emoji = emoji.replace(' ', '_')
|
||||
|
||||
payload['tags'] = emoji
|
||||
|
||||
data: GuildStickerPayload = await self._state.http.modify_guild_sticker(self.guild_id, self.id, payload, reason)
|
||||
return GuildSticker(state=self._state, data=data)
|
||||
|
||||
async def delete(self, *, reason: Optional[str] = None) -> None:
|
||||
"""|coro|
|
||||
|
||||
Deletes the custom :class:`Sticker` from the guild.
|
||||
|
||||
You must have :attr:`~Permissions.manage_emojis_and_stickers` permission to
|
||||
do this.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
reason: Optional[:class:`str`]
|
||||
The reason for deleting this sticker. Shows up on the audit log.
|
||||
|
||||
Raises
|
||||
-------
|
||||
Forbidden
|
||||
You are not allowed to delete stickers.
|
||||
HTTPException
|
||||
An error occurred deleting the sticker.
|
||||
"""
|
||||
await self._state.http.delete_guild_sticker(self.guild_id, self.id, reason)
|
||||
|
||||
|
||||
def _sticker_factory(sticker_type: Literal[1, 2]) -> Tuple[Type[Union[StandardSticker, GuildSticker, Sticker]], StickerType]:
|
||||
value = try_enum(StickerType, sticker_type)
|
||||
if value == StickerType.standard:
|
||||
return StandardSticker, value
|
||||
elif value == StickerType.guild:
|
||||
return GuildSticker, value
|
||||
else:
|
||||
return Sticker, value
|
||||
|
||||
@@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional, TYPE_CHECKING, overload
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
from .utils import parse_time, _get_as_snowflake, _bytes_to_base64_data, MISSING
|
||||
from .enums import VoiceRegion
|
||||
from .guild import Guild
|
||||
@@ -34,7 +34,10 @@ __all__ = (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import datetime
|
||||
from .types.template import Template as TemplatePayload
|
||||
from .state import ConnectionState
|
||||
from .user import User
|
||||
|
||||
|
||||
class _FriendlyHttpAttributeErrorHelper:
|
||||
@@ -77,7 +80,7 @@ class _PartialTemplateState:
|
||||
def _get_guild(self, id):
|
||||
return self.__state._get_guild(id)
|
||||
|
||||
async def query_members(self, **kwargs):
|
||||
async def query_members(self, **kwargs: Any):
|
||||
return []
|
||||
|
||||
def __getattr__(self, attr):
|
||||
@@ -127,33 +130,35 @@ class Template:
|
||||
'_state',
|
||||
)
|
||||
|
||||
def __init__(self, *, state, data: TemplatePayload):
|
||||
def __init__(self, *, state: ConnectionState, data: TemplatePayload) -> None:
|
||||
self._state = state
|
||||
self._store(data)
|
||||
|
||||
def _store(self, data: TemplatePayload):
|
||||
self.code = data['code']
|
||||
self.uses = data['usage_count']
|
||||
self.name = data['name']
|
||||
self.description = data['description']
|
||||
def _store(self, data: TemplatePayload) -> None:
|
||||
self.code: str = data['code']
|
||||
self.uses: int = data['usage_count']
|
||||
self.name: str = data['name']
|
||||
self.description: Optional[str] = data['description']
|
||||
creator_data = data.get('creator')
|
||||
self.creator = None if creator_data is None else self._state.store_user(creator_data)
|
||||
self.creator: Optional[User] = None if creator_data is None else self._state.create_user(creator_data)
|
||||
|
||||
self.created_at = parse_time(data.get('created_at'))
|
||||
self.updated_at = parse_time(data.get('updated_at'))
|
||||
self.created_at: Optional[datetime.datetime] = parse_time(data.get('created_at'))
|
||||
self.updated_at: Optional[datetime.datetime] = parse_time(data.get('updated_at'))
|
||||
|
||||
id = _get_as_snowflake(data, 'source_guild_id')
|
||||
guild_id = int(data['source_guild_id'])
|
||||
guild: Optional[Guild] = self._state._get_guild(guild_id)
|
||||
|
||||
guild = self._state._get_guild(id)
|
||||
|
||||
if guild is None and id:
|
||||
self.source_guild: Guild
|
||||
if guild is None:
|
||||
source_serialised = data['serialized_source_guild']
|
||||
source_serialised['id'] = id
|
||||
source_serialised['id'] = guild_id
|
||||
state = _PartialTemplateState(state=self._state)
|
||||
guild = Guild(data=source_serialised, state=state) # type: ignore
|
||||
# Guild expects a ConnectionState, we're passing a _PartialTemplateState
|
||||
self.source_guild = Guild(data=source_serialised, state=state) # type: ignore
|
||||
else:
|
||||
self.source_guild = guild
|
||||
|
||||
self.source_guild = guild
|
||||
self.is_dirty = data.get('is_dirty', None)
|
||||
self.is_dirty: Optional[bool] = data.get('is_dirty', None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
@@ -161,7 +166,7 @@ class Template:
|
||||
f' creator={self.creator!r} source_guild={self.source_guild!r} is_dirty={self.is_dirty}>'
|
||||
)
|
||||
|
||||
async def create_guild(self, name: str, region: Optional[VoiceRegion] = None, icon: Any = None):
|
||||
async def create_guild(self, name: str, region: Optional[VoiceRegion] = None, icon: Any = None) -> Guild:
|
||||
"""|coro|
|
||||
|
||||
Creates a :class:`.Guild` using the template.
|
||||
@@ -201,7 +206,7 @@ class Template:
|
||||
data = await self._state.http.create_from_template(self.code, name, region_value, icon)
|
||||
return Guild(data=data, state=self._state)
|
||||
|
||||
async def sync(self) -> None:
|
||||
async def sync(self) -> Template:
|
||||
"""|coro|
|
||||
|
||||
Sync the template to the guild's current state.
|
||||
@@ -211,6 +216,9 @@ class Template:
|
||||
|
||||
.. versionadded:: 1.7
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
The template is no longer edited in-place, instead it is returned.
|
||||
|
||||
Raises
|
||||
-------
|
||||
HTTPException
|
||||
@@ -219,17 +227,22 @@ class Template:
|
||||
You don't have permissions to edit the template.
|
||||
NotFound
|
||||
This template does not exist.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`Template`
|
||||
The newly edited template.
|
||||
"""
|
||||
|
||||
data = await self._state.http.sync_template(self.source_guild.id, self.code)
|
||||
self._store(data)
|
||||
return Template(state=self._state, data=data)
|
||||
|
||||
async def edit(
|
||||
self,
|
||||
*,
|
||||
name: str = MISSING,
|
||||
description: Optional[str] = MISSING,
|
||||
) -> None:
|
||||
) -> Template:
|
||||
"""|coro|
|
||||
|
||||
Edit the template metadata.
|
||||
@@ -239,6 +252,9 @@ class Template:
|
||||
|
||||
.. versionadded:: 1.7
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
The template is no longer edited in-place, instead it is returned.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
name: :class:`str`
|
||||
@@ -254,6 +270,11 @@ class Template:
|
||||
You don't have permissions to edit the template.
|
||||
NotFound
|
||||
This template does not exist.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`Template`
|
||||
The newly edited template.
|
||||
"""
|
||||
payload = {}
|
||||
|
||||
@@ -263,7 +284,7 @@ class Template:
|
||||
payload['description'] = description
|
||||
|
||||
data = await self._state.http.edit_template(self.source_guild.id, self.code, payload)
|
||||
self._store(data)
|
||||
return Template(state=self._state, data=data)
|
||||
|
||||
async def delete(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
@@ -46,8 +46,9 @@ if TYPE_CHECKING:
|
||||
ThreadMetadata,
|
||||
ThreadArchiveDuration,
|
||||
)
|
||||
from .types.snowflake import SnowflakeList
|
||||
from .guild import Guild
|
||||
from .channel import TextChannel
|
||||
from .channel import TextChannel, CategoryChannel
|
||||
from .member import Member
|
||||
from .message import Message, PartialMessage
|
||||
from .abc import Snowflake, SnowflakeTime
|
||||
@@ -73,6 +74,10 @@ class Thread(Messageable, Hashable):
|
||||
|
||||
Returns the thread's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the thread's ID.
|
||||
|
||||
.. describe:: str(x)
|
||||
|
||||
Returns the thread's name.
|
||||
@@ -110,6 +115,9 @@ class Thread(Messageable, Hashable):
|
||||
Whether the thread is archived.
|
||||
locked: :class:`bool`
|
||||
Whether the thread is locked.
|
||||
invitable: :class:`bool`
|
||||
Whether non-moderators can add other non-moderators to this thread.
|
||||
This is always ``True`` for public threads.
|
||||
archiver_id: Optional[:class:`int`]
|
||||
The user's ID that archived this thread.
|
||||
auto_archive_duration: :class:`int`
|
||||
@@ -135,6 +143,7 @@ class Thread(Messageable, Hashable):
|
||||
'me',
|
||||
'locked',
|
||||
'archived',
|
||||
'invitable',
|
||||
'archiver_id',
|
||||
'auto_archive_duration',
|
||||
'archive_timestamp',
|
||||
@@ -166,6 +175,8 @@ class Thread(Messageable, Hashable):
|
||||
self._type = try_enum(ChannelType, data['type'])
|
||||
self.last_message_id = _get_as_snowflake(data, 'last_message_id')
|
||||
self.slowmode_delay = data.get('rate_limit_per_user', 0)
|
||||
self.message_count = data['message_count']
|
||||
self.member_count = data['member_count']
|
||||
self._unroll_metadata(data['thread_metadata'])
|
||||
|
||||
try:
|
||||
@@ -181,6 +192,7 @@ class Thread(Messageable, Hashable):
|
||||
self.auto_archive_duration = data['auto_archive_duration']
|
||||
self.archive_timestamp = parse_time(data['archive_timestamp'])
|
||||
self.locked = data.get('locked', False)
|
||||
self.invitable = data.get('invitable', True)
|
||||
|
||||
def _update(self, data):
|
||||
try:
|
||||
@@ -215,6 +227,16 @@ class Thread(Messageable, Hashable):
|
||||
""":class:`str`: The string that allows you to mention the thread."""
|
||||
return f'<#{self.id}>'
|
||||
|
||||
@property
|
||||
def members(self) -> List[ThreadMember]:
|
||||
"""List[:class:`ThreadMember`]: A list of thread members in this thread.
|
||||
|
||||
This requires :attr:`Intents.members` to be properly filled. Most of the time however,
|
||||
this data is not provided by the gateway and a call to :meth:`fetch_members` is
|
||||
needed.
|
||||
"""
|
||||
return list(self._members.values())
|
||||
|
||||
@property
|
||||
def last_message(self) -> Optional[Message]:
|
||||
"""Fetches the last message from this channel in cache.
|
||||
@@ -236,6 +258,26 @@ class Thread(Messageable, Hashable):
|
||||
"""
|
||||
return self._state._get_message(self.last_message_id) if self.last_message_id else None
|
||||
|
||||
@property
|
||||
def category(self) -> Optional[CategoryChannel]:
|
||||
"""The category channel the parent channel belongs to, if applicable.
|
||||
|
||||
Raises
|
||||
-------
|
||||
ClientException
|
||||
The parent channel was not cached and returned ``None``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[:class:`CategoryChannel`]
|
||||
The parent channel's category.
|
||||
"""
|
||||
|
||||
parent = self.parent
|
||||
if parent is None:
|
||||
raise ClientException('Parent channel not found')
|
||||
return parent.category
|
||||
|
||||
@property
|
||||
def category_id(self) -> Optional[int]:
|
||||
"""The category channel ID the parent channel belongs to, if applicable.
|
||||
@@ -362,7 +404,7 @@ class Thread(Messageable, Hashable):
|
||||
if len(messages) > 100:
|
||||
raise ClientException('Can only bulk delete messages up to 100 messages')
|
||||
|
||||
message_ids = [m.id for m in messages]
|
||||
message_ids: SnowflakeList = [m.id for m in messages]
|
||||
await self._state.http.delete_messages(self.id, message_ids)
|
||||
|
||||
async def purge(
|
||||
@@ -488,9 +530,10 @@ class Thread(Messageable, Hashable):
|
||||
name: str = MISSING,
|
||||
archived: bool = MISSING,
|
||||
locked: bool = MISSING,
|
||||
invitable: bool = MISSING,
|
||||
slowmode_delay: int = MISSING,
|
||||
auto_archive_duration: ThreadArchiveDuration = MISSING,
|
||||
):
|
||||
) -> Thread:
|
||||
"""|coro|
|
||||
|
||||
Edits the thread.
|
||||
@@ -510,8 +553,11 @@ class Thread(Messageable, Hashable):
|
||||
Whether to archive the thread or not.
|
||||
locked: :class:`bool`
|
||||
Whether to lock the thread or not.
|
||||
invitable: :class:`bool`
|
||||
Whether non-moderators can add other non-moderators to this thread.
|
||||
Only available for private threads.
|
||||
auto_archive_duration: :class:`int`
|
||||
The new duration to auto archive threads for inactivity.
|
||||
The new duration in minutes before a thread is automatically archived for inactivity.
|
||||
Must be one of ``60``, ``1440``, ``4320``, or ``10080``.
|
||||
slowmode_delay: :class:`int`
|
||||
Specifies the slowmode rate limit for user in this thread, in seconds.
|
||||
@@ -523,6 +569,11 @@ class Thread(Messageable, Hashable):
|
||||
You do not have permissions to edit the thread.
|
||||
HTTPException
|
||||
Editing the thread failed.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`Thread`
|
||||
The newly edited thread.
|
||||
"""
|
||||
payload = {}
|
||||
if name is not MISSING:
|
||||
@@ -533,20 +584,22 @@ class Thread(Messageable, Hashable):
|
||||
payload['auto_archive_duration'] = auto_archive_duration
|
||||
if locked is not MISSING:
|
||||
payload['locked'] = locked
|
||||
if invitable is not MISSING:
|
||||
payload['invitable'] = invitable
|
||||
if slowmode_delay is not MISSING:
|
||||
payload['rate_limit_per_user'] = slowmode_delay
|
||||
|
||||
await self._state.http.edit_channel(self.id, **payload)
|
||||
data = await self._state.http.edit_channel(self.id, **payload)
|
||||
# The data payload will always be a Thread payload
|
||||
return Thread(data=data, state=self._state, guild=self.guild) # type: ignore
|
||||
|
||||
async def join(self):
|
||||
"""|coro|
|
||||
|
||||
Joins this thread.
|
||||
|
||||
You must have :attr:`~Permissions.send_messages` and :attr:`~Permissions.use_threads`
|
||||
to join a public thread. If the thread is private then :attr:`~Permissions.send_messages`
|
||||
and either :attr:`~Permissions.use_private_threads` or :attr:`~Permissions.manage_messages`
|
||||
is required to join the thread.
|
||||
You must have :attr:`~Permissions.send_messages_in_threads` to join a thread.
|
||||
If the thread is private, :attr:`~Permissions.manage_threads` is also needed.
|
||||
|
||||
Raises
|
||||
-------
|
||||
@@ -614,6 +667,28 @@ class Thread(Messageable, Hashable):
|
||||
"""
|
||||
await self._state.http.remove_user_from_thread(self.id, user.id)
|
||||
|
||||
async def fetch_members(self) -> List[ThreadMember]:
|
||||
"""|coro|
|
||||
|
||||
Retrieves all :class:`ThreadMember` that are in this thread.
|
||||
|
||||
This requires :attr:`Intents.members` to get information about members
|
||||
other than yourself.
|
||||
|
||||
Raises
|
||||
-------
|
||||
HTTPException
|
||||
Retrieving the members failed.
|
||||
|
||||
Returns
|
||||
--------
|
||||
List[:class:`ThreadMember`]
|
||||
All thread members in the thread.
|
||||
"""
|
||||
|
||||
members = await self._state.http.get_thread_members(self.id)
|
||||
return [ThreadMember(parent=self, data=data) for data in members]
|
||||
|
||||
async def delete(self):
|
||||
"""|coro|
|
||||
|
||||
@@ -677,6 +752,10 @@ class ThreadMember(Hashable):
|
||||
|
||||
Returns the thread member's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the thread member's ID.
|
||||
|
||||
.. describe:: str(x)
|
||||
|
||||
Returns the thread member's name.
|
||||
|
||||
@@ -61,6 +61,7 @@ class _PartialAppInfoOptional(TypedDict, total=False):
|
||||
terms_of_service_url: str
|
||||
privacy_policy_url: str
|
||||
max_participants: int
|
||||
flags: int
|
||||
|
||||
class PartialAppInfo(_PartialAppInfoOptional, BaseAppInfo):
|
||||
pass
|
||||
|
||||
@@ -32,6 +32,7 @@ from .user import User
|
||||
from .snowflake import Snowflake
|
||||
from .role import Role
|
||||
from .channel import ChannelType, VideoQualityMode, PermissionOverwrite
|
||||
from .threads import Thread
|
||||
|
||||
AuditLogEvent = Literal[
|
||||
1,
|
||||
@@ -69,19 +70,28 @@ AuditLogEvent = Literal[
|
||||
80,
|
||||
81,
|
||||
82,
|
||||
83,
|
||||
84,
|
||||
85,
|
||||
90,
|
||||
91,
|
||||
92,
|
||||
110,
|
||||
111,
|
||||
112,
|
||||
]
|
||||
|
||||
|
||||
class _AuditLogChange_Str(TypedDict):
|
||||
key: Literal[
|
||||
'name', 'description', 'preferred_locale', 'vanity_url_code', 'topic', 'code', 'allow', 'deny', 'permissions'
|
||||
'name', 'description', 'preferred_locale', 'vanity_url_code', 'topic', 'code', 'allow', 'deny', 'permissions', 'tags'
|
||||
]
|
||||
new_value: str
|
||||
old_value: str
|
||||
|
||||
|
||||
class _AuditLogChange_AssetHash(TypedDict):
|
||||
key: Literal['icon_hash', 'splash_hash', 'discovery_splash_hash', 'banner_hash', 'avatar_hash']
|
||||
key: Literal['icon_hash', 'splash_hash', 'discovery_splash_hash', 'banner_hash', 'avatar_hash', 'asset']
|
||||
new_value: str
|
||||
old_value: str
|
||||
|
||||
@@ -98,6 +108,7 @@ class _AuditLogChange_Snowflake(TypedDict):
|
||||
'application_id',
|
||||
'channel_id',
|
||||
'inviter_id',
|
||||
'guild_id',
|
||||
]
|
||||
new_value: Snowflake
|
||||
old_value: Snowflake
|
||||
@@ -116,6 +127,9 @@ class _AuditLogChange_Bool(TypedDict):
|
||||
'enabled_emoticons',
|
||||
'region',
|
||||
'rtc_region',
|
||||
'available',
|
||||
'archived',
|
||||
'locked',
|
||||
]
|
||||
new_value: bool
|
||||
old_value: bool
|
||||
@@ -132,6 +146,8 @@ class _AuditLogChange_Int(TypedDict):
|
||||
'max_uses',
|
||||
'max_age',
|
||||
'user_limit',
|
||||
'auto_archive_duration',
|
||||
'default_auto_archive_duration',
|
||||
]
|
||||
new_value: int
|
||||
old_value: int
|
||||
@@ -238,3 +254,4 @@ class AuditLog(TypedDict):
|
||||
users: List[User]
|
||||
audit_log_entries: List[AuditLogEntry]
|
||||
integrations: List[PartialIntegration]
|
||||
threads: List[Thread]
|
||||
|
||||
@@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
|
||||
from typing import List, Literal, Optional, TypedDict, Union
|
||||
from .user import PartialUser
|
||||
from .snowflake import Snowflake
|
||||
from .threads import ThreadMetadata, ThreadMember
|
||||
from .threads import ThreadMetadata, ThreadMember, ThreadArchiveDuration
|
||||
|
||||
|
||||
OverwriteType = Literal[0, 1]
|
||||
@@ -63,6 +63,7 @@ class _TextChannelOptional(TypedDict, total=False):
|
||||
last_message_id: Optional[Snowflake]
|
||||
last_pin_timestamp: str
|
||||
rate_limit_per_user: int
|
||||
default_auto_archive_duration: ThreadArchiveDuration
|
||||
|
||||
|
||||
class TextChannel(_BaseGuildChannel, _TextChannelOptional):
|
||||
@@ -115,7 +116,7 @@ class _ThreadChannelOptional(TypedDict, total=False):
|
||||
|
||||
|
||||
class ThreadChannel(_BaseChannel, _ThreadChannelOptional):
|
||||
type: Literal[11, 12]
|
||||
type: Literal[10, 11, 12]
|
||||
guild_id: Snowflake
|
||||
parent_id: Snowflake
|
||||
owner_id: Snowflake
|
||||
|
||||
@@ -53,6 +53,7 @@ class _SelectMenuOptional(TypedDict, total=False):
|
||||
placeholder: str
|
||||
min_values: int
|
||||
max_values: int
|
||||
disabled: bool
|
||||
|
||||
|
||||
class _SelectOptionsOptional(TypedDict, total=False):
|
||||
|
||||
@@ -37,8 +37,11 @@ if TYPE_CHECKING:
|
||||
from .message import AllowedMentions, Message
|
||||
|
||||
|
||||
ApplicationCommandType = Literal[1, 2, 3]
|
||||
|
||||
class _ApplicationCommandOptional(TypedDict, total=False):
|
||||
options: List[ApplicationCommandOption]
|
||||
type: ApplicationCommandType
|
||||
|
||||
|
||||
class ApplicationCommand(_ApplicationCommandOptional):
|
||||
@@ -53,7 +56,7 @@ class _ApplicationCommandOptionOptional(TypedDict, total=False):
|
||||
options: List[ApplicationCommandOption]
|
||||
|
||||
|
||||
ApplicationCommandOptionType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
ApplicationCommandOptionType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
|
||||
|
||||
class ApplicationCommandOption(_ApplicationCommandOptionOptional):
|
||||
@@ -122,12 +125,18 @@ class _ApplicationCommandInteractionDataOptionSnowflake(_ApplicationCommandInter
|
||||
value: Snowflake
|
||||
|
||||
|
||||
class _ApplicationCommandInteractionDataOptionNumber(_ApplicationCommandInteractionDataOption):
|
||||
type: Literal[10]
|
||||
value: float
|
||||
|
||||
|
||||
ApplicationCommandInteractionDataOption = Union[
|
||||
_ApplicationCommandInteractionDataOptionString,
|
||||
_ApplicationCommandInteractionDataOptionInteger,
|
||||
_ApplicationCommandInteractionDataOptionSubcommand,
|
||||
_ApplicationCommandInteractionDataOptionBoolean,
|
||||
_ApplicationCommandInteractionDataOptionSnowflake,
|
||||
_ApplicationCommandInteractionDataOptionNumber,
|
||||
]
|
||||
|
||||
|
||||
@@ -148,6 +157,8 @@ class ApplicationCommandInteractionDataResolved(TypedDict, total=False):
|
||||
class _ApplicationCommandInteractionDataOptional(TypedDict, total=False):
|
||||
options: List[ApplicationCommandInteractionDataOption]
|
||||
resolved: ApplicationCommandInteractionDataResolved
|
||||
target_id: Snowflake
|
||||
type: ApplicationCommandType
|
||||
|
||||
|
||||
class ApplicationCommandInteractionData(_ApplicationCommandInteractionDataOptional):
|
||||
@@ -211,8 +222,15 @@ class MessageInteraction(TypedDict):
|
||||
user: User
|
||||
|
||||
|
||||
class EditApplicationCommand(TypedDict):
|
||||
name: str
|
||||
|
||||
|
||||
|
||||
class _EditApplicationCommandOptional(TypedDict, total=False):
|
||||
description: str
|
||||
options: Optional[List[ApplicationCommandOption]]
|
||||
type: ApplicationCommandType
|
||||
|
||||
|
||||
class EditApplicationCommand(_EditApplicationCommandOptional):
|
||||
name: str
|
||||
default_permission: bool
|
||||
|
||||
@@ -39,6 +39,7 @@ class PartialMember(TypedDict):
|
||||
|
||||
|
||||
class Member(PartialMember, total=False):
|
||||
avatar: str
|
||||
user: User
|
||||
nick: str
|
||||
premium_since: str
|
||||
@@ -46,16 +47,17 @@ class Member(PartialMember, total=False):
|
||||
permissions: str
|
||||
|
||||
|
||||
class _OptionalGatewayMember(PartialMember, total=False):
|
||||
class _OptionalMemberWithUser(PartialMember, total=False):
|
||||
avatar: str
|
||||
nick: str
|
||||
premium_since: str
|
||||
pending: bool
|
||||
permissions: str
|
||||
|
||||
|
||||
class GatewayMember(_OptionalGatewayMember):
|
||||
class MemberWithUser(_OptionalMemberWithUser):
|
||||
user: User
|
||||
|
||||
|
||||
class UserWithMember(User, total=False):
|
||||
member: _OptionalGatewayMember
|
||||
member: _OptionalMemberWithUser
|
||||
|
||||
@@ -33,6 +33,7 @@ from .embed import Embed
|
||||
from .channel import ChannelType
|
||||
from .components import Component
|
||||
from .interactions import MessageInteraction
|
||||
from .sticker import StickerItem
|
||||
|
||||
|
||||
class ChannelMention(TypedDict):
|
||||
@@ -89,22 +90,6 @@ class MessageReference(TypedDict, total=False):
|
||||
fail_if_not_exists: bool
|
||||
|
||||
|
||||
class _StickerOptional(TypedDict, total=False):
|
||||
tags: str
|
||||
|
||||
|
||||
StickerFormatType = Literal[1, 2, 3]
|
||||
|
||||
|
||||
class Sticker(_StickerOptional):
|
||||
id: Snowflake
|
||||
pack_id: Snowflake
|
||||
name: str
|
||||
description: str
|
||||
asset: str
|
||||
format_type: StickerFormatType
|
||||
|
||||
|
||||
class _MessageOptional(TypedDict, total=False):
|
||||
guild_id: Snowflake
|
||||
member: Member
|
||||
@@ -117,7 +102,7 @@ class _MessageOptional(TypedDict, total=False):
|
||||
application_id: Snowflake
|
||||
message_reference: MessageReference
|
||||
flags: int
|
||||
stickers: List[Sticker]
|
||||
sticker_items: List[StickerItem]
|
||||
referenced_message: Optional[Message]
|
||||
interaction: MessageInteraction
|
||||
components: List[Component]
|
||||
|
||||
87
discord/types/raw_models.py
Normal file
87
discord/types/raw_models.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015-present Rapptz
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a
|
||||
copy of this software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from typing import TypedDict, List
|
||||
from .snowflake import Snowflake
|
||||
from .member import Member
|
||||
from .emoji import PartialEmoji
|
||||
|
||||
|
||||
class _MessageEventOptional(TypedDict, total=False):
|
||||
guild_id: Snowflake
|
||||
|
||||
|
||||
class MessageDeleteEvent(_MessageEventOptional):
|
||||
id: Snowflake
|
||||
channel_id: Snowflake
|
||||
|
||||
|
||||
class BulkMessageDeleteEvent(_MessageEventOptional):
|
||||
ids: List[Snowflake]
|
||||
channel_id: Snowflake
|
||||
|
||||
|
||||
class _ReactionActionEventOptional(TypedDict, total=False):
|
||||
guild_id: Snowflake
|
||||
member: Member
|
||||
|
||||
|
||||
class MessageUpdateEvent(_MessageEventOptional):
|
||||
id: Snowflake
|
||||
channel_id: Snowflake
|
||||
|
||||
|
||||
class ReactionActionEvent(_ReactionActionEventOptional):
|
||||
user_id: Snowflake
|
||||
channel_id: Snowflake
|
||||
message_id: Snowflake
|
||||
emoji: PartialEmoji
|
||||
|
||||
|
||||
class _ReactionClearEventOptional(TypedDict, total=False):
|
||||
guild_id: Snowflake
|
||||
|
||||
|
||||
class ReactionClearEvent(_ReactionClearEventOptional):
|
||||
channel_id: Snowflake
|
||||
message_id: Snowflake
|
||||
|
||||
|
||||
class _ReactionClearEmojiEventOptional(TypedDict, total=False):
|
||||
guild_id: Snowflake
|
||||
|
||||
|
||||
class ReactionClearEmojiEvent(_ReactionClearEmojiEventOptional):
|
||||
channel_id: int
|
||||
message_id: int
|
||||
emoji: PartialEmoji
|
||||
|
||||
|
||||
class _IntegrationDeleteEventOptional(TypedDict, total=False):
|
||||
application_id: Snowflake
|
||||
|
||||
|
||||
class IntegrationDeleteEvent(_IntegrationDeleteEventOptional):
|
||||
id: Snowflake
|
||||
guild_id: Snowflake
|
||||
93
discord/types/sticker.py
Normal file
93
discord/types/sticker.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015-present Rapptz
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a
|
||||
copy of this software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Literal, TypedDict, Union
|
||||
from .snowflake import Snowflake
|
||||
from .user import User
|
||||
|
||||
StickerFormatType = Literal[1, 2, 3]
|
||||
|
||||
|
||||
class StickerItem(TypedDict):
|
||||
id: Snowflake
|
||||
name: str
|
||||
format_type: StickerFormatType
|
||||
|
||||
|
||||
class BaseSticker(TypedDict):
|
||||
id: Snowflake
|
||||
name: str
|
||||
description: str
|
||||
tags: str
|
||||
format_type: StickerFormatType
|
||||
|
||||
|
||||
class StandardSticker(BaseSticker):
|
||||
type: Literal[1]
|
||||
sort_value: int
|
||||
pack_id: Snowflake
|
||||
|
||||
|
||||
class _GuildStickerOptional(TypedDict, total=False):
|
||||
user: User
|
||||
|
||||
|
||||
class GuildSticker(BaseSticker, _GuildStickerOptional):
|
||||
type: Literal[2]
|
||||
available: bool
|
||||
guild_id: Snowflake
|
||||
|
||||
|
||||
Sticker = Union[BaseSticker, StandardSticker, GuildSticker]
|
||||
|
||||
|
||||
class StickerPack(TypedDict):
|
||||
id: Snowflake
|
||||
stickers: List[StandardSticker]
|
||||
name: str
|
||||
sku_id: Snowflake
|
||||
cover_sticker_id: Snowflake
|
||||
description: str
|
||||
banner_asset_id: Snowflake
|
||||
|
||||
|
||||
class _CreateGuildStickerOptional(TypedDict, total=False):
|
||||
description: str
|
||||
|
||||
|
||||
class CreateGuildSticker(_CreateGuildStickerOptional):
|
||||
name: str
|
||||
tags: str
|
||||
|
||||
|
||||
class EditGuildSticker(TypedDict, total=False):
|
||||
name: str
|
||||
tags: str
|
||||
description: str
|
||||
|
||||
|
||||
class ListPremiumStickerPacks(TypedDict):
|
||||
sticker_packs: List[StickerPack]
|
||||
@@ -41,6 +41,7 @@ class ThreadMember(TypedDict):
|
||||
class _ThreadMetadataOptional(TypedDict, total=False):
|
||||
archiver_id: Snowflake
|
||||
locked: bool
|
||||
invitable: bool
|
||||
|
||||
|
||||
class ThreadMetadata(_ThreadMetadataOptional):
|
||||
|
||||
@@ -24,14 +24,14 @@ DEALINGS IN THE SOFTWARE.
|
||||
|
||||
from typing import Optional, TypedDict, List, Literal
|
||||
from .snowflake import Snowflake
|
||||
from .member import Member
|
||||
from .member import MemberWithUser
|
||||
|
||||
|
||||
SupportedModes = Literal['xsalsa20_poly1305_lite', 'xsalsa20_poly1305_suffix', 'xsalsa20_poly1305']
|
||||
|
||||
|
||||
class _PartialVoiceStateOptional(TypedDict, total=False):
|
||||
member: Member
|
||||
member: MemberWithUser
|
||||
self_stream: bool
|
||||
|
||||
|
||||
|
||||
@@ -78,6 +78,8 @@ class Select(Item[V]):
|
||||
Defaults to 1 and must be between 1 and 25.
|
||||
options: List[:class:`discord.SelectOption`]
|
||||
A list of options that can be selected in this menu.
|
||||
disabled: :class:`bool`
|
||||
Whether the select is disabled or not.
|
||||
row: Optional[:class:`int`]
|
||||
The relative row this select menu belongs to. A Discord component can only have 5
|
||||
rows. By default, items are arranged automatically into those 5 rows. If you'd
|
||||
@@ -91,6 +93,7 @@ class Select(Item[V]):
|
||||
'min_values',
|
||||
'max_values',
|
||||
'options',
|
||||
'disabled',
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@@ -101,8 +104,10 @@ class Select(Item[V]):
|
||||
min_values: int = 1,
|
||||
max_values: int = 1,
|
||||
options: List[SelectOption] = MISSING,
|
||||
disabled: bool = False,
|
||||
row: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._selected_values: List[str] = []
|
||||
self._provided_custom_id = custom_id is not MISSING
|
||||
custom_id = os.urandom(16).hex() if custom_id is MISSING else custom_id
|
||||
@@ -114,6 +119,7 @@ class Select(Item[V]):
|
||||
min_values=min_values,
|
||||
max_values=max_values,
|
||||
options=options,
|
||||
disabled=disabled,
|
||||
)
|
||||
self.row = row
|
||||
|
||||
@@ -191,13 +197,13 @@ class Select(Item[V]):
|
||||
-----------
|
||||
label: :class:`str`
|
||||
The label of the option. This is displayed to users.
|
||||
Can only be up to 25 characters.
|
||||
Can only be up to 100 characters.
|
||||
value: :class:`str`
|
||||
The value of the option. This is not displayed to users.
|
||||
If not given, defaults to the label. Can only be up to 100 characters.
|
||||
description: Optional[:class:`str`]
|
||||
An additional description of the option, if any.
|
||||
Can only be up to 50 characters.
|
||||
Can only be up to 100 characters.
|
||||
emoji: Optional[Union[:class:`str`, :class:`.Emoji`, :class:`.PartialEmoji`]]
|
||||
The emoji of the option, if available. This can either be a string representing
|
||||
the custom or unicode emoji or an instance of :class:`.PartialEmoji` or :class:`.Emoji`.
|
||||
@@ -240,6 +246,15 @@ class Select(Item[V]):
|
||||
|
||||
self._underlying.options.append(option)
|
||||
|
||||
@property
|
||||
def disabled(self) -> bool:
|
||||
""":class:`bool`: Whether the select is disabled or not."""
|
||||
return self._underlying.disabled
|
||||
|
||||
@disabled.setter
|
||||
def disabled(self, value: bool):
|
||||
self._underlying.disabled = bool(value)
|
||||
|
||||
@property
|
||||
def values(self) -> List[str]:
|
||||
"""List[:class:`str`]: A list of values that have been selected by the user."""
|
||||
@@ -267,6 +282,7 @@ class Select(Item[V]):
|
||||
min_values=component.min_values,
|
||||
max_values=component.max_values,
|
||||
options=component.options,
|
||||
disabled=component.disabled,
|
||||
row=None,
|
||||
)
|
||||
|
||||
@@ -285,6 +301,7 @@ def select(
|
||||
min_values: int = 1,
|
||||
max_values: int = 1,
|
||||
options: List[SelectOption] = MISSING,
|
||||
disabled: bool = False,
|
||||
row: Optional[int] = None,
|
||||
) -> Callable[[ItemCallbackType], ItemCallbackType]:
|
||||
"""A decorator that attaches a select menu to a component.
|
||||
@@ -317,11 +334,13 @@ def select(
|
||||
Defaults to 1 and must be between 1 and 25.
|
||||
options: List[:class:`discord.SelectOption`]
|
||||
A list of options that can be selected in this menu.
|
||||
disabled: :class:`bool`
|
||||
Whether the select is disabled or not. Defaults to ``False``.
|
||||
"""
|
||||
|
||||
def decorator(func: ItemCallbackType) -> ItemCallbackType:
|
||||
if not inspect.iscoroutinefunction(func):
|
||||
raise TypeError('button function must be a coroutine function')
|
||||
raise TypeError('select function must be a coroutine function')
|
||||
|
||||
func.__discord_ui_model_type__ = Select
|
||||
func.__discord_ui_model_kwargs__ = {
|
||||
@@ -331,6 +350,7 @@ def select(
|
||||
'min_values': min_values,
|
||||
'max_values': max_values,
|
||||
'options': options,
|
||||
'disabled': disabled,
|
||||
}
|
||||
return func
|
||||
|
||||
|
||||
@@ -456,8 +456,8 @@ class View:
|
||||
|
||||
class ViewStore:
|
||||
def __init__(self, state: ConnectionState):
|
||||
# (component_type, custom_id): (View, Item)
|
||||
self._views: Dict[Tuple[int, str], Tuple[View, Item]] = {}
|
||||
# (component_type, message_id, custom_id): (View, Item)
|
||||
self._views: Dict[Tuple[int, Optional[int], str], Tuple[View, Item]] = {}
|
||||
# message_id: View
|
||||
self._synced_message_views: Dict[int, View] = {}
|
||||
self._state: ConnectionState = state
|
||||
@@ -474,8 +474,7 @@ class ViewStore:
|
||||
return list(views.values())
|
||||
|
||||
def __verify_integrity(self):
|
||||
to_remove: List[Tuple[int, str]] = []
|
||||
now = time.monotonic()
|
||||
to_remove: List[Tuple[int, Optional[int], str]] = []
|
||||
for (k, (view, _)) in self._views.items():
|
||||
if view.is_finished():
|
||||
to_remove.append(k)
|
||||
@@ -489,7 +488,7 @@ class ViewStore:
|
||||
view._start_listening_from_store(self)
|
||||
for item in view.children:
|
||||
if item.is_dispatchable():
|
||||
self._views[(item.type.value, item.custom_id)] = (view, item) # type: ignore
|
||||
self._views[(item.type.value, message_id, item.custom_id)] = (view, item) # type: ignore
|
||||
|
||||
if message_id is not None:
|
||||
self._synced_message_views[message_id] = view
|
||||
@@ -506,8 +505,11 @@ class ViewStore:
|
||||
|
||||
def dispatch(self, component_type: int, custom_id: str, interaction: Interaction):
|
||||
self.__verify_integrity()
|
||||
key = (component_type, custom_id)
|
||||
value = self._views.get(key)
|
||||
message_id: Optional[int] = interaction.message and interaction.message.id
|
||||
key = (component_type, message_id, custom_id)
|
||||
# Fallback to None message_id searches in case a persistent view
|
||||
# was added without an associated message_id
|
||||
value = self._views.get(key) or self._views.get((component_type, None, custom_id))
|
||||
if value is None:
|
||||
return
|
||||
|
||||
|
||||
200
discord/user.py
200
discord/user.py
@@ -22,19 +22,35 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type, TypeVar, TYPE_CHECKING
|
||||
|
||||
import discord.abc
|
||||
from .asset import Asset
|
||||
from .colour import Colour
|
||||
from .enums import DefaultAvatar
|
||||
from .flags import PublicUserFlags
|
||||
from .utils import snowflake_time, _bytes_to_base64_data, MISSING
|
||||
from .enums import DefaultAvatar
|
||||
from .colour import Colour
|
||||
from .asset import Asset
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import datetime
|
||||
|
||||
from .channel import DMChannel
|
||||
from .guild import Guild
|
||||
from .message import Message
|
||||
from .state import ConnectionState
|
||||
from .types.channel import DMChannel as DMChannelPayload
|
||||
from .types.user import User as UserPayload
|
||||
|
||||
|
||||
__all__ = (
|
||||
'User',
|
||||
'ClientUser',
|
||||
)
|
||||
|
||||
BU = TypeVar('BU', bound='BaseUser')
|
||||
|
||||
|
||||
class _UserTag:
|
||||
__slots__ = ()
|
||||
@@ -42,7 +58,18 @@ class _UserTag:
|
||||
|
||||
|
||||
class BaseUser(_UserTag):
|
||||
__slots__ = ('name', 'id', 'discriminator', '_avatar', 'bot', 'system', '_public_flags', '_state')
|
||||
__slots__ = (
|
||||
'name',
|
||||
'id',
|
||||
'discriminator',
|
||||
'_avatar',
|
||||
'_banner',
|
||||
'_accent_colour',
|
||||
'bot',
|
||||
'system',
|
||||
'_public_flags',
|
||||
'_state',
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
name: str
|
||||
@@ -50,53 +77,65 @@ class BaseUser(_UserTag):
|
||||
discriminator: str
|
||||
bot: bool
|
||||
system: bool
|
||||
_state: ConnectionState
|
||||
_avatar: Optional[str]
|
||||
_banner: Optional[str]
|
||||
_accent_colour: Optional[str]
|
||||
_public_flags: int
|
||||
|
||||
def __init__(self, *, state, data):
|
||||
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
|
||||
self._state = state
|
||||
self._update(data)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<BaseUser id={self.id} name={self.name!r} discriminator={self.discriminator!r}"
|
||||
f" bot={self.bot} system={self.system}>"
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f'{self.name}#{self.discriminator}'
|
||||
|
||||
def __eq__(self, other):
|
||||
def __int__(self) -> int:
|
||||
return self.id
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, _UserTag) and other.id == self.id
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return self.id >> 22
|
||||
|
||||
def _update(self, data):
|
||||
def _update(self, data: UserPayload) -> None:
|
||||
self.name = data['username']
|
||||
self.id = int(data['id'])
|
||||
self.discriminator = data['discriminator']
|
||||
self._avatar = data['avatar']
|
||||
self._banner = data.get('banner', None)
|
||||
self._accent_colour = data.get('accent_color', None)
|
||||
self._public_flags = data.get('public_flags', 0)
|
||||
self.bot = data.get('bot', False)
|
||||
self.system = data.get('system', False)
|
||||
|
||||
@classmethod
|
||||
def _copy(cls, user):
|
||||
def _copy(cls: Type[BU], user: BU) -> BU:
|
||||
self = cls.__new__(cls) # bypass __init__
|
||||
|
||||
self.name = user.name
|
||||
self.id = user.id
|
||||
self.discriminator = user.discriminator
|
||||
self._avatar = user._avatar
|
||||
self._banner = user._banner
|
||||
self._accent_colour = user._accent_colour
|
||||
self.bot = user.bot
|
||||
self._state = user._state
|
||||
self._public_flags = user._public_flags
|
||||
|
||||
return self
|
||||
|
||||
def _to_minimal_user_json(self):
|
||||
def _to_minimal_user_json(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'username': self.name,
|
||||
'id': self.id,
|
||||
@@ -106,29 +145,82 @@ class BaseUser(_UserTag):
|
||||
}
|
||||
|
||||
@property
|
||||
def public_flags(self):
|
||||
def public_flags(self) -> PublicUserFlags:
|
||||
""":class:`PublicUserFlags`: The publicly available flags the user has."""
|
||||
return PublicUserFlags._from_value(self._public_flags)
|
||||
|
||||
@property
|
||||
def avatar(self):
|
||||
""":class:`Asset`: Returns an :class:`Asset` for the avatar the user has.
|
||||
def avatar(self) -> Optional[Asset]:
|
||||
"""Optional[:class:`Asset`]: Returns an :class:`Asset` for the avatar the user has.
|
||||
|
||||
If the user does not have a traditional avatar, an asset for
|
||||
the default avatar is returned instead.
|
||||
If the user does not have a traditional avatar, ``None`` is returned.
|
||||
If you want the avatar that a user has displayed, consider :attr:`display_avatar`.
|
||||
"""
|
||||
if self._avatar is None:
|
||||
return Asset._from_default_avatar(self._state, int(self.discriminator) % len(DefaultAvatar))
|
||||
else:
|
||||
if self._avatar is not None:
|
||||
return Asset._from_avatar(self._state, self.id, self._avatar)
|
||||
return None
|
||||
|
||||
@property
|
||||
def default_avatar(self):
|
||||
def default_avatar(self) -> Asset:
|
||||
""":class:`Asset`: Returns the default avatar for a given user. This is calculated by the user's discriminator."""
|
||||
return Asset._from_default_avatar(self._state, int(self.discriminator) % len(DefaultAvatar))
|
||||
|
||||
@property
|
||||
def colour(self):
|
||||
def display_avatar(self) -> Asset:
|
||||
""":class:`Asset`: Returns the user's display avatar.
|
||||
|
||||
For regular users this is just their default avatar or uploaded avatar.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return self.avatar or self.default_avatar
|
||||
|
||||
@property
|
||||
def banner(self) -> Optional[Asset]:
|
||||
"""Optional[:class:`Asset`]: Returns the user's banner asset, if available.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
|
||||
.. note::
|
||||
This information is only available via :meth:`Client.fetch_user`.
|
||||
"""
|
||||
if self._banner is None:
|
||||
return None
|
||||
return Asset._from_user_banner(self._state, self.id, self._banner)
|
||||
|
||||
@property
|
||||
def accent_colour(self) -> Optional[Colour]:
|
||||
"""Optional[:class:`Colour`]: Returns the user's accent colour, if applicable.
|
||||
|
||||
There is an alias for this named :attr:`accent_color`.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
.. note::
|
||||
|
||||
This information is only available via :meth:`Client.fetch_user`.
|
||||
"""
|
||||
if self._accent_colour is None:
|
||||
return None
|
||||
return Colour(self._accent_colour)
|
||||
|
||||
@property
|
||||
def accent_color(self) -> Optional[Colour]:
|
||||
"""Optional[:class:`Colour`]: Returns the user's accent color, if applicable.
|
||||
|
||||
There is an alias for this named :attr:`accent_colour`.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
.. note::
|
||||
|
||||
This information is only available via :meth:`Client.fetch_user`.
|
||||
"""
|
||||
return self.accent_colour
|
||||
|
||||
@property
|
||||
def colour(self) -> Colour:
|
||||
""":class:`Colour`: A property that returns a colour denoting the rendered colour
|
||||
for the user. This always returns :meth:`Colour.default`.
|
||||
|
||||
@@ -137,7 +229,7 @@ class BaseUser(_UserTag):
|
||||
return Colour.default()
|
||||
|
||||
@property
|
||||
def color(self):
|
||||
def color(self) -> Colour:
|
||||
""":class:`Colour`: A property that returns a color denoting the rendered color
|
||||
for the user. This always returns :meth:`Colour.default`.
|
||||
|
||||
@@ -146,12 +238,12 @@ class BaseUser(_UserTag):
|
||||
return self.colour
|
||||
|
||||
@property
|
||||
def mention(self):
|
||||
def mention(self) -> str:
|
||||
""":class:`str`: Returns a string that allows you to mention the given user."""
|
||||
return f'<@{self.id}>'
|
||||
|
||||
@property
|
||||
def created_at(self):
|
||||
def created_at(self) -> datetime:
|
||||
""":class:`datetime.datetime`: Returns the user's creation time in UTC.
|
||||
|
||||
This is when the user's Discord account was created.
|
||||
@@ -159,7 +251,7 @@ class BaseUser(_UserTag):
|
||||
return snowflake_time(self.id)
|
||||
|
||||
@property
|
||||
def display_name(self):
|
||||
def display_name(self) -> str:
|
||||
""":class:`str`: Returns the user's display name.
|
||||
|
||||
For regular users this is just their username, but
|
||||
@@ -168,7 +260,7 @@ class BaseUser(_UserTag):
|
||||
"""
|
||||
return self.name
|
||||
|
||||
def mentioned_in(self, message):
|
||||
def mentioned_in(self, message: Message) -> bool:
|
||||
"""Checks if the user is mentioned in the specified message.
|
||||
|
||||
Parameters
|
||||
@@ -234,16 +326,22 @@ class ClientUser(BaseUser):
|
||||
|
||||
__slots__ = ('locale', '_flags', 'verified', 'mfa_enabled', '__weakref__')
|
||||
|
||||
def __init__(self, *, state, data):
|
||||
if TYPE_CHECKING:
|
||||
verified: bool
|
||||
locale: Optional[str]
|
||||
mfa_enabled: bool
|
||||
_flags: int
|
||||
|
||||
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
|
||||
super().__init__(state=state, data=data)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'<ClientUser id={self.id} name={self.name!r} discriminator={self.discriminator!r}'
|
||||
f' bot={self.bot} verified={self.verified} mfa_enabled={self.mfa_enabled}>'
|
||||
)
|
||||
|
||||
def _update(self, data):
|
||||
def _update(self, data: UserPayload) -> None:
|
||||
super()._update(data)
|
||||
# There's actually an Optional[str] phone field as well but I won't use it
|
||||
self.verified = data.get('verified', False)
|
||||
@@ -251,7 +349,7 @@ class ClientUser(BaseUser):
|
||||
self._flags = data.get('flags', 0)
|
||||
self.mfa_enabled = data.get('mfa_enabled', False)
|
||||
|
||||
async def edit(self, *, username: str = MISSING, avatar: bytes = MISSING) -> None:
|
||||
async def edit(self, *, username: str = MISSING, avatar: bytes = MISSING) -> ClientUser:
|
||||
"""|coro|
|
||||
|
||||
Edits the current profile of the client.
|
||||
@@ -265,6 +363,9 @@ class ClientUser(BaseUser):
|
||||
|
||||
The only image formats supported for uploading is JPEG and PNG.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
The edit is no longer in-place, instead the newly edited client user is returned.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
username: :class:`str`
|
||||
@@ -279,6 +380,11 @@ class ClientUser(BaseUser):
|
||||
Editing your profile failed.
|
||||
InvalidArgument
|
||||
Wrong image format passed for ``avatar``.
|
||||
|
||||
Returns
|
||||
---------
|
||||
:class:`ClientUser`
|
||||
The newly edited client user.
|
||||
"""
|
||||
payload: Dict[str, Any] = {}
|
||||
if username is not MISSING:
|
||||
@@ -287,8 +393,8 @@ class ClientUser(BaseUser):
|
||||
if avatar is not MISSING:
|
||||
payload['avatar'] = _bytes_to_base64_data(avatar)
|
||||
|
||||
data = await self._state.http.edit_profile(payload)
|
||||
self._update(data)
|
||||
data: UserPayload = await self._state.http.edit_profile(payload)
|
||||
return ClientUser(state=self._state, data=data)
|
||||
|
||||
|
||||
class User(BaseUser, discord.abc.Messageable):
|
||||
@@ -312,6 +418,10 @@ class User(BaseUser, discord.abc.Messageable):
|
||||
|
||||
Returns the user's name with discriminator.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the user's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
name: :class:`str`
|
||||
@@ -328,11 +438,11 @@ class User(BaseUser, discord.abc.Messageable):
|
||||
|
||||
__slots__ = ('_stored',)
|
||||
|
||||
def __init__(self, *, state, data):
|
||||
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
|
||||
super().__init__(state=state, data=data)
|
||||
self._stored = False
|
||||
self._stored: bool = False
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f'<User id={self.id} name={self.name!r} discriminator={self.discriminator!r} bot={self.bot}>'
|
||||
|
||||
def __del__(self) -> None:
|
||||
@@ -343,17 +453,17 @@ class User(BaseUser, discord.abc.Messageable):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _copy(cls, user):
|
||||
def _copy(cls, user: User):
|
||||
self = super()._copy(user)
|
||||
self._stored = getattr(user, '_stored', False)
|
||||
self._stored = False
|
||||
return self
|
||||
|
||||
async def _get_channel(self):
|
||||
async def _get_channel(self) -> DMChannel:
|
||||
ch = await self.create_dm()
|
||||
return ch
|
||||
|
||||
@property
|
||||
def dm_channel(self):
|
||||
def dm_channel(self) -> Optional[DMChannel]:
|
||||
"""Optional[:class:`DMChannel`]: Returns the channel associated with this user if it exists.
|
||||
|
||||
If this returns ``None``, you can create a DM channel by calling the
|
||||
@@ -362,7 +472,7 @@ class User(BaseUser, discord.abc.Messageable):
|
||||
return self._state._get_private_channel_by_user(self.id)
|
||||
|
||||
@property
|
||||
def mutual_guilds(self):
|
||||
def mutual_guilds(self) -> List[Guild]:
|
||||
"""List[:class:`Guild`]: The guilds that the user shares with the client.
|
||||
|
||||
.. note::
|
||||
@@ -373,7 +483,7 @@ class User(BaseUser, discord.abc.Messageable):
|
||||
"""
|
||||
return [guild for guild in self._state._guilds.values() if guild.get_member(self.id)]
|
||||
|
||||
async def create_dm(self):
|
||||
async def create_dm(self) -> DMChannel:
|
||||
"""|coro|
|
||||
|
||||
Creates a :class:`DMChannel` with this user.
|
||||
@@ -391,5 +501,5 @@ class User(BaseUser, discord.abc.Messageable):
|
||||
return found
|
||||
|
||||
state = self._state
|
||||
data = await state.http.start_private_message(self.id)
|
||||
data: DMChannelPayload = await state.http.start_private_message(self.id)
|
||||
return state.add_dm_channel(data)
|
||||
|
||||
@@ -120,6 +120,9 @@ class _cached_property:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from functools import cached_property as cached_property
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from .permissions import Permissions
|
||||
from .abc import Snowflake
|
||||
from .invite import Invite
|
||||
@@ -129,6 +132,8 @@ if TYPE_CHECKING:
|
||||
headers: Mapping[str, Any]
|
||||
|
||||
|
||||
P = ParamSpec('P')
|
||||
|
||||
else:
|
||||
cached_property = _cached_property
|
||||
|
||||
@@ -231,8 +236,8 @@ def parse_time(timestamp: Optional[str]) -> Optional[datetime.datetime]:
|
||||
return None
|
||||
|
||||
|
||||
def copy_doc(original: Callable[..., Any]) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(overriden: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def copy_doc(original: Callable) -> Callable[[T], T]:
|
||||
def decorator(overriden: T) -> T:
|
||||
overriden.__doc__ = original.__doc__
|
||||
overriden.__signature__ = _signature(original) # type: ignore
|
||||
return overriden
|
||||
@@ -240,10 +245,10 @@ def copy_doc(original: Callable[..., Any]) -> Callable[[Callable[..., Any]], Cal
|
||||
return decorator
|
||||
|
||||
|
||||
def deprecated(instead: Optional[str] = None) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
||||
def actual_decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
def deprecated(instead: Optional[str] = None) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
||||
def actual_decorator(func: Callable[P, T]) -> Callable[P, T]:
|
||||
@functools.wraps(func)
|
||||
def decorated(*args, **kwargs) -> T:
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
warnings.simplefilter('always', DeprecationWarning) # turn off filter
|
||||
if instead:
|
||||
fmt = "{0.__name__} is deprecated, use {1} instead."
|
||||
@@ -267,7 +272,7 @@ def oauth_url(
|
||||
redirect_uri: str = MISSING,
|
||||
scopes: Iterable[str] = MISSING,
|
||||
disable_guild_select: bool = False,
|
||||
):
|
||||
) -> str:
|
||||
"""A helper function that returns the OAuth2 URL for inviting the bot
|
||||
into guilds.
|
||||
|
||||
@@ -479,17 +484,17 @@ def _bytes_to_base64_data(data: bytes) -> str:
|
||||
|
||||
if HAS_ORJSON:
|
||||
|
||||
def to_json(obj: Any) -> str: # type: ignore
|
||||
def _to_json(obj: Any) -> str: # type: ignore
|
||||
return orjson.dumps(obj).decode('utf-8')
|
||||
|
||||
from_json = orjson.loads # type: ignore
|
||||
_from_json = orjson.loads # type: ignore
|
||||
|
||||
else:
|
||||
|
||||
def to_json(obj: Any) -> str:
|
||||
def _to_json(obj: Any) -> str:
|
||||
return json.dumps(obj, separators=(',', ':'), ensure_ascii=True)
|
||||
|
||||
from_json = json.loads
|
||||
_from_json = json.loads
|
||||
|
||||
|
||||
def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
|
||||
@@ -916,7 +921,7 @@ def evaluate_annotation(
|
||||
is_literal = False
|
||||
args = tp.__args__
|
||||
if not hasattr(tp, '__origin__'):
|
||||
if PY_310 and tp.__class__ is types.Union: # type: ignore
|
||||
if PY_310 and tp.__class__ is types.UnionType: # type: ignore
|
||||
converted = Union[args] # type: ignore
|
||||
return evaluate_annotation(converted, globals, locals, cache)
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ __all__ = (
|
||||
|
||||
|
||||
|
||||
log: logging.Logger = logging.getLogger(__name__)
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
class VoiceProtocol:
|
||||
"""A class that represents the Discord voice protocol.
|
||||
@@ -301,7 +301,7 @@ class VoiceClient(VoiceProtocol):
|
||||
|
||||
async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
|
||||
if self._voice_server_complete.is_set():
|
||||
log.info('Ignoring extraneous voice server update.')
|
||||
_log.info('Ignoring extraneous voice server update.')
|
||||
return
|
||||
|
||||
self.token = data.get('token')
|
||||
@@ -309,7 +309,7 @@ class VoiceClient(VoiceProtocol):
|
||||
endpoint = data.get('endpoint')
|
||||
|
||||
if endpoint is None or self.token is None:
|
||||
log.warning('Awaiting endpoint... This requires waiting. ' \
|
||||
_log.warning('Awaiting endpoint... This requires waiting. ' \
|
||||
'If timeout occurred considering raising the timeout and reconnecting.')
|
||||
return
|
||||
|
||||
@@ -335,18 +335,18 @@ class VoiceClient(VoiceProtocol):
|
||||
await self.channel.guild.change_voice_state(channel=self.channel)
|
||||
|
||||
async def voice_disconnect(self) -> None:
|
||||
log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id)
|
||||
_log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id)
|
||||
await self.channel.guild.change_voice_state(channel=None)
|
||||
|
||||
def prepare_handshake(self) -> None:
|
||||
self._voice_state_complete.clear()
|
||||
self._voice_server_complete.clear()
|
||||
self._handshaking = True
|
||||
log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1)
|
||||
_log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1)
|
||||
self._connections += 1
|
||||
|
||||
def finish_handshake(self) -> None:
|
||||
log.info('Voice handshake complete. Endpoint found %s', self.endpoint)
|
||||
_log.info('Voice handshake complete. Endpoint found %s', self.endpoint)
|
||||
self._handshaking = False
|
||||
self._voice_server_complete.clear()
|
||||
self._voice_state_complete.clear()
|
||||
@@ -360,7 +360,7 @@ class VoiceClient(VoiceProtocol):
|
||||
return ws
|
||||
|
||||
async def connect(self, *, reconnect: bool, timeout: float) ->None:
|
||||
log.info('Connecting to voice...')
|
||||
_log.info('Connecting to voice...')
|
||||
self.timeout = timeout
|
||||
|
||||
for i in range(5):
|
||||
@@ -388,7 +388,7 @@ class VoiceClient(VoiceProtocol):
|
||||
break
|
||||
except (ConnectionClosed, asyncio.TimeoutError):
|
||||
if reconnect:
|
||||
log.exception('Failed to connect to voice... Retrying...')
|
||||
_log.exception('Failed to connect to voice... Retrying...')
|
||||
await asyncio.sleep(1 + i * 2.0)
|
||||
await self.voice_disconnect()
|
||||
continue
|
||||
@@ -453,14 +453,14 @@ class VoiceClient(VoiceProtocol):
|
||||
# 4014 - voice channel has been deleted.
|
||||
# 4015 - voice server has crashed
|
||||
if exc.code in (1000, 4015):
|
||||
log.info('Disconnecting from voice normally, close code %d.', exc.code)
|
||||
_log.info('Disconnecting from voice normally, close code %d.', exc.code)
|
||||
await self.disconnect()
|
||||
break
|
||||
if exc.code == 4014:
|
||||
log.info('Disconnected from voice by force... potentially reconnecting.')
|
||||
_log.info('Disconnected from voice by force... potentially reconnecting.')
|
||||
successful = await self.potential_reconnect()
|
||||
if not successful:
|
||||
log.info('Reconnect was unsuccessful, disconnecting from voice normally...')
|
||||
_log.info('Reconnect was unsuccessful, disconnecting from voice normally...')
|
||||
await self.disconnect()
|
||||
break
|
||||
else:
|
||||
@@ -471,7 +471,7 @@ class VoiceClient(VoiceProtocol):
|
||||
raise
|
||||
|
||||
retry = backoff.delay()
|
||||
log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry)
|
||||
_log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry)
|
||||
self._connected.clear()
|
||||
await asyncio.sleep(retry)
|
||||
await self.voice_disconnect()
|
||||
@@ -479,7 +479,7 @@ class VoiceClient(VoiceProtocol):
|
||||
await self.connect(reconnect=True, timeout=self.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
# at this point we've retried 5 times... let's continue the loop.
|
||||
log.warning('Could not connect to voice... Retrying...')
|
||||
_log.warning('Could not connect to voice... Retrying...')
|
||||
continue
|
||||
|
||||
async def disconnect(self, *, force: bool = False) -> None:
|
||||
@@ -671,6 +671,6 @@ class VoiceClient(VoiceProtocol):
|
||||
try:
|
||||
self.socket.sendto(packet, (self.endpoint_ip, self.voice_port))
|
||||
except BlockingIOError:
|
||||
log.warning('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp)
|
||||
_log.warning('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp)
|
||||
|
||||
self.checked_add('timestamp', opus.Encoder.SAMPLES_PER_FRAME, 4294967295)
|
||||
|
||||
@@ -24,7 +24,6 @@ DEALINGS IN THE SOFTWARE.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import asyncio
|
||||
import json
|
||||
@@ -32,6 +31,7 @@ import re
|
||||
|
||||
from urllib.parse import quote as urlquote
|
||||
from typing import Any, Dict, List, Literal, NamedTuple, Optional, TYPE_CHECKING, Tuple, Union, overload
|
||||
from contextvars import ContextVar
|
||||
|
||||
import aiohttp
|
||||
|
||||
@@ -43,7 +43,7 @@ from ..user import BaseUser, User
|
||||
from ..asset import Asset
|
||||
from ..http import Route
|
||||
from ..mixins import Hashable
|
||||
from ..object import Object
|
||||
from ..channel import PartialMessageable
|
||||
|
||||
__all__ = (
|
||||
'Webhook',
|
||||
@@ -52,15 +52,20 @@ __all__ = (
|
||||
'PartialWebhookGuild',
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..file import File
|
||||
from ..embeds import Embed
|
||||
from ..mentions import AllowedMentions
|
||||
from ..state import ConnectionState
|
||||
from ..http import Response
|
||||
from ..types.webhook import (
|
||||
Webhook as WebhookPayload,
|
||||
)
|
||||
from ..types.message import (
|
||||
Message as MessagePayload,
|
||||
)
|
||||
from ..guild import Guild
|
||||
from ..channel import TextChannel
|
||||
from ..abc import Snowflake
|
||||
@@ -116,7 +121,7 @@ class AsyncWebhookAdapter:
|
||||
|
||||
if payload is not None:
|
||||
headers['Content-Type'] = 'application/json'
|
||||
to_send = utils.to_json(payload)
|
||||
to_send = utils._to_json(payload)
|
||||
|
||||
if auth_token is not None:
|
||||
headers['Authorization'] = f'Bot {auth_token}'
|
||||
@@ -143,7 +148,7 @@ class AsyncWebhookAdapter:
|
||||
|
||||
try:
|
||||
async with session.request(method, url, data=to_send, headers=headers, params=params) as response:
|
||||
log.debug(
|
||||
_log.debug(
|
||||
'Webhook ID %s with %s %s has returned status code %s',
|
||||
webhook_id,
|
||||
method,
|
||||
@@ -157,7 +162,7 @@ class AsyncWebhookAdapter:
|
||||
remaining = response.headers.get('X-Ratelimit-Remaining')
|
||||
if remaining == '0' and response.status != 429:
|
||||
delta = utils._parse_ratelimit_header(response)
|
||||
log.debug(
|
||||
_log.debug(
|
||||
'Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds', webhook_id, delta
|
||||
)
|
||||
lock.delay_by(delta)
|
||||
@@ -170,7 +175,7 @@ class AsyncWebhookAdapter:
|
||||
raise HTTPException(response, data)
|
||||
|
||||
retry_after: float = data['retry_after'] # type: ignore
|
||||
log.warning('Webhook ID %s is rate limited. Retrying in %.2f seconds', webhook_id, retry_after)
|
||||
_log.warning('Webhook ID %s is rate limited. Retrying in %.2f seconds', webhook_id, retry_after)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
|
||||
@@ -205,7 +210,7 @@ class AsyncWebhookAdapter:
|
||||
token: Optional[str] = None,
|
||||
session: aiohttp.ClientSession,
|
||||
reason: Optional[str] = None,
|
||||
):
|
||||
) -> Response[None]:
|
||||
route = Route('DELETE', '/webhooks/{webhook_id}', webhook_id=webhook_id)
|
||||
return self.request(route, session, reason=reason, auth_token=token)
|
||||
|
||||
@@ -216,7 +221,7 @@ class AsyncWebhookAdapter:
|
||||
*,
|
||||
session: aiohttp.ClientSession,
|
||||
reason: Optional[str] = None,
|
||||
):
|
||||
) -> Response[None]:
|
||||
route = Route('DELETE', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
|
||||
return self.request(route, session, reason=reason)
|
||||
|
||||
@@ -228,7 +233,7 @@ class AsyncWebhookAdapter:
|
||||
*,
|
||||
session: aiohttp.ClientSession,
|
||||
reason: Optional[str] = None,
|
||||
):
|
||||
) -> Response[WebhookPayload]:
|
||||
route = Route('PATCH', '/webhooks/{webhook_id}', webhook_id=webhook_id)
|
||||
return self.request(route, session, reason=reason, payload=payload, auth_token=token)
|
||||
|
||||
@@ -240,7 +245,7 @@ class AsyncWebhookAdapter:
|
||||
*,
|
||||
session: aiohttp.ClientSession,
|
||||
reason: Optional[str] = None,
|
||||
):
|
||||
) -> Response[WebhookPayload]:
|
||||
route = Route('PATCH', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
|
||||
return self.request(route, session, reason=reason, payload=payload)
|
||||
|
||||
@@ -255,7 +260,7 @@ class AsyncWebhookAdapter:
|
||||
files: Optional[List[File]] = None,
|
||||
thread_id: Optional[int] = None,
|
||||
wait: bool = False,
|
||||
):
|
||||
) -> Response[Optional[MessagePayload]]:
|
||||
params = {'wait': int(wait)}
|
||||
if thread_id:
|
||||
params['thread_id'] = thread_id
|
||||
@@ -269,7 +274,7 @@ class AsyncWebhookAdapter:
|
||||
message_id: int,
|
||||
*,
|
||||
session: aiohttp.ClientSession,
|
||||
):
|
||||
) -> Response[MessagePayload]:
|
||||
route = Route(
|
||||
'GET',
|
||||
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
|
||||
@@ -289,7 +294,7 @@ class AsyncWebhookAdapter:
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
multipart: Optional[List[Dict[str, Any]]] = None,
|
||||
files: Optional[List[File]] = None,
|
||||
):
|
||||
) -> Response[Message]:
|
||||
route = Route(
|
||||
'PATCH',
|
||||
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
|
||||
@@ -306,7 +311,7 @@ class AsyncWebhookAdapter:
|
||||
message_id: int,
|
||||
*,
|
||||
session: aiohttp.ClientSession,
|
||||
):
|
||||
) -> Response[None]:
|
||||
route = Route(
|
||||
'DELETE',
|
||||
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
|
||||
@@ -322,7 +327,7 @@ class AsyncWebhookAdapter:
|
||||
token: str,
|
||||
*,
|
||||
session: aiohttp.ClientSession,
|
||||
):
|
||||
) -> Response[WebhookPayload]:
|
||||
route = Route('GET', '/webhooks/{webhook_id}', webhook_id=webhook_id)
|
||||
return self.request(route, session=session, auth_token=token)
|
||||
|
||||
@@ -332,7 +337,7 @@ class AsyncWebhookAdapter:
|
||||
token: str,
|
||||
*,
|
||||
session: aiohttp.ClientSession,
|
||||
):
|
||||
) -> Response[WebhookPayload]:
|
||||
route = Route('GET', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
|
||||
return self.request(route, session=session)
|
||||
|
||||
@@ -344,7 +349,7 @@ class AsyncWebhookAdapter:
|
||||
session: aiohttp.ClientSession,
|
||||
type: int,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
) -> Response[None]:
|
||||
payload: Dict[str, Any] = {
|
||||
'type': type,
|
||||
}
|
||||
@@ -367,7 +372,7 @@ class AsyncWebhookAdapter:
|
||||
token: str,
|
||||
*,
|
||||
session: aiohttp.ClientSession,
|
||||
):
|
||||
) -> Response[MessagePayload]:
|
||||
r = Route(
|
||||
'GET',
|
||||
'/webhooks/{webhook_id}/{webhook_token}/messages/@original',
|
||||
@@ -385,7 +390,7 @@ class AsyncWebhookAdapter:
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
multipart: Optional[List[Dict[str, Any]]] = None,
|
||||
files: Optional[List[File]] = None,
|
||||
):
|
||||
) -> Response[MessagePayload]:
|
||||
r = Route(
|
||||
'PATCH',
|
||||
'/webhooks/{webhook_id}/{webhook_token}/messages/@original',
|
||||
@@ -400,7 +405,7 @@ class AsyncWebhookAdapter:
|
||||
token: str,
|
||||
*,
|
||||
session: aiohttp.ClientSession,
|
||||
):
|
||||
) -> Response[None]:
|
||||
r = Route(
|
||||
'DELETE',
|
||||
'/webhooks/{webhook_id}/{wehook_token}/messages/@original',
|
||||
@@ -420,7 +425,7 @@ def handle_message_parameters(
|
||||
content: Optional[str] = MISSING,
|
||||
*,
|
||||
username: str = MISSING,
|
||||
avatar_url: str = MISSING,
|
||||
avatar_url: Any = MISSING,
|
||||
tts: bool = False,
|
||||
ephemeral: bool = False,
|
||||
file: File = MISSING,
|
||||
@@ -481,7 +486,7 @@ def handle_message_parameters(
|
||||
files = [file]
|
||||
|
||||
if files:
|
||||
multipart.append({'name': 'payload_json', 'value': utils.to_json(payload)})
|
||||
multipart.append({'name': 'payload_json', 'value': utils._to_json(payload)})
|
||||
payload = None
|
||||
if len(files) == 1:
|
||||
file = files[0]
|
||||
@@ -507,7 +512,7 @@ def handle_message_parameters(
|
||||
return ExecuteWebhookParameters(payload=payload, multipart=multipart, files=files)
|
||||
|
||||
|
||||
async_context = contextvars.ContextVar('async_webhook_context', default=AsyncWebhookAdapter())
|
||||
async_context: ContextVar[AsyncWebhookAdapter] = ContextVar('async_webhook_context', default=AsyncWebhookAdapter())
|
||||
|
||||
|
||||
class PartialWebhookChannel(Hashable):
|
||||
@@ -579,10 +584,11 @@ class _FriendlyHttpAttributeErrorHelper:
|
||||
class _WebhookState:
|
||||
__slots__ = ('_parent', '_webhook')
|
||||
|
||||
def __init__(self, webhook, parent):
|
||||
self._webhook = webhook
|
||||
def __init__(self, webhook: Any, parent: Optional[Union[ConnectionState, _WebhookState]]):
|
||||
self._webhook: Any = webhook
|
||||
|
||||
if isinstance(parent, self.__class__):
|
||||
self._parent: Optional[ConnectionState]
|
||||
if isinstance(parent, _WebhookState):
|
||||
self._parent = None
|
||||
else:
|
||||
self._parent = parent
|
||||
@@ -595,7 +601,12 @@ class _WebhookState:
|
||||
def store_user(self, data):
|
||||
if self._parent is not None:
|
||||
return self._parent.store_user(data)
|
||||
return BaseUser(state=self, data=data)
|
||||
# state parameter is artificial
|
||||
return BaseUser(state=self, data=data) # type: ignore
|
||||
|
||||
def create_user(self, data):
|
||||
# state parameter is artificial
|
||||
return BaseUser(state=self, data=data) # type: ignore
|
||||
|
||||
@property
|
||||
def http(self):
|
||||
@@ -636,13 +647,16 @@ class WebhookMessage(Message):
|
||||
files: List[File] = MISSING,
|
||||
view: Optional[View] = MISSING,
|
||||
allowed_mentions: Optional[AllowedMentions] = None,
|
||||
):
|
||||
) -> WebhookMessage:
|
||||
"""|coro|
|
||||
|
||||
Edits the message.
|
||||
|
||||
.. versionadded:: 1.6
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
The edit is no longer in-place, instead the newly edited message is returned.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
content: Optional[:class:`str`]
|
||||
@@ -682,8 +696,13 @@ class WebhookMessage(Message):
|
||||
The length of ``embeds`` was invalid
|
||||
InvalidArgument
|
||||
There was no token associated with this webhook.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`WebhookMessage`
|
||||
The newly edited message.
|
||||
"""
|
||||
await self._state._webhook.edit_message(
|
||||
return await self._state._webhook.edit_message(
|
||||
self.id,
|
||||
content=content,
|
||||
embeds=embeds,
|
||||
@@ -745,9 +764,9 @@ class BaseWebhook(Hashable):
|
||||
'_state',
|
||||
)
|
||||
|
||||
def __init__(self, data: WebhookPayload, token: Optional[str] = None, state=None):
|
||||
def __init__(self, data: WebhookPayload, token: Optional[str] = None, state: Optional[ConnectionState] = None):
|
||||
self.auth_token: Optional[str] = token
|
||||
self._state = state or _WebhookState(self, parent=state)
|
||||
self._state: Union[ConnectionState, _WebhookState] = state or _WebhookState(self, parent=state)
|
||||
self._update(data)
|
||||
|
||||
def _update(self, data: WebhookPayload):
|
||||
@@ -762,10 +781,8 @@ class BaseWebhook(Hashable):
|
||||
user = data.get('user')
|
||||
self.user: Optional[Union[BaseUser, User]] = None
|
||||
if user is not None:
|
||||
if self._state is None:
|
||||
self.user = BaseUser(state=None, data=user)
|
||||
else:
|
||||
self.user = User(state=self._state, data=user)
|
||||
# state parameter may be _WebhookState
|
||||
self.user = User(state=self._state, data=user) # type: ignore
|
||||
|
||||
source_channel = data.get('source_channel')
|
||||
if source_channel:
|
||||
@@ -869,6 +886,10 @@ class Webhook(BaseWebhook):
|
||||
|
||||
Returns the webhooks's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the webhooks's ID.
|
||||
|
||||
.. versionchanged:: 1.4
|
||||
Webhooks are now comparable and hashable.
|
||||
|
||||
@@ -916,12 +937,12 @@ class Webhook(BaseWebhook):
|
||||
return f'<Webhook id={self.id!r}>'
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
def url(self) -> str:
|
||||
""":class:`str` : Returns the webhook's url."""
|
||||
return f'https://discord.com/api/webhooks/{self.id}/{self.token}'
|
||||
|
||||
@classmethod
|
||||
def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None):
|
||||
def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook:
|
||||
"""Creates a partial :class:`Webhook`.
|
||||
|
||||
Parameters
|
||||
@@ -957,7 +978,7 @@ class Webhook(BaseWebhook):
|
||||
return cls(data, session, token=bot_token)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None):
|
||||
def from_url(cls, url: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook:
|
||||
"""Creates a partial :class:`Webhook` from a webhook URL.
|
||||
|
||||
Parameters
|
||||
@@ -996,7 +1017,7 @@ class Webhook(BaseWebhook):
|
||||
return cls(data, session, token=bot_token) # type: ignore
|
||||
|
||||
@classmethod
|
||||
def _as_follower(cls, data, *, channel, user):
|
||||
def _as_follower(cls, data, *, channel, user) -> Webhook:
|
||||
name = f"{channel.guild} #{channel}"
|
||||
feed: WebhookPayload = {
|
||||
'id': data['webhook_id'],
|
||||
@@ -1012,7 +1033,7 @@ class Webhook(BaseWebhook):
|
||||
return cls(feed, session=session, state=state, token=state.http.token)
|
||||
|
||||
@classmethod
|
||||
def from_state(cls, data, state):
|
||||
def from_state(cls, data, state) -> Webhook:
|
||||
session = state.http._HTTPClient__session
|
||||
return cls(data, session=session, state=state, token=state.http.token)
|
||||
|
||||
@@ -1108,7 +1129,7 @@ class Webhook(BaseWebhook):
|
||||
avatar: Optional[bytes] = MISSING,
|
||||
channel: Optional[Snowflake] = None,
|
||||
prefer_auth: bool = True,
|
||||
):
|
||||
) -> Webhook:
|
||||
"""|coro|
|
||||
|
||||
Edits this Webhook.
|
||||
@@ -1155,6 +1176,7 @@ class Webhook(BaseWebhook):
|
||||
|
||||
adapter = async_context.get()
|
||||
|
||||
data: Optional[WebhookPayload] = None
|
||||
# If a channel is given, always use the authenticated endpoint
|
||||
if channel is not None:
|
||||
if self.auth_token is None:
|
||||
@@ -1162,21 +1184,24 @@ class Webhook(BaseWebhook):
|
||||
|
||||
payload['channel_id'] = channel.id
|
||||
data = await adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
|
||||
self._update(data)
|
||||
return
|
||||
|
||||
if prefer_auth and self.auth_token:
|
||||
data = await adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
|
||||
self._update(data)
|
||||
elif self.token:
|
||||
data = await adapter.edit_webhook_with_token(
|
||||
self.id, self.token, payload=payload, session=self.session, reason=reason
|
||||
)
|
||||
self._update(data)
|
||||
|
||||
if data is None:
|
||||
raise RuntimeError('Unreachable code hit: data was not assigned')
|
||||
|
||||
return Webhook(data=data, session=self.session, token=self.auth_token, state=self._state)
|
||||
|
||||
def _create_message(self, data):
|
||||
state = _WebhookState(self, parent=self._state)
|
||||
channel = self.channel or Object(id=int(data['channel_id']))
|
||||
# state may be artificial (unlikely at this point...)
|
||||
channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore
|
||||
# state is artificial
|
||||
return WebhookMessage(data=data, state=state, channel=channel) # type: ignore
|
||||
|
||||
@overload
|
||||
@@ -1185,7 +1210,7 @@ class Webhook(BaseWebhook):
|
||||
content: str = MISSING,
|
||||
*,
|
||||
username: str = MISSING,
|
||||
avatar_url: str = MISSING,
|
||||
avatar_url: Any = MISSING,
|
||||
tts: bool = MISSING,
|
||||
ephemeral: bool = MISSING,
|
||||
file: File = MISSING,
|
||||
@@ -1205,7 +1230,7 @@ class Webhook(BaseWebhook):
|
||||
content: str = MISSING,
|
||||
*,
|
||||
username: str = MISSING,
|
||||
avatar_url: str = MISSING,
|
||||
avatar_url: Any = MISSING,
|
||||
tts: bool = MISSING,
|
||||
ephemeral: bool = MISSING,
|
||||
file: File = MISSING,
|
||||
@@ -1224,7 +1249,7 @@ class Webhook(BaseWebhook):
|
||||
content: str = MISSING,
|
||||
*,
|
||||
username: str = MISSING,
|
||||
avatar_url: str = MISSING,
|
||||
avatar_url: Any = MISSING,
|
||||
tts: bool = False,
|
||||
ephemeral: bool = False,
|
||||
file: File = MISSING,
|
||||
@@ -1261,9 +1286,10 @@ class Webhook(BaseWebhook):
|
||||
username: :class:`str`
|
||||
The username to send with this message. If no username is provided
|
||||
then the default username for the webhook is used.
|
||||
avatar_url: Union[:class:`str`, :class:`Asset`]
|
||||
avatar_url: :class:`str`
|
||||
The avatar URL to send with this message. If no avatar URL is provided
|
||||
then the default avatar for the webhook is used.
|
||||
then the default avatar for the webhook is used. If this is not a
|
||||
string then it is explicitly cast using ``str``.
|
||||
tts: :class:`bool`
|
||||
Indicates if the message should be sent using text-to-speech.
|
||||
ephemeral: :class:`bool`
|
||||
@@ -1435,7 +1461,7 @@ class Webhook(BaseWebhook):
|
||||
files: List[File] = MISSING,
|
||||
view: Optional[View] = MISSING,
|
||||
allowed_mentions: Optional[AllowedMentions] = None,
|
||||
):
|
||||
) -> WebhookMessage:
|
||||
"""|coro|
|
||||
|
||||
Edits a message owned by this webhook.
|
||||
@@ -1445,6 +1471,9 @@ class Webhook(BaseWebhook):
|
||||
|
||||
.. versionadded:: 1.6
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
The edit is no longer in-place, instead the newly edited message is returned.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
message_id: :class:`int`
|
||||
@@ -1488,6 +1517,11 @@ class Webhook(BaseWebhook):
|
||||
InvalidArgument
|
||||
There was no token associated with this webhook or the webhook had
|
||||
no state.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`WebhookMessage`
|
||||
The newly edited webhook message.
|
||||
"""
|
||||
|
||||
if self.token is None:
|
||||
@@ -1511,7 +1545,7 @@ class Webhook(BaseWebhook):
|
||||
previous_allowed_mentions=previous_mentions,
|
||||
)
|
||||
adapter = async_context.get()
|
||||
await adapter.edit_webhook_message(
|
||||
data = await adapter.edit_webhook_message(
|
||||
self.id,
|
||||
self.token,
|
||||
message_id,
|
||||
@@ -1521,10 +1555,12 @@ class Webhook(BaseWebhook):
|
||||
files=params.files,
|
||||
)
|
||||
|
||||
message = self._create_message(data)
|
||||
if view and not view.is_finished():
|
||||
self._state.store_view(view, message_id)
|
||||
return message
|
||||
|
||||
async def delete_message(self, message_id: int):
|
||||
async def delete_message(self, message_id: int, /) -> None:
|
||||
"""|coro|
|
||||
|
||||
Deletes a message owned by this webhook.
|
||||
|
||||
@@ -43,7 +43,7 @@ from .. import utils
|
||||
from ..errors import InvalidArgument, HTTPException, Forbidden, NotFound, DiscordServerError
|
||||
from ..message import Message
|
||||
from ..http import Route
|
||||
from ..object import Object
|
||||
from ..channel import PartialMessageable
|
||||
|
||||
from .async_ import BaseWebhook, handle_message_parameters, _WebhookState
|
||||
|
||||
@@ -52,7 +52,7 @@ __all__ = (
|
||||
'SyncWebhookMessage',
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..file import File
|
||||
@@ -117,7 +117,7 @@ class WebhookAdapter:
|
||||
|
||||
if payload is not None:
|
||||
headers['Content-Type'] = 'application/json'
|
||||
to_send = utils.to_json(payload)
|
||||
to_send = utils._to_json(payload)
|
||||
|
||||
if auth_token is not None:
|
||||
headers['Authorization'] = f'Bot {auth_token}'
|
||||
@@ -142,13 +142,15 @@ class WebhookAdapter:
|
||||
for p in multipart:
|
||||
name = p['name']
|
||||
if name == 'payload_json':
|
||||
to_send = { 'payload_json': p['value'] }
|
||||
to_send = {'payload_json': p['value']}
|
||||
else:
|
||||
file_data[name] = (p['filename'], p['value'], p['content_type'])
|
||||
|
||||
try:
|
||||
with session.request(method, url, data=to_send, files=file_data, headers=headers, params=params) as response:
|
||||
log.debug(
|
||||
with session.request(
|
||||
method, url, data=to_send, files=file_data, headers=headers, params=params
|
||||
) as response:
|
||||
_log.debug(
|
||||
'Webhook ID %s with %s %s has returned status code %s',
|
||||
webhook_id,
|
||||
method,
|
||||
@@ -166,7 +168,7 @@ class WebhookAdapter:
|
||||
remaining = response.headers.get('X-Ratelimit-Remaining')
|
||||
if remaining == '0' and response.status_code != 429:
|
||||
delta = utils._parse_ratelimit_header(response)
|
||||
log.debug(
|
||||
_log.debug(
|
||||
'Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds', webhook_id, delta
|
||||
)
|
||||
lock.delay_by(delta)
|
||||
@@ -179,7 +181,7 @@ class WebhookAdapter:
|
||||
raise HTTPException(response, data)
|
||||
|
||||
retry_after: float = data['retry_after'] # type: ignore
|
||||
log.warning('Webhook ID %s is rate limited. Retrying in %.2f seconds', webhook_id, retry_after)
|
||||
_log.warning('Webhook ID %s is rate limited. Retrying in %.2f seconds', webhook_id, retry_after)
|
||||
time.sleep(retry_after)
|
||||
continue
|
||||
|
||||
@@ -346,8 +348,17 @@ class WebhookAdapter:
|
||||
return self.request(route, session=session)
|
||||
|
||||
|
||||
_context = threading.local()
|
||||
_context.adapter = WebhookAdapter()
|
||||
class _WebhookContext(threading.local):
|
||||
adapter: Optional[WebhookAdapter] = None
|
||||
|
||||
|
||||
_context = _WebhookContext()
|
||||
|
||||
|
||||
def _get_webhook_adapter() -> WebhookAdapter:
|
||||
if _context.adapter is None:
|
||||
_context.adapter = WebhookAdapter()
|
||||
return _context.adapter
|
||||
|
||||
|
||||
class SyncWebhookMessage(Message):
|
||||
@@ -362,6 +373,8 @@ class SyncWebhookMessage(Message):
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
|
||||
_state: _WebhookState
|
||||
|
||||
def edit(
|
||||
self,
|
||||
content: Optional[str] = MISSING,
|
||||
@@ -370,7 +383,7 @@ class SyncWebhookMessage(Message):
|
||||
file: File = MISSING,
|
||||
files: List[File] = MISSING,
|
||||
allowed_mentions: Optional[AllowedMentions] = None,
|
||||
):
|
||||
) -> SyncWebhookMessage:
|
||||
"""Edits the message.
|
||||
|
||||
Parameters
|
||||
@@ -403,8 +416,13 @@ class SyncWebhookMessage(Message):
|
||||
The length of ``embeds`` was invalid
|
||||
InvalidArgument
|
||||
There was no token associated with this webhook.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`SyncWebhookMessage`
|
||||
The newly edited message.
|
||||
"""
|
||||
self._state._webhook.edit_message(
|
||||
return self._state._webhook.edit_message(
|
||||
self.id,
|
||||
content=content,
|
||||
embeds=embeds,
|
||||
@@ -457,6 +475,10 @@ class SyncWebhook(BaseWebhook):
|
||||
|
||||
Returns the webhooks's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the webhooks's ID.
|
||||
|
||||
.. versionchanged:: 1.4
|
||||
Webhooks are now comparable and hashable.
|
||||
|
||||
@@ -504,12 +526,12 @@ class SyncWebhook(BaseWebhook):
|
||||
return f'<Webhook id={self.id!r}>'
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
def url(self) -> str:
|
||||
""":class:`str` : Returns the webhook's url."""
|
||||
return f'https://discord.com/api/webhooks/{self.id}/{self.token}'
|
||||
|
||||
@classmethod
|
||||
def partial(cls, id: int, token: str, *, session: Session = MISSING, bot_token: Optional[str] = None):
|
||||
def partial(cls, id: int, token: str, *, session: Session = MISSING, bot_token: Optional[str] = None) -> SyncWebhook:
|
||||
"""Creates a partial :class:`Webhook`.
|
||||
|
||||
Parameters
|
||||
@@ -548,7 +570,7 @@ class SyncWebhook(BaseWebhook):
|
||||
return cls(data, session, token=bot_token)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, *, session: Session = MISSING, bot_token: Optional[str] = None):
|
||||
def from_url(cls, url: str, *, session: Session = MISSING, bot_token: Optional[str] = None) -> SyncWebhook:
|
||||
"""Creates a partial :class:`Webhook` from a webhook URL.
|
||||
|
||||
Parameters
|
||||
@@ -621,7 +643,7 @@ class SyncWebhook(BaseWebhook):
|
||||
:class:`SyncWebhook`
|
||||
The fetched webhook.
|
||||
"""
|
||||
adapter: WebhookAdapter = _context.adapter
|
||||
adapter: WebhookAdapter = _get_webhook_adapter()
|
||||
|
||||
if prefer_auth and self.auth_token:
|
||||
data = adapter.fetch_webhook(self.id, self.auth_token, session=self.session)
|
||||
@@ -632,7 +654,7 @@ class SyncWebhook(BaseWebhook):
|
||||
|
||||
return SyncWebhook(data, self.session, token=self.auth_token, state=self._state)
|
||||
|
||||
def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True):
|
||||
def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True) -> None:
|
||||
"""Deletes this Webhook.
|
||||
|
||||
Parameters
|
||||
@@ -659,7 +681,7 @@ class SyncWebhook(BaseWebhook):
|
||||
if self.token is None and self.auth_token is None:
|
||||
raise InvalidArgument('This webhook does not have a token associated with it')
|
||||
|
||||
adapter: WebhookAdapter = _context.adapter
|
||||
adapter: WebhookAdapter = _get_webhook_adapter()
|
||||
|
||||
if prefer_auth and self.auth_token:
|
||||
adapter.delete_webhook(self.id, token=self.auth_token, session=self.session, reason=reason)
|
||||
@@ -674,7 +696,7 @@ class SyncWebhook(BaseWebhook):
|
||||
avatar: Optional[bytes] = MISSING,
|
||||
channel: Optional[Snowflake] = None,
|
||||
prefer_auth: bool = True,
|
||||
):
|
||||
) -> SyncWebhook:
|
||||
"""Edits this Webhook.
|
||||
|
||||
Parameters
|
||||
@@ -702,6 +724,11 @@ class SyncWebhook(BaseWebhook):
|
||||
InvalidArgument
|
||||
This webhook does not have a token associated with it
|
||||
or it tried editing a channel without authentication.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`SyncWebhook`
|
||||
The newly edited webhook.
|
||||
"""
|
||||
if self.token is None and self.auth_token is None:
|
||||
raise InvalidArgument('This webhook does not have a token associated with it')
|
||||
@@ -713,8 +740,9 @@ class SyncWebhook(BaseWebhook):
|
||||
if avatar is not MISSING:
|
||||
payload['avatar'] = utils._bytes_to_base64_data(avatar) if avatar is not None else None
|
||||
|
||||
adapter: WebhookAdapter = _context.adapter
|
||||
adapter: WebhookAdapter = _get_webhook_adapter()
|
||||
|
||||
data: Optional[WebhookPayload] = None
|
||||
# If a channel is given, always use the authenticated endpoint
|
||||
if channel is not None:
|
||||
if self.auth_token is None:
|
||||
@@ -722,20 +750,23 @@ class SyncWebhook(BaseWebhook):
|
||||
|
||||
payload['channel_id'] = channel.id
|
||||
data = adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
|
||||
self._update(data)
|
||||
return
|
||||
|
||||
if prefer_auth and self.auth_token:
|
||||
data = adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
|
||||
self._update(data)
|
||||
elif self.token:
|
||||
data = adapter.edit_webhook_with_token(self.id, self.token, payload=payload, session=self.session, reason=reason)
|
||||
self._update(data)
|
||||
|
||||
if data is None:
|
||||
raise RuntimeError('Unreachable code hit: data was not assigned')
|
||||
|
||||
return SyncWebhook(data=data, session=self.session, token=self.auth_token, state=self._state)
|
||||
|
||||
def _create_message(self, data):
|
||||
state = _WebhookState(self, parent=self._state)
|
||||
channel = self.channel or Object(id=int(data['channel_id']))
|
||||
return SyncWebhookMessage(data=data, state=state, channel=channel)
|
||||
# state may be artificial (unlikely at this point...)
|
||||
channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore
|
||||
# state is artificial
|
||||
return SyncWebhookMessage(data=data, state=state, channel=channel) # type: ignore
|
||||
|
||||
@overload
|
||||
def send(
|
||||
@@ -743,7 +774,7 @@ class SyncWebhook(BaseWebhook):
|
||||
content: str = MISSING,
|
||||
*,
|
||||
username: str = MISSING,
|
||||
avatar_url: str = MISSING,
|
||||
avatar_url: Any = MISSING,
|
||||
tts: bool = MISSING,
|
||||
file: File = MISSING,
|
||||
files: List[File] = MISSING,
|
||||
@@ -760,7 +791,7 @@ class SyncWebhook(BaseWebhook):
|
||||
content: str = MISSING,
|
||||
*,
|
||||
username: str = MISSING,
|
||||
avatar_url: str = MISSING,
|
||||
avatar_url: Any = MISSING,
|
||||
tts: bool = MISSING,
|
||||
file: File = MISSING,
|
||||
files: List[File] = MISSING,
|
||||
@@ -776,7 +807,7 @@ class SyncWebhook(BaseWebhook):
|
||||
content: str = MISSING,
|
||||
*,
|
||||
username: str = MISSING,
|
||||
avatar_url: str = MISSING,
|
||||
avatar_url: Any = MISSING,
|
||||
tts: bool = False,
|
||||
file: File = MISSING,
|
||||
files: List[File] = MISSING,
|
||||
@@ -808,9 +839,10 @@ class SyncWebhook(BaseWebhook):
|
||||
username: :class:`str`
|
||||
The username to send with this message. If no username is provided
|
||||
then the default username for the webhook is used.
|
||||
avatar_url: Union[:class:`str`, :class:`Asset`]
|
||||
avatar_url: :class:`str`
|
||||
The avatar URL to send with this message. If no avatar URL is provided
|
||||
then the default avatar for the webhook is used.
|
||||
then the default avatar for the webhook is used. If this is not a
|
||||
string then it is explicitly cast using ``str``.
|
||||
tts: :class:`bool`
|
||||
Indicates if the message should be sent using text-to-speech.
|
||||
file: :class:`File`
|
||||
@@ -873,7 +905,7 @@ class SyncWebhook(BaseWebhook):
|
||||
allowed_mentions=allowed_mentions,
|
||||
previous_allowed_mentions=previous_mentions,
|
||||
)
|
||||
adapter: WebhookAdapter = _context.adapter
|
||||
adapter: WebhookAdapter = _get_webhook_adapter()
|
||||
thread_id: Optional[int] = None
|
||||
if thread is not MISSING:
|
||||
thread_id = thread.id
|
||||
@@ -891,7 +923,7 @@ class SyncWebhook(BaseWebhook):
|
||||
if wait:
|
||||
return self._create_message(data)
|
||||
|
||||
def fetch_message(self, id: int) -> SyncWebhookMessage:
|
||||
def fetch_message(self, id: int, /) -> SyncWebhookMessage:
|
||||
"""Retrieves a single :class:`~discord.SyncWebhookMessage` owned by this webhook.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
@@ -921,7 +953,7 @@ class SyncWebhook(BaseWebhook):
|
||||
if self.token is None:
|
||||
raise InvalidArgument('This webhook does not have a token associated with it')
|
||||
|
||||
adapter: WebhookAdapter = _context.adapter
|
||||
adapter: WebhookAdapter = _get_webhook_adapter()
|
||||
data = adapter.get_webhook_message(
|
||||
self.id,
|
||||
self.token,
|
||||
@@ -940,7 +972,7 @@ class SyncWebhook(BaseWebhook):
|
||||
file: File = MISSING,
|
||||
files: List[File] = MISSING,
|
||||
allowed_mentions: Optional[AllowedMentions] = None,
|
||||
):
|
||||
) -> SyncWebhookMessage:
|
||||
"""Edits a message owned by this webhook.
|
||||
|
||||
This is a lower level interface to :meth:`WebhookMessage.edit` in case
|
||||
@@ -995,8 +1027,8 @@ class SyncWebhook(BaseWebhook):
|
||||
allowed_mentions=allowed_mentions,
|
||||
previous_allowed_mentions=previous_mentions,
|
||||
)
|
||||
adapter: WebhookAdapter = _context.adapter
|
||||
adapter.edit_webhook_message(
|
||||
adapter: WebhookAdapter = _get_webhook_adapter()
|
||||
data = adapter.edit_webhook_message(
|
||||
self.id,
|
||||
self.token,
|
||||
message_id,
|
||||
@@ -1005,8 +1037,9 @@ class SyncWebhook(BaseWebhook):
|
||||
multipart=params.multipart,
|
||||
files=params.files,
|
||||
)
|
||||
return self._create_message(data)
|
||||
|
||||
def delete_message(self, message_id: int):
|
||||
def delete_message(self, message_id: int, /) -> None:
|
||||
"""Deletes a message owned by this webhook.
|
||||
|
||||
This is a lower level interface to :meth:`WebhookMessage.delete` in case
|
||||
@@ -1029,7 +1062,7 @@ class SyncWebhook(BaseWebhook):
|
||||
if self.token is None:
|
||||
raise InvalidArgument('This webhook does not have a token associated with it')
|
||||
|
||||
adapter: WebhookAdapter = _context.adapter
|
||||
adapter: WebhookAdapter = _get_webhook_adapter()
|
||||
adapter.delete_webhook_message(
|
||||
self.id,
|
||||
self.token,
|
||||
|
||||
Reference in New Issue
Block a user