Skip to main content

vecfit/
fit.rs

1use faer::{Mat, prelude::SolveLstsq};
2use num_complex::Complex64;
3use serde::{Deserialize, Serialize};
4
5use crate::error::{Result, VecfitError};
6use crate::shape::{Layout, Shape};
7
8/// Borrowed flat `(sample, channel)` response data.
9#[derive(Debug, Clone, Copy)]
10pub struct SampleMatrixRef<'a> {
11    pub values: &'a [Complex64],
12    pub samples: usize,
13    pub channels: usize,
14}
15
16impl<'a> SampleMatrixRef<'a> {
17    pub fn new(values: &'a [Complex64], samples: usize, channels: usize) -> Result<Self> {
18        if samples == 0 {
19            return Err(VecfitError::Dimension(
20                "sample matrix must have at least one row".to_string(),
21            ));
22        }
23        if channels == 0 {
24            return Err(VecfitError::Dimension(
25                "sample matrix must have at least one channel".to_string(),
26            ));
27        }
28        if values.len() != samples * channels {
29            return Err(VecfitError::Dimension(format!(
30                "sample matrix length {} does not match {samples}x{channels}",
31                values.len()
32            )));
33        }
34        Ok(Self {
35            values,
36            samples,
37            channels,
38        })
39    }
40}
41
42/// Owned flat `(sample, channel)` response data.
43#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
44pub struct SampleMatrix {
45    pub values: Vec<Complex64>,
46    pub samples: usize,
47    pub channels: usize,
48}
49
50impl SampleMatrix {
51    pub fn new(values: Vec<Complex64>, samples: usize, channels: usize) -> Result<Self> {
52        SampleMatrixRef::new(&values, samples, channels)?;
53        Ok(Self {
54            values,
55            samples,
56            channels,
57        })
58    }
59
60    pub fn as_ref(&self) -> SampleMatrixRef<'_> {
61        SampleMatrixRef {
62            values: &self.values,
63            samples: self.samples,
64            channels: self.channels,
65        }
66    }
67
68    pub fn row(&self, idx: usize) -> &[Complex64] {
69        let start = idx * self.channels;
70        &self.values[start..start + self.channels]
71    }
72}
73
74/// Complete borrowed fitting problem description.
75#[derive(Debug, Clone, Copy)]
76pub struct ProblemRef<'a> {
77    pub axis: &'a [Complex64],
78    pub response: SampleMatrixRef<'a>,
79    pub weights: Option<&'a [f64]>,
80    pub shape: &'a Shape,
81    pub layout: Layout,
82}
83
84impl<'a> ProblemRef<'a> {
85    pub fn validate(&self) -> Result<()> {
86        if self.axis.is_empty() {
87            return Err(VecfitError::InvalidInput(
88                "sample axis cannot be empty".to_string(),
89            ));
90        }
91        if self.response.samples != self.axis.len() {
92            return Err(VecfitError::Dimension(format!(
93                "response rows {} do not match sample length {}",
94                self.response.samples,
95                self.axis.len()
96            )));
97        }
98        if self.response.channels != self.shape.channels() {
99            return Err(VecfitError::Dimension(format!(
100                "response channels {} do not match shape {:?}",
101                self.response.channels,
102                self.shape.dims()
103            )));
104        }
105        validate_weights(self.weights, self.axis.len())?;
106        Ok(())
107    }
108}
109
110/// Least-squares backend preference used during fitting.
111#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
112pub enum SolverPolicy {
113    #[default]
114    Auto,
115    ColPivQr,
116    SvdOnly,
117}
118
119/// Least-squares backend that actually produced the final solution.
120#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
121pub enum SolverUsed {
122    ColPivQr,
123    Svd,
124}
125
126/// Automatic per-sample weighting strategy for the fitting objective.
127#[derive(Debug, Clone, Serialize, Deserialize, Default)]
128pub enum WeightStrategy {
129    /// No automatic weighting (default). Minimizes absolute error.
130    #[default]
131    None,
132    /// Weight each sample by `1 / max_channel(|f(s_k)|)`, floored to avoid
133    /// division by zero.  Minimizes relative error uniformly across frequencies.
134    InverseMagnitude,
135}
136
137/// Automatically search for the best pole count within a range.
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct AutoPoles {
140    /// Smallest pole count to try (must be >= 1).
141    pub min_poles: usize,
142    /// Largest pole count to try.
143    pub max_poles: usize,
144    /// Stop early when relative RMSE drops below this target.
145    pub target_rel_rmse: f64,
146}
147
148impl Default for AutoPoles {
149    fn default() -> Self {
150        Self {
151            min_poles: 2,
152            max_poles: 30,
153            target_rel_rmse: 1e-3,
154        }
155    }
156}
157
158/// Tuning knobs for relaxed vector fitting.
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct Options {
161    /// Number of poles to fit.
162    pub poles: usize,
163    /// User-supplied starting poles (must match `poles` in length).
164    pub initial_poles: Option<Vec<Complex64>>,
165    /// Maximum pole-relocation iterations before stopping.
166    pub max_iterations: usize,
167    /// Convergence threshold on the relative pole shift.
168    pub tolerance: f64,
169    /// Whether to include a constant (d) term in the model.
170    pub fit_constant: bool,
171    /// Whether to include a proportional (e * s) term in the model.
172    pub fit_proportional: bool,
173    /// Constrain all poles to the real axis.
174    pub real_only: bool,
175    /// Explicit per-sample weights (length must match the sample axis).
176    pub weights: Option<Vec<f64>>,
177    /// Least-squares backend preference.
178    pub solver: SolverPolicy,
179    /// Automatic per-sample weighting strategy.
180    pub weight_strategy: WeightStrategy,
181    /// Maximum multi-start restart attempts.
182    pub max_restarts: usize,
183    /// Relative RMSE below which restarts stop early.
184    pub restart_threshold: f64,
185    /// Automatic pole-count search configuration.
186    pub auto_poles: Option<AutoPoles>,
187    /// Record pole positions at each iteration for migration diagnostics.
188    pub track_pole_history: bool,
189    /// Memory layout for flattened output (default: RowMajor).
190    pub layout: Layout,
191}
192
193impl Default for Options {
194    fn default() -> Self {
195        Self {
196            poles: 6,
197            initial_poles: None,
198            max_iterations: 30,
199            tolerance: 1e-9,
200            fit_constant: true,
201            fit_proportional: false,
202            real_only: false,
203            weights: None,
204            solver: SolverPolicy::Auto,
205            weight_strategy: WeightStrategy::None,
206            max_restarts: 3,
207            restart_threshold: 0.05,
208            auto_poles: None,
209            track_pole_history: false,
210            layout: Layout::RowMajor,
211        }
212    }
213}
214
215impl Options {
216    pub fn new() -> Self {
217        Self::default()
218    }
219
220    pub fn poles(mut self, poles: usize) -> Self {
221        self.poles = poles;
222        self
223    }
224
225    pub fn initial_poles(mut self, poles: Vec<Complex64>) -> Self {
226        self.initial_poles = Some(poles);
227        self
228    }
229
230    pub fn max_iterations(mut self, max_iterations: usize) -> Self {
231        self.max_iterations = max_iterations;
232        self
233    }
234
235    pub fn tolerance(mut self, tolerance: f64) -> Self {
236        self.tolerance = tolerance;
237        self
238    }
239
240    pub fn fit_constant(mut self, fit_constant: bool) -> Self {
241        self.fit_constant = fit_constant;
242        self
243    }
244
245    pub fn fit_proportional(mut self, fit_proportional: bool) -> Self {
246        self.fit_proportional = fit_proportional;
247        self
248    }
249
250    pub fn real_only(mut self, real_only: bool) -> Self {
251        self.real_only = real_only;
252        self
253    }
254
255    pub fn weights(mut self, weights: Vec<f64>) -> Self {
256        self.weights = Some(weights);
257        self
258    }
259
260    pub fn solver(mut self, solver: SolverPolicy) -> Self {
261        self.solver = solver;
262        self
263    }
264
265    pub fn weight_strategy(mut self, strategy: WeightStrategy) -> Self {
266        self.weight_strategy = strategy;
267        self
268    }
269
270    pub fn max_restarts(mut self, max_restarts: usize) -> Self {
271        self.max_restarts = max_restarts;
272        self
273    }
274
275    pub fn restart_threshold(mut self, threshold: f64) -> Self {
276        self.restart_threshold = threshold;
277        self
278    }
279
280    pub fn auto_poles(mut self, auto_poles: AutoPoles) -> Self {
281        self.auto_poles = Some(auto_poles);
282        self
283    }
284
285    pub fn track_pole_history(mut self, track: bool) -> Self {
286        self.track_pole_history = track;
287        self
288    }
289
290    pub fn layout(mut self, layout: Layout) -> Self {
291        self.layout = layout;
292        self
293    }
294
295    /// Shorthand for `Options::new().poles(n)`.
296    pub fn with_poles(n: usize) -> Self {
297        Self::new().poles(n)
298    }
299
300    /// Automatic pole-count search with default configuration.
301    pub fn auto() -> Self {
302        Self::new().auto_poles(AutoPoles::default())
303    }
304
305    /// Real-only fit with the given pole count.
306    pub fn real(n: usize) -> Self {
307        Self::new().poles(n).real_only(true)
308    }
309
310    /// Weighted fit with inverse-magnitude strategy.
311    pub fn weighted(n: usize) -> Self {
312        Self::new()
313            .poles(n)
314            .weight_strategy(WeightStrategy::InverseMagnitude)
315    }
316
317    /// Set convergence parameters (max iterations and tolerance).
318    pub fn convergence(mut self, max_iter: usize, tol: f64) -> Self {
319        self.max_iterations = max_iter;
320        self.tolerance = tol;
321        self
322    }
323
324    /// Set multi-start restart parameters.
325    pub fn restarts(mut self, max: usize, threshold: f64) -> Self {
326        self.max_restarts = max;
327        self.restart_threshold = threshold;
328        self
329    }
330}
331
332/// Summary statistics describing the last fit.
333#[derive(Debug, Clone, Serialize, Deserialize)]
334pub struct Report {
335    /// Whether the pole-relocation loop converged within `max_iterations`.
336    pub converged: bool,
337    /// Number of pole-relocation iterations performed.
338    pub iterations: usize,
339    /// Absolute root-mean-square error between the model and the reference data.
340    pub abs_rmse: f64,
341    /// Relative root-mean-square error (absolute RMSE divided by signal RMS).
342    pub rel_rmse: f64,
343    /// Largest relative pole shift in the final iteration.
344    pub max_pole_shift: f64,
345    /// Relative pole shift at each iteration (for convergence diagnostics).
346    pub pole_shifts: Vec<f64>,
347    /// Which least-squares backend produced the final solution.
348    pub solver_used: SolverUsed,
349    /// Whether the SVD fallback was triggered during the fit.
350    pub svd_fallback_used: bool,
351    /// Whether per-sample weights were applied.
352    pub weighted: bool,
353    /// Whether all poles have non-positive real parts.
354    pub stable: bool,
355    /// Whether the model can be exported as real first/second-order sections.
356    pub real_sections_valid: bool,
357    /// Number of multi-start restarts performed.
358    pub restarts: usize,
359    /// Per-channel absolute RMSE. Length = channels.
360    pub channel_abs_rmse: Vec<f64>,
361    /// Per-channel relative RMSE. Length = channels.
362    pub channel_rel_rmse: Vec<f64>,
363    /// Pole snapshots at each iteration. Only populated when Options::track_pole_history is true.
364    #[serde(default, skip_serializing_if = "Vec::is_empty")]
365    pub pole_history: Vec<Vec<[f64; 2]>>,
366}
367
368impl Default for Report {
369    fn default() -> Self {
370        Self {
371            converged: false,
372            iterations: 0,
373            abs_rmse: 0.0,
374            rel_rmse: 0.0,
375            max_pole_shift: 0.0,
376            pole_shifts: Vec::new(),
377            solver_used: SolverUsed::ColPivQr,
378            svd_fallback_used: false,
379            weighted: false,
380            stable: false,
381            real_sections_valid: false,
382            restarts: 0,
383            channel_abs_rmse: Vec::new(),
384            channel_rel_rmse: Vec::new(),
385            pole_history: Vec::new(),
386        }
387    }
388}
389
390impl std::fmt::Display for Report {
391    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392        writeln!(f, "VecFit Report")?;
393        writeln!(f, "  converged:    {}", self.converged)?;
394        writeln!(f, "  iterations:   {}", self.iterations)?;
395        writeln!(f, "  restarts:     {}", self.restarts)?;
396        writeln!(f, "  abs RMSE:     {:.6e}", self.abs_rmse)?;
397        writeln!(f, "  rel RMSE:     {:.6e}", self.rel_rmse)?;
398        writeln!(f, "  max pole shift: {:.6e}", self.max_pole_shift)?;
399        writeln!(f, "  solver:       {:?}", self.solver_used)?;
400        writeln!(f, "  SVD fallback: {}", self.svd_fallback_used)?;
401        writeln!(f, "  weighted:     {}", self.weighted)?;
402        writeln!(f, "  stable:       {}", self.stable)?;
403        write!(f, "  real sections: {}", self.real_sections_valid)?;
404        if !self.channel_abs_rmse.is_empty() {
405            writeln!(f)?;
406            writeln!(f, "  per-channel abs RMSE:")?;
407            for (i, rmse) in self.channel_abs_rmse.iter().enumerate() {
408                write!(f, "    ch {}: {:.6e}", i, rmse)?;
409                if i + 1 < self.channel_abs_rmse.len() {
410                    writeln!(f)?;
411                }
412            }
413        }
414        if !self.channel_rel_rmse.is_empty() {
415            writeln!(f)?;
416            writeln!(f, "  per-channel rel RMSE:")?;
417            for (i, rmse) in self.channel_rel_rmse.iter().enumerate() {
418                write!(f, "    ch {}: {:.6e}", i, rmse)?;
419                if i + 1 < self.channel_rel_rmse.len() {
420                    writeln!(f)?;
421                }
422            }
423        }
424        if !self.pole_history.is_empty() {
425            writeln!(f)?;
426            write!(
427                f,
428                "  pole history: {} iterations tracked",
429                self.pole_history.len()
430            )?;
431        }
432        Ok(())
433    }
434}
435
436pub(crate) fn matrix_from_row_major_slice(
437    values: &[Complex64],
438    rows: usize,
439    cols: usize,
440) -> Mat<Complex64> {
441    Mat::from_fn(rows, cols, |row, col| values[row * cols + col])
442}
443
444pub(crate) fn pole_basis_matrix(
445    axis: &[Complex64],
446    poles: &[Complex64],
447    fit_constant: bool,
448    fit_proportional: bool,
449) -> Mat<Complex64> {
450    let cols = poles.len() + usize::from(fit_constant) + usize::from(fit_proportional);
451    Mat::from_fn(axis.len(), cols, |row, col| {
452        if col < poles.len() {
453            Complex64::new(1.0, 0.0) / (axis[row] - poles[col])
454        } else if fit_constant && col == poles.len() {
455            Complex64::new(1.0, 0.0)
456        } else {
457            axis[row]
458        }
459    })
460}
461
462pub(crate) fn geometric_space(start: f64, stop: f64, count: usize) -> Vec<f64> {
463    match count {
464        0 => Vec::new(),
465        1 => vec![(start * stop).sqrt()],
466        _ => {
467            let log_start = start.log10();
468            let log_stop = stop.log10();
469            (0..count)
470                .map(|idx| {
471                    let blend = idx as f64 / (count as f64 - 1.0);
472                    10f64.powf(log_start + blend * (log_stop - log_start))
473                })
474                .collect()
475        }
476    }
477}
478
479/// Generate initial poles for the VF iteration.
480///
481/// When `real_only` is false, generates complex conjugate pairs with imaginary
482/// parts spanning the frequency range of the sample axis — matching the
483/// Gustavsen & Semlyen approach for resonant systems.  When `real_only` is
484/// true, generates purely real negative poles (suitable for monotonic responses).
485pub(crate) fn initial_poles(axis: &[Complex64], poles: usize, real_only: bool) -> Vec<Complex64> {
486    let mut magnitudes = axis
487        .iter()
488        .map(|value| value.im.abs().max(value.norm()))
489        .filter(|value| *value > 1e-15)
490        .collect::<Vec<_>>();
491    if magnitudes.is_empty() {
492        magnitudes.push(1.0);
493    }
494    magnitudes.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
495    let lo = (magnitudes[0] * 0.5).max(1e-6);
496    let hi = magnitudes[magnitudes.len() - 1] * 1.5;
497
498    if real_only || poles < 2 {
499        geometric_space(lo, hi, poles)
500            .into_iter()
501            .map(|value| Complex64::new(-value * 0.01, 0.0))
502            .collect()
503    } else {
504        let pair_count = poles / 2;
505        let has_extra = poles % 2 == 1;
506        let n_pts = pair_count + usize::from(has_extra);
507        let pts = geometric_space(lo, hi, n_pts);
508        let mut result = Vec::with_capacity(poles);
509        for (i, &beta) in pts.iter().enumerate() {
510            if has_extra && i == n_pts - 1 {
511                result.push(Complex64::new(-beta * 0.01, 0.0));
512            } else {
513                result.push(Complex64::new(-beta * 0.01, beta));
514                result.push(Complex64::new(-beta * 0.01, -beta));
515            }
516        }
517        result.truncate(poles);
518        result
519    }
520}
521
522pub(crate) fn apply_sample_weights(
523    matrix: &mut Mat<Complex64>,
524    weights: Option<&[f64]>,
525    rows_per_sample: usize,
526) {
527    if let Some(weights) = weights {
528        let rows_per_sample = rows_per_sample.max(1);
529        for row in 0..matrix.nrows() {
530            let weight = weights[row / rows_per_sample].sqrt();
531            for col in 0..matrix.ncols() {
532                matrix[(row, col)] *= weight;
533            }
534        }
535    }
536}
537
538pub(crate) fn solve_least_squares(
539    system: &Mat<Complex64>,
540    rhs: &Mat<Complex64>,
541    solver_policy: SolverPolicy,
542) -> Result<(Mat<Complex64>, SolverUsed, bool)> {
543    if system.nrows() < system.ncols() {
544        return Err(VecfitError::InvalidInput(format!(
545            "least-squares system is underdetermined ({} rows for {} unknowns); provide more samples or reduce the number of fitted terms",
546            system.nrows(),
547            system.ncols()
548        )));
549    }
550
551    match solver_policy {
552        SolverPolicy::SvdOnly => {
553            let svd = system.as_ref().thin_svd()?;
554            let solution = svd.solve_lstsq(rhs.as_ref());
555            Ok((solution, SolverUsed::Svd, false))
556        }
557        SolverPolicy::ColPivQr => {
558            let qr = system.as_ref().col_piv_qr();
559            let solution = qr.solve_lstsq(rhs.as_ref());
560            Ok((solution, SolverUsed::ColPivQr, false))
561        }
562        SolverPolicy::Auto => {
563            let qr = system.as_ref().col_piv_qr();
564            let solution = qr.solve_lstsq(rhs.as_ref());
565            let all_finite = (0..solution.nrows()).all(|row| {
566                (0..solution.ncols()).all(|col| {
567                    let v = solution[(row, col)];
568                    v.re.is_finite() && v.im.is_finite()
569                })
570            });
571            if !all_finite {
572                let svd = system.as_ref().thin_svd()?;
573                let fallback = svd.solve_lstsq(rhs.as_ref());
574                return Ok((fallback, SolverUsed::Svd, true));
575            }
576            Ok((solution, SolverUsed::ColPivQr, false))
577        }
578    }
579}
580
581/// Solve a column-scaled least-squares system for better conditioning.
582///
583/// Scales each column of `system` to unit norm before solving, then unscales
584/// the solution.  Returns the same tuple as `solve_least_squares`.
585pub(crate) fn solve_least_squares_scaled(
586    system: &Mat<Complex64>,
587    rhs: &Mat<Complex64>,
588    solver_policy: SolverPolicy,
589) -> Result<(Mat<Complex64>, SolverUsed, bool)> {
590    let cols = system.ncols();
591    let mut scaled = system.clone();
592    let mut col_norms = vec![0.0f64; cols];
593    for j in 0..cols {
594        let mut norm_sq = 0.0;
595        for i in 0..scaled.nrows() {
596            norm_sq += scaled[(i, j)].norm_sqr();
597        }
598        col_norms[j] = norm_sq.sqrt().max(1e-30);
599        let inv = 1.0 / col_norms[j];
600        for i in 0..scaled.nrows() {
601            scaled[(i, j)] *= inv;
602        }
603    }
604    let (mut solution, solver_used, fallback) = solve_least_squares(&scaled, rhs, solver_policy)?;
605    // solution shape: (system_cols, rhs_cols) — unscale each row
606    for j in 0..cols {
607        let inv = 1.0 / col_norms[j];
608        for k in 0..solution.ncols() {
609            solution[(j, k)] *= inv;
610        }
611    }
612    Ok((solution, solver_used, fallback))
613}
614
615/// Compute inverse-magnitude weights from a flat response buffer.
616///
617/// For each sample, takes the maximum channel magnitude and returns
618/// `1 / max(|f_ch|, floor)`.  The floor prevents division by zero at
619/// transmission zeros.
620pub(crate) fn compute_inverse_magnitude_weights(
621    values: &[Complex64],
622    samples: usize,
623    channels: usize,
624) -> Vec<f64> {
625    let mut weights = Vec::with_capacity(samples);
626    let mut max_mag = 0.0f64;
627    for k in 0..samples {
628        let mut sample_max = 0.0f64;
629        for ch in 0..channels {
630            sample_max = sample_max.max(values[k * channels + ch].norm());
631        }
632        max_mag = max_mag.max(sample_max);
633        weights.push(sample_max);
634    }
635    let floor = max_mag * 1e-8;
636    for w in &mut weights {
637        *w = 1.0 / (*w).max(floor);
638    }
639    weights
640}
641
642pub(crate) fn validate_weights(weights: Option<&[f64]>, samples: usize) -> Result<()> {
643    if let Some(weights) = weights {
644        if weights.len() != samples {
645            return Err(VecfitError::Dimension(format!(
646                "weights length {} does not match sample length {}",
647                weights.len(),
648                samples
649            )));
650        }
651        if weights
652            .iter()
653            .any(|weight| !weight.is_finite() || *weight < 0.0)
654        {
655            return Err(VecfitError::InvalidInput(
656                "weights must be finite and nonnegative".to_string(),
657            ));
658        }
659    }
660    Ok(())
661}