Skip to main content

sp1_gpu_cudart/mle/
restrict.rs

1use std::sync::Arc;
2
3use slop_algebra::{ExtensionField, Field};
4use slop_multilinear::MleEval;
5use sp1_gpu_sys::{
6    runtime::KernelPtr,
7    v2_kernels::{
8        fix_last_variable_ext_ext_kernel, fix_last_variable_felt_ext_kernel,
9        mle_fix_last_variable_koala_bear_base_extension_constant_padding,
10        mle_fix_last_variable_koala_bear_ext_ext_constant_padding,
11    },
12};
13use sp1_primitives::{SP1ExtensionField, SP1Field};
14
15use crate::{args, DeviceCopy, DeviceTensor, TaskScope};
16
17use super::DeviceMle;
18
19/// # Safety
20pub unsafe trait MleFixLastVariableKernel<F: Field, EF: ExtensionField<F>> {
21    fn mle_fix_last_variable_kernel() -> KernelPtr;
22
23    fn mle_fix_last_variable_constant_padding_kernel() -> KernelPtr;
24}
25
26impl<F: DeviceCopy + Field> DeviceMle<F> {
27    /// Fix the last variable of the MLE at the given alpha value.
28    pub fn fix_last_variable<EF: DeviceCopy + ExtensionField<F>>(
29        &self,
30        alpha: EF,
31        padding_values: Arc<MleEval<F, TaskScope>>,
32    ) -> DeviceMle<EF>
33    where
34        TaskScope: MleFixLastVariableKernel<F, EF>,
35    {
36        let mle = self.guts();
37        let num_polynomials = self.num_polynomials();
38        // MLE guts shape is [num_polynomials, num_entries] for TaskScope convention
39        let input_height = mle.sizes()[1];
40        assert!(input_height > 0);
41        let output_height = input_height.div_ceil(2);
42        let mut output =
43            DeviceTensor::with_sizes_in([num_polynomials, output_height], self.backend().clone());
44
45        const BLOCK_SIZE: usize = 256;
46        const STRIDE: usize = 128;
47        let grid_size_x = output_height.div_ceil(BLOCK_SIZE * STRIDE);
48        let grid_size_y = num_polynomials;
49        let grid_size = (grid_size_x, grid_size_y, 1);
50
51        let args = args!(
52            mle.as_ptr(),
53            output.as_mut_ptr(),
54            padding_values.evaluations().as_ptr(),
55            alpha,
56            input_height,
57            num_polynomials
58        );
59
60        unsafe {
61            output.assume_init();
62            self.backend()
63                .launch_kernel(
64                    <TaskScope as MleFixLastVariableKernel<F, EF>>::mle_fix_last_variable_kernel(),
65                    grid_size,
66                    BLOCK_SIZE,
67                    &args,
68                    0,
69                )
70                .unwrap();
71        }
72
73        DeviceMle::new(output)
74    }
75
76    /// Fix the last variable of the MLE at the given alpha value with constant padding.
77    pub fn fix_last_variable_constant_padding<EF: DeviceCopy + ExtensionField<F>>(
78        &self,
79        alpha: EF,
80        padding_value: F,
81    ) -> DeviceMle<EF>
82    where
83        TaskScope: MleFixLastVariableKernel<F, EF>,
84    {
85        let mle = self.guts();
86        let num_polynomials = self.num_polynomials();
87        // MLE guts shape is [num_polynomials, num_entries] for TaskScope convention
88        let input_height = mle.sizes()[1];
89        assert!(input_height > 0);
90        let output_height = input_height.div_ceil(2);
91        let mut output =
92            DeviceTensor::with_sizes_in([num_polynomials, output_height], self.backend().clone());
93
94        const BLOCK_SIZE: usize = 256;
95        const STRIDE: usize = 128;
96        let grid_size_x = output_height.div_ceil(BLOCK_SIZE * STRIDE);
97        let grid_size_y = num_polynomials;
98        let grid_size = (grid_size_x, grid_size_y, 1);
99
100        let args = args!(
101            mle.as_ptr(),
102            output.as_mut_ptr(),
103            padding_value,
104            alpha,
105            input_height,
106            num_polynomials
107        );
108
109        unsafe {
110            output.assume_init();
111            self.backend()
112                .launch_kernel(
113                    <TaskScope as MleFixLastVariableKernel<F, EF>>::mle_fix_last_variable_constant_padding_kernel(),
114                    grid_size,
115                    BLOCK_SIZE,
116                    &args,
117                    0,
118                )
119                .unwrap();
120        }
121
122        DeviceMle::new(output)
123    }
124}
125
126unsafe impl MleFixLastVariableKernel<SP1Field, SP1ExtensionField> for TaskScope {
127    fn mle_fix_last_variable_kernel() -> KernelPtr {
128        unsafe { fix_last_variable_felt_ext_kernel() }
129    }
130    fn mle_fix_last_variable_constant_padding_kernel() -> KernelPtr {
131        unsafe { mle_fix_last_variable_koala_bear_base_extension_constant_padding() }
132    }
133}
134
135unsafe impl MleFixLastVariableKernel<SP1ExtensionField, SP1ExtensionField> for TaskScope {
136    fn mle_fix_last_variable_kernel() -> KernelPtr {
137        unsafe { fix_last_variable_ext_ext_kernel() }
138    }
139
140    fn mle_fix_last_variable_constant_padding_kernel() -> KernelPtr {
141        unsafe { mle_fix_last_variable_koala_bear_ext_ext_constant_padding() }
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use rand::Rng;
148    use slop_algebra::AbstractField;
149    use slop_multilinear::{Mle, Point};
150    use slop_tensor::Tensor;
151    use sp1_primitives::{SP1ExtensionField, SP1Field};
152
153    use crate::mle::eval::DevicePoint;
154    use crate::mle::DeviceMle;
155
156    #[test]
157    fn test_mle_fix_last_variable_constant_padding() {
158        let mut rng = rand::thread_rng();
159
160        type F = SP1Field;
161        type EF = SP1ExtensionField;
162
163        let mle = Mle::<F>::new(Tensor::rand(&mut rng, [(1 << 16) - 1000, 1]));
164        let random_point = Point::<EF>::rand(&mut rng, 15);
165        let alpha = rng.gen::<EF>();
166
167        let evals = crate::run_sync_in_place(|t| {
168            let d_mle = DeviceMle::from_host(&mle, &t).unwrap();
169            // Using fix_last_variable_constant_padding with F::zero() is equivalent
170            // to the host's fix_last_variable method.
171            let restriction = d_mle.fix_last_variable_constant_padding(alpha, F::zero());
172            let d_point = DevicePoint::from_host(&random_point, &t).unwrap();
173            let eval = restriction.eval_at_point(&d_point);
174            eval.to_host_vec().unwrap()
175        })
176        .unwrap();
177
178        // Host's fix_last_variable uses zero padding internally
179        let restriction = mle.fix_last_variable(alpha);
180        let host_evals = restriction.eval_at(&random_point).to_vec();
181
182        assert_eq!(evals, host_evals);
183    }
184
185    // Note: The spawned tests and PaddedMle tests are commented out as they require
186    // the async spawn interface which is not part of this sync refactor.
187}