conflict fix

This commit is contained in:
iDutchy
2020-10-28 21:00:48 -05:00
33 changed files with 12183 additions and 3671 deletions

View File

@ -15,7 +15,7 @@ __title__ = 'discord'
__author__ = 'Rapptz'
__license__ = 'MIT'
__copyright__ = 'Copyright 2015-2020 Rapptz'
__version__ = '1.5.0.3'
__version__ = '1.5.1.4'
__path__ = __import__('pkgutil').extend_path(__path__, __name__)
@ -61,7 +61,7 @@ from .team import *
VersionInfo = namedtuple('VersionInfo', 'major minor micro releaselevel serial')
version_info = VersionInfo(major=1, minor=5, micro=0, releaselevel='final', serial=0)
version_info = VersionInfo(major=1, minor=5, micro=1, releaselevel='final', serial=0)
try:
from logging import NullHandler

View File

@ -230,7 +230,7 @@ class GuildChannel:
# not there somehow lol
return
else:
index = next((i for i, c in enumerate(channels) if c.position >= position), -1)
index = next((i for i, c in enumerate(channels) if c.position >= position), len(channels))
# add ourselves at our designated position
channels.insert(index, self)

View File

@ -243,8 +243,8 @@ class Colour:
@classmethod
def dark_theme(cls):
"""A factory method that returns a :class:'Colour' with a value of ``0x36393F``.
Will appear transparent on Discord's dark theme and be the text colour on Discord's light theme.
"""A factory method that returns a :class:`Colour` with a value of ``0x36393F``.
This will appear transparent on Discord's dark theme.
.. versionadded:: 1.5
"""

View File

@ -122,13 +122,49 @@ class MemberConverter(IDConverter):
.. versionchanged:: 1.5
Raise :exc:`.MemberNotFound` instead of generic :exc:`.BadArgument`
.. versionchanged:: 1.5.1
This converter now lazily fetches members from the gateway and HTTP APIs,
optionally caching the result if :attr:`.MemberCacheFlags.joined` is enabled.
"""
async def query_member_named(self, guild, argument):
cache = guild._state._member_cache_flags.joined
if len(argument) > 5 and argument[-5] == '#':
username, _, discriminator = argument.rpartition('#')
members = await guild.query_members(username, limit=100, cache=cache)
return discord.utils.get(members, name=username, discriminator=discriminator)
else:
members = await guild.query_members(argument, limit=100, cache=cache)
return discord.utils.find(lambda m: m.name == argument or m.nick == argument, members)
async def query_member_by_id(self, bot, guild, user_id):
ws = bot._get_websocket(shard_id=guild.shard_id)
cache = guild._state._member_cache_flags.joined
if ws.is_ratelimited():
# If we're being rate limited on the WS, then fall back to using the HTTP API
# So we don't have to wait ~60 seconds for the query to finish
try:
member = await guild.fetch_member(user_id)
except discord.HTTPException:
return None
if cache:
guild._add_member(member)
return member
# If we're not being rate limited then we can use the websocket to actually query
members = await guild.query_members(limit=1, user_ids=[user_id], cache=cache)
if not members:
return None
return members[0]
async def convert(self, ctx, argument):
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]+)>$', argument)
guild = ctx.guild
result = None
user_id = None
if match is None:
# not a mention...
if guild:
@ -143,7 +179,16 @@ class MemberConverter(IDConverter):
result = _get_from_guilds(bot, 'get_member', user_id)
if result is None:
raise MemberNotFound(argument)
if guild is None:
raise MemberNotFound(argument)
if user_id is not None:
result = await self.query_member_by_id(bot, guild, user_id)
else:
result = await self.query_member_named(guild, argument)
if not result:
raise MemberNotFound(argument)
return result

View File

@ -611,7 +611,7 @@ class Intents(BaseFlags):
"""
return 1 << 8
@flag_value
@alias_flag_value
def messages(self):
""":class:`bool`: Whether guild and direct message related events are enabled.
@ -694,7 +694,7 @@ class Intents(BaseFlags):
"""
return 1 << 12
@flag_value
@alias_flag_value
def reactions(self):
""":class:`bool`: Whether guild and direct message reaction related events are enabled.
@ -757,7 +757,7 @@ class Intents(BaseFlags):
"""
return 1 << 13
@flag_value
@alias_flag_value
def typing(self):
""":class:`bool`: Whether guild and direct message typing related events are enabled.

View File

