Skip to main content

sp1_gpu_sys/
kernels.rs

1use std::ffi::c_void;
2
3use crate::runtime::{CudaRustError, CudaStreamHandle, KernelPtr};
4
5extern "C" {
6    // Sum kernels
7    pub fn sum_kernel_u32() -> KernelPtr;
8    pub fn sum_kernel_felt() -> KernelPtr;
9    pub fn sum_kernel_ext() -> KernelPtr;
10
11    // Tracegen kernels
12    pub fn generate_col_index() -> KernelPtr;
13    pub fn generate_start_indices() -> KernelPtr;
14    pub fn fill_buffer() -> KernelPtr;
15    pub fn count_and_add_kernel() -> KernelPtr;
16    pub fn sum_to_trace_kernel() -> KernelPtr;
17
18    // Reduce kernels
19    pub fn reduce_kernel_felt() -> KernelPtr;
20    pub fn reduce_kernel_ext() -> KernelPtr;
21
22    // JaggedMLE kernels
23    pub fn jagged_eval_kernel_chunked_felt() -> KernelPtr;
24    pub fn jagged_eval_kernel_chunked_ext() -> KernelPtr;
25
26    // JaggedInfo kernels
27    pub fn initialize_jagged_info() -> KernelPtr;
28    pub fn fix_last_variable_jagged_info() -> KernelPtr;
29
30    // Basic jagged fix last variable
31    pub fn fix_last_variable_jagged_felt() -> KernelPtr;
32    pub fn fix_last_variable_jagged_ext() -> KernelPtr;
33
34    // Fused dispatch: one launch per non-empty tier handles every
35    // Sequential chunk in a round. The launcher's per-block dispatch
36    // descriptor maps each block to its `(chunk_id, row_offset, n_rows)`.
37    pub fn zerocheck_fused_sequential_kb_32_kernel() -> KernelPtr;
38    pub fn zerocheck_fused_sequential_kb_64_kernel() -> KernelPtr;
39    pub fn zerocheck_fused_sequential_kb_128_kernel() -> KernelPtr;
40    pub fn zerocheck_fused_sequential_kb_256_kernel() -> KernelPtr;
41    pub fn zerocheck_fused_sequential_kb_512_kernel() -> KernelPtr;
42    pub fn zerocheck_fused_sequential_kb_1024_kernel() -> KernelPtr;
43    pub fn zerocheck_fused_sequential_ext_32_kernel() -> KernelPtr;
44    pub fn zerocheck_fused_sequential_ext_64_kernel() -> KernelPtr;
45    pub fn zerocheck_fused_sequential_ext_128_kernel() -> KernelPtr;
46    pub fn zerocheck_fused_sequential_ext_256_kernel() -> KernelPtr;
47    pub fn zerocheck_fused_sequential_ext_512_kernel() -> KernelPtr;
48    pub fn zerocheck_fused_sequential_ext_1024_kernel() -> KernelPtr;
49
50    // zerocheck (DAG-native): ColumnTile lowering kernels.
51    pub fn zerocheck_column_tile_kb_kernel() -> KernelPtr;
52    pub fn zerocheck_column_tile_ext_kernel() -> KernelPtr;
53
54    // zerocheck (DAG-native): per-chip geq correction. One block per chip,
55    // writes 3 ext_t partials per chip (one per eval point) that the host
56    // aggregation sums into the round's totals.
57    pub fn zerocheck_geq_corrections_kernel() -> KernelPtr;
58
59    // zerocheck (DAG-native): apply `VirtualGeq::fix_last_variable(alpha)`
60    // in place to each chip's geq state. One thread per chip.
61    pub fn zerocheck_fix_geq_state_kernel() -> KernelPtr;
62
63    // zerocheck (DAG-native): aggregate per-block partials into the 3
64    // per-eval-point totals via a single-block grid-stride reduction. The
65    // host then only downloads the 3 totals instead of the full partials.
66    pub fn zerocheck_aggregate_partials_kernel() -> KernelPtr;
67
68    // zerocheck (DAG-native): per-chip GKR column sweep. Decoupled from the
69    // sequential constraint kernel so wide chips can parallelise the column
70    // reduction across a warp's lanes. One block per (chip, row-tile).
71    pub fn zerocheck_gkr_sweep_kb_kernel() -> KernelPtr;
72    pub fn zerocheck_gkr_sweep_ext_kernel() -> KernelPtr;
73
74    // zerocheck (DAG-native): per-chunk padded_row_adjustment via the
75    // bytecode interpreter at the all-zero trace. One thread per chunk;
76    // output is one ext_t per chunk, summed by chip on the host into the
77    // per-chip `padded_row_adjustment`. Tiered by MAX_REGS like
78    // `fused_sequential`.
79    pub fn zerocheck_pad_adj_32_kernel() -> KernelPtr;
80    pub fn zerocheck_pad_adj_64_kernel() -> KernelPtr;
81    pub fn zerocheck_pad_adj_128_kernel() -> KernelPtr;
82    pub fn zerocheck_pad_adj_256_kernel() -> KernelPtr;
83    pub fn zerocheck_pad_adj_512_kernel() -> KernelPtr;
84    pub fn zerocheck_pad_adj_1024_kernel() -> KernelPtr;
85
86    // JaggedMle fold-metadata: one fused multi-block kernel reads
87    // `column_heights`, writes `new_column_heights` (= `h.div_ceil(4)*2`
88    // element-wise) and `new_start_indices` (= exclusive prefix sum) — all
89    // on device, no host round-trip. Uses decoupled-lookback to handle any
90    // n_columns. See `include/jagged_assist/fold_metadata.cuh` for the
91    // caller-init contract on `block_counter`, `flags`, `scan_values`.
92    pub fn jagged_fold_metadata_kernel() -> KernelPtr;
93    pub fn jagged_fold_metadata_block_dim() -> u32;
94    pub fn jagged_fold_metadata_section_size() -> u32;
95
96    // JaggedMle chip-layouts: reads `start_indices` + `column_heights` at
97    // the sparse per-chip positions described by `ChipColumnLayoutEntry`,
98    // writes per-chip `ChipLayout[chip_idx]` + `chip_heights[chip_idx]`.
99    // One thread per chip. See `include/jagged_assist/chip_layouts.cuh`.
100    pub fn jagged_chip_layouts_kernel() -> KernelPtr;
101
102    // Jagged Zerocheck Kernels
103    pub fn jagged_constraint_poly_eval_32_koala_bear_kernel() -> KernelPtr;
104    pub fn jagged_constraint_poly_eval_64_koala_bear_kernel() -> KernelPtr;
105    pub fn jagged_constraint_poly_eval_128_koala_bear_kernel() -> KernelPtr;
106    pub fn jagged_constraint_poly_eval_256_koala_bear_kernel() -> KernelPtr;
107    pub fn jagged_constraint_poly_eval_512_koala_bear_kernel() -> KernelPtr;
108    pub fn jagged_constraint_poly_eval_1024_koala_bear_kernel() -> KernelPtr;
109
110    pub fn jagged_constraint_poly_eval_32_koala_bear_extension_kernel() -> KernelPtr;
111    pub fn jagged_constraint_poly_eval_64_koala_bear_extension_kernel() -> KernelPtr;
112    pub fn jagged_constraint_poly_eval_128_koala_bear_extension_kernel() -> KernelPtr;
113    pub fn jagged_constraint_poly_eval_256_koala_bear_extension_kernel() -> KernelPtr;
114    pub fn jagged_constraint_poly_eval_512_koala_bear_extension_kernel() -> KernelPtr;
115    pub fn jagged_constraint_poly_eval_1024_koala_bear_extension_kernel() -> KernelPtr;
116
117    // Zerocheck kernels
118    pub fn zerocheck_sum_as_poly_base_ext_kernel() -> KernelPtr;
119    pub fn zerocheck_sum_as_poly_ext_ext_kernel() -> KernelPtr;
120
121    pub fn zerocheck_fix_last_variable_and_sum_as_poly_base_ext_kernel() -> KernelPtr;
122    pub fn zerocheck_fix_last_variable_and_sum_as_poly_ext_ext_kernel() -> KernelPtr;
123
124    // Hadamard kernels
125    pub fn hadamard_sum_as_poly_base_ext_kernel() -> KernelPtr;
126    pub fn hadamard_sum_as_poly_ext_ext_kernel() -> KernelPtr;
127
128    pub fn hadamard_fix_last_variable_and_sum_as_poly_base_ext_kernel() -> KernelPtr;
129    pub fn hadamard_fix_last_variable_and_sum_as_poly_ext_ext_kernel() -> KernelPtr;
130
131    pub fn fix_last_variable_felt_ext_kernel() -> KernelPtr;
132    pub fn fix_last_variable_ext_ext_kernel() -> KernelPtr;
133    pub fn mle_fix_last_variable_koala_bear_base_base_constant_padding() -> KernelPtr;
134    pub fn mle_fix_last_variable_koala_bear_base_extension_constant_padding() -> KernelPtr;
135    pub fn mle_fix_last_variable_koala_bear_ext_ext_constant_padding() -> KernelPtr;
136
137    pub fn mle_fix_last_variable_koala_bear_ext_ext_zero_padding() -> KernelPtr;
138
139    // ******** LogUp GKR kernels - Round operations ********
140    pub fn logup_gkr_sum_as_poly_circuit_layer() -> KernelPtr;
141    pub fn logup_gkr_first_sum_as_poly_circuit_layer() -> KernelPtr;
142    pub fn logup_gkr_fix_last_variable_circuit_layer() -> KernelPtr;
143    pub fn logup_gkr_fix_last_variable_last_circuit_layer() -> KernelPtr;
144    pub fn logup_gkr_sum_as_poly_interactions_layer() -> KernelPtr;
145    pub fn logup_gkr_fix_last_variable_interactions_layer() -> KernelPtr;
146
147    // LogUp GKR kernels - First layer operations
148    pub fn logup_gkr_fix_last_variable_first_layer() -> KernelPtr;
149    pub fn logup_gkr_fix_and_sum_first_layer() -> KernelPtr;
150    pub fn logup_gkr_sum_as_poly_first_layer() -> KernelPtr;
151    pub fn logup_gkr_first_layer_transition() -> KernelPtr;
152
153    // LogUp GKR kernels - Execution operations
154    pub fn logup_gkr_circuit_transition() -> KernelPtr;
155    pub fn logup_gkr_populate_last_circuit_layer() -> KernelPtr;
156    pub fn logup_gkr_extract_output() -> KernelPtr;
157
158    // Logup GKR kernels - Fused fix and sum kernels
159    pub fn logup_gkr_fix_and_sum_circuit_layer() -> KernelPtr;
160    pub fn logup_gkr_fix_and_sum_last_circuit_layer() -> KernelPtr;
161    pub fn logup_gkr_fix_and_sum_interactions_layer() -> KernelPtr;
162
163    // ******** Jagged sumcheck kernels ********
164    pub fn jagged_sum_as_poly() -> KernelPtr;
165    pub fn jagged_fix_and_sum() -> KernelPtr;
166    pub fn padded_hadamard_fix_and_sum() -> KernelPtr;
167
168    // Populate restrict eq
169    pub fn populate_restrict_eq_host(
170        src: *const c_void,
171        len: usize,
172        stream: CudaStreamHandle,
173    ) -> CudaRustError;
174    pub fn populate_restrict_eq_device(
175        src: *const c_void,
176        len: usize,
177        stream: CudaStreamHandle,
178    ) -> CudaRustError;
179
180    // ******** Hadamard look ahead kernels ********
181    // Look ahead kernels - FIX_TILE=32
182    pub fn round_kernel_1_32_2_2_false() -> KernelPtr;
183    pub fn round_kernel_2_32_2_2_true() -> KernelPtr;
184    pub fn round_kernel_2_32_2_2_false() -> KernelPtr;
185    pub fn round_kernel_4_32_2_2_true() -> KernelPtr;
186    pub fn round_kernel_4_32_2_2_false() -> KernelPtr;
187    pub fn round_kernel_8_32_2_2_true() -> KernelPtr;
188    pub fn round_kernel_8_32_2_2_false() -> KernelPtr;
189
190    // Look ahead kernels - FIX_TILE=64
191    pub fn round_kernel_1_64_2_2_false() -> KernelPtr;
192    pub fn round_kernel_2_64_2_2_true() -> KernelPtr;
193    pub fn round_kernel_2_64_2_2_false() -> KernelPtr;
194    pub fn round_kernel_4_64_2_2_true() -> KernelPtr;
195    pub fn round_kernel_4_64_2_2_false() -> KernelPtr;
196    pub fn round_kernel_8_64_2_2_true() -> KernelPtr;
197    pub fn round_kernel_8_64_2_2_false() -> KernelPtr;
198
199    // Look ahead kernels - NUM_POINTS=3, FIX_TILE=32
200    pub fn round_kernel_1_32_2_3_false() -> KernelPtr;
201    pub fn round_kernel_2_32_2_3_true() -> KernelPtr;
202    pub fn round_kernel_2_32_2_3_false() -> KernelPtr;
203    pub fn round_kernel_4_32_2_3_true() -> KernelPtr;
204    pub fn round_kernel_4_32_2_3_false() -> KernelPtr;
205    pub fn round_kernel_8_32_2_3_true() -> KernelPtr;
206    pub fn round_kernel_8_32_2_3_false() -> KernelPtr;
207
208    // Look ahead kernels - NUM_POINTS=3, FIX_TILE=64
209    pub fn round_kernel_1_64_2_3_false() -> KernelPtr;
210    pub fn round_kernel_1_64_4_8_false() -> KernelPtr;
211    pub fn round_kernel_2_64_2_3_true() -> KernelPtr;
212    pub fn round_kernel_2_64_2_3_false() -> KernelPtr;
213    pub fn round_kernel_4_64_2_3_true() -> KernelPtr;
214    pub fn round_kernel_4_64_2_3_false() -> KernelPtr;
215    pub fn round_kernel_4_64_4_8_true() -> KernelPtr;
216    pub fn round_kernel_4_64_4_8_false() -> KernelPtr;
217    pub fn round_kernel_8_64_2_3_true() -> KernelPtr;
218    pub fn round_kernel_8_64_2_3_false() -> KernelPtr;
219
220    // Look ahead kernels - FIX_TILE=128
221    pub fn round_kernel_1_128_4_8_false() -> KernelPtr;
222}