stepflow_data/
statedata_filtered.rs1use std::collections::HashSet;
2use super::StateData;
3use super::var::VarId;
4use super::value::ValidVal;
5
6pub struct StateDataFiltered<'sd> {
8 allowed_var_ids: HashSet<VarId>,
9 state_data: &'sd StateData,
10}
11
12impl<'sd> StateDataFiltered<'sd> {
13 pub fn new(state_data: &'sd StateData, allowed_var_ids: HashSet<VarId>) -> Self {
15 Self { state_data, allowed_var_ids }
16 }
17
18 pub fn get(&self, var_id: &VarId) -> Option<&ValidVal> {
19 if !self.allowed_var_ids.contains(var_id) {
20 return None
21 }
22 self.state_data.get(var_id)
23 }
24
25 pub fn contains(&self, var_id: &VarId) -> bool {
26 if !self.allowed_var_ids.contains(var_id) {
27 return false;
28 }
29 self.state_data.contains(var_id)
30 }
31}
32
33#[cfg(test)]
34mod tests {
35 use std::collections::HashSet;
36 use crate::{StateData, value::ValidVal, test_var_val};
37 use super::StateDataFiltered;
38
39 #[test]
40 fn basic() {
41 let var1 = test_var_val();
42 let var2 = test_var_val();
43
44 let val1_valid = ValidVal::try_new(var1.1.clone(), &var1.0).unwrap();
45
46 let mut data = StateData::new();
48 data.insert(&var1.0, var1.1).unwrap();
49 data.insert(&var2.0, var2.1).unwrap();
50
51 let mut filter = HashSet::new();
53 filter.insert(var1.0.id().clone());
54 let data_filtered = StateDataFiltered::new(&data, filter);
55
56 assert_eq!(data_filtered.get(var1.0.id()), Some(&val1_valid));
57 assert_eq!(data_filtered.get(var2.0.id()), None);
58 }
59
60}