Skip to main content

sp1_gpu_tracegen/recursion/
poseidon2_wide.rs

1use slop_air::BaseAir;
2use slop_alloc::mem::CopyError;
3use slop_alloc::Buffer;
4use slop_tensor::Tensor;
5use sp1_gpu_cudart::{args, DeviceMle, TaskScope};
6use sp1_gpu_cudart::{
7    TracegenPreprocessedRecursionPoseidon2WideKernel, TracegenRecursionPoseidon2WideKernel,
8};
9use sp1_hypercube::air::MachineAir;
10use sp1_recursion_executor::Instruction;
11use sp1_recursion_machine::chips::poseidon2_wide::Poseidon2WideChip;
12
13use crate::{CudaTracegenAir, F};
14
15impl<const DEGREE: usize> CudaTracegenAir<F> for Poseidon2WideChip<DEGREE> {
16    fn supports_device_preprocessed_tracegen(&self) -> bool {
17        true
18    }
19
20    async fn generate_preprocessed_trace_device(
21        &self,
22        program: &Self::Program,
23        scope: &TaskScope,
24    ) -> Result<Option<DeviceMle<F>>, CopyError> {
25        let instrs = program
26            .inner
27            .iter() // Faster than using `rayon` for some reason. Maybe vectorization?
28            .filter_map(|instruction| match instruction.inner() {
29                Instruction::Poseidon2(instr) => Some(**instr),
30                _ => None,
31            })
32            .collect::<Vec<_>>();
33
34        let instrs_device = {
35            let mut buf = Buffer::try_with_capacity_in(instrs.len(), scope.clone()).unwrap();
36            buf.extend_from_host_slice(&instrs)?;
37            buf
38        };
39
40        let width = MachineAir::<F>::preprocessed_width(self);
41
42        let height =
43            MachineAir::<F>::preprocessed_num_rows_with_instrs_len(self, program, instrs.len())
44                .expect("preprocessed_num_rows_with_instrs_len(...) should be Some(_)");
45
46        let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
47
48        unsafe {
49            const BLOCK_DIM: usize = 64;
50            let grid_dim = height.div_ceil(BLOCK_DIM);
51            // args:
52            // T *trace,
53            // uintptr_t trace_height,
54            // const sp1_gpu_sys::Poseidon2Instr<T> *instructions,
55            // uintptr_t nb_instructions
56            let args = args!(trace.as_mut_ptr(), height, instrs_device.as_ptr(), instrs.len());
57            scope
58                .launch_kernel(
59                    TaskScope::tracegen_preprocessed_recursion_poseidon2_wide_kernel(),
60                    grid_dim,
61                    BLOCK_DIM,
62                    &args,
63                    0,
64                )
65                .unwrap();
66        }
67
68        Ok(Some(DeviceMle::from(trace)))
69    }
70
71    fn supports_device_main_tracegen(&self) -> bool {
72        true
73    }
74
75    async fn generate_trace_device(
76        &self,
77        input: &Self::Record,
78        _: &mut Self::Record,
79        scope: &TaskScope,
80    ) -> Result<DeviceMle<F>, CopyError> {
81        debug_assert!(DEGREE == 3 || DEGREE == 9);
82        let sbox_state = DEGREE == 3;
83
84        let events = &input.poseidon2_events;
85
86        let events_device = {
87            let mut buf = Buffer::try_with_capacity_in(events.len(), scope.clone()).unwrap();
88            buf.extend_from_host_slice(events)?;
89            buf
90        };
91
92        let width = <Self as BaseAir<F>>::width(self);
93
94        let height = <Self as MachineAir<F>>::num_rows(self, input)
95            .expect("num_rows(...) should be Some(_)");
96
97        let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
98
99        unsafe {
100            const BLOCK_DIM: usize = 64;
101            let grid_dim = height.div_ceil(BLOCK_DIM);
102            // args:
103            // kb31_t *trace,
104            // uintptr_t trace_height,
105            // const sp1_gpu_sys::Poseidon2Event<kb31_t> *events,
106            // uintptr_t nb_events,
107            // bool sbox_state
108            let args =
109                args!(trace.as_mut_ptr(), height, events_device.as_ptr(), events.len(), sbox_state);
110            scope
111                .launch_kernel(
112                    TaskScope::tracegen_recursion_poseidon2_wide_kernel(),
113                    grid_dim,
114                    BLOCK_DIM,
115                    &args,
116                    0,
117                )
118                .unwrap();
119        }
120
121        Ok(DeviceMle::from(trace))
122    }
123}
124
125#[cfg(test)]
126mod tests {
127
128    use rand::{rngs::StdRng, Rng};
129
130    use slop_symmetric::Permutation;
131
132    use sp1_recursion_executor::{
133        Address, AnalyzedInstruction, ExecutionRecord, Instruction, Poseidon2Event, Poseidon2Instr,
134        Poseidon2Io, PERMUTATION_WIDTH,
135    };
136    use sp1_recursion_machine::chips::poseidon2_wide::Poseidon2WideChip;
137
138    use crate::F;
139
140    fn make_poseidon2_instr(rng: &mut StdRng) -> AnalyzedInstruction<F> {
141        AnalyzedInstruction::new(
142            Instruction::Poseidon2(Box::new(Poseidon2Instr {
143                addrs: Poseidon2Io {
144                    input: rng.gen::<[F; PERMUTATION_WIDTH]>().map(Address),
145                    output: rng.gen::<[F; PERMUTATION_WIDTH]>().map(Address),
146                },
147                mults: rng.gen(),
148            })),
149            rng.gen(),
150        )
151    }
152
153    #[tokio::test]
154    async fn test_poseidon2_wide_deg_3_generate_preprocessed_trace() {
155        sp1_gpu_cudart::spawn(move |scope| {
156            crate::recursion::tests::test_preprocessed_tracegen(
157                Poseidon2WideChip::<3>,
158                make_poseidon2_instr,
159                scope,
160            )
161        })
162        .await
163        .unwrap();
164    }
165
166    #[tokio::test]
167    async fn test_poseidon2_wide_deg_3_generate_main_trace() {
168        sp1_gpu_cudart::spawn(move |scope| {
169            crate::tests::test_main_tracegen(
170                Poseidon2WideChip::<3>,
171                |rng| {
172                    let input = rng.gen();
173                    let permuter = sp1_hypercube::inner_perm();
174                    let output = permuter.permute(input);
175
176                    Poseidon2Event { input, output }
177                },
178                |poseidon2_events| ExecutionRecord { poseidon2_events, ..Default::default() },
179                scope,
180            )
181        })
182        .await
183        .unwrap();
184    }
185}