sp1_gpu_cudart/mle/
restrict.rs1use std::sync::Arc;
2
3use slop_algebra::{ExtensionField, Field};
4use slop_multilinear::MleEval;
5use sp1_gpu_sys::{
6 runtime::KernelPtr,
7 v2_kernels::{
8 fix_last_variable_ext_ext_kernel, fix_last_variable_felt_ext_kernel,
9 mle_fix_last_variable_koala_bear_base_extension_constant_padding,
10 mle_fix_last_variable_koala_bear_ext_ext_constant_padding,
11 },
12};
13use sp1_primitives::{SP1ExtensionField, SP1Field};
14
15use crate::{args, DeviceCopy, DeviceTensor, TaskScope};
16
17use super::DeviceMle;
18
19pub unsafe trait MleFixLastVariableKernel<F: Field, EF: ExtensionField<F>> {
21 fn mle_fix_last_variable_kernel() -> KernelPtr;
22
23 fn mle_fix_last_variable_constant_padding_kernel() -> KernelPtr;
24}
25
26impl<F: DeviceCopy + Field> DeviceMle<F> {
27 pub fn fix_last_variable<EF: DeviceCopy + ExtensionField<F>>(
29 &self,
30 alpha: EF,
31 padding_values: Arc<MleEval<F, TaskScope>>,
32 ) -> DeviceMle<EF>
33 where
34 TaskScope: MleFixLastVariableKernel<F, EF>,
35 {
36 let mle = self.guts();
37 let num_polynomials = self.num_polynomials();
38 let input_height = mle.sizes()[1];
40 assert!(input_height > 0);
41 let output_height = input_height.div_ceil(2);
42 let mut output =
43 DeviceTensor::with_sizes_in([num_polynomials, output_height], self.backend().clone());
44
45 const BLOCK_SIZE: usize = 256;
46 const STRIDE: usize = 128;
47 let grid_size_x = output_height.div_ceil(BLOCK_SIZE * STRIDE);
48 let grid_size_y = num_polynomials;
49 let grid_size = (grid_size_x, grid_size_y, 1);
50
51 let args = args!(
52 mle.as_ptr(),
53 output.as_mut_ptr(),
54 padding_values.evaluations().as_ptr(),
55 alpha,
56 input_height,
57 num_polynomials
58 );
59
60 unsafe {
61 output.assume_init();
62 self.backend()
63 .launch_kernel(
64 <TaskScope as MleFixLastVariableKernel<F, EF>>::mle_fix_last_variable_kernel(),
65 grid_size,
66 BLOCK_SIZE,
67 &args,
68 0,
69 )
70 .unwrap();
71 }
72
73 DeviceMle::new(output)
74 }
75
76 pub fn fix_last_variable_constant_padding<EF: DeviceCopy + ExtensionField<F>>(
78 &self,
79 alpha: EF,
80 padding_value: F,
81 ) -> DeviceMle<EF>
82 where
83 TaskScope: MleFixLastVariableKernel<F, EF>,
84 {
85 let mle = self.guts();
86 let num_polynomials = self.num_polynomials();
87 let input_height = mle.sizes()[1];
89 assert!(input_height > 0);
90 let output_height = input_height.div_ceil(2);
91 let mut output =
92 DeviceTensor::with_sizes_in([num_polynomials, output_height], self.backend().clone());
93
94 const BLOCK_SIZE: usize = 256;
95 const STRIDE: usize = 128;
96 let grid_size_x = output_height.div_ceil(BLOCK_SIZE * STRIDE);
97 let grid_size_y = num_polynomials;
98 let grid_size = (grid_size_x, grid_size_y, 1);
99
100 let args = args!(
101 mle.as_ptr(),
102 output.as_mut_ptr(),
103 padding_value,
104 alpha,
105 input_height,
106 num_polynomials
107 );
108
109 unsafe {
110 output.assume_init();
111 self.backend()
112 .launch_kernel(
113 <TaskScope as MleFixLastVariableKernel<F, EF>>::mle_fix_last_variable_constant_padding_kernel(),
114 grid_size,
115 BLOCK_SIZE,
116 &args,
117 0,
118 )
119 .unwrap();
120 }
121
122 DeviceMle::new(output)
123 }
124}
125
126unsafe impl MleFixLastVariableKernel<SP1Field, SP1ExtensionField> for TaskScope {
127 fn mle_fix_last_variable_kernel() -> KernelPtr {
128 unsafe { fix_last_variable_felt_ext_kernel() }
129 }
130 fn mle_fix_last_variable_constant_padding_kernel() -> KernelPtr {
131 unsafe { mle_fix_last_variable_koala_bear_base_extension_constant_padding() }
132 }
133}
134
135unsafe impl MleFixLastVariableKernel<SP1ExtensionField, SP1ExtensionField> for TaskScope {
136 fn mle_fix_last_variable_kernel() -> KernelPtr {
137 unsafe { fix_last_variable_ext_ext_kernel() }
138 }
139
140 fn mle_fix_last_variable_constant_padding_kernel() -> KernelPtr {
141 unsafe { mle_fix_last_variable_koala_bear_ext_ext_constant_padding() }
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use rand::Rng;
148 use slop_algebra::AbstractField;
149 use slop_multilinear::{Mle, Point};
150 use slop_tensor::Tensor;
151 use sp1_primitives::{SP1ExtensionField, SP1Field};
152
153 use crate::mle::eval::DevicePoint;
154 use crate::mle::DeviceMle;
155
156 #[test]
157 fn test_mle_fix_last_variable_constant_padding() {
158 let mut rng = rand::thread_rng();
159
160 type F = SP1Field;
161 type EF = SP1ExtensionField;
162
163 let mle = Mle::<F>::new(Tensor::rand(&mut rng, [(1 << 16) - 1000, 1]));
164 let random_point = Point::<EF>::rand(&mut rng, 15);
165 let alpha = rng.gen::<EF>();
166
167 let evals = crate::run_sync_in_place(|t| {
168 let d_mle = DeviceMle::from_host(&mle, &t).unwrap();
169 let restriction = d_mle.fix_last_variable_constant_padding(alpha, F::zero());
172 let d_point = DevicePoint::from_host(&random_point, &t).unwrap();
173 let eval = restriction.eval_at_point(&d_point);
174 eval.to_host_vec().unwrap()
175 })
176 .unwrap();
177
178 let restriction = mle.fix_last_variable(alpha);
180 let host_evals = restriction.eval_at(&random_point).to_vec();
181
182 assert_eq!(evals, host_evals);
183 }
184
185 }