Skip to main content

tensorlogic_adapters/
hierarchy.rs

1//! Domain hierarchy and subtype relationships.
2
3use anyhow::{bail, Result};
4use indexmap::IndexMap;
5use serde::{Deserialize, Serialize};
6
7/// Domain hierarchy tracking subtype relationships
8#[derive(Clone, Debug, Serialize, Deserialize)]
9pub struct DomainHierarchy {
10    /// Map from subdomain to parent domain
11    parent_map: IndexMap<String, String>,
12}
13
14impl DomainHierarchy {
15    pub fn new() -> Self {
16        Self {
17            parent_map: IndexMap::new(),
18        }
19    }
20
21    /// Add a subtype relationship: subdomain <: parent
22    pub fn add_subtype(&mut self, subdomain: impl Into<String>, parent: impl Into<String>) {
23        self.parent_map.insert(subdomain.into(), parent.into());
24    }
25
26    /// Check if subdomain is a subtype of parent (directly or transitively)
27    pub fn is_subtype(&self, subdomain: &str, parent: &str) -> bool {
28        if subdomain == parent {
29            return true;
30        }
31
32        let mut current = subdomain;
33        while let Some(p) = self.parent_map.get(current) {
34            if p == parent {
35                return true;
36            }
37            current = p;
38        }
39
40        false
41    }
42
43    /// Get the direct parent of a domain, if any
44    pub fn get_parent(&self, domain: &str) -> Option<&str> {
45        self.parent_map.get(domain).map(|s| s.as_str())
46    }
47
48    /// Get all ancestors of a domain (parent, grandparent, etc.)
49    pub fn get_ancestors(&self, domain: &str) -> Vec<String> {
50        let mut ancestors = Vec::new();
51        let mut current = domain;
52
53        while let Some(parent) = self.parent_map.get(current) {
54            ancestors.push(parent.clone());
55            current = parent;
56        }
57
58        ancestors
59    }
60
61    /// Get all descendants of a domain
62    pub fn get_descendants(&self, domain: &str) -> Vec<String> {
63        self.parent_map
64            .iter()
65            .filter_map(|(child, parent)| {
66                if parent == domain || self.is_subtype(child, domain) {
67                    Some(child.clone())
68                } else {
69                    None
70                }
71            })
72            .collect()
73    }
74
75    /// Validate that there are no cycles in the hierarchy
76    pub fn validate_acyclic(&self) -> Result<()> {
77        for domain in self.parent_map.keys() {
78            let mut visited = std::collections::HashSet::new();
79            let mut current = domain.as_str();
80
81            while let Some(parent) = self.parent_map.get(current) {
82                if !visited.insert(current) {
83                    bail!("Cycle detected in domain hierarchy involving '{}'", domain);
84                }
85                current = parent;
86            }
87        }
88
89        Ok(())
90    }
91
92    /// Find the least common supertype of two domains
93    pub fn least_common_supertype(&self, domain1: &str, domain2: &str) -> Option<String> {
94        if domain1 == domain2 {
95            return Some(domain1.to_string());
96        }
97
98        let ancestors1: std::collections::HashSet<_> =
99            self.get_ancestors(domain1).into_iter().collect();
100
101        if ancestors1.contains(domain2) {
102            return Some(domain2.to_string());
103        }
104
105        self.get_ancestors(domain2)
106            .into_iter()
107            .find(|ancestor| ancestors1.contains(ancestor))
108    }
109
110    /// Get all domains in the hierarchy (both subdomains and their parents).
111    pub fn all_domains(&self) -> Vec<String> {
112        let mut domains = std::collections::HashSet::new();
113
114        for (subdomain, parent) in &self.parent_map {
115            domains.insert(subdomain.clone());
116            domains.insert(parent.clone());
117        }
118
119        domains.into_iter().collect()
120    }
121}
122
123impl Default for DomainHierarchy {
124    fn default() -> Self {
125        Self::new()
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[test]
134    fn test_subtype_direct() {
135        let mut hierarchy = DomainHierarchy::new();
136        hierarchy.add_subtype("Student", "Person");
137
138        assert!(hierarchy.is_subtype("Student", "Person"));
139        assert!(hierarchy.is_subtype("Student", "Student"));
140        assert!(!hierarchy.is_subtype("Person", "Student"));
141    }
142
143    #[test]
144    fn test_subtype_transitive() {
145        let mut hierarchy = DomainHierarchy::new();
146        hierarchy.add_subtype("Student", "Person");
147        hierarchy.add_subtype("Person", "Agent");
148
149        assert!(hierarchy.is_subtype("Student", "Agent"));
150        assert!(hierarchy.is_subtype("Student", "Person"));
151        assert!(hierarchy.is_subtype("Person", "Agent"));
152        assert!(!hierarchy.is_subtype("Agent", "Student"));
153    }
154
155    #[test]
156    fn test_get_ancestors() {
157        let mut hierarchy = DomainHierarchy::new();
158        hierarchy.add_subtype("Student", "Person");
159        hierarchy.add_subtype("Person", "Agent");
160
161        let ancestors = hierarchy.get_ancestors("Student");
162        assert_eq!(ancestors, vec!["Person", "Agent"]);
163    }
164
165    #[test]
166    fn test_get_descendants() {
167        let mut hierarchy = DomainHierarchy::new();
168        hierarchy.add_subtype("Student", "Person");
169        hierarchy.add_subtype("Teacher", "Person");
170
171        let descendants = hierarchy.get_descendants("Person");
172        assert_eq!(descendants.len(), 2);
173        assert!(descendants.contains(&"Student".to_string()));
174        assert!(descendants.contains(&"Teacher".to_string()));
175    }
176
177    #[test]
178    fn test_least_common_supertype() {
179        let mut hierarchy = DomainHierarchy::new();
180        hierarchy.add_subtype("Student", "Person");
181        hierarchy.add_subtype("Teacher", "Person");
182        hierarchy.add_subtype("Person", "Agent");
183
184        assert_eq!(
185            hierarchy.least_common_supertype("Student", "Teacher"),
186            Some("Person".to_string())
187        );
188        assert_eq!(
189            hierarchy.least_common_supertype("Student", "Student"),
190            Some("Student".to_string())
191        );
192    }
193
194    #[test]
195    fn test_validate_acyclic() {
196        let mut hierarchy = DomainHierarchy::new();
197        hierarchy.add_subtype("A", "B");
198        hierarchy.add_subtype("B", "C");
199
200        assert!(hierarchy.validate_acyclic().is_ok());
201    }
202
203    #[test]
204    fn test_validate_cycle_detection() {
205        let mut hierarchy = DomainHierarchy::new();
206        hierarchy.add_subtype("A", "B");
207        hierarchy.add_subtype("B", "A");
208
209        assert!(hierarchy.validate_acyclic().is_err());
210    }
211}