Use typing.Self throughout library

This commit is contained in:
Josh
2022-03-01 22:53:24 +10:00
committed by GitHub
parent a90e1824f4
commit 147948af9b
28 changed files with 212 additions and 191 deletions

View File

@ -45,6 +45,8 @@ from .help import HelpCommand, DefaultHelpCommand
from .cog import Cog
if TYPE_CHECKING:
from typing_extensions import Self
import importlib.machinery
from discord.message import Message
@ -934,14 +936,27 @@ class BotBase(GroupMixin):
return ret
@overload
async def get_context(self: BT, message: Message) -> Context[BT]:
async def get_context(
self,
message: Message,
) -> Context[Self]: # type: ignore
...
@overload
async def get_context(self, message: Message, *, cls: Type[CXT] = ...) -> CXT:
async def get_context(
self,
message: Message,
*,
cls: Type[CXT] = ...,
) -> CXT: # type: ignore
...
async def get_context(self, message: Message, *, cls: Type[Context] = Context) -> Any:
async def get_context(
self,
message: Message,
*,
cls: Type[CXT] = MISSING,
) -> Any:
r"""|coro|
Returns the invocation context from the message.
@ -970,6 +985,8 @@ class BotBase(GroupMixin):
The invocation context. The type of this can change via the
``cls`` parameter.
"""
if cls is MISSING:
cls = Context # type: ignore
view = StringView(message.content)
ctx = cls(prefix=None, view=view, bot=self, message=message)

View File

@ -31,6 +31,8 @@ from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, TYP
from ._types import _BaseCommand
if TYPE_CHECKING:
from typing_extensions import Self
from .bot import BotBase
from .context import Context
from .core import Command
@ -40,7 +42,6 @@ __all__ = (
'Cog',
)
CogT = TypeVar('CogT', bound='Cog')
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
MISSING: Any = discord.utils.MISSING
@ -111,7 +112,7 @@ class CogMeta(type):
__cog_commands__: List[Command]
__cog_listeners__: List[Tuple[str, str]]
def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta:
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
name, bases, attrs = args
attrs['__cog_name__'] = kwargs.pop('name', name)
attrs['__cog_settings__'] = kwargs.pop('command_attrs', {})
@ -190,10 +191,10 @@ class Cog(metaclass=CogMeta):
__cog_name__: ClassVar[str]
__cog_settings__: ClassVar[Dict[str, Any]]
__cog_commands__: ClassVar[List[Command]]
__cog_commands__: ClassVar[List[Command[Self, Any, Any]]]
__cog_listeners__: ClassVar[List[Tuple[str, str]]]
def __new__(cls: Type[CogT], *args: Any, **kwargs: Any) -> CogT:
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
# For issue 426, we need to store a copy of the command objects
# since we modify them to inject `self` to them.
# To do this, we need to interfere with the Cog creation process.
@ -220,7 +221,7 @@ class Cog(metaclass=CogMeta):
return self
def get_commands(self) -> List[Command]:
def get_commands(self) -> List[Command[Self, Any, Any]]:
r"""
Returns
--------
@ -248,7 +249,7 @@ class Cog(metaclass=CogMeta):
def description(self, description: str) -> None:
self.__cog_description__ = description
def walk_commands(self) -> Generator[Command, None, None]:
def walk_commands(self) -> Generator[Command[Self, Any, Any], None, None]:
"""An iterator that recursively walks through this cog's commands and subcommands.
Yields
@ -418,7 +419,7 @@ class Cog(metaclass=CogMeta):
"""
pass
def _inject(self: CogT, bot: BotBase) -> CogT:
def _inject(self, bot: BotBase) -> Self:
cls = self.__class__
# realistically, the only thing that can cause loading errors

View File

@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Any, Callable, Deque, Dict, Optional, Type, TypeVar, TYPE_CHECKING
from typing import Any, Callable, Deque, Dict, Optional, TYPE_CHECKING
from discord.enums import Enum
import time
import asyncio
@ -35,6 +35,8 @@ from ...abc import PrivateChannel
from .errors import MaxConcurrencyReached
if TYPE_CHECKING:
from typing_extensions import Self
from ...message import Message
__all__ = (
@ -45,9 +47,6 @@ __all__ = (
'MaxConcurrency',
)
C = TypeVar('C', bound='CooldownMapping')
MC = TypeVar('MC', bound='MaxConcurrency')
class BucketType(Enum):
default = 0
@ -221,7 +220,7 @@ class CooldownMapping:
return self._type
@classmethod
def from_cooldown(cls: Type[C], rate, per, type) -> C:
def from_cooldown(cls, rate, per, type) -> Self:
return cls(Cooldown(rate, per), type)
def _bucket_key(self, msg: Message) -> Any:
@ -356,7 +355,7 @@ class MaxConcurrency:
if not isinstance(per, BucketType):
raise TypeError(f'max_concurrency \'per\' must be of type BucketType not {type(per)!r}')
def copy(self: MC) -> MC:
def copy(self) -> Self:
return self.__class__(self.number, per=self.per, wait=self.wait)
def __repr__(self) -> str:

View File

@ -56,7 +56,7 @@ from .context import Context
if TYPE_CHECKING:
from typing_extensions import Concatenate, ParamSpec, TypeGuard
from typing_extensions import Concatenate, ParamSpec, TypeGuard, Self
from discord.message import Message
@ -292,7 +292,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
"""
__original_kwargs__: Dict[str, Any]
def __new__(cls: Type[CommandT], *args: Any, **kwargs: Any) -> CommandT:
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
# if you're wondering why this is done, it's because we need to ensure
# we have a complete original copy of **kwargs even for classes that
# mess with it by popping before delegating to the subclass __init__.
@ -498,7 +498,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
else:
return await self.callback(context, *args, **kwargs) # type: ignore
def _ensure_assignment_on_copy(self, other: CommandT) -> CommandT:
def _ensure_assignment_on_copy(self, other: Self) -> Self:
other._before_invoke = self._before_invoke
other._after_invoke = self._after_invoke
if self.checks != other.checks:
@ -515,7 +515,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
pass
return other
def copy(self: CommandT) -> CommandT:
def copy(self) -> Self:
"""Creates a copy of this command.
Returns
@ -526,7 +526,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
ret = self.__class__(self.callback, **self.__original_kwargs__)
return self._ensure_assignment_on_copy(ret)
def _update_copy(self: CommandT, kwargs: Dict[str, Any]) -> CommandT:
def _update_copy(self, kwargs: Dict[str, Any]) -> Self:
if kwargs:
kw = kwargs.copy()
kw.update(self.__original_kwargs__)
@ -1446,7 +1446,7 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
self.invoke_without_command: bool = attrs.pop('invoke_without_command', False)
super().__init__(*args, **attrs)
def copy(self: GroupT) -> GroupT:
def copy(self) -> Self:
"""Creates a copy of this :class:`Group`.
Returns

