Implement a least breaking approach to slash commands (#39)

* Most slash command support completed, needs some debugging (and reindent)

* Implement a ctx.send helper for slash commands

* Add group command support

* Add Option converter, fix default optional, fix help command

* Add client.setup and move readying commands to that

* Implement _FakeSlashMessage.from_interaction

* Rename normmal_command to message_command

* Add docs for added params

* Add slash_command_guilds to bot and decos

* Fix merge conflict

* Remove name from commands.Option, wasn't used

* Move slash command processing to BotBase.process_slash_commands

* Create slash_only.py

Basic example for slash commands

* Create slash_and_message.py

Basic example for mixed commands

* Fix slash_command and normal_command bools

* Add some basic error handling for registration

* Fixed converter upload errors

* Fix some logic and make an actual example

* Thanks Safety Jim

* docstrings, *args, and error changes

* Add proper literal support

* Add basic documentation on slash commands

* Fix non-slash command interactions

* Fix ctx.reply in slash command context

* Fix typing on Context.reply

* Fix multiple optional argument sorting

* Update ctx.message docs to mention error instead of warning

* Move slash command creation to BotBase

* Fix code style issues with Black

* Rearrange some stuff and add flag support

* Change some errors and fix interaction.channel fixing

* Fix slash command quoting for *args

Co-authored-by: iDutchy <42503862+iDutchy@users.noreply.github.com>
Co-authored-by: Lint Action <lint-action@samuelmeuli.com>
This commit is contained in:
Gnome!
2021-09-19 00:28:11 +01:00
committed by GitHub
parent 75a23351c4
commit 1957fa6011
14 changed files with 662 additions and 39 deletions

View File

@@ -28,18 +28,43 @@ from __future__ import annotations
import asyncio
import collections
import collections.abc
import inspect
import importlib.util
import sys
import traceback
import types
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union
from collections import defaultdict
from discord.http import HTTPClient
from typing import (
Any,
Callable,
Iterable,
Tuple,
cast,
Mapping,
List,
Dict,
TYPE_CHECKING,
Optional,
TypeVar,
Type,
Union,
)
import discord
from discord.types.interactions import (
ApplicationCommandInteractionData,
ApplicationCommandInteractionDataOption,
EditApplicationCommand,
_ApplicationCommandInteractionDataOptionString,
)
from .core import GroupMixin
from .view import StringView
from .converter import Greedy
from .view import StringView, supported_quotes
from .context import Context
from .flags import FlagConverter
from . import errors
from .help import HelpCommand, DefaultHelpCommand
from .cog import Cog
@@ -67,6 +92,23 @@ CFT = TypeVar("CFT", bound="CoroFunc")
CXT = TypeVar("CXT", bound="Context")
class _FakeSlashMessage(discord.PartialMessage):
activity = application = edited_at = reference = webhook_id = None
attachments = components = reactions = stickers = mentions = []
author: Union[discord.User, discord.Member]
tts = False
@classmethod
def from_interaction(
cls, interaction: discord.Interaction, channel: Union[discord.TextChannel, discord.DMChannel, discord.Thread]
):
self = cls(channel=channel, id=interaction.id)
assert interaction.user is not None
self.author = interaction.user
return self
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
"""A callable that implements a command prefix equivalent to being mentioned.
@@ -118,6 +160,35 @@ def _is_submodule(parent: str, child: str) -> bool:
return parent == child or child.startswith(parent + ".")
def _unwrap_slash_groups(
data: ApplicationCommandInteractionData,
) -> Tuple[str, List[ApplicationCommandInteractionDataOption]]:
command_name = data["name"]
command_options = data.get("options") or []
while any(o["type"] in {1, 2} for o in command_options): # type: ignore
for option in command_options: # type: ignore
if option["type"] in {1, 2}: # type: ignore
command_name += f' {option["name"]}' # type: ignore
command_options = option.get("options") or []
return command_name, command_options
def _quote_string_safe(string: str) -> str:
# we need to quote this string otherwise we may spill into
# other parameters and cause all kinds of trouble, as many
# quotes are supported and some may be in the option, we
# loop through all supported quotes and if neither open or
# close are in the string, we add them
for open, close in supported_quotes.items():
if open not in string and close not in string:
return f"{open}{string}{close}"
# all supported quotes are in the message and we cannot add any
# safely, very unlikely but still got to be covered
raise errors.UnexpectedQuoteError(string)
class _DefaultRepr:
def __repr__(self):
return "<default-help-command>"
@@ -127,9 +198,22 @@ _default = _DefaultRepr()
class BotBase(GroupMixin):
def __init__(self, command_prefix, help_command=_default, description=None, *, intents: discord.Intents, **options):
def __init__(
self,
command_prefix,
help_command=_default,
description=None,
*,
intents: discord.Intents,
message_commands: bool = True,
slash_commands: bool = False,
**options,
):
super().__init__(**options, intents=intents)
self.command_prefix = command_prefix
self.slash_commands = slash_commands
self.message_commands = message_commands
self.extra_events: Dict[str, List[CoroFunc]] = {}
self.__cogs: Dict[str, Cog] = {}
self.__extensions: Dict[str, types.ModuleType] = {}
@@ -142,6 +226,7 @@ class BotBase(GroupMixin):
self.owner_id = options.get("owner_id")
self.owner_ids = options.get("owner_ids", set())
self.strip_after_prefix = options.get("strip_after_prefix", False)
self.slash_command_guilds: Optional[Iterable[int]] = options.get("slash_command_guilds", None)
if self.owner_id and self.owner_ids:
raise TypeError("Both owner_id and owner_ids are set.")
@@ -149,6 +234,9 @@ class BotBase(GroupMixin):
if self.owner_ids and not isinstance(self.owner_ids, collections.abc.Collection):
raise TypeError(f"owner_ids must be a collection not {self.owner_ids.__class__!r}")
if not (message_commands or slash_commands):
raise ValueError("Both message_commands and slash_commands are disabled.")
if help_command is _default:
self.help_command = DefaultHelpCommand()
else:
@@ -163,6 +251,55 @@ class BotBase(GroupMixin):
for event in self.extra_events.get(ev, []):
self._schedule_event(event, ev, *args, **kwargs) # type: ignore
async def setup(self):
await self.create_slash_commands()
async def create_slash_commands(self):
commands: defaultdict[Optional[int], List[EditApplicationCommand]] = defaultdict(list)
for command in self.commands:
if command.hidden or (command.slash_command is None and not self.slash_commands):
continue
try:
payload = command.to_application_command()
except Exception:
raise errors.ApplicationCommandRegistrationError(command)
if payload is None:
continue
guilds = command.slash_command_guilds or self.slash_command_guilds
if guilds is None:
commands[None].append(payload)
else:
for guild in guilds:
commands[guild].append(payload)
http: HTTPClient = self.http # type: ignore
global_commands = commands.pop(None, None)
application_id = self.application_id or (await self.application_info()).id # type: ignore
if global_commands is not None:
if self.slash_command_guilds is None:
await http.bulk_upsert_global_commands(
payload=global_commands,
application_id=application_id,
)
else:
for guild in self.slash_command_guilds:
await http.bulk_upsert_guild_commands(
guild_id=guild,
payload=global_commands,
application_id=application_id,
)
for guild, guild_commands in commands.items():
assert guild is not None
await http.bulk_upsert_guild_commands(
guild_id=guild,
payload=guild_commands,
application_id=application_id,
)
@discord.utils.copy_doc(discord.Client.close)
async def close(self) -> None:
for extension in tuple(self.__extensions):
@@ -1084,9 +1221,97 @@ class BotBase(GroupMixin):
ctx = await self.get_context(message)
await self.invoke(ctx)
async def process_slash_commands(self, interaction: discord.Interaction):
"""|coro|
This function processes a slash command interaction into a usable
message and calls :meth:`.process_commands` based on it. Without this
coroutine slash commands will not be triggered.
By default, this coroutine is called inside the :func:`.on_interaction`
event. If you choose to override the :func:`.on_interaction` event,
then you should invoke this coroutine as well.
.. versionadded:: 2.0
Parameters
-----------
interaction: :class:`discord.Interaction`
The interaction to process slash commands for.
"""
if interaction.type != discord.InteractionType.application_command:
return
interaction.data = cast(ApplicationCommandInteractionData, interaction.data)
command_name, command_options = _unwrap_slash_groups(interaction.data)
command = self.get_command(command_name)
if command is None:
raise errors.CommandNotFound(f'Command "{command_name}" is not found')
# Ensure the interaction channel is usable
channel = interaction.channel
if channel is None or isinstance(channel, discord.PartialMessageable):
if interaction.guild is None:
assert interaction.user is not None
channel = await interaction.user.create_dm()
elif interaction.channel_id is not None:
channel = await interaction.guild.fetch_channel(interaction.channel_id)
else:
return # cannot do anything without stable channel
# Fetch a valid prefix, so process_commands can function
message: discord.Message = _FakeSlashMessage.from_interaction(interaction, channel) # type: ignore
prefix = await self.get_prefix(message)
if isinstance(prefix, list):
prefix = prefix[0]
# Add arguments to fake message content, in the right order
ignore_params: List[inspect.Parameter] = []
message.content = f"{prefix}{command_name} "
for name, param in command.clean_params.items():
if inspect.isclass(param.annotation) and issubclass(param.annotation, FlagConverter):
for name, flag in param.annotation.get_flags().items():
option = next((o for o in command_options if o["name"] == name), None)
if option is None:
if flag.required:
raise errors.MissingRequiredFlag(flag)
else:
prefix = param.annotation.__commands_flag_prefix__
delimiter = param.annotation.__commands_flag_delimiter__
message.content += f"{prefix}{name} {option['value']}{delimiter}" # type: ignore
continue
option = next((o for o in command_options if o["name"] == name), None)
if option is None:
if param.default is param.empty and not command._is_typing_optional(param.annotation):
raise errors.MissingRequiredArgument(param)
else:
ignore_params.append(param)
elif (
option["type"] == 3
and not isinstance(param.annotation, Greedy)
and param.kind in {param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY}
):
# String with space in without "consume rest"
option = cast(_ApplicationCommandInteractionDataOptionString, option)
message.content += f"{_quote_string_safe(option['value'])} "
else:
message.content += f'{option.get("value", "")} '
ctx = await self.get_context(message)
ctx._ignored_params = ignore_params
ctx.interaction = interaction
await self.invoke(ctx)
async def on_message(self, message):
await self.process_commands(message)
async def on_interaction(self, interaction: discord.Interaction):
await self.process_slash_commands(interaction)
class Bot(BotBase, discord.Client):
"""Represents a discord bot.
@@ -1157,6 +1382,28 @@ class Bot(BotBase, discord.Client):
the ``command_prefix`` is set to ``!``. Defaults to ``False``.
.. versionadded:: 1.7
message_commands: Optional[:class:`bool`]
Whether to process commands based on messages.
Can be overwritten per command in the command decorators or when making
a :class:`Command` object via the ``message_command`` parameter
.. versionadded:: 2.0
slash_commands: Optional[:class:`bool`]
Whether to upload and process slash commands.
Can be overwritten per command in the command decorators or when making
a :class:`Command` object via the ``slash_command`` parameter
.. versionadded:: 2.0
slash_command_guilds: Optional[:class:`List[int]`]
If this is set, only upload slash commands to these guild IDs.
Can be overwritten per command in the command decorators or when making
a :class:`Command` object via the ``slash_command_guilds`` parameter
.. versionadded:: 2.0
"""
pass