Fix type-errors in commands extension

This commit is contained in:
Josh
2022-02-23 23:04:49 +10:00
committed by GitHub
parent a315786869
commit 39c5a4fdc3
5 changed files with 103 additions and 82 deletions

View File

@@ -33,7 +33,7 @@ import importlib.util
import sys
import traceback
import types
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union, overload
import discord
@@ -65,6 +65,7 @@ MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
CFT = TypeVar('CFT', bound='CoroFunc')
CXT = TypeVar('CXT', bound='Context')
BT = TypeVar('BT', bound='Union[Bot, AutoShardedBot]')
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
@@ -932,7 +933,15 @@ class BotBase(GroupMixin):
return ret
async def get_context(self, message: Message, *, cls: Type[CXT] = Context) -> CXT:
@overload
async def get_context(self: BT, message: Message) -> Context[BT]:
...
@overload
async def get_context(self, message: Message, *, cls: Type[CXT] = ...) -> CXT:
...
async def get_context(self, message: Message, *, cls: Type[Context] = Context) -> Any:
r"""|coro|
Returns the invocation context from the message.

View File

@@ -41,6 +41,7 @@ from typing import (
Tuple,
Union,
runtime_checkable,
overload,
)
import discord
@@ -48,7 +49,8 @@ from .errors import *
if TYPE_CHECKING:
from .context import Context
from discord.message import PartialMessageableChannel
from discord.state import Channel
from discord.threads import Thread
from .bot import Bot, AutoShardedBot
_Bot = Union[Bot, AutoShardedBot]
@@ -357,7 +359,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
@staticmethod
def _resolve_channel(
ctx: Context[_Bot], guild_id: Optional[int], channel_id: Optional[int]
) -> Optional[PartialMessageableChannel]:
) -> Optional[Union[Channel, Thread]]:
if channel_id is None:
# we were passed just a message id so we can assume the channel is the current context channel
return ctx.channel
@@ -373,8 +375,8 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialMessage:
guild_id, message_id, channel_id = self._get_id_matches(ctx, argument)
channel = self._resolve_channel(ctx, guild_id, channel_id)
if not channel:
raise ChannelNotFound(channel_id)
if not channel or not isinstance(channel, discord.abc.Messageable):
raise ChannelNotFound(channel_id) # type: ignore - channel_id won't be None here
return discord.PartialMessage(channel=channel, id=message_id)
@@ -399,14 +401,14 @@ class MessageConverter(IDConverter[discord.Message]):
if message:
return message
channel = PartialMessageConverter._resolve_channel(ctx, guild_id, channel_id)
if not channel:
raise ChannelNotFound(channel_id)
if not channel or not isinstance(channel, discord.abc.Messageable):
raise ChannelNotFound(channel_id) # type: ignore - channel_id won't be None here
try:
return await channel.fetch_message(message_id)
except discord.NotFound:
raise MessageNotFound(argument)
except discord.Forbidden:
raise ChannelNotReadable(channel)
raise ChannelNotReadable(channel) # type: ignore - type-checker thinks channel could be a DMChannel at this point
class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
@@ -449,7 +451,8 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
else:
channel_id = int(match.group(1))
if guild:
result = guild.get_channel(channel_id)
# guild.get_channel returns an explicit union instead of the base class
result = guild.get_channel(channel_id) # type: ignore
else:
result = _get_from_guilds(bot, 'get_channel', channel_id)

View File

@@ -99,7 +99,7 @@ __all__ = (
MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
CogT = TypeVar('CogT', bound='Cog')
CogT = TypeVar('CogT', bound='Optional[Cog]')
CommandT = TypeVar('CommandT', bound='Command')
ContextT = TypeVar('ContextT', bound='Context')
# CHT = TypeVar('CHT', bound='Check')
@@ -307,7 +307,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
Callable[Concatenate[ContextT, P], Coro[T]],
],
**kwargs: Any,
):
) -> None:
if not asyncio.iscoroutinefunction(func):
raise TypeError('Callback must be a coroutine.')
@@ -372,7 +372,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.require_var_positional: bool = kwargs.get('require_var_positional', False)
self.ignore_extra: bool = kwargs.get('ignore_extra', True)
self.cooldown_after_parsing: bool = kwargs.get('cooldown_after_parsing', False)
self.cog: Optional[CogT] = None
self.cog: CogT = None
# bandaid for the fact that sometimes parent can be the bot instance
parent = kwargs.get('parent')
@@ -1321,9 +1321,8 @@ class GroupMixin(Generic[CogT]):
@overload
def command(
self,
self: GroupMixin[CogT],
name: str = ...,
cls: Type[Command[CogT, P, T]] = ...,
*args: Any,
**kwargs: Any,
) -> Callable[
@@ -1339,21 +1338,29 @@ class GroupMixin(Generic[CogT]):
@overload
def command(
self,
self: GroupMixin[CogT],
name: str = ...,
cls: Type[CommandT] = ...,
*args: Any,
**kwargs: Any,
) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], CommandT]:
) -> Callable[
[
Union[
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]],
]
],
CommandT,
]:
...
def command(
self,
name: str = MISSING,
cls: Type[CommandT] = MISSING,
cls: Type[Command] = MISSING,
*args: Any,
**kwargs: Any,
) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], CommandT]:
) -> Any:
"""A shortcut decorator that invokes :func:`.command` and adds it to
the internal command list via :meth:`~.GroupMixin.add_command`.
@@ -1363,7 +1370,8 @@ class GroupMixin(Generic[CogT]):
A decorator that converts the provided method into a Command, adds it to the bot, then returns it.
"""
def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> CommandT:
def decorator(func):
kwargs.setdefault('parent', self)
result = command(name=name, cls=cls, *args, **kwargs)(func)
self.add_command(result)
@@ -1373,34 +1381,46 @@ class GroupMixin(Generic[CogT]):
@overload
def group(
self,
self: GroupMixin[CogT],
name: str = ...,
cls: Type[Group[CogT, P, T]] = ...,
*args: Any,
**kwargs: Any,
) -> Callable[
[Union[Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]]]],
[
Union[
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]],
]
],
Group[CogT, P, T],
]:
...
@overload
def group(
self,
self: GroupMixin[CogT],
name: str = ...,
cls: Type[GroupT] = ...,
*args: Any,
**kwargs: Any,
) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], GroupT]:
) -> Callable[
[
Union[
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]],
]
],
GroupT,
]:
...
def group(
self,
name: str = MISSING,
cls: Type[GroupT] = MISSING,
cls: Type[Group] = MISSING,
*args: Any,
**kwargs: Any,
) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], GroupT]:
) -> Any:
"""A shortcut decorator that invokes :func:`.group` and adds it to
the internal command list via :meth:`~.GroupMixin.add_command`.
@@ -1410,7 +1430,7 @@ class GroupMixin(Generic[CogT]):
A decorator that converts the provided method into a Group, adds it to the bot, then returns it.
"""
def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> GroupT:
def decorator(func):
kwargs.setdefault('parent', self)
result = group(name=name, cls=cls, *args, **kwargs)(func)
self.add_command(result)
@@ -1533,21 +1553,39 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
# Decorators
if TYPE_CHECKING:
# Using a class to emulate a function allows for overloading the inner function in the decorator.
class _CommandDecorator:
@overload
def __call__(self, func: Callable[Concatenate[CogT, ContextT, P], Coro[T]], /) -> Command[CogT, P, T]:
...
@overload
def __call__(self, func: Callable[Concatenate[ContextT, P], Coro[T]], /) -> Command[None, P, T]:
...
def __call__(self, func: Callable[..., Coro[T]], /) -> Any:
...
class _GroupDecorator:
@overload
def __call__(self, func: Callable[Concatenate[CogT, ContextT, P], Coro[T]], /) -> Group[CogT, P, T]:
...
@overload
def __call__(self, func: Callable[Concatenate[ContextT, P], Coro[T]], /) -> Group[None, P, T]:
...
def __call__(self, func: Callable[..., Coro[T]], /) -> Any:
...
@overload
def command(
name: str = ...,
cls: Type[Command[CogT, P, T]] = ...,
**attrs: Any,
) -> Callable[
[
Union[
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]],
]
],
Command[CogT, P, T],
]:
) -> _CommandDecorator:
...
@@ -1559,8 +1597,8 @@ def command(
) -> Callable[
[
Union[
Callable[Concatenate[CogT, ContextT, P], Coro[Any]],
Callable[Concatenate[ContextT, P], Coro[Any]],
Callable[Concatenate[CogT, ContextT, P], Coro[Any]], # type: ignore - CogT is used here to allow covariance
]
],
CommandT,
@@ -1570,17 +1608,9 @@ def command(
def command(
name: str = MISSING,
cls: Type[CommandT] = MISSING,
cls: Type[Command] = MISSING,
**attrs: Any,
) -> Callable[
[
Union[
Callable[Concatenate[ContextT, P], Coro[Any]],
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
]
],
Union[Command[CogT, P, T], CommandT],
]:
) -> Any:
"""A decorator that transforms a function into a :class:`.Command`
or if called with :func:`.group`, :class:`.Group`.
@@ -1611,14 +1641,9 @@ def command(
If the function is not a coroutine or is already a command.
"""
if cls is MISSING:
cls = Command # type: ignore
cls = Command
def decorator(
func: Union[
Callable[Concatenate[ContextT, P], Coro[Any]],
Callable[Concatenate[CogT, ContextT, P], Coro[Any]],
]
) -> CommandT:
def decorator(func):
if isinstance(func, Command):
raise TypeError('Callback is already a command.')
return cls(func, name=name, **attrs)
@@ -1629,17 +1654,8 @@ def command(
@overload
def group(
name: str = ...,
cls: Type[Group[CogT, P, T]] = ...,
**attrs: Any,
) -> Callable[
[
Union[
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]],
]
],
Group[CogT, P, T],
]:
) -> _GroupDecorator:
...
@@ -1651,7 +1667,7 @@ def group(
) -> Callable[
[
Union[
Callable[Concatenate[CogT, ContextT, P], Coro[Any]],
Callable[Concatenate[CogT, ContextT, P], Coro[Any]], # type: ignore - CogT is used here to allow covariance
Callable[Concatenate[ContextT, P], Coro[Any]],
]
],
@@ -1662,17 +1678,9 @@ def group(
def group(
name: str = MISSING,
cls: Type[GroupT] = MISSING,
cls: Type[Group] = MISSING,
**attrs: Any,
) -> Callable[
[
Union[
Callable[Concatenate[ContextT, P], Coro[Any]],
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
]
],
Union[Group[CogT, P, T], GroupT],
]:
) -> Any:
"""A decorator that transforms a function into a :class:`.Group`.
This is similar to the :func:`.command` decorator but the ``cls``
@@ -1682,8 +1690,9 @@ def group(
The ``cls`` parameter can now be passed.
"""
if cls is MISSING:
cls = Group # type: ignore
return command(name=name, cls=cls, **attrs) # type: ignore
cls = Group
return command(name=name, cls=cls, **attrs)
def check(predicate: Check) -> Callable[[T], T]:

View File

@@ -21,7 +21,6 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError
# map from opening quotes to closing quotes
@@ -177,7 +176,7 @@ class StringView:
next_char = self.get()
valid_eof = not next_char or next_char.isspace()
if not valid_eof:
raise InvalidEndOfQuotedStringError(next_char)
raise InvalidEndOfQuotedStringError(next_char) # type: ignore - this will always be a string
# we're quoted so it's okay
return ''.join(result)