mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-05-09 23:39:50 +00:00
[commands] Allow Cog and app_commands interopability
This changeset allows app commands defined inside Cog to work as expected. Likewise, by deriving app_commands.Group and Cog you can make the cog function as a top level command on Discord.
This commit is contained in:
parent
5741ad9368
commit
446bfa78b0
@ -61,6 +61,11 @@ if TYPE_CHECKING:
|
||||
from .namespace import Namespace
|
||||
from .models import ChoiceT
|
||||
|
||||
# Generally, these two libraries are supposed to be separate from each other.
|
||||
# However, for type hinting purposes it's unfortunately necessary for one to
|
||||
# reference the other to prevent type checking errors in callbacks
|
||||
from discord.ext.commands import Cog
|
||||
|
||||
__all__ = (
|
||||
'Command',
|
||||
'ContextMenu',
|
||||
@ -79,7 +84,7 @@ else:
|
||||
P = TypeVar('P')
|
||||
|
||||
T = TypeVar('T')
|
||||
GroupT = TypeVar('GroupT', bound='Group')
|
||||
GroupT = TypeVar('GroupT', bound='Union[Group, Cog]')
|
||||
Coro = Coroutine[Any, Any, T]
|
||||
Error = Union[
|
||||
Callable[[GroupT, Interaction, AppCommandError], Coro[Any]],
|
||||
@ -628,15 +633,14 @@ class Group:
|
||||
"""
|
||||
|
||||
__discord_app_commands_group_children__: ClassVar[List[Union[Command, Group]]] = []
|
||||
__discord_app_commands_skip_init_binding__: bool = False
|
||||
__discord_app_commands_group_name__: str = MISSING
|
||||
__discord_app_commands_group_description__: str = MISSING
|
||||
|
||||
def __init_subclass__(cls, *, name: str = MISSING, description: str = MISSING) -> None:
|
||||
if not cls.__discord_app_commands_group_children__:
|
||||
cls.__discord_app_commands_group_children__ = children = [
|
||||
member
|
||||
for member in cls.__dict__.values()
|
||||
if isinstance(member, (Group, Command)) and member.parent is None
|
||||
member for member in cls.__dict__.values() if isinstance(member, (Group, Command)) and member.parent is None
|
||||
]
|
||||
|
||||
found = set()
|
||||
@ -661,7 +665,6 @@ class Group:
|
||||
else:
|
||||
cls.__discord_app_commands_group_description__ = description
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@ -683,10 +686,10 @@ class Group:
|
||||
self._children: Dict[str, Union[Command, Group]] = {}
|
||||
|
||||
for child in self.__discord_app_commands_group_children__:
|
||||
child = child._copy_with_binding(self)
|
||||
child = child._copy_with_binding(self) if not cls.__discord_app_commands_skip_init_binding__ else child
|
||||
child.parent = self
|
||||
self._children[child.name] = child
|
||||
if child._attr:
|
||||
if child._attr and not cls.__discord_app_commands_skip_init_binding__:
|
||||
setattr(self, child._attr, child)
|
||||
|
||||
if parent is not None and parent.parent is not None:
|
||||
@ -695,7 +698,7 @@ class Group:
|
||||
def __set_name__(self, owner: Type[Any], name: str) -> None:
|
||||
self._attr = name
|
||||
|
||||
def _copy_with_binding(self, binding: Group) -> Group:
|
||||
def _copy_with_binding(self, binding: Union[Group, Cog]) -> Group:
|
||||
cls = self.__class__
|
||||
copy = cls.__new__(cls)
|
||||
copy.name = self.name
|
||||
|
@ -36,6 +36,8 @@ import types
|
||||
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union, overload
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.app_commands.tree import _retrieve_guild_ids
|
||||
|
||||
from .core import GroupMixin
|
||||
from .view import StringView
|
||||
@ -50,7 +52,7 @@ if TYPE_CHECKING:
|
||||
import importlib.machinery
|
||||
|
||||
from discord.message import Message
|
||||
from discord.abc import User
|
||||
from discord.abc import User, Snowflake
|
||||
from ._types import (
|
||||
Check,
|
||||
CoroFunc,
|
||||
@ -135,6 +137,8 @@ class BotBase(GroupMixin):
|
||||
super().__init__(**options)
|
||||
self.command_prefix = command_prefix
|
||||
self.extra_events: Dict[str, List[CoroFunc]] = {}
|
||||
# Self doesn't have the ClientT bound, but since this is a mixin it technically does
|
||||
self.__tree: app_commands.CommandTree[Self] = app_commands.CommandTree(self) # type: ignore
|
||||
self.__cogs: Dict[str, Cog] = {}
|
||||
self.__extensions: Dict[str, types.ModuleType] = {}
|
||||
self._checks: List[Check] = []
|
||||
@ -529,11 +533,22 @@ class BotBase(GroupMixin):
|
||||
|
||||
# cogs
|
||||
|
||||
def add_cog(self, cog: Cog, /, *, override: bool = False) -> None:
|
||||
def add_cog(
|
||||
self,
|
||||
cog: Cog,
|
||||
/,
|
||||
*,
|
||||
override: bool = False,
|
||||
guild: Optional[Snowflake] = MISSING,
|
||||
guilds: List[Snowflake] = MISSING,
|
||||
) -> None:
|
||||
"""Adds a "cog" to the bot.
|
||||
|
||||
A cog is a class that has its own event listeners and commands.
|
||||
|
||||
If the cog is a :class:`.app_commands.Group` then it is added to
|
||||
the bot's :class:`~discord.app_commands.CommandTree` as well.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
|
||||
:exc:`.ClientException` is raised when a cog with the same name
|
||||
@ -551,6 +566,19 @@ class BotBase(GroupMixin):
|
||||
If a previously loaded cog with the same name should be ejected
|
||||
instead of raising an error.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
guild: Optional[:class:`~discord.abc.Snowflake`]
|
||||
If the cog is an application command group, then this would be the
|
||||
guild where the cog group would be added to. If not given then
|
||||
it becomes a global command instead.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
guilds: List[:class:`~discord.abc.Snowflake`]
|
||||
If the cog is an application command group, then this would be the
|
||||
guilds where the cog group would be added to. If not given then
|
||||
it becomes a global command instead. Cannot be mixed with
|
||||
``guild``.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Raises
|
||||
@ -572,7 +600,10 @@ class BotBase(GroupMixin):
|
||||
if existing is not None:
|
||||
if not override:
|
||||
raise discord.ClientException(f'Cog named {cog_name!r} already loaded')
|
||||
self.remove_cog(cog_name)
|
||||
self.remove_cog(cog_name, guild=guild, guilds=guilds)
|
||||
|
||||
if isinstance(cog, app_commands.Group):
|
||||
self.__tree.add_command(cog, override=override, guild=guild, guilds=guilds)
|
||||
|
||||
cog = cog._inject(self)
|
||||
self.__cogs[cog_name] = cog
|
||||
@ -600,7 +631,13 @@ class BotBase(GroupMixin):
|
||||
"""
|
||||
return self.__cogs.get(name)
|
||||
|
||||
def remove_cog(self, name: str, /) -> Optional[Cog]:
|
||||
def remove_cog(
|
||||
self,
|
||||
name: str,
|
||||
/,
|
||||
guild: Optional[Snowflake] = MISSING,
|
||||
guilds: List[Snowflake] = MISSING,
|
||||
) -> Optional[Cog]:
|
||||
"""Removes a cog from the bot and returns it.
|
||||
|
||||
All registered commands and event listeners that the
|
||||
@ -616,6 +653,19 @@ class BotBase(GroupMixin):
|
||||
-----------
|
||||
name: :class:`str`
|
||||
The name of the cog to remove.
|
||||
guild: Optional[:class:`~discord.abc.Snowflake`]
|
||||
If the cog is an application command group, then this would be the
|
||||
guild where the cog group would be removed from. If not given then
|
||||
a global command is removed instead instead.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
guilds: List[:class:`~discord.abc.Snowflake`]
|
||||
If the cog is an application command group, then this would be the
|
||||
guilds where the cog group would be removed from. If not given then
|
||||
a global command is removed instead instead. Cannot be mixed with
|
||||
``guild``.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -630,6 +680,15 @@ class BotBase(GroupMixin):
|
||||
help_command = self._help_command
|
||||
if help_command and help_command.cog is cog:
|
||||
help_command.cog = None
|
||||
|
||||
if isinstance(cog, app_commands.Group):
|
||||
guild_ids = _retrieve_guild_ids(cog, guild, guilds)
|
||||
if guild_ids is None:
|
||||
self.__tree.remove_command(name)
|
||||
else:
|
||||
for guild_id in guild_ids:
|
||||
self.__tree.remove_command(name, guild=discord.Object(guild_id))
|
||||
|
||||
cog._eject(self)
|
||||
|
||||
return cog
|
||||
@ -894,6 +953,20 @@ class BotBase(GroupMixin):
|
||||
else:
|
||||
self._help_command = None
|
||||
|
||||
# application command interop
|
||||
|
||||
# As mentioned above, this is a mixin so the Self type hint fails here.
|
||||
# However, since the only classes that can use this are subclasses of Client
|
||||
# anyway, then this is sound.
|
||||
@property
|
||||
def tree(self) -> app_commands.CommandTree[Self]: # type: ignore
|
||||
""":class:`~discord.app_commands.CommandTree`: The command tree responsible for handling the application commands
|
||||
in this bot.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return self.__tree
|
||||
|
||||
# command processing
|
||||
|
||||
async def get_prefix(self, message: Message) -> Union[List[str], str]:
|
||||
|
@ -24,14 +24,15 @@ DEALINGS IN THE SOFTWARE.
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import discord.utils
|
||||
import discord
|
||||
from discord import app_commands
|
||||
|
||||
from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Type
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union, Type
|
||||
|
||||
from ._types import _BaseCommand
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
from typing_extensions import Self, TypeGuard
|
||||
|
||||
from .bot import BotBase
|
||||
from .context import Context
|
||||
@ -110,19 +111,33 @@ class CogMeta(type):
|
||||
__cog_name__: str
|
||||
__cog_settings__: Dict[str, Any]
|
||||
__cog_commands__: List[Command]
|
||||
__cog_is_app_commands_group__: bool
|
||||
__cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Any, ..., Any]]]
|
||||
__cog_listeners__: List[Tuple[str, str]]
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
|
||||
name, bases, attrs = args
|
||||
attrs['__cog_name__'] = kwargs.pop('name', name)
|
||||
attrs['__cog_name__'] = kwargs.get('name', name)
|
||||
attrs['__cog_settings__'] = kwargs.pop('command_attrs', {})
|
||||
attrs['__cog_is_app_commands_group__'] = is_parent = app_commands.Group in bases
|
||||
|
||||
description = kwargs.pop('description', None)
|
||||
description = kwargs.get('description', None)
|
||||
if description is None:
|
||||
description = inspect.cleandoc(attrs.get('__doc__', ''))
|
||||
attrs['__cog_description__'] = description
|
||||
|
||||
if is_parent:
|
||||
attrs['__discord_app_commands_skip_init_binding__'] = True
|
||||
# This is hacky, but it signals the Group not to process this info.
|
||||
# It's overridden later.
|
||||
attrs['__discord_app_commands_group_children__'] = True
|
||||
else:
|
||||
# Remove the extraneous keyword arguments we're using
|
||||
kwargs.pop('name', None)
|
||||
kwargs.pop('description', None)
|
||||
|
||||
commands = {}
|
||||
cog_app_commands = {}
|
||||
listeners = {}
|
||||
no_bot_cog = 'Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})'
|
||||
|
||||
@ -143,6 +158,8 @@ class CogMeta(type):
|
||||
if elem.startswith(('cog_', 'bot_')):
|
||||
raise TypeError(no_bot_cog.format(base, elem))
|
||||
commands[elem] = value
|
||||
elif isinstance(value, (app_commands.Group, app_commands.Command)) and value.parent is None:
|
||||
cog_app_commands[elem] = value
|
||||
elif inspect.iscoroutinefunction(value):
|
||||
try:
|
||||
getattr(value, '__cog_listener__')
|
||||
@ -154,6 +171,13 @@ class CogMeta(type):
|
||||
listeners[elem] = value
|
||||
|
||||
new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__
|
||||
new_cls.__cog_app_commands__ = list(cog_app_commands.values())
|
||||
|
||||
if is_parent:
|
||||
# Prefill the app commands for the Group as well..
|
||||
# The type checker doesn't like runtime attribute modification and this one's
|
||||
# optional so it can't be cheesed.
|
||||
new_cls.__discord_app_commands_group_children__ = cog_app_commands # type: ignore
|
||||
|
||||
listeners_as_list = []
|
||||
for listener in listeners.values():
|
||||
@ -189,10 +213,11 @@ class Cog(metaclass=CogMeta):
|
||||
are equally valid here.
|
||||
"""
|
||||
|
||||
__cog_name__: ClassVar[str]
|
||||
__cog_settings__: ClassVar[Dict[str, Any]]
|
||||
__cog_commands__: ClassVar[List[Command[Self, ..., Any]]]
|
||||
__cog_listeners__: ClassVar[List[Tuple[str, str]]]
|
||||
__cog_name__: str
|
||||
__cog_settings__: Dict[str, Any]
|
||||
__cog_commands__: List[Command[Self, ..., Any]]
|
||||
__cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Self, ..., Any]]]
|
||||
__cog_listeners__: List[Tuple[str, str]]
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
|
||||
# For issue 426, we need to store a copy of the command objects
|
||||
@ -219,6 +244,25 @@ class Cog(metaclass=CogMeta):
|
||||
parent.remove_command(command.name) # type: ignore
|
||||
parent.add_command(command) # type: ignore
|
||||
|
||||
# Register the application commands
|
||||
children: List[Union[app_commands.Group, app_commands.Command[Self, ..., Any]]] = []
|
||||
for command in cls.__cog_app_commands__:
|
||||
copy = command._copy_with_binding(self)
|
||||
|
||||
if cls.__cog_is_app_commands_group__:
|
||||
# Type checker doesn't understand this type of narrowing.
|
||||
# Not even with TypeGuard somehow.
|
||||
copy.parent = self # type: ignore
|
||||
|
||||
children.append(copy)
|
||||
if command._attr:
|
||||
setattr(self, command._attr, copy)
|
||||
|
||||
self.__cog_app_commands__ = children
|
||||
if cls.__cog_is_app_commands_group__:
|
||||
# Dynamic attribute setting
|
||||
self.__discord_app_commands_group_children__ = children # type: ignore
|
||||
|
||||
return self
|
||||
|
||||
def get_commands(self) -> List[Command[Self, ..., Any]]:
|
||||
@ -452,6 +496,12 @@ class Cog(metaclass=CogMeta):
|
||||
for name, method_name in self.__cog_listeners__:
|
||||
bot.add_listener(getattr(self, method_name), name)
|
||||
|
||||
# Only do this if these are "top level" commands
|
||||
if not cls.__cog_is_app_commands_group__:
|
||||
for command in self.__cog_app_commands__:
|
||||
# This is already atomic
|
||||
bot.tree.add_command(command)
|
||||
|
||||
return self
|
||||
|
||||
def _eject(self, bot: BotBase) -> None:
|
||||
@ -462,6 +512,16 @@ class Cog(metaclass=CogMeta):
|
||||
if command.parent is None:
|
||||
bot.remove_command(command.name)
|
||||
|
||||
if not cls.__cog_is_app_commands_group__:
|
||||
for command in self.__cog_app_commands__:
|
||||
try:
|
||||
guild_ids = command.__discord_app_commands_default_guilds__
|
||||
except AttributeError:
|
||||
bot.tree.remove_command(command.name)
|
||||
else:
|
||||
for guild_id in guild_ids:
|
||||
bot.tree.remove_command(command.name, guild=discord.Object(id=guild_id))
|
||||
|
||||
for name, method_name in self.__cog_listeners__:
|
||||
bot.remove_listener(getattr(self, method_name), name)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user