Newer
Older
B-ASIC Operation Module.
TODO: More info.
"""
Angus Lothian
committed
from abc import abstractmethod
Angus Lothian
committed
from typing import List, Dict, Optional, Any, Set, TYPE_CHECKING
from collections import deque
Angus Lothian
committed
from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name
from b_asic.simulation import SimulationState, OperationState
from b_asic.signal import Signal
Angus Lothian
committed
Angus Lothian
committed
Angus Lothian
committed
class Operation(GraphComponent):
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""Operation interface.
TODO: More info.
"""
@abstractmethod
def inputs(self) -> "List[InputPort]":
"""Get a list of all input ports."""
raise NotImplementedError
@abstractmethod
def outputs(self) -> "List[OutputPort]":
"""Get a list of all output ports."""
raise NotImplementedError
@abstractmethod
def input_count(self) -> int:
"""Get the number of input ports."""
raise NotImplementedError
@abstractmethod
def output_count(self) -> int:
"""Get the number of output ports."""
raise NotImplementedError
@abstractmethod
def input(self, i: int) -> "InputPort":
"""Get the input port at index i."""
raise NotImplementedError
@abstractmethod
def output(self, i: int) -> "OutputPort":
"""Get the output port at index i."""
raise NotImplementedError
@abstractmethod
def params(self) -> Dict[str, Optional[Any]]:
"""Get a dictionary of all parameter values."""
raise NotImplementedError
@abstractmethod
def param(self, name: str) -> Optional[Any]:
"""Get the value of a parameter.
Returns None if the parameter is not defined.
"""
raise NotImplementedError
@abstractmethod
def set_param(self, name: str, value: Any) -> None:
"""Set the value of a parameter.
The parameter must be defined.
"""
raise NotImplementedError
@abstractmethod
def evaluate_outputs(self, state: "SimulationState") -> List[Number]:
"""Simulate the circuit until its iteration count matches that of the simulation state,
then return the resulting output vector.
"""
raise NotImplementedError
@abstractmethod
def split(self) -> "List[Operation]":
"""Split the operation into multiple operations.
If splitting is not possible, this may return a list containing only the operation itself.
"""
raise NotImplementedError
@property
@abstractmethod
def neighbours(self) -> "List[Operation]":
"""Return all operations that are connected by signals to this operation.
If no neighbours are found this returns an empty list
"""
raise NotImplementedError
Angus Lothian
committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
class AbstractOperation(Operation, AbstractGraphComponent):
"""Generic abstract operation class which most implementations will derive from.
TODO: More info.
"""
_input_ports: List["InputPort"]
_output_ports: List["OutputPort"]
_parameters: Dict[str, Optional[Any]]
def __init__(self, name: Name = ""):
super().__init__(name)
self._input_ports = []
self._output_ports = []
self._parameters = {}
@abstractmethod
def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ
"""Evaluate the operation and generate a list of output values given a
list of input values."""
raise NotImplementedError
def inputs(self) -> List["InputPort"]:
return self._input_ports.copy()
def outputs(self) -> List["OutputPort"]:
return self._output_ports.copy()
def input_count(self) -> int:
return len(self._input_ports)
def output_count(self) -> int:
return len(self._output_ports)
def input(self, i: int) -> "InputPort":
return self._input_ports[i]
def output(self, i: int) -> "OutputPort":
return self._output_ports[i]
def params(self) -> Dict[str, Optional[Any]]:
return self._parameters.copy()
def param(self, name: str) -> Optional[Any]:
return self._parameters.get(name)
def set_param(self, name: str, value: Any) -> None:
assert name in self._parameters # TODO: Error message.
self._parameters[name] = value
def evaluate_outputs(self, state: SimulationState) -> List[Number]:
# TODO: Check implementation.
input_count: int = self.input_count()
output_count: int = self.output_count()
assert input_count == len(self._input_ports) # TODO: Error message.
assert output_count == len(self._output_ports) # TODO: Error message.
self_state: OperationState = state.operation_states[self]
while self_state.iteration < state.iteration:
input_values: List[Number] = [0] * input_count
for i in range(input_count):
source: Signal = self._input_ports[i].signal
input_values[i] = source.operation.evaluate_outputs(state)[source.port_index]
self_state.output_values = self.evaluate(input_values)
assert len(self_state.output_values) == output_count # TODO: Error message.
self_state.iteration += 1
for i in range(output_count):
for signal in self._output_ports[i].signals():
destination: Signal = signal.destination
destination.evaluate_outputs(state)
return self_state.output_values
def split(self) -> List[Operation]:
# TODO: Check implementation.
results = self.evaluate(self._input_ports)
if all(isinstance(e, Operation) for e in results):
return results
return [self]
@property
def neighbours(self) -> List[Operation]:
neighbours: List[Operation] = []
for port in self._input_ports:
for signal in port.signals:
neighbours.append(signal.source.operation)
for port in self._output_ports:
for signal in port.signals:
neighbours.append(signal.destination.operation)
return neighbours
def traverse(self) -> Operation:
"""Traverse the operation tree and return a generator with start point in the operation."""
return self._breadth_first_search()
def _breadth_first_search(self) -> Operation:
"""Use breadth first search to traverse the operation tree."""
visited: Set[Operation] = {self}
queue = deque([self])
while queue:
operation = queue.popleft()
yield operation
for n_operation in operation.neighbours:
if n_operation not in visited:
visited.add(n_operation)
queue.append(n_operation)