diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index df0868ee53b0b2d8b0d74706e8f3b31cad813892..c670b15a1fdc76b0f575da9f7aa3d0c5da01f492 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -1053,6 +1053,7 @@ class MAD(AbstractOperation): _execution_time: Optional[int] is_swappable = True + is_distributive = True def __init__( self, diff --git a/b_asic/gui_utils/color_button.py b/b_asic/gui_utils/color_button.py index 834d07271576b177ec95efdb5a69a1b1fee77e97..ea54c6b5ac560a9afeac2b1516db4039bd518a6a 100644 --- a/b_asic/gui_utils/color_button.py +++ b/b_asic/gui_utils/color_button.py @@ -31,7 +31,10 @@ class ColorButton(QPushButton): self.set_color(self._default) def set_color(self, color: QColor): - """Set new color.""" + """Set new color. + color : QColor + The new color of the button. + """ if color != self._color: self._color = color self._color_changed.emit(color) @@ -42,7 +45,10 @@ class ColorButton(QPushButton): self.setStyleSheet("") def set_text_color(self, color: QColor): - """Set text color.""" + """Set text color. + color : QColor + The new color of the text in the button. + """ self.setStyleSheet(f"color: {color.name()};") @property diff --git a/b_asic/operation.py b/b_asic/operation.py index 02995d3f7cc6ee7658b9677fbd36d26406aa82f7..8e550874b9c4900cd1d8dd08321cba9eb354eee3 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -425,6 +425,22 @@ class Operation(GraphComponent, SignalSourceProvider): """ raise NotImplementedError + @property + @abstractmethod + def is_commutative(self) -> bool: + """ + Return True if the operation is commutative. + """ + raise NotImplementedError + + @property + @abstractmethod + def is_distributive(self) -> bool: + """ + Return True if the operation is distributive. + """ + raise NotImplementedError + @property @abstractmethod def is_swappable(self) -> bool: @@ -1064,6 +1080,29 @@ class AbstractOperation(Operation, AbstractGraphComponent): input_.connected_source.operation.is_constant for input_ in self.inputs ) + @property + def is_commutative(self) -> bool: + """ + Checks if the operation is commutative. + + An operation is commutative if the order of the inputs does not change the result. + For example, addition is commutative because `a + b == b + a`, but subtraction is not + because `a - b != b - a`. + + Returns: + bool: True if the operation is commutative, False otherwise. + + """ + # doc-string inherited + if self.input_count == 2: + return self.is_swappable + return False + + @property + def is_distributive(self) -> bool: + # doc-string inherited + return False + @property def is_swappable(self) -> bool: # doc-string inherited diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index cbd491d78ccee41bd79f50d444c1721ff983da60..29900ac45b2c67d08706db4f82288fb569e627f2 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -2111,6 +2111,153 @@ class SFG(AbstractOperation): def is_constant(self) -> bool: return all(output.is_constant for output in self._output_operations) + @property + def is_commutative(self) -> bool: + """ + Checks if all operations in the SFG are commutative. + + An operation is considered commutative if it is not in B-ASIC Special Operations Module, + and its `is_commutative` property is `True`. + + Returns + ------- + bool: `True` if all operations are commutative, `False` otherwise. + """ + return all( + ( + op.is_commutative + if op.type_name() not in ["in", "out", "t", "c"] + else True + ) + for op in self.split() + ) + + @property + def is_distributive(self) -> bool: + """ + Checks if the SFG is distributive. + + An operation is considered distributive if it can be applied to each element of a set separately. + For example, multiplication is distributive over addition, meaning that `a * (b + c)` is equivalent to `a * b + a * c`. + + Returns + ------- + bool: True if the SFG is distributive, False otherwise. + + Examples + -------- + >>> Mad_op = MAD(Input(), Input(), Input()) # Creates an instance of the Mad operation, MAD is defined in b_asic.core_operations + >>> Mad_sfg = Mad_op.to_sfg() # The operation is turned into a sfg + >>> Mad_sfg.is_distributive # True # if the distributive property holds, False otherwise + + """ + structures = [] + operations = self.get_operations_topological_order() + for op in operations: + if not ( + op.type_name() == "in" + or op.type_name() == "out" + or op.type_name() == "c" + or op.type_name() == "t" + ): + structures.append(op) + return ( + all(self.has_distributive_structure(op) for op in structures) + if len(structures) > 1 + else False + ) + + def has_distributive_structure(self, op: Operation) -> bool: + """ + Checks if the SFG contains distributive structures. + NOTE: a*b + c = a*(b + c/a) is considered distributive. Meaning that an algorithm transformation would require an additionat operation. + + Parameters: + ---------- + op : Operation + The operation that is the start of the structure to check for distributivity. + + Returns: + ------- + bool: True if a distributive structures is found, False otherwise. + """ + # TODO Butterfly and SymmetricTwoportAdaptor needs to be converted to a SF using to_sfg() in order to be evaluated + if op.type_name() == 'mac': + return True + elif op.type_name() in ['mul', 'div']: + for subsequent_op in op.subsequent_operations: + if subsequent_op.type_name() in [ + 'add', + 'sub', + 'addsub', + 'min', + 'max', + 'sqrt', + 'abs', + 'rec', + 'out', + 't', + ]: + return True + elif subsequent_op.type_name() in ['mul', 'div']: + for subsequent_op in subsequent_op.subsequent_operations: + return self.has_distributive_structure(subsequent_op) + else: + return False + elif op.type_name() in ['cmul', 'shift', 'rshift', 'lshift']: + for subsequent_op in op.subsequent_operations: + if subsequent_op.type_name() in [ + 'add', + 'sub', + 'addsub', + 'min', + 'max', + 'out', + 't', + ]: + return True + elif subsequent_op.type_name() in ['cmul', 'shift', 'rshift', 'lshift']: + for subsequent_op in subsequent_op.subsequent_operations: + return self.has_distributive_structure(subsequent_op) + else: + return False + elif op.type_name() in ['add', 'sub', 'addsub']: + for subsequent_op in op.subsequent_operations: + if subsequent_op.type_name() in [ + 'mul', + 'div', + 'min', + 'max', + 'out', + 'cmul', + 't', + ]: + return True + elif subsequent_op.type_name() in ['add', 'sub', 'addsub']: + for subsequent_op in subsequent_op.subsequent_operations: + return self.has_distributive_structure(subsequent_op) + else: + return False + elif op.type_name() in ['min', 'max']: + for subsequent_op in op.subsequent_operations: + if subsequent_op.type_name() in [ + 'add', + 'sub', + 'addsub', + 'mul', + 'div', + 'cmul', + 'out', + 't', + ]: + return True + elif subsequent_op.type_name() in ['min', 'max']: + for subsequent_op in subsequent_op.subsequent_operations: + return self.has_distributive_structure(subsequent_op) + else: + return False + return False + def get_used_type_names(self) -> List[TypeName]: """Get a list of all TypeNames used in the SFG.""" ret = list({op.type_name() for op in self.operations}) diff --git a/test/test_sfg.py b/test/test_sfg.py index 6ece6b74ab738baace869143a3f2dd0a800d8f6d..55652ae042a770fcb52f6d12a9052ed5cb83b5f4 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -1611,6 +1611,36 @@ class TestIsConstant: assert not sfg_nested.is_constant +class TestIsCommutative: + def test_single_accumulator(self, sfg_simple_accumulator: SFG): + assert sfg_simple_accumulator.is_commutative + + def test_sfg_accumulation(self, sfg_accumulator: SFG): + assert not sfg_accumulator.is_commutative + + +class TestIsDistributive: + def test_single_accumulator(self, sfg_simple_accumulator: SFG): + assert not sfg_simple_accumulator.is_distributive + assert sfg_simple_accumulator.has_distributive_structure( + sfg_simple_accumulator.find_by_id('add0') + ) + + def test_sfg_accumulation(self, sfg_accumulator: SFG): + assert sfg_accumulator.is_distributive + + def test_sfg_with_Chain_of_Additions(self): # value*(a+b+c) + a = Input() + b = Input() + c = Input() + add1 = a + b + add2 = add1 + c + mul = ConstantMultiplication(2, add2) + out1 = Output(mul) + sfg_simple = SFG(inputs=[a, b, c], outputs=[out1]) + assert sfg_simple.is_distributive + + class TestSwapIOOfOperation: def do_test(self, sfg: SFG, graph_id: GraphID): NUM_TESTS = 5