[commads] Change cog/extension load/unload methods to be async

This commit is contained in:
Josh
2022-03-14 11:03:45 +10:00
committed by GitHub
parent a339e01047
commit a1c618215e
8 changed files with 210 additions and 50 deletions

View File

@ -26,6 +26,7 @@ from __future__ import annotations
import inspect
import discord
from discord import app_commands
from discord.utils import maybe_coroutine
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union
@ -377,13 +378,30 @@ class Cog(metaclass=CogMeta):
return not hasattr(self.cog_command_error.__func__, '__cog_special_method__')
@_cog_special_method
def cog_unload(self) -> None:
"""A special method that is called when the cog gets removed.
async def cog_load(self) -> None:
"""|maybecoro|
This function **cannot** be a coroutine. It must be a regular
function.
A special method that is called when the cog gets loaded.
Subclasses must replace this if they want special asynchronous loading behaviour.
Note that the ``__init__`` special method does not allow asynchronous code to run
inside it, thus this is helpful for setting up code that needs to be asynchronous.
.. versionadded:: 2.0
"""
pass
@_cog_special_method
async def cog_unload(self) -> None:
"""|maybecoro|
A special method that is called when the cog gets removed.
Subclasses must replace this if they want special unloading behaviour.
.. versionchanged:: 2.0
This method can now be a :term:`coroutine`.
"""
pass
@ -466,9 +484,13 @@ class Cog(metaclass=CogMeta):
"""
pass
def _inject(self, bot: BotBase, override: bool, guild: Optional[Snowflake], guilds: List[Snowflake]) -> Self:
async def _inject(self, bot: BotBase, override: bool, guild: Optional[Snowflake], guilds: List[Snowflake]) -> Self:
cls = self.__class__
# we'll call this first so that errors can propagate without
# having to worry about undoing anything
await maybe_coroutine(self.cog_load)
# realistically, the only thing that can cause loading errors
# is essentially just the command loading, which raises if there are
# duplicates. When this condition is met, we want to undo all what
@ -507,7 +529,7 @@ class Cog(metaclass=CogMeta):
return self
def _eject(self, bot: BotBase, guild_ids: Optional[Iterable[int]]) -> None:
async def _eject(self, bot: BotBase, guild_ids: Optional[Iterable[int]]) -> None:
cls = self.__class__
try:
@ -534,6 +556,6 @@ class Cog(metaclass=CogMeta):
bot.remove_check(self.bot_check_once, call_once=True)
finally:
try:
self.cog_unload()
await maybe_coroutine(self.cog_unload)
except Exception:
pass