Fix typing issues and improve typing completeness across the library

Co-authored-by: Danny <Rapptz@users.noreply.github.com>
Co-authored-by: Josh <josh.ja.butt@gmail.com>
This commit is contained in:
Stocker
2022-03-13 23:52:10 -04:00
committed by GitHub
parent 603681940f
commit 5aa696ccfa
66 changed files with 1071 additions and 802 deletions

View File

@ -30,7 +30,7 @@ import json
import re
from urllib.parse import quote as urlquote
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, TypeVar, Type, overload
from contextvars import ContextVar
import weakref
@ -43,7 +43,7 @@ from ..enums import try_enum, WebhookType
from ..user import BaseUser, User
from ..flags import MessageFlags
from ..asset import Asset
from ..http import Route, handle_message_parameters, MultipartParameters
from ..http import Route, handle_message_parameters, MultipartParameters, HTTPClient
from ..mixins import Hashable
from ..channel import PartialMessageable
from ..file import File
@ -58,24 +58,38 @@ __all__ = (
_log = logging.getLogger(__name__)
if TYPE_CHECKING:
from typing_extensions import Self
from types import TracebackType
from ..embeds import Embed
from ..mentions import AllowedMentions
from ..message import Attachment
from ..state import ConnectionState
from ..http import Response
from ..types.webhook import (
Webhook as WebhookPayload,
)
from ..types.message import (
Message as MessagePayload,
)
from ..guild import Guild
from ..channel import TextChannel
from ..abc import Snowflake
from ..ui.view import View
import datetime
from ..types.webhook import (
Webhook as WebhookPayload,
SourceGuild as SourceGuildPayload,
)
from ..types.message import (
Message as MessagePayload,
)
from ..types.user import (
User as UserPayload,
PartialUser as PartialUserPayload,
)
from ..types.channel import (
PartialChannel as PartialChannelPayload,
)
MISSING = utils.MISSING
BE = TypeVar('BE', bound=BaseException)
_State = Union[ConnectionState, '_WebhookState']
MISSING: Any = utils.MISSING
class AsyncDeferredLock:
@ -83,14 +97,19 @@ class AsyncDeferredLock:
self.lock = lock
self.delta: Optional[float] = None
async def __aenter__(self):
async def __aenter__(self) -> Self:
await self.lock.acquire()
return self
def delay_by(self, delta: float) -> None:
self.delta = delta
async def __aexit__(self, type, value, traceback):
async def __aexit__(
self,
exc_type: Optional[Type[BE]],
exc: Optional[BE],
traceback: Optional[TracebackType],
) -> None:
if self.delta:
await asyncio.sleep(self.delta)
self.lock.release()
@ -545,11 +564,11 @@ class PartialWebhookChannel(Hashable):
__slots__ = ('id', 'name')
def __init__(self, *, data):
self.id = int(data['id'])
self.name = data['name']
def __init__(self, *, data: PartialChannelPayload) -> None:
self.id: int = int(data['id'])
self.name: str = data['name']
def __repr__(self):
def __repr__(self) -> str:
return f'<PartialWebhookChannel name={self.name!r} id={self.id}>'
@ -570,13 +589,13 @@ class PartialWebhookGuild(Hashable):
__slots__ = ('id', 'name', '_icon', '_state')
def __init__(self, *, data, state):
self._state = state
self.id = int(data['id'])
self.name = data['name']
self._icon = data['icon']
def __init__(self, *, data: SourceGuildPayload, state: _State) -> None:
self._state: _State = state
self.id: int = int(data['id'])
self.name: str = data['name']
self._icon: str = data['icon']
def __repr__(self):
def __repr__(self) -> str:
return f'<PartialWebhookGuild name={self.name!r} id={self.id}>'
@property
@ -590,14 +609,14 @@ class PartialWebhookGuild(Hashable):
class _FriendlyHttpAttributeErrorHelper:
__slots__ = ()
def __getattr__(self, attr):
def __getattr__(self, attr: str) -> Any:
raise AttributeError('PartialWebhookState does not support http methods.')
class _WebhookState:
__slots__ = ('_parent', '_webhook')
def __init__(self, webhook: Any, parent: Optional[Union[ConnectionState, _WebhookState]]):
def __init__(self, webhook: Any, parent: Optional[_State]):
self._webhook: Any = webhook
self._parent: Optional[ConnectionState]
@ -606,23 +625,23 @@ class _WebhookState:
else:
self._parent = parent
def _get_guild(self, guild_id):
def _get_guild(self, guild_id: Optional[int]) -> Optional[Guild]:
if self._parent is not None:
return self._parent._get_guild(guild_id)
return None
def store_user(self, data):
def store_user(self, data: Union[UserPayload, PartialUserPayload]) -> BaseUser:
if self._parent is not None:
return self._parent.store_user(data)
# state parameter is artificial
return BaseUser(state=self, data=data) # type: ignore
def create_user(self, data):
def create_user(self, data: Union[UserPayload, PartialUserPayload]) -> BaseUser:
# state parameter is artificial
return BaseUser(state=self, data=data) # type: ignore
@property
def http(self):
def http(self) -> Union[HTTPClient, _FriendlyHttpAttributeErrorHelper]:
if self._parent is not None:
return self._parent.http
@ -630,7 +649,7 @@ class _WebhookState:
# however, using it should result in a late-binding error.
return _FriendlyHttpAttributeErrorHelper()
def __getattr__(self, attr):
def __getattr__(self, attr: str) -> Any:
if self._parent is not None:
return getattr(self._parent, attr)
@ -830,19 +849,24 @@ class BaseWebhook(Hashable):
'_state',
)
def __init__(self, data: WebhookPayload, token: Optional[str] = None, state: Optional[ConnectionState] = None):
def __init__(
self,
data: WebhookPayload,
token: Optional[str] = None,
state: Optional[_State] = None,
) -> None:
self.auth_token: Optional[str] = token
self._state: Union[ConnectionState, _WebhookState] = state or _WebhookState(self, parent=state)
self._state: _State = state or _WebhookState(self, parent=state)
self._update(data)
def _update(self, data: WebhookPayload):
self.id = int(data['id'])
self.type = try_enum(WebhookType, int(data['type']))
self.channel_id = utils._get_as_snowflake(data, 'channel_id')
self.guild_id = utils._get_as_snowflake(data, 'guild_id')
self.name = data.get('name')
self._avatar = data.get('avatar')
self.token = data.get('token')
def _update(self, data: WebhookPayload) -> None:
self.id: int = int(data['id'])
self.type: WebhookType = try_enum(WebhookType, int(data['type']))
self.channel_id: Optional[int] = utils._get_as_snowflake(data, 'channel_id')
self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id')
self.name: Optional[str] = data.get('name')
self._avatar: Optional[str] = data.get('avatar')
self.token: Optional[str] = data.get('token')
user = data.get('user')
self.user: Optional[Union[BaseUser, User]] = None
@ -1010,11 +1034,17 @@ class Webhook(BaseWebhook):
__slots__: Tuple[str, ...] = ('session',)
def __init__(self, data: WebhookPayload, session: aiohttp.ClientSession, token: Optional[str] = None, state=None):
def __init__(
self,
data: WebhookPayload,
session: aiohttp.ClientSession,
token: Optional[str] = None,
state: Optional[_State] = None,
) -> None:
super().__init__(data, token, state)
self.session = session
self.session: aiohttp.ClientSession = session
def __repr__(self):
def __repr__(self) -> str:
return f'<Webhook id={self.id!r}>'
@property
@ -1023,7 +1053,7 @@ class Webhook(BaseWebhook):
return f'https://discord.com/api/webhooks/{self.id}/{self.token}'
@classmethod
def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook:
def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Self:
"""Creates a partial :class:`Webhook`.
Parameters
@ -1059,7 +1089,7 @@ class Webhook(BaseWebhook):
return cls(data, session, token=bot_token)
@classmethod
def from_url(cls, url: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook:
def from_url(cls, url: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Self:
"""Creates a partial :class:`Webhook` from a webhook URL.
.. versionchanged:: 2.0
@ -1102,7 +1132,7 @@ class Webhook(BaseWebhook):
return cls(data, session, token=bot_token) # type: ignore
@classmethod
def _as_follower(cls, data, *, channel, user) -> Webhook:
def _as_follower(cls, data, *, channel, user) -> Self:
name = f"{channel.guild} #{channel}"
feed: WebhookPayload = {
'id': data['webhook_id'],
@ -1118,8 +1148,8 @@ class Webhook(BaseWebhook):
return cls(feed, session=session, state=state, token=state.http.token)
@classmethod
def from_state(cls, data, state) -> Webhook:
session = state.http._HTTPClient__session
def from_state(cls, data: WebhookPayload, state: ConnectionState) -> Self:
session = state.http._HTTPClient__session # type: ignore
return cls(data, session=session, state=state, token=state.http.token)
async def fetch(self, *, prefer_auth: bool = True) -> Webhook:
@ -1168,7 +1198,7 @@ class Webhook(BaseWebhook):
return Webhook(data, self.session, token=self.auth_token, state=self._state)
async def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True):
async def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True) -> None:
"""|coro|
Deletes this Webhook.

View File

@ -37,7 +37,7 @@ import time
import re
from urllib.parse import quote as urlquote
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, TypeVar, Type, overload
import weakref
from .. import utils
@ -56,36 +56,50 @@ __all__ = (
_log = logging.getLogger(__name__)
if TYPE_CHECKING:
from typing_extensions import Self
from types import TracebackType
from ..file import File
from ..embeds import Embed
from ..mentions import AllowedMentions
from ..message import Attachment
from ..abc import Snowflake
from ..state import ConnectionState
from ..types.webhook import (
Webhook as WebhookPayload,
)
from ..abc import Snowflake
from ..types.message import (
Message as MessagePayload,
)
BE = TypeVar('BE', bound=BaseException)
try:
from requests import Session, Response
except ModuleNotFoundError:
pass
MISSING = utils.MISSING
MISSING: Any = utils.MISSING
class DeferredLock:
def __init__(self, lock: threading.Lock):
self.lock = lock
def __init__(self, lock: threading.Lock) -> None:
self.lock: threading.Lock = lock
self.delta: Optional[float] = None
def __enter__(self):
def __enter__(self) -> Self:
self.lock.acquire()
return self
def delay_by(self, delta: float) -> None:
self.delta = delta
def __exit__(self, type, value, traceback):
def __exit__(
self,
exc_type: Optional[Type[BE]],
exc: Optional[BE],
traceback: Optional[TracebackType],
) -> None:
if self.delta:
time.sleep(self.delta)
self.lock.release()
@ -218,7 +232,7 @@ class WebhookAdapter:
token: Optional[str] = None,
session: Session,
reason: Optional[str] = None,
):
) -> None:
route = Route('DELETE', '/webhooks/{webhook_id}', webhook_id=webhook_id)
return self.request(route, session, reason=reason, auth_token=token)
@ -229,7 +243,7 @@ class WebhookAdapter:
*,
session: Session,
reason: Optional[str] = None,
):
) -> None:
route = Route('DELETE', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, reason=reason)
@ -241,7 +255,7 @@ class WebhookAdapter:
*,
session: Session,
reason: Optional[str] = None,
):
) -> WebhookPayload:
route = Route('PATCH', '/webhooks/{webhook_id}', webhook_id=webhook_id)
return self.request(route, session, reason=reason, payload=payload, auth_token=token)
@ -253,7 +267,7 @@ class WebhookAdapter:
*,
session: Session,
reason: Optional[str] = None,
):
) -> WebhookPayload:
route = Route('PATCH', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, reason=reason, payload=payload)
@ -268,7 +282,7 @@ class WebhookAdapter:
files: Optional[List[File]] = None,
thread_id: Optional[int] = None,
wait: bool = False,
):
) -> MessagePayload:
params = {'wait': int(wait)}
if thread_id:
params['thread_id'] = thread_id
@ -282,7 +296,7 @@ class WebhookAdapter:
message_id: int,
*,
session: Session,
):
) -> MessagePayload:
route = Route(
'GET',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
@ -302,7 +316,7 @@ class WebhookAdapter:
payload: Optional[Dict[str, Any]] = None,
multipart: Optional[List[Dict[str, Any]]] = None,
files: Optional[List[File]] = None,
):
) -> MessagePayload:
route = Route(
'PATCH',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
@ -319,7 +333,7 @@ class WebhookAdapter:
message_id: int,
*,
session: Session,
):
) -> None:
route = Route(
'DELETE',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
@ -335,7 +349,7 @@ class WebhookAdapter:
token: str,
*,
session: Session,
):
) -> WebhookPayload:
route = Route('GET', '/webhooks/{webhook_id}', webhook_id=webhook_id)
return self.request(route, session=session, auth_token=token)
@ -345,7 +359,7 @@ class WebhookAdapter:
token: str,
*,
session: Session,
):
) -> WebhookPayload:
route = Route('GET', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session=session)
@ -569,11 +583,17 @@ class SyncWebhook(BaseWebhook):
__slots__: Tuple[str, ...] = ('session',)
def __init__(self, data: WebhookPayload, session: Session, token: Optional[str] = None, state=None):
def __init__(
self,
data: WebhookPayload,
session: Session,
token: Optional[str] = None,
state: Optional[Union[ConnectionState, _WebhookState]] = None,
) -> None:
super().__init__(data, token, state)
self.session = session
self.session: Session = session
def __repr__(self):
def __repr__(self) -> str:
return f'<Webhook id={self.id!r}>'
@property
@ -812,7 +832,7 @@ class SyncWebhook(BaseWebhook):
return SyncWebhook(data=data, session=self.session, token=self.auth_token, state=self._state)
def _create_message(self, data):
def _create_message(self, data: MessagePayload) -> SyncWebhookMessage:
state = _WebhookState(self, parent=self._state)
# state may be artificial (unlikely at this point...)
channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore