Most slash command support completed, needs some debugging (and reindent)

This commit is contained in:
Gnome
2021-08-30 16:14:44 +01:00
parent 45d498c1b7
commit a19e43675f
20 changed files with 238 additions and 45 deletions

View File

@@ -28,17 +28,23 @@ from __future__ import annotations
import asyncio
import collections
import collections.abc
import inspect
import importlib.util
import sys
import traceback
import types
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union
from typing import Any, Callable, cast, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union
import discord
from discord.types.interactions import (
ApplicationCommandInteractionData,
_ApplicationCommandInteractionDataOptionString
)
from .core import GroupMixin
from .view import StringView
from .converter import Greedy
from .view import StringView, supported_quotes
from .context import Context
from . import errors
from .help import HelpCommand, DefaultHelpCommand
@@ -66,6 +72,13 @@ T = TypeVar('T')
CFT = TypeVar('CFT', bound='CoroFunc')
CXT = TypeVar('CXT', bound='Context')
class _FakeSlashMessage(discord.PartialMessage):
activity = application = edited_at = reference = webhook_id = None
attachments = components = reactions = stickers = []
author: Union[discord.User, discord.Member]
tts = False
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
"""A callable that implements a command prefix equivalent to being mentioned.
@@ -120,9 +133,17 @@ class _DefaultRepr:
_default = _DefaultRepr()
class BotBase(GroupMixin):
def __init__(self, command_prefix, help_command=_default, description=None, **options):
def __init__(self,
command_prefix,
help_command=_default,
description=None,
message_commands: bool = True,
slash_commands: bool = False, **options
):
super().__init__(**options)
self.command_prefix = command_prefix
self.slash_commands = slash_commands
self.message_commands = message_commands
self.extra_events: Dict[str, List[CoroFunc]] = {}
self.__cogs: Dict[str, Cog] = {}
self.__extensions: Dict[str, types.ModuleType] = {}
@@ -142,11 +163,17 @@ class BotBase(GroupMixin):
if self.owner_ids and not isinstance(self.owner_ids, collections.abc.Collection):
raise TypeError(f'owner_ids must be a collection not {self.owner_ids.__class__!r}')
if not (message_commands or slash_commands):
raise TypeError("Both message_commands and slash_commands are disabled.")
elif slash_commands:
self.slash_command_guild = options['slash_command_guild']
if help_command is _default:
self.help_command = DefaultHelpCommand()
else:
self.help_command = help_command
# internal helpers
def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None:
@@ -1031,7 +1058,91 @@ class BotBase(GroupMixin):
await self.invoke(ctx)
async def on_message(self, message):
await self.process_commands(message)
if self.message_commands:
await self.process_commands(message)
async def on_interaction(self, interaction: discord.Interaction):
if not self.slash_commands or interaction.type != discord.InteractionType.application_command:
return
assert interaction.user is not None
interaction.data = cast(ApplicationCommandInteractionData, interaction.data)
# Ensure the interaction channel is usable
channel = interaction.channel
if channel is None or isinstance(channel, discord.PartialMessageable):
if interaction.guild is None:
channel = await interaction.user.create_dm()
elif interaction.channel_id is not None:
channel = await interaction.guild.fetch_channel(interaction.channel_id)
else:
return # cannot do anything without stable channel
# Fetch out subcommands from the options
command_name = interaction.data['name']
command_options = interaction.data.get('options') or []
for option in command_options:
if option['type'] in {1, 2}:
command_name = option['name']
command_options = option.get('options') or []
command_name += f'{command_name} '
command = self.get_command(command_name)
if command is None:
raise errors.CommandNotFound(f'Command "{command_name}" is not found')
message: discord.Message = _FakeSlashMessage(id=interaction.id, channel=channel) # type: ignore
message.author = interaction.user
# Fetch a valid prefix, so process_commands can function
prefix = await self.get_prefix(message)
if isinstance(prefix, list):
prefix = prefix[0]
# Add arguments to fake message content, in the right order
message.content = f'{prefix}{command_name} '
for name, param in command.clean_params.items():
option = next((o for o in command_options if o['name'] == name), None) # type: ignore
print(name, param, option)
if option is None:
if not command._is_typing_optional(param.annotation):
raise errors.MissingRequiredArgument(param)
elif (
option["type"] == 3
and " " in option["value"] # type: ignore
and param.kind != param.KEYWORD_ONLY
and not isinstance(param.annotation, Greedy)
):
# String with space in without "consume rest"
option = cast(_ApplicationCommandInteractionDataOptionString, option)
# we need to quote this string otherwise we may spill into
# other parameters and cause all kinds of trouble, as many
# quotes are supported and some may be in the option, we
# loop through all supported quotes and if neither open or
# close are in the string, we add them
quoted = False
string = option['value']
for open, close in supported_quotes.items():
if not (open in string or close in string):
message.content += f"{open}{string}{close} "
quoted = True
break
# all supported quotes are in the message and we cannot add any
# safely, very unlikely but still got to be covered
if not quoted:
raise errors.UnexpectedQuoteError(string)
else:
message.content += f'{option.get("value", "")} '
ctx = await self.get_context(message)
ctx.interaction = interaction
await self.invoke(ctx)
class Bot(BotBase, discord.Client):
"""Represents a discord bot.
@@ -1103,7 +1214,20 @@ class Bot(BotBase, discord.Client):
.. versionadded:: 1.7
"""
pass
# Needs to be moved to somewhere else, preferably BotBase
async def login(self, token: str) -> None:
await super().login(token=token)
await self._ready_commands()
async def _ready_commands(self):
if not self.slash_commands:
return
application = self.application_id or (await self.application_info()).id
commands = [scmd for cmd in self.commands if (scmd := cmd.to_application_command()) is not None]
await self.http.bulk_upsert_guild_commands(application, self.slash_command_guild, payload=commands)
class AutoShardedBot(BotBase, discord.AutoShardedClient):
"""This is similar to :class:`.Bot` except that it is inherited from

