logs_from improvements for after param.
- Properly support 'after' alone - Properly support both 'before' and 'after' - Add optional 'reverse' parameter to sort messages oldest->newest to 1) provide a sorted result set for 'after' 2) give flexibility when using both 'before' and 'after'
This commit is contained in:
		
				
					committed by
					
						 Rapptz
						Rapptz
					
				
			
			
				
	
			
			
			
						parent
						
							8e5347f4ed
						
					
				
				
					commit
					492c9afffb
				
			| @@ -1213,15 +1213,9 @@ class Client: | ||||
|         } | ||||
|  | ||||
|         if before: | ||||
|             if isinstance(before, datetime.datetime): | ||||
|                 params['before'] = utils.time_snowflake(before, high=False) | ||||
|             else: | ||||
|                 params['before'] = before.id | ||||
|             params['before'] = before.id | ||||
|         if after: | ||||
|             if isinstance(after, datetime.datetime): | ||||
|                 params['after'] = utils.time_snowflake(after, high=True) | ||||
|             else: | ||||
|                 params['after'] = after.id | ||||
|             params['after'] = after.id | ||||
|  | ||||
|         response = yield from self.session.get(url, params=params, headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='GET', response=response)) | ||||
| @@ -1230,11 +1224,21 @@ class Client: | ||||
|         return messages | ||||
|  | ||||
|     if PY35: | ||||
|         def logs_from(self, channel, limit=100, *, before=None, after=None): | ||||
|             return LogsFromIterator(self, channel, limit, before, after) | ||||
|         def logs_from(self, channel, limit=100, *, before=None, after=None, reverse=False): | ||||
|             if isinstance(before, datetime.datetime): | ||||
|                 before = Object(utils.time_snowflake(before, high=False)) | ||||
|             if isinstance(after, datetime.datetime): | ||||
|                 after = Object(utils.time_snowflake(after, high=True)) | ||||
|  | ||||
|             return LogsFromIterator.create(self, channel, limit, before=before, after=after, reverse=reverse) | ||||
|     else: | ||||
|         @asyncio.coroutine | ||||
|         def logs_from(self, channel, limit=100, *, before=None, after=None): | ||||
|             if isinstance(before, datetime.datetime): | ||||
|                 before = Object(utils.time_snowflake(before, high=False)) | ||||
|             if isinstance(after, datetime.datetime): | ||||
|                 after = Object(utils.time_snowflake(after, high=True)) | ||||
|  | ||||
|             def generator(data): | ||||
|                 for message in data: | ||||
|                     yield Message(channel=channel, **message) | ||||
|   | ||||
| @@ -33,25 +33,50 @@ from .object import Object | ||||
| PY35 = sys.version_info >= (3, 5) | ||||
|  | ||||
| class LogsFromIterator: | ||||
|     def __init__(self, client, channel, limit, before, after): | ||||
|     @staticmethod | ||||
|     def create(client, channel, limit, *, before=None, after=None, reverse=False): | ||||
|         """Create a proper iterator depending on parameters. | ||||
|  | ||||
|         The messages endpoint has two behaviors: | ||||
|             If `before` is specified, it returns the `limit` newest messages before `before`, sorted with newest first. | ||||
|               - Fill strategy - update 'before' to oldest message | ||||
|             If `after` is specified, it returns the `limit` oldest messages after `after`, sorted with newest first. | ||||
|               - Fill strategy - update 'after' to newest message | ||||
|               - If messages are not reversed, they will be out of order (99-0, 199-100, so on) | ||||
|  | ||||
|         A note that if both before and after are specified, before is ignored by the messages endpoint. | ||||
|  | ||||
|         Parameters | ||||
|         ----------- | ||||
|         client : class:`Client` | ||||
|         channel : class:`Channel` | ||||
|             Channel from which to request logs | ||||
|         limit : int | ||||
|             Maximum number of messages to retrieve | ||||
|         before : :class:`Message` or id-like | ||||
|             Message before which all messages must be. | ||||
|         after : :class:`Message` or id-like | ||||
|             Message after which all messages must be. | ||||
|         reverse : bool | ||||
|             If set to true, return messages in oldest->newest order. Recommended when using with "after" queries, | ||||
|             otherwise messages will be out of order. Defaults to False for backwards compatability. | ||||
|         """ | ||||
|         if before and after: | ||||
|             if reverse: | ||||
|                 return LogsFromBeforeAfterReversedIterator(client, channel, limit, before, after) | ||||
|             else: | ||||
|                 return LogsFromBeforeAfterIterator(client, channel, limit, before, after) | ||||
|         elif after: | ||||
|             return LogsFromAfterIterator(client, channel, limit, after, reverse=reverse) | ||||
|         else: | ||||
|             return LogsFromBeforeIterator(client, channel, limit, before) | ||||
|  | ||||
|     def __init__(self, client, channel, limit): | ||||
|         self.client = client | ||||
|         self.channel = channel | ||||
|         self.limit = limit | ||||
|         self.before = before | ||||
|         self.after = after | ||||
|         self.messages = asyncio.Queue() | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def fill_messages(self): | ||||
|         if self.limit > 0: | ||||
|             retrieve = self.limit if self.limit <= 100 else 100 | ||||
|             data = yield from self.client._logs_from(self.channel, retrieve, self.before, self.after) | ||||
|             if len(data): | ||||
|                 self.limit -= retrieve | ||||
|                 self.before = Object(id=data[-1]['id']) | ||||
|                 for element in data: | ||||
|                     yield from self.messages.put(Message(channel=self.channel, **element)) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def iterate(self): | ||||
|         if self.messages.empty(): | ||||
| @@ -73,3 +98,88 @@ class LogsFromIterator: | ||||
|                 # if we're still empty at this point... | ||||
|                 # we didn't get any new messages so stop looping | ||||
|                 raise StopAsyncIteration() | ||||
|  | ||||
| class LogsFromBeforeIterator(LogsFromIterator): | ||||
|     def __init__(self, client, channel, limit, before): | ||||
|         super().__init__(client, channel, limit) | ||||
|         self.before = before | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def fill_messages(self): | ||||
|         if self.limit > 0: | ||||
|             retrieve = self.limit if self.limit <= 100 else 100 | ||||
|  | ||||
|             data = yield from self.client._logs_from(self.channel, retrieve, before=self.before) | ||||
|             if len(data): | ||||
|                 self.limit -= retrieve | ||||
|                 self.before = Object(id=data[-1]['id']) | ||||
|                 for element in data: | ||||
|                     yield from self.messages.put(Message(channel=self.channel, **element)) | ||||
|  | ||||
| class LogsFromAfterIterator(LogsFromIterator): | ||||
|     """Iterator for retrieving "after" style responses. | ||||
|  | ||||
|     Recommended to use with reverse=True - this will return messages oldest to newest. | ||||
|     With reverse=False, you'll recieve messages 99-0, 199-100, etc.""" | ||||
|     def __init__(self, client, channel, limit, after, *, reverse=False): | ||||
|         super().__init__(client, channel, limit) | ||||
|         self.after = after | ||||
|         self.reverse = reverse | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def fill_messages(self): | ||||
|         if self.limit > 0: | ||||
|             retrieve = self.limit if self.limit <= 100 else 100 | ||||
|  | ||||
|             data = yield from self.client._logs_from(self.channel, retrieve, after=self.after) | ||||
|             if len(data): | ||||
|                 self.limit -= retrieve | ||||
|                 self.after = Object(id=data[0]['id']) | ||||
|                 for element in (data if not self.reverse else reversed(data)): | ||||
|                     yield from self.messages.put(Message(channel=self.channel, **element)) | ||||
|  | ||||
| class LogsFromBeforeAfterIterator(LogsFromIterator): | ||||
|     """Newest -> Oldest.""" | ||||
|     def __init__(self, client, channel, limit, before, after): | ||||
|         super().__init__(client, channel, limit) | ||||
|         self.before = before | ||||
|         self.after = after | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def fill_messages(self): | ||||
|         if self.limit > 0: | ||||
|             retrieve = self.limit if self.limit <= 100 else 100 | ||||
|  | ||||
|             data = yield from self.client._logs_from(self.channel, retrieve, before=self.before) | ||||
|             if len(data): | ||||
|                 self.limit -= retrieve | ||||
|                 self.before = Object(id=data[-1]['id']) | ||||
|                 # Only filter if the oldest message is not after our endpoint | ||||
|                 if int(data[-1]['id']) <= int(self.after.id): | ||||
|                     data = filter(lambda d: int(d['id']) > int(self.after.id), data) | ||||
|                 for element in data: | ||||
|                         yield from self.messages.put(Message(channel=self.channel, **element)) | ||||
|  | ||||
| class LogsFromBeforeAfterReversedIterator(LogsFromIterator): | ||||
|     """Oldest -> Newest.""" | ||||
|     def __init__(self, client, channel, limit, before, after): | ||||
|         super().__init__(client, channel, limit) | ||||
|         self.before = before | ||||
|         self.after = after | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def fill_messages(self): | ||||
|         if self.limit > 0: | ||||
|             retrieve = self.limit if self.limit <= 100 else 100 | ||||
|  | ||||
|             data = yield from self.client._logs_from(self.channel, retrieve, after=self.after) | ||||
|             if len(data): | ||||
|                 self.limit -= retrieve | ||||
|                 self.after = Object(id=data[0]['id']) | ||||
|                 # Only filter if the newest is not before our endpoint | ||||
|                 if int(data[0]['id']) >= int(self.before.id): | ||||
|                     data = filter(lambda d: int(d['id']) < int(self.before.id), reversed(data)) | ||||
|                 else: | ||||
|                     data = reversed(data) | ||||
|                 for element in data: | ||||
|                     yield from self.messages.put(Message(channel=self.channel, **element)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user