sp1_gpu_tracegen/recursion/
sbox.rs1use slop_air::BaseAir;
2use slop_alloc::mem::CopyError;
3use slop_alloc::Buffer;
4use slop_tensor::Tensor;
5use sp1_gpu_cudart::{args, DeviceMle, TaskScope};
6use sp1_gpu_cudart::{TracegenPreprocessedRecursionSBoxKernel, TracegenRecursionSBoxKernel};
7use sp1_hypercube::air::MachineAir;
8use sp1_recursion_executor::Instruction;
9use sp1_recursion_machine::chips::poseidon2_helper::sbox::Poseidon2SBoxChip;
10
11use crate::{CudaTracegenAir, F};
12
13impl CudaTracegenAir<F> for Poseidon2SBoxChip {
14 fn supports_device_preprocessed_tracegen(&self) -> bool {
15 true
16 }
17
18 async fn generate_preprocessed_trace_device(
19 &self,
20 program: &Self::Program,
21 scope: &TaskScope,
22 ) -> Result<Option<DeviceMle<F>>, CopyError> {
23 let instrs = program
24 .inner
25 .iter()
26 .filter_map(|instruction| match instruction.inner() {
27 Instruction::Poseidon2SBox(instr) => Some(*instr),
28 _ => None,
29 })
30 .collect::<Vec<_>>();
31
32 let instrs_device = {
33 let mut buf = Buffer::try_with_capacity_in(instrs.len(), scope.clone()).unwrap();
34 buf.extend_from_host_slice(&instrs)?;
35 buf
36 };
37
38 let width = MachineAir::<F>::preprocessed_width(self);
39
40 let height =
41 MachineAir::<F>::preprocessed_num_rows_with_instrs_len(self, program, instrs.len())
42 .expect("preprocessed_num_rows_with_instrs_len(...) should be Some(_)");
43
44 let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
45
46 unsafe {
47 const BLOCK_DIM: usize = 64;
48 let grid_dim = height.div_ceil(BLOCK_DIM);
49 let args = args!(trace.as_mut_ptr(), height, instrs_device.as_ptr(), instrs.len());
55 scope
56 .launch_kernel(
57 TaskScope::tracegen_preprocessed_recursion_sbox_kernel(),
58 grid_dim,
59 BLOCK_DIM,
60 &args,
61 0,
62 )
63 .unwrap();
64 }
65
66 Ok(Some(DeviceMle::from(trace)))
67 }
68
69 fn supports_device_main_tracegen(&self) -> bool {
70 true
71 }
72
73 async fn generate_trace_device(
74 &self,
75 input: &Self::Record,
76 _: &mut Self::Record,
77 scope: &TaskScope,
78 ) -> Result<DeviceMle<F>, CopyError> {
79 let events = &input.poseidon2_sbox_events;
80
81 let events_device = {
82 let mut buf = Buffer::try_with_capacity_in(events.len(), scope.clone()).unwrap();
83 buf.extend_from_host_slice(events)?;
84 buf
85 };
86
87 let width = <Self as BaseAir<F>>::width(self);
88
89 let height = <Self as MachineAir<F>>::num_rows(self, input)
90 .expect("num_rows(...) should be Some(_)");
91
92 let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
93
94 unsafe {
95 const BLOCK_DIM: usize = 64;
96 let grid_dim = height.div_ceil(BLOCK_DIM);
97 let args = args!(trace.as_mut_ptr(), height, events_device.as_ptr(), events.len());
103 scope
104 .launch_kernel(
105 TaskScope::tracegen_recursion_sbox_kernel(),
106 grid_dim,
107 BLOCK_DIM,
108 &args,
109 0,
110 )
111 .unwrap();
112 }
113
114 Ok(DeviceMle::from(trace))
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use rand::Rng;
121
122 use slop_algebra::AbstractField;
123 use sp1_recursion_executor::{
124 Address, AnalyzedInstruction, Block, ExecutionRecord, Instruction, Poseidon2SBoxInstr,
125 Poseidon2SBoxIo,
126 };
127 use sp1_recursion_machine::chips::poseidon2_helper::sbox::Poseidon2SBoxChip;
128
129 use crate::F;
130
131 #[tokio::test]
132 async fn test_sbox_generate_preprocessed_trace() {
133 sp1_gpu_cudart::spawn(|scope| {
134 crate::recursion::tests::test_preprocessed_tracegen(
135 Poseidon2SBoxChip,
136 |rng| {
137 let addrs =
138 Poseidon2SBoxIo { input: Address(rng.gen()), output: Address(rng.gen()) };
139 AnalyzedInstruction::new(
140 Instruction::Poseidon2SBox(Poseidon2SBoxInstr {
141 addrs,
142 mults: F::one(),
143 external: rng.gen(),
144 }),
145 rng.gen(),
146 )
147 },
148 scope,
149 )
150 })
151 .await
152 .unwrap();
153 }
154
155 #[tokio::test]
156 async fn test_sbox_generate_main_trace() {
157 sp1_gpu_cudart::spawn(move |scope| {
158 crate::tests::test_main_tracegen(
159 Poseidon2SBoxChip,
160 |rng| {
161 let input = Block(rng.gen());
162 let input_cubed =
164 Block(core::array::from_fn(|i| input.0[i] * input.0[i] * input.0[i]));
165 let output = Block(core::array::from_fn(|i| {
166 input.0[i] * input_cubed.0[i] * input_cubed.0[i]
167 }));
168 Poseidon2SBoxIo { input, output }
169 },
170 |poseidon2_sbox_events| ExecutionRecord {
171 poseidon2_sbox_events,
172 ..Default::default()
173 },
174 scope,
175 )
176 })
177 .await
178 .unwrap();
179 }
180}