mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-04-22 08:44:10 +00:00
Change the way shards are launched in AutoShardedClient.
This commit is contained in:
parent
d93067ca0f
commit
b5bed9ef33
@ -214,35 +214,6 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
else:
|
||||
return ws
|
||||
|
||||
@classmethod
|
||||
@asyncio.coroutine
|
||||
def from_sharded_client(cls, client):
|
||||
if client.shard_count is None:
|
||||
client.shard_count, gateway = yield from client.http.get_bot_gateway()
|
||||
else:
|
||||
gateway = yield from client.http.get_gateway()
|
||||
|
||||
ret = []
|
||||
client.connection.shard_count = client.shard_count
|
||||
|
||||
for shard_id in range(client.shard_count):
|
||||
ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls)
|
||||
ws.token = client.http.token
|
||||
ws._connection = client.connection
|
||||
ws._dispatch = client.dispatch
|
||||
ws.gateway = gateway
|
||||
ws.shard_id = shard_id
|
||||
ws.shard_count = client.shard_count
|
||||
|
||||
# OP HELLO
|
||||
yield from ws.poll_event()
|
||||
yield from ws.identify()
|
||||
ret.append(ws)
|
||||
log.info('Sent IDENTIFY payload to create the websocket for shard_id: %s' % shard_id)
|
||||
yield from asyncio.sleep(5.0, loop=client.loop)
|
||||
|
||||
return ret
|
||||
|
||||
def wait_for(self, event, predicate, result=None):
|
||||
"""Waits for a DISPATCH'd event that meets the predicate.
|
||||
|
||||
|
@ -32,6 +32,7 @@ from . import compat
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import websockets
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -93,8 +94,10 @@ class AutoShardedClient(Client):
|
||||
syncer=self._syncer, http=self.http, loop=self.loop, **kwargs)
|
||||
|
||||
# instead of a single websocket, we have multiple
|
||||
# the index is the shard_id
|
||||
self.shards = []
|
||||
# the key is the shard_id
|
||||
self.shards = {}
|
||||
|
||||
self._still_sharding = True
|
||||
|
||||
@asyncio.coroutine
|
||||
def request_offline_members(self, guild, *, shard_id=None):
|
||||
@ -135,6 +138,56 @@ class AutoShardedClient(Client):
|
||||
ws = self.shards[shard_id].ws
|
||||
yield from ws.send_as_json(payload)
|
||||
|
||||
@asyncio.coroutine
|
||||
def pending_reads(self, shard):
|
||||
try:
|
||||
while self._still_sharding:
|
||||
yield from shard.poll()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
@asyncio.coroutine
|
||||
def launch_shard(self, gateway, shard_id):
|
||||
try:
|
||||
ws = yield from websockets.connect(gateway, loop=self.loop, klass=DiscordWebSocket)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
log.info('Failed to connect for shard_id: %s. Retrying...' % shard_id)
|
||||
yield from asyncio.sleep(5.0, loop=self.loop)
|
||||
yield from self.launch_shard(gateway, shard_id)
|
||||
|
||||
ws.token = self.http.token
|
||||
ws._connection = self.connection
|
||||
ws._dispatch = self.dispatch
|
||||
ws.gateway = gateway
|
||||
ws.shard_id = shard_id
|
||||
ws.shard_count = self.shard_count
|
||||
|
||||
# OP HELLO
|
||||
yield from ws.poll_event()
|
||||
yield from ws.identify()
|
||||
log.info('Sent IDENTIFY payload to create the websocket for shard_id: %s' % shard_id)
|
||||
|
||||
# keep reading the shard while others connect
|
||||
self.shards[shard_id] = ret = Shard(ws, self)
|
||||
compat.create_task(self.pending_reads(ret), loop=self.loop)
|
||||
yield from asyncio.sleep(5.0, loop=self.loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
def launch_shards(self):
|
||||
if self.shard_count is None:
|
||||
self.shard_count, gateway = yield from self.http.get_bot_gateway()
|
||||
else:
|
||||
gateway = yield from self.http.get_gateway()
|
||||
|
||||
self.connection.shard_count = self.shard_count
|
||||
|
||||
for shard_id in range(self.shard_count):
|
||||
yield from self.launch_shard(gateway, shard_id)
|
||||
|
||||
self._still_sharding = False
|
||||
|
||||
@asyncio.coroutine
|
||||
def connect(self):
|
||||
"""|coro|
|
||||
@ -150,11 +203,10 @@ class AutoShardedClient(Client):
|
||||
ConnectionClosed
|
||||
The websocket connection has been terminated.
|
||||
"""
|
||||
ret = yield from DiscordWebSocket.from_sharded_client(self)
|
||||
self.shards = [Shard(ws, self) for ws in ret]
|
||||
yield from self.launch_shards()
|
||||
|
||||
while not self.is_closed:
|
||||
pollers = [shard.get_future() for shard in self.shards]
|
||||
pollers = [shard.get_future() for shard in self.shards.values()]
|
||||
yield from asyncio.wait(pollers, loop=self.loop, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
@asyncio.coroutine
|
||||
@ -166,7 +218,7 @@ class AutoShardedClient(Client):
|
||||
if self.is_closed:
|
||||
return
|
||||
|
||||
for shard in self.shards:
|
||||
for shard in self.shards.values():
|
||||
yield from shard.ws.close()
|
||||
|
||||
yield from self.http.close()
|
||||
|
Loading…
x
Reference in New Issue
Block a user