Add around parameter to LogsFromIterator.

This commit is contained in:
khazhyk
2016-10-16 17:41:11 -07:00
parent e2667a6f66
commit 158ac6bb50
3 changed files with 50 additions and 9 deletions

View File

@ -60,6 +60,9 @@ class LogsFromIterator:
Message before which all messages must be.
after : :class:`Message` or id-like
Message after which all messages must be.
around : :class:`Message` or id-like
Message around which all messages must be. Limit max 101. Note that if
limit is an even number, this will return at most limit+1 messages.
reverse : bool
If set to true, return messages in oldest->newest order. Recommended
when using with "after" queries with limit over 100, otherwise messages
@ -67,17 +70,33 @@ class LogsFromIterator:
"""
def __init__(self, client, channel, limit,
before=None, after=None, reverse=False):
before=None, after=None, around=None, reverse=False):
self.client = client
self.channel = channel
self.limit = limit
self.before = before
self.after = after
self.around = around
self.reverse = reverse
self._filter = None # message dict -> bool
self.messages = asyncio.Queue()
if self.before and self.after:
if self.around:
if self.limit > 101:
raise ValueError("LogsFrom max limit 101 when specifying around parameter")
elif self.limit == 101:
self.limit = 100 # Thanks discord
elif self.limit == 1:
raise ValueError("Use get_message.")
self._retrieve_messages = self._retrieve_messages_around_strategy
if self.before and self.after:
self._filter = lambda m: int(self.after.id) < int(m['id']) < int(self.before.id)
elif self.before:
self._filter = lambda m: int(m['id']) < int(self.before.id)
elif self.after:
self._filter = lambda m: int(self.after.id) < int(m['id'])
elif self.before and self.after:
if self.reverse:
self._retrieve_messages = self._retrieve_messages_after_strategy
self._filter = lambda m: int(m['id']) < int(self.before.id)
@ -131,6 +150,15 @@ class LogsFromIterator:
self.after = Object(id=data[0]['id'])
return data
@asyncio.coroutine
def _retrieve_messages_around_strategy(self, retrieve):
"""Retrieve messages using around parameter."""
if self.around:
data = yield from self.client._logs_from(self.channel, retrieve, around=self.around)
self.around = None
return data
return []
if PY35:
@asyncio.coroutine
def __aiter__(self):