[commands] Don't raise ExtensionNotFound for ImportErrors in modules

Now loading an extension that _contains_ a failed import will fail
with ExtensionFailed, rather than ExtensionNotFound.
This commit is contained in:
Benjamin Mintz 2019-06-24 06:12:23 +00:00 committed by Rapptz
parent 3961e7ef6d
commit 0a21591d0c
2 changed files with 23 additions and 17 deletions

View File

@ -27,7 +27,7 @@ DEALINGS IN THE SOFTWARE.
import asyncio import asyncio
import collections import collections
import inspect import inspect
import importlib import importlib.util
import sys import sys
import traceback import traceback
import re import re
@ -588,12 +588,17 @@ class BotBase(GroupMixin):
if _is_submodule(name, module): if _is_submodule(name, module):
del sys.modules[module] del sys.modules[module]
def _load_from_module_spec(self, lib, key): def _load_from_module_spec(self, spec, key):
# precondition: key not in self.__extensions # precondition: key not in self.__extensions
lib = importlib.util.module_from_spec(spec)
try:
spec.loader.exec_module(lib)
except Exception as e:
raise errors.ExtensionFailed(key, e) from e
try: try:
setup = getattr(lib, 'setup') setup = getattr(lib, 'setup')
except AttributeError: except AttributeError:
del sys.modules[key]
raise errors.NoEntryPointError(key) raise errors.NoEntryPointError(key)
try: try:
@ -603,7 +608,7 @@ class BotBase(GroupMixin):
self._call_module_finalizers(lib, key) self._call_module_finalizers(lib, key)
raise errors.ExtensionFailed(key, e) from e raise errors.ExtensionFailed(key, e) from e
else: else:
self.__extensions[key] = lib sys.modules[key] = self.__extensions[key] = lib
def load_extension(self, name): def load_extension(self, name):
"""Loads an extension. """Loads an extension.
@ -637,12 +642,11 @@ class BotBase(GroupMixin):
if name in self.__extensions: if name in self.__extensions:
raise errors.ExtensionAlreadyLoaded(name) raise errors.ExtensionAlreadyLoaded(name)
try: spec = importlib.util.find_spec(name)
lib = importlib.import_module(name) if spec is None:
except ImportError as e: raise errors.ExtensionNotFound(name)
raise errors.ExtensionNotFound(name, e) from e
else: self._load_from_module_spec(spec, name)
self._load_from_module_spec(lib, name)
def unload_extension(self, name): def unload_extension(self, name):
"""Unloads an extension. """Unloads an extension.

View File

@ -503,7 +503,7 @@ class NoEntryPointError(ExtensionError):
super().__init__("Extension {!r} has no 'setup' function.".format(name), name=name) super().__init__("Extension {!r} has no 'setup' function.".format(name), name=name)
class ExtensionFailed(ExtensionError): class ExtensionFailed(ExtensionError):
"""An exception raised when an extension failed to load during execution of the ``setup`` entry point. """An exception raised when an extension failed to load during execution of the module or ``setup`` entry point.
This inherits from :exc:`ExtensionError` This inherits from :exc:`ExtensionError`
@ -521,19 +521,21 @@ class ExtensionFailed(ExtensionError):
super().__init__(fmt.format(name, original), name=name) super().__init__(fmt.format(name, original), name=name)
class ExtensionNotFound(ExtensionError): class ExtensionNotFound(ExtensionError):
"""An exception raised when an extension failed to be imported. """An exception raised when an extension is not found.
This inherits from :exc:`ExtensionError` This inherits from :exc:`ExtensionError`
.. versionchanged:: 1.3.0
Made the ``original`` attribute always None.
Attributes Attributes
----------- -----------
name: :class:`str` name: :class:`str`
The extension that had the error. The extension that had the error.
original: :exc:`ImportError` original: :class:`NoneType`
The original exception that was raised. You can also get this via Always ``None`` for backwards compatibility.
the ``__cause__`` attribute.
""" """
def __init__(self, name, original): def __init__(self, name, original=None):
self.original = original self.original = None
fmt = 'Extension {0!r} could not be loaded.' fmt = 'Extension {0!r} could not be loaded.'
super().__init__(fmt.format(name), name=name) super().__init__(fmt.format(name), name=name)