Parse gateway URL as an actual URL using yarl

Discord has changed the URL format to make it infeasible to edit it
using basic string interpolation.
This commit is contained in:
Rapptz 2022-09-17 22:49:29 -04:00
parent 46d194df57
commit 8aaeb6acfa
2 changed files with 13 additions and 10 deletions

View File

@ -37,6 +37,7 @@ import zlib
from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar
import aiohttp import aiohttp
import yarl
from . import utils from . import utils
from .activity import BaseActivity from .activity import BaseActivity
@ -287,11 +288,11 @@ class DiscordWebSocket:
_initial_identify: bool _initial_identify: bool
shard_id: Optional[int] shard_id: Optional[int]
shard_count: Optional[int] shard_count: Optional[int]
gateway: str gateway: yarl.URL
_max_heartbeat_timeout: float _max_heartbeat_timeout: float
# fmt: off # fmt: off
DEFAULT_GATEWAY = 'wss://gateway.discord.gg/' DEFAULT_GATEWAY = yarl.URL('wss://gateway.discord.gg/')
DISPATCH = 0 DISPATCH = 0
HEARTBEAT = 1 HEARTBEAT = 1
IDENTIFY = 2 IDENTIFY = 2
@ -346,7 +347,7 @@ class DiscordWebSocket:
client: Client, client: Client,
*, *,
initial: bool = False, initial: bool = False,
gateway: Optional[str] = None, gateway: Optional[yarl.URL] = None,
shard_id: Optional[int] = None, shard_id: Optional[int] = None,
session: Optional[str] = None, session: Optional[str] = None,
sequence: Optional[int] = None, sequence: Optional[int] = None,
@ -364,11 +365,11 @@ class DiscordWebSocket:
gateway = gateway or cls.DEFAULT_GATEWAY gateway = gateway or cls.DEFAULT_GATEWAY
if zlib: if zlib:
url = f'{gateway}?v={INTERNAL_API_VERSION}&encoding={encoding}&compress=zlib-stream' url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding, compress='zlib-stream')
else: else:
url = f'{gateway}?v={INTERNAL_API_VERSION}&encoding={encoding}' url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding)
socket = await client.http.ws_connect(url) socket = await client.http.ws_connect(str(url))
ws = cls(socket, loop=client.loop) ws = cls(socket, loop=client.loop)
# dynamically add attributes needed # dynamically add attributes needed
@ -556,7 +557,7 @@ class DiscordWebSocket:
if event == 'READY': if event == 'READY':
self.sequence = msg['s'] self.sequence = msg['s']
self.session_id = data['session_id'] self.session_id = data['session_id']
self.gateway = data['resume_gateway_url'] self.gateway = yarl.URL(data['resume_gateway_url'])
_log.info('Shard ID %s has connected to Gateway (Session ID: %s).', self.shard_id, self.session_id) _log.info('Shard ID %s has connected to Gateway (Session ID: %s).', self.shard_id, self.session_id)
elif event == 'RESUMED': elif event == 'RESUMED':

View File

@ -28,6 +28,7 @@ import asyncio
import logging import logging
import aiohttp import aiohttp
import yarl
from .state import AutoShardedConnectionState from .state import AutoShardedConnectionState
from .client import Client from .client import Client
@ -403,7 +404,7 @@ class AutoShardedClient(Client):
"""Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object.""" """Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object."""
return {shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items()} return {shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items()}
async def launch_shard(self, gateway: str, shard_id: int, *, initial: bool = False) -> None: async def launch_shard(self, gateway: yarl.URL, shard_id: int, *, initial: bool = False) -> None:
try: try:
coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id) coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id)
ws = await asyncio.wait_for(coro, timeout=180.0) ws = await asyncio.wait_for(coro, timeout=180.0)
@ -422,9 +423,10 @@ class AutoShardedClient(Client):
if self.shard_count is None: if self.shard_count is None:
self.shard_count: int self.shard_count: int
self.shard_count, gateway = await self.http.get_bot_gateway() self.shard_count, gateway_url = await self.http.get_bot_gateway()
gateway = yarl.URL(gateway_url)
else: else:
gateway = await self.http.get_gateway() gateway = DiscordWebSocket.DEFAULT_GATEWAY
self._connection.shard_count = self.shard_count self._connection.shard_count = self.shard_count