stepflow_data/
statedata_filtered.rs

1use std::collections::HashSet;
2use super::StateData;
3use super::var::VarId;
4use super::value::ValidVal;
5
6/// Wrapper to a [`StateData`] that provides a filtered view of the data contained
7pub struct StateDataFiltered<'sd> {
8  allowed_var_ids: HashSet<VarId>,
9  state_data: &'sd StateData,
10}
11
12impl<'sd> StateDataFiltered<'sd> {
13  /// Wrap the `state_data` with a filtered view. Only IDs specified in `allowed_var_ids` are visible.
14  pub fn new(state_data: &'sd StateData, allowed_var_ids: HashSet<VarId>) -> Self {
15    Self { state_data, allowed_var_ids }
16  }
17
18  pub fn get(&self, var_id: &VarId) -> Option<&ValidVal> {
19    if !self.allowed_var_ids.contains(var_id) {
20      return None
21    }
22    self.state_data.get(var_id)
23  }
24
25  pub fn contains(&self, var_id: &VarId) -> bool {
26    if !self.allowed_var_ids.contains(var_id) {
27      return false;
28    }
29    self.state_data.contains(var_id)
30  }
31}
32
33#[cfg(test)]
34mod tests {
35  use std::collections::HashSet;
36  use crate::{StateData, value::ValidVal, test_var_val};
37  use super::StateDataFiltered;
38
39  #[test]
40  fn basic() {
41    let var1 = test_var_val();
42    let var2 = test_var_val();
43
44    let val1_valid = ValidVal::try_new(var1.1.clone(), &var1.0).unwrap();
45
46    // add var1 + var2
47    let mut data = StateData::new();
48    data.insert(&var1.0, var1.1).unwrap();
49    data.insert(&var2.0, var2.1).unwrap();
50
51    // create filtered statedata
52    let mut filter = HashSet::new();
53    filter.insert(var1.0.id().clone());
54    let data_filtered = StateDataFiltered::new(&data, filter);
55
56    assert_eq!(data_filtered.get(var1.0.id()), Some(&val1_valid));
57    assert_eq!(data_filtered.get(var2.0.id()), None);
58  }
59
60}