Add RESUME support.

This commit is contained in:
Rapptz
2016-06-01 05:13:15 -04:00
parent 20e86973ea
commit e0a91df32b
4 changed files with 54 additions and 12 deletions

View File

@@ -401,9 +401,10 @@ class Client:
while not self.is_closed:
try:
yield from self.ws.poll_event()
except ReconnectWebSocket:
log.info('Reconnecting the websocket.')
self.ws = yield from DiscordWebSocket.from_client(self)
except (ReconnectWebSocket, ResumeWebSocket) as e:
resume = type(e) is ResumeWebSocket
log.info('Got ' + type(e).__name__)
self.ws = yield from DiscordWebSocket.from_client(self, resume=resume)
except ConnectionClosed as e:
yield from self.close()
if e.code != 1000:

View File

@@ -42,12 +42,16 @@ log = logging.getLogger(__name__)
__all__ = [ 'ReconnectWebSocket', 'get_gateway', 'DiscordWebSocket',
'KeepAliveHandler', 'VoiceKeepAliveHandler',
'DiscordVoiceWebSocket' ]
'DiscordVoiceWebSocket', 'ResumeWebSocket' ]
class ReconnectWebSocket(Exception):
"""Signals to handle the RECONNECT opcode."""
pass
class ResumeWebSocket(Exception):
"""Signals to initialise via RESUME opcode instead of IDENTIFY."""
pass
EventListener = namedtuple('EventListener', 'predicate event result future')
class KeepAliveHandler(threading.Thread):
@@ -179,10 +183,9 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
# the keep alive
self._keep_alive = None
@classmethod
@asyncio.coroutine
def from_client(cls, client):
def from_client(cls, client, *, resume=False):
"""Creates a main websocket for Discord from a :class:`Client`.
This is for internal use only.
@@ -197,9 +200,21 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
ws.gateway = gateway
log.info('Created websocket connected to {}'.format(gateway))
yield from ws.identify()
log.info('sent the identify payload to create the websocket')
return ws
if not resume:
yield from ws.identify()
log.info('sent the identify payload to create the websocket')
return ws
yield from ws.resume()
log.info('sent the resume payload to create the websocket')
try:
yield from ws.ensure_open()
except websockets.exceptions.ConnectionClosed:
# ws got closed so let's just do a regular IDENTIFY connect.
log.info('RESUME failure.')
return (yield from cls.from_client(client))
else:
return ws
def wait_for(self, event, predicate, result=None):
"""Waits for a DISPATCH'd event that meets the predicate.
@@ -247,6 +262,21 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
}
yield from self.send_as_json(payload)
@asyncio.coroutine
def resume(self):
"""Sends the RESUME packet."""
state = self._connection
payload = {
'op': self.RESUME,
'd': {
'seq': state.sequence,
'session_id': state.session_id,
'token': self.token
}
}
yield from self.send_as_json(payload)
@asyncio.coroutine
def received_message(self, msg):
self._dispatch('socket_raw_receive', msg)
@@ -271,13 +301,14 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
# "reconnect" can only be handled by the Client
# so we terminate our connection and raise an
# internal exception signalling to reconnect.
log.info('Receivede RECONNECT opcode.')
log.info('Received RECONNECT opcode.')
yield from self.close()
raise ReconnectWebSocket()
if op == self.INVALIDATE_SESSION:
state.sequence = None
state.session_id = None
yield from self.identify()
return
if op != self.DISPATCH:
@@ -347,8 +378,11 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
yield from self.received_message(msg)
except websockets.exceptions.ConnectionClosed as e:
if self._can_handle_close(e.code):
log.info('Websocket closed with {0.code}, attempting a reconnect.'.format(e))
raise ReconnectWebSocket() from e
log.info('Websocket closed with {0.code} ({0.reason}), attempting a reconnect.'.format(e))
if e.code == 4006:
raise ReconnectWebSocket() from e
else:
raise ResumeWebSocket() from e
else:
raise ConnectionClosed(e) from e

View File

@@ -199,6 +199,9 @@ class ConnectionState:
compat.create_task(self._delay_ready(), loop=self.loop)
def parse_resumed(self, data):
self.dispatch('resumed')
def parse_message_create(self, data):
channel = self.get_channel(data.get('channel_id'))
message = Message(channel=channel, **data)