Refactor Client.run to use asyncio.run
This also adds asynchronous context manager support to allow for idiomatic asyncio usage for the lower-level counterpart. At first I wanted to remove Client.run but I figured that a lot of beginners would have been confused or not enjoyed the verbosity of the newer approach of using async-with.
This commit is contained in:
		| @@ -26,10 +26,24 @@ from __future__ import annotations | |||||||
|  |  | ||||||
| import asyncio | import asyncio | ||||||
| import logging | import logging | ||||||
| import signal |  | ||||||
| import sys | import sys | ||||||
| import traceback | import traceback | ||||||
| from typing import Any, Callable, Coroutine, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union | from typing import ( | ||||||
|  |     Any, | ||||||
|  |     Callable, | ||||||
|  |     Coroutine, | ||||||
|  |     Dict, | ||||||
|  |     Generator, | ||||||
|  |     Iterable, | ||||||
|  |     List, | ||||||
|  |     Optional, | ||||||
|  |     Sequence, | ||||||
|  |     TYPE_CHECKING, | ||||||
|  |     Tuple, | ||||||
|  |     TypeVar, | ||||||
|  |     Type, | ||||||
|  |     Union, | ||||||
|  | ) | ||||||
|  |  | ||||||
| import aiohttp | import aiohttp | ||||||
|  |  | ||||||
| @@ -68,6 +82,7 @@ if TYPE_CHECKING: | |||||||
|     from .message import Message |     from .message import Message | ||||||
|     from .member import Member |     from .member import Member | ||||||
|     from .voice_client import VoiceProtocol |     from .voice_client import VoiceProtocol | ||||||
|  |     from types import TracebackType | ||||||
|  |  | ||||||
| __all__ = ( | __all__ = ( | ||||||
|     'Client', |     'Client', | ||||||
| @@ -78,36 +93,8 @@ Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]]) | |||||||
|  |  | ||||||
| log: logging.Logger = logging.getLogger(__name__) | log: logging.Logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None: | C = TypeVar('C', bound='Client') | ||||||
|     tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()} |  | ||||||
|  |  | ||||||
|     if not tasks: |  | ||||||
|         return |  | ||||||
|  |  | ||||||
|     log.info('Cleaning up after %d tasks.', len(tasks)) |  | ||||||
|     for task in tasks: |  | ||||||
|         task.cancel() |  | ||||||
|  |  | ||||||
|     loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) |  | ||||||
|     log.info('All tasks finished cancelling.') |  | ||||||
|  |  | ||||||
|     for task in tasks: |  | ||||||
|         if task.cancelled(): |  | ||||||
|             continue |  | ||||||
|         if task.exception() is not None: |  | ||||||
|             loop.call_exception_handler({ |  | ||||||
|                 'message': 'Unhandled exception during Client.run shutdown.', |  | ||||||
|                 'exception': task.exception(), |  | ||||||
|                 'task': task |  | ||||||
|             }) |  | ||||||
|  |  | ||||||
| def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None: |  | ||||||
|     try: |  | ||||||
|         _cancel_tasks(loop) |  | ||||||
|         loop.run_until_complete(loop.shutdown_asyncgens()) |  | ||||||
|     finally: |  | ||||||
|         log.info('Closing the event loop.') |  | ||||||
|         loop.close() |  | ||||||
|  |  | ||||||
| class Client: | class Client: | ||||||
|     r"""Represents a client connection that connects to Discord. |     r"""Represents a client connection that connects to Discord. | ||||||
| @@ -200,6 +187,7 @@ class Client: | |||||||
|     loop: :class:`asyncio.AbstractEventLoop` |     loop: :class:`asyncio.AbstractEventLoop` | ||||||
|         The event loop that the client uses for asynchronous operations. |         The event loop that the client uses for asynchronous operations. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         *, |         *, | ||||||
| @@ -207,7 +195,8 @@ class Client: | |||||||
|         **options: Any, |         **options: Any, | ||||||
|     ): |     ): | ||||||
|         self.ws: DiscordWebSocket = None  # type: ignore |         self.ws: DiscordWebSocket = None  # type: ignore | ||||||
|         self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop |         # this is filled in later | ||||||
|  |         self.loop: asyncio.AbstractEventLoop = MISSING if loop is None else loop | ||||||
|         self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {} |         self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {} | ||||||
|         self.shard_id: Optional[int] = options.get('shard_id') |         self.shard_id: Optional[int] = options.get('shard_id') | ||||||
|         self.shard_count: Optional[int] = options.get('shard_count') |         self.shard_count: Optional[int] = options.get('shard_count') | ||||||
| @@ -216,14 +205,16 @@ class Client: | |||||||
|         proxy: Optional[str] = options.pop('proxy', None) |         proxy: Optional[str] = options.pop('proxy', None) | ||||||
|         proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None) |         proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None) | ||||||
|         unsync_clock: bool = options.pop('assume_unsync_clock', True) |         unsync_clock: bool = options.pop('assume_unsync_clock', True) | ||||||
|         self.http: HTTPClient = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop) |         self.http: HTTPClient = HTTPClient( | ||||||
|  |             connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=loop | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         self._handlers: Dict[str, Callable] = { |         self._handlers: Dict[str, Callable] = { | ||||||
|             'ready': self._handle_ready |             'ready': self._handle_ready, | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         self._hooks: Dict[str, Callable] = { |         self._hooks: Dict[str, Callable] = { | ||||||
|             'before_identify': self._call_before_identify_hook |             'before_identify': self._call_before_identify_hook, | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         self._enable_debug_events: bool = options.pop('enable_debug_events', False) |         self._enable_debug_events: bool = options.pop('enable_debug_events', False) | ||||||
| @@ -244,8 +235,9 @@ class Client: | |||||||
|         return self.ws |         return self.ws | ||||||
|  |  | ||||||
|     def _get_state(self, **options: Any) -> ConnectionState: |     def _get_state(self, **options: Any) -> ConnectionState: | ||||||
|         return ConnectionState(dispatch=self.dispatch, handlers=self._handlers, |         return ConnectionState( | ||||||
|                                hooks=self._hooks, http=self.http, loop=self.loop, **options) |             dispatch=self.dispatch, handlers=self._handlers, hooks=self._hooks, http=self.http, loop=self.loop, **options | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     def _handle_ready(self) -> None: |     def _handle_ready(self) -> None: | ||||||
|         self._ready.set() |         self._ready.set() | ||||||
| @@ -343,7 +335,9 @@ class Client: | |||||||
|         """:class:`bool`: Specifies if the client's internal cache is ready for use.""" |         """:class:`bool`: Specifies if the client's internal cache is ready for use.""" | ||||||
|         return self._ready.is_set() |         return self._ready.is_set() | ||||||
|  |  | ||||||
|     async def _run_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> None: |     async def _run_event( | ||||||
|  |         self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any | ||||||
|  |     ) -> None: | ||||||
|         try: |         try: | ||||||
|             await coro(*args, **kwargs) |             await coro(*args, **kwargs) | ||||||
|         except asyncio.CancelledError: |         except asyncio.CancelledError: | ||||||
| @@ -354,7 +348,9 @@ class Client: | |||||||
|             except asyncio.CancelledError: |             except asyncio.CancelledError: | ||||||
|                 pass |                 pass | ||||||
|  |  | ||||||
|     def _schedule_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task: |     def _schedule_event( | ||||||
|  |         self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any | ||||||
|  |     ) -> asyncio.Task: | ||||||
|         wrapped = self._run_event(coro, event_name, *args, **kwargs) |         wrapped = self._run_event(coro, event_name, *args, **kwargs) | ||||||
|         # Schedules the task |         # Schedules the task | ||||||
|         return asyncio.create_task(wrapped, name=f'discord.py: {event_name}') |         return asyncio.create_task(wrapped, name=f'discord.py: {event_name}') | ||||||
| @@ -466,7 +462,8 @@ class Client: | |||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         log.info('logging in using static token') |         log.info('logging in using static token') | ||||||
|  |         self.loop = loop = asyncio.get_running_loop() | ||||||
|  |         self._connection.loop = loop | ||||||
|         data = await self.http.static_login(token.strip()) |         data = await self.http.static_login(token.strip()) | ||||||
|         self._connection.user = ClientUser(state=self._connection, data=data) |         self._connection.user = ClientUser(state=self._connection, data=data) | ||||||
|  |  | ||||||
| @@ -512,12 +509,14 @@ class Client: | |||||||
|                 self.dispatch('disconnect') |                 self.dispatch('disconnect') | ||||||
|                 ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id) |                 ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id) | ||||||
|                 continue |                 continue | ||||||
|             except (OSError, |             except ( | ||||||
|  |                 OSError, | ||||||
|                 HTTPException, |                 HTTPException, | ||||||
|                 GatewayNotFound, |                 GatewayNotFound, | ||||||
|                 ConnectionClosed, |                 ConnectionClosed, | ||||||
|                 aiohttp.ClientError, |                 aiohttp.ClientError, | ||||||
|                     asyncio.TimeoutError) as exc: |                 asyncio.TimeoutError, | ||||||
|  |             ) as exc: | ||||||
|  |  | ||||||
|                 self.dispatch('disconnect') |                 self.dispatch('disconnect') | ||||||
|                 if not reconnect: |                 if not reconnect: | ||||||
| @@ -558,6 +557,22 @@ class Client: | |||||||
|         """|coro| |         """|coro| | ||||||
|  |  | ||||||
|         Closes the connection to Discord. |         Closes the connection to Discord. | ||||||
|  |  | ||||||
|  |         Instead of calling this directly, it is recommended to use the asynchronous context | ||||||
|  |         manager to allow resources to be cleaned up automatically: | ||||||
|  |  | ||||||
|  |         .. code-block:: python3 | ||||||
|  |  | ||||||
|  |             async def main(): | ||||||
|  |                 async with Client() as client: | ||||||
|  |                     await client.login(token) | ||||||
|  |                     await client.connect() | ||||||
|  |  | ||||||
|  |             asyncio.run(main()) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |         .. versionchanged:: 2.0 | ||||||
|  |             The client can now be closed with an asynchronous context manager | ||||||
|         """ |         """ | ||||||
|         if self._closed: |         if self._closed: | ||||||
|             return |             return | ||||||
| @@ -589,36 +604,47 @@ class Client: | |||||||
|         self._connection.clear() |         self._connection.clear() | ||||||
|         self.http.recreate() |         self.http.recreate() | ||||||
|  |  | ||||||
|  |     async def __aenter__(self: C) -> C: | ||||||
|  |         return self | ||||||
|  |  | ||||||
|  |     async def __aexit__( | ||||||
|  |         self, | ||||||
|  |         exc_type: Optional[Type[BaseException]], | ||||||
|  |         exc_value: Optional[BaseException], | ||||||
|  |         traceback: Optional[TracebackType], | ||||||
|  |     ) -> None: | ||||||
|  |         await self.close() | ||||||
|  |  | ||||||
|     async def start(self, token: str, *, reconnect: bool = True) -> None: |     async def start(self, token: str, *, reconnect: bool = True) -> None: | ||||||
|         """|coro| |         """|coro| | ||||||
|  |  | ||||||
|         A shorthand coroutine for :meth:`login` + :meth:`connect`. |         A shorthand function equivalent to the following: | ||||||
|  |  | ||||||
|         Raises |         .. code-block:: python3 | ||||||
|         ------- |  | ||||||
|         TypeError |             async with client: | ||||||
|             An unexpected keyword argument was received. |                 await client.login(token) | ||||||
|  |                 await client.connect() | ||||||
|  |  | ||||||
|  |         This closes the client when it returns. | ||||||
|         """ |         """ | ||||||
|  |         try: | ||||||
|             await self.login(token) |             await self.login(token) | ||||||
|             await self.connect(reconnect=reconnect) |             await self.connect(reconnect=reconnect) | ||||||
|  |         finally: | ||||||
|  |             await self.close() | ||||||
|  |  | ||||||
|     def run(self, *args: Any, **kwargs: Any) -> None: |     def run(self, *args: Any, **kwargs: Any) -> None: | ||||||
|         """A blocking call that abstracts away the event loop |         """A convenience blocking call that abstracts away the event loop | ||||||
|         initialisation from you. |         initialisation from you. | ||||||
|  |  | ||||||
|         If you want more control over the event loop then this |         If you want more control over the event loop then this | ||||||
|         function should not be used. Use :meth:`start` coroutine |         function should not be used. Use :meth:`start` coroutine | ||||||
|         or :meth:`connect` + :meth:`login`. |         or :meth:`connect` + :meth:`login`. | ||||||
|  |  | ||||||
|         Roughly Equivalent to: :: |         Equivalent to: :: | ||||||
|  |  | ||||||
|             try: |             asyncio.run(bot.start(*args, **kwargs)) | ||||||
|                 loop.run_until_complete(start(*args, **kwargs)) |  | ||||||
|             except KeyboardInterrupt: |  | ||||||
|                 loop.run_until_complete(close()) |  | ||||||
|                 # cancel all tasks lingering |  | ||||||
|             finally: |  | ||||||
|                 loop.close() |  | ||||||
|  |  | ||||||
|         .. warning:: |         .. warning:: | ||||||
|  |  | ||||||
| @@ -626,41 +652,7 @@ class Client: | |||||||
|             is blocking. That means that registration of events or anything being |             is blocking. That means that registration of events or anything being | ||||||
|             called after this function call will not execute until it returns. |             called after this function call will not execute until it returns. | ||||||
|         """ |         """ | ||||||
|         loop = self.loop |         asyncio.run(self.start(*args, **kwargs)) | ||||||
|  |  | ||||||
|         try: |  | ||||||
|             loop.add_signal_handler(signal.SIGINT, lambda: loop.stop()) |  | ||||||
|             loop.add_signal_handler(signal.SIGTERM, lambda: loop.stop()) |  | ||||||
|         except NotImplementedError: |  | ||||||
|             pass |  | ||||||
|  |  | ||||||
|         async def runner(): |  | ||||||
|             try: |  | ||||||
|                 await self.start(*args, **kwargs) |  | ||||||
|             finally: |  | ||||||
|                 if not self.is_closed(): |  | ||||||
|                     await self.close() |  | ||||||
|  |  | ||||||
|         def stop_loop_on_completion(f): |  | ||||||
|             loop.stop() |  | ||||||
|  |  | ||||||
|         future = asyncio.ensure_future(runner(), loop=loop) |  | ||||||
|         future.add_done_callback(stop_loop_on_completion) |  | ||||||
|         try: |  | ||||||
|             loop.run_forever() |  | ||||||
|         except KeyboardInterrupt: |  | ||||||
|             log.info('Received signal to terminate bot and event loop.') |  | ||||||
|         finally: |  | ||||||
|             future.remove_done_callback(stop_loop_on_completion) |  | ||||||
|             log.info('Cleaning up tasks.') |  | ||||||
|             _cleanup_loop(loop) |  | ||||||
|  |  | ||||||
|         if not future.cancelled(): |  | ||||||
|             try: |  | ||||||
|                 return future.result() |  | ||||||
|             except KeyboardInterrupt: |  | ||||||
|                 # I am unsure why this gets raised here but suppress it anyway |  | ||||||
|                 return None |  | ||||||
|  |  | ||||||
|     # properties |     # properties | ||||||
|  |  | ||||||
| @@ -973,8 +965,10 @@ class Client: | |||||||
|  |  | ||||||
|         future = self.loop.create_future() |         future = self.loop.create_future() | ||||||
|         if check is None: |         if check is None: | ||||||
|  |  | ||||||
|             def _check(*args): |             def _check(*args): | ||||||
|                 return True |                 return True | ||||||
|  |  | ||||||
|             check = _check |             check = _check | ||||||
|  |  | ||||||
|         ev = event.lower() |         ev = event.lower() | ||||||
| @@ -1083,7 +1077,7 @@ class Client: | |||||||
|         *, |         *, | ||||||
|         limit: Optional[int] = 100, |         limit: Optional[int] = 100, | ||||||
|         before: SnowflakeTime = None, |         before: SnowflakeTime = None, | ||||||
|         after: SnowflakeTime = None |         after: SnowflakeTime = None, | ||||||
|     ) -> GuildIterator: |     ) -> GuildIterator: | ||||||
|         """Retrieves an :class:`.AsyncIterator` that enables receiving your guilds. |         """Retrieves an :class:`.AsyncIterator` that enables receiving your guilds. | ||||||
|  |  | ||||||
| @@ -1284,7 +1278,9 @@ class Client: | |||||||
|  |  | ||||||
|     # Invite management |     # Invite management | ||||||
|  |  | ||||||
|     async def fetch_invite(self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True) -> Invite: |     async def fetch_invite( | ||||||
|  |         self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True | ||||||
|  |     ) -> Invite: | ||||||
|         """|coro| |         """|coro| | ||||||
|  |  | ||||||
|         Gets an :class:`.Invite` from a discord.gg URL or ID. |         Gets an :class:`.Invite` from a discord.gg URL or ID. | ||||||
|   | |||||||
| @@ -167,7 +167,7 @@ class HTTPClient: | |||||||
|         loop: Optional[asyncio.AbstractEventLoop] = None, |         loop: Optional[asyncio.AbstractEventLoop] = None, | ||||||
|         unsync_clock: bool = True |         unsync_clock: bool = True | ||||||
|     ) -> None: |     ) -> None: | ||||||
|         self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop |         self.loop: asyncio.AbstractEventLoop = MISSING if loop is None else loop  # filled in static_login | ||||||
|         self.connector = connector |         self.connector = connector | ||||||
|         self.__session: aiohttp.ClientSession = MISSING  # filled in static_login |         self.__session: aiohttp.ClientSession = MISSING  # filled in static_login | ||||||
|         self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary() |         self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary() | ||||||
| @@ -371,6 +371,7 @@ class HTTPClient: | |||||||
|  |  | ||||||
|     async def static_login(self, token: str) -> user.User: |     async def static_login(self, token: str) -> user.User: | ||||||
|         # Necessary to get aiohttp to stop complaining about session creation |         # Necessary to get aiohttp to stop complaining about session creation | ||||||
|  |         self.loop = asyncio.get_running_loop() | ||||||
|         self.__session = aiohttp.ClientSession(connector=self.connector, ws_response_class=DiscordClientWebSocketResponse) |         self.__session = aiohttp.ClientSession(connector=self.connector, ws_response_class=DiscordClientWebSocketResponse) | ||||||
|         old_token = self.token |         old_token = self.token | ||||||
|         self.token = token |         self.token = token | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user