vector_expression/
vector_expression.rs

1use xpr::{ops::Term, Expression, Fold, Xpr};
2
3// If we are writing a linear algebra library,
4// we will need a statically sized vector type
5#[derive(Debug)]
6struct Vec<const N: usize>(Box<[f64; N]>);
7
8impl<const N: usize> Vec<{ N }> {
9    #[inline]
10    fn new(array: [f64; N]) -> Self {
11        Self(Box::new(array))
12    }
13}
14
15// a convenience trait for cnverting Vec instances to xpr terminals
16impl<const N: usize> Expression for Vec<N> {}
17
18// now lets implement conversion from an Xpr<T> expression to Vec
19
20struct IthElement<const N: usize>(usize);
21
22// match all terminals wrapping a `Vec`
23impl<const N: usize> Fold<Term<Vec<{ N }>>> for IthElement<{ N }> {
24    // replace by the value at the index in `IthElement`
25    type Output = f64;
26
27    // extracts the i-th element of a vector terminal
28    #[inline]
29    fn fold(&mut self, Term(v): &Term<Vec<{ N }>>) -> f64 {
30        v.0[self.0]
31    }
32}
33
34impl<T, const N: usize> From<Xpr<T>> for Vec<{ N }>
35where
36    IthElement<N>: Fold<Xpr<T>, Output = f64>,
37{
38    // conversion from a vector expression to a Vec instance
39    #[inline]
40    fn from(expr: Xpr<T>) -> Self {
41        // scary unsafe uninitialized array
42        let mut ret = Vec::new(unsafe { std::mem::MaybeUninit::uninit().assume_init() });
43
44        // apply the operations in the vector expression element-wise
45        for (i, e) in ret.0.iter_mut().enumerate() {
46            *e = IthElement(i).fold(&expr);
47        }
48        ret
49    }
50}
51
52pub fn main() {
53    // Create a couple of vectors and convert to Xpr expressions
54    let x1 = Vec::new([0.6; 5000]).into_xpr();
55    let x2 = Vec::new([1.0; 5000]).into_xpr();
56    let x3 = Vec::new([40.0; 5000]).into_xpr();
57    let x4 = Vec::new([100.0; 5000]).into_xpr();
58    let x5 = Vec::new([3000.0; 5000]).into_xpr();
59
60    // A chained addition without any Vec temporaries!
61    let v = Vec::from(x1 + x2 + x3 + x4 + x5);
62    println!("v[0..5] = {:?}", &v.0[0..5]);
63}