Move socket and connection state out of Client
Move the socket message handling and Discord connection state tracking out of the Client class. The WebSocket class handles the ws4py based WebSocket to Discord, maintains the keepalive and dispatches socket_<events> based on activity. The ConnectionSTate class maintains the state associated with the WebSocket connection with Discord. In a reconnect and switch gateway scenario this state can be kept for a faster and less disruptive recovery.
This commit is contained in:
parent
c47e31c82e
commit
5e671a0d0d
@ -37,7 +37,7 @@ import requests
|
||||
import json, re, time, copy
|
||||
from collections import deque
|
||||
import threading
|
||||
from ws4py.client.threadedclient import WebSocketClient
|
||||
from ws4py.client import WebSocketBaseClient
|
||||
import sys
|
||||
import logging
|
||||
|
||||
@ -71,67 +71,63 @@ class KeepAliveHandler(threading.Thread):
|
||||
log.debug(msg.format(payload['d']))
|
||||
self.socket.send(json.dumps(payload))
|
||||
|
||||
class Client(object):
|
||||
"""Represents a client connection that connects to Discord.
|
||||
This class is used to interact with the Discord WebSocket and API.
|
||||
class WebSocket(WebSocketBaseClient):
|
||||
def __init__(self, dispatch, url):
|
||||
WebSocketBaseClient.__init__(self, url,
|
||||
protocols=['http-only', 'chat'])
|
||||
self.dispatch = dispatch
|
||||
self.keep_alive = None
|
||||
|
||||
A number of options can be passed to the :class:`Client` via keyword arguments.
|
||||
def opened(self):
|
||||
log.info('Opened at {}'.format(int(time.time())))
|
||||
self.dispatch('socket_opened')
|
||||
|
||||
:param int max_length: The maximum number of messages to store in :attr:`messages`. Defaults to 5000.
|
||||
def closed(self, code, reason=None):
|
||||
if self.keep_alive is not None:
|
||||
self.keep_alive.stop.set()
|
||||
log.info('Closed with {} ("{}") at {}'.format(code, reason,
|
||||
int(time.time())))
|
||||
self.dispatch('socket_closed')
|
||||
|
||||
Instance attributes:
|
||||
def handshake_ok(self):
|
||||
pass
|
||||
|
||||
.. attribute:: user
|
||||
def received_message(self, msg):
|
||||
response = json.loads(str(msg))
|
||||
log.debug('WebSocket Event: {}'.format(response))
|
||||
if response.get('op') != 0:
|
||||
log.info("Unhandled op {}".format(response.get('op')))
|
||||
return # What about op 7?
|
||||
|
||||
A :class:`User` that represents the connected client. None if not logged in.
|
||||
.. attribute:: servers
|
||||
self.dispatch('socket_response', response)
|
||||
event = response.get('t')
|
||||
data = response.get('d')
|
||||
|
||||
A list of :class:`Server` that the connected client has available.
|
||||
.. attribute:: private_channels
|
||||
if event == 'READY':
|
||||
interval = data['heartbeat_interval'] / 1000.0
|
||||
self.keep_alive = KeepAliveHandler(interval, self)
|
||||
self.keep_alive.start()
|
||||
|
||||
A list of :class:`PrivateChannel` that the connected client is participating on.
|
||||
.. attribute:: messages
|
||||
|
||||
A deque_ of :class:`Message` that the client has received from all servers and private messages.
|
||||
.. attribute:: email
|
||||
if event in ('READY', 'MESSAGE_CREATE', 'MESSAGE_DELETE',
|
||||
'MESSAGE_UPDATE', 'PRESENCE_UPDATE', 'USER_UPDATE',
|
||||
'CHANNEL_DELETE', 'CHANNEL_UPDATE', 'CHANNEL_CREATE',
|
||||
'GUILD_MEMBER_ADD', 'GUILD_MEMBER_REMOVE',
|
||||
'GUILD_MEMBER_UPDATE', 'GUILD_CREATE', 'GUILD_DELETE'):
|
||||
self.dispatch('socket_update', event, data)
|
||||
|
||||
The email used to login. This is only set if login is successful, otherwise it's None.
|
||||
else:
|
||||
log.info("Unhandled event {}".format(event))
|
||||
|
||||
.. _deque: https://docs.python.org/3.4/library/collections.html#collections.deque
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self._is_logged_in = False
|
||||
class ConnectionState(object):
|
||||
def __init__(self, dispatch, **kwargs):
|
||||
self.dispatch = dispatch
|
||||
self.user = None
|
||||
self.email = None
|
||||
self.servers = []
|
||||
self.private_channels = []
|
||||
self.token = ''
|
||||
self.messages = deque([], maxlen=kwargs.get('max_length', 5000))
|
||||
self.events = {
|
||||
'on_ready': _null_event,
|
||||
'on_disconnect': _null_event,
|
||||
'on_error': _null_event,
|
||||
'on_response': _null_event,
|
||||
'on_message': _null_event,
|
||||
'on_message_delete': _null_event,
|
||||
'on_message_edit': _null_event,
|
||||
'on_status': _null_event,
|
||||
'on_channel_delete': _null_event,
|
||||
'on_channel_create': _null_event,
|
||||
'on_channel_update': _null_event,
|
||||
'on_member_join': _null_event,
|
||||
'on_member_remove': _null_event,
|
||||
'on_member_update': _null_event,
|
||||
'on_server_create': _null_event,
|
||||
'on_server_delete': _null_event,
|
||||
}
|
||||
|
||||
# the actual headers for the request...
|
||||
# we only override 'authorization' since the rest could use the defaults.
|
||||
self.headers = {
|
||||
'authorization': self.token,
|
||||
}
|
||||
|
||||
def _get_message(self, msg_id):
|
||||
return utils.find(lambda m: m.id == msg_id, self.messages)
|
||||
@ -169,24 +165,225 @@ class Client(object):
|
||||
for member in server.members:
|
||||
member.server = server
|
||||
|
||||
channels = [Channel(server=server, **channel) for channel in guild['channels']]
|
||||
channels = [Channel(server=server, **channel)
|
||||
for channel in guild['channels']]
|
||||
server.channels = channels
|
||||
self.servers.append(server)
|
||||
|
||||
def handle_ready(self, data):
|
||||
self.user = User(**data['user'])
|
||||
guilds = data.get('guilds')
|
||||
|
||||
for guild in guilds:
|
||||
self._add_server(guild)
|
||||
|
||||
for pm in data.get('private_channels'):
|
||||
self.private_channels.append(PrivateChannel(id=pm['id'],
|
||||
user=User(**pm['recipient'])))
|
||||
|
||||
# we're all ready
|
||||
self.dispatch('ready')
|
||||
|
||||
def handle_message_create(self, data):
|
||||
channel = self.get_channel(data.get('channel_id'))
|
||||
message = Message(channel=channel, **data)
|
||||
self.dispatch('message', message)
|
||||
self.messages.append(message)
|
||||
|
||||
def handle_message_delete(self, data):
|
||||
channel = self.get_channel(data.get('channel_id'))
|
||||
message_id = data.get('id')
|
||||
found = self._get_message(message_id)
|
||||
if found is not None:
|
||||
self.dispatch('message_delete', found)
|
||||
self.messages.remove(found)
|
||||
|
||||
def handle_message_update(self, data):
|
||||
older_message = self._get_message(data.get('id'))
|
||||
if older_message is not None:
|
||||
# create a copy of the new message
|
||||
message = copy.deepcopy(older_message)
|
||||
# update the new update
|
||||
for attr in data:
|
||||
if attr == 'channel_id' or attr == 'author':
|
||||
continue
|
||||
value = data[attr]
|
||||
if 'time' in attr:
|
||||
setattr(message, attr, utils.parse_time(value))
|
||||
else:
|
||||
setattr(message, attr, value)
|
||||
self.dispatch('message_edit', older_message, message)
|
||||
# update the older message
|
||||
older_message = message
|
||||
|
||||
def handle_presence_update(self, data):
|
||||
server = self._get_server(data.get('guild_id'))
|
||||
if server is not None:
|
||||
status = data.get('status')
|
||||
user = data['user']
|
||||
member_id = user['id']
|
||||
member = utils.find(lambda m: m.id == member_id, server.members)
|
||||
if member is not None:
|
||||
member.status = data.get('status')
|
||||
member.game_id = data.get('game_id')
|
||||
member.name = user.get('username', member.name)
|
||||
member.avatar = user.get('avatar', member.avatar)
|
||||
|
||||
# call the event now
|
||||
self.dispatch('status', member)
|
||||
self.dispatch('member_update', member)
|
||||
|
||||
def handle_user_update(self, data):
|
||||
self.user = User(**data)
|
||||
|
||||
def handle_channel_delete(self, data):
|
||||
server = self._get_server(data.get('guild_id'))
|
||||
if server is not None:
|
||||
channel_id = data.get('id')
|
||||
channel = utils.find(lambda c: c.id == channel_id, server.channels)
|
||||
server.channels.remove(channel)
|
||||
self.dispatch('channel_delete', channel)
|
||||
|
||||
def handle_channel_update(self, data):
|
||||
server = self._get_server(data.get('guild_id'))
|
||||
if server is not None:
|
||||
channel_id = data.get('id')
|
||||
channel = utils.find(lambda c: c.id == channel_id, server.channels)
|
||||
channel.update(server=server, **data)
|
||||
self.dispatch('channel_update', channel)
|
||||
|
||||
def handle_channel_create(self, data):
|
||||
is_private = data.get('is_private', False)
|
||||
channel = None
|
||||
if is_private:
|
||||
recipient = User(**data.get('recipient'))
|
||||
pm_id = data.get('id')
|
||||
channel = PrivateChannel(id=pm_id, user=recipient)
|
||||
self.private_channels.append(channel)
|
||||
else:
|
||||
server = self._get_server(data.get('guild_id'))
|
||||
if server is not None:
|
||||
channel = Channel(server=server, **data)
|
||||
server.channels.append(channel)
|
||||
|
||||
self.dispatch('channel_create', channel)
|
||||
|
||||
def handle_guild_member_add(self, data):
|
||||
server = self._get_server(data.get('guild_id'))
|
||||
member = Member(server=server, deaf=False, mute=False, **data)
|
||||
server.members.append(member)
|
||||
self.dispatch('member_join', member)
|
||||
|
||||
def handle_guild_member_remove(self, data):
|
||||
server = self._get_server(data.get('guild_id'))
|
||||
user_id = data['user']['id']
|
||||
member = utils.find(lambda m: m.id == user_id, server.members)
|
||||
server.members.remove(member)
|
||||
self.dispatch('member_remove', member)
|
||||
|
||||
def handle_guild_member_update(self, data):
|
||||
server = self._get_server(data.get('guild_id'))
|
||||
user_id = data['user']['id']
|
||||
member = utils.find(lambda m: m.id == user_id, server.members)
|
||||
if member is not None:
|
||||
user = data['user']
|
||||
member.name = user['username']
|
||||
member.discriminator = user['discriminator']
|
||||
member.avatar = user['avatar']
|
||||
member.roles = []
|
||||
# update the roles
|
||||
for role in server.roles:
|
||||
if role.id in data['roles']:
|
||||
member.roles.append(role)
|
||||
|
||||
self.dispatch('member_update', member)
|
||||
|
||||
def handle_guild_create(self, data):
|
||||
self._add_server(data)
|
||||
self.dispatch('server_create', self.servers[-1])
|
||||
|
||||
def handle_guild_delete(self, data):
|
||||
server = self._get_server(data.get('id'))
|
||||
self.servers.remove(server)
|
||||
self.dispatch('server_delete', server)
|
||||
|
||||
def get_channel(self, id):
|
||||
if id is None:
|
||||
return None
|
||||
|
||||
for server in self.servers:
|
||||
for channel in server.channels:
|
||||
if channel.id == id:
|
||||
return channel
|
||||
|
||||
for pm in self.private_channels:
|
||||
if pm.id == id:
|
||||
return pm
|
||||
|
||||
|
||||
class Client(object):
|
||||
"""Represents a client connection that connects to Discord.
|
||||
This class is used to interact with the Discord WebSocket and API.
|
||||
|
||||
A number of options can be passed to the :class:`Client` via keyword arguments.
|
||||
|
||||
:param int max_length: The maximum number of messages to store in :attr:`messages`. Defaults to 5000.
|
||||
|
||||
Instance attributes:
|
||||
|
||||
.. attribute:: user
|
||||
|
||||
A :class:`User` that represents the connected client. None if not logged in.
|
||||
.. attribute:: servers
|
||||
|
||||
A list of :class:`Server` that the connected client has available.
|
||||
.. attribute:: private_channels
|
||||
|
||||
A list of :class:`PrivateChannel` that the connected client is participating on.
|
||||
.. attribute:: messages
|
||||
|
||||
A deque_ of :class:`Message` that the client has received from all servers and private messages.
|
||||
.. attribute:: email
|
||||
|
||||
The email used to login. This is only set if login is successful, otherwise it's None.
|
||||
|
||||
.. _deque: https://docs.python.org/3.4/library/collections.html#collections.deque
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self._is_logged_in = False
|
||||
self.connection = ConnectionState(self.dispatch, **kwargs)
|
||||
self.token = ''
|
||||
self.events = {
|
||||
'on_ready': _null_event,
|
||||
'on_disconnect': _null_event,
|
||||
'on_error': _null_event,
|
||||
'on_response': _null_event,
|
||||
'on_message': _null_event,
|
||||
'on_message_delete': _null_event,
|
||||
'on_message_edit': _null_event,
|
||||
'on_status': _null_event,
|
||||
'on_channel_delete': _null_event,
|
||||
'on_channel_create': _null_event,
|
||||
'on_channel_update': _null_event,
|
||||
'on_member_join': _null_event,
|
||||
'on_member_remove': _null_event,
|
||||
'on_member_update': _null_event,
|
||||
'on_server_create': _null_event,
|
||||
'on_server_delete': _null_event,
|
||||
}
|
||||
|
||||
# the actual headers for the request...
|
||||
# we only override 'authorization' since the rest could use the defaults.
|
||||
self.headers = {
|
||||
'authorization': self.token,
|
||||
}
|
||||
|
||||
def _create_websocket(self, url, reconnect=False):
|
||||
if url is None:
|
||||
raise GatewayNotFound()
|
||||
log.info('websocket gateway found')
|
||||
self.ws = WebSocketClient(url, protocols=['http-only', 'chat'])
|
||||
|
||||
# this is kind of hacky, but it's to avoid deadlocks.
|
||||
# i.e. python does not allow me to have the current thread running if it's self
|
||||
# it throws a 'cannot join current thread' RuntimeError
|
||||
# So instead of doing a basic inheritance scheme, we're overriding the member functions.
|
||||
|
||||
self.ws.opened = self._opened
|
||||
self.ws.closed = self._closed
|
||||
self.ws.received_message = self._received_message
|
||||
self.ws = WebSocket(self.dispatch, url)
|
||||
self.ws.connect()
|
||||
log.info('websocket has connected')
|
||||
|
||||
@ -220,6 +417,23 @@ class Client(object):
|
||||
msg = 'Caught exception in {} with args (*{}, **{})'
|
||||
log.exception(msg.format(event_method, args, kwargs))
|
||||
|
||||
# Compatibility shim
|
||||
def __getattr__(self, name):
|
||||
if name in ('user', 'email', 'servers', 'private_channels', 'messages',
|
||||
'get_channel'):
|
||||
return getattr(self.connection, name)
|
||||
else:
|
||||
msg = "'{}' object has no attribute '{}'"
|
||||
raise AttributeError(msg.format(self.__class__, name))
|
||||
|
||||
# Compatibility shim
|
||||
def __setattr__(self, name, value):
|
||||
if name in ('user', 'email', 'servers', 'private_channels',
|
||||
'messages'):
|
||||
return setattr(self.connection, name, value)
|
||||
else:
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
def dispatch(self, event, *args, **kwargs):
|
||||
log.debug("Dispatching event {}".format(event))
|
||||
handle_method = '_'.join(('handle', event))
|
||||
@ -242,156 +456,14 @@ class Client(object):
|
||||
log.error('an error ({}) occurred in event {} so on_error is invoked instead'.format(type(e).__name__, event_name))
|
||||
self.events['on_error'](event_name, *sys.exc_info())
|
||||
|
||||
def _received_message(self, msg):
|
||||
response = json.loads(str(msg))
|
||||
log.debug('WebSocket Event: {}'.format(response))
|
||||
if response.get('op') != 0:
|
||||
return
|
||||
|
||||
self.dispatch('response', response)
|
||||
event = response.get('t')
|
||||
data = response.get('d')
|
||||
|
||||
if event == 'READY':
|
||||
self.user = User(**data['user'])
|
||||
guilds = data.get('guilds')
|
||||
|
||||
for guild in guilds:
|
||||
self._add_server(guild)
|
||||
|
||||
for pm in data.get('private_channels'):
|
||||
self.private_channels.append(PrivateChannel(id=pm['id'], user=User(**pm['recipient'])))
|
||||
|
||||
# set the keep alive interval..
|
||||
interval = data.get('heartbeat_interval') / 1000.0
|
||||
self.keep_alive = KeepAliveHandler(interval, self.ws)
|
||||
self.keep_alive.start()
|
||||
|
||||
# we're all ready
|
||||
self.dispatch('ready')
|
||||
elif event == 'MESSAGE_CREATE':
|
||||
channel = self.get_channel(data.get('channel_id'))
|
||||
message = Message(channel=channel, **data)
|
||||
self.dispatch('message', message)
|
||||
self.messages.append(message)
|
||||
elif event == 'MESSAGE_DELETE':
|
||||
channel = self.get_channel(data.get('channel_id'))
|
||||
message_id = data.get('id')
|
||||
found = self._get_message(message_id)
|
||||
if found is not None:
|
||||
self.dispatch('message_delete', found)
|
||||
self.messages.remove(found)
|
||||
elif event == 'MESSAGE_UPDATE':
|
||||
older_message = self._get_message(data.get('id'))
|
||||
if older_message is not None:
|
||||
# create a copy of the new message
|
||||
message = copy.deepcopy(older_message)
|
||||
# update the new update
|
||||
for attr in data:
|
||||
if attr == 'channel_id' or attr == 'author':
|
||||
continue
|
||||
value = data[attr]
|
||||
if 'time' in attr:
|
||||
setattr(message, attr, utils.parse_time(value))
|
||||
else:
|
||||
setattr(message, attr, value)
|
||||
self.dispatch('message_edit', older_message, message)
|
||||
# update the older message
|
||||
older_message = message
|
||||
|
||||
elif event == 'PRESENCE_UPDATE':
|
||||
server = self._get_server(data.get('guild_id'))
|
||||
if server is not None:
|
||||
status = data.get('status')
|
||||
user = data['user']
|
||||
member_id = user['id']
|
||||
member = utils.find(lambda m: m.id == member_id, server.members)
|
||||
if member is not None:
|
||||
member.status = data.get('status')
|
||||
member.game_id = data.get('game_id')
|
||||
member.name = user.get('username', member.name)
|
||||
member.avatar = user.get('avatar', member.avatar)
|
||||
|
||||
# call the event now
|
||||
self.dispatch('status', member)
|
||||
self.dispatch('member_update', member)
|
||||
elif event == 'USER_UPDATE':
|
||||
self.user = User(**data)
|
||||
elif event == 'CHANNEL_DELETE':
|
||||
server = self._get_server(data.get('guild_id'))
|
||||
if server is not None:
|
||||
channel_id = data.get('id')
|
||||
channel = utils.find(lambda c: c.id == channel_id, server.channels)
|
||||
server.channels.remove(channel)
|
||||
self.dispatch('channel_delete', channel)
|
||||
elif event == 'CHANNEL_UPDATE':
|
||||
server = self._get_server(data.get('guild_id'))
|
||||
if server is not None:
|
||||
channel_id = data.get('id')
|
||||
channel = utils.find(lambda c: c.id == channel_id, server.channels)
|
||||
channel.update(server=server, **data)
|
||||
self.dispatch('channel_update', channel)
|
||||
elif event == 'CHANNEL_CREATE':
|
||||
is_private = data.get('is_private', False)
|
||||
channel = None
|
||||
if is_private:
|
||||
recipient = User(**data.get('recipient'))
|
||||
pm_id = data.get('id')
|
||||
channel = PrivateChannel(id=pm_id, user=recipient)
|
||||
self.private_channels.append(channel)
|
||||
else:
|
||||
server = self._get_server(data.get('guild_id'))
|
||||
if server is not None:
|
||||
channel = Channel(server=server, **data)
|
||||
server.channels.append(channel)
|
||||
|
||||
self.dispatch('channel_create', channel)
|
||||
elif event == 'GUILD_MEMBER_ADD':
|
||||
server = self._get_server(data.get('guild_id'))
|
||||
member = Member(server=server, deaf=False, mute=False, **data)
|
||||
server.members.append(member)
|
||||
self.dispatch('member_join', member)
|
||||
elif event == 'GUILD_MEMBER_REMOVE':
|
||||
server = self._get_server(data.get('guild_id'))
|
||||
user_id = data['user']['id']
|
||||
member = utils.find(lambda m: m.id == user_id, server.members)
|
||||
server.members.remove(member)
|
||||
self.dispatch('member_remove', member)
|
||||
elif event == 'GUILD_MEMBER_UPDATE':
|
||||
server = self._get_server(data.get('guild_id'))
|
||||
user_id = data['user']['id']
|
||||
member = utils.find(lambda m: m.id == user_id, server.members)
|
||||
if member is not None:
|
||||
user = data['user']
|
||||
member.name = user['username']
|
||||
member.discriminator = user['discriminator']
|
||||
member.avatar = user['avatar']
|
||||
member.roles = []
|
||||
# update the roles
|
||||
for role in server.roles:
|
||||
if role.id in data['roles']:
|
||||
member.roles.append(role)
|
||||
|
||||
self.dispatch('member_update', member)
|
||||
elif event == 'GUILD_CREATE':
|
||||
self._add_server(data)
|
||||
self.dispatch('server_create', self.servers[-1])
|
||||
elif event == 'GUILD_DELETE':
|
||||
server = self._get_server(data.get('id'))
|
||||
self.servers.remove(server)
|
||||
self.dispatch('server_delete', server)
|
||||
|
||||
def _opened(self):
|
||||
log.info('Opened at {}'.format(int(time.time())))
|
||||
|
||||
def _closed(self, code, reason=None):
|
||||
log.info('Closed with {} ("{}") at {}'.format(code, reason, int(time.time())))
|
||||
self.dispatch('disconnect')
|
||||
def handle_socket_update(self, event, data):
|
||||
method = '_'.join(('handle', event.lower()))
|
||||
getattr(self.connection, method)(data)
|
||||
|
||||
def run(self):
|
||||
"""Runs the client and allows it to receive messages and events."""
|
||||
log.info('Client is being run')
|
||||
self.ws.run_forever()
|
||||
self.ws.run()
|
||||
|
||||
@property
|
||||
def is_logged_in(self):
|
||||
@ -399,18 +471,10 @@ class Client(object):
|
||||
return self._is_logged_in
|
||||
|
||||
def get_channel(self, id):
|
||||
"""Returns a :class:`Channel` or :class:`PrivateChannel` with the following ID. If not found, returns None."""
|
||||
if id is None:
|
||||
return None
|
||||
|
||||
for server in self.servers:
|
||||
for channel in server.channels:
|
||||
if channel.id == id:
|
||||
return channel
|
||||
|
||||
for pm in self.private_channels:
|
||||
if pm.id == id:
|
||||
return pm
|
||||
"""Returns a :class:`Channel` or :class:`PrivateChannel` with the
|
||||
following ID. If not found, returns None.
|
||||
"""
|
||||
return self.connection.get_channel(id)
|
||||
|
||||
def start_private_message(self, user):
|
||||
"""Starts a private message with the user. This allows you to :meth:`send_message` to it.
|
||||
@ -578,7 +642,6 @@ class Client(object):
|
||||
response = requests.post(endpoints.LOGOUT)
|
||||
self.ws.close()
|
||||
self._is_logged_in = False
|
||||
self.keep_alive.stop.set()
|
||||
log.debug(request_logging_format.format(name='logout', response=response))
|
||||
|
||||
def logs_from(self, channel, limit=500):
|
||||
|
Loading…
x
Reference in New Issue
Block a user