mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-10-25 02:23:04 +00:00
Add support for context menu commands
This commit is contained in:
@@ -24,12 +24,12 @@ DEALINGS IN THE SOFTWARE.
|
||||
|
||||
from __future__ import annotations
|
||||
import inspect
|
||||
from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union
|
||||
from typing import Callable, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Type, Union, overload
|
||||
|
||||
|
||||
from .namespace import Namespace
|
||||
from .models import AppCommand
|
||||
from .commands import Command, Group, _shorten
|
||||
from .commands import Command, ContextMenu, Group, _shorten
|
||||
from .enums import AppCommandType
|
||||
from .errors import CommandAlreadyRegistered, CommandNotFound, CommandSignatureMismatch
|
||||
from ..errors import ClientException
|
||||
@@ -40,7 +40,7 @@ if TYPE_CHECKING:
|
||||
from ..interactions import Interaction
|
||||
from ..client import Client
|
||||
from ..abc import Snowflake
|
||||
from .commands import CommandCallback, P, T
|
||||
from .commands import ContextMenuCallback, CommandCallback, P, T
|
||||
|
||||
__all__ = ('CommandTree',)
|
||||
|
||||
@@ -65,7 +65,7 @@ class CommandTree:
|
||||
# The above two mappings can use this structure too but we need fast retrieval
|
||||
# by name and guild_id in the above case while here it isn't as important since
|
||||
# it's uncommon and N=5 anyway.
|
||||
self._context_menus: Dict[Tuple[str, Optional[int], int], Command] = {}
|
||||
self._context_menus: Dict[Tuple[str, Optional[int], int], ContextMenu] = {}
|
||||
|
||||
async def fetch_commands(self, *, guild: Optional[Snowflake] = None) -> List[AppCommand]:
|
||||
"""|coro|
|
||||
@@ -75,6 +75,10 @@ class CommandTree:
|
||||
If no guild is passed then global commands are fetched, otherwise
|
||||
the guild's commands are fetched instead.
|
||||
|
||||
.. note::
|
||||
|
||||
This includes context menu commands.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
guild: Optional[:class:`abc.Snowflake`]
|
||||
@@ -103,7 +107,14 @@ class CommandTree:
|
||||
|
||||
return [AppCommand(data=data, state=self._state) for data in commands]
|
||||
|
||||
def add_command(self, command: Union[Command, Group], /, *, guild: Optional[Snowflake] = None, override: bool = False):
|
||||
def add_command(
|
||||
self,
|
||||
command: Union[Command, ContextMenu, Group],
|
||||
/,
|
||||
*,
|
||||
guild: Optional[Snowflake] = None,
|
||||
override: bool = False,
|
||||
):
|
||||
"""Adds an application command to the tree.
|
||||
|
||||
This only adds the command locally -- in order to sync the commands
|
||||
@@ -133,7 +144,20 @@ class CommandTree:
|
||||
This is currently 100 for slash commands and 5 for context menu commands.
|
||||
"""
|
||||
|
||||
if not isinstance(command, (Command, Group)):
|
||||
if isinstance(command, ContextMenu):
|
||||
guild_id = None if guild is None else guild.id
|
||||
type = command.type.value
|
||||
key = (command.name, guild_id, type)
|
||||
found = key in self._context_menus
|
||||
if found and not override:
|
||||
raise CommandAlreadyRegistered(command.name, guild_id)
|
||||
|
||||
total = sum(1 for _, g, t in self._context_menus if g == guild_id and t == type)
|
||||
if total + found > 5:
|
||||
raise ValueError('maximum number of context menu commands exceeded (5)')
|
||||
self._context_menus[key] = command
|
||||
return
|
||||
elif not isinstance(command, (Command, Group)):
|
||||
raise TypeError(f'Expected a application command, received {command.__class__!r} instead')
|
||||
|
||||
# todo: validate application command groups having children (required)
|
||||
@@ -156,7 +180,36 @@ class CommandTree:
|
||||
raise ValueError('maximum number of slash commands exceeded (100)')
|
||||
self._global_commands[name] = root
|
||||
|
||||
def remove_command(self, command: str, /, *, guild: Optional[Snowflake] = None) -> Optional[Union[Command, Group]]:
|
||||
@overload
|
||||
def remove_command(
|
||||
self,
|
||||
command: str,
|
||||
/,
|
||||
*,
|
||||
guild: Optional[Snowflake] = ...,
|
||||
type: Literal[AppCommandType.message, AppCommandType.user] = ...,
|
||||
) -> Optional[ContextMenu]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def remove_command(
|
||||
self,
|
||||
command: str,
|
||||
/,
|
||||
*,
|
||||
guild: Optional[Snowflake] = ...,
|
||||
type: Literal[AppCommandType.chat_input] = ...,
|
||||
) -> Optional[Union[Command, Group]]:
|
||||
...
|
||||
|
||||
def remove_command(
|
||||
self,
|
||||
command: str,
|
||||
/,
|
||||
*,
|
||||
guild: Optional[Snowflake] = None,
|
||||
type: AppCommandType = AppCommandType.chat_input,
|
||||
) -> Optional[Union[Command, ContextMenu, Group]]:
|
||||
"""Removes an application command from the tree.
|
||||
|
||||
This only removes the command locally -- in order to sync the commands
|
||||
@@ -169,31 +222,64 @@ class CommandTree:
|
||||
guild: Optional[:class:`abc.Snowflake`]
|
||||
The guild to remove the command from. If not given then it
|
||||
removes a global command instead.
|
||||
type: :class:`AppCommandType`
|
||||
The type of command to remove. Defaults to :attr:`AppCommandType.chat_input`,
|
||||
i.e. slash commands.
|
||||
|
||||
Returns
|
||||
---------
|
||||
Optional[Union[:class:`Command`, :class:`Group`]]
|
||||
Optional[Union[:class:`Command`, :class:`ContextMenu`, :class:`Group`]]
|
||||
The application command that got removed.
|
||||
If nothing was removed then ``None`` is returned instead.
|
||||
"""
|
||||
|
||||
if guild is None:
|
||||
return self._global_commands.pop(command, None)
|
||||
else:
|
||||
try:
|
||||
commands = self._guild_commands[guild.id]
|
||||
except KeyError:
|
||||
return None
|
||||
if type is AppCommandType.chat_input:
|
||||
if guild is None:
|
||||
return self._global_commands.pop(command, None)
|
||||
else:
|
||||
return commands.pop(command, None)
|
||||
try:
|
||||
commands = self._guild_commands[guild.id]
|
||||
except KeyError:
|
||||
return None
|
||||
else:
|
||||
return commands.pop(command, None)
|
||||
elif type in (AppCommandType.user, AppCommandType.message):
|
||||
guild_id = None if guild is None else guild.id
|
||||
key = (command, guild_id, type.value)
|
||||
return self._context_menus.pop(key, None)
|
||||
|
||||
def get_command(self, command: str, /, *, guild: Optional[Snowflake] = None) -> Optional[Union[Command, Group]]:
|
||||
@overload
|
||||
def get_command(
|
||||
self,
|
||||
command: str,
|
||||
/,
|
||||
*,
|
||||
guild: Optional[Snowflake] = ...,
|
||||
type: Literal[AppCommandType.message, AppCommandType.user] = ...,
|
||||
) -> Optional[ContextMenu]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_command(
|
||||
self,
|
||||
command: str,
|
||||
/,
|
||||
*,
|
||||
guild: Optional[Snowflake] = ...,
|
||||
type: Literal[AppCommandType.chat_input] = ...,
|
||||
) -> Optional[Union[Command, Group]]:
|
||||
...
|
||||
|
||||
def get_command(
|
||||
self,
|
||||
command: str,
|
||||
/,
|
||||
*,
|
||||
guild: Optional[Snowflake] = None,
|
||||
type: AppCommandType = AppCommandType.chat_input,
|
||||
) -> Optional[Union[Command, ContextMenu, Group]]:
|
||||
"""Gets a application command from the tree.
|
||||
|
||||
.. note::
|
||||
|
||||
This does *not* include context menu commands.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
command: :class:`str`
|
||||
@@ -201,52 +287,103 @@ class CommandTree:
|
||||
guild: Optional[:class:`abc.Snowflake`]
|
||||
The guild to get the command from. If not given then it
|
||||
gets a global command instead.
|
||||
type: :class:`AppCommandType`
|
||||
The type of command to get. Defaults to :attr:`AppCommandType.chat_input`,
|
||||
i.e. slash commands.
|
||||
|
||||
Returns
|
||||
---------
|
||||
Optional[Union[:class:`Command`, :class:`Group`]]
|
||||
Optional[Union[:class:`Command`, :class:`ContextMenu`, :class:`Group`]]
|
||||
The application command that was found.
|
||||
If nothing was found then ``None`` is returned instead.
|
||||
"""
|
||||
|
||||
if guild is None:
|
||||
return self._global_commands.get(command)
|
||||
else:
|
||||
try:
|
||||
commands = self._guild_commands[guild.id]
|
||||
except KeyError:
|
||||
return None
|
||||
if type is AppCommandType.chat_input:
|
||||
if guild is None:
|
||||
return self._global_commands.get(command)
|
||||
else:
|
||||
return commands.get(command)
|
||||
try:
|
||||
commands = self._guild_commands[guild.id]
|
||||
except KeyError:
|
||||
return None
|
||||
else:
|
||||
return commands.get(command)
|
||||
elif type in (AppCommandType.user, AppCommandType.message):
|
||||
guild_id = None if guild is None else guild.id
|
||||
key = (command, guild_id, type.value)
|
||||
return self._context_menus.get(key)
|
||||
|
||||
def get_commands(self, *, guild: Optional[Snowflake] = None) -> List[Union[Command, Group]]:
|
||||
@overload
|
||||
def get_commands(
|
||||
self,
|
||||
*,
|
||||
guild: Optional[Snowflake] = ...,
|
||||
type: Literal[AppCommandType.message, AppCommandType.user] = ...,
|
||||
) -> List[ContextMenu]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_commands(
|
||||
self,
|
||||
*,
|
||||
guild: Optional[Snowflake] = ...,
|
||||
type: Literal[AppCommandType.chat_input] = ...,
|
||||
) -> List[Union[Command, Group]]:
|
||||
...
|
||||
|
||||
def get_commands(
|
||||
self,
|
||||
*,
|
||||
guild: Optional[Snowflake] = None,
|
||||
type: AppCommandType = AppCommandType.chat_input,
|
||||
) -> Union[List[Union[Command, Group]], List[ContextMenu]]:
|
||||
"""Gets all application commands from the tree.
|
||||
|
||||
.. note::
|
||||
|
||||
This does *not* retrieve context menu commands.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
guild: Optional[:class:`~discord.abc.Snowflake`]
|
||||
The guild to get the commands from. If not given then it
|
||||
gets all global commands instead.
|
||||
type: :class:`AppCommandType`
|
||||
The type of commands to get. Defaults to :attr:`AppCommandType.chat_input`,
|
||||
i.e. slash commands.
|
||||
|
||||
Returns
|
||||
---------
|
||||
List[Union[:class:`Command`, :class:`Group`]]
|
||||
Union[List[:class:`ContextMenu`], List[Union[:class:`Command`, :class:`Group`]]
|
||||
The application commands from the tree.
|
||||
"""
|
||||
|
||||
if type is AppCommandType.chat_input:
|
||||
if guild is None:
|
||||
return list(self._global_commands.values())
|
||||
else:
|
||||
try:
|
||||
commands = self._guild_commands[guild.id]
|
||||
except KeyError:
|
||||
return []
|
||||
else:
|
||||
return list(commands.values())
|
||||
else:
|
||||
guild_id = None if guild is None else guild.id
|
||||
value = type.value
|
||||
return [command for ((_, g, t), command) in self._context_menus.items() if g == guild_id and t == value]
|
||||
|
||||
def _get_all_commands(self, *, guild: Optional[Snowflake] = None) -> List[Union[Command, Group, ContextMenu]]:
|
||||
if guild is None:
|
||||
return list(self._global_commands.values())
|
||||
base: List[Union[Command, Group, ContextMenu]] = list(self._global_commands.values())
|
||||
base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g is None)
|
||||
return base
|
||||
else:
|
||||
try:
|
||||
commands = self._guild_commands[guild.id]
|
||||
except KeyError:
|
||||
return []
|
||||
return [cmd for ((_, g, _), cmd) in self._context_menus.items() if g is None]
|
||||
else:
|
||||
return list(commands.values())
|
||||
base: List[Union[Command, Group, ContextMenu]] = list(commands.values())
|
||||
guild_id = guild.id
|
||||
base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g == guild_id)
|
||||
return base
|
||||
|
||||
def command(
|
||||
self,
|
||||
@@ -266,7 +403,7 @@ class CommandTree:
|
||||
The description of the application command. This shows up in the UI to describe
|
||||
the application command. If not given, it defaults to the first line of the docstring
|
||||
of the callback shortened to 100 characters.
|
||||
guild: Optional[:class:`Snowflake`]
|
||||
guild: Optional[:class:`.abc.Snowflake`]
|
||||
The guild to add the command to. If not given then it
|
||||
becomes a global command instead.
|
||||
"""
|
||||
@@ -287,7 +424,6 @@ class CommandTree:
|
||||
name=name if name is not MISSING else func.__name__,
|
||||
description=desc,
|
||||
callback=func,
|
||||
type=AppCommandType.chat_input,
|
||||
parent=None,
|
||||
)
|
||||
self.add_command(command, guild=guild)
|
||||
@@ -295,6 +431,49 @@ class CommandTree:
|
||||
|
||||
return decorator
|
||||
|
||||
def context_menu(
|
||||
self, *, name: str = MISSING, guild: Optional[Snowflake] = None
|
||||
) -> Callable[[ContextMenuCallback], ContextMenu]:
|
||||
"""Creates a application command context menu from a regular function directly under this tree.
|
||||
|
||||
This function must have a signature of :class:`~discord.Interaction` as its first parameter
|
||||
and taking either a :class:`~discord.Member`, :class:`~discord.User`, or :class:`~discord.Message`,
|
||||
or a :obj:`typing.Union` of ``Member`` and ``User`` as its second parameter.
|
||||
|
||||
Examples
|
||||
---------
|
||||
|
||||
.. code-block:: python3
|
||||
|
||||
@app_commands.context_menu()
|
||||
async def react(interaction: discord.Interaction, message: discord.Message):
|
||||
await interaction.response.send_message('Very cool message!', ephemeral=True)
|
||||
|
||||
@app_commands.context_menu()
|
||||
async def ban(interaction: discord.Interaction, user: discord.Member):
|
||||
await interaction.response.send_message(f'Should I actually ban {user}...', ephemeral=True)
|
||||
|
||||
Parameters
|
||||
------------
|
||||
name: :class:`str`
|
||||
The name of the context menu command. If not given, it defaults to a title-case
|
||||
version of the callback name. Note that unlike regular slash commands this can
|
||||
have spaces and upper case characters in the name.
|
||||
guild: Optional[:class:`.abc.Snowflake`]
|
||||
The guild to add the command to. If not given then it
|
||||
becomes a global command instead.
|
||||
"""
|
||||
|
||||
def decorator(func: ContextMenuCallback) -> ContextMenu:
|
||||
if not inspect.iscoroutinefunction(func):
|
||||
raise TypeError('context menu function must be a coroutine function')
|
||||
|
||||
context_menu = ContextMenu._from_decorator(func, name=name)
|
||||
self.add_command(context_menu, guild=guild)
|
||||
return context_menu
|
||||
|
||||
return decorator
|
||||
|
||||
async def sync(self, *, guild: Optional[Snowflake]) -> List[AppCommand]:
|
||||
"""|coro|
|
||||
|
||||
@@ -327,7 +506,7 @@ class CommandTree:
|
||||
if self.client.application_id is None:
|
||||
raise ClientException('Client does not have an application ID set')
|
||||
|
||||
commands = self.get_commands(guild=guild)
|
||||
commands = self._get_all_commands(guild=guild)
|
||||
payload = [command.to_dict() for command in commands]
|
||||
if guild is None:
|
||||
data = await self._http.bulk_upsert_global_commands(self.client.application_id, payload=payload)
|
||||
@@ -345,6 +524,25 @@ class CommandTree:
|
||||
|
||||
self.client.loop.create_task(wrapper(), name='CommandTree-invoker')
|
||||
|
||||
async def _call_context_menu(self, interaction: Interaction, data: ApplicationCommandInteractionData, type: int):
|
||||
name = data['name']
|
||||
guild_id = interaction.guild_id
|
||||
ctx_menu = self._context_menus.get((name, guild_id, type))
|
||||
if ctx_menu is None:
|
||||
raise CommandNotFound(name, [], AppCommandType(type))
|
||||
|
||||
resolved = Namespace._get_resolved_items(interaction, data.get('resolved', {}))
|
||||
# This will always work at runtime
|
||||
value = resolved.get(data.get('target_id')) # type: ignore
|
||||
if ctx_menu.type.value != type:
|
||||
raise CommandSignatureMismatch(ctx_menu)
|
||||
|
||||
if value is None:
|
||||
raise RuntimeError('This should not happen if Discord sent well-formed data.')
|
||||
|
||||
# I assume I don't have to type check here.
|
||||
await ctx_menu._invoke(interaction, value)
|
||||
|
||||
async def call(self, interaction: Interaction):
|
||||
"""|coro|
|
||||
|
||||
@@ -367,6 +565,12 @@ class CommandTree:
|
||||
application command definition.
|
||||
"""
|
||||
data: ApplicationCommandInteractionData = interaction.data # type: ignore
|
||||
type = data.get('type', 1)
|
||||
if type != 1:
|
||||
# Context menu command...
|
||||
await self._call_context_menu(interaction, data, type)
|
||||
return
|
||||
|
||||
parents: List[str] = []
|
||||
name = data['name']
|
||||
command = self._global_commands.get(name)
|
||||
|
||||
Reference in New Issue
Block a user