mirror of
				https://github.com/Rapptz/discord.py.git
				synced 2025-11-03 23:12:56 +00:00 
			
		
		
		
	Implement remaining HTTP endpoints on threads
I'm not sure if I missed any -- but this is the entire documented set so far.
This commit is contained in:
		@@ -29,7 +29,7 @@ import datetime
 | 
			
		||||
from typing import Awaitable, TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator
 | 
			
		||||
 | 
			
		||||
from .errors import NoMoreItems
 | 
			
		||||
from .utils import time_snowflake, maybe_coroutine
 | 
			
		||||
from .utils import snowflake_time, time_snowflake, maybe_coroutine
 | 
			
		||||
from .object import Object
 | 
			
		||||
from .audit_logs import AuditLogEntry
 | 
			
		||||
 | 
			
		||||
@@ -55,11 +55,17 @@ if TYPE_CHECKING:
 | 
			
		||||
        PartialUser as PartialUserPayload,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    from .types.threads import (
 | 
			
		||||
        Thread as ThreadPayload,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    from .member import Member
 | 
			
		||||
    from .user import User
 | 
			
		||||
    from .message import Message
 | 
			
		||||
    from .audit_logs import AuditLogEntry
 | 
			
		||||
    from .guild import Guild
 | 
			
		||||
    from .threads import Thread
 | 
			
		||||
    from .abc import Snowflake
 | 
			
		||||
 | 
			
		||||
T = TypeVar('T')
 | 
			
		||||
OT = TypeVar('OT')
 | 
			
		||||
@@ -655,3 +661,92 @@ class MemberIterator(_AsyncIterator['Member']):
 | 
			
		||||
        from .member import Member
 | 
			
		||||
 | 
			
		||||
        return Member(data=data, guild=self.guild, state=self.state)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ArchivedThreadIterator(_AsyncIterator['Thread']):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        channel_id: int,
 | 
			
		||||
        guild: Guild,
 | 
			
		||||
        limit: Optional[int],
 | 
			
		||||
        joined: bool,
 | 
			
		||||
        private: bool,
 | 
			
		||||
        before: Optional[Union[Snowflake, datetime.datetime]] = None,
 | 
			
		||||
    ):
 | 
			
		||||
        self.channel_id = channel_id
 | 
			
		||||
        self.guild = guild
 | 
			
		||||
        self.limit = limit
 | 
			
		||||
        self.joined = joined
 | 
			
		||||
        self.private = private
 | 
			
		||||
        self.http = guild._state.http
 | 
			
		||||
 | 
			
		||||
        if joined and not private:
 | 
			
		||||
            raise ValueError('Cannot iterate over joined public archived threads')
 | 
			
		||||
 | 
			
		||||
        self.before: Optional[str]
 | 
			
		||||
        if before is None:
 | 
			
		||||
            self.before = None
 | 
			
		||||
        elif isinstance(before, datetime.datetime):
 | 
			
		||||
            if joined:
 | 
			
		||||
                self.before = str(time_snowflake(before, high=False))
 | 
			
		||||
            else:
 | 
			
		||||
                self.before = before.isoformat()
 | 
			
		||||
        else:
 | 
			
		||||
            if joined:
 | 
			
		||||
                self.before = str(before.id)
 | 
			
		||||
            else:
 | 
			
		||||
                self.before = snowflake_time(before.id).isoformat()
 | 
			
		||||
 | 
			
		||||
        self.update_before: Callable[[ThreadPayload], str] = self.get_archive_timestamp
 | 
			
		||||
 | 
			
		||||
        if joined:
 | 
			
		||||
            self.endpoint = self.http.get_joined_private_archived_threads
 | 
			
		||||
            self.update_before = self.get_thread_id
 | 
			
		||||
        elif private:
 | 
			
		||||
            self.endpoint = self.http.get_private_archived_threads
 | 
			
		||||
        else:
 | 
			
		||||
            self.endpoint = self.http.get_archived_threads
 | 
			
		||||
 | 
			
		||||
        self.queue: asyncio.Queue[Thread] = asyncio.Queue()
 | 
			
		||||
        self.has_more: bool = True
 | 
			
		||||
 | 
			
		||||
    async def next(self) -> Thread:
 | 
			
		||||
        if self.queue.empty():
 | 
			
		||||
            await self.fill_queue()
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            return self.queue.get_nowait()
 | 
			
		||||
        except asyncio.QueueEmpty:
 | 
			
		||||
            raise NoMoreItems()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_archive_timestamp(data: ThreadPayload) -> str:
 | 
			
		||||
        return data['thread_metadata']['archive_timestamp']
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_thread_id(data: ThreadPayload) -> str:
 | 
			
		||||
        return data['id']  # type: ignore
 | 
			
		||||
 | 
			
		||||
    async def fill_queue(self) -> None:
 | 
			
		||||
        if not self.has_more:
 | 
			
		||||
            raise NoMoreItems()
 | 
			
		||||
 | 
			
		||||
        limit = 50 if self.limit is None else max(self.limit, 50)
 | 
			
		||||
        data = await self.endpoint(self.channel_id, before=self.before, limit=limit)
 | 
			
		||||
 | 
			
		||||
        # This stuff is obviously WIP because 'members' is always empty
 | 
			
		||||
        threads: List[ThreadPayload] = data.get('threads', [])
 | 
			
		||||
        for d in reversed(threads):
 | 
			
		||||
            self.queue.put_nowait(self.create_thread(d))
 | 
			
		||||
 | 
			
		||||
        self.has_more = data.get('has_more', False)
 | 
			
		||||
        if self.limit is not None:
 | 
			
		||||
            self.limit -= len(threads)
 | 
			
		||||
            if self.limit <= 0:
 | 
			
		||||
                self.has_more = False
 | 
			
		||||
 | 
			
		||||
        if self.has_more:
 | 
			
		||||
            self.before = self.update_before(threads[-1])
 | 
			
		||||
 | 
			
		||||
    def create_thread(self, data: ThreadPayload) -> Thread:
 | 
			
		||||
        return Thread(guild=self.guild, data=data)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user