vector_expr/
evaluate.rs

1use crate::{BoolExpression, FloatExt, RealExpression, StringExpression};
2use bitvec::vec::BitVec;
3
4#[cfg(feature = "rayon")]
5use rayon::{
6    prelude::{
7        IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator,
8        ParallelExtend, ParallelIterator,
9    },
10    slice::ParallelSlice,
11};
12
13/// To speed up string comparisons, we use string interning.
14pub type StringId = u32;
15
16impl<Real: FloatExt> BoolExpression<Real> {
17    /// Calculates the `bool`-valued results of the expression component-wise.
18    pub fn evaluate<R: AsRef<[Real]>, S: AsRef<[StringId]>>(
19        &self,
20        real_bindings: &[R],
21        string_bindings: &[S],
22        mut get_string_literal_id: impl FnMut(&str) -> StringId,
23        registers: &mut Registers<Real>,
24    ) -> BitVec {
25        validate_bindings(real_bindings, registers.register_length);
26        validate_bindings(string_bindings, registers.register_length);
27        self.evaluate_recursive(
28            real_bindings,
29            string_bindings,
30            &mut get_string_literal_id,
31            registers,
32        )
33    }
34
35    fn evaluate_recursive<R: AsRef<[Real]>, S: AsRef<[StringId]>>(
36        &self,
37        real_bindings: &[R],
38        string_bindings: &[S],
39        get_string_literal_id: &mut impl FnMut(&str) -> StringId,
40        registers: &mut Registers<Real>,
41    ) -> BitVec {
42        let reg_len = registers.register_length;
43        match self {
44            Self::And(lhs, rhs) => evaluate_binary_logic(
45                |lhs, rhs, out| {
46                    #[cfg(feature = "rayon")]
47                    {
48                        out.resize(reg_len, Default::default());
49                        lhs.as_raw_slice()
50                            .par_iter()
51                            .zip(rhs.as_raw_slice().par_iter())
52                            .zip(out.as_raw_mut_slice().par_iter_mut())
53                            .for_each(|((lhs, rhs), out)| {
54                                *out = lhs & rhs;
55                            })
56                    }
57                    #[cfg(not(feature = "rayon"))]
58                    {
59                        out.resize(reg_len, true);
60                        *out &= lhs;
61                        *out &= rhs;
62                    }
63                },
64                lhs.as_ref(),
65                rhs.as_ref(),
66                real_bindings,
67                string_bindings,
68                get_string_literal_id,
69                registers,
70            ),
71            Self::Equal(lhs, rhs) => evaluate_real_comparison(
72                |lhs, rhs| lhs == rhs,
73                lhs.as_ref(),
74                rhs.as_ref(),
75                real_bindings,
76                registers,
77            ),
78            Self::Greater(lhs, rhs) => evaluate_real_comparison(
79                |lhs, rhs| lhs > rhs,
80                lhs.as_ref(),
81                rhs.as_ref(),
82                real_bindings,
83                registers,
84            ),
85            Self::GreaterEqual(lhs, rhs) => evaluate_real_comparison(
86                |lhs, rhs| lhs >= rhs,
87                lhs.as_ref(),
88                rhs.as_ref(),
89                real_bindings,
90                registers,
91            ),
92            Self::Less(lhs, rhs) => evaluate_real_comparison(
93                |lhs, rhs| lhs < rhs,
94                lhs.as_ref(),
95                rhs.as_ref(),
96                real_bindings,
97                registers,
98            ),
99            Self::LessEqual(lhs, rhs) => evaluate_real_comparison(
100                |lhs, rhs| lhs <= rhs,
101                lhs.as_ref(),
102                rhs.as_ref(),
103                real_bindings,
104                registers,
105            ),
106            Self::Not(only) => evaluate_unary_logic(
107                |only| {
108                    #[cfg(feature = "rayon")]
109                    {
110                        only.as_raw_mut_slice().par_iter_mut().for_each(|i| {
111                            *i = !*i;
112                        });
113                    }
114                    #[cfg(not(feature = "rayon"))]
115                    {
116                        *only = !std::mem::take(only);
117                    }
118                },
119                only.as_ref(),
120                real_bindings,
121                string_bindings,
122                get_string_literal_id,
123                registers,
124            ),
125            Self::NotEqual(lhs, rhs) => evaluate_real_comparison(
126                |lhs, rhs| lhs != rhs,
127                lhs.as_ref(),
128                rhs.as_ref(),
129                real_bindings,
130                registers,
131            ),
132            Self::Or(lhs, rhs) => evaluate_binary_logic(
133                |lhs, rhs, out| {
134                    #[cfg(feature = "rayon")]
135                    {
136                        out.resize(reg_len, Default::default());
137                        lhs.as_raw_slice()
138                            .par_iter()
139                            .zip(rhs.as_raw_slice().par_iter())
140                            .zip(out.as_raw_mut_slice().par_iter_mut())
141                            .for_each(|((lhs, rhs), out)| {
142                                *out = lhs | rhs;
143                            })
144                    }
145                    #[cfg(not(feature = "rayon"))]
146                    {
147                        out.resize(reg_len, false);
148                        *out |= lhs;
149                        *out |= rhs;
150                    }
151                },
152                lhs.as_ref(),
153                rhs.as_ref(),
154                real_bindings,
155                string_bindings,
156                get_string_literal_id,
157                registers,
158            ),
159            Self::StrEqual(lhs, rhs) => evaluate_string_comparison(
160                |lhs, rhs| lhs == rhs,
161                lhs,
162                rhs,
163                string_bindings,
164                get_string_literal_id,
165                registers,
166            ),
167            Self::StrNotEqual(lhs, rhs) => evaluate_string_comparison(
168                |lhs, rhs| lhs != rhs,
169                lhs,
170                rhs,
171                string_bindings,
172                get_string_literal_id,
173                registers,
174            ),
175        }
176    }
177}
178
179impl<Real: FloatExt> RealExpression<Real> {
180    pub fn evaluate_without_vars(&self, registers: &mut Registers<Real>) -> Vec<Real> {
181        self.evaluate::<[_; 0]>(&[], registers)
182    }
183
184    /// Calculates the real-valued results of the expression component-wise.
185    pub fn evaluate<R: AsRef<[Real]>>(
186        &self,
187        bindings: &[R],
188        registers: &mut Registers<Real>,
189    ) -> Vec<Real> {
190        validate_bindings(bindings, registers.register_length);
191        self.evaluate_recursive(bindings, registers)
192    }
193
194    fn evaluate_recursive<R: AsRef<[Real]>>(
195        &self,
196        bindings: &[R],
197        registers: &mut Registers<Real>,
198    ) -> Vec<Real> {
199        match self {
200            Self::Add(lhs, rhs) => evaluate_binary_real_op(
201                |lhs, rhs| lhs + rhs,
202                lhs.as_ref(),
203                rhs.as_ref(),
204                bindings,
205                registers,
206            ),
207            // This branch should only be taken if the entire expression is
208            // literally the identity map from one of the bindings.
209            Self::Binding(binding) => {
210                let mut output = registers.allocate_real();
211                output.extend_from_slice(bindings[*binding].as_ref());
212                output
213            }
214            Self::Div(lhs, rhs) => evaluate_binary_real_op(
215                |lhs, rhs| lhs / rhs,
216                lhs.as_ref(),
217                rhs.as_ref(),
218                bindings,
219                registers,
220            ),
221            Self::Literal(value) => {
222                let mut output = registers.allocate_real();
223                output.extend(std::iter::repeat(*value).take(registers.register_length));
224                output
225            }
226            Self::Mul(lhs, rhs) => evaluate_binary_real_op(
227                |lhs, rhs| lhs * rhs,
228                lhs.as_ref(),
229                rhs.as_ref(),
230                bindings,
231                registers,
232            ),
233            Self::Neg(only) => {
234                evaluate_unary_real_op(|only| -only, only.as_ref(), bindings, registers)
235            }
236            Self::Pow(lhs, rhs) => evaluate_binary_real_op(
237                |lhs, rhs| lhs.powf(rhs),
238                lhs.as_ref(),
239                rhs.as_ref(),
240                bindings,
241                registers,
242            ),
243            Self::Sub(lhs, rhs) => evaluate_binary_real_op(
244                |lhs, rhs| lhs - rhs,
245                lhs.as_ref(),
246                rhs.as_ref(),
247                bindings,
248                registers,
249            ),
250        }
251    }
252}
253
254fn validate_bindings<T, B: AsRef<[T]>>(input_bindings: &[B], expected_length: usize) {
255    for b in input_bindings.iter() {
256        assert_eq!(b.as_ref().len(), expected_length);
257    }
258}
259
260fn evaluate_binary_real_op<Real: FloatExt, R: AsRef<[Real]>>(
261    op: fn(Real, Real) -> Real,
262    lhs: &RealExpression<Real>,
263    rhs: &RealExpression<Real>,
264    bindings: &[R],
265    registers: &mut Registers<Real>,
266) -> Vec<Real> {
267    // Before doing recursive evaluation, we check first if we already have
268    // input values in our bindings. This avoids unnecessary copies.
269    let mut lhs_reg = None;
270    let lhs_values = if let RealExpression::Binding(binding) = lhs {
271        bindings[*binding].as_ref()
272    } else {
273        lhs_reg = Some(lhs.evaluate_recursive(bindings, registers));
274        lhs_reg.as_ref().unwrap()
275    };
276    let mut rhs_reg = None;
277    let rhs_values = if let RealExpression::Binding(binding) = rhs {
278        bindings[*binding].as_ref()
279    } else {
280        rhs_reg = Some(rhs.evaluate_recursive(bindings, registers));
281        rhs_reg.as_ref().unwrap()
282    };
283    // Allocate this output register as lazily as possible.
284    let mut output = registers.allocate_real();
285
286    #[cfg(feature = "rayon")]
287    {
288        output.par_extend(
289            lhs_values
290                .par_iter()
291                .zip(rhs_values.par_iter())
292                .map(|(lhs, rhs)| op(*lhs, *rhs)),
293        );
294    }
295    #[cfg(not(feature = "rayon"))]
296    {
297        output.extend(
298            lhs_values
299                .iter()
300                .zip(rhs_values.iter())
301                .map(|(lhs, rhs)| op(*lhs, *rhs)),
302        );
303    }
304
305    if let Some(r) = lhs_reg {
306        registers.recycle_real(r);
307    }
308    if let Some(r) = rhs_reg {
309        registers.recycle_real(r);
310    }
311    output
312}
313
314fn evaluate_unary_real_op<Real: FloatExt, R: AsRef<[Real]>>(
315    op: fn(Real) -> Real,
316    only: &RealExpression<Real>,
317    bindings: &[R],
318    registers: &mut Registers<Real>,
319) -> Vec<Real> {
320    // Before doing recursive evaluation, we check first if we already have
321    // input values in our bindings. This avoids unnecessary copies.
322    let mut only_reg = None;
323    let only_values = if let RealExpression::Binding(binding) = only {
324        bindings[*binding].as_ref()
325    } else {
326        only_reg = Some(only.evaluate_recursive(bindings, registers));
327        only_reg.as_ref().unwrap()
328    };
329    // Allocate this output register as lazily as possible.
330    let mut output = registers.allocate_real();
331
332    #[cfg(feature = "rayon")]
333    {
334        output.par_extend(only_values.par_iter().map(|only| op(*only)));
335    }
336    #[cfg(not(feature = "rayon"))]
337    {
338        output.extend(only_values.iter().map(|only| op(*only)));
339    }
340
341    if let Some(r) = only_reg {
342        registers.recycle_real(r);
343    }
344    output
345}
346
347fn evaluate_real_comparison<Real: FloatExt, R: AsRef<[Real]>>(
348    op: fn(Real, Real) -> bool,
349    lhs: &RealExpression<Real>,
350    rhs: &RealExpression<Real>,
351    bindings: &[R],
352    registers: &mut Registers<Real>,
353) -> BitVec {
354    // Before doing recursive evaluation, we check first if we already have
355    // input values in our bindings. This avoids unnecessary copies.
356    let mut lhs_reg = None;
357    let lhs_values = if let RealExpression::Binding(binding) = lhs {
358        bindings[*binding].as_ref()
359    } else {
360        lhs_reg = Some(lhs.evaluate_recursive(bindings, registers));
361        lhs_reg.as_ref().unwrap()
362    };
363    let mut rhs_reg = None;
364    let rhs_values = if let RealExpression::Binding(binding) = rhs {
365        bindings[*binding].as_ref()
366    } else {
367        rhs_reg = Some(rhs.evaluate_recursive(bindings, registers));
368        rhs_reg.as_ref().unwrap()
369    };
370    // Allocate this output register as lazily as possible.
371    let mut output = registers.allocate_bool();
372
373    #[cfg(feature = "rayon")]
374    {
375        output.resize(registers.register_length, Default::default());
376        parallel_comparison(op, lhs_values, rhs_values, &mut output);
377    }
378    #[cfg(not(feature = "rayon"))]
379    {
380        output.extend(
381            lhs_values
382                .iter()
383                .zip(rhs_values.iter())
384                .map(|(lhs, rhs)| op(*lhs, *rhs)),
385        );
386    }
387
388    if let Some(r) = lhs_reg {
389        registers.recycle_real(r);
390    }
391    if let Some(r) = rhs_reg {
392        registers.recycle_real(r);
393    }
394    output
395}
396
397fn evaluate_string_comparison<Real, S: AsRef<[StringId]>>(
398    op: fn(StringId, StringId) -> bool,
399    lhs: &StringExpression,
400    rhs: &StringExpression,
401    bindings: &[S],
402    mut get_string_literal_id: impl FnMut(&str) -> StringId,
403    registers: &mut Registers<Real>,
404) -> BitVec {
405    let mut lhs_reg = None;
406    let lhs_values = match lhs {
407        StringExpression::Binding(binding) => bindings[*binding].as_ref(),
408        StringExpression::Literal(literal_value) => {
409            let mut reg = registers.allocate_string();
410            let literal_id = get_string_literal_id(literal_value);
411            reg.extend(std::iter::repeat(literal_id).take(registers.register_length));
412            lhs_reg = Some(reg);
413            lhs_reg.as_ref().unwrap()
414        }
415    };
416    let mut rhs_reg = None;
417    let rhs_values = match rhs {
418        StringExpression::Binding(binding) => bindings[*binding].as_ref(),
419        StringExpression::Literal(literal_value) => {
420            let mut reg = registers.allocate_string();
421            let literal_id = get_string_literal_id(literal_value);
422            reg.extend(std::iter::repeat(literal_id).take(registers.register_length));
423            rhs_reg = Some(reg);
424            rhs_reg.as_ref().unwrap()
425        }
426    };
427    // Allocate this output register as lazily as possible.
428    let mut output = registers.allocate_bool();
429
430    #[cfg(feature = "rayon")]
431    {
432        output.resize(registers.register_length, Default::default());
433        parallel_comparison(op, lhs_values, rhs_values, &mut output);
434    }
435    #[cfg(not(feature = "rayon"))]
436    {
437        output.extend(
438            lhs_values
439                .iter()
440                .zip(rhs_values.iter())
441                .map(|(lhs, rhs)| op(*lhs, *rhs)),
442        );
443    }
444
445    if let Some(r) = lhs_reg {
446        registers.recycle_string(r);
447    }
448    if let Some(r) = rhs_reg {
449        registers.recycle_string(r);
450    }
451    output
452}
453
454#[cfg(feature = "rayon")]
455fn parallel_comparison<T: Copy + Send + Sync>(
456    op: fn(T, T) -> bool,
457    lhs_values: &[T],
458    rhs_values: &[T],
459    output: &mut BitVec,
460) {
461    // Some nasty chunked iteration to make sure chunks of input line up
462    // with the bit storage integers.
463    let bits_per_block = usize::BITS as usize;
464    let bit_blocks = output.as_raw_mut_slice();
465    let lhs_chunks = lhs_values.par_chunks_exact(bits_per_block);
466    let rhs_chunks = rhs_values.par_chunks_exact(bits_per_block);
467    if let Some(rem_block) = bit_blocks.last_mut() {
468        lhs_chunks
469            .remainder()
470            .iter()
471            .zip(rhs_chunks.remainder())
472            .enumerate()
473            .for_each(|(i, (&lhs, &rhs))| {
474                *rem_block |= usize::from(op(lhs, rhs)) << i;
475            });
476    }
477    lhs_chunks
478        .zip(rhs_chunks)
479        .zip(bit_blocks.par_iter_mut())
480        .for_each(|((lhs_chunk, rhs_chunk), out_block)| {
481            for (i, (&lhs, &rhs)) in lhs_chunk.iter().zip(rhs_chunk).enumerate() {
482                *out_block |= usize::from(op(lhs, rhs)) << i;
483            }
484        });
485}
486
487fn evaluate_binary_logic<Real: FloatExt, R: AsRef<[Real]>, S: AsRef<[StringId]>>(
488    op: impl Fn(&BitVec, &BitVec, &mut BitVec),
489    lhs: &BoolExpression<Real>,
490    rhs: &BoolExpression<Real>,
491    real_bindings: &[R],
492    string_bindings: &[S],
493    get_string_literal_id: &mut impl FnMut(&str) -> StringId,
494    registers: &mut Registers<Real>,
495) -> BitVec {
496    let lhs_values = lhs.evaluate_recursive(
497        real_bindings,
498        string_bindings,
499        get_string_literal_id,
500        registers,
501    );
502    let rhs_values = rhs.evaluate_recursive(
503        real_bindings,
504        string_bindings,
505        get_string_literal_id,
506        registers,
507    );
508
509    // Allocate this output register as lazily as possible.
510    let mut output = registers.allocate_bool();
511
512    op(&lhs_values, &rhs_values, &mut output);
513
514    registers.recycle_bool(lhs_values);
515    registers.recycle_bool(rhs_values);
516    output
517}
518
519fn evaluate_unary_logic<Real: FloatExt, R: AsRef<[Real]>, S: AsRef<[StringId]>>(
520    op: fn(&mut BitVec),
521    only: &BoolExpression<Real>,
522    real_bindings: &[R],
523    string_bindings: &[S],
524    get_string_literal_id: &mut impl FnMut(&str) -> StringId,
525    registers: &mut Registers<Real>,
526) -> BitVec {
527    let mut only_values = only.evaluate_recursive(
528        real_bindings,
529        string_bindings,
530        get_string_literal_id,
531        registers,
532    );
533
534    op(&mut only_values);
535
536    only_values
537}
538
539/// Scratch space for calculations. Can be reused across evaluations with the
540/// same data binding length.
541///
542/// Attempts to minimize allocations by recycling registers after intermediate
543/// calculations have finished.
544pub struct Registers<Real> {
545    num_allocations: usize,
546    real_registers: Vec<Vec<Real>>,
547    bool_registers: Vec<BitVec>,
548    string_registers: Vec<Vec<StringId>>,
549    register_length: usize,
550}
551
552impl<Real> Registers<Real> {
553    pub fn new(register_length: usize) -> Self {
554        Self {
555            num_allocations: 0,
556            real_registers: vec![],
557            bool_registers: vec![],
558            string_registers: vec![],
559            register_length,
560        }
561    }
562
563    /// Change the register length.
564    ///
565    /// This allows reusing `self` across evaluations even when the register
566    /// length changes.
567    ///
568    /// Allocated registers will be retained only if they have capacity of at
569    /// least `register_length`.
570    pub fn set_register_length(&mut self, register_length: usize) {
571        self.register_length = register_length;
572        self.real_registers
573            .retain(|reg| reg.capacity() >= self.register_length);
574        self.bool_registers
575            .retain(|reg| reg.capacity() >= self.register_length);
576        self.string_registers
577            .retain(|reg| reg.capacity() >= self.register_length);
578    }
579
580    fn recycle_real(&mut self, mut used: Vec<Real>) {
581        used.clear();
582        self.real_registers.push(used);
583    }
584
585    fn recycle_bool(&mut self, mut used: BitVec) {
586        used.clear();
587        self.bool_registers.push(used);
588    }
589
590    fn recycle_string(&mut self, mut used: Vec<StringId>) {
591        used.clear();
592        self.string_registers.push(used);
593    }
594
595    fn allocate_real(&mut self) -> Vec<Real> {
596        self.real_registers.pop().unwrap_or_else(|| {
597            self.num_allocations += 1;
598            Vec::with_capacity(self.register_length)
599        })
600    }
601
602    fn allocate_bool(&mut self) -> BitVec {
603        self.bool_registers.pop().unwrap_or_else(|| {
604            self.num_allocations += 1;
605            BitVec::with_capacity(self.register_length)
606        })
607    }
608
609    fn allocate_string(&mut self) -> Vec<StringId> {
610        self.string_registers.pop().unwrap_or_else(|| {
611            self.num_allocations += 1;
612            Vec::with_capacity(self.register_length)
613        })
614    }
615
616    pub fn num_allocations(&self) -> usize {
617        self.num_allocations
618    }
619}