mirror of
				https://github.com/Rapptz/discord.py.git
				synced 2025-11-03 23:12:56 +00:00 
			
		
		
		
	use typing.AsyncIterator for iterators
				
					
				
			This commit is contained in:
		@@ -22,20 +22,43 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
 | 
			
		||||
DEALINGS IN THE SOFTWARE.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import asyncio
 | 
			
		||||
import datetime
 | 
			
		||||
from typing import TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator, Coroutine
 | 
			
		||||
 | 
			
		||||
from .errors import NoMoreItems
 | 
			
		||||
from .utils import time_snowflake, maybe_coroutine
 | 
			
		||||
from .object import Object
 | 
			
		||||
from .audit_logs import AuditLogEntry
 | 
			
		||||
 | 
			
		||||
__all__ = (
 | 
			
		||||
    'ReactionIterator',
 | 
			
		||||
    'HistoryIterator',
 | 
			
		||||
    'AuditLogIterator',
 | 
			
		||||
    'GuildIterator',
 | 
			
		||||
    'MemberIterator',
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from .member import Member
 | 
			
		||||
    from .user import User
 | 
			
		||||
    from .message import Message
 | 
			
		||||
    from .audit_logs import AuditLogEntry
 | 
			
		||||
    from .guild import Guild
 | 
			
		||||
 | 
			
		||||
T = TypeVar('T')
 | 
			
		||||
OT = TypeVar('OT')
 | 
			
		||||
_Func = Callable[[T], Union[OT, Coroutine[Any, Any, OT]]]
 | 
			
		||||
_Predicate = Callable[[T], Union[T, Coroutine[Any, Any, T]]]
 | 
			
		||||
 | 
			
		||||
OLDEST_OBJECT = Object(id=0)
 | 
			
		||||
 | 
			
		||||
class _AsyncIterator:
 | 
			
		||||
class _AsyncIterator(AsyncIterator[T]):
 | 
			
		||||
    __slots__ = ()
 | 
			
		||||
 | 
			
		||||
    def get(self, **attrs):
 | 
			
		||||
    def get(self, **attrs: Any) -> Optional[T]:
 | 
			
		||||
        def predicate(elem):
 | 
			
		||||
            for attr, val in attrs.items():
 | 
			
		||||
                nested = attr.split('__')
 | 
			
		||||
@@ -49,7 +72,7 @@ class _AsyncIterator:
 | 
			
		||||
 | 
			
		||||
        return self.find(predicate)
 | 
			
		||||
 | 
			
		||||
    async def find(self, predicate):
 | 
			
		||||
    async def find(self, predicate: _Predicate[T]) -> Optional[T]:
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
                elem = await self.next()
 | 
			
		||||
@@ -60,40 +83,35 @@ class _AsyncIterator:
 | 
			
		||||
            if ret:
 | 
			
		||||
                return elem
 | 
			
		||||
 | 
			
		||||
    def chunk(self, max_size):
 | 
			
		||||
    def chunk(self, max_size: int) -> _ChunkedAsyncIterator[T]:
 | 
			
		||||
        if max_size <= 0:
 | 
			
		||||
            raise ValueError('async iterator chunk sizes must be greater than 0.')
 | 
			
		||||
        return _ChunkedAsyncIterator(self, max_size)
 | 
			
		||||
 | 
			
		||||
    def map(self, func):
 | 
			
		||||
    def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]:
 | 
			
		||||
        return _MappedAsyncIterator(self, func)
 | 
			
		||||
 | 
			
		||||
    def filter(self, predicate):
 | 
			
		||||
    def filter(self, predicate: _Predicate[T]) -> _FilteredAsyncIterator[T]:
 | 
			
		||||
        return _FilteredAsyncIterator(self, predicate)
 | 
			
		||||
 | 
			
		||||
    async def flatten(self):
 | 
			
		||||
    async def flatten(self) -> List[T]:
 | 
			
		||||
        return [element async for element in self]
 | 
			
		||||
 | 
			
		||||
    def __aiter__(self):
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    async def __anext__(self):
 | 
			
		||||
    async def __anext__(self) -> T:
 | 
			
		||||
        try:
 | 
			
		||||
            msg = await self.next()
 | 
			
		||||
            return await self.next()
 | 
			
		||||
        except NoMoreItems:
 | 
			
		||||
            raise StopAsyncIteration()
 | 
			
		||||
        else:
 | 
			
		||||
            return msg
 | 
			
		||||
 | 
			
		||||
