import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import matplotlib.colors as mcolors 
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1 import make_axes_locatable

def expected_signal(edep, kappa, qe, NA, n, apx, mo, gain=1):
    edep  # Energy deposited per area [MeV/mm2]
    kappa  # Scintillator Conversion [phots/MeV]
    qe  # Quantum Efficiency [abs]
    NA  # Lens numerical aperture [abs]
    n  # Refractive index [arb]
    apx  # pixel area [mm2]
    mo  # Optical Magnificaition [arb]
    gain  # Photons -> counts for sensor [Counts/Phot_opt]
    return (edep*kappa)*((qe*NA**2)/(4*n**2))*(apx/mo**2)*gain

### NA vs Mo scaling
mitu = [[1, 0.025],
        [2, 0.055],
        [5, 0.14],
        [5, 0.21],
        [7.5, 0.21],
        [10, 0.28],
        [10, 0.42]]

xx = [f[0] for f in mitu]
yy = [f[1] for f in mitu]

def func(x, a, b, c):
    return a * x**b + c

popt, pcov = curve_fit(func, yy, xx)

def mofromna(NA):
    return func(NA, *popt)

def koch(NA, ell, dx=0, Mo=1, p=0.7, q=0.28):
    ''' Koch et al. model adjusted for low magnfication domain'''
    return np.sqrt((p/NA)**2 + (q*ell*NA)**2 + (2*dx/Mo)**2)

def stopping(ell):
    ''' valid for LYSO up to 1000um thick '''
    x = np.linspace(0,1000,101)
    y = [0.0, 0.04150892299928336, 0.07117304622963712, 0.09712681396274517, 0.12089683461260739, 0.14309973214930702, 0.16405177416674746, 0.18394543849052813, 0.20291240787306994, 0.2210501355932717, 0.23843475823017246, 0.2551281210955598, 0.27118194230630743, 0.28664044685642476, 0.30154211972324524, 0.31592092151124074, 0.3298071606174514, 0.3432281373054879, 0.35620863131847924, 0.3687712791299881, 0.38093687143517196, 0.3927245917535853, 0.4041522107204561, 0.4152362464632444, 0.4259920986189435, 0.4364341615755821, 0.44657592112703415, 0.4564300377276975, 0.46600841880182386, 0.475322282020764, 0.48438221105563983, 0.4931982050053899, 0.5017797224645222, 0.5101357210125951, 0.5182746927650512, 0.5262046965128472, 0.5339333868891727, 0.5414680409301653, 0.5488155823389113, 0.555982603715188, 0.5629753869750687, 0.5697999221529098, 0.5764619247520402, 0.5829668517885971, 0.5893199166545908, 0.5955261029107757, 0.6015901771067389, 0.6075167007143831, 0.6133100412513378, 0.6189743826625264, 0.6245137350209209, 0.6299319436022555, 0.6352326973830049, 0.6404195370061384, 0.6454958622549406, 0.6504649390714593, 0.6553299061528369, 0.66009378115584, 0.6647594665372784, 0.6693297550556565, 0.673807334957297, 0.6781947948682743, 0.682494628411793, 0.6867092385690987, 0.6908409418006057, 0.6948919719426581, 0.6988644838941844, 0.702760557106449, 0.7065821988881433, 0.7103313475371763, 0.7140098753097263, 0.7176195912363595, 0.7211622437943606, 0.7246395234447789, 0.7280530650421293, 0.731404450124147, 0.7346952090885108, 0.7379268232629936, 0.7411007268750794, 0.7442183089266968, 0.7472809149793616, 0.750289848854681, 0.7532463742548718, 0.7561517163076419, 0.7590070630395298, 0.7618135667815393, 0.7645723455106757, 0.7672844841307722, 0.7699510356957953, 0.7725730225786254, 0.7751514375881363, 0.7776872450372281, 0.7801813817643195, 0.7826347581106545, 0.7850482588556511, 0.7874227441123844, 0.7897590501851899, 0.792057990391247, 0.7943203558479123, 0.796546916227467, 0.7987384204808525]
    return np.interp(ell,x,y)

### DATA SET
na_scan = np.linspace(0.05, 1.4, 55)
ell_scan =  np.geomspace(1, 401, 50)*1e-3
koch_map = np.zeros((len(na_scan), len(ell_scan)))
kappa_map = np.zeros((len(na_scan), len(ell_scan)))

for i, n in enumerate(na_scan):
    for j, l in enumerate(ell_scan):
        mo = mofromna(n)
        edep = 1e4*stopping(l)
        kappa_map[i, j] = expected_signal(edep=edep, kappa=55000,qe=0.65,NA=n,n=1.85,apx=6.5,mo=mo)
        koch_map[i, j] = koch(n, l*1e3)

## combined image
fig = plt.figure(figsize=(5, 7))
gs = gridspec.GridSpec(4, 3, width_ratios=[0.4,0.1, 0.4], height_ratios=[0.35, 0.35, 0.07, 0.3])

# subplots
ax1 = fig.add_subplot(gs[3, 0])
ax2 = fig.add_subplot(gs[3, 2])
ax3 = fig.add_subplot(gs[:2, :])

# Signal FOM
im0 = ax1.pcolormesh(na_scan, ell_scan, kappa_map.T/1000,
                     cmap='Reds', norm=mcolors.LogNorm(vmin=1,vmax=100),shading='Auto')
cs1 = fig.colorbar(im0, ax=ax1, label='Signal FOM (Arb.)')

# Resolution FOM
im1 = ax2.pcolormesh(na_scan, ell_scan, 1000/koch_map.T,
                       cmap='Blues', norm=mcolors.LogNorm(vmin=10,vmax=1000),shading='Auto')
cs2 = fig.colorbar(im1, ax=ax2, label='Resolution FOM (Arb.)')

# Combined FOM
im2 = ax3.pcolormesh(na_scan, ell_scan, (kappa_map/(1000*koch_map)).T, cmap='Greys',shading='Auto')
ax3.contour(na_scan, ell_scan, koch_map.T, levels=[1,2, 10], colors=['r'])

# Formatting axis
ax1.set_yscale('log')
ax2.set_yscale('log')
ax2.set_yticks([])
ax1.set_ylabel('Thickness (mm)')
ax1.set_xlabel('Numerical Aperture')
ax2.set_xlabel('Numerical Aperture')
ax3.set_yscale('log')
ax3.set_xlabel('Numerical Aperture')
ax3.set_ylabel('Thickness (mm)')

divider = make_axes_locatable(ax3)
cax = divider.append_axes('right', size="7%", pad=0.2,)
cs3 = fig.colorbar(im2,cax=cax,label="Combined FOM (arb.)")

## Add text label to plots 
ax3.text(1.2,0.053,'10 $\mu$m',color='w',rotation=-5)
ax3.text(1.2, 0.01, '2 $\mu$m', color='w', rotation=-5)
ax3.text(1.2, 0.0041, '1 $\mu$m', color='w', rotation=-5)

## Formatting
ax3.text(-0.1,0.5,'a)')
ax3.text(-0.1, 0.0003, 'b)')
ax3.text(0.9, 0.0003, 'c)')

plt.show()