import matplotlib.pyplot as plt 
from matplotlib.patches import Ellipse
from matplotlib.transforms import ScaledTranslation
from scipy.interpolate import interp1d
import numpy as np

import matplotlib.gridspec as gridspec

def numericalaperture(wd,d):
    return d/(2*wd)

def magnification(wd,f):
    return f/wd 

def resolution(NA,ell,dx,Mo,p=0.7,q=0.28):
    ''' Koch et al. model adjusted for low magnfication domain'''
    return np.sqrt((p/NA)**2 + (q*ell*NA)**2 + (2*dx/Mo)**2)

def expected_signal(edep,kappa,qe,NA,n,apx,mo,gain=1):
    edep; # Energy deposited per area [MeV/mm2]
    kappa; # Scintillator Conversion [phots/MeV]
    qe; # Quantum Efficiency [abs]
    NA; # Lens numerical aperture [abs]
    n; # Refractive index [arb]
    apx; # pixel area [mm2]
    mo; #Optical Magnificaition [arb]
    gain; #Photons -> counts for sensor [Counts/Phot_opt]
    return (edep*kappa)*((qe*NA**2)/(4*n**2))*(apx/mo**2)*gain

def stopping_interp(ell):
    ''' valid for LYSO up to 1000um thick '''
    x = np.linspace(0,1000,101)
    y = [0.0, 0.04150892299928336, 0.07117304622963712, 0.09712681396274517, 0.12089683461260739, 0.14309973214930702, 0.16405177416674746, 0.18394543849052813, 0.20291240787306994, 0.2210501355932717, 0.23843475823017246, 0.2551281210955598, 0.27118194230630743, 0.28664044685642476, 0.30154211972324524, 0.31592092151124074, 0.3298071606174514, 0.3432281373054879, 0.35620863131847924, 0.3687712791299881, 0.38093687143517196, 0.3927245917535853, 0.4041522107204561, 0.4152362464632444, 0.4259920986189435, 0.4364341615755821, 0.44657592112703415, 0.4564300377276975, 0.46600841880182386, 0.475322282020764, 0.48438221105563983, 0.4931982050053899, 0.5017797224645222, 0.5101357210125951, 0.5182746927650512, 0.5262046965128472, 0.5339333868891727, 0.5414680409301653, 0.5488155823389113, 0.555982603715188, 0.5629753869750687, 0.5697999221529098, 0.5764619247520402, 0.5829668517885971, 0.5893199166545908, 0.5955261029107757, 0.6015901771067389, 0.6075167007143831, 0.6133100412513378, 0.6189743826625264, 0.6245137350209209, 0.6299319436022555, 0.6352326973830049, 0.6404195370061384, 0.6454958622549406, 0.6504649390714593, 0.6553299061528369, 0.66009378115584, 0.6647594665372784, 0.6693297550556565, 0.673807334957297, 0.6781947948682743, 0.682494628411793, 0.6867092385690987, 0.6908409418006057, 0.6948919719426581, 0.6988644838941844, 0.702760557106449, 0.7065821988881433, 0.7103313475371763, 0.7140098753097263, 0.7176195912363595, 0.7211622437943606, 0.7246395234447789, 0.7280530650421293, 0.731404450124147, 0.7346952090885108, 0.7379268232629936, 0.7411007268750794, 0.7442183089266968, 0.7472809149793616, 0.750289848854681, 0.7532463742548718, 0.7561517163076419, 0.7590070630395298, 0.7618135667815393, 0.7645723455106757, 0.7672844841307722, 0.7699510356957953, 0.7725730225786254, 0.7751514375881363, 0.7776872450372281, 0.7801813817643195, 0.7826347581106545, 0.7850482588556511, 0.7874227441123844, 0.7897590501851899, 0.792057990391247, 0.7943203558479123, 0.796546916227467, 0.7987384204808525]
    return np.interp(ell,x,y)


lens_system = {
    'mitu1x': [0.025,1,'.'],
    'mitu2x': [0.055, 2,'.'],
    'mitu5x': [0.14, 5, '.'],
    'mitu5x+': [0.21, 5, '.'],
    'mitu7.5x': [0.21, 7.5, '.'],
    'mitu10x': [0.28, 10, '.'],
    'mitu10x+': [0.42, 10, '.'],
    'mitu20x': [0.28, 20, '.'],
    'mitu20x+': [0.42, 20, '.'],
    'ehd25085': [numericalaperture(170, 55), magnification(170, 25), 'o'],
    'ehd50085': [numericalaperture(240, 65), magnification(240, 50), 'o'],
    '25mmCSeries': [0.0706, magnification(100, 25), 'o'],
    '35mmCSeries': [0.0559, magnification(165, 35), 'o'],
    '50mmCSeries': [0.0437, magnification(250, 50), 'o'],
    '100mmCSeries': [0.0228, magnification(750, 100), 'o'],
    'HamamatsuFOPbwtaper': [0.99, 0.5, 'd'],
}

