Skip to main content

testx/adapters/
go.rs

1use std::path::Path;
2use std::process::Command;
3use std::time::Duration;
4
5use anyhow::Result;
6
7use super::util::duration_from_secs_safe;
8use super::{DetectionResult, TestAdapter, TestCase, TestRunResult, TestStatus, TestSuite};
9
10pub struct GoAdapter;
11
12impl Default for GoAdapter {
13    fn default() -> Self {
14        Self::new()
15    }
16}
17
18impl GoAdapter {
19    pub fn new() -> Self {
20        Self
21    }
22}
23
24impl TestAdapter for GoAdapter {
25    fn name(&self) -> &str {
26        "Go"
27    }
28
29    fn check_runner(&self) -> Option<String> {
30        if which::which("go").is_err() {
31            Some("go".into())
32        } else {
33            None
34        }
35    }
36
37    fn detect(&self, project_dir: &Path) -> Option<DetectionResult> {
38        if !project_dir.join("go.mod").exists() {
39            return None;
40        }
41
42        // Check for test files
43        let has_tests = std::fs::read_dir(project_dir).ok()?.any(|entry| {
44            entry
45                .ok()
46                .is_some_and(|e| e.file_name().to_string_lossy().ends_with("_test.go"))
47        }) || find_test_files_recursive(project_dir);
48
49        if !has_tests {
50            return None;
51        }
52
53        Some(DetectionResult {
54            language: "Go".into(),
55            framework: "go test".into(),
56            confidence: 0.95,
57        })
58    }
59
60    fn build_command(&self, project_dir: &Path, extra_args: &[String]) -> Result<Command> {
61        let mut cmd = Command::new("go");
62        cmd.arg("test");
63
64        if extra_args.is_empty() {
65            cmd.arg("-v"); // verbose for parsing individual tests
66            cmd.arg("./..."); // all packages
67        }
68
69        for arg in extra_args {
70            cmd.arg(arg);
71        }
72
73        cmd.current_dir(project_dir);
74        Ok(cmd)
75    }
76
77    fn parse_output(&self, stdout: &str, stderr: &str, exit_code: i32) -> TestRunResult {
78        let combined = format!("{}\n{}", stdout, stderr);
79        let failure_messages = parse_go_failures(&combined);
80        let mut suites: Vec<TestSuite> = Vec::new();
81        let mut current_pkg = String::new();
82        let mut current_tests: Vec<TestCase> = Vec::new();
83
84        for line in combined.lines() {
85            let trimmed = line.trim();
86
87            // Go test verbose output:
88            // "=== RUN   TestFoo"
89            // "--- PASS: TestFoo (0.00s)"
90            // "--- FAIL: TestFoo (0.05s)"
91            // "--- SKIP: TestFoo (0.00s)"
92
93            if trimmed.starts_with("--- PASS:")
94                || trimmed.starts_with("--- FAIL:")
95                || trimmed.starts_with("--- SKIP:")
96            {
97                let status = if trimmed.starts_with("--- PASS:") {
98                    TestStatus::Passed
99                } else if trimmed.starts_with("--- FAIL:") {
100                    TestStatus::Failed
101                } else {
102                    TestStatus::Skipped
103                };
104
105                let rest = trimmed.split(':').nth(1).unwrap_or("").trim();
106                let parts: Vec<&str> = rest.split_whitespace().collect();
107                let name = parts.first().unwrap_or(&"unknown").to_string();
108                let duration = parts
109                    .get(1)
110                    .and_then(|s| {
111                        let s = s.trim_matches(|c| c == '(' || c == ')' || c == 's');
112                        s.parse::<f64>().ok()
113                    })
114                    .map(duration_from_secs_safe)
115                    .unwrap_or(Duration::from_millis(0));
116
117                let error = if status == TestStatus::Failed {
118                    failure_messages
119                        .get(name.as_str())
120                        .map(|msg| super::TestError {
121                            message: msg.clone(),
122                            location: None,
123                        })
124                } else {
125                    None
126                };
127
128                current_tests.push(TestCase {
129                    name,
130                    status,
131                    duration,
132                    error,
133                });
134                continue;
135            }
136
137            // Package result line: "ok  	github.com/user/pkg	0.005s"
138            // or: "FAIL	github.com/user/pkg	0.005s"
139            if (trimmed.starts_with("ok") || trimmed.starts_with("FAIL")) && trimmed.contains('\t')
140            {
141                // Flush current tests to this new package suite
142                let parts: Vec<&str> = trimmed.split('\t').collect();
143                let pkg_name = parts.get(1).unwrap_or(&"").trim().to_string();
144
145                if !current_tests.is_empty() {
146                    suites.push(TestSuite {
147                        name: if current_pkg.is_empty() {
148                            pkg_name.clone()
149                        } else {
150                            current_pkg.clone()
151                        },
152                        tests: std::mem::take(&mut current_tests),
153                    });
154                }
155                current_pkg = pkg_name;
156            }
157        }
158
159        // Flush remaining
160        if !current_tests.is_empty() {
161            let name = if current_pkg.is_empty() {
162                "tests".into()
163            } else {
164                current_pkg
165            };
166            suites.push(TestSuite {
167                name,
168                tests: current_tests,
169            });
170        }
171
172        if suites.is_empty() {
173            let status = if exit_code == 0 {
174                TestStatus::Passed
175            } else {
176                TestStatus::Failed
177            };
178            suites.push(TestSuite {
179                name: "tests".into(),
180                tests: vec![TestCase {
181                    name: "test_suite".into(),
182                    status,
183                    duration: Duration::from_millis(0),
184                    error: None,
185                }],
186            });
187        }
188
189        // Parse total duration from last "ok" or "FAIL" line
190        let duration = parse_go_total_duration(&combined).unwrap_or(Duration::from_secs(0));
191
192        TestRunResult {
193            suites,
194            duration,
195            raw_exit_code: exit_code,
196        }
197    }
198}
199
200fn find_test_files_recursive(dir: &Path) -> bool {
201    let Ok(entries) = std::fs::read_dir(dir) else {
202        return false;
203    };
204    for entry in entries.flatten() {
205        let path = entry.path();
206        if path.is_file() && path.to_string_lossy().ends_with("_test.go") {
207            return true;
208        }
209        if path.is_dir() {
210            let name = path.file_name().unwrap_or_default().to_string_lossy();
211            // Skip hidden dirs and vendor
212            if !name.starts_with('.')
213                && name != "vendor"
214                && name != "node_modules"
215                && find_test_files_recursive(&path)
216            {
217                return true;
218            }
219        }
220    }
221    false
222}
223
224/// Parse go test failure output to extract error messages per test.
225/// Go test verbose output shows errors as indented lines between `=== RUN` and `--- FAIL:`:
226/// ```text
227/// === RUN   TestDivide
228///     math_test.go:15: expected 2, got 0
229/// --- FAIL: TestDivide (0.00s)
230/// ```
231fn parse_go_failures(output: &str) -> std::collections::HashMap<String, String> {
232    let mut failures = std::collections::HashMap::new();
233    let lines: Vec<&str> = output.lines().collect();
234
235    let mut i = 0;
236    while i < lines.len() {
237        let trimmed = lines[i].trim();
238        // Match "=== RUN   TestName"
239        if let Some(rest) = trimmed.strip_prefix("=== RUN") {
240            let test_name = rest.trim().to_string();
241            if !test_name.is_empty() {
242                let mut msg_lines = Vec::new();
243                i += 1;
244                while i < lines.len() {
245                    let l = lines[i].trim();
246                    if l.starts_with("--- FAIL:")
247                        || l.starts_with("--- PASS:")
248                        || l.starts_with("--- SKIP:")
249                        || l.starts_with("=== RUN")
250                    {
251                        break;
252                    }
253                    if !l.is_empty() {
254                        msg_lines.push(l.to_string());
255                    }
256                    i += 1;
257                }
258                // Only store if this test actually failed
259                if i < lines.len()
260                    && lines[i].trim().starts_with("--- FAIL:")
261                    && !msg_lines.is_empty()
262                {
263                    failures.insert(test_name, msg_lines.join(" | "));
264                }
265                continue;
266            }
267        }
268        i += 1;
269    }
270    failures
271}
272
273fn parse_go_total_duration(output: &str) -> Option<Duration> {
274    let mut total = Duration::from_secs(0);
275    let mut found = false;
276    for line in output.lines() {
277        let trimmed = line.trim();
278        if (trimmed.starts_with("ok") || trimmed.starts_with("FAIL")) && trimmed.contains('\t') {
279            let parts: Vec<&str> = trimmed.split('\t').collect();
280            if let Some(time_str) = parts.last() {
281                let time_str = time_str.trim().trim_end_matches('s');
282                if let Ok(secs) = time_str.parse::<f64>() {
283                    total += duration_from_secs_safe(secs);
284                    found = true;
285                }
286            }
287        }
288    }
289    if found { Some(total) } else { None }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    #[test]
297    fn parse_go_verbose_output() {
298        let stdout = r#"
299=== RUN   TestAdd
300--- PASS: TestAdd (0.00s)
301=== RUN   TestSubtract
302--- PASS: TestSubtract (0.00s)
303=== RUN   TestDivide
304    math_test.go:15: expected 2, got 0
305--- FAIL: TestDivide (0.05s)
306FAIL
307FAIL	github.com/user/mathpkg	0.052s
308"#;
309        let adapter = GoAdapter::new();
310        let result = adapter.parse_output(stdout, "", 1);
311
312        assert_eq!(result.total_tests(), 3);
313        assert_eq!(result.total_passed(), 2);
314        assert_eq!(result.total_failed(), 1);
315        assert!(!result.is_success());
316
317        // Verify error message was captured
318        let failed = &result.suites[0].failures();
319        assert_eq!(failed.len(), 1);
320        assert!(failed[0].error.is_some());
321        assert!(
322            failed[0]
323                .error
324                .as_ref()
325                .unwrap()
326                .message
327                .contains("expected 2, got 0")
328        );
329    }
330
331    #[test]
332    fn parse_go_all_pass() {
333        let stdout = r#"
334=== RUN   TestHello
335--- PASS: TestHello (0.00s)
336=== RUN   TestWorld
337--- PASS: TestWorld (0.01s)
338ok  	github.com/user/pkg	0.015s
339"#;
340        let adapter = GoAdapter::new();
341        let result = adapter.parse_output(stdout, "", 0);
342
343        assert_eq!(result.total_passed(), 2);
344        assert_eq!(result.total_failed(), 0);
345        assert!(result.is_success());
346    }
347
348    #[test]
349    fn parse_go_skipped() {
350        let stdout = r#"
351=== RUN   TestFoo
352--- SKIP: TestFoo (0.00s)
353ok  	github.com/user/pkg	0.001s
354"#;
355        let adapter = GoAdapter::new();
356        let result = adapter.parse_output(stdout, "", 0);
357
358        assert_eq!(result.total_skipped(), 1);
359        assert!(result.is_success());
360    }
361
362    #[test]
363    fn parse_go_multiple_packages() {
364        let stdout = r#"
365=== RUN   TestA
366--- PASS: TestA (0.00s)
367ok  	github.com/user/pkg/a	0.005s
368=== RUN   TestB
369--- FAIL: TestB (0.02s)
370FAIL	github.com/user/pkg/b	0.025s
371"#;
372        let adapter = GoAdapter::new();
373        let result = adapter.parse_output(stdout, "", 1);
374
375        assert_eq!(result.total_tests(), 2);
376        assert_eq!(result.total_passed(), 1);
377        assert_eq!(result.total_failed(), 1);
378    }
379
380    #[test]
381    fn parse_go_duration() {
382        let output = "ok  \tgithub.com/user/pkg\t1.234s\n";
383        let dur = parse_go_total_duration(output).unwrap();
384        assert_eq!(dur, Duration::from_millis(1234));
385    }
386
387    #[test]
388    fn detect_go_project() {
389        let dir = tempfile::tempdir().unwrap();
390        std::fs::write(dir.path().join("go.mod"), "module example.com/test\n").unwrap();
391        std::fs::write(dir.path().join("main_test.go"), "package main\n").unwrap();
392        let adapter = GoAdapter::new();
393        let det = adapter.detect(dir.path()).unwrap();
394        assert_eq!(det.framework, "go test");
395    }
396
397    #[test]
398    fn detect_no_go() {
399        let dir = tempfile::tempdir().unwrap();
400        let adapter = GoAdapter::new();
401        assert!(adapter.detect(dir.path()).is_none());
402    }
403
404    #[test]
405    fn parse_go_empty_output() {
406        let adapter = GoAdapter::new();
407        let result = adapter.parse_output("", "", 0);
408
409        assert_eq!(result.total_tests(), 1);
410        assert!(result.is_success());
411    }
412
413    #[test]
414    fn parse_go_subtests() {
415        let stdout = r#"
416=== RUN   TestMath
417=== RUN   TestMath/Add
418--- PASS: TestMath/Add (0.00s)
419=== RUN   TestMath/Subtract
420--- PASS: TestMath/Subtract (0.00s)
421--- PASS: TestMath (0.00s)
422ok  	github.com/user/pkg	0.003s
423"#;
424        let adapter = GoAdapter::new();
425        let result = adapter.parse_output(stdout, "", 0);
426
427        // Should capture parent and subtests
428        assert!(result.total_passed() >= 2);
429        assert!(result.is_success());
430    }
431
432    #[test]
433    fn parse_go_panic_output() {
434        let stdout = r#"
435=== RUN   TestCrash
436--- FAIL: TestCrash (0.00s)
437panic: runtime error: index out of range [recovered]
438FAIL	github.com/user/pkg	0.001s
439"#;
440        let adapter = GoAdapter::new();
441        let result = adapter.parse_output(stdout, "", 1);
442
443        assert_eq!(result.total_failed(), 1);
444        assert!(!result.is_success());
445    }
446
447    #[test]
448    fn parse_go_no_test_files() {
449        let stdout = "?   \tgithub.com/user/pkg\t[no test files]\n";
450        let adapter = GoAdapter::new();
451        let result = adapter.parse_output(stdout, "", 0);
452
453        // Should create a synthetic passing suite
454        assert!(result.is_success());
455    }
456
457    #[test]
458    fn detect_go_needs_test_files() {
459        let dir = tempfile::tempdir().unwrap();
460        // go.mod but no *_test.go files
461        std::fs::write(dir.path().join("go.mod"), "module example.com/test\n").unwrap();
462        std::fs::write(dir.path().join("main.go"), "package main\n").unwrap();
463        let adapter = GoAdapter::new();
464        assert!(adapter.detect(dir.path()).is_none());
465    }
466}