View File

@@ -41,6 +41,7 @@ if TYPE_CHECKING:
from discord.member import Member
from discord.state import ConnectionState
from discord.user import ClientUser, User
from discord.interactions import Interaction
from discord.voice_client import VoiceProtocol
from .bot import Bot, AutoShardedBot
@@ -121,6 +122,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
A boolean that indicates if the command failed to be parsed, checked,
or invoked.
"""
interaction: Optional[Interaction] = None
def __init__(self,
*,

View File

@@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import (
Any,
Callable,
@@ -44,6 +45,7 @@ import asyncio
import functools
import inspect
import datetime
from operator import itemgetter
import discord
@@ -59,6 +61,7 @@ if TYPE_CHECKING:
from typing_extensions import Concatenate, ParamSpec, TypeGuard
from discord.message import Message
from discord.types.interactions import EditApplicationCommand
from ._types import (
Coro,
@@ -106,6 +109,16 @@ ContextT = TypeVar('ContextT', bound='Context')
GroupT = TypeVar('GroupT', bound='Group')
HookT = TypeVar('HookT', bound='Hook')
ErrorT = TypeVar('ErrorT', bound='Error')
application_option_type_lookup = {
str: 3,
bool: 5,
int: 4,
(discord.Member, discord.User): 6, # Preferably discord.abc.User, but 'Protocols with non-method members don't support issubclass()'
(discord.abc.GuildChannel, discord.DMChannel): 7,
discord.Role: 8,
discord.Object: 9,
float: 10
}
if TYPE_CHECKING:
P = ParamSpec('P')
@@ -269,8 +282,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
which calls converters. If ``False`` then cooldown processing is done
first and then the converters are called second. Defaults to ``False``.
extras: :class:`dict`
A dict of user provided extras to attach to the Command.
A dict of user provided extras to attach to the Command.
.. note::
This object may be copied by the library.
@@ -309,6 +322,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.callback = func
self.enabled: bool = kwargs.get('enabled', True)
self.slash_command: Optional[bool] = kwargs.get("slash_command", None)
self.normal_command: Optional[bool] = kwargs.get("normal_command", None)
help_doc = kwargs.get('help')
if help_doc is not None:
@@ -344,7 +359,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
cooldown = func.__commands_cooldown__
except AttributeError:
cooldown = kwargs.get('cooldown')
if cooldown is None:
buckets = CooldownMapping(cooldown, BucketType.default)
elif isinstance(cooldown, CooldownMapping):
@@ -1098,7 +1113,13 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
A boolean indicating if the command can be invoked.
"""
if not self.enabled:
if not self.enabled or (
ctx.interaction is not None
and self.slash_command is False
) or (
ctx.interaction is None
and self.normal_command is False
):
raise DisabledCommand(f'{self.name} command is disabled')
original = ctx.command
@@ -1125,6 +1146,54 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
finally:
ctx.command = original
def to_application_command(self) -> Optional[EditApplicationCommand]:
if self.slash_command is False:
return
payload = {
"name": self.name,
"description": self.short_doc or "no description",
"options": []
}
option_descriptions = self.extras.get("option_descriptions", {})
for name, param in self.clean_params.items():
annotation: Type[Any] = param.annotation if param.annotation is not param.empty else str
origin = getattr(param.annotation, "__origin__", None)
if origin is None and isinstance(annotation, Greedy):
annotation = annotation.converter
origin = Greedy
option: Dict[str, Any] = {
"name": name,
"required": not self._is_typing_optional(annotation),
"description": option_descriptions.get(name, "no description"),
}
if not option["required"] and origin is not None and len(annotation.__args__) == 2:
# Unpack Optional[T] (Union[T, None]) into just T
annotation, origin = annotation.__args__[0], None
if origin is None:
option["type"] = next(
(num for t, num in application_option_type_lookup.items()
if issubclass(annotation, t)), str
)
elif origin is Literal and len(origin.__args__) <= 25: # type: ignore
option["choices"] = [{
"name": literal_value,
"value": literal_value
} for literal_value in origin.__args__] # type: ignore
else:
option["type"] = 3 # STRING
payload["options"].append(option)
# Now we have all options, make sure required is before optional.
payload["options"] = sorted(payload["options"], key=itemgetter("required"), reverse=True)
return payload # type: ignore
class GroupMixin(Generic[CogT]):
"""A mixin that implements common functionality for classes that behave
similar to :class:`.Group` and are allowed to register commands.

View File

@@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError
# map from opening quotes to closing quotes
_quotes = {
supported_quotes = {
'"': '"',
"": "",
"": "",
@@ -44,7 +44,7 @@ _quotes = {
"": "",
"": "",
}
_all_quotes = set(_quotes.keys()) | set(_quotes.values())
_all_quotes = set(supported_quotes.keys()) | set(supported_quotes.values())
class StringView:
def __init__(self, buffer):
@@ -129,7 +129,7 @@ class StringView:
if current is None:
return None
close_quote = _quotes.get(current)
close_quote = supported_quotes.get(current)
is_quoted = bool(close_quote)
if is_quoted:
result = []