Refactor AsyncIter to use 3.6+ asynchronous generators

This commit is contained in:
Kaylynn Morgan
2022-02-20 13:58:13 +11:00
committed by GitHub
parent dc19c6c7d5
commit 588cda0996
8 changed files with 386 additions and 930 deletions

View File

@@ -28,6 +28,7 @@ import time
import asyncio
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterable,
@@ -54,7 +55,6 @@ from .asset import Asset
from .errors import ClientException, InvalidArgument
from .stage_instance import StageInstance
from .threads import Thread
from .iterators import ArchivedThreadIterator
__all__ = (
'TextChannel',
@@ -755,15 +755,15 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
return Thread(guild=self.guild, state=self._state, data=data)
def archived_threads(
async def archived_threads(
self,
*,
private: bool = False,
joined: bool = False,
limit: Optional[int] = 50,
before: Optional[Union[Snowflake, datetime.datetime]] = None,
) -> ArchivedThreadIterator:
"""Returns an :class:`~discord.AsyncIterator` that iterates over all archived threads in the guild.
) -> AsyncIterator[Thread]:
"""Returns an :term:`asynchronous iterator` that iterates over all archived threads in the guild.
You must have :attr:`~Permissions.read_message_history` to use this. If iterating over private threads
then :attr:`~Permissions.manage_threads` is also required.
@@ -790,13 +790,57 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
You do not have permissions to get archived threads.
HTTPException
The request to get the archived threads failed.
ValueError
`joined`` was set to ``True`` and ``private`` was set to ``False``. You cannot retrieve public archived
threads that you have joined.
Yields
-------
:class:`Thread`
The archived threads.
"""
return ArchivedThreadIterator(self.id, self.guild, limit=limit, joined=joined, private=private, before=before)
if joined and not private:
raise ValueError('Cannot retrieve joined public archived threads')
before_timestamp = None
if isinstance(before, datetime.datetime):
if joined:
before_timestamp = str(utils.time_snowflake(before, high=False))
else:
before_timestamp = before.isoformat()
elif before is not None:
if joined:
before_timestamp = str(before.id)
else:
before_timestamp = utils.snowflake_time(before.id).isoformat()
update_before = lambda data: data['thread_metadata']['archive_timestamp']
endpoint = self.guild._state.http.get_public_archived_threads
if joined:
update_before = lambda data: data['id']
endpoint = self.guild._state.http.get_joined_private_archived_threads
elif private:
endpoint = self.guild._state.http.get_private_archived_threads
while True:
retrieve = 50 if limit is None else max(limit, 50)
data = await endpoint(self.id, before=before_timestamp, limit=retrieve)
threads = data.get('threads', [])
for raw_thread in reversed(threads):
yield Thread(guild=self.guild, state=self.guild._state, data=raw_thread)
if not data.get('has_more', False):
return
if limit is not None:
limit -= len(threads)
if limit <= 0:
return
before = update_before(threads[-1])
class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable):