Add Client.wait_for_reaction to wait for a reaction from a user.
This commit is contained in:
		| @@ -50,7 +50,7 @@ import aiohttp | |||||||
| import websockets | import websockets | ||||||
|  |  | ||||||
| import logging, traceback | import logging, traceback | ||||||
| import sys, re, io | import sys, re, io, enum | ||||||
| import tempfile, os, hashlib | import tempfile, os, hashlib | ||||||
| import itertools | import itertools | ||||||
| import datetime | import datetime | ||||||
| @@ -70,6 +70,10 @@ def app_info_icon_url(self): | |||||||
|  |  | ||||||
| AppInfo.icon_url = property(app_info_icon_url) | AppInfo.icon_url = property(app_info_icon_url) | ||||||
|  |  | ||||||
|  | class WaitForType(enum.Enum): | ||||||
|  |     message  = 0 | ||||||
|  |     reaction = 1 | ||||||
|  |  | ||||||
| ChannelPermissions = namedtuple('ChannelPermissions', 'target overwrite') | ChannelPermissions = namedtuple('ChannelPermissions', 'target overwrite') | ||||||
| ChannelPermissions.__new__.__defaults__ = (PermissionOverwrite(),) | ChannelPermissions.__new__.__defaults__ = (PermissionOverwrite(),) | ||||||
|  |  | ||||||
| @@ -194,9 +198,36 @@ class Client: | |||||||
|             log.info('a problem occurred while updating the login cache') |             log.info('a problem occurred while updating the login cache') | ||||||
|             pass |             pass | ||||||
|  |  | ||||||
|  |     def handle_reaction_add(self, reaction, user): | ||||||
|  |         removed = [] | ||||||
|  |         for i, (condition, future, event_type) in enumerate(self._listeners): | ||||||
|  |             if event_type is not WaitForType.reaction: | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |             if future.cancelled(): | ||||||
|  |                 removed.append(i) | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |             try: | ||||||
|  |                 result = condition(reaction, user) | ||||||
|  |             except Exception as e: | ||||||
|  |                 future.set_exception(e) | ||||||
|  |                 removed.append(i) | ||||||
|  |             else: | ||||||
|  |                 if result: | ||||||
|  |                     future.set_result((reaction, user)) | ||||||
|  |                     removed.append(i) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |         for idx in reversed(removed): | ||||||
|  |             del self._listeners[idx] | ||||||
|  |  | ||||||
|     def handle_message(self, message): |     def handle_message(self, message): | ||||||
|         removed = [] |         removed = [] | ||||||
|         for i, (condition, future) in enumerate(self._listeners): |         for i, (condition, future, event_type) in enumerate(self._listeners): | ||||||
|  |             if event_type is not WaitForType.message: | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|             if future.cancelled(): |             if future.cancelled(): | ||||||
|                 removed.append(i) |                 removed.append(i) | ||||||
|                 continue |                 continue | ||||||
| @@ -614,45 +645,45 @@ class Client: | |||||||
|         .. code-block:: python |         .. code-block:: python | ||||||
|             :emphasize-lines: 5 |             :emphasize-lines: 5 | ||||||
|  |  | ||||||
|             @client.async_event |             @client.event | ||||||
|             def on_message(message): |             async def on_message(message): | ||||||
|                 if message.content.startswith('$greet'): |                 if message.content.startswith('$greet'): | ||||||
|                     yield from client.send_message(message.channel, 'Say hello') |                     await client.send_message(message.channel, 'Say hello') | ||||||
|                     msg = yield from client.wait_for_message(author=message.author, content='hello') |                     msg = await client.wait_for_message(author=message.author, content='hello') | ||||||
|                     yield from client.send_message(message.channel, 'Hello.') |                     await client.send_message(message.channel, 'Hello.') | ||||||
|  |  | ||||||
|         Asking for a follow-up question: |         Asking for a follow-up question: | ||||||
|  |  | ||||||
|         .. code-block:: python |         .. code-block:: python | ||||||
|             :emphasize-lines: 6 |             :emphasize-lines: 6 | ||||||
|  |  | ||||||
|             @client.async_event |             @client.event | ||||||
|             def on_message(message): |             async def on_message(message): | ||||||
|                 if message.content.startswith('$start'): |                 if message.content.startswith('$start'): | ||||||
|                     yield from client.send_message(message.channel, 'Type $stop 4 times.') |                     await client.send_message(message.channel, 'Type $stop 4 times.') | ||||||
|                     for i in range(4): |                     for i in range(4): | ||||||
|                         msg = yield from client.wait_for_message(author=message.author, content='$stop') |                         msg = await client.wait_for_message(author=message.author, content='$stop') | ||||||
|                         fmt = '{} left to go...' |                         fmt = '{} left to go...' | ||||||
|                         yield from client.send_message(message.channel, fmt.format(3 - i)) |                         await client.send_message(message.channel, fmt.format(3 - i)) | ||||||
|  |  | ||||||
|                     yield from client.send_message(message.channel, 'Good job!') |                     await client.send_message(message.channel, 'Good job!') | ||||||
|  |  | ||||||
|         Advanced filters using ``check``: |         Advanced filters using ``check``: | ||||||
|  |  | ||||||
|         .. code-block:: python |         .. code-block:: python | ||||||
|             :emphasize-lines: 9 |             :emphasize-lines: 9 | ||||||
|  |  | ||||||
|             @client.async_event |             @client.event | ||||||
|             def on_message(message): |             async def on_message(message): | ||||||
|                 if message.content.startswith('$cool'): |                 if message.content.startswith('$cool'): | ||||||
|                     yield from client.send_message(message.channel, 'Who is cool? Type $name namehere') |                     await client.send_message(message.channel, 'Who is cool? Type $name namehere') | ||||||
|  |  | ||||||
|                     def check(msg): |                     def check(msg): | ||||||
|                         return msg.content.startswith('$name') |                         return msg.content.startswith('$name') | ||||||
|  |  | ||||||
|                     message = yield from client.wait_for_message(author=message.author, check=check) |                     message = await client.wait_for_message(author=message.author, check=check) | ||||||
|                     name = message.content[len('$name'):].strip() |                     name = message.content[len('$name'):].strip() | ||||||
|                     yield from client.send_message(message.channel, '{} is cool indeed'.format(name)) |                     await client.send_message(message.channel, '{} is cool indeed'.format(name)) | ||||||
|  |  | ||||||
|  |  | ||||||
|         Parameters |         Parameters | ||||||
| @@ -693,13 +724,107 @@ class Client: | |||||||
|             return result |             return result | ||||||
|  |  | ||||||
|         future = asyncio.Future(loop=self.loop) |         future = asyncio.Future(loop=self.loop) | ||||||
|         self._listeners.append((predicate, future)) |         self._listeners.append((predicate, future, WaitForType.message)) | ||||||
|         try: |         try: | ||||||
|             message = yield from asyncio.wait_for(future, timeout, loop=self.loop) |             message = yield from asyncio.wait_for(future, timeout, loop=self.loop) | ||||||
|         except asyncio.TimeoutError: |         except asyncio.TimeoutError: | ||||||
|             message = None |             message = None | ||||||
|         return message |         return message | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     @asyncio.coroutine | ||||||
|  |     def wait_for_reaction(self, emoji=None, *, user=None, timeout=None, message=None, check=None): | ||||||
|  |         """|coro| | ||||||
|  |  | ||||||
|  |         Waits for a message reaction from Discord. This is similar to :meth:`wait_for_message` | ||||||
|  |         and could be seen as another :func:`on_reaction_add` event outside of the actual event. | ||||||
|  |         This could be used for follow up situations. | ||||||
|  |  | ||||||
|  |         Similar to :meth:`wait_for_message`, the keyword arguments are combined using logical | ||||||
|  |         AND operator. The ``check`` keyword argument can be used to pass in more complicated | ||||||
|  |         checks and must a regular function taking in two arguments, ``(reaction, user)``. It | ||||||
|  |         must not be a coroutine. | ||||||
|  |  | ||||||
|  |         The ``timeout`` parameter is passed into asyncio.wait_for. By default, it | ||||||
|  |         does not timeout. Instead of throwing ``asyncio.TimeoutError`` the coroutine | ||||||
|  |         catches the exception and returns ``None`` instead of a the ``(reaction, user)`` | ||||||
|  |         tuple. | ||||||
|  |  | ||||||
|  |         If the ``check`` predicate throws an exception, then the exception is propagated. | ||||||
|  |  | ||||||
|  |         The ``emoji`` parameter can be either a :class:`Emoji`, a ``str`` representing | ||||||
|  |         an emoji, or a sequence of either type. If the ``emoji`` parameter is a sequence | ||||||
|  |         then the first reaction emoji that is in the list is returned. If ``None`` is | ||||||
|  |         passed then the first reaction emoji used is returned. | ||||||
|  |  | ||||||
|  |         This function returns the **first reaction that meets the requirements**. | ||||||
|  |  | ||||||
|  |         Examples | ||||||
|  |         --------- | ||||||
|  |  | ||||||
|  |         Basic Example: | ||||||
|  |  | ||||||
|  |         .. code-block:: python | ||||||
|  |  | ||||||
|  |             @client.event | ||||||
|  |             async def on_message(message): | ||||||
|  |                 if message.content.startswith('$react'): | ||||||
|  |                     msg = await client.send_message(message.channel, 'React with thumbs up or thumbs down.') | ||||||
|  |                     (reaction, user) = await client.wait_for_reaction(['\N{THUMBS UP SIGN}', | ||||||
|  |                                                                        '\N{THUMBS DOWN SIGN}'], | ||||||
|  |                                                                       message=msg) | ||||||
|  |                     await client.send_message(message.channel, '{} reacted with {.emoji}!'.format(user, reaction)) | ||||||
|  |  | ||||||
|  |         Parameters | ||||||
|  |         ----------- | ||||||
|  |         timeout: float | ||||||
|  |             The number of seconds to wait before returning ``None``. | ||||||
|  |         user: :class:`Member` or :class:`User` | ||||||
|  |             The user the reaction must be from. | ||||||
|  |         emoji: str or :class:`Emoji` or sequence | ||||||
|  |             The emoji that we are waiting to react with. | ||||||
|  |         message: :class:`Message` | ||||||
|  |             The message that we want the reaction to be from. | ||||||
|  |         check: function | ||||||
|  |             A predicate for other complicated checks. The predicate must take | ||||||
|  |             ``(reaction, user)`` as its two parameters, which ``reaction`` being a | ||||||
|  |             :class:`Reaction` and ``user`` being either a :class:`User` or a | ||||||
|  |             :class:`Member`. | ||||||
|  |  | ||||||
|  |         Returns | ||||||
|  |         -------- | ||||||
|  |         tuple | ||||||
|  |             A tuple of ``(reaction, user)`` similar to :func:`on_reaction_add`. | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         emoji_check = lambda r: True | ||||||
|  |         if isinstance(emoji, (str, Emoji)): | ||||||
|  |             emoji_check = lambda r: r.emoji == emoji | ||||||
|  |         else: | ||||||
|  |             emoji_check = lambda r: r.emoji in emoji | ||||||
|  |  | ||||||
|  |         def predicate(reaction, reaction_user): | ||||||
|  |             result = emoji_check(reaction) | ||||||
|  |  | ||||||
|  |             if message is not None: | ||||||
|  |                 result = result and message.id == reaction.message.id | ||||||
|  |  | ||||||
|  |             if user is not None: | ||||||
|  |                 result = result and user.id == reaction_user.id | ||||||
|  |  | ||||||
|  |             if callable(check): | ||||||
|  |                 # the exception thrown by check is propagated through the future. | ||||||
|  |                 result = result and check(reaction, reaction_user) | ||||||
|  |  | ||||||
|  |             return result | ||||||
|  |  | ||||||
|  |         future = asyncio.Future(loop=self.loop) | ||||||
|  |         self._listeners.append((predicate, future, WaitForType.reaction)) | ||||||
|  |         try: | ||||||
|  |             return (yield from asyncio.wait_for(future, timeout, loop=self.loop)) | ||||||
|  |         except asyncio.TimeoutError: | ||||||
|  |             return None | ||||||
|  |  | ||||||
|     # event registration |     # event registration | ||||||
|  |  | ||||||
|     def event(self, coro): |     def event(self, coro): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user