#!/usr/bin/python3 -O
import asyncio
import tcpip
from tcpip import Message
import pickle
from config_manager import ConfigReader
from process import ProcessCovar
from process import ProcessClusters
from process import ProcessRegimes
from process import ProcessFile
from process import ProcessGraphs
from objects import ObjectContainer
import atexit

class Server:
    def __init__(self):
        ObjectContainer.getInstance().load()
        self._conf = ConfigReader('regimes.ini')
        self._server = self._conf.server()
        self._port = self._conf.port()
        self._loop = asyncio.get_event_loop()
        self._processes = set()
        self._connections = {}
        self._messages =  asyncio.Queue(loop=self._loop)
        self._delconnections = asyncio.Queue(loop=self._loop)
        self._core = asyncio.start_server(self.handle_connect, self._server, self._port, loop=self._loop)
        self._th1 = asyncio.ensure_future(self.processing_thread())
        self._th2 = asyncio.ensure_future(self.scan_thread())
        self._th3 = asyncio.ensure_future(self.out_thread())
        self._run = True        
        
    def run(self):
        self._server = self._loop.run_until_complete(asyncio.gather(self._core,self._th1,self._th2,self._th3))
        self._loop.run_forever()
  
    async def out_thread(self):
        while self._run:     
            await asyncio.sleep(0)
            msg = await self._messages.get()
            data = msg.tos()
            
            for a,(r,w) in self._connections.items():
                if w._transport.is_closing():
                    print ('disconnected')
                    self._delconnections.put_nowait(a)
                else:
                    try:
                        size = len(data).to_bytes(4,byteorder='big')
                        w.write(size)
                        w.write(data)
                        #w.write(Message._SEP)
                        await w.drain()
                        await w.drain()
                        await asyncio.sleep(0)
                    except:
                        self._delconnections.put_nowait(a)
                        w.close()
                        print ('error connection:',a)
                        await asyncio.sleep(0)
                        
                    
            while (not self._delconnections.empty()):
                self._connections.pop(self._delconnections.get_nowait())
                
    def close(self):
        # Close the server
        self._server.close()
        self._loop.run_until_complete(server.wait_closed())
        self._loop.close()            
    
    
    def kill_all(self):
        for p in self._processes :
            if (p._running):
                p.interrupt()
        m = Message(Message.STRING,'all tasks interrupted')
        self._messages.put(m)           
                
    async def scan_thread(self):
        while self._run:
            await asyncio.sleep(1)
            todel = []
            for p in self._processes :
                if p._completed:
                    if p._progress:
                        p._progress._iteration = p._progress._total
                        m = Message(Message.PROGRESS,p._progress._display())
                        await self._messages.put(m)                    
                    m = Message(Message.STRING,p.name + ' done')
                    await self._messages.put(m)                
                    todel.append(p)
                if (p._running):
                    if p._progress:
                        m = Message(Message.PROGRESS,p._progress._display())
                        await self._messages.put(m)
                    else:
                        m = Message(Message.STRING,p.name + ': in progress')
            
            while todel:
                self._processes.remove(todel.pop())
                                
    async def processing_thread(self):
        while self._run:
            for p in self._processes:
                if not p._running and not p._completed:
                    r = await p.run()
                    print (p.name,'finished')
                    if not r:
                        m = Message(Message.STRING,p.name + ': cannot be started')
                        await self._messages.put(m)
                    elif r != 'ok':
                        if type(p) is ProcessRegimes:
                            m = Message(Message.STRING,r)
                            await self._messages.put(m)  
                        elif type(p) is ProcessGraphs:
                            m = Message(Message.GRAPH,r)
                            await self._messages.put(m)                              
                    break        
                            
            await asyncio.sleep(1)       
        
    async def handle_connect(self, reader, writer):
        addr =  writer.get_extra_info('peername')
        #bug python 3.6
        writer.transport.set_write_buffer_limits(0)
        self._connections[addr] = (reader,writer)
        print ('connection from ', addr)
        if reader.exception() :
            print (reader.exception())
        while self._run:
            try:
                data = await reader.read(3096)
                if len(data)==0:
                    return
                else:
                    await self.process(data)
                    await asyncio.sleep(0)
            except:
                await asyncio.sleep(1)
                continue
            
            
            
    async def process(self, data):
        
        if len(data) == 0:
            return
        
        message = pickle.loads(data)
        
        if message.cmd == Message.STRING:
            print ('processing message STRING')
            rep = self.rString(*message.args)
    
        elif message.cmd == Message.RUN_COVARS:
            print ('processing message RUN_COVARS')
            process = ProcessCovar(self._loop)
            self._processes.add(process)
    
        elif message.cmd == Message.KILL_ALL:
            print ('processing message KILL_ALL')
            self.kill_all()
    
        elif message.cmd == Message.RUN_CLUSTERS:
            print ('processing message RUN_CLUSTERS')
            process = ProcessClusters(self._loop)
            self._processes.add(process)
        
        elif message.cmd == Message.RUN_REGIMES:
            print ('processing message RUN_REGIMES')
            process = ProcessRegimes(self._loop)
            self._processes.add(process)
        
        elif message.cmd == Message.FILE:
            print ('processing message FILE')
            process = ProcessFile(self._loop,message.args)
            self._processes.add(process)
        
        elif message.cmd == Message.GRAPH:
            print ('processing message GRAPH')
            process = ProcessGraphs(self._loop,message.args)
            self._processes.add(process)        
             
def clean_exit():
    print ('saving ...')
    ObjectContainer.getInstance().save()    
            
            
           
if __name__ == "__main__":
    atexit.register(clean_exit)
    s = Server()
    
    try:
        s.run()
        s.close()
    except Exception as e:
        ObjectContainer.getInstance().save()    
        print('exception ',e)