#!/usr/bin/env python3

import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import medfilt , find_peaks
from scipy.optimize import curve_fit
import utils
import sys
import shelve
import pathlib
from scipy.ndimage import rotate
from astropy.io import fits

cache , filename , output , calibration , intensity_calibration , verbose , no_cache = '' , None , None , None , None , False, False
if len( sys.argv ) < 2:
    raise Exception( 'spectrum.py: type \'spectrum.py -h\' for more information' )

argv , i = sys.argv[ 1 : ] , 0
while i < len( argv ):
    arg = argv[ i ]
    if arg[0] == '-':
        if len( arg ) < 2:
            raise Exception( 'spectrum.py: unknown argument, type \'ETA.py -h\' for more information' )
        if arg[1] != '-':
            if arg == '-h':
                arg = '--help'
            elif arg == '-V':
                arg = '--version'
            elif arg == '-v':
                arg = '--verbose'
            elif arg == '-n':
                arg == '--no-cache'
            elif arg == '-c':
                if i == len( sys.argv ) - 1:
                    raise Exception( 'spectrum.py: cache have to take a value' )
                argv[ i + 1 ] = '--cache=' + argv[ i + 1 ]
                i += 1
                continue
            elif arg == '-o':
                if i == len( sys.argv ) - 1:
                    raise Exception( 'spectrum.py: output have to take a value' )
                argv[ i + 1 ] = '--output=' + argv[ i + 1 ]
                i += 1
                continue
            elif arg == '-w':
                if i == len( sys.argv ) - 1:
                    raise Exception( 'spectrum.py: wavelength have to take a value' )
                argv[ i + 1 ] = '--wavelength=' + argv[ i + 1 ]
                i += 1
                continue
            elif arg == '-i':
                if i == len( sys.argv ) - 1:
                    raise Exception( 'spectrum.py: intensity have to take a value' )
                argv[ i + 1 ] = '--intensity=' + argv[ i + 1 ]
                i += 1
                continue
            else:
                raise Exception( 'spectrum.py: unknown argument "' + arg + '", type \'spectrum.py -h\' for more information' )
        if arg[1] == '-': # not elif because arg can change after last if
            if arg == '--help':
                print( 'spectrum.py [options...] filename\
                      \n    -w --wavelength  wavelength calibration file, default to no calibration.\
                      \n                     No calibration means no wavelength interpolation\
                      \n    -i --intensity   intensity calibration file, default to no calibration.\
                      \n                     No calibration means no intensity interpolation\
                      \n    -c --cache       use given cache\
                      \n    -h --help        show this help and quit\
                      \n    -n --no-cache    do not use cache and rewrite it\
                      \n    -o --output      output file, default to standard output\
                      \n    -V --version     show version number and quit\
                      \n    -v --verbose     show more information to help debugging\
                      \n\
                      \nParse a naroo spectrum fits' )
                exit()
            elif arg == '--version':
                print( '0.3' )
                exit()
            elif arg == '--verbose':
                verbose = True
            elif arg == '--no-cache':
                no_cache = True
            elif len( arg ) > 8 and arg[ : 8 ] == '--cache=':
                cache = arg[ 8 : ]
            elif len( arg ) > 9 and arg[ : 9 ] == '--output=':
                output = arg[ 9 : ]
            elif len( arg ) > 13 and arg[ : 13 ] == '--wavelength=':
                calibration = arg[ 13 : ]
            elif len( arg ) > 12 and arg[ : 12 ] == '--intensity=':
                intensity = arg[ 12 : ]
            else:
                raise Exception( 'spectrum.py: unknown argument "' + arg + '", type \'ETA.py -h\' for more information' )
        else:
            raise Exception( 'spectrum.py: this exception should never be raised' )
    else:
        filename = arg
    i += 1
if filename == None:
    raise Exception( 'spectrum.py: filename should be given' )

if verbose:
    cache, filename, output, calibration, verbose
    print( f'spectrum.py: launching now with parameters:\
           \n    --filename:    {filename}\
           \n    --cache:       {cache} ( default: \'\' )\
           \n    --wavelength:  {calibration} ( default to None )\
           \n    --intensity:   {intensity} ( default to None )\
           \n    --output:      {output} ( default to None )\
           \n    --verbose:     True ( default to False)\
           \n\
           \n===========================================' )
