Fix typing issues and improve typing completeness across the library

Co-authored-by: Danny <Rapptz@users.noreply.github.com>
Co-authored-by: Josh <josh.ja.butt@gmail.com>
This commit is contained in:
Stocker
2022-03-13 23:52:10 -04:00
committed by GitHub
parent 603681940f
commit 5aa696ccfa
66 changed files with 1071 additions and 802 deletions

View File

@ -44,6 +44,7 @@ if TYPE_CHECKING:
from .view import View
from ..emoji import Emoji
from ..types.components import ButtonComponent as ButtonComponentPayload
V = TypeVar('V', bound='View', covariant=True)
@ -124,7 +125,7 @@ class Button(Item[V]):
style=style,
emoji=emoji,
)
self.row = row
self.row: Optional[int] = row
@property
def style(self) -> ButtonStyle:
@ -132,7 +133,7 @@ class Button(Item[V]):
return self._underlying.style
@style.setter
def style(self, value: ButtonStyle):
def style(self, value: ButtonStyle) -> None:
self._underlying.style = value
@property
@ -144,7 +145,7 @@ class Button(Item[V]):
return self._underlying.custom_id
@custom_id.setter
def custom_id(self, value: Optional[str]):
def custom_id(self, value: Optional[str]) -> None:
if value is not None and not isinstance(value, str):
raise TypeError('custom_id must be None or str')
@ -156,7 +157,7 @@ class Button(Item[V]):
return self._underlying.url
@url.setter
def url(self, value: Optional[str]):
def url(self, value: Optional[str]) -> None:
if value is not None and not isinstance(value, str):
raise TypeError('url must be None or str')
self._underlying.url = value
@ -167,7 +168,7 @@ class Button(Item[V]):
return self._underlying.disabled
@disabled.setter
def disabled(self, value: bool):
def disabled(self, value: bool) -> None:
self._underlying.disabled = bool(value)
@property
@ -176,7 +177,7 @@ class Button(Item[V]):
return self._underlying.label
@label.setter
def label(self, value: Optional[str]):
def label(self, value: Optional[str]) -> None:
self._underlying.label = str(value) if value is not None else value
@property
@ -185,7 +186,7 @@ class Button(Item[V]):
return self._underlying.emoji
@emoji.setter
def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]): # type: ignore
def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]) -> None:
if value is not None:
if isinstance(value, str):
self._underlying.emoji = PartialEmoji.from_str(value)
@ -212,7 +213,7 @@ class Button(Item[V]):
def type(self) -> ComponentType:
return self._underlying.type
def to_component_dict(self):
def to_component_dict(self) -> ButtonComponentPayload:
return self._underlying.to_dict()
def is_dispatchable(self) -> bool:

View File

@ -101,7 +101,7 @@ class Item(Generic[V]):
return self._row
@row.setter
def row(self, value: Optional[int]):
def row(self, value: Optional[int]) -> None:
if value is None:
self._row = None
elif 5 > value >= 0:
@ -118,7 +118,7 @@ class Item(Generic[V]):
"""Optional[:class:`View`]: The underlying view for this item."""
return self._view
async def callback(self, interaction: Interaction):
async def callback(self, interaction: Interaction) -> Any:
"""|coro|
The callback associated with this UI item.

View File

@ -38,6 +38,8 @@ from .item import Item
from .view import View
if TYPE_CHECKING:
from typing_extensions import Self
from ..interactions import Interaction
from ..types.interactions import ModalSubmitComponentInteractionData as ModalSubmitComponentInteractionDataPayload
@ -101,7 +103,7 @@ class Modal(View):
title: str
__discord_ui_modal__ = True
__modal_children_items__: ClassVar[Dict[str, Item]] = {}
__modal_children_items__: ClassVar[Dict[str, Item[Self]]] = {}
def __init_subclass__(cls, *, title: str = MISSING) -> None:
if title is not MISSING:
@ -139,7 +141,7 @@ class Modal(View):
super().__init__(timeout=timeout)
async def on_submit(self, interaction: Interaction):
async def on_submit(self, interaction: Interaction) -> None:
"""|coro|
Called when the modal is submitted.
@ -169,7 +171,7 @@ class Modal(View):
print(f'Ignoring exception in modal {self}:', file=sys.stderr)
traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr)
def refresh(self, components: Sequence[ModalSubmitComponentInteractionDataPayload]):
def refresh(self, components: Sequence[ModalSubmitComponentInteractionDataPayload]) -> None:
for component in components:
if component['type'] == 1:
self.refresh(component['components'])

View File

