Typehint shard.py

This commit is contained in:
Stocker
2021-08-20 20:05:02 -04:00
committed by GitHub
parent 745cf541ea
commit 5390caa67d

View File

@@ -22,8 +22,9 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
import asyncio import asyncio
import itertools
import logging import logging
import aiohttp import aiohttp
@@ -34,22 +35,29 @@ from .backoff import ExponentialBackoff
from .gateway import * from .gateway import *
from .errors import ( from .errors import (
ClientException, ClientException,
InvalidArgument,
HTTPException, HTTPException,
GatewayNotFound, GatewayNotFound,
ConnectionClosed, ConnectionClosed,
PrivilegedIntentsRequired, PrivilegedIntentsRequired,
) )
from . import utils
from .enums import Status from .enums import Status
from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict, TypeVar
if TYPE_CHECKING:
from .gateway import DiscordWebSocket
from .activity import BaseActivity
from .enums import Status
EI = TypeVar('EI', bound='EventItem')
__all__ = ( __all__ = (
'AutoShardedClient', 'AutoShardedClient',
'ShardInfo', 'ShardInfo',
) )
log = logging.getLogger(__name__) log: logging.Logger = logging.getLogger(__name__)
class EventType: class EventType:
close = 0 close = 0
@@ -62,36 +70,36 @@ class EventType:
class EventItem: class EventItem:
__slots__ = ('type', 'shard', 'error') __slots__ = ('type', 'shard', 'error')
def __init__(self, etype, shard, error): def __init__(self, etype: int, shard: Optional['Shard'], error: Optional[Exception]) -> None:
self.type = etype self.type: int = etype
self.shard = shard self.shard: Optional['Shard'] = shard
self.error = error self.error: Optional[Exception] = error
def __lt__(self, other): def __lt__(self: EI, other: EI) -> bool:
if not isinstance(other, EventItem): if not isinstance(other, EventItem):
return NotImplemented return NotImplemented
return self.type < other.type return self.type < other.type
def __eq__(self, other): def __eq__(self: EI, other: EI) -> bool:
if not isinstance(other, EventItem): if not isinstance(other, EventItem):
return NotImplemented return NotImplemented
return self.type == other.type return self.type == other.type
def __hash__(self): def __hash__(self) -> int:
return hash(self.type) return hash(self.type)
class Shard: class Shard:
def __init__(self, ws, client, queue_put): def __init__(self, ws: DiscordWebSocket, client: AutoShardedClient, queue_put: Callable[[EventItem], None]) -> None:
self.ws = ws self.ws: DiscordWebSocket = ws
self._client = client self._client: Client = client
self._dispatch = client.dispatch self._dispatch: Callable[..., None] = client.dispatch
self._queue_put = queue_put self._queue_put: Callable[[EventItem], None] = queue_put
self.loop = self._client.loop self.loop: asyncio.AbstractEventLoop = self._client.loop
self._disconnect = False self._disconnect: bool = False
self._reconnect = client._reconnect self._reconnect = client._reconnect
self._backoff = ExponentialBackoff() self._backoff: ExponentialBackoff = ExponentialBackoff()
self._task = None self._task: Optional[asyncio.Task] = None
self._handled_exceptions = ( self._handled_exceptions: Tuple[Type[Exception], ...] = (
OSError, OSError,
HTTPException, HTTPException,
GatewayNotFound, GatewayNotFound,
@@ -101,25 +109,26 @@ class Shard:
) )
@property @property
def id(self): def id(self) -> int:
return self.ws.shard_id # DiscordWebSocket.shard_id is set in the from_client classmethod
return self.ws.shard_id # type: ignore
def launch(self): def launch(self) -> None:
self._task = self.loop.create_task(self.worker()) self._task = self.loop.create_task(self.worker())
def _cancel_task(self): def _cancel_task(self) -> None:
if self._task is not None and not self._task.done(): if self._task is not None and not self._task.done():
self._task.cancel() self._task.cancel()
async def close(self): async def close(self) -> None:
self._cancel_task() self._cancel_task()
await self.ws.close(code=1000) await self.ws.close(code=1000)
async def disconnect(self): async def disconnect(self) -> None:
await self.close() await self.close()
self._dispatch('shard_disconnect', self.id) self._dispatch('shard_disconnect', self.id)
async def _handle_disconnect(self, e): async def _handle_disconnect(self, e: Exception) -> None:
self._dispatch('disconnect') self._dispatch('disconnect')
self._dispatch('shard_disconnect', self.id) self._dispatch('shard_disconnect', self.id)
if not self._reconnect: if not self._reconnect:
@@ -148,7 +157,7 @@ class Shard:
await asyncio.sleep(retry) await asyncio.sleep(retry)
self._queue_put(EventItem(EventType.reconnect, self, e)) self._queue_put(EventItem(EventType.reconnect, self, e))
async def worker(self): async def worker(self) -> None:
while not self._client.is_closed(): while not self._client.is_closed():
try: try:
await self.ws.poll_event() await self.ws.poll_event()
@@ -165,7 +174,7 @@ class Shard:
self._queue_put(EventItem(EventType.terminate, self, e)) self._queue_put(EventItem(EventType.terminate, self, e))
break break
async def reidentify(self, exc): async def reidentify(self, exc: ReconnectWebSocket) -> None:
self._cancel_task() self._cancel_task()
self._dispatch('disconnect') self._dispatch('disconnect')
self._dispatch('shard_disconnect', self.id) self._dispatch('shard_disconnect', self.id)
@@ -183,7 +192,7 @@ class Shard:
else: else:
self.launch() self.launch()
async def reconnect(self): async def reconnect(self) -> None:
self._cancel_task() self._cancel_task()
try: try:
coro = DiscordWebSocket.from_client(self._client, shard_id=self.id) coro = DiscordWebSocket.from_client(self._client, shard_id=self.id)
@@ -215,16 +224,16 @@ class ShardInfo:
__slots__ = ('_parent', 'id', 'shard_count') __slots__ = ('_parent', 'id', 'shard_count')
def __init__(self, parent, shard_count): def __init__(self, parent: Shard, shard_count: Optional[int]) -> None:
self._parent = parent self._parent: Shard = parent
self.id = parent.id self.id: int = parent.id
self.shard_count = shard_count self.shard_count: Optional[int] = shard_count
def is_closed(self): def is_closed(self) -> bool:
""":class:`bool`: Whether the shard connection is currently closed.""" """:class:`bool`: Whether the shard connection is currently closed."""
return not self._parent.ws.open return not self._parent.ws.open
async def disconnect(self): async def disconnect(self) -> None:
"""|coro| """|coro|
Disconnects a shard. When this is called, the shard connection will no Disconnects a shard. When this is called, the shard connection will no
@@ -237,7 +246,7 @@ class ShardInfo:
await self._parent.disconnect() await self._parent.disconnect()
async def reconnect(self): async def reconnect(self) -> None:
"""|coro| """|coro|
Disconnects and then connects the shard again. Disconnects and then connects the shard again.
@@ -246,7 +255,7 @@ class ShardInfo:
await self._parent.disconnect() await self._parent.disconnect()
await self._parent.reconnect() await self._parent.reconnect()
async def connect(self): async def connect(self) -> None:
"""|coro| """|coro|
Connects a shard. If the shard is already connected this does nothing. Connects a shard. If the shard is already connected this does nothing.
@@ -257,11 +266,11 @@ class ShardInfo:
await self._parent.reconnect() await self._parent.reconnect()
@property @property
def latency(self): def latency(self) -> float:
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds for this shard.""" """:class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds for this shard."""
return self._parent.ws.latency return self._parent.ws.latency
def is_ws_ratelimited(self): def is_ws_ratelimited(self) -> bool:
""":class:`bool`: Whether the websocket is currently rate limited. """:class:`bool`: Whether the websocket is currently rate limited.
This can be useful to know when deciding whether you should query members This can be useful to know when deciding whether you should query members
@@ -297,9 +306,12 @@ class AutoShardedClient(Client):
shard_ids: Optional[List[:class:`int`]] shard_ids: Optional[List[:class:`int`]]
An optional list of shard_ids to launch the shards with. An optional list of shard_ids to launch the shards with.
""" """
def __init__(self, *args, loop=None, **kwargs): if TYPE_CHECKING:
_connection: AutoShardedConnectionState
def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None:
kwargs.pop('shard_id', None) kwargs.pop('shard_id', None)
self.shard_ids = kwargs.pop('shard_ids', None) self.shard_ids: Optional[List[int]] = kwargs.pop('shard_ids', None)
super().__init__(*args, loop=loop, **kwargs) super().__init__(*args, loop=loop, **kwargs)
if self.shard_ids is not None: if self.shard_ids is not None:
@@ -315,18 +327,19 @@ class AutoShardedClient(Client):
self._connection._get_client = lambda: self self._connection._get_client = lambda: self
self.__queue = asyncio.PriorityQueue() self.__queue = asyncio.PriorityQueue()
def _get_websocket(self, guild_id=None, *, shard_id=None): def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket:
if shard_id is None: if shard_id is None:
shard_id = (guild_id >> 22) % self.shard_count # guild_id won't be None if shard_id is None and shard_count won't be None here
shard_id = (guild_id >> 22) % self.shard_count # type: ignore
return self.__shards[shard_id].ws return self.__shards[shard_id].ws
def _get_state(self, **options): def _get_state(self, **options: Any) -> AutoShardedConnectionState:
return AutoShardedConnectionState(dispatch=self.dispatch, return AutoShardedConnectionState(dispatch=self.dispatch,
handlers=self._handlers, handlers=self._handlers,
hooks=self._hooks, http=self.http, loop=self.loop, **options) hooks=self._hooks, http=self.http, loop=self.loop, **options)
@property @property
def latency(self): def latency(self) -> float:
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds. """:class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.
This operates similarly to :meth:`Client.latency` except it uses the average This operates similarly to :meth:`Client.latency` except it uses the average
@@ -338,14 +351,14 @@ class AutoShardedClient(Client):
return sum(latency for _, latency in self.latencies) / len(self.__shards) return sum(latency for _, latency in self.latencies) / len(self.__shards)
@property @property
def latencies(self): def latencies(self) -> List[Tuple[int, float]]:
"""List[Tuple[:class:`int`, :class:`float`]]: A list of latencies between a HEARTBEAT and a HEARTBEAT_ACK in seconds. """List[Tuple[:class:`int`, :class:`float`]]: A list of latencies between a HEARTBEAT and a HEARTBEAT_ACK in seconds.
This returns a list of tuples with elements ``(shard_id, latency)``. This returns a list of tuples with elements ``(shard_id, latency)``.
""" """
return [(shard_id, shard.ws.latency) for shard_id, shard in self.__shards.items()] return [(shard_id, shard.ws.latency) for shard_id, shard in self.__shards.items()]
def get_shard(self, shard_id): def get_shard(self, shard_id: int) -> Optional[ShardInfo]:
"""Optional[:class:`ShardInfo`]: Gets the shard information at a given shard ID or ``None`` if not found.""" """Optional[:class:`ShardInfo`]: Gets the shard information at a given shard ID or ``None`` if not found."""
try: try:
parent = self.__shards[shard_id] parent = self.__shards[shard_id]
@@ -355,11 +368,11 @@ class AutoShardedClient(Client):
return ShardInfo(parent, self.shard_count) return ShardInfo(parent, self.shard_count)
@property @property
def shards(self): def shards(self) -> Dict[int, ShardInfo]:
"""Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object.""" """Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object."""
return { shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items() } return { shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items() }
async def launch_shard(self, gateway, shard_id, *, initial=False): async def launch_shard(self, gateway: str, shard_id: int, *, initial: bool = False) -> None:
try: try:
coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id) coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id)
ws = await asyncio.wait_for(coro, timeout=180.0) ws = await asyncio.wait_for(coro, timeout=180.0)
@@ -372,7 +385,7 @@ class AutoShardedClient(Client):
self.__shards[shard_id] = ret = Shard(ws, self, self.__queue.put_nowait) self.__shards[shard_id] = ret = Shard(ws, self, self.__queue.put_nowait)
ret.launch() ret.launch()
async def launch_shards(self): async def launch_shards(self) -> None:
if self.shard_count is None: if self.shard_count is None:
self.shard_count, gateway = await self.http.get_bot_gateway() self.shard_count, gateway = await self.http.get_bot_gateway()
else: else:
@@ -389,7 +402,7 @@ class AutoShardedClient(Client):
self._connection.shards_launched.set() self._connection.shards_launched.set()
async def connect(self, *, reconnect=True): async def connect(self, *, reconnect: bool = True) -> None:
self._reconnect = reconnect self._reconnect = reconnect
await self.launch_shards() await self.launch_shards()
@@ -413,7 +426,7 @@ class AutoShardedClient(Client):
elif item.type == EventType.clean_close: elif item.type == EventType.clean_close:
return return
async def close(self): async def close(self) -> None:
"""|coro| """|coro|
Closes the connection to Discord. Closes the connection to Discord.
@@ -425,7 +438,7 @@ class AutoShardedClient(Client):
for vc in self.voice_clients: for vc in self.voice_clients:
try: try:
await vc.disconnect() await vc.disconnect(force=True)
except Exception: except Exception:
pass pass
@@ -436,7 +449,7 @@ class AutoShardedClient(Client):
await self.http.close() await self.http.close()
self.__queue.put_nowait(EventItem(EventType.clean_close, None, None)) self.__queue.put_nowait(EventItem(EventType.clean_close, None, None))
async def change_presence(self, *, activity=None, status=None, shard_id=None): async def change_presence(self, *, activity: Optional[BaseActivity] = None, status: Optional[Status] = None, shard_id: int = None) -> None:
"""|coro| """|coro|
Changes the client's presence. Changes the client's presence.
@@ -468,23 +481,23 @@ class AutoShardedClient(Client):
""" """
if status is None: if status is None:
status = 'online' status_value = 'online'
status_enum = Status.online status_enum = Status.online
elif status is Status.offline: elif status is Status.offline:
status = 'invisible' status_value = 'invisible'
status_enum = Status.offline status_enum = Status.offline
else: else:
status_enum = status status_enum = status
status = str(status) status_value = str(status)
if shard_id is None: if shard_id is None:
for shard in self.__shards.values(): for shard in self.__shards.values():
await shard.ws.change_presence(activity=activity, status=status) await shard.ws.change_presence(activity=activity, status=status_value)
guilds = self._connection.guilds guilds = self._connection.guilds
else: else:
shard = self.__shards[shard_id] shard = self.__shards[shard_id]
await shard.ws.change_presence(activity=activity, status=status) await shard.ws.change_presence(activity=activity, status=status_value)
guilds = [g for g in self._connection.guilds if g.shard_id == shard_id] guilds = [g for g in self._connection.guilds if g.shard_id == shard_id]
activities = () if activity is None else (activity,) activities = () if activity is None else (activity,)
@@ -493,10 +506,11 @@ class AutoShardedClient(Client):
if me is None: if me is None:
continue continue
me.activities = activities # Member.activities is typehinted as Tuple[ActivityType, ...], we may be setting it as Tuple[BaseActivity, ...]
me.activities = activities # type: ignore
me.status = status_enum me.status = status_enum
def is_ws_ratelimited(self): def is_ws_ratelimited(self) -> bool:
""":class:`bool`: Whether the websocket is currently rate limited. """:class:`bool`: Whether the websocket is currently rate limited.
This can be useful to know when deciding whether you should query members This can be useful to know when deciding whether you should query members