Implement AutoShardedClient for transparent sharding.
This allows people to run their >2,500 guild bot in a single process without the headaches of IPC/RPC or much difficulty.
This commit is contained in:
parent
d54d7f7ac0
commit
20041ea756
@ -37,6 +37,7 @@ from . import utils, opus, compat, abc
|
|||||||
from .enums import ChannelType, GuildRegion, Status, MessageType, VerificationLevel
|
from .enums import ChannelType, GuildRegion, Status, MessageType, VerificationLevel
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from .embeds import Embed
|
from .embeds import Embed
|
||||||
|
from .shard import AutoShardedClient
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -142,6 +142,7 @@ class Client:
|
|||||||
self.connection = ConnectionState(dispatch=self.dispatch, chunker=self.request_offline_members,
|
self.connection = ConnectionState(dispatch=self.dispatch, chunker=self.request_offline_members,
|
||||||
syncer=self._syncer, http=self.http, loop=self.loop, **options)
|
syncer=self._syncer, http=self.http, loop=self.loop, **options)
|
||||||
|
|
||||||
|
self.connection.shard_count = self.shard_count
|
||||||
self._closed = asyncio.Event(loop=self.loop)
|
self._closed = asyncio.Event(loop=self.loop)
|
||||||
self._is_logged_in = asyncio.Event(loop=self.loop)
|
self._is_logged_in = asyncio.Event(loop=self.loop)
|
||||||
self._is_ready = asyncio.Event(loop=self.loop)
|
self._is_ready = asyncio.Event(loop=self.loop)
|
||||||
@ -405,11 +406,14 @@ class Client:
|
|||||||
|
|
||||||
while not self.is_closed:
|
while not self.is_closed:
|
||||||
try:
|
try:
|
||||||
yield from self.ws.poll_event()
|
yield from ws.poll_event()
|
||||||
except (ReconnectWebSocket, ResumeWebSocket) as e:
|
except (ReconnectWebSocket, ResumeWebSocket) as e:
|
||||||
resume = type(e) is ResumeWebSocket
|
resume = type(e) is ResumeWebSocket
|
||||||
log.info('Got ' + type(e).__name__)
|
log.info('Got ' + type(e).__name__)
|
||||||
self.ws = yield from DiscordWebSocket.from_client(self, resume=resume)
|
self.ws = yield from DiscordWebSocket.from_client(self, shard_id=self.shard_id,
|
||||||
|
session=self.ws.session_id,
|
||||||
|
sequence=self.ws.sequence,
|
||||||
|
resume=resume)
|
||||||
except ConnectionClosed as e:
|
except ConnectionClosed as e:
|
||||||
yield from self.close()
|
yield from self.close()
|
||||||
if e.code != 1000:
|
if e.code != 1000:
|
||||||
|
@ -118,14 +118,17 @@ class ConnectionClosed(ClientException):
|
|||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
-----------
|
-----------
|
||||||
code : int
|
code: int
|
||||||
The close code of the websocket.
|
The close code of the websocket.
|
||||||
reason : str
|
reason: str
|
||||||
The reason provided for the closure.
|
The reason provided for the closure.
|
||||||
|
shard_id: Optional[int]
|
||||||
|
The shard ID that got closed if applicable.
|
||||||
"""
|
"""
|
||||||
def __init__(self, original):
|
def __init__(self, original, *, shard_id):
|
||||||
# This exception is just the same exception except
|
# This exception is just the same exception except
|
||||||
# reconfigured to subclass ClientException for users
|
# reconfigured to subclass ClientException for users
|
||||||
self.code = original.code
|
self.code = original.code
|
||||||
self.reason = original.reason
|
self.reason = original.reason
|
||||||
|
self.shard_id = shard_id
|
||||||
super().__init__(str(original))
|
super().__init__(str(original))
|
||||||
|
@ -47,11 +47,13 @@ __all__ = [ 'ReconnectWebSocket', 'DiscordWebSocket',
|
|||||||
|
|
||||||
class ReconnectWebSocket(Exception):
|
class ReconnectWebSocket(Exception):
|
||||||
"""Signals to handle the RECONNECT opcode."""
|
"""Signals to handle the RECONNECT opcode."""
|
||||||
pass
|
def __init__(self, shard_id):
|
||||||
|
self.shard_id = shard_id
|
||||||
|
|
||||||
class ResumeWebSocket(Exception):
|
class ResumeWebSocket(Exception):
|
||||||
"""Signals to initialise via RESUME opcode instead of IDENTIFY."""
|
"""Signals to initialise via RESUME opcode instead of IDENTIFY."""
|
||||||
pass
|
def __init__(self, shard_id):
|
||||||
|
self.shard_id = shard_id
|
||||||
|
|
||||||
EventListener = namedtuple('EventListener', 'predicate event result future')
|
EventListener = namedtuple('EventListener', 'predicate event result future')
|
||||||
|
|
||||||
@ -81,7 +83,7 @@ class KeepAliveHandler(threading.Thread):
|
|||||||
def get_payload(self):
|
def get_payload(self):
|
||||||
return {
|
return {
|
||||||
'op': self.ws.HEARTBEAT,
|
'op': self.ws.HEARTBEAT,
|
||||||
'd': self.ws._connection.sequence
|
'd': self.ws.sequence
|
||||||
}
|
}
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
@ -165,9 +167,13 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
# the keep alive
|
# the keep alive
|
||||||
self._keep_alive = None
|
self._keep_alive = None
|
||||||
|
|
||||||
|
# ws related stuff
|
||||||
|
self.session_id = None
|
||||||
|
self.sequence = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def from_client(cls, client, *, resume=False):
|
def from_client(cls, client, *, shard_id=None, session=None, sequence=None, resume=False):
|
||||||
"""Creates a main websocket for Discord from a :class:`Client`.
|
"""Creates a main websocket for Discord from a :class:`Client`.
|
||||||
|
|
||||||
This is for internal use only.
|
This is for internal use only.
|
||||||
@ -180,8 +186,10 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
ws._connection = client.connection
|
ws._connection = client.connection
|
||||||
ws._dispatch = client.dispatch
|
ws._dispatch = client.dispatch
|
||||||
ws.gateway = gateway
|
ws.gateway = gateway
|
||||||
ws.shard_id = client.shard_id
|
ws.shard_id = shard_id
|
||||||
ws.shard_count = client.shard_count
|
ws.shard_count = client.connection.shard_count
|
||||||
|
ws.session_id = session
|
||||||
|
ws.sequence = sequence
|
||||||
|
|
||||||
client.connection._update_references(ws)
|
client.connection._update_references(ws)
|
||||||
|
|
||||||
@ -206,6 +214,35 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
else:
|
else:
|
||||||
return ws
|
return ws
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@asyncio.coroutine
|
||||||
|
def from_sharded_client(cls, client):
|
||||||
|
if client.shard_count is None:
|
||||||
|
client.shard_count, gateway = yield from client.http.get_bot_gateway()
|
||||||
|
else:
|
||||||
|
gateway = yield from client.http.get_gateway()
|
||||||
|
|
||||||
|
ret = []
|
||||||
|
client.connection.shard_count = client.shard_count
|
||||||
|
|
||||||
|
for shard_id in range(client.shard_count):
|
||||||
|
ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls)
|
||||||
|
ws.token = client.http.token
|
||||||
|
ws._connection = client.connection
|
||||||
|
ws._dispatch = client.dispatch
|
||||||
|
ws.gateway = gateway
|
||||||
|
ws.shard_id = shard_id
|
||||||
|
ws.shard_count = client.shard_count
|
||||||
|
|
||||||
|
# OP HELLO
|
||||||
|
yield from ws.poll_event()
|
||||||
|
yield from ws.identify()
|
||||||
|
ret.append(ws)
|
||||||
|
log.info('Sent IDENTIFY payload to create the websocket for shard_id: %s' % shard_id)
|
||||||
|
yield from asyncio.sleep(5.0, loop=client.loop)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
def wait_for(self, event, predicate, result=None):
|
def wait_for(self, event, predicate, result=None):
|
||||||
"""Waits for a DISPATCH'd event that meets the predicate.
|
"""Waits for a DISPATCH'd event that meets the predicate.
|
||||||
|
|
||||||
@ -262,12 +299,11 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def resume(self):
|
def resume(self):
|
||||||
"""Sends the RESUME packet."""
|
"""Sends the RESUME packet."""
|
||||||
state = self._connection
|
|
||||||
payload = {
|
payload = {
|
||||||
'op': self.RESUME,
|
'op': self.RESUME,
|
||||||
'd': {
|
'd': {
|
||||||
'seq': state.sequence,
|
'seq': self.sequence,
|
||||||
'session_id': state.session_id,
|
'session_id': self.session_id,
|
||||||
'token': self.token
|
'token': self.token
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -283,16 +319,15 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
msg = msg.decode('utf-8')
|
msg = msg.decode('utf-8')
|
||||||
|
|
||||||
msg = json.loads(msg)
|
msg = json.loads(msg)
|
||||||
state = self._connection
|
|
||||||
|
|
||||||
log.debug('WebSocket Event: {}'.format(msg))
|
log.debug('For Shard ID {}: WebSocket Event: {}'.format(self.shard_id, msg))
|
||||||
self._dispatch('socket_response', msg)
|
self._dispatch('socket_response', msg)
|
||||||
|
|
||||||
op = msg.get('op')
|
op = msg.get('op')
|
||||||
data = msg.get('d')
|
data = msg.get('d')
|
||||||
seq = msg.get('s')
|
seq = msg.get('s')
|
||||||
if seq is not None:
|
if seq is not None:
|
||||||
state.sequence = seq
|
self.sequence = seq
|
||||||
|
|
||||||
if op == self.RECONNECT:
|
if op == self.RECONNECT:
|
||||||
# "reconnect" can only be handled by the Client
|
# "reconnect" can only be handled by the Client
|
||||||
@ -300,7 +335,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
# internal exception signalling to reconnect.
|
# internal exception signalling to reconnect.
|
||||||
log.info('Received RECONNECT opcode.')
|
log.info('Received RECONNECT opcode.')
|
||||||
yield from self.close()
|
yield from self.close()
|
||||||
raise ReconnectWebSocket()
|
raise ReconnectWebSocket(self.shard_id)
|
||||||
|
|
||||||
if op == self.HEARTBEAT_ACK:
|
if op == self.HEARTBEAT_ACK:
|
||||||
return # disable noisy logging for now
|
return # disable noisy logging for now
|
||||||
@ -317,11 +352,11 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if op == self.INVALIDATE_SESSION:
|
if op == self.INVALIDATE_SESSION:
|
||||||
state.sequence = None
|
self.sequence = None
|
||||||
state.session_id = None
|
self.session_id = None
|
||||||
if data == True:
|
if data == True:
|
||||||
yield from self.close()
|
yield from self.close()
|
||||||
raise ResumeWebSocket()
|
raise ResumeWebSocket(self.shard_id)
|
||||||
|
|
||||||
yield from self.identify()
|
yield from self.identify()
|
||||||
return
|
return
|
||||||
@ -334,9 +369,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
is_ready = event == 'READY'
|
is_ready = event == 'READY'
|
||||||
|
|
||||||
if is_ready:
|
if is_ready:
|
||||||
state.clear()
|
self.sequence = msg['s']
|
||||||
state.sequence = msg['s']
|
self.session_id = data['session_id']
|
||||||
state.session_id = data['session_id']
|
|
||||||
|
|
||||||
parser = 'parse_' + event.lower()
|
parser = 'parse_' + event.lower()
|
||||||
|
|
||||||
@ -389,9 +423,9 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
except websockets.exceptions.ConnectionClosed as e:
|
except websockets.exceptions.ConnectionClosed as e:
|
||||||
if self._can_handle_close(e.code):
|
if self._can_handle_close(e.code):
|
||||||
log.info('Websocket closed with {0.code} ({0.reason}), attempting a reconnect.'.format(e))
|
log.info('Websocket closed with {0.code} ({0.reason}), attempting a reconnect.'.format(e))
|
||||||
raise ResumeWebSocket() from e
|
raise ResumeWebSocket(self.shard_id) from e
|
||||||
else:
|
else:
|
||||||
raise ConnectionClosed(e) from e
|
raise ConnectionClosed(e, shard_id=self.shard_id) from e
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def send(self, data):
|
def send(self, data):
|
||||||
@ -404,7 +438,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
yield from super().send(utils.to_json(data))
|
yield from super().send(utils.to_json(data))
|
||||||
except websockets.exceptions.ConnectionClosed as e:
|
except websockets.exceptions.ConnectionClosed as e:
|
||||||
if not self._can_handle_close(e.code):
|
if not self._can_handle_close(e.code):
|
||||||
raise ConnectionClosed(e) from e
|
raise ConnectionClosed(e, shard_id=self.shard_id) from e
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def change_presence(self, *, game=None, status=None, afk=False, since=0.0, idle=None):
|
def change_presence(self, *, game=None, status=None, afk=False, since=0.0, idle=None):
|
||||||
@ -615,7 +649,7 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
msg = yield from asyncio.wait_for(self.recv(), timeout=30.0, loop=self.loop)
|
msg = yield from asyncio.wait_for(self.recv(), timeout=30.0, loop=self.loop)
|
||||||
yield from self.received_message(json.loads(msg))
|
yield from self.received_message(json.loads(msg))
|
||||||
except websockets.exceptions.ConnectionClosed as e:
|
except websockets.exceptions.ConnectionClosed as e:
|
||||||
raise ConnectionClosed(e) from e
|
raise ConnectionClosed(e, shard_id=None) from e
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def close_connection(self, force=False):
|
def close_connection(self, force=False):
|
||||||
|
@ -324,6 +324,14 @@ class Guild(Hashable):
|
|||||||
"""Returns the true member count regardless of it being loaded fully or not."""
|
"""Returns the true member count regardless of it being loaded fully or not."""
|
||||||
return self._member_count
|
return self._member_count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shard_id(self):
|
||||||
|
"""Returns the shard ID for this guild if applicable."""
|
||||||
|
count = self._state.shard_count
|
||||||
|
if count is None:
|
||||||
|
return None
|
||||||
|
return (self.id >> 22) % count
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def created_at(self):
|
def created_at(self):
|
||||||
"""Returns the guild's creation time in UTC."""
|
"""Returns the guild's creation time in UTC."""
|
||||||
|
@ -588,5 +588,14 @@ class HTTPClient:
|
|||||||
raise GatewayNotFound() from e
|
raise GatewayNotFound() from e
|
||||||
return data.get('url') + '?encoding=json&v=6'
|
return data.get('url') + '?encoding=json&v=6'
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def get_bot_gateway(self):
|
||||||
|
try:
|
||||||
|
data = yield from self.get(self.GATEWAY + '/bot', bucket=_func_())
|
||||||
|
except HTTPException as e:
|
||||||
|
raise GatewayNotFound() from e
|
||||||
|
else:
|
||||||
|
return data['shards'], data['url'] + '?encoding=json&v=6'
|
||||||
|
|
||||||
def get_user_info(self, user_id):
|
def get_user_info(self, user_id):
|
||||||
return self.get('{0.USERS}/{1}'.format(self, user_id), bucket=_func_())
|
return self.get('{0.USERS}/{1}'.format(self, user_id), bucket=_func_())
|
||||||
|
174
discord/shard.py
Normal file
174
discord/shard.py
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""
|
||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2015-2016 Rapptz
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
copy of this software and associated documentation files (the "Software"),
|
||||||
|
to deal in the Software without restriction, including without limitation
|
||||||
|
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
Software is furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in
|
||||||
|
all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||||
|
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
DEALINGS IN THE SOFTWARE.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .state import AutoShardedConnectionState
|
||||||
|
from .client import Client
|
||||||
|
from .gateway import *
|
||||||
|
from .errors import ConnectionClosed
|
||||||
|
from . import compat
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class Shard:
|
||||||
|
def __init__(self, ws, client):
|
||||||
|
self.ws = ws
|
||||||
|
self._client = client
|
||||||
|
self.loop = self._client.loop
|
||||||
|
self._current = asyncio.Future(loop=self.loop)
|
||||||
|
self._current.set_result(None) # we just need an already done future
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self):
|
||||||
|
return self.ws.shard_id
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def poll(self):
|
||||||
|
try:
|
||||||
|
yield from self.ws.poll_event()
|
||||||
|
except (ReconnectWebSocket, ResumeWebSocket) as e:
|
||||||
|
resume = type(e) is ResumeWebSocket
|
||||||
|
log.info('Got ' + type(e).__name__)
|
||||||
|
self.ws = yield from DiscordWebSocket.from_client(self._client, resume=resume,
|
||||||
|
shard_id=self.id,
|
||||||
|
session=self.ws.session_id,
|
||||||
|
sequence=self.ws.sequence)
|
||||||
|
except ConnectionClosed as e:
|
||||||
|
yield from self._client.close()
|
||||||
|
if e.code != 1000:
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_future(self):
|
||||||
|
if self._current.done():
|
||||||
|
self._current = compat.create_task(self.poll(), loop=self.loop)
|
||||||
|
|
||||||
|
return self._current
|
||||||
|
|
||||||
|
class AutoShardedClient(Client):
|
||||||
|
"""A client similar to :class:`Client` except it handles the complications
|
||||||
|
of sharding for the user into a more manageable and transparent single
|
||||||
|
process bot.
|
||||||
|
|
||||||
|
When using this client, you will be able to use it as-if it was a regular
|
||||||
|
:class:`Client` with a single shard when implementation wise internally it
|
||||||
|
is split up into multiple shards. This allows you to not have to deal with
|
||||||
|
IPC or other complicated infrastructure.
|
||||||
|
|
||||||
|
It is recommended to use this client only if you have surpassed at least
|
||||||
|
1000 guilds.
|
||||||
|
|
||||||
|
If no :attr:`shard_count` is provided, then the library will use the
|
||||||
|
Bot Gateway endpoint call to figure out how many shards to use.
|
||||||
|
"""
|
||||||
|
def __init__(self, *args, loop=None, **kwargs):
|
||||||
|
kwargs.pop('shard_id', None)
|
||||||
|
super().__init__(*args, loop=loop, **kwargs)
|
||||||
|
|
||||||
|
self.connection = AutoShardedConnectionState(dispatch=self.dispatch, chunker=self.request_offline_members,
|
||||||
|
syncer=self._syncer, http=self.http, loop=self.loop, **kwargs)
|
||||||
|
|
||||||
|
# instead of a single websocket, we have multiple
|
||||||
|
# the index is the shard_id
|
||||||
|
self.shards = []
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def request_offline_members(self, guild, *, shard_id=None):
|
||||||
|
"""|coro|
|
||||||
|
|
||||||
|
Requests previously offline members from the guild to be filled up
|
||||||
|
into the :attr:`Guild.members` cache. This function is usually not
|
||||||
|
called.
|
||||||
|
|
||||||
|
When the client logs on and connects to the websocket, Discord does
|
||||||
|
not provide the library with offline members if the number of members
|
||||||
|
in the guild is larger than 250. You can check if a guild is large
|
||||||
|
if :attr:`Guild.large` is ``True``.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
guild: :class:`Guild` or list
|
||||||
|
The guild to request offline members for. If this parameter is a
|
||||||
|
list then it is interpreted as a list of guilds to request offline
|
||||||
|
members for.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
guild_id = guild.id
|
||||||
|
shard_id = shard_id or guild.shard_id
|
||||||
|
except AttributeError:
|
||||||
|
guild_id = [s.id for s in guild]
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
'op': 8,
|
||||||
|
'd': {
|
||||||
|
'guild_id': guild_id,
|
||||||
|
'query': '',
|
||||||
|
'limit': 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ws = self.shards[shard_id].ws
|
||||||
|
yield from ws.send_as_json(payload)
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def connect(self):
|
||||||
|
"""|coro|
|
||||||
|
|
||||||
|
Creates a websocket connection and lets the websocket listen
|
||||||
|
to messages from discord.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
-------
|
||||||
|
GatewayNotFound
|
||||||
|
If the gateway to connect to discord is not found. Usually if this
|
||||||
|
is thrown then there is a discord API outage.
|
||||||
|
ConnectionClosed
|
||||||
|
The websocket connection has been terminated.
|
||||||
|
"""
|
||||||
|
ret = yield from DiscordWebSocket.from_sharded_client(self)
|
||||||
|
self.shards = [Shard(ws, self) for ws in ret]
|
||||||
|
|
||||||
|
while not self.is_closed:
|
||||||
|
pollers = [shard.get_future() for shard in self.shards]
|
||||||
|
yield from asyncio.wait(pollers, loop=self.loop, return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def close(self):
|
||||||
|
"""|coro|
|
||||||
|
|
||||||
|
Closes the connection to discord.
|
||||||
|
"""
|
||||||
|
if self.is_closed:
|
||||||
|
return
|
||||||
|
|
||||||
|
for shard in self.shards:
|
||||||
|
yield from shard.ws.close()
|
||||||
|
|
||||||
|
yield from self.http.close()
|
||||||
|
self._closed.set()
|
||||||
|
self._is_ready.clear()
|
@ -43,6 +43,7 @@ import datetime
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import weakref
|
import weakref
|
||||||
|
import itertools
|
||||||
|
|
||||||
class ListenerType(enum.Enum):
|
class ListenerType(enum.Enum):
|
||||||
chunk = 0
|
chunk = 0
|
||||||
@ -60,13 +61,12 @@ class ConnectionState:
|
|||||||
self.chunker = chunker
|
self.chunker = chunker
|
||||||
self.syncer = syncer
|
self.syncer = syncer
|
||||||
self.is_bot = None
|
self.is_bot = None
|
||||||
|
self.shard_count = None
|
||||||
self._listeners = []
|
self._listeners = []
|
||||||
self.clear()
|
self.clear()
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.user = None
|
self.user = None
|
||||||
self.sequence = None
|
|
||||||
self.session_id = None
|
|
||||||
self._users = weakref.WeakValueDictionary()
|
self._users = weakref.WeakValueDictionary()
|
||||||
self._calls = {}
|
self._calls = {}
|
||||||
self._emojis = {}
|
self._emojis = {}
|
||||||
@ -355,7 +355,8 @@ class ConnectionState:
|
|||||||
# the reason we're doing this is so it's also removed from the
|
# the reason we're doing this is so it's also removed from the
|
||||||
# private channel by user cache as well
|
# private channel by user cache as well
|
||||||
channel = self._get_private_channel(channel_id)
|
channel = self._get_private_channel(channel_id)
|
||||||
self._remove_private_channel(channel)
|
if channel is not None:
|
||||||
|
self._remove_private_channel(channel)
|
||||||
|
|
||||||
def parse_channel_update(self, data):
|
def parse_channel_update(self, data):
|
||||||
channel_type = try_enum(ChannelType, data.get('type'))
|
channel_type = try_enum(ChannelType, data.get('type'))
|
||||||
@ -701,3 +702,76 @@ class ConnectionState:
|
|||||||
listener = Listener(ListenerType.chunk, future, lambda s: s.id == guild_id)
|
listener = Listener(ListenerType.chunk, future, lambda s: s.id == guild_id)
|
||||||
self._listeners.append(listener)
|
self._listeners.append(listener)
|
||||||
return future
|
return future
|
||||||
|
|
||||||
|
class AutoShardedConnectionState(ConnectionState):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[])
|
||||||
|
self._ready_task = None
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def _delay_ready(self):
|
||||||
|
launch = self._ready_state.launch
|
||||||
|
while not launch.is_set():
|
||||||
|
# this snippet of code is basically waiting 2 seconds
|
||||||
|
# until the last GUILD_CREATE was sent
|
||||||
|
launch.set()
|
||||||
|
yield from asyncio.sleep(2.0 * self.shard_count, loop=self.loop)
|
||||||
|
|
||||||
|
guilds = sorted(self._ready_state.guilds, key=lambda g: g.shard_id)
|
||||||
|
|
||||||
|
# we only want to request ~75 guilds per chunk request.
|
||||||
|
# we also want to split the chunks per shard_id
|
||||||
|
for shard_id, sub_guilds in itertools.groupby(guilds, key=lambda g: g.shard_id):
|
||||||
|
sub_guilds = list(sub_guilds)
|
||||||
|
|
||||||
|
# split chunks by shard ID
|
||||||
|
chunks = []
|
||||||
|
for guild in sub_guilds:
|
||||||
|
chunks.extend(self.chunks_needed(guild))
|
||||||
|
|
||||||
|
splits = [sub_guilds[i:i + 75] for i in range(0, len(sub_guilds), 75)]
|
||||||
|
for split in splits:
|
||||||
|
yield from self.chunker(split, shard_id=shard_id)
|
||||||
|
|
||||||
|
# wait for the chunks
|
||||||
|
if chunks:
|
||||||
|
try:
|
||||||
|
yield from asyncio.wait(chunks, timeout=len(chunks) * 30.0, loop=self.loop)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
log.info('Somehow timed out waiting for chunks for %s shard_id' % shard_id)
|
||||||
|
|
||||||
|
self.dispatch('shard_ready', shard_id)
|
||||||
|
|
||||||
|
# sleep a second for every shard ID.
|
||||||
|
# yield from asyncio.sleep(1.0, loop=self.loop)
|
||||||
|
|
||||||
|
# remove the state
|
||||||
|
try:
|
||||||
|
del self._ready_state
|
||||||
|
except AttributeError:
|
||||||
|
pass # already been deleted somehow
|
||||||
|
|
||||||
|
# regular users cannot shard so we won't worry about it here.
|
||||||
|
|
||||||
|
# dispatch the event
|
||||||
|
self.dispatch('ready')
|
||||||
|
|
||||||
|
def parse_ready(self, data):
|
||||||
|
if not hasattr(self, '_ready_state'):
|
||||||
|
self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[])
|
||||||
|
|
||||||
|
self.user = self.store_user(data['user'])
|
||||||
|
|
||||||
|
guilds = self._ready_state.guilds
|
||||||
|
for guild_data in data['guilds']:
|
||||||
|
guild = self._add_guild_from_data(guild_data)
|
||||||
|
if not self.is_bot or guild.large:
|
||||||
|
guilds.append(guild)
|
||||||
|
|
||||||
|
for pm in data.get('private_channels', []):
|
||||||
|
factory, _ = _channel_factory(pm['type'])
|
||||||
|
self._add_private_channel(factory(me=self.user, data=pm, state=self))
|
||||||
|
|
||||||
|
if self._ready_task is None:
|
||||||
|
self._ready_task = compat.create_task(self._delay_ready(), loop=self.loop)
|
||||||
|
@ -37,6 +37,9 @@ Client
|
|||||||
.. autoclass:: Client
|
.. autoclass:: Client
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
.. autoclass:: AutoShardedClient
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
Voice
|
Voice
|
||||||
-----
|
-----
|
||||||
|
Loading…
x
Reference in New Issue
Block a user