[commands] Don't raise ExtensionNotFound for ImportErrors in modules

Now loading an extension that _contains_ a failed import will fail
with ExtensionFailed, rather than ExtensionNotFound.
This commit is contained in:
Benjamin Mintz
2019-06-24 06:12:23 +00:00
committed by Rapptz
parent 3961e7ef6d
commit 0a21591d0c
2 changed files with 23 additions and 17 deletions

View File

@@ -27,7 +27,7 @@ DEALINGS IN THE SOFTWARE.
import asyncio
import collections
import inspect
import importlib
import importlib.util
import sys
import traceback
import re
@@ -588,12 +588,17 @@ class BotBase(GroupMixin):
if _is_submodule(name, module):
del sys.modules[module]
def _load_from_module_spec(self, lib, key):
def _load_from_module_spec(self, spec, key):
# precondition: key not in self.__extensions
lib = importlib.util.module_from_spec(spec)
try:
spec.loader.exec_module(lib)
except Exception as e:
raise errors.ExtensionFailed(key, e) from e
try:
setup = getattr(lib, 'setup')
except AttributeError:
del sys.modules[key]
raise errors.NoEntryPointError(key)
try:
@@ -603,7 +608,7 @@ class BotBase(GroupMixin):
self._call_module_finalizers(lib, key)
raise errors.ExtensionFailed(key, e) from e
else:
self.__extensions[key] = lib
sys.modules[key] = self.__extensions[key] = lib
def load_extension(self, name):
"""Loads an extension.
@@ -637,12 +642,11 @@ class BotBase(GroupMixin):
if name in self.__extensions:
raise errors.ExtensionAlreadyLoaded(name)
try:
lib = importlib.import_module(name)
except ImportError as e:
raise errors.ExtensionNotFound(name, e) from e
else:
self._load_from_module_spec(lib, name)
spec = importlib.util.find_spec(name)
if spec is None:
raise errors.ExtensionNotFound(name)
self._load_from_module_spec(spec, name)
def unload_extension(self, name):
"""Unloads an extension.