mirror of
				https://github.com/Rapptz/discord.py.git
				synced 2025-10-25 18:43:00 +00:00 
			
		
		
		
	Add Interaction.client property
This commit is contained in:
		| @@ -26,7 +26,7 @@ from __future__ import annotations | ||||
| import inspect | ||||
| import sys | ||||
| import traceback | ||||
| from typing import Callable, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload | ||||
| from typing import Callable, Dict, Generic, List, Literal, Optional, TYPE_CHECKING, Tuple, TypeVar, Union, overload | ||||
|  | ||||
|  | ||||
| from .namespace import Namespace, ResolveKey | ||||
| @@ -52,8 +52,10 @@ if TYPE_CHECKING: | ||||
|  | ||||
| __all__ = ('CommandTree',) | ||||
|  | ||||
| ClientT = TypeVar('ClientT', bound='Client') | ||||
|  | ||||
| class CommandTree: | ||||
|  | ||||
| class CommandTree(Generic[ClientT]): | ||||
|     """Represents a container that holds application command information. | ||||
|  | ||||
|     Parameters | ||||
| @@ -62,8 +64,8 @@ class CommandTree: | ||||
|         The client instance to get application command information from. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, client: Client): | ||||
|         self.client = client | ||||
|     def __init__(self, client: ClientT): | ||||
|         self.client: ClientT = client | ||||
|         self._http = client.http | ||||
|         self._state = client._connection | ||||
|         self._state._command_tree = self | ||||
|   | ||||
| @@ -77,6 +77,7 @@ from .threads import Thread | ||||
| from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from typing_extensions import Self | ||||
|     from .types.guild import Guild as GuildPayload | ||||
|     from .abc import SnowflakeTime, Snowflake, PrivateChannel | ||||
|     from .guild import GuildChannel | ||||
| @@ -254,7 +255,7 @@ class Client: | ||||
|         } | ||||
|  | ||||
|         self._enable_debug_events: bool = options.pop('enable_debug_events', False) | ||||
|         self._connection: ConnectionState = self._get_state(**options) | ||||
|         self._connection: ConnectionState[Self] = self._get_state(**options) | ||||
|         self._connection.shard_count = self.shard_count | ||||
|         self._closed: bool = False | ||||
|         self._ready: asyncio.Event = asyncio.Event() | ||||
|   | ||||
| @@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE. | ||||
| """ | ||||
|  | ||||
| from __future__ import annotations | ||||
| from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union | ||||
| from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union | ||||
| import asyncio | ||||
|  | ||||
| from . import utils | ||||
| @@ -53,6 +53,7 @@ if TYPE_CHECKING: | ||||
|         Interaction as InteractionPayload, | ||||
|         InteractionData, | ||||
|     ) | ||||
|     from .client import Client | ||||
|     from .guild import Guild | ||||
|     from .state import ConnectionState | ||||
|     from .file import File | ||||
| @@ -70,9 +71,10 @@ if TYPE_CHECKING: | ||||
|     ] | ||||
|  | ||||
| MISSING: Any = utils.MISSING | ||||
| ClientT = TypeVar('ClientT', bound='Client') | ||||
|  | ||||
|  | ||||
| class Interaction: | ||||
| class Interaction(Generic[ClientT]): | ||||
|     """Represents a Discord interaction. | ||||
|  | ||||
|     An interaction happens when a user does an action that needs to | ||||
| @@ -116,6 +118,7 @@ class Interaction: | ||||
|         'version', | ||||
|         '_permissions', | ||||
|         '_state', | ||||
|         '_client', | ||||
|         '_session', | ||||
|         '_original_message', | ||||
|         '_cs_response', | ||||
| @@ -123,8 +126,9 @@ class Interaction: | ||||
|         '_cs_channel', | ||||
|     ) | ||||
|  | ||||
|     def __init__(self, *, data: InteractionPayload, state: ConnectionState): | ||||
|         self._state: ConnectionState = state | ||||
|     def __init__(self, *, data: InteractionPayload, state: ConnectionState[ClientT]): | ||||
|         self._state: ConnectionState[ClientT] = state | ||||
|         self._client: ClientT = state._get_client() | ||||
|         self._session: ClientSession = state.http._HTTPClient__session  # type: ignore - Mangled attribute for __session | ||||
|         self._original_message: Optional[InteractionMessage] = None | ||||
|         self._from_data(data) | ||||
| @@ -166,6 +170,11 @@ class Interaction: | ||||
|             except KeyError: | ||||
|                 pass | ||||
|  | ||||
|     @property | ||||
|     def client(self) -> ClientT: | ||||
|         """:class:`Client`: The client that is handling this interaction.""" | ||||
|         return self._client | ||||
|  | ||||
|     @property | ||||
|     def guild(self) -> Optional[Guild]: | ||||
|         """Optional[:class:`Guild`]: The guild the interaction was sent from.""" | ||||
|   | ||||
| @@ -46,6 +46,7 @@ from .enums import Status | ||||
| from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from typing_extensions import Self | ||||
|     from .gateway import DiscordWebSocket | ||||
|     from .activity import BaseActivity | ||||
|     from .enums import Status | ||||
| @@ -316,7 +317,7 @@ class AutoShardedClient(Client): | ||||
|     """ | ||||
|  | ||||
|     if TYPE_CHECKING: | ||||
|         _connection: AutoShardedConnectionState | ||||
|         _connection: AutoShardedConnectionState[Self] | ||||
|  | ||||
|     def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None: | ||||
|         kwargs.pop('shard_id', None) | ||||
|   | ||||
| @@ -30,7 +30,21 @@ import copy | ||||
| import datetime | ||||
| import itertools | ||||
| import logging | ||||
| from typing import Dict, Optional, TYPE_CHECKING, Union, Callable, Any, List, TypeVar, Coroutine, Sequence, Tuple, Deque | ||||
| from typing import ( | ||||
|     Dict, | ||||
|     Generic, | ||||
|     Optional, | ||||
|     TYPE_CHECKING, | ||||
|     Union, | ||||
|     Callable, | ||||
|     Any, | ||||
|     List, | ||||
|     TypeVar, | ||||
|     Coroutine, | ||||
|     Sequence, | ||||
|     Tuple, | ||||
|     Deque, | ||||
| ) | ||||
| import weakref | ||||
| import inspect | ||||
|  | ||||
| @@ -84,6 +98,8 @@ if TYPE_CHECKING: | ||||
|     T = TypeVar('T') | ||||
|     Channel = Union[GuildChannel, VocalGuildChannel, PrivateChannel, PartialMessageable] | ||||
|  | ||||
| ClientT = TypeVar('ClientT', bound='Client') | ||||
|  | ||||
|  | ||||
| class ChunkRequest: | ||||
|     def __init__( | ||||
| @@ -143,10 +159,10 @@ async def logging_coroutine(coroutine: Coroutine[Any, Any, T], *, info: str) -> | ||||
|         _log.exception('Exception occurred during %s', info) | ||||
|  | ||||
|  | ||||
| class ConnectionState: | ||||
| class ConnectionState(Generic[ClientT]): | ||||
|     if TYPE_CHECKING: | ||||
|         _get_websocket: Callable[..., DiscordWebSocket] | ||||
|         _get_client: Callable[..., Client] | ||||
|         _get_client: Callable[..., ClientT] | ||||
|         _parsers: Dict[str, Callable[[Dict[str, Any]], None]] | ||||
|  | ||||
|     def __init__( | ||||
| @@ -1471,7 +1487,7 @@ class ConnectionState: | ||||
|         return Message(state=self, channel=channel, data=data) | ||||
|  | ||||
|  | ||||
| class AutoShardedConnectionState(ConnectionState): | ||||
| class AutoShardedConnectionState(ConnectionState[ClientT]): | ||||
|     def __init__(self, *args: Any, **kwargs: Any) -> None: | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.shard_ids: Union[List[int], range] = [] | ||||
|   | ||||
		Reference in New Issue
	
	Block a user