diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index d2089f17d4f2c4811aa8b4db04ded44c6c7a09c2..4541c3679623b1a83218a4a90766825d876bd5ca 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -416,34 +416,24 @@ class SFG(AbstractOperation): _sfg = self() - inputs = [] - outputs = [] + input_signals = [] + output_signals = [] for _operation in [_sfg.find_by_id(_id) for _id in operation_ids]: - - # Retrive input operations - for _signal in _operation.input_signals: - if _signal.source.operation.graph_id not in operation_ids: - inputs.append(_signal.source.operation) - - # Retrive output operations - for _signal in _operation.output_signals: - if _signal.destination.operation.graph_id not in operation_ids: - outputs.append(_signal.destination.operation) - - assert len(inputs) == operation.input_count, "The input count must match" - assert len(outputs) == operation.output_count, "The output count must match" - - for index_in, _input in enumerate(inputs): - for _signal in _input.output_signals: + input_signals.extend(filter(lambda s: s.source.operation.graph_id not in operation_ids, _operation.input_signals)) + output_signals.extend(filter(lambda s: s.destination.operation.graph_id not in operation_ids, _operation.output_signals)) + + assert len(input_signals) == operation.input_count, "The input count must match" + assert len(output_signals) == operation.output_count, "The output count must match" + + for index_in, _signal in enumerate(input_signals): _signal.remove_destination() _signal.set_destination(operation.input(index_in)) - for index_out, _output in enumerate(outputs): - for _signal in _output.input_signals: + for index_out, _signal in enumerate(output_signals): _signal.remove_source() _signal.set_source(operation.output(index_out)) - + return _sfg() def _evaluate_source(self, src: OutputPort, results: MutableResultMap, registers: MutableRegisterMap, prefix: str) -> Number: