mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-10-24 10:02:56 +00:00
Fix Client.fetch_channel not returning Thread
This commit is contained in:
@@ -691,7 +691,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
type=ChannelType.public_thread.value,
|
||||
)
|
||||
|
||||
return Thread(guild=self.guild, data=data)
|
||||
return Thread(guild=self.guild, state=self._state, data=data)
|
||||
|
||||
def archived_threads(
|
||||
self,
|
||||
@@ -753,7 +753,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
"""
|
||||
data = await self._state.http.get_active_threads(self.id)
|
||||
# TODO: thread members?
|
||||
return [Thread(guild=self.guild, data=d) for d in data.get('threads', [])]
|
||||
return [Thread(guild=self.guild, state=self._state, data=d) for d in data.get('threads', [])]
|
||||
|
||||
|
||||
class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable):
|
||||
@@ -1924,3 +1924,9 @@ def _channel_factory(channel_type: Union[ChannelType, int]):
|
||||
return GroupChannel, value
|
||||
else:
|
||||
return cls, value
|
||||
|
||||
def _threaded_channel_factory(channel_type: Union[ChannelType, int]):
|
||||
cls, value = _channel_factory(channel_type)
|
||||
if value in (ChannelType.private_thread, ChannelType.public_thread, ChannelType.news_thread):
|
||||
return Thread, value
|
||||
return cls, value
|
||||
|
@@ -39,7 +39,7 @@ from .template import Template
|
||||
from .widget import Widget
|
||||
from .guild import Guild
|
||||
from .emoji import Emoji
|
||||
from .channel import _channel_factory
|
||||
from .channel import _threaded_channel_factory
|
||||
from .enums import ChannelType
|
||||
from .mentions import AllowedMentions
|
||||
from .errors import *
|
||||
@@ -58,6 +58,7 @@ from .iterators import GuildIterator
|
||||
from .appinfo import AppInfo
|
||||
from .ui.view import View
|
||||
from .stage_instance import StageInstance
|
||||
from .threads import Thread
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .abc import SnowflakeTime, PrivateChannel, GuildChannel, Snowflake
|
||||
@@ -1371,10 +1372,10 @@ class Client:
|
||||
data = await self.http.get_user(user_id)
|
||||
return User(state=self._connection, data=data)
|
||||
|
||||
async def fetch_channel(self, channel_id: int) -> Union[GuildChannel, PrivateChannel]:
|
||||
async def fetch_channel(self, channel_id: int) -> Union[GuildChannel, PrivateChannel, Thread]:
|
||||
"""|coro|
|
||||
|
||||
Retrieves a :class:`.abc.GuildChannel` or :class:`.abc.PrivateChannel` with the specified ID.
|
||||
Retrieves a :class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, or :class:`.Thread` with the specified ID.
|
||||
|
||||
.. note::
|
||||
|
||||
@@ -1395,12 +1396,12 @@ class Client:
|
||||
|
||||
Returns
|
||||
--------
|
||||
Union[:class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`]
|
||||
Union[:class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, :class:`.Thread`]
|
||||
The channel from the ID.
|
||||
"""
|
||||
data = await self.http.get_channel(channel_id)
|
||||
|
||||
factory, ch_type = _channel_factory(data['type'])
|
||||
factory, ch_type = _threaded_channel_factory(data['type'])
|
||||
if factory is None:
|
||||
raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data))
|
||||
|
||||
|
@@ -287,7 +287,7 @@ class Guild(Hashable):
|
||||
self._members[member.id] = member
|
||||
|
||||
def _store_thread(self, payload: ThreadPayload, /) -> Thread:
|
||||
thread = Thread(guild=self, data=payload)
|
||||
thread = Thread(guild=self, state=self._state, data=payload)
|
||||
self._threads[thread.id] = thread
|
||||
return thread
|
||||
|
||||
@@ -466,7 +466,7 @@ class Guild(Hashable):
|
||||
if 'threads' in data:
|
||||
threads = data['threads']
|
||||
for thread in threads:
|
||||
self._add_thread(Thread(guild=self, data=thread))
|
||||
self._add_thread(Thread(guild=self, state=self._state, data=thread))
|
||||
|
||||
@property
|
||||
def channels(self) -> List[GuildChannel]:
|
||||
|
@@ -750,4 +750,4 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
|
||||
|
||||
def create_thread(self, data: ThreadPayload) -> Thread:
|
||||
from .threads import Thread
|
||||
return Thread(guild=self.guild, data=data)
|
||||
return Thread(guild=self.guild, state=self.guild._state, data=data)
|
||||
|
@@ -1491,7 +1491,7 @@ class Message(Hashable):
|
||||
auto_archive_duration=auto_archive_duration,
|
||||
type=ChannelType.public_thread.value,
|
||||
)
|
||||
return Thread(guild=self.guild, data=data) # type: ignore
|
||||
return Thread(guild=self.guild, state=self._state, data=data) # type: ignore
|
||||
|
||||
async def reply(self, content: Optional[str] = None, **kwargs) -> Message:
|
||||
"""|coro|
|
||||
|
@@ -715,7 +715,7 @@ class ConnectionState:
|
||||
log.debug('THREAD_CREATE referencing an unknown guild ID: %s. Discarding', guild_id)
|
||||
return
|
||||
|
||||
thread = Thread(guild=guild, data=data)
|
||||
thread = Thread(guild=guild, state=guild._state, data=data)
|
||||
has_thread = guild.get_thread(thread.id)
|
||||
guild._add_thread(thread)
|
||||
if not has_thread:
|
||||
@@ -735,7 +735,7 @@ class ConnectionState:
|
||||
thread._update(data)
|
||||
self.dispatch('thread_update', old, thread)
|
||||
else:
|
||||
thread = Thread(guild=guild, data=data)
|
||||
thread = Thread(guild=guild, state=guild._state, data=data)
|
||||
guild._add_thread(thread)
|
||||
self.dispatch('thread_join', thread)
|
||||
|
||||
|
@@ -139,8 +139,8 @@ class Thread(Messageable, Hashable):
|
||||
'archive_timestamp',
|
||||
)
|
||||
|
||||
def __init__(self, *, guild: Guild, data: ThreadPayload):
|
||||
self._state: ConnectionState = guild._state
|
||||
def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload):
|
||||
self._state: ConnectionState = state
|
||||
self.guild = guild
|
||||
self._members: Dict[int, ThreadMember] = {}
|
||||
self._from_data(data)
|
||||
|
Reference in New Issue
Block a user