mirror of
				https://github.com/Rapptz/discord.py.git
				synced 2025-11-04 07:22:50 +00:00 
			
		
		
		
	Implement AutoShardedClient for transparent sharding.
This allows people to run their >2,500 guild bot in a single process without the headaches of IPC/RPC or much difficulty.
This commit is contained in:
		@@ -37,6 +37,7 @@ from . import utils, opus, compat, abc
 | 
			
		||||
from .enums import ChannelType, GuildRegion, Status, MessageType, VerificationLevel
 | 
			
		||||
from collections import namedtuple
 | 
			
		||||
from .embeds import Embed
 | 
			
		||||
from .shard import AutoShardedClient
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -142,6 +142,7 @@ class Client:
 | 
			
		||||
        self.connection = ConnectionState(dispatch=self.dispatch, chunker=self.request_offline_members,
 | 
			
		||||
                                          syncer=self._syncer, http=self.http, loop=self.loop, **options)
 | 
			
		||||
 | 
			
		||||
        self.connection.shard_count = self.shard_count
 | 
			
		||||
        self._closed = asyncio.Event(loop=self.loop)
 | 
			
		||||
        self._is_logged_in = asyncio.Event(loop=self.loop)
 | 
			
		||||
        self._is_ready = asyncio.Event(loop=self.loop)
 | 
			
		||||
@@ -405,11 +406,14 @@ class Client:
 | 
			
		||||
 | 
			
		||||
        while not self.is_closed:
 | 
			
		||||
            try:
 | 
			
		||||
                yield from self.ws.poll_event()
 | 
			
		||||
                yield from ws.poll_event()
 | 
			
		||||
            except (ReconnectWebSocket, ResumeWebSocket) as e:
 | 
			
		||||
                resume = type(e) is ResumeWebSocket
 | 
			
		||||
                log.info('Got ' + type(e).__name__)
 | 
			
		||||
                self.ws = yield from DiscordWebSocket.from_client(self, resume=resume)
 | 
			
		||||
                self.ws = yield from DiscordWebSocket.from_client(self, shard_id=self.shard_id,
 | 
			
		||||
                                                                        session=self.ws.session_id,
 | 
			
		||||
                                                                        sequence=self.ws.sequence,
 | 
			
		||||
                                                                        resume=resume)
 | 
			
		||||
            except ConnectionClosed as e:
 | 
			
		||||
                yield from self.close()
 | 
			
		||||
                if e.code != 1000:
 | 
			
		||||
 
 | 
			
		||||
@@ -118,14 +118,17 @@ class ConnectionClosed(ClientException):
 | 
			
		||||
 | 
			
		||||
    Attributes
 | 
			
		||||
    -----------
 | 
			
		||||
    code : int
 | 
			
		||||
    code: int
 | 
			
		||||
        The close code of the websocket.
 | 
			
		||||
    reason : str
 | 
			
		||||
    reason: str
 | 
			
		||||
        The reason provided for the closure.
 | 
			
		||||
    shard_id: Optional[int]
 | 
			
		||||
        The shard ID that got closed if applicable.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, original):
 | 
			
		||||
    def __init__(self, original, *, shard_id):
 | 
			
		||||
        # This exception is just the same exception except
 | 
			
		||||
        # reconfigured to subclass ClientException for users
 | 
			
		||||
        self.code = original.code
 | 
			
		||||
        self.reason = original.reason
 | 
			
		||||
        self.shard_id = shard_id
 | 
			
		||||
        super().__init__(str(original))
 | 
			
		||||
 
 | 
			
		||||
