1use super::{ExportConfig, ExportError, ExportResult, ExportType};
7use crate::engine::SonaEngine;
8use crate::types::LearnedPattern;
9use std::io::{BufWriter, Write};
10use std::path::Path;
11
12#[cfg(feature = "serde-support")]
13use serde::{Deserialize, Serialize};
14
15pub struct DatasetExporter<'a> {
17 config: &'a ExportConfig,
18}
19
20impl<'a> DatasetExporter<'a> {
21 pub fn new(config: &'a ExportConfig) -> Self {
23 Self { config }
24 }
25
26 pub fn export_patterns<P: AsRef<Path>>(
28 &self,
29 engine: &SonaEngine,
30 output_path: P,
31 ) -> Result<ExportResult, ExportError> {
32 let output_path = output_path.as_ref();
33
34 if let Some(parent) = output_path.parent() {
36 std::fs::create_dir_all(parent).map_err(ExportError::Io)?;
37 }
38
39 let file = std::fs::File::create(output_path).map_err(ExportError::Io)?;
40 let mut writer = BufWriter::new(file);
41
42 let patterns = engine.get_all_patterns();
43 let mut items_exported = 0;
44
45 for pattern in patterns {
46 if pattern.avg_quality < self.config.min_quality_threshold {
48 continue;
49 }
50
51 let record = PatternRecord {
52 id: pattern.id.to_string(),
53 embedding: pattern.centroid.clone(),
54 cluster_size: pattern.cluster_size,
55 avg_quality: pattern.avg_quality,
56 pattern_type: pattern.pattern_type.to_string(),
57 access_count: pattern.access_count as u64,
58 metadata: PatternMetadata {
59 source: "sona".to_string(),
60 version: env!("CARGO_PKG_VERSION").to_string(),
61 target_model: self.config.target_architecture.clone(),
62 },
63 };
64
65 let json = serde_json::to_string(&record).map_err(ExportError::Serialization)?;
66 writeln!(writer, "{}", json).map_err(ExportError::Io)?;
67 items_exported += 1;
68 }
69
70 writer.flush().map_err(ExportError::Io)?;
71
72 let size_bytes = std::fs::metadata(output_path).map(|m| m.len()).unwrap_or(0);
73
74 Ok(ExportResult {
75 export_type: ExportType::PatternsDataset,
76 items_exported,
77 output_path: output_path.to_string_lossy().to_string(),
78 size_bytes,
79 })
80 }
81
82 pub fn export_preferences<P: AsRef<Path>>(
84 &self,
85 engine: &SonaEngine,
86 output_path: P,
87 ) -> Result<ExportResult, ExportError> {
88 let output_path = output_path.as_ref();
89
90 if let Some(parent) = output_path.parent() {
92 std::fs::create_dir_all(parent).map_err(ExportError::Io)?;
93 }
94
95 let file = std::fs::File::create(output_path).map_err(ExportError::Io)?;
96 let mut writer = BufWriter::new(file);
97
98 let trajectories = engine.get_quality_trajectories();
99 let mut items_exported = 0;
100
101 let mut sorted_trajectories = trajectories.clone();
104 sorted_trajectories.sort_by(|a, b| {
105 b.quality
106 .partial_cmp(&a.quality)
107 .unwrap_or(std::cmp::Ordering::Equal)
108 });
109
110 let mid = sorted_trajectories.len() / 2;
111 let (high_quality, low_quality) = sorted_trajectories.split_at(mid);
112
113 for (chosen, rejected) in high_quality.iter().zip(low_quality.iter().rev()) {
114 if (chosen.quality - rejected.quality).abs() < 0.1 {
116 continue;
117 }
118
119 let pair = PreferencePair {
120 prompt: PreferencePrompt {
121 embedding: chosen.query_embedding.clone(),
122 context: chosen.context_ids.clone(),
123 },
124 chosen: PreferenceResponse {
125 route: chosen.route.clone(),
126 quality: chosen.quality,
127 embedding: chosen.response_embedding.clone(),
128 },
129 rejected: PreferenceResponse {
130 route: rejected.route.clone(),
131 quality: rejected.quality,
132 embedding: rejected.response_embedding.clone(),
133 },
134 metadata: PreferenceMetadata {
135 quality_delta: chosen.quality - rejected.quality,
136 source: "sona".to_string(),
137 version: env!("CARGO_PKG_VERSION").to_string(),
138 },
139 };
140
141 let json = serde_json::to_string(&pair).map_err(ExportError::Serialization)?;
142 writeln!(writer, "{}", json).map_err(ExportError::Io)?;
143 items_exported += 1;
144 }
145
146 writer.flush().map_err(ExportError::Io)?;
147
148 let size_bytes = std::fs::metadata(output_path).map(|m| m.len()).unwrap_or(0);
149
150 Ok(ExportResult {
151 export_type: ExportType::PreferencePairs,
152 items_exported,
153 output_path: output_path.to_string_lossy().to_string(),
154 size_bytes,
155 })
156 }
157
158 pub fn export_distillation_targets<P: AsRef<Path>>(
160 &self,
161 engine: &SonaEngine,
162 output_path: P,
163 ) -> Result<ExportResult, ExportError> {
164 let output_path = output_path.as_ref();
165
166 if let Some(parent) = output_path.parent() {
168 std::fs::create_dir_all(parent).map_err(ExportError::Io)?;
169 }
170
171 let file = std::fs::File::create(output_path).map_err(ExportError::Io)?;
172 let mut writer = BufWriter::new(file);
173
174 let routing_decisions = engine.get_routing_decisions();
175 let mut items_exported = 0;
176
177 for decision in routing_decisions {
178 if decision.quality < self.config.min_quality_threshold {
180 continue;
181 }
182
183 let target = DistillationTarget {
184 input_embedding: decision.query_embedding.clone(),
185 teacher_logits: decision.routing_logits.clone(),
186 selected_route: decision.selected_route.clone(),
187 confidence: decision.confidence,
188 quality: decision.quality,
189 metadata: DistillationMetadata {
190 source: "sona".to_string(),
191 version: env!("CARGO_PKG_VERSION").to_string(),
192 temperature: 1.0,
193 },
194 };
195
196 let json = serde_json::to_string(&target).map_err(ExportError::Serialization)?;
197 writeln!(writer, "{}", json).map_err(ExportError::Io)?;
198 items_exported += 1;
199 }
200
201 writer.flush().map_err(ExportError::Io)?;
202
203 let size_bytes = std::fs::metadata(output_path).map(|m| m.len()).unwrap_or(0);
204
205 Ok(ExportResult {
206 export_type: ExportType::DistillationTargets,
207 items_exported,
208 output_path: output_path.to_string_lossy().to_string(),
209 size_bytes,
210 })
211 }
212}
213
214#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
216#[derive(Clone, Debug)]
217pub struct PatternRecord {
218 pub id: String,
220 pub embedding: Vec<f32>,
222 pub cluster_size: usize,
224 pub avg_quality: f32,
226 pub pattern_type: String,
228 pub access_count: u64,
230 pub metadata: PatternMetadata,
232}
233
234#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
236#[derive(Clone, Debug)]
237pub struct PatternMetadata {
238 pub source: String,
240 pub version: String,
242 pub target_model: String,
244}
245
246#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
248#[derive(Clone, Debug)]
249pub struct PreferencePair {
250 pub prompt: PreferencePrompt,
252 pub chosen: PreferenceResponse,
254 pub rejected: PreferenceResponse,
256 pub metadata: PreferenceMetadata,
258}
259
260#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
262#[derive(Clone, Debug)]
263pub struct PreferencePrompt {
264 pub embedding: Vec<f32>,
266 pub context: Vec<String>,
268}
269
270#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
272#[derive(Clone, Debug)]
273pub struct PreferenceResponse {
274 pub route: String,
276 pub quality: f32,
278 pub embedding: Vec<f32>,
280}
281
282#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
284#[derive(Clone, Debug)]
285pub struct PreferenceMetadata {
286 pub quality_delta: f32,
288 pub source: String,
290 pub version: String,
292}
293
294#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
296#[derive(Clone, Debug)]
297pub struct DistillationTarget {
298 pub input_embedding: Vec<f32>,
300 pub teacher_logits: Vec<f32>,
302 pub selected_route: String,
304 pub confidence: f32,
306 pub quality: f32,
308 pub metadata: DistillationMetadata,
310}
311
312#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
314#[derive(Clone, Debug)]
315pub struct DistillationMetadata {
316 pub source: String,
318 pub version: String,
320 pub temperature: f32,
322}
323
324#[derive(Clone, Debug)]
326pub struct QualityTrajectory {
327 pub query_embedding: Vec<f32>,
329 pub response_embedding: Vec<f32>,
331 pub route: String,
333 pub quality: f32,
335 pub context_ids: Vec<String>,
337}
338
339#[derive(Clone, Debug)]
341pub struct RoutingDecision {
342 pub query_embedding: Vec<f32>,
344 pub routing_logits: Vec<f32>,
346 pub selected_route: String,
348 pub confidence: f32,
350 pub quality: f32,
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn test_pattern_record() {
360 let record = PatternRecord {
361 id: "test-pattern".to_string(),
362 embedding: vec![0.1, 0.2, 0.3],
363 cluster_size: 10,
364 avg_quality: 0.85,
365 pattern_type: "routing".to_string(),
366 access_count: 100,
367 metadata: PatternMetadata {
368 source: "sona".to_string(),
369 version: "0.1.0".to_string(),
370 target_model: "phi-4".to_string(),
371 },
372 };
373
374 let json = serde_json::to_string(&record).unwrap();
375 assert!(json.contains("test-pattern"));
376 assert!(json.contains("0.85"));
377 }
378
379 #[test]
380 fn test_preference_pair() {
381 let pair = PreferencePair {
382 prompt: PreferencePrompt {
383 embedding: vec![0.1, 0.2],
384 context: vec!["ctx1".to_string()],
385 },
386 chosen: PreferenceResponse {
387 route: "gpt-4".to_string(),
388 quality: 0.9,
389 embedding: vec![0.3, 0.4],
390 },
391 rejected: PreferenceResponse {
392 route: "gpt-3.5".to_string(),
393 quality: 0.6,
394 embedding: vec![0.5, 0.6],
395 },
396 metadata: PreferenceMetadata {
397 quality_delta: 0.3,
398 source: "sona".to_string(),
399 version: "0.1.0".to_string(),
400 },
401 };
402
403 let json = serde_json::to_string(&pair).unwrap();
404 assert!(json.contains("gpt-4"));
405 assert!(json.contains("0.9"));
406 }
407}