1use serde::Deserialize;
2use std::collections::HashMap;
6use std::fs::File;
7use std::io::{BufReader, Read, Seek, SeekFrom};
8use std::path::{Path, PathBuf};
9use trustformers_core::{
10 errors::{invalid_format, runtime_error, Result, TrustformersError},
11 tensor::Tensor,
12};
13
14use super::config::{WeightDataType, WeightFormat, WeightLoadingConfig};
15
16#[derive(Debug, Deserialize)]
18pub struct HuggingFaceIndex {
19 pub metadata: HuggingFaceMetadata,
20 pub weight_map: HashMap<String, String>,
21}
22
23#[derive(Debug, Deserialize)]
24pub struct HuggingFaceMetadata {
25 pub total_size: u64,
26 pub format: String,
27}
28
29#[derive(Debug)]
36pub struct SafeTensorsHeader {
37 pub metadata: Option<HashMap<String, String>>,
38 pub tensors: HashMap<String, TensorInfo>,
39}
40
41impl<'de> serde::Deserialize<'de> for SafeTensorsHeader {
42 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
43 where
44 D: serde::Deserializer<'de>,
45 {
46 let mut map: HashMap<String, serde_json::Value> = HashMap::deserialize(deserializer)?;
48
49 let metadata = map.remove("__metadata__").and_then(|v| serde_json::from_value(v).ok());
51
52 let tensors: HashMap<String, TensorInfo> = map
54 .into_iter()
55 .filter_map(|(k, v)| serde_json::from_value(v).ok().map(|info| (k, info)))
56 .collect();
57
58 Ok(SafeTensorsHeader { metadata, tensors })
59 }
60}
61
62#[derive(Debug, Clone, Deserialize)]
63pub struct TensorInfo {
64 pub dtype: String,
65 pub shape: Vec<usize>,
66 pub data_offsets: [u64; 2],
67}
68
69#[derive(Debug)]
71struct PyTorchTensorInfo {
72 pub shape: Vec<usize>,
73 pub dtype: WeightDataType,
74 pub data_offset: usize,
75}
76
77pub trait WeightLoader {
79 fn load_tensor(&mut self, name: &str) -> Result<Tensor>;
80 fn list_tensors(&self) -> Result<Vec<String>>;
81 fn tensor_info(&self, name: &str) -> Result<Option<TensorMetadata>>;
82 fn close(&mut self) -> Result<()>;
83}
84
85#[derive(Debug, Clone)]
87pub struct TensorMetadata {
88 pub shape: Vec<usize>,
89 pub dtype: WeightDataType,
90 pub size_bytes: u64,
91 pub offset: u64,
92}
93
94pub struct LazyTensor {
96 name: String,
97 #[allow(dead_code)]
98 filename: String,
99 metadata: TensorMetadata,
100 model_dir: PathBuf,
101 config: WeightLoadingConfig,
102}
103
104pub struct HuggingFaceLoader {
106 config: WeightLoadingConfig,
107 index: HuggingFaceIndex,
108 file_handles: HashMap<String, BufReader<File>>,
109 model_dir: PathBuf,
110 tensor_cache: HashMap<String, Tensor>,
111}
112
113impl HuggingFaceLoader {
114 pub fn new(model_dir: impl AsRef<Path>, config: WeightLoadingConfig) -> Result<Self> {
115 let model_dir = model_dir.as_ref().to_path_buf();
116
117 let index_path = model_dir.join("pytorch_model.bin.index.json");
119 let index = if index_path.exists() {
120 Self::load_index(&index_path)?
121 } else {
122 Self::create_single_file_index(&model_dir)?
124 };
125
126 Ok(Self {
127 config,
128 index,
129 file_handles: HashMap::new(),
130 model_dir,
131 tensor_cache: HashMap::new(),
132 })
133 }
134
135 fn load_index(path: &Path) -> Result<HuggingFaceIndex> {
136 let file = File::open(path)?;
137 let reader = BufReader::new(file);
138 serde_json::from_reader(reader).map_err(|e| {
139 TrustformersError::weight_load_error(format!(
140 "Failed to parse HuggingFace index: {}",
141 e
142 ))
143 })
144 }
145
146 fn create_single_file_index(model_dir: &Path) -> Result<HuggingFaceIndex> {
147 let bin_path = model_dir.join("pytorch_model.bin");
149 let safetensors_path = model_dir.join("model.safetensors");
150
151 let (weight_file, is_safetensors) = if safetensors_path.exists() {
152 ("model.safetensors", true)
153 } else if bin_path.exists() {
154 ("pytorch_model.bin", false)
155 } else {
156 return Err(TrustformersError::file_not_found(
157 "No weight files found in model directory".to_string(),
158 ));
159 };
160
161 let mut weight_map = HashMap::new();
163
164 if is_safetensors {
165 match Self::read_safetensors_tensor_names(&model_dir.join(weight_file)) {
167 Ok(tensor_names) => {
168 for name in tensor_names {
170 weight_map.insert(name, weight_file.to_string());
171 }
172 },
173 Err(e) => {
174 eprintln!(
175 "Warning: Failed to read SafeTensors header: {}. Using fallback index.",
176 e
177 );
178 weight_map.insert("*".to_string(), weight_file.to_string());
180 },
181 }
182 } else {
183 weight_map.insert("*".to_string(), weight_file.to_string());
185 }
186
187 Ok(HuggingFaceIndex {
188 metadata: HuggingFaceMetadata {
189 total_size: 0,
190 format: if is_safetensors { "safetensors" } else { "pytorch" }.to_string(),
191 },
192 weight_map,
193 })
194 }
195
196 fn read_safetensors_tensor_names(path: &Path) -> Result<Vec<String>> {
197 use std::io::Read;
198
199 let file = File::open(path)?;
200 let mut reader = BufReader::new(file);
201
202 let mut header_len_bytes = [0u8; 8];
204 reader.read_exact(&mut header_len_bytes)?;
205 let header_len = u64::from_le_bytes(header_len_bytes);
206
207 let mut header_bytes = vec![0u8; header_len as usize];
209 reader.read_exact(&mut header_bytes)?;
210 let header_str = String::from_utf8(header_bytes).map_err(|e| {
211 TrustformersError::weight_load_error(format!(
212 "Invalid UTF-8 in SafeTensors header: {}",
213 e
214 ))
215 })?;
216
217 let header: serde_json::Value = serde_json::from_str(&header_str).map_err(|e| {
219 TrustformersError::weight_load_error(format!(
220 "Failed to parse SafeTensors header: {}",
221 e
222 ))
223 })?;
224
225 let mut tensor_names = Vec::new();
226 if let Some(obj) = header.as_object() {
227 for (key, _value) in obj {
228 if key != "__metadata__" {
230 tensor_names.push(key.clone());
231 }
232 }
233 }
234
235 Ok(tensor_names)
236 }
237
238 fn get_file_handle(&mut self, filename: &str) -> Result<&mut BufReader<File>> {
239 if !self.file_handles.contains_key(filename) {
240 let file_path = self.model_dir.join(filename);
241 let file = File::open(&file_path)?;
242 let reader = BufReader::new(file);
243 self.file_handles.insert(filename.to_string(), reader);
244 }
245
246 self.file_handles.get_mut(filename).ok_or_else(|| {
247 TrustformersError::runtime_error(format!(
248 "File handle for {} not found after insertion",
249 filename
250 ))
251 })
252 }
253
254 fn load_from_pytorch_bin(&mut self, name: &str, filename: &str) -> Result<Tensor> {
256 let reader = self.get_file_handle(filename)?;
257
258 let mut buffer = Vec::new();
260 reader.read_to_end(&mut buffer).map_err(|e| {
261 TrustformersError::weight_load_error(format!("Failed to read tensor file: {}", e))
262 })?;
263
264 match Self::parse_pytorch_pickle_static(&buffer, name) {
267 Ok(tensor) => Ok(tensor),
268 Err(e) => {
269 eprintln!(
271 "Warning: Pickle parsing failed for {}: {}. Attempting raw tensor parsing.",
272 name, e
273 );
274 Self::parse_raw_tensor_data_static(&buffer, name)
275 },
276 }
277 }
278
279 #[allow(dead_code)]
280 fn parse_pytorch_tensor(&mut self, reader: &mut BufReader<File>, name: &str) -> Result<Tensor> {
281 let mut buffer = Vec::new();
286 reader.read_to_end(&mut buffer).map_err(|e| {
287 TrustformersError::weight_load_error(format!("Failed to read tensor file: {}", e))
288 })?;
289
290 match Self::parse_pytorch_pickle_static(&buffer, name) {
293 Ok(tensor) => Ok(tensor),
294 Err(e) => {
295 eprintln!(
297 "Warning: Pickle parsing failed for {}: {}. Attempting raw tensor parsing.",
298 name, e
299 );
300 Self::parse_raw_tensor_data_static(&buffer, name)
301 },
302 }
303 }
304
305 #[allow(dead_code)]
306 fn parse_pytorch_pickle(&self, data: &[u8], name: &str) -> Result<Tensor> {
307 Self::parse_pytorch_pickle_static(data, name)
308 }
309
310 fn parse_pytorch_pickle_static(data: &[u8], name: &str) -> Result<Tensor> {
311 if data.len() < 8 {
319 return Err(TrustformersError::weight_load_error(
320 "File too small to contain tensor data".to_string(),
321 ));
322 }
323
324 if let Some(tensor_info) = Self::extract_pytorch_tensor_info_static(data, name) {
327 let offset = tensor_info.data_offset;
328 let shape = tensor_info.shape;
329 let dtype = tensor_info.dtype;
330 let total_elements: usize = shape.iter().product();
331
332 match dtype {
333 WeightDataType::Float32 => {
334 let data_size = total_elements * 4;
335 if offset + data_size <= data.len() {
336 let tensor_data = &data[offset..offset + data_size];
337 let float_data: Vec<f32> = tensor_data
338 .chunks_exact(4)
339 .map(|chunk| {
340 f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])
341 })
342 .collect();
343
344 Tensor::from_vec(float_data, &shape).map_err(|e| {
345 TrustformersError::weight_load_error(format!(
346 "Failed to create tensor: {}",
347 e
348 ))
349 })
350 } else {
351 Err(TrustformersError::weight_load_error(
352 "Insufficient data for tensor".to_string(),
353 ))
354 }
355 },
356 WeightDataType::Float16 => {
357 let data_size = total_elements * 2;
358 if offset + data_size <= data.len() {
359 let tensor_data = &data[offset..offset + data_size];
360 let float_data: Vec<f32> = tensor_data
361 .chunks_exact(2)
362 .map(|chunk| {
363 let half_val = half::f16::from_le_bytes([chunk[0], chunk[1]]);
364 half_val.to_f32()
365 })
366 .collect();
367
368 Tensor::from_vec(float_data, &shape).map_err(|e| {
369 TrustformersError::weight_load_error(format!(
370 "Failed to create tensor: {}",
371 e
372 ))
373 })
374 } else {
375 Err(TrustformersError::weight_load_error(
376 "Insufficient data for tensor".to_string(),
377 ))
378 }
379 },
380 _ => Err(TrustformersError::weight_load_error(format!(
381 "Unsupported tensor dtype: {:?}",
382 dtype
383 ))),
384 }
385 } else {
386 Err(TrustformersError::weight_load_error(
387 "Could not extract tensor information from pickle data".to_string(),
388 ))
389 }
390 }
391
392 #[allow(dead_code)]
393 fn extract_pytorch_tensor_info(&self, data: &[u8], name: &str) -> Option<PyTorchTensorInfo> {
394 Self::extract_pytorch_tensor_info_static(data, name)
395 }
396
397 fn extract_pytorch_tensor_info_static(data: &[u8], name: &str) -> Option<PyTorchTensorInfo> {
398 let shape = Self::infer_tensor_shape_static(name);
403 let dtype = WeightDataType::Float32; let mut data_offset = 0;
408
409 for i in 0..data.len().saturating_sub(16) {
411 if Self::looks_like_tensor_data_static(&data[i..i.min(i + 16)]) {
413 data_offset = i;
414 break;
415 }
416 }
417
418 if data_offset == 0 && data.len() > 1024 {
420 data_offset = 1024; }
422
423 Some(PyTorchTensorInfo {
424 shape,
425 dtype,
426 data_offset,
427 })
428 }
429
430 #[allow(dead_code)]
431 fn infer_tensor_shape(&self, name: &str) -> Vec<usize> {
432 Self::infer_tensor_shape_static(name)
433 }
434
435 fn infer_tensor_shape_static(name: &str) -> Vec<usize> {
436 if name.contains("embeddings.word_embeddings.weight") {
440 vec![30522, 768] } else if name.contains("embeddings.position_embeddings.weight") {
442 vec![512, 768] } else if name.contains("attention.self.query.weight")
444 || name.contains("attention.self.key.weight")
445 || name.contains("attention.self.value.weight")
446 {
447 vec![768, 768] } else if name.contains("attention.output.dense.weight") {
449 vec![768, 768] } else if name.contains("intermediate.dense.weight") {
451 vec![768, 3072] } else if name.contains("output.dense.weight") {
453 vec![3072, 768] } else if name.contains("LayerNorm.weight") || name.contains("LayerNorm.bias") {
455 vec![768] } else if name.contains("bias") {
457 vec![768] } else {
459 vec![768, 768]
461 }
462 }
463
464 #[allow(dead_code)]
465 fn looks_like_tensor_data(&self, chunk: &[u8]) -> bool {
466 Self::looks_like_tensor_data_static(chunk)
467 }
468
469 fn looks_like_tensor_data_static(chunk: &[u8]) -> bool {
470 if chunk.len() < 4 {
472 return false;
473 }
474
475 let float_val = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
477
478 float_val.is_finite() && float_val.abs() < 100.0
480 }
481
482 #[allow(dead_code)]
483 fn parse_raw_tensor_data(&self, data: &[u8], name: &str) -> Result<Tensor> {
484 Self::parse_raw_tensor_data_static(data, name)
485 }
486
487 fn parse_raw_tensor_data_static(data: &[u8], name: &str) -> Result<Tensor> {
488 let shape = Self::infer_tensor_shape_static(name);
490 let total_elements: usize = shape.iter().product();
491 let expected_size = total_elements * 4; if data.len() >= expected_size {
494 for offset in (0..1024.min(data.len())).step_by(4) {
496 if offset + expected_size <= data.len() {
497 let tensor_data = &data[offset..offset + expected_size];
498 let float_data: Vec<f32> = tensor_data
499 .chunks_exact(4)
500 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
501 .collect();
502
503 if float_data.iter().any(|&x| x.is_finite() && x.abs() < 100.0) {
505 if let Ok(tensor) = Tensor::from_vec(float_data, &shape) {
506 return Ok(tensor);
507 }
508 }
509 }
510 }
511 }
512
513 Err(TrustformersError::weight_load_error(format!(
514 "Could not parse tensor data for {}",
515 name
516 )))
517 }
518
519 #[allow(dead_code)]
521 fn load_lazy(&mut self, name: &str) -> Result<LazyTensor> {
522 let filename = self.find_tensor_file(name)?;
523 let metadata = self.get_tensor_metadata(name, &filename)?;
524
525 Ok(LazyTensor {
526 name: name.to_string(),
527 filename,
528 metadata,
529 model_dir: self.model_dir.clone(),
530 config: self.config.clone(),
531 })
532 }
533
534 fn find_tensor_file(&self, name: &str) -> Result<String> {
535 if let Some(filename) = self.index.weight_map.get(name) {
537 Ok(filename.clone())
538 } else if let Some(filename) = self.index.weight_map.get("*") {
539 Ok(filename.clone())
541 } else {
542 Err(runtime_error(format!("Tensor not found: {}", name)))
543 }
544 }
545
546 fn get_tensor_metadata(&self, _name: &str, _filename: &str) -> Result<TensorMetadata> {
547 Ok(TensorMetadata {
549 shape: vec![1024, 768],
550 dtype: WeightDataType::Float32,
551 size_bytes: 1024 * 768 * 4,
552 offset: 0,
553 })
554 }
555
556 fn detect_format(&self, filename: &str) -> Result<WeightFormat> {
557 if filename.ends_with(".bin") {
558 Ok(WeightFormat::HuggingFaceBin)
559 } else if filename.ends_with(".safetensors") {
560 Ok(WeightFormat::SafeTensors)
561 } else {
562 Err(invalid_format(
563 "file format",
564 format!("Unknown format for file: {}", filename),
565 ))
566 }
567 }
568
569 fn load_from_safetensors(&mut self, name: &str, filename: &str) -> Result<Tensor> {
570 self.load_safetensors_tensor_complete(name, filename)
572 }
573
574 fn load_safetensors_tensor_complete(&mut self, name: &str, filename: &str) -> Result<Tensor> {
575 let file_path = self.model_dir.join(filename);
580 eprintln!(
581 "[SAFETENSORS DEBUG] Loading tensor '{}' from file: {:?}",
582 name, file_path
583 );
584 let file = File::open(&file_path)?;
585 let mut reader = BufReader::new(file);
586
587 let mut header_len_bytes = [0u8; 8];
589 reader.read_exact(&mut header_len_bytes)?;
590 let header_len = u64::from_le_bytes(header_len_bytes);
591 eprintln!("[SAFETENSORS DEBUG] Header length: {} bytes", header_len);
592
593 let mut header_bytes = vec![0u8; header_len as usize];
595 reader.read_exact(&mut header_bytes)?;
596
597 let header_str = std::str::from_utf8(&header_bytes).map_err(|e| {
598 TrustformersError::weight_load_error(format!(
599 "Invalid UTF-8 in SafeTensors header: {}",
600 e
601 ))
602 })?;
603
604 eprintln!(
606 "[SAFETENSORS DEBUG] Header preview (first 500 chars): {}",
607 &header_str[..header_str.len().min(500)]
608 );
609
610 let header: SafeTensorsHeader = serde_json::from_str(header_str).map_err(|e| {
611 eprintln!("[SAFETENSORS DEBUG] Failed to parse header, printing full header:");
612 eprintln!("{}", header_str);
613 TrustformersError::serialization_error(format!(
614 "Failed to parse SafeTensors header: {}",
615 e
616 ))
617 })?;
618
619 if let Some(tensor_info) = header.tensors.get(name) {
620 let tensor_data_start = 8 + header_len;
623 reader.seek(SeekFrom::Start(
624 tensor_data_start + tensor_info.data_offsets[0],
625 ))?;
626
627 let data_len = (tensor_info.data_offsets[1] - tensor_info.data_offsets[0]) as usize;
629 let mut data = vec![0u8; data_len];
630 reader.read_exact(&mut data)?;
631
632 self.bytes_to_tensor(data, &tensor_info.dtype, &tensor_info.shape)
634 } else {
635 Err(runtime_error(format!("Tensor not found: {}", name)))
636 }
637 }
638
639 #[allow(dead_code)]
640 fn parse_safetensors_header(
641 &mut self,
642 reader: &mut BufReader<File>,
643 ) -> Result<SafeTensorsHeader> {
644 let mut header_len_bytes = [0u8; 8];
646 reader.read_exact(&mut header_len_bytes)?;
647 let header_len = u64::from_le_bytes(header_len_bytes);
648
649 let mut header_bytes = vec![0u8; header_len as usize];
651 reader.read_exact(&mut header_bytes)?;
652
653 let header_str = std::str::from_utf8(&header_bytes).map_err(|e| {
654 TrustformersError::weight_load_error(format!(
655 "Invalid UTF-8 in SafeTensors header: {}",
656 e
657 ))
658 })?;
659 serde_json::from_str(header_str).map_err(|e| {
660 TrustformersError::serialization_error(format!(
661 "Failed to parse SafeTensors header: {}",
662 e
663 ))
664 })
665 }
666
667 #[allow(dead_code)]
668 fn load_safetensors_tensor(
669 &mut self,
670 reader: &mut BufReader<File>,
671 info: &TensorInfo,
672 ) -> Result<Tensor> {
673 reader.seek(SeekFrom::Start(info.data_offsets[0]))?;
675
676 let data_len = (info.data_offsets[1] - info.data_offsets[0]) as usize;
678 let mut data = vec![0u8; data_len];
679 reader.read_exact(&mut data)?;
680
681 self.bytes_to_tensor(data, &info.dtype, &info.shape)
683 }
684
685 fn bytes_to_tensor(&self, data: Vec<u8>, dtype: &str, shape: &[usize]) -> Result<Tensor> {
686 match dtype {
687 "F32" => {
688 let floats: Vec<f32> = data
689 .chunks_exact(4)
690 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
691 .collect();
692 Tensor::from_vec(floats, shape)
693 },
694 "F16" => {
695 let floats: Vec<f32> = data
697 .chunks_exact(2)
698 .map(|chunk| {
699 let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
700 half::f16::from_bits(bits).to_f32()
701 })
702 .collect();
703 Tensor::from_vec(floats, shape)
704 },
705 "I8" => {
706 let ints: Vec<i8> = data.into_iter().map(|b| b as i8).collect();
707 let floats: Vec<f32> = ints.into_iter().map(|i| i as f32).collect();
709 Tensor::from_vec(floats, shape)
710 },
711 _ => Err(invalid_format(
712 "dtype",
713 format!("Unsupported dtype: {}", dtype),
714 )),
715 }
716 }
717}
718
719impl WeightLoader for HuggingFaceLoader {
720 fn load_tensor(&mut self, name: &str) -> Result<Tensor> {
721 if let Some(tensor) = self.tensor_cache.get(name) {
723 return Ok(tensor.clone());
724 }
725
726 let filename = self.find_tensor_file(name)?;
727
728 let tensor = match self.detect_format(&filename)? {
729 WeightFormat::HuggingFaceBin => self.load_from_pytorch_bin(name, &filename)?,
730 WeightFormat::SafeTensors => self.load_from_safetensors(name, &filename)?,
731 _ => {
732 return Err(invalid_format("weight format", "Unsupported weight format"));
733 },
734 };
735
736 if !self.config.lazy_loading {
738 self.tensor_cache.insert(name.to_string(), tensor.clone());
739 }
740
741 Ok(tensor)
742 }
743
744 fn list_tensors(&self) -> Result<Vec<String>> {
745 Ok(self.index.weight_map.keys().cloned().collect())
746 }
747
748 fn tensor_info(&self, name: &str) -> Result<Option<TensorMetadata>> {
749 let filename = self.find_tensor_file(name)?;
750 Ok(Some(self.get_tensor_metadata(name, &filename)?))
751 }
752
753 fn close(&mut self) -> Result<()> {
754 self.file_handles.clear();
755 self.tensor_cache.clear();
756 Ok(())
757 }
758}
759
760impl LazyTensor {
761 pub fn load(&self) -> Result<Tensor> {
762 let mut temp_loader = HuggingFaceLoader::new(&self.model_dir, self.config.clone())?;
764 temp_loader.load_tensor(&self.name)
765 }
766
767 pub fn metadata(&self) -> &TensorMetadata {
768 &self.metadata
769 }
770}