Merge branch 'iDevision:2.0' into 2.0

This commit is contained in:
Gnome! 2021-09-05 15:57:14 +01:00 committed by GitHub
commit 6af5399936
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 1033 additions and 247 deletions

View File

@ -1,5 +1,7 @@
## Contributing to discord.py
Credits to the `original lib` by Rapptz <https://github.com/Rapptz/discord.py>
First off, thanks for taking the time to contribute. It makes the library substantially better. :+1:
The following is a set of guidelines for contributing to the repository. These are guidelines, not hard rules.
@ -8,9 +10,9 @@ The following is a set of guidelines for contributing to the repository. These a
Generally speaking questions are better suited in our resources below.
- The official support server: https://discord.gg/r3sSKJJ
- The official support server: https://discord.gg/TvqYBrGXEm
- The Discord API server under #python_discord-py: https://discord.gg/discord-api
- [The FAQ in the documentation](https://discordpy.readthedocs.io/en/latest/faq.html)
- [The FAQ in the documentation](https://enhanced-dpy.readthedocs.io/en/latest/faq.html)
- [StackOverflow's `discord.py` tag](https://stackoverflow.com/questions/tagged/discord.py)
Please try your best not to ask questions in our issue tracker. Most of them don't belong there unless they provide value to a larger audience.

View File

@ -1,5 +1,5 @@
discord.py
==========
enhanced-discord.py
===================
.. image:: https://discord.com/api/guilds/514232441498763279/embed.png
:target: https://discord.gg/PYAfZzpsjG
@ -59,7 +59,7 @@ To install the development version, do the following:
.. code:: sh
$ git clone https://github.com/iDevision/enhanced-discord.py
$ cd discord.py
$ cd enhanced-discord.py
$ python3 -m pip install -U .[voice]

View File

@ -40,6 +40,7 @@ from .colour import *
from .integrations import *
from .invite import *
from .template import *
from .welcome_screen import *
from .widget import *
from .object import *
from .reaction import *

View File

@ -794,12 +794,12 @@ class CustomActivity(BaseActivity):
return hash((self.name, str(self.emoji)))
def __str__(self) -> str:
if self.emoji:
if not self.emoji:
return str(self.name)
if self.name:
return f'{self.emoji} {self.name}'
return str(self.emoji)
else:
return str(self.name)
def __repr__(self) -> str:
return f'<CustomActivity name={self.name!r} emoji={self.emoji!r}>'

View File

@ -313,7 +313,8 @@ class Asset(AssetMixin):
if self._animated:
if format not in VALID_ASSET_FORMATS:
raise InvalidArgument(f'format must be one of {VALID_ASSET_FORMATS}')
else:
url = url.with_path(f'{path}.{format}')
elif static_format is MISSING:
if format not in VALID_STATIC_FORMATS:
raise InvalidArgument(f'format must be one of {VALID_STATIC_FORMATS}')
url = url.with_path(f'{path}.{format}')

View File

@ -330,6 +330,10 @@ class AuditLogEntry(Hashable):
Returns the entry's hash.
.. describe:: int(x)
Returns the entry's ID.
.. versionchanged:: 1.7
Audit log entries are now comparable and hashable.

View File

@ -115,6 +115,10 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
Returns the channel's name.
.. describe:: int(x)
Returns the channel's ID.
Attributes
-----------
name: :class:`str`
@ -224,6 +228,16 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
"""List[:class:`Member`]: Returns all members that can see this channel."""
return [m for m in self.guild.members if self.permissions_for(m).read_messages]
@property
def bots(self) -> List[Member]:
"""List[:class:`Member`]: Returns all bots that can see this channel."""
return [m for m in self.guild.members if m.bot and self.permissions_for(m).read_messages]
@property
def humans(self) -> List[Member]:
"""List[:class:`Member`]: Returns all human members that can see this channel."""
return [m for m in self.guild.members if not m.bot and self.permissions_for(m).read_messages]
@property
def threads(self) -> List[Thread]:
"""List[:class:`Thread`]: Returns all the threads that you can see.
@ -1334,6 +1348,10 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
Returns the category's name.
.. describe:: int(x)
Returns the category's ID.
Attributes
-----------
name: :class:`str`
@ -1556,6 +1574,10 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
Returns the channel's name.
.. describe:: int(x)
Returns the channel's ID.
Attributes
-----------
name: :class:`str`
@ -1728,6 +1750,10 @@ class DMChannel(discord.abc.Messageable, Hashable):
Returns a string representation of the channel
.. describe:: int(x)
Returns the channel's ID.
Attributes
----------
recipient: Optional[:class:`User`]
@ -1854,6 +1880,10 @@ class GroupChannel(discord.abc.Messageable, Hashable):
Returns a string representation of the channel
.. describe:: int(x)
Returns the channel's ID.
Attributes
----------
recipients: List[:class:`User`]
@ -2000,6 +2030,10 @@ class PartialMessageable(discord.abc.Messageable, Hashable):
Returns the partial messageable's hash.
.. describe:: int(x)
Returns the messageable's ID.
Attributes
-----------
id: :class:`int`

View File

@ -842,6 +842,38 @@ class Client:
"""
return self._connection.get_user(id)
async def try_user(self, id: int, /) -> Optional[User]:
"""|coro|
Returns a user with the given ID. If not from cache, the user will be requested from the API.
You do not have to share any guilds with the user to get this information from the API,
however many operations do require that you do.
.. note::
This method is an API call. If you have :attr:`discord.Intents.members` and member cache enabled, consider :meth:`get_user` instead.
.. versionadded:: 2.0
Parameters
-----------
id: :class:`int`
The ID to search for.
Returns
--------
Optional[:class:`~discord.User`]
The user or ``None`` if not found.
"""
maybe_user = self.get_user(id)
if maybe_user is not None:
return maybe_user
try:
return await self.fetch_user(id)
except NotFound:
return None
def get_emoji(self, id: int, /) -> Optional[Emoji]:
"""Returns an emoji with the given ID.

View File

@ -252,6 +252,13 @@ class Colour:
"""A factory method that returns a :class:`Colour` with a value of ``0xe74c3c``."""
return cls(0xe74c3c)
@classmethod
def nitro_booster(cls):
"""A factory method that returns a :class:`Colour` with a value of ``0xf47fff``.
.. versionadded:: 2.0"""
return cls(0xf47fff)
@classmethod
def dark_red(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x992d22``."""
@ -325,5 +332,14 @@ class Colour:
"""
return cls(0xFEE75C)
@classmethod
def dark_blurple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x4E5D94``.
This is the original Dark Blurple branding.
.. versionadded:: 2.0
"""
return cls(0x4E5D94)
Color = Colour

View File

@ -72,30 +72,36 @@ if TYPE_CHECKING:
T = TypeVar('T')
MaybeEmpty = Union[T, _EmptyEmbed]
class _EmbedFooterProxy(Protocol):
text: MaybeEmpty[str]
icon_url: MaybeEmpty[str]
class _EmbedFieldProxy(Protocol):
name: MaybeEmpty[str]
value: MaybeEmpty[str]
inline: bool
class _EmbedMediaProxy(Protocol):
url: MaybeEmpty[str]
proxy_url: MaybeEmpty[str]
height: MaybeEmpty[int]
width: MaybeEmpty[int]
class _EmbedVideoProxy(Protocol):
url: MaybeEmpty[str]
height: MaybeEmpty[int]
width: MaybeEmpty[int]
class _EmbedProviderProxy(Protocol):
name: MaybeEmpty[str]
url: MaybeEmpty[str]
class _EmbedAuthorProxy(Protocol):
name: MaybeEmpty[str]
url: MaybeEmpty[str]
@ -397,6 +403,22 @@ class Embed:
"""
return EmbedProxy(getattr(self, '_image', {})) # type: ignore
@image.setter
def image(self: E, url: Any):
if url is EmptyEmbed:
del self._image
else:
self._image = {
'url': str(url),
}
@image.deleter
def image(self: E):
try:
del self._image
except AttributeError:
pass
def set_image(self: E, *, url: MaybeEmpty[Any]) -> E:
"""Sets the image for the embed content.
@ -412,15 +434,7 @@ class Embed:
The source URL for the image. Only HTTP(S) is supported.
"""
if url is EmptyEmbed:
try:
del self._image
except AttributeError:
pass
else:
self._image = {
'url': str(url),
}
self.image = url
return self
@ -439,7 +453,23 @@ class Embed:
"""
return EmbedProxy(getattr(self, '_thumbnail', {})) # type: ignore
def set_thumbnail(self: E, *, url: MaybeEmpty[Any]) -> E:
@thumbnail.setter
def thumbnail(self: E, url: Any):
if url is EmptyEmbed:
del self._thumbnail
else:
self._thumbnail = {
'url': str(url),
}
@thumbnail.deleter
def thumbnail(self):
try:
del self._thumbnail
except AttributeError:
pass
def set_thumbnail(self: E, *, url: MaybeEmpty[Any]):
"""Sets the thumbnail for the embed content.
This function returns the class instance to allow for fluent-style
@ -454,15 +484,7 @@ class Embed:
The source URL for the thumbnail. Only HTTP(S) is supported.
"""
if url is EmptyEmbed:
try:
del self._thumbnail
except AttributeError:
pass
else:
self._thumbnail = {
'url': str(url),
}
self.thumbnail = url
return self

View File

@ -72,6 +72,10 @@ class Emoji(_EmojiTag, AssetMixin):
Returns the emoji rendered for discord.
.. describe:: int(x)
Returns the emoji ID.
Attributes
-----------
name: :class:`str`
@ -137,6 +141,9 @@ class Emoji(_EmojiTag, AssetMixin):
return f'<a:{self.name}:{self.id}>'
return f'<:{self.name}:{self.id}>'
def __int__(self) -> int:
return self.id
def __repr__(self) -> str:
return f'<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>'

View File

@ -53,6 +53,7 @@ from .context import Context
from . import errors
from .help import HelpCommand, DefaultHelpCommand
from .cog import Cog
from discord.utils import raise_expected_coro
if TYPE_CHECKING:
import importlib.machinery
@ -453,14 +454,59 @@ class BotBase(GroupMixin):
elif self.owner_ids:
return user.id in self.owner_ids
else:
# Populate the used fields, then retry the check. This is only done at-most once in the bot lifetime.
await self.populate_owners()
return await self.is_owner(user)
async def try_owners(self) -> List[discord.User]:
"""|coro|
Returns a list of :class:`~discord.User` representing the owners of the bot.
It uses the :attr:`owner_id` and :attr:`owner_ids`, if set.
.. versionadded:: 2.0
The function also checks if the application is team-owned if
:attr:`owner_ids` is not set.
Returns
--------
List[:class:`~discord.User`]
List of owners of the bot.
"""
if self.owner_id:
owner = await self.try_user(self.owner_id)
if owner:
return [owner]
else:
return []
elif self.owner_ids:
owners = []
for owner_id in self.owner_ids:
owner = await self.try_user(owner_id)
if owner:
owners.append(owner)
return owners
else:
# We didn't have owners cached yet, cache them and retry.
await self.populate_owners()
return await self.try_owners()
async def populate_owners(self):
"""|coro|
Populate the :attr:`owner_id` and :attr:`owner_ids` through the use of :meth:`~.Bot.application_info`.
.. versionadded:: 2.0
"""
app = await self.application_info() # type: ignore
if app.team:
self.owner_ids = ids = {m.id for m in app.team.members}
return user.id in ids
self.owner_ids = {m.id for m in app.team.members}
else:
self.owner_id = owner_id = app.owner.id
return user.id == owner_id
self.owner_id = app.owner.id
def before_invoke(self, coro: CFT) -> CFT:
"""A decorator that registers a coroutine as a pre-invoke hook.
@ -488,11 +534,9 @@ class BotBase(GroupMixin):
TypeError
The coroutine passed is not actually a coroutine.
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The pre-invoke hook must be a coroutine.')
self._before_invoke = coro
return coro
return raise_expected_coro(
coro, 'The pre-invoke hook must be a coroutine.'
)
def after_invoke(self, coro: CFT) -> CFT:
r"""A decorator that registers a coroutine as a post-invoke hook.
@ -521,11 +565,10 @@ class BotBase(GroupMixin):
TypeError
The coroutine passed is not actually a coroutine.
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The post-invoke hook must be a coroutine.')
return raise_expected_coro(
coro, 'The post-invoke hook must be a coroutine.'
)
self._after_invoke = coro
return coro
# listener registration
@ -1266,7 +1309,7 @@ class Bot(BotBase, discord.Client):
when passing an empty string, it should always be last as no prefix
after it will be matched.
case_insensitive: :class:`bool`
Whether the commands should be case insensitive. Defaults to ``False``. This
Whether the commands should be case insensitive. Defaults to ``True``. This
attribute does not carry over to groups. You must set it to every group if
you require group commands to be case insensitive as well.
description: :class:`str`

View File

@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import inspect
@ -32,6 +33,7 @@ import discord.abc
import discord.utils
from discord.message import Message
from discord import Permissions
if TYPE_CHECKING:
from typing_extensions import ParamSpec
@ -62,10 +64,7 @@ T = TypeVar('T')
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
CogT = TypeVar('CogT', bound="Cog")
if TYPE_CHECKING:
P = ParamSpec('P')
else:
P = TypeVar('P')
P = ParamSpec('P') if TYPE_CHECKING else TypeVar('P')
class Context(discord.abc.Messageable, Generic[BotT]):
@ -318,6 +317,13 @@ class Context(discord.abc.Messageable, Generic[BotT]):
g = self.guild
return g.voice_client if g else None
def author_permissions(self) -> Permissions:
"""Returns the author permissions in the given channel.
.. versionadded:: 2.0
"""
return self.channel.permissions_for(self.author)
async def send_help(self, *args: Any) -> Any:
"""send_help(entity=<bot>)

View File

@ -356,14 +356,14 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
@staticmethod
def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]:
if guild_id is not None:
if guild_id is None:
return ctx.bot.get_channel(channel_id) if channel_id else ctx.channel
guild = ctx.bot.get_guild(guild_id)
if guild is not None and channel_id is not None:
return guild._resolve_channel(channel_id) # type: ignore
else:
return None
else:
return ctx.bot.get_channel(channel_id) if channel_id else ctx.channel
async def convert(self, ctx: Context, argument: str) -> discord.PartialMessage:
guild_id, message_id, channel_id = self._get_id_matches(ctx, argument)
@ -942,8 +942,7 @@ class clean_content(Converter[str]):
def repl(match: re.Match) -> str:
type = match[1]
id = int(match[2])
transformed = transforms[type](id)
return transformed
return transforms[type](id)
result = re.sub(r'<(@[!&]?|#)([0-9]{15,20})>', repl, argument)
if self.escape_markdown:

View File

@ -1264,10 +1264,10 @@ class GroupMixin(Generic[CogT]):
A mapping of command name to :class:`.Command`
objects.
case_insensitive: :class:`bool`
Whether the commands should be case insensitive. Defaults to ``False``.
Whether the commands should be case insensitive. Defaults to ``True``.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
case_insensitive = kwargs.get('case_insensitive', False)
case_insensitive = kwargs.get('case_insensitive', True)
self.all_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {}
self.case_insensitive: bool = case_insensitive
super().__init__(*args, **kwargs)

View File

@ -82,9 +82,7 @@ class StringView:
def skip_string(self, string):
strlen = len(string)
if self.buffer[self.index:self.index + strlen] == string:
self.previous = self.index
self.index += strlen
return True
return self._return_index(strlen, True)
return False
def read_rest(self):
@ -95,9 +93,7 @@ class StringView:
def read(self, n):
result = self.buffer[self.index:self.index + n]
self.previous = self.index
self.index += n
return result
return self._return_index(n, result)
def get(self):
try:
@ -105,9 +101,12 @@ class StringView:
except IndexError:
result = None
return self._return_index(1, result)
def _return_index(self, arg0, arg1):
self.previous = self.index
self.index += 1
return result
self.index += arg0
return arg1
def get_word(self):
pos = 0

View File

@ -46,7 +46,9 @@ import traceback
from collections.abc import Sequence
from discord.backoff import ExponentialBackoff
from discord.utils import MISSING
from discord.utils import MISSING, raise_expected_coro
__all__ = (
'loop',
@ -488,11 +490,7 @@ class Loop(Generic[LF]):
The function was not a coroutine.
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
self._before_loop = coro
return coro
return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.')
def after_loop(self, coro: FT) -> FT:
"""A decorator that register a coroutine to be called after the loop finished running.
@ -516,11 +514,7 @@ class Loop(Generic[LF]):
The function was not a coroutine.
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
self._after_loop = coro
return coro
return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.')
def error(self, coro: ET) -> ET:
"""A decorator that registers a coroutine to be called if the task encounters an unhandled exception.
@ -542,11 +536,7 @@ class Loop(Generic[LF]):
TypeError
The function was not a coroutine.
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
self._error = coro # type: ignore
return coro
return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.')
def _get_next_sleep_time(self) -> datetime.datetime:
if self._sleep is not MISSING:
@ -614,8 +604,7 @@ class Loop(Generic[LF]):
)
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
ret = sorted(set(ret)) # de-dupe and sort times
return ret
return sorted(set(ret))
def change_interval(
self,

View File

@ -22,8 +22,12 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, TypedDict, Any, Optional, List, TypeVar, Type, Dict, Callable, Coroutine, NamedTuple, Deque
import asyncio
from collections import namedtuple, deque
from collections import deque
import concurrent.futures
import logging
import struct
@ -40,7 +44,23 @@ from .activity import BaseActivity
from .enums import SpeakingState
from .errors import ConnectionClosed, InvalidArgument
_log = logging.getLogger(__name__)
if TYPE_CHECKING:
from .client import Client
from .state import ConnectionState
from .voice_client import VoiceClient
T = TypeVar('T')
DWS = TypeVar('DWS', bound='DiscordWebSocket')
DVWS = TypeVar('DVWS', bound='DiscordVoiceWebSocket')
Coro = Callable[..., Coroutine[Any, Any, Any]]
Predicate = Callable[[Dict[str, Any]], bool]
DataCallable = Callable[[Dict[str, Any]], T]
Result = Optional[DataCallable[Any]]
_log: logging.Logger = logging.getLogger(__name__)
__all__ = (
'DiscordWebSocket',
@ -50,36 +70,49 @@ __all__ = (
'ReconnectWebSocket',
)
class Heartbeat(TypedDict):
op: int
d: int
class ReconnectWebSocket(Exception):
"""Signals to safely reconnect the websocket."""
def __init__(self, shard_id, *, resume=True):
self.shard_id = shard_id
self.resume = resume
def __init__(self, shard_id: Optional[int], *, resume: bool = True) -> None:
self.shard_id: Optional[int] = shard_id
self.resume: bool = resume
self.op = 'RESUME' if resume else 'IDENTIFY'
class WebSocketClosure(Exception):
"""An exception to make up for the fact that aiohttp doesn't signal closure."""
pass
EventListener = namedtuple('EventListener', 'predicate event result future')
class EventListener(NamedTuple):
predicate: Predicate
event: str
result: Result
future: asyncio.Future
class GatewayRatelimiter:
def __init__(self, count=110, per=60.0):
def __init__(self, count: int = 110, per: float = 60.0) -> None:
# The default is 110 to give room for at least 10 heartbeats per minute
self.max = count
self.remaining = count
self.window = 0.0
self.per = per
self.lock = asyncio.Lock()
self.shard_id = None
self.max: int = count
self.remaining: int = count
self.window: float = 0.0
self.per: float = per
self.lock: asyncio.Lock = asyncio.Lock()
self.shard_id: Optional[int] = None
def is_ratelimited(self):
def is_ratelimited(self) -> bool:
current = time.time()
if current > self.window + self.per:
return False
return self.remaining == 0
def get_delay(self):
def get_delay(self) -> float:
current = time.time()
if current > self.window + self.per:
@ -97,7 +130,7 @@ class GatewayRatelimiter:
return 0.0
async def block(self):
async def block(self) -> None:
async with self.lock:
delta = self.get_delay()
if delta:
@ -106,27 +139,27 @@ class GatewayRatelimiter:
class KeepAliveHandler(threading.Thread):
def __init__(self, *args, **kwargs):
ws = kwargs.pop('ws', None)
def __init__(self, *args: Any, **kwargs: Any) -> None:
ws = kwargs.pop('ws')
interval = kwargs.pop('interval', None)
shard_id = kwargs.pop('shard_id', None)
threading.Thread.__init__(self, *args, **kwargs)
self.ws = ws
self._main_thread_id = ws.thread_id
self.interval = interval
self.daemon = True
self.shard_id = shard_id
self.msg = 'Keeping shard ID %s websocket alive with sequence %s.'
self.block_msg = 'Shard ID %s heartbeat blocked for more than %s seconds.'
self.behind_msg = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.'
self._stop_ev = threading.Event()
self._last_ack = time.perf_counter()
self._last_send = time.perf_counter()
self._last_recv = time.perf_counter()
self.latency = float('inf')
self.heartbeat_timeout = ws._max_heartbeat_timeout
self.ws: DiscordWebSocket = ws
self._main_thread_id: int = ws.thread_id
self.interval: Optional[float] = interval
self.daemon: bool = True
self.shard_id: Optional[int] = shard_id
self.msg: str = 'Keeping shard ID %s websocket alive with sequence %s.'
self.block_msg: str = 'Shard ID %s heartbeat blocked for more than %s seconds.'
self.behind_msg: str = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.'
self._stop_ev: threading.Event = threading.Event()
self._last_ack: float = time.perf_counter()
self._last_send: float = time.perf_counter()
self._last_recv: float = time.perf_counter()
self.latency: float = float('inf')
self.heartbeat_timeout: float = ws._max_heartbeat_timeout
def run(self):
def run(self) -> None:
while not self._stop_ev.wait(self.interval):
if self._last_recv + self.heartbeat_timeout < time.perf_counter():
_log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id)
@ -168,19 +201,20 @@ class KeepAliveHandler(threading.Thread):
else:
self._last_send = time.perf_counter()
def get_payload(self):
def get_payload(self) -> Heartbeat:
return {
'op': self.ws.HEARTBEAT,
'd': self.ws.sequence
# the websocket's sequence won't be None here
'd': self.ws.sequence # type: ignore
}
def stop(self):
def stop(self) -> None:
self._stop_ev.set()
def tick(self):
def tick(self) -> None:
self._last_recv = time.perf_counter()
def ack(self):
def ack(self) -> None:
ack_time = time.perf_counter()
self._last_ack = ack_time
self.latency = ack_time - self._last_send
@ -188,30 +222,32 @@ class KeepAliveHandler(threading.Thread):
_log.warning(self.behind_msg, self.shard_id, self.latency)
class VoiceKeepAliveHandler(KeepAliveHandler):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.recent_ack_latencies = deque(maxlen=20)
self.recent_ack_latencies: Deque[float] = deque(maxlen=20)
self.msg = 'Keeping shard ID %s voice websocket alive with timestamp %s.'
self.block_msg = 'Shard ID %s voice heartbeat blocked for more than %s seconds'
self.behind_msg = 'High socket latency, shard ID %s heartbeat is %.1fs behind'
def get_payload(self):
def get_payload(self) -> Heartbeat:
return {
'op': self.ws.HEARTBEAT,
'd': int(time.time() * 1000)
}
def ack(self):
def ack(self) -> None:
ack_time = time.perf_counter()
self._last_ack = ack_time
self._last_recv = ack_time
self.latency = ack_time - self._last_send
self.recent_ack_latencies.append(self.latency)
class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse):
async def close(self, *, code: int = 4000, message: bytes = b'') -> bool:
return await super().close(code=code, message=message)
class DiscordWebSocket:
"""Implements a WebSocket for Discord's gateway v6.
@ -266,41 +302,53 @@ class DiscordWebSocket:
HEARTBEAT_ACK = 11
GUILD_SYNC = 12
def __init__(self, socket, *, loop):
self.socket = socket
self.loop = loop
def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.AbstractEventLoop) -> None:
self.socket: aiohttp.ClientWebSocketResponse = socket
self.loop: asyncio.AbstractEventLoop = loop
# an empty dispatcher to prevent crashes
self._dispatch = lambda *args: None
# generic event listeners
self._dispatch_listeners = []
self._dispatch_listeners: List[EventListener] = []
# the keep alive
self._keep_alive = None
self.thread_id = threading.get_ident()
self._keep_alive: Optional[KeepAliveHandler] = None
self.thread_id: int = threading.get_ident()
# ws related stuff
self.session_id = None
self.sequence = None
self.session_id: Optional[str] = None
self.sequence: Optional[int] = None
self._zlib = zlib.decompressobj()
self._buffer = bytearray()
self._close_code = None
self._rate_limiter = GatewayRatelimiter()
self._buffer: bytearray = bytearray()
self._close_code: Optional[int] = None
self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter()
# attributes that get set in from_client
self.token: str = utils.MISSING
self._connection: ConnectionState = utils.MISSING
self._discord_parsers: Dict[str, DataCallable[None]] = utils.MISSING
self.gateway: str = utils.MISSING
self.call_hooks: Coro = utils.MISSING
self._initial_identify: bool = utils.MISSING
self.shard_id: Optional[int] = utils.MISSING
self.shard_count: Optional[int] = utils.MISSING
self.session_id: Optional[str] = utils.MISSING
self._max_heartbeat_timeout: float = utils.MISSING
@property
def open(self):
def open(self) -> bool:
return not self.socket.closed
def is_ratelimited(self):
def is_ratelimited(self) -> bool:
return self._rate_limiter.is_ratelimited()
def debug_log_receive(self, data, /):
def debug_log_receive(self, data, /) -> None:
self._dispatch('socket_raw_receive', data)
def log_receive(self, _, /):
def log_receive(self, _, /) -> None:
pass
@classmethod
async def from_client(cls, client, *, initial=False, gateway=None, shard_id=None, session=None, sequence=None, resume=False):
async def from_client(cls: Type[DWS], client: Client, *, initial: bool = False, gateway: Optional[str] = None, shard_id: Optional[int] = None, session: Optional[str] = None, sequence: Optional[int] = None, resume: bool = False) -> DWS:
"""Creates a main websocket for Discord from a :class:`Client`.
This is for internal use only.
@ -310,7 +358,9 @@ class DiscordWebSocket:
ws = cls(socket, loop=client.loop)
# dynamically add attributes needed
ws.token = client.http.token
# the token won't be None here
ws.token = client.http.token # type: ignore
ws._connection = client._connection
ws._discord_parsers = client._connection.parsers
ws._dispatch = client.dispatch
@ -342,7 +392,7 @@ class DiscordWebSocket:
await ws.resume()
return ws
def wait_for(self, event, predicate, result=None):
def wait_for(self, event: str, predicate: Predicate, result: Result = None) -> asyncio.Future:
"""Waits for a DISPATCH'd event that meets the predicate.
Parameters
@ -367,7 +417,7 @@ class DiscordWebSocket:
self._dispatch_listeners.append(entry)
return future
async def identify(self):
async def identify(self) -> None:
"""Sends the IDENTIFY packet."""
payload = {
'op': self.IDENTIFY,
@ -405,7 +455,7 @@ class DiscordWebSocket:
await self.send_as_json(payload)
_log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id)
async def resume(self):
async def resume(self) -> None:
"""Sends the RESUME packet."""
payload = {
'op': self.RESUME,
@ -419,7 +469,8 @@ class DiscordWebSocket:
await self.send_as_json(payload)
_log.info('Shard ID %s has sent the RESUME payload.', self.shard_id)
async def received_message(self, msg, /):
async def received_message(self, msg, /) -> None:
if type(msg) is bytes:
self._buffer.extend(msg)
@ -537,16 +588,16 @@ class DiscordWebSocket:
del self._dispatch_listeners[index]
@property
def latency(self):
def latency(self) -> float:
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds."""
heartbeat = self._keep_alive
return float('inf') if heartbeat is None else heartbeat.latency
def _can_handle_close(self):
def _can_handle_close(self) -> bool:
code = self._close_code or self.socket.close_code
return code not in (1000, 4004, 4010, 4011, 4012, 4013, 4014)
async def poll_event(self):
async def poll_event(self) -> None:
"""Polls for a DISPATCH event and handles the general gateway loop.
Raises
@ -584,23 +635,23 @@ class DiscordWebSocket:
_log.info('Websocket closed with %s, cannot reconnect.', code)
raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None
async def debug_send(self, data, /):
async def debug_send(self, data, /) -> None:
await self._rate_limiter.block()
self._dispatch('socket_raw_send', data)
await self.socket.send_str(data)
async def send(self, data, /):
async def send(self, data, /) -> None:
await self._rate_limiter.block()
await self.socket.send_str(data)
async def send_as_json(self, data):
async def send_as_json(self, data) -> None:
try:
await self.send(utils._to_json(data))
except RuntimeError as exc:
if not self._can_handle_close():
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
async def send_heartbeat(self, data):
async def send_heartbeat(self, data: Heartbeat) -> None:
# This bypasses the rate limit handling code since it has a higher priority
try:
await self.socket.send_str(utils._to_json(data))
@ -608,13 +659,13 @@ class DiscordWebSocket:
if not self._can_handle_close():
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
async def change_presence(self, *, activity=None, status=None, since=0.0):
async def change_presence(self, *, activity: Optional[BaseActivity] = None, status: Optional[str] = None, since: float = 0.0) -> None:
if activity is not None:
if not isinstance(activity, BaseActivity):
raise InvalidArgument('activity must derive from BaseActivity.')
activity = [activity.to_dict()]
activities = [activity.to_dict()]
else:
activity = []
activities = []
if status == 'idle':
since = int(time.time() * 1000)
@ -622,7 +673,7 @@ class DiscordWebSocket:
payload = {
'op': self.PRESENCE,
'd': {
'activities': activity,
'activities': activities,
'afk': False,
'since': since,
'status': status
@ -633,7 +684,7 @@ class DiscordWebSocket:
_log.debug('Sending "%s" to change status', sent)
await self.send(sent)
async def request_chunks(self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None):
async def request_chunks(self, guild_id: int, query: Optional[str] = None, *, limit: int, user_ids: Optional[List[int]] = None, presences: bool = False, nonce: Optional[int] = None) -> None:
payload = {
'op': self.REQUEST_MEMBERS,
'd': {
@ -655,7 +706,7 @@ class DiscordWebSocket:
await self.send_as_json(payload)
async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False):
async def voice_state(self, guild_id: int, channel_id: int, self_mute: bool = False, self_deaf: bool = False) -> None:
payload = {
'op': self.VOICE_STATE,
'd': {
@ -669,7 +720,7 @@ class DiscordWebSocket:
_log.debug('Updating our voice state to %s.', payload)
await self.send_as_json(payload)
async def close(self, code=4000):
async def close(self, code: int = 4000) -> None:
if self._keep_alive:
self._keep_alive.stop()
self._keep_alive = None
@ -721,25 +772,31 @@ class DiscordVoiceWebSocket:
CLIENT_CONNECT = 12
CLIENT_DISCONNECT = 13
def __init__(self, socket, loop, *, hook=None):
self.ws = socket
self.loop = loop
self._keep_alive = None
self._close_code = None
self.secret_key = None
def __init__(self, socket: aiohttp.ClientWebSocketResponse, loop: asyncio.AbstractEventLoop, *, hook: Optional[Coro] = None) -> None:
self.ws: aiohttp.ClientWebSocketResponse = socket
self.loop: asyncio.AbstractEventLoop = loop
self._keep_alive: VoiceKeepAliveHandler = utils.MISSING
self._close_code: Optional[int] = None
self.secret_key: Optional[List[int]] = None
self.gateway: str = utils.MISSING
self._connection: VoiceClient = utils.MISSING
self._max_heartbeat_timeout: float = utils.MISSING
self.thread_id: int = utils.MISSING
if hook:
self._hook = hook
# we want to redeclare self._hook
self._hook = hook # type: ignore
async def _hook(self, *args):
async def _hook(self, *args: Any) -> Any:
pass
async def send_as_json(self, data):
async def send_as_json(self, data) -> None:
_log.debug('Sending voice websocket frame: %s.', data)
await self.ws.send_str(utils._to_json(data))
send_heartbeat = send_as_json
async def resume(self):
async def resume(self) -> None:
state = self._connection
payload = {
'op': self.RESUME,
@ -765,7 +822,7 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload)
@classmethod
async def from_client(cls, client, *, resume=False, hook=None):
async def from_client(cls: Type[DVWS], client: VoiceClient, *, resume: bool = False, hook: Optional[Coro] = None) -> DVWS:
"""Creates a voice websocket for the :class:`VoiceClient`."""
gateway = 'wss://' + client.endpoint + '/?v=4'
http = client._state.http
@ -783,7 +840,7 @@ class DiscordVoiceWebSocket:
return ws
async def select_protocol(self, ip, port, mode):
async def select_protocol(self, ip, port, mode) -> None:
payload = {
'op': self.SELECT_PROTOCOL,
'd': {
@ -798,7 +855,7 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload)
async def client_connect(self):
async def client_connect(self) -> None:
payload = {
'op': self.CLIENT_CONNECT,
'd': {
@ -808,7 +865,7 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload)
async def speak(self, state=SpeakingState.voice):
async def speak(self, state=SpeakingState.voice) -> None:
payload = {
'op': self.SPEAKING,
'd': {
@ -819,7 +876,8 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload)
async def received_message(self, msg):
async def received_message(self, msg) -> None:
_log.debug('Voice websocket frame received: %s', msg)
op = msg['op']
data = msg.get('d')
@ -840,7 +898,7 @@ class DiscordVoiceWebSocket:
await self._hook(self, msg)
async def initial_connection(self, data):
async def initial_connection(self, data) -> None:
state = self._connection
state.ssrc = data['ssrc']
state.voice_port = data['port']
@ -871,13 +929,13 @@ class DiscordVoiceWebSocket:
_log.info('selected the voice protocol for use (%s)', mode)
@property
def latency(self):
def latency(self) -> float:
""":class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds."""
heartbeat = self._keep_alive
return float('inf') if heartbeat is None else heartbeat.latency
@property
def average_latency(self):
def average_latency(self) -> float:
""":class:`list`: Average of last 20 HEARTBEAT latencies."""
heartbeat = self._keep_alive
if heartbeat is None or not heartbeat.recent_ack_latencies:
@ -885,13 +943,14 @@ class DiscordVoiceWebSocket:
return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies)
async def load_secret_key(self, data):
async def load_secret_key(self, data) -> None:
_log.info('received secret key for voice connection')
self.secret_key = self._connection.secret_key = data.get('secret_key')
await self.speak()
await self.speak(False)
async def poll_event(self):
async def poll_event(self) -> None:
# This exception is handled up the chain
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
if msg.type is aiohttp.WSMsgType.TEXT:
@ -903,7 +962,7 @@ class DiscordVoiceWebSocket:
_log.debug('Received %s', msg)
raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code)
async def close(self, code=1000):
async def close(self, code: int = 1000) -> None:
if self._keep_alive is not None:
self._keep_alive.stop()

View File

@ -46,7 +46,7 @@ from . import utils, abc
from .role import Role
from .member import Member, VoiceState
from .emoji import Emoji
from .errors import InvalidData
from .errors import InvalidData, NotFound
from .permissions import PermissionOverwrite
from .colour import Colour
from .errors import InvalidArgument, ClientException
@ -76,6 +76,7 @@ from .stage_instance import StageInstance
from .threads import Thread, ThreadMember
from .sticker import GuildSticker
from .file import File
from .welcome_screen import WelcomeScreen, WelcomeChannel
__all__ = (
@ -140,6 +141,10 @@ class Guild(Hashable):
Returns the guild's name.
.. describe:: int(x)
Returns the guild's ID.
Attributes
----------
name: :class:`str`
@ -738,12 +743,16 @@ class Guild(Hashable):
@property
def humans(self) -> List[Member]:
"""List[:class:`Member`]: A list of human members that belong to this guild."""
"""List[:class:`Member`]: A list of human members that belong to this guild.
.. versionadded:: 2.0 """
return [member for member in self.members if not member.bot]
@property
def bots(self) -> List[Member]:
"""List[:class:`Member`]: A list of bots that belong to this guild."""
"""List[:class:`Member`]: A list of bots that belong to this guild.
.. versionadded:: 2.0 """
return [member for member in self.members if member.bot]
def get_member(self, user_id: int, /) -> Optional[Member]:
@ -1715,6 +1724,8 @@ class Guild(Hashable):
You do not have access to the guild.
HTTPException
Fetching the member failed.
NotFound
A member with that ID does not exist.
Returns
--------
@ -1724,6 +1735,34 @@ class Guild(Hashable):
data = await self._state.http.get_member(self.id, member_id)
return Member(data=data, state=self._state, guild=self)
async def try_member(self, member_id: int, /) -> Optional[Member]:
"""|coro|
Returns a member with the given ID. This uses the cache first, and if not found, it'll request using :meth:`fetch_member`.
.. note::
This method might result in an API call.
Parameters
-----------
member_id: :class:`int`
The ID to search for.
Returns
--------
Optional[:class:`Member`]
The member or ``None`` if not found.
"""
member = self.get_member(member_id)
if member:
return member
else:
try:
return await self.fetch_member(member_id)
except NotFound:
return None
async def fetch_ban(self, user: Snowflake) -> BanEntry:
"""|coro|
@ -2566,6 +2605,81 @@ class Guild(Hashable):
return roles
async def welcome_screen(self) -> WelcomeScreen:
"""|coro|
Returns the guild's welcome screen.
The guild must have ``COMMUNITY`` in :attr:`~Guild.features`.
You must have the :attr:`~Permissions.manage_guild` permission to use
this as well.
.. versionadded:: 2.0
Raises
-------
Forbidden
You do not have the proper permissions to get this.
HTTPException
Retrieving the welcome screen failed.
Returns
--------
:class:`WelcomeScreen`
The welcome screen.
"""
data = await self._state.http.get_welcome_screen(self.id)
return WelcomeScreen(data=data, guild=self)
@overload
async def edit_welcome_screen(
self,
*,
description: Optional[str] = ...,
welcome_channels: Optional[List[WelcomeChannel]] = ...,
enabled: Optional[bool] = ...,
) -> WelcomeScreen:
...
@overload
async def edit_welcome_screen(self) -> None:
...
async def edit_welcome_screen(self, **kwargs):
"""|coro|
A shorthand method of :attr:`WelcomeScreen.edit` without needing
to fetch the welcome screen beforehand.
The guild must have ``COMMUNITY`` in :attr:`~Guild.features`.
You must have the :attr:`~Permissions.manage_guild` permission to use
this as well.
.. versionadded:: 2.0
Returns
--------
:class:`WelcomeScreen`
The edited welcome screen.
"""
try:
welcome_channels = kwargs['welcome_channels']
except KeyError:
pass
else:
welcome_channels_serialised = []
for wc in welcome_channels:
if not isinstance(wc, WelcomeChannel):
raise InvalidArgument('welcome_channels parameter must be a list of WelcomeChannel')
welcome_channels_serialised.append(wc.to_dict())
kwargs['welcome_channels'] = welcome_channels_serialised
if kwargs:
data = await self._state.http.edit_welcome_screen(self.id, kwargs)
return WelcomeScreen(data=data, guild=self)
async def kick(self, user: Snowflake, *, reason: Optional[str] = None) -> None:
"""|coro|

View File

@ -84,6 +84,7 @@ if TYPE_CHECKING:
threads,
voice,
sticker,
welcome_screen,
)
from .types.snowflake import Snowflake, SnowflakeList
@ -1116,6 +1117,20 @@ class HTTPClient:
payload['icon'] = icon
return self.request(Route('POST', '/guilds/templates/{code}', code=code), json=payload)
def get_welcome_screen(self, guild_id: Snowflake) -> Response[welcome_screen.WelcomeScreen]:
return self.request(Route('GET', '/guilds/{guild_id}/welcome-screen', guild_id=guild_id))
def edit_welcome_screen(self, guild_id: Snowflake, payload: Any) -> Response[welcome_screen.WelcomeScreen]:
valid_keys = (
'description',
'welcome_channels',
'enabled',
)
payload = {
k: v for k, v in payload.items() if k in valid_keys
}
return self.request(Route('PATCH', '/guilds/{guild_id}/welcome-screen', guild_id=guild_id), json=payload)
def get_bans(self, guild_id: Snowflake) -> Response[List[guild.Ban]]:
return self.request(Route('GET', '/guilds/{guild_id}/bans', guild_id=guild_id))

View File

@ -230,6 +230,7 @@ class Invite(Hashable):
Returns the invite URL.
The following table illustrates what methods will obtain the attributes:
+------------------------------------+------------------------------------------------------------+
@ -433,6 +434,9 @@ class Invite(Hashable):
def __str__(self) -> str:
return self.url
def __int__(self) -> int:
return 0 # To keep the object compatible with the hashable abc.
def __repr__(self) -> str:
return (
f'<Invite code={self.code!r} guild={self.guild!r} '

View File

@ -226,6 +226,10 @@ class Member(discord.abc.Messageable, _UserTag):
Returns the member's name with the discriminator.
.. describe:: int(x)
Returns the user's ID.
Attributes
----------
joined_at: Optional[:class:`datetime.datetime`]
@ -300,6 +304,9 @@ class Member(discord.abc.Messageable, _UserTag):
def __str__(self) -> str:
return str(self._user)
def __int__(self) -> int:
return self.id
def __repr__(self) -> str:
return (
f'<Member id={self._user.id} name={self._user.name!r} discriminator={self._user.discriminator!r}'

View File

@ -125,6 +125,10 @@ class Attachment(Hashable):
Returns the hash of the attachment.
.. describe:: int(x)
Returns the attachment's ID.
.. versionchanged:: 1.7
Attachment can now be casted to :class:`str` and is hashable.
@ -503,6 +507,14 @@ class Message(Hashable):
Returns the message's hash.
.. describe:: str(x)
Returns the message's content.
.. describe:: int(x)
Returns the message's ID.
Attributes
-----------
tts: :class:`bool`
@ -712,6 +724,10 @@ class Message(Hashable):
f'<{name} id={self.id} channel={self.channel!r} type={self.type!r} author={self.author!r} flags={self.flags!r}>'
)
def __str__(self) -> Optional[str]:
return self.content
def _try_patch(self, data, key, transform=None) -> None:
try:
value = data[key]
@ -1634,6 +1650,10 @@ class PartialMessage(Hashable):
Returns the partial message's hash.
.. describe:: int(x)
Returns the partial message's ID.
Attributes
-----------
channel: Union[:class:`TextChannel`, :class:`Thread`, :class:`DMChannel`]

View File

@ -43,5 +43,8 @@ class EqualityComparable:
class Hashable(EqualityComparable):
__slots__ = ()
def __int__(self) -> int:
return self.id
def __hash__(self) -> int:
return self.id >> 22

View File

@ -69,6 +69,10 @@ class Object(Hashable):
Returns the object's hash.
.. describe:: int(x)
Returns the object's ID.
Attributes
-----------
id: :class:`int`

View File

@ -299,6 +299,13 @@ class Permissions(BaseFlags):
"""
return 1 << 3
@make_permission_alias('administrator')
def admin(self) -> int:
""":class:`bool`: An alias for :attr:`administrator`.
.. versionadded:: 2.0
"""
return 1 << 3
@flag_value
def manage_channels(self) -> int:
""":class:`bool`: Returns ``True`` if a user can edit, delete, or create channels in the guild.

View File

@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import threading
@ -63,10 +64,7 @@ __all__ = (
CREATE_NO_WINDOW: int
if sys.platform != 'win32':
CREATE_NO_WINDOW = 0
else:
CREATE_NO_WINDOW = 0x08000000
CREATE_NO_WINDOW = 0 if sys.platform != 'win32' else 0x08000000
class AudioSource:
"""Represents an audio stream.
@ -526,7 +524,12 @@ class FFmpegOpusAudio(FFmpegAudio):
@staticmethod
def _probe_codec_native(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]:
exe = executable[:2] + 'probe' if executable in ('ffmpeg', 'avconv') else executable
exe = (
executable[:2] + 'probe'
if executable in {'ffmpeg', 'avconv'}
else executable
)
args = [exe, '-v', 'quiet', '-print_format', 'json', '-show_streams', '-select_streams', 'a:0', source]
output = subprocess.check_output(args, timeout=20)
codec = bitrate = None

View File

@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import datetime
from typing import TYPE_CHECKING, Optional, Set, List
if TYPE_CHECKING:
@ -34,7 +35,8 @@ if TYPE_CHECKING:
MessageUpdateEvent,
ReactionClearEvent,
ReactionClearEmojiEvent,
IntegrationDeleteEvent
IntegrationDeleteEvent,
TypingEvent
)
from .message import Message
from .partial_emoji import PartialEmoji
@ -49,6 +51,7 @@ __all__ = (
'RawReactionClearEvent',
'RawReactionClearEmojiEvent',
'RawIntegrationDeleteEvent',
'RawTypingEvent'
)
@ -276,3 +279,36 @@ class RawIntegrationDeleteEvent(_RawReprMixin):
self.application_id: Optional[int] = int(data['application_id'])
except KeyError:
self.application_id: Optional[int] = None
class RawTypingEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_typing` event.
.. versionadded:: 2.0
Attributes
-----------
channel_id: :class:`int`
The channel ID where the typing originated from.
user_id: :class:`int`
The ID of the user that started typing.
when: :class:`datetime.datetime`
When the typing started as an aware datetime in UTC.
guild_id: Optional[:class:`int`]
The guild ID where the typing originated from, if applicable.
member: Optional[:class:`Member`]
The member who started typing. Only available if the member started typing in a guild.
"""
__slots__ = ("channel_id", "user_id", "when", "guild_id", "member")
def __init__(self, data: TypingEvent) -> None:
self.channel_id: int = int(data['channel_id'])
self.user_id: int = int(data['user_id'])
self.when: datetime.datetime = datetime.datetime.fromtimestamp(data.get('timestamp'), tz=datetime.timezone.utc)
self.member: Optional[Member] = None
try:
self.guild_id: Optional[int] = int(data['guild_id'])
except KeyError:
self.guild_id: Optional[int] = None

View File

@ -141,6 +141,14 @@ class Role(Hashable):
Returns the role's name.
.. describe:: str(x)
Returns the role's ID.
.. describe:: int(x)
Returns the role's ID.
Attributes
----------
id: :class:`int`
@ -195,6 +203,9 @@ class Role(Hashable):
def __str__(self) -> str:
return self.name
def __int__(self) -> int:
return self.id
def __repr__(self) -> str:
return f'<Role id={self.id} name={self.name!r}>'

View File

@ -61,6 +61,10 @@ class StageInstance(Hashable):
Returns the stage instance's hash.
.. describe:: int(x)
Returns the stage instance's ID.
Attributes
-----------
id: :class:`int`

View File

@ -1327,28 +1327,37 @@ class ConnectionState:
asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice server update handler'))
def parse_typing_start(self, data) -> None:
channel, guild = self._get_guild_channel(data)
if channel is not None:
member = None
user_id = utils._get_as_snowflake(data, 'user_id')
if isinstance(channel, DMChannel):
member = channel.recipient
raw = RawTypingEvent(data)
elif isinstance(channel, (Thread, TextChannel)) and guild is not None:
# user_id won't be None
member = guild.get_member(user_id) # type: ignore
if member is None:
member_data = data.get('member')
if member_data:
member = Member(data=member_data, state=self, guild=guild)
guild = self._get_guild(raw.guild_id)
if guild is not None:
raw.member = Member(data=member_data, guild=guild, state=self)
else:
raw.member = None
else:
raw.member = None
self.dispatch('raw_typing', raw)
channel, guild = self._get_guild_channel(data)
if channel is not None:
user = raw.member or self._get_typing_user(channel, raw.user_id)
if user is not None:
self.dispatch('typing', channel, user, raw.when)
def _get_typing_user(self, channel: Optional[MessageableChannel], user_id: int) -> Optional[Union[User, Member]]:
if isinstance(channel, DMChannel):
return channel.recipient
elif isinstance(channel, (Thread, TextChannel)) and channel.guild is not None:
return channel.guild.get_member(user_id) # type: ignore
elif isinstance(channel, GroupChannel):
member = utils.find(lambda x: x.id == user_id, channel.recipients)
return utils.find(lambda x: x.id == user_id, channel.recipients)
if member is not None:
timestamp = datetime.datetime.fromtimestamp(data.get('timestamp'), tz=datetime.timezone.utc)
self.dispatch('typing', channel, member, timestamp)
return self.get_user(user_id)
def _get_reaction_user(self, channel: MessageableChannel, user_id: int) -> Optional[Union[User, Member]]:
if isinstance(channel, TextChannel):

View File

@ -67,6 +67,14 @@ class StickerPack(Hashable):
Returns the name of the sticker pack.
.. describe:: hash(x)
Returns the hash of the sticker pack.
.. describe:: int(x)
Returns the ID of the sticker pack.
.. describe:: x == y
Checks if the sticker pack is equal to another sticker pack.

View File

@ -74,6 +74,10 @@ class Thread(Messageable, Hashable):
Returns the thread's hash.
.. describe:: int(x)
Returns the thread's ID.
.. describe:: str(x)
Returns the thread's name.
@ -748,6 +752,10 @@ class ThreadMember(Hashable):
Returns the thread member's hash.
.. describe:: int(x)
Returns the thread member's ID.
.. describe:: str(x)
Returns the thread member's name.
@ -800,3 +808,39 @@ class ThreadMember(Hashable):
def thread(self) -> Thread:
""":class:`Thread`: The thread this member belongs to."""
return self.parent
async def fetch_member(self) -> Member:
"""|coro|
Retrieves a :class:`Member` from the ThreadMember object.
.. note::
This method is an API call. If you have :attr:`Intents.members` and member cache enabled, consider :meth:`get_member` instead.
Raises
-------
Forbidden
You do not have access to the guild.
HTTPException
Fetching the member failed.
Returns
--------
:class:`Member`
The member.
"""
return await self.thread.guild.fetch_member(self.id)
def get_member(self) -> Optional[Member]:
"""
Get the :class:`Member` from cache for the ThreadMember object.
Returns
--------
Optional[:class:`Member`]
The member or ``None`` if not found.
"""
return self.thread.guild.get_member(self.id)

View File

@ -85,3 +85,13 @@ class _IntegrationDeleteEventOptional(TypedDict, total=False):
class IntegrationDeleteEvent(_IntegrationDeleteEventOptional):
id: Snowflake
guild_id: Snowflake
class _TypingEventOptional(TypedDict, total=False):
guild_id: Snowflake
member: Member
class TypingEvent(_TypingEventOptional):
channel_id: Snowflake
user_id: Snowflake
timestamp: int

View File

@ -185,15 +185,15 @@ class Button(Item[V]):
@emoji.setter
def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]): # type: ignore
if value is not None:
if isinstance(value, str):
if value is None:
self._underlying.emoji = None
elif isinstance(value, str):
self._underlying.emoji = PartialEmoji.from_str(value)
elif isinstance(value, _EmojiTag):
self._underlying.emoji = value._to_partial()
else:
raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__} instead')
else:
self._underlying.emoji = None
@classmethod
def from_component(cls: Type[B], button: ButtonComponent) -> B:

View File

@ -96,6 +96,9 @@ class BaseUser(_UserTag):
def __str__(self) -> str:
return f'{self.name}#{self.discriminator}'
def __int__(self) -> int:
return self.id
def __eq__(self, other: Any) -> bool:
return isinstance(other, _UserTag) and other.id == self.id
@ -415,6 +418,10 @@ class User(BaseUser, discord.abc.Messageable):
Returns the user's name with discriminator.
.. describe:: int(x)
Returns the user's ID.
Attributes
-----------
name: :class:`str`

View File

@ -499,13 +499,13 @@ else:
def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
reset_after: Optional[str] = request.headers.get('X-Ratelimit-Reset-After')
if use_clock or not reset_after:
if not use_clock and reset_after:
return float(reset_after)
utc = datetime.timezone.utc
now = datetime.datetime.now(utc)
reset = datetime.datetime.fromtimestamp(float(request.headers['X-Ratelimit-Reset']), utc)
return (reset - now).total_seconds()
else:
return float(reset_after)
async def maybe_coroutine(f, *args, **kwargs):
@ -659,7 +659,6 @@ def resolve_invite(invite: Union[Invite, str]) -> str:
if isinstance(invite, Invite):
return invite.code
else:
rx = r'(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)'
m = re.match(rx, invite)
if m:
@ -687,7 +686,6 @@ def resolve_template(code: Union[Template, str]) -> str:
if isinstance(code, Template):
return code.code
else:
rx = r'(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)'
m = re.match(rx, code)
if m:
@ -1017,3 +1015,9 @@ def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None)
if style is None:
return f'<t:{int(dt.timestamp())}>'
return f'<t:{int(dt.timestamp())}:{style}>'
def raise_expected_coro(coro, error: str)-> TypeError:
if not asyncio.iscoroutinefunction(coro):
raise TypeError(error)
return coro

View File

@ -255,6 +255,9 @@ class VoiceClient(VoiceProtocol):
self.encoder: Encoder = MISSING
self._lite_nonce: int = 0
self.ws: DiscordVoiceWebSocket = MISSING
self.ip: str = MISSING
self.port: Tuple[Any, ...] = MISSING
warn_nacl = not has_nacl
supported_modes: Tuple[SupportedModes, ...] = (

View File

@ -886,6 +886,10 @@ class Webhook(BaseWebhook):
Returns the webhooks's hash.
.. describe:: int(x)
Returns the webhooks's ID.
.. versionchanged:: 1.4
Webhooks are now comparable and hashable.

View File

@ -475,6 +475,10 @@ class SyncWebhook(BaseWebhook):
Returns the webhooks's hash.
.. describe:: int(x)
Returns the webhooks's ID.
.. versionchanged:: 1.4
Webhooks are now comparable and hashable.

216
discord/welcome_screen.py Normal file
View File

@ -0,0 +1,216 @@
# -*- coding: utf-8 -*-
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Dict, List, Optional, TYPE_CHECKING, Union, overload
from .utils import _get_as_snowflake, get
from .errors import InvalidArgument
from .partial_emoji import _EmojiTag
__all__ = (
'WelcomeChannel',
'WelcomeScreen',
)
if TYPE_CHECKING:
from .types.welcome_screen import (
WelcomeScreen as WelcomeScreenPayload,
WelcomeScreenChannel as WelcomeScreenChannelPayload,
)
from .abc import Snowflake
from .guild import Guild
from .partial_emoji import PartialEmoji
from .emoji import Emoji
class WelcomeChannel:
"""Represents a :class:`WelcomeScreen` welcome channel.
.. versionadded:: 2.0
Attributes
-----------
channel: :class:`abc.Snowflake`
The guild channel that is being referenced.
description: :class:`str`
The description shown of the channel.
emoji: Optional[:class:`PartialEmoji`, :class:`Emoji`, :class:`str`]
The emoji used beside the channel description.
"""
def __init__(self, *, channel: Snowflake, description: str, emoji: Union[PartialEmoji, Emoji, str] = None):
self.channel = channel
self.description = description
self.emoji = emoji
def __repr__(self) -> str:
return f'<WelcomeChannel channel={self.channel!r} description={self.description!r} emoji={self.emoji!r}>'
@classmethod
def _from_dict(cls, *, data: WelcomeScreenChannelPayload, guild: Guild) -> WelcomeChannel:
channel_id = _get_as_snowflake(data, 'channel_id')
channel = guild.get_channel(channel_id)
description = data['description']
_emoji_id = _get_as_snowflake(data, 'emoji_id')
_emoji_name = data['emoji_name']
if _emoji_id:
# custom
emoji = get(guild.emojis, id=_emoji_id)
else:
# unicode or None
emoji = _emoji_name
return cls(channel=channel, description=description, emoji=emoji) # type: ignore
def to_dict(self) -> WelcomeScreenChannelPayload:
ret: WelcomeScreenChannelPayload = {
'channel_id': self.channel.id,
'description': self.description,
'emoji_id': None,
'emoji_name': None,
}
if isinstance(self.emoji, _EmojiTag):
ret['emoji_id'] = self.emoji.id # type: ignore
ret['emoji_name'] = self.emoji.name # type: ignore
else:
# unicode or None
ret['emoji_name'] = self.emoji
return ret
class WelcomeScreen:
"""Represents a :class:`Guild` welcome screen.
.. versionadded:: 2.0
Attributes
-----------
description: :class:`str`
The description shown on the welcome screen.
welcome_channels: List[:class:`WelcomeChannel`]
The channels shown on the welcome screen.
"""
def __init__(self, *, data: WelcomeScreenPayload, guild: Guild):
self._state = guild._state
self._guild = guild
self._store(data)
def _store(self, data: WelcomeScreenPayload) -> None:
self.description = data['description']
welcome_channels = data.get('welcome_channels', [])
self.welcome_channels = [WelcomeChannel._from_dict(data=wc, guild=self._guild) for wc in welcome_channels]
def __repr__(self) -> str:
return f'<WelcomeScreen description={self.description!r} welcome_channels={self.welcome_channels!r} enabled={self.enabled}>'
@property
def enabled(self) -> bool:
""":class:`bool`: Whether the welcome screen is displayed.
This is equivalent to checking if ``WELCOME_SCREEN_ENABLED``
is present in :attr:`Guild.features`.
"""
return 'WELCOME_SCREEN_ENABLED' in self._guild.features
@overload
async def edit(
self,
*,
description: Optional[str] = ...,
welcome_channels: Optional[List[WelcomeChannel]] = ...,
enabled: Optional[bool] = ...,
) -> None:
...
@overload
async def edit(self) -> None:
...
async def edit(self, **kwargs):
"""|coro|
Edit the welcome screen.
You must have the :attr:`~Permissions.manage_guild` permission in the
guild to do this.
Usage: ::
rules_channel = guild.get_channel(12345678)
announcements_channel = guild.get_channel(87654321)
custom_emoji = utils.get(guild.emojis, name='loudspeaker')
await welcome_screen.edit(
description='This is a very cool community server!',
welcome_channels=[
WelcomeChannel(channel=rules_channel, description='Read the rules!', emoji='👨‍🏫'),
WelcomeChannel(channel=announcements_channel, description='Watch out for announcements!', emoji=custom_emoji),
]
)
.. note::
Welcome channels can only accept custom emojis if :attr:`~Guild.premium_tier` is level 2 or above.
Parameters
------------
description: Optional[:class:`str`]
The template's description.
welcome_channels: Optional[List[:class:`WelcomeChannel`]]
The welcome channels, in their respective order.
enabled: Optional[:class:`bool`]
Whether the welcome screen should be displayed.
Raises
-------
HTTPException
Editing the welcome screen failed failed.
Forbidden
You don't have permissions to edit the welcome screen.
NotFound
This welcome screen does not exist.
"""
try:
welcome_channels = kwargs['welcome_channels']
except KeyError:
pass
else:
welcome_channels_serialised = []
for wc in welcome_channels:
if not isinstance(wc, WelcomeChannel):
raise InvalidArgument('welcome_channels parameter must be a list of WelcomeChannel')
welcome_channels_serialised.append(wc.to_dict())
kwargs['welcome_channels'] = welcome_channels_serialised
if kwargs:
data = await self._state.http.edit_welcome_screen(self._guild.id, kwargs)
self._store(data)

View File

@ -369,6 +369,17 @@ to handle it, which defaults to print a traceback and ignoring the exception.
:param when: When the typing started as an aware datetime in UTC.
:type when: :class:`datetime.datetime`
.. function:: on_raw_typing(payload)
Called when someone begins typing a message. Unlike :func:`on_typing`, this is
called regardless if the user can be found or not. This most often happens
when a user types in DMs.
This requires :attr:`Intents.typing` to be enabled.
:param payload: The raw typing payload.
:type payload: :class:`RawTypingEvent`
.. function:: on_message(message)
Called when a :class:`Message` is created and sent.
@ -3781,6 +3792,22 @@ Template
.. autoclass:: Template()
:members:
WelcomeScreen
~~~~~~~~~~~~~~~
.. attributetable:: WelcomeScreen
.. autoclass:: WelcomeScreen()
:members:
WelcomeChannel
~~~~~~~~~~~~~~~
.. attributetable:: WelcomeChannel
.. autoclass:: WelcomeChannel()
:members:
WidgetChannel
~~~~~~~~~~~~~~~
@ -3846,6 +3873,14 @@ GuildSticker
.. autoclass:: GuildSticker()
:members:
RawTypingEvent
~~~~~~~~~~~~~~~~~~~~~~~
.. attributetable:: RawTypingEvent
.. autoclass:: RawTypingEvent()
:members:
RawMessageDeleteEvent
~~~~~~~~~~~~~~~~~~~~~~~