@ -76,6 +76,12 @@ class GatewayRatelimiter:
self.lock = asyncio.Lock()
self.shard_id = None
def is_ratelimited(self):
current = time.time()
if current > self.window + self.per:
return False
return self.remaining == 0
def get_delay(self):
current = time.time()
@ -287,6 +293,9 @@ class DiscordWebSocket:
def open(self):
return not self.socket.closed
def is_ratelimited(self):
return self._rate_limiter.is_ratelimited()
@classmethod
async def from_client(cls, client, *, initial=False, gateway=None, shard_id=None, session=None, sequence=None, resume=False):
"""Creates a main websocket for Discord from a :class:`Client`.
@ -719,6 +728,7 @@ class DiscordVoiceWebSocket:
self.loop = loop
self._keep_alive = None
self._close_code = None
self.secret_key = None
async def send_as_json(self, data):
log.debug('Sending voice websocket frame: %s.', data)
@ -872,7 +882,7 @@ class DiscordVoiceWebSocket:
async def load_secret_key(self, data):
log.info('received secret key for voice connection')
self._connection.secret_key = data.get('secret_key')
self.secret_key = self._connection.secret_key = data.get('secret_key')
await self.speak()
await self.speak(False)

View File

@ -42,6 +42,7 @@ from .flags import MessageFlags
from .file import File
from .utils import escape_mentions
from .guild import Guild
from .mixins import Hashable
class Attachment:
@ -255,7 +256,7 @@ def flatten_handlers(cls):
return cls
@flatten_handlers
class Message:
class Message(Hashable):
r"""Represents a message from Discord.
There should be no need to create one of these manually.
@ -397,9 +398,6 @@ class Message:
def __repr__(self):
return '<Message id={0.id} channel={0.channel!r} type={0.type!r} author={0.author!r} flags={0.flags!r}>'.format(self)
def __eq__(self, other):
return isinstance(other, self.__class__) and self.id == other.id
def _try_patch(self, data, key, transform=None):
try:
value = data[key]

View File

@ -484,7 +484,10 @@ class PermissionOverwrite:
if value not in (True, None, False):
raise TypeError('Expected bool or NoneType, received {0.__class__.__name__}'.format(value))
self._values[key] = value
if value is None:
self._values.pop(key, None)
else:
self._values[key] = value
def pair(self):
"""Tuple[:class:`Permissions`, :class:`Permissions`]: Returns the (allow, deny) pair from this overwrite."""
@ -519,13 +522,13 @@ class PermissionOverwrite:
An empty permission overwrite is one that has no overwrites set
to ``True`` or ``False``.
Returns
-------
:class:`bool`
Indicates if the overwrite is empty.
"""
return all(x is None for x in self._values.values())
return len(self._values) == 0
def update(self, **kwargs):
r"""Bulk updates this permission overwrite object.

View File

