ruvector_sona/export/
dataset.rs

1//! Dataset Export - HuggingFace-compatible dataset formats
2//!
3//! Exports SONA's learned patterns and preference pairs as JSONL datasets
4//! compatible with HuggingFace's datasets library.
5
6use crate::engine::SonaEngine;
7use crate::types::LearnedPattern;
8use super::{ExportConfig, ExportResult, ExportType, ExportError};
9use std::path::Path;
10use std::io::{BufWriter, Write};
11
12#[cfg(feature = "serde-support")]
13use serde::{Deserialize, Serialize};
14
15/// Dataset exporter for patterns and preferences
16pub struct DatasetExporter<'a> {
17    config: &'a ExportConfig,
18}
19
20impl<'a> DatasetExporter<'a> {
21    /// Create new dataset exporter
22    pub fn new(config: &'a ExportConfig) -> Self {
23        Self { config }
24    }
25
26    /// Export learned patterns as JSONL dataset
27    pub fn export_patterns<P: AsRef<Path>>(
28        &self,
29        engine: &SonaEngine,
30        output_path: P,
31    ) -> Result<ExportResult, ExportError> {
32        let output_path = output_path.as_ref();
33
34        // Ensure parent directory exists
35        if let Some(parent) = output_path.parent() {
36            std::fs::create_dir_all(parent).map_err(ExportError::Io)?;
37        }
38
39        let file = std::fs::File::create(output_path).map_err(ExportError::Io)?;
40        let mut writer = BufWriter::new(file);
41
42        let patterns = engine.get_all_patterns();
43        let mut items_exported = 0;
44
45        for pattern in patterns {
46            // Filter by quality threshold
47            if pattern.avg_quality < self.config.min_quality_threshold {
48                continue;
49            }
50
51            let record = PatternRecord {
52                id: pattern.id.to_string(),
53                embedding: pattern.centroid.clone(),
54                cluster_size: pattern.cluster_size,
55                avg_quality: pattern.avg_quality,
56                pattern_type: pattern.pattern_type.to_string(),
57                access_count: pattern.access_count as u64,
58                metadata: PatternMetadata {
59                    source: "sona".to_string(),
60                    version: env!("CARGO_PKG_VERSION").to_string(),
61                    target_model: self.config.target_architecture.clone(),
62                },
63            };
64
65            let json = serde_json::to_string(&record).map_err(ExportError::Serialization)?;
66            writeln!(writer, "{}", json).map_err(ExportError::Io)?;
67            items_exported += 1;
68        }
69
70        writer.flush().map_err(ExportError::Io)?;
71
72        let size_bytes = std::fs::metadata(output_path)
73            .map(|m| m.len())
74            .unwrap_or(0);
75
76        Ok(ExportResult {
77            export_type: ExportType::PatternsDataset,
78            items_exported,
79            output_path: output_path.to_string_lossy().to_string(),
80            size_bytes,
81        })
82    }
83
84    /// Export preference pairs for DPO/RLHF training
85    pub fn export_preferences<P: AsRef<Path>>(
86        &self,
87        engine: &SonaEngine,
88        output_path: P,
89    ) -> Result<ExportResult, ExportError> {
90        let output_path = output_path.as_ref();
91
92        // Ensure parent directory exists
93        if let Some(parent) = output_path.parent() {
94            std::fs::create_dir_all(parent).map_err(ExportError::Io)?;
95        }
96
97        let file = std::fs::File::create(output_path).map_err(ExportError::Io)?;
98        let mut writer = BufWriter::new(file);
99
100        let trajectories = engine.get_quality_trajectories();
101        let mut items_exported = 0;
102
103        // Generate preference pairs from trajectories
104        // Sort by quality and pair high-quality with low-quality
105        let mut sorted_trajectories = trajectories.clone();
106        sorted_trajectories.sort_by(|a, b| {
107            b.quality.partial_cmp(&a.quality).unwrap_or(std::cmp::Ordering::Equal)
108        });
109
110        let mid = sorted_trajectories.len() / 2;
111        let (high_quality, low_quality) = sorted_trajectories.split_at(mid);
112
113        for (chosen, rejected) in high_quality.iter().zip(low_quality.iter().rev()) {
114            // Skip if quality difference is too small
115            if (chosen.quality - rejected.quality).abs() < 0.1 {
116                continue;
117            }
118
119            let pair = PreferencePair {
120                prompt: PreferencePrompt {
121                    embedding: chosen.query_embedding.clone(),
122                    context: chosen.context_ids.clone(),
123                },
124                chosen: PreferenceResponse {
125                    route: chosen.route.clone(),
126                    quality: chosen.quality,
127                    embedding: chosen.response_embedding.clone(),
128                },
129                rejected: PreferenceResponse {
130                    route: rejected.route.clone(),
131                    quality: rejected.quality,
132                    embedding: rejected.response_embedding.clone(),
133                },
134                metadata: PreferenceMetadata {
135                    quality_delta: chosen.quality - rejected.quality,
136                    source: "sona".to_string(),
137                    version: env!("CARGO_PKG_VERSION").to_string(),
138                },
139            };
140
141            let json = serde_json::to_string(&pair).map_err(ExportError::Serialization)?;
142            writeln!(writer, "{}", json).map_err(ExportError::Io)?;
143            items_exported += 1;
144        }
145
146        writer.flush().map_err(ExportError::Io)?;
147
148        let size_bytes = std::fs::metadata(output_path)
149            .map(|m| m.len())
150            .unwrap_or(0);
151
152        Ok(ExportResult {
153            export_type: ExportType::PreferencePairs,
154            items_exported,
155            output_path: output_path.to_string_lossy().to_string(),
156            size_bytes,
157        })
158    }
159
160    /// Export distillation targets for knowledge distillation
161    pub fn export_distillation_targets<P: AsRef<Path>>(
162        &self,
163        engine: &SonaEngine,
164        output_path: P,
165    ) -> Result<ExportResult, ExportError> {
166        let output_path = output_path.as_ref();
167
168        // Ensure parent directory exists
169        if let Some(parent) = output_path.parent() {
170            std::fs::create_dir_all(parent).map_err(ExportError::Io)?;
171        }
172
173        let file = std::fs::File::create(output_path).map_err(ExportError::Io)?;
174        let mut writer = BufWriter::new(file);
175
176        let routing_decisions = engine.get_routing_decisions();
177        let mut items_exported = 0;
178
179        for decision in routing_decisions {
180            // Filter by quality
181            if decision.quality < self.config.min_quality_threshold {
182                continue;
183            }
184
185            let target = DistillationTarget {
186                input_embedding: decision.query_embedding.clone(),
187                teacher_logits: decision.routing_logits.clone(),
188                selected_route: decision.selected_route.clone(),
189                confidence: decision.confidence,
190                quality: decision.quality,
191                metadata: DistillationMetadata {
192                    source: "sona".to_string(),
193                    version: env!("CARGO_PKG_VERSION").to_string(),
194                    temperature: 1.0,
195                },
196            };
197
198            let json = serde_json::to_string(&target).map_err(ExportError::Serialization)?;
199            writeln!(writer, "{}", json).map_err(ExportError::Io)?;
200            items_exported += 1;
201        }
202
203        writer.flush().map_err(ExportError::Io)?;
204
205        let size_bytes = std::fs::metadata(output_path)
206            .map(|m| m.len())
207            .unwrap_or(0);
208
209        Ok(ExportResult {
210            export_type: ExportType::DistillationTargets,
211            items_exported,
212            output_path: output_path.to_string_lossy().to_string(),
213            size_bytes,
214        })
215    }
216}
217
218/// Pattern record for JSONL export
219#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
220#[derive(Clone, Debug)]
221pub struct PatternRecord {
222    /// Pattern ID
223    pub id: String,
224    /// Embedding vector
225    pub embedding: Vec<f32>,
226    /// Number of trajectories in cluster
227    pub cluster_size: usize,
228    /// Average quality score
229    pub avg_quality: f32,
230    /// Pattern type (routing, reasoning, etc.)
231    pub pattern_type: String,
232    /// Access count
233    pub access_count: u64,
234    /// Export metadata
235    pub metadata: PatternMetadata,
236}
237
238/// Pattern export metadata
239#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
240#[derive(Clone, Debug)]
241pub struct PatternMetadata {
242    /// Source system
243    pub source: String,
244    /// Version
245    pub version: String,
246    /// Target model architecture
247    pub target_model: String,
248}
249
250/// Preference pair for DPO/RLHF
251#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
252#[derive(Clone, Debug)]
253pub struct PreferencePair {
254    /// Input prompt
255    pub prompt: PreferencePrompt,
256    /// Chosen (preferred) response
257    pub chosen: PreferenceResponse,
258    /// Rejected response
259    pub rejected: PreferenceResponse,
260    /// Metadata
261    pub metadata: PreferenceMetadata,
262}
263
264/// Preference prompt
265#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
266#[derive(Clone, Debug)]
267pub struct PreferencePrompt {
268    /// Query embedding
269    pub embedding: Vec<f32>,
270    /// Context IDs
271    pub context: Vec<String>,
272}
273
274/// Preference response
275#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
276#[derive(Clone, Debug)]
277pub struct PreferenceResponse {
278    /// Model route
279    pub route: String,
280    /// Quality score
281    pub quality: f32,
282    /// Response embedding
283    pub embedding: Vec<f32>,
284}
285
286/// Preference pair metadata
287#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
288#[derive(Clone, Debug)]
289pub struct PreferenceMetadata {
290    /// Quality difference between chosen and rejected
291    pub quality_delta: f32,
292    /// Source system
293    pub source: String,
294    /// Version
295    pub version: String,
296}
297
298/// Distillation target for knowledge distillation
299#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
300#[derive(Clone, Debug)]
301pub struct DistillationTarget {
302    /// Input embedding
303    pub input_embedding: Vec<f32>,
304    /// Teacher model logits
305    pub teacher_logits: Vec<f32>,
306    /// Selected route
307    pub selected_route: String,
308    /// Confidence score
309    pub confidence: f32,
310    /// Quality score
311    pub quality: f32,
312    /// Metadata
313    pub metadata: DistillationMetadata,
314}
315
316/// Distillation metadata
317#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
318#[derive(Clone, Debug)]
319pub struct DistillationMetadata {
320    /// Source system
321    pub source: String,
322    /// Version
323    pub version: String,
324    /// Temperature for softmax
325    pub temperature: f32,
326}
327
328/// Quality trajectory for preference learning
329#[derive(Clone, Debug)]
330pub struct QualityTrajectory {
331    /// Query embedding
332    pub query_embedding: Vec<f32>,
333    /// Response embedding
334    pub response_embedding: Vec<f32>,
335    /// Model route
336    pub route: String,
337    /// Quality score
338    pub quality: f32,
339    /// Context IDs
340    pub context_ids: Vec<String>,
341}
342
343/// Routing decision for distillation
344#[derive(Clone, Debug)]
345pub struct RoutingDecision {
346    /// Query embedding
347    pub query_embedding: Vec<f32>,
348    /// Routing logits
349    pub routing_logits: Vec<f32>,
350    /// Selected route
351    pub selected_route: String,
352    /// Confidence
353    pub confidence: f32,
354    /// Quality
355    pub quality: f32,
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361
362    #[test]
363    fn test_pattern_record() {
364        let record = PatternRecord {
365            id: "test-pattern".to_string(),
366            embedding: vec![0.1, 0.2, 0.3],
367            cluster_size: 10,
368            avg_quality: 0.85,
369            pattern_type: "routing".to_string(),
370            access_count: 100,
371            metadata: PatternMetadata {
372                source: "sona".to_string(),
373                version: "0.1.0".to_string(),
374                target_model: "phi-4".to_string(),
375            },
376        };
377
378        let json = serde_json::to_string(&record).unwrap();
379        assert!(json.contains("test-pattern"));
380        assert!(json.contains("0.85"));
381    }
382
383    #[test]
384    fn test_preference_pair() {
385        let pair = PreferencePair {
386            prompt: PreferencePrompt {
387                embedding: vec![0.1, 0.2],
388                context: vec!["ctx1".to_string()],
389            },
390            chosen: PreferenceResponse {
391                route: "gpt-4".to_string(),
392                quality: 0.9,
393                embedding: vec![0.3, 0.4],
394            },
395            rejected: PreferenceResponse {
396                route: "gpt-3.5".to_string(),
397                quality: 0.6,
398                embedding: vec![0.5, 0.6],
399            },
400            metadata: PreferenceMetadata {
401                quality_delta: 0.3,
402                source: "sona".to_string(),
403                version: "0.1.0".to_string(),
404            },
405        };
406
407        let json = serde_json::to_string(&pair).unwrap();
408        assert!(json.contains("gpt-4"));
409        assert!(json.contains("0.9"));
410    }
411}