1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::path::PathBuf;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct SynthConfig {
7 pub name: String,
8 #[serde(default)]
9 pub source: Option<SourceConfig>,
10 pub provider: ProviderConfig,
11 pub generation: GenerationConfig,
12 pub output: OutputConfig,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16#[serde(tag = "type", rename_all = "lowercase")]
17pub enum SourceConfig {
18 HuggingFace {
19 dataset: String,
20 #[serde(default)]
21 subset: Option<String>,
22 #[serde(default = "default_split")]
23 split: String,
24 #[serde(default)]
25 sample: Option<usize>,
26 #[serde(default)]
27 columns: Option<Vec<String>>,
28 },
29 Local {
30 path: PathBuf,
31 format: FileFormat,
32 #[serde(default)]
33 sample: Option<usize>,
34 },
35}
36
37fn default_split() -> String {
38 "train".to_string()
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
42#[serde(rename_all = "lowercase")]
43pub enum FileFormat {
44 Json,
45 Jsonl,
46 Csv,
47 Parquet,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51#[serde(tag = "type", rename_all = "lowercase")]
52pub enum ProviderConfig {
53 OpenAI {
54 model: String,
55 #[serde(default)]
56 api_key: Option<String>,
57 #[serde(default)]
58 base_url: Option<String>,
59 #[serde(default)]
60 temperature: Option<f32>,
61 #[serde(default)]
62 max_tokens: Option<u32>,
63 },
64 Anthropic {
65 model: String,
66 #[serde(default)]
67 api_key: Option<String>,
68 #[serde(default)]
69 temperature: Option<f32>,
70 #[serde(default)]
71 max_tokens: Option<u32>,
72 },
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct GenerationConfig {
77 pub task: GenerationTask,
78 #[serde(default = "default_count")]
79 pub count: usize,
80 #[serde(default)]
81 pub count_per_example: Option<usize>,
82 #[serde(default = "default_concurrency")]
83 pub concurrency: usize,
84 #[serde(default)]
85 pub strategy: Option<GenerationStrategy>,
86 #[serde(default)]
87 pub strategy_config: HashMap<String, serde_yaml::Value>,
88 #[serde(default)]
89 pub template: Option<String>,
90 #[serde(default)]
91 pub system_prompt: Option<String>,
92 #[serde(default)]
93 pub categories: Option<Vec<String>>,
94}
95
96fn default_count() -> usize {
97 100
98}
99
100fn default_concurrency() -> usize {
101 5
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
105#[serde(rename_all = "snake_case")]
106pub enum GenerationTask {
107 Generate,
108 Augment,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
112#[serde(rename_all = "snake_case")]
113pub enum GenerationStrategy {
114 Paraphrase,
115 StyleTransfer,
116 BackTranslation,
117 Custom,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct OutputConfig {
122 pub format: OutputFormat,
123 pub path: PathBuf,
124 #[serde(default = "default_batch_size")]
125 pub batch_size: usize,
126}
127
128fn default_batch_size() -> usize {
129 100
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
133#[serde(rename_all = "lowercase")]
134pub enum OutputFormat {
135 Json,
136 Jsonl,
137 Csv,
138 Parquet,
139}
140
141impl SynthConfig {
142 pub fn from_yaml(content: &str) -> crate::Result<Self> {
143 serde_yaml::from_str(content).map_err(Into::into)
144 }
145
146 pub fn from_file(path: &PathBuf) -> crate::Result<Self> {
147 let content = std::fs::read_to_string(path)?;
148 Self::from_yaml(&content)
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn test_parse_augment_config() {
158 let yaml = r#"
159name: "sentiment_augmentation"
160
161source:
162 type: huggingface
163 dataset: "cornell-movie-review-data/rotten_tomatoes"
164 split: "train"
165 sample: 1000
166
167provider:
168 type: openai
169 model: "gpt-4o-mini"
170
171generation:
172 task: augment
173 count_per_example: 3
174 concurrency: 10
175 strategy: paraphrase
176
177output:
178 format: jsonl
179 path: "./output/augmented.jsonl"
180"#;
181
182 let config = SynthConfig::from_yaml(yaml).unwrap();
183 assert_eq!(config.name, "sentiment_augmentation");
184 assert!(matches!(
185 config.source,
186 Some(SourceConfig::HuggingFace { .. })
187 ));
188 assert!(matches!(config.provider, ProviderConfig::OpenAI { .. }));
189 assert!(matches!(config.generation.task, GenerationTask::Augment));
190 }
191
192 #[test]
193 fn test_parse_generate_config() {
194 let yaml = r#"
195name: "product_reviews"
196
197provider:
198 type: anthropic
199 model: "claude-haiku-4-5-20251001"
200
201generation:
202 task: generate
203 count: 500
204 concurrency: 5
205 categories:
206 - electronics
207 - books
208 - clothing
209 template: |
210 Generate a realistic {category} product review.
211 Output only the review text.
212
213output:
214 format: parquet
215 path: "./output/reviews.parquet"
216"#;
217
218 let config = SynthConfig::from_yaml(yaml).unwrap();
219 assert_eq!(config.name, "product_reviews");
220 assert!(config.source.is_none());
221 assert!(matches!(config.provider, ProviderConfig::Anthropic { .. }));
222 assert!(matches!(config.generation.task, GenerationTask::Generate));
223 assert_eq!(config.generation.categories.as_ref().unwrap().len(), 3);
224 }
225}