# TODO: check in advance file to check if exists or writeable

hdul = fits.open( filename )
data = hdul[0].data
head = hdul[0].header
hdul.close()
if verbose:
    print( 'data loaded' )

cache_file = pathlib.Path( cache )

if cache_file.is_file() and not no_cache:
    if verbose:
        print( 'using cache' )
    with shelve.open( str( cache_file ) ) as cache:
        for key in [ 'data' , 'border' , 'calibrations' ]:
            if key not in cache:
                raise Exception( 'spectrum.py: missing data in cache file' )
        data         = cache[ 'data' ]
        border       = cache[ 'border']
        spectrum     = cache[ 'spectrum' ]
        calibrations = cache[ 'calibrations' ]
else:
    if verbose:
        print( 'not using cache' )
        print( 'starting first zoom' )
    """
    find fill point
    """
    points = []

    points += utils.find_point( data[ : , 0 ] , 0 ) # x_min
    points += utils.find_point(
        data[ : , data.shape[1] - 1 ],
        data.shape[1] - 1
    ) # x_max

    index_min = 0
    while data.shape[0] - 1 > index_min:
        index_min += 1
        if len( utils.find_point(
            data[ index_min , : ],
            index_min            ,
            'y'                  ,
        ) ) == 3:
            break
    points.append(
        utils.find_point(
            data[ index_min , : ],
            index_min            ,
            'y'                  ,
        )[1]
    ) # y_min

    index_max = data.shape[0] - 1
    while index_min < index_max:
        index_max -= 1
        if len( utils.find_point(
            data[ index_max , : ],
            index_max            ,
            'y'                  ,
        ) ) == 3:
            break

    points.append(
        utils.find_point(
            data[ index_max , : ],
            index_max            ,
            'y'                  ,
        )[1]
    )

    small_data = utils.compress( data , 5 )
    points     = utils.big_to_small( points , 5 )

    # size - 1
    points[ 1 ][ 1 ] -= 1
    points[ 3 ][ 0 ] -= 1

    # little shift to be inside the light
    points[ 2 ][ 1 ] += 3
    points[ 3 ][ 1 ] += 3

    """
    fill data
    """

    extremum = []
    for point in points:
        if point[0] < points[2][0]:
            point[0] = points[2][0]
        if point[1] < points[0][1]:
            point[1] = points[0][1]
        taken_points = utils.small_to_big(
            np.array( [
                points[2][0],
                points[0][1],
            ] ) + utils.fill(
                small_data[
                    points[2][0] : points[3][0] + 1,
                    points[0][1] : points[1][1] + 1,
                ]   ,
                [
                    point[0] - points[2][0],
                    point[1] - points[0][1],
                ]   ,
                1000,
            ),
            5,
        )
        extremum.append( [
            np.min( taken_points[ : , 1 ] ),
            np.max( taken_points[ : , 1 ] ),
            np.min( taken_points[ : , 0 ] ),
            np.max( taken_points[ : , 0 ] ),
        ] )

    border = {
        'x': {
            'min': points[0][1] + extremum[0][1] + 1,
            'max': points[0][1] + extremum[1][0]    ,
        },
        'y': {
            'min': points[2][0] + extremum[2][3] + 1,
            'max': points[2][0] + extremum[3][2]    ,
        },
    }

    if verbose:
        print( 'first zoom finished'      )
        print( 'starting rotation'        )
        print( 'retrieving current angle' )

    """
    Rotation
    """

    gauss = lambda x , sigma , mu , a , b : a * (
            1 / ( sigma * np.sqrt( 2 * np.pi ) )
        ) * np.exp(
            - ( x - mu ) ** 2 / ( 2 * sigma ** 2 )
        ) + b
    guess_params = [
        1                                                      ,
        ( border[ 'x' ][ 'max' ] - border[ 'x' ][ 'min' ] ) / 2,
        np.mean( data[
            border[ 'y' ][ 'min' ] : border[ 'y' ][ 'max' ],
            border[ 'x' ][ 'min' ] : border[ 'x' ][ 'max' ]
        ] )                                                    ,
        np.mean( data[
            border[ 'y' ][ 'min' ] : border[ 'y' ][ 'max' ],
            border[ 'x' ][ 'min' ] : border[ 'x' ][ 'max' ]
        ] )                                                    ,
    ]
    number         = 1000
    position_peaks = np.zeros( number )
    indexes        = np.linspace( border[ 'x' ][ 'min' ] , border[ 'x' ][ 'max' ] , number , dtype = int )
    for i in range( number ):
        x = np.arange(
            border[ 'y' ][ 'min' ],
            border[ 'y' ][ 'max' ],
            1                     ,
        )
        y = data[
            border[ 'y' ][ 'min' ] : border[ 'y' ][ 'max' ],
            indexes[ i ]
        ]

        try:
            position_peaks[ i ] = curve_fit(
                gauss       ,
                x           ,
                y           ,
                guess_params,
            )[0][1]
        except:
            position_peaks[ i ] = 0
    position_peaks = medfilt(
        position_peaks[ np.where( position_peaks != 0 ) ],
        11                                               ,
    )
    abciss         = np.arange(
        len( position_peaks )
    )[ np.where( position_peaks ) ]
    polyval        = np.polyfit( abciss , position_peaks , 1 )

    angle = np.arctan( polyval[0] )

    if verbose:
        print( 'current angle retrieved: ' + str( angle ) )
        print( 'starting image rotation'                  )

    data = rotate( data , angle * ( 180 / np.pi ) ) # utils.rotate does not keep intenisty absolute value TODO

    if verbose:
        print( 'image rotation finished' )
        print( 'rotation finished'           )
        print( 'starting spectrum isolation' )
        print( 'starting y border detection' )

    """
    Spectrum y
    """

    list_ = np.mean(
        data[
            border[ 'y' ][ 'min' ] : border[ 'y' ][ 'max' ],
            border[ 'x' ][ 'min' ] : border[ 'x' ][ 'max' ],
        ]       ,
        axis = 1,
    )
    indexes = utils.near_value(
        list_                                                        ,
        np.mean( list_ ) + ( np.max( list_ ) - np.mean( list_ ) ) / 10,
    ) + border[ 'y' ][ 'min' ]

    spectrum = {
        'y': {
            'min': indexes[0],
            'max': indexes[1],
        },
    }

    if verbose:
        print( 'y border detection finished' )
        print( 'starting x border detection' )

    """
    Spectrum x
    """

    list_ = np.convolve(
        np.mean(
            data[
                spectrum[ 'y' ][ 'min' ] : spectrum[ 'y' ][ 'max' ],
                border[ 'x' ][ 'min' ]   : border[ 'x' ][ 'max' ]  ,
            ]       ,
            axis = 0,
        )             ,
        np.ones( 200 ),
        'valid'       ,
    )
    abciss = np.arange(
        border[ 'x' ][ 'min' ] + 100,
        border[ 'x' ][ 'max' ] - 99 ,
        1                           ,
    )
    indexes = utils.near_value(
        list_                                                   ,
        np.min( list_ ) + ( np.mean( list_ ) - np.min( list_ ) ),
    )
    factor = 1
    while len( indexes ) == 2:
        factor += 1
        indexes = utils.near_value(
            list_                                                            ,
            np.min( list_ ) + ( np.mean( list_ ) - np.min( list_ ) ) / factor,
        )
    factor -= 1
    indexes = utils.near_value(
        list_                                                            ,
        np.min( list_ ) + ( np.mean( list_ ) - np.min( list_ ) ) / factor,
    ) + border[ 'x' ][ 'min' ] + 100 # valid convolution only

    spectrum[ 'x' ] = {
        'min': indexes[ 0 ],
        'max': indexes[ 1 ],
    }

    if verbose:
        print( 'x border detection finished'    )
        print( 'spectrum isolation finished'    )
        print( 'starting calibration isolation' )

    """
    Calibration
    """

    def indicator( list_ ):
        """
        define an indicator which define if the horizontal slice has a
        chance to be a part of a calibration
        """
        list_  = list_.copy()    # do not change data
        list_ -= np.min( list_ ) # min 0
        list_ /= np.max( list_ ) # max 1

        amplitude = np.mean( list_ ) # the lower the better
        peaks     = find_peaks(
            list_                                                   ,
            height = np.mean( list_ ) + ( 1 - np.mean( list_ ) ) / 2,
        )[0]
        number    = 0.01 + abs( len( peaks ) - 90 ) # the lower the better
        intensity = np.sum( list_[ peaks ] )
        return intensity / ( amplitude ** 2 * number )

    indicators = np.convolve(
        np.array( [
            indicator(
                data[
                    i                                                  ,
                    spectrum[ 'x' ][ 'min' ] : spectrum[ 'x' ][ 'max' ],
                ],
            ) for i in range(
                border[ 'y' ][ 'min' ],
                border[ 'y' ][ 'max' ],
                1                     ,
            )
        ] )          ,
        np.ones( 10 ),
        'valid'      ,
    )
    indicators       /= np.max( indicators )
    calibration_areas = utils.consecutive( np.where( indicators > 1 / 1000 )[0] )
    calibration_areas = [
        [ calibration_area for calibration_area in calibration_areas if (
            calibration_area[0] < (
                border[ 'y' ][ 'max' ] - border[ 'y' ][ 'min' ]
            ) / 2
        ) ],
        [ calibration_area for calibration_area in calibration_areas if (
            calibration_area[0] > (
                border[ 'y' ][ 'max' ] - border[ 'y' ][ 'min' ]
            ) / 2
        ) ],
    ]
    calibration_sizes = [
        [ len( calibration_area ) for calibration_area in calibration_areas[0] ],
        [ len( calibration_area ) for calibration_area in calibration_areas[1] ],
    ]
    calibrations_y = [
        calibration_areas[0][
            np.argmax( calibration_sizes[0] )
        ],
        calibration_areas[1][
            np.argmax( calibration_sizes[1] )
        ],
    ]

    calibrations = {
        'top': {
            'x': {
                'min': spectrum[ 'x' ][ 'min' ],
                'max': spectrum[ 'x' ][ 'max' ],
            },
            'y': {
                'min': border[ 'y' ][ 'min' ] + calibrations_y[0][0],
                'max': border[ 'y' ][ 'min' ] + calibrations_y[0][-1],
            },
        },
        'down': {
            'x': {
                'min': spectrum[ 'x' ][ 'min' ],
                'max': spectrum[ 'x' ][ 'max' ],
            },
            'y': {
                'min': border[ 'y' ][ 'min' ] + calibrations_y[1][0],
                'max': border[ 'y' ][ 'min' ] + calibrations_y[1][-1],
            },
        }
    }

    if verbose:
        print( 'calibration isolation finished' )

    if not cache_file.exists() and not no_cache:
        if verbose:
            print( 'writing result in cache' )
        with shelve.open( str( cache_file ) ) as cache:
            cache[ 'data' ]         = data
            cache[ 'border' ]       = border
            cache[ 'spectrum' ]     = spectrum
            cache[ 'calibrations' ] = calibrations
        if verbose:
            print( 'cache saved' )

