Skip to main content

trustformers_models/comprehensive_testing/
config.rs

1//! Configuration types for comprehensive testing framework
2
3use serde::{Deserialize, Serialize};
4
5/// Configuration for model validation tests
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ValidationConfig {
8    /// Numerical tolerance for comparisons
9    pub numerical_tolerance: f32,
10    /// Whether to run performance benchmarks
11    pub run_performance_tests: bool,
12    /// Whether to compare against reference implementations
13    pub compare_with_reference: bool,
14    /// Maximum acceptable inference time (milliseconds)
15    pub max_inference_time_ms: u64,
16    /// Maximum acceptable memory usage (MB)
17    pub max_memory_usage_mb: u64,
18    /// Test input configurations
19    pub test_inputs: Vec<TestInputConfig>,
20    /// Supported data types for testing
21    pub test_data_types: Vec<TestDataType>,
22}
23
24/// Test input configuration
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct TestInputConfig {
27    /// Input name/description
28    pub name: String,
29    /// Input dimensions
30    pub dimensions: Vec<usize>,
31    /// Input data type
32    pub data_type: TestDataType,
33    /// Whether this is a required test
34    pub required: bool,
35}
36
37/// Supported data types for testing
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
39pub enum TestDataType {
40    F32,
41    F16,
42    I32,
43    I64,
44}
45
46impl Default for ValidationConfig {
47    fn default() -> Self {
48        Self {
49            numerical_tolerance: 1e-4,
50            run_performance_tests: true,
51            compare_with_reference: false,
52            max_inference_time_ms: 10000,
53            max_memory_usage_mb: 16384,
54            test_inputs: vec![
55                TestInputConfig {
56                    name: "small_batch".to_string(),
57                    dimensions: vec![1, 128],
58                    data_type: TestDataType::I32,
59                    required: true,
60                },
61                TestInputConfig {
62                    name: "medium_batch".to_string(),
63                    dimensions: vec![4, 256],
64                    data_type: TestDataType::I32,
65                    required: true,
66                },
67                TestInputConfig {
68                    name: "large_batch".to_string(),
69                    dimensions: vec![16, 512],
70                    data_type: TestDataType::I32,
71                    required: false,
72                },
73            ],
74            test_data_types: vec![TestDataType::F32, TestDataType::F16],
75        }
76    }
77}