diff --git a/discord/client.py b/discord/client.py index 40f82443..3d0216d8 100644 --- a/discord/client.py +++ b/discord/client.py @@ -624,7 +624,7 @@ class Client: async def start(self, token: str, *, reconnect: bool = True) -> None: """|coro| - A shorthand coroutine for :meth:`login` + :meth:`connect`. + A shorthand coroutine for :meth:`login` + :meth:`setup` + :meth:`connect`. Raises ------- @@ -632,8 +632,21 @@ class Client: An unexpected keyword argument was received. """ await self.login(token) + await self.setup() await self.connect(reconnect=reconnect) + async def setup(self) -> Any: + """|coro| + + A coroutine to be called to setup the bot, by default this is blank. + + To perform asynchronous setup after the bot is logged in but before + it has connected to the Websocket, overwrite this coroutine. + + .. versionadded:: 2.0 + """ + pass + def run(self, *args: Any, **kwargs: Any) -> None: """A blocking call that abstracts away the event loop initialisation from you. @@ -722,7 +735,7 @@ class Client: """:class:`.Status`: The status being used upon logging on to Discord. - .. versionadded: 2.0 + .. versionadded:: 2.0 """ if self._connection._status in set(state.value for state in Status): return Status(self._connection._status) diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index 8a341f90..7abd6c8a 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -28,18 +28,43 @@ from __future__ import annotations import asyncio import collections import collections.abc + import inspect import importlib.util import sys import traceback import types -from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union +from collections import defaultdict +from discord.http import HTTPClient +from typing import ( + Any, + Callable, + Iterable, + Tuple, + cast, + Mapping, + List, + Dict, + TYPE_CHECKING, + Optional, + TypeVar, + Type, + Union, +) import discord +from discord.types.interactions import ( + ApplicationCommandInteractionData, + ApplicationCommandInteractionDataOption, + EditApplicationCommand, + _ApplicationCommandInteractionDataOptionString, +) from .core import GroupMixin -from .view import StringView +from .converter import Greedy +from .view import StringView, supported_quotes from .context import Context +from .flags import FlagConverter from . import errors from .help import HelpCommand, DefaultHelpCommand from .cog import Cog @@ -67,6 +92,23 @@ CFT = TypeVar("CFT", bound="CoroFunc") CXT = TypeVar("CXT", bound="Context") +class _FakeSlashMessage(discord.PartialMessage): + activity = application = edited_at = reference = webhook_id = None + attachments = components = reactions = stickers = mentions = [] + author: Union[discord.User, discord.Member] + tts = False + + @classmethod + def from_interaction( + cls, interaction: discord.Interaction, channel: Union[discord.TextChannel, discord.DMChannel, discord.Thread] + ): + self = cls(channel=channel, id=interaction.id) + assert interaction.user is not None + self.author = interaction.user + + return self + + def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]: """A callable that implements a command prefix equivalent to being mentioned. @@ -118,6 +160,35 @@ def _is_submodule(parent: str, child: str) -> bool: return parent == child or child.startswith(parent + ".") +def _unwrap_slash_groups( + data: ApplicationCommandInteractionData, +) -> Tuple[str, List[ApplicationCommandInteractionDataOption]]: + command_name = data["name"] + command_options = data.get("options") or [] + while any(o["type"] in {1, 2} for o in command_options): # type: ignore + for option in command_options: # type: ignore + if option["type"] in {1, 2}: # type: ignore + command_name += f' {option["name"]}' # type: ignore + command_options = option.get("options") or [] + + return command_name, command_options + + +def _quote_string_safe(string: str) -> str: + # we need to quote this string otherwise we may spill into + # other parameters and cause all kinds of trouble, as many + # quotes are supported and some may be in the option, we + # loop through all supported quotes and if neither open or + # close are in the string, we add them + for open, close in supported_quotes.items(): + if open not in string and close not in string: + return f"{open}{string}{close}" + + # all supported quotes are in the message and we cannot add any + # safely, very unlikely but still got to be covered + raise errors.UnexpectedQuoteError(string) + + class _DefaultRepr: def __repr__(self): return "" @@ -127,9 +198,22 @@ _default = _DefaultRepr() class BotBase(GroupMixin): - def __init__(self, command_prefix, help_command=_default, description=None, *, intents: discord.Intents, **options): + def __init__( + self, + command_prefix, + help_command=_default, + description=None, + *, + intents: discord.Intents, + message_commands: bool = True, + slash_commands: bool = False, + **options, + ): super().__init__(**options, intents=intents) + self.command_prefix = command_prefix + self.slash_commands = slash_commands + self.message_commands = message_commands self.extra_events: Dict[str, List[CoroFunc]] = {} self.__cogs: Dict[str, Cog] = {} self.__extensions: Dict[str, types.ModuleType] = {} @@ -142,6 +226,7 @@ class BotBase(GroupMixin): self.owner_id = options.get("owner_id") self.owner_ids = options.get("owner_ids", set()) self.strip_after_prefix = options.get("strip_after_prefix", False) + self.slash_command_guilds: Optional[Iterable[int]] = options.get("slash_command_guilds", None) if self.owner_id and self.owner_ids: raise TypeError("Both owner_id and owner_ids are set.") @@ -149,6 +234,9 @@ class BotBase(GroupMixin): if self.owner_ids and not isinstance(self.owner_ids, collections.abc.Collection): raise TypeError(f"owner_ids must be a collection not {self.owner_ids.__class__!r}") + if not (message_commands or slash_commands): + raise ValueError("Both message_commands and slash_commands are disabled.") + if help_command is _default: self.help_command = DefaultHelpCommand() else: @@ -163,6 +251,55 @@ class BotBase(GroupMixin): for event in self.extra_events.get(ev, []): self._schedule_event(event, ev, *args, **kwargs) # type: ignore + async def setup(self): + await self.create_slash_commands() + + async def create_slash_commands(self): + commands: defaultdict[Optional[int], List[EditApplicationCommand]] = defaultdict(list) + for command in self.commands: + if command.hidden or (command.slash_command is None and not self.slash_commands): + continue + + try: + payload = command.to_application_command() + except Exception: + raise errors.ApplicationCommandRegistrationError(command) + + if payload is None: + continue + + guilds = command.slash_command_guilds or self.slash_command_guilds + if guilds is None: + commands[None].append(payload) + else: + for guild in guilds: + commands[guild].append(payload) + + http: HTTPClient = self.http # type: ignore + global_commands = commands.pop(None, None) + application_id = self.application_id or (await self.application_info()).id # type: ignore + if global_commands is not None: + if self.slash_command_guilds is None: + await http.bulk_upsert_global_commands( + payload=global_commands, + application_id=application_id, + ) + else: + for guild in self.slash_command_guilds: + await http.bulk_upsert_guild_commands( + guild_id=guild, + payload=global_commands, + application_id=application_id, + ) + + for guild, guild_commands in commands.items(): + assert guild is not None + await http.bulk_upsert_guild_commands( + guild_id=guild, + payload=guild_commands, + application_id=application_id, + ) + @discord.utils.copy_doc(discord.Client.close) async def close(self) -> None: for extension in tuple(self.__extensions): @@ -1084,9 +1221,97 @@ class BotBase(GroupMixin): ctx = await self.get_context(message) await self.invoke(ctx) + async def process_slash_commands(self, interaction: discord.Interaction): + """|coro| + + This function processes a slash command interaction into a usable + message and calls :meth:`.process_commands` based on it. Without this + coroutine slash commands will not be triggered. + + By default, this coroutine is called inside the :func:`.on_interaction` + event. If you choose to override the :func:`.on_interaction` event, + then you should invoke this coroutine as well. + + .. versionadded:: 2.0 + + Parameters + ----------- + interaction: :class:`discord.Interaction` + The interaction to process slash commands for. + + """ + if interaction.type != discord.InteractionType.application_command: + return + + interaction.data = cast(ApplicationCommandInteractionData, interaction.data) + command_name, command_options = _unwrap_slash_groups(interaction.data) + + command = self.get_command(command_name) + if command is None: + raise errors.CommandNotFound(f'Command "{command_name}" is not found') + + # Ensure the interaction channel is usable + channel = interaction.channel + if channel is None or isinstance(channel, discord.PartialMessageable): + if interaction.guild is None: + assert interaction.user is not None + channel = await interaction.user.create_dm() + elif interaction.channel_id is not None: + channel = await interaction.guild.fetch_channel(interaction.channel_id) + else: + return # cannot do anything without stable channel + + # Fetch a valid prefix, so process_commands can function + message: discord.Message = _FakeSlashMessage.from_interaction(interaction, channel) # type: ignore + prefix = await self.get_prefix(message) + if isinstance(prefix, list): + prefix = prefix[0] + + # Add arguments to fake message content, in the right order + ignore_params: List[inspect.Parameter] = [] + message.content = f"{prefix}{command_name} " + for name, param in command.clean_params.items(): + if inspect.isclass(param.annotation) and issubclass(param.annotation, FlagConverter): + for name, flag in param.annotation.get_flags().items(): + option = next((o for o in command_options if o["name"] == name), None) + + if option is None: + if flag.required: + raise errors.MissingRequiredFlag(flag) + else: + prefix = param.annotation.__commands_flag_prefix__ + delimiter = param.annotation.__commands_flag_delimiter__ + message.content += f"{prefix}{name} {option['value']}{delimiter}" # type: ignore + continue + + option = next((o for o in command_options if o["name"] == name), None) + if option is None: + if param.default is param.empty and not command._is_typing_optional(param.annotation): + raise errors.MissingRequiredArgument(param) + else: + ignore_params.append(param) + elif ( + option["type"] == 3 + and not isinstance(param.annotation, Greedy) + and param.kind in {param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY} + ): + # String with space in without "consume rest" + option = cast(_ApplicationCommandInteractionDataOptionString, option) + message.content += f"{_quote_string_safe(option['value'])} " + else: + message.content += f'{option.get("value", "")} ' + + ctx = await self.get_context(message) + ctx._ignored_params = ignore_params + ctx.interaction = interaction + await self.invoke(ctx) + async def on_message(self, message): await self.process_commands(message) + async def on_interaction(self, interaction: discord.Interaction): + await self.process_slash_commands(interaction) + class Bot(BotBase, discord.Client): """Represents a discord bot. @@ -1157,6 +1382,28 @@ class Bot(BotBase, discord.Client): the ``command_prefix`` is set to ``!``. Defaults to ``False``. .. versionadded:: 1.7 + message_commands: Optional[:class:`bool`] + Whether to process commands based on messages. + + Can be overwritten per command in the command decorators or when making + a :class:`Command` object via the ``message_command`` parameter + + .. versionadded:: 2.0 + slash_commands: Optional[:class:`bool`] + Whether to upload and process slash commands. + + Can be overwritten per command in the command decorators or when making + a :class:`Command` object via the ``slash_command`` parameter + + .. versionadded:: 2.0 + slash_command_guilds: Optional[:class:`List[int]`] + If this is set, only upload slash commands to these guild IDs. + + Can be overwritten per command in the command decorators or when making + a :class:`Command` object via the ``slash_command_guilds`` parameter + + .. versionadded:: 2.0 + """ pass diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index a4135793..6bcce31b 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -25,8 +25,8 @@ from __future__ import annotations import inspect import re - -from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, TypeVar, Union +from datetime import timedelta +from typing import Any, Dict, Generic, List, Literal, Optional, TYPE_CHECKING, TypeVar, Union, overload import discord.abc import discord.utils @@ -42,6 +42,8 @@ if TYPE_CHECKING: from discord.member import Member from discord.state import ConnectionState from discord.user import ClientUser, User + from discord.webhook import WebhookMessage + from discord.interactions import Interaction from discord.voice_client import VoiceProtocol from .bot import Bot, AutoShardedBot @@ -120,6 +122,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): A boolean that indicates if the command failed to be parsed, checked, or invoked. """ + interaction: Optional[Interaction] = None def __init__( self, @@ -151,6 +154,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): self.subcommand_passed: Optional[str] = subcommand_passed self.command_failed: bool = command_failed self.current_parameter: Optional[inspect.Parameter] = current_parameter + self._ignored_params: List[inspect.Parameter] = [] self._state: ConnectionState = self.message._state async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T: @@ -402,6 +406,97 @@ class Context(discord.abc.Messageable, Generic[BotT]): except CommandError as e: await cmd.on_help_command_error(self, e) + @overload + async def send( + self, + content: Optional[str] = None, + return_message: Literal[False] = False, + ephemeral: bool = False, + **kwargs: Any, + ) -> Optional[Union[Message, WebhookMessage]]: + ... + + @overload + async def send( + self, + content: Optional[str] = None, + return_message: Literal[True] = True, + ephemeral: bool = False, + **kwargs: Any, + ) -> Union[Message, WebhookMessage]: + ... + + async def send( + self, content: Optional[str] = None, return_message: bool = True, ephemeral: bool = False, **kwargs: Any + ) -> Optional[Union[Message, WebhookMessage]]: + """ + |coro| + + A shortcut method to :meth:`.abc.Messageable.send` with interaction helpers. + + This function takes all the parameters of :meth:`.abc.Messageable.send` plus the following: + + Parameters + ------------ + return_message: :class:`bool` + Ignored if not in a slash command context. + If this is set to False more native interaction methods will be used. + ephemeral: :class:`bool` + Ignored if not in a slash command context. + Indicates if the message should only be visible to the user who started the interaction. + If a view is sent with an ephemeral message and it has no timeout set then the timeout + is set to 15 minutes. + + Returns + -------- + Optional[Union[:class:`.Message`, :class:`.WebhookMessage`]] + In a slash command context, the message that was sent if return_message is True. + + In a normal context, it always returns a :class:`.Message` + """ + + if self.interaction is None or ( + self.interaction.response.responded_at is not None + and discord.utils.utcnow() - self.interaction.response.responded_at >= timedelta(minutes=15) + ): + return await super().send(content, **kwargs) + + # Remove unsupported arguments from kwargs + kwargs.pop("nonce", None) + kwargs.pop("stickers", None) + kwargs.pop("reference", None) + kwargs.pop("delete_after", None) + kwargs.pop("mention_author", None) + + if not ( + return_message + or self.interaction.response.is_done() + or any(arg in kwargs for arg in ("file", "files", "allowed_mentions")) + ): + send = self.interaction.response.send_message + else: + # We have to defer in order to use the followup webhook + if not self.interaction.response.is_done(): + await self.interaction.response.defer(ephemeral=ephemeral) + + send = self.interaction.followup.send + + return await send(content, ephemeral=ephemeral, **kwargs) # type: ignore + + @overload + async def reply( + self, content: Optional[str] = None, return_message: Literal[False] = False, **kwargs: Any + ) -> Optional[Union[Message, WebhookMessage]]: + ... + + @overload + async def reply( + self, content: Optional[str] = None, return_message: Literal[True] = True, **kwargs: Any + ) -> Union[Message, WebhookMessage]: + ... + @discord.utils.copy_doc(Message.reply) - async def reply(self, content: Optional[str] = None, **kwargs: Any) -> Message: - return await self.message.reply(content, **kwargs) + async def reply( + self, content: Optional[str] = None, return_message: bool = True, **kwargs: Any + ) -> Optional[Union[Message, WebhookMessage]]: + return await self.send(content, return_message=return_message, reference=self.message, **kwargs) # type: ignore diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index a4791b8f..8504f4a0 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -77,6 +77,7 @@ __all__ = ( "GuildStickerConverter", "clean_content", "Greedy", + "Option", "run_converters", ) @@ -96,6 +97,8 @@ T_co = TypeVar("T_co", covariant=True) CT = TypeVar("CT", bound=discord.abc.GuildChannel) TT = TypeVar("TT", bound=discord.Thread) +DT = TypeVar("DT", bound=str) + @runtime_checkable class Converter(Protocol[T_co]): @@ -583,7 +586,7 @@ class ThreadConverter(IDConverter[discord.Thread]): 2. Lookup by mention. 3. Lookup by name. - .. versionadded: 2.0 + .. versionadded:: 2.0 """ async def convert(self, ctx: Context, argument: str) -> discord.Thread: @@ -1005,6 +1008,27 @@ class Greedy(List[T]): return cls(converter=converter) +if TYPE_CHECKING: + + def Option(default: T = inspect.Parameter.empty, *, description: str) -> T: + ... + + +else: + + class Option(Generic[T, DT]): + description: DT + default: Union[T, inspect.Parameter.empty] + __slots__ = ( + "default", + "description", + ) + + def __init__(self, default: T = inspect.Parameter.empty, *, description: DT) -> None: + self.description = description + self.default = default + + def _convert_to_bool(argument: str) -> bool: lowered = argument.lower() if lowered in ("yes", "y", "true", "t", "1", "enable", "on"): diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 836b799a..fa1e4212 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -23,12 +23,14 @@ DEALINGS IN THE SOFTWARE. """ from __future__ import annotations + from typing import ( Any, Callable, Dict, Generator, Generic, + Iterable, Literal, List, Optional, @@ -38,27 +40,32 @@ from typing import ( TypeVar, Type, TYPE_CHECKING, + cast, overload, ) import asyncio import functools import inspect import datetime +from collections import defaultdict +from operator import itemgetter import discord from .errors import * from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency, DynamicCooldownMapping -from .converter import run_converters, get_converter, Greedy +from .converter import CONVERTER_MAPPING, Converter, run_converters, get_converter, Greedy, Option from ._types import _BaseCommand from .cog import Cog from .context import Context +from .flags import FlagConverter if TYPE_CHECKING: from typing_extensions import Concatenate, ParamSpec, TypeGuard from discord.message import Message + from discord.types.interactions import EditApplicationCommand, ApplicationCommandInteractionDataOption from ._types import ( Coro, @@ -107,6 +114,21 @@ GroupT = TypeVar("GroupT", bound="Group") HookT = TypeVar("HookT", bound="Hook") ErrorT = TypeVar("ErrorT", bound="Error") +REVERSED_CONVERTER_MAPPING = {v: k for k, v in CONVERTER_MAPPING.items()} +application_option_type_lookup = { + str: 3, + bool: 5, + int: 4, + ( + discord.Member, + discord.User, + ): 6, # Preferably discord.abc.User, but 'Protocols with non-method members don't support issubclass()' + (discord.abc.GuildChannel, discord.DMChannel): 7, + discord.Role: 8, + discord.Object: 9, + float: 10, +} + if TYPE_CHECKING: P = ParamSpec("P") else: @@ -124,13 +146,21 @@ def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: return function -def get_signature_parameters(function: Callable[..., Any], globalns: Dict[str, Any]) -> Dict[str, inspect.Parameter]: +def get_signature_parameters( + function: Callable[..., Any], globalns: Dict[str, Any] +) -> Tuple[Dict[str, inspect.Parameter], Dict[str, str]]: signature = inspect.signature(function) params = {} cache: Dict[str, Any] = {} + descriptions = defaultdict(lambda: "no description") eval_annotation = discord.utils.evaluate_annotation for name, parameter in signature.parameters.items(): annotation = parameter.annotation + if isinstance(parameter.default, Option): # type: ignore + option = parameter.default + descriptions[name] = option.description + parameter = parameter.replace(default=option.default) + if annotation is parameter.empty: params[name] = parameter continue @@ -144,7 +174,7 @@ def get_signature_parameters(function: Callable[..., Any], globalns: Dict[str, A params[name] = parameter.replace(annotation=annotation) - return params + return params, descriptions def wrap_callback(coro): @@ -276,9 +306,26 @@ class Command(_BaseCommand, Generic[CogT, P, T]): extras: :class:`dict` A dict of user provided extras to attach to the Command. + .. versionadded:: 2.0 + .. note:: This object may be copied by the library. + message_command: Optional[:class:`bool`] + Whether to process this command based on messages. + This overwrites the global ``message_commands`` parameter of :class:`.Bot`. + + .. versionadded:: 2.0 + slash_command: Optional[:class:`bool`] + Whether to upload and process this command as a slash command. + + This overwrites the global ``slash_commands`` parameter of :class:`.Bot`. + + .. versionadded:: 2.0 + slash_command_guilds: Optional[:class:`List[int]`] + If this is set, only upload this slash command to these guild IDs. + + This overwrites the global ``slash_command_guilds`` parameter of :class:`.Bot`. .. versionadded:: 2.0 """ @@ -319,6 +366,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.callback = func self.enabled: bool = kwargs.get("enabled", True) + self.slash_command: Optional[bool] = kwargs.get("slash_command", None) + self.message_command: Optional[bool] = kwargs.get("message_command", None) + self.slash_command_guilds: Optional[Iterable[int]] = kwargs.get("slash_command_guilds", None) + help_doc = kwargs.get("help") if help_doc is not None: help_doc = inspect.cleandoc(help_doc) @@ -377,6 +428,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]): # bandaid for the fact that sometimes parent can be the bot instance parent = kwargs.get("parent") self.parent: Optional[GroupMixin] = parent if isinstance(parent, _BaseCommand) else None # type: ignore + if self.slash_command_guilds is not None and self.parent is not None: + raise ValueError( + "Cannot set specific guilds for a subcommand. They are inherited from the top level group." + ) self._before_invoke: Optional[Hook] = None try: @@ -417,7 +472,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): except AttributeError: globalns = {} - self.params = get_signature_parameters(function, globalns) + self.params, self.option_descriptions = get_signature_parameters(function, globalns) def add_check(self, func: Check) -> None: """Adds a check to the command. @@ -541,6 +596,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]): ctx.bot.dispatch("command_error", ctx, error) async def transform(self, ctx: Context, param: inspect.Parameter) -> Any: + if param in ctx._ignored_params: + # in a slash command, we need a way to mark a param as default so ctx._ignored_params is used + return param.default if param.default is not param.empty else None + required = param.default is param.empty converter = get_converter(param) consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw @@ -1109,10 +1168,19 @@ class Command(_BaseCommand, Generic[CogT, P, T]): :class:`bool` A boolean indicating if the command can be invoked. """ - if not self.enabled: raise DisabledCommand(f"{self.name} command is disabled") + if ctx.interaction is None and ( + self.message_command is False or (self.message_command is None and not ctx.bot.message_commands) + ): + raise DisabledCommand(f"{self.name} command cannot be run as a message command") + + if ctx.interaction is not None and ( + self.slash_command is False or (self.slash_command is None and not ctx.bot.slash_commands) + ): + raise DisabledCommand(f"{self.name} command cannot be run as a slash command") + original = ctx.command ctx.command = self @@ -1137,6 +1205,90 @@ class Command(_BaseCommand, Generic[CogT, P, T]): finally: ctx.command = original + def _param_to_options( + self, name: str, annotation: Any, required: bool, varadic: bool + ) -> List[Optional[ApplicationCommandInteractionDataOption]]: + + origin = getattr(annotation, "__origin__", None) + if inspect.isclass(annotation) and issubclass(annotation, FlagConverter): + return [ + param + for name, flag in annotation.get_flags().items() + for param in self._param_to_options( + name, flag.annotation, required=flag.required, varadic=flag.annotation is tuple + ) + ] + + if varadic: + annotation = str + origin = None + + if not required and origin is not None and len(annotation.__args__) == 2: + # Unpack Optional[T] (Union[T, None]) into just T + annotation, origin = annotation.__args__[0], None + + option: Dict[str, Any] = { + "name": name, + "required": required, + "description": self.option_descriptions[name], + } + + if origin is None: + if not inspect.isclass(annotation): + annotation = type(annotation) + + if issubclass(annotation, Converter): + # If this is a converter, we want to check if it is a native + # one, in which we can get the original type, eg, (MemberConverter -> Member) + annotation = REVERSED_CONVERTER_MAPPING.get(annotation, annotation) + + option["type"] = 3 + for python_type, discord_type in application_option_type_lookup.items(): + if issubclass(annotation, python_type): + option["type"] = discord_type + break + + elif origin is Literal: + literal_values = annotation.__args__ + python_type = type(literal_values[0]) + if ( + all(type(value) == python_type for value in literal_values) + and python_type in application_option_type_lookup.keys() + ): + + option["type"] = application_option_type_lookup[python_type] + option["choices"] = [ + {"name": literal_value, "value": literal_value} for literal_value in annotation.__args__ + ] + + option.setdefault("type", 3) # STRING + return [option] # type: ignore + + def to_application_command(self, nested: int = 0) -> Optional[EditApplicationCommand]: + if self.slash_command is False: + return + elif nested == 3: + raise ApplicationCommandRegistrationError(self, f"{self.qualified_name} is too deeply nested!") + + payload = {"name": self.name, "description": self.short_doc or "no description", "options": []} + if nested != 0: + payload["type"] = 1 + + for name, param in self.clean_params.items(): + options = self._param_to_options( + name, + param.annotation if param.annotation is not param.empty else str, + varadic=param.kind == param.KEYWORD_ONLY or isinstance(param.annotation, Greedy), + required=(param.default is param.empty and not self._is_typing_optional(param.annotation)) + or param.kind == param.VAR_POSITIONAL, + ) + if options is not None: + payload["options"].extend(option for option in options if option is not None) + + # Now we have all options, make sure required is before optional. + payload["options"] = sorted(payload["options"], key=itemgetter("required"), reverse=True) + return payload # type: ignore + class GroupMixin(Generic[CogT]): """A mixin that implements common functionality for classes that behave @@ -1510,6 +1662,19 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]): view.previous = previous await super().reinvoke(ctx, call_hooks=call_hooks) + def to_application_command(self, nested: int = 0) -> Optional[EditApplicationCommand]: + if self.slash_command is False: + return + elif nested == 2: + raise ApplicationCommandRegistrationError(self, f"{self.qualified_name} is too deeply nested!") + + return { # type: ignore + "name": self.name, + "type": int(not (nested - 1)) + 1, + "description": self.short_doc or "no description", + "options": [cmd.to_application_command(nested=nested + 1) for cmd in self.commands], + } + # Decorators diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index 7f6f5cb4..77aa4a8f 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -33,6 +33,7 @@ if TYPE_CHECKING: from .converter import Converter from .context import Context + from .core import Command from .cooldowns import Cooldown, BucketType from .flags import Flag from discord.abc import GuildChannel @@ -93,6 +94,7 @@ __all__ = ( "ExtensionFailed", "ExtensionNotFound", "CommandRegistrationError", + "ApplicationCommandRegistrationError", "FlagError", "BadFlagArgument", "MissingFlagArgument", @@ -1014,6 +1016,25 @@ class CommandRegistrationError(ClientException): super().__init__(f"The {type_} {name} is already an existing command or alias.") +class ApplicationCommandRegistrationError(ClientException): + """An exception raised when a command cannot be converted to an + application command. + + This inherits from :exc:`discord.ClientException` + + .. versionadded:: 2.0 + + Attributes + ---------- + command: :class:`Command` + The command that failed to be converted. + """ + + def __init__(self, command: Command, msg: str = None) -> None: + self.command = command + super().__init__(msg or f"{command.qualified_name} failed to converted to an application command.") + + class FlagError(BadArgument): """The base exception type for all flag parsing related errors. diff --git a/discord/ext/commands/help.py b/discord/ext/commands/help.py index 0630ea81..334d0fb1 100644 --- a/discord/ext/commands/help.py +++ b/discord/ext/commands/help.py @@ -615,7 +615,7 @@ class HelpCommand: :class:`.abc.Messageable` The destination where the help command will be output. """ - return self.context.channel + return self.context async def send_error_message(self, error): """|coro| @@ -977,6 +977,10 @@ class DefaultHelpCommand(HelpCommand): for page in self.paginator.pages: await destination.send(page) + interaction = self.context.interaction + if interaction is not None and destination == self.context.author and not interaction.response.is_done(): + await interaction.response.send_message("Sent help to your DMs!", ephemeral=True) + def add_command_formatting(self, command): """A utility function to format the non-indented block of commands and groups. @@ -1007,7 +1011,7 @@ class DefaultHelpCommand(HelpCommand): elif self.dm_help is None and len(self.paginator) > self.dm_help_threshold: return ctx.author else: - return ctx.channel + return ctx async def prepare_help_command(self, ctx, command): self.paginator.clear() diff --git a/discord/ext/commands/view.py b/discord/ext/commands/view.py index 9c503ac4..443f5bb6 100644 --- a/discord/ext/commands/view.py +++ b/discord/ext/commands/view.py @@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE. from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError # map from opening quotes to closing quotes -_quotes = { +supported_quotes = { '"': '"', "‘": "’", "‚": "‛", @@ -44,7 +44,7 @@ _quotes = { "《": "》", "〈": "〉", } -_all_quotes = set(_quotes.keys()) | set(_quotes.values()) +_all_quotes = set(supported_quotes.keys()) | set(supported_quotes.values()) class StringView: @@ -130,7 +130,7 @@ class StringView: if current is None: return None - close_quote = _quotes.get(current) + close_quote = supported_quotes.get(current) is_quoted = bool(close_quote) if is_quoted: result = [] diff --git a/discord/interactions.py b/discord/interactions.py index 0a9c7383..83a61a3b 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -47,6 +47,8 @@ __all__ = ( ) if TYPE_CHECKING: + from datetime import datetime + from .types.interactions import ( Interaction as InteractionPayload, InteractionData, @@ -58,12 +60,10 @@ if TYPE_CHECKING: from aiohttp import ClientSession from .embeds import Embed from .ui.view import View - from .channel import VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, PartialMessageable + from .channel import TextChannel, CategoryChannel, StoreChannel, PartialMessageable from .threads import Thread - InteractionChannel = Union[ - VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, Thread, PartialMessageable - ] + InteractionChannel = Union[TextChannel, CategoryChannel, StoreChannel, Thread, PartialMessageable] MISSING: Any = utils.MISSING @@ -179,7 +179,7 @@ class Interaction: type = ChannelType.text if self.guild_id is not None else ChannelType.private return PartialMessageable(state=self._state, id=self.channel_id, type=type) return None - return channel + return channel # type: ignore @property def permissions(self) -> Permissions: @@ -369,20 +369,20 @@ class InteractionResponse: """ __slots__: Tuple[str, ...] = ( - "_responded", + "responded_at", "_parent", ) def __init__(self, parent: Interaction): + self.responded_at: Optional[datetime] = None self._parent: Interaction = parent - self._responded: bool = False def is_done(self) -> bool: """:class:`bool`: Indicates whether an interaction response has been done before. An interaction can only be responded to once. """ - return self._responded + return self.responded_at is not None async def defer(self, *, ephemeral: bool = False) -> None: """|coro| @@ -405,7 +405,7 @@ class InteractionResponse: InteractionResponded This interaction has already been responded to before. """ - if self._responded: + if self.is_done(): raise InteractionResponded(self._parent) defer_type: int = 0 @@ -423,7 +423,8 @@ class InteractionResponse: await adapter.create_interaction_response( parent.id, parent.token, session=parent._session, type=defer_type, data=data ) - self._responded = True + + self.responded_at = utils.utcnow() async def pong(self) -> None: """|coro| @@ -439,7 +440,7 @@ class InteractionResponse: InteractionResponded This interaction has already been responded to before. """ - if self._responded: + if self.is_done(): raise InteractionResponded(self._parent) parent = self._parent @@ -448,7 +449,7 @@ class InteractionResponse: await adapter.create_interaction_response( parent.id, parent.token, session=parent._session, type=InteractionResponseType.pong.value ) - self._responded = True + self.responded_at = utils.utcnow() async def send_message( self, @@ -494,7 +495,7 @@ class InteractionResponse: InteractionResponded This interaction has already been responded to before. """ - if self._responded: + if self.is_done(): raise InteractionResponded(self._parent) payload: Dict[str, Any] = { @@ -537,7 +538,7 @@ class InteractionResponse: self._parent._state.store_view(view) - self._responded = True + self.responded_at = utils.utcnow() async def edit_message( self, @@ -578,7 +579,7 @@ class InteractionResponse: InteractionResponded This interaction has already been responded to before. """ - if self._responded: + if self.is_done(): raise InteractionResponded(self._parent) parent = self._parent @@ -629,7 +630,7 @@ class InteractionResponse: if view and not view.is_finished(): state.store_view(view, message_id) - self._responded = True + self.responded_at = utils.utcnow() class _InteractionMessageState: diff --git a/discord/types/interactions.py b/discord/types/interactions.py index 652d1902..1c58af67 100644 --- a/discord/types/interactions.py +++ b/discord/types/interactions.py @@ -227,8 +227,8 @@ class _EditApplicationCommandOptional(TypedDict, total=False): description: str options: Optional[List[ApplicationCommandOption]] type: ApplicationCommandType + default_permission: bool class EditApplicationCommand(_EditApplicationCommandOptional): name: str - default_permission: bool diff --git a/discord/ui/view.py b/discord/ui/view.py index 34a28d85..87899e0c 100644 --- a/discord/ui/view.py +++ b/discord/ui/view.py @@ -353,7 +353,7 @@ class View: return await item.callback(interaction) - if not interaction.response._responded: + if not interaction.response.is_done(): await interaction.response.defer() except Exception as e: return await self.on_error(e, item, interaction) diff --git a/docs/ext/commands/commands.rst b/docs/ext/commands/commands.rst index f205be2a..0f73f050 100644 --- a/docs/ext/commands/commands.rst +++ b/docs/ext/commands/commands.rst @@ -61,6 +61,13 @@ the name to something other than the function would be as simple as doing this: async def _list(ctx, arg): pass +Slash Commands +-------------- +Slash Commands can be enabled in the :class:`.Bot` constructor or :class:`.Command` constructor, using +``slash_commands=True`` or ``slash_command=True`` respectfully. All features of the commands extension +should work with these options enabled, however many will not have direct discord counterparts and therefore +will be subsituted for supported versions when uploaded to discord. + Parameters ------------ @@ -179,6 +186,11 @@ know how the command was executed. It contains a lot of useful information: The context implements the :class:`abc.Messageable` interface, so anything you can do on a :class:`abc.Messageable` you can do on the :class:`~ext.commands.Context`. +.. warning:: + :attr:`.Context.message` will be fake if in a slash command, it is not + recommended to access if :attr:`.Context.interaction` is not None as most + methods will error due to the message not actually existing. + Converters ------------ diff --git a/docs/whats_new.rst b/docs/whats_new.rst index bdbf1d75..bc283d22 100644 --- a/docs/whats_new.rst +++ b/docs/whats_new.rst @@ -43,7 +43,7 @@ Breaking Changes - :attr:`GroupChannel.owner` is now Optional - ``edit`` methods now only accept None if it actually means something (e.g. clearing it) - ``timeout`` parameter for ``ui.View.__init__`` is now keyword only -- When an interaction has already been responded and another one is sent, :exc:`InteractionResponded`is now raised. +- When an interaction has already been responded and another one is sent, :exc:`InteractionResponded` is now raised. - Discord's API only allows a single :attr:`interaction.response`. - Separate :func:`on_member_update` and :func:`on_presence_update` - The new event :func:`on_presence_update` is now called when status/activity is changed. diff --git a/examples/slash_commands.py b/examples/slash_commands.py new file mode 100644 index 00000000..79025b40 --- /dev/null +++ b/examples/slash_commands.py @@ -0,0 +1,41 @@ +import discord +from discord.ext import commands + +# Set slash commands=True when constructing your bot to enable all slash commands +# if your bot is only for a couple of servers, you can use the parameter +# `slash_command_guilds=[list, of, guild, ids]` to specify this, +# then the commands will be much faster to upload. +bot = commands.Bot("!", intents=discord.Intents(guilds=True, messages=True), slash_commands=True) + + +@bot.event +async def on_ready(): + print(f"Logged in as {bot.user} (ID: {bot.user.id})") + print("------") + + +@bot.command() +# You can use commands.Option to define descriptions for your options, and converters will still work fine. +async def ping( + ctx: commands.Context, emoji: bool = commands.Option(description="whether to use an emoji when responding") +): + # This command can be used with slash commands or message commands + if emoji: + await ctx.send("\U0001f3d3") + else: + await ctx.send("Pong!") + + +@bot.command(message_command=False) +async def only_slash(ctx: commands.Context): + # This command can only be used with slash commands + await ctx.send("Hello from slash commands!") + + +@bot.command(slash_command=False) +async def only_message(ctx: commands.Context): + # This command can only be used with a message + await ctx.send("Hello from message commands!") + + +bot.run("token")