[commads] Change cog/extension load/unload methods to be async

This commit is contained in:
Josh
2022-03-14 11:03:45 +10:00
committed by GitHub
parent a339e01047
commit a1c618215e
8 changed files with 210 additions and 50 deletions

View File

@ -170,13 +170,13 @@ class BotBase(GroupMixin):
async def close(self) -> None:
for extension in tuple(self.__extensions):
try:
self.unload_extension(extension)
await self.unload_extension(extension)
except Exception:
pass
for cog in tuple(self.__cogs):
try:
self.remove_cog(cog)
await self.remove_cog(cog)
except Exception:
pass
@ -528,7 +528,7 @@ class BotBase(GroupMixin):
# cogs
def add_cog(
async def add_cog(
self,
cog: Cog,
/,
@ -537,13 +537,20 @@ class BotBase(GroupMixin):
guild: Optional[Snowflake] = MISSING,
guilds: List[Snowflake] = MISSING,
) -> None:
"""Adds a "cog" to the bot.
"""|coro|
Adds a "cog" to the bot.
A cog is a class that has its own event listeners and commands.
If the cog is a :class:`.app_commands.Group` then it is added to
the bot's :class:`~discord.app_commands.CommandTree` as well.
.. note::
Exceptions raised inside a `class`:.Cog:'s :meth:`~.Cog.cog_load` method will be
propagated to the caller.
.. versionchanged:: 2.0
:exc:`.ClientException` is raised when a cog with the same name
@ -553,6 +560,10 @@ class BotBase(GroupMixin):
``cog`` parameter is now positional-only.
.. versionchanged:: 2.0
This method is now a :term:`coroutine`.
Parameters
-----------
cog: :class:`.Cog`
@ -595,12 +606,12 @@ class BotBase(GroupMixin):
if existing is not None:
if not override:
raise discord.ClientException(f'Cog named {cog_name!r} already loaded')
self.remove_cog(cog_name, guild=guild, guilds=guilds)
await self.remove_cog(cog_name, guild=guild, guilds=guilds)
if isinstance(cog, app_commands.Group):
self.__tree.add_command(cog, override=override, guild=guild, guilds=guilds)
cog = cog._inject(self, override=override, guild=guild, guilds=guilds)
cog = await cog._inject(self, override=override, guild=guild, guilds=guilds)
self.__cogs[cog_name] = cog
def get_cog(self, name: str, /) -> Optional[Cog]:
@ -626,14 +637,16 @@ class BotBase(GroupMixin):
"""
return self.__cogs.get(name)
def remove_cog(
async def remove_cog(
self,
name: str,
/,
guild: Optional[Snowflake] = MISSING,
guilds: List[Snowflake] = MISSING,
) -> Optional[Cog]:
"""Removes a cog from the bot and returns it.
"""|coro|
Removes a cog from the bot and returns it.
All registered commands and event listeners that the
cog has registered will be removed as well.
@ -644,6 +657,10 @@ class BotBase(GroupMixin):
``name`` parameter is now positional-only.
.. versionchanged:: 2.0
This method is now a :term:`coroutine`.
Parameters
-----------
name: :class:`str`
@ -684,7 +701,7 @@ class BotBase(GroupMixin):
for guild_id in guild_ids:
self.__tree.remove_command(name, guild=discord.Object(guild_id))
cog._eject(self, guild_ids=guild_ids)
await cog._eject(self, guild_ids=guild_ids)
return cog
@ -695,12 +712,12 @@ class BotBase(GroupMixin):
# extensions
def _remove_module_references(self, name: str) -> None:
async def _remove_module_references(self, name: str) -> None:
# find all references to the module
# remove the cogs registered from the module
for cogname, cog in self.__cogs.copy().items():
if _is_submodule(name, cog.__module__):
self.remove_cog(cogname)
await self.remove_cog(cogname)
# remove all the commands from the module
for cmd in self.all_commands.copy().values():
@ -722,14 +739,14 @@ class BotBase(GroupMixin):
# remove all relevant application commands from the tree
self.__tree._remove_with_module(name)
def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None:
async def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None:
try:
func = getattr(lib, 'teardown')
except AttributeError:
pass
else:
try:
func(self)
await func(self)
except Exception:
pass
finally:
@ -740,7 +757,7 @@ class BotBase(GroupMixin):
if _is_submodule(name, module):
del sys.modules[module]
def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None:
async def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None:
# precondition: key not in self.__extensions
lib = importlib.util.module_from_spec(spec)
sys.modules[key] = lib
@ -757,11 +774,11 @@ class BotBase(GroupMixin):
raise errors.NoEntryPointError(key)
try:
setup(self)
await setup(self)
except Exception as e:
del sys.modules[key]
self._remove_module_references(lib.__name__)
self._call_module_finalizers(lib, key)
await self._remove_module_references(lib.__name__)
await self._call_module_finalizers(lib, key)
raise errors.ExtensionFailed(key, e) from e
else:
self.__extensions[key] = lib
@ -772,8 +789,10 @@ class BotBase(GroupMixin):
except ImportError:
raise errors.ExtensionNotFound(name)
def load_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Loads an extension.
async def load_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""|coro|
Loads an extension.
An extension is a python module that contains commands, cogs, or
listeners.
@ -782,6 +801,10 @@ class BotBase(GroupMixin):
the entry point on what to do when the extension is loaded. This entry
point must have a single argument, the ``bot``.
.. versionchanged:: 2.0
This method is now a :term:`coroutine`.
Parameters
------------
name: :class:`str`
@ -817,10 +840,12 @@ class BotBase(GroupMixin):
if spec is None:
raise errors.ExtensionNotFound(name)
self._load_from_module_spec(spec, name)
await self._load_from_module_spec(spec, name)
def unload_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Unloads an extension.
async def unload_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""|coro|
Unloads an extension.
When the extension is unloaded, all commands, listeners, and cogs are
removed from the bot and the module is un-imported.
@ -830,6 +855,10 @@ class BotBase(GroupMixin):
parameter, the ``bot``, similar to ``setup`` from
:meth:`~.Bot.load_extension`.
.. versionchanged:: 2.0
This method is now a :term:`coroutine`.
Parameters
------------
name: :class:`str`
@ -857,10 +886,10 @@ class BotBase(GroupMixin):
if lib is None:
raise errors.ExtensionNotLoaded(name)
self._remove_module_references(lib.__name__)
self._call_module_finalizers(lib, name)
await self._remove_module_references(lib.__name__)
await self._call_module_finalizers(lib, name)
def reload_extension(self, name: str, *, package: Optional[str] = None) -> None:
async def reload_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Atomically reloads an extension.
This replaces the extension with the same extension, only refreshed. This is
@ -911,14 +940,14 @@ class BotBase(GroupMixin):
try:
# Unload and then load the module...
self._remove_module_references(lib.__name__)
self._call_module_finalizers(lib, name)
self.load_extension(name)
await self._remove_module_references(lib.__name__)
await self._call_module_finalizers(lib, name)
await self.load_extension(name)
except Exception:
# if the load failed, the remnants should have been
# cleaned from the load_extension function call
# so let's load it from our old compiled library.
lib.setup(self) # type: ignore
await lib.setup(self) # type: ignore
self.__extensions[name] = lib
# revert sys.modules back to normal and raise back to caller