1use std::collections::HashMap;
18use std::num::NonZeroUsize;
19use std::path::Path;
20use std::sync::Arc;
21
22use arc_swap::ArcSwap;
23use lru::LruCache;
24use serde::Deserialize;
25use tokio::sync::Mutex;
26
27use crate::error::MemoryError;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
41#[non_exhaustive]
42pub enum Cardinality {
43 One,
45 #[default]
47 Many,
48}
49
50#[derive(Debug, Default)]
52struct OntologyState {
53 alias_to_canonical: HashMap<String, String>,
55 cardinality: HashMap<String, Cardinality>,
57}
58
59impl OntologyState {
60 fn build(predicates: &[PredicateToml]) -> Self {
61 let mut alias_to_canonical = HashMap::new();
62 let mut cardinality = HashMap::new();
63
64 for entry in predicates {
65 let canonical = normalize(&entry.canonical);
66 let card = match entry.cardinality.as_deref() {
67 Some("1") => Cardinality::One,
68 _ => Cardinality::Many,
69 };
70 alias_to_canonical.insert(canonical.clone(), canonical.clone());
71 cardinality.insert(canonical.clone(), card);
72 for alias in &entry.aliases {
73 alias_to_canonical.insert(normalize(alias), canonical.clone());
74 }
75 }
76 Self {
77 alias_to_canonical,
78 cardinality,
79 }
80 }
81}
82
83pub type Ontology = OntologyTable;
85
86pub struct OntologyTable {
92 state: ArcSwap<OntologyState>,
93 cache: Mutex<LruCache<String, String>>,
95 cache_max: usize,
96}
97
98impl std::fmt::Debug for OntologyTable {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 f.debug_struct("OntologyTable")
101 .field("state", &"<ArcSwap<OntologyState>>")
102 .field("cache", &"<Mutex<LruCache>>")
103 .field("cache_max", &self.cache_max)
104 .finish()
105 }
106}
107
108impl OntologyTable {
109 fn new_with_state(state: OntologyState, cache_max: usize) -> Self {
110 let cap = NonZeroUsize::new(cache_max.max(1)).expect("cache_max >= 1");
111 Self {
112 state: ArcSwap::new(Arc::new(state)),
113 cache: Mutex::new(LruCache::new(cap)),
114 cache_max,
115 }
116 }
117
118 #[must_use]
120 pub fn from_default(cache_max: usize) -> Self {
121 let state = OntologyState::build(default_predicates());
122 Self::new_with_state(state, cache_max)
123 }
124
125 pub async fn from_path(path: &Path, cache_max: usize) -> Result<Self, MemoryError> {
131 let predicates = if path.as_os_str().is_empty() {
132 default_predicates().to_vec()
133 } else {
134 load_toml_file(path).await?
135 };
136 let state = OntologyState::build(&predicates);
137 Ok(Self::new_with_state(state, cache_max))
138 }
139
140 pub async fn reload(&self, path: &Path) -> Result<(), MemoryError> {
149 let predicates = if path.as_os_str().is_empty() {
150 default_predicates().to_vec()
151 } else {
152 load_toml_file(path).await?
153 };
154 let new_state = Arc::new(OntologyState::build(&predicates));
155 let mut cache = self.cache.lock().await;
158 cache.clear();
159 self.state.store(new_state);
160 Ok(())
161 }
162
163 pub async fn resolve(&self, raw_predicate: &str) -> (String, bool) {
186 let key = normalize(raw_predicate);
187 tracing::debug!(target: "memory.graph.apex.ontology_resolve", predicate = raw_predicate);
188
189 {
190 let mut cache = self.cache.lock().await;
191 if let Some(canonical) = cache.get(&key) {
192 return (canonical.clone(), false);
193 }
194 }
195
196 let state = self.state.load();
197 if let Some(canonical) = state.alias_to_canonical.get(&key) {
198 let canonical = canonical.clone();
199 let mut cache = self.cache.lock().await;
200 cache.put(key, canonical.clone());
201 return (canonical, false);
202 }
203
204 let canonical = key.clone();
206 let mut cache = self.cache.lock().await;
207 cache.put(key, canonical.clone());
208 (canonical, true)
209 }
210
211 #[must_use]
215 pub fn cardinality(&self, canonical_predicate: &str) -> Cardinality {
216 let key = normalize(canonical_predicate);
217 self.state
218 .load()
219 .cardinality
220 .get(&key)
221 .copied()
222 .unwrap_or_default()
223 }
224}
225
226pub(crate) fn normalize(s: &str) -> String {
228 s.trim()
229 .chars()
230 .filter(|c| !c.is_control())
231 .collect::<String>()
232 .to_lowercase()
233}
234
235#[derive(Debug, Clone, Deserialize)]
238struct OntologyToml {
239 #[serde(rename = "predicate")]
240 predicates: Vec<PredicateToml>,
241}
242
243#[derive(Debug, Clone, Deserialize)]
244struct PredicateToml {
245 canonical: String,
246 #[serde(default)]
247 aliases: Vec<String>,
248 #[serde(default, deserialize_with = "de_cardinality")]
250 cardinality: Option<String>,
251}
252
253fn de_cardinality<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
254where
255 D: serde::Deserializer<'de>,
256{
257 use serde::de::Visitor;
258
259 struct CardVisitor;
260 impl<'de> Visitor<'de> for CardVisitor {
261 type Value = Option<String>;
262
263 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
264 write!(f, r#"cardinality string "1" or "n", or integer 1"#)
265 }
266
267 fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
268 Ok(Some(v.to_string()))
269 }
270
271 fn visit_u64<E: serde::de::Error>(self, v: u64) -> Result<Self::Value, E> {
272 Ok(Some(if v == 1 {
273 "1".to_string()
274 } else {
275 "n".to_string()
276 }))
277 }
278
279 fn visit_none<E: serde::de::Error>(self) -> Result<Self::Value, E> {
280 Ok(None)
281 }
282
283 fn visit_some<D2: serde::Deserializer<'de>>(self, d: D2) -> Result<Self::Value, D2::Error> {
284 d.deserialize_any(self)
285 }
286 }
287
288 deserializer.deserialize_option(CardVisitor)
289}
290
291async fn load_toml_file(path: &Path) -> Result<Vec<PredicateToml>, MemoryError> {
292 let content = tokio::fs::read_to_string(path)
293 .await
294 .map_err(|e| MemoryError::InvalidInput(format!("ontology TOML read error: {e}")))?;
295 let parsed: OntologyToml = toml::from_str(&content)
296 .map_err(|e| MemoryError::InvalidInput(format!("ontology TOML parse error: {e}")))?;
297 Ok(parsed.predicates)
298}
299
300fn make(canonical: &str, aliases: &[&str], cardinality: &str) -> PredicateToml {
303 PredicateToml {
304 canonical: canonical.to_string(),
305 aliases: aliases.iter().map(|s| (*s).to_string()).collect(),
306 cardinality: Some(cardinality.to_string()),
307 }
308}
309
310fn default_predicates() -> &'static [PredicateToml] {
311 use std::sync::OnceLock;
312 static DEFAULTS: OnceLock<Vec<PredicateToml>> = OnceLock::new();
313 DEFAULTS.get_or_init(|| {
314 vec![
315 make("works_at", &["employed_by", "job_at", "works_for"], "1"),
316 make("lives_in", &["resides_in", "based_in"], "1"),
317 make("born_in", &["birthplace", "born_at"], "1"),
318 make("manages", &["manages_team", "leads", "supervises"], "1"),
319 make("owns", &["has", "possesses"], "n"),
320 make("depends_on", &["requires", "needs"], "n"),
321 make("knows", &[], "n"),
322 ]
323 })
324}
325
326#[cfg(test)]
329mod tests {
330 use super::*;
331
332 #[tokio::test]
333 async fn resolves_alias_to_canonical() {
334 let table = OntologyTable::from_default(64);
335 let (canonical, unmapped) = table.resolve("employed_by").await;
336 assert_eq!(canonical, "works_at");
337 assert!(!unmapped);
338 }
339
340 #[tokio::test]
341 async fn resolves_canonical_to_itself() {
342 let table = OntologyTable::from_default(64);
343 let (canonical, unmapped) = table.resolve("works_at").await;
344 assert_eq!(canonical, "works_at");
345 assert!(!unmapped);
346 }
347
348 #[tokio::test]
349 async fn unknown_predicate_returns_raw_and_unmapped() {
350 let table = OntologyTable::from_default(64);
351 let (canonical, unmapped) = table.resolve("some_new_predicate").await;
352 assert_eq!(canonical, "some_new_predicate");
353 assert!(unmapped);
354 }
355
356 #[tokio::test]
357 async fn cardinality_one_predicates() {
358 let table = OntologyTable::from_default(64);
359 assert_eq!(table.cardinality("works_at"), Cardinality::One);
360 assert_eq!(table.cardinality("lives_in"), Cardinality::One);
361 assert_eq!(table.cardinality("born_in"), Cardinality::One);
362 assert_eq!(table.cardinality("manages"), Cardinality::One);
363 }
364
365 #[tokio::test]
366 async fn cardinality_many_predicates() {
367 let table = OntologyTable::from_default(64);
368 assert_eq!(table.cardinality("owns"), Cardinality::Many);
369 assert_eq!(table.cardinality("depends_on"), Cardinality::Many);
370 assert_eq!(table.cardinality("unknown_pred"), Cardinality::Many);
371 }
372
373 #[tokio::test]
374 async fn normalize_trims_and_lowercases() {
375 assert_eq!(normalize(" Works_At "), "works_at");
376 assert_eq!(normalize("EMPLOYED_BY"), "employed_by");
377 }
378
379 #[tokio::test]
380 async fn cache_hit_on_second_resolve() {
381 let table = OntologyTable::from_default(64);
382 let (c1, _) = table.resolve("job_at").await;
383 let (c2, _) = table.resolve("job_at").await;
384 assert_eq!(c1, c2);
385 assert_eq!(c1, "works_at");
386 }
387
388 #[tokio::test]
389 async fn reload_clears_cache_and_preserves_resolution() {
390 let table = OntologyTable::from_default(64);
391 let _ = table.resolve("job_at").await;
392 table.reload(Path::new("")).await.unwrap();
393 let (canonical, _) = table.resolve("job_at").await;
394 assert_eq!(canonical, "works_at");
395 }
396}