tensorlogic_adapters/
mask.rs1use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5
6use crate::DomainInfo;
7
8#[derive(Clone, Debug, Serialize, Deserialize)]
10pub struct DomainMask {
11 pub domain: String,
12 pub included_elements: HashSet<String>,
13 pub excluded_elements: HashSet<String>,
14}
15
16impl DomainMask {
17 pub fn new(domain: impl Into<String>) -> Self {
18 DomainMask {
19 domain: domain.into(),
20 included_elements: HashSet::new(),
21 excluded_elements: HashSet::new(),
22 }
23 }
24
25 pub fn include(&mut self, element: impl Into<String>) -> &mut Self {
26 self.included_elements.insert(element.into());
27 self
28 }
29
30 pub fn exclude(&mut self, element: impl Into<String>) -> &mut Self {
31 self.excluded_elements.insert(element.into());
32 self
33 }
34
35 pub fn is_allowed(&self, element: &str) -> bool {
36 if !self.excluded_elements.is_empty() && self.excluded_elements.contains(element) {
37 return false;
38 }
39
40 if !self.included_elements.is_empty() {
41 return self.included_elements.contains(element);
42 }
43
44 true
45 }
46
47 pub fn apply_to_indices(&self, domain_info: &DomainInfo) -> Vec<usize> {
48 if let Some(elements) = &domain_info.elements {
49 elements
50 .iter()
51 .enumerate()
52 .filter(|(_, elem)| self.is_allowed(elem))
53 .map(|(idx, _)| idx)
54 .collect()
55 } else {
56 (0..domain_info.cardinality).collect()
57 }
58 }
59}