Skip to main content

sochdb_memory/
query.rs

1use crate::provenance::{ProvenanceBundle, TrustScore, TrustScoreConfig};
2use crate::store::MemoryStore;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::time::Instant;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
8pub enum Lane {
9    Bm25,
10    Trigram,
11    Vector,
12}
13
14#[derive(Debug, Clone, Default, Serialize, Deserialize)]
15pub struct QueryLanes {
16    pub bm25: bool,
17    pub trigram: bool,
18    pub vector: bool,
19    pub bm25_weight: f32,
20    pub trigram_weight: f32,
21    pub vector_weight: f32,
22}
23
24impl QueryLanes {
25    pub fn lexical_only() -> Self {
26        Self {
27            bm25: true,
28            trigram: true,
29            vector: false,
30            bm25_weight: 0.6,
31            trigram_weight: 0.4,
32            vector_weight: 0.0,
33        }
34    }
35
36    pub fn three_lane() -> Self {
37        Self {
38            bm25: true,
39            trigram: true,
40            vector: true,
41            bm25_weight: 0.4,
42            trigram_weight: 0.2,
43            vector_weight: 0.4,
44        }
45    }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct MemoryQuery {
50    pub namespace: String,
51    pub query: String,
52    pub as_of: Option<u64>,
53    pub lanes: QueryLanes,
54    pub k: usize,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct MemoryHit {
59    pub doc_id: u64,
60    pub score: f32,
61    pub lane: Lane,
62    pub snippet: String,
63    pub provenance: ProvenanceBundle,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct MemoryQueryResult {
68    pub hits: Vec<MemoryHit>,
69    pub query_latency_us: u64,
70    pub lanes_used: Vec<Lane>,
71}
72
73impl MemoryStore {
74    /// Three-lane fusion: BM25 + trigram (+ vector when enriched).
75    pub fn query(&self, q: &MemoryQuery) -> MemoryQueryResult {
76        let start = Instant::now();
77        let k = q.k.max(1);
78        let mut scores: HashMap<u64, f32> = HashMap::new();
79        let mut lanes_used = Vec::new();
80
81        if q.lanes.bm25 {
82            lanes_used.push(Lane::Bm25);
83            for (doc_id, score) in self.search_bm25(&q.namespace, &q.query, k * 2) {
84                *scores.entry(doc_id).or_default() += score * q.lanes.bm25_weight;
85            }
86        }
87
88        if q.lanes.trigram {
89            lanes_used.push(Lane::Trigram);
90            for (doc_id, score) in self.search_trigram_literal(&q.namespace, &q.query, k * 2) {
91                *scores.entry(doc_id).or_default() += score * q.lanes.trigram_weight;
92            }
93        }
94
95        if q.lanes.vector {
96            lanes_used.push(Lane::Vector);
97            for (doc_id, score) in self.search_vector(&q.namespace, &q.query, k * 2) {
98                *scores.entry(doc_id).or_default() += score * q.lanes.vector_weight;
99            }
100        }
101
102        let tau = q.as_of.unwrap_or(u64::MAX);
103        let trust_cfg = TrustScoreConfig::default();
104
105        let mut ranked: Vec<(u64, f32)> = scores.into_iter().collect();
106        ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
107        ranked.truncate(k);
108
109        let hits: Vec<MemoryHit> = ranked
110            .into_iter()
111            .filter_map(|(doc_id, score)| {
112                let text = self.episode_text(&q.namespace, doc_id)?;
113                let episode = self
114                    .get_episode(&q.namespace, crate::episode::EpisodeId(doc_id))
115                    .ok()?;
116                let snippet: String = text.chars().take(256).collect();
117                let provenance = ProvenanceBundle {
118                    episode_id: doc_id,
119                    t_valid_from: episode.t_valid_from,
120                    t_valid_to: if tau < u64::MAX { tau } else { u64::MAX },
121                    trust: TrustScore::compute(&trust_cfg, 1, episode.t_created, 0),
122                };
123                Some(MemoryHit {
124                    doc_id,
125                    score,
126                    lane: Lane::Bm25,
127                    snippet,
128                    provenance,
129                })
130            })
131            .collect();
132
133        MemoryQueryResult {
134            hits,
135            query_latency_us: start.elapsed().as_micros() as u64,
136            lanes_used,
137        }
138    }
139}