Merge branch 'iDevision:2.0' into 2.0

This commit is contained in:
Gnome! 2021-09-05 15:57:14 +01:00 committed by GitHub
commit 6af5399936
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 1033 additions and 247 deletions

View File

@ -1,5 +1,7 @@
## Contributing to discord.py ## Contributing to discord.py
Credits to the `original lib` by Rapptz <https://github.com/Rapptz/discord.py>
First off, thanks for taking the time to contribute. It makes the library substantially better. :+1: First off, thanks for taking the time to contribute. It makes the library substantially better. :+1:
The following is a set of guidelines for contributing to the repository. These are guidelines, not hard rules. The following is a set of guidelines for contributing to the repository. These are guidelines, not hard rules.
@ -8,9 +10,9 @@ The following is a set of guidelines for contributing to the repository. These a
Generally speaking questions are better suited in our resources below. Generally speaking questions are better suited in our resources below.
- The official support server: https://discord.gg/r3sSKJJ - The official support server: https://discord.gg/TvqYBrGXEm
- The Discord API server under #python_discord-py: https://discord.gg/discord-api - The Discord API server under #python_discord-py: https://discord.gg/discord-api
- [The FAQ in the documentation](https://discordpy.readthedocs.io/en/latest/faq.html) - [The FAQ in the documentation](https://enhanced-dpy.readthedocs.io/en/latest/faq.html)
- [StackOverflow's `discord.py` tag](https://stackoverflow.com/questions/tagged/discord.py) - [StackOverflow's `discord.py` tag](https://stackoverflow.com/questions/tagged/discord.py)
Please try your best not to ask questions in our issue tracker. Most of them don't belong there unless they provide value to a larger audience. Please try your best not to ask questions in our issue tracker. Most of them don't belong there unless they provide value to a larger audience.

View File

@ -1,5 +1,5 @@
discord.py enhanced-discord.py
========== ===================
.. image:: https://discord.com/api/guilds/514232441498763279/embed.png .. image:: https://discord.com/api/guilds/514232441498763279/embed.png
:target: https://discord.gg/PYAfZzpsjG :target: https://discord.gg/PYAfZzpsjG
@ -59,7 +59,7 @@ To install the development version, do the following:
.. code:: sh .. code:: sh
$ git clone https://github.com/iDevision/enhanced-discord.py $ git clone https://github.com/iDevision/enhanced-discord.py
$ cd discord.py $ cd enhanced-discord.py
$ python3 -m pip install -U .[voice] $ python3 -m pip install -U .[voice]

View File

@ -40,6 +40,7 @@ from .colour import *
from .integrations import * from .integrations import *
from .invite import * from .invite import *
from .template import * from .template import *
from .welcome_screen import *
from .widget import * from .widget import *
from .object import * from .object import *
from .reaction import * from .reaction import *

View File

@ -794,13 +794,13 @@ class CustomActivity(BaseActivity):
return hash((self.name, str(self.emoji))) return hash((self.name, str(self.emoji)))
def __str__(self) -> str: def __str__(self) -> str:
if self.emoji: if not self.emoji:
if self.name:
return f'{self.emoji} {self.name}'
return str(self.emoji)
else:
return str(self.name) return str(self.name)
if self.name:
return f'{self.emoji} {self.name}'
return str(self.emoji)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<CustomActivity name={self.name!r} emoji={self.emoji!r}>' return f'<CustomActivity name={self.name!r} emoji={self.emoji!r}>'

View File

@ -313,10 +313,11 @@ class Asset(AssetMixin):
if self._animated: if self._animated:
if format not in VALID_ASSET_FORMATS: if format not in VALID_ASSET_FORMATS:
raise InvalidArgument(f'format must be one of {VALID_ASSET_FORMATS}') raise InvalidArgument(f'format must be one of {VALID_ASSET_FORMATS}')
else: url = url.with_path(f'{path}.{format}')
elif static_format is MISSING:
if format not in VALID_STATIC_FORMATS: if format not in VALID_STATIC_FORMATS:
raise InvalidArgument(f'format must be one of {VALID_STATIC_FORMATS}') raise InvalidArgument(f'format must be one of {VALID_STATIC_FORMATS}')
url = url.with_path(f'{path}.{format}') url = url.with_path(f'{path}.{format}')
if static_format is not MISSING and not self._animated: if static_format is not MISSING and not self._animated:
if static_format not in VALID_STATIC_FORMATS: if static_format not in VALID_STATIC_FORMATS:

View File

@ -330,6 +330,10 @@ class AuditLogEntry(Hashable):
Returns the entry's hash. Returns the entry's hash.
.. describe:: int(x)
Returns the entry's ID.
.. versionchanged:: 1.7 .. versionchanged:: 1.7
Audit log entries are now comparable and hashable. Audit log entries are now comparable and hashable.

View File

@ -115,6 +115,10 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
Returns the channel's name. Returns the channel's name.
.. describe:: int(x)
Returns the channel's ID.
Attributes Attributes
----------- -----------
name: :class:`str` name: :class:`str`
@ -224,6 +228,16 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
"""List[:class:`Member`]: Returns all members that can see this channel.""" """List[:class:`Member`]: Returns all members that can see this channel."""
return [m for m in self.guild.members if self.permissions_for(m).read_messages] return [m for m in self.guild.members if self.permissions_for(m).read_messages]
@property
def bots(self) -> List[Member]:
"""List[:class:`Member`]: Returns all bots that can see this channel."""
return [m for m in self.guild.members if m.bot and self.permissions_for(m).read_messages]
@property
def humans(self) -> List[Member]:
"""List[:class:`Member`]: Returns all human members that can see this channel."""
return [m for m in self.guild.members if not m.bot and self.permissions_for(m).read_messages]
@property @property
def threads(self) -> List[Thread]: def threads(self) -> List[Thread]:
"""List[:class:`Thread`]: Returns all the threads that you can see. """List[:class:`Thread`]: Returns all the threads that you can see.
@ -1334,6 +1348,10 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
Returns the category's name. Returns the category's name.
.. describe:: int(x)
Returns the category's ID.
Attributes Attributes
----------- -----------
name: :class:`str` name: :class:`str`
@ -1556,6 +1574,10 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
Returns the channel's name. Returns the channel's name.
.. describe:: int(x)
Returns the channel's ID.
Attributes Attributes
----------- -----------
name: :class:`str` name: :class:`str`
@ -1728,6 +1750,10 @@ class DMChannel(discord.abc.Messageable, Hashable):
Returns a string representation of the channel Returns a string representation of the channel
.. describe:: int(x)
Returns the channel's ID.
Attributes Attributes
---------- ----------
recipient: Optional[:class:`User`] recipient: Optional[:class:`User`]
@ -1854,6 +1880,10 @@ class GroupChannel(discord.abc.Messageable, Hashable):
Returns a string representation of the channel Returns a string representation of the channel
.. describe:: int(x)
Returns the channel's ID.
Attributes Attributes
---------- ----------
recipients: List[:class:`User`] recipients: List[:class:`User`]
@ -2000,6 +2030,10 @@ class PartialMessageable(discord.abc.Messageable, Hashable):
Returns the partial messageable's hash. Returns the partial messageable's hash.
.. describe:: int(x)
Returns the messageable's ID.
Attributes Attributes
----------- -----------
id: :class:`int` id: :class:`int`

View File

@ -842,6 +842,38 @@ class Client:
""" """
return self._connection.get_user(id) return self._connection.get_user(id)
async def try_user(self, id: int, /) -> Optional[User]:
"""|coro|
Returns a user with the given ID. If not from cache, the user will be requested from the API.
You do not have to share any guilds with the user to get this information from the API,
however many operations do require that you do.
.. note::
This method is an API call. If you have :attr:`discord.Intents.members` and member cache enabled, consider :meth:`get_user` instead.
.. versionadded:: 2.0
Parameters
-----------
id: :class:`int`
The ID to search for.
Returns
--------
Optional[:class:`~discord.User`]
The user or ``None`` if not found.
"""
maybe_user = self.get_user(id)
if maybe_user is not None:
return maybe_user
try:
return await self.fetch_user(id)
except NotFound:
return None
def get_emoji(self, id: int, /) -> Optional[Emoji]: def get_emoji(self, id: int, /) -> Optional[Emoji]:
"""Returns an emoji with the given ID. """Returns an emoji with the given ID.

View File

@ -251,6 +251,13 @@ class Colour:
def red(cls: Type[CT]) -> CT: def red(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xe74c3c``.""" """A factory method that returns a :class:`Colour` with a value of ``0xe74c3c``."""
return cls(0xe74c3c) return cls(0xe74c3c)
@classmethod
def nitro_booster(cls):
"""A factory method that returns a :class:`Colour` with a value of ``0xf47fff``.
.. versionadded:: 2.0"""
return cls(0xf47fff)
@classmethod @classmethod
def dark_red(cls: Type[CT]) -> CT: def dark_red(cls: Type[CT]) -> CT:
@ -324,6 +331,15 @@ class Colour:
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
return cls(0xFEE75C) return cls(0xFEE75C)
@classmethod
def dark_blurple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x4E5D94``.
This is the original Dark Blurple branding.
.. versionadded:: 2.0
"""
return cls(0x4E5D94)
Color = Colour Color = Colour

View File

@ -72,30 +72,36 @@ if TYPE_CHECKING:
T = TypeVar('T') T = TypeVar('T')
MaybeEmpty = Union[T, _EmptyEmbed] MaybeEmpty = Union[T, _EmptyEmbed]
class _EmbedFooterProxy(Protocol): class _EmbedFooterProxy(Protocol):
text: MaybeEmpty[str] text: MaybeEmpty[str]
icon_url: MaybeEmpty[str] icon_url: MaybeEmpty[str]
class _EmbedFieldProxy(Protocol): class _EmbedFieldProxy(Protocol):
name: MaybeEmpty[str] name: MaybeEmpty[str]
value: MaybeEmpty[str] value: MaybeEmpty[str]
inline: bool inline: bool
class _EmbedMediaProxy(Protocol): class _EmbedMediaProxy(Protocol):
url: MaybeEmpty[str] url: MaybeEmpty[str]
proxy_url: MaybeEmpty[str] proxy_url: MaybeEmpty[str]
height: MaybeEmpty[int] height: MaybeEmpty[int]
width: MaybeEmpty[int] width: MaybeEmpty[int]
class _EmbedVideoProxy(Protocol): class _EmbedVideoProxy(Protocol):
url: MaybeEmpty[str] url: MaybeEmpty[str]
height: MaybeEmpty[int] height: MaybeEmpty[int]
width: MaybeEmpty[int] width: MaybeEmpty[int]
class _EmbedProviderProxy(Protocol): class _EmbedProviderProxy(Protocol):
name: MaybeEmpty[str] name: MaybeEmpty[str]
url: MaybeEmpty[str] url: MaybeEmpty[str]
class _EmbedAuthorProxy(Protocol): class _EmbedAuthorProxy(Protocol):
name: MaybeEmpty[str] name: MaybeEmpty[str]
url: MaybeEmpty[str] url: MaybeEmpty[str]
@ -175,15 +181,15 @@ class Embed:
Empty: Final = EmptyEmbed Empty: Final = EmptyEmbed
def __init__( def __init__(
self, self,
*, *,
colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed, colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed, color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
title: MaybeEmpty[Any] = EmptyEmbed, title: MaybeEmpty[Any] = EmptyEmbed,
type: EmbedType = 'rich', type: EmbedType = 'rich',
url: MaybeEmpty[Any] = EmptyEmbed, url: MaybeEmpty[Any] = EmptyEmbed,
description: MaybeEmpty[Any] = EmptyEmbed, description: MaybeEmpty[Any] = EmptyEmbed,
timestamp: datetime.datetime = None, timestamp: datetime.datetime = None,
): ):
self.colour = colour if colour is not EmptyEmbed else color self.colour = colour if colour is not EmptyEmbed else color
@ -397,6 +403,22 @@ class Embed:
""" """
return EmbedProxy(getattr(self, '_image', {})) # type: ignore return EmbedProxy(getattr(self, '_image', {})) # type: ignore
@image.setter
def image(self: E, url: Any):
if url is EmptyEmbed:
del self._image
else:
self._image = {
'url': str(url),
}
@image.deleter
def image(self: E):
try:
del self._image
except AttributeError:
pass
def set_image(self: E, *, url: MaybeEmpty[Any]) -> E: def set_image(self: E, *, url: MaybeEmpty[Any]) -> E:
"""Sets the image for the embed content. """Sets the image for the embed content.
@ -412,15 +434,7 @@ class Embed:
The source URL for the image. Only HTTP(S) is supported. The source URL for the image. Only HTTP(S) is supported.
""" """
if url is EmptyEmbed: self.image = url
try:
del self._image
except AttributeError:
pass
else:
self._image = {
'url': str(url),
}
return self return self
@ -439,7 +453,23 @@ class Embed:
""" """
return EmbedProxy(getattr(self, '_thumbnail', {})) # type: ignore return EmbedProxy(getattr(self, '_thumbnail', {})) # type: ignore
def set_thumbnail(self: E, *, url: MaybeEmpty[Any]) -> E: @thumbnail.setter
def thumbnail(self: E, url: Any):
if url is EmptyEmbed:
del self._thumbnail
else:
self._thumbnail = {
'url': str(url),
}
@thumbnail.deleter
def thumbnail(self):
try:
del self._thumbnail
except AttributeError:
pass
def set_thumbnail(self: E, *, url: MaybeEmpty[Any]):
"""Sets the thumbnail for the embed content. """Sets the thumbnail for the embed content.
This function returns the class instance to allow for fluent-style This function returns the class instance to allow for fluent-style
@ -454,15 +484,7 @@ class Embed:
The source URL for the thumbnail. Only HTTP(S) is supported. The source URL for the thumbnail. Only HTTP(S) is supported.
""" """
if url is EmptyEmbed: self.thumbnail = url
try:
del self._thumbnail
except AttributeError:
pass
else:
self._thumbnail = {
'url': str(url),
}
return self return self

View File

@ -72,6 +72,10 @@ class Emoji(_EmojiTag, AssetMixin):
Returns the emoji rendered for discord. Returns the emoji rendered for discord.
.. describe:: int(x)
Returns the emoji ID.
Attributes Attributes
----------- -----------
name: :class:`str` name: :class:`str`
@ -137,6 +141,9 @@ class Emoji(_EmojiTag, AssetMixin):
return f'<a:{self.name}:{self.id}>' return f'<a:{self.name}:{self.id}>'
return f'<:{self.name}:{self.id}>' return f'<:{self.name}:{self.id}>'
def __int__(self) -> int:
return self.id
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>' return f'<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>'

View File

@ -53,6 +53,7 @@ from .context import Context
from . import errors from . import errors
from .help import HelpCommand, DefaultHelpCommand from .help import HelpCommand, DefaultHelpCommand
from .cog import Cog from .cog import Cog
from discord.utils import raise_expected_coro
if TYPE_CHECKING: if TYPE_CHECKING:
import importlib.machinery import importlib.machinery
@ -453,14 +454,59 @@ class BotBase(GroupMixin):
elif self.owner_ids: elif self.owner_ids:
return user.id in self.owner_ids return user.id in self.owner_ids
else: else:
# Populate the used fields, then retry the check. This is only done at-most once in the bot lifetime.
await self.populate_owners()
return await self.is_owner(user)
app = await self.application_info() # type: ignore async def try_owners(self) -> List[discord.User]:
if app.team: """|coro|
self.owner_ids = ids = {m.id for m in app.team.members}
return user.id in ids Returns a list of :class:`~discord.User` representing the owners of the bot.
It uses the :attr:`owner_id` and :attr:`owner_ids`, if set.
.. versionadded:: 2.0
The function also checks if the application is team-owned if
:attr:`owner_ids` is not set.
Returns
--------
List[:class:`~discord.User`]
List of owners of the bot.
"""
if self.owner_id:
owner = await self.try_user(self.owner_id)
if owner:
return [owner]
else: else:
self.owner_id = owner_id = app.owner.id return []
return user.id == owner_id
elif self.owner_ids:
owners = []
for owner_id in self.owner_ids:
owner = await self.try_user(owner_id)
if owner:
owners.append(owner)
return owners
else:
# We didn't have owners cached yet, cache them and retry.
await self.populate_owners()
return await self.try_owners()
async def populate_owners(self):
"""|coro|
Populate the :attr:`owner_id` and :attr:`owner_ids` through the use of :meth:`~.Bot.application_info`.
.. versionadded:: 2.0
"""
app = await self.application_info() # type: ignore
if app.team:
self.owner_ids = {m.id for m in app.team.members}
else:
self.owner_id = app.owner.id
def before_invoke(self, coro: CFT) -> CFT: def before_invoke(self, coro: CFT) -> CFT:
"""A decorator that registers a coroutine as a pre-invoke hook. """A decorator that registers a coroutine as a pre-invoke hook.
@ -488,11 +534,9 @@ class BotBase(GroupMixin):
TypeError TypeError
The coroutine passed is not actually a coroutine. The coroutine passed is not actually a coroutine.
""" """
if not asyncio.iscoroutinefunction(coro): return raise_expected_coro(
raise TypeError('The pre-invoke hook must be a coroutine.') coro, 'The pre-invoke hook must be a coroutine.'
)
self._before_invoke = coro
return coro
def after_invoke(self, coro: CFT) -> CFT: def after_invoke(self, coro: CFT) -> CFT:
r"""A decorator that registers a coroutine as a post-invoke hook. r"""A decorator that registers a coroutine as a post-invoke hook.
@ -521,11 +565,10 @@ class BotBase(GroupMixin):
TypeError TypeError
The coroutine passed is not actually a coroutine. The coroutine passed is not actually a coroutine.
""" """
if not asyncio.iscoroutinefunction(coro): return raise_expected_coro(
raise TypeError('The post-invoke hook must be a coroutine.') coro, 'The post-invoke hook must be a coroutine.'
)
self._after_invoke = coro
return coro
# listener registration # listener registration
@ -1266,7 +1309,7 @@ class Bot(BotBase, discord.Client):
when passing an empty string, it should always be last as no prefix when passing an empty string, it should always be last as no prefix
after it will be matched. after it will be matched.
case_insensitive: :class:`bool` case_insensitive: :class:`bool`
Whether the commands should be case insensitive. Defaults to ``False``. This Whether the commands should be case insensitive. Defaults to ``True``. This
attribute does not carry over to groups. You must set it to every group if attribute does not carry over to groups. You must set it to every group if
you require group commands to be case insensitive as well. you require group commands to be case insensitive as well.
description: :class:`str` description: :class:`str`

View File

@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from __future__ import annotations
import inspect import inspect
@ -32,6 +33,7 @@ import discord.abc
import discord.utils import discord.utils
from discord.message import Message from discord.message import Message
from discord import Permissions
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
@ -62,10 +64,7 @@ T = TypeVar('T')
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]") BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
CogT = TypeVar('CogT', bound="Cog") CogT = TypeVar('CogT', bound="Cog")
if TYPE_CHECKING: P = ParamSpec('P') if TYPE_CHECKING else TypeVar('P')
P = ParamSpec('P')
else:
P = TypeVar('P')
class Context(discord.abc.Messageable, Generic[BotT]): class Context(discord.abc.Messageable, Generic[BotT]):
@ -318,6 +317,13 @@ class Context(discord.abc.Messageable, Generic[BotT]):
g = self.guild g = self.guild
return g.voice_client if g else None return g.voice_client if g else None
def author_permissions(self) -> Permissions:
"""Returns the author permissions in the given channel.
.. versionadded:: 2.0
"""
return self.channel.permissions_for(self.author)
async def send_help(self, *args: Any) -> Any: async def send_help(self, *args: Any) -> Any:
"""send_help(entity=<bot>) """send_help(entity=<bot>)

View File

@ -356,15 +356,15 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
@staticmethod @staticmethod
def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]: def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]:
if guild_id is not None: if guild_id is None:
guild = ctx.bot.get_guild(guild_id)
if guild is not None and channel_id is not None:
return guild._resolve_channel(channel_id) # type: ignore
else:
return None
else:
return ctx.bot.get_channel(channel_id) if channel_id else ctx.channel return ctx.bot.get_channel(channel_id) if channel_id else ctx.channel
guild = ctx.bot.get_guild(guild_id)
if guild is not None and channel_id is not None:
return guild._resolve_channel(channel_id) # type: ignore
else:
return None
async def convert(self, ctx: Context, argument: str) -> discord.PartialMessage: async def convert(self, ctx: Context, argument: str) -> discord.PartialMessage:
guild_id, message_id, channel_id = self._get_id_matches(ctx, argument) guild_id, message_id, channel_id = self._get_id_matches(ctx, argument)
channel = self._resolve_channel(ctx, guild_id, channel_id) channel = self._resolve_channel(ctx, guild_id, channel_id)
@ -757,8 +757,8 @@ class GuildConverter(IDConverter[discord.Guild]):
if result is None: if result is None:
result = discord.utils.get(ctx.bot.guilds, name=argument) result = discord.utils.get(ctx.bot.guilds, name=argument)
if result is None: if result is None:
raise GuildNotFound(argument) raise GuildNotFound(argument)
return result return result
@ -942,8 +942,7 @@ class clean_content(Converter[str]):
def repl(match: re.Match) -> str: def repl(match: re.Match) -> str:
type = match[1] type = match[1]
id = int(match[2]) id = int(match[2])
transformed = transforms[type](id) return transforms[type](id)
return transformed
result = re.sub(r'<(@[!&]?|#)([0-9]{15,20})>', repl, argument) result = re.sub(r'<(@[!&]?|#)([0-9]{15,20})>', repl, argument)
if self.escape_markdown: if self.escape_markdown:

