Skip to main content

sgr_agent/
codegen.rs

1//! Code generator: BAML AST → Rust source code with schemars + serde derives.
2//!
3//! Generates:
4//! - Rust structs for each BAML class
5//! - Rust enums for string unions
6//! - Tool registrations for BAML classes with `task` field (= tool definitions)
7//! - Prompt constants for BAML functions
8
9use crate::baml_parser::*;
10
11/// Generate Rust source code from parsed BAML module.
12pub fn generate(module: &BamlModule) -> String {
13    let mut out = String::new();
14
15    // Header
16    out.push_str("//! Auto-generated from .baml files by sgr-agent codegen.\n");
17    out.push_str("//! Do not edit manually — edit the .baml source and re-run.\n\n");
18    out.push_str("#![allow(dead_code, clippy::derivable_impls)]\n\n");
19    out.push_str("use serde::{Deserialize, Serialize};\n");
20    out.push_str("use schemars::JsonSchema;\n\n");
21
22    // Collect all inline string enums that need separate types
23    let mut enum_map: Vec<(String, Vec<String>)> = Vec::new();
24
25    // Generate structs
26    for class in &module.classes {
27        generate_struct(&mut out, class, module, &mut enum_map);
28    }
29
30    // Generate collected enums
31    for (name, variants) in &enum_map {
32        generate_string_enum(&mut out, name, variants);
33    }
34
35    // Generate tool registry
36    generate_tool_registry(&mut out, module);
37
38    // Generate prompt constants
39    generate_prompts(&mut out, module);
40
41    out
42}
43
44fn generate_struct(
45    out: &mut String,
46    class: &BamlClass,
47    module: &BamlModule,
48    enum_map: &mut Vec<(String, Vec<String>)>,
49) {
50    // Doc comment
51    if let Some(desc) = &class.description {
52        out.push_str(&format!("/// {}\n", desc));
53    }
54
55    out.push_str("#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]\n");
56    out.push_str(&format!("pub struct {} {{\n", class.name));
57
58    for field in &class.fields {
59        // Doc comment
60        if let Some(desc) = &field.description {
61            out.push_str(&format!("    /// {}\n", desc));
62        }
63
64        // Fixed value fields become constants (skip in struct, add as default)
65        if let Some(fixed) = &field.fixed_value {
66            out.push_str(&format!("    /// Fixed value: \"{}\"\n", fixed));
67            out.push_str(&format!(
68                "    #[serde(default = \"default_{}__{}\")]\n",
69                snake_case(&class.name),
70                field.name
71            ));
72            out.push_str(&format!("    pub {}: String,\n", field.name));
73            continue;
74        }
75
76        let rust_type = baml_type_to_rust(&field.ty, &class.name, &field.name, module, enum_map);
77
78        // Optional fields get serde skip_serializing_if
79        if matches!(&field.ty, BamlType::Optional(_)) {
80            out.push_str("    #[serde(skip_serializing_if = \"Option::is_none\")]\n");
81        }
82
83        out.push_str(&format!("    pub {}: {},\n", field.name, rust_type));
84    }
85
86    out.push_str("}\n\n");
87
88    // Generate default functions for fixed-value fields
89    for field in &class.fields {
90        if let Some(fixed) = &field.fixed_value {
91            out.push_str(&format!(
92                "fn default_{}__{}() -> String {{ \"{}\".to_string() }}\n",
93                snake_case(&class.name),
94                field.name,
95                fixed
96            ));
97        }
98    }
99    // Extra newline after defaults
100    if class.fields.iter().any(|f| f.fixed_value.is_some()) {
101        out.push('\n');
102    }
103}
104
105fn generate_string_enum(out: &mut String, name: &str, variants: &[String]) {
106    out.push_str("#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]\n");
107    out.push_str(&format!("pub enum {} {{\n", name));
108    for variant in variants {
109        let rust_variant = pascal_case(variant);
110        out.push_str(&format!("    #[serde(rename = \"{}\")]\n", variant));
111        out.push_str(&format!("    {},\n", rust_variant));
112    }
113    out.push_str("}\n\n");
114}
115
116fn generate_tool_registry(out: &mut String, module: &BamlModule) {
117    // Collect classes that have a `task` field with a fixed value — these are tools
118    let tool_classes: Vec<&BamlClass> = module
119        .classes
120        .iter()
121        .filter(|c| {
122            c.fields
123                .iter()
124                .any(|f| f.name == "task" && f.fixed_value.is_some())
125        })
126        .collect();
127
128    if tool_classes.is_empty() {
129        return;
130    }
131
132    out.push_str("// --- Tool Registry ---\n\n");
133    out.push_str("use crate::tool::ToolDef;\n\n");
134    out.push_str("/// All tools extracted from BAML definitions.\n");
135    out.push_str("pub fn all_tools() -> Vec<ToolDef> {\n");
136    out.push_str("    vec![\n");
137
138    for class in &tool_classes {
139        let task_field = class.fields.iter().find(|f| f.name == "task").unwrap();
140        let tool_name = task_field.fixed_value.as_deref().unwrap();
141        let description = task_field.description.as_deref().unwrap_or(&class.name);
142
143        out.push_str(&format!(
144            "        crate::tool::tool::<{}>(\"{}\", \"{}\"),\n",
145            class.name,
146            tool_name,
147            escape_string(description),
148        ));
149    }
150
151    out.push_str("    ]\n");
152    out.push_str("}\n\n");
153
154    // Generate ActionUnion enum for dispatch
155    out.push_str("/// Union of all tool types (for dispatching tool calls).\n");
156    out.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
157    out.push_str("#[serde(tag = \"task\")]\n");
158    out.push_str("pub enum ActionUnion {\n");
159
160    for class in &tool_classes {
161        let task_field = class.fields.iter().find(|f| f.name == "task").unwrap();
162        let tool_name = task_field.fixed_value.as_deref().unwrap();
163        out.push_str(&format!("    #[serde(rename = \"{}\")]\n", tool_name));
164        out.push_str(&format!("    {}({}),\n", class.name, class.name));
165    }
166
167    out.push_str("}\n\n");
168}
169
170fn generate_prompts(out: &mut String, module: &BamlModule) {
171    if module.functions.is_empty() {
172        return;
173    }
174
175    out.push_str("// --- Prompt Constants ---\n\n");
176
177    for func in &module.functions {
178        let const_name = screaming_snake_case(&func.name);
179        // Escape the prompt for a raw string
180        out.push_str(&format!(
181            "pub const {}_PROMPT: &str = r##\"\n{}\"##;\n\n",
182            const_name,
183            func.prompt.trim(),
184        ));
185    }
186}
187
188// --- Type conversion ---
189
190fn baml_type_to_rust(
191    ty: &BamlType,
192    class_name: &str,
193    field_name: &str,
194    module: &BamlModule,
195    enum_map: &mut Vec<(String, Vec<String>)>,
196) -> String {
197    match ty {
198        BamlType::String => "String".to_string(),
199        BamlType::Int => "i64".to_string(),
200        BamlType::Float => "f64".to_string(),
201        BamlType::Bool => "bool".to_string(),
202        BamlType::Image => "String".to_string(), // base64 or URL
203        BamlType::Ref(name) => {
204            if module.find_class(name).is_some() {
205                name.clone()
206            } else {
207                // Might be an enum we haven't seen — treat as String
208                "String".to_string()
209            }
210        }
211        BamlType::Optional(inner) => {
212            let inner_rust = baml_type_to_rust(inner, class_name, field_name, module, enum_map);
213            format!("Option<{}>", inner_rust)
214        }
215        BamlType::Array(inner) => {
216            let inner_rust = baml_type_to_rust(inner, class_name, field_name, module, enum_map);
217            format!("Vec<{}>", inner_rust)
218        }
219        BamlType::StringEnum(variants) => {
220            // Create a named enum type
221            let enum_name = format!("{}{}", class_name, pascal_case(field_name));
222            if !enum_map.iter().any(|(n, _)| n == &enum_name) {
223                enum_map.push((enum_name.clone(), variants.clone()));
224            }
225            enum_name
226        }
227        BamlType::Union(variants) => {
228            // For now, use serde_json::Value for complex unions
229            // Could generate a proper enum with #[serde(untagged)]
230            if variants.len() <= 4 {
231                // Small union → generate enum
232                let enum_name = format!("{}{}", class_name, pascal_case(field_name));
233                if !enum_map.iter().any(|(n, _)| n == &enum_name) {
234                    // This is a class union, not string enum — skip for now
235                    // Would need #[serde(untagged)] enum
236                }
237            }
238            "serde_json::Value".to_string()
239        }
240    }
241}
242
243// --- String helpers ---
244
245fn snake_case(s: &str) -> String {
246    let mut result = String::new();
247    for (i, c) in s.chars().enumerate() {
248        if c.is_uppercase() {
249            if i > 0 {
250                result.push('_');
251            }
252            result.push(c.to_ascii_lowercase());
253        } else {
254            result.push(c);
255        }
256    }
257    result
258}
259
260fn pascal_case(s: &str) -> String {
261    s.split('_')
262        .map(|word| {
263            let mut chars = word.chars();
264            match chars.next() {
265                Some(c) => c.to_uppercase().to_string() + chars.as_str(),
266                None => String::new(),
267            }
268        })
269        .collect()
270}
271
272fn screaming_snake_case(s: &str) -> String {
273    snake_case(s).to_uppercase()
274}
275
276fn escape_string(s: &str) -> String {
277    s.replace('\\', "\\\\").replace('"', "\\\"")
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn generates_from_simple_baml() {
286        let source = r#"
287class CutDecision {
288  action "trim" | "keep" | "highlight" @description("Editing action")
289  reason string @description("Short reasoning")
290}
291"#;
292        let mut module = BamlModule::default();
293        module.parse_source(source);
294
295        let code = generate(&module);
296        assert!(code.contains("pub struct CutDecision"));
297        assert!(code.contains("pub action: CutDecisionAction"));
298        assert!(code.contains("pub reason: String"));
299        assert!(code.contains("pub enum CutDecisionAction"));
300        assert!(code.contains("#[serde(rename = \"trim\")]"));
301    }
302
303    #[test]
304    fn generates_tools_from_baml() {
305        let source = r#"
306class FfmpegTask {
307  task "ffmpeg_operation" @description("FFmpeg operations") @stream.not_null
308  operation "convert" | "trim"
309  input_path string | null
310}
311"#;
312        let mut module = BamlModule::default();
313        module.parse_source(source);
314
315        let code = generate(&module);
316        assert!(code.contains("pub fn all_tools()"));
317        assert!(code.contains("\"ffmpeg_operation\""));
318        assert!(code.contains("pub enum ActionUnion"));
319        assert!(code.contains("FfmpegTask(FfmpegTask)"));
320    }
321
322    #[test]
323    fn generates_from_real_montage_baml() {
324        let mut path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
325        path.pop(); // sgr-agent
326        path.pop(); // crates
327        path.pop(); // rust-code
328        path.push("startups");
329        path.push("active");
330        path.push("video-analyzer");
331        path.push("crates");
332        path.push("va-agent");
333        path.push("baml_src");
334        path.push("montage");
335        path.set_extension("baml");
336        if !path.exists() {
337            eprintln!("Skipping: montage.baml not found at {}", path.display());
338            return;
339        }
340        let source = std::fs::read_to_string(&path).unwrap();
341        let mut module = BamlModule::default();
342        module.parse_source(&source);
343
344        let code = generate(&module);
345
346        // Should have all major structs
347        assert!(code.contains("pub struct MontageAgentNextStep"));
348        assert!(code.contains("pub struct AnalysisTask"));
349        assert!(code.contains("pub struct FfmpegTask"));
350        assert!(code.contains("pub struct ProjectTask"));
351
352        // Should have tool registry
353        assert!(code.contains("pub fn all_tools()"));
354        assert!(code.contains("\"analysis_operation\""));
355        assert!(code.contains("\"ffmpeg_operation\""));
356        assert!(code.contains("\"project_operation\""));
357
358        // Should have prompts
359        assert!(code.contains("DECIDE_MONTAGE_NEXT_STEP_SGR_PROMPT"));
360        assert!(code.contains("ANALYZE_SEGMENT_SGR_PROMPT"));
361    }
362}