Merge branch 'iDevision:2.0' into 2.0
This commit is contained in:
commit
6af5399936
6
.github/CONTRIBUTING.md
vendored
6
.github/CONTRIBUTING.md
vendored
@ -1,5 +1,7 @@
|
||||
## Contributing to discord.py
|
||||
|
||||
Credits to the `original lib` by Rapptz <https://github.com/Rapptz/discord.py>
|
||||
|
||||
First off, thanks for taking the time to contribute. It makes the library substantially better. :+1:
|
||||
|
||||
The following is a set of guidelines for contributing to the repository. These are guidelines, not hard rules.
|
||||
@ -8,9 +10,9 @@ The following is a set of guidelines for contributing to the repository. These a
|
||||
|
||||
Generally speaking questions are better suited in our resources below.
|
||||
|
||||
- The official support server: https://discord.gg/r3sSKJJ
|
||||
- The official support server: https://discord.gg/TvqYBrGXEm
|
||||
- The Discord API server under #python_discord-py: https://discord.gg/discord-api
|
||||
- [The FAQ in the documentation](https://discordpy.readthedocs.io/en/latest/faq.html)
|
||||
- [The FAQ in the documentation](https://enhanced-dpy.readthedocs.io/en/latest/faq.html)
|
||||
- [StackOverflow's `discord.py` tag](https://stackoverflow.com/questions/tagged/discord.py)
|
||||
|
||||
Please try your best not to ask questions in our issue tracker. Most of them don't belong there unless they provide value to a larger audience.
|
||||
|
@ -1,5 +1,5 @@
|
||||
discord.py
|
||||
==========
|
||||
enhanced-discord.py
|
||||
===================
|
||||
|
||||
.. image:: https://discord.com/api/guilds/514232441498763279/embed.png
|
||||
:target: https://discord.gg/PYAfZzpsjG
|
||||
@ -59,7 +59,7 @@ To install the development version, do the following:
|
||||
.. code:: sh
|
||||
|
||||
$ git clone https://github.com/iDevision/enhanced-discord.py
|
||||
$ cd discord.py
|
||||
$ cd enhanced-discord.py
|
||||
$ python3 -m pip install -U .[voice]
|
||||
|
||||
|
||||
|
@ -40,6 +40,7 @@ from .colour import *
|
||||
from .integrations import *
|
||||
from .invite import *
|
||||
from .template import *
|
||||
from .welcome_screen import *
|
||||
from .widget import *
|
||||
from .object import *
|
||||
from .reaction import *
|
||||
|
@ -794,13 +794,13 @@ class CustomActivity(BaseActivity):
|
||||
return hash((self.name, str(self.emoji)))
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.emoji:
|
||||
if self.name:
|
||||
return f'{self.emoji} {self.name}'
|
||||
return str(self.emoji)
|
||||
else:
|
||||
if not self.emoji:
|
||||
return str(self.name)
|
||||
|
||||
if self.name:
|
||||
return f'{self.emoji} {self.name}'
|
||||
return str(self.emoji)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<CustomActivity name={self.name!r} emoji={self.emoji!r}>'
|
||||
|
||||
|
@ -313,10 +313,11 @@ class Asset(AssetMixin):
|
||||
if self._animated:
|
||||
if format not in VALID_ASSET_FORMATS:
|
||||
raise InvalidArgument(f'format must be one of {VALID_ASSET_FORMATS}')
|
||||
else:
|
||||
url = url.with_path(f'{path}.{format}')
|
||||
elif static_format is MISSING:
|
||||
if format not in VALID_STATIC_FORMATS:
|
||||
raise InvalidArgument(f'format must be one of {VALID_STATIC_FORMATS}')
|
||||
url = url.with_path(f'{path}.{format}')
|
||||
url = url.with_path(f'{path}.{format}')
|
||||
|
||||
if static_format is not MISSING and not self._animated:
|
||||
if static_format not in VALID_STATIC_FORMATS:
|
||||
|
@ -330,6 +330,10 @@ class AuditLogEntry(Hashable):
|
||||
|
||||
Returns the entry's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the entry's ID.
|
||||
|
||||
.. versionchanged:: 1.7
|
||||
Audit log entries are now comparable and hashable.
|
||||
|
||||
|
@ -115,6 +115,10 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
|
||||
Returns the channel's name.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the channel's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
name: :class:`str`
|
||||
@ -224,6 +228,16 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
"""List[:class:`Member`]: Returns all members that can see this channel."""
|
||||
return [m for m in self.guild.members if self.permissions_for(m).read_messages]
|
||||
|
||||
@property
|
||||
def bots(self) -> List[Member]:
|
||||
"""List[:class:`Member`]: Returns all bots that can see this channel."""
|
||||
return [m for m in self.guild.members if m.bot and self.permissions_for(m).read_messages]
|
||||
|
||||
@property
|
||||
def humans(self) -> List[Member]:
|
||||
"""List[:class:`Member`]: Returns all human members that can see this channel."""
|
||||
return [m for m in self.guild.members if not m.bot and self.permissions_for(m).read_messages]
|
||||
|
||||
@property
|
||||
def threads(self) -> List[Thread]:
|
||||
"""List[:class:`Thread`]: Returns all the threads that you can see.
|
||||
@ -1334,6 +1348,10 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
|
||||
|
||||
Returns the category's name.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the category's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
name: :class:`str`
|
||||
@ -1556,6 +1574,10 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
|
||||
|
||||
Returns the channel's name.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the channel's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
name: :class:`str`
|
||||
@ -1728,6 +1750,10 @@ class DMChannel(discord.abc.Messageable, Hashable):
|
||||
|
||||
Returns a string representation of the channel
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the channel's ID.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
recipient: Optional[:class:`User`]
|
||||
@ -1854,6 +1880,10 @@ class GroupChannel(discord.abc.Messageable, Hashable):
|
||||
|
||||
Returns a string representation of the channel
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the channel's ID.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
recipients: List[:class:`User`]
|
||||
@ -2000,6 +2030,10 @@ class PartialMessageable(discord.abc.Messageable, Hashable):
|
||||
|
||||
Returns the partial messageable's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the messageable's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
id: :class:`int`
|
||||
|
@ -842,6 +842,38 @@ class Client:
|
||||
"""
|
||||
return self._connection.get_user(id)
|
||||
|
||||
async def try_user(self, id: int, /) -> Optional[User]:
|
||||
"""|coro|
|
||||
Returns a user with the given ID. If not from cache, the user will be requested from the API.
|
||||
|
||||
You do not have to share any guilds with the user to get this information from the API,
|
||||
however many operations do require that you do.
|
||||
|
||||
.. note::
|
||||
This method is an API call. If you have :attr:`discord.Intents.members` and member cache enabled, consider :meth:`get_user` instead.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
id: :class:`int`
|
||||
The ID to search for.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Optional[:class:`~discord.User`]
|
||||
The user or ``None`` if not found.
|
||||
"""
|
||||
maybe_user = self.get_user(id)
|
||||
|
||||
if maybe_user is not None:
|
||||
return maybe_user
|
||||
|
||||
try:
|
||||
return await self.fetch_user(id)
|
||||
except NotFound:
|
||||
return None
|
||||
|
||||
def get_emoji(self, id: int, /) -> Optional[Emoji]:
|
||||
"""Returns an emoji with the given ID.
|
||||
|
||||
|
@ -251,6 +251,13 @@ class Colour:
|
||||
def red(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0xe74c3c``."""
|
||||
return cls(0xe74c3c)
|
||||
|
||||
@classmethod
|
||||
def nitro_booster(cls):
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0xf47fff``.
|
||||
|
||||
.. versionadded:: 2.0"""
|
||||
return cls(0xf47fff)
|
||||
|
||||
@classmethod
|
||||
def dark_red(cls: Type[CT]) -> CT:
|
||||
@ -324,6 +331,15 @@ class Colour:
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return cls(0xFEE75C)
|
||||
|
||||
@classmethod
|
||||
def dark_blurple(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x4E5D94``.
|
||||
This is the original Dark Blurple branding.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return cls(0x4E5D94)
|
||||
|
||||
|
||||
Color = Colour
|
||||
|
@ -72,30 +72,36 @@ if TYPE_CHECKING:
|
||||
T = TypeVar('T')
|
||||
MaybeEmpty = Union[T, _EmptyEmbed]
|
||||
|
||||
|
||||
class _EmbedFooterProxy(Protocol):
|
||||
text: MaybeEmpty[str]
|
||||
icon_url: MaybeEmpty[str]
|
||||
|
||||
|
||||
class _EmbedFieldProxy(Protocol):
|
||||
name: MaybeEmpty[str]
|
||||
value: MaybeEmpty[str]
|
||||
inline: bool
|
||||
|
||||
|
||||
class _EmbedMediaProxy(Protocol):
|
||||
url: MaybeEmpty[str]
|
||||
proxy_url: MaybeEmpty[str]
|
||||
height: MaybeEmpty[int]
|
||||
width: MaybeEmpty[int]
|
||||
|
||||
|
||||
class _EmbedVideoProxy(Protocol):
|
||||
url: MaybeEmpty[str]
|
||||
height: MaybeEmpty[int]
|
||||
width: MaybeEmpty[int]
|
||||
|
||||
|
||||
class _EmbedProviderProxy(Protocol):
|
||||
name: MaybeEmpty[str]
|
||||
url: MaybeEmpty[str]
|
||||
|
||||
|
||||
class _EmbedAuthorProxy(Protocol):
|
||||
name: MaybeEmpty[str]
|
||||
url: MaybeEmpty[str]
|
||||
@ -175,15 +181,15 @@ class Embed:
|
||||
Empty: Final = EmptyEmbed
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
|
||||
color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
|
||||
title: MaybeEmpty[Any] = EmptyEmbed,
|
||||
type: EmbedType = 'rich',
|
||||
url: MaybeEmpty[Any] = EmptyEmbed,
|
||||
description: MaybeEmpty[Any] = EmptyEmbed,
|
||||
timestamp: datetime.datetime = None,
|
||||
self,
|
||||
*,
|
||||
colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
|
||||
color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
|
||||
title: MaybeEmpty[Any] = EmptyEmbed,
|
||||
type: EmbedType = 'rich',
|
||||
url: MaybeEmpty[Any] = EmptyEmbed,
|
||||
description: MaybeEmpty[Any] = EmptyEmbed,
|
||||
timestamp: datetime.datetime = None,
|
||||
):
|
||||
|
||||
self.colour = colour if colour is not EmptyEmbed else color
|
||||
@ -397,6 +403,22 @@ class Embed:
|
||||
"""
|
||||
return EmbedProxy(getattr(self, '_image', {})) # type: ignore
|
||||
|
||||
@image.setter
|
||||
def image(self: E, url: Any):
|
||||
if url is EmptyEmbed:
|
||||
del self._image
|
||||
else:
|
||||
self._image = {
|
||||
'url': str(url),
|
||||
}
|
||||
|
||||
@image.deleter
|
||||
def image(self: E):
|
||||
try:
|
||||
del self._image
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def set_image(self: E, *, url: MaybeEmpty[Any]) -> E:
|
||||
"""Sets the image for the embed content.
|
||||
|
||||
@ -412,15 +434,7 @@ class Embed:
|
||||
The source URL for the image. Only HTTP(S) is supported.
|
||||
"""
|
||||
|
||||
if url is EmptyEmbed:
|
||||
try:
|
||||
del self._image
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
self._image = {
|
||||
'url': str(url),
|
||||
}
|
||||
self.image = url
|
||||
|
||||
return self
|
||||
|
||||
@ -439,7 +453,23 @@ class Embed:
|
||||
"""
|
||||
return EmbedProxy(getattr(self, '_thumbnail', {})) # type: ignore
|
||||
|
||||
def set_thumbnail(self: E, *, url: MaybeEmpty[Any]) -> E:
|
||||
@thumbnail.setter
|
||||
def thumbnail(self: E, url: Any):
|
||||
if url is EmptyEmbed:
|
||||
del self._thumbnail
|
||||
else:
|
||||
self._thumbnail = {
|
||||
'url': str(url),
|
||||
}
|
||||
|
||||
@thumbnail.deleter
|
||||
def thumbnail(self):
|
||||
try:
|
||||
del self._thumbnail
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def set_thumbnail(self: E, *, url: MaybeEmpty[Any]):
|
||||
"""Sets the thumbnail for the embed content.
|
||||
|
||||
This function returns the class instance to allow for fluent-style
|
||||
@ -454,15 +484,7 @@ class Embed:
|
||||
The source URL for the thumbnail. Only HTTP(S) is supported.
|
||||
"""
|
||||
|
||||
if url is EmptyEmbed:
|
||||
try:
|
||||
del self._thumbnail
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
self._thumbnail = {
|
||||
'url': str(url),
|
||||
}
|
||||
self.thumbnail = url
|
||||
|
||||
return self
|
||||
|
||||
|
@ -72,6 +72,10 @@ class Emoji(_EmojiTag, AssetMixin):
|
||||
|
||||
Returns the emoji rendered for discord.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the emoji ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
name: :class:`str`
|
||||
@ -137,6 +141,9 @@ class Emoji(_EmojiTag, AssetMixin):
|
||||
return f'<a:{self.name}:{self.id}>'
|
||||
return f'<:{self.name}:{self.id}>'
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>'
|
||||
|
||||
|
@ -53,6 +53,7 @@ from .context import Context
|
||||
from . import errors
|
||||
from .help import HelpCommand, DefaultHelpCommand
|
||||
from .cog import Cog
|
||||
from discord.utils import raise_expected_coro
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import importlib.machinery
|
||||
@ -453,14 +454,59 @@ class BotBase(GroupMixin):
|
||||
elif self.owner_ids:
|
||||
return user.id in self.owner_ids
|
||||
else:
|
||||
# Populate the used fields, then retry the check. This is only done at-most once in the bot lifetime.
|
||||
await self.populate_owners()
|
||||
return await self.is_owner(user)
|
||||
|
||||
app = await self.application_info() # type: ignore
|
||||
if app.team:
|
||||
self.owner_ids = ids = {m.id for m in app.team.members}
|
||||
return user.id in ids
|
||||
async def try_owners(self) -> List[discord.User]:
|
||||
"""|coro|
|
||||
|
||||
Returns a list of :class:`~discord.User` representing the owners of the bot.
|
||||
It uses the :attr:`owner_id` and :attr:`owner_ids`, if set.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
The function also checks if the application is team-owned if
|
||||
:attr:`owner_ids` is not set.
|
||||
|
||||
Returns
|
||||
--------
|
||||
List[:class:`~discord.User`]
|
||||
List of owners of the bot.
|
||||
"""
|
||||
if self.owner_id:
|
||||
owner = await self.try_user(self.owner_id)
|
||||
|
||||
if owner:
|
||||
return [owner]
|
||||
else:
|
||||
self.owner_id = owner_id = app.owner.id
|
||||
return user.id == owner_id
|
||||
return []
|
||||
|
||||
elif self.owner_ids:
|
||||
owners = []
|
||||
|
||||
for owner_id in self.owner_ids:
|
||||
owner = await self.try_user(owner_id)
|
||||
if owner:
|
||||
owners.append(owner)
|
||||
|
||||
return owners
|
||||
else:
|
||||
# We didn't have owners cached yet, cache them and retry.
|
||||
await self.populate_owners()
|
||||
return await self.try_owners()
|
||||
|
||||
async def populate_owners(self):
|
||||
"""|coro|
|
||||
|
||||
Populate the :attr:`owner_id` and :attr:`owner_ids` through the use of :meth:`~.Bot.application_info`.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
app = await self.application_info() # type: ignore
|
||||
if app.team:
|
||||
self.owner_ids = {m.id for m in app.team.members}
|
||||
else:
|
||||
self.owner_id = app.owner.id
|
||||
|
||||
def before_invoke(self, coro: CFT) -> CFT:
|
||||
"""A decorator that registers a coroutine as a pre-invoke hook.
|
||||
@ -488,11 +534,9 @@ class BotBase(GroupMixin):
|
||||
TypeError
|
||||
The coroutine passed is not actually a coroutine.
|
||||
"""
|
||||
if not asyncio.iscoroutinefunction(coro):
|
||||
raise TypeError('The pre-invoke hook must be a coroutine.')
|
||||
|
||||
self._before_invoke = coro
|
||||
return coro
|
||||
return raise_expected_coro(
|
||||
coro, 'The pre-invoke hook must be a coroutine.'
|
||||
)
|
||||
|
||||
def after_invoke(self, coro: CFT) -> CFT:
|
||||
r"""A decorator that registers a coroutine as a post-invoke hook.
|
||||
@ -521,11 +565,10 @@ class BotBase(GroupMixin):
|
||||
TypeError
|
||||
The coroutine passed is not actually a coroutine.
|
||||
"""
|
||||
if not asyncio.iscoroutinefunction(coro):
|
||||
raise TypeError('The post-invoke hook must be a coroutine.')
|
||||
return raise_expected_coro(
|
||||
coro, 'The post-invoke hook must be a coroutine.'
|
||||
)
|
||||
|
||||
self._after_invoke = coro
|
||||
return coro
|
||||
|
||||
# listener registration
|
||||
|
||||
@ -1266,7 +1309,7 @@ class Bot(BotBase, discord.Client):
|
||||
when passing an empty string, it should always be last as no prefix
|
||||
after it will be matched.
|
||||
case_insensitive: :class:`bool`
|
||||
Whether the commands should be case insensitive. Defaults to ``False``. This
|
||||
Whether the commands should be case insensitive. Defaults to ``True``. This
|
||||
attribute does not carry over to groups. You must set it to every group if
|
||||
you require group commands to be case insensitive as well.
|
||||
description: :class:`str`
|
||||
|
@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
@ -32,6 +33,7 @@ import discord.abc
|
||||
import discord.utils
|
||||
|
||||
from discord.message import Message
|
||||
from discord import Permissions
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import ParamSpec
|
||||
@ -62,10 +64,7 @@ T = TypeVar('T')
|
||||
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
|
||||
CogT = TypeVar('CogT', bound="Cog")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
P = ParamSpec('P')
|
||||
else:
|
||||
P = TypeVar('P')
|
||||
P = ParamSpec('P') if TYPE_CHECKING else TypeVar('P')
|
||||
|
||||
|
||||
class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
@ -318,6 +317,13 @@ class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
g = self.guild
|
||||
return g.voice_client if g else None
|
||||
|
||||
def author_permissions(self) -> Permissions:
|
||||
"""Returns the author permissions in the given channel.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return self.channel.permissions_for(self.author)
|
||||
|
||||
async def send_help(self, *args: Any) -> Any:
|
||||
"""send_help(entity=<bot>)
|
||||
|
||||
|
@ -356,15 +356,15 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
|
||||
|
||||
@staticmethod
|
||||
def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]:
|
||||
if guild_id is not None:
|
||||
guild = ctx.bot.get_guild(guild_id)
|
||||
if guild is not None and channel_id is not None:
|
||||
return guild._resolve_channel(channel_id) # type: ignore
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
if guild_id is None:
|
||||
return ctx.bot.get_channel(channel_id) if channel_id else ctx.channel
|
||||
|
||||
guild = ctx.bot.get_guild(guild_id)
|
||||
if guild is not None and channel_id is not None:
|
||||
return guild._resolve_channel(channel_id) # type: ignore
|
||||
else:
|
||||
return None
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.PartialMessage:
|
||||
guild_id, message_id, channel_id = self._get_id_matches(ctx, argument)
|
||||
channel = self._resolve_channel(ctx, guild_id, channel_id)
|
||||
@ -757,8 +757,8 @@ class GuildConverter(IDConverter[discord.Guild]):
|
||||
if result is None:
|
||||
result = discord.utils.get(ctx.bot.guilds, name=argument)
|
||||
|
||||
if result is None:
|
||||
raise GuildNotFound(argument)
|
||||
if result is None:
|
||||
raise GuildNotFound(argument)
|
||||
return result
|
||||
|
||||
|
||||
@ -942,8 +942,7 @@ class clean_content(Converter[str]):
|
||||
def repl(match: re.Match) -> str:
|
||||
type = match[1]
|
||||
id = int(match[2])
|
||||
transformed = transforms[type](id)
|
||||
return transformed
|
||||
return transforms[type](id)
|
||||
|
||||
result = re.sub(r'<(@[!&]?|#)([0-9]{15,20})>', repl, argument)
|
||||
if self.escape_markdown:
|
||||
|
@ -1264,10 +1264,10 @@ class GroupMixin(Generic[CogT]):
|
||||
A mapping of command name to :class:`.Command`
|
||||
objects.
|
||||
case_insensitive: :class:`bool`
|
||||
Whether the commands should be case insensitive. Defaults to ``False``.
|
||||
Whether the commands should be case insensitive. Defaults to ``True``.
|
||||
"""
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
case_insensitive = kwargs.get('case_insensitive', False)
|
||||
case_insensitive = kwargs.get('case_insensitive', True)
|
||||
self.all_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {}
|
||||
self.case_insensitive: bool = case_insensitive
|
||||
super().__init__(*args, **kwargs)
|
||||
|
@ -82,9 +82,7 @@ class StringView:
|
||||
def skip_string(self, string):
|
||||
strlen = len(string)
|
||||
if self.buffer[self.index:self.index + strlen] == string:
|
||||
self.previous = self.index
|
||||
self.index += strlen
|
||||
return True
|
||||
return self._return_index(strlen, True)
|
||||
return False
|
||||
|
||||
def read_rest(self):
|
||||
@ -95,9 +93,7 @@ class StringView:
|
||||
|
||||
def read(self, n):
|
||||
result = self.buffer[self.index:self.index + n]
|
||||
self.previous = self.index
|
||||
self.index += n
|
||||
return result
|
||||
return self._return_index(n, result)
|
||||
|
||||
def get(self):
|
||||
try:
|
||||
@ -105,9 +101,12 @@ class StringView:
|
||||
except IndexError:
|
||||
result = None
|
||||
|
||||
return self._return_index(1, result)
|
||||
|
||||
def _return_index(self, arg0, arg1):
|
||||
self.previous = self.index
|
||||
self.index += 1
|
||||
return result
|
||||
self.index += arg0
|
||||
return arg1
|
||||
|
||||
def get_word(self):
|
||||
pos = 0
|
||||
|
@ -46,7 +46,9 @@ import traceback
|
||||
|
||||
from collections.abc import Sequence
|
||||
from discord.backoff import ExponentialBackoff
|
||||
from discord.utils import MISSING
|
||||
from discord.utils import MISSING, raise_expected_coro
|
||||
|
||||
|
||||
|
||||
__all__ = (
|
||||
'loop',
|
||||
@ -488,11 +490,7 @@ class Loop(Generic[LF]):
|
||||
The function was not a coroutine.
|
||||
"""
|
||||
|
||||
if not inspect.iscoroutinefunction(coro):
|
||||
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
|
||||
|
||||
self._before_loop = coro
|
||||
return coro
|
||||
return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.')
|
||||
|
||||
def after_loop(self, coro: FT) -> FT:
|
||||
"""A decorator that register a coroutine to be called after the loop finished running.
|
||||
@ -516,11 +514,7 @@ class Loop(Generic[LF]):
|
||||
The function was not a coroutine.
|
||||
"""
|
||||
|
||||
if not inspect.iscoroutinefunction(coro):
|
||||
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
|
||||
|
||||
self._after_loop = coro
|
||||
return coro
|
||||
return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.')
|
||||
|
||||
def error(self, coro: ET) -> ET:
|
||||
"""A decorator that registers a coroutine to be called if the task encounters an unhandled exception.
|
||||
@ -542,11 +536,7 @@ class Loop(Generic[LF]):
|
||||
TypeError
|
||||
The function was not a coroutine.
|
||||
"""
|
||||
if not inspect.iscoroutinefunction(coro):
|
||||
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
|
||||
|
||||
self._error = coro # type: ignore
|
||||
return coro
|
||||
return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.')
|
||||
|
||||
def _get_next_sleep_time(self) -> datetime.datetime:
|
||||
if self._sleep is not MISSING:
|
||||
@ -614,8 +604,7 @@ class Loop(Generic[LF]):
|
||||
)
|
||||
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
|
||||
|
||||
ret = sorted(set(ret)) # de-dupe and sort times
|
||||
return ret
|
||||
return sorted(set(ret))
|
||||
|
||||
def change_interval(
|
||||
self,
|
||||
|
@ -22,8 +22,12 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, TypedDict, Any, Optional, List, TypeVar, Type, Dict, Callable, Coroutine, NamedTuple, Deque
|
||||
|
||||
import asyncio
|
||||
from collections import namedtuple, deque
|
||||
from collections import deque
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import struct
|
||||
@ -38,9 +42,25 @@ import aiohttp
|
||||
from . import utils
|
||||
from .activity import BaseActivity
|
||||
from .enums import SpeakingState
|
||||
from .errors import ConnectionClosed, InvalidArgument
|
||||
from .errors import ConnectionClosed, InvalidArgument
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import Client
|
||||
from .state import ConnectionState
|
||||
from .voice_client import VoiceClient
|
||||
|
||||
T = TypeVar('T')
|
||||
DWS = TypeVar('DWS', bound='DiscordWebSocket')
|
||||
DVWS = TypeVar('DVWS', bound='DiscordVoiceWebSocket')
|
||||
|
||||
Coro = Callable[..., Coroutine[Any, Any, Any]]
|
||||
Predicate = Callable[[Dict[str, Any]], bool]
|
||||
DataCallable = Callable[[Dict[str, Any]], T]
|
||||
Result = Optional[DataCallable[Any]]
|
||||
|
||||
|
||||
_log: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
__all__ = (
|
||||
'DiscordWebSocket',
|
||||
@ -50,36 +70,49 @@ __all__ = (
|
||||
'ReconnectWebSocket',
|
||||
)
|
||||
|
||||
|
||||
class Heartbeat(TypedDict):
|
||||
op: int
|
||||
d: int
|
||||
|
||||
|
||||
class ReconnectWebSocket(Exception):
|
||||
"""Signals to safely reconnect the websocket."""
|
||||
def __init__(self, shard_id, *, resume=True):
|
||||
self.shard_id = shard_id
|
||||
self.resume = resume
|
||||
def __init__(self, shard_id: Optional[int], *, resume: bool = True) -> None:
|
||||
self.shard_id: Optional[int] = shard_id
|
||||
self.resume: bool = resume
|
||||
self.op = 'RESUME' if resume else 'IDENTIFY'
|
||||
|
||||
|
||||
class WebSocketClosure(Exception):
|
||||
"""An exception to make up for the fact that aiohttp doesn't signal closure."""
|
||||
pass
|
||||
|
||||
EventListener = namedtuple('EventListener', 'predicate event result future')
|
||||
|
||||
class EventListener(NamedTuple):
|
||||
predicate: Predicate
|
||||
event: str
|
||||
result: Result
|
||||
future: asyncio.Future
|
||||
|
||||
|
||||
class GatewayRatelimiter:
|
||||
def __init__(self, count=110, per=60.0):
|
||||
def __init__(self, count: int = 110, per: float = 60.0) -> None:
|
||||
# The default is 110 to give room for at least 10 heartbeats per minute
|
||||
self.max = count
|
||||
self.remaining = count
|
||||
self.window = 0.0
|
||||
self.per = per
|
||||
self.lock = asyncio.Lock()
|
||||
self.shard_id = None
|
||||
self.max: int = count
|
||||
self.remaining: int = count
|
||||
self.window: float = 0.0
|
||||
self.per: float = per
|
||||
self.lock: asyncio.Lock = asyncio.Lock()
|
||||
self.shard_id: Optional[int] = None
|
||||
|
||||
def is_ratelimited(self):
|
||||
def is_ratelimited(self) -> bool:
|
||||
current = time.time()
|
||||
if current > self.window + self.per:
|
||||
return False
|
||||
return self.remaining == 0
|
||||
|
||||
def get_delay(self):
|
||||
def get_delay(self) -> float:
|
||||
current = time.time()
|
||||
|
||||
if current > self.window + self.per:
|
||||
@ -97,7 +130,7 @@ class GatewayRatelimiter:
|
||||
|
||||
return 0.0
|
||||
|
||||
async def block(self):
|
||||
async def block(self) -> None:
|
||||
async with self.lock:
|
||||
delta = self.get_delay()
|
||||
if delta:
|
||||
@ -106,27 +139,27 @@ class GatewayRatelimiter:
|
||||
|
||||
|
||||
class KeepAliveHandler(threading.Thread):
|
||||
def __init__(self, *args, **kwargs):
|
||||
ws = kwargs.pop('ws', None)
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
ws = kwargs.pop('ws')
|
||||
interval = kwargs.pop('interval', None)
|
||||
shard_id = kwargs.pop('shard_id', None)
|
||||
threading.Thread.__init__(self, *args, **kwargs)
|
||||
self.ws = ws
|
||||
self._main_thread_id = ws.thread_id
|
||||
self.interval = interval
|
||||
self.daemon = True
|
||||
self.shard_id = shard_id
|
||||
self.msg = 'Keeping shard ID %s websocket alive with sequence %s.'
|
||||
self.block_msg = 'Shard ID %s heartbeat blocked for more than %s seconds.'
|
||||
self.behind_msg = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.'
|
||||
self._stop_ev = threading.Event()
|
||||
self._last_ack = time.perf_counter()
|
||||
self._last_send = time.perf_counter()
|
||||
self._last_recv = time.perf_counter()
|
||||
self.latency = float('inf')
|
||||
self.heartbeat_timeout = ws._max_heartbeat_timeout
|
||||
self.ws: DiscordWebSocket = ws
|
||||
self._main_thread_id: int = ws.thread_id
|
||||
self.interval: Optional[float] = interval
|
||||
self.daemon: bool = True
|
||||
self.shard_id: Optional[int] = shard_id
|
||||
self.msg: str = 'Keeping shard ID %s websocket alive with sequence %s.'
|
||||
self.block_msg: str = 'Shard ID %s heartbeat blocked for more than %s seconds.'
|
||||
self.behind_msg: str = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.'
|
||||
self._stop_ev: threading.Event = threading.Event()
|
||||
self._last_ack: float = time.perf_counter()
|
||||
self._last_send: float = time.perf_counter()
|
||||
self._last_recv: float = time.perf_counter()
|
||||
self.latency: float = float('inf')
|
||||
self.heartbeat_timeout: float = ws._max_heartbeat_timeout
|
||||
|
||||
def run(self):
|
||||
def run(self) -> None:
|
||||
while not self._stop_ev.wait(self.interval):
|
||||
if self._last_recv + self.heartbeat_timeout < time.perf_counter():
|
||||
_log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id)
|
||||
@ -168,19 +201,20 @@ class KeepAliveHandler(threading.Thread):
|
||||
else:
|
||||
self._last_send = time.perf_counter()
|
||||
|
||||
def get_payload(self):
|
||||
def get_payload(self) -> Heartbeat:
|
||||
return {
|
||||
'op': self.ws.HEARTBEAT,
|
||||
'd': self.ws.sequence
|
||||
# the websocket's sequence won't be None here
|
||||
'd': self.ws.sequence # type: ignore
|
||||
}
|
||||
|
||||
def stop(self):
|
||||
def stop(self) -> None:
|
||||
self._stop_ev.set()
|
||||
|
||||
def tick(self):
|
||||
def tick(self) -> None:
|
||||
self._last_recv = time.perf_counter()
|
||||
|
||||
def ack(self):
|
||||
def ack(self) -> None:
|
||||
ack_time = time.perf_counter()
|
||||
self._last_ack = ack_time
|
||||
self.latency = ack_time - self._last_send
|
||||
@ -188,30 +222,32 @@ class KeepAliveHandler(threading.Thread):
|
||||
_log.warning(self.behind_msg, self.shard_id, self.latency)
|
||||
|
||||
class VoiceKeepAliveHandler(KeepAliveHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.recent_ack_latencies = deque(maxlen=20)
|
||||
self.recent_ack_latencies: Deque[float] = deque(maxlen=20)
|
||||
self.msg = 'Keeping shard ID %s voice websocket alive with timestamp %s.'
|
||||
self.block_msg = 'Shard ID %s voice heartbeat blocked for more than %s seconds'
|
||||
self.behind_msg = 'High socket latency, shard ID %s heartbeat is %.1fs behind'
|
||||
|
||||
def get_payload(self):
|
||||
def get_payload(self) -> Heartbeat:
|
||||
return {
|
||||
'op': self.ws.HEARTBEAT,
|
||||
'd': int(time.time() * 1000)
|
||||
}
|
||||
|
||||
def ack(self):
|
||||
def ack(self) -> None:
|
||||
ack_time = time.perf_counter()
|
||||
self._last_ack = ack_time
|
||||
self._last_recv = ack_time
|
||||
self.latency = ack_time - self._last_send
|
||||
self.recent_ack_latencies.append(self.latency)
|
||||
|
||||
|
||||
class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse):
|
||||
async def close(self, *, code: int = 4000, message: bytes = b'') -> bool:
|
||||
return await super().close(code=code, message=message)
|
||||
|
||||
|
||||
class DiscordWebSocket:
|
||||
"""Implements a WebSocket for Discord's gateway v6.
|
||||
|
||||
@ -266,41 +302,53 @@ class DiscordWebSocket:
|
||||
HEARTBEAT_ACK = 11
|
||||
GUILD_SYNC = 12
|
||||
|
||||
def __init__(self, socket, *, loop):
|
||||
self.socket = socket
|
||||
self.loop = loop
|
||||
def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self.socket: aiohttp.ClientWebSocketResponse = socket
|
||||
self.loop: asyncio.AbstractEventLoop = loop
|
||||
|
||||
# an empty dispatcher to prevent crashes
|
||||
self._dispatch = lambda *args: None
|
||||
# generic event listeners
|
||||
self._dispatch_listeners = []
|
||||
self._dispatch_listeners: List[EventListener] = []
|
||||
# the keep alive
|
||||
self._keep_alive = None
|
||||
self.thread_id = threading.get_ident()
|
||||
self._keep_alive: Optional[KeepAliveHandler] = None
|
||||
self.thread_id: int = threading.get_ident()
|
||||
|
||||
# ws related stuff
|
||||
self.session_id = None
|
||||
self.sequence = None
|
||||
self.session_id: Optional[str] = None
|
||||
self.sequence: Optional[int] = None
|
||||
self._zlib = zlib.decompressobj()
|
||||
self._buffer = bytearray()
|
||||
self._close_code = None
|
||||
self._rate_limiter = GatewayRatelimiter()
|
||||
self._buffer: bytearray = bytearray()
|
||||
self._close_code: Optional[int] = None
|
||||
self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter()
|
||||
|
||||
# attributes that get set in from_client
|
||||
self.token: str = utils.MISSING
|
||||
self._connection: ConnectionState = utils.MISSING
|
||||
self._discord_parsers: Dict[str, DataCallable[None]] = utils.MISSING
|
||||
self.gateway: str = utils.MISSING
|
||||
self.call_hooks: Coro = utils.MISSING
|
||||
self._initial_identify: bool = utils.MISSING
|
||||
self.shard_id: Optional[int] = utils.MISSING
|
||||
self.shard_count: Optional[int] = utils.MISSING
|
||||
self.session_id: Optional[str] = utils.MISSING
|
||||
self._max_heartbeat_timeout: float = utils.MISSING
|
||||
|
||||
@property
|
||||
def open(self):
|
||||
def open(self) -> bool:
|
||||
return not self.socket.closed
|
||||
|
||||
def is_ratelimited(self):
|
||||
def is_ratelimited(self) -> bool:
|
||||
return self._rate_limiter.is_ratelimited()
|
||||
|
||||
def debug_log_receive(self, data, /):
|
||||
def debug_log_receive(self, data, /) -> None:
|
||||
self._dispatch('socket_raw_receive', data)
|
||||
|
||||
def log_receive(self, _, /):
|
||||
def log_receive(self, _, /) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def from_client(cls, client, *, initial=False, gateway=None, shard_id=None, session=None, sequence=None, resume=False):
|
||||
async def from_client(cls: Type[DWS], client: Client, *, initial: bool = False, gateway: Optional[str] = None, shard_id: Optional[int] = None, session: Optional[str] = None, sequence: Optional[int] = None, resume: bool = False) -> DWS:
|
||||
"""Creates a main websocket for Discord from a :class:`Client`.
|
||||
|
||||
This is for internal use only.
|
||||
@ -310,7 +358,9 @@ class DiscordWebSocket:
|
||||
ws = cls(socket, loop=client.loop)
|
||||
|
||||
# dynamically add attributes needed
|
||||
ws.token = client.http.token
|
||||
|
||||
# the token won't be None here
|
||||
ws.token = client.http.token # type: ignore
|
||||
ws._connection = client._connection
|
||||
ws._discord_parsers = client._connection.parsers
|
||||
ws._dispatch = client.dispatch
|
||||
@ -342,7 +392,7 @@ class DiscordWebSocket:
|
||||
await ws.resume()
|
||||
return ws
|
||||
|
||||
def wait_for(self, event, predicate, result=None):
|
||||
def wait_for(self, event: str, predicate: Predicate, result: Result = None) -> asyncio.Future:
|
||||
"""Waits for a DISPATCH'd event that meets the predicate.
|
||||
|
||||
Parameters
|
||||
@ -367,7 +417,7 @@ class DiscordWebSocket:
|
||||
self._dispatch_listeners.append(entry)
|
||||
return future
|
||||
|
||||
async def identify(self):
|
||||
async def identify(self) -> None:
|
||||
"""Sends the IDENTIFY packet."""
|
||||
payload = {
|
||||
'op': self.IDENTIFY,
|
||||
@ -405,7 +455,7 @@ class DiscordWebSocket:
|
||||
await self.send_as_json(payload)
|
||||
_log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id)
|
||||
|
||||
async def resume(self):
|
||||
async def resume(self) -> None:
|
||||
"""Sends the RESUME packet."""
|
||||
payload = {
|
||||
'op': self.RESUME,
|
||||
@ -419,7 +469,8 @@ class DiscordWebSocket:
|
||||
await self.send_as_json(payload)
|
||||
_log.info('Shard ID %s has sent the RESUME payload.', self.shard_id)
|
||||
|
||||
async def received_message(self, msg, /):
|
||||
|
||||
async def received_message(self, msg, /) -> None:
|
||||
if type(msg) is bytes:
|
||||
self._buffer.extend(msg)
|
||||
|
||||
@ -537,16 +588,16 @@ class DiscordWebSocket:
|
||||
del self._dispatch_listeners[index]
|
||||
|
||||
@property
|
||||
def latency(self):
|
||||
def latency(self) -> float:
|
||||
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds."""
|
||||
heartbeat = self._keep_alive
|
||||
return float('inf') if heartbeat is None else heartbeat.latency
|
||||
|
||||
def _can_handle_close(self):
|
||||
def _can_handle_close(self) -> bool:
|
||||
code = self._close_code or self.socket.close_code
|
||||
return code not in (1000, 4004, 4010, 4011, 4012, 4013, 4014)
|
||||
|
||||
async def poll_event(self):
|
||||
async def poll_event(self) -> None:
|
||||
"""Polls for a DISPATCH event and handles the general gateway loop.
|
||||
|
||||
Raises
|
||||
@ -584,23 +635,23 @@ class DiscordWebSocket:
|
||||
_log.info('Websocket closed with %s, cannot reconnect.', code)
|
||||
raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None
|
||||
|
||||
async def debug_send(self, data, /):
|
||||
async def debug_send(self, data, /) -> None:
|
||||
await self._rate_limiter.block()
|
||||
self._dispatch('socket_raw_send', data)
|
||||
await self.socket.send_str(data)
|
||||
|
||||
async def send(self, data, /):
|
||||
async def send(self, data, /) -> None:
|
||||
await self._rate_limiter.block()
|
||||
await self.socket.send_str(data)
|
||||
|
||||
async def send_as_json(self, data):
|
||||
async def send_as_json(self, data) -> None:
|
||||
try:
|
||||
await self.send(utils._to_json(data))
|
||||
except RuntimeError as exc:
|
||||
if not self._can_handle_close():
|
||||
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
|
||||
|
||||
async def send_heartbeat(self, data):
|
||||
async def send_heartbeat(self, data: Heartbeat) -> None:
|
||||
# This bypasses the rate limit handling code since it has a higher priority
|
||||
try:
|
||||
await self.socket.send_str(utils._to_json(data))
|
||||
@ -608,13 +659,13 @@ class DiscordWebSocket:
|
||||
if not self._can_handle_close():
|
||||
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
|
||||
|
||||
async def change_presence(self, *, activity=None, status=None, since=0.0):
|
||||
async def change_presence(self, *, activity: Optional[BaseActivity] = None, status: Optional[str] = None, since: float = 0.0) -> None:
|
||||
if activity is not None:
|
||||
if not isinstance(activity, BaseActivity):
|
||||
raise InvalidArgument('activity must derive from BaseActivity.')
|
||||
activity = [activity.to_dict()]
|
||||
activities = [activity.to_dict()]
|
||||
else:
|
||||
activity = []
|
||||
activities = []
|
||||
|
||||
if status == 'idle':
|
||||
since = int(time.time() * 1000)
|
||||
@ -622,7 +673,7 @@ class DiscordWebSocket:
|
||||
payload = {
|
||||
'op': self.PRESENCE,
|
||||
'd': {
|
||||
'activities': activity,
|
||||
'activities': activities,
|
||||
'afk': False,
|
||||
'since': since,
|
||||
'status': status
|
||||
@ -633,7 +684,7 @@ class DiscordWebSocket:
|
||||
_log.debug('Sending "%s" to change status', sent)
|
||||
await self.send(sent)
|
||||
|
||||
async def request_chunks(self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None):
|
||||
async def request_chunks(self, guild_id: int, query: Optional[str] = None, *, limit: int, user_ids: Optional[List[int]] = None, presences: bool = False, nonce: Optional[int] = None) -> None:
|
||||
payload = {
|
||||
'op': self.REQUEST_MEMBERS,
|
||||
'd': {
|
||||
@ -655,7 +706,7 @@ class DiscordWebSocket:
|
||||
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False):
|
||||
async def voice_state(self, guild_id: int, channel_id: int, self_mute: bool = False, self_deaf: bool = False) -> None:
|
||||
payload = {
|
||||
'op': self.VOICE_STATE,
|
||||
'd': {
|
||||
@ -669,7 +720,7 @@ class DiscordWebSocket:
|
||||
_log.debug('Updating our voice state to %s.', payload)
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def close(self, code=4000):
|
||||
async def close(self, code: int = 4000) -> None:
|
||||
if self._keep_alive:
|
||||
self._keep_alive.stop()
|
||||
self._keep_alive = None
|
||||
@ -721,25 +772,31 @@ class DiscordVoiceWebSocket:
|
||||
CLIENT_CONNECT = 12
|
||||
CLIENT_DISCONNECT = 13
|
||||
|
||||
def __init__(self, socket, loop, *, hook=None):
|
||||
self.ws = socket
|
||||
self.loop = loop
|
||||
self._keep_alive = None
|
||||
self._close_code = None
|
||||
self.secret_key = None
|
||||
def __init__(self, socket: aiohttp.ClientWebSocketResponse, loop: asyncio.AbstractEventLoop, *, hook: Optional[Coro] = None) -> None:
|
||||
self.ws: aiohttp.ClientWebSocketResponse = socket
|
||||
self.loop: asyncio.AbstractEventLoop = loop
|
||||
self._keep_alive: VoiceKeepAliveHandler = utils.MISSING
|
||||
self._close_code: Optional[int] = None
|
||||
self.secret_key: Optional[List[int]] = None
|
||||
self.gateway: str = utils.MISSING
|
||||
self._connection: VoiceClient = utils.MISSING
|
||||
self._max_heartbeat_timeout: float = utils.MISSING
|
||||
self.thread_id: int = utils.MISSING
|
||||
if hook:
|
||||
self._hook = hook
|
||||
# we want to redeclare self._hook
|
||||
self._hook = hook # type: ignore
|
||||
|
||||
async def _hook(self, *args):
|
||||
async def _hook(self, *args: Any) -> Any:
|
||||
pass
|
||||
|
||||
async def send_as_json(self, data):
|
||||
|
||||
async def send_as_json(self, data) -> None:
|
||||
_log.debug('Sending voice websocket frame: %s.', data)
|
||||
await self.ws.send_str(utils._to_json(data))
|
||||
|
||||
send_heartbeat = send_as_json
|
||||
|
||||
async def resume(self):
|
||||
async def resume(self) -> None:
|
||||
state = self._connection
|
||||
payload = {
|
||||
'op': self.RESUME,
|
||||
@ -765,7 +822,7 @@ class DiscordVoiceWebSocket:
|
||||
await self.send_as_json(payload)
|
||||
|
||||
@classmethod
|
||||
async def from_client(cls, client, *, resume=False, hook=None):
|
||||
async def from_client(cls: Type[DVWS], client: VoiceClient, *, resume: bool = False, hook: Optional[Coro] = None) -> DVWS:
|
||||
"""Creates a voice websocket for the :class:`VoiceClient`."""
|
||||
gateway = 'wss://' + client.endpoint + '/?v=4'
|
||||
http = client._state.http
|
||||
@ -783,7 +840,7 @@ class DiscordVoiceWebSocket:
|
||||
|
||||
return ws
|
||||
|
||||
async def select_protocol(self, ip, port, mode):
|
||||
async def select_protocol(self, ip, port, mode) -> None:
|
||||
payload = {
|
||||
'op': self.SELECT_PROTOCOL,
|
||||
'd': {
|
||||
@ -798,7 +855,7 @@ class DiscordVoiceWebSocket:
|
||||
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def client_connect(self):
|
||||
async def client_connect(self) -> None:
|
||||
payload = {
|
||||
'op': self.CLIENT_CONNECT,
|
||||
'd': {
|
||||
@ -808,7 +865,7 @@ class DiscordVoiceWebSocket:
|
||||
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def speak(self, state=SpeakingState.voice):
|
||||
async def speak(self, state=SpeakingState.voice) -> None:
|
||||
payload = {
|
||||
'op': self.SPEAKING,
|
||||
'd': {
|
||||
@ -819,7 +876,8 @@ class DiscordVoiceWebSocket:
|
||||
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def received_message(self, msg):
|
||||
|
||||
async def received_message(self, msg) -> None:
|
||||
_log.debug('Voice websocket frame received: %s', msg)
|
||||
op = msg['op']
|
||||
data = msg.get('d')
|
||||
@ -840,7 +898,7 @@ class DiscordVoiceWebSocket:
|
||||
|
||||
await self._hook(self, msg)
|
||||
|
||||
async def initial_connection(self, data):
|
||||
async def initial_connection(self, data) -> None:
|
||||
state = self._connection
|
||||
state.ssrc = data['ssrc']
|
||||
state.voice_port = data['port']
|
||||
@ -871,13 +929,13 @@ class DiscordVoiceWebSocket:
|
||||
_log.info('selected the voice protocol for use (%s)', mode)
|
||||
|
||||
@property
|
||||
def latency(self):
|
||||
def latency(self) -> float:
|
||||
""":class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds."""
|
||||
heartbeat = self._keep_alive
|
||||
return float('inf') if heartbeat is None else heartbeat.latency
|
||||
|
||||
@property
|
||||
def average_latency(self):
|
||||
def average_latency(self) -> float:
|
||||
""":class:`list`: Average of last 20 HEARTBEAT latencies."""
|
||||
heartbeat = self._keep_alive
|
||||
if heartbeat is None or not heartbeat.recent_ack_latencies:
|
||||
@ -885,13 +943,14 @@ class DiscordVoiceWebSocket:
|
||||
|
||||
return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies)
|
||||
|
||||
async def load_secret_key(self, data):
|
||||
|
||||
async def load_secret_key(self, data) -> None:
|
||||
_log.info('received secret key for voice connection')
|
||||
self.secret_key = self._connection.secret_key = data.get('secret_key')
|
||||
await self.speak()
|
||||
await self.speak(False)
|
||||
|
||||
async def poll_event(self):
|
||||
async def poll_event(self) -> None:
|
||||
# This exception is handled up the chain
|
||||
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
|
||||
if msg.type is aiohttp.WSMsgType.TEXT:
|
||||
@ -903,7 +962,7 @@ class DiscordVoiceWebSocket:
|
||||
_log.debug('Received %s', msg)
|
||||
raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code)
|
||||
|
||||
async def close(self, code=1000):
|
||||
async def close(self, code: int = 1000) -> None:
|
||||
if self._keep_alive is not None:
|
||||
self._keep_alive.stop()
|
||||
|
||||
|
120
discord/guild.py
120
discord/guild.py
@ -46,7 +46,7 @@ from . import utils, abc
|
||||
from .role import Role
|
||||
from .member import Member, VoiceState
|
||||
from .emoji import Emoji
|
||||
from .errors import InvalidData
|
||||
from .errors import InvalidData, NotFound
|
||||
from .permissions import PermissionOverwrite
|
||||
from .colour import Colour
|
||||
from .errors import InvalidArgument, ClientException
|
||||
@ -76,6 +76,7 @@ from .stage_instance import StageInstance
|
||||
from .threads import Thread, ThreadMember
|
||||
from .sticker import GuildSticker
|
||||
from .file import File
|
||||
from .welcome_screen import WelcomeScreen, WelcomeChannel
|
||||
|
||||
|
||||
__all__ = (
|
||||
@ -140,6 +141,10 @@ class Guild(Hashable):
|
||||
|
||||
Returns the guild's name.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the guild's ID.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name: :class:`str`
|
||||
@ -738,12 +743,16 @@ class Guild(Hashable):
|
||||
|
||||
@property
|
||||
def humans(self) -> List[Member]:
|
||||
"""List[:class:`Member`]: A list of human members that belong to this guild."""
|
||||
"""List[:class:`Member`]: A list of human members that belong to this guild.
|
||||
|
||||
.. versionadded:: 2.0 """
|
||||
return [member for member in self.members if not member.bot]
|
||||
|
||||
@property
|
||||
def bots(self) -> List[Member]:
|
||||
"""List[:class:`Member`]: A list of bots that belong to this guild."""
|
||||
"""List[:class:`Member`]: A list of bots that belong to this guild.
|
||||
|
||||
.. versionadded:: 2.0 """
|
||||
return [member for member in self.members if member.bot]
|
||||
|
||||
def get_member(self, user_id: int, /) -> Optional[Member]:
|
||||
@ -1715,6 +1724,8 @@ class Guild(Hashable):
|
||||
You do not have access to the guild.
|
||||
HTTPException
|
||||
Fetching the member failed.
|
||||
NotFound
|
||||
A member with that ID does not exist.
|
||||
|
||||
Returns
|
||||
--------
|
||||
@ -1724,6 +1735,34 @@ class Guild(Hashable):
|
||||
data = await self._state.http.get_member(self.id, member_id)
|
||||
return Member(data=data, state=self._state, guild=self)
|
||||
|
||||
async def try_member(self, member_id: int, /) -> Optional[Member]:
|
||||
"""|coro|
|
||||
|
||||
Returns a member with the given ID. This uses the cache first, and if not found, it'll request using :meth:`fetch_member`.
|
||||
|
||||
.. note::
|
||||
This method might result in an API call.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
member_id: :class:`int`
|
||||
The ID to search for.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Optional[:class:`Member`]
|
||||
The member or ``None`` if not found.
|
||||
"""
|
||||
member = self.get_member(member_id)
|
||||
|
||||
if member:
|
||||
return member
|
||||
else:
|
||||
try:
|
||||
return await self.fetch_member(member_id)
|
||||
except NotFound:
|
||||
return None
|
||||
|
||||
async def fetch_ban(self, user: Snowflake) -> BanEntry:
|
||||
"""|coro|
|
||||
|
||||
@ -2566,6 +2605,81 @@ class Guild(Hashable):
|
||||
|
||||
return roles
|
||||
|
||||
async def welcome_screen(self) -> WelcomeScreen:
|
||||
"""|coro|
|
||||
|
||||
Returns the guild's welcome screen.
|
||||
|
||||
The guild must have ``COMMUNITY`` in :attr:`~Guild.features`.
|
||||
|
||||
You must have the :attr:`~Permissions.manage_guild` permission to use
|
||||
this as well.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Raises
|
||||
-------
|
||||
Forbidden
|
||||
You do not have the proper permissions to get this.
|
||||
HTTPException
|
||||
Retrieving the welcome screen failed.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`WelcomeScreen`
|
||||
The welcome screen.
|
||||
"""
|
||||
data = await self._state.http.get_welcome_screen(self.id)
|
||||
return WelcomeScreen(data=data, guild=self)
|
||||
|
||||
@overload
|
||||
async def edit_welcome_screen(
|
||||
self,
|
||||
*,
|
||||
description: Optional[str] = ...,
|
||||
welcome_channels: Optional[List[WelcomeChannel]] = ...,
|
||||
enabled: Optional[bool] = ...,
|
||||
) -> WelcomeScreen:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def edit_welcome_screen(self) -> None:
|
||||
...
|
||||
|
||||
async def edit_welcome_screen(self, **kwargs):
|
||||
"""|coro|
|
||||
|
||||
A shorthand method of :attr:`WelcomeScreen.edit` without needing
|
||||
to fetch the welcome screen beforehand.
|
||||
|
||||
The guild must have ``COMMUNITY`` in :attr:`~Guild.features`.
|
||||
|
||||
You must have the :attr:`~Permissions.manage_guild` permission to use
|
||||
this as well.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`WelcomeScreen`
|
||||
The edited welcome screen.
|
||||
"""
|
||||
try:
|
||||
welcome_channels = kwargs['welcome_channels']
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
welcome_channels_serialised = []
|
||||
for wc in welcome_channels:
|
||||
if not isinstance(wc, WelcomeChannel):
|
||||
raise InvalidArgument('welcome_channels parameter must be a list of WelcomeChannel')
|
||||
welcome_channels_serialised.append(wc.to_dict())
|
||||
kwargs['welcome_channels'] = welcome_channels_serialised
|
||||
|
||||
if kwargs:
|
||||
data = await self._state.http.edit_welcome_screen(self.id, kwargs)
|
||||
return WelcomeScreen(data=data, guild=self)
|
||||
|
||||
async def kick(self, user: Snowflake, *, reason: Optional[str] = None) -> None:
|
||||
"""|coro|
|
||||
|
||||
|
@ -84,6 +84,7 @@ if TYPE_CHECKING:
|
||||
threads,
|
||||
voice,
|
||||
sticker,
|
||||
welcome_screen,
|
||||
)
|
||||
from .types.snowflake import Snowflake, SnowflakeList
|
||||
|
||||
@ -1116,6 +1117,20 @@ class HTTPClient:
|
||||
payload['icon'] = icon
|
||||
return self.request(Route('POST', '/guilds/templates/{code}', code=code), json=payload)
|
||||
|
||||
def get_welcome_screen(self, guild_id: Snowflake) -> Response[welcome_screen.WelcomeScreen]:
|
||||
return self.request(Route('GET', '/guilds/{guild_id}/welcome-screen', guild_id=guild_id))
|
||||
|
||||
def edit_welcome_screen(self, guild_id: Snowflake, payload: Any) -> Response[welcome_screen.WelcomeScreen]:
|
||||
valid_keys = (
|
||||
'description',
|
||||
'welcome_channels',
|
||||
'enabled',
|
||||
)
|
||||
payload = {
|
||||
k: v for k, v in payload.items() if k in valid_keys
|
||||
}
|
||||
return self.request(Route('PATCH', '/guilds/{guild_id}/welcome-screen', guild_id=guild_id), json=payload)
|
||||
|
||||
def get_bans(self, guild_id: Snowflake) -> Response[List[guild.Ban]]:
|
||||
return self.request(Route('GET', '/guilds/{guild_id}/bans', guild_id=guild_id))
|
||||
|
||||
|
@ -230,6 +230,7 @@ class Invite(Hashable):
|
||||
|
||||
Returns the invite URL.
|
||||
|
||||
|
||||
The following table illustrates what methods will obtain the attributes:
|
||||
|
||||
+------------------------------------+------------------------------------------------------------+
|
||||
@ -433,6 +434,9 @@ class Invite(Hashable):
|
||||
def __str__(self) -> str:
|
||||
return self.url
|
||||
|
||||
def __int__(self) -> int:
|
||||
return 0 # To keep the object compatible with the hashable abc.
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'<Invite code={self.code!r} guild={self.guild!r} '
|
||||
|
@ -226,6 +226,10 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
|
||||
Returns the member's name with the discriminator.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the user's ID.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
joined_at: Optional[:class:`datetime.datetime`]
|
||||
@ -300,6 +304,9 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
def __str__(self) -> str:
|
||||
return str(self._user)
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'<Member id={self._user.id} name={self._user.name!r} discriminator={self._user.discriminator!r}'
|
||||
|
@ -125,6 +125,10 @@ class Attachment(Hashable):
|
||||
|
||||
Returns the hash of the attachment.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the attachment's ID.
|
||||
|
||||
.. versionchanged:: 1.7
|
||||
Attachment can now be casted to :class:`str` and is hashable.
|
||||
|
||||
@ -503,6 +507,14 @@ class Message(Hashable):
|
||||
|
||||
Returns the message's hash.
|
||||
|
||||
.. describe:: str(x)
|
||||
|
||||
Returns the message's content.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the message's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
tts: :class:`bool`
|
||||
@ -712,6 +724,10 @@ class Message(Hashable):
|
||||
f'<{name} id={self.id} channel={self.channel!r} type={self.type!r} author={self.author!r} flags={self.flags!r}>'
|
||||
)
|
||||
|
||||
|
||||
def __str__(self) -> Optional[str]:
|
||||
return self.content
|
||||
|
||||
def _try_patch(self, data, key, transform=None) -> None:
|
||||
try:
|
||||
value = data[key]
|
||||
@ -1634,6 +1650,10 @@ class PartialMessage(Hashable):
|
||||
|
||||
Returns the partial message's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the partial message's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
channel: Union[:class:`TextChannel`, :class:`Thread`, :class:`DMChannel`]
|
||||
|
@ -43,5 +43,8 @@ class EqualityComparable:
|
||||
class Hashable(EqualityComparable):
|
||||
__slots__ = ()
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.id
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.id >> 22
|
||||
|
@ -69,6 +69,10 @@ class Object(Hashable):
|
||||
|
||||
Returns the object's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the object's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
id: :class:`int`
|
||||
|
@ -299,6 +299,13 @@ class Permissions(BaseFlags):
|
||||
"""
|
||||
return 1 << 3
|
||||
|
||||
@make_permission_alias('administrator')
|
||||
def admin(self) -> int:
|
||||
""":class:`bool`: An alias for :attr:`administrator`.
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return 1 << 3
|
||||
|
||||
@flag_value
|
||||
def manage_channels(self) -> int:
|
||||
""":class:`bool`: Returns ``True`` if a user can edit, delete, or create channels in the guild.
|
||||
|
@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
@ -63,10 +64,7 @@ __all__ = (
|
||||
|
||||
CREATE_NO_WINDOW: int
|
||||
|
||||
if sys.platform != 'win32':
|
||||
CREATE_NO_WINDOW = 0
|
||||
else:
|
||||
CREATE_NO_WINDOW = 0x08000000
|
||||
CREATE_NO_WINDOW = 0 if sys.platform != 'win32' else 0x08000000
|
||||
|
||||
class AudioSource:
|
||||
"""Represents an audio stream.
|
||||
@ -526,7 +524,12 @@ class FFmpegOpusAudio(FFmpegAudio):
|
||||
|
||||
@staticmethod
|
||||
def _probe_codec_native(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]:
|
||||
exe = executable[:2] + 'probe' if executable in ('ffmpeg', 'avconv') else executable
|
||||
exe = (
|
||||
executable[:2] + 'probe'
|
||||
if executable in {'ffmpeg', 'avconv'}
|
||||
else executable
|
||||
)
|
||||
|
||||
args = [exe, '-v', 'quiet', '-print_format', 'json', '-show_streams', '-select_streams', 'a:0', source]
|
||||
output = subprocess.check_output(args, timeout=20)
|
||||
codec = bitrate = None
|
||||
|
@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from typing import TYPE_CHECKING, Optional, Set, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -34,7 +35,8 @@ if TYPE_CHECKING:
|
||||
MessageUpdateEvent,
|
||||
ReactionClearEvent,
|
||||
ReactionClearEmojiEvent,
|
||||
IntegrationDeleteEvent
|
||||
IntegrationDeleteEvent,
|
||||
TypingEvent
|
||||
)
|
||||
from .message import Message
|
||||
from .partial_emoji import PartialEmoji
|
||||
@ -49,6 +51,7 @@ __all__ = (
|
||||
'RawReactionClearEvent',
|
||||
'RawReactionClearEmojiEvent',
|
||||
'RawIntegrationDeleteEvent',
|
||||
'RawTypingEvent'
|
||||
)
|
||||
|
||||
|
||||
@ -276,3 +279,36 @@ class RawIntegrationDeleteEvent(_RawReprMixin):
|
||||
self.application_id: Optional[int] = int(data['application_id'])
|
||||
except KeyError:
|
||||
self.application_id: Optional[int] = None
|
||||
|
||||
|
||||
class RawTypingEvent(_RawReprMixin):
|
||||
"""Represents the payload for a :func:`on_raw_typing` event.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
channel_id: :class:`int`
|
||||
The channel ID where the typing originated from.
|
||||
user_id: :class:`int`
|
||||
The ID of the user that started typing.
|
||||
when: :class:`datetime.datetime`
|
||||
When the typing started as an aware datetime in UTC.
|
||||
guild_id: Optional[:class:`int`]
|
||||
The guild ID where the typing originated from, if applicable.
|
||||
member: Optional[:class:`Member`]
|
||||
The member who started typing. Only available if the member started typing in a guild.
|
||||
"""
|
||||
|
||||
__slots__ = ("channel_id", "user_id", "when", "guild_id", "member")
|
||||
|
||||
def __init__(self, data: TypingEvent) -> None:
|
||||
self.channel_id: int = int(data['channel_id'])
|
||||
self.user_id: int = int(data['user_id'])
|
||||
self.when: datetime.datetime = datetime.datetime.fromtimestamp(data.get('timestamp'), tz=datetime.timezone.utc)
|
||||
self.member: Optional[Member] = None
|
||||
|
||||
try:
|
||||
self.guild_id: Optional[int] = int(data['guild_id'])
|
||||
except KeyError:
|
||||
self.guild_id: Optional[int] = None
|
@ -141,6 +141,14 @@ class Role(Hashable):
|
||||
|
||||
Returns the role's name.
|
||||
|
||||
.. describe:: str(x)
|
||||
|
||||
Returns the role's ID.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the role's ID.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
id: :class:`int`
|
||||
@ -195,6 +203,9 @@ class Role(Hashable):
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<Role id={self.id} name={self.name!r}>'
|
||||
|
||||
|
@ -61,6 +61,10 @@ class StageInstance(Hashable):
|
||||
|
||||
Returns the stage instance's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the stage instance's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
id: :class:`int`
|
||||
|
@ -1327,28 +1327,37 @@ class ConnectionState:
|
||||
asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice server update handler'))
|
||||
|
||||
def parse_typing_start(self, data) -> None:
|
||||
raw = RawTypingEvent(data)
|
||||
|
||||
member_data = data.get('member')
|
||||
if member_data:
|
||||
guild = self._get_guild(raw.guild_id)
|
||||
if guild is not None:
|
||||
raw.member = Member(data=member_data, guild=guild, state=self)
|
||||
else:
|
||||
raw.member = None
|
||||
else:
|
||||
raw.member = None
|
||||
self.dispatch('raw_typing', raw)
|
||||
|
||||
channel, guild = self._get_guild_channel(data)
|
||||
if channel is not None:
|
||||
member = None
|
||||
user_id = utils._get_as_snowflake(data, 'user_id')
|
||||
if isinstance(channel, DMChannel):
|
||||
member = channel.recipient
|
||||
user = raw.member or self._get_typing_user(channel, raw.user_id)
|
||||
|
||||
elif isinstance(channel, (Thread, TextChannel)) and guild is not None:
|
||||
# user_id won't be None
|
||||
member = guild.get_member(user_id) # type: ignore
|
||||
if user is not None:
|
||||
self.dispatch('typing', channel, user, raw.when)
|
||||
|
||||
if member is None:
|
||||
member_data = data.get('member')
|
||||
if member_data:
|
||||
member = Member(data=member_data, state=self, guild=guild)
|
||||
def _get_typing_user(self, channel: Optional[MessageableChannel], user_id: int) -> Optional[Union[User, Member]]:
|
||||
if isinstance(channel, DMChannel):
|
||||
return channel.recipient
|
||||
|
||||
elif isinstance(channel, GroupChannel):
|
||||
member = utils.find(lambda x: x.id == user_id, channel.recipients)
|
||||
elif isinstance(channel, (Thread, TextChannel)) and channel.guild is not None:
|
||||
return channel.guild.get_member(user_id) # type: ignore
|
||||
|
||||
if member is not None:
|
||||
timestamp = datetime.datetime.fromtimestamp(data.get('timestamp'), tz=datetime.timezone.utc)
|
||||
self.dispatch('typing', channel, member, timestamp)
|
||||
elif isinstance(channel, GroupChannel):
|
||||
return utils.find(lambda x: x.id == user_id, channel.recipients)
|
||||
|
||||
return self.get_user(user_id)
|
||||
|
||||
def _get_reaction_user(self, channel: MessageableChannel, user_id: int) -> Optional[Union[User, Member]]:
|
||||
if isinstance(channel, TextChannel):
|
||||
|
@ -67,6 +67,14 @@ class StickerPack(Hashable):
|
||||
|
||||
Returns the name of the sticker pack.
|
||||
|
||||
.. describe:: hash(x)
|
||||
|
||||
Returns the hash of the sticker pack.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the ID of the sticker pack.
|
||||
|
||||
.. describe:: x == y
|
||||
|
||||
Checks if the sticker pack is equal to another sticker pack.
|
||||
|
@ -74,6 +74,10 @@ class Thread(Messageable, Hashable):
|
||||
|
||||
Returns the thread's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the thread's ID.
|
||||
|
||||
.. describe:: str(x)
|
||||
|
||||
Returns the thread's name.
|
||||
@ -748,6 +752,10 @@ class ThreadMember(Hashable):
|
||||
|
||||
Returns the thread member's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the thread member's ID.
|
||||
|
||||
.. describe:: str(x)
|
||||
|
||||
Returns the thread member's name.
|
||||
@ -800,3 +808,39 @@ class ThreadMember(Hashable):
|
||||
def thread(self) -> Thread:
|
||||
""":class:`Thread`: The thread this member belongs to."""
|
||||
return self.parent
|
||||
|
||||
async def fetch_member(self) -> Member:
|
||||
"""|coro|
|
||||
|
||||
Retrieves a :class:`Member` from the ThreadMember object.
|
||||
|
||||
.. note::
|
||||
|
||||
This method is an API call. If you have :attr:`Intents.members` and member cache enabled, consider :meth:`get_member` instead.
|
||||
|
||||
Raises
|
||||
-------
|
||||
Forbidden
|
||||
You do not have access to the guild.
|
||||
HTTPException
|
||||
Fetching the member failed.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`Member`
|
||||
The member.
|
||||
"""
|
||||
|
||||
return await self.thread.guild.fetch_member(self.id)
|
||||
|
||||
def get_member(self) -> Optional[Member]:
|
||||
"""
|
||||
Get the :class:`Member` from cache for the ThreadMember object.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Optional[:class:`Member`]
|
||||
The member or ``None`` if not found.
|
||||
"""
|
||||
|
||||
return self.thread.guild.get_member(self.id)
|
||||
|
@ -85,3 +85,13 @@ class _IntegrationDeleteEventOptional(TypedDict, total=False):
|
||||
class IntegrationDeleteEvent(_IntegrationDeleteEventOptional):
|
||||
id: Snowflake
|
||||
guild_id: Snowflake
|
||||
|
||||
|
||||
class _TypingEventOptional(TypedDict, total=False):
|
||||
guild_id: Snowflake
|
||||
member: Member
|
||||
|
||||
class TypingEvent(_TypingEventOptional):
|
||||
channel_id: Snowflake
|
||||
user_id: Snowflake
|
||||
timestamp: int
|
@ -185,16 +185,16 @@ class Button(Item[V]):
|
||||
|
||||
@emoji.setter
|
||||
def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]): # type: ignore
|
||||
if value is not None:
|
||||
if isinstance(value, str):
|
||||
self._underlying.emoji = PartialEmoji.from_str(value)
|
||||
elif isinstance(value, _EmojiTag):
|
||||
self._underlying.emoji = value._to_partial()
|
||||
else:
|
||||
raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__} instead')
|
||||
else:
|
||||
if value is None:
|
||||
self._underlying.emoji = None
|
||||
|
||||
elif isinstance(value, str):
|
||||
self._underlying.emoji = PartialEmoji.from_str(value)
|
||||
elif isinstance(value, _EmojiTag):
|
||||
self._underlying.emoji = value._to_partial()
|
||||
else:
|
||||
raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__} instead')
|
||||
|
||||
@classmethod
|
||||
def from_component(cls: Type[B], button: ButtonComponent) -> B:
|
||||
return cls(
|
||||
|
@ -96,6 +96,9 @@ class BaseUser(_UserTag):
|
||||
def __str__(self) -> str:
|
||||
return f'{self.name}#{self.discriminator}'
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.id
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, _UserTag) and other.id == self.id
|
||||
|
||||
@ -415,6 +418,10 @@ class User(BaseUser, discord.abc.Messageable):
|
||||
|
||||
Returns the user's name with discriminator.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the user's ID.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
name: :class:`str`
|
||||
|
@ -499,14 +499,14 @@ else:
|
||||
|
||||
def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
|
||||
reset_after: Optional[str] = request.headers.get('X-Ratelimit-Reset-After')
|
||||
if use_clock or not reset_after:
|
||||
utc = datetime.timezone.utc
|
||||
now = datetime.datetime.now(utc)
|
||||
reset = datetime.datetime.fromtimestamp(float(request.headers['X-Ratelimit-Reset']), utc)
|
||||
return (reset - now).total_seconds()
|
||||
else:
|
||||
if not use_clock and reset_after:
|
||||
return float(reset_after)
|
||||
|
||||
utc = datetime.timezone.utc
|
||||
now = datetime.datetime.now(utc)
|
||||
reset = datetime.datetime.fromtimestamp(float(request.headers['X-Ratelimit-Reset']), utc)
|
||||
return (reset - now).total_seconds()
|
||||
|
||||
|
||||
async def maybe_coroutine(f, *args, **kwargs):
|
||||
value = f(*args, **kwargs)
|
||||
@ -659,11 +659,10 @@ def resolve_invite(invite: Union[Invite, str]) -> str:
|
||||
|
||||
if isinstance(invite, Invite):
|
||||
return invite.code
|
||||
else:
|
||||
rx = r'(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)'
|
||||
m = re.match(rx, invite)
|
||||
if m:
|
||||
return m.group(1)
|
||||
rx = r'(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)'
|
||||
m = re.match(rx, invite)
|
||||
if m:
|
||||
return m.group(1)
|
||||
return invite
|
||||
|
||||
|
||||
@ -687,11 +686,10 @@ def resolve_template(code: Union[Template, str]) -> str:
|
||||
|
||||
if isinstance(code, Template):
|
||||
return code.code
|
||||
else:
|
||||
rx = r'(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)'
|
||||
m = re.match(rx, code)
|
||||
if m:
|
||||
return m.group(1)
|
||||
rx = r'(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)'
|
||||
m = re.match(rx, code)
|
||||
if m:
|
||||
return m.group(1)
|
||||
return code
|
||||
|
||||
|
||||
@ -1017,3 +1015,9 @@ def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None)
|
||||
if style is None:
|
||||
return f'<t:{int(dt.timestamp())}>'
|
||||
return f'<t:{int(dt.timestamp())}:{style}>'
|
||||
|
||||
|
||||
def raise_expected_coro(coro, error: str)-> TypeError:
|
||||
if not asyncio.iscoroutinefunction(coro):
|
||||
raise TypeError(error)
|
||||
return coro
|
||||
|
@ -255,6 +255,9 @@ class VoiceClient(VoiceProtocol):
|
||||
self.encoder: Encoder = MISSING
|
||||
self._lite_nonce: int = 0
|
||||
self.ws: DiscordVoiceWebSocket = MISSING
|
||||
self.ip: str = MISSING
|
||||
self.port: Tuple[Any, ...] = MISSING
|
||||
|
||||
|
||||
warn_nacl = not has_nacl
|
||||
supported_modes: Tuple[SupportedModes, ...] = (
|
||||
|
@ -886,6 +886,10 @@ class Webhook(BaseWebhook):
|
||||
|
||||
Returns the webhooks's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the webhooks's ID.
|
||||
|
||||
.. versionchanged:: 1.4
|
||||
Webhooks are now comparable and hashable.
|
||||
|
||||
|
@ -475,6 +475,10 @@ class SyncWebhook(BaseWebhook):
|
||||
|
||||
Returns the webhooks's hash.
|
||||
|
||||
.. describe:: int(x)
|
||||
|
||||
Returns the webhooks's ID.
|
||||
|
||||
.. versionchanged:: 1.4
|
||||
Webhooks are now comparable and hashable.
|
||||
|
||||
|
216
discord/welcome_screen.py
Normal file
216
discord/welcome_screen.py
Normal file
@ -0,0 +1,216 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015-present Rapptz
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a
|
||||
copy of this software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING, Union, overload
|
||||
from .utils import _get_as_snowflake, get
|
||||
from .errors import InvalidArgument
|
||||
from .partial_emoji import _EmojiTag
|
||||
|
||||
__all__ = (
|
||||
'WelcomeChannel',
|
||||
'WelcomeScreen',
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .types.welcome_screen import (
|
||||
WelcomeScreen as WelcomeScreenPayload,
|
||||
WelcomeScreenChannel as WelcomeScreenChannelPayload,
|
||||
)
|
||||
from .abc import Snowflake
|
||||
from .guild import Guild
|
||||
from .partial_emoji import PartialEmoji
|
||||
from .emoji import Emoji
|
||||
|
||||
|
||||
class WelcomeChannel:
|
||||
"""Represents a :class:`WelcomeScreen` welcome channel.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
channel: :class:`abc.Snowflake`
|
||||
The guild channel that is being referenced.
|
||||
description: :class:`str`
|
||||
The description shown of the channel.
|
||||
emoji: Optional[:class:`PartialEmoji`, :class:`Emoji`, :class:`str`]
|
||||
The emoji used beside the channel description.
|
||||
"""
|
||||
|
||||
def __init__(self, *, channel: Snowflake, description: str, emoji: Union[PartialEmoji, Emoji, str] = None):
|
||||
self.channel = channel
|
||||
self.description = description
|
||||
self.emoji = emoji
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<WelcomeChannel channel={self.channel!r} description={self.description!r} emoji={self.emoji!r}>'
|
||||
|
||||
@classmethod
|
||||
def _from_dict(cls, *, data: WelcomeScreenChannelPayload, guild: Guild) -> WelcomeChannel:
|
||||
channel_id = _get_as_snowflake(data, 'channel_id')
|
||||
channel = guild.get_channel(channel_id)
|
||||
description = data['description']
|
||||
_emoji_id = _get_as_snowflake(data, 'emoji_id')
|
||||
_emoji_name = data['emoji_name']
|
||||
|
||||
if _emoji_id:
|
||||
# custom
|
||||
emoji = get(guild.emojis, id=_emoji_id)
|
||||
else:
|
||||
# unicode or None
|
||||
emoji = _emoji_name
|
||||
|
||||
return cls(channel=channel, description=description, emoji=emoji) # type: ignore
|
||||
|
||||
def to_dict(self) -> WelcomeScreenChannelPayload:
|
||||
ret: WelcomeScreenChannelPayload = {
|
||||
'channel_id': self.channel.id,
|
||||
'description': self.description,
|
||||
'emoji_id': None,
|
||||
'emoji_name': None,
|
||||
}
|
||||
|
||||
if isinstance(self.emoji, _EmojiTag):
|
||||
ret['emoji_id'] = self.emoji.id # type: ignore
|
||||
ret['emoji_name'] = self.emoji.name # type: ignore
|
||||
else:
|
||||
# unicode or None
|
||||
ret['emoji_name'] = self.emoji
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class WelcomeScreen:
|
||||
"""Represents a :class:`Guild` welcome screen.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
description: :class:`str`
|
||||
The description shown on the welcome screen.
|
||||
welcome_channels: List[:class:`WelcomeChannel`]
|
||||
The channels shown on the welcome screen.
|
||||
"""
|
||||
|
||||
def __init__(self, *, data: WelcomeScreenPayload, guild: Guild):
|
||||
self._state = guild._state
|
||||
self._guild = guild
|
||||
self._store(data)
|
||||
|
||||
def _store(self, data: WelcomeScreenPayload) -> None:
|
||||
self.description = data['description']
|
||||
welcome_channels = data.get('welcome_channels', [])
|
||||
self.welcome_channels = [WelcomeChannel._from_dict(data=wc, guild=self._guild) for wc in welcome_channels]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<WelcomeScreen description={self.description!r} welcome_channels={self.welcome_channels!r} enabled={self.enabled}>'
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
""":class:`bool`: Whether the welcome screen is displayed.
|
||||
|
||||
This is equivalent to checking if ``WELCOME_SCREEN_ENABLED``
|
||||
is present in :attr:`Guild.features`.
|
||||
"""
|
||||
return 'WELCOME_SCREEN_ENABLED' in self._guild.features
|
||||
|
||||
@overload
|
||||
async def edit(
|
||||
self,
|
||||
*,
|
||||
description: Optional[str] = ...,
|
||||
welcome_channels: Optional[List[WelcomeChannel]] = ...,
|
||||
enabled: Optional[bool] = ...,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def edit(self) -> None:
|
||||
...
|
||||
|
||||
async def edit(self, **kwargs):
|
||||
"""|coro|
|
||||
|
||||
Edit the welcome screen.
|
||||
|
||||
You must have the :attr:`~Permissions.manage_guild` permission in the
|
||||
guild to do this.
|
||||
|
||||
Usage: ::
|
||||
|
||||
rules_channel = guild.get_channel(12345678)
|
||||
announcements_channel = guild.get_channel(87654321)
|
||||
|
||||
custom_emoji = utils.get(guild.emojis, name='loudspeaker')
|
||||
|
||||
await welcome_screen.edit(
|
||||
description='This is a very cool community server!',
|
||||
welcome_channels=[
|
||||
WelcomeChannel(channel=rules_channel, description='Read the rules!', emoji='👨🏫'),
|
||||
WelcomeChannel(channel=announcements_channel, description='Watch out for announcements!', emoji=custom_emoji),
|
||||
]
|
||||
)
|
||||
|
||||
.. note::
|
||||
|
||||
Welcome channels can only accept custom emojis if :attr:`~Guild.premium_tier` is level 2 or above.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
description: Optional[:class:`str`]
|
||||
The template's description.
|
||||
welcome_channels: Optional[List[:class:`WelcomeChannel`]]
|
||||
The welcome channels, in their respective order.
|
||||
enabled: Optional[:class:`bool`]
|
||||
Whether the welcome screen should be displayed.
|
||||
|
||||
Raises
|
||||
-------
|
||||
HTTPException
|
||||
Editing the welcome screen failed failed.
|
||||
Forbidden
|
||||
You don't have permissions to edit the welcome screen.
|
||||
NotFound
|
||||
This welcome screen does not exist.
|
||||
"""
|
||||
try:
|
||||
welcome_channels = kwargs['welcome_channels']
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
welcome_channels_serialised = []
|
||||
for wc in welcome_channels:
|
||||
if not isinstance(wc, WelcomeChannel):
|
||||
raise InvalidArgument('welcome_channels parameter must be a list of WelcomeChannel')
|
||||
welcome_channels_serialised.append(wc.to_dict())
|
||||
kwargs['welcome_channels'] = welcome_channels_serialised
|
||||
|
||||
if kwargs:
|
||||
data = await self._state.http.edit_welcome_screen(self._guild.id, kwargs)
|
||||
self._store(data)
|
35
docs/api.rst
35
docs/api.rst
@ -369,6 +369,17 @@ to handle it, which defaults to print a traceback and ignoring the exception.
|
||||
:param when: When the typing started as an aware datetime in UTC.
|
||||
:type when: :class:`datetime.datetime`
|
||||
|
||||
.. function:: on_raw_typing(payload)
|
||||
|
||||
Called when someone begins typing a message. Unlike :func:`on_typing`, this is
|
||||
called regardless if the user can be found or not. This most often happens
|
||||
when a user types in DMs.
|
||||
|
||||
This requires :attr:`Intents.typing` to be enabled.
|
||||
|
||||
:param payload: The raw typing payload.
|
||||
:type payload: :class:`RawTypingEvent`
|
||||
|
||||
.. function:: on_message(message)
|
||||
|
||||
Called when a :class:`Message` is created and sent.
|
||||
@ -3781,6 +3792,22 @@ Template
|
||||
.. autoclass:: Template()
|
||||
:members:
|
||||
|
||||
WelcomeScreen
|
||||
~~~~~~~~~~~~~~~
|
||||
|
||||
.. attributetable:: WelcomeScreen
|
||||
|
||||
.. autoclass:: WelcomeScreen()
|
||||
:members:
|
||||
|
||||
WelcomeChannel
|
||||
~~~~~~~~~~~~~~~
|
||||
|
||||
.. attributetable:: WelcomeChannel
|
||||
|
||||
.. autoclass:: WelcomeChannel()
|
||||
:members:
|
||||
|
||||
WidgetChannel
|
||||
~~~~~~~~~~~~~~~
|
||||
|
||||
@ -3846,6 +3873,14 @@ GuildSticker
|
||||
.. autoclass:: GuildSticker()
|
||||
:members:
|
||||
|
||||
RawTypingEvent
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. attributetable:: RawTypingEvent
|
||||
|
||||
.. autoclass:: RawTypingEvent()
|
||||
:members:
|
||||
|
||||
RawMessageDeleteEvent
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user