sparse_ir/sampling.rs
1//! Sparse sampling in imaginary time
2//!
3//! This module provides `TauSampling` for transforming between IR basis coefficients
4//! and values at sparse sampling points in imaginary time.
5
6use crate::gemm::{GemmBackendHandle, matmul_par};
7use crate::traits::StatisticsType;
8use mdarray::{DTensor, DynRank, Shape, Tensor};
9
10/// Move axis from position `src` to position `dst`
11///
12/// This is equivalent to numpy.moveaxis or libsparseir's movedim.
13/// It creates a permutation array that moves the specified axis.
14///
15/// # Arguments
16/// * `arr` - Input tensor
17/// * `src` - Source axis position
18/// * `dst` - Destination axis position
19///
20/// # Returns
21/// Tensor with axes permuted
22///
23/// # Example
24/// ```ignore
25/// // For a 4D tensor with shape (2, 3, 4, 5)
26/// // movedim(arr, 0, 2) moves axis 0 to position 2
27/// // Result shape: (3, 4, 2, 5) with axes permuted as [1, 2, 0, 3]
28/// ```
29pub fn movedim<T: Clone>(arr: &Tensor<T, DynRank>, src: usize, dst: usize) -> Tensor<T, DynRank> {
30 if src == dst {
31 return arr.clone();
32 }
33
34 let rank = arr.rank();
35 assert!(
36 src < rank,
37 "src axis {} out of bounds for rank {}",
38 src,
39 rank
40 );
41 assert!(
42 dst < rank,
43 "dst axis {} out of bounds for rank {}",
44 dst,
45 rank
46 );
47
48 // Generate permutation: move src to dst position
49 let mut perm = Vec::with_capacity(rank);
50 let mut pos = 0;
51 for i in 0..rank {
52 if i == dst {
53 perm.push(src);
54 } else {
55 // Skip src position
56 if pos == src {
57 pos += 1;
58 }
59 perm.push(pos);
60 pos += 1;
61 }
62 }
63
64 arr.permute(&perm[..]).to_tensor()
65}
66
67/// Sparse sampling in imaginary time
68///
69/// Allows transformation between the IR basis and a set of sampling points
70/// in imaginary time (τ).
71pub struct TauSampling<S>
72where
73 S: StatisticsType,
74{
75 /// Sampling points in imaginary time τ ∈ [0, β]
76 sampling_points: Vec<f64>,
77
78 /// Real matrix fitter for least-squares fitting
79 fitter: crate::fitter::RealMatrixFitter,
80
81 /// Marker for statistics type
82 _phantom: std::marker::PhantomData<S>,
83}
84
85impl<S> TauSampling<S>
86where
87 S: StatisticsType,
88{
89 /// Create a new TauSampling with default sampling points
90 ///
91 /// The default sampling points are chosen as the extrema of the highest-order
92 /// basis function, which gives near-optimal conditioning.
93 /// SVD is computed lazily on first call to `fit` or `fit_nd`.
94 ///
95 /// # Arguments
96 /// * `basis` - Any basis implementing the `Basis` trait
97 ///
98 /// # Returns
99 /// A new TauSampling object
100 pub fn new(basis: &impl crate::basis_trait::Basis<S>) -> Self
101 where
102 S: 'static,
103 {
104 let sampling_points = basis.default_tau_sampling_points();
105 Self::with_sampling_points(basis, sampling_points)
106 }
107
108 /// Create a new TauSampling with custom sampling points
109 ///
110 /// SVD is computed lazily on first call to `fit` or `fit_nd`.
111 ///
112 /// # Arguments
113 /// * `basis` - Any basis implementing the `Basis` trait
114 /// * `sampling_points` - Custom sampling points in τ ∈ [-β, β]
115 ///
116 /// # Returns
117 /// A new TauSampling object
118 ///
119 /// # Panics
120 /// Panics if `sampling_points` is empty or if any point is outside [-β, β]
121 pub fn with_sampling_points(
122 basis: &impl crate::basis_trait::Basis<S>,
123 sampling_points: Vec<f64>,
124 ) -> Self
125 where
126 S: 'static,
127 {
128 assert!(!sampling_points.is_empty(), "No sampling points given");
129 assert!(
130 basis.size() <= sampling_points.len(),
131 "The number of sampling points must be greater than or equal to the basis size"
132 );
133
134 let beta = basis.beta();
135 for &tau in &sampling_points {
136 assert!(
137 tau >= -beta && tau <= beta,
138 "Sampling point τ={} is outside [-β, β]",
139 tau
140 );
141 }
142
143 // Compute sampling matrix: A[i, l] = u_l(τ_i)
144 // Use Basis trait's evaluate_tau method
145 let matrix = basis.evaluate_tau(&sampling_points);
146
147 // Create fitter
148 let fitter = crate::fitter::RealMatrixFitter::new(matrix);
149
150 Self {
151 sampling_points,
152 fitter,
153 _phantom: std::marker::PhantomData,
154 }
155 }
156
157 /// Create a new TauSampling with custom sampling points and pre-computed matrix
158 ///
159 /// This constructor is useful when the sampling matrix is already computed
160 /// (e.g., from external sources or for testing).
161 ///
162 /// # Arguments
163 /// * `sampling_points` - Sampling points in τ ∈ [-β, β]
164 /// * `matrix` - Pre-computed sampling matrix (n_points × basis_size)
165 ///
166 /// # Returns
167 /// A new TauSampling object
168 ///
169 /// # Panics
170 /// Panics if `sampling_points` is empty or if matrix dimensions don't match
171 pub fn from_matrix(sampling_points: Vec<f64>, matrix: DTensor<f64, 2>) -> Self {
172 assert!(!sampling_points.is_empty(), "No sampling points given");
173 assert_eq!(
174 matrix.shape().0,
175 sampling_points.len(),
176 "Matrix rows ({}) must match number of sampling points ({})",
177 matrix.shape().0,
178 sampling_points.len()
179 );
180
181 let fitter = crate::fitter::RealMatrixFitter::new(matrix);
182
183 Self {
184 sampling_points,
185 fitter,
186 _phantom: std::marker::PhantomData,
187 }
188 }
189
190 /// Get the sampling points
191 pub fn sampling_points(&self) -> &[f64] {
192 &self.sampling_points
193 }
194
195 /// Get the number of sampling points
196 pub fn n_sampling_points(&self) -> usize {
197 self.fitter.n_points()
198 }
199
200 /// Get the basis size
201 pub fn basis_size(&self) -> usize {
202 self.fitter.basis_size()
203 }
204
205 /// Get the sampling matrix
206 pub fn matrix(&self) -> &DTensor<f64, 2> {
207 &self.fitter.matrix
208 }
209
210 /// Evaluate basis coefficients at sampling points
211 ///
212 /// Computes g(τ_i) = Σ_l a_l * u_l(τ_i) for all sampling points
213 ///
214 /// # Arguments
215 /// * `coeffs` - Basis coefficients (length = basis_size)
216 ///
217 /// # Returns
218 /// Values at sampling points (length = n_sampling_points)
219 ///
220 /// # Panics
221 /// Panics if `coeffs.len() != basis_size`
222 pub fn evaluate(&self, coeffs: &[f64]) -> Vec<f64> {
223 self.fitter.evaluate(None, coeffs)
224 }
225
226 /// Internal generic evaluate_nd implementation
227 fn evaluate_nd_impl<T>(
228 &self,
229 backend: Option<&GemmBackendHandle>,
230 coeffs: &Tensor<T, DynRank>,
231 dim: usize,
232 ) -> Tensor<T, DynRank>
233 where
234 T: num_complex::ComplexFloat + faer_traits::ComplexField + 'static + From<f64> + Copy,
235 {
236 let rank = coeffs.rank();
237 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
238
239 let basis_size = self.basis_size();
240 let target_dim_size = coeffs.shape().dim(dim);
241
242 // Check that the target dimension matches basis_size
243 assert_eq!(
244 target_dim_size, basis_size,
245 "coeffs.shape().dim({}) = {} must equal basis_size = {}",
246 dim, target_dim_size, basis_size
247 );
248
249 // 1. Move target dimension to position 0
250 let coeffs_dim0 = movedim(coeffs, dim, 0);
251
252 // 2. Reshape to 2D: (basis_size, extra_size)
253 let extra_size: usize = coeffs_dim0.len() / basis_size;
254
255 // Convert DynRank to fixed Rank<2> for matmul_par
256 let coeffs_2d_dyn = coeffs_dim0
257 .reshape(&[basis_size, extra_size][..])
258 .to_tensor();
259 let coeffs_2d = DTensor::<T, 2>::from_fn([basis_size, extra_size], |idx| {
260 coeffs_2d_dyn[&[idx[0], idx[1]][..]]
261 });
262
263 // 3. Matrix multiply: result = A * coeffs
264 // A is real, convert to type T
265 let n_points = self.n_sampling_points();
266 let matrix_t = DTensor::<T, 2>::from_fn(*self.fitter.matrix.shape(), |idx| {
267 self.fitter.matrix[idx].into()
268 });
269 let result_2d = matmul_par(&matrix_t, &coeffs_2d, backend);
270
271 // 4. Reshape back to N-D with n_points at position 0
272 let mut result_shape = vec![n_points];
273 coeffs_dim0.shape().with_dims(|dims| {
274 for i in 1..dims.len() {
275 result_shape.push(dims[i]);
276 }
277 });
278
279 // Convert DTensor<T, 2> to DynRank using into_dyn()
280 let result_2d_dyn = result_2d.into_dyn();
281 let result_dim0 = result_2d_dyn.reshape(&result_shape[..]).to_tensor();
282
283 // 5. Move dimension back to original position
284 movedim(&result_dim0, 0, dim)
285 }
286
287 /// Evaluate basis coefficients at sampling points (N-dimensional)
288 ///
289 /// Evaluates along the specified dimension, keeping other dimensions intact.
290 /// Supports both real (`f64`) and complex (`Complex<f64>`) coefficients.
291 ///
292 /// # Type Parameters
293 /// * `T` - Element type (f64 or Complex<f64>)
294 ///
295 /// # Arguments
296 /// * `coeffs` - N-dimensional array with `coeffs.shape().dim(dim) == basis_size`
297 /// * `dim` - Dimension along which to evaluate (0-indexed)
298 ///
299 /// # Returns
300 /// N-dimensional array with `result.shape().dim(dim) == n_sampling_points`
301 ///
302 /// # Panics
303 /// Panics if `coeffs.shape().dim(dim) != basis_size` or if `dim >= rank`
304 ///
305 /// # Example
306 /// ```ignore
307 /// use num_complex::Complex;
308 /// use mdarray::tensor;
309 ///
310 /// // Real coefficients
311 /// let values_real = sampling.evaluate_nd::<f64>(&coeffs_real, 0);
312 ///
313 /// // Complex coefficients
314 /// let values_complex = sampling.evaluate_nd::<Complex<f64>>(&coeffs_complex, 0);
315 /// ```
316 pub fn evaluate_nd<T>(
317 &self,
318 backend: Option<&GemmBackendHandle>,
319 coeffs: &Tensor<T, DynRank>,
320 dim: usize,
321 ) -> Tensor<T, DynRank>
322 where
323 T: num_complex::ComplexFloat + faer_traits::ComplexField + 'static + From<f64> + Copy,
324 {
325 self.evaluate_nd_impl(backend, coeffs, dim)
326 }
327
328 /// Internal generic fit_nd implementation
329 ///
330 /// Delegates to fitter for real values, fits real/imaginary parts separately for complex values
331 fn fit_nd_impl<T>(
332 &self,
333 backend: Option<&GemmBackendHandle>,
334 values: &Tensor<T, DynRank>,
335 dim: usize,
336 ) -> Tensor<T, DynRank>
337 where
338 T: num_complex::ComplexFloat
339 + faer_traits::ComplexField
340 + 'static
341 + From<f64>
342 + Copy
343 + Default,
344 {
345 use num_complex::Complex;
346
347 let rank = values.rank();
348 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
349
350 let n_points = self.n_sampling_points();
351 let basis_size = self.basis_size();
352 let target_dim_size = values.shape().dim(dim);
353
354 assert_eq!(
355 target_dim_size, n_points,
356 "values.shape().dim({}) = {} must equal n_sampling_points = {}",
357 dim, target_dim_size, n_points
358 );
359
360 // 1. Move target dimension to position 0
361 let values_dim0 = movedim(values, dim, 0);
362
363 // 2. Reshape to 2D: (n_points, extra_size)
364 let extra_size: usize = values_dim0.len() / n_points;
365 let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
366
367 // 3. Convert to DTensor<T, 2> and fit using fitter's 2D methods
368 // Use type introspection to dispatch between real and complex
369 use std::any::TypeId;
370 let is_real = TypeId::of::<T>() == TypeId::of::<f64>();
371
372 let coeffs_2d = if is_real {
373 // Real case: convert to f64 tensor and fit
374 let values_2d_f64 = DTensor::<f64, 2>::from_fn([n_points, extra_size], |idx| unsafe {
375 *(&values_2d_dyn[&[idx[0], idx[1]][..]] as *const T as *const f64)
376 });
377 let coeffs_2d_f64 = self.fitter.fit_2d(backend, &values_2d_f64);
378 // Convert back to T
379 DTensor::<T, 2>::from_fn(*coeffs_2d_f64.shape(), |idx| unsafe {
380 *(&coeffs_2d_f64[idx] as *const f64 as *const T)
381 })
382 } else {
383 // Complex case: convert to Complex<f64> tensor and fit
384 let values_2d_c64 =
385 DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| unsafe {
386 *(&values_2d_dyn[&[idx[0], idx[1]][..]] as *const T as *const Complex<f64>)
387 });
388 let coeffs_2d_c64 = self.fitter.fit_complex_2d(backend, &values_2d_c64);
389 // Convert back to T
390 DTensor::<T, 2>::from_fn(*coeffs_2d_c64.shape(), |idx| unsafe {
391 *(&coeffs_2d_c64[idx] as *const Complex<f64> as *const T)
392 })
393 };
394
395 // 4. Reshape back to N-D with basis_size at position 0
396 let mut coeffs_shape = vec![basis_size];
397 values_dim0.shape().with_dims(|dims| {
398 for i in 1..dims.len() {
399 coeffs_shape.push(dims[i]);
400 }
401 });
402
403 let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
404
405 // 5. Move dimension 0 back to original position dim
406 movedim(&coeffs_dim0, 0, dim)
407 }
408
409 /// Fit basis coefficients from values at sampling points (N-dimensional)
410 ///
411 /// Fits along the specified dimension, keeping other dimensions intact.
412 /// Supports both real (`f64`) and complex (`Complex<f64>`) values.
413 ///
414 /// # Type Parameters
415 /// * `T` - Element type (f64 or Complex<f64>)
416 ///
417 /// # Arguments
418 /// * `values` - N-dimensional array with `values.shape().dim(dim) == n_sampling_points`
419 /// * `dim` - Dimension along which to fit (0-indexed)
420 ///
421 /// # Returns
422 /// N-dimensional array with `result.shape().dim(dim) == basis_size`
423 ///
424 /// # Panics
425 /// Panics if `values.shape().dim(dim) != n_sampling_points`, if `dim >= rank`, or if SVD not computed
426 ///
427 /// # Example
428 /// ```ignore
429 /// use num_complex::Complex;
430 /// use mdarray::tensor;
431 ///
432 /// // Real values
433 /// let coeffs_real = sampling.fit_nd::<f64>(&values_real, 0);
434 ///
435 /// // Complex values
436 /// let coeffs_complex = sampling.fit_nd::<Complex<f64>>(&values_complex, 0);
437 /// ```
438 pub fn fit_nd<T>(
439 &self,
440 backend: Option<&GemmBackendHandle>,
441 values: &Tensor<T, DynRank>,
442 dim: usize,
443 ) -> Tensor<T, DynRank>
444 where
445 T: num_complex::ComplexFloat
446 + faer_traits::ComplexField
447 + 'static
448 + From<f64>
449 + Copy
450 + Default,
451 {
452 self.fit_nd_impl(backend, values, dim)
453 }
454}
455
456#[cfg(test)]
457#[path = "tau_sampling_tests.rs"]
458mod tests;