Skip to main content

haystack_server/
domain_scope.rs

1//! Domain scoping for federated queries.
2
3use std::collections::HashSet;
4
5/// Scopes federation queries to specific connector domains.
6///
7/// A `None` domains set means wildcard — all connectors are in scope.
8/// An empty set means no connectors are in scope.
9#[derive(Debug, Clone)]
10pub struct DomainScope {
11    domains: Option<HashSet<String>>,
12}
13
14impl DomainScope {
15    /// Create a wildcard scope that includes all connectors.
16    pub fn all() -> Self {
17        Self { domains: None }
18    }
19
20    /// Create a scope limited to specific domains.
21    pub fn scoped(domains: impl IntoIterator<Item = String>) -> Self {
22        Self {
23            domains: Some(domains.into_iter().collect()),
24        }
25    }
26
27    /// Check if a domain is included in this scope.
28    /// Connectors with `None` domain are always included in any scope.
29    pub fn includes(&self, domain: Option<&str>) -> bool {
30        match (&self.domains, domain) {
31            (None, _) => true, // wildcard scope includes everything
32            (_, None) => true, // unscoped connector is always included
33            (Some(set), Some(d)) => set.contains(d),
34        }
35    }
36
37    /// Check if this is a wildcard scope.
38    pub fn is_wildcard(&self) -> bool {
39        self.domains.is_none()
40    }
41
42    /// Get the domain set, if scoped.
43    pub fn domains(&self) -> Option<&HashSet<String>> {
44        self.domains.as_ref()
45    }
46}
47
48impl Default for DomainScope {
49    fn default() -> Self {
50        Self::all()
51    }
52}
53
54#[cfg(test)]
55mod tests {
56    use super::*;
57
58    #[test]
59    fn wildcard_includes_everything() {
60        let scope = DomainScope::all();
61        assert!(scope.includes(Some("a")));
62        assert!(scope.includes(Some("b")));
63        assert!(scope.includes(None));
64    }
65
66    #[test]
67    fn scoped_includes_matching() {
68        let scope = DomainScope::scoped(["a".to_string(), "b".to_string()]);
69        assert!(scope.includes(Some("a")));
70        assert!(scope.includes(Some("b")));
71    }
72
73    #[test]
74    fn scoped_excludes_non_matching() {
75        let scope = DomainScope::scoped(["a".to_string()]);
76        assert!(!scope.includes(Some("c")));
77    }
78
79    #[test]
80    fn unscoped_connector_always_included() {
81        let scope = DomainScope::scoped(["a".to_string()]);
82        assert!(scope.includes(None));
83    }
84
85    #[test]
86    fn default_is_wildcard() {
87        let scope = DomainScope::default();
88        assert!(scope.is_wildcard());
89    }
90
91    #[test]
92    fn empty_scope_includes_unscoped() {
93        let scope = DomainScope::scoped(std::iter::empty::<String>());
94        assert!(scope.includes(None));
95        assert!(!scope.includes(Some("any")));
96    }
97}