1use std::hash::{Hash, Hasher};
2
3use crate::adapters::{TestCase, TestRunResult, TestSuite};
4use crate::error::{Result, TestxError};
5use crate::hash::StableHasher;
6
7#[derive(Debug, Clone)]
9pub enum ShardingMode {
10 Slice { index: usize, total: usize },
12 Hash { index: usize, total: usize },
14}
15
16impl ShardingMode {
17 pub fn parse(s: &str) -> Result<Self> {
19 let (mode, spec) = s.split_once(':').ok_or_else(|| TestxError::ConfigError {
20 message: format!(
21 "Invalid partition format '{}'. Expected 'slice:M/N' or 'hash:M/N'",
22 s
23 ),
24 })?;
25
26 let (m_str, n_str) = spec
27 .split_once('/')
28 .ok_or_else(|| TestxError::ConfigError {
29 message: format!(
30 "Invalid partition spec '{}'. Expected 'M/N' where 1 <= M <= N",
31 spec
32 ),
33 })?;
34
35 let m: usize = m_str.parse().map_err(|_| TestxError::ConfigError {
36 message: format!(
37 "Invalid partition index '{}': must be a positive integer",
38 m_str
39 ),
40 })?;
41
42 let n: usize = n_str.parse().map_err(|_| TestxError::ConfigError {
43 message: format!(
44 "Invalid partition total '{}': must be a positive integer",
45 n_str
46 ),
47 })?;
48
49 if n == 0 {
50 return Err(TestxError::ConfigError {
51 message: "Partition total must be >= 1".into(),
52 });
53 }
54
55 if m == 0 || m > n {
56 return Err(TestxError::ConfigError {
57 message: format!(
58 "Partition index must satisfy 1 <= M <= N, got M={}, N={}",
59 m, n
60 ),
61 });
62 }
63
64 match mode {
65 "slice" => Ok(ShardingMode::Slice { index: m, total: n }),
66 "hash" => Ok(ShardingMode::Hash { index: m, total: n }),
67 other => Err(TestxError::ConfigError {
68 message: format!("Unknown partition mode '{}'. Use 'slice' or 'hash'", other),
69 }),
70 }
71 }
72
73 pub fn apply(&self, result: &TestRunResult) -> TestRunResult {
75 match self {
76 ShardingMode::Slice { index, total } => shard_slice(result, *index, *total),
77 ShardingMode::Hash { index, total } => shard_hash(result, *index, *total),
78 }
79 }
80
81 pub fn description(&self) -> String {
83 match self {
84 ShardingMode::Slice { index, total } => {
85 format!("slice {}/{}", index, total)
86 }
87 ShardingMode::Hash { index, total } => {
88 format!("hash {}/{}", index, total)
89 }
90 }
91 }
92
93 pub fn index(&self) -> usize {
95 match self {
96 ShardingMode::Slice { index, .. } | ShardingMode::Hash { index, .. } => *index,
97 }
98 }
99
100 pub fn total(&self) -> usize {
102 match self {
103 ShardingMode::Slice { total, .. } | ShardingMode::Hash { total, .. } => *total,
104 }
105 }
106}
107
108fn shard_slice(result: &TestRunResult, index: usize, total: usize) -> TestRunResult {
110 let all_tests: Vec<(usize, &TestCase)> = result
112 .suites
113 .iter()
114 .enumerate()
115 .flat_map(|(si, s)| s.tests.iter().map(move |t| (si, t)))
116 .collect();
117
118 let bucket = index - 1;
120 let mut suite_tests: Vec<Vec<TestCase>> = vec![Vec::new(); result.suites.len()];
121
122 for (i, (suite_idx, test)) in all_tests.iter().enumerate() {
123 if i % total == bucket {
124 suite_tests[*suite_idx].push((*test).clone());
125 }
126 }
127
128 let suites: Vec<TestSuite> = result
129 .suites
130 .iter()
131 .enumerate()
132 .filter_map(|(i, orig)| {
133 if suite_tests[i].is_empty() {
134 None
135 } else {
136 Some(TestSuite {
137 name: orig.name.clone(),
138 tests: std::mem::take(&mut suite_tests[i]),
139 })
140 }
141 })
142 .collect();
143
144 TestRunResult {
145 suites,
146 duration: result.duration,
147 raw_exit_code: result.raw_exit_code,
148 }
149}
150
151fn shard_hash(result: &TestRunResult, index: usize, total: usize) -> TestRunResult {
153 let bucket = index - 1;
154 let mut suite_tests: Vec<Vec<TestCase>> = vec![Vec::new(); result.suites.len()];
155
156 for (si, suite) in result.suites.iter().enumerate() {
157 for test in &suite.tests {
158 let hash_key = format!("{}::{}::{}", si, suite.name, test.name);
159 let mut hasher = StableHasher::new();
160 hash_key.hash(&mut hasher);
161 let hash_val = hasher.finish();
162
163 if (hash_val as usize) % total == bucket {
164 suite_tests[si].push(test.clone());
165 }
166 }
167 }
168
169 let suites: Vec<TestSuite> = result
170 .suites
171 .iter()
172 .enumerate()
173 .filter_map(|(i, orig)| {
174 if suite_tests[i].is_empty() {
175 None
176 } else {
177 Some(TestSuite {
178 name: orig.name.clone(),
179 tests: std::mem::take(&mut suite_tests[i]),
180 })
181 }
182 })
183 .collect();
184
185 TestRunResult {
186 suites,
187 duration: result.duration,
188 raw_exit_code: result.raw_exit_code,
189 }
190}
191
192pub struct ShardStats {
194 pub total_tests: usize,
195 pub shard_tests: usize,
196 pub skipped_tests: usize,
197 pub shard_index: usize,
198 pub shard_total: usize,
199}
200
201pub fn compute_shard_stats(
202 original: &TestRunResult,
203 sharded: &TestRunResult,
204 mode: &ShardingMode,
205) -> ShardStats {
206 let total_tests = original.total_tests();
207 let shard_tests = sharded.total_tests();
208
209 ShardStats {
210 total_tests,
211 shard_tests,
212 skipped_tests: total_tests.saturating_sub(shard_tests),
213 shard_index: mode.index(),
214 shard_total: mode.total(),
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221 use crate::adapters::{TestError, TestStatus};
222 use std::time::Duration;
223
224 fn make_test(name: &str) -> TestCase {
225 TestCase {
226 name: name.to_string(),
227 status: TestStatus::Passed,
228 duration: Duration::from_millis(10),
229 error: None,
230 }
231 }
232
233 fn make_result(num_suites: usize, tests_per_suite: usize) -> TestRunResult {
234 let suites = (0..num_suites)
235 .map(|s| TestSuite {
236 name: format!("suite_{}", s),
237 tests: (0..tests_per_suite)
238 .map(|t| make_test(&format!("test_{}", t)))
239 .collect(),
240 })
241 .collect();
242
243 TestRunResult {
244 suites,
245 duration: Duration::from_secs(1),
246 raw_exit_code: 0,
247 }
248 }
249
250 #[test]
251 fn parse_slice_valid() {
252 let mode = ShardingMode::parse("slice:1/4").unwrap();
253 assert!(matches!(mode, ShardingMode::Slice { index: 1, total: 4 }));
254 }
255
256 #[test]
257 fn parse_hash_valid() {
258 let mode = ShardingMode::parse("hash:2/3").unwrap();
259 assert!(matches!(mode, ShardingMode::Hash { index: 2, total: 3 }));
260 }
261
262 #[test]
263 fn parse_invalid_format() {
264 assert!(ShardingMode::parse("invalid").is_err());
265 assert!(ShardingMode::parse("slice:1").is_err());
266 assert!(ShardingMode::parse("slice:0/3").is_err());
267 assert!(ShardingMode::parse("slice:4/3").is_err());
268 assert!(ShardingMode::parse("slice:1/0").is_err());
269 assert!(ShardingMode::parse("unknown:1/3").is_err());
270 }
271
272 #[test]
273 fn parse_edge_case_single_shard() {
274 let mode = ShardingMode::parse("slice:1/1").unwrap();
275 assert!(matches!(mode, ShardingMode::Slice { index: 1, total: 1 }));
276 }
277
278 #[test]
279 fn slice_distributes_tests_evenly() {
280 let result = make_result(1, 8);
281 let shard1 = ShardingMode::Slice { index: 1, total: 2 }.apply(&result);
282 let shard2 = ShardingMode::Slice { index: 2, total: 2 }.apply(&result);
283
284 assert_eq!(shard1.total_tests(), 4);
285 assert_eq!(shard2.total_tests(), 4);
286 }
287
288 #[test]
289 fn slice_all_shards_cover_all_tests() {
290 let result = make_result(2, 5); let total_shards = 3;
292
293 let mut all_test_names: Vec<String> = Vec::new();
294 for i in 1..=total_shards {
295 let shard = ShardingMode::Slice {
296 index: i,
297 total: total_shards,
298 }
299 .apply(&result);
300 for suite in &shard.suites {
301 for test in &suite.tests {
302 all_test_names.push(format!("{}::{}", suite.name, test.name));
303 }
304 }
305 }
306
307 all_test_names.sort();
308 let mut expected_names: Vec<String> = result
309 .suites
310 .iter()
311 .flat_map(|s| {
312 s.tests
313 .iter()
314 .map(move |t| format!("{}::{}", s.name, t.name))
315 })
316 .collect();
317 expected_names.sort();
318
319 assert_eq!(all_test_names, expected_names);
320 }
321
322 #[test]
323 fn slice_no_overlap_between_shards() {
324 let result = make_result(2, 6); let total = 4;
326
327 let mut all: Vec<Vec<String>> = Vec::new();
328 for i in 1..=total {
329 let shard = ShardingMode::Slice { index: i, total }.apply(&result);
330 let names: Vec<String> = shard
331 .suites
332 .iter()
333 .flat_map(|s| {
334 s.tests
335 .iter()
336 .map(move |t| format!("{}::{}", s.name, t.name))
337 })
338 .collect();
339 all.push(names);
340 }
341
342 for i in 0..all.len() {
344 for j in (i + 1)..all.len() {
345 for name in &all[i] {
346 assert!(!all[j].contains(name), "Overlap found: {}", name);
347 }
348 }
349 }
350 }
351
352 #[test]
353 fn slice_single_shard_keeps_all() {
354 let result = make_result(2, 5);
355 let shard = ShardingMode::Slice { index: 1, total: 1 }.apply(&result);
356 assert_eq!(shard.total_tests(), result.total_tests());
357 }
358
359 #[test]
360 fn hash_deterministic() {
361 let result = make_result(2, 5);
362 let shard1a = ShardingMode::Hash { index: 1, total: 3 }.apply(&result);
363 let shard1b = ShardingMode::Hash { index: 1, total: 3 }.apply(&result);
364
365 let names_a: Vec<String> = shard1a
366 .suites
367 .iter()
368 .flat_map(|s| {
369 s.tests
370 .iter()
371 .map(move |t| format!("{}::{}", s.name, t.name))
372 })
373 .collect();
374 let names_b: Vec<String> = shard1b
375 .suites
376 .iter()
377 .flat_map(|s| {
378 s.tests
379 .iter()
380 .map(move |t| format!("{}::{}", s.name, t.name))
381 })
382 .collect();
383
384 assert_eq!(names_a, names_b);
385 }
386
387 #[test]
388 fn hash_all_shards_cover_all_tests() {
389 let result = make_result(3, 4); let total = 3;
391
392 let mut all_names: Vec<String> = Vec::new();
393 for i in 1..=total {
394 let shard = ShardingMode::Hash { index: i, total }.apply(&result);
395 for suite in &shard.suites {
396 for test in &suite.tests {
397 all_names.push(format!("{}::{}", suite.name, test.name));
398 }
399 }
400 }
401
402 all_names.sort();
403 let mut expected: Vec<String> = result
404 .suites
405 .iter()
406 .flat_map(|s| {
407 s.tests
408 .iter()
409 .map(move |t| format!("{}::{}", s.name, t.name))
410 })
411 .collect();
412 expected.sort();
413
414 assert_eq!(all_names, expected);
415 }
416
417 #[test]
418 fn hash_no_overlap() {
419 let result = make_result(2, 6);
420 let total = 4;
421
422 let mut all: Vec<Vec<String>> = Vec::new();
423 for i in 1..=total {
424 let shard = ShardingMode::Hash { index: i, total }.apply(&result);
425 let names: Vec<String> = shard
426 .suites
427 .iter()
428 .flat_map(|s| {
429 s.tests
430 .iter()
431 .map(move |t| format!("{}::{}", s.name, t.name))
432 })
433 .collect();
434 all.push(names);
435 }
436
437 for i in 0..all.len() {
438 for j in (i + 1)..all.len() {
439 for name in &all[i] {
440 assert!(!all[j].contains(name), "Hash overlap: {}", name);
441 }
442 }
443 }
444 }
445
446 #[test]
447 fn empty_result_sharding() {
448 let result = TestRunResult {
449 suites: vec![],
450 duration: Duration::ZERO,
451 raw_exit_code: 0,
452 };
453
454 let shard = ShardingMode::Slice { index: 1, total: 3 }.apply(&result);
455 assert_eq!(shard.total_tests(), 0);
456
457 let shard = ShardingMode::Hash { index: 1, total: 3 }.apply(&result);
458 assert_eq!(shard.total_tests(), 0);
459 }
460
461 #[test]
462 fn shard_stats_computation() {
463 let result = make_result(2, 5);
464 let mode = ShardingMode::Slice { index: 1, total: 3 };
465 let sharded = mode.apply(&result);
466 let stats = compute_shard_stats(&result, &sharded, &mode);
467
468 assert_eq!(stats.total_tests, 10);
469 assert_eq!(stats.shard_index, 1);
470 assert_eq!(stats.shard_total, 3);
471 assert_eq!(stats.shard_tests + stats.skipped_tests, stats.total_tests);
472 }
473
474 #[test]
475 fn description_format() {
476 let slice = ShardingMode::Slice { index: 2, total: 5 };
477 assert_eq!(slice.description(), "slice 2/5");
478
479 let hash = ShardingMode::Hash { index: 1, total: 3 };
480 assert_eq!(hash.description(), "hash 1/3");
481 }
482
483 #[test]
484 fn preserves_suite_ordering() {
485 let result = make_result(3, 3);
486 let shard = ShardingMode::Slice { index: 1, total: 1 }.apply(&result);
487
488 let original_order: Vec<&str> = result.suites.iter().map(|s| s.name.as_str()).collect();
489 let shard_order: Vec<&str> = shard.suites.iter().map(|s| s.name.as_str()).collect();
490
491 assert_eq!(original_order, shard_order);
492 }
493
494 #[test]
495 fn hash_stable_after_test_addition() {
496 let mut result1 = make_result(1, 5);
498 let shard1 = ShardingMode::Hash { index: 1, total: 2 }.apply(&result1);
499 let names1: Vec<String> = shard1
500 .suites
501 .iter()
502 .flat_map(|s| s.tests.iter().map(move |t| t.name.clone()))
503 .collect();
504
505 result1.suites[0].tests.push(make_test("test_new"));
507 let shard2 = ShardingMode::Hash { index: 1, total: 2 }.apply(&result1);
508 let names2: Vec<String> = shard2
509 .suites
510 .iter()
511 .flat_map(|s| s.tests.iter().map(move |t| t.name.clone()))
512 .collect();
513
514 for name in &names1 {
516 assert!(
517 names2.contains(name),
518 "Test '{}' moved after addition",
519 name
520 );
521 }
522 }
523
524 #[test]
525 fn failed_tests_preserved_in_shard() {
526 let mut result = make_result(1, 4);
527 result.suites[0].tests[1].status = TestStatus::Failed;
528 result.suites[0].tests[1].error = Some(TestError {
529 message: "assertion failed".to_string(),
530 location: Some("test.rs:42".to_string()),
531 });
532
533 let shard = ShardingMode::Slice { index: 1, total: 1 }.apply(&result);
534 let failed: Vec<&TestCase> = shard.suites[0]
535 .tests
536 .iter()
537 .filter(|t| t.status == TestStatus::Failed)
538 .collect();
539
540 assert_eq!(failed.len(), 1);
541 assert_eq!(
542 failed[0].error.as_ref().unwrap().message,
543 "assertion failed"
544 );
545 }
546
547 #[test]
548 fn many_shards_more_than_tests() {
549 let result = make_result(1, 3); let total = 10; let mut total_assigned = 0;
553 for i in 1..=total {
554 let shard = ShardingMode::Slice { index: i, total }.apply(&result);
555 total_assigned += shard.total_tests();
556 }
557
558 assert_eq!(total_assigned, 3);
559 }
560}