1use crate::fitters::{ComplexMatrixFitter, ComplexToRealFitter, InplaceFitter};
7use crate::freq::MatsubaraFreq;
8use crate::gemm::GemmBackendHandle;
9use crate::sampling::movedim;
10use crate::traits::StatisticsType;
11use mdarray::{DTensor, DynRank, Shape, Slice, Tensor, ViewMut};
12use num_complex::Complex;
13use std::marker::PhantomData;
14
15pub trait MatsubaraCoeffs: Copy + 'static {
20 fn evaluate_nd_with<S: StatisticsType>(
22 sampler: &MatsubaraSampling<S>,
23 backend: Option<&GemmBackendHandle>,
24 coeffs: &Slice<Self, DynRank>,
25 dim: usize,
26 ) -> Tensor<Complex<f64>, DynRank>;
27}
28
29impl MatsubaraCoeffs for f64 {
30 fn evaluate_nd_with<S: StatisticsType>(
31 sampler: &MatsubaraSampling<S>,
32 backend: Option<&GemmBackendHandle>,
33 coeffs: &Slice<Self, DynRank>,
34 dim: usize,
35 ) -> Tensor<Complex<f64>, DynRank> {
36 sampler.evaluate_nd_impl_real(backend, coeffs, dim)
37 }
38}
39
40impl MatsubaraCoeffs for Complex<f64> {
41 fn evaluate_nd_with<S: StatisticsType>(
42 sampler: &MatsubaraSampling<S>,
43 backend: Option<&GemmBackendHandle>,
44 coeffs: &Slice<Self, DynRank>,
45 dim: usize,
46 ) -> Tensor<Complex<f64>, DynRank> {
47 sampler.evaluate_nd_impl_complex(backend, coeffs, dim)
48 }
49}
50
51pub struct MatsubaraSampling<S: StatisticsType> {
55 sampling_points: Vec<MatsubaraFreq<S>>,
56 fitter: ComplexMatrixFitter,
57 _phantom: PhantomData<S>,
58}
59
60impl<S: StatisticsType> MatsubaraSampling<S> {
61 pub fn new(basis: &impl crate::basis_trait::Basis<S>) -> Self
65 where
66 S: 'static,
67 {
68 let sampling_points = basis.default_matsubara_sampling_points(false);
69 Self::with_sampling_points(basis, sampling_points)
70 }
71
72 pub fn with_sampling_points(
74 basis: &impl crate::basis_trait::Basis<S>,
75 mut sampling_points: Vec<MatsubaraFreq<S>>,
76 ) -> Self
77 where
78 S: 'static,
79 {
80 sampling_points.sort();
82
83 let matrix = basis.evaluate_matsubara(&sampling_points);
86
87 let fitter = ComplexMatrixFitter::new(matrix);
89
90 Self {
91 sampling_points,
92 fitter,
93 _phantom: PhantomData,
94 }
95 }
96
97 pub fn from_matrix(
112 sampling_points: Vec<MatsubaraFreq<S>>,
113 matrix: DTensor<Complex<f64>, 2>,
114 ) -> Self {
115 assert!(!sampling_points.is_empty(), "No sampling points given");
116 assert_eq!(
117 matrix.shape().0,
118 sampling_points.len(),
119 "Matrix rows ({}) must match number of sampling points ({})",
120 matrix.shape().0,
121 sampling_points.len()
122 );
123 debug_assert!(
124 sampling_points.windows(2).all(|w| w[0] <= w[1]),
125 "Sampling points must be sorted in ascending order"
126 );
127
128 let fitter = ComplexMatrixFitter::new(matrix);
129
130 Self {
131 sampling_points,
132 fitter,
133 _phantom: PhantomData,
134 }
135 }
136
137 pub fn sampling_points(&self) -> &[MatsubaraFreq<S>] {
139 &self.sampling_points
140 }
141
142 pub fn n_sampling_points(&self) -> usize {
144 self.sampling_points.len()
145 }
146
147 pub fn basis_size(&self) -> usize {
149 self.fitter.basis_size()
150 }
151
152 pub fn matrix(&self) -> &DTensor<Complex<f64>, 2> {
154 &self.fitter.matrix
155 }
156
157 pub fn evaluate(&self, coeffs: &[Complex<f64>]) -> Vec<Complex<f64>> {
165 self.fitter.evaluate(None, coeffs)
166 }
167
168 pub fn fit(&self, values: &[Complex<f64>]) -> Vec<Complex<f64>> {
176 self.fitter.fit(None, values)
177 }
178
179 fn evaluate_nd_impl_real(
207 &self,
208 backend: Option<&GemmBackendHandle>,
209 coeffs: &Slice<f64, DynRank>,
210 dim: usize,
211 ) -> Tensor<Complex<f64>, DynRank> {
212 let rank = coeffs.rank();
213 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
214
215 let basis_size = self.basis_size();
216 let target_dim_size = coeffs.shape().dim(dim);
217
218 assert_eq!(
219 target_dim_size, basis_size,
220 "coeffs.shape().dim({}) = {} must equal basis_size = {}",
221 dim, target_dim_size, basis_size
222 );
223
224 let coeffs_dim0 = movedim(coeffs, dim, 0);
226
227 let extra_size: usize = coeffs_dim0.len() / basis_size;
229
230 let coeffs_2d_dyn = coeffs_dim0
231 .reshape(&[basis_size, extra_size][..])
232 .to_tensor();
233
234 let coeffs_2d = DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| {
236 coeffs_2d_dyn[&[idx[0], idx[1]][..]]
237 });
238 let coeffs_2d_view = coeffs_2d.view(.., ..);
239 let result_2d = self.fitter.evaluate_2d_real(backend, &coeffs_2d_view);
240
241 let n_points = self.n_sampling_points();
243 let mut result_shape = vec![n_points];
244 coeffs_dim0.shape().with_dims(|dims| {
245 for i in 1..dims.len() {
246 result_shape.push(dims[i]);
247 }
248 });
249
250 let result_dim0 = result_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
251
252 movedim(&result_dim0, 0, dim)
254 }
255
256 fn evaluate_nd_impl_complex(
258 &self,
259 backend: Option<&GemmBackendHandle>,
260 coeffs: &Slice<Complex<f64>, DynRank>,
261 dim: usize,
262 ) -> Tensor<Complex<f64>, DynRank> {
263 let rank = coeffs.rank();
264 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
265
266 let basis_size = self.basis_size();
267 let target_dim_size = coeffs.shape().dim(dim);
268
269 assert_eq!(
270 target_dim_size, basis_size,
271 "coeffs.shape().dim({}) = {} must equal basis_size = {}",
272 dim, target_dim_size, basis_size
273 );
274
275 let coeffs_dim0 = movedim(coeffs, dim, 0);
277
278 let extra_size: usize = coeffs_dim0.len() / basis_size;
280
281 let coeffs_2d_dyn = coeffs_dim0
282 .reshape(&[basis_size, extra_size][..])
283 .to_tensor();
284
285 let coeffs_2d = DTensor::<Complex<f64>, 2>::from_fn([basis_size, extra_size], |idx| {
287 coeffs_2d_dyn[&[idx[0], idx[1]][..]]
288 });
289 let coeffs_2d_view = coeffs_2d.view(.., ..);
290 let result_2d = self.fitter.evaluate_2d(backend, &coeffs_2d_view);
291
292 let n_points = self.n_sampling_points();
294 let mut result_shape = vec![n_points];
295 coeffs_dim0.shape().with_dims(|dims| {
296 for i in 1..dims.len() {
297 result_shape.push(dims[i]);
298 }
299 });
300
301 let result_dim0 = result_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
302
303 movedim(&result_dim0, 0, dim)
305 }
306
307 pub fn evaluate_nd<T: MatsubaraCoeffs>(
323 &self,
324 backend: Option<&GemmBackendHandle>,
325 coeffs: &Slice<T, DynRank>,
326 dim: usize,
327 ) -> Tensor<Complex<f64>, DynRank> {
328 T::evaluate_nd_with(self, backend, coeffs, dim)
329 }
330
331 pub fn evaluate_nd_real(
344 &self,
345 backend: Option<&GemmBackendHandle>,
346 coeffs: &Tensor<f64, DynRank>,
347 dim: usize,
348 ) -> Tensor<Complex<f64>, DynRank> {
349 let rank = coeffs.rank();
350 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
351
352 let basis_size = self.basis_size();
353 let target_dim_size = coeffs.shape().dim(dim);
354
355 assert_eq!(
356 target_dim_size, basis_size,
357 "coeffs.shape().dim({}) = {} must equal basis_size = {}",
358 dim, target_dim_size, basis_size
359 );
360
361 let coeffs_dim0 = movedim(coeffs, dim, 0);
363
364 let extra_size: usize = coeffs_dim0.len() / basis_size;
366
367 let coeffs_2d_dyn = coeffs_dim0
368 .reshape(&[basis_size, extra_size][..])
369 .to_tensor();
370
371 let coeffs_2d = DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| {
373 coeffs_2d_dyn[&[idx[0], idx[1]][..]]
374 });
375
376 let coeffs_2d_view = coeffs_2d.view(.., ..);
378 let values_2d = self.fitter.evaluate_2d_real(backend, &coeffs_2d_view);
379
380 let n_points = self.n_sampling_points();
382 let mut result_shape = Vec::with_capacity(rank);
383 result_shape.push(n_points);
384 coeffs_dim0.shape().with_dims(|dims| {
385 for i in 1..dims.len() {
386 result_shape.push(dims[i]);
387 }
388 });
389
390 let result_dim0 = values_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
391
392 movedim(&result_dim0, 0, dim)
394 }
395
396 pub fn fit_nd(
406 &self,
407 backend: Option<&GemmBackendHandle>,
408 values: &Tensor<Complex<f64>, DynRank>,
409 dim: usize,
410 ) -> Tensor<Complex<f64>, DynRank> {
411 let rank = values.rank();
412 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
413
414 let n_points = self.n_sampling_points();
415 let target_dim_size = values.shape().dim(dim);
416
417 assert_eq!(
418 target_dim_size, n_points,
419 "values.shape().dim({}) = {} must equal n_sampling_points = {}",
420 dim, target_dim_size, n_points
421 );
422
423 let values_dim0 = movedim(values, dim, 0);
425
426 let extra_size: usize = values_dim0.len() / n_points;
428 let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
429
430 let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
432 values_2d_dyn[&[idx[0], idx[1]][..]]
433 });
434
435 let values_2d_view = values_2d.view(.., ..);
437 let coeffs_2d = self.fitter.fit_2d(backend, &values_2d_view);
438
439 let basis_size = self.basis_size();
441 let mut coeffs_shape = vec![basis_size];
442 values_dim0.shape().with_dims(|dims| {
443 for i in 1..dims.len() {
444 coeffs_shape.push(dims[i]);
445 }
446 });
447
448 let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
449
450 movedim(&coeffs_dim0, 0, dim)
452 }
453
454 pub fn fit_nd_real(
467 &self,
468 backend: Option<&GemmBackendHandle>,
469 values: &Tensor<Complex<f64>, DynRank>,
470 dim: usize,
471 ) -> Tensor<f64, DynRank> {
472 let rank = values.rank();
473 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
474
475 let n_points = self.n_sampling_points();
476 let target_dim_size = values.shape().dim(dim);
477
478 assert_eq!(
479 target_dim_size, n_points,
480 "values.shape().dim({}) = {} must equal n_sampling_points = {}",
481 dim, target_dim_size, n_points
482 );
483
484 let values_dim0 = movedim(values, dim, 0);
486
487 let extra_size: usize = values_dim0.len() / n_points;
489 let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
490
491 let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
493 values_2d_dyn[&[idx[0], idx[1]][..]]
494 });
495
496 let values_2d_view = values_2d.view(.., ..);
498 let coeffs_2d = self.fitter.fit_2d_real(backend, &values_2d_view);
499
500 let basis_size = self.basis_size();
502 let mut coeffs_shape = vec![basis_size];
503 values_dim0.shape().with_dims(|dims| {
504 for i in 1..dims.len() {
505 coeffs_shape.push(dims[i]);
506 }
507 });
508
509 let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
510
511 movedim(&coeffs_dim0, 0, dim)
513 }
514
515 pub fn evaluate_nd_to<T: MatsubaraCoeffs>(
525 &self,
526 backend: Option<&GemmBackendHandle>,
527 coeffs: &Slice<T, DynRank>,
528 dim: usize,
529 out: &mut Tensor<Complex<f64>, DynRank>,
530 ) {
531 let rank = coeffs.rank();
533 assert_eq!(
534 out.rank(),
535 rank,
536 "out.rank()={} must equal coeffs.rank()={}",
537 out.rank(),
538 rank
539 );
540
541 let n_points = self.n_sampling_points();
542 let out_dim_size = out.shape().dim(dim);
543 assert_eq!(
544 out_dim_size, n_points,
545 "out.shape().dim({}) = {} must equal n_sampling_points = {}",
546 dim, out_dim_size, n_points
547 );
548
549 for d in 0..rank {
551 if d != dim {
552 let coeffs_d = coeffs.shape().dim(d);
553 let out_d = out.shape().dim(d);
554 assert_eq!(
555 coeffs_d, out_d,
556 "coeffs.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
557 d, coeffs_d, d, out_d
558 );
559 }
560 }
561
562 let result = self.evaluate_nd(backend, coeffs, dim);
564
565 let total = out.len();
567 for i in 0..total {
568 let mut idx = vec![0usize; rank];
569 let mut remaining = i;
570 for d in (0..rank).rev() {
571 let dim_size = out.shape().dim(d);
572 idx[d] = remaining % dim_size;
573 remaining /= dim_size;
574 }
575 out[&idx[..]] = result[&idx[..]];
576 }
577 }
578
579 pub fn fit_nd_to(
586 &self,
587 backend: Option<&GemmBackendHandle>,
588 values: &Tensor<Complex<f64>, DynRank>,
589 dim: usize,
590 out: &mut Tensor<Complex<f64>, DynRank>,
591 ) {
592 let rank = values.rank();
594 assert_eq!(
595 out.rank(),
596 rank,
597 "out.rank()={} must equal values.rank()={}",
598 out.rank(),
599 rank
600 );
601
602 let basis_size = self.basis_size();
603 let out_dim_size = out.shape().dim(dim);
604 assert_eq!(
605 out_dim_size, basis_size,
606 "out.shape().dim({}) = {} must equal basis_size = {}",
607 dim, out_dim_size, basis_size
608 );
609
610 for d in 0..rank {
612 if d != dim {
613 let values_d = values.shape().dim(d);
614 let out_d = out.shape().dim(d);
615 assert_eq!(
616 values_d, out_d,
617 "values.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
618 d, values_d, d, out_d
619 );
620 }
621 }
622
623 let result = self.fit_nd(backend, values, dim);
625
626 let total = out.len();
628 for i in 0..total {
629 let mut idx = vec![0usize; rank];
630 let mut remaining = i;
631 for d in (0..rank).rev() {
632 let dim_size = out.shape().dim(d);
633 idx[d] = remaining % dim_size;
634 remaining /= dim_size;
635 }
636 out[&idx[..]] = result[&idx[..]];
637 }
638 }
639}
640
641impl<S: StatisticsType> InplaceFitter for MatsubaraSampling<S> {
648 fn n_points(&self) -> usize {
649 self.n_sampling_points()
650 }
651
652 fn basis_size(&self) -> usize {
653 self.basis_size()
654 }
655
656 fn evaluate_nd_dz_to(
657 &self,
658 backend: Option<&GemmBackendHandle>,
659 coeffs: &Slice<f64, DynRank>,
660 dim: usize,
661 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
662 ) -> bool {
663 self.fitter.evaluate_nd_dz_to(backend, coeffs, dim, out)
664 }
665
666 fn evaluate_nd_zz_to(
667 &self,
668 backend: Option<&GemmBackendHandle>,
669 coeffs: &Slice<Complex<f64>, DynRank>,
670 dim: usize,
671 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
672 ) -> bool {
673 self.fitter.evaluate_nd_zz_to(backend, coeffs, dim, out)
674 }
675
676 fn fit_nd_zd_to(
677 &self,
678 backend: Option<&GemmBackendHandle>,
679 values: &Slice<Complex<f64>, DynRank>,
680 dim: usize,
681 out: &mut ViewMut<'_, f64, DynRank>,
682 ) -> bool {
683 self.fitter.fit_nd_zd_to(backend, values, dim, out)
684 }
685
686 fn fit_nd_zz_to(
687 &self,
688 backend: Option<&GemmBackendHandle>,
689 values: &Slice<Complex<f64>, DynRank>,
690 dim: usize,
691 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
692 ) -> bool {
693 self.fitter.fit_nd_zz_to(backend, values, dim, out)
694 }
695}
696
697pub struct MatsubaraSamplingPositiveOnly<S: StatisticsType> {
702 sampling_points: Vec<MatsubaraFreq<S>>,
703 fitter: ComplexToRealFitter,
704 _phantom: PhantomData<S>,
705}
706
707impl<S: StatisticsType> MatsubaraSamplingPositiveOnly<S> {
708 pub fn new(basis: &impl crate::basis_trait::Basis<S>) -> Self
713 where
714 S: 'static,
715 {
716 let sampling_points = basis.default_matsubara_sampling_points(true);
717 Self::with_sampling_points(basis, sampling_points)
718 }
719
720 pub fn with_sampling_points(
722 basis: &impl crate::basis_trait::Basis<S>,
723 mut sampling_points: Vec<MatsubaraFreq<S>>,
724 ) -> Self
725 where
726 S: 'static,
727 {
728 sampling_points.sort();
730
731 assert!(
733 sampling_points.iter().all(|f| f.n() >= 0),
734 "All sampling points must be non-negative for positive-only Matsubara sampling"
735 );
736
737 let matrix = basis.evaluate_matsubara(&sampling_points);
740
741 let fitter = ComplexToRealFitter::new(&matrix);
743
744 Self {
745 sampling_points,
746 fitter,
747 _phantom: PhantomData,
748 }
749 }
750
751 pub fn from_matrix(
766 sampling_points: Vec<MatsubaraFreq<S>>,
767 matrix: DTensor<Complex<f64>, 2>,
768 ) -> Self {
769 assert!(!sampling_points.is_empty(), "No sampling points given");
770 assert_eq!(
771 matrix.shape().0,
772 sampling_points.len(),
773 "Matrix rows ({}) must match number of sampling points ({})",
774 matrix.shape().0,
775 sampling_points.len()
776 );
777 debug_assert!(
778 sampling_points.windows(2).all(|w| w[0] <= w[1]),
779 "Sampling points must be sorted in ascending order"
780 );
781
782 let fitter = ComplexToRealFitter::new(&matrix);
783
784 Self {
785 sampling_points,
786 fitter,
787 _phantom: PhantomData,
788 }
789 }
790
791 pub fn sampling_points(&self) -> &[MatsubaraFreq<S>] {
793 &self.sampling_points
794 }
795
796 pub fn n_sampling_points(&self) -> usize {
798 self.sampling_points.len()
799 }
800
801 pub fn basis_size(&self) -> usize {
803 self.fitter.basis_size()
804 }
805
806 pub fn matrix(&self) -> &DTensor<Complex<f64>, 2> {
808 &self.fitter.matrix
809 }
810
811 pub fn evaluate(&self, coeffs: &[f64]) -> Vec<Complex<f64>> {
813 self.fitter.evaluate(None, coeffs)
814 }
815
816 pub fn fit(&self, values: &[Complex<f64>]) -> Vec<f64> {
818 self.fitter.fit(None, values)
819 }
820
821 pub fn evaluate_nd(
830 &self,
831 backend: Option<&GemmBackendHandle>,
832 coeffs: &Tensor<f64, DynRank>,
833 dim: usize,
834 ) -> Tensor<Complex<f64>, DynRank> {
835 let rank = coeffs.rank();
836 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
837
838 let basis_size = self.basis_size();
839 let target_dim_size = coeffs.shape().dim(dim);
840
841 assert_eq!(
842 target_dim_size, basis_size,
843 "coeffs.shape().dim({}) = {} must equal basis_size = {}",
844 dim, target_dim_size, basis_size
845 );
846
847 let coeffs_dim0 = movedim(coeffs, dim, 0);
849
850 let extra_size: usize = coeffs_dim0.len() / basis_size;
852
853 let coeffs_2d_dyn = coeffs_dim0
854 .reshape(&[basis_size, extra_size][..])
855 .to_tensor();
856
857 let coeffs_2d = DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| {
859 coeffs_2d_dyn[&[idx[0], idx[1]][..]]
860 });
861
862 let coeffs_2d_view = coeffs_2d.view(.., ..);
864 let result_2d = self.fitter.evaluate_2d(backend, &coeffs_2d_view);
865
866 let n_points = self.n_sampling_points();
868 let mut result_shape = vec![n_points];
869 coeffs_dim0.shape().with_dims(|dims| {
870 for i in 1..dims.len() {
871 result_shape.push(dims[i]);
872 }
873 });
874
875 let result_dim0 = result_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
876
877 movedim(&result_dim0, 0, dim)
879 }
880
881 pub fn fit_nd(
891 &self,
892 backend: Option<&GemmBackendHandle>,
893 values: &Tensor<Complex<f64>, DynRank>,
894 dim: usize,
895 ) -> Tensor<f64, DynRank> {
896 let rank = values.rank();
897 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
898
899 let n_points = self.n_sampling_points();
900 let target_dim_size = values.shape().dim(dim);
901
902 assert_eq!(
903 target_dim_size, n_points,
904 "values.shape().dim({}) = {} must equal n_sampling_points = {}",
905 dim, target_dim_size, n_points
906 );
907
908 let values_dim0 = movedim(values, dim, 0);
910
911 let extra_size: usize = values_dim0.len() / n_points;
913 let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
914
915 let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
917 values_2d_dyn[&[idx[0], idx[1]][..]]
918 });
919
920 let values_2d_view = values_2d.view(.., ..);
922 let coeffs_2d = self.fitter.fit_2d(backend, &values_2d_view);
923
924 let basis_size = self.basis_size();
926 let mut coeffs_shape = vec![basis_size];
927 values_dim0.shape().with_dims(|dims| {
928 for i in 1..dims.len() {
929 coeffs_shape.push(dims[i]);
930 }
931 });
932
933 let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
934
935 movedim(&coeffs_dim0, 0, dim)
937 }
938
939 pub fn evaluate_nd_to(
946 &self,
947 backend: Option<&GemmBackendHandle>,
948 coeffs: &Tensor<f64, DynRank>,
949 dim: usize,
950 out: &mut Tensor<Complex<f64>, DynRank>,
951 ) {
952 let rank = coeffs.rank();
954 assert_eq!(
955 out.rank(),
956 rank,
957 "out.rank()={} must equal coeffs.rank()={}",
958 out.rank(),
959 rank
960 );
961
962 let n_points = self.n_sampling_points();
963 let out_dim_size = out.shape().dim(dim);
964 assert_eq!(
965 out_dim_size, n_points,
966 "out.shape().dim({}) = {} must equal n_sampling_points = {}",
967 dim, out_dim_size, n_points
968 );
969
970 for d in 0..rank {
972 if d != dim {
973 let coeffs_d = coeffs.shape().dim(d);
974 let out_d = out.shape().dim(d);
975 assert_eq!(
976 coeffs_d, out_d,
977 "coeffs.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
978 d, coeffs_d, d, out_d
979 );
980 }
981 }
982
983 let result = self.evaluate_nd(backend, coeffs, dim);
985
986 let total = out.len();
988 for i in 0..total {
989 let mut idx = vec![0usize; rank];
990 let mut remaining = i;
991 for d in (0..rank).rev() {
992 let dim_size = out.shape().dim(d);
993 idx[d] = remaining % dim_size;
994 remaining /= dim_size;
995 }
996 out[&idx[..]] = result[&idx[..]];
997 }
998 }
999
1000 pub fn fit_nd_to(
1007 &self,
1008 backend: Option<&GemmBackendHandle>,
1009 values: &Tensor<Complex<f64>, DynRank>,
1010 dim: usize,
1011 out: &mut Tensor<f64, DynRank>,
1012 ) {
1013 let rank = values.rank();
1015 assert_eq!(
1016 out.rank(),
1017 rank,
1018 "out.rank()={} must equal values.rank()={}",
1019 out.rank(),
1020 rank
1021 );
1022
1023 let basis_size = self.basis_size();
1024 let out_dim_size = out.shape().dim(dim);
1025 assert_eq!(
1026 out_dim_size, basis_size,
1027 "out.shape().dim({}) = {} must equal basis_size = {}",
1028 dim, out_dim_size, basis_size
1029 );
1030
1031 for d in 0..rank {
1033 if d != dim {
1034 let values_d = values.shape().dim(d);
1035 let out_d = out.shape().dim(d);
1036 assert_eq!(
1037 values_d, out_d,
1038 "values.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
1039 d, values_d, d, out_d
1040 );
1041 }
1042 }
1043
1044 let result = self.fit_nd(backend, values, dim);
1046
1047 let total = out.len();
1049 for i in 0..total {
1050 let mut idx = vec![0usize; rank];
1051 let mut remaining = i;
1052 for d in (0..rank).rev() {
1053 let dim_size = out.shape().dim(d);
1054 idx[d] = remaining % dim_size;
1055 remaining /= dim_size;
1056 }
1057 out[&idx[..]] = result[&idx[..]];
1058 }
1059 }
1060}
1061
1062impl<S: StatisticsType> InplaceFitter for MatsubaraSamplingPositiveOnly<S> {
1070 fn n_points(&self) -> usize {
1071 self.n_sampling_points()
1072 }
1073
1074 fn basis_size(&self) -> usize {
1075 self.basis_size()
1076 }
1077
1078 fn evaluate_nd_dz_to(
1079 &self,
1080 backend: Option<&GemmBackendHandle>,
1081 coeffs: &Slice<f64, DynRank>,
1082 dim: usize,
1083 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
1084 ) -> bool {
1085 self.fitter.evaluate_nd_dz_to(backend, coeffs, dim, out)
1086 }
1087
1088 fn evaluate_nd_zz_to(
1089 &self,
1090 backend: Option<&GemmBackendHandle>,
1091 coeffs: &Slice<Complex<f64>, DynRank>,
1092 dim: usize,
1093 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
1094 ) -> bool {
1095 self.fitter.evaluate_nd_zz_to(backend, coeffs, dim, out)
1096 }
1097
1098 fn fit_nd_zd_to(
1099 &self,
1100 backend: Option<&GemmBackendHandle>,
1101 values: &Slice<Complex<f64>, DynRank>,
1102 dim: usize,
1103 out: &mut ViewMut<'_, f64, DynRank>,
1104 ) -> bool {
1105 self.fitter.fit_nd_zd_to(backend, values, dim, out)
1106 }
1107
1108 fn fit_nd_zz_to(
1109 &self,
1110 backend: Option<&GemmBackendHandle>,
1111 values: &Slice<Complex<f64>, DynRank>,
1112 dim: usize,
1113 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
1114 ) -> bool {
1115 self.fitter.fit_nd_zz_to(backend, values, dim, out)
1116 }
1117}
1118
1119#[cfg(test)]
1120#[path = "matsubara_sampling_tests.rs"]
1121mod tests;