Add support for context menu commands

This commit is contained in:
Rapptz
2022-02-26 10:51:16 -05:00
parent 0d2db90028
commit dffd72da58
4 changed files with 465 additions and 98 deletions

View File

@@ -24,12 +24,12 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import inspect
from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union
from typing import Callable, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Type, Union, overload
from .namespace import Namespace
from .models import AppCommand
from .commands import Command, Group, _shorten
from .commands import Command, ContextMenu, Group, _shorten
from .enums import AppCommandType
from .errors import CommandAlreadyRegistered, CommandNotFound, CommandSignatureMismatch
from ..errors import ClientException
@@ -40,7 +40,7 @@ if TYPE_CHECKING:
from ..interactions import Interaction
from ..client import Client
from ..abc import Snowflake
from .commands import CommandCallback, P, T
from .commands import ContextMenuCallback, CommandCallback, P, T
__all__ = ('CommandTree',)
@@ -65,7 +65,7 @@ class CommandTree:
# The above two mappings can use this structure too but we need fast retrieval
# by name and guild_id in the above case while here it isn't as important since
# it's uncommon and N=5 anyway.
self._context_menus: Dict[Tuple[str, Optional[int], int], Command] = {}
self._context_menus: Dict[Tuple[str, Optional[int], int], ContextMenu] = {}
async def fetch_commands(self, *, guild: Optional[Snowflake] = None) -> List[AppCommand]:
"""|coro|
@@ -75,6 +75,10 @@ class CommandTree:
If no guild is passed then global commands are fetched, otherwise
the guild's commands are fetched instead.
.. note::
This includes context menu commands.
Parameters
-----------
guild: Optional[:class:`abc.Snowflake`]
@@ -103,7 +107,14 @@ class CommandTree:
return [AppCommand(data=data, state=self._state) for data in commands]
def add_command(self, command: Union[Command, Group], /, *, guild: Optional[Snowflake] = None, override: bool = False):
def add_command(
self,
command: Union[Command, ContextMenu, Group],
/,
*,
guild: Optional[Snowflake] = None,
override: bool = False,
):
"""Adds an application command to the tree.
This only adds the command locally -- in order to sync the commands
@@ -133,7 +144,20 @@ class CommandTree:
This is currently 100 for slash commands and 5 for context menu commands.
"""
if not isinstance(command, (Command, Group)):
if isinstance(command, ContextMenu):
guild_id = None if guild is None else guild.id
type = command.type.value
key = (command.name, guild_id, type)
found = key in self._context_menus
if found and not override:
raise CommandAlreadyRegistered(command.name, guild_id)
total = sum(1 for _, g, t in self._context_menus if g == guild_id and t == type)
if total + found > 5:
raise ValueError('maximum number of context menu commands exceeded (5)')
self._context_menus[key] = command
return
elif not isinstance(command, (Command, Group)):
raise TypeError(f'Expected a application command, received {command.__class__!r} instead')
# todo: validate application command groups having children (required)
@@ -156,7 +180,36 @@ class CommandTree:
raise ValueError('maximum number of slash commands exceeded (100)')
self._global_commands[name] = root
def remove_command(self, command: str, /, *, guild: Optional[Snowflake] = None) -> Optional[Union[Command, Group]]:
@overload
def remove_command(
self,
command: str,
/,
*,
guild: Optional[Snowflake] = ...,
type: Literal[AppCommandType.message, AppCommandType.user] = ...,
) -> Optional[ContextMenu]:
...
@overload
def remove_command(
self,
command: str,
/,
*,
guild: Optional[Snowflake] = ...,
type: Literal[AppCommandType.chat_input] = ...,
) -> Optional[Union[Command, Group]]:
...
def remove_command(
self,
command: str,
/,
*,
guild: Optional[Snowflake] = None,
type: AppCommandType = AppCommandType.chat_input,
) -> Optional[Union[Command, ContextMenu, Group]]:
"""Removes an application command from the tree.
This only removes the command locally -- in order to sync the commands
@@ -169,31 +222,64 @@ class CommandTree:
guild: Optional[:class:`abc.Snowflake`]
The guild to remove the command from. If not given then it
removes a global command instead.
type: :class:`AppCommandType`
The type of command to remove. Defaults to :attr:`AppCommandType.chat_input`,
i.e. slash commands.
Returns
---------
Optional[Union[:class:`Command`, :class:`Group`]]
Optional[Union[:class:`Command`, :class:`ContextMenu`, :class:`Group`]]
The application command that got removed.
If nothing was removed then ``None`` is returned instead.
"""
if guild is None:
return self._global_commands.pop(command, None)
else:
try:
commands = self._guild_commands[guild.id]
except KeyError:
return None
if type is AppCommandType.chat_input:
if guild is None:
return self._global_commands.pop(command, None)
else:
return commands.pop(command, None)
try:
commands = self._guild_commands[guild.id]
except KeyError:
return None
else:
return commands.pop(command, None)
elif type in (AppCommandType.user, AppCommandType.message):
guild_id = None if guild is None else guild.id
key = (command, guild_id, type.value)
return self._context_menus.pop(key, None)
def get_command(self, command: str, /, *, guild: Optional[Snowflake] = None) -> Optional[Union[Command, Group]]:
@overload
def get_command(
self,
command: str,
/,
*,
guild: Optional[Snowflake] = ...,
type: Literal[AppCommandType.message, AppCommandType.user] = ...,
) -> Optional[ContextMenu]:
...
@overload
def get_command(
self,
command: str,
/,
*,
guild: Optional[Snowflake] = ...,
type: Literal[AppCommandType.chat_input] = ...,
) -> Optional[Union[Command, Group]]:
...
def get_command(
self,
command: str,
/,
*,
guild: Optional[Snowflake] = None,
type: AppCommandType = AppCommandType.chat_input,
) -> Optional[Union[Command, ContextMenu, Group]]:
"""Gets a application command from the tree.
.. note::
This does *not* include context menu commands.
Parameters
-----------
command: :class:`str`
@@ -201,52 +287,103 @@ class CommandTree:
guild: Optional[:class:`abc.Snowflake`]
The guild to get the command from. If not given then it
gets a global command instead.
type: :class:`AppCommandType`
The type of command to get. Defaults to :attr:`AppCommandType.chat_input`,
i.e. slash commands.
Returns
---------
Optional[Union[:class:`Command`, :class:`Group`]]
Optional[Union[:class:`Command`, :class:`ContextMenu`, :class:`Group`]]
The application command that was found.
If nothing was found then ``None`` is returned instead.
"""
if guild is None:
return self._global_commands.get(command)
else:
try:
commands = self._guild_commands[guild.id]
except KeyError:
return None
if type is AppCommandType.chat_input:
if guild is None:
return self._global_commands.get(command)
else:
return commands.get(command)
try:
commands = self._guild_commands[guild.id]
except KeyError:
return None
else:
return commands.get(command)
elif type in (AppCommandType.user, AppCommandType.message):
guild_id = None if guild is None else guild.id
key = (command, guild_id, type.value)
return self._context_menus.get(key)
def get_commands(self, *, guild: Optional[Snowflake] = None) -> List[Union[Command, Group]]:
@overload
def get_commands(
self,
*,
guild: Optional[Snowflake] = ...,
type: Literal[AppCommandType.message, AppCommandType.user] = ...,
) -> List[ContextMenu]:
...
@overload
def get_commands(
self,
*,
guild: Optional[Snowflake] = ...,
type: Literal[AppCommandType.chat_input] = ...,
) -> List[Union[Command, Group]]:
...
def get_commands(
self,
*,
guild: Optional[Snowflake] = None,
type: AppCommandType = AppCommandType.chat_input,
) -> Union[List[Union[Command, Group]], List[ContextMenu]]:
"""Gets all application commands from the tree.
.. note::
This does *not* retrieve context menu commands.
Parameters
-----------
guild: Optional[:class:`~discord.abc.Snowflake`]
The guild to get the commands from. If not given then it
gets all global commands instead.
type: :class:`AppCommandType`
The type of commands to get. Defaults to :attr:`AppCommandType.chat_input`,
i.e. slash commands.
Returns
---------
List[Union[:class:`Command`, :class:`Group`]]
Union[List[:class:`ContextMenu`], List[Union[:class:`Command`, :class:`Group`]]
The application commands from the tree.
"""
if type is AppCommandType.chat_input:
if guild is None:
return list(self._global_commands.values())
else:
try:
commands = self._guild_commands[guild.id]
except KeyError:
return []
else:
return list(commands.values())
else:
guild_id = None if guild is None else guild.id
value = type.value
return [command for ((_, g, t), command) in self._context_menus.items() if g == guild_id and t == value]
def _get_all_commands(self, *, guild: Optional[Snowflake] = None) -> List[Union[Command, Group, ContextMenu]]:
if guild is None:
return list(self._global_commands.values())
base: List[Union[Command, Group, ContextMenu]] = list(self._global_commands.values())
base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g is None)
return base
else:
try:
commands = self._guild_commands[guild.id]
except KeyError:
return []
return [cmd for ((_, g, _), cmd) in self._context_menus.items() if g is None]
else:
return list(commands.values())
base: List[Union[Command, Group, ContextMenu]] = list(commands.values())
guild_id = guild.id
base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g == guild_id)
return base
def command(
self,
@@ -266,7 +403,7 @@ class CommandTree:
The description of the application command. This shows up in the UI to describe
the application command. If not given, it defaults to the first line of the docstring
of the callback shortened to 100 characters.
guild: Optional[:class:`Snowflake`]
guild: Optional[:class:`.abc.Snowflake`]
The guild to add the command to. If not given then it
becomes a global command instead.
"""
@@ -287,7 +424,6 @@ class CommandTree:
name=name if name is not MISSING else func.__name__,
description=desc,
callback=func,
type=AppCommandType.chat_input,
parent=None,
)
self.add_command(command, guild=guild)
@@ -295,6 +431,49 @@ class CommandTree:
return decorator
def context_menu(
self, *, name: str = MISSING, guild: Optional[Snowflake] = None
) -> Callable[[ContextMenuCallback], ContextMenu]:
"""Creates a application command context menu from a regular function directly under this tree.
This function must have a signature of :class:`~discord.Interaction` as its first parameter
and taking either a :class:`~discord.Member`, :class:`~discord.User`, or :class:`~discord.Message`,
or a :obj:`typing.Union` of ``Member`` and ``User`` as its second parameter.
Examples
---------
.. code-block:: python3
@app_commands.context_menu()
async def react(interaction: discord.Interaction, message: discord.Message):
await interaction.response.send_message('Very cool message!', ephemeral=True)
@app_commands.context_menu()
async def ban(interaction: discord.Interaction, user: discord.Member):
await interaction.response.send_message(f'Should I actually ban {user}...', ephemeral=True)
Parameters
------------
name: :class:`str`
The name of the context menu command. If not given, it defaults to a title-case
version of the callback name. Note that unlike regular slash commands this can
have spaces and upper case characters in the name.
guild: Optional[:class:`.abc.Snowflake`]
The guild to add the command to. If not given then it
becomes a global command instead.
"""
def decorator(func: ContextMenuCallback) -> ContextMenu:
if not inspect.iscoroutinefunction(func):
raise TypeError('context menu function must be a coroutine function')
context_menu = ContextMenu._from_decorator(func, name=name)
self.add_command(context_menu, guild=guild)
return context_menu
return decorator
async def sync(self, *, guild: Optional[Snowflake]) -> List[AppCommand]:
"""|coro|
@@ -327,7 +506,7 @@ class CommandTree:
if self.client.application_id is None:
raise ClientException('Client does not have an application ID set')
commands = self.get_commands(guild=guild)
commands = self._get_all_commands(guild=guild)
payload = [command.to_dict() for command in commands]
if guild is None:
data = await self._http.bulk_upsert_global_commands(self.client.application_id, payload=payload)
@@ -345,6 +524,25 @@ class CommandTree:
self.client.loop.create_task(wrapper(), name='CommandTree-invoker')
async def _call_context_menu(self, interaction: Interaction, data: ApplicationCommandInteractionData, type: int):
name = data['name']
guild_id = interaction.guild_id
ctx_menu = self._context_menus.get((name, guild_id, type))
if ctx_menu is None:
raise CommandNotFound(name, [], AppCommandType(type))
resolved = Namespace._get_resolved_items(interaction, data.get('resolved', {}))
# This will always work at runtime
value = resolved.get(data.get('target_id')) # type: ignore
if ctx_menu.type.value != type:
raise CommandSignatureMismatch(ctx_menu)
if value is None:
raise RuntimeError('This should not happen if Discord sent well-formed data.')
# I assume I don't have to type check here.
await ctx_menu._invoke(interaction, value)
async def call(self, interaction: Interaction):
"""|coro|
@@ -367,6 +565,12 @@ class CommandTree:
application command definition.
"""
data: ApplicationCommandInteractionData = interaction.data # type: ignore
type = data.get('type', 1)
if type != 1:
# Context menu command...
await self._call_context_menu(interaction, data, type)
return
parents: List[str] = []
name = data['name']
command = self._global_commands.get(name)