"""@package docstring
B-ASIC Basic Operation Module.
TODO: More info.
"""

from abc import abstractmethod
from typing import List, Dict, Optional, Any
from numbers import Number

from b_asic.port import InputPort, OutputPort
from b_asic.signal import Signal
from b_asic.operation import Operation
from b_asic.simulation import SimulationState, OperationState


class BasicOperation(Operation):
	"""Generic abstract operation class which most implementations will derive from.
	TODO: More info.
	"""

	_input_ports: List[InputPort]
	_output_ports: List[OutputPort]
	_parameters: Dict[str, Optional[Any]]

	def __init__(self):
		"""Construct a BasicOperation."""
		self._input_ports = []
		self._output_ports = []
		self._parameters = {}

	@abstractmethod
	def evaluate(self, inputs: list) -> list:
		"""Evaluate the operation and generate a list of output values given a list of input values."""
		pass

	def inputs(self) -> List[InputPort]:
		return self._input_ports.copy()

	def outputs(self) -> List[OutputPort]:
		return self._output_ports.copy()

	def input_count(self) -> int:
		return len(self._input_ports)

	def output_count(self) -> int:
		return len(self._output_ports)

	def input(self, i: int) -> InputPort:
		return self._input_ports[i]

	def output(self, i: int) -> OutputPort:
		return self._output_ports[i]

	def params(self) -> Dict[str, Optional[Any]]:
		return self._parameters.copy()

	def param(self, name: str) -> Optional[Any]:
		return self._parameters.get(name)

	def set_param(self, name: str, value: Any) -> None:
		assert name in self._parameters # TODO: Error message.
		self._parameters[name] = value

	def evaluate_outputs(self, state: SimulationState) -> List[Number]:
		# TODO: Check implementation.
		input_count: int = self.input_count()
		output_count: int = self.output_count()
		assert input_count == len(self._input_ports) # TODO: Error message.
		assert output_count == len(self._output_ports) # TODO: Error message.

		self_state: OperationState = state.operation_states[self.identifier()]

		while self_state.iteration < state.iteration:
			input_values: List[Number] = [0] * input_count
			for i in range(input_count):
				source: Signal = self._input_ports[i].signal
				input_values[i] = source.operation.evaluate_outputs(state)[source.port_index]

			self_state.output_values = self.evaluate(input_values)
			assert len(self_state.output_values) == output_count # TODO: Error message.
			self_state.iteration += 1
			for i in range(output_count):
				for signal in self._output_ports[i].signals():
					destination: Signal = signal.destination
					destination.evaluate_outputs(state)

		return self_state.output_values

	def split(self) -> List[Operation]:
		# TODO: Check implementation.
		results = self.evaluate(self._input_ports)
		if all(isinstance(e, Operation) for e in results):
			return results
		return [self]

	@property
	def neighbours(self) -> List[Operation]:
		neighbours: List[Operation] = []
		for port in self._input_ports:
			for signal in port.signals:
				neighbours.append(signal.source.operation)

		for port in self._output_ports:
			for signal in port.signals:
				neighbours.append(signal.destination.operation)

		return neighbours

	# TODO: More stuff.