Type hint GuildChannel and don't make it a Protocol

This reverts GuildChannel back into a base class mixin.
This commit is contained in:
Rapptz
2021-05-05 11:14:58 -04:00
parent 7fde57c89a
commit c31946f29f
2 changed files with 153 additions and 32 deletions

View File

@@ -26,7 +26,7 @@ from __future__ import annotations
import copy import copy
import asyncio import asyncio
from typing import TYPE_CHECKING, Protocol, runtime_checkable from typing import Any, Dict, List, Mapping, Optional, TYPE_CHECKING, Protocol, TypeVar, Union, overload, runtime_checkable
from .iterators import HistoryIterator from .iterators import HistoryIterator
from .context_managers import Typing from .context_managers import Typing
@@ -54,16 +54,21 @@ if TYPE_CHECKING:
from .user import ClientUser from .user import ClientUser
from .asset import Asset from .asset import Asset
from .state import ConnectionState
from .guild import Guild
from .member import Member
from .channel import CategoryChannel
MISSING = utils.MISSING MISSING = utils.MISSING
class _Undefined: class _Undefined:
def __repr__(self): def __repr__(self):
return 'see-below' return 'see-below'
_undefined = _Undefined() _undefined: Any = _Undefined()
@runtime_checkable @runtime_checkable
@@ -81,6 +86,7 @@ class Snowflake(Protocol):
id: :class:`int` id: :class:`int`
The model's unique ID. The model's unique ID.
""" """
__slots__ = () __slots__ = ()
id: int id: int
@@ -113,6 +119,7 @@ class User(Snowflake, Protocol):
bot: :class:`bool` bot: :class:`bool`
If the user is a bot account. If the user is a bot account.
""" """
__slots__ = () __slots__ = ()
name: str name: str
@@ -147,6 +154,7 @@ class PrivateChannel(Snowflake, Protocol):
me: :class:`~discord.ClientUser` me: :class:`~discord.ClientUser`
The user presenting yourself. The user presenting yourself.
""" """
__slots__ = () __slots__ = ()
me: ClientUser me: ClientUser
@@ -179,7 +187,10 @@ class _Overwrites:
return self.type == 1 return self.type == 1
class GuildChannel(Snowflake, Protocol): GCH = TypeVar('GCH', bound='GuildChannel')
class GuildChannel:
"""An ABC that details the common operations on a Discord guild channel. """An ABC that details the common operations on a Discord guild channel.
The following implement this ABC: The following implement this ABC:
@@ -206,16 +217,38 @@ class GuildChannel(Snowflake, Protocol):
The position in the channel list. This is a number that starts at 0. The position in the channel list. This is a number that starts at 0.
e.g. the top channel is position 0. e.g. the top channel is position 0.
""" """
__slots__ = () __slots__ = ()
def __str__(self): id: int
name: str
guild: Guild
type: ChannelType
_state: ConnectionState
if TYPE_CHECKING:
def __init__(self, *, state: ConnectionState, guild: Guild, data: Dict[str, Any]):
...
def __str__(self) -> str:
return self.name return self.name
@property @property
def _sorting_bucket(self): def _sorting_bucket(self) -> int:
raise NotImplementedError raise NotImplementedError
async def _move(self, position, parent_id=None, lock_permissions=False, *, reason): def _update(self, guild: Guild, data: Dict[str, Any]) -> None:
raise NotImplementedError
async def _move(
self,
position: int,
parent_id: Optional[Any] = None,
lock_permissions: bool = False,
*,
reason: Optional[str],
):
if position < 0: if position < 0:
raise InvalidArgument('Channel position cannot be less than 0.') raise InvalidArgument('Channel position cannot be less than 0.')
@@ -304,7 +337,7 @@ class GuildChannel(Snowflake, Protocol):
payload = { payload = {
'allow': allow.value, 'allow': allow.value,
'deny': deny.value, 'deny': deny.value,
'id': target.id 'id': target.id,
} }
if isinstance(target, Role): if isinstance(target, Role):
@@ -354,7 +387,7 @@ class GuildChannel(Snowflake, Protocol):
tmp[everyone_index], tmp[0] = tmp[0], tmp[everyone_index] tmp[everyone_index], tmp[0] = tmp[0], tmp[everyone_index]
@property @property
def changed_roles(self): def changed_roles(self) -> List[Role]:
"""List[:class:`~discord.Role`]: Returns a list of roles that have been overridden from """List[:class:`~discord.Role`]: Returns a list of roles that have been overridden from
their default values in the :attr:`~discord.Guild.roles` attribute.""" their default values in the :attr:`~discord.Guild.roles` attribute."""
ret = [] ret = []
@@ -370,16 +403,16 @@ class GuildChannel(Snowflake, Protocol):
return ret return ret
@property @property
def mention(self): def mention(self) -> str:
""":class:`str`: The string that allows you to mention the channel.""" """:class:`str`: The string that allows you to mention the channel."""
return f'<#{self.id}>' return f'<#{self.id}>'
@property @property
def created_at(self): def created_at(self) -> datetime:
""":class:`datetime.datetime`: Returns the channel's creation time in UTC.""" """:class:`datetime.datetime`: Returns the channel's creation time in UTC."""
return utils.snowflake_time(self.id) return utils.snowflake_time(self.id)
def overwrites_for(self, obj): def overwrites_for(self, obj: Union[Role, User]) -> PermissionOverwrite:
"""Returns the channel-specific overwrites for a member or a role. """Returns the channel-specific overwrites for a member or a role.
Parameters Parameters
@@ -410,7 +443,7 @@ class GuildChannel(Snowflake, Protocol):
return PermissionOverwrite() return PermissionOverwrite()
@property @property
def overwrites(self): def overwrites(self) -> Mapping[Union[Role, Member], PermissionOverwrite]:
"""Returns all of the channel's overwrites. """Returns all of the channel's overwrites.
This is returned as a dictionary where the key contains the target which This is returned as a dictionary where the key contains the target which
@@ -427,6 +460,7 @@ class GuildChannel(Snowflake, Protocol):
allow = Permissions(ow.allow) allow = Permissions(ow.allow)
deny = Permissions(ow.deny) deny = Permissions(ow.deny)
overwrite = PermissionOverwrite.from_pair(allow, deny) overwrite = PermissionOverwrite.from_pair(allow, deny)
target = None
if ow.is_role(): if ow.is_role():
target = self.guild.get_role(ow.id) target = self.guild.get_role(ow.id)
@@ -443,7 +477,7 @@ class GuildChannel(Snowflake, Protocol):
return ret return ret
@property @property
def category(self): def category(self) -> Optional[CategoryChannel]:
"""Optional[:class:`~discord.CategoryChannel`]: The category this channel belongs to. """Optional[:class:`~discord.CategoryChannel`]: The category this channel belongs to.
If there is no category then this is ``None``. If there is no category then this is ``None``.
@@ -451,7 +485,7 @@ class GuildChannel(Snowflake, Protocol):
return self.guild.get_channel(self.category_id) return self.guild.get_channel(self.category_id)
@property @property
def permissions_synced(self): def permissions_synced(self) -> bool:
""":class:`bool`: Whether or not the permissions for this channel are synced with the """:class:`bool`: Whether or not the permissions for this channel are synced with the
category it belongs to. category it belongs to.
@@ -462,7 +496,7 @@ class GuildChannel(Snowflake, Protocol):
category = self.guild.get_channel(self.category_id) category = self.guild.get_channel(self.category_id)
return bool(category and category.overwrites == self.overwrites) return bool(category and category.overwrites == self.overwrites)
def permissions_for(self, obj, /): def permissions_for(self, obj: Union[Member, Role], /) -> Permissions:
"""Handles permission resolution for the :class:`~discord.Member` """Handles permission resolution for the :class:`~discord.Member`
or :class:`~discord.Role`. or :class:`~discord.Role`.
@@ -595,7 +629,7 @@ class GuildChannel(Snowflake, Protocol):
return base return base
async def delete(self, *, reason=None): async def delete(self, *, reason: Optional[str] = None) -> None:
"""|coro| """|coro|
Deletes the channel. Deletes the channel.
@@ -619,7 +653,14 @@ class GuildChannel(Snowflake, Protocol):
""" """
await self._state.http.delete_channel(self.id, reason=reason) await self._state.http.delete_channel(self.id, reason=reason)
async def set_permissions(self, target, *, overwrite=_undefined, reason=None, **permissions): async def set_permissions(
self,
target: Union[Member, Role],
*,
overwrite: Optional[PermissionOverwrite] = _undefined,
reason: Optional[str] = None,
**permissions: bool,
) -> None:
r"""|coro| r"""|coro|
Sets the channel specific permission overwrites for a target in the Sets the channel specific permission overwrites for a target in the
@@ -714,10 +755,14 @@ class GuildChannel(Snowflake, Protocol):
else: else:
raise InvalidArgument('Invalid overwrite type provided.') raise InvalidArgument('Invalid overwrite type provided.')
async def _clone_impl(self, base_attrs, *, name=None, reason=None): async def _clone_impl(
base_attrs['permission_overwrites'] = [ self: GCH,
x._asdict() for x in self._overwrites base_attrs: Dict[str, Any],
] *,
name: Optional[str] = None,
reason: Optional[str] = None,
) -> GCH:
base_attrs['permission_overwrites'] = [x._asdict() for x in self._overwrites]
base_attrs['parent_id'] = self.category_id base_attrs['parent_id'] = self.category_id
base_attrs['name'] = name or self.name base_attrs['name'] = name or self.name
guild_id = self.guild.id guild_id = self.guild.id
@@ -729,7 +774,7 @@ class GuildChannel(Snowflake, Protocol):
self.guild._channels[obj.id] = obj self.guild._channels[obj.id] = obj
return obj return obj
async def clone(self, *, name=None, reason=None): async def clone(self: GCH, *, name: Optional[str] = None, reason: Optional[str] = None) -> GCH:
"""|coro| """|coro|
Clones this channel. This creates a channel with the same properties Clones this channel. This creates a channel with the same properties
@@ -762,7 +807,55 @@ class GuildChannel(Snowflake, Protocol):
""" """
raise NotImplementedError raise NotImplementedError
async def move(self, **kwargs): @overload
async def move(
self,
*,
beginning: bool,
offset: int = MISSING,
category: Optional[Snowflake] = MISSING,
sync_permissions: bool = MISSING,
reason: str = MISSING,
) -> None:
...
@overload
async def move(
self,
*,
end: bool,
offset: int = MISSING,
category: Optional[Snowflake] = MISSING,
sync_permissions: bool = MISSING,
reason: str = MISSING,
) -> None:
...
@overload
async def move(
self,
*,
before: Snowflake,
offset: int = MISSING,
category: Optional[Snowflake] = MISSING,
sync_permissions: bool = MISSING,
reason: str = MISSING,
) -> None:
...
@overload
async def move(
self,
*,
after: Snowflake,
offset: int = MISSING,
category: Optional[Snowflake] = MISSING,
sync_permissions: bool = MISSING,
reason: str = MISSING,
) -> None:
...
async def move(self, **kwargs) -> None:
"""|coro| """|coro|
A rich interface to help move a channel relative to other channels. A rich interface to help move a channel relative to other channels.
@@ -832,6 +925,7 @@ class GuildChannel(Snowflake, Protocol):
bucket = self._sorting_bucket bucket = self._sorting_bucket
parent_id = kwargs.get('category', MISSING) parent_id = kwargs.get('category', MISSING)
# fmt: off
if parent_id not in (MISSING, None): if parent_id not in (MISSING, None):
parent_id = parent_id.id parent_id = parent_id.id
channels = [ channels = [
@@ -847,6 +941,7 @@ class GuildChannel(Snowflake, Protocol):
if ch._sorting_bucket == bucket if ch._sorting_bucket == bucket
and ch.category_id == self.category_id and ch.category_id == self.category_id
] ]
# fmt: on
channels.sort(key=lambda c: (c.position, c.id)) channels.sort(key=lambda c: (c.position, c.id))
@@ -882,7 +977,15 @@ class GuildChannel(Snowflake, Protocol):
await self._state.http.bulk_channel_update(self.guild.id, payload, reason=reason) await self._state.http.bulk_channel_update(self.guild.id, payload, reason=reason)
async def create_invite(self, *, reason=None, **fields): async def create_invite(
self,
*,
reason: Optional[str] = None,
max_age: int = 0,
max_uses: int = 0,
temporary: bool = False,
unique: bool = True,
) -> Invite:
"""|coro| """|coro|
Creates an instant invite from a text or voice channel. Creates an instant invite from a text or voice channel.
@@ -922,10 +1025,17 @@ class GuildChannel(Snowflake, Protocol):
The invite that was created. The invite that was created.
""" """
data = await self._state.http.create_invite(self.id, reason=reason, **fields) data = await self._state.http.create_invite(
self.id,
reason=reason,
max_age=max_age,
max_uses=max_uses,
temporary=temporary,
unique=unique,
)
return Invite.from_incomplete(data=data, state=self._state) return Invite.from_incomplete(data=data, state=self._state)
async def invites(self): async def invites(self) -> List[Invite]:
"""|coro| """|coro|
Returns a list of all active instant invites from this channel. Returns a list of all active instant invites from this channel.
@@ -1283,6 +1393,7 @@ class Connectable(Protocol):
This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass` This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass`
checks. checks.
""" """
__slots__ = () __slots__ = ()
def _get_voice_client_key(self): def _get_voice_client_key(self):

