sp1_gpu_cudart/mle/
eval.rs1use slop_algebra::{ExtensionField, Field};
2use slop_alloc::{mem::CopyError, CpuBackend};
3use slop_multilinear::{MleBaseBackend, MleEval, Point};
4use slop_tensor::Tensor;
5use sp1_gpu_sys::{
6 mle::{
7 partial_geq_koala_bear, partial_lagrange_koala_bear, partial_lagrange_koala_bear_extension,
8 },
9 runtime::KernelPtr,
10};
11use sp1_primitives::{SP1ExtensionField, SP1Field};
12
13use crate::{args, tensor::dot::DotKernel, DeviceCopy, DeviceTensor, TaskScope};
14
15use super::DeviceMle;
16
17pub struct DevicePoint<F> {
19 raw: Point<F, TaskScope>,
20}
21
22impl<F: DeviceCopy + Field> DevicePoint<F> {
23 pub fn new(point: Point<F, TaskScope>) -> Self {
25 Self { raw: point }
26 }
27
28 pub fn inner(&self) -> &Point<F, TaskScope> {
30 &self.raw
31 }
32
33 pub fn into_inner(self) -> Point<F, TaskScope> {
35 self.raw
36 }
37
38 pub fn dimension(&self) -> usize {
40 self.raw.dimension()
41 }
42
43 pub fn backend(&self) -> &TaskScope {
45 self.raw.backend()
46 }
47
48 pub fn as_ptr(&self) -> *const F {
50 self.raw.as_ptr()
51 }
52
53 pub fn from_host(
55 host_point: &Point<F, CpuBackend>,
56 scope: &TaskScope,
57 ) -> Result<Self, CopyError> {
58 use slop_alloc::Buffer;
59 let host_values = host_point.values();
60 let mut device_buf = Buffer::with_capacity_in(host_values.len(), scope.clone());
61 device_buf.extend_from_host_slice(host_values)?;
62 Ok(Self::new(Point::new(device_buf)))
63 }
64
65 pub fn partial_lagrange(&self) -> DeviceMle<F>
67 where
68 TaskScope: PartialLagrangeKernel<F>,
69 {
70 let dimension = self.dimension();
71 let num_elements = 1 << dimension;
72 let mut eq = DeviceTensor::with_sizes_in([1, num_elements], self.backend().clone());
74 unsafe {
75 eq.assume_init();
76 let block_dim = 256;
77 let grid_dim = ((1 << dimension) as u32).div_ceil(block_dim);
78 let args = args!(eq.as_mut_ptr(), self.as_ptr(), dimension);
79 self.backend()
80 .launch_kernel(
81 <TaskScope as PartialLagrangeKernel<F>>::partial_lagrange_kernel(),
82 grid_dim,
83 block_dim,
84 &args,
85 0,
86 )
87 .unwrap();
88 }
89 DeviceMle::new(eq)
90 }
91}
92
93pub struct DeviceMleEval<F> {
95 raw: MleEval<F, TaskScope>,
96}
97
98impl<F: DeviceCopy + Field> DeviceMleEval<F> {
99 pub fn new(eval: MleEval<F, TaskScope>) -> Self {
101 Self { raw: eval }
102 }
103
104 pub fn inner(&self) -> &MleEval<F, TaskScope> {
106 &self.raw
107 }
108
109 pub fn into_inner(self) -> MleEval<F, TaskScope> {
111 self.raw
112 }
113
114 pub fn evaluations(&self) -> &Tensor<F, TaskScope> {
116 self.raw.evaluations()
117 }
118
119 pub fn to_host_vec(&self) -> Result<Vec<F>, CopyError> {
121 let device_tensor = DeviceTensor::from_raw(self.raw.evaluations().clone());
122 let host_tensor = device_tensor.to_host()?;
123 Ok(host_tensor.into_buffer().into_vec())
124 }
125}
126
127pub unsafe trait PartialLagrangeKernel<F: Field> {
130 fn partial_lagrange_kernel() -> KernelPtr;
131}
132
133pub unsafe trait PartialGeqKernel<F: Field> {
136 fn partial_geq_kernel() -> KernelPtr;
137}
138
139impl<F: DeviceCopy + Field> DeviceMle<F> {
140 pub fn eval_at_point<EF: DeviceCopy + ExtensionField<F>>(
142 &self,
143 point: &DevicePoint<EF>,
144 ) -> DeviceMleEval<EF>
145 where
146 TaskScope: PartialLagrangeKernel<EF> + DotKernel<F, EF>,
147 {
148 let eq = point.partial_lagrange();
149 self.eval_at_eq(&eq)
150 }
151
152 pub fn eval_at_eq<EF: DeviceCopy + ExtensionField<F>>(
154 &self,
155 eq: &DeviceMle<EF>,
156 ) -> DeviceMleEval<EF>
157 where
158 TaskScope: DotKernel<F, EF>,
159 {
160 let result = self.guts.dot_along_dim(eq.guts(), 1);
164 DeviceMleEval::new(MleEval::new(result.into_inner()))
165 }
166
167 pub fn fixed_at_zero<EF: DeviceCopy + ExtensionField<F>>(
170 &self,
171 point: &Point<EF>,
172 ) -> DeviceMleEval<EF>
173 where
174 TaskScope: PartialLagrangeKernel<EF> + DotKernel<F, EF>,
175 {
176 let mut extended_point = point.clone();
178 extended_point.add_dimension_back(EF::zero());
179 let device_point = DevicePoint::from_host(&extended_point, self.backend()).unwrap();
180 self.eval_at_point(&device_point)
181 }
182}
183
184impl<F: Field> MleBaseBackend<F> for TaskScope {
185 #[inline]
186 fn uninit_mle(&self, num_polynomials: usize, num_non_zero_entries: usize) -> Tensor<F, Self> {
187 Tensor::with_sizes_in([num_polynomials, num_non_zero_entries], self.clone())
189 }
190
191 #[inline]
192 fn num_polynomials(guts: &Tensor<F, Self>) -> usize {
193 guts.sizes()[0]
195 }
196
197 #[inline]
198 fn num_variables(guts: &Tensor<F, Self>) -> u32 {
199 guts.sizes()[1].next_power_of_two().ilog2()
201 }
202
203 #[inline]
204 fn num_non_zero_entries(guts: &Tensor<F, Self>) -> usize {
205 guts.sizes()[1]
207 }
208}
209
210unsafe impl PartialLagrangeKernel<SP1Field> for TaskScope {
211 fn partial_lagrange_kernel() -> KernelPtr {
212 unsafe { partial_lagrange_koala_bear() }
213 }
214}
215
216unsafe impl PartialLagrangeKernel<SP1ExtensionField> for TaskScope {
217 fn partial_lagrange_kernel() -> KernelPtr {
218 unsafe { partial_lagrange_koala_bear_extension() }
219 }
220}
221
222unsafe impl PartialGeqKernel<SP1Field> for TaskScope {
223 fn partial_geq_kernel() -> KernelPtr {
224 unsafe { partial_geq_koala_bear() }
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use slop_multilinear::{Mle, Point};
231 use sp1_primitives::{SP1ExtensionField, SP1Field};
232
233 use super::{DeviceMleEval, DevicePoint};
234 use crate::mle::DeviceMle;
235
236 #[test]
237 fn test_mle_eval() {
238 let mut rng = rand::thread_rng();
239
240 type F = SP1Field;
241 type EF = SP1ExtensionField;
242
243 let mle = Mle::<F>::rand(&mut rng, 100, 16);
244 let point = Point::<EF>::rand(&mut rng, 16);
245
246 let evals = crate::run_sync_in_place(|t| {
247 let d_point = DevicePoint::from_host(&point, &t).unwrap();
248 let d_mle = DeviceMle::from_host(&mle, &t).unwrap();
249 let eval: DeviceMleEval<EF> = d_mle.eval_at_point(&d_point);
250 eval.to_host_vec().unwrap()
251 })
252 .unwrap();
253
254 let host_evals = mle.eval_at(&point).to_vec();
255 assert_eq!(evals, host_evals);
256 }
257}