1use crate::fitters::{ComplexMatrixFitter, ComplexToRealFitter, InplaceFitter};
7use crate::freq::MatsubaraFreq;
8use crate::gemm::GemmBackendHandle;
9use crate::traits::StatisticsType;
10use mdarray::{DTensor, DynRank, Shape, Slice, Tensor, ViewMut};
11use num_complex::Complex;
12use std::marker::PhantomData;
13
14fn movedim<T: Clone>(arr: &Slice<T, DynRank>, src: usize, dst: usize) -> Tensor<T, DynRank> {
16 if src == dst {
17 return arr.to_tensor();
18 }
19
20 let rank = arr.rank();
21 assert!(
22 src < rank,
23 "src axis {} out of bounds for rank {}",
24 src,
25 rank
26 );
27 assert!(
28 dst < rank,
29 "dst axis {} out of bounds for rank {}",
30 dst,
31 rank
32 );
33
34 let mut perm = Vec::with_capacity(rank);
36 let mut pos = 0;
37 for i in 0..rank {
38 if i == dst {
39 perm.push(src);
40 } else {
41 if pos == src {
42 pos += 1;
43 }
44 if pos < rank {
45 perm.push(pos);
46 pos += 1;
47 }
48 }
49 }
50
51 arr.permute(&perm[..]).to_tensor()
52}
53
54pub struct MatsubaraSampling<S: StatisticsType> {
58 sampling_points: Vec<MatsubaraFreq<S>>,
59 fitter: ComplexMatrixFitter,
60 _phantom: PhantomData<S>,
61}
62
63impl<S: StatisticsType> MatsubaraSampling<S> {
64 pub fn new(basis: &impl crate::basis_trait::Basis<S>) -> Self
68 where
69 S: 'static,
70 {
71 let sampling_points = basis.default_matsubara_sampling_points(false);
72 Self::with_sampling_points(basis, sampling_points)
73 }
74
75 pub fn with_sampling_points(
77 basis: &impl crate::basis_trait::Basis<S>,
78 mut sampling_points: Vec<MatsubaraFreq<S>>,
79 ) -> Self
80 where
81 S: 'static,
82 {
83 sampling_points.sort();
85
86 let matrix = basis.evaluate_matsubara(&sampling_points);
89
90 let fitter = ComplexMatrixFitter::new(matrix);
92
93 Self {
94 sampling_points,
95 fitter,
96 _phantom: PhantomData,
97 }
98 }
99
100 pub fn from_matrix(
115 mut sampling_points: Vec<MatsubaraFreq<S>>,
116 matrix: DTensor<Complex<f64>, 2>,
117 ) -> Self {
118 assert!(!sampling_points.is_empty(), "No sampling points given");
119 assert_eq!(
120 matrix.shape().0,
121 sampling_points.len(),
122 "Matrix rows ({}) must match number of sampling points ({})",
123 matrix.shape().0,
124 sampling_points.len()
125 );
126
127 sampling_points.sort();
129
130 let fitter = ComplexMatrixFitter::new(matrix);
131
132 Self {
133 sampling_points,
134 fitter,
135 _phantom: PhantomData,
136 }
137 }
138
139 pub fn sampling_points(&self) -> &[MatsubaraFreq<S>] {
141 &self.sampling_points
142 }
143
144 pub fn n_sampling_points(&self) -> usize {
146 self.sampling_points.len()
147 }
148
149 pub fn basis_size(&self) -> usize {
151 self.fitter.basis_size()
152 }
153
154 pub fn matrix(&self) -> &DTensor<Complex<f64>, 2> {
156 &self.fitter.matrix
157 }
158
159 pub fn evaluate(&self, coeffs: &[Complex<f64>]) -> Vec<Complex<f64>> {
167 self.fitter.evaluate(None, coeffs)
168 }
169
170 pub fn fit(&self, values: &[Complex<f64>]) -> Vec<Complex<f64>> {
178 self.fitter.fit(None, values)
179 }
180
181 fn evaluate_nd_impl_real(
209 &self,
210 backend: Option<&GemmBackendHandle>,
211 coeffs: &Slice<f64, DynRank>,
212 dim: usize,
213 ) -> Tensor<Complex<f64>, DynRank> {
214 let rank = coeffs.rank();
215 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
216
217 let basis_size = self.basis_size();
218 let target_dim_size = coeffs.shape().dim(dim);
219
220 assert_eq!(
221 target_dim_size, basis_size,
222 "coeffs.shape().dim({}) = {} must equal basis_size = {}",
223 dim, target_dim_size, basis_size
224 );
225
226 let coeffs_dim0 = movedim(coeffs, dim, 0);
228
229 let extra_size: usize = coeffs_dim0.len() / basis_size;
231
232 let coeffs_2d_dyn = coeffs_dim0
233 .reshape(&[basis_size, extra_size][..])
234 .to_tensor();
235
236 let coeffs_2d = DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| {
238 coeffs_2d_dyn[&[idx[0], idx[1]][..]]
239 });
240 let coeffs_2d_view = coeffs_2d.view(.., ..);
241 let result_2d = self.fitter.evaluate_2d_real(backend, &coeffs_2d_view);
242
243 let n_points = self.n_sampling_points();
245 let mut result_shape = vec![n_points];
246 coeffs_dim0.shape().with_dims(|dims| {
247 for i in 1..dims.len() {
248 result_shape.push(dims[i]);
249 }
250 });
251
252 let result_dim0 = result_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
253
254 movedim(&result_dim0, 0, dim)
256 }
257
258 fn evaluate_nd_impl_complex(
260 &self,
261 backend: Option<&GemmBackendHandle>,
262 coeffs: &Slice<Complex<f64>, DynRank>,
263 dim: usize,
264 ) -> Tensor<Complex<f64>, DynRank> {
265 let rank = coeffs.rank();
266 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
267
268 let basis_size = self.basis_size();
269 let target_dim_size = coeffs.shape().dim(dim);
270
271 assert_eq!(
272 target_dim_size, basis_size,
273 "coeffs.shape().dim({}) = {} must equal basis_size = {}",
274 dim, target_dim_size, basis_size
275 );
276
277 let coeffs_dim0 = movedim(coeffs, dim, 0);
279
280 let extra_size: usize = coeffs_dim0.len() / basis_size;
282
283 let coeffs_2d_dyn = coeffs_dim0
284 .reshape(&[basis_size, extra_size][..])
285 .to_tensor();
286
287 let coeffs_2d = DTensor::<Complex<f64>, 2>::from_fn([basis_size, extra_size], |idx| {
289 coeffs_2d_dyn[&[idx[0], idx[1]][..]]
290 });
291 let coeffs_2d_view = coeffs_2d.view(.., ..);
292 let result_2d = self.fitter.evaluate_2d(backend, &coeffs_2d_view);
293
294 let n_points = self.n_sampling_points();
296 let mut result_shape = vec![n_points];
297 coeffs_dim0.shape().with_dims(|dims| {
298 for i in 1..dims.len() {
299 result_shape.push(dims[i]);
300 }
301 });
302
303 let result_dim0 = result_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
304
305 movedim(&result_dim0, 0, dim)
307 }
308
309 pub fn evaluate_nd<T>(
310 &self,
311 backend: Option<&GemmBackendHandle>,
312 coeffs: &Slice<T, DynRank>,
313 dim: usize,
314 ) -> Tensor<Complex<f64>, DynRank>
315 where
316 T: Copy + 'static,
317 {
318 use std::any::TypeId;
319
320 if TypeId::of::<T>() == TypeId::of::<f64>() {
321 let coeffs_f64 =
324 unsafe { &*(coeffs as *const Slice<T, DynRank> as *const Slice<f64, DynRank>) };
325 self.evaluate_nd_impl_real(backend, coeffs_f64, dim)
326 } else if TypeId::of::<T>() == TypeId::of::<Complex<f64>>() {
327 let coeffs_complex = unsafe {
330 &*(coeffs as *const Slice<T, DynRank> as *const Slice<Complex<f64>, DynRank>)
331 };
332 self.evaluate_nd_impl_complex(backend, coeffs_complex, dim)
333 } else {
334 panic!("Unsupported type for evaluate_nd: must be f64 or Complex<f64>");
335 }
336 }
337
338 pub fn evaluate_nd_real(
351 &self,
352 backend: Option<&GemmBackendHandle>,
353 coeffs: &Tensor<f64, DynRank>,
354 dim: usize,
355 ) -> Tensor<Complex<f64>, DynRank> {
356 let rank = coeffs.rank();
357 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
358
359 let basis_size = self.basis_size();
360 let target_dim_size = coeffs.shape().dim(dim);
361
362 assert_eq!(
363 target_dim_size, basis_size,
364 "coeffs.shape().dim({}) = {} must equal basis_size = {}",
365 dim, target_dim_size, basis_size
366 );
367
368 let coeffs_dim0 = movedim(coeffs, dim, 0);
370
371 let extra_size: usize = coeffs_dim0.len() / basis_size;
373
374 let coeffs_2d_dyn = coeffs_dim0
375 .reshape(&[basis_size, extra_size][..])
376 .to_tensor();
377
378 let coeffs_2d = DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| {
380 coeffs_2d_dyn[&[idx[0], idx[1]][..]]
381 });
382
383 let coeffs_2d_view = coeffs_2d.view(.., ..);
385 let values_2d = self.fitter.evaluate_2d_real(backend, &coeffs_2d_view);
386
387 let n_points = self.n_sampling_points();
389 let mut result_shape = Vec::with_capacity(rank);
390 result_shape.push(n_points);
391 coeffs_dim0.shape().with_dims(|dims| {
392 for i in 1..dims.len() {
393 result_shape.push(dims[i]);
394 }
395 });
396
397 let result_dim0 = values_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
398
399 movedim(&result_dim0, 0, dim)
401 }
402
403 pub fn fit_nd(
413 &self,
414 backend: Option<&GemmBackendHandle>,
415 values: &Tensor<Complex<f64>, DynRank>,
416 dim: usize,
417 ) -> Tensor<Complex<f64>, DynRank> {
418 let rank = values.rank();
419 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
420
421 let n_points = self.n_sampling_points();
422 let target_dim_size = values.shape().dim(dim);
423
424 assert_eq!(
425 target_dim_size, n_points,
426 "values.shape().dim({}) = {} must equal n_sampling_points = {}",
427 dim, target_dim_size, n_points
428 );
429
430 let values_dim0 = movedim(values, dim, 0);
432
433 let extra_size: usize = values_dim0.len() / n_points;
435 let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
436
437 let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
439 values_2d_dyn[&[idx[0], idx[1]][..]]
440 });
441
442 let values_2d_view = values_2d.view(.., ..);
444 let coeffs_2d = self.fitter.fit_2d(backend, &values_2d_view);
445
446 let basis_size = self.basis_size();
448 let mut coeffs_shape = vec![basis_size];
449 values_dim0.shape().with_dims(|dims| {
450 for i in 1..dims.len() {
451 coeffs_shape.push(dims[i]);
452 }
453 });
454
455 let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
456
457 movedim(&coeffs_dim0, 0, dim)
459 }
460
461 pub fn fit_nd_real(
474 &self,
475 backend: Option<&GemmBackendHandle>,
476 values: &Tensor<Complex<f64>, DynRank>,
477 dim: usize,
478 ) -> Tensor<f64, DynRank> {
479 let rank = values.rank();
480 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
481
482 let n_points = self.n_sampling_points();
483 let target_dim_size = values.shape().dim(dim);
484
485 assert_eq!(
486 target_dim_size, n_points,
487 "values.shape().dim({}) = {} must equal n_sampling_points = {}",
488 dim, target_dim_size, n_points
489 );
490
491 let values_dim0 = movedim(values, dim, 0);
493
494 let extra_size: usize = values_dim0.len() / n_points;
496 let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
497
498 let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
500 values_2d_dyn[&[idx[0], idx[1]][..]]
501 });
502
503 let values_2d_view = values_2d.view(.., ..);
505 let coeffs_2d = self.fitter.fit_2d_real(backend, &values_2d_view);
506
507 let basis_size = self.basis_size();
509 let mut coeffs_shape = vec![basis_size];
510 values_dim0.shape().with_dims(|dims| {
511 for i in 1..dims.len() {
512 coeffs_shape.push(dims[i]);
513 }
514 });
515
516 let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
517
518 movedim(&coeffs_dim0, 0, dim)
520 }
521
522 pub fn evaluate_nd_to<T>(
532 &self,
533 backend: Option<&GemmBackendHandle>,
534 coeffs: &Slice<T, DynRank>,
535 dim: usize,
536 out: &mut Tensor<Complex<f64>, DynRank>,
537 ) where
538 T: Copy + 'static,
539 {
540 let rank = coeffs.rank();
542 assert_eq!(
543 out.rank(),
544 rank,
545 "out.rank()={} must equal coeffs.rank()={}",
546 out.rank(),
547 rank
548 );
549
550 let n_points = self.n_sampling_points();
551 let out_dim_size = out.shape().dim(dim);
552 assert_eq!(
553 out_dim_size, n_points,
554 "out.shape().dim({}) = {} must equal n_sampling_points = {}",
555 dim, out_dim_size, n_points
556 );
557
558 for d in 0..rank {
560 if d != dim {
561 let coeffs_d = coeffs.shape().dim(d);
562 let out_d = out.shape().dim(d);
563 assert_eq!(
564 coeffs_d, out_d,
565 "coeffs.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
566 d, coeffs_d, d, out_d
567 );
568 }
569 }
570
571 let result = self.evaluate_nd(backend, coeffs, dim);
573
574 let total = out.len();
576 for i in 0..total {
577 let mut idx = vec![0usize; rank];
578 let mut remaining = i;
579 for d in (0..rank).rev() {
580 let dim_size = out.shape().dim(d);
581 idx[d] = remaining % dim_size;
582 remaining /= dim_size;
583 }
584 out[&idx[..]] = result[&idx[..]];
585 }
586 }
587
588 pub fn fit_nd_to(
595 &self,
596 backend: Option<&GemmBackendHandle>,
597 values: &Tensor<Complex<f64>, DynRank>,
598 dim: usize,
599 out: &mut Tensor<Complex<f64>, DynRank>,
600 ) {
601 let rank = values.rank();
603 assert_eq!(
604 out.rank(),
605 rank,
606 "out.rank()={} must equal values.rank()={}",
607 out.rank(),
608 rank
609 );
610
611 let basis_size = self.basis_size();
612 let out_dim_size = out.shape().dim(dim);
613 assert_eq!(
614 out_dim_size, basis_size,
615 "out.shape().dim({}) = {} must equal basis_size = {}",
616 dim, out_dim_size, basis_size
617 );
618
619 for d in 0..rank {
621 if d != dim {
622 let values_d = values.shape().dim(d);
623 let out_d = out.shape().dim(d);
624 assert_eq!(
625 values_d, out_d,
626 "values.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
627 d, values_d, d, out_d
628 );
629 }
630 }
631
632 let result = self.fit_nd(backend, values, dim);
634
635 let total = out.len();
637 for i in 0..total {
638 let mut idx = vec![0usize; rank];
639 let mut remaining = i;
640 for d in (0..rank).rev() {
641 let dim_size = out.shape().dim(d);
642 idx[d] = remaining % dim_size;
643 remaining /= dim_size;
644 }
645 out[&idx[..]] = result[&idx[..]];
646 }
647 }
648}
649
650impl<S: StatisticsType> InplaceFitter for MatsubaraSampling<S> {
657 fn n_points(&self) -> usize {
658 self.n_sampling_points()
659 }
660
661 fn basis_size(&self) -> usize {
662 self.basis_size()
663 }
664
665 fn evaluate_nd_dz_to(
666 &self,
667 backend: Option<&GemmBackendHandle>,
668 coeffs: &Slice<f64, DynRank>,
669 dim: usize,
670 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
671 ) -> bool {
672 self.fitter.evaluate_nd_dz_to(backend, coeffs, dim, out)
673 }
674
675 fn evaluate_nd_zz_to(
676 &self,
677 backend: Option<&GemmBackendHandle>,
678 coeffs: &Slice<Complex<f64>, DynRank>,
679 dim: usize,
680 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
681 ) -> bool {
682 self.fitter.evaluate_nd_zz_to(backend, coeffs, dim, out)
683 }
684
685 fn fit_nd_zd_to(
686 &self,
687 backend: Option<&GemmBackendHandle>,
688 values: &Slice<Complex<f64>, DynRank>,
689 dim: usize,
690 out: &mut ViewMut<'_, f64, DynRank>,
691 ) -> bool {
692 self.fitter.fit_nd_zd_to(backend, values, dim, out)
693 }
694
695 fn fit_nd_zz_to(
696 &self,
697 backend: Option<&GemmBackendHandle>,
698 values: &Slice<Complex<f64>, DynRank>,
699 dim: usize,
700 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
701 ) -> bool {
702 self.fitter.fit_nd_zz_to(backend, values, dim, out)
703 }
704}
705
706pub struct MatsubaraSamplingPositiveOnly<S: StatisticsType> {
711 sampling_points: Vec<MatsubaraFreq<S>>,
712 fitter: ComplexToRealFitter,
713 _phantom: PhantomData<S>,
714}
715
716impl<S: StatisticsType> MatsubaraSamplingPositiveOnly<S> {
717 pub fn new(basis: &impl crate::basis_trait::Basis<S>) -> Self
722 where
723 S: 'static,
724 {
725 let sampling_points = basis.default_matsubara_sampling_points(true);
726 Self::with_sampling_points(basis, sampling_points)
727 }
728
729 pub fn with_sampling_points(
731 basis: &impl crate::basis_trait::Basis<S>,
732 mut sampling_points: Vec<MatsubaraFreq<S>>,
733 ) -> Self
734 where
735 S: 'static,
736 {
737 sampling_points.sort();
739
740 let matrix = basis.evaluate_matsubara(&sampling_points);
745
746 let fitter = ComplexToRealFitter::new(&matrix);
748
749 Self {
750 sampling_points,
751 fitter,
752 _phantom: PhantomData,
753 }
754 }
755
756 pub fn from_matrix(
771 mut sampling_points: Vec<MatsubaraFreq<S>>,
772 matrix: DTensor<Complex<f64>, 2>,
773 ) -> Self {
774 assert!(!sampling_points.is_empty(), "No sampling points given");
775 assert_eq!(
776 matrix.shape().0,
777 sampling_points.len(),
778 "Matrix rows ({}) must match number of sampling points ({})",
779 matrix.shape().0,
780 sampling_points.len()
781 );
782
783 sampling_points.sort();
785
786 let fitter = ComplexToRealFitter::new(&matrix);
787
788 Self {
789 sampling_points,
790 fitter,
791 _phantom: PhantomData,
792 }
793 }
794
795 pub fn sampling_points(&self) -> &[MatsubaraFreq<S>] {
797 &self.sampling_points
798 }
799
800 pub fn n_sampling_points(&self) -> usize {
802 self.sampling_points.len()
803 }
804
805 pub fn basis_size(&self) -> usize {
807 self.fitter.basis_size()
808 }
809
810 pub fn matrix(&self) -> &DTensor<Complex<f64>, 2> {
812 &self.fitter.matrix
813 }
814
815 pub fn evaluate(&self, coeffs: &[f64]) -> Vec<Complex<f64>> {
817 self.fitter.evaluate(None, coeffs)
818 }
819
820 pub fn fit(&self, values: &[Complex<f64>]) -> Vec<f64> {
822 self.fitter.fit(None, values)
823 }
824
825 pub fn evaluate_nd(
834 &self,
835 backend: Option<&GemmBackendHandle>,
836 coeffs: &Tensor<f64, DynRank>,
837 dim: usize,
838 ) -> Tensor<Complex<f64>, DynRank> {
839 let rank = coeffs.rank();
840 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
841
842 let basis_size = self.basis_size();
843 let target_dim_size = coeffs.shape().dim(dim);
844
845 assert_eq!(
846 target_dim_size, basis_size,
847 "coeffs.shape().dim({}) = {} must equal basis_size = {}",
848 dim, target_dim_size, basis_size
849 );
850
851 let coeffs_dim0 = movedim(coeffs, dim, 0);
853
854 let extra_size: usize = coeffs_dim0.len() / basis_size;
856
857 let coeffs_2d_dyn = coeffs_dim0
858 .reshape(&[basis_size, extra_size][..])
859 .to_tensor();
860
861 let coeffs_2d = DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| {
863 coeffs_2d_dyn[&[idx[0], idx[1]][..]]
864 });
865
866 let coeffs_2d_view = coeffs_2d.view(.., ..);
868 let result_2d = self.fitter.evaluate_2d(backend, &coeffs_2d_view);
869
870 let n_points = self.n_sampling_points();
872 let mut result_shape = vec![n_points];
873 coeffs_dim0.shape().with_dims(|dims| {
874 for i in 1..dims.len() {
875 result_shape.push(dims[i]);
876 }
877 });
878
879 let result_dim0 = result_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
880
881 movedim(&result_dim0, 0, dim)
883 }
884
885 pub fn fit_nd(
895 &self,
896 backend: Option<&GemmBackendHandle>,
897 values: &Tensor<Complex<f64>, DynRank>,
898 dim: usize,
899 ) -> Tensor<f64, DynRank> {
900 let rank = values.rank();
901 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
902
903 let n_points = self.n_sampling_points();
904 let target_dim_size = values.shape().dim(dim);
905
906 assert_eq!(
907 target_dim_size, n_points,
908 "values.shape().dim({}) = {} must equal n_sampling_points = {}",
909 dim, target_dim_size, n_points
910 );
911
912 let values_dim0 = movedim(values, dim, 0);
914
915 let extra_size: usize = values_dim0.len() / n_points;
917 let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
918
919 let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
921 values_2d_dyn[&[idx[0], idx[1]][..]]
922 });
923
924 let values_2d_view = values_2d.view(.., ..);
926 let coeffs_2d = self.fitter.fit_2d(backend, &values_2d_view);
927
928 let basis_size = self.basis_size();
930 let mut coeffs_shape = vec![basis_size];
931 values_dim0.shape().with_dims(|dims| {
932 for i in 1..dims.len() {
933 coeffs_shape.push(dims[i]);
934 }
935 });
936
937 let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
938
939 movedim(&coeffs_dim0, 0, dim)
941 }
942
943 pub fn evaluate_nd_to(
950 &self,
951 backend: Option<&GemmBackendHandle>,
952 coeffs: &Tensor<f64, DynRank>,
953 dim: usize,
954 out: &mut Tensor<Complex<f64>, DynRank>,
955 ) {
956 let rank = coeffs.rank();
958 assert_eq!(
959 out.rank(),
960 rank,
961 "out.rank()={} must equal coeffs.rank()={}",
962 out.rank(),
963 rank
964 );
965
966 let n_points = self.n_sampling_points();
967 let out_dim_size = out.shape().dim(dim);
968 assert_eq!(
969 out_dim_size, n_points,
970 "out.shape().dim({}) = {} must equal n_sampling_points = {}",
971 dim, out_dim_size, n_points
972 );
973
974 for d in 0..rank {
976 if d != dim {
977 let coeffs_d = coeffs.shape().dim(d);
978 let out_d = out.shape().dim(d);
979 assert_eq!(
980 coeffs_d, out_d,
981 "coeffs.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
982 d, coeffs_d, d, out_d
983 );
984 }
985 }
986
987 let result = self.evaluate_nd(backend, coeffs, dim);
989
990 let total = out.len();
992 for i in 0..total {
993 let mut idx = vec![0usize; rank];
994 let mut remaining = i;
995 for d in (0..rank).rev() {
996 let dim_size = out.shape().dim(d);
997 idx[d] = remaining % dim_size;
998 remaining /= dim_size;
999 }
1000 out[&idx[..]] = result[&idx[..]];
1001 }
1002 }
1003
1004 pub fn fit_nd_to(
1011 &self,
1012 backend: Option<&GemmBackendHandle>,
1013 values: &Tensor<Complex<f64>, DynRank>,
1014 dim: usize,
1015 out: &mut Tensor<f64, DynRank>,
1016 ) {
1017 let rank = values.rank();
1019 assert_eq!(
1020 out.rank(),
1021 rank,
1022 "out.rank()={} must equal values.rank()={}",
1023 out.rank(),
1024 rank
1025 );
1026
1027 let basis_size = self.basis_size();
1028 let out_dim_size = out.shape().dim(dim);
1029 assert_eq!(
1030 out_dim_size, basis_size,
1031 "out.shape().dim({}) = {} must equal basis_size = {}",
1032 dim, out_dim_size, basis_size
1033 );
1034
1035 for d in 0..rank {
1037 if d != dim {
1038 let values_d = values.shape().dim(d);
1039 let out_d = out.shape().dim(d);
1040 assert_eq!(
1041 values_d, out_d,
1042 "values.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
1043 d, values_d, d, out_d
1044 );
1045 }
1046 }
1047
1048 let result = self.fit_nd(backend, values, dim);
1050
1051 let total = out.len();
1053 for i in 0..total {
1054 let mut idx = vec![0usize; rank];
1055 let mut remaining = i;
1056 for d in (0..rank).rev() {
1057 let dim_size = out.shape().dim(d);
1058 idx[d] = remaining % dim_size;
1059 remaining /= dim_size;
1060 }
1061 out[&idx[..]] = result[&idx[..]];
1062 }
1063 }
1064}
1065
1066impl<S: StatisticsType> InplaceFitter for MatsubaraSamplingPositiveOnly<S> {
1074 fn n_points(&self) -> usize {
1075 self.n_sampling_points()
1076 }
1077
1078 fn basis_size(&self) -> usize {
1079 self.basis_size()
1080 }
1081
1082 fn evaluate_nd_dz_to(
1083 &self,
1084 backend: Option<&GemmBackendHandle>,
1085 coeffs: &Slice<f64, DynRank>,
1086 dim: usize,
1087 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
1088 ) -> bool {
1089 self.fitter.evaluate_nd_dz_to(backend, coeffs, dim, out)
1090 }
1091
1092 fn evaluate_nd_zz_to(
1093 &self,
1094 backend: Option<&GemmBackendHandle>,
1095 coeffs: &Slice<Complex<f64>, DynRank>,
1096 dim: usize,
1097 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
1098 ) -> bool {
1099 self.fitter.evaluate_nd_zz_to(backend, coeffs, dim, out)
1100 }
1101
1102 fn fit_nd_zd_to(
1103 &self,
1104 backend: Option<&GemmBackendHandle>,
1105 values: &Slice<Complex<f64>, DynRank>,
1106 dim: usize,
1107 out: &mut ViewMut<'_, f64, DynRank>,
1108 ) -> bool {
1109 self.fitter.fit_nd_zd_to(backend, values, dim, out)
1110 }
1111
1112 fn fit_nd_zz_to(
1113 &self,
1114 backend: Option<&GemmBackendHandle>,
1115 values: &Slice<Complex<f64>, DynRank>,
1116 dim: usize,
1117 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
1118 ) -> bool {
1119 self.fitter.fit_nd_zz_to(backend, values, dim, out)
1120 }
1121}
1122
1123#[cfg(test)]
1124#[path = "matsubara_sampling_tests.rs"]
1125mod tests;