Refactor AsyncIter to use 3.6+ asynchronous generators

This commit is contained in:
Kaylynn Morgan
2022-02-20 13:58:13 +11:00
committed by GitHub
parent dc19c6c7d5
commit 588cda0996
8 changed files with 386 additions and 930 deletions

View File

@@ -26,8 +26,10 @@ from __future__ import annotations
import copy
import asyncio
from datetime import datetime
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
List,
@@ -42,7 +44,7 @@ from typing import (
runtime_checkable,
)
from .iterators import HistoryIterator
from .object import OLDEST_OBJECT, Object
from .context_managers import Typing
from .enums import ChannelType
from .errors import InvalidArgument, ClientException
@@ -68,8 +70,6 @@ __all__ = (
T = TypeVar('T', bound=VoiceProtocol)
if TYPE_CHECKING:
from datetime import datetime
from .client import Client
from .user import ClientUser
from .asset import Asset
@@ -1465,7 +1465,7 @@ class Messageable:
data = await state.http.pins_from(channel.id)
return [state.create_message(channel=channel, data=m) for m in data]
def history(
async def history(
self,
*,
limit: Optional[int] = 100,
@@ -1473,8 +1473,8 @@ class Messageable:
after: Optional[SnowflakeTime] = None,
around: Optional[SnowflakeTime] = None,
oldest_first: Optional[bool] = None,
) -> HistoryIterator:
"""Returns an :class:`~discord.AsyncIterator` that enables receiving the destination's message history.
) -> AsyncIterator[Message]:
"""Returns an :term:`asynchronous iterator` that enables receiving the destination's message history.
You must have :attr:`~discord.Permissions.read_message_history` permissions to use this.
@@ -1490,7 +1490,7 @@ class Messageable:
Flattening into a list: ::
messages = await channel.history(limit=123).flatten()
messages = [message async for message in channel.history(limit=123)]
# messages is now a list of Message...
All parameters are optional.
@@ -1531,7 +1531,101 @@ class Messageable:
:class:`~discord.Message`
The message with the message data parsed.
"""
return HistoryIterator(self, limit=limit, before=before, after=after, around=around, oldest_first=oldest_first)
async def _around_strategy(retrieve, around, limit):
if not around:
return []
around_id = around.id if around else None
data = await self._state.http.logs_from(channel.id, retrieve, around=around_id)
return data, None, limit
async def _after_strategy(retrieve, after, limit):
after_id = after.id if after else None
data = await self._state.http.logs_from(channel.id, retrieve, after=after_id)
if data:
if limit is not None:
limit -= len(data)
after = Object(id=int(data[0]['id']))
return data, after, limit
async def _before_strategy(retrieve, before, limit):
before_id = before.id if before else None
data = await self._state.http.logs_from(channel.id, retrieve, before=before_id)
if data:
if limit is not None:
limit -= len(data)
before = Object(id=int(data[-1]['id']))
return data, before, limit
if isinstance(before, datetime):
before = Object(id=utils.time_snowflake(before, high=False))
if isinstance(after, datetime):
after = Object(id=utils.time_snowflake(after, high=True))
if isinstance(around, datetime):
around = Object(id=utils.time_snowflake(around))
if oldest_first is None:
reverse = after is not None
else:
reverse = oldest_first
after = after or OLDEST_OBJECT
predicate = None
if around:
if limit is None:
raise ValueError('history does not support around with limit=None')
if limit > 101:
raise ValueError("history max limit 101 when specifying around parameter")
# Strange Discord quirk
limit = 100 if limit == 101 else limit
strategy, state = _around_strategy, around
if before and after:
predicate = lambda m: after.id < int(m['id']) < before.id
elif before:
predicate = lambda m: int(m['id']) < before.id
elif after:
predicate = lambda m: after.id < int(m['id'])
elif reverse:
strategy, state = _after_strategy, after
if before:
predicate = lambda m: int(m['id']) < before.id
else:
strategy, state = _before_strategy, before
if after and after != OLDEST_OBJECT:
predicate = lambda m: int(m['id']) > after.id
channel = await self._get_channel()
while True:
retrieve = min(100 if limit is None else limit, 100)
if retrieve < 1:
return
data, state, limit = await strategy(retrieve, state, limit)
# Terminate loop on next iteration; there's no data left after this
if len(data) < 100:
limit = 0
if reverse:
data = reversed(data)
if predicate:
data = filter(predicate, data)
for raw_message in data:
yield self._state.create_message(channel=channel, data=raw_message)
class Connectable(Protocol):