import numpy as np 
import matplotlib.pyplot as plt 
import matplotlib.colors as mcolors
from matplotlib.ticker import ScalarFormatter, FormatStrFormatter


## function to add arrow
def add_arrow(line, position=None, direction='right', size=15, color=None):
    """
    add an arrow to a line.

    line:       Line2D object
    position:   x-position of the arrow. If None, mean of xdata is taken
    direction:  'left' or 'right'
    size:       size of the arrow in fontsize points
    color:      if None, line color is taken.
    """
    if color is None:
        color = line.get_color()

    xdata = line.get_xdata()
    ydata = line.get_ydata()

    if position is None:
        position = [xdata.mean()/2, 3*xdata.mean()/4]
    
    for p in position:    
    # find closest index
        start_ind = np.argmin(np.absolute(xdata - p))
        if direction == 'right':
            if start_ind+1 < len(xdata):
                end_ind = start_ind + 1
            else:
                end_ind = start_ind 
                start_ind -= 1
        else:
            end_ind = start_ind - 1

        line.axes.annotate('',
                        xytext=(xdata[start_ind], ydata[start_ind]),
                        xy=(xdata[end_ind], ydata[end_ind]),
                        arrowprops=dict(arrowstyle="simple", color=color),
                        size=size,
                        )

## initial parameters for calculations
parameters_mitu = {
    # numerical aperture, magnification, detector_pos, ell,pixel_size
    'mitu5x': [0.14, 5, 2, 1, 6.5],
    'mitu10x+': [0.42, 10, 2, 1, 6.5],
    'mituideal': [0.72, 20, 2, 2, 4.5],
}

parameters_fop = {
    'fopLOWNA': [0.43, 1, 20, 100, 100],
    'fopHIGHNA': [0.7, 1, 20, 100, 100],
    'fopzyla': [0.7, 1, 20, 20, 25],
}

parameters_mvl = {
    # numerical aperture, magnification, detector_pos, ell,pixel_size
    'mvlEHD25085': [62/170, 25/170, 20, 250, 6.5],
    'mvlEHD50085': [62/240, 50/240, 20, 250, 6.5],
}

sample = 1
nphot = 7e+11
theta = 5.08/1000

## functions for calculations
def expected_signal(edep, kappa, qe, NA, n, apx, mo, gain=1):
    edep  # Energy deposited per area [MeV/mm2]
    kappa  # Scintillator Conversion [phots/MeV]
    qe  # Quantum Efficiency [abs]
    NA  # Lens numerical aperture [abs]
    n  # Refractive index [arb]
    apx  # pixel area [mm2]
    mo  # Optical Magnificaition [arb]
    gain  # Photons -> counts for sensor [Counts/Phot_opt]
    return (edep*kappa)*((qe*NA**2)/(4*n**2))*(apx/mo**2)*gain

def resolution(NA, ell, dx, Mo, p=0.7, q=0.28):
    ''' Koch et al. model adjusted for low magnfication domain'''
    return np.sqrt((p/NA)**2 + (q*ell*NA)**2 + (2*dx/Mo)**2)

def stopping(ell):
    ''' valid for LYSO up to 1000um thick '''
    x = np.linspace(0,1000,101)
    y = [0.0, 0.04150892299928336, 0.07117304622963712, 0.09712681396274517, 0.12089683461260739, 0.14309973214930702, 0.16405177416674746, 0.18394543849052813, 0.20291240787306994, 0.2210501355932717, 0.23843475823017246, 0.2551281210955598, 0.27118194230630743, 0.28664044685642476, 0.30154211972324524, 0.31592092151124074, 0.3298071606174514, 0.3432281373054879, 0.35620863131847924, 0.3687712791299881, 0.38093687143517196, 0.3927245917535853, 0.4041522107204561, 0.4152362464632444, 0.4259920986189435, 0.4364341615755821, 0.44657592112703415, 0.4564300377276975, 0.46600841880182386, 0.475322282020764, 0.48438221105563983, 0.4931982050053899, 0.5017797224645222, 0.5101357210125951, 0.5182746927650512, 0.5262046965128472, 0.5339333868891727, 0.5414680409301653, 0.5488155823389113, 0.555982603715188, 0.5629753869750687, 0.5697999221529098, 0.5764619247520402, 0.5829668517885971, 0.5893199166545908, 0.5955261029107757, 0.6015901771067389, 0.6075167007143831, 0.6133100412513378, 0.6189743826625264, 0.6245137350209209, 0.6299319436022555, 0.6352326973830049, 0.6404195370061384, 0.6454958622549406, 0.6504649390714593, 0.6553299061528369, 0.66009378115584, 0.6647594665372784, 0.6693297550556565, 0.673807334957297, 0.6781947948682743, 0.682494628411793, 0.6867092385690987, 0.6908409418006057, 0.6948919719426581, 0.6988644838941844, 0.702760557106449, 0.7065821988881433, 0.7103313475371763, 0.7140098753097263, 0.7176195912363595, 0.7211622437943606, 0.7246395234447789, 0.7280530650421293, 0.731404450124147, 0.7346952090885108, 0.7379268232629936, 0.7411007268750794, 0.7442183089266968, 0.7472809149793616, 0.750289848854681, 0.7532463742548718, 0.7561517163076419, 0.7590070630395298, 0.7618135667815393, 0.7645723455106757, 0.7672844841307722, 0.7699510356957953, 0.7725730225786254, 0.7751514375881363, 0.7776872450372281, 0.7801813817643195, 0.7826347581106545, 0.7850482588556511, 0.7874227441123844, 0.7897590501851899, 0.792057990391247, 0.7943203558479123, 0.796546916227467, 0.7987384204808525]
    return np.interp(ell,x,y)

