sparse_ir/
basis.rs

1//! Finite temperature basis for SparseIR
2//!
3//! This module provides the `FiniteTempBasis` type which represents the
4//! intermediate representation (IR) basis for a given temperature.
5
6use std::sync::Arc;
7
8use crate::kernel::{CentrosymmKernel, KernelProperties, LogisticKernel};
9use crate::poly::{PiecewiseLegendrePolyVector, default_sampling_points};
10use crate::polyfourier::PiecewiseLegendreFTVector;
11use crate::sve::{SVEResult, TworkType, compute_sve};
12use crate::traits::{Bosonic, Fermionic, StatisticsType};
13
14// Re-export Statistics enum for C-API
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum Statistics {
17    Fermionic,
18    Bosonic,
19}
20
21/// Finite temperature basis for imaginary time/frequency Green's functions
22///
23/// For a continuation kernel `K` from real frequencies `ω ∈ [-ωmax, ωmax]` to
24/// imaginary time `τ ∈ [0, β]`, this type stores the truncated singular
25/// value expansion or IR basis:
26///
27/// ```text
28/// K(τ, ω) ≈ sum(u[l](τ) * s[l] * v[l](ω) for l in 1:L)
29/// ```
30///
31/// This basis is inferred from a reduced form by appropriate scaling of
32/// the variables.
33///
34/// # Type Parameters
35///
36/// * `K` - Kernel type implementing `KernelProperties + CentrosymmKernel`
37/// * `S` - Statistics type (`Fermionic` or `Bosonic`)
38#[derive(Clone)]
39pub struct FiniteTempBasis<K, S>
40where
41    K: KernelProperties + CentrosymmKernel + Clone + 'static,
42    S: StatisticsType,
43{
44    /// The kernel used to construct this basis
45    pub kernel: K,
46
47    /// The SVE result (in scaled variables)
48    pub sve_result: Arc<SVEResult>,
49
50    /// Accuracy of the basis (relative error)
51    pub accuracy: f64,
52
53    /// Inverse temperature β
54    pub beta: f64,
55
56    /// Left singular functions on imaginary time axis τ ∈ [0, β]
57    /// Arc for efficient sharing (large immutable data)
58    pub u: Arc<PiecewiseLegendrePolyVector>,
59
60    /// Right singular functions on real frequency axis ω ∈ [-ωmax, ωmax]
61    /// Arc for efficient sharing (large immutable data)
62    pub v: Arc<PiecewiseLegendrePolyVector>,
63
64    /// Singular values
65    pub s: Vec<f64>,
66
67    /// Left singular functions on Matsubara frequency axis (Fourier transform of u)
68    /// Arc for efficient sharing (large immutable data)
69    pub uhat: Arc<PiecewiseLegendreFTVector<S>>,
70
71    /// Full uhat (before truncation to basis size)
72    /// Arc for efficient sharing (large immutable data, used for Matsubara sampling)
73    pub uhat_full: Arc<PiecewiseLegendreFTVector<S>>,
74
75    _phantom: std::marker::PhantomData<S>,
76}
77
78impl<K, S> FiniteTempBasis<K, S>
79where
80    K: KernelProperties + CentrosymmKernel + Clone + 'static,
81    S: StatisticsType,
82{
83    /// Get the frequency cutoff ωmax
84    pub fn wmax(&self) -> f64 {
85        self.kernel.lambda() / self.beta
86    }
87
88    /// Get default Matsubara sampling points as i64 indices (for C-API)
89    pub fn default_matsubara_sampling_points_i64(&self, positive_only: bool) -> Vec<i64>
90    where
91        S: 'static,
92    {
93        let freqs = self.default_matsubara_sampling_points(positive_only);
94        freqs.into_iter().map(|f| f.n()).collect()
95    }
96
97    /// Get default Matsubara sampling points as i64 indices with mitigate parameter (for C-API)
98    pub fn default_matsubara_sampling_points_i64_with_mitigate(
99        &self,
100        positive_only: bool,
101        mitigate: bool,
102        n_points: usize,
103    ) -> Vec<i64>
104    where
105        S: 'static,
106    {
107        let fence = mitigate;
108        let freqs = Self::default_matsubara_sampling_points_impl(
109            &self.uhat_full,
110            n_points,
111            fence,
112            positive_only,
113        );
114        freqs.into_iter().map(|f| f.n()).collect()
115    }
116
117    /// Create a new FiniteTempBasis
118    ///
119    /// # Arguments
120    ///
121    /// * `kernel` - Kernel implementing `KernelProperties + CentrosymmKernel`
122    /// * `beta` - Inverse temperature (β > 0)
123    /// * `epsilon` - Accuracy parameter (optional, defaults to NaN for auto)
124    /// * `max_size` - Maximum number of basis functions (optional)
125    ///
126    /// # Returns
127    ///
128    /// A new FiniteTempBasis
129    pub fn new(kernel: K, beta: f64, epsilon: Option<f64>, max_size: Option<usize>) -> Self {
130        // Validate inputs
131        if beta <= 0.0 {
132            panic!("Inverse temperature beta must be positive, got {}", beta);
133        }
134
135        // Compute SVE
136        let epsilon_value = epsilon.unwrap_or(f64::NAN);
137        let sve_result = compute_sve(
138            kernel.clone(),
139            epsilon_value,
140            None, // cutoff
141            max_size,
142            TworkType::Auto,
143        );
144
145        Self::from_sve_result(kernel, beta, sve_result, epsilon, max_size)
146    }
147
148    /// Create basis from existing SVE result
149    ///
150    /// This is useful when you want to reuse the same SVE computation
151    /// for both fermionic and bosonic bases.
152    pub fn from_sve_result(
153        kernel: K,
154        beta: f64,
155        sve_result: SVEResult,
156        epsilon: Option<f64>,
157        max_size: Option<usize>,
158    ) -> Self {
159        // Get truncated u, s, v from SVE result
160        let (u_sve, s_sve, v_sve) = sve_result.part(epsilon, max_size);
161
162        // Calculate accuracy
163        let accuracy = if sve_result.s.len() > s_sve.len() {
164            sve_result.s[s_sve.len()] / sve_result.s[0]
165        } else {
166            sve_result.s[sve_result.s.len() - 1] / sve_result.s[0]
167        };
168
169        // Get kernel parameters
170        let lambda = kernel.lambda();
171        let omega_max = lambda / beta;
172
173        // Scale polynomials to new variables
174        // tau = β/2 * (x + 1), w = ωmax * y
175
176        // Transform u: x ∈ [-1, 1] → τ ∈ [0, β]
177        let u_knots: Vec<f64> = u_sve.get_polys()[0]
178            .knots
179            .iter()
180            .map(|&x| beta / 2.0 * (x + 1.0))
181            .collect();
182        let u_delta_x: Vec<f64> = u_sve.get_polys()[0]
183            .delta_x
184            .iter()
185            .map(|&dx| beta / 2.0 * dx)
186            .collect();
187        let u_symm: Vec<i32> = u_sve.get_polys().iter().map(|p| p.symm).collect();
188
189        let u = u_sve.rescale_domain(u_knots, Some(u_delta_x), Some(u_symm));
190
191        // Transform v: y ∈ [-1, 1] → ω ∈ [-ωmax, ωmax]
192        let v_knots: Vec<f64> = v_sve.get_polys()[0]
193            .knots
194            .iter()
195            .map(|&y| omega_max * y)
196            .collect();
197        let v_delta_x: Vec<f64> = v_sve.get_polys()[0]
198            .delta_x
199            .iter()
200            .map(|&dy| omega_max * dy)
201            .collect();
202        let v_symm: Vec<i32> = v_sve.get_polys().iter().map(|p| p.symm).collect();
203
204        let v = v_sve.rescale_domain(v_knots, Some(v_delta_x), Some(v_symm));
205
206        // Scale singular values
207        // s_scaled = sqrt(β/2 * ωmax) * ωmax^(-ypower) * s_sve
208        let ypower = kernel.ypower();
209        let scale_factor = (beta / 2.0 * omega_max).sqrt() * omega_max.powi(-ypower);
210        let s: Vec<f64> = s_sve.iter().map(|&x| scale_factor * x).collect();
211
212        // Construct uhat (Fourier transform of u)
213        // HACK: Fourier transforms only work on unit interval, so we scale the data
214        let uhat_base_full = sve_result.u.scale_data(beta.sqrt());
215        let conv_rad = kernel.conv_radius();
216
217        // Create statistics instance - we need a value of type S
218        // For Fermionic: S = Fermionic, for Bosonic: S = Bosonic
219        let stat_marker = unsafe { std::mem::zeroed::<S>() };
220
221        let uhat_full = PiecewiseLegendreFTVector::<S>::from_poly_vector(
222            &uhat_base_full,
223            stat_marker,
224            Some(conv_rad),
225        );
226
227        // Truncate uhat to basis size
228        let uhat_polyvec: Vec<_> = uhat_full.polyvec.iter().take(s.len()).cloned().collect();
229        let uhat = PiecewiseLegendreFTVector::from_vector(uhat_polyvec);
230
231        Self {
232            kernel,
233            sve_result: Arc::new(sve_result),
234            accuracy,
235            beta,
236            u: Arc::new(u),
237            v: Arc::new(v),
238            s,
239            uhat: Arc::new(uhat),
240            uhat_full: Arc::new(uhat_full),
241            _phantom: std::marker::PhantomData,
242        }
243    }
244
245    /// Get the size of the basis (number of basis functions)
246    pub fn size(&self) -> usize {
247        self.s.len()
248    }
249
250    /// Get the cutoff parameter Λ = β * ωmax
251    pub fn lambda(&self) -> f64 {
252        self.kernel.lambda()
253    }
254
255    /// Get the frequency cutoff ωmax
256    pub fn omega_max(&self) -> f64 {
257        self.lambda() / self.beta
258    }
259
260    /// Get significance of each singular value (s[i] / s[0])
261    pub fn significance(&self) -> Vec<f64> {
262        let s0 = self.s[0];
263        self.s.iter().map(|&s| s / s0).collect()
264    }
265
266    /// Get default tau sampling points
267    ///
268    /// C++ implementation: libsparseir/include/sparseir/basis.hpp:229-270
269    ///
270    /// Returns sampling points in imaginary time τ ∈ [-β/2, β/2].
271    pub fn default_tau_sampling_points(&self) -> Vec<f64> {
272        let sz = self.size();
273
274        // C++: Eigen::VectorXd x = default_sampling_points(*(this->sve_result->u), sz);
275        let x = default_sampling_points(&self.sve_result.u, sz);
276
277        // C++: Extract unique half of sampling points
278        let mut unique_x = Vec::new();
279        if x.len() % 2 == 0 {
280            // C++: for (auto i = 0; i < x.size() / 2; ++i)
281            for i in 0..(x.len() / 2) {
282                unique_x.push(x[i]);
283            }
284        } else {
285            // C++: for (auto i = 0; i < x.size() / 2; ++i)
286            for i in 0..(x.len() / 2) {
287                unique_x.push(x[i]);
288            }
289            // C++: auto x_new = 0.5 * (unique_x.back() + 0.5);
290            let x_new = 0.5 * (unique_x.last().unwrap() + 0.5);
291            unique_x.push(x_new);
292        }
293
294        // C++: Generate symmetric points
295        //      Eigen::VectorXd smpl_taus(2 * unique_x.size());
296        //      for (auto i = 0; i < unique_x.size(); ++i) {
297        //          smpl_taus(i) = (this->beta / 2.0) * (unique_x[i] + 1.0);
298        //          smpl_taus(unique_x.size() + i) = -smpl_taus(i);
299        //      }
300        let mut smpl_taus = Vec::with_capacity(2 * unique_x.len());
301        for &ux in &unique_x {
302            smpl_taus.push((self.beta / 2.0) * (ux + 1.0));
303        }
304        for i in 0..unique_x.len() {
305            smpl_taus.push(-smpl_taus[i]);
306        }
307
308        // C++: std::sort(smpl_taus.data(), smpl_taus.data() + smpl_taus.size());
309        smpl_taus.sort_by(|a, b| a.partial_cmp(b).unwrap());
310
311        // C++: Check if the number of sampling points is even
312        if smpl_taus.len() % 2 != 0 {
313            panic!("The number of tau sampling points is odd!");
314        }
315
316        // C++: Check if tau = 0 is not in the sampling points
317        for &tau in &smpl_taus {
318            if tau.abs() < 1e-10 {
319                eprintln!(
320                    "Warning: tau = 0 is in the sampling points (absolute error: {})",
321                    tau.abs()
322                );
323            }
324        }
325
326        // C++ implementation returns tau in [-beta/2, beta/2] (does NOT convert to [0, beta])
327        // This matches the natural range for centrosymmetric kernels
328        smpl_taus
329    }
330
331    /// Get default Matsubara frequency sampling points
332    ///
333    /// Returns sampling points as MatsubaraFreq objects based on extrema
334    /// of the Matsubara basis functions (same algorithm as C++/Julia).
335    ///
336    /// # Arguments
337    /// * `positive_only` - If true, returns only non-negative frequencies
338    ///
339    /// # Returns
340    /// Vector of Matsubara frequency sampling points
341    pub fn default_matsubara_sampling_points(
342        &self,
343        positive_only: bool,
344    ) -> Vec<crate::freq::MatsubaraFreq<S>>
345    where
346        S: 'static,
347    {
348        let fence = false;
349        Self::default_matsubara_sampling_points_impl(
350            &self.uhat_full,
351            self.size(),
352            fence,
353            positive_only,
354        )
355    }
356
357    /// Fence Matsubara sampling points to improve conditioning
358    ///
359    /// This function adds additional sampling points near the outer frequencies
360    /// to improve the conditioning of the sampling matrix. This is particularly
361    /// important for Matsubara sampling where we cannot freely choose sampling points.
362    ///
363    /// Implementation matches C++ version in `basis.hpp` (lines 407-452).
364    fn fence_matsubara_sampling(
365        omega_n: &mut Vec<crate::freq::MatsubaraFreq<S>>,
366        positive_only: bool,
367    ) where
368        S: StatisticsType + 'static,
369    {
370        use crate::freq::{BosonicFreq, MatsubaraFreq};
371
372        if omega_n.is_empty() {
373            return;
374        }
375
376        // Collect outer frequencies
377        let mut outer_frequencies = Vec::new();
378        if positive_only {
379            outer_frequencies.push(omega_n[omega_n.len() - 1]);
380        } else {
381            outer_frequencies.push(omega_n[0]);
382            outer_frequencies.push(omega_n[omega_n.len() - 1]);
383        }
384
385        for wn_outer in outer_frequencies {
386            let outer_val = wn_outer.n();
387            // In SparseIR.jl-v1, ωn_diff is always created as BosonicFreq
388            // This ensures diff_val is always even (valid for Bosonic)
389            let mut diff_val = 2 * (0.025 * outer_val as f64).round() as i64;
390
391            // Handle edge case: if diff_val is 0, set it to 2 (minimum even value for Bosonic)
392            if diff_val == 0 {
393                diff_val = 2;
394            }
395
396            // Get the n value from BosonicFreq (same as diff_val since it's even)
397            let wn_diff = BosonicFreq::new(diff_val).unwrap().n();
398
399            // Sign function: returns +1 if n > 0, -1 if n < 0, 0 if n == 0
400            // Matches C++ implementation: (a.get_n() > 0) - (a.get_n() < 0)
401            let sign_val = if outer_val > 0 {
402                1
403            } else if outer_val < 0 {
404                -1
405            } else {
406                0
407            };
408
409            // Check original size before adding (C++ checks wn.size() before each push)
410            let original_size = omega_n.len();
411            if original_size >= 20 {
412                // For Fermionic: wn_outer.n is odd, wn_diff is even, so wn_outer.n ± wn_diff is odd (valid)
413                // For Bosonic: wn_outer.n is even, wn_diff is even, so wn_outer.n ± wn_diff is even (valid)
414                let new_n = outer_val - sign_val * wn_diff;
415                if let Ok(new_freq) = MatsubaraFreq::<S>::new(new_n) {
416                    omega_n.push(new_freq);
417                }
418            }
419            if original_size >= 42 {
420                let new_n = outer_val + sign_val * wn_diff;
421                if let Ok(new_freq) = MatsubaraFreq::<S>::new(new_n) {
422                    omega_n.push(new_freq);
423                }
424            }
425        }
426
427        // Sort and remove duplicates using BTreeSet
428        let omega_n_set: std::collections::BTreeSet<MatsubaraFreq<S>> = omega_n.drain(..).collect();
429        *omega_n = omega_n_set.into_iter().collect();
430    }
431
432    pub fn default_matsubara_sampling_points_impl(
433        uhat_full: &PiecewiseLegendreFTVector<S>,
434        l: usize,
435        fence: bool,
436        positive_only: bool,
437    ) -> Vec<crate::freq::MatsubaraFreq<S>>
438    where
439        S: StatisticsType + 'static,
440    {
441        use crate::freq::MatsubaraFreq;
442        use crate::polyfourier::{find_extrema, sign_changes};
443        use std::collections::BTreeSet;
444
445        let mut l_requested = l;
446
447        // Adjust l_requested based on statistics (same as C++)
448        if S::STATISTICS == crate::traits::Statistics::Fermionic && l_requested % 2 != 0 {
449            l_requested += 1;
450        } else if S::STATISTICS == crate::traits::Statistics::Bosonic && l_requested % 2 == 0 {
451            l_requested += 1;
452        }
453
454        // Choose sign_changes or find_extrema based on l_requested
455        let mut omega_n = if l_requested < uhat_full.len() {
456            sign_changes(&uhat_full[l_requested], positive_only)
457        } else {
458            find_extrema(&uhat_full[uhat_full.len() - 1], positive_only)
459        };
460
461        // For bosons, include zero frequency explicitly to prevent conditioning issues
462        if S::STATISTICS == crate::traits::Statistics::Bosonic {
463            omega_n.push(MatsubaraFreq::<S>::new(0).unwrap());
464        }
465
466        // Sort and remove duplicates using BTreeSet
467        let omega_n_set: BTreeSet<MatsubaraFreq<S>> = omega_n.into_iter().collect();
468        let mut omega_n: Vec<MatsubaraFreq<S>> = omega_n_set.into_iter().collect();
469
470        // Check expected size
471        let expected_size = if positive_only {
472            l_requested.div_ceil(2)
473        } else {
474            l_requested
475        };
476
477        if omega_n.len() != expected_size {
478            eprintln!(
479                "Warning: Requested {} sampling frequencies for basis size L = {}, but got {}.",
480                expected_size,
481                l,
482                omega_n.len()
483            );
484        }
485
486        // Apply fencing if requested (same as C++ implementation)
487        if fence {
488            Self::fence_matsubara_sampling(&mut omega_n, positive_only);
489        }
490
491        omega_n
492    }
493    /// Get default omega (real frequency) sampling points
494    ///
495    /// Returns sampling points on the real-frequency axis ω ∈ [-ωmax, ωmax].
496    /// These are used as pole locations for the Discrete Lehmann Representation (DLR).
497    ///
498    /// The sampling points are chosen as the roots of the L-th basis function
499    /// in the spectral domain (v), which provides near-optimal conditioning.
500    ///
501    /// # Returns
502    /// Vector of real-frequency sampling points in [-ωmax, ωmax]
503    pub fn default_omega_sampling_points(&self) -> Vec<f64> {
504        let sz = self.size();
505
506        // Use UNTRUNCATED sve_result.v (same as C++)
507        // C++: default_sampling_points(*(sve_result->v), sz)
508        let y = default_sampling_points(&self.sve_result.v, sz);
509
510        // Scale to [-ωmax, ωmax]
511        let wmax = self.kernel.lambda() / self.beta;
512        let omega_points: Vec<f64> = y.into_iter().map(|yi| wmax * yi).collect();
513
514        omega_points
515    }
516}
517
518// ============================================================================
519// Trait implementations
520// ============================================================================
521
522impl<K, S> crate::basis_trait::Basis<S> for FiniteTempBasis<K, S>
523where
524    K: KernelProperties + CentrosymmKernel + Clone + 'static,
525    S: StatisticsType + 'static,
526{
527    type Kernel = K;
528
529    fn kernel(&self) -> &Self::Kernel {
530        &self.kernel
531    }
532
533    fn beta(&self) -> f64 {
534        self.beta
535    }
536
537    fn wmax(&self) -> f64 {
538        self.kernel.lambda() / self.beta
539    }
540
541    fn lambda(&self) -> f64 {
542        self.kernel.lambda()
543    }
544
545    fn size(&self) -> usize {
546        self.size()
547    }
548
549    fn accuracy(&self) -> f64 {
550        self.accuracy
551    }
552
553    fn significance(&self) -> Vec<f64> {
554        if let Some(&first_s) = self.s.first() {
555            self.s.iter().map(|&s| s / first_s).collect()
556        } else {
557            vec![]
558        }
559    }
560
561    fn svals(&self) -> Vec<f64> {
562        self.s.clone()
563    }
564
565    fn default_tau_sampling_points(&self) -> Vec<f64> {
566        self.default_tau_sampling_points()
567    }
568
569    fn default_matsubara_sampling_points(
570        &self,
571        positive_only: bool,
572    ) -> Vec<crate::freq::MatsubaraFreq<S>> {
573        self.default_matsubara_sampling_points(positive_only)
574    }
575
576    fn evaluate_tau(&self, tau: &[f64]) -> mdarray::DTensor<f64, 2> {
577        use crate::taufuncs::normalize_tau;
578        use mdarray::DTensor;
579
580        let n_points = tau.len();
581        let basis_size = self.size();
582
583        // Evaluate each basis function at all tau points
584        // Result: matrix[i, l] = u_l(tau[i])
585        // Note: tau can be in [-beta, beta] and will be normalized to [0, beta]
586        // self.u polynomials are already scaled to tau ∈ [0, beta] domain
587        DTensor::<f64, 2>::from_fn([n_points, basis_size], |idx| {
588            let i = idx[0]; // tau index
589            let l = idx[1]; // basis function index
590
591            // Normalize tau to [0, beta] with statistics-dependent sign
592            let (tau_norm, sign) = normalize_tau::<S>(tau[i], self.beta);
593
594            // Evaluate basis function directly (u polynomials are in tau domain)
595            sign * self.u[l].evaluate(tau_norm)
596        })
597    }
598
599    fn evaluate_matsubara(
600        &self,
601        freqs: &[crate::freq::MatsubaraFreq<S>],
602    ) -> mdarray::DTensor<num_complex::Complex<f64>, 2> {
603        use mdarray::DTensor;
604        use num_complex::Complex;
605
606        let n_points = freqs.len();
607        let basis_size = self.size();
608
609        // Evaluate each basis function at all Matsubara frequencies
610        // Result: matrix[i, l] = uhat_l(iωn[i])
611        DTensor::<Complex<f64>, 2>::from_fn([n_points, basis_size], |idx| {
612            let i = idx[0]; // frequency index
613            let l = idx[1]; // basis function index
614            self.uhat[l].evaluate(&freqs[i])
615        })
616    }
617
618    fn evaluate_omega(&self, omega: &[f64]) -> mdarray::DTensor<f64, 2> {
619        use mdarray::DTensor;
620
621        let n_points = omega.len();
622        let basis_size = self.size();
623
624        // Evaluate each spectral basis function at all omega points
625        // Result: matrix[i, l] = V_l(omega[i])
626        DTensor::<f64, 2>::from_fn([n_points, basis_size], |idx| {
627            let i = idx[0]; // omega index
628            let l = idx[1]; // basis function index
629            self.v[l].evaluate(omega[i])
630        })
631    }
632
633    fn default_omega_sampling_points(&self) -> Vec<f64> {
634        self.default_omega_sampling_points()
635    }
636}
637
638// ============================================================================
639// Type aliases
640// ============================================================================
641
642/// Type alias for fermionic basis with LogisticKernel
643pub type FermionicBasis = FiniteTempBasis<LogisticKernel, Fermionic>;
644
645/// Type alias for bosonic basis with LogisticKernel
646pub type BosonicBasis = FiniteTempBasis<LogisticKernel, Bosonic>;
647
648#[cfg(test)]
649#[path = "basis_tests.rs"]
650mod basis_tests;