Refactor LogsFromIterator
This commit is contained in:
		| @@ -1041,7 +1041,7 @@ class Client: | ||||
|         if isinstance(after, datetime.datetime): | ||||
|             after = Object(utils.time_snowflake(after, high=True)) | ||||
|  | ||||
|         iterator = LogsFromIterator.create(self, channel, limit, before=before, after=after) | ||||
|         iterator = LogsFromIterator(self, channel, limit, before=before, after=after) | ||||
|         ret = [] | ||||
|         count = 0 | ||||
|  | ||||
| @@ -1271,7 +1271,7 @@ class Client: | ||||
|             if isinstance(after, datetime.datetime): | ||||
|                 after = Object(utils.time_snowflake(after, high=True)) | ||||
|  | ||||
|             return LogsFromIterator.create(self, channel, limit, before=before, after=after, reverse=reverse) | ||||
|             return LogsFromIterator(self, channel, limit, before=before, after=after, reverse=reverse) | ||||
|     else: | ||||
|         @asyncio.coroutine | ||||
|         def logs_from(self, channel, limit=100, *, before=None, after=None): | ||||
|   | ||||
| @@ -32,19 +32,22 @@ from .object import Object | ||||
|  | ||||
| PY35 = sys.version_info >= (3, 5) | ||||
|  | ||||
|  | ||||
| class LogsFromIterator: | ||||
|     @staticmethod | ||||
|     def create(client, channel, limit, *, before=None, after=None, reverse=False): | ||||
|         """Create a proper iterator depending on parameters. | ||||
|     """Iterator for recieving logs. | ||||
|  | ||||
|         The messages endpoint has two behaviors: | ||||
|             If `before` is specified, it returns the `limit` newest messages before `before`, sorted with newest first. | ||||
|               - Fill strategy - update 'before' to oldest message | ||||
|             If `after` is specified, it returns the `limit` oldest messages after `after`, sorted with newest first. | ||||
|               - Fill strategy - update 'after' to newest message | ||||
|               - If messages are not reversed, they will be out of order (99-0, 199-100, so on) | ||||
|     The messages endpoint has two behaviors we care about here: | ||||
|     If `before` is specified, the messages endpoint returns the `limit` | ||||
|     newest messages before `before`, sorted with newest first. For filling over | ||||
|     100 messages, update the `before` parameter to the oldest message recieved. | ||||
|     Messages will be returned in order by time. | ||||
|     If `after` is specified, it returns the `limit` oldest messages after | ||||
|     `after`, sorted with newest first. For filling over 100 messages, update the | ||||
|     `after` parameter to the newest message recieved. If messages are not | ||||
|     reversed, they will be out of order (99-0, 199-100, so on) | ||||
|  | ||||
|         A note that if both before and after are specified, before is ignored by the messages endpoint. | ||||
|     A note that if both before and after are specified, before is ignored by the | ||||
|     messages endpoint. | ||||
|  | ||||
|     Parameters | ||||
|     ----------- | ||||
| @@ -58,25 +61,34 @@ class LogsFromIterator: | ||||
|     after : :class:`Message` or id-like | ||||
|         Message after which all messages must be. | ||||
|     reverse : bool | ||||
|             If set to true, return messages in oldest->newest order. Recommended when using with "after" queries, | ||||
|             otherwise messages will be out of order. Defaults to False for backwards compatability. | ||||
|         If set to true, return messages in oldest->newest order. Recommended | ||||
|         when using with "after" queries with limit over 100, otherwise messages | ||||
|         will be out of order. Defaults to False for backwards compatability. | ||||
|     """ | ||||
|         if before and after: | ||||
|             if reverse: | ||||
|                 return LogsFromBeforeAfterReversedIterator(client, channel, limit, before, after) | ||||
|             else: | ||||
|                 return LogsFromBeforeAfterIterator(client, channel, limit, before, after) | ||||
|         elif after: | ||||
|             return LogsFromAfterIterator(client, channel, limit, after, reverse=reverse) | ||||
|         else: | ||||
|             return LogsFromBeforeIterator(client, channel, limit, before) | ||||
|  | ||||
|     def __init__(self, client, channel, limit): | ||||
|     def __init__(self, client, channel, limit, | ||||
|                  before=None, after=None, reverse=False): | ||||
|         self.client = client | ||||
|         self.channel = channel | ||||
|         self.limit = limit | ||||
|         self.before = before | ||||
|         self.after = after | ||||
|         self.reverse = reverse | ||||
|         self._filter = None  # message dict -> bool | ||||
|         self.messages = asyncio.Queue() | ||||
|  | ||||
|         if self.before and self.after: | ||||
|             if self.reverse: | ||||
|                 self._retrieve_messages = self._retrieve_messages_after_strategy | ||||
|                 self._filter = lambda m: int(m['id']) < int(self.before.id) | ||||
|             else: | ||||
|                 self._retrieve_messages = self._retrieve_messages_before_strategy | ||||
|                 self._filter = lambda m: int(m['id']) > int(self.after.id) | ||||
|         elif self.after: | ||||
|             self._retrieve_messages = self._retrieve_messages_after_strategy | ||||
|         else: | ||||
|             self._retrieve_messages = self._retrieve_messages_before_strategy | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def iterate(self): | ||||
|         if self.messages.empty(): | ||||
| @@ -84,6 +96,41 @@ class LogsFromIterator: | ||||
|  | ||||
|         return self.messages.get_nowait() | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def fill_messages(self): | ||||
|         if self.limit > 0: | ||||
|             retrieve = self.limit if self.limit <= 100 else 100 | ||||
|             data = yield from self._retrieve_messages(retrieve) | ||||
|             if self.reverse: | ||||
|                 data = reversed(data) | ||||
|             if self._filter: | ||||
|                 data = filter(self._filter, data) | ||||
|             for element in data: | ||||
|                 yield from self.messages.put(Message(channel=self.channel, **element)) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def _retrieve_messages(self, retrieve): | ||||
|         """Retrieve messages and update next parameters.""" | ||||
|         pass | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def _retrieve_messages_before_strategy(self, retrieve): | ||||
|         """Retrieve messages using before parameter.""" | ||||
|         data = yield from self.client._logs_from(self.channel, retrieve, before=self.before) | ||||
|         if len(data): | ||||
|             self.limit -= retrieve | ||||
|             self.before = Object(id=data[-1]['id']) | ||||
|         return data | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def _retrieve_messages_after_strategy(self, retrieve): | ||||
|         """Retrieve messages using after parameter.""" | ||||
|         data = yield from self.client._logs_from(self.channel, retrieve, after=self.after) | ||||
|         if len(data): | ||||
|             self.limit -= retrieve | ||||
|             self.after = Object(id=data[0]['id']) | ||||
|         return data | ||||
|  | ||||
|     if PY35: | ||||
|         @asyncio.coroutine | ||||
|         def __aiter__(self): | ||||
| @@ -98,88 +145,3 @@ class LogsFromIterator: | ||||
|                 # if we're still empty at this point... | ||||
|                 # we didn't get any new messages so stop looping | ||||
|                 raise StopAsyncIteration() | ||||
|  | ||||
| class LogsFromBeforeIterator(LogsFromIterator): | ||||
|     def __init__(self, client, channel, limit, before): | ||||
|         super().__init__(client, channel, limit) | ||||
|         self.before = before | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def fill_messages(self): | ||||
|         if self.limit > 0: | ||||
|             retrieve = self.limit if self.limit <= 100 else 100 | ||||
|  | ||||
|             data = yield from self.client._logs_from(self.channel, retrieve, before=self.before) | ||||
|             if len(data): | ||||
|                 self.limit -= retrieve | ||||
|                 self.before = Object(id=data[-1]['id']) | ||||
|                 for element in data: | ||||
|                     yield from self.messages.put(Message(channel=self.channel, **element)) | ||||
|  | ||||
| class LogsFromAfterIterator(LogsFromIterator): | ||||
|     """Iterator for retrieving "after" style responses. | ||||
|  | ||||
|     Recommended to use with reverse=True - this will return messages oldest to newest. | ||||
|     With reverse=False, you'll recieve messages 99-0, 199-100, etc.""" | ||||
|     def __init__(self, client, channel, limit, after, *, reverse=False): | ||||
|         super().__init__(client, channel, limit) | ||||
|         self.after = after | ||||
|         self.reverse = reverse | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def fill_messages(self): | ||||
|         if self.limit > 0: | ||||
|             retrieve = self.limit if self.limit <= 100 else 100 | ||||
|  | ||||
|             data = yield from self.client._logs_from(self.channel, retrieve, after=self.after) | ||||
|             if len(data): | ||||
|                 self.limit -= retrieve | ||||
|                 self.after = Object(id=data[0]['id']) | ||||
|                 for element in (data if not self.reverse else reversed(data)): | ||||
|                     yield from self.messages.put(Message(channel=self.channel, **element)) | ||||
|  | ||||
| class LogsFromBeforeAfterIterator(LogsFromIterator): | ||||
|     """Newest -> Oldest.""" | ||||
|     def __init__(self, client, channel, limit, before, after): | ||||
|         super().__init__(client, channel, limit) | ||||
|         self.before = before | ||||
|         self.after = after | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def fill_messages(self): | ||||
|         if self.limit > 0: | ||||
|             retrieve = self.limit if self.limit <= 100 else 100 | ||||
|  | ||||
|             data = yield from self.client._logs_from(self.channel, retrieve, before=self.before) | ||||
|             if len(data): | ||||
|                 self.limit -= retrieve | ||||
|                 self.before = Object(id=data[-1]['id']) | ||||
|                 # Only filter if the oldest message is not after our endpoint | ||||
|                 if int(data[-1]['id']) <= int(self.after.id): | ||||
|                     data = filter(lambda d: int(d['id']) > int(self.after.id), data) | ||||
|                 for element in data: | ||||
|                         yield from self.messages.put(Message(channel=self.channel, **element)) | ||||
|  | ||||
| class LogsFromBeforeAfterReversedIterator(LogsFromIterator): | ||||
|     """Oldest -> Newest.""" | ||||
|     def __init__(self, client, channel, limit, before, after): | ||||
|         super().__init__(client, channel, limit) | ||||
|         self.before = before | ||||
|         self.after = after | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def fill_messages(self): | ||||
|         if self.limit > 0: | ||||
|             retrieve = self.limit if self.limit <= 100 else 100 | ||||
|  | ||||
|             data = yield from self.client._logs_from(self.channel, retrieve, after=self.after) | ||||
|             if len(data): | ||||
|                 self.limit -= retrieve | ||||
|                 self.after = Object(id=data[0]['id']) | ||||
|                 # Only filter if the newest is not before our endpoint | ||||
|                 if int(data[0]['id']) >= int(self.before.id): | ||||
|                     data = filter(lambda d: int(d['id']) < int(self.before.id), reversed(data)) | ||||
|                 else: | ||||
|                     data = reversed(data) | ||||
|                 for element in data: | ||||
|                     yield from self.messages.put(Message(channel=self.channel, **element)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user