#!/usr/bin/python3  -O
import sys, getopt
from datetime import datetime
from distribs import DistList
from config_manager import ConfigReader
import dill
import os.path
from utils import ProgressBar
import numpy as np
from math import *
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
from sklearn import cluster, covariance, manifold
from tabulate import tabulate
import makeclusters
from distribs import *
import pandas as pd
import io

class Graphs:
    
    def __init__(self,conf):
        self._conf = conf
       
    def graph_(self,embedding, centroids, labels,t = 'clusters'):
        L = labels
        colormap = [cm.gnuplot2(i) for i in np.linspace(0.,1.,max(L)+2)]
        colorms = [colormap[i] for i in L]

        fig = plt.figure()
        plt.scatter(embedding[0,:],embedding[1,:],s=40,c=colorms,alpha=1.0)   

        processed=set()

        for i,col in enumerate(range(max(L)+1)):
            if col in processed:
                continue
            lxy=centroids[i][0]
            labelx,labely = lxy[0],lxy[1]
            textx = labelx + 0.001
            texty = labely + 0.01
            plt.annotate(str(col),xy=(labelx,labely),xytext=(textx,texty), 
                         arrowprops=dict(arrowstyle="->",connectionstyle="arc3"),size=6)

        for i,z in enumerate(centroids):
            plt.scatter(z[0][0],z[0][1],s=50,c=[colormap[i]],alpha=1.0)   

        plt.title(t)
        buf = io.BytesIO()
        plt.savefig(buf,format='png')
        
        buf.flush()
        plt.savefig('graph.png',format='png')
        return buf

    def draw_clusters(self):
        cl = self._clusters
        return self.graph_(cl._embedding, cl._centroids, cl._clabels, 'clusters')
   
    def draw_regimes(self):
        r = self._regimes
        return self.graph_(r._clusters._embedding, r._rcentroids, r._regimes, 'regimes')        
            
    def loaddist_(self,ticker):
        mfile = self._conf.dataref()
        mfields =  [ticker]
        self._cmap = eval(self._conf.clmap())
        D = Distributions(varfact=1.1,eps=.2, window=90, overlap=60)
        D.load(mfile, mfields)
        self._ind = D._index
        self._x = D._logret.reshape(-1)
        self._title=ticker
        self._reg =  dill.load(open(self._conf.outreg(),'rb'))
  
    def annotate(self,dt,Y):  
        reg = self._reg
        labels=[]
        c = 0
        s = 0
        curr_label = -1
        Xi=0
        cx = 0
        l = len(reg._regimes)
        boxsize=(len(dt)-1)//l-1
        for i,tt in enumerate(dt):
            t = tt.strftime('%-m/%-d/%-y')
            if tt.year >= 2000:
                t = tt.strftime('%-m/%-d/%y')
            if curr_label==-1:
                curr_label=reg._rdates[t]
                continue
            k = reg._rdates[t]
            if k == curr_label:
                c += 1
            else:
                labelx = dt[s+c//2]
                labely = Y[s+c//2]
                if s+c//2 <= Xi + boxsize:
                    Xi = Xi + boxsize
                else:
                    Xi = s+c//2
                Xi = min(Xi,len(dt)-1)    
                textx = dt[Xi]
                if cx%2==0:
                    texty = Y[Xi] + 0.4
                else:
                    texty = Y[Xi] - 0.4
                plt.annotate(str(curr_label),xy=(labelx,labely),xytext=(textx,texty), arrowprops=dict(arrowstyle="->",
                                connectionstyle="arc3"),size=9)
                
                
                s=i
                c=0
                curr_label=k
                cx += 1
        buf = io.BytesIO()
        plt.savefig(buf,format='png')
        return buf        
 
    def graph_data(self):
       
        reg = self._reg
        self._cmap = cm.gnuplot2
        colors = [self._cmap(i) for i in np.linspace(0.,1.,reg._regimes.max()+2)]
        bbox_props = dict(boxstyle="square,pad=0.3",fc='white',ec='black',lw='2')
        #ax = plt.subplot(111)
        colorm = []
        xt = []
        Y = []
        for dx,y in zip(self._ind,self._x):
            if dx in reg._rdates.keys():
                rg = reg._rdates[dx]
                colorm.append(colors[rg])
                xt.append(pandas.to_datetime(dx))
                Y.append(np.log(y))
        plt.figure()
        t = plt.text(xt[0], Y[0], self._title, ha="center", va="center",
                    bbox=bbox_props)    
       
        plt.scatter(xt,Y,c=colorm,s=2)
        return xt,Y        

def printusage():
    print('usage: graphs.py -w <what=[clusters|regimes|ticker]> -i <file>')

def atexit():
    printusage()
    sys.exit(2)
 
def action(g,func,arg):
    
    if func == 'regimes':
        return g.draw_regimes()
    elif func == 'clusters':
        return g.draw_clusters()
    elif func == 'ticker':
        #g.loaddist(arg)
        xt,yt = g.graph_data()
        return g.annotate(xt,yt)

def init_graph():
    conf = ConfigReader('regimes.ini')
    g = Graphs(conf)
    return g

if __name__ == "__main__":
    args = sys.argv[1:]
    try:
        opts, args = getopt.getopt(args, 'w:i:')
        ARG = None
        FUNC = None
    
        for opt, arg in opts:
            if opt == '-w':
                if arg == 'clusters':
                    FUNC = 'draw_clusters'
                elif arg == 'regimes':
                    FUNC = 'draw_regimes'
                elif arg == 'ticker':
                    FUNC = 'draw_ticker'
                else:
                    atexit()
            if opt == '-i':
                ARG = arg    
        if ARG is None or FUNC is None:
            atexit()
        action(FUNC,ARG)
        plt.show()
        
    except getopt.GetoptError:
        atexit()    

        
