1use std::path::{Path, PathBuf};
2
3use globset::GlobMatcher;
4
5use crate::error::{Result, UbtError};
6use crate::plugin::{Plugin, PluginRegistry, PluginSource};
7
8#[derive(Debug)]
10pub struct DetectionResult {
11 pub plugin_name: String,
12 pub variant_name: String,
13 pub source: PluginSource,
14 pub project_root: PathBuf,
15}
16
17struct CompiledPlugin<'a> {
19 plugin: &'a Plugin,
20 source: &'a PluginSource,
21 detect_matchers: Vec<Option<GlobMatcher>>,
23 variant_matchers: Vec<(&'a str, Vec<Option<GlobMatcher>>)>,
25}
26
27impl<'a> CompiledPlugin<'a> {
28 fn new(plugin: &'a Plugin, source: &'a PluginSource) -> Result<Self> {
29 let detect_matchers = compile_patterns(&plugin.detect.files)?;
30
31 let variant_matchers = plugin
32 .variants
33 .iter()
34 .map(|(name, variant)| {
35 compile_patterns(&variant.detect_files).map(|m| (name.as_str(), m))
36 })
37 .collect::<Result<Vec<_>>>()?;
38
39 Ok(Self {
40 plugin,
41 source,
42 detect_matchers,
43 variant_matchers,
44 })
45 }
46}
47
48fn compile_patterns(patterns: &[String]) -> Result<Vec<Option<GlobMatcher>>> {
52 patterns
53 .iter()
54 .map(|p| {
55 if p.contains('*') {
56 globset::GlobBuilder::new(p)
57 .literal_separator(true)
58 .build()
59 .map(|g| Some(g.compile_matcher()))
60 .map_err(|e| UbtError::InvalidGlobPattern {
61 pattern: p.clone(),
62 detail: e.to_string(),
63 })
64 } else {
65 Ok(None)
66 }
67 })
68 .collect()
69}
70
71fn compile_registry(registry: &PluginRegistry) -> Result<Vec<CompiledPlugin<'_>>> {
73 registry
74 .iter()
75 .map(|(_name, (plugin, source))| CompiledPlugin::new(plugin, source))
76 .collect()
77}
78
79pub fn detect_tool(
85 cli_tool: Option<&str>,
86 config_tool: Option<&str>,
87 start_dir: &Path,
88 registry: &PluginRegistry,
89) -> Result<DetectionResult> {
90 if let Some(tool) = cli_tool {
92 return resolve_explicit_tool(tool, start_dir, registry);
93 }
94
95 if let Ok(tool) = std::env::var("UBT_TOOL")
98 && !tool.is_empty()
99 {
100 return resolve_explicit_tool(&tool, start_dir, registry);
101 }
102
103 if let Some(tool) = config_tool {
105 return resolve_explicit_tool(tool, start_dir, registry);
106 }
107
108 let compiled = compile_registry(registry)?;
110 auto_detect(start_dir, &compiled)
111}
112
113fn resolve_explicit_tool(
116 tool: &str,
117 start_dir: &Path,
118 registry: &PluginRegistry,
119) -> Result<DetectionResult> {
120 if let Some((plugin, source)) = registry.get(tool) {
122 return Ok(DetectionResult {
123 plugin_name: plugin.name.clone(),
124 variant_name: detect_variant_literal(plugin, start_dir)?
125 .unwrap_or_else(|| plugin.default_variant.clone()),
126 source: source.clone(),
127 project_root: start_dir.to_path_buf(),
128 });
129 }
130
131 for (_name, (plugin, source)) in registry.iter() {
133 if plugin.variants.contains_key(tool) {
134 return Ok(DetectionResult {
135 plugin_name: plugin.name.clone(),
136 variant_name: tool.to_string(),
137 source: source.clone(),
138 project_root: start_dir.to_path_buf(),
139 });
140 }
141 }
142
143 Err(UbtError::PluginLoadError {
144 name: tool.to_string(),
145 detail: "no plugin or variant found with this name".into(),
146 })
147}
148
149fn auto_detect(start_dir: &Path, compiled: &[CompiledPlugin<'_>]) -> Result<DetectionResult> {
151 let mut current = start_dir.to_path_buf();
152
153 loop {
154 let matches = detect_at_dir(¤t, compiled);
155 if !matches.is_empty() {
156 return resolve_matches(matches, ¤t);
157 }
158 if !current.pop() {
159 break;
160 }
161 }
162
163 Err(UbtError::NoPluginMatch)
164}
165
166#[derive(Debug)]
168struct DetectMatch {
169 plugin_name: String,
170 variant_name: String,
171 priority: i32,
172 source: PluginSource,
173}
174
175fn detect_at_dir(dir: &Path, compiled: &[CompiledPlugin<'_>]) -> Vec<DetectMatch> {
177 let mut matches = Vec::new();
178
179 for cp in compiled {
180 if plugin_matches_dir(cp, dir) {
181 let variant = detect_variant_compiled(cp, dir)
182 .unwrap_or_else(|| cp.plugin.default_variant.clone());
183 matches.push(DetectMatch {
184 plugin_name: cp.plugin.name.clone(),
185 variant_name: variant,
186 priority: cp.plugin.priority,
187 source: cp.source.clone(),
188 });
189 }
190 }
191
192 matches
193}
194
195fn plugin_matches_dir(cp: &CompiledPlugin<'_>, dir: &Path) -> bool {
198 cp.plugin
199 .detect
200 .files
201 .iter()
202 .zip(cp.detect_matchers.iter())
203 .any(|(pattern, matcher)| match matcher {
204 Some(m) => glob_matches_with(dir, m),
205 None => dir.join(pattern).exists(),
206 })
207}
208
209fn detect_variant_compiled(cp: &CompiledPlugin<'_>, dir: &Path) -> Option<String> {
211 for (variant_name, matchers) in &cp.variant_matchers {
212 let variant = cp.plugin.variants.get(*variant_name)?;
213 for (detect_file, matcher) in variant.detect_files.iter().zip(matchers.iter()) {
214 let matched = match matcher {
215 Some(m) => glob_matches_with(dir, m),
216 None => dir.join(detect_file).exists(),
217 };
218 if matched {
219 return Some((*variant_name).to_string());
220 }
221 }
222 }
223 None
224}
225
226fn detect_variant_literal(plugin: &Plugin, dir: &Path) -> Result<Option<String>> {
229 for (variant_name, variant) in &plugin.variants {
230 for detect_file in &variant.detect_files {
231 let matched = if detect_file.contains('*') {
232 let glob = globset::GlobBuilder::new(detect_file)
233 .literal_separator(true)
234 .build()
235 .map_err(|e| UbtError::InvalidGlobPattern {
236 pattern: detect_file.clone(),
237 detail: e.to_string(),
238 })?;
239 glob_matches_with(dir, &glob.compile_matcher())
240 } else {
241 dir.join(detect_file).exists()
242 };
243 if matched {
244 return Ok(Some(variant_name.clone()));
245 }
246 }
247 }
248 Ok(None)
249}
250
251fn glob_matches_with(dir: &Path, matcher: &GlobMatcher) -> bool {
253 let Ok(entries) = std::fs::read_dir(dir) else {
254 return false;
255 };
256
257 entries.filter_map(|e| e.ok()).any(|entry| {
258 entry
259 .file_name()
260 .to_str()
261 .map(|name| matcher.is_match(name))
262 .unwrap_or(false)
263 })
264}
265
266fn resolve_matches(matches: Vec<DetectMatch>, dir: &Path) -> Result<DetectionResult> {
268 assert!(!matches.is_empty());
269
270 if matches.len() == 1 {
271 let m = matches.into_iter().next().unwrap();
272 return Ok(DetectionResult {
273 plugin_name: m.plugin_name,
274 variant_name: m.variant_name,
275 source: m.source,
276 project_root: dir.to_path_buf(),
277 });
278 }
279
280 let mut sorted = matches;
282 sorted.sort_by(|a, b| b.priority.cmp(&a.priority));
283
284 if sorted[0].priority == sorted[1].priority {
286 let plugins: Vec<_> = sorted.iter().map(|m| m.plugin_name.as_str()).collect();
287 return Err(UbtError::PluginConflict {
288 plugins: plugins.join(", "),
289 suggested_tool: sorted[0].plugin_name.clone(),
290 });
291 }
292
293 let winner = sorted.into_iter().next().unwrap();
294 Ok(DetectionResult {
295 plugin_name: winner.plugin_name,
296 variant_name: winner.variant_name,
297 source: winner.source,
298 project_root: dir.to_path_buf(),
299 })
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305 use tempfile::TempDir;
306
307 fn with_clean_env<F, R>(f: F) -> R
308 where
309 F: FnOnce() -> R,
310 {
311 temp_env::with_var("UBT_TOOL", None::<&str>, f)
312 }
313
314 #[test]
315 fn detect_go_project() {
316 with_clean_env(|| {
317 let dir = TempDir::new().unwrap();
318 std::fs::write(dir.path().join("go.mod"), "module example.com/foo").unwrap();
319
320 let registry = PluginRegistry::new().unwrap();
321 let result = detect_tool(None, None, dir.path(), ®istry).unwrap();
322
323 assert_eq!(result.plugin_name, "go");
324 assert_eq!(result.variant_name, "go");
325 });
326 }
327
328 #[test]
329 fn detect_node_npm() {
330 with_clean_env(|| {
331 let dir = TempDir::new().unwrap();
332 std::fs::write(dir.path().join("package.json"), "{}").unwrap();
333 std::fs::write(dir.path().join("package-lock.json"), "{}").unwrap();
334
335 let registry = PluginRegistry::new().unwrap();
336 let result = detect_tool(None, None, dir.path(), ®istry).unwrap();
337
338 assert_eq!(result.plugin_name, "node");
339 assert_eq!(result.variant_name, "npm");
340 });
341 }
342
343 #[test]
344 fn detect_node_pnpm() {
345 with_clean_env(|| {
346 let dir = TempDir::new().unwrap();
347 std::fs::write(dir.path().join("package.json"), "{}").unwrap();
348 std::fs::write(dir.path().join("pnpm-lock.yaml"), "").unwrap();
349
350 let registry = PluginRegistry::new().unwrap();
351 let result = detect_tool(None, None, dir.path(), ®istry).unwrap();
352
353 assert_eq!(result.plugin_name, "node");
354 assert_eq!(result.variant_name, "pnpm");
355 });
356 }
357
358 #[test]
359 fn detect_node_default_variant_when_no_lockfile() {
360 with_clean_env(|| {
361 let dir = TempDir::new().unwrap();
362 std::fs::write(dir.path().join("package.json"), "{}").unwrap();
363
364 let registry = PluginRegistry::new().unwrap();
365 let result = detect_tool(None, None, dir.path(), ®istry).unwrap();
366
367 assert_eq!(result.plugin_name, "node");
368 assert_eq!(result.variant_name, "npm");
369 });
370 }
371
372 #[test]
373 fn detect_rust_project() {
374 with_clean_env(|| {
375 let dir = TempDir::new().unwrap();
376 std::fs::write(dir.path().join("Cargo.toml"), "[package]\nname = \"foo\"").unwrap();
377
378 let registry = PluginRegistry::new().unwrap();
379 let result = detect_tool(None, None, dir.path(), ®istry).unwrap();
380
381 assert_eq!(result.plugin_name, "rust");
382 assert_eq!(result.variant_name, "cargo");
383 });
384 }
385
386 #[test]
387 fn detect_cli_override() {
388 with_clean_env(|| {
389 let dir = TempDir::new().unwrap();
390 std::fs::write(dir.path().join("go.mod"), "module foo").unwrap();
392
393 let registry = PluginRegistry::new().unwrap();
394 let result = detect_tool(Some("node"), None, dir.path(), ®istry).unwrap();
395
396 assert_eq!(result.plugin_name, "node");
397 });
398 }
399
400 #[test]
401 fn detect_config_override() {
402 with_clean_env(|| {
403 let dir = TempDir::new().unwrap();
404 std::fs::write(dir.path().join("go.mod"), "module foo").unwrap();
405
406 let registry = PluginRegistry::new().unwrap();
407 let result = detect_tool(None, Some("node"), dir.path(), ®istry).unwrap();
408
409 assert_eq!(result.plugin_name, "node");
410 });
411 }
412
413 #[test]
414 fn detect_variant_name_as_tool() {
415 with_clean_env(|| {
416 let dir = TempDir::new().unwrap();
417 let registry = PluginRegistry::new().unwrap();
418 let result = detect_tool(Some("pnpm"), None, dir.path(), ®istry).unwrap();
419
420 assert_eq!(result.plugin_name, "node");
421 assert_eq!(result.variant_name, "pnpm");
422 });
423 }
424
425 #[test]
426 fn detect_walks_upward() {
427 with_clean_env(|| {
428 let dir = TempDir::new().unwrap();
429 std::fs::write(dir.path().join("go.mod"), "module foo").unwrap();
430 let nested = dir.path().join("a").join("b").join("c");
431 std::fs::create_dir_all(&nested).unwrap();
432
433 let registry = PluginRegistry::new().unwrap();
434 let result = detect_tool(None, None, &nested, ®istry).unwrap();
435
436 assert_eq!(result.plugin_name, "go");
437 assert_eq!(result.project_root, dir.path());
438 });
439 }
440
441 #[test]
442 fn detect_no_match_errors() {
443 with_clean_env(|| {
444 let dir = TempDir::new().unwrap();
445 let registry = PluginRegistry::new().unwrap();
446 let result = detect_tool(None, None, dir.path(), ®istry);
447
448 assert!(result.is_err());
449 assert!(matches!(result.unwrap_err(), UbtError::NoPluginMatch));
450 });
451 }
452
453 #[test]
454 fn detect_unknown_tool_errors() {
455 with_clean_env(|| {
456 let dir = TempDir::new().unwrap();
457 let registry = PluginRegistry::new().unwrap();
458 let result = detect_tool(Some("nonexistent"), None, dir.path(), ®istry);
459
460 assert!(result.is_err());
461 });
462 }
463
464 #[test]
465 fn detect_dotnet_glob() {
466 with_clean_env(|| {
467 let dir = TempDir::new().unwrap();
468 std::fs::write(dir.path().join("MyApp.csproj"), "<Project/>").unwrap();
469
470 let registry = PluginRegistry::new().unwrap();
471 let result = detect_tool(None, None, dir.path(), ®istry).unwrap();
472
473 assert_eq!(result.plugin_name, "dotnet");
474 });
475 }
476
477 #[test]
478 fn detect_ruby_project() {
479 with_clean_env(|| {
480 let dir = TempDir::new().unwrap();
481 std::fs::write(dir.path().join("Gemfile"), "source 'https://rubygems.org'").unwrap();
482 std::fs::write(dir.path().join("Gemfile.lock"), "").unwrap();
483
484 let registry = PluginRegistry::new().unwrap();
485 let result = detect_tool(None, None, dir.path(), ®istry).unwrap();
486
487 assert_eq!(result.plugin_name, "ruby");
488 assert_eq!(result.variant_name, "bundler");
489 });
490 }
491
492 #[test]
493 fn detect_ruby_rails_project() {
494 with_clean_env(|| {
495 let dir = TempDir::new().unwrap();
496 std::fs::create_dir(dir.path().join("bin")).unwrap();
497 std::fs::write(dir.path().join("Gemfile"), "source 'https://rubygems.org'").unwrap();
498 std::fs::write(dir.path().join("bin/rails"), "#!/usr/bin/env ruby").unwrap();
499
500 let registry = PluginRegistry::new().unwrap();
501 let result = detect_tool(None, None, dir.path(), ®istry).unwrap();
502
503 assert_eq!(result.plugin_name, "ruby");
504 assert_eq!(result.variant_name, "rails");
505 });
506 }
507
508 #[test]
509 fn detect_ruby_rails_with_lockfile() {
510 with_clean_env(|| {
511 let dir = TempDir::new().unwrap();
512 std::fs::create_dir(dir.path().join("bin")).unwrap();
513 std::fs::write(dir.path().join("Gemfile"), "source 'https://rubygems.org'").unwrap();
514 std::fs::write(dir.path().join("Gemfile.lock"), "").unwrap();
515 std::fs::write(dir.path().join("bin/rails"), "#!/usr/bin/env ruby").unwrap();
516
517 let registry = PluginRegistry::new().unwrap();
518 let result = detect_tool(None, None, dir.path(), ®istry).unwrap();
519
520 assert_eq!(result.plugin_name, "ruby");
521 assert_eq!(result.variant_name, "rails");
522 });
523 }
524
525 #[test]
526 fn detect_python_pip() {
527 with_clean_env(|| {
528 let dir = TempDir::new().unwrap();
529 std::fs::write(dir.path().join("requirements.txt"), "flask").unwrap();
530
531 let registry = PluginRegistry::new().unwrap();
532 let result = detect_tool(None, None, dir.path(), ®istry).unwrap();
533
534 assert_eq!(result.plugin_name, "python");
535 });
536 }
537
538 #[test]
539 fn detect_java_maven() {
540 with_clean_env(|| {
541 let dir = TempDir::new().unwrap();
542 std::fs::write(dir.path().join("pom.xml"), "<project/>").unwrap();
543
544 let registry = PluginRegistry::new().unwrap();
545 let result = detect_tool(None, None, dir.path(), ®istry).unwrap();
546
547 assert_eq!(result.plugin_name, "java");
548 assert_eq!(result.variant_name, "mvn");
549 });
550 }
551}