mirror of
				https://github.com/Rapptz/discord.py.git
				synced 2025-10-31 21:43:01 +00:00 
			
		
		
		
	Add zstd gateway compression to speed profile
This commit is contained in:
		
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							d10e70e04c
						
					
				
				
					commit
					91f300a28a
				
			| @@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING | ||||
| FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER | ||||
| DEALINGS IN THE SOFTWARE. | ||||
| """ | ||||
|  | ||||
| from __future__ import annotations | ||||
|  | ||||
| import asyncio | ||||
| @@ -32,7 +33,6 @@ import sys | ||||
| import time | ||||
| import threading | ||||
| import traceback | ||||
| import zlib | ||||
|  | ||||
| from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Tuple | ||||
|  | ||||
| @@ -325,8 +325,7 @@ class DiscordWebSocket: | ||||
|         # ws related stuff | ||||
|         self.session_id: Optional[str] = None | ||||
|         self.sequence: Optional[int] = None | ||||
|         self._zlib: zlib._Decompress = zlib.decompressobj() | ||||
|         self._buffer: bytearray = bytearray() | ||||
|         self._decompressor: utils._DecompressionContext = utils._ActiveDecompressionContext() | ||||
|         self._close_code: Optional[int] = None | ||||
|         self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter() | ||||
|  | ||||
| @@ -355,7 +354,7 @@ class DiscordWebSocket: | ||||
|         sequence: Optional[int] = None, | ||||
|         resume: bool = False, | ||||
|         encoding: str = 'json', | ||||
|         zlib: bool = True, | ||||
|         compress: bool = True, | ||||
|     ) -> Self: | ||||
|         """Creates a main websocket for Discord from a :class:`Client`. | ||||
|  | ||||
| @@ -366,10 +365,12 @@ class DiscordWebSocket: | ||||
|  | ||||
|         gateway = gateway or cls.DEFAULT_GATEWAY | ||||
|  | ||||
|         if zlib: | ||||
|             url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding, compress='zlib-stream') | ||||
|         else: | ||||
|         if not compress: | ||||
|             url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding) | ||||
|         else: | ||||
|             url = gateway.with_query( | ||||
|                 v=INTERNAL_API_VERSION, encoding=encoding, compress=utils._ActiveDecompressionContext.COMPRESSION_TYPE | ||||
|             ) | ||||
|  | ||||
|         socket = await client.http.ws_connect(str(url)) | ||||
|         ws = cls(socket, loop=client.loop) | ||||
| @@ -488,13 +489,11 @@ class DiscordWebSocket: | ||||
|  | ||||
|     async def received_message(self, msg: Any, /) -> None: | ||||
|         if type(msg) is bytes: | ||||
|             self._buffer.extend(msg) | ||||
|             msg = self._decompressor.decompress(msg) | ||||
|  | ||||
|             if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff': | ||||
|             # Received a partial gateway message | ||||
|             if msg is None: | ||||
|                 return | ||||
|             msg = self._zlib.decompress(self._buffer) | ||||
|             msg = msg.decode('utf-8') | ||||
|             self._buffer = bytearray() | ||||
|  | ||||
|         self.log_receive(msg) | ||||
|         msg = utils._from_json(msg) | ||||
|   | ||||
| @@ -2701,28 +2701,13 @@ class HTTPClient: | ||||
|  | ||||
|     # Misc | ||||
|  | ||||
|     async def get_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> str: | ||||
|         try: | ||||
|             data = await self.request(Route('GET', '/gateway')) | ||||
|         except HTTPException as exc: | ||||
|             raise GatewayNotFound() from exc | ||||
|         if zlib: | ||||
|             value = '{0}?encoding={1}&v={2}&compress=zlib-stream' | ||||
|         else: | ||||
|             value = '{0}?encoding={1}&v={2}' | ||||
|         return value.format(data['url'], encoding, INTERNAL_API_VERSION) | ||||
|  | ||||
|     async def get_bot_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> Tuple[int, str]: | ||||
|     async def get_bot_gateway(self) -> Tuple[int, str]: | ||||
|         try: | ||||
|             data = await self.request(Route('GET', '/gateway/bot')) | ||||
|         except HTTPException as exc: | ||||
|             raise GatewayNotFound() from exc | ||||
|  | ||||
|         if zlib: | ||||
|             value = '{0}?encoding={1}&v={2}&compress=zlib-stream' | ||||
|         else: | ||||
|             value = '{0}?encoding={1}&v={2}' | ||||
|         return data['shards'], value.format(data['url'], encoding, INTERNAL_API_VERSION) | ||||
|         return data['shards'], data['url'] | ||||
|  | ||||
|     def get_user(self, user_id: Snowflake) -> Response[user.User]: | ||||
|         return self.request(Route('GET', '/users/{user_id}', user_id=user_id)) | ||||
|   | ||||
| @@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING | ||||
| FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER | ||||
| DEALINGS IN THE SOFTWARE. | ||||
| """ | ||||
|  | ||||
| from __future__ import annotations | ||||
|  | ||||
| import array | ||||
| @@ -41,7 +42,6 @@ from typing import ( | ||||
|     Iterator, | ||||
|     List, | ||||
|     Literal, | ||||
|     Mapping, | ||||
|     NamedTuple, | ||||
|     Optional, | ||||
|     Protocol, | ||||
| @@ -71,6 +71,7 @@ import types | ||||
| import typing | ||||
| import warnings | ||||
| import logging | ||||
| import zlib | ||||
|  | ||||
| import yarl | ||||
|  | ||||
| @@ -81,6 +82,12 @@ except ModuleNotFoundError: | ||||
| else: | ||||
|     HAS_ORJSON = True | ||||
|  | ||||
| try: | ||||
|     import zstandard  # type: ignore | ||||
| except ImportError: | ||||
|     _HAS_ZSTD = False | ||||
| else: | ||||
|     _HAS_ZSTD = True | ||||
|  | ||||
| __all__ = ( | ||||
|     'oauth_url', | ||||
| @@ -148,8 +155,11 @@ if TYPE_CHECKING: | ||||
|     from .invite import Invite | ||||
|     from .template import Template | ||||
|  | ||||
|     class _RequestLike(Protocol): | ||||
|         headers: Mapping[str, Any] | ||||
|     class _DecompressionContext(Protocol): | ||||
|         COMPRESSION_TYPE: str | ||||
|  | ||||
|         def decompress(self, data: bytes, /) -> str | None: | ||||
|             ... | ||||
|  | ||||
|     P = ParamSpec('P') | ||||
|  | ||||
| @@ -1416,3 +1426,45 @@ def _human_join(seq: Sequence[str], /, *, delimiter: str = ', ', final: str = 'o | ||||
|         return f'{seq[0]} {final} {seq[1]}' | ||||
|  | ||||
|     return delimiter.join(seq[:-1]) + f' {final} {seq[-1]}' | ||||
|  | ||||
|  | ||||
| if _HAS_ZSTD: | ||||
|  | ||||
|     class _ZstdDecompressionContext: | ||||
|         __slots__ = ('context',) | ||||
|  | ||||
|         COMPRESSION_TYPE: str = 'zstd-stream' | ||||
|  | ||||
|         def __init__(self) -> None: | ||||
|             decompressor = zstandard.ZstdDecompressor() | ||||
|             self.context = decompressor.decompressobj() | ||||
|  | ||||
|         def decompress(self, data: bytes, /) -> str | None: | ||||
|             # Each WS message is a complete gateway message | ||||
|             return self.context.decompress(data).decode('utf-8') | ||||
|  | ||||
|     _ActiveDecompressionContext: Type[_DecompressionContext] = _ZstdDecompressionContext | ||||
| else: | ||||
|  | ||||
|     class _ZlibDecompressionContext: | ||||
|         __slots__ = ('context', 'buffer') | ||||
|  | ||||
|         COMPRESSION_TYPE: str = 'zlib-stream' | ||||
|  | ||||
|         def __init__(self) -> None: | ||||
|             self.buffer: bytearray = bytearray() | ||||
|             self.context = zlib.decompressobj() | ||||
|  | ||||
|         def decompress(self, data: bytes, /) -> str | None: | ||||
|             self.buffer.extend(data) | ||||
|  | ||||
|             # Check whether ending is Z_SYNC_FLUSH | ||||
|             if len(data) < 4 or data[-4:] != b'\x00\x00\xff\xff': | ||||
|                 return | ||||
|  | ||||
|             msg = self.context.decompress(self.buffer) | ||||
|             self.buffer = bytearray() | ||||
|  | ||||
|             return msg.decode('utf-8') | ||||
|  | ||||
|     _ActiveDecompressionContext: Type[_DecompressionContext] = _ZlibDecompressionContext | ||||
|   | ||||
| @@ -56,6 +56,7 @@ speed = [ | ||||
|     "aiodns>=1.1; sys_platform != 'win32'", | ||||
|     "Brotli", | ||||
|     "cchardet==2.1.7; python_version < '3.10'", | ||||
|     "zstandard>=0.23.0" | ||||
| ] | ||||
| test = [ | ||||
|     "coverage[toml]", | ||||
|   | ||||
		Reference in New Issue
	
	Block a user