mirror of
				https://github.com/Rapptz/discord.py.git
				synced 2025-10-26 02:53:07 +00:00 
			
		
		
		
	Add Interaction.client property
This commit is contained in:
		| @@ -26,7 +26,7 @@ from __future__ import annotations | |||||||
| import inspect | import inspect | ||||||
| import sys | import sys | ||||||
| import traceback | import traceback | ||||||
| from typing import Callable, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload | from typing import Callable, Dict, Generic, List, Literal, Optional, TYPE_CHECKING, Tuple, TypeVar, Union, overload | ||||||
|  |  | ||||||
|  |  | ||||||
| from .namespace import Namespace, ResolveKey | from .namespace import Namespace, ResolveKey | ||||||
| @@ -52,8 +52,10 @@ if TYPE_CHECKING: | |||||||
|  |  | ||||||
| __all__ = ('CommandTree',) | __all__ = ('CommandTree',) | ||||||
|  |  | ||||||
|  | ClientT = TypeVar('ClientT', bound='Client') | ||||||
|  |  | ||||||
| class CommandTree: |  | ||||||
|  | class CommandTree(Generic[ClientT]): | ||||||
|     """Represents a container that holds application command information. |     """Represents a container that holds application command information. | ||||||
|  |  | ||||||
|     Parameters |     Parameters | ||||||
| @@ -62,8 +64,8 @@ class CommandTree: | |||||||
|         The client instance to get application command information from. |         The client instance to get application command information from. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, client: Client): |     def __init__(self, client: ClientT): | ||||||
|         self.client = client |         self.client: ClientT = client | ||||||
|         self._http = client.http |         self._http = client.http | ||||||
|         self._state = client._connection |         self._state = client._connection | ||||||
|         self._state._command_tree = self |         self._state._command_tree = self | ||||||
|   | |||||||
| @@ -77,6 +77,7 @@ from .threads import Thread | |||||||
| from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory | from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory | ||||||
|  |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|  |     from typing_extensions import Self | ||||||
|     from .types.guild import Guild as GuildPayload |     from .types.guild import Guild as GuildPayload | ||||||
|     from .abc import SnowflakeTime, Snowflake, PrivateChannel |     from .abc import SnowflakeTime, Snowflake, PrivateChannel | ||||||
|     from .guild import GuildChannel |     from .guild import GuildChannel | ||||||
| @@ -254,7 +255,7 @@ class Client: | |||||||
|         } |         } | ||||||
|  |  | ||||||
|         self._enable_debug_events: bool = options.pop('enable_debug_events', False) |         self._enable_debug_events: bool = options.pop('enable_debug_events', False) | ||||||
|         self._connection: ConnectionState = self._get_state(**options) |         self._connection: ConnectionState[Self] = self._get_state(**options) | ||||||
|         self._connection.shard_count = self.shard_count |         self._connection.shard_count = self.shard_count | ||||||
|         self._closed: bool = False |         self._closed: bool = False | ||||||
|         self._ready: asyncio.Event = asyncio.Event() |         self._ready: asyncio.Event = asyncio.Event() | ||||||
|   | |||||||
| @@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE. | |||||||
| """ | """ | ||||||
|  |  | ||||||
| from __future__ import annotations | from __future__ import annotations | ||||||
| from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union | from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union | ||||||
| import asyncio | import asyncio | ||||||
|  |  | ||||||
| from . import utils | from . import utils | ||||||
| @@ -53,6 +53,7 @@ if TYPE_CHECKING: | |||||||
|         Interaction as InteractionPayload, |         Interaction as InteractionPayload, | ||||||
|         InteractionData, |         InteractionData, | ||||||
|     ) |     ) | ||||||
|  |     from .client import Client | ||||||
|     from .guild import Guild |     from .guild import Guild | ||||||
|     from .state import ConnectionState |     from .state import ConnectionState | ||||||
|     from .file import File |     from .file import File | ||||||
| @@ -70,9 +71,10 @@ if TYPE_CHECKING: | |||||||
|     ] |     ] | ||||||
|  |  | ||||||
| MISSING: Any = utils.MISSING | MISSING: Any = utils.MISSING | ||||||
|  | ClientT = TypeVar('ClientT', bound='Client') | ||||||
|  |  | ||||||
|  |  | ||||||
| class Interaction: | class Interaction(Generic[ClientT]): | ||||||
|     """Represents a Discord interaction. |     """Represents a Discord interaction. | ||||||
|  |  | ||||||
|     An interaction happens when a user does an action that needs to |     An interaction happens when a user does an action that needs to | ||||||
| @@ -116,6 +118,7 @@ class Interaction: | |||||||
|         'version', |         'version', | ||||||
|         '_permissions', |         '_permissions', | ||||||
|         '_state', |         '_state', | ||||||
|  |         '_client', | ||||||
|         '_session', |         '_session', | ||||||
|         '_original_message', |         '_original_message', | ||||||
|         '_cs_response', |         '_cs_response', | ||||||
| @@ -123,8 +126,9 @@ class Interaction: | |||||||
|         '_cs_channel', |         '_cs_channel', | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     def __init__(self, *, data: InteractionPayload, state: ConnectionState): |     def __init__(self, *, data: InteractionPayload, state: ConnectionState[ClientT]): | ||||||
|         self._state: ConnectionState = state |         self._state: ConnectionState[ClientT] = state | ||||||
|  |         self._client: ClientT = state._get_client() | ||||||
|         self._session: ClientSession = state.http._HTTPClient__session  # type: ignore - Mangled attribute for __session |         self._session: ClientSession = state.http._HTTPClient__session  # type: ignore - Mangled attribute for __session | ||||||
|         self._original_message: Optional[InteractionMessage] = None |         self._original_message: Optional[InteractionMessage] = None | ||||||
|         self._from_data(data) |         self._from_data(data) | ||||||
| @@ -166,6 +170,11 @@ class Interaction: | |||||||
|             except KeyError: |             except KeyError: | ||||||
|                 pass |                 pass | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def client(self) -> ClientT: | ||||||
|  |         """:class:`Client`: The client that is handling this interaction.""" | ||||||
|  |         return self._client | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def guild(self) -> Optional[Guild]: |     def guild(self) -> Optional[Guild]: | ||||||
|         """Optional[:class:`Guild`]: The guild the interaction was sent from.""" |         """Optional[:class:`Guild`]: The guild the interaction was sent from.""" | ||||||
|   | |||||||
| @@ -46,6 +46,7 @@ from .enums import Status | |||||||
| from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict | from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict | ||||||
|  |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|  |     from typing_extensions import Self | ||||||
|     from .gateway import DiscordWebSocket |     from .gateway import DiscordWebSocket | ||||||
|     from .activity import BaseActivity |     from .activity import BaseActivity | ||||||
|     from .enums import Status |     from .enums import Status | ||||||
| @@ -316,7 +317,7 @@ class AutoShardedClient(Client): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     if TYPE_CHECKING: |     if TYPE_CHECKING: | ||||||
|         _connection: AutoShardedConnectionState |         _connection: AutoShardedConnectionState[Self] | ||||||
|  |  | ||||||
|     def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None: |     def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None: | ||||||
|         kwargs.pop('shard_id', None) |         kwargs.pop('shard_id', None) | ||||||
|   | |||||||
| @@ -30,7 +30,21 @@ import copy | |||||||
| import datetime | import datetime | ||||||
| import itertools | import itertools | ||||||
| import logging | import logging | ||||||
| from typing import Dict, Optional, TYPE_CHECKING, Union, Callable, Any, List, TypeVar, Coroutine, Sequence, Tuple, Deque | from typing import ( | ||||||
|  |     Dict, | ||||||
|  |     Generic, | ||||||
|  |     Optional, | ||||||
|  |     TYPE_CHECKING, | ||||||
|  |     Union, | ||||||
|  |     Callable, | ||||||
|  |     Any, | ||||||
|  |     List, | ||||||
|  |     TypeVar, | ||||||
|  |     Coroutine, | ||||||
|  |     Sequence, | ||||||
|  |     Tuple, | ||||||
|  |     Deque, | ||||||
|  | ) | ||||||
| import weakref | import weakref | ||||||
| import inspect | import inspect | ||||||
|  |  | ||||||
| @@ -84,6 +98,8 @@ if TYPE_CHECKING: | |||||||
|     T = TypeVar('T') |     T = TypeVar('T') | ||||||
|     Channel = Union[GuildChannel, VocalGuildChannel, PrivateChannel, PartialMessageable] |     Channel = Union[GuildChannel, VocalGuildChannel, PrivateChannel, PartialMessageable] | ||||||
|  |  | ||||||
|  | ClientT = TypeVar('ClientT', bound='Client') | ||||||
|  |  | ||||||
|  |  | ||||||
| class ChunkRequest: | class ChunkRequest: | ||||||
|     def __init__( |     def __init__( | ||||||
| @@ -143,10 +159,10 @@ async def logging_coroutine(coroutine: Coroutine[Any, Any, T], *, info: str) -> | |||||||
|         _log.exception('Exception occurred during %s', info) |         _log.exception('Exception occurred during %s', info) | ||||||
|  |  | ||||||
|  |  | ||||||
| class ConnectionState: | class ConnectionState(Generic[ClientT]): | ||||||
|     if TYPE_CHECKING: |     if TYPE_CHECKING: | ||||||
|         _get_websocket: Callable[..., DiscordWebSocket] |         _get_websocket: Callable[..., DiscordWebSocket] | ||||||
|         _get_client: Callable[..., Client] |         _get_client: Callable[..., ClientT] | ||||||
|         _parsers: Dict[str, Callable[[Dict[str, Any]], None]] |         _parsers: Dict[str, Callable[[Dict[str, Any]], None]] | ||||||
|  |  | ||||||
|     def __init__( |     def __init__( | ||||||
| @@ -1471,7 +1487,7 @@ class ConnectionState: | |||||||
|         return Message(state=self, channel=channel, data=data) |         return Message(state=self, channel=channel, data=data) | ||||||
|  |  | ||||||
|  |  | ||||||
| class AutoShardedConnectionState(ConnectionState): | class AutoShardedConnectionState(ConnectionState[ClientT]): | ||||||
|     def __init__(self, *args: Any, **kwargs: Any) -> None: |     def __init__(self, *args: Any, **kwargs: Any) -> None: | ||||||
|         super().__init__(*args, **kwargs) |         super().__init__(*args, **kwargs) | ||||||
|         self.shard_ids: Union[List[int], range] = [] |         self.shard_ids: Union[List[int], range] = [] | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user