use itertools::Itertools;
use p3_field::AbstractField;
use super::{Builder, Config, FromConstant, MemIndex, MemVariable, Ptr, Usize, Var, Variable};
#[derive(Debug, Clone)]
pub enum Array<C: Config, T> {
Fixed(Vec<T>),
Dyn(Ptr<C::N>, Usize<C::N>),
}
impl<C: Config, V: MemVariable<C>> Array<C, V> {
pub fn vec(&self) -> Vec<V> {
match self {
Self::Fixed(vec) => vec.clone(),
_ => panic!("array is dynamic, not fixed"),
}
}
pub fn len(&self) -> Usize<C::N> {
match self {
Self::Fixed(vec) => Usize::from(vec.len()),
Self::Dyn(_, len) => *len,
}
}
pub fn shift(&self, builder: &mut Builder<C>, shift: Var<C::N>) -> Array<C, V> {
match self {
Self::Fixed(_) => {
todo!()
}
Self::Dyn(ptr, len) => {
assert!(V::size_of() == 1, "only support variables of size 1");
let new_address = builder.eval(ptr.address + shift);
let new_ptr = Ptr::<C::N> { address: new_address };
let len_var = len.materialize(builder);
let new_length = builder.eval(len_var - shift);
Array::Dyn(new_ptr, Usize::Var(new_length))
}
}
}
pub fn truncate(&self, builder: &mut Builder<C>, len: Usize<C::N>) {
match self {
Self::Fixed(_) => {
todo!()
}
Self::Dyn(_, old_len) => {
builder.assign(*old_len, len);
}
};
}
pub fn slice(
&self,
builder: &mut Builder<C>,
start: Usize<C::N>,
end: Usize<C::N>,
) -> Array<C, V> {
match self {
Self::Fixed(vec) => {
if let (Usize::Const(start), Usize::Const(end)) = (start, end) {
builder.vec(vec[start..end].to_vec())
} else {
panic!("Cannot slice a fixed array with a variable start or end");
}
}
Self::Dyn(_, len) => {
if builder.debug {
let start_v = start.materialize(builder);
let end_v = end.materialize(builder);
let valid = builder.lt(start_v, end_v);
builder.assert_var_eq(valid, C::N::one());
let len_v = len.materialize(builder);
let len_plus_1_v = builder.eval(len_v + C::N::one());
let valid = builder.lt(end_v, len_plus_1_v);
builder.assert_var_eq(valid, C::N::one());
}
let slice_len: Usize<_> = builder.eval(end - start);
let mut slice = builder.dyn_array(slice_len);
builder.range(0, slice_len).for_each(|i, builder| {
let idx: Usize<_> = builder.eval(start + i);
let value = builder.get(self, idx);
builder.set(&mut slice, i, value);
});
slice
}
}
}
}
impl<C: Config> Builder<C> {
pub fn array<V: MemVariable<C>>(&mut self, len: impl Into<Usize<C::N>>) -> Array<C, V> {
self.dyn_array(len)
}
pub fn vec<V: MemVariable<C>>(&mut self, v: Vec<V>) -> Array<C, V> {
Array::Fixed(v)
}
pub fn dyn_array<V: MemVariable<C>>(&mut self, len: impl Into<Usize<C::N>>) -> Array<C, V> {
let len = match len.into() {
Usize::Const(len) => self.eval(C::N::from_canonical_usize(len)),
Usize::Var(len) => len,
};
let len = Usize::Var(len);
let ptr = self.alloc(len, V::size_of());
Array::Dyn(ptr, len)
}
pub fn get<V: MemVariable<C>, I: Into<Usize<C::N>>>(
&mut self,
slice: &Array<C, V>,
index: I,
) -> V {
let index = index.into();
match slice {
Array::Fixed(slice) => {
if let Usize::Const(idx) = index {
slice[idx].clone()
} else {
panic!("Cannot index into a fixed slice with a variable size")
}
}
Array::Dyn(ptr, len) => {
if self.debug {
let index_v = index.materialize(self);
let len_v = len.materialize(self);
let valid = self.lt(index_v, len_v);
self.assert_var_eq(valid, C::N::one());
}
let index = MemIndex { index, offset: 0, size: V::size_of() };
let var: V = self.uninit();
self.load(var.clone(), *ptr, index);
var
}
}
}
pub fn get_ptr<V: MemVariable<C>, I: Into<Usize<C::N>>>(
&mut self,
slice: &Array<C, V>,
index: I,
) -> Ptr<C::N> {
let index = index.into();
match slice {
Array::Fixed(_) => {
todo!()
}
Array::Dyn(ptr, len) => {
if self.debug {
let index_v = index.materialize(self);
let len_v = len.materialize(self);
let valid = self.lt(index_v, len_v);
self.assert_var_eq(valid, C::N::one());
}
let index = MemIndex { index, offset: 0, size: V::size_of() };
let var: Ptr<C::N> = self.uninit();
self.load(var, *ptr, index);
var
}
}
}
pub fn set<V: MemVariable<C>, I: Into<Usize<C::N>>, Expr: Into<V::Expression>>(
&mut self,
slice: &mut Array<C, V>,
index: I,
value: Expr,
) {
let index = index.into();
match slice {
Array::Fixed(_) => {
todo!()
}
Array::Dyn(ptr, len) => {
if self.debug {
let index_v = index.materialize(self);
let len_v = len.materialize(self);
let valid = self.lt(index_v, len_v);
self.assert_var_eq(valid, C::N::one());
}
let index = MemIndex { index, offset: 0, size: V::size_of() };
let value: V = self.eval(value);
self.store(*ptr, index, value);
}
}
}
pub fn set_value<V: MemVariable<C>, I: Into<Usize<C::N>>>(
&mut self,
slice: &mut Array<C, V>,
index: I,
value: V,
) {
let index = index.into();
match slice {
Array::Fixed(_) => {
todo!()
}
Array::Dyn(ptr, _) => {
let index = MemIndex { index, offset: 0, size: V::size_of() };
self.store(*ptr, index, value);
}
}
}
}
impl<C: Config, T: MemVariable<C>> Variable<C> for Array<C, T> {
type Expression = Self;
fn uninit(builder: &mut Builder<C>) -> Self {
Array::Dyn(builder.uninit(), builder.uninit())
}
fn assign(&self, src: Self::Expression, builder: &mut Builder<C>) {
match (self, src.clone()) {
(Array::Dyn(lhs_ptr, lhs_len), Array::Dyn(rhs_ptr, rhs_len)) => {
builder.assign(*lhs_ptr, rhs_ptr);
builder.assign(*lhs_len, rhs_len);
}
_ => unreachable!(),
}
}
fn assert_eq(
lhs: impl Into<Self::Expression>,
rhs: impl Into<Self::Expression>,
builder: &mut Builder<C>,
) {
let lhs = lhs.into();
let rhs = rhs.into();
match (lhs.clone(), rhs.clone()) {
(Array::Fixed(lhs), Array::Fixed(rhs)) => {
for (l, r) in lhs.iter().zip_eq(rhs.iter()) {
T::assert_eq(
T::Expression::from(l.clone()),
T::Expression::from(r.clone()),
builder,
);
}
}
(Array::Dyn(_, lhs_len), Array::Dyn(_, rhs_len)) => {
let lhs_len_var = builder.materialize(lhs_len);
let rhs_len_var = builder.materialize(rhs_len);
builder.assert_eq::<Var<_>>(lhs_len_var, rhs_len_var);
let start = Usize::Const(0);
let end = lhs_len;
builder.range(start, end).for_each(|i, builder| {
let a = builder.get(&lhs, i);
let b = builder.get(&rhs, i);
builder.assert_eq::<T>(a, b);
});
}
_ => panic!("cannot compare arrays of different types"),
}
}
fn assert_ne(
lhs: impl Into<Self::Expression>,
rhs: impl Into<Self::Expression>,
builder: &mut Builder<C>,
) {
let lhs = lhs.into();
let rhs = rhs.into();
match (lhs.clone(), rhs.clone()) {
(Array::Fixed(lhs), Array::Fixed(rhs)) => {
for (l, r) in lhs.iter().zip_eq(rhs.iter()) {
T::assert_ne(
T::Expression::from(l.clone()),
T::Expression::from(r.clone()),
builder,
);
}
}
(Array::Dyn(_, lhs_len), Array::Dyn(_, rhs_len)) => {
builder.assert_usize_eq(lhs_len, rhs_len);
let end = lhs_len;
builder.range(0, end).for_each(|i, builder| {
let a = builder.get(&lhs, i);
let b = builder.get(&rhs, i);
builder.assert_ne::<T>(a, b);
});
}
_ => panic!("cannot compare arrays of different types"),
}
}
}
impl<C: Config, T: MemVariable<C>> MemVariable<C> for Array<C, T> {
fn size_of() -> usize {
2
}
fn load(&self, src: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
match self {
Array::Dyn(dst, Usize::Var(len)) => {
let mut index = index;
dst.load(src, index, builder);
index.offset += <Ptr<C::N> as MemVariable<C>>::size_of();
len.load(src, index, builder);
}
_ => unreachable!(),
}
}
fn store(&self, dst: Ptr<<C as Config>::N>, index: MemIndex<C::N>, builder: &mut Builder<C>) {
match self {
Array::Dyn(src, Usize::Var(len)) => {
let mut index = index;
src.store(dst, index, builder);
index.offset += <Ptr<C::N> as MemVariable<C>>::size_of();
len.store(dst, index, builder);
}
_ => unreachable!(),
}
}
}
impl<C: Config, V: FromConstant<C> + MemVariable<C>> FromConstant<C> for Array<C, V> {
type Constant = Vec<V::Constant>;
fn constant(value: Self::Constant, builder: &mut Builder<C>) -> Self {
let mut array = builder.dyn_array(value.len());
for (i, val) in value.into_iter().enumerate() {
let val = V::constant(val, builder);
builder.set(&mut array, i, val);
}
array
}
}
impl<C: Config, V: FromConstant<C> + MemVariable<C>> FromConstant<C> for Vec<V> {
type Constant = Vec<V::Constant>;
fn constant(value: Self::Constant, builder: &mut Builder<C>) -> Self {
value.into_iter().map(|x| V::constant(x, builder)).collect()
}
}
impl<C: Config, V: FromConstant<C> + MemVariable<C>, const N: usize> FromConstant<C> for [V; N] {
type Constant = [V::Constant; N];
fn constant(value: Self::Constant, builder: &mut Builder<C>) -> Self {
value.map(|x| V::constant(x, builder))
}
}