def _identity(x):
 | 
			
		||||
    return x
 | 
			
		||||
 | 
			
		||||
class _ChunkedAsyncIterator(_AsyncIterator):
 | 
			
		||||
class _ChunkedAsyncIterator(_AsyncIterator[T]):
 | 
			
		||||
    def __init__(self, iterator, max_size):
 | 
			
		||||
        self.iterator = iterator
 | 
			
		||||
        self.max_size = max_size
 | 
			
		||||
 | 
			
		||||
    async def next(self):
 | 
			
		||||
    async def next(self) -> T:
 | 
			
		||||
        ret = []
 | 
			
		||||
        n = 0
 | 
			
		||||
        while n < self.max_size:
 | 
			
		||||
@@ -108,17 +126,17 @@ class _ChunkedAsyncIterator(_AsyncIterator):
 | 
			
		||||
                n += 1
 | 
			
		||||
        return ret
 | 
			
		||||
 | 
			
		||||
class _MappedAsyncIterator(_AsyncIterator):
 | 
			
		||||
class _MappedAsyncIterator(_AsyncIterator[T]):
 | 
			
		||||
    def __init__(self, iterator, func):
 | 
			
		||||
        self.iterator = iterator
 | 
			
		||||
        self.func = func
 | 
			
		||||
 | 
			
		||||
    async def next(self):
 | 
			
		||||
    async def next(self) -> T:
 | 
			
		||||
        # this raises NoMoreItems and will propagate appropriately
 | 
			
		||||
        item = await self.iterator.next()
 | 
			
		||||
        return await maybe_coroutine(self.func, item)
 | 
			
		||||
 | 
			
		||||
class _FilteredAsyncIterator(_AsyncIterator):
 | 
			
		||||
class _FilteredAsyncIterator(_AsyncIterator[T]):
 | 
			
		||||
    def __init__(self, iterator, predicate):
 | 
			
		||||
        self.iterator = iterator
 | 
			
		||||
 | 
			
		||||
@@ -127,7 +145,7 @@ class _FilteredAsyncIterator(_AsyncIterator):
 | 
			
		||||
 | 
			
		||||
        self.predicate = predicate
 | 
			
		||||
 | 
			
		||||
    async def next(self):
 | 
			
		||||
    async def next(self) -> T:
 | 
			
		||||
        getter = self.iterator.next
 | 
			
		||||
        pred = self.predicate
 | 
			
		||||
        while True:
 | 
			
		||||
@@ -137,7 +155,7 @@ class _FilteredAsyncIterator(_AsyncIterator):
 | 
			
		||||
            if ret:
 | 
			
		||||
                return item
 | 
			
		||||
 | 
			
		||||
class ReactionIterator(_AsyncIterator):
 | 
			
		||||
class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
 | 
			
		||||
    def __init__(self, message, emoji, limit=100, after=None):
 | 
			
		||||
        self.message = message
 | 
			
		||||
        self.limit = limit
 | 
			
		||||
@@ -150,7 +168,7 @@ class ReactionIterator(_AsyncIterator):
 | 
			
		||||
        self.channel_id = message.channel.id
 | 
			
		||||
        self.users = asyncio.Queue()
 | 
			
		||||
 | 
			
		||||
    async def next(self):
 | 
			
		||||
    async def next(self) -> T:
 | 
			
		||||
        if self.users.empty():
 | 
			
		||||
            await self.fill_users()
 | 
			
		||||
 | 
			
		||||
