mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-04-19 15:36:02 +00:00
Add support for AsyncIterables in find and get
This commit is contained in:
parent
88b520b5ab
commit
eb6f5728e2
144
discord/utils.py
144
discord/utils.py
@ -26,11 +26,13 @@ from __future__ import annotations
|
||||
import array
|
||||
import asyncio
|
||||
import collections.abc
|
||||
import inspect
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
ForwardRef,
|
||||
Generic,
|
||||
@ -141,6 +143,7 @@ else:
|
||||
T = TypeVar('T')
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
_Iter = Union[Iterable[T], AsyncIterable[T]]
|
||||
Coro = Coroutine[Any, Any, T]
|
||||
|
||||
|
||||
class CachedSlotProperty(Generic[T, T_co]):
|
||||
@ -363,8 +366,30 @@ def time_snowflake(dt: datetime.datetime, high: bool = False) -> int:
|
||||
return (discord_millis << 22) + (2**22 - 1 if high else 0)
|
||||
|
||||
|
||||
def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> Optional[T]:
|
||||
"""A helper to return the first element found in the sequence
|
||||
def _find(predicate: Callable[[T], Any], iterable: Iterable[T], /) -> Optional[T]:
|
||||
return next((element for element in iterable if predicate(element)), None)
|
||||
|
||||
|
||||
async def _afind(predicate: Callable[[T], Any], iterable: AsyncIterable[T], /) -> Optional[T]:
|
||||
async for element in iterable:
|
||||
if predicate(element):
|
||||
return element
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@overload
|
||||
def find(predicate: Callable[[T], Any], iterable: Iterable[T], /) -> Optional[T]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def find(predicate: Callable[[T], Any], iterable: AsyncIterable[T], /) -> Coro[Optional[T]]:
|
||||
...
|
||||
|
||||
|
||||
def find(predicate: Callable[[T], Any], iterable: _Iter[T], /) -> Union[Optional[T], Coro[Optional[T]]]:
|
||||
r"""A helper to return the first element found in the sequence
|
||||
that meets the predicate. For example: ::
|
||||
|
||||
member = discord.utils.find(lambda m: m.name == 'Mighty', channel.guild.members)
|
||||
@ -379,17 +404,77 @@ def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> Optional[T]:
|
||||
-----------
|
||||
predicate
|
||||
A function that returns a boolean-like result.
|
||||
seq: :class:`collections.abc.Iterable`
|
||||
The iterable to search through.
|
||||
iterable: Union[:class:`collections.abc.Iterable`, :class:`collections.abc.AsyncIterable`]
|
||||
The iterable to search through. Using a :class:`collections.abc.AsyncIterable`,
|
||||
makes this function return a :term:`coroutine`.
|
||||
|
||||
.. versionchanged:: 2.0.0
|
||||
|
||||
Both parameters are now positional-only.
|
||||
|
||||
.. versionchanged:: 2.0.0
|
||||
|
||||
The ``iterable`` parameter supports :term:`asynchronous iterable`\s.
|
||||
"""
|
||||
|
||||
for element in seq:
|
||||
if predicate(element):
|
||||
return element
|
||||
return (
|
||||
_find(predicate, iterable) # type: ignore
|
||||
if hasattr(iterable, '__iter__') # isinstance(iterable, collections.abc.Iterable) is too slow
|
||||
else _afind(predicate, iterable) # type: ignore
|
||||
)
|
||||
|
||||
|
||||
def _get(iterable: Iterable[T], /, **attrs: Any) -> Optional[T]:
|
||||
# global -> local
|
||||
_all = all
|
||||
attrget = attrgetter
|
||||
|
||||
# Special case the single element call
|
||||
if len(attrs) == 1:
|
||||
k, v = attrs.popitem()
|
||||
pred = attrget(k.replace('__', '.'))
|
||||
return next((elem for elem in iterable if pred(elem) == v), None)
|
||||
|
||||
converted = [(attrget(attr.replace('__', '.')), value) for attr, value in attrs.items()]
|
||||
for elem in iterable:
|
||||
if _all(pred(elem) == value for pred, value in converted):
|
||||
return elem
|
||||
return None
|
||||
|
||||
|
||||
def get(iterable: Iterable[T], **attrs: Any) -> Optional[T]:
|
||||
async def _aget(iterable: AsyncIterable[T], /, **attrs: Any) -> Optional[T]:
|
||||
# global -> local
|
||||
_all = all
|
||||
attrget = attrgetter
|
||||
|
||||
# Special case the single element call
|
||||
if len(attrs) == 1:
|
||||
k, v = attrs.popitem()
|
||||
pred = attrget(k.replace('__', '.'))
|
||||
async for elem in iterable:
|
||||
if pred(elem) == v:
|
||||
return elem
|
||||
return None
|
||||
|
||||
converted = [(attrget(attr.replace('__', '.')), value) for attr, value in attrs.items()]
|
||||
|
||||
async for elem in iterable:
|
||||
if _all(pred(elem) == value for pred, value in converted):
|
||||
return elem
|
||||
return None
|
||||
|
||||
|
||||
@overload
|
||||
def get(iterable: Iterable[T], /, **attrs: Any) -> Optional[T]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def get(iterable: AsyncIterable[T], /, **attrs: Any) -> Coro[Optional[T]]:
|
||||
...
|
||||
|
||||
|
||||
def get(iterable: _Iter[T], /, **attrs: Any) -> Union[Optional[T], Coro[Optional[T]]]:
|
||||
r"""A helper that returns the first element in the iterable that meets
|
||||
all the traits passed in ``attrs``. This is an alternative for
|
||||
:func:`~discord.utils.find`.
|
||||
@ -425,33 +510,34 @@ def get(iterable: Iterable[T], **attrs: Any) -> Optional[T]:
|
||||
|
||||
channel = discord.utils.get(client.get_all_channels(), guild__name='Cool', name='general')
|
||||
|
||||
Async iterables:
|
||||
|
||||
.. code-block:: python3
|
||||
|
||||
msg = await discord.utils.get(channel.history(), author__name='Dave')
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
iterable
|
||||
An iterable to search through.
|
||||
iterable: Union[:class:`collections.abc.Iterable`, :class:`collections.abc.AsyncIterable`]
|
||||
The iterable to search through. Using a :class:`collections.abc.AsyncIterable`,
|
||||
makes this function return a :term:`coroutine`.
|
||||
\*\*attrs
|
||||
Keyword arguments that denote attributes to search with.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
|
||||
The ``iterable`` parameter is now positional-only.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
|
||||
The ``iterable`` parameter supports :term:`asynchronous iterable`\s.
|
||||
"""
|
||||
|
||||
# global -> local
|
||||
_all = all
|
||||
attrget = attrgetter
|
||||
|
||||
# Special case the single element call
|
||||
if len(attrs) == 1:
|
||||
k, v = attrs.popitem()
|
||||
pred = attrget(k.replace('__', '.'))
|
||||
for elem in iterable:
|
||||
if pred(elem) == v:
|
||||
return elem
|
||||
return None
|
||||
|
||||
converted = [(attrget(attr.replace('__', '.')), value) for attr, value in attrs.items()]
|
||||
|
||||
for elem in iterable:
|
||||
if _all(pred(elem) == value for pred, value in converted):
|
||||
return elem
|
||||
return None
|
||||
return (
|
||||
_get(iterable, **attrs) # type: ignore
|
||||
if hasattr(iterable, '__iter__') # isinstance(iterable, collections.abc.Iterable) is too slow
|
||||
else _aget(predicate, **attrs) # type: ignore
|
||||
)
|
||||
|
||||
|
||||
def _unique(iterable: Iterable[T]) -> List[T]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user