mirror of
				https://github.com/Rapptz/discord.py.git
				synced 2025-11-03 23:12:56 +00:00 
			
		
		
		
	Add typings for audit logs, integrations, and webhooks
This commit is contained in:
		@@ -42,6 +42,19 @@ __all__ = (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from .types.audit_log import (
 | 
			
		||||
        AuditLog as AuditLogPayload,
 | 
			
		||||
    )
 | 
			
		||||
    from .types.guild import (
 | 
			
		||||
        Guild as GuildPayload,
 | 
			
		||||
    )
 | 
			
		||||
    from .types.message import (
 | 
			
		||||
        Message as MessagePayload,
 | 
			
		||||
    )
 | 
			
		||||
    from .types.user import (
 | 
			
		||||
        PartialUser as PartialUserPayload,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    from .member import Member
 | 
			
		||||
    from .user import User
 | 
			
		||||
    from .message import Message
 | 
			
		||||
@@ -54,6 +67,7 @@ _Func = Callable[[T], Union[OT, Awaitable[OT]]]
 | 
			
		||||
 | 
			
		||||
OLDEST_OBJECT = Object(id=0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _AsyncIterator(AsyncIterator[T]):
 | 
			
		||||
    __slots__ = ()
 | 
			
		||||
 | 
			
		||||
@@ -105,9 +119,11 @@ class _AsyncIterator(AsyncIterator[T]):
 | 
			
		||||
        except NoMoreItems:
 | 
			
		||||
            raise StopAsyncIteration()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _identity(x):
 | 
			
		||||
    return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _ChunkedAsyncIterator(_AsyncIterator[List[T]]):
 | 
			
		||||
    def __init__(self, iterator, max_size):
 | 
			
		||||
        self.iterator = iterator
 | 
			
		||||
@@ -128,6 +144,7 @@ class _ChunkedAsyncIterator(_AsyncIterator[List[T]]):
 | 
			
		||||
                n += 1
 | 
			
		||||
        return ret
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _MappedAsyncIterator(_AsyncIterator[T]):
 | 
			
		||||
    def __init__(self, iterator, func):
 | 
			
		||||
        self.iterator = iterator
 | 
			
		||||
@@ -138,6 +155,7 @@ class _MappedAsyncIterator(_AsyncIterator[T]):
 | 
			
		||||
        item = await self.iterator.next()
 | 
			
		||||
        return await maybe_coroutine(self.func, item)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _FilteredAsyncIterator(_AsyncIterator[T]):
 | 
			
		||||
    def __init__(self, iterator, predicate):
 | 
			
		||||
        self.iterator = iterator
 | 
			
		||||
@@ -157,6 +175,7 @@ class _FilteredAsyncIterator(_AsyncIterator[T]):
 | 
			
		||||
            if ret:
 | 
			
		||||
                return item
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
 | 
			
		||||
    def __init__(self, message, emoji, limit=100, after=None):
 | 
			
		||||
        self.message = message
 | 
			
		||||
@@ -187,7 +206,9 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
 | 
			
		||||
            retrieve = self.limit if self.limit <= 100 else 100
 | 
			
		||||
 | 
			
		||||
            after = self.after.id if self.after else None
 | 
			
		||||
            data = await self.getter(self.channel_id, self.message.id, self.emoji, retrieve, after=after)
 | 
			
		||||
            data: List[PartialUserPayload] = await self.getter(
 | 
			
		||||
                self.channel_id, self.message.id, self.emoji, retrieve, after=after
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if data:
 | 
			
		||||
                self.limit -= retrieve
 | 
			
		||||
@@ -205,6 +226,7 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
 | 
			
		||||
                    else:
 | 
			
		||||
                        await self.users.put(User(state=self.state, data=element))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HistoryIterator(_AsyncIterator['Message']):
 | 
			
		||||
    """Iterator for receiving a channel's message history.
 | 
			
		||||
 | 
			
		||||
@@ -239,8 +261,7 @@ class HistoryIterator(_AsyncIterator['Message']):
 | 
			
		||||
        ``True`` if `after` is specified, otherwise ``False``.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, messageable, limit,
 | 
			
		||||
                 before=None, after=None, around=None, oldest_first=None):
 | 
			
		||||
    def __init__(self, messageable, limit, before=None, after=None, around=None, oldest_first=None):
 | 
			
		||||
 | 
			
		||||
        if isinstance(before, datetime.datetime):
 | 
			
		||||
            before = Object(id=time_snowflake(before, high=False))
 | 
			
		||||
@@ -274,7 +295,7 @@ class HistoryIterator(_AsyncIterator['Message']):
 | 
			
		||||
            elif self.limit == 101:
 | 
			
		||||
                self.limit = 100  # Thanks discord
 | 
			
		||||
 | 
			
		||||
            self._retrieve_messages = self._retrieve_messages_around_strategy
 | 
			
		||||
            self._retrieve_messages = self._retrieve_messages_around_strategy  # type: ignore
 | 
			
		||||
            if self.before and self.after:
 | 
			
		||||
                self._filter = lambda m: self.after.id < int(m['id']) < self.before.id
 | 
			
		||||
            elif self.before:
 | 
			
		||||
@@ -283,12 +304,12 @@ class HistoryIterator(_AsyncIterator['Message']):
 | 
			
		||||
                self._filter = lambda m: self.after.id < int(m['id'])
 | 
			
		||||
        else:
 | 
			
		||||
            if self.reverse:
 | 
			
		||||
                self._retrieve_messages = self._retrieve_messages_after_strategy
 | 
			
		||||
                if (self.before):
 | 
			
		||||
                self._retrieve_messages = self._retrieve_messages_after_strategy  # type: ignore
 | 
			
		||||
                if self.before:
 | 
			
		||||
                    self._filter = lambda m: int(m['id']) < self.before.id
 | 
			
		||||
            else:
 | 
			
		||||
                self._retrieve_messages = self._retrieve_messages_before_strategy
 | 
			
		||||
                if (self.after and self.after != OLDEST_OBJECT):
 | 
			
		||||
                self._retrieve_messages = self._retrieve_messages_before_strategy  # type: ignore
 | 
			
		||||
                if self.after and self.after != OLDEST_OBJECT:
 | 
			
		||||
                    self._filter = lambda m: int(m['id']) > self.after.id
 | 
			
		||||
 | 
			
		||||
    async def next(self) -> Message:
 | 
			
		||||
@@ -318,7 +339,7 @@ class HistoryIterator(_AsyncIterator['Message']):
 | 
			
		||||
        if self._get_retrieve():
 | 
			
		||||
            data = await self._retrieve_messages(self.retrieve)
 | 
			
		||||
            if len(data) < 100:
 | 
			
		||||
                self.limit = 0 # terminate the infinite loop
 | 
			
		||||
                self.limit = 0  # terminate the infinite loop
 | 
			
		||||
 | 
			
		||||
            if self.reverse:
 | 
			
		||||
                data = reversed(data)
 | 
			
		||||
@@ -329,14 +350,14 @@ class HistoryIterator(_AsyncIterator['Message']):
 | 
			
		||||
            for element in data:
 | 
			
		||||
                await self.messages.put(self.state.create_message(channel=channel, data=element))
 | 
			
		||||
 | 
			
		||||
    async def _retrieve_messages(self, retrieve):
 | 
			
		||||
    async def _retrieve_messages(self, retrieve) -> List[Message]:
 | 
			
		||||
        """Retrieve messages and update next parameters."""
 | 
			
		||||
        pass
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    async def _retrieve_messages_before_strategy(self, retrieve):
 | 
			
		||||
        """Retrieve messages using before parameter."""
 | 
			
		||||
        before = self.before.id if self.before else None
 | 
			
		||||
        data = await self.logs_from(self.channel.id, retrieve, before=before)
 | 
			
		||||
        data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, before=before)
 | 
			
		||||
        if len(data):
 | 
			
		||||
            if self.limit is not None:
 | 
			
		||||
                self.limit -= retrieve
 | 
			
		||||
@@ -346,7 +367,7 @@ class HistoryIterator(_AsyncIterator['Message']):
 | 
			
		||||
    async def _retrieve_messages_after_strategy(self, retrieve):
 | 
			
		||||
        """Retrieve messages using after parameter."""
 | 
			
		||||
        after = self.after.id if self.after else None
 | 
			
		||||
        data = await self.logs_from(self.channel.id, retrieve, after=after)
 | 
			
		||||
        data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, after=after)
 | 
			
		||||
        if len(data):
 | 
			
		||||
            if self.limit is not None:
 | 
			
		||||
                self.limit -= retrieve
 | 
			
		||||
@@ -357,11 +378,12 @@ class HistoryIterator(_AsyncIterator['Message']):
 | 
			
		||||
        """Retrieve messages using around parameter."""
 | 
			
		||||
        if self.around:
 | 
			
		||||
            around = self.around.id if self.around else None
 | 
			
		||||
            data = await self.logs_from(self.channel.id, retrieve, around=around)
 | 
			
		||||
            data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, around=around)
 | 
			
		||||
            self.around = None
 | 
			
		||||
            return data
 | 
			
		||||
        return []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
 | 
			
		||||
    def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None):
 | 
			
		||||
        if isinstance(before, datetime.datetime):
 | 
			
		||||
@@ -369,7 +391,6 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
 | 
			
		||||
        if isinstance(after, datetime.datetime):
 | 
			
		||||
            after = Object(id=time_snowflake(after, high=True))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        if oldest_first is None:
 | 
			
		||||
            self.reverse = after is not None
 | 
			
		||||
        else:
 | 
			
		||||
@@ -386,12 +407,10 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
 | 
			
		||||
        self._users = {}
 | 
			
		||||
        self._state = guild._state
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        self._filter = None  # entry dict -> bool
 | 
			
		||||
 | 
			
		||||
        self.entries = asyncio.Queue()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        if self.reverse:
 | 
			
		||||
            self._strategy = self._after_strategy
 | 
			
		||||
            if self.before:
 | 
			
		||||
@@ -403,8 +422,9 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
 | 
			
		||||
 | 
			
		||||
    async def _before_strategy(self, retrieve):
 | 
			
		||||
        before = self.before.id if self.before else None
 | 
			
		||||
        data = await self.request(self.guild.id, limit=retrieve, user_id=self.user_id,
 | 
			
		||||
                                  action_type=self.action_type, before=before)
 | 
			
		||||
        data: AuditLogPayload = await self.request(
 | 
			
		||||
            self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, before=before
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        entries = data.get('audit_log_entries', [])
 | 
			
		||||
        if len(data) and entries:
 | 
			
		||||
@@ -415,8 +435,9 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
 | 
			
		||||
 | 
			
		||||
    async def _after_strategy(self, retrieve):
 | 
			
		||||
        after = self.after.id if self.after else None
 | 
			
		||||
        data = await self.request(self.guild.id, limit=retrieve, user_id=self.user_id,
 | 
			
		||||
                                  action_type=self.action_type, after=after)
 | 
			
		||||
        data: AuditLogPayload = await self.request(
 | 
			
		||||
            self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, after=after
 | 
			
		||||
        )
 | 
			
		||||
        entries = data.get('audit_log_entries', [])
 | 
			
		||||
        if len(data) and entries:
 | 
			
		||||
            if self.limit is not None:
 | 
			
		||||
@@ -448,7 +469,7 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
 | 
			
		||||
        if self._get_retrieve():
 | 
			
		||||
            users, data = await self._strategy(self.retrieve)
 | 
			
		||||
            if len(data) < 100:
 | 
			
		||||
                self.limit = 0 # terminate the infinite loop
 | 
			
		||||
                self.limit = 0  # terminate the infinite loop
 | 
			
		||||
 | 
			
		||||
            if self.reverse:
 | 
			
		||||
                data = reversed(data)
 | 
			
		||||
@@ -495,6 +516,7 @@ class GuildIterator(_AsyncIterator['Guild']):
 | 
			
		||||
    after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]]
 | 
			
		||||
        Object after which all guilds must be.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, bot, limit, before=None, after=None):
 | 
			
		||||
 | 
			
		||||
        if isinstance(before, datetime.datetime):
 | 
			
		||||
@@ -514,12 +536,12 @@ class GuildIterator(_AsyncIterator['Guild']):
 | 
			
		||||
        self.guilds = asyncio.Queue()
 | 
			
		||||
 | 
			
		||||
        if self.before and self.after:
 | 
			
		||||
            self._retrieve_guilds = self._retrieve_guilds_before_strategy
 | 
			
		||||
            self._retrieve_guilds = self._retrieve_guilds_before_strategy  # type: ignore
 | 
			
		||||
            self._filter = lambda m: int(m['id']) > self.after.id
 | 
			
		||||
        elif self.after:
 | 
			
		||||
            self._retrieve_guilds = self._retrieve_guilds_after_strategy
 | 
			
		||||
            self._retrieve_guilds = self._retrieve_guilds_after_strategy  # type: ignore
 | 
			
		||||
        else:
 | 
			
		||||
            self._retrieve_guilds = self._retrieve_guilds_before_strategy
 | 
			
		||||
            self._retrieve_guilds = self._retrieve_guilds_before_strategy  # type: ignore
 | 
			
		||||
 | 
			
		||||
    async def next(self) -> Guild:
 | 
			
		||||
        if self.guilds.empty():
 | 
			
		||||
@@ -541,6 +563,7 @@ class GuildIterator(_AsyncIterator['Guild']):
 | 
			
		||||
 | 
			
		||||
    def create_guild(self, data):
 | 
			
		||||
        from .guild import Guild
 | 
			
		||||
 | 
			
		||||
        return Guild(state=self.state, data=data)
 | 
			
		||||
 | 
			
		||||
    async def fill_guilds(self):
 | 
			
		||||
@@ -555,14 +578,14 @@ class GuildIterator(_AsyncIterator['Guild']):
 | 
			
		||||
            for element in data:
 | 
			
		||||
                await self.guilds.put(self.create_guild(element))
 | 
			
		||||
 | 
			
		||||
    async def _retrieve_guilds(self, retrieve):
 | 
			
		||||
    async def _retrieve_guilds(self, retrieve) -> List[Guild]:
 | 
			
		||||
        """Retrieve guilds and update next parameters."""
 | 
			
		||||
        pass
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    async def _retrieve_guilds_before_strategy(self, retrieve):
 | 
			
		||||
        """Retrieve guilds using before parameter."""
 | 
			
		||||
        before = self.before.id if self.before else None
 | 
			
		||||
        data = await self.get_guilds(retrieve, before=before)
 | 
			
		||||
        data: List[GuildPayload] = await self.get_guilds(retrieve, before=before)
 | 
			
		||||
        if len(data):
 | 
			
		||||
            if self.limit is not None:
 | 
			
		||||
                self.limit -= retrieve
 | 
			
		||||
@@ -572,13 +595,14 @@ class GuildIterator(_AsyncIterator['Guild']):
 | 
			
		||||
    async def _retrieve_guilds_after_strategy(self, retrieve):
 | 
			
		||||
        """Retrieve guilds using after parameter."""
 | 
			
		||||
        after = self.after.id if self.after else None
 | 
			
		||||
        data = await self.get_guilds(retrieve, after=after)
 | 
			
		||||
        data: List[GuildPayload] = await self.get_guilds(retrieve, after=after)
 | 
			
		||||
        if len(data):
 | 
			
		||||
            if self.limit is not None:
 | 
			
		||||
                self.limit -= retrieve
 | 
			
		||||
            self.after = Object(id=int(data[0]['id']))
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MemberIterator(_AsyncIterator['Member']):
 | 
			
		||||
    def __init__(self, guild, limit=1000, after=None):
 | 
			
		||||
 | 
			
		||||
@@ -620,7 +644,7 @@ class MemberIterator(_AsyncIterator['Member']):
 | 
			
		||||
                return
 | 
			
		||||
 | 
			
		||||
            if len(data) < 1000:
 | 
			
		||||
                self.limit = 0 # terminate loop
 | 
			
		||||
                self.limit = 0  # terminate loop
 | 
			
		||||
 | 
			
		||||
            self.after = Object(id=int(data[-1]['user']['id']))
 | 
			
		||||
 | 
			
		||||
@@ -629,4 +653,5 @@ class MemberIterator(_AsyncIterator['Member']):
 | 
			
		||||
 | 
			
		||||
    def create_member(self, data):
 | 
			
		||||
        from .member import Member
 | 
			
		||||
 | 
			
		||||
        return Member(data=data, guild=self.guild, state=self.state)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user