[commands] Add initial implementation of hybrid commands

Hybrid commands allow a regular command to also double as a slash
command, assuming it meets the subset required to function.
This commit is contained in:
Rapptz
2022-04-10 17:14:38 -04:00
parent 151806ec94
commit 840eb577d4
10 changed files with 919 additions and 28 deletions

View File

@ -67,6 +67,7 @@ if TYPE_CHECKING:
import importlib.machinery
from discord.message import Message
from discord.interactions import Interaction
from discord.abc import User, Snowflake
from ._types import (
_Bot,
@ -76,6 +77,7 @@ if TYPE_CHECKING:
ContextT,
MaybeAwaitableFunc,
)
from .core import Command
_Prefix = Union[Iterable[str], str]
_PrefixCallable = MaybeAwaitableFunc[[BotT, Message], _Prefix]
@ -215,6 +217,38 @@ class BotBase(GroupMixin[None]):
await super().close() # type: ignore
# GroupMixin overrides
@discord.utils.copy_doc(GroupMixin.add_command)
def add_command(self, command: Command[Any, ..., Any], /) -> None:
super().add_command(command)
if hasattr(command, '__commands_is_hybrid__'):
# If a cog is also inheriting from app_commands.Group then it'll also
# add the hybrid commands as text commands, which would recursively add the
# hybrid commands as slash commands. This check just terminates that recursion
# from happening
if command.cog is None or not command.cog.__cog_is_app_commands_group__:
self.tree.add_command(command.app_command) # type: ignore
@discord.utils.copy_doc(GroupMixin.remove_command)
def remove_command(self, name: str, /) -> Optional[Command[Any, ..., Any]]:
cmd = super().remove_command(name)
if cmd is not None and hasattr(cmd, '__commands_is_hybrid__'):
# See above
if cmd.cog is not None and cmd.cog.__cog_is_app_commands_group__:
return cmd
guild_ids: Optional[List[int]] = cmd.app_command._guild_ids # type: ignore
if guild_ids is None:
self.__tree.remove_command(name)
else:
for guild_id in guild_ids:
self.__tree.remove_command(name, guild=discord.Object(id=guild_id))
return cmd
# Error handler
async def on_command_error(self, context: Context[BotT], exception: errors.CommandError, /) -> None:
"""|coro|
@ -1107,7 +1141,7 @@ class BotBase(GroupMixin[None]):
@overload
async def get_context(
self,
message: Message,
origin: Union[Message, Interaction],
/,
) -> Context[Self]: # type: ignore
...
@ -1115,23 +1149,23 @@ class BotBase(GroupMixin[None]):
@overload
async def get_context(
self,
message: Message,
origin: Union[Message, Interaction],
/,
*,
cls: Type[ContextT] = ...,
cls: Type[ContextT],
) -> ContextT:
...
async def get_context(
self,
message: Message,
origin: Union[Message, Interaction],
/,
*,
cls: Type[ContextT] = MISSING,
) -> Any:
r"""|coro|
Returns the invocation context from the message.
Returns the invocation context from the message or interaction.
This is a more low-level counter-part for :meth:`.process_commands`
to allow users more fine grained control over the processing.
@ -1141,14 +1175,20 @@ class BotBase(GroupMixin[None]):
If the context is not valid then it is not a valid candidate to be
invoked under :meth:`~.Bot.invoke`.
.. note::
In order for the custom context to be used inside an interaction-based
context (such as :class:`HybridCommand`) then this method must be
overridden to return that class.
.. versionchanged:: 2.0
``message`` parameter is now positional-only.
``message`` parameter is now positional-only and renamed to ``origin``.
Parameters
-----------
message: :class:`discord.Message`
The message to get the invocation context from.
origin: Union[:class:`discord.Message`, :class:`discord.Interaction`]
The message or interaction to get the invocation context from.
cls
The factory class that will be used to create the context.
By default, this is :class:`.Context`. Should a custom
@ -1164,13 +1204,16 @@ class BotBase(GroupMixin[None]):
if cls is MISSING:
cls = Context # type: ignore
view = StringView(message.content)
ctx = cls(prefix=None, view=view, bot=self, message=message)
if isinstance(origin, discord.Interaction):
return await cls.from_interaction(origin)
if message.author.id == self.user.id: # type: ignore
view = StringView(origin.content)
ctx = cls(prefix=None, view=view, bot=self, message=origin)
if origin.author.id == self.user.id: # type: ignore
return ctx
prefix = await self.get_prefix(message)
prefix = await self.get_prefix(origin)
invoked_prefix = prefix
if isinstance(prefix, str):
@ -1180,7 +1223,7 @@ class BotBase(GroupMixin[None]):
try:
# if the context class' __init__ consumes something from the view this
# will be wrong. That seems unreasonable though.
if message.content.startswith(tuple(prefix)):
if origin.content.startswith(tuple(prefix)):
invoked_prefix = discord.utils.find(view.skip_string, prefix)
else:
return ctx