Refactor internal message sending and editing parameter passing

This reduces some repetition in many functions and is ripped out of
the webhook code. This also removes the unused HTTP functions for
interaction responses since those belong in the webhook code rather
than the HTTPClient.
This commit is contained in:
Rapptz
2022-02-18 06:59:34 -05:00
parent 4248bb3717
commit 9c066a8cf6
6 changed files with 231 additions and 530 deletions

View File

@@ -35,14 +35,15 @@ from typing import (
Iterable,
List,
Literal,
NamedTuple,
Optional,
overload,
Sequence,
TYPE_CHECKING,
Tuple,
TYPE_CHECKING,
Type,
TypeVar,
Union,
overload,
)
from urllib.parse import quote as _uriquote
import weakref
@@ -58,6 +59,11 @@ _log = logging.getLogger(__name__)
if TYPE_CHECKING:
from .file import File
from .ui.view import View
from .embeds import Embed
from .mentions import AllowedMentions
from .message import Attachment
from .flags import MessageFlags
from .enums import (
AuditLogAction,
InteractionResponseType,
@@ -110,6 +116,149 @@ async def json_or_text(response: aiohttp.ClientResponse) -> Union[Dict[str, Any]
return text
class MultipartParameters(NamedTuple):
payload: Optional[Dict[str, Any]]
multipart: Optional[List[Dict[str, Any]]]
files: Optional[List[File]]
def __enter__(self):
return self
def __exit__(
self,
exc_type: Optional[Type[BE]],
exc: Optional[BE],
traceback: Optional[TracebackType],
) -> None:
if self.files:
for file in self.files:
file.close()
def handle_message_parameters(
content: Optional[str] = MISSING,
*,
username: str = MISSING,
avatar_url: Any = MISSING,
tts: bool = False,
nonce: Optional[Union[int, str]] = None,
flags: MessageFlags = MISSING,
file: File = MISSING,
files: List[File] = MISSING,
embed: Optional[Embed] = MISSING,
embeds: List[Embed] = MISSING,
attachments: List[Attachment] = MISSING,
view: Optional[View] = MISSING,
allowed_mentions: Optional[AllowedMentions] = MISSING,
message_reference: Optional[message.MessageReference] = MISSING,
stickers: Optional[SnowflakeList] = MISSING,
previous_allowed_mentions: Optional[AllowedMentions] = None,
mention_author: Optional[bool] = None,
) -> MultipartParameters:
if files is not MISSING and file is not MISSING:
raise TypeError('Cannot mix file and files keyword arguments.')
if embeds is not MISSING and embed is not MISSING:
raise TypeError('Cannot mix embed and embeds keyword arguments.')
payload = {}
if embeds is not MISSING:
if len(embeds) > 10:
raise InvalidArgument('embeds has a maximum of 10 elements.')
payload['embeds'] = [e.to_dict() for e in embeds]
if embed is not MISSING:
if embed is None:
payload['embeds'] = []
else:
payload['embeds'] = [embed.to_dict()]
if content is not MISSING:
if content is not None:
payload['content'] = str(content)
else:
payload['content'] = None
if view is not MISSING:
if view is not None:
payload['components'] = view.to_components()
else:
payload['components'] = []
if nonce is not MISSING:
payload['nonce'] = str(nonce)
if message_reference is not MISSING:
payload['message_reference'] = message_reference
if attachments is not MISSING:
# Note: This will be overwritten if file or files is provided
# However, right now this is only passed via Message.edit not Messageable.send
payload['attachments'] = [a.to_dict() for a in attachments]
if stickers is not MISSING:
if stickers is not None:
payload['sticker_ids'] = stickers
else:
payload['sticker_ids'] = []
payload['tts'] = tts
if avatar_url:
payload['avatar_url'] = str(avatar_url)
if username:
payload['username'] = username
if flags is not MISSING:
payload['flags'] = flags.value
if allowed_mentions:
if previous_allowed_mentions is not None:
payload['allowed_mentions'] = previous_allowed_mentions.merge(allowed_mentions).to_dict()
else:
payload['allowed_mentions'] = allowed_mentions.to_dict()
elif previous_allowed_mentions is not None:
payload['allowed_mentions'] = previous_allowed_mentions.to_dict()
if mention_author is not None:
try:
payload['allowed_mentions']['replied_user'] = mention_author
except KeyError:
pass
multipart = []
if file is not MISSING:
files = [file]
if files:
for index, file in enumerate(files):
attachments_payload = []
for index, file in enumerate(files):
attachment = {
'id': index,
'filename': file.filename,
}
if file.description is not None:
attachment['description'] = file.description
attachments_payload.append(attachment)
payload['attachments'] = attachments_payload
multipart.append({'name': 'payload_json', 'value': utils._to_json(payload)})
payload = None
for index, file in enumerate(files):
multipart.append(
{
'name': f'files[{index}]',
'value': file.fp,
'filename': file.filename,
'content_type': 'application/octet-stream',
}
)
return MultipartParameters(payload=payload, multipart=multipart, files=files)
class Route:
BASE: ClassVar[str] = 'https://discord.com/api/v8'
@@ -268,7 +417,7 @@ class HTTPClient:
if form:
# with quote_fields=True '[' and ']' in file field names are escaped, which discord does not support
form_data = aiohttp.FormData(quote_fields=False)
form_data = aiohttp.FormData(quote_fields=False)
for params in form:
form_data.add_field(**params)
kwargs['data'] = form_data
@@ -417,144 +566,18 @@ class HTTPClient:
def send_message(
self,
channel_id: Snowflake,
content: Optional[str],
*,
tts: bool = False,
embed: Optional[embed.Embed] = None,
embeds: Optional[List[embed.Embed]] = None,
nonce: Optional[str] = None,
allowed_mentions: Optional[message.AllowedMentions] = None,
message_reference: Optional[message.MessageReference] = None,
stickers: Optional[List[sticker.StickerItem]] = None,
components: Optional[List[components.Component]] = None,
params: MultipartParameters,
) -> Response[message.Message]:
r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id)
payload = {}
if content:
payload['content'] = content
if tts:
payload['tts'] = True
if embed:
payload['embeds'] = [embed]
if embeds:
payload['embeds'] = embeds
if nonce:
payload['nonce'] = nonce
if allowed_mentions:
payload['allowed_mentions'] = allowed_mentions
if message_reference:
payload['message_reference'] = message_reference
if components:
payload['components'] = components
if stickers:
payload['sticker_ids'] = stickers
return self.request(r, json=payload)
if params.files:
return self.request(r, files=params.files, form=params.multipart)
else:
return self.request(r, json=params.payload)
def send_typing(self, channel_id: Snowflake) -> Response[None]:
return self.request(Route('POST', '/channels/{channel_id}/typing', channel_id=channel_id))
def send_multipart_helper(
self,
route: Route,
*,
files: Sequence[File],
content: Optional[str] = None,
tts: bool = False,
embed: Optional[embed.Embed] = None,
embeds: Optional[Iterable[Optional[embed.Embed]]] = None,
nonce: Optional[str] = None,
allowed_mentions: Optional[message.AllowedMentions] = None,
message_reference: Optional[message.MessageReference] = None,
stickers: Optional[List[sticker.StickerItem]] = None,
components: Optional[List[components.Component]] = None,
) -> Response[message.Message]:
form = []
payload: Dict[str, Any] = {'tts': tts}
if content:
payload['content'] = content
if embed:
payload['embeds'] = [embed]
if embeds:
payload['embeds'] = embeds
if nonce:
payload['nonce'] = nonce
if allowed_mentions:
payload['allowed_mentions'] = allowed_mentions
if message_reference:
payload['message_reference'] = message_reference
if components:
payload['components'] = components
if stickers:
payload['sticker_ids'] = stickers
if files:
attachments = []
for index, file in enumerate(files):
attachment = {
"id": index,
"filename": file.filename,
}
if file.description is not None:
attachment["description"] = file.description
attachments.append(attachment)
payload['attachments'] = attachments
form.append({'name': 'payload_json', 'value': utils._to_json(payload)})
for index, file in enumerate(files):
form.append(
{
'name': f'files[{index}]',
'value': file.fp,
'filename': file.filename,
'content_type': 'image/png',
}
)
return self.request(route, form=form, files=files)
def send_files(
self,
channel_id: Snowflake,
*,
files: Sequence[File],
content: Optional[str] = None,
tts: bool = False,
embed: Optional[embed.Embed] = None,
embeds: Optional[List[embed.Embed]] = None,
nonce: Optional[str] = None,
allowed_mentions: Optional[message.AllowedMentions] = None,
message_reference: Optional[message.MessageReference] = None,
stickers: Optional[List[sticker.StickerItem]] = None,
components: Optional[List[components.Component]] = None,
) -> Response[message.Message]:
r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id)
return self.send_multipart_helper(
r,
files=files,
content=content,
tts=tts,
embed=embed,
embeds=embeds,
nonce=nonce,
allowed_mentions=allowed_mentions,
message_reference=message_reference,
stickers=stickers,
components=components,
)
def delete_message(
self, channel_id: Snowflake, message_id: Snowflake, *, reason: Optional[str] = None
) -> Response[None]:
@@ -571,9 +594,9 @@ class HTTPClient:
return self.request(r, json=payload, reason=reason)
def edit_message(self, channel_id: Snowflake, message_id: Snowflake, **fields: Any) -> Response[message.Message]:
def edit_message(self, channel_id: Snowflake, message_id: Snowflake, *, params: MultipartParameters) -> Response[message.Message]:
r = Route('PATCH', '/channels/{channel_id}/messages/{message_id}', channel_id=channel_id, message_id=message_id)
return self.request(r, json=fields)
return self.request(r, json=params.payload)
def add_reaction(self, channel_id: Snowflake, message_id: Snowflake, emoji: str) -> Response[None]:
r = Route(
@@ -1241,7 +1264,11 @@ class HTTPClient:
)
def modify_guild_sticker(
self, guild_id: Snowflake, sticker_id: Snowflake, payload: sticker.EditGuildSticker, reason: Optional[str],
self,
guild_id: Snowflake,
sticker_id: Snowflake,
payload: sticker.EditGuildSticker,
reason: Optional[str],
) -> Response[sticker.GuildSticker]:
return self.request(
Route('PATCH', '/guilds/{guild_id}/stickers/{sticker_id}', guild_id=guild_id, sticker_id=sticker_id),
@@ -1706,9 +1733,7 @@ class HTTPClient:
def get_global_commands(self, application_id: Snowflake) -> Response[List[command.ApplicationCommand]]:
return self.request(Route('GET', '/applications/{application_id}/commands', application_id=application_id))
def get_global_command(
self, application_id: Snowflake, command_id: Snowflake
) -> Response[command.ApplicationCommand]:
def get_global_command(self, application_id: Snowflake, command_id: Snowflake) -> Response[command.ApplicationCommand]:
r = Route(
'GET',
'/applications/{application_id}/commands/{command_id}',
@@ -1750,9 +1775,7 @@ class HTTPClient:
)
return self.request(r)
def bulk_upsert_global_commands(
self, application_id: Snowflake, payload
) -> Response[List[command.ApplicationCommand]]:
def bulk_upsert_global_commands(self, application_id: Snowflake, payload) -> Response[List[command.ApplicationCommand]]:
r = Route('PUT', '/applications/{application_id}/commands', application_id=application_id)
return self.request(r, json=payload)
@@ -1849,160 +1872,6 @@ class HTTPClient:
)
return self.request(r, json=payload)
# Interaction responses
def _edit_webhook_helper(
self,
route: Route,
file: Optional[File] = None,
content: Optional[str] = None,
embeds: Optional[List[embed.Embed]] = None,
allowed_mentions: Optional[message.AllowedMentions] = None,
):
payload: Dict[str, Any] = {}
if content:
payload['content'] = content
if embeds:
payload['embeds'] = embeds
if allowed_mentions:
payload['allowed_mentions'] = allowed_mentions
form: List[Dict[str, Any]] = [
{
'name': 'payload_json',
'value': utils._to_json(payload),
}
]
if file:
form.append(
{
'name': 'file',
'value': file.fp,
'filename': file.filename,
'content_type': 'application/octet-stream',
}
)
return self.request(route, form=form, files=[file] if file else None)
def create_interaction_response(
self,
interaction_id: Snowflake,
token: str,
*,
type: InteractionResponseType,
data: Optional[Dict[str, Any]] = None,
) -> Response[None]:
r = Route(
'POST',
'/interactions/{interaction_id}/{interaction_token}/callback',
interaction_id=interaction_id,
interaction_token=token,
)
payload: Dict[str, Any] = {
'type': type,
}
if data is not None:
payload['data'] = data
return self.request(r, json=payload)
def get_original_interaction_response(
self,
application_id: Snowflake,
token: str,
) -> Response[message.Message]:
r = Route(
'GET',
'/webhooks/{application_id}/{interaction_token}/messages/@original',
application_id=application_id,
interaction_token=token,
)
return self.request(r)
def edit_original_interaction_response(
self,
application_id: Snowflake,
token: str,
file: Optional[File] = None,
content: Optional[str] = None,
embeds: Optional[List[embed.Embed]] = None,
allowed_mentions: Optional[message.AllowedMentions] = None,
) -> Response[message.Message]:
r = Route(
'PATCH',
'/webhooks/{application_id}/{interaction_token}/messages/@original',
application_id=application_id,
interaction_token=token,
)
return self._edit_webhook_helper(r, file=file, content=content, embeds=embeds, allowed_mentions=allowed_mentions)
def delete_original_interaction_response(self, application_id: Snowflake, token: str) -> Response[None]:
r = Route(
'DELETE',
'/webhooks/{application_id}/{interaction_token}/messages/@original',
application_id=application_id,
interaction_token=token,
)
return self.request(r)
def create_followup_message(
self,
application_id: Snowflake,
token: str,
files: List[File] = [],
content: Optional[str] = None,
tts: bool = False,
embeds: Optional[List[embed.Embed]] = None,
allowed_mentions: Optional[message.AllowedMentions] = None,
) -> Response[message.Message]:
r = Route(
'POST',
'/webhooks/{application_id}/{interaction_token}',
application_id=application_id,
interaction_token=token,
)
return self.send_multipart_helper(
r,
content=content,
files=files,
tts=tts,
embeds=embeds,
allowed_mentions=allowed_mentions,
)
def edit_followup_message(
self,
application_id: Snowflake,
token: str,
message_id: Snowflake,
file: Optional[File] = None,
content: Optional[str] = None,
embeds: Optional[List[embed.Embed]] = None,
allowed_mentions: Optional[message.AllowedMentions] = None,
) -> Response[message.Message]:
r = Route(
'PATCH',
'/webhooks/{application_id}/{interaction_token}/messages/{message_id}',
application_id=application_id,
interaction_token=token,
message_id=message_id,
)
return self._edit_webhook_helper(r, file=file, content=content, embeds=embeds, allowed_mentions=allowed_mentions)
def delete_followup_message(self, application_id: Snowflake, token: str, message_id: Snowflake) -> Response[None]:
r = Route(
'DELETE',
'/webhooks/{application_id}/{interaction_token}/messages/{message_id}',
application_id=application_id,
interaction_token=token,
message_id=message_id,
)
return self.request(r)
def get_guild_application_command_permissions(
self,
application_id: Snowflake,