@@ -185,7 +203,7 @@ class ReactionIterator(_AsyncIterator):
 | 
			
		||||
                    else:
 | 
			
		||||
                        await self.users.put(User(state=self.state, data=element))
 | 
			
		||||
 | 
			
		||||
class HistoryIterator(_AsyncIterator):
 | 
			
		||||
class HistoryIterator(_AsyncIterator['Message']):
 | 
			
		||||
    """Iterator for receiving a channel's message history.
 | 
			
		||||
 | 
			
		||||
    The messages endpoint has two behaviours we care about here:
 | 
			
		||||
@@ -271,7 +289,7 @@ class HistoryIterator(_AsyncIterator):
 | 
			
		||||
                if (self.after and self.after != OLDEST_OBJECT):
 | 
			
		||||
                    self._filter = lambda m: int(m['id']) > self.after.id
 | 
			
		||||
 | 
			
		||||
    async def next(self):
 | 
			
		||||
    async def next(self) -> T:
 | 
			
		||||
        if self.messages.empty():
 | 
			
		||||
            await self.fill_messages()
 | 
			
		||||
 | 
			
		||||
@@ -342,7 +360,7 @@ class HistoryIterator(_AsyncIterator):
 | 
			
		||||
            return data
 | 
			
		||||
        return []
 | 
			
		||||
 | 
			
		||||
class AuditLogIterator(_AsyncIterator):
 | 
			
		||||
class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
 | 
			
		||||
    def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None):
 | 
			
		||||
        if isinstance(before, datetime.datetime):
 | 
			
		||||
            before = Object(id=time_snowflake(before, high=False))
 | 
			
		||||
@@ -404,7 +422,7 @@ class AuditLogIterator(_AsyncIterator):
 | 
			
		||||
            self.after = Object(id=int(entries[0]['id']))
 | 
			
		||||
        return data.get('users', []), entries
 | 
			
		||||
 | 
			
		||||
    async def next(self):
 | 
			
		||||
    async def next(self) -> T:
 | 
			
		||||
        if self.entries.empty():
 | 
			
		||||
            await self._fill()
 | 
			
		||||
 | 
			
		||||
@@ -447,7 +465,7 @@ class AuditLogIterator(_AsyncIterator):
 | 
			
		||||
                await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GuildIterator(_AsyncIterator):
 | 
			
		||||
class GuildIterator(_AsyncIterator['Guild']):
 | 
			
		||||
    """Iterator for receiving the client's guilds.
 | 
			
		||||
 | 
			
		||||
    The guilds endpoint has the same two behaviours as described
 | 
			
		||||
@@ -501,7 +519,7 @@ class GuildIterator(_AsyncIterator):
 | 
			
		||||
        else:
 | 
			
		||||
            self._retrieve_guilds = self._retrieve_guilds_before_strategy
 | 
			
		||||
 | 
			
		||||
    async def next(self):
 | 
			
		||||
    async def next(self) -> T:
 | 
			
		||||
        if self.guilds.empty():
 | 
			
		||||
            await self.fill_guilds()
 | 
			
		||||
 | 
			
		||||
@@ -559,7 +577,7 @@ class GuildIterator(_AsyncIterator):
 | 
			
		||||
            self.after = Object(id=int(data[0]['id']))
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
class MemberIterator(_AsyncIterator):
 | 
			
		||||
class MemberIterator(_AsyncIterator['Member']):
 | 
			
		||||
    def __init__(self, guild, limit=1000, after=None):
 | 
			
		||||
 | 
			
		||||
        if isinstance(after, datetime.datetime):
 | 
			
		||||
@@ -573,7 +591,7 @@ class MemberIterator(_AsyncIterator):
 | 
			
		||||
        self.get_members = self.state.http.get_members
 | 
			
		||||
        self.members = asyncio.Queue()
 | 
			
		||||
 | 
			
		||||
    async def next(self):
 | 
			
		||||
    async def next(self) -> T:
 | 
			
		||||
        if self.members.empty():
 | 
			
		||||
            await self.fill_members()
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user