Skip to main content

torsh_models/
lazy_loading.rs

1//! Lazy loading optimizations for efficient model loading
2//!
3//! This module provides advanced loading strategies:
4//! - Memory-mapped file access for large models
5//! - Lazy tensor materialization (load on first use)
6//! - Streaming loading for huge models
7//! - LRU caching for frequently accessed tensors
8
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::sync::{Arc, RwLock};
12
13use safetensors::SafeTensors;
14use torsh_core::{device::DeviceType, dtype::DType};
15use torsh_tensor::Tensor;
16
17use crate::{ModelError, ModelResult};
18
19/// Simple f16 to f32 conversion (not IEEE 754 accurate)
20/// In production, use a proper half-precision library
21fn f16_to_f32_simple(bits: u16) -> f32 {
22    // Extract sign, exponent, and mantissa
23    let sign = (bits >> 15) & 0x1;
24    let exponent = (bits >> 10) & 0x1F;
25    let mantissa = bits & 0x3FF;
26
27    // Handle special cases
28    if exponent == 0 {
29        if mantissa == 0 {
30            // Zero
31            return if sign == 1 { -0.0 } else { 0.0 };
32        }
33        // Subnormal
34        return 0.0; // Simplified: return 0 for subnormals
35    } else if exponent == 0x1F {
36        if mantissa == 0 {
37            // Infinity
38            return if sign == 1 {
39                f32::NEG_INFINITY
40            } else {
41                f32::INFINITY
42            };
43        }
44        // NaN
45        return f32::NAN;
46    }
47
48    // Normal numbers
49    let f32_exponent = (exponent as i32) - 15 + 127;
50    let f32_mantissa = (mantissa as u32) << 13;
51    let f32_sign = (sign as u32) << 31;
52
53    let f32_bits = f32_sign | ((f32_exponent as u32) << 23) | f32_mantissa;
54    f32::from_bits(f32_bits)
55}
56
57/// Lazy tensor that loads data on first access
58pub struct LazyTensor {
59    /// Tensor name
60    name: String,
61    /// Shape of the tensor
62    shape: Vec<usize>,
63    /// Data type
64    dtype: DType,
65    /// Path to the model file
66    file_path: PathBuf,
67    /// Cached tensor (None until first access)
68    cached: Arc<RwLock<Option<Tensor>>>,
69    /// Offset in file (for memory-mapped access)
70    _offset: usize,
71    /// Size in bytes
72    size: usize,
73}
74
75impl LazyTensor {
76    /// Create a new lazy tensor
77    pub fn new(
78        name: String,
79        shape: Vec<usize>,
80        dtype: DType,
81        file_path: PathBuf,
82        offset: usize,
83        size: usize,
84    ) -> Self {
85        Self {
86            name,
87            shape,
88            dtype,
89            file_path,
90            cached: Arc::new(RwLock::new(None)),
91            _offset: offset,
92            size,
93        }
94    }
95
96    /// Get the tensor, loading it if not cached
97    pub fn get(&self) -> ModelResult<Tensor> {
98        // Check if cached
99        {
100            let cache = self.cached.read().expect("lock should not be poisoned");
101            if let Some(tensor) = cache.as_ref() {
102                return Ok(tensor.clone());
103            }
104        }
105
106        // Load tensor
107        let tensor = self.load_from_file()?;
108
109        // Cache it
110        {
111            let mut cache = self.cached.write().expect("lock should not be poisoned");
112            *cache = Some(tensor.clone());
113        }
114
115        Ok(tensor)
116    }
117
118    /// Load tensor from file
119    fn load_from_file(&self) -> ModelResult<Tensor> {
120        // Read the file data
121        let file_data = std::fs::read(&self.file_path)?;
122
123        // Parse SafeTensors
124        let safetensors = SafeTensors::deserialize(&file_data)?;
125
126        // Get the specific tensor
127        let tensor_view = safetensors
128            .tensor(&self.name)
129            .map_err(|e| ModelError::LoadingError {
130                reason: format!("Tensor {} not found in file: {}", self.name, e),
131            })?;
132
133        // Convert to ToRSh tensor - safetensors data is &[u8], need to handle properly
134        let data = tensor_view.data();
135
136        // NOTE: Current limitation - all tensors are converted to f32 because Tensor<T> is generic
137        // In the future, this should properly support all dtypes with dynamic dispatch or enum wrapping
138
139        // Convert bytes to f32 based on original dtype
140        let float_data: Vec<f32> = match self.dtype {
141            DType::F32 => data
142                .chunks_exact(4)
143                .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
144                .collect(),
145            DType::F64 => data
146                .chunks_exact(8)
147                .map(|chunk| {
148                    let val = f64::from_le_bytes([
149                        chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
150                        chunk[7],
151                    ]);
152                    val as f32
153                })
154                .collect(),
155            DType::I32 => data
156                .chunks_exact(4)
157                .map(|chunk| {
158                    let val = i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
159                    val as f32
160                })
161                .collect(),
162            DType::I64 => data
163                .chunks_exact(8)
164                .map(|chunk| {
165                    let val = i64::from_le_bytes([
166                        chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
167                        chunk[7],
168                    ]);
169                    val as f32
170                })
171                .collect(),
172            DType::I16 => data
173                .chunks_exact(2)
174                .map(|chunk| {
175                    let val = i16::from_le_bytes([chunk[0], chunk[1]]);
176                    val as f32
177                })
178                .collect(),
179            DType::I8 => data.iter().map(|&b| (b as i8) as f32).collect(),
180            DType::U8 => data.iter().map(|&b| b as f32).collect(),
181            DType::U32 => data
182                .chunks_exact(4)
183                .map(|chunk| {
184                    let val = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
185                    val as f32
186                })
187                .collect(),
188            DType::U64 => data
189                .chunks_exact(8)
190                .map(|chunk| {
191                    let val = u64::from_le_bytes([
192                        chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
193                        chunk[7],
194                    ]);
195                    val as f32
196                })
197                .collect(),
198            DType::F16 | DType::BF16 => data
199                .chunks_exact(2)
200                .map(|chunk| {
201                    let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
202                    f16_to_f32_simple(bits)
203                })
204                .collect(),
205            _ => {
206                // Fallback: treat as u8 and convert to f32
207                data.iter().map(|&b| b as f32).collect()
208            }
209        };
210
211        let tensor = Tensor::from_data(float_data, self.shape.clone(), DeviceType::Cpu)?;
212
213        Ok(tensor)
214    }
215
216    /// Clear the cache to free memory
217    pub fn clear_cache(&self) {
218        let mut cache = self.cached.write().expect("lock should not be poisoned");
219        *cache = None;
220    }
221
222    /// Check if tensor is cached
223    pub fn is_cached(&self) -> bool {
224        let cache = self.cached.read().expect("lock should not be poisoned");
225        cache.is_some()
226    }
227
228    /// Get tensor shape without loading
229    pub fn shape(&self) -> &[usize] {
230        &self.shape
231    }
232
233    /// Get tensor dtype without loading
234    pub fn dtype(&self) -> DType {
235        self.dtype
236    }
237
238    /// Get tensor name
239    pub fn name(&self) -> &str {
240        &self.name
241    }
242}
243
244/// Lazy model loader with LRU cache
245pub struct LazyModelLoader {
246    /// Path to the model file
247    _file_path: PathBuf,
248    /// Lazy tensors indexed by name
249    tensors: HashMap<String, LazyTensor>,
250    /// Maximum cache size in bytes
251    max_cache_size: usize,
252    /// Current cache size in bytes
253    current_cache_size: Arc<RwLock<usize>>,
254    /// Access order for LRU
255    access_order: Arc<RwLock<Vec<String>>>,
256}
257
258impl LazyModelLoader {
259    /// Create a new lazy model loader
260    pub fn new<P: AsRef<Path>>(path: P, max_cache_size: usize) -> ModelResult<Self> {
261        let file_path = path.as_ref().to_path_buf();
262        let tensors = Self::scan_tensors(&file_path)?;
263
264        Ok(Self {
265            _file_path: file_path,
266            tensors,
267            max_cache_size,
268            current_cache_size: Arc::new(RwLock::new(0)),
269            access_order: Arc::new(RwLock::new(Vec::new())),
270        })
271    }
272
273    /// Scan file and create lazy tensors
274    fn scan_tensors(path: &Path) -> ModelResult<HashMap<String, LazyTensor>> {
275        let file_data = std::fs::read(path)?;
276        let safetensors = SafeTensors::deserialize(&file_data)?;
277
278        let mut tensors = HashMap::new();
279
280        for (name, _tensor_view) in safetensors.tensors() {
281            // Get tensor view again to extract metadata
282            let tensor_view = safetensors
283                .tensor(&name)
284                .map_err(|e| ModelError::LoadingError {
285                    reason: format!("Failed to get tensor {}: {}", name, e),
286                })?;
287
288            let shape = tensor_view.shape().to_vec();
289            let dtype = Self::convert_dtype(tensor_view.dtype());
290            let size = tensor_view.data().len();
291
292            let lazy_tensor = LazyTensor::new(
293                name.to_string(),
294                shape,
295                dtype,
296                path.to_path_buf(),
297                0, // Offset would need to be calculated properly
298                size,
299            );
300
301            tensors.insert(name.to_string(), lazy_tensor);
302        }
303
304        Ok(tensors)
305    }
306
307    /// Convert SafeTensors dtype to ToRSh DType
308    fn convert_dtype(dtype: safetensors::Dtype) -> DType {
309        match dtype {
310            safetensors::Dtype::F32 => DType::F32,
311            safetensors::Dtype::F64 => DType::F64,
312            safetensors::Dtype::I32 => DType::I32,
313            safetensors::Dtype::I64 => DType::I64,
314            safetensors::Dtype::U8 => DType::U8,
315            safetensors::Dtype::I8 => DType::I8,
316            safetensors::Dtype::I16 => DType::I16,
317            safetensors::Dtype::U16 => DType::I16, // Note: DType doesn't have U16, using I16
318            safetensors::Dtype::U32 => DType::U32,
319            safetensors::Dtype::U64 => DType::U64,
320            safetensors::Dtype::F16 => DType::F16,
321            safetensors::Dtype::BF16 => DType::BF16,
322            _ => DType::F32, // Default
323        }
324    }
325
326    /// Get a tensor by name
327    pub fn get_tensor(&self, name: &str) -> ModelResult<Tensor> {
328        let lazy_tensor = self
329            .tensors
330            .get(name)
331            .ok_or_else(|| ModelError::LoadingError {
332                reason: format!("Tensor {} not found", name),
333            })?;
334
335        // Update access order for LRU
336        self.update_access_order(name);
337
338        // Get the tensor (will load if not cached)
339        let tensor = lazy_tensor.get()?;
340
341        // Update cache size
342        let tensor_size = lazy_tensor.size;
343        self.add_to_cache(tensor_size)?;
344
345        Ok(tensor)
346    }
347
348    /// Update access order for LRU eviction
349    fn update_access_order(&self, name: &str) {
350        let mut access_order = self
351            .access_order
352            .write()
353            .expect("lock should not be poisoned");
354
355        // Remove if already present
356        if let Some(pos) = access_order.iter().position(|n| n == name) {
357            access_order.remove(pos);
358        }
359
360        // Add to the end (most recently used)
361        access_order.push(name.to_string());
362    }
363
364    /// Add to cache and evict if necessary
365    fn add_to_cache(&self, size: usize) -> ModelResult<()> {
366        let mut current_size = self
367            .current_cache_size
368            .write()
369            .expect("lock should not be poisoned");
370        *current_size += size;
371
372        // Evict least recently used tensors if cache is full
373        while *current_size > self.max_cache_size {
374            let evicted = self.evict_lru()?;
375            if !evicted {
376                break; // Nothing left to evict
377            }
378        }
379
380        Ok(())
381    }
382
383    /// Evict least recently used tensor
384    fn evict_lru(&self) -> ModelResult<bool> {
385        let mut access_order = self
386            .access_order
387            .write()
388            .expect("lock should not be poisoned");
389
390        if access_order.is_empty() {
391            return Ok(false);
392        }
393
394        // Get least recently used (first in the list)
395        let lru_name = access_order.remove(0);
396
397        if let Some(tensor) = self.tensors.get(&lru_name) {
398            let tensor_size = tensor.size;
399            tensor.clear_cache();
400
401            let mut current_size = self
402                .current_cache_size
403                .write()
404                .expect("lock should not be poisoned");
405            *current_size = current_size.saturating_sub(tensor_size);
406        }
407
408        Ok(true)
409    }
410
411    /// Get all tensor names
412    pub fn tensor_names(&self) -> Vec<String> {
413        self.tensors.keys().cloned().collect()
414    }
415
416    /// Get tensor metadata without loading
417    pub fn tensor_metadata(&self, name: &str) -> Option<(Vec<usize>, DType)> {
418        self.tensors
419            .get(name)
420            .map(|t| (t.shape().to_vec(), t.dtype()))
421    }
422
423    /// Clear entire cache
424    pub fn clear_cache(&self) {
425        for tensor in self.tensors.values() {
426            tensor.clear_cache();
427        }
428
429        let mut current_size = self
430            .current_cache_size
431            .write()
432            .expect("lock should not be poisoned");
433        *current_size = 0;
434
435        let mut access_order = self
436            .access_order
437            .write()
438            .expect("lock should not be poisoned");
439        access_order.clear();
440    }
441
442    /// Get current cache statistics
443    pub fn cache_stats(&self) -> CacheStats {
444        let cached_count = self.tensors.values().filter(|t| t.is_cached()).count();
445        let total_count = self.tensors.len();
446        let current_size = *self
447            .current_cache_size
448            .read()
449            .expect("lock should not be poisoned");
450
451        CacheStats {
452            cached_tensors: cached_count,
453            total_tensors: total_count,
454            cache_size_bytes: current_size,
455            max_cache_size_bytes: self.max_cache_size,
456        }
457    }
458}
459
460/// Cache statistics
461#[derive(Debug, Clone)]
462pub struct CacheStats {
463    /// Number of currently cached tensors
464    pub cached_tensors: usize,
465    /// Total number of tensors
466    pub total_tensors: usize,
467    /// Current cache size in bytes
468    pub cache_size_bytes: usize,
469    /// Maximum cache size in bytes
470    pub max_cache_size_bytes: usize,
471}
472
473impl CacheStats {
474    /// Get cache hit rate (0.0 to 1.0)
475    pub fn hit_rate(&self) -> f64 {
476        if self.total_tensors == 0 {
477            0.0
478        } else {
479            self.cached_tensors as f64 / self.total_tensors as f64
480        }
481    }
482
483    /// Get cache utilization (0.0 to 1.0)
484    pub fn utilization(&self) -> f64 {
485        if self.max_cache_size_bytes == 0 {
486            0.0
487        } else {
488            self.cache_size_bytes as f64 / self.max_cache_size_bytes as f64
489        }
490    }
491}
492
493/// Streaming model loader for very large models
494pub struct StreamingModelLoader {
495    /// Path to the model file
496    file_path: PathBuf,
497    /// Chunk size for streaming
498    chunk_size: usize,
499}
500
501impl StreamingModelLoader {
502    /// Create a new streaming model loader
503    pub fn new<P: AsRef<Path>>(path: P, chunk_size: usize) -> Self {
504        Self {
505            file_path: path.as_ref().to_path_buf(),
506            chunk_size,
507        }
508    }
509
510    /// Stream tensors one at a time
511    pub fn stream_tensors<F>(&self, mut callback: F) -> ModelResult<()>
512    where
513        F: FnMut(&str, Tensor) -> ModelResult<()>,
514    {
515        let file_data = std::fs::read(&self.file_path)?;
516        let safetensors = SafeTensors::deserialize(&file_data)?;
517
518        for (name, _tensor_view) in safetensors.tensors() {
519            // Get tensor view again
520            let tensor_view = safetensors
521                .tensor(&name)
522                .map_err(|e| ModelError::LoadingError {
523                    reason: format!("Failed to get tensor {}: {}", name, e),
524                })?;
525
526            let shape = tensor_view.shape().to_vec();
527            let data = tensor_view.data();
528
529            // Convert bytes to f32 for simplicity (in production, handle all dtypes)
530            let float_data: Vec<f32> = data
531                .chunks_exact(4)
532                .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
533                .collect();
534
535            let tensor = Tensor::from_data(float_data, shape, DeviceType::Cpu)?;
536
537            callback(&name, tensor)?;
538        }
539
540        Ok(())
541    }
542
543    /// Stream tensors in chunks
544    pub fn stream_tensor_chunks<F>(&self, tensor_name: &str, mut callback: F) -> ModelResult<()>
545    where
546        F: FnMut(usize, &[u8]) -> ModelResult<()>,
547    {
548        let file_data = std::fs::read(&self.file_path)?;
549        let safetensors = SafeTensors::deserialize(&file_data)?;
550
551        let tensor_view =
552            safetensors
553                .tensor(tensor_name)
554                .map_err(|e| ModelError::LoadingError {
555                    reason: format!("Tensor {} not found: {}", tensor_name, e),
556                })?;
557
558        let data = tensor_view.data();
559
560        // Stream in chunks
561        for (i, chunk) in data.chunks(self.chunk_size).enumerate() {
562            callback(i, chunk)?;
563        }
564
565        Ok(())
566    }
567}
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572    use std::io::Write;
573    use tempfile::NamedTempFile;
574
575    fn create_test_safetensors() -> NamedTempFile {
576        // Create a minimal SafeTensors file for testing
577        let mut file = NamedTempFile::new().unwrap();
578
579        // This is a simplified test - in practice you'd create a proper SafeTensors file
580        let test_data = vec![0u8; 100];
581        file.write_all(&test_data).unwrap();
582        file.flush().unwrap();
583
584        file
585    }
586
587    #[test]
588    fn test_cache_stats() {
589        let stats = CacheStats {
590            cached_tensors: 5,
591            total_tensors: 10,
592            cache_size_bytes: 1024,
593            max_cache_size_bytes: 2048,
594        };
595
596        assert_eq!(stats.hit_rate(), 0.5);
597        assert_eq!(stats.utilization(), 0.5);
598    }
599
600    #[test]
601    fn test_streaming_loader_creation() {
602        let file = create_test_safetensors();
603        let loader = StreamingModelLoader::new(file.path(), 1024);
604        assert_eq!(loader.chunk_size, 1024);
605    }
606}