mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-04-20 16:00:29 +00:00
use typing.AsyncIterator
for iterators
This commit is contained in:
parent
7a34de1570
commit
9f0c701a7a
@ -22,20 +22,43 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
from typing import TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator, Coroutine
|
||||
|
||||
from .errors import NoMoreItems
|
||||
from .utils import time_snowflake, maybe_coroutine
|
||||
from .object import Object
|
||||
from .audit_logs import AuditLogEntry
|
||||
|
||||
__all__ = (
|
||||
'ReactionIterator',
|
||||
'HistoryIterator',
|
||||
'AuditLogIterator',
|
||||
'GuildIterator',
|
||||
'MemberIterator',
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .member import Member
|
||||
from .user import User
|
||||
from .message import Message
|
||||
from .audit_logs import AuditLogEntry
|
||||
from .guild import Guild
|
||||
|
||||
T = TypeVar('T')
|
||||
OT = TypeVar('OT')
|
||||
_Func = Callable[[T], Union[OT, Coroutine[Any, Any, OT]]]
|
||||
_Predicate = Callable[[T], Union[T, Coroutine[Any, Any, T]]]
|
||||
|
||||
OLDEST_OBJECT = Object(id=0)
|
||||
|
||||
class _AsyncIterator:
|
||||
class _AsyncIterator(AsyncIterator[T]):
|
||||
__slots__ = ()
|
||||
|
||||
def get(self, **attrs):
|
||||
def get(self, **attrs: Any) -> Optional[T]:
|
||||
def predicate(elem):
|
||||
for attr, val in attrs.items():
|
||||
nested = attr.split('__')
|
||||
@ -49,7 +72,7 @@ class _AsyncIterator:
|
||||
|
||||
return self.find(predicate)
|
||||
|
||||
async def find(self, predicate):
|
||||
async def find(self, predicate: _Predicate[T]) -> Optional[T]:
|
||||
while True:
|
||||
try:
|
||||
elem = await self.next()
|
||||
@ -60,40 +83,35 @@ class _AsyncIterator:
|
||||
if ret:
|
||||
return elem
|
||||
|
||||
def chunk(self, max_size):
|
||||
def chunk(self, max_size: int) -> _ChunkedAsyncIterator[T]:
|
||||
if max_size <= 0:
|
||||
raise ValueError('async iterator chunk sizes must be greater than 0.')
|
||||
return _ChunkedAsyncIterator(self, max_size)
|
||||
|
||||
def map(self, func):
|
||||
def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]:
|
||||
return _MappedAsyncIterator(self, func)
|
||||
|
||||
def filter(self, predicate):
|
||||
def filter(self, predicate: _Predicate[T]) -> _FilteredAsyncIterator[T]:
|
||||
return _FilteredAsyncIterator(self, predicate)
|
||||
|
||||
async def flatten(self):
|
||||
async def flatten(self) -> List[T]:
|
||||
return [element async for element in self]
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
async def __anext__(self) -> T:
|
||||
try:
|
||||
msg = await self.next()
|
||||
return await self.next()
|
||||
except NoMoreItems:
|
||||
raise StopAsyncIteration()
|
||||
else:
|
||||
return msg
|
||||
|
||||
def _identity(x):
|
||||
return x
|
||||
|
||||
class _ChunkedAsyncIterator(_AsyncIterator):
|
||||
class _ChunkedAsyncIterator(_AsyncIterator[T]):
|
||||
def __init__(self, iterator, max_size):
|
||||
self.iterator = iterator
|
||||
self.max_size = max_size
|
||||
|
||||
async def next(self):
|
||||
async def next(self) -> T:
|
||||
ret = []
|
||||
n = 0
|
||||
while n < self.max_size:
|
||||
@ -108,17 +126,17 @@ class _ChunkedAsyncIterator(_AsyncIterator):
|
||||
n += 1
|
||||
return ret
|
||||
|
||||
class _MappedAsyncIterator(_AsyncIterator):
|
||||
class _MappedAsyncIterator(_AsyncIterator[T]):
|
||||
def __init__(self, iterator, func):
|
||||
self.iterator = iterator
|
||||
self.func = func
|
||||
|
||||
async def next(self):
|
||||
async def next(self) -> T:
|
||||
# this raises NoMoreItems and will propagate appropriately
|
||||
item = await self.iterator.next()
|
||||
return await maybe_coroutine(self.func, item)
|
||||
|
||||
class _FilteredAsyncIterator(_AsyncIterator):
|
||||
class _FilteredAsyncIterator(_AsyncIterator[T]):
|
||||
def __init__(self, iterator, predicate):
|
||||
self.iterator = iterator
|
||||
|
||||
@ -127,7 +145,7 @@ class _FilteredAsyncIterator(_AsyncIterator):
|
||||
|
||||
self.predicate = predicate
|
||||
|
||||
async def next(self):
|
||||
async def next(self) -> T:
|
||||
getter = self.iterator.next
|
||||
pred = self.predicate
|
||||
while True:
|
||||
@ -137,7 +155,7 @@ class _FilteredAsyncIterator(_AsyncIterator):
|
||||
if ret:
|
||||
return item
|
||||
|
||||
class ReactionIterator(_AsyncIterator):
|
||||
class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
|
||||
def __init__(self, message, emoji, limit=100, after=None):
|
||||
self.message = message
|
||||
self.limit = limit
|
||||
@ -150,7 +168,7 @@ class ReactionIterator(_AsyncIterator):
|
||||
self.channel_id = message.channel.id
|
||||
self.users = asyncio.Queue()
|
||||
|
||||
async def next(self):
|
||||
async def next(self) -> T:
|
||||
if self.users.empty():
|
||||
await self.fill_users()
|
||||
|
||||
@ -185,7 +203,7 @@ class ReactionIterator(_AsyncIterator):
|
||||
else:
|
||||
await self.users.put(User(state=self.state, data=element))
|
||||
|
||||
class HistoryIterator(_AsyncIterator):
|
||||
class HistoryIterator(_AsyncIterator['Message']):
|
||||
"""Iterator for receiving a channel's message history.
|
||||
|
||||
The messages endpoint has two behaviours we care about here:
|
||||
@ -271,7 +289,7 @@ class HistoryIterator(_AsyncIterator):
|
||||
if (self.after and self.after != OLDEST_OBJECT):
|
||||
self._filter = lambda m: int(m['id']) > self.after.id
|
||||
|
||||
async def next(self):
|
||||
async def next(self) -> T:
|
||||
if self.messages.empty():
|
||||
await self.fill_messages()
|
||||
|
||||
@ -342,7 +360,7 @@ class HistoryIterator(_AsyncIterator):
|
||||
return data
|
||||
return []
|
||||
|
||||
class AuditLogIterator(_AsyncIterator):
|
||||
class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
|
||||
def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None):
|
||||
if isinstance(before, datetime.datetime):
|
||||
before = Object(id=time_snowflake(before, high=False))
|
||||
@ -404,7 +422,7 @@ class AuditLogIterator(_AsyncIterator):
|
||||
self.after = Object(id=int(entries[0]['id']))
|
||||
return data.get('users', []), entries
|
||||
|
||||
async def next(self):
|
||||
async def next(self) -> T:
|
||||
if self.entries.empty():
|
||||
await self._fill()
|
||||
|
||||
@ -447,7 +465,7 @@ class AuditLogIterator(_AsyncIterator):
|
||||
await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild))
|
||||
|
||||
|
||||
class GuildIterator(_AsyncIterator):
|
||||
class GuildIterator(_AsyncIterator['Guild']):
|
||||
"""Iterator for receiving the client's guilds.
|
||||
|
||||
The guilds endpoint has the same two behaviours as described
|
||||
@ -501,7 +519,7 @@ class GuildIterator(_AsyncIterator):
|
||||
else:
|
||||
self._retrieve_guilds = self._retrieve_guilds_before_strategy
|
||||
|
||||
async def next(self):
|
||||
async def next(self) -> T:
|
||||
if self.guilds.empty():
|
||||
await self.fill_guilds()
|
||||
|
||||
@ -559,7 +577,7 @@ class GuildIterator(_AsyncIterator):
|
||||
self.after = Object(id=int(data[0]['id']))
|
||||
return data
|
||||
|
||||
class MemberIterator(_AsyncIterator):
|
||||
class MemberIterator(_AsyncIterator['Member']):
|
||||
def __init__(self, guild, limit=1000, after=None):
|
||||
|
||||
if isinstance(after, datetime.datetime):
|
||||
@ -573,7 +591,7 @@ class MemberIterator(_AsyncIterator):
|
||||
self.get_members = self.state.http.get_members
|
||||
self.members = asyncio.Queue()
|
||||
|
||||
async def next(self):
|
||||
async def next(self) -> T:
|
||||
if self.members.empty():
|
||||
await self.fill_members()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user