1use faer::{Mat, prelude::SolveLstsq};
2use num_complex::Complex64;
3use serde::{Deserialize, Serialize};
4
5use crate::error::{Result, VecfitError};
6use crate::shape::{Layout, Shape};
7
8#[derive(Debug, Clone, Copy)]
10pub struct SampleMatrixRef<'a> {
11 pub values: &'a [Complex64],
12 pub samples: usize,
13 pub channels: usize,
14}
15
16impl<'a> SampleMatrixRef<'a> {
17 pub fn new(values: &'a [Complex64], samples: usize, channels: usize) -> Result<Self> {
18 if samples == 0 {
19 return Err(VecfitError::Dimension(
20 "sample matrix must have at least one row".to_string(),
21 ));
22 }
23 if channels == 0 {
24 return Err(VecfitError::Dimension(
25 "sample matrix must have at least one channel".to_string(),
26 ));
27 }
28 if values.len() != samples * channels {
29 return Err(VecfitError::Dimension(format!(
30 "sample matrix length {} does not match {samples}x{channels}",
31 values.len()
32 )));
33 }
34 Ok(Self {
35 values,
36 samples,
37 channels,
38 })
39 }
40}
41
42#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
44pub struct SampleMatrix {
45 pub values: Vec<Complex64>,
46 pub samples: usize,
47 pub channels: usize,
48}
49
50impl SampleMatrix {
51 pub fn new(values: Vec<Complex64>, samples: usize, channels: usize) -> Result<Self> {
52 SampleMatrixRef::new(&values, samples, channels)?;
53 Ok(Self {
54 values,
55 samples,
56 channels,
57 })
58 }
59
60 pub fn as_ref(&self) -> SampleMatrixRef<'_> {
61 SampleMatrixRef {
62 values: &self.values,
63 samples: self.samples,
64 channels: self.channels,
65 }
66 }
67
68 pub fn row(&self, idx: usize) -> &[Complex64] {
69 let start = idx * self.channels;
70 &self.values[start..start + self.channels]
71 }
72}
73
74#[derive(Debug, Clone, Copy)]
76pub struct ProblemRef<'a> {
77 pub axis: &'a [Complex64],
78 pub response: SampleMatrixRef<'a>,
79 pub weights: Option<&'a [f64]>,
80 pub shape: &'a Shape,
81 pub layout: Layout,
82}
83
84impl<'a> ProblemRef<'a> {
85 pub fn validate(&self) -> Result<()> {
86 if self.axis.is_empty() {
87 return Err(VecfitError::InvalidInput(
88 "sample axis cannot be empty".to_string(),
89 ));
90 }
91 if self.response.samples != self.axis.len() {
92 return Err(VecfitError::Dimension(format!(
93 "response rows {} do not match sample length {}",
94 self.response.samples,
95 self.axis.len()
96 )));
97 }
98 if self.response.channels != self.shape.channels() {
99 return Err(VecfitError::Dimension(format!(
100 "response channels {} do not match shape {:?}",
101 self.response.channels,
102 self.shape.dims()
103 )));
104 }
105 validate_weights(self.weights, self.axis.len())?;
106 Ok(())
107 }
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
112pub enum SolverPolicy {
113 #[default]
114 Auto,
115 ColPivQr,
116 SvdOnly,
117}
118
119#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
121pub enum SolverUsed {
122 ColPivQr,
123 Svd,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize, Default)]
128pub enum WeightStrategy {
129 #[default]
131 None,
132 InverseMagnitude,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct AutoPoles {
140 pub min_poles: usize,
142 pub max_poles: usize,
144 pub target_rel_rmse: f64,
146}
147
148impl Default for AutoPoles {
149 fn default() -> Self {
150 Self {
151 min_poles: 2,
152 max_poles: 30,
153 target_rel_rmse: 1e-3,
154 }
155 }
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct Options {
161 pub poles: usize,
163 pub initial_poles: Option<Vec<Complex64>>,
165 pub max_iterations: usize,
167 pub tolerance: f64,
169 pub fit_constant: bool,
171 pub fit_proportional: bool,
173 pub real_only: bool,
175 pub weights: Option<Vec<f64>>,
177 pub solver: SolverPolicy,
179 pub weight_strategy: WeightStrategy,
181 pub max_restarts: usize,
183 pub restart_threshold: f64,
185 pub auto_poles: Option<AutoPoles>,
187 pub track_pole_history: bool,
189 pub layout: Layout,
191}
192
193impl Default for Options {
194 fn default() -> Self {
195 Self {
196 poles: 6,
197 initial_poles: None,
198 max_iterations: 30,
199 tolerance: 1e-9,
200 fit_constant: true,
201 fit_proportional: false,
202 real_only: false,
203 weights: None,
204 solver: SolverPolicy::Auto,
205 weight_strategy: WeightStrategy::None,
206 max_restarts: 3,
207 restart_threshold: 0.05,
208 auto_poles: None,
209 track_pole_history: false,
210 layout: Layout::RowMajor,
211 }
212 }
213}
214
215impl Options {
216 pub fn new() -> Self {
217 Self::default()
218 }
219
220 pub fn poles(mut self, poles: usize) -> Self {
221 self.poles = poles;
222 self
223 }
224
225 pub fn initial_poles(mut self, poles: Vec<Complex64>) -> Self {
226 self.initial_poles = Some(poles);
227 self
228 }
229
230 pub fn max_iterations(mut self, max_iterations: usize) -> Self {
231 self.max_iterations = max_iterations;
232 self
233 }
234
235 pub fn tolerance(mut self, tolerance: f64) -> Self {
236 self.tolerance = tolerance;
237 self
238 }
239
240 pub fn fit_constant(mut self, fit_constant: bool) -> Self {
241 self.fit_constant = fit_constant;
242 self
243 }
244
245 pub fn fit_proportional(mut self, fit_proportional: bool) -> Self {
246 self.fit_proportional = fit_proportional;
247 self
248 }
249
250 pub fn real_only(mut self, real_only: bool) -> Self {
251 self.real_only = real_only;
252 self
253 }
254
255 pub fn weights(mut self, weights: Vec<f64>) -> Self {
256 self.weights = Some(weights);
257 self
258 }
259
260 pub fn solver(mut self, solver: SolverPolicy) -> Self {
261 self.solver = solver;
262 self
263 }
264
265 pub fn weight_strategy(mut self, strategy: WeightStrategy) -> Self {
266 self.weight_strategy = strategy;
267 self
268 }
269
270 pub fn max_restarts(mut self, max_restarts: usize) -> Self {
271 self.max_restarts = max_restarts;
272 self
273 }
274
275 pub fn restart_threshold(mut self, threshold: f64) -> Self {
276 self.restart_threshold = threshold;
277 self
278 }
279
280 pub fn auto_poles(mut self, auto_poles: AutoPoles) -> Self {
281 self.auto_poles = Some(auto_poles);
282 self
283 }
284
285 pub fn track_pole_history(mut self, track: bool) -> Self {
286 self.track_pole_history = track;
287 self
288 }
289
290 pub fn layout(mut self, layout: Layout) -> Self {
291 self.layout = layout;
292 self
293 }
294
295 pub fn with_poles(n: usize) -> Self {
297 Self::new().poles(n)
298 }
299
300 pub fn auto() -> Self {
302 Self::new().auto_poles(AutoPoles::default())
303 }
304
305 pub fn real(n: usize) -> Self {
307 Self::new().poles(n).real_only(true)
308 }
309
310 pub fn weighted(n: usize) -> Self {
312 Self::new()
313 .poles(n)
314 .weight_strategy(WeightStrategy::InverseMagnitude)
315 }
316
317 pub fn convergence(mut self, max_iter: usize, tol: f64) -> Self {
319 self.max_iterations = max_iter;
320 self.tolerance = tol;
321 self
322 }
323
324 pub fn restarts(mut self, max: usize, threshold: f64) -> Self {
326 self.max_restarts = max;
327 self.restart_threshold = threshold;
328 self
329 }
330}
331
332#[derive(Debug, Clone, Serialize, Deserialize)]
334pub struct Report {
335 pub converged: bool,
337 pub iterations: usize,
339 pub abs_rmse: f64,
341 pub rel_rmse: f64,
343 pub max_pole_shift: f64,
345 pub pole_shifts: Vec<f64>,
347 pub solver_used: SolverUsed,
349 pub svd_fallback_used: bool,
351 pub weighted: bool,
353 pub stable: bool,
355 pub real_sections_valid: bool,
357 pub restarts: usize,
359 pub channel_abs_rmse: Vec<f64>,
361 pub channel_rel_rmse: Vec<f64>,
363 #[serde(default, skip_serializing_if = "Vec::is_empty")]
365 pub pole_history: Vec<Vec<[f64; 2]>>,
366}
367
368impl Default for Report {
369 fn default() -> Self {
370 Self {
371 converged: false,
372 iterations: 0,
373 abs_rmse: 0.0,
374 rel_rmse: 0.0,
375 max_pole_shift: 0.0,
376 pole_shifts: Vec::new(),
377 solver_used: SolverUsed::ColPivQr,
378 svd_fallback_used: false,
379 weighted: false,
380 stable: false,
381 real_sections_valid: false,
382 restarts: 0,
383 channel_abs_rmse: Vec::new(),
384 channel_rel_rmse: Vec::new(),
385 pole_history: Vec::new(),
386 }
387 }
388}
389
390impl std::fmt::Display for Report {
391 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392 writeln!(f, "VecFit Report")?;
393 writeln!(f, " converged: {}", self.converged)?;
394 writeln!(f, " iterations: {}", self.iterations)?;
395 writeln!(f, " restarts: {}", self.restarts)?;
396 writeln!(f, " abs RMSE: {:.6e}", self.abs_rmse)?;
397 writeln!(f, " rel RMSE: {:.6e}", self.rel_rmse)?;
398 writeln!(f, " max pole shift: {:.6e}", self.max_pole_shift)?;
399 writeln!(f, " solver: {:?}", self.solver_used)?;
400 writeln!(f, " SVD fallback: {}", self.svd_fallback_used)?;
401 writeln!(f, " weighted: {}", self.weighted)?;
402 writeln!(f, " stable: {}", self.stable)?;
403 write!(f, " real sections: {}", self.real_sections_valid)?;
404 if !self.channel_abs_rmse.is_empty() {
405 writeln!(f)?;
406 writeln!(f, " per-channel abs RMSE:")?;
407 for (i, rmse) in self.channel_abs_rmse.iter().enumerate() {
408 write!(f, " ch {}: {:.6e}", i, rmse)?;
409 if i + 1 < self.channel_abs_rmse.len() {
410 writeln!(f)?;
411 }
412 }
413 }
414 if !self.channel_rel_rmse.is_empty() {
415 writeln!(f)?;
416 writeln!(f, " per-channel rel RMSE:")?;
417 for (i, rmse) in self.channel_rel_rmse.iter().enumerate() {
418 write!(f, " ch {}: {:.6e}", i, rmse)?;
419 if i + 1 < self.channel_rel_rmse.len() {
420 writeln!(f)?;
421 }
422 }
423 }
424 if !self.pole_history.is_empty() {
425 writeln!(f)?;
426 write!(
427 f,
428 " pole history: {} iterations tracked",
429 self.pole_history.len()
430 )?;
431 }
432 Ok(())
433 }
434}
435
436pub(crate) fn matrix_from_row_major_slice(
437 values: &[Complex64],
438 rows: usize,
439 cols: usize,
440) -> Mat<Complex64> {
441 Mat::from_fn(rows, cols, |row, col| values[row * cols + col])
442}
443
444pub(crate) fn pole_basis_matrix(
445 axis: &[Complex64],
446 poles: &[Complex64],
447 fit_constant: bool,
448 fit_proportional: bool,
449) -> Mat<Complex64> {
450 let cols = poles.len() + usize::from(fit_constant) + usize::from(fit_proportional);
451 Mat::from_fn(axis.len(), cols, |row, col| {
452 if col < poles.len() {
453 Complex64::new(1.0, 0.0) / (axis[row] - poles[col])
454 } else if fit_constant && col == poles.len() {
455 Complex64::new(1.0, 0.0)
456 } else {
457 axis[row]
458 }
459 })
460}
461
462pub(crate) fn geometric_space(start: f64, stop: f64, count: usize) -> Vec<f64> {
463 match count {
464 0 => Vec::new(),
465 1 => vec![(start * stop).sqrt()],
466 _ => {
467 let log_start = start.log10();
468 let log_stop = stop.log10();
469 (0..count)
470 .map(|idx| {
471 let blend = idx as f64 / (count as f64 - 1.0);
472 10f64.powf(log_start + blend * (log_stop - log_start))
473 })
474 .collect()
475 }
476 }
477}
478
479pub(crate) fn initial_poles(axis: &[Complex64], poles: usize, real_only: bool) -> Vec<Complex64> {
486 let mut magnitudes = axis
487 .iter()
488 .map(|value| value.im.abs().max(value.norm()))
489 .filter(|value| *value > 1e-15)
490 .collect::<Vec<_>>();
491 if magnitudes.is_empty() {
492 magnitudes.push(1.0);
493 }
494 magnitudes.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
495 let lo = (magnitudes[0] * 0.5).max(1e-6);
496 let hi = magnitudes[magnitudes.len() - 1] * 1.5;
497
498 if real_only || poles < 2 {
499 geometric_space(lo, hi, poles)
500 .into_iter()
501 .map(|value| Complex64::new(-value * 0.01, 0.0))
502 .collect()
503 } else {
504 let pair_count = poles / 2;
505 let has_extra = poles % 2 == 1;
506 let n_pts = pair_count + usize::from(has_extra);
507 let pts = geometric_space(lo, hi, n_pts);
508 let mut result = Vec::with_capacity(poles);
509 for (i, &beta) in pts.iter().enumerate() {
510 if has_extra && i == n_pts - 1 {
511 result.push(Complex64::new(-beta * 0.01, 0.0));
512 } else {
513 result.push(Complex64::new(-beta * 0.01, beta));
514 result.push(Complex64::new(-beta * 0.01, -beta));
515 }
516 }
517 result.truncate(poles);
518 result
519 }
520}
521
522pub(crate) fn apply_sample_weights(
523 matrix: &mut Mat<Complex64>,
524 weights: Option<&[f64]>,
525 rows_per_sample: usize,
526) {
527 if let Some(weights) = weights {
528 let rows_per_sample = rows_per_sample.max(1);
529 for row in 0..matrix.nrows() {
530 let weight = weights[row / rows_per_sample].sqrt();
531 for col in 0..matrix.ncols() {
532 matrix[(row, col)] *= weight;
533 }
534 }
535 }
536}
537
538pub(crate) fn solve_least_squares(
539 system: &Mat<Complex64>,
540 rhs: &Mat<Complex64>,
541 solver_policy: SolverPolicy,
542) -> Result<(Mat<Complex64>, SolverUsed, bool)> {
543 if system.nrows() < system.ncols() {
544 return Err(VecfitError::InvalidInput(format!(
545 "least-squares system is underdetermined ({} rows for {} unknowns); provide more samples or reduce the number of fitted terms",
546 system.nrows(),
547 system.ncols()
548 )));
549 }
550
551 match solver_policy {
552 SolverPolicy::SvdOnly => {
553 let svd = system.as_ref().thin_svd()?;
554 let solution = svd.solve_lstsq(rhs.as_ref());
555 Ok((solution, SolverUsed::Svd, false))
556 }
557 SolverPolicy::ColPivQr => {
558 let qr = system.as_ref().col_piv_qr();
559 let solution = qr.solve_lstsq(rhs.as_ref());
560 Ok((solution, SolverUsed::ColPivQr, false))
561 }
562 SolverPolicy::Auto => {
563 let qr = system.as_ref().col_piv_qr();
564 let solution = qr.solve_lstsq(rhs.as_ref());
565 let all_finite = (0..solution.nrows()).all(|row| {
566 (0..solution.ncols()).all(|col| {
567 let v = solution[(row, col)];
568 v.re.is_finite() && v.im.is_finite()
569 })
570 });
571 if !all_finite {
572 let svd = system.as_ref().thin_svd()?;
573 let fallback = svd.solve_lstsq(rhs.as_ref());
574 return Ok((fallback, SolverUsed::Svd, true));
575 }
576 Ok((solution, SolverUsed::ColPivQr, false))
577 }
578 }
579}
580
581pub(crate) fn solve_least_squares_scaled(
586 system: &Mat<Complex64>,
587 rhs: &Mat<Complex64>,
588 solver_policy: SolverPolicy,
589) -> Result<(Mat<Complex64>, SolverUsed, bool)> {
590 let cols = system.ncols();
591 let mut scaled = system.clone();
592 let mut col_norms = vec![0.0f64; cols];
593 for j in 0..cols {
594 let mut norm_sq = 0.0;
595 for i in 0..scaled.nrows() {
596 norm_sq += scaled[(i, j)].norm_sqr();
597 }
598 col_norms[j] = norm_sq.sqrt().max(1e-30);
599 let inv = 1.0 / col_norms[j];
600 for i in 0..scaled.nrows() {
601 scaled[(i, j)] *= inv;
602 }
603 }
604 let (mut solution, solver_used, fallback) = solve_least_squares(&scaled, rhs, solver_policy)?;
605 for j in 0..cols {
607 let inv = 1.0 / col_norms[j];
608 for k in 0..solution.ncols() {
609 solution[(j, k)] *= inv;
610 }
611 }
612 Ok((solution, solver_used, fallback))
613}
614
615pub(crate) fn compute_inverse_magnitude_weights(
621 values: &[Complex64],
622 samples: usize,
623 channels: usize,
624) -> Vec<f64> {
625 let mut weights = Vec::with_capacity(samples);
626 let mut max_mag = 0.0f64;
627 for k in 0..samples {
628 let mut sample_max = 0.0f64;
629 for ch in 0..channels {
630 sample_max = sample_max.max(values[k * channels + ch].norm());
631 }
632 max_mag = max_mag.max(sample_max);
633 weights.push(sample_max);
634 }
635 let floor = max_mag * 1e-8;
636 for w in &mut weights {
637 *w = 1.0 / (*w).max(floor);
638 }
639 weights
640}
641
642pub(crate) fn validate_weights(weights: Option<&[f64]>, samples: usize) -> Result<()> {
643 if let Some(weights) = weights {
644 if weights.len() != samples {
645 return Err(VecfitError::Dimension(format!(
646 "weights length {} does not match sample length {}",
647 weights.len(),
648 samples
649 )));
650 }
651 if weights
652 .iter()
653 .any(|weight| !weight.is_finite() || *weight < 0.0)
654 {
655 return Err(VecfitError::InvalidInput(
656 "weights must be finite and nonnegative".to_string(),
657 ));
658 }
659 }
660 Ok(())
661}