[commands] Automatically unload top level app commands in extensions

This commit is contained in:
Rapptz
2022-03-12 09:24:26 -05:00
parent a672455ca9
commit 0ef369c0fa
5 changed files with 55 additions and 7 deletions

View File

@ -365,6 +365,7 @@ class Command(Generic[GroupT, P, T]):
self.parent: Optional[Group] = parent
self.binding: Optional[GroupT] = None
self.on_error: Optional[Error[GroupT]] = None
self.module: Optional[str] = callback.__module__
# Unwrap __self__ for bound methods
try:
@ -626,6 +627,7 @@ class ContextMenu:
raise ValueError(f'context menu callback implies a type of {actual_type} but {type} was passed.')
self._param_name = param
self._annotation = annotation
self.module: Optional[str] = callback.__module__
@property
def callback(self) -> ContextMenuCallback:
@ -642,6 +644,7 @@ class ContextMenu:
self.type = type
self._param_name = param
self._annotation = annotation
self.module = callback.__module__
return self
def to_dict(self) -> Dict[str, Any]:
@ -683,6 +686,7 @@ class Group:
__discord_app_commands_skip_init_binding__: bool = False
__discord_app_commands_group_name__: str = MISSING
__discord_app_commands_group_description__: str = MISSING
__discord_app_commands_has_module__: bool = False
def __init_subclass__(cls, *, name: str = MISSING, description: str = MISSING) -> None:
if not cls.__discord_app_commands_group_children__:
@ -712,6 +716,9 @@ class Group:
else:
cls.__discord_app_commands_group_description__ = description
if cls.__module__ != __name__:
cls.__discord_app_commands_has_module__ = True
def __init__(
self,
*,
@ -730,6 +737,16 @@ class Group:
raise TypeError('groups must have a description')
self.parent: Optional[Group] = parent
self.module: Optional[str]
if cls.__discord_app_commands_has_module__:
self.module = cls.__module__
else:
try:
# This is pretty hacky
# It allows the module to be fetched if someone just constructs a bare Group object though.
self.module = inspect.currentframe().f_back.f_globals['__name__'] # type: ignore
except (AttributeError, IndexError):
self.module = None
self._children: Dict[str, Union[Command, Group]] = {}
@ -745,6 +762,7 @@ class Group:
def __set_name__(self, owner: Type[Any], name: str) -> None:
self._attr = name
self.module = owner.__module__
def _copy_with_binding(self, binding: Union[Group, Cog]) -> Group:
cls = self.__class__

View File

@ -40,7 +40,7 @@ from .errors import (
)
from ..errors import ClientException
from ..enums import AppCommandType, InteractionType
from ..utils import MISSING, _get_as_snowflake
from ..utils import MISSING, _get_as_snowflake, _is_submodule
if TYPE_CHECKING:
from ..types.interactions import ApplicationCommandInteractionData, ApplicationCommandInteractionDataOption
@ -489,6 +489,32 @@ class CommandTree(Generic[ClientT]):
base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g == guild_id)
return base
def _remove_with_module(self, name: str) -> None:
remove: List[Any] = []
for key, cmd in self._context_menus.items():
if cmd.module is not None and _is_submodule(name, cmd.module):
remove.append(key)
for key in remove:
del self._context_menus[key]
remove = []
for key, cmd in self._global_commands.items():
if cmd.module is not None and _is_submodule(name, cmd.module):
remove.append(key)
for key in remove:
del self._global_commands[key]
for mapping in self._guild_commands.values():
remove = []
for key, cmd in mapping.items():
if cmd.module is not None and _is_submodule(name, cmd.module):
remove.append(key)
for key in remove:
del mapping[key]
async def on_error(
self,
interaction: Interaction,