#!/usr/bin/python3  -O
import sys
from datetime import datetime
from distribs import DistList
from config_manager import ConfigReader
import dill
import os.path
from utils import ProgressBar
import numpy as np
from math import *
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
from sklearn import cluster, covariance, manifold
from tabulate import tabulate



class Clusters:
    
    def __init__(self,conf):
        self._conf = conf
        self._NCLUSTERS = conf.clusterN()
        self._THRESHOLD = 1
        distfile = conf.mastercovlist()
        with open(distfile,'rb') as f:
            self._distlist = dill.load(f)
    
   
    def spectral_centroids(self, labels):
        centroids=[]
        for i in range(labels.max()+1):
            idx=labels==i
            mu = self._embedding[:,idx].mean(axis=1).reshape(2,1)
                #find the closest point to mu
            dist = np.square(self._embedding-mu).sum(axis=0)
            i2 = np.argmin(dist)
            covs = self._distlist._covlist
            centroids.append((self._embedding[:,i2],i2, self._embedding[:,idx].shape[1], covs[i2]))
        return centroids
    
    def graph_(self,centroids, labels,t = 'clusters'):
        
        embedding = self._embedding
        L = labels
       
        colormap = [cm.gnuplot2(i) for i in np.linspace(0.,1.,max(L)+2)]
        colorms = [colormap[i] for i in L]

        fig = plt.figure()
        plt.scatter(embedding[0,:],embedding[1,:],s=10,c=colorms,alpha=0.8)   

        processed=set()

        for i,col in enumerate(range(max(L)+1)):
            if col in processed:
                continue
            lxy=centroids[i][0]
            labelx,labely = lxy[0],lxy[1]
            textx = labelx + 0.004
            texty = labely + 0.01
            plt.annotate(str(col),xy=(labelx,labely),xytext=(textx,texty), 
                         arrowprops=dict(arrowstyle="->",connectionstyle="arc3"),size=6)

        for i,z in enumerate(centroids):
            plt.scatter(z[0][0],z[0][1],s=50,c=[colormap[i]],alpha=1.0)   

        plt.title(t)    
        plt.show()
        plt.savefig('clusters')
        plt.close(fig)
        
   
        
    
    def compute_clusters(self): 
        nn = self._conf.clspectral_neighbors()
        node_position_model = manifold.SpectralEmbedding(n_components=2,n_neighbors=nn,random_state=42)
        self._embedding = node_position_model.fit_transform(self._hellingerd.T).T
        cind = self._distlist._covindex  
        X = self._embedding.astype('float64').T
        N = self._conf.clusterN()
        aff = self._conf.claffinity()
        nn = self._conf.cln_neighbors()
        rs = self._conf.clrandom_states()
        es = self._conf.cleigen_solver()
        self._agg = cluster.SpectralClustering(n_clusters=N,affinity=aff,eigen_solver=es,n_neighbors=nn,random_state=rs).fit(X)
        self._clabels=self._agg.labels_
        with open('clabels.plk','wb') as f:
            dill.dump(self._clabels,f)
        with open('clusterengine.plk','wb') as f:
            dill.dump(self._agg,f)
        self._centroids = self.spectral_centroids(self._clabels)          
  
    def cal_affinity_matrix(self):
        if (os.path.isfile('hellinger.plk')):
            self._hellingerd = dill.load(open('hellinger.plk','rb'))
            return
        cind = self._distlist._covindex
        covs = self._distlist._covlist        
        hellingerd = np.full((covs.shape[0],covs.shape[0]),0.)
        n = covs.shape[1]
        br = ProgressBar(covs.shape[0]*covs.shape[0], prefix='Hellinger Affinities', suffix='Completed')
        for u in cind:
            for v in cind:
                br.iterate()
                if hellingerd[v,u] == 0:
                    hellingerd[u,v] = self._distlist.hellinger(n,covs[u],covs[v],self._conf.smooth_hellinger())
                else:
                    hellingerd[u,v] = hellingerd[v,u]
       
        self._hellingerd = hellingerd
        with open('hellinger.plk','wb') as f:
            dill.dump(hellingerd,f)
       
 
