Skip to main content

synth_claw/generation/
engine.rs

1use crate::config::{GenerationConfig, GenerationStrategy, GenerationTask, ProviderConfig, SourceConfig, SynthConfig};
2use crate::datasets::{DataSource, HuggingFaceSource, LocalSource, Record};
3use crate::providers::{create_provider, GenerationRequest, GenerationResponse, LLMProvider};
4use crate::{Error, Result};
5
6use super::prompt::{default_template_for_augment, default_template_for_generate, PromptBuilder};
7
8use futures::stream::{self, StreamExt};
9use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
10use std::sync::Arc;
11use tokio::sync::Mutex;
12
13#[derive(Debug, Default)]
14pub struct GenerationStats {
15    pub completed: AtomicUsize,
16    pub failed: AtomicUsize,
17    pub total_input_tokens: AtomicU64,
18    pub total_output_tokens: AtomicU64,
19}
20
21impl GenerationStats {
22    pub fn record_success(&self, response: &GenerationResponse) {
23        self.completed.fetch_add(1, Ordering::Relaxed);
24        self.total_input_tokens.fetch_add(response.input_tokens as u64, Ordering::Relaxed);
25        self.total_output_tokens.fetch_add(response.output_tokens as u64, Ordering::Relaxed);
26    }
27
28    pub fn record_failure(&self) {
29        self.failed.fetch_add(1, Ordering::Relaxed);
30    }
31
32    pub fn snapshot(&self) -> StatsSnapshot {
33        StatsSnapshot {
34            completed: self.completed.load(Ordering::Relaxed),
35            failed: self.failed.load(Ordering::Relaxed),
36            total_input_tokens: self.total_input_tokens.load(Ordering::Relaxed),
37            total_output_tokens: self.total_output_tokens.load(Ordering::Relaxed),
38        }
39    }
40}
41
42#[derive(Debug, Clone)]
43pub struct StatsSnapshot {
44    pub completed: usize,
45    pub failed: usize,
46    pub total_input_tokens: u64,
47    pub total_output_tokens: u64,
48}
49
50#[derive(Debug, Clone)]
51pub struct GenerationResult {
52    pub content: String,
53    pub source_index: Option<usize>,
54    pub category: Option<String>,
55    pub input_tokens: u32,
56    pub output_tokens: u32,
57}
58
59impl GenerationResult {
60    /// Parse content as JSON, extracting from markdown code blocks if needed
61    pub fn parse_json(&self) -> Result<serde_json::Value> {
62        let content = self.extract_json_content();
63        serde_json::from_str(&content).map_err(|e| Error::Json(e))
64    }
65
66    /// Parse content as a typed JSON object
67    pub fn parse_json_as<T: serde::de::DeserializeOwned>(&self) -> Result<T> {
68        let content = self.extract_json_content();
69        serde_json::from_str(&content).map_err(|e| Error::Json(e))
70    }
71
72    fn extract_json_content(&self) -> String {
73        let content = self.content.trim();
74
75        // Try ```json blocks
76        if let Some(start) = content.find("```json") {
77            if let Some(end) = content[start + 7..].find("```") {
78                return content[start + 7..start + 7 + end].trim().to_string();
79            }
80        }
81
82        // Try generic ``` blocks
83        if let Some(start) = content.find("```") {
84            if let Some(end) = content[start + 3..].find("```") {
85                let inner = content[start + 3..start + 3 + end].trim();
86                if let Some(newline) = inner.find('\n') {
87                    return inner[newline + 1..].trim().to_string();
88                }
89                return inner.to_string();
90            }
91        }
92
93        content.to_string()
94    }
95}
96
97struct GenerationTask_ {
98    prompt: String,
99    system_prompt: Option<String>,
100    source_index: Option<usize>,
101    category: Option<String>,
102}
103
104pub struct GenerationEngine {
105    provider: Box<dyn LLMProvider>,
106    config: GenerationConfig,
107    stats: Arc<GenerationStats>,
108}
109
110impl GenerationEngine {
111    pub fn new(provider_config: &ProviderConfig, generation_config: GenerationConfig) -> Result<Self> {
112        let provider = create_provider(provider_config)?;
113        Ok(Self {
114            provider,
115            config: generation_config,
116            stats: Arc::new(GenerationStats::default()),
117        })
118    }
119
120    pub fn stats(&self) -> Arc<GenerationStats> {
121        Arc::clone(&self.stats)
122    }
123
124    pub fn provider(&self) -> &dyn LLMProvider {
125        self.provider.as_ref()
126    }
127
128    /// Run generation and collect all results
129    pub async fn run(&self, config: &SynthConfig) -> Result<Vec<GenerationResult>> {
130        let tasks = self.build_tasks(config).await?;
131        let results = Arc::new(Mutex::new(Vec::with_capacity(tasks.len())));
132        
133        stream::iter(tasks)
134            .map(|task| {
135                let provider = &self.provider;
136                let stats = Arc::clone(&self.stats);
137                let results = Arc::clone(&results);
138                async move {
139                    match self.execute_task(provider.as_ref(), task).await {
140                        Ok(result) => {
141                            stats.record_success(&GenerationResponse {
142                                content: result.content.clone(),
143                                input_tokens: result.input_tokens,
144                                output_tokens: result.output_tokens,
145                            });
146                            results.lock().await.push(result);
147                        }
148                        Err(e) => {
149                            stats.record_failure();
150                            tracing::warn!("Generation failed: {}", e);
151                        }
152                    }
153                }
154            })
155            .buffer_unordered(self.config.concurrency)
156            .collect::<Vec<_>>()
157            .await;
158
159        let results = Arc::try_unwrap(results)
160            .map_err(|_| Error::Provider("Failed to unwrap results".to_string()))?
161            .into_inner();
162        
163        Ok(results)
164    }
165
166    /// Run generation with a callback for each result (for streaming output)
167    pub async fn run_with_callback<F>(&self, config: &SynthConfig, on_result: F) -> Result<()>
168    where
169        F: FnMut(GenerationResult) + Send,
170    {
171        let tasks = self.build_tasks(config).await?;
172        let callback = Arc::new(Mutex::new(on_result));
173        
174        stream::iter(tasks)
175            .map(|task| {
176                let provider = &self.provider;
177                let stats = Arc::clone(&self.stats);
178                let callback = Arc::clone(&callback);
179                async move {
180                    match self.execute_task(provider.as_ref(), task).await {
181                        Ok(result) => {
182                            stats.record_success(&GenerationResponse {
183                                content: result.content.clone(),
184                                input_tokens: result.input_tokens,
185                                output_tokens: result.output_tokens,
186                            });
187                            callback.lock().await(result);
188                        }
189                        Err(e) => {
190                            stats.record_failure();
191                            tracing::warn!("Generation failed: {}", e);
192                        }
193                    }
194                }
195            })
196            .buffer_unordered(self.config.concurrency)
197            .collect::<Vec<_>>()
198            .await;
199        
200        Ok(())
201    }
202
203    async fn build_tasks(&self, config: &SynthConfig) -> Result<Vec<GenerationTask_>> {
204        let prompt_builder = self.create_prompt_builder();
205        
206        match &config.generation.task {
207            GenerationTask::Generate => self.build_generate_tasks(&prompt_builder),
208            GenerationTask::Augment => self.build_augment_tasks(config, &prompt_builder).await,
209        }
210    }
211
212    fn build_generate_tasks(&self, prompt_builder: &PromptBuilder) -> Result<Vec<GenerationTask_>> {
213        let categories = self.config.categories.as_ref();
214        let count = self.config.count;
215        let system_prompt = Some(prompt_builder.system_prompt().to_string());
216
217        let mut tasks = Vec::with_capacity(count);
218
219        if let Some(cats) = categories {
220            let per_category = count / cats.len();
221            let remainder = count % cats.len();
222
223            for (cat_idx, category) in cats.iter().enumerate() {
224                let cat_count = per_category + if cat_idx < remainder { 1 } else { 0 };
225                for i in 0..cat_count {
226                    tasks.push(GenerationTask_ {
227                        prompt: prompt_builder.build_for_category(category, i),
228                        system_prompt: system_prompt.clone(),
229                        source_index: None,
230                        category: Some(category.clone()),
231                    });
232                }
233            }
234        } else {
235            for i in 0..count {
236                tasks.push(GenerationTask_ {
237                    prompt: prompt_builder.build_for_category("default", i),
238                    system_prompt: system_prompt.clone(),
239                    source_index: None,
240                    category: None,
241                });
242            }
243        }
244
245        Ok(tasks)
246    }
247
248    async fn build_augment_tasks(&self, config: &SynthConfig, prompt_builder: &PromptBuilder) -> Result<Vec<GenerationTask_>> {
249        let source_config = config.source.as_ref()
250            .ok_or_else(|| Error::Config("Augment task requires a source configuration".to_string()))?;
251
252        let records = self.load_source_data(source_config.clone()).await?;
253        let count_per = self.config.count_per_example.unwrap_or(1);
254        let system_prompt = Some(prompt_builder.system_prompt().to_string());
255
256        let mut tasks = Vec::with_capacity(records.len() * count_per);
257
258        for record in &records {
259            for _ in 0..count_per {
260                tasks.push(GenerationTask_ {
261                    prompt: prompt_builder.build_for_record(record),
262                    system_prompt: system_prompt.clone(),
263                    source_index: Some(record.index),
264                    category: None,
265                });
266            }
267        }
268
269        Ok(tasks)
270    }
271
272    async fn load_source_data(&self, source_config: SourceConfig) -> Result<Vec<Record>> {
273        // Run blocking IO operations in a separate thread pool
274        tokio::task::spawn_blocking(move || {
275            match source_config {
276                SourceConfig::HuggingFace { dataset, subset, split, sample, columns } => {
277                    let mut source = HuggingFaceSource::new(
278                        dataset,
279                        subset,
280                        split,
281                        columns,
282                    )?;
283                    source.load(sample)
284                }
285                SourceConfig::Local { path, format, sample } => {
286                    let mut source = LocalSource::new(path, format)?;
287                    source.load(sample)
288                }
289            }
290        })
291        .await
292        .map_err(|e| Error::Dataset(format!("Task join error: {}", e)))?
293    }
294
295    fn create_prompt_builder(&self) -> PromptBuilder {
296        let is_augment = matches!(&self.config.task, GenerationTask::Augment);
297        
298        let template = self.config.template.clone().unwrap_or_else(|| {
299            match &self.config.task {
300                GenerationTask::Generate => default_template_for_generate(),
301                GenerationTask::Augment => {
302                    let strategy = self.config.strategy.as_ref()
303                        .map(|s| match s {
304                            GenerationStrategy::Paraphrase => "paraphrase",
305                            GenerationStrategy::StyleTransfer => "style_transfer",
306                            GenerationStrategy::BackTranslation => "back_translation",
307                            GenerationStrategy::Custom => "custom",
308                        })
309                        .unwrap_or("paraphrase");
310                    default_template_for_augment(strategy)
311                }
312            }
313        });
314
315        PromptBuilder::new(template, self.config.system_prompt.clone(), is_augment)
316    }
317
318    async fn execute_task(&self, provider: &dyn LLMProvider, task: GenerationTask_) -> Result<GenerationResult> {
319        let request = GenerationRequest {
320            prompt: task.prompt,
321            system_prompt: task.system_prompt,
322            temperature: None,
323            max_tokens: None,
324        };
325
326        let response = provider.generate(request).await?;
327
328        Ok(GenerationResult {
329            content: response.content,
330            source_index: task.source_index,
331            category: task.category,
332            input_tokens: response.input_tokens,
333            output_tokens: response.output_tokens,
334        })
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use crate::config::*;
342
343    fn test_config() -> SynthConfig {
344        SynthConfig {
345            name: "test".to_string(),
346            source: None,
347            provider: ProviderConfig::OpenAI {
348                model: "gpt-4o-mini".to_string(),
349                api_key: Some("test-key".to_string()),
350                base_url: None,
351                temperature: None,
352                max_tokens: None,
353            },
354            generation: GenerationConfig {
355                task: GenerationTask::Generate,
356                count: 10,
357                count_per_example: None,
358                concurrency: 2,
359                strategy: None,
360                strategy_config: Default::default(),
361                template: Some("Generate a {category} example".to_string()),
362                system_prompt: None,
363                categories: Some(vec!["A".to_string(), "B".to_string()]),
364            },
365            output: OutputConfig {
366                format: OutputFormat::Jsonl,
367                path: "./output.jsonl".into(),
368                batch_size: 100,
369            },
370            validation: None,
371            hub: None,
372        }
373    }
374
375    #[test]
376    fn test_build_generate_tasks() {
377        let config = test_config();
378        let engine = GenerationEngine::new(&config.provider, config.generation.clone()).unwrap();
379        let prompt_builder = engine.create_prompt_builder();
380        
381        let tasks = engine.build_generate_tasks(&prompt_builder).unwrap();
382        
383        assert_eq!(tasks.len(), 10);
384        // 5 for category A, 5 for category B
385        let a_count = tasks.iter().filter(|t| t.category.as_deref() == Some("A")).count();
386        let b_count = tasks.iter().filter(|t| t.category.as_deref() == Some("B")).count();
387        assert_eq!(a_count, 5);
388        assert_eq!(b_count, 5);
389    }
390
391    #[test]
392    fn test_stats_tracking() {
393        let stats = GenerationStats::default();
394        
395        stats.record_success(&GenerationResponse {
396            content: "test".to_string(),
397            input_tokens: 100,
398            output_tokens: 50,
399        });
400        stats.record_success(&GenerationResponse {
401            content: "test".to_string(),
402            input_tokens: 200,
403            output_tokens: 100,
404        });
405        stats.record_failure();
406
407        let snapshot = stats.snapshot();
408        assert_eq!(snapshot.completed, 2);
409        assert_eq!(snapshot.failed, 1);
410        assert_eq!(snapshot.total_input_tokens, 300);
411        assert_eq!(snapshot.total_output_tokens, 150);
412    }
413}