1mod evaluate;
37mod expression;
38mod parse;
39
40#[doc = include_str!("grammar.pest")]
44pub mod grammar_doc {}
46
47pub use evaluate::*;
48pub use expression::*;
49pub use parse::ParseError;
50
51pub fn empty_binding_map(_var_name: &str) -> BindingId {
53 panic!("Empty binding map")
54}
55
56pub trait FloatExt: num_traits::Float + std::str::FromStr + Send + Sync {}
57impl FloatExt for f32 {}
58impl FloatExt for f64 {}
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63
64 #[test]
65 fn real_expression() {
66 fn binding_map(var_name: &str) -> BindingId {
67 match var_name {
68 "bar" => 0,
69 "baz" => 1,
70 "foo" => 2,
71 _ => unreachable!(),
72 }
73 }
74 let parsed = Expression::parse("2 * (foo + bar) * -baz", binding_map).unwrap();
75 let real = parsed.unwrap_real();
76
77 let bar = [1.0, 2.0, 3.0];
78 let baz = [4.0, 5.0, 6.0];
79 let foo = [7.0, 8.0, 9.0];
80 let bindings = &[bar, baz, foo];
81 let mut registers = Registers::new(3);
82 let output = real.evaluate(bindings, &mut registers);
83 assert_eq!(&output, &[-64.0, -100.0, -144.0]);
84 assert_eq!(registers.num_allocations(), 3);
85 }
86
87 #[test]
88 fn real_op_precedence() {
89 let mut registers = Registers::new(1);
90
91 let parsed = Expression::<f32>::parse("1 * 2 + 3 * 4", empty_binding_map).unwrap();
92 let real = parsed.unwrap_real();
93 let output = real.evaluate_without_vars(&mut registers);
94 assert_eq!(&output, &[14.0]);
95
96 let parsed = Expression::<f32>::parse("8 / 4 * 3", empty_binding_map).unwrap();
97 let real = parsed.unwrap_real();
98 let output = real.evaluate_without_vars(&mut registers);
99 assert_eq!(&output, &[6.0]);
100
101 let parsed = Expression::<f32>::parse("4 ^ 3 ^ 2", empty_binding_map).unwrap();
102 let real = parsed.unwrap_real();
103 let output = real.evaluate_without_vars(&mut registers);
104 assert_eq!(&output, &[262144.0]);
105 }
106
107 #[test]
108 fn bool_expression_with_real_bindings() {
109 fn binding_map(var_name: &str) -> BindingId {
110 match var_name {
111 "bar" => 0,
112 "baz" => 1,
113 "foo" => 2,
114 _ => unreachable!(),
115 }
116 }
117 let parsed = Expression::parse("!(bar < foo && bar < baz)", binding_map).unwrap();
118 let bool = parsed.unwrap_bool();
119
120 let bar = [1.0, 6.0, 7.0];
121 let baz = [2.0, 5.0, 8.0];
122 let foo = [3.0, 4.0, 9.0];
123 let bindings = &[bar, baz, foo];
124 let mut registers = Registers::new(3);
125 let output = bool.evaluate::<_, [_; 0]>(bindings, &[], |_| unreachable!(), &mut registers);
126 assert_eq!([output[0], output[1], output[2]], [false, true, false]);
127 assert_eq!(registers.num_allocations(), 3);
128 }
129
130 #[test]
131 fn bool_expression_with_real_and_string_bindings() {
132 fn binding_map(var_name: &str) -> BindingId {
133 match var_name {
134 "foo" => 0,
135 "bar" => 0,
136 _ => unreachable!(),
137 }
138 }
139 let parsed = Expression::parse("foo == \"foo_123\" && bar > 2", binding_map).unwrap();
140 let bool = parsed.unwrap_bool();
141
142 fn string_literal_id(value: &str) -> StringId {
143 match value {
144 "foo_123" => 0,
145 _ => unreachable!(),
146 }
147 }
148
149 let foo = [0, 1, 0];
150 let bar = [1.0, 2.0, 3.0];
151 let real_bindings = &[bar];
152 let string_bindings = &[foo];
153 let mut registers = Registers::new(3);
154 let output = bool.evaluate(
155 real_bindings,
156 string_bindings,
157 string_literal_id,
158 &mut registers,
159 );
160 assert_eq!([output[0], output[1], output[2]], [false, false, true]);
161 assert_eq!(registers.num_allocations(), 5);
162 }
163
164 #[test]
165 fn naive_allocations_limited_by_recycling() {
166 fn binding_map(var_name: &str) -> BindingId {
167 match var_name {
168 "bar" => 0,
169 "baz" => 1,
170 "foo" => 2,
171 _ => unreachable!(),
172 }
173 }
174 let parsed = Expression::parse(
175 "foo + bar + baz + foo + bar + baz + foo + bar + baz",
176 binding_map,
177 )
178 .unwrap();
179 let real = parsed.unwrap_real();
180
181 let bar = [1.0, 2.0, 3.0];
182 let baz = [4.0, 5.0, 6.0];
183 let foo = [7.0, 8.0, 9.0];
184 let bindings = &[bar, baz, foo];
185 let mut registers = Registers::new(3);
186 let output = real.evaluate(bindings, &mut registers);
187 assert_eq!(&output, &[36.0, 45.0, 54.0]);
188 assert_eq!(registers.num_allocations(), 2);
189 }
190
191 #[test]
192 fn real_bench() {
193 fn binding_map(var_name: &str) -> BindingId {
194 match var_name {
195 "x" => 0,
196 "y" => 1,
197 "z" => 2,
198 var => panic!("Unexpected variable: {var}"),
199 }
200 }
201 let parsed = Expression::parse("(z + (z^2 - 4*x*y)^0.5) / (2*x)", binding_map).unwrap();
202 let real = parsed.unwrap_real();
203
204 const LEN: i32 = 10_000_000;
205 let x: Vec<_> = (0..LEN).map(|i| i as f32).collect();
206 let y: Vec<_> = (0..LEN).map(|i| (LEN - i) as f32).collect();
207 let z: Vec<_> = (0..LEN).map(|i| ((LEN / 2) - i) as f32).collect();
208 let bindings = &[x, y, z];
209
210 let mut registers = Registers::new(LEN as usize);
211 let start = std::time::Instant::now();
212 let _output = real.evaluate(bindings, &mut registers);
213 let elapsed = start.elapsed().as_millis();
214 println!(
215 "Took {elapsed} ms, {} ns per element",
216 (1_000_000 * elapsed) / LEN as u128
217 );
218 assert_eq!(registers.num_allocations(), 3);
219 }
220}