Skip to main content

sp1_gpu_cudart/mle/
fold.rs

1use slop_algebra::Field;
2use slop_alloc::Backend;
3use sp1_gpu_sys::{
4    mle::{mle_fold_koala_bear_base_base, mle_fold_koala_bear_ext_ext},
5    runtime::KernelPtr,
6};
7use sp1_primitives::{SP1ExtensionField, SP1Field};
8
9use crate::{args, DeviceCopy, DeviceTensor, TaskScope};
10
11use super::DeviceMle;
12
13/// # Safety
14///
15/// todo
16pub unsafe trait FoldKernel<F: Field>: Backend {
17    fn fold_kernel() -> KernelPtr;
18}
19
20impl<F: DeviceCopy + Field> DeviceMle<F>
21where
22    TaskScope: FoldKernel<F>,
23{
24    /// Folds the MLE by the given beta value.
25    pub fn fold(&self, beta: F) -> DeviceMle<F> {
26        let guts = self.guts();
27        let num_polynomials = self.num_polynomials();
28        let num_non_zero_entries = self.num_non_zero_entries();
29        let folded_num_non_zero_entries = num_non_zero_entries / 2;
30        // MLE guts shape is [num_polynomials, num_entries] for TaskScope convention
31        let mut folded_guts = DeviceTensor::with_sizes_in(
32            [num_polynomials, folded_num_non_zero_entries],
33            self.backend().clone(),
34        );
35
36        const BLOCK_SIZE: usize = 256;
37        const STRIDE: usize = 16;
38        let block_dim = BLOCK_SIZE;
39        let grid_size_x = folded_num_non_zero_entries.div_ceil(BLOCK_SIZE * STRIDE);
40        let grid_size_y = num_polynomials;
41        let grid_dim = (grid_size_x, grid_size_y, 1);
42        let args = args!(
43            guts.as_ptr(),
44            folded_guts.as_mut_ptr(),
45            beta,
46            folded_num_non_zero_entries,
47            num_polynomials
48        );
49        unsafe {
50            folded_guts.assume_init();
51            self.backend()
52                .launch_kernel(TaskScope::fold_kernel(), grid_dim, block_dim, &args, 0)
53                .unwrap();
54        }
55        DeviceMle::new(folded_guts)
56    }
57}
58
59unsafe impl FoldKernel<SP1Field> for TaskScope {
60    fn fold_kernel() -> KernelPtr {
61        unsafe { mle_fold_koala_bear_base_base() }
62    }
63}
64
65unsafe impl FoldKernel<SP1ExtensionField> for TaskScope {
66    fn fold_kernel() -> KernelPtr {
67        unsafe { mle_fold_koala_bear_ext_ext() }
68    }
69}
70
71#[cfg(test)]
72mod tests {
73    use rand::Rng;
74    use slop_multilinear::Mle;
75    use sp1_primitives::SP1ExtensionField;
76
77    use crate::mle::DeviceMle;
78
79    #[test]
80    fn test_fold_mle() {
81        let num_variables = 11;
82
83        type EF = SP1ExtensionField;
84
85        let mut rng = rand::thread_rng();
86
87        let mle = Mle::<EF>::rand(&mut rng, 1, num_variables);
88        let beta = rng.gen::<EF>();
89
90        let folded_mle_host = mle.fold(beta);
91
92        let folded_mle_cuda = crate::run_sync_in_place(|t| {
93            let d_mle = DeviceMle::from_host(&mle, &t).unwrap();
94            let folded_mle_cuda = d_mle.fold(beta);
95            folded_mle_cuda.to_host().unwrap()
96        })
97        .unwrap();
98
99        for (val, exp) in
100            folded_mle_host.guts().as_slice().iter().zip(folded_mle_cuda.guts().as_slice())
101        {
102            assert_eq!(val, exp);
103        }
104    }
105}