## define colours for plot
CB_color_cycle = ['#377eb8', '#ff7f00', '#4daf4a',
                  '#f781bf', '#a65628', '#984ea3',
                  '#999999', '#e41a1c', '#dede00']

fig = plt.figure(figsize=(7, 4))
ax = fig.add_gridspec(right=0.6).subplots()

## all data
if 1:
    for p in parameters_mitu:
        NA, mo, Z, ell, dx = parameters_mitu[p]

        mg = sample/(Z-sample)
        nphot_per_mm = nphot/(4*np.pi*(Z*1000)**2)

        edep_per_photon = stopping(ell)
        edep_total = edep_per_photon * nphot_per_mm

        res = resolution(NA, ell, dx=dx, Mo=mo)*mg
        cpp = expected_signal(edep=edep_total, kappa=5500,
                            qe=0.65, NA=NA, n=1.85, apx=dx, mo=mo)
        if p == 'mitu5x':
            ax.plot(1000/res, cpp, 'b.',alpha=0.7,label='Microscope Objectives')
        else:
            ax.plot(1000/res, cpp, 'b.', alpha=0.7)

    for p in parameters_fop:
        NA, mo, Z, ell, dx = parameters_fop[p]

        mg = sample/(Z-sample)
        nphot_per_mm = nphot/(4*np.pi*(Z*1000)**2)

        edep_per_photon = stopping(ell)
        edep_total = edep_per_photon * nphot_per_mm

        res = resolution(NA, ell, dx=dx, Mo=mo)*mg
        cpp = expected_signal(edep=edep_total, kappa=5500,
                            qe=0.65, NA=NA, n=1.85, apx=dx, mo=mo)
        if p == 'fopzyla':
            ax.plot(1000/res, cpp, 'rd', alpha=0.7,label='Fibre optic plates')
        else:
            ax.plot(1000/res, cpp, 'rd', alpha=0.7)
    for p in parameters_mvl:
        NA, mo, Z, ell, dx = parameters_mvl[p]

        mg = sample/(Z-sample)
        nphot_per_mm = nphot/(4*np.pi*(Z*1000)**2)

        edep_per_photon = stopping(ell)
        edep_total = edep_per_photon * nphot_per_mm

        res = resolution(NA, ell, dx=dx, Mo=mo)*mg
        cpp = expected_signal(edep=edep_total, kappa=5500,
                            qe=0.65, NA=NA, n=1.85, apx=dx, mo=mo)

        if p == 'mvlEHD25085':
            ax.plot(1000/res, cpp, 'ko', alpha=0.7,label='Machine Vision Lenses')
        else:
            ax.plot(1000/res, cpp, 'ko', alpha=0.7)

    ax.set_yscale('log')
    ax.set_xscale('log')

    axins = ax.inset_axes([1.05, 0, 0.47, 0.47])

    p = 'mitu10x+'
    print('NA,M0,Z,ell,dx')
    print(parameters_mitu[p])
    NA, mo, Z, ell, dx = parameters_mitu[p]

    mg = sample/(Z-sample)
    nphot_per_mm = nphot/(4*np.pi*(Z*1000)**2)

    edep_per_photon = stopping(ell)
    edep_total = edep_per_photon * nphot_per_mm

    res = resolution(NA, ell, dx=dx, Mo=mo)*mg
    cpp = expected_signal(edep=edep_total, kappa=5500,
                        qe=0.65, NA=NA, n=1.85, apx=dx, mo=mo)
    # plt.plot(1000/res, cpp, marker='o',color='tab:blue')