View File

@@ -28,7 +28,7 @@ import asyncio
import json import json
import logging import logging
import sys import sys
from typing import Any, Coroutine, List, TYPE_CHECKING, TypeVar from typing import Any, Coroutine, List, Optional, TYPE_CHECKING, TypeVar
from urllib.parse import quote as _uriquote from urllib.parse import quote as _uriquote
import weakref import weakref
@@ -43,6 +43,7 @@ log = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from .types import ( from .types import (
interactions, interactions,
invite,
) )
T = TypeVar('T') T = TypeVar('T')
@@ -966,13 +967,22 @@ class HTTPClient:
# Invite management # Invite management
def create_invite(self, channel_id, *, reason=None, **options): def create_invite(
self,
channel_id: int,
*,
reason: Optional[str] = None,
max_age: int = 0,
max_uses: int = 0,
temporary: bool = False,
unique: bool = True,
) -> Response[invite.Invite]:
r = Route('POST', '/channels/{channel_id}/invites', channel_id=channel_id) r = Route('POST', '/channels/{channel_id}/invites', channel_id=channel_id)
payload = { payload = {
'max_age': options.get('max_age', 0), 'max_age': max_age,
'max_uses': options.get('max_uses', 0), 'max_uses': max_uses,
'temporary': options.get('temporary', False), 'temporary': temporary,
'unique': options.get('unique', True), 'unique': unique,
} }
return self.request(r, reason=reason, json=payload) return self.request(r, reason=reason, json=payload)