Replace wait_for_* with a generic Client.wait_for
This commit is contained in:
		| @@ -41,7 +41,7 @@ import aiohttp | ||||
| import websockets | ||||
|  | ||||
| import logging, traceback | ||||
| import sys, re, io, enum | ||||
| import sys, re, io | ||||
| import itertools | ||||
| import datetime | ||||
| from collections import namedtuple | ||||
| @@ -51,7 +51,6 @@ PY35 = sys.version_info >= (3, 5) | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
| AppInfo = namedtuple('AppInfo', 'id name description icon owner') | ||||
| WaitedReaction = namedtuple('WaitedReaction', 'reaction user') | ||||
|  | ||||
| def app_info_icon_url(self): | ||||
|     """Retrieves the application's icon_url if it exists. Empty string otherwise.""" | ||||
| @@ -62,10 +61,6 @@ def app_info_icon_url(self): | ||||
|  | ||||
| AppInfo.icon_url = property(app_info_icon_url) | ||||
|  | ||||
| class WaitForType(enum.Enum): | ||||
|     message  = 0 | ||||
|     reaction = 1 | ||||
|  | ||||
| class Client: | ||||
|     """Represents a client connection that connects to Discord. | ||||
|     This class is used to interact with the Discord WebSocket and API. | ||||
| @@ -113,7 +108,7 @@ class Client: | ||||
|         self.ws = None | ||||
|         self.email = None | ||||
|         self.loop = asyncio.get_event_loop() if loop is None else loop | ||||
|         self._listeners = [] | ||||
|         self._listeners = {} | ||||
|         self.shard_id = options.get('shard_id') | ||||
|         self.shard_count = options.get('shard_count') | ||||
|  | ||||
| @@ -125,8 +120,6 @@ class Client: | ||||
|  | ||||
|         self.connection.shard_count = self.shard_count | ||||
|         self._closed = asyncio.Event(loop=self.loop) | ||||
|         self._is_logged_in = asyncio.Event(loop=self.loop) | ||||
|         self._is_ready = asyncio.Event(loop=self.loop) | ||||
|  | ||||
|         # if VoiceClient.warn_nacl: | ||||
|         #     VoiceClient.warn_nacl = False | ||||
| @@ -156,57 +149,6 @@ class Client: | ||||
|  | ||||
|         yield from self.ws.send_as_json(payload) | ||||
|  | ||||
|     def handle_reaction_add(self, reaction, user): | ||||
|         removed = [] | ||||
|         for i, (condition, future, event_type) in enumerate(self._listeners): | ||||
|             if event_type is not WaitForType.reaction: | ||||
|                 continue | ||||
|  | ||||
|             if future.cancelled(): | ||||
|                 removed.append(i) | ||||
|                 continue | ||||
|  | ||||
|             try: | ||||
|                 result = condition(reaction, user) | ||||
|             except Exception as e: | ||||
|                 future.set_exception(e) | ||||
|                 removed.append(i) | ||||
|             else: | ||||
|                 if result: | ||||
|                     future.set_result(WaitedReaction(reaction, user)) | ||||
|                     removed.append(i) | ||||
|  | ||||
|  | ||||
|         for idx in reversed(removed): | ||||
|             del self._listeners[idx] | ||||
|  | ||||
|     def handle_message(self, message): | ||||
|         removed = [] | ||||
|         for i, (condition, future, event_type) in enumerate(self._listeners): | ||||
|             if event_type is not WaitForType.message: | ||||
|                 continue | ||||
|  | ||||
|             if future.cancelled(): | ||||
|                 removed.append(i) | ||||
|                 continue | ||||
|  | ||||
|             try: | ||||
|                 result = condition(message) | ||||
|             except Exception as e: | ||||
|                 future.set_exception(e) | ||||
|                 removed.append(i) | ||||
|             else: | ||||
|                 if result: | ||||
|                     future.set_result(message) | ||||
|                     removed.append(i) | ||||
|  | ||||
|  | ||||
|         for idx in reversed(removed): | ||||
|             del self._listeners[idx] | ||||
|  | ||||
|     def handle_ready(self): | ||||
|         self._is_ready.set() | ||||
|  | ||||
|     def _resolve_invite(self, invite): | ||||
|         if isinstance(invite, Invite) or isinstance(invite, Object): | ||||
|             return invite.id | ||||
| @@ -264,6 +206,35 @@ class Client: | ||||
|         method = 'on_' + event | ||||
|         handler = 'handle_' + event | ||||
|  | ||||
|         listeners = self._listeners.get(event) | ||||
|         if listeners: | ||||
|             removed = [] | ||||
|             for i, (future, condition) in enumerate(listeners): | ||||
|                 if future.cancelled(): | ||||
|                     removed.append(i) | ||||
|                     continue | ||||
|  | ||||
|                 try: | ||||
|                     result = condition(*args) | ||||
|                 except Exception as e: | ||||
|                     future.set_exception(e) | ||||
|                     removed.append(i) | ||||
|                 else: | ||||
|                     if result: | ||||
|                         if len(args) == 0: | ||||
|                             future.set_result(None) | ||||
|                         elif len(args) == 1: | ||||
|                             future.set_result(args[0]) | ||||
|                         else: | ||||
|                             future.set_result(args) | ||||
|                         removed.append(i) | ||||
|  | ||||
|             if len(removed) == len(listeners): | ||||
|                 self._listeners.pop(event) | ||||
|             else: | ||||
|                 for idx in reversed(removed): | ||||
|                     del listeners[idx] | ||||
|  | ||||
|         try: | ||||
|             actual_handler = getattr(self, handler) | ||||
|         except AttributeError: | ||||
| @@ -353,7 +324,6 @@ class Client: | ||||
|         data = yield from self.http.static_login(token, bot=bot) | ||||
|         self.email = data.get('email', None) | ||||
|         self.connection.is_bot = bot | ||||
|         self._is_logged_in.set() | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def logout(self): | ||||
| @@ -362,7 +332,6 @@ class Client: | ||||
|         Logs out of Discord and closes all connections. | ||||
|         """ | ||||
|         yield from self.close() | ||||
|         self._is_logged_in.clear() | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def connect(self): | ||||
| @@ -420,7 +389,6 @@ class Client: | ||||
|  | ||||
|         yield from self.http.close() | ||||
|         self._closed.set() | ||||
|         self._is_ready.clear() | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def start(self, *args, **kwargs): | ||||
| @@ -474,12 +442,7 @@ class Client: | ||||
|         finally: | ||||
|             self.loop.close() | ||||
|  | ||||
|         # properties | ||||
|  | ||||
|     @property | ||||
|     def is_logged_in(self): | ||||
|         """bool: Indicates if the client has logged in successfully.""" | ||||
|         return self._is_logged_in.is_set() | ||||
|     # properties | ||||
|  | ||||
|     @property | ||||
|     def is_closed(self): | ||||
| @@ -550,250 +513,83 @@ class Client: | ||||
|  | ||||
|     # listeners/waiters | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def wait_until_ready(self): | ||||
|     def wait_for(self, event, *, check=None, timeout=None): | ||||
|         """|coro| | ||||
|  | ||||
|         This coroutine waits until the client is all ready. This could be considered | ||||
|         another way of asking for :func:`discord.on_ready` except meant for your own | ||||
|         background tasks. | ||||
|         """ | ||||
|         yield from self._is_ready.wait() | ||||
|         Waits for a WebSocket event to be dispatched. | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def wait_until_login(self): | ||||
|         """|coro| | ||||
|         This could be used to wait for a user to reply to a message, | ||||
|         or to react to a message, or to edit a message in a self-contained | ||||
|         way. | ||||
|  | ||||
|         This coroutine waits until the client is logged on successfully. This | ||||
|         is different from waiting until the client's state is all ready. For | ||||
|         that check :func:`discord.on_ready` and :meth:`wait_until_ready`. | ||||
|         """ | ||||
|         yield from self._is_logged_in.wait() | ||||
|         The ``timeout`` parameter is passed onto `asyncio.wait_for`_. By default, | ||||
|         it does not timeout. Note that this does propagate the | ||||
|         ``asyncio.TimeoutError`` for you in case of timeout and is provided for | ||||
|         ease of use. | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def wait_for_message(self, timeout=None, *, author=None, channel=None, content=None, check=None): | ||||
|         """|coro| | ||||
|         In case the event returns multiple arguments, a tuple containing those | ||||
|         arguments is returned instead. Please check the | ||||
|         :ref:`documentation <discord-api-events>` for a list of events and their | ||||
|         parameters. | ||||
|  | ||||
|         Waits for a message reply from Discord. This could be seen as another | ||||
|         :func:`discord.on_message` event outside of the actual event. This could | ||||
|         also be used for follow-ups and easier user interactions. | ||||
|  | ||||
|         The keyword arguments passed into this function are combined using the logical and | ||||
|         operator. The ``check`` keyword argument can be used to pass in more complicated | ||||
|         checks and must be a regular function (not a coroutine). | ||||
|  | ||||
|         The ``timeout`` parameter is passed into `asyncio.wait_for`_. By default, it | ||||
|         does not timeout. Instead of throwing ``asyncio.TimeoutError`` the coroutine | ||||
|         catches the exception and returns ``None`` instead of a :class:`Message`. | ||||
|  | ||||
|         If the ``check`` predicate throws an exception, then the exception is propagated. | ||||
|  | ||||
|         This function returns the **first message that meets the requirements**. | ||||
|  | ||||
|         .. _asyncio.wait_for: https://docs.python.org/3/library/asyncio-task.html#asyncio.wait_for | ||||
|  | ||||
|         Examples | ||||
|         ---------- | ||||
|  | ||||
|         Basic example: | ||||
|  | ||||
|         .. code-block:: python | ||||
|             :emphasize-lines: 5 | ||||
|  | ||||
|             @client.event | ||||
|             async def on_message(message): | ||||
|                 if message.content.startswith('$greet'): | ||||
|                     await message.channel.send('Say hello') | ||||
|                     msg = await client.wait_for_message(author=message.author, content='hello') | ||||
|                     await message.channel.send('Hello.') | ||||
|  | ||||
|         Asking for a follow-up question: | ||||
|  | ||||
|         .. code-block:: python | ||||
|             :emphasize-lines: 6 | ||||
|  | ||||
|             @client.event | ||||
|             async def on_message(message): | ||||
|                 if message.content.startswith('$start'): | ||||
|                     await message.channel.send('Type $stop 4 times.') | ||||
|                     for i in range(4): | ||||
|                         msg = await client.wait_for_message(author=message.author, content='$stop') | ||||
|                         fmt = '{} left to go...' | ||||
|                         await message.channel.send(fmt.format(3 - i)) | ||||
|  | ||||
|                     await message.channel.send('Good job!') | ||||
|  | ||||
|         Advanced filters using ``check``: | ||||
|  | ||||
|         .. code-block:: python | ||||
|             :emphasize-lines: 9 | ||||
|  | ||||
|             @client.event | ||||
|             async def on_message(message): | ||||
|                 if message.content.startswith('$cool'): | ||||
|                     await message.channel.send('Who is cool? Type $name namehere') | ||||
|  | ||||
|                     def check(msg): | ||||
|                         return msg.content.startswith('$name') | ||||
|  | ||||
|                     message = await client.wait_for_message(author=message.author, check=check) | ||||
|                     name = message.content[len('$name'):].strip() | ||||
|                     await message.channel.send('{} is cool indeed'.format(name)) | ||||
|  | ||||
|  | ||||
|         Parameters | ||||
|         ----------- | ||||
|         timeout : float | ||||
|             The number of seconds to wait before returning ``None``. | ||||
|         author : :class:`Member` or :class:`User` | ||||
|             The author the message must be from. | ||||
|         channel : :class:`Channel` or :class:`PrivateChannel` or :class:`Object` | ||||
|             The channel the message must be from. | ||||
|         content : str | ||||
|             The exact content the message must have. | ||||
|         check : function | ||||
|             A predicate for other complicated checks. The predicate must take | ||||
|             a :class:`Message` as its only parameter. | ||||
|  | ||||
|         Returns | ||||
|         -------- | ||||
|         :class:`Message` | ||||
|             The message that you requested for. | ||||
|         """ | ||||
|  | ||||
|         def predicate(message): | ||||
|             result = True | ||||
|             if author is not None: | ||||
|                 result = result and message.author == author | ||||
|  | ||||
|             if content is not None: | ||||
|                 result = result and message.content == content | ||||
|  | ||||
|             if channel is not None: | ||||
|                 result = result and message.channel.id == channel.id | ||||
|  | ||||
|             if callable(check): | ||||
|                 # the exception thrown by check is propagated through the future. | ||||
|                 result = result and check(message) | ||||
|  | ||||
|             return result | ||||
|  | ||||
|         future = compat.create_future(self.loop) | ||||
|         self._listeners.append((predicate, future, WaitForType.message)) | ||||
|         try: | ||||
|             message = yield from asyncio.wait_for(future, timeout, loop=self.loop) | ||||
|         except asyncio.TimeoutError: | ||||
|             message = None | ||||
|         return message | ||||
|  | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def wait_for_reaction(self, emoji=None, *, user=None, timeout=None, message=None, check=None): | ||||
|         """|coro| | ||||
|  | ||||
|         Waits for a message reaction from Discord. This is similar to :meth:`wait_for_message` | ||||
|         and could be seen as another :func:`on_reaction_add` event outside of the actual event. | ||||
|         This could be used for follow up situations. | ||||
|  | ||||
|         Similar to :meth:`wait_for_message`, the keyword arguments are combined using logical | ||||
|         AND operator. The ``check`` keyword argument can be used to pass in more complicated | ||||
|         checks and must a regular function taking in two arguments, ``(reaction, user)``. It | ||||
|         must not be a coroutine. | ||||
|  | ||||
|         The ``timeout`` parameter is passed into asyncio.wait_for. By default, it | ||||
|         does not timeout. Instead of throwing ``asyncio.TimeoutError`` the coroutine | ||||
|         catches the exception and returns ``None`` instead of a the ``(reaction, user)`` | ||||
|         tuple. | ||||
|  | ||||
|         If the ``check`` predicate throws an exception, then the exception is propagated. | ||||
|  | ||||
|         The ``emoji`` parameter can be either a :class:`Emoji`, a ``str`` representing | ||||
|         an emoji, or a sequence of either type. If the ``emoji`` parameter is a sequence | ||||
|         then the first reaction emoji that is in the list is returned. If ``None`` is | ||||
|         passed then the first reaction emoji used is returned. | ||||
|  | ||||
|         This function returns the **first reaction that meets the requirements**. | ||||
|         This function returns the **first event that meets the requirements**. | ||||
|  | ||||
|         Examples | ||||
|         --------- | ||||
|  | ||||
|         Basic Example: | ||||
|  | ||||
|         .. code-block:: python | ||||
|         Waiting for a user reply: :: | ||||
|  | ||||
|             @client.event | ||||
|             async def on_message(message): | ||||
|                 if message.content.startswith('$react'): | ||||
|                     msg = await message.channel.send('React with thumbs up or thumbs down.') | ||||
|                     res = await client.wait_for_reaction(['\N{THUMBS UP SIGN}', '\N{THUMBS DOWN SIGN}'], message=msg) | ||||
|                     await message.channel.send('{0.user} reacted with {0.reaction.emoji}!'.format(res)) | ||||
|                 if message.content.startswith('$greet'): | ||||
|                     await message.channel.send('Say hello!') | ||||
|  | ||||
|         Checking for reaction emoji regardless of skin tone: | ||||
|                     def check(m): | ||||
|                         return m.content == 'hello' and m.channel == message.channel | ||||
|  | ||||
|         .. code-block:: python | ||||
|  | ||||
|             @client.event | ||||
|             async def on_message(message): | ||||
|                 if message.content.startswith('$react'): | ||||
|                     msg = await message.channel.send('React with thumbs up or thumbs down.') | ||||
|  | ||||
|                     def check(reaction, user): | ||||
|                         e = str(reaction.emoji) | ||||
|                         return e.startswith(('\N{THUMBS UP SIGN}', '\N{THUMBS DOWN SIGN}')) | ||||
|  | ||||
|                     res = await client.wait_for_reaction(message=msg, check=check) | ||||
|                     await message.channel.send('{0.user} reacted with {0.reaction.emoji}!'.format(res)) | ||||
|                     msg = await client.wait_for('message', check=check) | ||||
|                     await message.channel.send('Hello {.author}!'.format(msg)) | ||||
|  | ||||
|         Parameters | ||||
|         ----------- | ||||
|         timeout: float | ||||
|             The number of seconds to wait before returning ``None``. | ||||
|         user: :class:`Member` or :class:`User` | ||||
|             The user the reaction must be from. | ||||
|         emoji: str or :class:`Emoji` or sequence | ||||
|             The emoji that we are waiting to react with. | ||||
|         message: :class:`Message` | ||||
|             The message that we want the reaction to be from. | ||||
|         check: function | ||||
|             A predicate for other complicated checks. The predicate must take | ||||
|             ``(reaction, user)`` as its two parameters, which ``reaction`` being a | ||||
|             :class:`Reaction` and ``user`` being either a :class:`User` or a | ||||
|             :class:`Member`. | ||||
|         ------------ | ||||
|         event: str | ||||
|             The event name, similar to the :ref:`event reference <discord-api-events>`, | ||||
|             but without the ``on_`` prefix, to wait for. | ||||
|         check: Optional[predicate] | ||||
|             A predicate to check what to wait for. The arguments must meet the | ||||
|             parameters of the event being waited for. | ||||
|         timeout: Optional[float] | ||||
|             The number of seconds to wait before timing out and raising | ||||
|             ``asyncio.TimeoutError``\. | ||||
|  | ||||
|         Raises | ||||
|         ------- | ||||
|         asyncio.TimeoutError | ||||
|             If a timeout is provided and it was reached. | ||||
|  | ||||
|         Returns | ||||
|         -------- | ||||
|         namedtuple | ||||
|             A namedtuple with attributes ``reaction`` and ``user`` similar to :func:`on_reaction_add`. | ||||
|         Any | ||||
|             Returns no arguments, a single argument, or a tuple of multiple | ||||
|             arguments that mirrors the parameters passed in the | ||||
|             :ref:`event reference <discord-api-events>`. | ||||
|         """ | ||||
|  | ||||
|         if emoji is None: | ||||
|             emoji_check = lambda r: True | ||||
|         elif isinstance(emoji, (str, Emoji)): | ||||
|             emoji_check = lambda r: r.emoji == emoji | ||||
|         else: | ||||
|             emoji_check = lambda r: r.emoji in emoji | ||||
|  | ||||
|         def predicate(reaction, reaction_user): | ||||
|             result = emoji_check(reaction) | ||||
|  | ||||
|             if message is not None: | ||||
|                 result = result and message.id == reaction.message.id | ||||
|  | ||||
|             if user is not None: | ||||
|                 result = result and user.id == reaction_user.id | ||||
|  | ||||
|             if callable(check): | ||||
|                 # the exception thrown by check is propagated through the future. | ||||
|                 result = result and check(reaction, reaction_user) | ||||
|  | ||||
|             return result | ||||
|  | ||||
|         future = compat.create_future(self.loop) | ||||
|         self._listeners.append((predicate, future, WaitForType.reaction)) | ||||
|         if check is None: | ||||
|             def _check(*args): | ||||
|                 return True | ||||
|             check = _check | ||||
|  | ||||
|         ev = event.lower() | ||||
|         try: | ||||
|             return (yield from asyncio.wait_for(future, timeout, loop=self.loop)) | ||||
|         except asyncio.TimeoutError: | ||||
|             return None | ||||
|             listeners = self._listeners[ev] | ||||
|         except KeyError: | ||||
|             listeners = [] | ||||
|             self._listeners[ev] = listeners | ||||
|  | ||||
|         listeners.append((future, check)) | ||||
|         return asyncio.wait_for(future, timeout, loop=self.loop) | ||||
|  | ||||
|     # event registration | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user