Fix fitting & typing

This commit is contained in:
linarphy 2024-05-24 16:22:18 +02:00
parent ec5e67536e
commit 59a01b77ed
No known key found for this signature in database
GPG key ID: E61920135EFF2295
5 changed files with 114 additions and 28 deletions

View file

@ -42,3 +42,7 @@ docstring-code-format = true
[tool.basedpyright] [tool.basedpyright]
typeCheckingMode = "all" typeCheckingMode = "all"
reportUnknown = false
reportUnknownMemberType = false
reportUnknownArgumentType = false
reportUnknownVariableType = false

View file

@ -98,4 +98,39 @@ def opt_compute_light(
) )
def fit_compute_light(
scatter_factor: float,
factor_n: Signal,
coupling_n: Signal,
factor_d: Signal,
coupling_d: Signal,
power_in: float,
power_out: float,
data: Signal,
reference: Signal,
) -> float:
"""
Scalar function used to find the right scattering factor
"""
return sum(
abs(
Signal(
factor_n.x,
opt_compute_light(
scatter_factor=scatter_factor,
factor_n=factor_n,
coupling_n=coupling_n,
factor_d=factor_d,
coupling_d=coupling_d,
power_in=power_in,
power_out=power_out,
),
factor_n.settings,
)
+ reference
- data
).y
)
from backscattering_analyzer.signal import Signal # no circular import from backscattering_analyzer.signal import Signal # no circular import

View file

@ -1,14 +1,15 @@
# utils
from sys import argv from sys import argv
from backscattering_analyzer.settings import Settings from backscattering_analyzer.settings import Settings
from backscattering_analyzer.signal import Signal from backscattering_analyzer.signal import Signal
from backscattering_analyzer import ( from backscattering_analyzer import (
compute_light, compute_light,
fit_compute_light,
opt_compute_light, opt_compute_light,
interpolate, interpolate,
) )
from numpy import loadtxt, logspace, where, zeros, argmin, intp, pi from numpy import loadtxt, pi
from scipy.io.matlab import loadmat from scipy.io.matlab import loadmat
from scipy.optimize import Bounds, minimize
class Analyzer: class Analyzer:
@ -149,24 +150,22 @@ class Analyzer:
power_out=self.settings.power_out, power_out=self.settings.power_out,
) )
def fit_scatter_factor( def fit_scatter_factor(self, guess: None | float = None) -> float:
self, start: int, stop: int, number: int
) -> tuple[intp, float]:
""" """
Find the best scatter factor (first order only) in the given Find the best scatter factor (first order only) in the given
range range
""" """
import matplotlib.pyplot as plt if guess is None:
factors = logspace(start, stop, number) guess = self.settings.scattering_factor[0]
sums = zeros(number)
phase = 4 * pi / self.settings.wavelength phase = 4 * pi / self.settings.wavelength
factor_n = (self.movement * phase).sin().psd().sqrt() factor_n = (self.movement * phase).sin().psd().sqrt()
coupling_n = self.coupling[0].abs() coupling_n = self.coupling[0].abs()
factor_d = (self.movement * phase).cos().psd().sqrt() factor_d = (self.movement * phase).cos().psd().sqrt()
coupling_d = self.coupling[1].abs() coupling_d = self.coupling[1].abs()
coupling_d.cut(5, 40) # cut signal between 5 and 40 Hz coupling_d = coupling_d.cut_x(
10, 200
) # cut signal between 10 and 200 Hz
factor_n, coupling_n, factor_d, coupling_d, data, reference = ( factor_n, coupling_n, factor_d, coupling_d, data, reference = (
interpolate( interpolate(
@ -181,25 +180,56 @@ class Analyzer:
) )
) )
reference = reference.y bounds = Bounds(0, 1)
data = data.y min_result = minimize(
fit_compute_light,
guess,
(
factor_n,
coupling_n,
factor_d,
coupling_d,
self.settings.power_in,
self.settings.power_out,
data,
reference,
),
method = 'TNC',
bounds=bounds,
)
for index in range(number): if not min_result.success:
self.settings.log("{}".format(index)) raise Exception(min_result.message)
projection = opt_compute_light(
scatter_factor=factors[index], self.settings.log(
"found the best scattering factor in {} iterations".format(
min_result.nit
)
)
import matplotlib.pyplot as plt
projection = Signal(
factor_n.x,
opt_compute_light(
scatter_factor=min_result.x,
factor_n=factor_n, factor_n=factor_n,
coupling_n=coupling_n, coupling_n=coupling_n,
factor_d=factor_d, factor_d=factor_d,
coupling_d=coupling_d, coupling_d=coupling_d,
power_in=self.settings.power_in, power_in=self.settings.power_in,
power_out=self.settings.power_out, power_out=self.settings.power_out,
) ),
diff = abs(projection + reference - data) self.settings,
_ = plt.loglog(projection + reference) )
_ = plt.loglog(data)
_ = plt.show()
sums[index] = sum(diff)
min_index = argmin(sums)
return min_index, factors[min_index] _ = plt.loglog(projection.x, projection.y, label="projection")
_ = plt.loglog(reference.x, reference.y, label="référence")
_ = plt.loglog(data.x, data.y, label="data")
_ = plt.loglog(
reference.x, projection.y + reference.y, label="somme"
)
_ = plt.legend()
_ = plt.show()
return min_result.x

