Skip to main content

ruvector_sona/export/
safetensors.rs

1//! SafeTensors Export - PEFT-compatible LoRA weight serialization
2//!
3//! Exports SONA's learned LoRA weights in SafeTensors format for use with
4//! HuggingFace's PEFT library and transformers ecosystem.
5
6use super::{ExportConfig, ExportError, ExportResult, ExportType};
7use crate::engine::SonaEngine;
8use std::collections::HashMap;
9use std::path::Path;
10
11#[cfg(feature = "serde-support")]
12use serde::{Deserialize, Serialize};
13
14/// SafeTensors exporter for LoRA weights
15pub struct SafeTensorsExporter<'a> {
16    _config: &'a ExportConfig,
17}
18
19impl<'a> SafeTensorsExporter<'a> {
20    /// Create new SafeTensors exporter
21    pub fn new(config: &'a ExportConfig) -> Self {
22        Self { _config: config }
23    }
24
25    /// Export engine's LoRA weights to SafeTensors format
26    pub fn export_engine<P: AsRef<Path>>(
27        &self,
28        engine: &SonaEngine,
29        output_dir: P,
30    ) -> Result<ExportResult, ExportError> {
31        let output_dir = output_dir.as_ref();
32        std::fs::create_dir_all(output_dir).map_err(ExportError::Io)?;
33
34        // Get LoRA state from engine
35        let lora_state = engine.export_lora_state();
36
37        // Build tensor data map
38        let mut tensors: HashMap<String, TensorData> = HashMap::new();
39
40        // Export MicroLoRA weights (rank 1-2)
41        for (i, layer) in lora_state.micro_lora_layers.iter().enumerate() {
42            let a_key = format!(
43                "base_model.model.layers.{}.self_attn.micro_lora_A.weight",
44                i
45            );
46            let b_key = format!(
47                "base_model.model.layers.{}.self_attn.micro_lora_B.weight",
48                i
49            );
50
51            tensors.insert(
52                a_key,
53                TensorData {
54                    data: layer.lora_a.clone(),
55                    shape: vec![layer.rank, layer.input_dim],
56                    dtype: "F32".to_string(),
57                },
58            );
59
60            tensors.insert(
61                b_key,
62                TensorData {
63                    data: layer.lora_b.clone(),
64                    shape: vec![layer.output_dim, layer.rank],
65                    dtype: "F32".to_string(),
66                },
67            );
68        }
69
70        // Export BaseLoRA weights (rank 4-16)
71        for (i, layer) in lora_state.base_lora_layers.iter().enumerate() {
72            // Q projection
73            let q_a_key = format!(
74                "base_model.model.layers.{}.self_attn.q_proj.lora_A.weight",
75                i
76            );
77            let q_b_key = format!(
78                "base_model.model.layers.{}.self_attn.q_proj.lora_B.weight",
79                i
80            );
81
82            tensors.insert(
83                q_a_key,
84                TensorData {
85                    data: layer.lora_a.clone(),
86                    shape: vec![layer.rank, layer.input_dim],
87                    dtype: "F32".to_string(),
88                },
89            );
90
91            tensors.insert(
92                q_b_key,
93                TensorData {
94                    data: layer.lora_b.clone(),
95                    shape: vec![layer.output_dim, layer.rank],
96                    dtype: "F32".to_string(),
97                },
98            );
99
100            // K projection
101            let k_a_key = format!(
102                "base_model.model.layers.{}.self_attn.k_proj.lora_A.weight",
103                i
104            );
105            let k_b_key = format!(
106                "base_model.model.layers.{}.self_attn.k_proj.lora_B.weight",
107                i
108            );
109
110            tensors.insert(
111                k_a_key,
112                TensorData {
113                    data: layer.lora_a.clone(),
114                    shape: vec![layer.rank, layer.input_dim],
115                    dtype: "F32".to_string(),
116                },
117            );
118
119            tensors.insert(
120                k_b_key,
121                TensorData {
122                    data: layer.lora_b.clone(),
123                    shape: vec![layer.output_dim, layer.rank],
124                    dtype: "F32".to_string(),
125                },
126            );
127
128            // V projection
129            let v_a_key = format!(
130                "base_model.model.layers.{}.self_attn.v_proj.lora_A.weight",
131                i
132            );
133            let v_b_key = format!(
134                "base_model.model.layers.{}.self_attn.v_proj.lora_B.weight",
135                i
136            );
137
138            tensors.insert(
139                v_a_key,
140                TensorData {
141                    data: layer.lora_a.clone(),
142                    shape: vec![layer.rank, layer.input_dim],
143                    dtype: "F32".to_string(),
144                },
145            );
146
147            tensors.insert(
148                v_b_key,
149                TensorData {
150                    data: layer.lora_b.clone(),
151                    shape: vec![layer.output_dim, layer.rank],
152                    dtype: "F32".to_string(),
153                },
154            );
155
156            // O projection
157            let o_a_key = format!(
158                "base_model.model.layers.{}.self_attn.o_proj.lora_A.weight",
159                i
160            );
161            let o_b_key = format!(
162                "base_model.model.layers.{}.self_attn.o_proj.lora_B.weight",
163                i
164            );
165
166            tensors.insert(
167                o_a_key,
168                TensorData {
169                    data: layer.lora_a.clone(),
170                    shape: vec![layer.rank, layer.input_dim],
171                    dtype: "F32".to_string(),
172                },
173            );
174
175            tensors.insert(
176                o_b_key,
177                TensorData {
178                    data: layer.lora_b.clone(),
179                    shape: vec![layer.output_dim, layer.rank],
180                    dtype: "F32".to_string(),
181                },
182            );
183        }
184
185        // Serialize to SafeTensors format
186        let safetensors_path = output_dir.join("adapter_model.safetensors");
187        let bytes = self.serialize_safetensors(&tensors)?;
188        std::fs::write(&safetensors_path, &bytes).map_err(ExportError::Io)?;
189
190        let size_bytes = bytes.len() as u64;
191
192        Ok(ExportResult {
193            export_type: ExportType::SafeTensors,
194            items_exported: tensors.len(),
195            output_path: safetensors_path.to_string_lossy().to_string(),
196            size_bytes,
197        })
198    }
199
200    /// Serialize tensors to SafeTensors binary format
201    fn serialize_safetensors(
202        &self,
203        tensors: &HashMap<String, TensorData>,
204    ) -> Result<Vec<u8>, ExportError> {
205        // SafeTensors format:
206        // 8 bytes: header size (little endian u64)
207        // N bytes: JSON header with tensor metadata
208        // ... tensor data (aligned to 8 bytes)
209
210        let mut header_data: HashMap<String, TensorMetadata> = HashMap::new();
211        let mut tensor_bytes: Vec<u8> = Vec::new();
212
213        // Sort keys for deterministic output
214        let mut keys: Vec<_> = tensors.keys().collect();
215        keys.sort();
216
217        for key in keys {
218            let tensor = &tensors[key];
219
220            // Align to 8 bytes
221            let padding = (8 - (tensor_bytes.len() % 8)) % 8;
222            tensor_bytes.extend(vec![0u8; padding]);
223
224            let start_offset = tensor_bytes.len();
225
226            // Write tensor data
227            for &val in &tensor.data {
228                tensor_bytes.extend_from_slice(&val.to_le_bytes());
229            }
230
231            let end_offset = tensor_bytes.len();
232
233            header_data.insert(
234                key.clone(),
235                TensorMetadata {
236                    dtype: tensor.dtype.clone(),
237                    shape: tensor.shape.clone(),
238                    data_offsets: [start_offset, end_offset],
239                },
240            );
241        }
242
243        // Serialize header to JSON
244        let header_json =
245            serde_json::to_string(&header_data).map_err(ExportError::Serialization)?;
246        let header_bytes = header_json.as_bytes();
247
248        // Build final buffer
249        let mut result = Vec::new();
250
251        // Header size (8 bytes, little endian)
252        result.extend_from_slice(&(header_bytes.len() as u64).to_le_bytes());
253
254        // Header JSON
255        result.extend_from_slice(header_bytes);
256
257        // Tensor data
258        result.extend(tensor_bytes);
259
260        Ok(result)
261    }
262}
263
264/// Tensor data for export
265#[derive(Clone, Debug)]
266pub struct TensorData {
267    /// Flattened tensor values
268    pub data: Vec<f32>,
269    /// Tensor shape
270    pub shape: Vec<usize>,
271    /// Data type (F32, F16, BF16, etc.)
272    pub dtype: String,
273}
274
275/// Tensor metadata for SafeTensors header
276#[cfg(feature = "serde-support")]
277#[derive(Clone, Debug, Serialize, Deserialize)]
278struct TensorMetadata {
279    dtype: String,
280    shape: Vec<usize>,
281    data_offsets: [usize; 2],
282}
283
284/// LoRA layer state for export
285#[derive(Clone, Debug)]
286pub struct LoRALayerState {
287    /// LoRA A matrix (rank x input_dim)
288    pub lora_a: Vec<f32>,
289    /// LoRA B matrix (output_dim x rank)
290    pub lora_b: Vec<f32>,
291    /// LoRA rank
292    pub rank: usize,
293    /// Input dimension
294    pub input_dim: usize,
295    /// Output dimension
296    pub output_dim: usize,
297}
298
299/// Complete LoRA state for export
300#[derive(Clone, Debug, Default)]
301pub struct LoRAState {
302    /// MicroLoRA layers (instant adaptation)
303    pub micro_lora_layers: Vec<LoRALayerState>,
304    /// BaseLoRA layers (background learning)
305    pub base_lora_layers: Vec<LoRALayerState>,
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn test_tensor_data_creation() {
314        let tensor = TensorData {
315            data: vec![1.0, 2.0, 3.0, 4.0],
316            shape: vec![2, 2],
317            dtype: "F32".to_string(),
318        };
319
320        assert_eq!(tensor.data.len(), 4);
321        assert_eq!(tensor.shape, vec![2, 2]);
322    }
323
324    #[test]
325    fn test_lora_layer_state() {
326        let state = LoRALayerState {
327            lora_a: vec![0.1, 0.2, 0.3, 0.4],
328            lora_b: vec![0.5, 0.6, 0.7, 0.8],
329            rank: 2,
330            input_dim: 2,
331            output_dim: 2,
332        };
333
334        assert_eq!(state.rank, 2);
335        assert_eq!(state.lora_a.len(), 4);
336    }
337}