'''------------------------------------------------------------
Title:   Utilities for fibre-array tomography and analysis
Date:    07/12/2021
Author:  J. K. Patel
Version: 1.0

####################################################################

Useful functions for fibre-array tomographic analysis

####################################################################

-------------------------------------------------------------'''

import sys, time
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np

def normalize(a):
    ''' normalizes scale to 0-1 '''
    offset = a-np.min(a)
    return offset/np.max(offset)

def scale(arr):
    minstep = np.min(arr[arr != 0.])
    return np.array(arr*(1/minstep), dtype = np.float32)

def shift(arr, xshift = 0, yshift = 0, rolled_value = 0):
    ''' rolls array, by sets elements rolled from end back to beginning to 'rolled_value' '''
    arr = np.roll(arr, (xshift,yshift), axis = (1,0))
    xsign, ysign = np.sign(xshift), np.sign(yshift)
    
    if xsign >= 0:
        arr[:, :xshift] = 0
    else:
        arr[:, xshift:] = 0
    
    if ysign >= 0:
        arr[:yshift, :] = 0
    else:
        arr[yshift:, :] = 0

    return arr


def printProgressBar(i,imax,start, pre_text = '', post_text = ''):
    
    ''' 
    Function to print progress of a loop and estimate remaining and total time to complete.
    
    Params
    ------------------
    i:           current iteration of loop
    imax:        total number of iterations of loop
    start:       start time object
    post_text:   insert optional additional string at end of output. Empty string as default.
    post_text:   insert optional additional string at start of output. Empty string as default.
    
    
    Example
    -----------------
    
    # get current time before loop start
    START = time.time()
    
    # begin loop with counter i
    for i, value in enumerate(some_iterable):
        value.do_something()
        printProgressBar(i, len(some_iterable), START, pre_text = f'Run {run_number}.', post_text = f'This is iteration {i} of {len(some_iterable)}')
        
    Returns (example)
    ----------------
    Run 01. [====        ] 30% complete. Estimated remaining time: 00 hrs 05 mins 37 secs / 01 hrs 12 mins 07 secs. 
    '''
    
    
    n_bar = 20 #size of progress bar
    j= i/imax
    
    
    # if we are on the very first iteration we can't tell how long it might take, so just print that it's unknown
    if i == 0:
        sys.stdout.write('\r')
        sys.stdout.flush()
        sys.stdout.write(f"{pre_text}[{'=' * int(n_bar * j):{n_bar}s}] {int(100 * j)}%  complete. Estimated remaining time: unknown. {post_text}")
        sys.stdout.write('\r')
        sys.stdout.flush()
        
    else:
        # update elapsed time and recalculate estimated remaining time assuming subsequent iterations will take average of previous iteration times.
        elapsed_time = time.time()-start
        total_time_est = elapsed_time/i*imax
        rem_time = total_time_est-elapsed_time
        struct = time.gmtime(round(rem_time))
        rem_time = time.strftime('%H hrs %M mins %S secs', struct)
        totstruct = time.gmtime(round(total_time_est))
        tot_time = time.strftime('%H hrs %M mins %S secs', totstruct)
        
        if i == imax-1:
            j = 1

        sys.stdout.write(f"{pre_text}[{'=' * int(n_bar * j):{n_bar}s}] {int(100 * j)}%  complete. Estimated remaining time: {rem_time}/{tot_time}. {post_text}")
        sys.stdout.write('\r')
        sys.stdout.flush()

def add_cbar(fig, ax, plot):
    div = make_axes_locatable(ax)
    cax = div.append_axes('right', size = '5%', pad = 0.1)
    fig.colorbar(plot, cax = cax)
