sparse_ir/
sampling.rs

1//! Sparse sampling in imaginary time
2//!
3//! This module provides `TauSampling` for transforming between IR basis coefficients
4//! and values at sparse sampling points in imaginary time.
5
6use crate::gemm::{GemmBackendHandle, matmul_par};
7use crate::traits::StatisticsType;
8use mdarray::{DTensor, DynRank, Shape, Tensor};
9
10/// Move axis from position `src` to position `dst`
11///
12/// This is equivalent to numpy.moveaxis or libsparseir's movedim.
13/// It creates a permutation array that moves the specified axis.
14///
15/// # Arguments
16/// * `arr` - Input tensor
17/// * `src` - Source axis position
18/// * `dst` - Destination axis position
19///
20/// # Returns
21/// Tensor with axes permuted
22///
23/// # Example
24/// ```ignore
25/// // For a 4D tensor with shape (2, 3, 4, 5)
26/// // movedim(arr, 0, 2) moves axis 0 to position 2
27/// // Result shape: (3, 4, 2, 5) with axes permuted as [1, 2, 0, 3]
28/// ```
29pub fn movedim<T: Clone>(arr: &Tensor<T, DynRank>, src: usize, dst: usize) -> Tensor<T, DynRank> {
30    if src == dst {
31        return arr.clone();
32    }
33
34    let rank = arr.rank();
35    assert!(
36        src < rank,
37        "src axis {} out of bounds for rank {}",
38        src,
39        rank
40    );
41    assert!(
42        dst < rank,
43        "dst axis {} out of bounds for rank {}",
44        dst,
45        rank
46    );
47
48    // Generate permutation: move src to dst position
49    let mut perm = Vec::with_capacity(rank);
50    let mut pos = 0;
51    for i in 0..rank {
52        if i == dst {
53            perm.push(src);
54        } else {
55            // Skip src position
56            if pos == src {
57                pos += 1;
58            }
59            perm.push(pos);
60            pos += 1;
61        }
62    }
63
64    arr.permute(&perm[..]).to_tensor()
65}
66
67/// Sparse sampling in imaginary time
68///
69/// Allows transformation between the IR basis and a set of sampling points
70/// in imaginary time (τ).
71pub struct TauSampling<S>
72where
73    S: StatisticsType,
74{
75    /// Sampling points in imaginary time τ ∈ [0, β]
76    sampling_points: Vec<f64>,
77
78    /// Real matrix fitter for least-squares fitting
79    fitter: crate::fitter::RealMatrixFitter,
80
81    /// Marker for statistics type
82    _phantom: std::marker::PhantomData<S>,
83}
84
85impl<S> TauSampling<S>
86where
87    S: StatisticsType,
88{
89    /// Create a new TauSampling with default sampling points
90    ///
91    /// The default sampling points are chosen as the extrema of the highest-order
92    /// basis function, which gives near-optimal conditioning.
93    /// SVD is computed lazily on first call to `fit` or `fit_nd`.
94    ///
95    /// # Arguments
96    /// * `basis` - Any basis implementing the `Basis` trait
97    ///
98    /// # Returns
99    /// A new TauSampling object
100    pub fn new(basis: &impl crate::basis_trait::Basis<S>) -> Self
101    where
102        S: 'static,
103    {
104        let sampling_points = basis.default_tau_sampling_points();
105        Self::with_sampling_points(basis, sampling_points)
106    }
107
108    /// Create a new TauSampling with custom sampling points
109    ///
110    /// SVD is computed lazily on first call to `fit` or `fit_nd`.
111    ///
112    /// # Arguments
113    /// * `basis` - Any basis implementing the `Basis` trait
114    /// * `sampling_points` - Custom sampling points in τ ∈ [-β, β]
115    ///
116    /// # Returns
117    /// A new TauSampling object
118    ///
119    /// # Panics
120    /// Panics if `sampling_points` is empty or if any point is outside [-β, β]
121    pub fn with_sampling_points(
122        basis: &impl crate::basis_trait::Basis<S>,
123        sampling_points: Vec<f64>,
124    ) -> Self
125    where
126        S: 'static,
127    {
128        assert!(!sampling_points.is_empty(), "No sampling points given");
129        assert!(
130            basis.size() <= sampling_points.len(),
131            "The number of sampling points must be greater than or equal to the basis size"
132        );
133
134        let beta = basis.beta();
135        for &tau in &sampling_points {
136            assert!(
137                tau >= -beta && tau <= beta,
138                "Sampling point τ={} is outside [-β, β]",
139                tau
140            );
141        }
142
143        // Compute sampling matrix: A[i, l] = u_l(τ_i)
144        // Use Basis trait's evaluate_tau method
145        let matrix = basis.evaluate_tau(&sampling_points);
146
147        // Create fitter
148        let fitter = crate::fitter::RealMatrixFitter::new(matrix);
149
150        Self {
151            sampling_points,
152            fitter,
153            _phantom: std::marker::PhantomData,
154        }
155    }
156
157    /// Create a new TauSampling with custom sampling points and pre-computed matrix
158    ///
159    /// This constructor is useful when the sampling matrix is already computed
160    /// (e.g., from external sources or for testing).
161    ///
162    /// # Arguments
163    /// * `sampling_points` - Sampling points in τ ∈ [-β, β]
164    /// * `matrix` - Pre-computed sampling matrix (n_points × basis_size)
165    ///
166    /// # Returns
167    /// A new TauSampling object
168    ///
169    /// # Panics
170    /// Panics if `sampling_points` is empty or if matrix dimensions don't match
171    pub fn from_matrix(sampling_points: Vec<f64>, matrix: DTensor<f64, 2>) -> Self {
172        assert!(!sampling_points.is_empty(), "No sampling points given");
173        assert_eq!(
174            matrix.shape().0,
175            sampling_points.len(),
176            "Matrix rows ({}) must match number of sampling points ({})",
177            matrix.shape().0,
178            sampling_points.len()
179        );
180
181        let fitter = crate::fitter::RealMatrixFitter::new(matrix);
182
183        Self {
184            sampling_points,
185            fitter,
186            _phantom: std::marker::PhantomData,
187        }
188    }
189
190    /// Get the sampling points
191    pub fn sampling_points(&self) -> &[f64] {
192        &self.sampling_points
193    }
194
195    /// Get the number of sampling points
196    pub fn n_sampling_points(&self) -> usize {
197        self.fitter.n_points()
198    }
199
200    /// Get the basis size
201    pub fn basis_size(&self) -> usize {
202        self.fitter.basis_size()
203    }
204
205    /// Get the sampling matrix
206    pub fn matrix(&self) -> &DTensor<f64, 2> {
207        &self.fitter.matrix
208    }
209
210    /// Evaluate basis coefficients at sampling points
211    ///
212    /// Computes g(τ_i) = Σ_l a_l * u_l(τ_i) for all sampling points
213    ///
214    /// # Arguments
215    /// * `coeffs` - Basis coefficients (length = basis_size)
216    ///
217    /// # Returns
218    /// Values at sampling points (length = n_sampling_points)
219    ///
220    /// # Panics
221    /// Panics if `coeffs.len() != basis_size`
222    pub fn evaluate(&self, coeffs: &[f64]) -> Vec<f64> {
223        self.fitter.evaluate(None, coeffs)
224    }
225
226    /// Internal generic evaluate_nd implementation
227    fn evaluate_nd_impl<T>(
228        &self,
229        backend: Option<&GemmBackendHandle>,
230        coeffs: &Tensor<T, DynRank>,
231        dim: usize,
232    ) -> Tensor<T, DynRank>
233    where
234        T: num_complex::ComplexFloat + faer_traits::ComplexField + 'static + From<f64> + Copy,
235    {
236        let rank = coeffs.rank();
237        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
238
239        let basis_size = self.basis_size();
240        let target_dim_size = coeffs.shape().dim(dim);
241
242        // Check that the target dimension matches basis_size
243        assert_eq!(
244            target_dim_size, basis_size,
245            "coeffs.shape().dim({}) = {} must equal basis_size = {}",
246            dim, target_dim_size, basis_size
247        );
248
249        // 1. Move target dimension to position 0
250        let coeffs_dim0 = movedim(coeffs, dim, 0);
251
252        // 2. Reshape to 2D: (basis_size, extra_size)
253        let extra_size: usize = coeffs_dim0.len() / basis_size;
254
255        // Convert DynRank to fixed Rank<2> for matmul_par
256        let coeffs_2d_dyn = coeffs_dim0
257            .reshape(&[basis_size, extra_size][..])
258            .to_tensor();
259        let coeffs_2d = DTensor::<T, 2>::from_fn([basis_size, extra_size], |idx| {
260            coeffs_2d_dyn[&[idx[0], idx[1]][..]]
261        });
262
263        // 3. Matrix multiply: result = A * coeffs
264        //    A is real, convert to type T
265        let n_points = self.n_sampling_points();
266        let matrix_t = DTensor::<T, 2>::from_fn(*self.fitter.matrix.shape(), |idx| {
267            self.fitter.matrix[idx].into()
268        });
269        let result_2d = matmul_par(&matrix_t, &coeffs_2d, backend);
270
271        // 4. Reshape back to N-D with n_points at position 0
272        let mut result_shape = vec![n_points];
273        coeffs_dim0.shape().with_dims(|dims| {
274            for i in 1..dims.len() {
275                result_shape.push(dims[i]);
276            }
277        });
278
279        // Convert DTensor<T, 2> to DynRank using into_dyn()
280        let result_2d_dyn = result_2d.into_dyn();
281        let result_dim0 = result_2d_dyn.reshape(&result_shape[..]).to_tensor();
282
283        // 5. Move dimension back to original position
284        movedim(&result_dim0, 0, dim)
285    }
286
287    /// Evaluate basis coefficients at sampling points (N-dimensional)
288    ///
289    /// Evaluates along the specified dimension, keeping other dimensions intact.
290    /// Supports both real (`f64`) and complex (`Complex<f64>`) coefficients.
291    ///
292    /// # Type Parameters
293    /// * `T` - Element type (f64 or Complex<f64>)
294    ///
295    /// # Arguments
296    /// * `coeffs` - N-dimensional array with `coeffs.shape().dim(dim) == basis_size`
297    /// * `dim` - Dimension along which to evaluate (0-indexed)
298    ///
299    /// # Returns
300    /// N-dimensional array with `result.shape().dim(dim) == n_sampling_points`
301    ///
302    /// # Panics
303    /// Panics if `coeffs.shape().dim(dim) != basis_size` or if `dim >= rank`
304    ///
305    /// # Example
306    /// ```ignore
307    /// use num_complex::Complex;
308    /// use mdarray::tensor;
309    ///
310    /// // Real coefficients
311    /// let values_real = sampling.evaluate_nd::<f64>(&coeffs_real, 0);
312    ///
313    /// // Complex coefficients
314    /// let values_complex = sampling.evaluate_nd::<Complex<f64>>(&coeffs_complex, 0);
315    /// ```
316    pub fn evaluate_nd<T>(
317        &self,
318        backend: Option<&GemmBackendHandle>,
319        coeffs: &Tensor<T, DynRank>,
320        dim: usize,
321    ) -> Tensor<T, DynRank>
322    where
323        T: num_complex::ComplexFloat + faer_traits::ComplexField + 'static + From<f64> + Copy,
324    {
325        self.evaluate_nd_impl(backend, coeffs, dim)
326    }
327
328    /// Internal generic fit_nd implementation
329    ///
330    /// Delegates to fitter for real values, fits real/imaginary parts separately for complex values
331    fn fit_nd_impl<T>(
332        &self,
333        backend: Option<&GemmBackendHandle>,
334        values: &Tensor<T, DynRank>,
335        dim: usize,
336    ) -> Tensor<T, DynRank>
337    where
338        T: num_complex::ComplexFloat
339            + faer_traits::ComplexField
340            + 'static
341            + From<f64>
342            + Copy
343            + Default,
344    {
345        use num_complex::Complex;
346
347        let rank = values.rank();
348        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
349
350        let n_points = self.n_sampling_points();
351        let basis_size = self.basis_size();
352        let target_dim_size = values.shape().dim(dim);
353
354        assert_eq!(
355            target_dim_size, n_points,
356            "values.shape().dim({}) = {} must equal n_sampling_points = {}",
357            dim, target_dim_size, n_points
358        );
359
360        // 1. Move target dimension to position 0
361        let values_dim0 = movedim(values, dim, 0);
362
363        // 2. Reshape to 2D: (n_points, extra_size)
364        let extra_size: usize = values_dim0.len() / n_points;
365        let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
366
367        // 3. Convert to DTensor<T, 2> and fit using fitter's 2D methods
368        // Use type introspection to dispatch between real and complex
369        use std::any::TypeId;
370        let is_real = TypeId::of::<T>() == TypeId::of::<f64>();
371
372        let coeffs_2d = if is_real {
373            // Real case: convert to f64 tensor and fit
374            let values_2d_f64 = DTensor::<f64, 2>::from_fn([n_points, extra_size], |idx| unsafe {
375                *(&values_2d_dyn[&[idx[0], idx[1]][..]] as *const T as *const f64)
376            });
377            let coeffs_2d_f64 = self.fitter.fit_2d(backend, &values_2d_f64);
378            // Convert back to T
379            DTensor::<T, 2>::from_fn(*coeffs_2d_f64.shape(), |idx| unsafe {
380                *(&coeffs_2d_f64[idx] as *const f64 as *const T)
381            })
382        } else {
383            // Complex case: convert to Complex<f64> tensor and fit
384            let values_2d_c64 =
385                DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| unsafe {
386                    *(&values_2d_dyn[&[idx[0], idx[1]][..]] as *const T as *const Complex<f64>)
387                });
388            let coeffs_2d_c64 = self.fitter.fit_complex_2d(backend, &values_2d_c64);
389            // Convert back to T
390            DTensor::<T, 2>::from_fn(*coeffs_2d_c64.shape(), |idx| unsafe {
391                *(&coeffs_2d_c64[idx] as *const Complex<f64> as *const T)
392            })
393        };
394
395        // 4. Reshape back to N-D with basis_size at position 0
396        let mut coeffs_shape = vec![basis_size];
397        values_dim0.shape().with_dims(|dims| {
398            for i in 1..dims.len() {
399                coeffs_shape.push(dims[i]);
400            }
401        });
402
403        let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
404
405        // 5. Move dimension 0 back to original position dim
406        movedim(&coeffs_dim0, 0, dim)
407    }
408
409    /// Fit basis coefficients from values at sampling points (N-dimensional)
410    ///
411    /// Fits along the specified dimension, keeping other dimensions intact.
412    /// Supports both real (`f64`) and complex (`Complex<f64>`) values.
413    ///
414    /// # Type Parameters
415    /// * `T` - Element type (f64 or Complex<f64>)
416    ///
417    /// # Arguments
418    /// * `values` - N-dimensional array with `values.shape().dim(dim) == n_sampling_points`
419    /// * `dim` - Dimension along which to fit (0-indexed)
420    ///
421    /// # Returns
422    /// N-dimensional array with `result.shape().dim(dim) == basis_size`
423    ///
424    /// # Panics
425    /// Panics if `values.shape().dim(dim) != n_sampling_points`, if `dim >= rank`, or if SVD not computed
426    ///
427    /// # Example
428    /// ```ignore
429    /// use num_complex::Complex;
430    /// use mdarray::tensor;
431    ///
432    /// // Real values
433    /// let coeffs_real = sampling.fit_nd::<f64>(&values_real, 0);
434    ///
435    /// // Complex values
436    /// let coeffs_complex = sampling.fit_nd::<Complex<f64>>(&values_complex, 0);
437    /// ```
438    pub fn fit_nd<T>(
439        &self,
440        backend: Option<&GemmBackendHandle>,
441        values: &Tensor<T, DynRank>,
442        dim: usize,
443    ) -> Tensor<T, DynRank>
444    where
445        T: num_complex::ComplexFloat
446            + faer_traits::ComplexField
447            + 'static
448            + From<f64>
449            + Copy
450            + Default,
451    {
452        self.fit_nd_impl(backend, values, dim)
453    }
454}
455
456#[cfg(test)]
457#[path = "tau_sampling_tests.rs"]
458mod tests;