Change the way shards are launched in AutoShardedClient.
This commit is contained in:
		| @@ -214,35 +214,6 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): | ||||
|         else: | ||||
|             return ws | ||||
|  | ||||
|     @classmethod | ||||
|     @asyncio.coroutine | ||||
|     def from_sharded_client(cls, client): | ||||
|         if client.shard_count is None: | ||||
|             client.shard_count, gateway = yield from client.http.get_bot_gateway() | ||||
|         else: | ||||
|             gateway = yield from client.http.get_gateway() | ||||
|  | ||||
|         ret = [] | ||||
|         client.connection.shard_count = client.shard_count | ||||
|  | ||||
|         for shard_id in range(client.shard_count): | ||||
|             ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls) | ||||
|             ws.token = client.http.token | ||||
|             ws._connection = client.connection | ||||
|             ws._dispatch = client.dispatch | ||||
|             ws.gateway = gateway | ||||
|             ws.shard_id = shard_id | ||||
|             ws.shard_count = client.shard_count | ||||
|  | ||||
|             # OP HELLO | ||||
|             yield from ws.poll_event() | ||||
|             yield from ws.identify() | ||||
|             ret.append(ws) | ||||
|             log.info('Sent IDENTIFY payload to create the websocket for shard_id: %s' % shard_id) | ||||
|             yield from asyncio.sleep(5.0, loop=client.loop) | ||||
|  | ||||
|         return ret | ||||
|  | ||||
|     def wait_for(self, event, predicate, result=None): | ||||
|         """Waits for a DISPATCH'd event that meets the predicate. | ||||
|  | ||||
|   | ||||
| @@ -32,6 +32,7 @@ from . import compat | ||||
|  | ||||
| import asyncio | ||||
| import logging | ||||
| import websockets | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
| @@ -93,8 +94,10 @@ class AutoShardedClient(Client): | ||||
|                                                      syncer=self._syncer, http=self.http, loop=self.loop, **kwargs) | ||||
|  | ||||
|         # instead of a single websocket, we have multiple | ||||
|         # the index is the shard_id | ||||
|         self.shards = [] | ||||
|         # the key is the shard_id | ||||
|         self.shards = {} | ||||
|  | ||||
|         self._still_sharding = True | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def request_offline_members(self, guild, *, shard_id=None): | ||||
| @@ -135,6 +138,56 @@ class AutoShardedClient(Client): | ||||
|         ws = self.shards[shard_id].ws | ||||
|         yield from ws.send_as_json(payload) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def pending_reads(self, shard): | ||||
|         try: | ||||
|             while self._still_sharding: | ||||
|                 yield from shard.poll() | ||||
|         except asyncio.CancelledError: | ||||
|             pass | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def launch_shard(self, gateway, shard_id): | ||||
|         try: | ||||
|             ws = yield from websockets.connect(gateway, loop=self.loop, klass=DiscordWebSocket) | ||||
|         except Exception as e: | ||||
|             import traceback | ||||
|             traceback.print_exc() | ||||
|             log.info('Failed to connect for shard_id: %s. Retrying...' % shard_id) | ||||
|             yield from asyncio.sleep(5.0, loop=self.loop) | ||||
|             yield from self.launch_shard(gateway, shard_id) | ||||
|  | ||||
|         ws.token = self.http.token | ||||
|         ws._connection = self.connection | ||||
|         ws._dispatch = self.dispatch | ||||
|         ws.gateway = gateway | ||||
|         ws.shard_id = shard_id | ||||
|         ws.shard_count = self.shard_count | ||||
|  | ||||
|         # OP HELLO | ||||
|         yield from ws.poll_event() | ||||
|         yield from ws.identify() | ||||
|         log.info('Sent IDENTIFY payload to create the websocket for shard_id: %s' % shard_id) | ||||
|  | ||||
|         # keep reading the shard while others connect | ||||
|         self.shards[shard_id] = ret = Shard(ws, self) | ||||
|         compat.create_task(self.pending_reads(ret), loop=self.loop) | ||||
|         yield from asyncio.sleep(5.0, loop=self.loop) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def launch_shards(self): | ||||
|         if self.shard_count is None: | ||||
|             self.shard_count, gateway = yield from self.http.get_bot_gateway() | ||||
|         else: | ||||
|             gateway = yield from self.http.get_gateway() | ||||
|  | ||||
|         self.connection.shard_count = self.shard_count | ||||
|  | ||||
|         for shard_id in range(self.shard_count): | ||||
|             yield from self.launch_shard(gateway, shard_id) | ||||
|  | ||||
|         self._still_sharding = False | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def connect(self): | ||||
|         """|coro| | ||||
| @@ -150,11 +203,10 @@ class AutoShardedClient(Client): | ||||
|         ConnectionClosed | ||||
|             The websocket connection has been terminated. | ||||
|         """ | ||||
|         ret = yield from DiscordWebSocket.from_sharded_client(self) | ||||
|         self.shards = [Shard(ws, self) for ws in ret] | ||||
|         yield from self.launch_shards() | ||||
|  | ||||
|         while not self.is_closed: | ||||
|             pollers = [shard.get_future() for shard in self.shards] | ||||
|             pollers = [shard.get_future() for shard in self.shards.values()] | ||||
|             yield from asyncio.wait(pollers, loop=self.loop, return_when=asyncio.FIRST_COMPLETED) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
| @@ -166,7 +218,7 @@ class AutoShardedClient(Client): | ||||
|         if self.is_closed: | ||||
|             return | ||||
|  | ||||
|         for shard in self.shards: | ||||
|         for shard in self.shards.values(): | ||||
|             yield from shard.ws.close() | ||||
|  | ||||
|         yield from self.http.close() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user