View file

@ -14,6 +14,7 @@ class Settings:
"option": "grey50 italic", "option": "grey50 italic",
"argument": "red", "argument": "red",
"description": "green italic", "description": "green italic",
"warning": "bold red",
} }
) )
self.console = Console(theme=self.theme) self.console = Console(theme=self.theme)

View file

@ -6,7 +6,7 @@ from backscattering_analyzer import interpolate
from scipy.signal import welch, detrend from scipy.signal import welch, detrend
from scipy.fft import irfft, rfft, rfftfreq from scipy.fft import irfft, rfft, rfftfreq
from scipy.interpolate import CubicSpline from scipy.interpolate import CubicSpline
from numpy import where, sin, cos, array, sqrt, float64 from numpy import logical_and, where, sin, cos, array, sqrt, float64
class Signal: class Signal:
@ -21,7 +21,11 @@ class Signal:
settings: Settings, settings: Settings,
) -> None: ) -> None:
if x.shape != value.shape: if x.shape != value.shape:
raise Exception("x and y does not have the same dimension") raise Exception(
"x and y does not have the same dimension ({x} and {y})".format(
x=x.shape, y=value.shape
)
)
self.sampling = x[1] - x[0] self.sampling = x[1] - x[0]
self.rate = 1 / self.sampling self.rate = 1 / self.sampling
self.x = x self.x = x
@ -66,6 +70,17 @@ class Signal:
self.settings, self.settings,
) )
def cut_x(self, start: float, end: float) -> Signal:
"""
Cut signal from a start value to an end value
"""
indexes = where(logical_and(self.x > start, self.x < end))
return Signal(
self.x[indexes],
self.y[indexes],
self.settings,
)
def low_pass_filter(self, cutoff: float) -> Signal: def low_pass_filter(self, cutoff: float) -> Signal:
""" """
Cut higher frequencies than cutoff for this signal Cut higher frequencies than cutoff for this signal
@ -76,9 +91,10 @@ class Signal:
freq_y = rfft(self.y) freq_y = rfft(self.y)
index_to_remove = where(abs(freq_x) > cutoff) index_to_remove = where(abs(freq_x) > cutoff)
freq_y[index_to_remove] = 0 freq_y[index_to_remove] = 0
y = irfft(freq_y)
signal = Signal( signal = Signal(
self.x, self.x[: len(y)],
irfft(freq_y), y[: len(self.x)],
self.settings, self.settings,
) )
return signal return signal