'''
hxLAS parameter testing
Author: C. D. Armstrong
Year: 2021

Comment:
Generates input parameters for two temperature test case between T 0 MeV to 10 MeV and a flux of 10^1 to 10^6

Description:

'''


import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from scipy.interpolate import interp1d, pchip
from scipy import optimize
import csv
import h5py
from scipy import integrate
import os
from datetime import date
from lmfit.models import ExponentialModel
import time
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
import sys
import datetime
new_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
              '#9467bd', '#8c564b', '#e377c2', '#7f7f7f',
              '#bcbd22', '#17becf']

def generateInputData(dt='bestguess',niter=1000):
    '''
    Example operation routine to compute {niter} random parameter sets
    Only uses the "best guess" method as an example
    Generates and saves the input parameters to enable direct comparison with other methods
    Inputs:
        dt: str, label for saved files
        niter: int, number of tests
    '''
    # Define Working Energies in MeV
    hw1 = np.log10(0.01)
    hw2 = np.log10(100)
    n = 1001
    energies = np.asarray([10**(hw1 + (i-1)/(n-1)*(hw2 - hw1))
                           for i in range(1, n+1)])
    temperaturescan = np.geomspace(0.01, 15, 100)

    # Load Response
    rm = load_response(energies, file='./response_10.h5')

    # Generate Input Parameters
    input_params = np.zeros((niter, 4))
    for n in range(niter):
        p = randomstart()
        input_params[n, :] = p

    # Save Input Parameters to File
    df_i = pd.DataFrame(data=input_params, columns=['n1', 't1', 'n2', 't2'])
    df_i.to_csv(f'./{dt}_{niter}_input.csv')

    # Predefine arrays for outputs
    t_out = np.zeros(niter)
    n1_out = np.zeros((niter))
    t1_out = np.zeros((niter))
    n2_out = np.zeros((niter))
    t2_out = np.zeros((niter))
    mv_out = np.zeros((niter))

    # Compute Response for BestGuess Routine:
    for n in range(niter):
        printProgressBar(n, niter, 'complete')

        p = input_params[n, :] # reload input parameters
        data = combineTwoResponse(p, energies, rm) # generate crystal data
        p1, mv, t1 = timed_bestguess(data, temperaturescan, tm, energies, rm, None) #timed bestguess routine

        # commit to store
        t_out[n] = t1
        mv_out[n] = mv
        n1_out[n] = p1[0]
        n2_out[n] = p1[2]
        t1_out[n] = p1[1]
        t2_out[n] = p1[3]

    df = pd.DataFrame.from_dict(
        {
            'Time': t_out,
            'Merit': mv_out,
            'Reconstructed_N1': n1_out,
            'Reconstructed_T1': t1_out,
            'Reconstructed_N2': n2_out,
            'Reconstructed_T2': t2_out,
        }
    )

    df.to_csv(f'./comparison/{dt}/{dt}_{niter}_data.csv')
    print('######################')
    print(df)
    print('######################')

### Nelder-Mead Optimiser Routine
def timed_bestguess(data, temperaturescan, tm, energies, rm, x):
    '''
    Computes the best guess optimisation routine for crystal data
    Times the computation to compare against sparse and fine reconstruction
    Inputs:
        data: 1D array 
        temperaturescan: 1D array of temperatures
            (included for uniformity accross routines not used here)
        tm: 2D array of temperature responses
            (included for uniformity accross routines not used here)
        energies: 1D array 
    '''
    t0 = time.time()
    p0 = [data[0], 0.15, data[4], 1]
    p1, mv = optimiseGridResults(data, p0, energies, rm, verbose=False)
    return p1, mv, time.time()-t0

def optimiseGridResults(data, p0, energies, rm, verbose=False):
    guided = optimize.minimize(merit,
                               x0=p0,
                               args=(data, energies, rm),
                               method='Nelder-Mead',
                               options={'maxiter': 1000, 'disp': False, 'return_all': True})
    if verbose:
        return (guided.x, guided.fun, guided.allvecs)
    else:
        return (guided.x, guided.fun)

### CONSTRUCT INDIVIDUAL RESPONSE
def boltz(p, en):
    '''
    Simple Boltzmann expression
    Inputs:
        p: list of [flux, temperature]
        en: energy bins
    Returns:
        Boltzmann distribution array like en
    '''
    return (p[0]/p[1]) * np.exp(-en/p[1])

def singleResponse(temperature, energies, response):
    '''
    Convolves single photon (i.e. n=1) Boltzmann distribution with response matrix to determine crystal values
    Inputs:
        temperature: float 
        energies: 1D array of floats matching response matrix
        response: 2D array response matrix 
    Returns:
        1D length of number of crystals in response matrix
    '''
    spectra = boltz([1, temperature], energies)
    spectra2 = np.append(-np.diff(spectra), [0])
    return np.dot(spectra2, response)

