mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-04-19 07:26:17 +00:00
Rewrite of AutoShardedClient to prevent overlapping identify
This is experimental and I'm unsure if it actually works
This commit is contained in:
parent
1c3b0c02f8
commit
09ecb16680
@ -453,11 +453,14 @@ class Client:
|
||||
while True:
|
||||
try:
|
||||
await self.ws.poll_event()
|
||||
except ResumeWebSocket:
|
||||
log.info('Got a request to RESUME the websocket.')
|
||||
except ReconnectWebSocket as e:
|
||||
log.info('Got a request to %s the websocket.', e.op)
|
||||
self.dispatch('disconnect')
|
||||
if not e.resume:
|
||||
await asyncio.sleep(5.0)
|
||||
|
||||
coro = DiscordWebSocket.from_client(self, shard_id=self.shard_id, session=self.ws.session_id,
|
||||
sequence=self.ws.sequence, resume=True)
|
||||
sequence=self.ws.sequence, resume=e.resume)
|
||||
self.ws = await asyncio.wait_for(coro, timeout=180.0)
|
||||
|
||||
async def connect(self, *, reconnect=True):
|
||||
|
@ -50,13 +50,15 @@ __all__ = (
|
||||
'KeepAliveHandler',
|
||||
'VoiceKeepAliveHandler',
|
||||
'DiscordVoiceWebSocket',
|
||||
'ResumeWebSocket',
|
||||
'ReconnectWebSocket',
|
||||
)
|
||||
|
||||
class ResumeWebSocket(Exception):
|
||||
"""Signals to initialise via RESUME opcode instead of IDENTIFY."""
|
||||
def __init__(self, shard_id):
|
||||
class ReconnectWebSocket(Exception):
|
||||
"""Signals to safely reconnect the websocket."""
|
||||
def __init__(self, shard_id, *, resume=True):
|
||||
self.shard_id = shard_id
|
||||
self.resume = resume
|
||||
self.op = 'RESUME' if resume else 'IDENTIFY'
|
||||
|
||||
EventListener = namedtuple('EventListener', 'predicate event result future')
|
||||
|
||||
@ -385,7 +387,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
# internal exception signalling to reconnect.
|
||||
log.debug('Received RECONNECT opcode.')
|
||||
await self.close()
|
||||
raise ResumeWebSocket(self.shard_id)
|
||||
raise ReconnectWebSocket(self.shard_id)
|
||||
|
||||
if op == self.HEARTBEAT_ACK:
|
||||
self._keep_alive.ack()
|
||||
@ -406,16 +408,14 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
|
||||
if op == self.INVALIDATE_SESSION:
|
||||
if data is True:
|
||||
await asyncio.sleep(5.0)
|
||||
await self.close()
|
||||
raise ResumeWebSocket(self.shard_id)
|
||||
raise ReconnectWebSocket(self.shard_id)
|
||||
|
||||
self.sequence = None
|
||||
self.session_id = None
|
||||
log.info('Shard ID %s session has been invalidated.', self.shard_id)
|
||||
await asyncio.sleep(5.0)
|
||||
await self.identify()
|
||||
return
|
||||
await self.close(code=1000)
|
||||
raise ReconnectWebSocket(self.shard_id, resume=False)
|
||||
|
||||
log.warning('Unknown OP code %s.', op)
|
||||
return
|
||||
@ -489,7 +489,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
except websockets.exceptions.ConnectionClosed as exc:
|
||||
if self._can_handle_close(exc.code):
|
||||
log.info('Websocket closed with %s (%s), attempting a reconnect.', exc.code, exc.reason)
|
||||
raise ResumeWebSocket(self.shard_id) from exc
|
||||
raise ReconnectWebSocket(self.shard_id) from exc
|
||||
else:
|
||||
log.info('Websocket closed with %s (%s), cannot reconnect.', exc.code, exc.reason)
|
||||
raise ConnectionClosed(exc, shard_id=self.shard_id) from exc
|
||||
|
@ -33,61 +33,58 @@ import websockets
|
||||
from .state import AutoShardedConnectionState
|
||||
from .client import Client
|
||||
from .gateway import *
|
||||
from .errors import ClientException, InvalidArgument
|
||||
from .errors import ClientException, InvalidArgument, ConnectionClosed
|
||||
from . import utils
|
||||
from .enums import Status
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
class EventType:
|
||||
Close = 0
|
||||
Resume = 1
|
||||
Identify = 2
|
||||
|
||||
class Shard:
|
||||
def __init__(self, ws, client):
|
||||
self.ws = ws
|
||||
self._client = client
|
||||
self._dispatch = client.dispatch
|
||||
self._queue = client._queue
|
||||
self.loop = self._client.loop
|
||||
self._current = self.loop.create_future()
|
||||
self._current.set_result(None) # we just need an already done future
|
||||
self._pending = asyncio.Event()
|
||||
self._pending_task = None
|
||||
self._task = None
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self.ws.shard_id
|
||||
|
||||
def is_pending(self):
|
||||
return not self._pending.is_set()
|
||||
def launch(self):
|
||||
self._task = self.loop.create_task(self.worker())
|
||||
|
||||
def complete_pending_reads(self):
|
||||
self._pending.set()
|
||||
async def worker(self):
|
||||
while True:
|
||||
try:
|
||||
await self.ws.poll_event()
|
||||
except ReconnectWebSocket as e:
|
||||
etype = EventType.resume if e.resume else EventType.identify
|
||||
self._queue.put_nowait((etype, self, e))
|
||||
break
|
||||
except ConnectionClosed as e:
|
||||
self._queue.put_nowait((EventType.close, self, e))
|
||||
break
|
||||
|
||||
async def _pending_reads(self):
|
||||
try:
|
||||
while self.is_pending():
|
||||
await self.poll()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
async def reconnect(self, exc):
|
||||
if self._task is not None and not self._task.done():
|
||||
self._task.cancel()
|
||||
|
||||
def launch_pending_reads(self):
|
||||
self._pending_task = asyncio.ensure_future(self._pending_reads(), loop=self.loop)
|
||||
log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
|
||||
if not exc.resume:
|
||||
await asyncio.sleep(5.0)
|
||||
|
||||
def wait(self):
|
||||
return self._pending_task
|
||||
|
||||
async def poll(self):
|
||||
try:
|
||||
await self.ws.poll_event()
|
||||
except ResumeWebSocket:
|
||||
log.info('Got a request to RESUME the websocket at Shard ID %s.', self.id)
|
||||
coro = DiscordWebSocket.from_client(self._client, resume=True, shard_id=self.id,
|
||||
session=self.ws.session_id, sequence=self.ws.sequence)
|
||||
self._dispatch('disconnect')
|
||||
self.ws = await asyncio.wait_for(coro, timeout=180.0)
|
||||
|
||||
def get_future(self):
|
||||
if self._current.done():
|
||||
self._current = asyncio.ensure_future(self.poll(), loop=self.loop)
|
||||
|
||||
return self._current
|
||||
coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id,
|
||||
session=self.ws.session_id, sequence=self.ws.sequence)
|
||||
self._dispatch('disconnect')
|
||||
self.ws = await asyncio.wait_for(coro, timeout=180.0)
|
||||
self.launch()
|
||||
|
||||
class AutoShardedClient(Client):
|
||||
"""A client similar to :class:`Client` except it handles the complications
|
||||
@ -134,6 +131,7 @@ class AutoShardedClient(Client):
|
||||
# the key is the shard_id
|
||||
self.shards = {}
|
||||
self._connection._get_websocket = self._get_websocket
|
||||
self._queue = asyncio.PriorityQueue()
|
||||
|
||||
def _get_websocket(self, guild_id=None, *, shard_id=None):
|
||||
if shard_id is None:
|
||||
@ -220,8 +218,10 @@ class AutoShardedClient(Client):
|
||||
|
||||
# keep reading the shard while others connect
|
||||
self.shards[shard_id] = ret = Shard(ws, self)
|
||||
ret.launch_pending_reads()
|
||||
await asyncio.sleep(5.0)
|
||||
ret.launch()
|
||||
|
||||
if len(self.shards) == self.shard_count:
|
||||
self._connection.shards_launched.set()
|
||||
|
||||
async def launch_shards(self):
|
||||
if self.shard_count is None:
|
||||
@ -234,26 +234,29 @@ class AutoShardedClient(Client):
|
||||
shard_ids = self.shard_ids if self.shard_ids else range(self.shard_count)
|
||||
self._connection.shard_ids = shard_ids
|
||||
|
||||
last_shard_id = shard_ids[-1]
|
||||
for shard_id in shard_ids:
|
||||
await self.launch_shard(gateway, shard_id)
|
||||
if shard_id != last_shard_id:
|
||||
await asyncio.sleep(5.0)
|
||||
|
||||
shards_to_wait_for = []
|
||||
for shard in self.shards.values():
|
||||
shard.complete_pending_reads()
|
||||
shards_to_wait_for.append(shard.wait())
|
||||
# shards_to_wait_for = []
|
||||
# for shard in self.shards.values():
|
||||
# shard.complete_pending_reads()
|
||||
# shards_to_wait_for.append(shard.wait())
|
||||
|
||||
# wait for all pending tasks to finish
|
||||
await utils.sane_wait_for(shards_to_wait_for, timeout=300.0)
|
||||
# # wait for all pending tasks to finish
|
||||
# await utils.sane_wait_for(shards_to_wait_for, timeout=300.0)
|
||||
|
||||
async def _connect(self):
|
||||
await self.launch_shards()
|
||||
|
||||
while True:
|
||||
pollers = [shard.get_future() for shard in self.shards.values()]
|
||||
done, _ = await asyncio.wait(pollers, return_when=asyncio.FIRST_COMPLETED)
|
||||
for f in done:
|
||||
# we wanna re-raise to the main Client.connect handler if applicable
|
||||
f.result()
|
||||
etype, shard, exc = await self._queue.get()
|
||||
if etype == EventType.close:
|
||||
raise exc
|
||||
elif etype in (EventType.identify, EventType.resume):
|
||||
await shard.reconnect(exc)
|
||||
|
||||
async def close(self):
|
||||
"""|coro|
|
||||
|
@ -1047,6 +1047,7 @@ class AutoShardedConnectionState(ConnectionState):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._ready_task = None
|
||||
self.shard_ids = ()
|
||||
self.shards_launched = asyncio.Event()
|
||||
|
||||
async def chunker(self, guild_id, query='', limit=0, *, shard_id=None, nonce=None):
|
||||
ws = self._get_websocket(guild_id, shard_id=shard_id)
|
||||
@ -1073,6 +1074,7 @@ class AutoShardedConnectionState(ConnectionState):
|
||||
log.info('Finished requesting guild member chunks for %d guilds.', len(guilds))
|
||||
|
||||
async def _delay_ready(self):
|
||||
await self.shards_launched.wait()
|
||||
launch = self._ready_state.launch
|
||||
while True:
|
||||
# this snippet of code is basically waiting 2 * shard_ids seconds
|
||||
|
Loading…
x
Reference in New Issue
Block a user