Improve TranslationContext type narrowing using a tagged union

This commit is contained in:
Bryan Forbes 2022-08-15 07:17:41 -05:00 committed by GitHub
parent 49e6fe9a0c
commit 63b32994f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 93 additions and 61 deletions

View File

@ -52,7 +52,7 @@ from ..enums import AppCommandOptionType, AppCommandType, ChannelType, Locale
from .models import Choice from .models import Choice
from .transformers import annotation_to_parameter, CommandParameter, NoneType from .transformers import annotation_to_parameter, CommandParameter, NoneType
from .errors import AppCommandError, CheckFailure, CommandInvokeError, CommandSignatureMismatch, CommandAlreadyRegistered from .errors import AppCommandError, CheckFailure, CommandInvokeError, CommandSignatureMismatch, CommandAlreadyRegistered
from .translator import TranslationContext, TranslationContextLocation, Translator, locale_str from .translator import TranslationContextLocation, TranslationContext, Translator, locale_str
from ..message import Message from ..message import Message
from ..user import User from ..user import User
from ..member import Member from ..member import Member

View File

@ -52,7 +52,7 @@ __all__ = (
if TYPE_CHECKING: if TYPE_CHECKING:
from .commands import Command, Group, ContextMenu from .commands import Command, Group, ContextMenu
from .transformers import Transformer from .transformers import Transformer
from .translator import TranslationContext, locale_str from .translator import TranslationContextTypes, locale_str
from ..types.snowflake import Snowflake, SnowflakeList from ..types.snowflake import Snowflake, SnowflakeList
from .checks import Cooldown from .checks import Cooldown
@ -164,11 +164,11 @@ class TranslationError(AppCommandError):
*msg: str, *msg: str,
string: Optional[Union[str, locale_str]] = None, string: Optional[Union[str, locale_str]] = None,
locale: Optional[Locale] = None, locale: Optional[Locale] = None,
context: TranslationContext, context: TranslationContextTypes,
) -> None: ) -> None:
self.string: Optional[Union[str, locale_str]] = string self.string: Optional[Union[str, locale_str]] = string
self.locale: Optional[Locale] = locale self.locale: Optional[Locale] = locale
self.context: TranslationContext = context self.context: TranslationContextTypes = context
if msg: if msg:
super().__init__(*msg) super().__init__(*msg)

View File

@ -26,7 +26,7 @@ from __future__ import annotations
from datetime import datetime from datetime import datetime
from .errors import MissingApplicationID from .errors import MissingApplicationID
from .translator import TranslationContextLocation, Translator, TranslationContext, locale_str from .translator import TranslationContextLocation, TranslationContext, locale_str, Translator
from ..permissions import Permissions from ..permissions import Permissions
from ..enums import AppCommandOptionType, AppCommandType, AppCommandPermissionType, ChannelType, Locale, try_enum from ..enums import AppCommandOptionType, AppCommandType, AppCommandPermissionType, ChannelType, Locale, try_enum
from ..mixins import Hashable from ..mixins import Hashable

View File

