Skip to main content

sp1_gpu_tracegen/recursion/
alu_base.rs

1use slop_air::BaseAir;
2use slop_alloc::mem::CopyError;
3use slop_alloc::Buffer;
4use slop_tensor::Tensor;
5use sp1_gpu_cudart::{args, DeviceMle, TaskScope};
6use sp1_gpu_cudart::{TracegenPreprocessedRecursionBaseAluKernel, TracegenRecursionBaseAluKernel};
7use sp1_hypercube::air::MachineAir;
8use sp1_recursion_executor::Instruction;
9use sp1_recursion_machine::chips::alu_base::BaseAluChip;
10
11use crate::{CudaTracegenAir, F};
12
13impl CudaTracegenAir<F> for BaseAluChip {
14    fn supports_device_preprocessed_tracegen(&self) -> bool {
15        true
16    }
17
18    async fn generate_preprocessed_trace_device(
19        &self,
20        program: &Self::Program,
21        scope: &TaskScope,
22    ) -> Result<Option<DeviceMle<F>>, CopyError> {
23        let instrs = program
24            .inner
25            .iter() // Faster than using `rayon` for some reason. Maybe vectorization?
26            .filter_map(|instruction| match instruction.inner() {
27                Instruction::BaseAlu(instr) => Some(*instr),
28                _ => None,
29            })
30            .collect::<Vec<_>>();
31
32        let instrs_device = {
33            let mut buf = Buffer::try_with_capacity_in(instrs.len(), scope.clone()).unwrap();
34            buf.extend_from_host_slice(&instrs)?;
35            buf
36        };
37
38        let width = MachineAir::<F>::preprocessed_width(self);
39
40        let height =
41            MachineAir::<F>::preprocessed_num_rows_with_instrs_len(self, program, instrs.len())
42                .expect("preprocessed_num_rows_with_instrs_len(...) should be Some(_)");
43
44        let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
45
46        unsafe {
47            const BLOCK_DIM: usize = 64;
48            let grid_dim = height.div_ceil(BLOCK_DIM);
49            // args:
50            // T *trace,
51            // uintptr_t trace_height,
52            // const sp1_gpu_sys::BaseAluInstr<T> *instructions,
53            // uintptr_t nb_instructions
54            let args = args!(trace.as_mut_ptr(), height, instrs_device.as_ptr(), instrs.len());
55            scope
56                .launch_kernel(
57                    TaskScope::tracegen_preprocessed_recursion_base_alu_kernel(),
58                    grid_dim,
59                    BLOCK_DIM,
60                    &args,
61                    0,
62                )
63                .unwrap();
64        }
65
66        Ok(Some(DeviceMle::from(trace)))
67    }
68
69    fn supports_device_main_tracegen(&self) -> bool {
70        true
71    }
72
73    async fn generate_trace_device(
74        &self,
75        input: &Self::Record,
76        _: &mut Self::Record,
77        scope: &TaskScope,
78    ) -> Result<DeviceMle<F>, CopyError> {
79        let events = &input.base_alu_events;
80
81        let events_device = {
82            let mut buf = Buffer::try_with_capacity_in(events.len(), scope.clone()).unwrap();
83            buf.extend_from_host_slice(events)?;
84            buf
85        };
86
87        let width = <Self as BaseAir<F>>::width(self);
88
89        let height = <Self as MachineAir<F>>::num_rows(self, input)
90            .expect("num_rows(...) should be Some(_)");
91
92        let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
93
94        unsafe {
95            const BLOCK_DIM: usize = 64;
96            let grid_dim = height.div_ceil(BLOCK_DIM);
97            // args:
98            // T *trace,
99            // uintptr_t trace_height,
100            // const sp1_gpu_sys::BaseAluEvent<T> *events,
101            // uintptr_t nb_events
102            let args = args!(trace.as_mut_ptr(), height, events_device.as_ptr(), events.len());
103            scope
104                .launch_kernel(
105                    TaskScope::tracegen_recursion_base_alu_kernel(),
106                    grid_dim,
107                    BLOCK_DIM,
108                    &args,
109                    0,
110                )
111                .unwrap();
112        }
113
114        Ok(DeviceMle::from(trace))
115    }
116}
117
118#[cfg(test)]
119mod tests {
120
121    use rand::Rng;
122
123    use slop_algebra::{AbstractField, Field};
124    use sp1_recursion_executor::{
125        Address, AnalyzedInstruction, BaseAluEvent, BaseAluInstr, BaseAluIo, BaseAluOpcode,
126        ExecutionRecord, Instruction,
127    };
128    use sp1_recursion_machine::chips::alu_base::BaseAluChip;
129
130    use crate::F;
131
132    #[tokio::test]
133    async fn test_base_alu_generate_preprocessed_trace() {
134        sp1_gpu_cudart::spawn(|scope| {
135            crate::recursion::tests::test_preprocessed_tracegen(
136                BaseAluChip,
137                |rng| {
138                    let opcode = match rng.gen_range(0..4) {
139                        0 => BaseAluOpcode::AddF,
140                        1 => BaseAluOpcode::SubF,
141                        2 => BaseAluOpcode::MulF,
142                        _ => BaseAluOpcode::DivF,
143                    };
144                    AnalyzedInstruction::new(
145                        Instruction::BaseAlu(BaseAluInstr {
146                            opcode,
147                            mult: rng.gen(),
148                            addrs: BaseAluIo {
149                                out: Address(rng.gen()),
150                                in1: Address(rng.gen()),
151                                in2: Address(rng.gen()),
152                            },
153                        }),
154                        rng.gen(),
155                    )
156                },
157                scope,
158            )
159        })
160        .await
161        .unwrap();
162    }
163
164    #[tokio::test]
165    async fn test_base_alu_generate_main_trace() {
166        sp1_gpu_cudart::spawn(move |scope| {
167            crate::tests::test_main_tracegen(
168                BaseAluChip,
169                |rng| {
170                    let in1: F = rng.gen();
171                    let in2: F = rng.gen();
172                    let out = match rng.gen_range(0..4) {
173                        0 => in1 + in2, // Add
174                        1 => in1 - in2, // Sub
175                        2 => in1 * in2, // Mul
176                        _ => {
177                            let in2 = if in2.is_zero() { F::one() } else { in2 };
178                            in1 / in2
179                        }
180                    };
181                    BaseAluEvent { out, in1, in2 }
182                },
183                |base_alu_events| ExecutionRecord { base_alu_events, ..Default::default() },
184                scope,
185            )
186        })
187        .await
188        .unwrap();
189    }
190}