Skip to main content

sqlcx_core/generator/go/
pgx.rs

1// pgx (github.com/jackc/pgx/v5) driver generator for Go
2
3use std::collections::BTreeMap;
4use std::collections::BTreeSet;
5use std::path::Path;
6
7use crate::error::Result;
8use crate::generator::go::common::{
9    escape_sql, func_params, generate_result_struct, generate_row_struct, query_args, scan_fields,
10    sql_const_name,
11};
12use crate::generator::{DriverGenerator, GeneratedFile};
13use crate::ir::{ColumnDef, QueryCommand, QueryDef, SqlcxIR};
14use crate::utils::pascal_case;
15
16use super::structs::go_imports_for_columns;
17
18pub struct PgxGenerator;
19
20fn generate_client() -> String {
21    r#"// Code generated by sqlcx. DO NOT EDIT.
22package db
23
24import (
25	"context"
26
27	"github.com/jackc/pgx/v5"
28	"github.com/jackc/pgx/v5/pgconn"
29)
30
31type DBTX interface {
32	Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
33	Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
34	QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
35}
36
37type Queries struct {
38	db DBTX
39}
40
41func New(db DBTX) *Queries {
42	return &Queries{db: db}
43}"#
44    .to_string()
45}
46
47fn generate_query_function(query: &QueryDef) -> String {
48    let const_name = sql_const_name(&query.name);
49    let func_name = pascal_case(&query.name);
50    let params = func_params(query);
51    let args = query_args(query);
52
53    let mut parts: Vec<String> = Vec::new();
54
55    parts.push(format!(
56        "const {} = \"{}\"",
57        const_name,
58        escape_sql(&query.sql),
59    ));
60
61    match query.command {
62        QueryCommand::One => {
63            let row_type = format!("{}Row", pascal_case(&query.name));
64            if let Some(row_struct) = generate_row_struct(query) {
65                parts.push(row_struct);
66            }
67            let scans = scan_fields(&query.returns);
68            parts.push(format!(
69                "func (q *Queries) {}({}) (*{}, error) {{\n\
70                \trow := q.db.QueryRow(ctx, {}{})
71\tvar i {}
72\terr := row.Scan({})
73\tif err != nil {{
74\t\tif errors.Is(err, pgx.ErrNoRows) {{
75\t\t\treturn nil, nil
76\t\t}}
77\t\treturn nil, err
78\t}}
79\treturn &i, nil\n}}",
80                func_name, params, row_type, const_name, args, row_type, scans,
81            ));
82        }
83        QueryCommand::Many => {
84            let row_type = format!("{}Row", pascal_case(&query.name));
85            if let Some(row_struct) = generate_row_struct(query) {
86                parts.push(row_struct);
87            }
88            let scans = scan_fields(&query.returns);
89            parts.push(format!(
90                "func (q *Queries) {}({}) ([]{}, error) {{\n\
91                \trows, err := q.db.Query(ctx, {}{})
92\tif err != nil {{
93\t\treturn nil, err
94\t}}
95\tdefer rows.Close()
96\tvar items []{}
97\tfor rows.Next() {{
98\t\tvar i {}
99\t\tif err := rows.Scan({}); err != nil {{
100\t\t\treturn nil, err
101\t\t}}
102\t\titems = append(items, i)
103\t}}
104\treturn items, rows.Err()\n}}",
105                func_name, params, row_type, const_name, args, row_type, row_type, scans,
106            ));
107        }
108        QueryCommand::Exec => {
109            parts.push(format!(
110                "func (q *Queries) {}({}) error {{\n\
111                \t_, err := q.db.Exec(ctx, {}{})\n\treturn err\n}}",
112                func_name, params, const_name, args,
113            ));
114        }
115        QueryCommand::ExecResult => {
116            let result_type = format!("{}Result", pascal_case(&query.name));
117            parts.push(generate_result_struct(query));
118            parts.push(format!(
119                "func (q *Queries) {}({}) (*{}, error) {{\n\
120                \ttag, err := q.db.Exec(ctx, {}{})
121\tif err != nil {{
122\t\treturn nil, err
123\t}}
124\treturn &{}{{RowsAffected: tag.RowsAffected()}}, nil\n}}",
125                func_name, params, result_type, const_name, args, result_type,
126            ));
127        }
128    }
129
130    parts.join("\n\n")
131}
132
133fn collect_query_imports(queries: &[QueryDef]) -> BTreeSet<String> {
134    let mut imports = BTreeSet::new();
135    imports.insert("context".to_string());
136
137    let mut needs_pgx_err = false;
138    for query in queries {
139        let ret_imports = go_imports_for_columns(&query.returns);
140        imports.extend(ret_imports);
141
142        if query.command == QueryCommand::One {
143            needs_pgx_err = true;
144        }
145
146        for param in &query.params {
147            let col = ColumnDef {
148                name: param.name.clone(),
149                alias: None,
150                source_table: None,
151                sql_type: param.sql_type.clone(),
152                nullable: false,
153                has_default: false,
154            };
155            let col_imports = go_imports_for_columns(&[col]);
156            imports.extend(col_imports);
157        }
158    }
159
160    if needs_pgx_err {
161        imports.insert("errors".to_string());
162        imports.insert("github.com/jackc/pgx/v5".to_string());
163    }
164
165    imports
166}
167
168impl PgxGenerator {
169    pub fn generate_client(&self) -> String {
170        generate_client()
171    }
172
173    pub fn generate_query_file(&self, queries: &[QueryDef]) -> String {
174        let imports = collect_query_imports(queries);
175        let imports_str = if imports.is_empty() {
176            String::new()
177        } else {
178            let lines: Vec<String> = imports.iter().map(|i| format!("\t\"{}\"", i)).collect();
179            format!("\nimport (\n{}\n)\n", lines.join("\n"))
180        };
181
182        let functions: Vec<String> = queries.iter().map(generate_query_function).collect();
183
184        let mut content = String::new();
185        content.push_str("// Code generated by sqlcx. DO NOT EDIT.\npackage db\n");
186        content.push_str(&imports_str);
187        if !functions.is_empty() {
188            content.push('\n');
189            content.push_str(&functions.join("\n\n"));
190            content.push('\n');
191        }
192        content
193    }
194}
195
196impl DriverGenerator for PgxGenerator {
197    fn generate(&self, ir: &SqlcxIR) -> Result<Vec<GeneratedFile>> {
198        let mut files = Vec::new();
199
200        files.push(GeneratedFile {
201            path: "client.go".to_string(),
202            content: self.generate_client(),
203        });
204
205        let mut grouped: BTreeMap<String, Vec<&QueryDef>> = BTreeMap::new();
206        for query in &ir.queries {
207            grouped
208                .entry(query.source_file.clone())
209                .or_default()
210                .push(query);
211        }
212        for (source_file, queries) in &grouped {
213            let basename = Path::new(source_file)
214                .file_stem()
215                .unwrap_or_default()
216                .to_string_lossy();
217            let owned: Vec<QueryDef> = queries.iter().map(|q| (*q).clone()).collect();
218            files.push(GeneratedFile {
219                path: format!("{}.queries.go", basename),
220                content: self.generate_query_file(&owned),
221            });
222        }
223
224        Ok(files)
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::parser::DatabaseParser;
232    use crate::parser::postgres::PostgresParser;
233
234    fn parse_fixture_ir() -> SqlcxIR {
235        let schema_sql = include_str!("../../../../../tests/fixtures/schema.sql");
236        let queries_sql = include_str!("../../../../../tests/fixtures/queries/users.sql");
237        let parser = PostgresParser::new();
238        let (tables, enums) = parser.parse_schema(schema_sql).unwrap();
239        let queries = parser
240            .parse_queries(queries_sql, &tables, &enums, "queries/users.sql")
241            .unwrap();
242        SqlcxIR {
243            tables,
244            queries,
245            enums,
246        }
247    }
248
249    #[test]
250    fn generates_client_file() {
251        let gen_ = PgxGenerator;
252        let content = gen_.generate_client();
253        assert!(content.contains("github.com/jackc/pgx/v5"));
254        assert!(content.contains("type DBTX interface"));
255        insta::assert_snapshot!("go_pgx_client", content);
256    }
257
258    #[test]
259    fn generates_query_file() {
260        let ir = parse_fixture_ir();
261        let gen_ = PgxGenerator;
262        let content = gen_.generate_query_file(&ir.queries);
263        assert!(content.contains("func (q *Queries) GetUser"));
264        assert!(content.contains("func (q *Queries) ListUsers"));
265        insta::assert_snapshot!("go_pgx_queries", content);
266    }
267}