View File

@ -1264,10 +1264,10 @@ class GroupMixin(Generic[CogT]):
A mapping of command name to :class:`.Command` A mapping of command name to :class:`.Command`
objects. objects.
case_insensitive: :class:`bool` case_insensitive: :class:`bool`
Whether the commands should be case insensitive. Defaults to ``False``. Whether the commands should be case insensitive. Defaults to ``True``.
""" """
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
case_insensitive = kwargs.get('case_insensitive', False) case_insensitive = kwargs.get('case_insensitive', True)
self.all_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {} self.all_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {}
self.case_insensitive: bool = case_insensitive self.case_insensitive: bool = case_insensitive
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)

View File

@ -82,9 +82,7 @@ class StringView:
def skip_string(self, string): def skip_string(self, string):
strlen = len(string) strlen = len(string)
if self.buffer[self.index:self.index + strlen] == string: if self.buffer[self.index:self.index + strlen] == string:
self.previous = self.index return self._return_index(strlen, True)
self.index += strlen
return True
return False return False
def read_rest(self): def read_rest(self):
@ -95,9 +93,7 @@ class StringView:
def read(self, n): def read(self, n):
result = self.buffer[self.index:self.index + n] result = self.buffer[self.index:self.index + n]
self.previous = self.index return self._return_index(n, result)
self.index += n
return result
def get(self): def get(self):
try: try:
@ -105,9 +101,12 @@ class StringView:
except IndexError: except IndexError:
result = None result = None
return self._return_index(1, result)
def _return_index(self, arg0, arg1):
self.previous = self.index self.previous = self.index
self.index += 1 self.index += arg0
return result return arg1
def get_word(self): def get_word(self):
pos = 0 pos = 0

