Skip to main content

sp1_gpu_tracegen/recursion/
sbox.rs

1use slop_air::BaseAir;
2use slop_alloc::mem::CopyError;
3use slop_alloc::Buffer;
4use slop_tensor::Tensor;
5use sp1_gpu_cudart::{args, DeviceMle, TaskScope};
6use sp1_gpu_cudart::{TracegenPreprocessedRecursionSBoxKernel, TracegenRecursionSBoxKernel};
7use sp1_hypercube::air::MachineAir;
8use sp1_recursion_executor::Instruction;
9use sp1_recursion_machine::chips::poseidon2_helper::sbox::Poseidon2SBoxChip;
10
11use crate::{CudaTracegenAir, F};
12
13impl CudaTracegenAir<F> for Poseidon2SBoxChip {
14    fn supports_device_preprocessed_tracegen(&self) -> bool {
15        true
16    }
17
18    async fn generate_preprocessed_trace_device(
19        &self,
20        program: &Self::Program,
21        scope: &TaskScope,
22    ) -> Result<Option<DeviceMle<F>>, CopyError> {
23        let instrs = program
24            .inner
25            .iter()
26            .filter_map(|instruction| match instruction.inner() {
27                Instruction::Poseidon2SBox(instr) => Some(*instr),
28                _ => None,
29            })
30            .collect::<Vec<_>>();
31
32        let instrs_device = {
33            let mut buf = Buffer::try_with_capacity_in(instrs.len(), scope.clone()).unwrap();
34            buf.extend_from_host_slice(&instrs)?;
35            buf
36        };
37
38        let width = MachineAir::<F>::preprocessed_width(self);
39
40        let height =
41            MachineAir::<F>::preprocessed_num_rows_with_instrs_len(self, program, instrs.len())
42                .expect("preprocessed_num_rows_with_instrs_len(...) should be Some(_)");
43
44        let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
45
46        unsafe {
47            const BLOCK_DIM: usize = 64;
48            let grid_dim = height.div_ceil(BLOCK_DIM);
49            // args:
50            // T *trace,
51            // uintptr_t trace_height,
52            // const sp1_gpu_sys::Poseidon2SBoxInstr<T> *instructions,
53            // uintptr_t nb_instructions
54            let args = args!(trace.as_mut_ptr(), height, instrs_device.as_ptr(), instrs.len());
55            scope
56                .launch_kernel(
57                    TaskScope::tracegen_preprocessed_recursion_sbox_kernel(),
58                    grid_dim,
59                    BLOCK_DIM,
60                    &args,
61                    0,
62                )
63                .unwrap();
64        }
65
66        Ok(Some(DeviceMle::from(trace)))
67    }
68
69    fn supports_device_main_tracegen(&self) -> bool {
70        true
71    }
72
73    async fn generate_trace_device(
74        &self,
75        input: &Self::Record,
76        _: &mut Self::Record,
77        scope: &TaskScope,
78    ) -> Result<DeviceMle<F>, CopyError> {
79        let events = &input.poseidon2_sbox_events;
80
81        let events_device = {
82            let mut buf = Buffer::try_with_capacity_in(events.len(), scope.clone()).unwrap();
83            buf.extend_from_host_slice(events)?;
84            buf
85        };
86
87        let width = <Self as BaseAir<F>>::width(self);
88
89        let height = <Self as MachineAir<F>>::num_rows(self, input)
90            .expect("num_rows(...) should be Some(_)");
91
92        let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
93
94        unsafe {
95            const BLOCK_DIM: usize = 64;
96            let grid_dim = height.div_ceil(BLOCK_DIM);
97            // args:
98            // T *trace,
99            // uintptr_t trace_height,
100            // const sp1_gpu_sys::Poseidon2SBoxIo<T> *events,
101            // uintptr_t nb_events
102            let args = args!(trace.as_mut_ptr(), height, events_device.as_ptr(), events.len());
103            scope
104                .launch_kernel(
105                    TaskScope::tracegen_recursion_sbox_kernel(),
106                    grid_dim,
107                    BLOCK_DIM,
108                    &args,
109                    0,
110                )
111                .unwrap();
112        }
113
114        Ok(DeviceMle::from(trace))
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use rand::Rng;
121
122    use slop_algebra::AbstractField;
123    use sp1_recursion_executor::{
124        Address, AnalyzedInstruction, Block, ExecutionRecord, Instruction, Poseidon2SBoxInstr,
125        Poseidon2SBoxIo,
126    };
127    use sp1_recursion_machine::chips::poseidon2_helper::sbox::Poseidon2SBoxChip;
128
129    use crate::F;
130
131    #[tokio::test]
132    async fn test_sbox_generate_preprocessed_trace() {
133        sp1_gpu_cudart::spawn(|scope| {
134            crate::recursion::tests::test_preprocessed_tracegen(
135                Poseidon2SBoxChip,
136                |rng| {
137                    let addrs =
138                        Poseidon2SBoxIo { input: Address(rng.gen()), output: Address(rng.gen()) };
139                    AnalyzedInstruction::new(
140                        Instruction::Poseidon2SBox(Poseidon2SBoxInstr {
141                            addrs,
142                            mults: F::one(),
143                            external: rng.gen(),
144                        }),
145                        rng.gen(),
146                    )
147                },
148                scope,
149            )
150        })
151        .await
152        .unwrap();
153    }
154
155    #[tokio::test]
156    async fn test_sbox_generate_main_trace() {
157        sp1_gpu_cudart::spawn(move |scope| {
158            crate::tests::test_main_tracegen(
159                Poseidon2SBoxChip,
160                |rng| {
161                    let input = Block(rng.gen());
162                    // Compute output: x^7 = x^3 * x^3 * x (SBox operation)
163                    let input_cubed =
164                        Block(core::array::from_fn(|i| input.0[i] * input.0[i] * input.0[i]));
165                    let output = Block(core::array::from_fn(|i| {
166                        input.0[i] * input_cubed.0[i] * input_cubed.0[i]
167                    }));
168                    Poseidon2SBoxIo { input, output }
169                },
170                |poseidon2_sbox_events| ExecutionRecord {
171                    poseidon2_sbox_events,
172                    ..Default::default()
173                },
174                scope,
175            )
176        })
177        .await
178        .unwrap();
179    }
180}