Change Messageable channel getter to be a coroutine.
This commit is contained in:
		| @@ -467,6 +467,7 @@ class GuildChannel: | ||||
| class Messageable(metaclass=abc.ABCMeta): | ||||
|     __slots__ = () | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     @abc.abstractmethod | ||||
|     def _get_channel(self): | ||||
|         raise NotImplementedError | ||||
| @@ -534,7 +535,7 @@ class Messageable(metaclass=abc.ABCMeta): | ||||
|             The message that was sent. | ||||
|         """ | ||||
|  | ||||
|         channel = self._get_channel() | ||||
|         channel = yield from self._get_channel() | ||||
|         guild_id = self._get_guild_id() | ||||
|         state = self._state | ||||
|         content = str(content) if content else None | ||||
| @@ -576,7 +577,7 @@ class Messageable(metaclass=abc.ABCMeta): | ||||
|         *Typing* indicator will go away after 10 seconds, or after a message is sent. | ||||
|         """ | ||||
|  | ||||
|         channel = self._get_channel() | ||||
|         channel = yield from self._get_channel() | ||||
|         yield from self._state.http.send_typing(channel.id) | ||||
|  | ||||
|     def typing(self): | ||||
| @@ -596,7 +597,8 @@ class Messageable(metaclass=abc.ABCMeta): | ||||
|                 await channel.send_message('done!') | ||||
|  | ||||
|         """ | ||||
|         return Typing(self._get_channel()) | ||||
|         channel = yield from self._get_channel() | ||||
|         return Typing(channel) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def get_message(self, id): | ||||
| @@ -626,7 +628,7 @@ class Messageable(metaclass=abc.ABCMeta): | ||||
|             Retrieving the message failed. | ||||
|         """ | ||||
|  | ||||
|         channel = self._get_channel() | ||||
|         channel = yield from self._get_channel() | ||||
|         data = yield from self._state.http.get_message(channel.id, id) | ||||
|         return state.create_message(channel=channel, data=data) | ||||
|  | ||||
| @@ -660,7 +662,7 @@ class Messageable(metaclass=abc.ABCMeta): | ||||
|             raise ClientException('Can only delete messages in the range of [2, 100]') | ||||
|  | ||||
|         message_ids = [m.id for m in messages] | ||||
|         channel = self._get_channel() | ||||
|         channel = yield from self._get_channel() | ||||
|         guild_id = self._get_guild_id() | ||||
|  | ||||
|         yield from self._state.http.delete_messages(channel.id, message_ids, guild_id) | ||||
| @@ -677,7 +679,7 @@ class Messageable(metaclass=abc.ABCMeta): | ||||
|             Retrieving the pinned messages failed. | ||||
|         """ | ||||
|  | ||||
|         channel = self._get_channel() | ||||
|         channel = yield from self._get_channel() | ||||
|         state = self._state | ||||
|         data = yield from state.http.pins_from(channel.id) | ||||
|         return [state.create_message(channel=channel, data=m) for m in data] | ||||
| @@ -745,7 +747,7 @@ class Messageable(metaclass=abc.ABCMeta): | ||||
|                     if message.author == client.user: | ||||
|                         counter += 1 | ||||
|         """ | ||||
|         return LogsFromIterator(self._get_channel(), limit=limit, before=before, after=after, around=around, reverse=reverse) | ||||
|         return LogsFromIterator(self, limit=limit, before=before, after=after, around=around, reverse=reverse) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def purge(self, *, limit=100, check=None, before=None, after=None, around=None): | ||||
|   | ||||
| @@ -88,6 +88,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): | ||||
|         self.position = data['position'] | ||||
|         self._fill_overwrites(data) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def _get_channel(self): | ||||
|         return self | ||||
|  | ||||
| @@ -262,6 +263,7 @@ class DMChannel(discord.abc.Messageable, Hashable): | ||||
|         self.me = me | ||||
|         self.id = int(data['id']) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def _get_channel(self): | ||||
|         return self | ||||
|  | ||||
| @@ -360,6 +362,7 @@ class GroupChannel(discord.abc.Messageable, Hashable): | ||||
|         else: | ||||
|             self.owner = utils.find(lambda u: u.id == owner_id, self.recipients) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def _get_channel(self): | ||||
|         return self | ||||
|  | ||||
|   | ||||
| @@ -117,6 +117,7 @@ class Context(discord.abc.Messageable): | ||||
|         ret = yield from command.callback(*arguments, **kwargs) | ||||
|         return ret | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def _get_channel(self): | ||||
|         return self.channel | ||||
|  | ||||
|   | ||||
| @@ -70,7 +70,7 @@ class LogsFromIterator: | ||||
|         will be out of order. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, channel, limit, | ||||
|     def __init__(self, messageable, limit, | ||||
|                  before=None, after=None, around=None, reverse=None): | ||||
|  | ||||
|         if isinstance(before, datetime.datetime): | ||||
| @@ -80,9 +80,7 @@ class LogsFromIterator: | ||||
|         if isinstance(around, datetime.datetime): | ||||
|             around = Object(id=time_snowflake(around)) | ||||
|  | ||||
|         self.channel = channel | ||||
|         self.ctx = channel._state | ||||
|         self.logs_from = channel._state.http.logs_from | ||||
|         self.messageable = messageable | ||||
|         self.limit = limit | ||||
|         self.before = before | ||||
|         self.after = after | ||||
| @@ -135,6 +133,13 @@ class LogsFromIterator: | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def fill_messages(self): | ||||
|         if not hasattr(self, 'channel'): | ||||
|             # do the required set up | ||||
|             channel = yield from self.messageable._get_channel() | ||||
|             self.channel = channel | ||||
|             self.state = channel._state | ||||
|             self.logs_from = channel._state.http.logs_from | ||||
|  | ||||
|         if self.limit > 0: | ||||
|             retrieve = self.limit if self.limit <= 100 else 100 | ||||
|             data = yield from self._retrieve_messages(retrieve) | ||||
| @@ -144,9 +149,8 @@ class LogsFromIterator: | ||||
|                 data = filter(self._filter, data) | ||||
|  | ||||
|             channel = self.channel | ||||
|             state = self.ctx | ||||
|             for element in data: | ||||
|                 yield from self.messages.put(state.create_message(channel=channel, data=element)) | ||||
|                 yield from self.messages.put(self.state.create_message(channel=channel, data=element)) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def _retrieve_messages(self, retrieve): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user