sp1_gpu_tracegen/recursion/
alu_base.rs1use slop_air::BaseAir;
2use slop_alloc::mem::CopyError;
3use slop_alloc::Buffer;
4use slop_tensor::Tensor;
5use sp1_gpu_cudart::{args, DeviceMle, TaskScope};
6use sp1_gpu_cudart::{TracegenPreprocessedRecursionBaseAluKernel, TracegenRecursionBaseAluKernel};
7use sp1_hypercube::air::MachineAir;
8use sp1_recursion_executor::Instruction;
9use sp1_recursion_machine::chips::alu_base::BaseAluChip;
10
11use crate::{CudaTracegenAir, F};
12
13impl CudaTracegenAir<F> for BaseAluChip {
14 fn supports_device_preprocessed_tracegen(&self) -> bool {
15 true
16 }
17
18 async fn generate_preprocessed_trace_device(
19 &self,
20 program: &Self::Program,
21 scope: &TaskScope,
22 ) -> Result<Option<DeviceMle<F>>, CopyError> {
23 let instrs = program
24 .inner
25 .iter() .filter_map(|instruction| match instruction.inner() {
27 Instruction::BaseAlu(instr) => Some(*instr),
28 _ => None,
29 })
30 .collect::<Vec<_>>();
31
32 let instrs_device = {
33 let mut buf = Buffer::try_with_capacity_in(instrs.len(), scope.clone()).unwrap();
34 buf.extend_from_host_slice(&instrs)?;
35 buf
36 };
37
38 let width = MachineAir::<F>::preprocessed_width(self);
39
40 let height =
41 MachineAir::<F>::preprocessed_num_rows_with_instrs_len(self, program, instrs.len())
42 .expect("preprocessed_num_rows_with_instrs_len(...) should be Some(_)");
43
44 let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
45
46 unsafe {
47 const BLOCK_DIM: usize = 64;
48 let grid_dim = height.div_ceil(BLOCK_DIM);
49 let args = args!(trace.as_mut_ptr(), height, instrs_device.as_ptr(), instrs.len());
55 scope
56 .launch_kernel(
57 TaskScope::tracegen_preprocessed_recursion_base_alu_kernel(),
58 grid_dim,
59 BLOCK_DIM,
60 &args,
61 0,
62 )
63 .unwrap();
64 }
65
66 Ok(Some(DeviceMle::from(trace)))
67 }
68
69 fn supports_device_main_tracegen(&self) -> bool {
70 true
71 }
72
73 async fn generate_trace_device(
74 &self,
75 input: &Self::Record,
76 _: &mut Self::Record,
77 scope: &TaskScope,
78 ) -> Result<DeviceMle<F>, CopyError> {
79 let events = &input.base_alu_events;
80
81 let events_device = {
82 let mut buf = Buffer::try_with_capacity_in(events.len(), scope.clone()).unwrap();
83 buf.extend_from_host_slice(events)?;
84 buf
85 };
86
87 let width = <Self as BaseAir<F>>::width(self);
88
89 let height = <Self as MachineAir<F>>::num_rows(self, input)
90 .expect("num_rows(...) should be Some(_)");
91
92 let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
93
94 unsafe {
95 const BLOCK_DIM: usize = 64;
96 let grid_dim = height.div_ceil(BLOCK_DIM);
97 let args = args!(trace.as_mut_ptr(), height, events_device.as_ptr(), events.len());
103 scope
104 .launch_kernel(
105 TaskScope::tracegen_recursion_base_alu_kernel(),
106 grid_dim,
107 BLOCK_DIM,
108 &args,
109 0,
110 )
111 .unwrap();
112 }
113
114 Ok(DeviceMle::from(trace))
115 }
116}
117
118#[cfg(test)]
119mod tests {
120
121 use rand::Rng;
122
123 use slop_algebra::{AbstractField, Field};
124 use sp1_recursion_executor::{
125 Address, AnalyzedInstruction, BaseAluEvent, BaseAluInstr, BaseAluIo, BaseAluOpcode,
126 ExecutionRecord, Instruction,
127 };
128 use sp1_recursion_machine::chips::alu_base::BaseAluChip;
129
130 use crate::F;
131
132 #[tokio::test]
133 async fn test_base_alu_generate_preprocessed_trace() {
134 sp1_gpu_cudart::spawn(|scope| {
135 crate::recursion::tests::test_preprocessed_tracegen(
136 BaseAluChip,
137 |rng| {
138 let opcode = match rng.gen_range(0..4) {
139 0 => BaseAluOpcode::AddF,
140 1 => BaseAluOpcode::SubF,
141 2 => BaseAluOpcode::MulF,
142 _ => BaseAluOpcode::DivF,
143 };
144 AnalyzedInstruction::new(
145 Instruction::BaseAlu(BaseAluInstr {
146 opcode,
147 mult: rng.gen(),
148 addrs: BaseAluIo {
149 out: Address(rng.gen()),
150 in1: Address(rng.gen()),
151 in2: Address(rng.gen()),
152 },
153 }),
154 rng.gen(),
155 )
156 },
157 scope,
158 )
159 })
160 .await
161 .unwrap();
162 }
163
164 #[tokio::test]
165 async fn test_base_alu_generate_main_trace() {
166 sp1_gpu_cudart::spawn(move |scope| {
167 crate::tests::test_main_tracegen(
168 BaseAluChip,
169 |rng| {
170 let in1: F = rng.gen();
171 let in2: F = rng.gen();
172 let out = match rng.gen_range(0..4) {
173 0 => in1 + in2, 1 => in1 - in2, 2 => in1 * in2, _ => {
177 let in2 = if in2.is_zero() { F::one() } else { in2 };
178 in1 / in2
179 }
180 };
181 BaseAluEvent { out, in1, in2 }
182 },
183 |base_alu_events| ExecutionRecord { base_alu_events, ..Default::default() },
184 scope,
185 )
186 })
187 .await
188 .unwrap();
189 }
190}