1use std::collections::BTreeMap;
4use std::collections::BTreeSet;
5use std::path::Path;
6
7use crate::error::Result;
8use crate::generator::go::common::{
9 escape_sql, func_params, generate_result_struct, generate_row_struct, query_args, scan_fields,
10 sql_const_name,
11};
12use crate::generator::{DriverGenerator, GeneratedFile};
13use crate::ir::{ColumnDef, QueryCommand, QueryDef, SqlcxIR};
14use crate::utils::pascal_case;
15
16use super::structs::go_imports_for_columns;
17
18pub struct PgxGenerator;
19
20fn generate_client() -> String {
21 r#"// Code generated by sqlcx. DO NOT EDIT.
22package db
23
24import (
25 "context"
26
27 "github.com/jackc/pgx/v5"
28 "github.com/jackc/pgx/v5/pgconn"
29)
30
31type DBTX interface {
32 Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
33 Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
34 QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
35}
36
37type Queries struct {
38 db DBTX
39}
40
41func New(db DBTX) *Queries {
42 return &Queries{db: db}
43}"#
44 .to_string()
45}
46
47fn generate_query_function(query: &QueryDef) -> String {
48 let const_name = sql_const_name(&query.name);
49 let func_name = pascal_case(&query.name);
50 let params = func_params(query);
51 let args = query_args(query);
52
53 let mut parts: Vec<String> = Vec::new();
54
55 parts.push(format!(
56 "const {} = \"{}\"",
57 const_name,
58 escape_sql(&query.sql),
59 ));
60
61 match query.command {
62 QueryCommand::One => {
63 let row_type = format!("{}Row", pascal_case(&query.name));
64 if let Some(row_struct) = generate_row_struct(query) {
65 parts.push(row_struct);
66 }
67 let scans = scan_fields(&query.returns);
68 parts.push(format!(
69 "func (q *Queries) {}({}) (*{}, error) {{\n\
70 \trow := q.db.QueryRow(ctx, {}{})
71\tvar i {}
72\terr := row.Scan({})
73\tif err != nil {{
74\t\tif errors.Is(err, pgx.ErrNoRows) {{
75\t\t\treturn nil, nil
76\t\t}}
77\t\treturn nil, err
78\t}}
79\treturn &i, nil\n}}",
80 func_name, params, row_type, const_name, args, row_type, scans,
81 ));
82 }
83 QueryCommand::Many => {
84 let row_type = format!("{}Row", pascal_case(&query.name));
85 if let Some(row_struct) = generate_row_struct(query) {
86 parts.push(row_struct);
87 }
88 let scans = scan_fields(&query.returns);
89 parts.push(format!(
90 "func (q *Queries) {}({}) ([]{}, error) {{\n\
91 \trows, err := q.db.Query(ctx, {}{})
92\tif err != nil {{
93\t\treturn nil, err
94\t}}
95\tdefer rows.Close()
96\tvar items []{}
97\tfor rows.Next() {{
98\t\tvar i {}
99\t\tif err := rows.Scan({}); err != nil {{
100\t\t\treturn nil, err
101\t\t}}
102\t\titems = append(items, i)
103\t}}
104\treturn items, rows.Err()\n}}",
105 func_name, params, row_type, const_name, args, row_type, row_type, scans,
106 ));
107 }
108 QueryCommand::Exec => {
109 parts.push(format!(
110 "func (q *Queries) {}({}) error {{\n\
111 \t_, err := q.db.Exec(ctx, {}{})\n\treturn err\n}}",
112 func_name, params, const_name, args,
113 ));
114 }
115 QueryCommand::ExecResult => {
116 let result_type = format!("{}Result", pascal_case(&query.name));
117 parts.push(generate_result_struct(query));
118 parts.push(format!(
119 "func (q *Queries) {}({}) (*{}, error) {{\n\
120 \ttag, err := q.db.Exec(ctx, {}{})
121\tif err != nil {{
122\t\treturn nil, err
123\t}}
124\treturn &{}{{RowsAffected: tag.RowsAffected()}}, nil\n}}",
125 func_name, params, result_type, const_name, args, result_type,
126 ));
127 }
128 }
129
130 parts.join("\n\n")
131}
132
133fn collect_query_imports(queries: &[QueryDef]) -> BTreeSet<String> {
134 let mut imports = BTreeSet::new();
135 imports.insert("context".to_string());
136
137 let mut needs_pgx_err = false;
138 for query in queries {
139 let ret_imports = go_imports_for_columns(&query.returns);
140 imports.extend(ret_imports);
141
142 if query.command == QueryCommand::One {
143 needs_pgx_err = true;
144 }
145
146 for param in &query.params {
147 let col = ColumnDef {
148 name: param.name.clone(),
149 alias: None,
150 source_table: None,
151 sql_type: param.sql_type.clone(),
152 nullable: false,
153 has_default: false,
154 };
155 let col_imports = go_imports_for_columns(&[col]);
156 imports.extend(col_imports);
157 }
158 }
159
160 if needs_pgx_err {
161 imports.insert("errors".to_string());
162 imports.insert("github.com/jackc/pgx/v5".to_string());
163 }
164
165 imports
166}
167
168impl PgxGenerator {
169 pub fn generate_client(&self) -> String {
170 generate_client()
171 }
172
173 pub fn generate_query_file(&self, queries: &[QueryDef]) -> String {
174 let imports = collect_query_imports(queries);
175 let imports_str = if imports.is_empty() {
176 String::new()
177 } else {
178 let lines: Vec<String> = imports.iter().map(|i| format!("\t\"{}\"", i)).collect();
179 format!("\nimport (\n{}\n)\n", lines.join("\n"))
180 };
181
182 let functions: Vec<String> = queries.iter().map(generate_query_function).collect();
183
184 let mut content = String::new();
185 content.push_str("// Code generated by sqlcx. DO NOT EDIT.\npackage db\n");
186 content.push_str(&imports_str);
187 if !functions.is_empty() {
188 content.push('\n');
189 content.push_str(&functions.join("\n\n"));
190 content.push('\n');
191 }
192 content
193 }
194}
195
196impl DriverGenerator for PgxGenerator {
197 fn generate(&self, ir: &SqlcxIR) -> Result<Vec<GeneratedFile>> {
198 let mut files = Vec::new();
199
200 files.push(GeneratedFile {
201 path: "client.go".to_string(),
202 content: self.generate_client(),
203 });
204
205 let mut grouped: BTreeMap<String, Vec<&QueryDef>> = BTreeMap::new();
206 for query in &ir.queries {
207 grouped
208 .entry(query.source_file.clone())
209 .or_default()
210 .push(query);
211 }
212 for (source_file, queries) in &grouped {
213 let basename = Path::new(source_file)
214 .file_stem()
215 .unwrap_or_default()
216 .to_string_lossy();
217 let owned: Vec<QueryDef> = queries.iter().map(|q| (*q).clone()).collect();
218 files.push(GeneratedFile {
219 path: format!("{}.queries.go", basename),
220 content: self.generate_query_file(&owned),
221 });
222 }
223
224 Ok(files)
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use crate::parser::DatabaseParser;
232 use crate::parser::postgres::PostgresParser;
233
234 fn parse_fixture_ir() -> SqlcxIR {
235 let schema_sql = include_str!("../../../../../tests/fixtures/schema.sql");
236 let queries_sql = include_str!("../../../../../tests/fixtures/queries/users.sql");
237 let parser = PostgresParser::new();
238 let (tables, enums) = parser.parse_schema(schema_sql).unwrap();
239 let queries = parser
240 .parse_queries(queries_sql, &tables, &enums, "queries/users.sql")
241 .unwrap();
242 SqlcxIR {
243 tables,
244 queries,
245 enums,
246 }
247 }
248
249 #[test]
250 fn generates_client_file() {
251 let gen_ = PgxGenerator;
252 let content = gen_.generate_client();
253 assert!(content.contains("github.com/jackc/pgx/v5"));
254 assert!(content.contains("type DBTX interface"));
255 insta::assert_snapshot!("go_pgx_client", content);
256 }
257
258 #[test]
259 fn generates_query_file() {
260 let ir = parse_fixture_ir();
261 let gen_ = PgxGenerator;
262 let content = gen_.generate_query_file(&ir.queries);
263 assert!(content.contains("func (q *Queries) GetUser"));
264 assert!(content.contains("func (q *Queries) ListUsers"));
265 insta::assert_snapshot!("go_pgx_queries", content);
266 }
267}