#####################
# vary mg independently
if 1:
    NA, mo, Z0, ell, dx = parameters_mitu[p]
    Z = Z0*np.linspace(0.9, 1.1, 50)
    print(f'Varying Mg by {Z[0]/Z0} to {Z[-1]/Z0}')
    mg = sample/(Z-sample)
    nphot_per_mm = nphot/(4*np.pi*(Z*1000)**2)

    edep_per_photon = stopping(ell)
    edep_total = edep_per_photon * nphot_per_mm

    res = resolution(NA, ell, dx=dx, Mo=mo)*mg
    cpp = expected_signal(edep=edep_total, kappa=5500,
                        qe=0.65, NA=NA, n=1.85, apx=dx, mo=mo)

    line = axins.plot(1000/res, cpp, '-', label='Mg', color=CB_color_cycle[0])[0]
    add_arrow(line, position=[np.min(1000/res), np.max(1000/res)], size=15)
    axins.text((1000/res[0]) * 0.97,cpp[0]*1.08,'0.8x',color=CB_color_cycle[0],fontsize=8,fontweight='bold')
    axins.text((1000/res[-1]) * 0.96,cpp[-1]*1.15,'1.2x',color=CB_color_cycle[0],fontsize=8,fontweight='bold')

#####################
# vary mo independently
if 1:
    NA, M0, Z, ell, dx = parameters_mitu[p]
    mg = sample/(Z-sample)
    nphot_per_mm = nphot/(4*np.pi*(Z*1000)**2)
    edep_per_photon = stopping(ell)
    edep_total = edep_per_photon * nphot_per_mm

    mo = M0 * np.linspace(0.8, 1.2, 50)
    print(f'Varying Mo by {mo[0]/M0} to {mo[-1]/M0}')
    res = resolution(NA, ell, dx=dx, Mo=mo)*mg
    cpp = expected_signal(edep=edep_total, kappa=5500,
                        qe=0.65, NA=NA, n=1.85, apx=dx, mo=mo)

    line = axins.plot(1000/res, cpp, '-', label='Mo', color=CB_color_cycle[1])[0]
    add_arrow(line, position=[np.min(1000/res), np.max(1000/res)], size=15)
    axins.text((1000/res[0]) * 0.90,cpp[0]*1.08,'0.8x',color=CB_color_cycle[1],fontsize=8,fontweight='bold')
    axins.text((1000/res[-1]) * 1,cpp[-1]*0.9,'1.2x',color=CB_color_cycle[1],fontsize=8,fontweight='bold')

#####################
# vary na/mo
if 1:
    na0, mo, Z, ell, dx = parameters_mitu[p]
    NA = na0*np.geomspace(0.8, 1.2, 50)
    print(f'Varying NA by {NA[0]/na0} to {NA[-1]/na0}')
    mg = sample/(Z-sample)
    nphot_per_mm = nphot/(4*np.pi*(Z*1000)**2)

    edep_per_photon = stopping(ell)
    edep_total = edep_per_photon * nphot_per_mm

    res = resolution(NA, ell, dx=dx, Mo=mo)*mg
    cpp = expected_signal(edep=edep_total, kappa=5500,
                        qe=0.65, NA=NA, n=1.85, apx=dx, mo=mo)

    line = axins.plot(1000/res, cpp, '-', label='NA', color=CB_color_cycle[2])[0]
    add_arrow(line, position=[np.min(1000/res), np.max(1000/res)], size=15)
    axins.text((1000/res[0]) * 0.94,cpp[0]*0.7,'0.9x',color=CB_color_cycle[2],fontsize=8,fontweight='bold')
    axins.text((1000/res[-1]) * 0.99,cpp[-1]*1.03,'1.1x',color=CB_color_cycle[2],fontsize=8,fontweight='bold')


