Add zstd gateway compression to speed profile

This commit is contained in:
Lilly Rose Berner
2024-10-09 23:30:03 +02:00
committed by GitHub
parent d10e70e04c
commit 91f300a28a
4 changed files with 69 additions and 32 deletions

View File

@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import asyncio
@ -32,7 +33,6 @@ import sys
import time
import threading
import traceback
import zlib
from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Tuple
@ -325,8 +325,7 @@ class DiscordWebSocket:
# ws related stuff
self.session_id: Optional[str] = None
self.sequence: Optional[int] = None
self._zlib: zlib._Decompress = zlib.decompressobj()
self._buffer: bytearray = bytearray()
self._decompressor: utils._DecompressionContext = utils._ActiveDecompressionContext()
self._close_code: Optional[int] = None
self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter()
@ -355,7 +354,7 @@ class DiscordWebSocket:
sequence: Optional[int] = None,
resume: bool = False,
encoding: str = 'json',
zlib: bool = True,
compress: bool = True,
) -> Self:
"""Creates a main websocket for Discord from a :class:`Client`.
@ -366,10 +365,12 @@ class DiscordWebSocket:
gateway = gateway or cls.DEFAULT_GATEWAY
if zlib:
url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding, compress='zlib-stream')
else:
if not compress:
url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding)
else:
url = gateway.with_query(
v=INTERNAL_API_VERSION, encoding=encoding, compress=utils._ActiveDecompressionContext.COMPRESSION_TYPE
)
socket = await client.http.ws_connect(str(url))
ws = cls(socket, loop=client.loop)
@ -488,13 +489,11 @@ class DiscordWebSocket:
async def received_message(self, msg: Any, /) -> None:
if type(msg) is bytes:
self._buffer.extend(msg)
msg = self._decompressor.decompress(msg)
if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff':
# Received a partial gateway message
if msg is None:
return
msg = self._zlib.decompress(self._buffer)
msg = msg.decode('utf-8')
self._buffer = bytearray()
self.log_receive(msg)
msg = utils._from_json(msg)