530 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			530 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""
 | 
						|
The MIT License (MIT)
 | 
						|
 | 
						|
Copyright (c) 2015-present Rapptz
 | 
						|
 | 
						|
Permission is hereby granted, free of charge, to any person obtaining a
 | 
						|
copy of this software and associated documentation files (the "Software"),
 | 
						|
to deal in the Software without restriction, including without limitation
 | 
						|
the rights to use, copy, modify, merge, publish, distribute, sublicense,
 | 
						|
and/or sell copies of the Software, and to permit persons to whom the
 | 
						|
Software is furnished to do so, subject to the following conditions:
 | 
						|
 | 
						|
The above copyright notice and this permission notice shall be included in
 | 
						|
all copies or substantial portions of the Software.
 | 
						|
 | 
						|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
 | 
						|
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 | 
						|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 | 
						|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 | 
						|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 | 
						|
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
 | 
						|
DEALINGS IN THE SOFTWARE.
 | 
						|
"""
 | 
						|
 | 
						|
from __future__ import annotations
 | 
						|
from typing import Any, Callable, ClassVar, Dict, Iterator, List, Optional, Sequence, TYPE_CHECKING, Tuple
 | 
						|
from functools import partial
 | 
						|
from itertools import groupby
 | 
						|
 | 
						|
import traceback
 | 
						|
import asyncio
 | 
						|
import sys
 | 
						|
import time
 | 
						|
import os
 | 
						|
from .item import Item, ItemCallbackType
 | 
						|
from ..components import (
 | 
						|
    Component,
 | 
						|
    ActionRow as ActionRowComponent,
 | 
						|
    _component_factory,
 | 
						|
    Button as ButtonComponent,
 | 
						|
    SelectMenu as SelectComponent,
 | 
						|
)
 | 
						|
 | 
						|
__all__ = (
 | 
						|
    'View',
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
if TYPE_CHECKING:
 | 
						|
    from ..interactions import Interaction
 | 
						|
    from ..message import Message
 | 
						|
    from ..types.components import Component as ComponentPayload
 | 
						|
    from ..state import ConnectionState
 | 
						|
 | 
						|
 | 
						|
def _walk_all_components(components: List[Component]) -> Iterator[Component]:
 | 
						|
    for item in components:
 | 
						|
        if isinstance(item, ActionRowComponent):
 | 
						|
            yield from item.children
 | 
						|
        else:
 | 
						|
            yield item
 | 
						|
 | 
						|
 | 
						|
def _component_to_item(component: Component) -> Item:
 | 
						|
    if isinstance(component, ButtonComponent):
 | 
						|
        from .button import Button
 | 
						|
 | 
						|
        return Button.from_component(component)
 | 
						|
    if isinstance(component, SelectComponent):
 | 
						|
        from .select import Select
 | 
						|
 | 
						|
        return Select.from_component(component)
 | 
						|
    return Item.from_component(component)
 | 
						|
 | 
						|
 | 
						|
class _ViewWeights:
 | 
						|
    __slots__ = (
 | 
						|
        'weights',
 | 
						|
    )
 | 
						|
 | 
						|
    def __init__(self, children: List[Item]):
 | 
						|
        self.weights: List[int] = [0, 0, 0, 0, 0]
 | 
						|
 | 
						|
        key = lambda i: sys.maxsize if i.row is None else i.row
 | 
						|
        children = sorted(children, key=key)
 | 
						|
        for row, group in groupby(children, key=key):
 | 
						|
            for item in group:
 | 
						|
                self.add_item(item)
 | 
						|
 | 
						|
    def find_open_space(self, item: Item) -> int:
 | 
						|
        for index, weight in enumerate(self.weights):
 | 
						|
            if weight + item.width <= 5:
 | 
						|
                return index
 | 
						|
 | 
						|
        raise ValueError('could not find open space for item')
 | 
						|
 | 
						|
    def add_item(self, item: Item) -> None:
 | 
						|
        if item.row is not None:
 | 
						|
            total = self.weights[item.row] + item.width
 | 
						|
            if total > 5:
 | 
						|
                raise ValueError(f'item would not fit at row {item.row} ({total} > 5 width)')
 | 
						|
            self.weights[item.row] = total
 | 
						|
            item._rendered_row = item.row
 | 
						|
        else:
 | 
						|
            index = self.find_open_space(item)
 | 
						|
            self.weights[index] += item.width
 | 
						|
            item._rendered_row = index
 | 
						|
 | 
						|
    def remove_item(self, item: Item) -> None:
 | 
						|
        if item._rendered_row is not None:
 | 
						|
            self.weights[item._rendered_row] -= item.width
 | 
						|
            item._rendered_row = None
 | 
						|
 | 
						|
    def clear(self) -> None:
 | 
						|
        self.weights = [0, 0, 0, 0, 0]
 | 
						|
 | 
						|
 | 
						|
class View:
 | 
						|
    """Represents a UI view.
 | 
						|
 | 
						|
    This object must be inherited to create a UI within Discord.
 | 
						|
 | 
						|
    .. versionadded:: 2.0
 | 
						|
 | 
						|
    Parameters
 | 
						|
    -----------
 | 
						|
    timeout: Optional[:class:`float`]
 | 
						|
        Timeout in seconds from last interaction with the UI before no longer accepting input.
 | 
						|
        If ``None`` then there is no timeout.
 | 
						|
 | 
						|
    Attributes
 | 
						|
    ------------
 | 
						|
    timeout: Optional[:class:`float`]
 | 
						|
        Timeout from last interaction with the UI before no longer accepting input.
 | 
						|
        If ``None`` then there is no timeout.
 | 
						|
    children: List[:class:`Item`]
 | 
						|
        The list of children attached to this view.
 | 
						|
    """
 | 
						|
 | 
						|
    __discord_ui_view__: ClassVar[bool] = True
 | 
						|
    __view_children_items__: ClassVar[List[ItemCallbackType]] = []
 | 
						|
 | 
						|
    def __init_subclass__(cls) -> None:
 | 
						|
        children: List[ItemCallbackType] = []
 | 
						|
        for base in reversed(cls.__mro__):
 | 
						|
            for member in base.__dict__.values():
 | 
						|
                if hasattr(member, '__discord_ui_model_type__'):
 | 
						|
                    children.append(member)
 | 
						|
 | 
						|
        if len(children) > 25:
 | 
						|
            raise TypeError('View cannot have more than 25 children')
 | 
						|
 | 
						|
        cls.__view_children_items__ = children
 | 
						|
 | 
						|
    def __init__(self, *, timeout: Optional[float] = 180.0):
 | 
						|
        self.timeout = timeout
 | 
						|
        self.children: List[Item] = []
 | 
						|
        for func in self.__view_children_items__:
 | 
						|
            item: Item = func.__discord_ui_model_type__(**func.__discord_ui_model_kwargs__)
 | 
						|
            item.callback = partial(func, self, item)
 | 
						|
            item._view = self
 | 
						|
            setattr(self, func.__name__, item)
 | 
						|
            self.children.append(item)
 | 
						|
 | 
						|
        self.__weights = _ViewWeights(self.children)
 | 
						|
        loop = asyncio.get_running_loop()
 | 
						|
        self.id: str = os.urandom(16).hex()
 | 
						|
        self.__cancel_callback: Optional[Callable[[View], None]] = None
 | 
						|
        self.__timeout_expiry: Optional[float] = None
 | 
						|
        self.__timeout_task: Optional[asyncio.Task[None]] = None
 | 
						|
        self.__stopped: asyncio.Future[bool] = loop.create_future()
 | 
						|
 | 
						|
    def __repr__(self) -> str:
 | 
						|
        return f'<{self.__class__.__name__} timeout={self.timeout} children={len(self.children)}>'
 | 
						|
 | 
						|
    async def __timeout_task_impl(self) -> None:
 | 
						|
        while True:
 | 
						|
            # Guard just in case someone changes the value of the timeout at runtime
 | 
						|
            if self.timeout is None:
 | 
						|
                return
 | 
						|
 | 
						|
            if self.__timeout_expiry is None:
 | 
						|
                return self._dispatch_timeout()
 | 
						|
 | 
						|
            # Check if we've elapsed our currently set timeout
 | 
						|
            now = time.monotonic()
 | 
						|
            if now >= self.__timeout_expiry:
 | 
						|
                return self._dispatch_timeout()
 | 
						|
 | 
						|
            # Wait N seconds to see if timeout data has been refreshed
 | 
						|
            await asyncio.sleep(self.__timeout_expiry - now)
 | 
						|
 | 
						|
    def to_components(self) -> List[Dict[str, Any]]:
 | 
						|
        def key(item: Item) -> int:
 | 
						|
            return item._rendered_row or 0
 | 
						|
 | 
						|
        children = sorted(self.children, key=key)
 | 
						|
        components: List[Dict[str, Any]] = []
 | 
						|
        for _, group in groupby(children, key=key):
 | 
						|
            children = [item.to_component_dict() for item in group]
 | 
						|
            if not children:
 | 
						|
                continue
 | 
						|
 | 
						|
            components.append(
 | 
						|
                {
 | 
						|
                    'type': 1,
 | 
						|
                    'components': children,
 | 
						|
                }
 | 
						|
            )
 | 
						|
 | 
						|
        return components
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def from_message(cls, message: Message, /, *, timeout: Optional[float] = 180.0) -> View:
 | 
						|
        """Converts a message's components into a :class:`View`.
 | 
						|
 | 
						|
        The :attr:`.Message.components` of a message are read-only
 | 
						|
        and separate types from those in the ``discord.ui`` namespace.
 | 
						|
        In order to modify and edit message components they must be
 | 
						|
        converted into a :class:`View` first.
 | 
						|
 | 
						|
        Parameters
 | 
						|
        -----------
 | 
						|
        message: :class:`discord.Message`
 | 
						|
            The message with components to convert into a view.
 | 
						|
        timeout: Optional[:class:`float`]
 | 
						|
            The timeout of the converted view.
 | 
						|
 | 
						|
        Returns
 | 
						|
        --------
 | 
						|
        :class:`View`
 | 
						|
            The converted view. This always returns a :class:`View` and not
 | 
						|
            one of its subclasses.
 | 
						|
        """
 | 
						|
        view = View(timeout=timeout)
 | 
						|
        for component in _walk_all_components(message.components):
 | 
						|
            view.add_item(_component_to_item(component))
 | 
						|
        return view
 | 
						|
 | 
						|
    @property
 | 
						|
    def _expires_at(self) -> Optional[float]:
 | 
						|
        if self.timeout:
 | 
						|
            return time.monotonic() + self.timeout
 | 
						|
        return None
 | 
						|
 | 
						|
    def add_item(self, item: Item) -> None:
 | 
						|
        """Adds an item to the view.
 | 
						|
 | 
						|
        Parameters
 | 
						|
        -----------
 | 
						|
        item: :class:`Item`
 | 
						|
            The item to add to the view.
 | 
						|
 | 
						|
        Raises
 | 
						|
        --------
 | 
						|
        TypeError
 | 
						|
            An :class:`Item` was not passed.
 | 
						|
        ValueError
 | 
						|
            Maximum number of children has been exceeded (25)
 | 
						|
            or the row the item is trying to be added to is full.
 | 
						|
        """
 | 
						|
 | 
						|
        if len(self.children) > 25:
 | 
						|
            raise ValueError('maximum number of children exceeded')
 | 
						|
 | 
						|
        if not isinstance(item, Item):
 | 
						|
            raise TypeError(f'expected Item not {item.__class__!r}')
 | 
						|
 | 
						|
        self.__weights.add_item(item)
 | 
						|
 | 
						|
        item._view = self
 | 
						|
        self.children.append(item)
 | 
						|
 | 
						|
    def remove_item(self, item: Item) -> None:
 | 
						|
        """Removes an item from the view.
 | 
						|
 | 
						|
        Parameters
 | 
						|
        -----------
 | 
						|
        item: :class:`Item`
 | 
						|
            The item to remove from the view.
 | 
						|
        """
 | 
						|
 | 
						|
        try:
 | 
						|
            self.children.remove(item)
 | 
						|
        except ValueError:
 | 
						|
            pass
 | 
						|
        else:
 | 
						|
            self.__weights.remove_item(item)
 | 
						|
 | 
						|
    def clear_items(self) -> None:
 | 
						|
        """Removes all items from the view."""
 | 
						|
        self.children.clear()
 | 
						|
        self.__weights.clear()
 | 
						|
 | 
						|
    async def interaction_check(self, interaction: Interaction) -> bool:
 | 
						|
        """|coro|
 | 
						|
 | 
						|
        A callback that is called when an interaction happens within the view
 | 
						|
        that checks whether the view should process item callbacks for the interaction.
 | 
						|
 | 
						|
        This is useful to override if, for example, you want to ensure that the
 | 
						|
        interaction author is a given user.
 | 
						|
 | 
						|
        The default implementation of this returns ``True``.
 | 
						|
 | 
						|
        .. note::
 | 
						|
 | 
						|
            If an exception occurs within the body then the check
 | 
						|
            is considered a failure and :meth:`on_error` is called.
 | 
						|
 | 
						|
        Parameters
 | 
						|
        -----------
 | 
						|
        interaction: :class:`~discord.Interaction`
 | 
						|
            The interaction that occurred.
 | 
						|
 | 
						|
        Returns
 | 
						|
        ---------
 | 
						|
        :class:`bool`
 | 
						|
            Whether the view children's callbacks should be called.
 | 
						|
        """
 | 
						|
        return True
 | 
						|
 | 
						|
    async def on_timeout(self) -> None:
 | 
						|
        """|coro|
 | 
						|
 | 
						|
        A callback that is called when a view's timeout elapses without being explicitly stopped.
 | 
						|
        """
 | 
						|
        pass
 | 
						|
 | 
						|
    async def on_error(self, error: Exception, item: Item, interaction: Interaction) -> None:
 | 
						|
        """|coro|
 | 
						|
 | 
						|
        A callback that is called when an item's callback or :meth:`interaction_check`
 | 
						|
        fails with an error.
 | 
						|
 | 
						|
        The default implementation prints the traceback to stderr.
 | 
						|
 | 
						|
        Parameters
 | 
						|
        -----------
 | 
						|
        error: :class:`Exception`
 | 
						|
            The exception that was raised.
 | 
						|
        item: :class:`Item`
 | 
						|
            The item that failed the dispatch.
 | 
						|
        interaction: :class:`~discord.Interaction`
 | 
						|
            The interaction that led to the failure.
 | 
						|
        """
 | 
						|
        print(f'Ignoring exception in view {self} for item {item}:', file=sys.stderr)
 | 
						|
        traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr)
 | 
						|
 | 
						|
    async def _scheduled_task(self, item: Item, interaction: Interaction):
 | 
						|
        try:
 | 
						|
            if self.timeout:
 | 
						|
                self.__timeout_expiry = time.monotonic() + self.timeout
 | 
						|
 | 
						|
            allow = await self.interaction_check(interaction)
 | 
						|
            if not allow:
 | 
						|
                return
 | 
						|
 | 
						|
            await item.callback(interaction)
 | 
						|
            if not interaction.response._responded:
 | 
						|
                await interaction.response.defer()
 | 
						|
        except Exception as e:
 | 
						|
            return await self.on_error(e, item, interaction)
 | 
						|
 | 
						|
    def _start_listening_from_store(self, store: ViewStore) -> None:
 | 
						|
        self.__cancel_callback = partial(store.remove_view)
 | 
						|
        if self.timeout:
 | 
						|
            loop = asyncio.get_running_loop()
 | 
						|
            if self.__timeout_task is not None:
 | 
						|
                self.__timeout_task.cancel()
 | 
						|
 | 
						|
            self.__timeout_expiry = time.monotonic() + self.timeout
 | 
						|
            self.__timeout_task = loop.create_task(self.__timeout_task_impl())
 | 
						|
 | 
						|
    def _dispatch_timeout(self):
 | 
						|
        if self.__stopped.done():
 | 
						|
            return
 | 
						|
 | 
						|
        self.__stopped.set_result(True)
 | 
						|
        asyncio.create_task(self.on_timeout(), name=f'discord-ui-view-timeout-{self.id}')
 | 
						|
 | 
						|
    def _dispatch_item(self, item: Item, interaction: Interaction):
 | 
						|
        if self.__stopped.done():
 | 
						|
            return
 | 
						|
 | 
						|
        asyncio.create_task(self._scheduled_task(item, interaction), name=f'discord-ui-view-dispatch-{self.id}')
 | 
						|
 | 
						|
    def refresh(self, components: List[Component]):
 | 
						|
        # This is pretty hacky at the moment
 | 
						|
        # fmt: off
 | 
						|
        old_state: Dict[Tuple[int, str], Item] = {
 | 
						|
            (item.type.value, item.custom_id): item  # type: ignore
 | 
						|
            for item in self.children
 | 
						|
            if item.is_dispatchable()
 | 
						|
        }
 | 
						|
        # fmt: on
 | 
						|
        children: List[Item] = []
 | 
						|
        for component in _walk_all_components(components):
 | 
						|
            try:
 | 
						|
                older = old_state[(component.type.value, component.custom_id)]  # type: ignore
 | 
						|
            except (KeyError, AttributeError):
 | 
						|
                children.append(_component_to_item(component))
 | 
						|
            else:
 | 
						|
                older.refresh_component(component)
 | 
						|
                children.append(older)
 | 
						|
 | 
						|
        self.children = children
 | 
						|
 | 
						|
    def stop(self) -> None:
 | 
						|
        """Stops listening to interaction events from this view.
 | 
						|
 | 
						|
        This operation cannot be undone.
 | 
						|
        """
 | 
						|
        if not self.__stopped.done():
 | 
						|
            self.__stopped.set_result(False)
 | 
						|
 | 
						|
        self.__timeout_expiry = None
 | 
						|
        if self.__timeout_task is not None:
 | 
						|
            self.__timeout_task.cancel()
 | 
						|
            self.__timeout_task = None
 | 
						|
 | 
						|
        if self.__cancel_callback:
 | 
						|
            self.__cancel_callback(self)
 | 
						|
            self.__cancel_callback = None
 | 
						|
 | 
						|
    def is_finished(self) -> bool:
 | 
						|
        """:class:`bool`: Whether the view has finished interacting."""
 | 
						|
        return self.__stopped.done()
 | 
						|
 | 
						|
    def is_dispatching(self) -> bool:
 | 
						|
        """:class:`bool`: Whether the view has been added for dispatching purposes."""
 | 
						|
        return self.__cancel_callback is not None
 | 
						|
 | 
						|
    def is_persistent(self) -> bool:
 | 
						|
        """:class:`bool`: Whether the view is set up as persistent.
 | 
						|
 | 
						|
        A persistent view has all their components with a set ``custom_id`` and
 | 
						|
        a :attr:`timeout` set to ``None``.
 | 
						|
        """
 | 
						|
        return self.timeout is None and all(item.is_persistent() for item in self.children)
 | 
						|
 | 
						|
    async def wait(self) -> bool:
 | 
						|
        """Waits until the view has finished interacting.
 | 
						|
 | 
						|
        A view is considered finished when :meth:`stop` is called
 | 
						|
        or it times out.
 | 
						|
 | 
						|
        Returns
 | 
						|
        --------
 | 
						|
        :class:`bool`
 | 
						|
            If ``True``, then the view timed out. If ``False`` then
 | 
						|
            the view finished normally.
 | 
						|
        """
 | 
						|
        return await self.__stopped
 | 
						|
 | 
						|
 | 
						|
class ViewStore:
 | 
						|
    def __init__(self, state: ConnectionState):
 | 
						|
        # (component_type, message_id, custom_id): (View, Item)
 | 
						|
        self._views: Dict[Tuple[int, Optional[int], str], Tuple[View, Item]] = {}
 | 
						|
        # message_id: View
 | 
						|
        self._synced_message_views: Dict[int, View] = {}
 | 
						|
        self._state: ConnectionState = state
 | 
						|
 | 
						|
    @property
 | 
						|
    def persistent_views(self) -> Sequence[View]:
 | 
						|
        # fmt: off
 | 
						|
        views = {
 | 
						|
            view.id: view
 | 
						|
            for (_, (view, _)) in self._views.items()
 | 
						|
            if view.is_persistent()
 | 
						|
        }
 | 
						|
        # fmt: on
 | 
						|
        return list(views.values())
 | 
						|
 | 
						|
    def __verify_integrity(self):
 | 
						|
        to_remove: List[Tuple[int, Optional[int], str]] = []
 | 
						|
        for (k, (view, _)) in self._views.items():
 | 
						|
            if view.is_finished():
 | 
						|
                to_remove.append(k)
 | 
						|
 | 
						|
        for k in to_remove:
 | 
						|
            del self._views[k]
 | 
						|
 | 
						|
    def add_view(self, view: View, message_id: Optional[int] = None):
 | 
						|
        self.__verify_integrity()
 | 
						|
 | 
						|
        view._start_listening_from_store(self)
 | 
						|
        for item in view.children:
 | 
						|
            if item.is_dispatchable():
 | 
						|
                self._views[(item.type.value, message_id, item.custom_id)] = (view, item)  # type: ignore
 | 
						|
 | 
						|
        if message_id is not None:
 | 
						|
            self._synced_message_views[message_id] = view
 | 
						|
 | 
						|
    def remove_view(self, view: View):
 | 
						|
        for item in view.children:
 | 
						|
            if item.is_dispatchable():
 | 
						|
                self._views.pop((item.type.value, item.custom_id), None)  # type: ignore
 | 
						|
 | 
						|
        for key, value in self._synced_message_views.items():
 | 
						|
            if value.id == view.id:
 | 
						|
                del self._synced_message_views[key]
 | 
						|
                break
 | 
						|
 | 
						|
    def dispatch(self, component_type: int, custom_id: str, interaction: Interaction):
 | 
						|
        self.__verify_integrity()
 | 
						|
        message_id: Optional[int] = interaction.message and interaction.message.id
 | 
						|
        key = (component_type, message_id, custom_id)
 | 
						|
        # Fallback to None message_id searches in case a persistent view
 | 
						|
        # was added without an associated message_id
 | 
						|
        value = self._views.get(key) or self._views.get((component_type, None, custom_id))
 | 
						|
        if value is None:
 | 
						|
            return
 | 
						|
 | 
						|
        view, item = value
 | 
						|
        item.refresh_state(interaction)
 | 
						|
        view._dispatch_item(item, interaction)
 | 
						|
 | 
						|
    def is_message_tracked(self, message_id: int):
 | 
						|
        return message_id in self._synced_message_views
 | 
						|
 | 
						|
    def remove_message_tracking(self, message_id: int) -> Optional[View]:
 | 
						|
        return self._synced_message_views.pop(message_id, None)
 | 
						|
 | 
						|
    def update_from_message(self, message_id: int, components: List[ComponentPayload]):
 | 
						|
        # pre-req: is_message_tracked == true
 | 
						|
        view = self._synced_message_views[message_id]
 | 
						|
        view.refresh([_component_factory(d) for d in components])
 |