Skip to main content

search_semantically/
engine.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3
4use anyhow::{Context, Result};
5
6use super::chunker;
7use super::db::{SearchDb, StoredChunk};
8use super::embedder::{DownloadCallback, Embedder};
9use super::format::{SearchResult, format_results};
10use super::metrics;
11use super::query_classifier::classify_query;
12use super::ranker::{MetricScores, poem_rank};
13use super::scanner;
14use super::vector_store;
15
16const MODEL_NAME: &str = "Xenova/all-MiniLM-L6-v2";
17const METRIC_CANDIDATE_LIMIT: usize = 1000;
18
19#[derive(Clone)]
20pub struct SearchEngine {
21    project_root: PathBuf,
22    embedder_cache_dir: PathBuf,
23    download_callback: Option<DownloadCallback>,
24}
25
26impl SearchEngine {
27    pub fn new(project_root: PathBuf) -> Self {
28        let embedder_cache_dir = dirs::cache_dir()
29            .unwrap_or_else(|| PathBuf::from("/tmp"))
30            .join("search-semantically")
31            .join("models");
32        Self {
33            project_root,
34            embedder_cache_dir,
35            download_callback: None,
36        }
37    }
38
39    pub fn with_download_callback(mut self, callback: DownloadCallback) -> Self {
40        self.download_callback = Some(callback);
41        self
42    }
43
44    pub fn set_download_callback(&mut self, callback: DownloadCallback) {
45        self.download_callback = Some(callback);
46    }
47
48    fn get_embedder(&self) -> Embedder {
49        let mut embedder = Embedder::new(self.embedder_cache_dir.clone());
50        if let Some(ref cb) = self.download_callback {
51            embedder.set_download_callback(cb.clone());
52        }
53        embedder
54    }
55
56    fn ensure_embedder(embedder: &mut Embedder) -> bool {
57        if embedder.initialize().is_err() {
58            return false;
59        }
60        true
61    }
62
63    pub fn search(
64        &self,
65        query: &str,
66        limit: usize,
67        restrict_to_dir: Option<&str>,
68    ) -> Result<String> {
69        let index_dir = self.project_root.join(".search-index");
70        std::fs::create_dir_all(&index_dir)
71            .with_context(|| format!("Creating index directory: {}", index_dir.display()))?;
72
73        let db_path = index_dir.join("search.db");
74        let mut db = SearchDb::open(&db_path)?;
75
76        let mut embedder = self.get_embedder();
77
78        self.build_index(&mut db, &mut embedder)?;
79
80        let mut all_chunks = db.get_all_chunks()?;
81        if let Some(dir) = restrict_to_dir {
82            all_chunks.retain(|c| c.file_path.starts_with(dir));
83        }
84
85        if all_chunks.is_empty() {
86            return Ok(format_results(&[]));
87        }
88
89        let query_type = classify_query(query);
90
91        let bm25_scores = metrics::compute_bm25_scores(&mut db, query, METRIC_CANDIDATE_LIMIT);
92
93        let cosine_scores =
94            self.compute_vector_scores(&mut db, &mut embedder, query, METRIC_CANDIDATE_LIMIT)?;
95
96        let path_scores = metrics::compute_path_match_scores(query, &all_chunks);
97
98        let symbols = db.get_all_symbols()?;
99        let symbol_scores = metrics::compute_symbol_match_scores(query, &symbols);
100
101        let file_seed_scores =
102            self.aggregate_file_scores(&all_chunks, &bm25_scores, &cosine_scores);
103        let seed_threshold = self.compute_seed_threshold(&file_seed_scores);
104        let filtered_seeds: HashMap<i64, f64> = file_seed_scores
105            .into_iter()
106            .filter(|(_, score)| *score >= seed_threshold)
107            .collect();
108        let file_id_to_chunk_ids = self.build_file_chunk_map(&all_chunks);
109        let import_scores =
110            metrics::compute_import_graph_scores(&mut db, &filtered_seeds, &file_id_to_chunk_ids);
111
112        let recency_scores = metrics::compute_git_recency_scores(&self.project_root, &all_chunks);
113
114        let candidate_ids = self.collect_candidate_ids(
115            &bm25_scores,
116            &cosine_scores,
117            &path_scores,
118            &symbol_scores,
119            &import_scores,
120            &recency_scores,
121        );
122
123        let mut candidates: HashMap<i64, MetricScores> = HashMap::new();
124        for &id in &candidate_ids {
125            candidates.insert(
126                id,
127                MetricScores {
128                    bm25: bm25_scores.get(&id).copied().unwrap_or(0.0),
129                    cosine: cosine_scores.get(&id).copied().unwrap_or(0.0),
130                    path_match: path_scores.get(&id).copied().unwrap_or(0.0),
131                    symbol_match: symbol_scores.get(&id).copied().unwrap_or(0.0),
132                    import_graph: import_scores.get(&id).copied().unwrap_or(0.0),
133                    git_recency: recency_scores.get(&id).copied().unwrap_or(0.0),
134                },
135            );
136        }
137
138        if candidates.is_empty() {
139            return Ok(format_results(&[]));
140        }
141
142        let ranked = poem_rank(&candidates, &query_type, METRIC_CANDIDATE_LIMIT);
143
144        let chunk_map: HashMap<i64, &StoredChunk> = all_chunks.iter().map(|c| (c.id, c)).collect();
145
146        let results: Vec<SearchResult> = ranked
147            .into_iter()
148            .take(limit)
149            .filter_map(|candidate| {
150                let chunk = chunk_map.get(&candidate.id)?;
151                Some(SearchResult {
152                    chunk: (*chunk).clone(),
153                    scores: candidate.scores,
154                    rank: candidate.rank,
155                })
156            })
157            .collect();
158
159        Ok(format_results(&results))
160    }
161
162    fn build_index(&self, db: &mut SearchDb, embedder: &mut Embedder) -> Result<()> {
163        let scanned_files = scanner::scan_project(&self.project_root);
164        let existing_files = db.get_all_files()?;
165
166        let existing_by_path: HashMap<String, _> = existing_files
167            .iter()
168            .map(|f| (f.file_path.clone(), f))
169            .collect();
170
171        let scanned_by_path: HashMap<String, _> = scanned_files
172            .iter()
173            .map(|f| (f.file_path.clone(), f))
174            .collect();
175
176        let mut to_add = Vec::new();
177        let mut to_update = Vec::new();
178        let mut to_remove = Vec::new();
179
180        for scanned in &scanned_files {
181            if let Some(existing) = existing_by_path.get(&scanned.file_path) {
182                if (existing.mtime - scanned.mtime).abs() > 0.001 {
183                    to_update.push(scanned);
184                }
185            } else {
186                to_add.push(scanned);
187            }
188        }
189
190        for existing in &existing_files {
191            if !scanned_by_path.contains_key(&existing.file_path) {
192                to_remove.push(existing);
193            }
194        }
195
196        if to_add.is_empty() && to_update.is_empty() && to_remove.is_empty() {
197            return Ok(());
198        }
199
200        for file in &to_remove {
201            db.delete_file(file.id)?;
202        }
203
204        let files_to_process: Vec<_> = to_add
205            .into_iter()
206            .chain(to_update.iter().copied())
207            .collect();
208        let mut all_new_chunk_ids: Vec<i64> = Vec::new();
209
210        for scanned in &files_to_process {
211            let abs_path = self.project_root.join(&scanned.file_path);
212            let content = match std::fs::read_to_string(&abs_path) {
213                Ok(c) => c,
214                Err(_) => continue,
215            };
216
217            let chunks = chunker::chunk_file(&content, &scanned.file_path, &scanned.file_type);
218
219            let file_id = db.upsert_file(
220                &scanned.file_path,
221                scanned.mtime,
222                &scanned.file_type.to_string(),
223            )?;
224
225            if let Some(existing) = existing_by_path.get(&scanned.file_path) {
226                db.delete_chunks_for_file(existing.id)?;
227                let _ = db.delete_imports_for_file(existing.id);
228            }
229
230            for text_chunk in &chunks {
231                let chunk_id = db.insert_chunk(
232                    file_id,
233                    &text_chunk.file_path,
234                    text_chunk.start_line as i64,
235                    text_chunk.end_line as i64,
236                    &text_chunk.kind.to_string(),
237                    text_chunk.name.as_deref(),
238                    &text_chunk.content,
239                    &scanned.file_type.to_string(),
240                )?;
241                all_new_chunk_ids.push(chunk_id);
242
243                if let Some(name) = &text_chunk.name {
244                    db.insert_symbol(chunk_id, name, &text_chunk.kind.to_string())?;
245                }
246            }
247
248            let imports = extract_imports(&content, &scanned.file_type);
249            for target_path in imports {
250                let _ = db.insert_import(file_id, &target_path);
251            }
252        }
253
254        if !all_new_chunk_ids.is_empty() {
255            let _ = self.embed_chunks(db, embedder, &all_new_chunk_ids);
256        }
257
258        Ok(())
259    }
260
261    fn embed_chunks(
262        &self,
263        db: &mut SearchDb,
264        embedder: &mut Embedder,
265        chunk_ids: &[i64],
266    ) -> Result<()> {
267        if chunk_ids.is_empty() {
268            return Ok(());
269        }
270
271        let chunks = db.get_chunks_by_ids(chunk_ids)?;
272        if chunks.is_empty() {
273            return Ok(());
274        }
275
276        let texts: Vec<&str> = chunks.iter().map(|c| c.content.as_str()).collect();
277
278        if !Self::ensure_embedder(embedder) {
279            return Ok(());
280        }
281
282        let vectors = embedder.embed(&texts)?;
283
284        let items: Vec<(i64, String, Vec<u8>)> = chunks
285            .iter()
286            .zip(vectors.iter())
287            .map(|(chunk, vector)| {
288                let blob = vector_store::pack_vector(vector);
289                (chunk.id, MODEL_NAME.to_string(), blob)
290            })
291            .collect();
292
293        db.batch_upsert_embeddings(&items)?;
294        Ok(())
295    }
296
297    fn compute_vector_scores(
298        &self,
299        db: &mut SearchDb,
300        embedder: &mut Embedder,
301        query: &str,
302        limit: usize,
303    ) -> Result<HashMap<i64, f64>> {
304        if !Self::ensure_embedder(embedder) {
305            return Ok(HashMap::new());
306        }
307
308        let query_vectors = embedder.embed(&[query])?;
309        let query_vector = &query_vectors[0];
310
311        let stored = db.get_all_embeddings(MODEL_NAME)?;
312        if stored.is_empty() {
313            return Ok(HashMap::new());
314        }
315
316        let vectors: Vec<(i64, Vec<f32>)> = stored
317            .into_iter()
318            .map(|(id, blob)| (id, vector_store::unpack_vector(&blob)))
319            .collect();
320
321        let top_k = vector_store::top_k_similar(query_vector, &vectors, limit);
322
323        let scores: HashMap<i64, f64> = top_k
324            .into_iter()
325            .map(|(id, score)| (id, score.max(0.0) as f64))
326            .collect();
327
328        Ok(scores)
329    }
330
331    fn collect_candidate_ids(
332        &self,
333        bm25: &HashMap<i64, f64>,
334        cosine: &HashMap<i64, f64>,
335        path: &HashMap<i64, f64>,
336        symbol: &HashMap<i64, f64>,
337        import: &HashMap<i64, f64>,
338        recency: &HashMap<i64, f64>,
339    ) -> Vec<i64> {
340        let mut ids = std::collections::HashSet::new();
341        for map in &[bm25, cosine, path, symbol, import, recency] {
342            for &id in map.keys() {
343                ids.insert(id);
344            }
345        }
346        ids.into_iter().collect()
347    }
348
349    fn aggregate_file_scores(
350        &self,
351        chunks: &[StoredChunk],
352        bm25: &HashMap<i64, f64>,
353        cosine: &HashMap<i64, f64>,
354    ) -> HashMap<i64, f64> {
355        let mut file_scores = HashMap::new();
356        for chunk in chunks {
357            let max_score = bm25
358                .get(&chunk.id)
359                .copied()
360                .unwrap_or(0.0)
361                .max(cosine.get(&chunk.id).copied().unwrap_or(0.0));
362            if max_score > 0.0 {
363                let entry = file_scores.entry(chunk.file_id).or_insert(0.0);
364                if max_score > *entry {
365                    *entry = max_score;
366                }
367            }
368        }
369        file_scores
370    }
371
372    fn compute_seed_threshold(&self, file_scores: &HashMap<i64, f64>) -> f64 {
373        if file_scores.is_empty() {
374            return 0.0;
375        }
376        let mut sorted: Vec<f64> = file_scores.values().copied().collect();
377        sorted.sort_by(|a, b| b.partial_cmp(a).expect("floats"));
378        let median = sorted[sorted.len() / 2];
379        median.max(0.1)
380    }
381
382    fn build_file_chunk_map(&self, chunks: &[StoredChunk]) -> HashMap<i64, Vec<i64>> {
383        let mut map = HashMap::new();
384        for chunk in chunks {
385            map.entry(chunk.file_id)
386                .or_insert_with(Vec::new)
387                .push(chunk.id);
388        }
389        map
390    }
391}
392
393fn extract_imports(content: &str, file_type: &super::scanner::FileType) -> Vec<String> {
394    match file_type {
395        super::scanner::FileType::Rust => extract_rust_imports(content),
396        _ => Vec::new(),
397    }
398}
399
400fn extract_rust_imports(content: &str) -> Vec<String> {
401    let mut imports = Vec::new();
402    let lines: Vec<&str> = content.lines().collect();
403    let mut i = 0;
404
405    while i < lines.len() {
406        let trimmed = lines[i].trim();
407        if !trimmed.starts_with("use ") {
408            i += 1;
409            continue;
410        }
411
412        let mut full_line = trimmed.to_string();
413
414        if trimmed.contains("::{") && !trimmed.contains("};") {
415            let mut found_close = false;
416            for next_line in &lines[i + 1..] {
417                full_line.push(' ');
418                full_line.push_str(next_line.trim());
419                if full_line.contains("};") {
420                    found_close = true;
421                    break;
422                }
423            }
424            if !found_close {
425                i += 1;
426                continue;
427            }
428        }
429
430        let path = full_line
431            .strip_prefix("use ")
432            .unwrap_or("")
433            .trim()
434            .trim_end_matches(';')
435            .trim();
436
437        if let Some(brace_start) = path.find("::{") {
438            let base = &path[..brace_start];
439            let after_brace = &path[brace_start + 3..];
440            let close_pos = match after_brace.rfind('}') {
441                Some(p) => p,
442                None => {
443                    i += 1;
444                    continue;
445                }
446            };
447            let inner = after_brace[..close_pos].trim();
448            let base_normalized = base.replace("::", "/");
449            for item in inner.split(',') {
450                let item = item.trim();
451                if item.is_empty() {
452                    continue;
453                }
454                let full = format!("{}/{}", base_normalized, item);
455                let resolved = resolve_rust_path(&full);
456                imports.push(resolved);
457            }
458        } else {
459            let resolved = resolve_rust_path(&path.replace("::", "/"));
460            imports.push(resolved);
461        }
462
463        i += 1;
464    }
465    imports
466}
467
468fn resolve_rust_path(crate_path: &str) -> String {
469    let parts: Vec<&str> = crate_path.split('/').collect();
470    if parts.len() < 2 {
471        return crate_path.to_string();
472    }
473
474    match parts[0] {
475        "crate" => {
476            let rest = &parts[1..];
477            if rest.is_empty() {
478                return crate_path.to_string();
479            }
480            match rest[0] {
481                "super" | "self" => rest[1..].join("/"),
482                _ => rest.join("/"),
483            }
484        }
485        "super" | "self" => parts[1..].join("/"),
486        "std" => format!("lib/std/{}", parts[1..].join("/")),
487        "core" => format!("lib/core/{}", parts[1..].join("/")),
488        "alloc" => format!("lib/alloc/{}", parts[1..].join("/")),
489        _ => parts.join("/"),
490    }
491}
492
493#[cfg(feature = "ts-typescript")]
494fn extract_ts_imports(content: &str) -> Vec<String> {
495    let mut imports = Vec::new();
496    for line in content.lines() {
497        let trimmed = line.trim();
498        let import_path = if let Some(rest) = trimmed
499            .strip_prefix("import ")
500            .and_then(|s| s.strip_prefix("type "))
501            .or_else(|| trimmed.strip_prefix("import "))
502        {
503            rest.split("from").last().map(|s| {
504                s.trim()
505                    .trim_end_matches(';')
506                    .trim()
507                    .trim_matches('"')
508                    .trim_matches('\'')
509            })
510        } else if let Some(rest) = trimmed.strip_prefix("from ") {
511            Some(
512                rest.split("import")
513                    .next()
514                    .unwrap_or("")
515                    .trim()
516                    .trim_end_matches(';')
517                    .trim()
518                    .trim_matches('"')
519                    .trim_matches('\''),
520            )
521        } else if trimmed.starts_with("require(") {
522            Some(
523                trimmed
524                    .trim_start_matches("require(")
525                    .trim_end_matches(')')
526                    .trim()
527                    .trim_matches('"')
528                    .trim_matches('\''),
529            )
530        } else {
531            None
532        };
533
534        if let Some(path) = import_path {
535            if !path.is_empty() {
536                imports.push(path.to_string());
537            }
538        }
539    }
540    imports
541}
542
543#[cfg(feature = "ts-python")]
544fn extract_python_imports(content: &str) -> Vec<String> {
545    let mut imports = Vec::new();
546    for line in content.lines() {
547        let trimmed = line.trim();
548        if trimmed.starts_with("import ") || trimmed.starts_with("from ") {
549            let module = if let Some(rest) = trimmed.strip_prefix("from ") {
550                rest.split(" import").next().unwrap_or("")
551            } else {
552                trimmed.strip_prefix("import ").unwrap_or("")
553            };
554            let module = module.trim().split(" as ").next().unwrap_or(module).trim();
555            if !module.is_empty() {
556                imports.push(module.replace('.', "/"));
557            }
558        }
559    }
560    imports
561}
562
563#[cfg(feature = "ts-go")]
564fn extract_go_imports(content: &str) -> Vec<String> {
565    let mut imports = Vec::new();
566    let mut in_import_block = false;
567    for line in content.lines() {
568        let trimmed = line.trim();
569        if trimmed == "import (" {
570            in_import_block = true;
571            continue;
572        }
573        if in_import_block && trimmed == ")" {
574            in_import_block = false;
575            continue;
576        }
577        if in_import_block {
578            let path = trimmed
579                .trim_matches('"')
580                .split("//")
581                .next()
582                .unwrap_or("")
583                .trim();
584            if !path.is_empty() {
585                imports.push(path.to_string());
586            }
587        } else if let Some(rest) = trimmed.strip_prefix("import ") {
588            let path = rest
589                .trim()
590                .trim_matches('"')
591                .split("//")
592                .next()
593                .unwrap_or("")
594                .trim();
595            if !path.is_empty() {
596                imports.push(path.to_string());
597            }
598        }
599    }
600    imports
601}
602
603#[cfg(feature = "ts-java")]
604fn extract_java_imports(content: &str) -> Vec<String> {
605    let mut imports = Vec::new();
606    for line in content.lines() {
607        let trimmed = line.trim();
608        if trimmed.starts_with("import ") {
609            let path = trimmed
610                .strip_prefix("import ")
611                .unwrap_or("")
612                .trim()
613                .trim_end_matches(';')
614                .trim();
615            if !path.is_empty() {
616                imports.push(path.replace('.', "/"));
617            }
618        }
619    }
620    imports
621}
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626    use std::env;
627    use std::fs;
628    use tempfile::TempDir;
629
630    #[test]
631    fn search_empty_directory_returns_no_results() {
632        let temp = TempDir::new().expect("temp dir");
633        let engine = SearchEngine::new(temp.path().to_path_buf());
634        let result = engine.search("test", 20, None).expect("search");
635        assert_eq!(result, "No results found.");
636    }
637
638    #[test]
639    fn search_finds_files_in_project() {
640        let temp = TempDir::new().expect("temp dir");
641        fs::write(
642            temp.path().join("main.rs"),
643            "fn search_engine() -> Vec<String> {\n    vec![\"hello\".to_string()]\n}\n",
644        )
645        .expect("write");
646
647        let engine = SearchEngine::new(temp.path().to_path_buf());
648        let result = engine.search("search_engine", 20, None).expect("search");
649        assert_ne!(result, "No results found.");
650        assert!(result.contains("main.rs"));
651    }
652
653    #[test]
654    fn search_with_restrict_to_dir() {
655        let temp = TempDir::new().expect("temp dir");
656        let sub = temp.path().join("src");
657        fs::create_dir_all(&sub).expect("dir");
658        fs::write(sub.join("mod.rs"), "fn helper() {}").expect("write");
659        fs::write(temp.path().join("main.rs"), "fn main() {}").expect("write");
660
661        let engine = SearchEngine::new(temp.path().to_path_buf());
662        let result = engine.search("main", 20, Some("src")).expect("search");
663        assert!(!result.contains("main.rs") || result.contains("No results"));
664    }
665
666    #[test]
667    fn extract_rust_imports_simple() {
668        let code = "use std::collections::HashMap;\nuse crate::tools::search::db;\n";
669        let imports = extract_rust_imports(code);
670        assert!(imports.contains(&"lib/std/collections/HashMap".to_string()));
671        assert!(imports.contains(&"tools/search/db".to_string()));
672    }
673
674    #[test]
675    fn extract_rust_imports_grouped() {
676        let code = "use std::collections::{HashMap, BTreeMap};\n";
677        let imports = extract_rust_imports(code);
678        assert!(imports.contains(&"lib/std/collections/HashMap".to_string()));
679        assert!(imports.contains(&"lib/std/collections/BTreeMap".to_string()));
680    }
681
682    #[test]
683    fn extract_rust_imports_ignores_non_use_lines() {
684        let code = "fn main() {}\n// use something\nconst X: i32 = 1;\n";
685        let imports = extract_rust_imports(code);
686        assert!(imports.is_empty());
687    }
688
689    #[test]
690    fn extract_rust_imports_super_path() {
691        let code = "use super::engine::SearchEngine;\n";
692        let imports = extract_rust_imports(code);
693        assert!(imports.contains(&"engine/SearchEngine".to_string()));
694    }
695
696    #[test]
697    fn extract_rust_imports_crate_self() {
698        let code = "use crate::config::SearchConfig;\n";
699        let imports = extract_rust_imports(code);
700        assert!(imports.contains(&"config/SearchConfig".to_string()));
701    }
702
703    #[test]
704    fn extract_rust_imports_empty() {
705        let imports = extract_rust_imports("");
706        assert!(imports.is_empty());
707    }
708
709    #[test]
710    fn rebuild_creates_fresh_index() {
711        let temp = TempDir::new().expect("temp dir");
712        fs::write(temp.path().join("a.rs"), "fn first() {}").expect("write");
713
714        let engine = SearchEngine::new(temp.path().to_path_buf());
715        let result1 = engine.search("first", 20, None).expect("search");
716        assert!(result1.contains("a.rs"));
717
718        // Modify the file
719        fs::write(temp.path().join("a.rs"), "fn renamed() {}").expect("write");
720
721        // Add another file
722        fs::write(temp.path().join("b.rs"), "fn second() {}").expect("write");
723
724        // Delete the db to simulate rebuild
725        let db_path = temp.path().join(".search-index").join("search.db");
726        let _ = std::fs::remove_file(&db_path);
727
728        let result2 = engine.search("second", 20, None).expect("search");
729        assert!(result2.contains("b.rs"));
730    }
731
732    #[test]
733    fn extract_rust_imports_multiline_grouped() {
734        let code = "use std::collections::{\n    HashMap,\n    BTreeMap,\n};\n";
735        let imports = extract_rust_imports(code);
736        assert!(imports.contains(&"lib/std/collections/HashMap".to_string()));
737        assert!(imports.contains(&"lib/std/collections/BTreeMap".to_string()));
738    }
739
740    #[test]
741    fn extract_rust_imports_multiline_crate_path() {
742        let code = "use crate::upnp::{\n    browse,\n    browse_recursively,\n    Container,\n};\n";
743        let imports = extract_rust_imports(code);
744        assert!(imports.contains(&"upnp/browse".to_string()));
745        assert!(imports.contains(&"upnp/browse_recursively".to_string()));
746        assert!(imports.contains(&"upnp/Container".to_string()));
747    }
748
749    #[test]
750    fn extract_rust_imports_open_brace_only_no_panic() {
751        let code = "use foo::{\n";
752        let imports = extract_rust_imports(code);
753        assert!(imports.is_empty());
754    }
755
756    #[test]
757    fn extract_rust_imports_trailing_comma_multiline() {
758        let code = "use std::io::{Read, Write,};\n";
759        let imports = extract_rust_imports(code);
760        assert!(imports.contains(&"lib/std/io/Read".to_string()));
761        assert!(imports.contains(&"lib/std/io/Write".to_string()));
762        assert_eq!(imports.len(), 2);
763    }
764
765    #[test]
766    fn extract_rust_imports_mixed_single_and_multiline() {
767        let code = "use std::fs;\nuse crate::engine::{\n    SearchEngine,\n    MetricScores,\n};\nuse super::db::SearchDb;\n";
768        let imports = extract_rust_imports(code);
769        assert!(imports.contains(&"lib/std/fs".to_string()));
770        assert!(imports.contains(&"engine/SearchEngine".to_string()));
771        assert!(imports.contains(&"engine/MetricScores".to_string()));
772        assert!(imports.contains(&"db/SearchDb".to_string()));
773        assert_eq!(imports.len(), 4);
774    }
775
776    #[tokio::test]
777    async fn search_from_tokio_context_does_not_panic() {
778        let cache_temp = TempDir::new().expect("cache temp dir");
779        // SAFETY: This is an isolated test — mutating the env var is safe
780        // within this test's scope to force a unique cold model cache path.
781        unsafe { env::set_var("XDG_CACHE_HOME", cache_temp.path()) };
782
783        let temp = TempDir::new().expect("project temp dir");
784        fs::write(
785            temp.path().join("lib.rs"),
786            "fn search_engine() -> Vec<String> {\n    vec![\"hello\".to_string()]\n}\n",
787        )
788        .expect("write");
789
790        let engine = SearchEngine::new(temp.path().to_path_buf());
791        let result = engine.search("search_engine", 20, None);
792        assert!(
793            result.is_ok(),
794            "Search from tokio context should not panic: {:?}",
795            result.err()
796        );
797    }
798}