[Zstandard] Decompress even when discord doesn't encode size information

This commit is contained in:
Michael H
2025-11-07 18:57:01 -05:00
committed by GitHub
parent 0ace5f8b51
commit 8b15475496

View File

@@ -81,19 +81,20 @@ except ModuleNotFoundError:
else:
HAS_ORJSON = True
_ZSTD_SOURCE: Literal['zstandard', 'compression.zstd'] | None = None
try:
from zstandard import ZstdDecompressor # type: ignore
_HAS_ZSTD = True
_ZSTD_SOURCE = 'zstandard'
except ImportError:
try:
from compression.zstd import ZstdDecompressor # type: ignore
_ZSTD_SOURCE = 'compression.zstd'
except ImportError:
import zlib
_HAS_ZSTD = False
else:
_HAS_ZSTD = True
__all__ = (
'oauth_url',
@@ -1432,7 +1433,7 @@ def _human_join(seq: Sequence[str], /, *, delimiter: str = ', ', final: str = 'o
return delimiter.join(seq[:-1]) + f' {final} {seq[-1]}'
if _HAS_ZSTD:
if _ZSTD_SOURCE is not None:
class _ZstdDecompressionContext:
__slots__ = ('decompressor',)
@@ -1441,6 +1442,12 @@ if _HAS_ZSTD:
def __init__(self) -> None:
self.decompressor = ZstdDecompressor()
if _ZSTD_SOURCE == 'zstandard':
# The default API for zstandard requires a size hint when
# the size is not included in the zstandard frame.
# This constructs an instance of zstandard.ZstdDecompressionObj
# which dynamically allocates a buffer, matching stdlib module's behavior.
self.decompressor = self.decompressor.decompressobj()
def decompress(self, data: bytes, /) -> str | None:
# Each WS message is a complete gateway message