mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-06-07 20:28:38 +00:00
Ensure Client.close() has finished in __aexit__
This wraps the closing behavior in a task. Subsequent callers of .close() now await that same close finishing rather than short circuiting. This prevents a user-called close outside of __aexit__ from not finishing before no longer having a running event loop.
This commit is contained in:
parent
8fd1fd805a
commit
88f62d85d2
@ -287,7 +287,7 @@ class Client:
|
|||||||
self._enable_debug_events: bool = options.pop('enable_debug_events', False)
|
self._enable_debug_events: bool = options.pop('enable_debug_events', False)
|
||||||
self._connection: ConnectionState[Self] = self._get_state(intents=intents, **options)
|
self._connection: ConnectionState[Self] = self._get_state(intents=intents, **options)
|
||||||
self._connection.shard_count = self.shard_count
|
self._connection.shard_count = self.shard_count
|
||||||
self._closed: bool = False
|
self._closing_task: Optional[asyncio.Task[None]] = None
|
||||||
self._ready: asyncio.Event = MISSING
|
self._ready: asyncio.Event = MISSING
|
||||||
self._application: Optional[AppInfo] = None
|
self._application: Optional[AppInfo] = None
|
||||||
self._connection._get_websocket = self._get_websocket
|
self._connection._get_websocket = self._get_websocket
|
||||||
@ -307,7 +307,10 @@ class Client:
|
|||||||
exc_value: Optional[BaseException],
|
exc_value: Optional[BaseException],
|
||||||
traceback: Optional[TracebackType],
|
traceback: Optional[TracebackType],
|
||||||
) -> None:
|
) -> None:
|
||||||
if not self.is_closed():
|
# This avoids double-calling a user-provided .close()
|
||||||
|
if self._closing_task:
|
||||||
|
await self._closing_task
|
||||||
|
else:
|
||||||
await self.close()
|
await self.close()
|
||||||
|
|
||||||
# internals
|
# internals
|
||||||
@ -726,11 +729,10 @@ class Client:
|
|||||||
|
|
||||||
Closes the connection to Discord.
|
Closes the connection to Discord.
|
||||||
"""
|
"""
|
||||||
if self._closed:
|
if self._closing_task:
|
||||||
return
|
return await self._closing_task
|
||||||
|
|
||||||
self._closed = True
|
|
||||||
|
|
||||||
|
async def _close():
|
||||||
await self._connection.close()
|
await self._connection.close()
|
||||||
|
|
||||||
if self.ws is not None and self.ws.open:
|
if self.ws is not None and self.ws.open:
|
||||||
@ -743,6 +745,9 @@ class Client:
|
|||||||
|
|
||||||
self.loop = MISSING
|
self.loop = MISSING
|
||||||
|
|
||||||
|
self._closing_task = asyncio.create_task(_close())
|
||||||
|
await self._closing_task
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Clears the internal state of the bot.
|
"""Clears the internal state of the bot.
|
||||||
|
|
||||||
@ -750,7 +755,7 @@ class Client:
|
|||||||
and :meth:`is_ready` both return ``False`` along with the bot's internal
|
and :meth:`is_ready` both return ``False`` along with the bot's internal
|
||||||
cache cleared.
|
cache cleared.
|
||||||
"""
|
"""
|
||||||
self._closed = False
|
self._closing_task = None
|
||||||
self._ready.clear()
|
self._ready.clear()
|
||||||
self._connection.clear()
|
self._connection.clear()
|
||||||
self.http.clear()
|
self.http.clear()
|
||||||
@ -870,7 +875,7 @@ class Client:
|
|||||||
|
|
||||||
def is_closed(self) -> bool:
|
def is_closed(self) -> bool:
|
||||||
""":class:`bool`: Indicates if the websocket connection is closed."""
|
""":class:`bool`: Indicates if the websocket connection is closed."""
|
||||||
return self._closed
|
return self._closing_task is not None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def activity(self) -> Optional[ActivityTypes]:
|
def activity(self) -> Optional[ActivityTypes]:
|
||||||
|
@ -481,10 +481,10 @@ class AutoShardedClient(Client):
|
|||||||
|
|
||||||
Closes the connection to Discord.
|
Closes the connection to Discord.
|
||||||
"""
|
"""
|
||||||
if self.is_closed():
|
if self._closing_task:
|
||||||
return
|
return await self._closing_task
|
||||||
|
|
||||||
self._closed = True
|
async def _close():
|
||||||
await self._connection.close()
|
await self._connection.close()
|
||||||
|
|
||||||
to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.__shards.values()]
|
to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.__shards.values()]
|
||||||
@ -494,6 +494,9 @@ class AutoShardedClient(Client):
|
|||||||
await self.http.close()
|
await self.http.close()
|
||||||
self.__queue.put_nowait(EventItem(EventType.clean_close, None, None))
|
self.__queue.put_nowait(EventItem(EventType.clean_close, None, None))
|
||||||
|
|
||||||
|
self._closing_task = asyncio.create_task(_close())
|
||||||
|
await self._closing_task
|
||||||
|
|
||||||
async def change_presence(
|
async def change_presence(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user