Skip to main content

trustformers_models/weight_loading/
distributed.rs

1/// Distributed Weight Loader
2///
3/// This module provides distributed weight loading capabilities across multiple nodes
4/// with load balancing, fault tolerance, and caching.
5use std::collections::HashMap;
6use std::path::PathBuf;
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9use tokio::io::AsyncReadExt;
10use trustformers_core::{
11    errors::{runtime_error, Result, TrustformersError},
12    tensor::Tensor,
13};
14
15use super::config::{
16    CacheStrategy, DistributedConfig, FaultToleranceConfig, LoadBalancingStrategy, NodeConfig,
17    WeightLoadingConfig,
18};
19use super::huggingface::{TensorMetadata, WeightLoader};
20
21/// Distributed weight loader for loading across multiple nodes
22pub struct DistributedWeightLoader {
23    config: WeightLoadingConfig,
24    distributed_config: DistributedConfig,
25    local_loaders: HashMap<String, Box<dyn WeightLoader>>,
26    node_connections: HashMap<String, tokio::net::TcpStream>,
27    load_balancer: LoadBalancer,
28    health_monitor: HealthMonitor,
29    tensor_cache: Arc<Mutex<HashMap<String, Tensor>>>,
30    stats: DistributedStats,
31}
32
33impl DistributedWeightLoader {
34    /// Create a new distributed weight loader
35    pub fn new(config: WeightLoadingConfig, distributed_config: DistributedConfig) -> Result<Self> {
36        let load_balancer =
37            LoadBalancer::new(&distributed_config.load_balancer, &distributed_config.nodes)?;
38        let health_monitor = HealthMonitor::new(&distributed_config.fault_tolerance)?;
39
40        Ok(Self {
41            config,
42            distributed_config,
43            local_loaders: HashMap::new(),
44            node_connections: HashMap::new(),
45            load_balancer,
46            health_monitor,
47            tensor_cache: Arc::new(Mutex::new(HashMap::new())),
48            stats: DistributedStats::new(),
49        })
50    }
51
52    /// Initialize connections to all nodes
53    pub async fn initialize(&mut self) -> Result<()> {
54        for node in &self.distributed_config.nodes.clone() {
55            if let Err(e) = self.connect_to_node(node).await {
56                if !self.distributed_config.fault_tolerance.enable_failover {
57                    return Err(e);
58                }
59                // Log warning but continue with other nodes
60                eprintln!("Warning: Failed to connect to node {}: {}", node.id, e);
61            }
62        }
63
64        // Start health monitoring
65        self.health_monitor.start_monitoring(&self.distributed_config.nodes).await?;
66
67        Ok(())
68    }
69
70    /// Connect to a specific node
71    async fn connect_to_node(&mut self, node: &NodeConfig) -> Result<()> {
72        let address = format!("{}:{}", node.address, node.port);
73        let timeout_duration = self.distributed_config.network.connection_timeout;
74
75        let stream =
76            tokio::time::timeout(timeout_duration, tokio::net::TcpStream::connect(&address))
77                .await
78                .map_err(|_| {
79                    TrustformersError::runtime_error(format!("Connection to {} timed out", address))
80                })?
81                .map_err(|e| {
82                    TrustformersError::io_error(format!("Failed to connect to {}: {}", address, e))
83                })?;
84
85        self.node_connections.insert(node.id.clone(), stream);
86        Ok(())
87    }
88
89    /// Load tensor with distributed strategy
90    async fn load_tensor_distributed(&mut self, name: &str) -> Result<Tensor> {
91        // Check local cache first
92        if let Some(tensor) = self.check_cache(name).await? {
93            self.stats.cache_hits += 1;
94            return Ok(tensor);
95        }
96
97        self.stats.cache_misses += 1;
98
99        // Select optimal node for loading
100        let selected_node =
101            self.load_balancer.select_node(name, &self.distributed_config.nodes)?.clone();
102
103        // Attempt to load from selected node with fault tolerance
104        let mut attempts = 0;
105        let max_retries = self.distributed_config.fault_tolerance.max_retries;
106
107        loop {
108            match self.load_from_node(&selected_node, name).await {
109                Ok(tensor) => {
110                    // Cache the tensor if enabled
111                    if self.should_cache(name) {
112                        self.cache_tensor(name, &tensor).await?;
113                    }
114
115                    self.stats.successful_loads += 1;
116                    return Ok(tensor);
117                },
118                Err(e) => {
119                    attempts += 1;
120                    self.stats.failed_loads += 1;
121
122                    if attempts >= max_retries {
123                        if self.distributed_config.fault_tolerance.enable_failover {
124                            // Try backup nodes
125                            if let Some(backup_node) = self.find_backup_node(&selected_node.id) {
126                                return self.load_from_node(&backup_node, name).await;
127                            }
128                        }
129                        return Err(e);
130                    }
131
132                    // Wait before retry
133                    tokio::time::sleep(self.distributed_config.fault_tolerance.retry_delay).await;
134                },
135            }
136        }
137    }
138
139    /// Load tensor from a specific node
140    async fn load_from_node(&mut self, node: &NodeConfig, name: &str) -> Result<Tensor> {
141        let start_time = Instant::now();
142
143        // Find tensor file on the node
144        let file_path = self.find_tensor_on_node(node, name)?;
145
146        // Load tensor based on file type and configuration
147        let tensor = if self.config.streaming {
148            self.stream_tensor_from_node(node, &file_path, name).await?
149        } else {
150            self.load_tensor_from_node(node, &file_path, name).await?
151        };
152
153        let load_time = start_time.elapsed();
154        self.stats.total_load_time += load_time;
155        self.stats.node_load_times.entry(node.id.clone()).or_default().push(load_time);
156
157        Ok(tensor)
158    }
159
160    /// Find tensor file on a specific node
161    fn find_tensor_on_node(&self, node: &NodeConfig, name: &str) -> Result<PathBuf> {
162        for storage_path in &node.storage_paths {
163            let potential_files = vec![
164                storage_path.join(format!("{}.safetensors", name)),
165                storage_path.join(format!("{}.bin", name)),
166                storage_path.join("pytorch_model.bin"),
167                storage_path.join("model.safetensors"),
168            ];
169
170            for file_path in potential_files {
171                if file_path.exists() {
172                    return Ok(file_path);
173                }
174            }
175        }
176
177        Err(runtime_error(format!(
178            "Tensor {} not found on node {}",
179            name, node.id
180        )))
181    }
182
183    /// Load tensor from node using appropriate loader
184    async fn load_tensor_from_node(
185        &mut self,
186        node: &NodeConfig,
187        file_path: &PathBuf,
188        name: &str,
189    ) -> Result<Tensor> {
190        // Create appropriate loader for the node
191        let loader_key = format!("{}:{}", node.id, file_path.to_string_lossy());
192
193        if !self.local_loaders.contains_key(&loader_key) {
194            let loader = super::create_huggingface_loader(
195                file_path.parent().unwrap_or(file_path),
196                Some(self.config.clone()),
197            )?;
198            self.local_loaders.insert(loader_key.clone(), loader);
199        }
200
201        let loader = self.local_loaders.get_mut(&loader_key).ok_or_else(|| {
202            TrustformersError::runtime_error(format!(
203                "Loader for {} not found after insertion",
204                loader_key
205            ))
206        })?;
207        loader.load_tensor(name)
208    }
209
210    /// Stream tensor from node in chunks
211    async fn stream_tensor_from_node(
212        &mut self,
213        _node: &NodeConfig,
214        file_path: &PathBuf,
215        name: &str,
216    ) -> Result<Tensor> {
217        // Open file for streaming
218        let mut file = tokio::fs::File::open(file_path).await.map_err(|e| {
219            TrustformersError::file_not_found(format!(
220                "Failed to open {}: {}",
221                file_path.display(),
222                e
223            ))
224        })?;
225
226        // For simplicity, load the entire tensor
227        // In practice, this would stream chunks and reconstruct the tensor
228        let mut buffer = Vec::new();
229        file.read_to_end(&mut buffer)
230            .await
231            .map_err(|e| TrustformersError::io_error(e.to_string()))?;
232
233        // Parse tensor from bytes (simplified)
234        self.parse_tensor_from_bytes(buffer, name)
235    }
236
237    /// Parse tensor from raw bytes
238    fn parse_tensor_from_bytes(&self, data: Vec<u8>, _name: &str) -> Result<Tensor> {
239        // Simplified tensor parsing - in practice this would handle different formats
240        if data.len() < 4 {
241            return Err(TrustformersError::invalid_format_simple(
242                "Insufficient data for tensor".to_string(),
243            ));
244        }
245
246        // For demo purposes, create a simple tensor
247        let shape = vec![data.len() / 4]; // Assume f32 data
248        let floats: Vec<f32> = data
249            .chunks_exact(4)
250            .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
251            .collect();
252
253        Tensor::from_vec(floats, &shape)
254    }
255
256    /// Check if tensor should be cached
257    fn should_cache(&self, _name: &str) -> bool {
258        match self.distributed_config.distributed_cache.cache_strategy {
259            CacheStrategy::None => false,
260            CacheStrategy::ReadThrough | CacheStrategy::WriteThrough | CacheStrategy::WriteBack => {
261                true
262            },
263            CacheStrategy::ReadAround => false, // Skip cache on read
264        }
265    }
266
267    /// Check cache for tensor
268    async fn check_cache(&self, name: &str) -> Result<Option<Tensor>> {
269        let cache = self
270            .tensor_cache
271            .lock()
272            .map_err(|_| TrustformersError::lock_error("Cache lock poisoned".to_string()))?;
273        Ok(cache.get(name).cloned())
274    }
275
276    /// Cache tensor with replication
277    async fn cache_tensor(&self, name: &str, tensor: &Tensor) -> Result<()> {
278        let mut cache = self
279            .tensor_cache
280            .lock()
281            .map_err(|_| TrustformersError::lock_error("Cache lock poisoned".to_string()))?;
282
283        // Apply eviction policy if cache is full
284        self.apply_eviction_policy(&mut cache)?;
285
286        cache.insert(name.to_string(), tensor.clone());
287
288        // Replicate to other nodes based on replication factor
289        if self.distributed_config.distributed_cache.replication_factor > 1 {
290            self.replicate_tensor(name, tensor).await?;
291        }
292
293        Ok(())
294    }
295
296    /// Apply cache eviction policy
297    fn apply_eviction_policy(&self, cache: &mut HashMap<String, Tensor>) -> Result<()> {
298        // Simplified eviction - remove random entry if cache is too large
299        if cache.len() > 1000 {
300            // Arbitrary limit
301            if let Some(key) = cache.keys().next().cloned() {
302                cache.remove(&key);
303            }
304        }
305        Ok(())
306    }
307
308    /// Replicate tensor to other nodes
309    async fn replicate_tensor(&self, name: &str, _tensor: &Tensor) -> Result<()> {
310        let replication_count = (self.distributed_config.distributed_cache.replication_factor
311            as usize)
312            .min(self.distributed_config.nodes.len());
313
314        // Select nodes for replication (simplified)
315        for node in self.distributed_config.nodes.iter().take(replication_count) {
316            // In practice, this would send the tensor data to the node
317            println!("Replicating tensor {} to node {}", name, node.id);
318        }
319
320        Ok(())
321    }
322
323    /// Find backup node for failover
324    fn find_backup_node(&self, failed_node_id: &str) -> Option<NodeConfig> {
325        self.distributed_config
326            .nodes
327            .iter()
328            .find(|node| {
329                node.id != failed_node_id
330                    && self.distributed_config.fault_tolerance.backup_nodes.contains(&node.id)
331            })
332            .cloned()
333    }
334
335    /// Get distributed loading statistics
336    pub fn get_stats(&self) -> &DistributedStats {
337        &self.stats
338    }
339}
340
341impl WeightLoader for DistributedWeightLoader {
342    fn load_tensor(&mut self, name: &str) -> Result<Tensor> {
343        // For sync interface, use blocking runtime
344        let rt = tokio::runtime::Runtime::new().map_err(|e| {
345            TrustformersError::runtime_error(format!("Failed to create async runtime: {}", e))
346        })?;
347
348        rt.block_on(self.load_tensor_distributed(name))
349    }
350
351    fn list_tensors(&self) -> Result<Vec<String>> {
352        // Aggregate tensor lists from all nodes
353        let mut all_tensors = Vec::new();
354
355        for loader in self.local_loaders.values() {
356            let tensors = loader.list_tensors()?;
357            all_tensors.extend(tensors);
358        }
359
360        // Remove duplicates
361        all_tensors.sort();
362        all_tensors.dedup();
363
364        Ok(all_tensors)
365    }
366
367    fn tensor_info(&self, name: &str) -> Result<Option<TensorMetadata>> {
368        // Try to get info from any available loader
369        for loader in self.local_loaders.values() {
370            if let Ok(Some(info)) = loader.tensor_info(name) {
371                return Ok(Some(info));
372            }
373        }
374        Ok(None)
375    }
376
377    fn close(&mut self) -> Result<()> {
378        // Close all local loaders
379        for loader in self.local_loaders.values_mut() {
380            loader.close()?;
381        }
382
383        // Close network connections
384        self.node_connections.clear();
385
386        Ok(())
387    }
388}
389
390/// Load balancer for distributed weight loading
391struct LoadBalancer {
392    strategy: LoadBalancingStrategy,
393    node_states: HashMap<String, NodeState>,
394    round_robin_index: usize,
395}
396
397impl LoadBalancer {
398    fn new(strategy: &LoadBalancingStrategy, nodes: &[NodeConfig]) -> Result<Self> {
399        let mut node_states = HashMap::new();
400        for node in nodes {
401            node_states.insert(node.id.clone(), NodeState::new());
402        }
403
404        Ok(Self {
405            strategy: strategy.clone(),
406            node_states,
407            round_robin_index: 0,
408        })
409    }
410
411    fn select_node<'a>(
412        &mut self,
413        tensor_name: &str,
414        nodes: &'a [NodeConfig],
415    ) -> Result<&'a NodeConfig> {
416        match self.strategy {
417            LoadBalancingStrategy::RoundRobin => {
418                let selected = &nodes[self.round_robin_index % nodes.len()];
419                self.round_robin_index += 1;
420                Ok(selected)
421            },
422            LoadBalancingStrategy::LeastLoaded => {
423                let least_loaded_id = self
424                    .node_states
425                    .iter()
426                    .min_by_key(|(_, state)| state.current_load)
427                    .map(|(id, _)| id)
428                    .ok_or_else(|| {
429                        TrustformersError::invalid_state("No nodes available".to_string())
430                    })?;
431
432                nodes
433                    .iter()
434                    .find(|node| &node.id == least_loaded_id)
435                    .ok_or_else(|| TrustformersError::invalid_state("Node not found".to_string()))
436            },
437            LoadBalancingStrategy::ConsistentHashing => {
438                // Simple hash-based selection
439                let hash = self.hash_tensor_name(tensor_name);
440                let index = hash % nodes.len();
441                Ok(&nodes[index])
442            },
443            _ => {
444                // Fallback to round robin for other strategies
445                let selected = &nodes[self.round_robin_index % nodes.len()];
446                self.round_robin_index += 1;
447                Ok(selected)
448            },
449        }
450    }
451
452    fn hash_tensor_name(&self, name: &str) -> usize {
453        // Simple hash function
454        name.bytes().map(|b| b as usize).sum()
455    }
456}
457
458/// Health monitor for tracking node health
459struct HealthMonitor {
460    _config: FaultToleranceConfig,
461    node_health: HashMap<String, NodeHealth>,
462}
463
464impl HealthMonitor {
465    fn new(config: &FaultToleranceConfig) -> Result<Self> {
466        Ok(Self {
467            _config: config.clone(),
468            node_health: HashMap::new(),
469        })
470    }
471
472    async fn start_monitoring(&mut self, nodes: &[NodeConfig]) -> Result<()> {
473        for node in nodes {
474            self.node_health.insert(node.id.clone(), NodeHealth::new());
475        }
476
477        // Start background health checking
478        tokio::spawn(async move {
479            loop {
480                // Perform health checks
481                tokio::time::sleep(Duration::from_secs(30)).await;
482            }
483        });
484
485        Ok(())
486    }
487}
488
489/// Node state for load balancing
490struct NodeState {
491    current_load: u64,
492    _total_requests: u64,
493    _failed_requests: u64,
494    _last_request_time: Instant,
495}
496
497impl NodeState {
498    fn new() -> Self {
499        Self {
500            current_load: 0,
501            _total_requests: 0,
502            _failed_requests: 0,
503            _last_request_time: Instant::now(),
504        }
505    }
506}
507
508/// Node health information
509struct NodeHealth {
510    _is_healthy: bool,
511    _last_check: Instant,
512    _consecutive_failures: u32,
513}
514
515impl NodeHealth {
516    fn new() -> Self {
517        Self {
518            _is_healthy: true,
519            _last_check: Instant::now(),
520            _consecutive_failures: 0,
521        }
522    }
523}
524
525/// Statistics for distributed weight loading
526#[derive(Debug, Default)]
527pub struct DistributedStats {
528    pub cache_hits: u64,
529    pub cache_misses: u64,
530    pub successful_loads: u64,
531    pub failed_loads: u64,
532    pub total_load_time: Duration,
533    pub node_load_times: HashMap<String, Vec<Duration>>,
534    pub bytes_transferred: u64,
535}
536
537impl DistributedStats {
538    fn new() -> Self {
539        Self::default()
540    }
541
542    pub fn cache_hit_rate(&self) -> f64 {
543        let total = self.cache_hits + self.cache_misses;
544        if total == 0 {
545            0.0
546        } else {
547            self.cache_hits as f64 / total as f64
548        }
549    }
550
551    pub fn success_rate(&self) -> f64 {
552        let total = self.successful_loads + self.failed_loads;
553        if total == 0 {
554            0.0
555        } else {
556            self.successful_loads as f64 / total as f64
557        }
558    }
559
560    pub fn average_load_time(&self) -> Duration {
561        if self.successful_loads == 0 {
562            Duration::ZERO
563        } else {
564            self.total_load_time / self.successful_loads as u32
565        }
566    }
567}