"""
B-ASIC Port Module.
TODO: More info.
"""

from b_asic.signal import Signal
from abc import ABC, abstractmethod
from typing import NewType, Optional, List, Dict, final

PortId = NewType("PortId", int)

class Port(ABC):
	"""
	Abstract port class.
	TODO: More info.
	"""

	_identifier: PortId

	def __init__(self, identifier: PortId):
		"""
		Construct a Port.
		"""
		self._identifier = identifier
	
	@final
	def identifier(self) -> PortId:
		"""
		Get the unique identifier.
		"""
		return self._identifier

	@abstractmethod
	def signals(self) -> List[Signal]:
		"""
		Get a list of all connected signals.
		"""
		pass
	
	@abstractmethod
	def signal_count(self) -> int:
		"""
		Get the number of connected signals.
		"""
		pass

	@abstractmethod
	def signal(self, i: int = 0) -> Signal:
		"""
		Get the connected signal at index i.
		"""
		pass

	@abstractmethod
	def connect(self, signal: Signal) -> None:
		"""
		Connect a signal.
		"""
		pass

	@abstractmethod
	def disconnect(self, i: int = 0) -> None:
		"""
		Disconnect a signal.
		"""
		pass

	# TODO: More stuff.

class InputPort(Port):
	"""
	Input port.
	TODO: More info.
	"""
	_source_signal: Optional[Signal]

	def __init__(self, identifier: PortId):
		"""
		Construct an InputPort.
		"""
		super().__init__(identifier)
		self._source_signal = None

	@final
	def signals(self) -> List[Signal]:
		return [] if self._source_signal == None else [self._source_signal]
	
	@final
	def signal_count(self) -> int:
		return 0 if self._source_signal == None else 1

	@final
	def signal(self, i: int = 0) -> Signal:
		assert i >= 0 and i < self.signal_count() # TODO: Error message.
		assert self._source_signal != None # TODO: Error message.
		return self._source_signal

	@final
	def connect(self, signal: Signal) -> None:
		self._source_signal = signal

	@final
	def disconnect(self, i: int = 0) -> None:
		assert i >= 0 and i < self.signal_count() # TODO: Error message.
		self._source_signal = None

	# TODO: More stuff.

class OutputPort(Port):
	"""
	Output port.
	TODO: More info.
	"""

	_destination_signals: List[Signal]

	def __init__(self, identifier: PortId):
		"""
		Construct an OutputPort.
		"""
		super().__init__(identifier)
		self._destination_signals = []

	@final
	def signals(self) -> List[Signal]:
		return self._destination_signals.copy()

	@final
	def signal_count(self) -> int:
		return len(self._destination_signals)

	@final
	def signal(self, i: int = 0) -> Signal:
		assert i >= 0 and i < self.signal_count() # TODO: Error message.
		return self._destination_signals[i]

	@final
	def connect(self, signal: Signal) -> None:
		assert signal not in self._destination_signals # TODO: Error message.
		self._destination_signals.append(signal)
	
	@final
	def disconnect(self, i: int = 0) -> None:
		assert i >= 0 and i < self.signal_count() # TODO: Error message.
		del self._destination_signals[i]
		
	# TODO: More stuff.