class Regimes:
    def __init__(self, clusters):
        self._clusters = clusters
        self._reg_threshold = clusters._conf.reg_threshold()
    
    def printout_centroids(self,prefix):
        #distances
        D = np.full((len(self._rcentroids),len(self._rcentroids)),0.)
        rows=[]
        sm = self._clusters._conf.smooth_hellinger()
        hellinger = self._clusters._distlist.hellinger
        n = self._clusters._distlist._covlist.shape[1]
        data = self._clusters._distlist._covlist
        for i,x in enumerate(self._rcentroids):
            print (prefix +'#',i,':',x[2],'elements')
            row=[prefix + '#%d'%i]
            for j,y in enumerate(self._rcentroids):
                h = hellinger(n,data[x[1]],data[y[1]],sm)
                row.append(h)
            rows.append(row)
        headers=[prefix + '#%d'%r for r in range(len(self._rcentroids))]
        table = tabulate(rows,headers= headers,tablefmt='orgtbl',floatfmt='.3f')
        print (table)    
        
    def merge_regimes(self,s,t):
        idx = np.where(self._regimes==s)
        self._regimes[idx]=t
        idx = np.where(self._regimes > s)
        self._regimes[idx] =  self._regimes[idx]-1  
        for k,x in self._rdates.items():
            if x == s:
                self._rdates[k] = t
            elif x > s:
                self._rdates[k] = x-1
        
    def delete_centroid(self,s):
        for k in range(s,len(self._rcentroids)-1):
            self._rcentroids[k]=self._rcentroids[k+1]
        del self._rcentroids[len(self._rcentroids)-1]
    
    def compute_inter(self):
        dist = self._clusters._distlist
        K = 0
        lk = 0
        dlabels = dict()
        rlabels = np.array(self._clusters._clabels)
        vregimes = []
        for x,dt in enumerate(dist._index):
            for l, kitem in enumerate(dist._ilist[K:]):
                if x < kitem[0]:
                    break
                if x >= kitem[1]:
                    lk += 1
                    continue
                if not dt in dlabels.keys():
                    dlabels[dt] = set([rlabels[l+K]])
                else:
                    dlabels[dt].add(rlabels[l+K])
               
            K += lk  
            lk = 0
        self._dlabels = dlabels
        
    def label_regimes(self):
        self._rcounts = dict()
        self._rdates = dict()
        self._reflabel = dict()
        x = 0
        regimes = []
        indlist = list(self._clusters._distlist._ilist)
        
        for k,v in self._dlabels.items():
            key = tuple(sorted(v))
            if key in self._rcounts.keys():
                self._rcounts[key] += 1
            else:
                self._rcounts[key] = 1
                self._reflabel[key]= len(self._reflabel)
            self._rdates[k]=self._reflabel[key]
            if (indlist and x == indlist[0][0]):
                indlist = indlist[1:]
                regimes.append(self._reflabel[key])
            x += 1      
        
        self._regimes = np.array(regimes)   
        self._rcentroids = self._clusters.spectral_centroids(self._regimes)
       
        
    def filter_small(self):
        cont=True
        s = 0
        t = 1
        data = self._clusters._distlist._covlist
        
        while cont :
            if self._rcentroids[s][2] > self._reg_threshold:
                s+=1
                t=0
                if s >= len(self._rcentroids):
                    cont=False
                continue
            distarget = data[self._rcentroids[s][1]]
            v = self._clusters._distlist.vdist(distarget, self._rcentroids, self._clusters._conf.smooth_hellinger())
            v[s]=100.
            t = np.argmin(v)
            self.merge_regimes(s,t)
            if len(self._rcentroids) > 1:
                self.delete_centroid(s)
                if s >= len(self._rcentroids):
                    s-=1
            else:
                cont=False

        #re-evaluate centroids
        self._rcentroids = self._clusters.spectral_centroids(self._regimes)
        #cl._dt2rg()        
        
        
if __name__ == "__main__":
    conf = ConfigReader('regimes.ini')
    cl = Clusters(conf)
    cl.cal_affinity_matrix()
    cl.compute_clusters()
    r = Regimes(cl)
    r.compute_inter()
    r.label_regimes()
    r.printout_centroids('regimes')
    r.filter_small()
    r.printout_centroids('regimes')
    r._clusters.graph_(r._rcentroids,r._regimes,'regimes')
    