#####################
# vary ell
if 1:
    NA, mo, Z, ell, dx = parameters_mitu[p]

    l = ell*np.geomspace(0.5, 2, 25)
    print(f'Varying ell by {l[0]/ell} to {l[-1]/ell}')
    res = np.zeros_like(l)
    cpp = np.zeros_like(l)

    mg = sample/(Z-sample)
    nphot_per_mm = nphot/(4*np.pi*(Z*1000)**2)

    for i, ell in enumerate(l):
        edep_per_photon = stopping(ell)
        edep_total = edep_per_photon * nphot_per_mm

        res[i] = resolution(NA, ell, dx=dx, Mo=mo)*mg
        cpp[i] = expected_signal(edep=edep_total, kappa=5500,
                                qe=0.65, NA=NA, n=1.85, apx=dx, mo=mo)

    line = axins.plot(1000/res, cpp, '-', label='$\ell$',
                      color=CB_color_cycle[3])[0]
    add_arrow(line, position=[np.min(1000/res), np.max(1000/res)], size=15)
    axins.text((1000/res[0]) * 1.02,cpp[0]*0.8,'0.5x',color=CB_color_cycle[3],fontsize=8,fontweight='bold')
    axins.text((1000/res[-1]) * 0.99,cpp[-1]*1.03,'2x',color=CB_color_cycle[3],fontsize=8,fontweight='bold')

#####################
# plot central point
if 1:
    NA, mo, Z, ell, dx = parameters_mitu[p]

    mg = sample/(Z-sample)
    nphot_per_mm = nphot/(4*np.pi*(Z*1000)**2)

    edep_per_photon = stopping(ell)
    edep_total = edep_per_photon * nphot_per_mm

    res = resolution(NA, ell, dx=dx, Mo=mo)*mg
    cpp = expected_signal(edep=edep_total, kappa=5500,
                        qe=0.65, NA=NA, n=1.85, apx=dx, mo=mo)
    axins.plot(1000/res, cpp, marker='o', color='blue', alpha=0.7)

#####################
# format axis
if 1:
    # axins.set_yscale('log')
    # axins.set_xscale('log')
    # axins.legend(ncol=1,bbox_to_anchor=(1.5,0.9))

    x1, x2, y1, y2 = (1000/res) * 0.75, (1000/res) * 1.25, cpp / 2.5, cpp * 2
    axins.set_xlim(x1, x2)
    axins.set_ylim(y1, y2)

    axins.xaxis.set_label_position('top')
    axins.yaxis.tick_right()

    ax.indicate_inset_zoom(axins, edgecolor="black", label='_nolegend_')

######################################################## second datapoint

axins = ax.inset_axes([1.05, 0.53, 0.47, 0.47])

p = 'mvlEHD50085'
print('NA,M0,Z,ell,dx')
print(parameters_mvl[p])
if 1:
    NA, mo, Z, ell, dx = parameters_mvl[p]
    mg = sample/(Z-sample)
    nphot_per_mm = nphot/(4*np.pi*(Z*1000)**2)

    edep_per_photon = stopping(ell)
    edep_total = edep_per_photon * nphot_per_mm

    res = resolution(NA, ell, dx=dx, Mo=mo)*mg
    cpp = expected_signal(edep=edep_total, kappa=5500,
                        qe=0.65, NA=NA, n=1.85, apx=dx, mo=mo)
    # plt.plot(1000/res, cpp, marker='o',color='tab:blue')

#####################
# vary mg independently
if 1:
    NA, mo, Z0, ell, dx = parameters_mvl[p]
    Z = Z0*np.linspace(0.9, 1.1, 50)
    print(f'Varying Mg by {Z[0]/Z0} to {Z[-1]/Z0}')
    mg = sample/(Z-sample)
    nphot_per_mm = nphot/(4*np.pi*(Z*1000)**2)

    edep_per_photon = stopping(ell)
    edep_total = edep_per_photon * nphot_per_mm

    res = resolution(NA, ell, dx=dx, Mo=mo)*mg
    cpp = expected_signal(edep=edep_total, kappa=5500,
                        qe=0.65, NA=NA, n=1.85, apx=dx, mo=mo)

    line = axins.plot(1000/res, cpp, '-', label='Mg',
                      color=CB_color_cycle[0])[0]
    add_arrow(line, position=[np.min(1000/res), np.max(1000/res)], size=15)

#####################
# vary mo independently
if 1:
    NA, M0, Z, ell, dx = parameters_mvl[p]

    mg = sample/(Z-sample)
    nphot_per_mm = nphot/(4*np.pi*(Z*1000)**2)
    edep_per_photon = stopping(ell)
    edep_total = edep_per_photon * nphot_per_mm

    mo = M0 * np.linspace(0.8, 1.2, 50)
    print(f'Varying Mo by {mo[0]/M0} to {mo[-1]/M0}')

    res = resolution(NA, ell, dx=dx, Mo=mo)*mg
    cpp = expected_signal(edep=edep_total, kappa=5500,
                        qe=0.65, NA=NA, n=1.85, apx=dx, mo=mo)

    line = axins.plot(1000/res, cpp, '-', label='Mo',
                      color=CB_color_cycle[1])[0]
    add_arrow(line, position=[np.min(1000/res), np.max(1000/res)], size=15)

