[commands][types] Type hint commands-ext

This commit is contained in:
Josh
2021-08-20 09:51:26 +10:00
committed by GitHub
parent d4c683738d
commit f3cb197429
6 changed files with 635 additions and 311 deletions

View File

@ -22,13 +22,18 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import asyncio
import collections
import collections.abc
import inspect
import importlib.util
import sys
import traceback
import types
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union
import discord
@ -39,6 +44,15 @@ from . import errors
from .help import HelpCommand, DefaultHelpCommand
from .cog import Cog
if TYPE_CHECKING:
import importlib.machinery
from discord.message import Message
from ._types import (
Check,
CoroFunc,
)
__all__ = (
'when_mentioned',
'when_mentioned_or',
@ -46,14 +60,21 @@ __all__ = (
'AutoShardedBot',
)
def when_mentioned(bot, msg):
MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
CFT = TypeVar('CFT', bound='CoroFunc')
CXT = TypeVar('CXT', bound='Context')
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
"""A callable that implements a command prefix equivalent to being mentioned.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
"""
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> ']
# bot.user will never be None when this is called
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore
def when_mentioned_or(*prefixes):
def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]:
"""A callable that implements when mentioned or other prefixes provided.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
@ -89,7 +110,7 @@ def when_mentioned_or(*prefixes):
return inner
def _is_submodule(parent, child):
def _is_submodule(parent: str, child: str) -> bool:
return parent == child or child.startswith(parent + ".")
class _DefaultRepr:
@ -102,10 +123,10 @@ class BotBase(GroupMixin):
def __init__(self, command_prefix, help_command=_default, description=None, **options):
super().__init__(**options)
self.command_prefix = command_prefix
self.extra_events = {}
self.__cogs = {}
self.__extensions = {}
self._checks = []
self.extra_events: Dict[str, List[CoroFunc]] = {}
self.__cogs: Dict[str, Cog] = {}
self.__extensions: Dict[str, types.ModuleType] = {}
self._checks: List[Check] = []
self._check_once = []
self._before_invoke = None
self._after_invoke = None
@ -128,13 +149,14 @@ class BotBase(GroupMixin):
# internal helpers
def dispatch(self, event_name, *args, **kwargs):
super().dispatch(event_name, *args, **kwargs)
def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None:
# super() will resolve to Client
super().dispatch(event_name, *args, **kwargs) # type: ignore
ev = 'on_' + event_name
for event in self.extra_events.get(ev, []):
self._schedule_event(event, ev, *args, **kwargs)
self._schedule_event(event, ev, *args, **kwargs) # type: ignore
async def close(self):
async def close(self) -> None:
for extension in tuple(self.__extensions):
try:
self.unload_extension(extension)
@ -147,9 +169,9 @@ class BotBase(GroupMixin):
except Exception:
pass
await super().close()
await super().close() # type: ignore
async def on_command_error(self, context, exception):
async def on_command_error(self, context: Context, exception: errors.CommandError) -> None:
"""|coro|
The default command error handler provided by the bot.
@ -175,7 +197,7 @@ class BotBase(GroupMixin):
# global check registration
def check(self, func):
def check(self, func: T) -> T:
r"""A decorator that adds a global check to the bot.
A global check is similar to a :func:`.check` that is applied
@ -200,10 +222,11 @@ class BotBase(GroupMixin):
return ctx.command.qualified_name in allowed_commands
"""
self.add_check(func)
# T was used instead of Check to ensure the type matches on return
self.add_check(func) # type: ignore
return func
def add_check(self, func, *, call_once=False):
def add_check(self, func: Check, *, call_once: bool = False) -> None:
"""Adds a global check to the bot.
This is the non-decorator interface to :meth:`.check`
@ -223,7 +246,7 @@ class BotBase(GroupMixin):
else:
self._checks.append(func)
def remove_check(self, func, *, call_once=False):
def remove_check(self, func: Check, *, call_once: bool = False) -> None:
"""Removes a global check from the bot.
This function is idempotent and will not raise an exception
@ -244,7 +267,7 @@ class BotBase(GroupMixin):
except ValueError:
pass
def check_once(self, func):
def check_once(self, func: CFT) -> CFT:
r"""A decorator that adds a "call once" global check to the bot.
Unlike regular global checks, this one is called only once
@ -282,15 +305,16 @@ class BotBase(GroupMixin):
self.add_check(func, call_once=True)
return func
async def can_run(self, ctx, *, call_once=False):
async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool:
data = self._check_once if call_once else self._checks
if len(data) == 0:
return True
return await discord.utils.async_all(f(ctx) for f in data)
# type-checker doesn't distinguish between functions and methods
return await discord.utils.async_all(f(ctx) for f in data) # type: ignore
async def is_owner(self, user):
async def is_owner(self, user: discord.User) -> bool:
"""|coro|
Checks if a :class:`~discord.User` or :class:`~discord.Member` is the owner of
@ -319,7 +343,8 @@ class BotBase(GroupMixin):
elif self.owner_ids:
return user.id in self.owner_ids
else:
app = await self.application_info()
app = await self.application_info() # type: ignore
if app.team:
self.owner_ids = ids = {m.id for m in app.team.members}
return user.id in ids
@ -327,7 +352,7 @@ class BotBase(GroupMixin):
self.owner_id = owner_id = app.owner.id
return user.id == owner_id
def before_invoke(self, coro):
def before_invoke(self, coro: CFT) -> CFT:
"""A decorator that registers a coroutine as a pre-invoke hook.
A pre-invoke hook is called directly before the command is
@ -359,7 +384,7 @@ class BotBase(GroupMixin):
self._before_invoke = coro
return coro
def after_invoke(self, coro):
def after_invoke(self, coro: CFT) -> CFT:
r"""A decorator that registers a coroutine as a post-invoke hook.
A post-invoke hook is called directly after the command is
@ -394,14 +419,14 @@ class BotBase(GroupMixin):
# listener registration
def add_listener(self, func, name=None):
def add_listener(self, func: CoroFunc, name: str = MISSING) -> None:
"""The non decorator alternative to :meth:`.listen`.
Parameters
-----------
func: :ref:`coroutine <coroutine>`
The function to call.
name: Optional[:class:`str`]
name: :class:`str`
The name of the event to listen for. Defaults to ``func.__name__``.
Example
@ -416,7 +441,7 @@ class BotBase(GroupMixin):
bot.add_listener(my_message, 'on_message')
"""
name = func.__name__ if name is None else name
name = func.__name__ if name is MISSING else name
if not asyncio.iscoroutinefunction(func):
raise TypeError('Listeners must be coroutines')
@ -426,7 +451,7 @@ class BotBase(GroupMixin):
else:
self.extra_events[name] = [func]
def remove_listener(self, func, name=None):
def remove_listener(self, func: CoroFunc, name: str = MISSING) -> None:
"""Removes a listener from the pool of listeners.
Parameters
@ -438,7 +463,7 @@ class BotBase(GroupMixin):
``func.__name__``.
"""
name = func.__name__ if name is None else name
name = func.__name__ if name is MISSING else name
if name in self.extra_events:
try:
@ -446,7 +471,7 @@ class BotBase(GroupMixin):
except ValueError:
pass
def listen(self, name=None):
def listen(self, name: str = MISSING) -> Callable[[CFT], CFT]:
"""A decorator that registers another function as an external
event listener. Basically this allows you to listen to multiple
events from different places e.g. such as :func:`.on_ready`
@ -476,7 +501,7 @@ class BotBase(GroupMixin):
The function being listened to is not a coroutine.
"""
def decorator(func):
def decorator(func: CFT) -> CFT:
self.add_listener(func, name)
return func
@ -528,7 +553,7 @@ class BotBase(GroupMixin):
cog = cog._inject(self)
self.__cogs[cog_name] = cog
def get_cog(self, name):
def get_cog(self, name: str) -> Optional[Cog]:
"""Gets the cog instance requested.
If the cog is not found, ``None`` is returned instead.
@ -547,7 +572,7 @@ class BotBase(GroupMixin):
"""
return self.__cogs.get(name)
def remove_cog(self, name):
def remove_cog(self, name: str) -> Optional[Cog]:
"""Removes a cog from the bot and returns it.
All registered commands and event listeners that the
@ -578,13 +603,13 @@ class BotBase(GroupMixin):
return cog
@property
def cogs(self):
def cogs(self) -> Mapping[str, Cog]:
"""Mapping[:class:`str`, :class:`Cog`]: A read-only mapping of cog name to cog."""
return types.MappingProxyType(self.__cogs)
# extensions
def _remove_module_references(self, name):
def _remove_module_references(self, name: str) -> None:
# find all references to the module
# remove the cogs registered from the module
for cogname, cog in self.__cogs.copy().items():
@ -608,7 +633,7 @@ class BotBase(GroupMixin):
for index in reversed(remove):
del event_list[index]
def _call_module_finalizers(self, lib, key):
def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None:
try:
func = getattr(lib, 'teardown')
except AttributeError:
@ -626,12 +651,12 @@ class BotBase(GroupMixin):
if _is_submodule(name, module):
del sys.modules[module]
def _load_from_module_spec(self, spec, key):
def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None:
# precondition: key not in self.__extensions
lib = importlib.util.module_from_spec(spec)
sys.modules[key] = lib
try:
spec.loader.exec_module(lib)
spec.loader.exec_module(lib) # type: ignore
except Exception as e:
del sys.modules[key]
raise errors.ExtensionFailed(key, e) from e
@ -652,13 +677,13 @@ class BotBase(GroupMixin):
else:
self.__extensions[key] = lib
def _resolve_name(self, name, package):
def _resolve_name(self, name: str, package: Optional[str]) -> str:
try:
return importlib.util.resolve_name(name, package)
except ImportError:
raise errors.ExtensionNotFound(name)
def load_extension(self, name, *, package=None):
def load_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Loads an extension.
An extension is a python module that contains commands, cogs, or
@ -705,7 +730,7 @@ class BotBase(GroupMixin):
self._load_from_module_spec(spec, name)
def unload_extension(self, name, *, package=None):
def unload_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Unloads an extension.
When the extension is unloaded, all commands, listeners, and cogs are
@ -746,7 +771,7 @@ class BotBase(GroupMixin):
self._remove_module_references(lib.__name__)
self._call_module_finalizers(lib, name)
def reload_extension(self, name, *, package=None):
def reload_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Atomically reloads an extension.
This replaces the extension with the same extension, only refreshed. This is
@ -802,7 +827,7 @@ class BotBase(GroupMixin):
# if the load failed, the remnants should have been
# cleaned from the load_extension function call
# so let's load it from our old compiled library.
lib.setup(self)
lib.setup(self) # type: ignore
self.__extensions[name] = lib
# revert sys.modules back to normal and raise back to caller
@ -810,18 +835,18 @@ class BotBase(GroupMixin):
raise
@property
def extensions(self):
def extensions(self) -> Mapping[str, types.ModuleType]:
"""Mapping[:class:`str`, :class:`py:types.ModuleType`]: A read-only mapping of extension name to extension."""
return types.MappingProxyType(self.__extensions)
# help command stuff
@property
def help_command(self):
def help_command(self) -> Optional[HelpCommand]:
return self._help_command
@help_command.setter
def help_command(self, value):
def help_command(self, value: Optional[HelpCommand]) -> None:
if value is not None:
if not isinstance(value, HelpCommand):
raise TypeError('help_command must be a subclass of HelpCommand')
@ -837,7 +862,7 @@ class BotBase(GroupMixin):
# command processing
async def get_prefix(self, message):
async def get_prefix(self, message: Message) -> Union[List[str], str]:
"""|coro|
Retrieves the prefix the bot is listening to
@ -875,7 +900,7 @@ class BotBase(GroupMixin):
return ret
async def get_context(self, message, *, cls=Context):
async def get_context(self, message: Message, *, cls: Type[CXT] = Context) -> CXT:
r"""|coro|
Returns the invocation context from the message.
@ -908,7 +933,7 @@ class BotBase(GroupMixin):
view = StringView(message.content)
ctx = cls(prefix=None, view=view, bot=self, message=message)
if message.author.id == self.user.id:
if message.author.id == self.user.id: # type: ignore
return ctx
prefix = await self.get_prefix(message)
@ -945,11 +970,12 @@ class BotBase(GroupMixin):
invoker = view.get_word()
ctx.invoked_with = invoker
ctx.prefix = invoked_prefix
# type-checker fails to narrow invoked_prefix type.
ctx.prefix = invoked_prefix # type: ignore
ctx.command = self.all_commands.get(invoker)
return ctx
async def invoke(self, ctx):
async def invoke(self, ctx: Context) -> None:
"""|coro|
Invokes the command given under the invocation context and
@ -975,7 +1001,7 @@ class BotBase(GroupMixin):
exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found')
self.dispatch('command_error', ctx, exc)
async def process_commands(self, message):
async def process_commands(self, message: Message) -> None:
"""|coro|
This function processes the commands that have been registered