From f7ad2d1c09a7f505c247f00bc92fb2b5de82024f Mon Sep 17 00:00:00 2001 From: Oscar Gustafsson <oscar.gustafsson@gmail.com> Date: Wed, 1 Feb 2023 18:10:45 +0100 Subject: [PATCH] Initial work on constant propagation --- b_asic/core_operations.py | 101 +++++++++++++++++++++++++++++++++++++- b_asic/operation.py | 65 ++++++++++++++++++++++++ 2 files changed, 164 insertions(+), 2 deletions(-) diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index b4c9642b..1d660a3b 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -5,13 +5,13 @@ Contains some of the most commonly used mathematical operations. """ from numbers import Number -from typing import Dict, Optional +from typing import Dict, Iterable, Optional, Set from numpy import abs as np_abs from numpy import conjugate, sqrt from b_asic.graph_component import Name, TypeName -from b_asic.operation import AbstractOperation +from b_asic.operation import AbstractOperation, Operation from b_asic.port import SignalSourceProvider @@ -125,6 +125,14 @@ class Addition(AbstractOperation): def evaluate(self, a, b): return a + b + def _propagate_some_constants( + self, + constants: Iterable[Optional[Number]], + valid_operations: Optional[Set["Operation"]] = None, + ) -> None: + if any(c == 0.0 for c in constants): + print("One input is 0!") + class Subtraction(AbstractOperation): """ @@ -185,6 +193,14 @@ class Subtraction(AbstractOperation): def evaluate(self, a, b): return a - b + def _propagate_some_constants( + self, + constants: Iterable[Optional[Number]], + valid_operations: Optional[Set["Operation"]] = None, + ) -> None: + if any(c == 0.0 for c in constants): + print("One input is 0!") + class AddSub(AbstractOperation): r""" @@ -266,6 +282,20 @@ class AddSub(AbstractOperation): """Set if operation is add.""" self.set_param("is_add", is_add) + def _propagate_some_constants( + self, + constants: Iterable[Optional[Number]], + valid_operations: Optional[Set["Operation"]] = None, + ) -> None: + if any(c == 0.0 for c in constants): + print("One input is 0!") + + def _propagate_constant_parameters( + self, valid_operations: Optional[Set["Operation"]] = None + ) -> None: + print(f"Can turn into {'Addition' if self.is_add else 'Subtraction'}") + return + class Multiplication(AbstractOperation): r""" @@ -327,6 +357,17 @@ class Multiplication(AbstractOperation): def evaluate(self, a, b): return a * b + def _propagate_some_constants( + self, + constants: Iterable[Optional[Number]], + valid_operations: Optional[Set["Operation"]] = None, + ) -> None: + if any(c == 0.0 for c in constants): + print("One input is 0!") + if any(c == 1.0 for c in constants): + print("One input is 1!") + print("Can turn into ConstantMultiplication") + class Division(AbstractOperation): r""" @@ -368,6 +409,17 @@ class Division(AbstractOperation): def evaluate(self, a, b): return a / b + def _propagate_some_constants( + self, + constants: Iterable[Optional[Number]], + valid_operations: Optional[Set["Operation"]] = None, + ) -> None: + numerator, denominator = constants + if numerator == 0.0: + print("Result is 0!") + if denominator is not None: + print("Can turn into ConstantMultiplication") + class Min(AbstractOperation): r""" @@ -619,6 +671,15 @@ class ConstantMultiplication(AbstractOperation): """Set the constant value of this operation.""" self.set_param("value", value) + def _propagate_constant_parameters( + self, valid_operations: Optional[Set["Operation"]] = None + ) -> None: + if self.value == 0.0: + print("Constant is zero!") + if self.value == 1.0: + print("Constant is zero!") + return + class Butterfly(AbstractOperation): r""" @@ -661,6 +722,14 @@ class Butterfly(AbstractOperation): def evaluate(self, a, b): return a + b, a - b + def _propagate_some_constants( + self, + constants: Iterable[Optional[Number]], + valid_operations: Optional[Set["Operation"]] = None, + ) -> None: + if any(c == 0.0 for c in constants): + print("One input is 0!") + class MAD(AbstractOperation): r""" @@ -700,6 +769,19 @@ class MAD(AbstractOperation): def evaluate(self, a, b, c): return a * b + c + def _propagate_some_constants( + self, + constants: Iterable[Optional[Number]], + valid_operations: Optional[Set["Operation"]] = None, + ) -> None: + a, b, c = constants + if a == 0.0 or b == 0.0: + print("One multiplier input is zero!") + if a == 1.0 or b == 1.0: + print("One multiplier input is one!") + if any(c == 0.0): + print("Adder input is zero!") + class SymmetricTwoportAdaptor(AbstractOperation): r""" @@ -752,6 +834,21 @@ class SymmetricTwoportAdaptor(AbstractOperation): """Set the constant value of this operation.""" self.set_param("value", value) + def _propagate_some_constants( + self, + constants: Iterable[Optional[Number]], + valid_operations: Optional[Set["Operation"]] = None, + ) -> None: + if any(c == 0.0 for c in constants): + print("One input is 0!") + + def _propagate_constant_parameters( + self, valid_operations: Optional[Set["Operation"]] = None + ) -> None: + if self.value == 0.0: + print("Constant is zero!") + return + class Reciprocal(AbstractOperation): r""" diff --git a/b_asic/operation.py b/b_asic/operation.py index 23c5cd99..15284757 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -19,6 +19,7 @@ from typing import ( NewType, Optional, Sequence, + Set, Tuple, Union, cast, @@ -435,6 +436,30 @@ class Operation(GraphComponent, SignalSourceProvider): def _check_all_latencies_set(self) -> None: raise NotImplementedError + @abstractmethod + def _propagate_constants( + self, valid_operations: Optional[Set["Operation"]] = None + ) -> None: + raise NotImplementedError + + @abstractmethod + def _constant_inputs(self) -> Iterable[Number]: + raise NotImplementedError + + @abstractmethod + def _propagate_some_constants( + self, + constants: Iterable[Optional[Number]], + valid_operations: Optional[Set["Operation"]] = None, + ) -> None: + raise NotImplementedError + + @abstractmethod + def _propagate_constant_parameters( + self, valid_operations: Optional[Set["Operation"]] = None + ) -> None: + raise NotImplementedError + class AbstractOperation(Operation, AbstractGraphComponent): """ @@ -1093,3 +1118,43 @@ class AbstractOperation(Operation, AbstractGraphComponent): for k in range(len(self.outputs)) ] return input_coordinates, output_coordinates + + def _propagate_constants( + self, valid_operations: Optional[Set["Operation"]] = None + ) -> None: + # Must be implemented per operation, so just return otherwise + constants = self._constant_inputs() + if all(c is None for c in constants): + return + if all(c is not None for c in constants): + res = self.evalute(*constants) + print(f"Result is {res}!") + if any(c is not None for c in constants): + # This is operation dependent + self._propagate_some_constants(constants, valid_operations) + return + + def _constant_inputs(self) -> Iterable[Optional[Number]]: + from b_asic.core_operations import Constant + + ret = [] + for port in self._input_ports: + if port.connected_source is None: + ret.append(None) + elif isinstance(port.connected_source.operation, Constant): + ret.append(port.connected_source.operation.value) + else: + ret.append(None) + return ret + + def _propagate_some_constants( + self, + constants: Iterable[Optional[Number]], + valid_operations: Optional[Set["Operation"]] = None, + ) -> None: + return + + def _propagate_constant_parameters( + self, valid_operations: Optional[Set["Operation"]] = None + ) -> None: + return -- GitLab