View File

@ -46,7 +46,9 @@ import traceback
from collections.abc import Sequence from collections.abc import Sequence
from discord.backoff import ExponentialBackoff from discord.backoff import ExponentialBackoff
from discord.utils import MISSING from discord.utils import MISSING, raise_expected_coro
__all__ = ( __all__ = (
'loop', 'loop',
@ -488,11 +490,7 @@ class Loop(Generic[LF]):
The function was not a coroutine. The function was not a coroutine.
""" """
if not inspect.iscoroutinefunction(coro): return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.')
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
self._before_loop = coro
return coro
def after_loop(self, coro: FT) -> FT: def after_loop(self, coro: FT) -> FT:
"""A decorator that register a coroutine to be called after the loop finished running. """A decorator that register a coroutine to be called after the loop finished running.
@ -516,11 +514,7 @@ class Loop(Generic[LF]):
The function was not a coroutine. The function was not a coroutine.
""" """
if not inspect.iscoroutinefunction(coro): return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.')
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
self._after_loop = coro
return coro
def error(self, coro: ET) -> ET: def error(self, coro: ET) -> ET:
"""A decorator that registers a coroutine to be called if the task encounters an unhandled exception. """A decorator that registers a coroutine to be called if the task encounters an unhandled exception.
@ -542,11 +536,7 @@ class Loop(Generic[LF]):
TypeError TypeError
The function was not a coroutine. The function was not a coroutine.
""" """
if not inspect.iscoroutinefunction(coro): return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.')
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
self._error = coro # type: ignore
return coro
def _get_next_sleep_time(self) -> datetime.datetime: def _get_next_sleep_time(self) -> datetime.datetime:
if self._sleep is not MISSING: if self._sleep is not MISSING:
@ -614,8 +604,7 @@ class Loop(Generic[LF]):
) )
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc)) ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
ret = sorted(set(ret)) # de-dupe and sort times return sorted(set(ret))
return ret
def change_interval( def change_interval(
self, self,

View File

@ -22,8 +22,12 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
from typing import TYPE_CHECKING, TypedDict, Any, Optional, List, TypeVar, Type, Dict, Callable, Coroutine, NamedTuple, Deque
import asyncio import asyncio
from collections import namedtuple, deque from collections import deque
import concurrent.futures import concurrent.futures
import logging import logging
import struct import struct
@ -38,9 +42,25 @@ import aiohttp
from . import utils from . import utils
from .activity import BaseActivity from .activity import BaseActivity
from .enums import SpeakingState from .enums import SpeakingState
from .errors import ConnectionClosed, InvalidArgument from .errors import ConnectionClosed, InvalidArgument
if TYPE_CHECKING:
from .client import Client
from .state import ConnectionState
from .voice_client import VoiceClient
T = TypeVar('T')
DWS = TypeVar('DWS', bound='DiscordWebSocket')
DVWS = TypeVar('DVWS', bound='DiscordVoiceWebSocket')
Coro = Callable[..., Coroutine[Any, Any, Any]]
Predicate = Callable[[Dict[str, Any]], bool]
DataCallable = Callable[[Dict[str, Any]], T]
Result = Optional[DataCallable[Any]]
_log: logging.Logger = logging.getLogger(__name__)
_log = logging.getLogger(__name__)
__all__ = ( __all__ = (
'DiscordWebSocket', 'DiscordWebSocket',
@ -50,36 +70,49 @@ __all__ = (
'ReconnectWebSocket', 'ReconnectWebSocket',
) )
class Heartbeat(TypedDict):
op: int
d: int
class ReconnectWebSocket(Exception): class ReconnectWebSocket(Exception):
"""Signals to safely reconnect the websocket.""" """Signals to safely reconnect the websocket."""
def __init__(self, shard_id, *, resume=True): def __init__(self, shard_id: Optional[int], *, resume: bool = True) -> None:
self.shard_id = shard_id self.shard_id: Optional[int] = shard_id
self.resume = resume self.resume: bool = resume
self.op = 'RESUME' if resume else 'IDENTIFY' self.op = 'RESUME' if resume else 'IDENTIFY'
class WebSocketClosure(Exception): class WebSocketClosure(Exception):
"""An exception to make up for the fact that aiohttp doesn't signal closure.""" """An exception to make up for the fact that aiohttp doesn't signal closure."""
pass pass
EventListener = namedtuple('EventListener', 'predicate event result future')
class EventListener(NamedTuple):
predicate: Predicate
event: str
result: Result
future: asyncio.Future
class GatewayRatelimiter: class GatewayRatelimiter:
def __init__(self, count=110, per=60.0): def __init__(self, count: int = 110, per: float = 60.0) -> None:
# The default is 110 to give room for at least 10 heartbeats per minute # The default is 110 to give room for at least 10 heartbeats per minute
self.max = count self.max: int = count
self.remaining = count self.remaining: int = count
self.window = 0.0 self.window: float = 0.0
self.per = per self.per: float = per
self.lock = asyncio.Lock() self.lock: asyncio.Lock = asyncio.Lock()
self.shard_id = None self.shard_id: Optional[int] = None
def is_ratelimited(self): def is_ratelimited(self) -> bool:
current = time.time() current = time.time()
if current > self.window + self.per: if current > self.window + self.per:
return False return False
return self.remaining == 0 return self.remaining == 0
def get_delay(self): def get_delay(self) -> float:
current = time.time() current = time.time()
if current > self.window + self.per: if current > self.window + self.per:
@ -97,7 +130,7 @@ class GatewayRatelimiter:
return 0.0 return 0.0
async def block(self): async def block(self) -> None:
async with self.lock: async with self.lock:
delta = self.get_delay() delta = self.get_delay()
if delta: if delta:
@ -106,27 +139,27 @@ class GatewayRatelimiter:
class KeepAliveHandler(threading.Thread): class KeepAliveHandler(threading.Thread):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
ws = kwargs.pop('ws', None) ws = kwargs.pop('ws')
interval = kwargs.pop('interval', None) interval = kwargs.pop('interval', None)
shard_id = kwargs.pop('shard_id', None) shard_id = kwargs.pop('shard_id', None)
threading.Thread.__init__(self, *args, **kwargs) threading.Thread.__init__(self, *args, **kwargs)
self.ws = ws self.ws: DiscordWebSocket = ws
self._main_thread_id = ws.thread_id self._main_thread_id: int = ws.thread_id
self.interval = interval self.interval: Optional[float] = interval
self.daemon = True self.daemon: bool = True
self.shard_id = shard_id self.shard_id: Optional[int] = shard_id
self.msg = 'Keeping shard ID %s websocket alive with sequence %s.' self.msg: str = 'Keeping shard ID %s websocket alive with sequence %s.'
self.block_msg = 'Shard ID %s heartbeat blocked for more than %s seconds.' self.block_msg: str = 'Shard ID %s heartbeat blocked for more than %s seconds.'
self.behind_msg = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.' self.behind_msg: str = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.'
self._stop_ev = threading.Event() self._stop_ev: threading.Event = threading.Event()
self._last_ack = time.perf_counter() self._last_ack: float = time.perf_counter()
self._last_send = time.perf_counter() self._last_send: float = time.perf_counter()
self._last_recv = time.perf_counter() self._last_recv: float = time.perf_counter()
self.latency = float('inf') self.latency: float = float('inf')
self.heartbeat_timeout = ws._max_heartbeat_timeout self.heartbeat_timeout: float = ws._max_heartbeat_timeout
def run(self): def run(self) -> None:
while not self._stop_ev.wait(self.interval): while not self._stop_ev.wait(self.interval):
if self._last_recv + self.heartbeat_timeout < time.perf_counter(): if self._last_recv + self.heartbeat_timeout < time.perf_counter():
_log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id) _log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id)
@ -168,19 +201,20 @@ class KeepAliveHandler(threading.Thread):
else: else:
self._last_send = time.perf_counter() self._last_send = time.perf_counter()
def get_payload(self): def get_payload(self) -> Heartbeat:
return { return {
'op': self.ws.HEARTBEAT, 'op': self.ws.HEARTBEAT,
'd': self.ws.sequence # the websocket's sequence won't be None here
'd': self.ws.sequence # type: ignore
} }
def stop(self): def stop(self) -> None:
self._stop_ev.set() self._stop_ev.set()
def tick(self): def tick(self) -> None:
self._last_recv = time.perf_counter() self._last_recv = time.perf_counter()
def ack(self): def ack(self) -> None:
ack_time = time.perf_counter() ack_time = time.perf_counter()
self._last_ack = ack_time self._last_ack = ack_time
self.latency = ack_time - self._last_send self.latency = ack_time - self._last_send
@ -188,30 +222,32 @@ class KeepAliveHandler(threading.Thread):
_log.warning(self.behind_msg, self.shard_id, self.latency) _log.warning(self.behind_msg, self.shard_id, self.latency)
class VoiceKeepAliveHandler(KeepAliveHandler): class VoiceKeepAliveHandler(KeepAliveHandler):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.recent_ack_latencies = deque(maxlen=20) self.recent_ack_latencies: Deque[float] = deque(maxlen=20)
self.msg = 'Keeping shard ID %s voice websocket alive with timestamp %s.' self.msg = 'Keeping shard ID %s voice websocket alive with timestamp %s.'
self.block_msg = 'Shard ID %s voice heartbeat blocked for more than %s seconds' self.block_msg = 'Shard ID %s voice heartbeat blocked for more than %s seconds'
self.behind_msg = 'High socket latency, shard ID %s heartbeat is %.1fs behind' self.behind_msg = 'High socket latency, shard ID %s heartbeat is %.1fs behind'
def get_payload(self): def get_payload(self) -> Heartbeat:
return { return {
'op': self.ws.HEARTBEAT, 'op': self.ws.HEARTBEAT,
'd': int(time.time() * 1000) 'd': int(time.time() * 1000)
} }
def ack(self): def ack(self) -> None:
ack_time = time.perf_counter() ack_time = time.perf_counter()
self._last_ack = ack_time self._last_ack = ack_time
self._last_recv = ack_time self._last_recv = ack_time
self.latency = ack_time - self._last_send self.latency = ack_time - self._last_send
self.recent_ack_latencies.append(self.latency) self.recent_ack_latencies.append(self.latency)
class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse): class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse):
async def close(self, *, code: int = 4000, message: bytes = b'') -> bool: async def close(self, *, code: int = 4000, message: bytes = b'') -> bool:
return await super().close(code=code, message=message) return await super().close(code=code, message=message)
class DiscordWebSocket: class DiscordWebSocket:
"""Implements a WebSocket for Discord's gateway v6. """Implements a WebSocket for Discord's gateway v6.
@ -266,41 +302,53 @@ class DiscordWebSocket:
HEARTBEAT_ACK = 11 HEARTBEAT_ACK = 11
GUILD_SYNC = 12 GUILD_SYNC = 12
def __init__(self, socket, *, loop): def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.AbstractEventLoop) -> None:
self.socket = socket self.socket: aiohttp.ClientWebSocketResponse = socket
self.loop = loop self.loop: asyncio.AbstractEventLoop = loop
# an empty dispatcher to prevent crashes # an empty dispatcher to prevent crashes
self._dispatch = lambda *args: None self._dispatch = lambda *args: None
# generic event listeners # generic event listeners
self._dispatch_listeners = [] self._dispatch_listeners: List[EventListener] = []
# the keep alive # the keep alive
self._keep_alive = None self._keep_alive: Optional[KeepAliveHandler] = None
self.thread_id = threading.get_ident() self.thread_id: int = threading.get_ident()
# ws related stuff # ws related stuff
self.session_id = None self.session_id: Optional[str] = None
self.sequence = None self.sequence: Optional[int] = None
self._zlib = zlib.decompressobj() self._zlib = zlib.decompressobj()
self._buffer = bytearray() self._buffer: bytearray = bytearray()
self._close_code = None self._close_code: Optional[int] = None
self._rate_limiter = GatewayRatelimiter() self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter()
# attributes that get set in from_client
self.token: str = utils.MISSING
self._connection: ConnectionState = utils.MISSING
self._discord_parsers: Dict[str, DataCallable[None]] = utils.MISSING
self.gateway: str = utils.MISSING
self.call_hooks: Coro = utils.MISSING
self._initial_identify: bool = utils.MISSING
self.shard_id: Optional[int] = utils.MISSING
self.shard_count: Optional[int] = utils.MISSING
self.session_id: Optional[str] = utils.MISSING
self._max_heartbeat_timeout: float = utils.MISSING
@property @property
def open(self): def open(self) -> bool:
return not self.socket.closed return not self.socket.closed
def is_ratelimited(self): def is_ratelimited(self) -> bool:
return self._rate_limiter.is_ratelimited() return self._rate_limiter.is_ratelimited()
def debug_log_receive(self, data, /): def debug_log_receive(self, data, /) -> None:
self._dispatch('socket_raw_receive', data) self._dispatch('socket_raw_receive', data)
def log_receive(self, _, /): def log_receive(self, _, /) -> None:
pass pass
@classmethod @classmethod
async def from_client(cls, client, *, initial=False, gateway=None, shard_id=None, session=None, sequence=None, resume=False): async def from_client(cls: Type[DWS], client: Client, *, initial: bool = False, gateway: Optional[str] = None, shard_id: Optional[int] = None, session: Optional[str] = None, sequence: Optional[int] = None, resume: bool = False) -> DWS:
"""Creates a main websocket for Discord from a :class:`Client`. """Creates a main websocket for Discord from a :class:`Client`.
This is for internal use only. This is for internal use only.
@ -310,7 +358,9 @@ class DiscordWebSocket:
ws = cls(socket, loop=client.loop) ws = cls(socket, loop=client.loop)
# dynamically add attributes needed # dynamically add attributes needed
ws.token = client.http.token
# the token won't be None here
ws.token = client.http.token # type: ignore
ws._connection = client._connection ws._connection = client._connection
ws._discord_parsers = client._connection.parsers ws._discord_parsers = client._connection.parsers
ws._dispatch = client.dispatch ws._dispatch = client.dispatch
@ -342,7 +392,7 @@ class DiscordWebSocket:
await ws.resume() await ws.resume()
return ws return ws
def wait_for(self, event, predicate, result=None): def wait_for(self, event: str, predicate: Predicate, result: Result = None) -> asyncio.Future:
"""Waits for a DISPATCH'd event that meets the predicate. """Waits for a DISPATCH'd event that meets the predicate.
Parameters Parameters
@ -367,7 +417,7 @@ class DiscordWebSocket:
self._dispatch_listeners.append(entry) self._dispatch_listeners.append(entry)
return future return future
async def identify(self): async def identify(self) -> None:
"""Sends the IDENTIFY packet.""" """Sends the IDENTIFY packet."""
payload = { payload = {
'op': self.IDENTIFY, 'op': self.IDENTIFY,
@ -405,7 +455,7 @@ class DiscordWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
_log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id) _log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id)
async def resume(self): async def resume(self) -> None:
"""Sends the RESUME packet.""" """Sends the RESUME packet."""
payload = { payload = {
'op': self.RESUME, 'op': self.RESUME,
@ -419,7 +469,8 @@ class DiscordWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
_log.info('Shard ID %s has sent the RESUME payload.', self.shard_id) _log.info('Shard ID %s has sent the RESUME payload.', self.shard_id)
async def received_message(self, msg, /):
async def received_message(self, msg, /) -> None:
if type(msg) is bytes: if type(msg) is bytes:
self._buffer.extend(msg) self._buffer.extend(msg)
@ -537,16 +588,16 @@ class DiscordWebSocket:
del self._dispatch_listeners[index] del self._dispatch_listeners[index]
@property @property
def latency(self): def latency(self) -> float:
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.""" """:class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds."""
heartbeat = self._keep_alive heartbeat = self._keep_alive
return float('inf') if heartbeat is None else heartbeat.latency return float('inf') if heartbeat is None else heartbeat.latency
def _can_handle_close(self): def _can_handle_close(self) -> bool:
code = self._close_code or self.socket.close_code code = self._close_code or self.socket.close_code
return code not in (1000, 4004, 4010, 4011, 4012, 4013, 4014) return code not in (1000, 4004, 4010, 4011, 4012, 4013, 4014)
async def poll_event(self): async def poll_event(self) -> None:
"""Polls for a DISPATCH event and handles the general gateway loop. """Polls for a DISPATCH event and handles the general gateway loop.
Raises Raises
@ -584,23 +635,23 @@ class DiscordWebSocket:
_log.info('Websocket closed with %s, cannot reconnect.', code) _log.info('Websocket closed with %s, cannot reconnect.', code)
raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None
async def debug_send(self, data, /): async def debug_send(self, data, /) -> None:
await self._rate_limiter.block() await self._rate_limiter.block()
self._dispatch('socket_raw_send', data) self._dispatch('socket_raw_send', data)
await self.socket.send_str(data) await self.socket.send_str(data)
async def send(self, data, /): async def send(self, data, /) -> None:
await self._rate_limiter.block() await self._rate_limiter.block()
await self.socket.send_str(data) await self.socket.send_str(data)
async def send_as_json(self, data): async def send_as_json(self, data) -> None:
try: try:
await self.send(utils._to_json(data)) await self.send(utils._to_json(data))
except RuntimeError as exc: except RuntimeError as exc:
if not self._can_handle_close(): if not self._can_handle_close():
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
async def send_heartbeat(self, data): async def send_heartbeat(self, data: Heartbeat) -> None:
# This bypasses the rate limit handling code since it has a higher priority # This bypasses the rate limit handling code since it has a higher priority
try: try:
await self.socket.send_str(utils._to_json(data)) await self.socket.send_str(utils._to_json(data))
@ -608,13 +659,13 @@ class DiscordWebSocket:
if not self._can_handle_close(): if not self._can_handle_close():
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
async def change_presence(self, *, activity=None, status=None, since=0.0): async def change_presence(self, *, activity: Optional[BaseActivity] = None, status: Optional[str] = None, since: float = 0.0) -> None:
if activity is not None: if activity is not None:
if not isinstance(activity, BaseActivity): if not isinstance(activity, BaseActivity):
raise InvalidArgument('activity must derive from BaseActivity.') raise InvalidArgument('activity must derive from BaseActivity.')
activity = [activity.to_dict()] activities = [activity.to_dict()]
else: else:
activity = [] activities = []
if status == 'idle': if status == 'idle':
since = int(time.time() * 1000) since = int(time.time() * 1000)
@ -622,7 +673,7 @@ class DiscordWebSocket:
payload = { payload = {
'op': self.PRESENCE, 'op': self.PRESENCE,
'd': { 'd': {
'activities': activity, 'activities': activities,
'afk': False, 'afk': False,
'since': since, 'since': since,
'status': status 'status': status
@ -633,7 +684,7 @@ class DiscordWebSocket:
_log.debug('Sending "%s" to change status', sent) _log.debug('Sending "%s" to change status', sent)
await self.send(sent) await self.send(sent)
async def request_chunks(self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None): async def request_chunks(self, guild_id: int, query: Optional[str] = None, *, limit: int, user_ids: Optional[List[int]] = None, presences: bool = False, nonce: Optional[int] = None) -> None:
payload = { payload = {
'op': self.REQUEST_MEMBERS, 'op': self.REQUEST_MEMBERS,
'd': { 'd': {
@ -655,7 +706,7 @@ class DiscordWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False): async def voice_state(self, guild_id: int, channel_id: int, self_mute: bool = False, self_deaf: bool = False) -> None:
payload = { payload = {
'op': self.VOICE_STATE, 'op': self.VOICE_STATE,
'd': { 'd': {
@ -669,7 +720,7 @@ class DiscordWebSocket:
_log.debug('Updating our voice state to %s.', payload) _log.debug('Updating our voice state to %s.', payload)
await self.send_as_json(payload) await self.send_as_json(payload)
async def close(self, code=4000): async def close(self, code: int = 4000) -> None:
if self._keep_alive: if self._keep_alive:
self._keep_alive.stop() self._keep_alive.stop()
self._keep_alive = None self._keep_alive = None
@ -721,25 +772,31 @@ class DiscordVoiceWebSocket:
CLIENT_CONNECT = 12 CLIENT_CONNECT = 12
CLIENT_DISCONNECT = 13 CLIENT_DISCONNECT = 13
def __init__(self, socket, loop, *, hook=None): def __init__(self, socket: aiohttp.ClientWebSocketResponse, loop: asyncio.AbstractEventLoop, *, hook: Optional[Coro] = None) -> None:
self.ws = socket self.ws: aiohttp.ClientWebSocketResponse = socket
self.loop = loop self.loop: asyncio.AbstractEventLoop = loop
self._keep_alive = None self._keep_alive: VoiceKeepAliveHandler = utils.MISSING
self._close_code = None self._close_code: Optional[int] = None
self.secret_key = None self.secret_key: Optional[List[int]] = None
self.gateway: str = utils.MISSING
self._connection: VoiceClient = utils.MISSING
self._max_heartbeat_timeout: float = utils.MISSING
self.thread_id: int = utils.MISSING
if hook: if hook:
self._hook = hook # we want to redeclare self._hook
self._hook = hook # type: ignore
async def _hook(self, *args): async def _hook(self, *args: Any) -> Any:
pass pass
async def send_as_json(self, data):
async def send_as_json(self, data) -> None:
_log.debug('Sending voice websocket frame: %s.', data) _log.debug('Sending voice websocket frame: %s.', data)
await self.ws.send_str(utils._to_json(data)) await self.ws.send_str(utils._to_json(data))
send_heartbeat = send_as_json send_heartbeat = send_as_json
async def resume(self): async def resume(self) -> None:
state = self._connection state = self._connection
payload = { payload = {
'op': self.RESUME, 'op': self.RESUME,
@ -765,7 +822,7 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
@classmethod @classmethod
async def from_client(cls, client, *, resume=False, hook=None): async def from_client(cls: Type[DVWS], client: VoiceClient, *, resume: bool = False, hook: Optional[Coro] = None) -> DVWS:
"""Creates a voice websocket for the :class:`VoiceClient`.""" """Creates a voice websocket for the :class:`VoiceClient`."""
gateway = 'wss://' + client.endpoint + '/?v=4' gateway = 'wss://' + client.endpoint + '/?v=4'
http = client._state.http http = client._state.http
@ -783,7 +840,7 @@ class DiscordVoiceWebSocket:
return ws return ws
async def select_protocol(self, ip, port, mode): async def select_protocol(self, ip, port, mode) -> None:
payload = { payload = {
'op': self.SELECT_PROTOCOL, 'op': self.SELECT_PROTOCOL,
'd': { 'd': {
@ -798,7 +855,7 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
async def client_connect(self): async def client_connect(self) -> None:
payload = { payload = {
'op': self.CLIENT_CONNECT, 'op': self.CLIENT_CONNECT,
'd': { 'd': {
@ -808,7 +865,7 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
async def speak(self, state=SpeakingState.voice): async def speak(self, state=SpeakingState.voice) -> None:
payload = { payload = {
'op': self.SPEAKING, 'op': self.SPEAKING,
'd': { 'd': {
@ -819,7 +876,8 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
async def received_message(self, msg):
async def received_message(self, msg) -> None:
_log.debug('Voice websocket frame received: %s', msg) _log.debug('Voice websocket frame received: %s', msg)
op = msg['op'] op = msg['op']
data = msg.get('d') data = msg.get('d')
@ -840,7 +898,7 @@ class DiscordVoiceWebSocket:
await self._hook(self, msg) await self._hook(self, msg)
async def initial_connection(self, data): async def initial_connection(self, data) -> None:
state = self._connection state = self._connection
state.ssrc = data['ssrc'] state.ssrc = data['ssrc']
state.voice_port = data['port'] state.voice_port = data['port']
@ -871,13 +929,13 @@ class DiscordVoiceWebSocket:
_log.info('selected the voice protocol for use (%s)', mode) _log.info('selected the voice protocol for use (%s)', mode)
@property @property
def latency(self): def latency(self) -> float:
""":class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds.""" """:class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds."""
heartbeat = self._keep_alive heartbeat = self._keep_alive
return float('inf') if heartbeat is None else heartbeat.latency return float('inf') if heartbeat is None else heartbeat.latency
@property @property
def average_latency(self): def average_latency(self) -> float:
""":class:`list`: Average of last 20 HEARTBEAT latencies.""" """:class:`list`: Average of last 20 HEARTBEAT latencies."""
heartbeat = self._keep_alive heartbeat = self._keep_alive
if heartbeat is None or not heartbeat.recent_ack_latencies: if heartbeat is None or not heartbeat.recent_ack_latencies:
@ -885,13 +943,14 @@ class DiscordVoiceWebSocket:
return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies) return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies)
async def load_secret_key(self, data):
async def load_secret_key(self, data) -> None:
_log.info('received secret key for voice connection') _log.info('received secret key for voice connection')
self.secret_key = self._connection.secret_key = data.get('secret_key') self.secret_key = self._connection.secret_key = data.get('secret_key')
await self.speak() await self.speak()
await self.speak(False) await self.speak(False)
async def poll_event(self): async def poll_event(self) -> None:
# This exception is handled up the chain # This exception is handled up the chain
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0) msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
if msg.type is aiohttp.WSMsgType.TEXT: if msg.type is aiohttp.WSMsgType.TEXT:
@ -903,7 +962,7 @@ class DiscordVoiceWebSocket:
_log.debug('Received %s', msg) _log.debug('Received %s', msg)
raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code) raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code)
async def close(self, code=1000): async def close(self, code: int = 1000) -> None:
if self._keep_alive is not None: if self._keep_alive is not None:
self._keep_alive.stop() self._keep_alive.stop()

View File

@ -46,7 +46,7 @@ from . import utils, abc
from .role import Role from .role import Role
from .member import Member, VoiceState from .member import Member, VoiceState
from .emoji import Emoji from .emoji import Emoji
from .errors import InvalidData from .errors import InvalidData, NotFound
from .permissions import PermissionOverwrite from .permissions import PermissionOverwrite
from .colour import Colour from .colour import Colour
from .errors import InvalidArgument, ClientException from .errors import InvalidArgument, ClientException
@ -76,6 +76,7 @@ from .stage_instance import StageInstance
from .threads import Thread, ThreadMember from .threads import Thread, ThreadMember
from .sticker import GuildSticker from .sticker import GuildSticker
from .file import File from .file import File
from .welcome_screen import WelcomeScreen, WelcomeChannel
__all__ = ( __all__ = (
@ -140,6 +141,10 @@ class Guild(Hashable):
Returns the guild's name. Returns the guild's name.
.. describe:: int(x)
Returns the guild's ID.
Attributes Attributes
---------- ----------
name: :class:`str` name: :class:`str`
@ -738,12 +743,16 @@ class Guild(Hashable):
@property @property
def humans(self) -> List[Member]: def humans(self) -> List[Member]:
"""List[:class:`Member`]: A list of human members that belong to this guild.""" """List[:class:`Member`]: A list of human members that belong to this guild.
.. versionadded:: 2.0 """
return [member for member in self.members if not member.bot] return [member for member in self.members if not member.bot]
@property @property
def bots(self) -> List[Member]: def bots(self) -> List[Member]:
"""List[:class:`Member`]: A list of bots that belong to this guild.""" """List[:class:`Member`]: A list of bots that belong to this guild.
.. versionadded:: 2.0 """
return [member for member in self.members if member.bot] return [member for member in self.members if member.bot]
def get_member(self, user_id: int, /) -> Optional[Member]: def get_member(self, user_id: int, /) -> Optional[Member]:
@ -1715,6 +1724,8 @@ class Guild(Hashable):
You do not have access to the guild. You do not have access to the guild.
HTTPException HTTPException
Fetching the member failed. Fetching the member failed.
NotFound
A member with that ID does not exist.
Returns Returns
-------- --------
@ -1724,6 +1735,34 @@ class Guild(Hashable):
data = await self._state.http.get_member(self.id, member_id) data = await self._state.http.get_member(self.id, member_id)
return Member(data=data, state=self._state, guild=self) return Member(data=data, state=self._state, guild=self)
async def try_member(self, member_id: int, /) -> Optional[Member]:
"""|coro|
Returns a member with the given ID. This uses the cache first, and if not found, it'll request using :meth:`fetch_member`.
.. note::
This method might result in an API call.
Parameters
-----------
member_id: :class:`int`
The ID to search for.
Returns
--------
Optional[:class:`Member`]
The member or ``None`` if not found.
"""
member = self.get_member(member_id)
if member:
return member
else:
try:
return await self.fetch_member(member_id)
except NotFound:
return None
async def fetch_ban(self, user: Snowflake) -> BanEntry: async def fetch_ban(self, user: Snowflake) -> BanEntry:
"""|coro| """|coro|
@ -2566,6 +2605,81 @@ class Guild(Hashable):
return roles return roles
async def welcome_screen(self) -> WelcomeScreen:
"""|coro|
Returns the guild's welcome screen.
The guild must have ``COMMUNITY`` in :attr:`~Guild.features`.
You must have the :attr:`~Permissions.manage_guild` permission to use
this as well.
.. versionadded:: 2.0
Raises
-------
Forbidden
You do not have the proper permissions to get this.
HTTPException
Retrieving the welcome screen failed.
Returns
--------
:class:`WelcomeScreen`
The welcome screen.
"""
data = await self._state.http.get_welcome_screen(self.id)
return WelcomeScreen(data=data, guild=self)
@overload
async def edit_welcome_screen(
self,
*,
description: Optional[str] = ...,
welcome_channels: Optional[List[WelcomeChannel]] = ...,
enabled: Optional[bool] = ...,
) -> WelcomeScreen:
...
@overload
async def edit_welcome_screen(self) -> None:
...
async def edit_welcome_screen(self, **kwargs):
"""|coro|
A shorthand method of :attr:`WelcomeScreen.edit` without needing
to fetch the welcome screen beforehand.
The guild must have ``COMMUNITY`` in :attr:`~Guild.features`.
You must have the :attr:`~Permissions.manage_guild` permission to use
this as well.
.. versionadded:: 2.0
Returns
--------
:class:`WelcomeScreen`
The edited welcome screen.
"""
try:
welcome_channels = kwargs['welcome_channels']
except KeyError:
pass
else:
welcome_channels_serialised = []
for wc in welcome_channels:
if not isinstance(wc, WelcomeChannel):
raise InvalidArgument('welcome_channels parameter must be a list of WelcomeChannel')
welcome_channels_serialised.append(wc.to_dict())
kwargs['welcome_channels'] = welcome_channels_serialised
if kwargs:
data = await self._state.http.edit_welcome_screen(self.id, kwargs)
return WelcomeScreen(data=data, guild=self)
async def kick(self, user: Snowflake, *, reason: Optional[str] = None) -> None: async def kick(self, user: Snowflake, *, reason: Optional[str] = None) -> None:
"""|coro| """|coro|

View File

@ -84,6 +84,7 @@ if TYPE_CHECKING:
threads, threads,
voice, voice,
sticker, sticker,
welcome_screen,
) )
from .types.snowflake import Snowflake, SnowflakeList from .types.snowflake import Snowflake, SnowflakeList
@ -1116,6 +1117,20 @@ class HTTPClient:
payload['icon'] = icon payload['icon'] = icon
return self.request(Route('POST', '/guilds/templates/{code}', code=code), json=payload) return self.request(Route('POST', '/guilds/templates/{code}', code=code), json=payload)
def get_welcome_screen(self, guild_id: Snowflake) -> Response[welcome_screen.WelcomeScreen]:
return self.request(Route('GET', '/guilds/{guild_id}/welcome-screen', guild_id=guild_id))
def edit_welcome_screen(self, guild_id: Snowflake, payload: Any) -> Response[welcome_screen.WelcomeScreen]:
valid_keys = (
'description',
'welcome_channels',
'enabled',
)
payload = {
k: v for k, v in payload.items() if k in valid_keys
}
return self.request(Route('PATCH', '/guilds/{guild_id}/welcome-screen', guild_id=guild_id), json=payload)
def get_bans(self, guild_id: Snowflake) -> Response[List[guild.Ban]]: def get_bans(self, guild_id: Snowflake) -> Response[List[guild.Ban]]:
return self.request(Route('GET', '/guilds/{guild_id}/bans', guild_id=guild_id)) return self.request(Route('GET', '/guilds/{guild_id}/bans', guild_id=guild_id))

View File

@ -230,6 +230,7 @@ class Invite(Hashable):
Returns the invite URL. Returns the invite URL.
The following table illustrates what methods will obtain the attributes: The following table illustrates what methods will obtain the attributes:
+------------------------------------+------------------------------------------------------------+ +------------------------------------+------------------------------------------------------------+
@ -433,6 +434,9 @@ class Invite(Hashable):
def __str__(self) -> str: def __str__(self) -> str:
return self.url return self.url
def __int__(self) -> int:
return 0 # To keep the object compatible with the hashable abc.
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f'<Invite code={self.code!r} guild={self.guild!r} ' f'<Invite code={self.code!r} guild={self.guild!r} '

View File

@ -226,6 +226,10 @@ class Member(discord.abc.Messageable, _UserTag):
Returns the member's name with the discriminator. Returns the member's name with the discriminator.
.. describe:: int(x)
Returns the user's ID.
Attributes Attributes
---------- ----------
joined_at: Optional[:class:`datetime.datetime`] joined_at: Optional[:class:`datetime.datetime`]
@ -300,6 +304,9 @@ class Member(discord.abc.Messageable, _UserTag):
def __str__(self) -> str: def __str__(self) -> str:
return str(self._user) return str(self._user)
def __int__(self) -> int:
return self.id
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f'<Member id={self._user.id} name={self._user.name!r} discriminator={self._user.discriminator!r}' f'<Member id={self._user.id} name={self._user.name!r} discriminator={self._user.discriminator!r}'

View File

@ -125,6 +125,10 @@ class Attachment(Hashable):
Returns the hash of the attachment. Returns the hash of the attachment.
.. describe:: int(x)
Returns the attachment's ID.
.. versionchanged:: 1.7 .. versionchanged:: 1.7
Attachment can now be casted to :class:`str` and is hashable. Attachment can now be casted to :class:`str` and is hashable.
@ -503,6 +507,14 @@ class Message(Hashable):
Returns the message's hash. Returns the message's hash.
.. describe:: str(x)
Returns the message's content.
.. describe:: int(x)
Returns the message's ID.
Attributes Attributes
----------- -----------
tts: :class:`bool` tts: :class:`bool`
@ -712,6 +724,10 @@ class Message(Hashable):
f'<{name} id={self.id} channel={self.channel!r} type={self.type!r} author={self.author!r} flags={self.flags!r}>' f'<{name} id={self.id} channel={self.channel!r} type={self.type!r} author={self.author!r} flags={self.flags!r}>'
) )
def __str__(self) -> Optional[str]:
return self.content
def _try_patch(self, data, key, transform=None) -> None: def _try_patch(self, data, key, transform=None) -> None:
try: try:
value = data[key] value = data[key]
@ -1634,6 +1650,10 @@ class PartialMessage(Hashable):
Returns the partial message's hash. Returns the partial message's hash.
.. describe:: int(x)
Returns the partial message's ID.
Attributes Attributes
----------- -----------
channel: Union[:class:`TextChannel`, :class:`Thread`, :class:`DMChannel`] channel: Union[:class:`TextChannel`, :class:`Thread`, :class:`DMChannel`]

View File

@ -43,5 +43,8 @@ class EqualityComparable:
class Hashable(EqualityComparable): class Hashable(EqualityComparable):
__slots__ = () __slots__ = ()
def __int__(self) -> int:
return self.id
def __hash__(self) -> int: def __hash__(self) -> int:
return self.id >> 22 return self.id >> 22

View File

@ -69,6 +69,10 @@ class Object(Hashable):
Returns the object's hash. Returns the object's hash.
.. describe:: int(x)
Returns the object's ID.
Attributes Attributes
----------- -----------
id: :class:`int` id: :class:`int`

View File

@ -299,6 +299,13 @@ class Permissions(BaseFlags):
""" """
return 1 << 3 return 1 << 3
@make_permission_alias('administrator')
def admin(self) -> int:
""":class:`bool`: An alias for :attr:`administrator`.
.. versionadded:: 2.0
"""
return 1 << 3
@flag_value @flag_value
def manage_channels(self) -> int: def manage_channels(self) -> int:
""":class:`bool`: Returns ``True`` if a user can edit, delete, or create channels in the guild. """:class:`bool`: Returns ``True`` if a user can edit, delete, or create channels in the guild.

View File

@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from __future__ import annotations
import threading import threading
@ -63,10 +64,7 @@ __all__ = (
CREATE_NO_WINDOW: int CREATE_NO_WINDOW: int
if sys.platform != 'win32': CREATE_NO_WINDOW = 0 if sys.platform != 'win32' else 0x08000000
CREATE_NO_WINDOW = 0
else:
CREATE_NO_WINDOW = 0x08000000
class AudioSource: class AudioSource:
"""Represents an audio stream. """Represents an audio stream.
@ -526,7 +524,12 @@ class FFmpegOpusAudio(FFmpegAudio):
@staticmethod @staticmethod
def _probe_codec_native(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]: def _probe_codec_native(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]:
exe = executable[:2] + 'probe' if executable in ('ffmpeg', 'avconv') else executable exe = (
executable[:2] + 'probe'
if executable in {'ffmpeg', 'avconv'}
else executable
)
args = [exe, '-v', 'quiet', '-print_format', 'json', '-show_streams', '-select_streams', 'a:0', source] args = [exe, '-v', 'quiet', '-print_format', 'json', '-show_streams', '-select_streams', 'a:0', source]
output = subprocess.check_output(args, timeout=20) output = subprocess.check_output(args, timeout=20)
codec = bitrate = None codec = bitrate = None

View File

@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations from __future__ import annotations
import datetime
from typing import TYPE_CHECKING, Optional, Set, List from typing import TYPE_CHECKING, Optional, Set, List
if TYPE_CHECKING: if TYPE_CHECKING:
@ -34,7 +35,8 @@ if TYPE_CHECKING:
MessageUpdateEvent, MessageUpdateEvent,
ReactionClearEvent, ReactionClearEvent,
ReactionClearEmojiEvent, ReactionClearEmojiEvent,
IntegrationDeleteEvent IntegrationDeleteEvent,
TypingEvent
) )
from .message import Message from .message import Message
from .partial_emoji import PartialEmoji from .partial_emoji import PartialEmoji
@ -49,6 +51,7 @@ __all__ = (
'RawReactionClearEvent', 'RawReactionClearEvent',
'RawReactionClearEmojiEvent', 'RawReactionClearEmojiEvent',
'RawIntegrationDeleteEvent', 'RawIntegrationDeleteEvent',
'RawTypingEvent'
) )
@ -276,3 +279,36 @@ class RawIntegrationDeleteEvent(_RawReprMixin):
self.application_id: Optional[int] = int(data['application_id']) self.application_id: Optional[int] = int(data['application_id'])
except KeyError: except KeyError:
self.application_id: Optional[int] = None self.application_id: Optional[int] = None
class RawTypingEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_typing` event.
.. versionadded:: 2.0
Attributes
-----------
channel_id: :class:`int`
The channel ID where the typing originated from.
user_id: :class:`int`
The ID of the user that started typing.
when: :class:`datetime.datetime`
When the typing started as an aware datetime in UTC.
guild_id: Optional[:class:`int`]
The guild ID where the typing originated from, if applicable.
member: Optional[:class:`Member`]
The member who started typing. Only available if the member started typing in a guild.
"""
__slots__ = ("channel_id", "user_id", "when", "guild_id", "member")
def __init__(self, data: TypingEvent) -> None:
self.channel_id: int = int(data['channel_id'])
self.user_id: int = int(data['user_id'])
self.when: datetime.datetime = datetime.datetime.fromtimestamp(data.get('timestamp'), tz=datetime.timezone.utc)
self.member: Optional[Member] = None
try:
self.guild_id: Optional[int] = int(data['guild_id'])
except KeyError:
self.guild_id: Optional[int] = None

View File

@ -141,6 +141,14 @@ class Role(Hashable):
Returns the role's name. Returns the role's name.
.. describe:: str(x)
Returns the role's ID.
.. describe:: int(x)
Returns the role's ID.
Attributes Attributes
---------- ----------
id: :class:`int` id: :class:`int`
@ -195,6 +203,9 @@ class Role(Hashable):
def __str__(self) -> str: def __str__(self) -> str:
return self.name return self.name
def __int__(self) -> int:
return self.id
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<Role id={self.id} name={self.name!r}>' return f'<Role id={self.id} name={self.name!r}>'

View File

@ -61,6 +61,10 @@ class StageInstance(Hashable):
Returns the stage instance's hash. Returns the stage instance's hash.
.. describe:: int(x)
Returns the stage instance's ID.
Attributes Attributes
----------- -----------
id: :class:`int` id: :class:`int`

View File

@ -1327,28 +1327,37 @@ class ConnectionState:
asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice server update handler')) asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice server update handler'))
def parse_typing_start(self, data) -> None: def parse_typing_start(self, data) -> None:
raw = RawTypingEvent(data)
member_data = data.get('member')
if member_data:
guild = self._get_guild(raw.guild_id)
if guild is not None:
raw.member = Member(data=member_data, guild=guild, state=self)
else:
raw.member = None
else:
raw.member = None
self.dispatch('raw_typing', raw)
channel, guild = self._get_guild_channel(data) channel, guild = self._get_guild_channel(data)
if channel is not None: if channel is not None:
member = None user = raw.member or self._get_typing_user(channel, raw.user_id)
user_id = utils._get_as_snowflake(data, 'user_id')
if isinstance(channel, DMChannel):
member = channel.recipient
elif isinstance(channel, (Thread, TextChannel)) and guild is not None: if user is not None:
# user_id won't be None self.dispatch('typing', channel, user, raw.when)
member = guild.get_member(user_id) # type: ignore
if member is None: def _get_typing_user(self, channel: Optional[MessageableChannel], user_id: int) -> Optional[Union[User, Member]]:
member_data = data.get('member') if isinstance(channel, DMChannel):
if member_data: return channel.recipient
member = Member(data=member_data, state=self, guild=guild)
elif isinstance(channel, GroupChannel): elif isinstance(channel, (Thread, TextChannel)) and channel.guild is not None:
member = utils.find(lambda x: x.id == user_id, channel.recipients) return channel.guild.get_member(user_id) # type: ignore
if member is not None: elif isinstance(channel, GroupChannel):
timestamp = datetime.datetime.fromtimestamp(data.get('timestamp'), tz=datetime.timezone.utc) return utils.find(lambda x: x.id == user_id, channel.recipients)
self.dispatch('typing', channel, member, timestamp)
return self.get_user(user_id)
def _get_reaction_user(self, channel: MessageableChannel, user_id: int) -> Optional[Union[User, Member]]: def _get_reaction_user(self, channel: MessageableChannel, user_id: int) -> Optional[Union[User, Member]]:
if isinstance(channel, TextChannel): if isinstance(channel, TextChannel):

