Type-hint user.py
This commit is contained in:
parent
36b9bc8ee3
commit
529fad6fec
@ -22,19 +22,36 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type, TypeVar, TYPE_CHECKING
|
||||
|
||||
import discord.abc
|
||||
from .asset import Asset
|
||||
from .colour import Colour
|
||||
from .enums import DefaultAvatar
|
||||
from .flags import PublicUserFlags
|
||||
from .utils import snowflake_time, _bytes_to_base64_data, MISSING
|
||||
from .enums import DefaultAvatar
|
||||
from .colour import Colour
|
||||
from .asset import Asset
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import datetime
|
||||
|
||||
from .channel import DMChannel
|
||||
from .guild import Guild
|
||||
from .message import Message
|
||||
from .state import ConnectionState
|
||||
from .types.channel import DMChannel as DMChannelPayload
|
||||
from .types.user import User as UserPayload
|
||||
|
||||
|
||||
__all__ = (
|
||||
'User',
|
||||
'ClientUser',
|
||||
)
|
||||
|
||||
U = TypeVar('U', bound='User')
|
||||
BU = TypeVar('BU', bound='BaseUser')
|
||||
|
||||
|
||||
class _UserTag:
|
||||
__slots__ = ()
|
||||
@ -50,30 +67,35 @@ class BaseUser(_UserTag):
|
||||
discriminator: str
|
||||
bot: bool
|
||||
system: bool
|
||||
_state: ConnectionState
|
||||
_avatar: str
|
||||
_banner: Optional[str]
|
||||
_accent_colour: Optional[str]
|
||||
_public_flags: int
|
||||
|
||||
def __init__(self, *, state, data):
|
||||
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
|
||||
self._state = state
|
||||
self._update(data)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<BaseUser id={self.id} name={self.name!r} discriminator={self.discriminator!r}"
|
||||
f" bot={self.bot} system={self.system}>"
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f'{self.name}#{self.discriminator}'
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, _UserTag) and other.id == self.id
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return self.id >> 22
|
||||
|
||||
def _update(self, data):
|
||||
def _update(self, data: UserPayload) -> None:
|
||||
self.name = data['username']
|
||||
self.id = int(data['id'])
|
||||
self.discriminator = data['discriminator']
|
||||
@ -85,7 +107,7 @@ class BaseUser(_UserTag):
|
||||
self.system = data.get('system', False)
|
||||
|
||||
@classmethod
|
||||
def _copy(cls, user):
|
||||
def _copy(cls: Type[BU], user: BU) -> BU:
|
||||
self = cls.__new__(cls) # bypass __init__
|
||||
|
||||
self.name = user.name
|
||||
@ -100,7 +122,7 @@ class BaseUser(_UserTag):
|
||||
|
||||
return self
|
||||
|
||||
def _to_minimal_user_json(self):
|
||||
def _to_minimal_user_json(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'username': self.name,
|
||||
'id': self.id,
|
||||
@ -110,12 +132,12 @@ class BaseUser(_UserTag):
|
||||
}
|
||||
|
||||
@property
|
||||
def public_flags(self):
|
||||
def public_flags(self) -> PublicUserFlags:
|
||||
""":class:`PublicUserFlags`: The publicly available flags the user has."""
|
||||
return PublicUserFlags._from_value(self._public_flags)
|
||||
|
||||
@property
|
||||
def avatar(self):
|
||||
def avatar(self) -> Asset:
|
||||
""":class:`Asset`: Returns an :class:`Asset` for the avatar the user has.
|
||||
|
||||
If the user does not have a traditional avatar, an asset for
|
||||
@ -127,7 +149,7 @@ class BaseUser(_UserTag):
|
||||
return Asset._from_avatar(self._state, self.id, self._avatar)
|
||||
|
||||
@property
|
||||
def default_avatar(self):
|
||||
def default_avatar(self) -> Asset:
|
||||
""":class:`Asset`: Returns the default avatar for a given user. This is calculated by the user's discriminator."""
|
||||
return Asset._from_default_avatar(self._state, int(self.discriminator) % len(DefaultAvatar))
|
||||
|
||||
@ -176,7 +198,7 @@ class BaseUser(_UserTag):
|
||||
return self.accent_colour
|
||||
|
||||
@property
|
||||
def colour(self):
|
||||
def colour(self) -> Colour:
|
||||
""":class:`Colour`: A property that returns a colour denoting the rendered colour
|
||||
for the user. This always returns :meth:`Colour.default`.
|
||||
|
||||
@ -185,7 +207,7 @@ class BaseUser(_UserTag):
|
||||
return Colour.default()
|
||||
|
||||
@property
|
||||
def color(self):
|
||||
def color(self) -> Colour:
|
||||
""":class:`Colour`: A property that returns a color denoting the rendered color
|
||||
for the user. This always returns :meth:`Colour.default`.
|
||||
|
||||
@ -194,12 +216,12 @@ class BaseUser(_UserTag):
|
||||
return self.colour
|
||||
|
||||
@property
|
||||
def mention(self):
|
||||
def mention(self) -> str:
|
||||
""":class:`str`: Returns a string that allows you to mention the given user."""
|
||||
return f'<@{self.id}>'
|
||||
|
||||
@property
|
||||
def created_at(self):
|
||||
def created_at(self) -> datetime:
|
||||
""":class:`datetime.datetime`: Returns the user's creation time in UTC.
|
||||
|
||||
This is when the user's Discord account was created.
|
||||
@ -207,7 +229,7 @@ class BaseUser(_UserTag):
|
||||
return snowflake_time(self.id)
|
||||
|
||||
@property
|
||||
def display_name(self):
|
||||
def display_name(self) -> str:
|
||||
""":class:`str`: Returns the user's display name.
|
||||
|
||||
For regular users this is just their username, but
|
||||
@ -216,7 +238,7 @@ class BaseUser(_UserTag):
|
||||
"""
|
||||
return self.name
|
||||
|
||||
def mentioned_in(self, message):
|
||||
def mentioned_in(self, message: Message) -> bool:
|
||||
"""Checks if the user is mentioned in the specified message.
|
||||
|
||||
Parameters
|
||||
@ -282,16 +304,22 @@ class ClientUser(BaseUser):
|
||||
|
||||
__slots__ = ('locale', '_flags', 'verified', 'mfa_enabled', '__weakref__')
|
||||
|
||||
def __init__(self, *, state, data):
|
||||
if TYPE_CHECKING:
|
||||
verified: bool
|
||||
local: Optional[str]
|
||||
mfa_enabled: bool
|
||||
_flags: int
|
||||
|
||||
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
|
||||
super().__init__(state=state, data=data)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'<ClientUser id={self.id} name={self.name!r} discriminator={self.discriminator!r}'
|
||||
f' bot={self.bot} verified={self.verified} mfa_enabled={self.mfa_enabled}>'
|
||||
)
|
||||
|
||||
def _update(self, data):
|
||||
def _update(self, data: UserPayload) -> None:
|
||||
super()._update(data)
|
||||
# There's actually an Optional[str] phone field as well but I won't use it
|
||||
self.verified = data.get('verified', False)
|
||||
@ -335,7 +363,7 @@ class ClientUser(BaseUser):
|
||||
if avatar is not MISSING:
|
||||
payload['avatar'] = _bytes_to_base64_data(avatar)
|
||||
|
||||
data = await self._state.http.edit_profile(payload)
|
||||
data: UserPayload = await self._state.http.edit_profile(payload)
|
||||
self._update(data)
|
||||
|
||||
|
||||
@ -376,11 +404,14 @@ class User(BaseUser, discord.abc.Messageable):
|
||||
|
||||
__slots__ = ('_stored',)
|
||||
|
||||
def __init__(self, *, state, data):
|
||||
if TYPE_CHECKING:
|
||||
_stored: bool
|
||||
|
||||
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
|
||||
super().__init__(state=state, data=data)
|
||||
self._stored = False
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f'<User id={self.id} name={self.name!r} discriminator={self.discriminator!r} bot={self.bot}>'
|
||||
|
||||
def __del__(self) -> None:
|
||||
@ -391,17 +422,17 @@ class User(BaseUser, discord.abc.Messageable):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _copy(cls, user):
|
||||
def _copy(cls: Type[U], user: U) -> U:
|
||||
self = super()._copy(user)
|
||||
self._stored = False
|
||||
return self
|
||||
|
||||
async def _get_channel(self):
|
||||
async def _get_channel(self) -> DMChannel:
|
||||
ch = await self.create_dm()
|
||||
return ch
|
||||
|
||||
@property
|
||||
def dm_channel(self):
|
||||
def dm_channel(self) -> Optional[DMChannel]:
|
||||
"""Optional[:class:`DMChannel`]: Returns the channel associated with this user if it exists.
|
||||
|
||||
If this returns ``None``, you can create a DM channel by calling the
|
||||
@ -410,7 +441,7 @@ class User(BaseUser, discord.abc.Messageable):
|
||||
return self._state._get_private_channel_by_user(self.id)
|
||||
|
||||
@property
|
||||
def mutual_guilds(self):
|
||||
def mutual_guilds(self) -> List[Guild]:
|
||||
"""List[:class:`Guild`]: The guilds that the user shares with the client.
|
||||
|
||||
.. note::
|
||||
@ -421,7 +452,7 @@ class User(BaseUser, discord.abc.Messageable):
|
||||
"""
|
||||
return [guild for guild in self._state._guilds.values() if guild.get_member(self.id)]
|
||||
|
||||
async def create_dm(self):
|
||||
async def create_dm(self) -> DMChannel:
|
||||
"""|coro|
|
||||
|
||||
Creates a :class:`DMChannel` with this user.
|
||||
@ -439,5 +470,5 @@ class User(BaseUser, discord.abc.Messageable):
|
||||
return found
|
||||
|
||||
state = self._state
|
||||
data = await state.http.start_private_message(self.id)
|
||||
data: DMChannelPayload = await state.http.start_private_message(self.id)
|
||||
return state.add_dm_channel(data)
|
||||
|
Loading…
x
Reference in New Issue
Block a user