Fix type annotations to adhere to latest pyright release

This commit is contained in:
Josh 2022-06-13 05:30:45 +10:00 committed by GitHub
parent 334ef1d7fa
commit c9f777c873
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 50 additions and 47 deletions

View File

@ -38,7 +38,7 @@ jobs:
- name: Run Pyright - name: Run Pyright
uses: jakebailey/pyright-action@v1 uses: jakebailey/pyright-action@v1
with: with:
version: '1.1.242' version: '1.1.253'
warnings: false warnings: false
no-comments: ${{ matrix.python-version != '3.x' }} no-comments: ${{ matrix.python-version != '3.x' }}

View File

@ -771,8 +771,7 @@ class Command(Generic[GroupT, P, T]):
if not predicates: if not predicates:
return True return True
# Type checker does not understand negative narrowing cases like this function return await async_all(f(interaction) for f in predicates)
return await async_all(f(interaction) for f in predicates) # type: ignore
def error(self, coro: Error[GroupT]) -> Error[GroupT]: def error(self, coro: Error[GroupT]) -> Error[GroupT]:
"""A decorator that registers a coroutine as a local error handler. """A decorator that registers a coroutine as a local error handler.
@ -997,8 +996,7 @@ class ContextMenu:
if not predicates: if not predicates:
return True return True
# Type checker does not understand negative narrowing cases like this function return await async_all(f(interaction) for f in predicates)
return await async_all(f(interaction) for f in predicates) # type: ignore
def _has_any_error_handlers(self) -> bool: def _has_any_error_handlers(self) -> bool:
return self.on_error is not None return self.on_error is not None

View File

@ -96,14 +96,14 @@ class AllChannels:
__slots__ = ('guild',) __slots__ = ('guild',)
def __init__(self, guild: Guild): def __init__(self, guild: Guild):
self.guild = guild self.guild: Guild = guild
@property @property
def id(self) -> int: def id(self) -> int:
""":class:`int`: The ID sentinel used to represent all channels. Equivalent to the guild's ID minus 1.""" """:class:`int`: The ID sentinel used to represent all channels. Equivalent to the guild's ID minus 1."""
return self.guild.id - 1 return self.guild.id - 1
def __repr__(self): def __repr__(self) -> str:
return f'<AllChannels guild={self.guild}>' return f'<AllChannels guild={self.guild}>'

View File

@ -27,13 +27,9 @@ from typing import Dict, List, Optional, TYPE_CHECKING, Any, Tuple, Union
if TYPE_CHECKING: if TYPE_CHECKING:
from aiohttp import ClientResponse, ClientWebSocketResponse from aiohttp import ClientResponse, ClientWebSocketResponse
try:
from requests import Response from requests import Response
_ResponseType = Union[ClientResponse, Response] _ResponseType = Union[ClientResponse, Response]
except ModuleNotFoundError:
_ResponseType = ClientResponse
from .interactions import Interaction from .interactions import Interaction

View File

@ -59,7 +59,7 @@ BotT = TypeVar('BotT', bound=_Bot, covariant=True)
ContextT_co = TypeVar('ContextT_co', bound='Context[Any]', covariant=True) ContextT_co = TypeVar('ContextT_co', bound='Context[Any]', covariant=True)
class Check(Protocol[ContextT_co]): class Check(Protocol[ContextT_co]): # type: ignore # TypeVar is expected to be invariant
predicate: Callable[[ContextT_co], Coroutine[Any, Any, bool]] predicate: Callable[[ContextT_co], Coroutine[Any, Any, bool]]

View File

@ -460,8 +460,7 @@ class BotBase(GroupMixin[None]):
if len(data) == 0: if len(data) == 0:
return True return True
# type-checker doesn't distinguish between functions and methods return await discord.utils.async_all(f(ctx) for f in data)
return await discord.utils.async_all(f(ctx) for f in data) # type: ignore
async def is_owner(self, user: User, /) -> bool: async def is_owner(self, user: User, /) -> bool:
"""|coro| """|coro|

View File

