1use std::collections::HashMap;
18use std::num::NonZeroUsize;
19use std::path::Path;
20use std::sync::Arc;
21
22use arc_swap::ArcSwap;
23use lru::LruCache;
24use serde::Deserialize;
25use tokio::sync::Mutex;
26
27use crate::error::MemoryError;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
41pub enum Cardinality {
42 One,
44 #[default]
46 Many,
47}
48
49#[derive(Debug, Default)]
51struct OntologyState {
52 alias_to_canonical: HashMap<String, String>,
54 cardinality: HashMap<String, Cardinality>,
56}
57
58impl OntologyState {
59 fn build(predicates: &[PredicateToml]) -> Self {
60 let mut alias_to_canonical = HashMap::new();
61 let mut cardinality = HashMap::new();
62
63 for entry in predicates {
64 let canonical = normalize(&entry.canonical);
65 let card = match entry.cardinality.as_deref() {
66 Some("1") => Cardinality::One,
67 _ => Cardinality::Many,
68 };
69 alias_to_canonical.insert(canonical.clone(), canonical.clone());
70 cardinality.insert(canonical.clone(), card);
71 for alias in &entry.aliases {
72 alias_to_canonical.insert(normalize(alias), canonical.clone());
73 }
74 }
75 Self {
76 alias_to_canonical,
77 cardinality,
78 }
79 }
80}
81
82pub struct OntologyTable {
88 state: ArcSwap<OntologyState>,
89 cache: Mutex<LruCache<String, String>>,
91 cache_max: usize,
92}
93
94impl std::fmt::Debug for OntologyTable {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 f.debug_struct("OntologyTable")
97 .field("state", &"<ArcSwap<OntologyState>>")
98 .field("cache", &"<Mutex<LruCache>>")
99 .field("cache_max", &self.cache_max)
100 .finish()
101 }
102}
103
104impl OntologyTable {
105 fn new_with_state(state: OntologyState, cache_max: usize) -> Self {
106 let cap = NonZeroUsize::new(cache_max.max(1)).expect("cache_max >= 1");
107 Self {
108 state: ArcSwap::new(Arc::new(state)),
109 cache: Mutex::new(LruCache::new(cap)),
110 cache_max,
111 }
112 }
113
114 #[must_use]
116 pub fn from_default(cache_max: usize) -> Self {
117 let state = OntologyState::build(default_predicates());
118 Self::new_with_state(state, cache_max)
119 }
120
121 pub async fn from_path(path: &Path, cache_max: usize) -> Result<Self, MemoryError> {
127 let predicates = if path.as_os_str().is_empty() {
128 default_predicates().to_vec()
129 } else {
130 load_toml_file(path).await?
131 };
132 let state = OntologyState::build(&predicates);
133 Ok(Self::new_with_state(state, cache_max))
134 }
135
136 pub async fn reload(&self, path: &Path) -> Result<(), MemoryError> {
145 let predicates = if path.as_os_str().is_empty() {
146 default_predicates().to_vec()
147 } else {
148 load_toml_file(path).await?
149 };
150 let new_state = Arc::new(OntologyState::build(&predicates));
151 let mut cache = self.cache.lock().await;
154 cache.clear();
155 self.state.store(new_state);
156 Ok(())
157 }
158
159 pub async fn resolve(&self, raw_predicate: &str) -> (String, bool) {
182 let key = normalize(raw_predicate);
183 tracing::debug!(target: "memory.graph.apex.ontology_resolve", predicate = raw_predicate);
184
185 {
186 let mut cache = self.cache.lock().await;
187 if let Some(canonical) = cache.get(&key) {
188 return (canonical.clone(), false);
189 }
190 }
191
192 let state = self.state.load();
193 if let Some(canonical) = state.alias_to_canonical.get(&key) {
194 let canonical = canonical.clone();
195 let mut cache = self.cache.lock().await;
196 cache.put(key, canonical.clone());
197 return (canonical, false);
198 }
199
200 let canonical = key.clone();
202 let mut cache = self.cache.lock().await;
203 cache.put(key, canonical.clone());
204 (canonical, true)
205 }
206
207 #[must_use]
211 pub fn cardinality(&self, canonical_predicate: &str) -> Cardinality {
212 let key = normalize(canonical_predicate);
213 self.state
214 .load()
215 .cardinality
216 .get(&key)
217 .copied()
218 .unwrap_or_default()
219 }
220}
221
222pub(crate) fn normalize(s: &str) -> String {
224 s.trim()
225 .chars()
226 .filter(|c| !c.is_control())
227 .collect::<String>()
228 .to_lowercase()
229}
230
231#[derive(Debug, Clone, Deserialize)]
234struct OntologyToml {
235 #[serde(rename = "predicate")]
236 predicates: Vec<PredicateToml>,
237}
238
239#[derive(Debug, Clone, Deserialize)]
240struct PredicateToml {
241 canonical: String,
242 #[serde(default)]
243 aliases: Vec<String>,
244 #[serde(default, deserialize_with = "de_cardinality")]
246 cardinality: Option<String>,
247}
248
249fn de_cardinality<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
250where
251 D: serde::Deserializer<'de>,
252{
253 use serde::de::Visitor;
254
255 struct CardVisitor;
256 impl<'de> Visitor<'de> for CardVisitor {
257 type Value = Option<String>;
258
259 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260 write!(f, r#"cardinality string "1" or "n", or integer 1"#)
261 }
262
263 fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
264 Ok(Some(v.to_string()))
265 }
266
267 fn visit_u64<E: serde::de::Error>(self, v: u64) -> Result<Self::Value, E> {
268 Ok(Some(if v == 1 {
269 "1".to_string()
270 } else {
271 "n".to_string()
272 }))
273 }
274
275 fn visit_none<E: serde::de::Error>(self) -> Result<Self::Value, E> {
276 Ok(None)
277 }
278
279 fn visit_some<D2: serde::Deserializer<'de>>(self, d: D2) -> Result<Self::Value, D2::Error> {
280 d.deserialize_any(self)
281 }
282 }
283
284 deserializer.deserialize_option(CardVisitor)
285}
286
287async fn load_toml_file(path: &Path) -> Result<Vec<PredicateToml>, MemoryError> {
288 let content = tokio::fs::read_to_string(path)
289 .await
290 .map_err(|e| MemoryError::InvalidInput(format!("ontology TOML read error: {e}")))?;
291 let parsed: OntologyToml = toml::from_str(&content)
292 .map_err(|e| MemoryError::InvalidInput(format!("ontology TOML parse error: {e}")))?;
293 Ok(parsed.predicates)
294}
295
296fn make(canonical: &str, aliases: &[&str], cardinality: &str) -> PredicateToml {
299 PredicateToml {
300 canonical: canonical.to_string(),
301 aliases: aliases.iter().map(|s| (*s).to_string()).collect(),
302 cardinality: Some(cardinality.to_string()),
303 }
304}
305
306fn default_predicates() -> &'static [PredicateToml] {
307 use std::sync::OnceLock;
308 static DEFAULTS: OnceLock<Vec<PredicateToml>> = OnceLock::new();
309 DEFAULTS.get_or_init(|| {
310 vec![
311 make("works_at", &["employed_by", "job_at", "works_for"], "1"),
312 make("lives_in", &["resides_in", "based_in"], "1"),
313 make("born_in", &["birthplace", "born_at"], "1"),
314 make("manages", &["manages_team", "leads", "supervises"], "1"),
315 make("owns", &["has", "possesses"], "n"),
316 make("depends_on", &["requires", "needs"], "n"),
317 make("knows", &[], "n"),
318 ]
319 })
320}
321
322#[cfg(test)]
325mod tests {
326 use super::*;
327
328 #[tokio::test]
329 async fn resolves_alias_to_canonical() {
330 let table = OntologyTable::from_default(64);
331 let (canonical, unmapped) = table.resolve("employed_by").await;
332 assert_eq!(canonical, "works_at");
333 assert!(!unmapped);
334 }
335
336 #[tokio::test]
337 async fn resolves_canonical_to_itself() {
338 let table = OntologyTable::from_default(64);
339 let (canonical, unmapped) = table.resolve("works_at").await;
340 assert_eq!(canonical, "works_at");
341 assert!(!unmapped);
342 }
343
344 #[tokio::test]
345 async fn unknown_predicate_returns_raw_and_unmapped() {
346 let table = OntologyTable::from_default(64);
347 let (canonical, unmapped) = table.resolve("some_new_predicate").await;
348 assert_eq!(canonical, "some_new_predicate");
349 assert!(unmapped);
350 }
351
352 #[tokio::test]
353 async fn cardinality_one_predicates() {
354 let table = OntologyTable::from_default(64);
355 assert_eq!(table.cardinality("works_at"), Cardinality::One);
356 assert_eq!(table.cardinality("lives_in"), Cardinality::One);
357 assert_eq!(table.cardinality("born_in"), Cardinality::One);
358 assert_eq!(table.cardinality("manages"), Cardinality::One);
359 }
360
361 #[tokio::test]
362 async fn cardinality_many_predicates() {
363 let table = OntologyTable::from_default(64);
364 assert_eq!(table.cardinality("owns"), Cardinality::Many);
365 assert_eq!(table.cardinality("depends_on"), Cardinality::Many);
366 assert_eq!(table.cardinality("unknown_pred"), Cardinality::Many);
367 }
368
369 #[tokio::test]
370 async fn normalize_trims_and_lowercases() {
371 assert_eq!(normalize(" Works_At "), "works_at");
372 assert_eq!(normalize("EMPLOYED_BY"), "employed_by");
373 }
374
375 #[tokio::test]
376 async fn cache_hit_on_second_resolve() {
377 let table = OntologyTable::from_default(64);
378 let (c1, _) = table.resolve("job_at").await;
379 let (c2, _) = table.resolve("job_at").await;
380 assert_eq!(c1, c2);
381 assert_eq!(c1, "works_at");
382 }
383
384 #[tokio::test]
385 async fn reload_clears_cache_and_preserves_resolution() {
386 let table = OntologyTable::from_default(64);
387 let _ = table.resolve("job_at").await;
388 table.reload(Path::new("")).await.unwrap();
389 let (canonical, _) = table.resolve("job_at").await;
390 assert_eq!(canonical, "works_at");
391 }
392}