Rework view timeouts to work as documented
This commit is contained in:
		| @@ -162,14 +162,32 @@ class View: | ||||
|  | ||||
|         self.__weights = _ViewWeights(self.children) | ||||
|         loop = asyncio.get_running_loop() | ||||
|         self.id = os.urandom(16).hex() | ||||
|         self._cancel_callback: Optional[Callable[[View], None]] = None | ||||
|         self._timeout_handler: Optional[asyncio.TimerHandle] = None | ||||
|         self._stopped = loop.create_future() | ||||
|         self.id: str = os.urandom(16).hex() | ||||
|         self.__cancel_callback: Optional[Callable[[View], None]] = None | ||||
|         self.__timeout_expiry: Optional[float] = None | ||||
|         self.__timeout_task: Optional[asyncio.Task[None]] = None | ||||
|         self.__stopped: asyncio.Future[bool] = loop.create_future() | ||||
|  | ||||
|     def __repr__(self) -> str: | ||||
|         return f'<{self.__class__.__name__} timeout={self.timeout} children={len(self.children)}>' | ||||
|  | ||||
|     async def __timeout_task_impl(self) -> None: | ||||
|         while True: | ||||
|             # Guard just in case someone changes the value of the timeout at runtime | ||||
|             if self.timeout is None: | ||||
|                 return | ||||
|  | ||||
|             if self.__timeout_expiry is None: | ||||
|                 return self._dispatch_timeout() | ||||
|  | ||||
|             # Check if we've elapsed our currently set timeout | ||||
|             now = time.monotonic() | ||||
|             if now >= self.__timeout_expiry: | ||||
|                 return self._dispatch_timeout() | ||||
|  | ||||
|             # Wait N seconds to see if timeout data has been refreshed | ||||
|             await asyncio.sleep(self.__timeout_expiry - now) | ||||
|  | ||||
|     def to_components(self) -> List[Dict[str, Any]]: | ||||
|         def key(item: Item) -> int: | ||||
|             return item._rendered_row or 0 | ||||
| @@ -328,8 +346,11 @@ class View: | ||||
|         print(f'Ignoring exception in view {self} for item {item}:', file=sys.stderr) | ||||
|         traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr) | ||||
|  | ||||
|     async def _scheduled_task(self, state: Any, item: Item, interaction: Interaction): | ||||
|     async def _scheduled_task(self, item: Item, interaction: Interaction): | ||||
|         try: | ||||
|             if self.timeout: | ||||
|                 self.__timeout_expiry = time.monotonic() + self.timeout | ||||
|  | ||||
|             allow = await self.interaction_check(interaction) | ||||
|             if not allow: | ||||
|                 return | ||||
| @@ -340,21 +361,28 @@ class View: | ||||
|         except Exception as e: | ||||
|             return await self.on_error(e, item, interaction) | ||||
|  | ||||
|     def _start_listening(self, store: ViewStore) -> None: | ||||
|         self._cancel_callback = partial(store.remove_view) | ||||
|     def _start_listening_from_store(self, store: ViewStore) -> None: | ||||
|         self.__cancel_callback = partial(store.remove_view) | ||||
|         if self.timeout: | ||||
|             loop = asyncio.get_running_loop() | ||||
|             self._timeout_handler = loop.call_later(self.timeout, self.dispatch_timeout) | ||||
|             if self.__timeout_task is not None: | ||||
|                 self.__timeout_task.cancel() | ||||
|  | ||||
|     def dispatch_timeout(self): | ||||
|         if self._stopped.done(): | ||||
|             self.__timeout_expiry = time.monotonic() + self.timeout | ||||
|             self.__timeout_task = loop.create_task(self.__timeout_task_impl()) | ||||
|  | ||||
|     def _dispatch_timeout(self): | ||||
|         if self.__stopped.done(): | ||||
|             return | ||||
|  | ||||
|         self._stopped.set_result(True) | ||||
|         self.__stopped.set_result(True) | ||||
|         asyncio.create_task(self.on_timeout(), name=f'discord-ui-view-timeout-{self.id}') | ||||
|  | ||||
|     def dispatch(self, state: Any, item: Item, interaction: Interaction): | ||||
|         asyncio.create_task(self._scheduled_task(state, item, interaction), name=f'discord-ui-view-dispatch-{self.id}') | ||||
|     def _dispatch_item(self, item: Item, interaction: Interaction): | ||||
|         if self.__stopped.done(): | ||||
|             return | ||||
|  | ||||
|         asyncio.create_task(self._scheduled_task(item, interaction), name=f'discord-ui-view-dispatch-{self.id}') | ||||
|  | ||||
|     def refresh(self, components: List[Component]): | ||||
|         # This is pretty hacky at the moment | ||||
| @@ -382,23 +410,25 @@ class View: | ||||
|  | ||||
|         This operation cannot be undone. | ||||
|         """ | ||||
|         if not self._stopped.done(): | ||||
|             self._stopped.set_result(False) | ||||
|         if not self.__stopped.done(): | ||||
|             self.__stopped.set_result(False) | ||||
|  | ||||
|         if self._timeout_handler: | ||||
|             self._timeout_handler.cancel() | ||||
|         self.__timeout_expiry = None | ||||
|         if self.__timeout_task is not None: | ||||
|             self.__timeout_task.cancel() | ||||
|             self.__timeout_task = None | ||||
|  | ||||
|         if self._cancel_callback: | ||||
|             self._cancel_callback(self) | ||||
|             self._cancel_callback = None | ||||
|         if self.__cancel_callback: | ||||
|             self.__cancel_callback(self) | ||||
|             self.__cancel_callback = None | ||||
|  | ||||
|     def is_finished(self) -> bool: | ||||
|         """:class:`bool`: Whether the view has finished interacting.""" | ||||
|         return self._stopped.done() | ||||
|         return self.__stopped.done() | ||||
|  | ||||
|     def is_dispatching(self) -> bool: | ||||
|         """:class:`bool`: Whether the view has been added for dispatching purposes.""" | ||||
|         return self._cancel_callback is not None | ||||
|         return self.__cancel_callback is not None | ||||
|  | ||||
|     def is_persistent(self) -> bool: | ||||
|         """:class:`bool`: Whether the view is set up as persistent. | ||||
| @@ -420,13 +450,13 @@ class View: | ||||
|             If ``True``, then the view timed out. If ``False`` then | ||||
|             the view finished normally. | ||||
|         """ | ||||
|         return await self._stopped | ||||
|         return await self.__stopped | ||||
|  | ||||
|  | ||||
| class ViewStore: | ||||
|     def __init__(self, state: ConnectionState): | ||||
|         # (component_type, custom_id): (View, Item, Expiry) | ||||
|         self._views: Dict[Tuple[int, str], Tuple[View, Item, Optional[float]]] = {} | ||||
|         # (component_type, custom_id): (View, Item) | ||||
|         self._views: Dict[Tuple[int, str], Tuple[View, Item]] = {} | ||||
|         # message_id: View | ||||
|         self._synced_message_views: Dict[int, View] = {} | ||||
|         self._state: ConnectionState = state | ||||
| @@ -436,7 +466,7 @@ class ViewStore: | ||||
|         # fmt: off | ||||
|         views = { | ||||
|             view.id: view | ||||
|             for (_, (view, _, _)) in self._views.items() | ||||
|             for (_, (view, _)) in self._views.items() | ||||
|             if view.is_persistent() | ||||
|         } | ||||
|         # fmt: on | ||||
| @@ -445,8 +475,8 @@ class ViewStore: | ||||
|     def __verify_integrity(self): | ||||
|         to_remove: List[Tuple[int, str]] = [] | ||||
|         now = time.monotonic() | ||||
|         for (k, (_, _, expiry)) in self._views.items(): | ||||
|             if expiry is not None and now >= expiry: | ||||
|         for (k, (view, _)) in self._views.items(): | ||||
|             if view.is_finished(): | ||||
|                 to_remove.append(k) | ||||
|  | ||||
|         for k in to_remove: | ||||
| @@ -455,11 +485,10 @@ class ViewStore: | ||||
|     def add_view(self, view: View, message_id: Optional[int] = None): | ||||
|         self.__verify_integrity() | ||||
|  | ||||
|         expiry = view._expires_at | ||||
|         view._start_listening(self) | ||||
|         view._start_listening_from_store(self) | ||||
|         for item in view.children: | ||||
|             if item.is_dispatchable(): | ||||
|                 self._views[(item.type.value, item.custom_id)] = (view, item, expiry)  # type: ignore | ||||
|                 self._views[(item.type.value, item.custom_id)] = (view, item)  # type: ignore | ||||
|  | ||||
|         if message_id is not None: | ||||
|             self._synced_message_views[message_id] = view | ||||
| @@ -481,10 +510,10 @@ class ViewStore: | ||||
|         if value is None: | ||||
|             return | ||||
|  | ||||
|         view, item, _ = value | ||||
|         self._views[key] = (view, item, view._expires_at) | ||||
|         view, item = value | ||||
|         self._views[key] = (view, item) | ||||
|         item.refresh_state(interaction) | ||||
|         view.dispatch(self._state, item, interaction) | ||||
|         view._dispatch_item(item, interaction) | ||||
|  | ||||
|     def is_message_tracked(self, message_id: int): | ||||
|         return message_id in self._synced_message_views | ||||
|   | ||||
		Reference in New Issue
	
	Block a user