mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-09-05 09:26:10 +00:00
Add zstd gateway compression to speed profile
This commit is contained in:
committed by
GitHub
parent
d10e70e04c
commit
91f300a28a
@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import array
|
||||
@ -41,7 +42,6 @@ from typing import (
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Protocol,
|
||||
@ -71,6 +71,7 @@ import types
|
||||
import typing
|
||||
import warnings
|
||||
import logging
|
||||
import zlib
|
||||
|
||||
import yarl
|
||||
|
||||
@ -81,6 +82,12 @@ except ModuleNotFoundError:
|
||||
else:
|
||||
HAS_ORJSON = True
|
||||
|
||||
try:
|
||||
import zstandard # type: ignore
|
||||
except ImportError:
|
||||
_HAS_ZSTD = False
|
||||
else:
|
||||
_HAS_ZSTD = True
|
||||
|
||||
__all__ = (
|
||||
'oauth_url',
|
||||
@ -148,8 +155,11 @@ if TYPE_CHECKING:
|
||||
from .invite import Invite
|
||||
from .template import Template
|
||||
|
||||
class _RequestLike(Protocol):
|
||||
headers: Mapping[str, Any]
|
||||
class _DecompressionContext(Protocol):
|
||||
COMPRESSION_TYPE: str
|
||||
|
||||
def decompress(self, data: bytes, /) -> str | None:
|
||||
...
|
||||
|
||||
P = ParamSpec('P')
|
||||
|
||||
@ -1416,3 +1426,45 @@ def _human_join(seq: Sequence[str], /, *, delimiter: str = ', ', final: str = 'o
|
||||
return f'{seq[0]} {final} {seq[1]}'
|
||||
|
||||
return delimiter.join(seq[:-1]) + f' {final} {seq[-1]}'
|
||||
|
||||
|
||||
if _HAS_ZSTD:
|
||||
|
||||
class _ZstdDecompressionContext:
|
||||
__slots__ = ('context',)
|
||||
|
||||
COMPRESSION_TYPE: str = 'zstd-stream'
|
||||
|
||||
def __init__(self) -> None:
|
||||
decompressor = zstandard.ZstdDecompressor()
|
||||
self.context = decompressor.decompressobj()
|
||||
|
||||
def decompress(self, data: bytes, /) -> str | None:
|
||||
# Each WS message is a complete gateway message
|
||||
return self.context.decompress(data).decode('utf-8')
|
||||
|
||||
_ActiveDecompressionContext: Type[_DecompressionContext] = _ZstdDecompressionContext
|
||||
else:
|
||||
|
||||
class _ZlibDecompressionContext:
|
||||
__slots__ = ('context', 'buffer')
|
||||
|
||||
COMPRESSION_TYPE: str = 'zlib-stream'
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.buffer: bytearray = bytearray()
|
||||
self.context = zlib.decompressobj()
|
||||
|
||||
def decompress(self, data: bytes, /) -> str | None:
|
||||
self.buffer.extend(data)
|
||||
|
||||
# Check whether ending is Z_SYNC_FLUSH
|
||||
if len(data) < 4 or data[-4:] != b'\x00\x00\xff\xff':
|
||||
return
|
||||
|
||||
msg = self.context.decompress(self.buffer)
|
||||
self.buffer = bytearray()
|
||||
|
||||
return msg.decode('utf-8')
|
||||
|
||||
_ActiveDecompressionContext: Type[_DecompressionContext] = _ZlibDecompressionContext
|
||||
|
Reference in New Issue
Block a user