"""
Calibration
"""

wavelengths = np.arange( spectrum[ 'x' ][ 'max' ] - spectrum[ 'x' ][ 'min' ] )

if calibration != None:
    if verbose:
        print( 'starting wavelength calibration' )

    mean_data = np.mean( data[
        spectrum[ 'y' ][ 'min' ] : spectrum[ 'y' ][ 'max' ],
        spectrum[ 'x' ][ 'min' ] : spectrum[ 'x' ][ 'max' ]
    ] , axis = 0 )
    abciss    = np.arange( len( mean_data ) )

    ref = np.array( [
        6562.79,
        4861.35,
        4340.47,
        4101.73,
        3970.08,
        3889.06,
        3835.40,
#        3646   ,
    ] ) * 1e-10
    start , end = 5000 , 18440

    polyval_before = np.polyfit( abciss[       : start ] , mean_data[       : start ] , 2 )
    polyval_middle = np.polyfit( abciss[ start : end   ] , mean_data[ start : end   ] , 2 )
    polyval_end    = np.polyfit( abciss[ end   :       ] , mean_data[ end   :       ] , 2 )

    mean_data[       : start ] = mean_data[       : start ] - np.polyval( polyval_before , abciss[       : start ] )
    mean_data[ start : end   ] = mean_data[ start : end   ] - np.polyval( polyval_middle , abciss[ start : end   ] )
    mean_data[ end   :       ] = mean_data[ end   :       ] - np.polyval( polyval_end    , abciss[ end   :       ] )

    mean_data_normalized  = mean_data.copy() # normalization
    mean_data_normalized -= np.min( mean_data_normalized )
    mean_data_normalized /= np.max( mean_data_normalized )

    lines = [ np.mean( cons ) for cons in utils.consecutive( np.where( mean_data_normalized < 0.3 )[0] ) ]
    lines = np.array( lines[ : len( ref ) - 1 ] ) # Balmer discontinuity

    ref   = ref[ 1 : ] # start with H-beta

    wavelength_polyval = np.polyfit( lines , ref , 1 )

    wavelengths = np.polyval( wavelength_polyval , abciss )

    if verbose:
        print( 'wavelength calibration finished' )

