Skip to main content

sp1_gpu_tracegen/recursion/
alu_ext.rs

1use slop_air::BaseAir;
2use slop_alloc::mem::CopyError;
3use slop_alloc::Buffer;
4use slop_tensor::Tensor;
5use sp1_gpu_cudart::{args, DeviceMle, TaskScope};
6use sp1_gpu_cudart::{TracegenPreprocessedRecursionExtAluKernel, TracegenRecursionExtAluKernel};
7use sp1_hypercube::air::MachineAir;
8use sp1_recursion_executor::Instruction;
9use sp1_recursion_machine::chips::alu_ext::ExtAluChip;
10
11use crate::{CudaTracegenAir, F};
12
13impl CudaTracegenAir<F> for ExtAluChip {
14    fn supports_device_preprocessed_tracegen(&self) -> bool {
15        true
16    }
17
18    async fn generate_preprocessed_trace_device(
19        &self,
20        program: &Self::Program,
21        scope: &TaskScope,
22    ) -> Result<Option<DeviceMle<F>>, CopyError> {
23        let instrs = program
24            .inner
25            .iter() // Faster than using `rayon` for some reason. Maybe vectorization?
26            .filter_map(|instruction| match instruction.inner() {
27                Instruction::ExtAlu(instr) => Some(*instr),
28                _ => None,
29            })
30            .collect::<Vec<_>>();
31
32        let instrs_device = {
33            let mut buf = Buffer::try_with_capacity_in(instrs.len(), scope.clone()).unwrap();
34            buf.extend_from_host_slice(&instrs)?;
35            buf
36        };
37
38        let width = MachineAir::<F>::preprocessed_width(self);
39
40        let height =
41            MachineAir::<F>::preprocessed_num_rows_with_instrs_len(self, program, instrs.len())
42                .expect("preprocessed_num_rows_with_instrs_len(...) should be Some(_)");
43
44        let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
45
46        unsafe {
47            const BLOCK_DIM: usize = 64;
48            let grid_dim = height.div_ceil(BLOCK_DIM);
49            // args:
50            // T *trace,
51            // uintptr_t trace_height,
52            // const sp1_gpu_sys::ExtAluInstr<T> *instructions,
53            // uintptr_t nb_instructions
54            let args = args!(trace.as_mut_ptr(), height, instrs_device.as_ptr(), instrs.len());
55            scope
56                .launch_kernel(
57                    TaskScope::tracegen_preprocessed_recursion_ext_alu_kernel(),
58                    grid_dim,
59                    BLOCK_DIM,
60                    &args,
61                    0,
62                )
63                .unwrap();
64        }
65
66        Ok(Some(DeviceMle::from(trace)))
67    }
68
69    fn supports_device_main_tracegen(&self) -> bool {
70        true
71    }
72
73    async fn generate_trace_device(
74        &self,
75        input: &Self::Record,
76        _: &mut Self::Record,
77        scope: &TaskScope,
78    ) -> Result<DeviceMle<F>, CopyError> {
79        let events = &input.ext_alu_events;
80
81        let events_device = {
82            let mut buf = Buffer::try_with_capacity_in(events.len(), scope.clone()).unwrap();
83            buf.extend_from_host_slice(events)?;
84            buf
85        };
86
87        let width = <Self as BaseAir<F>>::width(self);
88
89        let height = <Self as MachineAir<F>>::num_rows(self, input)
90            .expect("num_rows(...) should be Some(_)");
91
92        let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
93
94        unsafe {
95            const BLOCK_DIM: usize = 64;
96            let grid_dim = height.div_ceil(BLOCK_DIM);
97            // args:
98            // T *trace,
99            // uintptr_t trace_height,
100            // const sp1_gpu_sys::ExtAluEvent<T> *events,
101            // uintptr_t nb_events
102            let args = args!(trace.as_mut_ptr(), height, events_device.as_ptr(), events.len());
103            scope
104                .launch_kernel(
105                    TaskScope::tracegen_recursion_ext_alu_kernel(),
106                    grid_dim,
107                    BLOCK_DIM,
108                    &args,
109                    0,
110                )
111                .unwrap();
112        }
113
114        Ok(DeviceMle::from(trace))
115    }
116}
117
118#[cfg(test)]
119mod tests {
120
121    use rand::Rng;
122
123    use slop_algebra::extension::BinomialExtensionField;
124    use slop_algebra::{AbstractExtensionField, AbstractField, Field};
125    use sp1_recursion_executor::{
126        Address, AnalyzedInstruction, Block, ExecutionRecord, ExtAluEvent, ExtAluInstr, ExtAluIo,
127        ExtAluOpcode, Instruction,
128    };
129    use sp1_recursion_machine::chips::alu_ext::ExtAluChip;
130
131    use crate::F;
132
133    type EF = BinomialExtensionField<F, 4>;
134
135    #[tokio::test]
136    async fn test_ext_alu_generate_preprocessed_trace() {
137        sp1_gpu_cudart::spawn(move |scope| {
138            crate::recursion::tests::test_preprocessed_tracegen(
139                ExtAluChip,
140                |rng| {
141                    let opcode = match rng.gen_range(0..4) {
142                        0 => ExtAluOpcode::AddE,
143                        1 => ExtAluOpcode::SubE,
144                        2 => ExtAluOpcode::MulE,
145                        _ => ExtAluOpcode::DivE,
146                    };
147                    AnalyzedInstruction::new(
148                        Instruction::ExtAlu(ExtAluInstr {
149                            opcode,
150                            mult: rng.gen(),
151                            addrs: ExtAluIo {
152                                out: Address(rng.gen()),
153                                in1: Address(rng.gen()),
154                                in2: Address(rng.gen()),
155                            },
156                        }),
157                        rng.gen(),
158                    )
159                },
160                scope,
161            )
162        })
163        .await
164        .unwrap();
165    }
166
167    #[tokio::test]
168    async fn test_ext_alu_generate_main_trace() {
169        sp1_gpu_cudart::spawn(move |scope| {
170            crate::tests::test_main_tracegen(
171                ExtAluChip,
172                |rng| {
173                    let b1 = Block(rng.gen());
174                    let b2 = Block(rng.gen());
175                    let in1: EF = b1.ext();
176                    let in2: EF = b2.ext();
177                    let out = Block::from(
178                        match rng.gen_range(0..4) {
179                            0 => in1 + in2, // Add
180                            1 => in1 - in2, // Sub
181                            2 => in1 * in2, // Mul
182                            _ => {
183                                let ef2 = if in2.is_zero() { EF::one() } else { in2 };
184                                in1 / ef2
185                            }
186                        }
187                        .as_base_slice(),
188                    );
189                    ExtAluEvent { out, in1: b1, in2: b2 }
190                },
191                |ext_alu_events| ExecutionRecord { ext_alu_events, ..Default::default() },
192                scope,
193            )
194        })
195        .await
196        .unwrap();
197    }
198}