Skip to main content

ruvector_sona/export/
dataset.rs

1//! Dataset Export - HuggingFace-compatible dataset formats
2//!
3//! Exports SONA's learned patterns and preference pairs as JSONL datasets
4//! compatible with HuggingFace's datasets library.
5
6use super::{ExportConfig, ExportError, ExportResult, ExportType};
7use crate::engine::SonaEngine;
8use std::io::{BufWriter, Write};
9use std::path::Path;
10
11#[cfg(feature = "serde-support")]
12use serde::{Deserialize, Serialize};
13
14/// Dataset exporter for patterns and preferences
15pub struct DatasetExporter<'a> {
16    config: &'a ExportConfig,
17}
18
19impl<'a> DatasetExporter<'a> {
20    /// Create new dataset exporter
21    pub fn new(config: &'a ExportConfig) -> Self {
22        Self { config }
23    }
24
25    /// Export learned patterns as JSONL dataset
26    pub fn export_patterns<P: AsRef<Path>>(
27        &self,
28        engine: &SonaEngine,
29        output_path: P,
30    ) -> Result<ExportResult, ExportError> {
31        let output_path = output_path.as_ref();
32
33        // Ensure parent directory exists
34        if let Some(parent) = output_path.parent() {
35            std::fs::create_dir_all(parent).map_err(ExportError::Io)?;
36        }
37
38        let file = std::fs::File::create(output_path).map_err(ExportError::Io)?;
39        let mut writer = BufWriter::new(file);
40
41        let patterns = engine.get_all_patterns();
42        let mut items_exported = 0;
43
44        for pattern in patterns {
45            // Filter by quality threshold
46            if pattern.avg_quality < self.config.min_quality_threshold {
47                continue;
48            }
49
50            let record = PatternRecord {
51                id: pattern.id.to_string(),
52                embedding: pattern.centroid.clone(),
53                cluster_size: pattern.cluster_size,
54                avg_quality: pattern.avg_quality,
55                pattern_type: pattern.pattern_type.to_string(),
56                access_count: pattern.access_count as u64,
57                metadata: PatternMetadata {
58                    source: "sona".to_string(),
59                    version: env!("CARGO_PKG_VERSION").to_string(),
60                    target_model: self.config.target_architecture.clone(),
61                },
62            };
63
64            let json = serde_json::to_string(&record).map_err(ExportError::Serialization)?;
65            writeln!(writer, "{}", json).map_err(ExportError::Io)?;
66            items_exported += 1;
67        }
68
69        writer.flush().map_err(ExportError::Io)?;
70
71        let size_bytes = std::fs::metadata(output_path).map(|m| m.len()).unwrap_or(0);
72
73        Ok(ExportResult {
74            export_type: ExportType::PatternsDataset,
75            items_exported,
76            output_path: output_path.to_string_lossy().to_string(),
77            size_bytes,
78        })
79    }
80
81    /// Export preference pairs for DPO/RLHF training
82    pub fn export_preferences<P: AsRef<Path>>(
83        &self,
84        engine: &SonaEngine,
85        output_path: P,
86    ) -> Result<ExportResult, ExportError> {
87        let output_path = output_path.as_ref();
88
89        // Ensure parent directory exists
90        if let Some(parent) = output_path.parent() {
91            std::fs::create_dir_all(parent).map_err(ExportError::Io)?;
92        }
93
94        let file = std::fs::File::create(output_path).map_err(ExportError::Io)?;
95        let mut writer = BufWriter::new(file);
96
97        let trajectories = engine.get_quality_trajectories();
98        let mut items_exported = 0;
99
100        // Generate preference pairs from trajectories
101        // Sort by quality and pair high-quality with low-quality
102        let mut sorted_trajectories = trajectories.clone();
103        sorted_trajectories.sort_by(|a, b| {
104            b.quality
105                .partial_cmp(&a.quality)
106                .unwrap_or(std::cmp::Ordering::Equal)
107        });
108
109        let mid = sorted_trajectories.len() / 2;
110        let (high_quality, low_quality) = sorted_trajectories.split_at(mid);
111
112        for (chosen, rejected) in high_quality.iter().zip(low_quality.iter().rev()) {
113            // Skip if quality difference is too small
114            if (chosen.quality - rejected.quality).abs() < 0.1 {
115                continue;
116            }
117
118            let pair = PreferencePair {
119                prompt: PreferencePrompt {
120                    embedding: chosen.query_embedding.clone(),
121                    context: chosen.context_ids.clone(),
122                },
123                chosen: PreferenceResponse {
124                    route: chosen.route.clone(),
125                    quality: chosen.quality,
126                    embedding: chosen.response_embedding.clone(),
127                },
128                rejected: PreferenceResponse {
129                    route: rejected.route.clone(),
130                    quality: rejected.quality,
131                    embedding: rejected.response_embedding.clone(),
132                },
133                metadata: PreferenceMetadata {
134                    quality_delta: chosen.quality - rejected.quality,
135                    source: "sona".to_string(),
136                    version: env!("CARGO_PKG_VERSION").to_string(),
137                },
138            };
139
140            let json = serde_json::to_string(&pair).map_err(ExportError::Serialization)?;
141            writeln!(writer, "{}", json).map_err(ExportError::Io)?;
142            items_exported += 1;
143        }
144
145        writer.flush().map_err(ExportError::Io)?;
146
147        let size_bytes = std::fs::metadata(output_path).map(|m| m.len()).unwrap_or(0);
148
149        Ok(ExportResult {
150            export_type: ExportType::PreferencePairs,
151            items_exported,
152            output_path: output_path.to_string_lossy().to_string(),
153            size_bytes,
154        })
155    }
156
157    /// Export distillation targets for knowledge distillation
158    pub fn export_distillation_targets<P: AsRef<Path>>(
159        &self,
160        engine: &SonaEngine,
161        output_path: P,
162    ) -> Result<ExportResult, ExportError> {
163        let output_path = output_path.as_ref();
164
165        // Ensure parent directory exists
166        if let Some(parent) = output_path.parent() {
167            std::fs::create_dir_all(parent).map_err(ExportError::Io)?;
168        }
169
170        let file = std::fs::File::create(output_path).map_err(ExportError::Io)?;
171        let mut writer = BufWriter::new(file);
172
173        let routing_decisions = engine.get_routing_decisions();
174        let mut items_exported = 0;
175
176        for decision in routing_decisions {
177            // Filter by quality
178            if decision.quality < self.config.min_quality_threshold {
179                continue;
180            }
181
182            let target = DistillationTarget {
183                input_embedding: decision.query_embedding.clone(),
184                teacher_logits: decision.routing_logits.clone(),
185                selected_route: decision.selected_route.clone(),
186                confidence: decision.confidence,
187                quality: decision.quality,
188                metadata: DistillationMetadata {
189                    source: "sona".to_string(),
190                    version: env!("CARGO_PKG_VERSION").to_string(),
191                    temperature: 1.0,
192                },
193            };
194
195            let json = serde_json::to_string(&target).map_err(ExportError::Serialization)?;
196            writeln!(writer, "{}", json).map_err(ExportError::Io)?;
197            items_exported += 1;
198        }
199
200        writer.flush().map_err(ExportError::Io)?;
201
202        let size_bytes = std::fs::metadata(output_path).map(|m| m.len()).unwrap_or(0);
203
204        Ok(ExportResult {
205            export_type: ExportType::DistillationTargets,
206            items_exported,
207            output_path: output_path.to_string_lossy().to_string(),
208            size_bytes,
209        })
210    }
211}
212
213/// Pattern record for JSONL export
214#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
215#[derive(Clone, Debug)]
216pub struct PatternRecord {
217    /// Pattern ID
218    pub id: String,
219    /// Embedding vector
220    pub embedding: Vec<f32>,
221    /// Number of trajectories in cluster
222    pub cluster_size: usize,
223    /// Average quality score
224    pub avg_quality: f32,
225    /// Pattern type (routing, reasoning, etc.)
226    pub pattern_type: String,
227    /// Access count
228    pub access_count: u64,
229    /// Export metadata
230    pub metadata: PatternMetadata,
231}
232
233/// Pattern export metadata
234#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
235#[derive(Clone, Debug)]
236pub struct PatternMetadata {
237    /// Source system
238    pub source: String,
239    /// Version
240    pub version: String,
241    /// Target model architecture
242    pub target_model: String,
243}
244
245/// Preference pair for DPO/RLHF
246#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
247#[derive(Clone, Debug)]
248pub struct PreferencePair {
249    /// Input prompt
250    pub prompt: PreferencePrompt,
251    /// Chosen (preferred) response
252    pub chosen: PreferenceResponse,
253    /// Rejected response
254    pub rejected: PreferenceResponse,
255    /// Metadata
256    pub metadata: PreferenceMetadata,
257}
258
259/// Preference prompt
260#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
261#[derive(Clone, Debug)]
262pub struct PreferencePrompt {
263    /// Query embedding
264    pub embedding: Vec<f32>,
265    /// Context IDs
266    pub context: Vec<String>,
267}
268
269/// Preference response
270#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
271#[derive(Clone, Debug)]
272pub struct PreferenceResponse {
273    /// Model route
274    pub route: String,
275    /// Quality score
276    pub quality: f32,
277    /// Response embedding
278    pub embedding: Vec<f32>,
279}
280
281/// Preference pair metadata
282#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
283#[derive(Clone, Debug)]
284pub struct PreferenceMetadata {
285    /// Quality difference between chosen and rejected
286    pub quality_delta: f32,
287    /// Source system
288    pub source: String,
289    /// Version
290    pub version: String,
291}
292
293/// Distillation target for knowledge distillation
294#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
295#[derive(Clone, Debug)]
296pub struct DistillationTarget {
297    /// Input embedding
298    pub input_embedding: Vec<f32>,
299    /// Teacher model logits
300    pub teacher_logits: Vec<f32>,
301    /// Selected route
302    pub selected_route: String,
303    /// Confidence score
304    pub confidence: f32,
305    /// Quality score
306    pub quality: f32,
307    /// Metadata
308    pub metadata: DistillationMetadata,
309}
310
311/// Distillation metadata
312#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
313#[derive(Clone, Debug)]
314pub struct DistillationMetadata {
315    /// Source system
316    pub source: String,
317    /// Version
318    pub version: String,
319    /// Temperature for softmax
320    pub temperature: f32,
321}
322
323/// Quality trajectory for preference learning
324#[derive(Clone, Debug)]
325pub struct QualityTrajectory {
326    /// Query embedding
327    pub query_embedding: Vec<f32>,
328    /// Response embedding
329    pub response_embedding: Vec<f32>,
330    /// Model route
331    pub route: String,
332    /// Quality score
333    pub quality: f32,
334    /// Context IDs
335    pub context_ids: Vec<String>,
336}
337
338/// Routing decision for distillation
339#[derive(Clone, Debug)]
340pub struct RoutingDecision {
341    /// Query embedding
342    pub query_embedding: Vec<f32>,
343    /// Routing logits
344    pub routing_logits: Vec<f32>,
345    /// Selected route
346    pub selected_route: String,
347    /// Confidence
348    pub confidence: f32,
349    /// Quality
350    pub quality: f32,
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356
357    #[test]
358    fn test_pattern_record() {
359        let record = PatternRecord {
360            id: "test-pattern".to_string(),
361            embedding: vec![0.1, 0.2, 0.3],
362            cluster_size: 10,
363            avg_quality: 0.85,
364            pattern_type: "routing".to_string(),
365            access_count: 100,
366            metadata: PatternMetadata {
367                source: "sona".to_string(),
368                version: "0.1.0".to_string(),
369                target_model: "phi-4".to_string(),
370            },
371        };
372
373        let json = serde_json::to_string(&record).unwrap();
374        assert!(json.contains("test-pattern"));
375        assert!(json.contains("0.85"));
376    }
377
378    #[test]
379    fn test_preference_pair() {
380        let pair = PreferencePair {
381            prompt: PreferencePrompt {
382                embedding: vec![0.1, 0.2],
383                context: vec!["ctx1".to_string()],
384            },
385            chosen: PreferenceResponse {
386                route: "gpt-4".to_string(),
387                quality: 0.9,
388                embedding: vec![0.3, 0.4],
389            },
390            rejected: PreferenceResponse {
391                route: "gpt-3.5".to_string(),
392                quality: 0.6,
393                embedding: vec![0.5, 0.6],
394            },
395            metadata: PreferenceMetadata {
396                quality_delta: 0.3,
397                source: "sona".to_string(),
398                version: "0.1.0".to_string(),
399            },
400        };
401
402        let json = serde_json::to_string(&pair).unwrap();
403        assert!(json.contains("gpt-4"));
404        assert!(json.contains("0.9"));
405    }
406}