Merge branch '2.0' into pr7422
This commit is contained in:
@@ -60,13 +60,15 @@ from .interactions import *
|
||||
from .components import *
|
||||
from .threads import *
|
||||
|
||||
class VersionInfo(NamedTuple):
|
||||
major: int
|
||||
minor: int
|
||||
micro: int
|
||||
releaselevel: Literal["alpha", "beta", "candidate", "final"]
|
||||
serial: int
|
||||
|
||||
version_info = VersionInfo(major=2, minor=0, micro=0, releaselevel='alpha', serial=0)
|
||||
class VersionInfo(NamedTuple):
|
||||
major: int
|
||||
minor: int
|
||||
micro: int
|
||||
releaselevel: Literal["alpha", "beta", "candidate", "final"]
|
||||
serial: int
|
||||
|
||||
|
||||
version_info: VersionInfo = VersionInfo(major=2, minor=0, micro=0, releaselevel='alpha', serial=0)
|
||||
|
||||
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
||||
|
||||
@@ -51,7 +51,7 @@ def core(parser, args):
|
||||
if args.version:
|
||||
show_version()
|
||||
|
||||
bot_template = """#!/usr/bin/env python3
|
||||
_bot_template = """#!/usr/bin/env python3
|
||||
|
||||
from discord.ext import commands
|
||||
import discord
|
||||
@@ -77,7 +77,7 @@ bot = Bot()
|
||||
bot.run(config.token)
|
||||
"""
|
||||
|
||||
gitignore_template = """# Byte-compiled / optimized / DLL files
|
||||
_gitignore_template = """# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
@@ -107,7 +107,7 @@ var/
|
||||
config.py
|
||||
"""
|
||||
|
||||
cog_template = '''from discord.ext import commands
|
||||
_cog_template = '''from discord.ext import commands
|
||||
import discord
|
||||
|
||||
class {name}(commands.Cog{attrs}):
|
||||
@@ -120,7 +120,7 @@ def setup(bot):
|
||||
bot.add_cog({name}(bot))
|
||||
'''
|
||||
|
||||
cog_extras = '''
|
||||
_cog_extras = '''
|
||||
def cog_unload(self):
|
||||
# clean up logic goes here
|
||||
pass
|
||||
@@ -170,7 +170,7 @@ _base_table = {
|
||||
# NUL (0) and 1-31 are disallowed
|
||||
_base_table.update((chr(i), None) for i in range(32))
|
||||
|
||||
translation_table = str.maketrans(_base_table)
|
||||
_translation_table = str.maketrans(_base_table)
|
||||
|
||||
def to_path(parser, name, *, replace_spaces=False):
|
||||
if isinstance(name, Path):
|
||||
@@ -182,7 +182,7 @@ def to_path(parser, name, *, replace_spaces=False):
|
||||
if len(name) <= 4 and name.upper() in forbidden:
|
||||
parser.error('invalid directory name given, use a different one')
|
||||
|
||||
name = name.translate(translation_table)
|
||||
name = name.translate(_translation_table)
|
||||
if replace_spaces:
|
||||
name = name.replace(' ', '-')
|
||||
return Path(name)
|
||||
@@ -215,14 +215,14 @@ def newbot(parser, args):
|
||||
try:
|
||||
with open(str(new_directory / 'bot.py'), 'w', encoding='utf-8') as fp:
|
||||
base = 'Bot' if not args.sharded else 'AutoShardedBot'
|
||||
fp.write(bot_template.format(base=base, prefix=args.prefix))
|
||||
fp.write(_bot_template.format(base=base, prefix=args.prefix))
|
||||
except OSError as exc:
|
||||
parser.error(f'could not create bot file ({exc})')
|
||||
|
||||
if not args.no_git:
|
||||
try:
|
||||
with open(str(new_directory / '.gitignore'), 'w', encoding='utf-8') as fp:
|
||||
fp.write(gitignore_template)
|
||||
fp.write(_gitignore_template)
|
||||
except OSError as exc:
|
||||
print(f'warning: could not create .gitignore file ({exc})')
|
||||
|
||||
@@ -240,7 +240,7 @@ def newcog(parser, args):
|
||||
try:
|
||||
with open(str(directory), 'w', encoding='utf-8') as fp:
|
||||
attrs = ''
|
||||
extra = cog_extras if args.full else ''
|
||||
extra = _cog_extras if args.full else ''
|
||||
if args.class_name:
|
||||
name = args.class_name
|
||||
else:
|
||||
@@ -255,7 +255,7 @@ def newcog(parser, args):
|
||||
attrs += f', name="{args.display_name}"'
|
||||
if args.hide_commands:
|
||||
attrs += ', command_attrs=dict(hidden=True)'
|
||||
fp.write(cog_template.format(name=name, extra=extra, attrs=attrs))
|
||||
fp.write(_cog_template.format(name=name, extra=extra, attrs=attrs))
|
||||
except OSError as exc:
|
||||
parser.error(f'could not create cog file ({exc})')
|
||||
else:
|
||||
|
||||
@@ -84,6 +84,7 @@ if TYPE_CHECKING:
|
||||
from .ui.view import View
|
||||
from .types.channel import (
|
||||
PermissionOverwrite as PermissionOverwritePayload,
|
||||
Channel as ChannelPayload,
|
||||
GuildChannel as GuildChannelPayload,
|
||||
OverwriteType,
|
||||
)
|
||||
@@ -122,11 +123,6 @@ class Snowflake(Protocol):
|
||||
__slots__ = ()
|
||||
id: int
|
||||
|
||||
@property
|
||||
def created_at(self) -> datetime:
|
||||
""":class:`datetime.datetime`: Returns the model's creation time as an aware datetime in UTC."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class User(Snowflake, Protocol):
|
||||
@@ -307,11 +303,8 @@ class GuildChannel:
|
||||
payload.append(d)
|
||||
|
||||
await http.bulk_channel_update(self.guild.id, payload, reason=reason)
|
||||
self.position = position
|
||||
if parent_id is not _undefined:
|
||||
self.category_id = int(parent_id) if parent_id else None
|
||||
|
||||
async def _edit(self, options: Dict[str, Any], reason: Optional[str]):
|
||||
async def _edit(self, options: Dict[str, Any], reason: Optional[str]) -> Optional[ChannelPayload]:
|
||||
try:
|
||||
parent = options.pop('category')
|
||||
except KeyError:
|
||||
@@ -390,8 +383,7 @@ class GuildChannel:
|
||||
options['type'] = ch_type.value
|
||||
|
||||
if options:
|
||||
data = await self._state.http.edit_channel(self.id, reason=reason, **options)
|
||||
self._update(self.guild, data)
|
||||
return await self._state.http.edit_channel(self.id, reason=reason, **options)
|
||||
|
||||
def _fill_overwrites(self, data: GuildChannelPayload) -> None:
|
||||
self._overwrites = []
|
||||
@@ -1641,6 +1633,8 @@ class Connectable(Protocol):
|
||||
Connects to voice and creates a :class:`VoiceClient` to establish
|
||||
your connection to the voice server.
|
||||
|
||||
This requires :attr:`Intents.voice_states`.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
timeout: :class:`float`
|
||||
|
||||
@@ -830,10 +830,12 @@ def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]:
|
||||
except KeyError:
|
||||
return Activity(**data)
|
||||
else:
|
||||
return CustomActivity(name=name, **data)
|
||||
# we removed the name key from data already
|
||||
return CustomActivity(name=name, **data) # type: ignore
|
||||
elif game_type is ActivityType.streaming:
|
||||
if 'url' in data:
|
||||
return Streaming(**data)
|
||||
# the url won't be None here
|
||||
return Streaming(**data) # type: ignore
|
||||
return Activity(**data)
|
||||
elif game_type is ActivityType.listening and 'sync_id' in data and 'session_id' in data:
|
||||
return Spotify(**data)
|
||||
|
||||
@@ -177,6 +177,17 @@ class Asset(AssetMixin):
|
||||
animated=animated,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) -> Asset:
|
||||
animated = avatar.startswith('a_')
|
||||
format = 'gif' if animated else 'png'
|
||||
return cls(
|
||||
state,
|
||||
url=f"{cls.BASE}/guilds/{guild_id}/users/{member_id}/avatars/{avatar}.{format}?size=1024",
|
||||
key=avatar,
|
||||
animated=animated,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset:
|
||||
return cls(
|
||||
@@ -302,10 +313,11 @@ class Asset(AssetMixin):
|
||||
if self._animated:
|
||||
if format not in VALID_ASSET_FORMATS:
|
||||
raise InvalidArgument(f'format must be one of {VALID_ASSET_FORMATS}')
|
||||
else:
|
||||
url = url.with_path(f'{path}.{format}')
|
||||
elif static_format is MISSING:
|
||||
if format not in VALID_STATIC_FORMATS:
|
||||
raise InvalidArgument(f'format must be one of {VALID_STATIC_FORMATS}')
|
||||
url = url.with_path(f'{path}.{format}')
|
||||
url = url.with_path(f'{path}.{format}')
|
||||
|
||||
if static_format is not MISSING and not self._animated:
|
||||
if static_format not in VALID_STATIC_FORMATS:
|
||||
|
||||
@@ -330,6 +330,10 @@ class AuditLogEntry(Hashable):
|
||||
|
||||
Returns the entry's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the entry's ID.
|
||||
|
||||
.. versionchanged:: 1.7
|
||||
Audit log entries are now comparable and hashable.
|
||||
|
||||
|
||||
@@ -86,6 +86,7 @@ if TYPE_CHECKING:
|
||||
StoreChannel as StoreChannelPayload,
|
||||
GroupDMChannel as GroupChannelPayload,
|
||||
)
|
||||
from .types.snowflake import SnowflakeList
|
||||
|
||||
|
||||
async def _single_delete_strategy(messages: Iterable[Message]):
|
||||
@@ -114,6 +115,10 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
|
||||
Returns the channel's name.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the channel's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
name: :class:`str`
|
||||
@@ -229,7 +234,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return [thread for thread in self.guild.threads if thread.parent_id == self.id]
|
||||
return [thread for thread in self.guild._threads.values() if thread.parent_id == self.id]
|
||||
|
||||
def is_nsfw(self) -> bool:
|
||||
""":class:`bool`: Checks if the channel is NSFW."""
|
||||
@@ -275,11 +280,11 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
default_auto_archive_duration: ThreadArchiveDuration = ...,
|
||||
type: ChannelType = ...,
|
||||
overwrites: Mapping[Union[Role, Member, Snowflake], PermissionOverwrite] = ...,
|
||||
) -> None:
|
||||
) -> Optional[TextChannel]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def edit(self) -> None:
|
||||
async def edit(self) -> Optional[TextChannel]:
|
||||
...
|
||||
|
||||
async def edit(self, *, reason=None, **options):
|
||||
@@ -296,6 +301,9 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
.. versionchanged:: 1.4
|
||||
The ``type`` keyword-only parameter was added.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
Edits are no longer in-place, the newly edited channel is returned instead.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: :class:`str`
|
||||
@@ -337,8 +345,18 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
You do not have permissions to edit the channel.
|
||||
HTTPException
|
||||
Editing the channel failed.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Optional[:class:`.TextChannel`]
|
||||
The newly edited text channel. If the edit was only positional
|
||||
then ``None`` is returned instead.
|
||||
"""
|
||||
await self._edit(options, reason=reason)
|
||||
|
||||
payload = await self._edit(options, reason=reason)
|
||||
if payload is not None:
|
||||
# the payload will always be the proper channel payload
|
||||
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
|
||||
|
||||
@utils.copy_doc(discord.abc.GuildChannel.clone)
|
||||
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> TextChannel:
|
||||
@@ -392,7 +410,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
if len(messages) > 100:
|
||||
raise ClientException('Can only bulk delete messages up to 100 messages')
|
||||
|
||||
message_ids: List[int] = [m.id for m in messages]
|
||||
message_ids: SnowflakeList = [m.id for m in messages]
|
||||
await self._state.http.delete_messages(self.id, message_ids)
|
||||
|
||||
async def purge(
|
||||
@@ -670,15 +688,8 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
|
||||
Creates a thread in this text channel.
|
||||
|
||||
If no starter message is passed with the ``message`` parameter then
|
||||
you must have :attr:`~discord.Permissions.send_messages` and
|
||||
:attr:`~discord.Permissions.use_private_threads` in order to create the thread
|
||||
if the ``type`` parameter is :attr:`~discord.ChannelType.private_thread`.
|
||||
Otherwise :attr:`~discord.Permissions.use_threads` is needed.
|
||||
|
||||
If a starter message is passed with the ``message`` parameter then
|
||||
you must have :attr:`~discord.Permissions.send_messages` and
|
||||
:attr:`~discord.Permissions.use_threads` in order to create the thread.
|
||||
To create a public thread, you must have :attr:`~discord.Permissions.create_public_threads`.
|
||||
For a private thread, :attr:`~discord.Permissions.create_private_threads` is needed instead.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
@@ -957,11 +968,11 @@ class VoiceChannel(VocalGuildChannel):
|
||||
rtc_region: Optional[VoiceRegion] = ...,
|
||||
video_quality_mode: VideoQualityMode = ...,
|
||||
reason: Optional[str] = ...,
|
||||
) -> None:
|
||||
) -> Optional[VoiceChannel]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def edit(self) -> None:
|
||||
async def edit(self) -> Optional[VoiceChannel]:
|
||||
...
|
||||
|
||||
async def edit(self, *, reason=None, **options):
|
||||
@@ -975,6 +986,9 @@ class VoiceChannel(VocalGuildChannel):
|
||||
.. versionchanged:: 1.3
|
||||
The ``overwrites`` keyword-only parameter was added.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
Edits are no longer in-place, the newly edited channel is returned instead.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: :class:`str`
|
||||
@@ -1014,9 +1028,18 @@ class VoiceChannel(VocalGuildChannel):
|
||||
You do not have permissions to edit the channel.
|
||||
HTTPException
|
||||
Editing the channel failed.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Optional[:class:`.VoiceChannel`]
|
||||
The newly edited voice channel. If the edit was only positional
|
||||
then ``None`` is returned instead.
|
||||
"""
|
||||
|
||||
await self._edit(options, reason=reason)
|
||||
payload = await self._edit(options, reason=reason)
|
||||
if payload is not None:
|
||||
# the payload will always be the proper channel payload
|
||||
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
|
||||
|
||||
|
||||
class StageChannel(VocalGuildChannel):
|
||||
@@ -1224,11 +1247,11 @@ class StageChannel(VocalGuildChannel):
|
||||
rtc_region: Optional[VoiceRegion] = ...,
|
||||
video_quality_mode: VideoQualityMode = ...,
|
||||
reason: Optional[str] = ...,
|
||||
) -> None:
|
||||
) -> Optional[StageChannel]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def edit(self) -> None:
|
||||
async def edit(self) -> Optional[StageChannel]:
|
||||
...
|
||||
|
||||
async def edit(self, *, reason=None, **options):
|
||||
@@ -1242,6 +1265,9 @@ class StageChannel(VocalGuildChannel):
|
||||
.. versionchanged:: 2.0
|
||||
The ``topic`` parameter must now be set via :attr:`create_instance`.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
Edits are no longer in-place, the newly edited channel is returned instead.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: :class:`str`
|
||||
@@ -1275,9 +1301,18 @@ class StageChannel(VocalGuildChannel):
|
||||
You do not have permissions to edit the channel.
|
||||
HTTPException
|
||||
Editing the channel failed.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Optional[:class:`.StageChannel`]
|
||||
The newly edited stage channel. If the edit was only positional
|
||||
then ``None`` is returned instead.
|
||||
"""
|
||||
|
||||
await self._edit(options, reason=reason)
|
||||
payload = await self._edit(options, reason=reason)
|
||||
if payload is not None:
|
||||
# the payload will always be the proper channel payload
|
||||
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
|
||||
|
||||
|
||||
class CategoryChannel(discord.abc.GuildChannel, Hashable):
|
||||
@@ -1303,6 +1338,10 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
|
||||
|
||||
Returns the category's name.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the category's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
name: :class:`str`
|
||||
@@ -1366,11 +1405,11 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
|
||||
nsfw: bool = ...,
|
||||
overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
|
||||
reason: Optional[str] = ...,
|
||||
) -> None:
|
||||
) -> Optional[CategoryChannel]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def edit(self) -> None:
|
||||
async def edit(self) -> Optional[CategoryChannel]:
|
||||
...
|
||||
|
||||
async def edit(self, *, reason=None, **options):
|
||||
@@ -1384,6 +1423,9 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
|
||||
.. versionchanged:: 1.3
|
||||
The ``overwrites`` keyword-only parameter was added.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
Edits are no longer in-place, the newly edited channel is returned instead.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: :class:`str`
|
||||
@@ -1406,9 +1448,18 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
|
||||
You do not have permissions to edit the category.
|
||||
HTTPException
|
||||
Editing the category failed.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Optional[:class:`.CategoryChannel`]
|
||||
The newly edited category channel. If the edit was only positional
|
||||
then ``None`` is returned instead.
|
||||
"""
|
||||
|
||||
await self._edit(options=options, reason=reason)
|
||||
payload = await self._edit(options, reason=reason)
|
||||
if payload is not None:
|
||||
# the payload will always be the proper channel payload
|
||||
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
|
||||
|
||||
@utils.copy_doc(discord.abc.GuildChannel.move)
|
||||
async def move(self, **kwargs):
|
||||
@@ -1513,6 +1564,10 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
|
||||
|
||||
Returns the channel's name.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the channel's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
name: :class:`str`
|
||||
@@ -1598,11 +1653,11 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
|
||||
category: Optional[CategoryChannel],
|
||||
reason: Optional[str],
|
||||
overwrites: Mapping[Union[Role, Member], PermissionOverwrite],
|
||||
) -> None:
|
||||
) -> Optional[StoreChannel]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def edit(self) -> None:
|
||||
async def edit(self) -> Optional[StoreChannel]:
|
||||
...
|
||||
|
||||
async def edit(self, *, reason=None, **options):
|
||||
@@ -1613,6 +1668,9 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
|
||||
You must have the :attr:`~Permissions.manage_channels` permission to
|
||||
use this.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
Edits are no longer in-place, the newly edited channel is returned instead.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: :class:`str`
|
||||
@@ -1644,8 +1702,18 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
|
||||
You do not have permissions to edit the channel.
|
||||
HTTPException
|
||||
Editing the channel failed.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Optional[:class:`.StoreChannel`]
|
||||
The newly edited store channel. If the edit was only positional
|
||||
then ``None`` is returned instead.
|
||||
"""
|
||||
await self._edit(options, reason=reason)
|
||||
|
||||
payload = await self._edit(options, reason=reason)
|
||||
if payload is not None:
|
||||
# the payload will always be the proper channel payload
|
||||
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
|
||||
|
||||
|
||||
DMC = TypeVar('DMC', bound='DMChannel')
|
||||
@@ -1672,6 +1740,10 @@ class DMChannel(discord.abc.Messageable, Hashable):
|
||||
|
||||
Returns a string representation of the channel
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the channel's ID.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
recipient: Optional[:class:`User`]
|
||||
@@ -1709,6 +1781,7 @@ class DMChannel(discord.abc.Messageable, Hashable):
|
||||
self._state = state
|
||||
self.id = channel_id
|
||||
self.recipient = None
|
||||
# state.user won't be None here
|
||||
self.me = state.user # type: ignore
|
||||
return self
|
||||
|
||||
@@ -1797,6 +1870,10 @@ class GroupChannel(discord.abc.Messageable, Hashable):
|
||||
|
||||
Returns a string representation of the channel
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the channel's ID.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
recipients: List[:class:`User`]
|
||||
@@ -1943,6 +2020,10 @@ class PartialMessageable(discord.abc.Messageable, Hashable):
|
||||
|
||||
Returns the partial messageable's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the messageable's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
id: :class:`int`
|
||||
|
||||
@@ -29,7 +29,7 @@ import logging
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any, Callable, Coroutine, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union
|
||||
from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union
|
||||
|
||||
import aiohttp
|
||||
|
||||
@@ -76,7 +76,7 @@ __all__ = (
|
||||
Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])
|
||||
|
||||
|
||||
log: logging.Logger = logging.getLogger(__name__)
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
|
||||
tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()}
|
||||
@@ -84,12 +84,12 @@ def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
|
||||
if not tasks:
|
||||
return
|
||||
|
||||
log.info('Cleaning up after %d tasks.', len(tasks))
|
||||
_log.info('Cleaning up after %d tasks.', len(tasks))
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
|
||||
log.info('All tasks finished cancelling.')
|
||||
_log.info('All tasks finished cancelling.')
|
||||
|
||||
for task in tasks:
|
||||
if task.cancelled():
|
||||
@@ -106,7 +106,7 @@ def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||
_cancel_tasks(loop)
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
finally:
|
||||
log.info('Closing the event loop.')
|
||||
_log.info('Closing the event loop.')
|
||||
loop.close()
|
||||
|
||||
class Client:
|
||||
@@ -142,7 +142,6 @@ class Client:
|
||||
intents: :class:`Intents`
|
||||
The intents that you want to enable for the session. This is a way of
|
||||
disabling and enabling certain gateway events from triggering and being sent.
|
||||
If not given, defaults to a regularly constructed :class:`Intents` class.
|
||||
|
||||
.. versionadded:: 1.5
|
||||
member_cache_flags: :class:`MemberCacheFlags`
|
||||
@@ -187,7 +186,7 @@ class Client:
|
||||
enable_debug_events: :class:`bool`
|
||||
Whether to enable events that are useful only for debugging gateway related information.
|
||||
|
||||
Right now this involves :func:`on_socket_raw_receive` and :func`:`on_socket_raw_send`. If
|
||||
Right now this involves :func:`on_socket_raw_receive` and :func:`on_socket_raw_send`. If
|
||||
this is ``False`` then those events will not be dispatched (due to performance considerations).
|
||||
To enable these events, this must be set to ``True``. Defaults to ``False``.
|
||||
|
||||
@@ -203,9 +202,13 @@ class Client:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
intents: Intents,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
**options: Any,
|
||||
):
|
||||
options["intents"] = intents
|
||||
|
||||
# self.ws is set in the connect method
|
||||
self.ws: DiscordWebSocket = None # type: ignore
|
||||
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
|
||||
self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {}
|
||||
@@ -236,7 +239,7 @@ class Client:
|
||||
|
||||
if VoiceClient.warn_nacl:
|
||||
VoiceClient.warn_nacl = False
|
||||
log.warning("PyNaCl is not installed, voice will NOT be supported")
|
||||
_log.warning("PyNaCl is not installed, voice will NOT be supported")
|
||||
|
||||
# internals
|
||||
|
||||
@@ -362,7 +365,7 @@ class Client:
|
||||
return asyncio.create_task(wrapped, name=f'discord.py: {event_name}')
|
||||
|
||||
def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None:
|
||||
log.debug('Dispatching event %s', event)
|
||||
_log.debug('Dispatching event %s', event)
|
||||
method = 'on_' + event
|
||||
|
||||
listeners = self._listeners.get(event)
|
||||
@@ -467,7 +470,7 @@ class Client:
|
||||
passing status code.
|
||||
"""
|
||||
|
||||
log.info('logging in using static token')
|
||||
_log.info('logging in using static token')
|
||||
|
||||
data = await self.http.static_login(token.strip())
|
||||
self._connection.user = ClientUser(state=self._connection, data=data)
|
||||
@@ -510,7 +513,7 @@ class Client:
|
||||
while True:
|
||||
await self.ws.poll_event()
|
||||
except ReconnectWebSocket as e:
|
||||
log.info('Got a request to %s the websocket.', e.op)
|
||||
_log.info('Got a request to %s the websocket.', e.op)
|
||||
self.dispatch('disconnect')
|
||||
ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id)
|
||||
continue
|
||||
@@ -549,7 +552,7 @@ class Client:
|
||||
raise
|
||||
|
||||
retry = backoff.delay()
|
||||
log.exception("Attempting a reconnect in %.2fs", retry)
|
||||
_log.exception("Attempting a reconnect in %.2fs", retry)
|
||||
await asyncio.sleep(retry)
|
||||
# Always try to RESUME the connection
|
||||
# If the connection is not RESUME-able then the gateway will invalidate the session.
|
||||
@@ -651,10 +654,10 @@ class Client:
|
||||
try:
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt:
|
||||
log.info('Received signal to terminate bot and event loop.')
|
||||
_log.info('Received signal to terminate bot and event loop.')
|
||||
finally:
|
||||
future.remove_done_callback(stop_loop_on_completion)
|
||||
log.info('Cleaning up tasks.')
|
||||
_log.info('Cleaning up tasks.')
|
||||
_cleanup_loop(loop)
|
||||
|
||||
if not future.cancelled():
|
||||
@@ -682,9 +685,30 @@ class Client:
|
||||
if value is None:
|
||||
self._connection._activity = None
|
||||
elif isinstance(value, BaseActivity):
|
||||
self._connection._activity = value.to_dict()
|
||||
# ConnectionState._activity is typehinted as ActivityPayload, we're passing Dict[str, Any]
|
||||
self._connection._activity = value.to_dict() # type: ignore
|
||||
else:
|
||||
raise TypeError('activity must derive from BaseActivity.')
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
""":class:`.Status`:
|
||||
The status being used upon logging on to Discord.
|
||||
|
||||
.. versionadded: 2.0
|
||||
"""
|
||||
if self._connection._status in set(state.value for state in Status):
|
||||
return Status(self._connection._status)
|
||||
return Status.online
|
||||
|
||||
@status.setter
|
||||
def status(self, value):
|
||||
if value is Status.offline:
|
||||
self._connection._status = 'invisible'
|
||||
elif isinstance(value, Status):
|
||||
self._connection._status = str(value)
|
||||
else:
|
||||
raise TypeError('status must derive from Status.')
|
||||
|
||||
@property
|
||||
def allowed_mentions(self) -> Optional[AllowedMentions]:
|
||||
@@ -716,8 +740,8 @@ class Client:
|
||||
"""List[:class:`~discord.User`]: Returns a list of all the users the bot can see."""
|
||||
return list(self._connection._users.values())
|
||||
|
||||
def get_channel(self, id: int) -> Optional[Union[GuildChannel, PrivateChannel]]:
|
||||
"""Returns a channel with the given ID.
|
||||
def get_channel(self, id: int, /) -> Optional[Union[GuildChannel, Thread, PrivateChannel]]:
|
||||
"""Returns a channel or thread with the given ID.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
@@ -726,7 +750,7 @@ class Client:
|
||||
|
||||
Returns
|
||||
--------
|
||||
Optional[Union[:class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`]]
|
||||
Optional[Union[:class:`.abc.GuildChannel`, :class:`.Thread`, :class:`.abc.PrivateChannel`]]
|
||||
The returned channel or ``None`` if not found.
|
||||
"""
|
||||
return self._connection.get_channel(id)
|
||||
@@ -743,7 +767,7 @@ class Client:
|
||||
-----------
|
||||
id: :class:`int`
|
||||
The channel ID to create a partial messageable for.
|
||||
type: Optional[:class:`ChannelType`]
|
||||
type: Optional[:class:`.ChannelType`]
|
||||
The underlying channel type for the partial messageable.
|
||||
|
||||
Returns
|
||||
@@ -753,7 +777,7 @@ class Client:
|
||||
"""
|
||||
return PartialMessageable(state=self._connection, id=id, type=type)
|
||||
|
||||
def get_stage_instance(self, id) -> Optional[StageInstance]:
|
||||
def get_stage_instance(self, id: int, /) -> Optional[StageInstance]:
|
||||
"""Returns a stage instance with the given stage channel ID.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
@@ -775,7 +799,7 @@ class Client:
|
||||
if isinstance(channel, StageChannel):
|
||||
return channel.instance
|
||||
|
||||
def get_guild(self, id) -> Optional[Guild]:
|
||||
def get_guild(self, id: int, /) -> Optional[Guild]:
|
||||
"""Returns a guild with the given ID.
|
||||
|
||||
Parameters
|
||||
@@ -790,7 +814,7 @@ class Client:
|
||||
"""
|
||||
return self._connection._get_guild(id)
|
||||
|
||||
def get_user(self, id) -> Optional[User]:
|
||||
def get_user(self, id: int, /) -> Optional[User]:
|
||||
"""Returns a user with the given ID.
|
||||
|
||||
Parameters
|
||||
@@ -805,7 +829,7 @@ class Client:
|
||||
"""
|
||||
return self._connection.get_user(id)
|
||||
|
||||
def get_emoji(self, id) -> Optional[Emoji]:
|
||||
def get_emoji(self, id: int, /) -> Optional[Emoji]:
|
||||
"""Returns an emoji with the given ID.
|
||||
|
||||
Parameters
|
||||
@@ -820,7 +844,7 @@ class Client:
|
||||
"""
|
||||
return self._connection.get_emoji(id)
|
||||
|
||||
def get_sticker(self, id: int) -> Optional[GuildSticker]:
|
||||
def get_sticker(self, id: int, /) -> Optional[GuildSticker]:
|
||||
"""Returns a guild sticker with the given ID.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
@@ -1019,7 +1043,7 @@ class Client:
|
||||
raise TypeError('event registered must be a coroutine function')
|
||||
|
||||
setattr(self, coro.__name__, coro)
|
||||
log.debug('%s has successfully been registered as an event', coro.__name__)
|
||||
_log.debug('%s has successfully been registered as an event', coro.__name__)
|
||||
return coro
|
||||
|
||||
async def change_presence(
|
||||
@@ -1169,7 +1193,7 @@ class Client:
|
||||
data = await self.http.get_template(code)
|
||||
return Template(data=data, state=self._connection) # type: ignore
|
||||
|
||||
async def fetch_guild(self, guild_id: int) -> Guild:
|
||||
async def fetch_guild(self, guild_id: int, /) -> Guild:
|
||||
"""|coro|
|
||||
|
||||
Retrieves a :class:`.Guild` from an ID.
|
||||
@@ -1258,7 +1282,7 @@ class Client:
|
||||
data = await self.http.create_guild(name, region_value, icon_base64)
|
||||
return Guild(data=data, state=self._connection)
|
||||
|
||||
async def fetch_stage_instance(self, channel_id: int) -> StageInstance:
|
||||
async def fetch_stage_instance(self, channel_id: int, /) -> StageInstance:
|
||||
"""|coro|
|
||||
|
||||
Gets a :class:`.StageInstance` for a stage channel id.
|
||||
@@ -1358,7 +1382,7 @@ class Client:
|
||||
|
||||
# Miscellaneous stuff
|
||||
|
||||
async def fetch_widget(self, guild_id: int) -> Widget:
|
||||
async def fetch_widget(self, guild_id: int, /) -> Widget:
|
||||
"""|coro|
|
||||
|
||||
Gets a :class:`.Widget` from a guild ID.
|
||||
@@ -1408,7 +1432,7 @@ class Client:
|
||||
data['rpc_origins'] = None
|
||||
return AppInfo(self._connection, data)
|
||||
|
||||
async def fetch_user(self, user_id: int) -> User:
|
||||
async def fetch_user(self, user_id: int, /) -> User:
|
||||
"""|coro|
|
||||
|
||||
Retrieves a :class:`~discord.User` based on their ID.
|
||||
@@ -1439,7 +1463,7 @@ class Client:
|
||||
data = await self.http.get_user(user_id)
|
||||
return User(state=self._connection, data=data)
|
||||
|
||||
async def fetch_channel(self, channel_id: int) -> Union[GuildChannel, PrivateChannel, Thread]:
|
||||
async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, PrivateChannel, Thread]:
|
||||
"""|coro|
|
||||
|
||||
Retrieves a :class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, or :class:`.Thread` with the specified ID.
|
||||
@@ -1473,15 +1497,18 @@ class Client:
|
||||
raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data))
|
||||
|
||||
if ch_type in (ChannelType.group, ChannelType.private):
|
||||
channel = factory(me=self.user, data=data, state=self._connection)
|
||||
# the factory will be a DMChannel or GroupChannel here
|
||||
channel = factory(me=self.user, data=data, state=self._connection) # type: ignore
|
||||
else:
|
||||
guild_id = int(data['guild_id'])
|
||||
# the factory can't be a DMChannel or GroupChannel here
|
||||
guild_id = int(data['guild_id']) # type: ignore
|
||||
guild = self.get_guild(guild_id) or Object(id=guild_id)
|
||||
channel = factory(guild=guild, state=self._connection, data=data)
|
||||
# GuildChannels expect a Guild, we may be passing an Object
|
||||
channel = factory(guild=guild, state=self._connection, data=data) # type: ignore
|
||||
|
||||
return channel
|
||||
|
||||
async def fetch_webhook(self, webhook_id: int) -> Webhook:
|
||||
async def fetch_webhook(self, webhook_id: int, /) -> Webhook:
|
||||
"""|coro|
|
||||
|
||||
Retrieves a :class:`.Webhook` with the specified ID.
|
||||
@@ -1503,7 +1530,7 @@ class Client:
|
||||
data = await self.http.get_webhook(webhook_id)
|
||||
return Webhook.from_state(data, state=self._connection)
|
||||
|
||||
async def fetch_sticker(self, sticker_id: int) -> Union[StandardSticker, GuildSticker]:
|
||||
async def fetch_sticker(self, sticker_id: int, /) -> Union[StandardSticker, GuildSticker]:
|
||||
"""|coro|
|
||||
|
||||
Retrieves a :class:`.Sticker` with the specified ID.
|
||||
|
||||
@@ -78,7 +78,7 @@ class Colour:
|
||||
|
||||
__slots__ = ('value',)
|
||||
|
||||
def __init__(self, value):
|
||||
def __init__(self, value: int):
|
||||
if not isinstance(value, int):
|
||||
raise TypeError(f'Expected int parameter, received {value.__class__.__name__} instead.')
|
||||
|
||||
@@ -171,6 +171,14 @@ class Colour:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x11806a``."""
|
||||
return cls(0x11806a)
|
||||
|
||||
@classmethod
|
||||
def brand_green(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x57F287``.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return cls(0x57F287)
|
||||
|
||||
@classmethod
|
||||
def green(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x2ecc71``."""
|
||||
@@ -231,6 +239,14 @@ class Colour:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0xa84300``."""
|
||||
return cls(0xa84300)
|
||||
|
||||
@classmethod
|
||||
def brand_red(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0xED4245``.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return cls(0xED4245)
|
||||
|
||||
@classmethod
|
||||
def red(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0xe74c3c``."""
|
||||
@@ -308,6 +324,15 @@ class Colour:
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return cls(0xFEE75C)
|
||||
|
||||
@classmethod
|
||||
def dark_blurple(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x4E5D94``.
|
||||
This is the original Dark Blurple branding.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return cls(0x4E5D94)
|
||||
|
||||
|
||||
Color = Colour
|
||||
|
||||
@@ -277,14 +277,14 @@ class SelectOption:
|
||||
-----------
|
||||
label: :class:`str`
|
||||
The label of the option. This is displayed to users.
|
||||
Can only be up to 25 characters.
|
||||
Can only be up to 100 characters.
|
||||
value: :class:`str`
|
||||
The value of the option. This is not displayed to users.
|
||||
If not provided when constructed then it defaults to the
|
||||
label. Can only be up to 100 characters.
|
||||
description: Optional[:class:`str`]
|
||||
An additional description of the option, if any.
|
||||
Can only be up to 50 characters.
|
||||
Can only be up to 100 characters.
|
||||
emoji: Optional[Union[:class:`str`, :class:`Emoji`, :class:`PartialEmoji`]]
|
||||
The emoji of the option, if available.
|
||||
default: :class:`bool`
|
||||
|
||||
@@ -22,13 +22,23 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, TypeVar, Optional, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .abc import Messageable
|
||||
|
||||
from types import TracebackType
|
||||
|
||||
TypingT = TypeVar('TypingT', bound='Typing')
|
||||
|
||||
__all__ = (
|
||||
'Typing',
|
||||
)
|
||||
|
||||
def _typing_done_callback(fut):
|
||||
def _typing_done_callback(fut: asyncio.Future) -> None:
|
||||
# just retrieve any exception and call it a day
|
||||
try:
|
||||
fut.exception()
|
||||
@@ -36,11 +46,11 @@ def _typing_done_callback(fut):
|
||||
pass
|
||||
|
||||
class Typing:
|
||||
def __init__(self, messageable):
|
||||
self.loop = messageable._state.loop
|
||||
self.messageable = messageable
|
||||
def __init__(self, messageable: Messageable) -> None:
|
||||
self.loop: asyncio.AbstractEventLoop = messageable._state.loop
|
||||
self.messageable: Messageable = messageable
|
||||
|
||||
async def do_typing(self):
|
||||
async def do_typing(self) -> None:
|
||||
try:
|
||||
channel = self._channel
|
||||
except AttributeError:
|
||||
@@ -52,18 +62,26 @@ class Typing:
|
||||
await typing(channel.id)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
def __enter__(self):
|
||||
self.task = asyncio.ensure_future(self.do_typing(), loop=self.loop)
|
||||
def __enter__(self: TypingT) -> TypingT:
|
||||
self.task: asyncio.Task = self.loop.create_task(self.do_typing())
|
||||
self.task.add_done_callback(_typing_done_callback)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
def __exit__(self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.task.cancel()
|
||||
|
||||
async def __aenter__(self):
|
||||
async def __aenter__(self: TypingT) -> TypingT:
|
||||
self._channel = channel = await self.messageable._get_channel()
|
||||
await channel._state.http.send_typing(channel.id)
|
||||
return self.__enter__()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
async def __aexit__(self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.task.cancel()
|
||||
|
||||
@@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from typing import Any, Dict, Final, List, Protocol, TYPE_CHECKING, Type, TypeVar, Union
|
||||
from typing import Any, Dict, Final, List, Mapping, Protocol, TYPE_CHECKING, Type, TypeVar, Union
|
||||
|
||||
from . import utils
|
||||
from .colour import Colour
|
||||
@@ -72,30 +72,36 @@ if TYPE_CHECKING:
|
||||
T = TypeVar('T')
|
||||
MaybeEmpty = Union[T, _EmptyEmbed]
|
||||
|
||||
|
||||
class _EmbedFooterProxy(Protocol):
|
||||
text: MaybeEmpty[str]
|
||||
icon_url: MaybeEmpty[str]
|
||||
|
||||
|
||||
class _EmbedFieldProxy(Protocol):
|
||||
name: MaybeEmpty[str]
|
||||
value: MaybeEmpty[str]
|
||||
inline: bool
|
||||
|
||||
|
||||
class _EmbedMediaProxy(Protocol):
|
||||
url: MaybeEmpty[str]
|
||||
proxy_url: MaybeEmpty[str]
|
||||
height: MaybeEmpty[int]
|
||||
width: MaybeEmpty[int]
|
||||
|
||||
|
||||
class _EmbedVideoProxy(Protocol):
|
||||
url: MaybeEmpty[str]
|
||||
height: MaybeEmpty[int]
|
||||
width: MaybeEmpty[int]
|
||||
|
||||
|
||||
class _EmbedProviderProxy(Protocol):
|
||||
name: MaybeEmpty[str]
|
||||
url: MaybeEmpty[str]
|
||||
|
||||
|
||||
class _EmbedAuthorProxy(Protocol):
|
||||
name: MaybeEmpty[str]
|
||||
url: MaybeEmpty[str]
|
||||
@@ -175,15 +181,15 @@ class Embed:
|
||||
Empty: Final = EmptyEmbed
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
|
||||
color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
|
||||
title: MaybeEmpty[Any] = EmptyEmbed,
|
||||
type: EmbedType = 'rich',
|
||||
url: MaybeEmpty[Any] = EmptyEmbed,
|
||||
description: MaybeEmpty[Any] = EmptyEmbed,
|
||||
timestamp: datetime.datetime = None,
|
||||
self,
|
||||
*,
|
||||
colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
|
||||
color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
|
||||
title: MaybeEmpty[Any] = EmptyEmbed,
|
||||
type: EmbedType = 'rich',
|
||||
url: MaybeEmpty[Any] = EmptyEmbed,
|
||||
description: MaybeEmpty[Any] = EmptyEmbed,
|
||||
timestamp: datetime.datetime = None,
|
||||
):
|
||||
|
||||
self.colour = colour if colour is not EmptyEmbed else color
|
||||
@@ -205,7 +211,7 @@ class Embed:
|
||||
self.timestamp = timestamp
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[E], data: EmbedData) -> E:
|
||||
def from_dict(cls: Type[E], data: Mapping[str, Any]) -> E:
|
||||
"""Converts a :class:`dict` to a :class:`Embed` provided it is in the
|
||||
format that Discord expects it to be in.
|
||||
|
||||
@@ -366,7 +372,7 @@ class Embed:
|
||||
self._footer['icon_url'] = str(icon_url)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def remove_footer(self: E) -> E:
|
||||
"""Clears embed's footer information.
|
||||
|
||||
@@ -381,7 +387,7 @@ class Embed:
|
||||
pass
|
||||
|
||||
return self
|
||||
|
||||
|
||||
@property
|
||||
def image(self) -> _EmbedMediaProxy:
|
||||
"""Returns an ``EmbedProxy`` denoting the image contents.
|
||||
@@ -397,6 +403,19 @@ class Embed:
|
||||
"""
|
||||
return EmbedProxy(getattr(self, '_image', {})) # type: ignore
|
||||
|
||||
@image.setter
|
||||
def image(self: E, *, url: Any):
|
||||
self._image = {
|
||||
'url': str(url),
|
||||
}
|
||||
|
||||
@image.deleter
|
||||
def image(self: E):
|
||||
try:
|
||||
del self._image
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def set_image(self: E, *, url: MaybeEmpty[Any]) -> E:
|
||||
"""Sets the image for the embed content.
|
||||
|
||||
@@ -413,14 +432,9 @@ class Embed:
|
||||
"""
|
||||
|
||||
if url is EmptyEmbed:
|
||||
try:
|
||||
del self._image
|
||||
except AttributeError:
|
||||
pass
|
||||
del self.image
|
||||
else:
|
||||
self._image = {
|
||||
'url': str(url),
|
||||
}
|
||||
self.image = url
|
||||
|
||||
return self
|
||||
|
||||
@@ -439,7 +453,25 @@ class Embed:
|
||||
"""
|
||||
return EmbedProxy(getattr(self, '_thumbnail', {})) # type: ignore
|
||||
|
||||
def set_thumbnail(self: E, *, url: MaybeEmpty[Any]) -> E:
|
||||
@thumbnail.setter
|
||||
def thumbnail(self: E, *, url: Any):
|
||||
"""Sets the thumbnail for the embed content.
|
||||
"""
|
||||
|
||||
self._thumbnail = {
|
||||
'url': str(url),
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
@thumbnail.deleter
|
||||
def thumbnail(self):
|
||||
try:
|
||||
del self.thumbnail
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def set_thumbnail(self: E, *, url: MaybeEmpty[Any]):
|
||||
"""Sets the thumbnail for the embed content.
|
||||
|
||||
This function returns the class instance to allow for fluent-style
|
||||
@@ -453,16 +485,10 @@ class Embed:
|
||||
url: :class:`str`
|
||||
The source URL for the thumbnail. Only HTTP(S) is supported.
|
||||
"""
|
||||
|
||||
if url is EmptyEmbed:
|
||||
try:
|
||||
del self._thumbnail
|
||||
except AttributeError:
|
||||
pass
|
||||
del self.thumbnail
|
||||
else:
|
||||
self._thumbnail = {
|
||||
'url': str(url),
|
||||
}
|
||||
self.thumbnail = url
|
||||
|
||||
return self
|
||||
|
||||
|
||||
@@ -72,6 +72,10 @@ class Emoji(_EmojiTag, AssetMixin):
|
||||
|
||||
Returns the emoji rendered for discord.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the emoji ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
name: :class:`str`
|
||||
@@ -137,6 +141,9 @@ class Emoji(_EmojiTag, AssetMixin):
|
||||
return f'<a:{self.name}:{self.id}>'
|
||||
return f'<:{self.name}:{self.id}>'
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>'
|
||||
|
||||
@@ -212,7 +219,7 @@ class Emoji(_EmojiTag, AssetMixin):
|
||||
|
||||
await self._state.http.delete_custom_emoji(self.guild.id, self.id, reason=reason)
|
||||
|
||||
async def edit(self, *, name: str = MISSING, roles: List[Snowflake] = MISSING, reason: Optional[str] = None) -> None:
|
||||
async def edit(self, *, name: str = MISSING, roles: List[Snowflake] = MISSING, reason: Optional[str] = None) -> Emoji:
|
||||
r"""|coro|
|
||||
|
||||
Edits the custom emoji.
|
||||
@@ -220,6 +227,9 @@ class Emoji(_EmojiTag, AssetMixin):
|
||||
You must have :attr:`~Permissions.manage_emojis` permission to
|
||||
do this.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
The newly updated emoji is returned.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
name: :class:`str`
|
||||
@@ -235,6 +245,11 @@ class Emoji(_EmojiTag, AssetMixin):
|
||||
You are not allowed to edit emojis.
|
||||
HTTPException
|
||||
An error occurred editing the emoji.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`Emoji`
|
||||
The newly updated emoji.
|
||||
"""
|
||||
|
||||
payload = {}
|
||||
@@ -243,4 +258,5 @@ class Emoji(_EmojiTag, AssetMixin):
|
||||
if roles is not MISSING:
|
||||
payload['roles'] = [role.id for role in roles]
|
||||
|
||||
await self._state.http.edit_custom_emoji(self.guild.id, self.id, payload=payload, reason=reason)
|
||||
data = await self._state.http.edit_custom_emoji(self.guild.id, self.id, payload=payload, reason=reason)
|
||||
return Emoji(guild=self.guild, data=data, state=self._state)
|
||||
|
||||
@@ -58,13 +58,17 @@ __all__ = (
|
||||
)
|
||||
|
||||
|
||||
def _create_value_cls(name):
|
||||
def _create_value_cls(name, comparable):
|
||||
cls = namedtuple('_EnumValue_' + name, 'name value')
|
||||
cls.__repr__ = lambda self: f'<{name}.{self.name}: {self.value!r}>'
|
||||
cls.__str__ = lambda self: f'{name}.{self.name}'
|
||||
if comparable:
|
||||
cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value
|
||||
cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value
|
||||
cls.__lt__ = lambda self, other: isinstance(other, self.__class__) and self.value < other.value
|
||||
cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value
|
||||
return cls
|
||||
|
||||
|
||||
def _is_descriptor(obj):
|
||||
return hasattr(obj, '__get__') or hasattr(obj, '__set__') or hasattr(obj, '__delete__')
|
||||
|
||||
@@ -76,12 +80,12 @@ class EnumMeta(type):
|
||||
_enum_member_map_: ClassVar[Dict[str, Any]]
|
||||
_enum_value_map_: ClassVar[Dict[Any, Any]]
|
||||
|
||||
def __new__(cls, name, bases, attrs):
|
||||
def __new__(cls, name, bases, attrs, *, comparable: bool = False):
|
||||
value_mapping = {}
|
||||
member_mapping = {}
|
||||
member_names = []
|
||||
|
||||
value_cls = _create_value_cls(name)
|
||||
value_cls = _create_value_cls(name, comparable)
|
||||
for key, value in list(attrs.items()):
|
||||
is_descriptor = _is_descriptor(value)
|
||||
if key[0] == '_' and not is_descriptor:
|
||||
@@ -252,7 +256,7 @@ class SpeakingState(Enum):
|
||||
return self.value
|
||||
|
||||
|
||||
class VerificationLevel(Enum):
|
||||
class VerificationLevel(Enum, comparable=True):
|
||||
none = 0
|
||||
low = 1
|
||||
medium = 2
|
||||
@@ -263,7 +267,7 @@ class VerificationLevel(Enum):
|
||||
return self.name
|
||||
|
||||
|
||||
class ContentFilter(Enum):
|
||||
class ContentFilter(Enum, comparable=True):
|
||||
disabled = 0
|
||||
no_role = 1
|
||||
all_members = 2
|
||||
@@ -296,7 +300,7 @@ class DefaultAvatar(Enum):
|
||||
return self.name
|
||||
|
||||
|
||||
class NotificationLevel(Enum):
|
||||
class NotificationLevel(Enum, comparable=True):
|
||||
all_messages = 0
|
||||
only_mentions = 1
|
||||
|
||||
@@ -578,7 +582,7 @@ class StagePrivacyLevel(Enum):
|
||||
guild_only = 2
|
||||
|
||||
|
||||
class NSFWLevel(Enum):
|
||||
class NSFWLevel(Enum, comparable=True):
|
||||
default = 0
|
||||
explicit = 1
|
||||
safe = 2
|
||||
|
||||
@@ -31,9 +31,9 @@ if TYPE_CHECKING:
|
||||
try:
|
||||
from requests import Response
|
||||
|
||||
ResponseType = Union[ClientResponse, Response]
|
||||
_ResponseType = Union[ClientResponse, Response]
|
||||
except ModuleNotFoundError:
|
||||
ResponseType = ClientResponse
|
||||
_ResponseType = ClientResponse
|
||||
|
||||
from .interactions import Interaction
|
||||
|
||||
@@ -123,8 +123,8 @@ class HTTPException(DiscordException):
|
||||
The Discord specific error code for the failure.
|
||||
"""
|
||||
|
||||
def __init__(self, response: ResponseType, message: Optional[Union[str, Dict[str, Any]]]):
|
||||
self.response: ResponseType = response
|
||||
def __init__(self, response: _ResponseType, message: Optional[Union[str, Dict[str, Any]]]):
|
||||
self.response: _ResponseType = response
|
||||
self.status: int = response.status # type: ignore
|
||||
self.code: int
|
||||
self.text: str
|
||||
|
||||
@@ -22,6 +22,26 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
|
||||
from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .context import Context
|
||||
from .cog import Cog
|
||||
from .errors import CommandError
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
Coro = Coroutine[Any, Any, T]
|
||||
MaybeCoro = Union[T, Coro[T]]
|
||||
CoroFunc = Callable[..., Coro[Any]]
|
||||
|
||||
Check = Union[Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]]
|
||||
Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]]
|
||||
Error = Union[Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]]
|
||||
|
||||
|
||||
# This is merely a tag type to avoid circular import issues.
|
||||
# Yes, this is a terrible solution but ultimately it is the only solution.
|
||||
class _BaseCommand:
|
||||
|
||||
@@ -22,13 +22,18 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import collections.abc
|
||||
import inspect
|
||||
import importlib.util
|
||||
import sys
|
||||
import traceback
|
||||
import types
|
||||
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union
|
||||
|
||||
import discord
|
||||
|
||||
@@ -39,6 +44,15 @@ from . import errors
|
||||
from .help import HelpCommand, DefaultHelpCommand
|
||||
from .cog import Cog
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import importlib.machinery
|
||||
|
||||
from discord.message import Message
|
||||
from ._types import (
|
||||
Check,
|
||||
CoroFunc,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
'when_mentioned',
|
||||
'when_mentioned_or',
|
||||
@@ -46,14 +60,21 @@ __all__ = (
|
||||
'AutoShardedBot',
|
||||
)
|
||||
|
||||
def when_mentioned(bot, msg):
|
||||
MISSING: Any = discord.utils.MISSING
|
||||
|
||||
T = TypeVar('T')
|
||||
CFT = TypeVar('CFT', bound='CoroFunc')
|
||||
CXT = TypeVar('CXT', bound='Context')
|
||||
|
||||
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
|
||||
"""A callable that implements a command prefix equivalent to being mentioned.
|
||||
|
||||
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
|
||||
"""
|
||||
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> ']
|
||||
# bot.user will never be None when this is called
|
||||
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore
|
||||
|
||||
def when_mentioned_or(*prefixes):
|
||||
def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]:
|
||||
"""A callable that implements when mentioned or other prefixes provided.
|
||||
|
||||
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
|
||||
@@ -89,7 +110,7 @@ def when_mentioned_or(*prefixes):
|
||||
|
||||
return inner
|
||||
|
||||
def _is_submodule(parent, child):
|
||||
def _is_submodule(parent: str, child: str) -> bool:
|
||||
return parent == child or child.startswith(parent + ".")
|
||||
|
||||
class _DefaultRepr:
|
||||
@@ -99,13 +120,13 @@ class _DefaultRepr:
|
||||
_default = _DefaultRepr()
|
||||
|
||||
class BotBase(GroupMixin):
|
||||
def __init__(self, command_prefix, help_command=_default, description=None, **options):
|
||||
super().__init__(**options)
|
||||
def __init__(self, command_prefix, help_command=_default, description=None, *, intents: discord.Intents, **options):
|
||||
super().__init__(**options, intents=intents)
|
||||
self.command_prefix = command_prefix
|
||||
self.extra_events = {}
|
||||
self.__cogs = {}
|
||||
self.__extensions = {}
|
||||
self._checks = []
|
||||
self.extra_events: Dict[str, List[CoroFunc]] = {}
|
||||
self.__cogs: Dict[str, Cog] = {}
|
||||
self.__extensions: Dict[str, types.ModuleType] = {}
|
||||
self._checks: List[Check] = []
|
||||
self._check_once = []
|
||||
self._before_invoke = None
|
||||
self._after_invoke = None
|
||||
@@ -128,13 +149,15 @@ class BotBase(GroupMixin):
|
||||
|
||||
# internal helpers
|
||||
|
||||
def dispatch(self, event_name, *args, **kwargs):
|
||||
super().dispatch(event_name, *args, **kwargs)
|
||||
def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None:
|
||||
# super() will resolve to Client
|
||||
super().dispatch(event_name, *args, **kwargs) # type: ignore
|
||||
ev = 'on_' + event_name
|
||||
for event in self.extra_events.get(ev, []):
|
||||
self._schedule_event(event, ev, *args, **kwargs)
|
||||
self._schedule_event(event, ev, *args, **kwargs) # type: ignore
|
||||
|
||||
async def close(self):
|
||||
@discord.utils.copy_doc(discord.Client.close)
|
||||
async def close(self) -> None:
|
||||
for extension in tuple(self.__extensions):
|
||||
try:
|
||||
self.unload_extension(extension)
|
||||
@@ -147,9 +170,9 @@ class BotBase(GroupMixin):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await super().close()
|
||||
await super().close() # type: ignore
|
||||
|
||||
async def on_command_error(self, context, exception):
|
||||
async def on_command_error(self, context: Context, exception: errors.CommandError) -> None:
|
||||
"""|coro|
|
||||
|
||||
The default command error handler provided by the bot.
|
||||
@@ -175,7 +198,7 @@ class BotBase(GroupMixin):
|
||||
|
||||
# global check registration
|
||||
|
||||
def check(self, func):
|
||||
def check(self, func: T) -> T:
|
||||
r"""A decorator that adds a global check to the bot.
|
||||
|
||||
A global check is similar to a :func:`.check` that is applied
|
||||
@@ -200,10 +223,11 @@ class BotBase(GroupMixin):
|
||||
return ctx.command.qualified_name in allowed_commands
|
||||
|
||||
"""
|
||||
self.add_check(func)
|
||||
# T was used instead of Check to ensure the type matches on return
|
||||
self.add_check(func) # type: ignore
|
||||
return func
|
||||
|
||||
def add_check(self, func, *, call_once=False):
|
||||
def add_check(self, func: Check, *, call_once: bool = False) -> None:
|
||||
"""Adds a global check to the bot.
|
||||
|
||||
This is the non-decorator interface to :meth:`.check`
|
||||
@@ -223,7 +247,7 @@ class BotBase(GroupMixin):
|
||||
else:
|
||||
self._checks.append(func)
|
||||
|
||||
def remove_check(self, func, *, call_once=False):
|
||||
def remove_check(self, func: Check, *, call_once: bool = False) -> None:
|
||||
"""Removes a global check from the bot.
|
||||
|
||||
This function is idempotent and will not raise an exception
|
||||
@@ -244,7 +268,7 @@ class BotBase(GroupMixin):
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def check_once(self, func):
|
||||
def check_once(self, func: CFT) -> CFT:
|
||||
r"""A decorator that adds a "call once" global check to the bot.
|
||||
|
||||
Unlike regular global checks, this one is called only once
|
||||
@@ -282,15 +306,16 @@ class BotBase(GroupMixin):
|
||||
self.add_check(func, call_once=True)
|
||||
return func
|
||||
|
||||
async def can_run(self, ctx, *, call_once=False):
|
||||
async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool:
|
||||
data = self._check_once if call_once else self._checks
|
||||
|
||||
if len(data) == 0:
|
||||
return True
|
||||
|
||||
return await discord.utils.async_all(f(ctx) for f in data)
|
||||
# type-checker doesn't distinguish between functions and methods
|
||||
return await discord.utils.async_all(f(ctx) for f in data) # type: ignore
|
||||
|
||||
async def is_owner(self, user):
|
||||
async def is_owner(self, user: discord.User) -> bool:
|
||||
"""|coro|
|
||||
|
||||
Checks if a :class:`~discord.User` or :class:`~discord.Member` is the owner of
|
||||
@@ -319,7 +344,8 @@ class BotBase(GroupMixin):
|
||||
elif self.owner_ids:
|
||||
return user.id in self.owner_ids
|
||||
else:
|
||||
app = await self.application_info()
|
||||
|
||||
app = await self.application_info() # type: ignore
|
||||
if app.team:
|
||||
self.owner_ids = ids = {m.id for m in app.team.members}
|
||||
return user.id in ids
|
||||
@@ -327,7 +353,7 @@ class BotBase(GroupMixin):
|
||||
self.owner_id = owner_id = app.owner.id
|
||||
return user.id == owner_id
|
||||
|
||||
def before_invoke(self, coro):
|
||||
def before_invoke(self, coro: CFT) -> CFT:
|
||||
"""A decorator that registers a coroutine as a pre-invoke hook.
|
||||
|
||||
A pre-invoke hook is called directly before the command is
|
||||
@@ -359,7 +385,7 @@ class BotBase(GroupMixin):
|
||||
self._before_invoke = coro
|
||||
return coro
|
||||
|
||||
def after_invoke(self, coro):
|
||||
def after_invoke(self, coro: CFT) -> CFT:
|
||||
r"""A decorator that registers a coroutine as a post-invoke hook.
|
||||
|
||||
A post-invoke hook is called directly after the command is
|
||||
@@ -394,14 +420,14 @@ class BotBase(GroupMixin):
|
||||
|
||||
# listener registration
|
||||
|
||||
def add_listener(self, func, name=None):
|
||||
def add_listener(self, func: CoroFunc, name: str = MISSING) -> None:
|
||||
"""The non decorator alternative to :meth:`.listen`.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
func: :ref:`coroutine <coroutine>`
|
||||
The function to call.
|
||||
name: Optional[:class:`str`]
|
||||
name: :class:`str`
|
||||
The name of the event to listen for. Defaults to ``func.__name__``.
|
||||
|
||||
Example
|
||||
@@ -416,7 +442,7 @@ class BotBase(GroupMixin):
|
||||
bot.add_listener(my_message, 'on_message')
|
||||
|
||||
"""
|
||||
name = func.__name__ if name is None else name
|
||||
name = func.__name__ if name is MISSING else name
|
||||
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
raise TypeError('Listeners must be coroutines')
|
||||
@@ -426,7 +452,7 @@ class BotBase(GroupMixin):
|
||||
else:
|
||||
self.extra_events[name] = [func]
|
||||
|
||||
def remove_listener(self, func, name=None):
|
||||
def remove_listener(self, func: CoroFunc, name: str = MISSING) -> None:
|
||||
"""Removes a listener from the pool of listeners.
|
||||
|
||||
Parameters
|
||||
@@ -438,7 +464,7 @@ class BotBase(GroupMixin):
|
||||
``func.__name__``.
|
||||
"""
|
||||
|
||||
name = func.__name__ if name is None else name
|
||||
name = func.__name__ if name is MISSING else name
|
||||
|
||||
if name in self.extra_events:
|
||||
try:
|
||||
@@ -446,7 +472,7 @@ class BotBase(GroupMixin):
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def listen(self, name=None):
|
||||
def listen(self, name: str = MISSING) -> Callable[[CFT], CFT]:
|
||||
"""A decorator that registers another function as an external
|
||||
event listener. Basically this allows you to listen to multiple
|
||||
events from different places e.g. such as :func:`.on_ready`
|
||||
@@ -476,7 +502,7 @@ class BotBase(GroupMixin):
|
||||
The function being listened to is not a coroutine.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
def decorator(func: CFT) -> CFT:
|
||||
self.add_listener(func, name)
|
||||
return func
|
||||
|
||||
@@ -528,7 +554,7 @@ class BotBase(GroupMixin):
|
||||
cog = cog._inject(self)
|
||||
self.__cogs[cog_name] = cog
|
||||
|
||||
def get_cog(self, name):
|
||||
def get_cog(self, name: str) -> Optional[Cog]:
|
||||
"""Gets the cog instance requested.
|
||||
|
||||
If the cog is not found, ``None`` is returned instead.
|
||||
@@ -547,7 +573,7 @@ class BotBase(GroupMixin):
|
||||
"""
|
||||
return self.__cogs.get(name)
|
||||
|
||||
def remove_cog(self, name):
|
||||
def remove_cog(self, name: str) -> Optional[Cog]:
|
||||
"""Removes a cog from the bot and returns it.
|
||||
|
||||
All registered commands and event listeners that the
|
||||
@@ -578,13 +604,13 @@ class BotBase(GroupMixin):
|
||||
return cog
|
||||
|
||||
@property
|
||||
def cogs(self):
|
||||
def cogs(self) -> Mapping[str, Cog]:
|
||||
"""Mapping[:class:`str`, :class:`Cog`]: A read-only mapping of cog name to cog."""
|
||||
return types.MappingProxyType(self.__cogs)
|
||||
|
||||
# extensions
|
||||
|
||||
def _remove_module_references(self, name):
|
||||
def _remove_module_references(self, name: str) -> None:
|
||||
# find all references to the module
|
||||
# remove the cogs registered from the module
|
||||
for cogname, cog in self.__cogs.copy().items():
|
||||
@@ -608,7 +634,7 @@ class BotBase(GroupMixin):
|
||||
for index in reversed(remove):
|
||||
del event_list[index]
|
||||
|
||||
def _call_module_finalizers(self, lib, key):
|
||||
def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None:
|
||||
try:
|
||||
func = getattr(lib, 'teardown')
|
||||
except AttributeError:
|
||||
@@ -626,12 +652,12 @@ class BotBase(GroupMixin):
|
||||
if _is_submodule(name, module):
|
||||
del sys.modules[module]
|
||||
|
||||
def _load_from_module_spec(self, spec, key):
|
||||
def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None:
|
||||
# precondition: key not in self.__extensions
|
||||
lib = importlib.util.module_from_spec(spec)
|
||||
sys.modules[key] = lib
|
||||
try:
|
||||
spec.loader.exec_module(lib)
|
||||
spec.loader.exec_module(lib) # type: ignore
|
||||
except Exception as e:
|
||||
del sys.modules[key]
|
||||
raise errors.ExtensionFailed(key, e) from e
|
||||
@@ -652,13 +678,13 @@ class BotBase(GroupMixin):
|
||||
else:
|
||||
self.__extensions[key] = lib
|
||||
|
||||
def _resolve_name(self, name, package):
|
||||
def _resolve_name(self, name: str, package: Optional[str]) -> str:
|
||||
try:
|
||||
return importlib.util.resolve_name(name, package)
|
||||
except ImportError:
|
||||
raise errors.ExtensionNotFound(name)
|
||||
|
||||
def load_extension(self, name, *, package=None):
|
||||
def load_extension(self, name: str, *, package: Optional[str] = None) -> None:
|
||||
"""Loads an extension.
|
||||
|
||||
An extension is a python module that contains commands, cogs, or
|
||||
@@ -705,7 +731,7 @@ class BotBase(GroupMixin):
|
||||
|
||||
self._load_from_module_spec(spec, name)
|
||||
|
||||
def unload_extension(self, name, *, package=None):
|
||||
def unload_extension(self, name: str, *, package: Optional[str] = None) -> None:
|
||||
"""Unloads an extension.
|
||||
|
||||
When the extension is unloaded, all commands, listeners, and cogs are
|
||||
@@ -746,7 +772,7 @@ class BotBase(GroupMixin):
|
||||
self._remove_module_references(lib.__name__)
|
||||
self._call_module_finalizers(lib, name)
|
||||
|
||||
def reload_extension(self, name, *, package=None):
|
||||
def reload_extension(self, name: str, *, package: Optional[str] = None) -> None:
|
||||
"""Atomically reloads an extension.
|
||||
|
||||
This replaces the extension with the same extension, only refreshed. This is
|
||||
@@ -802,7 +828,7 @@ class BotBase(GroupMixin):
|
||||
# if the load failed, the remnants should have been
|
||||
# cleaned from the load_extension function call
|
||||
# so let's load it from our old compiled library.
|
||||
lib.setup(self)
|
||||
lib.setup(self) # type: ignore
|
||||
self.__extensions[name] = lib
|
||||
|
||||
# revert sys.modules back to normal and raise back to caller
|
||||
@@ -810,18 +836,18 @@ class BotBase(GroupMixin):
|
||||
raise
|
||||
|
||||
@property
|
||||
def extensions(self):
|
||||
def extensions(self) -> Mapping[str, types.ModuleType]:
|
||||
"""Mapping[:class:`str`, :class:`py:types.ModuleType`]: A read-only mapping of extension name to extension."""
|
||||
return types.MappingProxyType(self.__extensions)
|
||||
|
||||
# help command stuff
|
||||
|
||||
@property
|
||||
def help_command(self):
|
||||
def help_command(self) -> Optional[HelpCommand]:
|
||||
return self._help_command
|
||||
|
||||
@help_command.setter
|
||||
def help_command(self, value):
|
||||
def help_command(self, value: Optional[HelpCommand]) -> None:
|
||||
if value is not None:
|
||||
if not isinstance(value, HelpCommand):
|
||||
raise TypeError('help_command must be a subclass of HelpCommand')
|
||||
@@ -837,7 +863,7 @@ class BotBase(GroupMixin):
|
||||
|
||||
# command processing
|
||||
|
||||
async def get_prefix(self, message):
|
||||
async def get_prefix(self, message: Message) -> Union[List[str], str]:
|
||||
"""|coro|
|
||||
|
||||
Retrieves the prefix the bot is listening to
|
||||
@@ -875,7 +901,7 @@ class BotBase(GroupMixin):
|
||||
|
||||
return ret
|
||||
|
||||
async def get_context(self, message, *, cls=Context):
|
||||
async def get_context(self, message: Message, *, cls: Type[CXT] = Context) -> CXT:
|
||||
r"""|coro|
|
||||
|
||||
Returns the invocation context from the message.
|
||||
@@ -908,7 +934,7 @@ class BotBase(GroupMixin):
|
||||
view = StringView(message.content)
|
||||
ctx = cls(prefix=None, view=view, bot=self, message=message)
|
||||
|
||||
if message.author.id == self.user.id:
|
||||
if message.author.id == self.user.id: # type: ignore
|
||||
return ctx
|
||||
|
||||
prefix = await self.get_prefix(message)
|
||||
@@ -945,11 +971,12 @@ class BotBase(GroupMixin):
|
||||
|
||||
invoker = view.get_word()
|
||||
ctx.invoked_with = invoker
|
||||
ctx.prefix = invoked_prefix
|
||||
# type-checker fails to narrow invoked_prefix type.
|
||||
ctx.prefix = invoked_prefix # type: ignore
|
||||
ctx.command = self.all_commands.get(invoker)
|
||||
return ctx
|
||||
|
||||
async def invoke(self, ctx):
|
||||
async def invoke(self, ctx: Context) -> None:
|
||||
"""|coro|
|
||||
|
||||
Invokes the command given under the invocation context and
|
||||
@@ -975,7 +1002,7 @@ class BotBase(GroupMixin):
|
||||
exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found')
|
||||
self.dispatch('command_error', ctx, exc)
|
||||
|
||||
async def process_commands(self, message):
|
||||
async def process_commands(self, message: Message) -> None:
|
||||
"""|coro|
|
||||
|
||||
This function processes the commands that have been registered
|
||||
|
||||
@@ -21,15 +21,30 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import discord.utils
|
||||
|
||||
from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Type
|
||||
|
||||
from ._types import _BaseCommand
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .bot import BotBase
|
||||
from .context import Context
|
||||
from .core import Command
|
||||
|
||||
__all__ = (
|
||||
'CogMeta',
|
||||
'Cog',
|
||||
)
|
||||
|
||||
CogT = TypeVar('CogT', bound='Cog')
|
||||
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
|
||||
|
||||
MISSING: Any = discord.utils.MISSING
|
||||
|
||||
class CogMeta(type):
|
||||
"""A metaclass for defining a cog.
|
||||
|
||||
@@ -89,8 +104,12 @@ class CogMeta(type):
|
||||
async def bar(self, ctx):
|
||||
pass # hidden -> False
|
||||
"""
|
||||
__cog_name__: str
|
||||
__cog_settings__: Dict[str, Any]
|
||||
__cog_commands__: List[Command]
|
||||
__cog_listeners__: List[Tuple[str, str]]
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta:
|
||||
name, bases, attrs = args
|
||||
attrs['__cog_name__'] = kwargs.pop('name', name)
|
||||
attrs['__cog_settings__'] = kwargs.pop('command_attrs', {})
|
||||
@@ -143,14 +162,14 @@ class CogMeta(type):
|
||||
new_cls.__cog_listeners__ = listeners_as_list
|
||||
return new_cls
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args)
|
||||
|
||||
@classmethod
|
||||
def qualified_name(cls):
|
||||
def qualified_name(cls) -> str:
|
||||
return cls.__cog_name__
|
||||
|
||||
def _cog_special_method(func):
|
||||
def _cog_special_method(func: FuncT) -> FuncT:
|
||||
func.__cog_special_method__ = None
|
||||
return func
|
||||
|
||||
@@ -164,8 +183,12 @@ class Cog(metaclass=CogMeta):
|
||||
When inheriting from this class, the options shown in :class:`CogMeta`
|
||||
are equally valid here.
|
||||
"""
|
||||
__cog_name__: ClassVar[str]
|
||||
__cog_settings__: ClassVar[Dict[str, Any]]
|
||||
__cog_commands__: ClassVar[List[Command]]
|
||||
__cog_listeners__: ClassVar[List[Tuple[str, str]]]
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
def __new__(cls: Type[CogT], *args: Any, **kwargs: Any) -> CogT:
|
||||
# For issue 426, we need to store a copy of the command objects
|
||||
# since we modify them to inject `self` to them.
|
||||
# To do this, we need to interfere with the Cog creation process.
|
||||
@@ -173,7 +196,8 @@ class Cog(metaclass=CogMeta):
|
||||
cmd_attrs = cls.__cog_settings__
|
||||
|
||||
# Either update the command with the cog provided defaults or copy it.
|
||||
self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__)
|
||||
# r.e type ignore, type-checker complains about overriding a ClassVar
|
||||
self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore
|
||||
|
||||
lookup = {
|
||||
cmd.qualified_name: cmd
|
||||
@@ -186,15 +210,15 @@ class Cog(metaclass=CogMeta):
|
||||
parent = command.parent
|
||||
if parent is not None:
|
||||
# Get the latest parent reference
|
||||
parent = lookup[parent.qualified_name]
|
||||
parent = lookup[parent.qualified_name] # type: ignore
|
||||
|
||||
# Update our parent's reference to our self
|
||||
parent.remove_command(command.name)
|
||||
parent.add_command(command)
|
||||
parent.remove_command(command.name) # type: ignore
|
||||
parent.add_command(command) # type: ignore
|
||||
|
||||
return self
|
||||
|
||||
def get_commands(self):
|
||||
def get_commands(self) -> List[Command]:
|
||||
r"""
|
||||
Returns
|
||||
--------
|
||||
@@ -209,20 +233,20 @@ class Cog(metaclass=CogMeta):
|
||||
return [c for c in self.__cog_commands__ if c.parent is None]
|
||||
|
||||
@property
|
||||
def qualified_name(self):
|
||||
def qualified_name(self) -> str:
|
||||
""":class:`str`: Returns the cog's specified name, not the class name."""
|
||||
return self.__cog_name__
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
def description(self) -> str:
|
||||
""":class:`str`: Returns the cog's description, typically the cleaned docstring."""
|
||||
return self.__cog_description__
|
||||
|
||||
@description.setter
|
||||
def description(self, description):
|
||||
def description(self, description: str) -> None:
|
||||
self.__cog_description__ = description
|
||||
|
||||
def walk_commands(self):
|
||||
def walk_commands(self) -> Generator[Command, None, None]:
|
||||
"""An iterator that recursively walks through this cog's commands and subcommands.
|
||||
|
||||
Yields
|
||||
@@ -237,7 +261,7 @@ class Cog(metaclass=CogMeta):
|
||||
if isinstance(command, GroupMixin):
|
||||
yield from command.walk_commands()
|
||||
|
||||
def get_listeners(self):
|
||||
def get_listeners(self) -> List[Tuple[str, Callable[..., Any]]]:
|
||||
"""Returns a :class:`list` of (name, function) listener pairs that are defined in this cog.
|
||||
|
||||
Returns
|
||||
@@ -248,12 +272,12 @@ class Cog(metaclass=CogMeta):
|
||||
return [(name, getattr(self, method_name)) for name, method_name in self.__cog_listeners__]
|
||||
|
||||
@classmethod
|
||||
def _get_overridden_method(cls, method):
|
||||
def _get_overridden_method(cls, method: FuncT) -> Optional[FuncT]:
|
||||
"""Return None if the method is not overridden. Otherwise returns the overridden method."""
|
||||
return getattr(method.__func__, '__cog_special_method__', method)
|
||||
|
||||
@classmethod
|
||||
def listener(cls, name=None):
|
||||
def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]:
|
||||
"""A decorator that marks a function as a listener.
|
||||
|
||||
This is the cog equivalent of :meth:`.Bot.listen`.
|
||||
@@ -271,10 +295,10 @@ class Cog(metaclass=CogMeta):
|
||||
the name.
|
||||
"""
|
||||
|
||||
if name is not None and not isinstance(name, str):
|
||||
if name is not MISSING and not isinstance(name, str):
|
||||
raise TypeError(f'Cog.listener expected str but received {name.__class__.__name__!r} instead.')
|
||||
|
||||
def decorator(func):
|
||||
def decorator(func: FuncT) -> FuncT:
|
||||
actual = func
|
||||
if isinstance(actual, staticmethod):
|
||||
actual = actual.__func__
|
||||
@@ -293,7 +317,7 @@ class Cog(metaclass=CogMeta):
|
||||
return func
|
||||
return decorator
|
||||
|
||||
def has_error_handler(self):
|
||||
def has_error_handler(self) -> bool:
|
||||
""":class:`bool`: Checks whether the cog has an error handler.
|
||||
|
||||
.. versionadded:: 1.7
|
||||
@@ -301,7 +325,7 @@ class Cog(metaclass=CogMeta):
|
||||
return not hasattr(self.cog_command_error.__func__, '__cog_special_method__')
|
||||
|
||||
@_cog_special_method
|
||||
def cog_unload(self):
|
||||
def cog_unload(self) -> None:
|
||||
"""A special method that is called when the cog gets removed.
|
||||
|
||||
This function **cannot** be a coroutine. It must be a regular
|
||||
@@ -312,7 +336,7 @@ class Cog(metaclass=CogMeta):
|
||||
pass
|
||||
|
||||
@_cog_special_method
|
||||
def bot_check_once(self, ctx):
|
||||
def bot_check_once(self, ctx: Context) -> bool:
|
||||
"""A special method that registers as a :meth:`.Bot.check_once`
|
||||
check.
|
||||
|
||||
@@ -322,7 +346,7 @@ class Cog(metaclass=CogMeta):
|
||||
return True
|
||||
|
||||
@_cog_special_method
|
||||
def bot_check(self, ctx):
|
||||
def bot_check(self, ctx: Context) -> bool:
|
||||
"""A special method that registers as a :meth:`.Bot.check`
|
||||
check.
|
||||
|
||||
@@ -332,7 +356,7 @@ class Cog(metaclass=CogMeta):
|
||||
return True
|
||||
|
||||
@_cog_special_method
|
||||
def cog_check(self, ctx):
|
||||
def cog_check(self, ctx: Context) -> bool:
|
||||
"""A special method that registers as a :func:`~discord.ext.commands.check`
|
||||
for every command and subcommand in this cog.
|
||||
|
||||
@@ -342,7 +366,7 @@ class Cog(metaclass=CogMeta):
|
||||
return True
|
||||
|
||||
@_cog_special_method
|
||||
async def cog_command_error(self, ctx, error):
|
||||
async def cog_command_error(self, ctx: Context, error: Exception) -> None:
|
||||
"""A special method that is called whenever an error
|
||||
is dispatched inside this cog.
|
||||
|
||||
@@ -361,7 +385,7 @@ class Cog(metaclass=CogMeta):
|
||||
pass
|
||||
|
||||
@_cog_special_method
|
||||
async def cog_before_invoke(self, ctx):
|
||||
async def cog_before_invoke(self, ctx: Context) -> None:
|
||||
"""A special method that acts as a cog local pre-invoke hook.
|
||||
|
||||
This is similar to :meth:`.Command.before_invoke`.
|
||||
@@ -376,7 +400,7 @@ class Cog(metaclass=CogMeta):
|
||||
pass
|
||||
|
||||
@_cog_special_method
|
||||
async def cog_after_invoke(self, ctx):
|
||||
async def cog_after_invoke(self, ctx: Context) -> None:
|
||||
"""A special method that acts as a cog local post-invoke hook.
|
||||
|
||||
This is similar to :meth:`.Command.after_invoke`.
|
||||
@@ -390,7 +414,7 @@ class Cog(metaclass=CogMeta):
|
||||
"""
|
||||
pass
|
||||
|
||||
def _inject(self, bot):
|
||||
def _inject(self: CogT, bot: BotBase) -> CogT:
|
||||
cls = self.__class__
|
||||
|
||||
# realistically, the only thing that can cause loading errors
|
||||
@@ -425,7 +449,7 @@ class Cog(metaclass=CogMeta):
|
||||
|
||||
return self
|
||||
|
||||
def _eject(self, bot):
|
||||
def _eject(self, bot: BotBase) -> None:
|
||||
cls = self.__class__
|
||||
|
||||
try:
|
||||
|
||||
@@ -21,16 +21,52 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import re
|
||||
|
||||
from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
import discord.abc
|
||||
import discord.utils
|
||||
import re
|
||||
|
||||
from discord.message import Message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from discord.abc import MessageableChannel
|
||||
from discord.guild import Guild
|
||||
from discord.member import Member
|
||||
from discord.state import ConnectionState
|
||||
from discord.user import ClientUser, User
|
||||
from discord.voice_client import VoiceProtocol
|
||||
|
||||
from .bot import Bot, AutoShardedBot
|
||||
from .cog import Cog
|
||||
from .core import Command
|
||||
from .help import HelpCommand
|
||||
from .view import StringView
|
||||
|
||||
__all__ = (
|
||||
'Context',
|
||||
)
|
||||
|
||||
class Context(discord.abc.Messageable):
|
||||
MISSING: Any = discord.utils.MISSING
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
|
||||
CogT = TypeVar('CogT', bound="Cog")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
P = ParamSpec('P')
|
||||
else:
|
||||
P = TypeVar('P')
|
||||
|
||||
|
||||
class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
r"""Represents the context in which a command is being invoked under.
|
||||
|
||||
This class contains a lot of meta data to help you understand more about
|
||||
@@ -58,11 +94,11 @@ class Context(discord.abc.Messageable):
|
||||
This is only of use for within converters.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
prefix: :class:`str`
|
||||
prefix: Optional[:class:`str`]
|
||||
The prefix that was used to invoke the command.
|
||||
command: :class:`Command`
|
||||
command: Optional[:class:`Command`]
|
||||
The command that is being invoked currently.
|
||||
invoked_with: :class:`str`
|
||||
invoked_with: Optional[:class:`str`]
|
||||
The command name that triggered this invocation. Useful for finding out
|
||||
which alias called the command.
|
||||
invoked_parents: List[:class:`str`]
|
||||
@@ -73,7 +109,7 @@ class Context(discord.abc.Messageable):
|
||||
|
||||
.. versionadded:: 1.7
|
||||
|
||||
invoked_subcommand: :class:`Command`
|
||||
invoked_subcommand: Optional[:class:`Command`]
|
||||
The subcommand that was invoked.
|
||||
If no valid subcommand was invoked then this is equal to ``None``.
|
||||
subcommand_passed: Optional[:class:`str`]
|
||||
@@ -86,23 +122,38 @@ class Context(discord.abc.Messageable):
|
||||
or invoked.
|
||||
"""
|
||||
|
||||
def __init__(self, **attrs):
|
||||
self.message = attrs.pop('message', None)
|
||||
self.bot = attrs.pop('bot', None)
|
||||
self.args = attrs.pop('args', [])
|
||||
self.kwargs = attrs.pop('kwargs', {})
|
||||
self.prefix = attrs.pop('prefix')
|
||||
self.command = attrs.pop('command', None)
|
||||
self.view = attrs.pop('view', None)
|
||||
self.invoked_with = attrs.pop('invoked_with', None)
|
||||
self.invoked_parents = attrs.pop('invoked_parents', [])
|
||||
self.invoked_subcommand = attrs.pop('invoked_subcommand', None)
|
||||
self.subcommand_passed = attrs.pop('subcommand_passed', None)
|
||||
self.command_failed = attrs.pop('command_failed', False)
|
||||
self.current_parameter = attrs.pop('current_parameter', None)
|
||||
self._state = self.message._state
|
||||
def __init__(self,
|
||||
*,
|
||||
message: Message,
|
||||
bot: BotT,
|
||||
view: StringView,
|
||||
args: List[Any] = MISSING,
|
||||
kwargs: Dict[str, Any] = MISSING,
|
||||
prefix: Optional[str] = None,
|
||||
command: Optional[Command] = None,
|
||||
invoked_with: Optional[str] = None,
|
||||
invoked_parents: List[str] = MISSING,
|
||||
invoked_subcommand: Optional[Command] = None,
|
||||
subcommand_passed: Optional[str] = None,
|
||||
command_failed: bool = False,
|
||||
current_parameter: Optional[inspect.Parameter] = None,
|
||||
):
|
||||
self.message: Message = message
|
||||
self.bot: BotT = bot
|
||||
self.args: List[Any] = args or []
|
||||
self.kwargs: Dict[str, Any] = kwargs or {}
|
||||
self.prefix: Optional[str] = prefix
|
||||
self.command: Optional[Command] = command
|
||||
self.view: StringView = view
|
||||
self.invoked_with: Optional[str] = invoked_with
|
||||
self.invoked_parents: List[str] = invoked_parents or []
|
||||
self.invoked_subcommand: Optional[Command] = invoked_subcommand
|
||||
self.subcommand_passed: Optional[str] = subcommand_passed
|
||||
self.command_failed: bool = command_failed
|
||||
self.current_parameter: Optional[inspect.Parameter] = current_parameter
|
||||
self._state: ConnectionState = self.message._state
|
||||
|
||||
async def invoke(self, command, /, *args, **kwargs):
|
||||
async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
|
||||
r"""|coro|
|
||||
|
||||
Calls a command with the arguments given.
|
||||
@@ -124,7 +175,7 @@ class Context(discord.abc.Messageable):
|
||||
command: :class:`.Command`
|
||||
The command that is going to be called.
|
||||
\*args
|
||||
The arguments to to use.
|
||||
The arguments to use.
|
||||
\*\*kwargs
|
||||
The keyword arguments to use.
|
||||
|
||||
@@ -133,17 +184,9 @@ class Context(discord.abc.Messageable):
|
||||
TypeError
|
||||
The command argument to invoke is missing.
|
||||
"""
|
||||
arguments = []
|
||||
if command.cog is not None:
|
||||
arguments.append(command.cog)
|
||||
return await command(self, *args, **kwargs)
|
||||
|
||||
arguments.append(self)
|
||||
arguments.extend(args)
|
||||
|
||||
ret = await command.callback(*arguments, **kwargs)
|
||||
return ret
|
||||
|
||||
async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True):
|
||||
async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True) -> None:
|
||||
"""|coro|
|
||||
|
||||
Calls the command again.
|
||||
@@ -187,7 +230,7 @@ class Context(discord.abc.Messageable):
|
||||
|
||||
if restart:
|
||||
to_call = cmd.root_parent or cmd
|
||||
view.index = len(self.prefix)
|
||||
view.index = len(self.prefix or '')
|
||||
view.previous = 0
|
||||
self.invoked_parents = []
|
||||
self.invoked_with = view.get_word() # advance to get the root command
|
||||
@@ -206,20 +249,23 @@ class Context(discord.abc.Messageable):
|
||||
self.subcommand_passed = subcommand_passed
|
||||
|
||||
@property
|
||||
def valid(self):
|
||||
def valid(self) -> bool:
|
||||
""":class:`bool`: Checks if the invocation context is valid to be invoked with."""
|
||||
return self.prefix is not None and self.command is not None
|
||||
|
||||
async def _get_channel(self):
|
||||
async def _get_channel(self) -> discord.abc.Messageable:
|
||||
return self.channel
|
||||
|
||||
@property
|
||||
def clean_prefix(self):
|
||||
def clean_prefix(self) -> str:
|
||||
""":class:`str`: The cleaned up invoke prefix. i.e. mentions are ``@name`` instead of ``<@id>``.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
user = self.guild.me if self.guild else self.bot.user
|
||||
if self.prefix is None:
|
||||
return ''
|
||||
|
||||
user = self.me
|
||||
# this breaks if the prefix mention is not the bot itself but I
|
||||
# consider this to be an *incredibly* strange use case. I'd rather go
|
||||
# for this common use case rather than waste performance for the
|
||||
@@ -228,7 +274,7 @@ class Context(discord.abc.Messageable):
|
||||
return pattern.sub("@%s" % user.display_name.replace('\\', r'\\'), self.prefix)
|
||||
|
||||
@property
|
||||
def cog(self):
|
||||
def cog(self) -> Optional[Cog]:
|
||||
"""Optional[:class:`.Cog`]: Returns the cog associated with this context's command. None if it does not exist."""
|
||||
|
||||
if self.command is None:
|
||||
@@ -236,38 +282,39 @@ class Context(discord.abc.Messageable):
|
||||
return self.command.cog
|
||||
|
||||
@discord.utils.cached_property
|
||||
def guild(self):
|
||||
def guild(self) -> Optional[Guild]:
|
||||
"""Optional[:class:`.Guild`]: Returns the guild associated with this context's command. None if not available."""
|
||||
return self.message.guild
|
||||
|
||||
@discord.utils.cached_property
|
||||
def channel(self):
|
||||
def channel(self) -> MessageableChannel:
|
||||
"""Union[:class:`.abc.Messageable`]: Returns the channel associated with this context's command.
|
||||
Shorthand for :attr:`.Message.channel`.
|
||||
"""
|
||||
return self.message.channel
|
||||
|
||||
@discord.utils.cached_property
|
||||
def author(self):
|
||||
def author(self) -> Union[User, Member]:
|
||||
"""Union[:class:`~discord.User`, :class:`.Member`]:
|
||||
Returns the author associated with this context's command. Shorthand for :attr:`.Message.author`
|
||||
"""
|
||||
return self.message.author
|
||||
|
||||
@discord.utils.cached_property
|
||||
def me(self):
|
||||
def me(self) -> Union[Member, ClientUser]:
|
||||
"""Union[:class:`.Member`, :class:`.ClientUser`]:
|
||||
Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message contexts.
|
||||
"""
|
||||
return self.guild.me if self.guild is not None else self.bot.user
|
||||
# bot.user will never be None at this point.
|
||||
return self.guild.me if self.guild is not None else self.bot.user # type: ignore
|
||||
|
||||
@property
|
||||
def voice_client(self):
|
||||
def voice_client(self) -> Optional[VoiceProtocol]:
|
||||
r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable."""
|
||||
g = self.guild
|
||||
return g.voice_client if g else None
|
||||
|
||||
async def send_help(self, *args):
|
||||
async def send_help(self, *args: Any) -> Any:
|
||||
"""send_help(entity=<bot>)
|
||||
|
||||
|coro|
|
||||
@@ -319,12 +366,12 @@ class Context(discord.abc.Messageable):
|
||||
return None
|
||||
|
||||
entity = args[0]
|
||||
if entity is None:
|
||||
return None
|
||||
|
||||
if isinstance(entity, str):
|
||||
entity = bot.get_cog(entity) or bot.get_command(entity)
|
||||
|
||||
if entity is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
entity.qualified_name
|
||||
except AttributeError:
|
||||
@@ -348,6 +395,6 @@ class Context(discord.abc.Messageable):
|
||||
except CommandError as e:
|
||||
await cmd.on_help_command_error(self, e)
|
||||
|
||||
@discord.utils.copy_doc(discord.Message.reply)
|
||||
async def reply(self, content=None, **kwargs):
|
||||
@discord.utils.copy_doc(Message.reply)
|
||||
async def reply(self, content: Optional[str] = None, **kwargs: Any) -> Message:
|
||||
return await self.message.reply(content, **kwargs)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -330,7 +330,7 @@ class ChannelNotReadable(BadArgument):
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
argument: Union[:class:`.abc.GuildChannel`, :class:`Thread`]
|
||||
argument: Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
|
||||
The channel supplied by the caller that was not readable
|
||||
"""
|
||||
def __init__(self, argument: Union[GuildChannel, Thread]) -> None:
|
||||
@@ -645,7 +645,7 @@ class NSFWChannelRequired(CheckFailure):
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
channel: Union[:class:`.abc.GuildChannel`, :class:`Thread`]
|
||||
channel: Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
|
||||
The channel that does not have NSFW enabled.
|
||||
"""
|
||||
def __init__(self, channel: Union[GuildChannel, Thread]) -> None:
|
||||
|
||||
@@ -27,11 +27,17 @@ import copy
|
||||
import functools
|
||||
import inspect
|
||||
import re
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
import discord.utils
|
||||
|
||||
from .core import Group, Command
|
||||
from .errors import CommandError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .context import Context
|
||||
|
||||
__all__ = (
|
||||
'Paginator',
|
||||
'HelpCommand',
|
||||
@@ -320,7 +326,7 @@ class HelpCommand:
|
||||
self.command_attrs = attrs = options.pop('command_attrs', {})
|
||||
attrs.setdefault('name', 'help')
|
||||
attrs.setdefault('help', 'Shows this message')
|
||||
self.context = None
|
||||
self.context: Context = discord.utils.MISSING
|
||||
self._command_impl = _HelpCommandImpl(self, **self.command_attrs)
|
||||
|
||||
def copy(self):
|
||||
|
||||
@@ -27,22 +27,20 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
@@ -50,8 +48,6 @@ from collections.abc import Sequence
|
||||
from discord.backoff import ExponentialBackoff
|
||||
from discord.utils import MISSING
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
__all__ = (
|
||||
'loop',
|
||||
)
|
||||
@@ -61,7 +57,6 @@ _func = Callable[..., Awaitable[Any]]
|
||||
LF = TypeVar('LF', bound=_func)
|
||||
FT = TypeVar('FT', bound=_func)
|
||||
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
|
||||
LT = TypeVar('LT', bound='Loop')
|
||||
|
||||
|
||||
class SleepHandle:
|
||||
@@ -78,7 +73,7 @@ class SleepHandle:
|
||||
relative_delta = discord.utils.compute_timedelta(dt)
|
||||
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True)
|
||||
|
||||
def wait(self) -> asyncio.Future:
|
||||
def wait(self) -> asyncio.Future[Any]:
|
||||
return self.future
|
||||
|
||||
def done(self) -> bool:
|
||||
@@ -94,7 +89,9 @@ class Loop(Generic[LF]):
|
||||
|
||||
The main interface to create this is through :func:`loop`.
|
||||
"""
|
||||
def __init__(self,
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
coro: LF,
|
||||
seconds: float,
|
||||
hours: float,
|
||||
@@ -102,15 +99,15 @@ class Loop(Generic[LF]):
|
||||
time: Union[datetime.time, Sequence[datetime.time]],
|
||||
count: Optional[int],
|
||||
reconnect: bool,
|
||||
loop: Optional[asyncio.AbstractEventLoop],
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
) -> None:
|
||||
self.coro: LF = coro
|
||||
self.reconnect: bool = reconnect
|
||||
self.loop: Optional[asyncio.AbstractEventLoop] = loop
|
||||
self.loop: asyncio.AbstractEventLoop = loop
|
||||
self.count: Optional[int] = count
|
||||
self._current_loop = 0
|
||||
self._handle = None
|
||||
self._task = None
|
||||
self._handle: SleepHandle = MISSING
|
||||
self._task: asyncio.Task[None] = MISSING
|
||||
self._injected = None
|
||||
self._valid_exception = (
|
||||
OSError,
|
||||
@@ -131,7 +128,7 @@ class Loop(Generic[LF]):
|
||||
|
||||
self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time)
|
||||
self._last_iteration_failed = False
|
||||
self._last_iteration = None
|
||||
self._last_iteration: datetime.datetime = MISSING
|
||||
self._next_iteration = None
|
||||
|
||||
if not inspect.iscoroutinefunction(self.coro):
|
||||
@@ -147,9 +144,8 @@ class Loop(Generic[LF]):
|
||||
else:
|
||||
await coro(*args, **kwargs)
|
||||
|
||||
|
||||
def _try_sleep_until(self, dt: datetime.datetime):
|
||||
self._handle = SleepHandle(dt=dt, loop=self.loop) # type: ignore
|
||||
self._handle = SleepHandle(dt=dt, loop=self.loop)
|
||||
return self._handle.wait()
|
||||
|
||||
async def _loop(self, *args: Any, **kwargs: Any) -> None:
|
||||
@@ -178,7 +174,7 @@ class Loop(Generic[LF]):
|
||||
await asyncio.sleep(backoff.delay())
|
||||
else:
|
||||
await self._try_sleep_until(self._next_iteration)
|
||||
|
||||
|
||||
if self._stop_next_iteration:
|
||||
return
|
||||
|
||||
@@ -211,14 +207,14 @@ class Loop(Generic[LF]):
|
||||
if obj is None:
|
||||
return self
|
||||
|
||||
copy = Loop(
|
||||
self.coro,
|
||||
seconds=self._seconds,
|
||||
hours=self._hours,
|
||||
copy: Loop[LF] = Loop(
|
||||
self.coro,
|
||||
seconds=self._seconds,
|
||||
hours=self._hours,
|
||||
minutes=self._minutes,
|
||||
time=self._time,
|
||||
time=self._time,
|
||||
count=self.count,
|
||||
reconnect=self.reconnect,
|
||||
reconnect=self.reconnect,
|
||||
loop=self.loop,
|
||||
)
|
||||
copy._injected = obj
|
||||
@@ -237,7 +233,7 @@ class Loop(Generic[LF]):
|
||||
"""
|
||||
if self._seconds is not MISSING:
|
||||
return self._seconds
|
||||
|
||||
|
||||
@property
|
||||
def minutes(self) -> Optional[float]:
|
||||
"""Optional[:class:`float`]: Read-only value for the number of minutes
|
||||
@@ -247,7 +243,7 @@ class Loop(Generic[LF]):
|
||||
"""
|
||||
if self._minutes is not MISSING:
|
||||
return self._minutes
|
||||
|
||||
|
||||
@property
|
||||
def hours(self) -> Optional[float]:
|
||||
"""Optional[:class:`float`]: Read-only value for the number of hours
|
||||
@@ -279,7 +275,7 @@ class Loop(Generic[LF]):
|
||||
|
||||
.. versionadded:: 1.3
|
||||
"""
|
||||
if self._task is None:
|
||||
if self._task is MISSING:
|
||||
return None
|
||||
elif self._task and self._task.done() or self._stop_next_iteration:
|
||||
return None
|
||||
@@ -305,7 +301,7 @@ class Loop(Generic[LF]):
|
||||
|
||||
return await self.coro(*args, **kwargs)
|
||||
|
||||
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task:
|
||||
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
|
||||
r"""Starts the internal task in the event loop.
|
||||
|
||||
Parameters
|
||||
@@ -326,13 +322,13 @@ class Loop(Generic[LF]):
|
||||
The task that has been created.
|
||||
"""
|
||||
|
||||
if self._task is not None and not self._task.done():
|
||||
if self._task is not MISSING and not self._task.done():
|
||||
raise RuntimeError('Task is already launched and is not completed.')
|
||||
|
||||
if self._injected is not None:
|
||||
args = (self._injected, *args)
|
||||
|
||||
if self.loop is None:
|
||||
if self.loop is MISSING:
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
self._task = self.loop.create_task(self._loop(*args, **kwargs))
|
||||
@@ -356,7 +352,7 @@ class Loop(Generic[LF]):
|
||||
|
||||
.. versionadded:: 1.2
|
||||
"""
|
||||
if self._task and not self._task.done():
|
||||
if self._task is not MISSING and not self._task.done():
|
||||
self._stop_next_iteration = True
|
||||
|
||||
def _can_be_cancelled(self) -> bool:
|
||||
@@ -383,7 +379,7 @@ class Loop(Generic[LF]):
|
||||
The keyword arguments to use.
|
||||
"""
|
||||
|
||||
def restart_when_over(fut, *, args=args, kwargs=kwargs):
|
||||
def restart_when_over(fut: Any, *, args: Any = args, kwargs: Any = kwargs) -> None:
|
||||
self._task.remove_done_callback(restart_when_over)
|
||||
self.start(*args, **kwargs)
|
||||
|
||||
@@ -446,9 +442,9 @@ class Loop(Generic[LF]):
|
||||
self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions)
|
||||
return len(self._valid_exception) == old_length - len(exceptions)
|
||||
|
||||
def get_task(self) -> Optional[asyncio.Task]:
|
||||
def get_task(self) -> Optional[asyncio.Task[None]]:
|
||||
"""Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running."""
|
||||
return self._task
|
||||
return self._task if self._task is not MISSING else None
|
||||
|
||||
def is_being_cancelled(self) -> bool:
|
||||
"""Whether the task is being cancelled."""
|
||||
@@ -466,7 +462,7 @@ class Loop(Generic[LF]):
|
||||
|
||||
.. versionadded:: 1.4
|
||||
"""
|
||||
return not bool(self._task.done()) if self._task else False
|
||||
return not bool(self._task.done()) if self._task is not MISSING else False
|
||||
|
||||
async def _error(self, *args: Any) -> None:
|
||||
exception: Exception = args[-1]
|
||||
@@ -560,7 +556,9 @@ class Loop(Generic[LF]):
|
||||
self._time_index = 0
|
||||
if self._current_loop == 0:
|
||||
# if we're at the last index on the first iteration, we need to sleep until tomorrow
|
||||
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0])
|
||||
return datetime.datetime.combine(
|
||||
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0]
|
||||
)
|
||||
|
||||
next_time = self._time[self._time_index]
|
||||
|
||||
@@ -568,7 +566,7 @@ class Loop(Generic[LF]):
|
||||
self._time_index += 1
|
||||
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time)
|
||||
|
||||
next_date = cast(datetime.datetime, self._last_iteration)
|
||||
next_date = self._last_iteration
|
||||
if self._time_index == 0:
|
||||
# we can assume that the earliest time should be scheduled for "tomorrow"
|
||||
next_date += datetime.timedelta(days=1)
|
||||
@@ -576,12 +574,14 @@ class Loop(Generic[LF]):
|
||||
self._time_index += 1
|
||||
return datetime.datetime.combine(next_date, next_time)
|
||||
|
||||
def _prepare_time_index(self, now: Optional[datetime.datetime] = None) -> None:
|
||||
def _prepare_time_index(self, now: datetime.datetime = MISSING) -> None:
|
||||
# now kwarg should be a datetime.datetime representing the time "now"
|
||||
# to calculate the next time index from
|
||||
|
||||
# pre-condition: self._time is set
|
||||
time_now = (now or datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)).timetz()
|
||||
time_now = (
|
||||
now if now is not MISSING else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)
|
||||
).timetz()
|
||||
for idx, time in enumerate(self._time):
|
||||
if time >= time_now:
|
||||
self._time_index = idx
|
||||
@@ -597,20 +597,24 @@ class Loop(Generic[LF]):
|
||||
utc: datetime.timezone = datetime.timezone.utc,
|
||||
) -> List[datetime.time]:
|
||||
if isinstance(time, dt):
|
||||
ret = time if time.tzinfo is not None else time.replace(tzinfo=utc)
|
||||
return [ret]
|
||||
inner = time if time.tzinfo is not None else time.replace(tzinfo=utc)
|
||||
return [inner]
|
||||
if not isinstance(time, Sequence):
|
||||
raise TypeError(f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.')
|
||||
raise TypeError(
|
||||
f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.'
|
||||
)
|
||||
if not time:
|
||||
raise ValueError('time parameter must not be an empty sequence.')
|
||||
|
||||
ret = []
|
||||
ret: List[datetime.time] = []
|
||||
for index, t in enumerate(time):
|
||||
if not isinstance(t, dt):
|
||||
raise TypeError(f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.')
|
||||
raise TypeError(
|
||||
f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.'
|
||||
)
|
||||
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
|
||||
|
||||
ret = sorted(set(ret)) # de-dupe and sort times
|
||||
ret = sorted(set(ret)) # de-dupe and sort times
|
||||
return ret
|
||||
|
||||
def change_interval(
|
||||
@@ -691,7 +695,7 @@ def loop(
|
||||
time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
|
||||
count: Optional[int] = None,
|
||||
reconnect: bool = True,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
loop: asyncio.AbstractEventLoop = MISSING,
|
||||
) -> Callable[[LF], Loop[LF]]:
|
||||
"""A decorator that schedules a task in the background for you with
|
||||
optional reconnect logic. The decorator returns a :class:`Loop`.
|
||||
@@ -707,7 +711,7 @@ def loop(
|
||||
time: Union[:class:`datetime.time`, Sequence[:class:`datetime.time`]]
|
||||
The exact times to run this loop at. Either a non-empty list or a single
|
||||
value of :class:`datetime.time` should be passed. Timezones are supported.
|
||||
If no timezone is given for the times, it is assumed to represent UTC time.
|
||||
If no timezone is given for the times, it is assumed to represent UTC time.
|
||||
|
||||
This cannot be used in conjunction with the relative time parameters.
|
||||
|
||||
@@ -724,7 +728,7 @@ def loop(
|
||||
Whether to handle errors and restart the task
|
||||
using an exponential back-off algorithm similar to the
|
||||
one used in :meth:`discord.Client.connect`.
|
||||
loop: Optional[:class:`asyncio.AbstractEventLoop`]
|
||||
loop: :class:`asyncio.AbstractEventLoop`
|
||||
The loop to use to register the task, if not given
|
||||
defaults to :func:`asyncio.get_event_loop`.
|
||||
|
||||
@@ -736,15 +740,17 @@ def loop(
|
||||
The function was not a coroutine, an invalid value for the ``time`` parameter was passed,
|
||||
or ``time`` parameter was passed in conjunction with relative time parameters.
|
||||
"""
|
||||
|
||||
def decorator(func: LF) -> Loop[LF]:
|
||||
kwargs = {
|
||||
'seconds': seconds,
|
||||
'minutes': minutes,
|
||||
'hours': hours,
|
||||
'count': count,
|
||||
'time': time,
|
||||
'reconnect': reconnect,
|
||||
'loop': loop,
|
||||
}
|
||||
return Loop(func, **kwargs)
|
||||
return Loop[LF](
|
||||
func,
|
||||
seconds=seconds,
|
||||
minutes=minutes,
|
||||
hours=hours,
|
||||
count=count,
|
||||
time=time,
|
||||
reconnect=reconnect,
|
||||
loop=loop,
|
||||
)
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -480,16 +480,6 @@ class Intents(BaseFlags):
|
||||
self.value = self.DEFAULT_VALUE
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def default(cls: Type[Intents]) -> Intents:
|
||||
"""A factory method that creates a :class:`Intents` with everything enabled
|
||||
except :attr:`presences` and :attr:`members`.
|
||||
"""
|
||||
self = cls.all()
|
||||
self.presences = False
|
||||
self.members = False
|
||||
return self
|
||||
|
||||
@flag_value
|
||||
def guilds(self):
|
||||
""":class:`bool`: Whether guild related events are enabled.
|
||||
@@ -650,6 +640,10 @@ class Intents(BaseFlags):
|
||||
- :attr:`VoiceChannel.members`
|
||||
- :attr:`VoiceChannel.voice_states`
|
||||
- :attr:`Member.voice`
|
||||
|
||||
.. note::
|
||||
|
||||
This intent is required to connect to voice.
|
||||
"""
|
||||
return 1 << 7
|
||||
|
||||
|
||||
@@ -58,7 +58,9 @@ if TYPE_CHECKING:
|
||||
DataCallable = Callable[[Dict[str, Any]], T]
|
||||
Result = Optional[DataCallable[Any]]
|
||||
|
||||
log: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
_log: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
__all__ = (
|
||||
'DiscordWebSocket',
|
||||
@@ -132,7 +134,7 @@ class GatewayRatelimiter:
|
||||
async with self.lock:
|
||||
delta = self.get_delay()
|
||||
if delta:
|
||||
log.warning('WebSocket in shard ID %s is ratelimited, waiting %.2f seconds', self.shard_id, delta)
|
||||
_log.warning('WebSocket in shard ID %s is ratelimited, waiting %.2f seconds', self.shard_id, delta)
|
||||
await asyncio.sleep(delta)
|
||||
|
||||
|
||||
@@ -160,20 +162,20 @@ class KeepAliveHandler(threading.Thread):
|
||||
def run(self) -> None:
|
||||
while not self._stop_ev.wait(self.interval):
|
||||
if self._last_recv + self.heartbeat_timeout < time.perf_counter():
|
||||
log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id)
|
||||
_log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id)
|
||||
coro = self.ws.close(4000)
|
||||
f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop)
|
||||
|
||||
try:
|
||||
f.result()
|
||||
except Exception:
|
||||
log.exception('An error occurred while stopping the gateway. Ignoring.')
|
||||
_log.exception('An error occurred while stopping the gateway. Ignoring.')
|
||||
finally:
|
||||
self.stop()
|
||||
return
|
||||
|
||||
data = self.get_payload()
|
||||
log.debug(self.msg, self.shard_id, data['d'])
|
||||
_log.debug(self.msg, self.shard_id, data['d'])
|
||||
coro = self.ws.send_heartbeat(data)
|
||||
f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop)
|
||||
try:
|
||||
@@ -192,7 +194,7 @@ class KeepAliveHandler(threading.Thread):
|
||||
else:
|
||||
stack = ''.join(traceback.format_stack(frame))
|
||||
msg = f'{self.block_msg}\nLoop thread traceback (most recent call last):\n{stack}'
|
||||
log.warning(msg, self.shard_id, total)
|
||||
_log.warning(msg, self.shard_id, total)
|
||||
|
||||
except Exception:
|
||||
self.stop()
|
||||
@@ -217,7 +219,7 @@ class KeepAliveHandler(threading.Thread):
|
||||
self._last_ack = ack_time
|
||||
self.latency = ack_time - self._last_send
|
||||
if self.latency > 10:
|
||||
log.warning(self.behind_msg, self.shard_id, self.latency)
|
||||
_log.warning(self.behind_msg, self.shard_id, self.latency)
|
||||
|
||||
class VoiceKeepAliveHandler(KeepAliveHandler):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
@@ -378,7 +380,7 @@ class DiscordWebSocket:
|
||||
|
||||
client._connection._update_references(ws)
|
||||
|
||||
log.debug('Created websocket connected to %s', gateway)
|
||||
_log.debug('Created websocket connected to %s', gateway)
|
||||
|
||||
# poll event for OP Hello
|
||||
await ws.poll_event()
|
||||
@@ -451,7 +453,7 @@ class DiscordWebSocket:
|
||||
|
||||
await self.call_hooks('before_identify', self.shard_id, initial=self._initial_identify)
|
||||
await self.send_as_json(payload)
|
||||
log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id)
|
||||
_log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id)
|
||||
|
||||
async def resume(self) -> None:
|
||||
"""Sends the RESUME packet."""
|
||||
@@ -465,11 +467,10 @@ class DiscordWebSocket:
|
||||
}
|
||||
|
||||
await self.send_as_json(payload)
|
||||
log.info('Shard ID %s has sent the RESUME payload.', self.shard_id)
|
||||
_log.info('Shard ID %s has sent the RESUME payload.', self.shard_id)
|
||||
|
||||
async def received_message(self, msg, /) -> None:
|
||||
self.log_receive(msg)
|
||||
|
||||
async def received_message(self, msg, /) -> None:
|
||||
if type(msg) is bytes:
|
||||
self._buffer.extend(msg)
|
||||
|
||||
@@ -478,9 +479,11 @@ class DiscordWebSocket:
|
||||
msg = self._zlib.decompress(self._buffer)
|
||||
msg = msg.decode('utf-8')
|
||||
self._buffer = bytearray()
|
||||
msg = utils.from_json(msg)
|
||||
|
||||
log.debug('For Shard ID %s: WebSocket Event: %s', self.shard_id, msg)
|
||||
self.log_receive(msg)
|
||||
msg = utils._from_json(msg)
|
||||
|
||||
_log.debug('For Shard ID %s: WebSocket Event: %s', self.shard_id, msg)
|
||||
event = msg.get('t')
|
||||
if event:
|
||||
self._dispatch('socket_event_type', event)
|
||||
@@ -499,7 +502,7 @@ class DiscordWebSocket:
|
||||
# "reconnect" can only be handled by the Client
|
||||
# so we terminate our connection and raise an
|
||||
# internal exception signalling to reconnect.
|
||||
log.debug('Received RECONNECT opcode.')
|
||||
_log.debug('Received RECONNECT opcode.')
|
||||
await self.close()
|
||||
raise ReconnectWebSocket(self.shard_id)
|
||||
|
||||
@@ -529,11 +532,11 @@ class DiscordWebSocket:
|
||||
|
||||
self.sequence = None
|
||||
self.session_id = None
|
||||
log.info('Shard ID %s session has been invalidated.', self.shard_id)
|
||||
_log.info('Shard ID %s session has been invalidated.', self.shard_id)
|
||||
await self.close(code=1000)
|
||||
raise ReconnectWebSocket(self.shard_id, resume=False)
|
||||
|
||||
log.warning('Unknown OP code %s.', op)
|
||||
_log.warning('Unknown OP code %s.', op)
|
||||
return
|
||||
|
||||
if event == 'READY':
|
||||
@@ -542,20 +545,20 @@ class DiscordWebSocket:
|
||||
self.session_id = data['session_id']
|
||||
# pass back shard ID to ready handler
|
||||
data['__shard_id__'] = self.shard_id
|
||||
log.info('Shard ID %s has connected to Gateway: %s (Session ID: %s).',
|
||||
_log.info('Shard ID %s has connected to Gateway: %s (Session ID: %s).',
|
||||
self.shard_id, ', '.join(trace), self.session_id)
|
||||
|
||||
elif event == 'RESUMED':
|
||||
self._trace = trace = data.get('_trace', [])
|
||||
# pass back the shard ID to the resumed handler
|
||||
data['__shard_id__'] = self.shard_id
|
||||
log.info('Shard ID %s has successfully RESUMED session %s under trace %s.',
|
||||
_log.info('Shard ID %s has successfully RESUMED session %s under trace %s.',
|
||||
self.shard_id, self.session_id, ', '.join(trace))
|
||||
|
||||
try:
|
||||
func = self._discord_parsers[event]
|
||||
except KeyError:
|
||||
log.debug('Unknown event %s.', event)
|
||||
_log.debug('Unknown event %s.', event)
|
||||
else:
|
||||
func(data)
|
||||
|
||||
@@ -609,10 +612,10 @@ class DiscordWebSocket:
|
||||
elif msg.type is aiohttp.WSMsgType.BINARY:
|
||||
await self.received_message(msg.data)
|
||||
elif msg.type is aiohttp.WSMsgType.ERROR:
|
||||
log.debug('Received %s', msg)
|
||||
_log.debug('Received %s', msg)
|
||||
raise msg.data
|
||||
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSE):
|
||||
log.debug('Received %s', msg)
|
||||
_log.debug('Received %s', msg)
|
||||
raise WebSocketClosure
|
||||
except (asyncio.TimeoutError, WebSocketClosure) as e:
|
||||
# Ensure the keep alive handler is closed
|
||||
@@ -621,15 +624,15 @@ class DiscordWebSocket:
|
||||
self._keep_alive = None
|
||||
|
||||
if isinstance(e, asyncio.TimeoutError):
|
||||
log.info('Timed out receiving packet. Attempting a reconnect.')
|
||||
_log.info('Timed out receiving packet. Attempting a reconnect.')
|
||||
raise ReconnectWebSocket(self.shard_id) from None
|
||||
|
||||
code = self._close_code or self.socket.close_code
|
||||
if self._can_handle_close():
|
||||
log.info('Websocket closed with %s, attempting a reconnect.', code)
|
||||
_log.info('Websocket closed with %s, attempting a reconnect.', code)
|
||||
raise ReconnectWebSocket(self.shard_id) from None
|
||||
else:
|
||||
log.info('Websocket closed with %s, cannot reconnect.', code)
|
||||
_log.info('Websocket closed with %s, cannot reconnect.', code)
|
||||
raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None
|
||||
|
||||
async def debug_send(self, data, /) -> None:
|
||||
@@ -643,7 +646,7 @@ class DiscordWebSocket:
|
||||
|
||||
async def send_as_json(self, data) -> None:
|
||||
try:
|
||||
await self.send(utils.to_json(data))
|
||||
await self.send(utils._to_json(data))
|
||||
except RuntimeError as exc:
|
||||
if not self._can_handle_close():
|
||||
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
|
||||
@@ -651,7 +654,7 @@ class DiscordWebSocket:
|
||||
async def send_heartbeat(self, data: Heartbeat) -> None:
|
||||
# This bypasses the rate limit handling code since it has a higher priority
|
||||
try:
|
||||
await self.socket.send_str(utils.to_json(data))
|
||||
await self.socket.send_str(utils._to_json(data))
|
||||
except RuntimeError as exc:
|
||||
if not self._can_handle_close():
|
||||
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
|
||||
@@ -677,8 +680,8 @@ class DiscordWebSocket:
|
||||
}
|
||||
}
|
||||
|
||||
sent = utils.to_json(payload)
|
||||
log.debug('Sending "%s" to change status', sent)
|
||||
sent = utils._to_json(payload)
|
||||
_log.debug('Sending "%s" to change status', sent)
|
||||
await self.send(sent)
|
||||
|
||||
async def request_chunks(self, guild_id: int, query: Optional[str] = None, *, limit: int, user_ids: Optional[List[int]] = None, presences: bool = False, nonce: Optional[int] = None) -> None:
|
||||
@@ -714,7 +717,7 @@ class DiscordWebSocket:
|
||||
}
|
||||
}
|
||||
|
||||
log.debug('Updating our voice state to %s.', payload)
|
||||
_log.debug('Updating our voice state to %s.', payload)
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def close(self, code: int = 4000) -> None:
|
||||
@@ -786,9 +789,10 @@ class DiscordVoiceWebSocket:
|
||||
async def _hook(self, *args: Any) -> Any:
|
||||
pass
|
||||
|
||||
|
||||
async def send_as_json(self, data) -> None:
|
||||
log.debug('Sending voice websocket frame: %s.', data)
|
||||
await self.ws.send_str(utils.to_json(data))
|
||||
_log.debug('Sending voice websocket frame: %s.', data)
|
||||
await self.ws.send_str(utils._to_json(data))
|
||||
|
||||
send_heartbeat = send_as_json
|
||||
|
||||
@@ -872,8 +876,9 @@ class DiscordVoiceWebSocket:
|
||||
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def received_message(self, msg) -> None:
|
||||
log.debug('Voice websocket frame received: %s', msg)
|
||||
|
||||
async def received_message(self, msg) -> None:
|
||||
_log.debug('Voice websocket frame received: %s', msg)
|
||||
op = msg['op']
|
||||
data = msg.get('d')
|
||||
|
||||
@@ -882,7 +887,7 @@ class DiscordVoiceWebSocket:
|
||||
elif op == self.HEARTBEAT_ACK:
|
||||
self._keep_alive.ack()
|
||||
elif op == self.RESUMED:
|
||||
log.info('Voice RESUME succeeded.')
|
||||
_log.info('Voice RESUME succeeded.')
|
||||
elif op == self.SESSION_DESCRIPTION:
|
||||
self._connection.mode = data['mode']
|
||||
await self.load_secret_key(data)
|
||||
@@ -905,7 +910,7 @@ class DiscordVoiceWebSocket:
|
||||
struct.pack_into('>I', packet, 4, state.ssrc)
|
||||
state.socket.sendto(packet, (state.endpoint_ip, state.voice_port))
|
||||
recv = await self.loop.sock_recv(state.socket, 70)
|
||||
log.debug('received packet in initial_connection: %s', recv)
|
||||
_log.debug('received packet in initial_connection: %s', recv)
|
||||
|
||||
# the ip is ascii starting at the 4th byte and ending at the first null
|
||||
ip_start = 4
|
||||
@@ -913,15 +918,15 @@ class DiscordVoiceWebSocket:
|
||||
state.ip = recv[ip_start:ip_end].decode('ascii')
|
||||
|
||||
state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
|
||||
log.debug('detected ip: %s port: %s', state.ip, state.port)
|
||||
_log.debug('detected ip: %s port: %s', state.ip, state.port)
|
||||
|
||||
# there *should* always be at least one supported mode (xsalsa20_poly1305)
|
||||
modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes]
|
||||
log.debug('received supported encryption modes: %s', ", ".join(modes))
|
||||
_log.debug('received supported encryption modes: %s', ", ".join(modes))
|
||||
|
||||
mode = modes[0]
|
||||
await self.select_protocol(state.ip, state.port, mode)
|
||||
log.info('selected the voice protocol for use (%s)', mode)
|
||||
_log.info('selected the voice protocol for use (%s)', mode)
|
||||
|
||||
@property
|
||||
def latency(self) -> float:
|
||||
@@ -938,8 +943,9 @@ class DiscordVoiceWebSocket:
|
||||
|
||||
return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies)
|
||||
|
||||
|
||||
async def load_secret_key(self, data) -> None:
|
||||
log.info('received secret key for voice connection')
|
||||
_log.info('received secret key for voice connection')
|
||||
self.secret_key = self._connection.secret_key = data.get('secret_key')
|
||||
await self.speak()
|
||||
await self.speak(False)
|
||||
@@ -948,12 +954,12 @@ class DiscordVoiceWebSocket:
|
||||
# This exception is handled up the chain
|
||||
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
|
||||
if msg.type is aiohttp.WSMsgType.TEXT:
|
||||
await self.received_message(utils.from_json(msg.data))
|
||||
await self.received_message(utils._from_json(msg.data))
|
||||
elif msg.type is aiohttp.WSMsgType.ERROR:
|
||||
log.debug('Received %s', msg)
|
||||
_log.debug('Received %s', msg)
|
||||
raise ConnectionClosed(self.ws, shard_id=None) from msg.data
|
||||
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
|
||||
log.debug('Received %s', msg)
|
||||
_log.debug('Received %s', msg)
|
||||
raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code)
|
||||
|
||||
async def close(self, code: int = 1000) -> None:
|
||||
|
||||
@@ -140,6 +140,10 @@ class Guild(Hashable):
|
||||
|
||||
Returns the guild's name.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the guild's ID.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name: :class:`str`
|
||||
@@ -717,7 +721,7 @@ class Guild(Hashable):
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
more_stickers = 60 if 'MORE_STICKERS' in self.features else 15
|
||||
more_stickers = 60 if 'MORE_STICKERS' in self.features else 0
|
||||
return max(more_stickers, self._PREMIUM_GUILD_LIMITS[self.premium_tier].stickers)
|
||||
|
||||
@property
|
||||
@@ -736,7 +740,21 @@ class Guild(Hashable):
|
||||
"""List[:class:`Member`]: A list of members that belong to this guild."""
|
||||
return list(self._members.values())
|
||||
|
||||
def get_member(self, user_id: int) -> Optional[Member]:
|
||||
@property
|
||||
def humans(self) -> List[Member]:
|
||||
"""List[:class:`Member`]: A list of human members that belong to this guild.
|
||||
|
||||
.. versionadded:: 2.0 """
|
||||
return [member for member in self.members if not member.bot]
|
||||
|
||||
@property
|
||||
def bots(self) -> List[Member]:
|
||||
"""List[:class:`Member`]: A list of bots that belong to this guild.
|
||||
|
||||
.. versionadded:: 2.0 """
|
||||
return [member for member in self.members if member.bot]
|
||||
|
||||
def get_member(self, user_id: int, /) -> Optional[Member]:
|
||||
"""Returns a member with the given ID.
|
||||
|
||||
Parameters
|
||||
@@ -1356,7 +1374,7 @@ class Guild(Hashable):
|
||||
preferred_locale: str = MISSING,
|
||||
rules_channel: Optional[TextChannel] = MISSING,
|
||||
public_updates_channel: Optional[TextChannel] = MISSING,
|
||||
) -> None:
|
||||
) -> Guild:
|
||||
r"""|coro|
|
||||
|
||||
Edits the guild.
|
||||
@@ -1370,6 +1388,9 @@ class Guild(Hashable):
|
||||
.. versionchanged:: 2.0
|
||||
The `discovery_splash` and `community` keyword-only parameters were added.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
The newly updated guild is returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: :class:`str`
|
||||
@@ -1443,6 +1464,12 @@ class Guild(Hashable):
|
||||
The image format passed in to ``icon`` is invalid. It must be
|
||||
PNG or JPG. This is also raised if you are not the owner of the
|
||||
guild and request an ownership transfer.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`Guild`
|
||||
The newly updated guild. Note that this has the same limitations as
|
||||
mentioned in :meth:`Client.fetch_guild` and may not have full data.
|
||||
"""
|
||||
|
||||
http = self._state.http
|
||||
@@ -1555,7 +1582,8 @@ class Guild(Hashable):
|
||||
|
||||
fields['features'] = features
|
||||
|
||||
await http.edit_guild(self.id, reason=reason, **fields)
|
||||
data = await http.edit_guild(self.id, reason=reason, **fields)
|
||||
return Guild(data=data, state=self._state)
|
||||
|
||||
async def fetch_channels(self) -> Sequence[GuildChannel]:
|
||||
"""|coro|
|
||||
|
||||
159
discord/http.py
159
discord/http.py
@@ -53,7 +53,7 @@ from .gateway import DiscordClientWebSocketResponse
|
||||
from . import __version__, utils
|
||||
from .utils import MISSING
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .file import File
|
||||
@@ -99,7 +99,7 @@ async def json_or_text(response: aiohttp.ClientResponse) -> Union[Dict[str, Any]
|
||||
text = await response.text(encoding='utf-8')
|
||||
try:
|
||||
if response.headers['content-type'] == 'application/json':
|
||||
return utils.from_json(text)
|
||||
return utils._from_json(text)
|
||||
except KeyError:
|
||||
# Thanks Cloudflare
|
||||
pass
|
||||
@@ -141,7 +141,8 @@ class MaybeUnlock:
|
||||
def defer(self) -> None:
|
||||
self._unlock = False
|
||||
|
||||
def __exit__(self,
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BE]],
|
||||
exc: Optional[BE],
|
||||
traceback: Optional[TracebackType],
|
||||
@@ -152,7 +153,7 @@ class MaybeUnlock:
|
||||
|
||||
# For some reason, the Discord voice websocket expects this header to be
|
||||
# completely lowercase while aiohttp respects spec and does it as case-insensitive
|
||||
aiohttp.hdrs.WEBSOCKET = 'websocket' #type: ignore
|
||||
aiohttp.hdrs.WEBSOCKET = 'websocket' # type: ignore
|
||||
|
||||
|
||||
class HTTPClient:
|
||||
@@ -165,7 +166,7 @@ class HTTPClient:
|
||||
proxy: Optional[str] = None,
|
||||
proxy_auth: Optional[aiohttp.BasicAuth] = None,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
unsync_clock: bool = True
|
||||
unsync_clock: bool = True,
|
||||
) -> None:
|
||||
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
|
||||
self.connector = connector
|
||||
@@ -209,7 +210,7 @@ class HTTPClient:
|
||||
*,
|
||||
files: Optional[Sequence[File]] = None,
|
||||
form: Optional[Iterable[Dict[str, Any]]] = None,
|
||||
**kwargs: Any
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
bucket = route.bucket
|
||||
method = route.method
|
||||
@@ -231,7 +232,7 @@ class HTTPClient:
|
||||
# some checking if it's a JSON request
|
||||
if 'json' in kwargs:
|
||||
headers['Content-Type'] = 'application/json'
|
||||
kwargs['data'] = utils.to_json(kwargs.pop('json'))
|
||||
kwargs['data'] = utils._to_json(kwargs.pop('json'))
|
||||
|
||||
try:
|
||||
reason = kwargs.pop('reason')
|
||||
@@ -270,7 +271,7 @@ class HTTPClient:
|
||||
|
||||
try:
|
||||
async with self.__session.request(method, url, **kwargs) as response:
|
||||
log.debug('%s %s with %s has returned %s', method, url, kwargs.get('data'), response.status)
|
||||
_log.debug('%s %s with %s has returned %s', method, url, kwargs.get('data'), response.status)
|
||||
|
||||
# even errors have text involved in them so this is safe to call
|
||||
data = await json_or_text(response)
|
||||
@@ -280,13 +281,13 @@ class HTTPClient:
|
||||
if remaining == '0' and response.status != 429:
|
||||
# we've depleted our current bucket
|
||||
delta = utils._parse_ratelimit_header(response, use_clock=self.use_clock)
|
||||
log.debug('A rate limit bucket has been exhausted (bucket: %s, retry: %s).', bucket, delta)
|
||||
_log.debug('A rate limit bucket has been exhausted (bucket: %s, retry: %s).', bucket, delta)
|
||||
maybe_lock.defer()
|
||||
self.loop.call_later(delta, lock.release)
|
||||
|
||||
# the request was successful so just return the text/json
|
||||
if 300 > response.status >= 200:
|
||||
log.debug('%s %s has received %s', method, url, data)
|
||||
_log.debug('%s %s has received %s', method, url, data)
|
||||
return data
|
||||
|
||||
# we are being rate limited
|
||||
@@ -299,22 +300,22 @@ class HTTPClient:
|
||||
|
||||
# sleep a bit
|
||||
retry_after: float = data['retry_after']
|
||||
log.warning(fmt, retry_after, bucket)
|
||||
_log.warning(fmt, retry_after, bucket)
|
||||
|
||||
# check if it's a global rate limit
|
||||
is_global = data.get('global', False)
|
||||
if is_global:
|
||||
log.warning('Global rate limit has been hit. Retrying in %.2f seconds.', retry_after)
|
||||
_log.warning('Global rate limit has been hit. Retrying in %.2f seconds.', retry_after)
|
||||
self._global_over.clear()
|
||||
|
||||
await asyncio.sleep(retry_after)
|
||||
log.debug('Done sleeping for the rate limit. Retrying...')
|
||||
_log.debug('Done sleeping for the rate limit. Retrying...')
|
||||
|
||||
# release the global lock now that the
|
||||
# global rate limit has passed
|
||||
if is_global:
|
||||
self._global_over.set()
|
||||
log.debug('Global rate limit is now over.')
|
||||
_log.debug('Global rate limit is now over.')
|
||||
|
||||
continue
|
||||
|
||||
@@ -412,7 +413,7 @@ class HTTPClient:
|
||||
def send_message(
|
||||
self,
|
||||
channel_id: Snowflake,
|
||||
content: str,
|
||||
content: Optional[str],
|
||||
*,
|
||||
tts: bool = False,
|
||||
embed: Optional[embed.Embed] = None,
|
||||
@@ -493,7 +494,7 @@ class HTTPClient:
|
||||
if stickers:
|
||||
payload['sticker_ids'] = stickers
|
||||
|
||||
form.append({'name': 'payload_json', 'value': utils.to_json(payload)})
|
||||
form.append({'name': 'payload_json', 'value': utils._to_json(payload)})
|
||||
if len(files) == 1:
|
||||
file = files[0]
|
||||
form.append(
|
||||
@@ -547,11 +548,15 @@ class HTTPClient:
|
||||
components=components,
|
||||
)
|
||||
|
||||
def delete_message(self, channel_id: Snowflake, message_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]:
|
||||
def delete_message(
|
||||
self, channel_id: Snowflake, message_id: Snowflake, *, reason: Optional[str] = None
|
||||
) -> Response[None]:
|
||||
r = Route('DELETE', '/channels/{channel_id}/messages/{message_id}', channel_id=channel_id, message_id=message_id)
|
||||
return self.request(r, reason=reason)
|
||||
|
||||
def delete_messages(self, channel_id: Snowflake, message_ids: SnowflakeList, *, reason: Optional[str] = None) -> Response[None]:
|
||||
def delete_messages(
|
||||
self, channel_id: Snowflake, message_ids: SnowflakeList, *, reason: Optional[str] = None
|
||||
) -> Response[None]:
|
||||
r = Route('POST', '/channels/{channel_id}/messages/bulk-delete', channel_id=channel_id)
|
||||
payload = {
|
||||
'messages': message_ids,
|
||||
@@ -573,7 +578,9 @@ class HTTPClient:
|
||||
)
|
||||
return self.request(r)
|
||||
|
||||
def remove_reaction(self, channel_id: Snowflake, message_id: Snowflake, emoji: str, member_id: Snowflake) -> Response[None]:
|
||||
def remove_reaction(
|
||||
self, channel_id: Snowflake, message_id: Snowflake, emoji: str, member_id: Snowflake
|
||||
) -> Response[None]:
|
||||
r = Route(
|
||||
'DELETE',
|
||||
'/channels/{channel_id}/messages/{message_id}/reactions/{emoji}/{member_id}',
|
||||
@@ -774,11 +781,11 @@ class HTTPClient:
|
||||
}
|
||||
return self.request(r, json=payload, reason=reason)
|
||||
|
||||
def edit_my_voice_state(self, guild_id: Snowflake, payload: voice.VoiceState) -> Response[None]:
|
||||
def edit_my_voice_state(self, guild_id: Snowflake, payload: Dict[str, Any]) -> Response[None]:
|
||||
r = Route('PATCH', '/guilds/{guild_id}/voice-states/@me', guild_id=guild_id)
|
||||
return self.request(r, json=payload)
|
||||
|
||||
def edit_voice_state(self, guild_id: Snowflake, user_id: Snowflake, payload: voice.VoiceState) -> Response[None]:
|
||||
def edit_voice_state(self, guild_id: Snowflake, user_id: Snowflake, payload: Dict[str, Any]) -> Response[None]:
|
||||
r = Route('PATCH', '/guilds/{guild_id}/voice-states/{user_id}', guild_id=guild_id, user_id=user_id)
|
||||
return self.request(r, json=payload)
|
||||
|
||||
@@ -789,7 +796,7 @@ class HTTPClient:
|
||||
*,
|
||||
reason: Optional[str] = None,
|
||||
**fields: Any,
|
||||
) -> Response[member.Member]:
|
||||
) -> Response[member.MemberWithUser]:
|
||||
r = Route('PATCH', '/guilds/{guild_id}/members/{user_id}', guild_id=guild_id, user_id=user_id)
|
||||
return self.request(r, json=fields, reason=reason)
|
||||
|
||||
@@ -819,6 +826,7 @@ class HTTPClient:
|
||||
'archived',
|
||||
'auto_archive_duration',
|
||||
'locked',
|
||||
'invitable',
|
||||
'default_auto_archive_duration',
|
||||
)
|
||||
payload = {k: v for k, v in options.items() if k in valid_keys}
|
||||
@@ -900,12 +908,14 @@ class HTTPClient:
|
||||
name: str,
|
||||
auto_archive_duration: threads.ThreadArchiveDuration,
|
||||
type: threads.ThreadType,
|
||||
invitable: bool = True,
|
||||
reason: Optional[str] = None,
|
||||
) -> Response[threads.Thread]:
|
||||
payload = {
|
||||
'name': name,
|
||||
'auto_archive_duration': auto_archive_duration,
|
||||
'type': type,
|
||||
'invitable': invitable,
|
||||
}
|
||||
|
||||
route = Route('POST', '/channels/{channel_id}/threads', channel_id=channel_id)
|
||||
@@ -1122,7 +1132,9 @@ class HTTPClient:
|
||||
def get_all_guild_channels(self, guild_id: Snowflake) -> Response[List[guild.GuildChannel]]:
|
||||
return self.request(Route('GET', '/guilds/{guild_id}/channels', guild_id=guild_id))
|
||||
|
||||
def get_members(self, guild_id: Snowflake, limit: int, after: Optional[Snowflake]) -> Response[List[member.Member]]:
|
||||
def get_members(
|
||||
self, guild_id: Snowflake, limit: int, after: Optional[Snowflake]
|
||||
) -> Response[List[member.MemberWithUser]]:
|
||||
params: Dict[str, Any] = {
|
||||
'limit': limit,
|
||||
}
|
||||
@@ -1132,7 +1144,7 @@ class HTTPClient:
|
||||
r = Route('GET', '/guilds/{guild_id}/members', guild_id=guild_id)
|
||||
return self.request(r, params=params)
|
||||
|
||||
def get_member(self, guild_id: Snowflake, member_id: Snowflake) -> Response[member.Member]:
|
||||
def get_member(self, guild_id: Snowflake, member_id: Snowflake) -> Response[member.MemberWithUser]:
|
||||
return self.request(Route('GET', '/guilds/{guild_id}/members/{member_id}', guild_id=guild_id, member_id=member_id))
|
||||
|
||||
def prune_members(
|
||||
@@ -1177,9 +1189,13 @@ class HTTPClient:
|
||||
return self.request(Route('GET', '/guilds/{guild_id}/stickers', guild_id=guild_id))
|
||||
|
||||
def get_guild_sticker(self, guild_id: Snowflake, sticker_id: Snowflake) -> Response[sticker.GuildSticker]:
|
||||
return self.request(Route('GET', '/guilds/{guild_id}/stickers/{sticker_id}', guild_id=guild_id, sticker_id=sticker_id))
|
||||
return self.request(
|
||||
Route('GET', '/guilds/{guild_id}/stickers/{sticker_id}', guild_id=guild_id, sticker_id=sticker_id)
|
||||
)
|
||||
|
||||
def create_guild_sticker(self, guild_id: Snowflake, payload: sticker.CreateGuildSticker, file: File, reason: str) -> Response[sticker.GuildSticker]:
|
||||
def create_guild_sticker(
|
||||
self, guild_id: Snowflake, payload: sticker.CreateGuildSticker, file: File, reason: str
|
||||
) -> Response[sticker.GuildSticker]:
|
||||
initial_bytes = file.fp.read(16)
|
||||
|
||||
try:
|
||||
@@ -1202,18 +1218,31 @@ class HTTPClient:
|
||||
]
|
||||
|
||||
for k, v in payload.items():
|
||||
form.append({
|
||||
'name': k,
|
||||
'value': v,
|
||||
})
|
||||
form.append(
|
||||
{
|
||||
'name': k,
|
||||
'value': v,
|
||||
}
|
||||
)
|
||||
|
||||
return self.request(Route('POST', '/guilds/{guild_id}/stickers', guild_id=guild_id), form=form, files=[file], reason=reason)
|
||||
return self.request(
|
||||
Route('POST', '/guilds/{guild_id}/stickers', guild_id=guild_id), form=form, files=[file], reason=reason
|
||||
)
|
||||
|
||||
def modify_guild_sticker(self, guild_id: Snowflake, sticker_id: Snowflake, payload: sticker.EditGuildSticker, reason: str) -> Response[sticker.GuildSticker]:
|
||||
return self.request(Route('PATCH', '/guilds/{guild_id}/stickers/{sticker_id}', guild_id=guild_id, sticker_id=sticker_id), json=payload, reason=reason)
|
||||
def modify_guild_sticker(
|
||||
self, guild_id: Snowflake, sticker_id: Snowflake, payload: sticker.EditGuildSticker, reason: Optional[str],
|
||||
) -> Response[sticker.GuildSticker]:
|
||||
return self.request(
|
||||
Route('PATCH', '/guilds/{guild_id}/stickers/{sticker_id}', guild_id=guild_id, sticker_id=sticker_id),
|
||||
json=payload,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
def delete_guild_sticker(self, guild_id: Snowflake, sticker_id: Snowflake, reason: str) -> Response[None]:
|
||||
return self.request(Route('DELETE', '/guilds/{guild_id}/stickers/{sticker_id}', guild_id=guild_id, sticker_id=sticker_id), reason=reason)
|
||||
def delete_guild_sticker(self, guild_id: Snowflake, sticker_id: Snowflake, reason: Optional[str]) -> Response[None]:
|
||||
return self.request(
|
||||
Route('DELETE', '/guilds/{guild_id}/stickers/{sticker_id}', guild_id=guild_id, sticker_id=sticker_id),
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
def get_all_custom_emojis(self, guild_id: Snowflake) -> Response[List[emoji.Emoji]]:
|
||||
return self.request(Route('GET', '/guilds/{guild_id}/emojis', guild_id=guild_id))
|
||||
@@ -1288,7 +1317,9 @@ class HTTPClient:
|
||||
|
||||
return self.request(r)
|
||||
|
||||
def delete_integration(self, guild_id: Snowflake, integration_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]:
|
||||
def delete_integration(
|
||||
self, guild_id: Snowflake, integration_id: Snowflake, *, reason: Optional[str] = None
|
||||
) -> Response[None]:
|
||||
r = Route(
|
||||
'DELETE', '/guilds/{guild_id}/integrations/{integration_id}', guild_id=guild_id, integration_id=integration_id
|
||||
)
|
||||
@@ -1336,7 +1367,7 @@ class HTTPClient:
|
||||
unique: bool = True,
|
||||
target_type: Optional[invite.InviteTargetType] = None,
|
||||
target_user_id: Optional[Snowflake] = None,
|
||||
target_application_id: Optional[Snowflake] = None
|
||||
target_application_id: Optional[Snowflake] = None,
|
||||
) -> Response[invite.Invite]:
|
||||
r = Route('POST', '/channels/{channel_id}/invites', channel_id=channel_id)
|
||||
payload = {
|
||||
@@ -1357,7 +1388,9 @@ class HTTPClient:
|
||||
|
||||
return self.request(r, reason=reason, json=payload)
|
||||
|
||||
def get_invite(self, invite_id: str, *, with_counts: bool = True, with_expiration: bool = True) -> Response[invite.Invite]:
|
||||
def get_invite(
|
||||
self, invite_id: str, *, with_counts: bool = True, with_expiration: bool = True
|
||||
) -> Response[invite.Invite]:
|
||||
params = {
|
||||
'with_counts': int(with_counts),
|
||||
'with_expiration': int(with_expiration),
|
||||
@@ -1378,7 +1411,9 @@ class HTTPClient:
|
||||
def get_roles(self, guild_id: Snowflake) -> Response[List[role.Role]]:
|
||||
return self.request(Route('GET', '/guilds/{guild_id}/roles', guild_id=guild_id))
|
||||
|
||||
def edit_role(self, guild_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None, **fields: Any) -> Response[role.Role]:
|
||||
def edit_role(
|
||||
self, guild_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None, **fields: Any
|
||||
) -> Response[role.Role]:
|
||||
r = Route('PATCH', '/guilds/{guild_id}/roles/{role_id}', guild_id=guild_id, role_id=role_id)
|
||||
valid_keys = ('name', 'permissions', 'color', 'hoist', 'mentionable')
|
||||
payload = {k: v for k, v in fields.items() if k in valid_keys}
|
||||
@@ -1395,7 +1430,7 @@ class HTTPClient:
|
||||
role_ids: List[int],
|
||||
*,
|
||||
reason: Optional[str] = None,
|
||||
) -> Response[member.Member]:
|
||||
) -> Response[member.MemberWithUser]:
|
||||
return self.edit_member(guild_id=guild_id, user_id=user_id, roles=role_ids, reason=reason)
|
||||
|
||||
def create_role(self, guild_id: Snowflake, *, reason: Optional[str] = None, **fields: Any) -> Response[role.Role]:
|
||||
@@ -1412,7 +1447,9 @@ class HTTPClient:
|
||||
r = Route('PATCH', '/guilds/{guild_id}/roles', guild_id=guild_id)
|
||||
return self.request(r, json=positions, reason=reason)
|
||||
|
||||
def add_role(self, guild_id: Snowflake, user_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]:
|
||||
def add_role(
|
||||
self, guild_id: Snowflake, user_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None
|
||||
) -> Response[None]:
|
||||
r = Route(
|
||||
'PUT',
|
||||
'/guilds/{guild_id}/members/{user_id}/roles/{role_id}',
|
||||
@@ -1422,7 +1459,9 @@ class HTTPClient:
|
||||
)
|
||||
return self.request(r, reason=reason)
|
||||
|
||||
def remove_role(self, guild_id: Snowflake, user_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]:
|
||||
def remove_role(
|
||||
self, guild_id: Snowflake, user_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None
|
||||
) -> Response[None]:
|
||||
r = Route(
|
||||
'DELETE',
|
||||
'/guilds/{guild_id}/members/{user_id}/roles/{role_id}',
|
||||
@@ -1447,11 +1486,7 @@ class HTTPClient:
|
||||
return self.request(r, json=payload, reason=reason)
|
||||
|
||||
def delete_channel_permissions(
|
||||
self,
|
||||
channel_id: Snowflake,
|
||||
target: channel.OverwriteType,
|
||||
*,
|
||||
reason: Optional[str] = None
|
||||
self, channel_id: Snowflake, target: channel.OverwriteType, *, reason: Optional[str] = None
|
||||
) -> Response[None]:
|
||||
r = Route('DELETE', '/channels/{channel_id}/permissions/{target}', channel_id=channel_id, target=target)
|
||||
return self.request(r, reason=reason)
|
||||
@@ -1465,7 +1500,7 @@ class HTTPClient:
|
||||
channel_id: Snowflake,
|
||||
*,
|
||||
reason: Optional[str] = None,
|
||||
) -> Response[member.Member]:
|
||||
) -> Response[member.MemberWithUser]:
|
||||
return self.edit_member(guild_id=guild_id, user_id=user_id, channel_id=channel_id, reason=reason)
|
||||
|
||||
# Stage instance management
|
||||
@@ -1490,7 +1525,9 @@ class HTTPClient:
|
||||
)
|
||||
payload = {k: v for k, v in payload.items() if k in valid_keys}
|
||||
|
||||
return self.request(Route('PATCH', '/stage-instances/{channel_id}', channel_id=channel_id), json=payload, reason=reason)
|
||||
return self.request(
|
||||
Route('PATCH', '/stage-instances/{channel_id}', channel_id=channel_id), json=payload, reason=reason
|
||||
)
|
||||
|
||||
def delete_stage_instance(self, channel_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]:
|
||||
return self.request(Route('DELETE', '/stage-instances/{channel_id}', channel_id=channel_id), reason=reason)
|
||||
@@ -1500,7 +1537,9 @@ class HTTPClient:
|
||||
def get_global_commands(self, application_id: Snowflake) -> Response[List[interactions.ApplicationCommand]]:
|
||||
return self.request(Route('GET', '/applications/{application_id}/commands', application_id=application_id))
|
||||
|
||||
def get_global_command(self, application_id: Snowflake, command_id: Snowflake) -> Response[interactions.ApplicationCommand]:
|
||||
def get_global_command(
|
||||
self, application_id: Snowflake, command_id: Snowflake
|
||||
) -> Response[interactions.ApplicationCommand]:
|
||||
r = Route(
|
||||
'GET',
|
||||
'/applications/{application_id}/commands/{command_id}',
|
||||
@@ -1513,7 +1552,8 @@ class HTTPClient:
|
||||
r = Route('POST', '/applications/{application_id}/commands', application_id=application_id)
|
||||
return self.request(r, json=payload)
|
||||
|
||||
def edit_global_command(self,
|
||||
def edit_global_command(
|
||||
self,
|
||||
application_id: Snowflake,
|
||||
command_id: Snowflake,
|
||||
payload: interactions.EditApplicationCommand,
|
||||
@@ -1541,13 +1581,17 @@ class HTTPClient:
|
||||
)
|
||||
return self.request(r)
|
||||
|
||||
def bulk_upsert_global_commands(self, application_id: Snowflake, payload) -> Response[List[interactions.ApplicationCommand]]:
|
||||
def bulk_upsert_global_commands(
|
||||
self, application_id: Snowflake, payload
|
||||
) -> Response[List[interactions.ApplicationCommand]]:
|
||||
r = Route('PUT', '/applications/{application_id}/commands', application_id=application_id)
|
||||
return self.request(r, json=payload)
|
||||
|
||||
# Application commands (guild)
|
||||
|
||||
def get_guild_commands(self, application_id: Snowflake, guild_id: Snowflake) -> Response[List[interactions.ApplicationCommand]]:
|
||||
def get_guild_commands(
|
||||
self, application_id: Snowflake, guild_id: Snowflake
|
||||
) -> Response[List[interactions.ApplicationCommand]]:
|
||||
r = Route(
|
||||
'GET',
|
||||
'/applications/{application_id}/guilds/{guild_id}/commands',
|
||||
@@ -1585,7 +1629,8 @@ class HTTPClient:
|
||||
)
|
||||
return self.request(r, json=payload)
|
||||
|
||||
def edit_guild_command(self,
|
||||
def edit_guild_command(
|
||||
self,
|
||||
application_id: Snowflake,
|
||||
guild_id: Snowflake,
|
||||
command_id: Snowflake,
|
||||
@@ -1657,7 +1702,7 @@ class HTTPClient:
|
||||
form: List[Dict[str, Any]] = [
|
||||
{
|
||||
'name': 'payload_json',
|
||||
'value': utils.to_json(payload),
|
||||
'value': utils._to_json(payload),
|
||||
}
|
||||
]
|
||||
|
||||
@@ -1679,7 +1724,7 @@ class HTTPClient:
|
||||
token: str,
|
||||
*,
|
||||
type: InteractionResponseType,
|
||||
data: Optional[interactions.InteractionApplicationCommandCallbackData] = None
|
||||
data: Optional[interactions.InteractionApplicationCommandCallbackData] = None,
|
||||
) -> Response[None]:
|
||||
r = Route(
|
||||
'POST',
|
||||
@@ -1769,7 +1814,7 @@ class HTTPClient:
|
||||
content: Optional[str] = None,
|
||||
embeds: Optional[List[embed.Embed]] = None,
|
||||
allowed_mentions: Optional[message.AllowedMentions] = None,
|
||||
)-> Response[message.Message]:
|
||||
) -> Response[message.Message]:
|
||||
r = Route(
|
||||
'PATCH',
|
||||
'/webhooks/{application_id}/{interaction_token}/messages/{message_id}',
|
||||
|
||||
@@ -262,17 +262,10 @@ class StreamIntegration(Integration):
|
||||
if enable_emoticons is not MISSING:
|
||||
payload['enable_emoticons'] = enable_emoticons
|
||||
|
||||
# This endpoint is undocumented.
|
||||
# Unsure if it returns the data or not as a result
|
||||
await self._state.http.edit_integration(self.guild.id, self.id, **payload)
|
||||
|
||||
if expire_behaviour is not MISSING:
|
||||
self.expire_behaviour = expire_behaviour
|
||||
|
||||
if enable_emoticons is not MISSING:
|
||||
self.enable_emoticons = enable_emoticons
|
||||
|
||||
if expire_grace_period is not MISSING:
|
||||
self.expire_grace_period = expire_grace_period
|
||||
|
||||
async def sync(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
|
||||
@@ -261,7 +261,7 @@ class Interaction:
|
||||
files: List[File] = MISSING,
|
||||
view: Optional[View] = MISSING,
|
||||
allowed_mentions: Optional[AllowedMentions] = None,
|
||||
):
|
||||
) -> InteractionMessage:
|
||||
"""|coro|
|
||||
|
||||
Edits the original interaction response message.
|
||||
@@ -302,7 +302,12 @@ class Interaction:
|
||||
TypeError
|
||||
You specified both ``embed`` and ``embeds`` or ``file`` and ``files``
|
||||
ValueError
|
||||
The length of ``embeds`` was invalid
|
||||
The length of ``embeds`` was invalid.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`InteractionMessage`
|
||||
The newly edited message.
|
||||
"""
|
||||
|
||||
previous_mentions: Optional[AllowedMentions] = self._state.allowed_mentions
|
||||
@@ -326,8 +331,11 @@ class Interaction:
|
||||
files=params.files,
|
||||
)
|
||||
|
||||
# The message channel types should always match
|
||||
message = InteractionMessage(state=self._state, channel=self.channel, data=data) # type: ignore
|
||||
if view and not view.is_finished():
|
||||
self._state.store_view(view, int(data['id']))
|
||||
self._state.store_view(view, message.id)
|
||||
return message
|
||||
|
||||
async def delete_original_message(self) -> None:
|
||||
"""|coro|
|
||||
@@ -672,7 +680,7 @@ class InteractionMessage(Message):
|
||||
files: List[File] = MISSING,
|
||||
view: Optional[View] = MISSING,
|
||||
allowed_mentions: Optional[AllowedMentions] = None,
|
||||
):
|
||||
) -> InteractionMessage:
|
||||
"""|coro|
|
||||
|
||||
Edits the message.
|
||||
@@ -707,9 +715,14 @@ class InteractionMessage(Message):
|
||||
TypeError
|
||||
You specified both ``embed`` and ``embeds`` or ``file`` and ``files``
|
||||
ValueError
|
||||
The length of ``embeds`` was invalid
|
||||
The length of ``embeds`` was invalid.
|
||||
|
||||
Returns
|
||||
---------
|
||||
:class:`InteractionMessage`
|
||||
The newly edited message.
|
||||
"""
|
||||
await self._state._interaction.edit_original_message(
|
||||
return await self._state._interaction.edit_original_message(
|
||||
content=content,
|
||||
embeds=embeds,
|
||||
embed=embed,
|
||||
|
||||
@@ -230,6 +230,7 @@ class Invite(Hashable):
|
||||
|
||||
Returns the invite URL.
|
||||
|
||||
|
||||
The following table illustrates what methods will obtain the attributes:
|
||||
|
||||
+------------------------------------+------------------------------------------------------------+
|
||||
@@ -433,6 +434,9 @@ class Invite(Hashable):
|
||||
def __str__(self) -> str:
|
||||
return self.url
|
||||
|
||||
def __int__(self) -> int:
|
||||
return 0 # To keep the object compatible with the hashable abc.
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'<Invite code={self.code!r} guild={self.guild!r} '
|
||||
|
||||
@@ -34,6 +34,7 @@ from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Typ
|
||||
import discord.abc
|
||||
|
||||
from . import utils
|
||||
from .asset import Asset
|
||||
from .utils import MISSING
|
||||
from .user import BaseUser, User, _UserTag
|
||||
from .activity import create_activity, ActivityTypes
|
||||
@@ -48,11 +49,13 @@ __all__ = (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .channel import VoiceChannel, StageChannel
|
||||
from .asset import Asset
|
||||
from .channel import DMChannel, VoiceChannel, StageChannel
|
||||
from .flags import PublicUserFlags
|
||||
from .guild import Guild
|
||||
from .types.activity import PartialPresenceUpdate
|
||||
from .types.member import (
|
||||
GatewayMember as GatewayMemberPayload,
|
||||
MemberWithUser as MemberWithUserPayload,
|
||||
Member as MemberPayload,
|
||||
UserWithMember as UserWithMemberPayload,
|
||||
)
|
||||
@@ -223,6 +226,10 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
|
||||
Returns the member's name with the discriminator.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the user's ID.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
joined_at: Optional[:class:`datetime.datetime`]
|
||||
@@ -261,6 +268,7 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
'_client_status',
|
||||
'_user',
|
||||
'_state',
|
||||
'_avatar',
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -270,14 +278,17 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
bot: bool
|
||||
system: bool
|
||||
created_at: datetime.datetime
|
||||
default_avatar = User.default_avatar
|
||||
avatar = User.avatar
|
||||
dm_channel = User.dm_channel
|
||||
default_avatar: Asset
|
||||
avatar: Optional[Asset]
|
||||
dm_channel: Optional[DMChannel]
|
||||
create_dm = User.create_dm
|
||||
mutual_guilds = User.mutual_guilds
|
||||
public_flags = User.public_flags
|
||||
mutual_guilds: List[Guild]
|
||||
public_flags: PublicUserFlags
|
||||
banner: Optional[Asset]
|
||||
accent_color: Optional[Colour]
|
||||
accent_colour: Optional[Colour]
|
||||
|
||||
def __init__(self, *, data: GatewayMemberPayload, guild: Guild, state: ConnectionState):
|
||||
def __init__(self, *, data: MemberWithUserPayload, guild: Guild, state: ConnectionState):
|
||||
self._state: ConnectionState = state
|
||||
self._user: User = state.store_user(data['user'])
|
||||
self.guild: Guild = guild
|
||||
@@ -288,10 +299,14 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
self.activities: Tuple[ActivityTypes, ...] = tuple()
|
||||
self.nick: Optional[str] = data.get('nick', None)
|
||||
self.pending: bool = data.get('pending', False)
|
||||
self._avatar: Optional[str] = data.get('avatar')
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self._user)
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'<Member id={self._user.id} name={self._user.name!r} discriminator={self._user.discriminator!r}'
|
||||
@@ -344,6 +359,7 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
self.pending = member.pending
|
||||
self.activities = member.activities
|
||||
self._state = member._state
|
||||
self._avatar = member._avatar
|
||||
|
||||
# Reference will not be copied unless necessary by PRESENCE_UPDATE
|
||||
# See below
|
||||
@@ -369,6 +385,7 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
|
||||
self.premium_since = utils.parse_time(data.get('premium_since'))
|
||||
self._roles = utils.SnowflakeList(map(int, data['roles']))
|
||||
self._avatar = data.get('avatar')
|
||||
|
||||
def _presence_update(self, data: PartialPresenceUpdate, user: UserPayload) -> Optional[Tuple[User, User]]:
|
||||
self.activities = tuple(map(create_activity, data['activities']))
|
||||
@@ -493,6 +510,29 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
"""
|
||||
return self.nick or self.name
|
||||
|
||||
@property
|
||||
def display_avatar(self) -> Asset:
|
||||
""":class:`Asset`: Returns the member's display avatar.
|
||||
|
||||
For regular members this is just their avatar, but
|
||||
if they have a guild specific avatar then that
|
||||
is returned instead.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return self.guild_avatar or self._user.avatar or self._user.default_avatar
|
||||
|
||||
@property
|
||||
def guild_avatar(self) -> Optional[Asset]:
|
||||
"""Optional[:class:`Asset`]: Returns an :class:`Asset` for the guild avatar
|
||||
the member has. If unavailable, ``None`` is returned.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
if self._avatar is None:
|
||||
return None
|
||||
return Asset._from_guild_avatar(self._state, self.guild.id, self.id, self._avatar)
|
||||
|
||||
@property
|
||||
def activity(self) -> Optional[ActivityTypes]:
|
||||
"""Optional[Union[:class:`BaseActivity`, :class:`Spotify`]]: Returns the primary
|
||||
@@ -611,7 +651,7 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
roles: List[discord.abc.Snowflake] = MISSING,
|
||||
voice_channel: Optional[VocalGuildChannel] = MISSING,
|
||||
reason: Optional[str] = None,
|
||||
) -> None:
|
||||
) -> Optional[Member]:
|
||||
"""|coro|
|
||||
|
||||
Edits the member's data.
|
||||
@@ -637,6 +677,9 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
.. versionchanged:: 1.1
|
||||
Can now pass ``None`` to ``voice_channel`` to kick a member from voice.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
The newly member is now optionally returned, if applicable.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
nick: Optional[:class:`str`]
|
||||
@@ -664,6 +707,12 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
You do not have the proper permissions to the action requested.
|
||||
HTTPException
|
||||
The operation failed.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Optional[:class:`.Member`]
|
||||
The newly updated member, if applicable. This is only returned
|
||||
when certain fields are updated.
|
||||
"""
|
||||
http = self._state.http
|
||||
guild_id = self.guild.id
|
||||
@@ -706,7 +755,8 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
payload['roles'] = tuple(r.id for r in roles)
|
||||
|
||||
if payload:
|
||||
await http.edit_member(guild_id, self.id, reason=reason, **payload)
|
||||
data = await http.edit_member(guild_id, self.id, reason=reason, **payload)
|
||||
return Member(data=data, guild=self.guild, state=self._state)
|
||||
|
||||
async def request_to_speak(self) -> None:
|
||||
"""|coro|
|
||||
@@ -847,7 +897,7 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
for role in roles:
|
||||
await req(guild_id, user_id, role.id, reason=reason)
|
||||
|
||||
def get_role(self, role_id: int) -> Optional[Role]:
|
||||
def get_role(self, role_id: int, /) -> Optional[Role]:
|
||||
"""Returns a role with the given ID from roles which the member has.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
@@ -29,7 +29,7 @@ import datetime
|
||||
import re
|
||||
import io
|
||||
from os import PathLike
|
||||
from typing import Dict, TYPE_CHECKING, Union, List, Optional, Any, Callable, Tuple, ClassVar, Optional, overload
|
||||
from typing import Dict, TYPE_CHECKING, Union, List, Optional, Any, Callable, Tuple, ClassVar, Optional, overload, TypeVar, Type
|
||||
|
||||
from . import utils
|
||||
from .reaction import Reaction
|
||||
@@ -76,6 +76,7 @@ if TYPE_CHECKING:
|
||||
from .role import Role
|
||||
from .ui.view import View
|
||||
|
||||
MR = TypeVar('MR', bound='MessageReference')
|
||||
EmojiInputType = Union[Emoji, PartialEmoji, str]
|
||||
|
||||
__all__ = (
|
||||
@@ -124,6 +125,10 @@ class Attachment(Hashable):
|
||||
|
||||
Returns the hash of the attachment.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the attachment's ID.
|
||||
|
||||
.. versionchanged:: 1.7
|
||||
Attachment can now be casted to :class:`str` and is hashable.
|
||||
|
||||
@@ -341,7 +346,8 @@ class DeletedReferencedMessage:
|
||||
@property
|
||||
def id(self) -> int:
|
||||
""":class:`int`: The message ID of the deleted referenced message."""
|
||||
return self._parent.message_id
|
||||
# the parent's message id won't be None here
|
||||
return self._parent.message_id # type: ignore
|
||||
|
||||
@property
|
||||
def channel_id(self) -> int:
|
||||
@@ -393,13 +399,13 @@ class MessageReference:
|
||||
def __init__(self, *, message_id: int, channel_id: int, guild_id: Optional[int] = None, fail_if_not_exists: bool = True):
|
||||
self._state: Optional[ConnectionState] = None
|
||||
self.resolved: Optional[Union[Message, DeletedReferencedMessage]] = None
|
||||
self.message_id: int = message_id
|
||||
self.message_id: Optional[int] = message_id
|
||||
self.channel_id: int = channel_id
|
||||
self.guild_id: Optional[int] = guild_id
|
||||
self.fail_if_not_exists: bool = fail_if_not_exists
|
||||
|
||||
@classmethod
|
||||
def with_state(cls, state: ConnectionState, data: MessageReferencePayload) -> MessageReference:
|
||||
def with_state(cls: Type[MR], state: ConnectionState, data: MessageReferencePayload) -> MR:
|
||||
self = cls.__new__(cls)
|
||||
self.message_id = utils._get_as_snowflake(data, 'message_id')
|
||||
self.channel_id = int(data.pop('channel_id'))
|
||||
@@ -410,7 +416,7 @@ class MessageReference:
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_message(cls, message: Message, *, fail_if_not_exists: bool = True) -> MessageReference:
|
||||
def from_message(cls: Type[MR], message: Message, *, fail_if_not_exists: bool = True) -> MR:
|
||||
"""Creates a :class:`MessageReference` from an existing :class:`~discord.Message`.
|
||||
|
||||
.. versionadded:: 1.6
|
||||
@@ -457,13 +463,13 @@ class MessageReference:
|
||||
return f'<MessageReference message_id={self.message_id!r} channel_id={self.channel_id!r} guild_id={self.guild_id!r}>'
|
||||
|
||||
def to_dict(self) -> MessageReferencePayload:
|
||||
result = {'message_id': self.message_id} if self.message_id is not None else {}
|
||||
result: MessageReferencePayload = {'message_id': self.message_id} if self.message_id is not None else {}
|
||||
result['channel_id'] = self.channel_id
|
||||
if self.guild_id is not None:
|
||||
result['guild_id'] = self.guild_id
|
||||
if self.fail_if_not_exists is not None:
|
||||
result['fail_if_not_exists'] = self.fail_if_not_exists
|
||||
return result # type: ignore
|
||||
return result
|
||||
|
||||
to_message_reference_dict = to_dict
|
||||
|
||||
@@ -501,6 +507,14 @@ class Message(Hashable):
|
||||
|
||||
Returns the message's hash.
|
||||
|
||||
.. describe:: str(x)
|
||||
|
||||
Returns the message's content.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the message's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
tts: :class:`bool`
|
||||
@@ -637,7 +651,7 @@ class Message(Hashable):
|
||||
_HANDLERS: ClassVar[List[Tuple[str, Callable[..., None]]]]
|
||||
_CACHED_SLOTS: ClassVar[List[str]]
|
||||
guild: Optional[Guild]
|
||||
ref: Optional[MessageReference]
|
||||
reference: Optional[MessageReference]
|
||||
mentions: List[Union[User, Member]]
|
||||
author: Union[User, Member]
|
||||
role_mentions: List[Role]
|
||||
@@ -646,7 +660,7 @@ class Message(Hashable):
|
||||
self,
|
||||
*,
|
||||
state: ConnectionState,
|
||||
channel: Union[TextChannel, Thread, DMChannel, GroupChannel, PartialMessageable],
|
||||
channel: MessageableChannel,
|
||||
data: MessagePayload,
|
||||
):
|
||||
self._state: ConnectionState = state
|
||||
@@ -670,6 +684,7 @@ class Message(Hashable):
|
||||
self.components: List[Component] = [_component_factory(d) for d in data.get('components', [])]
|
||||
|
||||
try:
|
||||
# if the channel doesn't have a guild attribute, we handle that
|
||||
self.guild = channel.guild # type: ignore
|
||||
except AttributeError:
|
||||
self.guild = state._get_guild(utils._get_as_snowflake(data, 'guild_id'))
|
||||
@@ -694,7 +709,8 @@ class Message(Hashable):
|
||||
else:
|
||||
chan, _ = state._get_guild_channel(resolved)
|
||||
|
||||
ref.resolved = self.__class__(channel=chan, data=resolved, state=state)
|
||||
# the channel will be the correct type here
|
||||
ref.resolved = self.__class__(channel=chan, data=resolved, state=state) # type: ignore
|
||||
|
||||
for handler in ('author', 'member', 'mentions', 'mention_roles'):
|
||||
try:
|
||||
@@ -708,6 +724,10 @@ class Message(Hashable):
|
||||
f'<{name} id={self.id} channel={self.channel!r} type={self.type!r} author={self.author!r} flags={self.flags!r}>'
|
||||
)
|
||||
|
||||
|
||||
def __str__(self) -> Optional[str]:
|
||||
return self.content
|
||||
|
||||
def _try_patch(self, data, key, transform=None) -> None:
|
||||
try:
|
||||
value = data[key]
|
||||
@@ -1071,6 +1091,7 @@ class Message(Hashable):
|
||||
return f'{self.author.name} has added {self.content} to this channel'
|
||||
|
||||
if self.type is MessageType.guild_stream:
|
||||
# the author will be a Member
|
||||
return f'{self.author.name} is live! Now streaming {self.author.activity.name}' # type: ignore
|
||||
|
||||
if self.type is MessageType.guild_discovery_disqualified:
|
||||
@@ -1095,12 +1116,13 @@ class Message(Hashable):
|
||||
if self.reference is None or self.reference.resolved is None:
|
||||
return 'Sorry, we couldn\'t load the first message in this thread'
|
||||
|
||||
# the resolved message for the reference will be a Message
|
||||
return self.reference.resolved.content # type: ignore
|
||||
|
||||
if self.type is MessageType.guild_invite_reminder:
|
||||
return 'Wondering who to invite?\nStart by inviting anyone who can help you build the server!'
|
||||
|
||||
async def delete(self, *, delay: Optional[float] = None) -> None:
|
||||
async def delete(self, *, delay: Optional[float] = None, silent: bool = False) -> None:
|
||||
"""|coro|
|
||||
|
||||
Deletes the message.
|
||||
@@ -1111,12 +1133,17 @@ class Message(Hashable):
|
||||
|
||||
.. versionchanged:: 1.1
|
||||
Added the new ``delay`` keyword-only parameter.
|
||||
.. versionchanged:: 2.0
|
||||
Added the new ``silent`` keyword-only parameter.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
delay: Optional[:class:`float`]
|
||||
If provided, the number of seconds to wait in the background
|
||||
before deleting the message. If the deletion fails then it is silently ignored.
|
||||
silent: :class:`bool`
|
||||
If silent is set to ``True``, the error will not be raised, it will be ignored.
|
||||
This defaults to ``False``
|
||||
|
||||
Raises
|
||||
------
|
||||
@@ -1138,7 +1165,11 @@ class Message(Hashable):
|
||||
|
||||
asyncio.create_task(delete(delay))
|
||||
else:
|
||||
await self._state.http.delete_message(self.channel.id, self.id)
|
||||
try:
|
||||
await self._state.http.delete_message(self.channel.id, self.id)
|
||||
except Exception:
|
||||
if not silent:
|
||||
raise
|
||||
|
||||
@overload
|
||||
async def edit(
|
||||
@@ -1151,7 +1182,7 @@ class Message(Hashable):
|
||||
delete_after: Optional[float] = ...,
|
||||
allowed_mentions: Optional[AllowedMentions] = ...,
|
||||
view: Optional[View] = ...,
|
||||
) -> None:
|
||||
) -> Message:
|
||||
...
|
||||
|
||||
@overload
|
||||
@@ -1165,7 +1196,7 @@ class Message(Hashable):
|
||||
delete_after: Optional[float] = ...,
|
||||
allowed_mentions: Optional[AllowedMentions] = ...,
|
||||
view: Optional[View] = ...,
|
||||
) -> None:
|
||||
) -> Message:
|
||||
...
|
||||
|
||||
async def edit(
|
||||
@@ -1178,7 +1209,7 @@ class Message(Hashable):
|
||||
delete_after: Optional[float] = None,
|
||||
allowed_mentions: Optional[AllowedMentions] = MISSING,
|
||||
view: Optional[View] = MISSING,
|
||||
) -> None:
|
||||
) -> Message:
|
||||
"""|coro|
|
||||
|
||||
Edits the message.
|
||||
@@ -1280,9 +1311,8 @@ class Message(Hashable):
|
||||
else:
|
||||
payload['components'] = []
|
||||
|
||||
if payload:
|
||||
data = await self._state.http.edit_message(self.channel.id, self.id, **payload)
|
||||
self._update(data)
|
||||
data = await self._state.http.edit_message(self.channel.id, self.id, **payload)
|
||||
message = Message(state=self._state, channel=self.channel, data=data)
|
||||
|
||||
if view and not view.is_finished():
|
||||
self._state.store_view(view, self.id)
|
||||
@@ -1290,6 +1320,8 @@ class Message(Hashable):
|
||||
if delete_after is not None:
|
||||
await self.delete(delay=delete_after)
|
||||
|
||||
return message
|
||||
|
||||
async def publish(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
@@ -1489,8 +1521,8 @@ class Message(Hashable):
|
||||
|
||||
Creates a public thread from this message.
|
||||
|
||||
You must have :attr:`~discord.Permissions.send_messages` and
|
||||
:attr:`~discord.Permissions.use_threads` in order to create a thread.
|
||||
You must have :attr:`~discord.Permissions.create_public_threads` in order to
|
||||
create a public thread from a message.
|
||||
|
||||
The channel this message belongs in must be a :class:`TextChannel`.
|
||||
|
||||
@@ -1521,13 +1553,14 @@ class Message(Hashable):
|
||||
if self.guild is None:
|
||||
raise InvalidArgument('This message does not have guild info attached.')
|
||||
|
||||
default_auto_archive_duration: ThreadArchiveDuration = getattr(self.channel, 'default_auto_archive_duration', 1440)
|
||||
data = await self._state.http.start_thread_with_message(
|
||||
self.channel.id,
|
||||
self.id,
|
||||
name=name,
|
||||
auto_archive_duration=auto_archive_duration or self.channel.default_auto_archive_duration,
|
||||
auto_archive_duration=auto_archive_duration or default_auto_archive_duration,
|
||||
)
|
||||
return Thread(guild=self.guild, state=self._state, data=data) # type: ignore
|
||||
return Thread(guild=self.guild, state=self._state, data=data)
|
||||
|
||||
async def reply(self, content: Optional[str] = None, **kwargs) -> Message:
|
||||
"""|coro|
|
||||
@@ -1617,6 +1650,10 @@ class PartialMessage(Hashable):
|
||||
|
||||
Returns the partial message's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the partial message's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
channel: Union[:class:`TextChannel`, :class:`Thread`, :class:`DMChannel`]
|
||||
@@ -1773,7 +1810,7 @@ class PartialMessage(Hashable):
|
||||
fields['embed'] = embed.to_dict()
|
||||
|
||||
try:
|
||||
suppress = fields.pop('suppress')
|
||||
suppress: bool = fields.pop('suppress')
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
@@ -1811,9 +1848,10 @@ class PartialMessage(Hashable):
|
||||
data = await self._state.http.edit_message(self.channel.id, self.id, **fields)
|
||||
|
||||
if delete_after is not None:
|
||||
await self.delete(delay=delete_after) # type: ignore
|
||||
await self.delete(delay=delete_after)
|
||||
|
||||
if fields:
|
||||
# data isn't unbound
|
||||
msg = self._state.create_message(channel=self.channel, data=data) # type: ignore
|
||||
if view and not view.is_finished():
|
||||
self._state.store_view(view, self.id)
|
||||
|
||||
@@ -43,5 +43,8 @@ class EqualityComparable:
|
||||
class Hashable(EqualityComparable):
|
||||
__slots__ = ()
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.id
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.id >> 22
|
||||
|
||||
@@ -69,6 +69,10 @@ class Object(Hashable):
|
||||
|
||||
Returns the object's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the object's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
id: :class:`int`
|
||||
|
||||
@@ -22,8 +22,12 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import struct
|
||||
|
||||
from typing import TYPE_CHECKING, ClassVar, IO, Generator, Tuple, Optional
|
||||
|
||||
from .errors import DiscordException
|
||||
|
||||
__all__ = (
|
||||
@@ -40,22 +44,29 @@ class OggError(DiscordException):
|
||||
# https://tools.ietf.org/html/rfc7845
|
||||
|
||||
class OggPage:
|
||||
_header = struct.Struct('<xBQIIIB')
|
||||
_header: ClassVar[struct.Struct] = struct.Struct('<xBQIIIB')
|
||||
if TYPE_CHECKING:
|
||||
flag: int
|
||||
gran_pos: int
|
||||
serial: int
|
||||
pagenum: int
|
||||
crc: int
|
||||
segnum: int
|
||||
|
||||
def __init__(self, stream):
|
||||
def __init__(self, stream: IO[bytes]) -> None:
|
||||
try:
|
||||
header = stream.read(struct.calcsize(self._header.format))
|
||||
|
||||
self.flag, self.gran_pos, self.serial, \
|
||||
self.pagenum, self.crc, self.segnum = self._header.unpack(header)
|
||||
|
||||
self.segtable = stream.read(self.segnum)
|
||||
self.segtable: bytes = stream.read(self.segnum)
|
||||
bodylen = sum(struct.unpack('B'*self.segnum, self.segtable))
|
||||
self.data = stream.read(bodylen)
|
||||
self.data: bytes = stream.read(bodylen)
|
||||
except Exception:
|
||||
raise OggError('bad data stream') from None
|
||||
|
||||
def iter_packets(self):
|
||||
def iter_packets(self) -> Generator[Tuple[bytes, bool], None, None]:
|
||||
packetlen = offset = 0
|
||||
partial = True
|
||||
|
||||
@@ -74,10 +85,10 @@ class OggPage:
|
||||
yield self.data[offset:], False
|
||||
|
||||
class OggStream:
|
||||
def __init__(self, stream):
|
||||
self.stream = stream
|
||||
def __init__(self, stream: IO[bytes]) -> None:
|
||||
self.stream: IO[bytes] = stream
|
||||
|
||||
def _next_page(self):
|
||||
def _next_page(self) -> Optional[OggPage]:
|
||||
head = self.stream.read(4)
|
||||
if head == b'OggS':
|
||||
return OggPage(self.stream)
|
||||
@@ -86,13 +97,13 @@ class OggStream:
|
||||
else:
|
||||
raise OggError('invalid header magic')
|
||||
|
||||
def _iter_pages(self):
|
||||
def _iter_pages(self) -> Generator[OggPage, None, None]:
|
||||
page = self._next_page()
|
||||
while page:
|
||||
yield page
|
||||
page = self._next_page()
|
||||
|
||||
def iter_packets(self):
|
||||
def iter_packets(self) -> Generator[bytes, None, None]:
|
||||
partial = b''
|
||||
for page in self._iter_pages():
|
||||
for data, complete in page.iter_packets():
|
||||
|
||||
125
discord/opus.py
125
discord/opus.py
@@ -22,6 +22,10 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Tuple, TypedDict, Any, TYPE_CHECKING, Callable, TypeVar, Literal, Optional, overload
|
||||
|
||||
import array
|
||||
import ctypes
|
||||
import ctypes.util
|
||||
@@ -31,7 +35,24 @@ import os.path
|
||||
import struct
|
||||
import sys
|
||||
|
||||
from .errors import DiscordException
|
||||
from .errors import DiscordException, InvalidArgument
|
||||
|
||||
if TYPE_CHECKING:
|
||||
T = TypeVar('T')
|
||||
BAND_CTL = Literal['narrow', 'medium', 'wide', 'superwide', 'full']
|
||||
SIGNAL_CTL = Literal['auto', 'voice', 'music']
|
||||
|
||||
class BandCtl(TypedDict):
|
||||
narrow: int
|
||||
medium: int
|
||||
wide: int
|
||||
superwide: int
|
||||
full: int
|
||||
|
||||
class SignalCtl(TypedDict):
|
||||
auto: int
|
||||
voice: int
|
||||
music: int
|
||||
|
||||
__all__ = (
|
||||
'Encoder',
|
||||
@@ -39,7 +60,7 @@ __all__ = (
|
||||
'OpusNotLoaded',
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
c_int_ptr = ctypes.POINTER(ctypes.c_int)
|
||||
c_int16_ptr = ctypes.POINTER(ctypes.c_int16)
|
||||
@@ -76,7 +97,7 @@ CTL_SET_SIGNAL = 4024
|
||||
CTL_SET_GAIN = 4034
|
||||
CTL_LAST_PACKET_DURATION = 4039
|
||||
|
||||
band_ctl = {
|
||||
band_ctl: BandCtl = {
|
||||
'narrow': 1101,
|
||||
'medium': 1102,
|
||||
'wide': 1103,
|
||||
@@ -84,22 +105,22 @@ band_ctl = {
|
||||
'full': 1105,
|
||||
}
|
||||
|
||||
signal_ctl = {
|
||||
signal_ctl: SignalCtl = {
|
||||
'auto': -1000,
|
||||
'voice': 3001,
|
||||
'music': 3002,
|
||||
}
|
||||
|
||||
def _err_lt(result, func, args):
|
||||
def _err_lt(result: int, func: Callable, args: List) -> int:
|
||||
if result < OK:
|
||||
log.info('error has happened in %s', func.__name__)
|
||||
_log.info('error has happened in %s', func.__name__)
|
||||
raise OpusError(result)
|
||||
return result
|
||||
|
||||
def _err_ne(result, func, args):
|
||||
def _err_ne(result: T, func: Callable, args: List) -> T:
|
||||
ret = args[-1]._obj
|
||||
if ret.value != OK:
|
||||
log.info('error has happened in %s', func.__name__)
|
||||
_log.info('error has happened in %s', func.__name__)
|
||||
raise OpusError(ret.value)
|
||||
return result
|
||||
|
||||
@@ -108,7 +129,7 @@ def _err_ne(result, func, args):
|
||||
# The second one are the types of arguments it takes.
|
||||
# The third is the result type.
|
||||
# The fourth is the error handler.
|
||||
exported_functions = [
|
||||
exported_functions: List[Tuple[Any, ...]] = [
|
||||
# Generic
|
||||
('opus_get_version_string',
|
||||
None, ctypes.c_char_p, None),
|
||||
@@ -158,7 +179,7 @@ exported_functions = [
|
||||
[ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
|
||||
]
|
||||
|
||||
def libopus_loader(name):
|
||||
def libopus_loader(name: str) -> Any:
|
||||
# create the library...
|
||||
lib = ctypes.cdll.LoadLibrary(name)
|
||||
|
||||
@@ -178,11 +199,11 @@ def libopus_loader(name):
|
||||
if item[3]:
|
||||
func.errcheck = item[3]
|
||||
except KeyError:
|
||||
log.exception("Error assigning check function to %s", func)
|
||||
_log.exception("Error assigning check function to %s", func)
|
||||
|
||||
return lib
|
||||
|
||||
def _load_default():
|
||||
def _load_default() -> bool:
|
||||
global _lib
|
||||
try:
|
||||
if sys.platform == 'win32':
|
||||
@@ -198,7 +219,7 @@ def _load_default():
|
||||
|
||||
return _lib is not None
|
||||
|
||||
def load_opus(name):
|
||||
def load_opus(name: str) -> None:
|
||||
"""Loads the libopus shared library for use with voice.
|
||||
|
||||
If this function is not called then the library uses the function
|
||||
@@ -236,7 +257,7 @@ def load_opus(name):
|
||||
global _lib
|
||||
_lib = libopus_loader(name)
|
||||
|
||||
def is_loaded():
|
||||
def is_loaded() -> bool:
|
||||
"""Function to check if opus lib is successfully loaded either
|
||||
via the :func:`ctypes.util.find_library` call of :func:`load_opus`.
|
||||
|
||||
@@ -259,10 +280,10 @@ class OpusError(DiscordException):
|
||||
The error code returned.
|
||||
"""
|
||||
|
||||
def __init__(self, code):
|
||||
self.code = code
|
||||
def __init__(self, code: int):
|
||||
self.code: int = code
|
||||
msg = _lib.opus_strerror(self.code).decode('utf-8')
|
||||
log.info('"%s" has happened', msg)
|
||||
_log.info('"%s" has happened', msg)
|
||||
super().__init__(msg)
|
||||
|
||||
class OpusNotLoaded(DiscordException):
|
||||
@@ -286,92 +307,96 @@ class _OpusStruct:
|
||||
return _lib.opus_get_version_string().decode('utf-8')
|
||||
|
||||
class Encoder(_OpusStruct):
|
||||
def __init__(self, application=APPLICATION_AUDIO):
|
||||
def __init__(self, application: int = APPLICATION_AUDIO):
|
||||
_OpusStruct.get_opus_version()
|
||||
|
||||
self.application = application
|
||||
self._state = self._create_state()
|
||||
self.application: int = application
|
||||
self._state: EncoderStruct = self._create_state()
|
||||
self.set_bitrate(128)
|
||||
self.set_fec(True)
|
||||
self.set_expected_packet_loss_percent(0.15)
|
||||
self.set_bandwidth('full')
|
||||
self.set_signal_type('auto')
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
if hasattr(self, '_state'):
|
||||
_lib.opus_encoder_destroy(self._state)
|
||||
self._state = None
|
||||
# This is a destructor, so it's okay to assign None
|
||||
self._state = None # type: ignore
|
||||
|
||||
def _create_state(self):
|
||||
def _create_state(self) -> EncoderStruct:
|
||||
ret = ctypes.c_int()
|
||||
return _lib.opus_encoder_create(self.SAMPLING_RATE, self.CHANNELS, self.application, ctypes.byref(ret))
|
||||
|
||||
def set_bitrate(self, kbps):
|
||||
def set_bitrate(self, kbps: int) -> int:
|
||||
kbps = min(512, max(16, int(kbps)))
|
||||
|
||||
_lib.opus_encoder_ctl(self._state, CTL_SET_BITRATE, kbps * 1024)
|
||||
return kbps
|
||||
|
||||
def set_bandwidth(self, req):
|
||||
def set_bandwidth(self, req: BAND_CTL) -> None:
|
||||
if req not in band_ctl:
|
||||
raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(band_ctl)}')
|
||||
|
||||
k = band_ctl[req]
|
||||
_lib.opus_encoder_ctl(self._state, CTL_SET_BANDWIDTH, k)
|
||||
|
||||
def set_signal_type(self, req):
|
||||
def set_signal_type(self, req: SIGNAL_CTL) -> None:
|
||||
if req not in signal_ctl:
|
||||
raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(signal_ctl)}')
|
||||
|
||||
k = signal_ctl[req]
|
||||
_lib.opus_encoder_ctl(self._state, CTL_SET_SIGNAL, k)
|
||||
|
||||
def set_fec(self, enabled=True):
|
||||
def set_fec(self, enabled: bool = True) -> None:
|
||||
_lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0)
|
||||
|
||||
def set_expected_packet_loss_percent(self, percentage):
|
||||
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100))))
|
||||
def set_expected_packet_loss_percent(self, percentage: float) -> None:
|
||||
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore
|
||||
|
||||
def encode(self, pcm, frame_size):
|
||||
def encode(self, pcm: bytes, frame_size: int) -> bytes:
|
||||
max_data_bytes = len(pcm)
|
||||
pcm = ctypes.cast(pcm, c_int16_ptr)
|
||||
# bytes can be used to reference pointer
|
||||
pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore
|
||||
data = (ctypes.c_char * max_data_bytes)()
|
||||
|
||||
ret = _lib.opus_encode(self._state, pcm, frame_size, data, max_data_bytes)
|
||||
ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes)
|
||||
|
||||
return array.array('b', data[:ret]).tobytes()
|
||||
# array can be initialized with bytes but mypy doesn't know
|
||||
return array.array('b', data[:ret]).tobytes() # type: ignore
|
||||
|
||||
class Decoder(_OpusStruct):
|
||||
def __init__(self):
|
||||
_OpusStruct.get_opus_version()
|
||||
|
||||
self._state = self._create_state()
|
||||
self._state: DecoderStruct = self._create_state()
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
if hasattr(self, '_state'):
|
||||
_lib.opus_decoder_destroy(self._state)
|
||||
self._state = None
|
||||
# This is a destructor, so it's okay to assign None
|
||||
self._state = None # type: ignore
|
||||
|
||||
def _create_state(self):
|
||||
def _create_state(self) -> DecoderStruct:
|
||||
ret = ctypes.c_int()
|
||||
return _lib.opus_decoder_create(self.SAMPLING_RATE, self.CHANNELS, ctypes.byref(ret))
|
||||
|
||||
@staticmethod
|
||||
def packet_get_nb_frames(data):
|
||||
def packet_get_nb_frames(data: bytes) -> int:
|
||||
"""Gets the number of frames in an Opus packet"""
|
||||
return _lib.opus_packet_get_nb_frames(data, len(data))
|
||||
|
||||
@staticmethod
|
||||
def packet_get_nb_channels(data):
|
||||
def packet_get_nb_channels(data: bytes) -> int:
|
||||
"""Gets the number of channels in an Opus packet"""
|
||||
return _lib.opus_packet_get_nb_channels(data)
|
||||
|
||||
@classmethod
|
||||
def packet_get_samples_per_frame(cls, data):
|
||||
def packet_get_samples_per_frame(cls, data: bytes) -> int:
|
||||
"""Gets the number of samples per frame from an Opus packet"""
|
||||
return _lib.opus_packet_get_samples_per_frame(data, cls.SAMPLING_RATE)
|
||||
|
||||
def _set_gain(self, adjustment):
|
||||
def _set_gain(self, adjustment: int) -> int:
|
||||
"""Configures decoder gain adjustment.
|
||||
|
||||
Scales the decoded output by a factor specified in Q8 dB units.
|
||||
@@ -383,26 +408,34 @@ class Decoder(_OpusStruct):
|
||||
"""
|
||||
return _lib.opus_decoder_ctl(self._state, CTL_SET_GAIN, adjustment)
|
||||
|
||||
def set_gain(self, dB):
|
||||
def set_gain(self, dB: float) -> int:
|
||||
"""Sets the decoder gain in dB, from -128 to 128."""
|
||||
|
||||
dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8)
|
||||
return self._set_gain(dB_Q8)
|
||||
|
||||
def set_volume(self, mult):
|
||||
def set_volume(self, mult: float) -> int:
|
||||
"""Sets the output volume as a float percent, i.e. 0.5 for 50%, 1.75 for 175%, etc."""
|
||||
return self.set_gain(20 * math.log10(mult)) # amplitude ratio
|
||||
|
||||
def _get_last_packet_duration(self):
|
||||
def _get_last_packet_duration(self) -> int:
|
||||
"""Gets the duration (in samples) of the last packet successfully decoded or concealed."""
|
||||
|
||||
ret = ctypes.c_int32()
|
||||
_lib.opus_decoder_ctl(self._state, CTL_LAST_PACKET_DURATION, ctypes.byref(ret))
|
||||
return ret.value
|
||||
|
||||
def decode(self, data, *, fec=False):
|
||||
@overload
|
||||
def decode(self, data: bytes, *, fec: bool) -> bytes:
|
||||
...
|
||||
|
||||
@overload
|
||||
def decode(self, data: Literal[None], *, fec: Literal[False]) -> bytes:
|
||||
...
|
||||
|
||||
def decode(self, data: Optional[bytes], *, fec: bool = False) -> bytes:
|
||||
if data is None and fec:
|
||||
raise OpusError("Invalid arguments: FEC cannot be used with null data")
|
||||
raise InvalidArgument("Invalid arguments: FEC cannot be used with null data")
|
||||
|
||||
if data is None:
|
||||
frame_size = self._get_last_packet_duration() or self.SAMPLES_PER_FRAME
|
||||
|
||||
@@ -147,7 +147,7 @@ class Permissions(BaseFlags):
|
||||
"""A factory method that creates a :class:`Permissions` with all
|
||||
permissions set to ``True``.
|
||||
"""
|
||||
return cls(0b11111111111111111111111111111111111111)
|
||||
return cls(0b111111111111111111111111111111111111111)
|
||||
|
||||
@classmethod
|
||||
def all_channel(cls: Type[P]) -> P:
|
||||
@@ -169,10 +169,11 @@ class Permissions(BaseFlags):
|
||||
Added :attr:`stream`, :attr:`priority_speaker` and :attr:`use_slash_commands` permissions.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
Added :attr:`use_threads`, :attr:`use_private_threads`, :attr:`manage_threads`,
|
||||
:attr:`use_external_stickers` and :attr:`request_to_speak` permissions.
|
||||
Added :attr:`create_public_threads`, :attr:`create_private_threads`, :attr:`manage_threads`,
|
||||
:attr:`use_external_stickers`, :attr:`send_messages_in_threads` and
|
||||
:attr:`request_to_speak` permissions.
|
||||
"""
|
||||
return cls(0b11110110110011111101111111111101010001)
|
||||
return cls(0b111110110110011111101111111111101010001)
|
||||
|
||||
@classmethod
|
||||
def general(cls: Type[P]) -> P:
|
||||
@@ -206,10 +207,10 @@ class Permissions(BaseFlags):
|
||||
Added :attr:`use_slash_commands` permission.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
Added :attr:`use_threads`, :attr:`use_private_threads`, :attr:`manage_threads`
|
||||
and :attr:`use_external_stickers` permissions.
|
||||
Added :attr:`create_public_threads`, :attr:`create_private_threads`, :attr:`manage_threads`,
|
||||
:attr:`send_messages_in_threads` and :attr:`use_external_stickers` permissions.
|
||||
"""
|
||||
return cls(0b11110010000000000001111111100001000000)
|
||||
return cls(0b111110010000000000001111111100001000000)
|
||||
|
||||
@classmethod
|
||||
def voice(cls: Type[P]) -> P:
|
||||
@@ -471,7 +472,7 @@ class Permissions(BaseFlags):
|
||||
return 1 << 30
|
||||
|
||||
@make_permission_alias('manage_emojis')
|
||||
def manage_emojis_and_stickers(self):
|
||||
def manage_emojis_and_stickers(self) -> int:
|
||||
""":class:`bool`: An alias for :attr:`manage_emojis`.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
@@ -511,16 +512,16 @@ class Permissions(BaseFlags):
|
||||
return 1 << 34
|
||||
|
||||
@flag_value
|
||||
def use_threads(self) -> int:
|
||||
""":class:`bool`: Returns ``True`` if a user can create and participate in public threads.
|
||||
def create_public_threads(self) -> int:
|
||||
""":class:`bool`: Returns ``True`` if a user can create public threads.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return 1 << 35
|
||||
|
||||
@flag_value
|
||||
def use_private_threads(self) -> int:
|
||||
""":class:`bool`: Returns ``True`` if a user can create and participate in private threads.
|
||||
def create_private_threads(self) -> int:
|
||||
""":class:`bool`: Returns ``True`` if a user can create private threads.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
@@ -542,6 +543,14 @@ class Permissions(BaseFlags):
|
||||
"""
|
||||
return 1 << 37
|
||||
|
||||
@flag_value
|
||||
def send_messages_in_threads(self) -> int:
|
||||
""":class:`bool`: Returns ``True`` if a user can send messages in threads.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return 1 << 38
|
||||
|
||||
PO = TypeVar('PO', bound='PermissionOverwrite')
|
||||
|
||||
def _augment_from_permissions(cls):
|
||||
@@ -645,12 +654,16 @@ class PermissionOverwrite:
|
||||
manage_permissions: Optional[bool]
|
||||
manage_webhooks: Optional[bool]
|
||||
manage_emojis: Optional[bool]
|
||||
manage_emojis_and_stickers: Optional[bool]
|
||||
use_slash_commands: Optional[bool]
|
||||
request_to_speak: Optional[bool]
|
||||
manage_events: Optional[bool]
|
||||
manage_threads: Optional[bool]
|
||||
use_threads: Optional[bool]
|
||||
use_private_threads: Optional[bool]
|
||||
create_public_threads: Optional[bool]
|
||||
create_private_threads: Optional[bool]
|
||||
send_messages_in_threads: Optional[bool]
|
||||
external_stickers: Optional[bool]
|
||||
use_external_stickers: Optional[bool]
|
||||
|
||||
def __init__(self, **kwargs: Optional[bool]):
|
||||
self._values: Dict[str, Optional[bool]] = {}
|
||||
@@ -673,7 +686,7 @@ class PermissionOverwrite:
|
||||
else:
|
||||
self._values[key] = value
|
||||
|
||||
def pair(self):
|
||||
def pair(self) -> Tuple[Permissions, Permissions]:
|
||||
"""Tuple[:class:`Permissions`, :class:`Permissions`]: Returns the (allow, deny) pair from this overwrite."""
|
||||
|
||||
allow = Permissions.none()
|
||||
|
||||
@@ -50,7 +50,7 @@ if TYPE_CHECKING:
|
||||
AT = TypeVar('AT', bound='AudioSource')
|
||||
FT = TypeVar('FT', bound='FFmpegOpusAudio')
|
||||
|
||||
log: logging.Logger = logging.getLogger(__name__)
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
__all__ = (
|
||||
'AudioSource',
|
||||
@@ -140,13 +140,25 @@ class FFmpegAudio(AudioSource):
|
||||
.. versionadded:: 1.3
|
||||
"""
|
||||
|
||||
def __init__(self, source: str, *, executable: str = 'ffmpeg', args: Any, **subprocess_kwargs: Any):
|
||||
def __init__(self, source: Union[str, io.BufferedIOBase], *, executable: str = 'ffmpeg', args: Any, **subprocess_kwargs: Any):
|
||||
piping = subprocess_kwargs.get('stdin') == subprocess.PIPE
|
||||
if piping and isinstance(source, str):
|
||||
raise TypeError("parameter conflict: 'source' parameter cannot be a string when piping to stdin")
|
||||
|
||||
args = [executable, *args]
|
||||
kwargs = {'stdout': subprocess.PIPE}
|
||||
kwargs.update(subprocess_kwargs)
|
||||
|
||||
self._process: subprocess.Popen = self._spawn_process(args, **kwargs)
|
||||
self._stdout: IO[bytes] = self._process.stdout # type: ignore
|
||||
self._stdin: Optional[IO[Bytes]] = None
|
||||
self._pipe_thread: Optional[threading.Thread] = None
|
||||
|
||||
if piping:
|
||||
n = f'popen-stdin-writer:{id(self):#x}'
|
||||
self._stdin = self._process.stdin
|
||||
self._pipe_thread = threading.Thread(target=self._pipe_writer, args=(source,), daemon=True, name=n)
|
||||
self._pipe_thread.start()
|
||||
|
||||
def _spawn_process(self, args: Any, **subprocess_kwargs: Any) -> subprocess.Popen:
|
||||
process = None
|
||||
@@ -160,26 +172,44 @@ class FFmpegAudio(AudioSource):
|
||||
else:
|
||||
return process
|
||||
|
||||
def cleanup(self) -> None:
|
||||
def _kill_process(self) -> None:
|
||||
proc = self._process
|
||||
if proc is MISSING:
|
||||
return
|
||||
|
||||
log.info('Preparing to terminate ffmpeg process %s.', proc.pid)
|
||||
_log.info('Preparing to terminate ffmpeg process %s.', proc.pid)
|
||||
|
||||
try:
|
||||
proc.kill()
|
||||
except Exception:
|
||||
log.exception("Ignoring error attempting to kill ffmpeg process %s", proc.pid)
|
||||
_log.exception('Ignoring error attempting to kill ffmpeg process %s', proc.pid)
|
||||
|
||||
if proc.poll() is None:
|
||||
log.info('ffmpeg process %s has not terminated. Waiting to terminate...', proc.pid)
|
||||
_log.info('ffmpeg process %s has not terminated. Waiting to terminate...', proc.pid)
|
||||
proc.communicate()
|
||||
log.info('ffmpeg process %s should have terminated with a return code of %s.', proc.pid, proc.returncode)
|
||||
_log.info('ffmpeg process %s should have terminated with a return code of %s.', proc.pid, proc.returncode)
|
||||
else:
|
||||
log.info('ffmpeg process %s successfully terminated with return code of %s.', proc.pid, proc.returncode)
|
||||
_log.info('ffmpeg process %s successfully terminated with return code of %s.', proc.pid, proc.returncode)
|
||||
|
||||
self._process = self._stdout = MISSING
|
||||
|
||||
def _pipe_writer(self, source: io.BufferedIOBase) -> None:
|
||||
while self._process:
|
||||
# arbitrarily large read size
|
||||
data = source.read(8192)
|
||||
if not data:
|
||||
self._process.terminate()
|
||||
return
|
||||
try:
|
||||
self._stdin.write(data)
|
||||
except Exception:
|
||||
_log.debug('Write error for %s, this is probably not a problem', self, exc_info=True)
|
||||
# at this point the source data is either exhausted or the process is fubar
|
||||
self._process.terminate()
|
||||
return
|
||||
|
||||
def cleanup(self) -> None:
|
||||
self._kill_process()
|
||||
self._process = self._stdout = self._stdin = MISSING
|
||||
|
||||
class FFmpegPCMAudio(FFmpegAudio):
|
||||
"""An audio source from FFmpeg (or AVConv).
|
||||
@@ -218,16 +248,16 @@ class FFmpegPCMAudio(FFmpegAudio):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source: str,
|
||||
source: Union[str, io.BufferedIOBase],
|
||||
*,
|
||||
executable: str = 'ffmpeg',
|
||||
pipe: bool = False,
|
||||
stderr: Optional[IO[str]] = None,
|
||||
before_options: Optional[str] = None,
|
||||
before_options: Optional[str] = None,
|
||||
options: Optional[str] = None
|
||||
) -> None:
|
||||
args = []
|
||||
subprocess_kwargs = {'stdin': source if pipe else subprocess.DEVNULL, 'stderr': stderr}
|
||||
subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr}
|
||||
|
||||
if isinstance(before_options, str):
|
||||
args.extend(shlex.split(before_options))
|
||||
@@ -315,7 +345,7 @@ class FFmpegOpusAudio(FFmpegAudio):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source: str,
|
||||
source: Union[str, io.BufferedIOBase],
|
||||
*,
|
||||
bitrate: int = 128,
|
||||
codec: Optional[str] = None,
|
||||
@@ -327,7 +357,7 @@ class FFmpegOpusAudio(FFmpegAudio):
|
||||
) -> None:
|
||||
|
||||
args = []
|
||||
subprocess_kwargs = {'stdin': source if pipe else subprocess.DEVNULL, 'stderr': stderr}
|
||||
subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr}
|
||||
|
||||
if isinstance(before_options, str):
|
||||
args.extend(shlex.split(before_options))
|
||||
@@ -384,7 +414,6 @@ class FFmpegOpusAudio(FFmpegAudio):
|
||||
|
||||
def custom_probe(source, executable):
|
||||
# some analysis code here
|
||||
|
||||
return codec, bitrate
|
||||
|
||||
source = await discord.FFmpegOpusAudio.from_probe("song.webm", method=custom_probe)
|
||||
@@ -480,18 +509,18 @@ class FFmpegOpusAudio(FFmpegAudio):
|
||||
codec, bitrate = await loop.run_in_executor(None, lambda: probefunc(source, executable)) # type: ignore
|
||||
except Exception:
|
||||
if not fallback:
|
||||
log.exception("Probe '%s' using '%s' failed", method, executable)
|
||||
_log.exception("Probe '%s' using '%s' failed", method, executable)
|
||||
return # type: ignore
|
||||
|
||||
log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable)
|
||||
_log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable)
|
||||
try:
|
||||
codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable)) # type: ignore
|
||||
except Exception:
|
||||
log.exception("Fallback probe using '%s' failed", executable)
|
||||
_log.exception("Fallback probe using '%s' failed", executable)
|
||||
else:
|
||||
log.info("Fallback probe found codec=%s, bitrate=%s", codec, bitrate)
|
||||
_log.info("Fallback probe found codec=%s, bitrate=%s", codec, bitrate)
|
||||
else:
|
||||
log.info("Probe found codec=%s, bitrate=%s", codec, bitrate)
|
||||
_log.info("Probe found codec=%s, bitrate=%s", codec, bitrate)
|
||||
finally:
|
||||
return codec, bitrate
|
||||
|
||||
@@ -656,12 +685,12 @@ class AudioPlayer(threading.Thread):
|
||||
try:
|
||||
self.after(error)
|
||||
except Exception as exc:
|
||||
log.exception('Calling the after function failed.')
|
||||
_log.exception('Calling the after function failed.')
|
||||
exc.__context__ = error
|
||||
traceback.print_exception(type(exc), exc, exc.__traceback__)
|
||||
elif error:
|
||||
msg = f'Exception in voice thread {self.name}'
|
||||
log.exception(msg, exc_info=error)
|
||||
_log.exception(msg, exc_info=error)
|
||||
print(msg, file=sys.stderr)
|
||||
traceback.print_exception(type(error), error, error.__traceback__)
|
||||
|
||||
@@ -698,4 +727,4 @@ class AudioPlayer(threading.Thread):
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.client.loop)
|
||||
except Exception as e:
|
||||
log.info("Speaking call in player failed: %s", e)
|
||||
_log.info("Speaking call in player failed: %s", e)
|
||||
|
||||
0
discord/py.typed
Normal file
0
discord/py.typed
Normal file
@@ -42,6 +42,7 @@ if TYPE_CHECKING:
|
||||
Role as RolePayload,
|
||||
RoleTags as RoleTagPayload,
|
||||
)
|
||||
from .types.guild import RolePositionUpdate
|
||||
from .guild import Guild
|
||||
from .member import Member
|
||||
from .state import ConnectionState
|
||||
@@ -140,6 +141,14 @@ class Role(Hashable):
|
||||
|
||||
Returns the role's name.
|
||||
|
||||
.. describe:: str(x)
|
||||
|
||||
Returns the role's ID.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the role's ID.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
id: :class:`int`
|
||||
@@ -194,6 +203,9 @@ class Role(Hashable):
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<Role id={self.id} name={self.name!r}>'
|
||||
|
||||
@@ -336,7 +348,7 @@ class Role(Hashable):
|
||||
else:
|
||||
roles.append(self.id)
|
||||
|
||||
payload = [{"id": z[0], "position": z[1]} for z in zip(roles, change_range)]
|
||||
payload: List[RolePositionUpdate] = [{"id": z[0], "position": z[1]} for z in zip(roles, change_range)]
|
||||
await http.move_role_position(self.guild.id, payload, reason=reason)
|
||||
|
||||
async def edit(
|
||||
@@ -350,7 +362,7 @@ class Role(Hashable):
|
||||
mentionable: bool = MISSING,
|
||||
position: int = MISSING,
|
||||
reason: Optional[str] = MISSING,
|
||||
) -> None:
|
||||
) -> Optional[Role]:
|
||||
"""|coro|
|
||||
|
||||
Edits the role.
|
||||
@@ -363,6 +375,9 @@ class Role(Hashable):
|
||||
.. versionchanged:: 1.4
|
||||
Can now pass ``int`` to ``colour`` keyword-only parameter.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
Edits are no longer in-place, the newly edited role is returned instead.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
name: :class:`str`
|
||||
@@ -390,11 +405,14 @@ class Role(Hashable):
|
||||
InvalidArgument
|
||||
An invalid position was given or the default
|
||||
role was asked to be moved.
|
||||
"""
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`Role`
|
||||
The newly edited role.
|
||||
"""
|
||||
if position is not MISSING:
|
||||
await self._move(position, reason=reason)
|
||||
self.position = position
|
||||
|
||||
payload: Dict[str, Any] = {}
|
||||
if color is not MISSING:
|
||||
@@ -419,7 +437,7 @@ class Role(Hashable):
|
||||
payload['mentionable'] = mentionable
|
||||
|
||||
data = await self._state.http.edit_role(self.guild.id, self.id, reason=reason, **payload)
|
||||
self._update(data)
|
||||
return Role(guild=self.guild, data=data, state=self._state)
|
||||
|
||||
async def delete(self, *, reason: Optional[str] = None) -> None:
|
||||
"""|coro|
|
||||
|
||||
180
discord/shard.py
180
discord/shard.py
@@ -22,8 +22,9 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
import logging
|
||||
|
||||
import aiohttp
|
||||
@@ -34,22 +35,30 @@ from .backoff import ExponentialBackoff
|
||||
from .gateway import *
|
||||
from .errors import (
|
||||
ClientException,
|
||||
InvalidArgument,
|
||||
HTTPException,
|
||||
GatewayNotFound,
|
||||
ConnectionClosed,
|
||||
PrivilegedIntentsRequired,
|
||||
)
|
||||
|
||||
from . import utils
|
||||
from .enums import Status
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .gateway import DiscordWebSocket
|
||||
from .activity import BaseActivity
|
||||
from .enums import Status
|
||||
|
||||
EI = TypeVar('EI', bound='EventItem')
|
||||
|
||||
__all__ = (
|
||||
'AutoShardedClient',
|
||||
'ShardInfo',
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventType:
|
||||
close = 0
|
||||
@@ -59,39 +68,41 @@ class EventType:
|
||||
terminate = 4
|
||||
clean_close = 5
|
||||
|
||||
|
||||
class EventItem:
|
||||
__slots__ = ('type', 'shard', 'error')
|
||||
|
||||
def __init__(self, etype, shard, error):
|
||||
self.type = etype
|
||||
self.shard = shard
|
||||
self.error = error
|
||||
def __init__(self, etype: int, shard: Optional['Shard'], error: Optional[Exception]) -> None:
|
||||
self.type: int = etype
|
||||
self.shard: Optional['Shard'] = shard
|
||||
self.error: Optional[Exception] = error
|
||||
|
||||
def __lt__(self, other):
|
||||
def __lt__(self: EI, other: EI) -> bool:
|
||||
if not isinstance(other, EventItem):
|
||||
return NotImplemented
|
||||
return self.type < other.type
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self: EI, other: EI) -> bool:
|
||||
if not isinstance(other, EventItem):
|
||||
return NotImplemented
|
||||
return self.type == other.type
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.type)
|
||||
|
||||
|
||||
class Shard:
|
||||
def __init__(self, ws, client, queue_put):
|
||||
self.ws = ws
|
||||
self._client = client
|
||||
self._dispatch = client.dispatch
|
||||
self._queue_put = queue_put
|
||||
self.loop = self._client.loop
|
||||
self._disconnect = False
|
||||
def __init__(self, ws: DiscordWebSocket, client: AutoShardedClient, queue_put: Callable[[EventItem], None]) -> None:
|
||||
self.ws: DiscordWebSocket = ws
|
||||
self._client: Client = client
|
||||
self._dispatch: Callable[..., None] = client.dispatch
|
||||
self._queue_put: Callable[[EventItem], None] = queue_put
|
||||
self.loop: asyncio.AbstractEventLoop = self._client.loop
|
||||
self._disconnect: bool = False
|
||||
self._reconnect = client._reconnect
|
||||
self._backoff = ExponentialBackoff()
|
||||
self._task = None
|
||||
self._handled_exceptions = (
|
||||
self._backoff: ExponentialBackoff = ExponentialBackoff()
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._handled_exceptions: Tuple[Type[Exception], ...] = (
|
||||
OSError,
|
||||
HTTPException,
|
||||
GatewayNotFound,
|
||||
@@ -101,25 +112,26 @@ class Shard:
|
||||
)
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self.ws.shard_id
|
||||
def id(self) -> int:
|
||||
# DiscordWebSocket.shard_id is set in the from_client classmethod
|
||||
return self.ws.shard_id # type: ignore
|
||||
|
||||
def launch(self):
|
||||
def launch(self) -> None:
|
||||
self._task = self.loop.create_task(self.worker())
|
||||
|
||||
def _cancel_task(self):
|
||||
def _cancel_task(self) -> None:
|
||||
if self._task is not None and not self._task.done():
|
||||
self._task.cancel()
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
self._cancel_task()
|
||||
await self.ws.close(code=1000)
|
||||
|
||||
async def disconnect(self):
|
||||
async def disconnect(self) -> None:
|
||||
await self.close()
|
||||
self._dispatch('shard_disconnect', self.id)
|
||||
|
||||
async def _handle_disconnect(self, e):
|
||||
async def _handle_disconnect(self, e: Exception) -> None:
|
||||
self._dispatch('disconnect')
|
||||
self._dispatch('shard_disconnect', self.id)
|
||||
if not self._reconnect:
|
||||
@@ -144,11 +156,11 @@ class Shard:
|
||||
return
|
||||
|
||||
retry = self._backoff.delay()
|
||||
log.error('Attempting a reconnect for shard ID %s in %.2fs', self.id, retry, exc_info=e)
|
||||
_log.error('Attempting a reconnect for shard ID %s in %.2fs', self.id, retry, exc_info=e)
|
||||
await asyncio.sleep(retry)
|
||||
self._queue_put(EventItem(EventType.reconnect, self, e))
|
||||
|
||||
async def worker(self):
|
||||
async def worker(self) -> None:
|
||||
while not self._client.is_closed():
|
||||
try:
|
||||
await self.ws.poll_event()
|
||||
@@ -165,14 +177,19 @@ class Shard:
|
||||
self._queue_put(EventItem(EventType.terminate, self, e))
|
||||
break
|
||||
|
||||
async def reidentify(self, exc):
|
||||
async def reidentify(self, exc: ReconnectWebSocket) -> None:
|
||||
self._cancel_task()
|
||||
self._dispatch('disconnect')
|
||||
self._dispatch('shard_disconnect', self.id)
|
||||
log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
|
||||
_log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
|
||||
try:
|
||||
coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id,
|
||||
session=self.ws.session_id, sequence=self.ws.sequence)
|
||||
coro = DiscordWebSocket.from_client(
|
||||
self._client,
|
||||
resume=exc.resume,
|
||||
shard_id=self.id,
|
||||
session=self.ws.session_id,
|
||||
sequence=self.ws.sequence,
|
||||
)
|
||||
self.ws = await asyncio.wait_for(coro, timeout=60.0)
|
||||
except self._handled_exceptions as e:
|
||||
await self._handle_disconnect(e)
|
||||
@@ -183,7 +200,7 @@ class Shard:
|
||||
else:
|
||||
self.launch()
|
||||
|
||||
async def reconnect(self):
|
||||
async def reconnect(self) -> None:
|
||||
self._cancel_task()
|
||||
try:
|
||||
coro = DiscordWebSocket.from_client(self._client, shard_id=self.id)
|
||||
@@ -197,6 +214,7 @@ class Shard:
|
||||
else:
|
||||
self.launch()
|
||||
|
||||
|
||||
class ShardInfo:
|
||||
"""A class that gives information and control over a specific shard.
|
||||
|
||||
@@ -215,16 +233,16 @@ class ShardInfo:
|
||||
|
||||
__slots__ = ('_parent', 'id', 'shard_count')
|
||||
|
||||
def __init__(self, parent, shard_count):
|
||||
self._parent = parent
|
||||
self.id = parent.id
|
||||
self.shard_count = shard_count
|
||||
def __init__(self, parent: Shard, shard_count: Optional[int]) -> None:
|
||||
self._parent: Shard = parent
|
||||
self.id: int = parent.id
|
||||
self.shard_count: Optional[int] = shard_count
|
||||
|
||||
def is_closed(self):
|
||||
def is_closed(self) -> bool:
|
||||
""":class:`bool`: Whether the shard connection is currently closed."""
|
||||
return not self._parent.ws.open
|
||||
|
||||
async def disconnect(self):
|
||||
async def disconnect(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
Disconnects a shard. When this is called, the shard connection will no
|
||||
@@ -237,7 +255,7 @@ class ShardInfo:
|
||||
|
||||
await self._parent.disconnect()
|
||||
|
||||
async def reconnect(self):
|
||||
async def reconnect(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
Disconnects and then connects the shard again.
|
||||
@@ -246,7 +264,7 @@ class ShardInfo:
|
||||
await self._parent.disconnect()
|
||||
await self._parent.reconnect()
|
||||
|
||||
async def connect(self):
|
||||
async def connect(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
Connects a shard. If the shard is already connected this does nothing.
|
||||
@@ -257,11 +275,11 @@ class ShardInfo:
|
||||
await self._parent.reconnect()
|
||||
|
||||
@property
|
||||
def latency(self):
|
||||
def latency(self) -> float:
|
||||
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds for this shard."""
|
||||
return self._parent.ws.latency
|
||||
|
||||
def is_ws_ratelimited(self):
|
||||
def is_ws_ratelimited(self) -> bool:
|
||||
""":class:`bool`: Whether the websocket is currently rate limited.
|
||||
|
||||
This can be useful to know when deciding whether you should query members
|
||||
@@ -271,6 +289,7 @@ class ShardInfo:
|
||||
"""
|
||||
return self._parent.ws.is_ratelimited()
|
||||
|
||||
|
||||
class AutoShardedClient(Client):
|
||||
"""A client similar to :class:`Client` except it handles the complications
|
||||
of sharding for the user into a more manageable and transparent single
|
||||
@@ -297,9 +316,13 @@ class AutoShardedClient(Client):
|
||||
shard_ids: Optional[List[:class:`int`]]
|
||||
An optional list of shard_ids to launch the shards with.
|
||||
"""
|
||||
def __init__(self, *args, loop=None, **kwargs):
|
||||
|
||||
if TYPE_CHECKING:
|
||||
_connection: AutoShardedConnectionState
|
||||
|
||||
def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None:
|
||||
kwargs.pop('shard_id', None)
|
||||
self.shard_ids = kwargs.pop('shard_ids', None)
|
||||
self.shard_ids: Optional[List[int]] = kwargs.pop('shard_ids', None)
|
||||
super().__init__(*args, loop=loop, **kwargs)
|
||||
|
||||
if self.shard_ids is not None:
|
||||
@@ -315,18 +338,24 @@ class AutoShardedClient(Client):
|
||||
self._connection._get_client = lambda: self
|
||||
self.__queue = asyncio.PriorityQueue()
|
||||
|
||||
def _get_websocket(self, guild_id=None, *, shard_id=None):
|
||||
def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket:
|
||||
if shard_id is None:
|
||||
shard_id = (guild_id >> 22) % self.shard_count
|
||||
# guild_id won't be None if shard_id is None and shard_count won't be None here
|
||||
shard_id = (guild_id >> 22) % self.shard_count # type: ignore
|
||||
return self.__shards[shard_id].ws
|
||||
|
||||
def _get_state(self, **options):
|
||||
return AutoShardedConnectionState(dispatch=self.dispatch,
|
||||
handlers=self._handlers,
|
||||
hooks=self._hooks, http=self.http, loop=self.loop, **options)
|
||||
def _get_state(self, **options: Any) -> AutoShardedConnectionState:
|
||||
return AutoShardedConnectionState(
|
||||
dispatch=self.dispatch,
|
||||
handlers=self._handlers,
|
||||
hooks=self._hooks,
|
||||
http=self.http,
|
||||
loop=self.loop,
|
||||
**options,
|
||||
)
|
||||
|
||||
@property
|
||||
def latency(self):
|
||||
def latency(self) -> float:
|
||||
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.
|
||||
|
||||
This operates similarly to :meth:`Client.latency` except it uses the average
|
||||
@@ -338,14 +367,14 @@ class AutoShardedClient(Client):
|
||||
return sum(latency for _, latency in self.latencies) / len(self.__shards)
|
||||
|
||||
@property
|
||||
def latencies(self):
|
||||
def latencies(self) -> List[Tuple[int, float]]:
|
||||
"""List[Tuple[:class:`int`, :class:`float`]]: A list of latencies between a HEARTBEAT and a HEARTBEAT_ACK in seconds.
|
||||
|
||||
This returns a list of tuples with elements ``(shard_id, latency)``.
|
||||
"""
|
||||
return [(shard_id, shard.ws.latency) for shard_id, shard in self.__shards.items()]
|
||||
|
||||
def get_shard(self, shard_id):
|
||||
def get_shard(self, shard_id: int) -> Optional[ShardInfo]:
|
||||
"""Optional[:class:`ShardInfo`]: Gets the shard information at a given shard ID or ``None`` if not found."""
|
||||
try:
|
||||
parent = self.__shards[shard_id]
|
||||
@@ -355,16 +384,16 @@ class AutoShardedClient(Client):
|
||||
return ShardInfo(parent, self.shard_count)
|
||||
|
||||
@property
|
||||
def shards(self):
|
||||
def shards(self) -> Dict[int, ShardInfo]:
|
||||
"""Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object."""
|
||||
return { shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items() }
|
||||
return {shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items()}
|
||||
|
||||
async def launch_shard(self, gateway, shard_id, *, initial=False):
|
||||
async def launch_shard(self, gateway: str, shard_id: int, *, initial: bool = False) -> None:
|
||||
try:
|
||||
coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id)
|
||||
ws = await asyncio.wait_for(coro, timeout=180.0)
|
||||
except Exception:
|
||||
log.exception('Failed to connect for shard_id: %s. Retrying...', shard_id)
|
||||
_log.exception('Failed to connect for shard_id: %s. Retrying...', shard_id)
|
||||
await asyncio.sleep(5.0)
|
||||
return await self.launch_shard(gateway, shard_id)
|
||||
|
||||
@@ -372,7 +401,7 @@ class AutoShardedClient(Client):
|
||||
self.__shards[shard_id] = ret = Shard(ws, self, self.__queue.put_nowait)
|
||||
ret.launch()
|
||||
|
||||
async def launch_shards(self):
|
||||
async def launch_shards(self) -> None:
|
||||
if self.shard_count is None:
|
||||
self.shard_count, gateway = await self.http.get_bot_gateway()
|
||||
else:
|
||||
@@ -389,7 +418,7 @@ class AutoShardedClient(Client):
|
||||
|
||||
self._connection.shards_launched.set()
|
||||
|
||||
async def connect(self, *, reconnect=True):
|
||||
async def connect(self, *, reconnect: bool = True) -> None:
|
||||
self._reconnect = reconnect
|
||||
await self.launch_shards()
|
||||
|
||||
@@ -413,7 +442,7 @@ class AutoShardedClient(Client):
|
||||
elif item.type == EventType.clean_close:
|
||||
return
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
Closes the connection to Discord.
|
||||
@@ -425,7 +454,7 @@ class AutoShardedClient(Client):
|
||||
|
||||
for vc in self.voice_clients:
|
||||
try:
|
||||
await vc.disconnect()
|
||||
await vc.disconnect(force=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -436,7 +465,13 @@ class AutoShardedClient(Client):
|
||||
await self.http.close()
|
||||
self.__queue.put_nowait(EventItem(EventType.clean_close, None, None))
|
||||
|
||||
async def change_presence(self, *, activity=None, status=None, shard_id=None):
|
||||
async def change_presence(
|
||||
self,
|
||||
*,
|
||||
activity: Optional[BaseActivity] = None,
|
||||
status: Optional[Status] = None,
|
||||
shard_id: int = None,
|
||||
) -> None:
|
||||
"""|coro|
|
||||
|
||||
Changes the client's presence.
|
||||
@@ -468,23 +503,23 @@ class AutoShardedClient(Client):
|
||||
"""
|
||||
|
||||
if status is None:
|
||||
status = 'online'
|
||||
status_value = 'online'
|
||||
status_enum = Status.online
|
||||
elif status is Status.offline:
|
||||
status = 'invisible'
|
||||
status_value = 'invisible'
|
||||
status_enum = Status.offline
|
||||
else:
|
||||
status_enum = status
|
||||
status = str(status)
|
||||
status_value = str(status)
|
||||
|
||||
if shard_id is None:
|
||||
for shard in self.__shards.values():
|
||||
await shard.ws.change_presence(activity=activity, status=status)
|
||||
await shard.ws.change_presence(activity=activity, status=status_value)
|
||||
|
||||
guilds = self._connection.guilds
|
||||
else:
|
||||
shard = self.__shards[shard_id]
|
||||
await shard.ws.change_presence(activity=activity, status=status)
|
||||
await shard.ws.change_presence(activity=activity, status=status_value)
|
||||
guilds = [g for g in self._connection.guilds if g.shard_id == shard_id]
|
||||
|
||||
activities = () if activity is None else (activity,)
|
||||
@@ -493,10 +528,11 @@ class AutoShardedClient(Client):
|
||||
if me is None:
|
||||
continue
|
||||
|
||||
me.activities = activities
|
||||
# Member.activities is typehinted as Tuple[ActivityType, ...], we may be setting it as Tuple[BaseActivity, ...]
|
||||
me.activities = activities # type: ignore
|
||||
me.status = status_enum
|
||||
|
||||
def is_ws_ratelimited(self):
|
||||
def is_ws_ratelimited(self) -> bool:
|
||||
""":class:`bool`: Whether the websocket is currently rate limited.
|
||||
|
||||
This can be useful to know when deciding whether you should query members
|
||||
|
||||
@@ -61,6 +61,10 @@ class StageInstance(Hashable):
|
||||
|
||||
Returns the stage instance's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the stage instance's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
id: :class:`int`
|
||||
@@ -106,7 +110,8 @@ class StageInstance(Hashable):
|
||||
@cached_slot_property('_cs_channel')
|
||||
def channel(self) -> Optional[StageChannel]:
|
||||
"""Optional[:class:`StageChannel`]: The channel that stage instance is running in."""
|
||||
return self._state.get_channel(self.channel_id)
|
||||
# the returned channel will always be a StageChannel or None
|
||||
return self._state.get_channel(self.channel_id) # type: ignore
|
||||
|
||||
def is_public(self) -> bool:
|
||||
return self.privacy_level is StagePrivacyLevel.public
|
||||
|
||||
579
discord/state.py
579
discord/state.py
File diff suppressed because it is too large
Load Diff
@@ -51,7 +51,8 @@ if TYPE_CHECKING:
|
||||
Sticker as StickerPayload,
|
||||
StandardSticker as StandardStickerPayload,
|
||||
GuildSticker as GuildStickerPayload,
|
||||
ListPremiumStickerPacks as ListPremiumStickerPacksPayload
|
||||
ListPremiumStickerPacks as ListPremiumStickerPacksPayload,
|
||||
EditGuildSticker,
|
||||
)
|
||||
|
||||
|
||||
@@ -66,6 +67,14 @@ class StickerPack(Hashable):
|
||||
|
||||
Returns the name of the sticker pack.
|
||||
|
||||
.. describe:: hash(x)
|
||||
|
||||
Returns the hash of the sticker pack.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the ID of the sticker pack.
|
||||
|
||||
.. describe:: x == y
|
||||
|
||||
Checks if the sticker pack is equal to another sticker pack.
|
||||
@@ -440,10 +449,10 @@ class GuildSticker(Sticker):
|
||||
description: str = MISSING,
|
||||
emoji: str = MISSING,
|
||||
reason: Optional[str] = None,
|
||||
) -> None:
|
||||
) -> GuildSticker:
|
||||
"""|coro|
|
||||
|
||||
Edits a :class:`Sticker` for the guild.
|
||||
Edits a :class:`GuildSticker` for the guild.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
@@ -462,8 +471,13 @@ class GuildSticker(Sticker):
|
||||
You are not allowed to edit stickers.
|
||||
HTTPException
|
||||
An error occurred editing the sticker.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`GuildSticker`
|
||||
The newly modified sticker.
|
||||
"""
|
||||
payload = {}
|
||||
payload: EditGuildSticker = {}
|
||||
|
||||
if name is not MISSING:
|
||||
payload['name'] = name
|
||||
@@ -482,8 +496,7 @@ class GuildSticker(Sticker):
|
||||
payload['tags'] = emoji
|
||||
|
||||
data: GuildStickerPayload = await self._state.http.modify_guild_sticker(self.guild_id, self.id, payload, reason)
|
||||
|
||||
self._from_data(data)
|
||||
return GuildSticker(state=self._state, data=data)
|
||||
|
||||
async def delete(self, *, reason: Optional[str] = None) -> None:
|
||||
"""|coro|
|
||||
|
||||
@@ -206,7 +206,7 @@ class Template:
|
||||
data = await self._state.http.create_from_template(self.code, name, region_value, icon)
|
||||
return Guild(data=data, state=self._state)
|
||||
|
||||
async def sync(self) -> None:
|
||||
async def sync(self) -> Template:
|
||||
"""|coro|
|
||||
|
||||
Sync the template to the guild's current state.
|
||||
@@ -216,6 +216,9 @@ class Template:
|
||||
|
||||
.. versionadded:: 1.7
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
The template is no longer edited in-place, instead it is returned.
|
||||
|
||||
Raises
|
||||
-------
|
||||
HTTPException
|
||||
@@ -224,17 +227,22 @@ class Template:
|
||||
You don't have permissions to edit the template.
|
||||
NotFound
|
||||
This template does not exist.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`Template`
|
||||
The newly edited template.
|
||||
"""
|
||||
|
||||
data = await self._state.http.sync_template(self.source_guild.id, self.code)
|
||||
self._store(data)
|
||||
return Template(state=self._state, data=data)
|
||||
|
||||
async def edit(
|
||||
self,
|
||||
*,
|
||||
name: str = MISSING,
|
||||
description: Optional[str] = MISSING,
|
||||
) -> None:
|
||||
) -> Template:
|
||||
"""|coro|
|
||||
|
||||
Edit the template metadata.
|
||||
@@ -244,6 +252,9 @@ class Template:
|
||||
|
||||
.. versionadded:: 1.7
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
The template is no longer edited in-place, instead it is returned.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
name: :class:`str`
|
||||
@@ -259,6 +270,11 @@ class Template:
|
||||
You don't have permissions to edit the template.
|
||||
NotFound
|
||||
This template does not exist.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`Template`
|
||||
The newly edited template.
|
||||
"""
|
||||
payload = {}
|
||||
|
||||
@@ -268,7 +284,7 @@ class Template:
|
||||
payload['description'] = description
|
||||
|
||||
data = await self._state.http.edit_template(self.source_guild.id, self.code, payload)
|
||||
self._store(data)
|
||||
return Template(state=self._state, data=data)
|
||||
|
||||
async def delete(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
@@ -46,6 +46,7 @@ if TYPE_CHECKING:
|
||||
ThreadMetadata,
|
||||
ThreadArchiveDuration,
|
||||
)
|
||||
from .types.snowflake import SnowflakeList
|
||||
from .guild import Guild
|
||||
from .channel import TextChannel, CategoryChannel
|
||||
from .member import Member
|
||||
@@ -73,6 +74,10 @@ class Thread(Messageable, Hashable):
|
||||
|
||||
Returns the thread's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the thread's ID.
|
||||
|
||||
.. describe:: str(x)
|
||||
|
||||
Returns the thread's name.
|
||||
@@ -110,6 +115,9 @@ class Thread(Messageable, Hashable):
|
||||
Whether the thread is archived.
|
||||
locked: :class:`bool`
|
||||
Whether the thread is locked.
|
||||
invitable: :class:`bool`
|
||||
Whether non-moderators can add other non-moderators to this thread.
|
||||
This is always ``True`` for public threads.
|
||||
archiver_id: Optional[:class:`int`]
|
||||
The user's ID that archived this thread.
|
||||
auto_archive_duration: :class:`int`
|
||||
@@ -135,6 +143,7 @@ class Thread(Messageable, Hashable):
|
||||
'me',
|
||||
'locked',
|
||||
'archived',
|
||||
'invitable',
|
||||
'archiver_id',
|
||||
'auto_archive_duration',
|
||||
'archive_timestamp',
|
||||
@@ -183,6 +192,7 @@ class Thread(Messageable, Hashable):
|
||||
self.auto_archive_duration = data['auto_archive_duration']
|
||||
self.archive_timestamp = parse_time(data['archive_timestamp'])
|
||||
self.locked = data.get('locked', False)
|
||||
self.invitable = data.get('invitable', True)
|
||||
|
||||
def _update(self, data):
|
||||
try:
|
||||
@@ -394,7 +404,7 @@ class Thread(Messageable, Hashable):
|
||||
if len(messages) > 100:
|
||||
raise ClientException('Can only bulk delete messages up to 100 messages')
|
||||
|
||||
message_ids = [m.id for m in messages]
|
||||
message_ids: SnowflakeList = [m.id for m in messages]
|
||||
await self._state.http.delete_messages(self.id, message_ids)
|
||||
|
||||
async def purge(
|
||||
@@ -520,9 +530,10 @@ class Thread(Messageable, Hashable):
|
||||
name: str = MISSING,
|
||||
archived: bool = MISSING,
|
||||
locked: bool = MISSING,
|
||||
invitable: bool = MISSING,
|
||||
slowmode_delay: int = MISSING,
|
||||
auto_archive_duration: ThreadArchiveDuration = MISSING,
|
||||
):
|
||||
) -> Thread:
|
||||
"""|coro|
|
||||
|
||||
Edits the thread.
|
||||
@@ -542,6 +553,9 @@ class Thread(Messageable, Hashable):
|
||||
Whether to archive the thread or not.
|
||||
locked: :class:`bool`
|
||||
Whether to lock the thread or not.
|
||||
invitable: :class:`bool`
|
||||
Whether non-moderators can add other non-moderators to this thread.
|
||||
Only available for private threads.
|
||||
auto_archive_duration: :class:`int`
|
||||
The new duration in minutes before a thread is automatically archived for inactivity.
|
||||
Must be one of ``60``, ``1440``, ``4320``, or ``10080``.
|
||||
@@ -555,6 +569,11 @@ class Thread(Messageable, Hashable):
|
||||
You do not have permissions to edit the thread.
|
||||
HTTPException
|
||||
Editing the thread failed.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`Thread`
|
||||
The newly edited thread.
|
||||
"""
|
||||
payload = {}
|
||||
if name is not MISSING:
|
||||
@@ -565,20 +584,22 @@ class Thread(Messageable, Hashable):
|
||||
payload['auto_archive_duration'] = auto_archive_duration
|
||||
if locked is not MISSING:
|
||||
payload['locked'] = locked
|
||||
if invitable is not MISSING:
|
||||
payload['invitable'] = invitable
|
||||
if slowmode_delay is not MISSING:
|
||||
payload['rate_limit_per_user'] = slowmode_delay
|
||||
|
||||
await self._state.http.edit_channel(self.id, **payload)
|
||||
data = await self._state.http.edit_channel(self.id, **payload)
|
||||
# The data payload will always be a Thread payload
|
||||
return Thread(data=data, state=self._state, guild=self.guild) # type: ignore
|
||||
|
||||
async def join(self):
|
||||
"""|coro|
|
||||
|
||||
Joins this thread.
|
||||
|
||||
You must have :attr:`~Permissions.send_messages` and :attr:`~Permissions.use_threads`
|
||||
to join a public thread. If the thread is private then :attr:`~Permissions.send_messages`
|
||||
and either :attr:`~Permissions.use_private_threads` or :attr:`~Permissions.manage_messages`
|
||||
is required to join the thread.
|
||||
You must have :attr:`~Permissions.send_messages_in_threads` to join a thread.
|
||||
If the thread is private, :attr:`~Permissions.manage_threads` is also needed.
|
||||
|
||||
Raises
|
||||
-------
|
||||
@@ -731,6 +752,10 @@ class ThreadMember(Hashable):
|
||||
|
||||
Returns the thread member's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the thread member's ID.
|
||||
|
||||
.. describe:: str(x)
|
||||
|
||||
Returns the thread member's name.
|
||||
|
||||
@@ -61,6 +61,7 @@ class _PartialAppInfoOptional(TypedDict, total=False):
|
||||
terms_of_service_url: str
|
||||
privacy_policy_url: str
|
||||
max_participants: int
|
||||
flags: int
|
||||
|
||||
class PartialAppInfo(_PartialAppInfoOptional, BaseAppInfo):
|
||||
pass
|
||||
|
||||
@@ -39,6 +39,7 @@ class PartialMember(TypedDict):
|
||||
|
||||
|
||||
class Member(PartialMember, total=False):
|
||||
avatar: str
|
||||
user: User
|
||||
nick: str
|
||||
premium_since: str
|
||||
@@ -46,16 +47,17 @@ class Member(PartialMember, total=False):
|
||||
permissions: str
|
||||
|
||||
|
||||
class _OptionalGatewayMember(PartialMember, total=False):
|
||||
class _OptionalMemberWithUser(PartialMember, total=False):
|
||||
avatar: str
|
||||
nick: str
|
||||
premium_since: str
|
||||
pending: bool
|
||||
permissions: str
|
||||
|
||||
|
||||
class GatewayMember(_OptionalGatewayMember):
|
||||
class MemberWithUser(_OptionalMemberWithUser):
|
||||
user: User
|
||||
|
||||
|
||||
class UserWithMember(User, total=False):
|
||||
member: _OptionalGatewayMember
|
||||
member: _OptionalMemberWithUser
|
||||
|
||||
@@ -41,6 +41,7 @@ class ThreadMember(TypedDict):
|
||||
class _ThreadMetadataOptional(TypedDict, total=False):
|
||||
archiver_id: Snowflake
|
||||
locked: bool
|
||||
invitable: bool
|
||||
|
||||
|
||||
class ThreadMetadata(_ThreadMetadataOptional):
|
||||
|
||||
@@ -24,14 +24,14 @@ DEALINGS IN THE SOFTWARE.
|
||||
|
||||
from typing import Optional, TypedDict, List, Literal
|
||||
from .snowflake import Snowflake
|
||||
from .member import Member
|
||||
from .member import MemberWithUser
|
||||
|
||||
|
||||
SupportedModes = Literal['xsalsa20_poly1305_lite', 'xsalsa20_poly1305_suffix', 'xsalsa20_poly1305']
|
||||
|
||||
|
||||
class _PartialVoiceStateOptional(TypedDict, total=False):
|
||||
member: Member
|
||||
member: MemberWithUser
|
||||
self_stream: bool
|
||||
|
||||
|
||||
|
||||
@@ -197,13 +197,13 @@ class Select(Item[V]):
|
||||
-----------
|
||||
label: :class:`str`
|
||||
The label of the option. This is displayed to users.
|
||||
Can only be up to 25 characters.
|
||||
Can only be up to 100 characters.
|
||||
value: :class:`str`
|
||||
The value of the option. This is not displayed to users.
|
||||
If not given, defaults to the label. Can only be up to 100 characters.
|
||||
description: Optional[:class:`str`]
|
||||
An additional description of the option, if any.
|
||||
Can only be up to 50 characters.
|
||||
Can only be up to 100 characters.
|
||||
emoji: Optional[Union[:class:`str`, :class:`.Emoji`, :class:`.PartialEmoji`]]
|
||||
The emoji of the option, if available. This can either be a string representing
|
||||
the custom or unicode emoji or an instance of :class:`.PartialEmoji` or :class:`.Emoji`.
|
||||
|
||||
@@ -35,7 +35,7 @@ from .utils import snowflake_time, _bytes_to_base64_data, MISSING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
from .channel import DMChannel
|
||||
from .guild import Guild
|
||||
from .message import Message
|
||||
@@ -49,7 +49,6 @@ __all__ = (
|
||||
'ClientUser',
|
||||
)
|
||||
|
||||
U = TypeVar('U', bound='User')
|
||||
BU = TypeVar('BU', bound='BaseUser')
|
||||
|
||||
|
||||
@@ -59,7 +58,18 @@ class _UserTag:
|
||||
|
||||
|
||||
class BaseUser(_UserTag):
|
||||
__slots__ = ('name', 'id', 'discriminator', '_avatar', '_banner', '_accent_colour', 'bot', 'system', '_public_flags', '_state')
|
||||
__slots__ = (
|
||||
'name',
|
||||
'id',
|
||||
'discriminator',
|
||||
'_avatar',
|
||||
'_banner',
|
||||
'_accent_colour',
|
||||
'bot',
|
||||
'system',
|
||||
'_public_flags',
|
||||
'_state',
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
name: str
|
||||
@@ -68,7 +78,7 @@ class BaseUser(_UserTag):
|
||||
bot: bool
|
||||
system: bool
|
||||
_state: ConnectionState
|
||||
_avatar: str
|
||||
_avatar: Optional[str]
|
||||
_banner: Optional[str]
|
||||
_accent_colour: Optional[str]
|
||||
_public_flags: int
|
||||
@@ -86,6 +96,9 @@ class BaseUser(_UserTag):
|
||||
def __str__(self) -> str:
|
||||
return f'{self.name}#{self.discriminator}'
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.id
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, _UserTag) and other.id == self.id
|
||||
|
||||
@@ -137,22 +150,31 @@ class BaseUser(_UserTag):
|
||||
return PublicUserFlags._from_value(self._public_flags)
|
||||
|
||||
@property
|
||||
def avatar(self) -> Asset:
|
||||
""":class:`Asset`: Returns an :class:`Asset` for the avatar the user has.
|
||||
def avatar(self) -> Optional[Asset]:
|
||||
"""Optional[:class:`Asset`]: Returns an :class:`Asset` for the avatar the user has.
|
||||
|
||||
If the user does not have a traditional avatar, an asset for
|
||||
the default avatar is returned instead.
|
||||
If the user does not have a traditional avatar, ``None`` is returned.
|
||||
If you want the avatar that a user has displayed, consider :attr:`display_avatar`.
|
||||
"""
|
||||
if self._avatar is None:
|
||||
return Asset._from_default_avatar(self._state, int(self.discriminator) % len(DefaultAvatar))
|
||||
else:
|
||||
if self._avatar is not None:
|
||||
return Asset._from_avatar(self._state, self.id, self._avatar)
|
||||
return None
|
||||
|
||||
@property
|
||||
def default_avatar(self) -> Asset:
|
||||
""":class:`Asset`: Returns the default avatar for a given user. This is calculated by the user's discriminator."""
|
||||
return Asset._from_default_avatar(self._state, int(self.discriminator) % len(DefaultAvatar))
|
||||
|
||||
@property
|
||||
def display_avatar(self) -> Asset:
|
||||
""":class:`Asset`: Returns the user's display avatar.
|
||||
|
||||
For regular users this is just their default avatar or uploaded avatar.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return self.avatar or self.default_avatar
|
||||
|
||||
@property
|
||||
def banner(self) -> Optional[Asset]:
|
||||
"""Optional[:class:`Asset`]: Returns the user's banner asset, if available.
|
||||
@@ -309,7 +331,7 @@ class ClientUser(BaseUser):
|
||||
locale: Optional[str]
|
||||
mfa_enabled: bool
|
||||
_flags: int
|
||||
|
||||
|
||||
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
|
||||
super().__init__(state=state, data=data)
|
||||
|
||||
@@ -327,7 +349,7 @@ class ClientUser(BaseUser):
|
||||
self._flags = data.get('flags', 0)
|
||||
self.mfa_enabled = data.get('mfa_enabled', False)
|
||||
|
||||
async def edit(self, *, username: str = MISSING, avatar: bytes = MISSING) -> None:
|
||||
async def edit(self, *, username: str = MISSING, avatar: bytes = MISSING) -> ClientUser:
|
||||
"""|coro|
|
||||
|
||||
Edits the current profile of the client.
|
||||
@@ -341,6 +363,9 @@ class ClientUser(BaseUser):
|
||||
|
||||
The only image formats supported for uploading is JPEG and PNG.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
The edit is no longer in-place, instead the newly edited client user is returned.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
username: :class:`str`
|
||||
@@ -355,6 +380,11 @@ class ClientUser(BaseUser):
|
||||
Editing your profile failed.
|
||||
InvalidArgument
|
||||
Wrong image format passed for ``avatar``.
|
||||
|
||||
Returns
|
||||
---------
|
||||
:class:`ClientUser`
|
||||
The newly edited client user.
|
||||
"""
|
||||
payload: Dict[str, Any] = {}
|
||||
if username is not MISSING:
|
||||
@@ -364,7 +394,7 @@ class ClientUser(BaseUser):
|
||||
payload['avatar'] = _bytes_to_base64_data(avatar)
|
||||
|
||||
data: UserPayload = await self._state.http.edit_profile(payload)
|
||||
self._update(data)
|
||||
return ClientUser(state=self._state, data=data)
|
||||
|
||||
|
||||
class User(BaseUser, discord.abc.Messageable):
|
||||
@@ -388,6 +418,10 @@ class User(BaseUser, discord.abc.Messageable):
|
||||
|
||||
Returns the user's name with discriminator.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the user's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
name: :class:`str`
|
||||
@@ -419,7 +453,7 @@ class User(BaseUser, discord.abc.Messageable):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _copy(cls: Type[U], user: U) -> U:
|
||||
def _copy(cls, user: User):
|
||||
self = super()._copy(user)
|
||||
self._stored = False
|
||||
return self
|
||||
|
||||
@@ -120,6 +120,9 @@ class _cached_property:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from functools import cached_property as cached_property
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from .permissions import Permissions
|
||||
from .abc import Snowflake
|
||||
from .invite import Invite
|
||||
@@ -129,6 +132,8 @@ if TYPE_CHECKING:
|
||||
headers: Mapping[str, Any]
|
||||
|
||||
|
||||
P = ParamSpec('P')
|
||||
|
||||
else:
|
||||
cached_property = _cached_property
|
||||
|
||||
@@ -231,8 +236,8 @@ def parse_time(timestamp: Optional[str]) -> Optional[datetime.datetime]:
|
||||
return None
|
||||
|
||||
|
||||
def copy_doc(original: Callable[..., Any]) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(overriden: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def copy_doc(original: Callable) -> Callable[[T], T]:
|
||||
def decorator(overriden: T) -> T:
|
||||
overriden.__doc__ = original.__doc__
|
||||
overriden.__signature__ = _signature(original) # type: ignore
|
||||
return overriden
|
||||
@@ -240,10 +245,10 @@ def copy_doc(original: Callable[..., Any]) -> Callable[[Callable[..., Any]], Cal
|
||||
return decorator
|
||||
|
||||
|
||||
def deprecated(instead: Optional[str] = None) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
||||
def actual_decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
def deprecated(instead: Optional[str] = None) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
||||
def actual_decorator(func: Callable[P, T]) -> Callable[P, T]:
|
||||
@functools.wraps(func)
|
||||
def decorated(*args, **kwargs) -> T:
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
warnings.simplefilter('always', DeprecationWarning) # turn off filter
|
||||
if instead:
|
||||
fmt = "{0.__name__} is deprecated, use {1} instead."
|
||||
@@ -267,7 +272,7 @@ def oauth_url(
|
||||
redirect_uri: str = MISSING,
|
||||
scopes: Iterable[str] = MISSING,
|
||||
disable_guild_select: bool = False,
|
||||
):
|
||||
) -> str:
|
||||
"""A helper function that returns the OAuth2 URL for inviting the bot
|
||||
into guilds.
|
||||
|
||||
@@ -479,17 +484,17 @@ def _bytes_to_base64_data(data: bytes) -> str:
|
||||
|
||||
if HAS_ORJSON:
|
||||
|
||||
def to_json(obj: Any) -> str: # type: ignore
|
||||
def _to_json(obj: Any) -> str: # type: ignore
|
||||
return orjson.dumps(obj).decode('utf-8')
|
||||
|
||||
from_json = orjson.loads # type: ignore
|
||||
_from_json = orjson.loads # type: ignore
|
||||
|
||||
else:
|
||||
|
||||
def to_json(obj: Any) -> str:
|
||||
def _to_json(obj: Any) -> str:
|
||||
return json.dumps(obj, separators=(',', ':'), ensure_ascii=True)
|
||||
|
||||
from_json = json.loads
|
||||
_from_json = json.loads
|
||||
|
||||
|
||||
def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
|
||||
|
||||
@@ -84,7 +84,7 @@ __all__ = (
|
||||
|
||||
|
||||
|
||||
log: logging.Logger = logging.getLogger(__name__)
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
class VoiceProtocol:
|
||||
"""A class that represents the Discord voice protocol.
|
||||
@@ -304,7 +304,7 @@ class VoiceClient(VoiceProtocol):
|
||||
|
||||
async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
|
||||
if self._voice_server_complete.is_set():
|
||||
log.info('Ignoring extraneous voice server update.')
|
||||
_log.info('Ignoring extraneous voice server update.')
|
||||
return
|
||||
|
||||
self.token = data.get('token')
|
||||
@@ -312,7 +312,7 @@ class VoiceClient(VoiceProtocol):
|
||||
endpoint = data.get('endpoint')
|
||||
|
||||
if endpoint is None or self.token is None:
|
||||
log.warning('Awaiting endpoint... This requires waiting. ' \
|
||||
_log.warning('Awaiting endpoint... This requires waiting. ' \
|
||||
'If timeout occurred considering raising the timeout and reconnecting.')
|
||||
return
|
||||
|
||||
@@ -338,18 +338,18 @@ class VoiceClient(VoiceProtocol):
|
||||
await self.channel.guild.change_voice_state(channel=self.channel)
|
||||
|
||||
async def voice_disconnect(self) -> None:
|
||||
log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id)
|
||||
_log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id)
|
||||
await self.channel.guild.change_voice_state(channel=None)
|
||||
|
||||
def prepare_handshake(self) -> None:
|
||||
self._voice_state_complete.clear()
|
||||
self._voice_server_complete.clear()
|
||||
self._handshaking = True
|
||||
log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1)
|
||||
_log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1)
|
||||
self._connections += 1
|
||||
|
||||
def finish_handshake(self) -> None:
|
||||
log.info('Voice handshake complete. Endpoint found %s', self.endpoint)
|
||||
_log.info('Voice handshake complete. Endpoint found %s', self.endpoint)
|
||||
self._handshaking = False
|
||||
self._voice_server_complete.clear()
|
||||
self._voice_state_complete.clear()
|
||||
@@ -363,7 +363,7 @@ class VoiceClient(VoiceProtocol):
|
||||
return ws
|
||||
|
||||
async def connect(self, *, reconnect: bool, timeout: float) ->None:
|
||||
log.info('Connecting to voice...')
|
||||
_log.info('Connecting to voice...')
|
||||
self.timeout = timeout
|
||||
|
||||
for i in range(5):
|
||||
@@ -391,7 +391,7 @@ class VoiceClient(VoiceProtocol):
|
||||
break
|
||||
except (ConnectionClosed, asyncio.TimeoutError):
|
||||
if reconnect:
|
||||
log.exception('Failed to connect to voice... Retrying...')
|
||||
_log.exception('Failed to connect to voice... Retrying...')
|
||||
await asyncio.sleep(1 + i * 2.0)
|
||||
await self.voice_disconnect()
|
||||
continue
|
||||
@@ -456,14 +456,14 @@ class VoiceClient(VoiceProtocol):
|
||||
# 4014 - voice channel has been deleted.
|
||||
# 4015 - voice server has crashed
|
||||
if exc.code in (1000, 4015):
|
||||
log.info('Disconnecting from voice normally, close code %d.', exc.code)
|
||||
_log.info('Disconnecting from voice normally, close code %d.', exc.code)
|
||||
await self.disconnect()
|
||||
break
|
||||
if exc.code == 4014:
|
||||
log.info('Disconnected from voice by force... potentially reconnecting.')
|
||||
_log.info('Disconnected from voice by force... potentially reconnecting.')
|
||||
successful = await self.potential_reconnect()
|
||||
if not successful:
|
||||
log.info('Reconnect was unsuccessful, disconnecting from voice normally...')
|
||||
_log.info('Reconnect was unsuccessful, disconnecting from voice normally...')
|
||||
await self.disconnect()
|
||||
break
|
||||
else:
|
||||
@@ -474,7 +474,7 @@ class VoiceClient(VoiceProtocol):
|
||||
raise
|
||||
|
||||
retry = backoff.delay()
|
||||
log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry)
|
||||
_log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry)
|
||||
self._connected.clear()
|
||||
await asyncio.sleep(retry)
|
||||
await self.voice_disconnect()
|
||||
@@ -482,7 +482,7 @@ class VoiceClient(VoiceProtocol):
|
||||
await self.connect(reconnect=True, timeout=self.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
# at this point we've retried 5 times... let's continue the loop.
|
||||
log.warning('Could not connect to voice... Retrying...')
|
||||
_log.warning('Could not connect to voice... Retrying...')
|
||||
continue
|
||||
|
||||
async def disconnect(self, *, force: bool = False) -> None:
|
||||
@@ -674,6 +674,6 @@ class VoiceClient(VoiceProtocol):
|
||||
try:
|
||||
self.socket.sendto(packet, (self.endpoint_ip, self.voice_port))
|
||||
except BlockingIOError:
|
||||
log.warning('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp)
|
||||
_log.warning('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp)
|
||||
|
||||
self.checked_add('timestamp', opus.Encoder.SAMPLES_PER_FRAME, 4294967295)
|
||||
|
||||
@@ -24,7 +24,6 @@ DEALINGS IN THE SOFTWARE.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import asyncio
|
||||
import json
|
||||
@@ -32,6 +31,7 @@ import re
|
||||
|
||||
from urllib.parse import quote as urlquote
|
||||
from typing import Any, Dict, List, Literal, NamedTuple, Optional, TYPE_CHECKING, Tuple, Union, overload
|
||||
from contextvars import ContextVar
|
||||
|
||||
import aiohttp
|
||||
|
||||
@@ -43,7 +43,7 @@ from ..user import BaseUser, User
|
||||
from ..asset import Asset
|
||||
from ..http import Route
|
||||
from ..mixins import Hashable
|
||||
from ..object import Object
|
||||
from ..channel import PartialMessageable
|
||||
|
||||
__all__ = (
|
||||
'Webhook',
|
||||
@@ -52,15 +52,20 @@ __all__ = (
|
||||
'PartialWebhookGuild',
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..file import File
|
||||
from ..embeds import Embed
|
||||
from ..mentions import AllowedMentions
|
||||
from ..state import ConnectionState
|
||||
from ..http import Response
|
||||
from ..types.webhook import (
|
||||
Webhook as WebhookPayload,
|
||||
)
|
||||
from ..types.message import (
|
||||
Message as MessagePayload,
|
||||
)
|
||||
from ..guild import Guild
|
||||
from ..channel import TextChannel
|
||||
from ..abc import Snowflake
|
||||
@@ -116,7 +121,7 @@ class AsyncWebhookAdapter:
|
||||
|
||||
if payload is not None:
|
||||
headers['Content-Type'] = 'application/json'
|
||||
to_send = utils.to_json(payload)
|
||||
to_send = utils._to_json(payload)
|
||||
|
||||
if auth_token is not None:
|
||||
headers['Authorization'] = f'Bot {auth_token}'
|
||||
@@ -143,7 +148,7 @@ class AsyncWebhookAdapter:
|
||||
|
||||
try:
|
||||
async with session.request(method, url, data=to_send, headers=headers, params=params) as response:
|
||||
log.debug(
|
||||
_log.debug(
|
||||
'Webhook ID %s with %s %s has returned status code %s',
|
||||
webhook_id,
|
||||
method,
|
||||
@@ -157,7 +162,7 @@ class AsyncWebhookAdapter:
|
||||
remaining = response.headers.get('X-Ratelimit-Remaining')
|
||||
if remaining == '0' and response.status != 429:
|
||||
delta = utils._parse_ratelimit_header(response)
|
||||
log.debug(
|
||||
_log.debug(
|
||||
'Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds', webhook_id, delta
|
||||
)
|
||||
lock.delay_by(delta)
|
||||
@@ -170,7 +175,7 @@ class AsyncWebhookAdapter:
|
||||
raise HTTPException(response, data)
|
||||
|
||||
retry_after: float = data['retry_after'] # type: ignore
|
||||
log.warning('Webhook ID %s is rate limited. Retrying in %.2f seconds', webhook_id, retry_after)
|
||||
_log.warning('Webhook ID %s is rate limited. Retrying in %.2f seconds', webhook_id, retry_after)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
|
||||
@@ -205,7 +210,7 @@ class AsyncWebhookAdapter:
|
||||
token: Optional[str] = None,
|
||||
session: aiohttp.ClientSession,
|
||||
reason: Optional[str] = None,
|
||||
):
|
||||
) -> Response[None]:
|
||||
route = Route('DELETE', '/webhooks/{webhook_id}', webhook_id=webhook_id)
|
||||
return self.request(route, session, reason=reason, auth_token=token)
|
||||
|
||||
@@ -216,7 +221,7 @@ class AsyncWebhookAdapter:
|
||||
*,
|
||||
session: aiohttp.ClientSession,
|
||||
reason: Optional[str] = None,
|
||||
):
|
||||
) -> Response[None]:
|
||||
route = Route('DELETE', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
|
||||
return self.request(route, session, reason=reason)
|
||||
|
||||
@@ -228,7 +233,7 @@ class AsyncWebhookAdapter:
|
||||
*,
|
||||
session: aiohttp.ClientSession,
|
||||
reason: Optional[str] = None,
|
||||
):
|
||||
) -> Response[WebhookPayload]:
|
||||
route = Route('PATCH', '/webhooks/{webhook_id}', webhook_id=webhook_id)
|
||||
return self.request(route, session, reason=reason, payload=payload, auth_token=token)
|
||||
|
||||
@@ -240,7 +245,7 @@ class AsyncWebhookAdapter:
|
||||
*,
|
||||
session: aiohttp.ClientSession,
|
||||
reason: Optional[str] = None,
|
||||
):
|
||||
) -> Response[WebhookPayload]:
|
||||
route = Route('PATCH', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
|
||||
return self.request(route, session, reason=reason, payload=payload)
|
||||
|
||||
@@ -255,7 +260,7 @@ class AsyncWebhookAdapter:
|
||||
files: Optional[List[File]] = None,
|
||||
thread_id: Optional[int] = None,
|
||||
wait: bool = False,
|
||||
):
|
||||
) -> Response[Optional[MessagePayload]]:
|
||||
params = {'wait': int(wait)}
|
||||
if thread_id:
|
||||
params['thread_id'] = thread_id
|
||||
@@ -269,7 +274,7 @@ class AsyncWebhookAdapter:
|
||||
message_id: int,
|
||||
*,
|
||||
session: aiohttp.ClientSession,
|
||||
):
|
||||
) -> Response[MessagePayload]:
|
||||
route = Route(
|
||||
'GET',
|
||||
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
|
||||
@@ -289,7 +294,7 @@ class AsyncWebhookAdapter:
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
multipart: Optional[List[Dict[str, Any]]] = None,
|
||||
files: Optional[List[File]] = None,
|
||||
):
|
||||
) -> Response[Message]:
|
||||
route = Route(
|
||||
'PATCH',
|
||||
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
|
||||
@@ -306,7 +311,7 @@ class AsyncWebhookAdapter:
|
||||
message_id: int,
|
||||
*,
|
||||
session: aiohttp.ClientSession,
|
||||
):
|
||||
) -> Response[None]:
|
||||
route = Route(
|
||||
'DELETE',
|
||||
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
|
||||
@@ -322,7 +327,7 @@ class AsyncWebhookAdapter:
|
||||
token: str,
|
||||
*,
|
||||
session: aiohttp.ClientSession,
|
||||
):
|
||||
) -> Response[WebhookPayload]:
|
||||
route = Route('GET', '/webhooks/{webhook_id}', webhook_id=webhook_id)
|
||||
return self.request(route, session=session, auth_token=token)
|
||||
|
||||
@@ -332,7 +337,7 @@ class AsyncWebhookAdapter:
|
||||
token: str,
|
||||
*,
|
||||
session: aiohttp.ClientSession,
|
||||
):
|
||||
) -> Response[WebhookPayload]:
|
||||
route = Route('GET', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
|
||||
return self.request(route, session=session)
|
||||
|
||||
@@ -344,7 +349,7 @@ class AsyncWebhookAdapter:
|
||||
session: aiohttp.ClientSession,
|
||||
type: int,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
) -> Response[None]:
|
||||
payload: Dict[str, Any] = {
|
||||
'type': type,
|
||||
}
|
||||
@@ -367,7 +372,7 @@ class AsyncWebhookAdapter:
|
||||
token: str,
|
||||
*,
|
||||
session: aiohttp.ClientSession,
|
||||
):
|
||||
) -> Response[MessagePayload]:
|
||||
r = Route(
|
||||
'GET',
|
||||
'/webhooks/{webhook_id}/{webhook_token}/messages/@original',
|
||||
@@ -385,7 +390,7 @@ class AsyncWebhookAdapter:
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
multipart: Optional[List[Dict[str, Any]]] = None,
|
||||
files: Optional[List[File]] = None,
|
||||
):
|
||||
) -> Response[MessagePayload]:
|
||||
r = Route(
|
||||
'PATCH',
|
||||
'/webhooks/{webhook_id}/{webhook_token}/messages/@original',
|
||||
@@ -400,7 +405,7 @@ class AsyncWebhookAdapter:
|
||||
token: str,
|
||||
*,
|
||||
session: aiohttp.ClientSession,
|
||||
):
|
||||
) -> Response[None]:
|
||||
r = Route(
|
||||
'DELETE',
|
||||
'/webhooks/{webhook_id}/{wehook_token}/messages/@original',
|
||||
@@ -420,7 +425,7 @@ def handle_message_parameters(
|
||||
content: Optional[str] = MISSING,
|
||||
*,
|
||||
username: str = MISSING,
|
||||
avatar_url: str = MISSING,
|
||||
avatar_url: Any = MISSING,
|
||||
tts: bool = False,
|
||||
ephemeral: bool = False,
|
||||
file: File = MISSING,
|
||||
@@ -481,7 +486,7 @@ def handle_message_parameters(
|
||||
files = [file]
|
||||
|
||||
if files:
|
||||
multipart.append({'name': 'payload_json', 'value': utils.to_json(payload)})
|
||||
multipart.append({'name': 'payload_json', 'value': utils._to_json(payload)})
|
||||
payload = None
|
||||
if len(files) == 1:
|
||||
file = files[0]
|
||||
@@ -507,7 +512,7 @@ def handle_message_parameters(
|
||||
return ExecuteWebhookParameters(payload=payload, multipart=multipart, files=files)
|
||||
|
||||
|
||||
async_context = contextvars.ContextVar('async_webhook_context', default=AsyncWebhookAdapter())
|
||||
async_context: ContextVar[AsyncWebhookAdapter] = ContextVar('async_webhook_context', default=AsyncWebhookAdapter())
|
||||
|
||||
|
||||
class PartialWebhookChannel(Hashable):
|
||||
@@ -579,10 +584,11 @@ class _FriendlyHttpAttributeErrorHelper:
|
||||
class _WebhookState:
|
||||
__slots__ = ('_parent', '_webhook')
|
||||
|
||||
def __init__(self, webhook, parent):
|
||||
self._webhook = webhook
|
||||
def __init__(self, webhook: Any, parent: Optional[Union[ConnectionState, _WebhookState]]):
|
||||
self._webhook: Any = webhook
|
||||
|
||||
if isinstance(parent, self.__class__):
|
||||
self._parent: Optional[ConnectionState]
|
||||
if isinstance(parent, _WebhookState):
|
||||
self._parent = None
|
||||
else:
|
||||
self._parent = parent
|
||||
@@ -595,10 +601,12 @@ class _WebhookState:
|
||||
def store_user(self, data):
|
||||
if self._parent is not None:
|
||||
return self._parent.store_user(data)
|
||||
return BaseUser(state=self, data=data)
|
||||
# state parameter is artificial
|
||||
return BaseUser(state=self, data=data) # type: ignore
|
||||
|
||||
def create_user(self, data):
|
||||
return BaseUser(state=self, data=data)
|
||||
# state parameter is artificial
|
||||
return BaseUser(state=self, data=data) # type: ignore
|
||||
|
||||
@property
|
||||
def http(self):
|
||||
@@ -639,13 +647,16 @@ class WebhookMessage(Message):
|
||||
files: List[File] = MISSING,
|
||||
view: Optional[View] = MISSING,
|
||||
allowed_mentions: Optional[AllowedMentions] = None,
|
||||
):
|
||||
) -> WebhookMessage:
|
||||
"""|coro|
|
||||
|
||||
Edits the message.
|
||||
|
||||
.. versionadded:: 1.6
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
The edit is no longer in-place, instead the newly edited message is returned.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
content: Optional[:class:`str`]
|
||||
@@ -685,8 +696,13 @@ class WebhookMessage(Message):
|
||||
The length of ``embeds`` was invalid
|
||||
InvalidArgument
|
||||
There was no token associated with this webhook.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`WebhookMessage`
|
||||
The newly edited message.
|
||||
"""
|
||||
await self._state._webhook.edit_message(
|
||||
return await self._state._webhook.edit_message(
|
||||
self.id,
|
||||
content=content,
|
||||
embeds=embeds,
|
||||
@@ -748,9 +764,9 @@ class BaseWebhook(Hashable):
|
||||
'_state',
|
||||
)
|
||||
|
||||
def __init__(self, data: WebhookPayload, token: Optional[str] = None, state=None):
|
||||
def __init__(self, data: WebhookPayload, token: Optional[str] = None, state: Optional[ConnectionState] = None):
|
||||
self.auth_token: Optional[str] = token
|
||||
self._state = state or _WebhookState(self, parent=state)
|
||||
self._state: Union[ConnectionState, _WebhookState] = state or _WebhookState(self, parent=state)
|
||||
self._update(data)
|
||||
|
||||
def _update(self, data: WebhookPayload):
|
||||
@@ -765,10 +781,8 @@ class BaseWebhook(Hashable):
|
||||
user = data.get('user')
|
||||
self.user: Optional[Union[BaseUser, User]] = None
|
||||
if user is not None:
|
||||
if self._state is None:
|
||||
self.user = BaseUser(state=None, data=user)
|
||||
else:
|
||||
self.user = User(state=self._state, data=user)
|
||||
# state parameter may be _WebhookState
|
||||
self.user = User(state=self._state, data=user) # type: ignore
|
||||
|
||||
source_channel = data.get('source_channel')
|
||||
if source_channel:
|
||||
@@ -872,6 +886,10 @@ class Webhook(BaseWebhook):
|
||||
|
||||
Returns the webhooks's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the webhooks's ID.
|
||||
|
||||
.. versionchanged:: 1.4
|
||||
Webhooks are now comparable and hashable.
|
||||
|
||||
@@ -919,12 +937,12 @@ class Webhook(BaseWebhook):
|
||||
return f'<Webhook id={self.id!r}>'
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
def url(self) -> str:
|
||||
""":class:`str` : Returns the webhook's url."""
|
||||
return f'https://discord.com/api/webhooks/{self.id}/{self.token}'
|
||||
|
||||
@classmethod
|
||||
def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None):
|
||||
def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook:
|
||||
"""Creates a partial :class:`Webhook`.
|
||||
|
||||
Parameters
|
||||
@@ -960,7 +978,7 @@ class Webhook(BaseWebhook):
|
||||
return cls(data, session, token=bot_token)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None):
|
||||
def from_url(cls, url: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook:
|
||||
"""Creates a partial :class:`Webhook` from a webhook URL.
|
||||
|
||||
Parameters
|
||||
@@ -999,7 +1017,7 @@ class Webhook(BaseWebhook):
|
||||
return cls(data, session, token=bot_token) # type: ignore
|
||||
|
||||
@classmethod
|
||||
def _as_follower(cls, data, *, channel, user):
|
||||
def _as_follower(cls, data, *, channel, user) -> Webhook:
|
||||
name = f"{channel.guild} #{channel}"
|
||||
feed: WebhookPayload = {
|
||||
'id': data['webhook_id'],
|
||||
@@ -1015,7 +1033,7 @@ class Webhook(BaseWebhook):
|
||||
return cls(feed, session=session, state=state, token=state.http.token)
|
||||
|
||||
@classmethod
|
||||
def from_state(cls, data, state):
|
||||
def from_state(cls, data, state) -> Webhook:
|
||||
session = state.http._HTTPClient__session
|
||||
return cls(data, session=session, state=state, token=state.http.token)
|
||||
|
||||
@@ -1111,7 +1129,7 @@ class Webhook(BaseWebhook):
|
||||
avatar: Optional[bytes] = MISSING,
|
||||
channel: Optional[Snowflake] = None,
|
||||
prefer_auth: bool = True,
|
||||
):
|
||||
) -> Webhook:
|
||||
"""|coro|
|
||||
|
||||
Edits this Webhook.
|
||||
@@ -1158,6 +1176,7 @@ class Webhook(BaseWebhook):
|
||||
|
||||
adapter = async_context.get()
|
||||
|
||||
data: Optional[WebhookPayload] = None
|
||||
# If a channel is given, always use the authenticated endpoint
|
||||
if channel is not None:
|
||||
if self.auth_token is None:
|
||||
@@ -1165,21 +1184,24 @@ class Webhook(BaseWebhook):
|
||||
|
||||
payload['channel_id'] = channel.id
|
||||
data = await adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
|
||||
self._update(data)
|
||||
return
|
||||
|
||||
if prefer_auth and self.auth_token:
|
||||
data = await adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
|
||||
self._update(data)
|
||||
elif self.token:
|
||||
data = await adapter.edit_webhook_with_token(
|
||||
self.id, self.token, payload=payload, session=self.session, reason=reason
|
||||
)
|
||||
self._update(data)
|
||||
|
||||
if data is None:
|
||||
raise RuntimeError('Unreachable code hit: data was not assigned')
|
||||
|
||||
return Webhook(data=data, session=self.session, token=self.auth_token, state=self._state)
|
||||
|
||||
def _create_message(self, data):
|
||||
state = _WebhookState(self, parent=self._state)
|
||||
channel = self.channel or Object(id=int(data['channel_id']))
|
||||
# state may be artificial (unlikely at this point...)
|
||||
channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore
|
||||
# state is artificial
|
||||
return WebhookMessage(data=data, state=state, channel=channel) # type: ignore
|
||||
|
||||
@overload
|
||||
@@ -1188,7 +1210,7 @@ class Webhook(BaseWebhook):
|
||||
content: str = MISSING,
|
||||
*,
|
||||
username: str = MISSING,
|
||||
avatar_url: str = MISSING,
|
||||
avatar_url: Any = MISSING,
|
||||
tts: bool = MISSING,
|
||||
ephemeral: bool = MISSING,
|
||||
file: File = MISSING,
|
||||
@@ -1208,7 +1230,7 @@ class Webhook(BaseWebhook):
|
||||
content: str = MISSING,
|
||||
*,
|
||||
username: str = MISSING,
|
||||
avatar_url: str = MISSING,
|
||||
avatar_url: Any = MISSING,
|
||||
tts: bool = MISSING,
|
||||
ephemeral: bool = MISSING,
|
||||
file: File = MISSING,
|
||||
@@ -1227,7 +1249,7 @@ class Webhook(BaseWebhook):
|
||||
content: str = MISSING,
|
||||
*,
|
||||
username: str = MISSING,
|
||||
avatar_url: str = MISSING,
|
||||
avatar_url: Any = MISSING,
|
||||
tts: bool = False,
|
||||
ephemeral: bool = False,
|
||||
file: File = MISSING,
|
||||
@@ -1264,9 +1286,10 @@ class Webhook(BaseWebhook):
|
||||
username: :class:`str`
|
||||
The username to send with this message. If no username is provided
|
||||
then the default username for the webhook is used.
|
||||
avatar_url: Union[:class:`str`, :class:`Asset`]
|
||||
avatar_url: :class:`str`
|
||||
The avatar URL to send with this message. If no avatar URL is provided
|
||||
then the default avatar for the webhook is used.
|
||||
then the default avatar for the webhook is used. If this is not a
|
||||
string then it is explicitly cast using ``str``.
|
||||
tts: :class:`bool`
|
||||
Indicates if the message should be sent using text-to-speech.
|
||||
ephemeral: :class:`bool`
|
||||
@@ -1438,7 +1461,7 @@ class Webhook(BaseWebhook):
|
||||
files: List[File] = MISSING,
|
||||
view: Optional[View] = MISSING,
|
||||
allowed_mentions: Optional[AllowedMentions] = None,
|
||||
):
|
||||
) -> WebhookMessage:
|
||||
"""|coro|
|
||||
|
||||
Edits a message owned by this webhook.
|
||||
@@ -1448,6 +1471,9 @@ class Webhook(BaseWebhook):
|
||||
|
||||
.. versionadded:: 1.6
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
The edit is no longer in-place, instead the newly edited message is returned.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
message_id: :class:`int`
|
||||
@@ -1491,6 +1517,11 @@ class Webhook(BaseWebhook):
|
||||
InvalidArgument
|
||||
There was no token associated with this webhook or the webhook had
|
||||
no state.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`WebhookMessage`
|
||||
The newly edited webhook message.
|
||||
"""
|
||||
|
||||
if self.token is None:
|
||||
@@ -1514,7 +1545,7 @@ class Webhook(BaseWebhook):
|
||||
previous_allowed_mentions=previous_mentions,
|
||||
)
|
||||
adapter = async_context.get()
|
||||
await adapter.edit_webhook_message(
|
||||
data = await adapter.edit_webhook_message(
|
||||
self.id,
|
||||
self.token,
|
||||
message_id,
|
||||
@@ -1524,10 +1555,12 @@ class Webhook(BaseWebhook):
|
||||
files=params.files,
|
||||
)
|
||||
|
||||
message = self._create_message(data)
|
||||
if view and not view.is_finished():
|
||||
self._state.store_view(view, message_id)
|
||||
return message
|
||||
|
||||
async def delete_message(self, message_id: int):
|
||||
async def delete_message(self, message_id: int, /) -> None:
|
||||
"""|coro|
|
||||
|
||||
Deletes a message owned by this webhook.
|
||||
|
||||
@@ -43,7 +43,7 @@ from .. import utils
|
||||
from ..errors import InvalidArgument, HTTPException, Forbidden, NotFound, DiscordServerError
|
||||
from ..message import Message
|
||||
from ..http import Route
|
||||
from ..object import Object
|
||||
from ..channel import PartialMessageable
|
||||
|
||||
from .async_ import BaseWebhook, handle_message_parameters, _WebhookState
|
||||
|
||||
@@ -52,7 +52,7 @@ __all__ = (
|
||||
'SyncWebhookMessage',
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..file import File
|
||||
@@ -117,7 +117,7 @@ class WebhookAdapter:
|
||||
|
||||
if payload is not None:
|
||||
headers['Content-Type'] = 'application/json'
|
||||
to_send = utils.to_json(payload)
|
||||
to_send = utils._to_json(payload)
|
||||
|
||||
if auth_token is not None:
|
||||
headers['Authorization'] = f'Bot {auth_token}'
|
||||
@@ -150,7 +150,7 @@ class WebhookAdapter:
|
||||
with session.request(
|
||||
method, url, data=to_send, files=file_data, headers=headers, params=params
|
||||
) as response:
|
||||
log.debug(
|
||||
_log.debug(
|
||||
'Webhook ID %s with %s %s has returned status code %s',
|
||||
webhook_id,
|
||||
method,
|
||||
@@ -168,7 +168,7 @@ class WebhookAdapter:
|
||||
remaining = response.headers.get('X-Ratelimit-Remaining')
|
||||
if remaining == '0' and response.status_code != 429:
|
||||
delta = utils._parse_ratelimit_header(response)
|
||||
log.debug(
|
||||
_log.debug(
|
||||
'Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds', webhook_id, delta
|
||||
)
|
||||
lock.delay_by(delta)
|
||||
@@ -181,7 +181,7 @@ class WebhookAdapter:
|
||||
raise HTTPException(response, data)
|
||||
|
||||
retry_after: float = data['retry_after'] # type: ignore
|
||||
log.warning('Webhook ID %s is rate limited. Retrying in %.2f seconds', webhook_id, retry_after)
|
||||
_log.warning('Webhook ID %s is rate limited. Retrying in %.2f seconds', webhook_id, retry_after)
|
||||
time.sleep(retry_after)
|
||||
continue
|
||||
|
||||
@@ -373,6 +373,8 @@ class SyncWebhookMessage(Message):
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
|
||||
_state: _WebhookState
|
||||
|
||||
def edit(
|
||||
self,
|
||||
content: Optional[str] = MISSING,
|
||||
@@ -381,7 +383,7 @@ class SyncWebhookMessage(Message):
|
||||
file: File = MISSING,
|
||||
files: List[File] = MISSING,
|
||||
allowed_mentions: Optional[AllowedMentions] = None,
|
||||
):
|
||||
) -> SyncWebhookMessage:
|
||||
"""Edits the message.
|
||||
|
||||
Parameters
|
||||
@@ -414,8 +416,13 @@ class SyncWebhookMessage(Message):
|
||||
The length of ``embeds`` was invalid
|
||||
InvalidArgument
|
||||
There was no token associated with this webhook.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`SyncWebhookMessage`
|
||||
The newly edited message.
|
||||
"""
|
||||
self._state._webhook.edit_message(
|
||||
return self._state._webhook.edit_message(
|
||||
self.id,
|
||||
content=content,
|
||||
embeds=embeds,
|
||||
@@ -468,6 +475,10 @@ class SyncWebhook(BaseWebhook):
|
||||
|
||||
Returns the webhooks's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the webhooks's ID.
|
||||
|
||||
.. versionchanged:: 1.4
|
||||
Webhooks are now comparable and hashable.
|
||||
|
||||
@@ -515,12 +526,12 @@ class SyncWebhook(BaseWebhook):
|
||||
return f'<Webhook id={self.id!r}>'
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
def url(self) -> str:
|
||||
""":class:`str` : Returns the webhook's url."""
|
||||
return f'https://discord.com/api/webhooks/{self.id}/{self.token}'
|
||||
|
||||
@classmethod
|
||||
def partial(cls, id: int, token: str, *, session: Session = MISSING, bot_token: Optional[str] = None):
|
||||
def partial(cls, id: int, token: str, *, session: Session = MISSING, bot_token: Optional[str] = None) -> SyncWebhook:
|
||||
"""Creates a partial :class:`Webhook`.
|
||||
|
||||
Parameters
|
||||
@@ -559,7 +570,7 @@ class SyncWebhook(BaseWebhook):
|
||||
return cls(data, session, token=bot_token)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, *, session: Session = MISSING, bot_token: Optional[str] = None):
|
||||
def from_url(cls, url: str, *, session: Session = MISSING, bot_token: Optional[str] = None) -> SyncWebhook:
|
||||
"""Creates a partial :class:`Webhook` from a webhook URL.
|
||||
|
||||
Parameters
|
||||
@@ -643,7 +654,7 @@ class SyncWebhook(BaseWebhook):
|
||||
|
||||
return SyncWebhook(data, self.session, token=self.auth_token, state=self._state)
|
||||
|
||||
def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True):
|
||||
def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True) -> None:
|
||||
"""Deletes this Webhook.
|
||||
|
||||
Parameters
|
||||
@@ -685,7 +696,7 @@ class SyncWebhook(BaseWebhook):
|
||||
avatar: Optional[bytes] = MISSING,
|
||||
channel: Optional[Snowflake] = None,
|
||||
prefer_auth: bool = True,
|
||||
):
|
||||
) -> SyncWebhook:
|
||||
"""Edits this Webhook.
|
||||
|
||||
Parameters
|
||||
@@ -713,6 +724,11 @@ class SyncWebhook(BaseWebhook):
|
||||
InvalidArgument
|
||||
This webhook does not have a token associated with it
|
||||
or it tried editing a channel without authentication.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`SyncWebhook`
|
||||
The newly edited webhook.
|
||||
"""
|
||||
if self.token is None and self.auth_token is None:
|
||||
raise InvalidArgument('This webhook does not have a token associated with it')
|
||||
@@ -726,6 +742,7 @@ class SyncWebhook(BaseWebhook):
|
||||
|
||||
adapter: WebhookAdapter = _get_webhook_adapter()
|
||||
|
||||
data: Optional[WebhookPayload] = None
|
||||
# If a channel is given, always use the authenticated endpoint
|
||||
if channel is not None:
|
||||
if self.auth_token is None:
|
||||
@@ -733,20 +750,23 @@ class SyncWebhook(BaseWebhook):
|
||||
|
||||
payload['channel_id'] = channel.id
|
||||
data = adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
|
||||
self._update(data)
|
||||
return
|
||||
|
||||
if prefer_auth and self.auth_token:
|
||||
data = adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
|
||||
self._update(data)
|
||||
elif self.token:
|
||||
data = adapter.edit_webhook_with_token(self.id, self.token, payload=payload, session=self.session, reason=reason)
|
||||
self._update(data)
|
||||
|
||||
if data is None:
|
||||
raise RuntimeError('Unreachable code hit: data was not assigned')
|
||||
|
||||
return SyncWebhook(data=data, session=self.session, token=self.auth_token, state=self._state)
|
||||
|
||||
def _create_message(self, data):
|
||||
state = _WebhookState(self, parent=self._state)
|
||||
channel = self.channel or Object(id=int(data['channel_id']))
|
||||
return SyncWebhookMessage(data=data, state=state, channel=channel)
|
||||
# state may be artificial (unlikely at this point...)
|
||||
channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore
|
||||
# state is artificial
|
||||
return SyncWebhookMessage(data=data, state=state, channel=channel) # type: ignore
|
||||
|
||||
@overload
|
||||
def send(
|
||||
@@ -754,7 +774,7 @@ class SyncWebhook(BaseWebhook):
|
||||
content: str = MISSING,
|
||||
*,
|
||||
username: str = MISSING,
|
||||
avatar_url: str = MISSING,
|
||||
avatar_url: Any = MISSING,
|
||||
tts: bool = MISSING,
|
||||
file: File = MISSING,
|
||||
files: List[File] = MISSING,
|
||||
@@ -771,7 +791,7 @@ class SyncWebhook(BaseWebhook):
|
||||
content: str = MISSING,
|
||||
*,
|
||||
username: str = MISSING,
|
||||
avatar_url: str = MISSING,
|
||||
avatar_url: Any = MISSING,
|
||||
tts: bool = MISSING,
|
||||
file: File = MISSING,
|
||||
files: List[File] = MISSING,
|
||||
@@ -787,7 +807,7 @@ class SyncWebhook(BaseWebhook):
|
||||
content: str = MISSING,
|
||||
*,
|
||||
username: str = MISSING,
|
||||
avatar_url: str = MISSING,
|
||||
avatar_url: Any = MISSING,
|
||||
tts: bool = False,
|
||||
file: File = MISSING,
|
||||
files: List[File] = MISSING,
|
||||
@@ -819,9 +839,10 @@ class SyncWebhook(BaseWebhook):
|
||||
username: :class:`str`
|
||||
The username to send with this message. If no username is provided
|
||||
then the default username for the webhook is used.
|
||||
avatar_url: Union[:class:`str`, :class:`Asset`]
|
||||
avatar_url: :class:`str`
|
||||
The avatar URL to send with this message. If no avatar URL is provided
|
||||
then the default avatar for the webhook is used.
|
||||
then the default avatar for the webhook is used. If this is not a
|
||||
string then it is explicitly cast using ``str``.
|
||||
tts: :class:`bool`
|
||||
Indicates if the message should be sent using text-to-speech.
|
||||
file: :class:`File`
|
||||
@@ -902,7 +923,7 @@ class SyncWebhook(BaseWebhook):
|
||||
if wait:
|
||||
return self._create_message(data)
|
||||
|
||||
def fetch_message(self, id: int) -> SyncWebhookMessage:
|
||||
def fetch_message(self, id: int, /) -> SyncWebhookMessage:
|
||||
"""Retrieves a single :class:`~discord.SyncWebhookMessage` owned by this webhook.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
@@ -951,7 +972,7 @@ class SyncWebhook(BaseWebhook):
|
||||
file: File = MISSING,
|
||||
files: List[File] = MISSING,
|
||||
allowed_mentions: Optional[AllowedMentions] = None,
|
||||
):
|
||||
) -> SyncWebhookMessage:
|
||||
"""Edits a message owned by this webhook.
|
||||
|
||||
This is a lower level interface to :meth:`WebhookMessage.edit` in case
|
||||
@@ -1007,7 +1028,7 @@ class SyncWebhook(BaseWebhook):
|
||||
previous_allowed_mentions=previous_mentions,
|
||||
)
|
||||
adapter: WebhookAdapter = _get_webhook_adapter()
|
||||
adapter.edit_webhook_message(
|
||||
data = adapter.edit_webhook_message(
|
||||
self.id,
|
||||
self.token,
|
||||
message_id,
|
||||
@@ -1016,8 +1037,9 @@ class SyncWebhook(BaseWebhook):
|
||||
multipart=params.multipart,
|
||||
files=params.files,
|
||||
)
|
||||
return self._create_message(data)
|
||||
|
||||
def delete_message(self, message_id: int):
|
||||
def delete_message(self, message_id: int, /) -> None:
|
||||
"""Deletes a message owned by this webhook.
|
||||
|
||||
This is a lower level interface to :meth:`WebhookMessage.delete` in case
|
||||
|
||||
Reference in New Issue
Block a user