Reimplement zlib streaming.
This time with less bugs. It turned out that the crash was due to a synchronisation issue between the pending reads and the actual shard polling mechanism. Essentially the pending reads would be cancelled via a simple bool but there would still be a pass left and thus we would have a single pending read left before or after running the polling mechanism and this would cause a race condition. Now the pending read mechanism is properly waited for before returning control back to the caller.
This commit is contained in:
parent
c3a727ac7e
commit
47a58d354d
@ -186,6 +186,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
# ws related stuff
|
# ws related stuff
|
||||||
self.session_id = None
|
self.session_id = None
|
||||||
self.sequence = None
|
self.sequence = None
|
||||||
|
self._zlib = zlib.decompressobj()
|
||||||
|
self._buffer = bytearray()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
@ -312,8 +314,17 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
self._dispatch('socket_raw_receive', msg)
|
self._dispatch('socket_raw_receive', msg)
|
||||||
|
|
||||||
if isinstance(msg, bytes):
|
if isinstance(msg, bytes):
|
||||||
msg = zlib.decompress(msg, 15, 10490000) # This is 10 MiB
|
self._buffer.extend(msg)
|
||||||
msg = msg.decode('utf-8')
|
|
||||||
|
if len(msg) >= 4:
|
||||||
|
if msg[-4:] == b'\x00\x00\xff\xff':
|
||||||
|
msg = self._zlib.decompress(self._buffer)
|
||||||
|
msg = msg.decode('utf-8')
|
||||||
|
self._buffer = bytearray()
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
msg = json.loads(msg)
|
msg = json.loads(msg)
|
||||||
|
|
||||||
|
@ -739,21 +739,29 @@ class HTTPClient:
|
|||||||
return self.request(Route('GET', '/oauth2/applications/@me'))
|
return self.request(Route('GET', '/oauth2/applications/@me'))
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def get_gateway(self):
|
def get_gateway(self, *, encoding='json', v=6, zlib=True):
|
||||||
try:
|
try:
|
||||||
data = yield from self.request(Route('GET', '/gateway'))
|
data = yield from self.request(Route('GET', '/gateway'))
|
||||||
except HTTPException as e:
|
except HTTPException as e:
|
||||||
raise GatewayNotFound() from e
|
raise GatewayNotFound() from e
|
||||||
return data.get('url') + '?encoding=json&v=6'
|
if zlib:
|
||||||
|
value = '{0}?encoding={1}&v={2}&compress=zlib-stream'
|
||||||
|
else:
|
||||||
|
value = '{0}?encoding={1}&v={2}'
|
||||||
|
return value.format(data['url'], encoding, v)
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def get_bot_gateway(self):
|
def get_bot_gateway(self, *, encoding='json', v=6, zlib=True):
|
||||||
try:
|
try:
|
||||||
data = yield from self.request(Route('GET', '/gateway/bot'))
|
data = yield from self.request(Route('GET', '/gateway/bot'))
|
||||||
except HTTPException as e:
|
except HTTPException as e:
|
||||||
raise GatewayNotFound() from e
|
raise GatewayNotFound() from e
|
||||||
|
|
||||||
|
if zlib:
|
||||||
|
value = '{0}?encoding={1}&v={2}&compress=zlib-stream'
|
||||||
else:
|
else:
|
||||||
return data['shards'], data['url'] + '?encoding=json&v=6'
|
value = '{0}?encoding={1}&v={2}'
|
||||||
|
return data['shards'], value.format(data['url'], encoding, v)
|
||||||
|
|
||||||
def get_user_info(self, user_id):
|
def get_user_info(self, user_id):
|
||||||
return self.request(Route('GET', '/users/{user_id}', user_id=user_id))
|
return self.request(Route('GET', '/users/{user_id}', user_id=user_id))
|
||||||
|
@ -28,7 +28,7 @@ from .state import AutoShardedConnectionState
|
|||||||
from .client import Client
|
from .client import Client
|
||||||
from .gateway import *
|
from .gateway import *
|
||||||
from .errors import ClientException, InvalidArgument
|
from .errors import ClientException, InvalidArgument
|
||||||
from . import compat
|
from . import compat, utils
|
||||||
from .enums import Status
|
from .enums import Status
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@ -45,11 +45,32 @@ class Shard:
|
|||||||
self.loop = self._client.loop
|
self.loop = self._client.loop
|
||||||
self._current = compat.create_future(self.loop)
|
self._current = compat.create_future(self.loop)
|
||||||
self._current.set_result(None) # we just need an already done future
|
self._current.set_result(None) # we just need an already done future
|
||||||
|
self._pending = asyncio.Event(loop=self.loop)
|
||||||
|
self._pending_task = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def id(self):
|
def id(self):
|
||||||
return self.ws.shard_id
|
return self.ws.shard_id
|
||||||
|
|
||||||
|
def is_pending(self):
|
||||||
|
return not self._pending.is_set()
|
||||||
|
|
||||||
|
def complete_pending_reads(self):
|
||||||
|
self._pending.set()
|
||||||
|
|
||||||
|
def _pending_reads(self):
|
||||||
|
try:
|
||||||
|
while self.is_pending():
|
||||||
|
yield from self.poll()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def launch_pending_reads(self):
|
||||||
|
self._pending_task = compat.create_task(self._pending_reads(), loop=self.loop)
|
||||||
|
|
||||||
|
def wait(self):
|
||||||
|
return self._pending_task
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def poll(self):
|
def poll(self):
|
||||||
try:
|
try:
|
||||||
@ -127,7 +148,6 @@ class AutoShardedClient(Client):
|
|||||||
return self.shards[i].ws
|
return self.shards[i].ws
|
||||||
|
|
||||||
self._connection._get_websocket = _get_websocket
|
self._connection._get_websocket = _get_websocket
|
||||||
self._still_sharding = True
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def _chunker(self, guild, *, shard_id=None):
|
def _chunker(self, guild, *, shard_id=None):
|
||||||
@ -199,14 +219,6 @@ class AutoShardedClient(Client):
|
|||||||
sub_guilds = list(sub_guilds)
|
sub_guilds = list(sub_guilds)
|
||||||
yield from self._connection.request_offline_members(sub_guilds, shard_id=shard_id)
|
yield from self._connection.request_offline_members(sub_guilds, shard_id=shard_id)
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def pending_reads(self, shard):
|
|
||||||
try:
|
|
||||||
while self._still_sharding:
|
|
||||||
yield from shard.poll()
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def launch_shard(self, gateway, shard_id):
|
def launch_shard(self, gateway, shard_id):
|
||||||
try:
|
try:
|
||||||
@ -235,7 +247,7 @@ class AutoShardedClient(Client):
|
|||||||
|
|
||||||
# keep reading the shard while others connect
|
# keep reading the shard while others connect
|
||||||
self.shards[shard_id] = ret = Shard(ws, self)
|
self.shards[shard_id] = ret = Shard(ws, self)
|
||||||
compat.create_task(self.pending_reads(ret), loop=self.loop)
|
ret.launch_pending_reads()
|
||||||
yield from asyncio.sleep(5.0, loop=self.loop)
|
yield from asyncio.sleep(5.0, loop=self.loop)
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
@ -252,7 +264,13 @@ class AutoShardedClient(Client):
|
|||||||
for shard_id in shard_ids:
|
for shard_id in shard_ids:
|
||||||
yield from self.launch_shard(gateway, shard_id)
|
yield from self.launch_shard(gateway, shard_id)
|
||||||
|
|
||||||
self._still_sharding = False
|
shards_to_wait_for = []
|
||||||
|
for shard in self.shards.values():
|
||||||
|
shard.complete_pending_reads()
|
||||||
|
shards_to_wait_for.append(shard.wait())
|
||||||
|
|
||||||
|
# wait for all pending tasks to finish
|
||||||
|
yield from utils.sane_wait_for(shards_to_wait_for, timeout=300.0, loop=self.loop)
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def _connect(self):
|
def _connect(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user