#!/bin/env python3

'''
Created on 23/11/2014

@author: rand huso
'''

from mpi4py import MPI
import psutil
import socket

class MI_MachineInfo( object ):
	''' This (message) class provides information about the CPU where this instance of the application is running in an MPI environment. '''
	def __init__( self, rank ):
		self.hostname = socket.gethostbyaddr( socket.gethostname())[0]  # or use Get_processor_name()
		self.memory = psutil.virtual_memory().total
		self.rank = rank

	def __repr__( self ):
		return '{},{},{}'.format( self.rank, self.hostname, self.memory )

''' This (message) class is for internal use only. '''
class MI_Shutdown( object ): pass

class MPI_IoC( object ):
	'''
	This class is the framework that controls the sending and receiving of messages in an MPI execution environment.

	Typical use:
	mi = MPI_IoC( None, infoRecipientMethod )
	mi.register( MyCommClass1, methodToProcessCommClass1 )
	mi.register( MyCommClass2, methodToProcessCommClass2 )
	mi.register( MyCommClass3, methodToProcessCommClass3 )
	...
	mi.start()

	Once started, the framework on each node will gather the capabilities of the CPU and send this information to Rank0 (the "control node").
	When any message is received by any node (including the capabilities message) the registered method is invoked and given the number of the
	node that sent the message and the received message.
	When a message is received, the node may take the opportunity to send message(s) to other node(s).
	The control node (Rank0) will end the process by sending a request for all to terminate (the "stop" method).
	'''
	def __init__( self, machineInfoRecipientRank0=None, startupMethodRank0=None ):
		'''
		The startupMethodRank0 is your method to be invoked when the framework is started.
		The machineInfoRecipientRank0 is your method to be invoked when the capabilities message (MI_MachineInfo) is received from any node.
		One of these methods should be supplied - leaving both None will be very uneventful.
		'''
		self.startupMethodRank0 = startupMethodRank0
		self.callbackTagByClassName = {}
		self.callbackMethodByNumber = {}
		self.comm = MPI.COMM_WORLD
		self.rank = self.comm.Get_rank()
		self.size = self.comm.Get_size()
		self.nextAvailableTag = 900;
		self.iocLoop = True
		self.register( MI_Shutdown, self._doShutdown )
		self.machineInfoRecipientRank0 = machineInfoRecipientRank0
		if machineInfoRecipientRank0 is not None: self.register( MI_MachineInfo, machineInfoRecipientRank0 )

	def getRank( self ) -> int: return self.rank
	def getSize( self ) -> int: return self.size
	def isControl( self ) -> bool: return 0 == self.getRank()

	def getNextAvailableTag( self ) -> int:
		''' this requires that all messages be registered in the same order - the alternative appeals less '''
		self.nextAvailableTag += 1
		return self.nextAvailableTag

	def register( self, callbackItem, callbackMethod ) -> None:
		'''
		Register callback methods to be invoked when objects of the callback item are received.
		The framework will invoke the "callbackMethod" when objects of the "callbackItem" are received.
		'''
		callbackTag = self.getNextAvailableTag()
		self.callbackTagByClassName[callbackItem.__name__] = callbackTag
		self.callbackMethodByNumber[callbackTag] = callbackMethod

	def sendMessage( self, txNode, txMessage ) -> None:
		''' Sends the message to the specified node. '''
		if txMessage.__class__.__name__ in self.callbackTagByClassName:
			tag = self.callbackTagByClassName[txMessage.__class__.__name__]
			self._sendMsg( txNode, txMessage, tag )

	def bcastMessage( self, txMessage, toAll=True ) -> None:
		''' this is a shorthand for sending a message to all nodes - not like the MPI_Bcast method '''
		tag = self.callbackTagByClassName[txMessage.__class__.__name__]
		for node in range( self.getSize()):
			if toAll or node != self.rank:
				self._sendMsg( node, txMessage, tag )

	def _sendMsg( self, txNode, txMessage, tag ) -> None:
		self.comm.send( txMessage, dest=txNode, tag=tag )

	def msgAvailable( self ) -> bool:
		''' check to see if another message has become available '''
		status = MPI.Status()
		return self.comm.Iprobe( source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status )

	def processIncomingMessage( self ) -> None:
		''' process the newly arrived message (wait for one to arrive if necessary) '''
		status = MPI.Status()
		data = self.comm.recv( source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status )
		self.callbackMethodByNumber[ status.Get_tag()]( status.Get_source(), data )

	def start( self ) -> None:
		''' yield control of the application to this framework '''
		if self.machineInfoRecipientRank0 is not None: self.sendMessage( 0, MI_MachineInfo( self.getRank()) )
		if self.isControl():
			if self.startupMethodRank0 is not None: self.startupMethodRank0()
		while self.iocLoop:
			self.processIncomingMessage()

	def stop( self ) -> None:
		''' Stops the main loop and returns control to the application. '''
		if self.isControl() and self.iocLoop: self.bcastMessage( MI_Shutdown())
		self._doShutdown()

	def _doShutdown( self, unused_sourceNode=None, unused_rxMessage=None ) -> None:
		self.iocLoop = False

