ruvector_sona/export/
safetensors.rs

1//! SafeTensors Export - PEFT-compatible LoRA weight serialization
2//!
3//! Exports SONA's learned LoRA weights in SafeTensors format for use with
4//! HuggingFace's PEFT library and transformers ecosystem.
5
6use crate::engine::SonaEngine;
7use crate::lora::{MicroLoRA, BaseLoRA};
8use super::{ExportConfig, ExportResult, ExportType, ExportError};
9use std::path::Path;
10use std::collections::HashMap;
11
12#[cfg(feature = "serde-support")]
13use serde::{Deserialize, Serialize};
14
15/// SafeTensors exporter for LoRA weights
16pub struct SafeTensorsExporter<'a> {
17    config: &'a ExportConfig,
18}
19
20impl<'a> SafeTensorsExporter<'a> {
21    /// Create new SafeTensors exporter
22    pub fn new(config: &'a ExportConfig) -> Self {
23        Self { config }
24    }
25
26    /// Export engine's LoRA weights to SafeTensors format
27    pub fn export_engine<P: AsRef<Path>>(
28        &self,
29        engine: &SonaEngine,
30        output_dir: P,
31    ) -> Result<ExportResult, ExportError> {
32        let output_dir = output_dir.as_ref();
33        std::fs::create_dir_all(output_dir).map_err(ExportError::Io)?;
34
35        // Get LoRA state from engine
36        let lora_state = engine.export_lora_state();
37
38        // Build tensor data map
39        let mut tensors: HashMap<String, TensorData> = HashMap::new();
40
41        // Export MicroLoRA weights (rank 1-2)
42        for (i, layer) in lora_state.micro_lora_layers.iter().enumerate() {
43            let a_key = format!("base_model.model.layers.{}.self_attn.micro_lora_A.weight", i);
44            let b_key = format!("base_model.model.layers.{}.self_attn.micro_lora_B.weight", i);
45
46            tensors.insert(a_key, TensorData {
47                data: layer.lora_a.clone(),
48                shape: vec![layer.rank, layer.input_dim],
49                dtype: "F32".to_string(),
50            });
51
52            tensors.insert(b_key, TensorData {
53                data: layer.lora_b.clone(),
54                shape: vec![layer.output_dim, layer.rank],
55                dtype: "F32".to_string(),
56            });
57        }
58
59        // Export BaseLoRA weights (rank 4-16)
60        for (i, layer) in lora_state.base_lora_layers.iter().enumerate() {
61            // Q projection
62            let q_a_key = format!("base_model.model.layers.{}.self_attn.q_proj.lora_A.weight", i);
63            let q_b_key = format!("base_model.model.layers.{}.self_attn.q_proj.lora_B.weight", i);
64
65            tensors.insert(q_a_key, TensorData {
66                data: layer.lora_a.clone(),
67                shape: vec![layer.rank, layer.input_dim],
68                dtype: "F32".to_string(),
69            });
70
71            tensors.insert(q_b_key, TensorData {
72                data: layer.lora_b.clone(),
73                shape: vec![layer.output_dim, layer.rank],
74                dtype: "F32".to_string(),
75            });
76
77            // K projection
78            let k_a_key = format!("base_model.model.layers.{}.self_attn.k_proj.lora_A.weight", i);
79            let k_b_key = format!("base_model.model.layers.{}.self_attn.k_proj.lora_B.weight", i);
80
81            tensors.insert(k_a_key, TensorData {
82                data: layer.lora_a.clone(),
83                shape: vec![layer.rank, layer.input_dim],
84                dtype: "F32".to_string(),
85            });
86
87            tensors.insert(k_b_key, TensorData {
88                data: layer.lora_b.clone(),
89                shape: vec![layer.output_dim, layer.rank],
90                dtype: "F32".to_string(),
91            });
92
93            // V projection
94            let v_a_key = format!("base_model.model.layers.{}.self_attn.v_proj.lora_A.weight", i);
95            let v_b_key = format!("base_model.model.layers.{}.self_attn.v_proj.lora_B.weight", i);
96
97            tensors.insert(v_a_key, TensorData {
98                data: layer.lora_a.clone(),
99                shape: vec![layer.rank, layer.input_dim],
100                dtype: "F32".to_string(),
101            });
102
103            tensors.insert(v_b_key, TensorData {
104                data: layer.lora_b.clone(),
105                shape: vec![layer.output_dim, layer.rank],
106                dtype: "F32".to_string(),
107            });
108
109            // O projection
110            let o_a_key = format!("base_model.model.layers.{}.self_attn.o_proj.lora_A.weight", i);
111            let o_b_key = format!("base_model.model.layers.{}.self_attn.o_proj.lora_B.weight", i);
112
113            tensors.insert(o_a_key, TensorData {
114                data: layer.lora_a.clone(),
115                shape: vec![layer.rank, layer.input_dim],
116                dtype: "F32".to_string(),
117            });
118
119            tensors.insert(o_b_key, TensorData {
120                data: layer.lora_b.clone(),
121                shape: vec![layer.output_dim, layer.rank],
122                dtype: "F32".to_string(),
123            });
124        }
125
126        // Serialize to SafeTensors format
127        let safetensors_path = output_dir.join("adapter_model.safetensors");
128        let bytes = self.serialize_safetensors(&tensors)?;
129        std::fs::write(&safetensors_path, &bytes).map_err(ExportError::Io)?;
130
131        let size_bytes = bytes.len() as u64;
132
133        Ok(ExportResult {
134            export_type: ExportType::SafeTensors,
135            items_exported: tensors.len(),
136            output_path: safetensors_path.to_string_lossy().to_string(),
137            size_bytes,
138        })
139    }
140
141    /// Serialize tensors to SafeTensors binary format
142    fn serialize_safetensors(&self, tensors: &HashMap<String, TensorData>) -> Result<Vec<u8>, ExportError> {
143        // SafeTensors format:
144        // 8 bytes: header size (little endian u64)
145        // N bytes: JSON header with tensor metadata
146        // ... tensor data (aligned to 8 bytes)
147
148        let mut header_data: HashMap<String, TensorMetadata> = HashMap::new();
149        let mut data_offset: usize = 0;
150        let mut tensor_bytes: Vec<u8> = Vec::new();
151
152        // Sort keys for deterministic output
153        let mut keys: Vec<_> = tensors.keys().collect();
154        keys.sort();
155
156        for key in keys {
157            let tensor = &tensors[key];
158            let tensor_size = tensor.data.len() * 4; // f32 = 4 bytes
159
160            // Align to 8 bytes
161            let padding = (8 - (tensor_bytes.len() % 8)) % 8;
162            tensor_bytes.extend(vec![0u8; padding]);
163
164            let start_offset = tensor_bytes.len();
165
166            // Write tensor data
167            for &val in &tensor.data {
168                tensor_bytes.extend_from_slice(&val.to_le_bytes());
169            }
170
171            let end_offset = tensor_bytes.len();
172
173            header_data.insert(key.clone(), TensorMetadata {
174                dtype: tensor.dtype.clone(),
175                shape: tensor.shape.clone(),
176                data_offsets: [start_offset, end_offset],
177            });
178        }
179
180        // Serialize header to JSON
181        let header_json = serde_json::to_string(&header_data)
182            .map_err(ExportError::Serialization)?;
183        let header_bytes = header_json.as_bytes();
184
185        // Build final buffer
186        let mut result = Vec::new();
187
188        // Header size (8 bytes, little endian)
189        result.extend_from_slice(&(header_bytes.len() as u64).to_le_bytes());
190
191        // Header JSON
192        result.extend_from_slice(header_bytes);
193
194        // Tensor data
195        result.extend(tensor_bytes);
196
197        Ok(result)
198    }
199}
200
201/// Tensor data for export
202#[derive(Clone, Debug)]
203pub struct TensorData {
204    /// Flattened tensor values
205    pub data: Vec<f32>,
206    /// Tensor shape
207    pub shape: Vec<usize>,
208    /// Data type (F32, F16, BF16, etc.)
209    pub dtype: String,
210}
211
212/// Tensor metadata for SafeTensors header
213#[cfg(feature = "serde-support")]
214#[derive(Clone, Debug, Serialize, Deserialize)]
215struct TensorMetadata {
216    dtype: String,
217    shape: Vec<usize>,
218    data_offsets: [usize; 2],
219}
220
221/// LoRA layer state for export
222#[derive(Clone, Debug)]
223pub struct LoRALayerState {
224    /// LoRA A matrix (rank x input_dim)
225    pub lora_a: Vec<f32>,
226    /// LoRA B matrix (output_dim x rank)
227    pub lora_b: Vec<f32>,
228    /// LoRA rank
229    pub rank: usize,
230    /// Input dimension
231    pub input_dim: usize,
232    /// Output dimension
233    pub output_dim: usize,
234}
235
236/// Complete LoRA state for export
237#[derive(Clone, Debug, Default)]
238pub struct LoRAState {
239    /// MicroLoRA layers (instant adaptation)
240    pub micro_lora_layers: Vec<LoRALayerState>,
241    /// BaseLoRA layers (background learning)
242    pub base_lora_layers: Vec<LoRALayerState>,
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn test_tensor_data_creation() {
251        let tensor = TensorData {
252            data: vec![1.0, 2.0, 3.0, 4.0],
253            shape: vec![2, 2],
254            dtype: "F32".to_string(),
255        };
256
257        assert_eq!(tensor.data.len(), 4);
258        assert_eq!(tensor.shape, vec![2, 2]);
259    }
260
261    #[test]
262    fn test_lora_layer_state() {
263        let state = LoRALayerState {
264            lora_a: vec![0.1, 0.2, 0.3, 0.4],
265            lora_b: vec![0.5, 0.6, 0.7, 0.8],
266            rank: 2,
267            input_dim: 2,
268            output_dim: 2,
269        };
270
271        assert_eq!(state.rank, 2);
272        assert_eq!(state.lora_a.len(), 4);
273    }
274}