1use crate::fitters::InplaceFitter;
7use crate::gemm::GemmBackendHandle;
8use crate::traits::StatisticsType;
9use mdarray::{DTensor, DynRank, Shape, Slice, Tensor, ViewMut};
10use num_complex::Complex;
11
12fn build_output_shape<S: Shape>(input_shape: &S, dim: usize, new_size: usize) -> Vec<usize> {
14 let mut out_shape: Vec<usize> = Vec::with_capacity(input_shape.rank());
15 input_shape.with_dims(|dims| {
16 for (i, d) in dims.iter().enumerate() {
17 if i == dim {
18 out_shape.push(new_size);
19 } else {
20 out_shape.push(*d);
21 }
22 }
23 });
24 out_shape
25}
26
27pub fn movedim<T: Clone>(arr: &Slice<T, DynRank>, src: usize, dst: usize) -> Tensor<T, DynRank> {
47 if src == dst {
48 return arr.to_tensor();
49 }
50
51 let rank = arr.rank();
52 assert!(
53 src < rank,
54 "src axis {} out of bounds for rank {}",
55 src,
56 rank
57 );
58 assert!(
59 dst < rank,
60 "dst axis {} out of bounds for rank {}",
61 dst,
62 rank
63 );
64
65 let mut perm = Vec::with_capacity(rank);
67 let mut pos = 0;
68 for i in 0..rank {
69 if i == dst {
70 perm.push(src);
71 } else {
72 if pos == src {
74 pos += 1;
75 }
76 perm.push(pos);
77 pos += 1;
78 }
79 }
80
81 arr.permute(&perm[..]).to_tensor()
82}
83
84pub struct TauSampling<S>
89where
90 S: StatisticsType,
91{
92 sampling_points: Vec<f64>,
94
95 fitter: crate::fitters::RealMatrixFitter,
97
98 _phantom: std::marker::PhantomData<S>,
100}
101
102impl<S> TauSampling<S>
103where
104 S: StatisticsType,
105{
106 pub fn new(basis: &impl crate::basis_trait::Basis<S>) -> Self
118 where
119 S: 'static,
120 {
121 let sampling_points = basis.default_tau_sampling_points();
122 Self::with_sampling_points(basis, sampling_points)
123 }
124
125 pub fn with_sampling_points(
139 basis: &impl crate::basis_trait::Basis<S>,
140 sampling_points: Vec<f64>,
141 ) -> Self
142 where
143 S: 'static,
144 {
145 assert!(!sampling_points.is_empty(), "No sampling points given");
146 assert!(
147 basis.size() <= sampling_points.len(),
148 "The number of sampling points must be greater than or equal to the basis size"
149 );
150
151 let beta = basis.beta();
152 for &tau in &sampling_points {
153 assert!(
154 tau >= -beta && tau <= beta,
155 "Sampling point τ={} is outside [-β, β]",
156 tau
157 );
158 }
159
160 let matrix = basis.evaluate_tau(&sampling_points);
163
164 let fitter = crate::fitters::RealMatrixFitter::new(matrix);
166
167 Self {
168 sampling_points,
169 fitter,
170 _phantom: std::marker::PhantomData,
171 }
172 }
173
174 pub fn from_matrix(sampling_points: Vec<f64>, matrix: DTensor<f64, 2>) -> Self {
189 assert!(!sampling_points.is_empty(), "No sampling points given");
190 assert_eq!(
191 matrix.shape().0,
192 sampling_points.len(),
193 "Matrix rows ({}) must match number of sampling points ({})",
194 matrix.shape().0,
195 sampling_points.len()
196 );
197
198 let fitter = crate::fitters::RealMatrixFitter::new(matrix);
199
200 Self {
201 sampling_points,
202 fitter,
203 _phantom: std::marker::PhantomData,
204 }
205 }
206
207 pub fn sampling_points(&self) -> &[f64] {
209 &self.sampling_points
210 }
211
212 pub fn n_sampling_points(&self) -> usize {
214 self.fitter.n_points()
215 }
216
217 pub fn basis_size(&self) -> usize {
219 self.fitter.basis_size()
220 }
221
222 pub fn matrix(&self) -> &DTensor<f64, 2> {
224 &self.fitter.matrix
225 }
226
227 pub fn evaluate(&self, coeffs: &[f64]) -> Vec<f64> {
241 self.fitter.evaluate(None, coeffs)
242 }
243
244 pub fn evaluate_to(&self, coeffs: &[f64], out: &mut [f64]) {
246 self.fitter.evaluate_to(None, coeffs, out)
247 }
248
249 pub fn fit(&self, values: &[f64]) -> Vec<f64> {
251 self.fitter.fit(None, values)
252 }
253
254 pub fn fit_to(&self, values: &[f64], out: &mut [f64]) {
256 self.fitter.fit_to(None, values, out)
257 }
258
259 pub fn evaluate_zz(&self, coeffs: &[Complex<f64>]) -> Vec<Complex<f64>> {
261 self.fitter.evaluate_zz(None, coeffs)
262 }
263
264 pub fn evaluate_zz_to(&self, coeffs: &[Complex<f64>], out: &mut [Complex<f64>]) {
266 self.fitter.evaluate_zz_to(None, coeffs, out)
267 }
268
269 pub fn fit_zz(&self, values: &[Complex<f64>]) -> Vec<Complex<f64>> {
271 self.fitter.fit_zz(None, values)
272 }
273
274 pub fn fit_zz_to(&self, values: &[Complex<f64>], out: &mut [Complex<f64>]) {
276 self.fitter.fit_zz_to(None, values, out)
277 }
278
279 pub fn evaluate_nd(
292 &self,
293 backend: Option<&GemmBackendHandle>,
294 coeffs: &Slice<f64, DynRank>,
295 dim: usize,
296 ) -> Tensor<f64, DynRank> {
297 let out_shape = build_output_shape(coeffs.shape(), dim, self.n_sampling_points());
298 let mut out = Tensor::<f64, DynRank>::zeros(&out_shape[..]);
299 self.evaluate_nd_to(backend, coeffs, dim, &mut out.expr_mut());
300 out
301 }
302
303 pub fn evaluate_nd_to(
305 &self,
306 backend: Option<&GemmBackendHandle>,
307 coeffs: &Slice<f64, DynRank>,
308 dim: usize,
309 out: &mut ViewMut<'_, f64, DynRank>,
310 ) {
311 InplaceFitter::evaluate_nd_dd_to(self, backend, coeffs, dim, out);
312 }
313
314 pub fn fit_nd(
323 &self,
324 backend: Option<&GemmBackendHandle>,
325 values: &Slice<f64, DynRank>,
326 dim: usize,
327 ) -> Tensor<f64, DynRank> {
328 let out_shape = build_output_shape(values.shape(), dim, self.basis_size());
329 let mut out = Tensor::<f64, DynRank>::zeros(&out_shape[..]);
330 self.fit_nd_to(backend, values, dim, &mut out.expr_mut());
331 out
332 }
333
334 pub fn fit_nd_to(
336 &self,
337 backend: Option<&GemmBackendHandle>,
338 values: &Slice<f64, DynRank>,
339 dim: usize,
340 out: &mut ViewMut<'_, f64, DynRank>,
341 ) {
342 InplaceFitter::fit_nd_dd_to(self, backend, values, dim, out);
343 }
344
345 pub fn evaluate_nd_zz(
358 &self,
359 backend: Option<&GemmBackendHandle>,
360 coeffs: &Slice<Complex<f64>, DynRank>,
361 dim: usize,
362 ) -> Tensor<Complex<f64>, DynRank> {
363 let out_shape = build_output_shape(coeffs.shape(), dim, self.n_sampling_points());
364 let mut out = Tensor::<Complex<f64>, DynRank>::zeros(&out_shape[..]);
365 self.evaluate_nd_zz_to(backend, coeffs, dim, &mut out.expr_mut());
366 out
367 }
368
369 pub fn evaluate_nd_zz_to(
371 &self,
372 backend: Option<&GemmBackendHandle>,
373 coeffs: &Slice<Complex<f64>, DynRank>,
374 dim: usize,
375 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
376 ) {
377 InplaceFitter::evaluate_nd_zz_to(self, backend, coeffs, dim, out);
378 }
379
380 pub fn fit_nd_zz(
389 &self,
390 backend: Option<&GemmBackendHandle>,
391 values: &Slice<Complex<f64>, DynRank>,
392 dim: usize,
393 ) -> Tensor<Complex<f64>, DynRank> {
394 let out_shape = build_output_shape(values.shape(), dim, self.basis_size());
395 let mut out = Tensor::<Complex<f64>, DynRank>::zeros(&out_shape[..]);
396 self.fit_nd_zz_to(backend, values, dim, &mut out.expr_mut());
397 out
398 }
399
400 pub fn fit_nd_zz_to(
402 &self,
403 backend: Option<&GemmBackendHandle>,
404 values: &Slice<Complex<f64>, DynRank>,
405 dim: usize,
406 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
407 ) {
408 InplaceFitter::fit_nd_zz_to(self, backend, values, dim, out);
409 }
410}
411
412impl<S: StatisticsType> InplaceFitter for TauSampling<S> {
416 fn n_points(&self) -> usize {
417 self.n_sampling_points()
418 }
419
420 fn basis_size(&self) -> usize {
421 self.basis_size()
422 }
423
424 fn evaluate_nd_dd_to(
425 &self,
426 backend: Option<&GemmBackendHandle>,
427 coeffs: &Slice<f64, DynRank>,
428 dim: usize,
429 out: &mut ViewMut<'_, f64, DynRank>,
430 ) -> bool {
431 self.fitter.evaluate_nd_dd_to(backend, coeffs, dim, out)
432 }
433
434 fn evaluate_nd_zz_to(
435 &self,
436 backend: Option<&GemmBackendHandle>,
437 coeffs: &Slice<Complex<f64>, DynRank>,
438 dim: usize,
439 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
440 ) -> bool {
441 self.fitter.evaluate_nd_zz_to(backend, coeffs, dim, out)
442 }
443
444 fn fit_nd_dd_to(
445 &self,
446 backend: Option<&GemmBackendHandle>,
447 values: &Slice<f64, DynRank>,
448 dim: usize,
449 out: &mut ViewMut<'_, f64, DynRank>,
450 ) -> bool {
451 self.fitter.fit_nd_dd_to(backend, values, dim, out)
452 }
453
454 fn fit_nd_zz_to(
455 &self,
456 backend: Option<&GemmBackendHandle>,
457 values: &Slice<Complex<f64>, DynRank>,
458 dim: usize,
459 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
460 ) -> bool {
461 self.fitter.fit_nd_zz_to(backend, values, dim, out)
462 }
463}
464
465#[cfg(test)]
466#[path = "tau_sampling_tests.rs"]
467mod tests;