Improve typing of app command transformers

This allows subclasses of transformers to specify a specialization for
interaction without violating covariance of parameter types
This commit is contained in:
Michael H 2024-10-09 17:27:55 -04:00 committed by GitHub
parent 053f29c96c
commit 3e168a93bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -34,6 +34,7 @@ from typing import (
ClassVar, ClassVar,
Coroutine, Coroutine,
Dict, Dict,
Generic,
List, List,
Literal, Literal,
Optional, Optional,
@ -56,6 +57,7 @@ from ..user import User
from ..role import Role from ..role import Role
from ..member import Member from ..member import Member
from ..message import Attachment from ..message import Attachment
from .._types import ClientT
__all__ = ( __all__ = (
'Transformer', 'Transformer',
@ -191,7 +193,7 @@ class CommandParameter:
return self.name if self._rename is MISSING else str(self._rename) return self.name if self._rename is MISSING else str(self._rename)
class Transformer: class Transformer(Generic[ClientT]):
"""The base class that allows a type annotation in an application command parameter """The base class that allows a type annotation in an application command parameter
to map into a :class:`~discord.AppCommandOptionType` and transform the raw value into one to map into a :class:`~discord.AppCommandOptionType` and transform the raw value into one
from this type. from this type.
@ -304,7 +306,7 @@ class Transformer:
else: else:
return name return name
async def transform(self, interaction: Interaction, value: Any, /) -> Any: async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Any:
"""|maybecoro| """|maybecoro|
Transforms the converted option value into another value. Transforms the converted option value into another value.
@ -324,7 +326,7 @@ class Transformer:
raise NotImplementedError('Derived classes need to implement this.') raise NotImplementedError('Derived classes need to implement this.')
async def autocomplete( async def autocomplete(
self, interaction: Interaction, value: Union[int, float, str], / self, interaction: Interaction[ClientT], value: Union[int, float, str], /
) -> List[Choice[Union[int, float, str]]]: ) -> List[Choice[Union[int, float, str]]]:
"""|coro| """|coro|
@ -352,7 +354,7 @@ class Transformer:
raise NotImplementedError('Derived classes can implement this.') raise NotImplementedError('Derived classes can implement this.')
class IdentityTransformer(Transformer): class IdentityTransformer(Transformer[ClientT]):
def __init__(self, type: AppCommandOptionType) -> None: def __init__(self, type: AppCommandOptionType) -> None:
self._type = type self._type = type
@ -360,7 +362,7 @@ class IdentityTransformer(Transformer):
def type(self) -> AppCommandOptionType: def type(self) -> AppCommandOptionType:
return self._type return self._type
async def transform(self, interaction: Interaction, value: Any, /) -> Any: async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Any:
return value return value
@ -489,7 +491,7 @@ class EnumNameTransformer(Transformer):
return self._enum[value] return self._enum[value]
class InlineTransformer(Transformer): class InlineTransformer(Transformer[ClientT]):
def __init__(self, annotation: Any) -> None: def __init__(self, annotation: Any) -> None:
super().__init__() super().__init__()
self.annotation: Any = annotation self.annotation: Any = annotation
@ -502,7 +504,7 @@ class InlineTransformer(Transformer):
def type(self) -> AppCommandOptionType: def type(self) -> AppCommandOptionType:
return AppCommandOptionType.string return AppCommandOptionType.string
async def transform(self, interaction: Interaction, value: Any, /) -> Any: async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Any:
return await self.annotation.transform(interaction, value) return await self.annotation.transform(interaction, value)
@ -611,18 +613,18 @@ else:
return transformer return transformer
class MemberTransformer(Transformer): class MemberTransformer(Transformer[ClientT]):
@property @property
def type(self) -> AppCommandOptionType: def type(self) -> AppCommandOptionType:
return AppCommandOptionType.user return AppCommandOptionType.user
async def transform(self, interaction: Interaction, value: Any, /) -> Member: async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Member:
if not isinstance(value, Member): if not isinstance(value, Member):
raise TransformerError(value, self.type, self) raise TransformerError(value, self.type, self)
return value return value
class BaseChannelTransformer(Transformer): class BaseChannelTransformer(Transformer[ClientT]):
def __init__(self, *channel_types: Type[Any]) -> None: def __init__(self, *channel_types: Type[Any]) -> None:
super().__init__() super().__init__()
if len(channel_types) == 1: if len(channel_types) == 1:
@ -654,22 +656,22 @@ class BaseChannelTransformer(Transformer):
def channel_types(self) -> List[ChannelType]: def channel_types(self) -> List[ChannelType]:
return self._channel_types return self._channel_types
async def transform(self, interaction: Interaction, value: Any, /): async def transform(self, interaction: Interaction[ClientT], value: Any, /):
resolved = value.resolve() resolved = value.resolve()
if resolved is None or not isinstance(resolved, self._types): if resolved is None or not isinstance(resolved, self._types):
raise TransformerError(value, AppCommandOptionType.channel, self) raise TransformerError(value, AppCommandOptionType.channel, self)
return resolved return resolved
class RawChannelTransformer(BaseChannelTransformer): class RawChannelTransformer(BaseChannelTransformer[ClientT]):
async def transform(self, interaction: Interaction, value: Any, /): async def transform(self, interaction: Interaction[ClientT], value: Any, /):
if not isinstance(value, self._types): if not isinstance(value, self._types):
raise TransformerError(value, AppCommandOptionType.channel, self) raise TransformerError(value, AppCommandOptionType.channel, self)
return value return value
class UnionChannelTransformer(BaseChannelTransformer): class UnionChannelTransformer(BaseChannelTransformer[ClientT]):
async def transform(self, interaction: Interaction, value: Any, /): async def transform(self, interaction: Interaction[ClientT], value: Any, /):
if isinstance(value, self._types): if isinstance(value, self._types):
return value return value