Skip to main content

sqlcx_core/generator/go/
structs.rs

1use crate::error::Result;
2use crate::generator::{GeneratedFile, SchemaGenerator};
3use crate::ir::{ColumnDef, Overrides, SqlType, SqlTypeCategory, SqlcxIR};
4use crate::utils::pascal_case;
5use std::collections::BTreeSet;
6
7pub struct GoStructGenerator;
8
9// ── Type mapping ──────────────────────────────────────────────────────────────
10
11/// Map a SQL type to a Go type string (non-nullable).
12fn go_type(sql_type: &SqlType) -> String {
13    if let Some(elem) = &sql_type.element_type {
14        return format!("[]{}", go_type(elem));
15    }
16    match sql_type.category {
17        SqlTypeCategory::String | SqlTypeCategory::Uuid | SqlTypeCategory::Enum => {
18            "string".to_string()
19        }
20        SqlTypeCategory::Number => {
21            let raw_upper = sql_type.raw.to_uppercase();
22            if raw_upper.contains("REAL")
23                || raw_upper.contains("FLOAT")
24                || raw_upper.contains("DOUBLE")
25                || raw_upper.contains("DECIMAL")
26                || raw_upper.contains("NUMERIC")
27            {
28                "float64".to_string()
29            } else {
30                "int64".to_string()
31            }
32        }
33        SqlTypeCategory::Boolean => "bool".to_string(),
34        SqlTypeCategory::Date => "time.Time".to_string(),
35        SqlTypeCategory::Json => "json.RawMessage".to_string(),
36        SqlTypeCategory::Binary => "[]byte".to_string(),
37        SqlTypeCategory::Unknown => "interface{}".to_string(),
38    }
39}
40
41/// Wrap a Go type for nullable columns (pointer type).
42fn nullable_go_type(sql_type: &SqlType) -> String {
43    let base = go_type(sql_type);
44    // Slices are already nullable in Go, but we use pointer for consistency
45    format!("*{}", base)
46}
47
48/// Collect Go imports needed for a set of columns.
49fn collect_imports(columns: &[ColumnDef]) -> BTreeSet<String> {
50    let mut imports = BTreeSet::new();
51    for col in columns {
52        collect_type_imports(&col.sql_type, &mut imports);
53    }
54    imports
55}
56
57fn collect_type_imports(sql_type: &SqlType, imports: &mut BTreeSet<String>) {
58    if let Some(elem) = &sql_type.element_type {
59        collect_type_imports(elem, imports);
60        return;
61    }
62    match sql_type.category {
63        SqlTypeCategory::Date => {
64            imports.insert("time".to_string());
65        }
66        SqlTypeCategory::Json => {
67            imports.insert("encoding/json".to_string());
68        }
69        _ => {}
70    }
71}
72
73fn format_imports(imports: &BTreeSet<String>) -> String {
74    if imports.is_empty() {
75        return String::new();
76    }
77    let lines: Vec<String> = imports.iter().map(|i| format!("\t\"{}\"", i)).collect();
78    format!("import (\n{}\n)\n", lines.join("\n"))
79}
80
81// ── Struct generation ─────────────────────────────────────────────────────────
82
83fn generate_select_struct(
84    table_name: &str,
85    columns: &[ColumnDef],
86    _overrides: &Overrides,
87) -> String {
88    let struct_name = pascal_case(table_name);
89    let fields: Vec<String> = columns
90        .iter()
91        .map(|col| {
92            let field_name = pascal_case(&col.name);
93            let field_type = if col.nullable {
94                nullable_go_type(&col.sql_type)
95            } else {
96                go_type(&col.sql_type)
97            };
98            let pad = compute_padding(&field_name, columns, false);
99            let type_pad = compute_type_padding(&field_type, columns, false);
100            format!(
101                "\t{}{}{}{}`db:\"{}\" json:\"{}\"`",
102                field_name, pad, field_type, type_pad, col.name, col.name,
103            )
104        })
105        .collect();
106    format!("type {} struct {{\n{}\n}}", struct_name, fields.join("\n"))
107}
108
109fn generate_insert_struct(
110    table_name: &str,
111    columns: &[ColumnDef],
112    _overrides: &Overrides,
113) -> String {
114    // Filter to only insertable columns: skip PK with default and auto-timestamps
115    let insertable: Vec<&ColumnDef> = columns
116        .iter()
117        .filter(|col| !(col.has_default && col.name == "id"))
118        .filter(|col| !(col.has_default && col.name == "created_at"))
119        .collect();
120
121    let struct_name = format!("Insert{}", pascal_case(table_name));
122    let fields: Vec<String> = insertable
123        .iter()
124        .map(|col| {
125            let field_name = pascal_case(&col.name);
126            let field_type = if col.nullable || col.has_default {
127                nullable_go_type(&col.sql_type)
128            } else {
129                go_type(&col.sql_type)
130            };
131            let pad = compute_padding_refs(&field_name, &insertable);
132            let type_pad = compute_type_padding_refs(&field_type, &insertable);
133            format!(
134                "\t{}{}{}{}`db:\"{}\" json:\"{}\"`",
135                field_name, pad, field_type, type_pad, col.name, col.name,
136            )
137        })
138        .collect();
139    format!(
140        "// {} has optional fields for columns with defaults\ntype {} struct {{\n{}\n}}",
141        struct_name,
142        struct_name,
143        fields.join("\n")
144    )
145}
146
147/// Compute field name padding for alignment.
148fn compute_padding(field_name: &str, columns: &[ColumnDef], _nullable_insert: bool) -> String {
149    let max_len = columns
150        .iter()
151        .map(|c| pascal_case(&c.name).len())
152        .max()
153        .unwrap_or(0);
154    let pad = max_len - field_name.len() + 1;
155    " ".repeat(pad)
156}
157
158fn compute_padding_refs(field_name: &str, columns: &[&ColumnDef]) -> String {
159    let max_len = columns
160        .iter()
161        .map(|c| pascal_case(&c.name).len())
162        .max()
163        .unwrap_or(0);
164    let pad = max_len - field_name.len() + 1;
165    " ".repeat(pad)
166}
167
168/// Compute type padding for struct tag alignment.
169fn compute_type_padding(field_type: &str, columns: &[ColumnDef], _nullable_insert: bool) -> String {
170    let max_len = columns
171        .iter()
172        .map(|c| {
173            if c.nullable {
174                nullable_go_type(&c.sql_type).len()
175            } else {
176                go_type(&c.sql_type).len()
177            }
178        })
179        .max()
180        .unwrap_or(0);
181    let pad = max_len.saturating_sub(field_type.len()) + 1;
182    " ".repeat(pad)
183}
184
185fn compute_type_padding_refs(field_type: &str, columns: &[&ColumnDef]) -> String {
186    let max_len = columns
187        .iter()
188        .map(|c| {
189            if c.nullable || c.has_default {
190                nullable_go_type(&c.sql_type).len()
191            } else {
192                go_type(&c.sql_type).len()
193            }
194        })
195        .max()
196        .unwrap_or(0);
197    let pad = max_len.saturating_sub(field_type.len()) + 1;
198    " ".repeat(pad)
199}
200
201// ── Generator ─────────────────────────────────────────────────────────────────
202
203impl GoStructGenerator {
204    pub fn generate_models_file(&self, ir: &SqlcxIR, overrides: &Overrides) -> String {
205        let mut parts: Vec<String> = Vec::new();
206
207        parts.push("// Code generated by sqlcx. DO NOT EDIT.".to_string());
208        parts.push("package db".to_string());
209
210        // Collect all imports
211        let mut all_imports = BTreeSet::new();
212        for table in &ir.tables {
213            let sel_imports = collect_imports(&table.columns);
214            let ins_imports = collect_imports(&table.columns);
215            all_imports.extend(sel_imports);
216            all_imports.extend(ins_imports);
217        }
218        let imports_str = format_imports(&all_imports);
219        if !imports_str.is_empty() {
220            parts.push(imports_str);
221        }
222
223        for table in &ir.tables {
224            parts.push(generate_select_struct(
225                &table.name,
226                &table.columns,
227                overrides,
228            ));
229            parts.push(generate_insert_struct(
230                &table.name,
231                &table.columns,
232                overrides,
233            ));
234        }
235
236        parts.join("\n\n") + "\n"
237    }
238}
239
240impl SchemaGenerator for GoStructGenerator {
241    fn generate(&self, ir: &SqlcxIR, overrides: &Overrides) -> Result<GeneratedFile> {
242        Ok(GeneratedFile {
243            path: "models.go".to_string(),
244            content: self.generate_models_file(ir, overrides),
245        })
246    }
247}
248
249// ── Public helpers for driver generator ───────────────────────────────────────
250
251/// Get the Go type for a column (used by the driver generator for scan types).
252pub fn go_column_type(col: &ColumnDef) -> String {
253    if col.nullable {
254        nullable_go_type(&col.sql_type)
255    } else {
256        go_type(&col.sql_type)
257    }
258}
259
260/// Get the base Go type for a SqlType (non-nullable).
261pub fn go_base_type(sql_type: &SqlType) -> String {
262    go_type(sql_type)
263}
264
265/// Collect imports needed for a set of columns.
266pub fn go_imports_for_columns(columns: &[ColumnDef]) -> BTreeSet<String> {
267    collect_imports(columns)
268}
269
270// ── Tests ─────────────────────────────────────────────────────────────────────
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use crate::parser::DatabaseParser;
276    use crate::parser::postgres::PostgresParser;
277    use std::collections::HashMap;
278
279    fn parse_fixture_ir() -> SqlcxIR {
280        let schema_sql = include_str!("../../../../../tests/fixtures/schema.sql");
281        let parser = PostgresParser::new();
282        let (tables, enums) = parser.parse_schema(schema_sql).unwrap();
283        SqlcxIR {
284            tables,
285            queries: vec![],
286            enums,
287        }
288    }
289
290    #[test]
291    fn generates_models_file() {
292        let ir = parse_fixture_ir();
293        let gen_ = GoStructGenerator;
294        let content = gen_.generate_models_file(&ir, &HashMap::new());
295        assert!(content.contains("package db"));
296        assert!(content.contains("type Users struct {"));
297        assert!(content.contains("type InsertUsers struct {"));
298        assert!(content.contains("type Posts struct {"));
299        assert!(content.contains("type InsertPosts struct {"));
300        insta::assert_snapshot!("go_structs_models", content);
301    }
302}