logs_from improvements for after param.
- Properly support 'after' alone - Properly support both 'before' and 'after' - Add optional 'reverse' parameter to sort messages oldest->newest to 1) provide a sorted result set for 'after' 2) give flexibility when using both 'before' and 'after'
This commit is contained in:
		
				
					committed by
					
						 Rapptz
						Rapptz
					
				
			
			
				
	
			
			
			
						parent
						
							8e5347f4ed
						
					
				
				
					commit
					492c9afffb
				
			| @@ -1213,15 +1213,9 @@ class Client: | |||||||
|         } |         } | ||||||
|  |  | ||||||
|         if before: |         if before: | ||||||
|             if isinstance(before, datetime.datetime): |             params['before'] = before.id | ||||||
|                 params['before'] = utils.time_snowflake(before, high=False) |  | ||||||
|             else: |  | ||||||
|                 params['before'] = before.id |  | ||||||
|         if after: |         if after: | ||||||
|             if isinstance(after, datetime.datetime): |             params['after'] = after.id | ||||||
|                 params['after'] = utils.time_snowflake(after, high=True) |  | ||||||
|             else: |  | ||||||
|                 params['after'] = after.id |  | ||||||
|  |  | ||||||
|         response = yield from self.session.get(url, params=params, headers=self.headers) |         response = yield from self.session.get(url, params=params, headers=self.headers) | ||||||
|         log.debug(request_logging_format.format(method='GET', response=response)) |         log.debug(request_logging_format.format(method='GET', response=response)) | ||||||
| @@ -1230,11 +1224,21 @@ class Client: | |||||||
|         return messages |         return messages | ||||||
|  |  | ||||||
|     if PY35: |     if PY35: | ||||||
|         def logs_from(self, channel, limit=100, *, before=None, after=None): |         def logs_from(self, channel, limit=100, *, before=None, after=None, reverse=False): | ||||||
|             return LogsFromIterator(self, channel, limit, before, after) |             if isinstance(before, datetime.datetime): | ||||||
|  |                 before = Object(utils.time_snowflake(before, high=False)) | ||||||
|  |             if isinstance(after, datetime.datetime): | ||||||
|  |                 after = Object(utils.time_snowflake(after, high=True)) | ||||||
|  |  | ||||||
|  |             return LogsFromIterator.create(self, channel, limit, before=before, after=after, reverse=reverse) | ||||||
|     else: |     else: | ||||||
|         @asyncio.coroutine |         @asyncio.coroutine | ||||||
|         def logs_from(self, channel, limit=100, *, before=None, after=None): |         def logs_from(self, channel, limit=100, *, before=None, after=None): | ||||||
|  |             if isinstance(before, datetime.datetime): | ||||||
|  |                 before = Object(utils.time_snowflake(before, high=False)) | ||||||
|  |             if isinstance(after, datetime.datetime): | ||||||
|  |                 after = Object(utils.time_snowflake(after, high=True)) | ||||||
|  |  | ||||||
|             def generator(data): |             def generator(data): | ||||||
|                 for message in data: |                 for message in data: | ||||||
|                     yield Message(channel=channel, **message) |                     yield Message(channel=channel, **message) | ||||||
|   | |||||||
| @@ -33,25 +33,50 @@ from .object import Object | |||||||
| PY35 = sys.version_info >= (3, 5) | PY35 = sys.version_info >= (3, 5) | ||||||
|  |  | ||||||
| class LogsFromIterator: | class LogsFromIterator: | ||||||
|     def __init__(self, client, channel, limit, before, after): |     @staticmethod | ||||||
|  |     def create(client, channel, limit, *, before=None, after=None, reverse=False): | ||||||
|  |         """Create a proper iterator depending on parameters. | ||||||
|  |  | ||||||
|  |         The messages endpoint has two behaviors: | ||||||
|  |             If `before` is specified, it returns the `limit` newest messages before `before`, sorted with newest first. | ||||||
|  |               - Fill strategy - update 'before' to oldest message | ||||||
|  |             If `after` is specified, it returns the `limit` oldest messages after `after`, sorted with newest first. | ||||||
|  |               - Fill strategy - update 'after' to newest message | ||||||
|  |               - If messages are not reversed, they will be out of order (99-0, 199-100, so on) | ||||||
|  |  | ||||||
|  |         A note that if both before and after are specified, before is ignored by the messages endpoint. | ||||||
|  |  | ||||||
|  |         Parameters | ||||||
|  |         ----------- | ||||||
|  |         client : class:`Client` | ||||||
|  |         channel : class:`Channel` | ||||||
|  |             Channel from which to request logs | ||||||
|  |         limit : int | ||||||
|  |             Maximum number of messages to retrieve | ||||||
|  |         before : :class:`Message` or id-like | ||||||
|  |             Message before which all messages must be. | ||||||
|  |         after : :class:`Message` or id-like | ||||||
|  |             Message after which all messages must be. | ||||||
|  |         reverse : bool | ||||||
|  |             If set to true, return messages in oldest->newest order. Recommended when using with "after" queries, | ||||||
|  |             otherwise messages will be out of order. Defaults to False for backwards compatability. | ||||||
|  |         """ | ||||||
|  |         if before and after: | ||||||
|  |             if reverse: | ||||||
|  |                 return LogsFromBeforeAfterReversedIterator(client, channel, limit, before, after) | ||||||
|  |             else: | ||||||
|  |                 return LogsFromBeforeAfterIterator(client, channel, limit, before, after) | ||||||
|  |         elif after: | ||||||
|  |             return LogsFromAfterIterator(client, channel, limit, after, reverse=reverse) | ||||||
|  |         else: | ||||||
|  |             return LogsFromBeforeIterator(client, channel, limit, before) | ||||||
|  |  | ||||||
|  |     def __init__(self, client, channel, limit): | ||||||
|         self.client = client |         self.client = client | ||||||
|         self.channel = channel |         self.channel = channel | ||||||
|         self.limit = limit |         self.limit = limit | ||||||
|         self.before = before |  | ||||||
|         self.after = after |  | ||||||
|         self.messages = asyncio.Queue() |         self.messages = asyncio.Queue() | ||||||
|  |  | ||||||
|     @asyncio.coroutine |  | ||||||
|     def fill_messages(self): |  | ||||||
|         if self.limit > 0: |  | ||||||
|             retrieve = self.limit if self.limit <= 100 else 100 |  | ||||||
|             data = yield from self.client._logs_from(self.channel, retrieve, self.before, self.after) |  | ||||||
|             if len(data): |  | ||||||
|                 self.limit -= retrieve |  | ||||||
|                 self.before = Object(id=data[-1]['id']) |  | ||||||
|                 for element in data: |  | ||||||
|                     yield from self.messages.put(Message(channel=self.channel, **element)) |  | ||||||
|  |  | ||||||
|     @asyncio.coroutine |     @asyncio.coroutine | ||||||
|     def iterate(self): |     def iterate(self): | ||||||
|         if self.messages.empty(): |         if self.messages.empty(): | ||||||
| @@ -73,3 +98,88 @@ class LogsFromIterator: | |||||||
|                 # if we're still empty at this point... |                 # if we're still empty at this point... | ||||||
|                 # we didn't get any new messages so stop looping |                 # we didn't get any new messages so stop looping | ||||||
|                 raise StopAsyncIteration() |                 raise StopAsyncIteration() | ||||||
|  |  | ||||||
|  | class LogsFromBeforeIterator(LogsFromIterator): | ||||||
|  |     def __init__(self, client, channel, limit, before): | ||||||
|  |         super().__init__(client, channel, limit) | ||||||
|  |         self.before = before | ||||||
|  |  | ||||||
|  |     @asyncio.coroutine | ||||||
|  |     def fill_messages(self): | ||||||
|  |         if self.limit > 0: | ||||||
|  |             retrieve = self.limit if self.limit <= 100 else 100 | ||||||
|  |  | ||||||
|  |             data = yield from self.client._logs_from(self.channel, retrieve, before=self.before) | ||||||
|  |             if len(data): | ||||||
|  |                 self.limit -= retrieve | ||||||
|  |                 self.before = Object(id=data[-1]['id']) | ||||||
|  |                 for element in data: | ||||||
|  |                     yield from self.messages.put(Message(channel=self.channel, **element)) | ||||||
|  |  | ||||||
|  | class LogsFromAfterIterator(LogsFromIterator): | ||||||
|  |     """Iterator for retrieving "after" style responses. | ||||||
|  |  | ||||||
|  |     Recommended to use with reverse=True - this will return messages oldest to newest. | ||||||
|  |     With reverse=False, you'll recieve messages 99-0, 199-100, etc.""" | ||||||
|  |     def __init__(self, client, channel, limit, after, *, reverse=False): | ||||||
|  |         super().__init__(client, channel, limit) | ||||||
|  |         self.after = after | ||||||
|  |         self.reverse = reverse | ||||||
|  |  | ||||||
|  |     @asyncio.coroutine | ||||||
|  |     def fill_messages(self): | ||||||
|  |         if self.limit > 0: | ||||||
|  |             retrieve = self.limit if self.limit <= 100 else 100 | ||||||
|  |  | ||||||
|  |             data = yield from self.client._logs_from(self.channel, retrieve, after=self.after) | ||||||
|  |             if len(data): | ||||||
|  |                 self.limit -= retrieve | ||||||
|  |                 self.after = Object(id=data[0]['id']) | ||||||
|  |                 for element in (data if not self.reverse else reversed(data)): | ||||||
|  |                     yield from self.messages.put(Message(channel=self.channel, **element)) | ||||||
|  |  | ||||||
|  | class LogsFromBeforeAfterIterator(LogsFromIterator): | ||||||
|  |     """Newest -> Oldest.""" | ||||||
|  |     def __init__(self, client, channel, limit, before, after): | ||||||
|  |         super().__init__(client, channel, limit) | ||||||
|  |         self.before = before | ||||||
|  |         self.after = after | ||||||
|  |  | ||||||
|  |     @asyncio.coroutine | ||||||
|  |     def fill_messages(self): | ||||||
|  |         if self.limit > 0: | ||||||
|  |             retrieve = self.limit if self.limit <= 100 else 100 | ||||||
|  |  | ||||||
|  |             data = yield from self.client._logs_from(self.channel, retrieve, before=self.before) | ||||||
|  |             if len(data): | ||||||
|  |                 self.limit -= retrieve | ||||||
|  |                 self.before = Object(id=data[-1]['id']) | ||||||
|  |                 # Only filter if the oldest message is not after our endpoint | ||||||
|  |                 if int(data[-1]['id']) <= int(self.after.id): | ||||||
|  |                     data = filter(lambda d: int(d['id']) > int(self.after.id), data) | ||||||
|  |                 for element in data: | ||||||
|  |                         yield from self.messages.put(Message(channel=self.channel, **element)) | ||||||
|  |  | ||||||
|  | class LogsFromBeforeAfterReversedIterator(LogsFromIterator): | ||||||
|  |     """Oldest -> Newest.""" | ||||||
|  |     def __init__(self, client, channel, limit, before, after): | ||||||
|  |         super().__init__(client, channel, limit) | ||||||
|  |         self.before = before | ||||||
|  |         self.after = after | ||||||
|  |  | ||||||
|  |     @asyncio.coroutine | ||||||
|  |     def fill_messages(self): | ||||||
|  |         if self.limit > 0: | ||||||
|  |             retrieve = self.limit if self.limit <= 100 else 100 | ||||||
|  |  | ||||||
|  |             data = yield from self.client._logs_from(self.channel, retrieve, after=self.after) | ||||||
|  |             if len(data): | ||||||
|  |                 self.limit -= retrieve | ||||||
|  |                 self.after = Object(id=data[0]['id']) | ||||||
|  |                 # Only filter if the newest is not before our endpoint | ||||||
|  |                 if int(data[0]['id']) >= int(self.before.id): | ||||||
|  |                     data = filter(lambda d: int(d['id']) < int(self.before.id), reversed(data)) | ||||||
|  |                 else: | ||||||
|  |                     data = reversed(data) | ||||||
|  |                 for element in data: | ||||||
|  |                     yield from self.messages.put(Message(channel=self.channel, **element)) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user