Skip to main content

sp1_gpu_cudart/mle/
eval.rs

1use slop_algebra::{ExtensionField, Field};
2use slop_alloc::{mem::CopyError, CpuBackend};
3use slop_multilinear::{MleBaseBackend, MleEval, Point};
4use slop_tensor::Tensor;
5use sp1_gpu_sys::{
6    mle::{
7        partial_geq_koala_bear, partial_lagrange_koala_bear, partial_lagrange_koala_bear_extension,
8    },
9    runtime::KernelPtr,
10};
11use sp1_primitives::{SP1ExtensionField, SP1Field};
12
13use crate::{args, tensor::dot::DotKernel, DeviceCopy, DeviceTensor, TaskScope};
14
15use super::DeviceMle;
16
17/// A Point stored on the GPU device.
18pub struct DevicePoint<F> {
19    raw: Point<F, TaskScope>,
20}
21
22impl<F: DeviceCopy + Field> DevicePoint<F> {
23    /// Creates a new DevicePoint from a Point.
24    pub fn new(point: Point<F, TaskScope>) -> Self {
25        Self { raw: point }
26    }
27
28    /// Returns a reference to the underlying Point.
29    pub fn inner(&self) -> &Point<F, TaskScope> {
30        &self.raw
31    }
32
33    /// Consumes self and returns the underlying Point.
34    pub fn into_inner(self) -> Point<F, TaskScope> {
35        self.raw
36    }
37
38    /// Returns the dimension of this point.
39    pub fn dimension(&self) -> usize {
40        self.raw.dimension()
41    }
42
43    /// Returns the backend (TaskScope) for this point.
44    pub fn backend(&self) -> &TaskScope {
45        self.raw.backend()
46    }
47
48    /// Returns a pointer to the underlying data.
49    pub fn as_ptr(&self) -> *const F {
50        self.raw.as_ptr()
51    }
52
53    /// Copies a host Point to the device.
54    pub fn from_host(
55        host_point: &Point<F, CpuBackend>,
56        scope: &TaskScope,
57    ) -> Result<Self, CopyError> {
58        use slop_alloc::Buffer;
59        let host_values = host_point.values();
60        let mut device_buf = Buffer::with_capacity_in(host_values.len(), scope.clone());
61        device_buf.extend_from_host_slice(host_values)?;
62        Ok(Self::new(Point::new(device_buf)))
63    }
64
65    /// Computes the partial Lagrange polynomial for this point.
66    pub fn partial_lagrange(&self) -> DeviceMle<F>
67    where
68        TaskScope: PartialLagrangeKernel<F>,
69    {
70        let dimension = self.dimension();
71        let num_elements = 1 << dimension;
72        // Shape [1, num_elements] to match MleBaseBackend convention for TaskScope: [num_polynomials, num_entries]
73        let mut eq = DeviceTensor::with_sizes_in([1, num_elements], self.backend().clone());
74        unsafe {
75            eq.assume_init();
76            let block_dim = 256;
77            let grid_dim = ((1 << dimension) as u32).div_ceil(block_dim);
78            let args = args!(eq.as_mut_ptr(), self.as_ptr(), dimension);
79            self.backend()
80                .launch_kernel(
81                    <TaskScope as PartialLagrangeKernel<F>>::partial_lagrange_kernel(),
82                    grid_dim,
83                    block_dim,
84                    &args,
85                    0,
86                )
87                .unwrap();
88        }
89        DeviceMle::new(eq)
90    }
91}
92
93/// MLE evaluations stored on the GPU device.
94pub struct DeviceMleEval<F> {
95    raw: MleEval<F, TaskScope>,
96}
97
98impl<F: DeviceCopy + Field> DeviceMleEval<F> {
99    /// Creates a new DeviceMleEval from an MleEval.
100    pub fn new(eval: MleEval<F, TaskScope>) -> Self {
101        Self { raw: eval }
102    }
103
104    /// Returns a reference to the underlying MleEval.
105    pub fn inner(&self) -> &MleEval<F, TaskScope> {
106        &self.raw
107    }
108
109    /// Consumes self and returns the underlying MleEval.
110    pub fn into_inner(self) -> MleEval<F, TaskScope> {
111        self.raw
112    }
113
114    /// Returns a reference to the evaluations tensor.
115    pub fn evaluations(&self) -> &Tensor<F, TaskScope> {
116        self.raw.evaluations()
117    }
118
119    /// Copies the evaluations to the host and returns them as a Vec.
120    pub fn to_host_vec(&self) -> Result<Vec<F>, CopyError> {
121        let device_tensor = DeviceTensor::from_raw(self.raw.evaluations().clone());
122        let host_tensor = device_tensor.to_host()?;
123        Ok(host_tensor.into_buffer().into_vec())
124    }
125}
126
127/// # Safety
128///
129pub unsafe trait PartialLagrangeKernel<F: Field> {
130    fn partial_lagrange_kernel() -> KernelPtr;
131}
132
133/// # Safety
134///
135pub unsafe trait PartialGeqKernel<F: Field> {
136    fn partial_geq_kernel() -> KernelPtr;
137}
138
139impl<F: DeviceCopy + Field> DeviceMle<F> {
140    /// Evaluates the MLE at the given point.
141    pub fn eval_at_point<EF: DeviceCopy + ExtensionField<F>>(
142        &self,
143        point: &DevicePoint<EF>,
144    ) -> DeviceMleEval<EF>
145    where
146        TaskScope: PartialLagrangeKernel<EF> + DotKernel<F, EF>,
147    {
148        let eq = point.partial_lagrange();
149        self.eval_at_eq(&eq)
150    }
151
152    /// Evaluates the MLE given precomputed eq polynomial.
153    pub fn eval_at_eq<EF: DeviceCopy + ExtensionField<F>>(
154        &self,
155        eq: &DeviceMle<EF>,
156    ) -> DeviceMleEval<EF>
157    where
158        TaskScope: DotKernel<F, EF>,
159    {
160        // MLE guts shape is [num_polynomials, num_entries] (TaskScope convention)
161        // eq shape is [1, num_entries] from partial_lagrange
162        // Dot along dim 1 reduces the num_entries dimension, giving [num_polynomials]
163        let result = self.guts.dot_along_dim(eq.guts(), 1);
164        DeviceMleEval::new(MleEval::new(result.into_inner()))
165    }
166
167    /// Evaluates the MLE at the given point with the last variable fixed to zero.
168    /// This is equivalent to evaluating at (point, 0).
169    pub fn fixed_at_zero<EF: DeviceCopy + ExtensionField<F>>(
170        &self,
171        point: &Point<EF>,
172    ) -> DeviceMleEval<EF>
173    where
174        TaskScope: PartialLagrangeKernel<EF> + DotKernel<F, EF>,
175    {
176        // Extend the point with zero at the end
177        let mut extended_point = point.clone();
178        extended_point.add_dimension_back(EF::zero());
179        let device_point = DevicePoint::from_host(&extended_point, self.backend()).unwrap();
180        self.eval_at_point(&device_point)
181    }
182}
183
184impl<F: Field> MleBaseBackend<F> for TaskScope {
185    #[inline]
186    fn uninit_mle(&self, num_polynomials: usize, num_non_zero_entries: usize) -> Tensor<F, Self> {
187        // TaskScope convention: [num_polynomials, num_non_zero_entries]
188        Tensor::with_sizes_in([num_polynomials, num_non_zero_entries], self.clone())
189    }
190
191    #[inline]
192    fn num_polynomials(guts: &Tensor<F, Self>) -> usize {
193        // TaskScope convention: sizes()[0] is num_polynomials
194        guts.sizes()[0]
195    }
196
197    #[inline]
198    fn num_variables(guts: &Tensor<F, Self>) -> u32 {
199        // TaskScope convention: sizes()[1] is num_non_zero_entries
200        guts.sizes()[1].next_power_of_two().ilog2()
201    }
202
203    #[inline]
204    fn num_non_zero_entries(guts: &Tensor<F, Self>) -> usize {
205        // TaskScope convention: sizes()[1] is num_non_zero_entries
206        guts.sizes()[1]
207    }
208}
209
210unsafe impl PartialLagrangeKernel<SP1Field> for TaskScope {
211    fn partial_lagrange_kernel() -> KernelPtr {
212        unsafe { partial_lagrange_koala_bear() }
213    }
214}
215
216unsafe impl PartialLagrangeKernel<SP1ExtensionField> for TaskScope {
217    fn partial_lagrange_kernel() -> KernelPtr {
218        unsafe { partial_lagrange_koala_bear_extension() }
219    }
220}
221
222unsafe impl PartialGeqKernel<SP1Field> for TaskScope {
223    fn partial_geq_kernel() -> KernelPtr {
224        unsafe { partial_geq_koala_bear() }
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use slop_multilinear::{Mle, Point};
231    use sp1_primitives::{SP1ExtensionField, SP1Field};
232
233    use super::{DeviceMleEval, DevicePoint};
234    use crate::mle::DeviceMle;
235
236    #[test]
237    fn test_mle_eval() {
238        let mut rng = rand::thread_rng();
239
240        type F = SP1Field;
241        type EF = SP1ExtensionField;
242
243        let mle = Mle::<F>::rand(&mut rng, 100, 16);
244        let point = Point::<EF>::rand(&mut rng, 16);
245
246        let evals = crate::run_sync_in_place(|t| {
247            let d_point = DevicePoint::from_host(&point, &t).unwrap();
248            let d_mle = DeviceMle::from_host(&mle, &t).unwrap();
249            let eval: DeviceMleEval<EF> = d_mle.eval_at_point(&d_point);
250            eval.to_host_vec().unwrap()
251        })
252        .unwrap();
253
254        let host_evals = mle.eval_at(&point).to_vec();
255        assert_eq!(evals, host_evals);
256    }
257}