1use crate::engine::SonaEngine;
7use crate::types::LearnedPattern;
8use super::{ExportConfig, ExportResult, ExportType, ExportError};
9use std::path::Path;
10use std::io::{BufWriter, Write};
11
12#[cfg(feature = "serde-support")]
13use serde::{Deserialize, Serialize};
14
15pub struct DatasetExporter<'a> {
17 config: &'a ExportConfig,
18}
19
20impl<'a> DatasetExporter<'a> {
21 pub fn new(config: &'a ExportConfig) -> Self {
23 Self { config }
24 }
25
26 pub fn export_patterns<P: AsRef<Path>>(
28 &self,
29 engine: &SonaEngine,
30 output_path: P,
31 ) -> Result<ExportResult, ExportError> {
32 let output_path = output_path.as_ref();
33
34 if let Some(parent) = output_path.parent() {
36 std::fs::create_dir_all(parent).map_err(ExportError::Io)?;
37 }
38
39 let file = std::fs::File::create(output_path).map_err(ExportError::Io)?;
40 let mut writer = BufWriter::new(file);
41
42 let patterns = engine.get_all_patterns();
43 let mut items_exported = 0;
44
45 for pattern in patterns {
46 if pattern.avg_quality < self.config.min_quality_threshold {
48 continue;
49 }
50
51 let record = PatternRecord {
52 id: pattern.id.to_string(),
53 embedding: pattern.centroid.clone(),
54 cluster_size: pattern.cluster_size,
55 avg_quality: pattern.avg_quality,
56 pattern_type: pattern.pattern_type.to_string(),
57 access_count: pattern.access_count as u64,
58 metadata: PatternMetadata {
59 source: "sona".to_string(),
60 version: env!("CARGO_PKG_VERSION").to_string(),
61 target_model: self.config.target_architecture.clone(),
62 },
63 };
64
65 let json = serde_json::to_string(&record).map_err(ExportError::Serialization)?;
66 writeln!(writer, "{}", json).map_err(ExportError::Io)?;
67 items_exported += 1;
68 }
69
70 writer.flush().map_err(ExportError::Io)?;
71
72 let size_bytes = std::fs::metadata(output_path)
73 .map(|m| m.len())
74 .unwrap_or(0);
75
76 Ok(ExportResult {
77 export_type: ExportType::PatternsDataset,
78 items_exported,
79 output_path: output_path.to_string_lossy().to_string(),
80 size_bytes,
81 })
82 }
83
84 pub fn export_preferences<P: AsRef<Path>>(
86 &self,
87 engine: &SonaEngine,
88 output_path: P,
89 ) -> Result<ExportResult, ExportError> {
90 let output_path = output_path.as_ref();
91
92 if let Some(parent) = output_path.parent() {
94 std::fs::create_dir_all(parent).map_err(ExportError::Io)?;
95 }
96
97 let file = std::fs::File::create(output_path).map_err(ExportError::Io)?;
98 let mut writer = BufWriter::new(file);
99
100 let trajectories = engine.get_quality_trajectories();
101 let mut items_exported = 0;
102
103 let mut sorted_trajectories = trajectories.clone();
106 sorted_trajectories.sort_by(|a, b| {
107 b.quality.partial_cmp(&a.quality).unwrap_or(std::cmp::Ordering::Equal)
108 });
109
110 let mid = sorted_trajectories.len() / 2;
111 let (high_quality, low_quality) = sorted_trajectories.split_at(mid);
112
113 for (chosen, rejected) in high_quality.iter().zip(low_quality.iter().rev()) {
114 if (chosen.quality - rejected.quality).abs() < 0.1 {
116 continue;
117 }
118
119 let pair = PreferencePair {
120 prompt: PreferencePrompt {
121 embedding: chosen.query_embedding.clone(),
122 context: chosen.context_ids.clone(),
123 },
124 chosen: PreferenceResponse {
125 route: chosen.route.clone(),
126 quality: chosen.quality,
127 embedding: chosen.response_embedding.clone(),
128 },
129 rejected: PreferenceResponse {
130 route: rejected.route.clone(),
131 quality: rejected.quality,
132 embedding: rejected.response_embedding.clone(),
133 },
134 metadata: PreferenceMetadata {
135 quality_delta: chosen.quality - rejected.quality,
136 source: "sona".to_string(),
137 version: env!("CARGO_PKG_VERSION").to_string(),
138 },
139 };
140
141 let json = serde_json::to_string(&pair).map_err(ExportError::Serialization)?;
142 writeln!(writer, "{}", json).map_err(ExportError::Io)?;
143 items_exported += 1;
144 }
145
146 writer.flush().map_err(ExportError::Io)?;
147
148 let size_bytes = std::fs::metadata(output_path)
149 .map(|m| m.len())
150 .unwrap_or(0);
151
152 Ok(ExportResult {
153 export_type: ExportType::PreferencePairs,
154 items_exported,
155 output_path: output_path.to_string_lossy().to_string(),
156 size_bytes,
157 })
158 }
159
160 pub fn export_distillation_targets<P: AsRef<Path>>(
162 &self,
163 engine: &SonaEngine,
164 output_path: P,
165 ) -> Result<ExportResult, ExportError> {
166 let output_path = output_path.as_ref();
167
168 if let Some(parent) = output_path.parent() {
170 std::fs::create_dir_all(parent).map_err(ExportError::Io)?;
171 }
172
173 let file = std::fs::File::create(output_path).map_err(ExportError::Io)?;
174 let mut writer = BufWriter::new(file);
175
176 let routing_decisions = engine.get_routing_decisions();
177 let mut items_exported = 0;
178
179 for decision in routing_decisions {
180 if decision.quality < self.config.min_quality_threshold {
182 continue;
183 }
184
185 let target = DistillationTarget {
186 input_embedding: decision.query_embedding.clone(),
187 teacher_logits: decision.routing_logits.clone(),
188 selected_route: decision.selected_route.clone(),
189 confidence: decision.confidence,
190 quality: decision.quality,
191 metadata: DistillationMetadata {
192 source: "sona".to_string(),
193 version: env!("CARGO_PKG_VERSION").to_string(),
194 temperature: 1.0,
195 },
196 };
197
198 let json = serde_json::to_string(&target).map_err(ExportError::Serialization)?;
199 writeln!(writer, "{}", json).map_err(ExportError::Io)?;
200 items_exported += 1;
201 }
202
203 writer.flush().map_err(ExportError::Io)?;
204
205 let size_bytes = std::fs::metadata(output_path)
206 .map(|m| m.len())
207 .unwrap_or(0);
208
209 Ok(ExportResult {
210 export_type: ExportType::DistillationTargets,
211 items_exported,
212 output_path: output_path.to_string_lossy().to_string(),
213 size_bytes,
214 })
215 }
216}
217
218#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
220#[derive(Clone, Debug)]
221pub struct PatternRecord {
222 pub id: String,
224 pub embedding: Vec<f32>,
226 pub cluster_size: usize,
228 pub avg_quality: f32,
230 pub pattern_type: String,
232 pub access_count: u64,
234 pub metadata: PatternMetadata,
236}
237
238#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
240#[derive(Clone, Debug)]
241pub struct PatternMetadata {
242 pub source: String,
244 pub version: String,
246 pub target_model: String,
248}
249
250#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
252#[derive(Clone, Debug)]
253pub struct PreferencePair {
254 pub prompt: PreferencePrompt,
256 pub chosen: PreferenceResponse,
258 pub rejected: PreferenceResponse,
260 pub metadata: PreferenceMetadata,
262}
263
264#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
266#[derive(Clone, Debug)]
267pub struct PreferencePrompt {
268 pub embedding: Vec<f32>,
270 pub context: Vec<String>,
272}
273
274#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
276#[derive(Clone, Debug)]
277pub struct PreferenceResponse {
278 pub route: String,
280 pub quality: f32,
282 pub embedding: Vec<f32>,
284}
285
286#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
288#[derive(Clone, Debug)]
289pub struct PreferenceMetadata {
290 pub quality_delta: f32,
292 pub source: String,
294 pub version: String,
296}
297
298#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
300#[derive(Clone, Debug)]
301pub struct DistillationTarget {
302 pub input_embedding: Vec<f32>,
304 pub teacher_logits: Vec<f32>,
306 pub selected_route: String,
308 pub confidence: f32,
310 pub quality: f32,
312 pub metadata: DistillationMetadata,
314}
315
316#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
318#[derive(Clone, Debug)]
319pub struct DistillationMetadata {
320 pub source: String,
322 pub version: String,
324 pub temperature: f32,
326}
327
328#[derive(Clone, Debug)]
330pub struct QualityTrajectory {
331 pub query_embedding: Vec<f32>,
333 pub response_embedding: Vec<f32>,
335 pub route: String,
337 pub quality: f32,
339 pub context_ids: Vec<String>,
341}
342
343#[derive(Clone, Debug)]
345pub struct RoutingDecision {
346 pub query_embedding: Vec<f32>,
348 pub routing_logits: Vec<f32>,
350 pub selected_route: String,
352 pub confidence: f32,
354 pub quality: f32,
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361
362 #[test]
363 fn test_pattern_record() {
364 let record = PatternRecord {
365 id: "test-pattern".to_string(),
366 embedding: vec![0.1, 0.2, 0.3],
367 cluster_size: 10,
368 avg_quality: 0.85,
369 pattern_type: "routing".to_string(),
370 access_count: 100,
371 metadata: PatternMetadata {
372 source: "sona".to_string(),
373 version: "0.1.0".to_string(),
374 target_model: "phi-4".to_string(),
375 },
376 };
377
378 let json = serde_json::to_string(&record).unwrap();
379 assert!(json.contains("test-pattern"));
380 assert!(json.contains("0.85"));
381 }
382
383 #[test]
384 fn test_preference_pair() {
385 let pair = PreferencePair {
386 prompt: PreferencePrompt {
387 embedding: vec![0.1, 0.2],
388 context: vec!["ctx1".to_string()],
389 },
390 chosen: PreferenceResponse {
391 route: "gpt-4".to_string(),
392 quality: 0.9,
393 embedding: vec![0.3, 0.4],
394 },
395 rejected: PreferenceResponse {
396 route: "gpt-3.5".to_string(),
397 quality: 0.6,
398 embedding: vec![0.5, 0.6],
399 },
400 metadata: PreferenceMetadata {
401 quality_delta: 0.3,
402 source: "sona".to_string(),
403 version: "0.1.0".to_string(),
404 },
405 };
406
407 let json = serde_json::to_string(&pair).unwrap();
408 assert!(json.contains("gpt-4"));
409 assert!(json.contains("0.9"));
410 }
411}