mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-04-19 15:36:02 +00:00
Change the way callbacks are defined to allow deriving
This should hopefully make these work more consistently as other functions do.
This commit is contained in:
parent
cc56f31bcd
commit
4c0ebc9221
@ -87,8 +87,6 @@ class Button(Item):
|
||||
The emoji of the button, if available.
|
||||
"""
|
||||
|
||||
__slots__: Tuple[str, ...] = Item.__slots__ + ('_underlying',)
|
||||
|
||||
__item_repr_attributes__: Tuple[str, ...] = (
|
||||
'style',
|
||||
'url',
|
||||
@ -192,19 +190,6 @@ class Button(Item):
|
||||
else:
|
||||
self._underlying.emoji = None
|
||||
|
||||
def copy(self: B) -> B:
|
||||
button = self.__class__(
|
||||
style=self.style,
|
||||
label=self.label,
|
||||
disabled=self.disabled,
|
||||
custom_id=self.custom_id,
|
||||
url=self.url,
|
||||
emoji=self.emoji,
|
||||
group=self.group_id,
|
||||
)
|
||||
button.callback = self.callback
|
||||
return button
|
||||
|
||||
@classmethod
|
||||
def from_component(cls: Type[B], button: ButtonComponent) -> B:
|
||||
return cls(
|
||||
@ -239,7 +224,7 @@ def button(
|
||||
style: ButtonStyle = ButtonStyle.grey,
|
||||
emoji: Optional[Union[str, PartialEmoji]] = None,
|
||||
group: Optional[int] = None,
|
||||
) -> Callable[[ItemCallbackType], Button]:
|
||||
) -> Callable[[ItemCallbackType], ItemCallbackType]:
|
||||
"""A decorator that attaches a button to a component.
|
||||
|
||||
The function being decorated should have three parameters, ``self`` representing
|
||||
@ -275,14 +260,22 @@ def button(
|
||||
ordering.
|
||||
"""
|
||||
|
||||
def decorator(func: ItemCallbackType) -> Button:
|
||||
def decorator(func: ItemCallbackType) -> ItemCallbackType:
|
||||
nonlocal custom_id
|
||||
if not inspect.iscoroutinefunction(func):
|
||||
raise TypeError('button function must be a coroutine function')
|
||||
|
||||
custom_id = custom_id or os.urandom(32).hex()
|
||||
button = Button(style=style, custom_id=custom_id, url=None, disabled=disabled, label=label, emoji=emoji, group=group)
|
||||
button.callback = func
|
||||
return button
|
||||
func.__discord_ui_model_type__ = Button
|
||||
func.__discord_ui_model_kwargs__ = {
|
||||
'style': style,
|
||||
'custom_id': custom_id,
|
||||
'url': None,
|
||||
'disabled': disabled,
|
||||
'label': label,
|
||||
'emoji': emoji,
|
||||
'group': group,
|
||||
}
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
@ -24,8 +24,7 @@ DEALINGS IN THE SOFTWARE.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Coroutine, Dict, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
|
||||
import inspect
|
||||
from typing import Any, Callable, Coroutine, Dict, Optional, TYPE_CHECKING, Tuple, Type, TypeVar
|
||||
|
||||
from ..interactions import Interaction
|
||||
|
||||
@ -50,25 +49,15 @@ class Item:
|
||||
- :class:`discord.ui.Button`
|
||||
"""
|
||||
|
||||
__slots__: Tuple[str, ...] = (
|
||||
'_callback',
|
||||
'_pass_view_arg',
|
||||
'group_id',
|
||||
)
|
||||
|
||||
__item_repr_attributes__: Tuple[str, ...] = ('group_id',)
|
||||
|
||||
def __init__(self):
|
||||
self._callback: Optional[ItemCallbackType] = None
|
||||
self._pass_view_arg = True
|
||||
self._view: Optional[View] = None
|
||||
self.group_id: Optional[int] = None
|
||||
|
||||
def to_component_dict(self) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def copy(self: I) -> I:
|
||||
raise NotImplementedError
|
||||
|
||||
def refresh_state(self, component: Component) -> None:
|
||||
return None
|
||||
|
||||
@ -88,53 +77,20 @@ class Item:
|
||||
return f'<{self.__class__.__name__} {attrs}>'
|
||||
|
||||
@property
|
||||
def callback(self) -> Optional[ItemCallbackType]:
|
||||
"""Returns the underlying callback associated with this interaction."""
|
||||
return self._callback
|
||||
def view(self) -> Optional[View]:
|
||||
"""Optional[:class:`View`]: The underlying view for this item."""
|
||||
return self._view
|
||||
|
||||
@callback.setter
|
||||
def callback(self, value: Optional[ItemCallbackType]):
|
||||
if value is None:
|
||||
self._callback = None
|
||||
return
|
||||
async def callback(self, interaction: Interaction):
|
||||
"""|coro|
|
||||
|
||||
# Check if it's a partial function
|
||||
try:
|
||||
partial = value.func
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
if not inspect.iscoroutinefunction(value.func):
|
||||
raise TypeError(f'inner partial function must be a coroutine')
|
||||
The callback associated with this UI item.
|
||||
|
||||
# Check if the partial is bound
|
||||
try:
|
||||
bound_partial = partial.__self__
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
self._pass_view_arg = not hasattr(bound_partial, '__discord_ui_view__')
|
||||
This can be overriden by subclasses.
|
||||
|
||||
self._callback = value
|
||||
return
|
||||
|
||||
try:
|
||||
func_self = value.__self__
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
if not isinstance(func_self, Item):
|
||||
raise TypeError(f'callback bound method must be from Item not {func_self!r}')
|
||||
else:
|
||||
value = value.__func__
|
||||
|
||||
if not inspect.iscoroutinefunction(value):
|
||||
raise TypeError(f'callback must be a coroutine not {value!r}')
|
||||
|
||||
self._callback = value
|
||||
|
||||
async def _do_call(self, view: View, interaction: Interaction):
|
||||
if self._pass_view_arg:
|
||||
await self._callback(view, self, interaction)
|
||||
else:
|
||||
await self._callback(self, interaction) # type: ignore
|
||||
Parameters
|
||||
-----------
|
||||
interaction: :class:`Interaction`
|
||||
The interaction that triggered this UI item.
|
||||
"""
|
||||
pass
|
||||
|
@ -31,7 +31,7 @@ import asyncio
|
||||
import sys
|
||||
import time
|
||||
import os
|
||||
from .item import Item
|
||||
from .item import Item, ItemCallbackType
|
||||
from ..enums import ComponentType
|
||||
from ..components import (
|
||||
Component,
|
||||
@ -95,13 +95,13 @@ class View:
|
||||
__discord_ui_view__: ClassVar[bool] = True
|
||||
|
||||
if TYPE_CHECKING:
|
||||
__view_children_items__: ClassVar[List[Item]]
|
||||
__view_children_items__: ClassVar[List[ItemCallbackType]]
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
children: List[Item] = []
|
||||
children: List[ItemCallbackType] = []
|
||||
for base in reversed(cls.__mro__):
|
||||
for member in base.__dict__.values():
|
||||
if isinstance(member, Item):
|
||||
if hasattr(member, '__discord_ui_model_type__'):
|
||||
children.append(member)
|
||||
|
||||
if len(children) > 25:
|
||||
@ -111,7 +111,13 @@ class View:
|
||||
|
||||
def __init__(self, timeout: Optional[float] = 180.0):
|
||||
self.timeout = timeout
|
||||
self.children: List[Item] = [i.copy() for i in self.__view_children_items__]
|
||||
self.children: List[Item] = []
|
||||
for func in self.__view_children_items__:
|
||||
item: Item = func.__discord_ui_model_type__(**func.__discord_ui_model_kwargs__)
|
||||
item.callback = partial(func, self, item)
|
||||
item._view = self
|
||||
self.children.append(item)
|
||||
|
||||
self.id = os.urandom(16).hex()
|
||||
self._cancel_callback: Optional[Callable[[View], None]] = None
|
||||
|
||||
@ -171,11 +177,12 @@ class View:
|
||||
if not isinstance(item, Item):
|
||||
raise TypeError(f'expected Item not {item.__class__!r}')
|
||||
|
||||
item._view = self
|
||||
self.children.append(item)
|
||||
|
||||
async def _scheduled_task(self, state: Any, item: Item, interaction: Interaction):
|
||||
await state.http.create_interaction_response(interaction.id, interaction.token, type=6)
|
||||
await item._do_call(self, interaction)
|
||||
await item.callback(interaction)
|
||||
|
||||
def dispatch(self, state: Any, item: Item, interaction: Interaction):
|
||||
asyncio.create_task(self._scheduled_task(state, item, interaction), name=f'discord-ui-view-dispatch-{self.id}')
|
||||
|
Loading…
x
Reference in New Issue
Block a user