sparse_ir_capi/
kernel.rs

1//! Kernel API for C
2//!
3//! Functions for creating and manipulating kernel objects.
4
5use std::panic::catch_unwind;
6
7use sparse_ir::kernel::SVEHints;
8
9use crate::types::spir_kernel;
10use crate::{SPIR_COMPUTATION_SUCCESS, SPIR_INTERNAL_ERROR, SPIR_INVALID_ARGUMENT, StatusCode};
11
12// Generate common opaque type functions: release, clone, is_assigned, get_raw_ptr
13
14/// Create a new Logistic kernel
15///
16/// # Arguments
17/// * `lambda` - The kernel parameter Λ = β * ωmax (must be > 0)
18/// * `status` - Pointer to store the status code
19///
20/// # Returns
21/// * Pointer to the newly created kernel object, or NULL if creation fails
22///
23/// # Safety
24/// The caller must ensure `status` is a valid pointer.
25///
26/// # Example (C)
27/// ```c
28/// int status;
29/// spir_kernel* kernel = spir_logistic_kernel_new(10.0, &status);
30/// if (kernel != NULL) {
31///     // Use kernel...
32///     spir_kernel_release(kernel);
33/// }
34/// ```
35#[unsafe(no_mangle)]
36pub extern "C" fn spir_logistic_kernel_new(
37    lambda: f64,
38    status: *mut StatusCode,
39) -> *mut spir_kernel {
40    // Input validation
41    if status.is_null() {
42        return std::ptr::null_mut();
43    }
44
45    if lambda <= 0.0 {
46        unsafe {
47            *status = SPIR_INVALID_ARGUMENT;
48        }
49        return std::ptr::null_mut();
50    }
51
52    // Catch panics to prevent unwinding across FFI boundary
53    let result = catch_unwind(|| {
54        let kernel = spir_kernel::new_logistic(lambda);
55        Box::into_raw(Box::new(kernel))
56    });
57
58    match result {
59        Ok(ptr) => {
60            unsafe {
61                *status = SPIR_COMPUTATION_SUCCESS;
62            }
63            ptr
64        }
65        Err(_) => {
66            unsafe {
67                *status = SPIR_INTERNAL_ERROR;
68            }
69            std::ptr::null_mut()
70        }
71    }
72}
73
74/// Create a new RegularizedBose kernel
75///
76/// # Arguments
77/// * `lambda` - The kernel parameter Λ = β * ωmax (must be > 0)
78/// * `status` - Pointer to store the status code
79///
80/// # Returns
81/// * Pointer to the newly created kernel object, or NULL if creation fails
82#[unsafe(no_mangle)]
83pub extern "C" fn spir_reg_bose_kernel_new(
84    lambda: f64,
85    status: *mut StatusCode,
86) -> *mut spir_kernel {
87    if status.is_null() {
88        return std::ptr::null_mut();
89    }
90
91    if lambda <= 0.0 {
92        unsafe {
93            *status = SPIR_INVALID_ARGUMENT;
94        }
95        return std::ptr::null_mut();
96    }
97
98    let result = catch_unwind(|| {
99        let kernel = spir_kernel::new_regularized_bose(lambda);
100        Box::into_raw(Box::new(kernel))
101    });
102
103    match result {
104        Ok(ptr) => {
105            unsafe {
106                *status = SPIR_COMPUTATION_SUCCESS;
107            }
108            ptr
109        }
110        Err(_) => {
111            unsafe {
112                *status = SPIR_INTERNAL_ERROR;
113            }
114            std::ptr::null_mut()
115        }
116    }
117}
118
119/// Get the lambda parameter of a kernel
120///
121/// # Arguments
122/// * `kernel` - Kernel object
123/// * `lambda_out` - Pointer to store the lambda value
124///
125/// # Returns
126/// * `SPIR_COMPUTATION_SUCCESS` on success
127/// * `SPIR_INVALID_ARGUMENT` if kernel or lambda_out is null
128/// * `SPIR_INTERNAL_ERROR` if internal panic occurs
129#[unsafe(no_mangle)]
130pub extern "C" fn spir_kernel_get_lambda(
131    kernel: *const spir_kernel,
132    lambda_out: *mut f64,
133) -> StatusCode {
134    if kernel.is_null() || lambda_out.is_null() {
135        return SPIR_INVALID_ARGUMENT;
136    }
137
138    let result = catch_unwind(|| unsafe {
139        let k = &*kernel;
140        *lambda_out = k.lambda();
141        SPIR_COMPUTATION_SUCCESS
142    });
143
144    result.unwrap_or(SPIR_INTERNAL_ERROR)
145}
146
147/// Compute kernel value K(x, y)
148///
149/// # Arguments
150/// * `kernel` - Kernel object
151/// * `x` - First argument (typically in [-1, 1])
152/// * `y` - Second argument (typically in [-1, 1])
153/// * `out` - Pointer to store the result
154///
155/// # Returns
156/// * `SPIR_COMPUTATION_SUCCESS` on success
157/// * `SPIR_INVALID_ARGUMENT` if kernel or out is null
158/// * `SPIR_INTERNAL_ERROR` if internal panic occurs
159#[unsafe(no_mangle)]
160pub extern "C" fn spir_kernel_compute(
161    kernel: *const spir_kernel,
162    x: f64,
163    y: f64,
164    out: *mut f64,
165) -> StatusCode {
166    if kernel.is_null() || out.is_null() {
167        return SPIR_INVALID_ARGUMENT;
168    }
169
170    let result = catch_unwind(|| unsafe {
171        let k = &*kernel;
172        *out = k.compute(x, y);
173        SPIR_COMPUTATION_SUCCESS
174    });
175
176    result.unwrap_or(SPIR_INTERNAL_ERROR)
177}
178
179/// Manual release function (replaces macro-generated one)
180///
181/// # Safety
182/// This function drops the kernel. The inner KernelType data is automatically freed
183/// by the Drop implementation when the spir_kernel structure is dropped.
184#[unsafe(no_mangle)]
185pub extern "C" fn spir_kernel_release(kernel: *mut spir_kernel) {
186    if !kernel.is_null() {
187        unsafe {
188            // Drop the spir_kernel structure itself.
189            // The Drop implementation will automatically free the inner KernelType data.
190            let _ = Box::from_raw(kernel);
191        }
192    }
193}
194
195/// Manual clone function (replaces macro-generated one)
196#[unsafe(no_mangle)]
197pub extern "C" fn spir_kernel_clone(src: *const spir_kernel) -> *mut spir_kernel {
198    if src.is_null() {
199        return std::ptr::null_mut();
200    }
201
202    let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| unsafe {
203        let src_ref = &*src;
204        let cloned = (*src_ref).clone();
205        Box::into_raw(Box::new(cloned))
206    }));
207
208    result.unwrap_or(std::ptr::null_mut())
209}
210
211/// Manual is_assigned function (replaces macro-generated one)
212#[unsafe(no_mangle)]
213pub extern "C" fn spir_kernel_is_assigned(obj: *const spir_kernel) -> i32 {
214    if obj.is_null() {
215        return 0;
216    }
217
218    let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| unsafe {
219        let _ = &*obj;
220        1
221    }));
222
223    result.unwrap_or(0)
224}
225
226/// Get kernel domain boundaries
227///
228/// # Arguments
229/// * `k` - Kernel object
230/// * `xmin` - Pointer to store minimum x value
231/// * `xmax` - Pointer to store maximum x value
232/// * `ymin` - Pointer to store minimum y value
233/// * `ymax` - Pointer to store maximum y value
234///
235/// # Returns
236/// * `SPIR_COMPUTATION_SUCCESS` on success
237/// * `SPIR_INVALID_ARGUMENT` if any pointer is null
238/// * `SPIR_INTERNAL_ERROR` if internal panic occurs
239#[unsafe(no_mangle)]
240pub extern "C" fn spir_kernel_get_domain(
241    k: *const spir_kernel,
242    xmin: *mut f64,
243    xmax: *mut f64,
244    ymin: *mut f64,
245    ymax: *mut f64,
246) -> StatusCode {
247    if k.is_null() || xmin.is_null() || xmax.is_null() || ymin.is_null() || ymax.is_null() {
248        return SPIR_INVALID_ARGUMENT;
249    }
250
251    let result = catch_unwind(|| unsafe {
252        let kernel = &*k;
253        let (xmin_val, xmax_val, ymin_val, ymax_val) = kernel.domain();
254        *xmin = xmin_val;
255        *xmax = xmax_val;
256        *ymin = ymin_val;
257        *ymax = ymax_val;
258        SPIR_COMPUTATION_SUCCESS
259    });
260
261    result.unwrap_or(SPIR_INTERNAL_ERROR)
262}
263
264/// Get x-segments for SVE discretization hints from a kernel
265///
266/// This function should be called twice:
267/// 1. First call with segments=NULL: set n_segments to the required array size
268/// 2. Second call with segments allocated: fill segments[0..n_segments-1] with values
269///
270/// # Arguments
271/// * `k` - Kernel object
272/// * `epsilon` - Accuracy target for the basis
273/// * `segments` - Pointer to store segments array (NULL for first call)
274/// * `n_segments` - [IN/OUT] Input: ignored when segments is NULL. Output: number of segments
275///
276/// # Returns
277/// * `SPIR_COMPUTATION_SUCCESS` on success
278/// * `SPIR_INVALID_ARGUMENT` if k or n_segments is null, or segments array is too small
279/// * `SPIR_INTERNAL_ERROR` if internal panic occurs
280#[unsafe(no_mangle)]
281pub extern "C" fn spir_kernel_get_sve_hints_segments_x(
282    k: *const spir_kernel,
283    epsilon: f64,
284    segments: *mut f64,
285    n_segments: *mut libc::c_int,
286) -> StatusCode {
287    if k.is_null() || n_segments.is_null() {
288        return SPIR_INVALID_ARGUMENT;
289    }
290
291    if epsilon <= 0.0 || !epsilon.is_finite() {
292        return SPIR_INVALID_ARGUMENT;
293    }
294
295    let result = catch_unwind(|| unsafe {
296        let kernel = &*k;
297
298        // Get SVE hints based on kernel type
299        let segs = match kernel.inner() {
300            crate::types::KernelType::Logistic(k) => {
301                use sparse_ir::kernel::KernelProperties;
302                let hints = k.sve_hints::<f64>(epsilon);
303                hints.segments_x()
304            }
305            crate::types::KernelType::RegularizedBose(k) => {
306                use sparse_ir::kernel::KernelProperties;
307                let hints = k.sve_hints::<f64>(epsilon);
308                hints.segments_x()
309            }
310        };
311
312        if segments.is_null() {
313            // First call: return the number of segments
314            *n_segments = (segs.len() - 1) as libc::c_int;
315            return SPIR_COMPUTATION_SUCCESS;
316        }
317
318        // Second call: copy segments to output array
319        if *n_segments < (segs.len() - 1) as libc::c_int {
320            return SPIR_INVALID_ARGUMENT;
321        }
322
323        for (i, &seg) in segs.iter().enumerate() {
324            *segments.add(i) = seg;
325        }
326        *n_segments = (segs.len() - 1) as libc::c_int;
327        SPIR_COMPUTATION_SUCCESS
328    });
329
330    result.unwrap_or(SPIR_INTERNAL_ERROR)
331}
332
333/// Get y-segments for SVE discretization hints from a kernel
334///
335/// This function should be called twice:
336/// 1. First call with segments=NULL: set n_segments to the required array size
337/// 2. Second call with segments allocated: fill segments[0..n_segments-1] with values
338///
339/// # Arguments
340/// * `k` - Kernel object
341/// * `epsilon` - Accuracy target for the basis
342/// * `segments` - Pointer to store segments array (NULL for first call)
343/// * `n_segments` - [IN/OUT] Input: ignored when segments is NULL. Output: number of segments
344///
345/// # Returns
346/// * `SPIR_COMPUTATION_SUCCESS` on success
347/// * `SPIR_INVALID_ARGUMENT` if k or n_segments is null, or segments array is too small
348/// * `SPIR_INTERNAL_ERROR` if internal panic occurs
349#[unsafe(no_mangle)]
350pub extern "C" fn spir_kernel_get_sve_hints_segments_y(
351    k: *const spir_kernel,
352    epsilon: f64,
353    segments: *mut f64,
354    n_segments: *mut libc::c_int,
355) -> StatusCode {
356    if k.is_null() || n_segments.is_null() {
357        return SPIR_INVALID_ARGUMENT;
358    }
359
360    if epsilon <= 0.0 || !epsilon.is_finite() {
361        return SPIR_INVALID_ARGUMENT;
362    }
363
364    let result = catch_unwind(|| unsafe {
365        let kernel = &*k;
366
367        // Get SVE hints based on kernel type
368        let segs = match kernel.inner() {
369            crate::types::KernelType::Logistic(k) => {
370                use sparse_ir::kernel::KernelProperties;
371                let hints = k.sve_hints::<f64>(epsilon);
372                hints.segments_y()
373            }
374            crate::types::KernelType::RegularizedBose(k) => {
375                use sparse_ir::kernel::KernelProperties;
376                let hints = k.sve_hints::<f64>(epsilon);
377                hints.segments_y()
378            }
379        };
380
381        if segments.is_null() {
382            // First call: return the number of segments
383            *n_segments = (segs.len() - 1) as libc::c_int;
384            return SPIR_COMPUTATION_SUCCESS;
385        }
386
387        // Second call: copy segments to output array
388        if *n_segments < (segs.len() - 1) as libc::c_int {
389            return SPIR_INVALID_ARGUMENT;
390        }
391
392        for (i, &seg) in segs.iter().enumerate() {
393            *segments.add(i) = seg;
394        }
395        *n_segments = (segs.len() - 1) as libc::c_int;
396        SPIR_COMPUTATION_SUCCESS
397    });
398
399    result.unwrap_or(SPIR_INTERNAL_ERROR)
400}
401
402/// Get the number of singular values hint from a kernel
403///
404/// # Arguments
405/// * `k` - Kernel object
406/// * `epsilon` - Accuracy target for the basis
407/// * `nsvals` - Pointer to store the number of singular values
408///
409/// # Returns
410/// * `SPIR_COMPUTATION_SUCCESS` on success
411/// * `SPIR_INVALID_ARGUMENT` if k or nsvals is null
412/// * `SPIR_INTERNAL_ERROR` if internal panic occurs
413#[unsafe(no_mangle)]
414pub extern "C" fn spir_kernel_get_sve_hints_nsvals(
415    k: *const spir_kernel,
416    epsilon: f64,
417    nsvals: *mut libc::c_int,
418) -> StatusCode {
419    if k.is_null() || nsvals.is_null() {
420        return SPIR_INVALID_ARGUMENT;
421    }
422
423    if epsilon <= 0.0 || !epsilon.is_finite() {
424        return SPIR_INVALID_ARGUMENT;
425    }
426
427    let result = catch_unwind(|| unsafe {
428        let kernel = &*k;
429
430        // Get SVE hints based on kernel type
431        let n = match kernel.inner() {
432            crate::types::KernelType::Logistic(k) => {
433                use sparse_ir::kernel::KernelProperties;
434                let hints = k.sve_hints::<f64>(epsilon);
435                hints.nsvals()
436            }
437            crate::types::KernelType::RegularizedBose(k) => {
438                use sparse_ir::kernel::KernelProperties;
439                let hints = k.sve_hints::<f64>(epsilon);
440                hints.nsvals()
441            }
442        };
443
444        *nsvals = n as libc::c_int;
445        SPIR_COMPUTATION_SUCCESS
446    });
447
448    result.unwrap_or(SPIR_INTERNAL_ERROR)
449}
450
451/// Get the number of Gauss points hint from a kernel
452///
453/// # Arguments
454/// * `k` - Kernel object
455/// * `epsilon` - Accuracy target for the basis
456/// * `ngauss` - Pointer to store the number of Gauss points
457///
458/// # Returns
459/// * `SPIR_COMPUTATION_SUCCESS` on success
460/// * `SPIR_INVALID_ARGUMENT` if k or ngauss is null
461/// * `SPIR_INTERNAL_ERROR` if internal panic occurs
462#[unsafe(no_mangle)]
463pub extern "C" fn spir_kernel_get_sve_hints_ngauss(
464    k: *const spir_kernel,
465    epsilon: f64,
466    ngauss: *mut libc::c_int,
467) -> StatusCode {
468    if k.is_null() || ngauss.is_null() {
469        return SPIR_INVALID_ARGUMENT;
470    }
471
472    if epsilon <= 0.0 || !epsilon.is_finite() {
473        return SPIR_INVALID_ARGUMENT;
474    }
475
476    let result = catch_unwind(|| unsafe {
477        let kernel = &*k;
478
479        // Get SVE hints based on kernel type
480        let n = match kernel.inner() {
481            crate::types::KernelType::Logistic(k) => {
482                use sparse_ir::kernel::KernelProperties;
483                let hints = k.sve_hints::<f64>(epsilon);
484                hints.ngauss()
485            }
486            crate::types::KernelType::RegularizedBose(k) => {
487                use sparse_ir::kernel::KernelProperties;
488                let hints = k.sve_hints::<f64>(epsilon);
489                hints.ngauss()
490            }
491        };
492
493        *ngauss = n as libc::c_int;
494        SPIR_COMPUTATION_SUCCESS
495    });
496
497    result.unwrap_or(SPIR_INTERNAL_ERROR)
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503    use std::ptr;
504
505    #[test]
506    fn test_logistic_kernel_creation() {
507        let mut status = SPIR_INTERNAL_ERROR;
508        let kernel = spir_logistic_kernel_new(10.0, &mut status);
509
510        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
511        assert!(!kernel.is_null());
512
513        spir_kernel_release(kernel);
514    }
515
516    #[test]
517    fn test_regularized_bose_kernel_creation() {
518        let mut status = SPIR_INTERNAL_ERROR;
519        let kernel = spir_reg_bose_kernel_new(10.0, &mut status);
520
521        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
522        assert!(!kernel.is_null());
523
524        spir_kernel_release(kernel);
525    }
526
527    #[test]
528    fn test_kernel_lambda() {
529        let mut status = SPIR_INTERNAL_ERROR;
530        let kernel = spir_logistic_kernel_new(10.0, &mut status);
531        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
532
533        let mut lambda = 0.0;
534        let status = spir_kernel_get_lambda(kernel, &mut lambda);
535
536        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
537        assert_eq!(lambda, 10.0);
538
539        spir_kernel_release(kernel);
540    }
541
542    #[test]
543    fn test_kernel_compute() {
544        let mut status = SPIR_INTERNAL_ERROR;
545        let kernel = spir_logistic_kernel_new(10.0, &mut status);
546        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
547
548        let mut result = 0.0;
549        let status = spir_kernel_compute(kernel, 0.5, 0.5, &mut result);
550
551        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
552        assert!(result > 0.0); // Kernel should be positive
553
554        spir_kernel_release(kernel);
555    }
556
557    #[test]
558    fn test_null_pointer_errors() {
559        // Null status pointer
560        let kernel = spir_logistic_kernel_new(10.0, ptr::null_mut());
561        assert!(kernel.is_null());
562
563        // Null kernel pointer
564        let mut lambda = 0.0;
565        let status = spir_kernel_get_lambda(ptr::null(), &mut lambda);
566        assert_eq!(status, SPIR_INVALID_ARGUMENT);
567    }
568
569    #[test]
570    fn test_invalid_lambda() {
571        let mut status = SPIR_COMPUTATION_SUCCESS;
572
573        // Zero lambda
574        let kernel = spir_logistic_kernel_new(0.0, &mut status);
575        assert_eq!(status, SPIR_INVALID_ARGUMENT);
576        assert!(kernel.is_null());
577
578        // Negative lambda
579        let kernel = spir_logistic_kernel_new(-1.0, &mut status);
580        assert_eq!(status, SPIR_INVALID_ARGUMENT);
581        assert!(kernel.is_null());
582    }
583
584    #[test]
585    fn test_kernel_domain() {
586        let mut status = SPIR_INTERNAL_ERROR;
587        let kernel = spir_logistic_kernel_new(10.0, &mut status);
588        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
589
590        let mut xmin = 0.0;
591        let mut xmax = 0.0;
592        let mut ymin = 0.0;
593        let mut ymax = 0.0;
594        let status = spir_kernel_get_domain(kernel, &mut xmin, &mut xmax, &mut ymin, &mut ymax);
595
596        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
597        assert_eq!(xmin, -1.0);
598        assert_eq!(xmax, 1.0);
599        assert_eq!(ymin, -1.0);
600        assert_eq!(ymax, 1.0);
601
602        spir_kernel_release(kernel);
603    }
604
605    #[test]
606    fn test_kernel_get_sve_hints_nsvals() {
607        let lambda = 10.0;
608        let epsilon = 1e-8;
609
610        let mut status = SPIR_INTERNAL_ERROR;
611        let kernel = spir_logistic_kernel_new(lambda, &mut status);
612        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
613        assert!(!kernel.is_null());
614
615        let mut nsvals = 0;
616        let status = spir_kernel_get_sve_hints_nsvals(kernel, epsilon, &mut nsvals);
617        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
618        assert!(nsvals > 0);
619        assert!(nsvals >= 10);
620        assert!(nsvals <= 1000);
621
622        spir_kernel_release(kernel);
623    }
624
625    #[test]
626    fn test_kernel_get_sve_hints_ngauss() {
627        let lambda = 10.0;
628        let epsilon_coarse = 1e-6;
629        let epsilon_fine = 1e-10;
630
631        let mut status = SPIR_INTERNAL_ERROR;
632        let kernel = spir_logistic_kernel_new(lambda, &mut status);
633        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
634        assert!(!kernel.is_null());
635
636        let mut ngauss_coarse = 0;
637        let status = spir_kernel_get_sve_hints_ngauss(kernel, epsilon_coarse, &mut ngauss_coarse);
638        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
639        assert!(ngauss_coarse > 0);
640        assert_eq!(ngauss_coarse, 10); // For epsilon >= 1e-8, ngauss should be 10
641
642        let mut ngauss_fine = 0;
643        let status = spir_kernel_get_sve_hints_ngauss(kernel, epsilon_fine, &mut ngauss_fine);
644        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
645        assert!(ngauss_fine > 0);
646        assert_eq!(ngauss_fine, 16); // For epsilon < 1e-8, ngauss should be 16
647
648        spir_kernel_release(kernel);
649    }
650
651    #[test]
652    fn test_kernel_get_sve_hints_segments_x() {
653        let lambda = 10.0;
654        let epsilon = 1e-8;
655
656        let mut status = SPIR_INTERNAL_ERROR;
657        let kernel = spir_logistic_kernel_new(lambda, &mut status);
658        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
659        assert!(!kernel.is_null());
660
661        // First call: get the number of segments
662        let mut n_segments = 0;
663        let status =
664            spir_kernel_get_sve_hints_segments_x(kernel, epsilon, ptr::null_mut(), &mut n_segments);
665        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
666        assert!(n_segments > 0);
667
668        // Second call: get the actual segments
669        let mut segments = vec![0.0; (n_segments + 1) as usize];
670        let mut n_segments_out = n_segments + 1;
671        let status = spir_kernel_get_sve_hints_segments_x(
672            kernel,
673            epsilon,
674            segments.as_mut_ptr(),
675            &mut n_segments_out,
676        );
677        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
678        assert_eq!(n_segments_out, n_segments);
679
680        // Verify segments are valid
681        assert_eq!(segments.len(), (n_segments + 1) as usize);
682        assert!((segments[0] - (0.0)).abs() < 1e-10);
683        assert!((segments[n_segments as usize] - 1.0).abs() < 1e-10);
684
685        // Verify segments are in ascending order
686        for i in 1..segments.len() {
687            assert!(segments[i] > segments[i - 1]);
688        }
689
690        spir_kernel_release(kernel);
691    }
692
693    #[test]
694    fn test_kernel_get_sve_hints_segments_y() {
695        let lambda = 10.0;
696        let epsilon = 1e-8;
697
698        let mut status = SPIR_INTERNAL_ERROR;
699        let kernel = spir_logistic_kernel_new(lambda, &mut status);
700        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
701        assert!(!kernel.is_null());
702
703        // First call: get the number of segments
704        let mut n_segments = 0;
705        let status =
706            spir_kernel_get_sve_hints_segments_y(kernel, epsilon, ptr::null_mut(), &mut n_segments);
707        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
708        assert!(n_segments > 0);
709
710        // Second call: get the actual segments
711        let mut segments = vec![0.0; (n_segments + 1) as usize];
712        let mut n_segments_out = n_segments + 1;
713        let status = spir_kernel_get_sve_hints_segments_y(
714            kernel,
715            epsilon,
716            segments.as_mut_ptr(),
717            &mut n_segments_out,
718        );
719        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
720        assert_eq!(n_segments_out, n_segments);
721
722        // Verify segments are valid
723        assert_eq!(segments.len(), (n_segments + 1) as usize);
724        assert!((segments[0] - (0.0)).abs() < 1e-10);
725        assert!((segments[n_segments as usize] - 1.0).abs() < 1e-10);
726
727        // Verify segments are in ascending order
728        for i in 1..segments.len() {
729            assert!(segments[i] > segments[i - 1]);
730        }
731
732        spir_kernel_release(kernel);
733    }
734
735    #[test]
736    fn test_kernel_get_sve_hints_with_regularized_bose() {
737        let lambda = 10.0;
738        let epsilon = 1e-8;
739
740        let mut status = SPIR_INTERNAL_ERROR;
741        let kernel = spir_reg_bose_kernel_new(lambda, &mut status);
742        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
743        assert!(!kernel.is_null());
744
745        // Test nsvals
746        let mut nsvals = 0;
747        let status = spir_kernel_get_sve_hints_nsvals(kernel, epsilon, &mut nsvals);
748        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
749        assert!(nsvals > 0);
750
751        // Test ngauss
752        let mut ngauss = 0;
753        let status = spir_kernel_get_sve_hints_ngauss(kernel, epsilon, &mut ngauss);
754        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
755        assert!(ngauss > 0);
756
757        // Test segments_x
758        let mut n_segments_x = 0;
759        let status = spir_kernel_get_sve_hints_segments_x(
760            kernel,
761            epsilon,
762            ptr::null_mut(),
763            &mut n_segments_x,
764        );
765        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
766        assert!(n_segments_x > 0);
767
768        // Test segments_y
769        let mut n_segments_y = 0;
770        let status = spir_kernel_get_sve_hints_segments_y(
771            kernel,
772            epsilon,
773            ptr::null_mut(),
774            &mut n_segments_y,
775        );
776        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
777        assert!(n_segments_y > 0);
778
779        spir_kernel_release(kernel);
780    }
781
782    #[test]
783    fn test_kernel_get_sve_hints_error_handling() {
784        let lambda = 10.0;
785        let epsilon = 1e-8;
786
787        let mut status = SPIR_INTERNAL_ERROR;
788        let kernel = spir_logistic_kernel_new(lambda, &mut status);
789        assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
790        assert!(!kernel.is_null());
791
792        // Test with nullptr kernel
793        let mut nsvals = 0;
794        let status = spir_kernel_get_sve_hints_nsvals(ptr::null(), epsilon, &mut nsvals);
795        assert_ne!(status, SPIR_COMPUTATION_SUCCESS);
796
797        // Test with nullptr output parameter
798        let status = spir_kernel_get_sve_hints_nsvals(kernel, epsilon, ptr::null_mut());
799        assert_ne!(status, SPIR_COMPUTATION_SUCCESS);
800
801        // Test with invalid epsilon
802        let mut nsvals = 0;
803        let status = spir_kernel_get_sve_hints_nsvals(kernel, -1.0, &mut nsvals);
804        assert_ne!(status, SPIR_COMPUTATION_SUCCESS);
805
806        spir_kernel_release(kernel);
807    }
808}