1use std::sync::Arc;
7
8use crate::kernel::{CentrosymmKernel, KernelProperties, LogisticKernel};
9use crate::poly::{PiecewiseLegendrePolyVector, default_sampling_points};
10use crate::polyfourier::PiecewiseLegendreFTVector;
11use crate::sve::{SVEResult, TworkType, compute_sve};
12use crate::traits::{Bosonic, Fermionic, StatisticsType};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum Statistics {
17 Fermionic,
18 Bosonic,
19}
20
21#[derive(Clone)]
39pub struct FiniteTempBasis<K, S>
40where
41 K: KernelProperties + CentrosymmKernel + Clone + 'static,
42 S: StatisticsType,
43{
44 kernel: K,
46
47 sve_result: Arc<SVEResult>,
49
50 accuracy: f64,
52
53 beta: f64,
55
56 u: Arc<PiecewiseLegendrePolyVector>,
59
60 v: Arc<PiecewiseLegendrePolyVector>,
63
64 s: Vec<f64>,
66
67 uhat: Arc<PiecewiseLegendreFTVector<S>>,
70
71 uhat_full: Arc<PiecewiseLegendreFTVector<S>>,
74
75 _phantom: std::marker::PhantomData<S>,
76}
77
78impl<K, S> FiniteTempBasis<K, S>
79where
80 K: KernelProperties + CentrosymmKernel + Clone + 'static,
81 S: StatisticsType,
82{
83 pub fn kernel(&self) -> &K {
87 &self.kernel
88 }
89
90 pub fn sve_result(&self) -> &Arc<SVEResult> {
92 &self.sve_result
93 }
94
95 pub fn accuracy(&self) -> f64 {
97 self.accuracy
98 }
99
100 pub fn beta(&self) -> f64 {
102 self.beta
103 }
104
105 pub fn u(&self) -> &Arc<PiecewiseLegendrePolyVector> {
107 &self.u
108 }
109
110 pub fn v(&self) -> &Arc<PiecewiseLegendrePolyVector> {
112 &self.v
113 }
114
115 pub fn s(&self) -> &[f64] {
117 &self.s
118 }
119
120 pub fn uhat(&self) -> &Arc<PiecewiseLegendreFTVector<S>> {
122 &self.uhat
123 }
124
125 pub fn uhat_full(&self) -> &Arc<PiecewiseLegendreFTVector<S>> {
127 &self.uhat_full
128 }
129
130 pub fn wmax(&self) -> f64 {
134 self.kernel.lambda() / self.beta
135 }
136
137 pub fn default_matsubara_sampling_points_i64(&self, positive_only: bool) -> Vec<i64>
139 where
140 S: 'static,
141 {
142 let freqs = self.default_matsubara_sampling_points(positive_only);
143 freqs.into_iter().map(|f| f.n()).collect()
144 }
145
146 pub fn default_matsubara_sampling_points_i64_with_mitigate(
152 &self,
153 positive_only: bool,
154 mitigate: bool,
155 n_points: usize,
156 ) -> Vec<i64>
157 where
158 S: 'static,
159 {
160 if !self.kernel().is_centrosymmetric() {
161 panic!(
162 "default_matsubara_sampling_points_i64_with_mitigate is not supported for non-centrosymmetric kernels. \
163 The current implementation relies on centrosymmetry to generate sampling points."
164 );
165 }
166 let fence = mitigate;
167 let freqs = Self::default_matsubara_sampling_points_impl(
168 &self.uhat_full,
169 n_points,
170 fence,
171 positive_only,
172 );
173 freqs.into_iter().map(|f| f.n()).collect()
174 }
175
176 pub fn new(kernel: K, beta: f64, epsilon: Option<f64>, max_size: Option<usize>) -> Self {
189 if beta <= 0.0 {
191 panic!("Inverse temperature beta must be positive, got {}", beta);
192 }
193
194 let epsilon_value = epsilon.unwrap_or(f64::NAN);
196 let sve_result = compute_sve(
197 kernel.clone(),
198 epsilon_value,
199 None, max_size,
201 TworkType::Auto,
202 );
203
204 Self::from_sve_result(kernel, beta, sve_result, epsilon, max_size)
205 }
206
207 pub fn from_sve_result(
212 kernel: K,
213 beta: f64,
214 sve_result: SVEResult,
215 epsilon: Option<f64>,
216 max_size: Option<usize>,
217 ) -> Self {
218 let (u_sve, s_sve, v_sve) = sve_result.part(epsilon, max_size);
220
221 let accuracy = if sve_result.s.len() > s_sve.len() {
223 sve_result.s[s_sve.len()] / sve_result.s[0]
224 } else {
225 sve_result.s[sve_result.s.len() - 1] / sve_result.s[0]
226 };
227
228 let lambda = kernel.lambda();
230 let omega_max = lambda / beta;
231
232 let u_knots: Vec<f64> = u_sve.get_polys()[0]
237 .knots
238 .iter()
239 .map(|&x| beta / 2.0 * (x + 1.0))
240 .collect();
241 let u_delta_x: Vec<f64> = u_sve.get_polys()[0]
242 .delta_x
243 .iter()
244 .map(|&dx| beta / 2.0 * dx)
245 .collect();
246 let u_symm: Vec<i32> = u_sve.get_polys().iter().map(|p| p.symm).collect();
247
248 let u = u_sve.rescale_domain(u_knots, Some(u_delta_x), Some(u_symm));
249
250 let v_knots: Vec<f64> = v_sve.get_polys()[0]
252 .knots
253 .iter()
254 .map(|&y| omega_max * y)
255 .collect();
256 let v_delta_x: Vec<f64> = v_sve.get_polys()[0]
257 .delta_x
258 .iter()
259 .map(|&dy| omega_max * dy)
260 .collect();
261 let v_symm: Vec<i32> = v_sve.get_polys().iter().map(|p| p.symm).collect();
262
263 let v = v_sve.rescale_domain(v_knots, Some(v_delta_x), Some(v_symm));
264
265 let ypower = kernel.ypower();
268 let scale_factor = (beta / 2.0 * omega_max).sqrt() * omega_max.powi(-ypower);
269 let s: Vec<f64> = s_sve.iter().map(|&x| scale_factor * x).collect();
270
271 let uhat_base_full = sve_result.u.scale_data(beta.sqrt());
274 let conv_rad = kernel.conv_radius();
275
276 let stat_marker = S::default();
279
280 let uhat_full = PiecewiseLegendreFTVector::<S>::from_poly_vector(
281 &uhat_base_full,
282 stat_marker,
283 Some(conv_rad),
284 );
285
286 let uhat_polyvec: Vec<_> = uhat_full.polyvec.iter().take(s.len()).cloned().collect();
288 let uhat = PiecewiseLegendreFTVector::from_vector(uhat_polyvec);
289
290 Self {
291 kernel,
292 sve_result: Arc::new(sve_result),
293 accuracy,
294 beta,
295 u: Arc::new(u),
296 v: Arc::new(v),
297 s,
298 uhat: Arc::new(uhat),
299 uhat_full: Arc::new(uhat_full),
300 _phantom: std::marker::PhantomData,
301 }
302 }
303
304 pub fn size(&self) -> usize {
306 self.s.len()
307 }
308
309 pub fn lambda(&self) -> f64 {
311 self.kernel.lambda()
312 }
313
314 pub fn omega_max(&self) -> f64 {
316 self.lambda() / self.beta
317 }
318
319 pub fn significance(&self) -> Vec<f64> {
321 let s0 = self.s[0];
322 self.s.iter().map(|&s| s / s0).collect()
323 }
324
325 pub fn default_tau_sampling_points(&self) -> Vec<f64> {
335 if !self.kernel().is_centrosymmetric() {
336 panic!(
337 "default_tau_sampling_points is not supported for non-centrosymmetric kernels. \
338 The current implementation relies on centrosymmetry to generate symmetric sampling points."
339 );
340 }
341 let points = self.default_tau_sampling_points_size_requested(self.size());
342 let basis_size = self.size();
343 if points.len() < basis_size {
344 eprintln!(
345 "Warning: Number of tau sampling points ({}) is less than basis size ({}). \
346 Basis parameters: beta={}, wmax={}, epsilon={:.2e}",
347 points.len(),
348 basis_size,
349 self.beta,
350 self.wmax(),
351 self.accuracy()
352 );
353 }
354 points
355 }
356
357 pub fn default_tau_sampling_points_size_requested(&self, size_requested: usize) -> Vec<f64> {
363 if !self.kernel().is_centrosymmetric() {
364 panic!(
365 "default_tau_sampling_points_size_requested is not supported for non-centrosymmetric kernels. \
366 The current implementation relies on centrosymmetry to generate symmetric sampling points."
367 );
368 }
369 let x = default_sampling_points(&self.sve_result.u, size_requested);
371 let mut unique_x = Vec::new();
373 if x.len() % 2 == 0 {
374 for i in 0..(x.len() / 2) {
376 unique_x.push(x[i]);
377 }
378 } else {
379 for i in 0..(x.len() / 2) {
381 unique_x.push(x[i]);
382 }
383 let x_new = 0.5 * (unique_x.last().unwrap() + 0.5);
385 unique_x.push(x_new);
386 }
387
388 let mut smpl_taus = Vec::with_capacity(2 * unique_x.len());
395 for &ux in &unique_x {
396 smpl_taus.push((self.beta / 2.0) * (ux + 1.0));
397 }
398 for i in 0..unique_x.len() {
399 smpl_taus.push(-smpl_taus[i]);
400 }
401
402 smpl_taus.sort_by(|a, b| a.partial_cmp(b).unwrap());
404
405 if smpl_taus.len() % 2 != 0 {
407 panic!("The number of tau sampling points is odd!");
408 }
409
410 for &tau in &smpl_taus {
412 if tau.abs() < 1e-10 {
413 eprintln!(
414 "Warning: tau = 0 is in the sampling points (absolute error: {})",
415 tau.abs()
416 );
417 }
418 }
419
420 smpl_taus
423 }
424
425 pub fn default_matsubara_sampling_points(
440 &self,
441 positive_only: bool,
442 ) -> Vec<crate::freq::MatsubaraFreq<S>>
443 where
444 S: 'static,
445 {
446 if !self.kernel().is_centrosymmetric() {
447 panic!(
448 "default_matsubara_sampling_points is not supported for non-centrosymmetric kernels. \
449 The current implementation relies on centrosymmetry to generate sampling points."
450 );
451 }
452 let fence = false;
453 let points = Self::default_matsubara_sampling_points_impl(
454 &self.uhat_full,
455 self.size(),
456 fence,
457 positive_only,
458 );
459 let basis_size = self.size();
460 let effective_points = if positive_only {
463 2 * points.len()
464 } else {
465 points.len()
466 };
467 if effective_points < basis_size {
468 eprintln!(
469 "Warning: Number of Matsubara sampling points ({}{}) is less than basis size ({}). \
470 Basis parameters: beta={}, wmax={}, epsilon={:.2e}",
471 points.len(),
472 if positive_only { " × 2" } else { "" },
473 basis_size,
474 self.beta,
475 self.wmax(),
476 self.accuracy()
477 );
478 }
479 points
480 }
481
482 fn fence_matsubara_sampling(
490 omega_n: &mut Vec<crate::freq::MatsubaraFreq<S>>,
491 positive_only: bool,
492 ) where
493 S: StatisticsType + 'static,
494 {
495 use crate::freq::{BosonicFreq, MatsubaraFreq};
496
497 if omega_n.is_empty() {
498 return;
499 }
500
501 let mut outer_frequencies = Vec::new();
503 if positive_only {
504 outer_frequencies.push(omega_n[omega_n.len() - 1]);
505 } else {
506 outer_frequencies.push(omega_n[0]);
507 outer_frequencies.push(omega_n[omega_n.len() - 1]);
508 }
509
510 for wn_outer in outer_frequencies {
511 let outer_val = wn_outer.n();
512 let mut diff_val = 2 * (0.025 * outer_val as f64).round() as i64;
515
516 if diff_val == 0 {
518 diff_val = 2;
519 }
520
521 let wn_diff = BosonicFreq::new(diff_val).unwrap().n();
523
524 let sign_val = if outer_val > 0 {
527 1
528 } else if outer_val < 0 {
529 -1
530 } else {
531 0
532 };
533
534 let original_size = omega_n.len();
536 if original_size >= 20 {
537 let new_n = outer_val - sign_val * wn_diff;
540 if let Ok(new_freq) = MatsubaraFreq::<S>::new(new_n) {
541 omega_n.push(new_freq);
542 }
543 }
544 if original_size >= 42 {
545 let new_n = outer_val + sign_val * wn_diff;
546 if let Ok(new_freq) = MatsubaraFreq::<S>::new(new_n) {
547 omega_n.push(new_freq);
548 }
549 }
550 }
551
552 let omega_n_set: std::collections::BTreeSet<MatsubaraFreq<S>> = omega_n.drain(..).collect();
554 *omega_n = omega_n_set.into_iter().collect();
555 }
556
557 pub fn default_matsubara_sampling_points_impl(
558 uhat_full: &PiecewiseLegendreFTVector<S>,
559 l: usize,
560 fence: bool,
561 positive_only: bool,
562 ) -> Vec<crate::freq::MatsubaraFreq<S>>
563 where
564 S: StatisticsType + 'static,
565 {
566 use crate::freq::MatsubaraFreq;
567 use crate::polyfourier::{find_extrema, sign_changes};
568 use std::collections::BTreeSet;
569
570 let mut l_requested = l;
571
572 if S::STATISTICS == crate::traits::Statistics::Fermionic && l_requested % 2 != 0 {
574 l_requested += 1;
575 } else if S::STATISTICS == crate::traits::Statistics::Bosonic && l_requested % 2 == 0 {
576 l_requested += 1;
577 }
578
579 let mut omega_n = if l_requested < uhat_full.len() {
581 sign_changes(&uhat_full[l_requested], positive_only)
582 } else {
583 find_extrema(&uhat_full[uhat_full.len() - 1], positive_only)
584 };
585
586 if S::STATISTICS == crate::traits::Statistics::Bosonic {
588 omega_n.push(MatsubaraFreq::<S>::new(0).unwrap());
589 }
590
591 let omega_n_set: BTreeSet<MatsubaraFreq<S>> = omega_n.into_iter().collect();
593 let mut omega_n: Vec<MatsubaraFreq<S>> = omega_n_set.into_iter().collect();
594
595 let expected_size = if positive_only {
597 l_requested.div_ceil(2)
598 } else {
599 l_requested
600 };
601
602 if omega_n.len() != expected_size {
603 eprintln!(
604 "Warning: Requested {} sampling frequencies for basis size L = {}, but got {}.",
605 expected_size,
606 l,
607 omega_n.len()
608 );
609 }
610
611 if fence {
613 Self::fence_matsubara_sampling(&mut omega_n, positive_only);
614 }
615
616 omega_n
617 }
618 pub fn default_omega_sampling_points(&self) -> Vec<f64> {
629 let sz = self.size();
630
631 let y = default_sampling_points(&self.sve_result.v, sz);
634
635 let wmax = self.kernel.lambda() / self.beta;
637 let omega_points: Vec<f64> = y.into_iter().map(|yi| wmax * yi).collect();
638
639 omega_points
640 }
641}
642
643impl<K, S> crate::basis_trait::Basis<S> for FiniteTempBasis<K, S>
648where
649 K: KernelProperties + CentrosymmKernel + Clone + 'static,
650 S: StatisticsType + 'static,
651{
652 type Kernel = K;
653
654 fn kernel(&self) -> &Self::Kernel {
655 &self.kernel
656 }
657
658 fn beta(&self) -> f64 {
659 self.beta
660 }
661
662 fn wmax(&self) -> f64 {
663 self.kernel.lambda() / self.beta
664 }
665
666 fn lambda(&self) -> f64 {
667 self.kernel.lambda()
668 }
669
670 fn size(&self) -> usize {
671 self.size()
672 }
673
674 fn accuracy(&self) -> f64 {
675 self.accuracy
676 }
677
678 fn significance(&self) -> Vec<f64> {
679 if let Some(&first_s) = self.s.first() {
680 self.s.iter().map(|&s| s / first_s).collect()
681 } else {
682 vec![]
683 }
684 }
685
686 fn svals(&self) -> Vec<f64> {
687 self.s.clone()
688 }
689
690 fn default_tau_sampling_points(&self) -> Vec<f64> {
691 self.default_tau_sampling_points()
692 }
693
694 fn default_matsubara_sampling_points(
695 &self,
696 positive_only: bool,
697 ) -> Vec<crate::freq::MatsubaraFreq<S>> {
698 self.default_matsubara_sampling_points(positive_only)
699 }
700
701 fn evaluate_tau(&self, tau: &[f64]) -> mdarray::DTensor<f64, 2> {
702 use crate::taufuncs::normalize_tau;
703 use mdarray::DTensor;
704
705 let n_points = tau.len();
706 let basis_size = self.size();
707
708 DTensor::<f64, 2>::from_fn([n_points, basis_size], |idx| {
713 let i = idx[0]; let l = idx[1]; let (tau_norm, sign) = normalize_tau::<S>(tau[i], self.beta);
718
719 sign * self.u[l].evaluate(tau_norm)
721 })
722 }
723
724 fn evaluate_matsubara(
725 &self,
726 freqs: &[crate::freq::MatsubaraFreq<S>],
727 ) -> mdarray::DTensor<num_complex::Complex<f64>, 2> {
728 use mdarray::DTensor;
729 use num_complex::Complex;
730
731 let n_points = freqs.len();
732 let basis_size = self.size();
733
734 DTensor::<Complex<f64>, 2>::from_fn([n_points, basis_size], |idx| {
737 let i = idx[0]; let l = idx[1]; self.uhat[l].evaluate(&freqs[i])
740 })
741 }
742
743 fn evaluate_omega(&self, omega: &[f64]) -> mdarray::DTensor<f64, 2> {
744 use mdarray::DTensor;
745
746 let n_points = omega.len();
747 let basis_size = self.size();
748
749 DTensor::<f64, 2>::from_fn([n_points, basis_size], |idx| {
752 let i = idx[0]; let l = idx[1]; self.v[l].evaluate(omega[i])
755 })
756 }
757
758 fn default_omega_sampling_points(&self) -> Vec<f64> {
759 self.default_omega_sampling_points()
760 }
761}
762
763pub type FermionicBasis = FiniteTempBasis<LogisticKernel, Fermionic>;
769
770pub type BosonicBasis = FiniteTempBasis<LogisticKernel, Bosonic>;
772
773#[cfg(test)]
774#[path = "basis_tests.rs"]
775mod basis_tests;