@ -58,13 +58,14 @@ from .object import Object
from .invite import Invite
class ChunkRequest:
def __init__(self, guild_id, future, resolver, *, cache=True):
def __init__(self, guild_id, loop, resolver, *, cache=True):
self.guild_id = guild_id
self.resolver = resolver
self.loop = loop
self.cache = cache
self.nonce = os.urandom(16).hex()
self.future = future
self.buffer = [] # List[Member]
self.waiters = []
def add_members(self, members):
self.buffer.extend(members)
@ -78,8 +79,23 @@ class ChunkRequest:
if existing is None or existing.joined_at is None:
guild._add_member(member)
async def wait(self):
future = self.loop.create_future()
self.waiters.append(future)
try:
return await future
finally:
self.waiters.remove(future)
def get_future(self):
future = self.loop.create_future()
self.waiters.append(future)
return future
def done(self):
self.future.set_result(self.buffer)
for future in self.waiters:
if not future.done():
future.set_result(self.buffer)
log = logging.getLogger(__name__)
@ -116,7 +132,7 @@ class ConnectionState:
raise TypeError('allowed_mentions parameter must be AllowedMentions')
self.allowed_mentions = allowed_mentions
self._chunk_requests = []
self._chunk_requests = {} # Dict[Union[int, str], ChunkRequest]
activity = options.get('activity', None)
if activity:
@ -198,20 +214,15 @@ class ConnectionState:
def process_chunk_requests(self, guild_id, nonce, members, complete):
removed = []
for i, request in enumerate(self._chunk_requests):
future = request.future
if future.cancelled():
removed.append(i)
continue
for key, request in self._chunk_requests.items():
if request.guild_id == guild_id and request.nonce == nonce:
request.add_members(members)
if complete:
request.done()
removed.append(i)
removed.append(key)
for index in reversed(removed):
del self._chunk_requests[index]
for key in removed:
del self._chunk_requests[key]
def call_handlers(self, key, *args, **kwargs):
try:
@ -377,14 +388,13 @@ class ConnectionState:
if ws is None:
raise RuntimeError('Somehow do not have a websocket for this guild_id')
future = self.loop.create_future()
request = ChunkRequest(guild.id, future, self._get_guild, cache=cache)
self._chunk_requests.append(request)
request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache)
self._chunk_requests[request.nonce] = request
try:
# start the query operation
await ws.request_chunks(guild_id, query=query, limit=limit, user_ids=user_ids, nonce=request.nonce)
return await asyncio.wait_for(future, timeout=30.0)
return await asyncio.wait_for(request.wait(), timeout=30.0)
except asyncio.TimeoutError:
log.warning('Timed out waiting for chunks with query %r and limit %d for guild_id %d', query, limit, guild_id)
raise
@ -610,7 +620,7 @@ class ConnectionState:
if user_update:
self.dispatch('user_update', user_update[0], user_update[1])
if flags._online_only and member.raw_status == 'offline':
if member.id != self.self_id and flags._online_only and member.raw_status == 'offline':
guild._remove_member(member)
self.dispatch('member_update', old_member, member)
@ -776,6 +786,9 @@ class ConnectionState:
self.dispatch('member_update', old_member, member)
else:
if self._member_cache_flags.joined:
member = Member(data=data, guild=guild, state=self)
guild._add_member(member)
log.debug('GUILD_MEMBER_UPDATE referencing an unknown member ID: %s. Discarding.', user_id)
def parse_guild_emojis_update(self, data):
@ -805,13 +818,14 @@ class ConnectionState:
async def chunk_guild(self, guild, *, wait=True, cache=None):
cache = cache or self._member_cache_flags.joined
future = self.loop.create_future()
request = ChunkRequest(guild.id, future, self._get_guild, cache=cache)
self._chunk_requests.append(request)
await self.chunker(guild.id, nonce=request.nonce)
request = self._chunk_requests.get(guild.id)
if request is None:
self._chunk_requests[guild.id] = request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache)
await self.chunker(guild.id, nonce=request.nonce)
if wait:
return await request.future
return request.future
return await request.wait()
return request.get_future()
async def _chunk_and_dispatch(self, guild, unavailable):
try:
@ -971,8 +985,9 @@ class ConnectionState:
guild = self._get_guild(utils._get_as_snowflake(data, 'guild_id'))
channel_id = utils._get_as_snowflake(data, 'channel_id')
flags = self._member_cache_flags
self_id = self.user.id
if guild is not None:
if int(data['user_id']) == self.user.id:
if int(data['user_id']) == self_id:
voice = self._get_voice_client(guild.id)
if voice is not None:
coro = voice.on_voice_state_update(data)
@ -981,10 +996,10 @@ class ConnectionState:
member, before, after = guild._update_voice_state(data, channel_id)
if member is not None:
if flags.voice:
if channel_id is None and flags.value == MemberCacheFlags.voice.flag:
if channel_id is None and flags._voice_only and member.id != self_id:
# Only remove from cache iff we only have the voice flag enabled
guild._remove_member(member)
else:
elif channel_id is not None:
guild._add_member(member)
self.dispatch('voice_state_update', member, before, after)
@ -1125,7 +1140,7 @@ class AutoShardedConnectionState(ConnectionState):
await utils.sane_wait_for(current_bucket, timeout=max_concurrency * 70.0)
except asyncio.TimeoutError:
fmt = 'Shard ID %s failed to wait for chunks from a sub-bucket with length %d'
log.warning(fmt, self.shard_id, len(current_bucket))
log.warning(fmt, guild.shard_id, len(current_bucket))
finally:
current_bucket = []
@ -1146,7 +1161,7 @@ class AutoShardedConnectionState(ConnectionState):
try:
await utils.sane_wait_for(futures, timeout=timeout)
except asyncio.TimeoutError:
log.warning('Shard ID %s failed to wait for chunks (timeout=%.2f) for %d guilds', self.shard_id,
log.warning('Shard ID %s failed to wait for chunks (timeout=%.2f) for %d guilds', shard_id,
timeout,
len(guilds))
for guild in children:

View File

@ -482,7 +482,7 @@ _MARKDOWN_ESCAPE_SUBREGEX = '|'.join(r'\{0}(?=([\s\S]*((?<!\{0})\{0})))'.format(
_MARKDOWN_ESCAPE_COMMON = r'^>(?:>>)?\s|\[.+\]\(.+\)'
_MARKDOWN_ESCAPE_REGEX = re.compile(r'(?P<markdown>%s|%s)' % (_MARKDOWN_ESCAPE_SUBREGEX, _MARKDOWN_ESCAPE_COMMON))
_MARKDOWN_ESCAPE_REGEX = re.compile(r'(?P<markdown>%s|%s)' % (_MARKDOWN_ESCAPE_SUBREGEX, _MARKDOWN_ESCAPE_COMMON), re.MULTILINE)
def escape_markdown(text, *, as_needed=False, ignore_links=True):
r"""A helper function that escapes Discord's markdown.
@ -521,7 +521,7 @@ def escape_markdown(text, *, as_needed=False, ignore_links=True):
regex = r'(?P<markdown>[_\\~|\*`]|%s)' % _MARKDOWN_ESCAPE_COMMON
if ignore_links:
regex = '(?:%s|%s)' % (url_regex, regex)
return re.sub(regex, replacement, text)
return re.sub(regex, replacement, text, 0, re.MULTILINE)
else:
text = re.sub(r'\\', r'\\\\', text)
return _MARKDOWN_ESCAPE_REGEX.sub(r'\\\1', text)

View File

@ -208,6 +208,7 @@ class VoiceClient(VoiceProtocol):
self._connected = threading.Event()
self._handshaking = False
self._potentially_reconnecting = False
self._voice_state_complete = asyncio.Event()
self._voice_server_complete = asyncio.Event()
@ -250,8 +251,10 @@ class VoiceClient(VoiceProtocol):
self.session_id = data['session_id']
channel_id = data['channel_id']
if not self._handshaking:
if not self._handshaking or self._potentially_reconnecting:
# If we're done handshaking then we just need to update ourselves
# If we're potentially reconnecting due to a 4014, then we need to differentiate
# a channel move and an actual force disconnect
if channel_id is None:
# We're being disconnected so cleanup
await self.disconnect()
@ -294,26 +297,39 @@ class VoiceClient(VoiceProtocol):
self._voice_server_complete.set()
async def voice_connect(self):
self._connections += 1
await self.channel.guild.change_voice_state(channel=self.channel)
async def voice_disconnect(self):
log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id)
await self.channel.guild.change_voice_state(channel=None)
def prepare_handshake(self):
self._voice_state_complete.clear()
self._voice_server_complete.clear()
self._handshaking = True
log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1)
self._connections += 1
def finish_handshake(self):
log.info('Voice handshake complete. Endpoint found %s', self.endpoint)
self._handshaking = False
self._voice_server_complete.clear()
self._voice_state_complete.clear()
async def connect_websocket(self):
ws = await DiscordVoiceWebSocket.from_client(self)
self._connected.clear()
while ws.secret_key is None:
await ws.poll_event()
self._connected.set()
return ws
async def connect(self, *, reconnect, timeout):
log.info('Connecting to voice...')
self.timeout = timeout
try:
del self.secret_key
except AttributeError:
pass
for i in range(5):
self._voice_state_complete.clear()
self._voice_server_complete.clear()
self._handshaking = True
self.prepare_handshake()
# This has to be created before we start the flow.
futures = [
@ -322,7 +338,6 @@ class VoiceClient(VoiceProtocol):
]
# Start the connection flow
log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1)
await self.voice_connect()
try:
@ -331,17 +346,10 @@ class VoiceClient(VoiceProtocol):
await self.disconnect(force=True)
raise
log.info('Voice handshake complete. Endpoint found %s', self.endpoint)
self._handshaking = False
self._voice_server_complete.clear()
self._voice_state_complete.clear()
self.finish_handshake()
try:
self.ws = await DiscordVoiceWebSocket.from_client(self)
self._connected.clear()
while not hasattr(self, 'secret_key'):
await self.ws.poll_event()
self._connected.set()
self.ws = await self.connect_websocket()
break
except (ConnectionClosed, asyncio.TimeoutError):
if reconnect:
@ -355,6 +363,26 @@ class VoiceClient(VoiceProtocol):
if self._runner is None:
self._runner = self.loop.create_task(self.poll_voice_ws(reconnect))
async def potential_reconnect(self):
self.prepare_handshake()
self._potentially_reconnecting = True
try:
# We only care about VOICE_SERVER_UPDATE since VOICE_STATE_UPDATE can come before we get disconnected
await asyncio.wait_for(self._voice_server_complete.wait(), timeout=self.timeout)
except asyncio.TimeoutError:
self._potentially_reconnecting = False
await self.disconnect(force=True)
return False
self.finish_handshake()
self._potentially_reconnecting = False
try:
self.ws = await self.connect_websocket()
except (ConnectionClosed, asyncio.TimeoutError):
return False
else:
return True
@property
def latency(self):
""":class:`float`: Latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.
@ -387,10 +415,19 @@ class VoiceClient(VoiceProtocol):
# 1000 - normal closure (obviously)
# 4014 - voice channel has been deleted.
# 4015 - voice server has crashed
if exc.code in (1000, 4014, 4015):
if exc.code in (1000, 4015):
log.info('Disconnecting from voice normally, close code %d.', exc.code)
await self.disconnect()
break
if exc.code == 4014:
log.info('Disconnected from voice by force... potentially reconnecting.')
successful = await self.potential_reconnect()
if not successful:
log.info('Reconnect was unsuccessful, disconnecting from voice normally...')
await self.disconnect()
break
else:
continue
if not reconnect:
await self.disconnect()