if verbose:
    print( 'starting bias substraction' )

bias = {
    'top':  np.mean(
        data[
            calibrations[ 'top' ][ 'y' ][ 'min' ] - 100 :
            calibrations[ 'top' ][ 'y' ][ 'min' ]       ,
            calibrations[ 'top' ][ 'x' ][ 'min' ]       :
            calibrations[ 'top' ][ 'x' ][ 'max' ]
        ]       ,
        axis = 0,
    ),
    'down': np.mean(
        data[
            calibrations[ 'down' ][ 'y' ][ 'max' ]      :
            calibrations[ 'down' ][ 'y' ][ 'max' ] + 100,
            calibrations[ 'down' ][ 'x' ][ 'min' ]      :
            calibrations[ 'down' ][ 'x' ][ 'max' ]
        ]       ,
        axis = 0,
    ),
}

mean_bias = np.mean( [ bias[ 'top' ] , bias[ 'down' ] ] , axis = 0 )

if verbose:
    print( 'bias substraction finished' )

mean_data = np.mean( data[
    spectrum[ 'y' ][ 'min' ] : spectrum[ 'y' ][ 'max' ],
    spectrum[ 'x' ][ 'min' ] : spectrum[ 'x' ][ 'max' ]
] , axis = 0 )

if intensity != None:
    if verbose:
        print( 'starting intensity calibration' )

    intensity_file = pathlib.Path( intensity )

    with shelve.open( str( intensity_file ) ) as storage:
        intensity_stairs      = storage[ 'data'       ]
        intensity_wavelengths = storage[ 'wavelength' ] * 1e-10

    wavelengths = wavelengths[ # remove wavelengths outside range
        np.where(
            np.logical_and(
                wavelengths > np.min( intensity_wavelengths ),
                wavelengths < np.max( intensity_wavelengths ),
            ),
        )
    ]
    intensity_wavelengths = intensity_wavelengths[ # remove intensity_wavelengths outside range
        np.where(
            np.logical_and(
                intensity_wavelengths > np.min( wavelengths ),
                intensity_wavelengths < np.max( wavelengths ),
            ),
        )
    ]

    if len( wavelengths ) == 0:
        raise Exception( 'spectrum.py: spectrum and ETA does not share any common wavelengths' )

    step = 0.2 # depends of ETA source #TODO

    final_intensity = np.zeros( len( wavelengths ) )

    for index in range( len( wavelengths ) ):
        intensity_value      = mean_data[ index ]
        intensity_wavelength = wavelengths[ index ]
        intensity_index      = utils.near_value( # list of index corresponding to index for intensity_stairs
            intensity_wavelengths,
            intensity_wavelength ,
        )

        if len( intensity_index ) != 1: # too much or no intensity found near value
            final_intensity[ index ] = - 1
            continue
        intensity_index = intensity_index[0]

        indexes_stair_lower = np.where( # stairs lower than intensity value
            intensity_stairs[
                :              ,
                intensity_index,
                0
            ] < intensity_value
        )[0]

        if len( indexes_stair_lower ) == 0: # intensity value outside ETA (below)
            final_intensity[ index ] = 0
            continue
        if len( indexes_stair_lower ) != intensity_stairs.shape[0] - indexes_stair_lower[0]: # stairs intensity does not decrease with index as it should
                                                                                             # could indicate an artefact in ETA
            indexes_stair_lower = [
                int( np.mean(
                    [
                        intensity_stairs.shape[0] -
                        indexes_stair_lower[0]     ,
                        len( indexes_stair_lower )
                    ]
                ) )
            ]

        indexes_stair_higher = np.where( # stairs higher than intensity value
            intensity_stairs[
                :              ,
                intensity_index,
                0
            ] > intensity_value
        )[0]

        if len( indexes_stair_higher ) == 0: # intensity value outside ETA (upper)
            final_intensity[ index ] = intensity_stairs.shape[0] * step
            continue
        if len( indexes_stair_higher ) - 1 != indexes_stair_higher[-1]: # stairs intensity does not decrease with index as it should
            indexes_stair_higher = [
                int( np.mean(
                    [
                        indexes_stair_higher[ - 1 ]    ,
                        len( indexes_stair_higher ) - 1
                    ]
                ) )
            ]
            indexes_stair_higher = [
                indexes_stair_lower[ 0 ] - 1,
            ]

        index_stair = {
            'higher': indexes_stair_higher[-1],
            'lower' : indexes_stair_lower[0]  ,
        }

        if index_stair[ 'lower' ] - index_stair[ 'higher' ] != 1: # ETA curve should decrease
            raise Exception( 'spectrum.py: given intensity stairs (from ETA) are missformed' )

        stair_intensity = {
            'higher': intensity_stairs[ index_stair[ 'higher' ] , intensity_index , 0 ],
            'lower' : intensity_stairs[ index_stair[ 'lower'  ] , intensity_index , 0 ],
        }

        index_polyval = np.polyfit(            # fraction stair index from intensity value
            [ stair_intensity[ 'higher' ] , stair_intensity[ 'lower' ] ],
            [     index_stair[ 'higher' ] ,     index_stair[ 'lower' ] ],
            1                                                           ,
        )

        true_intensity_value = ( intensity_stairs.shape[0] - np.polyval( index_polyval , intensity_value ) ) * step

        final_intensity[index] = np.exp( true_intensity_value )

    if verbose:
        print( 'intensity calibration finished' )