@ -35,7 +35,7 @@ from .view import StringView
from ._types import BotT from ._types import BotT
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self, ParamSpec from typing_extensions import Self, ParamSpec, TypeGuard
from discord.abc import MessageableChannel from discord.abc import MessageableChannel
from discord.guild import Guild from discord.guild import Guild
@ -77,6 +77,10 @@ else:
P = TypeVar('P') P = TypeVar('P')
def is_cog(obj: Any) -> TypeGuard[Cog]:
return hasattr(obj, '__cog_commands__')
class DeferTyping: class DeferTyping:
def __init__(self, ctx: Context[BotT], *, ephemeral: bool): def __init__(self, ctx: Context[BotT], *, ephemeral: bool):
self.ctx: Context[BotT] = ctx self.ctx: Context[BotT] = ctx
@ -526,7 +530,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
await cmd.prepare_help_command(self, entity.qualified_name) await cmd.prepare_help_command(self, entity.qualified_name)
try: try:
if hasattr(entity, '__cog_commands__'): if is_cog(entity):
injected = wrap_callback(cmd.send_cog_help) injected = wrap_callback(cmd.send_cog_help)
return await injected(entity) return await injected(entity)
elif isinstance(entity, Group): elif isinstance(entity, Group):

View File

@ -234,6 +234,7 @@ class MemberConverter(IDConverter[discord.Member]):
guild = ctx.guild guild = ctx.guild
result = None result = None
user_id = None user_id = None
if match is None: if match is None:
# not a mention... # not a mention...
if guild: if guild:
@ -247,7 +248,7 @@ class MemberConverter(IDConverter[discord.Member]):
else: else:
result = _get_from_guilds(bot, 'get_member', user_id) result = _get_from_guilds(bot, 'get_member', user_id)
if result is None: if not isinstance(result, discord.Member):
if guild is None: if guild is None:
raise MemberNotFound(argument) raise MemberNotFound(argument)
@ -1182,7 +1183,7 @@ async def _actual_conversion(ctx: Context[BotT], converter: Any, argument: str,
except CommandError: except CommandError:
raise raise
except Exception as exc: except Exception as exc:
raise ConversionError(converter, exc) from exc raise ConversionError(converter, exc) from exc # type: ignore
try: try:
return converter(argument) return converter(argument)

View File

@ -354,8 +354,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
def __init__( def __init__(
self, self,
func: Union[ func: Union[
Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[CogT, Context[Any], P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]], Callable[Concatenate[Context[Any], P], Coro[T]],
], ],
/, /,
**kwargs: Any, **kwargs: Any,
@ -399,7 +399,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
except AttributeError: except AttributeError:
checks = kwargs.get('checks', []) checks = kwargs.get('checks', [])
self.checks: List[UserCheck[ContextT]] = checks self.checks: List[UserCheck[Context[Any]]] = checks
try: try:
cooldown = func.__commands_cooldown__ cooldown = func.__commands_cooldown__
@ -479,7 +479,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.params: Dict[str, Parameter] = get_signature_parameters(function, globalns) self.params: Dict[str, Parameter] = get_signature_parameters(function, globalns)
def add_check(self, func: UserCheck[ContextT], /) -> None: def add_check(self, func: UserCheck[Context[Any]], /) -> None:
"""Adds a check to the command. """Adds a check to the command.
This is the non-decorator interface to :func:`.check`. This is the non-decorator interface to :func:`.check`.
@ -500,7 +500,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.checks.append(func) self.checks.append(func)
def remove_check(self, func: UserCheck[ContextT], /) -> None: def remove_check(self, func: UserCheck[Context[Any]], /) -> None:
"""Removes a check from the command. """Removes a check from the command.
This function is idempotent and will not raise an exception This function is idempotent and will not raise an exception
@ -1249,7 +1249,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
# since we have no checks, then we just return True. # since we have no checks, then we just return True.
return True return True
return await discord.utils.async_all(predicate(ctx) for predicate in predicates) # type: ignore return await discord.utils.async_all(predicate(ctx) for predicate in predicates)
finally: finally:
ctx.command = original ctx.command = original
@ -1448,7 +1448,7 @@ class GroupMixin(Generic[CogT]):
def command( def command(
self: GroupMixin[CogT], self: GroupMixin[CogT],
name: str = ..., name: str = ...,
cls: Type[CommandT] = ..., cls: Type[CommandT] = ..., # type: ignore # previous overload handles case where cls is not set
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Callable[ ) -> Callable[
@ -1508,7 +1508,7 @@ class GroupMixin(Generic[CogT]):
def group( def group(
self: GroupMixin[CogT], self: GroupMixin[CogT],
name: str = ..., name: str = ...,
cls: Type[GroupT] = ..., cls: Type[GroupT] = ..., # type: ignore # previous overload handles case where cls is not set
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Callable[ ) -> Callable[
@ -1700,7 +1700,7 @@ def command(
@overload @overload
def command( def command(
name: str = ..., name: str = ...,
cls: Type[CommandT] = ..., cls: Type[CommandT] = ..., # type: ignore # previous overload handles case where cls is not set
**attrs: Any, **attrs: Any,
) -> Callable[ ) -> Callable[
[ [
@ -1770,7 +1770,7 @@ def group(
@overload @overload
def group( def group(
name: str = ..., name: str = ...,
cls: Type[GroupT] = ..., cls: Type[GroupT] = ..., # type: ignore # previous overload handles case where cls is not set
**attrs: Any, **attrs: Any,
) -> Callable[ ) -> Callable[
[ [
@ -1878,9 +1878,9 @@ def check(predicate: UserCheck[ContextT], /) -> Check[ContextT]:
The predicate to check if the command should be invoked. The predicate to check if the command should be invoked.
""" """
def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: def decorator(func: Union[Command[Any, ..., Any], CoroFunc]) -> Union[Command[Any, ..., Any], CoroFunc]:
if isinstance(func, Command): if isinstance(func, Command):
func.checks.append(predicate) func.checks.append(predicate) # type: ignore
else: else:
if not hasattr(func, '__commands_checks__'): if not hasattr(func, '__commands_checks__'):
func.__commands_checks__ = [] func.__commands_checks__ = []

View File

@ -61,7 +61,6 @@ if TYPE_CHECKING:
from ._types import ( from ._types import (
UserCheck, UserCheck,
ContextT,
BotT, BotT,
_Bot, _Bot,
) )
@ -378,7 +377,7 @@ class HelpCommand:
bot.remove_command(self._command_impl.name) bot.remove_command(self._command_impl.name)
self._command_impl._eject_cog() self._command_impl._eject_cog()
def add_check(self, func: UserCheck[ContextT], /) -> None: def add_check(self, func: UserCheck[Context[Any]], /) -> None:
""" """
Adds a check to the help command. Adds a check to the help command.
@ -398,7 +397,7 @@ class HelpCommand:
self._command_impl.add_check(func) self._command_impl.add_check(func)
def remove_check(self, func: UserCheck[ContextT], /) -> None: def remove_check(self, func: UserCheck[Context[Any]], /) -> None:
""" """
Removes a check from the help command. Removes a check from the help command.

View File

@ -140,7 +140,7 @@ def make_converter_transformer(converter: Any, parameter: Parameter) -> Type[app
except CommandError: except CommandError:
raise raise
except Exception as exc: except Exception as exc:
raise ConversionError(converter, exc) from exc raise ConversionError(converter, exc) from exc # type: ignore
return type('ConverterTransformer', (app_commands.Transformer,), {'transform': classmethod(transform)}) return type('ConverterTransformer', (app_commands.Transformer,), {'transform': classmethod(transform)})
@ -400,10 +400,10 @@ class HybridAppCommand(discord.app_commands.Command[CogT, P, T]):
if not ret: if not ret:
return False return False
if self.checks and not await async_all(f(interaction) for f in self.checks): # type: ignore if self.checks and not await async_all(f(interaction) for f in self.checks):
return False return False
if self.wrapped.checks and not await async_all(f(ctx) for f in self.wrapped.checks): # type: ignore if self.wrapped.checks and not await async_all(f(ctx) for f in self.wrapped.checks):
return False return False
return True return True
@ -468,7 +468,7 @@ class HybridCommand(Command[CogT, P, T]):
def __init__( def __init__(
self, self,
func: CommandCallback[CogT, ContextT, P, T], func: CommandCallback[CogT, Context[Any], P, T],
/, /,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
@ -838,10 +838,10 @@ def hybrid_command(
If the function is not a coroutine or is already a command. If the function is not a coroutine or is already a command.
""" """
def decorator(func: CommandCallback[CogT, ContextT, P, T]): def decorator(func: CommandCallback[CogT, ContextT, P, T]) -> HybridCommand[CogT, P, T]:
if isinstance(func, Command): if isinstance(func, Command):
raise TypeError('Callback is already a command.') raise TypeError('Callback is already a command.')
return HybridCommand(func, name=name, with_app_command=with_app_command, **attrs) return HybridCommand(func, name=name, with_app_command=with_app_command, **attrs) # type: ignore # ???
return decorator return decorator

View File

@ -35,7 +35,7 @@ from .utils import parse_time, _get_as_snowflake, _bytes_to_base64_data, MISSING
if TYPE_CHECKING: if TYPE_CHECKING:
from .types.scheduled_event import ( from .types.scheduled_event import (
GuildScheduledEvent as GuildScheduledEventPayload, GuildScheduledEvent as BaseGuildScheduledEventPayload,
GuildScheduledEventWithUserCount as GuildScheduledEventWithUserCountPayload, GuildScheduledEventWithUserCount as GuildScheduledEventWithUserCountPayload,
EntityMetadata, EntityMetadata,
) )
@ -46,7 +46,7 @@ if TYPE_CHECKING:
from .state import ConnectionState from .state import ConnectionState
from .user import User from .user import User
GuildScheduledEventPayload = Union[GuildScheduledEventPayload, GuildScheduledEventWithUserCountPayload] GuildScheduledEventPayload = Union[BaseGuildScheduledEventPayload, GuildScheduledEventWithUserCountPayload]
# fmt: off # fmt: off
__all__ = ( __all__ = (

View File

@ -131,7 +131,7 @@ class _cached_property:
if TYPE_CHECKING: if TYPE_CHECKING:
from functools import cached_property as cached_property from functools import cached_property as cached_property
from typing_extensions import ParamSpec, Self from typing_extensions import ParamSpec, Self, TypeGuard
from .permissions import Permissions from .permissions import Permissions
from .abc import Snowflake from .abc import Snowflake
@ -624,7 +624,11 @@ async def maybe_coroutine(f: MaybeAwaitableFunc[P, T], *args: P.args, **kwargs:
return value # type: ignore return value # type: ignore
async def async_all(gen: Iterable[Awaitable[T]], *, check: Callable[[T], bool] = _isawaitable) -> bool: async def async_all(
gen: Iterable[Union[T, Awaitable[T]]],
*,
check: Callable[[Union[T, Awaitable[T]]], TypeGuard[Awaitable[T]]] = _isawaitable,
) -> bool:
for elem in gen: for elem in gen:
if check(elem): if check(elem):
elem = await elem elem = await elem

View File

@ -126,9 +126,11 @@ class WelcomeScreen:
self._store(data) self._store(data)
def _store(self, data: WelcomeScreenPayload) -> None: def _store(self, data: WelcomeScreenPayload) -> None:
self.description = data['description'] self.description: str = data['description']
welcome_channels = data.get('welcome_channels', []) welcome_channels = data.get('welcome_channels', [])
self.welcome_channels = [WelcomeChannel._from_dict(data=wc, guild=self._guild) for wc in welcome_channels] self.welcome_channels: List[WelcomeChannel] = [
WelcomeChannel._from_dict(data=wc, guild=self._guild) for wc in welcome_channels
]
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<WelcomeScreen description={self.description!r} welcome_channels={self.welcome_channels!r} enabled={self.enabled}>' return f'<WelcomeScreen description={self.description!r} welcome_channels={self.welcome_channels!r} enabled={self.enabled}>'