Skip to main content

tensorlogic_adapters/
mask.rs

1//! Domain masks for filtering and constraints.
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5
6use crate::DomainInfo;
7
8/// Domain mask for filtering and constraints
9#[derive(Clone, Debug, Serialize, Deserialize)]
10pub struct DomainMask {
11    pub domain: String,
12    pub included_elements: HashSet<String>,
13    pub excluded_elements: HashSet<String>,
14}
15
16impl DomainMask {
17    pub fn new(domain: impl Into<String>) -> Self {
18        DomainMask {
19            domain: domain.into(),
20            included_elements: HashSet::new(),
21            excluded_elements: HashSet::new(),
22        }
23    }
24
25    pub fn include(&mut self, element: impl Into<String>) -> &mut Self {
26        self.included_elements.insert(element.into());
27        self
28    }
29
30    pub fn exclude(&mut self, element: impl Into<String>) -> &mut Self {
31        self.excluded_elements.insert(element.into());
32        self
33    }
34
35    pub fn is_allowed(&self, element: &str) -> bool {
36        if !self.excluded_elements.is_empty() && self.excluded_elements.contains(element) {
37            return false;
38        }
39
40        if !self.included_elements.is_empty() {
41            return self.included_elements.contains(element);
42        }
43
44        true
45    }
46
47    pub fn apply_to_indices(&self, domain_info: &DomainInfo) -> Vec<usize> {
48        if let Some(elements) = &domain_info.elements {
49            elements
50                .iter()
51                .enumerate()
52                .filter(|(_, elem)| self.is_allowed(elem))
53                .map(|(idx, _)| idx)
54                .collect()
55        } else {
56            (0..domain_info.cardinality).collect()
57        }
58    }
59}