vector_expression_ref/
vector_expression_ref.rs

1use std::ops::{Index, Range};
2use xpr::{ops::Term, Expression, Fold, Xpr};
3
4// This is our statically sized vector type
5#[derive(Debug)]
6struct Vec<const N: usize>(Box<[f64; N]>);
7
8impl<const N: usize> Vec<{ N }> {
9    #[inline]
10    fn new(array: [f64; N]) -> Self {
11        Self(Box::new(array))
12    }
13}
14
15impl<const N: usize> Index<Range<usize>> for Vec<N> {
16    type Output = [f64];
17
18    #[inline]
19    fn index(&self, index: Range<usize>) -> &Self::Output {
20        &self.0[index]
21    }
22}
23
24impl<const N: usize> Index<usize> for Vec<N> {
25    type Output = f64;
26
27    #[inline]
28    fn index(&self, index: usize) -> &Self::Output {
29        &self.0[index]
30    }
31}
32
33// a convenience trait for cnverting Vec instances to xpr terminals
34impl<const N: usize> Expression for Vec<N> {}
35
36// now lets implement conversion from an Xpr<T> expression to Vec
37
38struct IthElement<'a, const N: usize>(usize, std::marker::PhantomData<&'a ()>);
39
40impl<'a, const N: usize> Fold<Term<&'a Vec<{ N }>>> for IthElement<'a, { N }> {
41    // replace by the value at the index in `IthElement`
42    type Output = f64;
43
44    // extracts the i-th element of a vector terminal
45    #[inline]
46    fn fold(&mut self, Term(v): &Term<&'a Vec<{ N }>>) -> f64 {
47        v[self.0]
48    }
49}
50
51impl<'a, T, const N: usize> From<Xpr<T>> for Vec<{ N }>
52where
53    IthElement<'a, N>: Fold<Xpr<T>, Output = f64>,
54{
55    // conversion from a vector expression to a Vec instance
56    #[inline]
57    fn from(expr: Xpr<T>) -> Self {
58        // scary unsafe uninitialized array
59        let mut ret = Vec::new(unsafe { std::mem::MaybeUninit::uninit().assume_init() });
60
61        // apply the operations in the vector expression element-wise
62        for (i, e) in ret.0.iter_mut().enumerate() {
63            *e = IthElement(i, std::marker::PhantomData).fold(&expr);
64        }
65        ret
66    }
67}
68
69pub fn main() {
70    // Create a couple of vectors
71    let x1 = Vec::new([0.6; 5000]);
72    let x2 = Vec::new([1.0; 5000]);
73    let x3 = Vec::new([40.0; 5000]);
74    let x4 = Vec::new([100.0; 5000]);
75    let x5 = Vec::new([3000.0; 5000]);
76
77    // A chained addition without any Vec temporaries!
78    let v = Vec::from(x1.as_xpr() + &x2 + &x3 + &x4 + &x5);
79    println!("v[0..5] = {:?}", &v[0..5]);
80}