1use std::collections::BTreeSet;
15use std::path::{Path, PathBuf};
16
17use anyhow::{anyhow, bail, Context, Result};
18use oxc::allocator::Allocator;
19use oxc::ast::ast::{Argument, CallExpression, Expression, ImportDeclaration, ImportOrExportKind};
20use oxc::ast_visit::{walk, Visit};
21use oxc::parser::Parser;
22use oxc::span::{SourceType, Span};
23
24use crate::lint::Violation;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum Origin {
29 FirstParty,
31 Builtin,
33 ThirdParty,
35}
36
37pub fn classify(specifier: &str) -> Origin {
46 if specifier.starts_with('.') || specifier.starts_with('/') {
47 return Origin::FirstParty;
48 }
49 if specifier.starts_with("node:") || is_node_builtin(specifier) {
50 return Origin::Builtin;
51 }
52 Origin::ThirdParty
53}
54
55fn is_node_builtin(specifier: &str) -> bool {
58 let head = specifier.split('/').next().unwrap_or(specifier);
59 NODE_BUILTINS.contains(&head)
60}
61
62const NODE_BUILTINS: &[&str] = &[
66 "assert",
67 "async_hooks",
68 "buffer",
69 "child_process",
70 "cluster",
71 "console",
72 "constants",
73 "crypto",
74 "dgram",
75 "diagnostics_channel",
76 "dns",
77 "domain",
78 "events",
79 "fs",
80 "http",
81 "http2",
82 "https",
83 "inspector",
84 "module",
85 "net",
86 "os",
87 "path",
88 "perf_hooks",
89 "process",
90 "punycode",
91 "querystring",
92 "readline",
93 "repl",
94 "stream",
95 "string_decoder",
96 "sys",
97 "timers",
98 "tls",
99 "trace_events",
100 "tty",
101 "url",
102 "util",
103 "v8",
104 "vm",
105 "wasi",
106 "worker_threads",
107 "zlib",
108];
109
110pub fn find_integration_violations(root: impl AsRef<Path>) -> Result<Vec<Violation>> {
117 let root = root.as_ref();
118 let mut files = Vec::new();
119 collect_ts_test_files(root, &mut files)?;
120 files.sort();
121
122 let mut violations = Vec::new();
123 for file in &files {
124 let source = std::fs::read_to_string(file)
125 .with_context(|| format!("reading test file `{}`", file.display()))?;
126 violations.extend(integration_violations_in(file, &source)?);
127 }
128
129 violations.sort_by(|a, b| a.file.cmp(&b.file).then(a.line.cmp(&b.line)));
130 Ok(violations)
131}
132
133pub fn find_unit_violations(root: impl AsRef<Path>) -> Result<Vec<Violation>> {
141 let root = root.as_ref();
142 let mut files = Vec::new();
143 collect_ts_test_files(root, &mut files)?;
144 files.sort();
145
146 let mut violations = Vec::new();
147 for file in &files {
148 let source = std::fs::read_to_string(file)
149 .with_context(|| format!("reading test file `{}`", file.display()))?;
150 violations.extend(unit_violations_in(file, &source)?);
151 }
152
153 violations.sort_by(|a, b| a.file.cmp(&b.file).then(a.line.cmp(&b.line)));
154 Ok(violations)
155}
156
157fn unit_violations_in(file: &Path, source: &str) -> Result<Vec<Violation>> {
161 let allocator = Allocator::default();
162 let source_type = SourceType::from_path(file).map_err(|err| {
163 anyhow!(
164 "unsupported TypeScript extension `{}`: {err}",
165 file.display()
166 )
167 })?;
168 let ret = Parser::new(&allocator, source, source_type).parse();
169 if ret.panicked || !ret.diagnostics.is_empty() {
170 let detail = ret
171 .diagnostics
172 .iter()
173 .map(|d| d.to_string())
174 .collect::<Vec<_>>()
175 .join("; ");
176 bail!("parsing `{}` failed: {detail}", file.display());
177 }
178
179 let mut collector = UnitCollector {
180 source,
181 imports: Vec::new(),
182 mocked: BTreeSet::new(),
183 untyped: Vec::new(),
184 };
185 collector.visit_program(&ret.program);
186
187 let unit = unit_under_test_specifier(file);
188 let mut violations = Vec::new();
189 for (spec, line) in &collector.imports {
190 if is_unit_under_test(spec, &unit)
191 || is_test_runner(spec)
192 || collector.mocked.contains(spec)
193 {
194 continue;
195 }
196 violations.push(Violation {
197 file: file.to_path_buf(),
198 line: *line,
199 rule: "unmocked-collaborator",
200 message: format!(
201 "unit test imports `{spec}` without mocking it — a unit test isolates the \
202 unit under test, so every collaborator must be `vi.mock()`-ed"
203 ),
204 });
205 }
206 for (spec, line) in &collector.untyped {
207 violations.push(Violation {
208 file: file.to_path_buf(),
209 line: *line,
210 rule: "untyped-mock",
211 message: format!(
212 "`vi.mock('{spec}', …)` has an untyped factory — anchor it to the real module \
213 with `vi.importActual<typeof import('{spec}')>()` so the double can't drift \
214 from the source"
215 ),
216 });
217 }
218 violations.sort_by_key(|v| v.line);
219 Ok(violations)
220}
221
222struct UnitCollector<'s> {
225 source: &'s str,
226 imports: Vec<(String, usize)>,
227 mocked: BTreeSet<String>,
228 untyped: Vec<(String, usize)>,
229}
230
231impl<'a> Visit<'a> for UnitCollector<'_> {
232 fn visit_import_declaration(&mut self, decl: &ImportDeclaration<'a>) {
233 if matches!(decl.import_kind, ImportOrExportKind::Type) {
235 return;
236 }
237 self.imports.push((
238 decl.source.value.to_string(),
239 line_of(self.source, decl.span.start),
240 ));
241 }
242
243 fn visit_call_expression(&mut self, call: &CallExpression<'a>) {
244 if let Some(spec) = vi_mock_target(call) {
245 if let Some(factory) = call.arguments.get(1) {
252 if is_factory(factory) && !factory_is_typed(factory) {
253 self.untyped
254 .push((spec.clone(), line_of(self.source, call.span.start)));
255 }
256 }
257 self.mocked.insert(spec);
258 }
259 walk::walk_call_expression(self, call);
260 }
261}
262
263fn unit_under_test_specifier(file: &Path) -> String {
265 let name = file
266 .file_name()
267 .and_then(|n| n.to_str())
268 .unwrap_or_default();
269 let stem = name.split(".test.").next().unwrap_or(name);
270 format!("./{stem}")
271}
272
273fn is_unit_under_test(spec: &str, unit: &str) -> bool {
276 strip_module_ext(spec) == unit
277}
278
279fn strip_module_ext(spec: &str) -> &str {
281 for ext in [".js", ".mjs", ".cjs", ".jsx", ".ts", ".mts", ".cts", ".tsx"] {
282 if let Some(base) = spec.strip_suffix(ext) {
283 return base;
284 }
285 }
286 spec
287}
288
289fn is_test_runner(spec: &str) -> bool {
292 spec == "vitest" || spec.starts_with("vitest/") || spec.starts_with("@vitest/")
293}
294
295fn is_factory(arg: &Argument) -> bool {
302 matches!(
303 arg.as_expression(),
304 Some(Expression::ArrowFunctionExpression(_) | Expression::FunctionExpression(_))
305 )
306}
307
308fn factory_is_typed(factory: &Argument) -> bool {
312 let mut finder = ImportActualFinder { typed: false };
313 finder.visit_argument(factory);
314 finder.typed
315}
316
317struct ImportActualFinder {
319 typed: bool,
320}
321
322impl<'a> Visit<'a> for ImportActualFinder {
323 fn visit_call_expression(&mut self, call: &CallExpression<'a>) {
324 if is_typed_import_actual(call) {
325 self.typed = true;
326 }
327 walk::walk_call_expression(self, call);
328 }
329}
330
331fn is_typed_import_actual(call: &CallExpression) -> bool {
334 let Expression::StaticMemberExpression(member) = &call.callee else {
335 return false;
336 };
337 let is_vi = matches!(&member.object, Expression::Identifier(id) if id.name == "vi");
338 is_vi && member.property.name.as_str() == "importActual" && call.type_arguments.is_some()
339}
340
341fn integration_violations_in(file: &Path, source: &str) -> Result<Vec<Violation>> {
345 let allocator = Allocator::default();
346 let source_type = SourceType::from_path(file).map_err(|err| {
347 anyhow!(
348 "unsupported TypeScript extension `{}`: {err}",
349 file.display()
350 )
351 })?;
352 let ret = Parser::new(&allocator, source, source_type).parse();
353 if ret.panicked || !ret.diagnostics.is_empty() {
354 let detail = ret
355 .diagnostics
356 .iter()
357 .map(|d| d.to_string())
358 .collect::<Vec<_>>()
359 .join("; ");
360 bail!("parsing `{}` failed: {detail}", file.display());
361 }
362
363 let mut visitor = MockVisitor {
364 file,
365 source,
366 violations: Vec::new(),
367 };
368 visitor.visit_program(&ret.program);
369 Ok(visitor.violations)
370}
371
372struct MockVisitor<'s> {
375 file: &'s Path,
376 source: &'s str,
377 violations: Vec<Violation>,
378}
379
380impl MockVisitor<'_> {
381 fn report(&mut self, span: Span, spec: &str) {
382 self.violations.push(Violation {
383 file: self.file.to_path_buf(),
384 line: line_of(self.source, span.start),
385 rule: "no-first-party-mock",
386 message: format!(
387 "integration test mocks first-party module `{spec}` — an integration test \
388 runs first-party code for real; only third-party packages and Node built-ins \
389 may be mocked"
390 ),
391 });
392 }
393}
394
395impl<'a> Visit<'a> for MockVisitor<'_> {
396 fn visit_call_expression(&mut self, call: &CallExpression<'a>) {
397 if let Some(spec) = vi_mock_target(call) {
398 if classify(&spec) == Origin::FirstParty {
399 self.report(call.span, &spec);
400 }
401 }
402 walk::walk_call_expression(self, call);
403 }
404}
405
406fn vi_mock_target(call: &CallExpression) -> Option<String> {
412 let Expression::StaticMemberExpression(member) = &call.callee else {
413 return None;
414 };
415 let is_vi = matches!(&member.object, Expression::Identifier(id) if id.name == "vi");
416 if !is_vi {
417 return None;
418 }
419 let method = member.property.name.as_str();
420 if method != "mock" && method != "doMock" {
421 return None;
422 }
423 match call.arguments.first() {
424 Some(Argument::StringLiteral(lit)) => Some(lit.value.to_string()),
425 _ => None,
426 }
427}
428
429fn line_of(source: &str, offset: u32) -> usize {
431 let offset = (offset as usize).min(source.len());
432 source.as_bytes()[..offset]
433 .iter()
434 .filter(|&&byte| byte == b'\n')
435 .count()
436 + 1
437}
438
439fn collect_ts_test_files(dir: &Path, out: &mut Vec<PathBuf>) -> Result<()> {
441 let entries =
442 std::fs::read_dir(dir).with_context(|| format!("reading directory `{}`", dir.display()))?;
443 for entry in entries {
444 let path = entry
445 .with_context(|| format!("reading an entry under `{}`", dir.display()))?
446 .path();
447 if path.is_dir() {
448 collect_ts_test_files(&path, out)?;
449 } else if is_ts_test_file(&path) {
450 out.push(path);
451 }
452 }
453 Ok(())
454}
455
456fn is_ts_test_file(path: &Path) -> bool {
458 let name = path
459 .file_name()
460 .and_then(|n| n.to_str())
461 .unwrap_or_default();
462 name.ends_with(".test.ts")
463 || name.ends_with(".test.tsx")
464 || name.ends_with(".test.mts")
465 || name.ends_with(".test.cts")
466}
467
468#[cfg(test)]
469mod tests {
470 use super::*;
471
472 fn violations(name: &str, source: &str) -> Vec<Violation> {
474 integration_violations_in(Path::new(name), source).expect("source should parse")
475 }
476
477 fn unit_violations(name: &str, source: &str) -> Vec<Violation> {
479 unit_violations_in(Path::new(name), source).expect("source should parse")
480 }
481
482 #[test]
483 fn unit_flags_unmocked_first_party_and_external() {
484 let found = unit_violations(
485 "widget.test.ts",
486 "import { makeWidget } from './widget';\n\
487 import { format } from './formatter';\n\
488 import { chunk } from 'lodash';\n",
489 );
490 assert_eq!(found.len(), 2, "got: {found:?}");
493 assert!(found.iter().all(|v| v.rule == "unmocked-collaborator"));
494 assert!(found.iter().any(|v| v.message.contains("./formatter")));
495 assert!(found.iter().any(|v| v.message.contains("lodash")));
496 }
497
498 #[test]
499 fn unit_mocked_collaborator_is_clean() {
500 let found = unit_violations(
501 "widget.test.ts",
502 "import { format } from './formatter';\nvi.mock('./formatter');\n",
503 );
504 assert!(found.is_empty(), "got: {found:?}");
505 }
506
507 #[test]
508 fn unit_under_test_and_runner_are_not_flagged() {
509 let found = unit_violations(
510 "widget.test.ts",
511 "import { vi } from 'vitest';\n\
512 import { makeWidget } from './widget.js';\n",
513 );
514 assert!(found.is_empty(), "got: {found:?}");
516 }
517
518 #[test]
519 fn unit_type_only_import_is_not_flagged() {
520 let found = unit_violations(
521 "widget.test.ts",
522 "import type { Opts } from './opts';\nimport { x } from './x';\nvi.mock('./x');\n",
523 );
524 assert!(found.is_empty(), "got: {found:?}");
525 }
526
527 #[test]
528 fn unit_under_test_specifier_strips_test_suffix() {
529 assert_eq!(
530 unit_under_test_specifier(Path::new("pkg/widget.test.ts")),
531 "./widget"
532 );
533 assert_eq!(
534 unit_under_test_specifier(Path::new("button.test.tsx")),
535 "./button"
536 );
537 }
538
539 #[test]
540 fn strip_module_ext_drops_known_extensions_only() {
541 assert_eq!(strip_module_ext("./widget.js"), "./widget");
542 assert_eq!(strip_module_ext("./widget.mts"), "./widget");
543 assert_eq!(strip_module_ext("./widget"), "./widget");
544 assert_eq!(strip_module_ext("lodash"), "lodash");
545 }
546
547 #[test]
548 fn recognizes_the_test_runner() {
549 assert!(is_test_runner("vitest"));
550 assert!(is_test_runner("vitest/config"));
551 assert!(is_test_runner("@vitest/spy"));
552 assert!(!is_test_runner("./vitest-helpers"));
553 assert!(!is_test_runner("lodash"));
554 }
555
556 #[test]
557 fn unit_flags_untyped_factory_mock() {
558 let found = unit_violations(
559 "widget.test.ts",
560 "import { x } from './x';\nvi.mock('./x', () => ({ x: vi.fn() }));\n",
561 );
562 assert_eq!(found.len(), 1, "got: {found:?}");
565 assert_eq!(found[0].rule, "untyped-mock");
566 assert!(found[0].message.contains("./x"));
567 }
568
569 #[test]
570 fn unit_typed_factory_mock_is_clean() {
571 let found = unit_violations(
572 "widget.test.ts",
573 "import { x } from './x';\n\
574 vi.mock('./x', async () => {\n\
575 \x20 const actual = await vi.importActual<typeof import('./x')>('./x');\n\
576 \x20 return { ...actual, x: vi.fn() };\n\
577 });\n",
578 );
579 assert!(found.is_empty(), "got: {found:?}");
580 }
581
582 #[test]
583 fn unit_options_object_mock_is_not_a_factory() {
584 let found = unit_violations(
588 "widget.test.ts",
589 "import { x } from './x';\nvi.mock('./x', { spy: true });\n",
590 );
591 assert!(found.is_empty(), "got: {found:?}");
592 }
593
594 #[test]
595 fn unit_untyped_import_actual_is_still_untyped() {
596 let found = unit_violations(
598 "widget.test.ts",
599 "import { x } from './x';\n\
600 vi.mock('./x', async () => {\n\
601 \x20 const actual = await vi.importActual('./x');\n\
602 \x20 return { ...(actual as object), x: vi.fn() };\n\
603 });\n",
604 );
605 assert_eq!(found.len(), 1, "got: {found:?}");
606 assert_eq!(found[0].rule, "untyped-mock");
607 }
608
609 #[test]
610 fn classify_relative_is_first_party() {
611 assert_eq!(classify("./service"), Origin::FirstParty);
612 assert_eq!(classify("../pkg/util"), Origin::FirstParty);
613 assert_eq!(classify("/abs/path"), Origin::FirstParty);
614 }
615
616 #[test]
617 fn classify_node_builtins() {
618 assert_eq!(classify("fs"), Origin::Builtin);
619 assert_eq!(classify("node:fs"), Origin::Builtin);
620 assert_eq!(classify("fs/promises"), Origin::Builtin);
621 assert_eq!(classify("node:test"), Origin::Builtin);
622 assert_eq!(classify("child_process"), Origin::Builtin);
623 assert_eq!(classify("node:some-future-builtin"), Origin::Builtin);
624 }
625
626 #[test]
627 fn classify_third_party() {
628 assert_eq!(classify("lodash"), Origin::ThirdParty);
629 assert_eq!(classify("@scope/pkg"), Origin::ThirdParty);
630 assert_eq!(classify("stripe/lib/client"), Origin::ThirdParty);
631 assert_eq!(classify("test"), Origin::ThirdParty);
634 }
635
636 #[test]
637 fn recognizes_ts_test_files() {
638 assert!(is_ts_test_file(Path::new("widget.test.ts")));
639 assert!(is_ts_test_file(Path::new("pkg/button.test.tsx")));
640 assert!(is_ts_test_file(Path::new("service.test.mts")));
641 assert!(is_ts_test_file(Path::new("legacy.test.cts")));
642 assert!(!is_ts_test_file(Path::new("widget.ts")));
643 assert!(!is_ts_test_file(Path::new("types.d.ts")));
644 assert!(!is_ts_test_file(Path::new("README.md")));
645 }
646
647 #[test]
648 fn line_of_counts_newlines() {
649 let src = "a\nb\nc\n";
650 assert_eq!(line_of(src, 0), 1);
651 assert_eq!(line_of(src, 2), 2);
652 assert_eq!(line_of(src, 4), 3);
653 }
654
655 #[test]
656 fn flags_mock_of_relative_module() {
657 let found = violations("a.test.ts", "vi.mock('./service');\n");
658 assert_eq!(found.len(), 1);
659 assert_eq!(found[0].rule, "no-first-party-mock");
660 assert_eq!(found[0].line, 1);
661 }
662
663 #[test]
664 fn flags_mock_with_factory_and_parent_path() {
665 let found = violations(
666 "a.test.ts",
667 "import { x } from './x';\nvi.mock('../src/ledger', () => ({ record: vi.fn() }));\n",
668 );
669 assert_eq!(found.len(), 1);
670 assert!(found[0].message.contains("../src/ledger"));
671 }
672
673 #[test]
674 fn flags_domock_of_relative_module() {
675 let found = violations("a.test.mts", "vi.doMock('./mailer');\n");
676 assert_eq!(found.len(), 1);
677 }
678
679 #[test]
680 fn allows_mock_of_third_party_and_builtins() {
681 let found = violations(
682 "a.test.ts",
683 "vi.mock('stripe');\nvi.mock('node:fs');\nvi.mock('fs/promises');\nvi.mock('@scope/pkg');\n",
684 );
685 assert!(found.is_empty(), "got: {found:?}");
686 }
687
688 #[test]
689 fn ignores_non_vi_and_non_mock_calls() {
690 let found = violations(
693 "a.test.ts",
694 "describe('s', () => {});\nvi.fn();\nexpect(1).toBe(1);\nother.mock('./x');\n",
695 );
696 assert!(found.is_empty(), "got: {found:?}");
697 }
698
699 #[test]
700 fn ignores_dynamic_mock_target() {
701 let found = violations("a.test.ts", "const m = './x';\nvi.mock(m);\n");
703 assert!(found.is_empty(), "got: {found:?}");
704 }
705
706 #[test]
707 fn finds_mocks_nested_in_blocks() {
708 let found = violations(
711 "a.test.ts",
712 "describe('s', () => {\n vi.mock('./inner');\n});\n",
713 );
714 assert_eq!(found.len(), 1);
715 assert_eq!(found[0].line, 2);
716 }
717
718 #[test]
719 fn parse_error_is_reported() {
720 let err = integration_violations_in(Path::new("bad.test.ts"), "const x = ;\n").unwrap_err();
721 assert!(err.to_string().contains("parsing"), "got: {err}");
722 }
723
724 #[test]
725 fn unsupported_extension_is_reported() {
726 let err = integration_violations_in(Path::new("weird.test.bogus"), "vi.mock('./x');\n")
727 .unwrap_err();
728 assert!(err.to_string().contains("unsupported"), "got: {err}");
729 }
730}