sp1_gpu_tracegen/recursion/
poseidon2_wide.rs1use slop_air::BaseAir;
2use slop_alloc::mem::CopyError;
3use slop_alloc::Buffer;
4use slop_tensor::Tensor;
5use sp1_gpu_cudart::{args, DeviceMle, TaskScope};
6use sp1_gpu_cudart::{
7 TracegenPreprocessedRecursionPoseidon2WideKernel, TracegenRecursionPoseidon2WideKernel,
8};
9use sp1_hypercube::air::MachineAir;
10use sp1_recursion_executor::Instruction;
11use sp1_recursion_machine::chips::poseidon2_wide::Poseidon2WideChip;
12
13use crate::{CudaTracegenAir, F};
14
15impl<const DEGREE: usize> CudaTracegenAir<F> for Poseidon2WideChip<DEGREE> {
16 fn supports_device_preprocessed_tracegen(&self) -> bool {
17 true
18 }
19
20 async fn generate_preprocessed_trace_device(
21 &self,
22 program: &Self::Program,
23 scope: &TaskScope,
24 ) -> Result<Option<DeviceMle<F>>, CopyError> {
25 let instrs = program
26 .inner
27 .iter() .filter_map(|instruction| match instruction.inner() {
29 Instruction::Poseidon2(instr) => Some(**instr),
30 _ => None,
31 })
32 .collect::<Vec<_>>();
33
34 let instrs_device = {
35 let mut buf = Buffer::try_with_capacity_in(instrs.len(), scope.clone()).unwrap();
36 buf.extend_from_host_slice(&instrs)?;
37 buf
38 };
39
40 let width = MachineAir::<F>::preprocessed_width(self);
41
42 let height =
43 MachineAir::<F>::preprocessed_num_rows_with_instrs_len(self, program, instrs.len())
44 .expect("preprocessed_num_rows_with_instrs_len(...) should be Some(_)");
45
46 let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
47
48 unsafe {
49 const BLOCK_DIM: usize = 64;
50 let grid_dim = height.div_ceil(BLOCK_DIM);
51 let args = args!(trace.as_mut_ptr(), height, instrs_device.as_ptr(), instrs.len());
57 scope
58 .launch_kernel(
59 TaskScope::tracegen_preprocessed_recursion_poseidon2_wide_kernel(),
60 grid_dim,
61 BLOCK_DIM,
62 &args,
63 0,
64 )
65 .unwrap();
66 }
67
68 Ok(Some(DeviceMle::from(trace)))
69 }
70
71 fn supports_device_main_tracegen(&self) -> bool {
72 true
73 }
74
75 async fn generate_trace_device(
76 &self,
77 input: &Self::Record,
78 _: &mut Self::Record,
79 scope: &TaskScope,
80 ) -> Result<DeviceMle<F>, CopyError> {
81 debug_assert!(DEGREE == 3 || DEGREE == 9);
82 let sbox_state = DEGREE == 3;
83
84 let events = &input.poseidon2_events;
85
86 let events_device = {
87 let mut buf = Buffer::try_with_capacity_in(events.len(), scope.clone()).unwrap();
88 buf.extend_from_host_slice(events)?;
89 buf
90 };
91
92 let width = <Self as BaseAir<F>>::width(self);
93
94 let height = <Self as MachineAir<F>>::num_rows(self, input)
95 .expect("num_rows(...) should be Some(_)");
96
97 let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
98
99 unsafe {
100 const BLOCK_DIM: usize = 64;
101 let grid_dim = height.div_ceil(BLOCK_DIM);
102 let args =
109 args!(trace.as_mut_ptr(), height, events_device.as_ptr(), events.len(), sbox_state);
110 scope
111 .launch_kernel(
112 TaskScope::tracegen_recursion_poseidon2_wide_kernel(),
113 grid_dim,
114 BLOCK_DIM,
115 &args,
116 0,
117 )
118 .unwrap();
119 }
120
121 Ok(DeviceMle::from(trace))
122 }
123}
124
125#[cfg(test)]
126mod tests {
127
128 use rand::{rngs::StdRng, Rng};
129
130 use slop_symmetric::Permutation;
131
132 use sp1_recursion_executor::{
133 Address, AnalyzedInstruction, ExecutionRecord, Instruction, Poseidon2Event, Poseidon2Instr,
134 Poseidon2Io, PERMUTATION_WIDTH,
135 };
136 use sp1_recursion_machine::chips::poseidon2_wide::Poseidon2WideChip;
137
138 use crate::F;
139
140 fn make_poseidon2_instr(rng: &mut StdRng) -> AnalyzedInstruction<F> {
141 AnalyzedInstruction::new(
142 Instruction::Poseidon2(Box::new(Poseidon2Instr {
143 addrs: Poseidon2Io {
144 input: rng.gen::<[F; PERMUTATION_WIDTH]>().map(Address),
145 output: rng.gen::<[F; PERMUTATION_WIDTH]>().map(Address),
146 },
147 mults: rng.gen(),
148 })),
149 rng.gen(),
150 )
151 }
152
153 #[tokio::test]
154 async fn test_poseidon2_wide_deg_3_generate_preprocessed_trace() {
155 sp1_gpu_cudart::spawn(move |scope| {
156 crate::recursion::tests::test_preprocessed_tracegen(
157 Poseidon2WideChip::<3>,
158 make_poseidon2_instr,
159 scope,
160 )
161 })
162 .await
163 .unwrap();
164 }
165
166 #[tokio::test]
167 async fn test_poseidon2_wide_deg_3_generate_main_trace() {
168 sp1_gpu_cudart::spawn(move |scope| {
169 crate::tests::test_main_tracegen(
170 Poseidon2WideChip::<3>,
171 |rng| {
172 let input = rng.gen();
173 let permuter = sp1_hypercube::inner_perm();
174 let output = permuter.permute(input);
175
176 Poseidon2Event { input, output }
177 },
178 |poseidon2_events| ExecutionRecord { poseidon2_events, ..Default::default() },
179 scope,
180 )
181 })
182 .await
183 .unwrap();
184 }
185}