Fix fitting & typing
This commit is contained in:
parent
ec5e67536e
commit
59a01b77ed
5 changed files with 114 additions and 28 deletions
|
@ -42,3 +42,7 @@ docstring-code-format = true
|
||||||
|
|
||||||
[tool.basedpyright]
|
[tool.basedpyright]
|
||||||
typeCheckingMode = "all"
|
typeCheckingMode = "all"
|
||||||
|
reportUnknown = false
|
||||||
|
reportUnknownMemberType = false
|
||||||
|
reportUnknownArgumentType = false
|
||||||
|
reportUnknownVariableType = false
|
||||||
|
|
|
@ -98,4 +98,39 @@ def opt_compute_light(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fit_compute_light(
|
||||||
|
scatter_factor: float,
|
||||||
|
factor_n: Signal,
|
||||||
|
coupling_n: Signal,
|
||||||
|
factor_d: Signal,
|
||||||
|
coupling_d: Signal,
|
||||||
|
power_in: float,
|
||||||
|
power_out: float,
|
||||||
|
data: Signal,
|
||||||
|
reference: Signal,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Scalar function used to find the right scattering factor
|
||||||
|
"""
|
||||||
|
return sum(
|
||||||
|
abs(
|
||||||
|
Signal(
|
||||||
|
factor_n.x,
|
||||||
|
opt_compute_light(
|
||||||
|
scatter_factor=scatter_factor,
|
||||||
|
factor_n=factor_n,
|
||||||
|
coupling_n=coupling_n,
|
||||||
|
factor_d=factor_d,
|
||||||
|
coupling_d=coupling_d,
|
||||||
|
power_in=power_in,
|
||||||
|
power_out=power_out,
|
||||||
|
),
|
||||||
|
factor_n.settings,
|
||||||
|
)
|
||||||
|
+ reference
|
||||||
|
- data
|
||||||
|
).y
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
from backscattering_analyzer.signal import Signal # no circular import
|
from backscattering_analyzer.signal import Signal # no circular import
|
||||||
|
|
|
@ -1,14 +1,15 @@
|
||||||
# utils
|
|
||||||
from sys import argv
|
from sys import argv
|
||||||
from backscattering_analyzer.settings import Settings
|
from backscattering_analyzer.settings import Settings
|
||||||
from backscattering_analyzer.signal import Signal
|
from backscattering_analyzer.signal import Signal
|
||||||
from backscattering_analyzer import (
|
from backscattering_analyzer import (
|
||||||
compute_light,
|
compute_light,
|
||||||
|
fit_compute_light,
|
||||||
opt_compute_light,
|
opt_compute_light,
|
||||||
interpolate,
|
interpolate,
|
||||||
)
|
)
|
||||||
from numpy import loadtxt, logspace, where, zeros, argmin, intp, pi
|
from numpy import loadtxt, pi
|
||||||
from scipy.io.matlab import loadmat
|
from scipy.io.matlab import loadmat
|
||||||
|
from scipy.optimize import Bounds, minimize
|
||||||
|
|
||||||
|
|
||||||
class Analyzer:
|
class Analyzer:
|
||||||
|
@ -149,24 +150,22 @@ class Analyzer:
|
||||||
power_out=self.settings.power_out,
|
power_out=self.settings.power_out,
|
||||||
)
|
)
|
||||||
|
|
||||||
def fit_scatter_factor(
|
def fit_scatter_factor(self, guess: None | float = None) -> float:
|
||||||
self, start: int, stop: int, number: int
|
|
||||||
) -> tuple[intp, float]:
|
|
||||||
"""
|
"""
|
||||||
Find the best scatter factor (first order only) in the given
|
Find the best scatter factor (first order only) in the given
|
||||||
range
|
range
|
||||||
"""
|
"""
|
||||||
import matplotlib.pyplot as plt
|
if guess is None:
|
||||||
factors = logspace(start, stop, number)
|
guess = self.settings.scattering_factor[0]
|
||||||
sums = zeros(number)
|
|
||||||
|
|
||||||
phase = 4 * pi / self.settings.wavelength
|
phase = 4 * pi / self.settings.wavelength
|
||||||
factor_n = (self.movement * phase).sin().psd().sqrt()
|
factor_n = (self.movement * phase).sin().psd().sqrt()
|
||||||
coupling_n = self.coupling[0].abs()
|
coupling_n = self.coupling[0].abs()
|
||||||
factor_d = (self.movement * phase).cos().psd().sqrt()
|
factor_d = (self.movement * phase).cos().psd().sqrt()
|
||||||
coupling_d = self.coupling[1].abs()
|
coupling_d = self.coupling[1].abs()
|
||||||
|
|
||||||
coupling_d.cut(5, 40) # cut signal between 5 and 40 Hz
|
coupling_d = coupling_d.cut_x(
|
||||||
|
10, 200
|
||||||
|
) # cut signal between 10 and 200 Hz
|
||||||
|
|
||||||
factor_n, coupling_n, factor_d, coupling_d, data, reference = (
|
factor_n, coupling_n, factor_d, coupling_d, data, reference = (
|
||||||
interpolate(
|
interpolate(
|
||||||
|
@ -181,25 +180,56 @@ class Analyzer:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
reference = reference.y
|
bounds = Bounds(0, 1)
|
||||||
data = data.y
|
min_result = minimize(
|
||||||
|
fit_compute_light,
|
||||||
|
guess,
|
||||||
|
(
|
||||||
|
factor_n,
|
||||||
|
coupling_n,
|
||||||
|
factor_d,
|
||||||
|
coupling_d,
|
||||||
|
self.settings.power_in,
|
||||||
|
self.settings.power_out,
|
||||||
|
data,
|
||||||
|
reference,
|
||||||
|
),
|
||||||
|
method = 'TNC',
|
||||||
|
bounds=bounds,
|
||||||
|
)
|
||||||
|
|
||||||
for index in range(number):
|
if not min_result.success:
|
||||||
self.settings.log("{}".format(index))
|
raise Exception(min_result.message)
|
||||||
projection = opt_compute_light(
|
|
||||||
scatter_factor=factors[index],
|
self.settings.log(
|
||||||
|
"found the best scattering factor in {} iterations".format(
|
||||||
|
min_result.nit
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
projection = Signal(
|
||||||
|
factor_n.x,
|
||||||
|
opt_compute_light(
|
||||||
|
scatter_factor=min_result.x,
|
||||||
factor_n=factor_n,
|
factor_n=factor_n,
|
||||||
coupling_n=coupling_n,
|
coupling_n=coupling_n,
|
||||||
factor_d=factor_d,
|
factor_d=factor_d,
|
||||||
coupling_d=coupling_d,
|
coupling_d=coupling_d,
|
||||||
power_in=self.settings.power_in,
|
power_in=self.settings.power_in,
|
||||||
power_out=self.settings.power_out,
|
power_out=self.settings.power_out,
|
||||||
)
|
),
|
||||||
diff = abs(projection + reference - data)
|
self.settings,
|
||||||
_ = plt.loglog(projection + reference)
|
)
|
||||||
_ = plt.loglog(data)
|
|
||||||
_ = plt.show()
|
|
||||||
sums[index] = sum(diff)
|
|
||||||
min_index = argmin(sums)
|
|
||||||
|
|
||||||
return min_index, factors[min_index]
|
_ = plt.loglog(projection.x, projection.y, label="projection")
|
||||||
|
_ = plt.loglog(reference.x, reference.y, label="référence")
|
||||||
|
_ = plt.loglog(data.x, data.y, label="data")
|
||||||
|
_ = plt.loglog(
|
||||||
|
reference.x, projection.y + reference.y, label="somme"
|
||||||
|
)
|
||||||
|
_ = plt.legend()
|
||||||
|
_ = plt.show()
|
||||||
|
|
||||||
|
return min_result.x
|
||||||
|
|
|
@ -14,6 +14,7 @@ class Settings:
|
||||||
"option": "grey50 italic",
|
"option": "grey50 italic",
|
||||||
"argument": "red",
|
"argument": "red",
|
||||||
"description": "green italic",
|
"description": "green italic",
|
||||||
|
"warning": "bold red",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
self.console = Console(theme=self.theme)
|
self.console = Console(theme=self.theme)
|
||||||
|
|
|
@ -6,7 +6,7 @@ from backscattering_analyzer import interpolate
|
||||||
from scipy.signal import welch, detrend
|
from scipy.signal import welch, detrend
|
||||||
from scipy.fft import irfft, rfft, rfftfreq
|
from scipy.fft import irfft, rfft, rfftfreq
|
||||||
from scipy.interpolate import CubicSpline
|
from scipy.interpolate import CubicSpline
|
||||||
from numpy import where, sin, cos, array, sqrt, float64
|
from numpy import logical_and, where, sin, cos, array, sqrt, float64
|
||||||
|
|
||||||
|
|
||||||
class Signal:
|
class Signal:
|
||||||
|
@ -21,7 +21,11 @@ class Signal:
|
||||||
settings: Settings,
|
settings: Settings,
|
||||||
) -> None:
|
) -> None:
|
||||||
if x.shape != value.shape:
|
if x.shape != value.shape:
|
||||||
raise Exception("x and y does not have the same dimension")
|
raise Exception(
|
||||||
|
"x and y does not have the same dimension ({x} and {y})".format(
|
||||||
|
x=x.shape, y=value.shape
|
||||||
|
)
|
||||||
|
)
|
||||||
self.sampling = x[1] - x[0]
|
self.sampling = x[1] - x[0]
|
||||||
self.rate = 1 / self.sampling
|
self.rate = 1 / self.sampling
|
||||||
self.x = x
|
self.x = x
|
||||||
|
@ -66,6 +70,17 @@ class Signal:
|
||||||
self.settings,
|
self.settings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def cut_x(self, start: float, end: float) -> Signal:
|
||||||
|
"""
|
||||||
|
Cut signal from a start value to an end value
|
||||||
|
"""
|
||||||
|
indexes = where(logical_and(self.x > start, self.x < end))
|
||||||
|
return Signal(
|
||||||
|
self.x[indexes],
|
||||||
|
self.y[indexes],
|
||||||
|
self.settings,
|
||||||
|
)
|
||||||
|
|
||||||
def low_pass_filter(self, cutoff: float) -> Signal:
|
def low_pass_filter(self, cutoff: float) -> Signal:
|
||||||
"""
|
"""
|
||||||
Cut higher frequencies than cutoff for this signal
|
Cut higher frequencies than cutoff for this signal
|
||||||
|
@ -76,9 +91,10 @@ class Signal:
|
||||||
freq_y = rfft(self.y)
|
freq_y = rfft(self.y)
|
||||||
index_to_remove = where(abs(freq_x) > cutoff)
|
index_to_remove = where(abs(freq_x) > cutoff)
|
||||||
freq_y[index_to_remove] = 0
|
freq_y[index_to_remove] = 0
|
||||||
|
y = irfft(freq_y)
|
||||||
signal = Signal(
|
signal = Signal(
|
||||||
self.x,
|
self.x[: len(y)],
|
||||||
irfft(freq_y),
|
y[: len(self.x)],
|
||||||
self.settings,
|
self.settings,
|
||||||
)
|
)
|
||||||
return signal
|
return signal
|
||||||
|
|
Loading…
Reference in a new issue