if verbose:
    print( 'starting output writing' )

if output == None:
    print( final_intensity[1:-1] )
else:
    if verbose:
        print( 'storing result in ' + output )
    main_hdu  = fits.PrimaryHDU( final_intensity[1:-1] ) # remove -1
    main_hdu.header[ 'CRVAL1'   ] = wavelengths[0]
    main_hdu.header[ 'CDELT1'   ] = wavelengths[1] - wavelengths[0]
    main_hdu.header[ 'CTYPE1'   ] = 'Wavelength'
    main_hdu.header[ 'CUNIT1'   ] = 'Angstrom'
    main_hdu.header[ 'OBJNAME'  ] = head[ 'OBJECT' ]
    # missing from Naroo à récupérer sur le txt
    #main_hdu.header[ 'OBSERVER' ] = head[ '' ]
    #main_hdu.header[ 'DATE-OBS' ] = head| '' ]
    #main_hdu.header[ 'EXPTIME'  ] = head[ '' ]
    main_hdu.header[ 'RADECSYS' ] = 'FKS'
    main_hdu.header[ 'OBS-ID'   ] = head[ 'OBS_ID' ]
    main_hdu.header[ 'DATE-NUM' ] = head[ 'DATE' ]
    main_hdu.header[ 'COMMENT'  ] = head[ 'COMMENT' ]
    main_hdu.header[ 'POLICY'   ] = head[ 'POLICY' ]

    # BSS keywords
    main_hdu.header[ 'BSS_VHEL' ] = 0
    main_hdu.header[ 'BSS_TELL' ] = 'None'
    main_hdu.header[ 'BSS_COM'  ] = 'None'
    # missing from Naroo
    #main_hdu.header[ 'BSS_INST' ] = head[ '' ]
    #main_hdu.header[ 'BSS_SITE' ] = head[ '' ]

    hdul = fits.HDUList( [ main_hdu ] )
    hdul.writeto( output , overwrite = True )
if verbose:
    print( 'output writing finished' )
    print( '===========================================\
          \nend of spectrum.py' )