sp1_gpu_cudart/mle/
fold.rs1use slop_algebra::Field;
2use slop_alloc::Backend;
3use sp1_gpu_sys::{
4 mle::{mle_fold_koala_bear_base_base, mle_fold_koala_bear_ext_ext},
5 runtime::KernelPtr,
6};
7use sp1_primitives::{SP1ExtensionField, SP1Field};
8
9use crate::{args, DeviceCopy, DeviceTensor, TaskScope};
10
11use super::DeviceMle;
12
13pub unsafe trait FoldKernel<F: Field>: Backend {
17 fn fold_kernel() -> KernelPtr;
18}
19
20impl<F: DeviceCopy + Field> DeviceMle<F>
21where
22 TaskScope: FoldKernel<F>,
23{
24 pub fn fold(&self, beta: F) -> DeviceMle<F> {
26 let guts = self.guts();
27 let num_polynomials = self.num_polynomials();
28 let num_non_zero_entries = self.num_non_zero_entries();
29 let folded_num_non_zero_entries = num_non_zero_entries / 2;
30 let mut folded_guts = DeviceTensor::with_sizes_in(
32 [num_polynomials, folded_num_non_zero_entries],
33 self.backend().clone(),
34 );
35
36 const BLOCK_SIZE: usize = 256;
37 const STRIDE: usize = 16;
38 let block_dim = BLOCK_SIZE;
39 let grid_size_x = folded_num_non_zero_entries.div_ceil(BLOCK_SIZE * STRIDE);
40 let grid_size_y = num_polynomials;
41 let grid_dim = (grid_size_x, grid_size_y, 1);
42 let args = args!(
43 guts.as_ptr(),
44 folded_guts.as_mut_ptr(),
45 beta,
46 folded_num_non_zero_entries,
47 num_polynomials
48 );
49 unsafe {
50 folded_guts.assume_init();
51 self.backend()
52 .launch_kernel(TaskScope::fold_kernel(), grid_dim, block_dim, &args, 0)
53 .unwrap();
54 }
55 DeviceMle::new(folded_guts)
56 }
57}
58
59unsafe impl FoldKernel<SP1Field> for TaskScope {
60 fn fold_kernel() -> KernelPtr {
61 unsafe { mle_fold_koala_bear_base_base() }
62 }
63}
64
65unsafe impl FoldKernel<SP1ExtensionField> for TaskScope {
66 fn fold_kernel() -> KernelPtr {
67 unsafe { mle_fold_koala_bear_ext_ext() }
68 }
69}
70
71#[cfg(test)]
72mod tests {
73 use rand::Rng;
74 use slop_multilinear::Mle;
75 use sp1_primitives::SP1ExtensionField;
76
77 use crate::mle::DeviceMle;
78
79 #[test]
80 fn test_fold_mle() {
81 let num_variables = 11;
82
83 type EF = SP1ExtensionField;
84
85 let mut rng = rand::thread_rng();
86
87 let mle = Mle::<EF>::rand(&mut rng, 1, num_variables);
88 let beta = rng.gen::<EF>();
89
90 let folded_mle_host = mle.fold(beta);
91
92 let folded_mle_cuda = crate::run_sync_in_place(|t| {
93 let d_mle = DeviceMle::from_host(&mle, &t).unwrap();
94 let folded_mle_cuda = d_mle.fold(beta);
95 folded_mle_cuda.to_host().unwrap()
96 })
97 .unwrap();
98
99 for (val, exp) in
100 folded_mle_host.guts().as_slice().iter().zip(folded_mle_cuda.guts().as_slice())
101 {
102 assert_eq!(val, exp);
103 }
104 }
105}