mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-06-08 04:38:42 +00:00
Add support for AsyncIterables in find and get
This commit is contained in:
parent
88b520b5ab
commit
eb6f5728e2
144
discord/utils.py
144
discord/utils.py
@ -26,11 +26,13 @@ from __future__ import annotations
|
|||||||
import array
|
import array
|
||||||
import asyncio
|
import asyncio
|
||||||
import collections.abc
|
import collections.abc
|
||||||
|
import inspect
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterable,
|
AsyncIterable,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Callable,
|
Callable,
|
||||||
|
Coroutine,
|
||||||
Dict,
|
Dict,
|
||||||
ForwardRef,
|
ForwardRef,
|
||||||
Generic,
|
Generic,
|
||||||
@ -141,6 +143,7 @@ else:
|
|||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
T_co = TypeVar('T_co', covariant=True)
|
T_co = TypeVar('T_co', covariant=True)
|
||||||
_Iter = Union[Iterable[T], AsyncIterable[T]]
|
_Iter = Union[Iterable[T], AsyncIterable[T]]
|
||||||
|
Coro = Coroutine[Any, Any, T]
|
||||||
|
|
||||||
|
|
||||||
class CachedSlotProperty(Generic[T, T_co]):
|
class CachedSlotProperty(Generic[T, T_co]):
|
||||||
@ -363,8 +366,30 @@ def time_snowflake(dt: datetime.datetime, high: bool = False) -> int:
|
|||||||
return (discord_millis << 22) + (2**22 - 1 if high else 0)
|
return (discord_millis << 22) + (2**22 - 1 if high else 0)
|
||||||
|
|
||||||
|
|
||||||
def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> Optional[T]:
|
def _find(predicate: Callable[[T], Any], iterable: Iterable[T], /) -> Optional[T]:
|
||||||
"""A helper to return the first element found in the sequence
|
return next((element for element in iterable if predicate(element)), None)
|
||||||
|
|
||||||
|
|
||||||
|
async def _afind(predicate: Callable[[T], Any], iterable: AsyncIterable[T], /) -> Optional[T]:
|
||||||
|
async for element in iterable:
|
||||||
|
if predicate(element):
|
||||||
|
return element
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def find(predicate: Callable[[T], Any], iterable: Iterable[T], /) -> Optional[T]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def find(predicate: Callable[[T], Any], iterable: AsyncIterable[T], /) -> Coro[Optional[T]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def find(predicate: Callable[[T], Any], iterable: _Iter[T], /) -> Union[Optional[T], Coro[Optional[T]]]:
|
||||||
|
r"""A helper to return the first element found in the sequence
|
||||||
that meets the predicate. For example: ::
|
that meets the predicate. For example: ::
|
||||||
|
|
||||||
member = discord.utils.find(lambda m: m.name == 'Mighty', channel.guild.members)
|
member = discord.utils.find(lambda m: m.name == 'Mighty', channel.guild.members)
|
||||||
@ -379,17 +404,77 @@ def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> Optional[T]:
|
|||||||
-----------
|
-----------
|
||||||
predicate
|
predicate
|
||||||
A function that returns a boolean-like result.
|
A function that returns a boolean-like result.
|
||||||
seq: :class:`collections.abc.Iterable`
|
iterable: Union[:class:`collections.abc.Iterable`, :class:`collections.abc.AsyncIterable`]
|
||||||
The iterable to search through.
|
The iterable to search through. Using a :class:`collections.abc.AsyncIterable`,
|
||||||
|
makes this function return a :term:`coroutine`.
|
||||||
|
|
||||||
|
.. versionchanged:: 2.0.0
|
||||||
|
|
||||||
|
Both parameters are now positional-only.
|
||||||
|
|
||||||
|
.. versionchanged:: 2.0.0
|
||||||
|
|
||||||
|
The ``iterable`` parameter supports :term:`asynchronous iterable`\s.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for element in seq:
|
return (
|
||||||
if predicate(element):
|
_find(predicate, iterable) # type: ignore
|
||||||
return element
|
if hasattr(iterable, '__iter__') # isinstance(iterable, collections.abc.Iterable) is too slow
|
||||||
|
else _afind(predicate, iterable) # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get(iterable: Iterable[T], /, **attrs: Any) -> Optional[T]:
|
||||||
|
# global -> local
|
||||||
|
_all = all
|
||||||
|
attrget = attrgetter
|
||||||
|
|
||||||
|
# Special case the single element call
|
||||||
|
if len(attrs) == 1:
|
||||||
|
k, v = attrs.popitem()
|
||||||
|
pred = attrget(k.replace('__', '.'))
|
||||||
|
return next((elem for elem in iterable if pred(elem) == v), None)
|
||||||
|
|
||||||
|
converted = [(attrget(attr.replace('__', '.')), value) for attr, value in attrs.items()]
|
||||||
|
for elem in iterable:
|
||||||
|
if _all(pred(elem) == value for pred, value in converted):
|
||||||
|
return elem
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get(iterable: Iterable[T], **attrs: Any) -> Optional[T]:
|
async def _aget(iterable: AsyncIterable[T], /, **attrs: Any) -> Optional[T]:
|
||||||
|
# global -> local
|
||||||
|
_all = all
|
||||||
|
attrget = attrgetter
|
||||||
|
|
||||||
|
# Special case the single element call
|
||||||
|
if len(attrs) == 1:
|
||||||
|
k, v = attrs.popitem()
|
||||||
|
pred = attrget(k.replace('__', '.'))
|
||||||
|
async for elem in iterable:
|
||||||
|
if pred(elem) == v:
|
||||||
|
return elem
|
||||||
|
return None
|
||||||
|
|
||||||
|
converted = [(attrget(attr.replace('__', '.')), value) for attr, value in attrs.items()]
|
||||||
|
|
||||||
|
async for elem in iterable:
|
||||||
|
if _all(pred(elem) == value for pred, value in converted):
|
||||||
|
return elem
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get(iterable: Iterable[T], /, **attrs: Any) -> Optional[T]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get(iterable: AsyncIterable[T], /, **attrs: Any) -> Coro[Optional[T]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def get(iterable: _Iter[T], /, **attrs: Any) -> Union[Optional[T], Coro[Optional[T]]]:
|
||||||
r"""A helper that returns the first element in the iterable that meets
|
r"""A helper that returns the first element in the iterable that meets
|
||||||
all the traits passed in ``attrs``. This is an alternative for
|
all the traits passed in ``attrs``. This is an alternative for
|
||||||
:func:`~discord.utils.find`.
|
:func:`~discord.utils.find`.
|
||||||
@ -425,33 +510,34 @@ def get(iterable: Iterable[T], **attrs: Any) -> Optional[T]:
|
|||||||
|
|
||||||
channel = discord.utils.get(client.get_all_channels(), guild__name='Cool', name='general')
|
channel = discord.utils.get(client.get_all_channels(), guild__name='Cool', name='general')
|
||||||
|
|
||||||
|
Async iterables:
|
||||||
|
|
||||||
|
.. code-block:: python3
|
||||||
|
|
||||||
|
msg = await discord.utils.get(channel.history(), author__name='Dave')
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
-----------
|
-----------
|
||||||
iterable
|
iterable: Union[:class:`collections.abc.Iterable`, :class:`collections.abc.AsyncIterable`]
|
||||||
An iterable to search through.
|
The iterable to search through. Using a :class:`collections.abc.AsyncIterable`,
|
||||||
|
makes this function return a :term:`coroutine`.
|
||||||
\*\*attrs
|
\*\*attrs
|
||||||
Keyword arguments that denote attributes to search with.
|
Keyword arguments that denote attributes to search with.
|
||||||
|
|
||||||
|
.. versionchanged:: 2.0
|
||||||
|
|
||||||
|
The ``iterable`` parameter is now positional-only.
|
||||||
|
|
||||||
|
.. versionchanged:: 2.0
|
||||||
|
|
||||||
|
The ``iterable`` parameter supports :term:`asynchronous iterable`\s.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# global -> local
|
return (
|
||||||
_all = all
|
_get(iterable, **attrs) # type: ignore
|
||||||
attrget = attrgetter
|
if hasattr(iterable, '__iter__') # isinstance(iterable, collections.abc.Iterable) is too slow
|
||||||
|
else _aget(predicate, **attrs) # type: ignore
|
||||||
# Special case the single element call
|
)
|
||||||
if len(attrs) == 1:
|
|
||||||
k, v = attrs.popitem()
|
|
||||||
pred = attrget(k.replace('__', '.'))
|
|
||||||
for elem in iterable:
|
|
||||||
if pred(elem) == v:
|
|
||||||
return elem
|
|
||||||
return None
|
|
||||||
|
|
||||||
converted = [(attrget(attr.replace('__', '.')), value) for attr, value in attrs.items()]
|
|
||||||
|
|
||||||
for elem in iterable:
|
|
||||||
if _all(pred(elem) == value for pred, value in converted):
|
|
||||||
return elem
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _unique(iterable: Iterable[T]) -> List[T]:
|
def _unique(iterable: Iterable[T]) -> List[T]:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user