Skip to main content

second_brain_core/
query.rs

1use anyhow::Result;
2use chrono::Utc;
3use uuid::Uuid;
4
5use crate::schema::{Memory, RelationType};
6use crate::store::Store;
7
8#[derive(Debug, Clone)]
9pub struct QueryRequest {
10    pub text: String,
11    pub embedding: Vec<f32>,
12    pub limit: usize,
13    pub filters: QueryFilters,
14}
15
16#[derive(Debug, Clone, Default)]
17pub struct QueryFilters {
18    pub source: Option<String>,
19    pub memory_type: Option<crate::schema::MemoryType>,
20    pub min_confidence: Option<f32>,
21    pub entity_names: Vec<String>,
22}
23
24#[derive(Debug, Clone)]
25pub struct QueryResult {
26    pub memory: Memory,
27    pub score: f32,
28    pub path: Vec<Uuid>,
29}
30
31const MIN_RELEVANCE_SCORE: f32 = 0.59;
32
33pub struct QueryEngine<'a, S: Store> {
34    store: &'a S,
35    vector_weight: f32,
36    graph_weight: f32,
37    recency_weight: f32,
38}
39
40impl<'a, S: Store> QueryEngine<'a, S> {
41    pub fn new(store: &'a S) -> Self {
42        Self {
43            store,
44            vector_weight: 0.5,
45            graph_weight: 0.3,
46            recency_weight: 0.2,
47        }
48    }
49
50    pub fn with_weights(mut self, vector: f32, graph: f32, recency: f32) -> Self {
51        self.vector_weight = vector;
52        self.graph_weight = graph;
53        self.recency_weight = recency;
54        self
55    }
56
57    pub fn recall(&self, request: &QueryRequest) -> Result<Vec<QueryResult>> {
58        let vector_results = self
59            .store
60            .vector_search(&request.embedding, request.limit * 3)?;
61
62        let mut scored: Vec<QueryResult> = Vec::new();
63
64        for (memory, similarity) in vector_results {
65            if let Some(min_conf) = request.filters.min_confidence {
66                if memory.confidence < min_conf {
67                    continue;
68                }
69            }
70            if let Some(ref source) = request.filters.source {
71                if &memory.source != source {
72                    continue;
73                }
74            }
75            if let Some(ref mt) = request.filters.memory_type {
76                if &memory.memory_type != mt {
77                    continue;
78                }
79            }
80
81            let recency_score = self.compute_recency(&memory);
82            let graph_score = self
83                .compute_graph_relevance(&memory, &request.filters.entity_names)
84                .unwrap_or(0.0);
85
86            let final_score = (similarity * self.vector_weight)
87                + (graph_score * self.graph_weight)
88                + (recency_score * self.recency_weight);
89
90            scored.push(QueryResult {
91                memory,
92                score: final_score,
93                path: Vec::new(),
94            });
95        }
96
97        scored.retain(|r| r.score >= MIN_RELEVANCE_SCORE);
98        scored.sort_by(|a, b| {
99            b.score
100                .partial_cmp(&a.score)
101                .unwrap_or(std::cmp::Ordering::Equal)
102        });
103        scored.truncate(request.limit);
104
105        Ok(scored)
106    }
107
108    fn compute_recency(&self, memory: &Memory) -> f32 {
109        let hours_since_access = Utc::now()
110            .signed_duration_since(memory.last_accessed)
111            .num_hours() as f32;
112
113        let decay = (-hours_since_access / (24.0 * 30.0)).exp();
114        let access_boost = (memory.access_count as f32).ln_1p() / 10.0;
115
116        (decay + access_boost).min(1.0)
117    }
118
119    fn compute_graph_relevance(&self, memory: &Memory, _entity_names: &[String]) -> Result<f32> {
120        let scored_types = [
121            (RelationType::Reinforces, 0.3_f32),
122            (RelationType::RelatesTo, 0.2),
123            (RelationType::DistilledFrom, 0.15),
124            (RelationType::Mentions, 0.1),
125            (RelationType::DerivedFrom, 0.05),
126            (RelationType::Contradicts, -0.1),
127            (RelationType::Supersedes, -0.2),
128        ];
129
130        let mut relevance = 0.0_f32;
131        for (rt, boost) in &scored_types {
132            if let Ok(rels) = self.store.get_relations(memory.id, Some(*rt)) {
133                for rel in &rels {
134                    let b = if *rt == RelationType::RelatesTo {
135                        boost * rel.strength
136                    } else {
137                        *boost
138                    };
139                    relevance += b;
140                }
141            }
142        }
143
144        Ok(relevance.clamp(0.0, 1.0))
145    }
146
147    pub fn store(&self) -> &S {
148        self.store
149    }
150}