Revert "Infer select type from callback annotation

This commit is contained in:
Zephyrkul 2022-12-18 02:17:41 -08:00 committed by GitHub
parent b671958e11
commit 7cf3cd51a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -22,7 +22,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import List, Literal, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Callable, Union, Dict
from typing import List, Literal, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Callable, Union, Dict, overload
from contextvars import ContextVar
import inspect
import os
@ -31,8 +31,11 @@ from .item import Item, ItemCallbackType
from ..enums import ChannelType, ComponentType
from ..partial_emoji import PartialEmoji
from ..emoji import Emoji
from ..utils import MISSING, resolve_annotation
from ..components import SelectOption, SelectMenu
from ..utils import MISSING
from ..components import (
SelectOption,
SelectMenu,
)
from ..app_commands.namespace import Namespace
__all__ = (
@ -69,6 +72,11 @@ if TYPE_CHECKING:
V = TypeVar('V', bound='View', covariant=True)
BaseSelectT = TypeVar('BaseSelectT', bound='BaseSelect')
SelectT = TypeVar('SelectT', bound='Select')
UserSelectT = TypeVar('UserSelectT', bound='UserSelect')
RoleSelectT = TypeVar('RoleSelectT', bound='RoleSelect')
ChannelSelectT = TypeVar('ChannelSelectT', bound='ChannelSelect')
MentionableSelectT = TypeVar('MentionableSelectT', bound='MentionableSelect')
SelectCallbackDecorator: TypeAlias = Callable[[ItemCallbackType[V, BaseSelectT]], BaseSelectT]
selected_values: ContextVar[Dict[str, List[PossibleValue]]] = ContextVar('selected_values')
@ -662,32 +670,89 @@ class ChannelSelect(BaseSelect[V]):
return super().values # type: ignore
def _get_select_callback_parameter(func: ItemCallbackType[V, BaseSelectT]) -> Type[BaseSelect]:
params = inspect.signature(func).parameters
if len(params) != 3:
raise TypeError(
f'select menu callback {func.__qualname__!r} requires 3 parameters, '
'the view instance (self), the discord.Interaction, and the select menu itself'
)
@overload
def select(
*,
cls: Type[SelectT] = Select,
options: List[SelectOption] = MISSING,
channel_types: List[ChannelType] = ...,
placeholder: Optional[str] = ...,
custom_id: str = ...,
min_values: int = ...,
max_values: int = ...,
disabled: bool = ...,
row: Optional[int] = ...,
) -> SelectCallbackDecorator[V, SelectT]:
...
iterator = iter(params.values())
parameter = next(iterator)
for parameter in iterator:
pass
if parameter.annotation is parameter.empty:
return Select
@overload
def select(
*,
cls: Type[UserSelectT],
options: List[SelectOption] = MISSING,
channel_types: List[ChannelType] = ...,
placeholder: Optional[str] = ...,
custom_id: str = ...,
min_values: int = ...,
max_values: int = ...,
disabled: bool = ...,
row: Optional[int] = ...,
) -> SelectCallbackDecorator[V, UserSelectT]:
...
resolved = resolve_annotation(parameter.annotation, func.__globals__, func.__globals__, {})
origin = getattr(resolved, '__origin__', resolved)
if origin is BaseSelect or not isinstance(origin, type) or not issubclass(origin, BaseSelect):
return Select
return origin
@overload
def select(
*,
cls: Type[RoleSelectT],
options: List[SelectOption] = MISSING,
channel_types: List[ChannelType] = ...,
placeholder: Optional[str] = ...,
custom_id: str = ...,
min_values: int = ...,
max_values: int = ...,
disabled: bool = ...,
row: Optional[int] = ...,
) -> SelectCallbackDecorator[V, RoleSelectT]:
...
@overload
def select(
*,
cls: Type[ChannelSelectT],
options: List[SelectOption] = MISSING,
channel_types: List[ChannelType] = ...,
placeholder: Optional[str] = ...,
custom_id: str = ...,
min_values: int = ...,
max_values: int = ...,
disabled: bool = ...,
row: Optional[int] = ...,
) -> SelectCallbackDecorator[V, ChannelSelectT]:
...
@overload
def select(
*,
cls: Type[MentionableSelectT],
options: List[SelectOption] = MISSING,
channel_types: List[ChannelType] = MISSING,
placeholder: Optional[str] = ...,
custom_id: str = ...,
min_values: int = ...,
max_values: int = ...,
disabled: bool = ...,
row: Optional[int] = ...,
) -> SelectCallbackDecorator[V, MentionableSelectT]:
...
def select(
*,
cls: Type[BaseSelectT] = Select if TYPE_CHECKING else MISSING,
cls: Type[BaseSelectT] = Select,
options: List[SelectOption] = MISSING,
channel_types: List[ChannelType] = MISSING,
placeholder: Optional[str] = None,
@ -722,10 +787,7 @@ def select(
.. versionchanged:: 2.1
Added the following keyword-arguments: ``cls``, ``channel_types``
.. versionchanged:: 2.2
Now infers ``cls`` based on the callback if not supplied.
Example
---------
.. code-block:: python3
@ -740,11 +802,10 @@ def select(
------------
cls: Union[Type[:class:`discord.ui.Select`], Type[:class:`discord.ui.UserSelect`], Type[:class:`discord.ui.RoleSelect`], \
Type[:class:`discord.ui.MentionableSelect`], Type[:class:`discord.ui.ChannelSelect`]]
The class to use for the select menu. Defaults to inferring the type from the
callback if available; otherwise defaults to :class:`discord.ui.Select`.
You can use other select types to display different select menus to the user.
See the table above for the different values you can get from each select type.
Subclasses work as well, however the callback in the subclass will get overridden.
The class to use for the select menu. Defaults to :class:`discord.ui.Select`. You can use other
select types to display different select menus to the user. See the table above for the different
values you can get from each select type. Subclasses work as well, however the callback in the subclass will
get overridden.
placeholder: Optional[:class:`str`]
The placeholder text that is shown if nothing is selected, if any.
custom_id: :class:`str`
@ -775,13 +836,10 @@ def select(
def decorator(func: ItemCallbackType[V, BaseSelectT]) -> ItemCallbackType[V, BaseSelectT]:
if not inspect.iscoroutinefunction(func):
raise TypeError('select function must be a coroutine function')
if cls is MISSING:
callback_cls = _get_select_callback_parameter(func)
else:
callback_cls = getattr(cls, '__origin__', cls)
callback_cls = getattr(cls, '__origin__', cls)
if not issubclass(callback_cls, BaseSelect):
supported_classes = ", ".join(["ChannelSelect", "MentionableSelect", "RoleSelect", "Select", "UserSelect"])
raise TypeError(f'cls must be one of {supported_classes} or a subclass of one of them, not {callback_cls!r}.')
raise TypeError(f'cls must be one of {supported_classes} or a subclass of one of them, not {cls!r}.')
func.__discord_ui_model_type__ = callback_cls
func.__discord_ui_model_kwargs__ = {