diff --git a/b_asic/utils.py b/b_asic/utils.py index 42242fd31297462c1d763e37a372acfbe9b73d2f..11c3007398e2e30d3636d10bb2f4c25e189f75a7 100644 --- a/b_asic/utils.py +++ b/b_asic/utils.py @@ -5,6 +5,7 @@ This module contains functions that are used as utilities by other modules or by from typing import Optional from os import getcwd, path +import importlib from b_asic import AbstractOperation @@ -55,13 +56,24 @@ def load_structure(_path: str) -> AbstractOperation: return None -def load_recipe(module: str) -> None: +def load_recipe(module: str, _namespace: Optional[dict] = None) -> "Module": """Given the module name or the path to a module, import the module and let it evaluate it's content. - Returns None as the content from the module will be added to the namespace. + Returns the attributes of the module to be added to the namespace. This runs .py scripts inline by importing them, it currently does no checks for security measure. Arguments: module: The path or name of the module to import from. + + Keyword Arguments: + _namespace: The namespace to add the imported vars to. Normally you would want to use globals(). If None return it's own namespace. """ - pass + try: + if _namespace is None: + return importlib.import_module(module) + + return _namespace.update(importlib.import_module(module).__dict__) + except Exception as e: + print("Unexpected error occured while loading recipe: ", e) + + return None diff --git a/test/test_load_save_structure.py b/test/test_load_save_structure.py index 4881e2374a5261c61b411eccda69d58158ce8331..549d772be7ced78eb6228bce59b991582c7fee8f 100644 --- a/test/test_load_save_structure.py +++ b/test/test_load_save_structure.py @@ -53,17 +53,17 @@ class TestSaveStructures: def test_load_recipe_file(self): # Create a file that doesn't exist - _file = "".join(choice(ascii_lowercase) for _ in range(4)) + ".py" - while path.exists(path.join(getcwd(), _file)): - _file = "".join(choice(ascii_lowercase) for _ in range(4)) + ".py" + _file = "".join(choice(ascii_lowercase) for _ in range(4)) + while path.exists(path.join(getcwd(), _file + ".py")): + _file = "".join(choice(ascii_lowercase) for _ in range(4)) try: - with open(_file, "w+") as handle: + with open(_file + ".py", "w+") as handle: # The string is indented that way so the file is properly indented, .strip() did not work idk why handle.write( """ from b_asic import SFG, Output, Addition, Constant -sfg = SFG(outputs=[Output(Addition(Addition(Constant(2), Constant(3)), Addition(Constant(4), Constant(5))))]) +sfg_recipe = SFG(outputs=[Output(Addition(Addition(Constant(2), Constant(3)), Addition(Constant(4), Constant(5))))]) """ ) except Exception as e: @@ -71,17 +71,42 @@ sfg = SFG(outputs=[Output(Addition(Addition(Constant(2), Constant(3)), Addition( # Not defined yet with pytest.raises(NameError): - sfg.evaluate() == 14 + sfg_recipe.evaluate() == 14 - load_recipe(_file) - assert sfg.evaluate() == 14 + load_recipe(_file, globals()) + assert sfg_recipe.evaluate() == 14 - def test_load_invalid_recipe_file(self): - _file = "".join(choice(ascii_lowercase) for _ in range(4)) + ".py" - while path.exists(path.join(getcwd(), _file)): - _file = "".join(choice(ascii_lowercase) for _ in range(4)) + ".py" + def test_load_recipe_file_namespace(self): + # Create a file that doesn't exist + _file = "".join(choice(ascii_lowercase) for _ in range(4)) + while path.exists(path.join(getcwd(), _file + ".py")): + _file = "".join(choice(ascii_lowercase) for _ in range(4)) + try: + with open(_file + ".py", "w+") as handle: + # The string is indented that way so the file is properly indented, .strip() did not work idk why + handle.write( +""" +from b_asic import SFG, Output, Addition, Constant +sfg_recipe_2 = SFG(outputs=[Output(Addition(Addition(Constant(2), Constant(3)), Addition(Constant(4), Constant(5))))]) +""" + ) + except Exception as e: + assert False, f"Could not create file: {e}" + + # Not defined yet + with pytest.raises(NameError): + sfg_recipe_2.evaluate() == 14 + + _namespace = load_recipe(_file) + assert _namespace.sfg_recipe_2.evaluate() == 14 + assert _namespace != globals() + + def test_load_invalid_recipe_file(self): + _file = "".join(choice(ascii_lowercase) for _ in range(4)) + while path.exists(path.join(getcwd(), _file + ".py")): + _file = "".join(choice(ascii_lowercase) for _ in range(4)) + load_recipe(_file) with pytest.raises(NameError): - assert sfg.evaluate() == 14 - \ No newline at end of file + assert sfg_recipe_3.evaluate() == 14