Skip to content
Snippets Groups Projects
Commit b8647143 authored by Felix Goding's avatar Felix Goding
Browse files

Merge branch '19-print-sfg' into 'develop'

Resolve "Print SFG"

See merge request PUM_TDDD96/B-ASIC!21
parents afe22efb f69a4f3c
No related branches found
No related tags found
3 merge requests!67WIP: B-ASIC version 1.0.0 hotfix,!65B-ASIC version 1.0.0,!21Resolve "Print SFG"
Pipeline #13300 passed
......@@ -166,13 +166,14 @@ class AbstractOperation(Operation, AbstractGraphComponent):
if input_sources is not None:
source_count = len(input_sources)
if source_count != input_count:
raise ValueError(f"Operation expected {input_count} input sources but only got {source_count}")
raise ValueError(
f"Operation expected {input_count} input sources but only got {source_count}")
for i, src in enumerate(input_sources):
if src is not None:
self._input_ports[i].connect(src.source)
@abstractmethod
def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ
def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ
"""Evaluate the operation and generate a list of output values given a
list of input values.
"""
......@@ -246,11 +247,13 @@ class AbstractOperation(Operation, AbstractGraphComponent):
result = self.evaluate(*input_values)
if isinstance(result, collections.Sequence):
if len(result) != self.output_count:
raise RuntimeError("Operation evaluated to incorrect number of outputs")
raise RuntimeError(
"Operation evaluated to incorrect number of outputs")
return result
if isinstance(result, Number):
if self.output_count != 1:
raise RuntimeError("Operation evaluated to incorrect number of outputs")
raise RuntimeError(
"Operation evaluated to incorrect number of outputs")
return [result]
raise RuntimeError("Operation evaluated to invalid type")
......@@ -296,11 +299,13 @@ class AbstractOperation(Operation, AbstractGraphComponent):
def source(self) -> OutputPort:
if self.output_count != 1:
diff = "more" if self.output_count > 1 else "less"
raise TypeError(f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output")
raise TypeError(
f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output")
return self.output(0)
def copy_unconnected(self) -> GraphComponent:
new_comp: AbstractOperation = super().copy_unconnected()
for name, value in self.params.items():
new_comp.set_param(name, deepcopy(value)) # pylint: disable=no-member
new_comp.set_param(name, deepcopy(
value)) # pylint: disable=no-member
return new_comp
......@@ -8,6 +8,7 @@ from copy import copy
from typing import NewType, Optional, List, Iterable, TYPE_CHECKING
from b_asic.signal import Signal
from b_asic.graph_component import Name
if TYPE_CHECKING:
from b_asic.operation import Operation
......@@ -144,22 +145,24 @@ class InputPort(AbstractPort):
"""
return None if self._source_signal is None else self._source_signal.source
def connect(self, src: SignalSourceProvider) -> Signal:
def connect(self, src: SignalSourceProvider, name: Name = "") -> Signal:
"""Connect the provided signal source to this input port by creating a new signal.
Returns the new signal.
"""
assert self._source_signal is None, "Attempted to connect already connected input port."
return Signal(src.source, self) # self._source_signal is set by the signal constructor.
# self._source_signal is set by the signal constructor.
return Signal(source=src.source, destination=self, name=name)
@property
def value_length(self) -> Optional[int]:
"""Get the number of bits that this port should truncate received values to."""
return self._value_length
@value_length.setter
def value_length(self, bits: Optional[int]) -> None:
"""Set the number of bits that this port should truncate received values to."""
assert bits is None or (isinstance(bits, int) and bits >= 0), "Value length must be non-negative."
assert bits is None or (isinstance(
bits, int) and bits >= 0), "Value length must be non-negative."
self._value_length = bits
......@@ -185,7 +188,7 @@ class OutputPort(AbstractPort, SignalSourceProvider):
def add_signal(self, signal: Signal) -> None:
assert signal not in self._destination_signals, "Attempted to add already connected signal."
self._destination_signals.append(signal)
signal.set_source(self)
signal.set_source(self)
def remove_signal(self, signal: Signal) -> None:
assert signal in self._destination_signals, "Attempted to remove already removed signal."
......@@ -195,7 +198,7 @@ class OutputPort(AbstractPort, SignalSourceProvider):
def clear(self) -> None:
for signal in copy(self._destination_signals):
self.remove_signal(signal)
@property
def source(self) -> "OutputPort":
return self
\ No newline at end of file
return self
This diff is collapsed.
"""
B-ASIC test suite for printing a SFG
"""
from b_asic.signal_flow_graph import SFG
from b_asic.core_operations import Addition, Multiplication, Constant, ConstantAddition
from b_asic.port import InputPort, OutputPort
from b_asic.signal import Signal
from b_asic.special_operations import Input, Output
import pytest
class TestPrintSfg:
def test_print_one_addition(self):
inp1 = Input("INP1")
inp2 = Input("INP2")
add1 = Addition(inp1, inp2, "ADD1")
out1 = Output(add1, "OUT1")
sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="sf1")
assert sfg.__str__() == ("id: add1, name: ADD1, input: [s1, s2], output: [s3]\nid: in1, name: INP1, input: [], output: [s1]\nid: in2, name: INP2, input: [], output: [s2]\nid: out1, name: OUT1, input: [s3], output: []\n")
def test_print_add_mul(self):
inp1 = Input("INP1")
inp2 = Input("INP2")
inp3 = Input("INP3")
add1 = Addition(inp1, inp2, "ADD1")
mul1 = Multiplication(add1, inp3, "MUL1")
out1 = Output(mul1, "OUT1")
sfg = SFG(inputs=[inp1, inp2, inp3], outputs=[out1], name="mac_sfg")
assert sfg.__str__() == ("id: add1, name: ADD1, input: [s1, s2], output: [s5]\nid: in1, name: INP1, input: [], output: [s1]\nid: in2, name: INP2, input: [], output: [s2]\nid: mul1, name: MUL1, input: [s5, s3], output: [s4]\nid: in3, name: INP3, input: [], output: [s3]\nid: out1, name: OUT1, input: [s4], output: []\n")
def test_print_constant(self):
inp1 = Input("INP1")
const1 = Constant(3, "CONST")
add1 = Addition(const1, inp1, "ADD1")
out1 = Output(add1, "OUT1")
sfg = SFG(inputs=[inp1], outputs=[out1], name="sfg")
assert sfg.__str__() == ("id: add1, name: ADD1, input: [s3, s1], output: [s2]\nid: c1, name: CONST, value: 3, input: [], output: [s3]\nid: in1, name: INP1, input: [], output: [s1]\nid: out1, name: OUT1, input: [s2], output: []\n")
\ No newline at end of file
from b_asic import SFG
from b_asic.signal import Signal
from b_asic.core_operations import Addition, Constant
from b_asic.core_operations import Addition, Constant, Multiplication
from b_asic.special_operations import Input, Output
class TestConstructor:
def test_direct_input_to_output_sfg_construction(self):
inp = Input("INP1")
out = Output(None, "OUT1")
out.input(0).connect(inp, "S1")
sfg = SFG(inputs=[inp], outputs=[out])
assert len(list(sfg.components)) == 3
assert sfg.input_count == 1
assert sfg.output_count == 1
def test_same_signal_input_and_output_sfg_construction(self):
add1 = Addition(None, None, "ADD1")
add2 = Addition(None, None, "ADD2")
sig1 = add2.input(0).connect(add1, "S1")
sfg = SFG(input_signals=[sig1], output_signals=[sig1])
assert len(list(sfg.components)) == 3
assert sfg.input_count == 1
assert sfg.output_count == 1
def test_outputs_construction(self, operation_tree):
outp = Output(operation_tree)
sfg = SFG(outputs=[outp])
......@@ -20,13 +44,73 @@ class TestConstructor:
assert sfg.input_count == 0
assert sfg.output_count == 1
def test_operations_construction(self, operation_tree):
sfg1 = SFG(operations=[operation_tree])
sfg2 = SFG(operations=[operation_tree.input(1).signals[0].source.operation])
assert len(list(sfg1.components)) == 5
assert len(list(sfg2.components)) == 5
assert sfg1.input_count == 0
assert sfg2.input_count == 0
assert sfg1.output_count == 0
assert sfg2.output_count == 0
class TestDeepCopy:
def test_deep_copy_no_duplicates(self):
inp1 = Input("INP1")
inp2 = Input("INP2")
inp3 = Input("INP3")
add1 = Addition(inp1, inp2, "ADD1")
mul1 = Multiplication(add1, inp3, "MUL1")
out1 = Output(mul1, "OUT1")
mac_sfg = SFG(inputs=[inp1, inp2],
outputs=[out1], name="mac_sfg")
mac_sfg_deep_copy = mac_sfg.deep_copy()
for g_id, component in mac_sfg._components_by_id.items():
component_copy = mac_sfg_deep_copy.find_by_id(g_id)
assert component.name == component_copy.name
def test_deep_copy(self):
inp1 = Input("INP1")
inp2 = Input("INP2")
inp3 = Input("INP3")
add1 = Addition(None, None, "ADD1")
add2 = Addition(None, None, "ADD2")
mul1 = Multiplication(None, None, "MUL1")
out1 = Output(None, "OUT1")
add1.input(0).connect(inp1, "S1")
add1.input(1).connect(inp2, "S2")
add2.input(0).connect(add1, "S4")
add2.input(1).connect(inp3, "S3")
mul1.input(0).connect(add1, "S5")
mul1.input(1).connect(add2, "S6")
out1.input(0).connect(mul1, "S7")
mac_sfg = SFG(inputs=[inp1, inp2],
outputs=[out1], name="mac_sfg")
mac_sfg_deep_copy = mac_sfg.deep_copy()
for g_id, component in mac_sfg._components_by_id.items():
component_copy = mac_sfg_deep_copy.find_by_id(g_id)
assert component.name == component_copy.name
class TestComponents:
def test_advanced_components(self):
inp1 = Input("INP1")
inp2 = Input("INP2")
inp3 = Input("INP3")
add1 = Addition(None, None, "ADD1")
add2 = Addition(None, None, "ADD2")
mul1 = Multiplication(None, None, "MUL1")
out1 = Output(None, "OUT1")
add1.input(0).connect(inp1, "S1")
add1.input(1).connect(inp2, "S2")
add2.input(0).connect(add1, "S4")
add2.input(1).connect(inp3, "S3")
mul1.input(0).connect(add1, "S5")
mul1.input(1).connect(add2, "S6")
out1.input(0).connect(mul1, "S7")
mac_sfg = SFG(inputs=[inp1, inp2],
outputs=[out1], name="mac_sfg")
assert set([comp.name for comp in mac_sfg.components]) == {
"INP1", "INP2", "INP3", "ADD1", "ADD2", "MUL1", "OUT1", "S1", "S2", "S3", "S4", "S5", "S6", "S7"}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment