Fix inaccuracies with AsyncIterator
typings
This commit is contained in:
parent
87e64dff06
commit
f8bea3bb05
@ -26,7 +26,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
from typing import TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator, Coroutine
|
||||
from typing import Awaitable, TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator
|
||||
|
||||
from .errors import NoMoreItems
|
||||
from .utils import time_snowflake, maybe_coroutine
|
||||
@ -50,16 +50,18 @@ if TYPE_CHECKING:
|
||||
|
||||
T = TypeVar('T')
|
||||
OT = TypeVar('OT')
|
||||
_Func = Callable[[T], Union[OT, Coroutine[Any, Any, OT]]]
|
||||
_Predicate = Callable[[T], Union[T, Coroutine[Any, Any, T]]]
|
||||
_Func = Callable[[T], Union[OT, Awaitable[OT]]]
|
||||
|
||||
OLDEST_OBJECT = Object(id=0)
|
||||
|
||||
class _AsyncIterator(AsyncIterator[T]):
|
||||
__slots__ = ()
|
||||
|
||||
def get(self, **attrs: Any) -> Optional[T]:
|
||||
def predicate(elem):
|
||||
async def next(self) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
def get(self, **attrs: Any) -> Awaitable[Optional[T]]:
|
||||
def predicate(elem: T):
|
||||
for attr, val in attrs.items():
|
||||
nested = attr.split('__')
|
||||
obj = elem
|
||||
@ -72,7 +74,7 @@ class _AsyncIterator(AsyncIterator[T]):
|
||||
|
||||
return self.find(predicate)
|
||||
|
||||
async def find(self, predicate: _Predicate[T]) -> Optional[T]:
|
||||
async def find(self, predicate: _Func[T, bool]) -> Optional[T]:
|
||||
while True:
|
||||
try:
|
||||
elem = await self.next()
|
||||
@ -91,7 +93,7 @@ class _AsyncIterator(AsyncIterator[T]):
|
||||
def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]:
|
||||
return _MappedAsyncIterator(self, func)
|
||||
|
||||
def filter(self, predicate: _Predicate[T]) -> _FilteredAsyncIterator[T]:
|
||||
def filter(self, predicate: _Func[T, bool]) -> _FilteredAsyncIterator[T]:
|
||||
return _FilteredAsyncIterator(self, predicate)
|
||||
|
||||
async def flatten(self) -> List[T]:
|
||||
@ -106,13 +108,13 @@ class _AsyncIterator(AsyncIterator[T]):
|
||||
def _identity(x):
|
||||
return x
|
||||
|
||||
class _ChunkedAsyncIterator(_AsyncIterator[T]):
|
||||
class _ChunkedAsyncIterator(_AsyncIterator[List[T]]):
|
||||
def __init__(self, iterator, max_size):
|
||||
self.iterator = iterator
|
||||
self.max_size = max_size
|
||||
|
||||
async def next(self) -> T:
|
||||
ret = []
|
||||
async def next(self) -> List[T]:
|
||||
ret: List[T] = []
|
||||
n = 0
|
||||
while n < self.max_size:
|
||||
try:
|
||||
@ -168,7 +170,7 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
|
||||
self.channel_id = message.channel.id
|
||||
self.users = asyncio.Queue()
|
||||
|
||||
async def next(self) -> T:
|
||||
async def next(self) -> Union[User, Member]:
|
||||
if self.users.empty():
|
||||
await self.fill_users()
|
||||
|
||||
@ -289,7 +291,7 @@ class HistoryIterator(_AsyncIterator['Message']):
|
||||
if (self.after and self.after != OLDEST_OBJECT):
|
||||
self._filter = lambda m: int(m['id']) > self.after.id
|
||||
|
||||
async def next(self) -> T:
|
||||
async def next(self) -> Message:
|
||||
if self.messages.empty():
|
||||
await self.fill_messages()
|
||||
|
||||
@ -422,7 +424,7 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
|
||||
self.after = Object(id=int(entries[0]['id']))
|
||||
return data.get('users', []), entries
|
||||
|
||||
async def next(self) -> T:
|
||||
async def next(self) -> AuditLogEntry:
|
||||
if self.entries.empty():
|
||||
await self._fill()
|
||||
|
||||
@ -519,7 +521,7 @@ class GuildIterator(_AsyncIterator['Guild']):
|
||||
else:
|
||||
self._retrieve_guilds = self._retrieve_guilds_before_strategy
|
||||
|
||||
async def next(self) -> T:
|
||||
async def next(self) -> Guild:
|
||||
if self.guilds.empty():
|
||||
await self.fill_guilds()
|
||||
|
||||
@ -591,7 +593,7 @@ class MemberIterator(_AsyncIterator['Member']):
|
||||
self.get_members = self.state.http.get_members
|
||||
self.members = asyncio.Queue()
|
||||
|
||||
async def next(self) -> T:
|
||||
async def next(self) -> Member:
|
||||
if self.members.empty():
|
||||
await self.fill_members()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user