@ -46,7 +46,7 @@ from typing import (
from .errors import AppCommandError, TransformerError from .errors import AppCommandError, TransformerError
from .models import AppCommandChannel, AppCommandThread, Choice from .models import AppCommandChannel, AppCommandThread, Choice
from .translator import TranslationContextLocation, locale_str, Translator, TranslationContext from .translator import TranslationContextLocation, TranslationContext, Translator, locale_str
from ..channel import StageChannel, VoiceChannel, TextChannel, CategoryChannel from ..channel import StageChannel, VoiceChannel, TextChannel, CategoryChannel
from ..abc import GuildChannel from ..abc import GuildChannel
from ..threads import Thread from ..threads import Thread

View File

@ -23,13 +23,19 @@ DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, Literal, Optional, Union from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, TypeVar, Union, overload
from .errors import TranslationError from .errors import TranslationError
from ..enums import Enum, Locale from ..enums import Enum, Locale
if TYPE_CHECKING:
from .commands import Command, ContextMenu, Group, Parameter
from .models import Choice
__all__ = ( __all__ = (
'TranslationContextLocation', 'TranslationContextLocation',
'TranslationContextTypes',
'TranslationContext', 'TranslationContext',
'Translator', 'Translator',
'locale_str', 'locale_str',
@ -47,7 +53,11 @@ class TranslationContextLocation(Enum):
other = 7 other = 7
class TranslationContext: # type: ignore # See below _L = TypeVar('_L', bound=TranslationContextLocation)
_D = TypeVar('_D')
class TranslationContext(Generic[_L, _D]):
"""A class that provides context for the :class:`locale_str` being translated. """A class that provides context for the :class:`locale_str` being translated.
This is useful to determine where exactly the string is located and aid in looking This is useful to determine where exactly the string is located and aid in looking
@ -63,60 +73,77 @@ class TranslationContext: # type: ignore # See below
__slots__ = ('location', 'data') __slots__ = ('location', 'data')
def __init__(self, location: TranslationContextLocation, data: Any) -> None: @overload
self.location: TranslationContextLocation = location def __init__(
self.data: Any = data self, location: Literal[TranslationContextLocation.command_name], data: Union[Command[Any, ..., Any], ContextMenu]
) -> None:
...
@overload
def __init__(
self, location: Literal[TranslationContextLocation.command_description], data: Command[Any, ..., Any]
) -> None:
...
@overload
def __init__(
self,
location: Literal[TranslationContextLocation.group_name, TranslationContextLocation.group_description],
data: Group,
) -> None:
...
@overload
def __init__(
self,
location: Literal[TranslationContextLocation.parameter_name, TranslationContextLocation.parameter_description],
data: Parameter,
) -> None:
...
@overload
def __init__(self, location: Literal[TranslationContextLocation.choice_name], data: Choice[Any]) -> None:
...
@overload
def __init__(self, location: Literal[TranslationContextLocation.other], data: Any) -> None:
...
def __init__(self, location: _L, data: _D) -> None:
self.location: _L = location
self.data: _D = data
if TYPE_CHECKING: # For type checking purposes, it makes sense to allow the user to leverage type narrowing
# For type checking purposes, it makes sense to allow the user to leverage type narrowing # So code like this works as expected:
# So code like this works as expected: #
# if context.type is TranslationContextLocation.command_name: # if context.type == TranslationContextLocation.command_name:
# reveal_type(context.data) # Revealed type is Command | ContextMenu # reveal_type(context.data) # Revealed type is Command | ContextMenu
# #
# Unfortunately doing a trick like this requires lying to the type checker so # This requires a union of types
# this is what the code below enables. CommandNameTranslationContext = TranslationContext[
# Literal[TranslationContextLocation.command_name], Union['Command[Any, ..., Any]', 'ContextMenu']
# Should this trick stop working then it might be fair to remove this code. ]
# It's purely here for convenience. CommandDescriptionTranslationContext = TranslationContext[
Literal[TranslationContextLocation.command_description], 'Command[Any, ..., Any]'
]
GroupTranslationContext = TranslationContext[
Literal[TranslationContextLocation.group_name, TranslationContextLocation.group_description], 'Group'
]
ParameterTranslationContext = TranslationContext[
Literal[TranslationContextLocation.parameter_name, TranslationContextLocation.parameter_description], 'Parameter'
]
ChoiceTranslationContext = TranslationContext[Literal[TranslationContextLocation.choice_name], 'Choice[Any]']
OtherTranslationContext = TranslationContext[Literal[TranslationContextLocation.other], Any]
from .commands import Command, ContextMenu, Group, Parameter TranslationContextTypes = Union[
from .models import Choice CommandNameTranslationContext,
CommandDescriptionTranslationContext,
class _CommandNameTranslationContext: GroupTranslationContext,
location: Literal[TranslationContextLocation.command_name] ParameterTranslationContext,
data: Union[Command[Any, ..., Any], ContextMenu] ChoiceTranslationContext,
OtherTranslationContext,
class _CommandDescriptionTranslationContext: ]
location: Literal[TranslationContextLocation.command_description]
data: Command[Any, ..., Any]
class _GroupTranslationContext:
location: Literal[TranslationContextLocation.group_name, TranslationContextLocation.group_description]
data: Group
class _ParameterTranslationContext:
location: Literal[TranslationContextLocation.parameter_description, TranslationContextLocation.parameter_name]
data: Parameter
class _ChoiceTranslationContext:
location: Literal[TranslationContextLocation.choice_name]
data: Choice[Union[int, str, float]]
class _OtherTranslationContext:
location: Literal[TranslationContextLocation.other]
data: Any
class TranslationContext(
_CommandNameTranslationContext,
_CommandDescriptionTranslationContext,
_GroupTranslationContext,
_ParameterTranslationContext,
_ChoiceTranslationContext,
_OtherTranslationContext,
):
def __init__(self, location: TranslationContextLocation, data: Any) -> None:
...
class Translator: class Translator:
@ -162,7 +189,9 @@ class Translator:
""" """
pass pass
async def _checked_translate(self, string: locale_str, locale: Locale, context: TranslationContext) -> Optional[str]: async def _checked_translate(
self, string: locale_str, locale: Locale, context: TranslationContextTypes
) -> Optional[str]:
try: try:
return await self.translate(string, locale, context) return await self.translate(string, locale, context)
except TranslationError: except TranslationError:
@ -170,7 +199,7 @@ class Translator:
except Exception as e: except Exception as e:
raise TranslationError(string=string, locale=locale, context=context) from e raise TranslationError(string=string, locale=locale, context=context) from e
async def translate(self, string: locale_str, locale: Locale, context: TranslationContext) -> Optional[str]: async def translate(self, string: locale_str, locale: Locale, context: TranslationContextTypes) -> Optional[str]:
"""|coro| """|coro|
Translates the given string to the specified locale. Translates the given string to the specified locale.
@ -190,6 +219,9 @@ class Translator:
The locale being requested for translation. The locale being requested for translation.
context: :class:`TranslationContext` context: :class:`TranslationContext`
The translation context where the string originated from. The translation context where the string originated from.
For better type checking ergonomics, the ``TranslationContextTypes``
type can be used instead to aid with type narrowing. It is functionally
equivalent to :class:`TranslationContext`.
""" """
return None return None