From bb6605f0407c66de130f9f8cfff41be5b65dc4bb Mon Sep 17 00:00:00 2001 From: Angus Lothian <anglo547@student.liu.se> Date: Tue, 28 Apr 2020 19:17:46 +0200 Subject: [PATCH] Change usage of type_name from property to function and add function for returning all components of the entered type --- b_asic/__init__.py | 2 - b_asic/core_operations.py | 85 +++++----- b_asic/graph_component.py | 10 +- b_asic/operation.py | 97 ++++++----- b_asic/port.py | 2 +- b_asic/precedence_chart.py | 21 --- b_asic/schema.py | 21 --- b_asic/signal.py | 11 +- b_asic/signal_flow_graph.py | 278 +++++++++++++++++++++---------- b_asic/simulation.py | 21 ++- b_asic/special_operations.py | 38 +++-- test/test_core_operations.py | 11 ++ test/test_depends.py | 19 --- test/test_sfg.py | 307 ++++++++++++++++++++++++++++++----- 14 files changed, 620 insertions(+), 303 deletions(-) delete mode 100644 b_asic/precedence_chart.py delete mode 100644 b_asic/schema.py delete mode 100644 test/test_depends.py diff --git a/b_asic/__init__.py b/b_asic/__init__.py index bd3574ba..b35d0c1b 100644 --- a/b_asic/__init__.py +++ b/b_asic/__init__.py @@ -5,9 +5,7 @@ TODO: More info. from b_asic.core_operations import * from b_asic.graph_component import * from b_asic.operation import * -from b_asic.precedence_chart import * from b_asic.port import * -from b_asic.schema import * from b_asic.signal_flow_graph import * from b_asic.signal import * from b_asic.simulation import * diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index 296803e3..3e6cd787 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -18,16 +18,16 @@ class Constant(AbstractOperation): """ def __init__(self, value: Number = 0, name: Name = ""): - super().__init__(input_count = 0, output_count = 1, name = name) + super().__init__(input_count=0, output_count=1, name=name) self.set_param("value", value) - @property - def type_name(self) -> TypeName: + @classmethod + def type_name(cls) -> TypeName: return "c" def evaluate(self): return self.param("value") - + @property def value(self) -> Number: """Get the constant value of this operation.""" @@ -45,10 +45,11 @@ class Addition(AbstractOperation): """ def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 2, output_count = 1, name = name, input_sources = [src0, src1]) + super().__init__(input_count=2, output_count=1, + name=name, input_sources=[src0, src1]) - @property - def type_name(self) -> TypeName: + @classmethod + def type_name(cls) -> TypeName: return "add" def evaluate(self, a, b): @@ -61,10 +62,11 @@ class Subtraction(AbstractOperation): """ def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 2, output_count = 1, name = name, input_sources = [src0, src1]) + super().__init__(input_count=2, output_count=1, + name=name, input_sources=[src0, src1]) - @property - def type_name(self) -> TypeName: + @classmethod + def type_name(cls) -> TypeName: return "sub" def evaluate(self, a, b): @@ -77,10 +79,11 @@ class Multiplication(AbstractOperation): """ def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 2, output_count = 1, name = name, input_sources = [src0, src1]) + super().__init__(input_count=2, output_count=1, + name=name, input_sources=[src0, src1]) - @property - def type_name(self) -> TypeName: + @classmethod + def type_name(cls) -> TypeName: return "mul" def evaluate(self, a, b): @@ -93,10 +96,11 @@ class Division(AbstractOperation): """ def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 2, output_count = 1, name = name, input_sources = [src0, src1]) + super().__init__(input_count=2, output_count=1, + name=name, input_sources=[src0, src1]) - @property - def type_name(self) -> TypeName: + @classmethod + def type_name(cls) -> TypeName: return "div" def evaluate(self, a, b): @@ -109,10 +113,11 @@ class Min(AbstractOperation): """ def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 2, output_count = 1, name = name, input_sources = [src0, src1]) + super().__init__(input_count=2, output_count=1, + name=name, input_sources=[src0, src1]) - @property - def type_name(self) -> TypeName: + @classmethod + def type_name(cls) -> TypeName: return "min" def evaluate(self, a, b): @@ -127,10 +132,11 @@ class Max(AbstractOperation): """ def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 2, output_count = 1, name = name, input_sources = [src0, src1]) + super().__init__(input_count=2, output_count=1, + name=name, input_sources=[src0, src1]) - @property - def type_name(self) -> TypeName: + @classmethod + def type_name(cls) -> TypeName: return "max" def evaluate(self, a, b): @@ -145,10 +151,11 @@ class SquareRoot(AbstractOperation): """ def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) + super().__init__(input_count=1, output_count=1, + name=name, input_sources=[src0]) - @property - def type_name(self) -> TypeName: + @classmethod + def type_name(cls) -> TypeName: return "sqrt" def evaluate(self, a): @@ -161,10 +168,11 @@ class ComplexConjugate(AbstractOperation): """ def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) + super().__init__(input_count=1, output_count=1, + name=name, input_sources=[src0]) - @property - def type_name(self) -> TypeName: + @classmethod + def type_name(cls) -> TypeName: return "conj" def evaluate(self, a): @@ -177,10 +185,11 @@ class Absolute(AbstractOperation): """ def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) + super().__init__(input_count=1, output_count=1, + name=name, input_sources=[src0]) - @property - def type_name(self) -> TypeName: + @classmethod + def type_name(cls) -> TypeName: return "abs" def evaluate(self, a): @@ -193,11 +202,12 @@ class ConstantMultiplication(AbstractOperation): """ def __init__(self, value: Number = 0, src0: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) + super().__init__(input_count=1, output_count=1, + name=name, input_sources=[src0]) self.set_param("value", value) - @property - def type_name(self) -> TypeName: + @classmethod + def type_name(cls) -> TypeName: return "cmul" def evaluate(self, a): @@ -221,10 +231,11 @@ class Butterfly(AbstractOperation): """ def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 2, output_count = 2, name = name, input_sources = [src0, src1]) + super().__init__(input_count=2, output_count=2, + name=name, input_sources=[src0, src1]) - @property - def type_name(self) -> TypeName: + @classmethod + def type_name(cls) -> TypeName: return "bfly" def evaluate(self, a, b): diff --git a/b_asic/graph_component.py b/b_asic/graph_component.py index e3799701..5efa8038 100644 --- a/b_asic/graph_component.py +++ b/b_asic/graph_component.py @@ -20,9 +20,9 @@ class GraphComponent(ABC): TODO: More info. """ - @property + @classmethod @abstractmethod - def type_name(self) -> TypeName: + def type_name(cls) -> TypeName: """Get the type name of this graph component""" raise NotImplementedError @@ -112,7 +112,7 @@ class AbstractGraphComponent(GraphComponent): @name.setter def name(self, name: Name) -> None: self._name = name - + @property def graph_id(self) -> GraphID: return self._graph_id @@ -136,7 +136,7 @@ class AbstractGraphComponent(GraphComponent): new_component.name = copy(self.name) new_component.graph_id = copy(self.graph_id) for name, value in self.params.items(): - new_component.set_param(copy(name), deepcopy(value)) # pylint: disable=no-member + new_component.set_param(copy(name), deepcopy(value)) # pylint: disable=no-member return new_component def traverse(self) -> Generator[GraphComponent, None, None]: @@ -149,4 +149,4 @@ class AbstractGraphComponent(GraphComponent): for neighbor in component.neighbors: if neighbor not in visited: visited.add(neighbor) - fontier.append(neighbor) \ No newline at end of file + fontier.append(neighbor) diff --git a/b_asic/operation.py b/b_asic/operation.py index 92c7b2b0..f8ac22e2 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -14,11 +14,12 @@ from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name from b_asic.port import SignalSourceProvider, InputPort, OutputPort from b_asic.signal import Signal -ResultKey = NewType("ResultKey", str) -ResultMap = Mapping[ResultKey, Optional[Number]] -MutableResultMap = MutableMapping[ResultKey, Optional[Number]] -RegisterMap = Mapping[ResultKey, Number] -MutableRegisterMap = MutableMapping[ResultKey, Number] +OutputKey = NewType("OutputKey", str) +OutputMap = Mapping[OutputKey, Optional[Number]] +MutableOutputMap = MutableMapping[OutputKey, Optional[Number]] +RegisterMap = Mapping[OutputKey, Number] +MutableRegisterMap = MutableMapping[OutputKey, Number] + class Operation(GraphComponent, SignalSourceProvider): """Operation interface. @@ -134,9 +135,9 @@ class Operation(GraphComponent, SignalSourceProvider): raise NotImplementedError @abstractmethod - def key(self, index: int, prefix: str = "") -> ResultKey: - """Get the key used to access the result of a certain output of this operation - from the results parameter passed to current_output(s) or evaluate_output(s). + def key(self, index: int, prefix: str = "") -> OutputKey: + """Get the key used to access the output of a certain output of this operation + from the output parameter passed to current_output(s) or evaluate_output(s). """ raise NotImplementedError @@ -150,7 +151,7 @@ class Operation(GraphComponent, SignalSourceProvider): raise NotImplementedError @abstractmethod - def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Number: + def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableOutputMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Number: """Evaluate the output at the given index of this operation with the given input values. The results parameter will be used to store any results (including intermediate results) for caching. The registers parameter will be used to get the current value of any intermediate registers that are encountered, and be updated with their new values. @@ -167,7 +168,7 @@ class Operation(GraphComponent, SignalSourceProvider): raise NotImplementedError @abstractmethod - def evaluate_outputs(self, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Sequence[Number]: + def evaluate_outputs(self, input_values: Sequence[Number], results: Optional[MutableOutputMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Sequence[Number]: """Evaluate all outputs of this operation given the input values. See evaluate_output for more information. """ @@ -196,55 +197,65 @@ class AbstractOperation(Operation, AbstractGraphComponent): def __init__(self, input_count: int, output_count: int, name: Name = "", input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None): super().__init__(name) - self._input_ports = [InputPort(self, i) for i in range(input_count)] # Allocate input ports. - self._output_ports = [OutputPort(self, i) for i in range(output_count)] # Allocate output ports. + + self._input_ports = [InputPort(self, i) for i in range(input_count)] + self._output_ports = [OutputPort(self, i) for i in range(output_count)] # Connect given input sources, if any. if input_sources is not None: source_count = len(input_sources) if source_count != input_count: - raise ValueError(f"Wrong number of input sources supplied to Operation (expected {input_count}, got {source_count})") + raise ValueError( + f"Wrong number of input sources supplied to Operation (expected {input_count}, got {source_count})") for i, src in enumerate(input_sources): if src is not None: self._input_ports[i].connect(src.source) @abstractmethod - def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ + def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ """Evaluate the operation and generate a list of output values given a list of input values.""" raise NotImplementedError def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Addition": - from b_asic.core_operations import Constant, Addition # Import here to avoid circular imports. + # Import here to avoid circular imports. + from b_asic.core_operations import Constant, Addition return Addition(self, Constant(src) if isinstance(src, Number) else src) - + def __radd__(self, src: Union[SignalSourceProvider, Number]) -> "Addition": - from b_asic.core_operations import Constant, Addition # Import here to avoid circular imports. + # Import here to avoid circular imports. + from b_asic.core_operations import Constant, Addition return Addition(Constant(src) if isinstance(src, Number) else src, self) def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction": - from b_asic.core_operations import Constant, Subtraction # Import here to avoid circular imports. + # Import here to avoid circular imports. + from b_asic.core_operations import Constant, Subtraction return Subtraction(self, Constant(src) if isinstance(src, Number) else src) - + def __rsub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction": - from b_asic.core_operations import Constant, Subtraction # Import here to avoid circular imports. + # Import here to avoid circular imports. + from b_asic.core_operations import Constant, Subtraction return Subtraction(Constant(src) if isinstance(src, Number) else src, self) def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]": - from b_asic.core_operations import Multiplication, ConstantMultiplication # Import here to avoid circular imports. + # Import here to avoid circular imports. + from b_asic.core_operations import Multiplication, ConstantMultiplication return ConstantMultiplication(src, self) if isinstance(src, Number) else Multiplication(self, src) - + def __rmul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]": - from b_asic.core_operations import Multiplication, ConstantMultiplication # Import here to avoid circular imports. + # Import here to avoid circular imports. + from b_asic.core_operations import Multiplication, ConstantMultiplication return ConstantMultiplication(src, self) if isinstance(src, Number) else Multiplication(src, self) def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division": - from b_asic.core_operations import Constant, Division # Import here to avoid circular imports. + # Import here to avoid circular imports. + from b_asic.core_operations import Constant, Division return Division(self, Constant(src) if isinstance(src, Number) else src) - + def __rtruediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division": - from b_asic.core_operations import Constant, Division # Import here to avoid circular imports. + # Import here to avoid circular imports. + from b_asic.core_operations import Constant, Division return Division(Constant(src) if isinstance(src, Number) else src, self) - + @property def input_count(self) -> int: return len(self._input_ports) @@ -282,8 +293,8 @@ class AbstractOperation(Operation, AbstractGraphComponent): for s in p.signals: result.append(s) return result - - def key(self, index: int, prefix: str = "") -> ResultKey: + + def key(self, index: int, prefix: str = "") -> OutputKey: key = prefix if self.output_count != 1: if key: @@ -292,15 +303,17 @@ class AbstractOperation(Operation, AbstractGraphComponent): elif not key: key = str(index) return key - + def current_output(self, index: int, registers: Optional[RegisterMap] = None, prefix: str = "") -> Optional[Number]: return None - def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Number: + def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableOutputMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Number: if index < 0 or index >= self.output_count: - raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {index})") + raise IndexError( + f"Output index out of range (expected 0-{self.output_count - 1}, got {index})") if len(input_values) != self.input_count: - raise ValueError(f"Wrong number of input values supplied to operation (expected {self.input_count}, got {len(input_values)})") + raise ValueError( + f"Wrong number of input values supplied to operation (expected {self.input_count}, got {len(input_values)})") if results is None: results = {} if registers is None: @@ -309,13 +322,16 @@ class AbstractOperation(Operation, AbstractGraphComponent): values = self.evaluate(*self.truncate_inputs(input_values)) if isinstance(values, collections.abc.Sequence): if len(values) != self.output_count: - raise RuntimeError(f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got {len(values)})") + raise RuntimeError( + f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got {len(values)})") elif isinstance(values, Number): if self.output_count != 1: - raise RuntimeError(f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got 1)") + raise RuntimeError( + f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got 1)") values = (values,) else: - raise RuntimeError(f"Operation evaluated to invalid type (expected Sequence/Number, got {values.__class__.__name__})") + raise RuntimeError( + f"Operation evaluated to invalid type (expected Sequence/Number, got {values.__class__.__name__})") if self.output_count == 1: results[self.key(index, prefix)] = values[index] @@ -327,7 +343,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): def current_outputs(self, registers: Optional[RegisterMap] = None, prefix: str = "") -> Sequence[Optional[Number]]: return [self.current_output(i, registers, prefix) for i in range(self.output_count)] - def evaluate_outputs(self, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Sequence[Number]: + def evaluate_outputs(self, input_values: Sequence[Number], results: Optional[MutableOutputMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Sequence[Number]: return [self.evaluate_output(i, input_values, results, registers, prefix) for i in range(self.output_count)] def split(self) -> Iterable[Operation]: @@ -348,7 +364,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): def inputs_required_for_output(self, output_index: int) -> Iterable[int]: if output_index < 0 or output_index >= self.output_count: raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})") - return [i for i in range(self.input_count)] # By default, assume each output depends on all inputs. + return [i for i in range(self.input_count)] # By default, assume each output depends on all inputs. @property def neighbors(self) -> Iterable[GraphComponent]: @@ -361,7 +377,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): raise TypeError( f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output") return self.output(0) - + def truncate_input(self, index: int, value: Number, bits: int) -> Number: """Truncate the value to be used as input at the given index to a certain bit length.""" n = value @@ -379,7 +395,8 @@ class AbstractOperation(Operation, AbstractGraphComponent): args.append(input_values[i]) else: if isinstance(input_values[i], complex): - raise TypeError("Complex value cannot be truncated to {bits} bits as requested by the signal connected to input #{i}") + raise TypeError( + "Complex value cannot be truncated to {bits} bits as requested by the signal connected to input #{i}") args.append(self.truncate_input(i, input_values[i], bits)) else: args.append(input_values[i]) diff --git a/b_asic/port.py b/b_asic/port.py index e8c007cb..59a218d9 100644 --- a/b_asic/port.py +++ b/b_asic/port.py @@ -5,7 +5,7 @@ TODO: More info. from abc import ABC, abstractmethod from copy import copy -from typing import NewType, Optional, List, Iterable, TYPE_CHECKING +from typing import Optional, List, Iterable, TYPE_CHECKING from b_asic.signal import Signal from b_asic.graph_component import Name diff --git a/b_asic/precedence_chart.py b/b_asic/precedence_chart.py deleted file mode 100644 index be55a123..00000000 --- a/b_asic/precedence_chart.py +++ /dev/null @@ -1,21 +0,0 @@ -"""@package docstring -B-ASIC Precedence Chart Module. -TODO: More info. -""" - -from b_asic.signal_flow_graph import SFG - - -class PrecedenceChart: - """Precedence chart constructed from a signal flow graph. - TODO: More info. - """ - - sfg: SFG - # TODO: More members. - - def __init__(self, sfg: SFG): - self.sfg = sfg - # TODO: Implement. - - # TODO: More stuff. diff --git a/b_asic/schema.py b/b_asic/schema.py deleted file mode 100644 index e5068cdc..00000000 --- a/b_asic/schema.py +++ /dev/null @@ -1,21 +0,0 @@ -"""@package docstring -B-ASIC Schema Module. -TODO: More info. -""" - -from b_asic.precedence_chart import PrecedenceChart - - -class Schema: - """Schema constructed from a precedence chart. - TODO: More info. - """ - - pc: PrecedenceChart - # TODO: More members. - - def __init__(self, pc: PrecedenceChart): - self.pc = pc - # TODO: Implement. - - # TODO: More stuff. diff --git a/b_asic/signal.py b/b_asic/signal.py index d322f161..24e8cc81 100644 --- a/b_asic/signal.py +++ b/b_asic/signal.py @@ -25,14 +25,14 @@ class Signal(AbstractGraphComponent): self.set_destination(destination) self.set_param("bits", bits) - @property - def type_name(self) -> TypeName: + @classmethod + def type_name(cls) -> TypeName: return "s" @property def neighbors(self) -> Iterable[GraphComponent]: return [p.operation for p in [self.source, self.destination] if p is not None] - + @property def source(self) -> Optional["OutputPort"]: """Return the source OutputPort of the signal.""" @@ -103,5 +103,6 @@ class Signal(AbstractGraphComponent): def bits(self, bits: Optional[int]) -> None: """Set the number of bits that operations using this signal as an input should truncate received values to. None = unlimited.""" - assert bits is None or (isinstance(bits, int) and bits >= 0), "Bits must be non-negative." - self.set_param("bits", bits) \ No newline at end of file + assert bits is None or (isinstance(bits, int) + and bits >= 0), "Bits must be non-negative." + self.set_param("bits", bits) diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 61449fc3..4dfb5ef3 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -8,10 +8,11 @@ from numbers import Number from collections import defaultdict, deque from b_asic.port import SignalSourceProvider, OutputPort -from b_asic.operation import Operation, AbstractOperation, ResultKey, RegisterMap, MutableResultMap, MutableRegisterMap +from b_asic.operation import Operation, AbstractOperation, MutableOutputMap, MutableRegisterMap from b_asic.signal import Signal from b_asic.graph_component import GraphID, GraphIDNumber, GraphComponent, Name, TypeName -from b_asic.special_operations import Input, Output +from b_asic.special_operations import Input, Output, Register +from b_asic.core_operations import Constant class GraphIDGenerator: @@ -30,7 +31,7 @@ class GraphIDGenerator: @property def id_number_offset(self) -> GraphIDNumber: """Get the graph id number offset of this generator.""" - return self._next_id_number.default_factory() # pylint: disable=not-callable + return self._next_id_number.default_factory() # pylint: disable=not-callable class SFG(AbstractOperation): @@ -48,18 +49,21 @@ class SFG(AbstractOperation): _original_components_to_new: MutableSet[GraphComponent] _original_input_signals_to_indices: Dict[Signal, int] _original_output_signals_to_indices: Dict[Signal, int] + _precedence_list: Optional[List[List[OutputPort]]] - def __init__(self, input_signals: Optional[Sequence[Signal]] = None, output_signals: Optional[Sequence[Signal]] = None, \ - inputs: Optional[Sequence[Input]] = None, outputs: Optional[Sequence[Output]] = None, \ - id_number_offset: GraphIDNumber = 0, name: Name = "", \ + def __init__(self, input_signals: Optional[Sequence[Signal]] = None, output_signals: Optional[Sequence[Signal]] = None, + inputs: Optional[Sequence[Input]] = None, outputs: Optional[Sequence[Output]] = None, + id_number_offset: GraphIDNumber = 0, name: Name = "", input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None): + input_signal_count = 0 if input_signals is None else len(input_signals) input_operation_count = 0 if inputs is None else len(inputs) - output_signal_count = 0 if output_signals is None else len(output_signals) + output_signal_count = 0 if output_signals is None else len( + output_signals) output_operation_count = 0 if outputs is None else len(outputs) - super().__init__(input_count = input_signal_count + input_operation_count, - output_count = output_signal_count + output_operation_count, - name = name, input_sources = input_sources) + super().__init__(input_count=input_signal_count + input_operation_count, + output_count=output_signal_count + output_operation_count, + name=name, input_sources=input_sources) self._components_by_id = dict() self._components_by_name = defaultdict(list) @@ -71,6 +75,7 @@ class SFG(AbstractOperation): self._original_components_to_new = {} self._original_input_signals_to_indices = {} self._original_output_signals_to_indices = {} + self._precedence_list = None # Setup input signals. if input_signals is not None: @@ -123,7 +128,8 @@ class SFG(AbstractOperation): new_signal = self._original_components_to_new[signal] else: # New signal has to be created. - new_signal = self._add_component_unconnected_copy(signal) + new_signal = self._add_component_unconnected_copy( + signal) new_signal.set_destination(new_output_op.input(0)) self._original_output_signals_to_indices[signal] = output_index @@ -138,13 +144,17 @@ class SFG(AbstractOperation): new_signal = self._original_components_to_new[signal] if new_signal.destination is None: if signal.destination is None: - raise ValueError(f"Input signal #{input_index} is missing destination in SFG") + raise ValueError( + f"Input signal #{input_index} is missing destination in SFG") if signal.destination.operation not in self._original_components_to_new: - self._add_operation_connected_tree_copy(signal.destination.operation) + self._add_operation_connected_tree_copy( + signal.destination.operation) elif new_signal.destination.operation in output_operations_set: # Add directly connected input to output to ordered list. - self._components_ordered.extend([new_signal.source.operation, new_signal, new_signal.destination.operation]) - self._operations_ordered.extend([new_signal.source.operation, new_signal.destination.operation]) + self._components_ordered.extend( + [new_signal.source.operation, new_signal, new_signal.destination.operation]) + self._operations_ordered.extend( + [new_signal.source.operation, new_signal.destination.operation]) # Search the graph inwards from each output signal. for signal, output_index in self._original_output_signals_to_indices.items(): @@ -152,13 +162,12 @@ class SFG(AbstractOperation): new_signal = self._original_components_to_new[signal] if new_signal.source is None: if signal.source is None: - raise ValueError(f"Output signal #{output_index} is missing source in SFG") + raise ValueError( + f"Output signal #{output_index} is missing source in SFG") if signal.source.operation not in self._original_components_to_new: - self._add_operation_connected_tree_copy(signal.source.operation) - - # Find dependencies. + self._add_operation_connected_tree_copy( + signal.source.operation) - def __str__(self) -> str: """Get a string representation of this SFG.""" output_string = "" @@ -167,16 +176,17 @@ class SFG(AbstractOperation): for key, value in self._components_by_id.items(): if value is component: output_string += "id: " + key + ", name: " - + if component.name != None: output_string += component.name + ", " else: output_string += "-, " - - if component.type_name is "c": - output_string += "value: " + str(component.value) + ", input: [" + + if isinstance(component, Constant): + output_string += "value: " + \ + str(component.value) + ", input: [" else: - output_string += "input: [" + output_string += "input: [" counter_input = 0 for input in component.inputs: @@ -185,7 +195,7 @@ class SFG(AbstractOperation): for key, value in self._components_by_id.items(): if value is signal: output_string += key + ", " - + if counter_input > 0: output_string = output_string[:-2] output_string += "], output: [" @@ -204,11 +214,11 @@ class SFG(AbstractOperation): def __call__(self, *src: Optional[SignalSourceProvider], name: Name = "") -> "SFG": """Get a new independent SFG instance that is identical to this SFG except without any of its external connections.""" - return SFG(inputs = self._input_operations, outputs = self._output_operations, - id_number_offset = self.id_number_offset, name = name, input_sources = src if src else None) + return SFG(inputs=self._input_operations, outputs=self._output_operations, + id_number_offset=self.id_number_offset, name=name, input_sources=src if src else None) - @property - def type_name(self) -> TypeName: + @classmethod + def type_name(cls) -> TypeName: return "sfg" def evaluate(self, *args): @@ -216,21 +226,24 @@ class SFG(AbstractOperation): n = len(result) return None if n == 0 else result[0] if n == 1 else result - def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Number: + def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableOutputMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Number: if index < 0 or index >= self.output_count: - raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {index})") + raise IndexError( + f"Output index out of range (expected 0-{self.output_count - 1}, got {index})") if len(input_values) != self.input_count: - raise ValueError(f"Wrong number of inputs supplied to SFG for evaluation (expected {self.input_count}, got {len(input_values)})") + raise ValueError( + f"Wrong number of inputs supplied to SFG for evaluation (expected {self.input_count}, got {len(input_values)})") if results is None: results = {} if registers is None: registers = {} - + # Set the values of our input operations to the given input values. for op, arg in zip(self._input_operations, self.truncate_inputs(input_values)): op.value = arg - - value = self._evaluate_source(self._output_operations[index].input(0).signals[0].source, results, registers, prefix) + + value = self._evaluate_source(self._output_operations[index].input( + 0).signals[0].source, results, registers, prefix) results[self.key(index, prefix)] = value return value @@ -273,12 +286,36 @@ class SFG(AbstractOperation): def inputs_required_for_output(self, output_index: int) -> Iterable[int]: if output_index < 0 or output_index >= self.output_count: - raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})") - return self._inputs_required_for_source(self._output_operations[output_index].input(0).signals[0].source, set()) - + raise IndexError( + f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})") + + input_indexes_required = [] + sfg_input_operations_to_indexes = { + input_op: index for index, input_op in enumerate(self._input_operations)} + output_op = self._output_operations[output_index] + queue = deque([output_op]) + visited = set([output_op]) + while queue: + op = queue.popleft() + if isinstance(op, Input): + if op in sfg_input_operations_to_indexes: + input_indexes_required.append( + sfg_input_operations_to_indexes[op]) + del sfg_input_operations_to_indexes[op] + + for input_port in op.inputs: + for signal in input_port.signals: + if signal.source is not None: + new_op = signal.source.operation + if new_op not in visited: + queue.append(new_op) + visited.add(new_op) + + return input_indexes_required + def copy_component(self, *args, **kwargs) -> GraphComponent: - return super().copy_component(*args, **kwargs, inputs = self._input_operations, outputs = self._output_operations, - id_number_offset = self.id_number_offset, name = self.name) + return super().copy_component(*args, **kwargs, inputs=self._input_operations, outputs=self._output_operations, + id_number_offset=self.id_number_offset, name=self.name) @property def id_number_offset(self) -> GraphIDNumber: @@ -295,12 +332,28 @@ class SFG(AbstractOperation): """Get all operations of this graph in depth-first order.""" return self._operations_ordered + def get_components_with_type_name(self, type_name: TypeName) -> List[GraphComponent]: + """Get a list with all components in this graph with the specified type_name. + + Keyword arguments: + type_name: The type_name of the desired components. + """ + i = self.id_number_offset + 1 + components = [] + found_comp = self.find_by_id(type_name + str(i)) + while found_comp is not None: + components.append(found_comp) + i += 1 + found_comp = self.find_by_id(type_name + str(i)) + + return components + def find_by_id(self, graph_id: GraphID) -> Optional[GraphComponent]: """Find the graph component with the specified ID. Returns None if the component was not found. Keyword arguments: - graph_id: Graph ID of the desired component(s) + graph_id: Graph ID of the desired component. """ return self._components_by_id.get(graph_id, None) @@ -317,7 +370,7 @@ class SFG(AbstractOperation): assert original_component not in self._original_components_to_new, "Tried to add duplicate SFG component" new_component = original_component.copy_component() self._original_components_to_new[original_component] = new_component - new_id = self._graph_id_generator.next_id(new_component.type_name) + new_id = self._graph_id_generator.next_id(new_component.type_name()) new_component.graph_id = new_id self._components_by_id[new_id] = new_component self._components_by_name[new_component.name].append(new_component) @@ -346,28 +399,37 @@ class SFG(AbstractOperation): if original_signal in self._original_input_signals_to_indices: # New signal already created during first step of constructor. new_signal = self._original_components_to_new[original_signal] - new_signal.set_destination(new_op.input(original_input_port.index)) - self._components_ordered.extend([new_signal, new_signal.source.operation]) - self._operations_ordered.append(new_signal.source.operation) + new_signal.set_destination( + new_op.input(original_input_port.index)) + self._components_ordered.extend( + [new_signal, new_signal.source.operation]) + self._operations_ordered.append( + new_signal.source.operation) # Check if the signal has not been added before. elif original_signal not in self._original_components_to_new: if original_signal.source is None: - raise ValueError("Dangling signal without source in SFG") - - new_signal = self._add_component_unconnected_copy(original_signal) - new_signal.set_destination(new_op.input(original_input_port.index)) + raise ValueError( + "Dangling signal without source in SFG") + + new_signal = self._add_component_unconnected_copy( + original_signal) + new_signal.set_destination( + new_op.input(original_input_port.index)) self._components_ordered.append(new_signal) original_connected_op = original_signal.source.operation # Check if connected Operation has been added before. if original_connected_op in self._original_components_to_new: # Set source to the already added operations port. - new_signal.set_source(self._original_components_to_new[original_connected_op].output(original_signal.source.index)) + new_signal.set_source(self._original_components_to_new[original_connected_op].output( + original_signal.source.index)) else: # Create new operation, set signal source to it. - new_connected_op = self._add_component_unconnected_copy(original_connected_op) - new_signal.set_source(new_connected_op.output(original_signal.source.index)) + new_connected_op = self._add_component_unconnected_copy( + original_connected_op) + new_signal.set_source(new_connected_op.output( + original_signal.source.index)) self._components_ordered.append(new_connected_op) self._operations_ordered.append(new_connected_op) @@ -381,28 +443,37 @@ class SFG(AbstractOperation): if original_signal in self._original_output_signals_to_indices: # New signal already created during first step of constructor. new_signal = self._original_components_to_new[original_signal] - new_signal.set_source(new_op.output(original_output_port.index)) - self._components_ordered.extend([new_signal, new_signal.destination.operation]) - self._operations_ordered.append(new_signal.destination.operation) + new_signal.set_source( + new_op.output(original_output_port.index)) + self._components_ordered.extend( + [new_signal, new_signal.destination.operation]) + self._operations_ordered.append( + new_signal.destination.operation) # Check if signal has not been added before. elif original_signal not in self._original_components_to_new: if original_signal.source is None: - raise ValueError("Dangling signal without source in SFG") + raise ValueError( + "Dangling signal without source in SFG") - new_signal = self._add_component_unconnected_copy(original_signal) - new_signal.set_source(new_op.output(original_output_port.index)) + new_signal = self._add_component_unconnected_copy( + original_signal) + new_signal.set_source( + new_op.output(original_output_port.index)) self._components_ordered.append(new_signal) original_connected_op = original_signal.destination.operation # Check if connected operation has been added. if original_connected_op in self._original_components_to_new: # Set destination to the already connected operations port. - new_signal.set_destination(self._original_components_to_new[original_connected_op].input(original_signal.destination.index)) + new_signal.set_destination(self._original_components_to_new[original_connected_op].input( + original_signal.destination.index)) else: # Create new operation, set destination to it. - new_connected_op = self._add_component_unconnected_copy(original_connected_op) - new_signal.set_destination(new_connected_op.input(original_signal.destination.index)) + new_connected_op = self._add_component_unconnected_copy( + original_connected_op) + new_signal.set_destination(new_connected_op.input( + original_signal.destination.index)) self._components_ordered.append(new_connected_op) self._operations_ordered.append(new_connected_op) @@ -438,7 +509,7 @@ class SFG(AbstractOperation): for _signal in _inp.signals: _signal.remove_destination() _signal.set_destination(component.input(index_in)) - + for index_out, _out in enumerate(_component.outputs): for _signal in _out.signals: _signal.remove_source() @@ -447,7 +518,7 @@ class SFG(AbstractOperation): # The old SFG will be deleted by Python GC return self() - def _evaluate_source(self, src: OutputPort, results: MutableResultMap, registers: MutableRegisterMap, prefix: str) -> Number: + def _evaluate_source(self, src: OutputPort, results: MutableOutputMap, registers: MutableRegisterMap, prefix: str) -> Number: src_prefix = prefix if src_prefix: src_prefix += "." @@ -457,26 +528,69 @@ class SFG(AbstractOperation): if key in results: value = results[key] if value is None: - raise RuntimeError(f"Direct feedback loop detected when evaluating operation.") + raise RuntimeError( + f"Direct feedback loop detected when evaluating operation.") return value - results[key] = src.operation.current_output(src.index, registers, src_prefix) - input_values = [self._evaluate_source(input_port.signals[0].source, results, registers, prefix) for input_port in src.operation.inputs] - value = src.operation.evaluate_output(src.index, input_values, results, registers, src_prefix) + results[key] = src.operation.current_output( + src.index, registers, src_prefix) + input_values = [self._evaluate_source( + input_port.signals[0].source, results, registers, prefix) for input_port in src.operation.inputs] + value = src.operation.evaluate_output( + src.index, input_values, results, registers, src_prefix) results[key] = value return value - def _inputs_required_for_source(self, src: OutputPort, visited: MutableSet[Operation]) -> Sequence[bool]: - if src.operation in visited: - return [] - visited.add(src.operation) - - if isinstance(src.operation, Input): - for i, input_operation in enumerate(self._input_operations): - if input_operation is src.operation: - return [i] - - input_indices = [] - for i in src.operation.inputs_required_for_output(src.index): - input_indices.extend(self._inputs_required_for_source(src.operation.input(i).signals[0].source, visited)) - return input_indices + def get_precedence_list(self) -> List[List[OutputPort]]: + """Returns a Precedence list of the SFG where each element in n:th the list consists + of elements that are executed in the n:th step. If the precedence list already has been + calculated for the current SFG then returns the cached version.""" + if self._precedence_list is not None: + return self._precedence_list + + # Find all operations with only outputs and no inputs. + no_input_ops = list(filter(lambda op: op.input_count == 0, self.operations)) + reg_ops = self.get_components_with_type_name(Register.type_name()) + + # Find all first iter output ports for precedence + first_iter_ports = [op.output(i) for op in (no_input_ops + reg_ops) for i in range(op.output_count)] + + self._precedence_list = self._traverse_for_precedence_list(first_iter_ports) + + return self._precedence_list + + def _traverse_for_precedence_list(self, first_iter_ports): + # Find dependencies of output ports and input ports. + outports_per_inport = defaultdict(list) + remaining_inports_per_outport = dict() + for op in self.operations: + op_inputs = op.inputs + for out_i, outport in enumerate(op.outputs): + dependendent_indexes = op.inputs_required_for_output(out_i) + remaining_inports_per_outport[outport] = len(dependendent_indexes) + for in_i in dependendent_indexes: + outports_per_inport[op_inputs[in_i]].append(outport) + + # Traverse output ports for precedence + curr_iter_ports = first_iter_ports + precedence_list = [] + + while curr_iter_ports: + # Add the found ports to the current iter + precedence_list.append(curr_iter_ports) + + next_iter_ports = [] + + for outport in curr_iter_ports: + for signal in outport.signals: + new_inport = signal.destination + # Don't traverse over Registers + if new_inport is not None and not isinstance(new_inport.operation, Register): + for new_outport in outports_per_inport[new_inport]: + remaining_inports_per_outport[new_outport] -= 1 + if remaining_inports_per_outport[new_outport] == 0: + next_iter_ports.append(new_outport) + + curr_iter_ports = next_iter_ports + + return precedence_list diff --git a/b_asic/simulation.py b/b_asic/simulation.py index 9d0d154f..81be8de4 100644 --- a/b_asic/simulation.py +++ b/b_asic/simulation.py @@ -7,7 +7,7 @@ from collections import defaultdict from numbers import Number from typing import List, Dict, DefaultDict, Callable, Sequence, Mapping, Union, Optional -from b_asic.operation import ResultKey, ResultMap +from b_asic.operation import OutputKey, OutputMap from b_asic.signal_flow_graph import SFG @@ -33,7 +33,8 @@ class Simulation: self._results = defaultdict(dict) self._registers = {} self._iteration = 0 - self._input_functions = [lambda _: 0 for _ in range(self._sfg.input_count)] + self._input_functions = [ + lambda _: 0 for _ in range(self._sfg.input_count)] self._current_input_values = [0 for _ in range(self._sfg.input_count)] self._latest_output_values = [0 for _ in range(self._sfg.output_count)] self._save_results = save_results @@ -43,7 +44,8 @@ class Simulation: def set_input(self, index: int, input_provider: InputProvider) -> None: """Set the input function used to get values for the specific input at the given index to the internal SFG.""" if index < 0 or index >= len(self._input_functions): - raise IndexError(f"Input index out of range (expected 0-{len(self._input_functions) - 1}, got {index})") + raise IndexError( + f"Input index out of range (expected 0-{len(self._input_functions) - 1}, got {index})") if callable(input_provider): self._input_functions[index] = input_provider elif isinstance(input_provider, Number): @@ -54,7 +56,8 @@ class Simulation: def set_inputs(self, input_providers: Sequence[Optional[InputProvider]]) -> None: """Set the input functions used to get values for the inputs to the internal SFG.""" if len(input_providers) != self._sfg.input_count: - raise ValueError(f"Wrong number of inputs supplied to simulation (expected {self._sfg.input_count}, got {len(input_providers)})") + raise ValueError( + f"Wrong number of inputs supplied to simulation (expected {self._sfg.input_count}, got {len(input_providers)})") self._input_functions = [None for _ in range(self._sfg.input_count)] for index, input_provider in enumerate(input_providers): if input_provider is not None: @@ -78,8 +81,10 @@ class Simulation: and return the resulting output values. """ while self._iteration < iteration: - self._current_input_values = [self._input_functions[i](self._iteration) for i in range(self._sfg.input_count)] - self._latest_output_values = self._sfg.evaluate_outputs(self._current_input_values, self._results[self._iteration], self._registers) + self._current_input_values = [self._input_functions[i]( + self._iteration) for i in range(self._sfg.input_count)] + self._latest_output_values = self._sfg.evaluate_outputs( + self._current_input_values, self._results[self._iteration], self._registers) if not self._save_results: del self._results[self.iteration] self._iteration += 1 @@ -95,7 +100,7 @@ class Simulation: return self._iteration @property - def results(self) -> Mapping[int, ResultMap]: + def results(self) -> Mapping[int, OutputMap]: """Get a mapping of all results, including intermediate values, calculated for each iteration up until now. The outer mapping maps from iteration number to value mapping. The value mapping maps output port identifiers to values. Example: {0: {"c1": 3, "c2": 4, "bfly1.0": 7, "bfly1.1": -1, "0": 7}} @@ -110,4 +115,4 @@ class Simulation: """Clear all current state of the simulation, except for the results and iteration.""" self._registers.clear() self._current_input_values = [0 for _ in range(self._sfg.input_count)] - self._latest_output_values = [0 for _ in range(self._sfg.output_count)] \ No newline at end of file + self._latest_output_values = [0 for _ in range(self._sfg.output_count)] diff --git a/b_asic/special_operations.py b/b_asic/special_operations.py index 96d341b9..0a256bc8 100644 --- a/b_asic/special_operations.py +++ b/b_asic/special_operations.py @@ -6,7 +6,7 @@ TODO: More info. from numbers import Number from typing import Optional, Sequence -from b_asic.operation import AbstractOperation, ResultKey, RegisterMap, MutableResultMap, MutableRegisterMap +from b_asic.operation import AbstractOperation, OutputKey, RegisterMap, MutableOutputMap, MutableRegisterMap from b_asic.graph_component import Name, TypeName from b_asic.port import SignalSourceProvider @@ -17,13 +17,13 @@ class Input(AbstractOperation): """ def __init__(self, name: Name = ""): - super().__init__(input_count = 0, output_count = 1, name = name) + super().__init__(input_count=0, output_count=1, name=name) self.set_param("value", 0) - @property - def type_name(self) -> TypeName: + @classmethod + def type_name(cls) -> TypeName: return "in" - + def evaluate(self): return self.param("value") @@ -44,10 +44,11 @@ class Output(AbstractOperation): """ def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 1, output_count = 0, name = name, input_sources = [src0]) + super().__init__(input_count=1, output_count=0, + name=name, input_sources=[src0]) - @property - def type_name(self) -> TypeName: + @classmethod + def type_name(cls) -> TypeName: return "out" def evaluate(self, _): @@ -60,11 +61,12 @@ class Register(AbstractOperation): """ def __init__(self, src0: Optional[SignalSourceProvider] = None, initial_value: Number = 0, name: Name = ""): - super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) + super().__init__(input_count=1, output_count=1, + name=name, input_sources=[src0]) self.set_param("initial_value", initial_value) - @property - def type_name(self) -> TypeName: + @classmethod + def type_name(cls) -> TypeName: return "reg" def evaluate(self, a): @@ -74,13 +76,15 @@ class Register(AbstractOperation): if registers is not None: return registers.get(self.key(index, prefix), self.param("initial_value")) return self.param("initial_value") - - def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Number: + + def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableOutputMap] = None, registers: Optional[MutableRegisterMap] = None, prefix: str = "") -> Number: if index != 0: - raise IndexError(f"Output index out of range (expected 0-0, got {index})") + raise IndexError( + f"Output index out of range (expected 0-0, got {index})") if len(input_values) != 1: - raise ValueError(f"Wrong number of inputs supplied to SFG for evaluation (expected 1, got {len(input_values)})") - + raise ValueError( + f"Wrong number of inputs supplied to SFG for evaluation (expected 1, got {len(input_values)})") + key = self.key(index, prefix) value = self.param("initial_value") if registers is not None: @@ -88,4 +92,4 @@ class Register(AbstractOperation): registers[key] = self.truncate_inputs(input_values)[0] if results is not None: results[key] = value - return value \ No newline at end of file + return value diff --git a/test/test_core_operations.py b/test/test_core_operations.py index 4d0039b5..2eb341da 100644 --- a/test/test_core_operations.py +++ b/test/test_core_operations.py @@ -164,3 +164,14 @@ class TestButterfly: test_operation = Butterfly() assert test_operation.evaluate_output(0, [2+1j, 3-2j]) == 5-1j assert test_operation.evaluate_output(1, [2+1j, 3-2j]) == -1+3j + + +class TestDepends: + def test_depends_addition(self): + add1 = Addition() + assert set(add1.inputs_required_for_output(0)) == {0, 1} + + def test_depends_butterfly(self): + bfly1 = Butterfly() + assert set(bfly1.inputs_required_for_output(0)) == {0, 1} + assert set(bfly1.inputs_required_for_output(1)) == {0, 1} diff --git a/test/test_depends.py b/test/test_depends.py deleted file mode 100644 index e2691105..00000000 --- a/test/test_depends.py +++ /dev/null @@ -1,19 +0,0 @@ -from b_asic import Addition, Butterfly - -class TestDepends: - def test_depends_addition(self): - add1 = Addition() - assert set(add1.inputs_required_for_output(0)) == {0, 1} - - def test_depends_butterfly(self): - bfly1 = Butterfly() - assert set(bfly1.inputs_required_for_output(0)) == {0, 1} - assert set(bfly1.inputs_required_for_output(1)) == {0, 1} - - def test_depends_sfg(self, sfg_two_inputs_two_outputs): - assert set(sfg_two_inputs_two_outputs.inputs_required_for_output(0)) == {0, 1} - assert set(sfg_two_inputs_two_outputs.inputs_required_for_output(1)) == {0, 1} - - def test_depends_sfg_independent(self, sfg_two_inputs_two_outputs_independent): - assert set(sfg_two_inputs_two_outputs_independent.inputs_required_for_output(0)) == {0} - assert set(sfg_two_inputs_two_outputs_independent.inputs_required_for_output(1)) == {1} \ No newline at end of file diff --git a/test/test_sfg.py b/test/test_sfg.py index c188351f..cf309c26 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -1,6 +1,7 @@ import pytest -from b_asic import SFG, Signal, Input, Output, Constant, Addition, Multiplication, Subtraction +from b_asic import SFG, Signal, Input, Output, Constant, ConstantMultiplication, Addition, Multiplication, Register, \ + Butterfly, Subtraction class TestInit: @@ -9,7 +10,7 @@ class TestInit: out1 = Output(None, "OUT1") out1.input(0).connect(in1, "S1") - sfg = SFG(inputs = [in1], outputs = [out1]) # in1 ---s1---> out1 + sfg = SFG(inputs=[in1], outputs=[out1]) # in1 ---s1---> out1 assert len(list(sfg.components)) == 3 assert len(list(sfg.operations)) == 2 @@ -22,7 +23,8 @@ class TestInit: s1 = add2.input(0).connect(add1, "S1") - sfg = SFG(input_signals = [s1], output_signals = [s1]) # in1 ---s1---> out1 + # in1 ---s1---> out1 + sfg = SFG(input_signals=[s1], output_signals=[s1]) assert len(list(sfg.components)) == 3 assert len(list(sfg.operations)) == 2 @@ -30,7 +32,7 @@ class TestInit: assert sfg.output_count == 1 def test_outputs_construction(self, operation_tree): - sfg = SFG(outputs = [Output(operation_tree)]) + sfg = SFG(outputs=[Output(operation_tree)]) assert len(list(sfg.components)) == 7 assert len(list(sfg.operations)) == 4 @@ -38,13 +40,14 @@ class TestInit: assert sfg.output_count == 1 def test_signals_construction(self, operation_tree): - sfg = SFG(output_signals = [Signal(source = operation_tree.output(0))]) + sfg = SFG(output_signals=[Signal(source=operation_tree.output(0))]) assert len(list(sfg.components)) == 7 assert len(list(sfg.operations)) == 4 assert sfg.input_count == 0 assert sfg.output_count == 1 + class TestPrintSfg: def test_one_addition(self): inp1 = Input("INP1") @@ -91,13 +94,14 @@ class TestPrintSfg: "id: out1, name: OUT1, input: [s2], output: []\n" def test_simple_filter(self, simple_filter): - assert simple_filter.__str__() == \ + assert simple_filter.__str__() == \ 'id: add1, name: , input: [s1, s3], output: [s4]\n' + \ 'id: in1, name: , input: [], output: [s1]\n' + \ 'id: cmul1, name: , input: [s5], output: [s3]\n' + \ 'id: reg1, name: , input: [s4], output: [s5, s2]\n' + \ 'id: out1, name: , input: [s2], output: []\n' + class TestDeepCopy: def test_deep_copy_no_duplicates(self): inp1 = Input("INP1") @@ -107,7 +111,7 @@ class TestDeepCopy: mul1 = Multiplication(add1, inp3, "MUL1") out1 = Output(mul1, "OUT1") - mac_sfg = SFG(inputs = [inp1, inp2], outputs = [out1], name = "mac_sfg") + mac_sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="mac_sfg") mac_sfg_new = mac_sfg() assert mac_sfg.name == "mac_sfg" @@ -134,8 +138,9 @@ class TestDeepCopy: mul1.input(1).connect(add2, "S6") out1.input(0).connect(mul1, "S7") - mac_sfg = SFG(inputs = [inp1, inp2], outputs = [out1], id_number_offset = 100, name = "mac_sfg") - mac_sfg_new = mac_sfg(name = "mac_sfg2") + mac_sfg = SFG(inputs=[inp1, inp2], outputs=[out1], + id_number_offset=100, name="mac_sfg") + mac_sfg_new = mac_sfg(name="mac_sfg2") assert mac_sfg.name == "mac_sfg" assert mac_sfg_new.name == "mac_sfg2" @@ -145,7 +150,7 @@ class TestDeepCopy: for g_id, component in mac_sfg._components_by_id.items(): component_copy = mac_sfg_new.find_by_id(g_id) assert component.name == component_copy.name - + def test_deep_copy_with_new_sources(self): inp1 = Input("INP1") inp2 = Input("INP2") @@ -154,7 +159,7 @@ class TestDeepCopy: mul1 = Multiplication(add1, inp3, "MUL1") out1 = Output(mul1, "OUT1") - mac_sfg = SFG(inputs = [inp1, inp2], outputs = [out1], name = "mac_sfg") + mac_sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="mac_sfg") a = Addition(Constant(3), Constant(5)) b = Constant(2) @@ -162,20 +167,22 @@ class TestDeepCopy: assert mac_sfg_new.input(0).signals[0].source.operation is a assert mac_sfg_new.input(1).signals[0].source.operation is b + class TestEvaluateOutput: def test_evaluate_output(self, operation_tree): - sfg = SFG(outputs = [Output(operation_tree)]) + sfg = SFG(outputs=[Output(operation_tree)]) assert sfg.evaluate_output(0, []) == 5 def test_evaluate_output_large(self, large_operation_tree): - sfg = SFG(outputs = [Output(large_operation_tree)]) + sfg = SFG(outputs=[Output(large_operation_tree)]) assert sfg.evaluate_output(0, []) == 14 def test_evaluate_output_cycle(self, operation_graph_with_cycle): - sfg = SFG(outputs = [Output(operation_graph_with_cycle)]) + sfg = SFG(outputs=[Output(operation_graph_with_cycle)]) with pytest.raises(Exception): sfg.evaluate_output(0, []) + class TestComponents: def test_advanced_components(self): inp1 = Input("INP1") @@ -194,9 +201,10 @@ class TestComponents: mul1.input(1).connect(add2, "S6") out1.input(0).connect(mul1, "S7") - mac_sfg = SFG(inputs = [inp1, inp2], outputs = [out1], name = "mac_sfg") + mac_sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="mac_sfg") - assert set([comp.name for comp in mac_sfg.components]) == {"INP1", "INP2", "INP3", "ADD1", "ADD2", "MUL1", "OUT1", "S1", "S2", "S3", "S4", "S5", "S6", "S7"} + assert set([comp.name for comp in mac_sfg.components]) == { + "INP1", "INP2", "INP3", "ADD1", "ADD2", "MUL1", "OUT1", "S1", "S2", "S3", "S4", "S5", "S6", "S7"} class TestReplaceComponents: @@ -204,7 +212,8 @@ class TestReplaceComponents: sfg = SFG(outputs=[Output(operation_tree)]) component_id = "add1" - sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id) + sfg = sfg.replace_component( + Multiplication(name="Multi"), _id=component_id) assert component_id not in sfg._components_by_id.keys() assert "Multi" in sfg._components_by_name.keys() @@ -213,7 +222,8 @@ class TestReplaceComponents: component_id = "add1" component = sfg.find_by_id(component_id) - sfg = sfg.replace_component(Multiplication(name="Multi"), _component=component) + sfg = sfg.replace_component(Multiplication( + name="Multi"), _component=component) assert component_id not in sfg._components_by_id.keys() assert "Multi" in sfg._components_by_name.keys() @@ -221,15 +231,16 @@ class TestReplaceComponents: sfg = SFG(outputs=[Output(large_operation_tree)]) component_id = "add3" - sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id) + sfg = sfg.replace_component( + Multiplication(name="Multi"), _id=component_id) assert "Multi" in sfg._components_by_name.keys() assert component_id not in sfg._components_by_id.keys() - + def test_replace_no_input_component(self, operation_tree): sfg = SFG(outputs=[Output(operation_tree)]) component_id = "c1" _const = sfg.find_by_id(component_id) - + sfg = sfg.replace_component(Constant(1), _id=component_id) assert _const is not sfg.find_by_id(component_id) @@ -238,7 +249,8 @@ class TestReplaceComponents: component_id = "addd1" try: - sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id) + sfg = sfg.replace_component( + Multiplication(name="Multi"), _id=component_id) except AssertionError: assert True else: @@ -249,12 +261,216 @@ class TestReplaceComponents: component_id = "c1" try: - sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id) + sfg = sfg.replace_component( + Multiplication(name="Multi"), _id=component_id) except AssertionError: assert True else: assert False + +class TestFindComponentsWithTypeName: + def test_mac_components(self): + inp1 = Input("INP1") + inp2 = Input("INP2") + inp3 = Input("INP3") + add1 = Addition(None, None, "ADD1") + add2 = Addition(None, None, "ADD2") + mul1 = Multiplication(None, None, "MUL1") + out1 = Output(None, "OUT1") + + add1.input(0).connect(inp1, "S1") + add1.input(1).connect(inp2, "S2") + add2.input(0).connect(add1, "S4") + add2.input(1).connect(inp3, "S3") + mul1.input(0).connect(add1, "S5") + mul1.input(1).connect(add2, "S6") + out1.input(0).connect(mul1, "S7") + + mac_sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="mac_sfg") + + assert {comp.name for comp in mac_sfg.get_components_with_type_name( + inp1.type_name())} == {"INP1", "INP2", "INP3"} + + assert {comp.name for comp in mac_sfg.get_components_with_type_name( + add1.type_name())} == {"ADD1", "ADD2"} + + assert {comp.name for comp in mac_sfg.get_components_with_type_name( + mul1.type_name())} == {"MUL1"} + + assert {comp.name for comp in mac_sfg.get_components_with_type_name( + out1.type_name())} == {"OUT1"} + + assert {comp.name for comp in mac_sfg.get_components_with_type_name( + Signal.type_name())} == {"S1", "S2", "S3", "S4", "S5", "S6", "S7"} + + +class TestGetPrecedenceList: + + def test_inputs_registers(self): + in1 = Input("IN1") + c0 = ConstantMultiplication(5, in1, "C0") + add1 = Addition(c0, None, "ADD1") + # Not sure what operation "Q" is supposed to be in the example + Q1 = ConstantMultiplication(1, add1, "Q1") + T1 = Register(Q1, 0, "T1") + T2 = Register(T1, 0, "T2") + b2 = ConstantMultiplication(2, T2, "B2") + b1 = ConstantMultiplication(3, T1, "B1") + add2 = Addition(b1, b2, "ADD2") + add1.input(1).connect(add2) + a1 = ConstantMultiplication(4, T1, "A1") + a2 = ConstantMultiplication(6, T2, "A2") + add3 = Addition(a1, a2, "ADD3") + a0 = ConstantMultiplication(7, Q1, "A0") + add4 = Addition(a0, add3, "ADD4") + out1 = Output(add4, "OUT1") + + sfg = SFG(inputs=[in1], outputs=[out1], name="SFG") + + precedence_list = sfg.get_precedence_list() + + assert len(precedence_list) == 7 + + assert set([port.operation.key(port.index, port.operation.name) + for port in precedence_list[0]]) == {"IN1", "T1", "T2"} + + assert set([port.operation.key(port.index, port.operation.name) + for port in precedence_list[1]]) == {"C0", "B1", "B2", "A1", "A2"} + + assert set([port.operation.key(port.index, port.operation.name) + for port in precedence_list[2]]) == {"ADD2", "ADD3"} + + assert set([port.operation.key(port.index, port.operation.name) + for port in precedence_list[3]]) == {"ADD1"} + + assert set([port.operation.key(port.index, port.operation.name) + for port in precedence_list[4]]) == {"Q1"} + + assert set([port.operation.key(port.index, port.operation.name) + for port in precedence_list[5]]) == {"A0"} + + assert set([port.operation.key(port.index, port.operation.name) + for port in precedence_list[6]]) == {"ADD4"} + + def test_inputs_constants_registers_multiple_outputs(self): + in1 = Input("IN1") + c0 = ConstantMultiplication(5, in1, "C0") + add1 = Addition(c0, None, "ADD1") + # Not sure what operation "Q" is supposed to be in the example + Q1 = ConstantMultiplication(1, add1, "Q1") + T1 = Register(Q1, 0, "T1") + const1 = Constant(10, "CONST1") # Replace T2 register with a constant + b2 = ConstantMultiplication(2, const1, "B2") + b1 = ConstantMultiplication(3, T1, "B1") + add2 = Addition(b1, b2, "ADD2") + add1.input(1).connect(add2) + a1 = ConstantMultiplication(4, T1, "A1") + a2 = ConstantMultiplication(10, const1, "A2") + add3 = Addition(a1, a2, "ADD3") + a0 = ConstantMultiplication(7, Q1, "A0") + # Replace ADD4 with a butterfly to test multiple output ports + bfly1 = Butterfly(a0, add3, "BFLY1") + out1 = Output(bfly1.output(0), "OUT1") + out2 = Output(bfly1.output(1), "OUT2") + + sfg = SFG(inputs=[in1], outputs=[out1], name="SFG") + + precedence_list = sfg.get_precedence_list() + + assert len(precedence_list) == 7 + + assert set([port.operation.key(port.index, port.operation.name) + for port in precedence_list[0]]) == {"IN1", "T1", "CONST1"} + + assert set([port.operation.key(port.index, port.operation.name) + for port in precedence_list[1]]) == {"C0", "B1", "B2", "A1", "A2"} + + assert set([port.operation.key(port.index, port.operation.name) + for port in precedence_list[2]]) == {"ADD2", "ADD3"} + + assert set([port.operation.key(port.index, port.operation.name) + for port in precedence_list[3]]) == {"ADD1"} + + assert set([port.operation.key(port.index, port.operation.name) + for port in precedence_list[4]]) == {"Q1"} + + assert set([port.operation.key(port.index, port.operation.name) + for port in precedence_list[5]]) == {"A0"} + + assert set([port.operation.key(port.index, port.operation.name) + for port in precedence_list[6]]) == {"BFLY1.0", "BFLY1.1"} + + def test_precedence_multiple_outputs_same_precedence(self, sfg_two_inputs_two_outputs): + sfg_two_inputs_two_outputs.name = "NESTED_SFG" + + in1 = Input("IN1") + sfg_two_inputs_two_outputs.input(0).connect(in1, "S1") + in2 = Input("IN2") + cmul1 = ConstantMultiplication(10, None, "CMUL1") + cmul1.input(0).connect(in2, "S2") + sfg_two_inputs_two_outputs.input(1).connect(cmul1, "S3") + + out1 = Output(sfg_two_inputs_two_outputs.output(0), "OUT1") + out2 = Output(sfg_two_inputs_two_outputs.output(1), "OUT2") + + sfg = SFG(inputs=[in1, in2], outputs=[out1, out2]) + + precedence_list = sfg.get_precedence_list() + + assert len(precedence_list) == 3 + + assert set([port.operation.key(port.index, port.operation.name) + for port in precedence_list[0]]) == {"IN1", "IN2"} + + assert set([port.operation.key(port.index, port.operation.name) + for port in precedence_list[1]]) == {"CMUL1"} + + assert set([port.operation.key(port.index, port.operation.name) + for port in precedence_list[2]]) == {"NESTED_SFG.0", "NESTED_SFG.1"} + + def test_precedence_sfg_multiple_outputs_different_precedences(self, sfg_two_inputs_two_outputs_independent): + sfg_two_inputs_two_outputs_independent.name = "NESTED_SFG" + + in1 = Input("IN1") + in2 = Input("IN2") + sfg_two_inputs_two_outputs_independent.input(0).connect(in1, "S1") + cmul1 = ConstantMultiplication(10, None, "CMUL1") + cmul1.input(0).connect(in2, "S2") + sfg_two_inputs_two_outputs_independent.input(1).connect(cmul1, "S3") + out1 = Output(sfg_two_inputs_two_outputs_independent.output(0), "OUT1") + out2 = Output(sfg_two_inputs_two_outputs_independent.output(1), "OUT2") + + sfg = SFG(inputs=[in1, in2], outputs=[out1, out2]) + + precedence_list = sfg.get_precedence_list() + + assert len(precedence_list) == 3 + + assert set([port.operation.key(port.index, port.operation.name) + for port in precedence_list[0]]) == {"IN1", "IN2"} + + assert set([port.operation.key(port.index, port.operation.name) + for port in precedence_list[1]]) == {"NESTED_SFG.0", "CMUL1"} + + assert set([port.operation.key(port.index, port.operation.name) + for port in precedence_list[2]]) == {"NESTED_SFG.1"} + + +class TestDepends: + def test_depends_sfg(self, sfg_two_inputs_two_outputs): + assert set(sfg_two_inputs_two_outputs.inputs_required_for_output(0)) == { + 0, 1} + assert set(sfg_two_inputs_two_outputs.inputs_required_for_output(1)) == { + 0, 1} + + def test_depends_sfg_independent(self, sfg_two_inputs_two_outputs_independent): + assert set( + sfg_two_inputs_two_outputs_independent.inputs_required_for_output(0)) == {0} + assert set( + sfg_two_inputs_two_outputs_independent.inputs_required_for_output(1)) == {1} + + class TestConnectExternalSignalsToComponentsSoloComp: def test_connect_external_signals_to_components_mac(self): @@ -275,8 +491,8 @@ class TestConnectExternalSignalsToComponentsSoloComp: mul1.input(1).connect(add2, "S6") out1.input(0).connect(mul1, "S7") - mac_sfg = SFG(inputs = [inp1, inp2], outputs = [out1]) - + mac_sfg = SFG(inputs=[inp1, inp2], outputs=[out1]) + inp4 = Input("INP4") inp5 = Input("INP5") out2 = Output(None, "OUT2") @@ -285,39 +501,40 @@ class TestConnectExternalSignalsToComponentsSoloComp: mac_sfg.input(1).connect(inp5, "S9") out2.input(0).connect(mac_sfg.outputs[0], "S10") - test_sfg = SFG(inputs = [inp4, inp5], outputs = [out2]) - assert test_sfg.evaluate(1,2) == 9 + test_sfg = SFG(inputs=[inp4, inp5], outputs=[out2]) + assert test_sfg.evaluate(1, 2) == 9 mac_sfg.connect_external_signals_to_components() - assert test_sfg.evaluate(1,2) == 9 - assert test_sfg.connect_external_signals_to_components() == False + assert test_sfg.evaluate(1, 2) == 9 + assert not test_sfg.connect_external_signals_to_components() def test_connect_external_signals_to_components_operation_tree(self, operation_tree): """ Replaces an SFG with only a operation_tree component with its inner components """ - sfg1 = SFG(outputs = [Output(operation_tree)]) + sfg1 = SFG(outputs=[Output(operation_tree)]) out1 = Output(None, "OUT1") out1.input(0).connect(sfg1.outputs[0], "S1") - test_sfg = SFG(outputs = [out1]) + test_sfg = SFG(outputs=[out1]) assert test_sfg.evaluate_output(0, []) == 5 sfg1.connect_external_signals_to_components() assert test_sfg.evaluate_output(0, []) == 5 - assert test_sfg.connect_external_signals_to_components() == False + assert not test_sfg.connect_external_signals_to_components() def test_connect_external_signals_to_components_large_operation_tree(self, large_operation_tree): """ Replaces an SFG with only a large_operation_tree component with its inner components """ - sfg1 = SFG(outputs = [Output(large_operation_tree)]) + sfg1 = SFG(outputs=[Output(large_operation_tree)]) out1 = Output(None, "OUT1") out1.input(0).connect(sfg1.outputs[0], "S1") - test_sfg = SFG(outputs = [out1]) + test_sfg = SFG(outputs=[out1]) assert test_sfg.evaluate_output(0, []) == 14 sfg1.connect_external_signals_to_components() assert test_sfg.evaluate_output(0, []) == 14 - assert test_sfg.connect_external_signals_to_components() == False + assert not test_sfg.connect_external_signals_to_components() + class TestConnectExternalSignalsToComponentsMultipleComp: def test_connect_external_signals_to_components_operation_tree(self, operation_tree): """ Replaces a operation_tree in an SFG with other components """ - sfg1 = SFG(outputs = [Output(operation_tree)]) + sfg1 = SFG(outputs=[Output(operation_tree)]) inp1 = Input("INP1") inp2 = Input("INP2") @@ -332,15 +549,15 @@ class TestConnectExternalSignalsToComponentsMultipleComp: add2.input(1).connect(sfg1.outputs[0], "S4") out1.input(0).connect(add2, "S5") - test_sfg = SFG(inputs = [inp1, inp2], outputs = [out1]) + test_sfg = SFG(inputs=[inp1, inp2], outputs=[out1]) assert test_sfg.evaluate(1, 2) == 8 sfg1.connect_external_signals_to_components() assert test_sfg.evaluate(1, 2) == 8 - assert test_sfg.connect_external_signals_to_components() == False + assert not test_sfg.connect_external_signals_to_components() def test_connect_external_signals_to_components_large_operation_tree(self, large_operation_tree): """ Replaces a large_operation_tree in an SFG with other components """ - sfg1 = SFG(outputs = [Output(large_operation_tree)]) + sfg1 = SFG(outputs=[Output(large_operation_tree)]) inp1 = Input("INP1") inp2 = Input("INP2") @@ -354,15 +571,15 @@ class TestConnectExternalSignalsToComponentsMultipleComp: add2.input(1).connect(sfg1.outputs[0], "S4") out1.input(0).connect(add2, "S5") - test_sfg = SFG(inputs = [inp1, inp2], outputs = [out1]) + test_sfg = SFG(inputs=[inp1, inp2], outputs=[out1]) assert test_sfg.evaluate(1, 2) == 17 sfg1.connect_external_signals_to_components() assert test_sfg.evaluate(1, 2) == 17 - assert test_sfg.connect_external_signals_to_components() == False + assert not test_sfg.connect_external_signals_to_components() def create_sfg(self, op_tree): """ Create a simple SFG with either operation_tree or large_operation_tree """ - sfg1 = SFG(outputs = [Output(op_tree)]) + sfg1 = SFG(outputs=[Output(op_tree)]) inp1 = Input("INP1") inp2 = Input("INP2") @@ -376,7 +593,7 @@ class TestConnectExternalSignalsToComponentsMultipleComp: add2.input(1).connect(sfg1.outputs[0], "S4") out1.input(0).connect(add2, "S5") - return SFG(inputs = [inp1, inp2], outputs = [out1]) + return SFG(inputs=[inp1, inp2], outputs=[out1]) def test_connect_external_signals_to_components_many_op(self, large_operation_tree): """ Replaces an sfg component in a larger SFG with several component operations """ @@ -399,8 +616,8 @@ class TestConnectExternalSignalsToComponentsMultipleComp: sub1.input(1).connect(inp4, "S6") out1.input(0).connect(sub1, "S7") - test_sfg = SFG(inputs = [inp1, inp2, inp3, inp4], outputs = [out1]) + test_sfg = SFG(inputs=[inp1, inp2, inp3, inp4], outputs=[out1]) assert test_sfg.evaluate(1, 2, 3, 4) == 16 sfg1.connect_external_signals_to_components() assert test_sfg.evaluate(1, 2, 3, 4) == 16 - assert test_sfg.connect_external_signals_to_components() == False \ No newline at end of file + assert not test_sfg.connect_external_signals_to_components() -- GitLab