Add zstd gateway compression to speed profile

This commit is contained in:
Lilly Rose Berner
2024-10-09 23:30:03 +02:00
committed by GitHub
parent d10e70e04c
commit 91f300a28a
4 changed files with 69 additions and 32 deletions

View File

@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import array
@ -41,7 +42,6 @@ from typing import (
Iterator,
List,
Literal,
Mapping,
NamedTuple,
Optional,
Protocol,
@ -71,6 +71,7 @@ import types
import typing
import warnings
import logging
import zlib
import yarl
@ -81,6 +82,12 @@ except ModuleNotFoundError:
else:
HAS_ORJSON = True
try:
import zstandard # type: ignore
except ImportError:
_HAS_ZSTD = False
else:
_HAS_ZSTD = True
__all__ = (
'oauth_url',
@ -148,8 +155,11 @@ if TYPE_CHECKING:
from .invite import Invite
from .template import Template
class _RequestLike(Protocol):
headers: Mapping[str, Any]
class _DecompressionContext(Protocol):
COMPRESSION_TYPE: str
def decompress(self, data: bytes, /) -> str | None:
...
P = ParamSpec('P')
@ -1416,3 +1426,45 @@ def _human_join(seq: Sequence[str], /, *, delimiter: str = ', ', final: str = 'o
return f'{seq[0]} {final} {seq[1]}'
return delimiter.join(seq[:-1]) + f' {final} {seq[-1]}'
if _HAS_ZSTD:
class _ZstdDecompressionContext:
__slots__ = ('context',)
COMPRESSION_TYPE: str = 'zstd-stream'
def __init__(self) -> None:
decompressor = zstandard.ZstdDecompressor()
self.context = decompressor.decompressobj()
def decompress(self, data: bytes, /) -> str | None:
# Each WS message is a complete gateway message
return self.context.decompress(data).decode('utf-8')
_ActiveDecompressionContext: Type[_DecompressionContext] = _ZstdDecompressionContext
else:
class _ZlibDecompressionContext:
__slots__ = ('context', 'buffer')
COMPRESSION_TYPE: str = 'zlib-stream'
def __init__(self) -> None:
self.buffer: bytearray = bytearray()
self.context = zlib.decompressobj()
def decompress(self, data: bytes, /) -> str | None:
self.buffer.extend(data)
# Check whether ending is Z_SYNC_FLUSH
if len(data) < 4 or data[-4:] != b'\x00\x00\xff\xff':
return
msg = self.context.decompress(self.buffer)
self.buffer = bytearray()
return msg.decode('utf-8')
_ActiveDecompressionContext: Type[_DecompressionContext] = _ZlibDecompressionContext