View File

@ -67,6 +67,14 @@ class StickerPack(Hashable):
Returns the name of the sticker pack. Returns the name of the sticker pack.
.. describe:: hash(x)
Returns the hash of the sticker pack.
.. describe:: int(x)
Returns the ID of the sticker pack.
.. describe:: x == y .. describe:: x == y
Checks if the sticker pack is equal to another sticker pack. Checks if the sticker pack is equal to another sticker pack.

View File

@ -74,6 +74,10 @@ class Thread(Messageable, Hashable):
Returns the thread's hash. Returns the thread's hash.
.. describe:: int(x)
Returns the thread's ID.
.. describe:: str(x) .. describe:: str(x)
Returns the thread's name. Returns the thread's name.
@ -748,6 +752,10 @@ class ThreadMember(Hashable):
Returns the thread member's hash. Returns the thread member's hash.
.. describe:: int(x)
Returns the thread member's ID.
.. describe:: str(x) .. describe:: str(x)
Returns the thread member's name. Returns the thread member's name.
@ -800,3 +808,39 @@ class ThreadMember(Hashable):
def thread(self) -> Thread: def thread(self) -> Thread:
""":class:`Thread`: The thread this member belongs to.""" """:class:`Thread`: The thread this member belongs to."""
return self.parent return self.parent
async def fetch_member(self) -> Member:
"""|coro|
Retrieves a :class:`Member` from the ThreadMember object.
.. note::
This method is an API call. If you have :attr:`Intents.members` and member cache enabled, consider :meth:`get_member` instead.
Raises
-------
Forbidden
You do not have access to the guild.
HTTPException
Fetching the member failed.
Returns
--------
:class:`Member`
The member.
"""
return await self.thread.guild.fetch_member(self.id)
def get_member(self) -> Optional[Member]:
"""
Get the :class:`Member` from cache for the ThreadMember object.
Returns
--------
Optional[:class:`Member`]
The member or ``None`` if not found.
"""
return self.thread.guild.get_member(self.id)

