trustformers_models/weight_loading/
distributed.rs1use std::collections::HashMap;
6use std::path::PathBuf;
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9use tokio::io::AsyncReadExt;
10use trustformers_core::{
11 errors::{runtime_error, Result, TrustformersError},
12 tensor::Tensor,
13};
14
15use super::config::{
16 CacheStrategy, DistributedConfig, FaultToleranceConfig, LoadBalancingStrategy, NodeConfig,
17 WeightLoadingConfig,
18};
19use super::huggingface::{TensorMetadata, WeightLoader};
20
21pub struct DistributedWeightLoader {
23 config: WeightLoadingConfig,
24 distributed_config: DistributedConfig,
25 local_loaders: HashMap<String, Box<dyn WeightLoader>>,
26 node_connections: HashMap<String, tokio::net::TcpStream>,
27 load_balancer: LoadBalancer,
28 health_monitor: HealthMonitor,
29 tensor_cache: Arc<Mutex<HashMap<String, Tensor>>>,
30 stats: DistributedStats,
31}
32
33impl DistributedWeightLoader {
34 pub fn new(config: WeightLoadingConfig, distributed_config: DistributedConfig) -> Result<Self> {
36 let load_balancer =
37 LoadBalancer::new(&distributed_config.load_balancer, &distributed_config.nodes)?;
38 let health_monitor = HealthMonitor::new(&distributed_config.fault_tolerance)?;
39
40 Ok(Self {
41 config,
42 distributed_config,
43 local_loaders: HashMap::new(),
44 node_connections: HashMap::new(),
45 load_balancer,
46 health_monitor,
47 tensor_cache: Arc::new(Mutex::new(HashMap::new())),
48 stats: DistributedStats::new(),
49 })
50 }
51
52 pub async fn initialize(&mut self) -> Result<()> {
54 for node in &self.distributed_config.nodes.clone() {
55 if let Err(e) = self.connect_to_node(node).await {
56 if !self.distributed_config.fault_tolerance.enable_failover {
57 return Err(e);
58 }
59 eprintln!("Warning: Failed to connect to node {}: {}", node.id, e);
61 }
62 }
63
64 self.health_monitor.start_monitoring(&self.distributed_config.nodes).await?;
66
67 Ok(())
68 }
69
70 async fn connect_to_node(&mut self, node: &NodeConfig) -> Result<()> {
72 let address = format!("{}:{}", node.address, node.port);
73 let timeout_duration = self.distributed_config.network.connection_timeout;
74
75 let stream =
76 tokio::time::timeout(timeout_duration, tokio::net::TcpStream::connect(&address))
77 .await
78 .map_err(|_| {
79 TrustformersError::runtime_error(format!("Connection to {} timed out", address))
80 })?
81 .map_err(|e| {
82 TrustformersError::io_error(format!("Failed to connect to {}: {}", address, e))
83 })?;
84
85 self.node_connections.insert(node.id.clone(), stream);
86 Ok(())
87 }
88
89 async fn load_tensor_distributed(&mut self, name: &str) -> Result<Tensor> {
91 if let Some(tensor) = self.check_cache(name).await? {
93 self.stats.cache_hits += 1;
94 return Ok(tensor);
95 }
96
97 self.stats.cache_misses += 1;
98
99 let selected_node =
101 self.load_balancer.select_node(name, &self.distributed_config.nodes)?.clone();
102
103 let mut attempts = 0;
105 let max_retries = self.distributed_config.fault_tolerance.max_retries;
106
107 loop {
108 match self.load_from_node(&selected_node, name).await {
109 Ok(tensor) => {
110 if self.should_cache(name) {
112 self.cache_tensor(name, &tensor).await?;
113 }
114
115 self.stats.successful_loads += 1;
116 return Ok(tensor);
117 },
118 Err(e) => {
119 attempts += 1;
120 self.stats.failed_loads += 1;
121
122 if attempts >= max_retries {
123 if self.distributed_config.fault_tolerance.enable_failover {
124 if let Some(backup_node) = self.find_backup_node(&selected_node.id) {
126 return self.load_from_node(&backup_node, name).await;
127 }
128 }
129 return Err(e);
130 }
131
132 tokio::time::sleep(self.distributed_config.fault_tolerance.retry_delay).await;
134 },
135 }
136 }
137 }
138
139 async fn load_from_node(&mut self, node: &NodeConfig, name: &str) -> Result<Tensor> {
141 let start_time = Instant::now();
142
143 let file_path = self.find_tensor_on_node(node, name)?;
145
146 let tensor = if self.config.streaming {
148 self.stream_tensor_from_node(node, &file_path, name).await?
149 } else {
150 self.load_tensor_from_node(node, &file_path, name).await?
151 };
152
153 let load_time = start_time.elapsed();
154 self.stats.total_load_time += load_time;
155 self.stats.node_load_times.entry(node.id.clone()).or_default().push(load_time);
156
157 Ok(tensor)
158 }
159
160 fn find_tensor_on_node(&self, node: &NodeConfig, name: &str) -> Result<PathBuf> {
162 for storage_path in &node.storage_paths {
163 let potential_files = vec![
164 storage_path.join(format!("{}.safetensors", name)),
165 storage_path.join(format!("{}.bin", name)),
166 storage_path.join("pytorch_model.bin"),
167 storage_path.join("model.safetensors"),
168 ];
169
170 for file_path in potential_files {
171 if file_path.exists() {
172 return Ok(file_path);
173 }
174 }
175 }
176
177 Err(runtime_error(format!(
178 "Tensor {} not found on node {}",
179 name, node.id
180 )))
181 }
182
183 async fn load_tensor_from_node(
185 &mut self,
186 node: &NodeConfig,
187 file_path: &PathBuf,
188 name: &str,
189 ) -> Result<Tensor> {
190 let loader_key = format!("{}:{}", node.id, file_path.to_string_lossy());
192
193 if !self.local_loaders.contains_key(&loader_key) {
194 let loader = super::create_huggingface_loader(
195 file_path.parent().unwrap_or(file_path),
196 Some(self.config.clone()),
197 )?;
198 self.local_loaders.insert(loader_key.clone(), loader);
199 }
200
201 let loader = self.local_loaders.get_mut(&loader_key).ok_or_else(|| {
202 TrustformersError::runtime_error(format!(
203 "Loader for {} not found after insertion",
204 loader_key
205 ))
206 })?;
207 loader.load_tensor(name)
208 }
209
210 async fn stream_tensor_from_node(
212 &mut self,
213 _node: &NodeConfig,
214 file_path: &PathBuf,
215 name: &str,
216 ) -> Result<Tensor> {
217 let mut file = tokio::fs::File::open(file_path).await.map_err(|e| {
219 TrustformersError::file_not_found(format!(
220 "Failed to open {}: {}",
221 file_path.display(),
222 e
223 ))
224 })?;
225
226 let mut buffer = Vec::new();
229 file.read_to_end(&mut buffer)
230 .await
231 .map_err(|e| TrustformersError::io_error(e.to_string()))?;
232
233 self.parse_tensor_from_bytes(buffer, name)
235 }
236
237 fn parse_tensor_from_bytes(&self, data: Vec<u8>, _name: &str) -> Result<Tensor> {
239 if data.len() < 4 {
241 return Err(TrustformersError::invalid_format_simple(
242 "Insufficient data for tensor".to_string(),
243 ));
244 }
245
246 let shape = vec![data.len() / 4]; let floats: Vec<f32> = data
249 .chunks_exact(4)
250 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
251 .collect();
252
253 Tensor::from_vec(floats, &shape)
254 }
255
256 fn should_cache(&self, _name: &str) -> bool {
258 match self.distributed_config.distributed_cache.cache_strategy {
259 CacheStrategy::None => false,
260 CacheStrategy::ReadThrough | CacheStrategy::WriteThrough | CacheStrategy::WriteBack => {
261 true
262 },
263 CacheStrategy::ReadAround => false, }
265 }
266
267 async fn check_cache(&self, name: &str) -> Result<Option<Tensor>> {
269 let cache = self
270 .tensor_cache
271 .lock()
272 .map_err(|_| TrustformersError::lock_error("Cache lock poisoned".to_string()))?;
273 Ok(cache.get(name).cloned())
274 }
275
276 async fn cache_tensor(&self, name: &str, tensor: &Tensor) -> Result<()> {
278 let mut cache = self
279 .tensor_cache
280 .lock()
281 .map_err(|_| TrustformersError::lock_error("Cache lock poisoned".to_string()))?;
282
283 self.apply_eviction_policy(&mut cache)?;
285
286 cache.insert(name.to_string(), tensor.clone());
287
288 if self.distributed_config.distributed_cache.replication_factor > 1 {
290 self.replicate_tensor(name, tensor).await?;
291 }
292
293 Ok(())
294 }
295
296 fn apply_eviction_policy(&self, cache: &mut HashMap<String, Tensor>) -> Result<()> {
298 if cache.len() > 1000 {
300 if let Some(key) = cache.keys().next().cloned() {
302 cache.remove(&key);
303 }
304 }
305 Ok(())
306 }
307
308 async fn replicate_tensor(&self, name: &str, _tensor: &Tensor) -> Result<()> {
310 let replication_count = (self.distributed_config.distributed_cache.replication_factor
311 as usize)
312 .min(self.distributed_config.nodes.len());
313
314 for node in self.distributed_config.nodes.iter().take(replication_count) {
316 println!("Replicating tensor {} to node {}", name, node.id);
318 }
319
320 Ok(())
321 }
322
323 fn find_backup_node(&self, failed_node_id: &str) -> Option<NodeConfig> {
325 self.distributed_config
326 .nodes
327 .iter()
328 .find(|node| {
329 node.id != failed_node_id
330 && self.distributed_config.fault_tolerance.backup_nodes.contains(&node.id)
331 })
332 .cloned()
333 }
334
335 pub fn get_stats(&self) -> &DistributedStats {
337 &self.stats
338 }
339}
340
341impl WeightLoader for DistributedWeightLoader {
342 fn load_tensor(&mut self, name: &str) -> Result<Tensor> {
343 let rt = tokio::runtime::Runtime::new().map_err(|e| {
345 TrustformersError::runtime_error(format!("Failed to create async runtime: {}", e))
346 })?;
347
348 rt.block_on(self.load_tensor_distributed(name))
349 }
350
351 fn list_tensors(&self) -> Result<Vec<String>> {
352 let mut all_tensors = Vec::new();
354
355 for loader in self.local_loaders.values() {
356 let tensors = loader.list_tensors()?;
357 all_tensors.extend(tensors);
358 }
359
360 all_tensors.sort();
362 all_tensors.dedup();
363
364 Ok(all_tensors)
365 }
366
367 fn tensor_info(&self, name: &str) -> Result<Option<TensorMetadata>> {
368 for loader in self.local_loaders.values() {
370 if let Ok(Some(info)) = loader.tensor_info(name) {
371 return Ok(Some(info));
372 }
373 }
374 Ok(None)
375 }
376
377 fn close(&mut self) -> Result<()> {
378 for loader in self.local_loaders.values_mut() {
380 loader.close()?;
381 }
382
383 self.node_connections.clear();
385
386 Ok(())
387 }
388}
389
390struct LoadBalancer {
392 strategy: LoadBalancingStrategy,
393 node_states: HashMap<String, NodeState>,
394 round_robin_index: usize,
395}
396
397impl LoadBalancer {
398 fn new(strategy: &LoadBalancingStrategy, nodes: &[NodeConfig]) -> Result<Self> {
399 let mut node_states = HashMap::new();
400 for node in nodes {
401 node_states.insert(node.id.clone(), NodeState::new());
402 }
403
404 Ok(Self {
405 strategy: strategy.clone(),
406 node_states,
407 round_robin_index: 0,
408 })
409 }
410
411 fn select_node<'a>(
412 &mut self,
413 tensor_name: &str,
414 nodes: &'a [NodeConfig],
415 ) -> Result<&'a NodeConfig> {
416 match self.strategy {
417 LoadBalancingStrategy::RoundRobin => {
418 let selected = &nodes[self.round_robin_index % nodes.len()];
419 self.round_robin_index += 1;
420 Ok(selected)
421 },
422 LoadBalancingStrategy::LeastLoaded => {
423 let least_loaded_id = self
424 .node_states
425 .iter()
426 .min_by_key(|(_, state)| state.current_load)
427 .map(|(id, _)| id)
428 .ok_or_else(|| {
429 TrustformersError::invalid_state("No nodes available".to_string())
430 })?;
431
432 nodes
433 .iter()
434 .find(|node| &node.id == least_loaded_id)
435 .ok_or_else(|| TrustformersError::invalid_state("Node not found".to_string()))
436 },
437 LoadBalancingStrategy::ConsistentHashing => {
438 let hash = self.hash_tensor_name(tensor_name);
440 let index = hash % nodes.len();
441 Ok(&nodes[index])
442 },
443 _ => {
444 let selected = &nodes[self.round_robin_index % nodes.len()];
446 self.round_robin_index += 1;
447 Ok(selected)
448 },
449 }
450 }
451
452 fn hash_tensor_name(&self, name: &str) -> usize {
453 name.bytes().map(|b| b as usize).sum()
455 }
456}
457
458struct HealthMonitor {
460 _config: FaultToleranceConfig,
461 node_health: HashMap<String, NodeHealth>,
462}
463
464impl HealthMonitor {
465 fn new(config: &FaultToleranceConfig) -> Result<Self> {
466 Ok(Self {
467 _config: config.clone(),
468 node_health: HashMap::new(),
469 })
470 }
471
472 async fn start_monitoring(&mut self, nodes: &[NodeConfig]) -> Result<()> {
473 for node in nodes {
474 self.node_health.insert(node.id.clone(), NodeHealth::new());
475 }
476
477 tokio::spawn(async move {
479 loop {
480 tokio::time::sleep(Duration::from_secs(30)).await;
482 }
483 });
484
485 Ok(())
486 }
487}
488
489struct NodeState {
491 current_load: u64,
492 _total_requests: u64,
493 _failed_requests: u64,
494 _last_request_time: Instant,
495}
496
497impl NodeState {
498 fn new() -> Self {
499 Self {
500 current_load: 0,
501 _total_requests: 0,
502 _failed_requests: 0,
503 _last_request_time: Instant::now(),
504 }
505 }
506}
507
508struct NodeHealth {
510 _is_healthy: bool,
511 _last_check: Instant,
512 _consecutive_failures: u32,
513}
514
515impl NodeHealth {
516 fn new() -> Self {
517 Self {
518 _is_healthy: true,
519 _last_check: Instant::now(),
520 _consecutive_failures: 0,
521 }
522 }
523}
524
525#[derive(Debug, Default)]
527pub struct DistributedStats {
528 pub cache_hits: u64,
529 pub cache_misses: u64,
530 pub successful_loads: u64,
531 pub failed_loads: u64,
532 pub total_load_time: Duration,
533 pub node_load_times: HashMap<String, Vec<Duration>>,
534 pub bytes_transferred: u64,
535}
536
537impl DistributedStats {
538 fn new() -> Self {
539 Self::default()
540 }
541
542 pub fn cache_hit_rate(&self) -> f64 {
543 let total = self.cache_hits + self.cache_misses;
544 if total == 0 {
545 0.0
546 } else {
547 self.cache_hits as f64 / total as f64
548 }
549 }
550
551 pub fn success_rate(&self) -> f64 {
552 let total = self.successful_loads + self.failed_loads;
553 if total == 0 {
554 0.0
555 } else {
556 self.successful_loads as f64 / total as f64
557 }
558 }
559
560 pub fn average_load_time(&self) -> Duration {
561 if self.successful_loads == 0 {
562 Duration::ZERO
563 } else {
564 self.total_load_time / self.successful_loads as u32
565 }
566 }
567}