mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-06-07 12:18:59 +00:00
Fix inaccuracies with AsyncIterator
typings
This commit is contained in:
parent
87e64dff06
commit
f8bea3bb05
@ -26,7 +26,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
from typing import TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator, Coroutine
|
from typing import Awaitable, TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator
|
||||||
|
|
||||||
from .errors import NoMoreItems
|
from .errors import NoMoreItems
|
||||||
from .utils import time_snowflake, maybe_coroutine
|
from .utils import time_snowflake, maybe_coroutine
|
||||||
@ -50,16 +50,18 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
OT = TypeVar('OT')
|
OT = TypeVar('OT')
|
||||||
_Func = Callable[[T], Union[OT, Coroutine[Any, Any, OT]]]
|
_Func = Callable[[T], Union[OT, Awaitable[OT]]]
|
||||||
_Predicate = Callable[[T], Union[T, Coroutine[Any, Any, T]]]
|
|
||||||
|
|
||||||
OLDEST_OBJECT = Object(id=0)
|
OLDEST_OBJECT = Object(id=0)
|
||||||
|
|
||||||
class _AsyncIterator(AsyncIterator[T]):
|
class _AsyncIterator(AsyncIterator[T]):
|
||||||
__slots__ = ()
|
__slots__ = ()
|
||||||
|
|
||||||
def get(self, **attrs: Any) -> Optional[T]:
|
async def next(self) -> T:
|
||||||
def predicate(elem):
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get(self, **attrs: Any) -> Awaitable[Optional[T]]:
|
||||||
|
def predicate(elem: T):
|
||||||
for attr, val in attrs.items():
|
for attr, val in attrs.items():
|
||||||
nested = attr.split('__')
|
nested = attr.split('__')
|
||||||
obj = elem
|
obj = elem
|
||||||
@ -72,7 +74,7 @@ class _AsyncIterator(AsyncIterator[T]):
|
|||||||
|
|
||||||
return self.find(predicate)
|
return self.find(predicate)
|
||||||
|
|
||||||
async def find(self, predicate: _Predicate[T]) -> Optional[T]:
|
async def find(self, predicate: _Func[T, bool]) -> Optional[T]:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
elem = await self.next()
|
elem = await self.next()
|
||||||
@ -91,7 +93,7 @@ class _AsyncIterator(AsyncIterator[T]):
|
|||||||
def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]:
|
def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]:
|
||||||
return _MappedAsyncIterator(self, func)
|
return _MappedAsyncIterator(self, func)
|
||||||
|
|
||||||
def filter(self, predicate: _Predicate[T]) -> _FilteredAsyncIterator[T]:
|
def filter(self, predicate: _Func[T, bool]) -> _FilteredAsyncIterator[T]:
|
||||||
return _FilteredAsyncIterator(self, predicate)
|
return _FilteredAsyncIterator(self, predicate)
|
||||||
|
|
||||||
async def flatten(self) -> List[T]:
|
async def flatten(self) -> List[T]:
|
||||||
@ -106,13 +108,13 @@ class _AsyncIterator(AsyncIterator[T]):
|
|||||||
def _identity(x):
|
def _identity(x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class _ChunkedAsyncIterator(_AsyncIterator[T]):
|
class _ChunkedAsyncIterator(_AsyncIterator[List[T]]):
|
||||||
def __init__(self, iterator, max_size):
|
def __init__(self, iterator, max_size):
|
||||||
self.iterator = iterator
|
self.iterator = iterator
|
||||||
self.max_size = max_size
|
self.max_size = max_size
|
||||||
|
|
||||||
async def next(self) -> T:
|
async def next(self) -> List[T]:
|
||||||
ret = []
|
ret: List[T] = []
|
||||||
n = 0
|
n = 0
|
||||||
while n < self.max_size:
|
while n < self.max_size:
|
||||||
try:
|
try:
|
||||||
@ -168,7 +170,7 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
|
|||||||
self.channel_id = message.channel.id
|
self.channel_id = message.channel.id
|
||||||
self.users = asyncio.Queue()
|
self.users = asyncio.Queue()
|
||||||
|
|
||||||
async def next(self) -> T:
|
async def next(self) -> Union[User, Member]:
|
||||||
if self.users.empty():
|
if self.users.empty():
|
||||||
await self.fill_users()
|
await self.fill_users()
|
||||||
|
|
||||||
@ -289,7 +291,7 @@ class HistoryIterator(_AsyncIterator['Message']):
|
|||||||
if (self.after and self.after != OLDEST_OBJECT):
|
if (self.after and self.after != OLDEST_OBJECT):
|
||||||
self._filter = lambda m: int(m['id']) > self.after.id
|
self._filter = lambda m: int(m['id']) > self.after.id
|
||||||
|
|
||||||
async def next(self) -> T:
|
async def next(self) -> Message:
|
||||||
if self.messages.empty():
|
if self.messages.empty():
|
||||||
await self.fill_messages()
|
await self.fill_messages()
|
||||||
|
|
||||||
@ -422,7 +424,7 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
|
|||||||
self.after = Object(id=int(entries[0]['id']))
|
self.after = Object(id=int(entries[0]['id']))
|
||||||
return data.get('users', []), entries
|
return data.get('users', []), entries
|
||||||
|
|
||||||
async def next(self) -> T:
|
async def next(self) -> AuditLogEntry:
|
||||||
if self.entries.empty():
|
if self.entries.empty():
|
||||||
await self._fill()
|
await self._fill()
|
||||||
|
|
||||||
@ -519,7 +521,7 @@ class GuildIterator(_AsyncIterator['Guild']):
|
|||||||
else:
|
else:
|
||||||
self._retrieve_guilds = self._retrieve_guilds_before_strategy
|
self._retrieve_guilds = self._retrieve_guilds_before_strategy
|
||||||
|
|
||||||
async def next(self) -> T:
|
async def next(self) -> Guild:
|
||||||
if self.guilds.empty():
|
if self.guilds.empty():
|
||||||
await self.fill_guilds()
|
await self.fill_guilds()
|
||||||
|
|
||||||
@ -591,7 +593,7 @@ class MemberIterator(_AsyncIterator['Member']):
|
|||||||
self.get_members = self.state.http.get_members
|
self.get_members = self.state.http.get_members
|
||||||
self.members = asyncio.Queue()
|
self.members = asyncio.Queue()
|
||||||
|
|
||||||
async def next(self) -> T:
|
async def next(self) -> Member:
|
||||||
if self.members.empty():
|
if self.members.empty():
|
||||||
await self.fill_members()
|
await self.fill_members()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user