Skip to main content

testx/
sharding.rs

1use std::hash::{Hash, Hasher};
2
3use crate::adapters::{TestCase, TestRunResult, TestSuite};
4use crate::error::{Result, TestxError};
5use crate::hash::StableHasher;
6
7/// Sharding mode for distributing tests across CI workers.
8#[derive(Debug, Clone)]
9pub enum ShardingMode {
10    /// Round-robin slice assignment — simple but not stable across test additions.
11    Slice { index: usize, total: usize },
12    /// Hash-based assignment — deterministic and stable across test additions.
13    Hash { index: usize, total: usize },
14}
15
16impl ShardingMode {
17    /// Parse a partition string like "slice:1/4" or "hash:2/3".
18    pub fn parse(s: &str) -> Result<Self> {
19        let (mode, spec) = s.split_once(':').ok_or_else(|| TestxError::ConfigError {
20            message: format!(
21                "Invalid partition format '{}'. Expected 'slice:M/N' or 'hash:M/N'",
22                s
23            ),
24        })?;
25
26        let (m_str, n_str) = spec
27            .split_once('/')
28            .ok_or_else(|| TestxError::ConfigError {
29                message: format!(
30                    "Invalid partition spec '{}'. Expected 'M/N' where 1 <= M <= N",
31                    spec
32                ),
33            })?;
34
35        let m: usize = m_str.parse().map_err(|_| TestxError::ConfigError {
36            message: format!(
37                "Invalid partition index '{}': must be a positive integer",
38                m_str
39            ),
40        })?;
41
42        let n: usize = n_str.parse().map_err(|_| TestxError::ConfigError {
43            message: format!(
44                "Invalid partition total '{}': must be a positive integer",
45                n_str
46            ),
47        })?;
48
49        if n == 0 {
50            return Err(TestxError::ConfigError {
51                message: "Partition total must be >= 1".into(),
52            });
53        }
54
55        if m == 0 || m > n {
56            return Err(TestxError::ConfigError {
57                message: format!(
58                    "Partition index must satisfy 1 <= M <= N, got M={}, N={}",
59                    m, n
60                ),
61            });
62        }
63
64        match mode {
65            "slice" => Ok(ShardingMode::Slice { index: m, total: n }),
66            "hash" => Ok(ShardingMode::Hash { index: m, total: n }),
67            other => Err(TestxError::ConfigError {
68                message: format!("Unknown partition mode '{}'. Use 'slice' or 'hash'", other),
69            }),
70        }
71    }
72
73    /// Apply sharding to a test run result, keeping only tests in this shard.
74    pub fn apply(&self, result: &TestRunResult) -> TestRunResult {
75        match self {
76            ShardingMode::Slice { index, total } => shard_slice(result, *index, *total),
77            ShardingMode::Hash { index, total } => shard_hash(result, *index, *total),
78        }
79    }
80
81    /// Return a human-readable description.
82    pub fn description(&self) -> String {
83        match self {
84            ShardingMode::Slice { index, total } => {
85                format!("slice {}/{}", index, total)
86            }
87            ShardingMode::Hash { index, total } => {
88                format!("hash {}/{}", index, total)
89            }
90        }
91    }
92
93    /// Return the shard index (1-based).
94    pub fn index(&self) -> usize {
95        match self {
96            ShardingMode::Slice { index, .. } | ShardingMode::Hash { index, .. } => *index,
97        }
98    }
99
100    /// Return the total number of shards.
101    pub fn total(&self) -> usize {
102        match self {
103            ShardingMode::Slice { total, .. } | ShardingMode::Hash { total, .. } => *total,
104        }
105    }
106}
107
108/// Slice-based sharding: flatten all tests, assign round-robin by position.
109fn shard_slice(result: &TestRunResult, index: usize, total: usize) -> TestRunResult {
110    // Flatten all tests with their suite index
111    let all_tests: Vec<(usize, &TestCase)> = result
112        .suites
113        .iter()
114        .enumerate()
115        .flat_map(|(si, s)| s.tests.iter().map(move |t| (si, t)))
116        .collect();
117
118    // Keep only tests where (position % total) == (index - 1) since index is 1-based
119    let bucket = index - 1;
120    let mut suite_tests: Vec<Vec<TestCase>> = vec![Vec::new(); result.suites.len()];
121
122    for (i, (suite_idx, test)) in all_tests.iter().enumerate() {
123        if i % total == bucket {
124            suite_tests[*suite_idx].push((*test).clone());
125        }
126    }
127
128    let suites: Vec<TestSuite> = result
129        .suites
130        .iter()
131        .enumerate()
132        .filter_map(|(i, orig)| {
133            if suite_tests[i].is_empty() {
134                None
135            } else {
136                Some(TestSuite {
137                    name: orig.name.clone(),
138                    tests: std::mem::take(&mut suite_tests[i]),
139                })
140            }
141        })
142        .collect();
143
144    TestRunResult {
145        suites,
146        duration: result.duration,
147        raw_exit_code: result.raw_exit_code,
148    }
149}
150
151/// Hash-based sharding: deterministic assignment based on hash of suite_index+test name.
152fn shard_hash(result: &TestRunResult, index: usize, total: usize) -> TestRunResult {
153    let bucket = index - 1;
154    let mut suite_tests: Vec<Vec<TestCase>> = vec![Vec::new(); result.suites.len()];
155
156    for (si, suite) in result.suites.iter().enumerate() {
157        for test in &suite.tests {
158            let hash_key = format!("{}::{}::{}", si, suite.name, test.name);
159            let mut hasher = StableHasher::new();
160            hash_key.hash(&mut hasher);
161            let hash_val = hasher.finish();
162
163            if (hash_val as usize) % total == bucket {
164                suite_tests[si].push(test.clone());
165            }
166        }
167    }
168
169    let suites: Vec<TestSuite> = result
170        .suites
171        .iter()
172        .enumerate()
173        .filter_map(|(i, orig)| {
174            if suite_tests[i].is_empty() {
175                None
176            } else {
177                Some(TestSuite {
178                    name: orig.name.clone(),
179                    tests: std::mem::take(&mut suite_tests[i]),
180                })
181            }
182        })
183        .collect();
184
185    TestRunResult {
186        suites,
187        duration: result.duration,
188        raw_exit_code: result.raw_exit_code,
189    }
190}
191
192/// Compute sharding statistics for display.
193pub struct ShardStats {
194    pub total_tests: usize,
195    pub shard_tests: usize,
196    pub skipped_tests: usize,
197    pub shard_index: usize,
198    pub shard_total: usize,
199}
200
201pub fn compute_shard_stats(
202    original: &TestRunResult,
203    sharded: &TestRunResult,
204    mode: &ShardingMode,
205) -> ShardStats {
206    let total_tests = original.total_tests();
207    let shard_tests = sharded.total_tests();
208
209    ShardStats {
210        total_tests,
211        shard_tests,
212        skipped_tests: total_tests.saturating_sub(shard_tests),
213        shard_index: mode.index(),
214        shard_total: mode.total(),
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221    use crate::adapters::{TestError, TestStatus};
222    use std::time::Duration;
223
224    fn make_test(name: &str) -> TestCase {
225        TestCase {
226            name: name.to_string(),
227            status: TestStatus::Passed,
228            duration: Duration::from_millis(10),
229            error: None,
230        }
231    }
232
233    fn make_result(num_suites: usize, tests_per_suite: usize) -> TestRunResult {
234        let suites = (0..num_suites)
235            .map(|s| TestSuite {
236                name: format!("suite_{}", s),
237                tests: (0..tests_per_suite)
238                    .map(|t| make_test(&format!("test_{}", t)))
239                    .collect(),
240            })
241            .collect();
242
243        TestRunResult {
244            suites,
245            duration: Duration::from_secs(1),
246            raw_exit_code: 0,
247        }
248    }
249
250    #[test]
251    fn parse_slice_valid() {
252        let mode = ShardingMode::parse("slice:1/4").unwrap();
253        assert!(matches!(mode, ShardingMode::Slice { index: 1, total: 4 }));
254    }
255
256    #[test]
257    fn parse_hash_valid() {
258        let mode = ShardingMode::parse("hash:2/3").unwrap();
259        assert!(matches!(mode, ShardingMode::Hash { index: 2, total: 3 }));
260    }
261
262    #[test]
263    fn parse_invalid_format() {
264        assert!(ShardingMode::parse("invalid").is_err());
265        assert!(ShardingMode::parse("slice:1").is_err());
266        assert!(ShardingMode::parse("slice:0/3").is_err());
267        assert!(ShardingMode::parse("slice:4/3").is_err());
268        assert!(ShardingMode::parse("slice:1/0").is_err());
269        assert!(ShardingMode::parse("unknown:1/3").is_err());
270    }
271
272    #[test]
273    fn parse_edge_case_single_shard() {
274        let mode = ShardingMode::parse("slice:1/1").unwrap();
275        assert!(matches!(mode, ShardingMode::Slice { index: 1, total: 1 }));
276    }
277
278    #[test]
279    fn slice_distributes_tests_evenly() {
280        let result = make_result(1, 8);
281        let shard1 = ShardingMode::Slice { index: 1, total: 2 }.apply(&result);
282        let shard2 = ShardingMode::Slice { index: 2, total: 2 }.apply(&result);
283
284        assert_eq!(shard1.total_tests(), 4);
285        assert_eq!(shard2.total_tests(), 4);
286    }
287
288    #[test]
289    fn slice_all_shards_cover_all_tests() {
290        let result = make_result(2, 5); // 10 tests total
291        let total_shards = 3;
292
293        let mut all_test_names: Vec<String> = Vec::new();
294        for i in 1..=total_shards {
295            let shard = ShardingMode::Slice {
296                index: i,
297                total: total_shards,
298            }
299            .apply(&result);
300            for suite in &shard.suites {
301                for test in &suite.tests {
302                    all_test_names.push(format!("{}::{}", suite.name, test.name));
303                }
304            }
305        }
306
307        all_test_names.sort();
308        let mut expected_names: Vec<String> = result
309            .suites
310            .iter()
311            .flat_map(|s| {
312                s.tests
313                    .iter()
314                    .map(move |t| format!("{}::{}", s.name, t.name))
315            })
316            .collect();
317        expected_names.sort();
318
319        assert_eq!(all_test_names, expected_names);
320    }
321
322    #[test]
323    fn slice_no_overlap_between_shards() {
324        let result = make_result(2, 6); // 12 tests
325        let total = 4;
326
327        let mut all: Vec<Vec<String>> = Vec::new();
328        for i in 1..=total {
329            let shard = ShardingMode::Slice { index: i, total }.apply(&result);
330            let names: Vec<String> = shard
331                .suites
332                .iter()
333                .flat_map(|s| {
334                    s.tests
335                        .iter()
336                        .map(move |t| format!("{}::{}", s.name, t.name))
337                })
338                .collect();
339            all.push(names);
340        }
341
342        // Check no overlap
343        for i in 0..all.len() {
344            for j in (i + 1)..all.len() {
345                for name in &all[i] {
346                    assert!(!all[j].contains(name), "Overlap found: {}", name);
347                }
348            }
349        }
350    }
351
352    #[test]
353    fn slice_single_shard_keeps_all() {
354        let result = make_result(2, 5);
355        let shard = ShardingMode::Slice { index: 1, total: 1 }.apply(&result);
356        assert_eq!(shard.total_tests(), result.total_tests());
357    }
358
359    #[test]
360    fn hash_deterministic() {
361        let result = make_result(2, 5);
362        let shard1a = ShardingMode::Hash { index: 1, total: 3 }.apply(&result);
363        let shard1b = ShardingMode::Hash { index: 1, total: 3 }.apply(&result);
364
365        let names_a: Vec<String> = shard1a
366            .suites
367            .iter()
368            .flat_map(|s| {
369                s.tests
370                    .iter()
371                    .map(move |t| format!("{}::{}", s.name, t.name))
372            })
373            .collect();
374        let names_b: Vec<String> = shard1b
375            .suites
376            .iter()
377            .flat_map(|s| {
378                s.tests
379                    .iter()
380                    .map(move |t| format!("{}::{}", s.name, t.name))
381            })
382            .collect();
383
384        assert_eq!(names_a, names_b);
385    }
386
387    #[test]
388    fn hash_all_shards_cover_all_tests() {
389        let result = make_result(3, 4); // 12 tests
390        let total = 3;
391
392        let mut all_names: Vec<String> = Vec::new();
393        for i in 1..=total {
394            let shard = ShardingMode::Hash { index: i, total }.apply(&result);
395            for suite in &shard.suites {
396                for test in &suite.tests {
397                    all_names.push(format!("{}::{}", suite.name, test.name));
398                }
399            }
400        }
401
402        all_names.sort();
403        let mut expected: Vec<String> = result
404            .suites
405            .iter()
406            .flat_map(|s| {
407                s.tests
408                    .iter()
409                    .map(move |t| format!("{}::{}", s.name, t.name))
410            })
411            .collect();
412        expected.sort();
413
414        assert_eq!(all_names, expected);
415    }
416
417    #[test]
418    fn hash_no_overlap() {
419        let result = make_result(2, 6);
420        let total = 4;
421
422        let mut all: Vec<Vec<String>> = Vec::new();
423        for i in 1..=total {
424            let shard = ShardingMode::Hash { index: i, total }.apply(&result);
425            let names: Vec<String> = shard
426                .suites
427                .iter()
428                .flat_map(|s| {
429                    s.tests
430                        .iter()
431                        .map(move |t| format!("{}::{}", s.name, t.name))
432                })
433                .collect();
434            all.push(names);
435        }
436
437        for i in 0..all.len() {
438            for j in (i + 1)..all.len() {
439                for name in &all[i] {
440                    assert!(!all[j].contains(name), "Hash overlap: {}", name);
441                }
442            }
443        }
444    }
445
446    #[test]
447    fn empty_result_sharding() {
448        let result = TestRunResult {
449            suites: vec![],
450            duration: Duration::ZERO,
451            raw_exit_code: 0,
452        };
453
454        let shard = ShardingMode::Slice { index: 1, total: 3 }.apply(&result);
455        assert_eq!(shard.total_tests(), 0);
456
457        let shard = ShardingMode::Hash { index: 1, total: 3 }.apply(&result);
458        assert_eq!(shard.total_tests(), 0);
459    }
460
461    #[test]
462    fn shard_stats_computation() {
463        let result = make_result(2, 5);
464        let mode = ShardingMode::Slice { index: 1, total: 3 };
465        let sharded = mode.apply(&result);
466        let stats = compute_shard_stats(&result, &sharded, &mode);
467
468        assert_eq!(stats.total_tests, 10);
469        assert_eq!(stats.shard_index, 1);
470        assert_eq!(stats.shard_total, 3);
471        assert_eq!(stats.shard_tests + stats.skipped_tests, stats.total_tests);
472    }
473
474    #[test]
475    fn description_format() {
476        let slice = ShardingMode::Slice { index: 2, total: 5 };
477        assert_eq!(slice.description(), "slice 2/5");
478
479        let hash = ShardingMode::Hash { index: 1, total: 3 };
480        assert_eq!(hash.description(), "hash 1/3");
481    }
482
483    #[test]
484    fn preserves_suite_ordering() {
485        let result = make_result(3, 3);
486        let shard = ShardingMode::Slice { index: 1, total: 1 }.apply(&result);
487
488        let original_order: Vec<&str> = result.suites.iter().map(|s| s.name.as_str()).collect();
489        let shard_order: Vec<&str> = shard.suites.iter().map(|s| s.name.as_str()).collect();
490
491        assert_eq!(original_order, shard_order);
492    }
493
494    #[test]
495    fn hash_stable_after_test_addition() {
496        // Hash sharding should be stable: adding a test shouldn't move existing tests
497        let mut result1 = make_result(1, 5);
498        let shard1 = ShardingMode::Hash { index: 1, total: 2 }.apply(&result1);
499        let names1: Vec<String> = shard1
500            .suites
501            .iter()
502            .flat_map(|s| s.tests.iter().map(move |t| t.name.clone()))
503            .collect();
504
505        // Add a new test
506        result1.suites[0].tests.push(make_test("test_new"));
507        let shard2 = ShardingMode::Hash { index: 1, total: 2 }.apply(&result1);
508        let names2: Vec<String> = shard2
509            .suites
510            .iter()
511            .flat_map(|s| s.tests.iter().map(move |t| t.name.clone()))
512            .collect();
513
514        // All original tests that were in shard 1 should still be there
515        for name in &names1 {
516            assert!(
517                names2.contains(name),
518                "Test '{}' moved after addition",
519                name
520            );
521        }
522    }
523
524    #[test]
525    fn failed_tests_preserved_in_shard() {
526        let mut result = make_result(1, 4);
527        result.suites[0].tests[1].status = TestStatus::Failed;
528        result.suites[0].tests[1].error = Some(TestError {
529            message: "assertion failed".to_string(),
530            location: Some("test.rs:42".to_string()),
531        });
532
533        let shard = ShardingMode::Slice { index: 1, total: 1 }.apply(&result);
534        let failed: Vec<&TestCase> = shard.suites[0]
535            .tests
536            .iter()
537            .filter(|t| t.status == TestStatus::Failed)
538            .collect();
539
540        assert_eq!(failed.len(), 1);
541        assert_eq!(
542            failed[0].error.as_ref().unwrap().message,
543            "assertion failed"
544        );
545    }
546
547    #[test]
548    fn many_shards_more_than_tests() {
549        let result = make_result(1, 3); // 3 tests
550        let total = 10; // 10 shards
551
552        let mut total_assigned = 0;
553        for i in 1..=total {
554            let shard = ShardingMode::Slice { index: i, total }.apply(&result);
555            total_assigned += shard.total_tests();
556        }
557
558        assert_eq!(total_assigned, 3);
559    }
560}