mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-04-19 15:36:02 +00:00
[commands][types] Type hint commands-ext
This commit is contained in:
parent
d4c683738d
commit
f3cb197429
@ -22,6 +22,26 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
|
||||
from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .context import Context
|
||||
from .cog import Cog
|
||||
from .errors import CommandError
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
Coro = Coroutine[Any, Any, T]
|
||||
MaybeCoro = Union[T, Coro[T]]
|
||||
CoroFunc = Callable[..., Coro[Any]]
|
||||
|
||||
Check = Union[Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]]
|
||||
Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]]
|
||||
Error = Union[Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]]
|
||||
|
||||
|
||||
# This is merely a tag type to avoid circular import issues.
|
||||
# Yes, this is a terrible solution but ultimately it is the only solution.
|
||||
class _BaseCommand:
|
||||
|
@ -22,13 +22,18 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import collections.abc
|
||||
import inspect
|
||||
import importlib.util
|
||||
import sys
|
||||
import traceback
|
||||
import types
|
||||
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union
|
||||
|
||||
import discord
|
||||
|
||||
@ -39,6 +44,15 @@ from . import errors
|
||||
from .help import HelpCommand, DefaultHelpCommand
|
||||
from .cog import Cog
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import importlib.machinery
|
||||
|
||||
from discord.message import Message
|
||||
from ._types import (
|
||||
Check,
|
||||
CoroFunc,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
'when_mentioned',
|
||||
'when_mentioned_or',
|
||||
@ -46,14 +60,21 @@ __all__ = (
|
||||
'AutoShardedBot',
|
||||
)
|
||||
|
||||
def when_mentioned(bot, msg):
|
||||
MISSING: Any = discord.utils.MISSING
|
||||
|
||||
T = TypeVar('T')
|
||||
CFT = TypeVar('CFT', bound='CoroFunc')
|
||||
CXT = TypeVar('CXT', bound='Context')
|
||||
|
||||
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
|
||||
"""A callable that implements a command prefix equivalent to being mentioned.
|
||||
|
||||
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
|
||||
"""
|
||||
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> ']
|
||||
# bot.user will never be None when this is called
|
||||
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore
|
||||
|
||||
def when_mentioned_or(*prefixes):
|
||||
def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]:
|
||||
"""A callable that implements when mentioned or other prefixes provided.
|
||||
|
||||
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
|
||||
@ -89,7 +110,7 @@ def when_mentioned_or(*prefixes):
|
||||
|
||||
return inner
|
||||
|
||||
def _is_submodule(parent, child):
|
||||
def _is_submodule(parent: str, child: str) -> bool:
|
||||
return parent == child or child.startswith(parent + ".")
|
||||
|
||||
class _DefaultRepr:
|
||||
@ -102,10 +123,10 @@ class BotBase(GroupMixin):
|
||||
def __init__(self, command_prefix, help_command=_default, description=None, **options):
|
||||
super().__init__(**options)
|
||||
self.command_prefix = command_prefix
|
||||
self.extra_events = {}
|
||||
self.__cogs = {}
|
||||
self.__extensions = {}
|
||||
self._checks = []
|
||||
self.extra_events: Dict[str, List[CoroFunc]] = {}
|
||||
self.__cogs: Dict[str, Cog] = {}
|
||||
self.__extensions: Dict[str, types.ModuleType] = {}
|
||||
self._checks: List[Check] = []
|
||||
self._check_once = []
|
||||
self._before_invoke = None
|
||||
self._after_invoke = None
|
||||
@ -128,13 +149,14 @@ class BotBase(GroupMixin):
|
||||
|
||||
# internal helpers
|
||||
|
||||
def dispatch(self, event_name, *args, **kwargs):
|
||||
super().dispatch(event_name, *args, **kwargs)
|
||||
def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None:
|
||||
# super() will resolve to Client
|
||||
super().dispatch(event_name, *args, **kwargs) # type: ignore
|
||||
ev = 'on_' + event_name
|
||||
for event in self.extra_events.get(ev, []):
|
||||
self._schedule_event(event, ev, *args, **kwargs)
|
||||
self._schedule_event(event, ev, *args, **kwargs) # type: ignore
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
for extension in tuple(self.__extensions):
|
||||
try:
|
||||
self.unload_extension(extension)
|
||||
@ -147,9 +169,9 @@ class BotBase(GroupMixin):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await super().close()
|
||||
await super().close() # type: ignore
|
||||
|
||||
async def on_command_error(self, context, exception):
|
||||
async def on_command_error(self, context: Context, exception: errors.CommandError) -> None:
|
||||
"""|coro|
|
||||
|
||||
The default command error handler provided by the bot.
|
||||
@ -175,7 +197,7 @@ class BotBase(GroupMixin):
|
||||
|
||||
# global check registration
|
||||
|
||||
def check(self, func):
|
||||
def check(self, func: T) -> T:
|
||||
r"""A decorator that adds a global check to the bot.
|
||||
|
||||
A global check is similar to a :func:`.check` that is applied
|
||||
@ -200,10 +222,11 @@ class BotBase(GroupMixin):
|
||||
return ctx.command.qualified_name in allowed_commands
|
||||
|
||||
"""
|
||||
self.add_check(func)
|
||||
# T was used instead of Check to ensure the type matches on return
|
||||
self.add_check(func) # type: ignore
|
||||
return func
|
||||
|
||||
def add_check(self, func, *, call_once=False):
|
||||
def add_check(self, func: Check, *, call_once: bool = False) -> None:
|
||||
"""Adds a global check to the bot.
|
||||
|
||||
This is the non-decorator interface to :meth:`.check`
|
||||
@ -223,7 +246,7 @@ class BotBase(GroupMixin):
|
||||
else:
|
||||
self._checks.append(func)
|
||||
|
||||
def remove_check(self, func, *, call_once=False):
|
||||
def remove_check(self, func: Check, *, call_once: bool = False) -> None:
|
||||
"""Removes a global check from the bot.
|
||||
|
||||
This function is idempotent and will not raise an exception
|
||||
@ -244,7 +267,7 @@ class BotBase(GroupMixin):
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def check_once(self, func):
|
||||
def check_once(self, func: CFT) -> CFT:
|
||||
r"""A decorator that adds a "call once" global check to the bot.
|
||||
|
||||
Unlike regular global checks, this one is called only once
|
||||
@ -282,15 +305,16 @@ class BotBase(GroupMixin):
|
||||
self.add_check(func, call_once=True)
|
||||
return func
|
||||
|
||||
async def can_run(self, ctx, *, call_once=False):
|
||||
async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool:
|
||||
data = self._check_once if call_once else self._checks
|
||||
|
||||
if len(data) == 0:
|
||||
return True
|
||||
|
||||
return await discord.utils.async_all(f(ctx) for f in data)
|
||||
# type-checker doesn't distinguish between functions and methods
|
||||
return await discord.utils.async_all(f(ctx) for f in data) # type: ignore
|
||||
|
||||
async def is_owner(self, user):
|
||||
async def is_owner(self, user: discord.User) -> bool:
|
||||
"""|coro|
|
||||
|
||||
Checks if a :class:`~discord.User` or :class:`~discord.Member` is the owner of
|
||||
@ -319,7 +343,8 @@ class BotBase(GroupMixin):
|
||||
elif self.owner_ids:
|
||||
return user.id in self.owner_ids
|
||||
else:
|
||||
app = await self.application_info()
|
||||
|
||||
app = await self.application_info() # type: ignore
|
||||
if app.team:
|
||||
self.owner_ids = ids = {m.id for m in app.team.members}
|
||||
return user.id in ids
|
||||
@ -327,7 +352,7 @@ class BotBase(GroupMixin):
|
||||
self.owner_id = owner_id = app.owner.id
|
||||
return user.id == owner_id
|
||||
|
||||
def before_invoke(self, coro):
|
||||
def before_invoke(self, coro: CFT) -> CFT:
|
||||
"""A decorator that registers a coroutine as a pre-invoke hook.
|
||||
|
||||
A pre-invoke hook is called directly before the command is
|
||||
@ -359,7 +384,7 @@ class BotBase(GroupMixin):
|
||||
self._before_invoke = coro
|
||||
return coro
|
||||
|
||||
def after_invoke(self, coro):
|
||||
def after_invoke(self, coro: CFT) -> CFT:
|
||||
r"""A decorator that registers a coroutine as a post-invoke hook.
|
||||
|
||||
A post-invoke hook is called directly after the command is
|
||||
@ -394,14 +419,14 @@ class BotBase(GroupMixin):
|
||||
|
||||
# listener registration
|
||||
|
||||
def add_listener(self, func, name=None):
|
||||
def add_listener(self, func: CoroFunc, name: str = MISSING) -> None:
|
||||
"""The non decorator alternative to :meth:`.listen`.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
func: :ref:`coroutine <coroutine>`
|
||||
The function to call.
|
||||
name: Optional[:class:`str`]
|
||||
name: :class:`str`
|
||||
The name of the event to listen for. Defaults to ``func.__name__``.
|
||||
|
||||
Example
|
||||
@ -416,7 +441,7 @@ class BotBase(GroupMixin):
|
||||
bot.add_listener(my_message, 'on_message')
|
||||
|
||||
"""
|
||||
name = func.__name__ if name is None else name
|
||||
name = func.__name__ if name is MISSING else name
|
||||
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
raise TypeError('Listeners must be coroutines')
|
||||
@ -426,7 +451,7 @@ class BotBase(GroupMixin):
|
||||
else:
|
||||
self.extra_events[name] = [func]
|
||||
|
||||
def remove_listener(self, func, name=None):
|
||||
def remove_listener(self, func: CoroFunc, name: str = MISSING) -> None:
|
||||
"""Removes a listener from the pool of listeners.
|
||||
|
||||
Parameters
|
||||
@ -438,7 +463,7 @@ class BotBase(GroupMixin):
|
||||
``func.__name__``.
|
||||
"""
|
||||
|
||||
name = func.__name__ if name is None else name
|
||||
name = func.__name__ if name is MISSING else name
|
||||
|
||||
if name in self.extra_events:
|
||||
try:
|
||||
@ -446,7 +471,7 @@ class BotBase(GroupMixin):
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def listen(self, name=None):
|
||||
def listen(self, name: str = MISSING) -> Callable[[CFT], CFT]:
|
||||
"""A decorator that registers another function as an external
|
||||
event listener. Basically this allows you to listen to multiple
|
||||
events from different places e.g. such as :func:`.on_ready`
|
||||
@ -476,7 +501,7 @@ class BotBase(GroupMixin):
|
||||
The function being listened to is not a coroutine.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
def decorator(func: CFT) -> CFT:
|
||||
self.add_listener(func, name)
|
||||
return func
|
||||
|
||||
@ -528,7 +553,7 @@ class BotBase(GroupMixin):
|
||||
cog = cog._inject(self)
|
||||
self.__cogs[cog_name] = cog
|
||||
|
||||
def get_cog(self, name):
|
||||
def get_cog(self, name: str) -> Optional[Cog]:
|
||||
"""Gets the cog instance requested.
|
||||
|
||||
If the cog is not found, ``None`` is returned instead.
|
||||
@ -547,7 +572,7 @@ class BotBase(GroupMixin):
|
||||
"""
|
||||
return self.__cogs.get(name)
|
||||
|
||||
def remove_cog(self, name):
|
||||
def remove_cog(self, name: str) -> Optional[Cog]:
|
||||
"""Removes a cog from the bot and returns it.
|
||||
|
||||
All registered commands and event listeners that the
|
||||
@ -578,13 +603,13 @@ class BotBase(GroupMixin):
|
||||
return cog
|
||||
|
||||
@property
|
||||
def cogs(self):
|
||||
def cogs(self) -> Mapping[str, Cog]:
|
||||
"""Mapping[:class:`str`, :class:`Cog`]: A read-only mapping of cog name to cog."""
|
||||
return types.MappingProxyType(self.__cogs)
|
||||
|
||||
# extensions
|
||||
|
||||
def _remove_module_references(self, name):
|
||||
def _remove_module_references(self, name: str) -> None:
|
||||
# find all references to the module
|
||||
# remove the cogs registered from the module
|
||||
for cogname, cog in self.__cogs.copy().items():
|
||||
@ -608,7 +633,7 @@ class BotBase(GroupMixin):
|
||||
for index in reversed(remove):
|
||||
del event_list[index]
|
||||
|
||||
def _call_module_finalizers(self, lib, key):
|
||||
def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None:
|
||||
try:
|
||||
func = getattr(lib, 'teardown')
|
||||
except AttributeError:
|
||||
@ -626,12 +651,12 @@ class BotBase(GroupMixin):
|
||||
if _is_submodule(name, module):
|
||||
del sys.modules[module]
|
||||
|
||||
def _load_from_module_spec(self, spec, key):
|
||||
def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None:
|
||||
# precondition: key not in self.__extensions
|
||||
lib = importlib.util.module_from_spec(spec)
|
||||
sys.modules[key] = lib
|
||||
try:
|
||||
spec.loader.exec_module(lib)
|
||||
spec.loader.exec_module(lib) # type: ignore
|
||||
except Exception as e:
|
||||
del sys.modules[key]
|
||||
raise errors.ExtensionFailed(key, e) from e
|
||||
@ -652,13 +677,13 @@ class BotBase(GroupMixin):
|
||||
else:
|
||||
self.__extensions[key] = lib
|
||||
|
||||
def _resolve_name(self, name, package):
|
||||
def _resolve_name(self, name: str, package: Optional[str]) -> str:
|
||||
try:
|
||||
return importlib.util.resolve_name(name, package)
|
||||
except ImportError:
|
||||
raise errors.ExtensionNotFound(name)
|
||||
|
||||
def load_extension(self, name, *, package=None):
|
||||
def load_extension(self, name: str, *, package: Optional[str] = None) -> None:
|
||||
"""Loads an extension.
|
||||
|
||||
An extension is a python module that contains commands, cogs, or
|
||||
@ -705,7 +730,7 @@ class BotBase(GroupMixin):
|
||||
|
||||
self._load_from_module_spec(spec, name)
|
||||
|
||||
def unload_extension(self, name, *, package=None):
|
||||
def unload_extension(self, name: str, *, package: Optional[str] = None) -> None:
|
||||
"""Unloads an extension.
|
||||
|
||||
When the extension is unloaded, all commands, listeners, and cogs are
|
||||
@ -746,7 +771,7 @@ class BotBase(GroupMixin):
|
||||
self._remove_module_references(lib.__name__)
|
||||
self._call_module_finalizers(lib, name)
|
||||
|
||||
def reload_extension(self, name, *, package=None):
|
||||
def reload_extension(self, name: str, *, package: Optional[str] = None) -> None:
|
||||
"""Atomically reloads an extension.
|
||||
|
||||
This replaces the extension with the same extension, only refreshed. This is
|
||||
@ -802,7 +827,7 @@ class BotBase(GroupMixin):
|
||||
# if the load failed, the remnants should have been
|
||||
# cleaned from the load_extension function call
|
||||
# so let's load it from our old compiled library.
|
||||
lib.setup(self)
|
||||
lib.setup(self) # type: ignore
|
||||
self.__extensions[name] = lib
|
||||
|
||||
# revert sys.modules back to normal and raise back to caller
|
||||
@ -810,18 +835,18 @@ class BotBase(GroupMixin):
|
||||
raise
|
||||
|
||||
@property
|
||||
def extensions(self):
|
||||
def extensions(self) -> Mapping[str, types.ModuleType]:
|
||||
"""Mapping[:class:`str`, :class:`py:types.ModuleType`]: A read-only mapping of extension name to extension."""
|
||||
return types.MappingProxyType(self.__extensions)
|
||||
|
||||
# help command stuff
|
||||
|
||||
@property
|
||||
def help_command(self):
|
||||
def help_command(self) -> Optional[HelpCommand]:
|
||||
return self._help_command
|
||||
|
||||
@help_command.setter
|
||||
def help_command(self, value):
|
||||
def help_command(self, value: Optional[HelpCommand]) -> None:
|
||||
if value is not None:
|
||||
if not isinstance(value, HelpCommand):
|
||||
raise TypeError('help_command must be a subclass of HelpCommand')
|
||||
@ -837,7 +862,7 @@ class BotBase(GroupMixin):
|
||||
|
||||
# command processing
|
||||
|
||||
async def get_prefix(self, message):
|
||||
async def get_prefix(self, message: Message) -> Union[List[str], str]:
|
||||
"""|coro|
|
||||
|
||||
Retrieves the prefix the bot is listening to
|
||||
@ -875,7 +900,7 @@ class BotBase(GroupMixin):
|
||||
|
||||
return ret
|
||||
|
||||
async def get_context(self, message, *, cls=Context):
|
||||
async def get_context(self, message: Message, *, cls: Type[CXT] = Context) -> CXT:
|
||||
r"""|coro|
|
||||
|
||||
Returns the invocation context from the message.
|
||||
@ -908,7 +933,7 @@ class BotBase(GroupMixin):
|
||||
view = StringView(message.content)
|
||||
ctx = cls(prefix=None, view=view, bot=self, message=message)
|
||||
|
||||
if message.author.id == self.user.id:
|
||||
if message.author.id == self.user.id: # type: ignore
|
||||
return ctx
|
||||
|
||||
prefix = await self.get_prefix(message)
|
||||
@ -945,11 +970,12 @@ class BotBase(GroupMixin):
|
||||
|
||||
invoker = view.get_word()
|
||||
ctx.invoked_with = invoker
|
||||
ctx.prefix = invoked_prefix
|
||||
# type-checker fails to narrow invoked_prefix type.
|
||||
ctx.prefix = invoked_prefix # type: ignore
|
||||
ctx.command = self.all_commands.get(invoker)
|
||||
return ctx
|
||||
|
||||
async def invoke(self, ctx):
|
||||
async def invoke(self, ctx: Context) -> None:
|
||||
"""|coro|
|
||||
|
||||
Invokes the command given under the invocation context and
|
||||
@ -975,7 +1001,7 @@ class BotBase(GroupMixin):
|
||||
exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found')
|
||||
self.dispatch('command_error', ctx, exc)
|
||||
|
||||
async def process_commands(self, message):
|
||||
async def process_commands(self, message: Message) -> None:
|
||||
"""|coro|
|
||||
|
||||
This function processes the commands that have been registered
|
||||
|
@ -21,15 +21,30 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import discord.utils
|
||||
|
||||
from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Type
|
||||
|
||||
from ._types import _BaseCommand
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .bot import BotBase
|
||||
from .context import Context
|
||||
from .core import Command
|
||||
|
||||
__all__ = (
|
||||
'CogMeta',
|
||||
'Cog',
|
||||
)
|
||||
|
||||
CogT = TypeVar('CogT', bound='Cog')
|
||||
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
|
||||
|
||||
MISSING: Any = discord.utils.MISSING
|
||||
|
||||
class CogMeta(type):
|
||||
"""A metaclass for defining a cog.
|
||||
|
||||
@ -89,8 +104,12 @@ class CogMeta(type):
|
||||
async def bar(self, ctx):
|
||||
pass # hidden -> False
|
||||
"""
|
||||
__cog_name__: str
|
||||
__cog_settings__: Dict[str, Any]
|
||||
__cog_commands__: List[Command]
|
||||
__cog_listeners__: List[Tuple[str, str]]
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta:
|
||||
name, bases, attrs = args
|
||||
attrs['__cog_name__'] = kwargs.pop('name', name)
|
||||
attrs['__cog_settings__'] = kwargs.pop('command_attrs', {})
|
||||
@ -143,14 +162,14 @@ class CogMeta(type):
|
||||
new_cls.__cog_listeners__ = listeners_as_list
|
||||
return new_cls
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args)
|
||||
|
||||
@classmethod
|
||||
def qualified_name(cls):
|
||||
def qualified_name(cls) -> str:
|
||||
return cls.__cog_name__
|
||||
|
||||
def _cog_special_method(func):
|
||||
def _cog_special_method(func: FuncT) -> FuncT:
|
||||
func.__cog_special_method__ = None
|
||||
return func
|
||||
|
||||
@ -164,8 +183,12 @@ class Cog(metaclass=CogMeta):
|
||||
When inheriting from this class, the options shown in :class:`CogMeta`
|
||||
are equally valid here.
|
||||
"""
|
||||
__cog_name__: ClassVar[str]
|
||||
__cog_settings__: ClassVar[Dict[str, Any]]
|
||||
__cog_commands__: ClassVar[List[Command]]
|
||||
__cog_listeners__: ClassVar[List[Tuple[str, str]]]
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
def __new__(cls: Type[CogT], *args: Any, **kwargs: Any) -> CogT:
|
||||
# For issue 426, we need to store a copy of the command objects
|
||||
# since we modify them to inject `self` to them.
|
||||
# To do this, we need to interfere with the Cog creation process.
|
||||
@ -173,7 +196,8 @@ class Cog(metaclass=CogMeta):
|
||||
cmd_attrs = cls.__cog_settings__
|
||||
|
||||
# Either update the command with the cog provided defaults or copy it.
|
||||
self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__)
|
||||
# r.e type ignore, type-checker complains about overriding a ClassVar
|
||||
self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore
|
||||
|
||||
lookup = {
|
||||
cmd.qualified_name: cmd
|
||||
@ -186,15 +210,15 @@ class Cog(metaclass=CogMeta):
|
||||
parent = command.parent
|
||||
if parent is not None:
|
||||
# Get the latest parent reference
|
||||
parent = lookup[parent.qualified_name]
|
||||
parent = lookup[parent.qualified_name] # type: ignore
|
||||
|
||||
# Update our parent's reference to our self
|
||||
parent.remove_command(command.name)
|
||||
parent.add_command(command)
|
||||
parent.remove_command(command.name) # type: ignore
|
||||
parent.add_command(command) # type: ignore
|
||||
|
||||
return self
|
||||
|
||||
def get_commands(self):
|
||||
def get_commands(self) -> List[Command]:
|
||||
r"""
|
||||
Returns
|
||||
--------
|
||||
@ -209,20 +233,20 @@ class Cog(metaclass=CogMeta):
|
||||
return [c for c in self.__cog_commands__ if c.parent is None]
|
||||
|
||||
@property
|
||||
def qualified_name(self):
|
||||
def qualified_name(self) -> str:
|
||||
""":class:`str`: Returns the cog's specified name, not the class name."""
|
||||
return self.__cog_name__
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
def description(self) -> str:
|
||||
""":class:`str`: Returns the cog's description, typically the cleaned docstring."""
|
||||
return self.__cog_description__
|
||||
|
||||
@description.setter
|
||||
def description(self, description):
|
||||
def description(self, description: str) -> None:
|
||||
self.__cog_description__ = description
|
||||
|
||||
def walk_commands(self):
|
||||
def walk_commands(self) -> Generator[Command, None, None]:
|
||||
"""An iterator that recursively walks through this cog's commands and subcommands.
|
||||
|
||||
Yields
|
||||
@ -237,7 +261,7 @@ class Cog(metaclass=CogMeta):
|
||||
if isinstance(command, GroupMixin):
|
||||
yield from command.walk_commands()
|
||||
|
||||
def get_listeners(self):
|
||||
def get_listeners(self) -> List[Tuple[str, Callable[..., Any]]]:
|
||||
"""Returns a :class:`list` of (name, function) listener pairs that are defined in this cog.
|
||||
|
||||
Returns
|
||||
@ -248,12 +272,12 @@ class Cog(metaclass=CogMeta):
|
||||
return [(name, getattr(self, method_name)) for name, method_name in self.__cog_listeners__]
|
||||
|
||||
@classmethod
|
||||
def _get_overridden_method(cls, method):
|
||||
def _get_overridden_method(cls, method: FuncT) -> Optional[FuncT]:
|
||||
"""Return None if the method is not overridden. Otherwise returns the overridden method."""
|
||||
return getattr(method.__func__, '__cog_special_method__', method)
|
||||
|
||||
@classmethod
|
||||
def listener(cls, name=None):
|
||||
def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]:
|
||||
"""A decorator that marks a function as a listener.
|
||||
|
||||
This is the cog equivalent of :meth:`.Bot.listen`.
|
||||
@ -271,10 +295,10 @@ class Cog(metaclass=CogMeta):
|
||||
the name.
|
||||
"""
|
||||
|
||||
if name is not None and not isinstance(name, str):
|
||||
if name is not MISSING and not isinstance(name, str):
|
||||
raise TypeError(f'Cog.listener expected str but received {name.__class__.__name__!r} instead.')
|
||||
|
||||
def decorator(func):
|
||||
def decorator(func: FuncT) -> FuncT:
|
||||
actual = func
|
||||
if isinstance(actual, staticmethod):
|
||||
actual = actual.__func__
|
||||
@ -293,7 +317,7 @@ class Cog(metaclass=CogMeta):
|
||||
return func
|
||||
return decorator
|
||||
|
||||
def has_error_handler(self):
|
||||
def has_error_handler(self) -> bool:
|
||||
""":class:`bool`: Checks whether the cog has an error handler.
|
||||
|
||||
.. versionadded:: 1.7
|
||||
@ -301,7 +325,7 @@ class Cog(metaclass=CogMeta):
|
||||
return not hasattr(self.cog_command_error.__func__, '__cog_special_method__')
|
||||
|
||||
@_cog_special_method
|
||||
def cog_unload(self):
|
||||
def cog_unload(self) -> None:
|
||||
"""A special method that is called when the cog gets removed.
|
||||
|
||||
This function **cannot** be a coroutine. It must be a regular
|
||||
@ -312,7 +336,7 @@ class Cog(metaclass=CogMeta):
|
||||
pass
|
||||
|
||||
@_cog_special_method
|
||||
def bot_check_once(self, ctx):
|
||||
def bot_check_once(self, ctx: Context) -> bool:
|
||||
"""A special method that registers as a :meth:`.Bot.check_once`
|
||||
check.
|
||||
|
||||
@ -322,7 +346,7 @@ class Cog(metaclass=CogMeta):
|
||||
return True
|
||||
|
||||
@_cog_special_method
|
||||
def bot_check(self, ctx):
|
||||
def bot_check(self, ctx: Context) -> bool:
|
||||
"""A special method that registers as a :meth:`.Bot.check`
|
||||
check.
|
||||
|
||||
@ -332,7 +356,7 @@ class Cog(metaclass=CogMeta):
|
||||
return True
|
||||
|
||||
@_cog_special_method
|
||||
def cog_check(self, ctx):
|
||||
def cog_check(self, ctx: Context) -> bool:
|
||||
"""A special method that registers as a :func:`~discord.ext.commands.check`
|
||||
for every command and subcommand in this cog.
|
||||
|
||||
@ -342,7 +366,7 @@ class Cog(metaclass=CogMeta):
|
||||
return True
|
||||
|
||||
@_cog_special_method
|
||||
async def cog_command_error(self, ctx, error):
|
||||
async def cog_command_error(self, ctx: Context, error: Exception) -> None:
|
||||
"""A special method that is called whenever an error
|
||||
is dispatched inside this cog.
|
||||
|
||||
@ -361,7 +385,7 @@ class Cog(metaclass=CogMeta):
|
||||
pass
|
||||
|
||||
@_cog_special_method
|
||||
async def cog_before_invoke(self, ctx):
|
||||
async def cog_before_invoke(self, ctx: Context) -> None:
|
||||
"""A special method that acts as a cog local pre-invoke hook.
|
||||
|
||||
This is similar to :meth:`.Command.before_invoke`.
|
||||
@ -376,7 +400,7 @@ class Cog(metaclass=CogMeta):
|
||||
pass
|
||||
|
||||
@_cog_special_method
|
||||
async def cog_after_invoke(self, ctx):
|
||||
async def cog_after_invoke(self, ctx: Context) -> None:
|
||||
"""A special method that acts as a cog local post-invoke hook.
|
||||
|
||||
This is similar to :meth:`.Command.after_invoke`.
|
||||
@ -390,7 +414,7 @@ class Cog(metaclass=CogMeta):
|
||||
"""
|
||||
pass
|
||||
|
||||
def _inject(self, bot):
|
||||
def _inject(self: CogT, bot: BotBase) -> CogT:
|
||||
cls = self.__class__
|
||||
|
||||
# realistically, the only thing that can cause loading errors
|
||||
@ -425,7 +449,7 @@ class Cog(metaclass=CogMeta):
|
||||
|
||||
return self
|
||||
|
||||
def _eject(self, bot):
|
||||
def _eject(self, bot: BotBase) -> None:
|
||||
cls = self.__class__
|
||||
|
||||
try:
|
||||
|
@ -21,16 +21,52 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import re
|
||||
|
||||
from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
import discord.abc
|
||||
import discord.utils
|
||||
import re
|
||||
|
||||
from discord.message import Message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from discord.abc import MessageableChannel
|
||||
from discord.guild import Guild
|
||||
from discord.member import Member
|
||||
from discord.state import ConnectionState
|
||||
from discord.user import ClientUser, User
|
||||
from discord.voice_client import VoiceProtocol
|
||||
|
||||
from .bot import Bot, AutoShardedBot
|
||||
from .cog import Cog
|
||||
from .core import Command
|
||||
from .help import HelpCommand
|
||||
from .view import StringView
|
||||
|
||||
__all__ = (
|
||||
'Context',
|
||||
)
|
||||
|
||||
class Context(discord.abc.Messageable):
|
||||
MISSING: Any = discord.utils.MISSING
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
|
||||
CogT = TypeVar('CogT', bound="Cog")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
P = ParamSpec('P')
|
||||
else:
|
||||
P = TypeVar('P')
|
||||
|
||||
|
||||
class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
r"""Represents the context in which a command is being invoked under.
|
||||
|
||||
This class contains a lot of meta data to help you understand more about
|
||||
@ -58,11 +94,11 @@ class Context(discord.abc.Messageable):
|
||||
This is only of use for within converters.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
prefix: :class:`str`
|
||||
prefix: Optional[:class:`str`]
|
||||
The prefix that was used to invoke the command.
|
||||
command: :class:`Command`
|
||||
command: Optional[:class:`Command`]
|
||||
The command that is being invoked currently.
|
||||
invoked_with: :class:`str`
|
||||
invoked_with: Optional[:class:`str`]
|
||||
The command name that triggered this invocation. Useful for finding out
|
||||
which alias called the command.
|
||||
invoked_parents: List[:class:`str`]
|
||||
@ -73,7 +109,7 @@ class Context(discord.abc.Messageable):
|
||||
|
||||
.. versionadded:: 1.7
|
||||
|
||||
invoked_subcommand: :class:`Command`
|
||||
invoked_subcommand: Optional[:class:`Command`]
|
||||
The subcommand that was invoked.
|
||||
If no valid subcommand was invoked then this is equal to ``None``.
|
||||
subcommand_passed: Optional[:class:`str`]
|
||||
@ -86,23 +122,38 @@ class Context(discord.abc.Messageable):
|
||||
or invoked.
|
||||
"""
|
||||
|
||||
def __init__(self, **attrs):
|
||||
self.message = attrs.pop('message', None)
|
||||
self.bot = attrs.pop('bot', None)
|
||||
self.args = attrs.pop('args', [])
|
||||
self.kwargs = attrs.pop('kwargs', {})
|
||||
self.prefix = attrs.pop('prefix')
|
||||
self.command = attrs.pop('command', None)
|
||||
self.view = attrs.pop('view', None)
|
||||
self.invoked_with = attrs.pop('invoked_with', None)
|
||||
self.invoked_parents = attrs.pop('invoked_parents', [])
|
||||
self.invoked_subcommand = attrs.pop('invoked_subcommand', None)
|
||||
self.subcommand_passed = attrs.pop('subcommand_passed', None)
|
||||
self.command_failed = attrs.pop('command_failed', False)
|
||||
self.current_parameter = attrs.pop('current_parameter', None)
|
||||
self._state = self.message._state
|
||||
def __init__(self,
|
||||
*,
|
||||
message: Message,
|
||||
bot: BotT,
|
||||
view: StringView,
|
||||
args: List[Any] = MISSING,
|
||||
kwargs: Dict[str, Any] = MISSING,
|
||||
prefix: Optional[str] = None,
|
||||
command: Optional[Command] = None,
|
||||
invoked_with: Optional[str] = None,
|
||||
invoked_parents: List[str] = MISSING,
|
||||
invoked_subcommand: Optional[Command] = None,
|
||||
subcommand_passed: Optional[str] = None,
|
||||
command_failed: bool = False,
|
||||
current_parameter: Optional[inspect.Parameter] = None,
|
||||
):
|
||||
self.message: Message = message
|
||||
self.bot: BotT = bot
|
||||
self.args: List[Any] = args or []
|
||||
self.kwargs: Dict[str, Any] = kwargs or {}
|
||||
self.prefix: Optional[str] = prefix
|
||||
self.command: Optional[Command] = command
|
||||
self.view: StringView = view
|
||||
self.invoked_with: Optional[str] = invoked_with
|
||||
self.invoked_parents: List[str] = invoked_parents or []
|
||||
self.invoked_subcommand: Optional[Command] = invoked_subcommand
|
||||
self.subcommand_passed: Optional[str] = subcommand_passed
|
||||
self.command_failed: bool = command_failed
|
||||
self.current_parameter: Optional[inspect.Parameter] = current_parameter
|
||||
self._state: ConnectionState = self.message._state
|
||||
|
||||
async def invoke(self, command, /, *args, **kwargs):
|
||||
async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
|
||||
r"""|coro|
|
||||
|
||||
Calls a command with the arguments given.
|
||||
@ -133,17 +184,9 @@ class Context(discord.abc.Messageable):
|
||||
TypeError
|
||||
The command argument to invoke is missing.
|
||||
"""
|
||||
arguments = []
|
||||
if command.cog is not None:
|
||||
arguments.append(command.cog)
|
||||
return await command(self, *args, **kwargs)
|
||||
|
||||
arguments.append(self)
|
||||
arguments.extend(args)
|
||||
|
||||
ret = await command.callback(*arguments, **kwargs)
|
||||
return ret
|
||||
|
||||
async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True):
|
||||
async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True) -> None:
|
||||
"""|coro|
|
||||
|
||||
Calls the command again.
|
||||
@ -187,7 +230,7 @@ class Context(discord.abc.Messageable):
|
||||
|
||||
if restart:
|
||||
to_call = cmd.root_parent or cmd
|
||||
view.index = len(self.prefix)
|
||||
view.index = len(self.prefix or '')
|
||||
view.previous = 0
|
||||
self.invoked_parents = []
|
||||
self.invoked_with = view.get_word() # advance to get the root command
|
||||
@ -206,20 +249,23 @@ class Context(discord.abc.Messageable):
|
||||
self.subcommand_passed = subcommand_passed
|
||||
|
||||
@property
|
||||
def valid(self):
|
||||
def valid(self) -> bool:
|
||||
""":class:`bool`: Checks if the invocation context is valid to be invoked with."""
|
||||
return self.prefix is not None and self.command is not None
|
||||
|
||||
async def _get_channel(self):
|
||||
async def _get_channel(self) -> discord.abc.Messageable:
|
||||
return self.channel
|
||||
|
||||
@property
|
||||
def clean_prefix(self):
|
||||
def clean_prefix(self) -> str:
|
||||
""":class:`str`: The cleaned up invoke prefix. i.e. mentions are ``@name`` instead of ``<@id>``.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
user = self.guild.me if self.guild else self.bot.user
|
||||
if self.prefix is None:
|
||||
return ''
|
||||
|
||||
user = self.me
|
||||
# this breaks if the prefix mention is not the bot itself but I
|
||||
# consider this to be an *incredibly* strange use case. I'd rather go
|
||||
# for this common use case rather than waste performance for the
|
||||
@ -228,7 +274,7 @@ class Context(discord.abc.Messageable):
|
||||
return pattern.sub("@%s" % user.display_name.replace('\\', r'\\'), self.prefix)
|
||||
|
||||
@property
|
||||
def cog(self):
|
||||
def cog(self) -> Optional[Cog]:
|
||||
"""Optional[:class:`.Cog`]: Returns the cog associated with this context's command. None if it does not exist."""
|
||||
|
||||
if self.command is None:
|
||||
@ -236,38 +282,39 @@ class Context(discord.abc.Messageable):
|
||||
return self.command.cog
|
||||
|
||||
@discord.utils.cached_property
|
||||
def guild(self):
|
||||
def guild(self) -> Optional[Guild]:
|
||||
"""Optional[:class:`.Guild`]: Returns the guild associated with this context's command. None if not available."""
|
||||
return self.message.guild
|
||||
|
||||
@discord.utils.cached_property
|
||||
def channel(self):
|
||||
def channel(self) -> MessageableChannel:
|
||||
"""Union[:class:`.abc.Messageable`]: Returns the channel associated with this context's command.
|
||||
Shorthand for :attr:`.Message.channel`.
|
||||
"""
|
||||
return self.message.channel
|
||||
|
||||
@discord.utils.cached_property
|
||||
def author(self):
|
||||
def author(self) -> Union[User, Member]:
|
||||
"""Union[:class:`~discord.User`, :class:`.Member`]:
|
||||
Returns the author associated with this context's command. Shorthand for :attr:`.Message.author`
|
||||
"""
|
||||
return self.message.author
|
||||
|
||||
@discord.utils.cached_property
|
||||
def me(self):
|
||||
def me(self) -> Union[Member, ClientUser]:
|
||||
"""Union[:class:`.Member`, :class:`.ClientUser`]:
|
||||
Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message contexts.
|
||||
"""
|
||||
return self.guild.me if self.guild is not None else self.bot.user
|
||||
# bot.user will never be None at this point.
|
||||
return self.guild.me if self.guild is not None else self.bot.user # type: ignore
|
||||
|
||||
@property
|
||||
def voice_client(self):
|
||||
def voice_client(self) -> Optional[VoiceProtocol]:
|
||||
r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable."""
|
||||
g = self.guild
|
||||
return g.voice_client if g else None
|
||||
|
||||
async def send_help(self, *args):
|
||||
async def send_help(self, *args: Any) -> Any:
|
||||
"""send_help(entity=<bot>)
|
||||
|
||||
|coro|
|
||||
@ -319,12 +366,12 @@ class Context(discord.abc.Messageable):
|
||||
return None
|
||||
|
||||
entity = args[0]
|
||||
if entity is None:
|
||||
return None
|
||||
|
||||
if isinstance(entity, str):
|
||||
entity = bot.get_cog(entity) or bot.get_command(entity)
|
||||
|
||||
if entity is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
entity.qualified_name
|
||||
except AttributeError:
|
||||
@ -348,6 +395,6 @@ class Context(discord.abc.Messageable):
|
||||
except CommandError as e:
|
||||
await cmd.on_help_command_error(self, e)
|
||||
|
||||
@discord.utils.copy_doc(discord.Message.reply)
|
||||
async def reply(self, content=None, **kwargs):
|
||||
@discord.utils.copy_doc(Message.reply)
|
||||
async def reply(self, content: Optional[str] = None, **kwargs: Any) -> Message:
|
||||
return await self.message.reply(content, **kwargs)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -27,11 +27,17 @@ import copy
|
||||
import functools
|
||||
import inspect
|
||||
import re
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
import discord.utils
|
||||
|
||||
from .core import Group, Command
|
||||
from .errors import CommandError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .context import Context
|
||||
|
||||
__all__ = (
|
||||
'Paginator',
|
||||
'HelpCommand',
|
||||
@ -320,7 +326,7 @@ class HelpCommand:
|
||||
self.command_attrs = attrs = options.pop('command_attrs', {})
|
||||
attrs.setdefault('name', 'help')
|
||||
attrs.setdefault('help', 'Shows this message')
|
||||
self.context = None
|
||||
self.context: Optional[Context] = None
|
||||
self._command_impl = _HelpCommandImpl(self, **self.command_attrs)
|
||||
|
||||
def copy(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user