1use core::mem;
2use std::sync::Arc;
3
4use futures::future::join_all;
5use slop_algebra::PrimeField32;
6use slop_alloc::mem::CopyError;
7use slop_alloc::Buffer;
8use slop_tensor::Tensor;
9use sp1_core_machine::global::{GlobalChip, GlobalCols, GLOBAL_INITIAL_DIGEST_POS};
10use sp1_gpu_cudart::sys::runtime::Dim3;
11use sp1_gpu_cudart::transpose::DeviceTransposeKernel;
12use sp1_gpu_cudart::{args, DeviceMle, ScanKernel, TaskScope};
13use sp1_hypercube::air::MachineAir;
14use sp1_hypercube::septic_curve::SepticCurve;
15use sp1_hypercube::septic_digest::SepticDigest;
16use sp1_hypercube::septic_extension::{SepticBlock, SepticExtension};
17
18use sp1_gpu_cudart::TracegenRiscvGlobalKernel;
19
20use crate::{CudaTracegenAir, F};
21
22impl CudaTracegenAir<F> for GlobalChip {
23 fn supports_device_main_tracegen(&self) -> bool {
24 true
25 }
26
27 async fn generate_trace_device(
28 &self,
29 input: &Self::Record,
30 output: &mut Self::Record,
31 scope: &TaskScope,
32 ) -> Result<DeviceMle<F>, CopyError> {
33 let events = &input.global_interaction_events;
34 let events_len = events.len();
35
36 let events_device = {
37 let mut buf = Buffer::try_with_capacity_in(events.len(), scope.clone()).unwrap();
38 buf.extend_from_host_slice(events)?;
39 buf
40 };
41
42 const NUM_GLOBAL_COLS: usize = size_of::<GlobalCols<u8>>();
43
44 let height = <Self as MachineAir<F>>::num_rows(self, input)
45 .expect("num_rows(...) should be Some(_)");
46
47 let mut trace = Tensor::<F, TaskScope>::zeros_in([NUM_GLOBAL_COLS, height], scope.clone());
48
49 unsafe {
51 const BLOCK_DIM: usize = 64;
52 let grid_dim = height.div_ceil(BLOCK_DIM);
53 let tracegen_riscv_global_args =
59 args!(trace.as_mut_ptr(), height, events_device.as_ptr(), events.len());
60 scope
61 .launch_kernel(
62 TaskScope::tracegen_riscv_global_decompress_kernel(),
63 grid_dim,
64 BLOCK_DIM,
65 &tracegen_riscv_global_args,
66 0,
67 )
68 .unwrap();
69 }
70
71 const CURVE_FIELD_EXT_DEGREE: usize = 7;
75 assert_eq!(CURVE_FIELD_EXT_DEGREE * mem::size_of::<F>(), mem::size_of::<SepticBlock<F>>());
76 const CURVE_POINT_WIDTH: usize = 2 * CURVE_FIELD_EXT_DEGREE;
78 assert_eq!(mem::size_of::<[SepticBlock<F>; 2]>(), CURVE_POINT_WIDTH * mem::size_of::<F>());
79 assert_eq!(mem::size_of::<SepticCurve<F>>(), CURVE_POINT_WIDTH * mem::size_of::<F>());
80 let mut cumulative_sums =
82 Buffer::<SepticCurve<F>, _>::with_capacity_in(height, scope.clone());
83 let mut accumulation_initial_digest_row_major =
85 Buffer::<SepticCurve<F>, _>::with_capacity_in(height, scope.clone());
86
87 {
89 let accumulation_initial_digest_col_major = &trace.as_buffer()[(height
90 * GLOBAL_INITIAL_DIGEST_POS)
91 ..(height * (GLOBAL_INITIAL_DIGEST_POS + CURVE_POINT_WIDTH))];
92 let src_sizes = [CURVE_POINT_WIDTH, height];
95 let src_ptr = accumulation_initial_digest_col_major.as_ptr();
96 assert_eq!(
97 src_sizes.into_iter().product::<usize>(),
98 accumulation_initial_digest_col_major.len()
99 );
100 let dst_sizes = [height, CURVE_POINT_WIDTH];
101 let dst_mut_ptr = accumulation_initial_digest_row_major.as_mut_ptr();
102 let num_dims = src_sizes.len();
103
104 let dim_x = src_sizes[num_dims - 2];
105 let dim_y = src_sizes[num_dims - 1];
106 let dim_z: usize = src_sizes.iter().take(num_dims - 2).product();
107 assert_eq!(dim_x, dst_sizes[num_dims - 1]);
108 assert_eq!(dim_y, dst_sizes[num_dims - 2]);
109
110 let block_dim: Dim3 = (32u32, 32u32, 1u32).into();
111 let grid_dim: Dim3 = (
112 dim_x.div_ceil(block_dim.x as usize),
113 dim_y.div_ceil(block_dim.y as usize),
114 dim_z.div_ceil(block_dim.z as usize),
115 )
116 .into();
117 let args = args!(src_ptr, dst_mut_ptr, dim_x, dim_y, dim_z);
118 unsafe {
119 scope
120 .launch_kernel(
121 <TaskScope as DeviceTransposeKernel<F>>::transpose_kernel(),
122 grid_dim,
123 block_dim,
124 &args,
125 0,
126 )
127 .unwrap();
128 }
129 }
130
131 {
134 const SCAN_KERNEL_LARGE_SECTION_SIZE: usize = 512;
135 let d_out = cumulative_sums.as_mut_ptr();
136 let d_in = accumulation_initial_digest_row_major.as_ptr();
137 let n = height;
138 if (2 * n) <= SCAN_KERNEL_LARGE_SECTION_SIZE {
139 let args = args!(d_out, d_in, n);
140 unsafe {
141 scope
142 .launch_kernel(
143 <TaskScope as ScanKernel<F>>::single_block_scan_kernel_large_bb31_septic_curve(
144 ),
145 1,
146 n,
147 &args,
148 0,
149 )
150 .unwrap()
151 };
152 } else {
153 let block_dim = SCAN_KERNEL_LARGE_SECTION_SIZE / 2;
154 let num_blocks = n.div_ceil(block_dim);
155 let mut scan_values =
158 Buffer::<SepticCurve<F>, _>::with_capacity_in(num_blocks + 1, scope.clone());
159 scan_values.write_bytes(0, mem::size_of::<SepticCurve<F>>()).unwrap();
160 let mut block_counter = Buffer::<u32, _>::with_capacity_in(1, scope.clone());
162 block_counter.write_bytes(0, mem::size_of::<u32>()).unwrap();
163 let mut flags = Buffer::<u32, _>::with_capacity_in(num_blocks + 1, scope.clone());
166 flags.write_bytes(1, size_of::<u32>()).unwrap();
167 flags.write_bytes(0, num_blocks * size_of::<u32>()).unwrap();
168 debug_assert_eq!(flags.len(), num_blocks + 1);
169 let args = args!(
170 d_out,
171 d_in,
172 n,
173 scan_values.as_mut_ptr(),
174 block_counter.as_mut_ptr(),
175 flags.as_mut_ptr()
176 );
177 unsafe {
178 scope
179 .launch_kernel(
180 <TaskScope as ScanKernel<F>>::scan_kernel_large_bb31_septic_curve(),
181 num_blocks,
182 block_dim,
183 &args,
184 0,
185 )
186 .unwrap()
187 };
188 }
189 }
190 drop(accumulation_initial_digest_row_major);
192
193 unsafe {
195 const BLOCK_DIM: usize = 64;
196 let grid_dim = height.div_ceil(BLOCK_DIM);
197 let tracegen_riscv_global_args =
203 args!(trace.as_mut_ptr(), height, cumulative_sums.as_ptr(), events.len());
204 scope
205 .launch_kernel(
206 TaskScope::tracegen_riscv_global_finalize_kernel(),
207 grid_dim,
208 BLOCK_DIM,
209 &tracegen_riscv_global_args,
210 0,
211 )
212 .unwrap();
213 }
214
215 output.global_interaction_event_count =
217 events.len().try_into().expect("number of Global events should fit in a u32");
218 let trace = Arc::new(trace);
220
221 let global_sum = if height == 0 {
222 SepticDigest(SepticCurve::convert(SepticDigest::<F>::zero().0, |x| {
223 F::as_canonical_u32(&x)
224 }))
225 } else {
226 const CUMULATIVE_SUM_COL_START: usize =
228 mem::offset_of!(GlobalCols<u8>, accumulation.cumulative_sum);
229 assert_eq!(CUMULATIVE_SUM_COL_START + CURVE_POINT_WIDTH, NUM_GLOBAL_COLS);
230 let copied_sum = join_all((CUMULATIVE_SUM_COL_START..NUM_GLOBAL_COLS).map(|i| {
231 let trace = Arc::clone(&trace);
232 let scope = scope.clone();
233 tokio::task::spawn_blocking(move || {
234 trace[[i, events_len - 1]].copy_into_host(&scope)
236 })
237 }))
238 .await;
239 SepticDigest(SepticCurve {
240 x: SepticExtension(core::array::from_fn(|i| {
241 copied_sum[i].as_ref().unwrap().as_canonical_u32()
242 })),
243 y: SepticExtension(core::array::from_fn(|i| {
244 copied_sum[CURVE_FIELD_EXT_DEGREE + i].as_ref().unwrap().as_canonical_u32()
245 })),
246 })
247 };
248
249 *input.global_cumulative_sum.lock().unwrap() = global_sum;
250
251 let trace =
252 Arc::into_inner(trace).expect("trace Arc should have exactly one strong reference");
253
254 Ok(DeviceMle::from(trace))
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use rand::{rngs::StdRng, Rng, SeedableRng};
261 use slop_algebra::PrimeField32;
262 use slop_tensor::Tensor;
263 use sp1_core_executor::{events::GlobalInteractionEvent, ExecutionRecord};
264 use sp1_core_machine::global::GlobalChip;
265 use sp1_gpu_cudart::TaskScope;
266 use sp1_hypercube::air::MachineAir;
267 use sp1_hypercube::MachineRecord;
268
269 use crate::{CudaTracegenAir, F};
270
271 #[tokio::test]
272 async fn test_global_generate_trace() {
273 sp1_gpu_cudart::spawn(inner_test_global_generate_trace).await.unwrap();
274 }
275
276 async fn inner_test_global_generate_trace(scope: TaskScope) {
277 let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
278 let events = core::iter::repeat_with(|| GlobalInteractionEvent {
279 message: core::array::from_fn(|_| rng.gen::<F>().as_canonical_u32()),
281 is_receive: rng.gen(),
282 kind: rng.gen_range(0..(1 << 6)),
283 })
284 .take(1000)
285 .collect::<Vec<_>>();
286
287 let [shard, gpu_shard] = core::array::from_fn(|_| ExecutionRecord {
288 global_interaction_events: events.clone(),
289 ..Default::default()
290 });
291
292 let chip = GlobalChip;
293
294 let trace = Tensor::<F>::from(chip.generate_trace(&shard, &mut ExecutionRecord::default()));
295
296 let gpu_trace = chip
297 .generate_trace_device(&gpu_shard, &mut ExecutionRecord::default(), &scope)
298 .await
299 .expect("should copy events to device successfully")
300 .to_host()
301 .expect("should copy trace to host successfully")
302 .into_guts();
303
304 crate::tests::test_traces_eq(&trace, &gpu_trace, &events);
305
306 assert_eq!(
307 *gpu_shard.global_cumulative_sum.lock().unwrap(),
308 *shard.global_cumulative_sum.lock().unwrap()
309 );
310
311 assert_eq!(gpu_shard.public_values::<F>(), shard.public_values::<F>());
312 }
313}