mirror of
				https://github.com/Rapptz/discord.py.git
				synced 2025-10-22 16:32:59 +00:00 
			
		
		
		
	Improve TranslationContext type narrowing using a tagged union
This commit is contained in:
		| @@ -52,7 +52,7 @@ from ..enums import AppCommandOptionType, AppCommandType, ChannelType, Locale | |||||||
| from .models import Choice | from .models import Choice | ||||||
| from .transformers import annotation_to_parameter, CommandParameter, NoneType | from .transformers import annotation_to_parameter, CommandParameter, NoneType | ||||||
| from .errors import AppCommandError, CheckFailure, CommandInvokeError, CommandSignatureMismatch, CommandAlreadyRegistered | from .errors import AppCommandError, CheckFailure, CommandInvokeError, CommandSignatureMismatch, CommandAlreadyRegistered | ||||||
| from .translator import TranslationContext, TranslationContextLocation, Translator, locale_str | from .translator import TranslationContextLocation, TranslationContext, Translator, locale_str | ||||||
| from ..message import Message | from ..message import Message | ||||||
| from ..user import User | from ..user import User | ||||||
| from ..member import Member | from ..member import Member | ||||||
|   | |||||||
| @@ -52,7 +52,7 @@ __all__ = ( | |||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|     from .commands import Command, Group, ContextMenu |     from .commands import Command, Group, ContextMenu | ||||||
|     from .transformers import Transformer |     from .transformers import Transformer | ||||||
|     from .translator import TranslationContext, locale_str |     from .translator import TranslationContextTypes, locale_str | ||||||
|     from ..types.snowflake import Snowflake, SnowflakeList |     from ..types.snowflake import Snowflake, SnowflakeList | ||||||
|     from .checks import Cooldown |     from .checks import Cooldown | ||||||
|  |  | ||||||
| @@ -164,11 +164,11 @@ class TranslationError(AppCommandError): | |||||||
|         *msg: str, |         *msg: str, | ||||||
|         string: Optional[Union[str, locale_str]] = None, |         string: Optional[Union[str, locale_str]] = None, | ||||||
|         locale: Optional[Locale] = None, |         locale: Optional[Locale] = None, | ||||||
|         context: TranslationContext, |         context: TranslationContextTypes, | ||||||
|     ) -> None: |     ) -> None: | ||||||
|         self.string: Optional[Union[str, locale_str]] = string |         self.string: Optional[Union[str, locale_str]] = string | ||||||
|         self.locale: Optional[Locale] = locale |         self.locale: Optional[Locale] = locale | ||||||
|         self.context: TranslationContext = context |         self.context: TranslationContextTypes = context | ||||||
|  |  | ||||||
|         if msg: |         if msg: | ||||||
|             super().__init__(*msg) |             super().__init__(*msg) | ||||||
|   | |||||||
| @@ -26,7 +26,7 @@ from __future__ import annotations | |||||||
| from datetime import datetime | from datetime import datetime | ||||||
|  |  | ||||||
| from .errors import MissingApplicationID | from .errors import MissingApplicationID | ||||||
| from .translator import TranslationContextLocation, Translator, TranslationContext, locale_str | from .translator import TranslationContextLocation, TranslationContext, locale_str, Translator | ||||||
| from ..permissions import Permissions | from ..permissions import Permissions | ||||||
| from ..enums import AppCommandOptionType, AppCommandType, AppCommandPermissionType, ChannelType, Locale, try_enum | from ..enums import AppCommandOptionType, AppCommandType, AppCommandPermissionType, ChannelType, Locale, try_enum | ||||||
| from ..mixins import Hashable | from ..mixins import Hashable | ||||||
|   | |||||||
| @@ -46,7 +46,7 @@ from typing import ( | |||||||
|  |  | ||||||
| from .errors import AppCommandError, TransformerError | from .errors import AppCommandError, TransformerError | ||||||
| from .models import AppCommandChannel, AppCommandThread, Choice | from .models import AppCommandChannel, AppCommandThread, Choice | ||||||
| from .translator import TranslationContextLocation, locale_str, Translator, TranslationContext | from .translator import TranslationContextLocation, TranslationContext, Translator, locale_str | ||||||
| from ..channel import StageChannel, VoiceChannel, TextChannel, CategoryChannel | from ..channel import StageChannel, VoiceChannel, TextChannel, CategoryChannel | ||||||
| from ..abc import GuildChannel | from ..abc import GuildChannel | ||||||
| from ..threads import Thread | from ..threads import Thread | ||||||
|   | |||||||
| @@ -23,13 +23,19 @@ DEALINGS IN THE SOFTWARE. | |||||||
| """ | """ | ||||||
|  |  | ||||||
| from __future__ import annotations | from __future__ import annotations | ||||||
| from typing import TYPE_CHECKING, Any, Literal, Optional, Union | from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, TypeVar, Union, overload | ||||||
| from .errors import TranslationError | from .errors import TranslationError | ||||||
| from ..enums import Enum, Locale | from ..enums import Enum, Locale | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if TYPE_CHECKING: | ||||||
|  |     from .commands import Command, ContextMenu, Group, Parameter | ||||||
|  |     from .models import Choice | ||||||
|  |  | ||||||
|  |  | ||||||
| __all__ = ( | __all__ = ( | ||||||
|     'TranslationContextLocation', |     'TranslationContextLocation', | ||||||
|  |     'TranslationContextTypes', | ||||||
|     'TranslationContext', |     'TranslationContext', | ||||||
|     'Translator', |     'Translator', | ||||||
|     'locale_str', |     'locale_str', | ||||||
| @@ -47,7 +53,11 @@ class TranslationContextLocation(Enum): | |||||||
|     other = 7 |     other = 7 | ||||||
|  |  | ||||||
|  |  | ||||||
| class TranslationContext:  # type: ignore  # See below | _L = TypeVar('_L', bound=TranslationContextLocation) | ||||||
|  | _D = TypeVar('_D') | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TranslationContext(Generic[_L, _D]): | ||||||
|     """A class that provides context for the :class:`locale_str` being translated. |     """A class that provides context for the :class:`locale_str` being translated. | ||||||
|  |  | ||||||
|     This is useful to determine where exactly the string is located and aid in looking |     This is useful to determine where exactly the string is located and aid in looking | ||||||
| @@ -63,61 +73,78 @@ class TranslationContext:  # type: ignore  # See below | |||||||
|  |  | ||||||
|     __slots__ = ('location', 'data') |     __slots__ = ('location', 'data') | ||||||
|  |  | ||||||
|     def __init__(self, location: TranslationContextLocation, data: Any) -> None: |     @overload | ||||||
|         self.location: TranslationContextLocation = location |     def __init__( | ||||||
|         self.data: Any = data |         self, location: Literal[TranslationContextLocation.command_name], data: Union[Command[Any, ..., Any], ContextMenu] | ||||||
|  |     ) -> None: | ||||||
|  |  | ||||||
| if TYPE_CHECKING: |  | ||||||
|     # For type checking purposes, it makes sense to allow the user to leverage type narrowing |  | ||||||
|     # So code like this works as expected: |  | ||||||
|     # if context.type is TranslationContextLocation.command_name: |  | ||||||
|     #    reveal_type(context.data)  # Revealed type is Command | ContextMenu |  | ||||||
|     # |  | ||||||
|     # Unfortunately doing a trick like this requires lying to the type checker so |  | ||||||
|     # this is what the code below enables. |  | ||||||
|     # |  | ||||||
|     # Should this trick stop working then it might be fair to remove this code. |  | ||||||
|     # It's purely here for convenience. |  | ||||||
|  |  | ||||||
|     from .commands import Command, ContextMenu, Group, Parameter |  | ||||||
|     from .models import Choice |  | ||||||
|  |  | ||||||
|     class _CommandNameTranslationContext: |  | ||||||
|         location: Literal[TranslationContextLocation.command_name] |  | ||||||
|         data: Union[Command[Any, ..., Any], ContextMenu] |  | ||||||
|  |  | ||||||
|     class _CommandDescriptionTranslationContext: |  | ||||||
|         location: Literal[TranslationContextLocation.command_description] |  | ||||||
|         data: Command[Any, ..., Any] |  | ||||||
|  |  | ||||||
|     class _GroupTranslationContext: |  | ||||||
|         location: Literal[TranslationContextLocation.group_name, TranslationContextLocation.group_description] |  | ||||||
|         data: Group |  | ||||||
|  |  | ||||||
|     class _ParameterTranslationContext: |  | ||||||
|         location: Literal[TranslationContextLocation.parameter_description, TranslationContextLocation.parameter_name] |  | ||||||
|         data: Parameter |  | ||||||
|  |  | ||||||
|     class _ChoiceTranslationContext: |  | ||||||
|         location: Literal[TranslationContextLocation.choice_name] |  | ||||||
|         data: Choice[Union[int, str, float]] |  | ||||||
|  |  | ||||||
|     class _OtherTranslationContext: |  | ||||||
|         location: Literal[TranslationContextLocation.other] |  | ||||||
|         data: Any |  | ||||||
|  |  | ||||||
|     class TranslationContext( |  | ||||||
|         _CommandNameTranslationContext, |  | ||||||
|         _CommandDescriptionTranslationContext, |  | ||||||
|         _GroupTranslationContext, |  | ||||||
|         _ParameterTranslationContext, |  | ||||||
|         _ChoiceTranslationContext, |  | ||||||
|         _OtherTranslationContext, |  | ||||||
|     ): |  | ||||||
|         def __init__(self, location: TranslationContextLocation, data: Any) -> None: |  | ||||||
|         ... |         ... | ||||||
|  |  | ||||||
|  |     @overload | ||||||
|  |     def __init__( | ||||||
|  |         self, location: Literal[TranslationContextLocation.command_description], data: Command[Any, ..., Any] | ||||||
|  |     ) -> None: | ||||||
|  |         ... | ||||||
|  |  | ||||||
|  |     @overload | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         location: Literal[TranslationContextLocation.group_name, TranslationContextLocation.group_description], | ||||||
|  |         data: Group, | ||||||
|  |     ) -> None: | ||||||
|  |         ... | ||||||
|  |  | ||||||
|  |     @overload | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         location: Literal[TranslationContextLocation.parameter_name, TranslationContextLocation.parameter_description], | ||||||
|  |         data: Parameter, | ||||||
|  |     ) -> None: | ||||||
|  |         ... | ||||||
|  |  | ||||||
|  |     @overload | ||||||
|  |     def __init__(self, location: Literal[TranslationContextLocation.choice_name], data: Choice[Any]) -> None: | ||||||
|  |         ... | ||||||
|  |  | ||||||
|  |     @overload | ||||||
|  |     def __init__(self, location: Literal[TranslationContextLocation.other], data: Any) -> None: | ||||||
|  |         ... | ||||||
|  |  | ||||||
|  |     def __init__(self, location: _L, data: _D) -> None: | ||||||
|  |         self.location: _L = location | ||||||
|  |         self.data: _D = data | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # For type checking purposes, it makes sense to allow the user to leverage type narrowing | ||||||
|  | # So code like this works as expected: | ||||||
|  | # | ||||||
|  | # if context.type == TranslationContextLocation.command_name: | ||||||
|  | #    reveal_type(context.data)  # Revealed type is Command | ContextMenu | ||||||
|  | # | ||||||
|  | # This requires a union of types | ||||||
|  | CommandNameTranslationContext = TranslationContext[ | ||||||
|  |     Literal[TranslationContextLocation.command_name], Union['Command[Any, ..., Any]', 'ContextMenu'] | ||||||
|  | ] | ||||||
|  | CommandDescriptionTranslationContext = TranslationContext[ | ||||||
|  |     Literal[TranslationContextLocation.command_description], 'Command[Any, ..., Any]' | ||||||
|  | ] | ||||||
|  | GroupTranslationContext = TranslationContext[ | ||||||
|  |     Literal[TranslationContextLocation.group_name, TranslationContextLocation.group_description], 'Group' | ||||||
|  | ] | ||||||
|  | ParameterTranslationContext = TranslationContext[ | ||||||
|  |     Literal[TranslationContextLocation.parameter_name, TranslationContextLocation.parameter_description], 'Parameter' | ||||||
|  | ] | ||||||
|  | ChoiceTranslationContext = TranslationContext[Literal[TranslationContextLocation.choice_name], 'Choice[Any]'] | ||||||
|  | OtherTranslationContext = TranslationContext[Literal[TranslationContextLocation.other], Any] | ||||||
|  |  | ||||||
|  | TranslationContextTypes = Union[ | ||||||
|  |     CommandNameTranslationContext, | ||||||
|  |     CommandDescriptionTranslationContext, | ||||||
|  |     GroupTranslationContext, | ||||||
|  |     ParameterTranslationContext, | ||||||
|  |     ChoiceTranslationContext, | ||||||
|  |     OtherTranslationContext, | ||||||
|  | ] | ||||||
|  |  | ||||||
|  |  | ||||||
| class Translator: | class Translator: | ||||||
|     """A class that handles translations for commands, parameters, and choices. |     """A class that handles translations for commands, parameters, and choices. | ||||||
| @@ -162,7 +189,9 @@ class Translator: | |||||||
|         """ |         """ | ||||||
|         pass |         pass | ||||||
|  |  | ||||||
|     async def _checked_translate(self, string: locale_str, locale: Locale, context: TranslationContext) -> Optional[str]: |     async def _checked_translate( | ||||||
|  |         self, string: locale_str, locale: Locale, context: TranslationContextTypes | ||||||
|  |     ) -> Optional[str]: | ||||||
|         try: |         try: | ||||||
|             return await self.translate(string, locale, context) |             return await self.translate(string, locale, context) | ||||||
|         except TranslationError: |         except TranslationError: | ||||||
| @@ -170,7 +199,7 @@ class Translator: | |||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             raise TranslationError(string=string, locale=locale, context=context) from e |             raise TranslationError(string=string, locale=locale, context=context) from e | ||||||
|  |  | ||||||
|     async def translate(self, string: locale_str, locale: Locale, context: TranslationContext) -> Optional[str]: |     async def translate(self, string: locale_str, locale: Locale, context: TranslationContextTypes) -> Optional[str]: | ||||||
|         """|coro| |         """|coro| | ||||||
|  |  | ||||||
|         Translates the given string to the specified locale. |         Translates the given string to the specified locale. | ||||||
| @@ -190,6 +219,9 @@ class Translator: | |||||||
|             The locale being requested for translation. |             The locale being requested for translation. | ||||||
|         context: :class:`TranslationContext` |         context: :class:`TranslationContext` | ||||||
|             The translation context where the string originated from. |             The translation context where the string originated from. | ||||||
|  |             For better type checking ergonomics, the ``TranslationContextTypes`` | ||||||
|  |             type can be used instead to aid with type narrowing. It is functionally | ||||||
|  |             equivalent to :class:`TranslationContext`. | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         return None |         return None | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user