Rewrite of AutoShardedClient to prevent overlapping identify

This is experimental and I'm unsure if it actually works
This commit is contained in:
Rapptz
2020-04-06 21:34:55 -04:00
parent 1c3b0c02f8
commit 09ecb16680
4 changed files with 70 additions and 62 deletions

View File

@ -33,61 +33,58 @@ import websockets
from .state import AutoShardedConnectionState
from .client import Client
from .gateway import *
from .errors import ClientException, InvalidArgument
from .errors import ClientException, InvalidArgument, ConnectionClosed
from . import utils
from .enums import Status
log = logging.getLogger(__name__)
class EventType:
Close = 0
Resume = 1
Identify = 2
class Shard:
def __init__(self, ws, client):
self.ws = ws
self._client = client
self._dispatch = client.dispatch
self._queue = client._queue
self.loop = self._client.loop
self._current = self.loop.create_future()
self._current.set_result(None) # we just need an already done future
self._pending = asyncio.Event()
self._pending_task = None
self._task = None
@property
def id(self):
return self.ws.shard_id
def is_pending(self):
return not self._pending.is_set()
def launch(self):
self._task = self.loop.create_task(self.worker())
def complete_pending_reads(self):
self._pending.set()
async def worker(self):
while True:
try:
await self.ws.poll_event()
except ReconnectWebSocket as e:
etype = EventType.resume if e.resume else EventType.identify
self._queue.put_nowait((etype, self, e))
break
except ConnectionClosed as e:
self._queue.put_nowait((EventType.close, self, e))
break
async def _pending_reads(self):
try:
while self.is_pending():
await self.poll()
except asyncio.CancelledError:
pass
async def reconnect(self, exc):
if self._task is not None and not self._task.done():
self._task.cancel()
def launch_pending_reads(self):
self._pending_task = asyncio.ensure_future(self._pending_reads(), loop=self.loop)
log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
if not exc.resume:
await asyncio.sleep(5.0)
def wait(self):
return self._pending_task
async def poll(self):
try:
await self.ws.poll_event()
except ResumeWebSocket:
log.info('Got a request to RESUME the websocket at Shard ID %s.', self.id)
coro = DiscordWebSocket.from_client(self._client, resume=True, shard_id=self.id,
session=self.ws.session_id, sequence=self.ws.sequence)
self._dispatch('disconnect')
self.ws = await asyncio.wait_for(coro, timeout=180.0)
def get_future(self):
if self._current.done():
self._current = asyncio.ensure_future(self.poll(), loop=self.loop)
return self._current
coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id,
session=self.ws.session_id, sequence=self.ws.sequence)
self._dispatch('disconnect')
self.ws = await asyncio.wait_for(coro, timeout=180.0)
self.launch()
class AutoShardedClient(Client):
"""A client similar to :class:`Client` except it handles the complications
@ -134,6 +131,7 @@ class AutoShardedClient(Client):
# the key is the shard_id
self.shards = {}
self._connection._get_websocket = self._get_websocket
self._queue = asyncio.PriorityQueue()
def _get_websocket(self, guild_id=None, *, shard_id=None):
if shard_id is None:
@ -220,8 +218,10 @@ class AutoShardedClient(Client):
# keep reading the shard while others connect
self.shards[shard_id] = ret = Shard(ws, self)
ret.launch_pending_reads()
await asyncio.sleep(5.0)
ret.launch()
if len(self.shards) == self.shard_count:
self._connection.shards_launched.set()
async def launch_shards(self):
if self.shard_count is None:
@ -234,26 +234,29 @@ class AutoShardedClient(Client):
shard_ids = self.shard_ids if self.shard_ids else range(self.shard_count)
self._connection.shard_ids = shard_ids
last_shard_id = shard_ids[-1]
for shard_id in shard_ids:
await self.launch_shard(gateway, shard_id)
if shard_id != last_shard_id:
await asyncio.sleep(5.0)
shards_to_wait_for = []
for shard in self.shards.values():
shard.complete_pending_reads()
shards_to_wait_for.append(shard.wait())
# shards_to_wait_for = []
# for shard in self.shards.values():
# shard.complete_pending_reads()
# shards_to_wait_for.append(shard.wait())
# wait for all pending tasks to finish
await utils.sane_wait_for(shards_to_wait_for, timeout=300.0)
# # wait for all pending tasks to finish
# await utils.sane_wait_for(shards_to_wait_for, timeout=300.0)
async def _connect(self):
await self.launch_shards()
while True:
pollers = [shard.get_future() for shard in self.shards.values()]
done, _ = await asyncio.wait(pollers, return_when=asyncio.FIRST_COMPLETED)
for f in done:
# we wanna re-raise to the main Client.connect handler if applicable
f.result()
etype, shard, exc = await self._queue.get()
if etype == EventType.close:
raise exc
elif etype in (EventType.identify, EventType.resume):
await shard.reconnect(exc)
async def close(self):
"""|coro|