Fix Client.fetch_channel not returning Thread

This commit is contained in:
Alex Nørgaard 2021-07-04 02:35:31 +01:00 committed by GitHub
parent 097b6064f1
commit d1dc41ec2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 22 additions and 15 deletions

View File

@ -691,7 +691,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
type=ChannelType.public_thread.value, type=ChannelType.public_thread.value,
) )
return Thread(guild=self.guild, data=data) return Thread(guild=self.guild, state=self._state, data=data)
def archived_threads( def archived_threads(
self, self,
@ -753,7 +753,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
""" """
data = await self._state.http.get_active_threads(self.id) data = await self._state.http.get_active_threads(self.id)
# TODO: thread members? # TODO: thread members?
return [Thread(guild=self.guild, data=d) for d in data.get('threads', [])] return [Thread(guild=self.guild, state=self._state, data=d) for d in data.get('threads', [])]
class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable): class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable):
@ -1924,3 +1924,9 @@ def _channel_factory(channel_type: Union[ChannelType, int]):
return GroupChannel, value return GroupChannel, value
else: else:
return cls, value return cls, value
def _threaded_channel_factory(channel_type: Union[ChannelType, int]):
cls, value = _channel_factory(channel_type)
if value in (ChannelType.private_thread, ChannelType.public_thread, ChannelType.news_thread):
return Thread, value
return cls, value

View File

@ -39,7 +39,7 @@ from .template import Template
from .widget import Widget from .widget import Widget
from .guild import Guild from .guild import Guild
from .emoji import Emoji from .emoji import Emoji
from .channel import _channel_factory from .channel import _threaded_channel_factory
from .enums import ChannelType from .enums import ChannelType
from .mentions import AllowedMentions from .mentions import AllowedMentions
from .errors import * from .errors import *
@ -58,6 +58,7 @@ from .iterators import GuildIterator
from .appinfo import AppInfo from .appinfo import AppInfo
from .ui.view import View from .ui.view import View
from .stage_instance import StageInstance from .stage_instance import StageInstance
from .threads import Thread
if TYPE_CHECKING: if TYPE_CHECKING:
from .abc import SnowflakeTime, PrivateChannel, GuildChannel, Snowflake from .abc import SnowflakeTime, PrivateChannel, GuildChannel, Snowflake
@ -1371,10 +1372,10 @@ class Client:
data = await self.http.get_user(user_id) data = await self.http.get_user(user_id)
return User(state=self._connection, data=data) return User(state=self._connection, data=data)
async def fetch_channel(self, channel_id: int) -> Union[GuildChannel, PrivateChannel]: async def fetch_channel(self, channel_id: int) -> Union[GuildChannel, PrivateChannel, Thread]:
"""|coro| """|coro|
Retrieves a :class:`.abc.GuildChannel` or :class:`.abc.PrivateChannel` with the specified ID. Retrieves a :class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, or :class:`.Thread` with the specified ID.
.. note:: .. note::
@ -1395,12 +1396,12 @@ class Client:
Returns Returns
-------- --------
Union[:class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`] Union[:class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, :class:`.Thread`]
The channel from the ID. The channel from the ID.
""" """
data = await self.http.get_channel(channel_id) data = await self.http.get_channel(channel_id)
factory, ch_type = _channel_factory(data['type']) factory, ch_type = _threaded_channel_factory(data['type'])
if factory is None: if factory is None:
raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data)) raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data))

View File

@ -287,7 +287,7 @@ class Guild(Hashable):
self._members[member.id] = member self._members[member.id] = member
def _store_thread(self, payload: ThreadPayload, /) -> Thread: def _store_thread(self, payload: ThreadPayload, /) -> Thread:
thread = Thread(guild=self, data=payload) thread = Thread(guild=self, state=self._state, data=payload)
self._threads[thread.id] = thread self._threads[thread.id] = thread
return thread return thread
@ -466,7 +466,7 @@ class Guild(Hashable):
if 'threads' in data: if 'threads' in data:
threads = data['threads'] threads = data['threads']
for thread in threads: for thread in threads:
self._add_thread(Thread(guild=self, data=thread)) self._add_thread(Thread(guild=self, state=self._state, data=thread))
@property @property
def channels(self) -> List[GuildChannel]: def channels(self) -> List[GuildChannel]:

View File

@ -750,4 +750,4 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
def create_thread(self, data: ThreadPayload) -> Thread: def create_thread(self, data: ThreadPayload) -> Thread:
from .threads import Thread from .threads import Thread
return Thread(guild=self.guild, data=data) return Thread(guild=self.guild, state=self.guild._state, data=data)

View File

@ -1491,7 +1491,7 @@ class Message(Hashable):
auto_archive_duration=auto_archive_duration, auto_archive_duration=auto_archive_duration,
type=ChannelType.public_thread.value, type=ChannelType.public_thread.value,
) )
return Thread(guild=self.guild, data=data) # type: ignore return Thread(guild=self.guild, state=self._state, data=data) # type: ignore
async def reply(self, content: Optional[str] = None, **kwargs) -> Message: async def reply(self, content: Optional[str] = None, **kwargs) -> Message:
"""|coro| """|coro|

View File

@ -715,7 +715,7 @@ class ConnectionState:
log.debug('THREAD_CREATE referencing an unknown guild ID: %s. Discarding', guild_id) log.debug('THREAD_CREATE referencing an unknown guild ID: %s. Discarding', guild_id)
return return
thread = Thread(guild=guild, data=data) thread = Thread(guild=guild, state=guild._state, data=data)
has_thread = guild.get_thread(thread.id) has_thread = guild.get_thread(thread.id)
guild._add_thread(thread) guild._add_thread(thread)
if not has_thread: if not has_thread:
@ -735,7 +735,7 @@ class ConnectionState:
thread._update(data) thread._update(data)
self.dispatch('thread_update', old, thread) self.dispatch('thread_update', old, thread)
else: else:
thread = Thread(guild=guild, data=data) thread = Thread(guild=guild, state=guild._state, data=data)
guild._add_thread(thread) guild._add_thread(thread)
self.dispatch('thread_join', thread) self.dispatch('thread_join', thread)

View File

@ -139,8 +139,8 @@ class Thread(Messageable, Hashable):
'archive_timestamp', 'archive_timestamp',
) )
def __init__(self, *, guild: Guild, data: ThreadPayload): def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload):
self._state: ConnectionState = guild._state self._state: ConnectionState = state
self.guild = guild self.guild = guild
self._members: Dict[int, ThreadMember] = {} self._members: Dict[int, ThreadMember] = {}
self._from_data(data) self._from_data(data)