import numpy as np
import utils
import sys
from scipy.signal import convolve as sp_convolve
from scipy.signal import find_peaks
from scipy.ndimage import rotate

if len( sys.argv ) < 2:
    raise Exception( 'this command must have a filename of an ETA fits as an argument' )
data = utils.load( sys.argv[1] )

"""
find fill points
"""
points = []

points += utils.find_point( data[ : , 0 ] , 0 ) # x_min
points += utils.find_point(
        data[ : , data.shape[1] - 1 ],
        data.shape[1] - 1            ,
) # x_max

index_min = 0
while data.shape[0] - 1 > index_min:
    index_min += 1
    if len( utils.find_point( data[ index_min , : ] , index_min , 'y' ) ) == 3:
        break

points.append(
    utils.find_point( data[ index_min , : ] , index_min , 'y' )[1]
) # y_min

index_max = data.shape[0] - 1
while index_min < index_max:
    index_max -= 1
    if len( utils.find_point( data[ index_max , : ] , index_max , 'y' ) ) == 3:
        break


points.append(
    utils.find_point( data[ index_max , : ] , index_max , 'y' )[1]
) # y_max

small_data = utils.compress( data , 5 )

points = utils.big_to_small( points , 5 )

# size - 1
points[ 1 ][ 1 ] -= 1
points[ 3 ][ 0 ] -= 1

# little shift to be inside the light
points[ 2 ][ 1 ] += 3
points[ 3 ][ 1 ] += 3

"""
fill data
"""

extremum = []
for point in points:
    if point[0] < points[2][0]:
        point[0] = points[2][0]
    if point[1] < points[0][1]:
        point[1] = points[0][1]
    taken_points = utils.small_to_big(
        np.array( [
            points[2][0],
            points[0][1],
        ] ) + utils.fill(
            small_data[
                points[2][0] : points[3][0] + 1,
                points[0][1] : points[1][1] + 1
            ],
            [
                point[0] - points[2][0],
                point[1] - points[0][1],
            ],
            1000
        ),
        5
    )
    extremum.append( [
        np.min( taken_points[ : , 1 ] ),
        np.max( taken_points[ : , 1 ] ),
        np.min( taken_points[ : , 0 ] ),
        np.max( taken_points[ : , 0 ] ),
    ] )

border = {
    'x': {
        'min': points[0][1] + extremum[0][1] + 1,
        'max': points[0][1] + extremum[1][0]    ,
    },
    'y': {
        'min': points[2][0] + extremum[2][3] + 1,
        'max': points[2][0] + extremum[3][2]    ,
    },
}

"""
label deletion
"""

mean_data = np.convolve(
    np.gradient(
        np.mean(
            data[
                border[ 'y' ][ 'min' ] : border[ 'y' ][ 'max' ],
                border[ 'x' ][ 'min' ] : border[ 'x' ][ 'max' ]
            ],
            axis = 0
        )
    ),
    np.ones(
        int( 0.01 * (
            border[ 'x' ][ 'max' ] - border[ 'x' ][ 'min' ]
        ) )
    ),
    'same'
)

mean_data -= np.min( mean_data )
mean_data /= np.max( mean_data )

top  = utils.consecutive( np.where( mean_data > 0.75 )[0] )
down = utils.consecutive( np.where( mean_data < 0.25 )[0] )

size_top  = [ len( list_ ) for list_ in top  ]
size_down = [ len( list_ ) for list_ in down ]

label_x = {
    'min': border[ 'x' ][ 'min' ] + top[ np.argmax( size_top ) ][0]   ,
    'max': border[ 'x' ][ 'min' ] + down[ np.argmax( size_down ) ][-1]
}

if label_x[ 'min' ] < data.shape[1] // 2:
    if label_x[ 'max' ] < data.shape[1] // 2:
        border[ 'x' ][ 'min' ] = label_x[ 'max' ]
    else:
        raise Exception( 'the label seems to be in the middle of the picture' )
elif label_x[ 'max' ] > data.shape[1] // 2:
    border[ 'x' ][ 'max' ] = label_x[ 'min' ]
else:
    raise Exception( 'for an unkown reason, label_x[ \'min\' ] > label_x[ \'max\' ]' )

"""
Rotation
"""

