mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-04-18 15:06:07 +00:00
Add zstd gateway compression to speed profile
This commit is contained in:
parent
d10e70e04c
commit
91f300a28a
@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
@ -32,7 +33,6 @@ import sys
|
||||
import time
|
||||
import threading
|
||||
import traceback
|
||||
import zlib
|
||||
|
||||
from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Tuple
|
||||
|
||||
@ -325,8 +325,7 @@ class DiscordWebSocket:
|
||||
# ws related stuff
|
||||
self.session_id: Optional[str] = None
|
||||
self.sequence: Optional[int] = None
|
||||
self._zlib: zlib._Decompress = zlib.decompressobj()
|
||||
self._buffer: bytearray = bytearray()
|
||||
self._decompressor: utils._DecompressionContext = utils._ActiveDecompressionContext()
|
||||
self._close_code: Optional[int] = None
|
||||
self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter()
|
||||
|
||||
@ -355,7 +354,7 @@ class DiscordWebSocket:
|
||||
sequence: Optional[int] = None,
|
||||
resume: bool = False,
|
||||
encoding: str = 'json',
|
||||
zlib: bool = True,
|
||||
compress: bool = True,
|
||||
) -> Self:
|
||||
"""Creates a main websocket for Discord from a :class:`Client`.
|
||||
|
||||
@ -366,10 +365,12 @@ class DiscordWebSocket:
|
||||
|
||||
gateway = gateway or cls.DEFAULT_GATEWAY
|
||||
|
||||
if zlib:
|
||||
url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding, compress='zlib-stream')
|
||||
else:
|
||||
if not compress:
|
||||
url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding)
|
||||
else:
|
||||
url = gateway.with_query(
|
||||
v=INTERNAL_API_VERSION, encoding=encoding, compress=utils._ActiveDecompressionContext.COMPRESSION_TYPE
|
||||
)
|
||||
|
||||
socket = await client.http.ws_connect(str(url))
|
||||
ws = cls(socket, loop=client.loop)
|
||||
@ -488,13 +489,11 @@ class DiscordWebSocket:
|
||||
|
||||
async def received_message(self, msg: Any, /) -> None:
|
||||
if type(msg) is bytes:
|
||||
self._buffer.extend(msg)
|
||||
msg = self._decompressor.decompress(msg)
|
||||
|
||||
if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff':
|
||||
# Received a partial gateway message
|
||||
if msg is None:
|
||||
return
|
||||
msg = self._zlib.decompress(self._buffer)
|
||||
msg = msg.decode('utf-8')
|
||||
self._buffer = bytearray()
|
||||
|
||||
self.log_receive(msg)
|
||||
msg = utils._from_json(msg)
|
||||
|
@ -2701,28 +2701,13 @@ class HTTPClient:
|
||||
|
||||
# Misc
|
||||
|
||||
async def get_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> str:
|
||||
try:
|
||||
data = await self.request(Route('GET', '/gateway'))
|
||||
except HTTPException as exc:
|
||||
raise GatewayNotFound() from exc
|
||||
if zlib:
|
||||
value = '{0}?encoding={1}&v={2}&compress=zlib-stream'
|
||||
else:
|
||||
value = '{0}?encoding={1}&v={2}'
|
||||
return value.format(data['url'], encoding, INTERNAL_API_VERSION)
|
||||
|
||||
async def get_bot_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> Tuple[int, str]:
|
||||
async def get_bot_gateway(self) -> Tuple[int, str]:
|
||||
try:
|
||||
data = await self.request(Route('GET', '/gateway/bot'))
|
||||
except HTTPException as exc:
|
||||
raise GatewayNotFound() from exc
|
||||
|
||||
if zlib:
|
||||
value = '{0}?encoding={1}&v={2}&compress=zlib-stream'
|
||||
else:
|
||||
value = '{0}?encoding={1}&v={2}'
|
||||
return data['shards'], value.format(data['url'], encoding, INTERNAL_API_VERSION)
|
||||
return data['shards'], data['url']
|
||||
|
||||
def get_user(self, user_id: Snowflake) -> Response[user.User]:
|
||||
return self.request(Route('GET', '/users/{user_id}', user_id=user_id))
|
||||
|
@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import array
|
||||
@ -41,7 +42,6 @@ from typing import (
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Protocol,
|
||||
@ -71,6 +71,7 @@ import types
|
||||
import typing
|
||||
import warnings
|
||||
import logging
|
||||
import zlib
|
||||
|
||||
import yarl
|
||||
|
||||
@ -81,6 +82,12 @@ except ModuleNotFoundError:
|
||||
else:
|
||||
HAS_ORJSON = True
|
||||
|
||||
try:
|
||||
import zstandard # type: ignore
|
||||
except ImportError:
|
||||
_HAS_ZSTD = False
|
||||
else:
|
||||
_HAS_ZSTD = True
|
||||
|
||||
__all__ = (
|
||||
'oauth_url',
|
||||
@ -148,8 +155,11 @@ if TYPE_CHECKING:
|
||||
from .invite import Invite
|
||||
from .template import Template
|
||||
|
||||
class _RequestLike(Protocol):
|
||||
headers: Mapping[str, Any]
|
||||
class _DecompressionContext(Protocol):
|
||||
COMPRESSION_TYPE: str
|
||||
|
||||
def decompress(self, data: bytes, /) -> str | None:
|
||||
...
|
||||
|
||||
P = ParamSpec('P')
|
||||
|
||||
@ -1416,3 +1426,45 @@ def _human_join(seq: Sequence[str], /, *, delimiter: str = ', ', final: str = 'o
|
||||
return f'{seq[0]} {final} {seq[1]}'
|
||||
|
||||
return delimiter.join(seq[:-1]) + f' {final} {seq[-1]}'
|
||||
|
||||
|
||||
if _HAS_ZSTD:
|
||||
|
||||
class _ZstdDecompressionContext:
|
||||
__slots__ = ('context',)
|
||||
|
||||
COMPRESSION_TYPE: str = 'zstd-stream'
|
||||
|
||||
def __init__(self) -> None:
|
||||
decompressor = zstandard.ZstdDecompressor()
|
||||
self.context = decompressor.decompressobj()
|
||||
|
||||
def decompress(self, data: bytes, /) -> str | None:
|
||||
# Each WS message is a complete gateway message
|
||||
return self.context.decompress(data).decode('utf-8')
|
||||
|
||||
_ActiveDecompressionContext: Type[_DecompressionContext] = _ZstdDecompressionContext
|
||||
else:
|
||||
|
||||
class _ZlibDecompressionContext:
|
||||
__slots__ = ('context', 'buffer')
|
||||
|
||||
COMPRESSION_TYPE: str = 'zlib-stream'
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.buffer: bytearray = bytearray()
|
||||
self.context = zlib.decompressobj()
|
||||
|
||||
def decompress(self, data: bytes, /) -> str | None:
|
||||
self.buffer.extend(data)
|
||||
|
||||
# Check whether ending is Z_SYNC_FLUSH
|
||||
if len(data) < 4 or data[-4:] != b'\x00\x00\xff\xff':
|
||||
return
|
||||
|
||||
msg = self.context.decompress(self.buffer)
|
||||
self.buffer = bytearray()
|
||||
|
||||
return msg.decode('utf-8')
|
||||
|
||||
_ActiveDecompressionContext: Type[_DecompressionContext] = _ZlibDecompressionContext
|
||||
|
@ -56,6 +56,7 @@ speed = [
|
||||
"aiodns>=1.1; sys_platform != 'win32'",
|
||||
"Brotli",
|
||||
"cchardet==2.1.7; python_version < '3.10'",
|
||||
"zstandard>=0.23.0"
|
||||
]
|
||||
test = [
|
||||
"coverage[toml]",
|
||||
|
Loading…
x
Reference in New Issue
Block a user