sparse_ir/fitters/
common.rs

1//! Common utilities for fitters
2//!
3//! This module contains shared helper functions and SVD structures
4//! used by all fitter implementations.
5
6use crate::fpu_check::FpuGuard;
7use crate::gemm::GemmBackendHandle;
8use mdarray::{DTensor, DynRank, Shape, Slice, ViewMut};
9use num_complex::Complex;
10
11// ============================================================================
12// InplaceFitter trait
13// ============================================================================
14
15/// Trait for inplace evaluation and fitting operations on N-dimensional arrays.
16///
17/// Uses BLAS-style naming convention for type suffixes:
18/// - `d` = double (f64)
19/// - `z` = double complex (Complex<f64>)
20///
21/// For example:
22/// - `evaluate_nd_dd_to`: f64 input → f64 output
23/// - `evaluate_nd_zz_to`: Complex<f64> input → Complex<f64> output
24/// - `evaluate_nd_dz_to`: f64 input → Complex<f64> output
25/// - `evaluate_nd_zd_to`: Complex<f64> input → f64 output
26/// Trait for inplace evaluation and fitting operations on N-dimensional arrays.
27///
28/// All methods return `bool`:
29/// - `true` = operation succeeded
30/// - `false` = operation not supported for this fitter
31///
32/// Default implementations return `false` (not supported).
33pub trait InplaceFitter {
34    /// Number of sampling points
35    fn n_points(&self) -> usize;
36
37    /// Number of basis functions
38    fn basis_size(&self) -> usize;
39
40    /// Evaluate ND: f64 coeffs → f64 values
41    fn evaluate_nd_dd_to(
42        &self,
43        backend: Option<&GemmBackendHandle>,
44        coeffs: &Slice<f64, DynRank>,
45        dim: usize,
46        out: &mut ViewMut<'_, f64, DynRank>,
47    ) -> bool {
48        let _ = (backend, coeffs, dim, out);
49        false
50    }
51
52    /// Evaluate ND: f64 coeffs → Complex<f64> values
53    fn evaluate_nd_dz_to(
54        &self,
55        backend: Option<&GemmBackendHandle>,
56        coeffs: &Slice<f64, DynRank>,
57        dim: usize,
58        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
59    ) -> bool {
60        let _ = (backend, coeffs, dim, out);
61        false
62    }
63
64    /// Evaluate ND: Complex<f64> coeffs → f64 values
65    fn evaluate_nd_zd_to(
66        &self,
67        backend: Option<&GemmBackendHandle>,
68        coeffs: &Slice<Complex<f64>, DynRank>,
69        dim: usize,
70        out: &mut ViewMut<'_, f64, DynRank>,
71    ) -> bool {
72        let _ = (backend, coeffs, dim, out);
73        false
74    }
75
76    /// Evaluate ND: Complex<f64> coeffs → Complex<f64> values
77    fn evaluate_nd_zz_to(
78        &self,
79        backend: Option<&GemmBackendHandle>,
80        coeffs: &Slice<Complex<f64>, DynRank>,
81        dim: usize,
82        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
83    ) -> bool {
84        let _ = (backend, coeffs, dim, out);
85        false
86    }
87
88    /// Fit ND: f64 values → f64 coeffs
89    fn fit_nd_dd_to(
90        &self,
91        backend: Option<&GemmBackendHandle>,
92        values: &Slice<f64, DynRank>,
93        dim: usize,
94        out: &mut ViewMut<'_, f64, DynRank>,
95    ) -> bool {
96        let _ = (backend, values, dim, out);
97        false
98    }
99
100    /// Fit ND: f64 values → Complex<f64> coeffs
101    fn fit_nd_dz_to(
102        &self,
103        backend: Option<&GemmBackendHandle>,
104        values: &Slice<f64, DynRank>,
105        dim: usize,
106        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
107    ) -> bool {
108        let _ = (backend, values, dim, out);
109        false
110    }
111
112    /// Fit ND: Complex<f64> values → f64 coeffs
113    fn fit_nd_zd_to(
114        &self,
115        backend: Option<&GemmBackendHandle>,
116        values: &Slice<Complex<f64>, DynRank>,
117        dim: usize,
118        out: &mut ViewMut<'_, f64, DynRank>,
119    ) -> bool {
120        let _ = (backend, values, dim, out);
121        false
122    }
123
124    /// Fit ND: Complex<f64> values → Complex<f64> coeffs
125    fn fit_nd_zz_to(
126        &self,
127        backend: Option<&GemmBackendHandle>,
128        values: &Slice<Complex<f64>, DynRank>,
129        dim: usize,
130        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
131    ) -> bool {
132        let _ = (backend, values, dim, out);
133        false
134    }
135}
136
137// ============================================================================
138// Permutation helpers
139// ============================================================================
140
141/// Generate permutation to move dimension `dim` to position 0
142///
143/// For example, with rank=4 and dim=2:
144/// - Result: [2, 0, 1, 3]
145pub(crate) fn make_perm_to_front(rank: usize, dim: usize) -> Vec<usize> {
146    let mut perm = Vec::with_capacity(rank);
147    perm.push(dim);
148    for i in 0..rank {
149        if i != dim {
150            perm.push(i);
151        }
152    }
153    perm
154}
155
156// ============================================================================
157// Strided copy helpers
158// ============================================================================
159
160/// Copy data from a contiguous slice to a strided view
161///
162/// This is useful for copying GEMM results back to a permuted output view.
163///
164/// # Arguments
165/// * `src` - Source slice (contiguous)
166/// * `dst` - Destination slice (may be strided)
167pub(crate) fn copy_from_contiguous<T: Copy>(
168    src: &[T],
169    dst: &mut mdarray::Slice<T, mdarray::DynRank, mdarray::Strided>,
170) {
171    assert_eq!(src.len(), dst.len(), "Source size mismatch");
172
173    // mdarray's iter_mut() returns elements in row-major order
174    for (d, s) in dst.iter_mut().zip(src.iter()) {
175        *d = *s;
176    }
177}
178
179// ============================================================================
180// Complex-Real reinterpretation helpers
181// ============================================================================
182
183/// Reinterpret a mutable Complex<f64> slice as a mutable f64 view with an extra dimension of size 2
184///
185/// Complex<f64> array `[d0, d1, ..., dN]` becomes f64 array `[d0, d1, ..., dN, 2]`
186/// where the last dimension contains [re, im] pairs.
187#[allow(dead_code)]
188pub(crate) fn complex_slice_mut_as_real<'a>(
189    out: &'a mut Slice<Complex<f64>, DynRank>,
190) -> mdarray::ViewMut<'a, f64, DynRank, mdarray::Dense> {
191    // Build new shape: [..., 2]
192    let mut new_shape: Vec<usize> = Vec::with_capacity(out.rank() + 1);
193    out.shape().with_dims(|dims| {
194        for d in dims {
195            new_shape.push(*d);
196        }
197    });
198    new_shape.push(2);
199
200    unsafe {
201        let shape: DynRank = Shape::from_dims(&new_shape[..]);
202        let mapping = mdarray::DenseMapping::new(shape);
203        mdarray::ViewMut::new_unchecked(out.as_mut_ptr() as *mut f64, mapping)
204    }
205}
206
207// ============================================================================
208// SVD structures
209// ============================================================================
210
211/// SVD decomposition for real matrices
212pub(crate) struct RealSVD {
213    pub ut: DTensor<f64, 2>, // (min_dim, n_rows) - U^T
214    pub s: Vec<f64>,         // (min_dim,)
215    pub v: DTensor<f64, 2>,  // (n_cols, min_dim) - V (transpose of V^T)
216}
217
218impl RealSVD {
219    pub fn new(u: DTensor<f64, 2>, s: Vec<f64>, vt: DTensor<f64, 2>) -> Self {
220        // Check dimensions
221        let (_, u_cols) = *u.shape();
222        let (vt_rows, _) = *vt.shape();
223        let min_dim = s.len();
224
225        assert_eq!(
226            u_cols, min_dim,
227            "u.cols()={} must equal s.len()={}",
228            u_cols, min_dim
229        );
230        assert_eq!(
231            vt_rows, min_dim,
232            "vt.rows()={} must equal s.len()={}",
233            vt_rows, min_dim
234        );
235
236        // Create ut and v from u and vt
237        let ut = u.transpose().to_tensor(); // (min_dim, n_rows)
238        let v = vt.transpose().to_tensor(); // (n_cols, min_dim)
239
240        // Verify v.cols() == s.len() (v.shape().1 is the second dimension, which is min_dim)
241        assert_eq!(
242            v.shape().1,
243            min_dim,
244            "v.cols()={} must equal s.len()={}",
245            v.shape().1,
246            min_dim
247        );
248
249        Self { ut, s, v }
250    }
251}
252
253/// SVD decomposition for complex matrices
254pub(crate) struct ComplexSVD {
255    pub ut: DTensor<Complex<f64>, 2>, // (min_dim, n_rows) - U^H
256    pub s: Vec<f64>,                  // (min_dim,) - singular values are real
257    pub v: DTensor<Complex<f64>, 2>,  // (n_cols, min_dim) - V (transpose of V^T)
258}
259
260impl ComplexSVD {
261    pub fn new(u: DTensor<Complex<f64>, 2>, s: Vec<f64>, vt: DTensor<Complex<f64>, 2>) -> Self {
262        // Check dimensions
263        let (u_rows, u_cols) = *u.shape();
264        let (vt_rows, _) = *vt.shape();
265        let min_dim = s.len();
266
267        assert_eq!(
268            u_cols, min_dim,
269            "u.cols()={} must equal s.len()={}",
270            u_cols, min_dim
271        );
272        assert_eq!(
273            vt_rows, min_dim,
274            "vt.rows()={} must equal s.len()={}",
275            vt_rows, min_dim
276        );
277
278        // Create ut (U^H, conjugate transpose) and v from u and vt
279        let ut = DTensor::<Complex<f64>, 2>::from_fn([u_cols, u_rows], |idx| {
280            u[[idx[1], idx[0]]].conj() // conjugate transpose: U^H
281        });
282        let v = vt.transpose().to_tensor(); // (n_cols, min_dim)
283
284        // Verify v.cols() == s.len() (v.shape().1 is the second dimension, which is min_dim)
285        assert_eq!(
286            v.shape().1,
287            min_dim,
288            "v.cols()={} must equal s.len()={}",
289            v.shape().1,
290            min_dim
291        );
292
293        Self { ut, s, v }
294    }
295}
296
297// ============================================================================
298// SVD computation functions
299// ============================================================================
300
301/// Compute SVD of a real matrix using mdarray-linalg
302pub(crate) fn compute_real_svd(matrix: &DTensor<f64, 2>) -> RealSVD {
303    use mdarray_linalg::prelude::SVD;
304    use mdarray_linalg::svd::SVDDecomp;
305    use mdarray_linalg_faer::Faer;
306
307    // Protect FPU state during SVD computation (required for Intel Fortran compatibility)
308    let _guard = FpuGuard::new_protect_computation();
309
310    let mut a = matrix.clone();
311    let SVDDecomp { u, s, vt } = Faer.svd(&mut *a).expect("SVD computation failed");
312
313    // Extract singular values from first row
314    let min_dim = s.shape().0.min(s.shape().1);
315    let s_vec: Vec<f64> = (0..min_dim).map(|i| s[[0, i]]).collect();
316
317    // Trim u and vt to min_dim
318    // u: (n_rows, n_cols) -> (n_rows, min_dim) - take first min_dim columns
319    // vt: (n_rows, n_cols) -> (min_dim, n_cols) - take first min_dim rows
320    let u_trimmed = u.view(.., ..min_dim).to_tensor();
321    let vt_trimmed = vt.view(..min_dim, ..).to_tensor();
322
323    RealSVD::new(u_trimmed, s_vec, vt_trimmed)
324}
325
326/// Compute SVD of a complex matrix directly
327pub(crate) fn compute_complex_svd(matrix: &DTensor<Complex<f64>, 2>) -> ComplexSVD {
328    use mdarray_linalg::prelude::SVD;
329    use mdarray_linalg::svd::SVDDecomp;
330    use mdarray_linalg_faer::Faer;
331
332    // Protect FPU state during SVD computation (required for Intel Fortran compatibility)
333    let _guard = FpuGuard::new_protect_computation();
334
335    // Use matrix directly (Complex<f64> is compatible with faer's c64)
336    let mut matrix_c64 = matrix.clone();
337
338    // Compute complex SVD directly
339    let SVDDecomp { u, s, vt } = Faer
340        .svd(&mut *matrix_c64)
341        .expect("Complex SVD computation failed");
342
343    // Extract singular values from first row (they are real even though stored as Complex)
344    let min_dim = s.shape().0.min(s.shape().1);
345    let s_vec: Vec<f64> = (0..min_dim).map(|i| s[[0, i]].re).collect();
346
347    // Trim u and vt to min_dim
348    // u: (n_rows, n_cols) -> (n_rows, min_dim) - take first min_dim columns
349    // vt: (n_rows, n_cols) -> (min_dim, n_cols) - take first min_dim rows
350    let u_trimmed = u.view(.., ..min_dim).to_tensor();
351    let vt_trimmed = vt.view(..min_dim, ..).to_tensor();
352
353    ComplexSVD::new(u_trimmed, s_vec, vt_trimmed)
354}
355
356// ============================================================================
357// Complex-Real conversion helpers
358// ============================================================================
359
360/// Combine real and imaginary parts into complex tensor
361pub(crate) fn combine_complex(
362    re: &DTensor<f64, 2>,
363    im: &DTensor<f64, 2>,
364) -> DTensor<Complex<f64>, 2> {
365    let (n_points, extra_size) = *re.shape();
366    DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
367        Complex::new(re[idx], im[idx])
368    })
369}
370
371/// Extract real parts from complex tensor (for coefficients)
372pub(crate) fn extract_real_parts_coeffs(coeffs_2d: &DTensor<Complex<f64>, 2>) -> DTensor<f64, 2> {
373    let (basis_size, extra_size) = *coeffs_2d.shape();
374    DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| coeffs_2d[idx].re)
375}