from config_manager import ConfigReader
from distribs import Distributions
import pandas
import numpy as np
from objects import ObjectContainer
import operator
import datetime
from series_c import *
from concurrent.futures import ThreadPoolExecutor,ProcessPoolExecutor
import queue
from time import time
import dill

class RandomGenerator:
    _lastx=0
    def __init__(self,transitions, covariances):
        self._transitions = transitions
        self._covariances = covariances
        self._tr = {}
        self._vect = {}
        
    def generate(self, n):
        self._n = n
        for x in range(self._transitions.shape[0]):
            probs = self._transitions[x,:]
            self._tr[x] = np.random.choice(len(probs),p=probs,size=n)
            if x not in self._covariances:
                print (x, 'not in covariances')
            mu,sigma = self._covariances[x]
            if type(mu) is not int:
                self._vect[x] = np.random.multivariate_normal(mu,sigma, size=n)
    
    def get_next(self, tran):
        nothere = True
        while nothere:
            RandomGenerator._lastx += 1
            i = RandomGenerator._lastx % self._n
            new = self._tr[tran][i]
            nothere = new not in self._vect.keys()
        return new
        
    def get_vector(self,rid):
        RandomGenerator._lastx += 1
        i = RandomGenerator._lastx % self._n        
        return self._vect[rid][i]
        
        
class MCMC:
    def __init__(self,conf, serie):
        self._serie = serie
        self._paths = None
        self._conf=conf
        self._transitions = ObjectContainer.getInstance().get('transitions')
        self._q = queue.Queue()
        #self._run = threading.Event()
        self._threads = []
        self._Nthreads = 12
        #self._complete = threading.Event()
        self._processed = None
        self._generator = RandomGenerator(self._transitions, self._serie._covars._covariances)
        self._generator.generate(40000)
    def _thread_run(self):
        #while True:
        self._run.wait()
        while not self._complete.is_set():
            try:
                x,dt1,dt2 = self._q.get(block=False, timeout=1)
            except queue.Empty:
                if self._processed.qsize() == self._npaths:
                    self._complete.set()
                return
            reg,lret = self.run_path(dt1, dt2)
            self._regimes[x,:] = reg
            self._logret[x,:] = lret
            self._processed.put(x)
            if self._processed.qsize() == self._npaths:
                self._complete.set()
    def start_threads(self):
        for i in range(self._Nthreads):
            t  = threading.Thread(target=self._thread_run)
            t.start()
            self._threads.append(t)
        self._run.set()
         
    def random_r(self,probs):
        tr = True
        r = -1
        while tr:
            r = np.random.choice(len(probs),p=probs)
            mu, sigma = self._serie._covars._covariances[r]
            tr = type(sigma) is int
        return r, mu, sigma
    
    
    def run_path(self,x,datefrom, dateto):
        
        regimes = np.zeros(self._Taxis.shape).astype(int)        
        regimes[0] = self._serie.getR(datefrom)
        slogrets = self._serie.values(datefrom,dateto)
        logret = np.zeros(slogrets.shape)
        logret[0] = slogrets[0,:]
    
        for ip in range(1,self._Taxis.shape[0]):
            r_ = regimes[ip-1]
            rnext = self._generator.get_next(r_)
            vect = self._generator.get_vector(rnext)
            #probs = self._transitions[r_,:]
            #n_,mu,sigma = self.random_r(probs)
            regimes[ip]= rnext
            logret[ip] = vect
            #logret[ip] = np.random.multivariate_normal(mu,sigma)
        
        self._regimes[x,:] = regimes
        self._logret[x,:] = logret        
        #return (regimes, logret)
    
    def run(self, datefrom, dateto):
        self._npaths = self._conf.mcmc_simul_npaths()
        self._Taxis = self._serie.get_dates(datefrom,dateto)
        nsteps = self._Taxis.shape[0]
        self._regimes = np.zeros((self._npaths, nsteps)).astype(int)
        self._logret = np.zeros((self._npaths, nsteps,self._serie._values.shape[1]))
        for x in range(self._npaths):
            self.run_path(x,datefrom,dateto)
        return
   
if __name__  == "__main__":
    start = time()
    conf = ConfigReader('regimes.ini')
    
    ts = Series(conf)
    end = time()
    print ('build ts',end-start)
    ts.init()
    start = end
    end = time()
    print ('init ts',end-start)
    datefrom = ts.closest_date(np.datetime64('2007-09-01'))  
    dateto = ts.closest_date(np.datetime64('2008-03-01'))
    ts.calculate(None, dateto)
    start = end
    end = time()
    print ('calculate ts',end-start)    
    mc = MCMC(conf,ts)
    start = end
    end = time()
    print ('build MC',end-start)    
    mc.run(datefrom, dateto)
    start = end
    end = time()
    print ('run MC',end-start)    
    print (mc._regimes)
    