mirror of
				https://github.com/Rapptz/discord.py.git
				synced 2025-10-25 02:23:04 +00:00 
			
		
		
		
	Fix inaccuracies with AsyncIterator typings
				
					
				
			This commit is contained in:
		| @@ -26,7 +26,7 @@ from __future__ import annotations | ||||
|  | ||||
| import asyncio | ||||
| import datetime | ||||
| from typing import TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator, Coroutine | ||||
| from typing import Awaitable, TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator | ||||
|  | ||||
| from .errors import NoMoreItems | ||||
| from .utils import time_snowflake, maybe_coroutine | ||||
| @@ -50,16 +50,18 @@ if TYPE_CHECKING: | ||||
|  | ||||
| T = TypeVar('T') | ||||
| OT = TypeVar('OT') | ||||
| _Func = Callable[[T], Union[OT, Coroutine[Any, Any, OT]]] | ||||
| _Predicate = Callable[[T], Union[T, Coroutine[Any, Any, T]]] | ||||
| _Func = Callable[[T], Union[OT, Awaitable[OT]]] | ||||
|  | ||||
| OLDEST_OBJECT = Object(id=0) | ||||
|  | ||||
| class _AsyncIterator(AsyncIterator[T]): | ||||
|     __slots__ = () | ||||
|  | ||||
|     def get(self, **attrs: Any) -> Optional[T]: | ||||
|         def predicate(elem): | ||||
|     async def next(self) -> T: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def get(self, **attrs: Any) -> Awaitable[Optional[T]]: | ||||
|         def predicate(elem: T): | ||||
|             for attr, val in attrs.items(): | ||||
|                 nested = attr.split('__') | ||||
|                 obj = elem | ||||
| @@ -72,7 +74,7 @@ class _AsyncIterator(AsyncIterator[T]): | ||||
|  | ||||
|         return self.find(predicate) | ||||
|  | ||||
|     async def find(self, predicate: _Predicate[T]) -> Optional[T]: | ||||
|     async def find(self, predicate: _Func[T, bool]) -> Optional[T]: | ||||
|         while True: | ||||
|             try: | ||||
|                 elem = await self.next() | ||||
| @@ -91,7 +93,7 @@ class _AsyncIterator(AsyncIterator[T]): | ||||
|     def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]: | ||||
|         return _MappedAsyncIterator(self, func) | ||||
|  | ||||
|     def filter(self, predicate: _Predicate[T]) -> _FilteredAsyncIterator[T]: | ||||
|     def filter(self, predicate: _Func[T, bool]) -> _FilteredAsyncIterator[T]: | ||||
|         return _FilteredAsyncIterator(self, predicate) | ||||
|  | ||||
|     async def flatten(self) -> List[T]: | ||||
| @@ -106,13 +108,13 @@ class _AsyncIterator(AsyncIterator[T]): | ||||
| def _identity(x): | ||||
|     return x | ||||
|  | ||||
| class _ChunkedAsyncIterator(_AsyncIterator[T]): | ||||
| class _ChunkedAsyncIterator(_AsyncIterator[List[T]]): | ||||
|     def __init__(self, iterator, max_size): | ||||
|         self.iterator = iterator | ||||
|         self.max_size = max_size | ||||
|  | ||||
|     async def next(self) -> T: | ||||
|         ret = [] | ||||
|     async def next(self) -> List[T]: | ||||
|         ret: List[T] = [] | ||||
|         n = 0 | ||||
|         while n < self.max_size: | ||||
|             try: | ||||
| @@ -168,7 +170,7 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]): | ||||
|         self.channel_id = message.channel.id | ||||
|         self.users = asyncio.Queue() | ||||
|  | ||||
|     async def next(self) -> T: | ||||
|     async def next(self) -> Union[User, Member]: | ||||
|         if self.users.empty(): | ||||
|             await self.fill_users() | ||||
|  | ||||
| @@ -289,7 +291,7 @@ class HistoryIterator(_AsyncIterator['Message']): | ||||
|                 if (self.after and self.after != OLDEST_OBJECT): | ||||
|                     self._filter = lambda m: int(m['id']) > self.after.id | ||||
|  | ||||
|     async def next(self) -> T: | ||||
|     async def next(self) -> Message: | ||||
|         if self.messages.empty(): | ||||
|             await self.fill_messages() | ||||
|  | ||||
| @@ -422,7 +424,7 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']): | ||||
|             self.after = Object(id=int(entries[0]['id'])) | ||||
|         return data.get('users', []), entries | ||||
|  | ||||
|     async def next(self) -> T: | ||||
|     async def next(self) -> AuditLogEntry: | ||||
|         if self.entries.empty(): | ||||
|             await self._fill() | ||||
|  | ||||
| @@ -519,7 +521,7 @@ class GuildIterator(_AsyncIterator['Guild']): | ||||
|         else: | ||||
|             self._retrieve_guilds = self._retrieve_guilds_before_strategy | ||||
|  | ||||
|     async def next(self) -> T: | ||||
|     async def next(self) -> Guild: | ||||
|         if self.guilds.empty(): | ||||
|             await self.fill_guilds() | ||||
|  | ||||
| @@ -591,7 +593,7 @@ class MemberIterator(_AsyncIterator['Member']): | ||||
|         self.get_members = self.state.http.get_members | ||||
|         self.members = asyncio.Queue() | ||||
|  | ||||
|     async def next(self) -> T: | ||||
|     async def next(self) -> Member: | ||||
|         if self.members.empty(): | ||||
|             await self.fill_members() | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user