camera_system = {
    'mantag235b': [5.86,0.65,1],
    'mantag033b': [6.5, 0.65, 1],
    'zyla5.5': [6.5, 0.85, 1],
    'ixon0': [13, 0.85, 1],
    'qcmos': [4.85, 0.95, 0.5],
}

kappa = {
    'CsI': [55000, 1.80, 'tab:blue'],
}

mags = [0.1, 0.5, 5]
na = np.geomspace(0.01, 0.5, 25)

fig = plt.figure(figsize=(5.5, 7))
gs = gridspec.GridSpec(5, 3, height_ratios=[0.4, 0.4, 0.4,0.005, 0.65], width_ratios=[0.33, 0.33, 0.33, ])

# Add subplots to the grid
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[0, 2])
ax4 = fig.add_subplot(gs[2:, :])

ax = [ax1,ax2,ax3]


ax = [ax1,ax2,ax3]

for i, m in enumerate(mags):
    ax[i].plot(na, resolution(na, 5, 5, m),
               color='tab:blue')
    ax[i].plot(na, resolution(na, 25, 5, m),
               color='tab:orange')
    ax[i].plot(na, resolution(na, 125, 5, m),
               color='tab:green')
    ax[i].plot(na, resolution(na, 1025, 5, m),
               color='tab:red')
    ax[i].set_ylim([2, 200])
    ax[i].set_title(f'{m}x')
    ax[i].set_yscale('log')
ax[1].set_xlabel('Numerical Aperture')

ax[0].text(0, 3, 'a)', fontsize=13)
ax[1].text(0, 3, 'b)', fontsize=13)
ax[2].text(0, 3, 'c)',fontsize=13)

for i in range(1, len(mags)):
    ax[i].set_yticks([])

ax[0].set_ylabel('Resolution Limit ($\mu$m)')
ax[0].legend(['5$\mu$m', '25$\mu$m', '125$\mu$m', '1025$\mu$m'],
             ncol=4, bbox_to_anchor=(len(mags)+0.5, -0.6),title='Scintillator Thickness')


ax = ax4
k,eta,col = kappa['CsI']   
for ell in [5,50,500,5000]:
    ed = stopping_interp(ell)
    for l in lens_system:
        na = lens_system[l][0]
        mo = lens_system[l][1]
        shape = lens_system[l][2]
        if shape == '.':
            col='tab:blue'
        elif shape == 'o':
            col='tab:grey'
        else:
            col='tab:red'
        for c in camera_system: 
            px = camera_system[c][0]
            qe = camera_system[c][1]
            gain = camera_system[c][2]
            
            r = resolution(na,ell,px,mo)
            s = expected_signal(ed,k,qe,na,eta,px,mo,gain)

            ax.plot(1000/(2*r), s, shape, color=col)

# Ellipse centre coordinates
x, y = 15, 2
# use the axis scale tform to figure out how far to translate
ell_offset = ScaledTranslation(x, y, ax.transScale)
# construct the composite tform
ell_tform = ell_offset + ax.transLimits + ax.transAxes
# Create the ellipse centred on the origin, apply the composite tform
ax.add_patch(Ellipse(xy=(0, 0), width=1.8, height=4, color="tab:blue", angle=45,
                     fill=False, lw=1, zorder=5, transform=ell_tform))
ax.text(1, 0.08, 'Microscope Objectives',
        rotation=-18, fontsize=8, color="tab:blue")

# Ellipse centre coordinates
x, y = 4.5, 1000
# use the axis scale tform to figure out how far to translate
ell_offset = ScaledTranslation(x, y, ax.transScale)
# construct the composite tform
ell_tform = ell_offset + ax.transLimits + ax.transAxes
# Create the ellipse centred on the origin, apply the composite tform
ax.add_patch(Ellipse(xy=(0, 0), width=0.8, height=5, color="grey", angle=7.5,
                     fill=False, lw=1, zorder=5, transform=ell_tform))
ax.text(15, 100, 'Machine Vision Lenses', rotation=0, fontsize=8)

# Ellipse centre coordinates
x, y = 4, 12000
# use the axis scale tform to figure out how far to translate
ell_offset = ScaledTranslation(x, y, ax.transScale)
# construct the composite tform
ell_tform = ell_offset + ax.transLimits + ax.transAxes
# Create the ellipse centred on the origin, apply the composite tform
ax.add_patch(Ellipse(xy=(0, 0), width=1.2, height=3.5, color="tab:red", angle=40,
                     fill=False, lw=1, zorder=5, transform=ell_tform))
ax.text(15, 4000, 'Fibre Optic Plates',
        rotation=-25, fontsize=8, color="tab:red")

ax.set_xlabel('Resolution Limit (lp/mm)')
ax.set_ylabel('Expected Signal (arb)')
ax.set_yscale('log')
ax.set_xscale('log')
ax.text(0.14,4e5,'d)')

plt.show()