1use crate::fitters::{ComplexMatrixFitter, ComplexToRealFitter, InplaceFitter};
7use crate::freq::MatsubaraFreq;
8use crate::gemm::GemmBackendHandle;
9use crate::sampling::movedim;
10use crate::traits::StatisticsType;
11use mdarray::{DTensor, DynRank, Shape, Slice, Tensor, ViewMut};
12use num_complex::Complex;
13use std::marker::PhantomData;
14
15pub trait MatsubaraCoeffs: Copy + 'static {
20 fn evaluate_nd_with<S: StatisticsType>(
22 sampler: &MatsubaraSampling<S>,
23 backend: Option<&GemmBackendHandle>,
24 coeffs: &Slice<Self, DynRank>,
25 dim: usize,
26 ) -> Tensor<Complex<f64>, DynRank>;
27}
28
29impl MatsubaraCoeffs for f64 {
30 fn evaluate_nd_with<S: StatisticsType>(
31 sampler: &MatsubaraSampling<S>,
32 backend: Option<&GemmBackendHandle>,
33 coeffs: &Slice<Self, DynRank>,
34 dim: usize,
35 ) -> Tensor<Complex<f64>, DynRank> {
36 sampler.evaluate_nd_impl_real(backend, coeffs, dim)
37 }
38}
39
40impl MatsubaraCoeffs for Complex<f64> {
41 fn evaluate_nd_with<S: StatisticsType>(
42 sampler: &MatsubaraSampling<S>,
43 backend: Option<&GemmBackendHandle>,
44 coeffs: &Slice<Self, DynRank>,
45 dim: usize,
46 ) -> Tensor<Complex<f64>, DynRank> {
47 sampler.evaluate_nd_impl_complex(backend, coeffs, dim)
48 }
49}
50
51pub struct MatsubaraSampling<S: StatisticsType> {
55 sampling_points: Vec<MatsubaraFreq<S>>,
56 fitter: ComplexMatrixFitter,
57 _phantom: PhantomData<S>,
58}
59
60impl<S: StatisticsType> MatsubaraSampling<S> {
61 pub fn new(basis: &impl crate::basis_trait::Basis<S>) -> Self
65 where
66 S: 'static,
67 {
68 let sampling_points = basis.default_matsubara_sampling_points(false);
69 Self::with_sampling_points(basis, sampling_points)
70 }
71
72 pub fn with_sampling_points(
74 basis: &impl crate::basis_trait::Basis<S>,
75 mut sampling_points: Vec<MatsubaraFreq<S>>,
76 ) -> Self
77 where
78 S: 'static,
79 {
80 sampling_points.sort();
82
83 let matrix = basis.evaluate_matsubara(&sampling_points);
86
87 let fitter = ComplexMatrixFitter::new(matrix);
89
90 Self {
91 sampling_points,
92 fitter,
93 _phantom: PhantomData,
94 }
95 }
96
97 pub fn from_matrix(
112 mut sampling_points: Vec<MatsubaraFreq<S>>,
113 matrix: DTensor<Complex<f64>, 2>,
114 ) -> Self {
115 assert!(!sampling_points.is_empty(), "No sampling points given");
116 assert_eq!(
117 matrix.shape().0,
118 sampling_points.len(),
119 "Matrix rows ({}) must match number of sampling points ({})",
120 matrix.shape().0,
121 sampling_points.len()
122 );
123
124 sampling_points.sort();
126
127 let fitter = ComplexMatrixFitter::new(matrix);
128
129 Self {
130 sampling_points,
131 fitter,
132 _phantom: PhantomData,
133 }
134 }
135
136 pub fn sampling_points(&self) -> &[MatsubaraFreq<S>] {
138 &self.sampling_points
139 }
140
141 pub fn n_sampling_points(&self) -> usize {
143 self.sampling_points.len()
144 }
145
146 pub fn basis_size(&self) -> usize {
148 self.fitter.basis_size()
149 }
150
151 pub fn matrix(&self) -> &DTensor<Complex<f64>, 2> {
153 &self.fitter.matrix
154 }
155
156 pub fn evaluate(&self, coeffs: &[Complex<f64>]) -> Vec<Complex<f64>> {
164 self.fitter.evaluate(None, coeffs)
165 }
166
167 pub fn fit(&self, values: &[Complex<f64>]) -> Vec<Complex<f64>> {
175 self.fitter.fit(None, values)
176 }
177
178 fn evaluate_nd_impl_real(
206 &self,
207 backend: Option<&GemmBackendHandle>,
208 coeffs: &Slice<f64, DynRank>,
209 dim: usize,
210 ) -> Tensor<Complex<f64>, DynRank> {
211 let rank = coeffs.rank();
212 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
213
214 let basis_size = self.basis_size();
215 let target_dim_size = coeffs.shape().dim(dim);
216
217 assert_eq!(
218 target_dim_size, basis_size,
219 "coeffs.shape().dim({}) = {} must equal basis_size = {}",
220 dim, target_dim_size, basis_size
221 );
222
223 let coeffs_dim0 = movedim(coeffs, dim, 0);
225
226 let extra_size: usize = coeffs_dim0.len() / basis_size;
228
229 let coeffs_2d_dyn = coeffs_dim0
230 .reshape(&[basis_size, extra_size][..])
231 .to_tensor();
232
233 let coeffs_2d = DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| {
235 coeffs_2d_dyn[&[idx[0], idx[1]][..]]
236 });
237 let coeffs_2d_view = coeffs_2d.view(.., ..);
238 let result_2d = self.fitter.evaluate_2d_real(backend, &coeffs_2d_view);
239
240 let n_points = self.n_sampling_points();
242 let mut result_shape = vec![n_points];
243 coeffs_dim0.shape().with_dims(|dims| {
244 for i in 1..dims.len() {
245 result_shape.push(dims[i]);
246 }
247 });
248
249 let result_dim0 = result_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
250
251 movedim(&result_dim0, 0, dim)
253 }
254
255 fn evaluate_nd_impl_complex(
257 &self,
258 backend: Option<&GemmBackendHandle>,
259 coeffs: &Slice<Complex<f64>, DynRank>,
260 dim: usize,
261 ) -> Tensor<Complex<f64>, DynRank> {
262 let rank = coeffs.rank();
263 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
264
265 let basis_size = self.basis_size();
266 let target_dim_size = coeffs.shape().dim(dim);
267
268 assert_eq!(
269 target_dim_size, basis_size,
270 "coeffs.shape().dim({}) = {} must equal basis_size = {}",
271 dim, target_dim_size, basis_size
272 );
273
274 let coeffs_dim0 = movedim(coeffs, dim, 0);
276
277 let extra_size: usize = coeffs_dim0.len() / basis_size;
279
280 let coeffs_2d_dyn = coeffs_dim0
281 .reshape(&[basis_size, extra_size][..])
282 .to_tensor();
283
284 let coeffs_2d = DTensor::<Complex<f64>, 2>::from_fn([basis_size, extra_size], |idx| {
286 coeffs_2d_dyn[&[idx[0], idx[1]][..]]
287 });
288 let coeffs_2d_view = coeffs_2d.view(.., ..);
289 let result_2d = self.fitter.evaluate_2d(backend, &coeffs_2d_view);
290
291 let n_points = self.n_sampling_points();
293 let mut result_shape = vec![n_points];
294 coeffs_dim0.shape().with_dims(|dims| {
295 for i in 1..dims.len() {
296 result_shape.push(dims[i]);
297 }
298 });
299
300 let result_dim0 = result_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
301
302 movedim(&result_dim0, 0, dim)
304 }
305
306 pub fn evaluate_nd<T: MatsubaraCoeffs>(
322 &self,
323 backend: Option<&GemmBackendHandle>,
324 coeffs: &Slice<T, DynRank>,
325 dim: usize,
326 ) -> Tensor<Complex<f64>, DynRank> {
327 T::evaluate_nd_with(self, backend, coeffs, dim)
328 }
329
330 pub fn evaluate_nd_real(
343 &self,
344 backend: Option<&GemmBackendHandle>,
345 coeffs: &Tensor<f64, DynRank>,
346 dim: usize,
347 ) -> Tensor<Complex<f64>, DynRank> {
348 let rank = coeffs.rank();
349 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
350
351 let basis_size = self.basis_size();
352 let target_dim_size = coeffs.shape().dim(dim);
353
354 assert_eq!(
355 target_dim_size, basis_size,
356 "coeffs.shape().dim({}) = {} must equal basis_size = {}",
357 dim, target_dim_size, basis_size
358 );
359
360 let coeffs_dim0 = movedim(coeffs, dim, 0);
362
363 let extra_size: usize = coeffs_dim0.len() / basis_size;
365
366 let coeffs_2d_dyn = coeffs_dim0
367 .reshape(&[basis_size, extra_size][..])
368 .to_tensor();
369
370 let coeffs_2d = DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| {
372 coeffs_2d_dyn[&[idx[0], idx[1]][..]]
373 });
374
375 let coeffs_2d_view = coeffs_2d.view(.., ..);
377 let values_2d = self.fitter.evaluate_2d_real(backend, &coeffs_2d_view);
378
379 let n_points = self.n_sampling_points();
381 let mut result_shape = Vec::with_capacity(rank);
382 result_shape.push(n_points);
383 coeffs_dim0.shape().with_dims(|dims| {
384 for i in 1..dims.len() {
385 result_shape.push(dims[i]);
386 }
387 });
388
389 let result_dim0 = values_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
390
391 movedim(&result_dim0, 0, dim)
393 }
394
395 pub fn fit_nd(
405 &self,
406 backend: Option<&GemmBackendHandle>,
407 values: &Tensor<Complex<f64>, DynRank>,
408 dim: usize,
409 ) -> Tensor<Complex<f64>, DynRank> {
410 let rank = values.rank();
411 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
412
413 let n_points = self.n_sampling_points();
414 let target_dim_size = values.shape().dim(dim);
415
416 assert_eq!(
417 target_dim_size, n_points,
418 "values.shape().dim({}) = {} must equal n_sampling_points = {}",
419 dim, target_dim_size, n_points
420 );
421
422 let values_dim0 = movedim(values, dim, 0);
424
425 let extra_size: usize = values_dim0.len() / n_points;
427 let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
428
429 let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
431 values_2d_dyn[&[idx[0], idx[1]][..]]
432 });
433
434 let values_2d_view = values_2d.view(.., ..);
436 let coeffs_2d = self.fitter.fit_2d(backend, &values_2d_view);
437
438 let basis_size = self.basis_size();
440 let mut coeffs_shape = vec![basis_size];
441 values_dim0.shape().with_dims(|dims| {
442 for i in 1..dims.len() {
443 coeffs_shape.push(dims[i]);
444 }
445 });
446
447 let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
448
449 movedim(&coeffs_dim0, 0, dim)
451 }
452
453 pub fn fit_nd_real(
466 &self,
467 backend: Option<&GemmBackendHandle>,
468 values: &Tensor<Complex<f64>, DynRank>,
469 dim: usize,
470 ) -> Tensor<f64, DynRank> {
471 let rank = values.rank();
472 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
473
474 let n_points = self.n_sampling_points();
475 let target_dim_size = values.shape().dim(dim);
476
477 assert_eq!(
478 target_dim_size, n_points,
479 "values.shape().dim({}) = {} must equal n_sampling_points = {}",
480 dim, target_dim_size, n_points
481 );
482
483 let values_dim0 = movedim(values, dim, 0);
485
486 let extra_size: usize = values_dim0.len() / n_points;
488 let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
489
490 let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
492 values_2d_dyn[&[idx[0], idx[1]][..]]
493 });
494
495 let values_2d_view = values_2d.view(.., ..);
497 let coeffs_2d = self.fitter.fit_2d_real(backend, &values_2d_view);
498
499 let basis_size = self.basis_size();
501 let mut coeffs_shape = vec![basis_size];
502 values_dim0.shape().with_dims(|dims| {
503 for i in 1..dims.len() {
504 coeffs_shape.push(dims[i]);
505 }
506 });
507
508 let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
509
510 movedim(&coeffs_dim0, 0, dim)
512 }
513
514 pub fn evaluate_nd_to<T: MatsubaraCoeffs>(
524 &self,
525 backend: Option<&GemmBackendHandle>,
526 coeffs: &Slice<T, DynRank>,
527 dim: usize,
528 out: &mut Tensor<Complex<f64>, DynRank>,
529 ) {
530 let rank = coeffs.rank();
532 assert_eq!(
533 out.rank(),
534 rank,
535 "out.rank()={} must equal coeffs.rank()={}",
536 out.rank(),
537 rank
538 );
539
540 let n_points = self.n_sampling_points();
541 let out_dim_size = out.shape().dim(dim);
542 assert_eq!(
543 out_dim_size, n_points,
544 "out.shape().dim({}) = {} must equal n_sampling_points = {}",
545 dim, out_dim_size, n_points
546 );
547
548 for d in 0..rank {
550 if d != dim {
551 let coeffs_d = coeffs.shape().dim(d);
552 let out_d = out.shape().dim(d);
553 assert_eq!(
554 coeffs_d, out_d,
555 "coeffs.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
556 d, coeffs_d, d, out_d
557 );
558 }
559 }
560
561 let result = self.evaluate_nd(backend, coeffs, dim);
563
564 let total = out.len();
566 for i in 0..total {
567 let mut idx = vec![0usize; rank];
568 let mut remaining = i;
569 for d in (0..rank).rev() {
570 let dim_size = out.shape().dim(d);
571 idx[d] = remaining % dim_size;
572 remaining /= dim_size;
573 }
574 out[&idx[..]] = result[&idx[..]];
575 }
576 }
577
578 pub fn fit_nd_to(
585 &self,
586 backend: Option<&GemmBackendHandle>,
587 values: &Tensor<Complex<f64>, DynRank>,
588 dim: usize,
589 out: &mut Tensor<Complex<f64>, DynRank>,
590 ) {
591 let rank = values.rank();
593 assert_eq!(
594 out.rank(),
595 rank,
596 "out.rank()={} must equal values.rank()={}",
597 out.rank(),
598 rank
599 );
600
601 let basis_size = self.basis_size();
602 let out_dim_size = out.shape().dim(dim);
603 assert_eq!(
604 out_dim_size, basis_size,
605 "out.shape().dim({}) = {} must equal basis_size = {}",
606 dim, out_dim_size, basis_size
607 );
608
609 for d in 0..rank {
611 if d != dim {
612 let values_d = values.shape().dim(d);
613 let out_d = out.shape().dim(d);
614 assert_eq!(
615 values_d, out_d,
616 "values.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
617 d, values_d, d, out_d
618 );
619 }
620 }
621
622 let result = self.fit_nd(backend, values, dim);
624
625 let total = out.len();
627 for i in 0..total {
628 let mut idx = vec![0usize; rank];
629 let mut remaining = i;
630 for d in (0..rank).rev() {
631 let dim_size = out.shape().dim(d);
632 idx[d] = remaining % dim_size;
633 remaining /= dim_size;
634 }
635 out[&idx[..]] = result[&idx[..]];
636 }
637 }
638}
639
640impl<S: StatisticsType> InplaceFitter for MatsubaraSampling<S> {
647 fn n_points(&self) -> usize {
648 self.n_sampling_points()
649 }
650
651 fn basis_size(&self) -> usize {
652 self.basis_size()
653 }
654
655 fn evaluate_nd_dz_to(
656 &self,
657 backend: Option<&GemmBackendHandle>,
658 coeffs: &Slice<f64, DynRank>,
659 dim: usize,
660 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
661 ) -> bool {
662 self.fitter.evaluate_nd_dz_to(backend, coeffs, dim, out)
663 }
664
665 fn evaluate_nd_zz_to(
666 &self,
667 backend: Option<&GemmBackendHandle>,
668 coeffs: &Slice<Complex<f64>, DynRank>,
669 dim: usize,
670 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
671 ) -> bool {
672 self.fitter.evaluate_nd_zz_to(backend, coeffs, dim, out)
673 }
674
675 fn fit_nd_zd_to(
676 &self,
677 backend: Option<&GemmBackendHandle>,
678 values: &Slice<Complex<f64>, DynRank>,
679 dim: usize,
680 out: &mut ViewMut<'_, f64, DynRank>,
681 ) -> bool {
682 self.fitter.fit_nd_zd_to(backend, values, dim, out)
683 }
684
685 fn fit_nd_zz_to(
686 &self,
687 backend: Option<&GemmBackendHandle>,
688 values: &Slice<Complex<f64>, DynRank>,
689 dim: usize,
690 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
691 ) -> bool {
692 self.fitter.fit_nd_zz_to(backend, values, dim, out)
693 }
694}
695
696pub struct MatsubaraSamplingPositiveOnly<S: StatisticsType> {
701 sampling_points: Vec<MatsubaraFreq<S>>,
702 fitter: ComplexToRealFitter,
703 _phantom: PhantomData<S>,
704}
705
706impl<S: StatisticsType> MatsubaraSamplingPositiveOnly<S> {
707 pub fn new(basis: &impl crate::basis_trait::Basis<S>) -> Self
712 where
713 S: 'static,
714 {
715 let sampling_points = basis.default_matsubara_sampling_points(true);
716 Self::with_sampling_points(basis, sampling_points)
717 }
718
719 pub fn with_sampling_points(
721 basis: &impl crate::basis_trait::Basis<S>,
722 mut sampling_points: Vec<MatsubaraFreq<S>>,
723 ) -> Self
724 where
725 S: 'static,
726 {
727 sampling_points.sort();
729
730 let matrix = basis.evaluate_matsubara(&sampling_points);
735
736 let fitter = ComplexToRealFitter::new(&matrix);
738
739 Self {
740 sampling_points,
741 fitter,
742 _phantom: PhantomData,
743 }
744 }
745
746 pub fn from_matrix(
761 mut sampling_points: Vec<MatsubaraFreq<S>>,
762 matrix: DTensor<Complex<f64>, 2>,
763 ) -> Self {
764 assert!(!sampling_points.is_empty(), "No sampling points given");
765 assert_eq!(
766 matrix.shape().0,
767 sampling_points.len(),
768 "Matrix rows ({}) must match number of sampling points ({})",
769 matrix.shape().0,
770 sampling_points.len()
771 );
772
773 sampling_points.sort();
775
776 let fitter = ComplexToRealFitter::new(&matrix);
777
778 Self {
779 sampling_points,
780 fitter,
781 _phantom: PhantomData,
782 }
783 }
784
785 pub fn sampling_points(&self) -> &[MatsubaraFreq<S>] {
787 &self.sampling_points
788 }
789
790 pub fn n_sampling_points(&self) -> usize {
792 self.sampling_points.len()
793 }
794
795 pub fn basis_size(&self) -> usize {
797 self.fitter.basis_size()
798 }
799
800 pub fn matrix(&self) -> &DTensor<Complex<f64>, 2> {
802 &self.fitter.matrix
803 }
804
805 pub fn evaluate(&self, coeffs: &[f64]) -> Vec<Complex<f64>> {
807 self.fitter.evaluate(None, coeffs)
808 }
809
810 pub fn fit(&self, values: &[Complex<f64>]) -> Vec<f64> {
812 self.fitter.fit(None, values)
813 }
814
815 pub fn evaluate_nd(
824 &self,
825 backend: Option<&GemmBackendHandle>,
826 coeffs: &Tensor<f64, DynRank>,
827 dim: usize,
828 ) -> Tensor<Complex<f64>, DynRank> {
829 let rank = coeffs.rank();
830 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
831
832 let basis_size = self.basis_size();
833 let target_dim_size = coeffs.shape().dim(dim);
834
835 assert_eq!(
836 target_dim_size, basis_size,
837 "coeffs.shape().dim({}) = {} must equal basis_size = {}",
838 dim, target_dim_size, basis_size
839 );
840
841 let coeffs_dim0 = movedim(coeffs, dim, 0);
843
844 let extra_size: usize = coeffs_dim0.len() / basis_size;
846
847 let coeffs_2d_dyn = coeffs_dim0
848 .reshape(&[basis_size, extra_size][..])
849 .to_tensor();
850
851 let coeffs_2d = DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| {
853 coeffs_2d_dyn[&[idx[0], idx[1]][..]]
854 });
855
856 let coeffs_2d_view = coeffs_2d.view(.., ..);
858 let result_2d = self.fitter.evaluate_2d(backend, &coeffs_2d_view);
859
860 let n_points = self.n_sampling_points();
862 let mut result_shape = vec![n_points];
863 coeffs_dim0.shape().with_dims(|dims| {
864 for i in 1..dims.len() {
865 result_shape.push(dims[i]);
866 }
867 });
868
869 let result_dim0 = result_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
870
871 movedim(&result_dim0, 0, dim)
873 }
874
875 pub fn fit_nd(
885 &self,
886 backend: Option<&GemmBackendHandle>,
887 values: &Tensor<Complex<f64>, DynRank>,
888 dim: usize,
889 ) -> Tensor<f64, DynRank> {
890 let rank = values.rank();
891 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
892
893 let n_points = self.n_sampling_points();
894 let target_dim_size = values.shape().dim(dim);
895
896 assert_eq!(
897 target_dim_size, n_points,
898 "values.shape().dim({}) = {} must equal n_sampling_points = {}",
899 dim, target_dim_size, n_points
900 );
901
902 let values_dim0 = movedim(values, dim, 0);
904
905 let extra_size: usize = values_dim0.len() / n_points;
907 let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
908
909 let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
911 values_2d_dyn[&[idx[0], idx[1]][..]]
912 });
913
914 let values_2d_view = values_2d.view(.., ..);
916 let coeffs_2d = self.fitter.fit_2d(backend, &values_2d_view);
917
918 let basis_size = self.basis_size();
920 let mut coeffs_shape = vec![basis_size];
921 values_dim0.shape().with_dims(|dims| {
922 for i in 1..dims.len() {
923 coeffs_shape.push(dims[i]);
924 }
925 });
926
927 let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
928
929 movedim(&coeffs_dim0, 0, dim)
931 }
932
933 pub fn evaluate_nd_to(
940 &self,
941 backend: Option<&GemmBackendHandle>,
942 coeffs: &Tensor<f64, DynRank>,
943 dim: usize,
944 out: &mut Tensor<Complex<f64>, DynRank>,
945 ) {
946 let rank = coeffs.rank();
948 assert_eq!(
949 out.rank(),
950 rank,
951 "out.rank()={} must equal coeffs.rank()={}",
952 out.rank(),
953 rank
954 );
955
956 let n_points = self.n_sampling_points();
957 let out_dim_size = out.shape().dim(dim);
958 assert_eq!(
959 out_dim_size, n_points,
960 "out.shape().dim({}) = {} must equal n_sampling_points = {}",
961 dim, out_dim_size, n_points
962 );
963
964 for d in 0..rank {
966 if d != dim {
967 let coeffs_d = coeffs.shape().dim(d);
968 let out_d = out.shape().dim(d);
969 assert_eq!(
970 coeffs_d, out_d,
971 "coeffs.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
972 d, coeffs_d, d, out_d
973 );
974 }
975 }
976
977 let result = self.evaluate_nd(backend, coeffs, dim);
979
980 let total = out.len();
982 for i in 0..total {
983 let mut idx = vec![0usize; rank];
984 let mut remaining = i;
985 for d in (0..rank).rev() {
986 let dim_size = out.shape().dim(d);
987 idx[d] = remaining % dim_size;
988 remaining /= dim_size;
989 }
990 out[&idx[..]] = result[&idx[..]];
991 }
992 }
993
994 pub fn fit_nd_to(
1001 &self,
1002 backend: Option<&GemmBackendHandle>,
1003 values: &Tensor<Complex<f64>, DynRank>,
1004 dim: usize,
1005 out: &mut Tensor<f64, DynRank>,
1006 ) {
1007 let rank = values.rank();
1009 assert_eq!(
1010 out.rank(),
1011 rank,
1012 "out.rank()={} must equal values.rank()={}",
1013 out.rank(),
1014 rank
1015 );
1016
1017 let basis_size = self.basis_size();
1018 let out_dim_size = out.shape().dim(dim);
1019 assert_eq!(
1020 out_dim_size, basis_size,
1021 "out.shape().dim({}) = {} must equal basis_size = {}",
1022 dim, out_dim_size, basis_size
1023 );
1024
1025 for d in 0..rank {
1027 if d != dim {
1028 let values_d = values.shape().dim(d);
1029 let out_d = out.shape().dim(d);
1030 assert_eq!(
1031 values_d, out_d,
1032 "values.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
1033 d, values_d, d, out_d
1034 );
1035 }
1036 }
1037
1038 let result = self.fit_nd(backend, values, dim);
1040
1041 let total = out.len();
1043 for i in 0..total {
1044 let mut idx = vec![0usize; rank];
1045 let mut remaining = i;
1046 for d in (0..rank).rev() {
1047 let dim_size = out.shape().dim(d);
1048 idx[d] = remaining % dim_size;
1049 remaining /= dim_size;
1050 }
1051 out[&idx[..]] = result[&idx[..]];
1052 }
1053 }
1054}
1055
1056impl<S: StatisticsType> InplaceFitter for MatsubaraSamplingPositiveOnly<S> {
1064 fn n_points(&self) -> usize {
1065 self.n_sampling_points()
1066 }
1067
1068 fn basis_size(&self) -> usize {
1069 self.basis_size()
1070 }
1071
1072 fn evaluate_nd_dz_to(
1073 &self,
1074 backend: Option<&GemmBackendHandle>,
1075 coeffs: &Slice<f64, DynRank>,
1076 dim: usize,
1077 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
1078 ) -> bool {
1079 self.fitter.evaluate_nd_dz_to(backend, coeffs, dim, out)
1080 }
1081
1082 fn evaluate_nd_zz_to(
1083 &self,
1084 backend: Option<&GemmBackendHandle>,
1085 coeffs: &Slice<Complex<f64>, DynRank>,
1086 dim: usize,
1087 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
1088 ) -> bool {
1089 self.fitter.evaluate_nd_zz_to(backend, coeffs, dim, out)
1090 }
1091
1092 fn fit_nd_zd_to(
1093 &self,
1094 backend: Option<&GemmBackendHandle>,
1095 values: &Slice<Complex<f64>, DynRank>,
1096 dim: usize,
1097 out: &mut ViewMut<'_, f64, DynRank>,
1098 ) -> bool {
1099 self.fitter.fit_nd_zd_to(backend, values, dim, out)
1100 }
1101
1102 fn fit_nd_zz_to(
1103 &self,
1104 backend: Option<&GemmBackendHandle>,
1105 values: &Slice<Complex<f64>, DynRank>,
1106 dim: usize,
1107 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
1108 ) -> bool {
1109 self.fitter.fit_nd_zz_to(backend, values, dim, out)
1110 }
1111}
1112
1113#[cfg(test)]
1114#[path = "matsubara_sampling_tests.rs"]
1115mod tests;