#!/usr/bin/python3  -O
import sys
from datetime import datetime
from distribs import DistList
from config_manager import ConfigReader
import dill
import os.path
from utils import ProgressBar
import numpy as np
from math import *
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
from sklearn import cluster, covariance, manifold



class Clusters:
    
    def __init__(self,conf):
        self._conf = conf
        self._NCLUSTERS = conf.clusterN()
        self._THRESHOLD = 1
        distfile = conf.mastercovlist()
        with open(distfile,'rb') as f:
            self._distlist = dill.load(f)
            
    def spectral_centroids(self):
        self._centroids=[]
        for i in range(self._clabels.max()+1):
            idx=self._clabels==i
            mu = self._embedding[:,idx].mean(axis=1).reshape(2,1)
                #find the closest point to mu
            dist = np.square(self._embedding-mu).sum(axis=0)
            i2 = np.argmin(dist)
            covs = self._distlist._covlist
            self._centroids.append((self._embedding[:,i2],i2, self._embedding[:,idx].shape[1], covs[i2]))
          
    def graph_clusters(self):
        t = 'clusters'
        embedding = self._embedding
        centroids = self._centroids
        L = self._clabels
       
        colormap = [cm.gnuplot2(i) for i in np.linspace(0.,1.,max(L)+2)]
        colorms = [colormap[i] for i in L]

        fig = plt.figure()
        plt.scatter(embedding[0,:],embedding[1,:],s=10,c=colorms,alpha=0.8)   

        processed=set()

        for i,col in enumerate(range(max(L)+1)):
            if col in processed:
                continue
            lxy=centroids[i][0]
            labelx,labely = lxy[0],lxy[1]
            textx = labelx + 0.004
            texty = labely + 0.01
            plt.annotate(str(col),xy=(labelx,labely),xytext=(textx,texty), 
                         arrowprops=dict(arrowstyle="->",connectionstyle="arc3"),size=6)

        for i,z in enumerate(centroids):
            plt.scatter(z[0][0],z[0][1],s=50,c=[colormap[i]],alpha=1.0)   

        plt.title(t)    
        plt.savefig('clusters')
        plt.close(fig)
        
    def hellinger(self,n,x,y,smooth,**kwargs):        
        x1 = x.reshape(n,n+1)
        x2 = y.reshape(n,n+1)
        mu1,sigma1=x1[:,0],x1[:,1:]
        mu2,sigma2=x2[:,0],x2[:,1:]
       
        sign1, ld1 = np.linalg.slogdet(sigma1)
        sign2, ld2 = np.linalg.slogdet(sigma2)
       
        dd = 0.5*(sigma1+sigma2)
        
        signdd, ldd = np.linalg.slogdet(dd)
        
        signt = (sign1 * sign2)/signdd
        if signt <= 0:
            print ('negative')
        s_1 = np.linalg.inv(dd)
        e = -0.125*(mu1-mu2).T.dot(s_1).dot(mu1-mu2)
        #d = pow(d1,.25)*pow(d2,.25)/pow(det1,.5)
        d = .25*(ld1+ld2-2*ldd)
        d = exp(d)
        es = 1-d*exp(e)
        #if es <=0:
        #    es = 0
        if smooth:
            r = atanh(es)
        else:
            r = sqrt(es)
        return r
        
    
    def compute_clusters(self): 
        nn = self._conf.clspectral_neighbors()
        node_position_model = manifold.SpectralEmbedding(n_components=2,n_neighbors=nn,random_state=42)
        self._embedding = node_position_model.fit_transform(self._hellingerd.T).T
        cind = self._distlist._covindex  
        X = self._embedding.astype('float64').T
        N = self._conf.clusterN()
        aff = self._conf.claffinity()
        nn = self._conf.cln_neighbors()
        rs = self._conf.clrandom_states()
        es = self._conf.cleigen_solver()
        self._agg = cluster.SpectralClustering(n_clusters=N,affinity=aff,eigen_solver=es,n_neighbors=nn,random_state=rs).fit(X)
        self._clabels=self._agg.labels_
        with open('clabels.plk','wb') as f:
            dill.dump(self._clabels,f)
        with open('clusterengine.plk','wb') as f:
            dill.dump(self._agg,f)
        self.spectral_centroids()          
  
    def cal_affinity_matrix(self):
        if (os.path.isfile('hellinger.plk')):
            self._hellingerd = dill.load(open('hellinger.plk','rb'))
            return
        cind = self._distlist._covindex
        covs = self._distlist._covlist        
        hellingerd = np.full((covs.shape[0],covs.shape[0]),0.)
        n = covs.shape[1]
        br = ProgressBar(covs.shape[0]*covs.shape[0], prefix='Hellinger Affinities', suffix='Completed')
        for u in cind:
            for v in cind:
                br.iterate()
                if hellingerd[v,u] == 0:
                    hellingerd[u,v] = self.hellinger(n,covs[u],covs[v],self._conf.smooth_hellinger())
                else:
                    hellingerd[u,v] = hellingerd[v,u]
       
        self._hellingerd = hellingerd
        with open('hellinger.plk','wb') as f:
            dill.dump(hellingerd,f)
       
        
        
if __name__ == "__main__":
    conf = ConfigReader('regimes.ini')
    cl = Clusters(conf)
    cl.cal_affinity_matrix()
    cl.compute_clusters()
    cl.graph_clusters()