Skip to main content

trustformers_models/weight_loading/
streaming.rs

1/// Streaming Weight Loader
2///
3/// This module provides streaming weight loading for models that don't fit entirely in memory.
4/// It loads and manages chunks of tensors, evicting old chunks as needed using an LRU strategy.
5use std::collections::{HashMap, VecDeque};
6use std::path::PathBuf;
7use std::sync::{Arc, Mutex};
8use trustformers_core::{
9    errors::{ErrorKind, Result, TrustformersError},
10    tensor::{DType, Tensor},
11};
12
13use super::config::WeightLoadingConfig;
14use super::huggingface::{HuggingFaceLoader, TensorMetadata, WeightLoader};
15
16/// Chunk information for streaming loader
17#[derive(Debug, Clone)]
18pub struct ChunkInfo {
19    pub chunk_id: usize,
20    pub tensor_names: Vec<String>,
21    pub memory_usage: usize,
22    pub last_accessed: std::time::Instant,
23}
24
25/// Streaming weight loader for models that don't fit in memory
26pub struct StreamingLoader {
27    config: WeightLoadingConfig,
28    model_path: PathBuf,
29    chunk_size: usize,
30    max_memory_usage: usize,
31    current_chunks: HashMap<usize, ChunkInfo>,
32    tensor_to_chunk: HashMap<String, usize>,
33    loaded_tensors: HashMap<String, Tensor>,
34    chunk_access_order: VecDeque<usize>,
35    total_memory_usage: usize,
36    tensor_metadata_cache: HashMap<String, TensorMetadata>,
37    underlying_loader: Arc<Mutex<Option<HuggingFaceLoader>>>,
38}
39
40impl StreamingLoader {
41    pub fn new(
42        config: WeightLoadingConfig,
43        model_path: PathBuf,
44        chunk_size: usize,
45        max_memory_mb: usize,
46    ) -> Self {
47        Self {
48            config,
49            model_path,
50            chunk_size,
51            max_memory_usage: max_memory_mb * 1024 * 1024, // Convert MB to bytes
52            current_chunks: HashMap::new(),
53            tensor_to_chunk: HashMap::new(),
54            loaded_tensors: HashMap::new(),
55            chunk_access_order: VecDeque::new(),
56            total_memory_usage: 0,
57            tensor_metadata_cache: HashMap::new(),
58            underlying_loader: Arc::new(Mutex::new(None)),
59        }
60    }
61
62    /// Initialize the streaming loader by analyzing the model structure
63    pub fn initialize(&mut self) -> Result<()> {
64        // Create underlying loader to analyze tensor structure
65        let loader = HuggingFaceLoader::new(&self.model_path, self.config.clone())?;
66
67        // Get all tensor information
68        let tensor_names = loader.list_tensors()?;
69        let mut current_chunk_id = 0;
70        let mut current_chunk_size: usize = 0;
71        let mut current_chunk_tensors = Vec::new();
72
73        // Group tensors into chunks based on size
74        for tensor_name in tensor_names {
75            if let Some(metadata) = loader.tensor_info(&tensor_name)? {
76                self.tensor_metadata_cache.insert(tensor_name.clone(), metadata.clone());
77
78                // Check if adding this tensor would exceed chunk size
79                if current_chunk_size + (metadata.size_bytes as usize) > self.chunk_size
80                    && !current_chunk_tensors.is_empty()
81                {
82                    // Finalize current chunk
83                    self.finalize_chunk(
84                        current_chunk_id,
85                        current_chunk_tensors,
86                        current_chunk_size,
87                    );
88
89                    // Start new chunk
90                    current_chunk_id += 1;
91                    current_chunk_tensors = Vec::new();
92                    current_chunk_size = 0;
93                }
94
95                current_chunk_tensors.push(tensor_name.clone());
96                current_chunk_size += metadata.size_bytes as usize;
97                self.tensor_to_chunk.insert(tensor_name, current_chunk_id);
98            }
99        }
100
101        // Finalize the last chunk
102        if !current_chunk_tensors.is_empty() {
103            self.finalize_chunk(current_chunk_id, current_chunk_tensors, current_chunk_size);
104        }
105
106        // Store the underlying loader for later use
107        *self
108            .underlying_loader
109            .lock()
110            .map_err(|e| TrustformersError::io_error(format!("Failed to acquire lock: {}", e)))? =
111            Some(loader);
112
113        Ok(())
114    }
115
116    fn finalize_chunk(&mut self, chunk_id: usize, tensor_names: Vec<String>, memory_usage: usize) {
117        let chunk_info = ChunkInfo {
118            chunk_id,
119            tensor_names,
120            memory_usage,
121            last_accessed: std::time::Instant::now(),
122        };
123        self.current_chunks.insert(chunk_id, chunk_info);
124    }
125
126    pub fn load_chunk(&mut self, chunk_id: usize) -> Result<()> {
127        // Check if chunk is already loaded
128        if self.chunk_access_order.contains(&chunk_id) {
129            // Move to front of access order (most recently used)
130            self.chunk_access_order.retain(|&x| x != chunk_id);
131            self.chunk_access_order.push_front(chunk_id);
132            return Ok(());
133        }
134
135        // Get chunk info
136        let chunk_info = self
137            .current_chunks
138            .get(&chunk_id)
139            .ok_or_else(|| {
140                TrustformersError::invalid_operation(format!("Chunk {} not found", chunk_id))
141            })?
142            .clone();
143
144        // Calculate memory needed for this chunk
145        let chunk_memory = chunk_info.memory_usage;
146
147        // Evict chunks if necessary to make room
148        while self.total_memory_usage + chunk_memory > self.max_memory_usage
149            && !self.chunk_access_order.is_empty()
150        {
151            let oldest_chunk = self.chunk_access_order.pop_back().ok_or_else(|| {
152                TrustformersError::runtime_error(
153                    "chunk_access_order unexpectedly empty".to_string(),
154                )
155            })?;
156            self.evict_chunk_internal(oldest_chunk)?;
157        }
158
159        // Load tensors for this chunk
160        let mut loader_guard = self
161            .underlying_loader
162            .lock()
163            .map_err(|e| TrustformersError::io_error(format!("Failed to acquire lock: {}", e)))?;
164        if let Some(loader) = loader_guard.as_mut() {
165            for tensor_name in &chunk_info.tensor_names {
166                let tensor = loader.load_tensor(tensor_name)?;
167                let tensor_size = self.calculate_tensor_memory_usage(&tensor);
168                self.loaded_tensors.insert(tensor_name.clone(), tensor);
169                self.total_memory_usage += tensor_size;
170            }
171        } else {
172            return Err(TrustformersError::invalid_operation(
173                "Streaming loader not initialized".to_string(),
174            ));
175        }
176
177        // Add to access order (most recently used at front)
178        self.chunk_access_order.push_front(chunk_id);
179
180        // Update chunk access time
181        if let Some(chunk) = self.current_chunks.get_mut(&chunk_id) {
182            chunk.last_accessed = std::time::Instant::now();
183        }
184
185        Ok(())
186    }
187
188    fn calculate_tensor_memory_usage(&self, tensor: &Tensor) -> usize {
189        let element_count: usize = tensor.shape().iter().product();
190        let bytes_per_element = match tensor.dtype() {
191            DType::F32 => 4,
192            DType::F16 => 2,
193            DType::F64 => 8,
194            DType::I32 => 4,
195            DType::I64 => 8,
196            _ => 4, // Default to 4 bytes
197        };
198        element_count * bytes_per_element
199    }
200
201    pub fn evict_chunk(&mut self, chunk_id: usize) -> Result<()> {
202        self.evict_chunk_internal(chunk_id)
203    }
204
205    fn evict_chunk_internal(&mut self, chunk_id: usize) -> Result<()> {
206        // Get chunk info
207        if let Some(chunk_info) = self.current_chunks.get(&chunk_id) {
208            let tensor_names = chunk_info.tensor_names.clone();
209
210            // Remove tensors from memory and update memory usage
211            for tensor_name in &tensor_names {
212                if let Some(tensor) = self.loaded_tensors.remove(tensor_name) {
213                    let tensor_size = self.calculate_tensor_memory_usage(&tensor);
214                    self.total_memory_usage = self.total_memory_usage.saturating_sub(tensor_size);
215                }
216            }
217
218            // Remove from access order
219            self.chunk_access_order.retain(|&x| x != chunk_id);
220        }
221
222        Ok(())
223    }
224
225    pub fn get_memory_usage(&self) -> usize {
226        self.total_memory_usage
227    }
228
229    /// Get memory usage as a percentage of maximum allowed
230    pub fn get_memory_usage_percentage(&self) -> f32 {
231        if self.max_memory_usage == 0 {
232            0.0
233        } else {
234            (self.total_memory_usage as f32 / self.max_memory_usage as f32) * 100.0
235        }
236    }
237
238    /// Get detailed memory statistics
239    pub fn get_memory_stats(&self) -> MemoryStats {
240        MemoryStats {
241            current_usage_bytes: self.total_memory_usage,
242            max_usage_bytes: self.max_memory_usage,
243            usage_percentage: self.get_memory_usage_percentage(),
244            loaded_chunks: self.chunk_access_order.len(),
245            total_chunks: self.current_chunks.len(),
246            loaded_tensors: self.loaded_tensors.len(),
247        }
248    }
249
250    pub fn is_chunk_loaded(&self, chunk_id: usize) -> bool {
251        self.chunk_access_order.contains(&chunk_id)
252    }
253
254    /// Get information about all chunks
255    pub fn get_chunk_info(&self) -> Vec<ChunkInfo> {
256        self.current_chunks.values().cloned().collect()
257    }
258
259    /// Force garbage collection of unused tensors
260    pub fn garbage_collect(&mut self) -> Result<()> {
261        // This could be enhanced with more sophisticated GC logic
262        // For now, just ensure memory usage tracking is accurate
263        let mut actual_usage = 0;
264        for tensor in self.loaded_tensors.values() {
265            actual_usage += self.calculate_tensor_memory_usage(tensor);
266        }
267        self.total_memory_usage = actual_usage;
268        Ok(())
269    }
270}
271
272impl WeightLoader for StreamingLoader {
273    fn load_tensor(&mut self, name: &str) -> Result<Tensor> {
274        // Check if tensor is already loaded
275        if let Some(tensor) = self.loaded_tensors.get(name) {
276            // Update access time for the chunk containing this tensor
277            if let Some(&chunk_id) = self.tensor_to_chunk.get(name) {
278                if let Some(chunk) = self.current_chunks.get_mut(&chunk_id) {
279                    chunk.last_accessed = std::time::Instant::now();
280                }
281                // Move chunk to front of access order
282                self.chunk_access_order.retain(|&x| x != chunk_id);
283                self.chunk_access_order.push_front(chunk_id);
284            }
285            return Ok(tensor.clone());
286        }
287
288        // Find which chunk contains this tensor
289        if let Some(&chunk_id) = self.tensor_to_chunk.get(name) {
290            // Load the chunk
291            self.load_chunk(chunk_id)?;
292
293            // Now the tensor should be loaded
294            if let Some(tensor) = self.loaded_tensors.get(name) {
295                Ok(tensor.clone())
296            } else {
297                Err(TrustformersError::new(ErrorKind::WeightLoadingError {
298                    reason: format!("Failed to load tensor {} from chunk {}", name, chunk_id),
299                }))
300            }
301        } else {
302            Err(TrustformersError::new(ErrorKind::WeightLoadingError {
303                reason: format!("Tensor {} not found in any chunk", name),
304            }))
305        }
306    }
307
308    fn list_tensors(&self) -> Result<Vec<String>> {
309        // Return all tensor names across all chunks
310        Ok(self.tensor_to_chunk.keys().cloned().collect())
311    }
312
313    fn tensor_info(&self, name: &str) -> Result<Option<TensorMetadata>> {
314        Ok(self.tensor_metadata_cache.get(name).cloned())
315    }
316
317    fn close(&mut self) -> Result<()> {
318        // Clear all loaded data
319        self.loaded_tensors.clear();
320        self.chunk_access_order.clear();
321        self.total_memory_usage = 0;
322
323        // Close underlying loader
324        if let Some(mut loader) = self.underlying_loader.lock().expect("operation failed").take() {
325            loader.close()?;
326        }
327
328        Ok(())
329    }
330}
331
332/// Additional methods for streaming-specific functionality
333impl StreamingLoader {
334    /// List only currently loaded tensors
335    pub fn list_loaded_tensors(&self) -> Vec<String> {
336        self.loaded_tensors.keys().cloned().collect()
337    }
338}
339
340/// Memory usage statistics for the streaming loader
341#[derive(Debug, Clone)]
342pub struct MemoryStats {
343    pub current_usage_bytes: usize,
344    pub max_usage_bytes: usize,
345    pub usage_percentage: f32,
346    pub loaded_chunks: usize,
347    pub total_chunks: usize,
348    pub loaded_tensors: usize,
349}
350
351impl std::fmt::Display for MemoryStats {
352    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
353        write!(
354            f,
355            "Memory Usage: {:.1}% ({} / {} bytes), Chunks: {} / {}, Tensors: {}",
356            self.usage_percentage,
357            self.current_usage_bytes,
358            self.max_usage_bytes,
359            self.loaded_chunks,
360            self.total_chunks,
361            self.loaded_tensors
362        )
363    }
364}