looks like I changed stuff?
This commit is contained in:
161
discord/state.py
161
discord/state.py
@@ -1,5 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
@@ -44,7 +42,6 @@ from .emoji import Emoji
|
||||
from .mentions import AllowedMentions
|
||||
from .partial_emoji import PartialEmoji
|
||||
from .message import Message
|
||||
from .relationship import Relationship
|
||||
from .channel import *
|
||||
from .raw_models import *
|
||||
from .member import Member
|
||||
@@ -54,6 +51,7 @@ from . import utils
|
||||
from .flags import Intents, MemberCacheFlags
|
||||
from .object import Object
|
||||
from .invite import Invite
|
||||
from .interactions import Interaction
|
||||
|
||||
class ChunkRequest:
|
||||
def __init__(self, guild_id, loop, resolver, *, cache=True):
|
||||
@@ -104,7 +102,7 @@ async def logging_coroutine(coroutine, *, info):
|
||||
log.exception('Exception occurred during %s', info)
|
||||
|
||||
class ConnectionState:
|
||||
def __init__(self, *, dispatch, handlers, hooks, syncer, http, loop, **options):
|
||||
def __init__(self, *, dispatch, handlers, hooks, http, loop, **options):
|
||||
self.loop = loop
|
||||
self.http = http
|
||||
self.max_messages = options.get('max_messages', 1000)
|
||||
@@ -112,12 +110,11 @@ class ConnectionState:
|
||||
self.max_messages = 1000
|
||||
|
||||
self.dispatch = dispatch
|
||||
self.syncer = syncer
|
||||
self.is_bot = None
|
||||
self.handlers = handlers
|
||||
self.hooks = hooks
|
||||
self.shard_count = None
|
||||
self._ready_task = None
|
||||
self.application_id = utils._get_as_snowflake(options, 'application_id')
|
||||
self.heartbeat_timeout = options.get('heartbeat_timeout', 60.0)
|
||||
self.guild_ready_timeout = options.get('guild_ready_timeout', 2.0)
|
||||
if self.guild_ready_timeout < 0:
|
||||
@@ -149,7 +146,7 @@ class ConnectionState:
|
||||
intents = options.get('intents', None)
|
||||
if intents is not None:
|
||||
if not isinstance(intents, Intents):
|
||||
raise TypeError('intents parameter must be Intent not %r' % type(intents))
|
||||
raise TypeError(f'intents parameter must be Intent not {type(intents)!r}')
|
||||
else:
|
||||
intents = Intents.default()
|
||||
|
||||
@@ -175,7 +172,7 @@ class ConnectionState:
|
||||
cache_flags = MemberCacheFlags.from_intents(intents)
|
||||
else:
|
||||
if not isinstance(cache_flags, MemberCacheFlags):
|
||||
raise TypeError('member_cache_flags parameter must be MemberCacheFlags not %r' % type(cache_flags))
|
||||
raise TypeError(f'member_cache_flags parameter must be MemberCacheFlags not {type(cache_flags)!r}')
|
||||
|
||||
cache_flags._verify_intents(intents)
|
||||
|
||||
@@ -198,7 +195,6 @@ class ConnectionState:
|
||||
self.user = None
|
||||
self._users = weakref.WeakValueDictionary()
|
||||
self._emojis = {}
|
||||
self._calls = {}
|
||||
self._guilds = {}
|
||||
self._voice_clients = {}
|
||||
|
||||
@@ -340,7 +336,7 @@ class ConnectionState:
|
||||
channel_id = channel.id
|
||||
self._private_channels[channel_id] = channel
|
||||
|
||||
if self.is_bot and len(self._private_channels) > 128:
|
||||
if len(self._private_channels) > 128:
|
||||
_, to_remove = self._private_channels.popitem(last=False)
|
||||
if isinstance(to_remove, DMChannel):
|
||||
self._private_channels_by_user.pop(to_remove.recipient.id, None)
|
||||
@@ -405,36 +401,34 @@ class ConnectionState:
|
||||
|
||||
async def _delay_ready(self):
|
||||
try:
|
||||
# only real bots wait for GUILD_CREATE streaming
|
||||
if self.is_bot:
|
||||
states = []
|
||||
while True:
|
||||
# this snippet of code is basically waiting N seconds
|
||||
# until the last GUILD_CREATE was sent
|
||||
try:
|
||||
guild = await asyncio.wait_for(self._ready_state.get(), timeout=self.guild_ready_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
states = []
|
||||
while True:
|
||||
# this snippet of code is basically waiting N seconds
|
||||
# until the last GUILD_CREATE was sent
|
||||
try:
|
||||
guild = await asyncio.wait_for(self._ready_state.get(), timeout=self.guild_ready_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
else:
|
||||
if self._guild_needs_chunking(guild):
|
||||
future = await self.chunk_guild(guild, wait=False)
|
||||
states.append((guild, future))
|
||||
else:
|
||||
if self._guild_needs_chunking(guild):
|
||||
future = await self.chunk_guild(guild, wait=False)
|
||||
states.append((guild, future))
|
||||
if guild.unavailable is False:
|
||||
self.dispatch('guild_available', guild)
|
||||
else:
|
||||
if guild.unavailable is False:
|
||||
self.dispatch('guild_available', guild)
|
||||
else:
|
||||
self.dispatch('guild_join', guild)
|
||||
self.dispatch('guild_join', guild)
|
||||
|
||||
for guild, future in states:
|
||||
try:
|
||||
await asyncio.wait_for(future, timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
log.warning('Shard ID %s timed out waiting for chunks for guild_id %s.', guild.shard_id, guild.id)
|
||||
for guild, future in states:
|
||||
try:
|
||||
await asyncio.wait_for(future, timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
log.warning('Shard ID %s timed out waiting for chunks for guild_id %s.', guild.shard_id, guild.id)
|
||||
|
||||
if guild.unavailable is False:
|
||||
self.dispatch('guild_available', guild)
|
||||
else:
|
||||
self.dispatch('guild_join', guild)
|
||||
if guild.unavailable is False:
|
||||
self.dispatch('guild_available', guild)
|
||||
else:
|
||||
self.dispatch('guild_join', guild)
|
||||
|
||||
# remove the state
|
||||
try:
|
||||
@@ -442,10 +436,6 @@ class ConnectionState:
|
||||
except AttributeError:
|
||||
pass # already been deleted somehow
|
||||
|
||||
# call GUILD_SYNC after we're done chunking
|
||||
if not self.is_bot:
|
||||
log.info('Requesting GUILD_SYNC for %s guilds', len(self.guilds))
|
||||
await self.syncer([s.id for s in self.guilds])
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
else:
|
||||
@@ -464,23 +454,19 @@ class ConnectionState:
|
||||
self.user = user = ClientUser(state=self, data=data['user'])
|
||||
self._users[user.id] = user
|
||||
|
||||
if self.application_id is None:
|
||||
try:
|
||||
application = data['application']
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
self.application_id = utils._get_as_snowflake(application, 'id')
|
||||
|
||||
for guild_data in data['guilds']:
|
||||
self._add_guild_from_data(guild_data)
|
||||
|
||||
for relationship in data.get('relationships', []):
|
||||
try:
|
||||
r_id = int(relationship['id'])
|
||||
except KeyError:
|
||||
continue
|
||||
else:
|
||||
user._relationships[r_id] = Relationship(state=self, data=relationship)
|
||||
|
||||
for pm in data.get('private_channels', []):
|
||||
factory, _ = _channel_factory(pm['type'])
|
||||
self._add_private_channel(factory(me=user, data=pm, state=self))
|
||||
|
||||
self.dispatch('connect')
|
||||
self._ready_task = asyncio.ensure_future(self._delay_ready(), loop=self.loop)
|
||||
self._ready_task = asyncio.create_task(self._delay_ready())
|
||||
|
||||
def parse_resumed(self, data):
|
||||
self.dispatch('resumed')
|
||||
@@ -601,6 +587,10 @@ class ConnectionState:
|
||||
if reaction:
|
||||
self.dispatch('reaction_clear_emoji', reaction)
|
||||
|
||||
def parse_interaction_create(self, data):
|
||||
interaction = Interaction(data=data, state=self)
|
||||
self.dispatch('interaction', interaction)
|
||||
|
||||
def parse_presence_update(self, data):
|
||||
guild_id = utils._get_as_snowflake(data, 'guild_id')
|
||||
guild = self._get_guild(guild_id)
|
||||
@@ -724,22 +714,6 @@ class ConnectionState:
|
||||
else:
|
||||
self.dispatch('guild_channel_pins_update', channel, last_pin)
|
||||
|
||||
def parse_channel_recipient_add(self, data):
|
||||
channel = self._get_private_channel(int(data['channel_id']))
|
||||
user = self.store_user(data['user'])
|
||||
channel.recipients.append(user)
|
||||
self.dispatch('group_join', channel, user)
|
||||
|
||||
def parse_channel_recipient_remove(self, data):
|
||||
channel = self._get_private_channel(int(data['channel_id']))
|
||||
user = self.store_user(data['user'])
|
||||
try:
|
||||
channel.recipients.remove(user)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
self.dispatch('group_remove', channel, user)
|
||||
|
||||
def parse_guild_member_add(self, data):
|
||||
guild = self._get_guild(int(data['guild_id']))
|
||||
if guild is None:
|
||||
@@ -871,7 +845,7 @@ class ConnectionState:
|
||||
|
||||
# check if it requires chunking
|
||||
if self._guild_needs_chunking(guild):
|
||||
asyncio.ensure_future(self._chunk_and_dispatch(guild, unavailable), loop=self.loop)
|
||||
asyncio.create_task(self._chunk_and_dispatch(guild, unavailable))
|
||||
return
|
||||
|
||||
# Dispatch available if newly available
|
||||
@@ -880,10 +854,6 @@ class ConnectionState:
|
||||
else:
|
||||
self.dispatch('guild_join', guild)
|
||||
|
||||
def parse_guild_sync(self, data):
|
||||
guild = self._get_guild(int(data['id']))
|
||||
guild._sync(data)
|
||||
|
||||
def parse_guild_update(self, data):
|
||||
guild = self._get_guild(int(data['id']))
|
||||
if guild is not None:
|
||||
@@ -1015,7 +985,7 @@ class ConnectionState:
|
||||
voice = self._get_voice_client(guild.id)
|
||||
if voice is not None:
|
||||
coro = voice.on_voice_state_update(data)
|
||||
asyncio.ensure_future(logging_coroutine(coro, info='Voice Protocol voice state update handler'))
|
||||
asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice state update handler'))
|
||||
|
||||
member, before, after = guild._update_voice_state(data, channel_id)
|
||||
if member is not None:
|
||||
@@ -1029,11 +999,6 @@ class ConnectionState:
|
||||
self.dispatch('voice_state_update', member, before, after)
|
||||
else:
|
||||
log.debug('VOICE_STATE_UPDATE referencing an unknown member ID: %s. Discarding.', data['user_id'])
|
||||
else:
|
||||
# in here we're either at private or group calls
|
||||
call = self._calls.get(channel_id)
|
||||
if call is not None:
|
||||
call._update_voice_state(data)
|
||||
|
||||
def parse_voice_server_update(self, data):
|
||||
try:
|
||||
@@ -1044,7 +1009,7 @@ class ConnectionState:
|
||||
vc = self._get_voice_client(key_id)
|
||||
if vc is not None:
|
||||
coro = vc.on_voice_server_update(data)
|
||||
asyncio.ensure_future(logging_coroutine(coro, info='Voice Protocol voice server update handler'))
|
||||
asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice server update handler'))
|
||||
|
||||
def parse_typing_start(self, data):
|
||||
channel, guild = self._get_guild_channel(data)
|
||||
@@ -1065,27 +1030,9 @@ class ConnectionState:
|
||||
|
||||
if member is not None:
|
||||
timestamp = datetime.datetime.utcfromtimestamp(data.get('timestamp'))
|
||||
timestamp = timestamp.replace(tzinfo=datetime.timezone.utc)
|
||||
self.dispatch('typing', channel, member, timestamp)
|
||||
|
||||
def parse_relationship_add(self, data):
|
||||
key = int(data['id'])
|
||||
old = self.user.get_relationship(key)
|
||||
new = Relationship(state=self, data=data)
|
||||
self.user._relationships[key] = new
|
||||
if old is not None:
|
||||
self.dispatch('relationship_update', old, new)
|
||||
else:
|
||||
self.dispatch('relationship_add', new)
|
||||
|
||||
def parse_relationship_remove(self, data):
|
||||
key = int(data['id'])
|
||||
try:
|
||||
old = self.user._relationships.pop(key)
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
self.dispatch('relationship_remove', old)
|
||||
|
||||
def _get_reaction_user(self, channel, user_id):
|
||||
if isinstance(channel, TextChannel):
|
||||
return channel.guild.get_member(user_id)
|
||||
@@ -1223,16 +1170,20 @@ class AutoShardedConnectionState(ConnectionState):
|
||||
self.user = user = ClientUser(state=self, data=data['user'])
|
||||
self._users[user.id] = user
|
||||
|
||||
if self.application_id is None:
|
||||
try:
|
||||
application = data['application']
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
self.application_id = utils._get_as_snowflake(application, 'id')
|
||||
|
||||
for guild_data in data['guilds']:
|
||||
self._add_guild_from_data(guild_data)
|
||||
|
||||
if self._messages:
|
||||
self._update_message_references()
|
||||
|
||||
for pm in data.get('private_channels', []):
|
||||
factory, _ = _channel_factory(pm['type'])
|
||||
self._add_private_channel(factory(me=user, data=pm, state=self))
|
||||
|
||||
self.dispatch('connect')
|
||||
self.dispatch('shard_connect', data['__shard_id__'])
|
||||
|
||||
@@ -1245,7 +1196,7 @@ class AutoShardedConnectionState(ConnectionState):
|
||||
gc.collect()
|
||||
|
||||
if self._ready_task is None:
|
||||
self._ready_task = asyncio.ensure_future(self._delay_ready(), loop=self.loop)
|
||||
self._ready_task = asyncio.create_task(self._delay_ready())
|
||||
|
||||
def parse_resumed(self, data):
|
||||
self.dispatch('resumed')
|
||||
|
||||
Reference in New Issue
Block a user