diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index 98523fbbc3ad6cf26c028e028429c60710578d33..296803e3e55b7b92d85f52b91b30533ecdfbc0b6 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -192,7 +192,7 @@ class ConstantMultiplication(AbstractOperation): TODO: More info. """ - def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""): + def __init__(self, value: Number = 0, src0: Optional[SignalSourceProvider] = None, name: Name = ""): super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) self.set_param("value", value) diff --git a/b_asic/operation.py b/b_asic/operation.py index 456c83290a57e6c874b514cdd78fbd8583b199d9..92c7b2b04029d039b125970774d84993ae1ae88b 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -180,6 +180,11 @@ class Operation(GraphComponent, SignalSourceProvider): """ raise NotImplementedError + @abstractmethod + def inputs_required_for_output(self, output_index: int) -> Iterable[int]: + """Get the input indices of all inputs in this operation whose values are required in order to evalueate the output at the given output index.""" + raise NotImplementedError + class AbstractOperation(Operation, AbstractGraphComponent): """Generic abstract operation class which most implementations will derive from. @@ -331,29 +336,6 @@ class AbstractOperation(Operation, AbstractGraphComponent): try: result = self.evaluate(*([Input()] * self.input_count)) if isinstance(result, collections.Sequence) and all(isinstance(e, Operation) for e in result): - - # Loopa igenom alla inputs för Self - # Spara Destination - # Hitta utsignalen från Destination - # Spara utsignalens destination - # Gör så att current Input Signals destination blir utsignalens destination - - - - - - # Loopa igenom alla outputs för Self - # Spara Source - # Hitta insignalen till Source - # Spara insignalens source - # Gör så att current Output Signals source blir insignalens source - - - self.input_signals - - - - return result if isinstance(result, Operation): return [result] @@ -363,7 +345,10 @@ class AbstractOperation(Operation, AbstractGraphComponent): pass return [self] - + def inputs_required_for_output(self, output_index: int) -> Iterable[int]: + if output_index < 0 or output_index >= self.output_count: + raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})") + return [i for i in range(self.input_count)] # By default, assume each output depends on all inputs. @property def neighbors(self) -> Iterable[GraphComponent]: diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 2aeb13aa711b940e953dc9f6778f841b122cbff8..04c43d7f43cba59d1c528fc30bf48e7cf7675577 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -3,7 +3,7 @@ B-ASIC Signal Flow Graph Module. TODO: More info. """ -from typing import List, Iterable, Sequence, Dict, Optional, DefaultDict, Set +from typing import List, Iterable, Sequence, Dict, Optional, DefaultDict, MutableSet from numbers import Number from collections import defaultdict, deque @@ -45,7 +45,7 @@ class SFG(AbstractOperation): _graph_id_generator: GraphIDGenerator _input_operations: List[Input] _output_operations: List[Output] - _original_components_to_new: Set[GraphComponent] + _original_components_to_new: MutableSet[GraphComponent] _original_input_signals_to_indices: Dict[Signal, int] _original_output_signals_to_indices: Dict[Signal, int] @@ -155,6 +155,9 @@ class SFG(AbstractOperation): raise ValueError(f"Output signal #{output_index} is missing source in SFG") if signal.source.operation not in self._original_components_to_new: self._add_operation_connected_tree_copy(signal.source.operation) + + # Find dependencies. + def __str__(self) -> str: """Get a string representation of this SFG.""" @@ -231,17 +234,6 @@ class SFG(AbstractOperation): results[self.key(index, prefix)] = value return value - def split(self) -> Iterable[Operation]: - """ Returns every operation in the SFG except for Input and Output types. """ - - ops = [] - for op in self.operations: - if not isinstance(op, Input) and not isinstance(op, Output): - ops.append(op) - - return ops # Need any checking before returning? - - def replace_self(self) -> None: """ Iterates over the SFG's (self) Input- and OutputSignals to reconnect them to each necessary operation inside the SFG, so that the inner operations of the SFG can function on their own, effectively replacing the SFG. """ @@ -277,6 +269,14 @@ class SFG(AbstractOperation): def output_operations(self) -> Sequence[Operation]: """Get the internal output operations in the same order as their respective output ports.""" return self._output_operations + + def split(self) -> Iterable[Operation]: + return self.operations + + def inputs_required_for_output(self, output_index: int) -> Iterable[int]: + if output_index < 0 or output_index >= self.output_count: + raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})") + return self._inputs_required_for_source(self._output_operations[output_index].input(0).signals[0].source, set()) def copy_component(self, *args, **kwargs) -> GraphComponent: return super().copy_component(*args, **kwargs, inputs = self._input_operations, outputs = self._output_operations, @@ -411,6 +411,44 @@ class SFG(AbstractOperation): # Add connected operation to the queue of operations to visit. op_stack.append(original_connected_op) + def replace_component(self, component: Operation, _component: Operation = None, _id: GraphID = None): + """Find and replace all components matching either on GraphID, Type or both. + Then return a new deepcopy of the sfg with the replaced component. + + Arguments: + component: The new component(s), e.g Multiplication + + Keyword arguments: + _component: The specific component to replace. + _id: The GraphID to match the component to replace. + """ + + assert _component is not None or _id is not None, \ + "Define either operation to replace or GraphID of operation" + + if _id is not None: + _component = self.find_by_id(_id) + + assert _component is not None and isinstance(_component, Operation), \ + "No operation matching the criteria found" + assert _component.output_count == component.output_count, \ + "The output count may not differ between the operations" + assert _component.input_count == component.input_count, \ + "The input count may not differ between the operations" + + for index_in, _inp in enumerate(_component.inputs): + for _signal in _inp.signals: + _signal.remove_destination() + _signal.set_destination(component.input(index_in)) + + for index_out, _out in enumerate(_component.outputs): + for _signal in _out.signals: + _signal.remove_source() + _signal.set_source(component.output(index_out)) + + # The old SFG will be deleted by Python GC + return self() + def _evaluate_source(self, src: OutputPort, results: MutableResultMap, registers: MutableRegisterMap, prefix: str) -> Number: src_prefix = prefix if src_prefix: @@ -429,3 +467,18 @@ class SFG(AbstractOperation): value = src.operation.evaluate_output(src.index, input_values, results, registers, src_prefix) results[key] = value return value + + def _inputs_required_for_source(self, src: OutputPort, visited: MutableSet[Operation]) -> Sequence[bool]: + if src.operation in visited: + return [] + visited.add(src.operation) + + if isinstance(src.operation, Input): + for i, input_operation in enumerate(self._input_operations): + if input_operation is src.operation: + return [i] + + input_indices = [] + for i in src.operation.inputs_required_for_output(src.index): + input_indices.extend(self._inputs_required_for_source(src.operation.input(i).signals[0].source, visited)) + return input_indices diff --git a/test/fixtures/port.py b/test/fixtures/port.py index fa528b8d9437e60b99c1ec426f317eb97b0164f2..4cce4f69b1f11b44426d1bd39702dba4e11c0efe 100644 --- a/test/fixtures/port.py +++ b/test/fixtures/port.py @@ -10,3 +10,11 @@ def input_port(): @pytest.fixture def output_port(): return OutputPort(None, 0) + +@pytest.fixture +def list_of_input_ports(): + return [InputPort(None, i) for i in range(0, 3)] + +@pytest.fixture +def list_of_output_ports(): + return [OutputPort(None, i) for i in range(0, 3)] diff --git a/test/fixtures/signal_flow_graph.py b/test/fixtures/signal_flow_graph.py index 0a6c554d1340478dad25a11655f0542bf6fba1d1..5a0ef25b94cec8e3fad9275cccf97882703de330 100644 --- a/test/fixtures/signal_flow_graph.py +++ b/test/fixtures/signal_flow_graph.py @@ -1,6 +1,6 @@ import pytest -from b_asic import SFG, Input, Output, Constant, Register +from b_asic import SFG, Input, Output, Constant, Register, ConstantMultiplication @pytest.fixture @@ -29,6 +29,33 @@ def sfg_two_inputs_two_outputs(): out2 = Output(add2) return SFG(inputs = [in1, in2], outputs = [out1, out2]) +@pytest.fixture +def sfg_two_inputs_two_outputs_independent(): + """Valid SFG with two inputs and two outputs, where the first output only depends + on the first input and the second output only depends on the second input. + . . + in1-------------------->out1 + . . + . . + . c1--+ . + . | . + . v . + in2------+ add1---->out2 + . | ^ . + . | | . + . +------+ . + . . + out1 = in1 + out2 = in2 + 3 + """ + in1 = Input() + in2 = Input() + c1 = Constant(3) + add1 = in2 + c1 + out1 = Output(in1) + out2 = Output(add1) + return SFG(inputs = [in1, in2], outputs = [out1, out2]) + @pytest.fixture def sfg_nested(): """Valid SFG with two inputs and one output. @@ -68,4 +95,21 @@ def sfg_accumulator(): reg = Register() reg.input(0).connect((reg + data_in) * (1 - reset)) data_out = Output(reg) - return SFG(inputs = [data_in, reset], outputs = [data_out]) \ No newline at end of file + return SFG(inputs = [data_in, reset], outputs = [data_out]) + +@pytest.fixture +def simple_filter(): + """A valid SFG that is used as a filter in the first lab for TSTE87. + +----<constmul1----+ + | | + | | + in1>------add1>------reg>------+------out1> + """ + in1 = Input() + reg = Register() + constmul1 = ConstantMultiplication(0.5) + add1 = in1 + constmul1 + reg.input(0).connect(add1) + constmul1.input(0).connect(reg) + out1 = Output(reg) + return SFG(inputs=[in1], outputs=[out1]) diff --git a/test/test_depends.py b/test/test_depends.py new file mode 100644 index 0000000000000000000000000000000000000000..e26911054a9604db2f08998f6ecfccd81a012e5a --- /dev/null +++ b/test/test_depends.py @@ -0,0 +1,19 @@ +from b_asic import Addition, Butterfly + +class TestDepends: + def test_depends_addition(self): + add1 = Addition() + assert set(add1.inputs_required_for_output(0)) == {0, 1} + + def test_depends_butterfly(self): + bfly1 = Butterfly() + assert set(bfly1.inputs_required_for_output(0)) == {0, 1} + assert set(bfly1.inputs_required_for_output(1)) == {0, 1} + + def test_depends_sfg(self, sfg_two_inputs_two_outputs): + assert set(sfg_two_inputs_two_outputs.inputs_required_for_output(0)) == {0, 1} + assert set(sfg_two_inputs_two_outputs.inputs_required_for_output(1)) == {0, 1} + + def test_depends_sfg_independent(self, sfg_two_inputs_two_outputs_independent): + assert set(sfg_two_inputs_two_outputs_independent.inputs_required_for_output(0)) == {0} + assert set(sfg_two_inputs_two_outputs_independent.inputs_required_for_output(1)) == {1} \ No newline at end of file diff --git a/test/test_inputport.py b/test/test_inputport.py index 85f892217c7e0f766417f6cc2e6d066d48d8a537..f4668938ce3aa90e052115d57e3d3d52c9d9e3eb 100644 --- a/test/test_inputport.py +++ b/test/test_inputport.py @@ -7,15 +7,7 @@ import pytest from b_asic import InputPort, OutputPort, Signal @pytest.fixture -def inp_port(): - return InputPort(None, 0) - -@pytest.fixture -def out_port(): - return OutputPort(None, 0) - -@pytest.fixture -def out_port2(): +def output_port2(): return OutputPort(None, 1) @pytest.fixture @@ -23,53 +15,53 @@ def dangling_sig(): return Signal() @pytest.fixture -def s_w_source(out_port): - return Signal(source=out_port) +def s_w_source(output_port): + return Signal(source=output_port) @pytest.fixture def sig_with_dest(inp_port): return Signal(destination=inp_port) @pytest.fixture -def connected_sig(inp_port, out_port): - return Signal(source=out_port, destination=inp_port) +def connected_sig(inp_port, output_port): + return Signal(source=output_port, destination=inp_port) -def test_connect_then_disconnect(inp_port, out_port): +def test_connect_then_disconnect(input_port, output_port): """Test connect unused port to port.""" - s1 = inp_port.connect(out_port) + s1 = input_port.connect(output_port) - assert inp_port.connected_source == out_port - assert inp_port.signals == [s1] - assert out_port.signals == [s1] - assert s1.source is out_port - assert s1.destination is inp_port + assert input_port.connected_source == output_port + assert input_port.signals == [s1] + assert output_port.signals == [s1] + assert s1.source is output_port + assert s1.destination is input_port - inp_port.remove_signal(s1) + input_port.remove_signal(s1) - assert inp_port.connected_source is None - assert inp_port.signals == [] - assert out_port.signals == [s1] - assert s1.source is out_port + assert input_port.connected_source is None + assert input_port.signals == [] + assert output_port.signals == [s1] + assert s1.source is output_port assert s1.destination is None -def test_connect_used_port_to_new_port(inp_port, out_port, out_port2): - """Does connecting multiple ports to an inputport throw error?""" - inp_port.connect(out_port) +def test_connect_used_port_to_new_port(input_port, output_port, output_port2): + """Multiple connections to an input port should throw an error.""" + input_port.connect(output_port) with pytest.raises(Exception): - inp_port.connect(out_port2) + input_port.connect(output_port2) -def test_add_signal_then_disconnect(inp_port, s_w_source): +def test_add_signal_then_disconnect(input_port, s_w_source): """Can signal be connected then disconnected properly?""" - inp_port.add_signal(s_w_source) + input_port.add_signal(s_w_source) - assert inp_port.connected_source == s_w_source.source - assert inp_port.signals == [s_w_source] + assert input_port.connected_source == s_w_source.source + assert input_port.signals == [s_w_source] assert s_w_source.source.signals == [s_w_source] - assert s_w_source.destination is inp_port + assert s_w_source.destination is input_port - inp_port.remove_signal(s_w_source) + input_port.remove_signal(s_w_source) - assert inp_port.connected_source is None - assert inp_port.signals == [] + assert input_port.connected_source is None + assert input_port.signals == [] assert s_w_source.source.signals == [s_w_source] assert s_w_source.destination is None diff --git a/test/test_outputport.py b/test/test_outputport.py index 189c89225f88f263294d24aae21995ee7a821ada..7cc250ee083bdecb5ee04c3c28e56518a5bb6331 100644 --- a/test/test_outputport.py +++ b/test/test_outputport.py @@ -5,33 +5,19 @@ import pytest from b_asic import OutputPort, InputPort, Signal - -@pytest.fixture -def output_port(): - return OutputPort(None, 0) - -@pytest.fixture -def input_port(): - return InputPort(None, 0) - -@pytest.fixture -def list_of_input_ports(): - return [InputPort(None, i) for i in range(0, 3)] - - class TestConnect: def test_multiple_ports(self, output_port, list_of_input_ports): - """Can multiple ports connect to an output port?""" + """Multiple connections to an output port should be possible.""" for port in list_of_input_ports: port.connect(output_port) assert output_port.signal_count == len(list_of_input_ports) - def test_same_port(self, output_port, list_of_input_ports): + def test_same_port(self, output_port, input_port): """Check error handing.""" - list_of_input_ports[0].connect(output_port) + input_port.connect(output_port) with pytest.raises(Exception): - list_of_input_ports[0].connect(output_port) + input_port.connect(output_port) assert output_port.signal_count == 1 @@ -43,9 +29,8 @@ class TestAddSignal: assert output_port.signal_count == 1 assert output_port.signals == [s] -class TestDisconnect: +class TestClear: def test_others_clear(self, output_port, list_of_input_ports): - """Can multiple ports disconnect from OutputPort?""" for port in list_of_input_ports: port.connect(output_port) @@ -56,7 +41,6 @@ class TestDisconnect: assert all(s.dangling() for s in output_port.signals) def test_self_clear(self, output_port, list_of_input_ports): - """Can an OutputPort disconnect from multiple ports?""" for port in list_of_input_ports: port.connect(output_port) @@ -74,7 +58,6 @@ class TestRemoveSignal: assert output_port.signals == [] def test_multiple_signals(self, output_port, list_of_input_ports): - """Can multiple signals disconnect from OutputPort?""" sigs = [] for port in list_of_input_ports: diff --git a/test/test_print_sfg.py b/test/test_print_sfg.py deleted file mode 100644 index 49b0950d82857f86ba652e76075b5d3cb40e1584..0000000000000000000000000000000000000000 --- a/test/test_print_sfg.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -B-ASIC test suite for printing a SFG -""" - - -from b_asic.signal_flow_graph import SFG -from b_asic.core_operations import Addition, Multiplication, Constant -from b_asic.port import InputPort, OutputPort -from b_asic.signal import Signal -from b_asic.special_operations import Input, Output - -import pytest - - -class TestPrintSfg: - def test_print_one_addition(self): - inp1 = Input("INP1") - inp2 = Input("INP2") - add1 = Addition(inp1, inp2, "ADD1") - out1 = Output(add1, "OUT1") - sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="sf1") - - assert sfg.__str__() == ("id: add1, name: ADD1, input: [s1, s2], output: [s3]\nid: in1, name: INP1, input: [], output: [s1]\nid: in2, name: INP2, input: [], output: [s2]\nid: out1, name: OUT1, input: [s3], output: []\n") - - def test_print_add_mul(self): - inp1 = Input("INP1") - inp2 = Input("INP2") - inp3 = Input("INP3") - add1 = Addition(inp1, inp2, "ADD1") - mul1 = Multiplication(add1, inp3, "MUL1") - out1 = Output(mul1, "OUT1") - sfg = SFG(inputs=[inp1, inp2, inp3], outputs=[out1], name="mac_sfg") - - assert sfg.__str__() == ("id: add1, name: ADD1, input: [s1, s2], output: [s5]\nid: in1, name: INP1, input: [], output: [s1]\nid: in2, name: INP2, input: [], output: [s2]\nid: mul1, name: MUL1, input: [s5, s3], output: [s4]\nid: in3, name: INP3, input: [], output: [s3]\nid: out1, name: OUT1, input: [s4], output: []\n") - - def test_print_constant(self): - inp1 = Input("INP1") - const1 = Constant(3, "CONST") - add1 = Addition(const1, inp1, "ADD1") - out1 = Output(add1, "OUT1") - - sfg = SFG(inputs=[inp1], outputs=[out1], name="sfg") - - assert sfg.__str__() == ("id: add1, name: ADD1, input: [s3, s1], output: [s2]\nid: c1, name: CONST, value: 3, input: [], output: [s3]\nid: in1, name: INP1, input: [], output: [s1]\nid: out1, name: OUT1, input: [s2], output: []\n") - - \ No newline at end of file diff --git a/test/test_signal_flow_graph.py b/test/test_sfg.py similarity index 70% rename from test/test_signal_flow_graph.py rename to test/test_sfg.py index e68639294050ccacc675525539fa6733e7edf780..38d7f27f9321b6a480ba03b85345e2324170fa91 100644 --- a/test/test_signal_flow_graph.py +++ b/test/test_sfg.py @@ -3,7 +3,7 @@ import pytest from b_asic import SFG, Signal, Input, Output, Constant, Addition, Multiplication, Subtraction -class TestConstructor: +class TestInit: def test_direct_input_to_output_sfg_construction(self): in1 = Input("IN1") out1 = Output(None, "OUT1") @@ -45,21 +45,58 @@ class TestConstructor: assert sfg.input_count == 0 assert sfg.output_count == 1 +class TestPrintSfg: + def test_one_addition(self): + inp1 = Input("INP1") + inp2 = Input("INP2") + add1 = Addition(inp1, inp2, "ADD1") + out1 = Output(add1, "OUT1") + sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="sf1") -class TestEvaluation: - def test_evaluate_output(self, operation_tree): - sfg = SFG(outputs = [Output(operation_tree)]) - assert sfg.evaluate_output(0, []) == 5 + assert sfg.__str__() == \ + "id: add1, name: ADD1, input: [s1, s2], output: [s3]\n" + \ + "id: in1, name: INP1, input: [], output: [s1]\n" + \ + "id: in2, name: INP2, input: [], output: [s2]\n" + \ + "id: out1, name: OUT1, input: [s3], output: []\n" - def test_evaluate_output_large(self, large_operation_tree): - sfg = SFG(outputs = [Output(large_operation_tree)]) - assert sfg.evaluate_output(0, []) == 14 + def test_add_mul(self): + inp1 = Input("INP1") + inp2 = Input("INP2") + inp3 = Input("INP3") + add1 = Addition(inp1, inp2, "ADD1") + mul1 = Multiplication(add1, inp3, "MUL1") + out1 = Output(mul1, "OUT1") + sfg = SFG(inputs=[inp1, inp2, inp3], outputs=[out1], name="mac_sfg") - def test_evaluate_output_cycle(self, operation_graph_with_cycle): - sfg = SFG(outputs = [Output(operation_graph_with_cycle)]) - with pytest.raises(Exception): - sfg.evaluate_output(0, []) + assert sfg.__str__() == \ + "id: add1, name: ADD1, input: [s1, s2], output: [s5]\n" + \ + "id: in1, name: INP1, input: [], output: [s1]\n" + \ + "id: in2, name: INP2, input: [], output: [s2]\n" + \ + "id: mul1, name: MUL1, input: [s5, s3], output: [s4]\n" + \ + "id: in3, name: INP3, input: [], output: [s3]\n" + \ + "id: out1, name: OUT1, input: [s4], output: []\n" + def test_constant(self): + inp1 = Input("INP1") + const1 = Constant(3, "CONST") + add1 = Addition(const1, inp1, "ADD1") + out1 = Output(add1, "OUT1") + + sfg = SFG(inputs=[inp1], outputs=[out1], name="sfg") + + assert sfg.__str__() == \ + "id: add1, name: ADD1, input: [s3, s1], output: [s2]\n" + \ + "id: c1, name: CONST, value: 3, input: [], output: [s3]\n" + \ + "id: in1, name: INP1, input: [], output: [s1]\n" + \ + "id: out1, name: OUT1, input: [s2], output: []\n" + + def test_simple_filter(self, simple_filter): + assert simple_filter.__str__() == \ + 'id: add1, name: , input: [s1, s3], output: [s4]\n' + \ + 'id: in1, name: , input: [], output: [s1]\n' + \ + 'id: cmul1, name: , input: [s5], output: [s3]\n' + \ + 'id: reg1, name: , input: [s4], output: [s5, s2]\n' + \ + 'id: out1, name: , input: [s2], output: []\n' class TestDeepCopy: def test_deep_copy_no_duplicates(self): @@ -125,6 +162,19 @@ class TestDeepCopy: assert mac_sfg_new.input(0).signals[0].source.operation is a assert mac_sfg_new.input(1).signals[0].source.operation is b +class TestEvaluateOutput: + def test_evaluate_output(self, operation_tree): + sfg = SFG(outputs = [Output(operation_tree)]) + assert sfg.evaluate_output(0, []) == 5 + + def test_evaluate_output_large(self, large_operation_tree): + sfg = SFG(outputs = [Output(large_operation_tree)]) + assert sfg.evaluate_output(0, []) == 14 + + def test_evaluate_output_cycle(self, operation_graph_with_cycle): + sfg = SFG(outputs = [Output(operation_graph_with_cycle)]) + with pytest.raises(Exception): + sfg.evaluate_output(0, []) class TestComponents: def test_advanced_components(self): @@ -149,6 +199,62 @@ class TestComponents: assert set([comp.name for comp in mac_sfg.components]) == {"INP1", "INP2", "INP3", "ADD1", "ADD2", "MUL1", "OUT1", "S1", "S2", "S3", "S4", "S5", "S6", "S7"} +class TestReplaceComponents: + def test_replace_addition_by_id(self, operation_tree): + sfg = SFG(outputs=[Output(operation_tree)]) + component_id = "add1" + + sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id) + assert component_id not in sfg._components_by_id.keys() + assert "Multi" in sfg._components_by_name.keys() + + def test_replace_addition_by_component(self, operation_tree): + sfg = SFG(outputs=[Output(operation_tree)]) + component_id = "add1" + component = sfg.find_by_id(component_id) + + sfg = sfg.replace_component(Multiplication(name="Multi"), _component=component) + assert component_id not in sfg._components_by_id.keys() + assert "Multi" in sfg._components_by_name.keys() + + def test_replace_addition_large_tree(self, large_operation_tree): + sfg = SFG(outputs=[Output(large_operation_tree)]) + component_id = "add3" + + sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id) + assert "Multi" in sfg._components_by_name.keys() + assert component_id not in sfg._components_by_id.keys() + + def test_replace_no_input_component(self, operation_tree): + sfg = SFG(outputs=[Output(operation_tree)]) + component_id = "c1" + _const = sfg.find_by_id(component_id) + + sfg = sfg.replace_component(Constant(1), _id=component_id) + assert _const is not sfg.find_by_id(component_id) + + def test_no_match_on_replace(self, large_operation_tree): + sfg = SFG(outputs=[Output(large_operation_tree)]) + component_id = "addd1" + + try: + sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id) + except AssertionError: + assert True + else: + assert False + + def test_not_equal_input(self, large_operation_tree): + sfg = SFG(outputs=[Output(large_operation_tree)]) + component_id = "c1" + + try: + sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id) + except AssertionError: + assert True + else: + assert False + class TestReplaceSelfSoloComp: def test_replace_self_mac(self): @@ -224,8 +330,6 @@ class TestReplaceSelfSoloComp: class TestReplaceSelfMultipleComp: - - def test_replace_self_operation_tree(self, operation_tree): """ Replaces a operation_tree in an SFG with other components """ sfg1 = SFG(outputs = [Output(operation_tree)]) @@ -335,8 +439,4 @@ class TestReplaceSelfMultipleComp: sfg1.replace_self() - assert test_sfg.evaluate(1, 2, 3, 4) == 16 - - - - + assert test_sfg.evaluate(1, 2, 3, 4) == 16 \ No newline at end of file diff --git a/test/test_signal.py b/test/test_signal.py index cad16c9ba5b73b3d597c4e80aa666677d1909888..42086d4d5fb68eb861eee72d1351931d473b480b 100644 --- a/test/test_signal.py +++ b/test/test_signal.py @@ -61,27 +61,28 @@ def test_signal_creation_and_disconnction_and_connection_changing(): assert s.source is out_port assert s.destination is in_port -def test_signal_set_bits_pos_int(signal): - signal.bits = 10 - assert signal.bits == 10 - -def test_signal_set_bits_zero(signal): - signal.bits = 0 - assert signal.bits == 0 - -def test_signal_set_bits_neg_int(signal): - with pytest.raises(Exception): - signal.bits = -10 - -def test_signal_set_bits_complex(signal): - with pytest.raises(Exception): - signal.bits = (2+4j) - -def test_signal_set_bits_float(signal): - with pytest.raises(Exception): - signal.bits = 3.2 - -def test_signal_set_bits_pos_then_none(signal): - signal.bits = 10 - signal.bits = None - assert signal.bits is None \ No newline at end of file +class Bits: + def test_pos_int(self, signal): + signal.bits = 10 + assert signal.bits == 10 + + def test_bits_zero(self, signal): + signal.bits = 0 + assert signal.bits == 0 + + def test_bits_neg_int(self, signal): + with pytest.raises(Exception): + signal.bits = -10 + + def test_bits_complex(self, signal): + with pytest.raises(Exception): + signal.bits = (2+4j) + + def test_bits_float(self, signal): + with pytest.raises(Exception): + signal.bits = 3.2 + + def test_bits_pos_then_none(self, signal): + signal.bits = 10 + signal.bits = None + assert signal.bits is None \ No newline at end of file diff --git a/test/test_simulation.py b/test/test_simulation.py index faa1f75eb12acccf26169f31849f61da83df598e..70d4ede54d39dc3bfed2dac87ca3703016935a92 100644 --- a/test/test_simulation.py +++ b/test/test_simulation.py @@ -4,8 +4,8 @@ import numpy as np from b_asic import SFG, Output, Simulation -class TestSimulation: - def test_simulate_with_lambdas_as_input(self, sfg_two_inputs_two_outputs): +class TestRunFor: + def test_with_lambdas_as_input(self, sfg_two_inputs_two_outputs): simulation = Simulation(sfg_two_inputs_two_outputs, [lambda n: n + 3, lambda n: 1 + n * 2], save_results = True) output = simulation.run_for(101) @@ -44,7 +44,7 @@ class TestSimulation: assert simulation.results[3]["0"] == 13 assert simulation.results[3]["1"] == 20 - def test_simulate_with_numpy_arrays_as_input(self, sfg_two_inputs_two_outputs): + def test_with_numpy_arrays_as_input(self, sfg_two_inputs_two_outputs): input0 = np.array([5, 9, 25, -5, 7]) input1 = np.array([7, 3, 3, 54, 2]) simulation = Simulation(sfg_two_inputs_two_outputs, [input0, input1]) @@ -85,8 +85,8 @@ class TestSimulation: assert simulation.results[4]["0"] == 9 assert simulation.results[4]["1"] == 11 - - def test_simulate_with_numpy_array_overflow(self, sfg_two_inputs_two_outputs): + + def test_with_numpy_array_overflow(self, sfg_two_inputs_two_outputs): input0 = np.array([5, 9, 25, -5, 7]) input1 = np.array([7, 3, 3, 54, 2]) simulation = Simulation(sfg_two_inputs_two_outputs, [input0, input1]) @@ -94,18 +94,7 @@ class TestSimulation: with pytest.raises(IndexError): simulation.run_for(1) - def test_simulate_nested(self, sfg_nested): - input0 = np.array([5, 9]) - input1 = np.array([7, 3]) - simulation = Simulation(sfg_nested, [input0, input1]) - - output0 = simulation.run() - output1 = simulation.run() - - assert output0[0] == 11405 - assert output1[0] == 4221 - - def test_simulate_delay(self, sfg_delay): + def test_delay(self, sfg_delay): simulation = Simulation(sfg_delay, save_results = True) simulation.set_input(0, [5, -2, 25, -6, 7, 0]) simulation.run_for(6) @@ -116,8 +105,20 @@ class TestSimulation: assert simulation.results[3]["0"] == 25 assert simulation.results[4]["0"] == -6 assert simulation.results[5]["0"] == 7 + +class TestRun: + def test_nested(self, sfg_nested): + input0 = np.array([5, 9]) + input1 = np.array([7, 3]) + simulation = Simulation(sfg_nested, [input0, input1]) + + output0 = simulation.run() + output1 = simulation.run() + + assert output0[0] == 11405 + assert output1[0] == 4221 - def test_simulate_accumulator(self, sfg_accumulator): + def test_accumulator(self, sfg_accumulator): data_in = np.array([5, -2, 25, -6, 7, 0]) reset = np.array([0, 0, 0, 1, 0, 0]) simulation = Simulation(sfg_accumulator, [data_in, reset]) @@ -132,4 +133,10 @@ class TestSimulation: assert output2[0] == 3 assert output3[0] == 28 assert output4[0] == 0 - assert output5[0] == 7 \ No newline at end of file + assert output5[0] == 7 + + def test_simple_filter(self, simple_filter): + input0 = np.array([1, 2, 3, 4, 5]) + simulation = Simulation(simple_filter, [input0], save_results=True) + output0 = [simulation.run()[0] for _ in range(len(input0))] + assert output0 == [0, 1.0, 2.5, 4.25, 6.125]