diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index e8e7af01ab93fdba948d9ff7ec19078b3b71dee6..d311bfacb69878c76d03b4194c5cf24a54e4f79c 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -367,6 +367,44 @@ class SFG(AbstractOperation): # Add connected operation to the queue of operations to visit. op_stack.append(original_connected_op) + def replace_component(self, component: Operation, _component: Operation = None, _id: GraphID = None): + """Find and replace all components matching either on GraphID, Type or both. + Then return a new deepcopy of the sfg with the replaced component. + + Arguments: + component: The new component(s), e.g Multiplication + + Keyword arguments: + _component: The specific component to replace. + _id: The GraphID to match the component to replace. + """ + + assert _component is not None or _id is not None, \ + "Define either operation to replace or GraphID of operation" + + if _id is not None: + _component = self.find_by_id(_id) + + assert _component is not None and isinstance(_component, Operation), \ + "No operation matching the criteria found" + assert _component.output_count == component.output_count, \ + "The output count may not differ between the operations" + assert _component.input_count == component.input_count, \ + "The input count may not differ between the operations" + + for index_in, _inp in enumerate(_component.inputs): + for _signal in _inp.signals: + _signal.remove_destination() + _signal.set_destination(component.input(index_in)) + + for index_out, _out in enumerate(_component.outputs): + for _signal in _out.signals: + _signal.remove_source() + _signal.set_source(component.output(index_out)) + + # The old SFG will be deleted by Python GC + return self() + def _evaluate_source(self, src: OutputPort, results: MutableResultMap, registers: MutableRegisterMap, prefix: str) -> Number: src_prefix = prefix if src_prefix: diff --git a/test/test_signal_flow_graph.py b/test/test_signal_flow_graph.py index 51267cc44ece60c05e622c02c89b8ec1a5d5b17d..a2114f81482de70454d3b4328a379ba1a09e33f9 100644 --- a/test/test_signal_flow_graph.py +++ b/test/test_signal_flow_graph.py @@ -147,3 +147,61 @@ class TestComponents: mac_sfg = SFG(inputs = [inp1, inp2], outputs = [out1], name = "mac_sfg") assert set([comp.name for comp in mac_sfg.components]) == {"INP1", "INP2", "INP3", "ADD1", "ADD2", "MUL1", "OUT1", "S1", "S2", "S3", "S4", "S5", "S6", "S7"} + + +class TestReplaceComponents: + + def test_replace_addition_by_id(self, operation_tree): + sfg = SFG(outputs=[Output(operation_tree)]) + component_id = "add1" + + sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id) + assert component_id not in sfg._components_by_id.keys() + assert "Multi" in sfg._components_by_name.keys() + + def test_replace_addition_by_component(self, operation_tree): + sfg = SFG(outputs=[Output(operation_tree)]) + component_id = "add1" + component = sfg.find_by_id(component_id) + + sfg = sfg.replace_component(Multiplication(name="Multi"), _component=component) + assert component_id not in sfg._components_by_id.keys() + assert "Multi" in sfg._components_by_name.keys() + + def test_replace_addition_large_tree(self, large_operation_tree): + sfg = SFG(outputs=[Output(large_operation_tree)]) + component_id = "add3" + + sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id) + assert "Multi" in sfg._components_by_name.keys() + assert component_id not in sfg._components_by_id.keys() + + def test_replace_no_input_component(self, operation_tree): + sfg = SFG(outputs=[Output(operation_tree)]) + component_id = "c1" + _const = sfg.find_by_id(component_id) + + sfg = sfg.replace_component(Constant(1), _id=component_id) + assert _const is not sfg.find_by_id(component_id) + + def test_no_match_on_replace(self, large_operation_tree): + sfg = SFG(outputs=[Output(large_operation_tree)]) + component_id = "addd1" + + try: + sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id) + except AssertionError: + assert True + else: + assert False + + def test_not_equal_input(self, large_operation_tree): + sfg = SFG(outputs=[Output(large_operation_tree)]) + component_id = "c1" + + try: + sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id) + except AssertionError: + assert True + else: + assert False