1pub mod bootstrap;
2pub mod caching_store;
3pub mod metrics;
4
5use std::collections::HashSet;
6use std::io::Write;
7use std::path::Path;
8use std::sync::Mutex;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::thread;
11
12use anyhow::Result;
13use serde::{Deserialize, Serialize};
14use uuid::Uuid;
15
16use second_brain_core::embedding::Embedder;
17use second_brain_core::kuzu_store::KuzuStore;
18use second_brain_core::query::{QueryEngine, QueryFilters, QueryRequest};
19use second_brain_core::store::Store;
20
21#[derive(Debug, Clone, Deserialize, Serialize)]
22pub struct EvalQuery {
23 pub query_id: String,
24 pub query: String,
25 pub query_variant: String,
26 pub seed_memory_id: Uuid,
27 pub memory_type: String,
28 pub relevant_memory_ids: Vec<Uuid>,
29 #[serde(default)]
30 pub note: Option<String>,
31 #[serde(default)]
32 pub tags: Vec<String>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct QueryRecord {
37 pub query_id: String,
38 pub use_prefix: bool,
39 pub ranked_ids: Vec<Uuid>,
40 pub scores: Vec<f32>,
41 pub first_relevant_rank: Option<usize>,
42 pub gold_raw_rank: Option<usize>,
43 pub gold_raw_similarity: Option<f32>,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ArmMetrics {
48 pub recall_at_1: f32,
49 pub recall_at_3: f32,
50 pub recall_at_5: f32,
51 pub mrr: f32,
52 pub precision_at_5: f32,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct AggregateReport {
57 pub bare: ArmMetrics,
58 pub prefixed: ArmMetrics,
59 pub delta_recall_at_3_ci: (f32, f32),
60 pub delta_mrr_ci: (f32, f32),
61 pub gated_out_rate_bare: f32,
62 pub gated_out_rate_prefixed: f32,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct GatePoint {
67 pub threshold: f32,
68 pub recall_at_1: f32,
69 pub recall_at_3: f32,
70 pub recall_at_5: f32,
71 pub precision_proxy: f32,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct GateSweepReport {
76 pub frontier: Vec<GatePoint>,
77 pub baseline_threshold: f32,
78 pub chosen_threshold: f32,
79 pub chosen_beats_baseline: bool,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct CorpusEntry {
84 pub id: Uuid,
85 pub content: String,
86 pub memory_type: String,
87 pub created_at: String,
88 pub project_path: Option<String>,
89}
90
91const BASELINE_THRESHOLD: f32 = 0.59;
92
93pub fn load_eval_set(path: &Path) -> Result<Vec<EvalQuery>> {
94 let text = std::fs::read_to_string(path)?;
95 let mut out = Vec::new();
96 for line in text.lines() {
97 let trimmed = line.trim();
98 if trimmed.is_empty() {
99 continue;
100 }
101 out.push(serde_json::from_str(trimmed)?);
102 }
103 Ok(out)
104}
105
106pub fn run_arm(
107 store: &KuzuStore,
108 embedder: &Embedder,
109 queries: &[EvalQuery],
110 use_prefix: bool,
111 limit: usize,
112) -> Result<Vec<QueryRecord>> {
113 let engine = QueryEngine::new(store);
114 let mut records = Vec::with_capacity(queries.len());
115
116 for q in queries {
117 let embedding = if use_prefix {
118 embedder.embed_query(&q.query)?
119 } else {
120 embedder.embed(&q.query)?
121 };
122
123 let relevant: HashSet<Uuid> = q.relevant_memory_ids.iter().copied().collect();
124
125 let request = QueryRequest {
126 text: q.query.clone(),
127 embedding: embedding.clone(),
128 limit,
129 filters: QueryFilters::default(),
130 };
131 let results = engine.recall(&request)?;
132
133 let ranked_ids: Vec<Uuid> = results.iter().map(|r| r.memory.id).collect();
134 let scores: Vec<f32> = results.iter().map(|r| r.score).collect();
135 let first_relevant_rank = ranked_ids
136 .iter()
137 .position(|id| relevant.contains(id))
138 .map(|idx| idx + 1);
139
140 let raw = store.vector_search(&embedding, limit * 3)?;
141 let mut gold_raw_rank = None;
142 let mut gold_raw_similarity = None;
143 for (idx, (mem, sim)) in raw.iter().enumerate() {
144 if relevant.contains(&mem.id) {
145 gold_raw_rank = Some(idx + 1);
146 gold_raw_similarity = Some(*sim);
147 break;
148 }
149 }
150
151 records.push(QueryRecord {
152 query_id: q.query_id.clone(),
153 use_prefix,
154 ranked_ids,
155 scores,
156 first_relevant_rank,
157 gold_raw_rank,
158 gold_raw_similarity,
159 });
160 }
161
162 Ok(records)
163}
164
165pub struct EmbeddedQuery {
166 pub query: EvalQuery,
167 pub relevant: HashSet<Uuid>,
168 pub bare_embedding: Vec<f32>,
169 pub prefixed_embedding: Vec<f32>,
170}
171
172pub fn embed_all_queries(
176 embedder: &Embedder,
177 queries: &[EvalQuery],
178) -> Result<Vec<EmbeddedQuery>> {
179 let bare_texts: Vec<&str> = queries.iter().map(|q| q.query.as_str()).collect();
180 let bare = embedder.embed_batch(&bare_texts)?;
181
182 let prefixed_owned: Vec<String> = queries
183 .iter()
184 .map(|q| second_brain_core::embedding::query_prompt(&q.query))
185 .collect();
186 let prefixed_texts: Vec<&str> = prefixed_owned.iter().map(|s| s.as_str()).collect();
187 let prefixed = embedder.embed_batch(&prefixed_texts)?;
188
189 let mut out = Vec::with_capacity(queries.len());
190 for (i, q) in queries.iter().enumerate() {
191 out.push(EmbeddedQuery {
192 query: q.clone(),
193 relevant: q.relevant_memory_ids.iter().copied().collect(),
194 bare_embedding: bare[i].clone(),
195 prefixed_embedding: prefixed[i].clone(),
196 });
197 }
198 Ok(out)
199}
200
201fn record_for<S: Store + Sync>(
202 embedded: &EmbeddedQuery,
203 store: &S,
204 use_prefix: bool,
205 limit: usize,
206) -> Result<QueryRecord> {
207 let engine = QueryEngine::new(store);
208 let embedding = if use_prefix {
209 &embedded.prefixed_embedding
210 } else {
211 &embedded.bare_embedding
212 };
213
214 let request = QueryRequest {
215 text: embedded.query.query.clone(),
216 embedding: embedding.clone(),
217 limit,
218 filters: QueryFilters::default(),
219 };
220 let results = engine.recall(&request)?;
221
222 let ranked_ids: Vec<Uuid> = results.iter().map(|r| r.memory.id).collect();
223 let scores: Vec<f32> = results.iter().map(|r| r.score).collect();
224 let first_relevant_rank = ranked_ids
225 .iter()
226 .position(|id| embedded.relevant.contains(id))
227 .map(|idx| idx + 1);
228
229 let raw = store.vector_search(embedding, limit * 3)?;
230 let mut gold_raw_rank = None;
231 let mut gold_raw_similarity = None;
232 for (idx, (mem, sim)) in raw.iter().enumerate() {
233 if embedded.relevant.contains(&mem.id) {
234 gold_raw_rank = Some(idx + 1);
235 gold_raw_similarity = Some(*sim);
236 break;
237 }
238 }
239
240 Ok(QueryRecord {
241 query_id: embedded.query.query_id.clone(),
242 use_prefix,
243 ranked_ids,
244 scores,
245 first_relevant_rank,
246 gold_raw_rank,
247 gold_raw_similarity,
248 })
249}
250
251pub fn run_arm_parallel<S: Store + Sync>(
257 store: &S,
258 embedded: &[EmbeddedQuery],
259 use_prefix: bool,
260 limit: usize,
261) -> Result<Vec<QueryRecord>> {
262 let total = embedded.len();
263 if total == 0 {
264 return Ok(Vec::new());
265 }
266
267 let workers = thread::available_parallelism()
268 .map(|n| n.get().saturating_sub(1).max(1))
269 .unwrap_or(1);
270 let chunk_size = total.div_ceil(workers);
271
272 let done = AtomicUsize::new(0);
273 let collected: Mutex<Vec<(usize, QueryRecord)>> = Mutex::new(Vec::with_capacity(total));
274 let error: Mutex<Option<anyhow::Error>> = Mutex::new(None);
275
276 thread::scope(|scope| {
277 for chunk in 0..workers {
278 let start = chunk * chunk_size;
279 if start >= total {
280 break;
281 }
282 let end = (start + chunk_size).min(total);
283 let done = &done;
284 let collected = &collected;
285 let error = &error;
286 scope.spawn(move || {
287 let mut local: Vec<(usize, QueryRecord)> = Vec::with_capacity(end - start);
288 for (offset, eq) in embedded[start..end].iter().enumerate() {
289 if error.lock().unwrap().is_some() {
290 return;
291 }
292 match record_for(eq, store, use_prefix, limit) {
293 Ok(rec) => local.push((start + offset, rec)),
294 Err(e) => {
295 *error.lock().unwrap() = Some(e);
296 return;
297 }
298 }
299 let n = done.fetch_add(1, Ordering::Relaxed) + 1;
301 if n % 25 == 0 || n == total {
302 eprintln!(" {n}/{total} queries");
303 }
304 }
305 collected.lock().unwrap().extend(local);
306 });
307 }
308 });
309
310 if let Some(e) = error.into_inner().unwrap() {
311 return Err(e);
312 }
313
314 let mut indexed = collected.into_inner().unwrap();
315 indexed.sort_by_key(|(i, _)| *i);
316 Ok(indexed.into_iter().map(|(_, r)| r).collect())
317}
318
319pub fn aggregate(
320 bare: &[QueryRecord],
321 prefixed: &[QueryRecord],
322 relevant_sets: &std::collections::HashMap<String, HashSet<Uuid>>,
323) -> AggregateReport {
324 const AGG_SEED: u64 = 0x4B1D_C0DE;
325 let empty: HashSet<Uuid> = HashSet::new();
326
327 let per_query = |rec: &QueryRecord| -> (f32, f32, f32, f32, f32) {
328 let rel = relevant_sets.get(&rec.query_id).unwrap_or(&empty);
329 (
330 metrics::recall_at_k(&rec.ranked_ids, rel, 1),
331 metrics::recall_at_k(&rec.ranked_ids, rel, 3),
332 metrics::recall_at_k(&rec.ranked_ids, rel, 5),
333 metrics::mrr(&rec.ranked_ids, rel),
334 metrics::precision_at_k(&rec.ranked_ids, rel, 5),
335 )
336 };
337
338 let arm = |records: &[QueryRecord]| -> ArmMetrics {
339 if records.is_empty() {
340 return ArmMetrics {
341 recall_at_1: 0.0,
342 recall_at_3: 0.0,
343 recall_at_5: 0.0,
344 mrr: 0.0,
345 precision_at_5: 0.0,
346 };
347 }
348 let n = records.len() as f32;
349 let mut acc = (0.0, 0.0, 0.0, 0.0, 0.0);
350 for r in records {
351 let (r1, r3, r5, m, p5) = per_query(r);
352 acc.0 += r1;
353 acc.1 += r3;
354 acc.2 += r5;
355 acc.3 += m;
356 acc.4 += p5;
357 }
358 ArmMetrics {
359 recall_at_1: acc.0 / n,
360 recall_at_3: acc.1 / n,
361 recall_at_5: acc.2 / n,
362 mrr: acc.3 / n,
363 precision_at_5: acc.4 / n,
364 }
365 };
366
367 let bare_idx: std::collections::HashMap<&str, &QueryRecord> =
368 bare.iter().map(|r| (r.query_id.as_str(), r)).collect();
369
370 let mut delta_r3 = Vec::new();
371 let mut delta_mrr = Vec::new();
372 for p_rec in prefixed {
373 if let Some(b_rec) = bare_idx.get(p_rec.query_id.as_str()) {
374 let (_, p_r3, _, p_mrr, _) = per_query(p_rec);
375 let (_, b_r3, _, b_mrr, _) = per_query(b_rec);
376 delta_r3.push(p_r3 - b_r3);
377 delta_mrr.push(p_mrr - b_mrr);
378 }
379 }
380
381 let gated_rate = |records: &[QueryRecord]| -> f32 {
382 let flags: Vec<bool> = records
383 .iter()
384 .map(|r| match (r.gold_raw_rank, r.gold_raw_similarity) {
385 (Some(_), Some(sim)) => sim < BASELINE_THRESHOLD,
386 _ => false,
387 })
388 .collect();
389 metrics::gated_out_rate(&flags)
390 };
391
392 AggregateReport {
393 bare: arm(bare),
394 prefixed: arm(prefixed),
395 delta_recall_at_3_ci: bootstrap::paired_bootstrap_ci(&delta_r3, 10000, 0.95, AGG_SEED),
396 delta_mrr_ci: bootstrap::paired_bootstrap_ci(&delta_mrr, 10000, 0.95, AGG_SEED),
397 gated_out_rate_bare: gated_rate(bare),
398 gated_out_rate_prefixed: gated_rate(prefixed),
399 }
400}
401
402pub fn gate_sweep(prefixed: &[QueryRecord]) -> GateSweepReport {
403 const GRID_SEED: u64 = 0x5EED_6A7E;
404
405 let recalled_at = |rec: &QueryRecord, k: usize, t: f32| -> f32 {
406 match (rec.gold_raw_rank, rec.gold_raw_similarity) {
407 (Some(rank), Some(sim)) if rank <= k && sim >= t => 1.0,
408 _ => 0.0,
409 }
410 };
411
412 let mean = |vals: &[f32]| -> f32 {
413 if vals.is_empty() {
414 0.0
415 } else {
416 vals.iter().sum::<f32>() / vals.len() as f32
417 }
418 };
419
420 let baseline_r3: Vec<f32> = prefixed
421 .iter()
422 .map(|r| recalled_at(r, 3, BASELINE_THRESHOLD))
423 .collect();
424 let baseline_r3_mean = mean(&baseline_r3);
425
426 let mut frontier = Vec::with_capacity(41);
427 let mut chosen_threshold = BASELINE_THRESHOLD;
428 let mut chosen_beats_baseline = false;
429 let mut best_recall_at_3 = baseline_r3_mean;
430
431 for step in 0..=40u32 {
432 let t = 0.40 + step as f32 * 0.01;
433
434 let r1: Vec<f32> = prefixed.iter().map(|r| recalled_at(r, 1, t)).collect();
435 let r3: Vec<f32> = prefixed.iter().map(|r| recalled_at(r, 3, t)).collect();
436 let r5: Vec<f32> = prefixed.iter().map(|r| recalled_at(r, 5, t)).collect();
437
438 let recall_at_3 = mean(&r3);
439 let recall_at_5 = mean(&r5);
440
441 frontier.push(GatePoint {
442 threshold: t,
443 recall_at_1: mean(&r1),
444 recall_at_3,
445 recall_at_5,
446 precision_proxy: recall_at_5 / 5.0,
447 });
448
449 let deltas: Vec<f32> = r3
450 .iter()
451 .zip(baseline_r3.iter())
452 .map(|(t_val, b_val)| t_val - b_val)
453 .collect();
454 let (lo, _hi) = bootstrap::paired_bootstrap_ci(&deltas, 2000, 0.95, GRID_SEED);
455
456 if lo > 0.0 && recall_at_3 > best_recall_at_3 + 1e-6 {
457 best_recall_at_3 = recall_at_3;
458 chosen_threshold = t;
459 chosen_beats_baseline = true;
460 }
461 }
462
463 GateSweepReport {
464 frontier,
465 baseline_threshold: BASELINE_THRESHOLD,
466 chosen_threshold,
467 chosen_beats_baseline,
468 }
469}
470
471pub fn extract_corpus(store: &KuzuStore, out: &Path) -> Result<usize> {
472 let memories = store.all_memories_with_embeddings()?;
473 let mut file = std::fs::File::create(out)?;
474 let mut count = 0;
475 for m in &memories {
476 let entry = CorpusEntry {
477 id: m.id,
478 content: m.content.clone(),
479 memory_type: format!("{:?}", m.memory_type).to_lowercase(),
480 created_at: m.created_at.to_rfc3339(),
481 project_path: m.project_path.clone(),
482 };
483 writeln!(file, "{}", serde_json::to_string(&entry)?)?;
484 count += 1;
485 }
486 Ok(count)
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492 use std::io::Write;
493
494 fn record(id_rank: Option<usize>, raw_rank: Option<usize>, raw_sim: Option<f32>) -> QueryRecord {
495 QueryRecord {
496 query_id: "q".to_string(),
497 use_prefix: true,
498 ranked_ids: Vec::new(),
499 scores: Vec::new(),
500 first_relevant_rank: id_rank,
501 gold_raw_rank: raw_rank,
502 gold_raw_similarity: raw_sim,
503 }
504 }
505
506 #[test]
507 fn load_eval_set_parses_one_object_per_line() {
508 let dir = std::env::temp_dir();
509 let path = dir.join(format!("eval_set_{}.jsonl", Uuid::new_v4()));
510 let id_a = Uuid::new_v4();
511 let id_b = Uuid::new_v4();
512 let line1 = format!(
513 r#"{{"query_id":"q1","query":"kuzu choice","query_variant":"literal","seed_memory_id":"{id_a}","memory_type":"decision","relevant_memory_ids":["{id_a}"]}}"#
514 );
515 let line2 = format!(
516 r#"{{"query_id":"q2","query":"sync design","query_variant":"paraphrase","seed_memory_id":"{id_b}","memory_type":"architecture","relevant_memory_ids":["{id_b}","{id_a}"],"tags":["sync"]}}"#
517 );
518 let mut f = std::fs::File::create(&path).unwrap();
519 writeln!(f, "{line1}").unwrap();
520 writeln!(f, "{line2}").unwrap();
521 drop(f);
522
523 let queries = load_eval_set(&path).unwrap();
524 std::fs::remove_file(&path).ok();
525
526 assert_eq!(queries.len(), 2);
527 assert_eq!(queries[0].query_id, "q1");
528 assert_eq!(queries[0].seed_memory_id, id_a);
529 assert_eq!(queries[1].relevant_memory_ids.len(), 2);
530 assert_eq!(queries[1].tags, vec!["sync".to_string()]);
531 }
532
533 #[test]
534 fn load_eval_set_tolerates_blank_lines() {
535 let dir = std::env::temp_dir();
536 let path = dir.join(format!("eval_blank_{}.jsonl", Uuid::new_v4()));
537 let id = Uuid::new_v4();
538 let line = format!(
539 r#"{{"query_id":"q1","query":"x","query_variant":"v","seed_memory_id":"{id}","memory_type":"semantic","relevant_memory_ids":["{id}"]}}"#
540 );
541 std::fs::write(&path, format!("\n{line}\n\n")).unwrap();
542
543 let queries = load_eval_set(&path).unwrap();
544 std::fs::remove_file(&path).ok();
545
546 assert_eq!(queries.len(), 1);
547 }
548
549 #[test]
550 fn gate_sweep_emits_full_grid_and_monotone_recall() {
551 let records = vec![
552 record(Some(1), Some(1), Some(0.85)),
553 record(Some(2), Some(2), Some(0.62)),
554 record(Some(4), Some(4), Some(0.55)),
555 record(None, None, None),
556 ];
557
558 let report = gate_sweep(&records);
559
560 assert_eq!(report.frontier.len(), 41);
562 assert!((report.frontier.first().unwrap().threshold - 0.40).abs() < 1e-4);
563 assert!((report.frontier.last().unwrap().threshold - 0.80).abs() < 1e-4);
564
565 for w in report.frontier.windows(2) {
566 assert!(
567 w[0].recall_at_3 >= w[1].recall_at_3 - 1e-6,
568 "recall must not increase as the gate tightens"
569 );
570 }
571 }
572
573 #[test]
574 fn gate_sweep_recall_reflects_raw_rank_and_similarity() {
575 let records = vec![
576 record(Some(1), Some(1), Some(0.85)),
577 record(Some(2), Some(2), Some(0.62)),
578 record(Some(4), Some(4), Some(0.55)),
579 record(None, None, None),
580 ];
581
582 let report = gate_sweep(&records);
583
584 let at = |t: f32| {
585 report
586 .frontier
587 .iter()
588 .find(|p| (p.threshold - t).abs() < 1e-4)
589 .unwrap()
590 };
591
592 let p050 = at(0.50);
595 assert!((p050.recall_at_1 - 0.25).abs() < 1e-6, "recall@1 was {}", p050.recall_at_1);
596 assert!((p050.recall_at_3 - 0.5).abs() < 1e-6, "recall@3 was {}", p050.recall_at_3);
597 assert!((p050.recall_at_5 - 0.75).abs() < 1e-6, "recall@5 was {}", p050.recall_at_5);
598
599 let p070 = at(0.70);
601 assert!((p070.recall_at_1 - 0.25).abs() < 1e-6);
602 assert!((p070.recall_at_3 - 0.25).abs() < 1e-6);
603 assert!((p070.recall_at_5 - 0.25).abs() < 1e-6);
604 }
605
606 #[test]
607 fn gate_sweep_keeps_baseline_when_nothing_beats_it() {
608 let records = vec![record(Some(1), Some(1), Some(0.85))];
609 let report = gate_sweep(&records);
610 assert!((report.baseline_threshold - 0.59).abs() < 1e-6);
611 assert!(!report.chosen_beats_baseline);
612 assert!((report.chosen_threshold - 0.59).abs() < 1e-6);
613 }
614}