def singleTemperatureMatrix(temperatures, energies, response):
    '''
    Compute single photon (i.e. n=1) Boltzmann distribution response for many temperatures
    Inputs:
        temperatures: 1D array of temperatures
        energies: 1D array of floats matching response matrix
        response: 2D array response matrix
    Returns:
        2D length of number of crystals in response matrix and number of temperatures
    '''
    rm = np.zeros((len(temperatures), np.min(response.shape)))
    for ti, t in enumerate(temperatures):
        spectra = boltz([1, t], energies)
        spectra2 = np.append(-np.diff(spectra), [0])
        rm[ti, :] = convertCounts(np.dot(spectra2, response))
    return rm

def combineTwoResponse(p, energies, response):
    '''
    Combines the response of two Boltzmann distributions response for given parameters
    Inputs:
        p: 4-element list [n1,t1,n2,t2]
        energies: 1D array of energies matching response matrix
        response: 2D array of energies and crystal numbers
    Returns:
        1D array of crystal values converted into expected counts on CCD
    '''
    n1, t1, n2, t2 = p
    s1 = n1*singleResponse(t1, energies, response)
    s2 = n2*singleResponse(t2, energies, response)
    return convertCounts(s1+s2)

### DATA WRANGLER
def load_response(energies, file='./response_10.h5'):
    '''
    Loads and interpolates response matrix from .h5 files
    Input:
        energies: desired energy bins for response matrix
        file: (opt) relative file path for response matrix
    Returns:
        interpolated response matrix (2D array)
    '''
    with h5py.File(file, 'r') as hf:
        en = np.array(hf.get('energy'))/1000  # converting from keV to MeV
        edep = np.array(hf.get('edeposit'))/1000000  # converting to per photon

    return interp_response(edep, en, energies)

def interp_response(edep, oldbins, newbins):
    '''
    Interpolates scintillator response matrix to new energy bins
    Extrapolates for values beyond initial set
    Input:
        edep: response matrix (MxN array) energy in one axis scintillator layer in other
        oldbins: array of input energies, must match either M or N
        newbins: array of desired output energies
    Returns:
        r2: interpolated response matrix
    '''

    if edep.shape[0] is not len(oldbins):
        edep = edep.T

    r2 = np.zeros((len(newbins), edep.shape[1]))
    for i, r in enumerate(edep.T):
        f = interp1d(oldbins, r, fill_value='extrapolate')
        r2[:, i] = f(newbins)
    return r2

### UTILITY
def convertCounts(arr):
    '''
    Converts energy in scintillator to counts on CCD
    Input:
        arr: 10x1 array of floats
    Returns:
        array: 10x1 floats
    '''

    sr = 4.88E-04  # [sr] Lens acceptance angle
    photonspermev = 404.7  # from calibration data in [Optical Photons/MeV/Sr]
    qe = 0.6  # [dimensionless] Quantum efficiency
    cperg = 6  # 20 gain in [Counts per Optical Photon]
    efficiency = photonspermev*sr*qe*cperg

    return np.multiply(arr, efficiency)

def fn(array, value):
    '''
    "Find Nearest" element in array returns index of closest match in array
    Input:
        array: list/array
        value: float
    Returns:
        idx: int
    '''
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx

def order_components(p0):
    '''
    Generates random start positions for parameter scan 
    temperature between 0, 10 MeV
    flux between 10^1 and 10^6
    Input:
        None
    Returns:
        n1,t2,n2,t2
    '''
    idx = np.argmax([p0[1], p0[3]])
    if idx:
        sparse_params = [p0[0], p0[1], p0[2], p0[3]]
    else:
        sparse_params = [p0[2], p0[3], p0[0], p0[1]]
    return sparse_params

def randomstart():
    '''
    Generates random start positions for parameter scan 
    temperature between 0, 10 MeV
    flux between 10^1 and 10^6
    Input:
        None
    Returns:
        [n1,t2,n2,t2]
    '''
    t1, t2 = np.random.uniform(0, 10, 2)
    n1, n2 = 10**np.random.uniform(1, 6, 2)
    return [n1, t1, n2, t2]

def printProgressBar(i, max, postText):
    '''
    Minimalist method to print progress bar to console
    Input:
        i: int
        max: int
        postTest: str 
    Returns:
        None
    '''
    n_bar = 10  # size of progress bar
    j = i/max
    sys.stdout.write('\r')
    sys.stdout.write(
        f"[{'=' * int(n_bar * j):{n_bar}s}] {int(100 * j)}%  {postText}")
    sys.stdout.flush()
