Move slash command processing to BotBase.process_slash_commands
This commit is contained in:
parent
84b1d7d0cd
commit
caa5f39c0f
@ -36,16 +36,17 @@ import sys
|
|||||||
import traceback
|
import traceback
|
||||||
import types
|
import types
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Callable, Iterable, cast, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union
|
from typing import Any, Callable, Iterable, Tuple, cast, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
from discord.types.interactions import (
|
from discord.types.interactions import (
|
||||||
ApplicationCommandInteractionData,
|
ApplicationCommandInteractionData,
|
||||||
|
ApplicationCommandInteractionDataOption,
|
||||||
EditApplicationCommand,
|
EditApplicationCommand,
|
||||||
_ApplicationCommandInteractionDataOptionString
|
_ApplicationCommandInteractionDataOptionString
|
||||||
)
|
)
|
||||||
|
|
||||||
from .core import GroupMixin
|
from .core import Command, GroupMixin
|
||||||
from .converter import Greedy
|
from .converter import Greedy
|
||||||
from .view import StringView, supported_quotes
|
from .view import StringView, supported_quotes
|
||||||
from .context import Context
|
from .context import Context
|
||||||
@ -136,6 +137,18 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M
|
|||||||
def _is_submodule(parent: str, child: str) -> bool:
|
def _is_submodule(parent: str, child: str) -> bool:
|
||||||
return parent == child or child.startswith(parent + ".")
|
return parent == child or child.startswith(parent + ".")
|
||||||
|
|
||||||
|
def _unwrap_slash_groups(data: ApplicationCommandInteractionData) -> Tuple[str, List[ApplicationCommandInteractionDataOption]]:
|
||||||
|
command_name = data['name']
|
||||||
|
command_options = data.get('options') or []
|
||||||
|
while any(o["type"] in {1, 2} for o in command_options): # type: ignore
|
||||||
|
for option in command_options: # type: ignore
|
||||||
|
if option['type'] in {1, 2}: # type: ignore
|
||||||
|
command_name += f' {option["name"]}' # type: ignore
|
||||||
|
command_options = option.get('options') or []
|
||||||
|
|
||||||
|
return command_name, command_options
|
||||||
|
|
||||||
|
|
||||||
class _DefaultRepr:
|
class _DefaultRepr:
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<default-help-command>'
|
return '<default-help-command>'
|
||||||
@ -1110,21 +1123,22 @@ class BotBase(GroupMixin):
|
|||||||
ctx = await self.get_context(message)
|
ctx = await self.get_context(message)
|
||||||
await self.invoke(ctx)
|
await self.invoke(ctx)
|
||||||
|
|
||||||
async def on_message(self, message):
|
|
||||||
if self.message_commands:
|
|
||||||
await self.process_commands(message)
|
|
||||||
|
|
||||||
async def on_interaction(self, interaction: discord.Interaction):
|
async def process_slash_commands(self, interaction: discord.Interaction):
|
||||||
if not self.slash_commands or interaction.type != discord.InteractionType.application_command:
|
|
||||||
return
|
|
||||||
|
|
||||||
assert interaction.user is not None
|
|
||||||
interaction.data = cast(ApplicationCommandInteractionData, interaction.data)
|
interaction.data = cast(ApplicationCommandInteractionData, interaction.data)
|
||||||
|
command_name, command_options = _unwrap_slash_groups(interaction.data)
|
||||||
|
|
||||||
|
command = self.get_command(command_name)
|
||||||
|
if command is None:
|
||||||
|
raise errors.CommandNotFound(f'Command "{command_name}" is not found')
|
||||||
|
elif not command.slash_command:
|
||||||
|
return
|
||||||
|
|
||||||
# Ensure the interaction channel is usable
|
# Ensure the interaction channel is usable
|
||||||
channel = interaction.channel
|
channel = interaction.channel
|
||||||
if channel is None or isinstance(channel, discord.PartialMessageable):
|
if channel is None or isinstance(channel, discord.PartialMessageable):
|
||||||
if interaction.guild is None:
|
if interaction.guild is None:
|
||||||
|
assert interaction.user is not None
|
||||||
channel = await interaction.user.create_dm()
|
channel = await interaction.user.create_dm()
|
||||||
elif interaction.channel_id is not None:
|
elif interaction.channel_id is not None:
|
||||||
channel = await interaction.guild.fetch_channel(interaction.channel_id)
|
channel = await interaction.guild.fetch_channel(interaction.channel_id)
|
||||||
@ -1134,19 +1148,6 @@ class BotBase(GroupMixin):
|
|||||||
interaction.channel = channel # type: ignore
|
interaction.channel = channel # type: ignore
|
||||||
del channel
|
del channel
|
||||||
|
|
||||||
# Fetch out subcommands from the options
|
|
||||||
command_name = interaction.data['name']
|
|
||||||
command_options = interaction.data.get('options') or []
|
|
||||||
while any(o["type"] in {1, 2} for o in command_options):
|
|
||||||
for option in command_options:
|
|
||||||
if option['type'] in {1, 2}:
|
|
||||||
command_name += f' {option["name"]}'
|
|
||||||
command_options = option.get('options') or []
|
|
||||||
|
|
||||||
command = self.get_command(command_name)
|
|
||||||
if command is None:
|
|
||||||
raise errors.CommandNotFound(f'Command "{command_name}" is not found')
|
|
||||||
|
|
||||||
# Fetch a valid prefix, so process_commands can function
|
# Fetch a valid prefix, so process_commands can function
|
||||||
message = _FakeSlashMessage.from_interaction(interaction)
|
message = _FakeSlashMessage.from_interaction(interaction)
|
||||||
prefix = await self.get_prefix(message)
|
prefix = await self.get_prefix(message)
|
||||||
@ -1157,7 +1158,6 @@ class BotBase(GroupMixin):
|
|||||||
message.content = f'{prefix}{command_name} '
|
message.content = f'{prefix}{command_name} '
|
||||||
for name, param in command.clean_params.items():
|
for name, param in command.clean_params.items():
|
||||||
option = next((o for o in command_options if o['name'] == name), None) # type: ignore
|
option = next((o for o in command_options if o['name'] == name), None) # type: ignore
|
||||||
|
|
||||||
if option is None:
|
if option is None:
|
||||||
if param.default is param.empty and not command._is_typing_optional(param.annotation):
|
if param.default is param.empty and not command._is_typing_optional(param.annotation):
|
||||||
raise errors.MissingRequiredArgument(param)
|
raise errors.MissingRequiredArgument(param)
|
||||||
@ -1178,7 +1178,7 @@ class BotBase(GroupMixin):
|
|||||||
quoted = False
|
quoted = False
|
||||||
string = option['value']
|
string = option['value']
|
||||||
for open, close in supported_quotes.items():
|
for open, close in supported_quotes.items():
|
||||||
if not (open in string or close in string):
|
if open not in string and close not in string:
|
||||||
message.content += f"{open}{string}{close} "
|
message.content += f"{open}{string}{close} "
|
||||||
quoted = True
|
quoted = True
|
||||||
break
|
break
|
||||||
@ -1195,6 +1195,15 @@ class BotBase(GroupMixin):
|
|||||||
await self.invoke(ctx)
|
await self.invoke(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
async def on_message(self, message):
|
||||||
|
if self.message_commands:
|
||||||
|
await self.process_commands(message)
|
||||||
|
|
||||||
|
async def on_interaction(self, interaction: discord.Interaction):
|
||||||
|
if self.slash_commands and interaction.type == discord.InteractionType.application_command:
|
||||||
|
await self.process_slash_commands(interaction)
|
||||||
|
|
||||||
|
|
||||||
class Bot(BotBase, discord.Client):
|
class Bot(BotBase, discord.Client):
|
||||||
"""Represents a discord bot.
|
"""Represents a discord bot.
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user