mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-04-22 00:34:06 +00:00
Add asynchronous context manager support for Client
This commit is contained in:
parent
93af158b0c
commit
c02a3c0bb2
@ -41,6 +41,7 @@ from typing import (
|
||||
Sequence,
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
@ -76,6 +77,8 @@ from .threads import Thread
|
||||
from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
from types import TracebackType
|
||||
from .types.guild import Guild as GuildPayload
|
||||
from .abc import SnowflakeTime, Snowflake, PrivateChannel
|
||||
from .guild import GuildChannel
|
||||
@ -180,10 +183,7 @@ class Client:
|
||||
The websocket gateway the client is currently connected to. Could be ``None``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**options: Any,
|
||||
):
|
||||
def __init__(self, **options: Any) -> None:
|
||||
self.loop: asyncio.AbstractEventLoop = MISSING
|
||||
# self.ws is set in the connect method
|
||||
self.ws: DiscordWebSocket = None # type: ignore
|
||||
@ -216,6 +216,19 @@ class Client:
|
||||
VoiceClient.warn_nacl = False
|
||||
_log.warning("PyNaCl is not installed, voice will NOT be supported")
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
self.loop = asyncio.get_running_loop()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
if not self.is_closed():
|
||||
await self.close()
|
||||
|
||||
# internals
|
||||
|
||||
def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket:
|
||||
@ -601,12 +614,8 @@ class Client:
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self.http.loop = self.loop
|
||||
self._connection.loop = self.loop
|
||||
try:
|
||||
await self.login(token)
|
||||
await self.connect(reconnect=reconnect)
|
||||
finally:
|
||||
if not self.is_closed():
|
||||
await self.close()
|
||||
await self.login(token)
|
||||
await self.connect(reconnect=reconnect)
|
||||
|
||||
async def setup_hook(self) -> None:
|
||||
"""|coro|
|
||||
@ -645,8 +654,13 @@ class Client:
|
||||
is blocking. That means that registration of events or anything being
|
||||
called after this function call will not execute until it returns.
|
||||
"""
|
||||
|
||||
async def runner():
|
||||
async with self:
|
||||
await self.start(*args, **kwargs)
|
||||
|
||||
try:
|
||||
asyncio.run(self.start(*args, **kwargs))
|
||||
asyncio.run(runner())
|
||||
except KeyboardInterrupt:
|
||||
# nothing to do here
|
||||
# `asyncio.run` handles the loop cleanup
|
||||
|
Loading…
x
Reference in New Issue
Block a user