1use crate::engine::SonaEngine;
7use crate::lora::{MicroLoRA, BaseLoRA};
8use super::{ExportConfig, ExportResult, ExportType, ExportError};
9use std::path::Path;
10use std::collections::HashMap;
11
12#[cfg(feature = "serde-support")]
13use serde::{Deserialize, Serialize};
14
15pub struct SafeTensorsExporter<'a> {
17 config: &'a ExportConfig,
18}
19
20impl<'a> SafeTensorsExporter<'a> {
21 pub fn new(config: &'a ExportConfig) -> Self {
23 Self { config }
24 }
25
26 pub fn export_engine<P: AsRef<Path>>(
28 &self,
29 engine: &SonaEngine,
30 output_dir: P,
31 ) -> Result<ExportResult, ExportError> {
32 let output_dir = output_dir.as_ref();
33 std::fs::create_dir_all(output_dir).map_err(ExportError::Io)?;
34
35 let lora_state = engine.export_lora_state();
37
38 let mut tensors: HashMap<String, TensorData> = HashMap::new();
40
41 for (i, layer) in lora_state.micro_lora_layers.iter().enumerate() {
43 let a_key = format!("base_model.model.layers.{}.self_attn.micro_lora_A.weight", i);
44 let b_key = format!("base_model.model.layers.{}.self_attn.micro_lora_B.weight", i);
45
46 tensors.insert(a_key, TensorData {
47 data: layer.lora_a.clone(),
48 shape: vec![layer.rank, layer.input_dim],
49 dtype: "F32".to_string(),
50 });
51
52 tensors.insert(b_key, TensorData {
53 data: layer.lora_b.clone(),
54 shape: vec![layer.output_dim, layer.rank],
55 dtype: "F32".to_string(),
56 });
57 }
58
59 for (i, layer) in lora_state.base_lora_layers.iter().enumerate() {
61 let q_a_key = format!("base_model.model.layers.{}.self_attn.q_proj.lora_A.weight", i);
63 let q_b_key = format!("base_model.model.layers.{}.self_attn.q_proj.lora_B.weight", i);
64
65 tensors.insert(q_a_key, TensorData {
66 data: layer.lora_a.clone(),
67 shape: vec![layer.rank, layer.input_dim],
68 dtype: "F32".to_string(),
69 });
70
71 tensors.insert(q_b_key, TensorData {
72 data: layer.lora_b.clone(),
73 shape: vec![layer.output_dim, layer.rank],
74 dtype: "F32".to_string(),
75 });
76
77 let k_a_key = format!("base_model.model.layers.{}.self_attn.k_proj.lora_A.weight", i);
79 let k_b_key = format!("base_model.model.layers.{}.self_attn.k_proj.lora_B.weight", i);
80
81 tensors.insert(k_a_key, TensorData {
82 data: layer.lora_a.clone(),
83 shape: vec![layer.rank, layer.input_dim],
84 dtype: "F32".to_string(),
85 });
86
87 tensors.insert(k_b_key, TensorData {
88 data: layer.lora_b.clone(),
89 shape: vec![layer.output_dim, layer.rank],
90 dtype: "F32".to_string(),
91 });
92
93 let v_a_key = format!("base_model.model.layers.{}.self_attn.v_proj.lora_A.weight", i);
95 let v_b_key = format!("base_model.model.layers.{}.self_attn.v_proj.lora_B.weight", i);
96
97 tensors.insert(v_a_key, TensorData {
98 data: layer.lora_a.clone(),
99 shape: vec![layer.rank, layer.input_dim],
100 dtype: "F32".to_string(),
101 });
102
103 tensors.insert(v_b_key, TensorData {
104 data: layer.lora_b.clone(),
105 shape: vec![layer.output_dim, layer.rank],
106 dtype: "F32".to_string(),
107 });
108
109 let o_a_key = format!("base_model.model.layers.{}.self_attn.o_proj.lora_A.weight", i);
111 let o_b_key = format!("base_model.model.layers.{}.self_attn.o_proj.lora_B.weight", i);
112
113 tensors.insert(o_a_key, TensorData {
114 data: layer.lora_a.clone(),
115 shape: vec![layer.rank, layer.input_dim],
116 dtype: "F32".to_string(),
117 });
118
119 tensors.insert(o_b_key, TensorData {
120 data: layer.lora_b.clone(),
121 shape: vec![layer.output_dim, layer.rank],
122 dtype: "F32".to_string(),
123 });
124 }
125
126 let safetensors_path = output_dir.join("adapter_model.safetensors");
128 let bytes = self.serialize_safetensors(&tensors)?;
129 std::fs::write(&safetensors_path, &bytes).map_err(ExportError::Io)?;
130
131 let size_bytes = bytes.len() as u64;
132
133 Ok(ExportResult {
134 export_type: ExportType::SafeTensors,
135 items_exported: tensors.len(),
136 output_path: safetensors_path.to_string_lossy().to_string(),
137 size_bytes,
138 })
139 }
140
141 fn serialize_safetensors(&self, tensors: &HashMap<String, TensorData>) -> Result<Vec<u8>, ExportError> {
143 let mut header_data: HashMap<String, TensorMetadata> = HashMap::new();
149 let mut data_offset: usize = 0;
150 let mut tensor_bytes: Vec<u8> = Vec::new();
151
152 let mut keys: Vec<_> = tensors.keys().collect();
154 keys.sort();
155
156 for key in keys {
157 let tensor = &tensors[key];
158 let tensor_size = tensor.data.len() * 4; let padding = (8 - (tensor_bytes.len() % 8)) % 8;
162 tensor_bytes.extend(vec![0u8; padding]);
163
164 let start_offset = tensor_bytes.len();
165
166 for &val in &tensor.data {
168 tensor_bytes.extend_from_slice(&val.to_le_bytes());
169 }
170
171 let end_offset = tensor_bytes.len();
172
173 header_data.insert(key.clone(), TensorMetadata {
174 dtype: tensor.dtype.clone(),
175 shape: tensor.shape.clone(),
176 data_offsets: [start_offset, end_offset],
177 });
178 }
179
180 let header_json = serde_json::to_string(&header_data)
182 .map_err(ExportError::Serialization)?;
183 let header_bytes = header_json.as_bytes();
184
185 let mut result = Vec::new();
187
188 result.extend_from_slice(&(header_bytes.len() as u64).to_le_bytes());
190
191 result.extend_from_slice(header_bytes);
193
194 result.extend(tensor_bytes);
196
197 Ok(result)
198 }
199}
200
201#[derive(Clone, Debug)]
203pub struct TensorData {
204 pub data: Vec<f32>,
206 pub shape: Vec<usize>,
208 pub dtype: String,
210}
211
212#[cfg(feature = "serde-support")]
214#[derive(Clone, Debug, Serialize, Deserialize)]
215struct TensorMetadata {
216 dtype: String,
217 shape: Vec<usize>,
218 data_offsets: [usize; 2],
219}
220
221#[derive(Clone, Debug)]
223pub struct LoRALayerState {
224 pub lora_a: Vec<f32>,
226 pub lora_b: Vec<f32>,
228 pub rank: usize,
230 pub input_dim: usize,
232 pub output_dim: usize,
234}
235
236#[derive(Clone, Debug, Default)]
238pub struct LoRAState {
239 pub micro_lora_layers: Vec<LoRALayerState>,
241 pub base_lora_layers: Vec<LoRALayerState>,
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn test_tensor_data_creation() {
251 let tensor = TensorData {
252 data: vec![1.0, 2.0, 3.0, 4.0],
253 shape: vec![2, 2],
254 dtype: "F32".to_string(),
255 };
256
257 assert_eq!(tensor.data.len(), 4);
258 assert_eq!(tensor.shape, vec![2, 2]);
259 }
260
261 #[test]
262 fn test_lora_layer_state() {
263 let state = LoRALayerState {
264 lora_a: vec![0.1, 0.2, 0.3, 0.4],
265 lora_b: vec![0.5, 0.6, 0.7, 0.8],
266 rank: 2,
267 input_dim: 2,
268 output_dim: 2,
269 };
270
271 assert_eq!(state.rank, 2);
272 assert_eq!(state.lora_a.len(), 4);
273 }
274}