Add typings for audit logs, integrations, and webhooks
This commit is contained in:
@ -42,6 +42,19 @@ __all__ = (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .types.audit_log import (
|
||||
AuditLog as AuditLogPayload,
|
||||
)
|
||||
from .types.guild import (
|
||||
Guild as GuildPayload,
|
||||
)
|
||||
from .types.message import (
|
||||
Message as MessagePayload,
|
||||
)
|
||||
from .types.user import (
|
||||
PartialUser as PartialUserPayload,
|
||||
)
|
||||
|
||||
from .member import Member
|
||||
from .user import User
|
||||
from .message import Message
|
||||
@ -54,6 +67,7 @@ _Func = Callable[[T], Union[OT, Awaitable[OT]]]
|
||||
|
||||
OLDEST_OBJECT = Object(id=0)
|
||||
|
||||
|
||||
class _AsyncIterator(AsyncIterator[T]):
|
||||
__slots__ = ()
|
||||
|
||||
@ -105,9 +119,11 @@ class _AsyncIterator(AsyncIterator[T]):
|
||||
except NoMoreItems:
|
||||
raise StopAsyncIteration()
|
||||
|
||||
|
||||
def _identity(x):
|
||||
return x
|
||||
|
||||
|
||||
class _ChunkedAsyncIterator(_AsyncIterator[List[T]]):
|
||||
def __init__(self, iterator, max_size):
|
||||
self.iterator = iterator
|
||||
@ -128,6 +144,7 @@ class _ChunkedAsyncIterator(_AsyncIterator[List[T]]):
|
||||
n += 1
|
||||
return ret
|
||||
|
||||
|
||||
class _MappedAsyncIterator(_AsyncIterator[T]):
|
||||
def __init__(self, iterator, func):
|
||||
self.iterator = iterator
|
||||
@ -138,6 +155,7 @@ class _MappedAsyncIterator(_AsyncIterator[T]):
|
||||
item = await self.iterator.next()
|
||||
return await maybe_coroutine(self.func, item)
|
||||
|
||||
|
||||
class _FilteredAsyncIterator(_AsyncIterator[T]):
|
||||
def __init__(self, iterator, predicate):
|
||||
self.iterator = iterator
|
||||
@ -157,6 +175,7 @@ class _FilteredAsyncIterator(_AsyncIterator[T]):
|
||||
if ret:
|
||||
return item
|
||||
|
||||
|
||||
class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
|
||||
def __init__(self, message, emoji, limit=100, after=None):
|
||||
self.message = message
|
||||
@ -187,7 +206,9 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
|
||||
retrieve = self.limit if self.limit <= 100 else 100
|
||||
|
||||
after = self.after.id if self.after else None
|
||||
data = await self.getter(self.channel_id, self.message.id, self.emoji, retrieve, after=after)
|
||||
data: List[PartialUserPayload] = await self.getter(
|
||||
self.channel_id, self.message.id, self.emoji, retrieve, after=after
|
||||
)
|
||||
|
||||
if data:
|
||||
self.limit -= retrieve
|
||||
@ -205,6 +226,7 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
|
||||
else:
|
||||
await self.users.put(User(state=self.state, data=element))
|
||||
|
||||
|
||||
class HistoryIterator(_AsyncIterator['Message']):
|
||||
"""Iterator for receiving a channel's message history.
|
||||
|
||||
@ -239,8 +261,7 @@ class HistoryIterator(_AsyncIterator['Message']):
|
||||
``True`` if `after` is specified, otherwise ``False``.
|
||||
"""
|
||||
|
||||
def __init__(self, messageable, limit,
|
||||
before=None, after=None, around=None, oldest_first=None):
|
||||
def __init__(self, messageable, limit, before=None, after=None, around=None, oldest_first=None):
|
||||
|
||||
if isinstance(before, datetime.datetime):
|
||||
before = Object(id=time_snowflake(before, high=False))
|
||||
@ -274,7 +295,7 @@ class HistoryIterator(_AsyncIterator['Message']):
|
||||
elif self.limit == 101:
|
||||
self.limit = 100 # Thanks discord
|
||||
|
||||
self._retrieve_messages = self._retrieve_messages_around_strategy
|
||||
self._retrieve_messages = self._retrieve_messages_around_strategy # type: ignore
|
||||
if self.before and self.after:
|
||||
self._filter = lambda m: self.after.id < int(m['id']) < self.before.id
|
||||
elif self.before:
|
||||
@ -283,12 +304,12 @@ class HistoryIterator(_AsyncIterator['Message']):
|
||||
self._filter = lambda m: self.after.id < int(m['id'])
|
||||
else:
|
||||
if self.reverse:
|
||||
self._retrieve_messages = self._retrieve_messages_after_strategy
|
||||
if (self.before):
|
||||
self._retrieve_messages = self._retrieve_messages_after_strategy # type: ignore
|
||||
if self.before:
|
||||
self._filter = lambda m: int(m['id']) < self.before.id
|
||||
else:
|
||||
self._retrieve_messages = self._retrieve_messages_before_strategy
|
||||
if (self.after and self.after != OLDEST_OBJECT):
|
||||
self._retrieve_messages = self._retrieve_messages_before_strategy # type: ignore
|
||||
if self.after and self.after != OLDEST_OBJECT:
|
||||
self._filter = lambda m: int(m['id']) > self.after.id
|
||||
|
||||
async def next(self) -> Message:
|
||||
@ -318,7 +339,7 @@ class HistoryIterator(_AsyncIterator['Message']):
|
||||
if self._get_retrieve():
|
||||
data = await self._retrieve_messages(self.retrieve)
|
||||
if len(data) < 100:
|
||||
self.limit = 0 # terminate the infinite loop
|
||||
self.limit = 0 # terminate the infinite loop
|
||||
|
||||
if self.reverse:
|
||||
data = reversed(data)
|
||||
@ -329,14 +350,14 @@ class HistoryIterator(_AsyncIterator['Message']):
|
||||
for element in data:
|
||||
await self.messages.put(self.state.create_message(channel=channel, data=element))
|
||||
|
||||
async def _retrieve_messages(self, retrieve):
|
||||
async def _retrieve_messages(self, retrieve) -> List[Message]:
|
||||
"""Retrieve messages and update next parameters."""
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
async def _retrieve_messages_before_strategy(self, retrieve):
|
||||
"""Retrieve messages using before parameter."""
|
||||
before = self.before.id if self.before else None
|
||||
data = await self.logs_from(self.channel.id, retrieve, before=before)
|
||||
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, before=before)
|
||||
if len(data):
|
||||
if self.limit is not None:
|
||||
self.limit -= retrieve
|
||||
@ -346,7 +367,7 @@ class HistoryIterator(_AsyncIterator['Message']):
|
||||
async def _retrieve_messages_after_strategy(self, retrieve):
|
||||
"""Retrieve messages using after parameter."""
|
||||
after = self.after.id if self.after else None
|
||||
data = await self.logs_from(self.channel.id, retrieve, after=after)
|
||||
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, after=after)
|
||||
if len(data):
|
||||
if self.limit is not None:
|
||||
self.limit -= retrieve
|
||||
@ -357,11 +378,12 @@ class HistoryIterator(_AsyncIterator['Message']):
|
||||
"""Retrieve messages using around parameter."""
|
||||
if self.around:
|
||||
around = self.around.id if self.around else None
|
||||
data = await self.logs_from(self.channel.id, retrieve, around=around)
|
||||
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, around=around)
|
||||
self.around = None
|
||||
return data
|
||||
return []
|
||||
|
||||
|
||||
class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
|
||||
def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None):
|
||||
if isinstance(before, datetime.datetime):
|
||||
@ -369,7 +391,6 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
|
||||
if isinstance(after, datetime.datetime):
|
||||
after = Object(id=time_snowflake(after, high=True))
|
||||
|
||||
|
||||
if oldest_first is None:
|
||||
self.reverse = after is not None
|
||||
else:
|
||||
@ -386,12 +407,10 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
|
||||
self._users = {}
|
||||
self._state = guild._state
|
||||
|
||||
|
||||
self._filter = None # entry dict -> bool
|
||||
|
||||
self.entries = asyncio.Queue()
|
||||
|
||||
|
||||
if self.reverse:
|
||||
self._strategy = self._after_strategy
|
||||
if self.before:
|
||||
@ -403,8 +422,9 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
|
||||
|
||||
async def _before_strategy(self, retrieve):
|
||||
before = self.before.id if self.before else None
|
||||
data = await self.request(self.guild.id, limit=retrieve, user_id=self.user_id,
|
||||
action_type=self.action_type, before=before)
|
||||
data: AuditLogPayload = await self.request(
|
||||
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, before=before
|
||||
)
|
||||
|
||||
entries = data.get('audit_log_entries', [])
|
||||
if len(data) and entries:
|
||||
@ -415,8 +435,9 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
|
||||
|
||||
async def _after_strategy(self, retrieve):
|
||||
after = self.after.id if self.after else None
|
||||
data = await self.request(self.guild.id, limit=retrieve, user_id=self.user_id,
|
||||
action_type=self.action_type, after=after)
|
||||
data: AuditLogPayload = await self.request(
|
||||
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, after=after
|
||||
)
|
||||
entries = data.get('audit_log_entries', [])
|
||||
if len(data) and entries:
|
||||
if self.limit is not None:
|
||||
@ -448,7 +469,7 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
|
||||
if self._get_retrieve():
|
||||
users, data = await self._strategy(self.retrieve)
|
||||
if len(data) < 100:
|
||||
self.limit = 0 # terminate the infinite loop
|
||||
self.limit = 0 # terminate the infinite loop
|
||||
|
||||
if self.reverse:
|
||||
data = reversed(data)
|
||||
@ -495,6 +516,7 @@ class GuildIterator(_AsyncIterator['Guild']):
|
||||
after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]]
|
||||
Object after which all guilds must be.
|
||||
"""
|
||||
|
||||
def __init__(self, bot, limit, before=None, after=None):
|
||||
|
||||
if isinstance(before, datetime.datetime):
|
||||
@ -514,12 +536,12 @@ class GuildIterator(_AsyncIterator['Guild']):
|
||||
self.guilds = asyncio.Queue()
|
||||
|
||||
if self.before and self.after:
|
||||
self._retrieve_guilds = self._retrieve_guilds_before_strategy
|
||||
self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore
|
||||
self._filter = lambda m: int(m['id']) > self.after.id
|
||||
elif self.after:
|
||||
self._retrieve_guilds = self._retrieve_guilds_after_strategy
|
||||
self._retrieve_guilds = self._retrieve_guilds_after_strategy # type: ignore
|
||||
else:
|
||||
self._retrieve_guilds = self._retrieve_guilds_before_strategy
|
||||
self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore
|
||||
|
||||
async def next(self) -> Guild:
|
||||
if self.guilds.empty():
|
||||
@ -541,6 +563,7 @@ class GuildIterator(_AsyncIterator['Guild']):
|
||||
|
||||
def create_guild(self, data):
|
||||
from .guild import Guild
|
||||
|
||||
return Guild(state=self.state, data=data)
|
||||
|
||||
async def fill_guilds(self):
|
||||
@ -555,14 +578,14 @@ class GuildIterator(_AsyncIterator['Guild']):
|
||||
for element in data:
|
||||
await self.guilds.put(self.create_guild(element))
|
||||
|
||||
async def _retrieve_guilds(self, retrieve):
|
||||
async def _retrieve_guilds(self, retrieve) -> List[Guild]:
|
||||
"""Retrieve guilds and update next parameters."""
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
async def _retrieve_guilds_before_strategy(self, retrieve):
|
||||
"""Retrieve guilds using before parameter."""
|
||||
before = self.before.id if self.before else None
|
||||
data = await self.get_guilds(retrieve, before=before)
|
||||
data: List[GuildPayload] = await self.get_guilds(retrieve, before=before)
|
||||
if len(data):
|
||||
if self.limit is not None:
|
||||
self.limit -= retrieve
|
||||
@ -572,13 +595,14 @@ class GuildIterator(_AsyncIterator['Guild']):
|
||||
async def _retrieve_guilds_after_strategy(self, retrieve):
|
||||
"""Retrieve guilds using after parameter."""
|
||||
after = self.after.id if self.after else None
|
||||
data = await self.get_guilds(retrieve, after=after)
|
||||
data: List[GuildPayload] = await self.get_guilds(retrieve, after=after)
|
||||
if len(data):
|
||||
if self.limit is not None:
|
||||
self.limit -= retrieve
|
||||
self.after = Object(id=int(data[0]['id']))
|
||||
return data
|
||||
|
||||
|
||||
class MemberIterator(_AsyncIterator['Member']):
|
||||
def __init__(self, guild, limit=1000, after=None):
|
||||
|
||||
@ -620,7 +644,7 @@ class MemberIterator(_AsyncIterator['Member']):
|
||||
return
|
||||
|
||||
if len(data) < 1000:
|
||||
self.limit = 0 # terminate loop
|
||||
self.limit = 0 # terminate loop
|
||||
|
||||
self.after = Object(id=int(data[-1]['user']['id']))
|
||||
|
||||
@ -629,4 +653,5 @@ class MemberIterator(_AsyncIterator['Member']):
|
||||
|
||||
def create_member(self, data):
|
||||
from .member import Member
|
||||
|
||||
return Member(data=data, guild=self.guild, state=self.state)
|
||||
|
Reference in New Issue
Block a user