Add around parameter to LogsFromIterator.
This commit is contained in:
parent
e2667a6f66
commit
158ac6bb50
@ -978,7 +978,7 @@ class Client:
|
|||||||
yield from self.http.delete_messages(channel.id, message_ids, guild_id)
|
yield from self.http.delete_messages(channel.id, message_ids, guild_id)
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def purge_from(self, channel, *, limit=100, check=None, before=None, after=None):
|
def purge_from(self, channel, *, limit=100, check=None, before=None, after=None, around=None):
|
||||||
"""|coro|
|
"""|coro|
|
||||||
|
|
||||||
Purges a list of messages that meet the criteria given by the predicate
|
Purges a list of messages that meet the criteria given by the predicate
|
||||||
@ -1007,6 +1007,9 @@ class Client:
|
|||||||
after : :class:`Message` or `datetime`
|
after : :class:`Message` or `datetime`
|
||||||
The message or date after which all deleted messages must be.
|
The message or date after which all deleted messages must be.
|
||||||
If a date is provided it must be a timezone-naive datetime representing UTC time.
|
If a date is provided it must be a timezone-naive datetime representing UTC time.
|
||||||
|
around : :class:`Message` or `datetime`
|
||||||
|
The message or date around which all deleted messages must be.
|
||||||
|
If a date is provided it must be a timezone-naive datetime representing UTC time.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
-------
|
-------
|
||||||
@ -1040,8 +1043,10 @@ class Client:
|
|||||||
before = Object(utils.time_snowflake(before, high=False))
|
before = Object(utils.time_snowflake(before, high=False))
|
||||||
if isinstance(after, datetime.datetime):
|
if isinstance(after, datetime.datetime):
|
||||||
after = Object(utils.time_snowflake(after, high=True))
|
after = Object(utils.time_snowflake(after, high=True))
|
||||||
|
if isinstance(around, datetime.datetime):
|
||||||
|
around = Object(utils.time_snowflake(around, high=True))
|
||||||
|
|
||||||
iterator = LogsFromIterator(self, channel, limit, before=before, after=after)
|
iterator = LogsFromIterator(self, channel, limit, before=before, after=after, around=around)
|
||||||
ret = []
|
ret = []
|
||||||
count = 0
|
count = 0
|
||||||
|
|
||||||
@ -1209,7 +1214,7 @@ class Client:
|
|||||||
data = yield from self.http.pins_from(channel.id)
|
data = yield from self.http.pins_from(channel.id)
|
||||||
return [Message(channel=channel, **m) for m in data]
|
return [Message(channel=channel, **m) for m in data]
|
||||||
|
|
||||||
def _logs_from(self, channel, limit=100, before=None, after=None):
|
def _logs_from(self, channel, limit=100, before=None, after=None, around=None):
|
||||||
"""|coro|
|
"""|coro|
|
||||||
|
|
||||||
This coroutine returns a generator that obtains logs from a specified channel.
|
This coroutine returns a generator that obtains logs from a specified channel.
|
||||||
@ -1226,6 +1231,9 @@ class Client:
|
|||||||
after : :class:`Message` or `datetime`
|
after : :class:`Message` or `datetime`
|
||||||
The message or date after which all returned messages must be.
|
The message or date after which all returned messages must be.
|
||||||
If a date is provided it must be a timezone-naive datetime representing UTC time.
|
If a date is provided it must be a timezone-naive datetime representing UTC time.
|
||||||
|
around : :class:`Message` or `datetime`
|
||||||
|
The message or date around which all returned messages must be.
|
||||||
|
If a date is provided it must be a timezone-naive datetime representing UTC time.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
@ -1261,17 +1269,20 @@ class Client:
|
|||||||
"""
|
"""
|
||||||
before = getattr(before, 'id', None)
|
before = getattr(before, 'id', None)
|
||||||
after = getattr(after, 'id', None)
|
after = getattr(after, 'id', None)
|
||||||
|
around = getattr(around, 'id', None)
|
||||||
|
|
||||||
return self.http.logs_from(channel.id, limit, before=before, after=after)
|
return self.http.logs_from(channel.id, limit, before=before, after=after, around=around)
|
||||||
|
|
||||||
if PY35:
|
if PY35:
|
||||||
def logs_from(self, channel, limit=100, *, before=None, after=None, reverse=False):
|
def logs_from(self, channel, limit=100, *, before=None, after=None, around=None, reverse=False):
|
||||||
if isinstance(before, datetime.datetime):
|
if isinstance(before, datetime.datetime):
|
||||||
before = Object(utils.time_snowflake(before, high=False))
|
before = Object(utils.time_snowflake(before, high=False))
|
||||||
if isinstance(after, datetime.datetime):
|
if isinstance(after, datetime.datetime):
|
||||||
after = Object(utils.time_snowflake(after, high=True))
|
after = Object(utils.time_snowflake(after, high=True))
|
||||||
|
if isinstance(around, datetime.datetime):
|
||||||
|
around = Object(utils.time_snowflake(around))
|
||||||
|
|
||||||
return LogsFromIterator(self, channel, limit, before=before, after=after, reverse=reverse)
|
return LogsFromIterator(self, channel, limit, before=before, after=after, around=around, reverse=reverse)
|
||||||
else:
|
else:
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def logs_from(self, channel, limit=100, *, before=None, after=None):
|
def logs_from(self, channel, limit=100, *, before=None, after=None):
|
||||||
|
@ -265,7 +265,7 @@ class HTTPClient:
|
|||||||
url = '{0.CHANNELS}/{1}/messages/{2}'.format(self, channel_id, message_id)
|
url = '{0.CHANNELS}/{1}/messages/{2}'.format(self, channel_id, message_id)
|
||||||
return self.get(url, bucket=_func_())
|
return self.get(url, bucket=_func_())
|
||||||
|
|
||||||
def logs_from(self, channel_id, limit, before=None, after=None):
|
def logs_from(self, channel_id, limit, before=None, after=None, around=None):
|
||||||
url = '{0.CHANNELS}/{1}/messages'.format(self, channel_id)
|
url = '{0.CHANNELS}/{1}/messages'.format(self, channel_id)
|
||||||
params = {
|
params = {
|
||||||
'limit': limit
|
'limit': limit
|
||||||
@ -275,6 +275,8 @@ class HTTPClient:
|
|||||||
params['before'] = before
|
params['before'] = before
|
||||||
if after:
|
if after:
|
||||||
params['after'] = after
|
params['after'] = after
|
||||||
|
if around:
|
||||||
|
params['around'] = around
|
||||||
|
|
||||||
return self.get(url, params=params, bucket=_func_())
|
return self.get(url, params=params, bucket=_func_())
|
||||||
|
|
||||||
|
@ -60,6 +60,9 @@ class LogsFromIterator:
|
|||||||
Message before which all messages must be.
|
Message before which all messages must be.
|
||||||
after : :class:`Message` or id-like
|
after : :class:`Message` or id-like
|
||||||
Message after which all messages must be.
|
Message after which all messages must be.
|
||||||
|
around : :class:`Message` or id-like
|
||||||
|
Message around which all messages must be. Limit max 101. Note that if
|
||||||
|
limit is an even number, this will return at most limit+1 messages.
|
||||||
reverse : bool
|
reverse : bool
|
||||||
If set to true, return messages in oldest->newest order. Recommended
|
If set to true, return messages in oldest->newest order. Recommended
|
||||||
when using with "after" queries with limit over 100, otherwise messages
|
when using with "after" queries with limit over 100, otherwise messages
|
||||||
@ -67,17 +70,33 @@ class LogsFromIterator:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, client, channel, limit,
|
def __init__(self, client, channel, limit,
|
||||||
before=None, after=None, reverse=False):
|
before=None, after=None, around=None, reverse=False):
|
||||||
self.client = client
|
self.client = client
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
self.limit = limit
|
self.limit = limit
|
||||||
self.before = before
|
self.before = before
|
||||||
self.after = after
|
self.after = after
|
||||||
|
self.around = around
|
||||||
self.reverse = reverse
|
self.reverse = reverse
|
||||||
self._filter = None # message dict -> bool
|
self._filter = None # message dict -> bool
|
||||||
self.messages = asyncio.Queue()
|
self.messages = asyncio.Queue()
|
||||||
|
|
||||||
|
if self.around:
|
||||||
|
if self.limit > 101:
|
||||||
|
raise ValueError("LogsFrom max limit 101 when specifying around parameter")
|
||||||
|
elif self.limit == 101:
|
||||||
|
self.limit = 100 # Thanks discord
|
||||||
|
elif self.limit == 1:
|
||||||
|
raise ValueError("Use get_message.")
|
||||||
|
|
||||||
|
self._retrieve_messages = self._retrieve_messages_around_strategy
|
||||||
if self.before and self.after:
|
if self.before and self.after:
|
||||||
|
self._filter = lambda m: int(self.after.id) < int(m['id']) < int(self.before.id)
|
||||||
|
elif self.before:
|
||||||
|
self._filter = lambda m: int(m['id']) < int(self.before.id)
|
||||||
|
elif self.after:
|
||||||
|
self._filter = lambda m: int(self.after.id) < int(m['id'])
|
||||||
|
elif self.before and self.after:
|
||||||
if self.reverse:
|
if self.reverse:
|
||||||
self._retrieve_messages = self._retrieve_messages_after_strategy
|
self._retrieve_messages = self._retrieve_messages_after_strategy
|
||||||
self._filter = lambda m: int(m['id']) < int(self.before.id)
|
self._filter = lambda m: int(m['id']) < int(self.before.id)
|
||||||
@ -131,6 +150,15 @@ class LogsFromIterator:
|
|||||||
self.after = Object(id=data[0]['id'])
|
self.after = Object(id=data[0]['id'])
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def _retrieve_messages_around_strategy(self, retrieve):
|
||||||
|
"""Retrieve messages using around parameter."""
|
||||||
|
if self.around:
|
||||||
|
data = yield from self.client._logs_from(self.channel, retrieve, around=self.around)
|
||||||
|
self.around = None
|
||||||
|
return data
|
||||||
|
return []
|
||||||
|
|
||||||
if PY35:
|
if PY35:
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def __aiter__(self):
|
def __aiter__(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user