diff --git a/b_asic/utils.py b/b_asic/utils.py index 7cca1f152546a445dbf39b659e99c3f99ac6c09d..11f690c91b12f7ca214db0ffa319c943288f367b 100644 --- a/b_asic/utils.py +++ b/b_asic/utils.py @@ -20,15 +20,18 @@ def save_structure(struct: AbstractOperation, _path: Optional[str] = None, _name Keyword Arguments: _path: The path (str) to which the structure will be saved. - _name: The name (str) of the file to be saved. Only used if _path is not defined and is not None. + _name: The name (str) of the file to be saved. Only used if _path is None. """ try: - _name = _name if _name is not None else f"{struct.type_name}.pickle" - if _path is None: + _name = _name if _name is not None else f"{struct.type_name}.pickle" _path = path.join(getcwd(), _name) - print(_path) + index = 1 + while path.exists(_path): + _path = path.join(getcwd(), f"{struct.type_name}({index}).pickle") + index += 1 + with open(_path, "wb") as handle: dill.dump(struct, handle, protocol=dill.HIGHEST_PROTOCOL) except Exception as e: diff --git a/test/test_load_save_structure.py b/test/test_load_save_structure.py index 193753b63d293cf99bb171cbdd53ff196fb82364..d84d8da0e5e10039d062f7df149055088668011a 100644 --- a/test/test_load_save_structure.py +++ b/test/test_load_save_structure.py @@ -1,4 +1,10 @@ -from os.path import isfile +""" +B-ASIC test suite for load/save datastructures. +""" + +from os import getcwd, path +from random import choice +from string import ascii_lowercase import pytest @@ -9,15 +15,38 @@ class TestSaveStructures: def test_save_sfg(self, large_operation_tree): sfg = SFG(outputs=[Output(large_operation_tree)]) - path = save_structure(sfg) - - assert path is not None - assert isfile(path) + _path = save_structure(sfg) + + assert _path is not None + assert path.isfile(_path) def test_load_sfg(self, large_operation_tree): sfg = SFG(outputs=[Output(large_operation_tree)]) - path = save_structure(sfg) - _sfg = load_structure(path) + _path = save_structure(sfg) + _sfg = load_structure(_path) assert isinstance(_sfg, SFG) + assert sorted([comp.type_name for comp in _sfg.components]) == sorted([comp.type_name for comp in sfg.components]) assert sfg.evaluate() == _sfg.evaluate() + + def test_save_invalid_path(self, large_operation_tree): + sfg = SFG(outputs=[Output(large_operation_tree)]) + + _folder = "".join(choice(ascii_lowercase) for _ in range(4)) + while path.exists(path.join(getcwd(), _folder)): + _folder = "".join(choice(ascii_lowercase) for _ in range(4)) + + _invalid_path = path.join(getcwd(), _folder, "cool.pickle") + _path = save_structure(sfg, _path=_invalid_path) + + assert _path is None + + def test_load_invalid_path(self, large_operation_tree): + _folder = "".join(choice(ascii_lowercase) for _ in range(4)) + while path.exists(path.join(getcwd(), _folder)): + _folder = "".join(choice(ascii_lowercase) for _ in range(4)) + + _invalid_path = path.join(getcwd(), _folder, "cool.pickle") + _path = load_structure(_path=_invalid_path) + + assert _path is None \ No newline at end of file