1use std::path::Path;
2
3use crate::adapters::cpp::CppAdapter;
4use crate::adapters::dotnet::DotnetAdapter;
5use crate::adapters::elixir::ElixirAdapter;
6use crate::adapters::go::GoAdapter;
7use crate::adapters::java::JavaAdapter;
8use crate::adapters::javascript::JavaScriptAdapter;
9use crate::adapters::php::PhpAdapter;
10use crate::adapters::python::PythonAdapter;
11use crate::adapters::ruby::RubyAdapter;
12use crate::adapters::rust::RustAdapter;
13use crate::adapters::zig::ZigAdapter;
14use crate::adapters::{DetectionResult, TestAdapter};
15
16pub struct DetectionEngine {
17 adapters: Vec<Box<dyn TestAdapter>>,
18}
19
20pub struct DetectedProject {
21 pub detection: DetectionResult,
22 pub adapter_index: usize,
23}
24
25impl Default for DetectionEngine {
26 fn default() -> Self {
27 Self::new()
28 }
29}
30
31impl DetectionEngine {
32 pub fn new() -> Self {
33 Self {
34 adapters: vec![
35 Box::new(RustAdapter::new()),
36 Box::new(GoAdapter::new()),
37 Box::new(PythonAdapter::new()),
38 Box::new(JavaScriptAdapter::new()),
39 Box::new(JavaAdapter::new()),
40 Box::new(CppAdapter::new()),
41 Box::new(RubyAdapter::new()),
42 Box::new(ElixirAdapter::new()),
43 Box::new(PhpAdapter::new()),
44 Box::new(DotnetAdapter::new()),
45 Box::new(ZigAdapter::new()),
46 ],
47 }
48 }
49
50 pub fn detect(&self, project_dir: &Path) -> Option<DetectedProject> {
53 let mut best: Option<DetectedProject> = None;
54
55 for (i, adapter) in self.adapters.iter().enumerate() {
56 if let Some(result) = adapter.detect(project_dir) {
57 let dominated = best
58 .as_ref()
59 .map(|b| result.confidence > b.detection.confidence)
60 .unwrap_or(true);
61 if dominated {
62 best = Some(DetectedProject {
63 detection: result,
64 adapter_index: i,
65 });
66 }
67 }
68 }
69
70 best
71 }
72
73 pub fn detect_all(&self, project_dir: &Path) -> Vec<DetectedProject> {
75 let mut results = Vec::new();
76 for (i, adapter) in self.adapters.iter().enumerate() {
77 if let Some(result) = adapter.detect(project_dir) {
78 results.push(DetectedProject {
79 detection: result,
80 adapter_index: i,
81 });
82 }
83 }
84 results.sort_by(|a, b| {
85 b.detection
86 .confidence
87 .partial_cmp(&a.detection.confidence)
88 .unwrap_or(std::cmp::Ordering::Equal)
89 });
90 results
91 }
92
93 pub fn adapter(&self, index: usize) -> &dyn TestAdapter {
95 self.adapters[index].as_ref()
96 }
97
98 pub fn adapters(&self) -> &[Box<dyn TestAdapter>] {
100 &self.adapters
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107
108 #[test]
109 fn detect_rust_project() {
110 let dir = tempfile::tempdir().unwrap();
111 std::fs::write(
112 dir.path().join("Cargo.toml"),
113 "[package]\nname = \"test\"\n",
114 )
115 .unwrap();
116 let engine = DetectionEngine::new();
117 let det = engine.detect(dir.path()).unwrap();
118 assert_eq!(det.detection.language, "Rust");
119 }
120
121 #[test]
122 fn detect_go_project() {
123 let dir = tempfile::tempdir().unwrap();
124 std::fs::write(dir.path().join("go.mod"), "module example.com/test\n").unwrap();
125 std::fs::write(dir.path().join("main_test.go"), "package main\n").unwrap();
126 let engine = DetectionEngine::new();
127 let det = engine.detect(dir.path()).unwrap();
128 assert_eq!(det.detection.language, "Go");
129 }
130
131 #[test]
132 fn detect_python_project() {
133 let dir = tempfile::tempdir().unwrap();
134 std::fs::write(dir.path().join("pyproject.toml"), "[tool.pytest]\n").unwrap();
135 let engine = DetectionEngine::new();
136 let det = engine.detect(dir.path()).unwrap();
137 assert_eq!(det.detection.language, "Python");
138 }
139
140 #[test]
141 fn detect_js_project() {
142 let dir = tempfile::tempdir().unwrap();
143 std::fs::write(
144 dir.path().join("package.json"),
145 r#"{"devDependencies":{"jest":"^29"}}"#,
146 )
147 .unwrap();
148 std::fs::write(dir.path().join("jest.config.js"), "").unwrap();
149 let engine = DetectionEngine::new();
150 let det = engine.detect(dir.path()).unwrap();
151 assert_eq!(det.detection.language, "JavaScript");
152 }
153
154 #[test]
155 fn detect_nothing_in_empty_dir() {
156 let dir = tempfile::tempdir().unwrap();
157 let engine = DetectionEngine::new();
158 assert!(engine.detect(dir.path()).is_none());
159 }
160
161 #[test]
162 fn detect_all_polyglot() {
163 let dir = tempfile::tempdir().unwrap();
164 std::fs::write(
166 dir.path().join("Cargo.toml"),
167 "[package]\nname = \"test\"\n",
168 )
169 .unwrap();
170 std::fs::write(dir.path().join("pyproject.toml"), "[tool.pytest]\n").unwrap();
171 let engine = DetectionEngine::new();
172 let all = engine.detect_all(dir.path());
173 assert!(all.len() >= 2);
174 }
175
176 #[test]
177 fn adapter_count() {
178 let engine = DetectionEngine::new();
179 assert_eq!(engine.adapters().len(), 11);
180 }
181}