Begin working on the rewrite.
This commit is contained in:
@ -47,18 +47,20 @@ class ListenerType(enum.Enum):
|
||||
chunk = 0
|
||||
|
||||
Listener = namedtuple('Listener', ('type', 'future', 'predicate'))
|
||||
StateContext = namedtuple('StateContext', 'try_insert_user http')
|
||||
log = logging.getLogger(__name__)
|
||||
ReadyState = namedtuple('ReadyState', ('launch', 'servers'))
|
||||
|
||||
class ConnectionState:
|
||||
def __init__(self, dispatch, chunker, syncer, max_messages, *, loop):
|
||||
def __init__(self, *, dispatch, chunker, syncer, http, loop, **options):
|
||||
self.loop = loop
|
||||
self.max_messages = max_messages
|
||||
self.max_messages = max(options.get('max_messages', 5000), 100)
|
||||
self.dispatch = dispatch
|
||||
self.chunker = chunker
|
||||
self.syncer = syncer
|
||||
self.is_bot = None
|
||||
self._listeners = []
|
||||
self.ctx = StateContext(try_insert_user=self.try_insert_user, http=http)
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
@ -66,6 +68,7 @@ class ConnectionState:
|
||||
self.sequence = None
|
||||
self.session_id = None
|
||||
self._calls = {}
|
||||
self._users = {}
|
||||
self._servers = {}
|
||||
self._voice_clients = {}
|
||||
self._private_channels = {}
|
||||
@ -116,6 +119,15 @@ class ConnectionState:
|
||||
for vc in self.voice_clients:
|
||||
vc.main_ws = ws
|
||||
|
||||
def try_insert_user(self, data):
|
||||
# this way is 300% faster than `dict.setdefault`.
|
||||
user_id = data['id']
|
||||
try:
|
||||
return self._users[user_id]
|
||||
except KeyError:
|
||||
self._users[user_id] = user = User(state=self.ctx, data=data)
|
||||
return user
|
||||
|
||||
@property
|
||||
def servers(self):
|
||||
return self._servers.values()
|
||||
@ -153,7 +165,7 @@ class ConnectionState:
|
||||
return utils.find(lambda m: m.id == msg_id, self.messages)
|
||||
|
||||
def _add_server_from_data(self, guild):
|
||||
server = Server(**guild)
|
||||
server = Server(data=guild, state=self.ctx)
|
||||
Server.me = property(lambda s: s.get_member(self.user.id))
|
||||
Server.voice_client = property(lambda s: self._get_voice_client(s.id))
|
||||
self._add_server(server)
|
||||
@ -207,7 +219,7 @@ class ConnectionState:
|
||||
|
||||
def parse_ready(self, data):
|
||||
self._ready_state = ReadyState(launch=asyncio.Event(), servers=[])
|
||||
self.user = User(**data['user'])
|
||||
self.user = self.try_insert_user(data['user'])
|
||||
guilds = data.get('guilds')
|
||||
|
||||
servers = self._ready_state.servers
|
||||
@ -217,7 +229,7 @@ class ConnectionState:
|
||||
servers.append(server)
|
||||
|
||||
for pm in data.get('private_channels'):
|
||||
self._add_private_channel(PrivateChannel(self.user, **pm))
|
||||
self._add_private_channel(PrivateChannel(me=self.user, data=pm, state=self.ctx))
|
||||
|
||||
compat.create_task(self._delay_ready(), loop=self.loop)
|
||||
|
||||
@ -226,7 +238,7 @@ class ConnectionState:
|
||||
|
||||
def parse_message_create(self, data):
|
||||
channel = self.get_channel(data.get('channel_id'))
|
||||
message = self._create_message(channel=channel, **data)
|
||||
message = Message(channel=channel, data=data, state=self.ctx)
|
||||
self.dispatch('message', message)
|
||||
self.messages.append(message)
|
||||
|
||||
@ -255,7 +267,7 @@ class ConnectionState:
|
||||
# embed only edit
|
||||
message.embeds = data['embeds']
|
||||
else:
|
||||
message._update(channel=message.channel, **data)
|
||||
message._update(channel=message.channel, data=data)
|
||||
|
||||
self.dispatch('message_edit', older_message, message)
|
||||
|
||||
@ -329,22 +341,11 @@ class ConnectionState:
|
||||
server._add_member(member)
|
||||
|
||||
old_member = member._copy()
|
||||
member.status = data.get('status')
|
||||
try:
|
||||
member.status = Status(member.status)
|
||||
except:
|
||||
pass
|
||||
|
||||
game = data.get('game', {})
|
||||
member.game = Game(**game) if game else None
|
||||
member.name = user.get('username', member.name)
|
||||
member.avatar = user.get('avatar', member.avatar)
|
||||
member.discriminator = user.get('discriminator', member.discriminator)
|
||||
|
||||
member._presence_update(data=data, user=user)
|
||||
self.dispatch('member_update', old_member, member)
|
||||
|
||||
def parse_user_update(self, data):
|
||||
self.user = User(**data)
|
||||
self.user = User(state=self.ctx, data=data)
|
||||
|
||||
def parse_channel_delete(self, data):
|
||||
server = self._get_server(data.get('guild_id'))
|
||||
@ -361,7 +362,7 @@ class ConnectionState:
|
||||
if channel_type is ChannelType.group:
|
||||
channel = self._get_private_channel(channel_id)
|
||||
old_channel = copy.copy(channel)
|
||||
channel._update_group(**data)
|
||||
channel._update_group(data)
|
||||
self.dispatch('channel_update', old_channel, channel)
|
||||
return
|
||||
|
||||
@ -370,32 +371,32 @@ class ConnectionState:
|
||||
channel = server.get_channel(channel_id)
|
||||
if channel is not None:
|
||||
old_channel = copy.copy(channel)
|
||||
channel._update(server=server, **data)
|
||||
channel._update(server, data)
|
||||
self.dispatch('channel_update', old_channel, channel)
|
||||
|
||||
def parse_channel_create(self, data):
|
||||
ch_type = try_enum(ChannelType, data.get('type'))
|
||||
channel = None
|
||||
if ch_type in (ChannelType.group, ChannelType.private):
|
||||
channel = PrivateChannel(self.user, **data)
|
||||
channel = PrivateChannel(me=self.user, data=data, state=self.ctx)
|
||||
self._add_private_channel(channel)
|
||||
else:
|
||||
server = self._get_server(data.get('guild_id'))
|
||||
if server is not None:
|
||||
channel = Channel(server=server, **data)
|
||||
channel = Channel(server=server, state=self.ctx, data=data)
|
||||
server._add_channel(channel)
|
||||
|
||||
self.dispatch('channel_create', channel)
|
||||
|
||||
def parse_channel_recipient_add(self, data):
|
||||
channel = self._get_private_channel(data.get('channel_id'))
|
||||
user = User(**data.get('user', {}))
|
||||
user = self.try_insert_user(data['user'])
|
||||
channel.recipients.append(user)
|
||||
self.dispatch('group_join', channel, user)
|
||||
|
||||
def parse_channel_recipient_remove(self, data):
|
||||
channel = self._get_private_channel(data.get('channel_id'))
|
||||
user = User(**data.get('user', {}))
|
||||
user = self.try_insert_user(data['user'])
|
||||
try:
|
||||
channel.recipients.remove(user)
|
||||
except ValueError:
|
||||
@ -411,7 +412,7 @@ class ConnectionState:
|
||||
roles.append(role)
|
||||
|
||||
data['roles'] = sorted(roles, key=lambda r: int(r.id))
|
||||
return Member(server=server, **data)
|
||||
return Member(server=server, data=data, state=self.ctx)
|
||||
|
||||
def parse_guild_member_add(self, data):
|
||||
server = self._get_server(data.get('guild_id'))
|
||||
@ -441,35 +442,18 @@ class ConnectionState:
|
||||
|
||||
def parse_guild_member_update(self, data):
|
||||
server = self._get_server(data.get('guild_id'))
|
||||
user_id = data['user']['id']
|
||||
user = data['user']
|
||||
user_id = user['id']
|
||||
member = server.get_member(user_id)
|
||||
if member is not None:
|
||||
user = data['user']
|
||||
old_member = member._copy()
|
||||
member.name = user['username']
|
||||
member.discriminator = user['discriminator']
|
||||
member.avatar = user['avatar']
|
||||
member.bot = user.get('bot', False)
|
||||
|
||||
# the nickname change is optional,
|
||||
# if it isn't in the payload then it didn't change
|
||||
if 'nick' in data:
|
||||
member.nick = data['nick']
|
||||
|
||||
# update the roles
|
||||
member.roles = [server.default_role]
|
||||
for role in server.roles:
|
||||
if role.id in data['roles']:
|
||||
member.roles.append(role)
|
||||
|
||||
# sort the roles by ID since they can be "randomised"
|
||||
member.roles.sort(key=lambda r: int(r.id))
|
||||
member._update(data, user)
|
||||
self.dispatch('member_update', old_member, member)
|
||||
|
||||
def parse_guild_emojis_update(self, data):
|
||||
server = self._get_server(data.get('guild_id'))
|
||||
before_emojis = server.emojis
|
||||
server.emojis = [Emoji(server=server, **e) for e in data.get('emojis', [])]
|
||||
server.emojis = [Emoji(server=server, data=e, state=self.ctx) for e in data.get('emojis', [])]
|
||||
self.dispatch('server_emojis_update', before_emojis, server.emojis)
|
||||
|
||||
def _get_create_server(self, data):
|
||||
@ -584,13 +568,13 @@ class ConnectionState:
|
||||
server = self._get_server(data.get('guild_id'))
|
||||
if server is not None:
|
||||
if 'user' in data:
|
||||
user = User(**data['user'])
|
||||
user = self.try_insert_user(data['user'])
|
||||
self.dispatch('member_unban', server, user)
|
||||
|
||||
def parse_guild_role_create(self, data):
|
||||
server = self._get_server(data.get('guild_id'))
|
||||
role_data = data.get('role', {})
|
||||
role = Role(server=server, **role_data)
|
||||
server = self._get_server(data['guild_id'])
|
||||
role_data = data['role']
|
||||
role = Role(server=server, data=role_data, state=self.ctx)
|
||||
server._add_role(role)
|
||||
self.dispatch('server_role_create', role)
|
||||
|
||||
@ -609,11 +593,12 @@ class ConnectionState:
|
||||
def parse_guild_role_update(self, data):
|
||||
server = self._get_server(data.get('guild_id'))
|
||||
if server is not None:
|
||||
role_id = data['role']['id']
|
||||
role_data = data['role']
|
||||
role_id = role_data['id']
|
||||
role = utils.find(lambda r: r.id == role_id, server.roles)
|
||||
if role is not None:
|
||||
old_role = copy.copy(role)
|
||||
role._update(**data['role'])
|
||||
role._update(role_data)
|
||||
self.dispatch('server_role_update', old_role, role)
|
||||
|
||||
def parse_guild_members_chunk(self, data):
|
||||
|
Reference in New Issue
Block a user