Skip to main content

testx/adapters/
python.rs

1use std::path::Path;
2use std::process::Command;
3use std::time::Duration;
4
5use anyhow::Result;
6
7use super::util::{combined_output, duration_from_secs_safe};
8use super::{
9    ConfidenceScore, DetectionResult, TestAdapter, TestCase, TestRunResult, TestStatus, TestSuite,
10};
11
12pub struct PythonAdapter;
13
14impl Default for PythonAdapter {
15    fn default() -> Self {
16        Self::new()
17    }
18}
19
20impl PythonAdapter {
21    pub fn new() -> Self {
22        Self
23    }
24
25    /// Check if pytest is the test framework
26    fn is_pytest(project_dir: &Path) -> bool {
27        // Check for pytest-specific files/configs
28        let markers = ["pytest.ini", ".pytest_cache", "conftest.py"];
29        for m in &markers {
30            if project_dir.join(m).exists() {
31                return true;
32            }
33        }
34
35        // Check pyproject.toml for pytest config
36        let pyproject = project_dir.join("pyproject.toml");
37        if pyproject.exists()
38            && let Ok(content) = std::fs::read_to_string(&pyproject)
39            && (content.contains("[tool.pytest") || content.contains("pytest"))
40        {
41            return true;
42        }
43
44        // Check setup.cfg
45        let setup_cfg = project_dir.join("setup.cfg");
46        if setup_cfg.exists()
47            && let Ok(content) = std::fs::read_to_string(&setup_cfg)
48            && content.contains("[tool:pytest]")
49        {
50            return true;
51        }
52
53        // Check tox.ini
54        let tox_ini = project_dir.join("tox.ini");
55        if tox_ini.exists()
56            && let Ok(content) = std::fs::read_to_string(&tox_ini)
57            && content.contains("[pytest]")
58        {
59            return true;
60        }
61
62        false
63    }
64
65    /// Check if Django is present
66    fn is_django(project_dir: &Path) -> bool {
67        project_dir.join("manage.py").exists()
68    }
69
70    /// Detect the Python package manager to use as a prefix
71    fn detect_runner_prefix(project_dir: &Path) -> Option<Vec<String>> {
72        if project_dir.join("uv.lock").exists() || {
73            let pyproject = project_dir.join("pyproject.toml");
74            pyproject.exists()
75                && std::fs::read_to_string(&pyproject)
76                    .map(|c| c.contains("[tool.uv]"))
77                    .unwrap_or(false)
78        } {
79            return Some(vec!["uv".into(), "run".into()]);
80        }
81        if project_dir.join("poetry.lock").exists() {
82            return Some(vec!["poetry".into(), "run".into()]);
83        }
84        if project_dir.join("pdm.lock").exists() {
85            return Some(vec!["pdm".into(), "run".into()]);
86        }
87        None
88    }
89
90    /// Find pytest binary inside a local virtualenv (.venv, venv, .env, env).
91    fn find_venv_pytest(project_dir: &Path) -> Option<std::path::PathBuf> {
92        for venv_dir in [".venv", "venv", ".env", "env"] {
93            let venv_pytest = project_dir.join(venv_dir).join("bin").join("pytest");
94            if venv_pytest.exists() {
95                return Some(venv_pytest);
96            }
97        }
98        None
99    }
100}
101
102impl TestAdapter for PythonAdapter {
103    fn name(&self) -> &str {
104        "Python"
105    }
106
107    fn check_runner(&self) -> Option<String> {
108        // If uv/poetry/pdm is the runner, check that instead
109        for runner in ["uv", "poetry", "pdm", "pytest", "python"] {
110            if which::which(runner).is_ok() {
111                return None;
112            }
113        }
114        Some("python".into())
115    }
116
117    fn detect(&self, project_dir: &Path) -> Option<DetectionResult> {
118        let has_python_files = project_dir.join("pyproject.toml").exists()
119            || project_dir.join("setup.py").exists()
120            || project_dir.join("setup.cfg").exists()
121            || project_dir.join("requirements.txt").exists()
122            || project_dir.join("Pipfile").exists();
123
124        if !has_python_files {
125            return None;
126        }
127
128        let framework = if Self::is_pytest(project_dir) {
129            "pytest"
130        } else if Self::is_django(project_dir) {
131            "django"
132        } else {
133            "unittest"
134        };
135
136        let is_pytest = framework == "pytest";
137        let is_django = framework == "django";
138        let has_test_dir = project_dir.join("tests").is_dir() || project_dir.join("test").is_dir();
139        let has_lock = project_dir.join("poetry.lock").exists()
140            || project_dir.join("Pipfile.lock").exists()
141            || project_dir.join("uv.lock").exists()
142            || project_dir.join("pdm.lock").exists();
143
144        let confidence = ConfidenceScore::base(0.50)
145            .signal(0.20, is_pytest)
146            .signal(0.15, is_django)
147            .signal(0.10, has_test_dir)
148            .signal(0.07, has_lock)
149            .signal(
150                0.07,
151                which::which("pytest").is_ok() || which::which("python").is_ok(),
152            )
153            .finish();
154
155        Some(DetectionResult {
156            language: "Python".into(),
157            framework: framework.into(),
158            confidence,
159        })
160    }
161
162    fn build_command(&self, project_dir: &Path, extra_args: &[String]) -> Result<Command> {
163        let prefix = Self::detect_runner_prefix(project_dir);
164        let is_pytest = Self::is_pytest(project_dir);
165        let is_django = Self::is_django(project_dir);
166
167        let mut cmd;
168
169        if let Some(prefix_args) = &prefix {
170            cmd = Command::new(&prefix_args[0]);
171            for arg in &prefix_args[1..] {
172                cmd.arg(arg);
173            }
174            if is_pytest {
175                cmd.arg("pytest");
176            } else if is_django {
177                cmd.arg("python").arg("-m").arg("django").arg("test");
178            } else {
179                cmd.arg("python").arg("-m").arg("unittest");
180            }
181        } else if is_pytest {
182            // Check for pytest inside a local virtualenv first
183            if let Some(venv_pytest) = Self::find_venv_pytest(project_dir) {
184                cmd = Command::new(venv_pytest);
185            } else {
186                cmd = Command::new("pytest");
187            }
188        } else if is_django {
189            cmd = Command::new("python");
190            cmd.arg("manage.py").arg("test");
191        } else {
192            cmd = Command::new("python");
193            cmd.arg("-m").arg("unittest");
194        }
195
196        // Add verbose flag for better output parsing (pytest)
197        if is_pytest && extra_args.is_empty() {
198            cmd.arg("-v");
199        }
200
201        for arg in extra_args {
202            cmd.arg(arg);
203        }
204
205        cmd.current_dir(project_dir);
206        Ok(cmd)
207    }
208
209    fn filter_args(&self, pattern: &str) -> Vec<String> {
210        vec!["-k".to_string(), pattern.to_string()]
211    }
212
213    fn parse_output(&self, stdout: &str, stderr: &str, exit_code: i32) -> TestRunResult {
214        let combined = combined_output(stdout, stderr);
215        let failure_messages = parse_pytest_failures(&combined);
216        let mut suites: Vec<TestSuite> = Vec::new();
217        let mut current_suite_name = String::from("tests");
218        let mut tests: Vec<TestCase> = Vec::new();
219
220        for line in combined.lines() {
221            let trimmed = line.trim();
222
223            // pytest verbose output: "test_file.py::TestClass::test_name PASSED"
224            // or: "test_file.py::test_name PASSED"
225            if let Some((test_path, status_str)) = parse_pytest_line(trimmed) {
226                let parts: Vec<&str> = test_path.split("::").collect();
227                let suite_name = parts.first().unwrap_or(&"tests").to_string();
228                let test_name = parts.last().unwrap_or(&"unknown").to_string();
229
230                // If suite changed, flush current tests
231                if suite_name != current_suite_name && !tests.is_empty() {
232                    suites.push(TestSuite {
233                        name: current_suite_name.clone(),
234                        tests: std::mem::take(&mut tests),
235                    });
236                }
237                current_suite_name = suite_name;
238
239                let status = match status_str.to_uppercase().as_str() {
240                    "PASSED" => TestStatus::Passed,
241                    "FAILED" => TestStatus::Failed,
242                    "SKIPPED" | "XFAIL" | "XPASS" => TestStatus::Skipped,
243                    "ERROR" => TestStatus::Failed,
244                    _ => TestStatus::Failed,
245                };
246
247                let error = if status == TestStatus::Failed {
248                    // Try full path first, then just test name
249                    failure_messages
250                        .get(&test_path)
251                        .or_else(|| failure_messages.get(&test_name))
252                        .map(|msg| super::TestError {
253                            message: msg.clone(),
254                            location: None,
255                        })
256                } else {
257                    None
258                };
259
260                tests.push(TestCase {
261                    name: test_name,
262                    status,
263                    duration: Duration::from_millis(0),
264                    error,
265                });
266            }
267        }
268
269        // Flush remaining tests
270        if !tests.is_empty() {
271            suites.push(TestSuite {
272                name: current_suite_name,
273                tests,
274            });
275        }
276
277        // If we couldn't parse any individual tests, create a summary suite from the summary line
278        if suites.is_empty() {
279            suites.push(parse_pytest_summary(&combined, exit_code));
280        }
281
282        // Try to parse total duration from pytest summary
283        let duration = parse_pytest_duration(&combined).unwrap_or(Duration::from_secs(0));
284
285        TestRunResult {
286            suites,
287            duration,
288            raw_exit_code: exit_code,
289        }
290    }
291}
292
293/// Parse a pytest verbose output line like "tests/test_foo.py::test_bar PASSED"
294fn parse_pytest_line(line: &str) -> Option<(String, String)> {
295    // Match patterns like: "path::test_name PASSED  [ 50%]"
296    let statuses = ["PASSED", "FAILED", "SKIPPED", "ERROR", "XFAIL", "XPASS"];
297    for status in &statuses {
298        if let Some(idx) = line.rfind(status) {
299            // Ensure the status word is preceded by whitespace (not part of test name)
300            if idx > 0 && !line.as_bytes()[idx - 1].is_ascii_whitespace() {
301                continue;
302            }
303            let path = line[..idx].trim().to_string();
304            if path.contains("::") {
305                return Some((path, status.to_string()));
306            }
307        }
308    }
309    None
310}
311
312/// Parse pytest summary line like "=== 5 passed, 2 failed in 0.32s ==="
313fn parse_pytest_summary(output: &str, exit_code: i32) -> TestSuite {
314    let mut passed = 0usize;
315    let mut failed = 0usize;
316    let mut skipped = 0usize;
317
318    for line in output.lines() {
319        let trimmed = line.trim().trim_matches('=').trim();
320        if trimmed.contains("passed") || trimmed.contains("failed") || trimmed.contains("error") {
321            // Parse "5 passed", "2 failed", etc.
322            for part in trimmed.split(',') {
323                let part = part.trim();
324                if let Some(n) = part
325                    .split_whitespace()
326                    .next()
327                    .and_then(|s| s.parse::<usize>().ok())
328                {
329                    if part.contains("passed") {
330                        passed = n;
331                    } else if part.contains("failed") || part.contains("error") {
332                        failed = n;
333                    } else if part.contains("skipped") {
334                        skipped = n;
335                    }
336                }
337            }
338        }
339    }
340
341    let mut tests = Vec::new();
342    for i in 0..passed {
343        tests.push(TestCase {
344            name: format!("test_{}", i + 1),
345            status: TestStatus::Passed,
346            duration: Duration::from_millis(0),
347            error: None,
348        });
349    }
350    for i in 0..failed {
351        tests.push(TestCase {
352            name: format!("failed_test_{}", i + 1),
353            status: TestStatus::Failed,
354            duration: Duration::from_millis(0),
355            error: None,
356        });
357    }
358    for i in 0..skipped {
359        tests.push(TestCase {
360            name: format!("skipped_test_{}", i + 1),
361            status: TestStatus::Skipped,
362            duration: Duration::from_millis(0),
363            error: None,
364        });
365    }
366
367    // If we still got nothing, infer from exit code
368    if tests.is_empty() {
369        tests.push(TestCase {
370            name: "test_suite".into(),
371            status: if exit_code == 0 {
372                TestStatus::Passed
373            } else {
374                TestStatus::Failed
375            },
376            duration: Duration::from_millis(0),
377            error: None,
378        });
379    }
380
381    TestSuite {
382        name: "tests".into(),
383        tests,
384    }
385}
386
387/// Parse pytest FAILURES section to extract error messages per test.
388/// Pytest output looks like:
389/// ```text
390/// =========================== FAILURES ===========================
391/// __________________ test_multiply __________________
392///
393///     def test_multiply():
394/// >       assert multiply(2, 3) == 7
395/// E       assert 6 == 7
396/// E       +  where 6 = multiply(2, 3)
397///
398/// tests/test_math.py:10: AssertionError
399/// =========================== short test summary info ===========================
400/// ```
401fn parse_pytest_failures(output: &str) -> std::collections::HashMap<String, String> {
402    let mut failures = std::collections::HashMap::new();
403    let lines: Vec<&str> = output.lines().collect();
404    let mut in_failures = false;
405
406    let mut i = 0;
407    while i < lines.len() {
408        let trimmed = lines[i].trim();
409
410        // Enter FAILURES section
411        if trimmed.contains("FAILURES") && trimmed.starts_with('=') {
412            in_failures = true;
413            i += 1;
414            continue;
415        }
416
417        // Exit FAILURES section
418        if in_failures
419            && trimmed.starts_with('=')
420            && (trimmed.contains("short test summary")
421                || trimmed.contains("passed")
422                || trimmed.contains("failed")
423                || trimmed.contains("error"))
424        {
425            break;
426        }
427
428        // Match test header: "__________________ test_name __________________"
429        if in_failures && trimmed.starts_with('_') && trimmed.ends_with('_') {
430            let test_name = trimmed.trim_matches('_').trim().to_string();
431            if !test_name.is_empty() {
432                let mut error_lines = Vec::new();
433                let mut location = None;
434                i += 1;
435                while i < lines.len() {
436                    let l = lines[i].trim();
437                    // Next test header or section boundary
438                    if (l.starts_with('_') && l.ends_with('_') && l.len() > 5)
439                        || (l.starts_with('=') && l.len() > 5)
440                    {
441                        break;
442                    }
443                    // Assertion lines start with "E"
444                    if l.starts_with("E ") || l.starts_with("E\t") {
445                        error_lines.push(l[1..].trim().to_string());
446                    }
447                    // Location line like "tests/test_math.py:10: AssertionError"
448                    if l.contains(".py:")
449                        && l.contains(':')
450                        && !l.starts_with('>')
451                        && !l.starts_with("E")
452                    {
453                        let parts: Vec<&str> = l.splitn(3, ':').collect();
454                        if parts.len() >= 2 {
455                            location = Some(format!("{}:{}", parts[0].trim(), parts[1].trim()));
456                        }
457                    }
458                    i += 1;
459                }
460                if !error_lines.is_empty() {
461                    let mut msg = error_lines.join(" | ");
462                    if let Some(loc) = location {
463                        msg = format!("{} ({})", msg, loc);
464                    }
465                    failures.insert(test_name, msg);
466                }
467                continue;
468            }
469        }
470        i += 1;
471    }
472    failures
473}
474
475/// Parse duration from pytest summary like "in 0.32s"
476fn parse_pytest_duration(output: &str) -> Option<Duration> {
477    for line in output.lines() {
478        if let Some(idx) = line.find(" in ") {
479            let after = &line[idx + 4..];
480            let num_str: String = after
481                .chars()
482                .take_while(|c| c.is_ascii_digit() || *c == '.')
483                .collect();
484            if let Ok(secs) = num_str.parse::<f64>() {
485                return Some(duration_from_secs_safe(secs));
486            }
487        }
488    }
489    None
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495
496    #[test]
497    fn parse_pytest_verbose_output() {
498        let stdout = r#"
499============================= test session starts ==============================
500collected 4 items
501
502tests/test_math.py::test_add PASSED                                      [ 25%]
503tests/test_math.py::test_subtract PASSED                                 [ 50%]
504tests/test_math.py::test_multiply FAILED                                 [ 75%]
505tests/test_string.py::test_upper PASSED                                  [100%]
506
507=================================== FAILURES ===================================
508________________________________ test_multiply _________________________________
509
510    def test_multiply():
511>       assert multiply(2, 3) == 7
512E       assert 6 == 7
513E         +  where 6 = multiply(2, 3)
514
515tests/test_math.py:10: AssertionError
516=========================== short test summary info ============================
517FAILED tests/test_math.py::test_multiply - assert 6 == 7
518============================== 3 passed, 1 failed in 0.12s =====================
519"#;
520        let adapter = PythonAdapter::new();
521        let result = adapter.parse_output(stdout, "", 1);
522
523        assert_eq!(result.total_tests(), 4);
524        assert_eq!(result.total_passed(), 3);
525        assert_eq!(result.total_failed(), 1);
526        assert!(!result.is_success());
527        assert_eq!(result.suites.len(), 2); // two test files
528        assert_eq!(result.duration, Duration::from_millis(120));
529
530        // Verify error message was captured
531        let failed: Vec<_> = result.suites.iter().flat_map(|s| s.failures()).collect();
532        assert_eq!(failed.len(), 1);
533        assert!(failed[0].error.is_some());
534        assert!(
535            failed[0]
536                .error
537                .as_ref()
538                .unwrap()
539                .message
540                .contains("assert 6 == 7")
541        );
542    }
543
544    #[test]
545    fn parse_pytest_all_pass() {
546        let stdout = "========================= 5 passed in 0.32s =========================\n";
547        let adapter = PythonAdapter::new();
548        let result = adapter.parse_output(stdout, "", 0);
549
550        assert_eq!(result.total_tests(), 5);
551        assert_eq!(result.total_passed(), 5);
552        assert!(result.is_success());
553    }
554
555    #[test]
556    fn parse_pytest_with_skipped() {
557        let stdout = r#"
558tests/test_foo.py::test_a PASSED
559tests/test_foo.py::test_b SKIPPED
560tests/test_foo.py::test_c PASSED
561
562========================= 2 passed, 1 skipped in 0.05s =========================
563"#;
564        let adapter = PythonAdapter::new();
565        let result = adapter.parse_output(stdout, "", 0);
566
567        assert_eq!(result.total_passed(), 2);
568        assert_eq!(result.total_skipped(), 1);
569        assert!(result.is_success());
570    }
571
572    #[test]
573    fn parse_pytest_class_based() {
574        let stdout = r#"
575tests/test_calc.py::TestCalculator::test_add PASSED
576tests/test_calc.py::TestCalculator::test_div FAILED
577"#;
578        let adapter = PythonAdapter::new();
579        let result = adapter.parse_output(stdout, "", 1);
580
581        assert_eq!(result.total_tests(), 2);
582        assert_eq!(result.total_passed(), 1);
583        assert_eq!(result.total_failed(), 1);
584    }
585
586    #[test]
587    fn parse_pytest_summary_only() {
588        let stdout = "===== 10 passed, 2 failed, 3 skipped in 1.50s =====\n";
589        let adapter = PythonAdapter::new();
590        let result = adapter.parse_output(stdout, "", 1);
591
592        assert_eq!(result.total_passed(), 10);
593        assert_eq!(result.total_failed(), 2);
594        assert_eq!(result.total_skipped(), 3);
595        assert_eq!(result.total_tests(), 15);
596    }
597
598    #[test]
599    fn parse_pytest_duration_extraction() {
600        assert_eq!(
601            parse_pytest_duration("=== 1 passed in 2.34s ==="),
602            Some(Duration::from_millis(2340))
603        );
604        assert_eq!(parse_pytest_duration("no duration here"), None);
605    }
606
607    #[test]
608    fn parse_pytest_line_function() {
609        assert_eq!(
610            parse_pytest_line("tests/test_foo.py::test_bar PASSED                    [ 50%]"),
611            Some(("tests/test_foo.py::test_bar".into(), "PASSED".into()))
612        );
613        assert_eq!(parse_pytest_line("collected 5 items"), None);
614        assert_eq!(parse_pytest_line(""), None);
615    }
616
617    #[test]
618    fn detect_in_pytest_project() {
619        let dir = tempfile::tempdir().unwrap();
620        std::fs::write(
621            dir.path().join("pyproject.toml"),
622            "[tool.pytest.ini_options]\n",
623        )
624        .unwrap();
625        let adapter = PythonAdapter::new();
626        let det = adapter.detect(dir.path()).unwrap();
627        assert_eq!(det.framework, "pytest");
628        assert!(det.confidence > 0.65);
629    }
630
631    #[test]
632    fn detect_no_python() {
633        let dir = tempfile::tempdir().unwrap();
634        std::fs::write(dir.path().join("main.go"), "package main\n").unwrap();
635        let adapter = PythonAdapter::new();
636        assert!(adapter.detect(dir.path()).is_none());
637    }
638
639    #[test]
640    fn detect_django_project() {
641        let dir = tempfile::tempdir().unwrap();
642        std::fs::write(dir.path().join("requirements.txt"), "django\n").unwrap();
643        std::fs::write(dir.path().join("manage.py"), "#!/usr/bin/env python\n").unwrap();
644        let adapter = PythonAdapter::new();
645        let det = adapter.detect(dir.path()).unwrap();
646        assert_eq!(det.framework, "django");
647    }
648
649    #[test]
650    fn parse_pytest_empty_output() {
651        let adapter = PythonAdapter::new();
652        let result = adapter.parse_output("", "", 0);
653
654        assert_eq!(result.total_tests(), 1);
655        assert!(result.is_success());
656    }
657
658    #[test]
659    fn parse_pytest_xfail_xpass() {
660        let stdout = r#"
661tests/test_edge.py::test_expected_fail XFAIL
662tests/test_edge.py::test_unexpected_pass XPASS
663
664========================= 2 xfailed in 0.05s =========================
665"#;
666        let adapter = PythonAdapter::new();
667        let result = adapter.parse_output(stdout, "", 0);
668
669        // XFAIL and XPASS should be counted as skipped
670        assert_eq!(result.total_skipped(), 2);
671        assert!(result.is_success());
672    }
673
674    #[test]
675    fn parse_pytest_parametrized() {
676        let stdout = r#"
677tests/test_math.py::test_add[1-2-3] PASSED
678tests/test_math.py::test_add[0-0-0] PASSED
679tests/test_math.py::test_add[-1-1-0] PASSED
680
681========================= 3 passed in 0.01s =========================
682"#;
683        let adapter = PythonAdapter::new();
684        let result = adapter.parse_output(stdout, "", 0);
685
686        assert_eq!(result.total_tests(), 3);
687        assert_eq!(result.total_passed(), 3);
688    }
689
690    #[test]
691    fn parse_pytest_error_status() {
692        let stdout = r#"
693tests/test_math.py::test_setup ERROR
694
695========================= 1 error in 0.10s =========================
696"#;
697        let adapter = PythonAdapter::new();
698        let result = adapter.parse_output(stdout, "", 1);
699
700        assert_eq!(result.total_failed(), 1);
701        assert!(!result.is_success());
702    }
703
704    #[test]
705    fn detect_pipfile_project() {
706        let dir = tempfile::tempdir().unwrap();
707        std::fs::write(dir.path().join("Pipfile"), "[packages]\n").unwrap();
708        let adapter = PythonAdapter::new();
709        let det = adapter.detect(dir.path()).unwrap();
710        assert_eq!(det.language, "Python");
711    }
712
713    #[test]
714    fn detect_unittest_fallback() {
715        // Has Python markers but no pytest/django markers
716        let dir = tempfile::tempdir().unwrap();
717        std::fs::write(
718            dir.path().join("setup.py"),
719            "from setuptools import setup\n",
720        )
721        .unwrap();
722        let adapter = PythonAdapter::new();
723        let det = adapter.detect(dir.path()).unwrap();
724        assert_eq!(det.framework, "unittest");
725        assert!(det.confidence < 0.6);
726    }
727}