Replace wait_for_* with a generic Client.wait_for

This commit is contained in:
Rapptz 2017-01-25 22:26:49 -05:00
parent b876133e87
commit e5cb7d295c

View File

@ -41,7 +41,7 @@ import aiohttp
import websockets import websockets
import logging, traceback import logging, traceback
import sys, re, io, enum import sys, re, io
import itertools import itertools
import datetime import datetime
from collections import namedtuple from collections import namedtuple
@ -51,7 +51,6 @@ PY35 = sys.version_info >= (3, 5)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
AppInfo = namedtuple('AppInfo', 'id name description icon owner') AppInfo = namedtuple('AppInfo', 'id name description icon owner')
WaitedReaction = namedtuple('WaitedReaction', 'reaction user')
def app_info_icon_url(self): def app_info_icon_url(self):
"""Retrieves the application's icon_url if it exists. Empty string otherwise.""" """Retrieves the application's icon_url if it exists. Empty string otherwise."""
@ -62,10 +61,6 @@ def app_info_icon_url(self):
AppInfo.icon_url = property(app_info_icon_url) AppInfo.icon_url = property(app_info_icon_url)
class WaitForType(enum.Enum):
message = 0
reaction = 1
class Client: class Client:
"""Represents a client connection that connects to Discord. """Represents a client connection that connects to Discord.
This class is used to interact with the Discord WebSocket and API. This class is used to interact with the Discord WebSocket and API.
@ -113,7 +108,7 @@ class Client:
self.ws = None self.ws = None
self.email = None self.email = None
self.loop = asyncio.get_event_loop() if loop is None else loop self.loop = asyncio.get_event_loop() if loop is None else loop
self._listeners = [] self._listeners = {}
self.shard_id = options.get('shard_id') self.shard_id = options.get('shard_id')
self.shard_count = options.get('shard_count') self.shard_count = options.get('shard_count')
@ -125,8 +120,6 @@ class Client:
self.connection.shard_count = self.shard_count self.connection.shard_count = self.shard_count
self._closed = asyncio.Event(loop=self.loop) self._closed = asyncio.Event(loop=self.loop)
self._is_logged_in = asyncio.Event(loop=self.loop)
self._is_ready = asyncio.Event(loop=self.loop)
# if VoiceClient.warn_nacl: # if VoiceClient.warn_nacl:
# VoiceClient.warn_nacl = False # VoiceClient.warn_nacl = False
@ -156,57 +149,6 @@ class Client:
yield from self.ws.send_as_json(payload) yield from self.ws.send_as_json(payload)
def handle_reaction_add(self, reaction, user):
removed = []
for i, (condition, future, event_type) in enumerate(self._listeners):
if event_type is not WaitForType.reaction:
continue
if future.cancelled():
removed.append(i)
continue
try:
result = condition(reaction, user)
except Exception as e:
future.set_exception(e)
removed.append(i)
else:
if result:
future.set_result(WaitedReaction(reaction, user))
removed.append(i)
for idx in reversed(removed):
del self._listeners[idx]
def handle_message(self, message):
removed = []
for i, (condition, future, event_type) in enumerate(self._listeners):
if event_type is not WaitForType.message:
continue
if future.cancelled():
removed.append(i)
continue
try:
result = condition(message)
except Exception as e:
future.set_exception(e)
removed.append(i)
else:
if result:
future.set_result(message)
removed.append(i)
for idx in reversed(removed):
del self._listeners[idx]
def handle_ready(self):
self._is_ready.set()
def _resolve_invite(self, invite): def _resolve_invite(self, invite):
if isinstance(invite, Invite) or isinstance(invite, Object): if isinstance(invite, Invite) or isinstance(invite, Object):
return invite.id return invite.id
@ -264,6 +206,35 @@ class Client:
method = 'on_' + event method = 'on_' + event
handler = 'handle_' + event handler = 'handle_' + event
listeners = self._listeners.get(event)
if listeners:
removed = []
for i, (future, condition) in enumerate(listeners):
if future.cancelled():
removed.append(i)
continue
try:
result = condition(*args)
except Exception as e:
future.set_exception(e)
removed.append(i)
else:
if result:
if len(args) == 0:
future.set_result(None)
elif len(args) == 1:
future.set_result(args[0])
else:
future.set_result(args)
removed.append(i)
if len(removed) == len(listeners):
self._listeners.pop(event)
else:
for idx in reversed(removed):
del listeners[idx]
try: try:
actual_handler = getattr(self, handler) actual_handler = getattr(self, handler)
except AttributeError: except AttributeError:
@ -353,7 +324,6 @@ class Client:
data = yield from self.http.static_login(token, bot=bot) data = yield from self.http.static_login(token, bot=bot)
self.email = data.get('email', None) self.email = data.get('email', None)
self.connection.is_bot = bot self.connection.is_bot = bot
self._is_logged_in.set()
@asyncio.coroutine @asyncio.coroutine
def logout(self): def logout(self):
@ -362,7 +332,6 @@ class Client:
Logs out of Discord and closes all connections. Logs out of Discord and closes all connections.
""" """
yield from self.close() yield from self.close()
self._is_logged_in.clear()
@asyncio.coroutine @asyncio.coroutine
def connect(self): def connect(self):
@ -420,7 +389,6 @@ class Client:
yield from self.http.close() yield from self.http.close()
self._closed.set() self._closed.set()
self._is_ready.clear()
@asyncio.coroutine @asyncio.coroutine
def start(self, *args, **kwargs): def start(self, *args, **kwargs):
@ -476,11 +444,6 @@ class Client:
# properties # properties
@property
def is_logged_in(self):
"""bool: Indicates if the client has logged in successfully."""
return self._is_logged_in.is_set()
@property @property
def is_closed(self): def is_closed(self):
"""bool: Indicates if the websocket connection is closed.""" """bool: Indicates if the websocket connection is closed."""
@ -550,250 +513,83 @@ class Client:
# listeners/waiters # listeners/waiters
@asyncio.coroutine def wait_for(self, event, *, check=None, timeout=None):
def wait_until_ready(self):
"""|coro| """|coro|
This coroutine waits until the client is all ready. This could be considered Waits for a WebSocket event to be dispatched.
another way of asking for :func:`discord.on_ready` except meant for your own
background tasks.
"""
yield from self._is_ready.wait()
@asyncio.coroutine This could be used to wait for a user to reply to a message,
def wait_until_login(self): or to react to a message, or to edit a message in a self-contained
"""|coro| way.
This coroutine waits until the client is logged on successfully. This The ``timeout`` parameter is passed onto `asyncio.wait_for`_. By default,
is different from waiting until the client's state is all ready. For it does not timeout. Note that this does propagate the
that check :func:`discord.on_ready` and :meth:`wait_until_ready`. ``asyncio.TimeoutError`` for you in case of timeout and is provided for
""" ease of use.
yield from self._is_logged_in.wait()
@asyncio.coroutine In case the event returns multiple arguments, a tuple containing those
def wait_for_message(self, timeout=None, *, author=None, channel=None, content=None, check=None): arguments is returned instead. Please check the
"""|coro| :ref:`documentation <discord-api-events>` for a list of events and their
parameters.
Waits for a message reply from Discord. This could be seen as another This function returns the **first event that meets the requirements**.
:func:`discord.on_message` event outside of the actual event. This could
also be used for follow-ups and easier user interactions.
The keyword arguments passed into this function are combined using the logical and
operator. The ``check`` keyword argument can be used to pass in more complicated
checks and must be a regular function (not a coroutine).
The ``timeout`` parameter is passed into `asyncio.wait_for`_. By default, it
does not timeout. Instead of throwing ``asyncio.TimeoutError`` the coroutine
catches the exception and returns ``None`` instead of a :class:`Message`.
If the ``check`` predicate throws an exception, then the exception is propagated.
This function returns the **first message that meets the requirements**.
.. _asyncio.wait_for: https://docs.python.org/3/library/asyncio-task.html#asyncio.wait_for
Examples
----------
Basic example:
.. code-block:: python
:emphasize-lines: 5
@client.event
async def on_message(message):
if message.content.startswith('$greet'):
await message.channel.send('Say hello')
msg = await client.wait_for_message(author=message.author, content='hello')
await message.channel.send('Hello.')
Asking for a follow-up question:
.. code-block:: python
:emphasize-lines: 6
@client.event
async def on_message(message):
if message.content.startswith('$start'):
await message.channel.send('Type $stop 4 times.')
for i in range(4):
msg = await client.wait_for_message(author=message.author, content='$stop')
fmt = '{} left to go...'
await message.channel.send(fmt.format(3 - i))
await message.channel.send('Good job!')
Advanced filters using ``check``:
.. code-block:: python
:emphasize-lines: 9
@client.event
async def on_message(message):
if message.content.startswith('$cool'):
await message.channel.send('Who is cool? Type $name namehere')
def check(msg):
return msg.content.startswith('$name')
message = await client.wait_for_message(author=message.author, check=check)
name = message.content[len('$name'):].strip()
await message.channel.send('{} is cool indeed'.format(name))
Parameters
-----------
timeout : float
The number of seconds to wait before returning ``None``.
author : :class:`Member` or :class:`User`
The author the message must be from.
channel : :class:`Channel` or :class:`PrivateChannel` or :class:`Object`
The channel the message must be from.
content : str
The exact content the message must have.
check : function
A predicate for other complicated checks. The predicate must take
a :class:`Message` as its only parameter.
Returns
--------
:class:`Message`
The message that you requested for.
"""
def predicate(message):
result = True
if author is not None:
result = result and message.author == author
if content is not None:
result = result and message.content == content
if channel is not None:
result = result and message.channel.id == channel.id
if callable(check):
# the exception thrown by check is propagated through the future.
result = result and check(message)
return result
future = compat.create_future(self.loop)
self._listeners.append((predicate, future, WaitForType.message))
try:
message = yield from asyncio.wait_for(future, timeout, loop=self.loop)
except asyncio.TimeoutError:
message = None
return message
@asyncio.coroutine
def wait_for_reaction(self, emoji=None, *, user=None, timeout=None, message=None, check=None):
"""|coro|
Waits for a message reaction from Discord. This is similar to :meth:`wait_for_message`
and could be seen as another :func:`on_reaction_add` event outside of the actual event.
This could be used for follow up situations.
Similar to :meth:`wait_for_message`, the keyword arguments are combined using logical
AND operator. The ``check`` keyword argument can be used to pass in more complicated
checks and must a regular function taking in two arguments, ``(reaction, user)``. It
must not be a coroutine.
The ``timeout`` parameter is passed into asyncio.wait_for. By default, it
does not timeout. Instead of throwing ``asyncio.TimeoutError`` the coroutine
catches the exception and returns ``None`` instead of a the ``(reaction, user)``
tuple.
If the ``check`` predicate throws an exception, then the exception is propagated.
The ``emoji`` parameter can be either a :class:`Emoji`, a ``str`` representing
an emoji, or a sequence of either type. If the ``emoji`` parameter is a sequence
then the first reaction emoji that is in the list is returned. If ``None`` is
passed then the first reaction emoji used is returned.
This function returns the **first reaction that meets the requirements**.
Examples Examples
--------- ---------
Basic Example: Waiting for a user reply: ::
.. code-block:: python
@client.event @client.event
async def on_message(message): async def on_message(message):
if message.content.startswith('$react'): if message.content.startswith('$greet'):
msg = await message.channel.send('React with thumbs up or thumbs down.') await message.channel.send('Say hello!')
res = await client.wait_for_reaction(['\N{THUMBS UP SIGN}', '\N{THUMBS DOWN SIGN}'], message=msg)
await message.channel.send('{0.user} reacted with {0.reaction.emoji}!'.format(res))
Checking for reaction emoji regardless of skin tone: def check(m):
return m.content == 'hello' and m.channel == message.channel
.. code-block:: python msg = await client.wait_for('message', check=check)
await message.channel.send('Hello {.author}!'.format(msg))
@client.event
async def on_message(message):
if message.content.startswith('$react'):
msg = await message.channel.send('React with thumbs up or thumbs down.')
def check(reaction, user):
e = str(reaction.emoji)
return e.startswith(('\N{THUMBS UP SIGN}', '\N{THUMBS DOWN SIGN}'))
res = await client.wait_for_reaction(message=msg, check=check)
await message.channel.send('{0.user} reacted with {0.reaction.emoji}!'.format(res))
Parameters Parameters
----------- ------------
timeout: float event: str
The number of seconds to wait before returning ``None``. The event name, similar to the :ref:`event reference <discord-api-events>`,
user: :class:`Member` or :class:`User` but without the ``on_`` prefix, to wait for.
The user the reaction must be from. check: Optional[predicate]
emoji: str or :class:`Emoji` or sequence A predicate to check what to wait for. The arguments must meet the
The emoji that we are waiting to react with. parameters of the event being waited for.
message: :class:`Message` timeout: Optional[float]
The message that we want the reaction to be from. The number of seconds to wait before timing out and raising
check: function ``asyncio.TimeoutError``\.
A predicate for other complicated checks. The predicate must take
``(reaction, user)`` as its two parameters, which ``reaction`` being a Raises
:class:`Reaction` and ``user`` being either a :class:`User` or a -------
:class:`Member`. asyncio.TimeoutError
If a timeout is provided and it was reached.
Returns Returns
-------- --------
namedtuple Any
A namedtuple with attributes ``reaction`` and ``user`` similar to :func:`on_reaction_add`. Returns no arguments, a single argument, or a tuple of multiple
arguments that mirrors the parameters passed in the
:ref:`event reference <discord-api-events>`.
""" """
if emoji is None:
emoji_check = lambda r: True
elif isinstance(emoji, (str, Emoji)):
emoji_check = lambda r: r.emoji == emoji
else:
emoji_check = lambda r: r.emoji in emoji
def predicate(reaction, reaction_user):
result = emoji_check(reaction)
if message is not None:
result = result and message.id == reaction.message.id
if user is not None:
result = result and user.id == reaction_user.id
if callable(check):
# the exception thrown by check is propagated through the future.
result = result and check(reaction, reaction_user)
return result
future = compat.create_future(self.loop) future = compat.create_future(self.loop)
self._listeners.append((predicate, future, WaitForType.reaction)) if check is None:
def _check(*args):
return True
check = _check
ev = event.lower()
try: try:
return (yield from asyncio.wait_for(future, timeout, loop=self.loop)) listeners = self._listeners[ev]
except asyncio.TimeoutError: except KeyError:
return None listeners = []
self._listeners[ev] = listeners
listeners.append((future, check))
return asyncio.wait_for(future, timeout, loop=self.loop)
# event registration # event registration