[commands] Don't raise ExtensionNotFound for ImportErrors in modules
Now loading an extension that _contains_ a failed import will fail with ExtensionFailed, rather than ExtensionNotFound.
This commit is contained in:
		| @@ -27,7 +27,7 @@ DEALINGS IN THE SOFTWARE. | ||||
| import asyncio | ||||
| import collections | ||||
| import inspect | ||||
| import importlib | ||||
| import importlib.util | ||||
| import sys | ||||
| import traceback | ||||
| import re | ||||
| @@ -588,12 +588,17 @@ class BotBase(GroupMixin): | ||||
|                 if _is_submodule(name, module): | ||||
|                     del sys.modules[module] | ||||
|  | ||||
|     def _load_from_module_spec(self, lib, key): | ||||
|     def _load_from_module_spec(self, spec, key): | ||||
|         # precondition: key not in self.__extensions | ||||
|         lib = importlib.util.module_from_spec(spec) | ||||
|         try: | ||||
|             spec.loader.exec_module(lib) | ||||
|         except Exception as e: | ||||
|             raise errors.ExtensionFailed(key, e) from e | ||||
|  | ||||
|         try: | ||||
|             setup = getattr(lib, 'setup') | ||||
|         except AttributeError: | ||||
|             del sys.modules[key] | ||||
|             raise errors.NoEntryPointError(key) | ||||
|  | ||||
|         try: | ||||
| @@ -603,7 +608,7 @@ class BotBase(GroupMixin): | ||||
|             self._call_module_finalizers(lib, key) | ||||
|             raise errors.ExtensionFailed(key, e) from e | ||||
|         else: | ||||
|             self.__extensions[key] = lib | ||||
|             sys.modules[key] = self.__extensions[key] = lib | ||||
|  | ||||
|     def load_extension(self, name): | ||||
|         """Loads an extension. | ||||
| @@ -637,12 +642,11 @@ class BotBase(GroupMixin): | ||||
|         if name in self.__extensions: | ||||
|             raise errors.ExtensionAlreadyLoaded(name) | ||||
|  | ||||
|         try: | ||||
|             lib = importlib.import_module(name) | ||||
|         except ImportError as e: | ||||
|             raise errors.ExtensionNotFound(name, e) from e | ||||
|         else: | ||||
|             self._load_from_module_spec(lib, name) | ||||
|         spec = importlib.util.find_spec(name) | ||||
|         if spec is None: | ||||
|             raise errors.ExtensionNotFound(name) | ||||
|  | ||||
|         self._load_from_module_spec(spec, name) | ||||
|  | ||||
|     def unload_extension(self, name): | ||||
|         """Unloads an extension. | ||||
|   | ||||
| @@ -503,7 +503,7 @@ class NoEntryPointError(ExtensionError): | ||||
|         super().__init__("Extension {!r} has no 'setup' function.".format(name), name=name) | ||||
|  | ||||
| class ExtensionFailed(ExtensionError): | ||||
|     """An exception raised when an extension failed to load during execution of the ``setup`` entry point. | ||||
|     """An exception raised when an extension failed to load during execution of the module or ``setup`` entry point. | ||||
|  | ||||
|     This inherits from :exc:`ExtensionError` | ||||
|  | ||||
| @@ -521,19 +521,21 @@ class ExtensionFailed(ExtensionError): | ||||
|         super().__init__(fmt.format(name, original), name=name) | ||||
|  | ||||
| class ExtensionNotFound(ExtensionError): | ||||
|     """An exception raised when an extension failed to be imported. | ||||
|     """An exception raised when an extension is not found. | ||||
|  | ||||
|     This inherits from :exc:`ExtensionError` | ||||
|  | ||||
|     .. versionchanged:: 1.3.0 | ||||
|         Made the ``original`` attribute always None. | ||||
|  | ||||
|     Attributes | ||||
|     ----------- | ||||
|     name: :class:`str` | ||||
|         The extension that had the error. | ||||
|     original: :exc:`ImportError` | ||||
|         The original exception that was raised. You can also get this via | ||||
|         the ``__cause__`` attribute. | ||||
|     original: :class:`NoneType` | ||||
|         Always ``None`` for backwards compatibility. | ||||
|     """ | ||||
|     def __init__(self, name, original): | ||||
|         self.original = original | ||||
|     def __init__(self, name, original=None): | ||||
|         self.original = None | ||||
|         fmt = 'Extension {0!r} could not be loaded.' | ||||
|         super().__init__(fmt.format(name), name=name) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user