Add intensity step correction

- Finished curve correction
- Add intensity correction for each x
This commit is contained in:
linarphy 2023-05-15 16:43:59 +02:00
parent 2840f86646
commit 15325e0e8f
No known key found for this signature in database
GPG key ID: 3D4AAAC3AD16E79C

84
ETA.py
View file

@ -1,6 +1,9 @@
import numpy as np
import matplotlib.pyplot as plt
import utils
import sys
import pathlib
import shelve
from scipy.signal import convolve as sp_convolve
from scipy.signal import find_peaks
from scipy.ndimage import rotate
@ -9,6 +12,14 @@ if len( sys.argv ) < 2:
raise Exception( 'this command must have a filename of an ETA fits as an argument' )
data = utils.load( sys.argv[1] )
cache_file = pathlib.Path( 'asset/points_' + sys.argv[1].split( '/' )[-1][:-5] + '.pag' )
if cache_file.is_file():
with shelve.open( str( cache_file ) ) as cache:
data = cache[ 'rotated_data' ]
border = cache[ 'border' ]
calibrations = cache[ 'calibrations' ]
else:
"""
find fill points
"""
@ -224,7 +235,7 @@ def indicator( list_ ):
return 4
return 10
indicators = np.array( [ indicator( data[ i , border[ 'x' ][ 'min' ] : border[ 'x' ][ 'max' ] ] ) for i in range( border[ 'y' ][ 'min' ] , border[ 'y' ][ 'max' ] , 1 ) ] )
indicators = np.array( [ indicator( data[ i , border[ 'x' ][ 'min' ] : border[ 'x' ][ 'max' ] ].copy() ) for i in range( border[ 'y' ][ 'min' ] , border[ 'y' ][ 'max' ] , 1 ) ] )
calibration_areas = utils.consecutive( np.where( indicators == 10 )[0] )
calibration_sizes = [ len( calibration_area ) for calibration_area in calibration_areas ]
@ -253,36 +264,73 @@ calibrations = {
},
}
with shelve.open( str( cache_file ) ) as cache:
cache[ 'rotated_data' ] = data
cache[ 'border' ] = border
cache[ 'calibrations'] = calibrations
"""
stripes curves detection
"""
list_ = data[
calibrations[ 'top' ][ 'y' ][ 'max' ] : calibrations[ 'down' ][ 'y' ][ 'min' ],
border[ 'x' ][ 'min' ] : border[ 'x' ][ 'max' ]
].copy()
list_ -= np.min( list_ , axis = 0 )
list_ /= np.max( list_ , axis = 0 )
size = list_.shape[1]
size = border[ 'x' ][ 'max' ] - border[ 'x' ][ 'min' ]
x_stripe = np.arange( border[ 'x' ][ 'min' ] + 1 * size // 4 , border[ 'x' ][ 'min' ] + 3 * size // 4 , 1 ).astype( int )
y_stripe = np.array( [
np.where(
list_[ : , x ] > 0.8
)[0][0] for x in x_stripe
] )
y_stripe = np.argmax( list_ , axis = 0 )
good_x = np.where( y_stripe < 2 * np.mean( y_stripe ) )[0]
x_stripe = np.arange( 0 , size , 1 ).astype( int )[ good_x ]
y_stripe = y_stripe[ good_x ]
stripes = [ # list of polyval result for each stripe
np.polyfit( x_stripe , y_stripe , 2 )
np.polyfit( x_stripe , y_stripe , 3 )
]
# First deformation
y_diff = np.polyval( stripes[0] , np.arange( 0 , size , 1 ) ).astype( int )
results = np.zeros( ( list_.shape[1] , list_.shape[0] - np.max( y_diff ) ) )
for i in range( list_.shape[1] ):
results[i] = list_[ y_diff[ i ] : list_.shape[0] + y_diff[ i ] - np.max( y_diff ) , i ]
results = results.transpose()
import matplotlib.pyplot as plt
plt.imshow( results )
plt.savefig( 'asset/deformation.png' )
y_diff = ( np.polyval( stripes[0] , np.arange( 0 , size , 1 ) ) ).astype( int )
y_diff[ np.where( y_diff < 0 ) ] = 0
results = np.zeros( ( list_.shape[0] + np.max( y_diff ) , list_.shape[1] ) )
for i in range( list_.shape[1] ):
results[ : , i ] = np.concatenate( ( np.zeros( np.max( y_diff ) - y_diff[ i ] ) , list_[ : , i ] , np.zeros( y_diff[i] ) ) )
list_results = np.convolve(
np.gradient(
np.mean( results , axis = 1 ),
) ,
np.ones( 50 ),
'same' ,
)
fall = utils.consecutive( np.where( list_results < - 0.02 )[0] )
fall = np.array( [
np.argmax( list_results )
] + [
consecutive[0] + np.argmin(
list_results[ consecutive[0] : consecutive[-1] ]
) for consecutive in fall
] ).astype( int )
"""
plt.imshow( results , aspect = 'auto' )
plt.hlines( fall , 0 , size )
plt.show()
"""
temp = np.convolve( results[ : , 10000 ] , np.ones( 50 ) , 'same' )
for i in range( len( fall ) - 1 ):
temp[ fall[ i ] : fall[ i + 1 ] ] = np.mean( temp[ fall[ i ] : fall[ i + 1 ] ] )
plt.plot( temp )
plt.plot(
np.convolve(
results[ : , 10000 ],
np.ones( 50 ) ,
'same' ,
),
)
plt.vlines( fall , 0 , 50 , colors = 'red' )
plt.show()