mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-10-22 00:13:01 +00:00
Refactor AsyncIter to use 3.6+ asynchronous generators
This commit is contained in:
133
discord/guild.py
133
discord/guild.py
@@ -25,9 +25,11 @@ DEALINGS IN THE SOFTWARE.
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import datetime
|
||||
import unicodedata
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
ClassVar,
|
||||
Dict,
|
||||
List,
|
||||
@@ -67,7 +69,6 @@ from .enums import (
|
||||
from .mixins import Hashable
|
||||
from .user import User
|
||||
from .invite import Invite
|
||||
from .iterators import AuditLogIterator, MemberIterator
|
||||
from .widget import Widget
|
||||
from .asset import Asset
|
||||
from .flags import SystemChannelFlags
|
||||
@@ -76,6 +77,8 @@ from .stage_instance import StageInstance
|
||||
from .threads import Thread, ThreadMember
|
||||
from .sticker import GuildSticker
|
||||
from .file import File
|
||||
from .audit_logs import AuditLogEntry
|
||||
from .object import OLDEST_OBJECT, Object
|
||||
|
||||
|
||||
__all__ = (
|
||||
@@ -98,8 +101,6 @@ if TYPE_CHECKING:
|
||||
from .state import ConnectionState
|
||||
from .voice_client import VoiceProtocol
|
||||
|
||||
import datetime
|
||||
|
||||
VocalGuildChannel = Union[VoiceChannel, StageChannel]
|
||||
GuildChannel = Union[VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel]
|
||||
ByCategoryItem = Tuple[Optional[CategoryChannel], List[GuildChannel]]
|
||||
@@ -1649,9 +1650,8 @@ class Guild(Hashable):
|
||||
|
||||
return threads
|
||||
|
||||
# TODO: Remove Optional typing here when async iterators are refactored
|
||||
def fetch_members(self, *, limit: int = 1000, after: Optional[SnowflakeTime] = None) -> MemberIterator:
|
||||
"""Retrieves an :class:`.AsyncIterator` that enables receiving the guild's members. In order to use this,
|
||||
async def fetch_members(self, *, limit: int = 1000, after: SnowflakeTime = MISSING) -> AsyncIterator[Member]:
|
||||
"""Retrieves an :term:`asynchronous iterator` that enables receiving the guild's members. In order to use this,
|
||||
:meth:`Intents.members` must be enabled.
|
||||
|
||||
.. note::
|
||||
@@ -1701,7 +1701,30 @@ class Guild(Hashable):
|
||||
if not self._state._intents.members:
|
||||
raise ClientException('Intents.members must be enabled to use this.')
|
||||
|
||||
return MemberIterator(self, limit=limit, after=after)
|
||||
while True:
|
||||
retrieve = min(1000 if limit is None else limit, 1000)
|
||||
if retrieve < 1:
|
||||
return
|
||||
|
||||
if isinstance(after, datetime.datetime):
|
||||
after = Object(id=utils.time_snowflake(after, high=True))
|
||||
|
||||
after = after or OLDEST_OBJECT
|
||||
after_id = after.id if after else None
|
||||
state = self._state
|
||||
|
||||
data = await state.http.get_members(self.id, retrieve, after_id)
|
||||
if not data:
|
||||
return
|
||||
|
||||
# Terminate loop on next iteration; there's no data left after this
|
||||
if len(data) < 1000:
|
||||
limit = 0
|
||||
|
||||
after = Object(id=int(data[-1]['user']['id']))
|
||||
|
||||
for raw_member in reversed(data):
|
||||
yield Member(data=raw_member, guild=self, state=state)
|
||||
|
||||
async def fetch_member(self, member_id: int, /) -> Member:
|
||||
"""|coro|
|
||||
@@ -2731,18 +2754,17 @@ class Guild(Hashable):
|
||||
payload['uses'] = payload.get('uses', 0)
|
||||
return Invite(state=self._state, data=payload, guild=self, channel=channel)
|
||||
|
||||
# TODO: use MISSING when async iterators get refactored
|
||||
def audit_logs(
|
||||
async def audit_logs(
|
||||
self,
|
||||
*,
|
||||
limit: int = 100,
|
||||
before: Optional[SnowflakeTime] = None,
|
||||
after: Optional[SnowflakeTime] = None,
|
||||
oldest_first: Optional[bool] = None,
|
||||
user: Snowflake = None,
|
||||
action: AuditLogAction = None,
|
||||
) -> AuditLogIterator:
|
||||
"""Returns an :class:`AsyncIterator` that enables receiving the guild's audit logs.
|
||||
user: Snowflake = MISSING,
|
||||
action: AuditLogAction = MISSING,
|
||||
) -> AsyncIterator[AuditLogEntry]:
|
||||
"""Returns an :term:`asynchronous iterator` that enables receiving the guild's audit logs.
|
||||
|
||||
You must have the :attr:`~Permissions.view_audit_log` permission to use this.
|
||||
|
||||
@@ -2761,7 +2783,7 @@ class Guild(Hashable):
|
||||
|
||||
Getting entries made by a specific user: ::
|
||||
|
||||
entries = await guild.audit_logs(limit=None, user=guild.me).flatten()
|
||||
entries = [entry async for entry in guild.audit_logs(limit=None, user=guild.me)]
|
||||
await channel.send(f'I made {len(entries)} moderation actions.')
|
||||
|
||||
Parameters
|
||||
@@ -2796,6 +2818,39 @@ class Guild(Hashable):
|
||||
:class:`AuditLogEntry`
|
||||
The audit log entry.
|
||||
"""
|
||||
|
||||
async def _before_strategy(retrieve, before, limit):
|
||||
before_id = before.id if before else None
|
||||
data = await self._state.http.get_audit_logs(
|
||||
self.id, limit=retrieve, user_id=user_id, action_type=action, before=before_id
|
||||
)
|
||||
|
||||
entries = data.get('audit_log_entries', [])
|
||||
|
||||
if data and entries:
|
||||
if limit is not None:
|
||||
limit -= len(data)
|
||||
|
||||
before = Object(id=int(entries[-1]['id']))
|
||||
|
||||
return data.get('users', []), entries, before, limit
|
||||
|
||||
async def _after_strategy(retrieve, after, limit):
|
||||
after_id = after.id if after else None
|
||||
data = await self._state.http.get_audit_logs(
|
||||
self.id, limit=retrieve, user_id=user_id, action_type=action, after=after_id
|
||||
)
|
||||
|
||||
entries = data.get('audit_log_entries', [])
|
||||
|
||||
if data and entries:
|
||||
if limit is not None:
|
||||
limit -= len(data)
|
||||
|
||||
after = Object(id=int(entries[0]['id']))
|
||||
|
||||
return data.get('users', []), entries, after, limit
|
||||
|
||||
if user is not None:
|
||||
user_id = user.id
|
||||
else:
|
||||
@@ -2804,9 +2859,53 @@ class Guild(Hashable):
|
||||
if action:
|
||||
action = action.value
|
||||
|
||||
return AuditLogIterator(
|
||||
self, before=before, after=after, limit=limit, oldest_first=oldest_first, user_id=user_id, action_type=action
|
||||
)
|
||||
if isinstance(before, datetime.datetime):
|
||||
before = Object(id=utils.time_snowflake(before, high=False))
|
||||
if isinstance(after, datetime.datetime):
|
||||
after = Object(id=utils.time_snowflake(after, high=True))
|
||||
|
||||
|
||||
if oldest_first is None:
|
||||
reverse = after is not None
|
||||
else:
|
||||
reverse = oldest_first
|
||||
|
||||
predicate = None
|
||||
|
||||
if reverse:
|
||||
strategy, state = _after_strategy, after
|
||||
if before:
|
||||
predicate = lambda m: int(m['id']) < before.id
|
||||
else:
|
||||
strategy, state = _before_strategy, before
|
||||
if after and after != OLDEST_OBJECT:
|
||||
predicate = lambda m: int(m['id']) > after.id
|
||||
|
||||
while True:
|
||||
retrieve = min(100 if limit is None else limit, 100)
|
||||
if retrieve < 1:
|
||||
return
|
||||
|
||||
raw_users, data, state, limit = await strategy(retrieve, state, limit)
|
||||
|
||||
# Terminate loop on next iteration; there's no data left after this
|
||||
if len(data) < 100:
|
||||
limit = 0
|
||||
|
||||
if reverse:
|
||||
data = reversed(data)
|
||||
if predicate:
|
||||
data = filter(predicate, data)
|
||||
|
||||
users = (User(data=raw_user, state=self._state) for raw_user in raw_users)
|
||||
user_map = {user.id: user for user in users}
|
||||
|
||||
for raw_entry in data:
|
||||
# Weird Discord quirk
|
||||
if raw_entry['action_type'] is None:
|
||||
continue
|
||||
|
||||
yield AuditLogEntry(data=raw_entry, users=user_map, guild=self)
|
||||
|
||||
async def widget(self) -> Widget:
|
||||
"""|coro|
|
||||
|
Reference in New Issue
Block a user