trustformers_models/comprehensive_testing/
config.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ValidationConfig {
8 pub numerical_tolerance: f32,
10 pub run_performance_tests: bool,
12 pub compare_with_reference: bool,
14 pub max_inference_time_ms: u64,
16 pub max_memory_usage_mb: u64,
18 pub test_inputs: Vec<TestInputConfig>,
20 pub test_data_types: Vec<TestDataType>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct TestInputConfig {
27 pub name: String,
29 pub dimensions: Vec<usize>,
31 pub data_type: TestDataType,
33 pub required: bool,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
39pub enum TestDataType {
40 F32,
41 F16,
42 I32,
43 I64,
44}
45
46impl Default for ValidationConfig {
47 fn default() -> Self {
48 Self {
49 numerical_tolerance: 1e-4,
50 run_performance_tests: true,
51 compare_with_reference: false,
52 max_inference_time_ms: 10000,
53 max_memory_usage_mb: 16384,
54 test_inputs: vec![
55 TestInputConfig {
56 name: "small_batch".to_string(),
57 dimensions: vec![1, 128],
58 data_type: TestDataType::I32,
59 required: true,
60 },
61 TestInputConfig {
62 name: "medium_batch".to_string(),
63 dimensions: vec![4, 256],
64 data_type: TestDataType::I32,
65 required: true,
66 },
67 TestInputConfig {
68 name: "large_batch".to_string(),
69 dimensions: vec![16, 512],
70 data_type: TestDataType::I32,
71 required: false,
72 },
73 ],
74 test_data_types: vec![TestDataType::F32, TestDataType::F16],
75 }
76 }
77}