View File

@ -66,6 +66,8 @@ __all__ = (
if TYPE_CHECKING:
from typing_extensions import Self
from .context import Context
@ -265,7 +267,7 @@ class FlagsMeta(type):
__commands_flag_prefix__: str
def __new__(
cls: Type[type],
cls,
name: str,
bases: Tuple[type, ...],
attrs: Dict[str, Any],
@ -273,7 +275,7 @@ class FlagsMeta(type):
case_insensitive: bool = MISSING,
delimiter: str = MISSING,
prefix: str = MISSING,
):
) -> Self:
attrs['__commands_is_flag__'] = True
try:
@ -432,9 +434,6 @@ async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -
raise BadFlagArgument(flag) from e
F = TypeVar('F', bound='FlagConverter')
class FlagConverter(metaclass=FlagsMeta):
"""A converter that allows for a user-friendly flag syntax.
@ -481,8 +480,8 @@ class FlagConverter(metaclass=FlagsMeta):
yield (flag.name, getattr(self, flag.attribute))
@classmethod
async def _construct_default(cls: Type[F], ctx: Context) -> F:
self: F = cls.__new__(cls)
async def _construct_default(cls, ctx: Context) -> Self:
self = cls.__new__(cls)
flags = cls.__commands_flags__
for flag in flags.values():
if callable(flag.default):
@ -547,7 +546,7 @@ class FlagConverter(metaclass=FlagsMeta):
return result
@classmethod
async def convert(cls: Type[F], ctx: Context, argument: str) -> F:
async def convert(cls, ctx: Context, argument: str) -> Self:
"""|coro|
The method that actually converters an argument to the flag mapping.
@ -576,7 +575,7 @@ class FlagConverter(metaclass=FlagsMeta):
arguments = cls.parse_flags(argument)
flags = cls.__commands_flags__
self: F = cls.__new__(cls)
self = cls.__new__(cls)
for name, flag in flags.items():
try:
values = arguments[name]