diff --git a/discord/utils.py b/discord/utils.py index e2bc96bb1..e015cddb0 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -81,19 +81,20 @@ except ModuleNotFoundError: else: HAS_ORJSON = True +_ZSTD_SOURCE: Literal['zstandard', 'compression.zstd'] | None = None + try: from zstandard import ZstdDecompressor # type: ignore - _HAS_ZSTD = True + _ZSTD_SOURCE = 'zstandard' except ImportError: try: from compression.zstd import ZstdDecompressor # type: ignore + + _ZSTD_SOURCE = 'compression.zstd' except ImportError: import zlib - _HAS_ZSTD = False - else: - _HAS_ZSTD = True __all__ = ( 'oauth_url', @@ -1432,7 +1433,7 @@ def _human_join(seq: Sequence[str], /, *, delimiter: str = ', ', final: str = 'o return delimiter.join(seq[:-1]) + f' {final} {seq[-1]}' -if _HAS_ZSTD: +if _ZSTD_SOURCE is not None: class _ZstdDecompressionContext: __slots__ = ('decompressor',) @@ -1441,6 +1442,12 @@ if _HAS_ZSTD: def __init__(self) -> None: self.decompressor = ZstdDecompressor() + if _ZSTD_SOURCE == 'zstandard': + # The default API for zstandard requires a size hint when + # the size is not included in the zstandard frame. + # This constructs an instance of zstandard.ZstdDecompressionObj + # which dynamically allocates a buffer, matching stdlib module's behavior. + self.decompressor = self.decompressor.decompressobj() def decompress(self, data: bytes, /) -> str | None: # Each WS message is a complete gateway message