trustformers_models/weight_loading/
config.rs1use std::path::PathBuf;
5use std::time::Duration;
6
7#[derive(Debug, Clone, PartialEq)]
9pub enum WeightFormat {
10 HuggingFaceBin, SafeTensors, ONNX, TensorFlow, GGUF, Custom(String), }
17
18#[derive(Debug, Clone)]
20pub struct WeightLoadingConfig {
21 pub format: Option<WeightFormat>,
22 pub lazy_loading: bool,
23 pub memory_mapped: bool,
24 pub streaming: bool,
25 pub device: String,
26 pub dtype: WeightDataType,
27 pub quantization: Option<QuantizationConfig>,
28 pub cache_dir: Option<PathBuf>,
29 pub verify_checksums: bool,
30 pub distributed: Option<DistributedConfig>,
31}
32
33#[derive(Debug, Clone)]
35pub struct DistributedConfig {
36 pub nodes: Vec<NodeConfig>,
38 pub load_balancer: LoadBalancingStrategy,
40 pub fault_tolerance: FaultToleranceConfig,
42 pub network: NetworkConfig,
44 pub distributed_cache: DistributedCacheConfig,
46 pub compression: bool,
48}
49
50#[derive(Debug, Clone)]
52pub struct NodeConfig {
53 pub id: String,
54 pub address: String,
55 pub port: u16,
56 pub weight_capacity: u64, pub bandwidth: f64, pub priority: u8, pub storage_paths: Vec<PathBuf>, }
61
62#[derive(Debug, Clone, PartialEq)]
64pub enum LoadBalancingStrategy {
65 RoundRobin,
66 LeastLoaded,
67 WeightedRoundRobin,
68 ConsistentHashing,
69 LocalityAware,
70 PerformanceBased,
71 Adaptive,
72}
73
74#[derive(Debug, Clone)]
76pub struct FaultToleranceConfig {
77 pub max_retries: u32,
78 pub retry_delay: Duration,
79 pub timeout: Duration,
80 pub enable_failover: bool,
81 pub health_check_interval: Duration,
82 pub backup_nodes: Vec<String>, }
84
85#[derive(Debug, Clone)]
87pub struct NetworkConfig {
88 pub max_concurrent_connections: usize,
89 pub connection_timeout: Duration,
90 pub read_timeout: Duration,
91 pub chunk_size: usize,
92 pub enable_keepalive: bool,
93 pub compression_threshold: usize, }
95
96#[derive(Debug, Clone)]
98pub struct DistributedCacheConfig {
99 pub cache_strategy: CacheStrategy,
100 pub replication_factor: u8, pub eviction_policy: CacheEvictionPolicy,
102 pub consistency_level: ConsistencyLevel,
103}
104
105#[derive(Debug, Clone, PartialEq)]
106pub enum CacheStrategy {
107 None,
108 ReadThrough,
109 WriteThrough,
110 WriteBack,
111 ReadAround,
112}
113
114#[derive(Debug, Clone, PartialEq)]
115pub enum CacheEvictionPolicy {
116 LRU,
117 LFU,
118 FIFO,
119 Random,
120 TTL,
121}
122
123#[derive(Debug, Clone, PartialEq)]
124pub enum ConsistencyLevel {
125 Strong,
126 Eventual,
127 Weak,
128}
129
130#[derive(Debug, Clone)]
131pub enum WeightDataType {
132 Float32,
133 Float16,
134 BFloat16,
135 Int8,
136 Int4,
137}
138
139#[derive(Debug, Clone)]
140pub struct QuantizationConfig {
141 pub bits: u8,
142 pub group_size: Option<usize>,
143 pub symmetric: bool,
144}
145
146impl Default for WeightLoadingConfig {
147 fn default() -> Self {
148 Self {
149 format: None,
150 lazy_loading: false,
151 memory_mapped: false,
152 streaming: false,
153 device: "cpu".to_string(),
154 dtype: WeightDataType::Float32,
155 quantization: None,
156 cache_dir: None,
157 verify_checksums: true,
158 distributed: None,
159 }
160 }
161}
162
163impl Default for FaultToleranceConfig {
164 fn default() -> Self {
165 Self {
166 max_retries: 3,
167 retry_delay: Duration::from_millis(1000),
168 timeout: Duration::from_secs(30),
169 enable_failover: true,
170 health_check_interval: Duration::from_secs(60),
171 backup_nodes: Vec::new(),
172 }
173 }
174}
175
176impl Default for NetworkConfig {
177 fn default() -> Self {
178 Self {
179 max_concurrent_connections: 10,
180 connection_timeout: Duration::from_secs(30),
181 read_timeout: Duration::from_secs(60),
182 chunk_size: 8192,
183 enable_keepalive: true,
184 compression_threshold: 1024 * 1024, }
186 }
187}
188
189impl Default for DistributedCacheConfig {
190 fn default() -> Self {
191 Self {
192 cache_strategy: CacheStrategy::ReadThrough,
193 replication_factor: 2,
194 eviction_policy: CacheEvictionPolicy::LRU,
195 consistency_level: ConsistencyLevel::Eventual,
196 }
197 }
198}