1use crate::fpu_check::FpuGuard;
7use crate::gemm::GemmBackendHandle;
8use mdarray::{DTensor, DynRank, Shape, Slice, ViewMut};
9use num_complex::Complex;
10
11pub trait InplaceFitter {
34 fn n_points(&self) -> usize;
36
37 fn basis_size(&self) -> usize;
39
40 fn evaluate_nd_dd_to(
42 &self,
43 backend: Option<&GemmBackendHandle>,
44 coeffs: &Slice<f64, DynRank>,
45 dim: usize,
46 out: &mut ViewMut<'_, f64, DynRank>,
47 ) -> bool {
48 let _ = (backend, coeffs, dim, out);
49 false
50 }
51
52 fn evaluate_nd_dz_to(
54 &self,
55 backend: Option<&GemmBackendHandle>,
56 coeffs: &Slice<f64, DynRank>,
57 dim: usize,
58 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
59 ) -> bool {
60 let _ = (backend, coeffs, dim, out);
61 false
62 }
63
64 fn evaluate_nd_zd_to(
66 &self,
67 backend: Option<&GemmBackendHandle>,
68 coeffs: &Slice<Complex<f64>, DynRank>,
69 dim: usize,
70 out: &mut ViewMut<'_, f64, DynRank>,
71 ) -> bool {
72 let _ = (backend, coeffs, dim, out);
73 false
74 }
75
76 fn evaluate_nd_zz_to(
78 &self,
79 backend: Option<&GemmBackendHandle>,
80 coeffs: &Slice<Complex<f64>, DynRank>,
81 dim: usize,
82 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
83 ) -> bool {
84 let _ = (backend, coeffs, dim, out);
85 false
86 }
87
88 fn fit_nd_dd_to(
90 &self,
91 backend: Option<&GemmBackendHandle>,
92 values: &Slice<f64, DynRank>,
93 dim: usize,
94 out: &mut ViewMut<'_, f64, DynRank>,
95 ) -> bool {
96 let _ = (backend, values, dim, out);
97 false
98 }
99
100 fn fit_nd_dz_to(
102 &self,
103 backend: Option<&GemmBackendHandle>,
104 values: &Slice<f64, DynRank>,
105 dim: usize,
106 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
107 ) -> bool {
108 let _ = (backend, values, dim, out);
109 false
110 }
111
112 fn fit_nd_zd_to(
114 &self,
115 backend: Option<&GemmBackendHandle>,
116 values: &Slice<Complex<f64>, DynRank>,
117 dim: usize,
118 out: &mut ViewMut<'_, f64, DynRank>,
119 ) -> bool {
120 let _ = (backend, values, dim, out);
121 false
122 }
123
124 fn fit_nd_zz_to(
126 &self,
127 backend: Option<&GemmBackendHandle>,
128 values: &Slice<Complex<f64>, DynRank>,
129 dim: usize,
130 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
131 ) -> bool {
132 let _ = (backend, values, dim, out);
133 false
134 }
135}
136
137pub(crate) fn make_perm_to_front(rank: usize, dim: usize) -> Vec<usize> {
146 let mut perm = Vec::with_capacity(rank);
147 perm.push(dim);
148 for i in 0..rank {
149 if i != dim {
150 perm.push(i);
151 }
152 }
153 perm
154}
155
156pub(crate) fn copy_from_contiguous<T: Copy>(
168 src: &[T],
169 dst: &mut mdarray::Slice<T, mdarray::DynRank, mdarray::Strided>,
170) {
171 assert_eq!(src.len(), dst.len(), "Source size mismatch");
172
173 for (d, s) in dst.iter_mut().zip(src.iter()) {
175 *d = *s;
176 }
177}
178
179#[allow(dead_code)]
188pub(crate) fn complex_slice_mut_as_real<'a>(
189 out: &'a mut Slice<Complex<f64>, DynRank>,
190) -> mdarray::ViewMut<'a, f64, DynRank, mdarray::Dense> {
191 let mut new_shape: Vec<usize> = Vec::with_capacity(out.rank() + 1);
193 out.shape().with_dims(|dims| {
194 for d in dims {
195 new_shape.push(*d);
196 }
197 });
198 new_shape.push(2);
199
200 unsafe {
201 let shape: DynRank = Shape::from_dims(&new_shape[..]);
202 let mapping = mdarray::DenseMapping::new(shape);
203 mdarray::ViewMut::new_unchecked(out.as_mut_ptr() as *mut f64, mapping)
204 }
205}
206
207pub(crate) struct RealSVD {
213 pub ut: DTensor<f64, 2>, pub s: Vec<f64>, pub v: DTensor<f64, 2>, }
217
218impl RealSVD {
219 pub fn new(u: DTensor<f64, 2>, s: Vec<f64>, vt: DTensor<f64, 2>) -> Self {
220 let (_, u_cols) = *u.shape();
222 let (vt_rows, _) = *vt.shape();
223 let min_dim = s.len();
224
225 assert_eq!(
226 u_cols, min_dim,
227 "u.cols()={} must equal s.len()={}",
228 u_cols, min_dim
229 );
230 assert_eq!(
231 vt_rows, min_dim,
232 "vt.rows()={} must equal s.len()={}",
233 vt_rows, min_dim
234 );
235
236 let ut = u.transpose().to_tensor(); let v = vt.transpose().to_tensor(); assert_eq!(
242 v.shape().1,
243 min_dim,
244 "v.cols()={} must equal s.len()={}",
245 v.shape().1,
246 min_dim
247 );
248
249 Self { ut, s, v }
250 }
251}
252
253pub(crate) struct ComplexSVD {
255 pub ut: DTensor<Complex<f64>, 2>, pub s: Vec<f64>, pub v: DTensor<Complex<f64>, 2>, }
259
260impl ComplexSVD {
261 pub fn new(u: DTensor<Complex<f64>, 2>, s: Vec<f64>, vt: DTensor<Complex<f64>, 2>) -> Self {
262 let (u_rows, u_cols) = *u.shape();
264 let (vt_rows, _) = *vt.shape();
265 let min_dim = s.len();
266
267 assert_eq!(
268 u_cols, min_dim,
269 "u.cols()={} must equal s.len()={}",
270 u_cols, min_dim
271 );
272 assert_eq!(
273 vt_rows, min_dim,
274 "vt.rows()={} must equal s.len()={}",
275 vt_rows, min_dim
276 );
277
278 let ut = DTensor::<Complex<f64>, 2>::from_fn([u_cols, u_rows], |idx| {
280 u[[idx[1], idx[0]]].conj() });
282 let v = vt.transpose().to_tensor(); assert_eq!(
286 v.shape().1,
287 min_dim,
288 "v.cols()={} must equal s.len()={}",
289 v.shape().1,
290 min_dim
291 );
292
293 Self { ut, s, v }
294 }
295}
296
297pub(crate) fn compute_real_svd(matrix: &DTensor<f64, 2>) -> RealSVD {
303 use mdarray_linalg::prelude::SVD;
304 use mdarray_linalg::svd::SVDDecomp;
305 use mdarray_linalg_faer::Faer;
306
307 let _guard = FpuGuard::new_protect_computation();
309
310 let mut a = matrix.clone();
311 let SVDDecomp { u, s, vt } = Faer.svd(&mut *a).expect("SVD computation failed");
312
313 let min_dim = s.shape().0.min(s.shape().1);
315 let s_vec: Vec<f64> = (0..min_dim).map(|i| s[[0, i]]).collect();
316
317 let u_trimmed = u.view(.., ..min_dim).to_tensor();
321 let vt_trimmed = vt.view(..min_dim, ..).to_tensor();
322
323 RealSVD::new(u_trimmed, s_vec, vt_trimmed)
324}
325
326pub(crate) fn compute_complex_svd(matrix: &DTensor<Complex<f64>, 2>) -> ComplexSVD {
328 use mdarray_linalg::prelude::SVD;
329 use mdarray_linalg::svd::SVDDecomp;
330 use mdarray_linalg_faer::Faer;
331
332 let _guard = FpuGuard::new_protect_computation();
334
335 let mut matrix_c64 = matrix.clone();
337
338 let SVDDecomp { u, s, vt } = Faer
340 .svd(&mut *matrix_c64)
341 .expect("Complex SVD computation failed");
342
343 let min_dim = s.shape().0.min(s.shape().1);
345 let s_vec: Vec<f64> = (0..min_dim).map(|i| s[[0, i]].re).collect();
346
347 let u_trimmed = u.view(.., ..min_dim).to_tensor();
351 let vt_trimmed = vt.view(..min_dim, ..).to_tensor();
352
353 ComplexSVD::new(u_trimmed, s_vec, vt_trimmed)
354}
355
356pub(crate) fn combine_complex(
362 re: &DTensor<f64, 2>,
363 im: &DTensor<f64, 2>,
364) -> DTensor<Complex<f64>, 2> {
365 let (n_points, extra_size) = *re.shape();
366 DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
367 Complex::new(re[idx], im[idx])
368 })
369}
370
371pub(crate) fn extract_real_parts_coeffs(coeffs_2d: &DTensor<Complex<f64>, 2>) -> DTensor<f64, 2> {
373 let (basis_size, extra_size) = *coeffs_2d.shape();
374 DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| coeffs_2d[idx].re)
375}