From 44d1d297080d2536fc432aa7841be69f941ccb8b Mon Sep 17 00:00:00 2001 From: Rapptz Date: Tue, 29 Jun 2021 23:55:19 -0400 Subject: [PATCH] Add explicit types to variables in Message types --- discord/message.py | 86 +++++++++++++++++++++++++--------------------- 1 file changed, 47 insertions(+), 39 deletions(-) diff --git a/discord/message.py b/discord/message.py index 7d6cdaa09..e641028b6 100644 --- a/discord/message.py +++ b/discord/message.py @@ -65,9 +65,12 @@ if TYPE_CHECKING: from .types.embed import Embed as EmbedPayload from .abc import Snowflake from .abc import GuildChannel + from .components import Component from .state import ConnectionState from .channel import TextChannel, GroupChannel, DMChannel from .mentions import AllowedMentions + from .user import User + from .role import Role from .ui.view import View EmojiInputType = Union[Emoji, PartialEmoji, str] @@ -149,15 +152,15 @@ class Attachment(Hashable): __slots__ = ('id', 'size', 'height', 'width', 'filename', 'url', 'proxy_url', '_http', 'content_type') def __init__(self, *, data: AttachmentPayload, state: ConnectionState): - self.id = int(data['id']) - self.size = data['size'] - self.height = data.get('height') - self.width = data.get('width') - self.filename = data['filename'] - self.url = data.get('url') - self.proxy_url = data.get('proxy_url') + self.id: int = int(data['id']) + self.size: int = data['size'] + self.height: Optional[int] = data.get('height') + self.width: Optional[int] = data.get('width') + self.filename: str = data['filename'] + self.url: str = data.get('url') + self.proxy_url: str = data.get('proxy_url') self._http = state.http - self.content_type = data.get('content_type') + self.content_type: Optional[str] = data.get('content_type') def is_spoiler(self) -> bool: """:class:`bool`: Whether this attachment contains a spoiler.""" @@ -327,7 +330,7 @@ class DeletedReferencedMessage: __slots__ = ('_parent',) def __init__(self, parent: MessageReference): - self._parent = parent + self._parent: MessageReference = parent def __repr__(self) -> str: return f"" @@ -387,10 +390,10 @@ class MessageReference: def __init__(self, *, message_id: int, channel_id: int, guild_id: Optional[int] = None, fail_if_not_exists: bool = True): self._state: Optional[ConnectionState] = None self.resolved: Optional[Union[Message, DeletedReferencedMessage]] = None - self.message_id = message_id - self.channel_id = channel_id - self.guild_id = guild_id - self.fail_if_not_exists = fail_if_not_exists + self.message_id: int = message_id + self.channel_id: int = channel_id + self.guild_id: Optional[int] = guild_id + self.fail_if_not_exists: bool = fail_if_not_exists @classmethod def with_state(cls, state: ConnectionState, data: MessageReferencePayload) -> MessageReference: @@ -509,7 +512,7 @@ class Message(Hashable): private channel or the user has the left the guild, then it is a :class:`User` instead. content: :class:`str` The actual contents of the message. - nonce: Union[:class:`str`, :class:`int`] + nonce: Optional[Union[:class:`str`, :class:`int`]] The value used by the discord guild and the client to verify that the message is successfully sent. This is not stored long term within Discord's servers and is only used ephemerally. embeds: List[:class:`Embed`] @@ -630,6 +633,11 @@ class Message(Hashable): if TYPE_CHECKING: _HANDLERS: ClassVar[List[Tuple[str, Callable[..., None]]]] _CACHED_SLOTS: ClassVar[List[str]] + guild: Optional[Guild] + ref: Optional[MessageReference] + mentions: List[Union[User, Member]] + author: Union[User, Member] + role_mentions: List[Role] def __init__( self, @@ -638,28 +646,28 @@ class Message(Hashable): channel: Union[TextChannel, Thread, DMChannel, GroupChannel], data: MessagePayload, ): - self._state = state - self.id = int(data['id']) - self.webhook_id = utils._get_as_snowflake(data, 'webhook_id') - self.reactions = [Reaction(message=self, data=d) for d in data.get('reactions', [])] - self.attachments = [Attachment(data=a, state=self._state) for a in data['attachments']] - self.embeds = [Embed.from_dict(a) for a in data['embeds']] - self.application = data.get('application') - self.activity = data.get('activity') - self.channel = channel - self._edited_timestamp = utils.parse_time(data['edited_timestamp']) - self.type = try_enum(MessageType, data['type']) - self.pinned = data['pinned'] - self.flags = MessageFlags._from_value(data.get('flags', 0)) - self.mention_everyone = data['mention_everyone'] - self.tts = data['tts'] - self.content = data['content'] - self.nonce = data.get('nonce') - self.stickers = [Sticker(data=d, state=state) for d in data.get('stickers', [])] - self.components = [_component_factory(d) for d in data.get('components', [])] + self._state: ConnectionState = state + self.id: int = int(data['id']) + self.webhook_id: Optional[int] = utils._get_as_snowflake(data, 'webhook_id') + self.reactions: List[Reaction] = [Reaction(message=self, data=d) for d in data.get('reactions', [])] + self.attachments: List[Attachment] = [Attachment(data=a, state=self._state) for a in data['attachments']] + self.embeds: List[Embed] = [Embed.from_dict(a) for a in data['embeds']] + self.application: Optional[MessageApplicationPayload] = data.get('application') + self.activity: Optional[MessageActivityPayload] = data.get('activity') + self.channel: Union[TextChannel, Thread, DMChannel, GroupChannel] = channel + self._edited_timestamp: Optional[datetime.datetime] = utils.parse_time(data['edited_timestamp']) + self.type: MessageType = try_enum(MessageType, data['type']) + self.pinned: bool = data['pinned'] + self.flags: MessageFlags = MessageFlags._from_value(data.get('flags', 0)) + self.mention_everyone: bool = data['mention_everyone'] + self.tts: bool = data['tts'] + self.content: str = data['content'] + self.nonce: Optional[Union[int, str]] = data.get('nonce') + self.stickers: List[Sticker] = [Sticker(data=d, state=state) for d in data.get('stickers', [])] + self.components: List[Component] = [_component_factory(d) for d in data.get('components', [])] try: - self.guild = channel.guild + self.guild = channel.guild # type: ignore except AttributeError: self.guild = state._get_guild(utils._get_as_snowflake(data, 'guild_id')) @@ -685,7 +693,7 @@ class Message(Hashable): ref.resolved = self.__class__(channel=chan, data=resolved, state=state) - for handler in ('author', 'member', 'mentions', 'mention_roles', 'flags'): + for handler in ('author', 'member', 'mentions', 'mention_roles'): try: getattr(self, f'_handle_{handler}')(data[handler]) except KeyError: @@ -775,7 +783,7 @@ class Message(Hashable): def _handle_edited_timestamp(self, value: str) -> None: self._edited_timestamp = utils.parse_time(value) - def _handle_pinned(self, value: int) -> None: + def _handle_pinned(self, value: bool) -> None: self.pinned = value def _handle_flags(self, value: int) -> None: @@ -1589,9 +1597,9 @@ class PartialMessage(Hashable): if channel.type not in (ChannelType.text, ChannelType.news, ChannelType.private): raise TypeError(f'Expected TextChannel or DMChannel not {type(channel)!r}') - self.channel = channel - self._state = channel._state - self.id = id + self.channel: Union[TextChannel, DMChannel] = channel + self._state: ConnectionState = channel._state + self.id: int = id def _update(self, data) -> None: # This is used for duck typing purposes.