Skip to main content

sgr_agent/
baml_parser.rs

1//! Lightweight BAML parser — extracts classes, enums, functions from `.baml` files.
2//!
3//! NOT a full BAML compiler. Parses just enough to generate:
4//! - Rust structs with `#[derive(JsonSchema, Serialize, Deserialize)]`
5//! - Tool definitions with prompts
6//! - Enum types with string variants
7//!
8//! Source of truth stays in `.baml` files.
9
10use std::path::Path;
11
12/// A parsed BAML class (→ Rust struct).
13#[derive(Debug, Clone)]
14pub struct BamlClass {
15    pub name: String,
16    pub fields: Vec<BamlField>,
17    pub description: Option<String>,
18}
19
20/// A field within a BAML class.
21#[derive(Debug, Clone)]
22pub struct BamlField {
23    pub name: String,
24    pub ty: BamlType,
25    pub description: Option<String>,
26    /// Fixed string value (e.g. `task "analysis_operation"`).
27    pub fixed_value: Option<String>,
28}
29
30/// BAML type representation.
31#[derive(Debug, Clone)]
32pub enum BamlType {
33    String,
34    Int,
35    Float,
36    Bool,
37    /// A string enum: `"trim" | "keep" | "highlight"`
38    StringEnum(Vec<String>),
39    /// Reference to another class.
40    Ref(String),
41    /// Optional type (T | null).
42    Optional(Box<BamlType>),
43    /// Array type.
44    Array(Box<BamlType>),
45    /// Union of class references (for next_actions).
46    Union(Vec<String>),
47    /// Image type (special BAML type).
48    Image,
49}
50
51/// A parsed BAML function (→ tool definition + prompt).
52#[derive(Debug, Clone)]
53pub struct BamlFunction {
54    pub name: String,
55    pub params: Vec<(String, BamlType)>,
56    pub return_type: String,
57    pub client: String,
58    pub prompt: String,
59}
60
61/// All parsed items from BAML source files.
62#[derive(Debug, Clone, Default)]
63pub struct BamlModule {
64    pub classes: Vec<BamlClass>,
65    pub functions: Vec<BamlFunction>,
66}
67
68impl BamlModule {
69    /// Parse all `.baml` files in a directory.
70    pub fn parse_dir(dir: &Path) -> Result<Self, String> {
71        let mut module = BamlModule::default();
72
73        let entries =
74            std::fs::read_dir(dir).map_err(|e| format!("Cannot read {}: {}", dir.display(), e))?;
75
76        for entry in entries.flatten() {
77            let path = entry.path();
78            if path.extension().is_some_and(|ext| ext == "baml") {
79                let source = std::fs::read_to_string(&path)
80                    .map_err(|e| format!("Cannot read {}: {}", path.display(), e))?;
81                module.parse_source(&source);
82            }
83        }
84
85        Ok(module)
86    }
87
88    /// Parse a single BAML source string.
89    pub fn parse_source(&mut self, source: &str) {
90        let lines: Vec<&str> = source.lines().collect();
91        let mut i = 0;
92
93        while i < lines.len() {
94            let line = lines[i].trim();
95
96            // Skip comments and empty lines
97            if line.is_empty() || line.starts_with("//") {
98                i += 1;
99                continue;
100            }
101
102            // Class definition
103            if line.starts_with("class ")
104                && let Some((class, consumed)) = parse_class(&lines[i..])
105            {
106                self.classes.push(class);
107                i += consumed;
108                continue;
109            }
110
111            // Function definition
112            if line.starts_with("function ")
113                && let Some((func, consumed)) = parse_function(&lines[i..])
114            {
115                self.functions.push(func);
116                i += consumed;
117                continue;
118            }
119
120            i += 1;
121        }
122    }
123
124    /// Find a class by name.
125    pub fn find_class(&self, name: &str) -> Option<&BamlClass> {
126        self.classes.iter().find(|c| c.name == name)
127    }
128
129    /// Find a function by name.
130    pub fn find_function(&self, name: &str) -> Option<&BamlFunction> {
131        self.functions.iter().find(|f| f.name == name)
132    }
133}
134
135// --- Parsers ---
136
137fn parse_class(lines: &[&str]) -> Option<(BamlClass, usize)> {
138    let header = lines[0].trim();
139    let name = header
140        .strip_prefix("class ")?
141        .trim()
142        .trim_end_matches('{')
143        .trim()
144        .to_string();
145
146    let mut fields = Vec::new();
147    let mut i = 1;
148
149    while i < lines.len() {
150        let line = lines[i].trim();
151        i += 1;
152
153        if line == "}" {
154            break;
155        }
156        if line.is_empty() || line.starts_with("//") {
157            continue;
158        }
159
160        if let Some(field) = parse_field(line) {
161            fields.push(field);
162        }
163    }
164
165    Some((
166        BamlClass {
167            name,
168            fields,
169            description: None,
170        },
171        i,
172    ))
173}
174
175fn parse_field(line: &str) -> Option<BamlField> {
176    // Examples:
177    //   action "trim" | "keep" | "highlight" @description("...")
178    //   input_path string | null @description("...")
179    //   task "analysis_operation" @description("...") @stream.not_null
180    //   target_seconds int @description("...")
181    //   next_actions (Type1 | Type2)[] @description("...")
182
183    let line = line.trim();
184
185    // Extract description
186    let description = extract_description(line);
187
188    // Remove annotations (@description, @stream, etc.)
189    let clean = remove_annotations(line);
190    let clean = clean.trim();
191
192    // Split into name and type
193    let mut parts = clean.splitn(2, char::is_whitespace);
194    let name = parts.next()?.trim().to_string();
195    let type_str = parts.next()?.trim();
196
197    // Check for fixed value: `task "analysis_operation"`
198    if type_str.starts_with('"') && !type_str.contains('|') {
199        let value = type_str.trim_matches('"').to_string();
200        return Some(BamlField {
201            name,
202            ty: BamlType::String,
203            description,
204            fixed_value: Some(value),
205        });
206    }
207
208    let ty = parse_type(type_str);
209
210    Some(BamlField {
211        name,
212        ty,
213        description,
214        fixed_value: None,
215    })
216}
217
218fn parse_type(s: &str) -> BamlType {
219    let s = s.trim();
220
221    // Array: T[] or (T)[]
222    if s.ends_with("[]") {
223        let inner = s.trim_end_matches("[]").trim();
224        // Union array: (Type1 | Type2)[]
225        if inner.starts_with('(') && inner.ends_with(')') {
226            let inner_types = &inner[1..inner.len() - 1];
227            let variants: Vec<String> = inner_types
228                .split('|')
229                .map(|v| v.trim().to_string())
230                .collect();
231            // Check if all variants are class references (start with uppercase)
232            if variants
233                .iter()
234                .all(|v| v.starts_with(|c: char| c.is_uppercase()))
235            {
236                return BamlType::Array(Box::new(BamlType::Union(variants)));
237            }
238        }
239        let inner_type = parse_type(inner);
240        return BamlType::Array(Box::new(inner_type));
241    }
242
243    // Nullable: T | null
244    if s.contains("| null") || s.contains("null |") {
245        let base = s
246            .replace("| null", "")
247            .replace("null |", "")
248            .trim()
249            .to_string();
250        return BamlType::Optional(Box::new(parse_type(&base)));
251    }
252
253    // String enum: "a" | "b" | "c"
254    if s.contains('"') && s.contains('|') {
255        let variants: Vec<String> = s
256            .split('|')
257            .map(|v| v.trim().trim_matches('"').to_string())
258            .filter(|v| !v.is_empty())
259            .collect();
260        return BamlType::StringEnum(variants);
261    }
262
263    // Union of types (without quotes): Type1 | Type2
264    if s.contains('|') {
265        let variants: Vec<String> = s.split('|').map(|v| v.trim().to_string()).collect();
266        if variants
267            .iter()
268            .all(|v| v.starts_with(|c: char| c.is_uppercase()))
269        {
270            return BamlType::Union(variants);
271        }
272    }
273
274    // Primitives
275    match s {
276        "string" => BamlType::String,
277        "int" => BamlType::Int,
278        "float" => BamlType::Float,
279        "bool" => BamlType::Bool,
280        "image" => BamlType::Image,
281        _ => {
282            // Class reference
283            if s.starts_with(|c: char| c.is_uppercase()) {
284                BamlType::Ref(s.to_string())
285            } else {
286                BamlType::String // fallback
287            }
288        }
289    }
290}
291
292fn extract_description(line: &str) -> Option<String> {
293    let marker = "@description(\"";
294    if let Some(start) = line.find(marker) {
295        let rest = &line[start + marker.len()..];
296        if let Some(end) = rest.find("\")") {
297            return Some(rest[..end].to_string());
298        }
299    }
300    None
301}
302
303fn remove_annotations(line: &str) -> String {
304    let mut result = line.to_string();
305    // Remove @description("...")
306    while let Some(start) = result.find("@description(\"") {
307        if let Some(end) = result[start..].find("\")") {
308            result = format!("{}{}", &result[..start], &result[start + end + 2..]);
309        } else {
310            break;
311        }
312    }
313    // Remove @stream.not_null and other @annotations
314    while let Some(start) = result.find('@') {
315        let rest = &result[start + 1..];
316        let end = rest.find(|c: char| c.is_whitespace()).unwrap_or(rest.len());
317        result = format!("{}{}", &result[..start], &result[start + 1 + end..]);
318    }
319    result
320}
321
322fn parse_function(lines: &[&str]) -> Option<(BamlFunction, usize)> {
323    let header = lines[0].trim();
324
325    // function Name(param: Type, ...) -> ReturnType {
326    let rest = header.strip_prefix("function ")?;
327
328    // Extract name
329    let paren_start = rest.find('(')?;
330    let name = rest[..paren_start].trim().to_string();
331
332    // Extract params
333    let paren_end = rest.find(')')?;
334    let params_str = &rest[paren_start + 1..paren_end];
335    let params: Vec<(String, BamlType)> = if params_str.trim().is_empty() {
336        vec![]
337    } else {
338        params_str
339            .split(',')
340            .filter_map(|p| {
341                let p = p.trim();
342                let mut parts = p.splitn(2, ':');
343                let pname = parts.next()?.trim().to_string();
344                let ptype = parse_type(parts.next()?.trim());
345                Some((pname, ptype))
346            })
347            .collect()
348    };
349
350    // Extract return type
351    let arrow = rest.find("->")?;
352    let return_rest = rest[arrow + 2..].trim();
353    let return_type = return_rest.trim_end_matches('{').trim().to_string();
354
355    // Extract body (client + prompt)
356    let mut client = String::new();
357    let mut prompt_lines = Vec::new();
358    let mut in_prompt = false;
359    let mut i = 1;
360
361    while i < lines.len() {
362        let line = lines[i].trim();
363        i += 1;
364
365        if line == "}" && !in_prompt {
366            break;
367        }
368
369        if line.starts_with("client ") {
370            client = line
371                .strip_prefix("client ")
372                .unwrap_or("")
373                .trim()
374                .trim_matches('"')
375                .to_string();
376            continue;
377        }
378
379        if line.starts_with("prompt #\"") {
380            in_prompt = true;
381            // Content after prompt #"
382            let after = line.strip_prefix("prompt #\"").unwrap_or("");
383            if !after.is_empty() {
384                prompt_lines.push(after.to_string());
385            }
386            continue;
387        }
388
389        if in_prompt {
390            if line.contains("\"#") {
391                let before = line.trim_end_matches("\"#").trim_end();
392                if !before.is_empty() {
393                    prompt_lines.push(before.to_string());
394                }
395                in_prompt = false;
396                continue;
397            }
398            prompt_lines.push(lines[i - 1].to_string());
399        }
400    }
401
402    Some((
403        BamlFunction {
404            name,
405            params,
406            return_type,
407            client,
408            prompt: prompt_lines.join("\n"),
409        },
410        i,
411    ))
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417
418    #[test]
419    fn parses_simple_class() {
420        let source = r#"
421class CutDecision {
422  action "trim" | "keep" | "highlight" @description("Editing action")
423  reason string @description("Short reasoning")
424}
425"#;
426        let mut module = BamlModule::default();
427        module.parse_source(source);
428
429        assert_eq!(module.classes.len(), 1);
430        let cls = &module.classes[0];
431        assert_eq!(cls.name, "CutDecision");
432        assert_eq!(cls.fields.len(), 2);
433
434        let action = &cls.fields[0];
435        assert_eq!(action.name, "action");
436        match &action.ty {
437            BamlType::StringEnum(variants) => {
438                assert_eq!(variants, &["trim", "keep", "highlight"]);
439            }
440            other => panic!("Expected StringEnum, got {:?}", other),
441        }
442        assert_eq!(action.description.as_deref(), Some("Editing action"));
443    }
444
445    #[test]
446    fn parses_class_with_optional_and_array() {
447        let source = r#"
448class FfmpegTask {
449  task "ffmpeg_operation" @description("FFmpeg ops") @stream.not_null
450  operation "convert" | "trim" | "concat"
451  input_path string | null
452  custom_args string[] | null
453  overwrite bool | null
454}
455"#;
456        let mut module = BamlModule::default();
457        module.parse_source(source);
458
459        let cls = &module.classes[0];
460        assert_eq!(cls.name, "FfmpegTask");
461
462        // task has fixed value
463        assert_eq!(
464            cls.fields[0].fixed_value.as_deref(),
465            Some("ffmpeg_operation")
466        );
467
468        // input_path is Optional<String>
469        assert!(matches!(cls.fields[2].ty, BamlType::Optional(_)));
470
471        // custom_args is Optional<Array<String>>
472        match &cls.fields[3].ty {
473            BamlType::Optional(inner) => {
474                assert!(matches!(inner.as_ref(), BamlType::Array(_)));
475            }
476            other => panic!("Expected Optional(Array), got {:?}", other),
477        }
478    }
479
480    #[test]
481    fn parses_union_array() {
482        let source = r#"
483class MontageAgentNextStep {
484  intent "display" | "montage"
485  next_actions (AnalysisTask | FfmpegTask | ProjectTask)[] @description("Tools to execute")
486}
487"#;
488        let mut module = BamlModule::default();
489        module.parse_source(source);
490
491        let cls = &module.classes[0];
492        let actions_field = &cls.fields[1];
493        match &actions_field.ty {
494            BamlType::Array(inner) => match inner.as_ref() {
495                BamlType::Union(variants) => {
496                    assert_eq!(variants, &["AnalysisTask", "FfmpegTask", "ProjectTask"]);
497                }
498                other => panic!("Expected Union inside Array, got {:?}", other),
499            },
500            other => panic!("Expected Array, got {:?}", other),
501        }
502    }
503
504    #[test]
505    fn parses_function() {
506        let source = r##"
507function AnalyzeSegmentSgr(genre: string, scene: string) -> SgrSegmentDecision {
508  client AgentFallback
509  prompt #"
510    You are a video editor.
511    Genre: {{ genre }}
512    {{ ctx.output_format }}
513  "#
514}
515"##;
516        let mut module = BamlModule::default();
517        module.parse_source(source);
518
519        assert_eq!(module.functions.len(), 1);
520        let func = &module.functions[0];
521        assert_eq!(func.name, "AnalyzeSegmentSgr");
522        assert_eq!(func.params.len(), 2);
523        assert_eq!(func.return_type, "SgrSegmentDecision");
524        assert_eq!(func.client, "AgentFallback");
525        assert!(func.prompt.contains("video editor"));
526    }
527
528    #[test]
529    fn parses_real_montage_baml() {
530        let mut path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
531        path.pop(); // sgr-agent
532        path.pop(); // crates
533        path.pop(); // rust-code
534        path.push("startups");
535        path.push("active");
536        path.push("video-analyzer");
537        path.push("crates");
538        path.push("va-agent");
539        path.push("baml_src");
540        path.push("montage");
541        path.set_extension("baml");
542        if !path.exists() {
543            eprintln!("Skipping: montage.baml not found at {}", path.display());
544            return;
545        }
546        let source = std::fs::read_to_string(&path).unwrap();
547        let mut module = BamlModule::default();
548        module.parse_source(&source);
549
550        // Should find all major classes
551        assert!(module.find_class("CutDecision").is_some());
552        assert!(module.find_class("MontageAgentNextStep").is_some());
553        assert!(module.find_class("AnalysisTask").is_some());
554        assert!(module.find_class("FfmpegTask").is_some());
555        assert!(module.find_class("ProjectTask").is_some());
556        assert!(module.find_class("ReportTaskCompletion").is_some());
557
558        // Should find major functions
559        assert!(module.find_function("AnalyzeSegmentSgr").is_some());
560        assert!(module.find_function("DecideMontageNextStepSgr").is_some());
561        assert!(module.find_function("SummarizeTranscriptSgr").is_some());
562
563        // MontageAgentNextStep should have union array for next_actions
564        let step = module.find_class("MontageAgentNextStep").unwrap();
565        let actions = step
566            .fields
567            .iter()
568            .find(|f| f.name == "next_actions")
569            .unwrap();
570        match &actions.ty {
571            BamlType::Array(inner) => match inner.as_ref() {
572                BamlType::Union(variants) => {
573                    assert!(variants.contains(&"AnalysisTask".to_string()));
574                    assert!(variants.contains(&"FfmpegTask".to_string()));
575                    assert!(
576                        variants.len() >= 10,
577                        "Should have 16 tool types, got {}",
578                        variants.len()
579                    );
580                }
581                other => panic!("Expected Union, got {:?}", other),
582            },
583            other => panic!("Expected Array, got {:?}", other),
584        }
585    }
586}