Add support for default_values field on selects

This commit is contained in:
Soheab_
2023-09-29 21:55:20 +02:00
committed by GitHub
parent 9f8f9bf56b
commit c5ecc42c72
6 changed files with 342 additions and 7 deletions

View File

@ -22,21 +22,42 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Any, List, Literal, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Callable, Union, Dict, overload
from typing import (
Any,
List,
Literal,
Optional,
TYPE_CHECKING,
Tuple,
Type,
TypeVar,
Callable,
Union,
Dict,
overload,
Sequence,
)
from contextvars import ContextVar
import inspect
import os
from .item import Item, ItemCallbackType
from ..enums import ChannelType, ComponentType
from ..enums import ChannelType, ComponentType, SelectDefaultValueType
from ..partial_emoji import PartialEmoji
from ..emoji import Emoji
from ..utils import MISSING
from ..components import (
SelectOption,
SelectMenu,
SelectDefaultValue,
)
from ..app_commands.namespace import Namespace
from ..member import Member
from ..object import Object
from ..role import Role
from ..user import User
from ..abc import GuildChannel
from ..threads import Thread
__all__ = (
'Select',
@ -54,9 +75,6 @@ if TYPE_CHECKING:
from ..types.components import SelectMenu as SelectMenuPayload
from ..types.interactions import SelectMessageComponentInteractionData
from ..app_commands import AppCommandChannel, AppCommandThread
from ..member import Member
from ..role import Role
from ..user import User
from ..interactions import Interaction
ValidSelectType: TypeAlias = Literal[
@ -69,6 +87,17 @@ if TYPE_CHECKING:
PossibleValue: TypeAlias = Union[
str, User, Member, Role, AppCommandChannel, AppCommandThread, Union[Role, Member], Union[Role, User]
]
ValidDefaultValues: TypeAlias = Union[
SelectDefaultValue,
Object,
Role,
Member,
User,
GuildChannel,
AppCommandChannel,
AppCommandThread,
Thread,
]
V = TypeVar('V', bound='View', covariant=True)
BaseSelectT = TypeVar('BaseSelectT', bound='BaseSelect[Any]')
@ -82,6 +111,73 @@ SelectCallbackDecorator: TypeAlias = Callable[[ItemCallbackType[V, BaseSelectT]]
selected_values: ContextVar[Dict[str, List[PossibleValue]]] = ContextVar('selected_values')
def _handle_select_defaults(
defaults: Sequence[ValidDefaultValues],
component_type: Literal[
ComponentType.user_select,
ComponentType.role_select,
ComponentType.channel_select,
ComponentType.mentionable_select,
],
) -> List[SelectDefaultValue]:
if not defaults or defaults is MISSING:
return []
from ..app_commands import AppCommandChannel, AppCommandThread
cls_to_type: Dict[Type[ValidDefaultValues], SelectDefaultValueType] = {
User: SelectDefaultValueType.user,
Member: SelectDefaultValueType.user,
Role: SelectDefaultValueType.role,
GuildChannel: SelectDefaultValueType.channel,
AppCommandChannel: SelectDefaultValueType.channel,
AppCommandThread: SelectDefaultValueType.channel,
Thread: SelectDefaultValueType.channel,
}
type_to_supported_classes: Dict[ValidSelectType, Tuple[Type[ValidDefaultValues], ...]] = {
ComponentType.user_select: (User, Member, Object),
ComponentType.role_select: (Role, Object),
ComponentType.channel_select: (GuildChannel, AppCommandChannel, AppCommandThread, Thread, Object),
ComponentType.mentionable_select: (User, Member, Role, Object),
}
values: List[SelectDefaultValue] = []
for obj in defaults:
if isinstance(obj, SelectDefaultValue):
values.append(obj)
continue
object_type = obj.__class__ if not isinstance(obj, Object) else obj.type
if object_type not in type_to_supported_classes[component_type]:
# TODO: split this into a util function
supported_classes = [c.__name__ for c in type_to_supported_classes[component_type]]
if len(supported_classes) > 2:
supported_classes = ', '.join(supported_classes[:-1]) + f', or {supported_classes[-1]}'
elif len(supported_classes) == 2:
supported_classes = f'{supported_classes[0]} or {supported_classes[1]}'
else:
supported_classes = supported_classes[0]
raise TypeError(f'Expected an instance of {supported_classes} not {object_type.__name__}')
if object_type is Object:
if component_type is ComponentType.mentionable_select:
raise ValueError(
'Object must have a type specified for the chosen select type. Please pass one using the `type`` kwarg.'
)
elif component_type is ComponentType.user_select:
object_type = User
elif component_type is ComponentType.role_select:
object_type = Role
elif component_type is ComponentType.channel_select:
object_type = GuildChannel
values.append(SelectDefaultValue(id=obj.id, type=cls_to_type[object_type]))
return values
class BaseSelect(Item[V]):
"""The base Select model that all other Select models inherit from.
@ -128,6 +224,7 @@ class BaseSelect(Item[V]):
disabled: bool = False,
options: List[SelectOption] = MISSING,
channel_types: List[ChannelType] = MISSING,
default_values: Sequence[SelectDefaultValue] = MISSING,
) -> None:
super().__init__()
self._provided_custom_id = custom_id is not MISSING
@ -144,6 +241,7 @@ class BaseSelect(Item[V]):
disabled=disabled,
channel_types=[] if channel_types is MISSING else channel_types,
options=[] if options is MISSING else options,
default_values=[] if default_values is MISSING else default_values,
)
self.row = row
@ -410,6 +508,10 @@ class UserSelect(BaseSelect[V]):
Defaults to 1 and must be between 1 and 25.
disabled: :class:`bool`
Whether the select is disabled or not.
default_values: Sequence[:class:`~discord.abc.Snowflake`]
A list of objects representing the users that should be selected by default.
.. versionadded:: 2.4
row: Optional[:class:`int`]
The relative row this select menu belongs to. A Discord component can only have 5
rows. By default, items are arranged automatically into those 5 rows. If you'd
@ -418,6 +520,8 @@ class UserSelect(BaseSelect[V]):
ordering. The row number must be between 0 and 4 (i.e. zero indexed).
"""
__item_repr_attributes__ = BaseSelect.__item_repr_attributes__ + ('default_values',)
def __init__(
self,
*,
@ -427,6 +531,7 @@ class UserSelect(BaseSelect[V]):
max_values: int = 1,
disabled: bool = False,
row: Optional[int] = None,
default_values: Sequence[ValidDefaultValues] = MISSING,
) -> None:
super().__init__(
self.type,
@ -436,6 +541,7 @@ class UserSelect(BaseSelect[V]):
max_values=max_values,
disabled=disabled,
row=row,
default_values=_handle_select_defaults(default_values, self.type),
)
@property
@ -456,6 +562,18 @@ class UserSelect(BaseSelect[V]):
"""
return super().values # type: ignore
@property
def default_values(self) -> List[SelectDefaultValue]:
"""List[:class:`discord.SelectDefaultValue`]: A list of default values for the select menu.
.. versionadded:: 2.4
"""
return self._underlying.default_values
@default_values.setter
def default_values(self, value: Sequence[ValidDefaultValues]) -> None:
self._underlying.default_values = _handle_select_defaults(value, self.type)
class RoleSelect(BaseSelect[V]):
"""Represents a UI select menu with a list of predefined options with the current roles of the guild.
@ -479,6 +597,10 @@ class RoleSelect(BaseSelect[V]):
Defaults to 1 and must be between 1 and 25.
disabled: :class:`bool`
Whether the select is disabled or not.
default_values: Sequence[:class:`~discord.abc.Snowflake`]
A list of objects representing the users that should be selected by default.
.. versionadded:: 2.4
row: Optional[:class:`int`]
The relative row this select menu belongs to. A Discord component can only have 5
rows. By default, items are arranged automatically into those 5 rows. If you'd
@ -487,6 +609,8 @@ class RoleSelect(BaseSelect[V]):
ordering. The row number must be between 0 and 4 (i.e. zero indexed).
"""
__item_repr_attributes__ = BaseSelect.__item_repr_attributes__ + ('default_values',)
def __init__(
self,
*,
@ -496,6 +620,7 @@ class RoleSelect(BaseSelect[V]):
max_values: int = 1,
disabled: bool = False,
row: Optional[int] = None,
default_values: Sequence[ValidDefaultValues] = MISSING,
) -> None:
super().__init__(
self.type,
@ -505,6 +630,7 @@ class RoleSelect(BaseSelect[V]):
max_values=max_values,
disabled=disabled,
row=row,
default_values=_handle_select_defaults(default_values, self.type),
)
@property
@ -517,6 +643,18 @@ class RoleSelect(BaseSelect[V]):
"""List[:class:`discord.Role`]: A list of roles that have been selected by the user."""
return super().values # type: ignore
@property
def default_values(self) -> List[SelectDefaultValue]:
"""List[:class:`discord.SelectDefaultValue`]: A list of default values for the select menu.
.. versionadded:: 2.4
"""
return self._underlying.default_values
@default_values.setter
def default_values(self, value: Sequence[ValidDefaultValues]) -> None:
self._underlying.default_values = _handle_select_defaults(value, self.type)
class MentionableSelect(BaseSelect[V]):
"""Represents a UI select menu with a list of predefined options with the current members and roles in the guild.
@ -543,6 +681,11 @@ class MentionableSelect(BaseSelect[V]):
Defaults to 1 and must be between 1 and 25.
disabled: :class:`bool`
Whether the select is disabled or not.
default_values: Sequence[:class:`~discord.abc.Snowflake`]
A list of objects representing the users/roles that should be selected by default.
if :class:`.Object` is passed, then the type must be specified in the constructor.
.. versionadded:: 2.4
row: Optional[:class:`int`]
The relative row this select menu belongs to. A Discord component can only have 5
rows. By default, items are arranged automatically into those 5 rows. If you'd
@ -551,6 +694,8 @@ class MentionableSelect(BaseSelect[V]):
ordering. The row number must be between 0 and 4 (i.e. zero indexed).
"""
__item_repr_attributes__ = BaseSelect.__item_repr_attributes__ + ('default_values',)
def __init__(
self,
*,
@ -560,6 +705,7 @@ class MentionableSelect(BaseSelect[V]):
max_values: int = 1,
disabled: bool = False,
row: Optional[int] = None,
default_values: Sequence[ValidDefaultValues] = MISSING,
) -> None:
super().__init__(
self.type,
@ -569,6 +715,7 @@ class MentionableSelect(BaseSelect[V]):
max_values=max_values,
disabled=disabled,
row=row,
default_values=_handle_select_defaults(default_values, self.type),
)
@property
@ -589,6 +736,18 @@ class MentionableSelect(BaseSelect[V]):
"""
return super().values # type: ignore
@property
def default_values(self) -> List[SelectDefaultValue]:
"""List[:class:`discord.SelectDefaultValue`]: A list of default values for the select menu.
.. versionadded:: 2.4
"""
return self._underlying.default_values
@default_values.setter
def default_values(self, value: Sequence[ValidDefaultValues]) -> None:
self._underlying.default_values = _handle_select_defaults(value, self.type)
class ChannelSelect(BaseSelect[V]):
"""Represents a UI select menu with a list of predefined options with the current channels in the guild.
@ -614,6 +773,10 @@ class ChannelSelect(BaseSelect[V]):
Defaults to 1 and must be between 1 and 25.
disabled: :class:`bool`
Whether the select is disabled or not.
default_values: Sequence[:class:`~discord.abc.Snowflake`]
A list of objects representing the channels that should be selected by default.
.. versionadded:: 2.4
row: Optional[:class:`int`]
The relative row this select menu belongs to. A Discord component can only have 5
rows. By default, items are arranged automatically into those 5 rows. If you'd
@ -622,7 +785,10 @@ class ChannelSelect(BaseSelect[V]):
ordering. The row number must be between 0 and 4 (i.e. zero indexed).
"""
__item_repr_attributes__ = BaseSelect.__item_repr_attributes__ + ('channel_types',)
__item_repr_attributes__ = BaseSelect.__item_repr_attributes__ + (
'channel_types',
'default_values',
)
def __init__(
self,
@ -634,6 +800,7 @@ class ChannelSelect(BaseSelect[V]):
max_values: int = 1,
disabled: bool = False,
row: Optional[int] = None,
default_values: Sequence[ValidDefaultValues] = MISSING,
) -> None:
super().__init__(
self.type,
@ -644,6 +811,7 @@ class ChannelSelect(BaseSelect[V]):
disabled=disabled,
row=row,
channel_types=channel_types,
default_values=_handle_select_defaults(default_values, self.type),
)
@property
@ -670,6 +838,18 @@ class ChannelSelect(BaseSelect[V]):
"""List[Union[:class:`~discord.app_commands.AppCommandChannel`, :class:`~discord.app_commands.AppCommandThread`]]: A list of channels selected by the user."""
return super().values # type: ignore
@property
def default_values(self) -> List[SelectDefaultValue]:
"""List[:class:`discord.SelectDefaultValue`]: A list of default values for the select menu.
.. versionadded:: 2.4
"""
return self._underlying.default_values
@default_values.setter
def default_values(self, value: Sequence[ValidDefaultValues]) -> None:
self._underlying.default_values = _handle_select_defaults(value, self.type)
@overload
def select(
@ -698,6 +878,7 @@ def select(
min_values: int = ...,
max_values: int = ...,
disabled: bool = ...,
default_values: Sequence[ValidDefaultValues] = ...,
row: Optional[int] = ...,
) -> SelectCallbackDecorator[V, UserSelectT]:
...
@ -714,6 +895,7 @@ def select(
min_values: int = ...,
max_values: int = ...,
disabled: bool = ...,
default_values: Sequence[ValidDefaultValues] = ...,
row: Optional[int] = ...,
) -> SelectCallbackDecorator[V, RoleSelectT]:
...
@ -730,6 +912,7 @@ def select(
min_values: int = ...,
max_values: int = ...,
disabled: bool = ...,
default_values: Sequence[ValidDefaultValues] = ...,
row: Optional[int] = ...,
) -> SelectCallbackDecorator[V, ChannelSelectT]:
...
@ -746,6 +929,7 @@ def select(
min_values: int = ...,
max_values: int = ...,
disabled: bool = ...,
default_values: Sequence[ValidDefaultValues] = ...,
row: Optional[int] = ...,
) -> SelectCallbackDecorator[V, MentionableSelectT]:
...
@ -761,6 +945,7 @@ def select(
min_values: int = 1,
max_values: int = 1,
disabled: bool = False,
default_values: Sequence[ValidDefaultValues] = MISSING,
row: Optional[int] = None,
) -> SelectCallbackDecorator[V, BaseSelectT]:
"""A decorator that attaches a select menu to a component.
@ -832,6 +1017,12 @@ def select(
with :class:`ChannelSelect` instances.
disabled: :class:`bool`
Whether the select is disabled or not. Defaults to ``False``.
default_values: Sequence[:class:`~discord.abc.Snowflake`]
A list of objects representing the default values for the select menu. This cannot be used with regular :class:`Select` instances.
If ``cls`` is :class:`MentionableSelect` and :class:`.Object` is passed, then the type must be specified in the constructor.
if `cls` is :class:`MentionableSelect` and :class:`.Object` is passed, then the type must be specified in the constructor.
.. versionadded:: 2.4
"""
def decorator(func: ItemCallbackType[V, BaseSelectT]) -> ItemCallbackType[V, BaseSelectT]:
@ -855,6 +1046,24 @@ def select(
func.__discord_ui_model_kwargs__['options'] = options
if issubclass(callback_cls, ChannelSelect):
func.__discord_ui_model_kwargs__['channel_types'] = channel_types
if not issubclass(callback_cls, Select):
cls_to_type: Dict[
Type[BaseSelect],
Literal[
ComponentType.user_select,
ComponentType.channel_select,
ComponentType.role_select,
ComponentType.mentionable_select,
],
] = {
UserSelect: ComponentType.user_select,
RoleSelect: ComponentType.role_select,
MentionableSelect: ComponentType.mentionable_select,
ChannelSelect: ComponentType.channel_select,
}
func.__discord_ui_model_kwargs__['default_values'] = (
MISSING if default_values is MISSING else _handle_select_defaults(default_values, cls_to_type[callback_cls])
)
return func