sparse_ir/
basis.rs

1//! Finite temperature basis for SparseIR
2//!
3//! This module provides the `FiniteTempBasis` type which represents the
4//! intermediate representation (IR) basis for a given temperature.
5
6use std::sync::Arc;
7
8use crate::kernel::{CentrosymmKernel, KernelProperties, LogisticKernel};
9use crate::poly::{PiecewiseLegendrePolyVector, default_sampling_points};
10use crate::polyfourier::PiecewiseLegendreFTVector;
11use crate::sve::{SVEResult, TworkType, compute_sve};
12use crate::traits::{Bosonic, Fermionic, StatisticsType};
13
14// Re-export Statistics enum for C-API
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum Statistics {
17    Fermionic,
18    Bosonic,
19}
20
21/// Finite temperature basis for imaginary time/frequency Green's functions
22///
23/// For a continuation kernel `K` from real frequencies `ω ∈ [-ωmax, ωmax]` to
24/// imaginary time `τ ∈ [0, β]`, this type stores the truncated singular
25/// value expansion or IR basis:
26///
27/// ```text
28/// K(τ, ω) ≈ sum(u[l](τ) * s[l] * v[l](ω) for l in 1:L)
29/// ```
30///
31/// This basis is inferred from a reduced form by appropriate scaling of
32/// the variables.
33///
34/// # Type Parameters
35///
36/// * `K` - Kernel type implementing `KernelProperties + CentrosymmKernel`
37/// * `S` - Statistics type (`Fermionic` or `Bosonic`)
38#[derive(Clone)]
39pub struct FiniteTempBasis<K, S>
40where
41    K: KernelProperties + CentrosymmKernel + Clone + 'static,
42    S: StatisticsType,
43{
44    /// The kernel used to construct this basis
45    pub kernel: K,
46
47    /// The SVE result (in scaled variables)
48    pub sve_result: Arc<SVEResult>,
49
50    /// Accuracy of the basis (relative error)
51    pub accuracy: f64,
52
53    /// Inverse temperature β
54    pub beta: f64,
55
56    /// Left singular functions on imaginary time axis τ ∈ [0, β]
57    /// Arc for efficient sharing (large immutable data)
58    pub u: Arc<PiecewiseLegendrePolyVector>,
59
60    /// Right singular functions on real frequency axis ω ∈ [-ωmax, ωmax]
61    /// Arc for efficient sharing (large immutable data)
62    pub v: Arc<PiecewiseLegendrePolyVector>,
63
64    /// Singular values
65    pub s: Vec<f64>,
66
67    /// Left singular functions on Matsubara frequency axis (Fourier transform of u)
68    /// Arc for efficient sharing (large immutable data)
69    pub uhat: Arc<PiecewiseLegendreFTVector<S>>,
70
71    /// Full uhat (before truncation to basis size)
72    /// Arc for efficient sharing (large immutable data, used for Matsubara sampling)
73    pub uhat_full: Arc<PiecewiseLegendreFTVector<S>>,
74
75    _phantom: std::marker::PhantomData<S>,
76}
77
78impl<K, S> FiniteTempBasis<K, S>
79where
80    K: KernelProperties + CentrosymmKernel + Clone + 'static,
81    S: StatisticsType,
82{
83    /// Get the frequency cutoff ωmax
84    pub fn wmax(&self) -> f64 {
85        self.kernel.lambda() / self.beta
86    }
87
88    /// Get default Matsubara sampling points as i64 indices (for C-API)
89    pub fn default_matsubara_sampling_points_i64(&self, positive_only: bool) -> Vec<i64>
90    where
91        S: 'static,
92    {
93        let freqs = self.default_matsubara_sampling_points(positive_only);
94        freqs.into_iter().map(|f| f.n()).collect()
95    }
96
97    /// Get default Matsubara sampling points as i64 indices with mitigate parameter (for C-API)
98    pub fn default_matsubara_sampling_points_i64_with_mitigate(
99        &self,
100        positive_only: bool,
101        mitigate: bool,
102        n_points: usize,
103    ) -> Vec<i64>
104    where
105        S: 'static,
106    {
107        let fence = mitigate;
108        let freqs = Self::default_matsubara_sampling_points_impl(
109            &self.uhat_full,
110            n_points,
111            fence,
112            positive_only,
113        );
114        freqs.into_iter().map(|f| f.n()).collect()
115    }
116
117    /// Create a new FiniteTempBasis
118    ///
119    /// # Arguments
120    ///
121    /// * `kernel` - Kernel implementing `KernelProperties + CentrosymmKernel`
122    /// * `beta` - Inverse temperature (β > 0)
123    /// * `epsilon` - Accuracy parameter (optional, defaults to NaN for auto)
124    /// * `max_size` - Maximum number of basis functions (optional)
125    ///
126    /// # Returns
127    ///
128    /// A new FiniteTempBasis
129    pub fn new(kernel: K, beta: f64, epsilon: Option<f64>, max_size: Option<usize>) -> Self {
130        // Validate inputs
131        if beta <= 0.0 {
132            panic!("Inverse temperature beta must be positive, got {}", beta);
133        }
134
135        // Compute SVE
136        let epsilon_value = epsilon.unwrap_or(f64::NAN);
137        let sve_result = compute_sve(
138            kernel.clone(),
139            epsilon_value,
140            None, // cutoff
141            max_size,
142            TworkType::Auto,
143        );
144
145        Self::from_sve_result(kernel, beta, sve_result, epsilon, max_size)
146    }
147
148    /// Create basis from existing SVE result
149    ///
150    /// This is useful when you want to reuse the same SVE computation
151    /// for both fermionic and bosonic bases.
152    pub fn from_sve_result(
153        kernel: K,
154        beta: f64,
155        sve_result: SVEResult,
156        epsilon: Option<f64>,
157        max_size: Option<usize>,
158    ) -> Self {
159        // Get truncated u, s, v from SVE result
160        let (u_sve, s_sve, v_sve) = sve_result.part(epsilon, max_size);
161
162        // Calculate accuracy
163        let accuracy = if sve_result.s.len() > s_sve.len() {
164            sve_result.s[s_sve.len()] / sve_result.s[0]
165        } else {
166            sve_result.s[sve_result.s.len() - 1] / sve_result.s[0]
167        };
168
169        // Get kernel parameters
170        let lambda = kernel.lambda();
171        let omega_max = lambda / beta;
172
173        // Scale polynomials to new variables
174        // tau = β/2 * (x + 1), w = ωmax * y
175
176        // Transform u: x ∈ [-1, 1] → τ ∈ [0, β]
177        let u_knots: Vec<f64> = u_sve.get_polys()[0]
178            .knots
179            .iter()
180            .map(|&x| beta / 2.0 * (x + 1.0))
181            .collect();
182        let u_delta_x: Vec<f64> = u_sve.get_polys()[0]
183            .delta_x
184            .iter()
185            .map(|&dx| beta / 2.0 * dx)
186            .collect();
187        let u_symm: Vec<i32> = u_sve.get_polys().iter().map(|p| p.symm).collect();
188
189        let u = u_sve.rescale_domain(u_knots, Some(u_delta_x), Some(u_symm));
190
191        // Transform v: y ∈ [-1, 1] → ω ∈ [-ωmax, ωmax]
192        let v_knots: Vec<f64> = v_sve.get_polys()[0]
193            .knots
194            .iter()
195            .map(|&y| omega_max * y)
196            .collect();
197        let v_delta_x: Vec<f64> = v_sve.get_polys()[0]
198            .delta_x
199            .iter()
200            .map(|&dy| omega_max * dy)
201            .collect();
202        let v_symm: Vec<i32> = v_sve.get_polys().iter().map(|p| p.symm).collect();
203
204        let v = v_sve.rescale_domain(v_knots, Some(v_delta_x), Some(v_symm));
205
206        // Scale singular values
207        // s_scaled = sqrt(β/2 * ωmax) * ωmax^(-ypower) * s_sve
208        let ypower = kernel.ypower();
209        let scale_factor = (beta / 2.0 * omega_max).sqrt() * omega_max.powi(-ypower);
210        let s: Vec<f64> = s_sve.iter().map(|&x| scale_factor * x).collect();
211
212        // Construct uhat (Fourier transform of u)
213        // HACK: Fourier transforms only work on unit interval, so we scale the data
214        let uhat_base_full = sve_result.u.scale_data(beta.sqrt());
215        let conv_rad = kernel.conv_radius();
216
217        // Create statistics instance - we need a value of type S
218        // For Fermionic: S = Fermionic, for Bosonic: S = Bosonic
219        let stat_marker = unsafe { std::mem::zeroed::<S>() };
220
221        let uhat_full = PiecewiseLegendreFTVector::<S>::from_poly_vector(
222            &uhat_base_full,
223            stat_marker,
224            Some(conv_rad),
225        );
226
227        // Truncate uhat to basis size
228        let uhat_polyvec: Vec<_> = uhat_full.polyvec.iter().take(s.len()).cloned().collect();
229        let uhat = PiecewiseLegendreFTVector::from_vector(uhat_polyvec);
230
231        Self {
232            kernel,
233            sve_result: Arc::new(sve_result),
234            accuracy,
235            beta,
236            u: Arc::new(u),
237            v: Arc::new(v),
238            s,
239            uhat: Arc::new(uhat),
240            uhat_full: Arc::new(uhat_full),
241            _phantom: std::marker::PhantomData,
242        }
243    }
244
245    /// Get the size of the basis (number of basis functions)
246    pub fn size(&self) -> usize {
247        self.s.len()
248    }
249
250    /// Get the cutoff parameter Λ = β * ωmax
251    pub fn lambda(&self) -> f64 {
252        self.kernel.lambda()
253    }
254
255    /// Get the frequency cutoff ωmax
256    pub fn omega_max(&self) -> f64 {
257        self.lambda() / self.beta
258    }
259
260    /// Get significance of each singular value (s[i] / s[0])
261    pub fn significance(&self) -> Vec<f64> {
262        let s0 = self.s[0];
263        self.s.iter().map(|&s| s / s0).collect()
264    }
265
266    /// Get default tau sampling points
267    ///
268    /// C++ implementation: libsparseir/include/sparseir/basis.hpp:229-270
269    ///
270    /// Returns sampling points in imaginary time τ ∈ [-β/2, β/2].
271    pub fn default_tau_sampling_points(&self) -> Vec<f64> {
272        self.default_tau_sampling_points_size_requested(self.size())
273    }
274
275    pub fn default_tau_sampling_points_size_requested(&self, size_requested: usize) -> Vec<f64> {
276        // C++: Eigen::VectorXd x = default_sampling_points(*(this->sve_result->u), sz);
277        let x = default_sampling_points(&self.sve_result.u, size_requested);
278        // C++: Extract unique half of sampling points
279        let mut unique_x = Vec::new();
280        if x.len() % 2 == 0 {
281            // C++: for (auto i = 0; i < x.size() / 2; ++i)
282            for i in 0..(x.len() / 2) {
283                unique_x.push(x[i]);
284            }
285        } else {
286            // C++: for (auto i = 0; i < x.size() / 2; ++i)
287            for i in 0..(x.len() / 2) {
288                unique_x.push(x[i]);
289            }
290            // C++: auto x_new = 0.5 * (unique_x.back() + 0.5);
291            let x_new = 0.5 * (unique_x.last().unwrap() + 0.5);
292            unique_x.push(x_new);
293        }
294
295        // C++: Generate symmetric points
296        //      Eigen::VectorXd smpl_taus(2 * unique_x.size());
297        //      for (auto i = 0; i < unique_x.size(); ++i) {
298        //          smpl_taus(i) = (this->beta / 2.0) * (unique_x[i] + 1.0);
299        //          smpl_taus(unique_x.size() + i) = -smpl_taus(i);
300        //      }
301        let mut smpl_taus = Vec::with_capacity(2 * unique_x.len());
302        for &ux in &unique_x {
303            smpl_taus.push((self.beta / 2.0) * (ux + 1.0));
304        }
305        for i in 0..unique_x.len() {
306            smpl_taus.push(-smpl_taus[i]);
307        }
308
309        // C++: std::sort(smpl_taus.data(), smpl_taus.data() + smpl_taus.size());
310        smpl_taus.sort_by(|a, b| a.partial_cmp(b).unwrap());
311
312        // C++: Check if the number of sampling points is even
313        if smpl_taus.len() % 2 != 0 {
314            panic!("The number of tau sampling points is odd!");
315        }
316
317        // C++: Check if tau = 0 is not in the sampling points
318        for &tau in &smpl_taus {
319            if tau.abs() < 1e-10 {
320                eprintln!(
321                    "Warning: tau = 0 is in the sampling points (absolute error: {})",
322                    tau.abs()
323                );
324            }
325        }
326
327        // C++ implementation returns tau in [-beta/2, beta/2] (does NOT convert to [0, beta])
328        // This matches the natural range for centrosymmetric kernels
329        smpl_taus
330    }
331
332    /// Get default Matsubara frequency sampling points
333    ///
334    /// Returns sampling points as MatsubaraFreq objects based on extrema
335    /// of the Matsubara basis functions (same algorithm as C++/Julia).
336    ///
337    /// # Arguments
338    /// * `positive_only` - If true, returns only non-negative frequencies
339    ///
340    /// # Returns
341    /// Vector of Matsubara frequency sampling points
342    pub fn default_matsubara_sampling_points(
343        &self,
344        positive_only: bool,
345    ) -> Vec<crate::freq::MatsubaraFreq<S>>
346    where
347        S: 'static,
348    {
349        let fence = false;
350        Self::default_matsubara_sampling_points_impl(
351            &self.uhat_full,
352            self.size(),
353            fence,
354            positive_only,
355        )
356    }
357
358    /// Fence Matsubara sampling points to improve conditioning
359    ///
360    /// This function adds additional sampling points near the outer frequencies
361    /// to improve the conditioning of the sampling matrix. This is particularly
362    /// important for Matsubara sampling where we cannot freely choose sampling points.
363    ///
364    /// Implementation matches C++ version in `basis.hpp` (lines 407-452).
365    fn fence_matsubara_sampling(
366        omega_n: &mut Vec<crate::freq::MatsubaraFreq<S>>,
367        positive_only: bool,
368    ) where
369        S: StatisticsType + 'static,
370    {
371        use crate::freq::{BosonicFreq, MatsubaraFreq};
372
373        if omega_n.is_empty() {
374            return;
375        }
376
377        // Collect outer frequencies
378        let mut outer_frequencies = Vec::new();
379        if positive_only {
380            outer_frequencies.push(omega_n[omega_n.len() - 1]);
381        } else {
382            outer_frequencies.push(omega_n[0]);
383            outer_frequencies.push(omega_n[omega_n.len() - 1]);
384        }
385
386        for wn_outer in outer_frequencies {
387            let outer_val = wn_outer.n();
388            // In SparseIR.jl-v1, ωn_diff is always created as BosonicFreq
389            // This ensures diff_val is always even (valid for Bosonic)
390            let mut diff_val = 2 * (0.025 * outer_val as f64).round() as i64;
391
392            // Handle edge case: if diff_val is 0, set it to 2 (minimum even value for Bosonic)
393            if diff_val == 0 {
394                diff_val = 2;
395            }
396
397            // Get the n value from BosonicFreq (same as diff_val since it's even)
398            let wn_diff = BosonicFreq::new(diff_val).unwrap().n();
399
400            // Sign function: returns +1 if n > 0, -1 if n < 0, 0 if n == 0
401            // Matches C++ implementation: (a.get_n() > 0) - (a.get_n() < 0)
402            let sign_val = if outer_val > 0 {
403                1
404            } else if outer_val < 0 {
405                -1
406            } else {
407                0
408            };
409
410            // Check original size before adding (C++ checks wn.size() before each push)
411            let original_size = omega_n.len();
412            if original_size >= 20 {
413                // For Fermionic: wn_outer.n is odd, wn_diff is even, so wn_outer.n ± wn_diff is odd (valid)
414                // For Bosonic: wn_outer.n is even, wn_diff is even, so wn_outer.n ± wn_diff is even (valid)
415                let new_n = outer_val - sign_val * wn_diff;
416                if let Ok(new_freq) = MatsubaraFreq::<S>::new(new_n) {
417                    omega_n.push(new_freq);
418                }
419            }
420            if original_size >= 42 {
421                let new_n = outer_val + sign_val * wn_diff;
422                if let Ok(new_freq) = MatsubaraFreq::<S>::new(new_n) {
423                    omega_n.push(new_freq);
424                }
425            }
426        }
427
428        // Sort and remove duplicates using BTreeSet
429        let omega_n_set: std::collections::BTreeSet<MatsubaraFreq<S>> = omega_n.drain(..).collect();
430        *omega_n = omega_n_set.into_iter().collect();
431    }
432
433    pub fn default_matsubara_sampling_points_impl(
434        uhat_full: &PiecewiseLegendreFTVector<S>,
435        l: usize,
436        fence: bool,
437        positive_only: bool,
438    ) -> Vec<crate::freq::MatsubaraFreq<S>>
439    where
440        S: StatisticsType + 'static,
441    {
442        use crate::freq::MatsubaraFreq;
443        use crate::polyfourier::{find_extrema, sign_changes};
444        use std::collections::BTreeSet;
445
446        let mut l_requested = l;
447
448        // Adjust l_requested based on statistics (same as C++)
449        if S::STATISTICS == crate::traits::Statistics::Fermionic && l_requested % 2 != 0 {
450            l_requested += 1;
451        } else if S::STATISTICS == crate::traits::Statistics::Bosonic && l_requested % 2 == 0 {
452            l_requested += 1;
453        }
454
455        // Choose sign_changes or find_extrema based on l_requested
456        let mut omega_n = if l_requested < uhat_full.len() {
457            sign_changes(&uhat_full[l_requested], positive_only)
458        } else {
459            find_extrema(&uhat_full[uhat_full.len() - 1], positive_only)
460        };
461
462        // For bosons, include zero frequency explicitly to prevent conditioning issues
463        if S::STATISTICS == crate::traits::Statistics::Bosonic {
464            omega_n.push(MatsubaraFreq::<S>::new(0).unwrap());
465        }
466
467        // Sort and remove duplicates using BTreeSet
468        let omega_n_set: BTreeSet<MatsubaraFreq<S>> = omega_n.into_iter().collect();
469        let mut omega_n: Vec<MatsubaraFreq<S>> = omega_n_set.into_iter().collect();
470
471        // Check expected size
472        let expected_size = if positive_only {
473            l_requested.div_ceil(2)
474        } else {
475            l_requested
476        };
477
478        if omega_n.len() != expected_size {
479            eprintln!(
480                "Warning: Requested {} sampling frequencies for basis size L = {}, but got {}.",
481                expected_size,
482                l,
483                omega_n.len()
484            );
485        }
486
487        // Apply fencing if requested (same as C++ implementation)
488        if fence {
489            Self::fence_matsubara_sampling(&mut omega_n, positive_only);
490        }
491
492        omega_n
493    }
494    /// Get default omega (real frequency) sampling points
495    ///
496    /// Returns sampling points on the real-frequency axis ω ∈ [-ωmax, ωmax].
497    /// These are used as pole locations for the Discrete Lehmann Representation (DLR).
498    ///
499    /// The sampling points are chosen as the roots of the L-th basis function
500    /// in the spectral domain (v), which provides near-optimal conditioning.
501    ///
502    /// # Returns
503    /// Vector of real-frequency sampling points in [-ωmax, ωmax]
504    pub fn default_omega_sampling_points(&self) -> Vec<f64> {
505        let sz = self.size();
506
507        // Use UNTRUNCATED sve_result.v (same as C++)
508        // C++: default_sampling_points(*(sve_result->v), sz)
509        let y = default_sampling_points(&self.sve_result.v, sz);
510
511        // Scale to [-ωmax, ωmax]
512        let wmax = self.kernel.lambda() / self.beta;
513        let omega_points: Vec<f64> = y.into_iter().map(|yi| wmax * yi).collect();
514
515        omega_points
516    }
517}
518
519// ============================================================================
520// Trait implementations
521// ============================================================================
522
523impl<K, S> crate::basis_trait::Basis<S> for FiniteTempBasis<K, S>
524where
525    K: KernelProperties + CentrosymmKernel + Clone + 'static,
526    S: StatisticsType + 'static,
527{
528    type Kernel = K;
529
530    fn kernel(&self) -> &Self::Kernel {
531        &self.kernel
532    }
533
534    fn beta(&self) -> f64 {
535        self.beta
536    }
537
538    fn wmax(&self) -> f64 {
539        self.kernel.lambda() / self.beta
540    }
541
542    fn lambda(&self) -> f64 {
543        self.kernel.lambda()
544    }
545
546    fn size(&self) -> usize {
547        self.size()
548    }
549
550    fn accuracy(&self) -> f64 {
551        self.accuracy
552    }
553
554    fn significance(&self) -> Vec<f64> {
555        if let Some(&first_s) = self.s.first() {
556            self.s.iter().map(|&s| s / first_s).collect()
557        } else {
558            vec![]
559        }
560    }
561
562    fn svals(&self) -> Vec<f64> {
563        self.s.clone()
564    }
565
566    fn default_tau_sampling_points(&self) -> Vec<f64> {
567        self.default_tau_sampling_points()
568    }
569
570    fn default_matsubara_sampling_points(
571        &self,
572        positive_only: bool,
573    ) -> Vec<crate::freq::MatsubaraFreq<S>> {
574        self.default_matsubara_sampling_points(positive_only)
575    }
576
577    fn evaluate_tau(&self, tau: &[f64]) -> mdarray::DTensor<f64, 2> {
578        use crate::taufuncs::normalize_tau;
579        use mdarray::DTensor;
580
581        let n_points = tau.len();
582        let basis_size = self.size();
583
584        // Evaluate each basis function at all tau points
585        // Result: matrix[i, l] = u_l(tau[i])
586        // Note: tau can be in [-beta, beta] and will be normalized to [0, beta]
587        // self.u polynomials are already scaled to tau ∈ [0, beta] domain
588        DTensor::<f64, 2>::from_fn([n_points, basis_size], |idx| {
589            let i = idx[0]; // tau index
590            let l = idx[1]; // basis function index
591
592            // Normalize tau to [0, beta] with statistics-dependent sign
593            let (tau_norm, sign) = normalize_tau::<S>(tau[i], self.beta);
594
595            // Evaluate basis function directly (u polynomials are in tau domain)
596            sign * self.u[l].evaluate(tau_norm)
597        })
598    }
599
600    fn evaluate_matsubara(
601        &self,
602        freqs: &[crate::freq::MatsubaraFreq<S>],
603    ) -> mdarray::DTensor<num_complex::Complex<f64>, 2> {
604        use mdarray::DTensor;
605        use num_complex::Complex;
606
607        let n_points = freqs.len();
608        let basis_size = self.size();
609
610        // Evaluate each basis function at all Matsubara frequencies
611        // Result: matrix[i, l] = uhat_l(iωn[i])
612        DTensor::<Complex<f64>, 2>::from_fn([n_points, basis_size], |idx| {
613            let i = idx[0]; // frequency index
614            let l = idx[1]; // basis function index
615            self.uhat[l].evaluate(&freqs[i])
616        })
617    }
618
619    fn evaluate_omega(&self, omega: &[f64]) -> mdarray::DTensor<f64, 2> {
620        use mdarray::DTensor;
621
622        let n_points = omega.len();
623        let basis_size = self.size();
624
625        // Evaluate each spectral basis function at all omega points
626        // Result: matrix[i, l] = V_l(omega[i])
627        DTensor::<f64, 2>::from_fn([n_points, basis_size], |idx| {
628            let i = idx[0]; // omega index
629            let l = idx[1]; // basis function index
630            self.v[l].evaluate(omega[i])
631        })
632    }
633
634    fn default_omega_sampling_points(&self) -> Vec<f64> {
635        self.default_omega_sampling_points()
636    }
637}
638
639// ============================================================================
640// Type aliases
641// ============================================================================
642
643/// Type alias for fermionic basis with LogisticKernel
644pub type FermionicBasis = FiniteTempBasis<LogisticKernel, Fermionic>;
645
646/// Type alias for bosonic basis with LogisticKernel
647pub type BosonicBasis = FiniteTempBasis<LogisticKernel, Bosonic>;
648
649#[cfg(test)]
650#[path = "basis_tests.rs"]
651mod basis_tests;