sparse_ir/
sampling.rs

1//! Sparse sampling in imaginary time
2//!
3//! This module provides `TauSampling` for transforming between IR basis coefficients
4//! and values at sparse sampling points in imaginary time.
5
6use crate::fitters::InplaceFitter;
7use crate::gemm::GemmBackendHandle;
8use crate::traits::StatisticsType;
9use mdarray::{DTensor, DynRank, Shape, Slice, Tensor, ViewMut};
10use num_complex::Complex;
11
12/// Build output shape by replacing dimension `dim` with `new_size`
13fn build_output_shape<S: Shape>(input_shape: &S, dim: usize, new_size: usize) -> Vec<usize> {
14    let mut out_shape: Vec<usize> = Vec::with_capacity(input_shape.rank());
15    input_shape.with_dims(|dims| {
16        for (i, d) in dims.iter().enumerate() {
17            if i == dim {
18                out_shape.push(new_size);
19            } else {
20                out_shape.push(*d);
21            }
22        }
23    });
24    out_shape
25}
26
27/// Move axis from position `src` to position `dst`
28///
29/// This is equivalent to numpy.moveaxis or libsparseir's movedim.
30/// It creates a permutation array that moves the specified axis.
31///
32/// # Arguments
33/// * `arr` - Input array slice (Tensor or View)
34/// * `src` - Source axis position
35/// * `dst` - Destination axis position
36///
37/// # Returns
38/// Tensor with axes permuted
39///
40/// # Example
41/// ```ignore
42/// // For a 4D tensor with shape (2, 3, 4, 5)
43/// // movedim(arr, 0, 2) moves axis 0 to position 2
44/// // Result shape: (3, 4, 2, 5) with axes permuted as [1, 2, 0, 3]
45/// ```
46pub fn movedim<T: Clone>(arr: &Slice<T, DynRank>, src: usize, dst: usize) -> Tensor<T, DynRank> {
47    if src == dst {
48        return arr.to_tensor();
49    }
50
51    let rank = arr.rank();
52    assert!(
53        src < rank,
54        "src axis {} out of bounds for rank {}",
55        src,
56        rank
57    );
58    assert!(
59        dst < rank,
60        "dst axis {} out of bounds for rank {}",
61        dst,
62        rank
63    );
64
65    // Generate permutation: move src to dst position
66    let mut perm = Vec::with_capacity(rank);
67    let mut pos = 0;
68    for i in 0..rank {
69        if i == dst {
70            perm.push(src);
71        } else {
72            // Skip src position
73            if pos == src {
74                pos += 1;
75            }
76            perm.push(pos);
77            pos += 1;
78        }
79    }
80
81    arr.permute(&perm[..]).to_tensor()
82}
83
84/// Sparse sampling in imaginary time
85///
86/// Allows transformation between the IR basis and a set of sampling points
87/// in imaginary time (τ).
88pub struct TauSampling<S>
89where
90    S: StatisticsType,
91{
92    /// Sampling points in imaginary time τ ∈ [-β/2, β/2]
93    sampling_points: Vec<f64>,
94
95    /// Real matrix fitter for least-squares fitting
96    fitter: crate::fitters::RealMatrixFitter,
97
98    /// Marker for statistics type
99    _phantom: std::marker::PhantomData<S>,
100}
101
102impl<S> TauSampling<S>
103where
104    S: StatisticsType,
105{
106    /// Create a new TauSampling with default sampling points
107    ///
108    /// The default sampling points are chosen as the extrema of the highest-order
109    /// basis function, which gives near-optimal conditioning.
110    /// SVD is computed lazily on first call to `fit` or `fit_nd`.
111    ///
112    /// # Arguments
113    /// * `basis` - Any basis implementing the `Basis` trait
114    ///
115    /// # Returns
116    /// A new TauSampling object
117    pub fn new(basis: &impl crate::basis_trait::Basis<S>) -> Self
118    where
119        S: 'static,
120    {
121        let sampling_points = basis.default_tau_sampling_points();
122        Self::with_sampling_points(basis, sampling_points)
123    }
124
125    /// Create a new TauSampling with custom sampling points
126    ///
127    /// SVD is computed lazily on first call to `fit` or `fit_nd`.
128    ///
129    /// # Arguments
130    /// * `basis` - Any basis implementing the `Basis` trait
131    /// * `sampling_points` - Custom sampling points in τ ∈ [-β, β]
132    ///
133    /// # Returns
134    /// A new TauSampling object
135    ///
136    /// # Panics
137    /// Panics if `sampling_points` is empty or if any point is outside [-β, β]
138    pub fn with_sampling_points(
139        basis: &impl crate::basis_trait::Basis<S>,
140        sampling_points: Vec<f64>,
141    ) -> Self
142    where
143        S: 'static,
144    {
145        assert!(!sampling_points.is_empty(), "No sampling points given");
146        assert!(
147            basis.size() <= sampling_points.len(),
148            "The number of sampling points must be greater than or equal to the basis size"
149        );
150
151        let beta = basis.beta();
152        for &tau in &sampling_points {
153            assert!(
154                tau >= -beta && tau <= beta,
155                "Sampling point τ={} is outside [-β, β]",
156                tau
157            );
158        }
159
160        // Compute sampling matrix: A[i, l] = u_l(τ_i)
161        // Use Basis trait's evaluate_tau method
162        let matrix = basis.evaluate_tau(&sampling_points);
163
164        // Create fitter
165        let fitter = crate::fitters::RealMatrixFitter::new(matrix);
166
167        Self {
168            sampling_points,
169            fitter,
170            _phantom: std::marker::PhantomData,
171        }
172    }
173
174    /// Create a new TauSampling with custom sampling points and pre-computed matrix
175    ///
176    /// This constructor is useful when the sampling matrix is already computed
177    /// (e.g., from external sources or for testing).
178    ///
179    /// # Arguments
180    /// * `sampling_points` - Sampling points in τ ∈ [-β, β]
181    /// * `matrix` - Pre-computed sampling matrix (n_points × basis_size)
182    ///
183    /// # Returns
184    /// A new TauSampling object
185    ///
186    /// # Panics
187    /// Panics if `sampling_points` is empty or if matrix dimensions don't match
188    pub fn from_matrix(sampling_points: Vec<f64>, matrix: DTensor<f64, 2>) -> Self {
189        assert!(!sampling_points.is_empty(), "No sampling points given");
190        assert_eq!(
191            matrix.shape().0,
192            sampling_points.len(),
193            "Matrix rows ({}) must match number of sampling points ({})",
194            matrix.shape().0,
195            sampling_points.len()
196        );
197
198        let fitter = crate::fitters::RealMatrixFitter::new(matrix);
199
200        Self {
201            sampling_points,
202            fitter,
203            _phantom: std::marker::PhantomData,
204        }
205    }
206
207    /// Get the sampling points
208    pub fn sampling_points(&self) -> &[f64] {
209        &self.sampling_points
210    }
211
212    /// Get the number of sampling points
213    pub fn n_sampling_points(&self) -> usize {
214        self.fitter.n_points()
215    }
216
217    /// Get the basis size
218    pub fn basis_size(&self) -> usize {
219        self.fitter.basis_size()
220    }
221
222    /// Get the sampling matrix
223    pub fn matrix(&self) -> &DTensor<f64, 2> {
224        &self.fitter.matrix
225    }
226
227    // ========================================================================
228    // 1D functions (real and complex)
229    // ========================================================================
230
231    /// Evaluate basis coefficients at sampling points
232    ///
233    /// Computes g(τ_i) = Σ_l a_l * u_l(τ_i) for all sampling points
234    ///
235    /// # Arguments
236    /// * `coeffs` - Basis coefficients (length = basis_size)
237    ///
238    /// # Returns
239    /// Values at sampling points (length = n_sampling_points)
240    pub fn evaluate(&self, coeffs: &[f64]) -> Vec<f64> {
241        self.fitter.evaluate(None, coeffs)
242    }
243
244    /// Evaluate basis coefficients at sampling points, writing to output slice
245    pub fn evaluate_to(&self, coeffs: &[f64], out: &mut [f64]) {
246        self.fitter.evaluate_to(None, coeffs, out)
247    }
248
249    /// Fit values at sampling points to basis coefficients
250    pub fn fit(&self, values: &[f64]) -> Vec<f64> {
251        self.fitter.fit(None, values)
252    }
253
254    /// Fit values at sampling points to basis coefficients, writing to output slice
255    pub fn fit_to(&self, values: &[f64], out: &mut [f64]) {
256        self.fitter.fit_to(None, values, out)
257    }
258
259    /// Evaluate complex basis coefficients at sampling points
260    pub fn evaluate_zz(&self, coeffs: &[Complex<f64>]) -> Vec<Complex<f64>> {
261        self.fitter.evaluate_zz(None, coeffs)
262    }
263
264    /// Evaluate complex basis coefficients, writing to output slice
265    pub fn evaluate_zz_to(&self, coeffs: &[Complex<f64>], out: &mut [Complex<f64>]) {
266        self.fitter.evaluate_zz_to(None, coeffs, out)
267    }
268
269    /// Fit complex values at sampling points to basis coefficients
270    pub fn fit_zz(&self, values: &[Complex<f64>]) -> Vec<Complex<f64>> {
271        self.fitter.fit_zz(None, values)
272    }
273
274    /// Fit complex values, writing to output slice
275    pub fn fit_zz_to(&self, values: &[Complex<f64>], out: &mut [Complex<f64>]) {
276        self.fitter.fit_zz_to(None, values, out)
277    }
278
279    // ========================================================================
280    // N-D functions (real)
281    // ========================================================================
282
283    /// Evaluate N-D real coefficients at sampling points
284    ///
285    /// # Arguments
286    /// * `coeffs` - N-dimensional array with `coeffs.shape().dim(dim) == basis_size`
287    /// * `dim` - Dimension along which to evaluate (0-indexed)
288    ///
289    /// # Returns
290    /// N-dimensional array with `result.shape().dim(dim) == n_sampling_points`
291    pub fn evaluate_nd(
292        &self,
293        backend: Option<&GemmBackendHandle>,
294        coeffs: &Slice<f64, DynRank>,
295        dim: usize,
296    ) -> Tensor<f64, DynRank> {
297        let out_shape = build_output_shape(coeffs.shape(), dim, self.n_sampling_points());
298        let mut out = Tensor::<f64, DynRank>::zeros(&out_shape[..]);
299        self.evaluate_nd_to(backend, coeffs, dim, &mut out.expr_mut());
300        out
301    }
302
303    /// Evaluate N-D real coefficients, writing to a mutable view
304    pub fn evaluate_nd_to(
305        &self,
306        backend: Option<&GemmBackendHandle>,
307        coeffs: &Slice<f64, DynRank>,
308        dim: usize,
309        out: &mut ViewMut<'_, f64, DynRank>,
310    ) {
311        InplaceFitter::evaluate_nd_dd_to(self, backend, coeffs, dim, out);
312    }
313
314    /// Fit N-D real values at sampling points to basis coefficients
315    ///
316    /// # Arguments
317    /// * `values` - N-dimensional array with `values.shape().dim(dim) == n_sampling_points`
318    /// * `dim` - Dimension along which to fit (0-indexed)
319    ///
320    /// # Returns
321    /// N-dimensional array with `result.shape().dim(dim) == basis_size`
322    pub fn fit_nd(
323        &self,
324        backend: Option<&GemmBackendHandle>,
325        values: &Slice<f64, DynRank>,
326        dim: usize,
327    ) -> Tensor<f64, DynRank> {
328        let out_shape = build_output_shape(values.shape(), dim, self.basis_size());
329        let mut out = Tensor::<f64, DynRank>::zeros(&out_shape[..]);
330        self.fit_nd_to(backend, values, dim, &mut out.expr_mut());
331        out
332    }
333
334    /// Fit N-D real values, writing to a mutable view
335    pub fn fit_nd_to(
336        &self,
337        backend: Option<&GemmBackendHandle>,
338        values: &Slice<f64, DynRank>,
339        dim: usize,
340        out: &mut ViewMut<'_, f64, DynRank>,
341    ) {
342        InplaceFitter::fit_nd_dd_to(self, backend, values, dim, out);
343    }
344
345    // ========================================================================
346    // N-D functions (complex)
347    // ========================================================================
348
349    /// Evaluate N-D complex coefficients at sampling points
350    ///
351    /// # Arguments
352    /// * `coeffs` - N-dimensional complex array with `coeffs.shape().dim(dim) == basis_size`
353    /// * `dim` - Dimension along which to evaluate (0-indexed)
354    ///
355    /// # Returns
356    /// N-dimensional complex array with `result.shape().dim(dim) == n_sampling_points`
357    pub fn evaluate_nd_zz(
358        &self,
359        backend: Option<&GemmBackendHandle>,
360        coeffs: &Slice<Complex<f64>, DynRank>,
361        dim: usize,
362    ) -> Tensor<Complex<f64>, DynRank> {
363        let out_shape = build_output_shape(coeffs.shape(), dim, self.n_sampling_points());
364        let mut out = Tensor::<Complex<f64>, DynRank>::zeros(&out_shape[..]);
365        self.evaluate_nd_zz_to(backend, coeffs, dim, &mut out.expr_mut());
366        out
367    }
368
369    /// Evaluate N-D complex coefficients, writing to a mutable view
370    pub fn evaluate_nd_zz_to(
371        &self,
372        backend: Option<&GemmBackendHandle>,
373        coeffs: &Slice<Complex<f64>, DynRank>,
374        dim: usize,
375        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
376    ) {
377        InplaceFitter::evaluate_nd_zz_to(self, backend, coeffs, dim, out);
378    }
379
380    /// Fit N-D complex values at sampling points to basis coefficients
381    ///
382    /// # Arguments
383    /// * `values` - N-dimensional complex array with `values.shape().dim(dim) == n_sampling_points`
384    /// * `dim` - Dimension along which to fit (0-indexed)
385    ///
386    /// # Returns
387    /// N-dimensional complex array with `result.shape().dim(dim) == basis_size`
388    pub fn fit_nd_zz(
389        &self,
390        backend: Option<&GemmBackendHandle>,
391        values: &Slice<Complex<f64>, DynRank>,
392        dim: usize,
393    ) -> Tensor<Complex<f64>, DynRank> {
394        let out_shape = build_output_shape(values.shape(), dim, self.basis_size());
395        let mut out = Tensor::<Complex<f64>, DynRank>::zeros(&out_shape[..]);
396        self.fit_nd_zz_to(backend, values, dim, &mut out.expr_mut());
397        out
398    }
399
400    /// Fit N-D complex values, writing to a mutable view
401    pub fn fit_nd_zz_to(
402        &self,
403        backend: Option<&GemmBackendHandle>,
404        values: &Slice<Complex<f64>, DynRank>,
405        dim: usize,
406        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
407    ) {
408        InplaceFitter::fit_nd_zz_to(self, backend, values, dim, out);
409    }
410}
411
412/// InplaceFitter implementation for TauSampling
413///
414/// Delegates to RealMatrixFitter which supports dd and zz operations.
415impl<S: StatisticsType> InplaceFitter for TauSampling<S> {
416    fn n_points(&self) -> usize {
417        self.n_sampling_points()
418    }
419
420    fn basis_size(&self) -> usize {
421        self.basis_size()
422    }
423
424    fn evaluate_nd_dd_to(
425        &self,
426        backend: Option<&GemmBackendHandle>,
427        coeffs: &Slice<f64, DynRank>,
428        dim: usize,
429        out: &mut ViewMut<'_, f64, DynRank>,
430    ) -> bool {
431        self.fitter.evaluate_nd_dd_to(backend, coeffs, dim, out)
432    }
433
434    fn evaluate_nd_zz_to(
435        &self,
436        backend: Option<&GemmBackendHandle>,
437        coeffs: &Slice<Complex<f64>, DynRank>,
438        dim: usize,
439        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
440    ) -> bool {
441        self.fitter.evaluate_nd_zz_to(backend, coeffs, dim, out)
442    }
443
444    fn fit_nd_dd_to(
445        &self,
446        backend: Option<&GemmBackendHandle>,
447        values: &Slice<f64, DynRank>,
448        dim: usize,
449        out: &mut ViewMut<'_, f64, DynRank>,
450    ) -> bool {
451        self.fitter.fit_nd_dd_to(backend, values, dim, out)
452    }
453
454    fn fit_nd_zz_to(
455        &self,
456        backend: Option<&GemmBackendHandle>,
457        values: &Slice<Complex<f64>, DynRank>,
458        dim: usize,
459        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
460    ) -> bool {
461        self.fitter.fit_nd_zz_to(backend, values, dim, out)
462    }
463}
464
465#[cfg(test)]
466#[path = "tau_sampling_tests.rs"]
467mod tests;