sp1_gpu_tracegen/recursion/
convert.rs1use slop_air::BaseAir;
2use slop_alloc::mem::CopyError;
3use slop_alloc::Buffer;
4use slop_tensor::Tensor;
5use sp1_gpu_cudart::{args, DeviceMle, TaskScope};
6use sp1_gpu_cudart::{TracegenPreprocessedRecursionConvertKernel, TracegenRecursionConvertKernel};
7use sp1_hypercube::air::MachineAir;
8use sp1_recursion_executor::Instruction;
9use sp1_recursion_machine::chips::poseidon2_helper::convert::ConvertChip;
10
11use crate::{CudaTracegenAir, F};
12
13impl CudaTracegenAir<F> for ConvertChip {
14 fn supports_device_preprocessed_tracegen(&self) -> bool {
15 true
16 }
17
18 async fn generate_preprocessed_trace_device(
19 &self,
20 program: &Self::Program,
21 scope: &TaskScope,
22 ) -> Result<Option<DeviceMle<F>>, CopyError> {
23 let instrs = program
24 .inner
25 .iter()
26 .filter_map(|instruction| match instruction.inner() {
27 Instruction::ExtFelt(instr) => Some(*instr),
28 _ => None,
29 })
30 .collect::<Vec<_>>();
31
32 let instrs_device = {
33 let mut buf = Buffer::try_with_capacity_in(instrs.len(), scope.clone()).unwrap();
34 buf.extend_from_host_slice(&instrs)?;
35 buf
36 };
37
38 let width = MachineAir::<F>::preprocessed_width(self);
39
40 let height =
41 MachineAir::<F>::preprocessed_num_rows_with_instrs_len(self, program, instrs.len())
42 .expect("preprocessed_num_rows_with_instrs_len(...) should be Some(_)");
43
44 let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
45
46 unsafe {
47 const BLOCK_DIM: usize = 64;
48 let grid_dim = height.div_ceil(BLOCK_DIM);
49 let args = args!(trace.as_mut_ptr(), height, instrs_device.as_ptr(), instrs.len());
55 scope
56 .launch_kernel(
57 TaskScope::tracegen_preprocessed_recursion_convert_kernel(),
58 grid_dim,
59 BLOCK_DIM,
60 &args,
61 0,
62 )
63 .unwrap();
64 }
65
66 Ok(Some(DeviceMle::from(trace)))
67 }
68
69 fn supports_device_main_tracegen(&self) -> bool {
70 true
71 }
72
73 async fn generate_trace_device(
74 &self,
75 input: &Self::Record,
76 _: &mut Self::Record,
77 scope: &TaskScope,
78 ) -> Result<DeviceMle<F>, CopyError> {
79 let events = &input.ext_felt_conversion_events;
80
81 let events_device = {
82 let mut buf = Buffer::try_with_capacity_in(events.len(), scope.clone()).unwrap();
83 buf.extend_from_host_slice(events)?;
84 buf
85 };
86
87 let width = <Self as BaseAir<F>>::width(self);
88
89 let height = <Self as MachineAir<F>>::num_rows(self, input)
90 .expect("num_rows(...) should be Some(_)");
91
92 let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
93
94 unsafe {
95 const BLOCK_DIM: usize = 64;
96 let grid_dim = height.div_ceil(BLOCK_DIM);
97 let args = args!(trace.as_mut_ptr(), height, events_device.as_ptr(), events.len());
103 scope
104 .launch_kernel(
105 TaskScope::tracegen_recursion_convert_kernel(),
106 grid_dim,
107 BLOCK_DIM,
108 &args,
109 0,
110 )
111 .unwrap();
112 }
113
114 Ok(DeviceMle::from(trace))
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use rand::Rng;
121
122 use sp1_recursion_executor::{
123 Address, AnalyzedInstruction, Block, ExecutionRecord, ExtFeltEvent, ExtFeltInstr,
124 Instruction,
125 };
126 use sp1_recursion_machine::chips::poseidon2_helper::convert::ConvertChip;
127
128 #[tokio::test]
129 async fn test_convert_generate_preprocessed_trace() {
130 sp1_gpu_cudart::spawn(|scope| {
131 crate::recursion::tests::test_preprocessed_tracegen(
132 ConvertChip,
133 |rng| {
134 AnalyzedInstruction::new(
135 Instruction::ExtFelt(ExtFeltInstr {
136 addrs: [
137 Address(rng.gen()),
138 Address(rng.gen()),
139 Address(rng.gen()),
140 Address(rng.gen()),
141 Address(rng.gen()),
142 ],
143 mults: [rng.gen(), rng.gen(), rng.gen(), rng.gen(), rng.gen()],
144 ext2felt: rng.gen(),
145 }),
146 rng.gen(),
147 )
148 },
149 scope,
150 )
151 })
152 .await
153 .unwrap();
154 }
155
156 #[tokio::test]
157 async fn test_convert_generate_main_trace() {
158 sp1_gpu_cudart::spawn(move |scope| {
159 crate::tests::test_main_tracegen(
160 ConvertChip,
161 |rng| ExtFeltEvent { input: Block(rng.gen()) },
162 |ext_felt_conversion_events| ExecutionRecord {
163 ext_felt_conversion_events,
164 ..Default::default()
165 },
166 scope,
167 )
168 })
169 .await
170 .unwrap();
171 }
172}