Skip to main content

second_brain_core/
query.rs

1use anyhow::Result;
2use chrono::Utc;
3use uuid::Uuid;
4
5use crate::schema::{Memory, RelationType};
6use crate::store::Store;
7
8#[derive(Debug, Clone)]
9pub struct QueryRequest {
10    pub text: String,
11    pub embedding: Vec<f32>,
12    pub limit: usize,
13    pub filters: QueryFilters,
14}
15
16#[derive(Debug, Clone, Default)]
17pub struct QueryFilters {
18    pub source: Option<String>,
19    pub memory_type: Option<crate::schema::MemoryType>,
20    pub min_confidence: Option<f32>,
21    pub entity_names: Vec<String>,
22    pub project_path: Option<String>,
23}
24
25#[derive(Debug, Clone)]
26pub struct QueryResult {
27    pub memory: Memory,
28    pub score: f32,
29    pub path: Vec<Uuid>,
30}
31
32const MIN_RELEVANCE_SCORE: f32 = 0.59;
33const PROJECT_AFFINITY_BOOST: f32 = 1.15;
34
35pub struct QueryEngine<'a, S: Store> {
36    store: &'a S,
37    vector_weight: f32,
38    graph_weight: f32,
39    recency_weight: f32,
40}
41
42impl<'a, S: Store> QueryEngine<'a, S> {
43    pub fn new(store: &'a S) -> Self {
44        Self {
45            store,
46            vector_weight: 0.5,
47            graph_weight: 0.3,
48            recency_weight: 0.2,
49        }
50    }
51
52    pub fn with_weights(mut self, vector: f32, graph: f32, recency: f32) -> Self {
53        self.vector_weight = vector;
54        self.graph_weight = graph;
55        self.recency_weight = recency;
56        self
57    }
58
59    pub fn recall(&self, request: &QueryRequest) -> Result<Vec<QueryResult>> {
60        let vector_results = self
61            .store
62            .vector_search(&request.embedding, request.limit * 3)?;
63
64        let mut scored: Vec<QueryResult> = Vec::new();
65
66        for (memory, similarity) in vector_results {
67            if let Some(min_conf) = request.filters.min_confidence
68                && memory.confidence < min_conf
69            {
70                continue;
71            }
72            if let Some(ref source) = request.filters.source
73                && &memory.source != source
74            {
75                continue;
76            }
77            if let Some(ref mt) = request.filters.memory_type
78                && &memory.memory_type != mt
79            {
80                continue;
81            }
82
83            let recency_score = self.compute_recency(&memory);
84            let graph_score = self
85                .compute_graph_relevance(&memory, &request.filters.entity_names)
86                .unwrap_or(0.0);
87
88            let mut final_score = (similarity * self.vector_weight)
89                + (graph_score * self.graph_weight)
90                + (recency_score * self.recency_weight);
91
92            if let Some(ref qp) = request.filters.project_path {
93                if memory.project_path.as_deref() == Some(qp.as_str()) {
94                    final_score *= PROJECT_AFFINITY_BOOST;
95                }
96            }
97
98            scored.push(QueryResult {
99                memory,
100                score: final_score,
101                path: Vec::new(),
102            });
103        }
104
105        scored.retain(|r| r.score >= MIN_RELEVANCE_SCORE);
106        scored.sort_by(|a, b| {
107            b.score
108                .partial_cmp(&a.score)
109                .unwrap_or(std::cmp::Ordering::Equal)
110        });
111        scored.truncate(request.limit);
112
113        Ok(scored)
114    }
115
116    fn compute_recency(&self, memory: &Memory) -> f32 {
117        let hours_since_access = Utc::now()
118            .signed_duration_since(memory.last_accessed)
119            .num_hours() as f32;
120
121        let decay = (-hours_since_access / (24.0 * 30.0)).exp();
122        let access_boost = (memory.access_count as f32).ln_1p() / 10.0;
123
124        (decay + access_boost).min(1.0)
125    }
126
127    fn compute_graph_relevance(&self, memory: &Memory, _entity_names: &[String]) -> Result<f32> {
128        let scored_types = [
129            (RelationType::Reinforces, 0.3_f32),
130            (RelationType::RelatesTo, 0.2),
131            (RelationType::DistilledFrom, 0.15),
132            (RelationType::Mentions, 0.1),
133            (RelationType::DerivedFrom, 0.05),
134            (RelationType::Contradicts, -0.1),
135            (RelationType::Supersedes, -0.2),
136        ];
137
138        let mut relevance = 0.0_f32;
139        for (rt, boost) in &scored_types {
140            if let Ok(rels) = self.store.get_relations(memory.id, Some(*rt)) {
141                for rel in &rels {
142                    let b = if *rt == RelationType::RelatesTo {
143                        boost * rel.strength
144                    } else {
145                        *boost
146                    };
147                    relevance += b;
148                }
149            }
150        }
151
152        Ok(relevance.clamp(0.0, 1.0))
153    }
154
155    pub fn store(&self) -> &S {
156        self.store
157    }
158}