diff --git a/discord/gateway.py b/discord/gateway.py index dcef3eb5..74c60057 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -49,9 +49,15 @@ if TYPE_CHECKING: from .state import ConnectionState from .voice_client import VoiceClient + T = TypeVar('T') DWS = TypeVar('DWS', bound='DiscordWebSocket') DVWS = TypeVar('DVWS', bound='DiscordVoiceWebSocket') + Coro = Callable[..., Coroutine[Any, Any, Any]] + Predicate = Callable[[Dict[str, Any]], bool] + DataCallable = Callable[[Dict[str, Any]], T] + Result = Optional[DataCallable[Any]] + log: logging.Logger = logging.getLogger(__name__) __all__ = ( @@ -82,9 +88,9 @@ class WebSocketClosure(Exception): class EventListener(NamedTuple): - predicate: Callable[[Dict[str, Any]], bool] + predicate: Predicate event: str - result: Optional[Callable[[Dict[str, Any]], Any]] + result: Result future: asyncio.Future @@ -317,9 +323,9 @@ class DiscordWebSocket: # attributes that get set in from_client self.token: str = utils.MISSING self._connection: ConnectionState = utils.MISSING - self._discord_parsers: Dict[str, Callable[[Dict[str, Any]], None]] = utils.MISSING + self._discord_parsers: Dict[str, DataCallable[None]] = utils.MISSING self.gateway: str = utils.MISSING - self.call_hooks: Callable[..., Coroutine[Any, Any, None]] = utils.MISSING + self.call_hooks: Coro = utils.MISSING self._initial_identify: bool = utils.MISSING self.shard_id: Optional[int] = utils.MISSING self.shard_count: Optional[int] = utils.MISSING @@ -384,7 +390,7 @@ class DiscordWebSocket: await ws.resume() return ws - def wait_for(self, event: str, predicate: Callable[[Dict[str, Any]], bool], result: Optional[Callable[[Dict[str, Any]], Any]] = None) -> asyncio.Future: + def wait_for(self, event: str, predicate: Predicate, result: Result = None) -> asyncio.Future: """Waits for a DISPATCH'd event that meets the predicate. Parameters @@ -763,7 +769,7 @@ class DiscordVoiceWebSocket: CLIENT_CONNECT = 12 CLIENT_DISCONNECT = 13 - def __init__(self, socket: aiohttp.ClientWebSocketResponse, loop: asyncio.AbstractEventLoop, *, hook: Optional[Callable[..., Any]] = None) -> None: + def __init__(self, socket: aiohttp.ClientWebSocketResponse, loop: asyncio.AbstractEventLoop, *, hook: Optional[Coro] = None) -> None: self.ws: aiohttp.ClientWebSocketResponse = socket self.loop: asyncio.AbstractEventLoop = loop self._keep_alive: VoiceKeepAliveHandler = utils.MISSING @@ -812,7 +818,7 @@ class DiscordVoiceWebSocket: await self.send_as_json(payload) @classmethod - async def from_client(cls: Type[DVWS], client: VoiceClient, *, resume: bool = False, hook: Optional[Callable[..., Any]] = None) -> DVWS: + async def from_client(cls: Type[DVWS], client: VoiceClient, *, resume: bool = False, hook: Optional[Coro] = None) -> DVWS: """Creates a voice websocket for the :class:`VoiceClient`.""" gateway = 'wss://' + client.endpoint + '/?v=4' http = client._state.http