Skip to main content

vela_protocol/
search.rs

1//! Full-text search across findings in a frontier or VelaRepo.
2
3use std::path::Path;
4
5use colored::Colorize;
6
7use crate::cli_style as style;
8
9use crate::bundle::FindingBundle;
10use crate::project::Project;
11use crate::repo;
12
13/// A single search result with relevance score.
14pub struct SearchResult {
15    pub id: String,
16    pub score: f32,
17    pub assertion: String,
18    pub assertion_type: String,
19    pub confidence: f64,
20    pub entities: Vec<String>,
21    pub doi: Option<String>,
22}
23
24/// Search findings by query text, with optional entity and assertion type filters.
25///
26/// Scoring: query word matches in assertion text (x2), entity names (x3),
27/// conditions text (x1). Normalized by total query words.
28pub fn search(
29    source_path: &Path,
30    query: &str,
31    entity_filter: Option<&str>,
32    type_filter: Option<&str>,
33    limit: usize,
34) -> Vec<SearchResult> {
35    let frontier = match repo::load_from_path(source_path) {
36        Ok(c) => c,
37        Err(e) => {
38            eprintln!("{} failed to load frontier: {e}", style::err_prefix());
39            return Vec::new();
40        }
41    };
42
43    let query_words: Vec<String> = query
44        .to_lowercase()
45        .split_whitespace()
46        .map(|w| w.to_string())
47        .collect();
48
49    if query_words.is_empty() {
50        return Vec::new();
51    }
52
53    let mut results: Vec<SearchResult> = frontier
54        .findings
55        .iter()
56        .filter(|f| {
57            if let Some(ef) = entity_filter {
58                let ef_lower = ef.to_lowercase();
59                if !f
60                    .assertion
61                    .entities
62                    .iter()
63                    .any(|e| e.name.to_lowercase().contains(&ef_lower))
64                {
65                    return false;
66                }
67            }
68            if let Some(tf) = type_filter
69                && f.assertion.assertion_type.to_lowercase() != tf.to_lowercase()
70            {
71                return false;
72            }
73            true
74        })
75        .filter_map(|f| {
76            let score = score_finding(f, &query_words);
77            if score > 0.0 {
78                Some(SearchResult {
79                    id: f.id.clone(),
80                    score,
81                    assertion: f.assertion.text.clone(),
82                    assertion_type: f.assertion.assertion_type.clone(),
83                    confidence: f.confidence.score,
84                    entities: f
85                        .assertion
86                        .entities
87                        .iter()
88                        .map(|e| e.name.clone())
89                        .collect(),
90                    doi: f.provenance.doi.clone(),
91                })
92            } else {
93                None
94            }
95        })
96        .collect();
97
98    results.sort_by(|a, b| {
99        b.score
100            .partial_cmp(&a.score)
101            .unwrap_or(std::cmp::Ordering::Equal)
102    });
103    results.truncate(limit);
104    results
105}
106
107/// Score a finding against query words.
108fn score_finding(finding: &FindingBundle, query_words: &[String]) -> f32 {
109    let assertion_lower = finding.assertion.text.to_lowercase();
110    let conditions_lower = finding.conditions.text.to_lowercase();
111
112    let mut total_score: f32 = 0.0;
113
114    for word in query_words {
115        // Assertion text matches (weight x2)
116        if assertion_lower.contains(word.as_str()) {
117            total_score += 2.0;
118        }
119
120        // Entity name matches (weight x3)
121        for entity in &finding.assertion.entities {
122            if entity.name.to_lowercase().contains(word.as_str()) {
123                total_score += 3.0;
124            }
125        }
126
127        // Conditions text matches (weight x1)
128        if conditions_lower.contains(word.as_str()) {
129            total_score += 1.0;
130        }
131    }
132
133    // Normalize by number of query words
134    total_score / query_words.len() as f32
135}
136
137fn truncate_for_cli(text: &str, max_chars: usize) -> String {
138    if text.chars().count() <= max_chars {
139        return text.to_string();
140    }
141    let keep = max_chars.saturating_sub(3);
142    let prefix: String = text.chars().take(keep).collect();
143    format!("{prefix}...")
144}
145
146/// CLI entry point for `vela search`.
147pub fn run(
148    source: &Path,
149    query: &str,
150    entity: Option<&str>,
151    type_filter: Option<&str>,
152    limit: usize,
153) {
154    let results = search(source, query, entity, type_filter, limit);
155
156    if results.is_empty() {
157        println!("no findings matched the query.");
158        return;
159    }
160
161    println!();
162    println!(
163        "  {} results for {}",
164        results.len(),
165        format!("\"{}\"", query).bold()
166    );
167    println!("  {}", style::tick_row(60));
168
169    for (i, r) in results.iter().enumerate() {
170        let truncated = truncate_for_cli(&r.assertion, 120);
171
172        println!(
173            "  {}. {} [score: {:.2}] [conf: {:.2}] [{}]",
174            (i + 1).to_string().dimmed(),
175            style::signal(&r.id),
176            r.score,
177            r.confidence,
178            style::dust_color(&r.assertion_type),
179        );
180        println!("     {}", truncated);
181        if !r.entities.is_empty() {
182            println!("     entities: {}", r.entities.join(", ").dimmed());
183        }
184        if let Some(doi) = &r.doi {
185            println!("     doi: {}", doi.dimmed());
186        }
187        println!();
188    }
189}
190
191// ── Cross-frontier search ───────────────────────────────────────────
192
193/// A search result grouped by source frontier.
194#[allow(dead_code)]
195pub struct CrossFrontierResult {
196    pub frontier_name: String,
197    pub frontier_file: String,
198    pub results: Vec<SearchResult>,
199}
200
201/// Search a pre-loaded frontier (avoids re-loading from disk).
202pub fn search_frontier(
203    frontier: &Project,
204    query: &str,
205    entity_filter: Option<&str>,
206    type_filter: Option<&str>,
207    limit: usize,
208) -> Vec<SearchResult> {
209    let query_words: Vec<String> = query
210        .to_lowercase()
211        .split_whitespace()
212        .map(|w| w.to_string())
213        .collect();
214
215    if query_words.is_empty() {
216        return Vec::new();
217    }
218
219    let mut results: Vec<SearchResult> = frontier
220        .findings
221        .iter()
222        .filter(|f| {
223            if let Some(ef) = entity_filter {
224                let ef_lower = ef.to_lowercase();
225                if !f
226                    .assertion
227                    .entities
228                    .iter()
229                    .any(|e| e.name.to_lowercase().contains(&ef_lower))
230                {
231                    return false;
232                }
233            }
234            if let Some(tf) = type_filter
235                && f.assertion.assertion_type.to_lowercase() != tf.to_lowercase()
236            {
237                return false;
238            }
239            true
240        })
241        .filter_map(|f| {
242            let score = score_finding(f, &query_words);
243            if score > 0.0 {
244                Some(SearchResult {
245                    id: f.id.clone(),
246                    score,
247                    assertion: f.assertion.text.clone(),
248                    assertion_type: f.assertion.assertion_type.clone(),
249                    confidence: f.confidence.score,
250                    entities: f
251                        .assertion
252                        .entities
253                        .iter()
254                        .map(|e| e.name.clone())
255                        .collect(),
256                    doi: f.provenance.doi.clone(),
257                })
258            } else {
259                None
260            }
261        })
262        .collect();
263
264    results.sort_by(|a, b| {
265        b.score
266            .partial_cmp(&a.score)
267            .unwrap_or(std::cmp::Ordering::Equal)
268    });
269    results.truncate(limit);
270    results
271}
272
273/// Search across all `.json` frontier files in a directory.
274///
275/// Loads each frontier, runs scored search, collects top `limit` results
276/// sorted by score across all frontiers, then groups by frontier.
277pub fn search_all(
278    dir: &Path,
279    query: &str,
280    entity_filter: Option<&str>,
281    type_filter: Option<&str>,
282    limit: usize,
283) -> Vec<CrossFrontierResult> {
284    // Collect all .json files in the directory
285    let entries: Vec<std::path::PathBuf> = match std::fs::read_dir(dir) {
286        Ok(rd) => rd
287            .filter_map(|e| e.ok())
288            .map(|e| e.path())
289            .filter(|p| p.extension().is_some_and(|ext| ext == "json"))
290            .collect(),
291        Err(e) => {
292            eprintln!(
293                "{} failed to read directory '{}': {e}",
294                style::err_prefix(),
295                dir.display()
296            );
297            return Vec::new();
298        }
299    };
300
301    if entries.is_empty() {
302        eprintln!("no .json frontier files found in {}", dir.display());
303        return Vec::new();
304    }
305
306    // Score every finding across all frontiers, keeping frontier provenance
307    let mut scored: Vec<(String, String, SearchResult)> = Vec::new(); // (name, file, result)
308
309    for path in &entries {
310        let frontier = match repo::load_from_path(path) {
311            Ok(c) => c,
312            Err(e) => {
313                eprintln!("skipping {}: {e}", path.display());
314                continue;
315            }
316        };
317
318        let name = frontier.project.name.clone();
319        let file = path
320            .file_name()
321            .unwrap_or_default()
322            .to_string_lossy()
323            .to_string();
324
325        // Get all results from this frontier (no per-frontier limit yet)
326        let results = search_frontier(&frontier, query, entity_filter, type_filter, usize::MAX);
327        for r in results {
328            scored.push((name.clone(), file.clone(), r));
329        }
330    }
331
332    // Sort all results by score descending, take top `limit`
333    scored.sort_by(|a, b| {
334        b.2.score
335            .partial_cmp(&a.2.score)
336            .unwrap_or(std::cmp::Ordering::Equal)
337    });
338    scored.truncate(limit);
339
340    // Group by frontier, preserving sort order within each group
341    let mut groups: Vec<CrossFrontierResult> = Vec::new();
342    let mut seen_frontiers: std::collections::HashMap<String, usize> =
343        std::collections::HashMap::new();
344
345    for (name, file, result) in scored {
346        if let Some(&idx) = seen_frontiers.get(&file) {
347            groups[idx].results.push(result);
348        } else {
349            let idx = groups.len();
350            seen_frontiers.insert(file.clone(), idx);
351            groups.push(CrossFrontierResult {
352                frontier_name: name,
353                frontier_file: file,
354                results: vec![result],
355            });
356        }
357    }
358
359    groups
360}
361
362/// CLI entry point for `vela search --all <dir>`.
363pub fn run_all(
364    dir: &Path,
365    query: &str,
366    entity: Option<&str>,
367    type_filter: Option<&str>,
368    limit: usize,
369) {
370    let groups = search_all(dir, query, entity, type_filter, limit);
371
372    if groups.is_empty() {
373        println!("no findings matched the query across any frontier.");
374        return;
375    }
376
377    let total_results: usize = groups.iter().map(|g| g.results.len()).sum();
378    let frontier_count = groups.len();
379
380    println!();
381    println!(
382        "  {} results across {} frontiers for {}",
383        total_results,
384        frontier_count,
385        format!("\"{}\"", query).bold(),
386    );
387    println!("  {}", style::tick_row(60));
388
389    for group in &groups {
390        let stem = group
391            .frontier_file
392            .strip_suffix(".json")
393            .unwrap_or(&group.frontier_file);
394        println!(
395            "  [{}] {} results",
396            style::signal(stem),
397            group.results.len()
398        );
399        for (i, r) in group.results.iter().enumerate() {
400            let truncated = truncate_for_cli(&r.assertion, 100);
401            println!(
402                "    {}. {} [score: {:.1}] [conf: {:.2}] {}",
403                (i + 1).to_string().dimmed(),
404                style::signal(&r.id),
405                r.score,
406                r.confidence,
407                truncated,
408            );
409        }
410        println!();
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417    use crate::bundle::*;
418    use crate::project;
419    use tempfile::TempDir;
420
421    fn make_finding(
422        id: &str,
423        assertion: &str,
424        assertion_type: &str,
425        entities: Vec<(&str, &str)>,
426        conditions: &str,
427        confidence: f64,
428        doi: Option<&str>,
429    ) -> FindingBundle {
430        FindingBundle {
431            id: id.into(),
432            version: 1,
433            previous_version: None,
434            assertion: Assertion {
435                text: assertion.into(),
436                assertion_type: assertion_type.into(),
437                entities: entities
438                    .iter()
439                    .map(|(name, etype)| Entity {
440                        name: name.to_string(),
441                        entity_type: etype.to_string(),
442                        identifiers: serde_json::Map::new(),
443                        canonical_id: None,
444                        candidates: vec![],
445                        aliases: vec![],
446                        resolution_provenance: None,
447                        resolution_confidence: 1.0,
448                        resolution_method: None,
449                        species_context: None,
450                        needs_review: false,
451                    })
452                    .collect(),
453                relation: None,
454                direction: None,
455                causal_claim: None,
456                causal_evidence_grade: None,
457            },
458            evidence: Evidence {
459                evidence_type: "experimental".into(),
460                model_system: String::new(),
461                species: None,
462                method: String::new(),
463                sample_size: None,
464                effect_size: None,
465                p_value: None,
466                replicated: false,
467                replication_count: None,
468                evidence_spans: vec![],
469            },
470            conditions: Conditions {
471                text: conditions.into(),
472                species_verified: vec![],
473                species_unverified: vec![],
474                in_vitro: false,
475                in_vivo: false,
476                human_data: false,
477                clinical_trial: false,
478                concentration_range: None,
479                duration: None,
480                age_group: None,
481                cell_type: None,
482            },
483            confidence: Confidence::raw(confidence, "test", 0.85),
484            provenance: Provenance {
485                source_type: "published_paper".into(),
486                doi: doi.map(|s| s.to_string()),
487                pmid: None,
488                pmc: None,
489                openalex_id: None,
490                url: None,
491                title: "Test Paper".into(),
492                authors: vec![],
493                year: Some(2024),
494                journal: None,
495                license: None,
496                publisher: None,
497                funders: vec![],
498                extraction: Extraction::default(),
499                review: None,
500                citation_count: None,
501            },
502            flags: Flags {
503                gap: false,
504                negative_space: false,
505                contested: false,
506                retracted: false,
507                declining: false,
508                gravity_well: false,
509                review_state: None,
510                superseded: false,
511                signature_threshold: None,
512                jointly_accepted: false,
513            },
514            links: vec![],
515            annotations: vec![],
516            attachments: vec![],
517            created: String::new(),
518            updated: None,
519
520            access_tier: crate::access_tier::AccessTier::Public,
521        }
522    }
523
524    fn write_test_frontier(dir: &Path) -> std::path::PathBuf {
525        let findings = vec![
526            make_finding(
527                "vf_0000000000000001",
528                "NLRP3 activates caspase-1 in microglia",
529                "mechanism",
530                vec![("NLRP3", "protein"), ("caspase-1", "protein")],
531                "in vitro mouse",
532                0.9,
533                Some("10.1234/a"),
534            ),
535            make_finding(
536                "vf_0000000000000002",
537                "Tau phosphorylation increases in Alzheimer disease",
538                "biomarker",
539                vec![("Tau", "protein"), ("Alzheimer disease", "disease")],
540                "human brain tissue",
541                0.85,
542                Some("10.1234/b"),
543            ),
544            make_finding(
545                "vf_0000000000000003",
546                "Donepezil improves cognition in mild AD patients",
547                "therapeutic",
548                vec![("Donepezil", "compound"), ("Alzheimer disease", "disease")],
549                "clinical trial phase 3",
550                0.95,
551                None,
552            ),
553        ];
554        let c = project::assemble("test-frontier", findings, 3, 0, "Test");
555        let path = dir.join("test.json");
556        let json = serde_json::to_string_pretty(&c).unwrap();
557        std::fs::write(&path, json).unwrap();
558        path
559    }
560
561    #[test]
562    fn search_by_query_returns_scored_results() {
563        let tmp = TempDir::new().unwrap();
564        let path = write_test_frontier(tmp.path());
565        let results = search(&path, "NLRP3 caspase", None, None, 10);
566        assert!(!results.is_empty());
567        assert_eq!(results[0].id, "vf_0000000000000001");
568        assert!(results[0].score > 0.0);
569    }
570
571    #[test]
572    fn search_with_entity_filter() {
573        let tmp = TempDir::new().unwrap();
574        let path = write_test_frontier(tmp.path());
575        let results = search(&path, "disease", Some("Tau"), None, 10);
576        assert_eq!(results.len(), 1);
577        assert_eq!(results[0].id, "vf_0000000000000002");
578    }
579
580    #[test]
581    fn search_with_type_filter() {
582        let tmp = TempDir::new().unwrap();
583        let path = write_test_frontier(tmp.path());
584        let results = search(&path, "Alzheimer", None, Some("therapeutic"), 10);
585        assert_eq!(results.len(), 1);
586        assert_eq!(results[0].id, "vf_0000000000000003");
587    }
588
589    #[test]
590    fn search_no_match_returns_empty() {
591        let tmp = TempDir::new().unwrap();
592        let path = write_test_frontier(tmp.path());
593        let results = search(&path, "xyzzyfoobar", None, None, 10);
594        assert!(results.is_empty());
595    }
596
597    #[test]
598    fn search_respects_limit() {
599        let tmp = TempDir::new().unwrap();
600        let path = write_test_frontier(tmp.path());
601        let results = search(&path, "disease", None, None, 1);
602        assert_eq!(results.len(), 1);
603    }
604
605    #[test]
606    fn entity_match_scores_higher_than_assertion() {
607        let tmp = TempDir::new().unwrap();
608        let path = write_test_frontier(tmp.path());
609        // "NLRP3" appears in both entity name and assertion text for finding 1
610        // but only in assertion text for others (if any)
611        let results = search(&path, "NLRP3", None, None, 10);
612        assert!(!results.is_empty());
613        // First result should have entity match (score = (2+3)/1 = 5.0)
614        assert!(results[0].score >= 5.0);
615    }
616
617    #[test]
618    fn truncate_for_cli_preserves_utf8_boundaries() {
619        let text =
620            "Lecanemab reduced amyloid burden with 200–1000-fold selectivity in early disease.";
621        let truncated = truncate_for_cli(text, 48);
622        assert!(truncated.ends_with("..."));
623        assert!(truncated.len() <= text.len());
624        assert!(std::str::from_utf8(truncated.as_bytes()).is_ok());
625    }
626
627    // ── Cross-frontier search tests ─────────────────────────────────
628
629    fn write_second_frontier(dir: &Path) -> std::path::PathBuf {
630        let findings = vec![
631            make_finding(
632                "vf_1000000000000001",
633                "Iron accumulation in senescent cells",
634                "mechanism",
635                vec![("iron", "compound"), ("senescent cells", "cell_type")],
636                "in vitro human fibroblasts",
637                0.88,
638                Some("10.5678/a"),
639            ),
640            make_finding(
641                "vf_1000000000000002",
642                "Ferrostatin-1 prevents iron-mediated neuronal death after TBI",
643                "therapeutic",
644                vec![("Ferrostatin-1", "compound"), ("iron", "compound")],
645                "mouse model TBI",
646                0.94,
647                Some("10.5678/b"),
648            ),
649        ];
650        let c = project::assemble("iron-biology", findings, 2, 0, "Iron biology frontier");
651        let path = dir.join("iron-biology.json");
652        let json = serde_json::to_string_pretty(&c).unwrap();
653        std::fs::write(&path, json).unwrap();
654        path
655    }
656
657    #[test]
658    fn search_all_finds_across_frontiers() {
659        let tmp = TempDir::new().unwrap();
660        write_test_frontier(tmp.path());
661        write_second_frontier(tmp.path());
662
663        let groups = search_all(tmp.path(), "iron", None, None, 20);
664        assert!(!groups.is_empty());
665        // Iron findings should come from the iron-biology frontier
666        let total: usize = groups.iter().map(|g| g.results.len()).sum();
667        assert!(
668            total >= 2,
669            "Expected at least 2 results for 'iron', got {total}"
670        );
671    }
672
673    #[test]
674    fn search_all_respects_limit() {
675        let tmp = TempDir::new().unwrap();
676        write_test_frontier(tmp.path());
677        write_second_frontier(tmp.path());
678
679        // Both frontiers have findings that match "disease" or broad terms
680        // but we limit to 1 result total
681        let groups = search_all(tmp.path(), "iron", None, None, 1);
682        let total: usize = groups.iter().map(|g| g.results.len()).sum();
683        assert_eq!(total, 1);
684    }
685
686    #[test]
687    fn search_all_empty_dir_returns_empty() {
688        let tmp = TempDir::new().unwrap();
689        let empty_dir = tmp.path().join("empty");
690        std::fs::create_dir_all(&empty_dir).unwrap();
691        let groups = search_all(&empty_dir, "anything", None, None, 20);
692        assert!(groups.is_empty());
693    }
694}