[commands] Rework help command to avoid a deepcopy on invoke

This commit is contained in:
Josh 2022-03-19 20:34:19 +10:00 committed by GitHub
parent 94f4da9248
commit fafc5b13f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 99 additions and 193 deletions

View File

@ -354,6 +354,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
""" """
from .core import Group, Command, wrap_callback from .core import Group, Command, wrap_callback
from .errors import CommandError from .errors import CommandError
from .help import _context
bot = self.bot bot = self.bot
cmd = bot.help_command cmd = bot.help_command
@ -361,8 +362,8 @@ class Context(discord.abc.Messageable, Generic[BotT]):
if cmd is None: if cmd is None:
return None return None
cmd = cmd.copy() _context.set(self)
cmd.context = self # type: ignore
if len(args) == 0: if len(args) == 0:
await cmd.prepare_help_command(self, None) await cmd.prepare_help_command(self, None)
mapping = cmd.get_bot_mapping() mapping = cmd.get_bot_mapping()

View File

@ -24,8 +24,8 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations from __future__ import annotations
from contextvars import ContextVar
import itertools import itertools
import copy
import functools import functools
import re import re
@ -33,12 +33,12 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Optional, Optional,
Generator, Generator,
Generic,
List, List,
TypeVar, TypeVar,
Callable, Callable,
Any, Any,
Dict, Dict,
Tuple,
Iterable, Iterable,
Sequence, Sequence,
Mapping, Mapping,
@ -50,7 +50,6 @@ from .core import Group, Command, get_signature_parameters
from .errors import CommandError from .errors import CommandError
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self
import inspect import inspect
import discord.abc import discord.abc
@ -59,13 +58,6 @@ if TYPE_CHECKING:
from .context import Context from .context import Context
from .cog import Cog from .cog import Cog
from ._types import (
Check,
ContextT,
BotT,
_Bot,
)
__all__ = ( __all__ = (
'Paginator', 'Paginator',
'HelpCommand', 'HelpCommand',
@ -73,7 +65,11 @@ __all__ = (
'MinimalHelpCommand', 'MinimalHelpCommand',
) )
T = TypeVar('T')
ContextT = TypeVar('ContextT', bound='Context')
FuncT = TypeVar('FuncT', bound=Callable[..., Any]) FuncT = TypeVar('FuncT', bound=Callable[..., Any])
HelpCommandCommand = Command[Optional['Cog'], ... if TYPE_CHECKING else Any, Any]
MISSING: Any = discord.utils.MISSING MISSING: Any = discord.utils.MISSING
@ -219,92 +215,12 @@ def _not_overridden(f: FuncT) -> FuncT:
return f return f
class _HelpCommandImpl(Command): _context: ContextVar[Optional[Context]] = ContextVar('context', default=None)
def __init__(self, inject: HelpCommand, *args: Any, **kwargs: Any) -> None:
super().__init__(inject.command_callback, *args, **kwargs)
self._original: HelpCommand = inject
self._injected: HelpCommand = inject
self.params: Dict[str, inspect.Parameter] = get_signature_parameters(
inject.command_callback, globals(), skip_parameters=1
)
async def prepare(self, ctx: Context[Any]) -> None:
self._injected = injected = self._original.copy()
injected.context = ctx
self.callback = injected.command_callback
self.params = get_signature_parameters(injected.command_callback, globals(), skip_parameters=1)
on_error = injected.on_help_command_error
if not hasattr(on_error, '__help_command_not_overridden__'):
if self.cog is not None:
self.on_error = self._on_error_cog_implementation
else:
self.on_error = on_error
await super().prepare(ctx)
async def _parse_arguments(self, ctx: Context[BotT]) -> None:
# Make the parser think we don't have a cog so it doesn't
# inject the parameter into `ctx.args`.
original_cog = self.cog
self.cog = None
try:
await super()._parse_arguments(ctx)
finally:
self.cog = original_cog
async def _on_error_cog_implementation(self, _, ctx: Context[BotT], error: CommandError) -> None:
await self._injected.on_help_command_error(ctx, error)
def _inject_into_cog(self, cog: Cog) -> None:
# Warning: hacky
# Make the cog think that get_commands returns this command
# as well if we inject it without modifying __cog_commands__
# since that's used for the injection and ejection of cogs.
def wrapped_get_commands(
*, _original: Callable[[], List[Command[Any, ..., Any]]] = cog.get_commands
) -> List[Command[Any, ..., Any]]:
ret = _original()
ret.append(self)
return ret
# Ditto here
def wrapped_walk_commands(
*, _original: Callable[[], Generator[Command[Any, ..., Any], None, None]] = cog.walk_commands
):
yield from _original()
yield self
functools.update_wrapper(wrapped_get_commands, cog.get_commands)
functools.update_wrapper(wrapped_walk_commands, cog.walk_commands)
cog.get_commands = wrapped_get_commands
cog.walk_commands = wrapped_walk_commands
self.cog = cog
def _eject_cog(self) -> None:
if self.cog is None:
return
# revert back into their original methods
cog = self.cog
cog.get_commands = cog.get_commands.__wrapped__
cog.walk_commands = cog.walk_commands.__wrapped__
self.cog = None
class HelpCommand: class HelpCommand(HelpCommandCommand, Generic[ContextT]):
r"""The base implementation for help command formatting. r"""The base implementation for help command formatting.
.. note::
Internally instances of this class are deep copied every time
the command itself is invoked to prevent a race condition
mentioned in :issue:`2123`.
This means that relying on the state of this class to be
the same between command invocations would not work as expected.
Attributes Attributes
------------ ------------
context: Optional[:class:`Context`] context: Optional[:class:`Context`]
@ -336,88 +252,53 @@ class HelpCommand:
MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys())) MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys()))
if TYPE_CHECKING: def __init__(
__original_kwargs__: Dict[str, Any] self,
__original_args__: Tuple[Any, ...] *,
show_hidden: bool = False,
def __new__(cls, *args: Any, **kwargs: Any) -> Self: verify_checks: bool = True,
# To prevent race conditions of a single instance while also allowing command_attrs: Dict[str, Any] = MISSING,
# for settings to be passed the original arguments passed must be assigned ) -> None:
# to allow for easier copies (which will be made when the help command is actually called) self.show_hidden: bool = show_hidden
# see issue 2123 self.verify_checks: bool = verify_checks
self = super().__new__(cls) self.command_attrs = attrs = command_attrs if command_attrs is not MISSING else {}
# Shallow copies cannot be used in this case since it is not unusual to pass
# instances that need state, e.g. Paginator or what have you into the function
# The keys can be safely copied as-is since they're 99.99% certain of being
# string keys
deepcopy = copy.deepcopy
self.__original_kwargs__ = {k: deepcopy(v) for k, v in kwargs.items()}
self.__original_args__ = deepcopy(args)
return self
def __init__(self, **options: Any) -> None:
self.show_hidden: bool = options.pop('show_hidden', False)
self.verify_checks: bool = options.pop('verify_checks', True)
self.command_attrs: Dict[str, Any]
self.command_attrs = attrs = options.pop('command_attrs', {})
attrs.setdefault('name', 'help') attrs.setdefault('name', 'help')
attrs.setdefault('help', 'Shows this message') attrs.setdefault('help', 'Shows this message')
self.context: Context[_Bot] = MISSING self._cog: Optional[Cog] = None
self._command_impl = _HelpCommandImpl(self, **self.command_attrs) super().__init__(self._set_context, **attrs)
self.params: Dict[str, inspect.Parameter] = get_signature_parameters(
self.command_callback, globals(), skip_parameters=1
)
def copy(self) -> Self: async def __call__(self, context: ContextT, *args: Any, **kwargs: Any) -> Any:
obj = self.__class__(*self.__original_args__, **self.__original_kwargs__) return await self.command_callback(context, *args, **kwargs)
obj._command_impl = self._command_impl
return obj async def _set_context(self, context: ContextT, *args: Any, **kwargs: Any) -> Any:
_context.set(context)
return await self.command_callback(context, *args, **kwargs)
@property
def context(self) -> ContextT:
ctx = _context.get()
if ctx is None:
raise AttributeError('context attribute cannot be accessed in non command-invocation contexts.')
return ctx # type: ignore
def _add_to_bot(self, bot: BotBase) -> None: def _add_to_bot(self, bot: BotBase) -> None:
command = _HelpCommandImpl(self, **self.command_attrs) bot.add_command(self) # type: ignore
bot.add_command(command)
self._command_impl = command
def _remove_from_bot(self, bot: BotBase) -> None: def _remove_from_bot(self, bot: BotBase) -> None:
bot.remove_command(self._command_impl.name) bot.remove_command(self) # type: ignore
self._command_impl._eject_cog() self._eject_cog()
def add_check(self, func: Check[ContextT], /) -> None: async def invoke(self, ctx: ContextT) -> None:
""" # we need to temporarily set the cog to None to prevent the cog
Adds a check to the help command. # from being passed into the command callback.
cog = self._cog
.. versionadded:: 1.4 self._cog = None
await self.prepare(ctx)
.. versionchanged:: 2.0 self._cog = cog
await self.callback(*ctx.args, **ctx.kwargs)
``func`` parameter is now positional-only.
Parameters
----------
func
The function that will be used as a check.
"""
self._command_impl.add_check(func)
def remove_check(self, func: Check[ContextT], /) -> None:
"""
Removes a check from the help command.
This function is idempotent and will not raise an exception if
the function is not in the command's checks.
.. versionadded:: 1.4
.. versionchanged:: 2.0
``func`` parameter is now positional-only.
Parameters
----------
func
The function to remove from the checks.
"""
self._command_impl.remove_check(func)
def get_bot_mapping(self) -> Dict[Optional[Cog], List[Command[Any, ..., Any]]]: def get_bot_mapping(self) -> Dict[Optional[Cog], List[Command[Any, ..., Any]]]:
"""Retrieves the bot mapping passed to :meth:`send_bot_help`.""" """Retrieves the bot mapping passed to :meth:`send_bot_help`."""
@ -441,7 +322,7 @@ class HelpCommand:
Optional[:class:`str`] Optional[:class:`str`]
The command name that triggered this invocation. The command name that triggered this invocation.
""" """
command_name = self._command_impl.name command_name = self.name
ctx = self.context ctx = self.context
if ctx is MISSING or ctx.command is None or ctx.command.qualified_name != command_name: if ctx is MISSING or ctx.command is None or ctx.command.qualified_name != command_name:
return command_name return command_name
@ -498,31 +379,54 @@ class HelpCommand:
return self.MENTION_PATTERN.sub(replace, string) return self.MENTION_PATTERN.sub(replace, string)
def _inject_into_cog(self, cog: Cog) -> None:
# Warning: hacky
# Make the cog think that get_commands returns this command
# as well if we inject it without modifying __cog_commands__
# since that's used for the injection and ejection of cogs.
def wrapped_get_commands(
*, _original: Callable[[], List[Command[Any, ..., Any]]] = cog.get_commands
) -> List[Command[Any, ..., Any]]:
ret = _original()
ret.append(self)
return ret
# Ditto here
def wrapped_walk_commands(
*, _original: Callable[[], Generator[Command[Any, ..., Any], None, None]] = cog.walk_commands
):
yield from _original()
yield self
functools.update_wrapper(wrapped_get_commands, cog.get_commands)
functools.update_wrapper(wrapped_walk_commands, cog.walk_commands)
cog.get_commands = wrapped_get_commands
cog.walk_commands = wrapped_walk_commands
self._cog = cog
def _eject_cog(self) -> None:
if self._cog is None:
return
# revert back into their original methods
cog = self._cog
cog.get_commands = cog.get_commands.__wrapped__
cog.walk_commands = cog.walk_commands.__wrapped__
self._cog = None
@property @property
def cog(self) -> Optional[Cog]: def cog(self) -> Optional[Cog]:
"""A property for retrieving or setting the cog for the help command. return self._cog
When a cog is set for the help command, it is as-if the help command
belongs to that cog. All cog special methods will apply to the help
command and it will be automatically unset on unload.
To unbind the cog from the help command, you can set it to ``None``.
Returns
--------
Optional[:class:`Cog`]
The cog that is currently set for the help command.
"""
return self._command_impl.cog
@cog.setter @cog.setter
def cog(self, cog: Optional[Cog]) -> None: def cog(self, cog: Optional[Cog]) -> None:
# Remove whatever cog is currently valid, if any # Remove whatever cog is currently valid, if any
self._command_impl._eject_cog() self._eject_cog()
# If a new cog is set then inject it. # If a new cog is set then inject it.
if cog is not None: if cog is not None:
self._command_impl._inject_into_cog(cog) self._inject_into_cog(cog)
def command_not_found(self, string: str) -> str: def command_not_found(self, string: str) -> str:
"""|maybecoro| """|maybecoro|
@ -693,7 +597,7 @@ class HelpCommand:
await destination.send(error) await destination.send(error)
@_not_overridden @_not_overridden
async def on_help_command_error(self, ctx: Context[BotT], error: CommandError) -> None: async def on_help_command_error(self, ctx: ContextT, error: CommandError) -> None:
"""|coro| """|coro|
The help command's error handler, as specified by :ref:`ext_commands_error_handler`. The help command's error handler, as specified by :ref:`ext_commands_error_handler`.
@ -836,7 +740,7 @@ class HelpCommand:
""" """
return None return None
async def prepare_help_command(self, ctx: Context[BotT], command: Optional[str] = None) -> None: async def prepare_help_command(self, ctx: ContextT, command: Optional[str] = None) -> None:
"""|coro| """|coro|
A low level method that can be used to prepare the help command A low level method that can be used to prepare the help command
@ -860,7 +764,7 @@ class HelpCommand:
""" """
pass pass
async def command_callback(self, ctx: Context[BotT], *, command: Optional[str] = None) -> None: async def command_callback(self, ctx: ContextT, *, command: Optional[str] = None) -> Any:
"""|coro| """|coro|
The actual implementation of the help command. The actual implementation of the help command.
@ -880,6 +784,7 @@ class HelpCommand:
- :meth:`prepare_help_command` - :meth:`prepare_help_command`
""" """
await self.prepare_help_command(ctx, command) await self.prepare_help_command(ctx, command)
bot = ctx.bot bot = ctx.bot
if command is None: if command is None:
@ -905,7 +810,7 @@ class HelpCommand:
for key in keys[1:]: for key in keys[1:]:
try: try:
found = cmd.all_commands.get(key) # type: ignore found = cmd.all_commands.get(key)
except AttributeError: except AttributeError:
string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key)) string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key))
return await self.send_error_message(string) return await self.send_error_message(string)
@ -921,7 +826,7 @@ class HelpCommand:
return await self.send_command_help(cmd) return await self.send_command_help(cmd)
class DefaultHelpCommand(HelpCommand): class DefaultHelpCommand(HelpCommand[ContextT]):
"""The implementation of the default help command. """The implementation of the default help command.
This inherits from :class:`HelpCommand`. This inherits from :class:`HelpCommand`.
@ -1062,7 +967,7 @@ class DefaultHelpCommand(HelpCommand):
else: else:
return ctx.channel return ctx.channel
async def prepare_help_command(self, ctx: Context[BotT], command: str) -> None: async def prepare_help_command(self, ctx: ContextT, command: str) -> None:
self.paginator.clear() self.paginator.clear()
await super().prepare_help_command(ctx, command) await super().prepare_help_command(ctx, command)
@ -1130,7 +1035,7 @@ class DefaultHelpCommand(HelpCommand):
await self.send_pages() await self.send_pages()
class MinimalHelpCommand(HelpCommand): class MinimalHelpCommand(HelpCommand[ContextT]):
"""An implementation of a help command with minimal output. """An implementation of a help command with minimal output.
This inherits from :class:`HelpCommand`. This inherits from :class:`HelpCommand`.
@ -1306,7 +1211,7 @@ class MinimalHelpCommand(HelpCommand):
else: else:
return ctx.channel return ctx.channel
async def prepare_help_command(self, ctx: Context[BotT], command: str) -> None: async def prepare_help_command(self, ctx: ContextT, command: str) -> None:
self.paginator.clear() self.paginator.clear()
await super().prepare_help_command(ctx, command) await super().prepare_help_command(ctx, command)