import numpy as np
import scipy.signal as sig

def remove_peaks( signal , peaks ):
    """
    remove peaks from a signal
    """
    peakless_signal = signal.copy()
    for peak in peaks:
        first = peak
        peak , old = first - 2 , first - 1
        while peakless_signal[ peak ] <= peakless_signal[ old ]:
            old   = peak
            peak -= 1
        peakless_signal[ peak : first ] = peakless_signal[ peak ] * np.ones( first - peak )

        peak , old = first + 2 , first + 1
        while peakless_signal[ peak ] <= peakless_signal[ old ]:
            old   = peak
            peak += 1
        peakless_signal[ first : peak ] = peakless_signal[ peak ] * np.ones( peak - first )

    return sig.medfilt( peakless_signal , 111 )

def get_extremities( signal , peaks ):
    """
    It's possible to have an idea of the part the spectrum begins with the
    small large peak at the start of it, the peak at 3810 A should be inside
    it. The same goes for the area at the end.
    ONLY TRUE FOR THE CURRENT CALIBRATION (Hg)
    """
    peaks_inside , i = [ 0 , 0 ] , 0
    argmin           = len( signal ) // 2
    argmin           = np.argmin(
        signal[ : argmin ]
    )
    argmax           = np.argmax( signal[ : argmin ] )
    peaks_inside = np.where(
        np.logical_and(
            argmax < peaks,
            argmin > peaks,
        )
    )[0]
    if len( peaks_inside ) == 0:
        raise Exception( 'unknown plage, cannot autocalibrate' )
    first_peak = peaks_inside[0]

    """
    The next peak after the minimum at the end of the spectrum should be
    5079 A.
    ONLY TRUE FOR THE CURRENT CALIBRATION (Hg)
    """
    argmin_1 = np.argmin(
        signal[
            len( signal ) // 2 :
            - int(
                0.1 * len( signal ) // 2
            )
        ] # not at the end
    ) + len( signal ) // 2
    peaks_inside = np.where(
        argmin_1 < peaks,
    )[0]
    if len( peaks_inside ) < 1:
        raise Exception( 'unknown plage, cannot autocalibrate' )
    return ( first_peak , peaks_inside[0] )

def only_keep_calib( peaks_data , peaks_calib ):
    """
    only keep data peaks corresponding to calibration
    """
    diff_calib  = ( peaks_calib[ 1 : ] - peaks_calib[ : -1 ] ).astype( float )
    diff_calib -= np.min( diff_calib )
    diff_calib /= np.max( diff_calib )

    diff_data  = ( peaks_data[ 1 : ] - peaks_data[ : -1 ] ).astype( float )
    diff_data -= np.min( diff_data )
    diff_data /= np.max( diff_data )

    peaks = [ -1 ]
    for i in range( len( diff_calib ) ):
        good , sum_ = -1 , 0
        for j in range( peaks[ - 1 ] + 1 , len( diff_data ) ):
            sum_ += diff_data[ j ]
            if sum_ - diff_calib[ i ] >  0.002:
                print( sum_ - diff_calib[ i ] )
                raise Exception( 'reference peak not found' )
            if sum_ - diff_calib[ i ] > - 0.002:
                good = j
                break
        if good == -1:
            raise Exception( 'reference peak not found and not exceeded' )
        peaks.append( good )
    peaks.append( peaks[-1] + 1 ) # append the last peak

    return np.array(
        [ peaks_data[ i ] for i in peaks[ 1 : ] ] # remove the first -1 value
    ).astype( int )