[commands] Rework help command to avoid a deepcopy on invoke

This commit is contained in:
Josh
2022-03-19 20:34:19 +10:00
committed by GitHub
parent 94f4da9248
commit fafc5b13f6
2 changed files with 99 additions and 193 deletions

View File

@@ -354,6 +354,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
"""
from .core import Group, Command, wrap_callback
from .errors import CommandError
from .help import _context
bot = self.bot
cmd = bot.help_command
@@ -361,8 +362,8 @@ class Context(discord.abc.Messageable, Generic[BotT]):
if cmd is None:
return None
cmd = cmd.copy()
cmd.context = self # type: ignore
_context.set(self)
if len(args) == 0:
await cmd.prepare_help_command(self, None)
mapping = cmd.get_bot_mapping()

View File

@@ -24,8 +24,8 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from contextvars import ContextVar
import itertools
import copy
import functools
import re
@@ -33,12 +33,12 @@ from typing import (
TYPE_CHECKING,
Optional,
Generator,
Generic,
List,
TypeVar,
Callable,
Any,
Dict,
Tuple,
Iterable,
Sequence,
Mapping,
@@ -50,7 +50,6 @@ from .core import Group, Command, get_signature_parameters
from .errors import CommandError
if TYPE_CHECKING:
from typing_extensions import Self
import inspect
import discord.abc
@@ -59,13 +58,6 @@ if TYPE_CHECKING:
from .context import Context
from .cog import Cog
from ._types import (
Check,
ContextT,
BotT,
_Bot,
)
__all__ = (
'Paginator',
'HelpCommand',
@@ -73,7 +65,11 @@ __all__ = (
'MinimalHelpCommand',
)
T = TypeVar('T')
ContextT = TypeVar('ContextT', bound='Context')
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
HelpCommandCommand = Command[Optional['Cog'], ... if TYPE_CHECKING else Any, Any]
MISSING: Any = discord.utils.MISSING
@@ -219,92 +215,12 @@ def _not_overridden(f: FuncT) -> FuncT:
return f
class _HelpCommandImpl(Command):
def __init__(self, inject: HelpCommand, *args: Any, **kwargs: Any) -> None:
super().__init__(inject.command_callback, *args, **kwargs)
self._original: HelpCommand = inject
self._injected: HelpCommand = inject
self.params: Dict[str, inspect.Parameter] = get_signature_parameters(
inject.command_callback, globals(), skip_parameters=1
)
async def prepare(self, ctx: Context[Any]) -> None:
self._injected = injected = self._original.copy()
injected.context = ctx
self.callback = injected.command_callback
self.params = get_signature_parameters(injected.command_callback, globals(), skip_parameters=1)
on_error = injected.on_help_command_error
if not hasattr(on_error, '__help_command_not_overridden__'):
if self.cog is not None:
self.on_error = self._on_error_cog_implementation
else:
self.on_error = on_error
await super().prepare(ctx)
async def _parse_arguments(self, ctx: Context[BotT]) -> None:
# Make the parser think we don't have a cog so it doesn't
# inject the parameter into `ctx.args`.
original_cog = self.cog
self.cog = None
try:
await super()._parse_arguments(ctx)
finally:
self.cog = original_cog
async def _on_error_cog_implementation(self, _, ctx: Context[BotT], error: CommandError) -> None:
await self._injected.on_help_command_error(ctx, error)
def _inject_into_cog(self, cog: Cog) -> None:
# Warning: hacky
# Make the cog think that get_commands returns this command
# as well if we inject it without modifying __cog_commands__
# since that's used for the injection and ejection of cogs.
def wrapped_get_commands(
*, _original: Callable[[], List[Command[Any, ..., Any]]] = cog.get_commands
) -> List[Command[Any, ..., Any]]:
ret = _original()
ret.append(self)
return ret
# Ditto here
def wrapped_walk_commands(
*, _original: Callable[[], Generator[Command[Any, ..., Any], None, None]] = cog.walk_commands
):
yield from _original()
yield self
functools.update_wrapper(wrapped_get_commands, cog.get_commands)
functools.update_wrapper(wrapped_walk_commands, cog.walk_commands)
cog.get_commands = wrapped_get_commands
cog.walk_commands = wrapped_walk_commands
self.cog = cog
def _eject_cog(self) -> None:
if self.cog is None:
return
# revert back into their original methods
cog = self.cog
cog.get_commands = cog.get_commands.__wrapped__
cog.walk_commands = cog.walk_commands.__wrapped__
self.cog = None
_context: ContextVar[Optional[Context]] = ContextVar('context', default=None)
class HelpCommand:
class HelpCommand(HelpCommandCommand, Generic[ContextT]):
r"""The base implementation for help command formatting.
.. note::
Internally instances of this class are deep copied every time
the command itself is invoked to prevent a race condition
mentioned in :issue:`2123`.
This means that relying on the state of this class to be
the same between command invocations would not work as expected.
Attributes
------------
context: Optional[:class:`Context`]
@@ -336,88 +252,53 @@ class HelpCommand:
MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys()))
if TYPE_CHECKING:
__original_kwargs__: Dict[str, Any]
__original_args__: Tuple[Any, ...]
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
# To prevent race conditions of a single instance while also allowing
# for settings to be passed the original arguments passed must be assigned
# to allow for easier copies (which will be made when the help command is actually called)
# see issue 2123
self = super().__new__(cls)
# Shallow copies cannot be used in this case since it is not unusual to pass
# instances that need state, e.g. Paginator or what have you into the function
# The keys can be safely copied as-is since they're 99.99% certain of being
# string keys
deepcopy = copy.deepcopy
self.__original_kwargs__ = {k: deepcopy(v) for k, v in kwargs.items()}
self.__original_args__ = deepcopy(args)
return self
def __init__(self, **options: Any) -> None:
self.show_hidden: bool = options.pop('show_hidden', False)
self.verify_checks: bool = options.pop('verify_checks', True)
self.command_attrs: Dict[str, Any]
self.command_attrs = attrs = options.pop('command_attrs', {})
def __init__(
self,
*,
show_hidden: bool = False,
verify_checks: bool = True,
command_attrs: Dict[str, Any] = MISSING,
) -> None:
self.show_hidden: bool = show_hidden
self.verify_checks: bool = verify_checks
self.command_attrs = attrs = command_attrs if command_attrs is not MISSING else {}
attrs.setdefault('name', 'help')
attrs.setdefault('help', 'Shows this message')
self.context: Context[_Bot] = MISSING
self._command_impl = _HelpCommandImpl(self, **self.command_attrs)
self._cog: Optional[Cog] = None
super().__init__(self._set_context, **attrs)
self.params: Dict[str, inspect.Parameter] = get_signature_parameters(
self.command_callback, globals(), skip_parameters=1
)
def copy(self) -> Self:
obj = self.__class__(*self.__original_args__, **self.__original_kwargs__)
obj._command_impl = self._command_impl
return obj
async def __call__(self, context: ContextT, *args: Any, **kwargs: Any) -> Any:
return await self.command_callback(context, *args, **kwargs)
async def _set_context(self, context: ContextT, *args: Any, **kwargs: Any) -> Any:
_context.set(context)
return await self.command_callback(context, *args, **kwargs)
@property
def context(self) -> ContextT:
ctx = _context.get()
if ctx is None:
raise AttributeError('context attribute cannot be accessed in non command-invocation contexts.')
return ctx # type: ignore
def _add_to_bot(self, bot: BotBase) -> None:
command = _HelpCommandImpl(self, **self.command_attrs)
bot.add_command(command)
self._command_impl = command
bot.add_command(self) # type: ignore
def _remove_from_bot(self, bot: BotBase) -> None:
bot.remove_command(self._command_impl.name)
self._command_impl._eject_cog()
bot.remove_command(self) # type: ignore
self._eject_cog()
def add_check(self, func: Check[ContextT], /) -> None:
"""
Adds a check to the help command.
.. versionadded:: 1.4
.. versionchanged:: 2.0
``func`` parameter is now positional-only.
Parameters
----------
func
The function that will be used as a check.
"""
self._command_impl.add_check(func)
def remove_check(self, func: Check[ContextT], /) -> None:
"""
Removes a check from the help command.
This function is idempotent and will not raise an exception if
the function is not in the command's checks.
.. versionadded:: 1.4
.. versionchanged:: 2.0
``func`` parameter is now positional-only.
Parameters
----------
func
The function to remove from the checks.
"""
self._command_impl.remove_check(func)
async def invoke(self, ctx: ContextT) -> None:
# we need to temporarily set the cog to None to prevent the cog
# from being passed into the command callback.
cog = self._cog
self._cog = None
await self.prepare(ctx)
self._cog = cog
await self.callback(*ctx.args, **ctx.kwargs)
def get_bot_mapping(self) -> Dict[Optional[Cog], List[Command[Any, ..., Any]]]:
"""Retrieves the bot mapping passed to :meth:`send_bot_help`."""
@@ -441,7 +322,7 @@ class HelpCommand:
Optional[:class:`str`]
The command name that triggered this invocation.
"""
command_name = self._command_impl.name
command_name = self.name
ctx = self.context
if ctx is MISSING or ctx.command is None or ctx.command.qualified_name != command_name:
return command_name
@@ -498,31 +379,54 @@ class HelpCommand:
return self.MENTION_PATTERN.sub(replace, string)
def _inject_into_cog(self, cog: Cog) -> None:
# Warning: hacky
# Make the cog think that get_commands returns this command
# as well if we inject it without modifying __cog_commands__
# since that's used for the injection and ejection of cogs.
def wrapped_get_commands(
*, _original: Callable[[], List[Command[Any, ..., Any]]] = cog.get_commands
) -> List[Command[Any, ..., Any]]:
ret = _original()
ret.append(self)
return ret
# Ditto here
def wrapped_walk_commands(
*, _original: Callable[[], Generator[Command[Any, ..., Any], None, None]] = cog.walk_commands
):
yield from _original()
yield self
functools.update_wrapper(wrapped_get_commands, cog.get_commands)
functools.update_wrapper(wrapped_walk_commands, cog.walk_commands)
cog.get_commands = wrapped_get_commands
cog.walk_commands = wrapped_walk_commands
self._cog = cog
def _eject_cog(self) -> None:
if self._cog is None:
return
# revert back into their original methods
cog = self._cog
cog.get_commands = cog.get_commands.__wrapped__
cog.walk_commands = cog.walk_commands.__wrapped__
self._cog = None
@property
def cog(self) -> Optional[Cog]:
"""A property for retrieving or setting the cog for the help command.
When a cog is set for the help command, it is as-if the help command
belongs to that cog. All cog special methods will apply to the help
command and it will be automatically unset on unload.
To unbind the cog from the help command, you can set it to ``None``.
Returns
--------
Optional[:class:`Cog`]
The cog that is currently set for the help command.
"""
return self._command_impl.cog
return self._cog
@cog.setter
def cog(self, cog: Optional[Cog]) -> None:
# Remove whatever cog is currently valid, if any
self._command_impl._eject_cog()
self._eject_cog()
# If a new cog is set then inject it.
if cog is not None:
self._command_impl._inject_into_cog(cog)
self._inject_into_cog(cog)
def command_not_found(self, string: str) -> str:
"""|maybecoro|
@@ -693,7 +597,7 @@ class HelpCommand:
await destination.send(error)
@_not_overridden
async def on_help_command_error(self, ctx: Context[BotT], error: CommandError) -> None:
async def on_help_command_error(self, ctx: ContextT, error: CommandError) -> None:
"""|coro|
The help command's error handler, as specified by :ref:`ext_commands_error_handler`.
@@ -836,7 +740,7 @@ class HelpCommand:
"""
return None
async def prepare_help_command(self, ctx: Context[BotT], command: Optional[str] = None) -> None:
async def prepare_help_command(self, ctx: ContextT, command: Optional[str] = None) -> None:
"""|coro|
A low level method that can be used to prepare the help command
@@ -860,7 +764,7 @@ class HelpCommand:
"""
pass
async def command_callback(self, ctx: Context[BotT], *, command: Optional[str] = None) -> None:
async def command_callback(self, ctx: ContextT, *, command: Optional[str] = None) -> Any:
"""|coro|
The actual implementation of the help command.
@@ -880,6 +784,7 @@ class HelpCommand:
- :meth:`prepare_help_command`
"""
await self.prepare_help_command(ctx, command)
bot = ctx.bot
if command is None:
@@ -905,7 +810,7 @@ class HelpCommand:
for key in keys[1:]:
try:
found = cmd.all_commands.get(key) # type: ignore
found = cmd.all_commands.get(key)
except AttributeError:
string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key))
return await self.send_error_message(string)
@@ -921,7 +826,7 @@ class HelpCommand:
return await self.send_command_help(cmd)
class DefaultHelpCommand(HelpCommand):
class DefaultHelpCommand(HelpCommand[ContextT]):
"""The implementation of the default help command.
This inherits from :class:`HelpCommand`.
@@ -1062,7 +967,7 @@ class DefaultHelpCommand(HelpCommand):
else:
return ctx.channel
async def prepare_help_command(self, ctx: Context[BotT], command: str) -> None:
async def prepare_help_command(self, ctx: ContextT, command: str) -> None:
self.paginator.clear()
await super().prepare_help_command(ctx, command)
@@ -1130,7 +1035,7 @@ class DefaultHelpCommand(HelpCommand):
await self.send_pages()
class MinimalHelpCommand(HelpCommand):
class MinimalHelpCommand(HelpCommand[ContextT]):
"""An implementation of a help command with minimal output.
This inherits from :class:`HelpCommand`.
@@ -1306,7 +1211,7 @@ class MinimalHelpCommand(HelpCommand):
else:
return ctx.channel
async def prepare_help_command(self, ctx: Context[BotT], command: str) -> None:
async def prepare_help_command(self, ctx: ContextT, command: str) -> None:
self.paginator.clear()
await super().prepare_help_command(ctx, command)