Add an exposed way to extract shard-specific information.

Closes #2654
This commit is contained in:
Rapptz
2020-07-25 09:37:48 -04:00
parent a42bebe581
commit 7ed26db3b3
3 changed files with 104 additions and 16 deletions

View File

@ -103,6 +103,10 @@ class Shard:
self._cancel_task()
await self.ws.close(code=1000)
async def disconnect(self):
await self.close()
self._dispatch('shard_disconnect', self.id)
async def _handle_disconnect(self, e):
self._dispatch('disconnect')
self._dispatch('shard_disconnect', self.id)
@ -178,6 +182,70 @@ class Shard:
else:
self.launch()
class ShardInfo:
"""A class that gives information and control over a specific shard.
You can retrieve this object via :meth:`AutoShardedClient.get_shard`
or :attr:`AutoShardedClient.shards`.
.. versionadded:: 1.4
Attributes
------------
id: :class:`int`
The shard ID for this shard.
shard_count: Optional[:class:`int`]
The shard count for this cluster. If this is ``None`` then the bot has not started yet.
"""
__slots__ = ('_parent', 'id', 'shard_count')
def __init__(self, parent, shard_count):
self._parent = parent
self.id = parent.id
self.shard_count = shard_count
def is_closed(self):
""":class:`bool`: Whether the shard connection is currently closed."""
return not self._parent.ws.open
async def disconnect(self):
"""|coro|
Disconnects a shard. When this is called, the shard connection will no
longer be open.
If the shard is already disconnected this does nothing.
"""
if self.is_closed():
return
await self._parent.disconnect()
async def reconnect(self):
"""|coro|
Disconnects and then connects the shard again.
"""
if not self.is_closed():
await self._parent.disconnect()
await self._parent.reconnect()
async def connect(self):
"""|coro|
Connects a shard. If the shard is already connected this does nothing.
"""
if not self.is_closed():
return
await self._parent.reconnect()
@property
def latency(self):
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds for this shard."""
return self._parent.ws.latency
class AutoShardedClient(Client):
"""A client similar to :class:`Client` except it handles the complications
of sharding for the user into a more manageable and transparent single
@ -221,14 +289,14 @@ class AutoShardedClient(Client):
# instead of a single websocket, we have multiple
# the key is the shard_id
self.shards = {}
self.__shards = {}
self._connection._get_websocket = self._get_websocket
self._queue = asyncio.PriorityQueue()
self.__queue = asyncio.PriorityQueue()
def _get_websocket(self, guild_id=None, *, shard_id=None):
if shard_id is None:
shard_id = (guild_id >> 22) % self.shard_count
return self.shards[shard_id].ws
return self.__shards[shard_id].ws
@property
def latency(self):
@ -238,9 +306,9 @@ class AutoShardedClient(Client):
latency of every shard's latency. To get a list of shard latency, check the
:attr:`latencies` property. Returns ``nan`` if there are no shards ready.
"""
if not self.shards:
if not self.__shards:
return float('nan')
return sum(latency for _, latency in self.latencies) / len(self.shards)
return sum(latency for _, latency in self.latencies) / len(self.__shards)
@property
def latencies(self):
@ -248,7 +316,21 @@ class AutoShardedClient(Client):
This returns a list of tuples with elements ``(shard_id, latency)``.
"""
return [(shard_id, shard.ws.latency) for shard_id, shard in self.shards.items()]
return [(shard_id, shard.ws.latency) for shard_id, shard in self.__shards.items()]
def get_shard(self, shard_id):
"""Optional[:class:`ShardInfo`]: Gets the shard information at a given shard ID or ``None`` if not found."""
try:
parent = self.__shards[shard_id]
except KeyError:
return None
else:
return ShardInfo(parent, self.shard_count)
@utils.cached_property
def shards(self):
"""Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object."""
return { shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items() }
async def request_offline_members(self, *guilds):
r"""|coro|
@ -291,7 +373,7 @@ class AutoShardedClient(Client):
return await self.launch_shard(gateway, shard_id)
# keep reading the shard while others connect
self.shards[shard_id] = ret = Shard(ws, self)
self.__shards[shard_id] = ret = Shard(ws, self)
ret.launch()
async def launch_shards(self):
@ -316,7 +398,7 @@ class AutoShardedClient(Client):
await self.launch_shards()
while not self.is_closed():
item = await self._queue.get()
item = await self.__queue.get()
if item.type == EventType.close:
await self.close()
if isinstance(item.error, ConnectionClosed) and item.error.code != 1000:
@ -346,7 +428,7 @@ class AutoShardedClient(Client):
except Exception:
pass
to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.shards.values()]
to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.__shards.values()]
if to_close:
await asyncio.wait(to_close)
@ -395,12 +477,12 @@ class AutoShardedClient(Client):
status = str(status)
if shard_id is None:
for shard in self.shards.values():
for shard in self.__shards.values():
await shard.ws.change_presence(activity=activity, status=status, afk=afk)
guilds = self._connection.guilds
else:
shard = self.shards[shard_id]
shard = self.__shards[shard_id]
await shard.ws.change_presence(activity=activity, status=status, afk=afk)
guilds = [g for g in self._connection.guilds if g.shard_id == shard_id]