@ -121,7 +121,7 @@ class Select(Item[V]):
options=options,
disabled=disabled,
)
self.row = row
self.row: Optional[int] = row
@property
def custom_id(self) -> str:
@ -129,7 +129,7 @@ class Select(Item[V]):
return self._underlying.custom_id
@custom_id.setter
def custom_id(self, value: str):
def custom_id(self, value: str) -> None:
if not isinstance(value, str):
raise TypeError('custom_id must be None or str')
@ -141,7 +141,7 @@ class Select(Item[V]):
return self._underlying.placeholder
@placeholder.setter
def placeholder(self, value: Optional[str]):
def placeholder(self, value: Optional[str]) -> None:
if value is not None and not isinstance(value, str):
raise TypeError('placeholder must be None or str')
@ -153,7 +153,7 @@ class Select(Item[V]):
return self._underlying.min_values
@min_values.setter
def min_values(self, value: int):
def min_values(self, value: int) -> None:
self._underlying.min_values = int(value)
@property
@ -162,7 +162,7 @@ class Select(Item[V]):
return self._underlying.max_values
@max_values.setter
def max_values(self, value: int):
def max_values(self, value: int) -> None:
self._underlying.max_values = int(value)
@property
@ -171,7 +171,7 @@ class Select(Item[V]):
return self._underlying.options
@options.setter
def options(self, value: List[SelectOption]):
def options(self, value: List[SelectOption]) -> None:
if not isinstance(value, list):
raise TypeError('options must be a list of SelectOption')
if not all(isinstance(obj, SelectOption) for obj in value):
@ -187,7 +187,7 @@ class Select(Item[V]):
description: Optional[str] = None,
emoji: Optional[Union[str, Emoji, PartialEmoji]] = None,
default: bool = False,
):
) -> None:
"""Adds an option to the select menu.
To append a pre-existing :class:`discord.SelectOption` use the
@ -226,7 +226,7 @@ class Select(Item[V]):
self.append_option(option)
def append_option(self, option: SelectOption):
def append_option(self, option: SelectOption) -> None:
"""Appends an option to the select menu.
Parameters
@ -251,7 +251,7 @@ class Select(Item[V]):
return self._underlying.disabled
@disabled.setter
def disabled(self, value: bool):
def disabled(self, value: bool) -> None:
self._underlying.disabled = bool(value)
@property

View File

@ -50,6 +50,8 @@ __all__ = (
if TYPE_CHECKING:
from typing_extensions import Self
from ..interactions import Interaction
from ..message import Message
from ..types.components import Component as ComponentPayload
@ -163,7 +165,7 @@ class View:
cls.__view_children_items__ = children
def _init_children(self) -> List[Item]:
def _init_children(self) -> List[Item[Self]]:
children = []
for func in self.__view_children_items__:
item: Item = func.__discord_ui_model_type__(**func.__discord_ui_model_kwargs__)
@ -175,7 +177,7 @@ class View:
def __init__(self, *, timeout: Optional[float] = 180.0):
self.timeout = timeout
self.children: List[Item] = self._init_children()
self.children: List[Item[Self]] = self._init_children()
self.__weights = _ViewWeights(self.children)
self.id: str = os.urandom(16).hex()
self.__cancel_callback: Optional[Callable[[View], None]] = None
@ -250,7 +252,7 @@ class View:
view.add_item(_component_to_item(component))
return view
def add_item(self, item: Item) -> None:
def add_item(self, item: Item[Any]) -> None:
"""Adds an item to the view.
Parameters
@ -278,7 +280,7 @@ class View:
item._view = self
self.children.append(item)
def remove_item(self, item: Item) -> None:
def remove_item(self, item: Item[Any]) -> None:
"""Removes an item from the view.
Parameters
@ -334,7 +336,7 @@ class View:
"""
pass
async def on_error(self, error: Exception, item: Item, interaction: Interaction) -> None:
async def on_error(self, error: Exception, item: Item[Any], interaction: Interaction) -> None:
"""|coro|
A callback that is called when an item's callback or :meth:`interaction_check`
@ -395,16 +397,16 @@ class View:
asyncio.create_task(self._scheduled_task(item, interaction), name=f'discord-ui-view-dispatch-{self.id}')
def refresh(self, components: List[Component]):
def refresh(self, components: List[Component]) -> None:
# This is pretty hacky at the moment
# fmt: off
old_state: Dict[Tuple[int, str], Item] = {
old_state: Dict[Tuple[int, str], Item[Any]] = {
(item.type.value, item.custom_id): item # type: ignore
for item in self.children
if item.is_dispatchable()
}
# fmt: on
children: List[Item] = []
children: List[Item[Any]] = []
for component in _walk_all_components(components):
try:
older = old_state[(component.type.value, component.custom_id)] # type: ignore
@ -494,7 +496,7 @@ class ViewStore:
for k in to_remove:
del self._views[k]
def add_view(self, view: View, message_id: Optional[int] = None):
def add_view(self, view: View, message_id: Optional[int] = None) -> None:
view._start_listening_from_store(self)
if view.__discord_ui_modal__:
self._modals[view.custom_id] = view # type: ignore
@ -509,7 +511,7 @@ class ViewStore:
if message_id is not None:
self._synced_message_views[message_id] = view
def remove_view(self, view: View):
def remove_view(self, view: View) -> None:
if view.__discord_ui_modal__:
self._modals.pop(view.custom_id, None) # type: ignore
return
@ -523,7 +525,7 @@ class ViewStore:
del self._synced_message_views[key]
break
def dispatch_view(self, component_type: int, custom_id: str, interaction: Interaction):
def dispatch_view(self, component_type: int, custom_id: str, interaction: Interaction) -> None:
self.__verify_integrity()
message_id: Optional[int] = interaction.message and interaction.message.id
key = (component_type, message_id, custom_id)
@ -542,7 +544,7 @@ class ViewStore:
custom_id: str,
interaction: Interaction,
components: List[ModalSubmitComponentInteractionDataPayload],
):
) -> None:
modal = self._modals.get(custom_id)
if modal is None:
_log.debug("Modal interaction referencing unknown custom_id %s. Discarding", custom_id)
@ -551,13 +553,13 @@ class ViewStore:
modal.refresh(components)
modal._dispatch_submit(interaction)
def is_message_tracked(self, message_id: int):
def is_message_tracked(self, message_id: int) -> bool:
return message_id in self._synced_message_views
def remove_message_tracking(self, message_id: int) -> Optional[View]:
return self._synced_message_views.pop(message_id, None)
def update_from_message(self, message_id: int, components: List[ComponentPayload]):
def update_from_message(self, message_id: int, components: List[ComponentPayload]) -> None:
# pre-req: is_message_tracked == true
view = self._synced_message_views[message_id]
view.refresh([_component_factory(d) for d in components])