diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py index 2a0e542b..ef879a5d 100644 --- a/tests/test_wrapper.py +++ b/tests/test_wrapper.py @@ -38,7 +38,7 @@ class RepoWrapperTestCase(unittest.TestCase): def setUp(self): """Load the wrapper module every time.""" - wrapper._wrapper_module = None + wrapper.Wrapper.cache_clear() self.wrapper = wrapper.Wrapper() diff --git a/wrapper.py b/wrapper.py index 65dcf3c6..3099ad5d 100644 --- a/wrapper.py +++ b/wrapper.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import importlib.machinery import importlib.util import os @@ -21,15 +22,11 @@ def WrapperPath(): return os.path.join(os.path.dirname(__file__), 'repo') -_wrapper_module = None - - +@functools.lru_cache(maxsize=None) def Wrapper(): - global _wrapper_module - if not _wrapper_module: - modname = 'wrapper' - loader = importlib.machinery.SourceFileLoader(modname, WrapperPath()) - spec = importlib.util.spec_from_loader(modname, loader) - _wrapper_module = importlib.util.module_from_spec(spec) - loader.exec_module(_wrapper_module) - return _wrapper_module + modname = 'wrapper' + loader = importlib.machinery.SourceFileLoader(modname, WrapperPath()) + spec = importlib.util.spec_from_loader(modname, loader) + module = importlib.util.module_from_spec(spec) + loader.exec_module(module) + return module