Skip to content
Snippets Groups Projects
Commit f69a4f3c authored by Adam Jakobsson's avatar Adam Jakobsson Committed by Felix Goding
Browse files

Refactor constructor so that Input signals and Output signals are connected to...

Refactor constructor so that Input signals and Output signals are connected to ports before traversal is started, that way edge cases of empty SFG's are easily handled
parent afe22efb
No related branches found
No related tags found
2 merge requests!67WIP: B-ASIC version 1.0.0 hotfix,!65B-ASIC version 1.0.0
...@@ -166,13 +166,14 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -166,13 +166,14 @@ class AbstractOperation(Operation, AbstractGraphComponent):
if input_sources is not None: if input_sources is not None:
source_count = len(input_sources) source_count = len(input_sources)
if source_count != input_count: if source_count != input_count:
raise ValueError(f"Operation expected {input_count} input sources but only got {source_count}") raise ValueError(
f"Operation expected {input_count} input sources but only got {source_count}")
for i, src in enumerate(input_sources): for i, src in enumerate(input_sources):
if src is not None: if src is not None:
self._input_ports[i].connect(src.source) self._input_ports[i].connect(src.source)
@abstractmethod @abstractmethod
def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ
"""Evaluate the operation and generate a list of output values given a """Evaluate the operation and generate a list of output values given a
list of input values. list of input values.
""" """
...@@ -246,11 +247,13 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -246,11 +247,13 @@ class AbstractOperation(Operation, AbstractGraphComponent):
result = self.evaluate(*input_values) result = self.evaluate(*input_values)
if isinstance(result, collections.Sequence): if isinstance(result, collections.Sequence):
if len(result) != self.output_count: if len(result) != self.output_count:
raise RuntimeError("Operation evaluated to incorrect number of outputs") raise RuntimeError(
"Operation evaluated to incorrect number of outputs")
return result return result
if isinstance(result, Number): if isinstance(result, Number):
if self.output_count != 1: if self.output_count != 1:
raise RuntimeError("Operation evaluated to incorrect number of outputs") raise RuntimeError(
"Operation evaluated to incorrect number of outputs")
return [result] return [result]
raise RuntimeError("Operation evaluated to invalid type") raise RuntimeError("Operation evaluated to invalid type")
...@@ -296,11 +299,13 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -296,11 +299,13 @@ class AbstractOperation(Operation, AbstractGraphComponent):
def source(self) -> OutputPort: def source(self) -> OutputPort:
if self.output_count != 1: if self.output_count != 1:
diff = "more" if self.output_count > 1 else "less" diff = "more" if self.output_count > 1 else "less"
raise TypeError(f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output") raise TypeError(
f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output")
return self.output(0) return self.output(0)
def copy_unconnected(self) -> GraphComponent: def copy_unconnected(self) -> GraphComponent:
new_comp: AbstractOperation = super().copy_unconnected() new_comp: AbstractOperation = super().copy_unconnected()
for name, value in self.params.items(): for name, value in self.params.items():
new_comp.set_param(name, deepcopy(value)) # pylint: disable=no-member new_comp.set_param(name, deepcopy(
value)) # pylint: disable=no-member
return new_comp return new_comp
...@@ -8,6 +8,7 @@ from copy import copy ...@@ -8,6 +8,7 @@ from copy import copy
from typing import NewType, Optional, List, Iterable, TYPE_CHECKING from typing import NewType, Optional, List, Iterable, TYPE_CHECKING
from b_asic.signal import Signal from b_asic.signal import Signal
from b_asic.graph_component import Name
if TYPE_CHECKING: if TYPE_CHECKING:
from b_asic.operation import Operation from b_asic.operation import Operation
...@@ -144,22 +145,24 @@ class InputPort(AbstractPort): ...@@ -144,22 +145,24 @@ class InputPort(AbstractPort):
""" """
return None if self._source_signal is None else self._source_signal.source return None if self._source_signal is None else self._source_signal.source
def connect(self, src: SignalSourceProvider) -> Signal: def connect(self, src: SignalSourceProvider, name: Name = "") -> Signal:
"""Connect the provided signal source to this input port by creating a new signal. """Connect the provided signal source to this input port by creating a new signal.
Returns the new signal. Returns the new signal.
""" """
assert self._source_signal is None, "Attempted to connect already connected input port." assert self._source_signal is None, "Attempted to connect already connected input port."
return Signal(src.source, self) # self._source_signal is set by the signal constructor. # self._source_signal is set by the signal constructor.
return Signal(source=src.source, destination=self, name=name)
@property @property
def value_length(self) -> Optional[int]: def value_length(self) -> Optional[int]:
"""Get the number of bits that this port should truncate received values to.""" """Get the number of bits that this port should truncate received values to."""
return self._value_length return self._value_length
@value_length.setter @value_length.setter
def value_length(self, bits: Optional[int]) -> None: def value_length(self, bits: Optional[int]) -> None:
"""Set the number of bits that this port should truncate received values to.""" """Set the number of bits that this port should truncate received values to."""
assert bits is None or (isinstance(bits, int) and bits >= 0), "Value length must be non-negative." assert bits is None or (isinstance(
bits, int) and bits >= 0), "Value length must be non-negative."
self._value_length = bits self._value_length = bits
...@@ -185,7 +188,7 @@ class OutputPort(AbstractPort, SignalSourceProvider): ...@@ -185,7 +188,7 @@ class OutputPort(AbstractPort, SignalSourceProvider):
def add_signal(self, signal: Signal) -> None: def add_signal(self, signal: Signal) -> None:
assert signal not in self._destination_signals, "Attempted to add already connected signal." assert signal not in self._destination_signals, "Attempted to add already connected signal."
self._destination_signals.append(signal) self._destination_signals.append(signal)
signal.set_source(self) signal.set_source(self)
def remove_signal(self, signal: Signal) -> None: def remove_signal(self, signal: Signal) -> None:
assert signal in self._destination_signals, "Attempted to remove already removed signal." assert signal in self._destination_signals, "Attempted to remove already removed signal."
...@@ -195,7 +198,7 @@ class OutputPort(AbstractPort, SignalSourceProvider): ...@@ -195,7 +198,7 @@ class OutputPort(AbstractPort, SignalSourceProvider):
def clear(self) -> None: def clear(self) -> None:
for signal in copy(self._destination_signals): for signal in copy(self._destination_signals):
self.remove_signal(signal) self.remove_signal(signal)
@property @property
def source(self) -> "OutputPort": def source(self) -> "OutputPort":
return self return self
\ No newline at end of file
This diff is collapsed.
"""
B-ASIC test suite for printing a SFG
"""
from b_asic.signal_flow_graph import SFG
from b_asic.core_operations import Addition, Multiplication, Constant, ConstantAddition
from b_asic.port import InputPort, OutputPort
from b_asic.signal import Signal
from b_asic.special_operations import Input, Output
import pytest
class TestPrintSfg:
def test_print_one_addition(self):
inp1 = Input("INP1")
inp2 = Input("INP2")
add1 = Addition(inp1, inp2, "ADD1")
out1 = Output(add1, "OUT1")
sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="sf1")
assert sfg.__str__() == ("id: add1, name: ADD1, input: [s1, s2], output: [s3]\nid: in1, name: INP1, input: [], output: [s1]\nid: in2, name: INP2, input: [], output: [s2]\nid: out1, name: OUT1, input: [s3], output: []\n")
def test_print_add_mul(self):
inp1 = Input("INP1")
inp2 = Input("INP2")
inp3 = Input("INP3")
add1 = Addition(inp1, inp2, "ADD1")
mul1 = Multiplication(add1, inp3, "MUL1")
out1 = Output(mul1, "OUT1")
sfg = SFG(inputs=[inp1, inp2, inp3], outputs=[out1], name="mac_sfg")
assert sfg.__str__() == ("id: add1, name: ADD1, input: [s1, s2], output: [s5]\nid: in1, name: INP1, input: [], output: [s1]\nid: in2, name: INP2, input: [], output: [s2]\nid: mul1, name: MUL1, input: [s5, s3], output: [s4]\nid: in3, name: INP3, input: [], output: [s3]\nid: out1, name: OUT1, input: [s4], output: []\n")
def test_print_constant(self):
inp1 = Input("INP1")
const1 = Constant(3, "CONST")
add1 = Addition(const1, inp1, "ADD1")
out1 = Output(add1, "OUT1")
sfg = SFG(inputs=[inp1], outputs=[out1], name="sfg")
assert sfg.__str__() == ("id: add1, name: ADD1, input: [s3, s1], output: [s2]\nid: c1, name: CONST, value: 3, input: [], output: [s3]\nid: in1, name: INP1, input: [], output: [s1]\nid: out1, name: OUT1, input: [s2], output: []\n")
\ No newline at end of file
from b_asic import SFG from b_asic import SFG
from b_asic.signal import Signal from b_asic.signal import Signal
from b_asic.core_operations import Addition, Constant from b_asic.core_operations import Addition, Constant, Multiplication
from b_asic.special_operations import Input, Output from b_asic.special_operations import Input, Output
class TestConstructor: class TestConstructor:
def test_direct_input_to_output_sfg_construction(self):
inp = Input("INP1")
out = Output(None, "OUT1")
out.input(0).connect(inp, "S1")
sfg = SFG(inputs=[inp], outputs=[out])
assert len(list(sfg.components)) == 3
assert sfg.input_count == 1
assert sfg.output_count == 1
def test_same_signal_input_and_output_sfg_construction(self):
add1 = Addition(None, None, "ADD1")
add2 = Addition(None, None, "ADD2")
sig1 = add2.input(0).connect(add1, "S1")
sfg = SFG(input_signals=[sig1], output_signals=[sig1])
assert len(list(sfg.components)) == 3
assert sfg.input_count == 1
assert sfg.output_count == 1
def test_outputs_construction(self, operation_tree): def test_outputs_construction(self, operation_tree):
outp = Output(operation_tree) outp = Output(operation_tree)
sfg = SFG(outputs=[outp]) sfg = SFG(outputs=[outp])
...@@ -20,13 +44,73 @@ class TestConstructor: ...@@ -20,13 +44,73 @@ class TestConstructor:
assert sfg.input_count == 0 assert sfg.input_count == 0
assert sfg.output_count == 1 assert sfg.output_count == 1
def test_operations_construction(self, operation_tree):
sfg1 = SFG(operations=[operation_tree])
sfg2 = SFG(operations=[operation_tree.input(1).signals[0].source.operation])
assert len(list(sfg1.components)) == 5 class TestDeepCopy:
assert len(list(sfg2.components)) == 5 def test_deep_copy_no_duplicates(self):
assert sfg1.input_count == 0 inp1 = Input("INP1")
assert sfg2.input_count == 0 inp2 = Input("INP2")
assert sfg1.output_count == 0 inp3 = Input("INP3")
assert sfg2.output_count == 0 add1 = Addition(inp1, inp2, "ADD1")
mul1 = Multiplication(add1, inp3, "MUL1")
out1 = Output(mul1, "OUT1")
mac_sfg = SFG(inputs=[inp1, inp2],
outputs=[out1], name="mac_sfg")
mac_sfg_deep_copy = mac_sfg.deep_copy()
for g_id, component in mac_sfg._components_by_id.items():
component_copy = mac_sfg_deep_copy.find_by_id(g_id)
assert component.name == component_copy.name
def test_deep_copy(self):
inp1 = Input("INP1")
inp2 = Input("INP2")
inp3 = Input("INP3")
add1 = Addition(None, None, "ADD1")
add2 = Addition(None, None, "ADD2")
mul1 = Multiplication(None, None, "MUL1")
out1 = Output(None, "OUT1")
add1.input(0).connect(inp1, "S1")
add1.input(1).connect(inp2, "S2")
add2.input(0).connect(add1, "S4")
add2.input(1).connect(inp3, "S3")
mul1.input(0).connect(add1, "S5")
mul1.input(1).connect(add2, "S6")
out1.input(0).connect(mul1, "S7")
mac_sfg = SFG(inputs=[inp1, inp2],
outputs=[out1], name="mac_sfg")
mac_sfg_deep_copy = mac_sfg.deep_copy()
for g_id, component in mac_sfg._components_by_id.items():
component_copy = mac_sfg_deep_copy.find_by_id(g_id)
assert component.name == component_copy.name
class TestComponents:
def test_advanced_components(self):
inp1 = Input("INP1")
inp2 = Input("INP2")
inp3 = Input("INP3")
add1 = Addition(None, None, "ADD1")
add2 = Addition(None, None, "ADD2")
mul1 = Multiplication(None, None, "MUL1")
out1 = Output(None, "OUT1")
add1.input(0).connect(inp1, "S1")
add1.input(1).connect(inp2, "S2")
add2.input(0).connect(add1, "S4")
add2.input(1).connect(inp3, "S3")
mul1.input(0).connect(add1, "S5")
mul1.input(1).connect(add2, "S6")
out1.input(0).connect(mul1, "S7")
mac_sfg = SFG(inputs=[inp1, inp2],
outputs=[out1], name="mac_sfg")
assert set([comp.name for comp in mac_sfg.components]) == {
"INP1", "INP2", "INP3", "ADD1", "ADD2", "MUL1", "OUT1", "S1", "S2", "S3", "S4", "S5", "S6", "S7"}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment