sp1_gpu_tracegen/recursion/
select.rs1use slop_air::BaseAir;
2use slop_alloc::mem::CopyError;
3use slop_alloc::Buffer;
4use slop_tensor::Tensor;
5use sp1_gpu_cudart::{args, DeviceMle, TaskScope};
6use sp1_gpu_cudart::{TracegenPreprocessedRecursionSelectKernel, TracegenRecursionSelectKernel};
7use sp1_hypercube::air::MachineAir;
8use sp1_recursion_executor::Instruction;
9use sp1_recursion_machine::chips::select::SelectChip;
10
11use crate::{CudaTracegenAir, F};
12
13impl CudaTracegenAir<F> for SelectChip {
14 fn supports_device_preprocessed_tracegen(&self) -> bool {
15 true
16 }
17
18 async fn generate_preprocessed_trace_device(
19 &self,
20 program: &Self::Program,
21 scope: &TaskScope,
22 ) -> Result<Option<DeviceMle<F>>, CopyError> {
23 let instrs = program
24 .inner
25 .iter() .filter_map(|instruction| match instruction.inner() {
27 Instruction::Select(instr) => Some(*instr),
28 _ => None,
29 })
30 .collect::<Vec<_>>();
31
32 let instrs_device = {
33 let mut buf = Buffer::try_with_capacity_in(instrs.len(), scope.clone()).unwrap();
34 buf.extend_from_host_slice(&instrs)?;
35 buf
36 };
37
38 let width = MachineAir::<F>::preprocessed_width(self);
39
40 let height =
41 MachineAir::<F>::preprocessed_num_rows_with_instrs_len(self, program, instrs.len())
42 .expect("preprocessed_num_rows_with_instrs_len(...) should be Some(_)");
43
44 let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
45
46 unsafe {
47 const BLOCK_DIM: usize = 64;
48 let grid_dim = height.div_ceil(BLOCK_DIM);
49 let args = args!(trace.as_mut_ptr(), height, instrs_device.as_ptr(), instrs.len());
55 scope
56 .launch_kernel(
57 TaskScope::tracegen_preprocessed_recursion_select_kernel(),
58 grid_dim,
59 BLOCK_DIM,
60 &args,
61 0,
62 )
63 .unwrap();
64 }
65
66 Ok(Some(DeviceMle::from(trace)))
67 }
68
69 fn supports_device_main_tracegen(&self) -> bool {
70 true
71 }
72
73 async fn generate_trace_device(
74 &self,
75 input: &Self::Record,
76 _: &mut Self::Record,
77 scope: &TaskScope,
78 ) -> Result<DeviceMle<F>, CopyError> {
79 let events = &input.select_events;
80
81 let events_device = {
82 let mut buf = Buffer::try_with_capacity_in(events.len(), scope.clone()).unwrap();
83 buf.extend_from_host_slice(events)?;
84 buf
85 };
86
87 let width = <Self as BaseAir<F>>::width(self);
88
89 let height = <Self as MachineAir<F>>::num_rows(self, input)
90 .expect("num_rows(...) should be Some(_)");
91
92 let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
93
94 unsafe {
95 const BLOCK_DIM: usize = 64;
96 let grid_dim = height.div_ceil(BLOCK_DIM);
97 let args = args!(trace.as_mut_ptr(), height, events_device.as_ptr(), events.len());
103 scope
104 .launch_kernel(
105 TaskScope::tracegen_recursion_select_kernel(),
106 grid_dim,
107 BLOCK_DIM,
108 &args,
109 0,
110 )
111 .unwrap();
112 }
113
114 Ok(DeviceMle::from(trace))
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use rand::Rng;
121
122 use slop_algebra::{AbstractField, Field};
123 use sp1_recursion_executor::{
124 Address, AnalyzedInstruction, ExecutionRecord, Instruction, SelectEvent, SelectInstr,
125 SelectIo,
126 };
127 use sp1_recursion_machine::chips::select::SelectChip;
128
129 use crate::F;
130
131 #[tokio::test]
132 async fn test_select_generate_preprocessed_trace() {
133 sp1_gpu_cudart::spawn(move |scope| {
134 crate::recursion::tests::test_preprocessed_tracegen(
135 SelectChip,
136 |rng| {
137 AnalyzedInstruction::new(
138 Instruction::Select(SelectInstr {
139 addrs: SelectIo {
140 bit: Address(rng.gen()),
141 out1: Address(rng.gen()),
142 out2: Address(rng.gen()),
143 in1: Address(rng.gen()),
144 in2: Address(rng.gen()),
145 },
146 mult1: rng.gen(),
147 mult2: rng.gen(),
148 }),
149 rng.gen(),
150 )
151 },
152 scope,
153 )
154 })
155 .await
156 .unwrap();
157 }
158
159 #[tokio::test]
160 async fn test_select_generate_main_trace() {
161 sp1_gpu_cudart::spawn(move |scope| {
162 crate::tests::test_main_tracegen(
163 SelectChip,
164 |rng| {
165 let bit = F::from_bool(rng.gen());
166 let in1 = rng.gen();
167 let in2 = rng.gen();
168 let (out1, out2) = if bit.is_one() { (in1, in2) } else { (in2, in1) };
169 SelectEvent { bit, out1, out2, in1, in2 }
170 },
171 |select_events| ExecutionRecord { select_events, ..Default::default() },
172 scope,
173 )
174 })
175 .await
176 .unwrap();
177 }
178}