diff --git a/discord/client.py b/discord/client.py index b6198d10..250e9da1 100644 --- a/discord/client.py +++ b/discord/client.py @@ -329,7 +329,7 @@ class Client: If this is not passed via ``__init__`` then this is retrieved through the gateway when an event contains the data. Usually after :func:`~discord.on_connect` is called. - + .. versionadded:: 2.0 """ return self._connection.application_id @@ -687,7 +687,7 @@ class Client: self._connection._activity = value.to_dict() # type: ignore else: raise TypeError('activity must derive from BaseActivity.') - + @property def status(self): """:class:`.Status`: @@ -758,7 +758,7 @@ class Client: This is useful if you have a channel_id but don't want to do an API call to send messages to it. - + .. versionadded:: 2.0 Parameters @@ -1604,7 +1604,7 @@ class Client: This method should be used for when a view is comprised of components that last longer than the lifecycle of the program. - + .. versionadded:: 2.0 Parameters @@ -1636,7 +1636,7 @@ class Client: @property def persistent_views(self) -> Sequence[View]: """Sequence[:class:`.View`]: A sequence of persistent views added to the client. - + .. versionadded:: 2.0 """ return self._connection.persistent_views diff --git a/discord/embeds.py b/discord/embeds.py index 7033a10e..d332b4e6 100644 --- a/discord/embeds.py +++ b/discord/embeds.py @@ -366,7 +366,7 @@ class Embed: self._footer['icon_url'] = str(icon_url) return self - + def remove_footer(self: E) -> E: """Clears embed's footer information. @@ -381,7 +381,7 @@ class Embed: pass return self - + @property def image(self) -> _EmbedMediaProxy: """Returns an ``EmbedProxy`` denoting the image contents. diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index b4da6100..92c00ea3 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -28,17 +28,23 @@ from __future__ import annotations import asyncio import collections import collections.abc + import inspect import importlib.util import sys import traceback import types -from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union +from typing import Any, Callable, cast, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union import discord +from discord.types.interactions import ( + ApplicationCommandInteractionData, + _ApplicationCommandInteractionDataOptionString +) from .core import GroupMixin -from .view import StringView +from .converter import Greedy +from .view import StringView, supported_quotes from .context import Context from . import errors from .help import HelpCommand, DefaultHelpCommand @@ -66,6 +72,13 @@ T = TypeVar('T') CFT = TypeVar('CFT', bound='CoroFunc') CXT = TypeVar('CXT', bound='Context') +class _FakeSlashMessage(discord.PartialMessage): + activity = application = edited_at = reference = webhook_id = None + attachments = components = reactions = stickers = [] + author: Union[discord.User, discord.Member] + tts = False + + def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]: """A callable that implements a command prefix equivalent to being mentioned. @@ -120,9 +133,17 @@ class _DefaultRepr: _default = _DefaultRepr() class BotBase(GroupMixin): - def __init__(self, command_prefix, help_command=_default, description=None, **options): + def __init__(self, + command_prefix, + help_command=_default, + description=None, + message_commands: bool = True, + slash_commands: bool = False, **options + ): super().__init__(**options) self.command_prefix = command_prefix + self.slash_commands = slash_commands + self.message_commands = message_commands self.extra_events: Dict[str, List[CoroFunc]] = {} self.__cogs: Dict[str, Cog] = {} self.__extensions: Dict[str, types.ModuleType] = {} @@ -142,11 +163,17 @@ class BotBase(GroupMixin): if self.owner_ids and not isinstance(self.owner_ids, collections.abc.Collection): raise TypeError(f'owner_ids must be a collection not {self.owner_ids.__class__!r}') + if not (message_commands or slash_commands): + raise TypeError("Both message_commands and slash_commands are disabled.") + elif slash_commands: + self.slash_command_guild = options['slash_command_guild'] + if help_command is _default: self.help_command = DefaultHelpCommand() else: self.help_command = help_command + # internal helpers def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None: @@ -1031,7 +1058,91 @@ class BotBase(GroupMixin): await self.invoke(ctx) async def on_message(self, message): - await self.process_commands(message) + if self.message_commands: + await self.process_commands(message) + + async def on_interaction(self, interaction: discord.Interaction): + if not self.slash_commands or interaction.type != discord.InteractionType.application_command: + return + + assert interaction.user is not None + + interaction.data = cast(ApplicationCommandInteractionData, interaction.data) + + # Ensure the interaction channel is usable + channel = interaction.channel + if channel is None or isinstance(channel, discord.PartialMessageable): + if interaction.guild is None: + channel = await interaction.user.create_dm() + elif interaction.channel_id is not None: + channel = await interaction.guild.fetch_channel(interaction.channel_id) + else: + return # cannot do anything without stable channel + + # Fetch out subcommands from the options + command_name = interaction.data['name'] + command_options = interaction.data.get('options') or [] + for option in command_options: + if option['type'] in {1, 2}: + command_name = option['name'] + command_options = option.get('options') or [] + + command_name += f'{command_name} ' + + command = self.get_command(command_name) + if command is None: + raise errors.CommandNotFound(f'Command "{command_name}" is not found') + + message: discord.Message = _FakeSlashMessage(id=interaction.id, channel=channel) # type: ignore + message.author = interaction.user + + # Fetch a valid prefix, so process_commands can function + prefix = await self.get_prefix(message) + if isinstance(prefix, list): + prefix = prefix[0] + + # Add arguments to fake message content, in the right order + message.content = f'{prefix}{command_name} ' + for name, param in command.clean_params.items(): + option = next((o for o in command_options if o['name'] == name), None) # type: ignore + print(name, param, option) + + if option is None: + if not command._is_typing_optional(param.annotation): + raise errors.MissingRequiredArgument(param) + elif ( + option["type"] == 3 + and " " in option["value"] # type: ignore + and param.kind != param.KEYWORD_ONLY + and not isinstance(param.annotation, Greedy) + ): + # String with space in without "consume rest" + option = cast(_ApplicationCommandInteractionDataOptionString, option) + + # we need to quote this string otherwise we may spill into + # other parameters and cause all kinds of trouble, as many + # quotes are supported and some may be in the option, we + # loop through all supported quotes and if neither open or + # close are in the string, we add them + quoted = False + string = option['value'] + for open, close in supported_quotes.items(): + if not (open in string or close in string): + message.content += f"{open}{string}{close} " + quoted = True + break + + # all supported quotes are in the message and we cannot add any + # safely, very unlikely but still got to be covered + if not quoted: + raise errors.UnexpectedQuoteError(string) + else: + message.content += f'{option.get("value", "")} ' + + ctx = await self.get_context(message) + ctx.interaction = interaction + await self.invoke(ctx) + class Bot(BotBase, discord.Client): """Represents a discord bot. @@ -1103,7 +1214,20 @@ class Bot(BotBase, discord.Client): .. versionadded:: 1.7 """ - pass + # Needs to be moved to somewhere else, preferably BotBase + async def login(self, token: str) -> None: + await super().login(token=token) + await self._ready_commands() + + async def _ready_commands(self): + if not self.slash_commands: + return + + application = self.application_id or (await self.application_info()).id + commands = [scmd for cmd in self.commands if (scmd := cmd.to_application_command()) is not None] + + await self.http.bulk_upsert_guild_commands(application, self.slash_command_guild, payload=commands) + class AutoShardedBot(BotBase, discord.AutoShardedClient): """This is similar to :class:`.Bot` except that it is inherited from diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 38a24d1d..a751c6bb 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -41,6 +41,7 @@ if TYPE_CHECKING: from discord.member import Member from discord.state import ConnectionState from discord.user import ClientUser, User + from discord.interactions import Interaction from discord.voice_client import VoiceProtocol from .bot import Bot, AutoShardedBot @@ -121,6 +122,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): A boolean that indicates if the command failed to be parsed, checked, or invoked. """ + interaction: Optional[Interaction] = None def __init__(self, *, diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 35b7e840..cb411f2c 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE. """ from __future__ import annotations + from typing import ( Any, Callable, @@ -44,6 +45,7 @@ import asyncio import functools import inspect import datetime +from operator import itemgetter import discord @@ -59,6 +61,7 @@ if TYPE_CHECKING: from typing_extensions import Concatenate, ParamSpec, TypeGuard from discord.message import Message + from discord.types.interactions import EditApplicationCommand from ._types import ( Coro, @@ -106,6 +109,16 @@ ContextT = TypeVar('ContextT', bound='Context') GroupT = TypeVar('GroupT', bound='Group') HookT = TypeVar('HookT', bound='Hook') ErrorT = TypeVar('ErrorT', bound='Error') +application_option_type_lookup = { + str: 3, + bool: 5, + int: 4, + (discord.Member, discord.User): 6, # Preferably discord.abc.User, but 'Protocols with non-method members don't support issubclass()' + (discord.abc.GuildChannel, discord.DMChannel): 7, + discord.Role: 8, + discord.Object: 9, + float: 10 +} if TYPE_CHECKING: P = ParamSpec('P') @@ -269,8 +282,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]): which calls converters. If ``False`` then cooldown processing is done first and then the converters are called second. Defaults to ``False``. extras: :class:`dict` - A dict of user provided extras to attach to the Command. - + A dict of user provided extras to attach to the Command. + .. note:: This object may be copied by the library. @@ -309,6 +322,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.callback = func self.enabled: bool = kwargs.get('enabled', True) + self.slash_command: Optional[bool] = kwargs.get("slash_command", None) + self.normal_command: Optional[bool] = kwargs.get("normal_command", None) help_doc = kwargs.get('help') if help_doc is not None: @@ -344,7 +359,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): cooldown = func.__commands_cooldown__ except AttributeError: cooldown = kwargs.get('cooldown') - + if cooldown is None: buckets = CooldownMapping(cooldown, BucketType.default) elif isinstance(cooldown, CooldownMapping): @@ -1098,7 +1113,13 @@ class Command(_BaseCommand, Generic[CogT, P, T]): A boolean indicating if the command can be invoked. """ - if not self.enabled: + if not self.enabled or ( + ctx.interaction is not None + and self.slash_command is False + ) or ( + ctx.interaction is None + and self.normal_command is False + ): raise DisabledCommand(f'{self.name} command is disabled') original = ctx.command @@ -1125,6 +1146,54 @@ class Command(_BaseCommand, Generic[CogT, P, T]): finally: ctx.command = original + def to_application_command(self) -> Optional[EditApplicationCommand]: + if self.slash_command is False: + return + + payload = { + "name": self.name, + "description": self.short_doc or "no description", + "options": [] + } + + option_descriptions = self.extras.get("option_descriptions", {}) + for name, param in self.clean_params.items(): + annotation: Type[Any] = param.annotation if param.annotation is not param.empty else str + origin = getattr(param.annotation, "__origin__", None) + + if origin is None and isinstance(annotation, Greedy): + annotation = annotation.converter + origin = Greedy + + option: Dict[str, Any] = { + "name": name, + "required": not self._is_typing_optional(annotation), + "description": option_descriptions.get(name, "no description"), + } + + if not option["required"] and origin is not None and len(annotation.__args__) == 2: + # Unpack Optional[T] (Union[T, None]) into just T + annotation, origin = annotation.__args__[0], None + + if origin is None: + option["type"] = next( + (num for t, num in application_option_type_lookup.items() + if issubclass(annotation, t)), str + ) + elif origin is Literal and len(origin.__args__) <= 25: # type: ignore + option["choices"] = [{ + "name": literal_value, + "value": literal_value + } for literal_value in origin.__args__] # type: ignore + else: + option["type"] = 3 # STRING + + payload["options"].append(option) + + # Now we have all options, make sure required is before optional. + payload["options"] = sorted(payload["options"], key=itemgetter("required"), reverse=True) + return payload # type: ignore + class GroupMixin(Generic[CogT]): """A mixin that implements common functionality for classes that behave similar to :class:`.Group` and are allowed to register commands. diff --git a/discord/ext/commands/view.py b/discord/ext/commands/view.py index a7dc7236..a613dbfe 100644 --- a/discord/ext/commands/view.py +++ b/discord/ext/commands/view.py @@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE. from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError # map from opening quotes to closing quotes -_quotes = { +supported_quotes = { '"': '"', "‘": "’", "‚": "‛", @@ -44,7 +44,7 @@ _quotes = { "《": "》", "〈": "〉", } -_all_quotes = set(_quotes.keys()) | set(_quotes.values()) +_all_quotes = set(supported_quotes.keys()) | set(supported_quotes.values()) class StringView: def __init__(self, buffer): @@ -129,7 +129,7 @@ class StringView: if current is None: return None - close_quote = _quotes.get(current) + close_quote = supported_quotes.get(current) is_quoted = bool(close_quote) if is_quoted: result = [] diff --git a/discord/interactions.py b/discord/interactions.py index b89d49f5..f4849a2f 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -58,11 +58,11 @@ if TYPE_CHECKING: from aiohttp import ClientSession from .embeds import Embed from .ui.view import View - from .channel import VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, PartialMessageable + from .channel import TextChannel, CategoryChannel, StoreChannel, PartialMessageable from .threads import Thread InteractionChannel = Union[ - VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, Thread, PartialMessageable + TextChannel, CategoryChannel, StoreChannel, Thread, PartialMessageable ] MISSING: Any = utils.MISSING @@ -179,7 +179,7 @@ class Interaction: type = ChannelType.text if self.guild_id is not None else ChannelType.private return PartialMessageable(state=self._state, id=self.channel_id, type=type) return None - return channel + return channel # type: ignore @property def permissions(self) -> Permissions: diff --git a/discord/opus.py b/discord/opus.py index 97d437a3..16bf1384 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -428,7 +428,7 @@ class Decoder(_OpusStruct): @overload def decode(self, data: bytes, *, fec: bool) -> bytes: ... - + @overload def decode(self, data: Literal[None], *, fec: Literal[False]) -> bytes: ... diff --git a/discord/template.py b/discord/template.py index 30af3a4d..449a0110 100644 --- a/discord/template.py +++ b/discord/template.py @@ -310,7 +310,7 @@ class Template: @property def url(self) -> str: """:class:`str`: The template url. - + .. versionadded:: 2.0 """ return f'https://discord.new/{self.code}' diff --git a/discord/threads.py b/discord/threads.py index 892910d9..c49e8f78 100644 --- a/discord/threads.py +++ b/discord/threads.py @@ -273,7 +273,7 @@ class Thread(Messageable, Hashable): if parent is None: raise ClientException('Parent channel not found') return parent.category - + @property def category_id(self) -> Optional[int]: """The category channel ID the parent channel belongs to, if applicable. diff --git a/discord/types/interactions.py b/discord/types/interactions.py index b0ce156b..74f58a17 100644 --- a/discord/types/interactions.py +++ b/discord/types/interactions.py @@ -229,8 +229,7 @@ class _EditApplicationCommandOptional(TypedDict, total=False): description: str options: Optional[List[ApplicationCommandOption]] type: ApplicationCommandType - + default_permission: bool class EditApplicationCommand(_EditApplicationCommandOptional): name: str - default_permission: bool diff --git a/discord/voice_client.py b/discord/voice_client.py index d382a74d..eba4f47c 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -66,7 +66,7 @@ if TYPE_CHECKING: VoiceServerUpdate as VoiceServerUpdatePayload, SupportedModes, ) - + has_nacl: bool diff --git a/docs/conf.py b/docs/conf.py index 03f69c19..5a03014c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -355,7 +355,7 @@ texinfo_documents = [ #texinfo_no_detailmenu = False def setup(app): - if app.config.language == 'ja': - app.config.intersphinx_mapping['py'] = ('https://docs.python.org/ja/3', None) - app.config.html_context['discord_invite'] = 'https://discord.gg/nXzj3dg' - app.config.resource_links['discord'] = 'https://discord.gg/nXzj3dg' + if app.config.language == 'ja': + app.config.intersphinx_mapping['py'] = ('https://docs.python.org/ja/3', None) + app.config.html_context['discord_invite'] = 'https://discord.gg/nXzj3dg' + app.config.resource_links['discord'] = 'https://discord.gg/nXzj3dg' diff --git a/docs/extensions/details.py b/docs/extensions/details.py index 96f39d5b..ba6f5b70 100644 --- a/docs/extensions/details.py +++ b/docs/extensions/details.py @@ -52,4 +52,3 @@ def setup(app): app.add_node(details, html=(visit_details_node, depart_details_node)) app.add_node(summary, html=(visit_summary_node, depart_summary_node)) app.add_directive('details', DetailsDirective) - diff --git a/docs/extensions/nitpick_file_ignorer.py b/docs/extensions/nitpick_file_ignorer.py index f5dff1d1..dda44c9c 100644 --- a/docs/extensions/nitpick_file_ignorer.py +++ b/docs/extensions/nitpick_file_ignorer.py @@ -5,7 +5,7 @@ from sphinx.util import logging as sphinx_logging class NitpickFileIgnorer(logging.Filter): - + def __init__(self, app: Sphinx) -> None: self.app = app super().__init__() diff --git a/examples/converters.py b/examples/converters.py index 9bd8ae06..1e5cf7e7 100644 --- a/examples/converters.py +++ b/examples/converters.py @@ -78,7 +78,7 @@ class ChannelOrMemberConverter(commands.Converter): async def notify(ctx: commands.Context, target: ChannelOrMemberConverter): # This command signature utilises the custom converter written above # What will happen during command invocation is that the `target` above will be passed to - # the `argument` parameter of the `ChannelOrMemberConverter.convert` method and + # the `argument` parameter of the `ChannelOrMemberConverter.convert` method and # the conversion will go through the process defined there. await target.send(f'Hello, {target.name}!') diff --git a/examples/custom_context.py b/examples/custom_context.py index d3a5b94b..e970c2b9 100644 --- a/examples/custom_context.py +++ b/examples/custom_context.py @@ -27,7 +27,7 @@ class MyBot(commands.Bot): # subclass to the super() method, which tells the bot to # use the new MyContext class return await super().get_context(message, cls=cls) - + bot = MyBot(command_prefix='!') @@ -43,7 +43,7 @@ async def guess(ctx, number: int): await ctx.tick(number == value) # IMPORTANT: You shouldn't hard code your token -# these are very important, and leaking them can +# these are very important, and leaking them can # let people do very malicious things with your # bot. Try to use a file or something to keep # them private, and don't commit it to GitHub diff --git a/examples/secret.py b/examples/secret.py index 9246c68f..a12e8978 100644 --- a/examples/secret.py +++ b/examples/secret.py @@ -5,7 +5,7 @@ from discord.ext import commands bot = commands.Bot(command_prefix=commands.when_mentioned, description="Nothing to see here!") -# the `hidden` keyword argument hides it from the help command. +# the `hidden` keyword argument hides it from the help command. @bot.group(hidden=True) async def secret(ctx: commands.Context): """What is this "secret" you speak of?""" @@ -13,7 +13,7 @@ async def secret(ctx: commands.Context): await ctx.send('Shh!', delete_after=5) def create_overwrites(ctx, *objects): - """This is just a helper function that creates the overwrites for the + """This is just a helper function that creates the overwrites for the voice/text channels. A `discord.PermissionOverwrite` allows you to determine the permissions @@ -45,10 +45,10 @@ def create_overwrites(ctx, *objects): @secret.command() @commands.guild_only() async def text(ctx: commands.Context, name: str, *objects: typing.Union[discord.Role, discord.Member]): - """This makes a text channel with a specified name + """This makes a text channel with a specified name that is only visible to roles or members that are specified. """ - + overwrites = create_overwrites(ctx, *objects) await ctx.guild.create_text_channel( diff --git a/examples/views/dropdown.py b/examples/views/dropdown.py index db6d699a..40606481 100644 --- a/examples/views/dropdown.py +++ b/examples/views/dropdown.py @@ -24,7 +24,7 @@ class Dropdown(discord.ui.Select): async def callback(self, interaction: discord.Interaction): # Use the interaction object to send a response message containing # the user's favourite colour or choice. The self object refers to the - # Select object, and the values attribute gets a list of the user's + # Select object, and the values attribute gets a list of the user's # selected options. We only want the first one. await interaction.response.send_message(f'Your favourite colour is {self.values[0]}') @@ -44,8 +44,8 @@ class Bot(commands.Bot): async def on_ready(self): print(f'Logged in as {self.user} (ID: {self.user.id})') print('------') - - + + bot = Bot() diff --git a/setup.py b/setup.py index 9ffd24ce..be000875 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ import re requirements = [] with open('requirements.txt') as f: - requirements = f.read().splitlines() + requirements = f.read().splitlines() version = '' with open('discord/__init__.py') as f: