Skip to main content

shodh_memory/memory/
facts.rs

1//! Semantic Fact Storage
2//!
3//! Persistent storage for semantic facts extracted from episodic memories.
4//! Facts represent durable knowledge distilled from multiple experiences.
5//!
6//! Storage schema:
7//! - `facts:{user_id}:{fact_id}` - Primary fact storage
8//! - `facts_by_entity:{user_id}:{entity}:{fact_id}` - Entity index for fast lookup
9//! - `facts_by_type:{user_id}:{type}:{fact_id}` - Type index
10//! - `facts_embedding:{user_id}:{fact_id}` - Pre-computed embedding vector (384-dim)
11
12use anyhow::Result;
13use rocksdb::{IteratorMode, DB};
14use serde::{Deserialize, Serialize};
15use std::sync::Arc;
16
17use super::compression::{FactType, SemanticFact};
18
19/// Response for fact queries
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct FactQueryResponse {
22    pub facts: Vec<SemanticFact>,
23    pub total: usize,
24}
25
26/// Statistics about semantic facts
27#[derive(Debug, Clone, Default, Serialize, Deserialize)]
28pub struct FactStats {
29    pub total_facts: usize,
30    pub by_type: std::collections::HashMap<String, usize>,
31    pub avg_confidence: f32,
32    pub avg_support: f32,
33}
34
35/// Storage for semantic facts with indexing
36pub struct SemanticFactStore {
37    db: Arc<DB>,
38}
39
40impl SemanticFactStore {
41    /// Create a new fact store backed by RocksDB
42    pub fn new(db: Arc<DB>) -> Self {
43        Self { db }
44    }
45
46    /// Get references to all RocksDB databases for backup
47    pub fn databases(&self) -> Vec<(&str, &Arc<DB>)> {
48        vec![("semantic_facts", &self.db)]
49    }
50
51    /// Store a semantic fact
52    pub fn store(&self, user_id: &str, fact: &SemanticFact) -> Result<()> {
53        // Primary storage
54        let key = format!("facts:{}:{}", user_id, fact.id);
55        let value = bincode::serde::encode_to_vec(fact, bincode::config::standard())?;
56        self.db.put(key.as_bytes(), &value)?;
57
58        // Entity index - index by each related entity
59        for entity in &fact.related_entities {
60            let entity_key = format!(
61                "facts_by_entity:{}:{}:{}",
62                user_id,
63                entity.to_lowercase(),
64                fact.id
65            );
66            self.db.put(entity_key.as_bytes(), fact.id.as_bytes())?;
67        }
68
69        // Type index
70        let type_name = format!("{:?}", fact.fact_type);
71        let type_key = format!("facts_by_type:{}:{}:{}", user_id, type_name, fact.id);
72        self.db.put(type_key.as_bytes(), fact.id.as_bytes())?;
73
74        Ok(())
75    }
76
77    /// Store multiple facts in a batch
78    pub fn store_batch(&self, user_id: &str, facts: &[SemanticFact]) -> Result<usize> {
79        let mut stored = 0;
80        for fact in facts {
81            if self.store(user_id, fact).is_ok() {
82                stored += 1;
83            }
84        }
85        Ok(stored)
86    }
87
88    /// Get a fact by ID
89    pub fn get(&self, user_id: &str, fact_id: &str) -> Result<Option<SemanticFact>> {
90        let key = format!("facts:{}:{}", user_id, fact_id);
91        match self.db.get(key.as_bytes())? {
92            Some(data) => {
93                let (fact, _): (SemanticFact, _) =
94                    bincode::serde::decode_from_slice(&data, bincode::config::standard())?;
95                Ok(Some(fact))
96            }
97            None => Ok(None),
98        }
99    }
100
101    /// Update an existing fact (for reinforcement)
102    pub fn update(&self, user_id: &str, fact: &SemanticFact) -> Result<()> {
103        // Simply overwrite - indices stay valid since ID doesn't change
104        let key = format!("facts:{}:{}", user_id, fact.id);
105        let value = bincode::serde::encode_to_vec(fact, bincode::config::standard())?;
106        self.db.put(key.as_bytes(), &value)?;
107        Ok(())
108    }
109
110    /// Delete a fact
111    pub fn delete(&self, user_id: &str, fact_id: &str) -> Result<bool> {
112        // Get fact first to clean up indices
113        if let Some(fact) = self.get(user_id, fact_id)? {
114            // Delete entity indices
115            for entity in &fact.related_entities {
116                let entity_key = format!(
117                    "facts_by_entity:{}:{}:{}",
118                    user_id,
119                    entity.to_lowercase(),
120                    fact_id
121                );
122                self.db.delete(entity_key.as_bytes())?;
123            }
124
125            // Delete type index
126            let type_name = format!("{:?}", fact.fact_type);
127            let type_key = format!("facts_by_type:{}:{}:{}", user_id, type_name, fact_id);
128            self.db.delete(type_key.as_bytes())?;
129
130            // Delete primary record
131            let key = format!("facts:{}:{}", user_id, fact_id);
132            self.db.delete(key.as_bytes())?;
133
134            // Delete embedding if present
135            let _ = self.delete_embedding(user_id, fact_id);
136
137            Ok(true)
138        } else {
139            Ok(false)
140        }
141    }
142
143    /// List all facts for a user
144    pub fn list(&self, user_id: &str, limit: usize) -> Result<Vec<SemanticFact>> {
145        let prefix = format!("facts:{}:", user_id);
146        let mut facts = Vec::new();
147
148        let iter = self.db.iterator(IteratorMode::From(
149            prefix.as_bytes(),
150            rocksdb::Direction::Forward,
151        ));
152
153        for item in iter {
154            let (key, value) = item?;
155            let key_str = String::from_utf8_lossy(&key);
156
157            // Stop when we leave the prefix
158            if !key_str.starts_with(&prefix) {
159                break;
160            }
161
162            // Skip index keys (they contain extra colons)
163            if key_str.matches(':').count() > 2 {
164                continue;
165            }
166
167            if let Ok(fact) = bincode::serde::decode_from_slice::<SemanticFact, _>(
168                &value,
169                bincode::config::standard(),
170            )
171            .map(|(v, _)| v)
172            {
173                facts.push(fact);
174                if facts.len() >= limit {
175                    break;
176                }
177            }
178        }
179
180        // Sort by confidence (highest first)
181        facts.sort_by(|a, b| b.confidence.total_cmp(&a.confidence));
182
183        Ok(facts)
184    }
185
186    /// Find facts by related entity
187    pub fn find_by_entity(
188        &self,
189        user_id: &str,
190        entity: &str,
191        limit: usize,
192    ) -> Result<Vec<SemanticFact>> {
193        let prefix = format!("facts_by_entity:{}:{}:", user_id, entity.to_lowercase());
194        let mut facts = Vec::new();
195        let mut seen_ids = std::collections::HashSet::new();
196
197        let iter = self.db.iterator(IteratorMode::From(
198            prefix.as_bytes(),
199            rocksdb::Direction::Forward,
200        ));
201
202        for item in iter {
203            let (key, value) = item?;
204            let key_str = String::from_utf8_lossy(&key);
205
206            if !key_str.starts_with(&prefix) {
207                break;
208            }
209
210            let fact_id = String::from_utf8_lossy(&value);
211            if seen_ids.insert(fact_id.to_string()) {
212                if let Some(fact) = self.get(user_id, &fact_id)? {
213                    facts.push(fact);
214                    if facts.len() >= limit {
215                        break;
216                    }
217                }
218            }
219        }
220
221        Ok(facts)
222    }
223
224    /// Find facts by type
225    pub fn find_by_type(
226        &self,
227        user_id: &str,
228        fact_type: FactType,
229        limit: usize,
230    ) -> Result<Vec<SemanticFact>> {
231        let type_name = format!("{:?}", fact_type);
232        let prefix = format!("facts_by_type:{}:{}:", user_id, type_name);
233        let mut facts = Vec::new();
234
235        let iter = self.db.iterator(IteratorMode::From(
236            prefix.as_bytes(),
237            rocksdb::Direction::Forward,
238        ));
239
240        for item in iter {
241            let (key, value) = item?;
242            let key_str = String::from_utf8_lossy(&key);
243
244            if !key_str.starts_with(&prefix) {
245                break;
246            }
247
248            let fact_id = String::from_utf8_lossy(&value);
249            if let Some(fact) = self.get(user_id, &fact_id)? {
250                facts.push(fact);
251                if facts.len() >= limit {
252                    break;
253                }
254            }
255        }
256
257        Ok(facts)
258    }
259
260    /// Search facts by keyword in fact content
261    pub fn search(&self, user_id: &str, query: &str, limit: usize) -> Result<Vec<SemanticFact>> {
262        let query_lower = query.to_lowercase();
263        let all_facts = self.list(user_id, 1000)?; // Get all facts
264
265        let mut matching: Vec<SemanticFact> = all_facts
266            .into_iter()
267            .filter(|f| f.fact.to_lowercase().contains(&query_lower))
268            .collect();
269
270        matching.truncate(limit);
271        Ok(matching)
272    }
273
274    /// Get statistics about stored facts
275    pub fn stats(&self, user_id: &str) -> Result<FactStats> {
276        let facts = self.list(user_id, 10000)?;
277
278        if facts.is_empty() {
279            return Ok(FactStats::default());
280        }
281
282        let mut by_type: std::collections::HashMap<String, usize> =
283            std::collections::HashMap::new();
284        let mut total_confidence: f32 = 0.0;
285        let mut total_support: usize = 0;
286
287        for fact in &facts {
288            let type_name = format!("{:?}", fact.fact_type);
289            *by_type.entry(type_name).or_insert(0) += 1;
290            total_confidence += fact.confidence;
291            total_support += fact.support_count;
292        }
293
294        let count = facts.len();
295        Ok(FactStats {
296            total_facts: count,
297            by_type,
298            avg_confidence: total_confidence / count as f32,
299            avg_support: total_support as f32 / count as f32,
300        })
301    }
302
303    /// Find the creation timestamp of the most recent fact for a user.
304    ///
305    /// Used at startup to initialize the fact extraction watermark when no
306    /// persisted watermark exists. Returns None if user has no facts.
307    pub fn latest_fact_created_at(&self, user_id: &str) -> Option<i64> {
308        let prefix = format!("facts:{user_id}:");
309        let mut max_millis: Option<i64> = None;
310
311        let iter = self.db.iterator(IteratorMode::From(
312            prefix.as_bytes(),
313            rocksdb::Direction::Forward,
314        ));
315
316        for item in iter {
317            let (key, value) = match item {
318                Ok(kv) => kv,
319                Err(_) => break,
320            };
321            let key_str = String::from_utf8_lossy(&key);
322            if !key_str.starts_with(&prefix) {
323                break;
324            }
325            // Skip index keys (entity/type sub-keys have extra colons)
326            if key_str.matches(':').count() > 2 {
327                continue;
328            }
329            if let Ok((fact, _)) = bincode::serde::decode_from_slice::<SemanticFact, _>(
330                &value,
331                bincode::config::standard(),
332            ) {
333                let millis = fact.created_at.timestamp_millis();
334                max_millis = Some(max_millis.map_or(millis, |cur| cur.max(millis)));
335            }
336        }
337
338        max_millis
339    }
340
341    /// Find facts that should decay (no reinforcement for too long)
342    pub fn find_decaying_facts(
343        &self,
344        user_id: &str,
345        max_age_days: i64,
346    ) -> Result<Vec<SemanticFact>> {
347        let cutoff = chrono::Utc::now() - chrono::Duration::days(max_age_days);
348        let all_facts = self.list(user_id, 10000)?;
349
350        let decaying: Vec<SemanticFact> = all_facts
351            .into_iter()
352            .filter(|f| f.last_reinforced < cutoff)
353            .collect();
354
355        Ok(decaying)
356    }
357
358    /// Check if a similar fact already exists (hybrid dedup)
359    ///
360    /// Multi-gate pipeline when embedding is provided:
361    /// 1. Entity gate: at least 1 shared entity, OR both have zero entities
362    /// 2. Polarity gate: same negation polarity (prevents merging contradictions)
363    /// 3. Cosine gate: embedding similarity >= FACT_DEDUP_COSINE_THRESHOLD
364    /// 4. Jaccard floor: word overlap >= FACT_DEDUP_JACCARD_FLOOR
365    ///
366    /// Falls back to pure Jaccard (0.70) if no embedding is provided.
367    pub fn find_similar(
368        &self,
369        user_id: &str,
370        fact_content: &str,
371        fact_entities: &[String],
372        new_embedding: Option<&[f32]>,
373    ) -> Result<Option<SemanticFact>> {
374        use crate::constants::{
375            FACT_DEDUP_COSINE_THRESHOLD, FACT_DEDUP_JACCARD_FALLBACK, FACT_DEDUP_JACCARD_FLOOR,
376        };
377        use crate::similarity::cosine_similarity;
378
379        let facts = self.list(user_id, 1000)?;
380        let query_lower = fact_content.to_lowercase();
381        let query_words: std::collections::HashSet<&str> = query_lower.split_whitespace().collect();
382        let new_polarity = detect_polarity(&query_lower);
383        let new_entity_set: std::collections::HashSet<&str> =
384            fact_entities.iter().map(|s| s.as_str()).collect();
385
386        let use_hybrid = new_embedding.is_some();
387        let mut best_match: Option<(f32, SemanticFact)> = None;
388
389        for fact in facts {
390            let fact_lower = fact.fact.to_lowercase();
391            let fact_words: std::collections::HashSet<&str> =
392                fact_lower.split_whitespace().collect();
393
394            // Compute Jaccard (needed in both modes)
395            let intersection = query_words.intersection(&fact_words).count();
396            let union = query_words.union(&fact_words).count();
397            let jaccard = if union > 0 {
398                intersection as f32 / union as f32
399            } else {
400                0.0
401            };
402
403            if use_hybrid {
404                // Gate 1: Entity overlap — at least 1 shared entity, or both empty
405                let existing_entity_set: std::collections::HashSet<&str> =
406                    fact.related_entities.iter().map(|s| s.as_str()).collect();
407                let both_empty = new_entity_set.is_empty() && existing_entity_set.is_empty();
408                let has_overlap = !new_entity_set.is_disjoint(&existing_entity_set);
409                if !both_empty && !has_overlap {
410                    continue;
411                }
412
413                // Gate 2: Polarity match — prevents merging contradictions
414                let existing_polarity = detect_polarity(&fact_lower);
415                if new_polarity != existing_polarity {
416                    continue;
417                }
418
419                // Gate 3: Cosine similarity
420                let new_emb = new_embedding.unwrap();
421                match self.get_embedding(user_id, &fact.id) {
422                    Ok(Some(existing_emb)) => {
423                        let cosine = cosine_similarity(new_emb, &existing_emb);
424                        if cosine < FACT_DEDUP_COSINE_THRESHOLD {
425                            continue;
426                        }
427
428                        // Gate 4: Jaccard sanity floor
429                        if jaccard < FACT_DEDUP_JACCARD_FLOOR {
430                            continue;
431                        }
432
433                        // Passed all gates — rank by cosine
434                        if best_match.as_ref().map_or(true, |(s, _)| cosine > *s) {
435                            best_match = Some((cosine, fact));
436                        }
437                    }
438                    _ => {
439                        // No stored embedding — fall back to Jaccard-only for this candidate
440                        if jaccard >= FACT_DEDUP_JACCARD_FALLBACK
441                            && best_match.as_ref().map_or(true, |(s, _)| jaccard > *s)
442                        {
443                            best_match = Some((jaccard, fact));
444                        }
445                    }
446                }
447            } else {
448                // Fallback: pure Jaccard (legacy behavior when embedder unavailable)
449                if jaccard >= FACT_DEDUP_JACCARD_FALLBACK {
450                    return Ok(Some(fact));
451                }
452            }
453        }
454
455        Ok(best_match.map(|(_, fact)| fact))
456    }
457
458    // =========================================================================
459    // EMBEDDING PERSISTENCE
460    // =========================================================================
461
462    /// Store pre-computed embedding vector for a fact
463    ///
464    /// Key format: `facts_embedding:{user_id}:{fact_id}` → bincode Vec<f32>
465    /// Stored separately from SemanticFact struct for backward compatibility.
466    pub fn store_embedding(&self, user_id: &str, fact_id: &str, embedding: &[f32]) -> Result<()> {
467        let key = format!("facts_embedding:{user_id}:{fact_id}");
468        let value = bincode::serde::encode_to_vec(embedding, bincode::config::standard())?;
469        self.db.put(key.as_bytes(), &value)?;
470        Ok(())
471    }
472
473    /// Get pre-computed embedding vector for a fact
474    pub fn get_embedding(&self, user_id: &str, fact_id: &str) -> Result<Option<Vec<f32>>> {
475        let key = format!("facts_embedding:{user_id}:{fact_id}");
476        match self.db.get(key.as_bytes())? {
477            Some(data) => {
478                let (embedding, _): (Vec<f32>, _) =
479                    bincode::serde::decode_from_slice(&data, bincode::config::standard())?;
480                Ok(Some(embedding))
481            }
482            None => Ok(None),
483        }
484    }
485
486    /// Delete embedding for a fact (called during fact deletion)
487    pub fn delete_embedding(&self, user_id: &str, fact_id: &str) -> Result<()> {
488        let key = format!("facts_embedding:{user_id}:{fact_id}");
489        self.db.delete(key.as_bytes())?;
490        Ok(())
491    }
492
493    /// List all unique user IDs that have facts
494    pub fn list_users(&self, limit: usize) -> Result<Vec<String>> {
495        let prefix = "facts:";
496        let mut users = std::collections::HashSet::new();
497
498        let iter = self.db.iterator(IteratorMode::From(
499            prefix.as_bytes(),
500            rocksdb::Direction::Forward,
501        ));
502
503        for item in iter {
504            let (key, _) = item?;
505            let key_str = String::from_utf8_lossy(&key);
506
507            if !key_str.starts_with(prefix) {
508                break;
509            }
510
511            // Key format: facts:{user_id}:{fact_id}
512            // Skip index keys (facts_by_entity, facts_by_type)
513            if key_str.starts_with("facts_by_") {
514                continue;
515            }
516
517            // Extract user_id from key
518            let parts: Vec<&str> = key_str.splitn(3, ':').collect();
519            if parts.len() >= 2 {
520                users.insert(parts[1].to_string());
521                if users.len() >= limit {
522                    break;
523                }
524            }
525        }
526
527        Ok(users.into_iter().collect())
528    }
529}
530
531/// Detect negation polarity of a fact statement.
532///
533/// Returns `true` for positive polarity (even negation count, including 0),
534/// `false` for negative polarity (odd negation count).
535/// Handles double-negation: "not unlike" = positive.
536fn detect_polarity(text_lower: &str) -> bool {
537    use crate::constants::FACT_NEGATION_MARKERS;
538    let words: Vec<&str> = text_lower.split_whitespace().collect();
539    let negation_count = words
540        .iter()
541        .filter(|w| FACT_NEGATION_MARKERS.iter().any(|marker| *w == marker))
542        .count();
543    negation_count % 2 == 0
544}
545
546#[cfg(test)]
547mod tests {
548    use super::*;
549    use tempfile::TempDir;
550
551    fn create_test_store() -> (SemanticFactStore, TempDir) {
552        let temp_dir = TempDir::new().unwrap();
553        let db = Arc::new(DB::open_default(temp_dir.path()).unwrap());
554        (SemanticFactStore::new(db), temp_dir)
555    }
556
557    fn create_test_fact(id: &str, content: &str) -> SemanticFact {
558        SemanticFact {
559            id: id.to_string(),
560            fact: content.to_string(),
561            confidence: 0.8,
562            support_count: 3,
563            source_memories: vec![],
564            related_entities: vec!["rust".to_string(), "memory".to_string()],
565            created_at: chrono::Utc::now(),
566            last_reinforced: chrono::Utc::now(),
567            fact_type: FactType::Pattern,
568        }
569    }
570
571    #[test]
572    fn test_store_and_get() {
573        let (store, _dir) = create_test_store();
574        let fact = create_test_fact("fact-1", "Rust is a systems programming language");
575
576        store.store("user-1", &fact).unwrap();
577        let retrieved = store.get("user-1", "fact-1").unwrap();
578
579        assert!(retrieved.is_some());
580        assert_eq!(
581            retrieved.unwrap().fact,
582            "Rust is a systems programming language"
583        );
584    }
585
586    #[test]
587    fn test_find_by_entity() {
588        let (store, _dir) = create_test_store();
589        let fact = create_test_fact("fact-1", "Rust has efficient memory management");
590
591        store.store("user-1", &fact).unwrap();
592        let results = store.find_by_entity("user-1", "rust", 10).unwrap();
593
594        assert_eq!(results.len(), 1);
595        assert_eq!(results[0].id, "fact-1");
596    }
597
598    #[test]
599    fn test_find_by_type() {
600        let (store, _dir) = create_test_store();
601        let fact = create_test_fact("fact-1", "Pattern detected in codebase");
602
603        store.store("user-1", &fact).unwrap();
604        let results = store.find_by_type("user-1", FactType::Pattern, 10).unwrap();
605
606        assert_eq!(results.len(), 1);
607    }
608
609    #[test]
610    fn test_delete() {
611        let (store, _dir) = create_test_store();
612        let fact = create_test_fact("fact-1", "Test fact");
613
614        store.store("user-1", &fact).unwrap();
615        assert!(store.get("user-1", "fact-1").unwrap().is_some());
616
617        store.delete("user-1", "fact-1").unwrap();
618        assert!(store.get("user-1", "fact-1").unwrap().is_none());
619
620        // Entity index should also be cleaned up
621        let by_entity = store.find_by_entity("user-1", "rust", 10).unwrap();
622        assert!(by_entity.is_empty());
623    }
624
625    #[test]
626    fn test_stats() {
627        let (store, _dir) = create_test_store();
628
629        store
630            .store("user-1", &create_test_fact("fact-1", "Fact one"))
631            .unwrap();
632        store
633            .store("user-1", &create_test_fact("fact-2", "Fact two"))
634            .unwrap();
635
636        let stats = store.stats("user-1").unwrap();
637        assert_eq!(stats.total_facts, 2);
638        assert!(stats.avg_confidence > 0.0);
639    }
640}