diff --git a/pyproject.toml b/pyproject.toml index 2bbc7a8..3646bec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,3 +42,7 @@ docstring-code-format = true [tool.basedpyright] typeCheckingMode = "all" +reportUnknown = false +reportUnknownMemberType = false +reportUnknownArgumentType = false +reportUnknownVariableType = false diff --git a/src/backscattering_analyzer/__init__.py b/src/backscattering_analyzer/__init__.py index adb422c..04c6029 100644 --- a/src/backscattering_analyzer/__init__.py +++ b/src/backscattering_analyzer/__init__.py @@ -98,4 +98,39 @@ def opt_compute_light( ) +def fit_compute_light( + scatter_factor: float, + factor_n: Signal, + coupling_n: Signal, + factor_d: Signal, + coupling_d: Signal, + power_in: float, + power_out: float, + data: Signal, + reference: Signal, +) -> float: + """ + Scalar function used to find the right scattering factor + """ + return sum( + abs( + Signal( + factor_n.x, + opt_compute_light( + scatter_factor=scatter_factor, + factor_n=factor_n, + coupling_n=coupling_n, + factor_d=factor_d, + coupling_d=coupling_d, + power_in=power_in, + power_out=power_out, + ), + factor_n.settings, + ) + + reference + - data + ).y + ) + + from backscattering_analyzer.signal import Signal # no circular import diff --git a/src/backscattering_analyzer/analyzer.py b/src/backscattering_analyzer/analyzer.py index c674e35..c345bc7 100644 --- a/src/backscattering_analyzer/analyzer.py +++ b/src/backscattering_analyzer/analyzer.py @@ -1,14 +1,15 @@ -# utils from sys import argv from backscattering_analyzer.settings import Settings from backscattering_analyzer.signal import Signal from backscattering_analyzer import ( compute_light, + fit_compute_light, opt_compute_light, interpolate, ) -from numpy import loadtxt, logspace, where, zeros, argmin, intp, pi +from numpy import loadtxt, pi from scipy.io.matlab import loadmat +from scipy.optimize import Bounds, minimize class Analyzer: @@ -149,24 +150,22 @@ class Analyzer: power_out=self.settings.power_out, ) - def fit_scatter_factor( - self, start: int, stop: int, number: int - ) -> tuple[intp, float]: + def fit_scatter_factor(self, guess: None | float = None) -> float: """ Find the best scatter factor (first order only) in the given range """ - import matplotlib.pyplot as plt - factors = logspace(start, stop, number) - sums = zeros(number) - + if guess is None: + guess = self.settings.scattering_factor[0] phase = 4 * pi / self.settings.wavelength factor_n = (self.movement * phase).sin().psd().sqrt() coupling_n = self.coupling[0].abs() factor_d = (self.movement * phase).cos().psd().sqrt() coupling_d = self.coupling[1].abs() - coupling_d.cut(5, 40) # cut signal between 5 and 40 Hz + coupling_d = coupling_d.cut_x( + 10, 200 + ) # cut signal between 10 and 200 Hz factor_n, coupling_n, factor_d, coupling_d, data, reference = ( interpolate( @@ -181,25 +180,56 @@ class Analyzer: ) ) - reference = reference.y - data = data.y + bounds = Bounds(0, 1) + min_result = minimize( + fit_compute_light, + guess, + ( + factor_n, + coupling_n, + factor_d, + coupling_d, + self.settings.power_in, + self.settings.power_out, + data, + reference, + ), + method = 'TNC', + bounds=bounds, + ) - for index in range(number): - self.settings.log("{}".format(index)) - projection = opt_compute_light( - scatter_factor=factors[index], + if not min_result.success: + raise Exception(min_result.message) + + self.settings.log( + "found the best scattering factor in {} iterations".format( + min_result.nit + ) + ) + + import matplotlib.pyplot as plt + + projection = Signal( + factor_n.x, + opt_compute_light( + scatter_factor=min_result.x, factor_n=factor_n, coupling_n=coupling_n, factor_d=factor_d, coupling_d=coupling_d, power_in=self.settings.power_in, power_out=self.settings.power_out, - ) - diff = abs(projection + reference - data) - _ = plt.loglog(projection + reference) - _ = plt.loglog(data) - _ = plt.show() - sums[index] = sum(diff) - min_index = argmin(sums) + ), + self.settings, + ) - return min_index, factors[min_index] + _ = plt.loglog(projection.x, projection.y, label="projection") + _ = plt.loglog(reference.x, reference.y, label="référence") + _ = plt.loglog(data.x, data.y, label="data") + _ = plt.loglog( + reference.x, projection.y + reference.y, label="somme" + ) + _ = plt.legend() + _ = plt.show() + + return min_result.x diff --git a/src/backscattering_analyzer/settings.py b/src/backscattering_analyzer/settings.py index 3797ece..2f591d7 100644 --- a/src/backscattering_analyzer/settings.py +++ b/src/backscattering_analyzer/settings.py @@ -14,6 +14,7 @@ class Settings: "option": "grey50 italic", "argument": "red", "description": "green italic", + "warning": "bold red", } ) self.console = Console(theme=self.theme) diff --git a/src/backscattering_analyzer/signal.py b/src/backscattering_analyzer/signal.py index 1829df8..de84f19 100644 --- a/src/backscattering_analyzer/signal.py +++ b/src/backscattering_analyzer/signal.py @@ -6,7 +6,7 @@ from backscattering_analyzer import interpolate from scipy.signal import welch, detrend from scipy.fft import irfft, rfft, rfftfreq from scipy.interpolate import CubicSpline -from numpy import where, sin, cos, array, sqrt, float64 +from numpy import logical_and, where, sin, cos, array, sqrt, float64 class Signal: @@ -21,7 +21,11 @@ class Signal: settings: Settings, ) -> None: if x.shape != value.shape: - raise Exception("x and y does not have the same dimension") + raise Exception( + "x and y does not have the same dimension ({x} and {y})".format( + x=x.shape, y=value.shape + ) + ) self.sampling = x[1] - x[0] self.rate = 1 / self.sampling self.x = x @@ -66,6 +70,17 @@ class Signal: self.settings, ) + def cut_x(self, start: float, end: float) -> Signal: + """ + Cut signal from a start value to an end value + """ + indexes = where(logical_and(self.x > start, self.x < end)) + return Signal( + self.x[indexes], + self.y[indexes], + self.settings, + ) + def low_pass_filter(self, cutoff: float) -> Signal: """ Cut higher frequencies than cutoff for this signal @@ -76,9 +91,10 @@ class Signal: freq_y = rfft(self.y) index_to_remove = where(abs(freq_x) > cutoff) freq_y[index_to_remove] = 0 + y = irfft(freq_y) signal = Signal( - self.x, - irfft(freq_y), + self.x[: len(y)], + y[: len(self.x)], self.settings, ) return signal