Refactor loop code to allow usage of asyncio.run

This commit is contained in:
Han Seung Min - 한승민
2022-03-13 14:24:14 +05:30
committed by GitHub
parent 196db33e9f
commit 93af158b0c
9 changed files with 44 additions and 136 deletions

View File

@@ -27,7 +27,6 @@ from __future__ import annotations
import asyncio
import datetime
import logging
import signal
import sys
import traceback
from typing import (
@@ -97,41 +96,6 @@ Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])
_log = logging.getLogger(__name__)
def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()}
if not tasks:
return
_log.info('Cleaning up after %d tasks.', len(tasks))
for task in tasks:
task.cancel()
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
_log.info('All tasks finished cancelling.')
for task in tasks:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler(
{
'message': 'Unhandled exception during Client.run shutdown.',
'exception': task.exception(),
'task': task,
}
)
def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None:
try:
_cancel_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
_log.info('Closing the event loop.')
loop.close()
class Client:
r"""Represents a client connection that connects to Discord.
This class is used to interact with the Discord WebSocket and API.
@@ -146,12 +110,6 @@ class Client:
.. versionchanged:: 1.3
Allow disabling the message cache and change the default size to ``1000``.
loop: Optional[:class:`asyncio.AbstractEventLoop`]
The :class:`asyncio.AbstractEventLoop` to use for asynchronous operations.
Defaults to ``None``, in which case the default event loop is used via
:func:`asyncio.get_event_loop()`.
connector: Optional[:class:`aiohttp.BaseConnector`]
The connector to use for connection pooling.
proxy: Optional[:class:`str`]
Proxy URL.
proxy_auth: Optional[:class:`aiohttp.BasicAuth`]
@@ -220,30 +178,23 @@ class Client:
-----------
ws
The websocket gateway the client is currently connected to. Could be ``None``.
loop: :class:`asyncio.AbstractEventLoop`
The event loop that the client uses for asynchronous operations.
"""
def __init__(
self,
*,
loop: Optional[asyncio.AbstractEventLoop] = None,
**options: Any,
):
self.loop: asyncio.AbstractEventLoop = MISSING
# self.ws is set in the connect method
self.ws: DiscordWebSocket = None # type: ignore
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {}
self.shard_id: Optional[int] = options.get('shard_id')
self.shard_count: Optional[int] = options.get('shard_count')
connector: Optional[aiohttp.BaseConnector] = options.pop('connector', None)
proxy: Optional[str] = options.pop('proxy', None)
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None)
unsync_clock: bool = options.pop('assume_unsync_clock', True)
self.http: HTTPClient = HTTPClient(
connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop
)
self.http: HTTPClient = HTTPClient(self.loop, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock)
self._handlers: Dict[str, Callable] = {
'ready': self._handle_ready,
@@ -399,7 +350,7 @@ class Client:
) -> asyncio.Task:
wrapped = self._run_event(coro, event_name, *args, **kwargs)
# Schedules the task
return asyncio.create_task(wrapped, name=f'discord.py: {event_name}')
return self.loop.create_task(wrapped, name=f'discord.py: {event_name}')
def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None:
_log.debug('Dispatching event %s', event)
@@ -623,6 +574,7 @@ class Client:
await self.http.close()
self._ready.clear()
self.loop = MISSING
def clear(self) -> None:
"""Clears the internal state of the bot.
@@ -646,8 +598,15 @@ class Client:
TypeError
An unexpected keyword argument was received.
"""
await self.login(token)
await self.connect(reconnect=reconnect)
self.loop = asyncio.get_running_loop()
self.http.loop = self.loop
self._connection.loop = self.loop
try:
await self.login(token)
await self.connect(reconnect=reconnect)
finally:
if not self.is_closed():
await self.close()
async def setup_hook(self) -> None:
"""|coro|
@@ -676,12 +635,9 @@ class Client:
Roughly Equivalent to: ::
try:
loop.run_until_complete(start(*args, **kwargs))
asyncio.run(self.start(*args, **kwargs))
except KeyboardInterrupt:
loop.run_until_complete(close())
# cancel all tasks lingering
finally:
loop.close()
return
.. warning::
@@ -689,41 +645,13 @@ class Client:
is blocking. That means that registration of events or anything being
called after this function call will not execute until it returns.
"""
loop = self.loop
try:
loop.add_signal_handler(signal.SIGINT, lambda: loop.stop())
loop.add_signal_handler(signal.SIGTERM, lambda: loop.stop())
except NotImplementedError:
pass
async def runner():
try:
await self.start(*args, **kwargs)
finally:
if not self.is_closed():
await self.close()
def stop_loop_on_completion(f):
loop.stop()
future = asyncio.ensure_future(runner(), loop=loop)
future.add_done_callback(stop_loop_on_completion)
try:
loop.run_forever()
asyncio.run(self.start(*args, **kwargs))
except KeyboardInterrupt:
_log.info('Received signal to terminate bot and event loop.')
finally:
future.remove_done_callback(stop_loop_on_completion)
_log.info('Cleaning up tasks.')
_cleanup_loop(loop)
if not future.cancelled():
try:
return future.result()
except KeyboardInterrupt:
# I am unsure why this gets raised here but suppress it anyway
return None
# nothing to do here
# `asyncio.run` handles the loop cleanup
# and `self.start` closes all sockets and the HTTPClient instance.
return
# properties
@@ -1324,7 +1252,7 @@ class Client:
"""
code = utils.resolve_template(code)
data = await self.http.get_template(code)
return Template(data=data, state=self._connection) # type: ignore
return Template(data=data, state=self._connection)
async def fetch_guild(self, guild_id: int, /, *, with_counts: bool = True) -> Guild:
"""|coro|