[commands] Don't raise ExtensionNotFound for ImportErrors in modules
Now loading an extension that _contains_ a failed import will fail with ExtensionFailed, rather than ExtensionNotFound.
This commit is contained in:
parent
3961e7ef6d
commit
0a21591d0c
@ -27,7 +27,7 @@ DEALINGS IN THE SOFTWARE.
|
||||
import asyncio
|
||||
import collections
|
||||
import inspect
|
||||
import importlib
|
||||
import importlib.util
|
||||
import sys
|
||||
import traceback
|
||||
import re
|
||||
@ -588,12 +588,17 @@ class BotBase(GroupMixin):
|
||||
if _is_submodule(name, module):
|
||||
del sys.modules[module]
|
||||
|
||||
def _load_from_module_spec(self, lib, key):
|
||||
def _load_from_module_spec(self, spec, key):
|
||||
# precondition: key not in self.__extensions
|
||||
lib = importlib.util.module_from_spec(spec)
|
||||
try:
|
||||
spec.loader.exec_module(lib)
|
||||
except Exception as e:
|
||||
raise errors.ExtensionFailed(key, e) from e
|
||||
|
||||
try:
|
||||
setup = getattr(lib, 'setup')
|
||||
except AttributeError:
|
||||
del sys.modules[key]
|
||||
raise errors.NoEntryPointError(key)
|
||||
|
||||
try:
|
||||
@ -603,7 +608,7 @@ class BotBase(GroupMixin):
|
||||
self._call_module_finalizers(lib, key)
|
||||
raise errors.ExtensionFailed(key, e) from e
|
||||
else:
|
||||
self.__extensions[key] = lib
|
||||
sys.modules[key] = self.__extensions[key] = lib
|
||||
|
||||
def load_extension(self, name):
|
||||
"""Loads an extension.
|
||||
@ -637,12 +642,11 @@ class BotBase(GroupMixin):
|
||||
if name in self.__extensions:
|
||||
raise errors.ExtensionAlreadyLoaded(name)
|
||||
|
||||
try:
|
||||
lib = importlib.import_module(name)
|
||||
except ImportError as e:
|
||||
raise errors.ExtensionNotFound(name, e) from e
|
||||
else:
|
||||
self._load_from_module_spec(lib, name)
|
||||
spec = importlib.util.find_spec(name)
|
||||
if spec is None:
|
||||
raise errors.ExtensionNotFound(name)
|
||||
|
||||
self._load_from_module_spec(spec, name)
|
||||
|
||||
def unload_extension(self, name):
|
||||
"""Unloads an extension.
|
||||
|
@ -503,7 +503,7 @@ class NoEntryPointError(ExtensionError):
|
||||
super().__init__("Extension {!r} has no 'setup' function.".format(name), name=name)
|
||||
|
||||
class ExtensionFailed(ExtensionError):
|
||||
"""An exception raised when an extension failed to load during execution of the ``setup`` entry point.
|
||||
"""An exception raised when an extension failed to load during execution of the module or ``setup`` entry point.
|
||||
|
||||
This inherits from :exc:`ExtensionError`
|
||||
|
||||
@ -521,19 +521,21 @@ class ExtensionFailed(ExtensionError):
|
||||
super().__init__(fmt.format(name, original), name=name)
|
||||
|
||||
class ExtensionNotFound(ExtensionError):
|
||||
"""An exception raised when an extension failed to be imported.
|
||||
"""An exception raised when an extension is not found.
|
||||
|
||||
This inherits from :exc:`ExtensionError`
|
||||
|
||||
.. versionchanged:: 1.3.0
|
||||
Made the ``original`` attribute always None.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
name: :class:`str`
|
||||
The extension that had the error.
|
||||
original: :exc:`ImportError`
|
||||
The original exception that was raised. You can also get this via
|
||||
the ``__cause__`` attribute.
|
||||
original: :class:`NoneType`
|
||||
Always ``None`` for backwards compatibility.
|
||||
"""
|
||||
def __init__(self, name, original):
|
||||
self.original = original
|
||||
def __init__(self, name, original=None):
|
||||
self.original = None
|
||||
fmt = 'Extension {0!r} could not be loaded.'
|
||||
super().__init__(fmt.format(name), name=name)
|
||||
|
Loading…
x
Reference in New Issue
Block a user