mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-09-07 10:22:59 +00:00
[commads] Change cog/extension load/unload methods to be async
This commit is contained in:
@ -170,13 +170,13 @@ class BotBase(GroupMixin):
|
||||
async def close(self) -> None:
|
||||
for extension in tuple(self.__extensions):
|
||||
try:
|
||||
self.unload_extension(extension)
|
||||
await self.unload_extension(extension)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for cog in tuple(self.__cogs):
|
||||
try:
|
||||
self.remove_cog(cog)
|
||||
await self.remove_cog(cog)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@ -528,7 +528,7 @@ class BotBase(GroupMixin):
|
||||
|
||||
# cogs
|
||||
|
||||
def add_cog(
|
||||
async def add_cog(
|
||||
self,
|
||||
cog: Cog,
|
||||
/,
|
||||
@ -537,13 +537,20 @@ class BotBase(GroupMixin):
|
||||
guild: Optional[Snowflake] = MISSING,
|
||||
guilds: List[Snowflake] = MISSING,
|
||||
) -> None:
|
||||
"""Adds a "cog" to the bot.
|
||||
"""|coro|
|
||||
|
||||
Adds a "cog" to the bot.
|
||||
|
||||
A cog is a class that has its own event listeners and commands.
|
||||
|
||||
If the cog is a :class:`.app_commands.Group` then it is added to
|
||||
the bot's :class:`~discord.app_commands.CommandTree` as well.
|
||||
|
||||
.. note::
|
||||
|
||||
Exceptions raised inside a `class`:.Cog:'s :meth:`~.Cog.cog_load` method will be
|
||||
propagated to the caller.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
|
||||
:exc:`.ClientException` is raised when a cog with the same name
|
||||
@ -553,6 +560,10 @@ class BotBase(GroupMixin):
|
||||
|
||||
``cog`` parameter is now positional-only.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
|
||||
This method is now a :term:`coroutine`.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
cog: :class:`.Cog`
|
||||
@ -595,12 +606,12 @@ class BotBase(GroupMixin):
|
||||
if existing is not None:
|
||||
if not override:
|
||||
raise discord.ClientException(f'Cog named {cog_name!r} already loaded')
|
||||
self.remove_cog(cog_name, guild=guild, guilds=guilds)
|
||||
await self.remove_cog(cog_name, guild=guild, guilds=guilds)
|
||||
|
||||
if isinstance(cog, app_commands.Group):
|
||||
self.__tree.add_command(cog, override=override, guild=guild, guilds=guilds)
|
||||
|
||||
cog = cog._inject(self, override=override, guild=guild, guilds=guilds)
|
||||
cog = await cog._inject(self, override=override, guild=guild, guilds=guilds)
|
||||
self.__cogs[cog_name] = cog
|
||||
|
||||
def get_cog(self, name: str, /) -> Optional[Cog]:
|
||||
@ -626,14 +637,16 @@ class BotBase(GroupMixin):
|
||||
"""
|
||||
return self.__cogs.get(name)
|
||||
|
||||
def remove_cog(
|
||||
async def remove_cog(
|
||||
self,
|
||||
name: str,
|
||||
/,
|
||||
guild: Optional[Snowflake] = MISSING,
|
||||
guilds: List[Snowflake] = MISSING,
|
||||
) -> Optional[Cog]:
|
||||
"""Removes a cog from the bot and returns it.
|
||||
"""|coro|
|
||||
|
||||
Removes a cog from the bot and returns it.
|
||||
|
||||
All registered commands and event listeners that the
|
||||
cog has registered will be removed as well.
|
||||
@ -644,6 +657,10 @@ class BotBase(GroupMixin):
|
||||
|
||||
``name`` parameter is now positional-only.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
|
||||
This method is now a :term:`coroutine`.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
name: :class:`str`
|
||||
@ -684,7 +701,7 @@ class BotBase(GroupMixin):
|
||||
for guild_id in guild_ids:
|
||||
self.__tree.remove_command(name, guild=discord.Object(guild_id))
|
||||
|
||||
cog._eject(self, guild_ids=guild_ids)
|
||||
await cog._eject(self, guild_ids=guild_ids)
|
||||
|
||||
return cog
|
||||
|
||||
@ -695,12 +712,12 @@ class BotBase(GroupMixin):
|
||||
|
||||
# extensions
|
||||
|
||||
def _remove_module_references(self, name: str) -> None:
|
||||
async def _remove_module_references(self, name: str) -> None:
|
||||
# find all references to the module
|
||||
# remove the cogs registered from the module
|
||||
for cogname, cog in self.__cogs.copy().items():
|
||||
if _is_submodule(name, cog.__module__):
|
||||
self.remove_cog(cogname)
|
||||
await self.remove_cog(cogname)
|
||||
|
||||
# remove all the commands from the module
|
||||
for cmd in self.all_commands.copy().values():
|
||||
@ -722,14 +739,14 @@ class BotBase(GroupMixin):
|
||||
# remove all relevant application commands from the tree
|
||||
self.__tree._remove_with_module(name)
|
||||
|
||||
def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None:
|
||||
async def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None:
|
||||
try:
|
||||
func = getattr(lib, 'teardown')
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
func(self)
|
||||
await func(self)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
@ -740,7 +757,7 @@ class BotBase(GroupMixin):
|
||||
if _is_submodule(name, module):
|
||||
del sys.modules[module]
|
||||
|
||||
def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None:
|
||||
async def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None:
|
||||
# precondition: key not in self.__extensions
|
||||
lib = importlib.util.module_from_spec(spec)
|
||||
sys.modules[key] = lib
|
||||
@ -757,11 +774,11 @@ class BotBase(GroupMixin):
|
||||
raise errors.NoEntryPointError(key)
|
||||
|
||||
try:
|
||||
setup(self)
|
||||
await setup(self)
|
||||
except Exception as e:
|
||||
del sys.modules[key]
|
||||
self._remove_module_references(lib.__name__)
|
||||
self._call_module_finalizers(lib, key)
|
||||
await self._remove_module_references(lib.__name__)
|
||||
await self._call_module_finalizers(lib, key)
|
||||
raise errors.ExtensionFailed(key, e) from e
|
||||
else:
|
||||
self.__extensions[key] = lib
|
||||
@ -772,8 +789,10 @@ class BotBase(GroupMixin):
|
||||
except ImportError:
|
||||
raise errors.ExtensionNotFound(name)
|
||||
|
||||
def load_extension(self, name: str, *, package: Optional[str] = None) -> None:
|
||||
"""Loads an extension.
|
||||
async def load_extension(self, name: str, *, package: Optional[str] = None) -> None:
|
||||
"""|coro|
|
||||
|
||||
Loads an extension.
|
||||
|
||||
An extension is a python module that contains commands, cogs, or
|
||||
listeners.
|
||||
@ -782,6 +801,10 @@ class BotBase(GroupMixin):
|
||||
the entry point on what to do when the extension is loaded. This entry
|
||||
point must have a single argument, the ``bot``.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
|
||||
This method is now a :term:`coroutine`.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
name: :class:`str`
|
||||
@ -817,10 +840,12 @@ class BotBase(GroupMixin):
|
||||
if spec is None:
|
||||
raise errors.ExtensionNotFound(name)
|
||||
|
||||
self._load_from_module_spec(spec, name)
|
||||
await self._load_from_module_spec(spec, name)
|
||||
|
||||
def unload_extension(self, name: str, *, package: Optional[str] = None) -> None:
|
||||
"""Unloads an extension.
|
||||
async def unload_extension(self, name: str, *, package: Optional[str] = None) -> None:
|
||||
"""|coro|
|
||||
|
||||
Unloads an extension.
|
||||
|
||||
When the extension is unloaded, all commands, listeners, and cogs are
|
||||
removed from the bot and the module is un-imported.
|
||||
@ -830,6 +855,10 @@ class BotBase(GroupMixin):
|
||||
parameter, the ``bot``, similar to ``setup`` from
|
||||
:meth:`~.Bot.load_extension`.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
|
||||
This method is now a :term:`coroutine`.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
name: :class:`str`
|
||||
@ -857,10 +886,10 @@ class BotBase(GroupMixin):
|
||||
if lib is None:
|
||||
raise errors.ExtensionNotLoaded(name)
|
||||
|
||||
self._remove_module_references(lib.__name__)
|
||||
self._call_module_finalizers(lib, name)
|
||||
await self._remove_module_references(lib.__name__)
|
||||
await self._call_module_finalizers(lib, name)
|
||||
|
||||
def reload_extension(self, name: str, *, package: Optional[str] = None) -> None:
|
||||
async def reload_extension(self, name: str, *, package: Optional[str] = None) -> None:
|
||||
"""Atomically reloads an extension.
|
||||
|
||||
This replaces the extension with the same extension, only refreshed. This is
|
||||
@ -911,14 +940,14 @@ class BotBase(GroupMixin):
|
||||
|
||||
try:
|
||||
# Unload and then load the module...
|
||||
self._remove_module_references(lib.__name__)
|
||||
self._call_module_finalizers(lib, name)
|
||||
self.load_extension(name)
|
||||
await self._remove_module_references(lib.__name__)
|
||||
await self._call_module_finalizers(lib, name)
|
||||
await self.load_extension(name)
|
||||
except Exception:
|
||||
# if the load failed, the remnants should have been
|
||||
# cleaned from the load_extension function call
|
||||
# so let's load it from our old compiled library.
|
||||
lib.setup(self) # type: ignore
|
||||
await lib.setup(self) # type: ignore
|
||||
self.__extensions[name] = lib
|
||||
|
||||
# revert sys.modules back to normal and raise back to caller
|
||||
|
@ -26,6 +26,7 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.utils import maybe_coroutine
|
||||
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union
|
||||
|
||||
@ -377,13 +378,30 @@ class Cog(metaclass=CogMeta):
|
||||
return not hasattr(self.cog_command_error.__func__, '__cog_special_method__')
|
||||
|
||||
@_cog_special_method
|
||||
def cog_unload(self) -> None:
|
||||
"""A special method that is called when the cog gets removed.
|
||||
async def cog_load(self) -> None:
|
||||
"""|maybecoro|
|
||||
|
||||
This function **cannot** be a coroutine. It must be a regular
|
||||
function.
|
||||
A special method that is called when the cog gets loaded.
|
||||
|
||||
Subclasses must replace this if they want special asynchronous loading behaviour.
|
||||
Note that the ``__init__`` special method does not allow asynchronous code to run
|
||||
inside it, thus this is helpful for setting up code that needs to be asynchronous.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
pass
|
||||
|
||||
@_cog_special_method
|
||||
async def cog_unload(self) -> None:
|
||||
"""|maybecoro|
|
||||
|
||||
A special method that is called when the cog gets removed.
|
||||
|
||||
Subclasses must replace this if they want special unloading behaviour.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
|
||||
This method can now be a :term:`coroutine`.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -466,9 +484,13 @@ class Cog(metaclass=CogMeta):
|
||||
"""
|
||||
pass
|
||||
|
||||
def _inject(self, bot: BotBase, override: bool, guild: Optional[Snowflake], guilds: List[Snowflake]) -> Self:
|
||||
async def _inject(self, bot: BotBase, override: bool, guild: Optional[Snowflake], guilds: List[Snowflake]) -> Self:
|
||||
cls = self.__class__
|
||||
|
||||
# we'll call this first so that errors can propagate without
|
||||
# having to worry about undoing anything
|
||||
await maybe_coroutine(self.cog_load)
|
||||
|
||||
# realistically, the only thing that can cause loading errors
|
||||
# is essentially just the command loading, which raises if there are
|
||||
# duplicates. When this condition is met, we want to undo all what
|
||||
@ -507,7 +529,7 @@ class Cog(metaclass=CogMeta):
|
||||
|
||||
return self
|
||||
|
||||
def _eject(self, bot: BotBase, guild_ids: Optional[Iterable[int]]) -> None:
|
||||
async def _eject(self, bot: BotBase, guild_ids: Optional[Iterable[int]]) -> None:
|
||||
cls = self.__class__
|
||||
|
||||
try:
|
||||
@ -534,6 +556,6 @@ class Cog(metaclass=CogMeta):
|
||||
bot.remove_check(self.bot_check_once, call_once=True)
|
||||
finally:
|
||||
try:
|
||||
self.cog_unload()
|
||||
await maybe_coroutine(self.cog_unload)
|
||||
except Exception:
|
||||
pass
|
||||
|
@ -2350,7 +2350,6 @@ def before_invoke(coro) -> Callable[[T], T]:
|
||||
async def why(self, ctx): # Output: <Nothing>
|
||||
await ctx.send('because someone made me')
|
||||
|
||||
bot.add_cog(What())
|
||||
"""
|
||||
|
||||
def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
|
||||
|
Reference in New Issue
Block a user