Parse gateway URL as an actual URL using yarl

Discord has changed the URL format to make it infeasible to edit it
using basic string interpolation.
This commit is contained in:
Rapptz 2022-09-17 22:49:29 -04:00
parent 46d194df57
commit 8aaeb6acfa
2 changed files with 13 additions and 10 deletions

View File

@ -37,6 +37,7 @@ import zlib
from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar
import aiohttp
import yarl
from . import utils
from .activity import BaseActivity
@ -287,11 +288,11 @@ class DiscordWebSocket:
_initial_identify: bool
shard_id: Optional[int]
shard_count: Optional[int]
gateway: str
gateway: yarl.URL
_max_heartbeat_timeout: float
# fmt: off
DEFAULT_GATEWAY = 'wss://gateway.discord.gg/'
DEFAULT_GATEWAY = yarl.URL('wss://gateway.discord.gg/')
DISPATCH = 0
HEARTBEAT = 1
IDENTIFY = 2
@ -346,7 +347,7 @@ class DiscordWebSocket:
client: Client,
*,
initial: bool = False,
gateway: Optional[str] = None,
gateway: Optional[yarl.URL] = None,
shard_id: Optional[int] = None,
session: Optional[str] = None,
sequence: Optional[int] = None,
@ -364,11 +365,11 @@ class DiscordWebSocket:
gateway = gateway or cls.DEFAULT_GATEWAY
if zlib:
url = f'{gateway}?v={INTERNAL_API_VERSION}&encoding={encoding}&compress=zlib-stream'
url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding, compress='zlib-stream')
else:
url = f'{gateway}?v={INTERNAL_API_VERSION}&encoding={encoding}'
url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding)
socket = await client.http.ws_connect(url)
socket = await client.http.ws_connect(str(url))
ws = cls(socket, loop=client.loop)
# dynamically add attributes needed
@ -556,7 +557,7 @@ class DiscordWebSocket:
if event == 'READY':
self.sequence = msg['s']
self.session_id = data['session_id']
self.gateway = data['resume_gateway_url']
self.gateway = yarl.URL(data['resume_gateway_url'])
_log.info('Shard ID %s has connected to Gateway (Session ID: %s).', self.shard_id, self.session_id)
elif event == 'RESUMED':

View File

@ -28,6 +28,7 @@ import asyncio
import logging
import aiohttp
import yarl
from .state import AutoShardedConnectionState
from .client import Client
@ -403,7 +404,7 @@ class AutoShardedClient(Client):
"""Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object."""
return {shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items()}
async def launch_shard(self, gateway: str, shard_id: int, *, initial: bool = False) -> None:
async def launch_shard(self, gateway: yarl.URL, shard_id: int, *, initial: bool = False) -> None:
try:
coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id)
ws = await asyncio.wait_for(coro, timeout=180.0)
@ -422,9 +423,10 @@ class AutoShardedClient(Client):
if self.shard_count is None:
self.shard_count: int
self.shard_count, gateway = await self.http.get_bot_gateway()
self.shard_count, gateway_url = await self.http.get_bot_gateway()
gateway = yarl.URL(gateway_url)
else:
gateway = await self.http.get_gateway()
gateway = DiscordWebSocket.DEFAULT_GATEWAY
self._connection.shard_count = self.shard_count