trustformers_models/weight_loading/
memory_mapped.rs1use std::fs::File;
6use std::path::Path;
7use trustformers_core::{
8 errors::{invalid_format, Result, TrustformersError},
9 tensor::Tensor,
10};
11
12use super::config::WeightDataType;
13use super::huggingface::{SafeTensorsHeader, TensorMetadata, WeightLoader};
14
15pub struct MemoryMappedLoader {
17 #[allow(dead_code)]
18 file: File,
19 mapping: Option<memmap2::Mmap>,
20 header: SafeTensorsHeader,
21}
22
23impl MemoryMappedLoader {
24 pub fn new(path: impl AsRef<Path>) -> Result<Self> {
25 let file = File::open(path)?;
26 let mapping = unsafe { memmap2::Mmap::map(&file)? };
27
28 let header = Self::parse_header_from_mmap(&mapping)?;
30
31 Ok(Self {
32 file,
33 mapping: Some(mapping),
34 header,
35 })
36 }
37
38 fn parse_header_from_mmap(mmap: &[u8]) -> Result<SafeTensorsHeader> {
39 let header_len = u64::from_le_bytes([
41 mmap[0], mmap[1], mmap[2], mmap[3], mmap[4], mmap[5], mmap[6], mmap[7],
42 ]);
43
44 let header_bytes = &mmap[8..8 + header_len as usize];
46 let header_str = std::str::from_utf8(header_bytes).map_err(|e| {
47 TrustformersError::weight_load_error(format!(
48 "Invalid UTF-8 in SafeTensors header: {}",
49 e
50 ))
51 })?;
52
53 serde_json::from_str(header_str).map_err(|e| {
54 TrustformersError::serialization_error(format!(
55 "Failed to parse SafeTensors header: {}",
56 e
57 ))
58 })
59 }
60
61 fn mmap_bytes_to_tensor(&self, data: &[u8], dtype: &str, shape: &[usize]) -> Result<Tensor> {
62 match dtype {
63 "F32" => {
64 let floats: Vec<f32> = data
65 .chunks_exact(4)
66 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
67 .collect();
68 Tensor::from_vec(floats, shape)
69 },
70 "F16" => {
71 let floats: Vec<f32> = data
72 .chunks_exact(2)
73 .map(|chunk| {
74 let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
75 half::f16::from_bits(bits).to_f32()
76 })
77 .collect();
78 Tensor::from_vec(floats, shape)
79 },
80 _ => Err(invalid_format(
81 "data type",
82 format!("Unsupported dtype: {}", dtype),
83 )),
84 }
85 }
86}
87
88impl WeightLoader for MemoryMappedLoader {
89 fn load_tensor(&mut self, name: &str) -> Result<Tensor> {
90 if let Some(tensor_info) = self.header.tensors.get(name) {
91 if let Some(ref mapping) = self.mapping {
92 let start = tensor_info.data_offsets[0] as usize;
93 let end = tensor_info.data_offsets[1] as usize;
94 let data = &mapping[start..end];
95
96 self.mmap_bytes_to_tensor(data, &tensor_info.dtype, &tensor_info.shape)
97 } else {
98 Err(TrustformersError::invalid_state(
99 "No memory mapping".to_string(),
100 ))
101 }
102 } else {
103 Err(TrustformersError::runtime_error(format!(
104 "Tensor not found: {}",
105 name
106 )))
107 }
108 }
109
110 fn list_tensors(&self) -> Result<Vec<String>> {
111 Ok(self.header.tensors.keys().cloned().collect())
112 }
113
114 fn tensor_info(&self, name: &str) -> Result<Option<TensorMetadata>> {
115 if let Some(tensor_info) = self.header.tensors.get(name) {
116 let dtype = match tensor_info.dtype.as_str() {
117 "F32" => WeightDataType::Float32,
118 "F16" => WeightDataType::Float16,
119 "I8" => WeightDataType::Int8,
120 _ => WeightDataType::Float32,
121 };
122
123 Ok(Some(TensorMetadata {
124 shape: tensor_info.shape.clone(),
125 dtype,
126 size_bytes: tensor_info.data_offsets[1] - tensor_info.data_offsets[0],
127 offset: tensor_info.data_offsets[0],
128 }))
129 } else {
130 Ok(None)
131 }
132 }
133
134 fn close(&mut self) -> Result<()> {
135 self.mapping.take();
136 Ok(())
137 }
138}