[commands] Add Bot.reload_extension for atomic loading.
Also do atomic loading in Bot.load_extension
This commit is contained in:
@ -523,6 +523,65 @@ class BotBase(GroupMixin):
|
||||
|
||||
# extensions
|
||||
|
||||
def _remove_module_references(self, name):
|
||||
# find all references to the module
|
||||
# remove the cogs registered from the module
|
||||
for cogname, cog in self._cogs.copy().items():
|
||||
if _is_submodule(name, cog.__module__):
|
||||
self.remove_cog(cogname)
|
||||
|
||||
# remove all the commands from the module
|
||||
for cmd in self.all_commands.copy().values():
|
||||
if cmd.module is not None and _is_submodule(name, cmd.module):
|
||||
if isinstance(cmd, GroupMixin):
|
||||
cmd.recursively_remove_all_commands()
|
||||
self.remove_command(cmd.name)
|
||||
|
||||
# remove all the listeners from the module
|
||||
for event_list in self.extra_events.copy().values():
|
||||
remove = []
|
||||
for index, event in enumerate(event_list):
|
||||
if event.__module__ is not None and _is_submodule(name, event.__module__):
|
||||
remove.append(index)
|
||||
|
||||
for index in reversed(remove):
|
||||
del event_list[index]
|
||||
|
||||
def _call_module_finalizers(self, lib, key):
|
||||
try:
|
||||
func = getattr(lib, 'teardown')
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
func(self)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self._extensions.pop(key, None)
|
||||
sys.modules.pop(key, None)
|
||||
name = lib.__name__
|
||||
for module in list(sys.modules.keys()):
|
||||
if _is_submodule(name, module):
|
||||
del sys.modules[module]
|
||||
|
||||
def _load_from_module_spec(self, lib, key):
|
||||
# precondition: key not in self._extensions
|
||||
try:
|
||||
setup = getattr(lib, 'setup')
|
||||
except AttributeError:
|
||||
del sys.modules[key]
|
||||
raise discord.ClientException('extension {!r} ({!r}) does not have a setup function.'.format(key, lib))
|
||||
|
||||
try:
|
||||
setup(self)
|
||||
except Exception:
|
||||
self._remove_module_references(lib.__name__)
|
||||
self._call_module_finalizers(lib, key)
|
||||
raise
|
||||
else:
|
||||
self._extensions[key] = lib
|
||||
|
||||
def load_extension(self, name):
|
||||
"""Loads an extension.
|
||||
|
||||
@ -546,19 +605,16 @@ class BotBase(GroupMixin):
|
||||
The extension does not have a setup function.
|
||||
ImportError
|
||||
The extension could not be imported.
|
||||
Exception
|
||||
Any other exception raised by the extension will be raised back
|
||||
to the caller.
|
||||
"""
|
||||
|
||||
if name in self._extensions:
|
||||
return
|
||||
|
||||
lib = importlib.import_module(name)
|
||||
if not hasattr(lib, 'setup'):
|
||||
del lib
|
||||
del sys.modules[name]
|
||||
raise discord.ClientException('extension does not have a setup function')
|
||||
|
||||
lib.setup(self)
|
||||
self._extensions[name] = lib
|
||||
self._load_from_module_spec(lib, name)
|
||||
|
||||
def unload_extension(self, name):
|
||||
"""Unloads an extension.
|
||||
@ -583,49 +639,56 @@ class BotBase(GroupMixin):
|
||||
if lib is None:
|
||||
return
|
||||
|
||||
lib_name = lib.__name__
|
||||
self._remove_module_references(lib.__name__)
|
||||
self._call_module_finalizers(lib, name)
|
||||
|
||||
# find all references to the module
|
||||
def reload_extension(self, name):
|
||||
"""Atomically reloads an extension.
|
||||
|
||||
# remove the cogs registered from the module
|
||||
for cogname, cog in self._cogs.copy().items():
|
||||
if _is_submodule(lib_name, cog.__module__):
|
||||
self.remove_cog(cogname)
|
||||
This replaces the extension with the same extension, only refreshed. This is
|
||||
equivalent to a :meth:`unload_extension` followed by a :meth:`load_extension`
|
||||
except done in an atomic way. That is, if an operation fails mid-reload then
|
||||
the bot will roll-back to the prior working state.
|
||||
|
||||
# remove all the commands from the module
|
||||
for cmd in self.all_commands.copy().values():
|
||||
if cmd.module is not None and _is_submodule(lib_name, cmd.module):
|
||||
if isinstance(cmd, GroupMixin):
|
||||
cmd.recursively_remove_all_commands()
|
||||
self.remove_command(cmd.name)
|
||||
Parameters
|
||||
------------
|
||||
name: :class:`str`
|
||||
The extension name to reload. It must be dot separated like
|
||||
regular Python imports if accessing a sub-module. e.g.
|
||||
``foo.test`` if you want to import ``foo/test.py``.
|
||||
|
||||
# remove all the listeners from the module
|
||||
for event_list in self.extra_events.copy().values():
|
||||
remove = []
|
||||
for index, event in enumerate(event_list):
|
||||
if event.__module__ is not None and _is_submodule(lib_name, event.__module__):
|
||||
remove.append(index)
|
||||
Raises
|
||||
-------
|
||||
Exception
|
||||
Any exception raised by the extension will be raised back
|
||||
to the caller.
|
||||
"""
|
||||
|
||||
for index in reversed(remove):
|
||||
del event_list[index]
|
||||
lib = self._extensions.get(name)
|
||||
if lib is None:
|
||||
return
|
||||
|
||||
# get the previous module states from sys modules
|
||||
modules = {
|
||||
name: module
|
||||
for name, module in sys.modules.items()
|
||||
if _is_submodule(lib.__name__, name)
|
||||
}
|
||||
|
||||
try:
|
||||
func = getattr(lib, 'teardown')
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
func(self)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
# finally remove the import..
|
||||
del lib
|
||||
del self._extensions[name]
|
||||
del sys.modules[name]
|
||||
for module in list(sys.modules.keys()):
|
||||
if _is_submodule(lib_name, module):
|
||||
del sys.modules[module]
|
||||
# Unload and then load the module...
|
||||
self._remove_module_references(lib.__name__)
|
||||
self._call_module_finalizers(lib, name)
|
||||
self.load_extension(name)
|
||||
except Exception as e:
|
||||
# if the load failed, the remnants should have been
|
||||
# cleaned from the load_extension function call
|
||||
# so let's load it from our old compiled library.
|
||||
self._load_from_module_spec(lib, name)
|
||||
|
||||
# revert sys.modules back to normal and raise back to caller
|
||||
sys.modules.update(modules)
|
||||
raise
|
||||
|
||||
@property
|
||||
def extensions(self):
|
||||
|
Reference in New Issue
Block a user