sparse_ir_capi/
sampling.rs

1//! Sampling API for C
2//!
3//! This module provides the C API for sparse sampling in imaginary time (τ),
4//! Matsubara frequency (iωn), and real frequency (ω) domains.
5//!
6//! Functions:
7//! - Creation: spir_tau_sampling_new, spir_matsu_sampling_new, ...
8//! - Introspection: get_npoints, get_taus, get_matsus, get_cond_num
9//! - Evaluation: eval_dd, eval_dz, eval_zz (coefficients → sampling points)
10//! - Fitting: fit_dd, fit_zz, fit_zd (sampling points → coefficients)
11//! - Memory: release, clone, is_assigned (via macro)
12
13use mdarray::Shape;
14use num_complex::Complex64;
15use std::panic::{AssertUnwindSafe, catch_unwind};
16use std::sync::Arc;
17
18use crate::gemm::{get_backend_handle, spir_gemm_backend};
19use crate::types::{BasisType, SamplingType, spir_basis, spir_sampling};
20use crate::utils::{
21    MemoryOrder, build_output_dims, convert_dims_for_row_major, create_dview_from_ptr,
22    create_dviewmut_from_ptr, read_tensor_nd,
23};
24use crate::{
25    SPIR_COMPUTATION_SUCCESS, SPIR_INVALID_ARGUMENT, SPIR_NOT_SUPPORTED, SPIR_STATISTICS_BOSONIC,
26    SPIR_STATISTICS_FERMIONIC, StatusCode,
27};
28use sparse_ir::fitters::InplaceFitter;
29use sparse_ir::{Bosonic, Fermionic};
30
31/// Manual release function (replaces macro-generated one)
32#[unsafe(no_mangle)]
33pub extern "C" fn spir_sampling_release(sampling: *mut spir_sampling) {
34    if !sampling.is_null() {
35        unsafe {
36            let _ = Box::from_raw(sampling);
37        }
38    }
39}
40
41/// Manual clone function (replaces macro-generated one)
42#[unsafe(no_mangle)]
43pub extern "C" fn spir_sampling_clone(src: *const spir_sampling) -> *mut spir_sampling {
44    if src.is_null() {
45        return std::ptr::null_mut();
46    }
47
48    let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| unsafe {
49        let src_ref = &*src;
50        let cloned = (*src_ref).clone();
51        Box::into_raw(Box::new(cloned))
52    }));
53
54    result.unwrap_or(std::ptr::null_mut())
55}
56
57/// Manual is_assigned function (replaces macro-generated one)
58#[unsafe(no_mangle)]
59pub extern "C" fn spir_sampling_is_assigned(obj: *const spir_sampling) -> i32 {
60    if obj.is_null() {
61        return 0;
62    }
63
64    let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| unsafe {
65        let _ = &*obj;
66        1
67    }));
68
69    result.unwrap_or(0)
70}
71
72// ============================================================================
73// Creation Functions
74// ============================================================================
75
76/// Creates a new tau sampling object for sparse sampling in imaginary time
77///
78/// # Arguments
79/// * `b` - Pointer to a finite temperature basis object
80/// * `num_points` - Number of sampling points
81/// * `points` - Array of sampling points in imaginary time (τ)
82/// * `status` - Pointer to store the status code
83///
84/// # Returns
85/// Pointer to the newly created sampling object, or NULL if creation fails
86///
87/// # Safety
88/// Caller must ensure `b` is valid and `points` has `num_points` elements
89#[unsafe(no_mangle)]
90pub extern "C" fn spir_tau_sampling_new(
91    b: *const spir_basis,
92    num_points: libc::c_int,
93    points: *const f64,
94    status: *mut StatusCode,
95) -> *mut spir_sampling {
96    let result = catch_unwind(AssertUnwindSafe(|| {
97        // Validate inputs
98        if b.is_null() || points.is_null() {
99            return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT);
100        }
101        if num_points <= 0 {
102            return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT);
103        }
104
105        let basis_ref = unsafe { &*b };
106        let points_slice = unsafe { std::slice::from_raw_parts(points, num_points as usize) };
107
108        // Convert points to Vec
109        let tau_points: Vec<f64> = points_slice.to_vec();
110
111        // Create sampling based on basis statistics
112        let sampling_type = match basis_ref.inner() {
113            BasisType::LogisticFermionic(ir_basis) => {
114                let tau_sampling = sparse_ir::sampling::TauSampling::with_sampling_points(
115                    ir_basis.as_ref(),
116                    tau_points,
117                );
118                SamplingType::TauFermionic(Arc::new(tau_sampling))
119            }
120            BasisType::RegularizedBoseFermionic(ir_basis) => {
121                let tau_sampling = sparse_ir::sampling::TauSampling::with_sampling_points(
122                    ir_basis.as_ref(),
123                    tau_points,
124                );
125                SamplingType::TauFermionic(Arc::new(tau_sampling))
126            }
127            BasisType::LogisticBosonic(ir_basis) => {
128                let tau_sampling = sparse_ir::sampling::TauSampling::with_sampling_points(
129                    ir_basis.as_ref(),
130                    tau_points,
131                );
132                SamplingType::TauBosonic(Arc::new(tau_sampling))
133            }
134            BasisType::RegularizedBoseBosonic(ir_basis) => {
135                let tau_sampling = sparse_ir::sampling::TauSampling::with_sampling_points(
136                    ir_basis.as_ref(),
137                    tau_points,
138                );
139                SamplingType::TauBosonic(Arc::new(tau_sampling))
140            }
141            // DLR: tau sampling supported via Basis trait
142            BasisType::DLRFermionic(dlr) => {
143                let tau_sampling = sparse_ir::sampling::TauSampling::with_sampling_points(
144                    dlr.as_ref(),
145                    tau_points,
146                );
147                SamplingType::TauFermionic(Arc::new(tau_sampling))
148            }
149            BasisType::DLRBosonic(dlr) => {
150                let tau_sampling = sparse_ir::sampling::TauSampling::with_sampling_points(
151                    dlr.as_ref(),
152                    tau_points,
153                );
154                SamplingType::TauBosonic(Arc::new(tau_sampling))
155            }
156        };
157
158        let inner = sampling_type;
159        let sampling = spir_sampling {
160            _private: Box::into_raw(Box::new(inner)) as *mut std::ffi::c_void,
161        };
162
163        (Box::into_raw(Box::new(sampling)), SPIR_COMPUTATION_SUCCESS)
164    }));
165
166    match result {
167        Ok((ptr, code)) => {
168            if !status.is_null() {
169                unsafe {
170                    *status = code;
171                }
172            }
173            ptr
174        }
175        Err(_) => {
176            if !status.is_null() {
177                unsafe {
178                    *status = crate::SPIR_INTERNAL_ERROR;
179                }
180            }
181            std::ptr::null_mut()
182        }
183    }
184}
185
186/// Creates a new Matsubara sampling object for sparse sampling in Matsubara frequencies
187///
188/// # Arguments
189/// * `b` - Pointer to a finite temperature basis object
190/// * `positive_only` - If true, only positive frequencies are used
191/// * `num_points` - Number of sampling points
192/// * `points` - Array of Matsubara frequency indices (n)
193/// * `status` - Pointer to store the status code
194///
195/// # Returns
196/// Pointer to the newly created sampling object, or NULL if creation fails
197#[unsafe(no_mangle)]
198pub extern "C" fn spir_matsu_sampling_new(
199    b: *const spir_basis,
200    positive_only: bool,
201    num_points: libc::c_int,
202    points: *const i64,
203    status: *mut StatusCode,
204) -> *mut spir_sampling {
205    let result = catch_unwind(AssertUnwindSafe(|| {
206        // Validate inputs
207        if b.is_null() || points.is_null() {
208            return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT);
209        }
210        if num_points <= 0 {
211            return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT);
212        }
213
214        let basis_ref = unsafe { &*b };
215        let points_slice = unsafe { std::slice::from_raw_parts(points, num_points as usize) };
216
217        // Convert points to Vec
218        let matsu_points: Vec<i64> = points_slice.to_vec();
219
220        // Convert i64 indices to MatsubaraFreq
221        use sparse_ir::freq::MatsubaraFreq;
222
223        // Helper macro to reduce duplication
224        macro_rules! create_matsu_sampling {
225            ($basis:expr, Fermionic) => {
226                if positive_only {
227                    let matsu_freqs: Vec<MatsubaraFreq<Fermionic>> = matsu_points
228                        .iter()
229                        .map(|&n| MatsubaraFreq::new(n).expect("Invalid Matsubara frequency"))
230                        .collect();
231                    let matsu_sampling = sparse_ir::matsubara_sampling::MatsubaraSamplingPositiveOnly::with_sampling_points(
232                        $basis,
233                        matsu_freqs,
234                    );
235                    SamplingType::MatsubaraPositiveOnlyFermionic(Arc::new(matsu_sampling))
236                } else {
237                    let matsu_freqs: Vec<MatsubaraFreq<Fermionic>> = matsu_points
238                        .iter()
239                        .map(|&n| MatsubaraFreq::new(n).expect("Invalid Matsubara frequency"))
240                        .collect();
241                    let matsu_sampling = sparse_ir::matsubara_sampling::MatsubaraSampling::with_sampling_points(
242                        $basis,
243                        matsu_freqs,
244                    );
245                    SamplingType::MatsubaraFermionic(Arc::new(matsu_sampling))
246                }
247            };
248            ($basis:expr, Bosonic) => {
249                if positive_only {
250                    let matsu_freqs: Vec<MatsubaraFreq<Bosonic>> = matsu_points
251                        .iter()
252                        .map(|&n| MatsubaraFreq::new(n).expect("Invalid Matsubara frequency"))
253                        .collect();
254                    let matsu_sampling = sparse_ir::matsubara_sampling::MatsubaraSamplingPositiveOnly::with_sampling_points(
255                        $basis,
256                        matsu_freqs,
257                    );
258                    SamplingType::MatsubaraPositiveOnlyBosonic(Arc::new(matsu_sampling))
259                } else {
260                    let matsu_freqs: Vec<MatsubaraFreq<Bosonic>> = matsu_points
261                        .iter()
262                        .map(|&n| MatsubaraFreq::new(n).expect("Invalid Matsubara frequency"))
263                        .collect();
264                    let matsu_sampling = sparse_ir::matsubara_sampling::MatsubaraSampling::with_sampling_points(
265                        $basis,
266                        matsu_freqs,
267                    );
268                    SamplingType::MatsubaraBosonic(Arc::new(matsu_sampling))
269                }
270            };
271        }
272
273        // Create sampling based on basis statistics and positive_only flag
274        let sampling_type = match basis_ref.inner() {
275            BasisType::LogisticFermionic(ir_basis) => {
276                create_matsu_sampling!(ir_basis.as_ref(), Fermionic)
277            }
278            BasisType::RegularizedBoseFermionic(ir_basis) => {
279                create_matsu_sampling!(ir_basis.as_ref(), Fermionic)
280            }
281            BasisType::LogisticBosonic(ir_basis) => {
282                create_matsu_sampling!(ir_basis.as_ref(), Bosonic)
283            }
284            BasisType::RegularizedBoseBosonic(ir_basis) => {
285                create_matsu_sampling!(ir_basis.as_ref(), Bosonic)
286            }
287            // DLR: Matsubara sampling supported via Basis trait
288            BasisType::DLRFermionic(dlr) => {
289                create_matsu_sampling!(dlr.as_ref(), Fermionic)
290            }
291            BasisType::DLRBosonic(dlr) => {
292                create_matsu_sampling!(dlr.as_ref(), Bosonic)
293            }
294        };
295
296        let inner = sampling_type;
297        let sampling = spir_sampling {
298            _private: Box::into_raw(Box::new(inner)) as *mut std::ffi::c_void,
299        };
300
301        (Box::into_raw(Box::new(sampling)), SPIR_COMPUTATION_SUCCESS)
302    }));
303
304    match result {
305        Ok((ptr, code)) => {
306            if !status.is_null() {
307                unsafe {
308                    *status = code;
309                }
310            }
311            ptr
312        }
313        Err(_) => {
314            if !status.is_null() {
315                unsafe {
316                    *status = crate::SPIR_INTERNAL_ERROR;
317                }
318            }
319            std::ptr::null_mut()
320        }
321    }
322}
323
324/// Creates a new tau sampling object with custom sampling points and pre-computed matrix
325///
326/// # Arguments
327/// * `order` - Memory layout order (SPIR_ORDER_ROW_MAJOR or SPIR_ORDER_COLUMN_MAJOR)
328/// * `statistics` - Statistics type (SPIR_STATISTICS_FERMIONIC or SPIR_STATISTICS_BOSONIC)
329/// * `basis_size` - Basis size
330/// * `num_points` - Number of sampling points
331/// * `points` - Array of sampling points in imaginary time (τ)
332/// * `matrix` - Pre-computed matrix for the sampling points (num_points x basis_size)
333/// * `status` - Pointer to store the status code
334///
335/// # Returns
336/// Pointer to the newly created sampling object, or NULL if creation fails
337///
338/// # Safety
339/// Caller must ensure `points` and `matrix` have correct sizes
340#[unsafe(no_mangle)]
341pub extern "C" fn spir_tau_sampling_new_with_matrix(
342    order: libc::c_int,
343    statistics: libc::c_int,
344    basis_size: libc::c_int,
345    num_points: libc::c_int,
346    points: *const f64,
347    matrix: *const f64,
348    status: *mut StatusCode,
349) -> *mut spir_sampling {
350    let result = catch_unwind(AssertUnwindSafe(|| {
351        // Validate inputs
352        if points.is_null() || matrix.is_null() {
353            return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT);
354        }
355        if num_points <= 0 || basis_size <= 0 {
356            return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT);
357        }
358
359        // Parse order
360        let mem_order = match MemoryOrder::from_c_int(order) {
361            Ok(o) => o,
362            Err(_) => return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT),
363        };
364
365        // Convert points to Vec
366        let points_slice = unsafe { std::slice::from_raw_parts(points, num_points as usize) };
367        let tau_points: Vec<f64> = points_slice.to_vec();
368
369        // Convert matrix to Tensor using the new helper function
370        let orig_dims = [num_points as usize, basis_size as usize];
371        let dyn_tensor = unsafe { read_tensor_nd(matrix, &orig_dims, mem_order) };
372
373        // Convert DynRank to fixed 2D shape using from_fn (safe conversion)
374        let shape_dims = dyn_tensor.shape().with_dims(|dims| dims.to_vec());
375        assert_eq!(
376            shape_dims.len(),
377            2,
378            "Expected 2D tensor, got {}D",
379            shape_dims.len()
380        );
381        let num_points_actual = shape_dims[0];
382        let basis_size_actual = shape_dims[1];
383        let matrix_tensor =
384            sparse_ir::DTensor::<f64, 2>::from_fn([num_points_actual, basis_size_actual], |idx| {
385                dyn_tensor[&[idx[0], idx[1]][..]]
386            });
387        // Create sampling based on statistics
388        let sampling_type = match statistics {
389            SPIR_STATISTICS_FERMIONIC => {
390                // SPIR_STATISTICS_FERMIONIC
391                let tau_sampling = sparse_ir::sampling::TauSampling::<Fermionic>::from_matrix(
392                    tau_points,
393                    matrix_tensor,
394                );
395                SamplingType::TauFermionic(Arc::new(tau_sampling))
396            }
397            SPIR_STATISTICS_BOSONIC => {
398                // SPIR_STATISTICS_BOSONIC
399                let tau_sampling = sparse_ir::sampling::TauSampling::<Bosonic>::from_matrix(
400                    tau_points,
401                    matrix_tensor,
402                );
403                SamplingType::TauBosonic(Arc::new(tau_sampling))
404            }
405            _ => return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT),
406        };
407
408        let inner = sampling_type;
409        let sampling = spir_sampling {
410            _private: Box::into_raw(Box::new(inner)) as *mut std::ffi::c_void,
411        };
412
413        (Box::into_raw(Box::new(sampling)), SPIR_COMPUTATION_SUCCESS)
414    }));
415
416    match result {
417        Ok((ptr, code)) => {
418            if !status.is_null() {
419                unsafe {
420                    *status = code;
421                }
422            }
423            ptr
424        }
425        Err(_) => {
426            if !status.is_null() {
427                unsafe {
428                    *status = crate::SPIR_INTERNAL_ERROR;
429                }
430            }
431            std::ptr::null_mut()
432        }
433    }
434}
435
436/// Creates a new Matsubara sampling object with custom sampling points and pre-computed matrix
437///
438/// # Arguments
439/// * `order` - Memory layout order (SPIR_ORDER_ROW_MAJOR or SPIR_ORDER_COLUMN_MAJOR)
440/// * `statistics` - Statistics type (SPIR_STATISTICS_FERMIONIC or SPIR_STATISTICS_BOSONIC)
441/// * `basis_size` - Basis size
442/// * `positive_only` - If true, only positive frequencies are used
443/// * `num_points` - Number of sampling points
444/// * `points` - Array of Matsubara frequency indices (n)
445/// * `matrix` - Pre-computed complex matrix (num_points x basis_size)
446/// * `status` - Pointer to store the status code
447///
448/// # Returns
449/// Pointer to the newly created sampling object, or NULL if creation fails
450///
451/// # Safety
452/// Caller must ensure `points` and `matrix` have correct sizes
453#[unsafe(no_mangle)]
454pub extern "C" fn spir_matsu_sampling_new_with_matrix(
455    order: libc::c_int,
456    statistics: libc::c_int,
457    basis_size: libc::c_int,
458    positive_only: bool,
459    num_points: libc::c_int,
460    points: *const i64,
461    matrix: *const Complex64,
462    status: *mut StatusCode,
463) -> *mut spir_sampling {
464    use std::io::Write;
465    debug_println!(
466        "spir_matsu_sampling_new_with_matrix: start, order={}, statistics={}, basis_size={}, positive_only={}, num_points={}",
467        order,
468        statistics,
469        basis_size,
470        positive_only,
471        num_points
472    );
473    std::io::stderr().flush().ok();
474    let result = catch_unwind(AssertUnwindSafe(|| {
475        use std::io::Write;
476        debug_println!("spir_matsu_sampling_new_with_matrix: inside catch_unwind");
477        std::io::stderr().flush().ok();
478        // Validate inputs
479        if points.is_null() || matrix.is_null() {
480            debug_eprintln!("spir_matsu_sampling_new_with_matrix: null pointer");
481            return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT);
482        }
483        if num_points <= 0 || basis_size <= 0 {
484            debug_eprintln!(
485                "spir_matsu_sampling_new_with_matrix: invalid size, num_points={}, basis_size={}",
486                num_points,
487                basis_size
488            );
489            return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT);
490        }
491        debug_println!("spir_matsu_sampling_new_with_matrix: input validation passed");
492        std::io::stderr().flush().ok();
493
494        // Parse order
495        let mem_order = match MemoryOrder::from_c_int(order) {
496            Ok(o) => o,
497            Err(_) => return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT),
498        };
499
500        // Convert points to Vec<MatsubaraFreq>
501        debug_println!("spir_matsu_sampling_new_with_matrix: creating points slice...");
502        std::io::stderr().flush().ok();
503        let points_slice = unsafe { std::slice::from_raw_parts(points, num_points as usize) };
504        debug_println!(
505            "spir_matsu_sampling_new_with_matrix: points slice created, len = {}",
506            points_slice.len()
507        );
508        std::io::stderr().flush().ok();
509        let matsu_points: Vec<i64> = points_slice.to_vec();
510        debug_println!(
511            "spir_matsu_sampling_new_with_matrix: matsu_points created, len = {}",
512            matsu_points.len()
513        );
514        std::io::stderr().flush().ok();
515
516        use sparse_ir::freq::MatsubaraFreq;
517
518        // Convert matrix to Tensor using the new helper function
519        let orig_dims = [num_points as usize, basis_size as usize];
520        debug_println!(
521            "spir_matsu_sampling_new_with_matrix: orig_dims = {:?}, mem_order = {:?}",
522            orig_dims,
523            mem_order
524        );
525        std::io::stderr().flush().ok();
526
527        debug_println!("spir_matsu_sampling_new_with_matrix: reading tensor from buffer...");
528        std::io::stderr().flush().ok();
529        let dyn_tensor = unsafe { read_tensor_nd(matrix, &orig_dims, mem_order) };
530        let shape_dims = dyn_tensor.shape().with_dims(|dims| dims.to_vec());
531        debug_println!(
532            "spir_matsu_sampling_new_with_matrix: dyn_tensor created, shape = {:?}",
533            shape_dims
534        );
535        std::io::stderr().flush().ok();
536
537        // Convert DynRank to fixed 2D shape using from_fn (safe conversion)
538        debug_println!("spir_matsu_sampling_new_with_matrix: converting to fixed 2D tensor...");
539        std::io::stderr().flush().ok();
540        assert_eq!(
541            shape_dims.len(),
542            2,
543            "Expected 2D tensor, got {}D",
544            shape_dims.len()
545        );
546        let num_points_actual = shape_dims[0];
547        let basis_size_actual = shape_dims[1];
548        debug_println!(
549            "spir_matsu_sampling_new_with_matrix: converting from shape {:?} to DTensor<Complex64, 2>",
550            shape_dims
551        );
552        std::io::stderr().flush().ok();
553        let matrix_tensor = sparse_ir::DTensor::<Complex64, 2>::from_fn(
554            [num_points_actual, basis_size_actual],
555            |idx| dyn_tensor[&[idx[0], idx[1]][..]],
556        );
557        debug_println!(
558            "spir_matsu_sampling_new_with_matrix: matrix_tensor created, shape = {:?}",
559            matrix_tensor.shape()
560        );
561        std::io::stderr().flush().ok();
562
563        // Create sampling based on statistics and positive_only
564        debug_println!(
565            "spir_matsu_sampling_new_with_matrix: creating sampling, statistics={}, positive_only={}",
566            statistics,
567            positive_only
568        );
569        std::io::stderr().flush().ok();
570        let sampling_type = match (statistics, positive_only) {
571            (SPIR_STATISTICS_FERMIONIC, true) => {
572                debug_println!("spir_matsu_sampling_new_with_matrix: Fermionic, positive-only");
573                std::io::stderr().flush().ok();
574                // Fermionic, positive-only
575                let matsu_freqs: Vec<MatsubaraFreq<Fermionic>> = matsu_points
576                    .iter()
577                    .map(|&n| MatsubaraFreq::new(n).expect("Invalid Matsubara frequency"))
578                    .collect();
579                debug_println!(
580                    "spir_matsu_sampling_new_with_matrix: matsu_freqs created, len = {}",
581                    matsu_freqs.len()
582                );
583                std::io::stderr().flush().ok();
584                debug_println!("spir_matsu_sampling_new_with_matrix: calling from_matrix...");
585                std::io::stderr().flush().ok();
586                let matsu_sampling =
587                    sparse_ir::matsubara_sampling::MatsubaraSamplingPositiveOnly::from_matrix(
588                        matsu_freqs,
589                        matrix_tensor.clone(),
590                    );
591                debug_println!("spir_matsu_sampling_new_with_matrix: from_matrix returned");
592                std::io::stderr().flush().ok();
593                SamplingType::MatsubaraPositiveOnlyFermionic(Arc::new(matsu_sampling))
594            }
595            (SPIR_STATISTICS_FERMIONIC, false) => {
596                debug_println!("spir_matsu_sampling_new_with_matrix: Fermionic, full range");
597                std::io::stderr().flush().ok();
598                // Fermionic, full range
599                let matsu_freqs: Vec<MatsubaraFreq<Fermionic>> = matsu_points
600                    .iter()
601                    .map(|&n| MatsubaraFreq::new(n).expect("Invalid Matsubara frequency"))
602                    .collect();
603                debug_println!(
604                    "spir_matsu_sampling_new_with_matrix: matsu_freqs created, len = {}",
605                    matsu_freqs.len()
606                );
607                std::io::stderr().flush().ok();
608                debug_println!("spir_matsu_sampling_new_with_matrix: calling from_matrix...");
609                std::io::stderr().flush().ok();
610                let matsu_sampling = sparse_ir::matsubara_sampling::MatsubaraSampling::from_matrix(
611                    matsu_freqs,
612                    matrix_tensor.clone(),
613                );
614                debug_println!("spir_matsu_sampling_new_with_matrix: from_matrix returned");
615                std::io::stderr().flush().ok();
616                SamplingType::MatsubaraFermionic(Arc::new(matsu_sampling))
617            }
618            (SPIR_STATISTICS_BOSONIC, true) => {
619                // Bosonic, positive-only
620                let matsu_freqs: Vec<MatsubaraFreq<Bosonic>> = matsu_points
621                    .iter()
622                    .map(|&n| MatsubaraFreq::new(n).expect("Invalid Matsubara frequency"))
623                    .collect();
624                let matsu_sampling =
625                    sparse_ir::matsubara_sampling::MatsubaraSamplingPositiveOnly::from_matrix(
626                        matsu_freqs,
627                        matrix_tensor.clone(),
628                    );
629                SamplingType::MatsubaraPositiveOnlyBosonic(Arc::new(matsu_sampling))
630            }
631            (SPIR_STATISTICS_BOSONIC, false) => {
632                // Bosonic, full range
633                let matsu_freqs: Vec<MatsubaraFreq<Bosonic>> = matsu_points
634                    .iter()
635                    .map(|&n| MatsubaraFreq::new(n).expect("Invalid Matsubara frequency"))
636                    .collect();
637                let matsu_sampling = sparse_ir::matsubara_sampling::MatsubaraSampling::from_matrix(
638                    matsu_freqs,
639                    matrix_tensor.clone(),
640                );
641                SamplingType::MatsubaraBosonic(Arc::new(matsu_sampling))
642            }
643            _ => return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT),
644        };
645
646        let inner = sampling_type;
647        let sampling = spir_sampling {
648            _private: Box::into_raw(Box::new(inner)) as *mut std::ffi::c_void,
649        };
650
651        (Box::into_raw(Box::new(sampling)), SPIR_COMPUTATION_SUCCESS)
652    }));
653
654    match result {
655        Ok((ptr, code)) => {
656            if !status.is_null() {
657                unsafe {
658                    *status = code;
659                }
660            }
661            ptr
662        }
663        Err(_) => {
664            if !status.is_null() {
665                unsafe {
666                    *status = crate::SPIR_INTERNAL_ERROR;
667                }
668            }
669            std::ptr::null_mut()
670        }
671    }
672}
673
674// ============================================================================
675// Introspection Functions
676// ============================================================================
677
678/// Gets the number of sampling points in a sampling object.
679///
680/// This function returns the number of sampling points used in the specified
681/// sampling object. This number is needed to allocate arrays of the correct size
682/// when retrieving the actual sampling points.
683///
684/// # Arguments
685///
686/// * `s` - Pointer to the sampling object.
687/// * `num_points` - Pointer to store the number of sampling points.
688///
689/// # Returns
690///
691/// A status code:
692/// - `0` ([`SPIR_COMPUTATION_SUCCESS`]) on success
693/// - A non-zero error code on failure
694///
695/// # See also
696///
697/// - [`spir_sampling_get_taus`]
698/// - [`spir_sampling_get_matsus`]
699#[unsafe(no_mangle)]
700pub extern "C" fn spir_sampling_get_npoints(
701    s: *const spir_sampling,
702    num_points: *mut libc::c_int,
703) -> StatusCode {
704    let result = catch_unwind(AssertUnwindSafe(|| {
705        if s.is_null() || num_points.is_null() {
706            return SPIR_INVALID_ARGUMENT;
707        }
708
709        let sampling_ref = unsafe { &*s };
710
711        let n_points = match sampling_ref.inner() {
712            SamplingType::TauFermionic(tau) => tau.n_sampling_points(),
713            SamplingType::TauBosonic(tau) => tau.n_sampling_points(),
714            SamplingType::MatsubaraFermionic(matsu) => matsu.n_sampling_points(),
715            SamplingType::MatsubaraBosonic(matsu) => matsu.n_sampling_points(),
716            SamplingType::MatsubaraPositiveOnlyFermionic(matsu) => matsu.n_sampling_points(),
717            SamplingType::MatsubaraPositiveOnlyBosonic(matsu) => matsu.n_sampling_points(),
718        };
719
720        unsafe {
721            *num_points = n_points as libc::c_int;
722        }
723        SPIR_COMPUTATION_SUCCESS
724    }));
725
726    result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
727}
728
729/// Gets the imaginary time (τ) sampling points used in the specified sampling object.
730///
731/// This function fills the provided array with the imaginary time (τ) sampling points used in the specified sampling object.
732/// The array must be pre-allocated with sufficient size (use [`spir_sampling_get_npoints`] to determine the required size).
733///
734/// # Arguments
735///
736/// * `s` - Pointer to the sampling object.
737/// * `points` - Pre-allocated array to store the τ sampling points.
738///
739/// # Returns
740///
741/// An integer status code:
742/// - `0` ([`SPIR_COMPUTATION_SUCCESS`]) on success
743/// - A non-zero error code on failure
744///
745/// # Notes
746///
747/// The array must be pre-allocated with size >= [`spir_sampling_get_npoints`](spir_sampling_get_npoints).
748///
749/// # See also
750///
751/// - [`spir_sampling_get_npoints`]
752#[unsafe(no_mangle)]
753pub extern "C" fn spir_sampling_get_taus(s: *const spir_sampling, points: *mut f64) -> StatusCode {
754    let result = catch_unwind(AssertUnwindSafe(|| {
755        if s.is_null() || points.is_null() {
756            return SPIR_INVALID_ARGUMENT;
757        }
758
759        let sampling_ref = unsafe { &*s };
760
761        match sampling_ref.inner() {
762            SamplingType::TauFermionic(tau) => {
763                let tau_points = tau.sampling_points();
764                let out_slice = unsafe { std::slice::from_raw_parts_mut(points, tau_points.len()) };
765                out_slice.copy_from_slice(tau_points);
766                SPIR_COMPUTATION_SUCCESS
767            }
768            SamplingType::TauBosonic(tau) => {
769                let tau_points = tau.sampling_points();
770                let out_slice = unsafe { std::slice::from_raw_parts_mut(points, tau_points.len()) };
771                out_slice.copy_from_slice(tau_points);
772                SPIR_COMPUTATION_SUCCESS
773            }
774            _ => SPIR_NOT_SUPPORTED,
775        }
776    }));
777
778    result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
779}
780
781/// Gets the Matsubara frequency sampling points
782#[unsafe(no_mangle)]
783pub extern "C" fn spir_sampling_get_matsus(
784    s: *const spir_sampling,
785    points: *mut i64,
786) -> StatusCode {
787    let result = catch_unwind(AssertUnwindSafe(|| {
788        if s.is_null() || points.is_null() {
789            return SPIR_INVALID_ARGUMENT;
790        }
791
792        let sampling_ref = unsafe { &*s };
793
794        match sampling_ref.inner() {
795            SamplingType::MatsubaraFermionic(matsu) => {
796                let matsu_freqs = matsu.sampling_points();
797                let out_slice =
798                    unsafe { std::slice::from_raw_parts_mut(points, matsu_freqs.len()) };
799                for (i, freq) in matsu_freqs.iter().enumerate() {
800                    out_slice[i] = freq.n();
801                }
802                SPIR_COMPUTATION_SUCCESS
803            }
804            SamplingType::MatsubaraBosonic(matsu) => {
805                let matsu_freqs = matsu.sampling_points();
806                let out_slice =
807                    unsafe { std::slice::from_raw_parts_mut(points, matsu_freqs.len()) };
808                for (i, freq) in matsu_freqs.iter().enumerate() {
809                    out_slice[i] = freq.n();
810                }
811                SPIR_COMPUTATION_SUCCESS
812            }
813            SamplingType::MatsubaraPositiveOnlyFermionic(matsu) => {
814                let matsu_freqs = matsu.sampling_points();
815                let out_slice =
816                    unsafe { std::slice::from_raw_parts_mut(points, matsu_freqs.len()) };
817                for (i, freq) in matsu_freqs.iter().enumerate() {
818                    out_slice[i] = freq.n();
819                }
820                SPIR_COMPUTATION_SUCCESS
821            }
822            SamplingType::MatsubaraPositiveOnlyBosonic(matsu) => {
823                let matsu_freqs = matsu.sampling_points();
824                let out_slice =
825                    unsafe { std::slice::from_raw_parts_mut(points, matsu_freqs.len()) };
826                for (i, freq) in matsu_freqs.iter().enumerate() {
827                    out_slice[i] = freq.n();
828                }
829                SPIR_COMPUTATION_SUCCESS
830            }
831            _ => SPIR_NOT_SUPPORTED,
832        }
833    }));
834
835    result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
836}
837
838/// Gets the condition number of the sampling matrix.
839///
840/// This function returns the condition number of the sampling matrix used in the
841/// specified sampling object. The condition number is a measure of how well-
842/// conditioned the sampling matrix is.
843///
844/// # Parameters
845/// - `s`: Pointer to the sampling object.
846/// - `cond_num`: Pointer to store the condition number.
847///
848/// # Returns
849/// An integer status code:
850/// - 0 (`SPIR_COMPUTATION_SUCCESS`) on success
851/// - Non-zero error code on failure
852///
853/// # Notes
854/// - A large condition number indicates that the sampling matrix is ill-conditioned,
855///   which may lead to numerical instability in transformations.
856/// - The condition number is the ratio of the largest to smallest singular value
857///   of the sampling matrix.
858#[unsafe(no_mangle)]
859pub extern "C" fn spir_sampling_get_cond_num(
860    s: *const spir_sampling,
861    cond_num: *mut f64,
862) -> StatusCode {
863    let result = catch_unwind(AssertUnwindSafe(|| {
864        if s.is_null() || cond_num.is_null() {
865            return SPIR_INVALID_ARGUMENT;
866        }
867
868        let sampling_ref = unsafe { &*s };
869
870        // Calculate condition number from SVD of the sampling matrix
871        let condition_number = match sampling_ref.inner() {
872            SamplingType::TauFermionic(tau) => {
873                // For tau sampling, matrix is real
874                let matrix = tau.matrix();
875                compute_condition_number_real(matrix)
876            }
877            SamplingType::TauBosonic(tau) => {
878                // For tau sampling, matrix is real
879                let matrix = tau.matrix();
880                compute_condition_number_real(matrix)
881            }
882            SamplingType::MatsubaraFermionic(matsu) => {
883                // For Matsubara sampling, matrix is complex
884                let matrix = matsu.matrix();
885                compute_condition_number_complex(matrix)
886            }
887            SamplingType::MatsubaraBosonic(matsu) => {
888                let matrix = matsu.matrix();
889                compute_condition_number_complex(matrix)
890            }
891            SamplingType::MatsubaraPositiveOnlyFermionic(matsu) => {
892                // For positive-only Matsubara, use the complex matrix
893                // The fitter uses ComplexToRealFitter internally, but we can use the complex matrix
894                let matrix = matsu.matrix();
895                compute_condition_number_complex(matrix)
896            }
897            SamplingType::MatsubaraPositiveOnlyBosonic(matsu) => {
898                let matrix = matsu.matrix();
899                compute_condition_number_complex(matrix)
900            }
901        };
902
903        unsafe {
904            *cond_num = condition_number;
905        }
906        SPIR_COMPUTATION_SUCCESS
907    }));
908
909    result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
910}
911
912/// Compute condition number from real matrix using SVD
913fn compute_condition_number_real(matrix: &mdarray::DTensor<f64, 2>) -> f64 {
914    use mdarray_linalg::prelude::SVD;
915    use mdarray_linalg::svd::SVDDecomp;
916    use mdarray_linalg_faer::Faer;
917
918    let mut matrix_copy = matrix.clone();
919    let SVDDecomp { s, .. } = Faer.svd(&mut *matrix_copy).expect("SVD computation failed");
920
921    let min_dim = s.shape().0.min(s.shape().1);
922    if min_dim == 0 {
923        return 1.0;
924    }
925
926    let max_sv = s[[0, 0]];
927    let min_sv = s[[0, min_dim - 1]];
928
929    if min_sv.abs() < 1e-15 {
930        // Matrix is singular or nearly singular
931        return f64::INFINITY;
932    }
933
934    max_sv / min_sv
935}
936
937/// Compute condition number from complex matrix using SVD
938fn compute_condition_number_complex(matrix: &mdarray::DTensor<num_complex::Complex64, 2>) -> f64 {
939    use mdarray_linalg::prelude::SVD;
940    use mdarray_linalg::svd::SVDDecomp;
941    use mdarray_linalg_faer::Faer;
942
943    let mut matrix_copy = matrix.clone();
944    let SVDDecomp { s, .. } = Faer.svd(&mut *matrix_copy).expect("SVD computation failed");
945
946    let min_dim = s.shape().0.min(s.shape().1);
947    if min_dim == 0 {
948        return 1.0;
949    }
950
951    // Singular values are real (stored as Complex, but imaginary part is 0)
952    let max_sv = s[[0, 0]].re;
953    let min_sv = s[[0, min_dim - 1]].re;
954
955    if min_sv.abs() < 1e-15 {
956        // Matrix is singular or nearly singular
957        return f64::INFINITY;
958    }
959
960    max_sv / min_sv
961}
962
963// ============================================================================
964// Evaluation Functions (coefficients → sampling points)
965// ============================================================================
966
967/// Evaluates basis coefficients at sampling points (double to double version).
968///
969/// Transforms basis coefficients to values at sampling points, where both input
970/// and output are real (double precision) values. The operation can be performed
971/// along any dimension of a multidimensional array.
972///
973/// # Arguments
974///
975/// * `s` - Pointer to the sampling object
976/// * `order` - Memory layout order (`SPIR_ORDER_ROW_MAJOR` or `SPIR_ORDER_COLUMN_MAJOR`)
977/// * `ndim` - Number of dimensions in the input/output arrays
978/// * `input_dims` - Array of dimension sizes
979/// * `target_dim` - Target dimension for the transformation (0-based)
980/// * `input` - Input array of basis coefficients
981/// * `out` - Output array for the evaluated values at sampling points
982///
983/// # Returns
984///
985/// An integer status code:
986/// - `0` (`SPIR_COMPUTATION_SUCCESS`) on success
987/// - A non-zero error code on failure
988///
989/// # Notes
990///
991/// - For optimal performance, the target dimension should be either the
992///   first (`0`) or the last (`ndim-1`) dimension to avoid large temporary array allocations
993/// - The output array must be pre-allocated with the correct size
994/// - The input and output arrays must be contiguous in memory
995/// - The transformation is performed using a pre-computed sampling matrix
996///   that is factorized using SVD for efficiency
997///
998/// # See also
999/// - [`spir_sampling_eval_dz`]
1000/// - [`spir_sampling_eval_zz`]
1001/// # Note
1002/// Supports both row-major and column-major order. Zero-copy implementation.
1003#[unsafe(no_mangle)]
1004pub extern "C" fn spir_sampling_eval_dd(
1005    s: *const spir_sampling,
1006    backend: *const spir_gemm_backend,
1007    order: libc::c_int,
1008    ndim: libc::c_int,
1009    input_dims: *const libc::c_int,
1010    target_dim: libc::c_int,
1011    input: *const f64,
1012    out: *mut f64,
1013) -> StatusCode {
1014    let result = catch_unwind(AssertUnwindSafe(|| {
1015        // Validate inputs
1016        if s.is_null() || input_dims.is_null() || input.is_null() || out.is_null() {
1017            return SPIR_INVALID_ARGUMENT;
1018        }
1019        if ndim <= 0 || target_dim < 0 || target_dim >= ndim {
1020            return SPIR_INVALID_ARGUMENT;
1021        }
1022
1023        // Parse order
1024        let mem_order = match MemoryOrder::from_c_int(order) {
1025            Ok(o) => o,
1026            Err(_) => return SPIR_INVALID_ARGUMENT,
1027        };
1028
1029        let sampling_ref = unsafe { &*s };
1030        let dims_slice = unsafe { std::slice::from_raw_parts(input_dims, ndim as usize) };
1031        let orig_dims: Vec<usize> = dims_slice.iter().map(|&d| d as usize).collect();
1032
1033        // Convert dimensions for row-major processing
1034        // For column-major, this reverses dims and adjusts target_dim
1035        let (row_major_dims, row_major_target_dim) =
1036            convert_dims_for_row_major(&orig_dims, target_dim as usize, mem_order);
1037
1038        // Create input view directly from buffer (zero-copy)
1039        let input_view = unsafe { create_dview_from_ptr(input, &row_major_dims) };
1040
1041        // Validate that input dimension matches basis size
1042        let sampling_inner = sampling_ref.inner();
1043        let expected_basis_size = sampling_inner.basis_size();
1044        if row_major_dims[row_major_target_dim] != expected_basis_size {
1045            return crate::SPIR_INPUT_DIMENSION_MISMATCH;
1046        }
1047
1048        // Build output dimensions
1049        let n_points = sampling_inner.n_points();
1050        let out_dims = build_output_dims(&row_major_dims, row_major_target_dim, n_points);
1051
1052        // Create output view directly from buffer (zero-copy)
1053        let mut output_view = unsafe { create_dviewmut_from_ptr(out, &out_dims) };
1054
1055        // Get backend handle (NULL means use default)
1056        let backend_handle = unsafe { get_backend_handle(backend) };
1057
1058        // Evaluate using InplaceFitter (zero-copy: writes directly to output buffer)
1059        if !InplaceFitter::evaluate_nd_dd_to(
1060            sampling_inner,
1061            backend_handle,
1062            &input_view,
1063            row_major_target_dim,
1064            &mut output_view,
1065        ) {
1066            return SPIR_NOT_SUPPORTED;
1067        }
1068
1069        SPIR_COMPUTATION_SUCCESS
1070    }));
1071
1072    result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
1073}
1074
1075/// Evaluate basis coefficients at sampling points (double → complex)
1076///
1077/// For Matsubara sampling: transforms real IR coefficients to complex values.
1078/// Zero-copy implementation.
1079#[unsafe(no_mangle)]
1080pub extern "C" fn spir_sampling_eval_dz(
1081    s: *const spir_sampling,
1082    backend: *const spir_gemm_backend,
1083    order: libc::c_int,
1084    ndim: libc::c_int,
1085    input_dims: *const libc::c_int,
1086    target_dim: libc::c_int,
1087    input: *const f64,
1088    out: *mut Complex64,
1089) -> StatusCode {
1090    let result = catch_unwind(AssertUnwindSafe(|| {
1091        // Validate inputs
1092        if s.is_null() || input_dims.is_null() || input.is_null() || out.is_null() {
1093            return SPIR_INVALID_ARGUMENT;
1094        }
1095        if ndim <= 0 || target_dim < 0 || target_dim >= ndim {
1096            return SPIR_INVALID_ARGUMENT;
1097        }
1098
1099        // Parse order
1100        let mem_order = match MemoryOrder::from_c_int(order) {
1101            Ok(o) => o,
1102            Err(_) => return SPIR_INVALID_ARGUMENT,
1103        };
1104
1105        let sampling_ref = unsafe { &*s };
1106        let dims_slice = unsafe { std::slice::from_raw_parts(input_dims, ndim as usize) };
1107        let orig_dims: Vec<usize> = dims_slice.iter().map(|&d| d as usize).collect();
1108
1109        // Convert dimensions for row-major processing
1110        let (row_major_dims, row_major_target_dim) =
1111            convert_dims_for_row_major(&orig_dims, target_dim as usize, mem_order);
1112
1113        // Create input view directly from buffer (zero-copy)
1114        let input_view = unsafe { create_dview_from_ptr(input, &row_major_dims) };
1115
1116        // Validate that input dimension matches basis size
1117        let sampling_inner = sampling_ref.inner();
1118        let expected_basis_size = sampling_inner.basis_size();
1119        if row_major_dims[row_major_target_dim] != expected_basis_size {
1120            return crate::SPIR_INPUT_DIMENSION_MISMATCH;
1121        }
1122
1123        // Build output dimensions
1124        let n_points = sampling_inner.n_points();
1125        let out_dims = build_output_dims(&row_major_dims, row_major_target_dim, n_points);
1126
1127        // Create output view directly from buffer (zero-copy)
1128        let mut output_view = unsafe { create_dviewmut_from_ptr(out, &out_dims) };
1129
1130        // Get backend handle (NULL means use default)
1131        let backend_handle = unsafe { get_backend_handle(backend) };
1132
1133        // Evaluate using InplaceFitter (dz: real → complex)
1134        if !InplaceFitter::evaluate_nd_dz_to(
1135            sampling_inner,
1136            backend_handle,
1137            &input_view,
1138            row_major_target_dim,
1139            &mut output_view,
1140        ) {
1141            return SPIR_NOT_SUPPORTED;
1142        }
1143
1144        SPIR_COMPUTATION_SUCCESS
1145    }));
1146
1147    result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
1148}
1149
1150/// Evaluate basis coefficients at sampling points (complex → complex)
1151///
1152/// For Matsubara sampling: transforms complex coefficients to complex values.
1153/// Zero-copy implementation.
1154#[unsafe(no_mangle)]
1155pub extern "C" fn spir_sampling_eval_zz(
1156    s: *const spir_sampling,
1157    backend: *const spir_gemm_backend,
1158    order: libc::c_int,
1159    ndim: libc::c_int,
1160    input_dims: *const libc::c_int,
1161    target_dim: libc::c_int,
1162    input: *const Complex64,
1163    out: *mut Complex64,
1164) -> StatusCode {
1165    let result = catch_unwind(AssertUnwindSafe(|| {
1166        if s.is_null() || input_dims.is_null() || input.is_null() || out.is_null() {
1167            return SPIR_INVALID_ARGUMENT;
1168        }
1169        if ndim <= 0 || target_dim < 0 || target_dim >= ndim {
1170            return SPIR_INVALID_ARGUMENT;
1171        }
1172
1173        // Parse order
1174        let mem_order = match MemoryOrder::from_c_int(order) {
1175            Ok(o) => o,
1176            Err(_) => return SPIR_INVALID_ARGUMENT,
1177        };
1178
1179        let sampling_ref = unsafe { &*s };
1180        let dims_slice = unsafe { std::slice::from_raw_parts(input_dims, ndim as usize) };
1181        let orig_dims: Vec<usize> = dims_slice.iter().map(|&d| d as usize).collect();
1182
1183        // Convert dimensions for row-major processing
1184        let (row_major_dims, row_major_target_dim) =
1185            convert_dims_for_row_major(&orig_dims, target_dim as usize, mem_order);
1186
1187        // Create input view directly from buffer (zero-copy)
1188        let input_view = unsafe { create_dview_from_ptr(input, &row_major_dims) };
1189
1190        // Validate that input dimension matches basis size
1191        let sampling_inner = sampling_ref.inner();
1192        let expected_basis_size = sampling_inner.basis_size();
1193        if row_major_dims[row_major_target_dim] != expected_basis_size {
1194            return crate::SPIR_INPUT_DIMENSION_MISMATCH;
1195        }
1196
1197        // Build output dimensions
1198        let n_points = sampling_inner.n_points();
1199        let out_dims = build_output_dims(&row_major_dims, row_major_target_dim, n_points);
1200
1201        // Create output view directly from buffer (zero-copy)
1202        let mut output_view = unsafe { create_dviewmut_from_ptr(out, &out_dims) };
1203
1204        // Get backend handle (NULL means use default)
1205        let backend_handle = unsafe { get_backend_handle(backend) };
1206
1207        // Evaluate using InplaceFitter (zz: complex → complex)
1208        if !InplaceFitter::evaluate_nd_zz_to(
1209            sampling_inner,
1210            backend_handle,
1211            &input_view,
1212            row_major_target_dim,
1213            &mut output_view,
1214        ) {
1215            return SPIR_NOT_SUPPORTED;
1216        }
1217
1218        SPIR_COMPUTATION_SUCCESS
1219    }));
1220
1221    result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
1222}
1223
1224// ============================================================================
1225// Fitting Functions (sampling points → coefficients)
1226// ============================================================================
1227
1228/// Fits values at sampling points to basis coefficients (double to double version).
1229///
1230/// Transforms values at sampling points back to basis coefficients, where both
1231/// input and output are real (double precision) values. The operation can be
1232/// performed along any dimension of a multidimensional array.
1233///
1234/// # Arguments
1235///
1236/// * `s` - Pointer to the sampling object
1237/// * `backend` - Pointer to the GEMM backend (can be null to use default)
1238/// * `order` - Memory layout order (SPIR_ORDER_ROW_MAJOR or SPIR_ORDER_COLUMN_MAJOR)
1239/// * `ndim` - Number of dimensions in the input/output arrays
1240/// * `input_dims` - Array of dimension sizes
1241/// * `target_dim` - Target dimension for the transformation (0-based)
1242/// * `input` - Input array of values at sampling points
1243/// * `out` - Output array for the fitted basis coefficients
1244///
1245/// # Returns
1246///
1247/// An integer status code:
1248/// * `0` (SPIR_COMPUTATION_SUCCESS) on success
1249/// * A non-zero error code on failure
1250///
1251/// # Notes
1252///
1253/// * The output array must be pre-allocated with the correct size
1254/// * This function performs the inverse operation of `spir_sampling_eval_dd`
1255/// * The transformation is performed using a pre-computed sampling matrix
1256///   that is factorized using SVD for efficiency
1257/// * Zero-copy implementation
1258///
1259/// # See also
1260///
1261/// * [`spir_sampling_eval_dd`]
1262/// * [`spir_sampling_fit_zz`]
1263#[unsafe(no_mangle)]
1264pub extern "C" fn spir_sampling_fit_dd(
1265    s: *const spir_sampling,
1266    backend: *const spir_gemm_backend,
1267    order: libc::c_int,
1268    ndim: libc::c_int,
1269    input_dims: *const libc::c_int,
1270    target_dim: libc::c_int,
1271    input: *const f64,
1272    out: *mut f64,
1273) -> StatusCode {
1274    let result = catch_unwind(AssertUnwindSafe(|| {
1275        if s.is_null() || input_dims.is_null() || input.is_null() || out.is_null() {
1276            return SPIR_INVALID_ARGUMENT;
1277        }
1278        if ndim <= 0 || target_dim < 0 || target_dim >= ndim {
1279            return SPIR_INVALID_ARGUMENT;
1280        }
1281
1282        // Parse order
1283        let mem_order = match MemoryOrder::from_c_int(order) {
1284            Ok(o) => o,
1285            Err(_) => return SPIR_INVALID_ARGUMENT,
1286        };
1287
1288        let sampling_ref = unsafe { &*s };
1289        let dims_slice = unsafe { std::slice::from_raw_parts(input_dims, ndim as usize) };
1290        let orig_dims: Vec<usize> = dims_slice.iter().map(|&d| d as usize).collect();
1291
1292        // Convert dimensions for row-major processing
1293        let (row_major_dims, row_major_target_dim) =
1294            convert_dims_for_row_major(&orig_dims, target_dim as usize, mem_order);
1295
1296        // Create input view directly from buffer (zero-copy)
1297        let input_view = unsafe { create_dview_from_ptr(input, &row_major_dims) };
1298
1299        // Validate that input dimension matches n_points
1300        let sampling_inner = sampling_ref.inner();
1301        let expected_n_points = sampling_inner.n_points();
1302        if row_major_dims[row_major_target_dim] != expected_n_points {
1303            return crate::SPIR_INPUT_DIMENSION_MISMATCH;
1304        }
1305
1306        // Build output dimensions (replace n_points with basis_size)
1307        let basis_size = sampling_inner.basis_size();
1308        let out_dims = build_output_dims(&row_major_dims, row_major_target_dim, basis_size);
1309
1310        // Create output view directly from buffer (zero-copy)
1311        let mut output_view = unsafe { create_dviewmut_from_ptr(out, &out_dims) };
1312
1313        // Get backend handle (NULL means use default)
1314        let backend_handle = unsafe { get_backend_handle(backend) };
1315
1316        // Fit using InplaceFitter (dd: real → real)
1317        if !InplaceFitter::fit_nd_dd_to(
1318            sampling_inner,
1319            backend_handle,
1320            &input_view,
1321            row_major_target_dim,
1322            &mut output_view,
1323        ) {
1324            return SPIR_NOT_SUPPORTED;
1325        }
1326
1327        SPIR_COMPUTATION_SUCCESS
1328    }));
1329
1330    result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
1331}
1332
1333/// Fits values at sampling points to basis coefficients (complex to complex version).
1334///
1335/// For more details, see [`spir_sampling_fit_dd`]
1336/// Zero-copy implementation for Tau and Matsubara (full).
1337/// MatsubaraPositiveOnly requires intermediate storage for real→complex conversion.
1338#[unsafe(no_mangle)]
1339pub extern "C" fn spir_sampling_fit_zz(
1340    s: *const spir_sampling,
1341    backend: *const spir_gemm_backend,
1342    order: libc::c_int,
1343    ndim: libc::c_int,
1344    input_dims: *const libc::c_int,
1345    target_dim: libc::c_int,
1346    input: *const Complex64,
1347    out: *mut Complex64,
1348) -> StatusCode {
1349    let result = catch_unwind(AssertUnwindSafe(|| {
1350        if s.is_null() || input_dims.is_null() || input.is_null() || out.is_null() {
1351            return SPIR_INVALID_ARGUMENT;
1352        }
1353        if ndim <= 0 || target_dim < 0 || target_dim >= ndim {
1354            return SPIR_INVALID_ARGUMENT;
1355        }
1356
1357        // Parse order
1358        let mem_order = match MemoryOrder::from_c_int(order) {
1359            Ok(o) => o,
1360            Err(_) => return SPIR_INVALID_ARGUMENT,
1361        };
1362
1363        let sampling_ref = unsafe { &*s };
1364        let dims_slice = unsafe { std::slice::from_raw_parts(input_dims, ndim as usize) };
1365        let orig_dims: Vec<usize> = dims_slice.iter().map(|&d| d as usize).collect();
1366
1367        // Convert dimensions for row-major processing
1368        let (row_major_dims, row_major_target_dim) =
1369            convert_dims_for_row_major(&orig_dims, target_dim as usize, mem_order);
1370
1371        // Create input view directly from buffer (zero-copy)
1372        let input_view = unsafe { create_dview_from_ptr(input, &row_major_dims) };
1373
1374        // Validate that input dimension matches n_points
1375        let sampling_inner = sampling_ref.inner();
1376        let expected_n_points = sampling_inner.n_points();
1377        if row_major_dims[row_major_target_dim] != expected_n_points {
1378            return crate::SPIR_INPUT_DIMENSION_MISMATCH;
1379        }
1380
1381        // Build output dimensions (replace n_points with basis_size)
1382        let basis_size = sampling_inner.basis_size();
1383        let out_dims = build_output_dims(&row_major_dims, row_major_target_dim, basis_size);
1384
1385        // Create output view directly from buffer (zero-copy)
1386        let mut output_view = unsafe { create_dviewmut_from_ptr(out, &out_dims) };
1387
1388        // Get backend handle (NULL means use default)
1389        let backend_handle = unsafe { get_backend_handle(backend) };
1390
1391        // Fit using InplaceFitter (zz: complex → complex)
1392        if !InplaceFitter::fit_nd_zz_to(
1393            sampling_inner,
1394            backend_handle,
1395            &input_view,
1396            row_major_target_dim,
1397            &mut output_view,
1398        ) {
1399            return SPIR_NOT_SUPPORTED;
1400        }
1401
1402        SPIR_COMPUTATION_SUCCESS
1403    }));
1404
1405    result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
1406}
1407
1408/// Fit basis coefficients from Matsubara sampling points (complex input, real output)
1409///
1410/// This function fits basis coefficients from Matsubara sampling points
1411/// using complex input and real output.
1412///
1413/// # Supported Sampling Types
1414///
1415/// - **Matsubara (full)**: ✅ Supported (takes real part of fitted complex coefficients)
1416/// - **Matsubara (positive_only)**: ✅ Supported
1417/// - **Tau**: ❌ Not supported (use `spir_sampling_fit_dd` instead)
1418///
1419/// # Notes
1420///
1421/// For full-range Matsubara sampling, this function fits complex coefficients
1422/// internally and returns their real parts. This is physically correct for
1423/// Green's functions where IR coefficients are guaranteed to be real by symmetry.
1424///
1425/// Zero-copy implementation.
1426///
1427/// # Arguments
1428///
1429/// * `s` - Pointer to the sampling object (must be Matsubara)
1430/// * `backend` - Pointer to the GEMM backend (can be null to use default)
1431/// * `order` - Memory layout order (SPIR_ORDER_COLUMN_MAJOR or SPIR_ORDER_ROW_MAJOR)
1432/// * `ndim` - Number of dimensions in the input/output arrays
1433/// * `input_dims` - Array of dimension sizes
1434/// * `target_dim` - Target dimension for the transformation (0-based)
1435/// * `input` - Input array (complex)
1436/// * `out` - Output array (real)
1437///
1438/// # Returns
1439///
1440/// - `SPIR_COMPUTATION_SUCCESS` on success
1441/// - `SPIR_NOT_SUPPORTED` if the sampling type doesn't support this operation
1442/// - Other error codes on failure
1443///
1444/// # See also
1445///
1446/// * [`spir_sampling_fit_zz`]
1447/// * [`spir_sampling_fit_dd`]
1448#[unsafe(no_mangle)]
1449pub extern "C" fn spir_sampling_fit_zd(
1450    s: *const spir_sampling,
1451    backend: *const spir_gemm_backend,
1452    order: libc::c_int,
1453    ndim: libc::c_int,
1454    input_dims: *const libc::c_int,
1455    target_dim: libc::c_int,
1456    input: *const Complex64,
1457    out: *mut f64,
1458) -> StatusCode {
1459    let result = catch_unwind(AssertUnwindSafe(|| {
1460        if s.is_null() || input_dims.is_null() || input.is_null() || out.is_null() {
1461            return SPIR_INVALID_ARGUMENT;
1462        }
1463        if ndim <= 0 || target_dim < 0 || target_dim >= ndim {
1464            return SPIR_INVALID_ARGUMENT;
1465        }
1466
1467        // Parse order
1468        let mem_order = match MemoryOrder::from_c_int(order) {
1469            Ok(o) => o,
1470            Err(_) => return SPIR_INVALID_ARGUMENT,
1471        };
1472
1473        let sampling_ref = unsafe { &*s };
1474        let dims_slice = unsafe { std::slice::from_raw_parts(input_dims, ndim as usize) };
1475        let orig_dims: Vec<usize> = dims_slice.iter().map(|&d| d as usize).collect();
1476
1477        // Convert dimensions for row-major processing
1478        let (row_major_dims, row_major_target_dim) =
1479            convert_dims_for_row_major(&orig_dims, target_dim as usize, mem_order);
1480
1481        // Create input view directly from buffer (zero-copy)
1482        let input_view = unsafe { create_dview_from_ptr(input, &row_major_dims) };
1483
1484        // Validate that input dimension matches n_points
1485        let sampling_inner = sampling_ref.inner();
1486        let expected_n_points = sampling_inner.n_points();
1487        if row_major_dims[row_major_target_dim] != expected_n_points {
1488            return crate::SPIR_INPUT_DIMENSION_MISMATCH;
1489        }
1490
1491        // Build output dimensions (replace n_points with basis_size)
1492        let basis_size = sampling_inner.basis_size();
1493        let out_dims = build_output_dims(&row_major_dims, row_major_target_dim, basis_size);
1494
1495        // Create output view directly from buffer (zero-copy)
1496        let mut output_view = unsafe { create_dviewmut_from_ptr(out, &out_dims) };
1497
1498        // Get backend handle (NULL means use default)
1499        let backend_handle = unsafe { get_backend_handle(backend) };
1500
1501        // Fit using InplaceFitter (zd: complex → real)
1502        // Note: For full-range Matsubara, this takes the real part of the fitted
1503        // complex coefficients. This is physically correct for Green's functions
1504        // where IR coefficients are guaranteed to be real by symmetry.
1505        if !InplaceFitter::fit_nd_zd_to(
1506            sampling_inner,
1507            backend_handle,
1508            &input_view,
1509            row_major_target_dim,
1510            &mut output_view,
1511        ) {
1512            return SPIR_NOT_SUPPORTED;
1513        }
1514
1515        SPIR_COMPUTATION_SUCCESS
1516    }));
1517
1518    result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
1519}
1520
1521#[cfg(test)]
1522mod tests {
1523    use super::*;
1524
1525    #[test]
1526    fn test_tau_sampling_creation() {
1527        // Create a basis
1528        let mut status = 0;
1529        let kernel = crate::spir_logistic_kernel_new(10.0, &mut status);
1530        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
1531
1532        let sve = crate::spir_sve_result_new(kernel, 1e-6, -1, -1, -1, &mut status);
1533        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
1534
1535        // Limit basis size to 5
1536        let basis = crate::spir_basis_new(1, 10.0, 1.0, 1e-6, kernel, sve, 5, &mut status);
1537        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
1538
1539        // Get actual basis size
1540        let mut actual_basis_size = 0;
1541        let ret = crate::spir_basis_get_size(basis, &mut actual_basis_size);
1542        assert_eq!(ret, SPIR_COMPUTATION_SUCCESS);
1543
1544        // Create tau sampling with enough points (at least basis_size)
1545        let tau_points: Vec<f64> = (0..actual_basis_size)
1546            .map(|i| (i as f64 + 1.0) * 10.0 / (actual_basis_size as f64 + 1.0))
1547            .collect();
1548
1549        let sampling = spir_tau_sampling_new(
1550            basis,
1551            tau_points.len() as i32,
1552            tau_points.as_ptr(),
1553            &mut status,
1554        );
1555        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
1556        assert!(!sampling.is_null());
1557
1558        // Get number of points
1559        let mut n_points = 0;
1560        let ret = spir_sampling_get_npoints(sampling, &mut n_points);
1561        assert_eq!(ret, SPIR_COMPUTATION_SUCCESS);
1562        assert_eq!(n_points, actual_basis_size);
1563
1564        // Get tau points back
1565        let mut retrieved_points = vec![0.0; actual_basis_size as usize];
1566        let ret = spir_sampling_get_taus(sampling, retrieved_points.as_mut_ptr());
1567        assert_eq!(ret, SPIR_COMPUTATION_SUCCESS);
1568
1569        // Check that retrieved points match
1570        for (i, (&retrieved, &original)) in
1571            retrieved_points.iter().zip(tau_points.iter()).enumerate()
1572        {
1573            assert!(
1574                (retrieved - original).abs() < 1e-10,
1575                "Point {} mismatch: {} vs {}",
1576                i,
1577                retrieved,
1578                original
1579            );
1580        }
1581
1582        // Get condition number
1583        let mut cond = 0.0;
1584        let ret = spir_sampling_get_cond_num(sampling, &mut cond);
1585        assert_eq!(ret, SPIR_COMPUTATION_SUCCESS);
1586        assert!(cond >= 1.0); // Condition number >= 1
1587
1588        // Clean up
1589        crate::spir_sampling_release(sampling);
1590        crate::spir_basis_release(basis);
1591        crate::spir_sve_result_release(sve);
1592        crate::spir_kernel_release(kernel);
1593    }
1594}