trustformers_models/weight_loading/
streaming.rs1use std::collections::{HashMap, VecDeque};
6use std::path::PathBuf;
7use std::sync::{Arc, Mutex};
8use trustformers_core::{
9 errors::{ErrorKind, Result, TrustformersError},
10 tensor::{DType, Tensor},
11};
12
13use super::config::WeightLoadingConfig;
14use super::huggingface::{HuggingFaceLoader, TensorMetadata, WeightLoader};
15
16#[derive(Debug, Clone)]
18pub struct ChunkInfo {
19 pub chunk_id: usize,
20 pub tensor_names: Vec<String>,
21 pub memory_usage: usize,
22 pub last_accessed: std::time::Instant,
23}
24
25pub struct StreamingLoader {
27 config: WeightLoadingConfig,
28 model_path: PathBuf,
29 chunk_size: usize,
30 max_memory_usage: usize,
31 current_chunks: HashMap<usize, ChunkInfo>,
32 tensor_to_chunk: HashMap<String, usize>,
33 loaded_tensors: HashMap<String, Tensor>,
34 chunk_access_order: VecDeque<usize>,
35 total_memory_usage: usize,
36 tensor_metadata_cache: HashMap<String, TensorMetadata>,
37 underlying_loader: Arc<Mutex<Option<HuggingFaceLoader>>>,
38}
39
40impl StreamingLoader {
41 pub fn new(
42 config: WeightLoadingConfig,
43 model_path: PathBuf,
44 chunk_size: usize,
45 max_memory_mb: usize,
46 ) -> Self {
47 Self {
48 config,
49 model_path,
50 chunk_size,
51 max_memory_usage: max_memory_mb * 1024 * 1024, current_chunks: HashMap::new(),
53 tensor_to_chunk: HashMap::new(),
54 loaded_tensors: HashMap::new(),
55 chunk_access_order: VecDeque::new(),
56 total_memory_usage: 0,
57 tensor_metadata_cache: HashMap::new(),
58 underlying_loader: Arc::new(Mutex::new(None)),
59 }
60 }
61
62 pub fn initialize(&mut self) -> Result<()> {
64 let loader = HuggingFaceLoader::new(&self.model_path, self.config.clone())?;
66
67 let tensor_names = loader.list_tensors()?;
69 let mut current_chunk_id = 0;
70 let mut current_chunk_size: usize = 0;
71 let mut current_chunk_tensors = Vec::new();
72
73 for tensor_name in tensor_names {
75 if let Some(metadata) = loader.tensor_info(&tensor_name)? {
76 self.tensor_metadata_cache.insert(tensor_name.clone(), metadata.clone());
77
78 if current_chunk_size + (metadata.size_bytes as usize) > self.chunk_size
80 && !current_chunk_tensors.is_empty()
81 {
82 self.finalize_chunk(
84 current_chunk_id,
85 current_chunk_tensors,
86 current_chunk_size,
87 );
88
89 current_chunk_id += 1;
91 current_chunk_tensors = Vec::new();
92 current_chunk_size = 0;
93 }
94
95 current_chunk_tensors.push(tensor_name.clone());
96 current_chunk_size += metadata.size_bytes as usize;
97 self.tensor_to_chunk.insert(tensor_name, current_chunk_id);
98 }
99 }
100
101 if !current_chunk_tensors.is_empty() {
103 self.finalize_chunk(current_chunk_id, current_chunk_tensors, current_chunk_size);
104 }
105
106 *self
108 .underlying_loader
109 .lock()
110 .map_err(|e| TrustformersError::io_error(format!("Failed to acquire lock: {}", e)))? =
111 Some(loader);
112
113 Ok(())
114 }
115
116 fn finalize_chunk(&mut self, chunk_id: usize, tensor_names: Vec<String>, memory_usage: usize) {
117 let chunk_info = ChunkInfo {
118 chunk_id,
119 tensor_names,
120 memory_usage,
121 last_accessed: std::time::Instant::now(),
122 };
123 self.current_chunks.insert(chunk_id, chunk_info);
124 }
125
126 pub fn load_chunk(&mut self, chunk_id: usize) -> Result<()> {
127 if self.chunk_access_order.contains(&chunk_id) {
129 self.chunk_access_order.retain(|&x| x != chunk_id);
131 self.chunk_access_order.push_front(chunk_id);
132 return Ok(());
133 }
134
135 let chunk_info = self
137 .current_chunks
138 .get(&chunk_id)
139 .ok_or_else(|| {
140 TrustformersError::invalid_operation(format!("Chunk {} not found", chunk_id))
141 })?
142 .clone();
143
144 let chunk_memory = chunk_info.memory_usage;
146
147 while self.total_memory_usage + chunk_memory > self.max_memory_usage
149 && !self.chunk_access_order.is_empty()
150 {
151 let oldest_chunk = self.chunk_access_order.pop_back().ok_or_else(|| {
152 TrustformersError::runtime_error(
153 "chunk_access_order unexpectedly empty".to_string(),
154 )
155 })?;
156 self.evict_chunk_internal(oldest_chunk)?;
157 }
158
159 let mut loader_guard = self
161 .underlying_loader
162 .lock()
163 .map_err(|e| TrustformersError::io_error(format!("Failed to acquire lock: {}", e)))?;
164 if let Some(loader) = loader_guard.as_mut() {
165 for tensor_name in &chunk_info.tensor_names {
166 let tensor = loader.load_tensor(tensor_name)?;
167 let tensor_size = self.calculate_tensor_memory_usage(&tensor);
168 self.loaded_tensors.insert(tensor_name.clone(), tensor);
169 self.total_memory_usage += tensor_size;
170 }
171 } else {
172 return Err(TrustformersError::invalid_operation(
173 "Streaming loader not initialized".to_string(),
174 ));
175 }
176
177 self.chunk_access_order.push_front(chunk_id);
179
180 if let Some(chunk) = self.current_chunks.get_mut(&chunk_id) {
182 chunk.last_accessed = std::time::Instant::now();
183 }
184
185 Ok(())
186 }
187
188 fn calculate_tensor_memory_usage(&self, tensor: &Tensor) -> usize {
189 let element_count: usize = tensor.shape().iter().product();
190 let bytes_per_element = match tensor.dtype() {
191 DType::F32 => 4,
192 DType::F16 => 2,
193 DType::F64 => 8,
194 DType::I32 => 4,
195 DType::I64 => 8,
196 _ => 4, };
198 element_count * bytes_per_element
199 }
200
201 pub fn evict_chunk(&mut self, chunk_id: usize) -> Result<()> {
202 self.evict_chunk_internal(chunk_id)
203 }
204
205 fn evict_chunk_internal(&mut self, chunk_id: usize) -> Result<()> {
206 if let Some(chunk_info) = self.current_chunks.get(&chunk_id) {
208 let tensor_names = chunk_info.tensor_names.clone();
209
210 for tensor_name in &tensor_names {
212 if let Some(tensor) = self.loaded_tensors.remove(tensor_name) {
213 let tensor_size = self.calculate_tensor_memory_usage(&tensor);
214 self.total_memory_usage = self.total_memory_usage.saturating_sub(tensor_size);
215 }
216 }
217
218 self.chunk_access_order.retain(|&x| x != chunk_id);
220 }
221
222 Ok(())
223 }
224
225 pub fn get_memory_usage(&self) -> usize {
226 self.total_memory_usage
227 }
228
229 pub fn get_memory_usage_percentage(&self) -> f32 {
231 if self.max_memory_usage == 0 {
232 0.0
233 } else {
234 (self.total_memory_usage as f32 / self.max_memory_usage as f32) * 100.0
235 }
236 }
237
238 pub fn get_memory_stats(&self) -> MemoryStats {
240 MemoryStats {
241 current_usage_bytes: self.total_memory_usage,
242 max_usage_bytes: self.max_memory_usage,
243 usage_percentage: self.get_memory_usage_percentage(),
244 loaded_chunks: self.chunk_access_order.len(),
245 total_chunks: self.current_chunks.len(),
246 loaded_tensors: self.loaded_tensors.len(),
247 }
248 }
249
250 pub fn is_chunk_loaded(&self, chunk_id: usize) -> bool {
251 self.chunk_access_order.contains(&chunk_id)
252 }
253
254 pub fn get_chunk_info(&self) -> Vec<ChunkInfo> {
256 self.current_chunks.values().cloned().collect()
257 }
258
259 pub fn garbage_collect(&mut self) -> Result<()> {
261 let mut actual_usage = 0;
264 for tensor in self.loaded_tensors.values() {
265 actual_usage += self.calculate_tensor_memory_usage(tensor);
266 }
267 self.total_memory_usage = actual_usage;
268 Ok(())
269 }
270}
271
272impl WeightLoader for StreamingLoader {
273 fn load_tensor(&mut self, name: &str) -> Result<Tensor> {
274 if let Some(tensor) = self.loaded_tensors.get(name) {
276 if let Some(&chunk_id) = self.tensor_to_chunk.get(name) {
278 if let Some(chunk) = self.current_chunks.get_mut(&chunk_id) {
279 chunk.last_accessed = std::time::Instant::now();
280 }
281 self.chunk_access_order.retain(|&x| x != chunk_id);
283 self.chunk_access_order.push_front(chunk_id);
284 }
285 return Ok(tensor.clone());
286 }
287
288 if let Some(&chunk_id) = self.tensor_to_chunk.get(name) {
290 self.load_chunk(chunk_id)?;
292
293 if let Some(tensor) = self.loaded_tensors.get(name) {
295 Ok(tensor.clone())
296 } else {
297 Err(TrustformersError::new(ErrorKind::WeightLoadingError {
298 reason: format!("Failed to load tensor {} from chunk {}", name, chunk_id),
299 }))
300 }
301 } else {
302 Err(TrustformersError::new(ErrorKind::WeightLoadingError {
303 reason: format!("Tensor {} not found in any chunk", name),
304 }))
305 }
306 }
307
308 fn list_tensors(&self) -> Result<Vec<String>> {
309 Ok(self.tensor_to_chunk.keys().cloned().collect())
311 }
312
313 fn tensor_info(&self, name: &str) -> Result<Option<TensorMetadata>> {
314 Ok(self.tensor_metadata_cache.get(name).cloned())
315 }
316
317 fn close(&mut self) -> Result<()> {
318 self.loaded_tensors.clear();
320 self.chunk_access_order.clear();
321 self.total_memory_usage = 0;
322
323 if let Some(mut loader) = self.underlying_loader.lock().expect("operation failed").take() {
325 loader.close()?;
326 }
327
328 Ok(())
329 }
330}
331
332impl StreamingLoader {
334 pub fn list_loaded_tensors(&self) -> Vec<String> {
336 self.loaded_tensors.keys().cloned().collect()
337 }
338}
339
340#[derive(Debug, Clone)]
342pub struct MemoryStats {
343 pub current_usage_bytes: usize,
344 pub max_usage_bytes: usize,
345 pub usage_percentage: f32,
346 pub loaded_chunks: usize,
347 pub total_chunks: usize,
348 pub loaded_tensors: usize,
349}
350
351impl std::fmt::Display for MemoryStats {
352 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
353 write!(
354 f,
355 "Memory Usage: {:.1}% ({} / {} bytes), Chunks: {} / {}, Tensors: {}",
356 self.usage_percentage,
357 self.current_usage_bytes,
358 self.max_usage_bytes,
359 self.loaded_chunks,
360 self.total_chunks,
361 self.loaded_tensors
362 )
363 }
364}