Parse remaining thread events.

This commit is contained in:
Rapptz
2021-05-09 22:23:21 -04:00
parent 9adf94e6b1
commit bd369c76ea
4 changed files with 115 additions and 10 deletions

View File

@ -26,7 +26,7 @@ from __future__ import annotations
import copy
from collections import namedtuple
from typing import Dict, List, Literal, Optional, TYPE_CHECKING, Union, overload
from typing import Dict, List, Set, Literal, Optional, TYPE_CHECKING, Union, overload
from . import utils, abc
from .role import Role
@ -227,6 +227,20 @@ class Guild(Hashable):
def _remove_thread(self, thread):
self._threads.pop(thread.id, None)
def _clear_threads(self):
self._threads.clear()
def _remove_threads_by_channel(self, channel_id: int):
to_remove = [k for k, t in self._threads.items() if t.parent_id == channel_id]
for k in to_remove:
del self._threads[k]
def _filter_threads(self, channel_ids: Set[int]) -> Dict[int, Thread]:
to_remove: Dict[int, Thread] = {k: t for k, t in self._threads.items() if t.parent_id in channel_ids}
for k in to_remove:
del self._threads[k]
return to_remove
def __str__(self):
return self.name or ''

View File

@ -716,7 +716,7 @@ class ConnectionState:
thread = Thread(guild=guild, data=data)
guild._add_thread(thread)
self.dispatch('thread_create', thread)
self.dispatch('thread_join', thread)
def parse_thread_update(self, data):
guild_id = int(data['guild_id'])
@ -752,6 +752,16 @@ class ConnectionState:
log.debug('THREAD_LIST_SYNC referencing an unknown guild ID: %s. Discarding', guild_id)
return
try:
channel_ids = set(data['channel_ids'])
except KeyError:
# If not provided, then the entire guild is being synced
# So all previous thread data should be overwritten
previous_threads = guild._threads.copy()
guild._clear_threads()
else:
previous_threads = guild._filter_threads(channel_ids)
threads = {
d['id']: guild._store_thread(d)
for d in data.get('threads', [])
@ -766,7 +776,13 @@ class ConnectionState:
else:
thread._add_member(ThreadMember(thread, member))
# TODO: dispatch?
for thread in threads.values():
old = previous_threads.pop(thread.id, None)
if old is None:
self.dispatch('thread_join', thread)
for thread in previous_threads.values():
self.dispatch('thread_remove', thread)
def parse_thread_member_update(self, data):
guild_id = int(data['guild_id'])
@ -776,15 +792,44 @@ class ConnectionState:
return
thread_id = int(data['id'])
thread = guild.get_thread(thread_id)
thread: Optional[Thread] = guild.get_thread(thread_id)
if thread is None:
log.debug('THREAD_MEMBER_UPDATE referencing an unknown thread ID: %s. Discarding', thread_id)
return
member = ThreadMember(thread, data)
thread._add_member(member)
thread.me = member
# TODO: dispatch
def parse_thread_members_update(self, data):
guild_id = int(data['guild_id'])
guild: Optional[Guild] = self._get_guild(guild_id)
if guild is None:
log.debug('THREAD_MEMBERS_UPDATE referencing an unknown guild ID: %s. Discarding', guild_id)
return
thread_id = int(data['id'])
thread: Optional[Thread] = guild.get_thread(thread_id)
if thread is None:
log.debug('THREAD_MEMBERS_UPDATE referencing an unknown thread ID: %s. Discarding', thread_id)
return
added_members = [ThreadMember(thread, d) for d in data.get('added_members', [])]
removed_member_ids = data.get('removed_member_ids', [])
self_id = self.self_id
for member in added_members:
if member.id != self_id:
thread._add_member(member)
self.dispatch('thread_member_join', member)
else:
thread.me = member
self.dispatch('thread_join', thread)
for member_id in removed_member_ids:
if member_id != self_id:
member = thread._pop_member(member_id)
self.dispatch('thread_member_leave', member)
else:
self.dispatch('thread_remove', thread)
def parse_guild_member_add(self, data):
guild = self._get_guild(int(data['guild_id']))

View File

@ -383,6 +383,8 @@ class Thread(Messageable, Hashable):
def _add_member(self, member: ThreadMember) -> None:
self._members[member.id] = member
def _pop_member(self, member_id: int) -> Optional[ThreadMember]:
return self._members.pop(member_id, None)
class ThreadMember(Hashable):
"""Represents a Discord thread member.