Change the way callbacks are defined to allow deriving
This should hopefully make these work more consistently as other functions do.
This commit is contained in:
		| @@ -87,8 +87,6 @@ class Button(Item): | ||||
|         The emoji of the button, if available. | ||||
|     """ | ||||
|  | ||||
|     __slots__: Tuple[str, ...] = Item.__slots__ + ('_underlying',) | ||||
|  | ||||
|     __item_repr_attributes__: Tuple[str, ...] = ( | ||||
|         'style', | ||||
|         'url', | ||||
| @@ -192,19 +190,6 @@ class Button(Item): | ||||
|         else: | ||||
|             self._underlying.emoji = None | ||||
|  | ||||
|     def copy(self: B) -> B: | ||||
|         button = self.__class__( | ||||
|             style=self.style, | ||||
|             label=self.label, | ||||
|             disabled=self.disabled, | ||||
|             custom_id=self.custom_id, | ||||
|             url=self.url, | ||||
|             emoji=self.emoji, | ||||
|             group=self.group_id, | ||||
|         ) | ||||
|         button.callback = self.callback | ||||
|         return button | ||||
|  | ||||
|     @classmethod | ||||
|     def from_component(cls: Type[B], button: ButtonComponent) -> B: | ||||
|         return cls( | ||||
| @@ -239,7 +224,7 @@ def button( | ||||
|     style: ButtonStyle = ButtonStyle.grey, | ||||
|     emoji: Optional[Union[str, PartialEmoji]] = None, | ||||
|     group: Optional[int] = None, | ||||
| ) -> Callable[[ItemCallbackType], Button]: | ||||
| ) -> Callable[[ItemCallbackType], ItemCallbackType]: | ||||
|     """A decorator that attaches a button to a component. | ||||
|  | ||||
|     The function being decorated should have three parameters, ``self`` representing | ||||
| @@ -275,14 +260,22 @@ def button( | ||||
|         ordering. | ||||
|     """ | ||||
|  | ||||
|     def decorator(func: ItemCallbackType) -> Button: | ||||
|     def decorator(func: ItemCallbackType) -> ItemCallbackType: | ||||
|         nonlocal custom_id | ||||
|         if not inspect.iscoroutinefunction(func): | ||||
|             raise TypeError('button function must be a coroutine function') | ||||
|  | ||||
|         custom_id = custom_id or os.urandom(32).hex() | ||||
|         button = Button(style=style, custom_id=custom_id, url=None, disabled=disabled, label=label, emoji=emoji, group=group) | ||||
|         button.callback = func | ||||
|         return button | ||||
|         func.__discord_ui_model_type__ = Button | ||||
|         func.__discord_ui_model_kwargs__ = { | ||||
|             'style': style, | ||||
|             'custom_id': custom_id, | ||||
|             'url': None, | ||||
|             'disabled': disabled, | ||||
|             'label': label, | ||||
|             'emoji': emoji, | ||||
|             'group': group, | ||||
|         } | ||||
|         return func | ||||
|  | ||||
|     return decorator | ||||
|   | ||||
| @@ -24,8 +24,7 @@ DEALINGS IN THE SOFTWARE. | ||||
|  | ||||
| from __future__ import annotations | ||||
|  | ||||
| from typing import Any, Callable, Coroutine, Dict, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union | ||||
| import inspect | ||||
| from typing import Any, Callable, Coroutine, Dict, Optional, TYPE_CHECKING, Tuple, Type, TypeVar | ||||
|  | ||||
| from ..interactions import Interaction | ||||
|  | ||||
| @@ -50,25 +49,15 @@ class Item: | ||||
|     - :class:`discord.ui.Button` | ||||
|     """ | ||||
|  | ||||
|     __slots__: Tuple[str, ...] = ( | ||||
|         '_callback', | ||||
|         '_pass_view_arg', | ||||
|         'group_id', | ||||
|     ) | ||||
|  | ||||
|     __item_repr_attributes__: Tuple[str, ...] = ('group_id',) | ||||
|  | ||||
|     def __init__(self): | ||||
|         self._callback: Optional[ItemCallbackType] = None | ||||
|         self._pass_view_arg = True | ||||
|         self._view: Optional[View] = None | ||||
|         self.group_id: Optional[int] = None | ||||
|  | ||||
|     def to_component_dict(self) -> Dict[str, Any]: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def copy(self: I) -> I: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def refresh_state(self, component: Component) -> None: | ||||
|         return None | ||||
|  | ||||
| @@ -88,53 +77,20 @@ class Item: | ||||
|         return f'<{self.__class__.__name__} {attrs}>' | ||||
|  | ||||
|     @property | ||||
|     def callback(self) -> Optional[ItemCallbackType]: | ||||
|         """Returns the underlying callback associated with this interaction.""" | ||||
|         return self._callback | ||||
|     def view(self) -> Optional[View]: | ||||
|         """Optional[:class:`View`]: The underlying view for this item.""" | ||||
|         return self._view | ||||
|  | ||||
|     @callback.setter | ||||
|     def callback(self, value: Optional[ItemCallbackType]): | ||||
|         if value is None: | ||||
|             self._callback = None | ||||
|             return | ||||
|     async def callback(self, interaction: Interaction): | ||||
|         """|coro| | ||||
|  | ||||
|         # Check if it's a partial function | ||||
|         try: | ||||
|             partial = value.func | ||||
|         except AttributeError: | ||||
|             pass | ||||
|         else: | ||||
|             if not inspect.iscoroutinefunction(value.func): | ||||
|                 raise TypeError(f'inner partial function must be a coroutine') | ||||
|         The callback associated with this UI item. | ||||
|  | ||||
|             # Check if the partial is bound | ||||
|             try: | ||||
|                 bound_partial = partial.__self__ | ||||
|             except AttributeError: | ||||
|                 pass | ||||
|             else: | ||||
|                 self._pass_view_arg = not hasattr(bound_partial, '__discord_ui_view__') | ||||
|         This can be overriden by subclasses. | ||||
|  | ||||
|             self._callback = value | ||||
|             return | ||||
|  | ||||
|         try: | ||||
|             func_self = value.__self__ | ||||
|         except AttributeError: | ||||
|             pass | ||||
|         else: | ||||
|             if not isinstance(func_self, Item): | ||||
|                 raise TypeError(f'callback bound method must be from Item not {func_self!r}') | ||||
|             else: | ||||
|                 value = value.__func__ | ||||
|  | ||||
|         if not inspect.iscoroutinefunction(value): | ||||
|             raise TypeError(f'callback must be a coroutine not {value!r}') | ||||
|  | ||||
|         self._callback = value | ||||
|  | ||||
|     async def _do_call(self, view: View, interaction: Interaction): | ||||
|         if self._pass_view_arg: | ||||
|             await self._callback(view, self, interaction) | ||||
|         else: | ||||
|             await self._callback(self, interaction)  # type: ignore | ||||
|         Parameters | ||||
|         ----------- | ||||
|         interaction: :class:`Interaction` | ||||
|             The interaction that triggered this UI item. | ||||
|         """ | ||||
|         pass | ||||
|   | ||||
| @@ -31,7 +31,7 @@ import asyncio | ||||
| import sys | ||||
| import time | ||||
| import os | ||||
| from .item import Item | ||||
| from .item import Item, ItemCallbackType | ||||
| from ..enums import ComponentType | ||||
| from ..components import ( | ||||
|     Component, | ||||
| @@ -95,13 +95,13 @@ class View: | ||||
|     __discord_ui_view__: ClassVar[bool] = True | ||||
|  | ||||
|     if TYPE_CHECKING: | ||||
|         __view_children_items__: ClassVar[List[Item]] | ||||
|         __view_children_items__: ClassVar[List[ItemCallbackType]] | ||||
|  | ||||
|     def __init_subclass__(cls) -> None: | ||||
|         children: List[Item] = [] | ||||
|         children: List[ItemCallbackType] = [] | ||||
|         for base in reversed(cls.__mro__): | ||||
|             for member in base.__dict__.values(): | ||||
|                 if isinstance(member, Item): | ||||
|                 if hasattr(member, '__discord_ui_model_type__'): | ||||
|                     children.append(member) | ||||
|  | ||||
|         if len(children) > 25: | ||||
| @@ -111,7 +111,13 @@ class View: | ||||
|  | ||||
|     def __init__(self, timeout: Optional[float] = 180.0): | ||||
|         self.timeout = timeout | ||||
|         self.children: List[Item] = [i.copy() for i in self.__view_children_items__] | ||||
|         self.children: List[Item] = [] | ||||
|         for func in self.__view_children_items__: | ||||
|             item: Item = func.__discord_ui_model_type__(**func.__discord_ui_model_kwargs__) | ||||
|             item.callback = partial(func, self, item) | ||||
|             item._view = self | ||||
|             self.children.append(item) | ||||
|  | ||||
|         self.id = os.urandom(16).hex() | ||||
|         self._cancel_callback: Optional[Callable[[View], None]] = None | ||||
|  | ||||
| @@ -171,11 +177,12 @@ class View: | ||||
|         if not isinstance(item, Item): | ||||
|             raise TypeError(f'expected Item not {item.__class__!r}') | ||||
|  | ||||
|         item._view = self | ||||
|         self.children.append(item) | ||||
|  | ||||
|     async def _scheduled_task(self, state: Any, item: Item, interaction: Interaction): | ||||
|         await state.http.create_interaction_response(interaction.id, interaction.token, type=6) | ||||
|         await item._do_call(self, interaction) | ||||
|         await item.callback(interaction) | ||||
|  | ||||
|     def dispatch(self, state: Any, item: Item, interaction: Interaction): | ||||
|         asyncio.create_task(self._scheduled_task(state, item, interaction), name=f'discord-ui-view-dispatch-{self.id}') | ||||
|   | ||||
		Reference in New Issue
	
	Block a user