Skip to main content

sp1_gpu_sys/
v2_kernels.rs

1use std::ffi::c_void;
2
3use crate::runtime::{CudaRustError, CudaStreamHandle, KernelPtr};
4
5extern "C" {
6    // Sum kernels
7    pub fn sum_kernel_u32() -> KernelPtr;
8    pub fn sum_kernel_felt() -> KernelPtr;
9    pub fn sum_kernel_ext() -> KernelPtr;
10
11    // Tracegen kernels
12    pub fn generate_col_index() -> KernelPtr;
13    pub fn generate_start_indices() -> KernelPtr;
14    pub fn fill_buffer() -> KernelPtr;
15    pub fn count_and_add_kernel() -> KernelPtr;
16    pub fn sum_to_trace_kernel() -> KernelPtr;
17
18    // Reduce kernels
19    pub fn reduce_kernel_felt() -> KernelPtr;
20    pub fn reduce_kernel_ext() -> KernelPtr;
21
22    // JaggedMLE kernels
23    pub fn jagged_eval_kernel_chunked_felt() -> KernelPtr;
24    pub fn jagged_eval_kernel_chunked_ext() -> KernelPtr;
25
26    // JaggedInfo kernels
27    pub fn initialize_jagged_info() -> KernelPtr;
28    pub fn fix_last_variable_jagged_info() -> KernelPtr;
29
30    // Basic jagged fix last variable
31    pub fn fix_last_variable_jagged_felt() -> KernelPtr;
32    pub fn fix_last_variable_jagged_ext() -> KernelPtr;
33
34    // Jagged Zerocheck Kernels
35    pub fn jagged_constraint_poly_eval_32_koala_bear_kernel() -> KernelPtr;
36    pub fn jagged_constraint_poly_eval_64_koala_bear_kernel() -> KernelPtr;
37    pub fn jagged_constraint_poly_eval_128_koala_bear_kernel() -> KernelPtr;
38    pub fn jagged_constraint_poly_eval_256_koala_bear_kernel() -> KernelPtr;
39    pub fn jagged_constraint_poly_eval_512_koala_bear_kernel() -> KernelPtr;
40    pub fn jagged_constraint_poly_eval_1024_koala_bear_kernel() -> KernelPtr;
41
42    pub fn jagged_constraint_poly_eval_32_koala_bear_extension_kernel() -> KernelPtr;
43    pub fn jagged_constraint_poly_eval_64_koala_bear_extension_kernel() -> KernelPtr;
44    pub fn jagged_constraint_poly_eval_128_koala_bear_extension_kernel() -> KernelPtr;
45    pub fn jagged_constraint_poly_eval_256_koala_bear_extension_kernel() -> KernelPtr;
46    pub fn jagged_constraint_poly_eval_512_koala_bear_extension_kernel() -> KernelPtr;
47    pub fn jagged_constraint_poly_eval_1024_koala_bear_extension_kernel() -> KernelPtr;
48
49    // Zerocheck kernels
50    pub fn zerocheck_sum_as_poly_base_ext_kernel() -> KernelPtr;
51    pub fn zerocheck_sum_as_poly_ext_ext_kernel() -> KernelPtr;
52
53    pub fn zerocheck_fix_last_variable_and_sum_as_poly_base_ext_kernel() -> KernelPtr;
54    pub fn zerocheck_fix_last_variable_and_sum_as_poly_ext_ext_kernel() -> KernelPtr;
55
56    // Hadamard kernels
57    pub fn hadamard_sum_as_poly_base_ext_kernel() -> KernelPtr;
58    pub fn hadamard_sum_as_poly_ext_ext_kernel() -> KernelPtr;
59
60    pub fn hadamard_fix_last_variable_and_sum_as_poly_base_ext_kernel() -> KernelPtr;
61    pub fn hadamard_fix_last_variable_and_sum_as_poly_ext_ext_kernel() -> KernelPtr;
62
63    pub fn fix_last_variable_felt_ext_kernel() -> KernelPtr;
64    pub fn fix_last_variable_ext_ext_kernel() -> KernelPtr;
65    pub fn mle_fix_last_variable_koala_bear_base_base_constant_padding() -> KernelPtr;
66    pub fn mle_fix_last_variable_koala_bear_base_extension_constant_padding() -> KernelPtr;
67    pub fn mle_fix_last_variable_koala_bear_ext_ext_constant_padding() -> KernelPtr;
68
69    pub fn mle_fix_last_variable_koala_bear_ext_ext_zero_padding() -> KernelPtr;
70
71    // ******** LogUp GKR kernels - Round operations ********
72    pub fn logup_gkr_sum_as_poly_circuit_layer() -> KernelPtr;
73    pub fn logup_gkr_first_sum_as_poly_circuit_layer() -> KernelPtr;
74    pub fn logup_gkr_fix_last_variable_circuit_layer() -> KernelPtr;
75    pub fn logup_gkr_fix_last_variable_last_circuit_layer() -> KernelPtr;
76    pub fn logup_gkr_sum_as_poly_interactions_layer() -> KernelPtr;
77    pub fn logup_gkr_fix_last_variable_interactions_layer() -> KernelPtr;
78
79    // LogUp GKR kernels - First layer operations
80    pub fn logup_gkr_fix_last_variable_first_layer() -> KernelPtr;
81    pub fn logup_gkr_fix_and_sum_first_layer() -> KernelPtr;
82    pub fn logup_gkr_sum_as_poly_first_layer() -> KernelPtr;
83    pub fn logup_gkr_first_layer_transition() -> KernelPtr;
84
85    // LogUp GKR kernels - Execution operations
86    pub fn logup_gkr_circuit_transition() -> KernelPtr;
87    pub fn logup_gkr_populate_last_circuit_layer() -> KernelPtr;
88    pub fn logup_gkr_extract_output() -> KernelPtr;
89
90    // Logup GKR kernels - Fused fix and sum kernels
91    pub fn logup_gkr_fix_and_sum_circuit_layer() -> KernelPtr;
92    pub fn logup_gkr_fix_and_sum_last_circuit_layer() -> KernelPtr;
93    pub fn logup_gkr_fix_and_sum_interactions_layer() -> KernelPtr;
94
95    // ******** Jagged sumcheck kernels ********
96    pub fn jagged_sum_as_poly() -> KernelPtr;
97    pub fn jagged_fix_and_sum() -> KernelPtr;
98    pub fn padded_hadamard_fix_and_sum() -> KernelPtr;
99
100    // Populate restrict eq
101    pub fn populate_restrict_eq_host(
102        src: *const c_void,
103        len: usize,
104        stream: CudaStreamHandle,
105    ) -> CudaRustError;
106    pub fn populate_restrict_eq_device(
107        src: *const c_void,
108        len: usize,
109        stream: CudaStreamHandle,
110    ) -> CudaRustError;
111
112    // ******** Hadamard look ahead kernels ********
113    // Look ahead kernels - FIX_TILE=32
114    pub fn round_kernel_1_32_2_2_false() -> KernelPtr;
115    pub fn round_kernel_2_32_2_2_true() -> KernelPtr;
116    pub fn round_kernel_2_32_2_2_false() -> KernelPtr;
117    pub fn round_kernel_4_32_2_2_true() -> KernelPtr;
118    pub fn round_kernel_4_32_2_2_false() -> KernelPtr;
119    pub fn round_kernel_8_32_2_2_true() -> KernelPtr;
120    pub fn round_kernel_8_32_2_2_false() -> KernelPtr;
121
122    // Look ahead kernels - FIX_TILE=64
123    pub fn round_kernel_1_64_2_2_false() -> KernelPtr;
124    pub fn round_kernel_2_64_2_2_true() -> KernelPtr;
125    pub fn round_kernel_2_64_2_2_false() -> KernelPtr;
126    pub fn round_kernel_4_64_2_2_true() -> KernelPtr;
127    pub fn round_kernel_4_64_2_2_false() -> KernelPtr;
128    pub fn round_kernel_8_64_2_2_true() -> KernelPtr;
129    pub fn round_kernel_8_64_2_2_false() -> KernelPtr;
130
131    // Look ahead kernels - NUM_POINTS=3, FIX_TILE=32
132    pub fn round_kernel_1_32_2_3_false() -> KernelPtr;
133    pub fn round_kernel_2_32_2_3_true() -> KernelPtr;
134    pub fn round_kernel_2_32_2_3_false() -> KernelPtr;
135    pub fn round_kernel_4_32_2_3_true() -> KernelPtr;
136    pub fn round_kernel_4_32_2_3_false() -> KernelPtr;
137    pub fn round_kernel_8_32_2_3_true() -> KernelPtr;
138    pub fn round_kernel_8_32_2_3_false() -> KernelPtr;
139
140    // Look ahead kernels - NUM_POINTS=3, FIX_TILE=64
141    pub fn round_kernel_1_64_2_3_false() -> KernelPtr;
142    pub fn round_kernel_1_64_4_8_false() -> KernelPtr;
143    pub fn round_kernel_2_64_2_3_true() -> KernelPtr;
144    pub fn round_kernel_2_64_2_3_false() -> KernelPtr;
145    pub fn round_kernel_4_64_2_3_true() -> KernelPtr;
146    pub fn round_kernel_4_64_2_3_false() -> KernelPtr;
147    pub fn round_kernel_4_64_4_8_true() -> KernelPtr;
148    pub fn round_kernel_4_64_4_8_false() -> KernelPtr;
149    pub fn round_kernel_8_64_2_3_true() -> KernelPtr;
150    pub fn round_kernel_8_64_2_3_false() -> KernelPtr;
151
152    // Look ahead kernels - FIX_TILE=128
153    pub fn round_kernel_1_128_4_8_false() -> KernelPtr;
154}