Move socket and connection state out of Client

Move the socket message handling and Discord connection state tracking
out of the Client class.  The WebSocket class handles the ws4py based
WebSocket to Discord, maintains the keepalive and dispatches
socket_<events> based on activity.  The ConnectionSTate class maintains
the state associated with the WebSocket connection with Discord.  In a
reconnect and switch gateway scenario this state can be kept for a
faster and less disruptive recovery.
This commit is contained in:
Hornwitser 2015-09-29 08:47:37 +02:00
parent c47e31c82e
commit 5e671a0d0d

View File

@ -37,7 +37,7 @@ import requests
import json, re, time, copy import json, re, time, copy
from collections import deque from collections import deque
import threading import threading
from ws4py.client.threadedclient import WebSocketClient from ws4py.client import WebSocketBaseClient
import sys import sys
import logging import logging
@ -71,67 +71,63 @@ class KeepAliveHandler(threading.Thread):
log.debug(msg.format(payload['d'])) log.debug(msg.format(payload['d']))
self.socket.send(json.dumps(payload)) self.socket.send(json.dumps(payload))
class Client(object): class WebSocket(WebSocketBaseClient):
"""Represents a client connection that connects to Discord. def __init__(self, dispatch, url):
This class is used to interact with the Discord WebSocket and API. WebSocketBaseClient.__init__(self, url,
protocols=['http-only', 'chat'])
self.dispatch = dispatch
self.keep_alive = None
A number of options can be passed to the :class:`Client` via keyword arguments. def opened(self):
log.info('Opened at {}'.format(int(time.time())))
self.dispatch('socket_opened')
:param int max_length: The maximum number of messages to store in :attr:`messages`. Defaults to 5000. def closed(self, code, reason=None):
if self.keep_alive is not None:
self.keep_alive.stop.set()
log.info('Closed with {} ("{}") at {}'.format(code, reason,
int(time.time())))
self.dispatch('socket_closed')
Instance attributes: def handshake_ok(self):
pass
.. attribute:: user def received_message(self, msg):
response = json.loads(str(msg))
log.debug('WebSocket Event: {}'.format(response))
if response.get('op') != 0:
log.info("Unhandled op {}".format(response.get('op')))
return # What about op 7?
A :class:`User` that represents the connected client. None if not logged in. self.dispatch('socket_response', response)
.. attribute:: servers event = response.get('t')
data = response.get('d')
A list of :class:`Server` that the connected client has available. if event == 'READY':
.. attribute:: private_channels interval = data['heartbeat_interval'] / 1000.0
self.keep_alive = KeepAliveHandler(interval, self)
self.keep_alive.start()
A list of :class:`PrivateChannel` that the connected client is participating on.
.. attribute:: messages
A deque_ of :class:`Message` that the client has received from all servers and private messages. if event in ('READY', 'MESSAGE_CREATE', 'MESSAGE_DELETE',
.. attribute:: email 'MESSAGE_UPDATE', 'PRESENCE_UPDATE', 'USER_UPDATE',
'CHANNEL_DELETE', 'CHANNEL_UPDATE', 'CHANNEL_CREATE',
'GUILD_MEMBER_ADD', 'GUILD_MEMBER_REMOVE',
'GUILD_MEMBER_UPDATE', 'GUILD_CREATE', 'GUILD_DELETE'):
self.dispatch('socket_update', event, data)
The email used to login. This is only set if login is successful, otherwise it's None. else:
log.info("Unhandled event {}".format(event))
.. _deque: https://docs.python.org/3.4/library/collections.html#collections.deque
"""
def __init__(self, **kwargs): class ConnectionState(object):
self._is_logged_in = False def __init__(self, dispatch, **kwargs):
self.dispatch = dispatch
self.user = None self.user = None
self.email = None self.email = None
self.servers = [] self.servers = []
self.private_channels = [] self.private_channels = []
self.token = ''
self.messages = deque([], maxlen=kwargs.get('max_length', 5000)) self.messages = deque([], maxlen=kwargs.get('max_length', 5000))
self.events = {
'on_ready': _null_event,
'on_disconnect': _null_event,
'on_error': _null_event,
'on_response': _null_event,
'on_message': _null_event,
'on_message_delete': _null_event,
'on_message_edit': _null_event,
'on_status': _null_event,
'on_channel_delete': _null_event,
'on_channel_create': _null_event,
'on_channel_update': _null_event,
'on_member_join': _null_event,
'on_member_remove': _null_event,
'on_member_update': _null_event,
'on_server_create': _null_event,
'on_server_delete': _null_event,
}
# the actual headers for the request...
# we only override 'authorization' since the rest could use the defaults.
self.headers = {
'authorization': self.token,
}
def _get_message(self, msg_id): def _get_message(self, msg_id):
return utils.find(lambda m: m.id == msg_id, self.messages) return utils.find(lambda m: m.id == msg_id, self.messages)
@ -169,24 +165,225 @@ class Client(object):
for member in server.members: for member in server.members:
member.server = server member.server = server
channels = [Channel(server=server, **channel) for channel in guild['channels']] channels = [Channel(server=server, **channel)
for channel in guild['channels']]
server.channels = channels server.channels = channels
self.servers.append(server) self.servers.append(server)
def handle_ready(self, data):
self.user = User(**data['user'])
guilds = data.get('guilds')
for guild in guilds:
self._add_server(guild)
for pm in data.get('private_channels'):
self.private_channels.append(PrivateChannel(id=pm['id'],
user=User(**pm['recipient'])))
# we're all ready
self.dispatch('ready')
def handle_message_create(self, data):
channel = self.get_channel(data.get('channel_id'))
message = Message(channel=channel, **data)
self.dispatch('message', message)
self.messages.append(message)
def handle_message_delete(self, data):
channel = self.get_channel(data.get('channel_id'))
message_id = data.get('id')
found = self._get_message(message_id)
if found is not None:
self.dispatch('message_delete', found)
self.messages.remove(found)
def handle_message_update(self, data):
older_message = self._get_message(data.get('id'))
if older_message is not None:
# create a copy of the new message
message = copy.deepcopy(older_message)
# update the new update
for attr in data:
if attr == 'channel_id' or attr == 'author':
continue
value = data[attr]
if 'time' in attr:
setattr(message, attr, utils.parse_time(value))
else:
setattr(message, attr, value)
self.dispatch('message_edit', older_message, message)
# update the older message
older_message = message
def handle_presence_update(self, data):
server = self._get_server(data.get('guild_id'))
if server is not None:
status = data.get('status')
user = data['user']
member_id = user['id']
member = utils.find(lambda m: m.id == member_id, server.members)
if member is not None:
member.status = data.get('status')
member.game_id = data.get('game_id')
member.name = user.get('username', member.name)
member.avatar = user.get('avatar', member.avatar)
# call the event now
self.dispatch('status', member)
self.dispatch('member_update', member)
def handle_user_update(self, data):
self.user = User(**data)
def handle_channel_delete(self, data):
server = self._get_server(data.get('guild_id'))
if server is not None:
channel_id = data.get('id')
channel = utils.find(lambda c: c.id == channel_id, server.channels)
server.channels.remove(channel)
self.dispatch('channel_delete', channel)
def handle_channel_update(self, data):
server = self._get_server(data.get('guild_id'))
if server is not None:
channel_id = data.get('id')
channel = utils.find(lambda c: c.id == channel_id, server.channels)
channel.update(server=server, **data)
self.dispatch('channel_update', channel)
def handle_channel_create(self, data):
is_private = data.get('is_private', False)
channel = None
if is_private:
recipient = User(**data.get('recipient'))
pm_id = data.get('id')
channel = PrivateChannel(id=pm_id, user=recipient)
self.private_channels.append(channel)
else:
server = self._get_server(data.get('guild_id'))
if server is not None:
channel = Channel(server=server, **data)
server.channels.append(channel)
self.dispatch('channel_create', channel)
def handle_guild_member_add(self, data):
server = self._get_server(data.get('guild_id'))
member = Member(server=server, deaf=False, mute=False, **data)
server.members.append(member)
self.dispatch('member_join', member)
def handle_guild_member_remove(self, data):
server = self._get_server(data.get('guild_id'))
user_id = data['user']['id']
member = utils.find(lambda m: m.id == user_id, server.members)
server.members.remove(member)
self.dispatch('member_remove', member)
def handle_guild_member_update(self, data):
server = self._get_server(data.get('guild_id'))
user_id = data['user']['id']
member = utils.find(lambda m: m.id == user_id, server.members)
if member is not None:
user = data['user']
member.name = user['username']
member.discriminator = user['discriminator']
member.avatar = user['avatar']
member.roles = []
# update the roles
for role in server.roles:
if role.id in data['roles']:
member.roles.append(role)
self.dispatch('member_update', member)
def handle_guild_create(self, data):
self._add_server(data)
self.dispatch('server_create', self.servers[-1])
def handle_guild_delete(self, data):
server = self._get_server(data.get('id'))
self.servers.remove(server)
self.dispatch('server_delete', server)
def get_channel(self, id):
if id is None:
return None
for server in self.servers:
for channel in server.channels:
if channel.id == id:
return channel
for pm in self.private_channels:
if pm.id == id:
return pm
class Client(object):
"""Represents a client connection that connects to Discord.
This class is used to interact with the Discord WebSocket and API.
A number of options can be passed to the :class:`Client` via keyword arguments.
:param int max_length: The maximum number of messages to store in :attr:`messages`. Defaults to 5000.
Instance attributes:
.. attribute:: user
A :class:`User` that represents the connected client. None if not logged in.
.. attribute:: servers
A list of :class:`Server` that the connected client has available.
.. attribute:: private_channels
A list of :class:`PrivateChannel` that the connected client is participating on.
.. attribute:: messages
A deque_ of :class:`Message` that the client has received from all servers and private messages.
.. attribute:: email
The email used to login. This is only set if login is successful, otherwise it's None.
.. _deque: https://docs.python.org/3.4/library/collections.html#collections.deque
"""
def __init__(self, **kwargs):
self._is_logged_in = False
self.connection = ConnectionState(self.dispatch, **kwargs)
self.token = ''
self.events = {
'on_ready': _null_event,
'on_disconnect': _null_event,
'on_error': _null_event,
'on_response': _null_event,
'on_message': _null_event,
'on_message_delete': _null_event,
'on_message_edit': _null_event,
'on_status': _null_event,
'on_channel_delete': _null_event,
'on_channel_create': _null_event,
'on_channel_update': _null_event,
'on_member_join': _null_event,
'on_member_remove': _null_event,
'on_member_update': _null_event,
'on_server_create': _null_event,
'on_server_delete': _null_event,
}
# the actual headers for the request...
# we only override 'authorization' since the rest could use the defaults.
self.headers = {
'authorization': self.token,
}
def _create_websocket(self, url, reconnect=False): def _create_websocket(self, url, reconnect=False):
if url is None: if url is None:
raise GatewayNotFound() raise GatewayNotFound()
log.info('websocket gateway found') log.info('websocket gateway found')
self.ws = WebSocketClient(url, protocols=['http-only', 'chat']) self.ws = WebSocket(self.dispatch, url)
# this is kind of hacky, but it's to avoid deadlocks.
# i.e. python does not allow me to have the current thread running if it's self
# it throws a 'cannot join current thread' RuntimeError
# So instead of doing a basic inheritance scheme, we're overriding the member functions.
self.ws.opened = self._opened
self.ws.closed = self._closed
self.ws.received_message = self._received_message
self.ws.connect() self.ws.connect()
log.info('websocket has connected') log.info('websocket has connected')
@ -220,6 +417,23 @@ class Client(object):
msg = 'Caught exception in {} with args (*{}, **{})' msg = 'Caught exception in {} with args (*{}, **{})'
log.exception(msg.format(event_method, args, kwargs)) log.exception(msg.format(event_method, args, kwargs))
# Compatibility shim
def __getattr__(self, name):
if name in ('user', 'email', 'servers', 'private_channels', 'messages',
'get_channel'):
return getattr(self.connection, name)
else:
msg = "'{}' object has no attribute '{}'"
raise AttributeError(msg.format(self.__class__, name))
# Compatibility shim
def __setattr__(self, name, value):
if name in ('user', 'email', 'servers', 'private_channels',
'messages'):
return setattr(self.connection, name, value)
else:
object.__setattr__(self, name, value)
def dispatch(self, event, *args, **kwargs): def dispatch(self, event, *args, **kwargs):
log.debug("Dispatching event {}".format(event)) log.debug("Dispatching event {}".format(event))
handle_method = '_'.join(('handle', event)) handle_method = '_'.join(('handle', event))
@ -242,156 +456,14 @@ class Client(object):
log.error('an error ({}) occurred in event {} so on_error is invoked instead'.format(type(e).__name__, event_name)) log.error('an error ({}) occurred in event {} so on_error is invoked instead'.format(type(e).__name__, event_name))
self.events['on_error'](event_name, *sys.exc_info()) self.events['on_error'](event_name, *sys.exc_info())
def _received_message(self, msg): def handle_socket_update(self, event, data):
response = json.loads(str(msg)) method = '_'.join(('handle', event.lower()))
log.debug('WebSocket Event: {}'.format(response)) getattr(self.connection, method)(data)
if response.get('op') != 0:
return
self.dispatch('response', response)
event = response.get('t')
data = response.get('d')
if event == 'READY':
self.user = User(**data['user'])
guilds = data.get('guilds')
for guild in guilds:
self._add_server(guild)
for pm in data.get('private_channels'):
self.private_channels.append(PrivateChannel(id=pm['id'], user=User(**pm['recipient'])))
# set the keep alive interval..
interval = data.get('heartbeat_interval') / 1000.0
self.keep_alive = KeepAliveHandler(interval, self.ws)
self.keep_alive.start()
# we're all ready
self.dispatch('ready')
elif event == 'MESSAGE_CREATE':
channel = self.get_channel(data.get('channel_id'))
message = Message(channel=channel, **data)
self.dispatch('message', message)
self.messages.append(message)
elif event == 'MESSAGE_DELETE':
channel = self.get_channel(data.get('channel_id'))
message_id = data.get('id')
found = self._get_message(message_id)
if found is not None:
self.dispatch('message_delete', found)
self.messages.remove(found)
elif event == 'MESSAGE_UPDATE':
older_message = self._get_message(data.get('id'))
if older_message is not None:
# create a copy of the new message
message = copy.deepcopy(older_message)
# update the new update
for attr in data:
if attr == 'channel_id' or attr == 'author':
continue
value = data[attr]
if 'time' in attr:
setattr(message, attr, utils.parse_time(value))
else:
setattr(message, attr, value)
self.dispatch('message_edit', older_message, message)
# update the older message
older_message = message
elif event == 'PRESENCE_UPDATE':
server = self._get_server(data.get('guild_id'))
if server is not None:
status = data.get('status')
user = data['user']
member_id = user['id']
member = utils.find(lambda m: m.id == member_id, server.members)
if member is not None:
member.status = data.get('status')
member.game_id = data.get('game_id')
member.name = user.get('username', member.name)
member.avatar = user.get('avatar', member.avatar)
# call the event now
self.dispatch('status', member)
self.dispatch('member_update', member)
elif event == 'USER_UPDATE':
self.user = User(**data)
elif event == 'CHANNEL_DELETE':
server = self._get_server(data.get('guild_id'))
if server is not None:
channel_id = data.get('id')
channel = utils.find(lambda c: c.id == channel_id, server.channels)
server.channels.remove(channel)
self.dispatch('channel_delete', channel)
elif event == 'CHANNEL_UPDATE':
server = self._get_server(data.get('guild_id'))
if server is not None:
channel_id = data.get('id')
channel = utils.find(lambda c: c.id == channel_id, server.channels)
channel.update(server=server, **data)
self.dispatch('channel_update', channel)
elif event == 'CHANNEL_CREATE':
is_private = data.get('is_private', False)
channel = None
if is_private:
recipient = User(**data.get('recipient'))
pm_id = data.get('id')
channel = PrivateChannel(id=pm_id, user=recipient)
self.private_channels.append(channel)
else:
server = self._get_server(data.get('guild_id'))
if server is not None:
channel = Channel(server=server, **data)
server.channels.append(channel)
self.dispatch('channel_create', channel)
elif event == 'GUILD_MEMBER_ADD':
server = self._get_server(data.get('guild_id'))
member = Member(server=server, deaf=False, mute=False, **data)
server.members.append(member)
self.dispatch('member_join', member)
elif event == 'GUILD_MEMBER_REMOVE':
server = self._get_server(data.get('guild_id'))
user_id = data['user']['id']
member = utils.find(lambda m: m.id == user_id, server.members)
server.members.remove(member)
self.dispatch('member_remove', member)
elif event == 'GUILD_MEMBER_UPDATE':
server = self._get_server(data.get('guild_id'))
user_id = data['user']['id']
member = utils.find(lambda m: m.id == user_id, server.members)
if member is not None:
user = data['user']
member.name = user['username']
member.discriminator = user['discriminator']
member.avatar = user['avatar']
member.roles = []
# update the roles
for role in server.roles:
if role.id in data['roles']:
member.roles.append(role)
self.dispatch('member_update', member)
elif event == 'GUILD_CREATE':
self._add_server(data)
self.dispatch('server_create', self.servers[-1])
elif event == 'GUILD_DELETE':
server = self._get_server(data.get('id'))
self.servers.remove(server)
self.dispatch('server_delete', server)
def _opened(self):
log.info('Opened at {}'.format(int(time.time())))
def _closed(self, code, reason=None):
log.info('Closed with {} ("{}") at {}'.format(code, reason, int(time.time())))
self.dispatch('disconnect')
def run(self): def run(self):
"""Runs the client and allows it to receive messages and events.""" """Runs the client and allows it to receive messages and events."""
log.info('Client is being run') log.info('Client is being run')
self.ws.run_forever() self.ws.run()
@property @property
def is_logged_in(self): def is_logged_in(self):
@ -399,18 +471,10 @@ class Client(object):
return self._is_logged_in return self._is_logged_in
def get_channel(self, id): def get_channel(self, id):
"""Returns a :class:`Channel` or :class:`PrivateChannel` with the following ID. If not found, returns None.""" """Returns a :class:`Channel` or :class:`PrivateChannel` with the
if id is None: following ID. If not found, returns None.
return None """
return self.connection.get_channel(id)
for server in self.servers:
for channel in server.channels:
if channel.id == id:
return channel
for pm in self.private_channels:
if pm.id == id:
return pm
def start_private_message(self, user): def start_private_message(self, user):
"""Starts a private message with the user. This allows you to :meth:`send_message` to it. """Starts a private message with the user. This allows you to :meth:`send_message` to it.
@ -578,7 +642,6 @@ class Client(object):
response = requests.post(endpoints.LOGOUT) response = requests.post(endpoints.LOGOUT)
self.ws.close() self.ws.close()
self._is_logged_in = False self._is_logged_in = False
self.keep_alive.stop.set()
log.debug(request_logging_format.format(name='logout', response=response)) log.debug(request_logging_format.format(name='logout', response=response))
def logs_from(self, channel, limit=500): def logs_from(self, channel, limit=500):