1use sp1_primitives::SP1Field;
2
3use crate::runtime::{CudaRustError, CudaStreamHandle, DEFAULT_STREAM};
4
5pub unsafe fn sppark_init_default_stream() -> CudaRustError {
9 sppark_init(DEFAULT_STREAM)
10}
11
12extern "C" {
13 pub fn sppark_init(stream: CudaStreamHandle) -> CudaRustError;
14
15 pub fn batch_coset_dft(
16 d_out: *mut SP1Field,
17 d_in: *mut SP1Field,
18 lg_domain_size: u32,
19 lg_blowup: u32,
20 shift: SP1Field,
21 poly_count: u32,
22 is_bit_rev: bool,
23 stream: CudaStreamHandle,
24 ) -> CudaRustError;
25
26 pub fn batch_lde_shift_in_place(
27 d_inout: *mut SP1Field,
28 lg_domain_size: u32,
29 lg_blowup: u32,
30 shift: SP1Field,
31 poly_count: u32,
32 is_bit_rev: bool,
33 stream: CudaStreamHandle,
34 ) -> CudaRustError;
35
36 pub fn batch_coset_dft_in_place(
37 d_inout: *mut SP1Field,
38 lg_domain_size: u32,
39 lg_blowup: u32,
40 shift: SP1Field,
41 poly_count: u32,
42 is_bit_rev: bool,
43 stream: CudaStreamHandle,
44 ) -> CudaRustError;
45
46 pub fn batch_NTT(
47 d_inout: *mut SP1Field,
48 lg_domain_size: u32,
49 poly_count: u32,
50 stream: CudaStreamHandle,
51 ) -> CudaRustError;
52
53 pub fn batch_iNTT(
54 d_inout: *mut SP1Field,
55 lg_domain_size: u32,
56 poly_count: u32,
57 stream: CudaStreamHandle,
58 ) -> CudaRustError;
59}
60
61#[cfg(test)]
62mod tests {
63 use super::*;
64
65 #[test]
66 fn test_sppark_init() {
67 unsafe { sppark_init_default_stream() };
68 }
69}