Skip to main content

sqlcx_core/generator/go/
database_sql.rs

1// database/sql driver generator for Go
2
3use std::collections::BTreeMap;
4use std::collections::BTreeSet;
5use std::path::Path;
6
7use crate::error::Result;
8use crate::generator::go::common::{
9    escape_sql, func_params, generate_result_struct, generate_row_struct, query_args, scan_fields,
10    sql_const_name,
11};
12use crate::generator::{DriverGenerator, GeneratedFile};
13use crate::ir::{ColumnDef, QueryCommand, QueryDef, SqlcxIR};
14use crate::utils::pascal_case;
15
16use super::structs::go_imports_for_columns;
17
18/// Go's `database/sql` package is driver-agnostic — any third-party driver
19/// can plug into the same interface. The generated code is identical across
20/// backends; only the SQL placeholder style differs (Postgres uses `$1,
21/// $2, ...`; MySQL and SQLite use `?`). `DatabaseSqlBackend` parameterizes
22/// which placeholder style the generated SQL constant uses.
23#[derive(Debug, Clone, Copy)]
24pub enum DatabaseSqlBackend {
25    Postgres,
26    MySql,
27    Sqlite,
28}
29
30impl DatabaseSqlBackend {
31    /// Rewrite `$N` placeholders to the style this backend expects. For
32    /// Postgres, returns the SQL unchanged with empty occurrence indices.
33    /// For MySQL/SQLite, replaces every `$N` outside a single-quoted
34    /// string literal with `?`, and returns the `$N` indices in document
35    /// order so callers can emit args in positional order (handles reused
36    /// and out-of-order params correctly).
37    fn rewrite_placeholders(self, sql: &str) -> (String, Vec<u32>) {
38        match self {
39            DatabaseSqlBackend::Postgres => (sql.to_string(), Vec::new()),
40            DatabaseSqlBackend::MySql | DatabaseSqlBackend::Sqlite => {
41                let mut out = String::with_capacity(sql.len());
42                let mut indices = Vec::new();
43                let mut chars = sql.chars().peekable();
44                let mut in_string = false;
45                while let Some(c) = chars.next() {
46                    if c == '\'' {
47                        if in_string && chars.peek() == Some(&'\'') {
48                            out.push(c);
49                            out.push(chars.next().unwrap());
50                            continue;
51                        }
52                        in_string = !in_string;
53                        out.push(c);
54                        continue;
55                    }
56                    if !in_string && c == '$' && chars.peek().is_some_and(|ch| ch.is_ascii_digit())
57                    {
58                        let mut num = String::new();
59                        while chars.peek().is_some_and(|ch| ch.is_ascii_digit()) {
60                            num.push(chars.next().unwrap());
61                        }
62                        indices.push(num.parse::<u32>().unwrap_or(0));
63                        out.push('?');
64                    } else {
65                        out.push(c);
66                    }
67                }
68                (out, indices)
69            }
70        }
71    }
72}
73
74pub struct DatabaseSqlGenerator {
75    backend: DatabaseSqlBackend,
76}
77
78impl DatabaseSqlGenerator {
79    pub fn postgres() -> Self {
80        Self {
81            backend: DatabaseSqlBackend::Postgres,
82        }
83    }
84
85    pub fn mysql() -> Self {
86        Self {
87            backend: DatabaseSqlBackend::MySql,
88        }
89    }
90
91    pub fn sqlite() -> Self {
92        Self {
93            backend: DatabaseSqlBackend::Sqlite,
94        }
95    }
96}
97
98impl Default for DatabaseSqlGenerator {
99    fn default() -> Self {
100        Self::postgres()
101    }
102}
103
104// ── Client file ───────────────────────────────────────────────────────────────
105
106fn generate_client() -> String {
107    r#"// Code generated by sqlcx. DO NOT EDIT.
108package db
109
110import (
111	"context"
112	"database/sql"
113)
114
115type DBTX interface {
116	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
117	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
118	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
119}
120
121type Queries struct {
122	db DBTX
123}
124
125func New(db DBTX) *Queries {
126	return &Queries{db: db}
127}"#
128    .to_string()
129}
130
131// ── Query generation ──────────────────────────────────────────────────────────
132
133/// Generate a full query function. `backend` determines the placeholder
134/// style in the emitted SQL constant.
135fn generate_query_function(backend: DatabaseSqlBackend, query: &QueryDef) -> String {
136    let const_name = sql_const_name(&query.name);
137    let func_name = pascal_case(&query.name);
138    let params = func_params(query);
139
140    let mut parts: Vec<String> = Vec::new();
141
142    // SQL constant — placeholder style depends on backend.
143    let (rewritten_sql, occurrence_indices) = backend.rewrite_placeholders(&query.sql);
144    parts.push(format!(
145        "const {} = \"{}\"",
146        const_name,
147        escape_sql(&rewritten_sql),
148    ));
149
150    // Args — Postgres binds each unique param once (query_args default),
151    // MySQL/SQLite bind one arg per placeholder occurrence (handles reused
152    // and out-of-order $N correctly).
153    let args = if occurrence_indices.is_empty() {
154        query_args(query)
155    } else {
156        let names: Vec<String> = occurrence_indices
157            .iter()
158            .map(|idx| {
159                query
160                    .params
161                    .iter()
162                    .find(|p| p.index == *idx)
163                    .map(|p| p.name.clone())
164                    .unwrap_or_else(|| "nil".to_string())
165            })
166            .collect();
167        if names.is_empty() {
168            String::new()
169        } else {
170            format!(", {}", names.join(", "))
171        }
172    };
173
174    match query.command {
175        QueryCommand::One => {
176            let row_type = format!("{}Row", pascal_case(&query.name));
177            if let Some(row_struct) = generate_row_struct(query) {
178                parts.push(row_struct);
179            }
180            let scans = scan_fields(&query.returns);
181            parts.push(format!(
182                "func (q *Queries) {}({}) (*{}, error) {{\n\
183                \trow := q.db.QueryRowContext(ctx, {}{})
184\tvar i {}
185\terr := row.Scan({})
186\tif err == sql.ErrNoRows {{
187\t\treturn nil, nil
188\t}}
189\tif err != nil {{
190\t\treturn nil, err
191\t}}
192\treturn &i, nil\n}}",
193                func_name, params, row_type, const_name, args, row_type, scans,
194            ));
195        }
196        QueryCommand::Many => {
197            let row_type = format!("{}Row", pascal_case(&query.name));
198            if let Some(row_struct) = generate_row_struct(query) {
199                parts.push(row_struct);
200            }
201            let scans = scan_fields(&query.returns);
202            parts.push(format!(
203                "func (q *Queries) {}({}) ([]{}, error) {{\n\
204                \trows, err := q.db.QueryContext(ctx, {}{})
205\tif err != nil {{
206\t\treturn nil, err
207\t}}
208\tdefer rows.Close()
209\tvar items []{}
210\tfor rows.Next() {{
211\t\tvar i {}
212\t\tif err := rows.Scan({}); err != nil {{
213\t\t\treturn nil, err
214\t\t}}
215\t\titems = append(items, i)
216\t}}
217\treturn items, rows.Err()\n}}",
218                func_name, params, row_type, const_name, args, row_type, row_type, scans,
219            ));
220        }
221        QueryCommand::Exec => {
222            parts.push(format!(
223                "func (q *Queries) {}({}) error {{\n\
224                \t_, err := q.db.ExecContext(ctx, {}{})\n\treturn err\n}}",
225                func_name, params, const_name, args,
226            ));
227        }
228        QueryCommand::ExecResult => {
229            let result_type = format!("{}Result", pascal_case(&query.name));
230            parts.push(generate_result_struct(query));
231            parts.push(format!(
232                "func (q *Queries) {}({}) (*{}, error) {{\n\
233                \tresult, err := q.db.ExecContext(ctx, {}{})
234\tif err != nil {{
235\t\treturn nil, err
236\t}}
237\taffected, err := result.RowsAffected()
238\tif err != nil {{
239\t\treturn nil, err
240\t}}
241\treturn &{}{{RowsAffected: affected}}, nil\n}}",
242                func_name, params, result_type, const_name, args, result_type,
243            ));
244        }
245    }
246
247    parts.join("\n\n")
248}
249
250/// Collect imports needed for a set of queries.
251fn collect_query_imports(queries: &[QueryDef]) -> BTreeSet<String> {
252    let mut imports = BTreeSet::new();
253    imports.insert("context".to_string());
254
255    let mut needs_database_sql = false;
256    for query in queries {
257        // Check return columns
258        let ret_imports = go_imports_for_columns(&query.returns);
259        imports.extend(ret_imports);
260
261        // Check if we need database/sql (for ErrNoRows or Result)
262        match query.command {
263            QueryCommand::One | QueryCommand::ExecResult => {
264                needs_database_sql = true;
265            }
266            _ => {}
267        }
268
269        // Check params for time/json imports
270        for param in &query.params {
271            let col = ColumnDef {
272                name: param.name.clone(),
273                alias: None,
274                source_table: None,
275                sql_type: param.sql_type.clone(),
276                nullable: false,
277                has_default: false,
278            };
279            let col_imports = go_imports_for_columns(&[col]);
280            imports.extend(col_imports);
281        }
282    }
283
284    if needs_database_sql {
285        imports.insert("database/sql".to_string());
286    }
287
288    imports
289}
290
291// ── Public API ────────────────────────────────────────────────────────────────
292
293impl DatabaseSqlGenerator {
294    pub fn generate_client(&self) -> String {
295        generate_client()
296    }
297
298    pub fn generate_query_file(&self, queries: &[QueryDef]) -> String {
299        let imports = collect_query_imports(queries);
300        let imports_str = if imports.is_empty() {
301            String::new()
302        } else {
303            let lines: Vec<String> = imports.iter().map(|i| format!("\t\"{}\"", i)).collect();
304            format!("\nimport (\n{}\n)\n", lines.join("\n"))
305        };
306
307        let functions: Vec<String> = queries
308            .iter()
309            .map(|q| generate_query_function(self.backend, q))
310            .collect();
311
312        let mut content = String::new();
313        content.push_str("// Code generated by sqlcx. DO NOT EDIT.\npackage db\n");
314        content.push_str(&imports_str);
315        if !functions.is_empty() {
316            content.push('\n');
317            content.push_str(&functions.join("\n\n"));
318            content.push('\n');
319        }
320        content
321    }
322}
323
324impl DriverGenerator for DatabaseSqlGenerator {
325    fn generate(&self, ir: &SqlcxIR) -> Result<Vec<GeneratedFile>> {
326        let mut files = Vec::new();
327
328        // client.go
329        files.push(GeneratedFile {
330            path: "client.go".to_string(),
331            content: self.generate_client(),
332        });
333
334        // Group queries by source_file → one .queries.go per file
335        let mut grouped: BTreeMap<String, Vec<&QueryDef>> = BTreeMap::new();
336        for query in &ir.queries {
337            grouped
338                .entry(query.source_file.clone())
339                .or_default()
340                .push(query);
341        }
342        for (source_file, queries) in &grouped {
343            let basename = Path::new(source_file)
344                .file_stem()
345                .unwrap_or_default()
346                .to_string_lossy();
347            let owned: Vec<QueryDef> = queries.iter().map(|q| (*q).clone()).collect();
348            files.push(GeneratedFile {
349                path: format!("{}.queries.go", basename),
350                content: self.generate_query_file(&owned),
351            });
352        }
353
354        Ok(files)
355    }
356}
357
358// ── Tests ─────────────────────────────────────────────────────────────────────
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363    use crate::parser::DatabaseParser;
364    use crate::parser::postgres::PostgresParser;
365
366    fn parse_fixture_ir() -> SqlcxIR {
367        let schema_sql = include_str!("../../../../../tests/fixtures/schema.sql");
368        let queries_sql = include_str!("../../../../../tests/fixtures/queries/users.sql");
369        let parser = PostgresParser::new();
370        let (tables, enums) = parser.parse_schema(schema_sql).unwrap();
371        let queries = parser
372            .parse_queries(queries_sql, &tables, &enums, "queries/users.sql")
373            .unwrap();
374        SqlcxIR {
375            tables,
376            queries,
377            enums,
378        }
379    }
380
381    #[test]
382    fn generates_client_file() {
383        let gen_ = DatabaseSqlGenerator::postgres();
384        let content = gen_.generate_client();
385        assert!(content.contains("type DBTX interface"));
386        assert!(content.contains("type Queries struct"));
387        assert!(content.contains("func New(db DBTX) *Queries"));
388        insta::assert_snapshot!("go_database_sql_client", content);
389    }
390
391    #[test]
392    fn generates_query_file() {
393        let ir = parse_fixture_ir();
394        let gen_ = DatabaseSqlGenerator::postgres();
395        let content = gen_.generate_query_file(&ir.queries);
396        assert!(content.contains("func (q *Queries) GetUser"));
397        assert!(content.contains("func (q *Queries) ListUsers"));
398        assert!(content.contains("func (q *Queries) CreateUser"));
399        assert!(content.contains("func (q *Queries) DeleteUser"));
400        // Postgres keeps $N placeholders.
401        assert!(content.contains("WHERE id = $1"));
402        insta::assert_snapshot!("go_database_sql_queries", content);
403    }
404
405    #[test]
406    fn mysql_backend_rewrites_placeholders_to_question_marks() {
407        let ir = parse_fixture_ir();
408        let gen_ = DatabaseSqlGenerator::mysql();
409        let content = gen_.generate_query_file(&ir.queries);
410        // MySQL: `$1` → `?`.
411        assert!(content.contains("WHERE id = ?"));
412        assert!(!content.contains("WHERE id = $1"));
413        insta::assert_snapshot!("go_database_sql_mysql_queries", content);
414    }
415
416    #[test]
417    fn sqlite_backend_rewrites_placeholders_to_question_marks() {
418        let ir = parse_fixture_ir();
419        let gen_ = DatabaseSqlGenerator::sqlite();
420        let content = gen_.generate_query_file(&ir.queries);
421        assert!(content.contains("WHERE id = ?"));
422        assert!(!content.contains("WHERE id = $1"));
423        insta::assert_snapshot!("go_database_sql_sqlite_queries", content);
424    }
425
426    #[test]
427    fn placeholder_rewrite_preserves_dollar_in_string_literals() {
428        let (rewritten, idx) =
429            DatabaseSqlBackend::MySql.rewrite_placeholders("SELECT '$1' FROM x WHERE a = $1");
430        assert_eq!(rewritten, "SELECT '$1' FROM x WHERE a = ?");
431        assert_eq!(idx, vec![1]);
432    }
433
434    #[test]
435    fn reused_param_emits_repeated_args_in_mysql() {
436        // WHERE x = $1 OR y = $1 must produce two `?` AND pass the arg twice.
437        use crate::ir::{ParamDef, SqlType, SqlTypeCategory};
438        let query = QueryDef {
439            name: "Search".to_string(),
440            command: QueryCommand::Many,
441            sql: "SELECT id FROM users WHERE name = $1 OR email = $1".to_string(),
442            params: vec![ParamDef {
443                index: 1,
444                name: "q".to_string(),
445                sql_type: SqlType {
446                    raw: "text".to_string(),
447                    normalized: "text".to_string(),
448                    category: SqlTypeCategory::String,
449                    element_type: None,
450                    enum_name: None,
451                    enum_values: None,
452                    json_shape: None,
453                },
454            }],
455            returns: vec![],
456            source_file: "q.sql".to_string(),
457        };
458        let out = generate_query_function(DatabaseSqlBackend::MySql, &query);
459        // Two `?` in the SQL const.
460        assert_eq!(out.matches('?').count(), 2);
461        // Two `q` args in the Query/Exec call.
462        assert!(out.contains(", q, q)"), "expected `, q, q)` in: {out}");
463    }
464
465    #[test]
466    fn out_of_order_params_bind_in_document_order_in_sqlite() {
467        // WHERE b = $2 AND a = $1 — first ? binds param with index 2,
468        // second binds param with index 1.
469        use crate::ir::{ParamDef, SqlType, SqlTypeCategory};
470        let int_type = SqlType {
471            raw: "integer".to_string(),
472            normalized: "integer".to_string(),
473            category: SqlTypeCategory::Number,
474            element_type: None,
475            enum_name: None,
476            enum_values: None,
477            json_shape: None,
478        };
479        let query = QueryDef {
480            name: "Range".to_string(),
481            command: QueryCommand::Many,
482            sql: "SELECT id FROM t WHERE b = $2 AND a = $1".to_string(),
483            params: vec![
484                ParamDef {
485                    index: 1,
486                    name: "a".to_string(),
487                    sql_type: int_type.clone(),
488                },
489                ParamDef {
490                    index: 2,
491                    name: "b".to_string(),
492                    sql_type: int_type,
493                },
494            ],
495            returns: vec![],
496            source_file: "q.sql".to_string(),
497        };
498        let out = generate_query_function(DatabaseSqlBackend::Sqlite, &query);
499        // Must be `, b, a)` (document order), NOT `, a, b)` (param-index order).
500        assert!(out.contains(", b, a)"), "expected `, b, a)` in: {out}");
501        assert!(!out.contains(", a, b)"));
502    }
503}