index    = border[ 'x' ][ 'min' ]
gradient = np.gradient(
    data[
        border[ 'y' ][ 'min' ] : border[ 'y' ][ 'min' ] + (
            border[ 'y' ][ 'max' ] - border[ 'y' ][ 'min' ]
        ) // 2,
        index
    ]
)
while np.max( gradient ) - np.min( gradient ) > 5500:
    index   += 1
    gradient = np.gradient(
        data[
            border[ 'y' ][ 'min' ] : border[ 'y' ][ 'min' ] + (
                border[ 'y' ][ 'max' ] - border[ 'y' ][ 'min' ]
            ) // 2,
            index
        ]
    )

positions = np.argmax(
    sp_convolve(
        np.gradient(
            data[
                border[ 'y' ][ 'min' ] : border[ 'y' ][ 'min' ] + (
                    border[ 'y' ][ 'max' ] - border[ 'y' ][ 'min' ]
                ) // 2                        ,
                border[ 'x' ][ 'min' ] : index
            ]       ,
            axis = 0
        )                     ,
        np.ones( ( 100 , 1 ) ),
        'valid'
    )       ,
    axis = 0
)

list_   = np.arange(  0     , index - border[ 'x' ][ 'min' ] , 1 )
polyval = np.polyfit( list_ , positions                      , 1 )

angle = np.arctan( polyval[0] )

data = rotate( data , angle * ( 180 / np.pi ) ) # utils.rotate does not keep intensity absolute value ? TODO

diff_y = int( np.tan( angle ) * ( border[ 'x' ][ 'max' ] - border[ 'x' ][ 'min' ] ) )

border[ 'y' ][ 'min' ] -= diff_y
border[ 'y' ][ 'max' ] -= diff_y

"""
Calibration
"""

tot_avg = np.mean(
    data[
        border[ 'y' ][ 'min' ] : border[ 'y' ][ 'max' ],
        border[ 'x' ][ 'min' ] : border[ 'x' ][ 'max' ]
    ]
)
def indicator( list_ ):
    if np.mean( list_ ) > 0.75 * tot_avg:
        return 0
    if np.mean( list_ ) < 0.25 * tot_avg:
        return 1
    list_ -= np.min( list_ )
    list_ /= np.max( list_ )
    positions = np.where( list_ > 0.5 )[0]
    if len( positions ) < 10:
        return 2
    if len( positions ) > 400:
        return 3
    distance = np.mean( positions[ 1 : ] - positions[ : -1 ] )
    if distance < 10:
        return 4
    return 10

indicators = np.array( [ indicator( data[ i , border[ 'x' ][ 'min' ] : border[ 'x' ][ 'max' ] ] ) for i in range( border[ 'y' ][ 'min' ] , border[ 'y' ][ 'max' ] , 1 ) ] )

calibration_areas = utils.consecutive( np.where( indicators == 10 )[0] )
calibration_sizes = [ len( calibration_area ) for calibration_area in calibration_areas ]

y_calibrations = [ calibration_areas[ i ] for i in np.argsort( calibration_sizes ) ][ -2 : ]
calibrations = {
    'top': {
        'x': {
            'min': border['x']['min'],
            'max': border['x']['max'],
        },
        'y': {
            'min': border['y']['min'] + y_calibrations[0][ 0],
            'max': border['y']['min'] + y_calibrations[0][-1],
        },
    },
    'down': {
        'x': {
            'min': border['x']['min'],
            'max': border['x']['max'],
        },
        'y': {
            'min': border['y']['min'] + y_calibrations[1][ 0],
            'max': border['y']['min'] + y_calibrations[1][-1],
        },
    },
}

"""
stripes curves detection
"""

list_ = data[
    calibrations[ 'top' ][ 'y' ][ 'max' ] : calibrations[ 'down' ][ 'y' ][ 'min' ],
    border[ 'x' ][ 'min' ] : border[ 'x' ][ 'max' ]
].copy()
list_ -= np.min( list_ )
list_ /= np.max( list_ )

size  = border[ 'x' ][ 'max' ] - border[ 'x' ][ 'min' ]
x_stripe = np.arange( border[ 'x' ][ 'min' ] + 1 * size / 4 , border[ 'x' ][ 'min' ] + 3 * size / 4 , 1 ).astype( int )
y_stripe = np.array( [
    np.where(
        list_[ : , x ] > 0.8
    )[0][0] for x in x_stripe 
] )

stripes = [ # list of polyval result for each stripe
    np.polyfit( x_strip , y_stripe , 2 )
]

import matplotlib.pyplot as plt
plt.plot( x_stripe , y_stripe )
plt.plot( x_stripe , np.polyval( stripes[0] , x_stripe ) )
plt.savefig( 'asset/stripe.png' )