Skip to main content

second_brain_api/eval/
mod.rs

1pub mod bootstrap;
2pub mod caching_store;
3pub mod metrics;
4
5use std::collections::HashSet;
6use std::io::Write;
7use std::path::Path;
8use std::sync::Mutex;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::thread;
11
12use anyhow::Result;
13use serde::{Deserialize, Serialize};
14use uuid::Uuid;
15
16use second_brain_core::embedding::Embedder;
17use second_brain_core::kuzu_store::KuzuStore;
18use second_brain_core::query::{QueryEngine, QueryFilters, QueryRequest};
19use second_brain_core::store::Store;
20
21#[derive(Debug, Clone, Deserialize, Serialize)]
22pub struct EvalQuery {
23    pub query_id: String,
24    pub query: String,
25    pub query_variant: String,
26    pub seed_memory_id: Uuid,
27    pub memory_type: String,
28    pub relevant_memory_ids: Vec<Uuid>,
29    #[serde(default)]
30    pub note: Option<String>,
31    #[serde(default)]
32    pub tags: Vec<String>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct QueryRecord {
37    pub query_id: String,
38    pub use_prefix: bool,
39    pub ranked_ids: Vec<Uuid>,
40    pub scores: Vec<f32>,
41    pub first_relevant_rank: Option<usize>,
42    pub gold_raw_rank: Option<usize>,
43    pub gold_raw_similarity: Option<f32>,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ArmMetrics {
48    pub recall_at_1: f32,
49    pub recall_at_3: f32,
50    pub recall_at_5: f32,
51    pub mrr: f32,
52    pub precision_at_5: f32,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct AggregateReport {
57    pub bare: ArmMetrics,
58    pub prefixed: ArmMetrics,
59    pub delta_recall_at_3_ci: (f32, f32),
60    pub delta_mrr_ci: (f32, f32),
61    pub gated_out_rate_bare: f32,
62    pub gated_out_rate_prefixed: f32,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct GatePoint {
67    pub threshold: f32,
68    pub recall_at_1: f32,
69    pub recall_at_3: f32,
70    pub recall_at_5: f32,
71    pub precision_proxy: f32,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct GateSweepReport {
76    pub frontier: Vec<GatePoint>,
77    pub baseline_threshold: f32,
78    pub chosen_threshold: f32,
79    pub chosen_beats_baseline: bool,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct CorpusEntry {
84    pub id: Uuid,
85    pub content: String,
86    pub memory_type: String,
87    pub created_at: String,
88    pub project_path: Option<String>,
89}
90
91const BASELINE_THRESHOLD: f32 = 0.59;
92
93pub fn load_eval_set(path: &Path) -> Result<Vec<EvalQuery>> {
94    let text = std::fs::read_to_string(path)?;
95    let mut out = Vec::new();
96    for line in text.lines() {
97        let trimmed = line.trim();
98        if trimmed.is_empty() {
99            continue;
100        }
101        out.push(serde_json::from_str(trimmed)?);
102    }
103    Ok(out)
104}
105
106pub fn run_arm(
107    store: &KuzuStore,
108    embedder: &Embedder,
109    queries: &[EvalQuery],
110    use_prefix: bool,
111    limit: usize,
112) -> Result<Vec<QueryRecord>> {
113    let engine = QueryEngine::new(store);
114    let mut records = Vec::with_capacity(queries.len());
115
116    for q in queries {
117        let embedding = if use_prefix {
118            embedder.embed_query(&q.query)?
119        } else {
120            embedder.embed(&q.query)?
121        };
122
123        let relevant: HashSet<Uuid> = q.relevant_memory_ids.iter().copied().collect();
124
125        let request = QueryRequest {
126            text: q.query.clone(),
127            embedding: embedding.clone(),
128            limit,
129            filters: QueryFilters::default(),
130        };
131        let results = engine.recall(&request)?;
132
133        let ranked_ids: Vec<Uuid> = results.iter().map(|r| r.memory.id).collect();
134        let scores: Vec<f32> = results.iter().map(|r| r.score).collect();
135        let first_relevant_rank = ranked_ids
136            .iter()
137            .position(|id| relevant.contains(id))
138            .map(|idx| idx + 1);
139
140        let raw = store.vector_search(&embedding, limit * 3)?;
141        let mut gold_raw_rank = None;
142        let mut gold_raw_similarity = None;
143        for (idx, (mem, sim)) in raw.iter().enumerate() {
144            if relevant.contains(&mem.id) {
145                gold_raw_rank = Some(idx + 1);
146                gold_raw_similarity = Some(*sim);
147                break;
148            }
149        }
150
151        records.push(QueryRecord {
152            query_id: q.query_id.clone(),
153            use_prefix,
154            ranked_ids,
155            scores,
156            first_relevant_rank,
157            gold_raw_rank,
158            gold_raw_similarity,
159        });
160    }
161
162    Ok(records)
163}
164
165pub struct EmbeddedQuery {
166    pub query: EvalQuery,
167    pub relevant: HashSet<Uuid>,
168    pub bare_embedding: Vec<f32>,
169    pub prefixed_embedding: Vec<f32>,
170}
171
172// The Embedder wraps the model in a Mutex, so embedding cannot run in parallel.
173// We pre-compute every embedding serially here, in batches, to keep the model
174// out of the parallel recall section entirely.
175pub fn embed_all_queries(
176    embedder: &Embedder,
177    queries: &[EvalQuery],
178) -> Result<Vec<EmbeddedQuery>> {
179    let bare_texts: Vec<&str> = queries.iter().map(|q| q.query.as_str()).collect();
180    let bare = embedder.embed_batch(&bare_texts)?;
181
182    let prefixed_owned: Vec<String> = queries
183        .iter()
184        .map(|q| second_brain_core::embedding::query_prompt(&q.query))
185        .collect();
186    let prefixed_texts: Vec<&str> = prefixed_owned.iter().map(|s| s.as_str()).collect();
187    let prefixed = embedder.embed_batch(&prefixed_texts)?;
188
189    let mut out = Vec::with_capacity(queries.len());
190    for (i, q) in queries.iter().enumerate() {
191        out.push(EmbeddedQuery {
192            query: q.clone(),
193            relevant: q.relevant_memory_ids.iter().copied().collect(),
194            bare_embedding: bare[i].clone(),
195            prefixed_embedding: prefixed[i].clone(),
196        });
197    }
198    Ok(out)
199}
200
201fn record_for<S: Store + Sync>(
202    embedded: &EmbeddedQuery,
203    store: &S,
204    use_prefix: bool,
205    limit: usize,
206) -> Result<QueryRecord> {
207    let engine = QueryEngine::new(store);
208    let embedding = if use_prefix {
209        &embedded.prefixed_embedding
210    } else {
211        &embedded.bare_embedding
212    };
213
214    let request = QueryRequest {
215        text: embedded.query.query.clone(),
216        embedding: embedding.clone(),
217        limit,
218        filters: QueryFilters::default(),
219    };
220    let results = engine.recall(&request)?;
221
222    let ranked_ids: Vec<Uuid> = results.iter().map(|r| r.memory.id).collect();
223    let scores: Vec<f32> = results.iter().map(|r| r.score).collect();
224    let first_relevant_rank = ranked_ids
225        .iter()
226        .position(|id| embedded.relevant.contains(id))
227        .map(|idx| idx + 1);
228
229    let raw = store.vector_search(embedding, limit * 3)?;
230    let mut gold_raw_rank = None;
231    let mut gold_raw_similarity = None;
232    for (idx, (mem, sim)) in raw.iter().enumerate() {
233        if embedded.relevant.contains(&mem.id) {
234            gold_raw_rank = Some(idx + 1);
235            gold_raw_similarity = Some(*sim);
236            break;
237        }
238    }
239
240    Ok(QueryRecord {
241        query_id: embedded.query.query_id.clone(),
242        use_prefix,
243        ranked_ids,
244        scores,
245        first_relevant_rank,
246        gold_raw_rank,
247        gold_raw_similarity,
248    })
249}
250
251// Each store call creates its own Kuzu Connection (KuzuStore::conn does
252// Connection::new per call), which is the same concurrent-read pattern the
253// daemon uses to serve overlapping recall requests, so sharing &store across
254// scoped threads is safe. The Embedder is absent here on purpose: embeddings
255// are pre-computed in embed_all_queries.
256pub fn run_arm_parallel<S: Store + Sync>(
257    store: &S,
258    embedded: &[EmbeddedQuery],
259    use_prefix: bool,
260    limit: usize,
261) -> Result<Vec<QueryRecord>> {
262    let total = embedded.len();
263    if total == 0 {
264        return Ok(Vec::new());
265    }
266
267    let workers = thread::available_parallelism()
268        .map(|n| n.get().saturating_sub(1).max(1))
269        .unwrap_or(1);
270    let chunk_size = total.div_ceil(workers);
271
272    let done = AtomicUsize::new(0);
273    let collected: Mutex<Vec<(usize, QueryRecord)>> = Mutex::new(Vec::with_capacity(total));
274    let error: Mutex<Option<anyhow::Error>> = Mutex::new(None);
275
276    thread::scope(|scope| {
277        for chunk in 0..workers {
278            let start = chunk * chunk_size;
279            if start >= total {
280                break;
281            }
282            let end = (start + chunk_size).min(total);
283            let done = &done;
284            let collected = &collected;
285            let error = &error;
286            scope.spawn(move || {
287                let mut local: Vec<(usize, QueryRecord)> = Vec::with_capacity(end - start);
288                for (offset, eq) in embedded[start..end].iter().enumerate() {
289                    if error.lock().unwrap().is_some() {
290                        return;
291                    }
292                    match record_for(eq, store, use_prefix, limit) {
293                        Ok(rec) => local.push((start + offset, rec)),
294                        Err(e) => {
295                            *error.lock().unwrap() = Some(e);
296                            return;
297                        }
298                    }
299                    // eprintln progress because long runs were previously invisible.
300                    let n = done.fetch_add(1, Ordering::Relaxed) + 1;
301                    if n % 25 == 0 || n == total {
302                        eprintln!("  {n}/{total} queries");
303                    }
304                }
305                collected.lock().unwrap().extend(local);
306            });
307        }
308    });
309
310    if let Some(e) = error.into_inner().unwrap() {
311        return Err(e);
312    }
313
314    let mut indexed = collected.into_inner().unwrap();
315    indexed.sort_by_key(|(i, _)| *i);
316    Ok(indexed.into_iter().map(|(_, r)| r).collect())
317}
318
319pub fn aggregate(
320    bare: &[QueryRecord],
321    prefixed: &[QueryRecord],
322    relevant_sets: &std::collections::HashMap<String, HashSet<Uuid>>,
323) -> AggregateReport {
324    const AGG_SEED: u64 = 0x4B1D_C0DE;
325    let empty: HashSet<Uuid> = HashSet::new();
326
327    let per_query = |rec: &QueryRecord| -> (f32, f32, f32, f32, f32) {
328        let rel = relevant_sets.get(&rec.query_id).unwrap_or(&empty);
329        (
330            metrics::recall_at_k(&rec.ranked_ids, rel, 1),
331            metrics::recall_at_k(&rec.ranked_ids, rel, 3),
332            metrics::recall_at_k(&rec.ranked_ids, rel, 5),
333            metrics::mrr(&rec.ranked_ids, rel),
334            metrics::precision_at_k(&rec.ranked_ids, rel, 5),
335        )
336    };
337
338    let arm = |records: &[QueryRecord]| -> ArmMetrics {
339        if records.is_empty() {
340            return ArmMetrics {
341                recall_at_1: 0.0,
342                recall_at_3: 0.0,
343                recall_at_5: 0.0,
344                mrr: 0.0,
345                precision_at_5: 0.0,
346            };
347        }
348        let n = records.len() as f32;
349        let mut acc = (0.0, 0.0, 0.0, 0.0, 0.0);
350        for r in records {
351            let (r1, r3, r5, m, p5) = per_query(r);
352            acc.0 += r1;
353            acc.1 += r3;
354            acc.2 += r5;
355            acc.3 += m;
356            acc.4 += p5;
357        }
358        ArmMetrics {
359            recall_at_1: acc.0 / n,
360            recall_at_3: acc.1 / n,
361            recall_at_5: acc.2 / n,
362            mrr: acc.3 / n,
363            precision_at_5: acc.4 / n,
364        }
365    };
366
367    let bare_idx: std::collections::HashMap<&str, &QueryRecord> =
368        bare.iter().map(|r| (r.query_id.as_str(), r)).collect();
369
370    let mut delta_r3 = Vec::new();
371    let mut delta_mrr = Vec::new();
372    for p_rec in prefixed {
373        if let Some(b_rec) = bare_idx.get(p_rec.query_id.as_str()) {
374            let (_, p_r3, _, p_mrr, _) = per_query(p_rec);
375            let (_, b_r3, _, b_mrr, _) = per_query(b_rec);
376            delta_r3.push(p_r3 - b_r3);
377            delta_mrr.push(p_mrr - b_mrr);
378        }
379    }
380
381    let gated_rate = |records: &[QueryRecord]| -> f32 {
382        let flags: Vec<bool> = records
383            .iter()
384            .map(|r| match (r.gold_raw_rank, r.gold_raw_similarity) {
385                (Some(_), Some(sim)) => sim < BASELINE_THRESHOLD,
386                _ => false,
387            })
388            .collect();
389        metrics::gated_out_rate(&flags)
390    };
391
392    AggregateReport {
393        bare: arm(bare),
394        prefixed: arm(prefixed),
395        delta_recall_at_3_ci: bootstrap::paired_bootstrap_ci(&delta_r3, 10000, 0.95, AGG_SEED),
396        delta_mrr_ci: bootstrap::paired_bootstrap_ci(&delta_mrr, 10000, 0.95, AGG_SEED),
397        gated_out_rate_bare: gated_rate(bare),
398        gated_out_rate_prefixed: gated_rate(prefixed),
399    }
400}
401
402pub fn gate_sweep(prefixed: &[QueryRecord]) -> GateSweepReport {
403    const GRID_SEED: u64 = 0x5EED_6A7E;
404
405    let recalled_at = |rec: &QueryRecord, k: usize, t: f32| -> f32 {
406        match (rec.gold_raw_rank, rec.gold_raw_similarity) {
407            (Some(rank), Some(sim)) if rank <= k && sim >= t => 1.0,
408            _ => 0.0,
409        }
410    };
411
412    let mean = |vals: &[f32]| -> f32 {
413        if vals.is_empty() {
414            0.0
415        } else {
416            vals.iter().sum::<f32>() / vals.len() as f32
417        }
418    };
419
420    let baseline_r3: Vec<f32> = prefixed
421        .iter()
422        .map(|r| recalled_at(r, 3, BASELINE_THRESHOLD))
423        .collect();
424    let baseline_r3_mean = mean(&baseline_r3);
425
426    let mut frontier = Vec::with_capacity(41);
427    let mut chosen_threshold = BASELINE_THRESHOLD;
428    let mut chosen_beats_baseline = false;
429    let mut best_recall_at_3 = baseline_r3_mean;
430
431    for step in 0..=40u32 {
432        let t = 0.40 + step as f32 * 0.01;
433
434        let r1: Vec<f32> = prefixed.iter().map(|r| recalled_at(r, 1, t)).collect();
435        let r3: Vec<f32> = prefixed.iter().map(|r| recalled_at(r, 3, t)).collect();
436        let r5: Vec<f32> = prefixed.iter().map(|r| recalled_at(r, 5, t)).collect();
437
438        let recall_at_3 = mean(&r3);
439        let recall_at_5 = mean(&r5);
440
441        frontier.push(GatePoint {
442            threshold: t,
443            recall_at_1: mean(&r1),
444            recall_at_3,
445            recall_at_5,
446            precision_proxy: recall_at_5 / 5.0,
447        });
448
449        let deltas: Vec<f32> = r3
450            .iter()
451            .zip(baseline_r3.iter())
452            .map(|(t_val, b_val)| t_val - b_val)
453            .collect();
454        let (lo, _hi) = bootstrap::paired_bootstrap_ci(&deltas, 2000, 0.95, GRID_SEED);
455
456        if lo > 0.0 && recall_at_3 > best_recall_at_3 + 1e-6 {
457            best_recall_at_3 = recall_at_3;
458            chosen_threshold = t;
459            chosen_beats_baseline = true;
460        }
461    }
462
463    GateSweepReport {
464        frontier,
465        baseline_threshold: BASELINE_THRESHOLD,
466        chosen_threshold,
467        chosen_beats_baseline,
468    }
469}
470
471pub fn extract_corpus(store: &KuzuStore, out: &Path) -> Result<usize> {
472    let memories = store.all_memories_with_embeddings()?;
473    let mut file = std::fs::File::create(out)?;
474    let mut count = 0;
475    for m in &memories {
476        let entry = CorpusEntry {
477            id: m.id,
478            content: m.content.clone(),
479            memory_type: format!("{:?}", m.memory_type).to_lowercase(),
480            created_at: m.created_at.to_rfc3339(),
481            project_path: m.project_path.clone(),
482        };
483        writeln!(file, "{}", serde_json::to_string(&entry)?)?;
484        count += 1;
485    }
486    Ok(count)
487}
488
489#[cfg(test)]
490mod tests {
491    use super::*;
492    use std::io::Write;
493
494    fn record(id_rank: Option<usize>, raw_rank: Option<usize>, raw_sim: Option<f32>) -> QueryRecord {
495        QueryRecord {
496            query_id: "q".to_string(),
497            use_prefix: true,
498            ranked_ids: Vec::new(),
499            scores: Vec::new(),
500            first_relevant_rank: id_rank,
501            gold_raw_rank: raw_rank,
502            gold_raw_similarity: raw_sim,
503        }
504    }
505
506    #[test]
507    fn load_eval_set_parses_one_object_per_line() {
508        let dir = std::env::temp_dir();
509        let path = dir.join(format!("eval_set_{}.jsonl", Uuid::new_v4()));
510        let id_a = Uuid::new_v4();
511        let id_b = Uuid::new_v4();
512        let line1 = format!(
513            r#"{{"query_id":"q1","query":"kuzu choice","query_variant":"literal","seed_memory_id":"{id_a}","memory_type":"decision","relevant_memory_ids":["{id_a}"]}}"#
514        );
515        let line2 = format!(
516            r#"{{"query_id":"q2","query":"sync design","query_variant":"paraphrase","seed_memory_id":"{id_b}","memory_type":"architecture","relevant_memory_ids":["{id_b}","{id_a}"],"tags":["sync"]}}"#
517        );
518        let mut f = std::fs::File::create(&path).unwrap();
519        writeln!(f, "{line1}").unwrap();
520        writeln!(f, "{line2}").unwrap();
521        drop(f);
522
523        let queries = load_eval_set(&path).unwrap();
524        std::fs::remove_file(&path).ok();
525
526        assert_eq!(queries.len(), 2);
527        assert_eq!(queries[0].query_id, "q1");
528        assert_eq!(queries[0].seed_memory_id, id_a);
529        assert_eq!(queries[1].relevant_memory_ids.len(), 2);
530        assert_eq!(queries[1].tags, vec!["sync".to_string()]);
531    }
532
533    #[test]
534    fn load_eval_set_tolerates_blank_lines() {
535        let dir = std::env::temp_dir();
536        let path = dir.join(format!("eval_blank_{}.jsonl", Uuid::new_v4()));
537        let id = Uuid::new_v4();
538        let line = format!(
539            r#"{{"query_id":"q1","query":"x","query_variant":"v","seed_memory_id":"{id}","memory_type":"semantic","relevant_memory_ids":["{id}"]}}"#
540        );
541        std::fs::write(&path, format!("\n{line}\n\n")).unwrap();
542
543        let queries = load_eval_set(&path).unwrap();
544        std::fs::remove_file(&path).ok();
545
546        assert_eq!(queries.len(), 1);
547    }
548
549    #[test]
550    fn gate_sweep_emits_full_grid_and_monotone_recall() {
551        let records = vec![
552            record(Some(1), Some(1), Some(0.85)),
553            record(Some(2), Some(2), Some(0.62)),
554            record(Some(4), Some(4), Some(0.55)),
555            record(None, None, None),
556        ];
557
558        let report = gate_sweep(&records);
559
560        // grid 0.40..=0.80 step 0.01 inclusive is 41 points.
561        assert_eq!(report.frontier.len(), 41);
562        assert!((report.frontier.first().unwrap().threshold - 0.40).abs() < 1e-4);
563        assert!((report.frontier.last().unwrap().threshold - 0.80).abs() < 1e-4);
564
565        for w in report.frontier.windows(2) {
566            assert!(
567                w[0].recall_at_3 >= w[1].recall_at_3 - 1e-6,
568                "recall must not increase as the gate tightens"
569            );
570        }
571    }
572
573    #[test]
574    fn gate_sweep_recall_reflects_raw_rank_and_similarity() {
575        let records = vec![
576            record(Some(1), Some(1), Some(0.85)),
577            record(Some(2), Some(2), Some(0.62)),
578            record(Some(4), Some(4), Some(0.55)),
579            record(None, None, None),
580        ];
581
582        let report = gate_sweep(&records);
583
584        let at = |t: f32| {
585            report
586                .frontier
587                .iter()
588                .find(|p| (p.threshold - t).abs() < 1e-4)
589                .unwrap()
590        };
591
592        // T=0.50: golds at sim 0.85, 0.62, 0.55 survive; ranks 1,2,4. recall@3 covers
593        // the first two (rank<=3 and sim>=0.50) so 2/4 = 0.5; recall@5 covers all three = 0.75.
594        let p050 = at(0.50);
595        assert!((p050.recall_at_1 - 0.25).abs() < 1e-6, "recall@1 was {}", p050.recall_at_1);
596        assert!((p050.recall_at_3 - 0.5).abs() < 1e-6, "recall@3 was {}", p050.recall_at_3);
597        assert!((p050.recall_at_5 - 0.75).abs() < 1e-6, "recall@5 was {}", p050.recall_at_5);
598
599        // T=0.70: only the 0.85 gold survives, at rank 1.
600        let p070 = at(0.70);
601        assert!((p070.recall_at_1 - 0.25).abs() < 1e-6);
602        assert!((p070.recall_at_3 - 0.25).abs() < 1e-6);
603        assert!((p070.recall_at_5 - 0.25).abs() < 1e-6);
604    }
605
606    #[test]
607    fn gate_sweep_keeps_baseline_when_nothing_beats_it() {
608        let records = vec![record(Some(1), Some(1), Some(0.85))];
609        let report = gate_sweep(&records);
610        assert!((report.baseline_threshold - 0.59).abs() < 1e-6);
611        assert!(!report.chosen_beats_baseline);
612        assert!((report.chosen_threshold - 0.59).abs() < 1e-6);
613    }
614}