1use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::sync::{Arc, RwLock};
12
13use safetensors::SafeTensors;
14use torsh_core::{device::DeviceType, dtype::DType};
15use torsh_tensor::Tensor;
16
17use crate::{ModelError, ModelResult};
18
19fn f16_to_f32_simple(bits: u16) -> f32 {
22 let sign = (bits >> 15) & 0x1;
24 let exponent = (bits >> 10) & 0x1F;
25 let mantissa = bits & 0x3FF;
26
27 if exponent == 0 {
29 if mantissa == 0 {
30 return if sign == 1 { -0.0 } else { 0.0 };
32 }
33 return 0.0; } else if exponent == 0x1F {
36 if mantissa == 0 {
37 return if sign == 1 {
39 f32::NEG_INFINITY
40 } else {
41 f32::INFINITY
42 };
43 }
44 return f32::NAN;
46 }
47
48 let f32_exponent = (exponent as i32) - 15 + 127;
50 let f32_mantissa = (mantissa as u32) << 13;
51 let f32_sign = (sign as u32) << 31;
52
53 let f32_bits = f32_sign | ((f32_exponent as u32) << 23) | f32_mantissa;
54 f32::from_bits(f32_bits)
55}
56
57pub struct LazyTensor {
59 name: String,
61 shape: Vec<usize>,
63 dtype: DType,
65 file_path: PathBuf,
67 cached: Arc<RwLock<Option<Tensor>>>,
69 _offset: usize,
71 size: usize,
73}
74
75impl LazyTensor {
76 pub fn new(
78 name: String,
79 shape: Vec<usize>,
80 dtype: DType,
81 file_path: PathBuf,
82 offset: usize,
83 size: usize,
84 ) -> Self {
85 Self {
86 name,
87 shape,
88 dtype,
89 file_path,
90 cached: Arc::new(RwLock::new(None)),
91 _offset: offset,
92 size,
93 }
94 }
95
96 pub fn get(&self) -> ModelResult<Tensor> {
98 {
100 let cache = self.cached.read().expect("lock should not be poisoned");
101 if let Some(tensor) = cache.as_ref() {
102 return Ok(tensor.clone());
103 }
104 }
105
106 let tensor = self.load_from_file()?;
108
109 {
111 let mut cache = self.cached.write().expect("lock should not be poisoned");
112 *cache = Some(tensor.clone());
113 }
114
115 Ok(tensor)
116 }
117
118 fn load_from_file(&self) -> ModelResult<Tensor> {
120 let file_data = std::fs::read(&self.file_path)?;
122
123 let safetensors = SafeTensors::deserialize(&file_data)?;
125
126 let tensor_view = safetensors
128 .tensor(&self.name)
129 .map_err(|e| ModelError::LoadingError {
130 reason: format!("Tensor {} not found in file: {}", self.name, e),
131 })?;
132
133 let data = tensor_view.data();
135
136 let float_data: Vec<f32> = match self.dtype {
141 DType::F32 => data
142 .chunks_exact(4)
143 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
144 .collect(),
145 DType::F64 => data
146 .chunks_exact(8)
147 .map(|chunk| {
148 let val = f64::from_le_bytes([
149 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
150 chunk[7],
151 ]);
152 val as f32
153 })
154 .collect(),
155 DType::I32 => data
156 .chunks_exact(4)
157 .map(|chunk| {
158 let val = i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
159 val as f32
160 })
161 .collect(),
162 DType::I64 => data
163 .chunks_exact(8)
164 .map(|chunk| {
165 let val = i64::from_le_bytes([
166 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
167 chunk[7],
168 ]);
169 val as f32
170 })
171 .collect(),
172 DType::I16 => data
173 .chunks_exact(2)
174 .map(|chunk| {
175 let val = i16::from_le_bytes([chunk[0], chunk[1]]);
176 val as f32
177 })
178 .collect(),
179 DType::I8 => data.iter().map(|&b| (b as i8) as f32).collect(),
180 DType::U8 => data.iter().map(|&b| b as f32).collect(),
181 DType::U32 => data
182 .chunks_exact(4)
183 .map(|chunk| {
184 let val = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
185 val as f32
186 })
187 .collect(),
188 DType::U64 => data
189 .chunks_exact(8)
190 .map(|chunk| {
191 let val = u64::from_le_bytes([
192 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
193 chunk[7],
194 ]);
195 val as f32
196 })
197 .collect(),
198 DType::F16 | DType::BF16 => data
199 .chunks_exact(2)
200 .map(|chunk| {
201 let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
202 f16_to_f32_simple(bits)
203 })
204 .collect(),
205 _ => {
206 data.iter().map(|&b| b as f32).collect()
208 }
209 };
210
211 let tensor = Tensor::from_data(float_data, self.shape.clone(), DeviceType::Cpu)?;
212
213 Ok(tensor)
214 }
215
216 pub fn clear_cache(&self) {
218 let mut cache = self.cached.write().expect("lock should not be poisoned");
219 *cache = None;
220 }
221
222 pub fn is_cached(&self) -> bool {
224 let cache = self.cached.read().expect("lock should not be poisoned");
225 cache.is_some()
226 }
227
228 pub fn shape(&self) -> &[usize] {
230 &self.shape
231 }
232
233 pub fn dtype(&self) -> DType {
235 self.dtype
236 }
237
238 pub fn name(&self) -> &str {
240 &self.name
241 }
242}
243
244pub struct LazyModelLoader {
246 _file_path: PathBuf,
248 tensors: HashMap<String, LazyTensor>,
250 max_cache_size: usize,
252 current_cache_size: Arc<RwLock<usize>>,
254 access_order: Arc<RwLock<Vec<String>>>,
256}
257
258impl LazyModelLoader {
259 pub fn new<P: AsRef<Path>>(path: P, max_cache_size: usize) -> ModelResult<Self> {
261 let file_path = path.as_ref().to_path_buf();
262 let tensors = Self::scan_tensors(&file_path)?;
263
264 Ok(Self {
265 _file_path: file_path,
266 tensors,
267 max_cache_size,
268 current_cache_size: Arc::new(RwLock::new(0)),
269 access_order: Arc::new(RwLock::new(Vec::new())),
270 })
271 }
272
273 fn scan_tensors(path: &Path) -> ModelResult<HashMap<String, LazyTensor>> {
275 let file_data = std::fs::read(path)?;
276 let safetensors = SafeTensors::deserialize(&file_data)?;
277
278 let mut tensors = HashMap::new();
279
280 for (name, _tensor_view) in safetensors.tensors() {
281 let tensor_view = safetensors
283 .tensor(&name)
284 .map_err(|e| ModelError::LoadingError {
285 reason: format!("Failed to get tensor {}: {}", name, e),
286 })?;
287
288 let shape = tensor_view.shape().to_vec();
289 let dtype = Self::convert_dtype(tensor_view.dtype());
290 let size = tensor_view.data().len();
291
292 let lazy_tensor = LazyTensor::new(
293 name.to_string(),
294 shape,
295 dtype,
296 path.to_path_buf(),
297 0, size,
299 );
300
301 tensors.insert(name.to_string(), lazy_tensor);
302 }
303
304 Ok(tensors)
305 }
306
307 fn convert_dtype(dtype: safetensors::Dtype) -> DType {
309 match dtype {
310 safetensors::Dtype::F32 => DType::F32,
311 safetensors::Dtype::F64 => DType::F64,
312 safetensors::Dtype::I32 => DType::I32,
313 safetensors::Dtype::I64 => DType::I64,
314 safetensors::Dtype::U8 => DType::U8,
315 safetensors::Dtype::I8 => DType::I8,
316 safetensors::Dtype::I16 => DType::I16,
317 safetensors::Dtype::U16 => DType::I16, safetensors::Dtype::U32 => DType::U32,
319 safetensors::Dtype::U64 => DType::U64,
320 safetensors::Dtype::F16 => DType::F16,
321 safetensors::Dtype::BF16 => DType::BF16,
322 _ => DType::F32, }
324 }
325
326 pub fn get_tensor(&self, name: &str) -> ModelResult<Tensor> {
328 let lazy_tensor = self
329 .tensors
330 .get(name)
331 .ok_or_else(|| ModelError::LoadingError {
332 reason: format!("Tensor {} not found", name),
333 })?;
334
335 self.update_access_order(name);
337
338 let tensor = lazy_tensor.get()?;
340
341 let tensor_size = lazy_tensor.size;
343 self.add_to_cache(tensor_size)?;
344
345 Ok(tensor)
346 }
347
348 fn update_access_order(&self, name: &str) {
350 let mut access_order = self
351 .access_order
352 .write()
353 .expect("lock should not be poisoned");
354
355 if let Some(pos) = access_order.iter().position(|n| n == name) {
357 access_order.remove(pos);
358 }
359
360 access_order.push(name.to_string());
362 }
363
364 fn add_to_cache(&self, size: usize) -> ModelResult<()> {
366 let mut current_size = self
367 .current_cache_size
368 .write()
369 .expect("lock should not be poisoned");
370 *current_size += size;
371
372 while *current_size > self.max_cache_size {
374 let evicted = self.evict_lru()?;
375 if !evicted {
376 break; }
378 }
379
380 Ok(())
381 }
382
383 fn evict_lru(&self) -> ModelResult<bool> {
385 let mut access_order = self
386 .access_order
387 .write()
388 .expect("lock should not be poisoned");
389
390 if access_order.is_empty() {
391 return Ok(false);
392 }
393
394 let lru_name = access_order.remove(0);
396
397 if let Some(tensor) = self.tensors.get(&lru_name) {
398 let tensor_size = tensor.size;
399 tensor.clear_cache();
400
401 let mut current_size = self
402 .current_cache_size
403 .write()
404 .expect("lock should not be poisoned");
405 *current_size = current_size.saturating_sub(tensor_size);
406 }
407
408 Ok(true)
409 }
410
411 pub fn tensor_names(&self) -> Vec<String> {
413 self.tensors.keys().cloned().collect()
414 }
415
416 pub fn tensor_metadata(&self, name: &str) -> Option<(Vec<usize>, DType)> {
418 self.tensors
419 .get(name)
420 .map(|t| (t.shape().to_vec(), t.dtype()))
421 }
422
423 pub fn clear_cache(&self) {
425 for tensor in self.tensors.values() {
426 tensor.clear_cache();
427 }
428
429 let mut current_size = self
430 .current_cache_size
431 .write()
432 .expect("lock should not be poisoned");
433 *current_size = 0;
434
435 let mut access_order = self
436 .access_order
437 .write()
438 .expect("lock should not be poisoned");
439 access_order.clear();
440 }
441
442 pub fn cache_stats(&self) -> CacheStats {
444 let cached_count = self.tensors.values().filter(|t| t.is_cached()).count();
445 let total_count = self.tensors.len();
446 let current_size = *self
447 .current_cache_size
448 .read()
449 .expect("lock should not be poisoned");
450
451 CacheStats {
452 cached_tensors: cached_count,
453 total_tensors: total_count,
454 cache_size_bytes: current_size,
455 max_cache_size_bytes: self.max_cache_size,
456 }
457 }
458}
459
460#[derive(Debug, Clone)]
462pub struct CacheStats {
463 pub cached_tensors: usize,
465 pub total_tensors: usize,
467 pub cache_size_bytes: usize,
469 pub max_cache_size_bytes: usize,
471}
472
473impl CacheStats {
474 pub fn hit_rate(&self) -> f64 {
476 if self.total_tensors == 0 {
477 0.0
478 } else {
479 self.cached_tensors as f64 / self.total_tensors as f64
480 }
481 }
482
483 pub fn utilization(&self) -> f64 {
485 if self.max_cache_size_bytes == 0 {
486 0.0
487 } else {
488 self.cache_size_bytes as f64 / self.max_cache_size_bytes as f64
489 }
490 }
491}
492
493pub struct StreamingModelLoader {
495 file_path: PathBuf,
497 chunk_size: usize,
499}
500
501impl StreamingModelLoader {
502 pub fn new<P: AsRef<Path>>(path: P, chunk_size: usize) -> Self {
504 Self {
505 file_path: path.as_ref().to_path_buf(),
506 chunk_size,
507 }
508 }
509
510 pub fn stream_tensors<F>(&self, mut callback: F) -> ModelResult<()>
512 where
513 F: FnMut(&str, Tensor) -> ModelResult<()>,
514 {
515 let file_data = std::fs::read(&self.file_path)?;
516 let safetensors = SafeTensors::deserialize(&file_data)?;
517
518 for (name, _tensor_view) in safetensors.tensors() {
519 let tensor_view = safetensors
521 .tensor(&name)
522 .map_err(|e| ModelError::LoadingError {
523 reason: format!("Failed to get tensor {}: {}", name, e),
524 })?;
525
526 let shape = tensor_view.shape().to_vec();
527 let data = tensor_view.data();
528
529 let float_data: Vec<f32> = data
531 .chunks_exact(4)
532 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
533 .collect();
534
535 let tensor = Tensor::from_data(float_data, shape, DeviceType::Cpu)?;
536
537 callback(&name, tensor)?;
538 }
539
540 Ok(())
541 }
542
543 pub fn stream_tensor_chunks<F>(&self, tensor_name: &str, mut callback: F) -> ModelResult<()>
545 where
546 F: FnMut(usize, &[u8]) -> ModelResult<()>,
547 {
548 let file_data = std::fs::read(&self.file_path)?;
549 let safetensors = SafeTensors::deserialize(&file_data)?;
550
551 let tensor_view =
552 safetensors
553 .tensor(tensor_name)
554 .map_err(|e| ModelError::LoadingError {
555 reason: format!("Tensor {} not found: {}", tensor_name, e),
556 })?;
557
558 let data = tensor_view.data();
559
560 for (i, chunk) in data.chunks(self.chunk_size).enumerate() {
562 callback(i, chunk)?;
563 }
564
565 Ok(())
566 }
567}
568
569#[cfg(test)]
570mod tests {
571 use super::*;
572 use std::io::Write;
573 use tempfile::NamedTempFile;
574
575 fn create_test_safetensors() -> NamedTempFile {
576 let mut file = NamedTempFile::new().unwrap();
578
579 let test_data = vec![0u8; 100];
581 file.write_all(&test_data).unwrap();
582 file.flush().unwrap();
583
584 file
585 }
586
587 #[test]
588 fn test_cache_stats() {
589 let stats = CacheStats {
590 cached_tensors: 5,
591 total_tensors: 10,
592 cache_size_bytes: 1024,
593 max_cache_size_bytes: 2048,
594 };
595
596 assert_eq!(stats.hit_rate(), 0.5);
597 assert_eq!(stats.utilization(), 0.5);
598 }
599
600 #[test]
601 fn test_streaming_loader_creation() {
602 let file = create_test_safetensors();
603 let loader = StreamingModelLoader::new(file.path(), 1024);
604 assert_eq!(loader.chunk_size, 1024);
605 }
606}