Allow overriding the shard_ids used for initial shard launch.

This commit is contained in:
Rapptz 2017-01-08 02:05:21 -05:00
parent 4bc6625739
commit 92c1637921
3 changed files with 22 additions and 6 deletions

View File

@ -406,7 +406,7 @@ class Client:
while not self.is_closed:
try:
yield from ws.poll_event()
yield from self.ws.poll_event()
except (ReconnectWebSocket, ResumeWebSocket) as e:
resume = type(e) is ResumeWebSocket
log.info('Got ' + type(e).__name__)

View File

@ -27,7 +27,7 @@ DEALINGS IN THE SOFTWARE.
from .state import AutoShardedConnectionState
from .client import Client
from .gateway import *
from .errors import ConnectionClosed
from .errors import ConnectionClosed, ClientException
from . import compat
from .enums import Status
@ -86,11 +86,28 @@ class AutoShardedClient(Client):
If no :attr:`shard_count` is provided, then the library will use the
Bot Gateway endpoint call to figure out how many shards to use.
If a ``shard_ids`` parameter is given, then those shard IDs will be used
to launch the internal shards. Note that :attr:`shard_count` must be provided
if this is used. By default, when omitted, the client will launch shards from
0 to ``shard_count - ``\.
Attributes
------------
shard_ids: Optional[List[int]]
An optional list of shard_ids to launch the shards with.
"""
def __init__(self, *args, loop=None, **kwargs):
kwargs.pop('shard_id', None)
self.shard_ids = kwargs.pop('shard_ids', None)
super().__init__(*args, loop=loop, **kwargs)
if self.shard_ids is not None:
if self.shard_count is None:
raise ClientException('When passing manual shard_ids, you must provide a shard_count.')
elif not isinstance(self.shard_ids, (list, tuple)):
raise ClientException('shard_ids parameter must be a list or a tuple.')
self.connection = AutoShardedConnectionState(dispatch=self.dispatch, chunker=self.request_offline_members,
syncer=self._syncer, http=self.http, loop=self.loop, **kwargs)
@ -184,7 +201,9 @@ class AutoShardedClient(Client):
self.connection.shard_count = self.shard_count
for shard_id in range(self.shard_count):
shard_ids = self.shard_ids if self.shard_ids else range(self.shard_count)
for shard_id in shard_ids:
yield from self.launch_shard(gateway, shard_id)
self._still_sharding = False

View File

@ -743,9 +743,6 @@ class AutoShardedConnectionState(ConnectionState):
self.dispatch('shard_ready', shard_id)
# sleep a second for every shard ID.
# yield from asyncio.sleep(1.0, loop=self.loop)
# remove the state
try:
del self._ready_state