Refactor AsyncIter to use 3.6+ asynchronous generators

This commit is contained in:
Kaylynn Morgan
2022-02-20 13:58:13 +11:00
committed by GitHub
parent dc19c6c7d5
commit 588cda0996
8 changed files with 386 additions and 930 deletions

View File

@@ -25,9 +25,11 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import copy
import datetime
import unicodedata
from typing import (
Any,
AsyncIterator,
ClassVar,
Dict,
List,
@@ -67,7 +69,6 @@ from .enums import (
from .mixins import Hashable
from .user import User
from .invite import Invite
from .iterators import AuditLogIterator, MemberIterator
from .widget import Widget
from .asset import Asset
from .flags import SystemChannelFlags
@@ -76,6 +77,8 @@ from .stage_instance import StageInstance
from .threads import Thread, ThreadMember
from .sticker import GuildSticker
from .file import File
from .audit_logs import AuditLogEntry
from .object import OLDEST_OBJECT, Object
__all__ = (
@@ -98,8 +101,6 @@ if TYPE_CHECKING:
from .state import ConnectionState
from .voice_client import VoiceProtocol
import datetime
VocalGuildChannel = Union[VoiceChannel, StageChannel]
GuildChannel = Union[VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel]
ByCategoryItem = Tuple[Optional[CategoryChannel], List[GuildChannel]]
@@ -1649,9 +1650,8 @@ class Guild(Hashable):
return threads
# TODO: Remove Optional typing here when async iterators are refactored
def fetch_members(self, *, limit: int = 1000, after: Optional[SnowflakeTime] = None) -> MemberIterator:
"""Retrieves an :class:`.AsyncIterator` that enables receiving the guild's members. In order to use this,
async def fetch_members(self, *, limit: int = 1000, after: SnowflakeTime = MISSING) -> AsyncIterator[Member]:
"""Retrieves an :term:`asynchronous iterator` that enables receiving the guild's members. In order to use this,
:meth:`Intents.members` must be enabled.
.. note::
@@ -1701,7 +1701,30 @@ class Guild(Hashable):
if not self._state._intents.members:
raise ClientException('Intents.members must be enabled to use this.')
return MemberIterator(self, limit=limit, after=after)
while True:
retrieve = min(1000 if limit is None else limit, 1000)
if retrieve < 1:
return
if isinstance(after, datetime.datetime):
after = Object(id=utils.time_snowflake(after, high=True))
after = after or OLDEST_OBJECT
after_id = after.id if after else None
state = self._state
data = await state.http.get_members(self.id, retrieve, after_id)
if not data:
return
# Terminate loop on next iteration; there's no data left after this
if len(data) < 1000:
limit = 0
after = Object(id=int(data[-1]['user']['id']))
for raw_member in reversed(data):
yield Member(data=raw_member, guild=self, state=state)
async def fetch_member(self, member_id: int, /) -> Member:
"""|coro|
@@ -2731,18 +2754,17 @@ class Guild(Hashable):
payload['uses'] = payload.get('uses', 0)
return Invite(state=self._state, data=payload, guild=self, channel=channel)
# TODO: use MISSING when async iterators get refactored
def audit_logs(
async def audit_logs(
self,
*,
limit: int = 100,
before: Optional[SnowflakeTime] = None,
after: Optional[SnowflakeTime] = None,
oldest_first: Optional[bool] = None,
user: Snowflake = None,
action: AuditLogAction = None,
) -> AuditLogIterator:
"""Returns an :class:`AsyncIterator` that enables receiving the guild's audit logs.
user: Snowflake = MISSING,
action: AuditLogAction = MISSING,
) -> AsyncIterator[AuditLogEntry]:
"""Returns an :term:`asynchronous iterator` that enables receiving the guild's audit logs.
You must have the :attr:`~Permissions.view_audit_log` permission to use this.
@@ -2761,7 +2783,7 @@ class Guild(Hashable):
Getting entries made by a specific user: ::
entries = await guild.audit_logs(limit=None, user=guild.me).flatten()
entries = [entry async for entry in guild.audit_logs(limit=None, user=guild.me)]
await channel.send(f'I made {len(entries)} moderation actions.')
Parameters
@@ -2796,6 +2818,39 @@ class Guild(Hashable):
:class:`AuditLogEntry`
The audit log entry.
"""
async def _before_strategy(retrieve, before, limit):
before_id = before.id if before else None
data = await self._state.http.get_audit_logs(
self.id, limit=retrieve, user_id=user_id, action_type=action, before=before_id
)
entries = data.get('audit_log_entries', [])
if data and entries:
if limit is not None:
limit -= len(data)
before = Object(id=int(entries[-1]['id']))
return data.get('users', []), entries, before, limit
async def _after_strategy(retrieve, after, limit):
after_id = after.id if after else None
data = await self._state.http.get_audit_logs(
self.id, limit=retrieve, user_id=user_id, action_type=action, after=after_id
)
entries = data.get('audit_log_entries', [])
if data and entries:
if limit is not None:
limit -= len(data)
after = Object(id=int(entries[0]['id']))
return data.get('users', []), entries, after, limit
if user is not None:
user_id = user.id
else:
@@ -2804,9 +2859,53 @@ class Guild(Hashable):
if action:
action = action.value
return AuditLogIterator(
self, before=before, after=after, limit=limit, oldest_first=oldest_first, user_id=user_id, action_type=action
)
if isinstance(before, datetime.datetime):
before = Object(id=utils.time_snowflake(before, high=False))
if isinstance(after, datetime.datetime):
after = Object(id=utils.time_snowflake(after, high=True))
if oldest_first is None:
reverse = after is not None
else:
reverse = oldest_first
predicate = None
if reverse:
strategy, state = _after_strategy, after
if before:
predicate = lambda m: int(m['id']) < before.id
else:
strategy, state = _before_strategy, before
if after and after != OLDEST_OBJECT:
predicate = lambda m: int(m['id']) > after.id
while True:
retrieve = min(100 if limit is None else limit, 100)
if retrieve < 1:
return
raw_users, data, state, limit = await strategy(retrieve, state, limit)
# Terminate loop on next iteration; there's no data left after this
if len(data) < 100:
limit = 0
if reverse:
data = reversed(data)
if predicate:
data = filter(predicate, data)
users = (User(data=raw_user, state=self._state) for raw_user in raw_users)
user_map = {user.id: user for user in users}
for raw_entry in data:
# Weird Discord quirk
if raw_entry['action_type'] is None:
continue
yield AuditLogEntry(data=raw_entry, users=user_map, guild=self)
async def widget(self) -> Widget:
"""|coro|