sparse_ir/
matsubara_sampling.rs

1//! Sparse sampling in Matsubara frequencies
2//!
3//! This module provides Matsubara frequency sampling for transforming between
4//! IR basis coefficients and values at sparse Matsubara frequencies.
5
6use crate::fitters::{ComplexMatrixFitter, ComplexToRealFitter, InplaceFitter};
7use crate::freq::MatsubaraFreq;
8use crate::gemm::GemmBackendHandle;
9use crate::sampling::movedim;
10use crate::traits::StatisticsType;
11use mdarray::{DTensor, DynRank, Shape, Slice, Tensor, ViewMut};
12use num_complex::Complex;
13use std::marker::PhantomData;
14
15/// Trait for coefficient types that can be evaluated by Matsubara sampling
16///
17/// This provides compile-time dispatch for different coefficient types,
18/// avoiding runtime TypeId checks and unsafe pointer casts.
19pub trait MatsubaraCoeffs: Copy + 'static {
20    /// Evaluate coefficients using the given sampler
21    fn evaluate_nd_with<S: StatisticsType>(
22        sampler: &MatsubaraSampling<S>,
23        backend: Option<&GemmBackendHandle>,
24        coeffs: &Slice<Self, DynRank>,
25        dim: usize,
26    ) -> Tensor<Complex<f64>, DynRank>;
27}
28
29impl MatsubaraCoeffs for f64 {
30    fn evaluate_nd_with<S: StatisticsType>(
31        sampler: &MatsubaraSampling<S>,
32        backend: Option<&GemmBackendHandle>,
33        coeffs: &Slice<Self, DynRank>,
34        dim: usize,
35    ) -> Tensor<Complex<f64>, DynRank> {
36        sampler.evaluate_nd_impl_real(backend, coeffs, dim)
37    }
38}
39
40impl MatsubaraCoeffs for Complex<f64> {
41    fn evaluate_nd_with<S: StatisticsType>(
42        sampler: &MatsubaraSampling<S>,
43        backend: Option<&GemmBackendHandle>,
44        coeffs: &Slice<Self, DynRank>,
45        dim: usize,
46    ) -> Tensor<Complex<f64>, DynRank> {
47        sampler.evaluate_nd_impl_complex(backend, coeffs, dim)
48    }
49}
50
51/// Matsubara sampling for full frequency range (positive and negative)
52///
53/// General complex problem without symmetry → complex coefficients
54pub struct MatsubaraSampling<S: StatisticsType> {
55    sampling_points: Vec<MatsubaraFreq<S>>,
56    fitter: ComplexMatrixFitter,
57    _phantom: PhantomData<S>,
58}
59
60impl<S: StatisticsType> MatsubaraSampling<S> {
61    /// Create Matsubara sampling with default sampling points
62    ///
63    /// Uses extrema-based sampling point selection (symmetric: positive and negative frequencies).
64    pub fn new(basis: &impl crate::basis_trait::Basis<S>) -> Self
65    where
66        S: 'static,
67    {
68        let sampling_points = basis.default_matsubara_sampling_points(false);
69        Self::with_sampling_points(basis, sampling_points)
70    }
71
72    /// Create Matsubara sampling with custom sampling points
73    pub fn with_sampling_points(
74        basis: &impl crate::basis_trait::Basis<S>,
75        mut sampling_points: Vec<MatsubaraFreq<S>>,
76    ) -> Self
77    where
78        S: 'static,
79    {
80        // Sort sampling points
81        sampling_points.sort();
82
83        // Evaluate matrix at sampling points
84        // Use Basis trait's evaluate_matsubara method
85        let matrix = basis.evaluate_matsubara(&sampling_points);
86
87        // Create fitter (complex → complex, no symmetry)
88        let fitter = ComplexMatrixFitter::new(matrix);
89
90        Self {
91            sampling_points,
92            fitter,
93            _phantom: PhantomData,
94        }
95    }
96
97    /// Create Matsubara sampling with custom sampling points and pre-computed matrix
98    ///
99    /// This constructor is useful when the sampling matrix is already computed
100    /// (e.g., from external sources or for testing).
101    ///
102    /// # Arguments
103    /// * `sampling_points` - Matsubara frequency sampling points
104    /// * `matrix` - Pre-computed sampling matrix (n_points × basis_size)
105    ///
106    /// # Returns
107    /// A new MatsubaraSampling object
108    ///
109    /// # Panics
110    /// Panics if `sampling_points` is empty or if matrix dimensions don't match
111    pub fn from_matrix(
112        mut sampling_points: Vec<MatsubaraFreq<S>>,
113        matrix: DTensor<Complex<f64>, 2>,
114    ) -> Self {
115        assert!(!sampling_points.is_empty(), "No sampling points given");
116        assert_eq!(
117            matrix.shape().0,
118            sampling_points.len(),
119            "Matrix rows ({}) must match number of sampling points ({})",
120            matrix.shape().0,
121            sampling_points.len()
122        );
123
124        // Sort sampling points
125        sampling_points.sort();
126
127        let fitter = ComplexMatrixFitter::new(matrix);
128
129        Self {
130            sampling_points,
131            fitter,
132            _phantom: PhantomData,
133        }
134    }
135
136    /// Get sampling points
137    pub fn sampling_points(&self) -> &[MatsubaraFreq<S>] {
138        &self.sampling_points
139    }
140
141    /// Number of sampling points
142    pub fn n_sampling_points(&self) -> usize {
143        self.sampling_points.len()
144    }
145
146    /// Basis size
147    pub fn basis_size(&self) -> usize {
148        self.fitter.basis_size()
149    }
150
151    /// Get the sampling matrix
152    pub fn matrix(&self) -> &DTensor<Complex<f64>, 2> {
153        &self.fitter.matrix
154    }
155
156    /// Evaluate complex basis coefficients at sampling points
157    ///
158    /// # Arguments
159    /// * `coeffs` - Complex basis coefficients (length = basis_size)
160    ///
161    /// # Returns
162    /// Complex values at Matsubara frequencies (length = n_sampling_points)
163    pub fn evaluate(&self, coeffs: &[Complex<f64>]) -> Vec<Complex<f64>> {
164        self.fitter.evaluate(None, coeffs)
165    }
166
167    /// Fit complex basis coefficients from values at sampling points
168    ///
169    /// # Arguments
170    /// * `values` - Complex values at Matsubara frequencies (length = n_sampling_points)
171    ///
172    /// # Returns
173    /// Fitted complex basis coefficients (length = basis_size)
174    pub fn fit(&self, values: &[Complex<f64>]) -> Vec<Complex<f64>> {
175        self.fitter.fit(None, values)
176    }
177
178    /// Evaluate N-dimensional array of basis coefficients at sampling points
179    ///
180    /// Supports both real (`f64`) and complex (`Complex<f64>`) coefficients.
181    /// Always returns complex values at Matsubara frequencies.
182    ///
183    /// # Type Parameters
184    /// * `T` - Element type (f64 or Complex<f64>)
185    ///
186    /// # Arguments
187    /// * `backend` - Optional GEMM backend handle (None uses default)
188    /// * `coeffs` - N-dimensional tensor of basis coefficients
189    /// * `dim` - Dimension along which to evaluate (must have size = basis_size)
190    ///
191    /// # Returns
192    /// N-dimensional tensor of complex values at Matsubara frequencies
193    ///
194    /// # Example
195    /// ```ignore
196    /// use num_complex::Complex;
197    ///
198    /// // Real coefficients
199    /// let values = matsubara_sampling.evaluate_nd::<f64>(None, &coeffs_real, 0);
200    ///
201    /// // Complex coefficients
202    /// let values = matsubara_sampling.evaluate_nd::<Complex<f64>>(None, &coeffs_complex, 0);
203    /// ```
204    /// Evaluate N-D coefficients for the real case `T = f64`
205    fn evaluate_nd_impl_real(
206        &self,
207        backend: Option<&GemmBackendHandle>,
208        coeffs: &Slice<f64, DynRank>,
209        dim: usize,
210    ) -> Tensor<Complex<f64>, DynRank> {
211        let rank = coeffs.rank();
212        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
213
214        let basis_size = self.basis_size();
215        let target_dim_size = coeffs.shape().dim(dim);
216
217        assert_eq!(
218            target_dim_size, basis_size,
219            "coeffs.shape().dim({}) = {} must equal basis_size = {}",
220            dim, target_dim_size, basis_size
221        );
222
223        // 1. Move target dimension to position 0
224        let coeffs_dim0 = movedim(coeffs, dim, 0);
225
226        // 2. Reshape to 2D: (basis_size, extra_size)
227        let extra_size: usize = coeffs_dim0.len() / basis_size;
228
229        let coeffs_2d_dyn = coeffs_dim0
230            .reshape(&[basis_size, extra_size][..])
231            .to_tensor();
232
233        // 3. Convert to DTensor and evaluate using evaluate_2d_real
234        let coeffs_2d = DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| {
235            coeffs_2d_dyn[&[idx[0], idx[1]][..]]
236        });
237        let coeffs_2d_view = coeffs_2d.view(.., ..);
238        let result_2d = self.fitter.evaluate_2d_real(backend, &coeffs_2d_view);
239
240        // 4. Reshape back to N-D with n_points at position 0
241        let n_points = self.n_sampling_points();
242        let mut result_shape = vec![n_points];
243        coeffs_dim0.shape().with_dims(|dims| {
244            for i in 1..dims.len() {
245                result_shape.push(dims[i]);
246            }
247        });
248
249        let result_dim0 = result_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
250
251        // 5. Move dimension 0 back to original position dim
252        movedim(&result_dim0, 0, dim)
253    }
254
255    /// Evaluate N-D coefficients for the complex case `T = Complex<f64>`
256    fn evaluate_nd_impl_complex(
257        &self,
258        backend: Option<&GemmBackendHandle>,
259        coeffs: &Slice<Complex<f64>, DynRank>,
260        dim: usize,
261    ) -> Tensor<Complex<f64>, DynRank> {
262        let rank = coeffs.rank();
263        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
264
265        let basis_size = self.basis_size();
266        let target_dim_size = coeffs.shape().dim(dim);
267
268        assert_eq!(
269            target_dim_size, basis_size,
270            "coeffs.shape().dim({}) = {} must equal basis_size = {}",
271            dim, target_dim_size, basis_size
272        );
273
274        // 1. Move target dimension to position 0
275        let coeffs_dim0 = movedim(coeffs, dim, 0);
276
277        // 2. Reshape to 2D: (basis_size, extra_size)
278        let extra_size: usize = coeffs_dim0.len() / basis_size;
279
280        let coeffs_2d_dyn = coeffs_dim0
281            .reshape(&[basis_size, extra_size][..])
282            .to_tensor();
283
284        // 3. Convert to DTensor and evaluate using evaluate_2d
285        let coeffs_2d = DTensor::<Complex<f64>, 2>::from_fn([basis_size, extra_size], |idx| {
286            coeffs_2d_dyn[&[idx[0], idx[1]][..]]
287        });
288        let coeffs_2d_view = coeffs_2d.view(.., ..);
289        let result_2d = self.fitter.evaluate_2d(backend, &coeffs_2d_view);
290
291        // 4. Reshape back to N-D with n_points at position 0
292        let n_points = self.n_sampling_points();
293        let mut result_shape = vec![n_points];
294        coeffs_dim0.shape().with_dims(|dims| {
295            for i in 1..dims.len() {
296                result_shape.push(dims[i]);
297            }
298        });
299
300        let result_dim0 = result_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
301
302        // 5. Move dimension 0 back to original position dim
303        movedim(&result_dim0, 0, dim)
304    }
305
306    /// Evaluate N-dimensional coefficients at Matsubara sampling points
307    ///
308    /// This method dispatches to the appropriate implementation based on the
309    /// coefficient type at compile time using the `MatsubaraCoeffs` trait.
310    ///
311    /// # Type Parameter
312    /// * `T` - Must implement `MatsubaraCoeffs` (currently `f64` or `Complex<f64>`)
313    ///
314    /// # Arguments
315    /// * `backend` - Optional GEMM backend handle
316    /// * `coeffs` - N-dimensional tensor of basis coefficients
317    /// * `dim` - Dimension along which to evaluate
318    ///
319    /// # Returns
320    /// N-dimensional tensor of complex values at Matsubara frequencies
321    pub fn evaluate_nd<T: MatsubaraCoeffs>(
322        &self,
323        backend: Option<&GemmBackendHandle>,
324        coeffs: &Slice<T, DynRank>,
325        dim: usize,
326    ) -> Tensor<Complex<f64>, DynRank> {
327        T::evaluate_nd_with(self, backend, coeffs, dim)
328    }
329
330    /// Evaluate real basis coefficients at Matsubara sampling points (N-dimensional)
331    ///
332    /// This method takes real coefficients and produces complex values, useful when
333    /// working with symmetry-exploiting representations or real-valued IR coefficients.
334    ///
335    /// # Arguments
336    /// * `backend` - Optional GEMM backend handle (None uses default)
337    /// * `coeffs` - N-dimensional tensor of real basis coefficients
338    /// * `dim` - Dimension along which to evaluate (must have size = basis_size)
339    ///
340    /// # Returns
341    /// N-dimensional tensor of complex values at Matsubara frequencies
342    pub fn evaluate_nd_real(
343        &self,
344        backend: Option<&GemmBackendHandle>,
345        coeffs: &Tensor<f64, DynRank>,
346        dim: usize,
347    ) -> Tensor<Complex<f64>, DynRank> {
348        let rank = coeffs.rank();
349        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
350
351        let basis_size = self.basis_size();
352        let target_dim_size = coeffs.shape().dim(dim);
353
354        assert_eq!(
355            target_dim_size, basis_size,
356            "coeffs.shape().dim({}) = {} must equal basis_size = {}",
357            dim, target_dim_size, basis_size
358        );
359
360        // 1. Move target dimension to position 0
361        let coeffs_dim0 = movedim(coeffs, dim, 0);
362
363        // 2. Reshape to 2D: (basis_size, extra_size)
364        let extra_size: usize = coeffs_dim0.len() / basis_size;
365
366        let coeffs_2d_dyn = coeffs_dim0
367            .reshape(&[basis_size, extra_size][..])
368            .to_tensor();
369
370        // 3. Convert to DTensor and evaluate using ComplexMatrixFitter
371        let coeffs_2d = DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| {
372            coeffs_2d_dyn[&[idx[0], idx[1]][..]]
373        });
374
375        // 4. Evaluate: values = A * coeffs (A is complex, coeffs is real)
376        let coeffs_2d_view = coeffs_2d.view(.., ..);
377        let values_2d = self.fitter.evaluate_2d_real(backend, &coeffs_2d_view);
378
379        // 5. Reshape result back to N-D with first dimension = n_sampling_points
380        let n_points = self.n_sampling_points();
381        let mut result_shape = Vec::with_capacity(rank);
382        result_shape.push(n_points);
383        coeffs_dim0.shape().with_dims(|dims| {
384            for i in 1..dims.len() {
385                result_shape.push(dims[i]);
386            }
387        });
388
389        let result_dim0 = values_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
390
391        // 6. Move dimension 0 back to original position dim
392        movedim(&result_dim0, 0, dim)
393    }
394
395    /// Fit N-dimensional array of complex values to complex basis coefficients
396    ///
397    /// # Arguments
398    /// * `backend` - Optional GEMM backend handle (None uses default)
399    /// * `values` - N-dimensional tensor of complex values at Matsubara frequencies
400    /// * `dim` - Dimension along which to fit (must have size = n_sampling_points)
401    ///
402    /// # Returns
403    /// N-dimensional tensor of complex basis coefficients
404    pub fn fit_nd(
405        &self,
406        backend: Option<&GemmBackendHandle>,
407        values: &Tensor<Complex<f64>, DynRank>,
408        dim: usize,
409    ) -> Tensor<Complex<f64>, DynRank> {
410        let rank = values.rank();
411        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
412
413        let n_points = self.n_sampling_points();
414        let target_dim_size = values.shape().dim(dim);
415
416        assert_eq!(
417            target_dim_size, n_points,
418            "values.shape().dim({}) = {} must equal n_sampling_points = {}",
419            dim, target_dim_size, n_points
420        );
421
422        // 1. Move target dimension to position 0
423        let values_dim0 = movedim(values, dim, 0);
424
425        // 2. Reshape to 2D: (n_points, extra_size)
426        let extra_size: usize = values_dim0.len() / n_points;
427        let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
428
429        // 3. Convert to DTensor and fit using GEMM
430        let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
431            values_2d_dyn[&[idx[0], idx[1]][..]]
432        });
433
434        // Use fitter's efficient 2D fit (GEMM-based)
435        let values_2d_view = values_2d.view(.., ..);
436        let coeffs_2d = self.fitter.fit_2d(backend, &values_2d_view);
437
438        // 4. Reshape back to N-D with basis_size at position 0
439        let basis_size = self.basis_size();
440        let mut coeffs_shape = vec![basis_size];
441        values_dim0.shape().with_dims(|dims| {
442            for i in 1..dims.len() {
443                coeffs_shape.push(dims[i]);
444            }
445        });
446
447        let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
448
449        // 5. Move dimension 0 back to original position dim
450        movedim(&coeffs_dim0, 0, dim)
451    }
452
453    /// Fit N-dimensional array of complex values to real basis coefficients
454    ///
455    /// This method fits complex Matsubara values to real IR coefficients.
456    /// Takes the real part of the least-squares solution.
457    ///
458    /// # Arguments
459    /// * `backend` - Optional GEMM backend handle (None uses default)
460    /// * `values` - N-dimensional tensor of complex values at Matsubara frequencies
461    /// * `dim` - Dimension along which to fit (must have size = n_sampling_points)
462    ///
463    /// # Returns
464    /// N-dimensional tensor of real basis coefficients
465    pub fn fit_nd_real(
466        &self,
467        backend: Option<&GemmBackendHandle>,
468        values: &Tensor<Complex<f64>, DynRank>,
469        dim: usize,
470    ) -> Tensor<f64, DynRank> {
471        let rank = values.rank();
472        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
473
474        let n_points = self.n_sampling_points();
475        let target_dim_size = values.shape().dim(dim);
476
477        assert_eq!(
478            target_dim_size, n_points,
479            "values.shape().dim({}) = {} must equal n_sampling_points = {}",
480            dim, target_dim_size, n_points
481        );
482
483        // 1. Move target dimension to position 0
484        let values_dim0 = movedim(values, dim, 0);
485
486        // 2. Reshape to 2D: (n_points, extra_size)
487        let extra_size: usize = values_dim0.len() / n_points;
488        let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
489
490        // 3. Convert to DTensor and fit
491        let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
492            values_2d_dyn[&[idx[0], idx[1]][..]]
493        });
494
495        // Use fitter's fit_2d_real method
496        let values_2d_view = values_2d.view(.., ..);
497        let coeffs_2d = self.fitter.fit_2d_real(backend, &values_2d_view);
498
499        // 4. Reshape back to N-D with basis_size at position 0
500        let basis_size = self.basis_size();
501        let mut coeffs_shape = vec![basis_size];
502        values_dim0.shape().with_dims(|dims| {
503            for i in 1..dims.len() {
504                coeffs_shape.push(dims[i]);
505            }
506        });
507
508        let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
509
510        // 5. Move dimension 0 back to original position dim
511        movedim(&coeffs_dim0, 0, dim)
512    }
513
514    /// Evaluate basis coefficients at Matsubara sampling points (N-dimensional) with in-place output
515    ///
516    /// # Type Parameters
517    /// * `T` - Coefficient type (f64 or Complex<f64>)
518    ///
519    /// # Arguments
520    /// * `coeffs` - N-dimensional tensor with `coeffs.shape().dim(dim) == basis_size`
521    /// * `dim` - Dimension along which to evaluate (0-indexed)
522    /// * `out` - Output tensor with `out.shape().dim(dim) == n_sampling_points` (Complex<f64>)
523    pub fn evaluate_nd_to<T: MatsubaraCoeffs>(
524        &self,
525        backend: Option<&GemmBackendHandle>,
526        coeffs: &Slice<T, DynRank>,
527        dim: usize,
528        out: &mut Tensor<Complex<f64>, DynRank>,
529    ) {
530        // Validate output shape
531        let rank = coeffs.rank();
532        assert_eq!(
533            out.rank(),
534            rank,
535            "out.rank()={} must equal coeffs.rank()={}",
536            out.rank(),
537            rank
538        );
539
540        let n_points = self.n_sampling_points();
541        let out_dim_size = out.shape().dim(dim);
542        assert_eq!(
543            out_dim_size, n_points,
544            "out.shape().dim({}) = {} must equal n_sampling_points = {}",
545            dim, out_dim_size, n_points
546        );
547
548        // Validate other dimensions match
549        for d in 0..rank {
550            if d != dim {
551                let coeffs_d = coeffs.shape().dim(d);
552                let out_d = out.shape().dim(d);
553                assert_eq!(
554                    coeffs_d, out_d,
555                    "coeffs.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
556                    d, coeffs_d, d, out_d
557                );
558            }
559        }
560
561        // Compute result and copy to out
562        let result = self.evaluate_nd(backend, coeffs, dim);
563
564        // Copy result to out
565        let total = out.len();
566        for i in 0..total {
567            let mut idx = vec![0usize; rank];
568            let mut remaining = i;
569            for d in (0..rank).rev() {
570                let dim_size = out.shape().dim(d);
571                idx[d] = remaining % dim_size;
572                remaining /= dim_size;
573            }
574            out[&idx[..]] = result[&idx[..]];
575        }
576    }
577
578    /// Fit N-dimensional complex values to complex coefficients with in-place output
579    ///
580    /// # Arguments
581    /// * `values` - N-dimensional tensor with `values.shape().dim(dim) == n_sampling_points`
582    /// * `dim` - Dimension along which to fit (0-indexed)
583    /// * `out` - Output tensor with `out.shape().dim(dim) == basis_size` (Complex<f64>)
584    pub fn fit_nd_to(
585        &self,
586        backend: Option<&GemmBackendHandle>,
587        values: &Tensor<Complex<f64>, DynRank>,
588        dim: usize,
589        out: &mut Tensor<Complex<f64>, DynRank>,
590    ) {
591        // Validate output shape
592        let rank = values.rank();
593        assert_eq!(
594            out.rank(),
595            rank,
596            "out.rank()={} must equal values.rank()={}",
597            out.rank(),
598            rank
599        );
600
601        let basis_size = self.basis_size();
602        let out_dim_size = out.shape().dim(dim);
603        assert_eq!(
604            out_dim_size, basis_size,
605            "out.shape().dim({}) = {} must equal basis_size = {}",
606            dim, out_dim_size, basis_size
607        );
608
609        // Validate other dimensions match
610        for d in 0..rank {
611            if d != dim {
612                let values_d = values.shape().dim(d);
613                let out_d = out.shape().dim(d);
614                assert_eq!(
615                    values_d, out_d,
616                    "values.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
617                    d, values_d, d, out_d
618                );
619            }
620        }
621
622        // Compute result and copy to out
623        let result = self.fit_nd(backend, values, dim);
624
625        // Copy result to out
626        let total = out.len();
627        for i in 0..total {
628            let mut idx = vec![0usize; rank];
629            let mut remaining = i;
630            for d in (0..rank).rev() {
631                let dim_size = out.shape().dim(d);
632                idx[d] = remaining % dim_size;
633                remaining /= dim_size;
634            }
635            out[&idx[..]] = result[&idx[..]];
636        }
637    }
638}
639
640/// InplaceFitter implementation for MatsubaraSampling
641///
642/// Delegates to ComplexMatrixFitter which supports:
643/// - zz: Complex input → Complex output (full support)
644/// - dz: Real input → Complex output (evaluate only)
645/// - zd: Complex input → Real output (fit only, takes real part)
646impl<S: StatisticsType> InplaceFitter for MatsubaraSampling<S> {
647    fn n_points(&self) -> usize {
648        self.n_sampling_points()
649    }
650
651    fn basis_size(&self) -> usize {
652        self.basis_size()
653    }
654
655    fn evaluate_nd_dz_to(
656        &self,
657        backend: Option<&GemmBackendHandle>,
658        coeffs: &Slice<f64, DynRank>,
659        dim: usize,
660        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
661    ) -> bool {
662        self.fitter.evaluate_nd_dz_to(backend, coeffs, dim, out)
663    }
664
665    fn evaluate_nd_zz_to(
666        &self,
667        backend: Option<&GemmBackendHandle>,
668        coeffs: &Slice<Complex<f64>, DynRank>,
669        dim: usize,
670        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
671    ) -> bool {
672        self.fitter.evaluate_nd_zz_to(backend, coeffs, dim, out)
673    }
674
675    fn fit_nd_zd_to(
676        &self,
677        backend: Option<&GemmBackendHandle>,
678        values: &Slice<Complex<f64>, DynRank>,
679        dim: usize,
680        out: &mut ViewMut<'_, f64, DynRank>,
681    ) -> bool {
682        self.fitter.fit_nd_zd_to(backend, values, dim, out)
683    }
684
685    fn fit_nd_zz_to(
686        &self,
687        backend: Option<&GemmBackendHandle>,
688        values: &Slice<Complex<f64>, DynRank>,
689        dim: usize,
690        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
691    ) -> bool {
692        self.fitter.fit_nd_zz_to(backend, values, dim, out)
693    }
694}
695
696/// Matsubara sampling for positive frequencies only
697///
698/// Exploits symmetry to reconstruct real coefficients from positive frequencies only.
699/// Supports: {0, 1, 2, 3, ...} (no negative frequencies)
700pub struct MatsubaraSamplingPositiveOnly<S: StatisticsType> {
701    sampling_points: Vec<MatsubaraFreq<S>>,
702    fitter: ComplexToRealFitter,
703    _phantom: PhantomData<S>,
704}
705
706impl<S: StatisticsType> MatsubaraSamplingPositiveOnly<S> {
707    /// Create Matsubara sampling with default positive-only sampling points
708    ///
709    /// Uses extrema-based sampling point selection (positive frequencies only).
710    /// Exploits symmetry to reconstruct real coefficients.
711    pub fn new(basis: &impl crate::basis_trait::Basis<S>) -> Self
712    where
713        S: 'static,
714    {
715        let sampling_points = basis.default_matsubara_sampling_points(true);
716        Self::with_sampling_points(basis, sampling_points)
717    }
718
719    /// Create Matsubara sampling with custom positive-only sampling points
720    pub fn with_sampling_points(
721        basis: &impl crate::basis_trait::Basis<S>,
722        mut sampling_points: Vec<MatsubaraFreq<S>>,
723    ) -> Self
724    where
725        S: 'static,
726    {
727        // Sort and validate (all n >= 0)
728        sampling_points.sort();
729
730        // TODO: Validate that all points are non-negative
731
732        // Evaluate matrix at sampling points
733        // Use Basis trait's evaluate_matsubara method
734        let matrix = basis.evaluate_matsubara(&sampling_points);
735
736        // Create fitter (complex → real, exploits symmetry)
737        let fitter = ComplexToRealFitter::new(&matrix);
738
739        Self {
740            sampling_points,
741            fitter,
742            _phantom: PhantomData,
743        }
744    }
745
746    /// Create Matsubara sampling (positive-only) with custom sampling points and pre-computed matrix
747    ///
748    /// This constructor is useful when the sampling matrix is already computed.
749    /// Uses symmetry to fit real coefficients from complex values at positive frequencies.
750    ///
751    /// # Arguments
752    /// * `sampling_points` - Matsubara frequency sampling points (should be positive)
753    /// * `matrix` - Pre-computed sampling matrix (n_points × basis_size)
754    ///
755    /// # Returns
756    /// A new MatsubaraSamplingPositiveOnly object
757    ///
758    /// # Panics
759    /// Panics if `sampling_points` is empty or if matrix dimensions don't match
760    pub fn from_matrix(
761        mut sampling_points: Vec<MatsubaraFreq<S>>,
762        matrix: DTensor<Complex<f64>, 2>,
763    ) -> Self {
764        assert!(!sampling_points.is_empty(), "No sampling points given");
765        assert_eq!(
766            matrix.shape().0,
767            sampling_points.len(),
768            "Matrix rows ({}) must match number of sampling points ({})",
769            matrix.shape().0,
770            sampling_points.len()
771        );
772
773        // Sort sampling points
774        sampling_points.sort();
775
776        let fitter = ComplexToRealFitter::new(&matrix);
777
778        Self {
779            sampling_points,
780            fitter,
781            _phantom: PhantomData,
782        }
783    }
784
785    /// Get sampling points
786    pub fn sampling_points(&self) -> &[MatsubaraFreq<S>] {
787        &self.sampling_points
788    }
789
790    /// Number of sampling points
791    pub fn n_sampling_points(&self) -> usize {
792        self.sampling_points.len()
793    }
794
795    /// Basis size
796    pub fn basis_size(&self) -> usize {
797        self.fitter.basis_size()
798    }
799
800    /// Get the original complex sampling matrix
801    pub fn matrix(&self) -> &DTensor<Complex<f64>, 2> {
802        &self.fitter.matrix
803    }
804
805    /// Evaluate basis coefficients at sampling points
806    pub fn evaluate(&self, coeffs: &[f64]) -> Vec<Complex<f64>> {
807        self.fitter.evaluate(None, coeffs)
808    }
809
810    /// Fit basis coefficients from values at sampling points
811    pub fn fit(&self, values: &[Complex<f64>]) -> Vec<f64> {
812        self.fitter.fit(None, values)
813    }
814
815    /// Evaluate N-dimensional array of real basis coefficients at sampling points
816    ///
817    /// # Arguments
818    /// * `coeffs` - N-dimensional tensor of real basis coefficients
819    /// * `dim` - Dimension along which to evaluate (must have size = basis_size)
820    ///
821    /// # Returns
822    /// N-dimensional tensor of complex values at Matsubara frequencies
823    pub fn evaluate_nd(
824        &self,
825        backend: Option<&GemmBackendHandle>,
826        coeffs: &Tensor<f64, DynRank>,
827        dim: usize,
828    ) -> Tensor<Complex<f64>, DynRank> {
829        let rank = coeffs.rank();
830        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
831
832        let basis_size = self.basis_size();
833        let target_dim_size = coeffs.shape().dim(dim);
834
835        assert_eq!(
836            target_dim_size, basis_size,
837            "coeffs.shape().dim({}) = {} must equal basis_size = {}",
838            dim, target_dim_size, basis_size
839        );
840
841        // 1. Move target dimension to position 0
842        let coeffs_dim0 = movedim(coeffs, dim, 0);
843
844        // 2. Reshape to 2D: (basis_size, extra_size)
845        let extra_size: usize = coeffs_dim0.len() / basis_size;
846
847        let coeffs_2d_dyn = coeffs_dim0
848            .reshape(&[basis_size, extra_size][..])
849            .to_tensor();
850
851        // 3. Convert to DTensor and evaluate using GEMM
852        let coeffs_2d = DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| {
853            coeffs_2d_dyn[&[idx[0], idx[1]][..]]
854        });
855
856        // Use fitter's efficient 2D evaluate (GEMM-based)
857        let coeffs_2d_view = coeffs_2d.view(.., ..);
858        let result_2d = self.fitter.evaluate_2d(backend, &coeffs_2d_view);
859
860        // 4. Reshape back to N-D with n_points at position 0
861        let n_points = self.n_sampling_points();
862        let mut result_shape = vec![n_points];
863        coeffs_dim0.shape().with_dims(|dims| {
864            for i in 1..dims.len() {
865                result_shape.push(dims[i]);
866            }
867        });
868
869        let result_dim0 = result_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
870
871        // 5. Move dimension 0 back to original position dim
872        movedim(&result_dim0, 0, dim)
873    }
874
875    /// Fit N-dimensional array of complex values to real basis coefficients
876    ///
877    /// # Arguments
878    /// * `backend` - Optional GEMM backend handle (None uses default)
879    /// * `values` - N-dimensional tensor of complex values at Matsubara frequencies
880    /// * `dim` - Dimension along which to fit (must have size = n_sampling_points)
881    ///
882    /// # Returns
883    /// N-dimensional tensor of real basis coefficients
884    pub fn fit_nd(
885        &self,
886        backend: Option<&GemmBackendHandle>,
887        values: &Tensor<Complex<f64>, DynRank>,
888        dim: usize,
889    ) -> Tensor<f64, DynRank> {
890        let rank = values.rank();
891        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
892
893        let n_points = self.n_sampling_points();
894        let target_dim_size = values.shape().dim(dim);
895
896        assert_eq!(
897            target_dim_size, n_points,
898            "values.shape().dim({}) = {} must equal n_sampling_points = {}",
899            dim, target_dim_size, n_points
900        );
901
902        // 1. Move target dimension to position 0
903        let values_dim0 = movedim(values, dim, 0);
904
905        // 2. Reshape to 2D: (n_points, extra_size)
906        let extra_size: usize = values_dim0.len() / n_points;
907        let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
908
909        // 3. Convert to DTensor and fit using GEMM
910        let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
911            values_2d_dyn[&[idx[0], idx[1]][..]]
912        });
913
914        // Use fitter's efficient 2D fit (GEMM-based)
915        let values_2d_view = values_2d.view(.., ..);
916        let coeffs_2d = self.fitter.fit_2d(backend, &values_2d_view);
917
918        // 4. Reshape back to N-D with basis_size at position 0
919        let basis_size = self.basis_size();
920        let mut coeffs_shape = vec![basis_size];
921        values_dim0.shape().with_dims(|dims| {
922            for i in 1..dims.len() {
923                coeffs_shape.push(dims[i]);
924            }
925        });
926
927        let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
928
929        // 5. Move dimension 0 back to original position dim
930        movedim(&coeffs_dim0, 0, dim)
931    }
932
933    /// Evaluate real basis coefficients at Matsubara sampling points (N-dimensional) with in-place output
934    ///
935    /// # Arguments
936    /// * `coeffs` - N-dimensional tensor of real coefficients with `coeffs.shape().dim(dim) == basis_size`
937    /// * `dim` - Dimension along which to evaluate (0-indexed)
938    /// * `out` - Output tensor with `out.shape().dim(dim) == n_sampling_points` (Complex<f64>)
939    pub fn evaluate_nd_to(
940        &self,
941        backend: Option<&GemmBackendHandle>,
942        coeffs: &Tensor<f64, DynRank>,
943        dim: usize,
944        out: &mut Tensor<Complex<f64>, DynRank>,
945    ) {
946        // Validate output shape
947        let rank = coeffs.rank();
948        assert_eq!(
949            out.rank(),
950            rank,
951            "out.rank()={} must equal coeffs.rank()={}",
952            out.rank(),
953            rank
954        );
955
956        let n_points = self.n_sampling_points();
957        let out_dim_size = out.shape().dim(dim);
958        assert_eq!(
959            out_dim_size, n_points,
960            "out.shape().dim({}) = {} must equal n_sampling_points = {}",
961            dim, out_dim_size, n_points
962        );
963
964        // Validate other dimensions match
965        for d in 0..rank {
966            if d != dim {
967                let coeffs_d = coeffs.shape().dim(d);
968                let out_d = out.shape().dim(d);
969                assert_eq!(
970                    coeffs_d, out_d,
971                    "coeffs.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
972                    d, coeffs_d, d, out_d
973                );
974            }
975        }
976
977        // Compute result and copy to out
978        let result = self.evaluate_nd(backend, coeffs, dim);
979
980        // Copy result to out
981        let total = out.len();
982        for i in 0..total {
983            let mut idx = vec![0usize; rank];
984            let mut remaining = i;
985            for d in (0..rank).rev() {
986                let dim_size = out.shape().dim(d);
987                idx[d] = remaining % dim_size;
988                remaining /= dim_size;
989            }
990            out[&idx[..]] = result[&idx[..]];
991        }
992    }
993
994    /// Fit N-dimensional complex values to real coefficients with in-place output
995    ///
996    /// # Arguments
997    /// * `values` - N-dimensional tensor with `values.shape().dim(dim) == n_sampling_points`
998    /// * `dim` - Dimension along which to fit (0-indexed)
999    /// * `out` - Output tensor with `out.shape().dim(dim) == basis_size` (f64)
1000    pub fn fit_nd_to(
1001        &self,
1002        backend: Option<&GemmBackendHandle>,
1003        values: &Tensor<Complex<f64>, DynRank>,
1004        dim: usize,
1005        out: &mut Tensor<f64, DynRank>,
1006    ) {
1007        // Validate output shape
1008        let rank = values.rank();
1009        assert_eq!(
1010            out.rank(),
1011            rank,
1012            "out.rank()={} must equal values.rank()={}",
1013            out.rank(),
1014            rank
1015        );
1016
1017        let basis_size = self.basis_size();
1018        let out_dim_size = out.shape().dim(dim);
1019        assert_eq!(
1020            out_dim_size, basis_size,
1021            "out.shape().dim({}) = {} must equal basis_size = {}",
1022            dim, out_dim_size, basis_size
1023        );
1024
1025        // Validate other dimensions match
1026        for d in 0..rank {
1027            if d != dim {
1028                let values_d = values.shape().dim(d);
1029                let out_d = out.shape().dim(d);
1030                assert_eq!(
1031                    values_d, out_d,
1032                    "values.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
1033                    d, values_d, d, out_d
1034                );
1035            }
1036        }
1037
1038        // Compute result and copy to out
1039        let result = self.fit_nd(backend, values, dim);
1040
1041        // Copy result to out
1042        let total = out.len();
1043        for i in 0..total {
1044            let mut idx = vec![0usize; rank];
1045            let mut remaining = i;
1046            for d in (0..rank).rev() {
1047                let dim_size = out.shape().dim(d);
1048                idx[d] = remaining % dim_size;
1049                remaining /= dim_size;
1050            }
1051            out[&idx[..]] = result[&idx[..]];
1052        }
1053    }
1054}
1055
1056/// InplaceFitter implementation for MatsubaraSamplingPositiveOnly
1057///
1058/// Delegates to ComplexToRealFitter which supports:
1059/// - dz: Real coefficients → Complex values (evaluate)
1060/// - zz: Complex coefficients → Complex values (evaluate, extracts real parts)
1061/// - zd: Complex values → Real coefficients (fit)
1062/// - zz: Complex values → Complex coefficients (fit, with zero imaginary parts)
1063impl<S: StatisticsType> InplaceFitter for MatsubaraSamplingPositiveOnly<S> {
1064    fn n_points(&self) -> usize {
1065        self.n_sampling_points()
1066    }
1067
1068    fn basis_size(&self) -> usize {
1069        self.basis_size()
1070    }
1071
1072    fn evaluate_nd_dz_to(
1073        &self,
1074        backend: Option<&GemmBackendHandle>,
1075        coeffs: &Slice<f64, DynRank>,
1076        dim: usize,
1077        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
1078    ) -> bool {
1079        self.fitter.evaluate_nd_dz_to(backend, coeffs, dim, out)
1080    }
1081
1082    fn evaluate_nd_zz_to(
1083        &self,
1084        backend: Option<&GemmBackendHandle>,
1085        coeffs: &Slice<Complex<f64>, DynRank>,
1086        dim: usize,
1087        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
1088    ) -> bool {
1089        self.fitter.evaluate_nd_zz_to(backend, coeffs, dim, out)
1090    }
1091
1092    fn fit_nd_zd_to(
1093        &self,
1094        backend: Option<&GemmBackendHandle>,
1095        values: &Slice<Complex<f64>, DynRank>,
1096        dim: usize,
1097        out: &mut ViewMut<'_, f64, DynRank>,
1098    ) -> bool {
1099        self.fitter.fit_nd_zd_to(backend, values, dim, out)
1100    }
1101
1102    fn fit_nd_zz_to(
1103        &self,
1104        backend: Option<&GemmBackendHandle>,
1105        values: &Slice<Complex<f64>, DynRank>,
1106        dim: usize,
1107        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
1108    ) -> bool {
1109        self.fitter.fit_nd_zz_to(backend, values, dim, out)
1110    }
1111}
1112
1113#[cfg(test)]
1114#[path = "matsubara_sampling_tests.rs"]
1115mod tests;