Fix typing issues and improve typing completeness across the library

Co-authored-by: Danny <Rapptz@users.noreply.github.com>
Co-authored-by: Josh <josh.ja.butt@gmail.com>
This commit is contained in:
Stocker
2022-03-13 23:52:10 -04:00
committed by GitHub
parent 603681940f
commit 5aa696ccfa
66 changed files with 1071 additions and 802 deletions

View File

@@ -43,6 +43,8 @@ from typing import (
Sequence,
Tuple,
Deque,
Literal,
overload,
)
import weakref
import inspect
@@ -88,7 +90,7 @@ if TYPE_CHECKING:
from .types.activity import Activity as ActivityPayload
from .types.channel import DMChannel as DMChannelPayload
from .types.user import User as UserPayload, PartialUser as PartialUserPayload
from .types.emoji import Emoji as EmojiPayload
from .types.emoji import Emoji as EmojiPayload, PartialEmoji as PartialEmojiPayload
from .types.sticker import GuildSticker as GuildStickerPayload
from .types.guild import Guild as GuildPayload
from .types.message import Message as MessagePayload, PartialMessage as PartialMessagePayload
@@ -165,9 +167,9 @@ class ConnectionState:
def __init__(
self,
*,
dispatch: Callable,
handlers: Dict[str, Callable],
hooks: Dict[str, Callable],
dispatch: Callable[..., Any],
handlers: Dict[str, Callable[..., Any]],
hooks: Dict[str, Callable[..., Coroutine[Any, Any, Any]]],
http: HTTPClient,
**options: Any,
) -> None:
@@ -178,9 +180,9 @@ class ConnectionState:
if self.max_messages is not None and self.max_messages <= 0:
self.max_messages = 1000
self.dispatch: Callable = dispatch
self.handlers: Dict[str, Callable] = handlers
self.hooks: Dict[str, Callable] = hooks
self.dispatch: Callable[..., Any] = dispatch
self.handlers: Dict[str, Callable[..., Any]] = handlers
self.hooks: Dict[str, Callable[..., Coroutine[Any, Any, Any]]] = hooks
self.shard_count: Optional[int] = None
self._ready_task: Optional[asyncio.Task] = None
self.application_id: Optional[int] = utils._get_as_snowflake(options, 'application_id')
@@ -245,6 +247,7 @@ class ConnectionState:
if not intents.members or cache_flags._empty:
self.store_user = self.store_user_no_intents # type: ignore - This reassignment is on purpose
self.parsers: Dict[str, Callable[[Any], None]]
self.parsers = parsers = {}
for attr, func in inspect.getmembers(self):
if attr.startswith('parse_'):
@@ -343,13 +346,13 @@ class ConnectionState:
self._users[user_id] = user
return user
def store_user_no_intents(self, data):
def store_user_no_intents(self, data: Union[UserPayload, PartialUserPayload]) -> User:
return User(state=self, data=data)
def create_user(self, data: Union[UserPayload, PartialUserPayload]) -> User:
return User(state=self, data=data)
def get_user(self, id):
def get_user(self, id: int) -> Optional[User]:
return self._users.get(id)
def store_emoji(self, guild: Guild, data: EmojiPayload) -> Emoji:
@@ -571,8 +574,7 @@ class ConnectionState:
pass
else:
self.application_id = utils._get_as_snowflake(application, 'id')
# flags will always be present here
self.application_flags = ApplicationFlags._from_value(application['flags'])
self.application_flags: ApplicationFlags = ApplicationFlags._from_value(application['flags'])
for guild_data in data['guilds']:
self._add_guild_from_data(guild_data) # type: ignore
@@ -743,7 +745,7 @@ class ConnectionState:
self.dispatch('presence_update', old_member, member)
def parse_user_update(self, data: gw.UserUpdateEvent):
def parse_user_update(self, data: gw.UserUpdateEvent) -> None:
if self.user:
self.user._update(data)
@@ -1050,7 +1052,7 @@ class ConnectionState:
guild.stickers = tuple(map(lambda d: self.store_sticker(guild, d), data['stickers']))
self.dispatch('guild_stickers_update', guild, before_stickers, guild.stickers)
def _get_create_guild(self, data):
def _get_create_guild(self, data: gw.GuildCreateEvent) -> Guild:
if data.get('unavailable') is False:
# GUILD_CREATE with unavailable in the response
# usually means that the guild has become available
@@ -1063,10 +1065,22 @@ class ConnectionState:
return self._add_guild_from_data(data)
def is_guild_evicted(self, guild) -> bool:
def is_guild_evicted(self, guild: Guild) -> bool:
return guild.id not in self._guilds
async def chunk_guild(self, guild, *, wait=True, cache=None):
@overload
async def chunk_guild(self, guild: Guild, *, wait: Literal[True] = ..., cache: Optional[bool] = ...) -> List[Member]:
...
@overload
async def chunk_guild(
self, guild: Guild, *, wait: Literal[False] = ..., cache: Optional[bool] = ...
) -> asyncio.Future[List[Member]]:
...
async def chunk_guild(
self, guild: Guild, *, wait: bool = True, cache: Optional[bool] = None
) -> Union[List[Member], asyncio.Future[List[Member]]]:
cache = cache or self.member_cache_flags.joined
request = self._chunk_requests.get(guild.id)
if request is None:
@@ -1445,16 +1459,19 @@ class ConnectionState:
return channel.guild.get_member(user_id)
return self.get_user(user_id)
def get_reaction_emoji(self, data) -> Union[Emoji, PartialEmoji]:
def get_reaction_emoji(self, data: PartialEmojiPayload) -> Union[Emoji, PartialEmoji, str]:
emoji_id = utils._get_as_snowflake(data, 'id')
if not emoji_id:
return data['name']
# the name key will be a str
return data['name'] # type: ignore
try:
return self._emojis[emoji_id]
except KeyError:
return PartialEmoji.with_state(self, animated=data.get('animated', False), id=emoji_id, name=data['name'])
return PartialEmoji.with_state(
self, animated=data.get('animated', False), id=emoji_id, name=data['name'] # type: ignore
)
def _upgrade_partial_emoji(self, emoji: PartialEmoji) -> Union[Emoji, PartialEmoji, str]:
emoji_id = emoji.id
@@ -1589,6 +1606,7 @@ class AutoShardedConnectionState(ConnectionState):
if not hasattr(self, '_ready_state'):
self._ready_state = asyncio.Queue()
self.user: Optional[ClientUser]
self.user = user = ClientUser(state=self, data=data['user'])
# self._users is a list of Users, we're setting a ClientUser
self._users[user.id] = user # type: ignore
@@ -1599,8 +1617,8 @@ class AutoShardedConnectionState(ConnectionState):
except KeyError:
pass
else:
self.application_id = utils._get_as_snowflake(application, 'id')
self.application_flags = ApplicationFlags._from_value(application['flags'])
self.application_id: Optional[int] = utils._get_as_snowflake(application, 'id')
self.application_flags: ApplicationFlags = ApplicationFlags._from_value(application['flags'])
for guild_data in data['guilds']:
self._add_guild_from_data(guild_data) # type: ignore - _add_guild_from_data requires a complete Guild payload