diff --git a/discord/client.py b/discord/client.py index dbd40e397..c3594638f 100644 --- a/discord/client.py +++ b/discord/client.py @@ -978,7 +978,7 @@ class Client: yield from self.http.delete_messages(channel.id, message_ids, guild_id) @asyncio.coroutine - def purge_from(self, channel, *, limit=100, check=None, before=None, after=None): + def purge_from(self, channel, *, limit=100, check=None, before=None, after=None, around=None): """|coro| Purges a list of messages that meet the criteria given by the predicate @@ -1007,6 +1007,9 @@ class Client: after : :class:`Message` or `datetime` The message or date after which all deleted messages must be. If a date is provided it must be a timezone-naive datetime representing UTC time. + around : :class:`Message` or `datetime` + The message or date around which all deleted messages must be. + If a date is provided it must be a timezone-naive datetime representing UTC time. Raises ------- @@ -1040,8 +1043,10 @@ class Client: before = Object(utils.time_snowflake(before, high=False)) if isinstance(after, datetime.datetime): after = Object(utils.time_snowflake(after, high=True)) + if isinstance(around, datetime.datetime): + around = Object(utils.time_snowflake(around, high=True)) - iterator = LogsFromIterator(self, channel, limit, before=before, after=after) + iterator = LogsFromIterator(self, channel, limit, before=before, after=after, around=around) ret = [] count = 0 @@ -1209,7 +1214,7 @@ class Client: data = yield from self.http.pins_from(channel.id) return [Message(channel=channel, **m) for m in data] - def _logs_from(self, channel, limit=100, before=None, after=None): + def _logs_from(self, channel, limit=100, before=None, after=None, around=None): """|coro| This coroutine returns a generator that obtains logs from a specified channel. @@ -1226,6 +1231,9 @@ class Client: after : :class:`Message` or `datetime` The message or date after which all returned messages must be. If a date is provided it must be a timezone-naive datetime representing UTC time. + around : :class:`Message` or `datetime` + The message or date around which all returned messages must be. + If a date is provided it must be a timezone-naive datetime representing UTC time. Raises ------ @@ -1261,17 +1269,20 @@ class Client: """ before = getattr(before, 'id', None) after = getattr(after, 'id', None) + around = getattr(around, 'id', None) - return self.http.logs_from(channel.id, limit, before=before, after=after) + return self.http.logs_from(channel.id, limit, before=before, after=after, around=around) if PY35: - def logs_from(self, channel, limit=100, *, before=None, after=None, reverse=False): + def logs_from(self, channel, limit=100, *, before=None, after=None, around=None, reverse=False): if isinstance(before, datetime.datetime): before = Object(utils.time_snowflake(before, high=False)) if isinstance(after, datetime.datetime): after = Object(utils.time_snowflake(after, high=True)) + if isinstance(around, datetime.datetime): + around = Object(utils.time_snowflake(around)) - return LogsFromIterator(self, channel, limit, before=before, after=after, reverse=reverse) + return LogsFromIterator(self, channel, limit, before=before, after=after, around=around, reverse=reverse) else: @asyncio.coroutine def logs_from(self, channel, limit=100, *, before=None, after=None): diff --git a/discord/http.py b/discord/http.py index ada71f7c1..26824fcb2 100644 --- a/discord/http.py +++ b/discord/http.py @@ -265,7 +265,7 @@ class HTTPClient: url = '{0.CHANNELS}/{1}/messages/{2}'.format(self, channel_id, message_id) return self.get(url, bucket=_func_()) - def logs_from(self, channel_id, limit, before=None, after=None): + def logs_from(self, channel_id, limit, before=None, after=None, around=None): url = '{0.CHANNELS}/{1}/messages'.format(self, channel_id) params = { 'limit': limit @@ -275,6 +275,8 @@ class HTTPClient: params['before'] = before if after: params['after'] = after + if around: + params['around'] = around return self.get(url, params=params, bucket=_func_()) diff --git a/discord/iterators.py b/discord/iterators.py index 5b3dd2022..fbf1a72c6 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -60,6 +60,9 @@ class LogsFromIterator: Message before which all messages must be. after : :class:`Message` or id-like Message after which all messages must be. + around : :class:`Message` or id-like + Message around which all messages must be. Limit max 101. Note that if + limit is an even number, this will return at most limit+1 messages. reverse : bool If set to true, return messages in oldest->newest order. Recommended when using with "after" queries with limit over 100, otherwise messages @@ -67,17 +70,33 @@ class LogsFromIterator: """ def __init__(self, client, channel, limit, - before=None, after=None, reverse=False): + before=None, after=None, around=None, reverse=False): self.client = client self.channel = channel self.limit = limit self.before = before self.after = after + self.around = around self.reverse = reverse self._filter = None # message dict -> bool self.messages = asyncio.Queue() - if self.before and self.after: + if self.around: + if self.limit > 101: + raise ValueError("LogsFrom max limit 101 when specifying around parameter") + elif self.limit == 101: + self.limit = 100 # Thanks discord + elif self.limit == 1: + raise ValueError("Use get_message.") + + self._retrieve_messages = self._retrieve_messages_around_strategy + if self.before and self.after: + self._filter = lambda m: int(self.after.id) < int(m['id']) < int(self.before.id) + elif self.before: + self._filter = lambda m: int(m['id']) < int(self.before.id) + elif self.after: + self._filter = lambda m: int(self.after.id) < int(m['id']) + elif self.before and self.after: if self.reverse: self._retrieve_messages = self._retrieve_messages_after_strategy self._filter = lambda m: int(m['id']) < int(self.before.id) @@ -131,6 +150,15 @@ class LogsFromIterator: self.after = Object(id=data[0]['id']) return data + @asyncio.coroutine + def _retrieve_messages_around_strategy(self, retrieve): + """Retrieve messages using around parameter.""" + if self.around: + data = yield from self.client._logs_from(self.channel, retrieve, around=self.around) + self.around = None + return data + return [] + if PY35: @asyncio.coroutine def __aiter__(self):