Skip to main content

sp1_gpu_tracegen/recursion/
select.rs

1use slop_air::BaseAir;
2use slop_alloc::mem::CopyError;
3use slop_alloc::Buffer;
4use slop_tensor::Tensor;
5use sp1_gpu_cudart::{args, DeviceMle, TaskScope};
6use sp1_gpu_cudart::{TracegenPreprocessedRecursionSelectKernel, TracegenRecursionSelectKernel};
7use sp1_hypercube::air::MachineAir;
8use sp1_recursion_executor::Instruction;
9use sp1_recursion_machine::chips::select::SelectChip;
10
11use crate::{CudaTracegenAir, F};
12
13impl CudaTracegenAir<F> for SelectChip {
14    fn supports_device_preprocessed_tracegen(&self) -> bool {
15        true
16    }
17
18    async fn generate_preprocessed_trace_device(
19        &self,
20        program: &Self::Program,
21        scope: &TaskScope,
22    ) -> Result<Option<DeviceMle<F>>, CopyError> {
23        let instrs = program
24            .inner
25            .iter() // Faster than using `rayon` for some reason. Maybe vectorization?
26            .filter_map(|instruction| match instruction.inner() {
27                Instruction::Select(instr) => Some(*instr),
28                _ => None,
29            })
30            .collect::<Vec<_>>();
31
32        let instrs_device = {
33            let mut buf = Buffer::try_with_capacity_in(instrs.len(), scope.clone()).unwrap();
34            buf.extend_from_host_slice(&instrs)?;
35            buf
36        };
37
38        let width = MachineAir::<F>::preprocessed_width(self);
39
40        let height =
41            MachineAir::<F>::preprocessed_num_rows_with_instrs_len(self, program, instrs.len())
42                .expect("preprocessed_num_rows_with_instrs_len(...) should be Some(_)");
43
44        let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
45
46        unsafe {
47            const BLOCK_DIM: usize = 64;
48            let grid_dim = height.div_ceil(BLOCK_DIM);
49            // args:
50            // T *trace,
51            // uintptr_t trace_height,
52            // const sp1_gpu_sys::SelectInstr<T> *instructions,
53            // uintptr_t nb_instructions
54            let args = args!(trace.as_mut_ptr(), height, instrs_device.as_ptr(), instrs.len());
55            scope
56                .launch_kernel(
57                    TaskScope::tracegen_preprocessed_recursion_select_kernel(),
58                    grid_dim,
59                    BLOCK_DIM,
60                    &args,
61                    0,
62                )
63                .unwrap();
64        }
65
66        Ok(Some(DeviceMle::from(trace)))
67    }
68
69    fn supports_device_main_tracegen(&self) -> bool {
70        true
71    }
72
73    async fn generate_trace_device(
74        &self,
75        input: &Self::Record,
76        _: &mut Self::Record,
77        scope: &TaskScope,
78    ) -> Result<DeviceMle<F>, CopyError> {
79        let events = &input.select_events;
80
81        let events_device = {
82            let mut buf = Buffer::try_with_capacity_in(events.len(), scope.clone()).unwrap();
83            buf.extend_from_host_slice(events)?;
84            buf
85        };
86
87        let width = <Self as BaseAir<F>>::width(self);
88
89        let height = <Self as MachineAir<F>>::num_rows(self, input)
90            .expect("num_rows(...) should be Some(_)");
91
92        let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
93
94        unsafe {
95            const BLOCK_DIM: usize = 64;
96            let grid_dim = height.div_ceil(BLOCK_DIM);
97            // args:
98            // T *trace,
99            // uintptr_t trace_height,
100            // const sp1_gpu_sys::SelectEvent<T> *events,
101            // uintptr_t nb_events
102            let args = args!(trace.as_mut_ptr(), height, events_device.as_ptr(), events.len());
103            scope
104                .launch_kernel(
105                    TaskScope::tracegen_recursion_select_kernel(),
106                    grid_dim,
107                    BLOCK_DIM,
108                    &args,
109                    0,
110                )
111                .unwrap();
112        }
113
114        Ok(DeviceMle::from(trace))
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use rand::Rng;
121
122    use slop_algebra::{AbstractField, Field};
123    use sp1_recursion_executor::{
124        Address, AnalyzedInstruction, ExecutionRecord, Instruction, SelectEvent, SelectInstr,
125        SelectIo,
126    };
127    use sp1_recursion_machine::chips::select::SelectChip;
128
129    use crate::F;
130
131    #[tokio::test]
132    async fn test_select_generate_preprocessed_trace() {
133        sp1_gpu_cudart::spawn(move |scope| {
134            crate::recursion::tests::test_preprocessed_tracegen(
135                SelectChip,
136                |rng| {
137                    AnalyzedInstruction::new(
138                        Instruction::Select(SelectInstr {
139                            addrs: SelectIo {
140                                bit: Address(rng.gen()),
141                                out1: Address(rng.gen()),
142                                out2: Address(rng.gen()),
143                                in1: Address(rng.gen()),
144                                in2: Address(rng.gen()),
145                            },
146                            mult1: rng.gen(),
147                            mult2: rng.gen(),
148                        }),
149                        rng.gen(),
150                    )
151                },
152                scope,
153            )
154        })
155        .await
156        .unwrap();
157    }
158
159    #[tokio::test]
160    async fn test_select_generate_main_trace() {
161        sp1_gpu_cudart::spawn(move |scope| {
162            crate::tests::test_main_tracegen(
163                SelectChip,
164                |rng| {
165                    let bit = F::from_bool(rng.gen());
166                    let in1 = rng.gen();
167                    let in2 = rng.gen();
168                    let (out1, out2) = if bit.is_one() { (in1, in2) } else { (in2, in1) };
169                    SelectEvent { bit, out1, out2, in1, in2 }
170                },
171                |select_events| ExecutionRecord { select_events, ..Default::default() },
172                scope,
173            )
174        })
175        .await
176        .unwrap();
177    }
178}