Skip to main content

synth_claw/generation/
prompt.rs

1use crate::datasets::Record;
2use crate::{Error, Result};
3use regex::Regex;
4use serde_json::Value;
5use std::collections::HashMap;
6
7pub struct PromptBuilder {
8    template: String,
9    system_prompt: String,
10}
11
12impl PromptBuilder {
13    pub fn new(template: String, system_prompt: Option<String>, is_augment: bool) -> Self {
14        let system_prompt = system_prompt.unwrap_or_else(|| {
15            if is_augment {
16                default_system_prompt_augment().to_string()
17            } else {
18                default_system_prompt_generate().to_string()
19            }
20        });
21        Self {
22            template,
23            system_prompt,
24        }
25    }
26
27    pub fn system_prompt(&self) -> &str {
28        &self.system_prompt
29    }
30
31    /// Build prompt for from-scratch generation with category
32    pub fn build_for_category(&self, category: &str, index: usize) -> String {
33        let mut vars = HashMap::new();
34        vars.insert("category".to_string(), Value::String(category.to_string()));
35        vars.insert("index".to_string(), Value::Number(index.into()));
36        self.substitute(&vars)
37    }
38
39    /// Build prompt for augmenting existing data
40    pub fn build_for_record(&self, record: &Record) -> String {
41        let vars = self.extract_vars(&record.data);
42        self.substitute(&vars)
43    }
44
45    /// Extract variable names from template (e.g., {category}, {text})
46    pub fn required_variables(&self) -> Vec<String> {
47        let re = Regex::new(r"\{(\w+)\}").unwrap();
48        re.captures_iter(&self.template)
49            .map(|cap| cap[1].to_string())
50            .collect()
51    }
52
53    /// Validate that all required variables will be available
54    pub fn validate_for_generate(&self, categories: &Option<Vec<String>>) -> Result<()> {
55        let required = self.required_variables();
56
57        for var in &required {
58            match var.as_str() {
59                "category" => {
60                    if categories.is_none()
61                        || categories.as_ref().map(|c| c.is_empty()).unwrap_or(true)
62                    {
63                        return Err(Error::Config(
64                            "Template uses {category} but no categories provided".to_string(),
65                        ));
66                    }
67                }
68                "index" => {} // always available
69                other => {
70                    return Err(Error::Config(format!(
71                        "Template uses {{{}}} which is not available in generate mode. Available: {{category}}, {{index}}",
72                        other
73                    )));
74                }
75            }
76        }
77        Ok(())
78    }
79
80    /// Validate that required variables exist in source data
81    pub fn validate_for_augment(&self, available_columns: &[String]) -> Result<()> {
82        let required = self.required_variables();
83
84        for var in &required {
85            if var != "index" && !available_columns.contains(var) {
86                return Err(Error::Config(format!(
87                    "Template uses {{{}}} but source data only has columns: {:?}",
88                    var, available_columns
89                )));
90            }
91        }
92        Ok(())
93    }
94
95    fn substitute(&self, vars: &HashMap<String, Value>) -> String {
96        let mut result = self.template.clone();
97        for (key, value) in vars {
98            let placeholder = format!("{{{}}}", key);
99            let replacement = match value {
100                Value::String(s) => s.clone(),
101                Value::Number(n) => n.to_string(),
102                Value::Bool(b) => b.to_string(),
103                Value::Null => "null".to_string(),
104                Value::Array(arr) => serde_json::to_string(arr).unwrap_or_default(),
105                Value::Object(obj) => serde_json::to_string(obj).unwrap_or_default(),
106            };
107            result = result.replace(&placeholder, &replacement);
108        }
109        result
110    }
111
112    fn extract_vars(&self, data: &Value) -> HashMap<String, Value> {
113        let mut vars = HashMap::new();
114        if let Value::Object(map) = data {
115            for (key, value) in map {
116                vars.insert(key.clone(), value.clone());
117            }
118        }
119        vars
120    }
121}
122
123pub fn default_system_prompt_generate() -> &'static str {
124    r#"You are a synthetic data generation assistant. Your task is to generate realistic, high-quality training data.
125
126Rules:
127- Output ONLY the requested content, nothing else
128- No explanations, meta-commentary, or surrounding text
129- No markdown formatting unless explicitly requested
130- Generate diverse, realistic examples that could plausibly exist in the real world
131- Vary your outputs - avoid repetitive patterns or templates
132- Match the tone and style appropriate for the content type
133- If generating text that would have a label (sentiment, category, etc.), make the content clearly match that label"#
134}
135
136pub fn default_system_prompt_augment() -> &'static str {
137    r#"You are a data augmentation assistant. Your task is to transform input data while preserving its essential properties.
138
139Rules:
140- Output ONLY the transformed content, nothing else
141- No explanations, meta-commentary, or surrounding text
142- No markdown formatting unless explicitly requested
143- Preserve the original meaning, sentiment, and intent
144- If the data has a label (positive/negative, category, etc.), the augmented version must retain the same label
145- Make meaningful changes - simple word swaps are not sufficient
146- The output should be natural and fluent"#
147}
148
149pub fn default_template_for_generate() -> String {
150    r#"Generate a realistic example of: {category}
151
152Requirements:
153- Authentic, natural language
154- Specific details that make it believable
155- 2-5 sentences unless otherwise specified
156- Diverse - vary structure and content"#
157        .to_string()
158}
159
160pub fn default_template_for_augment(strategy: &str) -> String {
161    match strategy {
162        "paraphrase" => {
163            r#"Paraphrase the following text. Preserve the original meaning and sentiment exactly.
164
165Input: {text}
166
167Paraphrased version:"#
168                .to_string()
169        }
170
171        "style_transfer" => {
172            r#"Rewrite the following text in a different style while preserving the core meaning.
173
174Input: {text}
175Target style: {style}
176
177Rewritten version:"#
178                .to_string()
179        }
180
181        "back_translation" => {
182            r#"Rephrase this text as if it was translated to another language and back. Keep the same meaning but use different word choices and sentence structures.
183
184Input: {text}
185
186Rephrased version:"#
187                .to_string()
188        }
189
190        _ => {
191            r#"Transform the following text while preserving its meaning:
192
193Input: {text}
194
195Transformed version:"#
196                .to_string()
197        }
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    #[test]
206    fn test_build_for_category() {
207        let builder = PromptBuilder::new(
208            "Generate a {category} review (item #{index})".to_string(),
209            None,
210            false,
211        );
212        let result = builder.build_for_category("electronics", 5);
213        assert_eq!(result, "Generate a electronics review (item #5)");
214    }
215
216    #[test]
217    fn test_build_for_record() {
218        let builder =
219            PromptBuilder::new("Paraphrase: {text}\nLabel: {label}".to_string(), None, true);
220        let record = Record {
221            data: serde_json::json!({
222                "text": "This movie is great!",
223                "label": 1
224            }),
225            index: 0,
226        };
227        let result = builder.build_for_record(&record);
228        assert_eq!(result, "Paraphrase: This movie is great!\nLabel: 1");
229    }
230
231    #[test]
232    fn test_required_variables() {
233        let builder = PromptBuilder::new(
234            "Hello {name}, your {item} for {category} is ready".to_string(),
235            None,
236            false,
237        );
238        let vars = builder.required_variables();
239        assert!(vars.contains(&"name".to_string()));
240        assert!(vars.contains(&"item".to_string()));
241        assert!(vars.contains(&"category".to_string()));
242    }
243
244    #[test]
245    fn test_validate_generate_missing_categories() {
246        let builder = PromptBuilder::new("Generate a {category} example".to_string(), None, false);
247        let result = builder.validate_for_generate(&None);
248        assert!(result.is_err());
249    }
250
251    #[test]
252    fn test_validate_generate_with_categories() {
253        let builder = PromptBuilder::new("Generate a {category} example".to_string(), None, false);
254        let result = builder.validate_for_generate(&Some(vec!["test".to_string()]));
255        assert!(result.is_ok());
256    }
257
258    #[test]
259    fn test_validate_augment_missing_column() {
260        let builder =
261            PromptBuilder::new("Paraphrase: {text} with {missing}".to_string(), None, true);
262        let result = builder.validate_for_augment(&["text".to_string()]);
263        assert!(result.is_err());
264    }
265
266    #[test]
267    fn test_validate_augment_valid() {
268        let builder = PromptBuilder::new("Paraphrase: {text}".to_string(), None, true);
269        let result = builder.validate_for_augment(&["text".to_string(), "label".to_string()]);
270        assert!(result.is_ok());
271    }
272
273    #[test]
274    fn test_default_system_prompts_exist() {
275        assert!(!default_system_prompt_generate().is_empty());
276        assert!(!default_system_prompt_augment().is_empty());
277    }
278}