tensorlogic_adapters/
hierarchy.rs1use anyhow::{bail, Result};
4use indexmap::IndexMap;
5use serde::{Deserialize, Serialize};
6
7#[derive(Clone, Debug, Serialize, Deserialize)]
9pub struct DomainHierarchy {
10 parent_map: IndexMap<String, String>,
12}
13
14impl DomainHierarchy {
15 pub fn new() -> Self {
16 Self {
17 parent_map: IndexMap::new(),
18 }
19 }
20
21 pub fn add_subtype(&mut self, subdomain: impl Into<String>, parent: impl Into<String>) {
23 self.parent_map.insert(subdomain.into(), parent.into());
24 }
25
26 pub fn is_subtype(&self, subdomain: &str, parent: &str) -> bool {
28 if subdomain == parent {
29 return true;
30 }
31
32 let mut current = subdomain;
33 while let Some(p) = self.parent_map.get(current) {
34 if p == parent {
35 return true;
36 }
37 current = p;
38 }
39
40 false
41 }
42
43 pub fn get_parent(&self, domain: &str) -> Option<&str> {
45 self.parent_map.get(domain).map(|s| s.as_str())
46 }
47
48 pub fn get_ancestors(&self, domain: &str) -> Vec<String> {
50 let mut ancestors = Vec::new();
51 let mut current = domain;
52
53 while let Some(parent) = self.parent_map.get(current) {
54 ancestors.push(parent.clone());
55 current = parent;
56 }
57
58 ancestors
59 }
60
61 pub fn get_descendants(&self, domain: &str) -> Vec<String> {
63 self.parent_map
64 .iter()
65 .filter_map(|(child, parent)| {
66 if parent == domain || self.is_subtype(child, domain) {
67 Some(child.clone())
68 } else {
69 None
70 }
71 })
72 .collect()
73 }
74
75 pub fn validate_acyclic(&self) -> Result<()> {
77 for domain in self.parent_map.keys() {
78 let mut visited = std::collections::HashSet::new();
79 let mut current = domain.as_str();
80
81 while let Some(parent) = self.parent_map.get(current) {
82 if !visited.insert(current) {
83 bail!("Cycle detected in domain hierarchy involving '{}'", domain);
84 }
85 current = parent;
86 }
87 }
88
89 Ok(())
90 }
91
92 pub fn least_common_supertype(&self, domain1: &str, domain2: &str) -> Option<String> {
94 if domain1 == domain2 {
95 return Some(domain1.to_string());
96 }
97
98 let ancestors1: std::collections::HashSet<_> =
99 self.get_ancestors(domain1).into_iter().collect();
100
101 if ancestors1.contains(domain2) {
102 return Some(domain2.to_string());
103 }
104
105 self.get_ancestors(domain2)
106 .into_iter()
107 .find(|ancestor| ancestors1.contains(ancestor))
108 }
109
110 pub fn all_domains(&self) -> Vec<String> {
112 let mut domains = std::collections::HashSet::new();
113
114 for (subdomain, parent) in &self.parent_map {
115 domains.insert(subdomain.clone());
116 domains.insert(parent.clone());
117 }
118
119 domains.into_iter().collect()
120 }
121}
122
123impl Default for DomainHierarchy {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
134 fn test_subtype_direct() {
135 let mut hierarchy = DomainHierarchy::new();
136 hierarchy.add_subtype("Student", "Person");
137
138 assert!(hierarchy.is_subtype("Student", "Person"));
139 assert!(hierarchy.is_subtype("Student", "Student"));
140 assert!(!hierarchy.is_subtype("Person", "Student"));
141 }
142
143 #[test]
144 fn test_subtype_transitive() {
145 let mut hierarchy = DomainHierarchy::new();
146 hierarchy.add_subtype("Student", "Person");
147 hierarchy.add_subtype("Person", "Agent");
148
149 assert!(hierarchy.is_subtype("Student", "Agent"));
150 assert!(hierarchy.is_subtype("Student", "Person"));
151 assert!(hierarchy.is_subtype("Person", "Agent"));
152 assert!(!hierarchy.is_subtype("Agent", "Student"));
153 }
154
155 #[test]
156 fn test_get_ancestors() {
157 let mut hierarchy = DomainHierarchy::new();
158 hierarchy.add_subtype("Student", "Person");
159 hierarchy.add_subtype("Person", "Agent");
160
161 let ancestors = hierarchy.get_ancestors("Student");
162 assert_eq!(ancestors, vec!["Person", "Agent"]);
163 }
164
165 #[test]
166 fn test_get_descendants() {
167 let mut hierarchy = DomainHierarchy::new();
168 hierarchy.add_subtype("Student", "Person");
169 hierarchy.add_subtype("Teacher", "Person");
170
171 let descendants = hierarchy.get_descendants("Person");
172 assert_eq!(descendants.len(), 2);
173 assert!(descendants.contains(&"Student".to_string()));
174 assert!(descendants.contains(&"Teacher".to_string()));
175 }
176
177 #[test]
178 fn test_least_common_supertype() {
179 let mut hierarchy = DomainHierarchy::new();
180 hierarchy.add_subtype("Student", "Person");
181 hierarchy.add_subtype("Teacher", "Person");
182 hierarchy.add_subtype("Person", "Agent");
183
184 assert_eq!(
185 hierarchy.least_common_supertype("Student", "Teacher"),
186 Some("Person".to_string())
187 );
188 assert_eq!(
189 hierarchy.least_common_supertype("Student", "Student"),
190 Some("Student".to_string())
191 );
192 }
193
194 #[test]
195 fn test_validate_acyclic() {
196 let mut hierarchy = DomainHierarchy::new();
197 hierarchy.add_subtype("A", "B");
198 hierarchy.add_subtype("B", "C");
199
200 assert!(hierarchy.validate_acyclic().is_ok());
201 }
202
203 #[test]
204 fn test_validate_cycle_detection() {
205 let mut hierarchy = DomainHierarchy::new();
206 hierarchy.add_subtype("A", "B");
207 hierarchy.add_subtype("B", "A");
208
209 assert!(hierarchy.validate_acyclic().is_err());
210 }
211}