1use anyhow::Result;
13use rocksdb::{IteratorMode, DB};
14use serde::{Deserialize, Serialize};
15use std::sync::Arc;
16
17use super::compression::{FactType, SemanticFact};
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct FactQueryResponse {
22 pub facts: Vec<SemanticFact>,
23 pub total: usize,
24}
25
26#[derive(Debug, Clone, Default, Serialize, Deserialize)]
28pub struct FactStats {
29 pub total_facts: usize,
30 pub by_type: std::collections::HashMap<String, usize>,
31 pub avg_confidence: f32,
32 pub avg_support: f32,
33}
34
35pub struct SemanticFactStore {
37 db: Arc<DB>,
38}
39
40impl SemanticFactStore {
41 pub fn new(db: Arc<DB>) -> Self {
43 Self { db }
44 }
45
46 pub fn databases(&self) -> Vec<(&str, &Arc<DB>)> {
48 vec![("semantic_facts", &self.db)]
49 }
50
51 pub fn store(&self, user_id: &str, fact: &SemanticFact) -> Result<()> {
53 let key = format!("facts:{}:{}", user_id, fact.id);
55 let value = bincode::serde::encode_to_vec(fact, bincode::config::standard())?;
56 self.db.put(key.as_bytes(), &value)?;
57
58 for entity in &fact.related_entities {
60 let entity_key = format!(
61 "facts_by_entity:{}:{}:{}",
62 user_id,
63 entity.to_lowercase(),
64 fact.id
65 );
66 self.db.put(entity_key.as_bytes(), fact.id.as_bytes())?;
67 }
68
69 let type_name = format!("{:?}", fact.fact_type);
71 let type_key = format!("facts_by_type:{}:{}:{}", user_id, type_name, fact.id);
72 self.db.put(type_key.as_bytes(), fact.id.as_bytes())?;
73
74 Ok(())
75 }
76
77 pub fn store_batch(&self, user_id: &str, facts: &[SemanticFact]) -> Result<usize> {
79 let mut stored = 0;
80 for fact in facts {
81 if self.store(user_id, fact).is_ok() {
82 stored += 1;
83 }
84 }
85 Ok(stored)
86 }
87
88 pub fn get(&self, user_id: &str, fact_id: &str) -> Result<Option<SemanticFact>> {
90 let key = format!("facts:{}:{}", user_id, fact_id);
91 match self.db.get(key.as_bytes())? {
92 Some(data) => {
93 let (fact, _): (SemanticFact, _) =
94 bincode::serde::decode_from_slice(&data, bincode::config::standard())?;
95 Ok(Some(fact))
96 }
97 None => Ok(None),
98 }
99 }
100
101 pub fn update(&self, user_id: &str, fact: &SemanticFact) -> Result<()> {
103 let key = format!("facts:{}:{}", user_id, fact.id);
105 let value = bincode::serde::encode_to_vec(fact, bincode::config::standard())?;
106 self.db.put(key.as_bytes(), &value)?;
107 Ok(())
108 }
109
110 pub fn delete(&self, user_id: &str, fact_id: &str) -> Result<bool> {
112 if let Some(fact) = self.get(user_id, fact_id)? {
114 for entity in &fact.related_entities {
116 let entity_key = format!(
117 "facts_by_entity:{}:{}:{}",
118 user_id,
119 entity.to_lowercase(),
120 fact_id
121 );
122 self.db.delete(entity_key.as_bytes())?;
123 }
124
125 let type_name = format!("{:?}", fact.fact_type);
127 let type_key = format!("facts_by_type:{}:{}:{}", user_id, type_name, fact_id);
128 self.db.delete(type_key.as_bytes())?;
129
130 let key = format!("facts:{}:{}", user_id, fact_id);
132 self.db.delete(key.as_bytes())?;
133
134 let _ = self.delete_embedding(user_id, fact_id);
136
137 Ok(true)
138 } else {
139 Ok(false)
140 }
141 }
142
143 pub fn list(&self, user_id: &str, limit: usize) -> Result<Vec<SemanticFact>> {
145 let prefix = format!("facts:{}:", user_id);
146 let mut facts = Vec::new();
147
148 let iter = self.db.iterator(IteratorMode::From(
149 prefix.as_bytes(),
150 rocksdb::Direction::Forward,
151 ));
152
153 for item in iter {
154 let (key, value) = item?;
155 let key_str = String::from_utf8_lossy(&key);
156
157 if !key_str.starts_with(&prefix) {
159 break;
160 }
161
162 if key_str.matches(':').count() > 2 {
164 continue;
165 }
166
167 if let Ok(fact) = bincode::serde::decode_from_slice::<SemanticFact, _>(
168 &value,
169 bincode::config::standard(),
170 )
171 .map(|(v, _)| v)
172 {
173 facts.push(fact);
174 if facts.len() >= limit {
175 break;
176 }
177 }
178 }
179
180 facts.sort_by(|a, b| b.confidence.total_cmp(&a.confidence));
182
183 Ok(facts)
184 }
185
186 pub fn find_by_entity(
188 &self,
189 user_id: &str,
190 entity: &str,
191 limit: usize,
192 ) -> Result<Vec<SemanticFact>> {
193 let prefix = format!("facts_by_entity:{}:{}:", user_id, entity.to_lowercase());
194 let mut facts = Vec::new();
195 let mut seen_ids = std::collections::HashSet::new();
196
197 let iter = self.db.iterator(IteratorMode::From(
198 prefix.as_bytes(),
199 rocksdb::Direction::Forward,
200 ));
201
202 for item in iter {
203 let (key, value) = item?;
204 let key_str = String::from_utf8_lossy(&key);
205
206 if !key_str.starts_with(&prefix) {
207 break;
208 }
209
210 let fact_id = String::from_utf8_lossy(&value);
211 if seen_ids.insert(fact_id.to_string()) {
212 if let Some(fact) = self.get(user_id, &fact_id)? {
213 facts.push(fact);
214 if facts.len() >= limit {
215 break;
216 }
217 }
218 }
219 }
220
221 Ok(facts)
222 }
223
224 pub fn find_by_type(
226 &self,
227 user_id: &str,
228 fact_type: FactType,
229 limit: usize,
230 ) -> Result<Vec<SemanticFact>> {
231 let type_name = format!("{:?}", fact_type);
232 let prefix = format!("facts_by_type:{}:{}:", user_id, type_name);
233 let mut facts = Vec::new();
234
235 let iter = self.db.iterator(IteratorMode::From(
236 prefix.as_bytes(),
237 rocksdb::Direction::Forward,
238 ));
239
240 for item in iter {
241 let (key, value) = item?;
242 let key_str = String::from_utf8_lossy(&key);
243
244 if !key_str.starts_with(&prefix) {
245 break;
246 }
247
248 let fact_id = String::from_utf8_lossy(&value);
249 if let Some(fact) = self.get(user_id, &fact_id)? {
250 facts.push(fact);
251 if facts.len() >= limit {
252 break;
253 }
254 }
255 }
256
257 Ok(facts)
258 }
259
260 pub fn search(&self, user_id: &str, query: &str, limit: usize) -> Result<Vec<SemanticFact>> {
262 let query_lower = query.to_lowercase();
263 let all_facts = self.list(user_id, 1000)?; let mut matching: Vec<SemanticFact> = all_facts
266 .into_iter()
267 .filter(|f| f.fact.to_lowercase().contains(&query_lower))
268 .collect();
269
270 matching.truncate(limit);
271 Ok(matching)
272 }
273
274 pub fn stats(&self, user_id: &str) -> Result<FactStats> {
276 let facts = self.list(user_id, 10000)?;
277
278 if facts.is_empty() {
279 return Ok(FactStats::default());
280 }
281
282 let mut by_type: std::collections::HashMap<String, usize> =
283 std::collections::HashMap::new();
284 let mut total_confidence: f32 = 0.0;
285 let mut total_support: usize = 0;
286
287 for fact in &facts {
288 let type_name = format!("{:?}", fact.fact_type);
289 *by_type.entry(type_name).or_insert(0) += 1;
290 total_confidence += fact.confidence;
291 total_support += fact.support_count;
292 }
293
294 let count = facts.len();
295 Ok(FactStats {
296 total_facts: count,
297 by_type,
298 avg_confidence: total_confidence / count as f32,
299 avg_support: total_support as f32 / count as f32,
300 })
301 }
302
303 pub fn latest_fact_created_at(&self, user_id: &str) -> Option<i64> {
308 let prefix = format!("facts:{user_id}:");
309 let mut max_millis: Option<i64> = None;
310
311 let iter = self.db.iterator(IteratorMode::From(
312 prefix.as_bytes(),
313 rocksdb::Direction::Forward,
314 ));
315
316 for item in iter {
317 let (key, value) = match item {
318 Ok(kv) => kv,
319 Err(_) => break,
320 };
321 let key_str = String::from_utf8_lossy(&key);
322 if !key_str.starts_with(&prefix) {
323 break;
324 }
325 if key_str.matches(':').count() > 2 {
327 continue;
328 }
329 if let Ok((fact, _)) = bincode::serde::decode_from_slice::<SemanticFact, _>(
330 &value,
331 bincode::config::standard(),
332 ) {
333 let millis = fact.created_at.timestamp_millis();
334 max_millis = Some(max_millis.map_or(millis, |cur| cur.max(millis)));
335 }
336 }
337
338 max_millis
339 }
340
341 pub fn find_decaying_facts(
343 &self,
344 user_id: &str,
345 max_age_days: i64,
346 ) -> Result<Vec<SemanticFact>> {
347 let cutoff = chrono::Utc::now() - chrono::Duration::days(max_age_days);
348 let all_facts = self.list(user_id, 10000)?;
349
350 let decaying: Vec<SemanticFact> = all_facts
351 .into_iter()
352 .filter(|f| f.last_reinforced < cutoff)
353 .collect();
354
355 Ok(decaying)
356 }
357
358 pub fn find_similar(
368 &self,
369 user_id: &str,
370 fact_content: &str,
371 fact_entities: &[String],
372 new_embedding: Option<&[f32]>,
373 ) -> Result<Option<SemanticFact>> {
374 use crate::constants::{
375 FACT_DEDUP_COSINE_THRESHOLD, FACT_DEDUP_JACCARD_FALLBACK, FACT_DEDUP_JACCARD_FLOOR,
376 };
377 use crate::similarity::cosine_similarity;
378
379 let facts = self.list(user_id, 1000)?;
380 let query_lower = fact_content.to_lowercase();
381 let query_words: std::collections::HashSet<&str> = query_lower.split_whitespace().collect();
382 let new_polarity = detect_polarity(&query_lower);
383 let new_entity_set: std::collections::HashSet<&str> =
384 fact_entities.iter().map(|s| s.as_str()).collect();
385
386 let use_hybrid = new_embedding.is_some();
387 let mut best_match: Option<(f32, SemanticFact)> = None;
388
389 for fact in facts {
390 let fact_lower = fact.fact.to_lowercase();
391 let fact_words: std::collections::HashSet<&str> =
392 fact_lower.split_whitespace().collect();
393
394 let intersection = query_words.intersection(&fact_words).count();
396 let union = query_words.union(&fact_words).count();
397 let jaccard = if union > 0 {
398 intersection as f32 / union as f32
399 } else {
400 0.0
401 };
402
403 if use_hybrid {
404 let existing_entity_set: std::collections::HashSet<&str> =
406 fact.related_entities.iter().map(|s| s.as_str()).collect();
407 let both_empty = new_entity_set.is_empty() && existing_entity_set.is_empty();
408 let has_overlap = !new_entity_set.is_disjoint(&existing_entity_set);
409 if !both_empty && !has_overlap {
410 continue;
411 }
412
413 let existing_polarity = detect_polarity(&fact_lower);
415 if new_polarity != existing_polarity {
416 continue;
417 }
418
419 let new_emb = new_embedding.unwrap();
421 match self.get_embedding(user_id, &fact.id) {
422 Ok(Some(existing_emb)) => {
423 let cosine = cosine_similarity(new_emb, &existing_emb);
424 if cosine < FACT_DEDUP_COSINE_THRESHOLD {
425 continue;
426 }
427
428 if jaccard < FACT_DEDUP_JACCARD_FLOOR {
430 continue;
431 }
432
433 if best_match.as_ref().map_or(true, |(s, _)| cosine > *s) {
435 best_match = Some((cosine, fact));
436 }
437 }
438 _ => {
439 if jaccard >= FACT_DEDUP_JACCARD_FALLBACK
441 && best_match.as_ref().map_or(true, |(s, _)| jaccard > *s)
442 {
443 best_match = Some((jaccard, fact));
444 }
445 }
446 }
447 } else {
448 if jaccard >= FACT_DEDUP_JACCARD_FALLBACK {
450 return Ok(Some(fact));
451 }
452 }
453 }
454
455 Ok(best_match.map(|(_, fact)| fact))
456 }
457
458 pub fn store_embedding(&self, user_id: &str, fact_id: &str, embedding: &[f32]) -> Result<()> {
467 let key = format!("facts_embedding:{user_id}:{fact_id}");
468 let value = bincode::serde::encode_to_vec(embedding, bincode::config::standard())?;
469 self.db.put(key.as_bytes(), &value)?;
470 Ok(())
471 }
472
473 pub fn get_embedding(&self, user_id: &str, fact_id: &str) -> Result<Option<Vec<f32>>> {
475 let key = format!("facts_embedding:{user_id}:{fact_id}");
476 match self.db.get(key.as_bytes())? {
477 Some(data) => {
478 let (embedding, _): (Vec<f32>, _) =
479 bincode::serde::decode_from_slice(&data, bincode::config::standard())?;
480 Ok(Some(embedding))
481 }
482 None => Ok(None),
483 }
484 }
485
486 pub fn delete_embedding(&self, user_id: &str, fact_id: &str) -> Result<()> {
488 let key = format!("facts_embedding:{user_id}:{fact_id}");
489 self.db.delete(key.as_bytes())?;
490 Ok(())
491 }
492
493 pub fn list_users(&self, limit: usize) -> Result<Vec<String>> {
495 let prefix = "facts:";
496 let mut users = std::collections::HashSet::new();
497
498 let iter = self.db.iterator(IteratorMode::From(
499 prefix.as_bytes(),
500 rocksdb::Direction::Forward,
501 ));
502
503 for item in iter {
504 let (key, _) = item?;
505 let key_str = String::from_utf8_lossy(&key);
506
507 if !key_str.starts_with(prefix) {
508 break;
509 }
510
511 if key_str.starts_with("facts_by_") {
514 continue;
515 }
516
517 let parts: Vec<&str> = key_str.splitn(3, ':').collect();
519 if parts.len() >= 2 {
520 users.insert(parts[1].to_string());
521 if users.len() >= limit {
522 break;
523 }
524 }
525 }
526
527 Ok(users.into_iter().collect())
528 }
529}
530
531fn detect_polarity(text_lower: &str) -> bool {
537 use crate::constants::FACT_NEGATION_MARKERS;
538 let words: Vec<&str> = text_lower.split_whitespace().collect();
539 let negation_count = words
540 .iter()
541 .filter(|w| FACT_NEGATION_MARKERS.iter().any(|marker| *w == marker))
542 .count();
543 negation_count % 2 == 0
544}
545
546#[cfg(test)]
547mod tests {
548 use super::*;
549 use tempfile::TempDir;
550
551 fn create_test_store() -> (SemanticFactStore, TempDir) {
552 let temp_dir = TempDir::new().unwrap();
553 let db = Arc::new(DB::open_default(temp_dir.path()).unwrap());
554 (SemanticFactStore::new(db), temp_dir)
555 }
556
557 fn create_test_fact(id: &str, content: &str) -> SemanticFact {
558 SemanticFact {
559 id: id.to_string(),
560 fact: content.to_string(),
561 confidence: 0.8,
562 support_count: 3,
563 source_memories: vec![],
564 related_entities: vec!["rust".to_string(), "memory".to_string()],
565 created_at: chrono::Utc::now(),
566 last_reinforced: chrono::Utc::now(),
567 fact_type: FactType::Pattern,
568 }
569 }
570
571 #[test]
572 fn test_store_and_get() {
573 let (store, _dir) = create_test_store();
574 let fact = create_test_fact("fact-1", "Rust is a systems programming language");
575
576 store.store("user-1", &fact).unwrap();
577 let retrieved = store.get("user-1", "fact-1").unwrap();
578
579 assert!(retrieved.is_some());
580 assert_eq!(
581 retrieved.unwrap().fact,
582 "Rust is a systems programming language"
583 );
584 }
585
586 #[test]
587 fn test_find_by_entity() {
588 let (store, _dir) = create_test_store();
589 let fact = create_test_fact("fact-1", "Rust has efficient memory management");
590
591 store.store("user-1", &fact).unwrap();
592 let results = store.find_by_entity("user-1", "rust", 10).unwrap();
593
594 assert_eq!(results.len(), 1);
595 assert_eq!(results[0].id, "fact-1");
596 }
597
598 #[test]
599 fn test_find_by_type() {
600 let (store, _dir) = create_test_store();
601 let fact = create_test_fact("fact-1", "Pattern detected in codebase");
602
603 store.store("user-1", &fact).unwrap();
604 let results = store.find_by_type("user-1", FactType::Pattern, 10).unwrap();
605
606 assert_eq!(results.len(), 1);
607 }
608
609 #[test]
610 fn test_delete() {
611 let (store, _dir) = create_test_store();
612 let fact = create_test_fact("fact-1", "Test fact");
613
614 store.store("user-1", &fact).unwrap();
615 assert!(store.get("user-1", "fact-1").unwrap().is_some());
616
617 store.delete("user-1", "fact-1").unwrap();
618 assert!(store.get("user-1", "fact-1").unwrap().is_none());
619
620 let by_entity = store.find_by_entity("user-1", "rust", 10).unwrap();
622 assert!(by_entity.is_empty());
623 }
624
625 #[test]
626 fn test_stats() {
627 let (store, _dir) = create_test_store();
628
629 store
630 .store("user-1", &create_test_fact("fact-1", "Fact one"))
631 .unwrap();
632 store
633 .store("user-1", &create_test_fact("fact-2", "Fact two"))
634 .unwrap();
635
636 let stats = store.stats("user-1").unwrap();
637 assert_eq!(stats.total_facts, 2);
638 assert!(stats.avg_confidence > 0.0);
639 }
640}