import numpy as np

def check_side( data , point , tolerance ):
    """
    give coordinates of all side point of the given point which have an
    intensity difference inferior than tolerance
    """
    if not isinstance( data , np.ndarray ) and not isinstance( data , list ):
        raise ValueError( 'data must be a list, ' + type( data ) + ' given' )
    if not isinstance( point , np.ndarray ) and not isinstance( point , tuple ) and not isinstance( point , list ):
        raise ValueError( 'point must be a tuple, ' + type( point ) + ' given' )
    if not isinstance( tolerance , int ) and not isinstance( tolerance , float ):
        raise ValueError( 'tolerance must be a number, ' + type( tolerance ) + ' given' )

    positions , intensity = [] , data[ tuple( point ) ]
    if 0 <= point[0] < data.shape[0] - 1 and intensity - tolerance <= data[ point[0] + 1 , point[1] ] <= intensity + tolerance:
        positions.append( [ point[0] + 1 , point[1] ] )
    if 0 < point[0] < data.shape[0] and intensity - tolerance <= data[ point[0] - 1 , point[1] ] <= intensity + tolerance:
        positions.append( [ point[0] - 1 , point[1] ] )
    if 0 <= point[1] < data.shape[1] - 1 and intensity - tolerance <= data[ point[0] , point[1] + 1 ] <= intensity + tolerance:
        positions.append( [ point[0] , point[1] + 1 ] )
    if 0 < point[1] < data.shape[1] and intensity - tolerance <= data[ point[0] , point[1] - 1 ] <= intensity + tolerance:
        positions.append( [ point[0] , point[1] - 1 ] )
    return positions

def fill( data , point , tolerance , limit = 100000 ):
    """
    give the coordinate of all points that fill the area with the given tolerance
    """
    if not isinstance( data , np.ndarray ) and not isinstance( data , list ):
        raise ValueError( 'data must be a list, ' + type( data ) + ' given' )
    if not isinstance( point , np.ndarray ) and not isinstance( point , tuple ) and not isinstance( point , list ):
        raise ValueError( 'point must be a tuple, ' + type( point ) + ' given' )
    if not isinstance( tolerance , int ) and not isinstance( tolerance , float ):
        raise ValueError( 'tolerance must be a number, ' + type( tolerance ) + ' given' )
    if not isinstance( limit , int ):
        raise ValueError( 'limit must be an integer, ' + type( limit ) + ' given' )

    taken_point = []
    new_points  = [ point ]
    i           = 0
    while len( new_points ) != 0 and i < limit:
        point = new_points.pop(0)
        taken_point.append( point )
        for position in check_side( data , point , tolerance ):
            if not position in new_points and not position in taken_point:
                new_points.append( position )
        i += 1
    return np.array( taken_point )

def point( index_1 , index_2 , axis = 'x' ):
    """
    reorder coordinate
    """
    if not isinstance( index_1 , int ):
        raise ValueError( 'index_1 must be an integer, ' + type( index_1 ) + ' given' )
    if not isinstance( index_2 , int ):
        raise ValueError( 'index_2 must be an integer, ' + type( index_2 ) + ' given' )
    if not isinstance( axis , str ):
        raise ValueError( 'axis must be a string, ' + type( axis ) + ' given' )
    if axis not in [ 'x' , 'y' ]:
        raise ValueError( 'axis must be "x" or "y", ' + axis + ' given' )
    if axis == 'x':
        return [ index_2 , index_1 ]
    return [ index_1 , index_2 ]

def find_point( list_ , index , axis = 'x' , threshold = 0.95 ):
    """
    find the index where to fill in a side
    """
    if not isinstance( list_ , list ) and not isinstance( list_ , np.ndarray ):
        raise ValueError( 'list_ must be a list, ' + type( list_ ) + ' given' )
    if not isinstance( index , int ):
        raise ValueError( 'index must be an integer, ' + type( index ) + ' given' )
    if not isinstance( axis , str ):
        raise ValueError( 'axis must be a string, ' + type( axis ) + ' given' )
    if axis not in [ 'x' , 'y' ]:
        raise ValueError( 'axis must be "x" or "y", ' + axis + ' given' )
    if not isinstance( threshold , float ):
        raise ValueError( 'threshold must be a float, ' + type( threshold ) + ' given' )

    ampl = np.max( list_ ) - np.min( list_ )

    if ampl < np.mean( list_ ) / 2:
        return [ point( index , 0 , axis ) ]
    else:
        points = []

        list_  = list_.copy()
        list_ -= np.min( list_ )
        list_ /= np.max( list_ )

        i , inside , size = 0 , False , 0
        while i < len( list_ ):
            if list_[ i ] > threshold and not inside:
                points.append( point( index , i, axis ) )
                inside = True
                size   = 0
            elif list_[ i ] < threshold and inside:
                size += 1
                if size > 0.01 * len( list_ ): # low sensibility
                    inside = False
            i += 1
    return points
