mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-10-24 18:13:00 +00:00
Reimplement zlib streaming.
This time with less bugs. It turned out that the crash was due to a synchronisation issue between the pending reads and the actual shard polling mechanism. Essentially the pending reads would be cancelled via a simple bool but there would still be a pass left and thus we would have a single pending read left before or after running the polling mechanism and this would cause a race condition. Now the pending read mechanism is properly waited for before returning control back to the caller.
This commit is contained in:
@@ -186,6 +186,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
# ws related stuff
|
||||
self.session_id = None
|
||||
self.sequence = None
|
||||
self._zlib = zlib.decompressobj()
|
||||
self._buffer = bytearray()
|
||||
|
||||
@classmethod
|
||||
@asyncio.coroutine
|
||||
@@ -312,8 +314,17 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
self._dispatch('socket_raw_receive', msg)
|
||||
|
||||
if isinstance(msg, bytes):
|
||||
msg = zlib.decompress(msg, 15, 10490000) # This is 10 MiB
|
||||
self._buffer.extend(msg)
|
||||
|
||||
if len(msg) >= 4:
|
||||
if msg[-4:] == b'\x00\x00\xff\xff':
|
||||
msg = self._zlib.decompress(self._buffer)
|
||||
msg = msg.decode('utf-8')
|
||||
self._buffer = bytearray()
|
||||
else:
|
||||
return
|
||||
else:
|
||||
return
|
||||
|
||||
msg = json.loads(msg)
|
||||
|
||||
|
@@ -739,21 +739,29 @@ class HTTPClient:
|
||||
return self.request(Route('GET', '/oauth2/applications/@me'))
|
||||
|
||||
@asyncio.coroutine
|
||||
def get_gateway(self):
|
||||
def get_gateway(self, *, encoding='json', v=6, zlib=True):
|
||||
try:
|
||||
data = yield from self.request(Route('GET', '/gateway'))
|
||||
except HTTPException as e:
|
||||
raise GatewayNotFound() from e
|
||||
return data.get('url') + '?encoding=json&v=6'
|
||||
if zlib:
|
||||
value = '{0}?encoding={1}&v={2}&compress=zlib-stream'
|
||||
else:
|
||||
value = '{0}?encoding={1}&v={2}'
|
||||
return value.format(data['url'], encoding, v)
|
||||
|
||||
@asyncio.coroutine
|
||||
def get_bot_gateway(self):
|
||||
def get_bot_gateway(self, *, encoding='json', v=6, zlib=True):
|
||||
try:
|
||||
data = yield from self.request(Route('GET', '/gateway/bot'))
|
||||
except HTTPException as e:
|
||||
raise GatewayNotFound() from e
|
||||
|
||||
if zlib:
|
||||
value = '{0}?encoding={1}&v={2}&compress=zlib-stream'
|
||||
else:
|
||||
return data['shards'], data['url'] + '?encoding=json&v=6'
|
||||
value = '{0}?encoding={1}&v={2}'
|
||||
return data['shards'], value.format(data['url'], encoding, v)
|
||||
|
||||
def get_user_info(self, user_id):
|
||||
return self.request(Route('GET', '/users/{user_id}', user_id=user_id))
|
||||
|
@@ -28,7 +28,7 @@ from .state import AutoShardedConnectionState
|
||||
from .client import Client
|
||||
from .gateway import *
|
||||
from .errors import ClientException, InvalidArgument
|
||||
from . import compat
|
||||
from . import compat, utils
|
||||
from .enums import Status
|
||||
|
||||
import asyncio
|
||||
@@ -45,11 +45,32 @@ class Shard:
|
||||
self.loop = self._client.loop
|
||||
self._current = compat.create_future(self.loop)
|
||||
self._current.set_result(None) # we just need an already done future
|
||||
self._pending = asyncio.Event(loop=self.loop)
|
||||
self._pending_task = None
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self.ws.shard_id
|
||||
|
||||
def is_pending(self):
|
||||
return not self._pending.is_set()
|
||||
|
||||
def complete_pending_reads(self):
|
||||
self._pending.set()
|
||||
|
||||
def _pending_reads(self):
|
||||
try:
|
||||
while self.is_pending():
|
||||
yield from self.poll()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def launch_pending_reads(self):
|
||||
self._pending_task = compat.create_task(self._pending_reads(), loop=self.loop)
|
||||
|
||||
def wait(self):
|
||||
return self._pending_task
|
||||
|
||||
@asyncio.coroutine
|
||||
def poll(self):
|
||||
try:
|
||||
@@ -127,7 +148,6 @@ class AutoShardedClient(Client):
|
||||
return self.shards[i].ws
|
||||
|
||||
self._connection._get_websocket = _get_websocket
|
||||
self._still_sharding = True
|
||||
|
||||
@asyncio.coroutine
|
||||
def _chunker(self, guild, *, shard_id=None):
|
||||
@@ -199,14 +219,6 @@ class AutoShardedClient(Client):
|
||||
sub_guilds = list(sub_guilds)
|
||||
yield from self._connection.request_offline_members(sub_guilds, shard_id=shard_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def pending_reads(self, shard):
|
||||
try:
|
||||
while self._still_sharding:
|
||||
yield from shard.poll()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
@asyncio.coroutine
|
||||
def launch_shard(self, gateway, shard_id):
|
||||
try:
|
||||
@@ -235,7 +247,7 @@ class AutoShardedClient(Client):
|
||||
|
||||
# keep reading the shard while others connect
|
||||
self.shards[shard_id] = ret = Shard(ws, self)
|
||||
compat.create_task(self.pending_reads(ret), loop=self.loop)
|
||||
ret.launch_pending_reads()
|
||||
yield from asyncio.sleep(5.0, loop=self.loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
@@ -252,7 +264,13 @@ class AutoShardedClient(Client):
|
||||
for shard_id in shard_ids:
|
||||
yield from self.launch_shard(gateway, shard_id)
|
||||
|
||||
self._still_sharding = False
|
||||
shards_to_wait_for = []
|
||||
for shard in self.shards.values():
|
||||
shard.complete_pending_reads()
|
||||
shards_to_wait_for.append(shard.wait())
|
||||
|
||||
# wait for all pending tasks to finish
|
||||
yield from utils.sane_wait_for(shards_to_wait_for, timeout=300.0, loop=self.loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
def _connect(self):
|
||||
|
Reference in New Issue
Block a user