[commands] Fix (Partial)MessageConverter to work with thread messages
This commit is contained in:
		| @@ -86,7 +86,8 @@ if TYPE_CHECKING: | |||||||
|         OverwriteType, |         OverwriteType, | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     MessageableChannel = Union[TextChannel, Thread, DMChannel, GroupChannel] |     PartialMessageableChannel = Union[TextChannel, Thread, DMChannel] | ||||||
|  |     MessageableChannel = Union[PartialMessageableChannel, GroupChannel] | ||||||
|     SnowflakeTime = Union["Snowflake", datetime] |     SnowflakeTime = Union["Snowflake", datetime] | ||||||
|  |  | ||||||
| MISSING = utils.MISSING | MISSING = utils.MISSING | ||||||
|   | |||||||
| @@ -48,6 +48,7 @@ from .errors import * | |||||||
|  |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|     from .context import Context |     from .context import Context | ||||||
|  |     from discord.message import PartialMessageableChannel | ||||||
|  |  | ||||||
|  |  | ||||||
| __all__ = ( | __all__ = ( | ||||||
| @@ -349,11 +350,11 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): | |||||||
|         return guild_id, message_id, channel_id |         return guild_id, message_id, channel_id | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def _resolve_channel(ctx, guild_id, channel_id): |     def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]: | ||||||
|         if guild_id is not None: |         if guild_id is not None: | ||||||
|             guild = ctx.bot.get_guild(guild_id) |             guild = ctx.bot.get_guild(guild_id) | ||||||
|             if guild is not None and channel_id is not None: |             if guild is not None and channel_id is not None: | ||||||
|                 return guild.get_channel(channel_id) |                 return guild._resolve_channel(channel_id)  # type: ignore | ||||||
|             else: |             else: | ||||||
|                 return None |                 return None | ||||||
|         else: |         else: | ||||||
| @@ -470,6 +471,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): | |||||||
|  |  | ||||||
|         return result |         return result | ||||||
|  |  | ||||||
|  |  | ||||||
| class TextChannelConverter(IDConverter[discord.TextChannel]): | class TextChannelConverter(IDConverter[discord.TextChannel]): | ||||||
|     """Converts to a :class:`~discord.TextChannel`. |     """Converts to a :class:`~discord.TextChannel`. | ||||||
|  |  | ||||||
| @@ -567,6 +569,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]): | |||||||
|     async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel: |     async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel: | ||||||
|         return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel) |         return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel) | ||||||
|  |  | ||||||
|  |  | ||||||
| class ThreadConverter(IDConverter[discord.Thread]): | class ThreadConverter(IDConverter[discord.Thread]): | ||||||
|     """Coverts to a :class:`~discord.Thread`. |     """Coverts to a :class:`~discord.Thread`. | ||||||
|  |  | ||||||
| @@ -584,6 +587,7 @@ class ThreadConverter(IDConverter[discord.Thread]): | |||||||
|     async def convert(self, ctx: Context, argument: str) -> discord.Thread: |     async def convert(self, ctx: Context, argument: str) -> discord.Thread: | ||||||
|         return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread) |         return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread) | ||||||
|  |  | ||||||
|  |  | ||||||
| class ColourConverter(Converter[discord.Colour]): | class ColourConverter(Converter[discord.Colour]): | ||||||
|     """Converts to a :class:`~discord.Colour`. |     """Converts to a :class:`~discord.Colour`. | ||||||
|  |  | ||||||
| @@ -844,7 +848,7 @@ class clean_content(Converter[str]): | |||||||
|         fix_channel_mentions: bool = False, |         fix_channel_mentions: bool = False, | ||||||
|         use_nicknames: bool = True, |         use_nicknames: bool = True, | ||||||
|         escape_markdown: bool = False, |         escape_markdown: bool = False, | ||||||
|         remove_markdown: bool = False |         remove_markdown: bool = False, | ||||||
|     ) -> None: |     ) -> None: | ||||||
|         self.fix_channel_mentions = fix_channel_mentions |         self.fix_channel_mentions = fix_channel_mentions | ||||||
|         self.use_nicknames = use_nicknames |         self.use_nicknames = use_nicknames | ||||||
| @@ -855,6 +859,7 @@ class clean_content(Converter[str]): | |||||||
|         msg = ctx.message |         msg = ctx.message | ||||||
|  |  | ||||||
|         if ctx.guild: |         if ctx.guild: | ||||||
|  |  | ||||||
|             def resolve_member(id: int) -> str: |             def resolve_member(id: int) -> str: | ||||||
|                 m = _utils_get(msg.mentions, id=id) or ctx.guild.get_member(id) |                 m = _utils_get(msg.mentions, id=id) or ctx.guild.get_member(id) | ||||||
|                 return f'@{m.display_name if self.use_nicknames else m.name}' if m else '@deleted-user' |                 return f'@{m.display_name if self.use_nicknames else m.name}' if m else '@deleted-user' | ||||||
| @@ -862,7 +867,9 @@ class clean_content(Converter[str]): | |||||||
|             def resolve_role(id: int) -> str: |             def resolve_role(id: int) -> str: | ||||||
|                 r = _utils_get(msg.role_mentions, id=id) or ctx.guild.get_role(id) |                 r = _utils_get(msg.role_mentions, id=id) or ctx.guild.get_role(id) | ||||||
|                 return f'@{r.name}' if r else '@deleted-role' |                 return f'@{r.name}' if r else '@deleted-role' | ||||||
|  |  | ||||||
|         else: |         else: | ||||||
|  |  | ||||||
|             def resolve_member(id: int) -> str: |             def resolve_member(id: int) -> str: | ||||||
|                 m = _utils_get(msg.mentions, id=id) or ctx.bot.get_user(id) |                 m = _utils_get(msg.mentions, id=id) or ctx.bot.get_user(id) | ||||||
|                 return f'@{m.name}' if m else '@deleted-user' |                 return f'@{m.name}' if m else '@deleted-user' | ||||||
| @@ -871,10 +878,13 @@ class clean_content(Converter[str]): | |||||||
|                 return '@deleted-role' |                 return '@deleted-role' | ||||||
|  |  | ||||||
|         if self.fix_channel_mentions and ctx.guild: |         if self.fix_channel_mentions and ctx.guild: | ||||||
|  |  | ||||||
|             def resolve_channel(id: int) -> str: |             def resolve_channel(id: int) -> str: | ||||||
|                 c = ctx.guild.get_channel(id) |                 c = ctx.guild.get_channel(id) | ||||||
|                 return f'#{c.name}' if c else '#deleted-channel' |                 return f'#{c.name}' if c else '#deleted-channel' | ||||||
|  |  | ||||||
|         else: |         else: | ||||||
|  |  | ||||||
|             def resolve_channel(id: int) -> str: |             def resolve_channel(id: int) -> str: | ||||||
|                 return f'<#{id}>' |                 return f'<#{id}>' | ||||||
|  |  | ||||||
|   | |||||||
| @@ -67,7 +67,7 @@ if TYPE_CHECKING: | |||||||
|     from .types.user import User as UserPayload |     from .types.user import User as UserPayload | ||||||
|     from .types.embed import Embed as EmbedPayload |     from .types.embed import Embed as EmbedPayload | ||||||
|     from .abc import Snowflake |     from .abc import Snowflake | ||||||
|     from .abc import GuildChannel |     from .abc import GuildChannel, PartialMessageableChannel, MessageableChannel | ||||||
|     from .components import Component |     from .components import Component | ||||||
|     from .state import ConnectionState |     from .state import ConnectionState | ||||||
|     from .channel import TextChannel, GroupChannel, DMChannel |     from .channel import TextChannel, GroupChannel, DMChannel | ||||||
| @@ -657,7 +657,7 @@ class Message(Hashable): | |||||||
|         self.embeds: List[Embed] = [Embed.from_dict(a) for a in data['embeds']] |         self.embeds: List[Embed] = [Embed.from_dict(a) for a in data['embeds']] | ||||||
|         self.application: Optional[MessageApplicationPayload] = data.get('application') |         self.application: Optional[MessageApplicationPayload] = data.get('application') | ||||||
|         self.activity: Optional[MessageActivityPayload] = data.get('activity') |         self.activity: Optional[MessageActivityPayload] = data.get('activity') | ||||||
|         self.channel: Union[TextChannel, Thread, DMChannel, GroupChannel] = channel |         self.channel: MessageableChannel = channel | ||||||
|         self._edited_timestamp: Optional[datetime.datetime] = utils.parse_time(data['edited_timestamp']) |         self._edited_timestamp: Optional[datetime.datetime] = utils.parse_time(data['edited_timestamp']) | ||||||
|         self.type: MessageType = try_enum(MessageType, data['type']) |         self.type: MessageType = try_enum(MessageType, data['type']) | ||||||
|         self.pinned: bool = data['pinned'] |         self.pinned: bool = data['pinned'] | ||||||
| @@ -1557,8 +1557,11 @@ class PartialMessage(Hashable): | |||||||
|     a message and channel ID are present. |     a message and channel ID are present. | ||||||
|  |  | ||||||
|     There are two ways to construct this class. The first one is through |     There are two ways to construct this class. The first one is through | ||||||
|     the constructor itself, and the second is via |     the constructor itself, and the second is via the following: | ||||||
|     :meth:`TextChannel.get_partial_message` or :meth:`DMChannel.get_partial_message`. |  | ||||||
|  |     - :meth:`TextChannel.get_partial_message` | ||||||
|  |     - :meth:`Thread.get_partial_message` | ||||||
|  |     - :meth:`DMChannel.get_partial_message` | ||||||
|  |  | ||||||
|     Note that this class is trimmed down and has no rich attributes. |     Note that this class is trimmed down and has no rich attributes. | ||||||
|  |  | ||||||
| @@ -1580,7 +1583,7 @@ class PartialMessage(Hashable): | |||||||
|  |  | ||||||
|     Attributes |     Attributes | ||||||
|     ----------- |     ----------- | ||||||
|     channel: Union[:class:`TextChannel`, :class:`DMChannel`] |     channel: Union[:class:`TextChannel`, :class:`Thread`, :class:`DMChannel`] | ||||||
|         The channel associated with this partial message. |         The channel associated with this partial message. | ||||||
|     id: :class:`int` |     id: :class:`int` | ||||||
|         The message ID. |         The message ID. | ||||||
| @@ -1601,11 +1604,11 @@ class PartialMessage(Hashable): | |||||||
|     to_reference = Message.to_reference |     to_reference = Message.to_reference | ||||||
|     to_message_reference_dict = Message.to_message_reference_dict |     to_message_reference_dict = Message.to_message_reference_dict | ||||||
|  |  | ||||||
|     def __init__(self, *, channel: Union[TextChannel, DMChannel], id: int): |     def __init__(self, *, channel: PartialMessageableChannel, id: int): | ||||||
|         if channel.type not in (ChannelType.text, ChannelType.news, ChannelType.private): |         if channel.type not in (ChannelType.text, ChannelType.news, ChannelType.private): | ||||||
|             raise TypeError(f'Expected TextChannel or DMChannel not {type(channel)!r}') |             raise TypeError(f'Expected TextChannel or DMChannel not {type(channel)!r}') | ||||||
|  |  | ||||||
|         self.channel: Union[TextChannel, DMChannel] = channel |         self.channel: PartialMessageableChannel = channel | ||||||
|         self._state: ConnectionState = channel._state |         self._state: ConnectionState = channel._state | ||||||
|         self.id: int = id |         self.id: int = id | ||||||
|  |  | ||||||
|   | |||||||
| @@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE. | |||||||
| """ | """ | ||||||
|  |  | ||||||
| from __future__ import annotations | from __future__ import annotations | ||||||
|  |  | ||||||
| from typing import Callable, Dict, Iterable, List, Optional, Union, TYPE_CHECKING | from typing import Callable, Dict, Iterable, List, Optional, Union, TYPE_CHECKING | ||||||
| import time | import time | ||||||
| import asyncio | import asyncio | ||||||
| @@ -48,7 +49,7 @@ if TYPE_CHECKING: | |||||||
|     from .guild import Guild |     from .guild import Guild | ||||||
|     from .channel import TextChannel |     from .channel import TextChannel | ||||||
|     from .member import Member |     from .member import Member | ||||||
|     from .message import Message |     from .message import Message, PartialMessage | ||||||
|     from .abc import Snowflake, SnowflakeTime |     from .abc import Snowflake, SnowflakeTime | ||||||
|     from .role import Role |     from .role import Role | ||||||
|     from .permissions import Permissions |     from .permissions import Permissions | ||||||
| @@ -191,6 +192,7 @@ class Thread(Messageable, Hashable): | |||||||
|             self._unroll_metadata(data['thread_metadata']) |             self._unroll_metadata(data['thread_metadata']) | ||||||
|         except KeyError: |         except KeyError: | ||||||
|             pass |             pass | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def type(self) -> ChannelType: |     def type(self) -> ChannelType: | ||||||
|         """:class:`ChannelType`: The channel's Discord type.""" |         """:class:`ChannelType`: The channel's Discord type.""" | ||||||
| @@ -626,6 +628,29 @@ class Thread(Messageable, Hashable): | |||||||
|         """ |         """ | ||||||
|         await self._state.http.delete_channel(self.id) |         await self._state.http.delete_channel(self.id) | ||||||
|  |  | ||||||
|  |     def get_partial_message(self, message_id: int, /) -> PartialMessage: | ||||||
|  |         """Creates a :class:`PartialMessage` from the message ID. | ||||||
|  |  | ||||||
|  |         This is useful if you want to work with a message and only have its ID without | ||||||
|  |         doing an unnecessary API call. | ||||||
|  |  | ||||||
|  |         .. versionadded:: 2.0 | ||||||
|  |  | ||||||
|  |         Parameters | ||||||
|  |         ------------ | ||||||
|  |         message_id: :class:`int` | ||||||
|  |             The message ID to create a partial message for. | ||||||
|  |  | ||||||
|  |         Returns | ||||||
|  |         --------- | ||||||
|  |         :class:`PartialMessage` | ||||||
|  |             The partial message. | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         from .message import PartialMessage | ||||||
|  |  | ||||||
|  |         return PartialMessage(channel=self, id=message_id) | ||||||
|  |  | ||||||
|     def _add_member(self, member: ThreadMember) -> None: |     def _add_member(self, member: ThreadMember) -> None: | ||||||
|         self._members[member.id] = member |         self._members[member.id] = member | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user