Implement a least breaking approach to slash commands (#39)
* Most slash command support completed, needs some debugging (and reindent) * Implement a ctx.send helper for slash commands * Add group command support * Add Option converter, fix default optional, fix help command * Add client.setup and move readying commands to that * Implement _FakeSlashMessage.from_interaction * Rename normmal_command to message_command * Add docs for added params * Add slash_command_guilds to bot and decos * Fix merge conflict * Remove name from commands.Option, wasn't used * Move slash command processing to BotBase.process_slash_commands * Create slash_only.py Basic example for slash commands * Create slash_and_message.py Basic example for mixed commands * Fix slash_command and normal_command bools * Add some basic error handling for registration * Fixed converter upload errors * Fix some logic and make an actual example * Thanks Safety Jim * docstrings, *args, and error changes * Add proper literal support * Add basic documentation on slash commands * Fix non-slash command interactions * Fix ctx.reply in slash command context * Fix typing on Context.reply * Fix multiple optional argument sorting * Update ctx.message docs to mention error instead of warning * Move slash command creation to BotBase * Fix code style issues with Black * Rearrange some stuff and add flag support * Change some errors and fix interaction.channel fixing * Fix slash command quoting for *args Co-authored-by: iDutchy <42503862+iDutchy@users.noreply.github.com> Co-authored-by: Lint Action <lint-action@samuelmeuli.com>
This commit is contained in:
@@ -28,18 +28,43 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import collections
|
||||
import collections.abc
|
||||
|
||||
import inspect
|
||||
import importlib.util
|
||||
import sys
|
||||
import traceback
|
||||
import types
|
||||
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union
|
||||
from collections import defaultdict
|
||||
from discord.http import HTTPClient
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Iterable,
|
||||
Tuple,
|
||||
cast,
|
||||
Mapping,
|
||||
List,
|
||||
Dict,
|
||||
TYPE_CHECKING,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import discord
|
||||
from discord.types.interactions import (
|
||||
ApplicationCommandInteractionData,
|
||||
ApplicationCommandInteractionDataOption,
|
||||
EditApplicationCommand,
|
||||
_ApplicationCommandInteractionDataOptionString,
|
||||
)
|
||||
|
||||
from .core import GroupMixin
|
||||
from .view import StringView
|
||||
from .converter import Greedy
|
||||
from .view import StringView, supported_quotes
|
||||
from .context import Context
|
||||
from .flags import FlagConverter
|
||||
from . import errors
|
||||
from .help import HelpCommand, DefaultHelpCommand
|
||||
from .cog import Cog
|
||||
@@ -67,6 +92,23 @@ CFT = TypeVar("CFT", bound="CoroFunc")
|
||||
CXT = TypeVar("CXT", bound="Context")
|
||||
|
||||
|
||||
class _FakeSlashMessage(discord.PartialMessage):
|
||||
activity = application = edited_at = reference = webhook_id = None
|
||||
attachments = components = reactions = stickers = mentions = []
|
||||
author: Union[discord.User, discord.Member]
|
||||
tts = False
|
||||
|
||||
@classmethod
|
||||
def from_interaction(
|
||||
cls, interaction: discord.Interaction, channel: Union[discord.TextChannel, discord.DMChannel, discord.Thread]
|
||||
):
|
||||
self = cls(channel=channel, id=interaction.id)
|
||||
assert interaction.user is not None
|
||||
self.author = interaction.user
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
|
||||
"""A callable that implements a command prefix equivalent to being mentioned.
|
||||
|
||||
@@ -118,6 +160,35 @@ def _is_submodule(parent: str, child: str) -> bool:
|
||||
return parent == child or child.startswith(parent + ".")
|
||||
|
||||
|
||||
def _unwrap_slash_groups(
|
||||
data: ApplicationCommandInteractionData,
|
||||
) -> Tuple[str, List[ApplicationCommandInteractionDataOption]]:
|
||||
command_name = data["name"]
|
||||
command_options = data.get("options") or []
|
||||
while any(o["type"] in {1, 2} for o in command_options): # type: ignore
|
||||
for option in command_options: # type: ignore
|
||||
if option["type"] in {1, 2}: # type: ignore
|
||||
command_name += f' {option["name"]}' # type: ignore
|
||||
command_options = option.get("options") or []
|
||||
|
||||
return command_name, command_options
|
||||
|
||||
|
||||
def _quote_string_safe(string: str) -> str:
|
||||
# we need to quote this string otherwise we may spill into
|
||||
# other parameters and cause all kinds of trouble, as many
|
||||
# quotes are supported and some may be in the option, we
|
||||
# loop through all supported quotes and if neither open or
|
||||
# close are in the string, we add them
|
||||
for open, close in supported_quotes.items():
|
||||
if open not in string and close not in string:
|
||||
return f"{open}{string}{close}"
|
||||
|
||||
# all supported quotes are in the message and we cannot add any
|
||||
# safely, very unlikely but still got to be covered
|
||||
raise errors.UnexpectedQuoteError(string)
|
||||
|
||||
|
||||
class _DefaultRepr:
|
||||
def __repr__(self):
|
||||
return "<default-help-command>"
|
||||
@@ -127,9 +198,22 @@ _default = _DefaultRepr()
|
||||
|
||||
|
||||
class BotBase(GroupMixin):
|
||||
def __init__(self, command_prefix, help_command=_default, description=None, *, intents: discord.Intents, **options):
|
||||
def __init__(
|
||||
self,
|
||||
command_prefix,
|
||||
help_command=_default,
|
||||
description=None,
|
||||
*,
|
||||
intents: discord.Intents,
|
||||
message_commands: bool = True,
|
||||
slash_commands: bool = False,
|
||||
**options,
|
||||
):
|
||||
super().__init__(**options, intents=intents)
|
||||
|
||||
self.command_prefix = command_prefix
|
||||
self.slash_commands = slash_commands
|
||||
self.message_commands = message_commands
|
||||
self.extra_events: Dict[str, List[CoroFunc]] = {}
|
||||
self.__cogs: Dict[str, Cog] = {}
|
||||
self.__extensions: Dict[str, types.ModuleType] = {}
|
||||
@@ -142,6 +226,7 @@ class BotBase(GroupMixin):
|
||||
self.owner_id = options.get("owner_id")
|
||||
self.owner_ids = options.get("owner_ids", set())
|
||||
self.strip_after_prefix = options.get("strip_after_prefix", False)
|
||||
self.slash_command_guilds: Optional[Iterable[int]] = options.get("slash_command_guilds", None)
|
||||
|
||||
if self.owner_id and self.owner_ids:
|
||||
raise TypeError("Both owner_id and owner_ids are set.")
|
||||
@@ -149,6 +234,9 @@ class BotBase(GroupMixin):
|
||||
if self.owner_ids and not isinstance(self.owner_ids, collections.abc.Collection):
|
||||
raise TypeError(f"owner_ids must be a collection not {self.owner_ids.__class__!r}")
|
||||
|
||||
if not (message_commands or slash_commands):
|
||||
raise ValueError("Both message_commands and slash_commands are disabled.")
|
||||
|
||||
if help_command is _default:
|
||||
self.help_command = DefaultHelpCommand()
|
||||
else:
|
||||
@@ -163,6 +251,55 @@ class BotBase(GroupMixin):
|
||||
for event in self.extra_events.get(ev, []):
|
||||
self._schedule_event(event, ev, *args, **kwargs) # type: ignore
|
||||
|
||||
async def setup(self):
|
||||
await self.create_slash_commands()
|
||||
|
||||
async def create_slash_commands(self):
|
||||
commands: defaultdict[Optional[int], List[EditApplicationCommand]] = defaultdict(list)
|
||||
for command in self.commands:
|
||||
if command.hidden or (command.slash_command is None and not self.slash_commands):
|
||||
continue
|
||||
|
||||
try:
|
||||
payload = command.to_application_command()
|
||||
except Exception:
|
||||
raise errors.ApplicationCommandRegistrationError(command)
|
||||
|
||||
if payload is None:
|
||||
continue
|
||||
|
||||
guilds = command.slash_command_guilds or self.slash_command_guilds
|
||||
if guilds is None:
|
||||
commands[None].append(payload)
|
||||
else:
|
||||
for guild in guilds:
|
||||
commands[guild].append(payload)
|
||||
|
||||
http: HTTPClient = self.http # type: ignore
|
||||
global_commands = commands.pop(None, None)
|
||||
application_id = self.application_id or (await self.application_info()).id # type: ignore
|
||||
if global_commands is not None:
|
||||
if self.slash_command_guilds is None:
|
||||
await http.bulk_upsert_global_commands(
|
||||
payload=global_commands,
|
||||
application_id=application_id,
|
||||
)
|
||||
else:
|
||||
for guild in self.slash_command_guilds:
|
||||
await http.bulk_upsert_guild_commands(
|
||||
guild_id=guild,
|
||||
payload=global_commands,
|
||||
application_id=application_id,
|
||||
)
|
||||
|
||||
for guild, guild_commands in commands.items():
|
||||
assert guild is not None
|
||||
await http.bulk_upsert_guild_commands(
|
||||
guild_id=guild,
|
||||
payload=guild_commands,
|
||||
application_id=application_id,
|
||||
)
|
||||
|
||||
@discord.utils.copy_doc(discord.Client.close)
|
||||
async def close(self) -> None:
|
||||
for extension in tuple(self.__extensions):
|
||||
@@ -1084,9 +1221,97 @@ class BotBase(GroupMixin):
|
||||
ctx = await self.get_context(message)
|
||||
await self.invoke(ctx)
|
||||
|
||||
async def process_slash_commands(self, interaction: discord.Interaction):
|
||||
"""|coro|
|
||||
|
||||
This function processes a slash command interaction into a usable
|
||||
message and calls :meth:`.process_commands` based on it. Without this
|
||||
coroutine slash commands will not be triggered.
|
||||
|
||||
By default, this coroutine is called inside the :func:`.on_interaction`
|
||||
event. If you choose to override the :func:`.on_interaction` event,
|
||||
then you should invoke this coroutine as well.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
interaction: :class:`discord.Interaction`
|
||||
The interaction to process slash commands for.
|
||||
|
||||
"""
|
||||
if interaction.type != discord.InteractionType.application_command:
|
||||
return
|
||||
|
||||
interaction.data = cast(ApplicationCommandInteractionData, interaction.data)
|
||||
command_name, command_options = _unwrap_slash_groups(interaction.data)
|
||||
|
||||
command = self.get_command(command_name)
|
||||
if command is None:
|
||||
raise errors.CommandNotFound(f'Command "{command_name}" is not found')
|
||||
|
||||
# Ensure the interaction channel is usable
|
||||
channel = interaction.channel
|
||||
if channel is None or isinstance(channel, discord.PartialMessageable):
|
||||
if interaction.guild is None:
|
||||
assert interaction.user is not None
|
||||
channel = await interaction.user.create_dm()
|
||||
elif interaction.channel_id is not None:
|
||||
channel = await interaction.guild.fetch_channel(interaction.channel_id)
|
||||
else:
|
||||
return # cannot do anything without stable channel
|
||||
|
||||
# Fetch a valid prefix, so process_commands can function
|
||||
message: discord.Message = _FakeSlashMessage.from_interaction(interaction, channel) # type: ignore
|
||||
prefix = await self.get_prefix(message)
|
||||
if isinstance(prefix, list):
|
||||
prefix = prefix[0]
|
||||
|
||||
# Add arguments to fake message content, in the right order
|
||||
ignore_params: List[inspect.Parameter] = []
|
||||
message.content = f"{prefix}{command_name} "
|
||||
for name, param in command.clean_params.items():
|
||||
if inspect.isclass(param.annotation) and issubclass(param.annotation, FlagConverter):
|
||||
for name, flag in param.annotation.get_flags().items():
|
||||
option = next((o for o in command_options if o["name"] == name), None)
|
||||
|
||||
if option is None:
|
||||
if flag.required:
|
||||
raise errors.MissingRequiredFlag(flag)
|
||||
else:
|
||||
prefix = param.annotation.__commands_flag_prefix__
|
||||
delimiter = param.annotation.__commands_flag_delimiter__
|
||||
message.content += f"{prefix}{name} {option['value']}{delimiter}" # type: ignore
|
||||
continue
|
||||
|
||||
option = next((o for o in command_options if o["name"] == name), None)
|
||||
if option is None:
|
||||
if param.default is param.empty and not command._is_typing_optional(param.annotation):
|
||||
raise errors.MissingRequiredArgument(param)
|
||||
else:
|
||||
ignore_params.append(param)
|
||||
elif (
|
||||
option["type"] == 3
|
||||
and not isinstance(param.annotation, Greedy)
|
||||
and param.kind in {param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY}
|
||||
):
|
||||
# String with space in without "consume rest"
|
||||
option = cast(_ApplicationCommandInteractionDataOptionString, option)
|
||||
message.content += f"{_quote_string_safe(option['value'])} "
|
||||
else:
|
||||
message.content += f'{option.get("value", "")} '
|
||||
|
||||
ctx = await self.get_context(message)
|
||||
ctx._ignored_params = ignore_params
|
||||
ctx.interaction = interaction
|
||||
await self.invoke(ctx)
|
||||
|
||||
async def on_message(self, message):
|
||||
await self.process_commands(message)
|
||||
|
||||
async def on_interaction(self, interaction: discord.Interaction):
|
||||
await self.process_slash_commands(interaction)
|
||||
|
||||
|
||||
class Bot(BotBase, discord.Client):
|
||||
"""Represents a discord bot.
|
||||
@@ -1157,6 +1382,28 @@ class Bot(BotBase, discord.Client):
|
||||
the ``command_prefix`` is set to ``!``. Defaults to ``False``.
|
||||
|
||||
.. versionadded:: 1.7
|
||||
message_commands: Optional[:class:`bool`]
|
||||
Whether to process commands based on messages.
|
||||
|
||||
Can be overwritten per command in the command decorators or when making
|
||||
a :class:`Command` object via the ``message_command`` parameter
|
||||
|
||||
.. versionadded:: 2.0
|
||||
slash_commands: Optional[:class:`bool`]
|
||||
Whether to upload and process slash commands.
|
||||
|
||||
Can be overwritten per command in the command decorators or when making
|
||||
a :class:`Command` object via the ``slash_command`` parameter
|
||||
|
||||
.. versionadded:: 2.0
|
||||
slash_command_guilds: Optional[:class:`List[int]`]
|
||||
If this is set, only upload slash commands to these guild IDs.
|
||||
|
||||
Can be overwritten per command in the command decorators or when making
|
||||
a :class:`Command` object via the ``slash_command_guilds`` parameter
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
"""
|
||||
|
||||
pass
|
||||
|
Reference in New Issue
Block a user