diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index 2bda7313..72ece175 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -36,16 +36,17 @@ import sys import traceback import types from collections import defaultdict -from typing import Any, Callable, Iterable, cast, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union +from typing import Any, Callable, Iterable, Tuple, cast, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union import discord from discord.types.interactions import ( ApplicationCommandInteractionData, + ApplicationCommandInteractionDataOption, EditApplicationCommand, _ApplicationCommandInteractionDataOptionString ) -from .core import GroupMixin +from .core import Command, GroupMixin from .converter import Greedy from .view import StringView, supported_quotes from .context import Context @@ -136,6 +137,18 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M def _is_submodule(parent: str, child: str) -> bool: return parent == child or child.startswith(parent + ".") +def _unwrap_slash_groups(data: ApplicationCommandInteractionData) -> Tuple[str, List[ApplicationCommandInteractionDataOption]]: + command_name = data['name'] + command_options = data.get('options') or [] + while any(o["type"] in {1, 2} for o in command_options): # type: ignore + for option in command_options: # type: ignore + if option['type'] in {1, 2}: # type: ignore + command_name += f' {option["name"]}' # type: ignore + command_options = option.get('options') or [] + + return command_name, command_options + + class _DefaultRepr: def __repr__(self): return '' @@ -1110,21 +1123,22 @@ class BotBase(GroupMixin): ctx = await self.get_context(message) await self.invoke(ctx) - async def on_message(self, message): - if self.message_commands: - await self.process_commands(message) - async def on_interaction(self, interaction: discord.Interaction): - if not self.slash_commands or interaction.type != discord.InteractionType.application_command: - return - - assert interaction.user is not None + async def process_slash_commands(self, interaction: discord.Interaction): interaction.data = cast(ApplicationCommandInteractionData, interaction.data) + command_name, command_options = _unwrap_slash_groups(interaction.data) + + command = self.get_command(command_name) + if command is None: + raise errors.CommandNotFound(f'Command "{command_name}" is not found') + elif not command.slash_command: + return # Ensure the interaction channel is usable channel = interaction.channel if channel is None or isinstance(channel, discord.PartialMessageable): if interaction.guild is None: + assert interaction.user is not None channel = await interaction.user.create_dm() elif interaction.channel_id is not None: channel = await interaction.guild.fetch_channel(interaction.channel_id) @@ -1134,19 +1148,6 @@ class BotBase(GroupMixin): interaction.channel = channel # type: ignore del channel - # Fetch out subcommands from the options - command_name = interaction.data['name'] - command_options = interaction.data.get('options') or [] - while any(o["type"] in {1, 2} for o in command_options): - for option in command_options: - if option['type'] in {1, 2}: - command_name += f' {option["name"]}' - command_options = option.get('options') or [] - - command = self.get_command(command_name) - if command is None: - raise errors.CommandNotFound(f'Command "{command_name}" is not found') - # Fetch a valid prefix, so process_commands can function message = _FakeSlashMessage.from_interaction(interaction) prefix = await self.get_prefix(message) @@ -1157,7 +1158,6 @@ class BotBase(GroupMixin): message.content = f'{prefix}{command_name} ' for name, param in command.clean_params.items(): option = next((o for o in command_options if o['name'] == name), None) # type: ignore - if option is None: if param.default is param.empty and not command._is_typing_optional(param.annotation): raise errors.MissingRequiredArgument(param) @@ -1178,7 +1178,7 @@ class BotBase(GroupMixin): quoted = False string = option['value'] for open, close in supported_quotes.items(): - if not (open in string or close in string): + if open not in string and close not in string: message.content += f"{open}{string}{close} " quoted = True break @@ -1195,6 +1195,15 @@ class BotBase(GroupMixin): await self.invoke(ctx) + async def on_message(self, message): + if self.message_commands: + await self.process_commands(message) + + async def on_interaction(self, interaction: discord.Interaction): + if self.slash_commands and interaction.type == discord.InteractionType.application_command: + await self.process_slash_commands(interaction) + + class Bot(BotBase, discord.Client): """Represents a discord bot.