Fix code style issues with Black
This commit is contained in:
@@ -29,7 +29,20 @@ import logging
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import aiohttp
|
||||
|
||||
@@ -69,46 +82,49 @@ if TYPE_CHECKING:
|
||||
from .member import Member
|
||||
from .voice_client import VoiceProtocol
|
||||
|
||||
__all__ = (
|
||||
'Client',
|
||||
)
|
||||
__all__ = ("Client",)
|
||||
|
||||
Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])
|
||||
Coro = TypeVar("Coro", bound=Callable[..., Coroutine[Any, Any, Any]])
|
||||
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
|
||||
tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()}
|
||||
|
||||
if not tasks:
|
||||
return
|
||||
|
||||
_log.info('Cleaning up after %d tasks.', len(tasks))
|
||||
_log.info("Cleaning up after %d tasks.", len(tasks))
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
|
||||
_log.info('All tasks finished cancelling.')
|
||||
_log.info("All tasks finished cancelling.")
|
||||
|
||||
for task in tasks:
|
||||
if task.cancelled():
|
||||
continue
|
||||
if task.exception() is not None:
|
||||
loop.call_exception_handler({
|
||||
'message': 'Unhandled exception during Client.run shutdown.',
|
||||
'exception': task.exception(),
|
||||
'task': task
|
||||
})
|
||||
loop.call_exception_handler(
|
||||
{
|
||||
"message": "Unhandled exception during Client.run shutdown.",
|
||||
"exception": task.exception(),
|
||||
"task": task,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||
try:
|
||||
_cancel_tasks(loop)
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
finally:
|
||||
_log.info('Closing the event loop.')
|
||||
_log.info("Closing the event loop.")
|
||||
loop.close()
|
||||
|
||||
|
||||
class Client:
|
||||
r"""Represents a client connection that connects to Discord.
|
||||
This class is used to interact with the Discord WebSocket and API.
|
||||
@@ -199,6 +215,7 @@ class Client:
|
||||
loop: :class:`asyncio.AbstractEventLoop`
|
||||
The event loop that the client uses for asynchronous operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -212,24 +229,22 @@ class Client:
|
||||
self.ws: DiscordWebSocket = None # type: ignore
|
||||
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
|
||||
self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {}
|
||||
self.shard_id: Optional[int] = options.get('shard_id')
|
||||
self.shard_count: Optional[int] = options.get('shard_count')
|
||||
self.shard_id: Optional[int] = options.get("shard_id")
|
||||
self.shard_count: Optional[int] = options.get("shard_count")
|
||||
|
||||
connector: Optional[aiohttp.BaseConnector] = options.pop('connector', None)
|
||||
proxy: Optional[str] = options.pop('proxy', None)
|
||||
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None)
|
||||
unsync_clock: bool = options.pop('assume_unsync_clock', True)
|
||||
self.http: HTTPClient = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop)
|
||||
connector: Optional[aiohttp.BaseConnector] = options.pop("connector", None)
|
||||
proxy: Optional[str] = options.pop("proxy", None)
|
||||
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop("proxy_auth", None)
|
||||
unsync_clock: bool = options.pop("assume_unsync_clock", True)
|
||||
self.http: HTTPClient = HTTPClient(
|
||||
connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop
|
||||
)
|
||||
|
||||
self._handlers: Dict[str, Callable] = {
|
||||
'ready': self._handle_ready
|
||||
}
|
||||
self._handlers: Dict[str, Callable] = {"ready": self._handle_ready}
|
||||
|
||||
self._hooks: Dict[str, Callable] = {
|
||||
'before_identify': self._call_before_identify_hook
|
||||
}
|
||||
self._hooks: Dict[str, Callable] = {"before_identify": self._call_before_identify_hook}
|
||||
|
||||
self._enable_debug_events: bool = options.pop('enable_debug_events', False)
|
||||
self._enable_debug_events: bool = options.pop("enable_debug_events", False)
|
||||
self._connection: ConnectionState = self._get_state(**options)
|
||||
self._connection.shard_count = self.shard_count
|
||||
self._closed: bool = False
|
||||
@@ -247,8 +262,14 @@ class Client:
|
||||
return self.ws
|
||||
|
||||
def _get_state(self, **options: Any) -> ConnectionState:
|
||||
return ConnectionState(dispatch=self.dispatch, handlers=self._handlers,
|
||||
hooks=self._hooks, http=self.http, loop=self.loop, **options)
|
||||
return ConnectionState(
|
||||
dispatch=self.dispatch,
|
||||
handlers=self._handlers,
|
||||
hooks=self._hooks,
|
||||
http=self.http,
|
||||
loop=self.loop,
|
||||
**options,
|
||||
)
|
||||
|
||||
def _handle_ready(self) -> None:
|
||||
self._ready.set()
|
||||
@@ -260,7 +281,7 @@ class Client:
|
||||
This could be referred to as the Discord WebSocket protocol latency.
|
||||
"""
|
||||
ws = self.ws
|
||||
return float('nan') if not ws else ws.latency
|
||||
return float("nan") if not ws else ws.latency
|
||||
|
||||
def is_ws_ratelimited(self) -> bool:
|
||||
""":class:`bool`: Whether the websocket is currently rate limited.
|
||||
@@ -331,7 +352,7 @@ class Client:
|
||||
If this is not passed via ``__init__`` then this is retrieved
|
||||
through the gateway when an event contains the data. Usually
|
||||
after :func:`~discord.on_connect` is called.
|
||||
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return self._connection.application_id
|
||||
@@ -348,7 +369,9 @@ class Client:
|
||||
""":class:`bool`: Specifies if the client's internal cache is ready for use."""
|
||||
return self._ready.is_set()
|
||||
|
||||
async def _run_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> None:
|
||||
async def _run_event(
|
||||
self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
try:
|
||||
await coro(*args, **kwargs)
|
||||
except asyncio.CancelledError:
|
||||
@@ -359,14 +382,16 @@ class Client:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def _schedule_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task:
|
||||
def _schedule_event(
|
||||
self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any
|
||||
) -> asyncio.Task:
|
||||
wrapped = self._run_event(coro, event_name, *args, **kwargs)
|
||||
# Schedules the task
|
||||
return asyncio.create_task(wrapped, name=f'discord.py: {event_name}')
|
||||
return asyncio.create_task(wrapped, name=f"discord.py: {event_name}")
|
||||
|
||||
def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None:
|
||||
_log.debug('Dispatching event %s', event)
|
||||
method = 'on_' + event
|
||||
_log.debug("Dispatching event %s", event)
|
||||
method = "on_" + event
|
||||
|
||||
listeners = self._listeners.get(event)
|
||||
if listeners:
|
||||
@@ -413,7 +438,7 @@ class Client:
|
||||
overridden to have a different implementation.
|
||||
Check :func:`~discord.on_error` for more details.
|
||||
"""
|
||||
print(f'Ignoring exception in {event_method}', file=sys.stderr)
|
||||
print(f"Ignoring exception in {event_method}", file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
|
||||
# hooks
|
||||
@@ -470,7 +495,7 @@ class Client:
|
||||
passing status code.
|
||||
"""
|
||||
|
||||
_log.info('logging in using static token')
|
||||
_log.info("logging in using static token")
|
||||
|
||||
data = await self.http.static_login(token.strip())
|
||||
self._connection.user = ClientUser(state=self._connection, data=data)
|
||||
@@ -502,29 +527,31 @@ class Client:
|
||||
|
||||
backoff = ExponentialBackoff()
|
||||
ws_params = {
|
||||
'initial': True,
|
||||
'shard_id': self.shard_id,
|
||||
"initial": True,
|
||||
"shard_id": self.shard_id,
|
||||
}
|
||||
while not self.is_closed():
|
||||
try:
|
||||
coro = DiscordWebSocket.from_client(self, **ws_params)
|
||||
self.ws = await asyncio.wait_for(coro, timeout=60.0)
|
||||
ws_params['initial'] = False
|
||||
ws_params["initial"] = False
|
||||
while True:
|
||||
await self.ws.poll_event()
|
||||
except ReconnectWebSocket as e:
|
||||
_log.info('Got a request to %s the websocket.', e.op)
|
||||
self.dispatch('disconnect')
|
||||
_log.info("Got a request to %s the websocket.", e.op)
|
||||
self.dispatch("disconnect")
|
||||
ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id)
|
||||
continue
|
||||
except (OSError,
|
||||
HTTPException,
|
||||
GatewayNotFound,
|
||||
ConnectionClosed,
|
||||
aiohttp.ClientError,
|
||||
asyncio.TimeoutError) as exc:
|
||||
except (
|
||||
OSError,
|
||||
HTTPException,
|
||||
GatewayNotFound,
|
||||
ConnectionClosed,
|
||||
aiohttp.ClientError,
|
||||
asyncio.TimeoutError,
|
||||
) as exc:
|
||||
|
||||
self.dispatch('disconnect')
|
||||
self.dispatch("disconnect")
|
||||
if not reconnect:
|
||||
await self.close()
|
||||
if isinstance(exc, ConnectionClosed) and exc.code == 1000:
|
||||
@@ -654,10 +681,10 @@ class Client:
|
||||
try:
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt:
|
||||
_log.info('Received signal to terminate bot and event loop.')
|
||||
_log.info("Received signal to terminate bot and event loop.")
|
||||
finally:
|
||||
future.remove_done_callback(stop_loop_on_completion)
|
||||
_log.info('Cleaning up tasks.')
|
||||
_log.info("Cleaning up tasks.")
|
||||
_cleanup_loop(loop)
|
||||
|
||||
if not future.cancelled():
|
||||
@@ -686,10 +713,10 @@ class Client:
|
||||
self._connection._activity = None
|
||||
elif isinstance(value, BaseActivity):
|
||||
# ConnectionState._activity is typehinted as ActivityPayload, we're passing Dict[str, Any]
|
||||
self._connection._activity = value.to_dict() # type: ignore
|
||||
self._connection._activity = value.to_dict() # type: ignore
|
||||
else:
|
||||
raise TypeError('activity must derive from BaseActivity.')
|
||||
|
||||
raise TypeError("activity must derive from BaseActivity.")
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
""":class:`.Status`:
|
||||
@@ -704,11 +731,11 @@ class Client:
|
||||
@status.setter
|
||||
def status(self, value):
|
||||
if value is Status.offline:
|
||||
self._connection._status = 'invisible'
|
||||
self._connection._status = "invisible"
|
||||
elif isinstance(value, Status):
|
||||
self._connection._status = str(value)
|
||||
else:
|
||||
raise TypeError('status must derive from Status.')
|
||||
raise TypeError("status must derive from Status.")
|
||||
|
||||
@property
|
||||
def allowed_mentions(self) -> Optional[AllowedMentions]:
|
||||
@@ -723,7 +750,7 @@ class Client:
|
||||
if value is None or isinstance(value, AllowedMentions):
|
||||
self._connection.allowed_mentions = value
|
||||
else:
|
||||
raise TypeError(f'allowed_mentions must be AllowedMentions not {value.__class__!r}')
|
||||
raise TypeError(f"allowed_mentions must be AllowedMentions not {value.__class__!r}")
|
||||
|
||||
@property
|
||||
def intents(self) -> Intents:
|
||||
@@ -760,7 +787,7 @@ class Client:
|
||||
|
||||
This is useful if you have a channel_id but don't want to do an API call
|
||||
to send messages to it.
|
||||
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Parameters
|
||||
@@ -1033,8 +1060,10 @@ class Client:
|
||||
|
||||
future = self.loop.create_future()
|
||||
if check is None:
|
||||
|
||||
def _check(*args):
|
||||
return True
|
||||
|
||||
check = _check
|
||||
|
||||
ev = event.lower()
|
||||
@@ -1072,10 +1101,10 @@ class Client:
|
||||
"""
|
||||
|
||||
if not asyncio.iscoroutinefunction(coro):
|
||||
raise TypeError('event registered must be a coroutine function')
|
||||
raise TypeError("event registered must be a coroutine function")
|
||||
|
||||
setattr(self, coro.__name__, coro)
|
||||
_log.debug('%s has successfully been registered as an event', coro.__name__)
|
||||
_log.debug("%s has successfully been registered as an event", coro.__name__)
|
||||
return coro
|
||||
|
||||
async def change_presence(
|
||||
@@ -1114,10 +1143,10 @@ class Client:
|
||||
"""
|
||||
|
||||
if status is None:
|
||||
status_str = 'online'
|
||||
status_str = "online"
|
||||
status = Status.online
|
||||
elif status is Status.offline:
|
||||
status_str = 'invisible'
|
||||
status_str = "invisible"
|
||||
status = Status.offline
|
||||
else:
|
||||
status_str = str(status)
|
||||
@@ -1139,11 +1168,7 @@ class Client:
|
||||
# Guild stuff
|
||||
|
||||
def fetch_guilds(
|
||||
self,
|
||||
*,
|
||||
limit: Optional[int] = 100,
|
||||
before: SnowflakeTime = None,
|
||||
after: SnowflakeTime = None
|
||||
self, *, limit: Optional[int] = 100, before: SnowflakeTime = None, after: SnowflakeTime = None
|
||||
) -> GuildIterator:
|
||||
"""Retrieves an :class:`.AsyncIterator` that enables receiving your guilds.
|
||||
|
||||
@@ -1223,7 +1248,7 @@ class Client:
|
||||
"""
|
||||
code = utils.resolve_template(code)
|
||||
data = await self.http.get_template(code)
|
||||
return Template(data=data, state=self._connection) # type: ignore
|
||||
return Template(data=data, state=self._connection) # type: ignore
|
||||
|
||||
async def fetch_guild(self, guild_id: int, /) -> Guild:
|
||||
"""|coro|
|
||||
@@ -1339,12 +1364,14 @@ class Client:
|
||||
The stage instance from the stage channel ID.
|
||||
"""
|
||||
data = await self.http.get_stage_instance(channel_id)
|
||||
guild = self.get_guild(int(data['guild_id']))
|
||||
guild = self.get_guild(int(data["guild_id"]))
|
||||
return StageInstance(guild=guild, state=self._connection, data=data) # type: ignore
|
||||
|
||||
# Invite management
|
||||
|
||||
async def fetch_invite(self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True) -> Invite:
|
||||
async def fetch_invite(
|
||||
self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True
|
||||
) -> Invite:
|
||||
"""|coro|
|
||||
|
||||
Gets an :class:`.Invite` from a discord.gg URL or ID.
|
||||
@@ -1460,8 +1487,8 @@ class Client:
|
||||
The bot's application information.
|
||||
"""
|
||||
data = await self.http.application_info()
|
||||
if 'rpc_origins' not in data:
|
||||
data['rpc_origins'] = None
|
||||
if "rpc_origins" not in data:
|
||||
data["rpc_origins"] = None
|
||||
return AppInfo(self._connection, data)
|
||||
|
||||
async def fetch_user(self, user_id: int, /) -> User:
|
||||
@@ -1524,19 +1551,19 @@ class Client:
|
||||
"""
|
||||
data = await self.http.get_channel(channel_id)
|
||||
|
||||
factory, ch_type = _threaded_channel_factory(data['type'])
|
||||
factory, ch_type = _threaded_channel_factory(data["type"])
|
||||
if factory is None:
|
||||
raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data))
|
||||
raise InvalidData("Unknown channel type {type} for channel ID {id}.".format_map(data))
|
||||
|
||||
if ch_type in (ChannelType.group, ChannelType.private):
|
||||
# the factory will be a DMChannel or GroupChannel here
|
||||
channel = factory(me=self.user, data=data, state=self._connection) # type: ignore
|
||||
channel = factory(me=self.user, data=data, state=self._connection) # type: ignore
|
||||
else:
|
||||
# the factory can't be a DMChannel or GroupChannel here
|
||||
guild_id = int(data['guild_id']) # type: ignore
|
||||
guild_id = int(data["guild_id"]) # type: ignore
|
||||
guild = self.get_guild(guild_id) or Object(id=guild_id)
|
||||
# GuildChannels expect a Guild, we may be passing an Object
|
||||
channel = factory(guild=guild, state=self._connection, data=data) # type: ignore
|
||||
channel = factory(guild=guild, state=self._connection, data=data) # type: ignore
|
||||
|
||||
return channel
|
||||
|
||||
@@ -1582,8 +1609,8 @@ class Client:
|
||||
The sticker you requested.
|
||||
"""
|
||||
data = await self.http.get_sticker(sticker_id)
|
||||
cls, _ = _sticker_factory(data['type']) # type: ignore
|
||||
return cls(state=self._connection, data=data) # type: ignore
|
||||
cls, _ = _sticker_factory(data["type"]) # type: ignore
|
||||
return cls(state=self._connection, data=data) # type: ignore
|
||||
|
||||
async def fetch_premium_sticker_packs(self) -> List[StickerPack]:
|
||||
"""|coro|
|
||||
@@ -1603,7 +1630,7 @@ class Client:
|
||||
All available premium sticker packs.
|
||||
"""
|
||||
data = await self.http.list_premium_sticker_packs()
|
||||
return [StickerPack(state=self._connection, data=pack) for pack in data['sticker_packs']]
|
||||
return [StickerPack(state=self._connection, data=pack) for pack in data["sticker_packs"]]
|
||||
|
||||
async def create_dm(self, user: Snowflake) -> DMChannel:
|
||||
"""|coro|
|
||||
@@ -1638,7 +1665,7 @@ class Client:
|
||||
|
||||
This method should be used for when a view is comprised of components
|
||||
that last longer than the lifecycle of the program.
|
||||
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Parameters
|
||||
@@ -1660,17 +1687,17 @@ class Client:
|
||||
"""
|
||||
|
||||
if not isinstance(view, View):
|
||||
raise TypeError(f'expected an instance of View not {view.__class__!r}')
|
||||
raise TypeError(f"expected an instance of View not {view.__class__!r}")
|
||||
|
||||
if not view.is_persistent():
|
||||
raise ValueError('View is not persistent. Items need to have a custom_id set and View must have no timeout')
|
||||
raise ValueError("View is not persistent. Items need to have a custom_id set and View must have no timeout")
|
||||
|
||||
self._connection.store_view(view, message_id)
|
||||
|
||||
@property
|
||||
def persistent_views(self) -> Sequence[View]:
|
||||
"""Sequence[:class:`.View`]: A sequence of persistent views added to the client.
|
||||
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return self._connection.persistent_views
|
||||
|
||||
Reference in New Issue
Block a user