1use super::{ExportConfig, ExportError, ExportResult, ExportType};
7use crate::engine::SonaEngine;
8use std::io::{BufWriter, Write};
9use std::path::Path;
10
11#[cfg(feature = "serde-support")]
12use serde::{Deserialize, Serialize};
13
14pub struct DatasetExporter<'a> {
16 config: &'a ExportConfig,
17}
18
19impl<'a> DatasetExporter<'a> {
20 pub fn new(config: &'a ExportConfig) -> Self {
22 Self { config }
23 }
24
25 pub fn export_patterns<P: AsRef<Path>>(
27 &self,
28 engine: &SonaEngine,
29 output_path: P,
30 ) -> Result<ExportResult, ExportError> {
31 let output_path = output_path.as_ref();
32
33 if let Some(parent) = output_path.parent() {
35 std::fs::create_dir_all(parent).map_err(ExportError::Io)?;
36 }
37
38 let file = std::fs::File::create(output_path).map_err(ExportError::Io)?;
39 let mut writer = BufWriter::new(file);
40
41 let patterns = engine.get_all_patterns();
42 let mut items_exported = 0;
43
44 for pattern in patterns {
45 if pattern.avg_quality < self.config.min_quality_threshold {
47 continue;
48 }
49
50 let record = PatternRecord {
51 id: pattern.id.to_string(),
52 embedding: pattern.centroid.clone(),
53 cluster_size: pattern.cluster_size,
54 avg_quality: pattern.avg_quality,
55 pattern_type: pattern.pattern_type.to_string(),
56 access_count: pattern.access_count as u64,
57 metadata: PatternMetadata {
58 source: "sona".to_string(),
59 version: env!("CARGO_PKG_VERSION").to_string(),
60 target_model: self.config.target_architecture.clone(),
61 },
62 };
63
64 let json = serde_json::to_string(&record).map_err(ExportError::Serialization)?;
65 writeln!(writer, "{}", json).map_err(ExportError::Io)?;
66 items_exported += 1;
67 }
68
69 writer.flush().map_err(ExportError::Io)?;
70
71 let size_bytes = std::fs::metadata(output_path).map(|m| m.len()).unwrap_or(0);
72
73 Ok(ExportResult {
74 export_type: ExportType::PatternsDataset,
75 items_exported,
76 output_path: output_path.to_string_lossy().to_string(),
77 size_bytes,
78 })
79 }
80
81 pub fn export_preferences<P: AsRef<Path>>(
83 &self,
84 engine: &SonaEngine,
85 output_path: P,
86 ) -> Result<ExportResult, ExportError> {
87 let output_path = output_path.as_ref();
88
89 if let Some(parent) = output_path.parent() {
91 std::fs::create_dir_all(parent).map_err(ExportError::Io)?;
92 }
93
94 let file = std::fs::File::create(output_path).map_err(ExportError::Io)?;
95 let mut writer = BufWriter::new(file);
96
97 let trajectories = engine.get_quality_trajectories();
98 let mut items_exported = 0;
99
100 let mut sorted_trajectories = trajectories.clone();
103 sorted_trajectories.sort_by(|a, b| {
104 b.quality
105 .partial_cmp(&a.quality)
106 .unwrap_or(std::cmp::Ordering::Equal)
107 });
108
109 let mid = sorted_trajectories.len() / 2;
110 let (high_quality, low_quality) = sorted_trajectories.split_at(mid);
111
112 for (chosen, rejected) in high_quality.iter().zip(low_quality.iter().rev()) {
113 if (chosen.quality - rejected.quality).abs() < 0.1 {
115 continue;
116 }
117
118 let pair = PreferencePair {
119 prompt: PreferencePrompt {
120 embedding: chosen.query_embedding.clone(),
121 context: chosen.context_ids.clone(),
122 },
123 chosen: PreferenceResponse {
124 route: chosen.route.clone(),
125 quality: chosen.quality,
126 embedding: chosen.response_embedding.clone(),
127 },
128 rejected: PreferenceResponse {
129 route: rejected.route.clone(),
130 quality: rejected.quality,
131 embedding: rejected.response_embedding.clone(),
132 },
133 metadata: PreferenceMetadata {
134 quality_delta: chosen.quality - rejected.quality,
135 source: "sona".to_string(),
136 version: env!("CARGO_PKG_VERSION").to_string(),
137 },
138 };
139
140 let json = serde_json::to_string(&pair).map_err(ExportError::Serialization)?;
141 writeln!(writer, "{}", json).map_err(ExportError::Io)?;
142 items_exported += 1;
143 }
144
145 writer.flush().map_err(ExportError::Io)?;
146
147 let size_bytes = std::fs::metadata(output_path).map(|m| m.len()).unwrap_or(0);
148
149 Ok(ExportResult {
150 export_type: ExportType::PreferencePairs,
151 items_exported,
152 output_path: output_path.to_string_lossy().to_string(),
153 size_bytes,
154 })
155 }
156
157 pub fn export_distillation_targets<P: AsRef<Path>>(
159 &self,
160 engine: &SonaEngine,
161 output_path: P,
162 ) -> Result<ExportResult, ExportError> {
163 let output_path = output_path.as_ref();
164
165 if let Some(parent) = output_path.parent() {
167 std::fs::create_dir_all(parent).map_err(ExportError::Io)?;
168 }
169
170 let file = std::fs::File::create(output_path).map_err(ExportError::Io)?;
171 let mut writer = BufWriter::new(file);
172
173 let routing_decisions = engine.get_routing_decisions();
174 let mut items_exported = 0;
175
176 for decision in routing_decisions {
177 if decision.quality < self.config.min_quality_threshold {
179 continue;
180 }
181
182 let target = DistillationTarget {
183 input_embedding: decision.query_embedding.clone(),
184 teacher_logits: decision.routing_logits.clone(),
185 selected_route: decision.selected_route.clone(),
186 confidence: decision.confidence,
187 quality: decision.quality,
188 metadata: DistillationMetadata {
189 source: "sona".to_string(),
190 version: env!("CARGO_PKG_VERSION").to_string(),
191 temperature: 1.0,
192 },
193 };
194
195 let json = serde_json::to_string(&target).map_err(ExportError::Serialization)?;
196 writeln!(writer, "{}", json).map_err(ExportError::Io)?;
197 items_exported += 1;
198 }
199
200 writer.flush().map_err(ExportError::Io)?;
201
202 let size_bytes = std::fs::metadata(output_path).map(|m| m.len()).unwrap_or(0);
203
204 Ok(ExportResult {
205 export_type: ExportType::DistillationTargets,
206 items_exported,
207 output_path: output_path.to_string_lossy().to_string(),
208 size_bytes,
209 })
210 }
211}
212
213#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
215#[derive(Clone, Debug)]
216pub struct PatternRecord {
217 pub id: String,
219 pub embedding: Vec<f32>,
221 pub cluster_size: usize,
223 pub avg_quality: f32,
225 pub pattern_type: String,
227 pub access_count: u64,
229 pub metadata: PatternMetadata,
231}
232
233#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
235#[derive(Clone, Debug)]
236pub struct PatternMetadata {
237 pub source: String,
239 pub version: String,
241 pub target_model: String,
243}
244
245#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
247#[derive(Clone, Debug)]
248pub struct PreferencePair {
249 pub prompt: PreferencePrompt,
251 pub chosen: PreferenceResponse,
253 pub rejected: PreferenceResponse,
255 pub metadata: PreferenceMetadata,
257}
258
259#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
261#[derive(Clone, Debug)]
262pub struct PreferencePrompt {
263 pub embedding: Vec<f32>,
265 pub context: Vec<String>,
267}
268
269#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
271#[derive(Clone, Debug)]
272pub struct PreferenceResponse {
273 pub route: String,
275 pub quality: f32,
277 pub embedding: Vec<f32>,
279}
280
281#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
283#[derive(Clone, Debug)]
284pub struct PreferenceMetadata {
285 pub quality_delta: f32,
287 pub source: String,
289 pub version: String,
291}
292
293#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
295#[derive(Clone, Debug)]
296pub struct DistillationTarget {
297 pub input_embedding: Vec<f32>,
299 pub teacher_logits: Vec<f32>,
301 pub selected_route: String,
303 pub confidence: f32,
305 pub quality: f32,
307 pub metadata: DistillationMetadata,
309}
310
311#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
313#[derive(Clone, Debug)]
314pub struct DistillationMetadata {
315 pub source: String,
317 pub version: String,
319 pub temperature: f32,
321}
322
323#[derive(Clone, Debug)]
325pub struct QualityTrajectory {
326 pub query_embedding: Vec<f32>,
328 pub response_embedding: Vec<f32>,
330 pub route: String,
332 pub quality: f32,
334 pub context_ids: Vec<String>,
336}
337
338#[derive(Clone, Debug)]
340pub struct RoutingDecision {
341 pub query_embedding: Vec<f32>,
343 pub routing_logits: Vec<f32>,
345 pub selected_route: String,
347 pub confidence: f32,
349 pub quality: f32,
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356
357 #[test]
358 fn test_pattern_record() {
359 let record = PatternRecord {
360 id: "test-pattern".to_string(),
361 embedding: vec![0.1, 0.2, 0.3],
362 cluster_size: 10,
363 avg_quality: 0.85,
364 pattern_type: "routing".to_string(),
365 access_count: 100,
366 metadata: PatternMetadata {
367 source: "sona".to_string(),
368 version: "0.1.0".to_string(),
369 target_model: "phi-4".to_string(),
370 },
371 };
372
373 let json = serde_json::to_string(&record).unwrap();
374 assert!(json.contains("test-pattern"));
375 assert!(json.contains("0.85"));
376 }
377
378 #[test]
379 fn test_preference_pair() {
380 let pair = PreferencePair {
381 prompt: PreferencePrompt {
382 embedding: vec![0.1, 0.2],
383 context: vec!["ctx1".to_string()],
384 },
385 chosen: PreferenceResponse {
386 route: "gpt-4".to_string(),
387 quality: 0.9,
388 embedding: vec![0.3, 0.4],
389 },
390 rejected: PreferenceResponse {
391 route: "gpt-3.5".to_string(),
392 quality: 0.6,
393 embedding: vec![0.5, 0.6],
394 },
395 metadata: PreferenceMetadata {
396 quality_delta: 0.3,
397 source: "sona".to_string(),
398 version: "0.1.0".to_string(),
399 },
400 };
401
402 let json = serde_json::to_string(&pair).unwrap();
403 assert!(json.contains("gpt-4"));
404 assert!(json.contains("0.9"));
405 }
406}