1use super::{ExportConfig, ExportError, ExportResult, ExportType};
7use crate::engine::SonaEngine;
8use std::collections::HashMap;
9use std::path::Path;
10
11#[cfg(feature = "serde-support")]
12use serde::{Deserialize, Serialize};
13
14pub struct SafeTensorsExporter<'a> {
16 _config: &'a ExportConfig,
17}
18
19impl<'a> SafeTensorsExporter<'a> {
20 pub fn new(config: &'a ExportConfig) -> Self {
22 Self { _config: config }
23 }
24
25 pub fn export_engine<P: AsRef<Path>>(
27 &self,
28 engine: &SonaEngine,
29 output_dir: P,
30 ) -> Result<ExportResult, ExportError> {
31 let output_dir = output_dir.as_ref();
32 std::fs::create_dir_all(output_dir).map_err(ExportError::Io)?;
33
34 let lora_state = engine.export_lora_state();
36
37 let mut tensors: HashMap<String, TensorData> = HashMap::new();
39
40 for (i, layer) in lora_state.micro_lora_layers.iter().enumerate() {
42 let a_key = format!(
43 "base_model.model.layers.{}.self_attn.micro_lora_A.weight",
44 i
45 );
46 let b_key = format!(
47 "base_model.model.layers.{}.self_attn.micro_lora_B.weight",
48 i
49 );
50
51 tensors.insert(
52 a_key,
53 TensorData {
54 data: layer.lora_a.clone(),
55 shape: vec![layer.rank, layer.input_dim],
56 dtype: "F32".to_string(),
57 },
58 );
59
60 tensors.insert(
61 b_key,
62 TensorData {
63 data: layer.lora_b.clone(),
64 shape: vec![layer.output_dim, layer.rank],
65 dtype: "F32".to_string(),
66 },
67 );
68 }
69
70 for (i, layer) in lora_state.base_lora_layers.iter().enumerate() {
72 let q_a_key = format!(
74 "base_model.model.layers.{}.self_attn.q_proj.lora_A.weight",
75 i
76 );
77 let q_b_key = format!(
78 "base_model.model.layers.{}.self_attn.q_proj.lora_B.weight",
79 i
80 );
81
82 tensors.insert(
83 q_a_key,
84 TensorData {
85 data: layer.lora_a.clone(),
86 shape: vec![layer.rank, layer.input_dim],
87 dtype: "F32".to_string(),
88 },
89 );
90
91 tensors.insert(
92 q_b_key,
93 TensorData {
94 data: layer.lora_b.clone(),
95 shape: vec![layer.output_dim, layer.rank],
96 dtype: "F32".to_string(),
97 },
98 );
99
100 let k_a_key = format!(
102 "base_model.model.layers.{}.self_attn.k_proj.lora_A.weight",
103 i
104 );
105 let k_b_key = format!(
106 "base_model.model.layers.{}.self_attn.k_proj.lora_B.weight",
107 i
108 );
109
110 tensors.insert(
111 k_a_key,
112 TensorData {
113 data: layer.lora_a.clone(),
114 shape: vec![layer.rank, layer.input_dim],
115 dtype: "F32".to_string(),
116 },
117 );
118
119 tensors.insert(
120 k_b_key,
121 TensorData {
122 data: layer.lora_b.clone(),
123 shape: vec![layer.output_dim, layer.rank],
124 dtype: "F32".to_string(),
125 },
126 );
127
128 let v_a_key = format!(
130 "base_model.model.layers.{}.self_attn.v_proj.lora_A.weight",
131 i
132 );
133 let v_b_key = format!(
134 "base_model.model.layers.{}.self_attn.v_proj.lora_B.weight",
135 i
136 );
137
138 tensors.insert(
139 v_a_key,
140 TensorData {
141 data: layer.lora_a.clone(),
142 shape: vec![layer.rank, layer.input_dim],
143 dtype: "F32".to_string(),
144 },
145 );
146
147 tensors.insert(
148 v_b_key,
149 TensorData {
150 data: layer.lora_b.clone(),
151 shape: vec![layer.output_dim, layer.rank],
152 dtype: "F32".to_string(),
153 },
154 );
155
156 let o_a_key = format!(
158 "base_model.model.layers.{}.self_attn.o_proj.lora_A.weight",
159 i
160 );
161 let o_b_key = format!(
162 "base_model.model.layers.{}.self_attn.o_proj.lora_B.weight",
163 i
164 );
165
166 tensors.insert(
167 o_a_key,
168 TensorData {
169 data: layer.lora_a.clone(),
170 shape: vec![layer.rank, layer.input_dim],
171 dtype: "F32".to_string(),
172 },
173 );
174
175 tensors.insert(
176 o_b_key,
177 TensorData {
178 data: layer.lora_b.clone(),
179 shape: vec![layer.output_dim, layer.rank],
180 dtype: "F32".to_string(),
181 },
182 );
183 }
184
185 let safetensors_path = output_dir.join("adapter_model.safetensors");
187 let bytes = self.serialize_safetensors(&tensors)?;
188 std::fs::write(&safetensors_path, &bytes).map_err(ExportError::Io)?;
189
190 let size_bytes = bytes.len() as u64;
191
192 Ok(ExportResult {
193 export_type: ExportType::SafeTensors,
194 items_exported: tensors.len(),
195 output_path: safetensors_path.to_string_lossy().to_string(),
196 size_bytes,
197 })
198 }
199
200 fn serialize_safetensors(
202 &self,
203 tensors: &HashMap<String, TensorData>,
204 ) -> Result<Vec<u8>, ExportError> {
205 let mut header_data: HashMap<String, TensorMetadata> = HashMap::new();
211 let mut tensor_bytes: Vec<u8> = Vec::new();
212
213 let mut keys: Vec<_> = tensors.keys().collect();
215 keys.sort();
216
217 for key in keys {
218 let tensor = &tensors[key];
219
220 let padding = (8 - (tensor_bytes.len() % 8)) % 8;
222 tensor_bytes.extend(vec![0u8; padding]);
223
224 let start_offset = tensor_bytes.len();
225
226 for &val in &tensor.data {
228 tensor_bytes.extend_from_slice(&val.to_le_bytes());
229 }
230
231 let end_offset = tensor_bytes.len();
232
233 header_data.insert(
234 key.clone(),
235 TensorMetadata {
236 dtype: tensor.dtype.clone(),
237 shape: tensor.shape.clone(),
238 data_offsets: [start_offset, end_offset],
239 },
240 );
241 }
242
243 let header_json =
245 serde_json::to_string(&header_data).map_err(ExportError::Serialization)?;
246 let header_bytes = header_json.as_bytes();
247
248 let mut result = Vec::new();
250
251 result.extend_from_slice(&(header_bytes.len() as u64).to_le_bytes());
253
254 result.extend_from_slice(header_bytes);
256
257 result.extend(tensor_bytes);
259
260 Ok(result)
261 }
262}
263
264#[derive(Clone, Debug)]
266pub struct TensorData {
267 pub data: Vec<f32>,
269 pub shape: Vec<usize>,
271 pub dtype: String,
273}
274
275#[cfg(feature = "serde-support")]
277#[derive(Clone, Debug, Serialize, Deserialize)]
278struct TensorMetadata {
279 dtype: String,
280 shape: Vec<usize>,
281 data_offsets: [usize; 2],
282}
283
284#[derive(Clone, Debug)]
286pub struct LoRALayerState {
287 pub lora_a: Vec<f32>,
289 pub lora_b: Vec<f32>,
291 pub rank: usize,
293 pub input_dim: usize,
295 pub output_dim: usize,
297}
298
299#[derive(Clone, Debug, Default)]
301pub struct LoRAState {
302 pub micro_lora_layers: Vec<LoRALayerState>,
304 pub base_lora_layers: Vec<LoRALayerState>,
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
313 fn test_tensor_data_creation() {
314 let tensor = TensorData {
315 data: vec![1.0, 2.0, 3.0, 4.0],
316 shape: vec![2, 2],
317 dtype: "F32".to_string(),
318 };
319
320 assert_eq!(tensor.data.len(), 4);
321 assert_eq!(tensor.shape, vec![2, 2]);
322 }
323
324 #[test]
325 fn test_lora_layer_state() {
326 let state = LoRALayerState {
327 lora_a: vec![0.1, 0.2, 0.3, 0.4],
328 lora_b: vec![0.5, 0.6, 0.7, 0.8],
329 rank: 2,
330 input_dim: 2,
331 output_dim: 2,
332 };
333
334 assert_eq!(state.rank, 2);
335 assert_eq!(state.lora_a.len(), 4);
336 }
337}