styx_gen_go/
lib.rs

1#![doc = include_str!("../README.md")]
2//! Go code generation from Styx schemas.
3
4use facet_styx::SchemaFile;
5use std::fmt::Write as _;
6use std::path::Path;
7
8mod error;
9mod types;
10
11pub use error::GenError;
12use types::{GoType, TypeMapper};
13
14/// Generate Go code from a Styx schema.
15pub fn generate(schema: &SchemaFile, package_name: &str, output_dir: &str) -> Result<(), GenError> {
16    let mut mapper = TypeMapper::new();
17
18    // Collect all type definitions from the schema
19    for (name_opt, schema_type) in &schema.schema {
20        let type_name = match name_opt {
21            Some(name) => name.clone(),
22            None => "Root".to_string(), // Root type gets default name
23        };
24        mapper.register_type(&type_name, schema_type)?;
25    }
26
27    // Generate the main types file
28    let types_code = generate_types_file(&mapper, package_name)?;
29    let types_path = Path::new(output_dir).join("types.go");
30    std::fs::write(&types_path, types_code)
31        .map_err(|e| GenError::Io(format!("failed to write {}: {}", types_path.display(), e)))?;
32
33    // Generate the validation file
34    let validation_code = generate_validation_file(&mapper, package_name)?;
35    let validation_path = Path::new(output_dir).join("validation.go");
36    std::fs::write(&validation_path, validation_code).map_err(|e| {
37        GenError::Io(format!(
38            "failed to write {}: {}",
39            validation_path.display(),
40            e
41        ))
42    })?;
43
44    // Generate the parsing helpers file
45    let parse_code = generate_parse_file(&mapper, package_name, schema.meta.id.as_str())?;
46    let parse_path = Path::new(output_dir).join("parse.go");
47    std::fs::write(&parse_path, parse_code)
48        .map_err(|e| GenError::Io(format!("failed to write {}: {}", parse_path.display(), e)))?;
49
50    Ok(())
51}
52
53fn generate_types_file(mapper: &TypeMapper, package_name: &str) -> Result<String, GenError> {
54    let mut out = String::new();
55
56    // Package declaration
57    writeln!(out, "// Code generated by styx gen go. DO NOT EDIT.")?;
58    writeln!(out)?;
59    writeln!(out, "package {}", package_name)?;
60    writeln!(out)?;
61
62    // Generate type definitions
63    for (name, go_type) in mapper.types() {
64        generate_type_definition(&mut out, name, go_type)?;
65        writeln!(out)?;
66    }
67
68    Ok(out)
69}
70
71fn generate_type_definition(
72    out: &mut String,
73    name: &str,
74    go_type: &GoType,
75) -> Result<(), GenError> {
76    match go_type {
77        GoType::Struct { fields, doc } => {
78            if let Some(doc) = doc {
79                for line in doc.lines() {
80                    writeln!(out, "// {}", line)?;
81                }
82            }
83            writeln!(out, "type {} struct {{", name)?;
84            for field in fields {
85                if let Some(doc) = &field.doc {
86                    for line in doc.lines() {
87                        writeln!(out, "\t// {}", line)?;
88                    }
89                }
90                let tags = format!(
91                    "`json:\"{}{}\" styx:\"{}{}\"`",
92                    field.json_name,
93                    if field.optional { ",omitempty" } else { "" },
94                    field.styx_name,
95                    if field.optional { ",optional" } else { "" }
96                );
97                writeln!(out, "\t{} {} {}", field.go_name, field.type_name, tags)?;
98            }
99            writeln!(out, "}}")?;
100        }
101        GoType::Enum { variants, doc } => {
102            if let Some(doc) = doc {
103                for line in doc.lines() {
104                    writeln!(out, "// {}", line)?;
105                }
106            }
107            writeln!(out, "type {} string", name)?;
108            writeln!(out)?;
109            writeln!(out, "const (")?;
110            for variant in variants {
111                let const_name = format!("{}{}", name, to_pascal_case(&variant.name));
112                if let Some(doc) = &variant.doc {
113                    for line in doc.lines() {
114                        writeln!(out, "\t// {}", line)?;
115                    }
116                }
117                writeln!(out, "\t{} {} = \"{}\"", const_name, name, variant.name)?;
118            }
119            writeln!(out, ")")?;
120        }
121        _ => {}
122    }
123    Ok(())
124}
125
126fn generate_validation_file(mapper: &TypeMapper, package_name: &str) -> Result<String, GenError> {
127    let mut out = String::new();
128
129    writeln!(out, "// Code generated by styx gen go. DO NOT EDIT.")?;
130    writeln!(out)?;
131    writeln!(out, "package {}", package_name)?;
132    writeln!(out)?;
133    writeln!(out, "import (")?;
134    writeln!(out, "\t\"fmt\"")?;
135    writeln!(out, "\t\"strings\"")?;
136    writeln!(out, ")")?;
137    writeln!(out)?;
138
139    // ValidationError type
140    writeln!(out, "// ValidationError represents a validation failure.")?;
141    writeln!(out, "type ValidationError struct {{")?;
142    writeln!(out, "\tField   string")?;
143    writeln!(out, "\tMessage string")?;
144    writeln!(out, "}}")?;
145    writeln!(out)?;
146    writeln!(out, "func (e *ValidationError) Error() string {{")?;
147    writeln!(out, "\treturn fmt.Sprintf(\"%s: %s\", e.Field, e.Message)")?;
148    writeln!(out, "}}")?;
149    writeln!(out)?;
150
151    // ValidationErrors type
152    writeln!(
153        out,
154        "// ValidationErrors represents multiple validation failures."
155    )?;
156    writeln!(out, "type ValidationErrors []*ValidationError")?;
157    writeln!(out)?;
158    writeln!(out, "func (e ValidationErrors) Error() string {{")?;
159    writeln!(out, "\tif len(e) == 0 {{")?;
160    writeln!(out, "\t\treturn \"no validation errors\"")?;
161    writeln!(out, "\t}}")?;
162    writeln!(out, "\tif len(e) == 1 {{")?;
163    writeln!(out, "\t\treturn e[0].Error()")?;
164    writeln!(out, "\t}}")?;
165    writeln!(out, "\tvar msgs []string")?;
166    writeln!(out, "\tfor _, err := range e {{")?;
167    writeln!(out, "\t\tmsgs = append(msgs, err.Error())")?;
168    writeln!(out, "\t}}")?;
169    writeln!(
170        out,
171        "\treturn fmt.Sprintf(\"%d validation errors: %s\", len(e), strings.Join(msgs, \"; \"))"
172    )?;
173    writeln!(out, "}}")?;
174    writeln!(out)?;
175
176    // Generate validation functions for each type
177    for (name, go_type) in mapper.types() {
178        if let GoType::Struct { fields, .. } = go_type {
179            generate_validation_function(&mut out, name, fields)?;
180            writeln!(out)?;
181        }
182    }
183
184    Ok(out)
185}
186
187fn generate_validation_function(
188    out: &mut String,
189    type_name: &str,
190    fields: &[types::StructField],
191) -> Result<(), GenError> {
192    writeln!(out, "// Validate validates the {} instance.", type_name)?;
193    writeln!(out, "func (v *{}) Validate() error {{", type_name)?;
194    writeln!(out, "\tvar errs ValidationErrors")?;
195    writeln!(out)?;
196
197    for field in fields {
198        // Check required fields
199        if !field.optional && field.type_name.starts_with('*') {
200            writeln!(out, "\tif v.{} == nil {{", field.go_name)?;
201            writeln!(out, "\t\terrs = append(errs, &ValidationError{{")?;
202            writeln!(out, "\t\t\tField:   \"{}\",", field.styx_name)?;
203            writeln!(out, "\t\t\tMessage: \"required field is missing\",")?;
204            writeln!(out, "\t\t}})")?;
205            writeln!(out, "\t}}")?;
206        }
207
208        // Validate constraints
209        if let Some(constraints) = &field.constraints {
210            if field.optional {
211                writeln!(out, "\tif v.{} != nil {{", field.go_name)?;
212            }
213
214            // String constraints
215            if let Some(min_len) = constraints.min_length {
216                let deref = if field.type_name.starts_with('*') {
217                    "*"
218                } else {
219                    ""
220                };
221                writeln!(
222                    out,
223                    "\t\tif len({}v.{}) < {} {{",
224                    deref, field.go_name, min_len
225                )?;
226                writeln!(out, "\t\t\terrs = append(errs, &ValidationError{{")?;
227                writeln!(out, "\t\t\t\tField:   \"{}\",", field.styx_name)?;
228                writeln!(
229                    out,
230                    "\t\t\t\tMessage: fmt.Sprintf(\"length must be at least {}\", {}),",
231                    min_len, min_len
232                )?;
233                writeln!(out, "\t\t\t}})")?;
234                writeln!(out, "\t\t}}")?;
235            }
236
237            if let Some(max_len) = constraints.max_length {
238                let deref = if field.type_name.starts_with('*') {
239                    "*"
240                } else {
241                    ""
242                };
243                writeln!(
244                    out,
245                    "\t\tif len({}v.{}) > {} {{",
246                    deref, field.go_name, max_len
247                )?;
248                writeln!(out, "\t\t\terrs = append(errs, &ValidationError{{")?;
249                writeln!(out, "\t\t\t\tField:   \"{}\",", field.styx_name)?;
250                writeln!(
251                    out,
252                    "\t\t\t\tMessage: fmt.Sprintf(\"length must be at most {}\", {}),",
253                    max_len, max_len
254                )?;
255                writeln!(out, "\t\t\t}})")?;
256                writeln!(out, "\t\t}}")?;
257            }
258
259            // Integer constraints
260            if let Some(min) = constraints.min_int {
261                let deref = if field.type_name.starts_with('*') {
262                    "*"
263                } else {
264                    ""
265                };
266                writeln!(out, "\t\tif {}v.{} < {} {{", deref, field.go_name, min)?;
267                writeln!(out, "\t\t\terrs = append(errs, &ValidationError{{")?;
268                writeln!(out, "\t\t\t\tField:   \"{}\",", field.styx_name)?;
269                writeln!(
270                    out,
271                    "\t\t\t\tMessage: fmt.Sprintf(\"must be at least {}\", {}),",
272                    min, min
273                )?;
274                writeln!(out, "\t\t\t}})")?;
275                writeln!(out, "\t\t}}")?;
276            }
277
278            if let Some(max) = constraints.max_int {
279                let deref = if field.type_name.starts_with('*') {
280                    "*"
281                } else {
282                    ""
283                };
284                writeln!(out, "\t\tif {}v.{} > {} {{", deref, field.go_name, max)?;
285                writeln!(out, "\t\t\terrs = append(errs, &ValidationError{{")?;
286                writeln!(out, "\t\t\t\tField:   \"{}\",", field.styx_name)?;
287                writeln!(
288                    out,
289                    "\t\t\t\tMessage: fmt.Sprintf(\"must be at most {}\", {}),",
290                    max, max
291                )?;
292                writeln!(out, "\t\t\t}})")?;
293                writeln!(out, "\t\t}}")?;
294            }
295
296            // Float constraints
297            if let Some(min) = constraints.min_float {
298                let deref = if field.type_name.starts_with('*') {
299                    "*"
300                } else {
301                    ""
302                };
303                writeln!(out, "\t\tif {}v.{} < {} {{", deref, field.go_name, min)?;
304                writeln!(out, "\t\t\terrs = append(errs, &ValidationError{{")?;
305                writeln!(out, "\t\t\t\tField:   \"{}\",", field.styx_name)?;
306                writeln!(
307                    out,
308                    "\t\t\t\tMessage: fmt.Sprintf(\"must be at least {}\", {}),",
309                    min, min
310                )?;
311                writeln!(out, "\t\t\t}})")?;
312                writeln!(out, "\t\t}}")?;
313            }
314
315            if let Some(max) = constraints.max_float {
316                let deref = if field.type_name.starts_with('*') {
317                    "*"
318                } else {
319                    ""
320                };
321                writeln!(out, "\t\tif {}v.{} > {} {{", deref, field.go_name, max)?;
322                writeln!(out, "\t\t\terrs = append(errs, &ValidationError{{")?;
323                writeln!(out, "\t\t\t\tField:   \"{}\",", field.styx_name)?;
324                writeln!(
325                    out,
326                    "\t\t\t\tMessage: fmt.Sprintf(\"must be at most {}\", {}),",
327                    max, max
328                )?;
329                writeln!(out, "\t\t\t}})")?;
330                writeln!(out, "\t\t}}")?;
331            }
332
333            if field.optional {
334                writeln!(out, "\t}}")?;
335            }
336        }
337    }
338
339    writeln!(out)?;
340    writeln!(out, "\tif len(errs) > 0 {{")?;
341    writeln!(out, "\t\treturn errs")?;
342    writeln!(out, "\t}}")?;
343    writeln!(out, "\treturn nil")?;
344    writeln!(out, "}}")?;
345
346    Ok(())
347}
348
349fn generate_parse_file(
350    mapper: &TypeMapper,
351    package_name: &str,
352    _schema_id: &str,
353) -> Result<String, GenError> {
354    let mut out = String::new();
355
356    writeln!(out, "// Code generated by styx gen go. DO NOT EDIT.")?;
357    writeln!(out)?;
358    writeln!(out, "package {}", package_name)?;
359    writeln!(out)?;
360    writeln!(out, "import (")?;
361    writeln!(out, "\t\"os\"")?;
362    writeln!(out)?;
363    writeln!(
364        out,
365        "\tstyx \"github.com/bearcove/styx/implementations/styx-go\""
366    )?;
367    writeln!(out, ")")?;
368    writeln!(out)?;
369
370    // Find the root type - should be "Root" which is the default name for the @ type
371    let root_type = mapper
372        .types()
373        .get("Root")
374        .and_then(|go_type| {
375            if matches!(go_type, GoType::Struct { .. }) {
376                Some("Root")
377            } else {
378                None
379            }
380        })
381        .or_else(|| {
382            // Fallback: find any struct type if "Root" doesn't exist
383            mapper.types().iter().find_map(|(name, go_type)| {
384                if matches!(go_type, GoType::Struct { .. }) {
385                    Some(name.as_str())
386                } else {
387                    None
388                }
389            })
390        });
391
392    if let Some(root_type) = root_type {
393        // LoadFromFile helper
394        writeln!(
395            out,
396            "// LoadFromFile loads and validates a {} from a .styx file.",
397            root_type
398        )?;
399        writeln!(
400            out,
401            "func LoadFromFile(path string) (*{}, error) {{",
402            root_type
403        )?;
404        writeln!(out, "\tdata, err := os.ReadFile(path)")?;
405        writeln!(out, "\tif err != nil {{")?;
406        writeln!(out, "\t\treturn nil, err")?;
407        writeln!(out, "\t}}")?;
408        writeln!(out, "\treturn Parse(string(data))")?;
409        writeln!(out, "}}")?;
410        writeln!(out)?;
411
412        // Parse helper
413        writeln!(
414            out,
415            "// Parse parses and validates a {} from a .styx string.",
416            root_type
417        )?;
418        writeln!(out, "func Parse(source string) (*{}, error) {{", root_type)?;
419        writeln!(out, "\tdoc, err := styx.Parse(source)")?;
420        writeln!(out, "\tif err != nil {{")?;
421        writeln!(out, "\t\treturn nil, err")?;
422        writeln!(out, "\t}}")?;
423        writeln!(out)?;
424        writeln!(out, "\t// TODO: Map styx.Document to {} struct", root_type)?;
425        writeln!(
426            out,
427            "\t// This requires implementing a mapper from styx values to Go types"
428        )?;
429        writeln!(out, "\tvar config {}", root_type)?;
430        writeln!(out, "\t_ = doc // Use doc to populate config")?;
431        writeln!(out)?;
432        writeln!(out, "\t// Validate the parsed config")?;
433        writeln!(out, "\tif err := config.Validate(); err != nil {{")?;
434        writeln!(out, "\t\treturn nil, err")?;
435        writeln!(out, "\t}}")?;
436        writeln!(out)?;
437        writeln!(out, "\treturn &config, nil")?;
438        writeln!(out, "}}")?;
439    }
440
441    Ok(out)
442}
443
444fn to_pascal_case(s: &str) -> String {
445    s.split(['_', '-'])
446        .filter(|s| !s.is_empty())
447        .map(|word| {
448            let mut chars = word.chars();
449            match chars.next() {
450                None => String::new(),
451                Some(first) => first.to_uppercase().chain(chars).collect(),
452            }
453        })
454        .collect()
455}
456
457impl From<std::fmt::Error> for GenError {
458    fn from(_: std::fmt::Error) -> Self {
459        GenError::Format("formatting error".into())
460    }
461}