###########
# vary NA independently
if 1:
    na0, mo, Z, ell, dx = parameters_mvl[p]
    NA = na0*np.geomspace(0.8, 1.2, 50)
    print(f'Varying NA by {NA[0]/na0} to {NA[-1]/na0}')

    mg = sample/(Z-sample)
    nphot_per_mm = nphot/(4*np.pi*(Z*1000)**2)

    edep_per_photon = stopping(ell)
    edep_total = edep_per_photon * nphot_per_mm

    res = resolution(NA, ell, dx=dx, Mo=mo)*mg
    cpp = expected_signal(edep=edep_total, kappa=5500,
                        qe=0.65, NA=NA, n=1.85, apx=dx, mo=mo)

    line = axins.plot(1000/res, cpp, '-', label='NA', color=CB_color_cycle[2])[0]
    add_arrow(line, position=[np.min(1000/res), np.max(1000/res)], size=15)

#####################
# vary ell
if 1:
    NA, mo, Z, ell, dx = parameters_mvl[p]

    l = ell*np.geomspace(0.5, 2, 25)
    print(f'Varying ell by {l[0]/ell} to {l[-1]/ell}')
    res = np.zeros_like(l)
    cpp = np.zeros_like(l)

    mg = sample/(Z-sample)
    nphot_per_mm = nphot/(4*np.pi*(Z*1000)**2)

    for i, ell in enumerate(l):
        edep_per_photon = stopping(ell)
        edep_total = edep_per_photon * nphot_per_mm

        res[i] = resolution(NA, ell, dx=dx, Mo=mo)*mg
        cpp[i] = expected_signal(edep=edep_total, kappa=5500,
                                qe=0.65, NA=NA, n=1.85, apx=dx, mo=mo)

    line = axins.plot(1000/res, cpp, '-', label='$\ell$',
                      color=CB_color_cycle[3])[0]
    add_arrow(line, position=[np.min(1000/res), np.max(1000/res)], size=15)



NA, mo, Z, ell, dx = parameters_mvl[p]

mg = sample/(Z-sample)
nphot_per_mm = nphot/(4*np.pi*(Z*1000)**2)

edep_per_photon = stopping(ell)
edep_total = edep_per_photon * nphot_per_mm

res = resolution(NA, ell, dx=dx, Mo=mo)*mg
cpp = expected_signal(edep=edep_total, kappa=5500,
                      qe=0.65, NA=NA, n=1.85, apx=dx, mo=mo)
axins.plot(1000/res, cpp, marker='o',color='k', alpha=0.7)

# axins.set_yscale('log')
# axins.set_xscale('log')

x1, x2, y1, y2 = (1000/res) * 0.75, (1000/res) * 1.25, cpp / 2.5, cpp * 2
axins.set_xlim(x1, x2)
axins.set_ylim(y1, y2)

axins.yaxis.set_major_formatter(FormatStrFormatter('%.0e'))
axins.set_xticks([250, 325],minor=True)

axins.xaxis.set_label_position('top')
axins.xaxis.tick_top()

axins.yaxis.tick_right()
ax.indicate_inset_zoom(axins, edgecolor="black", label='_nolegend_')

plt.xlim(8e1,2e3)
plt.xlabel('Resolution (lp/mm)')
plt.ylabel('Signal (Counts/px)')

ax.legend(fontsize=10,ncol=1,bbox_to_anchor=(0.7,.55),frameon=False)

ax.text(2e4, 3e3, 'Signal (Counts/px)',rotation=90)
ax.text(2.5e3, 2e1, 'Resolution (lp/mm)')

## Custom Legend for Secondary Data Points
ax.text(0.45e3,2.5e5,'Numerical Aperture',color=CB_color_cycle[2],fontweight='bold',fontsize=9)
ax.text(0.78e3,2.5e5*0.6,'Optical Mag',color=CB_color_cycle[1],fontweight='bold',fontsize=9)
ax.text(0.62e3,2.5e5*0.6*0.6,'Geometric Mag',color=CB_color_cycle[0],fontweight='bold',fontsize=9)
ax.text(0.48e3,2.5e5*0.6*0.6*0.6,'Scintillator Length',color=CB_color_cycle[3],fontweight='bold',fontsize=9)

plt.show()

