Skip to main content

ruvector_sona/export/
dataset.rs

1//! Dataset Export - HuggingFace-compatible dataset formats
2//!
3//! Exports SONA's learned patterns and preference pairs as JSONL datasets
4//! compatible with HuggingFace's datasets library.
5
6use super::{ExportConfig, ExportError, ExportResult, ExportType};
7use crate::engine::SonaEngine;
8use crate::types::LearnedPattern;
9use std::io::{BufWriter, Write};
10use std::path::Path;
11
12#[cfg(feature = "serde-support")]
13use serde::{Deserialize, Serialize};
14
15/// Dataset exporter for patterns and preferences
16pub struct DatasetExporter<'a> {
17    config: &'a ExportConfig,
18}
19
20impl<'a> DatasetExporter<'a> {
21    /// Create new dataset exporter
22    pub fn new(config: &'a ExportConfig) -> Self {
23        Self { config }
24    }
25
26    /// Export learned patterns as JSONL dataset
27    pub fn export_patterns<P: AsRef<Path>>(
28        &self,
29        engine: &SonaEngine,
30        output_path: P,
31    ) -> Result<ExportResult, ExportError> {
32        let output_path = output_path.as_ref();
33
34        // Ensure parent directory exists
35        if let Some(parent) = output_path.parent() {
36            std::fs::create_dir_all(parent).map_err(ExportError::Io)?;
37        }
38
39        let file = std::fs::File::create(output_path).map_err(ExportError::Io)?;
40        let mut writer = BufWriter::new(file);
41
42        let patterns = engine.get_all_patterns();
43        let mut items_exported = 0;
44
45        for pattern in patterns {
46            // Filter by quality threshold
47            if pattern.avg_quality < self.config.min_quality_threshold {
48                continue;
49            }
50
51            let record = PatternRecord {
52                id: pattern.id.to_string(),
53                embedding: pattern.centroid.clone(),
54                cluster_size: pattern.cluster_size,
55                avg_quality: pattern.avg_quality,
56                pattern_type: pattern.pattern_type.to_string(),
57                access_count: pattern.access_count as u64,
58                metadata: PatternMetadata {
59                    source: "sona".to_string(),
60                    version: env!("CARGO_PKG_VERSION").to_string(),
61                    target_model: self.config.target_architecture.clone(),
62                },
63            };
64
65            let json = serde_json::to_string(&record).map_err(ExportError::Serialization)?;
66            writeln!(writer, "{}", json).map_err(ExportError::Io)?;
67            items_exported += 1;
68        }
69
70        writer.flush().map_err(ExportError::Io)?;
71
72        let size_bytes = std::fs::metadata(output_path).map(|m| m.len()).unwrap_or(0);
73
74        Ok(ExportResult {
75            export_type: ExportType::PatternsDataset,
76            items_exported,
77            output_path: output_path.to_string_lossy().to_string(),
78            size_bytes,
79        })
80    }
81
82    /// Export preference pairs for DPO/RLHF training
83    pub fn export_preferences<P: AsRef<Path>>(
84        &self,
85        engine: &SonaEngine,
86        output_path: P,
87    ) -> Result<ExportResult, ExportError> {
88        let output_path = output_path.as_ref();
89
90        // Ensure parent directory exists
91        if let Some(parent) = output_path.parent() {
92            std::fs::create_dir_all(parent).map_err(ExportError::Io)?;
93        }
94
95        let file = std::fs::File::create(output_path).map_err(ExportError::Io)?;
96        let mut writer = BufWriter::new(file);
97
98        let trajectories = engine.get_quality_trajectories();
99        let mut items_exported = 0;
100
101        // Generate preference pairs from trajectories
102        // Sort by quality and pair high-quality with low-quality
103        let mut sorted_trajectories = trajectories.clone();
104        sorted_trajectories.sort_by(|a, b| {
105            b.quality
106                .partial_cmp(&a.quality)
107                .unwrap_or(std::cmp::Ordering::Equal)
108        });
109
110        let mid = sorted_trajectories.len() / 2;
111        let (high_quality, low_quality) = sorted_trajectories.split_at(mid);
112
113        for (chosen, rejected) in high_quality.iter().zip(low_quality.iter().rev()) {
114            // Skip if quality difference is too small
115            if (chosen.quality - rejected.quality).abs() < 0.1 {
116                continue;
117            }
118
119            let pair = PreferencePair {
120                prompt: PreferencePrompt {
121                    embedding: chosen.query_embedding.clone(),
122                    context: chosen.context_ids.clone(),
123                },
124                chosen: PreferenceResponse {
125                    route: chosen.route.clone(),
126                    quality: chosen.quality,
127                    embedding: chosen.response_embedding.clone(),
128                },
129                rejected: PreferenceResponse {
130                    route: rejected.route.clone(),
131                    quality: rejected.quality,
132                    embedding: rejected.response_embedding.clone(),
133                },
134                metadata: PreferenceMetadata {
135                    quality_delta: chosen.quality - rejected.quality,
136                    source: "sona".to_string(),
137                    version: env!("CARGO_PKG_VERSION").to_string(),
138                },
139            };
140
141            let json = serde_json::to_string(&pair).map_err(ExportError::Serialization)?;
142            writeln!(writer, "{}", json).map_err(ExportError::Io)?;
143            items_exported += 1;
144        }
145
146        writer.flush().map_err(ExportError::Io)?;
147
148        let size_bytes = std::fs::metadata(output_path).map(|m| m.len()).unwrap_or(0);
149
150        Ok(ExportResult {
151            export_type: ExportType::PreferencePairs,
152            items_exported,
153            output_path: output_path.to_string_lossy().to_string(),
154            size_bytes,
155        })
156    }
157
158    /// Export distillation targets for knowledge distillation
159    pub fn export_distillation_targets<P: AsRef<Path>>(
160        &self,
161        engine: &SonaEngine,
162        output_path: P,
163    ) -> Result<ExportResult, ExportError> {
164        let output_path = output_path.as_ref();
165
166        // Ensure parent directory exists
167        if let Some(parent) = output_path.parent() {
168            std::fs::create_dir_all(parent).map_err(ExportError::Io)?;
169        }
170
171        let file = std::fs::File::create(output_path).map_err(ExportError::Io)?;
172        let mut writer = BufWriter::new(file);
173
174        let routing_decisions = engine.get_routing_decisions();
175        let mut items_exported = 0;
176
177        for decision in routing_decisions {
178            // Filter by quality
179            if decision.quality < self.config.min_quality_threshold {
180                continue;
181            }
182
183            let target = DistillationTarget {
184                input_embedding: decision.query_embedding.clone(),
185                teacher_logits: decision.routing_logits.clone(),
186                selected_route: decision.selected_route.clone(),
187                confidence: decision.confidence,
188                quality: decision.quality,
189                metadata: DistillationMetadata {
190                    source: "sona".to_string(),
191                    version: env!("CARGO_PKG_VERSION").to_string(),
192                    temperature: 1.0,
193                },
194            };
195
196            let json = serde_json::to_string(&target).map_err(ExportError::Serialization)?;
197            writeln!(writer, "{}", json).map_err(ExportError::Io)?;
198            items_exported += 1;
199        }
200
201        writer.flush().map_err(ExportError::Io)?;
202
203        let size_bytes = std::fs::metadata(output_path).map(|m| m.len()).unwrap_or(0);
204
205        Ok(ExportResult {
206            export_type: ExportType::DistillationTargets,
207            items_exported,
208            output_path: output_path.to_string_lossy().to_string(),
209            size_bytes,
210        })
211    }
212}
213
214/// Pattern record for JSONL export
215#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
216#[derive(Clone, Debug)]
217pub struct PatternRecord {
218    /// Pattern ID
219    pub id: String,
220    /// Embedding vector
221    pub embedding: Vec<f32>,
222    /// Number of trajectories in cluster
223    pub cluster_size: usize,
224    /// Average quality score
225    pub avg_quality: f32,
226    /// Pattern type (routing, reasoning, etc.)
227    pub pattern_type: String,
228    /// Access count
229    pub access_count: u64,
230    /// Export metadata
231    pub metadata: PatternMetadata,
232}
233
234/// Pattern export metadata
235#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
236#[derive(Clone, Debug)]
237pub struct PatternMetadata {
238    /// Source system
239    pub source: String,
240    /// Version
241    pub version: String,
242    /// Target model architecture
243    pub target_model: String,
244}
245
246/// Preference pair for DPO/RLHF
247#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
248#[derive(Clone, Debug)]
249pub struct PreferencePair {
250    /// Input prompt
251    pub prompt: PreferencePrompt,
252    /// Chosen (preferred) response
253    pub chosen: PreferenceResponse,
254    /// Rejected response
255    pub rejected: PreferenceResponse,
256    /// Metadata
257    pub metadata: PreferenceMetadata,
258}
259
260/// Preference prompt
261#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
262#[derive(Clone, Debug)]
263pub struct PreferencePrompt {
264    /// Query embedding
265    pub embedding: Vec<f32>,
266    /// Context IDs
267    pub context: Vec<String>,
268}
269
270/// Preference response
271#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
272#[derive(Clone, Debug)]
273pub struct PreferenceResponse {
274    /// Model route
275    pub route: String,
276    /// Quality score
277    pub quality: f32,
278    /// Response embedding
279    pub embedding: Vec<f32>,
280}
281
282/// Preference pair metadata
283#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
284#[derive(Clone, Debug)]
285pub struct PreferenceMetadata {
286    /// Quality difference between chosen and rejected
287    pub quality_delta: f32,
288    /// Source system
289    pub source: String,
290    /// Version
291    pub version: String,
292}
293
294/// Distillation target for knowledge distillation
295#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
296#[derive(Clone, Debug)]
297pub struct DistillationTarget {
298    /// Input embedding
299    pub input_embedding: Vec<f32>,
300    /// Teacher model logits
301    pub teacher_logits: Vec<f32>,
302    /// Selected route
303    pub selected_route: String,
304    /// Confidence score
305    pub confidence: f32,
306    /// Quality score
307    pub quality: f32,
308    /// Metadata
309    pub metadata: DistillationMetadata,
310}
311
312/// Distillation metadata
313#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
314#[derive(Clone, Debug)]
315pub struct DistillationMetadata {
316    /// Source system
317    pub source: String,
318    /// Version
319    pub version: String,
320    /// Temperature for softmax
321    pub temperature: f32,
322}
323
324/// Quality trajectory for preference learning
325#[derive(Clone, Debug)]
326pub struct QualityTrajectory {
327    /// Query embedding
328    pub query_embedding: Vec<f32>,
329    /// Response embedding
330    pub response_embedding: Vec<f32>,
331    /// Model route
332    pub route: String,
333    /// Quality score
334    pub quality: f32,
335    /// Context IDs
336    pub context_ids: Vec<String>,
337}
338
339/// Routing decision for distillation
340#[derive(Clone, Debug)]
341pub struct RoutingDecision {
342    /// Query embedding
343    pub query_embedding: Vec<f32>,
344    /// Routing logits
345    pub routing_logits: Vec<f32>,
346    /// Selected route
347    pub selected_route: String,
348    /// Confidence
349    pub confidence: f32,
350    /// Quality
351    pub quality: f32,
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[test]
359    fn test_pattern_record() {
360        let record = PatternRecord {
361            id: "test-pattern".to_string(),
362            embedding: vec![0.1, 0.2, 0.3],
363            cluster_size: 10,
364            avg_quality: 0.85,
365            pattern_type: "routing".to_string(),
366            access_count: 100,
367            metadata: PatternMetadata {
368                source: "sona".to_string(),
369                version: "0.1.0".to_string(),
370                target_model: "phi-4".to_string(),
371            },
372        };
373
374        let json = serde_json::to_string(&record).unwrap();
375        assert!(json.contains("test-pattern"));
376        assert!(json.contains("0.85"));
377    }
378
379    #[test]
380    fn test_preference_pair() {
381        let pair = PreferencePair {
382            prompt: PreferencePrompt {
383                embedding: vec![0.1, 0.2],
384                context: vec!["ctx1".to_string()],
385            },
386            chosen: PreferenceResponse {
387                route: "gpt-4".to_string(),
388                quality: 0.9,
389                embedding: vec![0.3, 0.4],
390            },
391            rejected: PreferenceResponse {
392                route: "gpt-3.5".to_string(),
393                quality: 0.6,
394                embedding: vec![0.5, 0.6],
395            },
396            metadata: PreferenceMetadata {
397                quality_delta: 0.3,
398                source: "sona".to_string(),
399                version: "0.1.0".to_string(),
400            },
401        };
402
403        let json = serde_json::to_string(&pair).unwrap();
404        assert!(json.contains("gpt-4"));
405        assert!(json.contains("0.9"));
406    }
407}