mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-04-19 15:36:02 +00:00
Refactor AsyncIter to use 3.6+ asynchronous generators
This commit is contained in:
parent
dc19c6c7d5
commit
588cda0996
110
discord/abc.py
110
discord/abc.py
@ -26,8 +26,10 @@ from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
@ -42,7 +44,7 @@ from typing import (
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from .iterators import HistoryIterator
|
||||
from .object import OLDEST_OBJECT, Object
|
||||
from .context_managers import Typing
|
||||
from .enums import ChannelType
|
||||
from .errors import InvalidArgument, ClientException
|
||||
@ -68,8 +70,6 @@ __all__ = (
|
||||
T = TypeVar('T', bound=VoiceProtocol)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import datetime
|
||||
|
||||
from .client import Client
|
||||
from .user import ClientUser
|
||||
from .asset import Asset
|
||||
@ -1465,7 +1465,7 @@ class Messageable:
|
||||
data = await state.http.pins_from(channel.id)
|
||||
return [state.create_message(channel=channel, data=m) for m in data]
|
||||
|
||||
def history(
|
||||
async def history(
|
||||
self,
|
||||
*,
|
||||
limit: Optional[int] = 100,
|
||||
@ -1473,8 +1473,8 @@ class Messageable:
|
||||
after: Optional[SnowflakeTime] = None,
|
||||
around: Optional[SnowflakeTime] = None,
|
||||
oldest_first: Optional[bool] = None,
|
||||
) -> HistoryIterator:
|
||||
"""Returns an :class:`~discord.AsyncIterator` that enables receiving the destination's message history.
|
||||
) -> AsyncIterator[Message]:
|
||||
"""Returns an :term:`asynchronous iterator` that enables receiving the destination's message history.
|
||||
|
||||
You must have :attr:`~discord.Permissions.read_message_history` permissions to use this.
|
||||
|
||||
@ -1490,7 +1490,7 @@ class Messageable:
|
||||
|
||||
Flattening into a list: ::
|
||||
|
||||
messages = await channel.history(limit=123).flatten()
|
||||
messages = [message async for message in channel.history(limit=123)]
|
||||
# messages is now a list of Message...
|
||||
|
||||
All parameters are optional.
|
||||
@ -1531,7 +1531,101 @@ class Messageable:
|
||||
:class:`~discord.Message`
|
||||
The message with the message data parsed.
|
||||
"""
|
||||
return HistoryIterator(self, limit=limit, before=before, after=after, around=around, oldest_first=oldest_first)
|
||||
|
||||
async def _around_strategy(retrieve, around, limit):
|
||||
if not around:
|
||||
return []
|
||||
|
||||
around_id = around.id if around else None
|
||||
data = await self._state.http.logs_from(channel.id, retrieve, around=around_id)
|
||||
|
||||
return data, None, limit
|
||||
|
||||
async def _after_strategy(retrieve, after, limit):
|
||||
after_id = after.id if after else None
|
||||
data = await self._state.http.logs_from(channel.id, retrieve, after=after_id)
|
||||
|
||||
if data:
|
||||
if limit is not None:
|
||||
limit -= len(data)
|
||||
|
||||
after = Object(id=int(data[0]['id']))
|
||||
|
||||
return data, after, limit
|
||||
|
||||
async def _before_strategy(retrieve, before, limit):
|
||||
before_id = before.id if before else None
|
||||
data = await self._state.http.logs_from(channel.id, retrieve, before=before_id)
|
||||
|
||||
if data:
|
||||
if limit is not None:
|
||||
limit -= len(data)
|
||||
|
||||
before = Object(id=int(data[-1]['id']))
|
||||
|
||||
return data, before, limit
|
||||
|
||||
if isinstance(before, datetime):
|
||||
before = Object(id=utils.time_snowflake(before, high=False))
|
||||
if isinstance(after, datetime):
|
||||
after = Object(id=utils.time_snowflake(after, high=True))
|
||||
if isinstance(around, datetime):
|
||||
around = Object(id=utils.time_snowflake(around))
|
||||
|
||||
if oldest_first is None:
|
||||
reverse = after is not None
|
||||
else:
|
||||
reverse = oldest_first
|
||||
|
||||
after = after or OLDEST_OBJECT
|
||||
predicate = None
|
||||
|
||||
if around:
|
||||
if limit is None:
|
||||
raise ValueError('history does not support around with limit=None')
|
||||
if limit > 101:
|
||||
raise ValueError("history max limit 101 when specifying around parameter")
|
||||
|
||||
# Strange Discord quirk
|
||||
limit = 100 if limit == 101 else limit
|
||||
|
||||
strategy, state = _around_strategy, around
|
||||
|
||||
if before and after:
|
||||
predicate = lambda m: after.id < int(m['id']) < before.id
|
||||
elif before:
|
||||
predicate = lambda m: int(m['id']) < before.id
|
||||
elif after:
|
||||
predicate = lambda m: after.id < int(m['id'])
|
||||
elif reverse:
|
||||
strategy, state = _after_strategy, after
|
||||
if before:
|
||||
predicate = lambda m: int(m['id']) < before.id
|
||||
else:
|
||||
strategy, state = _before_strategy, before
|
||||
if after and after != OLDEST_OBJECT:
|
||||
predicate = lambda m: int(m['id']) > after.id
|
||||
|
||||
channel = await self._get_channel()
|
||||
|
||||
while True:
|
||||
retrieve = min(100 if limit is None else limit, 100)
|
||||
if retrieve < 1:
|
||||
return
|
||||
|
||||
data, state, limit = await strategy(retrieve, state, limit)
|
||||
|
||||
# Terminate loop on next iteration; there's no data left after this
|
||||
if len(data) < 100:
|
||||
limit = 0
|
||||
|
||||
if reverse:
|
||||
data = reversed(data)
|
||||
if predicate:
|
||||
data = filter(predicate, data)
|
||||
|
||||
for raw_message in data:
|
||||
yield self._state.create_message(channel=channel, data=raw_message)
|
||||
|
||||
|
||||
class Connectable(Protocol):
|
||||
|
@ -28,6 +28,7 @@ import time
|
||||
import asyncio
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
@ -54,7 +55,6 @@ from .asset import Asset
|
||||
from .errors import ClientException, InvalidArgument
|
||||
from .stage_instance import StageInstance
|
||||
from .threads import Thread
|
||||
from .iterators import ArchivedThreadIterator
|
||||
|
||||
__all__ = (
|
||||
'TextChannel',
|
||||
@ -755,15 +755,15 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
|
||||
return Thread(guild=self.guild, state=self._state, data=data)
|
||||
|
||||
def archived_threads(
|
||||
async def archived_threads(
|
||||
self,
|
||||
*,
|
||||
private: bool = False,
|
||||
joined: bool = False,
|
||||
limit: Optional[int] = 50,
|
||||
before: Optional[Union[Snowflake, datetime.datetime]] = None,
|
||||
) -> ArchivedThreadIterator:
|
||||
"""Returns an :class:`~discord.AsyncIterator` that iterates over all archived threads in the guild.
|
||||
) -> AsyncIterator[Thread]:
|
||||
"""Returns an :term:`asynchronous iterator` that iterates over all archived threads in the guild.
|
||||
|
||||
You must have :attr:`~Permissions.read_message_history` to use this. If iterating over private threads
|
||||
then :attr:`~Permissions.manage_threads` is also required.
|
||||
@ -790,13 +790,57 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
You do not have permissions to get archived threads.
|
||||
HTTPException
|
||||
The request to get the archived threads failed.
|
||||
ValueError
|
||||
`joined`` was set to ``True`` and ``private`` was set to ``False``. You cannot retrieve public archived
|
||||
threads that you have joined.
|
||||
|
||||
Yields
|
||||
-------
|
||||
:class:`Thread`
|
||||
The archived threads.
|
||||
"""
|
||||
return ArchivedThreadIterator(self.id, self.guild, limit=limit, joined=joined, private=private, before=before)
|
||||
if joined and not private:
|
||||
raise ValueError('Cannot retrieve joined public archived threads')
|
||||
|
||||
before_timestamp = None
|
||||
|
||||
if isinstance(before, datetime.datetime):
|
||||
if joined:
|
||||
before_timestamp = str(utils.time_snowflake(before, high=False))
|
||||
else:
|
||||
before_timestamp = before.isoformat()
|
||||
elif before is not None:
|
||||
if joined:
|
||||
before_timestamp = str(before.id)
|
||||
else:
|
||||
before_timestamp = utils.snowflake_time(before.id).isoformat()
|
||||
|
||||
update_before = lambda data: data['thread_metadata']['archive_timestamp']
|
||||
endpoint = self.guild._state.http.get_public_archived_threads
|
||||
|
||||
if joined:
|
||||
update_before = lambda data: data['id']
|
||||
endpoint = self.guild._state.http.get_joined_private_archived_threads
|
||||
elif private:
|
||||
endpoint = self.guild._state.http.get_private_archived_threads
|
||||
|
||||
while True:
|
||||
retrieve = 50 if limit is None else max(limit, 50)
|
||||
data = await endpoint(self.id, before=before_timestamp, limit=retrieve)
|
||||
|
||||
threads = data.get('threads', [])
|
||||
for raw_thread in reversed(threads):
|
||||
yield Thread(guild=self.guild, state=self.guild._state, data=raw_thread)
|
||||
|
||||
if not data.get('has_more', False):
|
||||
return
|
||||
|
||||
if limit is not None:
|
||||
limit -= len(threads)
|
||||
if limit <= 0:
|
||||
return
|
||||
|
||||
before = update_before(threads[-1])
|
||||
|
||||
|
||||
class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable):
|
||||
|
@ -25,11 +25,26 @@ DEALINGS IN THE SOFTWARE.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union
|
||||
)
|
||||
|
||||
import aiohttp
|
||||
|
||||
@ -51,11 +66,10 @@ from .voice_client import VoiceClient
|
||||
from .http import HTTPClient
|
||||
from .state import ConnectionState
|
||||
from . import utils
|
||||
from .utils import MISSING
|
||||
from .utils import MISSING, time_snowflake
|
||||
from .object import Object
|
||||
from .backoff import ExponentialBackoff
|
||||
from .webhook import Webhook
|
||||
from .iterators import GuildIterator
|
||||
from .appinfo import AppInfo
|
||||
from .ui.view import View
|
||||
from .stage_instance import StageInstance
|
||||
@ -63,6 +77,7 @@ from .threads import Thread
|
||||
from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .types.guild import Guild as GuildPayload
|
||||
from .abc import SnowflakeTime, PrivateChannel, GuildChannel, Snowflake
|
||||
from .channel import DMChannel
|
||||
from .message import Message
|
||||
@ -1120,14 +1135,14 @@ class Client:
|
||||
|
||||
# Guild stuff
|
||||
|
||||
def fetch_guilds(
|
||||
async def fetch_guilds(
|
||||
self,
|
||||
*,
|
||||
limit: Optional[int] = 100,
|
||||
before: SnowflakeTime = None,
|
||||
after: SnowflakeTime = None
|
||||
) -> GuildIterator:
|
||||
"""Retrieves an :class:`.AsyncIterator` that enables receiving your guilds.
|
||||
before: Optional[SnowflakeTime] = None,
|
||||
after: Optional[SnowflakeTime] = None,
|
||||
) -> AsyncIterator[Guild]:
|
||||
"""Retrieves an :term:`asynchronous iterator` that enables receiving your guilds.
|
||||
|
||||
.. note::
|
||||
|
||||
@ -1148,7 +1163,7 @@ class Client:
|
||||
|
||||
Flattening into a list ::
|
||||
|
||||
guilds = await client.fetch_guilds(limit=150).flatten()
|
||||
guilds = [guild async for guild in client.fetch_guilds(limit=150)]
|
||||
# guilds is now a list of Guild...
|
||||
|
||||
All parameters are optional.
|
||||
@ -1179,7 +1194,60 @@ class Client:
|
||||
:class:`.Guild`
|
||||
The guild with the guild data parsed.
|
||||
"""
|
||||
return GuildIterator(self, limit=limit, before=before, after=after)
|
||||
|
||||
async def _before_strategy(retrieve, before, limit):
|
||||
before_id = before.id if before else None
|
||||
data = await self.http.get_guilds(retrieve, before=before_id)
|
||||
|
||||
if data:
|
||||
if limit is not None:
|
||||
limit -= len(data)
|
||||
|
||||
before = Object(id=int(data[-1]['id']))
|
||||
|
||||
return data, before, limit
|
||||
|
||||
async def _after_strategy(retrieve, after, limit):
|
||||
after_id = after.id if after else None
|
||||
data = await self.http.get_guilds(retrieve, after=after_id)
|
||||
|
||||
if data:
|
||||
if limit is not None:
|
||||
limit -= len(data)
|
||||
|
||||
after = Object(id=int(data[0]['id']))
|
||||
|
||||
return data, after, limit
|
||||
|
||||
if isinstance(before, datetime.datetime):
|
||||
before = Object(id=time_snowflake(before, high=False))
|
||||
if isinstance(after, datetime.datetime):
|
||||
after = Object(id=time_snowflake(after, high=True))
|
||||
|
||||
predicate = None
|
||||
strategy, state = _before_strategy, before
|
||||
|
||||
if before and after:
|
||||
predicate = lambda m: int(m['id']) > after.id # type: ignore
|
||||
elif after:
|
||||
strategy, state = _after_strategy, after
|
||||
|
||||
while True:
|
||||
retrieve = min(100 if limit is None else limit, 100)
|
||||
if retrieve < 1:
|
||||
return
|
||||
|
||||
data, state, limit = await strategy(retrieve, state, limit)
|
||||
|
||||
# Terminate loop on next iteration; there's no data left after this
|
||||
if len(data) < 100:
|
||||
limit = 0
|
||||
|
||||
if predicate:
|
||||
data = filter(predicate, data)
|
||||
|
||||
for raw_guild in data:
|
||||
yield Guild(state=self._connection, data=raw_guild)
|
||||
|
||||
async def fetch_template(self, code: Union[Template, str]) -> Template:
|
||||
"""|coro|
|
||||
|
133
discord/guild.py
133
discord/guild.py
@ -25,9 +25,11 @@ DEALINGS IN THE SOFTWARE.
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import datetime
|
||||
import unicodedata
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
ClassVar,
|
||||
Dict,
|
||||
List,
|
||||
@ -67,7 +69,6 @@ from .enums import (
|
||||
from .mixins import Hashable
|
||||
from .user import User
|
||||
from .invite import Invite
|
||||
from .iterators import AuditLogIterator, MemberIterator
|
||||
from .widget import Widget
|
||||
from .asset import Asset
|
||||
from .flags import SystemChannelFlags
|
||||
@ -76,6 +77,8 @@ from .stage_instance import StageInstance
|
||||
from .threads import Thread, ThreadMember
|
||||
from .sticker import GuildSticker
|
||||
from .file import File
|
||||
from .audit_logs import AuditLogEntry
|
||||
from .object import OLDEST_OBJECT, Object
|
||||
|
||||
|
||||
__all__ = (
|
||||
@ -98,8 +101,6 @@ if TYPE_CHECKING:
|
||||
from .state import ConnectionState
|
||||
from .voice_client import VoiceProtocol
|
||||
|
||||
import datetime
|
||||
|
||||
VocalGuildChannel = Union[VoiceChannel, StageChannel]
|
||||
GuildChannel = Union[VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel]
|
||||
ByCategoryItem = Tuple[Optional[CategoryChannel], List[GuildChannel]]
|
||||
@ -1649,9 +1650,8 @@ class Guild(Hashable):
|
||||
|
||||
return threads
|
||||
|
||||
# TODO: Remove Optional typing here when async iterators are refactored
|
||||
def fetch_members(self, *, limit: int = 1000, after: Optional[SnowflakeTime] = None) -> MemberIterator:
|
||||
"""Retrieves an :class:`.AsyncIterator` that enables receiving the guild's members. In order to use this,
|
||||
async def fetch_members(self, *, limit: int = 1000, after: SnowflakeTime = MISSING) -> AsyncIterator[Member]:
|
||||
"""Retrieves an :term:`asynchronous iterator` that enables receiving the guild's members. In order to use this,
|
||||
:meth:`Intents.members` must be enabled.
|
||||
|
||||
.. note::
|
||||
@ -1701,7 +1701,30 @@ class Guild(Hashable):
|
||||
if not self._state._intents.members:
|
||||
raise ClientException('Intents.members must be enabled to use this.')
|
||||
|
||||
return MemberIterator(self, limit=limit, after=after)
|
||||
while True:
|
||||
retrieve = min(1000 if limit is None else limit, 1000)
|
||||
if retrieve < 1:
|
||||
return
|
||||
|
||||
if isinstance(after, datetime.datetime):
|
||||
after = Object(id=utils.time_snowflake(after, high=True))
|
||||
|
||||
after = after or OLDEST_OBJECT
|
||||
after_id = after.id if after else None
|
||||
state = self._state
|
||||
|
||||
data = await state.http.get_members(self.id, retrieve, after_id)
|
||||
if not data:
|
||||
return
|
||||
|
||||
# Terminate loop on next iteration; there's no data left after this
|
||||
if len(data) < 1000:
|
||||
limit = 0
|
||||
|
||||
after = Object(id=int(data[-1]['user']['id']))
|
||||
|
||||
for raw_member in reversed(data):
|
||||
yield Member(data=raw_member, guild=self, state=state)
|
||||
|
||||
async def fetch_member(self, member_id: int, /) -> Member:
|
||||
"""|coro|
|
||||
@ -2731,18 +2754,17 @@ class Guild(Hashable):
|
||||
payload['uses'] = payload.get('uses', 0)
|
||||
return Invite(state=self._state, data=payload, guild=self, channel=channel)
|
||||
|
||||
# TODO: use MISSING when async iterators get refactored
|
||||
def audit_logs(
|
||||
async def audit_logs(
|
||||
self,
|
||||
*,
|
||||
limit: int = 100,
|
||||
before: Optional[SnowflakeTime] = None,
|
||||
after: Optional[SnowflakeTime] = None,
|
||||
oldest_first: Optional[bool] = None,
|
||||
user: Snowflake = None,
|
||||
action: AuditLogAction = None,
|
||||
) -> AuditLogIterator:
|
||||
"""Returns an :class:`AsyncIterator` that enables receiving the guild's audit logs.
|
||||
user: Snowflake = MISSING,
|
||||
action: AuditLogAction = MISSING,
|
||||
) -> AsyncIterator[AuditLogEntry]:
|
||||
"""Returns an :term:`asynchronous iterator` that enables receiving the guild's audit logs.
|
||||
|
||||
You must have the :attr:`~Permissions.view_audit_log` permission to use this.
|
||||
|
||||
@ -2761,7 +2783,7 @@ class Guild(Hashable):
|
||||
|
||||
Getting entries made by a specific user: ::
|
||||
|
||||
entries = await guild.audit_logs(limit=None, user=guild.me).flatten()
|
||||
entries = [entry async for entry in guild.audit_logs(limit=None, user=guild.me)]
|
||||
await channel.send(f'I made {len(entries)} moderation actions.')
|
||||
|
||||
Parameters
|
||||
@ -2796,6 +2818,39 @@ class Guild(Hashable):
|
||||
:class:`AuditLogEntry`
|
||||
The audit log entry.
|
||||
"""
|
||||
|
||||
async def _before_strategy(retrieve, before, limit):
|
||||
before_id = before.id if before else None
|
||||
data = await self._state.http.get_audit_logs(
|
||||
self.id, limit=retrieve, user_id=user_id, action_type=action, before=before_id
|
||||
)
|
||||
|
||||
entries = data.get('audit_log_entries', [])
|
||||
|
||||
if data and entries:
|
||||
if limit is not None:
|
||||
limit -= len(data)
|
||||
|
||||
before = Object(id=int(entries[-1]['id']))
|
||||
|
||||
return data.get('users', []), entries, before, limit
|
||||
|
||||
async def _after_strategy(retrieve, after, limit):
|
||||
after_id = after.id if after else None
|
||||
data = await self._state.http.get_audit_logs(
|
||||
self.id, limit=retrieve, user_id=user_id, action_type=action, after=after_id
|
||||
)
|
||||
|
||||
entries = data.get('audit_log_entries', [])
|
||||
|
||||
if data and entries:
|
||||
if limit is not None:
|
||||
limit -= len(data)
|
||||
|
||||
after = Object(id=int(entries[0]['id']))
|
||||
|
||||
return data.get('users', []), entries, after, limit
|
||||
|
||||
if user is not None:
|
||||
user_id = user.id
|
||||
else:
|
||||
@ -2804,9 +2859,53 @@ class Guild(Hashable):
|
||||
if action:
|
||||
action = action.value
|
||||
|
||||
return AuditLogIterator(
|
||||
self, before=before, after=after, limit=limit, oldest_first=oldest_first, user_id=user_id, action_type=action
|
||||
)
|
||||
if isinstance(before, datetime.datetime):
|
||||
before = Object(id=utils.time_snowflake(before, high=False))
|
||||
if isinstance(after, datetime.datetime):
|
||||
after = Object(id=utils.time_snowflake(after, high=True))
|
||||
|
||||
|
||||
if oldest_first is None:
|
||||
reverse = after is not None
|
||||
else:
|
||||
reverse = oldest_first
|
||||
|
||||
predicate = None
|
||||
|
||||
if reverse:
|
||||
strategy, state = _after_strategy, after
|
||||
if before:
|
||||
predicate = lambda m: int(m['id']) < before.id
|
||||
else:
|
||||
strategy, state = _before_strategy, before
|
||||
if after and after != OLDEST_OBJECT:
|
||||
predicate = lambda m: int(m['id']) > after.id
|
||||
|
||||
while True:
|
||||
retrieve = min(100 if limit is None else limit, 100)
|
||||
if retrieve < 1:
|
||||
return
|
||||
|
||||
raw_users, data, state, limit = await strategy(retrieve, state, limit)
|
||||
|
||||
# Terminate loop on next iteration; there's no data left after this
|
||||
if len(data) < 100:
|
||||
limit = 0
|
||||
|
||||
if reverse:
|
||||
data = reversed(data)
|
||||
if predicate:
|
||||
data = filter(predicate, data)
|
||||
|
||||
users = (User(data=raw_user, state=self._state) for raw_user in raw_users)
|
||||
user_map = {user.id: user for user in users}
|
||||
|
||||
for raw_entry in data:
|
||||
# Weird Discord quirk
|
||||
if raw_entry['action_type'] is None:
|
||||
continue
|
||||
|
||||
yield AuditLogEntry(data=raw_entry, users=user_map, guild=self)
|
||||
|
||||
async def widget(self) -> Widget:
|
||||
"""|coro|
|
||||
|
@ -1,753 +0,0 @@
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015-present Rapptz
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a
|
||||
copy of this software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
from typing import Awaitable, TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator
|
||||
|
||||
from .errors import NoMoreItems
|
||||
from .utils import snowflake_time, time_snowflake, maybe_coroutine
|
||||
from .object import Object
|
||||
from .audit_logs import AuditLogEntry
|
||||
|
||||
__all__ = (
|
||||
'ReactionIterator',
|
||||
'HistoryIterator',
|
||||
'AuditLogIterator',
|
||||
'GuildIterator',
|
||||
'MemberIterator',
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .types.audit_log import (
|
||||
AuditLog as AuditLogPayload,
|
||||
)
|
||||
from .types.guild import (
|
||||
Guild as GuildPayload,
|
||||
)
|
||||
from .types.message import (
|
||||
Message as MessagePayload,
|
||||
)
|
||||
from .types.user import (
|
||||
PartialUser as PartialUserPayload,
|
||||
)
|
||||
|
||||
from .types.threads import (
|
||||
Thread as ThreadPayload,
|
||||
)
|
||||
|
||||
from .member import Member
|
||||
from .user import User
|
||||
from .message import Message
|
||||
from .audit_logs import AuditLogEntry
|
||||
from .guild import Guild
|
||||
from .threads import Thread
|
||||
from .abc import Snowflake
|
||||
|
||||
T = TypeVar('T')
|
||||
OT = TypeVar('OT')
|
||||
_Func = Callable[[T], Union[OT, Awaitable[OT]]]
|
||||
|
||||
OLDEST_OBJECT = Object(id=0)
|
||||
|
||||
|
||||
class _AsyncIterator(AsyncIterator[T]):
|
||||
__slots__ = ()
|
||||
|
||||
async def next(self) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
def get(self, **attrs: Any) -> Awaitable[Optional[T]]:
|
||||
def predicate(elem: T):
|
||||
for attr, val in attrs.items():
|
||||
nested = attr.split('__')
|
||||
obj = elem
|
||||
for attribute in nested:
|
||||
obj = getattr(obj, attribute)
|
||||
|
||||
if obj != val:
|
||||
return False
|
||||
return True
|
||||
|
||||
return self.find(predicate)
|
||||
|
||||
async def find(self, predicate: _Func[T, bool]) -> Optional[T]:
|
||||
while True:
|
||||
try:
|
||||
elem = await self.next()
|
||||
except NoMoreItems:
|
||||
return None
|
||||
|
||||
ret = await maybe_coroutine(predicate, elem)
|
||||
if ret:
|
||||
return elem
|
||||
|
||||
def chunk(self, max_size: int) -> _ChunkedAsyncIterator[T]:
|
||||
if max_size <= 0:
|
||||
raise ValueError('async iterator chunk sizes must be greater than 0.')
|
||||
return _ChunkedAsyncIterator(self, max_size)
|
||||
|
||||
def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]:
|
||||
return _MappedAsyncIterator(self, func)
|
||||
|
||||
def filter(self, predicate: _Func[T, bool]) -> _FilteredAsyncIterator[T]:
|
||||
return _FilteredAsyncIterator(self, predicate)
|
||||
|
||||
async def flatten(self) -> List[T]:
|
||||
return [element async for element in self]
|
||||
|
||||
async def __anext__(self) -> T:
|
||||
try:
|
||||
return await self.next()
|
||||
except NoMoreItems:
|
||||
raise StopAsyncIteration()
|
||||
|
||||
|
||||
def _identity(x):
|
||||
return x
|
||||
|
||||
|
||||
class _ChunkedAsyncIterator(_AsyncIterator[List[T]]):
|
||||
def __init__(self, iterator, max_size):
|
||||
self.iterator = iterator
|
||||
self.max_size = max_size
|
||||
|
||||
async def next(self) -> List[T]:
|
||||
ret: List[T] = []
|
||||
n = 0
|
||||
while n < self.max_size:
|
||||
try:
|
||||
item = await self.iterator.next()
|
||||
except NoMoreItems:
|
||||
if ret:
|
||||
return ret
|
||||
raise
|
||||
else:
|
||||
ret.append(item)
|
||||
n += 1
|
||||
return ret
|
||||
|
||||
|
||||
class _MappedAsyncIterator(_AsyncIterator[T]):
|
||||
def __init__(self, iterator, func):
|
||||
self.iterator = iterator
|
||||
self.func = func
|
||||
|
||||
async def next(self) -> T:
|
||||
# this raises NoMoreItems and will propagate appropriately
|
||||
item = await self.iterator.next()
|
||||
return await maybe_coroutine(self.func, item)
|
||||
|
||||
|
||||
class _FilteredAsyncIterator(_AsyncIterator[T]):
|
||||
def __init__(self, iterator, predicate):
|
||||
self.iterator = iterator
|
||||
|
||||
if predicate is None:
|
||||
predicate = _identity
|
||||
|
||||
self.predicate = predicate
|
||||
|
||||
async def next(self) -> T:
|
||||
getter = self.iterator.next
|
||||
pred = self.predicate
|
||||
while True:
|
||||
# propagate NoMoreItems similar to _MappedAsyncIterator
|
||||
item = await getter()
|
||||
ret = await maybe_coroutine(pred, item)
|
||||
if ret:
|
||||
return item
|
||||
|
||||
|
||||
class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
|
||||
def __init__(self, message, emoji, limit=100, after=None):
|
||||
self.message = message
|
||||
self.limit = limit
|
||||
self.after = after
|
||||
state = message._state
|
||||
self.getter = state.http.get_reaction_users
|
||||
self.state = state
|
||||
self.emoji = emoji
|
||||
self.guild = message.guild
|
||||
self.channel_id = message.channel.id
|
||||
self.users = asyncio.Queue()
|
||||
|
||||
async def next(self) -> Union[User, Member]:
|
||||
if self.users.empty():
|
||||
await self.fill_users()
|
||||
|
||||
try:
|
||||
return self.users.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
raise NoMoreItems()
|
||||
|
||||
async def fill_users(self):
|
||||
# this is a hack because >circular imports<
|
||||
from .user import User
|
||||
|
||||
if self.limit > 0:
|
||||
retrieve = self.limit if self.limit <= 100 else 100
|
||||
|
||||
after = self.after.id if self.after else None
|
||||
data: List[PartialUserPayload] = await self.getter(
|
||||
self.channel_id, self.message.id, self.emoji, retrieve, after=after
|
||||
)
|
||||
|
||||
if data:
|
||||
self.limit -= retrieve
|
||||
self.after = Object(id=int(data[-1]['id']))
|
||||
|
||||
if self.guild is None or isinstance(self.guild, Object):
|
||||
for element in reversed(data):
|
||||
await self.users.put(User(state=self.state, data=element))
|
||||
else:
|
||||
for element in reversed(data):
|
||||
member_id = int(element['id'])
|
||||
member = self.guild.get_member(member_id)
|
||||
if member is not None:
|
||||
await self.users.put(member)
|
||||
else:
|
||||
await self.users.put(User(state=self.state, data=element))
|
||||
|
||||
|
||||
class HistoryIterator(_AsyncIterator['Message']):
|
||||
"""Iterator for receiving a channel's message history.
|
||||
|
||||
The messages endpoint has two behaviours we care about here:
|
||||
If ``before`` is specified, the messages endpoint returns the `limit`
|
||||
newest messages before ``before``, sorted with newest first. For filling over
|
||||
100 messages, update the ``before`` parameter to the oldest message received.
|
||||
Messages will be returned in order by time.
|
||||
If ``after`` is specified, it returns the ``limit`` oldest messages after
|
||||
``after``, sorted with newest first. For filling over 100 messages, update the
|
||||
``after`` parameter to the newest message received. If messages are not
|
||||
reversed, they will be out of order (99-0, 199-100, so on)
|
||||
|
||||
A note that if both ``before`` and ``after`` are specified, ``before`` is ignored by the
|
||||
messages endpoint.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
messageable: :class:`abc.Messageable`
|
||||
Messageable class to retrieve message history from.
|
||||
limit: :class:`int`
|
||||
Maximum number of messages to retrieve
|
||||
before: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]]
|
||||
Message before which all messages must be.
|
||||
after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]]
|
||||
Message after which all messages must be.
|
||||
around: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]]
|
||||
Message around which all messages must be. Limit max 101. Note that if
|
||||
limit is an even number, this will return at most limit+1 messages.
|
||||
oldest_first: Optional[:class:`bool`]
|
||||
If set to ``True``, return messages in oldest->newest order. Defaults to
|
||||
``True`` if `after` is specified, otherwise ``False``.
|
||||
"""
|
||||
|
||||
def __init__(self, messageable, limit, before=None, after=None, around=None, oldest_first=None):
|
||||
|
||||
if isinstance(before, datetime.datetime):
|
||||
before = Object(id=time_snowflake(before, high=False))
|
||||
if isinstance(after, datetime.datetime):
|
||||
after = Object(id=time_snowflake(after, high=True))
|
||||
if isinstance(around, datetime.datetime):
|
||||
around = Object(id=time_snowflake(around))
|
||||
|
||||
if oldest_first is None:
|
||||
self.reverse = after is not None
|
||||
else:
|
||||
self.reverse = oldest_first
|
||||
|
||||
self.messageable = messageable
|
||||
self.limit = limit
|
||||
self.before = before
|
||||
self.after = after or OLDEST_OBJECT
|
||||
self.around = around
|
||||
|
||||
self._filter = None # message dict -> bool
|
||||
|
||||
self.state = self.messageable._state
|
||||
self.logs_from = self.state.http.logs_from
|
||||
self.messages = asyncio.Queue()
|
||||
|
||||
if self.around:
|
||||
if self.limit is None:
|
||||
raise ValueError('history does not support around with limit=None')
|
||||
if self.limit > 101:
|
||||
raise ValueError("history max limit 101 when specifying around parameter")
|
||||
elif self.limit == 101:
|
||||
self.limit = 100 # Thanks discord
|
||||
|
||||
self._retrieve_messages = self._retrieve_messages_around_strategy # type: ignore
|
||||
if self.before and self.after:
|
||||
self._filter = lambda m: self.after.id < int(m['id']) < self.before.id
|
||||
elif self.before:
|
||||
self._filter = lambda m: int(m['id']) < self.before.id
|
||||
elif self.after:
|
||||
self._filter = lambda m: self.after.id < int(m['id'])
|
||||
else:
|
||||
if self.reverse:
|
||||
self._retrieve_messages = self._retrieve_messages_after_strategy # type: ignore
|
||||
if self.before:
|
||||
self._filter = lambda m: int(m['id']) < self.before.id
|
||||
else:
|
||||
self._retrieve_messages = self._retrieve_messages_before_strategy # type: ignore
|
||||
if self.after and self.after != OLDEST_OBJECT:
|
||||
self._filter = lambda m: int(m['id']) > self.after.id
|
||||
|
||||
async def next(self) -> Message:
|
||||
if self.messages.empty():
|
||||
await self.fill_messages()
|
||||
|
||||
try:
|
||||
return self.messages.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
raise NoMoreItems()
|
||||
|
||||
def _get_retrieve(self):
|
||||
l = self.limit
|
||||
if l is None or l > 100:
|
||||
r = 100
|
||||
else:
|
||||
r = l
|
||||
self.retrieve = r
|
||||
return r > 0
|
||||
|
||||
async def fill_messages(self):
|
||||
if not hasattr(self, 'channel'):
|
||||
# do the required set up
|
||||
channel = await self.messageable._get_channel()
|
||||
self.channel = channel
|
||||
|
||||
if self._get_retrieve():
|
||||
data = await self._retrieve_messages(self.retrieve)
|
||||
if len(data) < 100:
|
||||
self.limit = 0 # terminate the infinite loop
|
||||
|
||||
if self.reverse:
|
||||
data = reversed(data)
|
||||
if self._filter:
|
||||
data = filter(self._filter, data)
|
||||
|
||||
channel = self.channel
|
||||
for element in data:
|
||||
await self.messages.put(self.state.create_message(channel=channel, data=element))
|
||||
|
||||
async def _retrieve_messages(self, retrieve) -> List[Message]:
|
||||
"""Retrieve messages and update next parameters."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def _retrieve_messages_before_strategy(self, retrieve):
|
||||
"""Retrieve messages using before parameter."""
|
||||
before = self.before.id if self.before else None
|
||||
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, before=before)
|
||||
if len(data):
|
||||
if self.limit is not None:
|
||||
self.limit -= retrieve
|
||||
self.before = Object(id=int(data[-1]['id']))
|
||||
return data
|
||||
|
||||
async def _retrieve_messages_after_strategy(self, retrieve):
|
||||
"""Retrieve messages using after parameter."""
|
||||
after = self.after.id if self.after else None
|
||||
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, after=after)
|
||||
if len(data):
|
||||
if self.limit is not None:
|
||||
self.limit -= retrieve
|
||||
self.after = Object(id=int(data[0]['id']))
|
||||
return data
|
||||
|
||||
async def _retrieve_messages_around_strategy(self, retrieve):
|
||||
"""Retrieve messages using around parameter."""
|
||||
if self.around:
|
||||
around = self.around.id if self.around else None
|
||||
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, around=around)
|
||||
self.around = None
|
||||
return data
|
||||
return []
|
||||
|
||||
|
||||
class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
|
||||
def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None):
|
||||
if isinstance(before, datetime.datetime):
|
||||
before = Object(id=time_snowflake(before, high=False))
|
||||
if isinstance(after, datetime.datetime):
|
||||
after = Object(id=time_snowflake(after, high=True))
|
||||
|
||||
if oldest_first is None:
|
||||
self.reverse = after is not None
|
||||
else:
|
||||
self.reverse = oldest_first
|
||||
|
||||
self.guild = guild
|
||||
self.loop = guild._state.loop
|
||||
self.request = guild._state.http.get_audit_logs
|
||||
self.limit = limit
|
||||
self.before = before
|
||||
self.user_id = user_id
|
||||
self.action_type = action_type
|
||||
self.after = OLDEST_OBJECT
|
||||
self._users = {}
|
||||
self._state = guild._state
|
||||
|
||||
self._filter = None # entry dict -> bool
|
||||
|
||||
self.entries = asyncio.Queue()
|
||||
|
||||
if self.reverse:
|
||||
self._strategy = self._after_strategy
|
||||
if self.before:
|
||||
self._filter = lambda m: int(m['id']) < self.before.id
|
||||
else:
|
||||
self._strategy = self._before_strategy
|
||||
if self.after and self.after != OLDEST_OBJECT:
|
||||
self._filter = lambda m: int(m['id']) > self.after.id
|
||||
|
||||
async def _before_strategy(self, retrieve):
|
||||
before = self.before.id if self.before else None
|
||||
data: AuditLogPayload = await self.request(
|
||||
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, before=before
|
||||
)
|
||||
|
||||
entries = data.get('audit_log_entries', [])
|
||||
if len(data) and entries:
|
||||
if self.limit is not None:
|
||||
self.limit -= retrieve
|
||||
self.before = Object(id=int(entries[-1]['id']))
|
||||
return data.get('users', []), entries
|
||||
|
||||
async def _after_strategy(self, retrieve):
|
||||
after = self.after.id if self.after else None
|
||||
data: AuditLogPayload = await self.request(
|
||||
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, after=after
|
||||
)
|
||||
entries = data.get('audit_log_entries', [])
|
||||
if len(data) and entries:
|
||||
if self.limit is not None:
|
||||
self.limit -= retrieve
|
||||
self.after = Object(id=int(entries[0]['id']))
|
||||
return data.get('users', []), entries
|
||||
|
||||
async def next(self) -> AuditLogEntry:
|
||||
if self.entries.empty():
|
||||
await self._fill()
|
||||
|
||||
try:
|
||||
return self.entries.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
raise NoMoreItems()
|
||||
|
||||
def _get_retrieve(self):
|
||||
l = self.limit
|
||||
if l is None or l > 100:
|
||||
r = 100
|
||||
else:
|
||||
r = l
|
||||
self.retrieve = r
|
||||
return r > 0
|
||||
|
||||
async def _fill(self):
|
||||
from .user import User
|
||||
|
||||
if self._get_retrieve():
|
||||
users, data = await self._strategy(self.retrieve)
|
||||
if len(data) < 100:
|
||||
self.limit = 0 # terminate the infinite loop
|
||||
|
||||
if self.reverse:
|
||||
data = reversed(data)
|
||||
if self._filter:
|
||||
data = filter(self._filter, data)
|
||||
|
||||
for user in users:
|
||||
u = User(data=user, state=self._state)
|
||||
self._users[u.id] = u
|
||||
|
||||
for element in data:
|
||||
# TODO: remove this if statement later
|
||||
if element['action_type'] is None:
|
||||
continue
|
||||
|
||||
await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild))
|
||||
|
||||
|
||||
class GuildIterator(_AsyncIterator['Guild']):
|
||||
"""Iterator for receiving the client's guilds.
|
||||
|
||||
The guilds endpoint has the same two behaviours as described
|
||||
in :class:`HistoryIterator`:
|
||||
If ``before`` is specified, the guilds endpoint returns the ``limit``
|
||||
newest guilds before ``before``, sorted with newest first. For filling over
|
||||
100 guilds, update the ``before`` parameter to the oldest guild received.
|
||||
Guilds will be returned in order by time.
|
||||
If `after` is specified, it returns the ``limit`` oldest guilds after ``after``,
|
||||
sorted with newest first. For filling over 100 guilds, update the ``after``
|
||||
parameter to the newest guild received, If guilds are not reversed, they
|
||||
will be out of order (99-0, 199-100, so on)
|
||||
|
||||
Not that if both ``before`` and ``after`` are specified, ``before`` is ignored by the
|
||||
guilds endpoint.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
bot: :class:`discord.Client`
|
||||
The client to retrieve the guilds from.
|
||||
limit: :class:`int`
|
||||
Maximum number of guilds to retrieve.
|
||||
before: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]]
|
||||
Object before which all guilds must be.
|
||||
after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]]
|
||||
Object after which all guilds must be.
|
||||
"""
|
||||
|
||||
def __init__(self, bot, limit, before=None, after=None):
|
||||
|
||||
if isinstance(before, datetime.datetime):
|
||||
before = Object(id=time_snowflake(before, high=False))
|
||||
if isinstance(after, datetime.datetime):
|
||||
after = Object(id=time_snowflake(after, high=True))
|
||||
|
||||
self.bot = bot
|
||||
self.limit = limit
|
||||
self.before = before
|
||||
self.after = after
|
||||
|
||||
self._filter = None
|
||||
|
||||
self.state = self.bot._connection
|
||||
self.get_guilds = self.bot.http.get_guilds
|
||||
self.guilds = asyncio.Queue()
|
||||
|
||||
if self.before and self.after:
|
||||
self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore
|
||||
self._filter = lambda m: int(m['id']) > self.after.id
|
||||
elif self.after:
|
||||
self._retrieve_guilds = self._retrieve_guilds_after_strategy # type: ignore
|
||||
else:
|
||||
self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore
|
||||
|
||||
async def next(self) -> Guild:
|
||||
if self.guilds.empty():
|
||||
await self.fill_guilds()
|
||||
|
||||
try:
|
||||
return self.guilds.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
raise NoMoreItems()
|
||||
|
||||
def _get_retrieve(self):
|
||||
l = self.limit
|
||||
if l is None or l > 100:
|
||||
r = 100
|
||||
else:
|
||||
r = l
|
||||
self.retrieve = r
|
||||
return r > 0
|
||||
|
||||
def create_guild(self, data):
|
||||
from .guild import Guild
|
||||
|
||||
return Guild(state=self.state, data=data)
|
||||
|
||||
async def fill_guilds(self):
|
||||
if self._get_retrieve():
|
||||
data = await self._retrieve_guilds(self.retrieve)
|
||||
if self.limit is None or len(data) < 100:
|
||||
self.limit = 0
|
||||
|
||||
if self._filter:
|
||||
data = filter(self._filter, data)
|
||||
|
||||
for element in data:
|
||||
await self.guilds.put(self.create_guild(element))
|
||||
|
||||
async def _retrieve_guilds(self, retrieve) -> List[Guild]:
|
||||
"""Retrieve guilds and update next parameters."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def _retrieve_guilds_before_strategy(self, retrieve):
|
||||
"""Retrieve guilds using before parameter."""
|
||||
before = self.before.id if self.before else None
|
||||
data: List[GuildPayload] = await self.get_guilds(retrieve, before=before)
|
||||
if len(data):
|
||||
if self.limit is not None:
|
||||
self.limit -= retrieve
|
||||
self.before = Object(id=int(data[-1]['id']))
|
||||
return data
|
||||
|
||||
async def _retrieve_guilds_after_strategy(self, retrieve):
|
||||
"""Retrieve guilds using after parameter."""
|
||||
after = self.after.id if self.after else None
|
||||
data: List[GuildPayload] = await self.get_guilds(retrieve, after=after)
|
||||
if len(data):
|
||||
if self.limit is not None:
|
||||
self.limit -= retrieve
|
||||
self.after = Object(id=int(data[0]['id']))
|
||||
return data
|
||||
|
||||
|
||||
class MemberIterator(_AsyncIterator['Member']):
|
||||
def __init__(self, guild, limit=1000, after=None):
|
||||
|
||||
if isinstance(after, datetime.datetime):
|
||||
after = Object(id=time_snowflake(after, high=True))
|
||||
|
||||
self.guild = guild
|
||||
self.limit = limit
|
||||
self.after = after or OLDEST_OBJECT
|
||||
|
||||
self.state = self.guild._state
|
||||
self.get_members = self.state.http.get_members
|
||||
self.members = asyncio.Queue()
|
||||
|
||||
async def next(self) -> Member:
|
||||
if self.members.empty():
|
||||
await self.fill_members()
|
||||
|
||||
try:
|
||||
return self.members.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
raise NoMoreItems()
|
||||
|
||||
def _get_retrieve(self):
|
||||
l = self.limit
|
||||
if l is None or l > 1000:
|
||||
r = 1000
|
||||
else:
|
||||
r = l
|
||||
self.retrieve = r
|
||||
return r > 0
|
||||
|
||||
async def fill_members(self):
|
||||
if self._get_retrieve():
|
||||
after = self.after.id if self.after else None
|
||||
data = await self.get_members(self.guild.id, self.retrieve, after)
|
||||
if not data:
|
||||
# no data, terminate
|
||||
return
|
||||
|
||||
if len(data) < 1000:
|
||||
self.limit = 0 # terminate loop
|
||||
|
||||
self.after = Object(id=int(data[-1]['user']['id']))
|
||||
|
||||
for element in reversed(data):
|
||||
await self.members.put(self.create_member(element))
|
||||
|
||||
def create_member(self, data):
|
||||
from .member import Member
|
||||
|
||||
return Member(data=data, guild=self.guild, state=self.state)
|
||||
|
||||
|
||||
class ArchivedThreadIterator(_AsyncIterator['Thread']):
|
||||
def __init__(
|
||||
self,
|
||||
channel_id: int,
|
||||
guild: Guild,
|
||||
limit: Optional[int],
|
||||
joined: bool,
|
||||
private: bool,
|
||||
before: Optional[Union[Snowflake, datetime.datetime]] = None,
|
||||
):
|
||||
self.channel_id = channel_id
|
||||
self.guild = guild
|
||||
self.limit = limit
|
||||
self.joined = joined
|
||||
self.private = private
|
||||
self.http = guild._state.http
|
||||
|
||||
if joined and not private:
|
||||
raise ValueError('Cannot iterate over joined public archived threads')
|
||||
|
||||
self.before: Optional[str]
|
||||
if before is None:
|
||||
self.before = None
|
||||
elif isinstance(before, datetime.datetime):
|
||||
if joined:
|
||||
self.before = str(time_snowflake(before, high=False))
|
||||
else:
|
||||
self.before = before.isoformat()
|
||||
else:
|
||||
if joined:
|
||||
self.before = str(before.id)
|
||||
else:
|
||||
self.before = snowflake_time(before.id).isoformat()
|
||||
|
||||
self.update_before: Callable[[ThreadPayload], str] = self.get_archive_timestamp
|
||||
|
||||
if joined:
|
||||
self.endpoint = self.http.get_joined_private_archived_threads
|
||||
self.update_before = self.get_thread_id
|
||||
elif private:
|
||||
self.endpoint = self.http.get_private_archived_threads
|
||||
else:
|
||||
self.endpoint = self.http.get_public_archived_threads
|
||||
|
||||
self.queue: asyncio.Queue[Thread] = asyncio.Queue()
|
||||
self.has_more: bool = True
|
||||
|
||||
async def next(self) -> Thread:
|
||||
if self.queue.empty():
|
||||
await self.fill_queue()
|
||||
|
||||
try:
|
||||
return self.queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
raise NoMoreItems()
|
||||
|
||||
@staticmethod
|
||||
def get_archive_timestamp(data: ThreadPayload) -> str:
|
||||
return data['thread_metadata']['archive_timestamp']
|
||||
|
||||
@staticmethod
|
||||
def get_thread_id(data: ThreadPayload) -> str:
|
||||
return data['id'] # type: ignore
|
||||
|
||||
async def fill_queue(self) -> None:
|
||||
if not self.has_more:
|
||||
raise NoMoreItems()
|
||||
|
||||
limit = 50 if self.limit is None else max(self.limit, 50)
|
||||
data = await self.endpoint(self.channel_id, before=self.before, limit=limit)
|
||||
|
||||
# This stuff is obviously WIP because 'members' is always empty
|
||||
threads: List[ThreadPayload] = data.get('threads', [])
|
||||
for d in reversed(threads):
|
||||
self.queue.put_nowait(self.create_thread(d))
|
||||
|
||||
self.has_more = data.get('has_more', False)
|
||||
if self.limit is not None:
|
||||
self.limit -= len(threads)
|
||||
if self.limit <= 0:
|
||||
self.has_more = False
|
||||
|
||||
if self.has_more:
|
||||
self.before = self.update_before(threads[-1])
|
||||
|
||||
def create_thread(self, data: ThreadPayload) -> Thread:
|
||||
from .threads import Thread
|
||||
return Thread(guild=self.guild, state=self.guild._state, data=data)
|
@ -24,8 +24,8 @@ DEALINGS IN THE SOFTWARE.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from . import utils
|
||||
from .mixins import Hashable
|
||||
from .utils import snowflake_time
|
||||
|
||||
from typing import (
|
||||
SupportsInt,
|
||||
@ -89,4 +89,7 @@ class Object(Hashable):
|
||||
@property
|
||||
def created_at(self) -> datetime.datetime:
|
||||
""":class:`datetime.datetime`: Returns the snowflake's creation time in UTC."""
|
||||
return utils.snowflake_time(self.id)
|
||||
return snowflake_time(self.id)
|
||||
|
||||
|
||||
OLDEST_OBJECT = Object(id=0)
|
||||
|
@ -23,15 +23,18 @@ DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Any, TYPE_CHECKING, Union, Optional
|
||||
from typing import Any, TYPE_CHECKING, AsyncIterator, List, Union, Optional
|
||||
from typing_extensions import reveal_type
|
||||
|
||||
from .iterators import ReactionIterator
|
||||
from .object import Object
|
||||
|
||||
__all__ = (
|
||||
'Reaction',
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
from .member import Member
|
||||
from .types.message import Reaction as ReactionPayload
|
||||
from .message import Message
|
||||
from .partial_emoji import PartialEmoji
|
||||
@ -155,8 +158,8 @@ class Reaction:
|
||||
"""
|
||||
await self.message.clear_reaction(self.emoji)
|
||||
|
||||
def users(self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None) -> ReactionIterator:
|
||||
"""Returns an :class:`AsyncIterator` representing the users that have reacted to the message.
|
||||
async def users(self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None) -> AsyncIterator[Union[Member, User]]:
|
||||
"""Returns an :term:`asynchronous iterator` representing the users that have reacted to the message.
|
||||
|
||||
The ``after`` parameter must represent a member
|
||||
and meet the :class:`abc.Snowflake` abc.
|
||||
@ -176,7 +179,7 @@ class Reaction:
|
||||
|
||||
Flattening into a list: ::
|
||||
|
||||
users = await reaction.users().flatten()
|
||||
users = [user async for user in reaction.users()]
|
||||
# users is now a list of User...
|
||||
winner = random.choice(users)
|
||||
await channel.send(f'{winner} has won the raffle.')
|
||||
@ -212,4 +215,31 @@ class Reaction:
|
||||
if limit is None:
|
||||
limit = self.count
|
||||
|
||||
return ReactionIterator(self.message, emoji, limit, after)
|
||||
while limit > 0:
|
||||
retrieve = min(limit, 100)
|
||||
|
||||
message = self.message
|
||||
guild = message.guild
|
||||
state = message._state
|
||||
after_id = after.id if after else None
|
||||
|
||||
data = await state.http.get_reaction_users(
|
||||
message.channel.id, message.id, emoji, retrieve, after=after_id
|
||||
)
|
||||
|
||||
if data:
|
||||
limit -= len(data)
|
||||
after = Object(id=int(data[-1]['id']))
|
||||
|
||||
if guild is None or isinstance(guild, Object):
|
||||
for raw_user in reversed(data):
|
||||
yield User(state=state, data=raw_user)
|
||||
|
||||
continue
|
||||
|
||||
for raw_user in reversed(data):
|
||||
member_id = int(raw_user['id'])
|
||||
member = guild.get_member(member_id)
|
||||
|
||||
yield member or User(state=state, data=raw_user)
|
||||
|
||||
|
129
docs/api.rst
129
docs/api.rst
@ -2614,135 +2614,6 @@ of :class:`enum.Enum`.
|
||||
|
||||
The guild may contain NSFW content.
|
||||
|
||||
Async Iterator
|
||||
----------------
|
||||
|
||||
Some API functions return an "async iterator". An async iterator is something that is
|
||||
capable of being used in an :ref:`async for statement <py:async for>`.
|
||||
|
||||
These async iterators can be used as follows: ::
|
||||
|
||||
async for elem in channel.history():
|
||||
# do stuff with elem here
|
||||
|
||||
Certain utilities make working with async iterators easier, detailed below.
|
||||
|
||||
.. class:: AsyncIterator
|
||||
|
||||
Represents the "AsyncIterator" concept. Note that no such class exists,
|
||||
it is purely abstract.
|
||||
|
||||
.. container:: operations
|
||||
|
||||
.. describe:: async for x in y
|
||||
|
||||
Iterates over the contents of the async iterator.
|
||||
|
||||
|
||||
.. method:: next()
|
||||
:async:
|
||||
|
||||
|coro|
|
||||
|
||||
Advances the iterator by one, if possible. If no more items are found
|
||||
then this raises :exc:`NoMoreItems`.
|
||||
|
||||
.. method:: get(**attrs)
|
||||
:async:
|
||||
|
||||
|coro|
|
||||
|
||||
Similar to :func:`utils.get` except run over the async iterator.
|
||||
|
||||
Getting the last message by a user named 'Dave' or ``None``: ::
|
||||
|
||||
msg = await channel.history().get(author__name='Dave')
|
||||
|
||||
.. method:: find(predicate)
|
||||
:async:
|
||||
|
||||
|coro|
|
||||
|
||||
Similar to :func:`utils.find` except run over the async iterator.
|
||||
|
||||
Unlike :func:`utils.find`\, the predicate provided can be a
|
||||
|coroutine_link|_.
|
||||
|
||||
Getting the last audit log with a reason or ``None``: ::
|
||||
|
||||
def predicate(event):
|
||||
return event.reason is not None
|
||||
|
||||
event = await guild.audit_logs().find(predicate)
|
||||
|
||||
:param predicate: The predicate to use. Could be a |coroutine_link|_.
|
||||
:return: The first element that returns ``True`` for the predicate or ``None``.
|
||||
|
||||
.. method:: flatten()
|
||||
:async:
|
||||
|
||||
|coro|
|
||||
|
||||
Flattens the async iterator into a :class:`list` with all the elements.
|
||||
|
||||
:return: A list of every element in the async iterator.
|
||||
:rtype: list
|
||||
|
||||
.. method:: chunk(max_size)
|
||||
|
||||
Collects items into chunks of up to a given maximum size.
|
||||
Another :class:`AsyncIterator` is returned which collects items into
|
||||
:class:`list`\s of a given size. The maximum chunk size must be a positive integer.
|
||||
|
||||
.. versionadded:: 1.6
|
||||
|
||||
Collecting groups of users: ::
|
||||
|
||||
async for leader, *users in reaction.users().chunk(3):
|
||||
...
|
||||
|
||||
.. warning::
|
||||
|
||||
The last chunk collected may not be as large as ``max_size``.
|
||||
|
||||
:param max_size: The size of individual chunks.
|
||||
:rtype: :class:`AsyncIterator`
|
||||
|
||||
.. method:: map(func)
|
||||
|
||||
This is similar to the built-in :func:`map <py:map>` function. Another
|
||||
:class:`AsyncIterator` is returned that executes the function on
|
||||
every element it is iterating over. This function can either be a
|
||||
regular function or a |coroutine_link|_.
|
||||
|
||||
Creating a content iterator: ::
|
||||
|
||||
def transform(message):
|
||||
return message.content
|
||||
|
||||
async for content in channel.history().map(transform):
|
||||
message_length = len(content)
|
||||
|
||||
:param func: The function to call on every element. Could be a |coroutine_link|_.
|
||||
:rtype: :class:`AsyncIterator`
|
||||
|
||||
.. method:: filter(predicate)
|
||||
|
||||
This is similar to the built-in :func:`filter <py:filter>` function. Another
|
||||
:class:`AsyncIterator` is returned that filters over the original
|
||||
async iterator. This predicate can be a regular function or a |coroutine_link|_.
|
||||
|
||||
Getting messages by non-bot accounts: ::
|
||||
|
||||
def predicate(message):
|
||||
return not message.author.bot
|
||||
|
||||
async for elem in channel.history().filter(predicate):
|
||||
...
|
||||
|
||||
:param predicate: The predicate to call on every element. Could be a |coroutine_link|_.
|
||||
:rtype: :class:`AsyncIterator`
|
||||
|
||||
.. _discord-api-audit-logs:
|
||||
|
||||
Audit Log Data
|
||||
|
Loading…
x
Reference in New Issue
Block a user