Skip to main content

trustformers_tokenizers/
test_infrastructure.rs

1use scirs2_core::random::*; // SciRS2 Integration Policy - Replaces rand
2use serde::{Deserialize, Serialize};
3use std::collections::{HashMap, HashSet};
4use std::time::{Duration, Instant};
5use trustformers_core::errors::{Result, TrustformersError};
6use trustformers_core::traits::{TokenizedInput, Tokenizer};
7
8/// Configuration for comprehensive testing
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct TestConfig {
11    /// Number of random test cases to generate
12    pub num_random_tests: usize,
13    /// Maximum input length for testing
14    pub max_input_length: usize,
15    /// Timeout for individual tests (in milliseconds)
16    pub timeout_ms: u64,
17    /// Whether to run performance benchmarks
18    pub run_benchmarks: bool,
19    /// Whether to run fuzzing tests
20    pub run_fuzzing: bool,
21    /// Whether to run regression tests
22    pub run_regression: bool,
23    /// Languages to test
24    pub test_languages: Vec<String>,
25    /// Custom test cases to include
26    pub custom_test_cases: Vec<String>,
27}
28
29impl Default for TestConfig {
30    fn default() -> Self {
31        Self {
32            num_random_tests: 1000,
33            max_input_length: 1000,
34            timeout_ms: 5000,
35            run_benchmarks: true,
36            run_fuzzing: true,
37            run_regression: true,
38            test_languages: vec![
39                "en".to_string(),
40                "es".to_string(),
41                "fr".to_string(),
42                "de".to_string(),
43                "zh".to_string(),
44                "ja".to_string(),
45                "ru".to_string(),
46            ],
47            custom_test_cases: Vec::new(),
48        }
49    }
50}
51
52/// Test result for a single test case
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct TestResult {
55    /// Test case description
56    pub test_case: String,
57    /// Whether the test passed
58    pub passed: bool,
59    /// Error message if test failed
60    pub error: Option<String>,
61    /// Execution time
62    pub execution_time: Duration,
63    /// Additional metrics
64    pub metrics: HashMap<String, f64>,
65}
66
67/// Comprehensive test suite results
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct TestSuiteResult {
70    /// Total number of tests run
71    pub total_tests: usize,
72    /// Number of passed tests
73    pub passed_tests: usize,
74    /// Number of failed tests
75    pub failed_tests: usize,
76    /// Individual test results
77    pub test_results: Vec<TestResult>,
78    /// Performance benchmark results
79    pub benchmark_results: Option<BenchmarkResults>,
80    /// Fuzzing test results
81    pub fuzzing_results: Option<FuzzingResults>,
82    /// Regression test results
83    pub regression_results: Option<RegressionResults>,
84    /// Cross-tokenizer validation results
85    pub cross_validation_results: Option<CrossValidationResults>,
86}
87
88/// Performance benchmark results
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct BenchmarkResults {
91    /// Tokens per second for encoding
92    pub encode_tokens_per_second: f64,
93    /// Tokens per second for decoding
94    pub decode_tokens_per_second: f64,
95    /// Memory usage statistics
96    pub memory_usage_mb: f64,
97    /// Latency percentiles
98    pub latency_percentiles: HashMap<String, Duration>,
99}
100
101/// Fuzzing test results
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct FuzzingResults {
104    /// Number of fuzzing tests run
105    pub tests_run: usize,
106    /// Number of crashes/panics detected
107    pub crashes_detected: usize,
108    /// Unique error types found
109    pub error_types: HashSet<String>,
110    /// Coverage metrics
111    pub coverage_metrics: HashMap<String, f64>,
112}
113
114/// Regression test results
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct RegressionResults {
117    /// Number of regression tests run
118    pub tests_run: usize,
119    /// Number of regressions detected
120    pub regressions_detected: usize,
121    /// Details of detected regressions
122    pub regression_details: Vec<RegressionDetail>,
123}
124
125/// Details of a detected regression
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct RegressionDetail {
128    /// Test case that regressed
129    pub test_case: String,
130    /// Expected result
131    pub expected: String,
132    /// Actual result
133    pub actual: String,
134    /// Difference description
135    pub difference: String,
136}
137
138/// Regression test case definition
139#[derive(Debug, Clone, Serialize, Deserialize)]
140struct RegressionTestCase {
141    /// Test case name
142    name: String,
143    /// Input text for tokenization
144    input: String,
145    /// Expected number of tokens (if known)
146    expected_token_count: Option<usize>,
147    /// Whether tokenization should succeed
148    expected_success: bool,
149    /// Maximum allowed execution time
150    max_execution_time: Duration,
151}
152
153/// Cross-tokenizer validation results
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct CrossValidationResults {
156    /// Tokenizers compared
157    pub tokenizers_compared: Vec<String>,
158    /// Consistency score (0.0 to 1.0)
159    pub consistency_score: f64,
160    /// Number of inconsistencies found
161    pub inconsistencies_found: usize,
162    /// Details of inconsistencies
163    pub inconsistency_details: Vec<InconsistencyDetail>,
164}
165
166/// Details of a tokenization inconsistency
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct InconsistencyDetail {
169    /// Input text that caused inconsistency
170    pub input: String,
171    /// Results from each tokenizer
172    pub tokenizer_results: HashMap<String, Vec<String>>,
173    /// Severity of inconsistency
174    pub severity: InconsistencySeverity,
175}
176
177/// Severity levels for inconsistencies
178#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
179pub enum InconsistencySeverity {
180    Low,
181    Medium,
182    High,
183    Critical,
184}
185
186/// Test case generator for fuzzing and random testing
187pub struct TestCaseGenerator {
188    rng: StdRng,
189    config: TestConfig,
190}
191
192#[allow(deprecated)]
193impl TestCaseGenerator {
194    /// Create a new test case generator
195    pub fn new(config: TestConfig, seed: Option<u64>) -> Self {
196        let rng = if let Some(seed) = seed {
197            StdRng::seed_from_u64(seed)
198        } else {
199            StdRng::from_rng(&mut thread_rng().rng_mut())
200        };
201
202        Self { rng, config }
203    }
204
205    /// Generate random text for testing
206    pub fn generate_random_text(&mut self) -> String {
207        let length = self.rng.random_range(1..=self.config.max_input_length);
208        let mut text = String::new();
209
210        for _ in 0..length {
211            let char_type = self.rng.random_range(0..10);
212            let ch = match char_type {
213                0..=5 => self.rng.random_range(b'a'..=b'z') as char, // Lowercase letters
214                6 => self.rng.random_range(b'A'..=b'Z') as char,     // Uppercase letters
215                7 => self.rng.random_range(b'0'..=b'9') as char,     // Digits
216                8 => ' ',                                            // Space
217                _ => self.generate_special_char(),                   // Special characters
218            };
219            text.push(ch);
220        }
221
222        text
223    }
224
225    /// Generate Unicode text for testing
226    pub fn generate_unicode_text(&mut self) -> String {
227        let length = self.rng.random_range(1..=self.config.max_input_length / 2);
228        let mut text = String::new();
229
230        for _ in 0..length {
231            let char_type = self.rng.random_range(0..10);
232            let ch = match char_type {
233                0..=3 => self.rng.random_range('a'..='z'),
234                4 => self.rng.random_range('À'..='ÿ'), // Latin extended
235                5 => self.rng.random_range('Α'..='ω'), // Greek
236                6 => self.rng.random_range('А'..='я'), // Cyrillic
237                7 => self.rng.random_range('一'..='龯'), // CJK
238                8 => self.rng.random_range('ا'..='ي'), // Arabic
239                _ => self.rng.random_range('😀'..='🙏'), // Emoji
240            };
241            text.push(ch);
242        }
243
244        text
245    }
246
247    /// Generate edge case text
248    pub fn generate_edge_case_text(&mut self) -> String {
249        let long_token = "a".repeat(1000);
250        let edge_cases = [
251            "",                            // Empty string
252            " ",                           // Single space
253            "\n\t\r",                      // Whitespace only
254            &long_token,                   // Very long token
255            "123456789",                   // Numbers only
256            "!@#$%^&*()",                  // Special characters only
257            "\u{200B}\u{200C}\u{200D}",    // Zero-width characters
258            "Test\u{0000}null",            // Null character
259            "🚀🌟💫⭐",                    // Emoji sequence
260            "a\u{0301}e\u{0301}i\u{0301}", // Combining characters
261        ];
262
263        edge_cases[self.rng.random_range(0..edge_cases.len())].to_string()
264    }
265
266    /// Generate malformed input for fuzzing
267    pub fn generate_malformed_input(&mut self) -> Vec<u8> {
268        let length = self.rng.random_range(1..=100);
269        let mut bytes = Vec::new();
270
271        for _ in 0..length {
272            bytes.push(self.rng.random());
273        }
274
275        bytes
276    }
277
278    fn generate_special_char(&mut self) -> char {
279        let special_chars = [
280            '!', '@', '#', '$', '%', '^', '&', '*', '(', ')', '-', '_', '+', '=',
281        ];
282        special_chars[self.rng.random_range(0..special_chars.len())]
283    }
284}
285
286/// Comprehensive test runner
287pub struct TestRunner {
288    config: TestConfig,
289    generator: TestCaseGenerator,
290}
291
292impl TestRunner {
293    /// Create a new test runner
294    pub fn new(config: TestConfig) -> Self {
295        let generator = TestCaseGenerator::new(config.clone(), None);
296        Self { config, generator }
297    }
298
299    /// Run the complete test suite
300    pub fn run_complete_suite<T: Tokenizer + Clone>(
301        &mut self,
302        tokenizer: &T,
303        test_name: &str,
304    ) -> Result<TestSuiteResult> {
305        let mut results = Vec::new();
306        let mut total_tests = 0;
307        let mut passed_tests = 0;
308
309        // Basic functionality tests
310        let basic_results = self.run_basic_tests(tokenizer, test_name)?;
311        total_tests += basic_results.len();
312        passed_tests += basic_results.iter().filter(|r| r.passed).count();
313        results.extend(basic_results);
314
315        // Random tests
316        let random_results = self.run_random_tests(tokenizer)?;
317        total_tests += random_results.len();
318        passed_tests += random_results.iter().filter(|r| r.passed).count();
319        results.extend(random_results);
320
321        // Edge case tests
322        let edge_results = self.run_edge_case_tests(tokenizer)?;
323        total_tests += edge_results.len();
324        passed_tests += edge_results.iter().filter(|r| r.passed).count();
325        results.extend(edge_results);
326
327        // Custom test cases
328        if !self.config.custom_test_cases.is_empty() {
329            let custom_results = self.run_custom_tests(tokenizer)?;
330            total_tests += custom_results.len();
331            passed_tests += custom_results.iter().filter(|r| r.passed).count();
332            results.extend(custom_results);
333        }
334
335        let failed_tests = total_tests - passed_tests;
336
337        // Optional advanced testing
338        let benchmark_results = if self.config.run_benchmarks {
339            Some(self.run_benchmarks(tokenizer)?)
340        } else {
341            None
342        };
343
344        let fuzzing_results = if self.config.run_fuzzing {
345            Some(self.run_fuzzing_tests(tokenizer)?)
346        } else {
347            None
348        };
349
350        let regression_results = if self.config.run_regression {
351            Some(self.run_regression_tests(tokenizer)?)
352        } else {
353            None
354        };
355
356        Ok(TestSuiteResult {
357            total_tests,
358            passed_tests,
359            failed_tests,
360            test_results: results,
361            benchmark_results,
362            fuzzing_results,
363            regression_results,
364            cross_validation_results: None, // Filled by cross-validation runner
365        })
366    }
367
368    /// Run basic functionality tests
369    fn run_basic_tests<T: Tokenizer>(
370        &mut self,
371        tokenizer: &T,
372        test_name: &str,
373    ) -> Result<Vec<TestResult>> {
374        let mut results = Vec::new();
375
376        // Test basic encode/decode cycle
377        let test_cases = vec![
378            "Hello, world!",
379            "The quick brown fox jumps over the lazy dog.",
380            "123456789",
381            "Special chars: !@#$%^&*()",
382            "",
383        ];
384
385        for (i, text) in test_cases.into_iter().enumerate() {
386            let start = Instant::now();
387            let test_case = format!("{}_basic_{}", test_name, i);
388
389            match self.test_encode_decode_cycle(tokenizer, text) {
390                Ok(metrics) => {
391                    results.push(TestResult {
392                        test_case,
393                        passed: true,
394                        error: None,
395                        execution_time: start.elapsed(),
396                        metrics,
397                    });
398                },
399                Err(e) => {
400                    results.push(TestResult {
401                        test_case,
402                        passed: false,
403                        error: Some(e.to_string()),
404                        execution_time: start.elapsed(),
405                        metrics: HashMap::new(),
406                    });
407                },
408            }
409        }
410
411        Ok(results)
412    }
413
414    /// Run random tests
415    fn run_random_tests<T: Tokenizer>(&mut self, tokenizer: &T) -> Result<Vec<TestResult>> {
416        let mut results = Vec::new();
417
418        for i in 0..self.config.num_random_tests {
419            let text = self.generator.generate_random_text();
420            let start = Instant::now();
421            let test_case = format!("random_{}", i);
422
423            match self.test_encode_decode_cycle(tokenizer, &text) {
424                Ok(metrics) => {
425                    results.push(TestResult {
426                        test_case,
427                        passed: true,
428                        error: None,
429                        execution_time: start.elapsed(),
430                        metrics,
431                    });
432                },
433                Err(e) => {
434                    results.push(TestResult {
435                        test_case,
436                        passed: false,
437                        error: Some(e.to_string()),
438                        execution_time: start.elapsed(),
439                        metrics: HashMap::new(),
440                    });
441                },
442            }
443        }
444
445        Ok(results)
446    }
447
448    /// Run edge case tests
449    fn run_edge_case_tests<T: Tokenizer>(&mut self, tokenizer: &T) -> Result<Vec<TestResult>> {
450        let mut results = Vec::new();
451
452        for i in 0..100 {
453            let text = self.generator.generate_edge_case_text();
454            let start = Instant::now();
455            let test_case = format!("edge_case_{}", i);
456
457            match self.test_encode_decode_cycle(tokenizer, &text) {
458                Ok(metrics) => {
459                    results.push(TestResult {
460                        test_case,
461                        passed: true,
462                        error: None,
463                        execution_time: start.elapsed(),
464                        metrics,
465                    });
466                },
467                Err(e) => {
468                    results.push(TestResult {
469                        test_case,
470                        passed: false,
471                        error: Some(e.to_string()),
472                        execution_time: start.elapsed(),
473                        metrics: HashMap::new(),
474                    });
475                },
476            }
477        }
478
479        Ok(results)
480    }
481
482    /// Run custom test cases
483    fn run_custom_tests<T: Tokenizer>(&mut self, tokenizer: &T) -> Result<Vec<TestResult>> {
484        let mut results = Vec::new();
485
486        for (i, text) in self.config.custom_test_cases.iter().enumerate() {
487            let start = Instant::now();
488            let test_case = format!("custom_{}", i);
489
490            match self.test_encode_decode_cycle(tokenizer, text) {
491                Ok(metrics) => {
492                    results.push(TestResult {
493                        test_case,
494                        passed: true,
495                        error: None,
496                        execution_time: start.elapsed(),
497                        metrics,
498                    });
499                },
500                Err(e) => {
501                    results.push(TestResult {
502                        test_case,
503                        passed: false,
504                        error: Some(e.to_string()),
505                        execution_time: start.elapsed(),
506                        metrics: HashMap::new(),
507                    });
508                },
509            }
510        }
511
512        Ok(results)
513    }
514
515    /// Test encode/decode cycle for correctness
516    fn test_encode_decode_cycle<T: Tokenizer>(
517        &self,
518        tokenizer: &T,
519        text: &str,
520    ) -> Result<HashMap<String, f64>> {
521        let mut metrics = HashMap::new();
522
523        // Encode
524        let encoded = tokenizer.encode(text)?;
525        metrics.insert("num_tokens".to_string(), encoded.input_ids.len() as f64);
526        metrics.insert("input_length".to_string(), text.chars().count() as f64);
527
528        // Decode
529        let decoded = tokenizer.decode(&encoded.input_ids)?;
530
531        // Verify vocabulary consistency
532        for &token_id in &encoded.input_ids {
533            if let Some(token) = tokenizer.id_to_token(token_id) {
534                if tokenizer.token_to_id(&token).is_none() {
535                    return Err(TrustformersError::runtime_error(format!(
536                        "Token '{}' not found in vocabulary",
537                        token
538                    )));
539                }
540            }
541        }
542
543        // Calculate compression ratio
544        if !text.is_empty() {
545            let compression_ratio = encoded.input_ids.len() as f64 / text.chars().count() as f64;
546            metrics.insert("compression_ratio".to_string(), compression_ratio);
547        }
548
549        // Verify round-trip if possible (not all tokenizers preserve exact text)
550        if decoded.trim() != text.trim() {
551            metrics.insert("exact_match".to_string(), 0.0);
552        } else {
553            metrics.insert("exact_match".to_string(), 1.0);
554        }
555
556        Ok(metrics)
557    }
558
559    /// Run performance benchmarks
560    fn run_benchmarks<T: Tokenizer>(&mut self, tokenizer: &T) -> Result<BenchmarkResults> {
561        let test_texts: Vec<String> =
562            (0..1000).map(|_| self.generator.generate_random_text()).collect();
563
564        // Benchmark encoding
565        let start = Instant::now();
566        let mut total_tokens = 0;
567
568        for text in &test_texts {
569            let encoded = tokenizer.encode(text)?;
570            total_tokens += encoded.input_ids.len();
571        }
572
573        let encoding_time = start.elapsed();
574        let encode_tokens_per_second = total_tokens as f64 / encoding_time.as_secs_f64();
575
576        // Benchmark decoding
577        let token_sequences: Vec<Vec<u32>> = test_texts
578            .iter()
579            .map(|text| tokenizer.encode(text).map(|enc| enc.input_ids))
580            .collect::<std::result::Result<Vec<_>, _>>()?;
581
582        let start = Instant::now();
583        for tokens in &token_sequences {
584            let _ = tokenizer.decode(tokens)?;
585        }
586        let decoding_time = start.elapsed();
587        let decode_tokens_per_second = total_tokens as f64 / decoding_time.as_secs_f64();
588
589        // Memory usage (simplified)
590        let vocab = tokenizer.get_vocab();
591        let memory_usage_mb = (vocab.len() * 100) as f64 / 1024.0 / 1024.0; // Rough estimate
592
593        // Latency percentiles (simplified)
594        let mut latency_percentiles = HashMap::new();
595        latency_percentiles.insert("p50".to_string(), encoding_time / test_texts.len() as u32);
596        latency_percentiles.insert("p95".to_string(), encoding_time / test_texts.len() as u32);
597        latency_percentiles.insert("p99".to_string(), encoding_time / test_texts.len() as u32);
598
599        Ok(BenchmarkResults {
600            encode_tokens_per_second,
601            decode_tokens_per_second,
602            memory_usage_mb,
603            latency_percentiles,
604        })
605    }
606
607    /// Run fuzzing tests
608    fn run_fuzzing_tests<T: Tokenizer>(&mut self, tokenizer: &T) -> Result<FuzzingResults> {
609        let mut tests_run = 0;
610        let mut crashes_detected = 0;
611        let mut error_types = HashSet::new();
612        let mut coverage_metrics = HashMap::new();
613
614        // Generate and test malformed inputs
615        for _ in 0..1000 {
616            tests_run += 1;
617
618            // Try with malformed bytes converted to string (if possible)
619            let malformed_bytes = self.generator.generate_malformed_input();
620            if let Ok(malformed_string) = String::from_utf8(malformed_bytes) {
621                match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
622                    tokenizer.encode(&malformed_string)
623                })) {
624                    Ok(result) => {
625                        if let Err(e) = result {
626                            error_types.insert(format!("{:?}", e));
627                        }
628                    },
629                    Err(_) => {
630                        crashes_detected += 1;
631                    },
632                }
633            }
634
635            // Test with extremely long inputs
636            if tests_run % 100 == 0 {
637                let very_long_text = "a".repeat(10000);
638                match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
639                    tokenizer.encode(&very_long_text)
640                })) {
641                    Ok(result) => {
642                        if let Err(e) = result {
643                            error_types.insert(format!("{:?}", e));
644                        }
645                    },
646                    Err(_) => {
647                        crashes_detected += 1;
648                    },
649                }
650            }
651        }
652
653        // Calculate basic coverage metrics
654        coverage_metrics.insert(
655            "crash_rate".to_string(),
656            crashes_detected as f64 / tests_run as f64,
657        );
658        coverage_metrics.insert("error_diversity".to_string(), error_types.len() as f64);
659
660        Ok(FuzzingResults {
661            tests_run,
662            crashes_detected,
663            error_types,
664            coverage_metrics,
665        })
666    }
667
668    /// Run regression tests against stored baseline data
669    fn run_regression_tests<T: Tokenizer>(&mut self, tokenizer: &T) -> Result<RegressionResults> {
670        let mut regression_details = Vec::new();
671        let mut tests_run = 0;
672
673        // Create regression test cases
674        let test_cases = self.create_regression_test_cases();
675
676        for test_case in &test_cases {
677            tests_run += 1;
678
679            // Run current tokenization
680            let start_time = Instant::now();
681            let current_result = match tokenizer.encode(&test_case.input) {
682                Ok(result) => result,
683                Err(e) => {
684                    // If tokenization fails, that's a regression if baseline succeeded
685                    if test_case.expected_success {
686                        regression_details.push(RegressionDetail {
687                            test_case: test_case.name.clone(),
688                            expected: "Successful tokenization".to_string(),
689                            actual: format!("Failed with error: {}", e),
690                            difference: "Tokenization failed unexpectedly".to_string(),
691                        });
692                    }
693                    continue;
694                },
695            };
696
697            let execution_time = start_time.elapsed();
698
699            // Compare against baseline
700            if let Some(regression) =
701                self.compare_with_baseline(test_case, &current_result, execution_time)
702            {
703                regression_details.push(regression);
704            }
705        }
706
707        Ok(RegressionResults {
708            tests_run,
709            regressions_detected: regression_details.len(),
710            regression_details,
711        })
712    }
713
714    /// Create standard regression test cases
715    fn create_regression_test_cases(&self) -> Vec<RegressionTestCase> {
716        vec![
717            RegressionTestCase {
718                name: "basic_english".to_string(),
719                input: "Hello world".to_string(),
720                expected_token_count: Some(2),
721                expected_success: true,
722                max_execution_time: Duration::from_millis(100),
723            },
724            RegressionTestCase {
725                name: "unicode_text".to_string(),
726                input: "你好世界 🌍".to_string(),
727                expected_token_count: None, // Variable depending on tokenizer
728                expected_success: true,
729                max_execution_time: Duration::from_millis(100),
730            },
731            RegressionTestCase {
732                name: "long_sentence".to_string(),
733                input: "The quick brown fox jumps over the lazy dog. This is a longer sentence to test tokenization performance and accuracy.".to_string(),
734                expected_token_count: None, // Variable depending on tokenizer
735                expected_success: true,
736                max_execution_time: Duration::from_millis(200),
737            },
738            RegressionTestCase {
739                name: "empty_string".to_string(),
740                input: "".to_string(),
741                expected_token_count: Some(0),
742                expected_success: true,
743                max_execution_time: Duration::from_millis(50),
744            },
745            RegressionTestCase {
746                name: "special_characters".to_string(),
747                input: "!@#$%^&*()_+-=[]{}|;':\",./<>?".to_string(),
748                expected_token_count: None, // Variable depending on tokenizer
749                expected_success: true,
750                max_execution_time: Duration::from_millis(100),
751            },
752            RegressionTestCase {
753                name: "mixed_languages".to_string(),
754                input: "Hello こんにちは 你好 Hola".to_string(),
755                expected_token_count: None, // Variable depending on tokenizer
756                expected_success: true,
757                max_execution_time: Duration::from_millis(150),
758            },
759        ]
760    }
761
762    /// Compare current results with baseline expectations
763    fn compare_with_baseline(
764        &self,
765        test_case: &RegressionTestCase,
766        current_result: &TokenizedInput,
767        execution_time: Duration,
768    ) -> Option<RegressionDetail> {
769        let mut differences = Vec::new();
770
771        // Check token count if expected
772        if let Some(expected_count) = test_case.expected_token_count {
773            let actual_count = current_result.input_ids.len();
774            if actual_count != expected_count {
775                differences.push(format!(
776                    "Token count: expected {}, got {}",
777                    expected_count, actual_count
778                ));
779            }
780        }
781
782        // Check execution time
783        if execution_time > test_case.max_execution_time {
784            differences.push(format!(
785                "Execution time: expected <= {:?}, got {:?}",
786                test_case.max_execution_time, execution_time
787            ));
788        }
789
790        // Check for empty results when they shouldn't be
791        if !test_case.input.is_empty() && current_result.input_ids.is_empty() {
792            differences.push("Unexpected empty tokenization result".to_string());
793        }
794
795        // Check attention mask consistency
796        if current_result.input_ids.len() != current_result.attention_mask.len() {
797            differences.push(format!(
798                "Attention mask length mismatch: input_ids={}, attention_mask={}",
799                current_result.input_ids.len(),
800                current_result.attention_mask.len()
801            ));
802        }
803
804        if !differences.is_empty() {
805            Some(RegressionDetail {
806                test_case: test_case.name.clone(),
807                expected: format!(
808                    "Proper tokenization within {} ms",
809                    test_case.max_execution_time.as_millis()
810                ),
811                actual: format!("Issues detected: {:?}", differences),
812                difference: differences.join("; "),
813            })
814        } else {
815            None
816        }
817    }
818}
819
820/// Cross-tokenizer validation runner
821pub struct CrossValidationRunner {
822    #[allow(dead_code)]
823    config: TestConfig,
824}
825
826impl CrossValidationRunner {
827    pub fn new(config: TestConfig) -> Self {
828        Self { config }
829    }
830
831    /// Compare multiple tokenizers for consistency
832    pub fn compare_tokenizers(
833        &self,
834        tokenizers: Vec<(&str, &dyn Tokenizer)>,
835        test_cases: &[String],
836    ) -> Result<CrossValidationResults> {
837        let mut inconsistencies = Vec::new();
838        let mut total_comparisons = 0;
839        let mut consistent_comparisons = 0;
840
841        for text in test_cases {
842            total_comparisons += 1;
843            let mut results = HashMap::new();
844
845            // Get results from each tokenizer
846            for (name, tokenizer) in &tokenizers {
847                match tokenizer.encode(text) {
848                    Ok(encoded) => {
849                        let tokens: Vec<String> = encoded
850                            .input_ids
851                            .iter()
852                            .filter_map(|&id| tokenizer.id_to_token(id))
853                            .collect();
854                        results.insert(name.to_string(), tokens);
855                    },
856                    Err(_) => {
857                        // Skip this comparison if any tokenizer fails
858                        continue;
859                    },
860                }
861            }
862
863            // Check for consistency
864            if results.len() > 1 {
865                let first_result = match results.values().next() {
866                    Some(result) => result,
867                    None => continue, // Should not happen since results.len() > 1, but handle it safely
868                };
869                let is_consistent = results.values().all(|tokens| tokens == first_result);
870
871                if is_consistent {
872                    consistent_comparisons += 1;
873                } else {
874                    let severity = self.determine_inconsistency_severity(&results);
875                    inconsistencies.push(InconsistencyDetail {
876                        input: text.clone(),
877                        tokenizer_results: results,
878                        severity,
879                    });
880                }
881            }
882        }
883
884        let consistency_score = if total_comparisons > 0 {
885            consistent_comparisons as f64 / total_comparisons as f64
886        } else {
887            0.0
888        };
889
890        Ok(CrossValidationResults {
891            tokenizers_compared: tokenizers.iter().map(|(name, _)| name.to_string()).collect(),
892            consistency_score,
893            inconsistencies_found: inconsistencies.len(),
894            inconsistency_details: inconsistencies,
895        })
896    }
897
898    fn determine_inconsistency_severity(
899        &self,
900        results: &HashMap<String, Vec<String>>,
901    ) -> InconsistencySeverity {
902        // Simple heuristic: if token counts differ significantly, it's high severity
903        let token_counts: Vec<usize> = results.values().map(|tokens| tokens.len()).collect();
904        let min_count = *token_counts.iter().min().unwrap_or(&0);
905        let max_count = *token_counts.iter().max().unwrap_or(&0);
906
907        if max_count == 0 {
908            InconsistencySeverity::Low
909        } else {
910            let ratio = min_count as f64 / max_count as f64;
911            if ratio < 0.5 {
912                InconsistencySeverity::High
913            } else if ratio < 0.8 {
914                InconsistencySeverity::Medium
915            } else {
916                InconsistencySeverity::Low
917            }
918        }
919    }
920}
921
922/// Utilities for test reporting and analysis
923pub struct TestReportUtils;
924
925impl TestReportUtils {
926    /// Generate a comprehensive test report
927    pub fn generate_report(result: &TestSuiteResult) -> String {
928        let mut report = String::new();
929
930        report.push_str("=== COMPREHENSIVE TEST REPORT ===\n\n");
931
932        // Summary
933        report.push_str(&format!("Total Tests: {}\n", result.total_tests));
934        report.push_str(&format!(
935            "Passed: {} ({:.1}%)\n",
936            result.passed_tests,
937            (result.passed_tests as f64 / result.total_tests as f64) * 100.0
938        ));
939        report.push_str(&format!(
940            "Failed: {} ({:.1}%)\n\n",
941            result.failed_tests,
942            (result.failed_tests as f64 / result.total_tests as f64) * 100.0
943        ));
944
945        // Failed tests
946        if result.failed_tests > 0 {
947            report.push_str("FAILED TESTS:\n");
948            for test in &result.test_results {
949                if !test.passed {
950                    report.push_str(&format!(
951                        "  {} - {}\n",
952                        test.test_case,
953                        test.error.as_ref().unwrap_or(&"Unknown error".to_string())
954                    ));
955                }
956            }
957            report.push('\n');
958        }
959
960        // Benchmark results
961        if let Some(ref benchmarks) = result.benchmark_results {
962            report.push_str("PERFORMANCE BENCHMARKS:\n");
963            report.push_str(&format!(
964                "  Encoding: {:.0} tokens/sec\n",
965                benchmarks.encode_tokens_per_second
966            ));
967            report.push_str(&format!(
968                "  Decoding: {:.0} tokens/sec\n",
969                benchmarks.decode_tokens_per_second
970            ));
971            report.push_str(&format!(
972                "  Memory Usage: {:.1} MB\n\n",
973                benchmarks.memory_usage_mb
974            ));
975        }
976
977        // Fuzzing results
978        if let Some(ref fuzzing) = result.fuzzing_results {
979            report.push_str("FUZZING RESULTS:\n");
980            report.push_str(&format!("  Tests Run: {}\n", fuzzing.tests_run));
981            report.push_str(&format!(
982                "  Crashes Detected: {}\n",
983                fuzzing.crashes_detected
984            ));
985            report.push_str(&format!(
986                "  Unique Error Types: {}\n\n",
987                fuzzing.error_types.len()
988            ));
989        }
990
991        // Cross-validation results
992        if let Some(ref cross_val) = result.cross_validation_results {
993            report.push_str("CROSS-VALIDATION RESULTS:\n");
994            report.push_str(&format!(
995                "  Consistency Score: {:.3}\n",
996                cross_val.consistency_score
997            ));
998            report.push_str(&format!(
999                "  Inconsistencies Found: {}\n\n",
1000                cross_val.inconsistencies_found
1001            ));
1002        }
1003
1004        report
1005    }
1006
1007    /// Analyze test metrics
1008    pub fn analyze_metrics(results: &[TestResult]) -> HashMap<String, f64> {
1009        let mut analysis = HashMap::new();
1010
1011        let mut total_time = Duration::new(0, 0);
1012        let mut total_tokens = 0.0;
1013        let mut compression_ratios = Vec::new();
1014
1015        for result in results {
1016            total_time += result.execution_time;
1017
1018            if let Some(&tokens) = result.metrics.get("num_tokens") {
1019                total_tokens += tokens;
1020            }
1021
1022            if let Some(&ratio) = result.metrics.get("compression_ratio") {
1023                compression_ratios.push(ratio);
1024            }
1025        }
1026
1027        analysis.insert(
1028            "avg_execution_time_ms".to_string(),
1029            total_time.as_millis() as f64 / results.len() as f64,
1030        );
1031        analysis.insert(
1032            "avg_tokens_per_test".to_string(),
1033            total_tokens / results.len() as f64,
1034        );
1035
1036        if !compression_ratios.is_empty() {
1037            let avg_compression =
1038                compression_ratios.iter().sum::<f64>() / compression_ratios.len() as f64;
1039            analysis.insert("avg_compression_ratio".to_string(), avg_compression);
1040        }
1041
1042        analysis
1043    }
1044}
1045
1046#[cfg(test)]
1047mod tests {
1048    use super::*;
1049    use std::collections::HashMap;
1050
1051    // Mock tokenizer for testing
1052    #[derive(Clone)]
1053    struct MockTokenizer {
1054        vocab: HashMap<String, u32>,
1055    }
1056
1057    impl MockTokenizer {
1058        fn new() -> Self {
1059            let mut vocab = HashMap::new();
1060            vocab.insert("hello".to_string(), 1);
1061            vocab.insert("world".to_string(), 2);
1062            vocab.insert("test".to_string(), 3);
1063            vocab.insert("!".to_string(), 4);
1064
1065            Self { vocab }
1066        }
1067    }
1068
1069    impl Tokenizer for MockTokenizer {
1070        fn encode(&self, text: &str) -> Result<TokenizedInput> {
1071            let tokens: Vec<&str> = text.split_whitespace().collect();
1072            let mut input_ids = Vec::new();
1073            let mut token_strings = Vec::new();
1074
1075            for token in tokens {
1076                if let Some(&id) = self.vocab.get(token) {
1077                    input_ids.push(id);
1078                    token_strings.push(token.to_string());
1079                }
1080            }
1081
1082            Ok(TokenizedInput {
1083                input_ids,
1084                attention_mask: vec![1; token_strings.len()],
1085                token_type_ids: None,
1086                special_tokens_mask: None,
1087                offset_mapping: None,
1088                overflowing_tokens: None,
1089            })
1090        }
1091
1092        fn decode(&self, token_ids: &[u32]) -> Result<String> {
1093            let tokens: Result<Vec<String>> = token_ids
1094                .iter()
1095                .map(|&id| {
1096                    self.vocab.iter().find(|(_, &v)| v == id).map(|(k, _)| k.clone()).ok_or_else(
1097                        || TrustformersError::other(format!("Unknown token ID: {}", id)),
1098                    )
1099                })
1100                .collect();
1101
1102            Ok(tokens?.join(" "))
1103        }
1104
1105        fn get_vocab(&self) -> HashMap<String, u32> {
1106            self.vocab.clone()
1107        }
1108
1109        fn token_to_id(&self, token: &str) -> Option<u32> {
1110            self.vocab.get(token).copied()
1111        }
1112
1113        fn id_to_token(&self, id: u32) -> Option<String> {
1114            self.vocab.iter().find(|(_, &v)| v == id).map(|(k, _)| k.clone())
1115        }
1116
1117        fn encode_pair(&self, text_a: &str, text_b: &str) -> Result<TokenizedInput> {
1118            let combined = format!("{} {}", text_a, text_b);
1119            self.encode(&combined)
1120        }
1121
1122        fn vocab_size(&self) -> usize {
1123            self.vocab.len()
1124        }
1125    }
1126
1127    #[test]
1128    fn test_config_default() {
1129        let config = TestConfig::default();
1130        assert_eq!(config.num_random_tests, 1000);
1131        assert!(config.run_benchmarks);
1132        assert!(config.run_fuzzing);
1133    }
1134
1135    #[test]
1136    fn test_case_generator() {
1137        let config = TestConfig::default();
1138        let mut generator = TestCaseGenerator::new(config, Some(42));
1139
1140        let random_text = generator.generate_random_text();
1141        assert!(!random_text.is_empty());
1142
1143        let unicode_text = generator.generate_unicode_text();
1144        assert!(!unicode_text.is_empty());
1145
1146        let edge_case = generator.generate_edge_case_text();
1147        // Edge cases can be empty
1148        assert!(edge_case.len() <= 1000);
1149    }
1150
1151    #[test]
1152    fn test_basic_functionality() {
1153        let config = TestConfig::default();
1154        let mut runner = TestRunner::new(config);
1155        let tokenizer = MockTokenizer::new();
1156
1157        let results = runner.run_basic_tests(&tokenizer, "test").expect("Operation failed in test");
1158        assert!(!results.is_empty());
1159
1160        // At least some tests should pass
1161        let passed_count = results.iter().filter(|r| r.passed).count();
1162        assert!(passed_count > 0);
1163    }
1164
1165    #[test]
1166    fn test_encode_decode_cycle() {
1167        let config = TestConfig::default();
1168        let runner = TestRunner::new(config);
1169        let tokenizer = MockTokenizer::new();
1170
1171        let metrics = runner
1172            .test_encode_decode_cycle(&tokenizer, "hello world")
1173            .expect("Operation failed in test");
1174        assert!(metrics.contains_key("num_tokens"));
1175        assert!(metrics.contains_key("input_length"));
1176    }
1177
1178    #[test]
1179    fn test_cross_validation() {
1180        let config = TestConfig::default();
1181        let validator = CrossValidationRunner::new(config);
1182
1183        let tokenizer1 = MockTokenizer::new();
1184        let tokenizer2 = MockTokenizer::new();
1185
1186        let tokenizers: Vec<(&str, &dyn Tokenizer)> =
1187            vec![("mock1", &tokenizer1), ("mock2", &tokenizer2)];
1188
1189        let test_cases = vec!["hello world".to_string(), "test".to_string()];
1190        let results = validator
1191            .compare_tokenizers(tokenizers, &test_cases)
1192            .expect("Operation failed in test");
1193
1194        assert_eq!(results.tokenizers_compared.len(), 2);
1195        assert!(results.consistency_score >= 0.0 && results.consistency_score <= 1.0);
1196    }
1197
1198    #[test]
1199    fn test_report_generation() {
1200        let test_result = TestSuiteResult {
1201            total_tests: 10,
1202            passed_tests: 8,
1203            failed_tests: 2,
1204            test_results: vec![
1205                TestResult {
1206                    test_case: "test1".to_string(),
1207                    passed: true,
1208                    error: None,
1209                    execution_time: Duration::from_millis(10),
1210                    metrics: HashMap::new(),
1211                },
1212                TestResult {
1213                    test_case: "test2".to_string(),
1214                    passed: false,
1215                    error: Some("Test failed".to_string()),
1216                    execution_time: Duration::from_millis(5),
1217                    metrics: HashMap::new(),
1218                },
1219            ],
1220            benchmark_results: None,
1221            fuzzing_results: None,
1222            regression_results: None,
1223            cross_validation_results: None,
1224        };
1225
1226        let report = TestReportUtils::generate_report(&test_result);
1227        assert!(report.contains("Total Tests: 10"));
1228        assert!(report.contains("Passed: 8"));
1229        assert!(report.contains("Failed: 2"));
1230    }
1231
1232    #[test]
1233    fn test_metrics_analysis() {
1234        let results = vec![
1235            TestResult {
1236                test_case: "test1".to_string(),
1237                passed: true,
1238                error: None,
1239                execution_time: Duration::from_millis(10),
1240                metrics: {
1241                    let mut m = HashMap::new();
1242                    m.insert("num_tokens".to_string(), 5.0);
1243                    m.insert("compression_ratio".to_string(), 0.8);
1244                    m
1245                },
1246            },
1247            TestResult {
1248                test_case: "test2".to_string(),
1249                passed: true,
1250                error: None,
1251                execution_time: Duration::from_millis(20),
1252                metrics: {
1253                    let mut m = HashMap::new();
1254                    m.insert("num_tokens".to_string(), 3.0);
1255                    m.insert("compression_ratio".to_string(), 1.2);
1256                    m
1257                },
1258            },
1259        ];
1260
1261        let analysis = TestReportUtils::analyze_metrics(&results);
1262        assert!(analysis.contains_key("avg_execution_time_ms"));
1263        assert!(analysis.contains_key("avg_tokens_per_test"));
1264        assert!(analysis.contains_key("avg_compression_ratio"));
1265    }
1266}