Add typings for audit logs, integrations, and webhooks

This commit is contained in:
Nadir Chowdhury
2021-04-10 07:53:24 +01:00
committed by GitHub
parent 68aef92b37
commit 3e92196a2b
8 changed files with 386 additions and 52 deletions

View File

@ -42,6 +42,19 @@ __all__ = (
)
if TYPE_CHECKING:
from .types.audit_log import (
AuditLog as AuditLogPayload,
)
from .types.guild import (
Guild as GuildPayload,
)
from .types.message import (
Message as MessagePayload,
)
from .types.user import (
PartialUser as PartialUserPayload,
)
from .member import Member
from .user import User
from .message import Message
@ -54,6 +67,7 @@ _Func = Callable[[T], Union[OT, Awaitable[OT]]]
OLDEST_OBJECT = Object(id=0)
class _AsyncIterator(AsyncIterator[T]):
__slots__ = ()
@ -105,9 +119,11 @@ class _AsyncIterator(AsyncIterator[T]):
except NoMoreItems:
raise StopAsyncIteration()
def _identity(x):
return x
class _ChunkedAsyncIterator(_AsyncIterator[List[T]]):
def __init__(self, iterator, max_size):
self.iterator = iterator
@ -128,6 +144,7 @@ class _ChunkedAsyncIterator(_AsyncIterator[List[T]]):
n += 1
return ret
class _MappedAsyncIterator(_AsyncIterator[T]):
def __init__(self, iterator, func):
self.iterator = iterator
@ -138,6 +155,7 @@ class _MappedAsyncIterator(_AsyncIterator[T]):
item = await self.iterator.next()
return await maybe_coroutine(self.func, item)
class _FilteredAsyncIterator(_AsyncIterator[T]):
def __init__(self, iterator, predicate):
self.iterator = iterator
@ -157,6 +175,7 @@ class _FilteredAsyncIterator(_AsyncIterator[T]):
if ret:
return item
class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
def __init__(self, message, emoji, limit=100, after=None):
self.message = message
@ -187,7 +206,9 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
retrieve = self.limit if self.limit <= 100 else 100
after = self.after.id if self.after else None
data = await self.getter(self.channel_id, self.message.id, self.emoji, retrieve, after=after)
data: List[PartialUserPayload] = await self.getter(
self.channel_id, self.message.id, self.emoji, retrieve, after=after
)
if data:
self.limit -= retrieve
@ -205,6 +226,7 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
else:
await self.users.put(User(state=self.state, data=element))
class HistoryIterator(_AsyncIterator['Message']):
"""Iterator for receiving a channel's message history.
@ -239,8 +261,7 @@ class HistoryIterator(_AsyncIterator['Message']):
``True`` if `after` is specified, otherwise ``False``.
"""
def __init__(self, messageable, limit,
before=None, after=None, around=None, oldest_first=None):
def __init__(self, messageable, limit, before=None, after=None, around=None, oldest_first=None):
if isinstance(before, datetime.datetime):
before = Object(id=time_snowflake(before, high=False))
@ -274,7 +295,7 @@ class HistoryIterator(_AsyncIterator['Message']):
elif self.limit == 101:
self.limit = 100 # Thanks discord
self._retrieve_messages = self._retrieve_messages_around_strategy
self._retrieve_messages = self._retrieve_messages_around_strategy # type: ignore
if self.before and self.after:
self._filter = lambda m: self.after.id < int(m['id']) < self.before.id
elif self.before:
@ -283,12 +304,12 @@ class HistoryIterator(_AsyncIterator['Message']):
self._filter = lambda m: self.after.id < int(m['id'])
else:
if self.reverse:
self._retrieve_messages = self._retrieve_messages_after_strategy
if (self.before):
self._retrieve_messages = self._retrieve_messages_after_strategy # type: ignore
if self.before:
self._filter = lambda m: int(m['id']) < self.before.id
else:
self._retrieve_messages = self._retrieve_messages_before_strategy
if (self.after and self.after != OLDEST_OBJECT):
self._retrieve_messages = self._retrieve_messages_before_strategy # type: ignore
if self.after and self.after != OLDEST_OBJECT:
self._filter = lambda m: int(m['id']) > self.after.id
async def next(self) -> Message:
@ -318,7 +339,7 @@ class HistoryIterator(_AsyncIterator['Message']):
if self._get_retrieve():
data = await self._retrieve_messages(self.retrieve)
if len(data) < 100:
self.limit = 0 # terminate the infinite loop
self.limit = 0 # terminate the infinite loop
if self.reverse:
data = reversed(data)
@ -329,14 +350,14 @@ class HistoryIterator(_AsyncIterator['Message']):
for element in data:
await self.messages.put(self.state.create_message(channel=channel, data=element))
async def _retrieve_messages(self, retrieve):
async def _retrieve_messages(self, retrieve) -> List[Message]:
"""Retrieve messages and update next parameters."""
pass
raise NotImplementedError
async def _retrieve_messages_before_strategy(self, retrieve):
"""Retrieve messages using before parameter."""
before = self.before.id if self.before else None
data = await self.logs_from(self.channel.id, retrieve, before=before)
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, before=before)
if len(data):
if self.limit is not None:
self.limit -= retrieve
@ -346,7 +367,7 @@ class HistoryIterator(_AsyncIterator['Message']):
async def _retrieve_messages_after_strategy(self, retrieve):
"""Retrieve messages using after parameter."""
after = self.after.id if self.after else None
data = await self.logs_from(self.channel.id, retrieve, after=after)
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, after=after)
if len(data):
if self.limit is not None:
self.limit -= retrieve
@ -357,11 +378,12 @@ class HistoryIterator(_AsyncIterator['Message']):
"""Retrieve messages using around parameter."""
if self.around:
around = self.around.id if self.around else None
data = await self.logs_from(self.channel.id, retrieve, around=around)
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, around=around)
self.around = None
return data
return []
class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None):
if isinstance(before, datetime.datetime):
@ -369,7 +391,6 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
if isinstance(after, datetime.datetime):
after = Object(id=time_snowflake(after, high=True))
if oldest_first is None:
self.reverse = after is not None
else:
@ -386,12 +407,10 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
self._users = {}
self._state = guild._state
self._filter = None # entry dict -> bool
self.entries = asyncio.Queue()
if self.reverse:
self._strategy = self._after_strategy
if self.before:
@ -403,8 +422,9 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
async def _before_strategy(self, retrieve):
before = self.before.id if self.before else None
data = await self.request(self.guild.id, limit=retrieve, user_id=self.user_id,
action_type=self.action_type, before=before)
data: AuditLogPayload = await self.request(
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, before=before
)
entries = data.get('audit_log_entries', [])
if len(data) and entries:
@ -415,8 +435,9 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
async def _after_strategy(self, retrieve):
after = self.after.id if self.after else None
data = await self.request(self.guild.id, limit=retrieve, user_id=self.user_id,
action_type=self.action_type, after=after)
data: AuditLogPayload = await self.request(
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, after=after
)
entries = data.get('audit_log_entries', [])
if len(data) and entries:
if self.limit is not None:
@ -448,7 +469,7 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
if self._get_retrieve():
users, data = await self._strategy(self.retrieve)
if len(data) < 100:
self.limit = 0 # terminate the infinite loop
self.limit = 0 # terminate the infinite loop
if self.reverse:
data = reversed(data)
@ -495,6 +516,7 @@ class GuildIterator(_AsyncIterator['Guild']):
after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]]
Object after which all guilds must be.
"""
def __init__(self, bot, limit, before=None, after=None):
if isinstance(before, datetime.datetime):
@ -514,12 +536,12 @@ class GuildIterator(_AsyncIterator['Guild']):
self.guilds = asyncio.Queue()
if self.before and self.after:
self._retrieve_guilds = self._retrieve_guilds_before_strategy
self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore
self._filter = lambda m: int(m['id']) > self.after.id
elif self.after:
self._retrieve_guilds = self._retrieve_guilds_after_strategy
self._retrieve_guilds = self._retrieve_guilds_after_strategy # type: ignore
else:
self._retrieve_guilds = self._retrieve_guilds_before_strategy
self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore
async def next(self) -> Guild:
if self.guilds.empty():
@ -541,6 +563,7 @@ class GuildIterator(_AsyncIterator['Guild']):
def create_guild(self, data):
from .guild import Guild
return Guild(state=self.state, data=data)
async def fill_guilds(self):
@ -555,14 +578,14 @@ class GuildIterator(_AsyncIterator['Guild']):
for element in data:
await self.guilds.put(self.create_guild(element))
async def _retrieve_guilds(self, retrieve):
async def _retrieve_guilds(self, retrieve) -> List[Guild]:
"""Retrieve guilds and update next parameters."""
pass
raise NotImplementedError
async def _retrieve_guilds_before_strategy(self, retrieve):
"""Retrieve guilds using before parameter."""
before = self.before.id if self.before else None
data = await self.get_guilds(retrieve, before=before)
data: List[GuildPayload] = await self.get_guilds(retrieve, before=before)
if len(data):
if self.limit is not None:
self.limit -= retrieve
@ -572,13 +595,14 @@ class GuildIterator(_AsyncIterator['Guild']):
async def _retrieve_guilds_after_strategy(self, retrieve):
"""Retrieve guilds using after parameter."""
after = self.after.id if self.after else None
data = await self.get_guilds(retrieve, after=after)
data: List[GuildPayload] = await self.get_guilds(retrieve, after=after)
if len(data):
if self.limit is not None:
self.limit -= retrieve
self.after = Object(id=int(data[0]['id']))
return data
class MemberIterator(_AsyncIterator['Member']):
def __init__(self, guild, limit=1000, after=None):
@ -620,7 +644,7 @@ class MemberIterator(_AsyncIterator['Member']):
return
if len(data) < 1000:
self.limit = 0 # terminate loop
self.limit = 0 # terminate loop
self.after = Object(id=int(data[-1]['user']['id']))
@ -629,4 +653,5 @@ class MemberIterator(_AsyncIterator['Member']):
def create_member(self, data):
from .member import Member
return Member(data=data, guild=self.guild, state=self.state)