sparse_ir/fitters/
common.rs

1//! Common utilities for fitters
2//!
3//! This module contains shared helper functions and SVD structures
4//! used by all fitter implementations.
5
6use crate::fpu_check::FpuGuard;
7use crate::gemm::GemmBackendHandle;
8use mdarray::{DTensor, DynRank, Shape, Slice, ViewMut};
9use num_complex::Complex;
10
11// ============================================================================
12// InplaceFitter trait
13// ============================================================================
14
15/// Trait for inplace evaluation and fitting operations on N-dimensional arrays.
16///
17/// Uses BLAS-style naming convention for type suffixes:
18/// - `d` = double (f64)
19/// - `z` = double complex (Complex<f64>)
20///
21/// For example:
22/// - `evaluate_nd_dd_to`: f64 input → f64 output
23/// - `evaluate_nd_zz_to`: Complex<f64> input → Complex<f64> output
24/// - `evaluate_nd_dz_to`: f64 input → Complex<f64> output
25/// - `evaluate_nd_zd_to`: Complex<f64> input → f64 output
26/// Trait for inplace evaluation and fitting operations on N-dimensional arrays.
27///
28/// All methods return `bool`:
29/// - `true` = operation succeeded
30/// - `false` = operation not supported for this fitter
31///
32/// Default implementations return `false` (not supported).
33pub trait InplaceFitter {
34    /// Number of sampling points
35    fn n_points(&self) -> usize;
36
37    /// Number of basis functions
38    fn basis_size(&self) -> usize;
39
40    /// Evaluate ND: f64 coeffs → f64 values
41    fn evaluate_nd_dd_to(
42        &self,
43        backend: Option<&GemmBackendHandle>,
44        coeffs: &Slice<f64, DynRank>,
45        dim: usize,
46        out: &mut ViewMut<'_, f64, DynRank>,
47    ) -> bool {
48        let _ = (backend, coeffs, dim, out);
49        false
50    }
51
52    /// Evaluate ND: f64 coeffs → Complex<f64> values
53    fn evaluate_nd_dz_to(
54        &self,
55        backend: Option<&GemmBackendHandle>,
56        coeffs: &Slice<f64, DynRank>,
57        dim: usize,
58        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
59    ) -> bool {
60        let _ = (backend, coeffs, dim, out);
61        false
62    }
63
64    /// Evaluate ND: Complex<f64> coeffs → f64 values
65    fn evaluate_nd_zd_to(
66        &self,
67        backend: Option<&GemmBackendHandle>,
68        coeffs: &Slice<Complex<f64>, DynRank>,
69        dim: usize,
70        out: &mut ViewMut<'_, f64, DynRank>,
71    ) -> bool {
72        let _ = (backend, coeffs, dim, out);
73        false
74    }
75
76    /// Evaluate ND: Complex<f64> coeffs → Complex<f64> values
77    fn evaluate_nd_zz_to(
78        &self,
79        backend: Option<&GemmBackendHandle>,
80        coeffs: &Slice<Complex<f64>, DynRank>,
81        dim: usize,
82        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
83    ) -> bool {
84        let _ = (backend, coeffs, dim, out);
85        false
86    }
87
88    /// Fit ND: f64 values → f64 coeffs
89    fn fit_nd_dd_to(
90        &self,
91        backend: Option<&GemmBackendHandle>,
92        values: &Slice<f64, DynRank>,
93        dim: usize,
94        out: &mut ViewMut<'_, f64, DynRank>,
95    ) -> bool {
96        let _ = (backend, values, dim, out);
97        false
98    }
99
100    /// Fit ND: f64 values → Complex<f64> coeffs
101    fn fit_nd_dz_to(
102        &self,
103        backend: Option<&GemmBackendHandle>,
104        values: &Slice<f64, DynRank>,
105        dim: usize,
106        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
107    ) -> bool {
108        let _ = (backend, values, dim, out);
109        false
110    }
111
112    /// Fit ND: Complex<f64> values → f64 coeffs
113    fn fit_nd_zd_to(
114        &self,
115        backend: Option<&GemmBackendHandle>,
116        values: &Slice<Complex<f64>, DynRank>,
117        dim: usize,
118        out: &mut ViewMut<'_, f64, DynRank>,
119    ) -> bool {
120        let _ = (backend, values, dim, out);
121        false
122    }
123
124    /// Fit ND: Complex<f64> values → Complex<f64> coeffs
125    fn fit_nd_zz_to(
126        &self,
127        backend: Option<&GemmBackendHandle>,
128        values: &Slice<Complex<f64>, DynRank>,
129        dim: usize,
130        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
131    ) -> bool {
132        let _ = (backend, values, dim, out);
133        false
134    }
135}
136
137// ============================================================================
138// Permutation helpers
139// ============================================================================
140
141/// Generate permutation to move dimension `dim` to position 0
142///
143/// For example, with rank=4 and dim=2:
144/// - Result: [2, 0, 1, 3]
145pub(crate) fn make_perm_to_front(rank: usize, dim: usize) -> Vec<usize> {
146    let mut perm = Vec::with_capacity(rank);
147    perm.push(dim);
148    for i in 0..rank {
149        if i != dim {
150            perm.push(i);
151        }
152    }
153    perm
154}
155
156// ============================================================================
157// Strided copy helpers
158// ============================================================================
159
160/// Copy data from a strided view to a contiguous slice
161///
162/// This is useful for copying permuted views to a contiguous buffer
163/// before performing GEMM operations.
164///
165/// # Arguments
166/// * `src` - Source slice (may be strided)
167/// * `dst` - Destination slice (must be contiguous, same total elements)
168pub(crate) fn copy_to_contiguous<T: Copy>(
169    src: &mdarray::Slice<T, mdarray::DynRank, mdarray::Strided>,
170    dst: &mut [T],
171) {
172    assert_eq!(dst.len(), src.len(), "Destination size mismatch");
173
174    // mdarray's iter() returns elements in row-major order
175    for (d, s) in dst.iter_mut().zip(src.iter()) {
176        *d = *s;
177    }
178}
179
180/// Copy data from a contiguous slice to a strided view
181///
182/// This is useful for copying GEMM results back to a permuted output view.
183///
184/// # Arguments
185/// * `src` - Source slice (contiguous)
186/// * `dst` - Destination slice (may be strided)
187pub(crate) fn copy_from_contiguous<T: Copy>(
188    src: &[T],
189    dst: &mut mdarray::Slice<T, mdarray::DynRank, mdarray::Strided>,
190) {
191    assert_eq!(src.len(), dst.len(), "Source size mismatch");
192
193    // mdarray's iter_mut() returns elements in row-major order
194    for (d, s) in dst.iter_mut().zip(src.iter()) {
195        *d = *s;
196    }
197}
198
199// ============================================================================
200// Complex-Real reinterpretation helpers
201// ============================================================================
202
203/// Reinterpret a Complex<f64> slice as a f64 view with an extra dimension of size 2
204///
205/// Complex<f64> array `[d0, d1, ..., dN]` becomes f64 array `[d0, d1, ..., dN, 2]`
206/// where the last dimension contains [re, im] pairs.
207pub(crate) fn complex_slice_as_real<'a>(
208    coeffs: &'a Slice<Complex<f64>, DynRank>,
209) -> mdarray::View<'a, f64, DynRank, mdarray::Dense> {
210    // Build new shape: [..., 2]
211    let mut new_shape: Vec<usize> = Vec::with_capacity(coeffs.rank() + 1);
212    coeffs.shape().with_dims(|dims| {
213        for d in dims {
214            new_shape.push(*d);
215        }
216    });
217    new_shape.push(2);
218
219    unsafe {
220        let shape: DynRank = Shape::from_dims(&new_shape[..]);
221        let mapping = mdarray::DenseMapping::new(shape);
222        mdarray::View::new_unchecked(coeffs.as_ptr() as *const f64, mapping)
223    }
224}
225
226/// Reinterpret a mutable Complex<f64> slice as a mutable f64 view with an extra dimension of size 2
227///
228/// Complex<f64> array `[d0, d1, ..., dN]` becomes f64 array `[d0, d1, ..., dN, 2]`
229/// where the last dimension contains [re, im] pairs.
230#[allow(dead_code)]
231pub(crate) fn complex_slice_mut_as_real<'a>(
232    out: &'a mut Slice<Complex<f64>, DynRank>,
233) -> mdarray::ViewMut<'a, f64, DynRank, mdarray::Dense> {
234    // Build new shape: [..., 2]
235    let mut new_shape: Vec<usize> = Vec::with_capacity(out.rank() + 1);
236    out.shape().with_dims(|dims| {
237        for d in dims {
238            new_shape.push(*d);
239        }
240    });
241    new_shape.push(2);
242
243    unsafe {
244        let shape: DynRank = Shape::from_dims(&new_shape[..]);
245        let mapping = mdarray::DenseMapping::new(shape);
246        mdarray::ViewMut::new_unchecked(out.as_mut_ptr() as *mut f64, mapping)
247    }
248}
249
250// ============================================================================
251// SVD structures
252// ============================================================================
253
254/// SVD decomposition for real matrices
255pub(crate) struct RealSVD {
256    pub ut: DTensor<f64, 2>, // (min_dim, n_rows) - U^T
257    pub s: Vec<f64>,         // (min_dim,)
258    pub v: DTensor<f64, 2>,  // (n_cols, min_dim) - V (transpose of V^T)
259}
260
261impl RealSVD {
262    pub fn new(u: DTensor<f64, 2>, s: Vec<f64>, vt: DTensor<f64, 2>) -> Self {
263        // Check dimensions
264        let (_, u_cols) = *u.shape();
265        let (vt_rows, _) = *vt.shape();
266        let min_dim = s.len();
267
268        assert_eq!(
269            u_cols, min_dim,
270            "u.cols()={} must equal s.len()={}",
271            u_cols, min_dim
272        );
273        assert_eq!(
274            vt_rows, min_dim,
275            "vt.rows()={} must equal s.len()={}",
276            vt_rows, min_dim
277        );
278
279        // Create ut and v from u and vt
280        let ut = u.transpose().to_tensor(); // (min_dim, n_rows)
281        let v = vt.transpose().to_tensor(); // (n_cols, min_dim)
282
283        // Verify v.cols() == s.len() (v.shape().1 is the second dimension, which is min_dim)
284        assert_eq!(
285            v.shape().1,
286            min_dim,
287            "v.cols()={} must equal s.len()={}",
288            v.shape().1,
289            min_dim
290        );
291
292        Self { ut, s, v }
293    }
294}
295
296/// SVD decomposition for complex matrices
297pub(crate) struct ComplexSVD {
298    pub ut: DTensor<Complex<f64>, 2>, // (min_dim, n_rows) - U^H
299    pub s: Vec<f64>,                  // (min_dim,) - singular values are real
300    pub v: DTensor<Complex<f64>, 2>,  // (n_cols, min_dim) - V (transpose of V^T)
301}
302
303impl ComplexSVD {
304    pub fn new(u: DTensor<Complex<f64>, 2>, s: Vec<f64>, vt: DTensor<Complex<f64>, 2>) -> Self {
305        // Check dimensions
306        let (u_rows, u_cols) = *u.shape();
307        let (vt_rows, _) = *vt.shape();
308        let min_dim = s.len();
309
310        assert_eq!(
311            u_cols, min_dim,
312            "u.cols()={} must equal s.len()={}",
313            u_cols, min_dim
314        );
315        assert_eq!(
316            vt_rows, min_dim,
317            "vt.rows()={} must equal s.len()={}",
318            vt_rows, min_dim
319        );
320
321        // Create ut (U^H, conjugate transpose) and v from u and vt
322        let ut = DTensor::<Complex<f64>, 2>::from_fn([u_cols, u_rows], |idx| {
323            u[[idx[1], idx[0]]].conj() // conjugate transpose: U^H
324        });
325        let v = vt.transpose().to_tensor(); // (n_cols, min_dim)
326
327        // Verify v.cols() == s.len() (v.shape().1 is the second dimension, which is min_dim)
328        assert_eq!(
329            v.shape().1,
330            min_dim,
331            "v.cols()={} must equal s.len()={}",
332            v.shape().1,
333            min_dim
334        );
335
336        Self { ut, s, v }
337    }
338}
339
340// ============================================================================
341// SVD computation functions
342// ============================================================================
343
344/// Compute SVD of a real matrix using mdarray-linalg
345pub(crate) fn compute_real_svd(matrix: &DTensor<f64, 2>) -> RealSVD {
346    use mdarray_linalg::prelude::SVD;
347    use mdarray_linalg::svd::SVDDecomp;
348    use mdarray_linalg_faer::Faer;
349
350    // Protect FPU state during SVD computation (required for Intel Fortran compatibility)
351    let _guard = FpuGuard::new_protect_computation();
352
353    let mut a = matrix.clone();
354    let SVDDecomp { u, s, vt } = Faer.svd(&mut *a).expect("SVD computation failed");
355
356    // Extract singular values from first row
357    let min_dim = s.shape().0.min(s.shape().1);
358    let s_vec: Vec<f64> = (0..min_dim).map(|i| s[[0, i]]).collect();
359
360    // Trim u and vt to min_dim
361    // u: (n_rows, n_cols) -> (n_rows, min_dim) - take first min_dim columns
362    // vt: (n_rows, n_cols) -> (min_dim, n_cols) - take first min_dim rows
363    let u_trimmed = u.view(.., ..min_dim).to_tensor();
364    let vt_trimmed = vt.view(..min_dim, ..).to_tensor();
365
366    RealSVD::new(u_trimmed, s_vec, vt_trimmed)
367}
368
369/// Compute SVD of a complex matrix directly
370pub(crate) fn compute_complex_svd(matrix: &DTensor<Complex<f64>, 2>) -> ComplexSVD {
371    use mdarray_linalg::prelude::SVD;
372    use mdarray_linalg::svd::SVDDecomp;
373    use mdarray_linalg_faer::Faer;
374
375    // Protect FPU state during SVD computation (required for Intel Fortran compatibility)
376    let _guard = FpuGuard::new_protect_computation();
377
378    // Use matrix directly (Complex<f64> is compatible with faer's c64)
379    let mut matrix_c64 = matrix.clone();
380
381    // Compute complex SVD directly
382    let SVDDecomp { u, s, vt } = Faer
383        .svd(&mut *matrix_c64)
384        .expect("Complex SVD computation failed");
385
386    // Extract singular values from first row (they are real even though stored as Complex)
387    let min_dim = s.shape().0.min(s.shape().1);
388    let s_vec: Vec<f64> = (0..min_dim).map(|i| s[[0, i]].re).collect();
389
390    // Trim u and vt to min_dim
391    // u: (n_rows, n_cols) -> (n_rows, min_dim) - take first min_dim columns
392    // vt: (n_rows, n_cols) -> (min_dim, n_cols) - take first min_dim rows
393    let u_trimmed = u.view(.., ..min_dim).to_tensor();
394    let vt_trimmed = vt.view(..min_dim, ..).to_tensor();
395
396    ComplexSVD::new(u_trimmed, s_vec, vt_trimmed)
397}
398
399// ============================================================================
400// Complex-Real conversion helpers
401// ============================================================================
402
403/// Combine real and imaginary parts into complex tensor
404pub(crate) fn combine_complex(
405    re: &DTensor<f64, 2>,
406    im: &DTensor<f64, 2>,
407) -> DTensor<Complex<f64>, 2> {
408    let (n_points, extra_size) = *re.shape();
409    DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
410        Complex::new(re[idx], im[idx])
411    })
412}
413
414/// Extract real parts from complex tensor (for coefficients)
415pub(crate) fn extract_real_parts_coeffs(coeffs_2d: &DTensor<Complex<f64>, 2>) -> DTensor<f64, 2> {
416    let (basis_size, extra_size) = *coeffs_2d.shape();
417    DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| coeffs_2d[idx].re)
418}