mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-05-09 23:39:50 +00:00
[commands] Allow Cog and app_commands interopability
This changeset allows app commands defined inside Cog to work as expected. Likewise, by deriving app_commands.Group and Cog you can make the cog function as a top level command on Discord.
This commit is contained in:
parent
5741ad9368
commit
446bfa78b0
@ -61,6 +61,11 @@ if TYPE_CHECKING:
|
|||||||
from .namespace import Namespace
|
from .namespace import Namespace
|
||||||
from .models import ChoiceT
|
from .models import ChoiceT
|
||||||
|
|
||||||
|
# Generally, these two libraries are supposed to be separate from each other.
|
||||||
|
# However, for type hinting purposes it's unfortunately necessary for one to
|
||||||
|
# reference the other to prevent type checking errors in callbacks
|
||||||
|
from discord.ext.commands import Cog
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
'Command',
|
'Command',
|
||||||
'ContextMenu',
|
'ContextMenu',
|
||||||
@ -79,7 +84,7 @@ else:
|
|||||||
P = TypeVar('P')
|
P = TypeVar('P')
|
||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
GroupT = TypeVar('GroupT', bound='Group')
|
GroupT = TypeVar('GroupT', bound='Union[Group, Cog]')
|
||||||
Coro = Coroutine[Any, Any, T]
|
Coro = Coroutine[Any, Any, T]
|
||||||
Error = Union[
|
Error = Union[
|
||||||
Callable[[GroupT, Interaction, AppCommandError], Coro[Any]],
|
Callable[[GroupT, Interaction, AppCommandError], Coro[Any]],
|
||||||
@ -628,15 +633,14 @@ class Group:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
__discord_app_commands_group_children__: ClassVar[List[Union[Command, Group]]] = []
|
__discord_app_commands_group_children__: ClassVar[List[Union[Command, Group]]] = []
|
||||||
|
__discord_app_commands_skip_init_binding__: bool = False
|
||||||
__discord_app_commands_group_name__: str = MISSING
|
__discord_app_commands_group_name__: str = MISSING
|
||||||
__discord_app_commands_group_description__: str = MISSING
|
__discord_app_commands_group_description__: str = MISSING
|
||||||
|
|
||||||
def __init_subclass__(cls, *, name: str = MISSING, description: str = MISSING) -> None:
|
def __init_subclass__(cls, *, name: str = MISSING, description: str = MISSING) -> None:
|
||||||
if not cls.__discord_app_commands_group_children__:
|
if not cls.__discord_app_commands_group_children__:
|
||||||
cls.__discord_app_commands_group_children__ = children = [
|
cls.__discord_app_commands_group_children__ = children = [
|
||||||
member
|
member for member in cls.__dict__.values() if isinstance(member, (Group, Command)) and member.parent is None
|
||||||
for member in cls.__dict__.values()
|
|
||||||
if isinstance(member, (Group, Command)) and member.parent is None
|
|
||||||
]
|
]
|
||||||
|
|
||||||
found = set()
|
found = set()
|
||||||
@ -661,7 +665,6 @@ class Group:
|
|||||||
else:
|
else:
|
||||||
cls.__discord_app_commands_group_description__ = description
|
cls.__discord_app_commands_group_description__ = description
|
||||||
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@ -683,10 +686,10 @@ class Group:
|
|||||||
self._children: Dict[str, Union[Command, Group]] = {}
|
self._children: Dict[str, Union[Command, Group]] = {}
|
||||||
|
|
||||||
for child in self.__discord_app_commands_group_children__:
|
for child in self.__discord_app_commands_group_children__:
|
||||||
child = child._copy_with_binding(self)
|
child = child._copy_with_binding(self) if not cls.__discord_app_commands_skip_init_binding__ else child
|
||||||
child.parent = self
|
child.parent = self
|
||||||
self._children[child.name] = child
|
self._children[child.name] = child
|
||||||
if child._attr:
|
if child._attr and not cls.__discord_app_commands_skip_init_binding__:
|
||||||
setattr(self, child._attr, child)
|
setattr(self, child._attr, child)
|
||||||
|
|
||||||
if parent is not None and parent.parent is not None:
|
if parent is not None and parent.parent is not None:
|
||||||
@ -695,7 +698,7 @@ class Group:
|
|||||||
def __set_name__(self, owner: Type[Any], name: str) -> None:
|
def __set_name__(self, owner: Type[Any], name: str) -> None:
|
||||||
self._attr = name
|
self._attr = name
|
||||||
|
|
||||||
def _copy_with_binding(self, binding: Group) -> Group:
|
def _copy_with_binding(self, binding: Union[Group, Cog]) -> Group:
|
||||||
cls = self.__class__
|
cls = self.__class__
|
||||||
copy = cls.__new__(cls)
|
copy = cls.__new__(cls)
|
||||||
copy.name = self.name
|
copy.name = self.name
|
||||||
|
@ -36,6 +36,8 @@ import types
|
|||||||
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union, overload
|
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union, overload
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
from discord import app_commands
|
||||||
|
from discord.app_commands.tree import _retrieve_guild_ids
|
||||||
|
|
||||||
from .core import GroupMixin
|
from .core import GroupMixin
|
||||||
from .view import StringView
|
from .view import StringView
|
||||||
@ -50,7 +52,7 @@ if TYPE_CHECKING:
|
|||||||
import importlib.machinery
|
import importlib.machinery
|
||||||
|
|
||||||
from discord.message import Message
|
from discord.message import Message
|
||||||
from discord.abc import User
|
from discord.abc import User, Snowflake
|
||||||
from ._types import (
|
from ._types import (
|
||||||
Check,
|
Check,
|
||||||
CoroFunc,
|
CoroFunc,
|
||||||
@ -135,6 +137,8 @@ class BotBase(GroupMixin):
|
|||||||
super().__init__(**options)
|
super().__init__(**options)
|
||||||
self.command_prefix = command_prefix
|
self.command_prefix = command_prefix
|
||||||
self.extra_events: Dict[str, List[CoroFunc]] = {}
|
self.extra_events: Dict[str, List[CoroFunc]] = {}
|
||||||
|
# Self doesn't have the ClientT bound, but since this is a mixin it technically does
|
||||||
|
self.__tree: app_commands.CommandTree[Self] = app_commands.CommandTree(self) # type: ignore
|
||||||
self.__cogs: Dict[str, Cog] = {}
|
self.__cogs: Dict[str, Cog] = {}
|
||||||
self.__extensions: Dict[str, types.ModuleType] = {}
|
self.__extensions: Dict[str, types.ModuleType] = {}
|
||||||
self._checks: List[Check] = []
|
self._checks: List[Check] = []
|
||||||
@ -529,11 +533,22 @@ class BotBase(GroupMixin):
|
|||||||
|
|
||||||
# cogs
|
# cogs
|
||||||
|
|
||||||
def add_cog(self, cog: Cog, /, *, override: bool = False) -> None:
|
def add_cog(
|
||||||
|
self,
|
||||||
|
cog: Cog,
|
||||||
|
/,
|
||||||
|
*,
|
||||||
|
override: bool = False,
|
||||||
|
guild: Optional[Snowflake] = MISSING,
|
||||||
|
guilds: List[Snowflake] = MISSING,
|
||||||
|
) -> None:
|
||||||
"""Adds a "cog" to the bot.
|
"""Adds a "cog" to the bot.
|
||||||
|
|
||||||
A cog is a class that has its own event listeners and commands.
|
A cog is a class that has its own event listeners and commands.
|
||||||
|
|
||||||
|
If the cog is a :class:`.app_commands.Group` then it is added to
|
||||||
|
the bot's :class:`~discord.app_commands.CommandTree` as well.
|
||||||
|
|
||||||
.. versionchanged:: 2.0
|
.. versionchanged:: 2.0
|
||||||
|
|
||||||
:exc:`.ClientException` is raised when a cog with the same name
|
:exc:`.ClientException` is raised when a cog with the same name
|
||||||
@ -551,6 +566,19 @@ class BotBase(GroupMixin):
|
|||||||
If a previously loaded cog with the same name should be ejected
|
If a previously loaded cog with the same name should be ejected
|
||||||
instead of raising an error.
|
instead of raising an error.
|
||||||
|
|
||||||
|
.. versionadded:: 2.0
|
||||||
|
guild: Optional[:class:`~discord.abc.Snowflake`]
|
||||||
|
If the cog is an application command group, then this would be the
|
||||||
|
guild where the cog group would be added to. If not given then
|
||||||
|
it becomes a global command instead.
|
||||||
|
|
||||||
|
.. versionadded:: 2.0
|
||||||
|
guilds: List[:class:`~discord.abc.Snowflake`]
|
||||||
|
If the cog is an application command group, then this would be the
|
||||||
|
guilds where the cog group would be added to. If not given then
|
||||||
|
it becomes a global command instead. Cannot be mixed with
|
||||||
|
``guild``.
|
||||||
|
|
||||||
.. versionadded:: 2.0
|
.. versionadded:: 2.0
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
@ -572,7 +600,10 @@ class BotBase(GroupMixin):
|
|||||||
if existing is not None:
|
if existing is not None:
|
||||||
if not override:
|
if not override:
|
||||||
raise discord.ClientException(f'Cog named {cog_name!r} already loaded')
|
raise discord.ClientException(f'Cog named {cog_name!r} already loaded')
|
||||||
self.remove_cog(cog_name)
|
self.remove_cog(cog_name, guild=guild, guilds=guilds)
|
||||||
|
|
||||||
|
if isinstance(cog, app_commands.Group):
|
||||||
|
self.__tree.add_command(cog, override=override, guild=guild, guilds=guilds)
|
||||||
|
|
||||||
cog = cog._inject(self)
|
cog = cog._inject(self)
|
||||||
self.__cogs[cog_name] = cog
|
self.__cogs[cog_name] = cog
|
||||||
@ -600,7 +631,13 @@ class BotBase(GroupMixin):
|
|||||||
"""
|
"""
|
||||||
return self.__cogs.get(name)
|
return self.__cogs.get(name)
|
||||||
|
|
||||||
def remove_cog(self, name: str, /) -> Optional[Cog]:
|
def remove_cog(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
/,
|
||||||
|
guild: Optional[Snowflake] = MISSING,
|
||||||
|
guilds: List[Snowflake] = MISSING,
|
||||||
|
) -> Optional[Cog]:
|
||||||
"""Removes a cog from the bot and returns it.
|
"""Removes a cog from the bot and returns it.
|
||||||
|
|
||||||
All registered commands and event listeners that the
|
All registered commands and event listeners that the
|
||||||
@ -616,6 +653,19 @@ class BotBase(GroupMixin):
|
|||||||
-----------
|
-----------
|
||||||
name: :class:`str`
|
name: :class:`str`
|
||||||
The name of the cog to remove.
|
The name of the cog to remove.
|
||||||
|
guild: Optional[:class:`~discord.abc.Snowflake`]
|
||||||
|
If the cog is an application command group, then this would be the
|
||||||
|
guild where the cog group would be removed from. If not given then
|
||||||
|
a global command is removed instead instead.
|
||||||
|
|
||||||
|
.. versionadded:: 2.0
|
||||||
|
guilds: List[:class:`~discord.abc.Snowflake`]
|
||||||
|
If the cog is an application command group, then this would be the
|
||||||
|
guilds where the cog group would be removed from. If not given then
|
||||||
|
a global command is removed instead instead. Cannot be mixed with
|
||||||
|
``guild``.
|
||||||
|
|
||||||
|
.. versionadded:: 2.0
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -630,6 +680,15 @@ class BotBase(GroupMixin):
|
|||||||
help_command = self._help_command
|
help_command = self._help_command
|
||||||
if help_command and help_command.cog is cog:
|
if help_command and help_command.cog is cog:
|
||||||
help_command.cog = None
|
help_command.cog = None
|
||||||
|
|
||||||
|
if isinstance(cog, app_commands.Group):
|
||||||
|
guild_ids = _retrieve_guild_ids(cog, guild, guilds)
|
||||||
|
if guild_ids is None:
|
||||||
|
self.__tree.remove_command(name)
|
||||||
|
else:
|
||||||
|
for guild_id in guild_ids:
|
||||||
|
self.__tree.remove_command(name, guild=discord.Object(guild_id))
|
||||||
|
|
||||||
cog._eject(self)
|
cog._eject(self)
|
||||||
|
|
||||||
return cog
|
return cog
|
||||||
@ -894,6 +953,20 @@ class BotBase(GroupMixin):
|
|||||||
else:
|
else:
|
||||||
self._help_command = None
|
self._help_command = None
|
||||||
|
|
||||||
|
# application command interop
|
||||||
|
|
||||||
|
# As mentioned above, this is a mixin so the Self type hint fails here.
|
||||||
|
# However, since the only classes that can use this are subclasses of Client
|
||||||
|
# anyway, then this is sound.
|
||||||
|
@property
|
||||||
|
def tree(self) -> app_commands.CommandTree[Self]: # type: ignore
|
||||||
|
""":class:`~discord.app_commands.CommandTree`: The command tree responsible for handling the application commands
|
||||||
|
in this bot.
|
||||||
|
|
||||||
|
.. versionadded:: 2.0
|
||||||
|
"""
|
||||||
|
return self.__tree
|
||||||
|
|
||||||
# command processing
|
# command processing
|
||||||
|
|
||||||
async def get_prefix(self, message: Message) -> Union[List[str], str]:
|
async def get_prefix(self, message: Message) -> Union[List[str], str]:
|
||||||
|
@ -24,14 +24,15 @@ DEALINGS IN THE SOFTWARE.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import discord.utils
|
import discord
|
||||||
|
from discord import app_commands
|
||||||
|
|
||||||
from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Type
|
from typing import Any, Callable, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union, Type
|
||||||
|
|
||||||
from ._types import _BaseCommand
|
from ._types import _BaseCommand
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self, TypeGuard
|
||||||
|
|
||||||
from .bot import BotBase
|
from .bot import BotBase
|
||||||
from .context import Context
|
from .context import Context
|
||||||
@ -110,19 +111,33 @@ class CogMeta(type):
|
|||||||
__cog_name__: str
|
__cog_name__: str
|
||||||
__cog_settings__: Dict[str, Any]
|
__cog_settings__: Dict[str, Any]
|
||||||
__cog_commands__: List[Command]
|
__cog_commands__: List[Command]
|
||||||
|
__cog_is_app_commands_group__: bool
|
||||||
|
__cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Any, ..., Any]]]
|
||||||
__cog_listeners__: List[Tuple[str, str]]
|
__cog_listeners__: List[Tuple[str, str]]
|
||||||
|
|
||||||
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
|
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
|
||||||
name, bases, attrs = args
|
name, bases, attrs = args
|
||||||
attrs['__cog_name__'] = kwargs.pop('name', name)
|
attrs['__cog_name__'] = kwargs.get('name', name)
|
||||||
attrs['__cog_settings__'] = kwargs.pop('command_attrs', {})
|
attrs['__cog_settings__'] = kwargs.pop('command_attrs', {})
|
||||||
|
attrs['__cog_is_app_commands_group__'] = is_parent = app_commands.Group in bases
|
||||||
|
|
||||||
description = kwargs.pop('description', None)
|
description = kwargs.get('description', None)
|
||||||
if description is None:
|
if description is None:
|
||||||
description = inspect.cleandoc(attrs.get('__doc__', ''))
|
description = inspect.cleandoc(attrs.get('__doc__', ''))
|
||||||
attrs['__cog_description__'] = description
|
attrs['__cog_description__'] = description
|
||||||
|
|
||||||
|
if is_parent:
|
||||||
|
attrs['__discord_app_commands_skip_init_binding__'] = True
|
||||||
|
# This is hacky, but it signals the Group not to process this info.
|
||||||
|
# It's overridden later.
|
||||||
|
attrs['__discord_app_commands_group_children__'] = True
|
||||||
|
else:
|
||||||
|
# Remove the extraneous keyword arguments we're using
|
||||||
|
kwargs.pop('name', None)
|
||||||
|
kwargs.pop('description', None)
|
||||||
|
|
||||||
commands = {}
|
commands = {}
|
||||||
|
cog_app_commands = {}
|
||||||
listeners = {}
|
listeners = {}
|
||||||
no_bot_cog = 'Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})'
|
no_bot_cog = 'Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})'
|
||||||
|
|
||||||
@ -143,6 +158,8 @@ class CogMeta(type):
|
|||||||
if elem.startswith(('cog_', 'bot_')):
|
if elem.startswith(('cog_', 'bot_')):
|
||||||
raise TypeError(no_bot_cog.format(base, elem))
|
raise TypeError(no_bot_cog.format(base, elem))
|
||||||
commands[elem] = value
|
commands[elem] = value
|
||||||
|
elif isinstance(value, (app_commands.Group, app_commands.Command)) and value.parent is None:
|
||||||
|
cog_app_commands[elem] = value
|
||||||
elif inspect.iscoroutinefunction(value):
|
elif inspect.iscoroutinefunction(value):
|
||||||
try:
|
try:
|
||||||
getattr(value, '__cog_listener__')
|
getattr(value, '__cog_listener__')
|
||||||
@ -154,6 +171,13 @@ class CogMeta(type):
|
|||||||
listeners[elem] = value
|
listeners[elem] = value
|
||||||
|
|
||||||
new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__
|
new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__
|
||||||
|
new_cls.__cog_app_commands__ = list(cog_app_commands.values())
|
||||||
|
|
||||||
|
if is_parent:
|
||||||
|
# Prefill the app commands for the Group as well..
|
||||||
|
# The type checker doesn't like runtime attribute modification and this one's
|
||||||
|
# optional so it can't be cheesed.
|
||||||
|
new_cls.__discord_app_commands_group_children__ = cog_app_commands # type: ignore
|
||||||
|
|
||||||
listeners_as_list = []
|
listeners_as_list = []
|
||||||
for listener in listeners.values():
|
for listener in listeners.values():
|
||||||
@ -189,10 +213,11 @@ class Cog(metaclass=CogMeta):
|
|||||||
are equally valid here.
|
are equally valid here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__cog_name__: ClassVar[str]
|
__cog_name__: str
|
||||||
__cog_settings__: ClassVar[Dict[str, Any]]
|
__cog_settings__: Dict[str, Any]
|
||||||
__cog_commands__: ClassVar[List[Command[Self, ..., Any]]]
|
__cog_commands__: List[Command[Self, ..., Any]]
|
||||||
__cog_listeners__: ClassVar[List[Tuple[str, str]]]
|
__cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Self, ..., Any]]]
|
||||||
|
__cog_listeners__: List[Tuple[str, str]]
|
||||||
|
|
||||||
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
|
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
|
||||||
# For issue 426, we need to store a copy of the command objects
|
# For issue 426, we need to store a copy of the command objects
|
||||||
@ -219,6 +244,25 @@ class Cog(metaclass=CogMeta):
|
|||||||
parent.remove_command(command.name) # type: ignore
|
parent.remove_command(command.name) # type: ignore
|
||||||
parent.add_command(command) # type: ignore
|
parent.add_command(command) # type: ignore
|
||||||
|
|
||||||
|
# Register the application commands
|
||||||
|
children: List[Union[app_commands.Group, app_commands.Command[Self, ..., Any]]] = []
|
||||||
|
for command in cls.__cog_app_commands__:
|
||||||
|
copy = command._copy_with_binding(self)
|
||||||
|
|
||||||
|
if cls.__cog_is_app_commands_group__:
|
||||||
|
# Type checker doesn't understand this type of narrowing.
|
||||||
|
# Not even with TypeGuard somehow.
|
||||||
|
copy.parent = self # type: ignore
|
||||||
|
|
||||||
|
children.append(copy)
|
||||||
|
if command._attr:
|
||||||
|
setattr(self, command._attr, copy)
|
||||||
|
|
||||||
|
self.__cog_app_commands__ = children
|
||||||
|
if cls.__cog_is_app_commands_group__:
|
||||||
|
# Dynamic attribute setting
|
||||||
|
self.__discord_app_commands_group_children__ = children # type: ignore
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_commands(self) -> List[Command[Self, ..., Any]]:
|
def get_commands(self) -> List[Command[Self, ..., Any]]:
|
||||||
@ -452,6 +496,12 @@ class Cog(metaclass=CogMeta):
|
|||||||
for name, method_name in self.__cog_listeners__:
|
for name, method_name in self.__cog_listeners__:
|
||||||
bot.add_listener(getattr(self, method_name), name)
|
bot.add_listener(getattr(self, method_name), name)
|
||||||
|
|
||||||
|
# Only do this if these are "top level" commands
|
||||||
|
if not cls.__cog_is_app_commands_group__:
|
||||||
|
for command in self.__cog_app_commands__:
|
||||||
|
# This is already atomic
|
||||||
|
bot.tree.add_command(command)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _eject(self, bot: BotBase) -> None:
|
def _eject(self, bot: BotBase) -> None:
|
||||||
@ -462,6 +512,16 @@ class Cog(metaclass=CogMeta):
|
|||||||
if command.parent is None:
|
if command.parent is None:
|
||||||
bot.remove_command(command.name)
|
bot.remove_command(command.name)
|
||||||
|
|
||||||
|
if not cls.__cog_is_app_commands_group__:
|
||||||
|
for command in self.__cog_app_commands__:
|
||||||
|
try:
|
||||||
|
guild_ids = command.__discord_app_commands_default_guilds__
|
||||||
|
except AttributeError:
|
||||||
|
bot.tree.remove_command(command.name)
|
||||||
|
else:
|
||||||
|
for guild_id in guild_ids:
|
||||||
|
bot.tree.remove_command(command.name, guild=discord.Object(id=guild_id))
|
||||||
|
|
||||||
for name, method_name in self.__cog_listeners__:
|
for name, method_name in self.__cog_listeners__:
|
||||||
bot.remove_listener(getattr(self, method_name), name)
|
bot.remove_listener(getattr(self, method_name), name)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user