'''
Test software for the above.
This tests the registering, sending, and receiving of messages, testing if a message is available, startup, and capabilities.
'''
# import time
# import random

class MI_Test( object ):
	def __init__( self ):
		self.data = { 'a': [ random.randint(  1, 3  ), random.randint(  1, 3  ), random.randint(  1, 3  ) ], 'b': 'Fred', 'c': 2.718281828 }
	def __repr__( self ):
		return '\tMI_Test:: {}'.format(  self.data  )

class TestMain(  object  ):
	nodesReportedIn = 0

	def rxMessage( self, sourceNode, msg ):
		print( '[{}:{}]\t->\tTestMain::rxMessage \tsourceNode {} data {}\t\tmsgAvailable {}'.format(  self.rank, self.size, sourceNode, msg, self.m.msgAvailable()) )
		time.sleep( 1 )
		print( '[{}:{}]\t->\tTestMain::rxMessage \tmsgAvailable[{}]'.format(  self.rank, self.size, self.m.msgAvailable()) )
		if self.m.msgAvailable():
			print( '[{}:{}]\t->\tTestMain::rxMessage message available - processing locally vvvvv'.format(  self.rank, self.size ))
			self.m.processIncomingMessage()
			print( '[{}:{}]\t->\tTestMain::rxMessage message available - processed locally  ^^^^^'.format(  self.rank, self.size ))

	def startupMethod(  self  ):
		print( '[{}:{}]\t->\tTestMain::startupMethod'.format(  self.rank, self.size  ))
		data = MI_Test()
		self.m.bcastMessage(  data  )
		self.m.bcastMessage(  data  )
		self.m.bcastMessage(  data  )
		self.m.sendMessage(  1, data  )
		self.m.sendMessage(  1, data  )

	def capabilitiesIn(  self, sourceNode, data  ):
		print( '[{}:{}]\t->\tTestMain::capabilitiesIn \t{} {}\t\tmsgAvailable {}'.format(  self.rank, self.size, sourceNode, data, self.m.msgAvailable()) )
		self.nodesReportedIn += 1
		if self.size == self.nodesReportedIn:
			time.sleep( 5 )
			print( '[{}:{}]\t->\tTestMain::capabilitiesIn \t\tmsgAvailable[{}]'.format(  self.rank, self.size, self.m.msgAvailable()) )
			self.m.stop()

	def start(  self  ):
		self.m = MPI_IoC(  self.capabilitiesIn, self.startupMethod  )
		self.rank = self.m.getRank()
		self.size = self.m.getSize()
		self.m.register(  MI_Test, self.rxMessage  )
		self.m.start()

def main():
	testObj = TestMain()
	testObj.start()

if __name__ == "__main__":
	import time
	import random
	main()
