1use std::path::{Path, PathBuf};
2
3use crate::error::{Result, UbtError};
4use crate::plugin::{Plugin, PluginRegistry, PluginSource};
5
6#[derive(Debug)]
8pub struct DetectionResult {
9 pub plugin_name: String,
10 pub variant_name: String,
11 pub source: PluginSource,
12 pub project_root: PathBuf,
13}
14
15pub fn detect_tool(
21 cli_tool: Option<&str>,
22 config_tool: Option<&str>,
23 start_dir: &Path,
24 registry: &PluginRegistry,
25) -> Result<DetectionResult> {
26 if let Some(tool) = cli_tool {
28 return resolve_explicit_tool(tool, start_dir, registry);
29 }
30
31 if let Ok(tool) = std::env::var("UBT_TOOL") {
34 if !tool.is_empty() {
35 return resolve_explicit_tool(&tool, start_dir, registry);
36 }
37 }
38
39 if let Some(tool) = config_tool {
41 return resolve_explicit_tool(tool, start_dir, registry);
42 }
43
44 auto_detect(start_dir, registry)
46}
47
48fn resolve_explicit_tool(
51 tool: &str,
52 start_dir: &Path,
53 registry: &PluginRegistry,
54) -> Result<DetectionResult> {
55 if let Some((plugin, source)) = registry.get(tool) {
57 return Ok(DetectionResult {
58 plugin_name: plugin.name.clone(),
59 variant_name: detect_variant(plugin, start_dir)
60 .unwrap_or_else(|| plugin.default_variant.clone()),
61 source: source.clone(),
62 project_root: start_dir.to_path_buf(),
63 });
64 }
65
66 for (_name, (plugin, source)) in registry.iter() {
68 if plugin.variants.contains_key(tool) {
69 return Ok(DetectionResult {
70 plugin_name: plugin.name.clone(),
71 variant_name: tool.to_string(),
72 source: source.clone(),
73 project_root: start_dir.to_path_buf(),
74 });
75 }
76 }
77
78 Err(UbtError::PluginLoadError {
79 name: tool.to_string(),
80 detail: "no plugin or variant found with this name".into(),
81 })
82}
83
84fn auto_detect(start_dir: &Path, registry: &PluginRegistry) -> Result<DetectionResult> {
86 let mut current = start_dir.to_path_buf();
87
88 loop {
89 let matches = detect_at_dir(¤t, registry);
90 if !matches.is_empty() {
91 return resolve_matches(matches, ¤t);
92 }
93 if !current.pop() {
94 break;
95 }
96 }
97
98 Err(UbtError::NoPluginMatch)
99}
100
101#[derive(Debug)]
103struct DetectMatch {
104 plugin_name: String,
105 variant_name: String,
106 priority: i32,
107 source: PluginSource,
108}
109
110fn detect_at_dir(dir: &Path, registry: &PluginRegistry) -> Vec<DetectMatch> {
112 let mut matches = Vec::new();
113
114 for (_name, (plugin, source)) in registry.iter() {
115 if plugin_matches_dir(plugin, dir) {
116 let variant =
117 detect_variant(plugin, dir).unwrap_or_else(|| plugin.default_variant.clone());
118 matches.push(DetectMatch {
119 plugin_name: plugin.name.clone(),
120 variant_name: variant,
121 priority: plugin.priority,
122 source: source.clone(),
123 });
124 }
125 }
126
127 matches
128}
129
130fn plugin_matches_dir(plugin: &Plugin, dir: &Path) -> bool {
132 plugin.detect.files.iter().any(|pattern| {
133 if pattern.contains('*') {
134 glob_matches(dir, pattern)
136 } else {
137 dir.join(pattern).exists()
138 }
139 })
140}
141
142fn glob_matches(dir: &Path, pattern: &str) -> bool {
144 let Ok(matcher) = globset::GlobBuilder::new(pattern)
145 .literal_separator(true)
146 .build()
147 .map(|g| g.compile_matcher())
148 else {
149 return false;
150 };
151
152 let Ok(entries) = std::fs::read_dir(dir) else {
153 return false;
154 };
155
156 entries.filter_map(|e| e.ok()).any(|entry| {
157 entry
158 .file_name()
159 .to_str()
160 .map(|name| matcher.is_match(name))
161 .unwrap_or(false)
162 })
163}
164
165fn detect_variant(plugin: &Plugin, dir: &Path) -> Option<String> {
167 for (variant_name, variant) in &plugin.variants {
168 for detect_file in &variant.detect_files {
169 if detect_file.contains('*') {
170 if glob_matches(dir, detect_file) {
171 return Some(variant_name.clone());
172 }
173 } else if dir.join(detect_file).exists() {
174 return Some(variant_name.clone());
175 }
176 }
177 }
178 None
179}
180
181fn resolve_matches(matches: Vec<DetectMatch>, dir: &Path) -> Result<DetectionResult> {
183 assert!(!matches.is_empty());
184
185 if matches.len() == 1 {
186 let m = matches.into_iter().next().unwrap();
187 return Ok(DetectionResult {
188 plugin_name: m.plugin_name,
189 variant_name: m.variant_name,
190 source: m.source,
191 project_root: dir.to_path_buf(),
192 });
193 }
194
195 let mut sorted = matches;
197 sorted.sort_by(|a, b| b.priority.cmp(&a.priority));
198
199 if sorted[0].priority == sorted[1].priority {
201 let plugins: Vec<_> = sorted.iter().map(|m| m.plugin_name.as_str()).collect();
202 return Err(UbtError::PluginConflict {
203 plugins: plugins.join(", "),
204 suggested_tool: sorted[0].plugin_name.clone(),
205 });
206 }
207
208 let winner = sorted.into_iter().next().unwrap();
209 Ok(DetectionResult {
210 plugin_name: winner.plugin_name,
211 variant_name: winner.variant_name,
212 source: winner.source,
213 project_root: dir.to_path_buf(),
214 })
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use std::sync::Mutex;
221 use tempfile::TempDir;
222
223 static ENV_MUTEX: Mutex<()> = Mutex::new(());
224
225 fn with_clean_env<F, R>(f: F) -> R
226 where
227 F: FnOnce() -> R,
228 {
229 let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
230 let prev = std::env::var("UBT_TOOL").ok();
231 unsafe {
232 std::env::remove_var("UBT_TOOL");
233 }
234 let result = f();
235 if let Some(v) = prev {
236 unsafe {
237 std::env::set_var("UBT_TOOL", v);
238 }
239 }
240 result
241 }
242
243 #[test]
244 fn detect_go_project() {
245 with_clean_env(|| {
246 let dir = TempDir::new().unwrap();
247 std::fs::write(dir.path().join("go.mod"), "module example.com/foo").unwrap();
248
249 let registry = PluginRegistry::new().unwrap();
250 let result = detect_tool(None, None, dir.path(), ®istry).unwrap();
251
252 assert_eq!(result.plugin_name, "go");
253 assert_eq!(result.variant_name, "go");
254 });
255 }
256
257 #[test]
258 fn detect_node_npm() {
259 with_clean_env(|| {
260 let dir = TempDir::new().unwrap();
261 std::fs::write(dir.path().join("package.json"), "{}").unwrap();
262 std::fs::write(dir.path().join("package-lock.json"), "{}").unwrap();
263
264 let registry = PluginRegistry::new().unwrap();
265 let result = detect_tool(None, None, dir.path(), ®istry).unwrap();
266
267 assert_eq!(result.plugin_name, "node");
268 assert_eq!(result.variant_name, "npm");
269 });
270 }
271
272 #[test]
273 fn detect_node_pnpm() {
274 with_clean_env(|| {
275 let dir = TempDir::new().unwrap();
276 std::fs::write(dir.path().join("package.json"), "{}").unwrap();
277 std::fs::write(dir.path().join("pnpm-lock.yaml"), "").unwrap();
278
279 let registry = PluginRegistry::new().unwrap();
280 let result = detect_tool(None, None, dir.path(), ®istry).unwrap();
281
282 assert_eq!(result.plugin_name, "node");
283 assert_eq!(result.variant_name, "pnpm");
284 });
285 }
286
287 #[test]
288 fn detect_node_default_variant_when_no_lockfile() {
289 with_clean_env(|| {
290 let dir = TempDir::new().unwrap();
291 std::fs::write(dir.path().join("package.json"), "{}").unwrap();
292
293 let registry = PluginRegistry::new().unwrap();
294 let result = detect_tool(None, None, dir.path(), ®istry).unwrap();
295
296 assert_eq!(result.plugin_name, "node");
297 assert_eq!(result.variant_name, "npm");
298 });
299 }
300
301 #[test]
302 fn detect_rust_project() {
303 with_clean_env(|| {
304 let dir = TempDir::new().unwrap();
305 std::fs::write(dir.path().join("Cargo.toml"), "[package]\nname = \"foo\"").unwrap();
306
307 let registry = PluginRegistry::new().unwrap();
308 let result = detect_tool(None, None, dir.path(), ®istry).unwrap();
309
310 assert_eq!(result.plugin_name, "rust");
311 assert_eq!(result.variant_name, "cargo");
312 });
313 }
314
315 #[test]
316 fn detect_cli_override() {
317 with_clean_env(|| {
318 let dir = TempDir::new().unwrap();
319 std::fs::write(dir.path().join("go.mod"), "module foo").unwrap();
321
322 let registry = PluginRegistry::new().unwrap();
323 let result = detect_tool(Some("node"), None, dir.path(), ®istry).unwrap();
324
325 assert_eq!(result.plugin_name, "node");
326 });
327 }
328
329 #[test]
330 fn detect_config_override() {
331 with_clean_env(|| {
332 let dir = TempDir::new().unwrap();
333 std::fs::write(dir.path().join("go.mod"), "module foo").unwrap();
334
335 let registry = PluginRegistry::new().unwrap();
336 let result = detect_tool(None, Some("node"), dir.path(), ®istry).unwrap();
337
338 assert_eq!(result.plugin_name, "node");
339 });
340 }
341
342 #[test]
343 fn detect_variant_name_as_tool() {
344 with_clean_env(|| {
345 let dir = TempDir::new().unwrap();
346 let registry = PluginRegistry::new().unwrap();
347 let result = detect_tool(Some("pnpm"), None, dir.path(), ®istry).unwrap();
348
349 assert_eq!(result.plugin_name, "node");
350 assert_eq!(result.variant_name, "pnpm");
351 });
352 }
353
354 #[test]
355 fn detect_walks_upward() {
356 with_clean_env(|| {
357 let dir = TempDir::new().unwrap();
358 std::fs::write(dir.path().join("go.mod"), "module foo").unwrap();
359 let nested = dir.path().join("a").join("b").join("c");
360 std::fs::create_dir_all(&nested).unwrap();
361
362 let registry = PluginRegistry::new().unwrap();
363 let result = detect_tool(None, None, &nested, ®istry).unwrap();
364
365 assert_eq!(result.plugin_name, "go");
366 assert_eq!(result.project_root, dir.path());
367 });
368 }
369
370 #[test]
371 fn detect_no_match_errors() {
372 with_clean_env(|| {
373 let dir = TempDir::new().unwrap();
374 let registry = PluginRegistry::new().unwrap();
375 let result = detect_tool(None, None, dir.path(), ®istry);
376
377 assert!(result.is_err());
378 assert!(matches!(result.unwrap_err(), UbtError::NoPluginMatch));
379 });
380 }
381
382 #[test]
383 fn detect_unknown_tool_errors() {
384 with_clean_env(|| {
385 let dir = TempDir::new().unwrap();
386 let registry = PluginRegistry::new().unwrap();
387 let result = detect_tool(Some("nonexistent"), None, dir.path(), ®istry);
388
389 assert!(result.is_err());
390 });
391 }
392
393 #[test]
394 fn detect_dotnet_glob() {
395 with_clean_env(|| {
396 let dir = TempDir::new().unwrap();
397 std::fs::write(dir.path().join("MyApp.csproj"), "<Project/>").unwrap();
398
399 let registry = PluginRegistry::new().unwrap();
400 let result = detect_tool(None, None, dir.path(), ®istry).unwrap();
401
402 assert_eq!(result.plugin_name, "dotnet");
403 });
404 }
405
406 #[test]
407 fn detect_ruby_project() {
408 with_clean_env(|| {
409 let dir = TempDir::new().unwrap();
410 std::fs::write(dir.path().join("Gemfile"), "source 'https://rubygems.org'").unwrap();
411 std::fs::write(dir.path().join("Gemfile.lock"), "").unwrap();
412
413 let registry = PluginRegistry::new().unwrap();
414 let result = detect_tool(None, None, dir.path(), ®istry).unwrap();
415
416 assert_eq!(result.plugin_name, "ruby");
417 assert_eq!(result.variant_name, "bundler");
418 });
419 }
420
421 #[test]
422 fn detect_python_pip() {
423 with_clean_env(|| {
424 let dir = TempDir::new().unwrap();
425 std::fs::write(dir.path().join("requirements.txt"), "flask").unwrap();
426
427 let registry = PluginRegistry::new().unwrap();
428 let result = detect_tool(None, None, dir.path(), ®istry).unwrap();
429
430 assert_eq!(result.plugin_name, "python");
431 });
432 }
433
434 #[test]
435 fn detect_java_maven() {
436 with_clean_env(|| {
437 let dir = TempDir::new().unwrap();
438 std::fs::write(dir.path().join("pom.xml"), "<project/>").unwrap();
439
440 let registry = PluginRegistry::new().unwrap();
441 let result = detect_tool(None, None, dir.path(), ®istry).unwrap();
442
443 assert_eq!(result.plugin_name, "java");
444 assert_eq!(result.variant_name, "mvn");
445 });
446 }
447}