Support for v5 Gateway.
This commit is contained in:
parent
c3c9db7777
commit
8b3617111a
@ -108,7 +108,7 @@ class Channel(Hashable):
|
||||
|
||||
self._permission_overwrites = []
|
||||
everyone_index = 0
|
||||
everyone_id = self.server.default_role.id
|
||||
everyone_id = self.server.id
|
||||
|
||||
for index, overridden in enumerate(kwargs.get('permission_overwrites', [])):
|
||||
overridden_id = overridden['id']
|
||||
|
@ -138,7 +138,8 @@ class Client:
|
||||
if max_messages is None or max_messages < 100:
|
||||
max_messages = 5000
|
||||
|
||||
self.connection = ConnectionState(self.dispatch, self.request_offline_members, max_messages, loop=self.loop)
|
||||
self.connection = ConnectionState(self.dispatch, self.request_offline_members,
|
||||
self._syncer, max_messages, loop=self.loop)
|
||||
|
||||
connector = options.pop('connector', None)
|
||||
self.http = HTTPClient(connector, loop=self.loop)
|
||||
@ -149,6 +150,10 @@ class Client:
|
||||
|
||||
# internals
|
||||
|
||||
@asyncio.coroutine
|
||||
def _syncer(self, guilds):
|
||||
yield from self.ws.request_sync(guilds)
|
||||
|
||||
def _get_cache_filename(self, email):
|
||||
filename = hashlib.md5(email.encode('utf-8')).hexdigest()
|
||||
return os.path.join(tempfile.gettempdir(), 'discord_py', filename)
|
||||
@ -295,12 +300,16 @@ class Client:
|
||||
@asyncio.coroutine
|
||||
def _login_1(self, token, **kwargs):
|
||||
log.info('logging in using static token')
|
||||
yield from self.http.static_login(token, bot=kwargs.pop('bot', True))
|
||||
is_bot = kwargs.pop('bot', True)
|
||||
yield from self.http.static_login(token, bot=is_bot)
|
||||
self.connection.is_bot = is_bot
|
||||
self._is_logged_in.set()
|
||||
|
||||
@asyncio.coroutine
|
||||
def _login_2(self, email, password, **kwargs):
|
||||
# attempt to read the token from cache
|
||||
self.connection.is_bot = False
|
||||
|
||||
if self.cache_auth:
|
||||
token = self._get_cache_token(email, password)
|
||||
try:
|
||||
|
@ -127,6 +127,13 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
INVALIDATE_SESSION
|
||||
Receive only. Tells the client to invalidate the session and IDENTIFY
|
||||
again.
|
||||
HELLO
|
||||
Receive only. Tells the client the heartbeat interval.
|
||||
HEARTBEAT_ACK
|
||||
Receive only. Confirms receiving of a heartbeat. Not having it implies
|
||||
a connection issue.
|
||||
GUILD_SYNC
|
||||
Send only. Requests a guild sync.
|
||||
gateway
|
||||
The gateway we are currently connected to.
|
||||
token
|
||||
@ -143,6 +150,9 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
RECONNECT = 7
|
||||
REQUEST_MEMBERS = 8
|
||||
INVALIDATE_SESSION = 9
|
||||
HELLO = 10
|
||||
HEARTBEAT_ACK = 11
|
||||
GUILD_SYNC = 12
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, max_size=None, **kwargs)
|
||||
@ -172,6 +182,10 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
client.connection._update_references(ws)
|
||||
|
||||
log.info('Created websocket connected to {}'.format(gateway))
|
||||
|
||||
# poll the event for OP HELLO
|
||||
yield from ws.poll_event()
|
||||
|
||||
if not resume:
|
||||
yield from ws.identify()
|
||||
log.info('sent the identify payload to create the websocket')
|
||||
@ -232,6 +246,10 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
'v': 3
|
||||
}
|
||||
}
|
||||
|
||||
if not self._connection.is_bot:
|
||||
payload['d']['synced_guilds'] = []
|
||||
|
||||
yield from self.send_as_json(payload)
|
||||
|
||||
@asyncio.coroutine
|
||||
@ -277,6 +295,12 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
yield from self.close()
|
||||
raise ReconnectWebSocket()
|
||||
|
||||
if op == self.HELLO:
|
||||
interval = data['heartbeat_interval'] / 1000.0
|
||||
self._keep_alive = KeepAliveHandler(ws=self, interval=interval)
|
||||
self._keep_alive.start()
|
||||
return
|
||||
|
||||
if op == self.INVALIDATE_SESSION:
|
||||
state.sequence = None
|
||||
state.session_id = None
|
||||
@ -298,11 +322,6 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
state.sequence = msg['s']
|
||||
state.session_id = data['session_id']
|
||||
|
||||
if is_ready or event == 'RESUMED':
|
||||
interval = data['heartbeat_interval'] / 1000.0
|
||||
self._keep_alive = KeepAliveHandler(ws=self, interval=interval)
|
||||
self._keep_alive.start()
|
||||
|
||||
parser = 'parse_' + event.lower()
|
||||
|
||||
try:
|
||||
@ -400,6 +419,14 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
status = Status.idle if idle_since else Status.online
|
||||
me.status = status
|
||||
|
||||
@asyncio.coroutine
|
||||
def request_sync(self, guild_ids):
|
||||
payload = {
|
||||
'op': self.GUILD_SYNC,
|
||||
'd': list(guild_ids)
|
||||
}
|
||||
yield from self.send_as_json(payload)
|
||||
|
||||
@asyncio.coroutine
|
||||
def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False):
|
||||
payload = {
|
||||
|
@ -497,4 +497,4 @@ class HTTPClient:
|
||||
data = yield from self.get(self.GATEWAY, bucket=_func_())
|
||||
except HTTPException as e:
|
||||
raise GatewayNotFound() from e
|
||||
return data.get('url') + '?encoding=json&v=4'
|
||||
return data.get('url') + '?encoding=json&v=5'
|
||||
|
@ -169,7 +169,6 @@ class Server(Hashable):
|
||||
self._member_count = member_count
|
||||
|
||||
self.name = guild.get('name')
|
||||
self.large = guild.get('large', None if member_count is None else self._member_count > 250)
|
||||
self.region = guild.get('region')
|
||||
try:
|
||||
self.region = ServerRegion(self.region)
|
||||
@ -181,24 +180,36 @@ class Server(Hashable):
|
||||
self.unavailable = guild.get('unavailable', False)
|
||||
self.id = guild['id']
|
||||
self.roles = [Role(server=self, **r) for r in guild.get('roles', [])]
|
||||
|
||||
for data in guild.get('members', []):
|
||||
roles = [self.default_role]
|
||||
for role_id in data['roles']:
|
||||
role = utils.find(lambda r: r.id == role_id, self.roles)
|
||||
if role is not None:
|
||||
roles.append(role)
|
||||
|
||||
data['roles'] = roles
|
||||
member = Member(**data)
|
||||
member.server = self
|
||||
self._add_member(member)
|
||||
self._sync(guild)
|
||||
self.large = None if member_count is None else self._member_count > 250
|
||||
|
||||
if 'owner_id' in guild:
|
||||
self.owner_id = guild['owner_id']
|
||||
self.owner = self.get_member(self.owner_id)
|
||||
|
||||
for presence in guild.get('presences', []):
|
||||
afk_id = guild.get('afk_channel_id')
|
||||
self.afk_channel = self.get_channel(afk_id)
|
||||
|
||||
for obj in guild.get('voice_states', []):
|
||||
self._update_voice_state(obj)
|
||||
|
||||
def _sync(self, data):
|
||||
if 'large' in data:
|
||||
self.large = data['large']
|
||||
|
||||
for mdata in data.get('members', []):
|
||||
roles = [self.default_role]
|
||||
for role_id in mdata['roles']:
|
||||
role = utils.find(lambda r: r.id == role_id, self.roles)
|
||||
if role is not None:
|
||||
roles.append(role)
|
||||
|
||||
mdata['roles'] = roles
|
||||
member = Member(**mdata)
|
||||
member.server = self
|
||||
self._add_member(member)
|
||||
|
||||
for presence in data.get('presences', []):
|
||||
user_id = presence['user']['id']
|
||||
member = self.get_member(user_id)
|
||||
if member is not None:
|
||||
@ -210,17 +221,12 @@ class Server(Hashable):
|
||||
game = presence.get('game', {})
|
||||
member.game = Game(**game) if game else None
|
||||
|
||||
if 'channels' in guild:
|
||||
channels = guild['channels']
|
||||
if 'channels' in data:
|
||||
channels = data['channels']
|
||||
for c in channels:
|
||||
channel = Channel(server=self, **c)
|
||||
self._add_channel(channel)
|
||||
|
||||
afk_id = guild.get('afk_channel_id')
|
||||
self.afk_channel = self.get_channel(afk_id)
|
||||
|
||||
for obj in guild.get('voice_states', []):
|
||||
self._update_voice_state(obj)
|
||||
|
||||
@utils.cached_slot_property('_default_role')
|
||||
def default_role(self):
|
||||
|
@ -49,11 +49,13 @@ log = logging.getLogger(__name__)
|
||||
ReadyState = namedtuple('ReadyState', ('launch', 'servers'))
|
||||
|
||||
class ConnectionState:
|
||||
def __init__(self, dispatch, chunker, max_messages, *, loop):
|
||||
def __init__(self, dispatch, chunker, syncer, max_messages, *, loop):
|
||||
self.loop = loop
|
||||
self.max_messages = max_messages
|
||||
self.dispatch = dispatch
|
||||
self.chunker = chunker
|
||||
self.syncer = syncer
|
||||
self.is_bot = None
|
||||
self._listeners = []
|
||||
self.clear()
|
||||
|
||||
@ -165,8 +167,9 @@ class ConnectionState:
|
||||
launch.set()
|
||||
yield from asyncio.sleep(2)
|
||||
|
||||
# get all the chunks
|
||||
servers = self._ready_state.servers
|
||||
|
||||
# get all the chunks
|
||||
chunks = []
|
||||
for server in servers:
|
||||
chunks.extend(self.chunks_needed(server))
|
||||
@ -194,9 +197,12 @@ class ConnectionState:
|
||||
servers = self._ready_state.servers
|
||||
for guild in guilds:
|
||||
server = self._add_server_from_data(guild)
|
||||
if server.large:
|
||||
if server.large or not self.is_bot:
|
||||
servers.append(server)
|
||||
|
||||
if not self.is_bot:
|
||||
compat.create_task(self.syncer([s.id for s in self.servers]), loop=self.loop)
|
||||
|
||||
for pm in data.get('private_channels'):
|
||||
self._add_private_channel(PrivateChannel(id=pm['id'],
|
||||
user=User(**pm['recipient'])))
|
||||
@ -427,6 +433,10 @@ class ConnectionState:
|
||||
else:
|
||||
self.dispatch('server_join', server)
|
||||
|
||||
def parse_guild_sync(self, data):
|
||||
server = self._get_server(data.get('id'))
|
||||
server._sync(data)
|
||||
|
||||
def parse_guild_update(self, data):
|
||||
server = self._get_server(data.get('id'))
|
||||
if server is not None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user