sqlcx_core/generator/go/
structs.rs1use crate::error::Result;
2use crate::generator::{GeneratedFile, SchemaGenerator};
3use crate::ir::{ColumnDef, Overrides, SqlType, SqlTypeCategory, SqlcxIR};
4use crate::utils::pascal_case;
5use std::collections::BTreeSet;
6
7pub struct GoStructGenerator;
8
9fn go_type(sql_type: &SqlType) -> String {
13 if let Some(elem) = &sql_type.element_type {
14 return format!("[]{}", go_type(elem));
15 }
16 match sql_type.category {
17 SqlTypeCategory::String | SqlTypeCategory::Uuid | SqlTypeCategory::Enum => {
18 "string".to_string()
19 }
20 SqlTypeCategory::Number => {
21 let raw_upper = sql_type.raw.to_uppercase();
22 if raw_upper.contains("REAL")
23 || raw_upper.contains("FLOAT")
24 || raw_upper.contains("DOUBLE")
25 || raw_upper.contains("DECIMAL")
26 || raw_upper.contains("NUMERIC")
27 {
28 "float64".to_string()
29 } else {
30 "int64".to_string()
31 }
32 }
33 SqlTypeCategory::Boolean => "bool".to_string(),
34 SqlTypeCategory::Date => "time.Time".to_string(),
35 SqlTypeCategory::Json => "json.RawMessage".to_string(),
36 SqlTypeCategory::Binary => "[]byte".to_string(),
37 SqlTypeCategory::Unknown => "interface{}".to_string(),
38 }
39}
40
41fn nullable_go_type(sql_type: &SqlType) -> String {
43 let base = go_type(sql_type);
44 format!("*{}", base)
46}
47
48fn collect_imports(columns: &[ColumnDef]) -> BTreeSet<String> {
50 let mut imports = BTreeSet::new();
51 for col in columns {
52 collect_type_imports(&col.sql_type, &mut imports);
53 }
54 imports
55}
56
57fn collect_type_imports(sql_type: &SqlType, imports: &mut BTreeSet<String>) {
58 if let Some(elem) = &sql_type.element_type {
59 collect_type_imports(elem, imports);
60 return;
61 }
62 match sql_type.category {
63 SqlTypeCategory::Date => {
64 imports.insert("time".to_string());
65 }
66 SqlTypeCategory::Json => {
67 imports.insert("encoding/json".to_string());
68 }
69 _ => {}
70 }
71}
72
73fn format_imports(imports: &BTreeSet<String>) -> String {
74 if imports.is_empty() {
75 return String::new();
76 }
77 let lines: Vec<String> = imports.iter().map(|i| format!("\t\"{}\"", i)).collect();
78 format!("import (\n{}\n)\n", lines.join("\n"))
79}
80
81fn generate_select_struct(
84 table_name: &str,
85 columns: &[ColumnDef],
86 _overrides: &Overrides,
87) -> String {
88 let struct_name = pascal_case(table_name);
89 let fields: Vec<String> = columns
90 .iter()
91 .map(|col| {
92 let field_name = pascal_case(&col.name);
93 let field_type = if col.nullable {
94 nullable_go_type(&col.sql_type)
95 } else {
96 go_type(&col.sql_type)
97 };
98 let pad = compute_padding(&field_name, columns, false);
99 let type_pad = compute_type_padding(&field_type, columns, false);
100 format!(
101 "\t{}{}{}{}`db:\"{}\" json:\"{}\"`",
102 field_name, pad, field_type, type_pad, col.name, col.name,
103 )
104 })
105 .collect();
106 format!("type {} struct {{\n{}\n}}", struct_name, fields.join("\n"))
107}
108
109fn generate_insert_struct(
110 table_name: &str,
111 columns: &[ColumnDef],
112 _overrides: &Overrides,
113) -> String {
114 let insertable: Vec<&ColumnDef> = columns
116 .iter()
117 .filter(|col| !(col.has_default && col.name == "id"))
118 .filter(|col| !(col.has_default && col.name == "created_at"))
119 .collect();
120
121 let struct_name = format!("Insert{}", pascal_case(table_name));
122 let fields: Vec<String> = insertable
123 .iter()
124 .map(|col| {
125 let field_name = pascal_case(&col.name);
126 let field_type = if col.nullable || col.has_default {
127 nullable_go_type(&col.sql_type)
128 } else {
129 go_type(&col.sql_type)
130 };
131 let pad = compute_padding_refs(&field_name, &insertable);
132 let type_pad = compute_type_padding_refs(&field_type, &insertable);
133 format!(
134 "\t{}{}{}{}`db:\"{}\" json:\"{}\"`",
135 field_name, pad, field_type, type_pad, col.name, col.name,
136 )
137 })
138 .collect();
139 format!(
140 "// {} has optional fields for columns with defaults\ntype {} struct {{\n{}\n}}",
141 struct_name,
142 struct_name,
143 fields.join("\n")
144 )
145}
146
147fn compute_padding(field_name: &str, columns: &[ColumnDef], _nullable_insert: bool) -> String {
149 let max_len = columns
150 .iter()
151 .map(|c| pascal_case(&c.name).len())
152 .max()
153 .unwrap_or(0);
154 let pad = max_len - field_name.len() + 1;
155 " ".repeat(pad)
156}
157
158fn compute_padding_refs(field_name: &str, columns: &[&ColumnDef]) -> String {
159 let max_len = columns
160 .iter()
161 .map(|c| pascal_case(&c.name).len())
162 .max()
163 .unwrap_or(0);
164 let pad = max_len - field_name.len() + 1;
165 " ".repeat(pad)
166}
167
168fn compute_type_padding(field_type: &str, columns: &[ColumnDef], _nullable_insert: bool) -> String {
170 let max_len = columns
171 .iter()
172 .map(|c| {
173 if c.nullable {
174 nullable_go_type(&c.sql_type).len()
175 } else {
176 go_type(&c.sql_type).len()
177 }
178 })
179 .max()
180 .unwrap_or(0);
181 let pad = max_len.saturating_sub(field_type.len()) + 1;
182 " ".repeat(pad)
183}
184
185fn compute_type_padding_refs(field_type: &str, columns: &[&ColumnDef]) -> String {
186 let max_len = columns
187 .iter()
188 .map(|c| {
189 if c.nullable || c.has_default {
190 nullable_go_type(&c.sql_type).len()
191 } else {
192 go_type(&c.sql_type).len()
193 }
194 })
195 .max()
196 .unwrap_or(0);
197 let pad = max_len.saturating_sub(field_type.len()) + 1;
198 " ".repeat(pad)
199}
200
201impl GoStructGenerator {
204 pub fn generate_models_file(&self, ir: &SqlcxIR, overrides: &Overrides) -> String {
205 let mut parts: Vec<String> = Vec::new();
206
207 parts.push("// Code generated by sqlcx. DO NOT EDIT.".to_string());
208 parts.push("package db".to_string());
209
210 let mut all_imports = BTreeSet::new();
212 for table in &ir.tables {
213 let sel_imports = collect_imports(&table.columns);
214 let ins_imports = collect_imports(&table.columns);
215 all_imports.extend(sel_imports);
216 all_imports.extend(ins_imports);
217 }
218 let imports_str = format_imports(&all_imports);
219 if !imports_str.is_empty() {
220 parts.push(imports_str);
221 }
222
223 for table in &ir.tables {
224 parts.push(generate_select_struct(
225 &table.name,
226 &table.columns,
227 overrides,
228 ));
229 parts.push(generate_insert_struct(
230 &table.name,
231 &table.columns,
232 overrides,
233 ));
234 }
235
236 parts.join("\n\n") + "\n"
237 }
238}
239
240impl SchemaGenerator for GoStructGenerator {
241 fn generate(&self, ir: &SqlcxIR, overrides: &Overrides) -> Result<GeneratedFile> {
242 Ok(GeneratedFile {
243 path: "models.go".to_string(),
244 content: self.generate_models_file(ir, overrides),
245 })
246 }
247}
248
249pub fn go_column_type(col: &ColumnDef) -> String {
253 if col.nullable {
254 nullable_go_type(&col.sql_type)
255 } else {
256 go_type(&col.sql_type)
257 }
258}
259
260pub fn go_base_type(sql_type: &SqlType) -> String {
262 go_type(sql_type)
263}
264
265pub fn go_imports_for_columns(columns: &[ColumnDef]) -> BTreeSet<String> {
267 collect_imports(columns)
268}
269
270#[cfg(test)]
273mod tests {
274 use super::*;
275 use crate::parser::DatabaseParser;
276 use crate::parser::postgres::PostgresParser;
277 use std::collections::HashMap;
278
279 fn parse_fixture_ir() -> SqlcxIR {
280 let schema_sql = include_str!("../../../../../tests/fixtures/schema.sql");
281 let parser = PostgresParser::new();
282 let (tables, enums) = parser.parse_schema(schema_sql).unwrap();
283 SqlcxIR {
284 tables,
285 queries: vec![],
286 enums,
287 }
288 }
289
290 #[test]
291 fn generates_models_file() {
292 let ir = parse_fixture_ir();
293 let gen_ = GoStructGenerator;
294 let content = gen_.generate_models_file(&ir, &HashMap::new());
295 assert!(content.contains("package db"));
296 assert!(content.contains("type Users struct {"));
297 assert!(content.contains("type InsertUsers struct {"));
298 assert!(content.contains("type Posts struct {"));
299 assert!(content.contains("type InsertPosts struct {"));
300 insta::assert_snapshot!("go_structs_models", content);
301 }
302}