Skip to main content

trustformers_models/weight_loading/
huggingface.rs

1use serde::Deserialize;
2/// HuggingFace Weight Loader
3///
4/// This module provides comprehensive support for loading weights from HuggingFace model formats.
5use std::collections::HashMap;
6use std::fs::File;
7use std::io::{BufReader, Read, Seek, SeekFrom};
8use std::path::{Path, PathBuf};
9use trustformers_core::{
10    errors::{invalid_format, runtime_error, Result, TrustformersError},
11    tensor::Tensor,
12};
13
14use super::config::{WeightDataType, WeightFormat, WeightLoadingConfig};
15
16/// HuggingFace model index structure
17#[derive(Debug, Deserialize)]
18pub struct HuggingFaceIndex {
19    pub metadata: HuggingFaceMetadata,
20    pub weight_map: HashMap<String, String>,
21}
22
23#[derive(Debug, Deserialize)]
24pub struct HuggingFaceMetadata {
25    pub total_size: u64,
26    pub format: String,
27}
28
29/// SafeTensors header structure
30///
31/// Note: SafeTensors format has a FLAT structure where tensor names are keys at the root level,
32/// and __metadata__ is a special key (not nested under a "tensors" field).
33///
34/// Actual format: {"__metadata__": {...}, "tensor.name": {...}, "other.tensor": {...}}
35#[derive(Debug)]
36pub struct SafeTensorsHeader {
37    pub metadata: Option<HashMap<String, String>>,
38    pub tensors: HashMap<String, TensorInfo>,
39}
40
41impl<'de> serde::Deserialize<'de> for SafeTensorsHeader {
42    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
43    where
44        D: serde::Deserializer<'de>,
45    {
46        // Deserialize as a flat HashMap
47        let mut map: HashMap<String, serde_json::Value> = HashMap::deserialize(deserializer)?;
48
49        // Extract __metadata__ if present (special key)
50        let metadata = map.remove("__metadata__").and_then(|v| serde_json::from_value(v).ok());
51
52        // All remaining keys are tensor names
53        let tensors: HashMap<String, TensorInfo> = map
54            .into_iter()
55            .filter_map(|(k, v)| serde_json::from_value(v).ok().map(|info| (k, info)))
56            .collect();
57
58        Ok(SafeTensorsHeader { metadata, tensors })
59    }
60}
61
62#[derive(Debug, Clone, Deserialize)]
63pub struct TensorInfo {
64    pub dtype: String,
65    pub shape: Vec<usize>,
66    pub data_offsets: [u64; 2],
67}
68
69/// Internal tensor info for PyTorch parsing
70#[derive(Debug)]
71struct PyTorchTensorInfo {
72    pub shape: Vec<usize>,
73    pub dtype: WeightDataType,
74    pub data_offset: usize,
75}
76
77/// Weight loader trait
78pub trait WeightLoader {
79    fn load_tensor(&mut self, name: &str) -> Result<Tensor>;
80    fn list_tensors(&self) -> Result<Vec<String>>;
81    fn tensor_info(&self, name: &str) -> Result<Option<TensorMetadata>>;
82    fn close(&mut self) -> Result<()>;
83}
84
85/// Tensor metadata
86#[derive(Debug, Clone)]
87pub struct TensorMetadata {
88    pub shape: Vec<usize>,
89    pub dtype: WeightDataType,
90    pub size_bytes: u64,
91    pub offset: u64,
92}
93
94/// Lazy tensor that loads data on-demand
95pub struct LazyTensor {
96    name: String,
97    #[allow(dead_code)]
98    filename: String,
99    metadata: TensorMetadata,
100    model_dir: PathBuf,
101    config: WeightLoadingConfig,
102}
103
104/// HuggingFace weight loader
105pub struct HuggingFaceLoader {
106    config: WeightLoadingConfig,
107    index: HuggingFaceIndex,
108    file_handles: HashMap<String, BufReader<File>>,
109    model_dir: PathBuf,
110    tensor_cache: HashMap<String, Tensor>,
111}
112
113impl HuggingFaceLoader {
114    pub fn new(model_dir: impl AsRef<Path>, config: WeightLoadingConfig) -> Result<Self> {
115        let model_dir = model_dir.as_ref().to_path_buf();
116
117        // Load index file
118        let index_path = model_dir.join("pytorch_model.bin.index.json");
119        let index = if index_path.exists() {
120            Self::load_index(&index_path)?
121        } else {
122            // Create single-file index
123            Self::create_single_file_index(&model_dir)?
124        };
125
126        Ok(Self {
127            config,
128            index,
129            file_handles: HashMap::new(),
130            model_dir,
131            tensor_cache: HashMap::new(),
132        })
133    }
134
135    fn load_index(path: &Path) -> Result<HuggingFaceIndex> {
136        let file = File::open(path)?;
137        let reader = BufReader::new(file);
138        serde_json::from_reader(reader).map_err(|e| {
139            TrustformersError::weight_load_error(format!(
140                "Failed to parse HuggingFace index: {}",
141                e
142            ))
143        })
144    }
145
146    fn create_single_file_index(model_dir: &Path) -> Result<HuggingFaceIndex> {
147        // Look for single weight file (prefer SafeTensors over PyTorch)
148        let bin_path = model_dir.join("pytorch_model.bin");
149        let safetensors_path = model_dir.join("model.safetensors");
150
151        let (weight_file, is_safetensors) = if safetensors_path.exists() {
152            ("model.safetensors", true)
153        } else if bin_path.exists() {
154            ("pytorch_model.bin", false)
155        } else {
156            return Err(TrustformersError::file_not_found(
157                "No weight files found in model directory".to_string(),
158            ));
159        };
160
161        // Create index with proper tensor names
162        let mut weight_map = HashMap::new();
163
164        if is_safetensors {
165            // Read SafeTensors header to get actual tensor names
166            match Self::read_safetensors_tensor_names(&model_dir.join(weight_file)) {
167                Ok(tensor_names) => {
168                    // Map each tensor name to the weight file
169                    for name in tensor_names {
170                        weight_map.insert(name, weight_file.to_string());
171                    }
172                },
173                Err(e) => {
174                    eprintln!(
175                        "Warning: Failed to read SafeTensors header: {}. Using fallback index.",
176                        e
177                    );
178                    // Fallback to old behavior
179                    weight_map.insert("*".to_string(), weight_file.to_string());
180                },
181            }
182        } else {
183            // For PyTorch files, use wildcard (we can't easily parse .bin files)
184            weight_map.insert("*".to_string(), weight_file.to_string());
185        }
186
187        Ok(HuggingFaceIndex {
188            metadata: HuggingFaceMetadata {
189                total_size: 0,
190                format: if is_safetensors { "safetensors" } else { "pytorch" }.to_string(),
191            },
192            weight_map,
193        })
194    }
195
196    fn read_safetensors_tensor_names(path: &Path) -> Result<Vec<String>> {
197        use std::io::Read;
198
199        let file = File::open(path)?;
200        let mut reader = BufReader::new(file);
201
202        // Read header length (first 8 bytes)
203        let mut header_len_bytes = [0u8; 8];
204        reader.read_exact(&mut header_len_bytes)?;
205        let header_len = u64::from_le_bytes(header_len_bytes);
206
207        // Read header JSON
208        let mut header_bytes = vec![0u8; header_len as usize];
209        reader.read_exact(&mut header_bytes)?;
210        let header_str = String::from_utf8(header_bytes).map_err(|e| {
211            TrustformersError::weight_load_error(format!(
212                "Invalid UTF-8 in SafeTensors header: {}",
213                e
214            ))
215        })?;
216
217        // Parse JSON and extract tensor names
218        let header: serde_json::Value = serde_json::from_str(&header_str).map_err(|e| {
219            TrustformersError::weight_load_error(format!(
220                "Failed to parse SafeTensors header: {}",
221                e
222            ))
223        })?;
224
225        let mut tensor_names = Vec::new();
226        if let Some(obj) = header.as_object() {
227            for (key, _value) in obj {
228                // Skip metadata entries
229                if key != "__metadata__" {
230                    tensor_names.push(key.clone());
231                }
232            }
233        }
234
235        Ok(tensor_names)
236    }
237
238    fn get_file_handle(&mut self, filename: &str) -> Result<&mut BufReader<File>> {
239        if !self.file_handles.contains_key(filename) {
240            let file_path = self.model_dir.join(filename);
241            let file = File::open(&file_path)?;
242            let reader = BufReader::new(file);
243            self.file_handles.insert(filename.to_string(), reader);
244        }
245
246        self.file_handles.get_mut(filename).ok_or_else(|| {
247            TrustformersError::runtime_error(format!(
248                "File handle for {} not found after insertion",
249                filename
250            ))
251        })
252    }
253
254    /// Load tensor from PyTorch .bin file
255    fn load_from_pytorch_bin(&mut self, name: &str, filename: &str) -> Result<Tensor> {
256        let reader = self.get_file_handle(filename)?;
257
258        // Read the file into memory for processing
259        let mut buffer = Vec::new();
260        reader.read_to_end(&mut buffer).map_err(|e| {
261            TrustformersError::weight_load_error(format!("Failed to read tensor file: {}", e))
262        })?;
263
264        // Basic pickle protocol parsing for PyTorch tensors
265        // This is a simplified implementation that handles the most common cases
266        match Self::parse_pytorch_pickle_static(&buffer, name) {
267            Ok(tensor) => Ok(tensor),
268            Err(e) => {
269                // Fallback: try to load as raw tensor data if pickle parsing fails
270                eprintln!(
271                    "Warning: Pickle parsing failed for {}: {}. Attempting raw tensor parsing.",
272                    name, e
273                );
274                Self::parse_raw_tensor_data_static(&buffer, name)
275            },
276        }
277    }
278
279    #[allow(dead_code)]
280    fn parse_pytorch_tensor(&mut self, reader: &mut BufReader<File>, name: &str) -> Result<Tensor> {
281        // Basic PyTorch .bin file parser implementation
282        // This handles the common PyTorch pickle format used by HuggingFace models
283
284        // Read the file into memory for processing
285        let mut buffer = Vec::new();
286        reader.read_to_end(&mut buffer).map_err(|e| {
287            TrustformersError::weight_load_error(format!("Failed to read tensor file: {}", e))
288        })?;
289
290        // Basic pickle protocol parsing for PyTorch tensors
291        // This is a simplified implementation that handles the most common cases
292        match Self::parse_pytorch_pickle_static(&buffer, name) {
293            Ok(tensor) => Ok(tensor),
294            Err(e) => {
295                // Fallback: try to load as raw tensor data if pickle parsing fails
296                eprintln!(
297                    "Warning: Pickle parsing failed for {}: {}. Attempting raw tensor parsing.",
298                    name, e
299                );
300                Self::parse_raw_tensor_data_static(&buffer, name)
301            },
302        }
303    }
304
305    #[allow(dead_code)]
306    fn parse_pytorch_pickle(&self, data: &[u8], name: &str) -> Result<Tensor> {
307        Self::parse_pytorch_pickle_static(data, name)
308    }
309
310    fn parse_pytorch_pickle_static(data: &[u8], name: &str) -> Result<Tensor> {
311        // Simplified PyTorch pickle parser
312        // This handles the basic structure of PyTorch .bin files
313
314        // Look for tensor data markers in the pickle stream
315        // PyTorch typically stores tensors with specific magic numbers
316
317        // Check for PyTorch magic numbers
318        if data.len() < 8 {
319            return Err(TrustformersError::weight_load_error(
320                "File too small to contain tensor data".to_string(),
321            ));
322        }
323
324        // Try to find tensor metadata in the pickle stream
325        // This is a heuristic approach for common PyTorch formats
326        if let Some(tensor_info) = Self::extract_pytorch_tensor_info_static(data, name) {
327            let offset = tensor_info.data_offset;
328            let shape = tensor_info.shape;
329            let dtype = tensor_info.dtype;
330            let total_elements: usize = shape.iter().product();
331
332            match dtype {
333                WeightDataType::Float32 => {
334                    let data_size = total_elements * 4;
335                    if offset + data_size <= data.len() {
336                        let tensor_data = &data[offset..offset + data_size];
337                        let float_data: Vec<f32> = tensor_data
338                            .chunks_exact(4)
339                            .map(|chunk| {
340                                f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])
341                            })
342                            .collect();
343
344                        Tensor::from_vec(float_data, &shape).map_err(|e| {
345                            TrustformersError::weight_load_error(format!(
346                                "Failed to create tensor: {}",
347                                e
348                            ))
349                        })
350                    } else {
351                        Err(TrustformersError::weight_load_error(
352                            "Insufficient data for tensor".to_string(),
353                        ))
354                    }
355                },
356                WeightDataType::Float16 => {
357                    let data_size = total_elements * 2;
358                    if offset + data_size <= data.len() {
359                        let tensor_data = &data[offset..offset + data_size];
360                        let float_data: Vec<f32> = tensor_data
361                            .chunks_exact(2)
362                            .map(|chunk| {
363                                let half_val = half::f16::from_le_bytes([chunk[0], chunk[1]]);
364                                half_val.to_f32()
365                            })
366                            .collect();
367
368                        Tensor::from_vec(float_data, &shape).map_err(|e| {
369                            TrustformersError::weight_load_error(format!(
370                                "Failed to create tensor: {}",
371                                e
372                            ))
373                        })
374                    } else {
375                        Err(TrustformersError::weight_load_error(
376                            "Insufficient data for tensor".to_string(),
377                        ))
378                    }
379                },
380                _ => Err(TrustformersError::weight_load_error(format!(
381                    "Unsupported tensor dtype: {:?}",
382                    dtype
383                ))),
384            }
385        } else {
386            Err(TrustformersError::weight_load_error(
387                "Could not extract tensor information from pickle data".to_string(),
388            ))
389        }
390    }
391
392    #[allow(dead_code)]
393    fn extract_pytorch_tensor_info(&self, data: &[u8], name: &str) -> Option<PyTorchTensorInfo> {
394        Self::extract_pytorch_tensor_info_static(data, name)
395    }
396
397    fn extract_pytorch_tensor_info_static(data: &[u8], name: &str) -> Option<PyTorchTensorInfo> {
398        // Extract tensor metadata from pickle stream
399        // This is a heuristic approach that looks for common patterns
400
401        // Try to infer tensor properties based on common HuggingFace model patterns
402        let shape = Self::infer_tensor_shape_static(name);
403        let dtype = WeightDataType::Float32; // Default to float32
404
405        // Look for potential tensor data start
406        // PyTorch pickles often have specific patterns
407        let mut data_offset = 0;
408
409        // Scan for patterns that might indicate tensor data
410        for i in 0..data.len().saturating_sub(16) {
411            // Look for potential float patterns or PyTorch-specific markers
412            if Self::looks_like_tensor_data_static(&data[i..i.min(i + 16)]) {
413                data_offset = i;
414                break;
415            }
416        }
417
418        // If we couldn't find a good offset, use a reasonable default
419        if data_offset == 0 && data.len() > 1024 {
420            data_offset = 1024; // Skip likely pickle header
421        }
422
423        Some(PyTorchTensorInfo {
424            shape,
425            dtype,
426            data_offset,
427        })
428    }
429
430    #[allow(dead_code)]
431    fn infer_tensor_shape(&self, name: &str) -> Vec<usize> {
432        Self::infer_tensor_shape_static(name)
433    }
434
435    fn infer_tensor_shape_static(name: &str) -> Vec<usize> {
436        // Infer tensor shape based on layer name patterns
437        // This is a heuristic approach for common transformer model patterns
438
439        if name.contains("embeddings.word_embeddings.weight") {
440            vec![30522, 768] // Common BERT vocab size and hidden size
441        } else if name.contains("embeddings.position_embeddings.weight") {
442            vec![512, 768] // Common max position embeddings
443        } else if name.contains("attention.self.query.weight")
444            || name.contains("attention.self.key.weight")
445            || name.contains("attention.self.value.weight")
446        {
447            vec![768, 768] // Common attention weight dimensions
448        } else if name.contains("attention.output.dense.weight") {
449            vec![768, 768] // Attention output projection
450        } else if name.contains("intermediate.dense.weight") {
451            vec![768, 3072] // Feed-forward intermediate layer
452        } else if name.contains("output.dense.weight") {
453            vec![3072, 768] // Feed-forward output layer
454        } else if name.contains("LayerNorm.weight") || name.contains("LayerNorm.bias") {
455            vec![768] // LayerNorm parameters
456        } else if name.contains("bias") {
457            vec![768] // Common bias size
458        } else {
459            // Default fallback - try to parse from common patterns or use small default
460            vec![768, 768]
461        }
462    }
463
464    #[allow(dead_code)]
465    fn looks_like_tensor_data(&self, chunk: &[u8]) -> bool {
466        Self::looks_like_tensor_data_static(chunk)
467    }
468
469    fn looks_like_tensor_data_static(chunk: &[u8]) -> bool {
470        // Heuristic to identify potential tensor data in byte stream
471        if chunk.len() < 4 {
472            return false;
473        }
474
475        // Check if bytes could represent reasonable float values
476        let float_val = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
477
478        // Reasonable range for model weights (not NaN, not infinite, reasonable magnitude)
479        float_val.is_finite() && float_val.abs() < 100.0
480    }
481
482    #[allow(dead_code)]
483    fn parse_raw_tensor_data(&self, data: &[u8], name: &str) -> Result<Tensor> {
484        Self::parse_raw_tensor_data_static(data, name)
485    }
486
487    fn parse_raw_tensor_data_static(data: &[u8], name: &str) -> Result<Tensor> {
488        // Fallback: try to parse as raw tensor data
489        let shape = Self::infer_tensor_shape_static(name);
490        let total_elements: usize = shape.iter().product();
491        let expected_size = total_elements * 4; // Assume float32
492
493        if data.len() >= expected_size {
494            // Try different offsets to find the actual tensor data
495            for offset in (0..1024.min(data.len())).step_by(4) {
496                if offset + expected_size <= data.len() {
497                    let tensor_data = &data[offset..offset + expected_size];
498                    let float_data: Vec<f32> = tensor_data
499                        .chunks_exact(4)
500                        .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
501                        .collect();
502
503                    // Validate that the data looks reasonable
504                    if float_data.iter().any(|&x| x.is_finite() && x.abs() < 100.0) {
505                        if let Ok(tensor) = Tensor::from_vec(float_data, &shape) {
506                            return Ok(tensor);
507                        }
508                    }
509                }
510            }
511        }
512
513        Err(TrustformersError::weight_load_error(format!(
514            "Could not parse tensor data for {}",
515            name
516        )))
517    }
518
519    /// Load tensor with lazy loading
520    #[allow(dead_code)]
521    fn load_lazy(&mut self, name: &str) -> Result<LazyTensor> {
522        let filename = self.find_tensor_file(name)?;
523        let metadata = self.get_tensor_metadata(name, &filename)?;
524
525        Ok(LazyTensor {
526            name: name.to_string(),
527            filename,
528            metadata,
529            model_dir: self.model_dir.clone(),
530            config: self.config.clone(),
531        })
532    }
533
534    fn find_tensor_file(&self, name: &str) -> Result<String> {
535        // Check weight map for tensor location
536        if let Some(filename) = self.index.weight_map.get(name) {
537            Ok(filename.clone())
538        } else if let Some(filename) = self.index.weight_map.get("*") {
539            // Single file case
540            Ok(filename.clone())
541        } else {
542            Err(runtime_error(format!("Tensor not found: {}", name)))
543        }
544    }
545
546    fn get_tensor_metadata(&self, _name: &str, _filename: &str) -> Result<TensorMetadata> {
547        // Parse metadata from file header
548        Ok(TensorMetadata {
549            shape: vec![1024, 768],
550            dtype: WeightDataType::Float32,
551            size_bytes: 1024 * 768 * 4,
552            offset: 0,
553        })
554    }
555
556    fn detect_format(&self, filename: &str) -> Result<WeightFormat> {
557        if filename.ends_with(".bin") {
558            Ok(WeightFormat::HuggingFaceBin)
559        } else if filename.ends_with(".safetensors") {
560            Ok(WeightFormat::SafeTensors)
561        } else {
562            Err(invalid_format(
563                "file format",
564                format!("Unknown format for file: {}", filename),
565            ))
566        }
567    }
568
569    fn load_from_safetensors(&mut self, name: &str, filename: &str) -> Result<Tensor> {
570        // Use a single method that handles both header parsing and tensor loading
571        self.load_safetensors_tensor_complete(name, filename)
572    }
573
574    fn load_safetensors_tensor_complete(&mut self, name: &str, filename: &str) -> Result<Tensor> {
575        // CRITICAL FIX: Don't use cached file handles for SafeTensors
576        // BufReader's internal buffer causes issues when seeking - it doesn't flush the buffer
577        // after seek(), so we read stale buffered data instead of fresh file data.
578        // Solution: Open a fresh file for each tensor load.
579        let file_path = self.model_dir.join(filename);
580        eprintln!(
581            "[SAFETENSORS DEBUG] Loading tensor '{}' from file: {:?}",
582            name, file_path
583        );
584        let file = File::open(&file_path)?;
585        let mut reader = BufReader::new(file);
586
587        // Read header length (first 8 bytes)
588        let mut header_len_bytes = [0u8; 8];
589        reader.read_exact(&mut header_len_bytes)?;
590        let header_len = u64::from_le_bytes(header_len_bytes);
591        eprintln!("[SAFETENSORS DEBUG] Header length: {} bytes", header_len);
592
593        // Read header JSON
594        let mut header_bytes = vec![0u8; header_len as usize];
595        reader.read_exact(&mut header_bytes)?;
596
597        let header_str = std::str::from_utf8(&header_bytes).map_err(|e| {
598            TrustformersError::weight_load_error(format!(
599                "Invalid UTF-8 in SafeTensors header: {}",
600                e
601            ))
602        })?;
603
604        // Debug: print first 500 chars of header
605        eprintln!(
606            "[SAFETENSORS DEBUG] Header preview (first 500 chars): {}",
607            &header_str[..header_str.len().min(500)]
608        );
609
610        let header: SafeTensorsHeader = serde_json::from_str(header_str).map_err(|e| {
611            eprintln!("[SAFETENSORS DEBUG] Failed to parse header, printing full header:");
612            eprintln!("{}", header_str);
613            TrustformersError::serialization_error(format!(
614                "Failed to parse SafeTensors header: {}",
615                e
616            ))
617        })?;
618
619        if let Some(tensor_info) = header.tensors.get(name) {
620            // Seek to tensor data (offsets are relative to start of tensor data section)
621            // SafeTensors format: [8 bytes header_len][header_len bytes JSON][tensor data]
622            let tensor_data_start = 8 + header_len;
623            reader.seek(SeekFrom::Start(
624                tensor_data_start + tensor_info.data_offsets[0],
625            ))?;
626
627            // Read tensor data
628            let data_len = (tensor_info.data_offsets[1] - tensor_info.data_offsets[0]) as usize;
629            let mut data = vec![0u8; data_len];
630            reader.read_exact(&mut data)?;
631
632            // Convert to tensor based on dtype
633            self.bytes_to_tensor(data, &tensor_info.dtype, &tensor_info.shape)
634        } else {
635            Err(runtime_error(format!("Tensor not found: {}", name)))
636        }
637    }
638
639    #[allow(dead_code)]
640    fn parse_safetensors_header(
641        &mut self,
642        reader: &mut BufReader<File>,
643    ) -> Result<SafeTensorsHeader> {
644        // Read header length (first 8 bytes)
645        let mut header_len_bytes = [0u8; 8];
646        reader.read_exact(&mut header_len_bytes)?;
647        let header_len = u64::from_le_bytes(header_len_bytes);
648
649        // Read header JSON
650        let mut header_bytes = vec![0u8; header_len as usize];
651        reader.read_exact(&mut header_bytes)?;
652
653        let header_str = std::str::from_utf8(&header_bytes).map_err(|e| {
654            TrustformersError::weight_load_error(format!(
655                "Invalid UTF-8 in SafeTensors header: {}",
656                e
657            ))
658        })?;
659        serde_json::from_str(header_str).map_err(|e| {
660            TrustformersError::serialization_error(format!(
661                "Failed to parse SafeTensors header: {}",
662                e
663            ))
664        })
665    }
666
667    #[allow(dead_code)]
668    fn load_safetensors_tensor(
669        &mut self,
670        reader: &mut BufReader<File>,
671        info: &TensorInfo,
672    ) -> Result<Tensor> {
673        // Seek to tensor data
674        reader.seek(SeekFrom::Start(info.data_offsets[0]))?;
675
676        // Read tensor data
677        let data_len = (info.data_offsets[1] - info.data_offsets[0]) as usize;
678        let mut data = vec![0u8; data_len];
679        reader.read_exact(&mut data)?;
680
681        // Convert to tensor based on dtype
682        self.bytes_to_tensor(data, &info.dtype, &info.shape)
683    }
684
685    fn bytes_to_tensor(&self, data: Vec<u8>, dtype: &str, shape: &[usize]) -> Result<Tensor> {
686        match dtype {
687            "F32" => {
688                let floats: Vec<f32> = data
689                    .chunks_exact(4)
690                    .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
691                    .collect();
692                Tensor::from_vec(floats, shape)
693            },
694            "F16" => {
695                // Convert f16 to f32
696                let floats: Vec<f32> = data
697                    .chunks_exact(2)
698                    .map(|chunk| {
699                        let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
700                        half::f16::from_bits(bits).to_f32()
701                    })
702                    .collect();
703                Tensor::from_vec(floats, shape)
704            },
705            "I8" => {
706                let ints: Vec<i8> = data.into_iter().map(|b| b as i8).collect();
707                // Convert to f32 for now
708                let floats: Vec<f32> = ints.into_iter().map(|i| i as f32).collect();
709                Tensor::from_vec(floats, shape)
710            },
711            _ => Err(invalid_format(
712                "dtype",
713                format!("Unsupported dtype: {}", dtype),
714            )),
715        }
716    }
717}
718
719impl WeightLoader for HuggingFaceLoader {
720    fn load_tensor(&mut self, name: &str) -> Result<Tensor> {
721        // Check cache first
722        if let Some(tensor) = self.tensor_cache.get(name) {
723            return Ok(tensor.clone());
724        }
725
726        let filename = self.find_tensor_file(name)?;
727
728        let tensor = match self.detect_format(&filename)? {
729            WeightFormat::HuggingFaceBin => self.load_from_pytorch_bin(name, &filename)?,
730            WeightFormat::SafeTensors => self.load_from_safetensors(name, &filename)?,
731            _ => {
732                return Err(invalid_format("weight format", "Unsupported weight format"));
733            },
734        };
735
736        // Cache if not lazy loading
737        if !self.config.lazy_loading {
738            self.tensor_cache.insert(name.to_string(), tensor.clone());
739        }
740
741        Ok(tensor)
742    }
743
744    fn list_tensors(&self) -> Result<Vec<String>> {
745        Ok(self.index.weight_map.keys().cloned().collect())
746    }
747
748    fn tensor_info(&self, name: &str) -> Result<Option<TensorMetadata>> {
749        let filename = self.find_tensor_file(name)?;
750        Ok(Some(self.get_tensor_metadata(name, &filename)?))
751    }
752
753    fn close(&mut self) -> Result<()> {
754        self.file_handles.clear();
755        self.tensor_cache.clear();
756        Ok(())
757    }
758}
759
760impl LazyTensor {
761    pub fn load(&self) -> Result<Tensor> {
762        // Create a temporary loader instance to load this specific tensor
763        let mut temp_loader = HuggingFaceLoader::new(&self.model_dir, self.config.clone())?;
764        temp_loader.load_tensor(&self.name)
765    }
766
767    pub fn metadata(&self) -> &TensorMetadata {
768        &self.metadata
769    }
770}