View File

@ -85,3 +85,13 @@ class _IntegrationDeleteEventOptional(TypedDict, total=False):
class IntegrationDeleteEvent(_IntegrationDeleteEventOptional): class IntegrationDeleteEvent(_IntegrationDeleteEventOptional):
id: Snowflake id: Snowflake
guild_id: Snowflake guild_id: Snowflake
class _TypingEventOptional(TypedDict, total=False):
guild_id: Snowflake
member: Member
class TypingEvent(_TypingEventOptional):
channel_id: Snowflake
user_id: Snowflake
timestamp: int

View File

@ -185,16 +185,16 @@ class Button(Item[V]):
@emoji.setter @emoji.setter
def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]): # type: ignore def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]): # type: ignore
if value is not None: if value is None:
if isinstance(value, str):
self._underlying.emoji = PartialEmoji.from_str(value)
elif isinstance(value, _EmojiTag):
self._underlying.emoji = value._to_partial()
else:
raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__} instead')
else:
self._underlying.emoji = None self._underlying.emoji = None
elif isinstance(value, str):
self._underlying.emoji = PartialEmoji.from_str(value)
elif isinstance(value, _EmojiTag):
self._underlying.emoji = value._to_partial()
else:
raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__} instead')
@classmethod @classmethod
def from_component(cls: Type[B], button: ButtonComponent) -> B: def from_component(cls: Type[B], button: ButtonComponent) -> B:
return cls( return cls(

View File

@ -96,6 +96,9 @@ class BaseUser(_UserTag):
def __str__(self) -> str: def __str__(self) -> str:
return f'{self.name}#{self.discriminator}' return f'{self.name}#{self.discriminator}'
def __int__(self) -> int:
return self.id
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
return isinstance(other, _UserTag) and other.id == self.id return isinstance(other, _UserTag) and other.id == self.id
@ -415,6 +418,10 @@ class User(BaseUser, discord.abc.Messageable):
Returns the user's name with discriminator. Returns the user's name with discriminator.
.. describe:: int(x)
Returns the user's ID.
Attributes Attributes
----------- -----------
name: :class:`str` name: :class:`str`

View File

@ -499,14 +499,14 @@ else:
def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float: def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
reset_after: Optional[str] = request.headers.get('X-Ratelimit-Reset-After') reset_after: Optional[str] = request.headers.get('X-Ratelimit-Reset-After')
if use_clock or not reset_after: if not use_clock and reset_after:
utc = datetime.timezone.utc
now = datetime.datetime.now(utc)
reset = datetime.datetime.fromtimestamp(float(request.headers['X-Ratelimit-Reset']), utc)
return (reset - now).total_seconds()
else:
return float(reset_after) return float(reset_after)
utc = datetime.timezone.utc
now = datetime.datetime.now(utc)
reset = datetime.datetime.fromtimestamp(float(request.headers['X-Ratelimit-Reset']), utc)
return (reset - now).total_seconds()
async def maybe_coroutine(f, *args, **kwargs): async def maybe_coroutine(f, *args, **kwargs):
value = f(*args, **kwargs) value = f(*args, **kwargs)
@ -659,11 +659,10 @@ def resolve_invite(invite: Union[Invite, str]) -> str:
if isinstance(invite, Invite): if isinstance(invite, Invite):
return invite.code return invite.code
else: rx = r'(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)'
rx = r'(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)' m = re.match(rx, invite)
m = re.match(rx, invite) if m:
if m: return m.group(1)
return m.group(1)
return invite return invite
@ -687,11 +686,10 @@ def resolve_template(code: Union[Template, str]) -> str:
if isinstance(code, Template): if isinstance(code, Template):
return code.code return code.code
else: rx = r'(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)'
rx = r'(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)' m = re.match(rx, code)
m = re.match(rx, code) if m:
if m: return m.group(1)
return m.group(1)
return code return code
@ -1017,3 +1015,9 @@ def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None)
if style is None: if style is None:
return f'<t:{int(dt.timestamp())}>' return f'<t:{int(dt.timestamp())}>'
return f'<t:{int(dt.timestamp())}:{style}>' return f'<t:{int(dt.timestamp())}:{style}>'
def raise_expected_coro(coro, error: str)-> TypeError:
if not asyncio.iscoroutinefunction(coro):
raise TypeError(error)
return coro

View File

@ -255,6 +255,9 @@ class VoiceClient(VoiceProtocol):
self.encoder: Encoder = MISSING self.encoder: Encoder = MISSING
self._lite_nonce: int = 0 self._lite_nonce: int = 0
self.ws: DiscordVoiceWebSocket = MISSING self.ws: DiscordVoiceWebSocket = MISSING
self.ip: str = MISSING
self.port: Tuple[Any, ...] = MISSING
warn_nacl = not has_nacl warn_nacl = not has_nacl
supported_modes: Tuple[SupportedModes, ...] = ( supported_modes: Tuple[SupportedModes, ...] = (

View File

@ -886,6 +886,10 @@ class Webhook(BaseWebhook):
Returns the webhooks's hash. Returns the webhooks's hash.
.. describe:: int(x)
Returns the webhooks's ID.
.. versionchanged:: 1.4 .. versionchanged:: 1.4
Webhooks are now comparable and hashable. Webhooks are now comparable and hashable.

View File

@ -475,6 +475,10 @@ class SyncWebhook(BaseWebhook):
Returns the webhooks's hash. Returns the webhooks's hash.
.. describe:: int(x)
Returns the webhooks's ID.
.. versionchanged:: 1.4 .. versionchanged:: 1.4
Webhooks are now comparable and hashable. Webhooks are now comparable and hashable.

216
discord/welcome_screen.py Normal file
View File

@ -0,0 +1,216 @@
# -*- coding: utf-8 -*-
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Dict, List, Optional, TYPE_CHECKING, Union, overload
from .utils import _get_as_snowflake, get
from .errors import InvalidArgument
from .partial_emoji import _EmojiTag
__all__ = (
'WelcomeChannel',
'WelcomeScreen',
)
if TYPE_CHECKING:
from .types.welcome_screen import (
WelcomeScreen as WelcomeScreenPayload,
WelcomeScreenChannel as WelcomeScreenChannelPayload,
)
from .abc import Snowflake
from .guild import Guild
from .partial_emoji import PartialEmoji
from .emoji import Emoji
class WelcomeChannel:
"""Represents a :class:`WelcomeScreen` welcome channel.
.. versionadded:: 2.0
Attributes
-----------
channel: :class:`abc.Snowflake`
The guild channel that is being referenced.
description: :class:`str`
The description shown of the channel.
emoji: Optional[:class:`PartialEmoji`, :class:`Emoji`, :class:`str`]
The emoji used beside the channel description.
"""
def __init__(self, *, channel: Snowflake, description: str, emoji: Union[PartialEmoji, Emoji, str] = None):
self.channel = channel
self.description = description
self.emoji = emoji
def __repr__(self) -> str:
return f'<WelcomeChannel channel={self.channel!r} description={self.description!r} emoji={self.emoji!r}>'
@classmethod
def _from_dict(cls, *, data: WelcomeScreenChannelPayload, guild: Guild) -> WelcomeChannel:
channel_id = _get_as_snowflake(data, 'channel_id')
channel = guild.get_channel(channel_id)
description = data['description']
_emoji_id = _get_as_snowflake(data, 'emoji_id')
_emoji_name = data['emoji_name']
if _emoji_id:
# custom
emoji = get(guild.emojis, id=_emoji_id)
else:
# unicode or None
emoji = _emoji_name
return cls(channel=channel, description=description, emoji=emoji) # type: ignore
def to_dict(self) -> WelcomeScreenChannelPayload:
ret: WelcomeScreenChannelPayload = {
'channel_id': self.channel.id,
'description': self.description,
'emoji_id': None,
'emoji_name': None,
}
if isinstance(self.emoji, _EmojiTag):
ret['emoji_id'] = self.emoji.id # type: ignore
ret['emoji_name'] = self.emoji.name # type: ignore
else:
# unicode or None
ret['emoji_name'] = self.emoji
return ret
class WelcomeScreen:
"""Represents a :class:`Guild` welcome screen.
.. versionadded:: 2.0
Attributes
-----------
description: :class:`str`
The description shown on the welcome screen.
welcome_channels: List[:class:`WelcomeChannel`]
The channels shown on the welcome screen.
"""
def __init__(self, *, data: WelcomeScreenPayload, guild: Guild):
self._state = guild._state
self._guild = guild
self._store(data)
def _store(self, data: WelcomeScreenPayload) -> None:
self.description = data['description']
welcome_channels = data.get('welcome_channels', [])
self.welcome_channels = [WelcomeChannel._from_dict(data=wc, guild=self._guild) for wc in welcome_channels]
def __repr__(self) -> str:
return f'<WelcomeScreen description={self.description!r} welcome_channels={self.welcome_channels!r} enabled={self.enabled}>'
@property
def enabled(self) -> bool:
""":class:`bool`: Whether the welcome screen is displayed.
This is equivalent to checking if ``WELCOME_SCREEN_ENABLED``
is present in :attr:`Guild.features`.
"""
return 'WELCOME_SCREEN_ENABLED' in self._guild.features
@overload
async def edit(
self,
*,
description: Optional[str] = ...,
welcome_channels: Optional[List[WelcomeChannel]] = ...,
enabled: Optional[bool] = ...,
) -> None:
...
@overload
async def edit(self) -> None:
...
async def edit(self, **kwargs):
"""|coro|
Edit the welcome screen.
You must have the :attr:`~Permissions.manage_guild` permission in the
guild to do this.
Usage: ::
rules_channel = guild.get_channel(12345678)
announcements_channel = guild.get_channel(87654321)
custom_emoji = utils.get(guild.emojis, name='loudspeaker')
await welcome_screen.edit(
description='This is a very cool community server!',
welcome_channels=[
WelcomeChannel(channel=rules_channel, description='Read the rules!', emoji='👨‍🏫'),
WelcomeChannel(channel=announcements_channel, description='Watch out for announcements!', emoji=custom_emoji),
]
)
.. note::
Welcome channels can only accept custom emojis if :attr:`~Guild.premium_tier` is level 2 or above.
Parameters
------------
description: Optional[:class:`str`]
The template's description.
welcome_channels: Optional[List[:class:`WelcomeChannel`]]
The welcome channels, in their respective order.
enabled: Optional[:class:`bool`]
Whether the welcome screen should be displayed.
Raises
-------
HTTPException
Editing the welcome screen failed failed.
Forbidden
You don't have permissions to edit the welcome screen.
NotFound
This welcome screen does not exist.
"""
try:
welcome_channels = kwargs['welcome_channels']
except KeyError:
pass
else:
welcome_channels_serialised = []
for wc in welcome_channels:
if not isinstance(wc, WelcomeChannel):
raise InvalidArgument('welcome_channels parameter must be a list of WelcomeChannel')
welcome_channels_serialised.append(wc.to_dict())
kwargs['welcome_channels'] = welcome_channels_serialised
if kwargs:
data = await self._state.http.edit_welcome_screen(self._guild.id, kwargs)
self._store(data)

View File

@ -369,6 +369,17 @@ to handle it, which defaults to print a traceback and ignoring the exception.
:param when: When the typing started as an aware datetime in UTC. :param when: When the typing started as an aware datetime in UTC.
:type when: :class:`datetime.datetime` :type when: :class:`datetime.datetime`
.. function:: on_raw_typing(payload)
Called when someone begins typing a message. Unlike :func:`on_typing`, this is
called regardless if the user can be found or not. This most often happens
when a user types in DMs.
This requires :attr:`Intents.typing` to be enabled.
:param payload: The raw typing payload.
:type payload: :class:`RawTypingEvent`
.. function:: on_message(message) .. function:: on_message(message)
Called when a :class:`Message` is created and sent. Called when a :class:`Message` is created and sent.
@ -3781,6 +3792,22 @@ Template
.. autoclass:: Template() .. autoclass:: Template()
:members: :members:
WelcomeScreen
~~~~~~~~~~~~~~~
.. attributetable:: WelcomeScreen
.. autoclass:: WelcomeScreen()
:members:
WelcomeChannel
~~~~~~~~~~~~~~~
.. attributetable:: WelcomeChannel
.. autoclass:: WelcomeChannel()
:members:
WidgetChannel WidgetChannel
~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~
@ -3846,6 +3873,14 @@ GuildSticker
.. autoclass:: GuildSticker() .. autoclass:: GuildSticker()
:members: :members:
RawTypingEvent
~~~~~~~~~~~~~~~~~~~~~~~
.. attributetable:: RawTypingEvent
.. autoclass:: RawTypingEvent()
:members:
RawMessageDeleteEvent RawMessageDeleteEvent
~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~