Update label detection method

- Remove old label detection code and dependancy to scipy
- Add quick label detection method without dependancies
This commit is contained in:
linarphy 2023-05-09 12:43:34 +02:00
parent ba1a591669
commit a240bf83e8
No known key found for this signature in database
GPG key ID: 3D4AAAC3AD16E79C

131
ETA.py
View file

@ -1,9 +1,10 @@
import numpy as np import numpy as np
from scipy.optimize import curve_fit
import utils import utils
import sys import sys
data = utils.load( sys.argv[1] ) if len( sys.argv ) < 2:
raise Exception( 'this command must have a filename of an ETA fits as an argument' )
data = utils.load( sys.argv[1] )
""" """
find fill points find fill points
@ -55,8 +56,25 @@ fill data
extremum = [] extremum = []
for point in points: for point in points:
if point[0] < points[2][0]:
point[0] = points[2][0]
if point[1] < points[0][1]:
point[1] = points[0][1]
taken_points = utils.small_to_big( taken_points = utils.small_to_big(
utils.fill( small_data , point , 1000 ), np.array( [
points[2][0],
points[0][1],
] ) + utils.fill(
small_data[
points[2][0] : points[3][0] + 1,
points[0][1] : points[1][1] + 1
],
[
point[0] - points[2][0],
point[1] - points[0][1],
],
1000
),
5 5
) )
extremum.append( [ extremum.append( [
@ -68,12 +86,12 @@ for point in points:
border = { border = {
'x': { 'x': {
'min': extremum[0][1] + 1, 'min': points[0][1] + extremum[0][1] + 1,
'max': extremum[1][0] , 'max': points[0][1] + extremum[1][0] ,
}, },
'y': { 'y': {
'min': extremum[2][3] + 1, 'min': points[2][0] + extremum[2][3] + 1,
'max': extremum[3][2] , 'max': points[2][0] + extremum[3][2] ,
}, },
} }
@ -81,72 +99,47 @@ border = {
label deletion label deletion
""" """
mean_data = np.mean( data[ mean_data = np.convolve(
border['y']['min'] : border['y']['max'], np.gradient(
border['x']['min'] : border['x']['max'] np.mean(
] , axis = 0 ) data[
border[ 'y' ][ 'min' ] : border[ 'y' ][ 'max' ],
gauss = lambda x , sigma , mu , a , b : a * ( border[ 'x' ][ 'min' ] : border[ 'x' ][ 'max' ]
1 / sigma * np.sqrt( ],
2 * np.pi axis = 0
) * np.exp(
- ( x - mu ) ** 2 / ( 2 * sigma ** 2 )
)
) + b
abciss = np.arange(
border['x']['min'],
border['x']['max'],
1
)
guess_params = [
1 ,
border['x']['min'] + ( border['x']['max'] - border['x']['min'] ) // 2,
np.max( mean_data ) ,
np.min( mean_data ) ,
]
first_estimate = curve_fit(
gauss ,
abciss ,
mean_data ,
guess_params
)[0]
part_data = [
mean_data[ : mean_data.shape[0] // 2 ],
mean_data[ mean_data.shape[0] // 2 : ]
]
part_abciss = [
abciss[ : abciss.shape[0] // 2 ],
abciss[ abciss.shape[0] // 2 : ]
]
part_result = []
for i in range( 2 ):
part_result.append(
curve_fit(
gauss ,
part_abciss[i],
part_data[i] ,
first_estimate
) )
) ),
np.ones(
cov = np.array( [ int( 0.01 * (
np.sum( np.diag( part_result[i][1] ) ) for i in range( 2 ) border[ 'x' ][ 'max' ] - border[ 'x' ][ 'min' ]
] ) ) )
i = np.argmax( cov ) # part where the label is ),
'same'
derivee = np.convolve(
np.gradient( part_data[i] ),
np.ones( part_data[i].shape[0] // 100 ),
'same',
) )
start_label = np.argmax( derivee )
end_label = np.argmin( derivee[ start_label :: ( - 1 ) ** i ] )
keys = [ 'min' , 'max' ] mean_data -= np.min( mean_data )
mean_data /= np.max( mean_data )
border['x'][keys[i]] += ( - 1 ) ** i * ( start_label + end_label ) top = utils.consecutive( np.where( mean_data > 0.75 )[0] )
down = utils.consecutive( np.where( mean_data < 0.25 )[0] )
size_top = [ len( list_ ) for list_ in top ]
size_down = [ len( list_ ) for list_ in down ]
label_x = {
'min': border[ 'x' ][ 'min' ] + top[ np.argmax( size_top ) ][0] ,
'max': border[ 'x' ][ 'min' ] + down[ np.argmax( size_down ) ][-1]
}
if label_x[ 'min' ] < data.shape[1] // 2:
if label_x[ 'max' ] < data.shape[1] // 2:
border[ 'x' ][ 'min' ] = label_x[ 'max' ]
else:
raise Exception( 'the label seems to be in the middle of the picture' )
elif label_x[ 'max' ] > data.shape[1] // 2:
border[ 'x' ][ 'max' ] = label_x[ 'min' ]
else:
raise Exception( 'for an unkown reason, label_x[ \'min\' ] > label_x[ \'max\' ]' )
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
plt.imshow( data[ plt.imshow( data[