Skip to main content

synth_claw/generation/
engine.rs

1use crate::config::{GenerationConfig, GenerationStrategy, GenerationTask, ProviderConfig, SourceConfig, SynthConfig};
2use crate::datasets::{DataSource, HuggingFaceSource, LocalSource, Record};
3use crate::providers::{create_provider, GenerationRequest, GenerationResponse, LLMProvider};
4use crate::{Error, Result};
5
6use super::prompt::{default_template_for_augment, default_template_for_generate, PromptBuilder};
7
8use futures::stream::{self, StreamExt};
9use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
10use std::sync::Arc;
11use tokio::sync::Mutex;
12
13#[derive(Debug, Default)]
14pub struct GenerationStats {
15    pub completed: AtomicUsize,
16    pub failed: AtomicUsize,
17    pub total_input_tokens: AtomicU64,
18    pub total_output_tokens: AtomicU64,
19}
20
21impl GenerationStats {
22    pub fn record_success(&self, response: &GenerationResponse) {
23        self.completed.fetch_add(1, Ordering::Relaxed);
24        self.total_input_tokens.fetch_add(response.input_tokens as u64, Ordering::Relaxed);
25        self.total_output_tokens.fetch_add(response.output_tokens as u64, Ordering::Relaxed);
26    }
27
28    pub fn record_failure(&self) {
29        self.failed.fetch_add(1, Ordering::Relaxed);
30    }
31
32    pub fn snapshot(&self) -> StatsSnapshot {
33        StatsSnapshot {
34            completed: self.completed.load(Ordering::Relaxed),
35            failed: self.failed.load(Ordering::Relaxed),
36            total_input_tokens: self.total_input_tokens.load(Ordering::Relaxed),
37            total_output_tokens: self.total_output_tokens.load(Ordering::Relaxed),
38        }
39    }
40}
41
42#[derive(Debug, Clone)]
43pub struct StatsSnapshot {
44    pub completed: usize,
45    pub failed: usize,
46    pub total_input_tokens: u64,
47    pub total_output_tokens: u64,
48}
49
50#[derive(Debug, Clone)]
51pub struct GenerationResult {
52    pub content: String,
53    pub source_index: Option<usize>,
54    pub category: Option<String>,
55    pub input_tokens: u32,
56    pub output_tokens: u32,
57}
58
59struct GenerationTask_ {
60    prompt: String,
61    system_prompt: Option<String>,
62    source_index: Option<usize>,
63    category: Option<String>,
64}
65
66pub struct GenerationEngine {
67    provider: Box<dyn LLMProvider>,
68    config: GenerationConfig,
69    stats: Arc<GenerationStats>,
70}
71
72impl GenerationEngine {
73    pub fn new(provider_config: &ProviderConfig, generation_config: GenerationConfig) -> Result<Self> {
74        let provider = create_provider(provider_config)?;
75        Ok(Self {
76            provider,
77            config: generation_config,
78            stats: Arc::new(GenerationStats::default()),
79        })
80    }
81
82    pub fn stats(&self) -> Arc<GenerationStats> {
83        Arc::clone(&self.stats)
84    }
85
86    pub fn provider(&self) -> &dyn LLMProvider {
87        self.provider.as_ref()
88    }
89
90    /// Run generation and collect all results
91    pub async fn run(&self, config: &SynthConfig) -> Result<Vec<GenerationResult>> {
92        let tasks = self.build_tasks(config).await?;
93        let results = Arc::new(Mutex::new(Vec::with_capacity(tasks.len())));
94        
95        stream::iter(tasks)
96            .map(|task| {
97                let provider = &self.provider;
98                let stats = Arc::clone(&self.stats);
99                let results = Arc::clone(&results);
100                async move {
101                    match self.execute_task(provider.as_ref(), task).await {
102                        Ok(result) => {
103                            stats.record_success(&GenerationResponse {
104                                content: result.content.clone(),
105                                input_tokens: result.input_tokens,
106                                output_tokens: result.output_tokens,
107                            });
108                            results.lock().await.push(result);
109                        }
110                        Err(e) => {
111                            stats.record_failure();
112                            tracing::warn!("Generation failed: {}", e);
113                        }
114                    }
115                }
116            })
117            .buffer_unordered(self.config.concurrency)
118            .collect::<Vec<_>>()
119            .await;
120
121        let results = Arc::try_unwrap(results)
122            .map_err(|_| Error::Provider("Failed to unwrap results".to_string()))?
123            .into_inner();
124        
125        Ok(results)
126    }
127
128    /// Run generation with a callback for each result (for streaming output)
129    pub async fn run_with_callback<F>(&self, config: &SynthConfig, on_result: F) -> Result<()>
130    where
131        F: FnMut(GenerationResult) + Send,
132    {
133        let tasks = self.build_tasks(config).await?;
134        let callback = Arc::new(Mutex::new(on_result));
135        
136        stream::iter(tasks)
137            .map(|task| {
138                let provider = &self.provider;
139                let stats = Arc::clone(&self.stats);
140                let callback = Arc::clone(&callback);
141                async move {
142                    match self.execute_task(provider.as_ref(), task).await {
143                        Ok(result) => {
144                            stats.record_success(&GenerationResponse {
145                                content: result.content.clone(),
146                                input_tokens: result.input_tokens,
147                                output_tokens: result.output_tokens,
148                            });
149                            callback.lock().await(result);
150                        }
151                        Err(e) => {
152                            stats.record_failure();
153                            tracing::warn!("Generation failed: {}", e);
154                        }
155                    }
156                }
157            })
158            .buffer_unordered(self.config.concurrency)
159            .collect::<Vec<_>>()
160            .await;
161        
162        Ok(())
163    }
164
165    async fn build_tasks(&self, config: &SynthConfig) -> Result<Vec<GenerationTask_>> {
166        let prompt_builder = self.create_prompt_builder();
167        
168        match &config.generation.task {
169            GenerationTask::Generate => self.build_generate_tasks(&prompt_builder),
170            GenerationTask::Augment => self.build_augment_tasks(config, &prompt_builder).await,
171        }
172    }
173
174    fn build_generate_tasks(&self, prompt_builder: &PromptBuilder) -> Result<Vec<GenerationTask_>> {
175        let categories = self.config.categories.as_ref();
176        let count = self.config.count;
177        let system_prompt = Some(prompt_builder.system_prompt().to_string());
178
179        let mut tasks = Vec::with_capacity(count);
180
181        if let Some(cats) = categories {
182            let per_category = count / cats.len();
183            let remainder = count % cats.len();
184
185            for (cat_idx, category) in cats.iter().enumerate() {
186                let cat_count = per_category + if cat_idx < remainder { 1 } else { 0 };
187                for i in 0..cat_count {
188                    tasks.push(GenerationTask_ {
189                        prompt: prompt_builder.build_for_category(category, i),
190                        system_prompt: system_prompt.clone(),
191                        source_index: None,
192                        category: Some(category.clone()),
193                    });
194                }
195            }
196        } else {
197            for i in 0..count {
198                tasks.push(GenerationTask_ {
199                    prompt: prompt_builder.build_for_category("default", i),
200                    system_prompt: system_prompt.clone(),
201                    source_index: None,
202                    category: None,
203                });
204            }
205        }
206
207        Ok(tasks)
208    }
209
210    async fn build_augment_tasks(&self, config: &SynthConfig, prompt_builder: &PromptBuilder) -> Result<Vec<GenerationTask_>> {
211        let source_config = config.source.as_ref()
212            .ok_or_else(|| Error::Config("Augment task requires a source configuration".to_string()))?;
213
214        let records = self.load_source_data(source_config.clone()).await?;
215        let count_per = self.config.count_per_example.unwrap_or(1);
216        let system_prompt = Some(prompt_builder.system_prompt().to_string());
217
218        let mut tasks = Vec::with_capacity(records.len() * count_per);
219
220        for record in &records {
221            for _ in 0..count_per {
222                tasks.push(GenerationTask_ {
223                    prompt: prompt_builder.build_for_record(record),
224                    system_prompt: system_prompt.clone(),
225                    source_index: Some(record.index),
226                    category: None,
227                });
228            }
229        }
230
231        Ok(tasks)
232    }
233
234    async fn load_source_data(&self, source_config: SourceConfig) -> Result<Vec<Record>> {
235        // Run blocking IO operations in a separate thread pool
236        tokio::task::spawn_blocking(move || {
237            match source_config {
238                SourceConfig::HuggingFace { dataset, subset, split, sample, columns } => {
239                    let mut source = HuggingFaceSource::new(
240                        dataset,
241                        subset,
242                        split,
243                        columns,
244                    )?;
245                    source.load(sample)
246                }
247                SourceConfig::Local { path, format, sample } => {
248                    let mut source = LocalSource::new(path, format)?;
249                    source.load(sample)
250                }
251            }
252        })
253        .await
254        .map_err(|e| Error::Dataset(format!("Task join error: {}", e)))?
255    }
256
257    fn create_prompt_builder(&self) -> PromptBuilder {
258        let is_augment = matches!(&self.config.task, GenerationTask::Augment);
259        
260        let template = self.config.template.clone().unwrap_or_else(|| {
261            match &self.config.task {
262                GenerationTask::Generate => default_template_for_generate(),
263                GenerationTask::Augment => {
264                    let strategy = self.config.strategy.as_ref()
265                        .map(|s| match s {
266                            GenerationStrategy::Paraphrase => "paraphrase",
267                            GenerationStrategy::StyleTransfer => "style_transfer",
268                            GenerationStrategy::BackTranslation => "back_translation",
269                            GenerationStrategy::Custom => "custom",
270                        })
271                        .unwrap_or("paraphrase");
272                    default_template_for_augment(strategy)
273                }
274            }
275        });
276
277        PromptBuilder::new(template, self.config.system_prompt.clone(), is_augment)
278    }
279
280    async fn execute_task(&self, provider: &dyn LLMProvider, task: GenerationTask_) -> Result<GenerationResult> {
281        let request = GenerationRequest {
282            prompt: task.prompt,
283            system_prompt: task.system_prompt,
284            temperature: None,
285            max_tokens: None,
286        };
287
288        let response = provider.generate(request).await?;
289
290        Ok(GenerationResult {
291            content: response.content,
292            source_index: task.source_index,
293            category: task.category,
294            input_tokens: response.input_tokens,
295            output_tokens: response.output_tokens,
296        })
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303    use crate::config::*;
304
305    fn test_config() -> SynthConfig {
306        SynthConfig {
307            name: "test".to_string(),
308            source: None,
309            provider: ProviderConfig::OpenAI {
310                model: "gpt-4o-mini".to_string(),
311                api_key: Some("test-key".to_string()),
312                base_url: None,
313                temperature: None,
314                max_tokens: None,
315            },
316            generation: GenerationConfig {
317                task: GenerationTask::Generate,
318                count: 10,
319                count_per_example: None,
320                concurrency: 2,
321                strategy: None,
322                strategy_config: Default::default(),
323                template: Some("Generate a {category} example".to_string()),
324                system_prompt: None,
325                categories: Some(vec!["A".to_string(), "B".to_string()]),
326            },
327            output: OutputConfig {
328                format: OutputFormat::Jsonl,
329                path: "./output.jsonl".into(),
330                batch_size: 100,
331            },
332        }
333    }
334
335    #[test]
336    fn test_build_generate_tasks() {
337        let config = test_config();
338        let engine = GenerationEngine::new(&config.provider, config.generation.clone()).unwrap();
339        let prompt_builder = engine.create_prompt_builder();
340        
341        let tasks = engine.build_generate_tasks(&prompt_builder).unwrap();
342        
343        assert_eq!(tasks.len(), 10);
344        // 5 for category A, 5 for category B
345        let a_count = tasks.iter().filter(|t| t.category.as_deref() == Some("A")).count();
346        let b_count = tasks.iter().filter(|t| t.category.as_deref() == Some("B")).count();
347        assert_eq!(a_count, 5);
348        assert_eq!(b_count, 5);
349    }
350
351    #[test]
352    fn test_stats_tracking() {
353        let stats = GenerationStats::default();
354        
355        stats.record_success(&GenerationResponse {
356            content: "test".to_string(),
357            input_tokens: 100,
358            output_tokens: 50,
359        });
360        stats.record_success(&GenerationResponse {
361            content: "test".to_string(),
362            input_tokens: 200,
363            output_tokens: 100,
364        });
365        stats.record_failure();
366
367        let snapshot = stats.snapshot();
368        assert_eq!(snapshot.completed, 2);
369        assert_eq!(snapshot.failed, 1);
370        assert_eq!(snapshot.total_input_tokens, 300);
371        assert_eq!(snapshot.total_output_tokens, 150);
372    }
373}