mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-05-16 18:59:09 +00:00
Move chunking logic back into ConnectionState.
This allows for a nicer design when dealing with parsers that could end up being coroutines.
This commit is contained in:
parent
f437ffe44e
commit
425bd2c091
@ -51,7 +51,7 @@ import logging, traceback
|
|||||||
import sys, time, re, json
|
import sys, time, re, json
|
||||||
import tempfile, os, hashlib
|
import tempfile, os, hashlib
|
||||||
import itertools
|
import itertools
|
||||||
import zlib, math
|
import zlib
|
||||||
from random import randint as random_integer
|
from random import randint as random_integer
|
||||||
|
|
||||||
PY35 = sys.version_info >= (3, 5)
|
PY35 = sys.version_info >= (3, 5)
|
||||||
@ -122,7 +122,7 @@ class Client:
|
|||||||
if max_messages is None or max_messages < 100:
|
if max_messages is None or max_messages < 100:
|
||||||
max_messages = 5000
|
max_messages = 5000
|
||||||
|
|
||||||
self.connection = ConnectionState(self.dispatch, max_messages, loop=self.loop)
|
self.connection = ConnectionState(self.dispatch, self.request_offline_members, max_messages, loop=self.loop)
|
||||||
|
|
||||||
# Blame Jake for this
|
# Blame Jake for this
|
||||||
user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}'
|
user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}'
|
||||||
@ -145,28 +145,6 @@ class Client:
|
|||||||
|
|
||||||
# internals
|
# internals
|
||||||
|
|
||||||
def _get_all_chunks(self):
|
|
||||||
# a chunk has a maximum of 1000 members.
|
|
||||||
# we need to find out how many futures we're actually waiting for
|
|
||||||
large_servers = filter(lambda s: s.large, self.servers)
|
|
||||||
futures = []
|
|
||||||
for server in large_servers:
|
|
||||||
chunks_needed = math.ceil(server._member_count / 1000)
|
|
||||||
for chunk in range(chunks_needed):
|
|
||||||
futures.append(self.connection.receive_chunk(server.id))
|
|
||||||
|
|
||||||
return futures
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def _fill_offline(self):
|
|
||||||
yield from self.request_offline_members(filter(lambda s: s.large, self.servers))
|
|
||||||
chunks = self._get_all_chunks()
|
|
||||||
|
|
||||||
if chunks:
|
|
||||||
yield from asyncio.wait(chunks)
|
|
||||||
|
|
||||||
self.dispatch('ready')
|
|
||||||
|
|
||||||
def _get_cache_filename(self, email):
|
def _get_cache_filename(self, email):
|
||||||
filename = hashlib.md5(email.encode('utf-8')).hexdigest()
|
filename = hashlib.md5(email.encode('utf-8')).hexdigest()
|
||||||
return os.path.join(tempfile.gettempdir(), 'discord_py', filename)
|
return os.path.join(tempfile.gettempdir(), 'discord_py', filename)
|
||||||
@ -392,11 +370,10 @@ class Client:
|
|||||||
func = getattr(self.connection, parser)
|
func = getattr(self.connection, parser)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
log.info('Unhandled event {}'.format(event))
|
log.info('Unhandled event {}'.format(event))
|
||||||
else:
|
|
||||||
func(data)
|
|
||||||
|
|
||||||
if is_ready:
|
result = func(data)
|
||||||
utils.create_task(self._fill_offline(), loop=self.loop)
|
if asyncio.iscoroutine(result):
|
||||||
|
utils.create_task(result, loop=self.loop)
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def _make_websocket(self, initial=True):
|
def _make_websocket(self, initial=True):
|
||||||
|
@ -36,10 +36,9 @@ from .enums import Status
|
|||||||
|
|
||||||
|
|
||||||
from collections import deque, namedtuple
|
from collections import deque, namedtuple
|
||||||
import copy
|
import copy, enum, math
|
||||||
import datetime
|
import datetime
|
||||||
import asyncio
|
import asyncio
|
||||||
import enum
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
class ListenerType(enum.Enum):
|
class ListenerType(enum.Enum):
|
||||||
@ -49,10 +48,11 @@ Listener = namedtuple('Listener', ('type', 'future', 'predicate'))
|
|||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
class ConnectionState:
|
class ConnectionState:
|
||||||
def __init__(self, dispatch, max_messages, *, loop):
|
def __init__(self, dispatch, chunker, max_messages, *, loop):
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
self.max_messages = max_messages
|
self.max_messages = max_messages
|
||||||
self.dispatch = dispatch
|
self.dispatch = dispatch
|
||||||
|
self.chunker = chunker
|
||||||
self._listeners = []
|
self._listeners = []
|
||||||
self.clear()
|
self.clear()
|
||||||
|
|
||||||
@ -128,6 +128,7 @@ class ConnectionState:
|
|||||||
self._add_server(server)
|
self._add_server(server)
|
||||||
return server
|
return server
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
def parse_ready(self, data):
|
def parse_ready(self, data):
|
||||||
self.user = User(**data['user'])
|
self.user = User(**data['user'])
|
||||||
guilds = data.get('guilds')
|
guilds = data.get('guilds')
|
||||||
@ -139,6 +140,23 @@ class ConnectionState:
|
|||||||
self._add_private_channel(PrivateChannel(id=pm['id'],
|
self._add_private_channel(PrivateChannel(id=pm['id'],
|
||||||
user=User(**pm['recipient'])))
|
user=User(**pm['recipient'])))
|
||||||
|
|
||||||
|
# a chunk has a maximum of 1000 members.
|
||||||
|
# we need to find out how many futures we're actually waiting for
|
||||||
|
|
||||||
|
large_servers = [s for s in self.servers if s.large]
|
||||||
|
yield from self.chunker(large_servers)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
for server in large_servers:
|
||||||
|
chunks_needed = math.ceil(server._member_count / 1000)
|
||||||
|
for chunk in range(chunks_needed):
|
||||||
|
chunks.append(self.receive_chunk(server.id))
|
||||||
|
|
||||||
|
if chunks:
|
||||||
|
yield from asyncio.wait(chunks)
|
||||||
|
|
||||||
|
self.dispatch('ready')
|
||||||
|
|
||||||
def parse_message_create(self, data):
|
def parse_message_create(self, data):
|
||||||
channel = self.get_channel(data.get('channel_id'))
|
channel = self.get_channel(data.get('channel_id'))
|
||||||
message = Message(channel=channel, **data)
|
message = Message(channel=channel, **data)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user