mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-10-24 10:02:56 +00:00
Refactor Client.run to use asyncio.run
This also adds asynchronous context manager support to allow for idiomatic asyncio usage for the lower-level counterpart. At first I wanted to remove Client.run but I figured that a lot of beginners would have been confused or not enjoyed the verbosity of the newer approach of using async-with.
This commit is contained in:
@@ -26,10 +26,24 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import signal
|
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Any, Callable, Coroutine, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Coroutine,
|
||||||
|
Dict,
|
||||||
|
Generator,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
@@ -68,6 +82,7 @@ if TYPE_CHECKING:
|
|||||||
from .message import Message
|
from .message import Message
|
||||||
from .member import Member
|
from .member import Member
|
||||||
from .voice_client import VoiceProtocol
|
from .voice_client import VoiceProtocol
|
||||||
|
from types import TracebackType
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
'Client',
|
'Client',
|
||||||
@@ -78,36 +93,8 @@ Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])
|
|||||||
|
|
||||||
log: logging.Logger = logging.getLogger(__name__)
|
log: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
|
C = TypeVar('C', bound='Client')
|
||||||
tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()}
|
|
||||||
|
|
||||||
if not tasks:
|
|
||||||
return
|
|
||||||
|
|
||||||
log.info('Cleaning up after %d tasks.', len(tasks))
|
|
||||||
for task in tasks:
|
|
||||||
task.cancel()
|
|
||||||
|
|
||||||
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
|
|
||||||
log.info('All tasks finished cancelling.')
|
|
||||||
|
|
||||||
for task in tasks:
|
|
||||||
if task.cancelled():
|
|
||||||
continue
|
|
||||||
if task.exception() is not None:
|
|
||||||
loop.call_exception_handler({
|
|
||||||
'message': 'Unhandled exception during Client.run shutdown.',
|
|
||||||
'exception': task.exception(),
|
|
||||||
'task': task
|
|
||||||
})
|
|
||||||
|
|
||||||
def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None:
|
|
||||||
try:
|
|
||||||
_cancel_tasks(loop)
|
|
||||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
|
||||||
finally:
|
|
||||||
log.info('Closing the event loop.')
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
class Client:
|
class Client:
|
||||||
r"""Represents a client connection that connects to Discord.
|
r"""Represents a client connection that connects to Discord.
|
||||||
@@ -200,6 +187,7 @@ class Client:
|
|||||||
loop: :class:`asyncio.AbstractEventLoop`
|
loop: :class:`asyncio.AbstractEventLoop`
|
||||||
The event loop that the client uses for asynchronous operations.
|
The event loop that the client uses for asynchronous operations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -207,7 +195,8 @@ class Client:
|
|||||||
**options: Any,
|
**options: Any,
|
||||||
):
|
):
|
||||||
self.ws: DiscordWebSocket = None # type: ignore
|
self.ws: DiscordWebSocket = None # type: ignore
|
||||||
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
|
# this is filled in later
|
||||||
|
self.loop: asyncio.AbstractEventLoop = MISSING if loop is None else loop
|
||||||
self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {}
|
self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {}
|
||||||
self.shard_id: Optional[int] = options.get('shard_id')
|
self.shard_id: Optional[int] = options.get('shard_id')
|
||||||
self.shard_count: Optional[int] = options.get('shard_count')
|
self.shard_count: Optional[int] = options.get('shard_count')
|
||||||
@@ -216,14 +205,16 @@ class Client:
|
|||||||
proxy: Optional[str] = options.pop('proxy', None)
|
proxy: Optional[str] = options.pop('proxy', None)
|
||||||
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None)
|
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None)
|
||||||
unsync_clock: bool = options.pop('assume_unsync_clock', True)
|
unsync_clock: bool = options.pop('assume_unsync_clock', True)
|
||||||
self.http: HTTPClient = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop)
|
self.http: HTTPClient = HTTPClient(
|
||||||
|
connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=loop
|
||||||
|
)
|
||||||
|
|
||||||
self._handlers: Dict[str, Callable] = {
|
self._handlers: Dict[str, Callable] = {
|
||||||
'ready': self._handle_ready
|
'ready': self._handle_ready,
|
||||||
}
|
}
|
||||||
|
|
||||||
self._hooks: Dict[str, Callable] = {
|
self._hooks: Dict[str, Callable] = {
|
||||||
'before_identify': self._call_before_identify_hook
|
'before_identify': self._call_before_identify_hook,
|
||||||
}
|
}
|
||||||
|
|
||||||
self._enable_debug_events: bool = options.pop('enable_debug_events', False)
|
self._enable_debug_events: bool = options.pop('enable_debug_events', False)
|
||||||
@@ -244,8 +235,9 @@ class Client:
|
|||||||
return self.ws
|
return self.ws
|
||||||
|
|
||||||
def _get_state(self, **options: Any) -> ConnectionState:
|
def _get_state(self, **options: Any) -> ConnectionState:
|
||||||
return ConnectionState(dispatch=self.dispatch, handlers=self._handlers,
|
return ConnectionState(
|
||||||
hooks=self._hooks, http=self.http, loop=self.loop, **options)
|
dispatch=self.dispatch, handlers=self._handlers, hooks=self._hooks, http=self.http, loop=self.loop, **options
|
||||||
|
)
|
||||||
|
|
||||||
def _handle_ready(self) -> None:
|
def _handle_ready(self) -> None:
|
||||||
self._ready.set()
|
self._ready.set()
|
||||||
@@ -343,7 +335,9 @@ class Client:
|
|||||||
""":class:`bool`: Specifies if the client's internal cache is ready for use."""
|
""":class:`bool`: Specifies if the client's internal cache is ready for use."""
|
||||||
return self._ready.is_set()
|
return self._ready.is_set()
|
||||||
|
|
||||||
async def _run_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> None:
|
async def _run_event(
|
||||||
|
self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
try:
|
try:
|
||||||
await coro(*args, **kwargs)
|
await coro(*args, **kwargs)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
@@ -354,7 +348,9 @@ class Client:
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _schedule_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task:
|
def _schedule_event(
|
||||||
|
self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any
|
||||||
|
) -> asyncio.Task:
|
||||||
wrapped = self._run_event(coro, event_name, *args, **kwargs)
|
wrapped = self._run_event(coro, event_name, *args, **kwargs)
|
||||||
# Schedules the task
|
# Schedules the task
|
||||||
return asyncio.create_task(wrapped, name=f'discord.py: {event_name}')
|
return asyncio.create_task(wrapped, name=f'discord.py: {event_name}')
|
||||||
@@ -466,7 +462,8 @@ class Client:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
log.info('logging in using static token')
|
log.info('logging in using static token')
|
||||||
|
self.loop = loop = asyncio.get_running_loop()
|
||||||
|
self._connection.loop = loop
|
||||||
data = await self.http.static_login(token.strip())
|
data = await self.http.static_login(token.strip())
|
||||||
self._connection.user = ClientUser(state=self._connection, data=data)
|
self._connection.user = ClientUser(state=self._connection, data=data)
|
||||||
|
|
||||||
@@ -512,12 +509,14 @@ class Client:
|
|||||||
self.dispatch('disconnect')
|
self.dispatch('disconnect')
|
||||||
ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id)
|
ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id)
|
||||||
continue
|
continue
|
||||||
except (OSError,
|
except (
|
||||||
HTTPException,
|
OSError,
|
||||||
GatewayNotFound,
|
HTTPException,
|
||||||
ConnectionClosed,
|
GatewayNotFound,
|
||||||
aiohttp.ClientError,
|
ConnectionClosed,
|
||||||
asyncio.TimeoutError) as exc:
|
aiohttp.ClientError,
|
||||||
|
asyncio.TimeoutError,
|
||||||
|
) as exc:
|
||||||
|
|
||||||
self.dispatch('disconnect')
|
self.dispatch('disconnect')
|
||||||
if not reconnect:
|
if not reconnect:
|
||||||
@@ -558,6 +557,22 @@ class Client:
|
|||||||
"""|coro|
|
"""|coro|
|
||||||
|
|
||||||
Closes the connection to Discord.
|
Closes the connection to Discord.
|
||||||
|
|
||||||
|
Instead of calling this directly, it is recommended to use the asynchronous context
|
||||||
|
manager to allow resources to be cleaned up automatically:
|
||||||
|
|
||||||
|
.. code-block:: python3
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with Client() as client:
|
||||||
|
await client.login(token)
|
||||||
|
await client.connect()
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
|
|
||||||
|
.. versionchanged:: 2.0
|
||||||
|
The client can now be closed with an asynchronous context manager
|
||||||
"""
|
"""
|
||||||
if self._closed:
|
if self._closed:
|
||||||
return
|
return
|
||||||
@@ -589,36 +604,47 @@ class Client:
|
|||||||
self._connection.clear()
|
self._connection.clear()
|
||||||
self.http.recreate()
|
self.http.recreate()
|
||||||
|
|
||||||
|
async def __aenter__(self: C) -> C:
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[Type[BaseException]],
|
||||||
|
exc_value: Optional[BaseException],
|
||||||
|
traceback: Optional[TracebackType],
|
||||||
|
) -> None:
|
||||||
|
await self.close()
|
||||||
|
|
||||||
async def start(self, token: str, *, reconnect: bool = True) -> None:
|
async def start(self, token: str, *, reconnect: bool = True) -> None:
|
||||||
"""|coro|
|
"""|coro|
|
||||||
|
|
||||||
A shorthand coroutine for :meth:`login` + :meth:`connect`.
|
A shorthand function equivalent to the following:
|
||||||
|
|
||||||
Raises
|
.. code-block:: python3
|
||||||
-------
|
|
||||||
TypeError
|
async with client:
|
||||||
An unexpected keyword argument was received.
|
await client.login(token)
|
||||||
|
await client.connect()
|
||||||
|
|
||||||
|
This closes the client when it returns.
|
||||||
"""
|
"""
|
||||||
await self.login(token)
|
try:
|
||||||
await self.connect(reconnect=reconnect)
|
await self.login(token)
|
||||||
|
await self.connect(reconnect=reconnect)
|
||||||
|
finally:
|
||||||
|
await self.close()
|
||||||
|
|
||||||
def run(self, *args: Any, **kwargs: Any) -> None:
|
def run(self, *args: Any, **kwargs: Any) -> None:
|
||||||
"""A blocking call that abstracts away the event loop
|
"""A convenience blocking call that abstracts away the event loop
|
||||||
initialisation from you.
|
initialisation from you.
|
||||||
|
|
||||||
If you want more control over the event loop then this
|
If you want more control over the event loop then this
|
||||||
function should not be used. Use :meth:`start` coroutine
|
function should not be used. Use :meth:`start` coroutine
|
||||||
or :meth:`connect` + :meth:`login`.
|
or :meth:`connect` + :meth:`login`.
|
||||||
|
|
||||||
Roughly Equivalent to: ::
|
Equivalent to: ::
|
||||||
|
|
||||||
try:
|
asyncio.run(bot.start(*args, **kwargs))
|
||||||
loop.run_until_complete(start(*args, **kwargs))
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
loop.run_until_complete(close())
|
|
||||||
# cancel all tasks lingering
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
|
|
||||||
@@ -626,41 +652,7 @@ class Client:
|
|||||||
is blocking. That means that registration of events or anything being
|
is blocking. That means that registration of events or anything being
|
||||||
called after this function call will not execute until it returns.
|
called after this function call will not execute until it returns.
|
||||||
"""
|
"""
|
||||||
loop = self.loop
|
asyncio.run(self.start(*args, **kwargs))
|
||||||
|
|
||||||
try:
|
|
||||||
loop.add_signal_handler(signal.SIGINT, lambda: loop.stop())
|
|
||||||
loop.add_signal_handler(signal.SIGTERM, lambda: loop.stop())
|
|
||||||
except NotImplementedError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def runner():
|
|
||||||
try:
|
|
||||||
await self.start(*args, **kwargs)
|
|
||||||
finally:
|
|
||||||
if not self.is_closed():
|
|
||||||
await self.close()
|
|
||||||
|
|
||||||
def stop_loop_on_completion(f):
|
|
||||||
loop.stop()
|
|
||||||
|
|
||||||
future = asyncio.ensure_future(runner(), loop=loop)
|
|
||||||
future.add_done_callback(stop_loop_on_completion)
|
|
||||||
try:
|
|
||||||
loop.run_forever()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
log.info('Received signal to terminate bot and event loop.')
|
|
||||||
finally:
|
|
||||||
future.remove_done_callback(stop_loop_on_completion)
|
|
||||||
log.info('Cleaning up tasks.')
|
|
||||||
_cleanup_loop(loop)
|
|
||||||
|
|
||||||
if not future.cancelled():
|
|
||||||
try:
|
|
||||||
return future.result()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
# I am unsure why this gets raised here but suppress it anyway
|
|
||||||
return None
|
|
||||||
|
|
||||||
# properties
|
# properties
|
||||||
|
|
||||||
@@ -973,8 +965,10 @@ class Client:
|
|||||||
|
|
||||||
future = self.loop.create_future()
|
future = self.loop.create_future()
|
||||||
if check is None:
|
if check is None:
|
||||||
|
|
||||||
def _check(*args):
|
def _check(*args):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
check = _check
|
check = _check
|
||||||
|
|
||||||
ev = event.lower()
|
ev = event.lower()
|
||||||
@@ -1083,7 +1077,7 @@ class Client:
|
|||||||
*,
|
*,
|
||||||
limit: Optional[int] = 100,
|
limit: Optional[int] = 100,
|
||||||
before: SnowflakeTime = None,
|
before: SnowflakeTime = None,
|
||||||
after: SnowflakeTime = None
|
after: SnowflakeTime = None,
|
||||||
) -> GuildIterator:
|
) -> GuildIterator:
|
||||||
"""Retrieves an :class:`.AsyncIterator` that enables receiving your guilds.
|
"""Retrieves an :class:`.AsyncIterator` that enables receiving your guilds.
|
||||||
|
|
||||||
@@ -1163,7 +1157,7 @@ class Client:
|
|||||||
"""
|
"""
|
||||||
code = utils.resolve_template(code)
|
code = utils.resolve_template(code)
|
||||||
data = await self.http.get_template(code)
|
data = await self.http.get_template(code)
|
||||||
return Template(data=data, state=self._connection) # type: ignore
|
return Template(data=data, state=self._connection) # type: ignore
|
||||||
|
|
||||||
async def fetch_guild(self, guild_id: int) -> Guild:
|
async def fetch_guild(self, guild_id: int) -> Guild:
|
||||||
"""|coro|
|
"""|coro|
|
||||||
@@ -1284,7 +1278,9 @@ class Client:
|
|||||||
|
|
||||||
# Invite management
|
# Invite management
|
||||||
|
|
||||||
async def fetch_invite(self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True) -> Invite:
|
async def fetch_invite(
|
||||||
|
self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True
|
||||||
|
) -> Invite:
|
||||||
"""|coro|
|
"""|coro|
|
||||||
|
|
||||||
Gets an :class:`.Invite` from a discord.gg URL or ID.
|
Gets an :class:`.Invite` from a discord.gg URL or ID.
|
||||||
@@ -1520,7 +1516,7 @@ class Client:
|
|||||||
"""
|
"""
|
||||||
data = await self.http.get_sticker(sticker_id)
|
data = await self.http.get_sticker(sticker_id)
|
||||||
cls, _ = _sticker_factory(data['type']) # type: ignore
|
cls, _ = _sticker_factory(data['type']) # type: ignore
|
||||||
return cls(state=self._connection, data=data) # type: ignore
|
return cls(state=self._connection, data=data) # type: ignore
|
||||||
|
|
||||||
async def fetch_premium_sticker_packs(self) -> List[StickerPack]:
|
async def fetch_premium_sticker_packs(self) -> List[StickerPack]:
|
||||||
"""|coro|
|
"""|coro|
|
||||||
|
@@ -167,7 +167,7 @@ class HTTPClient:
|
|||||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||||
unsync_clock: bool = True
|
unsync_clock: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
|
self.loop: asyncio.AbstractEventLoop = MISSING if loop is None else loop # filled in static_login
|
||||||
self.connector = connector
|
self.connector = connector
|
||||||
self.__session: aiohttp.ClientSession = MISSING # filled in static_login
|
self.__session: aiohttp.ClientSession = MISSING # filled in static_login
|
||||||
self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
|
self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
|
||||||
@@ -371,6 +371,7 @@ class HTTPClient:
|
|||||||
|
|
||||||
async def static_login(self, token: str) -> user.User:
|
async def static_login(self, token: str) -> user.User:
|
||||||
# Necessary to get aiohttp to stop complaining about session creation
|
# Necessary to get aiohttp to stop complaining about session creation
|
||||||
|
self.loop = asyncio.get_running_loop()
|
||||||
self.__session = aiohttp.ClientSession(connector=self.connector, ws_response_class=DiscordClientWebSocketResponse)
|
self.__session = aiohttp.ClientSession(connector=self.connector, ws_response_class=DiscordClientWebSocketResponse)
|
||||||
old_token = self.token
|
old_token = self.token
|
||||||
self.token = token
|
self.token = token
|
||||||
|
Reference in New Issue
Block a user