Skip to content
Snippets Groups Projects
operation.py 4.8 KiB
Newer Older
  • Learn to ignore specific revisions
  • """
    B-ASIC Operation Module.
    TODO: More info.
    """
    from b_asic.port import InputPort, OutputPort
    from b_asic.signal import SignalSource, SignalDestination
    from b_asic.simulation import SimulationState, OperationState
    from abc import ABC, abstractmethod
    from numbers import Number
    from typing import NewType, List, Dict, Optional, final
    
    OperationId = NewType("OperationId", int)
    
    class Operation(ABC):
    	"""
    	Operation interface.
    	TODO: More info.
    	"""
    
    	@abstractmethod
    	def identifier(self) -> OperationId:
    		"""
    		Get the unique identifier.
    		"""
    		pass
    
    	@abstractmethod
    	def inputs(self) -> List[InputPort]:
    		"""
    		Get a list of all input ports.
    		"""
    		pass
    
    	@abstractmethod
    	def outputs(self) -> List[OutputPort]:
    		"""
    		Get a list of all output ports.
    		"""
    		pass
    
    	@abstractmethod
    	def input_count(self) -> int:
    		"""
    		Get the number of input ports.
    		"""
    		pass
    
    	@abstractmethod
    	def output_count(self) -> int:
    		"""
    		Get the number of output ports.
    		"""
    		pass
    
    	@abstractmethod
    	def input(self, i: int) -> InputPort:
    		"""
    		Get the input port at index i.
    		"""
    		pass
    
    	@abstractmethod
    	def output(self, i: int) -> OutputPort:
    		"""
    		Get the output port at index i.
    		"""
    		pass
    
    	@abstractmethod
    	def params(self) -> Dict[str, Optional[Any]]:
    		"""
    		Get a dictionary of all parameter values.
    		"""
    		pass
    
    	@abstractmethod
    	def param(self, name: str) -> Optional[Any]:
    		"""
    		Get the value of a parameter.
    		Returns None if the parameter is not defined.
    		"""
    		pass
    
    	@abstractmethod
    	def set_param(self, name: str, value: Any) -> None:
    		"""
    		Set the value of a parameter.
    		The parameter must be defined.
    		"""
    		pass
    
    	@abstractmethod
    	def evaluate_outputs(self, state: SimulationState) -> List[Number]:
    		"""
    		Simulate the circuit until its iteration count matches that of the simulation state,
    		then return the resulting output vector.
    		"""
    		pass
    
    	@abstractmethod
    	def split(self) -> List[Operation]:
    		"""
    		Split the operation into multiple operations.
    		If splitting is not possible, this may return a list containing only the operation itself.
    		"""
    		pass
    	
    	# TODO: More stuff.
    
    class BasicOperation(ABC, Operation):
    	"""
    	Generic abstract operation class which most implementations will derive from.
    	TODO: More info.
    	"""
    
    	_identifier: OperationId
    	_input_ports: List[InputPort]
    	_output_ports: List[OutputPort]
    	_parameters: Dict[str, Optional[Any]]
    
    	def __init__(self, identifier: OperationId):
    		"""
    		Construct a BasicOperation.
    		"""
    		self._identifier = identifier
    		self._input_ports = []
    		self._output_ports = []
    		self._parameters = {}
    
    	@abstractmethod
    	def evaluate(self, inputs: list) -> list:
    		"""
    		Evaluate the operation and generate a list of output values given a list of input values.
    		"""
    		pass
    
    	@final
    	def id(self) -> OperationId:
    		return self._identifier
    		
    	@final
    	def inputs(self) -> List[InputPort]:
    		return self._input_ports.copy()
    
    	@final
    	def outputs(self) -> List[OutputPort]:
    		return self._output_ports.copy()
    
    	@final
    	def input_count(self) -> int:
    		return len(self._input_ports)
    
    	@final
    	def output_count(self) -> int:
    		return len(self._output_ports)
    
    	@final
    	def input(self, i: int) -> InputPort:
    		return self._input_ports[i]
    
    	@final
    	def output(self, i: int) -> OutputPort:
    		return self._output_ports[i]
    
    	@final
    	def params(self) -> Dict[str, Optional[Any]]:
    		return self._parameters.copy()
    	
    	@final
    	def param(self, name: str) -> Optional[Any]:
    		return self._parameters.get(name)
    
    	@final
    	def set_param(self, name: str, value: Any) -> None:
    		assert name in self._parameters # TODO: Error message.
    		self._parameters[name] = value
    
    	def evaluate_outputs(self, state: SimulationState) -> List[Number]:
    		# TODO: Check implementation.
    		input_count: int = self.input_count()
    		output_count: int = self.output_count()
    		assert input_count == len(self._input_ports) # TODO: Error message.
    		assert output_count == len(self._output_ports) # TODO: Error message.
    
    		self_state: OperationState = state.operation_states[self.identifier()]
    
    		while self_state.iteration < state.iteration:
    			input_values: List[Number] = [0] * input_count
    			for i in range(input_count):
    				source: SignalSource = self._input_ports[i].signal().source
    				input_values[i] = source.operation.evaluate_outputs(state)[source.port_index]
    
    			self_state.output_values = self.evaluate(input_values)
    			assert len(self_state.output_values) == output_count # TODO: Error message.
    			self_state.iteration += 1
    			for i in range(output_count):
    				for signal in self._output_ports[i].signals():
    					destination: SignalDestination = signal.destination
    					destination.evaluate_outputs(state)
    
    		return self_state.output_values
    
    	def split(self) -> List[Operation]:
    		# TODO: Check implementation.
    		results = self.evaluate(self._input_ports)
    		if all(isinstance(e, Operation) for e in results):
    			return results
    		return [self]
    		
    	# TODO: More stuff.