vector_expr/
lib.rs

1//! Vectorized math expression parser/evaluator.
2//!
3//! # Why?
4//!
5//! Performance. Evaluation of math expressions involving many variables can
6//! incur significant overhead from traversing the expression tree or performing
7//! variable lookups. We amortize that cost by performing intermediate
8//! operations on _vectors_ of input data at a time (with optional data
9//! parallelism via the `rayon` feature).
10//!
11//! # Example
12//!
13//! ```rust
14//! use vector_expr::*;
15//!
16//! fn binding_map(var_name: &str) -> BindingId {
17//!     match var_name {
18//!         "bar" => 0,
19//!         "baz" => 1,
20//!         "foo" => 2,
21//!         _ => unreachable!(),
22//!     }
23//! }
24//! let parsed = Expression::parse("2 * (foo + bar) * baz", binding_map).unwrap();
25//! let real = parsed.unwrap_real();
26//!
27//! let bar = [1.0, 2.0, 3.0];
28//! let baz = [4.0, 5.0, 6.0];
29//! let foo = [7.0, 8.0, 9.0];
30//! let bindings: &[&[f64]] = &[&bar, &baz, &foo];
31//! let mut registers = Registers::new(3);
32//! let output = real.evaluate(bindings, &mut registers);
33//! assert_eq!(&output, &[64.0, 100.0, 144.0]);
34//! ```
35
36mod evaluate;
37mod expression;
38mod parse;
39
40/// Uses the [`pest`] parsing expression grammar language.
41///
42/// ```text
43#[doc = include_str!("grammar.pest")]
44/// ```
45pub mod grammar_doc {}
46
47pub use evaluate::*;
48pub use expression::*;
49pub use parse::ParseError;
50
51/// Pass to `Expression::parse` if the expression has no variables.
52pub fn empty_binding_map(_var_name: &str) -> BindingId {
53    panic!("Empty binding map")
54}
55
56pub trait FloatExt: num_traits::Float + std::str::FromStr + Send + Sync {}
57impl FloatExt for f32 {}
58impl FloatExt for f64 {}
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63
64    #[test]
65    fn real_expression() {
66        fn binding_map(var_name: &str) -> BindingId {
67            match var_name {
68                "bar" => 0,
69                "baz" => 1,
70                "foo" => 2,
71                _ => unreachable!(),
72            }
73        }
74        let parsed = Expression::parse("2 * (foo + bar) * -baz", binding_map).unwrap();
75        let real = parsed.unwrap_real();
76
77        let bar = [1.0, 2.0, 3.0];
78        let baz = [4.0, 5.0, 6.0];
79        let foo = [7.0, 8.0, 9.0];
80        let bindings = &[bar, baz, foo];
81        let mut registers = Registers::new(3);
82        let output = real.evaluate(bindings, &mut registers);
83        assert_eq!(&output, &[-64.0, -100.0, -144.0]);
84        assert_eq!(registers.num_allocations(), 3);
85    }
86
87    #[test]
88    fn real_op_precedence() {
89        let mut registers = Registers::new(1);
90
91        let parsed = Expression::<f32>::parse("1 * 2 + 3 * 4", empty_binding_map).unwrap();
92        let real = parsed.unwrap_real();
93        let output = real.evaluate_without_vars(&mut registers);
94        assert_eq!(&output, &[14.0]);
95
96        let parsed = Expression::<f32>::parse("8 / 4 * 3", empty_binding_map).unwrap();
97        let real = parsed.unwrap_real();
98        let output = real.evaluate_without_vars(&mut registers);
99        assert_eq!(&output, &[6.0]);
100
101        let parsed = Expression::<f32>::parse("4 ^ 3 ^ 2", empty_binding_map).unwrap();
102        let real = parsed.unwrap_real();
103        let output = real.evaluate_without_vars(&mut registers);
104        assert_eq!(&output, &[262144.0]);
105    }
106
107    #[test]
108    fn bool_expression_with_real_bindings() {
109        fn binding_map(var_name: &str) -> BindingId {
110            match var_name {
111                "bar" => 0,
112                "baz" => 1,
113                "foo" => 2,
114                _ => unreachable!(),
115            }
116        }
117        let parsed = Expression::parse("!(bar < foo && bar < baz)", binding_map).unwrap();
118        let bool = parsed.unwrap_bool();
119
120        let bar = [1.0, 6.0, 7.0];
121        let baz = [2.0, 5.0, 8.0];
122        let foo = [3.0, 4.0, 9.0];
123        let bindings = &[bar, baz, foo];
124        let mut registers = Registers::new(3);
125        let output = bool.evaluate::<_, [_; 0]>(bindings, &[], |_| unreachable!(), &mut registers);
126        assert_eq!([output[0], output[1], output[2]], [false, true, false]);
127        assert_eq!(registers.num_allocations(), 3);
128    }
129
130    #[test]
131    fn bool_expression_with_real_and_string_bindings() {
132        fn binding_map(var_name: &str) -> BindingId {
133            match var_name {
134                "foo" => 0,
135                "bar" => 0,
136                _ => unreachable!(),
137            }
138        }
139        let parsed = Expression::parse("foo == \"foo_123\" && bar > 2", binding_map).unwrap();
140        let bool = parsed.unwrap_bool();
141
142        fn string_literal_id(value: &str) -> StringId {
143            match value {
144                "foo_123" => 0,
145                _ => unreachable!(),
146            }
147        }
148
149        let foo = [0, 1, 0];
150        let bar = [1.0, 2.0, 3.0];
151        let real_bindings = &[bar];
152        let string_bindings = &[foo];
153        let mut registers = Registers::new(3);
154        let output = bool.evaluate(
155            real_bindings,
156            string_bindings,
157            string_literal_id,
158            &mut registers,
159        );
160        assert_eq!([output[0], output[1], output[2]], [false, false, true]);
161        assert_eq!(registers.num_allocations(), 5);
162    }
163
164    #[test]
165    fn naive_allocations_limited_by_recycling() {
166        fn binding_map(var_name: &str) -> BindingId {
167            match var_name {
168                "bar" => 0,
169                "baz" => 1,
170                "foo" => 2,
171                _ => unreachable!(),
172            }
173        }
174        let parsed = Expression::parse(
175            "foo + bar + baz + foo + bar + baz + foo + bar + baz",
176            binding_map,
177        )
178        .unwrap();
179        let real = parsed.unwrap_real();
180
181        let bar = [1.0, 2.0, 3.0];
182        let baz = [4.0, 5.0, 6.0];
183        let foo = [7.0, 8.0, 9.0];
184        let bindings = &[bar, baz, foo];
185        let mut registers = Registers::new(3);
186        let output = real.evaluate(bindings, &mut registers);
187        assert_eq!(&output, &[36.0, 45.0, 54.0]);
188        assert_eq!(registers.num_allocations(), 2);
189    }
190
191    #[test]
192    fn real_bench() {
193        fn binding_map(var_name: &str) -> BindingId {
194            match var_name {
195                "x" => 0,
196                "y" => 1,
197                "z" => 2,
198                var => panic!("Unexpected variable: {var}"),
199            }
200        }
201        let parsed = Expression::parse("(z + (z^2 - 4*x*y)^0.5) / (2*x)", binding_map).unwrap();
202        let real = parsed.unwrap_real();
203
204        const LEN: i32 = 10_000_000;
205        let x: Vec<_> = (0..LEN).map(|i| i as f32).collect();
206        let y: Vec<_> = (0..LEN).map(|i| (LEN - i) as f32).collect();
207        let z: Vec<_> = (0..LEN).map(|i| ((LEN / 2) - i) as f32).collect();
208        let bindings = &[x, y, z];
209
210        let mut registers = Registers::new(LEN as usize);
211        let start = std::time::Instant::now();
212        let _output = real.evaluate(bindings, &mut registers);
213        let elapsed = start.elapsed().as_millis();
214        println!(
215            "Took {elapsed} ms, {} ns per element",
216            (1_000_000 * elapsed) / LEN as u128
217        );
218        assert_eq!(registers.num_allocations(), 3);
219    }
220}