Implement AutoShardedClient for transparent sharding.
This allows people to run their >2,500 guild bot in a single process without the headaches of IPC/RPC or much difficulty.
This commit is contained in:
@ -47,11 +47,13 @@ __all__ = [ 'ReconnectWebSocket', 'DiscordWebSocket',
|
||||
|
||||
class ReconnectWebSocket(Exception):
|
||||
"""Signals to handle the RECONNECT opcode."""
|
||||
pass
|
||||
def __init__(self, shard_id):
|
||||
self.shard_id = shard_id
|
||||
|
||||
class ResumeWebSocket(Exception):
|
||||
"""Signals to initialise via RESUME opcode instead of IDENTIFY."""
|
||||
pass
|
||||
def __init__(self, shard_id):
|
||||
self.shard_id = shard_id
|
||||
|
||||
EventListener = namedtuple('EventListener', 'predicate event result future')
|
||||
|
||||
@ -81,7 +83,7 @@ class KeepAliveHandler(threading.Thread):
|
||||
def get_payload(self):
|
||||
return {
|
||||
'op': self.ws.HEARTBEAT,
|
||||
'd': self.ws._connection.sequence
|
||||
'd': self.ws.sequence
|
||||
}
|
||||
|
||||
def stop(self):
|
||||
@ -165,9 +167,13 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
# the keep alive
|
||||
self._keep_alive = None
|
||||
|
||||
# ws related stuff
|
||||
self.session_id = None
|
||||
self.sequence = None
|
||||
|
||||
@classmethod
|
||||
@asyncio.coroutine
|
||||
def from_client(cls, client, *, resume=False):
|
||||
def from_client(cls, client, *, shard_id=None, session=None, sequence=None, resume=False):
|
||||
"""Creates a main websocket for Discord from a :class:`Client`.
|
||||
|
||||
This is for internal use only.
|
||||
@ -180,8 +186,10 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
ws._connection = client.connection
|
||||
ws._dispatch = client.dispatch
|
||||
ws.gateway = gateway
|
||||
ws.shard_id = client.shard_id
|
||||
ws.shard_count = client.shard_count
|
||||
ws.shard_id = shard_id
|
||||
ws.shard_count = client.connection.shard_count
|
||||
ws.session_id = session
|
||||
ws.sequence = sequence
|
||||
|
||||
client.connection._update_references(ws)
|
||||
|
||||
@ -206,6 +214,35 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
else:
|
||||
return ws
|
||||
|
||||
@classmethod
|
||||
@asyncio.coroutine
|
||||
def from_sharded_client(cls, client):
|
||||
if client.shard_count is None:
|
||||
client.shard_count, gateway = yield from client.http.get_bot_gateway()
|
||||
else:
|
||||
gateway = yield from client.http.get_gateway()
|
||||
|
||||
ret = []
|
||||
client.connection.shard_count = client.shard_count
|
||||
|
||||
for shard_id in range(client.shard_count):
|
||||
ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls)
|
||||
ws.token = client.http.token
|
||||
ws._connection = client.connection
|
||||
ws._dispatch = client.dispatch
|
||||
ws.gateway = gateway
|
||||
ws.shard_id = shard_id
|
||||
ws.shard_count = client.shard_count
|
||||
|
||||
# OP HELLO
|
||||
yield from ws.poll_event()
|
||||
yield from ws.identify()
|
||||
ret.append(ws)
|
||||
log.info('Sent IDENTIFY payload to create the websocket for shard_id: %s' % shard_id)
|
||||
yield from asyncio.sleep(5.0, loop=client.loop)
|
||||
|
||||
return ret
|
||||
|
||||
def wait_for(self, event, predicate, result=None):
|
||||
"""Waits for a DISPATCH'd event that meets the predicate.
|
||||
|
||||
@ -262,12 +299,11 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
@asyncio.coroutine
|
||||
def resume(self):
|
||||
"""Sends the RESUME packet."""
|
||||
state = self._connection
|
||||
payload = {
|
||||
'op': self.RESUME,
|
||||
'd': {
|
||||
'seq': state.sequence,
|
||||
'session_id': state.session_id,
|
||||
'seq': self.sequence,
|
||||
'session_id': self.session_id,
|
||||
'token': self.token
|
||||
}
|
||||
}
|
||||
@ -283,16 +319,15 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
msg = msg.decode('utf-8')
|
||||
|
||||
msg = json.loads(msg)
|
||||
state = self._connection
|
||||
|
||||
log.debug('WebSocket Event: {}'.format(msg))
|
||||
log.debug('For Shard ID {}: WebSocket Event: {}'.format(self.shard_id, msg))
|
||||
self._dispatch('socket_response', msg)
|
||||
|
||||
op = msg.get('op')
|
||||
data = msg.get('d')
|
||||
seq = msg.get('s')
|
||||
if seq is not None:
|
||||
state.sequence = seq
|
||||
self.sequence = seq
|
||||
|
||||
if op == self.RECONNECT:
|
||||
# "reconnect" can only be handled by the Client
|
||||
@ -300,7 +335,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
# internal exception signalling to reconnect.
|
||||
log.info('Received RECONNECT opcode.')
|
||||
yield from self.close()
|
||||
raise ReconnectWebSocket()
|
||||
raise ReconnectWebSocket(self.shard_id)
|
||||
|
||||
if op == self.HEARTBEAT_ACK:
|
||||
return # disable noisy logging for now
|
||||
@ -317,11 +352,11 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
return
|
||||
|
||||
if op == self.INVALIDATE_SESSION:
|
||||
state.sequence = None
|
||||
state.session_id = None
|
||||
self.sequence = None
|
||||
self.session_id = None
|
||||
if data == True:
|
||||
yield from self.close()
|
||||
raise ResumeWebSocket()
|
||||
raise ResumeWebSocket(self.shard_id)
|
||||
|
||||
yield from self.identify()
|
||||
return
|
||||
@ -334,9 +369,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
is_ready = event == 'READY'
|
||||
|
||||
if is_ready:
|
||||
state.clear()
|
||||
state.sequence = msg['s']
|
||||
state.session_id = data['session_id']
|
||||
self.sequence = msg['s']
|
||||
self.session_id = data['session_id']
|
||||
|
||||
parser = 'parse_' + event.lower()
|
||||
|
||||
@ -389,9 +423,9 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
if self._can_handle_close(e.code):
|
||||
log.info('Websocket closed with {0.code} ({0.reason}), attempting a reconnect.'.format(e))
|
||||
raise ResumeWebSocket() from e
|
||||
raise ResumeWebSocket(self.shard_id) from e
|
||||
else:
|
||||
raise ConnectionClosed(e) from e
|
||||
raise ConnectionClosed(e, shard_id=self.shard_id) from e
|
||||
|
||||
@asyncio.coroutine
|
||||
def send(self, data):
|
||||
@ -404,7 +438,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
yield from super().send(utils.to_json(data))
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
if not self._can_handle_close(e.code):
|
||||
raise ConnectionClosed(e) from e
|
||||
raise ConnectionClosed(e, shard_id=self.shard_id) from e
|
||||
|
||||
@asyncio.coroutine
|
||||
def change_presence(self, *, game=None, status=None, afk=False, since=0.0, idle=None):
|
||||
@ -615,7 +649,7 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
msg = yield from asyncio.wait_for(self.recv(), timeout=30.0, loop=self.loop)
|
||||
yield from self.received_message(json.loads(msg))
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
raise ConnectionClosed(e) from e
|
||||
raise ConnectionClosed(e, shard_id=None) from e
|
||||
|
||||
@asyncio.coroutine
|
||||
def close_connection(self, force=False):
|
||||
|
Reference in New Issue
Block a user