[commands] Add Bot.reload_extension for atomic loading.

Also do atomic loading in Bot.load_extension
This commit is contained in:
Rapptz
2019-03-19 06:21:39 -04:00
parent d221ca5f7d
commit 26e9b5bfac
2 changed files with 109 additions and 47 deletions

View File

@ -523,6 +523,65 @@ class BotBase(GroupMixin):
# extensions
def _remove_module_references(self, name):
# find all references to the module
# remove the cogs registered from the module
for cogname, cog in self._cogs.copy().items():
if _is_submodule(name, cog.__module__):
self.remove_cog(cogname)
# remove all the commands from the module
for cmd in self.all_commands.copy().values():
if cmd.module is not None and _is_submodule(name, cmd.module):
if isinstance(cmd, GroupMixin):
cmd.recursively_remove_all_commands()
self.remove_command(cmd.name)
# remove all the listeners from the module
for event_list in self.extra_events.copy().values():
remove = []
for index, event in enumerate(event_list):
if event.__module__ is not None and _is_submodule(name, event.__module__):
remove.append(index)
for index in reversed(remove):
del event_list[index]
def _call_module_finalizers(self, lib, key):
try:
func = getattr(lib, 'teardown')
except AttributeError:
pass
else:
try:
func(self)
except Exception:
pass
finally:
self._extensions.pop(key, None)
sys.modules.pop(key, None)
name = lib.__name__
for module in list(sys.modules.keys()):
if _is_submodule(name, module):
del sys.modules[module]
def _load_from_module_spec(self, lib, key):
# precondition: key not in self._extensions
try:
setup = getattr(lib, 'setup')
except AttributeError:
del sys.modules[key]
raise discord.ClientException('extension {!r} ({!r}) does not have a setup function.'.format(key, lib))
try:
setup(self)
except Exception:
self._remove_module_references(lib.__name__)
self._call_module_finalizers(lib, key)
raise
else:
self._extensions[key] = lib
def load_extension(self, name):
"""Loads an extension.
@ -546,19 +605,16 @@ class BotBase(GroupMixin):
The extension does not have a setup function.
ImportError
The extension could not be imported.
Exception
Any other exception raised by the extension will be raised back
to the caller.
"""
if name in self._extensions:
return
lib = importlib.import_module(name)
if not hasattr(lib, 'setup'):
del lib
del sys.modules[name]
raise discord.ClientException('extension does not have a setup function')
lib.setup(self)
self._extensions[name] = lib
self._load_from_module_spec(lib, name)
def unload_extension(self, name):
"""Unloads an extension.
@ -583,49 +639,56 @@ class BotBase(GroupMixin):
if lib is None:
return
lib_name = lib.__name__
self._remove_module_references(lib.__name__)
self._call_module_finalizers(lib, name)
# find all references to the module
def reload_extension(self, name):
"""Atomically reloads an extension.
# remove the cogs registered from the module
for cogname, cog in self._cogs.copy().items():
if _is_submodule(lib_name, cog.__module__):
self.remove_cog(cogname)
This replaces the extension with the same extension, only refreshed. This is
equivalent to a :meth:`unload_extension` followed by a :meth:`load_extension`
except done in an atomic way. That is, if an operation fails mid-reload then
the bot will roll-back to the prior working state.
# remove all the commands from the module
for cmd in self.all_commands.copy().values():
if cmd.module is not None and _is_submodule(lib_name, cmd.module):
if isinstance(cmd, GroupMixin):
cmd.recursively_remove_all_commands()
self.remove_command(cmd.name)
Parameters
------------
name: :class:`str`
The extension name to reload. It must be dot separated like
regular Python imports if accessing a sub-module. e.g.
``foo.test`` if you want to import ``foo/test.py``.
# remove all the listeners from the module
for event_list in self.extra_events.copy().values():
remove = []
for index, event in enumerate(event_list):
if event.__module__ is not None and _is_submodule(lib_name, event.__module__):
remove.append(index)
Raises
-------
Exception
Any exception raised by the extension will be raised back
to the caller.
"""
for index in reversed(remove):
del event_list[index]
lib = self._extensions.get(name)
if lib is None:
return
# get the previous module states from sys modules
modules = {
name: module
for name, module in sys.modules.items()
if _is_submodule(lib.__name__, name)
}
try:
func = getattr(lib, 'teardown')
except AttributeError:
pass
else:
try:
func(self)
except Exception:
pass
finally:
# finally remove the import..
del lib
del self._extensions[name]
del sys.modules[name]
for module in list(sys.modules.keys()):
if _is_submodule(lib_name, module):
del sys.modules[module]
# Unload and then load the module...
self._remove_module_references(lib.__name__)
self._call_module_finalizers(lib, name)
self.load_extension(name)
except Exception as e:
# if the load failed, the remnants should have been
# cleaned from the load_extension function call
# so let's load it from our old compiled library.
self._load_from_module_spec(lib, name)
# revert sys.modules back to normal and raise back to caller
sys.modules.update(modules)
raise
@property
def extensions(self):