1use itertools::Itertools;
2use p3_field::AbstractField;
3
4use super::{Builder, Config, FromConstant, MemIndex, MemVariable, Ptr, Usize, Var, Variable};
5
6#[derive(Debug, Clone)]
8pub enum Array<C: Config, T> {
9 Fixed(Vec<T>),
10 Dyn(Ptr<C::N>, Usize<C::N>),
11}
12
13impl<C: Config, V: MemVariable<C>> Array<C, V> {
14 pub fn vec(&self) -> Vec<V> {
16 match self {
17 Self::Fixed(vec) => vec.clone(),
18 _ => panic!("array is dynamic, not fixed"),
19 }
20 }
21
22 pub fn len(&self) -> Usize<C::N> {
24 match self {
25 Self::Fixed(vec) => Usize::from(vec.len()),
26 Self::Dyn(_, len) => *len,
27 }
28 }
29
30 pub fn shift(&self, builder: &mut Builder<C>, shift: Var<C::N>) -> Array<C, V> {
32 match self {
33 Self::Fixed(_) => {
34 todo!()
35 }
36 Self::Dyn(ptr, len) => {
37 assert!(V::size_of() == 1, "only support variables of size 1");
38 let new_address = builder.eval(ptr.address + shift);
39 let new_ptr = Ptr::<C::N> { address: new_address };
40 let len_var = len.materialize(builder);
41 let new_length = builder.eval(len_var - shift);
42 Array::Dyn(new_ptr, Usize::Var(new_length))
43 }
44 }
45 }
46
47 pub fn truncate(&self, builder: &mut Builder<C>, len: Usize<C::N>) {
49 match self {
50 Self::Fixed(_) => {
51 todo!()
52 }
53 Self::Dyn(_, old_len) => {
54 builder.assign(*old_len, len);
55 }
56 };
57 }
58
59 pub fn slice(
60 &self,
61 builder: &mut Builder<C>,
62 start: Usize<C::N>,
63 end: Usize<C::N>,
64 ) -> Array<C, V> {
65 match self {
66 Self::Fixed(vec) => {
67 if let (Usize::Const(start), Usize::Const(end)) = (start, end) {
68 builder.vec(vec[start..end].to_vec())
69 } else {
70 panic!("Cannot slice a fixed array with a variable start or end");
71 }
72 }
73 Self::Dyn(_, len) => {
74 if builder.debug {
75 let start_v = start.materialize(builder);
76 let end_v = end.materialize(builder);
77 let valid = builder.lt(start_v, end_v);
78 builder.assert_var_eq(valid, C::N::one());
79
80 let len_v = len.materialize(builder);
81 let len_plus_1_v = builder.eval(len_v + C::N::one());
82 let valid = builder.lt(end_v, len_plus_1_v);
83 builder.assert_var_eq(valid, C::N::one());
84 }
85
86 let slice_len: Usize<_> = builder.eval(end - start);
87 let mut slice = builder.dyn_array(slice_len);
88 builder.range(0, slice_len).for_each(|i, builder| {
89 let idx: Usize<_> = builder.eval(start + i);
90 let value = builder.get(self, idx);
91 builder.set(&mut slice, i, value);
92 });
93
94 slice
95 }
96 }
97 }
98}
99
100impl<C: Config> Builder<C> {
101 pub fn array<V: MemVariable<C>>(&mut self, len: impl Into<Usize<C::N>>) -> Array<C, V> {
103 self.dyn_array(len)
104 }
105
106 pub fn vec<V: MemVariable<C>>(&mut self, v: Vec<V>) -> Array<C, V> {
108 Array::Fixed(v)
109 }
110
111 pub fn dyn_array<V: MemVariable<C>>(&mut self, len: impl Into<Usize<C::N>>) -> Array<C, V> {
113 let len = match len.into() {
114 Usize::Const(len) => self.eval(C::N::from_canonical_usize(len)),
115 Usize::Var(len) => len,
116 };
117 let len = Usize::Var(len);
118 let ptr = self.alloc(len, V::size_of());
119 Array::Dyn(ptr, len)
120 }
121
122 pub fn get<V: MemVariable<C>, I: Into<Usize<C::N>>>(
123 &mut self,
124 slice: &Array<C, V>,
125 index: I,
126 ) -> V {
127 let index = index.into();
128
129 match slice {
130 Array::Fixed(slice) => {
131 if let Usize::Const(idx) = index {
132 slice[idx].clone()
133 } else {
134 panic!("Cannot index into a fixed slice with a variable size")
135 }
136 }
137 Array::Dyn(ptr, len) => {
138 if self.debug {
139 let index_v = index.materialize(self);
140 let len_v = len.materialize(self);
141 let valid = self.lt(index_v, len_v);
142 self.assert_var_eq(valid, C::N::one());
143 }
144 let index = MemIndex { index, offset: 0, size: V::size_of() };
145 let var: V = self.uninit();
146 self.load(var.clone(), *ptr, index);
147 var
148 }
149 }
150 }
151
152 pub fn get_ptr<V: MemVariable<C>, I: Into<Usize<C::N>>>(
153 &mut self,
154 slice: &Array<C, V>,
155 index: I,
156 ) -> Ptr<C::N> {
157 let index = index.into();
158
159 match slice {
160 Array::Fixed(_) => {
161 todo!()
162 }
163 Array::Dyn(ptr, len) => {
164 if self.debug {
165 let index_v = index.materialize(self);
166 let len_v = len.materialize(self);
167 let valid = self.lt(index_v, len_v);
168 self.assert_var_eq(valid, C::N::one());
169 }
170 let index = MemIndex { index, offset: 0, size: V::size_of() };
171 let var: Ptr<C::N> = self.uninit();
172 self.load(var, *ptr, index);
173 var
174 }
175 }
176 }
177
178 pub fn set<V: MemVariable<C>, I: Into<Usize<C::N>>, Expr: Into<V::Expression>>(
179 &mut self,
180 slice: &mut Array<C, V>,
181 index: I,
182 value: Expr,
183 ) {
184 let index = index.into();
185
186 match slice {
187 Array::Fixed(_) => {
188 todo!()
189 }
190 Array::Dyn(ptr, len) => {
191 if self.debug {
192 let index_v = index.materialize(self);
193 let len_v = len.materialize(self);
194 let valid = self.lt(index_v, len_v);
195 self.assert_var_eq(valid, C::N::one());
196 }
197 let index = MemIndex { index, offset: 0, size: V::size_of() };
198 let value: V = self.eval(value);
199 self.store(*ptr, index, value);
200 }
201 }
202 }
203
204 pub fn set_value<V: MemVariable<C>, I: Into<Usize<C::N>>>(
205 &mut self,
206 slice: &mut Array<C, V>,
207 index: I,
208 value: V,
209 ) {
210 let index = index.into();
211
212 match slice {
213 Array::Fixed(_) => {
214 todo!()
215 }
216 Array::Dyn(ptr, _) => {
217 let index = MemIndex { index, offset: 0, size: V::size_of() };
218 self.store(*ptr, index, value);
219 }
220 }
221 }
222}
223
224impl<C: Config, T: MemVariable<C>> Variable<C> for Array<C, T> {
225 type Expression = Self;
226
227 fn uninit(builder: &mut Builder<C>) -> Self {
228 Array::Dyn(builder.uninit(), builder.uninit())
229 }
230
231 fn assign(&self, src: Self::Expression, builder: &mut Builder<C>) {
232 match (self, src.clone()) {
233 (Array::Dyn(lhs_ptr, lhs_len), Array::Dyn(rhs_ptr, rhs_len)) => {
234 builder.assign(*lhs_ptr, rhs_ptr);
235 builder.assign(*lhs_len, rhs_len);
236 }
237 _ => unreachable!(),
238 }
239 }
240
241 fn assert_eq(
242 lhs: impl Into<Self::Expression>,
243 rhs: impl Into<Self::Expression>,
244 builder: &mut Builder<C>,
245 ) {
246 let lhs = lhs.into();
247 let rhs = rhs.into();
248
249 match (lhs.clone(), rhs.clone()) {
250 (Array::Fixed(lhs), Array::Fixed(rhs)) => {
251 for (l, r) in lhs.iter().zip_eq(rhs.iter()) {
252 T::assert_eq(
253 T::Expression::from(l.clone()),
254 T::Expression::from(r.clone()),
255 builder,
256 );
257 }
258 }
259 (Array::Dyn(_, lhs_len), Array::Dyn(_, rhs_len)) => {
260 let lhs_len_var = builder.materialize(lhs_len);
261 let rhs_len_var = builder.materialize(rhs_len);
262 builder.assert_eq::<Var<_>>(lhs_len_var, rhs_len_var);
263
264 let start = Usize::Const(0);
265 let end = lhs_len;
266 builder.range(start, end).for_each(|i, builder| {
267 let a = builder.get(&lhs, i);
268 let b = builder.get(&rhs, i);
269 builder.assert_eq::<T>(a, b);
270 });
271 }
272 _ => panic!("cannot compare arrays of different types"),
273 }
274 }
275
276 fn assert_ne(
277 lhs: impl Into<Self::Expression>,
278 rhs: impl Into<Self::Expression>,
279 builder: &mut Builder<C>,
280 ) {
281 let lhs = lhs.into();
282 let rhs = rhs.into();
283
284 match (lhs.clone(), rhs.clone()) {
285 (Array::Fixed(lhs), Array::Fixed(rhs)) => {
286 for (l, r) in lhs.iter().zip_eq(rhs.iter()) {
287 T::assert_ne(
288 T::Expression::from(l.clone()),
289 T::Expression::from(r.clone()),
290 builder,
291 );
292 }
293 }
294 (Array::Dyn(_, lhs_len), Array::Dyn(_, rhs_len)) => {
295 builder.assert_usize_eq(lhs_len, rhs_len);
296
297 let end = lhs_len;
298 builder.range(0, end).for_each(|i, builder| {
299 let a = builder.get(&lhs, i);
300 let b = builder.get(&rhs, i);
301 builder.assert_ne::<T>(a, b);
302 });
303 }
304 _ => panic!("cannot compare arrays of different types"),
305 }
306 }
307}
308
309impl<C: Config, T: MemVariable<C>> MemVariable<C> for Array<C, T> {
310 fn size_of() -> usize {
311 2
312 }
313
314 fn load(&self, src: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
315 match self {
316 Array::Dyn(dst, Usize::Var(len)) => {
317 let mut index = index;
318 dst.load(src, index, builder);
319 index.offset += <Ptr<C::N> as MemVariable<C>>::size_of();
320 len.load(src, index, builder);
321 }
322 _ => unreachable!(),
323 }
324 }
325
326 fn store(&self, dst: Ptr<<C as Config>::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
327 match self {
328 Array::Dyn(src, Usize::Var(len)) => {
329 let mut index = index;
330 src.store(dst, index, builder);
331 index.offset += <Ptr<C::N> as MemVariable<C>>::size_of();
332 len.store(dst, index, builder);
333 }
334 _ => unreachable!(),
335 }
336 }
337}
338
339impl<C: Config, V: FromConstant<C> + MemVariable<C>> FromConstant<C> for Array<C, V> {
340 type Constant = Vec<V::Constant>;
341
342 fn constant(value: Self::Constant, builder: &mut Builder<C>) -> Self {
343 let mut array = builder.dyn_array(value.len());
344 for (i, val) in value.into_iter().enumerate() {
345 let val = V::constant(val, builder);
346 builder.set(&mut array, i, val);
347 }
348 array
349 }
350}
351
352impl<C: Config, V: FromConstant<C> + MemVariable<C>> FromConstant<C> for Vec<V> {
353 type Constant = Vec<V::Constant>;
354
355 fn constant(value: Self::Constant, builder: &mut Builder<C>) -> Self {
356 value.into_iter().map(|x| V::constant(x, builder)).collect()
357 }
358}
359
360impl<C: Config, V: FromConstant<C> + MemVariable<C>, const N: usize> FromConstant<C> for [V; N] {
361 type Constant = [V::Constant; N];
362
363 fn constant(value: Self::Constant, builder: &mut Builder<C>) -> Self {
364 value.map(|x| V::constant(x, builder))
365 }
366}