1use mdarray::Shape;
14use num_complex::Complex64;
15use std::panic::{AssertUnwindSafe, catch_unwind};
16use std::sync::Arc;
17
18use crate::gemm::{get_backend_handle, spir_gemm_backend};
19use crate::types::{BasisType, SamplingType, spir_basis, spir_sampling};
20use crate::utils::{
21 MemoryOrder, build_output_dims, convert_dims_for_row_major, create_dview_from_ptr,
22 create_dviewmut_from_ptr, read_tensor_nd,
23};
24use crate::{
25 SPIR_COMPUTATION_SUCCESS, SPIR_INVALID_ARGUMENT, SPIR_NOT_SUPPORTED, SPIR_STATISTICS_BOSONIC,
26 SPIR_STATISTICS_FERMIONIC, StatusCode,
27};
28use sparse_ir::fitters::InplaceFitter;
29use sparse_ir::{Bosonic, Fermionic};
30
31#[unsafe(no_mangle)]
33pub extern "C" fn spir_sampling_release(sampling: *mut spir_sampling) {
34 if !sampling.is_null() {
35 unsafe {
36 let _ = Box::from_raw(sampling);
37 }
38 }
39}
40
41#[unsafe(no_mangle)]
43pub extern "C" fn spir_sampling_clone(src: *const spir_sampling) -> *mut spir_sampling {
44 if src.is_null() {
45 return std::ptr::null_mut();
46 }
47
48 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| unsafe {
49 let src_ref = &*src;
50 let cloned = (*src_ref).clone();
51 Box::into_raw(Box::new(cloned))
52 }));
53
54 result.unwrap_or(std::ptr::null_mut())
55}
56
57#[unsafe(no_mangle)]
59pub extern "C" fn spir_sampling_is_assigned(obj: *const spir_sampling) -> i32 {
60 if obj.is_null() {
61 return 0;
62 }
63
64 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| unsafe {
65 let _ = &*obj;
66 1
67 }));
68
69 result.unwrap_or(0)
70}
71
72#[unsafe(no_mangle)]
90pub extern "C" fn spir_tau_sampling_new(
91 b: *const spir_basis,
92 num_points: libc::c_int,
93 points: *const f64,
94 status: *mut StatusCode,
95) -> *mut spir_sampling {
96 let result = catch_unwind(AssertUnwindSafe(|| {
97 if b.is_null() || points.is_null() {
99 return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT);
100 }
101 if num_points <= 0 {
102 return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT);
103 }
104
105 let basis_ref = unsafe { &*b };
106 let points_slice = unsafe { std::slice::from_raw_parts(points, num_points as usize) };
107
108 let tau_points: Vec<f64> = points_slice.to_vec();
110
111 let sampling_type = match basis_ref.inner() {
113 BasisType::LogisticFermionic(ir_basis) => {
114 let tau_sampling = sparse_ir::sampling::TauSampling::with_sampling_points(
115 ir_basis.as_ref(),
116 tau_points,
117 );
118 SamplingType::TauFermionic(Arc::new(tau_sampling))
119 }
120 BasisType::RegularizedBoseFermionic(ir_basis) => {
121 let tau_sampling = sparse_ir::sampling::TauSampling::with_sampling_points(
122 ir_basis.as_ref(),
123 tau_points,
124 );
125 SamplingType::TauFermionic(Arc::new(tau_sampling))
126 }
127 BasisType::LogisticBosonic(ir_basis) => {
128 let tau_sampling = sparse_ir::sampling::TauSampling::with_sampling_points(
129 ir_basis.as_ref(),
130 tau_points,
131 );
132 SamplingType::TauBosonic(Arc::new(tau_sampling))
133 }
134 BasisType::RegularizedBoseBosonic(ir_basis) => {
135 let tau_sampling = sparse_ir::sampling::TauSampling::with_sampling_points(
136 ir_basis.as_ref(),
137 tau_points,
138 );
139 SamplingType::TauBosonic(Arc::new(tau_sampling))
140 }
141 BasisType::DLRFermionic(dlr) => {
143 let tau_sampling = sparse_ir::sampling::TauSampling::with_sampling_points(
144 dlr.as_ref(),
145 tau_points,
146 );
147 SamplingType::TauFermionic(Arc::new(tau_sampling))
148 }
149 BasisType::DLRBosonic(dlr) => {
150 let tau_sampling = sparse_ir::sampling::TauSampling::with_sampling_points(
151 dlr.as_ref(),
152 tau_points,
153 );
154 SamplingType::TauBosonic(Arc::new(tau_sampling))
155 }
156 };
157
158 let inner = sampling_type;
159 let sampling = spir_sampling {
160 _private: Box::into_raw(Box::new(inner)) as *mut std::ffi::c_void,
161 };
162
163 (Box::into_raw(Box::new(sampling)), SPIR_COMPUTATION_SUCCESS)
164 }));
165
166 match result {
167 Ok((ptr, code)) => {
168 if !status.is_null() {
169 unsafe {
170 *status = code;
171 }
172 }
173 ptr
174 }
175 Err(_) => {
176 if !status.is_null() {
177 unsafe {
178 *status = crate::SPIR_INTERNAL_ERROR;
179 }
180 }
181 std::ptr::null_mut()
182 }
183 }
184}
185
186#[unsafe(no_mangle)]
198pub extern "C" fn spir_matsu_sampling_new(
199 b: *const spir_basis,
200 positive_only: bool,
201 num_points: libc::c_int,
202 points: *const i64,
203 status: *mut StatusCode,
204) -> *mut spir_sampling {
205 let result = catch_unwind(AssertUnwindSafe(|| {
206 if b.is_null() || points.is_null() {
208 return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT);
209 }
210 if num_points <= 0 {
211 return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT);
212 }
213
214 let basis_ref = unsafe { &*b };
215 let points_slice = unsafe { std::slice::from_raw_parts(points, num_points as usize) };
216
217 let matsu_points: Vec<i64> = points_slice.to_vec();
219
220 use sparse_ir::freq::MatsubaraFreq;
222
223 macro_rules! create_matsu_sampling {
225 ($basis:expr, Fermionic) => {
226 if positive_only {
227 let matsu_freqs: Vec<MatsubaraFreq<Fermionic>> = matsu_points
228 .iter()
229 .map(|&n| MatsubaraFreq::new(n).expect("Invalid Matsubara frequency"))
230 .collect();
231 let matsu_sampling = sparse_ir::matsubara_sampling::MatsubaraSamplingPositiveOnly::with_sampling_points(
232 $basis,
233 matsu_freqs,
234 );
235 SamplingType::MatsubaraPositiveOnlyFermionic(Arc::new(matsu_sampling))
236 } else {
237 let matsu_freqs: Vec<MatsubaraFreq<Fermionic>> = matsu_points
238 .iter()
239 .map(|&n| MatsubaraFreq::new(n).expect("Invalid Matsubara frequency"))
240 .collect();
241 let matsu_sampling = sparse_ir::matsubara_sampling::MatsubaraSampling::with_sampling_points(
242 $basis,
243 matsu_freqs,
244 );
245 SamplingType::MatsubaraFermionic(Arc::new(matsu_sampling))
246 }
247 };
248 ($basis:expr, Bosonic) => {
249 if positive_only {
250 let matsu_freqs: Vec<MatsubaraFreq<Bosonic>> = matsu_points
251 .iter()
252 .map(|&n| MatsubaraFreq::new(n).expect("Invalid Matsubara frequency"))
253 .collect();
254 let matsu_sampling = sparse_ir::matsubara_sampling::MatsubaraSamplingPositiveOnly::with_sampling_points(
255 $basis,
256 matsu_freqs,
257 );
258 SamplingType::MatsubaraPositiveOnlyBosonic(Arc::new(matsu_sampling))
259 } else {
260 let matsu_freqs: Vec<MatsubaraFreq<Bosonic>> = matsu_points
261 .iter()
262 .map(|&n| MatsubaraFreq::new(n).expect("Invalid Matsubara frequency"))
263 .collect();
264 let matsu_sampling = sparse_ir::matsubara_sampling::MatsubaraSampling::with_sampling_points(
265 $basis,
266 matsu_freqs,
267 );
268 SamplingType::MatsubaraBosonic(Arc::new(matsu_sampling))
269 }
270 };
271 }
272
273 let sampling_type = match basis_ref.inner() {
275 BasisType::LogisticFermionic(ir_basis) => {
276 create_matsu_sampling!(ir_basis.as_ref(), Fermionic)
277 }
278 BasisType::RegularizedBoseFermionic(ir_basis) => {
279 create_matsu_sampling!(ir_basis.as_ref(), Fermionic)
280 }
281 BasisType::LogisticBosonic(ir_basis) => {
282 create_matsu_sampling!(ir_basis.as_ref(), Bosonic)
283 }
284 BasisType::RegularizedBoseBosonic(ir_basis) => {
285 create_matsu_sampling!(ir_basis.as_ref(), Bosonic)
286 }
287 BasisType::DLRFermionic(dlr) => {
289 create_matsu_sampling!(dlr.as_ref(), Fermionic)
290 }
291 BasisType::DLRBosonic(dlr) => {
292 create_matsu_sampling!(dlr.as_ref(), Bosonic)
293 }
294 };
295
296 let inner = sampling_type;
297 let sampling = spir_sampling {
298 _private: Box::into_raw(Box::new(inner)) as *mut std::ffi::c_void,
299 };
300
301 (Box::into_raw(Box::new(sampling)), SPIR_COMPUTATION_SUCCESS)
302 }));
303
304 match result {
305 Ok((ptr, code)) => {
306 if !status.is_null() {
307 unsafe {
308 *status = code;
309 }
310 }
311 ptr
312 }
313 Err(_) => {
314 if !status.is_null() {
315 unsafe {
316 *status = crate::SPIR_INTERNAL_ERROR;
317 }
318 }
319 std::ptr::null_mut()
320 }
321 }
322}
323
324#[unsafe(no_mangle)]
341pub extern "C" fn spir_tau_sampling_new_with_matrix(
342 order: libc::c_int,
343 statistics: libc::c_int,
344 basis_size: libc::c_int,
345 num_points: libc::c_int,
346 points: *const f64,
347 matrix: *const f64,
348 status: *mut StatusCode,
349) -> *mut spir_sampling {
350 let result = catch_unwind(AssertUnwindSafe(|| {
351 if points.is_null() || matrix.is_null() {
353 return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT);
354 }
355 if num_points <= 0 || basis_size <= 0 {
356 return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT);
357 }
358
359 let mem_order = match MemoryOrder::from_c_int(order) {
361 Ok(o) => o,
362 Err(_) => return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT),
363 };
364
365 let points_slice = unsafe { std::slice::from_raw_parts(points, num_points as usize) };
367 let tau_points: Vec<f64> = points_slice.to_vec();
368
369 let orig_dims = [num_points as usize, basis_size as usize];
371 let dyn_tensor = unsafe { read_tensor_nd(matrix, &orig_dims, mem_order) };
372
373 let shape_dims = dyn_tensor.shape().with_dims(|dims| dims.to_vec());
375 assert_eq!(
376 shape_dims.len(),
377 2,
378 "Expected 2D tensor, got {}D",
379 shape_dims.len()
380 );
381 let num_points_actual = shape_dims[0];
382 let basis_size_actual = shape_dims[1];
383 let matrix_tensor =
384 sparse_ir::DTensor::<f64, 2>::from_fn([num_points_actual, basis_size_actual], |idx| {
385 dyn_tensor[&[idx[0], idx[1]][..]]
386 });
387 let sampling_type = match statistics {
389 SPIR_STATISTICS_FERMIONIC => {
390 let tau_sampling = sparse_ir::sampling::TauSampling::<Fermionic>::from_matrix(
392 tau_points,
393 matrix_tensor,
394 );
395 SamplingType::TauFermionic(Arc::new(tau_sampling))
396 }
397 SPIR_STATISTICS_BOSONIC => {
398 let tau_sampling = sparse_ir::sampling::TauSampling::<Bosonic>::from_matrix(
400 tau_points,
401 matrix_tensor,
402 );
403 SamplingType::TauBosonic(Arc::new(tau_sampling))
404 }
405 _ => return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT),
406 };
407
408 let inner = sampling_type;
409 let sampling = spir_sampling {
410 _private: Box::into_raw(Box::new(inner)) as *mut std::ffi::c_void,
411 };
412
413 (Box::into_raw(Box::new(sampling)), SPIR_COMPUTATION_SUCCESS)
414 }));
415
416 match result {
417 Ok((ptr, code)) => {
418 if !status.is_null() {
419 unsafe {
420 *status = code;
421 }
422 }
423 ptr
424 }
425 Err(_) => {
426 if !status.is_null() {
427 unsafe {
428 *status = crate::SPIR_INTERNAL_ERROR;
429 }
430 }
431 std::ptr::null_mut()
432 }
433 }
434}
435
436#[unsafe(no_mangle)]
454pub extern "C" fn spir_matsu_sampling_new_with_matrix(
455 order: libc::c_int,
456 statistics: libc::c_int,
457 basis_size: libc::c_int,
458 positive_only: bool,
459 num_points: libc::c_int,
460 points: *const i64,
461 matrix: *const Complex64,
462 status: *mut StatusCode,
463) -> *mut spir_sampling {
464 use std::io::Write;
465 debug_println!(
466 "spir_matsu_sampling_new_with_matrix: start, order={}, statistics={}, basis_size={}, positive_only={}, num_points={}",
467 order,
468 statistics,
469 basis_size,
470 positive_only,
471 num_points
472 );
473 std::io::stderr().flush().ok();
474 let result = catch_unwind(AssertUnwindSafe(|| {
475 use std::io::Write;
476 debug_println!("spir_matsu_sampling_new_with_matrix: inside catch_unwind");
477 std::io::stderr().flush().ok();
478 if points.is_null() || matrix.is_null() {
480 debug_eprintln!("spir_matsu_sampling_new_with_matrix: null pointer");
481 return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT);
482 }
483 if num_points <= 0 || basis_size <= 0 {
484 debug_eprintln!(
485 "spir_matsu_sampling_new_with_matrix: invalid size, num_points={}, basis_size={}",
486 num_points,
487 basis_size
488 );
489 return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT);
490 }
491 debug_println!("spir_matsu_sampling_new_with_matrix: input validation passed");
492 std::io::stderr().flush().ok();
493
494 let mem_order = match MemoryOrder::from_c_int(order) {
496 Ok(o) => o,
497 Err(_) => return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT),
498 };
499
500 debug_println!("spir_matsu_sampling_new_with_matrix: creating points slice...");
502 std::io::stderr().flush().ok();
503 let points_slice = unsafe { std::slice::from_raw_parts(points, num_points as usize) };
504 debug_println!(
505 "spir_matsu_sampling_new_with_matrix: points slice created, len = {}",
506 points_slice.len()
507 );
508 std::io::stderr().flush().ok();
509 let matsu_points: Vec<i64> = points_slice.to_vec();
510 debug_println!(
511 "spir_matsu_sampling_new_with_matrix: matsu_points created, len = {}",
512 matsu_points.len()
513 );
514 std::io::stderr().flush().ok();
515
516 use sparse_ir::freq::MatsubaraFreq;
517
518 let orig_dims = [num_points as usize, basis_size as usize];
520 debug_println!(
521 "spir_matsu_sampling_new_with_matrix: orig_dims = {:?}, mem_order = {:?}",
522 orig_dims,
523 mem_order
524 );
525 std::io::stderr().flush().ok();
526
527 debug_println!("spir_matsu_sampling_new_with_matrix: reading tensor from buffer...");
528 std::io::stderr().flush().ok();
529 let dyn_tensor = unsafe { read_tensor_nd(matrix, &orig_dims, mem_order) };
530 let shape_dims = dyn_tensor.shape().with_dims(|dims| dims.to_vec());
531 debug_println!(
532 "spir_matsu_sampling_new_with_matrix: dyn_tensor created, shape = {:?}",
533 shape_dims
534 );
535 std::io::stderr().flush().ok();
536
537 debug_println!("spir_matsu_sampling_new_with_matrix: converting to fixed 2D tensor...");
539 std::io::stderr().flush().ok();
540 assert_eq!(
541 shape_dims.len(),
542 2,
543 "Expected 2D tensor, got {}D",
544 shape_dims.len()
545 );
546 let num_points_actual = shape_dims[0];
547 let basis_size_actual = shape_dims[1];
548 debug_println!(
549 "spir_matsu_sampling_new_with_matrix: converting from shape {:?} to DTensor<Complex64, 2>",
550 shape_dims
551 );
552 std::io::stderr().flush().ok();
553 let matrix_tensor = sparse_ir::DTensor::<Complex64, 2>::from_fn(
554 [num_points_actual, basis_size_actual],
555 |idx| dyn_tensor[&[idx[0], idx[1]][..]],
556 );
557 debug_println!(
558 "spir_matsu_sampling_new_with_matrix: matrix_tensor created, shape = {:?}",
559 matrix_tensor.shape()
560 );
561 std::io::stderr().flush().ok();
562
563 debug_println!(
565 "spir_matsu_sampling_new_with_matrix: creating sampling, statistics={}, positive_only={}",
566 statistics,
567 positive_only
568 );
569 std::io::stderr().flush().ok();
570 let sampling_type = match (statistics, positive_only) {
571 (SPIR_STATISTICS_FERMIONIC, true) => {
572 debug_println!("spir_matsu_sampling_new_with_matrix: Fermionic, positive-only");
573 std::io::stderr().flush().ok();
574 let matsu_freqs: Vec<MatsubaraFreq<Fermionic>> = matsu_points
576 .iter()
577 .map(|&n| MatsubaraFreq::new(n).expect("Invalid Matsubara frequency"))
578 .collect();
579 debug_println!(
580 "spir_matsu_sampling_new_with_matrix: matsu_freqs created, len = {}",
581 matsu_freqs.len()
582 );
583 std::io::stderr().flush().ok();
584 debug_println!("spir_matsu_sampling_new_with_matrix: calling from_matrix...");
585 std::io::stderr().flush().ok();
586 let matsu_sampling =
587 sparse_ir::matsubara_sampling::MatsubaraSamplingPositiveOnly::from_matrix(
588 matsu_freqs,
589 matrix_tensor.clone(),
590 );
591 debug_println!("spir_matsu_sampling_new_with_matrix: from_matrix returned");
592 std::io::stderr().flush().ok();
593 SamplingType::MatsubaraPositiveOnlyFermionic(Arc::new(matsu_sampling))
594 }
595 (SPIR_STATISTICS_FERMIONIC, false) => {
596 debug_println!("spir_matsu_sampling_new_with_matrix: Fermionic, full range");
597 std::io::stderr().flush().ok();
598 let matsu_freqs: Vec<MatsubaraFreq<Fermionic>> = matsu_points
600 .iter()
601 .map(|&n| MatsubaraFreq::new(n).expect("Invalid Matsubara frequency"))
602 .collect();
603 debug_println!(
604 "spir_matsu_sampling_new_with_matrix: matsu_freqs created, len = {}",
605 matsu_freqs.len()
606 );
607 std::io::stderr().flush().ok();
608 debug_println!("spir_matsu_sampling_new_with_matrix: calling from_matrix...");
609 std::io::stderr().flush().ok();
610 let matsu_sampling = sparse_ir::matsubara_sampling::MatsubaraSampling::from_matrix(
611 matsu_freqs,
612 matrix_tensor.clone(),
613 );
614 debug_println!("spir_matsu_sampling_new_with_matrix: from_matrix returned");
615 std::io::stderr().flush().ok();
616 SamplingType::MatsubaraFermionic(Arc::new(matsu_sampling))
617 }
618 (SPIR_STATISTICS_BOSONIC, true) => {
619 let matsu_freqs: Vec<MatsubaraFreq<Bosonic>> = matsu_points
621 .iter()
622 .map(|&n| MatsubaraFreq::new(n).expect("Invalid Matsubara frequency"))
623 .collect();
624 let matsu_sampling =
625 sparse_ir::matsubara_sampling::MatsubaraSamplingPositiveOnly::from_matrix(
626 matsu_freqs,
627 matrix_tensor.clone(),
628 );
629 SamplingType::MatsubaraPositiveOnlyBosonic(Arc::new(matsu_sampling))
630 }
631 (SPIR_STATISTICS_BOSONIC, false) => {
632 let matsu_freqs: Vec<MatsubaraFreq<Bosonic>> = matsu_points
634 .iter()
635 .map(|&n| MatsubaraFreq::new(n).expect("Invalid Matsubara frequency"))
636 .collect();
637 let matsu_sampling = sparse_ir::matsubara_sampling::MatsubaraSampling::from_matrix(
638 matsu_freqs,
639 matrix_tensor.clone(),
640 );
641 SamplingType::MatsubaraBosonic(Arc::new(matsu_sampling))
642 }
643 _ => return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT),
644 };
645
646 let inner = sampling_type;
647 let sampling = spir_sampling {
648 _private: Box::into_raw(Box::new(inner)) as *mut std::ffi::c_void,
649 };
650
651 (Box::into_raw(Box::new(sampling)), SPIR_COMPUTATION_SUCCESS)
652 }));
653
654 match result {
655 Ok((ptr, code)) => {
656 if !status.is_null() {
657 unsafe {
658 *status = code;
659 }
660 }
661 ptr
662 }
663 Err(_) => {
664 if !status.is_null() {
665 unsafe {
666 *status = crate::SPIR_INTERNAL_ERROR;
667 }
668 }
669 std::ptr::null_mut()
670 }
671 }
672}
673
674#[unsafe(no_mangle)]
700pub extern "C" fn spir_sampling_get_npoints(
701 s: *const spir_sampling,
702 num_points: *mut libc::c_int,
703) -> StatusCode {
704 let result = catch_unwind(AssertUnwindSafe(|| {
705 if s.is_null() || num_points.is_null() {
706 return SPIR_INVALID_ARGUMENT;
707 }
708
709 let sampling_ref = unsafe { &*s };
710
711 let n_points = match sampling_ref.inner() {
712 SamplingType::TauFermionic(tau) => tau.n_sampling_points(),
713 SamplingType::TauBosonic(tau) => tau.n_sampling_points(),
714 SamplingType::MatsubaraFermionic(matsu) => matsu.n_sampling_points(),
715 SamplingType::MatsubaraBosonic(matsu) => matsu.n_sampling_points(),
716 SamplingType::MatsubaraPositiveOnlyFermionic(matsu) => matsu.n_sampling_points(),
717 SamplingType::MatsubaraPositiveOnlyBosonic(matsu) => matsu.n_sampling_points(),
718 };
719
720 unsafe {
721 *num_points = n_points as libc::c_int;
722 }
723 SPIR_COMPUTATION_SUCCESS
724 }));
725
726 result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
727}
728
729#[unsafe(no_mangle)]
753pub extern "C" fn spir_sampling_get_taus(s: *const spir_sampling, points: *mut f64) -> StatusCode {
754 let result = catch_unwind(AssertUnwindSafe(|| {
755 if s.is_null() || points.is_null() {
756 return SPIR_INVALID_ARGUMENT;
757 }
758
759 let sampling_ref = unsafe { &*s };
760
761 match sampling_ref.inner() {
762 SamplingType::TauFermionic(tau) => {
763 let tau_points = tau.sampling_points();
764 let out_slice = unsafe { std::slice::from_raw_parts_mut(points, tau_points.len()) };
765 out_slice.copy_from_slice(tau_points);
766 SPIR_COMPUTATION_SUCCESS
767 }
768 SamplingType::TauBosonic(tau) => {
769 let tau_points = tau.sampling_points();
770 let out_slice = unsafe { std::slice::from_raw_parts_mut(points, tau_points.len()) };
771 out_slice.copy_from_slice(tau_points);
772 SPIR_COMPUTATION_SUCCESS
773 }
774 _ => SPIR_NOT_SUPPORTED,
775 }
776 }));
777
778 result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
779}
780
781#[unsafe(no_mangle)]
783pub extern "C" fn spir_sampling_get_matsus(
784 s: *const spir_sampling,
785 points: *mut i64,
786) -> StatusCode {
787 let result = catch_unwind(AssertUnwindSafe(|| {
788 if s.is_null() || points.is_null() {
789 return SPIR_INVALID_ARGUMENT;
790 }
791
792 let sampling_ref = unsafe { &*s };
793
794 match sampling_ref.inner() {
795 SamplingType::MatsubaraFermionic(matsu) => {
796 let matsu_freqs = matsu.sampling_points();
797 let out_slice =
798 unsafe { std::slice::from_raw_parts_mut(points, matsu_freqs.len()) };
799 for (i, freq) in matsu_freqs.iter().enumerate() {
800 out_slice[i] = freq.n();
801 }
802 SPIR_COMPUTATION_SUCCESS
803 }
804 SamplingType::MatsubaraBosonic(matsu) => {
805 let matsu_freqs = matsu.sampling_points();
806 let out_slice =
807 unsafe { std::slice::from_raw_parts_mut(points, matsu_freqs.len()) };
808 for (i, freq) in matsu_freqs.iter().enumerate() {
809 out_slice[i] = freq.n();
810 }
811 SPIR_COMPUTATION_SUCCESS
812 }
813 SamplingType::MatsubaraPositiveOnlyFermionic(matsu) => {
814 let matsu_freqs = matsu.sampling_points();
815 let out_slice =
816 unsafe { std::slice::from_raw_parts_mut(points, matsu_freqs.len()) };
817 for (i, freq) in matsu_freqs.iter().enumerate() {
818 out_slice[i] = freq.n();
819 }
820 SPIR_COMPUTATION_SUCCESS
821 }
822 SamplingType::MatsubaraPositiveOnlyBosonic(matsu) => {
823 let matsu_freqs = matsu.sampling_points();
824 let out_slice =
825 unsafe { std::slice::from_raw_parts_mut(points, matsu_freqs.len()) };
826 for (i, freq) in matsu_freqs.iter().enumerate() {
827 out_slice[i] = freq.n();
828 }
829 SPIR_COMPUTATION_SUCCESS
830 }
831 _ => SPIR_NOT_SUPPORTED,
832 }
833 }));
834
835 result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
836}
837
838#[unsafe(no_mangle)]
859pub extern "C" fn spir_sampling_get_cond_num(
860 s: *const spir_sampling,
861 cond_num: *mut f64,
862) -> StatusCode {
863 let result = catch_unwind(AssertUnwindSafe(|| {
864 if s.is_null() || cond_num.is_null() {
865 return SPIR_INVALID_ARGUMENT;
866 }
867
868 let sampling_ref = unsafe { &*s };
869
870 let condition_number = match sampling_ref.inner() {
872 SamplingType::TauFermionic(tau) => {
873 let matrix = tau.matrix();
875 compute_condition_number_real(matrix)
876 }
877 SamplingType::TauBosonic(tau) => {
878 let matrix = tau.matrix();
880 compute_condition_number_real(matrix)
881 }
882 SamplingType::MatsubaraFermionic(matsu) => {
883 let matrix = matsu.matrix();
885 compute_condition_number_complex(matrix)
886 }
887 SamplingType::MatsubaraBosonic(matsu) => {
888 let matrix = matsu.matrix();
889 compute_condition_number_complex(matrix)
890 }
891 SamplingType::MatsubaraPositiveOnlyFermionic(matsu) => {
892 let matrix = matsu.matrix();
895 compute_condition_number_complex(matrix)
896 }
897 SamplingType::MatsubaraPositiveOnlyBosonic(matsu) => {
898 let matrix = matsu.matrix();
899 compute_condition_number_complex(matrix)
900 }
901 };
902
903 unsafe {
904 *cond_num = condition_number;
905 }
906 SPIR_COMPUTATION_SUCCESS
907 }));
908
909 result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
910}
911
912fn compute_condition_number_real(matrix: &mdarray::DTensor<f64, 2>) -> f64 {
914 use mdarray_linalg::prelude::SVD;
915 use mdarray_linalg::svd::SVDDecomp;
916 use mdarray_linalg_faer::Faer;
917
918 let mut matrix_copy = matrix.clone();
919 let SVDDecomp { s, .. } = Faer.svd(&mut *matrix_copy).expect("SVD computation failed");
920
921 let min_dim = s.shape().0.min(s.shape().1);
922 if min_dim == 0 {
923 return 1.0;
924 }
925
926 let max_sv = s[[0, 0]];
927 let min_sv = s[[0, min_dim - 1]];
928
929 if min_sv.abs() < 1e-15 {
930 return f64::INFINITY;
932 }
933
934 max_sv / min_sv
935}
936
937fn compute_condition_number_complex(matrix: &mdarray::DTensor<num_complex::Complex64, 2>) -> f64 {
939 use mdarray_linalg::prelude::SVD;
940 use mdarray_linalg::svd::SVDDecomp;
941 use mdarray_linalg_faer::Faer;
942
943 let mut matrix_copy = matrix.clone();
944 let SVDDecomp { s, .. } = Faer.svd(&mut *matrix_copy).expect("SVD computation failed");
945
946 let min_dim = s.shape().0.min(s.shape().1);
947 if min_dim == 0 {
948 return 1.0;
949 }
950
951 let max_sv = s[[0, 0]].re;
953 let min_sv = s[[0, min_dim - 1]].re;
954
955 if min_sv.abs() < 1e-15 {
956 return f64::INFINITY;
958 }
959
960 max_sv / min_sv
961}
962
963#[unsafe(no_mangle)]
1004pub extern "C" fn spir_sampling_eval_dd(
1005 s: *const spir_sampling,
1006 backend: *const spir_gemm_backend,
1007 order: libc::c_int,
1008 ndim: libc::c_int,
1009 input_dims: *const libc::c_int,
1010 target_dim: libc::c_int,
1011 input: *const f64,
1012 out: *mut f64,
1013) -> StatusCode {
1014 let result = catch_unwind(AssertUnwindSafe(|| {
1015 if s.is_null() || input_dims.is_null() || input.is_null() || out.is_null() {
1017 return SPIR_INVALID_ARGUMENT;
1018 }
1019 if ndim <= 0 || target_dim < 0 || target_dim >= ndim {
1020 return SPIR_INVALID_ARGUMENT;
1021 }
1022
1023 let mem_order = match MemoryOrder::from_c_int(order) {
1025 Ok(o) => o,
1026 Err(_) => return SPIR_INVALID_ARGUMENT,
1027 };
1028
1029 let sampling_ref = unsafe { &*s };
1030 let dims_slice = unsafe { std::slice::from_raw_parts(input_dims, ndim as usize) };
1031 let orig_dims: Vec<usize> = dims_slice.iter().map(|&d| d as usize).collect();
1032
1033 let (row_major_dims, row_major_target_dim) =
1036 convert_dims_for_row_major(&orig_dims, target_dim as usize, mem_order);
1037
1038 let input_view = unsafe { create_dview_from_ptr(input, &row_major_dims) };
1040
1041 let sampling_inner = sampling_ref.inner();
1043 let expected_basis_size = sampling_inner.basis_size();
1044 if row_major_dims[row_major_target_dim] != expected_basis_size {
1045 return crate::SPIR_INPUT_DIMENSION_MISMATCH;
1046 }
1047
1048 let n_points = sampling_inner.n_points();
1050 let out_dims = build_output_dims(&row_major_dims, row_major_target_dim, n_points);
1051
1052 let mut output_view = unsafe { create_dviewmut_from_ptr(out, &out_dims) };
1054
1055 let backend_handle = unsafe { get_backend_handle(backend) };
1057
1058 if !InplaceFitter::evaluate_nd_dd_to(
1060 sampling_inner,
1061 backend_handle,
1062 &input_view,
1063 row_major_target_dim,
1064 &mut output_view,
1065 ) {
1066 return SPIR_NOT_SUPPORTED;
1067 }
1068
1069 SPIR_COMPUTATION_SUCCESS
1070 }));
1071
1072 result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
1073}
1074
1075#[unsafe(no_mangle)]
1080pub extern "C" fn spir_sampling_eval_dz(
1081 s: *const spir_sampling,
1082 backend: *const spir_gemm_backend,
1083 order: libc::c_int,
1084 ndim: libc::c_int,
1085 input_dims: *const libc::c_int,
1086 target_dim: libc::c_int,
1087 input: *const f64,
1088 out: *mut Complex64,
1089) -> StatusCode {
1090 let result = catch_unwind(AssertUnwindSafe(|| {
1091 if s.is_null() || input_dims.is_null() || input.is_null() || out.is_null() {
1093 return SPIR_INVALID_ARGUMENT;
1094 }
1095 if ndim <= 0 || target_dim < 0 || target_dim >= ndim {
1096 return SPIR_INVALID_ARGUMENT;
1097 }
1098
1099 let mem_order = match MemoryOrder::from_c_int(order) {
1101 Ok(o) => o,
1102 Err(_) => return SPIR_INVALID_ARGUMENT,
1103 };
1104
1105 let sampling_ref = unsafe { &*s };
1106 let dims_slice = unsafe { std::slice::from_raw_parts(input_dims, ndim as usize) };
1107 let orig_dims: Vec<usize> = dims_slice.iter().map(|&d| d as usize).collect();
1108
1109 let (row_major_dims, row_major_target_dim) =
1111 convert_dims_for_row_major(&orig_dims, target_dim as usize, mem_order);
1112
1113 let input_view = unsafe { create_dview_from_ptr(input, &row_major_dims) };
1115
1116 let sampling_inner = sampling_ref.inner();
1118 let expected_basis_size = sampling_inner.basis_size();
1119 if row_major_dims[row_major_target_dim] != expected_basis_size {
1120 return crate::SPIR_INPUT_DIMENSION_MISMATCH;
1121 }
1122
1123 let n_points = sampling_inner.n_points();
1125 let out_dims = build_output_dims(&row_major_dims, row_major_target_dim, n_points);
1126
1127 let mut output_view = unsafe { create_dviewmut_from_ptr(out, &out_dims) };
1129
1130 let backend_handle = unsafe { get_backend_handle(backend) };
1132
1133 if !InplaceFitter::evaluate_nd_dz_to(
1135 sampling_inner,
1136 backend_handle,
1137 &input_view,
1138 row_major_target_dim,
1139 &mut output_view,
1140 ) {
1141 return SPIR_NOT_SUPPORTED;
1142 }
1143
1144 SPIR_COMPUTATION_SUCCESS
1145 }));
1146
1147 result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
1148}
1149
1150#[unsafe(no_mangle)]
1155pub extern "C" fn spir_sampling_eval_zz(
1156 s: *const spir_sampling,
1157 backend: *const spir_gemm_backend,
1158 order: libc::c_int,
1159 ndim: libc::c_int,
1160 input_dims: *const libc::c_int,
1161 target_dim: libc::c_int,
1162 input: *const Complex64,
1163 out: *mut Complex64,
1164) -> StatusCode {
1165 let result = catch_unwind(AssertUnwindSafe(|| {
1166 if s.is_null() || input_dims.is_null() || input.is_null() || out.is_null() {
1167 return SPIR_INVALID_ARGUMENT;
1168 }
1169 if ndim <= 0 || target_dim < 0 || target_dim >= ndim {
1170 return SPIR_INVALID_ARGUMENT;
1171 }
1172
1173 let mem_order = match MemoryOrder::from_c_int(order) {
1175 Ok(o) => o,
1176 Err(_) => return SPIR_INVALID_ARGUMENT,
1177 };
1178
1179 let sampling_ref = unsafe { &*s };
1180 let dims_slice = unsafe { std::slice::from_raw_parts(input_dims, ndim as usize) };
1181 let orig_dims: Vec<usize> = dims_slice.iter().map(|&d| d as usize).collect();
1182
1183 let (row_major_dims, row_major_target_dim) =
1185 convert_dims_for_row_major(&orig_dims, target_dim as usize, mem_order);
1186
1187 let input_view = unsafe { create_dview_from_ptr(input, &row_major_dims) };
1189
1190 let sampling_inner = sampling_ref.inner();
1192 let expected_basis_size = sampling_inner.basis_size();
1193 if row_major_dims[row_major_target_dim] != expected_basis_size {
1194 return crate::SPIR_INPUT_DIMENSION_MISMATCH;
1195 }
1196
1197 let n_points = sampling_inner.n_points();
1199 let out_dims = build_output_dims(&row_major_dims, row_major_target_dim, n_points);
1200
1201 let mut output_view = unsafe { create_dviewmut_from_ptr(out, &out_dims) };
1203
1204 let backend_handle = unsafe { get_backend_handle(backend) };
1206
1207 if !InplaceFitter::evaluate_nd_zz_to(
1209 sampling_inner,
1210 backend_handle,
1211 &input_view,
1212 row_major_target_dim,
1213 &mut output_view,
1214 ) {
1215 return SPIR_NOT_SUPPORTED;
1216 }
1217
1218 SPIR_COMPUTATION_SUCCESS
1219 }));
1220
1221 result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
1222}
1223
1224#[unsafe(no_mangle)]
1264pub extern "C" fn spir_sampling_fit_dd(
1265 s: *const spir_sampling,
1266 backend: *const spir_gemm_backend,
1267 order: libc::c_int,
1268 ndim: libc::c_int,
1269 input_dims: *const libc::c_int,
1270 target_dim: libc::c_int,
1271 input: *const f64,
1272 out: *mut f64,
1273) -> StatusCode {
1274 let result = catch_unwind(AssertUnwindSafe(|| {
1275 if s.is_null() || input_dims.is_null() || input.is_null() || out.is_null() {
1276 return SPIR_INVALID_ARGUMENT;
1277 }
1278 if ndim <= 0 || target_dim < 0 || target_dim >= ndim {
1279 return SPIR_INVALID_ARGUMENT;
1280 }
1281
1282 let mem_order = match MemoryOrder::from_c_int(order) {
1284 Ok(o) => o,
1285 Err(_) => return SPIR_INVALID_ARGUMENT,
1286 };
1287
1288 let sampling_ref = unsafe { &*s };
1289 let dims_slice = unsafe { std::slice::from_raw_parts(input_dims, ndim as usize) };
1290 let orig_dims: Vec<usize> = dims_slice.iter().map(|&d| d as usize).collect();
1291
1292 let (row_major_dims, row_major_target_dim) =
1294 convert_dims_for_row_major(&orig_dims, target_dim as usize, mem_order);
1295
1296 let input_view = unsafe { create_dview_from_ptr(input, &row_major_dims) };
1298
1299 let sampling_inner = sampling_ref.inner();
1301 let expected_n_points = sampling_inner.n_points();
1302 if row_major_dims[row_major_target_dim] != expected_n_points {
1303 return crate::SPIR_INPUT_DIMENSION_MISMATCH;
1304 }
1305
1306 let basis_size = sampling_inner.basis_size();
1308 let out_dims = build_output_dims(&row_major_dims, row_major_target_dim, basis_size);
1309
1310 let mut output_view = unsafe { create_dviewmut_from_ptr(out, &out_dims) };
1312
1313 let backend_handle = unsafe { get_backend_handle(backend) };
1315
1316 if !InplaceFitter::fit_nd_dd_to(
1318 sampling_inner,
1319 backend_handle,
1320 &input_view,
1321 row_major_target_dim,
1322 &mut output_view,
1323 ) {
1324 return SPIR_NOT_SUPPORTED;
1325 }
1326
1327 SPIR_COMPUTATION_SUCCESS
1328 }));
1329
1330 result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
1331}
1332
1333#[unsafe(no_mangle)]
1339pub extern "C" fn spir_sampling_fit_zz(
1340 s: *const spir_sampling,
1341 backend: *const spir_gemm_backend,
1342 order: libc::c_int,
1343 ndim: libc::c_int,
1344 input_dims: *const libc::c_int,
1345 target_dim: libc::c_int,
1346 input: *const Complex64,
1347 out: *mut Complex64,
1348) -> StatusCode {
1349 let result = catch_unwind(AssertUnwindSafe(|| {
1350 if s.is_null() || input_dims.is_null() || input.is_null() || out.is_null() {
1351 return SPIR_INVALID_ARGUMENT;
1352 }
1353 if ndim <= 0 || target_dim < 0 || target_dim >= ndim {
1354 return SPIR_INVALID_ARGUMENT;
1355 }
1356
1357 let mem_order = match MemoryOrder::from_c_int(order) {
1359 Ok(o) => o,
1360 Err(_) => return SPIR_INVALID_ARGUMENT,
1361 };
1362
1363 let sampling_ref = unsafe { &*s };
1364 let dims_slice = unsafe { std::slice::from_raw_parts(input_dims, ndim as usize) };
1365 let orig_dims: Vec<usize> = dims_slice.iter().map(|&d| d as usize).collect();
1366
1367 let (row_major_dims, row_major_target_dim) =
1369 convert_dims_for_row_major(&orig_dims, target_dim as usize, mem_order);
1370
1371 let input_view = unsafe { create_dview_from_ptr(input, &row_major_dims) };
1373
1374 let sampling_inner = sampling_ref.inner();
1376 let expected_n_points = sampling_inner.n_points();
1377 if row_major_dims[row_major_target_dim] != expected_n_points {
1378 return crate::SPIR_INPUT_DIMENSION_MISMATCH;
1379 }
1380
1381 let basis_size = sampling_inner.basis_size();
1383 let out_dims = build_output_dims(&row_major_dims, row_major_target_dim, basis_size);
1384
1385 let mut output_view = unsafe { create_dviewmut_from_ptr(out, &out_dims) };
1387
1388 let backend_handle = unsafe { get_backend_handle(backend) };
1390
1391 if !InplaceFitter::fit_nd_zz_to(
1393 sampling_inner,
1394 backend_handle,
1395 &input_view,
1396 row_major_target_dim,
1397 &mut output_view,
1398 ) {
1399 return SPIR_NOT_SUPPORTED;
1400 }
1401
1402 SPIR_COMPUTATION_SUCCESS
1403 }));
1404
1405 result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
1406}
1407
1408#[unsafe(no_mangle)]
1449pub extern "C" fn spir_sampling_fit_zd(
1450 s: *const spir_sampling,
1451 backend: *const spir_gemm_backend,
1452 order: libc::c_int,
1453 ndim: libc::c_int,
1454 input_dims: *const libc::c_int,
1455 target_dim: libc::c_int,
1456 input: *const Complex64,
1457 out: *mut f64,
1458) -> StatusCode {
1459 let result = catch_unwind(AssertUnwindSafe(|| {
1460 if s.is_null() || input_dims.is_null() || input.is_null() || out.is_null() {
1461 return SPIR_INVALID_ARGUMENT;
1462 }
1463 if ndim <= 0 || target_dim < 0 || target_dim >= ndim {
1464 return SPIR_INVALID_ARGUMENT;
1465 }
1466
1467 let mem_order = match MemoryOrder::from_c_int(order) {
1469 Ok(o) => o,
1470 Err(_) => return SPIR_INVALID_ARGUMENT,
1471 };
1472
1473 let sampling_ref = unsafe { &*s };
1474 let dims_slice = unsafe { std::slice::from_raw_parts(input_dims, ndim as usize) };
1475 let orig_dims: Vec<usize> = dims_slice.iter().map(|&d| d as usize).collect();
1476
1477 let (row_major_dims, row_major_target_dim) =
1479 convert_dims_for_row_major(&orig_dims, target_dim as usize, mem_order);
1480
1481 let input_view = unsafe { create_dview_from_ptr(input, &row_major_dims) };
1483
1484 let sampling_inner = sampling_ref.inner();
1486 let expected_n_points = sampling_inner.n_points();
1487 if row_major_dims[row_major_target_dim] != expected_n_points {
1488 return crate::SPIR_INPUT_DIMENSION_MISMATCH;
1489 }
1490
1491 let basis_size = sampling_inner.basis_size();
1493 let out_dims = build_output_dims(&row_major_dims, row_major_target_dim, basis_size);
1494
1495 let mut output_view = unsafe { create_dviewmut_from_ptr(out, &out_dims) };
1497
1498 let backend_handle = unsafe { get_backend_handle(backend) };
1500
1501 if !InplaceFitter::fit_nd_zd_to(
1506 sampling_inner,
1507 backend_handle,
1508 &input_view,
1509 row_major_target_dim,
1510 &mut output_view,
1511 ) {
1512 return SPIR_NOT_SUPPORTED;
1513 }
1514
1515 SPIR_COMPUTATION_SUCCESS
1516 }));
1517
1518 result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
1519}
1520
1521#[cfg(test)]
1522mod tests {
1523 use super::*;
1524
1525 #[test]
1526 fn test_tau_sampling_creation() {
1527 let mut status = 0;
1529 let kernel = crate::spir_logistic_kernel_new(10.0, &mut status);
1530 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
1531
1532 let sve = crate::spir_sve_result_new(kernel, 1e-6, -1, -1, -1, &mut status);
1533 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
1534
1535 let basis = crate::spir_basis_new(1, 10.0, 1.0, 1e-6, kernel, sve, 5, &mut status);
1537 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
1538
1539 let mut actual_basis_size = 0;
1541 let ret = crate::spir_basis_get_size(basis, &mut actual_basis_size);
1542 assert_eq!(ret, SPIR_COMPUTATION_SUCCESS);
1543
1544 let tau_points: Vec<f64> = (0..actual_basis_size)
1546 .map(|i| (i as f64 + 1.0) * 10.0 / (actual_basis_size as f64 + 1.0))
1547 .collect();
1548
1549 let sampling = spir_tau_sampling_new(
1550 basis,
1551 tau_points.len() as i32,
1552 tau_points.as_ptr(),
1553 &mut status,
1554 );
1555 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
1556 assert!(!sampling.is_null());
1557
1558 let mut n_points = 0;
1560 let ret = spir_sampling_get_npoints(sampling, &mut n_points);
1561 assert_eq!(ret, SPIR_COMPUTATION_SUCCESS);
1562 assert_eq!(n_points, actual_basis_size);
1563
1564 let mut retrieved_points = vec![0.0; actual_basis_size as usize];
1566 let ret = spir_sampling_get_taus(sampling, retrieved_points.as_mut_ptr());
1567 assert_eq!(ret, SPIR_COMPUTATION_SUCCESS);
1568
1569 for (i, (&retrieved, &original)) in
1571 retrieved_points.iter().zip(tau_points.iter()).enumerate()
1572 {
1573 assert!(
1574 (retrieved - original).abs() < 1e-10,
1575 "Point {} mismatch: {} vs {}",
1576 i,
1577 retrieved,
1578 original
1579 );
1580 }
1581
1582 let mut cond = 0.0;
1584 let ret = spir_sampling_get_cond_num(sampling, &mut cond);
1585 assert_eq!(ret, SPIR_COMPUTATION_SUCCESS);
1586 assert!(cond >= 1.0); crate::spir_sampling_release(sampling);
1590 crate::spir_basis_release(basis);
1591 crate::spir_sve_result_release(sve);
1592 crate::spir_kernel_release(kernel);
1593 }
1594}