1use crate::provenance::{ProvenanceBundle, TrustScore, TrustScoreConfig};
2use crate::store::MemoryStore;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::time::Instant;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
8pub enum Lane {
9 Bm25,
10 Trigram,
11 Vector,
12}
13
14#[derive(Debug, Clone, Default, Serialize, Deserialize)]
15pub struct QueryLanes {
16 pub bm25: bool,
17 pub trigram: bool,
18 pub vector: bool,
19 pub bm25_weight: f32,
20 pub trigram_weight: f32,
21 pub vector_weight: f32,
22}
23
24impl QueryLanes {
25 pub fn lexical_only() -> Self {
26 Self {
27 bm25: true,
28 trigram: true,
29 vector: false,
30 bm25_weight: 0.6,
31 trigram_weight: 0.4,
32 vector_weight: 0.0,
33 }
34 }
35
36 pub fn three_lane() -> Self {
37 Self {
38 bm25: true,
39 trigram: true,
40 vector: true,
41 bm25_weight: 0.4,
42 trigram_weight: 0.2,
43 vector_weight: 0.4,
44 }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct MemoryQuery {
50 pub namespace: String,
51 pub query: String,
52 pub as_of: Option<u64>,
53 pub lanes: QueryLanes,
54 pub k: usize,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct MemoryHit {
59 pub doc_id: u64,
60 pub score: f32,
61 pub lane: Lane,
62 pub snippet: String,
63 pub provenance: ProvenanceBundle,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct MemoryQueryResult {
68 pub hits: Vec<MemoryHit>,
69 pub query_latency_us: u64,
70 pub lanes_used: Vec<Lane>,
71}
72
73impl MemoryStore {
74 pub fn query(&self, q: &MemoryQuery) -> MemoryQueryResult {
76 let start = Instant::now();
77 let k = q.k.max(1);
78 let mut scores: HashMap<u64, f32> = HashMap::new();
79 let mut lanes_used = Vec::new();
80
81 if q.lanes.bm25 {
82 lanes_used.push(Lane::Bm25);
83 for (doc_id, score) in self.search_bm25(&q.namespace, &q.query, k * 2) {
84 *scores.entry(doc_id).or_default() += score * q.lanes.bm25_weight;
85 }
86 }
87
88 if q.lanes.trigram {
89 lanes_used.push(Lane::Trigram);
90 for (doc_id, score) in self.search_trigram_literal(&q.namespace, &q.query, k * 2) {
91 *scores.entry(doc_id).or_default() += score * q.lanes.trigram_weight;
92 }
93 }
94
95 if q.lanes.vector {
96 lanes_used.push(Lane::Vector);
97 for (doc_id, score) in self.search_vector(&q.namespace, &q.query, k * 2) {
98 *scores.entry(doc_id).or_default() += score * q.lanes.vector_weight;
99 }
100 }
101
102 let tau = q.as_of.unwrap_or(u64::MAX);
103 let trust_cfg = TrustScoreConfig::default();
104
105 let mut ranked: Vec<(u64, f32)> = scores.into_iter().collect();
106 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
107 ranked.truncate(k);
108
109 let hits: Vec<MemoryHit> = ranked
110 .into_iter()
111 .filter_map(|(doc_id, score)| {
112 let text = self.episode_text(&q.namespace, doc_id)?;
113 let episode = self
114 .get_episode(&q.namespace, crate::episode::EpisodeId(doc_id))
115 .ok()?;
116 let snippet: String = text.chars().take(256).collect();
117 let provenance = ProvenanceBundle {
118 episode_id: doc_id,
119 t_valid_from: episode.t_valid_from,
120 t_valid_to: if tau < u64::MAX { tau } else { u64::MAX },
121 trust: TrustScore::compute(&trust_cfg, 1, episode.t_created, 0),
122 };
123 Some(MemoryHit {
124 doc_id,
125 score,
126 lane: Lane::Bm25,
127 snippet,
128 provenance,
129 })
130 })
131 .collect();
132
133 MemoryQueryResult {
134 hits,
135 query_latency_us: start.elapsed().as_micros() as u64,
136 lanes_used,
137 }
138 }
139}