sp1_recursion_compiler/ir/
collections.rs

1use itertools::Itertools;
2use p3_field::AbstractField;
3
4use super::{Builder, Config, FromConstant, MemIndex, MemVariable, Ptr, Usize, Var, Variable};
5
6/// An array that is either of static or dynamic size.
7#[derive(Debug, Clone)]
8pub enum Array<C: Config, T> {
9    Fixed(Vec<T>),
10    Dyn(Ptr<C::N>, Usize<C::N>),
11}
12
13impl<C: Config, V: MemVariable<C>> Array<C, V> {
14    /// Gets a fixed version of the array.
15    pub fn vec(&self) -> Vec<V> {
16        match self {
17            Self::Fixed(vec) => vec.clone(),
18            _ => panic!("array is dynamic, not fixed"),
19        }
20    }
21
22    /// Gets the length of the array as a variable inside the DSL.
23    pub fn len(&self) -> Usize<C::N> {
24        match self {
25            Self::Fixed(vec) => Usize::from(vec.len()),
26            Self::Dyn(_, len) => *len,
27        }
28    }
29
30    /// Shifts the array by `shift` elements.
31    pub fn shift(&self, builder: &mut Builder<C>, shift: Var<C::N>) -> Array<C, V> {
32        match self {
33            Self::Fixed(_) => {
34                todo!()
35            }
36            Self::Dyn(ptr, len) => {
37                assert!(V::size_of() == 1, "only support variables of size 1");
38                let new_address = builder.eval(ptr.address + shift);
39                let new_ptr = Ptr::<C::N> { address: new_address };
40                let len_var = len.materialize(builder);
41                let new_length = builder.eval(len_var - shift);
42                Array::Dyn(new_ptr, Usize::Var(new_length))
43            }
44        }
45    }
46
47    /// Truncates the array to `len` elements.
48    pub fn truncate(&self, builder: &mut Builder<C>, len: Usize<C::N>) {
49        match self {
50            Self::Fixed(_) => {
51                todo!()
52            }
53            Self::Dyn(_, old_len) => {
54                builder.assign(*old_len, len);
55            }
56        };
57    }
58
59    pub fn slice(
60        &self,
61        builder: &mut Builder<C>,
62        start: Usize<C::N>,
63        end: Usize<C::N>,
64    ) -> Array<C, V> {
65        match self {
66            Self::Fixed(vec) => {
67                if let (Usize::Const(start), Usize::Const(end)) = (start, end) {
68                    builder.vec(vec[start..end].to_vec())
69                } else {
70                    panic!("Cannot slice a fixed array with a variable start or end");
71                }
72            }
73            Self::Dyn(_, len) => {
74                if builder.debug {
75                    let start_v = start.materialize(builder);
76                    let end_v = end.materialize(builder);
77                    let valid = builder.lt(start_v, end_v);
78                    builder.assert_var_eq(valid, C::N::one());
79
80                    let len_v = len.materialize(builder);
81                    let len_plus_1_v = builder.eval(len_v + C::N::one());
82                    let valid = builder.lt(end_v, len_plus_1_v);
83                    builder.assert_var_eq(valid, C::N::one());
84                }
85
86                let slice_len: Usize<_> = builder.eval(end - start);
87                let mut slice = builder.dyn_array(slice_len);
88                builder.range(0, slice_len).for_each(|i, builder| {
89                    let idx: Usize<_> = builder.eval(start + i);
90                    let value = builder.get(self, idx);
91                    builder.set(&mut slice, i, value);
92                });
93
94                slice
95            }
96        }
97    }
98}
99
100impl<C: Config> Builder<C> {
101    /// Initialize an array of fixed length `len`. The entries will be uninitialized.
102    pub fn array<V: MemVariable<C>>(&mut self, len: impl Into<Usize<C::N>>) -> Array<C, V> {
103        self.dyn_array(len)
104    }
105
106    /// Creates an array from a vector.
107    pub fn vec<V: MemVariable<C>>(&mut self, v: Vec<V>) -> Array<C, V> {
108        Array::Fixed(v)
109    }
110
111    /// Creates a dynamic array for a length.
112    pub fn dyn_array<V: MemVariable<C>>(&mut self, len: impl Into<Usize<C::N>>) -> Array<C, V> {
113        let len = match len.into() {
114            Usize::Const(len) => self.eval(C::N::from_canonical_usize(len)),
115            Usize::Var(len) => len,
116        };
117        let len = Usize::Var(len);
118        let ptr = self.alloc(len, V::size_of());
119        Array::Dyn(ptr, len)
120    }
121
122    pub fn get<V: MemVariable<C>, I: Into<Usize<C::N>>>(
123        &mut self,
124        slice: &Array<C, V>,
125        index: I,
126    ) -> V {
127        let index = index.into();
128
129        match slice {
130            Array::Fixed(slice) => {
131                if let Usize::Const(idx) = index {
132                    slice[idx].clone()
133                } else {
134                    panic!("Cannot index into a fixed slice with a variable size")
135                }
136            }
137            Array::Dyn(ptr, len) => {
138                if self.debug {
139                    let index_v = index.materialize(self);
140                    let len_v = len.materialize(self);
141                    let valid = self.lt(index_v, len_v);
142                    self.assert_var_eq(valid, C::N::one());
143                }
144                let index = MemIndex { index, offset: 0, size: V::size_of() };
145                let var: V = self.uninit();
146                self.load(var.clone(), *ptr, index);
147                var
148            }
149        }
150    }
151
152    pub fn get_ptr<V: MemVariable<C>, I: Into<Usize<C::N>>>(
153        &mut self,
154        slice: &Array<C, V>,
155        index: I,
156    ) -> Ptr<C::N> {
157        let index = index.into();
158
159        match slice {
160            Array::Fixed(_) => {
161                todo!()
162            }
163            Array::Dyn(ptr, len) => {
164                if self.debug {
165                    let index_v = index.materialize(self);
166                    let len_v = len.materialize(self);
167                    let valid = self.lt(index_v, len_v);
168                    self.assert_var_eq(valid, C::N::one());
169                }
170                let index = MemIndex { index, offset: 0, size: V::size_of() };
171                let var: Ptr<C::N> = self.uninit();
172                self.load(var, *ptr, index);
173                var
174            }
175        }
176    }
177
178    pub fn set<V: MemVariable<C>, I: Into<Usize<C::N>>, Expr: Into<V::Expression>>(
179        &mut self,
180        slice: &mut Array<C, V>,
181        index: I,
182        value: Expr,
183    ) {
184        let index = index.into();
185
186        match slice {
187            Array::Fixed(_) => {
188                todo!()
189            }
190            Array::Dyn(ptr, len) => {
191                if self.debug {
192                    let index_v = index.materialize(self);
193                    let len_v = len.materialize(self);
194                    let valid = self.lt(index_v, len_v);
195                    self.assert_var_eq(valid, C::N::one());
196                }
197                let index = MemIndex { index, offset: 0, size: V::size_of() };
198                let value: V = self.eval(value);
199                self.store(*ptr, index, value);
200            }
201        }
202    }
203
204    pub fn set_value<V: MemVariable<C>, I: Into<Usize<C::N>>>(
205        &mut self,
206        slice: &mut Array<C, V>,
207        index: I,
208        value: V,
209    ) {
210        let index = index.into();
211
212        match slice {
213            Array::Fixed(_) => {
214                todo!()
215            }
216            Array::Dyn(ptr, _) => {
217                let index = MemIndex { index, offset: 0, size: V::size_of() };
218                self.store(*ptr, index, value);
219            }
220        }
221    }
222}
223
224impl<C: Config, T: MemVariable<C>> Variable<C> for Array<C, T> {
225    type Expression = Self;
226
227    fn uninit(builder: &mut Builder<C>) -> Self {
228        Array::Dyn(builder.uninit(), builder.uninit())
229    }
230
231    fn assign(&self, src: Self::Expression, builder: &mut Builder<C>) {
232        match (self, src.clone()) {
233            (Array::Dyn(lhs_ptr, lhs_len), Array::Dyn(rhs_ptr, rhs_len)) => {
234                builder.assign(*lhs_ptr, rhs_ptr);
235                builder.assign(*lhs_len, rhs_len);
236            }
237            _ => unreachable!(),
238        }
239    }
240
241    fn assert_eq(
242        lhs: impl Into<Self::Expression>,
243        rhs: impl Into<Self::Expression>,
244        builder: &mut Builder<C>,
245    ) {
246        let lhs = lhs.into();
247        let rhs = rhs.into();
248
249        match (lhs.clone(), rhs.clone()) {
250            (Array::Fixed(lhs), Array::Fixed(rhs)) => {
251                for (l, r) in lhs.iter().zip_eq(rhs.iter()) {
252                    T::assert_eq(
253                        T::Expression::from(l.clone()),
254                        T::Expression::from(r.clone()),
255                        builder,
256                    );
257                }
258            }
259            (Array::Dyn(_, lhs_len), Array::Dyn(_, rhs_len)) => {
260                let lhs_len_var = builder.materialize(lhs_len);
261                let rhs_len_var = builder.materialize(rhs_len);
262                builder.assert_eq::<Var<_>>(lhs_len_var, rhs_len_var);
263
264                let start = Usize::Const(0);
265                let end = lhs_len;
266                builder.range(start, end).for_each(|i, builder| {
267                    let a = builder.get(&lhs, i);
268                    let b = builder.get(&rhs, i);
269                    builder.assert_eq::<T>(a, b);
270                });
271            }
272            _ => panic!("cannot compare arrays of different types"),
273        }
274    }
275
276    fn assert_ne(
277        lhs: impl Into<Self::Expression>,
278        rhs: impl Into<Self::Expression>,
279        builder: &mut Builder<C>,
280    ) {
281        let lhs = lhs.into();
282        let rhs = rhs.into();
283
284        match (lhs.clone(), rhs.clone()) {
285            (Array::Fixed(lhs), Array::Fixed(rhs)) => {
286                for (l, r) in lhs.iter().zip_eq(rhs.iter()) {
287                    T::assert_ne(
288                        T::Expression::from(l.clone()),
289                        T::Expression::from(r.clone()),
290                        builder,
291                    );
292                }
293            }
294            (Array::Dyn(_, lhs_len), Array::Dyn(_, rhs_len)) => {
295                builder.assert_usize_eq(lhs_len, rhs_len);
296
297                let end = lhs_len;
298                builder.range(0, end).for_each(|i, builder| {
299                    let a = builder.get(&lhs, i);
300                    let b = builder.get(&rhs, i);
301                    builder.assert_ne::<T>(a, b);
302                });
303            }
304            _ => panic!("cannot compare arrays of different types"),
305        }
306    }
307}
308
309impl<C: Config, T: MemVariable<C>> MemVariable<C> for Array<C, T> {
310    fn size_of() -> usize {
311        2
312    }
313
314    fn load(&self, src: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
315        match self {
316            Array::Dyn(dst, Usize::Var(len)) => {
317                let mut index = index;
318                dst.load(src, index, builder);
319                index.offset += <Ptr<C::N> as MemVariable<C>>::size_of();
320                len.load(src, index, builder);
321            }
322            _ => unreachable!(),
323        }
324    }
325
326    fn store(&self, dst: Ptr<<C as Config>::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
327        match self {
328            Array::Dyn(src, Usize::Var(len)) => {
329                let mut index = index;
330                src.store(dst, index, builder);
331                index.offset += <Ptr<C::N> as MemVariable<C>>::size_of();
332                len.store(dst, index, builder);
333            }
334            _ => unreachable!(),
335        }
336    }
337}
338
339impl<C: Config, V: FromConstant<C> + MemVariable<C>> FromConstant<C> for Array<C, V> {
340    type Constant = Vec<V::Constant>;
341
342    fn constant(value: Self::Constant, builder: &mut Builder<C>) -> Self {
343        let mut array = builder.dyn_array(value.len());
344        for (i, val) in value.into_iter().enumerate() {
345            let val = V::constant(val, builder);
346            builder.set(&mut array, i, val);
347        }
348        array
349    }
350}
351
352impl<C: Config, V: FromConstant<C> + MemVariable<C>> FromConstant<C> for Vec<V> {
353    type Constant = Vec<V::Constant>;
354
355    fn constant(value: Self::Constant, builder: &mut Builder<C>) -> Self {
356        value.into_iter().map(|x| V::constant(x, builder)).collect()
357    }
358}
359
360impl<C: Config, V: FromConstant<C> + MemVariable<C>, const N: usize> FromConstant<C> for [V; N] {
361    type Constant = [V::Constant; N];
362
363    fn constant(value: Self::Constant, builder: &mut Builder<C>) -> Self {
364        value.map(|x| V::constant(x, builder))
365    }
366}