Fix code style issues with Black

This commit is contained in:
Lint Action
2021-09-05 21:34:20 +00:00
parent a23dae8604
commit 7513c2138f
108 changed files with 5369 additions and 4858 deletions

View File

@ -34,11 +34,11 @@ from .object import Object
from .audit_logs import AuditLogEntry
__all__ = (
'ReactionIterator',
'HistoryIterator',
'AuditLogIterator',
'GuildIterator',
'MemberIterator',
"ReactionIterator",
"HistoryIterator",
"AuditLogIterator",
"GuildIterator",
"MemberIterator",
)
if TYPE_CHECKING:
@ -67,8 +67,8 @@ if TYPE_CHECKING:
from .threads import Thread
from .abc import Snowflake
T = TypeVar('T')
OT = TypeVar('OT')
T = TypeVar("T")
OT = TypeVar("OT")
_Func = Callable[[T], Union[OT, Awaitable[OT]]]
OLDEST_OBJECT = Object(id=0)
@ -83,7 +83,7 @@ class _AsyncIterator(AsyncIterator[T]):
def get(self, **attrs: Any) -> Awaitable[Optional[T]]:
def predicate(elem: T):
for attr, val in attrs.items():
nested = attr.split('__')
nested = attr.split("__")
obj = elem
for attribute in nested:
obj = getattr(obj, attribute)
@ -107,7 +107,7 @@ class _AsyncIterator(AsyncIterator[T]):
def chunk(self, max_size: int) -> _ChunkedAsyncIterator[T]:
if max_size <= 0:
raise ValueError('async iterator chunk sizes must be greater than 0.')
raise ValueError("async iterator chunk sizes must be greater than 0.")
return _ChunkedAsyncIterator(self, max_size)
def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]:
@ -182,7 +182,7 @@ class _FilteredAsyncIterator(_AsyncIterator[T]):
return item
class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
class ReactionIterator(_AsyncIterator[Union["User", "Member"]]):
def __init__(self, message, emoji, limit=100, after=None):
self.message = message
self.limit = limit
@ -218,14 +218,14 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
if data:
self.limit -= retrieve
self.after = Object(id=int(data[-1]['id']))
self.after = Object(id=int(data[-1]["id"]))
if self.guild is None or isinstance(self.guild, Object):
for element in reversed(data):
await self.users.put(User(state=self.state, data=element))
else:
for element in reversed(data):
member_id = int(element['id'])
member_id = int(element["id"])
member = self.guild.get_member(member_id)
if member is not None:
await self.users.put(member)
@ -233,7 +233,7 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
await self.users.put(User(state=self.state, data=element))
class HistoryIterator(_AsyncIterator['Message']):
class HistoryIterator(_AsyncIterator["Message"]):
"""Iterator for receiving a channel's message history.
The messages endpoint has two behaviours we care about here:
@ -295,7 +295,7 @@ class HistoryIterator(_AsyncIterator['Message']):
if self.around:
if self.limit is None:
raise ValueError('history does not support around with limit=None')
raise ValueError("history does not support around with limit=None")
if self.limit > 101:
raise ValueError("history max limit 101 when specifying around parameter")
elif self.limit == 101:
@ -303,20 +303,20 @@ class HistoryIterator(_AsyncIterator['Message']):
self._retrieve_messages = self._retrieve_messages_around_strategy # type: ignore
if self.before and self.after:
self._filter = lambda m: self.after.id < int(m['id']) < self.before.id
self._filter = lambda m: self.after.id < int(m["id"]) < self.before.id
elif self.before:
self._filter = lambda m: int(m['id']) < self.before.id
self._filter = lambda m: int(m["id"]) < self.before.id
elif self.after:
self._filter = lambda m: self.after.id < int(m['id'])
self._filter = lambda m: self.after.id < int(m["id"])
else:
if self.reverse:
self._retrieve_messages = self._retrieve_messages_after_strategy # type: ignore
if self.before:
self._filter = lambda m: int(m['id']) < self.before.id
self._filter = lambda m: int(m["id"]) < self.before.id
else:
self._retrieve_messages = self._retrieve_messages_before_strategy # type: ignore
if self.after and self.after != OLDEST_OBJECT:
self._filter = lambda m: int(m['id']) > self.after.id
self._filter = lambda m: int(m["id"]) > self.after.id
async def next(self) -> Message:
if self.messages.empty():
@ -337,7 +337,7 @@ class HistoryIterator(_AsyncIterator['Message']):
return r > 0
async def fill_messages(self):
if not hasattr(self, 'channel'):
if not hasattr(self, "channel"):
# do the required set up
channel = await self.messageable._get_channel()
self.channel = channel
@ -367,7 +367,7 @@ class HistoryIterator(_AsyncIterator['Message']):
if len(data):
if self.limit is not None:
self.limit -= retrieve
self.before = Object(id=int(data[-1]['id']))
self.before = Object(id=int(data[-1]["id"]))
return data
async def _retrieve_messages_after_strategy(self, retrieve):
@ -377,7 +377,7 @@ class HistoryIterator(_AsyncIterator['Message']):
if len(data):
if self.limit is not None:
self.limit -= retrieve
self.after = Object(id=int(data[0]['id']))
self.after = Object(id=int(data[0]["id"]))
return data
async def _retrieve_messages_around_strategy(self, retrieve):
@ -390,7 +390,7 @@ class HistoryIterator(_AsyncIterator['Message']):
return []
class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
class AuditLogIterator(_AsyncIterator["AuditLogEntry"]):
def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None):
if isinstance(before, datetime.datetime):
before = Object(id=time_snowflake(before, high=False))
@ -420,11 +420,11 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
if self.reverse:
self._strategy = self._after_strategy
if self.before:
self._filter = lambda m: int(m['id']) < self.before.id
self._filter = lambda m: int(m["id"]) < self.before.id
else:
self._strategy = self._before_strategy
if self.after and self.after != OLDEST_OBJECT:
self._filter = lambda m: int(m['id']) > self.after.id
self._filter = lambda m: int(m["id"]) > self.after.id
async def _before_strategy(self, retrieve):
before = self.before.id if self.before else None
@ -432,24 +432,24 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, before=before
)
entries = data.get('audit_log_entries', [])
entries = data.get("audit_log_entries", [])
if len(data) and entries:
if self.limit is not None:
self.limit -= retrieve
self.before = Object(id=int(entries[-1]['id']))
return data.get('users', []), entries
self.before = Object(id=int(entries[-1]["id"]))
return data.get("users", []), entries
async def _after_strategy(self, retrieve):
after = self.after.id if self.after else None
data: AuditLogPayload = await self.request(
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, after=after
)
entries = data.get('audit_log_entries', [])
entries = data.get("audit_log_entries", [])
if len(data) and entries:
if self.limit is not None:
self.limit -= retrieve
self.after = Object(id=int(entries[0]['id']))
return data.get('users', []), entries
self.after = Object(id=int(entries[0]["id"]))
return data.get("users", []), entries
async def next(self) -> AuditLogEntry:
if self.entries.empty():
@ -488,13 +488,13 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
for element in data:
# TODO: remove this if statement later
if element['action_type'] is None:
if element["action_type"] is None:
continue
await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild))
class GuildIterator(_AsyncIterator['Guild']):
class GuildIterator(_AsyncIterator["Guild"]):
"""Iterator for receiving the client's guilds.
The guilds endpoint has the same two behaviours as described
@ -543,7 +543,7 @@ class GuildIterator(_AsyncIterator['Guild']):
if self.before and self.after:
self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore
self._filter = lambda m: int(m['id']) > self.after.id
self._filter = lambda m: int(m["id"]) > self.after.id
elif self.after:
self._retrieve_guilds = self._retrieve_guilds_after_strategy # type: ignore
else:
@ -595,7 +595,7 @@ class GuildIterator(_AsyncIterator['Guild']):
if len(data):
if self.limit is not None:
self.limit -= retrieve
self.before = Object(id=int(data[-1]['id']))
self.before = Object(id=int(data[-1]["id"]))
return data
async def _retrieve_guilds_after_strategy(self, retrieve):
@ -605,11 +605,11 @@ class GuildIterator(_AsyncIterator['Guild']):
if len(data):
if self.limit is not None:
self.limit -= retrieve
self.after = Object(id=int(data[0]['id']))
self.after = Object(id=int(data[0]["id"]))
return data
class MemberIterator(_AsyncIterator['Member']):
class MemberIterator(_AsyncIterator["Member"]):
def __init__(self, guild, limit=1000, after=None):
if isinstance(after, datetime.datetime):
@ -652,7 +652,7 @@ class MemberIterator(_AsyncIterator['Member']):
if len(data) < 1000:
self.limit = 0 # terminate loop
self.after = Object(id=int(data[-1]['user']['id']))
self.after = Object(id=int(data[-1]["user"]["id"]))
for element in reversed(data):
await self.members.put(self.create_member(element))
@ -663,7 +663,7 @@ class MemberIterator(_AsyncIterator['Member']):
return Member(data=data, guild=self.guild, state=self.state)
class ArchivedThreadIterator(_AsyncIterator['Thread']):
class ArchivedThreadIterator(_AsyncIterator["Thread"]):
def __init__(
self,
channel_id: int,
@ -681,7 +681,7 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
self.http = guild._state.http
if joined and not private:
raise ValueError('Cannot iterate over joined public archived threads')
raise ValueError("Cannot iterate over joined public archived threads")
self.before: Optional[str]
if before is None:
@ -721,11 +721,11 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
@staticmethod
def get_archive_timestamp(data: ThreadPayload) -> str:
return data['thread_metadata']['archive_timestamp']
return data["thread_metadata"]["archive_timestamp"]
@staticmethod
def get_thread_id(data: ThreadPayload) -> str:
return data['id'] # type: ignore
return data["id"] # type: ignore
async def fill_queue(self) -> None:
if not self.has_more:
@ -735,11 +735,11 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
data = await self.endpoint(self.channel_id, before=self.before, limit=limit)
# This stuff is obviously WIP because 'members' is always empty
threads: List[ThreadPayload] = data.get('threads', [])
threads: List[ThreadPayload] = data.get("threads", [])
for d in reversed(threads):
self.queue.put_nowait(self.create_thread(d))
self.has_more = data.get('has_more', False)
self.has_more = data.get("has_more", False)
if self.limit is not None:
self.limit -= len(threads)
if self.limit <= 0:
@ -750,4 +750,5 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
def create_thread(self, data: ThreadPayload) -> Thread:
from .threads import Thread
return Thread(guild=self.guild, state=self.guild._state, data=data)