1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::path::PathBuf;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct SynthConfig {
7 pub name: String,
8 #[serde(default)]
9 pub source: Option<SourceConfig>,
10 pub provider: ProviderConfig,
11 pub generation: GenerationConfig,
12 pub output: OutputConfig,
13 #[serde(default)]
14 pub validation: Option<ValidationConfig>,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18#[serde(tag = "type", rename_all = "lowercase")]
19pub enum SourceConfig {
20 HuggingFace {
21 dataset: String,
22 #[serde(default)]
23 subset: Option<String>,
24 #[serde(default = "default_split")]
25 split: String,
26 #[serde(default)]
27 sample: Option<usize>,
28 #[serde(default)]
29 columns: Option<Vec<String>>,
30 },
31 Local {
32 path: PathBuf,
33 format: FileFormat,
34 #[serde(default)]
35 sample: Option<usize>,
36 },
37}
38
39fn default_split() -> String {
40 "train".to_string()
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44#[serde(rename_all = "lowercase")]
45pub enum FileFormat {
46 Json,
47 Jsonl,
48 Csv,
49 Parquet,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53#[serde(tag = "type", rename_all = "lowercase")]
54pub enum ProviderConfig {
55 OpenAI {
56 model: String,
57 #[serde(default)]
58 api_key: Option<String>,
59 #[serde(default)]
60 base_url: Option<String>,
61 #[serde(default)]
62 temperature: Option<f32>,
63 #[serde(default)]
64 max_tokens: Option<u32>,
65 },
66 Anthropic {
67 model: String,
68 #[serde(default)]
69 api_key: Option<String>,
70 #[serde(default)]
71 temperature: Option<f32>,
72 #[serde(default)]
73 max_tokens: Option<u32>,
74 },
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct GenerationConfig {
79 pub task: GenerationTask,
80 #[serde(default = "default_count")]
81 pub count: usize,
82 #[serde(default)]
83 pub count_per_example: Option<usize>,
84 #[serde(default = "default_concurrency")]
85 pub concurrency: usize,
86 #[serde(default)]
87 pub strategy: Option<GenerationStrategy>,
88 #[serde(default)]
89 pub strategy_config: HashMap<String, serde_yaml::Value>,
90 #[serde(default)]
91 pub template: Option<String>,
92 #[serde(default)]
93 pub system_prompt: Option<String>,
94 #[serde(default)]
95 pub categories: Option<Vec<String>>,
96}
97
98fn default_count() -> usize {
99 100
100}
101
102fn default_concurrency() -> usize {
103 5
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
107#[serde(rename_all = "snake_case")]
108pub enum GenerationTask {
109 Generate,
110 Augment,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
114#[serde(rename_all = "snake_case")]
115pub enum GenerationStrategy {
116 Paraphrase,
117 StyleTransfer,
118 BackTranslation,
119 Custom,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct OutputConfig {
124 pub format: OutputFormat,
125 pub path: PathBuf,
126 #[serde(default = "default_batch_size")]
127 pub batch_size: usize,
128}
129
130fn default_batch_size() -> usize {
131 100
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
135#[serde(rename_all = "lowercase")]
136pub enum OutputFormat {
137 Json,
138 Jsonl,
139 Csv,
140 Parquet,
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize, Default)]
144pub struct ValidationConfig {
145 #[serde(default)]
146 pub min_length: Option<usize>,
147 #[serde(default)]
148 pub max_length: Option<usize>,
149 #[serde(default)]
150 pub json: bool,
151 #[serde(default)]
152 pub json_schema: Option<Vec<String>>,
153 #[serde(default)]
154 pub blocklist: bool,
155 #[serde(default)]
156 pub repetition: bool,
157 #[serde(default)]
158 pub dedupe: Option<DedupeStrategy>,
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
162#[serde(rename_all = "lowercase")]
163pub enum DedupeStrategy {
164 Exact,
165 Normalized,
166 Jaccard,
167}
168
169impl SynthConfig {
170 pub fn from_yaml(content: &str) -> crate::Result<Self> {
171 serde_yaml::from_str(content).map_err(Into::into)
172 }
173
174 pub fn from_file(path: &PathBuf) -> crate::Result<Self> {
175 let content = std::fs::read_to_string(path)?;
176 Self::from_yaml(&content)
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[test]
185 fn test_parse_augment_config() {
186 let yaml = r#"
187name: "sentiment_augmentation"
188
189source:
190 type: huggingface
191 dataset: "cornell-movie-review-data/rotten_tomatoes"
192 split: "train"
193 sample: 1000
194
195provider:
196 type: openai
197 model: "gpt-4o-mini"
198
199generation:
200 task: augment
201 count_per_example: 3
202 concurrency: 10
203 strategy: paraphrase
204
205output:
206 format: jsonl
207 path: "./output/augmented.jsonl"
208"#;
209
210 let config = SynthConfig::from_yaml(yaml).unwrap();
211 assert_eq!(config.name, "sentiment_augmentation");
212 assert!(matches!(
213 config.source,
214 Some(SourceConfig::HuggingFace { .. })
215 ));
216 assert!(matches!(config.provider, ProviderConfig::OpenAI { .. }));
217 assert!(matches!(config.generation.task, GenerationTask::Augment));
218 }
219
220 #[test]
221 fn test_parse_generate_config() {
222 let yaml = r#"
223name: "product_reviews"
224
225provider:
226 type: anthropic
227 model: "claude-haiku-4-5-20251001"
228
229generation:
230 task: generate
231 count: 500
232 concurrency: 5
233 categories:
234 - electronics
235 - books
236 - clothing
237 template: |
238 Generate a realistic {category} product review.
239 Output only the review text.
240
241output:
242 format: parquet
243 path: "./output/reviews.parquet"
244"#;
245
246 let config = SynthConfig::from_yaml(yaml).unwrap();
247 assert_eq!(config.name, "product_reviews");
248 assert!(config.source.is_none());
249 assert!(matches!(config.provider, ProviderConfig::Anthropic { .. }));
250 assert!(matches!(config.generation.task, GenerationTask::Generate));
251 assert_eq!(config.generation.categories.as_ref().unwrap().len(), 3);
252 }
253
254 #[test]
255 fn test_parse_validation_config() {
256 let yaml = r#"
257name: "with_validation"
258
259provider:
260 type: openai
261 model: "gpt-4o-mini"
262
263generation:
264 task: generate
265 count: 10
266 template: "Generate JSON: {\"q\": \"...\", \"a\": \"...\"}"
267
268output:
269 format: jsonl
270 path: "./output.jsonl"
271
272validation:
273 min_length: 20
274 max_length: 1000
275 json: true
276 json_schema:
277 - question
278 - answer
279 blocklist: true
280 repetition: true
281 dedupe: normalized
282"#;
283
284 let config = SynthConfig::from_yaml(yaml).unwrap();
285 let v = config.validation.unwrap();
286 assert_eq!(v.min_length, Some(20));
287 assert_eq!(v.max_length, Some(1000));
288 assert!(v.json);
289 assert_eq!(v.json_schema.unwrap(), vec!["question", "answer"]);
290 assert!(v.blocklist);
291 assert!(v.repetition);
292 assert!(matches!(v.dedupe, Some(DedupeStrategy::Normalized)));
293 }
294}