Skip to main content

trustformers_models/weight_loading/
memory_mapped.rs

1/// Memory-Mapped Weight Loader
2///
3/// This module provides memory-mapped file support for efficient weight loading without
4/// loading entire files into memory.
5use std::fs::File;
6use std::path::Path;
7use trustformers_core::{
8    errors::{invalid_format, Result, TrustformersError},
9    tensor::Tensor,
10};
11
12use super::config::WeightDataType;
13use super::huggingface::{SafeTensorsHeader, TensorMetadata, WeightLoader};
14
15/// Memory-mapped weight loader for efficient access to large weight files
16pub struct MemoryMappedLoader {
17    #[allow(dead_code)]
18    file: File,
19    mapping: Option<memmap2::Mmap>,
20    header: SafeTensorsHeader,
21}
22
23impl MemoryMappedLoader {
24    pub fn new(path: impl AsRef<Path>) -> Result<Self> {
25        let file = File::open(path)?;
26        let mapping = unsafe { memmap2::Mmap::map(&file)? };
27
28        // Parse header from memory map
29        let header = Self::parse_header_from_mmap(&mapping)?;
30
31        Ok(Self {
32            file,
33            mapping: Some(mapping),
34            header,
35        })
36    }
37
38    fn parse_header_from_mmap(mmap: &[u8]) -> Result<SafeTensorsHeader> {
39        // Read header length
40        let header_len = u64::from_le_bytes([
41            mmap[0], mmap[1], mmap[2], mmap[3], mmap[4], mmap[5], mmap[6], mmap[7],
42        ]);
43
44        // Parse header JSON
45        let header_bytes = &mmap[8..8 + header_len as usize];
46        let header_str = std::str::from_utf8(header_bytes).map_err(|e| {
47            TrustformersError::weight_load_error(format!(
48                "Invalid UTF-8 in SafeTensors header: {}",
49                e
50            ))
51        })?;
52
53        serde_json::from_str(header_str).map_err(|e| {
54            TrustformersError::serialization_error(format!(
55                "Failed to parse SafeTensors header: {}",
56                e
57            ))
58        })
59    }
60
61    fn mmap_bytes_to_tensor(&self, data: &[u8], dtype: &str, shape: &[usize]) -> Result<Tensor> {
62        match dtype {
63            "F32" => {
64                let floats: Vec<f32> = data
65                    .chunks_exact(4)
66                    .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
67                    .collect();
68                Tensor::from_vec(floats, shape)
69            },
70            "F16" => {
71                let floats: Vec<f32> = data
72                    .chunks_exact(2)
73                    .map(|chunk| {
74                        let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
75                        half::f16::from_bits(bits).to_f32()
76                    })
77                    .collect();
78                Tensor::from_vec(floats, shape)
79            },
80            _ => Err(invalid_format(
81                "data type",
82                format!("Unsupported dtype: {}", dtype),
83            )),
84        }
85    }
86}
87
88impl WeightLoader for MemoryMappedLoader {
89    fn load_tensor(&mut self, name: &str) -> Result<Tensor> {
90        if let Some(tensor_info) = self.header.tensors.get(name) {
91            if let Some(ref mapping) = self.mapping {
92                let start = tensor_info.data_offsets[0] as usize;
93                let end = tensor_info.data_offsets[1] as usize;
94                let data = &mapping[start..end];
95
96                self.mmap_bytes_to_tensor(data, &tensor_info.dtype, &tensor_info.shape)
97            } else {
98                Err(TrustformersError::invalid_state(
99                    "No memory mapping".to_string(),
100                ))
101            }
102        } else {
103            Err(TrustformersError::runtime_error(format!(
104                "Tensor not found: {}",
105                name
106            )))
107        }
108    }
109
110    fn list_tensors(&self) -> Result<Vec<String>> {
111        Ok(self.header.tensors.keys().cloned().collect())
112    }
113
114    fn tensor_info(&self, name: &str) -> Result<Option<TensorMetadata>> {
115        if let Some(tensor_info) = self.header.tensors.get(name) {
116            let dtype = match tensor_info.dtype.as_str() {
117                "F32" => WeightDataType::Float32,
118                "F16" => WeightDataType::Float16,
119                "I8" => WeightDataType::Int8,
120                _ => WeightDataType::Float32,
121            };
122
123            Ok(Some(TensorMetadata {
124                shape: tensor_info.shape.clone(),
125                dtype,
126                size_bytes: tensor_info.data_offsets[1] - tensor_info.data_offsets[0],
127                offset: tensor_info.data_offsets[0],
128            }))
129        } else {
130            Ok(None)
131        }
132    }
133
134    fn close(&mut self) -> Result<()> {
135        self.mapping.take();
136        Ok(())
137    }
138}