Skip to main content

sp1_gpu_tracegen/recursion/
linear_layer.rs

1use slop_air::BaseAir;
2use slop_alloc::mem::CopyError;
3use slop_alloc::Buffer;
4use slop_tensor::Tensor;
5use sp1_gpu_cudart::{args, DeviceMle, TaskScope};
6use sp1_gpu_cudart::{
7    TracegenPreprocessedRecursionLinearLayerKernel, TracegenRecursionLinearLayerKernel,
8};
9use sp1_hypercube::air::MachineAir;
10use sp1_recursion_executor::Instruction;
11use sp1_recursion_machine::chips::poseidon2_helper::linear::Poseidon2LinearLayerChip;
12
13use crate::{CudaTracegenAir, F};
14
15impl CudaTracegenAir<F> for Poseidon2LinearLayerChip {
16    fn supports_device_preprocessed_tracegen(&self) -> bool {
17        true
18    }
19
20    async fn generate_preprocessed_trace_device(
21        &self,
22        program: &Self::Program,
23        scope: &TaskScope,
24    ) -> Result<Option<DeviceMle<F>>, CopyError> {
25        let instrs = program
26            .inner
27            .iter()
28            .filter_map(|instruction| match instruction.inner() {
29                Instruction::Poseidon2LinearLayer(instr) => Some(**instr),
30                _ => None,
31            })
32            .collect::<Vec<_>>();
33
34        let instrs_device = {
35            let mut buf = Buffer::try_with_capacity_in(instrs.len(), scope.clone()).unwrap();
36            buf.extend_from_host_slice(&instrs)?;
37            buf
38        };
39
40        let width = MachineAir::<F>::preprocessed_width(self);
41
42        let height =
43            MachineAir::<F>::preprocessed_num_rows_with_instrs_len(self, program, instrs.len())
44                .expect("preprocessed_num_rows_with_instrs_len(...) should be Some(_)");
45
46        let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
47
48        unsafe {
49            const BLOCK_DIM: usize = 64;
50            let grid_dim = height.div_ceil(BLOCK_DIM);
51            // args:
52            // T *trace,
53            // uintptr_t trace_height,
54            // const sp1_gpu_sys::Poseidon2LinearLayerInstr<T> *instructions,
55            // uintptr_t nb_instructions
56            let args = args!(trace.as_mut_ptr(), height, instrs_device.as_ptr(), instrs.len());
57            scope
58                .launch_kernel(
59                    TaskScope::tracegen_preprocessed_recursion_linear_layer_kernel(),
60                    grid_dim,
61                    BLOCK_DIM,
62                    &args,
63                    0,
64                )
65                .unwrap();
66        }
67
68        Ok(Some(DeviceMle::from(trace)))
69    }
70
71    fn supports_device_main_tracegen(&self) -> bool {
72        true
73    }
74
75    async fn generate_trace_device(
76        &self,
77        input: &Self::Record,
78        _: &mut Self::Record,
79        scope: &TaskScope,
80    ) -> Result<DeviceMle<F>, CopyError> {
81        let events = &input.poseidon2_linear_layer_events;
82
83        let events_device = {
84            let mut buf = Buffer::try_with_capacity_in(events.len(), scope.clone()).unwrap();
85            buf.extend_from_host_slice(events)?;
86            buf
87        };
88
89        let width = <Self as BaseAir<F>>::width(self);
90
91        let height = <Self as MachineAir<F>>::num_rows(self, input)
92            .expect("num_rows(...) should be Some(_)");
93
94        let mut trace = Tensor::<F, TaskScope>::zeros_in([width, height], scope.clone());
95
96        unsafe {
97            const BLOCK_DIM: usize = 64;
98            let grid_dim = height.div_ceil(BLOCK_DIM);
99            // args:
100            // T *trace,
101            // uintptr_t trace_height,
102            // const sp1_gpu_sys::Poseidon2LinearLayerIo<T> *events,
103            // uintptr_t nb_events
104            let args = args!(trace.as_mut_ptr(), height, events_device.as_ptr(), events.len());
105            scope
106                .launch_kernel(
107                    TaskScope::tracegen_recursion_linear_layer_kernel(),
108                    grid_dim,
109                    BLOCK_DIM,
110                    &args,
111                    0,
112                )
113                .unwrap();
114        }
115
116        Ok(DeviceMle::from(trace))
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use rand::Rng;
123
124    use slop_algebra::AbstractField;
125    use sp1_recursion_executor::{
126        Address, AnalyzedInstruction, Block, ExecutionRecord, Instruction,
127        Poseidon2LinearLayerInstr, Poseidon2LinearLayerIo,
128    };
129    use sp1_recursion_machine::chips::poseidon2_helper::linear::Poseidon2LinearLayerChip;
130
131    use crate::F;
132
133    #[tokio::test]
134    async fn test_linear_layer_generate_preprocessed_trace() {
135        sp1_gpu_cudart::spawn(|scope| {
136            crate::recursion::tests::test_preprocessed_tracegen(
137                Poseidon2LinearLayerChip,
138                |rng| {
139                    let addrs = Poseidon2LinearLayerIo {
140                        input: core::array::from_fn(|_| Address(rng.gen())),
141                        output: core::array::from_fn(|_| Address(rng.gen())),
142                    };
143                    AnalyzedInstruction::new(
144                        Instruction::Poseidon2LinearLayer(Box::new(Poseidon2LinearLayerInstr {
145                            addrs,
146                            mults: core::array::from_fn(|_| F::one()),
147                            external: rng.gen(),
148                        })),
149                        rng.gen(),
150                    )
151                },
152                scope,
153            )
154        })
155        .await
156        .unwrap();
157    }
158
159    #[tokio::test]
160    async fn test_linear_layer_generate_main_trace() {
161        sp1_gpu_cudart::spawn(move |scope| {
162            crate::tests::test_main_tracegen(
163                Poseidon2LinearLayerChip,
164                |rng| Poseidon2LinearLayerIo {
165                    input: core::array::from_fn(|_| Block(rng.gen())),
166                    output: core::array::from_fn(|_| Block(rng.gen())),
167                },
168                |poseidon2_linear_layer_events| ExecutionRecord {
169                    poseidon2_linear_layer_events,
170                    ..Default::default()
171                },
172                scope,
173            )
174        })
175        .await
176        .unwrap();
177    }
178}