1use crate::{BoolExpression, FloatExt, RealExpression, StringExpression};
2use bitvec::vec::BitVec;
3
4#[cfg(feature = "rayon")]
5use rayon::{
6 prelude::{
7 IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator,
8 ParallelExtend, ParallelIterator,
9 },
10 slice::ParallelSlice,
11};
12
13pub type StringId = u32;
15
16impl<Real: FloatExt> BoolExpression<Real> {
17 pub fn evaluate<R: AsRef<[Real]>, S: AsRef<[StringId]>>(
19 &self,
20 real_bindings: &[R],
21 string_bindings: &[S],
22 mut get_string_literal_id: impl FnMut(&str) -> StringId,
23 registers: &mut Registers<Real>,
24 ) -> BitVec {
25 validate_bindings(real_bindings, registers.register_length);
26 validate_bindings(string_bindings, registers.register_length);
27 self.evaluate_recursive(
28 real_bindings,
29 string_bindings,
30 &mut get_string_literal_id,
31 registers,
32 )
33 }
34
35 fn evaluate_recursive<R: AsRef<[Real]>, S: AsRef<[StringId]>>(
36 &self,
37 real_bindings: &[R],
38 string_bindings: &[S],
39 get_string_literal_id: &mut impl FnMut(&str) -> StringId,
40 registers: &mut Registers<Real>,
41 ) -> BitVec {
42 let reg_len = registers.register_length;
43 match self {
44 Self::And(lhs, rhs) => evaluate_binary_logic(
45 |lhs, rhs, out| {
46 #[cfg(feature = "rayon")]
47 {
48 out.resize(reg_len, Default::default());
49 lhs.as_raw_slice()
50 .par_iter()
51 .zip(rhs.as_raw_slice().par_iter())
52 .zip(out.as_raw_mut_slice().par_iter_mut())
53 .for_each(|((lhs, rhs), out)| {
54 *out = lhs & rhs;
55 })
56 }
57 #[cfg(not(feature = "rayon"))]
58 {
59 out.resize(reg_len, true);
60 *out &= lhs;
61 *out &= rhs;
62 }
63 },
64 lhs.as_ref(),
65 rhs.as_ref(),
66 real_bindings,
67 string_bindings,
68 get_string_literal_id,
69 registers,
70 ),
71 Self::Equal(lhs, rhs) => evaluate_real_comparison(
72 |lhs, rhs| lhs == rhs,
73 lhs.as_ref(),
74 rhs.as_ref(),
75 real_bindings,
76 registers,
77 ),
78 Self::Greater(lhs, rhs) => evaluate_real_comparison(
79 |lhs, rhs| lhs > rhs,
80 lhs.as_ref(),
81 rhs.as_ref(),
82 real_bindings,
83 registers,
84 ),
85 Self::GreaterEqual(lhs, rhs) => evaluate_real_comparison(
86 |lhs, rhs| lhs >= rhs,
87 lhs.as_ref(),
88 rhs.as_ref(),
89 real_bindings,
90 registers,
91 ),
92 Self::Less(lhs, rhs) => evaluate_real_comparison(
93 |lhs, rhs| lhs < rhs,
94 lhs.as_ref(),
95 rhs.as_ref(),
96 real_bindings,
97 registers,
98 ),
99 Self::LessEqual(lhs, rhs) => evaluate_real_comparison(
100 |lhs, rhs| lhs <= rhs,
101 lhs.as_ref(),
102 rhs.as_ref(),
103 real_bindings,
104 registers,
105 ),
106 Self::Not(only) => evaluate_unary_logic(
107 |only| {
108 #[cfg(feature = "rayon")]
109 {
110 only.as_raw_mut_slice().par_iter_mut().for_each(|i| {
111 *i = !*i;
112 });
113 }
114 #[cfg(not(feature = "rayon"))]
115 {
116 *only = !std::mem::take(only);
117 }
118 },
119 only.as_ref(),
120 real_bindings,
121 string_bindings,
122 get_string_literal_id,
123 registers,
124 ),
125 Self::NotEqual(lhs, rhs) => evaluate_real_comparison(
126 |lhs, rhs| lhs != rhs,
127 lhs.as_ref(),
128 rhs.as_ref(),
129 real_bindings,
130 registers,
131 ),
132 Self::Or(lhs, rhs) => evaluate_binary_logic(
133 |lhs, rhs, out| {
134 #[cfg(feature = "rayon")]
135 {
136 out.resize(reg_len, Default::default());
137 lhs.as_raw_slice()
138 .par_iter()
139 .zip(rhs.as_raw_slice().par_iter())
140 .zip(out.as_raw_mut_slice().par_iter_mut())
141 .for_each(|((lhs, rhs), out)| {
142 *out = lhs | rhs;
143 })
144 }
145 #[cfg(not(feature = "rayon"))]
146 {
147 out.resize(reg_len, false);
148 *out |= lhs;
149 *out |= rhs;
150 }
151 },
152 lhs.as_ref(),
153 rhs.as_ref(),
154 real_bindings,
155 string_bindings,
156 get_string_literal_id,
157 registers,
158 ),
159 Self::StrEqual(lhs, rhs) => evaluate_string_comparison(
160 |lhs, rhs| lhs == rhs,
161 lhs,
162 rhs,
163 string_bindings,
164 get_string_literal_id,
165 registers,
166 ),
167 Self::StrNotEqual(lhs, rhs) => evaluate_string_comparison(
168 |lhs, rhs| lhs != rhs,
169 lhs,
170 rhs,
171 string_bindings,
172 get_string_literal_id,
173 registers,
174 ),
175 }
176 }
177}
178
179impl<Real: FloatExt> RealExpression<Real> {
180 pub fn evaluate_without_vars(&self, registers: &mut Registers<Real>) -> Vec<Real> {
181 self.evaluate::<[_; 0]>(&[], registers)
182 }
183
184 pub fn evaluate<R: AsRef<[Real]>>(
186 &self,
187 bindings: &[R],
188 registers: &mut Registers<Real>,
189 ) -> Vec<Real> {
190 validate_bindings(bindings, registers.register_length);
191 self.evaluate_recursive(bindings, registers)
192 }
193
194 fn evaluate_recursive<R: AsRef<[Real]>>(
195 &self,
196 bindings: &[R],
197 registers: &mut Registers<Real>,
198 ) -> Vec<Real> {
199 match self {
200 Self::Add(lhs, rhs) => evaluate_binary_real_op(
201 |lhs, rhs| lhs + rhs,
202 lhs.as_ref(),
203 rhs.as_ref(),
204 bindings,
205 registers,
206 ),
207 Self::Binding(binding) => {
210 let mut output = registers.allocate_real();
211 output.extend_from_slice(bindings[*binding].as_ref());
212 output
213 }
214 Self::Div(lhs, rhs) => evaluate_binary_real_op(
215 |lhs, rhs| lhs / rhs,
216 lhs.as_ref(),
217 rhs.as_ref(),
218 bindings,
219 registers,
220 ),
221 Self::Literal(value) => {
222 let mut output = registers.allocate_real();
223 output.extend(std::iter::repeat(*value).take(registers.register_length));
224 output
225 }
226 Self::Mul(lhs, rhs) => evaluate_binary_real_op(
227 |lhs, rhs| lhs * rhs,
228 lhs.as_ref(),
229 rhs.as_ref(),
230 bindings,
231 registers,
232 ),
233 Self::Neg(only) => {
234 evaluate_unary_real_op(|only| -only, only.as_ref(), bindings, registers)
235 }
236 Self::Pow(lhs, rhs) => evaluate_binary_real_op(
237 |lhs, rhs| lhs.powf(rhs),
238 lhs.as_ref(),
239 rhs.as_ref(),
240 bindings,
241 registers,
242 ),
243 Self::Sub(lhs, rhs) => evaluate_binary_real_op(
244 |lhs, rhs| lhs - rhs,
245 lhs.as_ref(),
246 rhs.as_ref(),
247 bindings,
248 registers,
249 ),
250 }
251 }
252}
253
254fn validate_bindings<T, B: AsRef<[T]>>(input_bindings: &[B], expected_length: usize) {
255 for b in input_bindings.iter() {
256 assert_eq!(b.as_ref().len(), expected_length);
257 }
258}
259
260fn evaluate_binary_real_op<Real: FloatExt, R: AsRef<[Real]>>(
261 op: fn(Real, Real) -> Real,
262 lhs: &RealExpression<Real>,
263 rhs: &RealExpression<Real>,
264 bindings: &[R],
265 registers: &mut Registers<Real>,
266) -> Vec<Real> {
267 let mut lhs_reg = None;
270 let lhs_values = if let RealExpression::Binding(binding) = lhs {
271 bindings[*binding].as_ref()
272 } else {
273 lhs_reg = Some(lhs.evaluate_recursive(bindings, registers));
274 lhs_reg.as_ref().unwrap()
275 };
276 let mut rhs_reg = None;
277 let rhs_values = if let RealExpression::Binding(binding) = rhs {
278 bindings[*binding].as_ref()
279 } else {
280 rhs_reg = Some(rhs.evaluate_recursive(bindings, registers));
281 rhs_reg.as_ref().unwrap()
282 };
283 let mut output = registers.allocate_real();
285
286 #[cfg(feature = "rayon")]
287 {
288 output.par_extend(
289 lhs_values
290 .par_iter()
291 .zip(rhs_values.par_iter())
292 .map(|(lhs, rhs)| op(*lhs, *rhs)),
293 );
294 }
295 #[cfg(not(feature = "rayon"))]
296 {
297 output.extend(
298 lhs_values
299 .iter()
300 .zip(rhs_values.iter())
301 .map(|(lhs, rhs)| op(*lhs, *rhs)),
302 );
303 }
304
305 if let Some(r) = lhs_reg {
306 registers.recycle_real(r);
307 }
308 if let Some(r) = rhs_reg {
309 registers.recycle_real(r);
310 }
311 output
312}
313
314fn evaluate_unary_real_op<Real: FloatExt, R: AsRef<[Real]>>(
315 op: fn(Real) -> Real,
316 only: &RealExpression<Real>,
317 bindings: &[R],
318 registers: &mut Registers<Real>,
319) -> Vec<Real> {
320 let mut only_reg = None;
323 let only_values = if let RealExpression::Binding(binding) = only {
324 bindings[*binding].as_ref()
325 } else {
326 only_reg = Some(only.evaluate_recursive(bindings, registers));
327 only_reg.as_ref().unwrap()
328 };
329 let mut output = registers.allocate_real();
331
332 #[cfg(feature = "rayon")]
333 {
334 output.par_extend(only_values.par_iter().map(|only| op(*only)));
335 }
336 #[cfg(not(feature = "rayon"))]
337 {
338 output.extend(only_values.iter().map(|only| op(*only)));
339 }
340
341 if let Some(r) = only_reg {
342 registers.recycle_real(r);
343 }
344 output
345}
346
347fn evaluate_real_comparison<Real: FloatExt, R: AsRef<[Real]>>(
348 op: fn(Real, Real) -> bool,
349 lhs: &RealExpression<Real>,
350 rhs: &RealExpression<Real>,
351 bindings: &[R],
352 registers: &mut Registers<Real>,
353) -> BitVec {
354 let mut lhs_reg = None;
357 let lhs_values = if let RealExpression::Binding(binding) = lhs {
358 bindings[*binding].as_ref()
359 } else {
360 lhs_reg = Some(lhs.evaluate_recursive(bindings, registers));
361 lhs_reg.as_ref().unwrap()
362 };
363 let mut rhs_reg = None;
364 let rhs_values = if let RealExpression::Binding(binding) = rhs {
365 bindings[*binding].as_ref()
366 } else {
367 rhs_reg = Some(rhs.evaluate_recursive(bindings, registers));
368 rhs_reg.as_ref().unwrap()
369 };
370 let mut output = registers.allocate_bool();
372
373 #[cfg(feature = "rayon")]
374 {
375 output.resize(registers.register_length, Default::default());
376 parallel_comparison(op, lhs_values, rhs_values, &mut output);
377 }
378 #[cfg(not(feature = "rayon"))]
379 {
380 output.extend(
381 lhs_values
382 .iter()
383 .zip(rhs_values.iter())
384 .map(|(lhs, rhs)| op(*lhs, *rhs)),
385 );
386 }
387
388 if let Some(r) = lhs_reg {
389 registers.recycle_real(r);
390 }
391 if let Some(r) = rhs_reg {
392 registers.recycle_real(r);
393 }
394 output
395}
396
397fn evaluate_string_comparison<Real, S: AsRef<[StringId]>>(
398 op: fn(StringId, StringId) -> bool,
399 lhs: &StringExpression,
400 rhs: &StringExpression,
401 bindings: &[S],
402 mut get_string_literal_id: impl FnMut(&str) -> StringId,
403 registers: &mut Registers<Real>,
404) -> BitVec {
405 let mut lhs_reg = None;
406 let lhs_values = match lhs {
407 StringExpression::Binding(binding) => bindings[*binding].as_ref(),
408 StringExpression::Literal(literal_value) => {
409 let mut reg = registers.allocate_string();
410 let literal_id = get_string_literal_id(literal_value);
411 reg.extend(std::iter::repeat(literal_id).take(registers.register_length));
412 lhs_reg = Some(reg);
413 lhs_reg.as_ref().unwrap()
414 }
415 };
416 let mut rhs_reg = None;
417 let rhs_values = match rhs {
418 StringExpression::Binding(binding) => bindings[*binding].as_ref(),
419 StringExpression::Literal(literal_value) => {
420 let mut reg = registers.allocate_string();
421 let literal_id = get_string_literal_id(literal_value);
422 reg.extend(std::iter::repeat(literal_id).take(registers.register_length));
423 rhs_reg = Some(reg);
424 rhs_reg.as_ref().unwrap()
425 }
426 };
427 let mut output = registers.allocate_bool();
429
430 #[cfg(feature = "rayon")]
431 {
432 output.resize(registers.register_length, Default::default());
433 parallel_comparison(op, lhs_values, rhs_values, &mut output);
434 }
435 #[cfg(not(feature = "rayon"))]
436 {
437 output.extend(
438 lhs_values
439 .iter()
440 .zip(rhs_values.iter())
441 .map(|(lhs, rhs)| op(*lhs, *rhs)),
442 );
443 }
444
445 if let Some(r) = lhs_reg {
446 registers.recycle_string(r);
447 }
448 if let Some(r) = rhs_reg {
449 registers.recycle_string(r);
450 }
451 output
452}
453
454#[cfg(feature = "rayon")]
455fn parallel_comparison<T: Copy + Send + Sync>(
456 op: fn(T, T) -> bool,
457 lhs_values: &[T],
458 rhs_values: &[T],
459 output: &mut BitVec,
460) {
461 let bits_per_block = usize::BITS as usize;
464 let bit_blocks = output.as_raw_mut_slice();
465 let lhs_chunks = lhs_values.par_chunks_exact(bits_per_block);
466 let rhs_chunks = rhs_values.par_chunks_exact(bits_per_block);
467 if let Some(rem_block) = bit_blocks.last_mut() {
468 lhs_chunks
469 .remainder()
470 .iter()
471 .zip(rhs_chunks.remainder())
472 .enumerate()
473 .for_each(|(i, (&lhs, &rhs))| {
474 *rem_block |= usize::from(op(lhs, rhs)) << i;
475 });
476 }
477 lhs_chunks
478 .zip(rhs_chunks)
479 .zip(bit_blocks.par_iter_mut())
480 .for_each(|((lhs_chunk, rhs_chunk), out_block)| {
481 for (i, (&lhs, &rhs)) in lhs_chunk.iter().zip(rhs_chunk).enumerate() {
482 *out_block |= usize::from(op(lhs, rhs)) << i;
483 }
484 });
485}
486
487fn evaluate_binary_logic<Real: FloatExt, R: AsRef<[Real]>, S: AsRef<[StringId]>>(
488 op: impl Fn(&BitVec, &BitVec, &mut BitVec),
489 lhs: &BoolExpression<Real>,
490 rhs: &BoolExpression<Real>,
491 real_bindings: &[R],
492 string_bindings: &[S],
493 get_string_literal_id: &mut impl FnMut(&str) -> StringId,
494 registers: &mut Registers<Real>,
495) -> BitVec {
496 let lhs_values = lhs.evaluate_recursive(
497 real_bindings,
498 string_bindings,
499 get_string_literal_id,
500 registers,
501 );
502 let rhs_values = rhs.evaluate_recursive(
503 real_bindings,
504 string_bindings,
505 get_string_literal_id,
506 registers,
507 );
508
509 let mut output = registers.allocate_bool();
511
512 op(&lhs_values, &rhs_values, &mut output);
513
514 registers.recycle_bool(lhs_values);
515 registers.recycle_bool(rhs_values);
516 output
517}
518
519fn evaluate_unary_logic<Real: FloatExt, R: AsRef<[Real]>, S: AsRef<[StringId]>>(
520 op: fn(&mut BitVec),
521 only: &BoolExpression<Real>,
522 real_bindings: &[R],
523 string_bindings: &[S],
524 get_string_literal_id: &mut impl FnMut(&str) -> StringId,
525 registers: &mut Registers<Real>,
526) -> BitVec {
527 let mut only_values = only.evaluate_recursive(
528 real_bindings,
529 string_bindings,
530 get_string_literal_id,
531 registers,
532 );
533
534 op(&mut only_values);
535
536 only_values
537}
538
539pub struct Registers<Real> {
545 num_allocations: usize,
546 real_registers: Vec<Vec<Real>>,
547 bool_registers: Vec<BitVec>,
548 string_registers: Vec<Vec<StringId>>,
549 register_length: usize,
550}
551
552impl<Real> Registers<Real> {
553 pub fn new(register_length: usize) -> Self {
554 Self {
555 num_allocations: 0,
556 real_registers: vec![],
557 bool_registers: vec![],
558 string_registers: vec![],
559 register_length,
560 }
561 }
562
563 pub fn set_register_length(&mut self, register_length: usize) {
571 self.register_length = register_length;
572 self.real_registers
573 .retain(|reg| reg.capacity() >= self.register_length);
574 self.bool_registers
575 .retain(|reg| reg.capacity() >= self.register_length);
576 self.string_registers
577 .retain(|reg| reg.capacity() >= self.register_length);
578 }
579
580 fn recycle_real(&mut self, mut used: Vec<Real>) {
581 used.clear();
582 self.real_registers.push(used);
583 }
584
585 fn recycle_bool(&mut self, mut used: BitVec) {
586 used.clear();
587 self.bool_registers.push(used);
588 }
589
590 fn recycle_string(&mut self, mut used: Vec<StringId>) {
591 used.clear();
592 self.string_registers.push(used);
593 }
594
595 fn allocate_real(&mut self) -> Vec<Real> {
596 self.real_registers.pop().unwrap_or_else(|| {
597 self.num_allocations += 1;
598 Vec::with_capacity(self.register_length)
599 })
600 }
601
602 fn allocate_bool(&mut self) -> BitVec {
603 self.bool_registers.pop().unwrap_or_else(|| {
604 self.num_allocations += 1;
605 BitVec::with_capacity(self.register_length)
606 })
607 }
608
609 fn allocate_string(&mut self) -> Vec<StringId> {
610 self.string_registers.pop().unwrap_or_else(|| {
611 self.num_allocations += 1;
612 Vec::with_capacity(self.register_length)
613 })
614 }
615
616 pub fn num_allocations(&self) -> usize {
617 self.num_allocations
618 }
619}