@@ -47,11 +47,13 @@ __all__ = [ 'ReconnectWebSocket', 'DiscordWebSocket',
 | 
			
		||||
 | 
			
		||||
class ReconnectWebSocket(Exception):
 | 
			
		||||
    """Signals to handle the RECONNECT opcode."""
 | 
			
		||||
    pass
 | 
			
		||||
    def __init__(self, shard_id):
 | 
			
		||||
        self.shard_id = shard_id
 | 
			
		||||
 | 
			
		||||
class ResumeWebSocket(Exception):
 | 
			
		||||
    """Signals to initialise via RESUME opcode instead of IDENTIFY."""
 | 
			
		||||
    pass
 | 
			
		||||
    def __init__(self, shard_id):
 | 
			
		||||
        self.shard_id = shard_id
 | 
			
		||||
 | 
			
		||||
EventListener = namedtuple('EventListener', 'predicate event result future')
 | 
			
		||||
 | 
			
		||||
@@ -81,7 +83,7 @@ class KeepAliveHandler(threading.Thread):
 | 
			
		||||
    def get_payload(self):
 | 
			
		||||
        return {
 | 
			
		||||
            'op': self.ws.HEARTBEAT,
 | 
			
		||||
            'd': self.ws._connection.sequence
 | 
			
		||||
            'd': self.ws.sequence
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def stop(self):
 | 
			
		||||
@@ -165,9 +167,13 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
 | 
			
		||||
        # the keep alive
 | 
			
		||||
        self._keep_alive = None
 | 
			
		||||
 | 
			
		||||
        # ws related stuff
 | 
			
		||||
        self.session_id = None
 | 
			
		||||
        self.sequence = None
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    @asyncio.coroutine
 | 
			
		||||
    def from_client(cls, client, *, resume=False):
 | 
			
		||||
    def from_client(cls, client, *, shard_id=None, session=None, sequence=None, resume=False):
 | 
			
		||||
        """Creates a main websocket for Discord from a :class:`Client`.
 | 
			
		||||
 | 
			
		||||
        This is for internal use only.
 | 
			
		||||
@@ -180,8 +186,10 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
 | 
			
		||||
        ws._connection = client.connection
 | 
			
		||||
        ws._dispatch = client.dispatch
 | 
			
		||||
        ws.gateway = gateway
 | 
			
		||||
        ws.shard_id = client.shard_id
 | 
			
		||||
        ws.shard_count = client.shard_count
 | 
			
		||||
        ws.shard_id = shard_id
 | 
			
		||||
        ws.shard_count = client.connection.shard_count
 | 
			
		||||
        ws.session_id = session
 | 
			
		||||
        ws.sequence = sequence
 | 
			
		||||
 | 
			
		||||
        client.connection._update_references(ws)
 | 
			
		||||
 | 
			
		||||
@@ -206,6 +214,35 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
 | 
			
		||||
        else:
 | 
			
		||||
            return ws
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    @asyncio.coroutine
 | 
			
		||||
    def from_sharded_client(cls, client):
 | 
			
		||||
        if client.shard_count is None:
 | 
			
		||||
            client.shard_count, gateway = yield from client.http.get_bot_gateway()
 | 
			
		||||
        else:
 | 
			
		||||
            gateway = yield from client.http.get_gateway()
 | 
			
		||||
 | 
			
		||||
        ret = []
 | 
			
		||||
        client.connection.shard_count = client.shard_count
 | 
			
		||||
 | 
			
		||||
        for shard_id in range(client.shard_count):
 | 
			
		||||
            ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls)
 | 
			
		||||
            ws.token = client.http.token
 | 
			
		||||
            ws._connection = client.connection
 | 
			
		||||
            ws._dispatch = client.dispatch
 | 
			
		||||
            ws.gateway = gateway
 | 
			
		||||
            ws.shard_id = shard_id
 | 
			
		||||
            ws.shard_count = client.shard_count
 | 
			
		||||
 | 
			
		||||
            # OP HELLO
 | 
			
		||||
            yield from ws.poll_event()
 | 
			
		||||
            yield from ws.identify()
 | 
			
		||||
            ret.append(ws)
 | 
			
		||||
            log.info('Sent IDENTIFY payload to create the websocket for shard_id: %s' % shard_id)
 | 
			
		||||
            yield from asyncio.sleep(5.0, loop=client.loop)
 | 
			
		||||
 | 
			
		||||
        return ret
 | 
			
		||||
 | 
			
		||||
    def wait_for(self, event, predicate, result=None):
 | 
			
		||||
        """Waits for a DISPATCH'd event that meets the predicate.
 | 
			
		||||
 | 
			
		||||
@@ -262,12 +299,11 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
 | 
			
		||||
    @asyncio.coroutine
 | 
			
		||||
    def resume(self):
 | 
			
		||||
        """Sends the RESUME packet."""
 | 
			
		||||
        state = self._connection
 | 
			
		||||
        payload = {
 | 
			
		||||
            'op': self.RESUME,
 | 
			
		||||
            'd': {
 | 
			
		||||
                'seq': state.sequence,
 | 
			
		||||
                'session_id': state.session_id,
 | 
			
		||||
                'seq': self.sequence,
 | 
			
		||||
                'session_id': self.session_id,
 | 
			
		||||
                'token': self.token
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
@@ -283,16 +319,15 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
 | 
			
		||||
            msg = msg.decode('utf-8')
 | 
			
		||||
 | 
			
		||||
        msg = json.loads(msg)
 | 
			
		||||
        state = self._connection
 | 
			
		||||
 | 
			
		||||
        log.debug('WebSocket Event: {}'.format(msg))
 | 
			
		||||
        log.debug('For Shard ID {}: WebSocket Event: {}'.format(self.shard_id, msg))
 | 
			
		||||
        self._dispatch('socket_response', msg)
 | 
			
		||||
 | 
			
		||||
        op = msg.get('op')
 | 
			
		||||
        data = msg.get('d')
 | 
			
		||||
        seq = msg.get('s')
 | 
			
		||||
        if seq is not None:
 | 
			
		||||
            state.sequence = seq
 | 
			
		||||
            self.sequence = seq
 | 
			
		||||
 | 
			
		||||
        if op == self.RECONNECT:
 | 
			
		||||
            # "reconnect" can only be handled by the Client
 | 
			
		||||
@@ -300,7 +335,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
 | 
			
		||||
            # internal exception signalling to reconnect.
 | 
			
		||||
            log.info('Received RECONNECT opcode.')
 | 
			
		||||
            yield from self.close()
 | 
			
		||||
            raise ReconnectWebSocket()
 | 
			
		||||
            raise ReconnectWebSocket(self.shard_id)
 | 
			
		||||
 | 
			
		||||
        if op == self.HEARTBEAT_ACK:
 | 
			
		||||
            return # disable noisy logging for now
 | 
			
		||||
@@ -317,11 +352,11 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        if op == self.INVALIDATE_SESSION:
 | 
			
		||||
            state.sequence = None
 | 
			
		||||
            state.session_id = None
 | 
			
		||||
            self.sequence = None
 | 
			
		||||
            self.session_id = None
 | 
			
		||||
            if data == True:
 | 
			
		||||
                yield from self.close()
 | 
			
		||||
                raise ResumeWebSocket()
 | 
			
		||||
                raise ResumeWebSocket(self.shard_id)
 | 
			
		||||
 | 
			
		||||
            yield from self.identify()
 | 
			
		||||
            return
 | 
			
		||||
@@ -334,9 +369,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
 | 
			
		||||
        is_ready = event == 'READY'
 | 
			
		||||
 | 
			
		||||
        if is_ready:
 | 
			
		||||
            state.clear()
 | 
			
		||||
            state.sequence = msg['s']
 | 
			
		||||
            state.session_id = data['session_id']
 | 
			
		||||
            self.sequence = msg['s']
 | 
			
		||||
            self.session_id = data['session_id']
 | 
			
		||||
 | 
			
		||||
        parser = 'parse_' + event.lower()
 | 
			
		||||
 | 
			
		||||
@@ -389,9 +423,9 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
 | 
			
		||||
        except websockets.exceptions.ConnectionClosed as e:
 | 
			
		||||
            if self._can_handle_close(e.code):
 | 
			
		||||
                log.info('Websocket closed with {0.code} ({0.reason}), attempting a reconnect.'.format(e))
 | 
			
		||||
                raise ResumeWebSocket() from e
 | 
			
		||||
                raise ResumeWebSocket(self.shard_id) from e
 | 
			
		||||
            else:
 | 
			
		||||
                raise ConnectionClosed(e) from e
 | 
			
		||||
                raise ConnectionClosed(e, shard_id=self.shard_id) from e
 | 
			
		||||
 | 
			
		||||
    @asyncio.coroutine
 | 
			
		||||
    def send(self, data):
 | 
			
		||||
@@ -404,7 +438,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
 | 
			
		||||
            yield from super().send(utils.to_json(data))
 | 
			
		||||
        except websockets.exceptions.ConnectionClosed as e:
 | 
			
		||||
            if not self._can_handle_close(e.code):
 | 
			
		||||
                raise ConnectionClosed(e) from e
 | 
			
		||||
                raise ConnectionClosed(e, shard_id=self.shard_id) from e
 | 
			
		||||
 | 
			
		||||
    @asyncio.coroutine
 | 
			
		||||
    def change_presence(self, *, game=None, status=None, afk=False, since=0.0, idle=None):
 | 
			
		||||
@@ -615,7 +649,7 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
 | 
			
		||||
            msg = yield from asyncio.wait_for(self.recv(), timeout=30.0, loop=self.loop)
 | 
			
		||||
            yield from self.received_message(json.loads(msg))
 | 
			
		||||
        except websockets.exceptions.ConnectionClosed as e:
 | 
			
		||||
            raise ConnectionClosed(e) from e
 | 
			
		||||
            raise ConnectionClosed(e, shard_id=None) from e
 | 
			
		||||
 | 
			
		||||
    @asyncio.coroutine
 | 
			
		||||
    def close_connection(self, force=False):
 | 
			
		||||
 
 | 
			
		||||
@@ -324,6 +324,14 @@ class Guild(Hashable):
 | 
			
		||||
        """Returns the true member count regardless of it being loaded fully or not."""
 | 
			
		||||
        return self._member_count
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def shard_id(self):
 | 
			
		||||
        """Returns the shard ID for this guild if applicable."""
 | 
			
		||||
        count = self._state.shard_count
 | 
			
		||||
        if count is None:
 | 
			
		||||
            return None
 | 
			
		||||
        return (self.id >> 22) % count
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def created_at(self):
 | 
			
		||||
        """Returns the guild's creation time in UTC."""
 | 
			
		||||
 
 | 
			
		||||
@@ -588,5 +588,14 @@ class HTTPClient:
 | 
			
		||||
            raise GatewayNotFound() from e
 | 
			
		||||
        return data.get('url') + '?encoding=json&v=6'
 | 
			
		||||
 | 
			
		||||
    @asyncio.coroutine
 | 
			
		||||
    def get_bot_gateway(self):
 | 
			
		||||
        try:
 | 
			
		||||
            data = yield from self.get(self.GATEWAY + '/bot', bucket=_func_())
 | 
			
		||||
        except HTTPException as e:
 | 
			
		||||
            raise GatewayNotFound() from e
 | 
			
		||||
        else:
 | 
			
		||||
            return data['shards'], data['url'] + '?encoding=json&v=6'
 | 
			
		||||
 | 
			
		||||
    def get_user_info(self, user_id):
 | 
			
		||||
        return self.get('{0.USERS}/{1}'.format(self, user_id), bucket=_func_())
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										174
									
								
								discord/shard.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										174
									
								
								discord/shard.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,174 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
The MIT License (MIT)
 | 
			
		||||
 | 
			
		||||
Copyright (c) 2015-2016 Rapptz
 | 
			
		||||
 | 
			
		||||
Permission is hereby granted, free of charge, to any person obtaining a
 | 
			
		||||
copy of this software and associated documentation files (the "Software"),
 | 
			
		||||
to deal in the Software without restriction, including without limitation
 | 
			
		||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
 | 
			
		||||
and/or sell copies of the Software, and to permit persons to whom the
 | 
			
		||||
Software is furnished to do so, subject to the following conditions:
 | 
			
		||||
 | 
			
		||||
The above copyright notice and this permission notice shall be included in
 | 
			
		||||
all copies or substantial portions of the Software.
 | 
			
		||||
 | 
			
		||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
 | 
			
		||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 | 
			
		||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 | 
			
		||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 | 
			
		||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 | 
			
		||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
 | 
			
		||||
DEALINGS IN THE SOFTWARE.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
from .state import AutoShardedConnectionState
 | 
			
		||||
from .client import Client
 | 
			
		||||
from .gateway import *
 | 
			
		||||
from .errors import ConnectionClosed
 | 
			
		||||
from . import compat
 | 
			
		||||
 | 
			
		||||
import asyncio
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
log = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
class Shard:
 | 
			
		||||
    def __init__(self, ws, client):
 | 
			
		||||
        self.ws = ws
 | 
			
		||||
        self._client = client
 | 
			
		||||
        self.loop = self._client.loop
 | 
			
		||||
        self._current = asyncio.Future(loop=self.loop)
 | 
			
		||||
        self._current.set_result(None) # we just need an already done future
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def id(self):
 | 
			
		||||
        return self.ws.shard_id
 | 
			
		||||
 | 
			
		||||
    @asyncio.coroutine
 | 
			
		||||
    def poll(self):
 | 
			
		||||
        try:
 | 
			
		||||
            yield from self.ws.poll_event()
 | 
			
		||||
        except (ReconnectWebSocket, ResumeWebSocket) as e:
 | 
			
		||||
            resume = type(e) is ResumeWebSocket
 | 
			
		||||
            log.info('Got ' + type(e).__name__)
 | 
			
		||||
            self.ws = yield from DiscordWebSocket.from_client(self._client, resume=resume,
 | 
			
		||||
                                                                            shard_id=self.id,
 | 
			
		||||
                                                                            session=self.ws.session_id,
 | 
			
		||||
                                                                            sequence=self.ws.sequence)
 | 
			
		||||
        except ConnectionClosed as e:
 | 
			
		||||
            yield from self._client.close()
 | 
			
		||||
            if e.code != 1000:
 | 
			
		||||
                raise
 | 
			
		||||
 | 
			
		||||
    def get_future(self):
 | 
			
		||||
        if self._current.done():
 | 
			
		||||
            self._current = compat.create_task(self.poll(), loop=self.loop)
 | 
			
		||||
 | 
			
		||||
        return self._current
 | 
			
		||||
 | 
			
		||||
class AutoShardedClient(Client):
 | 
			
		||||
    """A client similar to :class:`Client` except it handles the complications
 | 
			
		||||
    of sharding for the user into a more manageable and transparent single
 | 
			
		||||
    process bot.
 | 
			
		||||
 | 
			
		||||
    When using this client, you will be able to use it as-if it was a regular
 | 
			
		||||
    :class:`Client` with a single shard when implementation wise internally it
 | 
			
		||||
    is split up into multiple shards. This allows you to not have to deal with
 | 
			
		||||
    IPC or other complicated infrastructure.
 | 
			
		||||
 | 
			
		||||
    It is recommended to use this client only if you have surpassed at least
 | 
			
		||||
    1000 guilds.
 | 
			
		||||
 | 
			
		||||
    If no :attr:`shard_count` is provided, then the library will use the
 | 
			
		||||
    Bot Gateway endpoint call to figure out how many shards to use.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, *args, loop=None, **kwargs):
 | 
			
		||||
        kwargs.pop('shard_id', None)
 | 
			
		||||
        super().__init__(*args, loop=loop, **kwargs)
 | 
			
		||||
 | 
			
		||||
        self.connection = AutoShardedConnectionState(dispatch=self.dispatch, chunker=self.request_offline_members,
 | 
			
		||||
                                                     syncer=self._syncer, http=self.http, loop=self.loop, **kwargs)
 | 
			
		||||
 | 
			
		||||
        # instead of a single websocket, we have multiple
 | 
			
		||||
        # the index is the shard_id
 | 
			
		||||
        self.shards = []
 | 
			
		||||
 | 
			
		||||
    @asyncio.coroutine
 | 
			
		||||
    def request_offline_members(self, guild, *, shard_id=None):
 | 
			
		||||
        """|coro|
 | 
			
		||||
 | 
			
		||||
        Requests previously offline members from the guild to be filled up
 | 
			
		||||
        into the :attr:`Guild.members` cache. This function is usually not
 | 
			
		||||
        called.
 | 
			
		||||
 | 
			
		||||
        When the client logs on and connects to the websocket, Discord does
 | 
			
		||||
        not provide the library with offline members if the number of members
 | 
			
		||||
        in the guild is larger than 250. You can check if a guild is large
 | 
			
		||||
        if :attr:`Guild.large` is ``True``.
 | 
			
		||||
 | 
			
		||||
        Parameters
 | 
			
		||||
        -----------
 | 
			
		||||
        guild: :class:`Guild` or list
 | 
			
		||||
            The guild to request offline members for. If this parameter is a
 | 
			
		||||
            list then it is interpreted as a list of guilds to request offline
 | 
			
		||||
            members for.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            guild_id = guild.id
 | 
			
		||||
            shard_id = shard_id or guild.shard_id
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            guild_id = [s.id for s in guild]
 | 
			
		||||
 | 
			
		||||
        payload = {
 | 
			
		||||
            'op': 8,
 | 
			
		||||
            'd': {
 | 
			
		||||
                'guild_id': guild_id,
 | 
			
		||||
                'query': '',
 | 
			
		||||
                'limit': 0
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        ws = self.shards[shard_id].ws
 | 
			
		||||
        yield from ws.send_as_json(payload)
 | 
			
		||||
 | 
			
		||||
    @asyncio.coroutine
 | 
			
		||||
    def connect(self):
 | 
			
		||||
        """|coro|
 | 
			
		||||
 | 
			
		||||
        Creates a websocket connection and lets the websocket listen
 | 
			
		||||
        to messages from discord.
 | 
			
		||||
 | 
			
		||||
        Raises
 | 
			
		||||
        -------
 | 
			
		||||
        GatewayNotFound
 | 
			
		||||
            If the gateway to connect to discord is not found. Usually if this
 | 
			
		||||
            is thrown then there is a discord API outage.
 | 
			
		||||
        ConnectionClosed
 | 
			
		||||
            The websocket connection has been terminated.
 | 
			
		||||
        """
 | 
			
		||||
        ret = yield from DiscordWebSocket.from_sharded_client(self)
 | 
			
		||||
        self.shards = [Shard(ws, self) for ws in ret]
 | 
			
		||||
 | 
			
		||||
        while not self.is_closed:
 | 
			
		||||
            pollers = [shard.get_future() for shard in self.shards]
 | 
			
		||||
            yield from asyncio.wait(pollers, loop=self.loop, return_when=asyncio.FIRST_COMPLETED)
 | 
			
		||||
 | 
			
		||||
    @asyncio.coroutine
 | 
			
		||||
    def close(self):
 | 
			
		||||
        """|coro|
 | 
			
		||||
 | 
			
		||||
        Closes the connection to discord.
 | 
			
		||||
        """
 | 
			
		||||
        if self.is_closed:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        for shard in self.shards:
 | 
			
		||||
            yield from shard.ws.close()
 | 
			
		||||
 | 
			
		||||
        yield from self.http.close()
 | 
			
		||||
        self._closed.set()
 | 
			
		||||
        self._is_ready.clear()
 | 
			
		||||
@@ -43,6 +43,7 @@ import datetime
 | 
			
		||||
import asyncio
 | 
			
		||||
import logging
 | 
			
		||||
import weakref
 | 
			
		||||
import itertools
 | 
			
		||||
 | 
			
		||||
class ListenerType(enum.Enum):
 | 
			
		||||
    chunk = 0
 | 
			
		||||
@@ -60,13 +61,12 @@ class ConnectionState:
 | 
			
		||||
        self.chunker = chunker
 | 
			
		||||
        self.syncer = syncer
 | 
			
		||||
        self.is_bot = None
 | 
			
		||||
        self.shard_count = None
 | 
			
		||||
        self._listeners = []
 | 
			
		||||
        self.clear()
 | 
			
		||||
 | 
			
		||||
    def clear(self):
 | 
			
		||||
        self.user = None
 | 
			
		||||
        self.sequence = None
 | 
			
		||||
        self.session_id = None
 | 
			
		||||
        self._users = weakref.WeakValueDictionary()
 | 
			
		||||
        self._calls = {}
 | 
			
		||||
        self._emojis = {}
 | 
			
		||||
@@ -355,7 +355,8 @@ class ConnectionState:
 | 
			
		||||
            # the reason we're doing this is so it's also removed from the
 | 
			
		||||
            # private channel by user cache as well
 | 
			
		||||
            channel = self._get_private_channel(channel_id)
 | 
			
		||||
            self._remove_private_channel(channel)
 | 
			
		||||
            if channel is not None:
 | 
			
		||||
                self._remove_private_channel(channel)
 | 
			
		||||
 | 
			
		||||
    def parse_channel_update(self, data):
 | 
			
		||||
        channel_type = try_enum(ChannelType, data.get('type'))
 | 
			
		||||
@@ -701,3 +702,76 @@ class ConnectionState:
 | 
			
		||||
        listener = Listener(ListenerType.chunk, future, lambda s: s.id == guild_id)
 | 
			
		||||
        self._listeners.append(listener)
 | 
			
		||||
        return future
 | 
			
		||||
 | 
			
		||||
class AutoShardedConnectionState(ConnectionState):
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
        self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[])
 | 
			
		||||
        self._ready_task = None
 | 
			
		||||
 | 
			
		||||
    @asyncio.coroutine
 | 
			
		||||
    def _delay_ready(self):
 | 
			
		||||
        launch = self._ready_state.launch
 | 
			
		||||
        while not launch.is_set():
 | 
			
		||||
            # this snippet of code is basically waiting 2 seconds
 | 
			
		||||
            # until the last GUILD_CREATE was sent
 | 
			
		||||
            launch.set()
 | 
			
		||||
            yield from asyncio.sleep(2.0 * self.shard_count, loop=self.loop)
 | 
			
		||||
 | 
			
		||||
        guilds = sorted(self._ready_state.guilds, key=lambda g: g.shard_id)
 | 
			
		||||
 | 
			
		||||
        # we only want to request ~75 guilds per chunk request.
 | 
			
		||||
        # we also want to split the chunks per shard_id
 | 
			
		||||
        for shard_id, sub_guilds in itertools.groupby(guilds, key=lambda g: g.shard_id):
 | 
			
		||||
            sub_guilds = list(sub_guilds)
 | 
			
		||||
 | 
			
		||||
            # split chunks by shard ID
 | 
			
		||||
            chunks = []
 | 
			
		||||
            for guild in sub_guilds:
 | 
			
		||||
                chunks.extend(self.chunks_needed(guild))
 | 
			
		||||
 | 
			
		||||
            splits = [sub_guilds[i:i + 75] for i in range(0, len(sub_guilds), 75)]
 | 
			
		||||
            for split in splits:
 | 
			
		||||
                yield from self.chunker(split, shard_id=shard_id)
 | 
			
		||||
 | 
			
		||||
            # wait for the chunks
 | 
			
		||||
            if chunks:
 | 
			
		||||
                try:
 | 
			
		||||
                    yield from asyncio.wait(chunks, timeout=len(chunks) * 30.0, loop=self.loop)
 | 
			
		||||
                except asyncio.TimeoutError:
 | 
			
		||||
                    log.info('Somehow timed out waiting for chunks for %s shard_id' % shard_id)
 | 
			
		||||
 | 
			
		||||
            self.dispatch('shard_ready', shard_id)
 | 
			
		||||
 | 
			
		||||
            # sleep a second for every shard ID.
 | 
			
		||||
            # yield from asyncio.sleep(1.0, loop=self.loop)
 | 
			
		||||
 | 
			
		||||
        # remove the state
 | 
			
		||||
        try:
 | 
			
		||||
            del self._ready_state
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            pass # already been deleted somehow
 | 
			
		||||
 | 
			
		||||
        # regular users cannot shard so we won't worry about it here.
 | 
			
		||||
 | 
			
		||||
        # dispatch the event
 | 
			
		||||
        self.dispatch('ready')
 | 
			
		||||
 | 
			
		||||
    def parse_ready(self, data):
 | 
			
		||||
        if not hasattr(self, '_ready_state'):
 | 
			
		||||
            self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[])
 | 
			
		||||
 | 
			
		||||
        self.user = self.store_user(data['user'])
 | 
			
		||||
 | 
			
		||||
        guilds = self._ready_state.guilds
 | 
			
		||||
        for guild_data in data['guilds']:
 | 
			
		||||
            guild = self._add_guild_from_data(guild_data)
 | 
			
		||||
            if not self.is_bot or guild.large:
 | 
			
		||||
                guilds.append(guild)
 | 
			
		||||
 | 
			
		||||
        for pm in data.get('private_channels', []):
 | 
			
		||||
            factory, _ = _channel_factory(pm['type'])
 | 
			
		||||
            self._add_private_channel(factory(me=self.user, data=pm, state=self))
 | 
			
		||||
 | 
			
		||||
        if self._ready_task is None:
 | 
			
		||||
            self._ready_task = compat.create_task(self._delay_ready(), loop=self.loop)
 | 
			
		||||
 
 | 
			
		||||
@@ -37,6 +37,9 @@ Client
 | 
			
		||||
.. autoclass:: Client
 | 
			
		||||
    :members:
 | 
			
		||||
 | 
			
		||||
.. autoclass:: AutoShardedClient
 | 
			
		||||
    :members:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Voice
 | 
			
		||||
-----
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user