1use scirs2_core::random::*; use serde::{Deserialize, Serialize};
3use std::collections::{HashMap, HashSet};
4use std::time::{Duration, Instant};
5use trustformers_core::errors::{Result, TrustformersError};
6use trustformers_core::traits::{TokenizedInput, Tokenizer};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct TestConfig {
11 pub num_random_tests: usize,
13 pub max_input_length: usize,
15 pub timeout_ms: u64,
17 pub run_benchmarks: bool,
19 pub run_fuzzing: bool,
21 pub run_regression: bool,
23 pub test_languages: Vec<String>,
25 pub custom_test_cases: Vec<String>,
27}
28
29impl Default for TestConfig {
30 fn default() -> Self {
31 Self {
32 num_random_tests: 1000,
33 max_input_length: 1000,
34 timeout_ms: 5000,
35 run_benchmarks: true,
36 run_fuzzing: true,
37 run_regression: true,
38 test_languages: vec![
39 "en".to_string(),
40 "es".to_string(),
41 "fr".to_string(),
42 "de".to_string(),
43 "zh".to_string(),
44 "ja".to_string(),
45 "ru".to_string(),
46 ],
47 custom_test_cases: Vec::new(),
48 }
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct TestResult {
55 pub test_case: String,
57 pub passed: bool,
59 pub error: Option<String>,
61 pub execution_time: Duration,
63 pub metrics: HashMap<String, f64>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct TestSuiteResult {
70 pub total_tests: usize,
72 pub passed_tests: usize,
74 pub failed_tests: usize,
76 pub test_results: Vec<TestResult>,
78 pub benchmark_results: Option<BenchmarkResults>,
80 pub fuzzing_results: Option<FuzzingResults>,
82 pub regression_results: Option<RegressionResults>,
84 pub cross_validation_results: Option<CrossValidationResults>,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct BenchmarkResults {
91 pub encode_tokens_per_second: f64,
93 pub decode_tokens_per_second: f64,
95 pub memory_usage_mb: f64,
97 pub latency_percentiles: HashMap<String, Duration>,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct FuzzingResults {
104 pub tests_run: usize,
106 pub crashes_detected: usize,
108 pub error_types: HashSet<String>,
110 pub coverage_metrics: HashMap<String, f64>,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct RegressionResults {
117 pub tests_run: usize,
119 pub regressions_detected: usize,
121 pub regression_details: Vec<RegressionDetail>,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct RegressionDetail {
128 pub test_case: String,
130 pub expected: String,
132 pub actual: String,
134 pub difference: String,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140struct RegressionTestCase {
141 name: String,
143 input: String,
145 expected_token_count: Option<usize>,
147 expected_success: bool,
149 max_execution_time: Duration,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct CrossValidationResults {
156 pub tokenizers_compared: Vec<String>,
158 pub consistency_score: f64,
160 pub inconsistencies_found: usize,
162 pub inconsistency_details: Vec<InconsistencyDetail>,
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct InconsistencyDetail {
169 pub input: String,
171 pub tokenizer_results: HashMap<String, Vec<String>>,
173 pub severity: InconsistencySeverity,
175}
176
177#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
179pub enum InconsistencySeverity {
180 Low,
181 Medium,
182 High,
183 Critical,
184}
185
186pub struct TestCaseGenerator {
188 rng: StdRng,
189 config: TestConfig,
190}
191
192#[allow(deprecated)]
193impl TestCaseGenerator {
194 pub fn new(config: TestConfig, seed: Option<u64>) -> Self {
196 let rng = if let Some(seed) = seed {
197 StdRng::seed_from_u64(seed)
198 } else {
199 StdRng::from_rng(&mut thread_rng().rng_mut())
200 };
201
202 Self { rng, config }
203 }
204
205 pub fn generate_random_text(&mut self) -> String {
207 let length = self.rng.random_range(1..=self.config.max_input_length);
208 let mut text = String::new();
209
210 for _ in 0..length {
211 let char_type = self.rng.random_range(0..10);
212 let ch = match char_type {
213 0..=5 => self.rng.random_range(b'a'..=b'z') as char, 6 => self.rng.random_range(b'A'..=b'Z') as char, 7 => self.rng.random_range(b'0'..=b'9') as char, 8 => ' ', _ => self.generate_special_char(), };
219 text.push(ch);
220 }
221
222 text
223 }
224
225 pub fn generate_unicode_text(&mut self) -> String {
227 let length = self.rng.random_range(1..=self.config.max_input_length / 2);
228 let mut text = String::new();
229
230 for _ in 0..length {
231 let char_type = self.rng.random_range(0..10);
232 let ch = match char_type {
233 0..=3 => self.rng.random_range('a'..='z'),
234 4 => self.rng.random_range('À'..='ÿ'), 5 => self.rng.random_range('Α'..='ω'), 6 => self.rng.random_range('А'..='я'), 7 => self.rng.random_range('一'..='龯'), 8 => self.rng.random_range('ا'..='ي'), _ => self.rng.random_range('😀'..='🙏'), };
241 text.push(ch);
242 }
243
244 text
245 }
246
247 pub fn generate_edge_case_text(&mut self) -> String {
249 let long_token = "a".repeat(1000);
250 let edge_cases = [
251 "", " ", "\n\t\r", &long_token, "123456789", "!@#$%^&*()", "\u{200B}\u{200C}\u{200D}", "Test\u{0000}null", "🚀🌟💫⭐", "a\u{0301}e\u{0301}i\u{0301}", ];
262
263 edge_cases[self.rng.random_range(0..edge_cases.len())].to_string()
264 }
265
266 pub fn generate_malformed_input(&mut self) -> Vec<u8> {
268 let length = self.rng.random_range(1..=100);
269 let mut bytes = Vec::new();
270
271 for _ in 0..length {
272 bytes.push(self.rng.random());
273 }
274
275 bytes
276 }
277
278 fn generate_special_char(&mut self) -> char {
279 let special_chars = [
280 '!', '@', '#', '$', '%', '^', '&', '*', '(', ')', '-', '_', '+', '=',
281 ];
282 special_chars[self.rng.random_range(0..special_chars.len())]
283 }
284}
285
286pub struct TestRunner {
288 config: TestConfig,
289 generator: TestCaseGenerator,
290}
291
292impl TestRunner {
293 pub fn new(config: TestConfig) -> Self {
295 let generator = TestCaseGenerator::new(config.clone(), None);
296 Self { config, generator }
297 }
298
299 pub fn run_complete_suite<T: Tokenizer + Clone>(
301 &mut self,
302 tokenizer: &T,
303 test_name: &str,
304 ) -> Result<TestSuiteResult> {
305 let mut results = Vec::new();
306 let mut total_tests = 0;
307 let mut passed_tests = 0;
308
309 let basic_results = self.run_basic_tests(tokenizer, test_name)?;
311 total_tests += basic_results.len();
312 passed_tests += basic_results.iter().filter(|r| r.passed).count();
313 results.extend(basic_results);
314
315 let random_results = self.run_random_tests(tokenizer)?;
317 total_tests += random_results.len();
318 passed_tests += random_results.iter().filter(|r| r.passed).count();
319 results.extend(random_results);
320
321 let edge_results = self.run_edge_case_tests(tokenizer)?;
323 total_tests += edge_results.len();
324 passed_tests += edge_results.iter().filter(|r| r.passed).count();
325 results.extend(edge_results);
326
327 if !self.config.custom_test_cases.is_empty() {
329 let custom_results = self.run_custom_tests(tokenizer)?;
330 total_tests += custom_results.len();
331 passed_tests += custom_results.iter().filter(|r| r.passed).count();
332 results.extend(custom_results);
333 }
334
335 let failed_tests = total_tests - passed_tests;
336
337 let benchmark_results = if self.config.run_benchmarks {
339 Some(self.run_benchmarks(tokenizer)?)
340 } else {
341 None
342 };
343
344 let fuzzing_results = if self.config.run_fuzzing {
345 Some(self.run_fuzzing_tests(tokenizer)?)
346 } else {
347 None
348 };
349
350 let regression_results = if self.config.run_regression {
351 Some(self.run_regression_tests(tokenizer)?)
352 } else {
353 None
354 };
355
356 Ok(TestSuiteResult {
357 total_tests,
358 passed_tests,
359 failed_tests,
360 test_results: results,
361 benchmark_results,
362 fuzzing_results,
363 regression_results,
364 cross_validation_results: None, })
366 }
367
368 fn run_basic_tests<T: Tokenizer>(
370 &mut self,
371 tokenizer: &T,
372 test_name: &str,
373 ) -> Result<Vec<TestResult>> {
374 let mut results = Vec::new();
375
376 let test_cases = vec![
378 "Hello, world!",
379 "The quick brown fox jumps over the lazy dog.",
380 "123456789",
381 "Special chars: !@#$%^&*()",
382 "",
383 ];
384
385 for (i, text) in test_cases.into_iter().enumerate() {
386 let start = Instant::now();
387 let test_case = format!("{}_basic_{}", test_name, i);
388
389 match self.test_encode_decode_cycle(tokenizer, text) {
390 Ok(metrics) => {
391 results.push(TestResult {
392 test_case,
393 passed: true,
394 error: None,
395 execution_time: start.elapsed(),
396 metrics,
397 });
398 },
399 Err(e) => {
400 results.push(TestResult {
401 test_case,
402 passed: false,
403 error: Some(e.to_string()),
404 execution_time: start.elapsed(),
405 metrics: HashMap::new(),
406 });
407 },
408 }
409 }
410
411 Ok(results)
412 }
413
414 fn run_random_tests<T: Tokenizer>(&mut self, tokenizer: &T) -> Result<Vec<TestResult>> {
416 let mut results = Vec::new();
417
418 for i in 0..self.config.num_random_tests {
419 let text = self.generator.generate_random_text();
420 let start = Instant::now();
421 let test_case = format!("random_{}", i);
422
423 match self.test_encode_decode_cycle(tokenizer, &text) {
424 Ok(metrics) => {
425 results.push(TestResult {
426 test_case,
427 passed: true,
428 error: None,
429 execution_time: start.elapsed(),
430 metrics,
431 });
432 },
433 Err(e) => {
434 results.push(TestResult {
435 test_case,
436 passed: false,
437 error: Some(e.to_string()),
438 execution_time: start.elapsed(),
439 metrics: HashMap::new(),
440 });
441 },
442 }
443 }
444
445 Ok(results)
446 }
447
448 fn run_edge_case_tests<T: Tokenizer>(&mut self, tokenizer: &T) -> Result<Vec<TestResult>> {
450 let mut results = Vec::new();
451
452 for i in 0..100 {
453 let text = self.generator.generate_edge_case_text();
454 let start = Instant::now();
455 let test_case = format!("edge_case_{}", i);
456
457 match self.test_encode_decode_cycle(tokenizer, &text) {
458 Ok(metrics) => {
459 results.push(TestResult {
460 test_case,
461 passed: true,
462 error: None,
463 execution_time: start.elapsed(),
464 metrics,
465 });
466 },
467 Err(e) => {
468 results.push(TestResult {
469 test_case,
470 passed: false,
471 error: Some(e.to_string()),
472 execution_time: start.elapsed(),
473 metrics: HashMap::new(),
474 });
475 },
476 }
477 }
478
479 Ok(results)
480 }
481
482 fn run_custom_tests<T: Tokenizer>(&mut self, tokenizer: &T) -> Result<Vec<TestResult>> {
484 let mut results = Vec::new();
485
486 for (i, text) in self.config.custom_test_cases.iter().enumerate() {
487 let start = Instant::now();
488 let test_case = format!("custom_{}", i);
489
490 match self.test_encode_decode_cycle(tokenizer, text) {
491 Ok(metrics) => {
492 results.push(TestResult {
493 test_case,
494 passed: true,
495 error: None,
496 execution_time: start.elapsed(),
497 metrics,
498 });
499 },
500 Err(e) => {
501 results.push(TestResult {
502 test_case,
503 passed: false,
504 error: Some(e.to_string()),
505 execution_time: start.elapsed(),
506 metrics: HashMap::new(),
507 });
508 },
509 }
510 }
511
512 Ok(results)
513 }
514
515 fn test_encode_decode_cycle<T: Tokenizer>(
517 &self,
518 tokenizer: &T,
519 text: &str,
520 ) -> Result<HashMap<String, f64>> {
521 let mut metrics = HashMap::new();
522
523 let encoded = tokenizer.encode(text)?;
525 metrics.insert("num_tokens".to_string(), encoded.input_ids.len() as f64);
526 metrics.insert("input_length".to_string(), text.chars().count() as f64);
527
528 let decoded = tokenizer.decode(&encoded.input_ids)?;
530
531 for &token_id in &encoded.input_ids {
533 if let Some(token) = tokenizer.id_to_token(token_id) {
534 if tokenizer.token_to_id(&token).is_none() {
535 return Err(TrustformersError::runtime_error(format!(
536 "Token '{}' not found in vocabulary",
537 token
538 )));
539 }
540 }
541 }
542
543 if !text.is_empty() {
545 let compression_ratio = encoded.input_ids.len() as f64 / text.chars().count() as f64;
546 metrics.insert("compression_ratio".to_string(), compression_ratio);
547 }
548
549 if decoded.trim() != text.trim() {
551 metrics.insert("exact_match".to_string(), 0.0);
552 } else {
553 metrics.insert("exact_match".to_string(), 1.0);
554 }
555
556 Ok(metrics)
557 }
558
559 fn run_benchmarks<T: Tokenizer>(&mut self, tokenizer: &T) -> Result<BenchmarkResults> {
561 let test_texts: Vec<String> =
562 (0..1000).map(|_| self.generator.generate_random_text()).collect();
563
564 let start = Instant::now();
566 let mut total_tokens = 0;
567
568 for text in &test_texts {
569 let encoded = tokenizer.encode(text)?;
570 total_tokens += encoded.input_ids.len();
571 }
572
573 let encoding_time = start.elapsed();
574 let encode_tokens_per_second = total_tokens as f64 / encoding_time.as_secs_f64();
575
576 let token_sequences: Vec<Vec<u32>> = test_texts
578 .iter()
579 .map(|text| tokenizer.encode(text).map(|enc| enc.input_ids))
580 .collect::<std::result::Result<Vec<_>, _>>()?;
581
582 let start = Instant::now();
583 for tokens in &token_sequences {
584 let _ = tokenizer.decode(tokens)?;
585 }
586 let decoding_time = start.elapsed();
587 let decode_tokens_per_second = total_tokens as f64 / decoding_time.as_secs_f64();
588
589 let vocab = tokenizer.get_vocab();
591 let memory_usage_mb = (vocab.len() * 100) as f64 / 1024.0 / 1024.0; let mut latency_percentiles = HashMap::new();
595 latency_percentiles.insert("p50".to_string(), encoding_time / test_texts.len() as u32);
596 latency_percentiles.insert("p95".to_string(), encoding_time / test_texts.len() as u32);
597 latency_percentiles.insert("p99".to_string(), encoding_time / test_texts.len() as u32);
598
599 Ok(BenchmarkResults {
600 encode_tokens_per_second,
601 decode_tokens_per_second,
602 memory_usage_mb,
603 latency_percentiles,
604 })
605 }
606
607 fn run_fuzzing_tests<T: Tokenizer>(&mut self, tokenizer: &T) -> Result<FuzzingResults> {
609 let mut tests_run = 0;
610 let mut crashes_detected = 0;
611 let mut error_types = HashSet::new();
612 let mut coverage_metrics = HashMap::new();
613
614 for _ in 0..1000 {
616 tests_run += 1;
617
618 let malformed_bytes = self.generator.generate_malformed_input();
620 if let Ok(malformed_string) = String::from_utf8(malformed_bytes) {
621 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
622 tokenizer.encode(&malformed_string)
623 })) {
624 Ok(result) => {
625 if let Err(e) = result {
626 error_types.insert(format!("{:?}", e));
627 }
628 },
629 Err(_) => {
630 crashes_detected += 1;
631 },
632 }
633 }
634
635 if tests_run % 100 == 0 {
637 let very_long_text = "a".repeat(10000);
638 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
639 tokenizer.encode(&very_long_text)
640 })) {
641 Ok(result) => {
642 if let Err(e) = result {
643 error_types.insert(format!("{:?}", e));
644 }
645 },
646 Err(_) => {
647 crashes_detected += 1;
648 },
649 }
650 }
651 }
652
653 coverage_metrics.insert(
655 "crash_rate".to_string(),
656 crashes_detected as f64 / tests_run as f64,
657 );
658 coverage_metrics.insert("error_diversity".to_string(), error_types.len() as f64);
659
660 Ok(FuzzingResults {
661 tests_run,
662 crashes_detected,
663 error_types,
664 coverage_metrics,
665 })
666 }
667
668 fn run_regression_tests<T: Tokenizer>(&mut self, tokenizer: &T) -> Result<RegressionResults> {
670 let mut regression_details = Vec::new();
671 let mut tests_run = 0;
672
673 let test_cases = self.create_regression_test_cases();
675
676 for test_case in &test_cases {
677 tests_run += 1;
678
679 let start_time = Instant::now();
681 let current_result = match tokenizer.encode(&test_case.input) {
682 Ok(result) => result,
683 Err(e) => {
684 if test_case.expected_success {
686 regression_details.push(RegressionDetail {
687 test_case: test_case.name.clone(),
688 expected: "Successful tokenization".to_string(),
689 actual: format!("Failed with error: {}", e),
690 difference: "Tokenization failed unexpectedly".to_string(),
691 });
692 }
693 continue;
694 },
695 };
696
697 let execution_time = start_time.elapsed();
698
699 if let Some(regression) =
701 self.compare_with_baseline(test_case, ¤t_result, execution_time)
702 {
703 regression_details.push(regression);
704 }
705 }
706
707 Ok(RegressionResults {
708 tests_run,
709 regressions_detected: regression_details.len(),
710 regression_details,
711 })
712 }
713
714 fn create_regression_test_cases(&self) -> Vec<RegressionTestCase> {
716 vec![
717 RegressionTestCase {
718 name: "basic_english".to_string(),
719 input: "Hello world".to_string(),
720 expected_token_count: Some(2),
721 expected_success: true,
722 max_execution_time: Duration::from_millis(100),
723 },
724 RegressionTestCase {
725 name: "unicode_text".to_string(),
726 input: "你好世界 🌍".to_string(),
727 expected_token_count: None, expected_success: true,
729 max_execution_time: Duration::from_millis(100),
730 },
731 RegressionTestCase {
732 name: "long_sentence".to_string(),
733 input: "The quick brown fox jumps over the lazy dog. This is a longer sentence to test tokenization performance and accuracy.".to_string(),
734 expected_token_count: None, expected_success: true,
736 max_execution_time: Duration::from_millis(200),
737 },
738 RegressionTestCase {
739 name: "empty_string".to_string(),
740 input: "".to_string(),
741 expected_token_count: Some(0),
742 expected_success: true,
743 max_execution_time: Duration::from_millis(50),
744 },
745 RegressionTestCase {
746 name: "special_characters".to_string(),
747 input: "!@#$%^&*()_+-=[]{}|;':\",./<>?".to_string(),
748 expected_token_count: None, expected_success: true,
750 max_execution_time: Duration::from_millis(100),
751 },
752 RegressionTestCase {
753 name: "mixed_languages".to_string(),
754 input: "Hello こんにちは 你好 Hola".to_string(),
755 expected_token_count: None, expected_success: true,
757 max_execution_time: Duration::from_millis(150),
758 },
759 ]
760 }
761
762 fn compare_with_baseline(
764 &self,
765 test_case: &RegressionTestCase,
766 current_result: &TokenizedInput,
767 execution_time: Duration,
768 ) -> Option<RegressionDetail> {
769 let mut differences = Vec::new();
770
771 if let Some(expected_count) = test_case.expected_token_count {
773 let actual_count = current_result.input_ids.len();
774 if actual_count != expected_count {
775 differences.push(format!(
776 "Token count: expected {}, got {}",
777 expected_count, actual_count
778 ));
779 }
780 }
781
782 if execution_time > test_case.max_execution_time {
784 differences.push(format!(
785 "Execution time: expected <= {:?}, got {:?}",
786 test_case.max_execution_time, execution_time
787 ));
788 }
789
790 if !test_case.input.is_empty() && current_result.input_ids.is_empty() {
792 differences.push("Unexpected empty tokenization result".to_string());
793 }
794
795 if current_result.input_ids.len() != current_result.attention_mask.len() {
797 differences.push(format!(
798 "Attention mask length mismatch: input_ids={}, attention_mask={}",
799 current_result.input_ids.len(),
800 current_result.attention_mask.len()
801 ));
802 }
803
804 if !differences.is_empty() {
805 Some(RegressionDetail {
806 test_case: test_case.name.clone(),
807 expected: format!(
808 "Proper tokenization within {} ms",
809 test_case.max_execution_time.as_millis()
810 ),
811 actual: format!("Issues detected: {:?}", differences),
812 difference: differences.join("; "),
813 })
814 } else {
815 None
816 }
817 }
818}
819
820pub struct CrossValidationRunner {
822 #[allow(dead_code)]
823 config: TestConfig,
824}
825
826impl CrossValidationRunner {
827 pub fn new(config: TestConfig) -> Self {
828 Self { config }
829 }
830
831 pub fn compare_tokenizers(
833 &self,
834 tokenizers: Vec<(&str, &dyn Tokenizer)>,
835 test_cases: &[String],
836 ) -> Result<CrossValidationResults> {
837 let mut inconsistencies = Vec::new();
838 let mut total_comparisons = 0;
839 let mut consistent_comparisons = 0;
840
841 for text in test_cases {
842 total_comparisons += 1;
843 let mut results = HashMap::new();
844
845 for (name, tokenizer) in &tokenizers {
847 match tokenizer.encode(text) {
848 Ok(encoded) => {
849 let tokens: Vec<String> = encoded
850 .input_ids
851 .iter()
852 .filter_map(|&id| tokenizer.id_to_token(id))
853 .collect();
854 results.insert(name.to_string(), tokens);
855 },
856 Err(_) => {
857 continue;
859 },
860 }
861 }
862
863 if results.len() > 1 {
865 let first_result = match results.values().next() {
866 Some(result) => result,
867 None => continue, };
869 let is_consistent = results.values().all(|tokens| tokens == first_result);
870
871 if is_consistent {
872 consistent_comparisons += 1;
873 } else {
874 let severity = self.determine_inconsistency_severity(&results);
875 inconsistencies.push(InconsistencyDetail {
876 input: text.clone(),
877 tokenizer_results: results,
878 severity,
879 });
880 }
881 }
882 }
883
884 let consistency_score = if total_comparisons > 0 {
885 consistent_comparisons as f64 / total_comparisons as f64
886 } else {
887 0.0
888 };
889
890 Ok(CrossValidationResults {
891 tokenizers_compared: tokenizers.iter().map(|(name, _)| name.to_string()).collect(),
892 consistency_score,
893 inconsistencies_found: inconsistencies.len(),
894 inconsistency_details: inconsistencies,
895 })
896 }
897
898 fn determine_inconsistency_severity(
899 &self,
900 results: &HashMap<String, Vec<String>>,
901 ) -> InconsistencySeverity {
902 let token_counts: Vec<usize> = results.values().map(|tokens| tokens.len()).collect();
904 let min_count = *token_counts.iter().min().unwrap_or(&0);
905 let max_count = *token_counts.iter().max().unwrap_or(&0);
906
907 if max_count == 0 {
908 InconsistencySeverity::Low
909 } else {
910 let ratio = min_count as f64 / max_count as f64;
911 if ratio < 0.5 {
912 InconsistencySeverity::High
913 } else if ratio < 0.8 {
914 InconsistencySeverity::Medium
915 } else {
916 InconsistencySeverity::Low
917 }
918 }
919 }
920}
921
922pub struct TestReportUtils;
924
925impl TestReportUtils {
926 pub fn generate_report(result: &TestSuiteResult) -> String {
928 let mut report = String::new();
929
930 report.push_str("=== COMPREHENSIVE TEST REPORT ===\n\n");
931
932 report.push_str(&format!("Total Tests: {}\n", result.total_tests));
934 report.push_str(&format!(
935 "Passed: {} ({:.1}%)\n",
936 result.passed_tests,
937 (result.passed_tests as f64 / result.total_tests as f64) * 100.0
938 ));
939 report.push_str(&format!(
940 "Failed: {} ({:.1}%)\n\n",
941 result.failed_tests,
942 (result.failed_tests as f64 / result.total_tests as f64) * 100.0
943 ));
944
945 if result.failed_tests > 0 {
947 report.push_str("FAILED TESTS:\n");
948 for test in &result.test_results {
949 if !test.passed {
950 report.push_str(&format!(
951 " {} - {}\n",
952 test.test_case,
953 test.error.as_ref().unwrap_or(&"Unknown error".to_string())
954 ));
955 }
956 }
957 report.push('\n');
958 }
959
960 if let Some(ref benchmarks) = result.benchmark_results {
962 report.push_str("PERFORMANCE BENCHMARKS:\n");
963 report.push_str(&format!(
964 " Encoding: {:.0} tokens/sec\n",
965 benchmarks.encode_tokens_per_second
966 ));
967 report.push_str(&format!(
968 " Decoding: {:.0} tokens/sec\n",
969 benchmarks.decode_tokens_per_second
970 ));
971 report.push_str(&format!(
972 " Memory Usage: {:.1} MB\n\n",
973 benchmarks.memory_usage_mb
974 ));
975 }
976
977 if let Some(ref fuzzing) = result.fuzzing_results {
979 report.push_str("FUZZING RESULTS:\n");
980 report.push_str(&format!(" Tests Run: {}\n", fuzzing.tests_run));
981 report.push_str(&format!(
982 " Crashes Detected: {}\n",
983 fuzzing.crashes_detected
984 ));
985 report.push_str(&format!(
986 " Unique Error Types: {}\n\n",
987 fuzzing.error_types.len()
988 ));
989 }
990
991 if let Some(ref cross_val) = result.cross_validation_results {
993 report.push_str("CROSS-VALIDATION RESULTS:\n");
994 report.push_str(&format!(
995 " Consistency Score: {:.3}\n",
996 cross_val.consistency_score
997 ));
998 report.push_str(&format!(
999 " Inconsistencies Found: {}\n\n",
1000 cross_val.inconsistencies_found
1001 ));
1002 }
1003
1004 report
1005 }
1006
1007 pub fn analyze_metrics(results: &[TestResult]) -> HashMap<String, f64> {
1009 let mut analysis = HashMap::new();
1010
1011 let mut total_time = Duration::new(0, 0);
1012 let mut total_tokens = 0.0;
1013 let mut compression_ratios = Vec::new();
1014
1015 for result in results {
1016 total_time += result.execution_time;
1017
1018 if let Some(&tokens) = result.metrics.get("num_tokens") {
1019 total_tokens += tokens;
1020 }
1021
1022 if let Some(&ratio) = result.metrics.get("compression_ratio") {
1023 compression_ratios.push(ratio);
1024 }
1025 }
1026
1027 analysis.insert(
1028 "avg_execution_time_ms".to_string(),
1029 total_time.as_millis() as f64 / results.len() as f64,
1030 );
1031 analysis.insert(
1032 "avg_tokens_per_test".to_string(),
1033 total_tokens / results.len() as f64,
1034 );
1035
1036 if !compression_ratios.is_empty() {
1037 let avg_compression =
1038 compression_ratios.iter().sum::<f64>() / compression_ratios.len() as f64;
1039 analysis.insert("avg_compression_ratio".to_string(), avg_compression);
1040 }
1041
1042 analysis
1043 }
1044}
1045
1046#[cfg(test)]
1047mod tests {
1048 use super::*;
1049 use std::collections::HashMap;
1050
1051 #[derive(Clone)]
1053 struct MockTokenizer {
1054 vocab: HashMap<String, u32>,
1055 }
1056
1057 impl MockTokenizer {
1058 fn new() -> Self {
1059 let mut vocab = HashMap::new();
1060 vocab.insert("hello".to_string(), 1);
1061 vocab.insert("world".to_string(), 2);
1062 vocab.insert("test".to_string(), 3);
1063 vocab.insert("!".to_string(), 4);
1064
1065 Self { vocab }
1066 }
1067 }
1068
1069 impl Tokenizer for MockTokenizer {
1070 fn encode(&self, text: &str) -> Result<TokenizedInput> {
1071 let tokens: Vec<&str> = text.split_whitespace().collect();
1072 let mut input_ids = Vec::new();
1073 let mut token_strings = Vec::new();
1074
1075 for token in tokens {
1076 if let Some(&id) = self.vocab.get(token) {
1077 input_ids.push(id);
1078 token_strings.push(token.to_string());
1079 }
1080 }
1081
1082 Ok(TokenizedInput {
1083 input_ids,
1084 attention_mask: vec![1; token_strings.len()],
1085 token_type_ids: None,
1086 special_tokens_mask: None,
1087 offset_mapping: None,
1088 overflowing_tokens: None,
1089 })
1090 }
1091
1092 fn decode(&self, token_ids: &[u32]) -> Result<String> {
1093 let tokens: Result<Vec<String>> = token_ids
1094 .iter()
1095 .map(|&id| {
1096 self.vocab.iter().find(|(_, &v)| v == id).map(|(k, _)| k.clone()).ok_or_else(
1097 || TrustformersError::other(format!("Unknown token ID: {}", id)),
1098 )
1099 })
1100 .collect();
1101
1102 Ok(tokens?.join(" "))
1103 }
1104
1105 fn get_vocab(&self) -> HashMap<String, u32> {
1106 self.vocab.clone()
1107 }
1108
1109 fn token_to_id(&self, token: &str) -> Option<u32> {
1110 self.vocab.get(token).copied()
1111 }
1112
1113 fn id_to_token(&self, id: u32) -> Option<String> {
1114 self.vocab.iter().find(|(_, &v)| v == id).map(|(k, _)| k.clone())
1115 }
1116
1117 fn encode_pair(&self, text_a: &str, text_b: &str) -> Result<TokenizedInput> {
1118 let combined = format!("{} {}", text_a, text_b);
1119 self.encode(&combined)
1120 }
1121
1122 fn vocab_size(&self) -> usize {
1123 self.vocab.len()
1124 }
1125 }
1126
1127 #[test]
1128 fn test_config_default() {
1129 let config = TestConfig::default();
1130 assert_eq!(config.num_random_tests, 1000);
1131 assert!(config.run_benchmarks);
1132 assert!(config.run_fuzzing);
1133 }
1134
1135 #[test]
1136 fn test_case_generator() {
1137 let config = TestConfig::default();
1138 let mut generator = TestCaseGenerator::new(config, Some(42));
1139
1140 let random_text = generator.generate_random_text();
1141 assert!(!random_text.is_empty());
1142
1143 let unicode_text = generator.generate_unicode_text();
1144 assert!(!unicode_text.is_empty());
1145
1146 let edge_case = generator.generate_edge_case_text();
1147 assert!(edge_case.len() <= 1000);
1149 }
1150
1151 #[test]
1152 fn test_basic_functionality() {
1153 let config = TestConfig::default();
1154 let mut runner = TestRunner::new(config);
1155 let tokenizer = MockTokenizer::new();
1156
1157 let results = runner.run_basic_tests(&tokenizer, "test").expect("Operation failed in test");
1158 assert!(!results.is_empty());
1159
1160 let passed_count = results.iter().filter(|r| r.passed).count();
1162 assert!(passed_count > 0);
1163 }
1164
1165 #[test]
1166 fn test_encode_decode_cycle() {
1167 let config = TestConfig::default();
1168 let runner = TestRunner::new(config);
1169 let tokenizer = MockTokenizer::new();
1170
1171 let metrics = runner
1172 .test_encode_decode_cycle(&tokenizer, "hello world")
1173 .expect("Operation failed in test");
1174 assert!(metrics.contains_key("num_tokens"));
1175 assert!(metrics.contains_key("input_length"));
1176 }
1177
1178 #[test]
1179 fn test_cross_validation() {
1180 let config = TestConfig::default();
1181 let validator = CrossValidationRunner::new(config);
1182
1183 let tokenizer1 = MockTokenizer::new();
1184 let tokenizer2 = MockTokenizer::new();
1185
1186 let tokenizers: Vec<(&str, &dyn Tokenizer)> =
1187 vec![("mock1", &tokenizer1), ("mock2", &tokenizer2)];
1188
1189 let test_cases = vec!["hello world".to_string(), "test".to_string()];
1190 let results = validator
1191 .compare_tokenizers(tokenizers, &test_cases)
1192 .expect("Operation failed in test");
1193
1194 assert_eq!(results.tokenizers_compared.len(), 2);
1195 assert!(results.consistency_score >= 0.0 && results.consistency_score <= 1.0);
1196 }
1197
1198 #[test]
1199 fn test_report_generation() {
1200 let test_result = TestSuiteResult {
1201 total_tests: 10,
1202 passed_tests: 8,
1203 failed_tests: 2,
1204 test_results: vec![
1205 TestResult {
1206 test_case: "test1".to_string(),
1207 passed: true,
1208 error: None,
1209 execution_time: Duration::from_millis(10),
1210 metrics: HashMap::new(),
1211 },
1212 TestResult {
1213 test_case: "test2".to_string(),
1214 passed: false,
1215 error: Some("Test failed".to_string()),
1216 execution_time: Duration::from_millis(5),
1217 metrics: HashMap::new(),
1218 },
1219 ],
1220 benchmark_results: None,
1221 fuzzing_results: None,
1222 regression_results: None,
1223 cross_validation_results: None,
1224 };
1225
1226 let report = TestReportUtils::generate_report(&test_result);
1227 assert!(report.contains("Total Tests: 10"));
1228 assert!(report.contains("Passed: 8"));
1229 assert!(report.contains("Failed: 2"));
1230 }
1231
1232 #[test]
1233 fn test_metrics_analysis() {
1234 let results = vec![
1235 TestResult {
1236 test_case: "test1".to_string(),
1237 passed: true,
1238 error: None,
1239 execution_time: Duration::from_millis(10),
1240 metrics: {
1241 let mut m = HashMap::new();
1242 m.insert("num_tokens".to_string(), 5.0);
1243 m.insert("compression_ratio".to_string(), 0.8);
1244 m
1245 },
1246 },
1247 TestResult {
1248 test_case: "test2".to_string(),
1249 passed: true,
1250 error: None,
1251 execution_time: Duration::from_millis(20),
1252 metrics: {
1253 let mut m = HashMap::new();
1254 m.insert("num_tokens".to_string(), 3.0);
1255 m.insert("compression_ratio".to_string(), 1.2);
1256 m
1257 },
1258 },
1259 ];
1260
1261 let analysis = TestReportUtils::analyze_metrics(&results);
1262 assert!(analysis.contains_key("avg_execution_time_ms"));
1263 assert!(analysis.contains_key("avg_tokens_per_test"));
1264 assert!(analysis.contains_key("avg_compression_ratio"));
1265 }
1266}