Skip to main content

sparse_ir/
matsubara_sampling.rs

1//! Sparse sampling in Matsubara frequencies
2//!
3//! This module provides Matsubara frequency sampling for transforming between
4//! IR basis coefficients and values at sparse Matsubara frequencies.
5
6use crate::fitters::{ComplexMatrixFitter, ComplexToRealFitter, InplaceFitter};
7use crate::freq::MatsubaraFreq;
8use crate::gemm::GemmBackendHandle;
9use crate::sampling::movedim;
10use crate::traits::StatisticsType;
11use mdarray::{DTensor, DynRank, Shape, Slice, Tensor, ViewMut};
12use num_complex::Complex;
13use std::marker::PhantomData;
14
15/// Trait for coefficient types that can be evaluated by Matsubara sampling
16///
17/// This provides compile-time dispatch for different coefficient types,
18/// avoiding runtime TypeId checks and unsafe pointer casts.
19pub trait MatsubaraCoeffs: Copy + 'static {
20    /// Evaluate coefficients using the given sampler
21    fn evaluate_nd_with<S: StatisticsType>(
22        sampler: &MatsubaraSampling<S>,
23        backend: Option<&GemmBackendHandle>,
24        coeffs: &Slice<Self, DynRank>,
25        dim: usize,
26    ) -> Tensor<Complex<f64>, DynRank>;
27}
28
29impl MatsubaraCoeffs for f64 {
30    fn evaluate_nd_with<S: StatisticsType>(
31        sampler: &MatsubaraSampling<S>,
32        backend: Option<&GemmBackendHandle>,
33        coeffs: &Slice<Self, DynRank>,
34        dim: usize,
35    ) -> Tensor<Complex<f64>, DynRank> {
36        sampler.evaluate_nd_impl_real(backend, coeffs, dim)
37    }
38}
39
40impl MatsubaraCoeffs for Complex<f64> {
41    fn evaluate_nd_with<S: StatisticsType>(
42        sampler: &MatsubaraSampling<S>,
43        backend: Option<&GemmBackendHandle>,
44        coeffs: &Slice<Self, DynRank>,
45        dim: usize,
46    ) -> Tensor<Complex<f64>, DynRank> {
47        sampler.evaluate_nd_impl_complex(backend, coeffs, dim)
48    }
49}
50
51/// Matsubara sampling for full frequency range (positive and negative)
52///
53/// General complex problem without symmetry → complex coefficients
54pub struct MatsubaraSampling<S: StatisticsType> {
55    sampling_points: Vec<MatsubaraFreq<S>>,
56    fitter: ComplexMatrixFitter,
57    _phantom: PhantomData<S>,
58}
59
60impl<S: StatisticsType> MatsubaraSampling<S> {
61    /// Create Matsubara sampling with default sampling points
62    ///
63    /// Uses extrema-based sampling point selection (symmetric: positive and negative frequencies).
64    pub fn new(basis: &impl crate::basis_trait::Basis<S>) -> Self
65    where
66        S: 'static,
67    {
68        let sampling_points = basis.default_matsubara_sampling_points(false);
69        Self::with_sampling_points(basis, sampling_points)
70    }
71
72    /// Create Matsubara sampling with custom sampling points
73    pub fn with_sampling_points(
74        basis: &impl crate::basis_trait::Basis<S>,
75        mut sampling_points: Vec<MatsubaraFreq<S>>,
76    ) -> Self
77    where
78        S: 'static,
79    {
80        // Sort sampling points
81        sampling_points.sort();
82
83        // Evaluate matrix at sampling points
84        // Use Basis trait's evaluate_matsubara method
85        let matrix = basis.evaluate_matsubara(&sampling_points);
86
87        // Create fitter (complex → complex, no symmetry)
88        let fitter = ComplexMatrixFitter::new(matrix);
89
90        Self {
91            sampling_points,
92            fitter,
93            _phantom: PhantomData,
94        }
95    }
96
97    /// Create Matsubara sampling with custom sampling points and pre-computed matrix
98    ///
99    /// This constructor is useful when the sampling matrix is already computed
100    /// (e.g., from external sources or for testing).
101    ///
102    /// # Arguments
103    /// * `sampling_points` - Matsubara frequency sampling points
104    /// * `matrix` - Pre-computed sampling matrix (n_points × basis_size)
105    ///
106    /// # Returns
107    /// A new MatsubaraSampling object
108    ///
109    /// # Panics
110    /// Panics if `sampling_points` is empty or if matrix dimensions don't match
111    pub fn from_matrix(
112        sampling_points: Vec<MatsubaraFreq<S>>,
113        matrix: DTensor<Complex<f64>, 2>,
114    ) -> Self {
115        assert!(!sampling_points.is_empty(), "No sampling points given");
116        assert_eq!(
117            matrix.shape().0,
118            sampling_points.len(),
119            "Matrix rows ({}) must match number of sampling points ({})",
120            matrix.shape().0,
121            sampling_points.len()
122        );
123        debug_assert!(
124            sampling_points.windows(2).all(|w| w[0] <= w[1]),
125            "Sampling points must be sorted in ascending order"
126        );
127
128        let fitter = ComplexMatrixFitter::new(matrix);
129
130        Self {
131            sampling_points,
132            fitter,
133            _phantom: PhantomData,
134        }
135    }
136
137    /// Get sampling points
138    pub fn sampling_points(&self) -> &[MatsubaraFreq<S>] {
139        &self.sampling_points
140    }
141
142    /// Number of sampling points
143    pub fn n_sampling_points(&self) -> usize {
144        self.sampling_points.len()
145    }
146
147    /// Basis size
148    pub fn basis_size(&self) -> usize {
149        self.fitter.basis_size()
150    }
151
152    /// Get the sampling matrix
153    pub fn matrix(&self) -> &DTensor<Complex<f64>, 2> {
154        &self.fitter.matrix
155    }
156
157    /// Evaluate complex basis coefficients at sampling points
158    ///
159    /// # Arguments
160    /// * `coeffs` - Complex basis coefficients (length = basis_size)
161    ///
162    /// # Returns
163    /// Complex values at Matsubara frequencies (length = n_sampling_points)
164    pub fn evaluate(&self, coeffs: &[Complex<f64>]) -> Vec<Complex<f64>> {
165        self.fitter.evaluate(None, coeffs)
166    }
167
168    /// Fit complex basis coefficients from values at sampling points
169    ///
170    /// # Arguments
171    /// * `values` - Complex values at Matsubara frequencies (length = n_sampling_points)
172    ///
173    /// # Returns
174    /// Fitted complex basis coefficients (length = basis_size)
175    pub fn fit(&self, values: &[Complex<f64>]) -> Vec<Complex<f64>> {
176        self.fitter.fit(None, values)
177    }
178
179    /// Evaluate N-dimensional array of basis coefficients at sampling points
180    ///
181    /// Supports both real (`f64`) and complex (`Complex<f64>`) coefficients.
182    /// Always returns complex values at Matsubara frequencies.
183    ///
184    /// # Type Parameters
185    /// * `T` - Element type (f64 or Complex<f64>)
186    ///
187    /// # Arguments
188    /// * `backend` - Optional GEMM backend handle (None uses default)
189    /// * `coeffs` - N-dimensional tensor of basis coefficients
190    /// * `dim` - Dimension along which to evaluate (must have size = basis_size)
191    ///
192    /// # Returns
193    /// N-dimensional tensor of complex values at Matsubara frequencies
194    ///
195    /// # Example
196    /// ```ignore
197    /// use num_complex::Complex;
198    ///
199    /// // Real coefficients
200    /// let values = matsubara_sampling.evaluate_nd::<f64>(None, &coeffs_real, 0);
201    ///
202    /// // Complex coefficients
203    /// let values = matsubara_sampling.evaluate_nd::<Complex<f64>>(None, &coeffs_complex, 0);
204    /// ```
205    /// Evaluate N-D coefficients for the real case `T = f64`
206    fn evaluate_nd_impl_real(
207        &self,
208        backend: Option<&GemmBackendHandle>,
209        coeffs: &Slice<f64, DynRank>,
210        dim: usize,
211    ) -> Tensor<Complex<f64>, DynRank> {
212        let rank = coeffs.rank();
213        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
214
215        let basis_size = self.basis_size();
216        let target_dim_size = coeffs.shape().dim(dim);
217
218        assert_eq!(
219            target_dim_size, basis_size,
220            "coeffs.shape().dim({}) = {} must equal basis_size = {}",
221            dim, target_dim_size, basis_size
222        );
223
224        // 1. Move target dimension to position 0
225        let coeffs_dim0 = movedim(coeffs, dim, 0);
226
227        // 2. Reshape to 2D: (basis_size, extra_size)
228        let extra_size: usize = coeffs_dim0.len() / basis_size;
229
230        let coeffs_2d_dyn = coeffs_dim0
231            .reshape(&[basis_size, extra_size][..])
232            .to_tensor();
233
234        // 3. Convert to DTensor and evaluate using evaluate_2d_real
235        let coeffs_2d = DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| {
236            coeffs_2d_dyn[&[idx[0], idx[1]][..]]
237        });
238        let coeffs_2d_view = coeffs_2d.view(.., ..);
239        let result_2d = self.fitter.evaluate_2d_real(backend, &coeffs_2d_view);
240
241        // 4. Reshape back to N-D with n_points at position 0
242        let n_points = self.n_sampling_points();
243        let mut result_shape = vec![n_points];
244        coeffs_dim0.shape().with_dims(|dims| {
245            for i in 1..dims.len() {
246                result_shape.push(dims[i]);
247            }
248        });
249
250        let result_dim0 = result_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
251
252        // 5. Move dimension 0 back to original position dim
253        movedim(&result_dim0, 0, dim)
254    }
255
256    /// Evaluate N-D coefficients for the complex case `T = Complex<f64>`
257    fn evaluate_nd_impl_complex(
258        &self,
259        backend: Option<&GemmBackendHandle>,
260        coeffs: &Slice<Complex<f64>, DynRank>,
261        dim: usize,
262    ) -> Tensor<Complex<f64>, DynRank> {
263        let rank = coeffs.rank();
264        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
265
266        let basis_size = self.basis_size();
267        let target_dim_size = coeffs.shape().dim(dim);
268
269        assert_eq!(
270            target_dim_size, basis_size,
271            "coeffs.shape().dim({}) = {} must equal basis_size = {}",
272            dim, target_dim_size, basis_size
273        );
274
275        // 1. Move target dimension to position 0
276        let coeffs_dim0 = movedim(coeffs, dim, 0);
277
278        // 2. Reshape to 2D: (basis_size, extra_size)
279        let extra_size: usize = coeffs_dim0.len() / basis_size;
280
281        let coeffs_2d_dyn = coeffs_dim0
282            .reshape(&[basis_size, extra_size][..])
283            .to_tensor();
284
285        // 3. Convert to DTensor and evaluate using evaluate_2d
286        let coeffs_2d = DTensor::<Complex<f64>, 2>::from_fn([basis_size, extra_size], |idx| {
287            coeffs_2d_dyn[&[idx[0], idx[1]][..]]
288        });
289        let coeffs_2d_view = coeffs_2d.view(.., ..);
290        let result_2d = self.fitter.evaluate_2d(backend, &coeffs_2d_view);
291
292        // 4. Reshape back to N-D with n_points at position 0
293        let n_points = self.n_sampling_points();
294        let mut result_shape = vec![n_points];
295        coeffs_dim0.shape().with_dims(|dims| {
296            for i in 1..dims.len() {
297                result_shape.push(dims[i]);
298            }
299        });
300
301        let result_dim0 = result_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
302
303        // 5. Move dimension 0 back to original position dim
304        movedim(&result_dim0, 0, dim)
305    }
306
307    /// Evaluate N-dimensional coefficients at Matsubara sampling points
308    ///
309    /// This method dispatches to the appropriate implementation based on the
310    /// coefficient type at compile time using the `MatsubaraCoeffs` trait.
311    ///
312    /// # Type Parameter
313    /// * `T` - Must implement `MatsubaraCoeffs` (currently `f64` or `Complex<f64>`)
314    ///
315    /// # Arguments
316    /// * `backend` - Optional GEMM backend handle
317    /// * `coeffs` - N-dimensional tensor of basis coefficients
318    /// * `dim` - Dimension along which to evaluate
319    ///
320    /// # Returns
321    /// N-dimensional tensor of complex values at Matsubara frequencies
322    pub fn evaluate_nd<T: MatsubaraCoeffs>(
323        &self,
324        backend: Option<&GemmBackendHandle>,
325        coeffs: &Slice<T, DynRank>,
326        dim: usize,
327    ) -> Tensor<Complex<f64>, DynRank> {
328        T::evaluate_nd_with(self, backend, coeffs, dim)
329    }
330
331    /// Evaluate real basis coefficients at Matsubara sampling points (N-dimensional)
332    ///
333    /// This method takes real coefficients and produces complex values, useful when
334    /// working with symmetry-exploiting representations or real-valued IR coefficients.
335    ///
336    /// # Arguments
337    /// * `backend` - Optional GEMM backend handle (None uses default)
338    /// * `coeffs` - N-dimensional tensor of real basis coefficients
339    /// * `dim` - Dimension along which to evaluate (must have size = basis_size)
340    ///
341    /// # Returns
342    /// N-dimensional tensor of complex values at Matsubara frequencies
343    pub fn evaluate_nd_real(
344        &self,
345        backend: Option<&GemmBackendHandle>,
346        coeffs: &Tensor<f64, DynRank>,
347        dim: usize,
348    ) -> Tensor<Complex<f64>, DynRank> {
349        let rank = coeffs.rank();
350        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
351
352        let basis_size = self.basis_size();
353        let target_dim_size = coeffs.shape().dim(dim);
354
355        assert_eq!(
356            target_dim_size, basis_size,
357            "coeffs.shape().dim({}) = {} must equal basis_size = {}",
358            dim, target_dim_size, basis_size
359        );
360
361        // 1. Move target dimension to position 0
362        let coeffs_dim0 = movedim(coeffs, dim, 0);
363
364        // 2. Reshape to 2D: (basis_size, extra_size)
365        let extra_size: usize = coeffs_dim0.len() / basis_size;
366
367        let coeffs_2d_dyn = coeffs_dim0
368            .reshape(&[basis_size, extra_size][..])
369            .to_tensor();
370
371        // 3. Convert to DTensor and evaluate using ComplexMatrixFitter
372        let coeffs_2d = DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| {
373            coeffs_2d_dyn[&[idx[0], idx[1]][..]]
374        });
375
376        // 4. Evaluate: values = A * coeffs (A is complex, coeffs is real)
377        let coeffs_2d_view = coeffs_2d.view(.., ..);
378        let values_2d = self.fitter.evaluate_2d_real(backend, &coeffs_2d_view);
379
380        // 5. Reshape result back to N-D with first dimension = n_sampling_points
381        let n_points = self.n_sampling_points();
382        let mut result_shape = Vec::with_capacity(rank);
383        result_shape.push(n_points);
384        coeffs_dim0.shape().with_dims(|dims| {
385            for i in 1..dims.len() {
386                result_shape.push(dims[i]);
387            }
388        });
389
390        let result_dim0 = values_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
391
392        // 6. Move dimension 0 back to original position dim
393        movedim(&result_dim0, 0, dim)
394    }
395
396    /// Fit N-dimensional array of complex values to complex basis coefficients
397    ///
398    /// # Arguments
399    /// * `backend` - Optional GEMM backend handle (None uses default)
400    /// * `values` - N-dimensional tensor of complex values at Matsubara frequencies
401    /// * `dim` - Dimension along which to fit (must have size = n_sampling_points)
402    ///
403    /// # Returns
404    /// N-dimensional tensor of complex basis coefficients
405    pub fn fit_nd(
406        &self,
407        backend: Option<&GemmBackendHandle>,
408        values: &Tensor<Complex<f64>, DynRank>,
409        dim: usize,
410    ) -> Tensor<Complex<f64>, DynRank> {
411        let rank = values.rank();
412        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
413
414        let n_points = self.n_sampling_points();
415        let target_dim_size = values.shape().dim(dim);
416
417        assert_eq!(
418            target_dim_size, n_points,
419            "values.shape().dim({}) = {} must equal n_sampling_points = {}",
420            dim, target_dim_size, n_points
421        );
422
423        // 1. Move target dimension to position 0
424        let values_dim0 = movedim(values, dim, 0);
425
426        // 2. Reshape to 2D: (n_points, extra_size)
427        let extra_size: usize = values_dim0.len() / n_points;
428        let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
429
430        // 3. Convert to DTensor and fit using GEMM
431        let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
432            values_2d_dyn[&[idx[0], idx[1]][..]]
433        });
434
435        // Use fitter's efficient 2D fit (GEMM-based)
436        let values_2d_view = values_2d.view(.., ..);
437        let coeffs_2d = self.fitter.fit_2d(backend, &values_2d_view);
438
439        // 4. Reshape back to N-D with basis_size at position 0
440        let basis_size = self.basis_size();
441        let mut coeffs_shape = vec![basis_size];
442        values_dim0.shape().with_dims(|dims| {
443            for i in 1..dims.len() {
444                coeffs_shape.push(dims[i]);
445            }
446        });
447
448        let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
449
450        // 5. Move dimension 0 back to original position dim
451        movedim(&coeffs_dim0, 0, dim)
452    }
453
454    /// Fit N-dimensional array of complex values to real basis coefficients
455    ///
456    /// This method fits complex Matsubara values to real IR coefficients.
457    /// Takes the real part of the least-squares solution.
458    ///
459    /// # Arguments
460    /// * `backend` - Optional GEMM backend handle (None uses default)
461    /// * `values` - N-dimensional tensor of complex values at Matsubara frequencies
462    /// * `dim` - Dimension along which to fit (must have size = n_sampling_points)
463    ///
464    /// # Returns
465    /// N-dimensional tensor of real basis coefficients
466    pub fn fit_nd_real(
467        &self,
468        backend: Option<&GemmBackendHandle>,
469        values: &Tensor<Complex<f64>, DynRank>,
470        dim: usize,
471    ) -> Tensor<f64, DynRank> {
472        let rank = values.rank();
473        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
474
475        let n_points = self.n_sampling_points();
476        let target_dim_size = values.shape().dim(dim);
477
478        assert_eq!(
479            target_dim_size, n_points,
480            "values.shape().dim({}) = {} must equal n_sampling_points = {}",
481            dim, target_dim_size, n_points
482        );
483
484        // 1. Move target dimension to position 0
485        let values_dim0 = movedim(values, dim, 0);
486
487        // 2. Reshape to 2D: (n_points, extra_size)
488        let extra_size: usize = values_dim0.len() / n_points;
489        let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
490
491        // 3. Convert to DTensor and fit
492        let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
493            values_2d_dyn[&[idx[0], idx[1]][..]]
494        });
495
496        // Use fitter's fit_2d_real method
497        let values_2d_view = values_2d.view(.., ..);
498        let coeffs_2d = self.fitter.fit_2d_real(backend, &values_2d_view);
499
500        // 4. Reshape back to N-D with basis_size at position 0
501        let basis_size = self.basis_size();
502        let mut coeffs_shape = vec![basis_size];
503        values_dim0.shape().with_dims(|dims| {
504            for i in 1..dims.len() {
505                coeffs_shape.push(dims[i]);
506            }
507        });
508
509        let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
510
511        // 5. Move dimension 0 back to original position dim
512        movedim(&coeffs_dim0, 0, dim)
513    }
514
515    /// Evaluate basis coefficients at Matsubara sampling points (N-dimensional) with in-place output
516    ///
517    /// # Type Parameters
518    /// * `T` - Coefficient type (f64 or Complex<f64>)
519    ///
520    /// # Arguments
521    /// * `coeffs` - N-dimensional tensor with `coeffs.shape().dim(dim) == basis_size`
522    /// * `dim` - Dimension along which to evaluate (0-indexed)
523    /// * `out` - Output tensor with `out.shape().dim(dim) == n_sampling_points` (Complex<f64>)
524    pub fn evaluate_nd_to<T: MatsubaraCoeffs>(
525        &self,
526        backend: Option<&GemmBackendHandle>,
527        coeffs: &Slice<T, DynRank>,
528        dim: usize,
529        out: &mut Tensor<Complex<f64>, DynRank>,
530    ) {
531        // Validate output shape
532        let rank = coeffs.rank();
533        assert_eq!(
534            out.rank(),
535            rank,
536            "out.rank()={} must equal coeffs.rank()={}",
537            out.rank(),
538            rank
539        );
540
541        let n_points = self.n_sampling_points();
542        let out_dim_size = out.shape().dim(dim);
543        assert_eq!(
544            out_dim_size, n_points,
545            "out.shape().dim({}) = {} must equal n_sampling_points = {}",
546            dim, out_dim_size, n_points
547        );
548
549        // Validate other dimensions match
550        for d in 0..rank {
551            if d != dim {
552                let coeffs_d = coeffs.shape().dim(d);
553                let out_d = out.shape().dim(d);
554                assert_eq!(
555                    coeffs_d, out_d,
556                    "coeffs.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
557                    d, coeffs_d, d, out_d
558                );
559            }
560        }
561
562        // Compute result and copy to out
563        let result = self.evaluate_nd(backend, coeffs, dim);
564
565        // Copy result to out
566        let total = out.len();
567        for i in 0..total {
568            let mut idx = vec![0usize; rank];
569            let mut remaining = i;
570            for d in (0..rank).rev() {
571                let dim_size = out.shape().dim(d);
572                idx[d] = remaining % dim_size;
573                remaining /= dim_size;
574            }
575            out[&idx[..]] = result[&idx[..]];
576        }
577    }
578
579    /// Fit N-dimensional complex values to complex coefficients with in-place output
580    ///
581    /// # Arguments
582    /// * `values` - N-dimensional tensor with `values.shape().dim(dim) == n_sampling_points`
583    /// * `dim` - Dimension along which to fit (0-indexed)
584    /// * `out` - Output tensor with `out.shape().dim(dim) == basis_size` (Complex<f64>)
585    pub fn fit_nd_to(
586        &self,
587        backend: Option<&GemmBackendHandle>,
588        values: &Tensor<Complex<f64>, DynRank>,
589        dim: usize,
590        out: &mut Tensor<Complex<f64>, DynRank>,
591    ) {
592        // Validate output shape
593        let rank = values.rank();
594        assert_eq!(
595            out.rank(),
596            rank,
597            "out.rank()={} must equal values.rank()={}",
598            out.rank(),
599            rank
600        );
601
602        let basis_size = self.basis_size();
603        let out_dim_size = out.shape().dim(dim);
604        assert_eq!(
605            out_dim_size, basis_size,
606            "out.shape().dim({}) = {} must equal basis_size = {}",
607            dim, out_dim_size, basis_size
608        );
609
610        // Validate other dimensions match
611        for d in 0..rank {
612            if d != dim {
613                let values_d = values.shape().dim(d);
614                let out_d = out.shape().dim(d);
615                assert_eq!(
616                    values_d, out_d,
617                    "values.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
618                    d, values_d, d, out_d
619                );
620            }
621        }
622
623        // Compute result and copy to out
624        let result = self.fit_nd(backend, values, dim);
625
626        // Copy result to out
627        let total = out.len();
628        for i in 0..total {
629            let mut idx = vec![0usize; rank];
630            let mut remaining = i;
631            for d in (0..rank).rev() {
632                let dim_size = out.shape().dim(d);
633                idx[d] = remaining % dim_size;
634                remaining /= dim_size;
635            }
636            out[&idx[..]] = result[&idx[..]];
637        }
638    }
639}
640
641/// InplaceFitter implementation for MatsubaraSampling
642///
643/// Delegates to ComplexMatrixFitter which supports:
644/// - zz: Complex input → Complex output (full support)
645/// - dz: Real input → Complex output (evaluate only)
646/// - zd: Complex input → Real output (fit only, takes real part)
647impl<S: StatisticsType> InplaceFitter for MatsubaraSampling<S> {
648    fn n_points(&self) -> usize {
649        self.n_sampling_points()
650    }
651
652    fn basis_size(&self) -> usize {
653        self.basis_size()
654    }
655
656    fn evaluate_nd_dz_to(
657        &self,
658        backend: Option<&GemmBackendHandle>,
659        coeffs: &Slice<f64, DynRank>,
660        dim: usize,
661        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
662    ) -> bool {
663        self.fitter.evaluate_nd_dz_to(backend, coeffs, dim, out)
664    }
665
666    fn evaluate_nd_zz_to(
667        &self,
668        backend: Option<&GemmBackendHandle>,
669        coeffs: &Slice<Complex<f64>, DynRank>,
670        dim: usize,
671        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
672    ) -> bool {
673        self.fitter.evaluate_nd_zz_to(backend, coeffs, dim, out)
674    }
675
676    fn fit_nd_zd_to(
677        &self,
678        backend: Option<&GemmBackendHandle>,
679        values: &Slice<Complex<f64>, DynRank>,
680        dim: usize,
681        out: &mut ViewMut<'_, f64, DynRank>,
682    ) -> bool {
683        self.fitter.fit_nd_zd_to(backend, values, dim, out)
684    }
685
686    fn fit_nd_zz_to(
687        &self,
688        backend: Option<&GemmBackendHandle>,
689        values: &Slice<Complex<f64>, DynRank>,
690        dim: usize,
691        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
692    ) -> bool {
693        self.fitter.fit_nd_zz_to(backend, values, dim, out)
694    }
695}
696
697/// Matsubara sampling for positive frequencies only
698///
699/// Exploits symmetry to reconstruct real coefficients from positive frequencies only.
700/// Supports: {0, 1, 2, 3, ...} (no negative frequencies)
701pub struct MatsubaraSamplingPositiveOnly<S: StatisticsType> {
702    sampling_points: Vec<MatsubaraFreq<S>>,
703    fitter: ComplexToRealFitter,
704    _phantom: PhantomData<S>,
705}
706
707impl<S: StatisticsType> MatsubaraSamplingPositiveOnly<S> {
708    /// Create Matsubara sampling with default positive-only sampling points
709    ///
710    /// Uses extrema-based sampling point selection (positive frequencies only).
711    /// Exploits symmetry to reconstruct real coefficients.
712    pub fn new(basis: &impl crate::basis_trait::Basis<S>) -> Self
713    where
714        S: 'static,
715    {
716        let sampling_points = basis.default_matsubara_sampling_points(true);
717        Self::with_sampling_points(basis, sampling_points)
718    }
719
720    /// Create Matsubara sampling with custom positive-only sampling points
721    pub fn with_sampling_points(
722        basis: &impl crate::basis_trait::Basis<S>,
723        mut sampling_points: Vec<MatsubaraFreq<S>>,
724    ) -> Self
725    where
726        S: 'static,
727    {
728        // Sort and validate (all n >= 0)
729        sampling_points.sort();
730
731        // Validate that all points are non-negative
732        assert!(
733            sampling_points.iter().all(|f| f.n() >= 0),
734            "All sampling points must be non-negative for positive-only Matsubara sampling"
735        );
736
737        // Evaluate matrix at sampling points
738        // Use Basis trait's evaluate_matsubara method
739        let matrix = basis.evaluate_matsubara(&sampling_points);
740
741        // Create fitter (complex → real, exploits symmetry)
742        let fitter = ComplexToRealFitter::new(&matrix);
743
744        Self {
745            sampling_points,
746            fitter,
747            _phantom: PhantomData,
748        }
749    }
750
751    /// Create Matsubara sampling (positive-only) with custom sampling points and pre-computed matrix
752    ///
753    /// This constructor is useful when the sampling matrix is already computed.
754    /// Uses symmetry to fit real coefficients from complex values at positive frequencies.
755    ///
756    /// # Arguments
757    /// * `sampling_points` - Matsubara frequency sampling points (should be positive)
758    /// * `matrix` - Pre-computed sampling matrix (n_points × basis_size)
759    ///
760    /// # Returns
761    /// A new MatsubaraSamplingPositiveOnly object
762    ///
763    /// # Panics
764    /// Panics if `sampling_points` is empty or if matrix dimensions don't match
765    pub fn from_matrix(
766        sampling_points: Vec<MatsubaraFreq<S>>,
767        matrix: DTensor<Complex<f64>, 2>,
768    ) -> Self {
769        assert!(!sampling_points.is_empty(), "No sampling points given");
770        assert_eq!(
771            matrix.shape().0,
772            sampling_points.len(),
773            "Matrix rows ({}) must match number of sampling points ({})",
774            matrix.shape().0,
775            sampling_points.len()
776        );
777        debug_assert!(
778            sampling_points.windows(2).all(|w| w[0] <= w[1]),
779            "Sampling points must be sorted in ascending order"
780        );
781
782        let fitter = ComplexToRealFitter::new(&matrix);
783
784        Self {
785            sampling_points,
786            fitter,
787            _phantom: PhantomData,
788        }
789    }
790
791    /// Get sampling points
792    pub fn sampling_points(&self) -> &[MatsubaraFreq<S>] {
793        &self.sampling_points
794    }
795
796    /// Number of sampling points
797    pub fn n_sampling_points(&self) -> usize {
798        self.sampling_points.len()
799    }
800
801    /// Basis size
802    pub fn basis_size(&self) -> usize {
803        self.fitter.basis_size()
804    }
805
806    /// Get the original complex sampling matrix
807    pub fn matrix(&self) -> &DTensor<Complex<f64>, 2> {
808        &self.fitter.matrix
809    }
810
811    /// Evaluate basis coefficients at sampling points
812    pub fn evaluate(&self, coeffs: &[f64]) -> Vec<Complex<f64>> {
813        self.fitter.evaluate(None, coeffs)
814    }
815
816    /// Fit basis coefficients from values at sampling points
817    pub fn fit(&self, values: &[Complex<f64>]) -> Vec<f64> {
818        self.fitter.fit(None, values)
819    }
820
821    /// Evaluate N-dimensional array of real basis coefficients at sampling points
822    ///
823    /// # Arguments
824    /// * `coeffs` - N-dimensional tensor of real basis coefficients
825    /// * `dim` - Dimension along which to evaluate (must have size = basis_size)
826    ///
827    /// # Returns
828    /// N-dimensional tensor of complex values at Matsubara frequencies
829    pub fn evaluate_nd(
830        &self,
831        backend: Option<&GemmBackendHandle>,
832        coeffs: &Tensor<f64, DynRank>,
833        dim: usize,
834    ) -> Tensor<Complex<f64>, DynRank> {
835        let rank = coeffs.rank();
836        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
837
838        let basis_size = self.basis_size();
839        let target_dim_size = coeffs.shape().dim(dim);
840
841        assert_eq!(
842            target_dim_size, basis_size,
843            "coeffs.shape().dim({}) = {} must equal basis_size = {}",
844            dim, target_dim_size, basis_size
845        );
846
847        // 1. Move target dimension to position 0
848        let coeffs_dim0 = movedim(coeffs, dim, 0);
849
850        // 2. Reshape to 2D: (basis_size, extra_size)
851        let extra_size: usize = coeffs_dim0.len() / basis_size;
852
853        let coeffs_2d_dyn = coeffs_dim0
854            .reshape(&[basis_size, extra_size][..])
855            .to_tensor();
856
857        // 3. Convert to DTensor and evaluate using GEMM
858        let coeffs_2d = DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| {
859            coeffs_2d_dyn[&[idx[0], idx[1]][..]]
860        });
861
862        // Use fitter's efficient 2D evaluate (GEMM-based)
863        let coeffs_2d_view = coeffs_2d.view(.., ..);
864        let result_2d = self.fitter.evaluate_2d(backend, &coeffs_2d_view);
865
866        // 4. Reshape back to N-D with n_points at position 0
867        let n_points = self.n_sampling_points();
868        let mut result_shape = vec![n_points];
869        coeffs_dim0.shape().with_dims(|dims| {
870            for i in 1..dims.len() {
871                result_shape.push(dims[i]);
872            }
873        });
874
875        let result_dim0 = result_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
876
877        // 5. Move dimension 0 back to original position dim
878        movedim(&result_dim0, 0, dim)
879    }
880
881    /// Fit N-dimensional array of complex values to real basis coefficients
882    ///
883    /// # Arguments
884    /// * `backend` - Optional GEMM backend handle (None uses default)
885    /// * `values` - N-dimensional tensor of complex values at Matsubara frequencies
886    /// * `dim` - Dimension along which to fit (must have size = n_sampling_points)
887    ///
888    /// # Returns
889    /// N-dimensional tensor of real basis coefficients
890    pub fn fit_nd(
891        &self,
892        backend: Option<&GemmBackendHandle>,
893        values: &Tensor<Complex<f64>, DynRank>,
894        dim: usize,
895    ) -> Tensor<f64, DynRank> {
896        let rank = values.rank();
897        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
898
899        let n_points = self.n_sampling_points();
900        let target_dim_size = values.shape().dim(dim);
901
902        assert_eq!(
903            target_dim_size, n_points,
904            "values.shape().dim({}) = {} must equal n_sampling_points = {}",
905            dim, target_dim_size, n_points
906        );
907
908        // 1. Move target dimension to position 0
909        let values_dim0 = movedim(values, dim, 0);
910
911        // 2. Reshape to 2D: (n_points, extra_size)
912        let extra_size: usize = values_dim0.len() / n_points;
913        let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
914
915        // 3. Convert to DTensor and fit using GEMM
916        let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
917            values_2d_dyn[&[idx[0], idx[1]][..]]
918        });
919
920        // Use fitter's efficient 2D fit (GEMM-based)
921        let values_2d_view = values_2d.view(.., ..);
922        let coeffs_2d = self.fitter.fit_2d(backend, &values_2d_view);
923
924        // 4. Reshape back to N-D with basis_size at position 0
925        let basis_size = self.basis_size();
926        let mut coeffs_shape = vec![basis_size];
927        values_dim0.shape().with_dims(|dims| {
928            for i in 1..dims.len() {
929                coeffs_shape.push(dims[i]);
930            }
931        });
932
933        let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
934
935        // 5. Move dimension 0 back to original position dim
936        movedim(&coeffs_dim0, 0, dim)
937    }
938
939    /// Evaluate real basis coefficients at Matsubara sampling points (N-dimensional) with in-place output
940    ///
941    /// # Arguments
942    /// * `coeffs` - N-dimensional tensor of real coefficients with `coeffs.shape().dim(dim) == basis_size`
943    /// * `dim` - Dimension along which to evaluate (0-indexed)
944    /// * `out` - Output tensor with `out.shape().dim(dim) == n_sampling_points` (Complex<f64>)
945    pub fn evaluate_nd_to(
946        &self,
947        backend: Option<&GemmBackendHandle>,
948        coeffs: &Tensor<f64, DynRank>,
949        dim: usize,
950        out: &mut Tensor<Complex<f64>, DynRank>,
951    ) {
952        // Validate output shape
953        let rank = coeffs.rank();
954        assert_eq!(
955            out.rank(),
956            rank,
957            "out.rank()={} must equal coeffs.rank()={}",
958            out.rank(),
959            rank
960        );
961
962        let n_points = self.n_sampling_points();
963        let out_dim_size = out.shape().dim(dim);
964        assert_eq!(
965            out_dim_size, n_points,
966            "out.shape().dim({}) = {} must equal n_sampling_points = {}",
967            dim, out_dim_size, n_points
968        );
969
970        // Validate other dimensions match
971        for d in 0..rank {
972            if d != dim {
973                let coeffs_d = coeffs.shape().dim(d);
974                let out_d = out.shape().dim(d);
975                assert_eq!(
976                    coeffs_d, out_d,
977                    "coeffs.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
978                    d, coeffs_d, d, out_d
979                );
980            }
981        }
982
983        // Compute result and copy to out
984        let result = self.evaluate_nd(backend, coeffs, dim);
985
986        // Copy result to out
987        let total = out.len();
988        for i in 0..total {
989            let mut idx = vec![0usize; rank];
990            let mut remaining = i;
991            for d in (0..rank).rev() {
992                let dim_size = out.shape().dim(d);
993                idx[d] = remaining % dim_size;
994                remaining /= dim_size;
995            }
996            out[&idx[..]] = result[&idx[..]];
997        }
998    }
999
1000    /// Fit N-dimensional complex values to real coefficients with in-place output
1001    ///
1002    /// # Arguments
1003    /// * `values` - N-dimensional tensor with `values.shape().dim(dim) == n_sampling_points`
1004    /// * `dim` - Dimension along which to fit (0-indexed)
1005    /// * `out` - Output tensor with `out.shape().dim(dim) == basis_size` (f64)
1006    pub fn fit_nd_to(
1007        &self,
1008        backend: Option<&GemmBackendHandle>,
1009        values: &Tensor<Complex<f64>, DynRank>,
1010        dim: usize,
1011        out: &mut Tensor<f64, DynRank>,
1012    ) {
1013        // Validate output shape
1014        let rank = values.rank();
1015        assert_eq!(
1016            out.rank(),
1017            rank,
1018            "out.rank()={} must equal values.rank()={}",
1019            out.rank(),
1020            rank
1021        );
1022
1023        let basis_size = self.basis_size();
1024        let out_dim_size = out.shape().dim(dim);
1025        assert_eq!(
1026            out_dim_size, basis_size,
1027            "out.shape().dim({}) = {} must equal basis_size = {}",
1028            dim, out_dim_size, basis_size
1029        );
1030
1031        // Validate other dimensions match
1032        for d in 0..rank {
1033            if d != dim {
1034                let values_d = values.shape().dim(d);
1035                let out_d = out.shape().dim(d);
1036                assert_eq!(
1037                    values_d, out_d,
1038                    "values.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
1039                    d, values_d, d, out_d
1040                );
1041            }
1042        }
1043
1044        // Compute result and copy to out
1045        let result = self.fit_nd(backend, values, dim);
1046
1047        // Copy result to out
1048        let total = out.len();
1049        for i in 0..total {
1050            let mut idx = vec![0usize; rank];
1051            let mut remaining = i;
1052            for d in (0..rank).rev() {
1053                let dim_size = out.shape().dim(d);
1054                idx[d] = remaining % dim_size;
1055                remaining /= dim_size;
1056            }
1057            out[&idx[..]] = result[&idx[..]];
1058        }
1059    }
1060}
1061
1062/// InplaceFitter implementation for MatsubaraSamplingPositiveOnly
1063///
1064/// Delegates to ComplexToRealFitter which supports:
1065/// - dz: Real coefficients → Complex values (evaluate)
1066/// - zz: Complex coefficients → Complex values (evaluate, extracts real parts)
1067/// - zd: Complex values → Real coefficients (fit)
1068/// - zz: Complex values → Complex coefficients (fit, with zero imaginary parts)
1069impl<S: StatisticsType> InplaceFitter for MatsubaraSamplingPositiveOnly<S> {
1070    fn n_points(&self) -> usize {
1071        self.n_sampling_points()
1072    }
1073
1074    fn basis_size(&self) -> usize {
1075        self.basis_size()
1076    }
1077
1078    fn evaluate_nd_dz_to(
1079        &self,
1080        backend: Option<&GemmBackendHandle>,
1081        coeffs: &Slice<f64, DynRank>,
1082        dim: usize,
1083        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
1084    ) -> bool {
1085        self.fitter.evaluate_nd_dz_to(backend, coeffs, dim, out)
1086    }
1087
1088    fn evaluate_nd_zz_to(
1089        &self,
1090        backend: Option<&GemmBackendHandle>,
1091        coeffs: &Slice<Complex<f64>, DynRank>,
1092        dim: usize,
1093        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
1094    ) -> bool {
1095        self.fitter.evaluate_nd_zz_to(backend, coeffs, dim, out)
1096    }
1097
1098    fn fit_nd_zd_to(
1099        &self,
1100        backend: Option<&GemmBackendHandle>,
1101        values: &Slice<Complex<f64>, DynRank>,
1102        dim: usize,
1103        out: &mut ViewMut<'_, f64, DynRank>,
1104    ) -> bool {
1105        self.fitter.fit_nd_zd_to(backend, values, dim, out)
1106    }
1107
1108    fn fit_nd_zz_to(
1109        &self,
1110        backend: Option<&GemmBackendHandle>,
1111        values: &Slice<Complex<f64>, DynRank>,
1112        dim: usize,
1113        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
1114    ) -> bool {
1115        self.fitter.fit_nd_zz_to(backend, values, dim, out)
1116    }
1117}
1118
1119#[cfg(test)]
1120#[path = "matsubara_sampling_tests.rs"]
1121mod tests;