def consecutive( list_ ):
    """
    divide a sorted list of integer by consecutive part
    """
    if not isinstance( list_ , list ) and not isinstance( list_ , np.ndarray ):
        raise ValueError( 'list_ must be a list, ' + type( list_ ) + ' given' )
    if len( list_ ) == 0:
        return list_
    index = last_consecutive( list_ )
    if index == len( list_ ) - 1:
        return [ list_ ]
    return [ list_[ : index + 1 ] ] + consecutive( list_[ index + 1 : ] ) # happy recursion \o/
def last_consecutive( list_ ):
    """
    return the last index of the first consecutive list
    """
    if not isinstance( list_ , list ) and not isinstance( list_ , np.ndarray ):
        raise ValueError( 'list_ must be a list, ' + type( list_ ) + ' given' )
    first , lower , greater = list_[0] , 0 , len( list_ )
    i = lower + ( greater - lower ) // 2
    while greater - lower != 0:
        i = lower + ( greater - lower ) // 2
        if list_[ i ] - first != i: # outside of the consecutive list
            greater = i
        else:
            if i != len( list_ ) - 1:
                if list_[ i ] + 1 != list_[ i + 1 ]: # next one is not inside the consecutive list => limit retrieved
                    break
                lower = i
            else: # if inside the consecutive list and last element, every element is consecutive
                break
    return i
def same_value( list_ ):
    """
    divide a sorted list of integer by same value part
    """
    if not isinstance( list_ , list ) and not isinstance( list_ , np.ndarray ):
        raise ValueError( 'list_ must be a list, ' + type( list_ ) + ' given' )
    if len( list_ ) == 0:
        return list_
    counter = np.arange( 1 , len( list_ ) )
    return np.split( list_ , counter[ list_[ 1 : ] != list_[ : - 1 ] ] )
def last_same_value( list_ ):
    """
    return the last index of the first same value list
    """
    if not isinstance( list_ , list ) and not isinstance( list_ , np.ndarray ):
        raise ValueError( 'list_ must be a list, ' + type( list_ ) + ' given' )
    value = list_[0]
    return np.argwhere( list_ == value ).max()
def retrieve_peaks( data , window_size = 5 , error_coef = 1.05 , max_window_size = 30 , min_successive = 2 ):
    """
    get peak position from a 1D data
    """
    spectral_energy    = np.log( data ** 2 )
    error_thr          = error_coef / np.median( spectral_energy )

    average_window = np.convolve(
        spectral_energy       ,
        np.ones( window_size ),
        'same'                ,
    ) / window_size
    average_energy = np.mean( average_window )
    peaks          = np.where(
        average_window / average_energy ** 2 > error_thr
    )[0]
    peaks          = [
        np.mean( peak ) for peak in consecutive( peaks )
    ]
    successive = 0

    while successive < min_successive and window_size < max_window_size:
        average_window = np.convolve(
            spectral_energy       ,
            np.ones( window_size ),
            'same'                ,
        ) / window_size
        average_energy = np.mean( average_window )
        new_peaks      = np.where(
            average_window / average_energy ** 2 > error_thr
        )[0]
        new_peaks = [
            np.mean( peak ) for peak in consecutive( new_peaks )
        ]

        if len( peaks ) == len( new_peaks ):
            successive += 1
        else:
            successive = 0
            peaks      = new_peaks
        window_size += 1
    return peaks
def near_value( list_ , value ):
    """
    return indexes of the list whith a value nearest of the given
    one when crossing it
    """
    change = np.where( np.diff( np.sign( list_ - value ) ) != 0 ) # sign change
    index  = change + (
        value - list_[ change ]
    ) / (
        list_[ change + np.ones_like( change ) ] - list_[ change ]
    ) # interpolation
    index = np.append( index , np.where( list_ == value ) )
    return np.round( np.sort( index ) ).astype( int ) # triage

def cut_biggest( list_ ):
    """
    Return index of start and end of the biggest peak in a list
    """
    factor = 1
    indexes = near_value(
        list_          ,
        np.max( list_ ),
    )
    if len( indexes ) > 2:
        import matplotlib.pyplot as plt
        plt.plot( list_ )
        plt.show()
        raise Exception( 'too much peak' )
    while len( indexes ) < 2:
        factor += 1
        indexes = near_value(
            list_     ,
            np.min( list_ ) + (
                np.max( list_ ) - np.mean( list_ )
            ) / factor,
        )
    factor -= 1
    indexes = near_value(
        list_     ,
        np.min( list_ ) + (
            np.max( list_ ) - np.mean( list_ )
        ) / factor,
    )
    if len( indexes ) == 2:
        raise Exception( 'less than two pixel peak' )

    return indexes