mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-05-15 18:29:52 +00:00
Make every shard maintain its own reconnect loop
Previously if a disconnect happened the client would get in a bad state and certain shards would be double sending due to unhandled exceptions raising back to Client.connect and causing all shards to be reconnected again. This new code overrides Client.connect to have more finer control and allow each individual shard to maintain its own reconnect loop and then serially request reconnection to ensure that IDENTIFYs are not overlapping.
This commit is contained in:
parent
394b514cc9
commit
f658fcf164
@ -28,10 +28,13 @@ import asyncio
|
|||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
from .state import AutoShardedConnectionState
|
from .state import AutoShardedConnectionState
|
||||||
from .client import Client
|
from .client import Client
|
||||||
|
from .backoff import ExponentialBackoff
|
||||||
from .gateway import *
|
from .gateway import *
|
||||||
from .errors import ClientException, InvalidArgument, ConnectionClosed
|
from .errors import ClientException, InvalidArgument, HTTPException, GatewayNotFound, ConnectionClosed
|
||||||
from . import utils
|
from . import utils
|
||||||
from .enums import Status
|
from .enums import Status
|
||||||
|
|
||||||
@ -39,8 +42,9 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class EventType:
|
class EventType:
|
||||||
close = 0
|
close = 0
|
||||||
resume = 1
|
reconnect = 1
|
||||||
identify = 2
|
resume = 2
|
||||||
|
identify = 3
|
||||||
|
|
||||||
class EventItem:
|
class EventItem:
|
||||||
__slots__ = ('type', 'shard', 'error')
|
__slots__ = ('type', 'shard', 'error')
|
||||||
@ -70,7 +74,18 @@ class Shard:
|
|||||||
self._dispatch = client.dispatch
|
self._dispatch = client.dispatch
|
||||||
self._queue = client._queue
|
self._queue = client._queue
|
||||||
self.loop = self._client.loop
|
self.loop = self._client.loop
|
||||||
|
self._disconnect = False
|
||||||
|
self._reconnect = client._reconnect
|
||||||
|
self._backoff = ExponentialBackoff()
|
||||||
self._task = None
|
self._task = None
|
||||||
|
self._handled_exceptions = (
|
||||||
|
OSError,
|
||||||
|
HTTPException,
|
||||||
|
GatewayNotFound,
|
||||||
|
ConnectionClosed,
|
||||||
|
aiohttp.ClientError,
|
||||||
|
asyncio.TimeoutError,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def id(self):
|
def id(self):
|
||||||
@ -79,6 +94,33 @@ class Shard:
|
|||||||
def launch(self):
|
def launch(self):
|
||||||
self._task = self.loop.create_task(self.worker())
|
self._task = self.loop.create_task(self.worker())
|
||||||
|
|
||||||
|
def _cancel_task(self):
|
||||||
|
if self._task is not None and not self._task.done():
|
||||||
|
self._task.cancel()
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
self._cancel_task()
|
||||||
|
await self.ws.close(code=1000)
|
||||||
|
|
||||||
|
async def _handle_disconnect(self, e):
|
||||||
|
self._dispatch('disconnect')
|
||||||
|
if not self._reconnect:
|
||||||
|
self._queue.put_nowait(EventItem(EventType.close, self, e))
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._client.is_closed():
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(e, ConnectionClosed):
|
||||||
|
if e.code != 1000:
|
||||||
|
self._queue.put_nowait(EventItem(EventType.close, self, e))
|
||||||
|
return
|
||||||
|
|
||||||
|
retry = self._backoff.delay()
|
||||||
|
log.error('Attempting a reconnect for shard ID %s in %.2fs', self.id, retry, exc_info=e)
|
||||||
|
await asyncio.sleep(retry)
|
||||||
|
self._queue.put_nowait(EventItem(EventType.reconnect, self, e))
|
||||||
|
|
||||||
async def worker(self):
|
async def worker(self):
|
||||||
while not self._client.is_closed():
|
while not self._client.is_closed():
|
||||||
try:
|
try:
|
||||||
@ -87,14 +129,12 @@ class Shard:
|
|||||||
etype = EventType.resume if e.resume else EventType.identify
|
etype = EventType.resume if e.resume else EventType.identify
|
||||||
self._queue.put_nowait(EventItem(etype, self, e))
|
self._queue.put_nowait(EventItem(etype, self, e))
|
||||||
break
|
break
|
||||||
except ConnectionClosed as e:
|
except self._handled_exceptions as e:
|
||||||
self._queue.put_nowait(EventItem(EventType.close, self, e))
|
await self._handle_disconnect(e)
|
||||||
break
|
break
|
||||||
|
|
||||||
async def reconnect(self, exc):
|
async def reidentify(self, exc):
|
||||||
if self._task is not None and not self._task.done():
|
self._cancel_task()
|
||||||
self._task.cancel()
|
|
||||||
|
|
||||||
log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
|
log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
|
||||||
coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id,
|
coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id,
|
||||||
session=self.ws.session_id, sequence=self.ws.sequence)
|
session=self.ws.session_id, sequence=self.ws.sequence)
|
||||||
@ -102,6 +142,16 @@ class Shard:
|
|||||||
self.ws = await asyncio.wait_for(coro, timeout=180.0)
|
self.ws = await asyncio.wait_for(coro, timeout=180.0)
|
||||||
self.launch()
|
self.launch()
|
||||||
|
|
||||||
|
async def reconnect(self):
|
||||||
|
self._cancel_task()
|
||||||
|
try:
|
||||||
|
coro = DiscordWebSocket.from_client(self._client, shard_id=self.id)
|
||||||
|
self.ws = await asyncio.wait_for(coro, timeout=180.0)
|
||||||
|
except self._handled_exceptions as e:
|
||||||
|
await self._handle_disconnect(e)
|
||||||
|
else:
|
||||||
|
self.launch()
|
||||||
|
|
||||||
class AutoShardedClient(Client):
|
class AutoShardedClient(Client):
|
||||||
"""A client similar to :class:`Client` except it handles the complications
|
"""A client similar to :class:`Client` except it handles the complications
|
||||||
of sharding for the user into a more manageable and transparent single
|
of sharding for the user into a more manageable and transparent single
|
||||||
@ -235,15 +285,21 @@ class AutoShardedClient(Client):
|
|||||||
|
|
||||||
self._connection.shards_launched.set()
|
self._connection.shards_launched.set()
|
||||||
|
|
||||||
async def _connect(self):
|
async def connect(self, *, reconnect=True):
|
||||||
|
self._reconnect = reconnect
|
||||||
await self.launch_shards()
|
await self.launch_shards()
|
||||||
|
|
||||||
while True:
|
while not self.is_closed():
|
||||||
item = await self._queue.get()
|
item = await self._queue.get()
|
||||||
if item.type == EventType.close:
|
if item.type == EventType.close:
|
||||||
raise item.error
|
await self.close()
|
||||||
|
if isinstance(item.error, ConnectionClosed) and item.error.code != 1000:
|
||||||
|
raise item.error
|
||||||
|
return
|
||||||
elif item.type in (EventType.identify, EventType.resume):
|
elif item.type in (EventType.identify, EventType.resume):
|
||||||
await item.shard.reconnect(item.error)
|
await item.shard.reidentify(item.error)
|
||||||
|
elif item.type == EventType.reconnect:
|
||||||
|
await item.shard.reconnect()
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""|coro|
|
"""|coro|
|
||||||
@ -261,7 +317,7 @@ class AutoShardedClient(Client):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
to_close = [asyncio.ensure_future(shard.ws.close(code=1000), loop=self.loop) for shard in self.shards.values()]
|
to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.shards.values()]
|
||||||
if to_close:
|
if to_close:
|
||||||
await asyncio.wait(to_close)
|
await asyncio.wait(to_close)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user