Refactor AsyncIter to use 3.6+ asynchronous generators

This commit is contained in:
Kaylynn Morgan
2022-02-20 13:58:13 +11:00
committed by GitHub
parent dc19c6c7d5
commit 588cda0996
8 changed files with 386 additions and 930 deletions

View File

@@ -25,11 +25,26 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import asyncio
import datetime
import logging
import signal
import sys
import traceback
from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union
from typing import (
Any,
AsyncIterator,
Callable,
Coroutine,
Dict,
Generator,
List,
Optional,
Sequence,
TYPE_CHECKING,
Tuple,
TypeVar,
Union
)
import aiohttp
@@ -51,11 +66,10 @@ from .voice_client import VoiceClient
from .http import HTTPClient
from .state import ConnectionState
from . import utils
from .utils import MISSING
from .utils import MISSING, time_snowflake
from .object import Object
from .backoff import ExponentialBackoff
from .webhook import Webhook
from .iterators import GuildIterator
from .appinfo import AppInfo
from .ui.view import View
from .stage_instance import StageInstance
@@ -63,6 +77,7 @@ from .threads import Thread
from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory
if TYPE_CHECKING:
from .types.guild import Guild as GuildPayload
from .abc import SnowflakeTime, PrivateChannel, GuildChannel, Snowflake
from .channel import DMChannel
from .message import Message
@@ -1120,14 +1135,14 @@ class Client:
# Guild stuff
def fetch_guilds(
async def fetch_guilds(
self,
*,
limit: Optional[int] = 100,
before: SnowflakeTime = None,
after: SnowflakeTime = None
) -> GuildIterator:
"""Retrieves an :class:`.AsyncIterator` that enables receiving your guilds.
before: Optional[SnowflakeTime] = None,
after: Optional[SnowflakeTime] = None,
) -> AsyncIterator[Guild]:
"""Retrieves an :term:`asynchronous iterator` that enables receiving your guilds.
.. note::
@@ -1148,7 +1163,7 @@ class Client:
Flattening into a list ::
guilds = await client.fetch_guilds(limit=150).flatten()
guilds = [guild async for guild in client.fetch_guilds(limit=150)]
# guilds is now a list of Guild...
All parameters are optional.
@@ -1179,7 +1194,60 @@ class Client:
:class:`.Guild`
The guild with the guild data parsed.
"""
return GuildIterator(self, limit=limit, before=before, after=after)
async def _before_strategy(retrieve, before, limit):
before_id = before.id if before else None
data = await self.http.get_guilds(retrieve, before=before_id)
if data:
if limit is not None:
limit -= len(data)
before = Object(id=int(data[-1]['id']))
return data, before, limit
async def _after_strategy(retrieve, after, limit):
after_id = after.id if after else None
data = await self.http.get_guilds(retrieve, after=after_id)
if data:
if limit is not None:
limit -= len(data)
after = Object(id=int(data[0]['id']))
return data, after, limit
if isinstance(before, datetime.datetime):
before = Object(id=time_snowflake(before, high=False))
if isinstance(after, datetime.datetime):
after = Object(id=time_snowflake(after, high=True))
predicate = None
strategy, state = _before_strategy, before
if before and after:
predicate = lambda m: int(m['id']) > after.id # type: ignore
elif after:
strategy, state = _after_strategy, after
while True:
retrieve = min(100 if limit is None else limit, 100)
if retrieve < 1:
return
data, state, limit = await strategy(retrieve, state, limit)
# Terminate loop on next iteration; there's no data left after this
if len(data) < 100:
limit = 0
if predicate:
data = filter(predicate, data)
for raw_guild in data:
yield Guild(state=self._connection, data=raw_guild)
async def fetch_template(self, code: Union[Template, str]) -> Template:
"""|coro|