1use std::collections::BTreeMap;
4use std::collections::BTreeSet;
5use std::path::Path;
6
7use crate::error::Result;
8use crate::generator::go::common::{
9 escape_sql, func_params, generate_result_struct, generate_row_struct, query_args, scan_fields,
10 sql_const_name,
11};
12use crate::generator::{DriverGenerator, GeneratedFile};
13use crate::ir::{ColumnDef, QueryCommand, QueryDef, SqlcxIR};
14use crate::utils::pascal_case;
15
16use super::structs::go_imports_for_columns;
17
18#[derive(Debug, Clone, Copy)]
24pub enum DatabaseSqlBackend {
25 Postgres,
26 MySql,
27 Sqlite,
28}
29
30impl DatabaseSqlBackend {
31 fn rewrite_placeholders(self, sql: &str) -> (String, Vec<u32>) {
38 match self {
39 DatabaseSqlBackend::Postgres => (sql.to_string(), Vec::new()),
40 DatabaseSqlBackend::MySql | DatabaseSqlBackend::Sqlite => {
41 let mut out = String::with_capacity(sql.len());
42 let mut indices = Vec::new();
43 let mut chars = sql.chars().peekable();
44 let mut in_string = false;
45 while let Some(c) = chars.next() {
46 if c == '\'' {
47 if in_string && chars.peek() == Some(&'\'') {
48 out.push(c);
49 out.push(chars.next().unwrap());
50 continue;
51 }
52 in_string = !in_string;
53 out.push(c);
54 continue;
55 }
56 if !in_string && c == '$' && chars.peek().is_some_and(|ch| ch.is_ascii_digit())
57 {
58 let mut num = String::new();
59 while chars.peek().is_some_and(|ch| ch.is_ascii_digit()) {
60 num.push(chars.next().unwrap());
61 }
62 indices.push(num.parse::<u32>().unwrap_or(0));
63 out.push('?');
64 } else {
65 out.push(c);
66 }
67 }
68 (out, indices)
69 }
70 }
71 }
72}
73
74pub struct DatabaseSqlGenerator {
75 backend: DatabaseSqlBackend,
76}
77
78impl DatabaseSqlGenerator {
79 pub fn postgres() -> Self {
80 Self {
81 backend: DatabaseSqlBackend::Postgres,
82 }
83 }
84
85 pub fn mysql() -> Self {
86 Self {
87 backend: DatabaseSqlBackend::MySql,
88 }
89 }
90
91 pub fn sqlite() -> Self {
92 Self {
93 backend: DatabaseSqlBackend::Sqlite,
94 }
95 }
96}
97
98impl Default for DatabaseSqlGenerator {
99 fn default() -> Self {
100 Self::postgres()
101 }
102}
103
104fn generate_client() -> String {
107 r#"// Code generated by sqlcx. DO NOT EDIT.
108package db
109
110import (
111 "context"
112 "database/sql"
113)
114
115type DBTX interface {
116 ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
117 QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
118 QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
119}
120
121type Queries struct {
122 db DBTX
123}
124
125func New(db DBTX) *Queries {
126 return &Queries{db: db}
127}"#
128 .to_string()
129}
130
131fn generate_query_function(backend: DatabaseSqlBackend, query: &QueryDef) -> String {
136 let const_name = sql_const_name(&query.name);
137 let func_name = pascal_case(&query.name);
138 let params = func_params(query);
139
140 let mut parts: Vec<String> = Vec::new();
141
142 let (rewritten_sql, occurrence_indices) = backend.rewrite_placeholders(&query.sql);
144 parts.push(format!(
145 "const {} = \"{}\"",
146 const_name,
147 escape_sql(&rewritten_sql),
148 ));
149
150 let args = if occurrence_indices.is_empty() {
154 query_args(query)
155 } else {
156 let names: Vec<String> = occurrence_indices
157 .iter()
158 .map(|idx| {
159 query
160 .params
161 .iter()
162 .find(|p| p.index == *idx)
163 .map(|p| p.name.clone())
164 .unwrap_or_else(|| "nil".to_string())
165 })
166 .collect();
167 if names.is_empty() {
168 String::new()
169 } else {
170 format!(", {}", names.join(", "))
171 }
172 };
173
174 match query.command {
175 QueryCommand::One => {
176 let row_type = format!("{}Row", pascal_case(&query.name));
177 if let Some(row_struct) = generate_row_struct(query) {
178 parts.push(row_struct);
179 }
180 let scans = scan_fields(&query.returns);
181 parts.push(format!(
182 "func (q *Queries) {}({}) (*{}, error) {{\n\
183 \trow := q.db.QueryRowContext(ctx, {}{})
184\tvar i {}
185\terr := row.Scan({})
186\tif err == sql.ErrNoRows {{
187\t\treturn nil, nil
188\t}}
189\tif err != nil {{
190\t\treturn nil, err
191\t}}
192\treturn &i, nil\n}}",
193 func_name, params, row_type, const_name, args, row_type, scans,
194 ));
195 }
196 QueryCommand::Many => {
197 let row_type = format!("{}Row", pascal_case(&query.name));
198 if let Some(row_struct) = generate_row_struct(query) {
199 parts.push(row_struct);
200 }
201 let scans = scan_fields(&query.returns);
202 parts.push(format!(
203 "func (q *Queries) {}({}) ([]{}, error) {{\n\
204 \trows, err := q.db.QueryContext(ctx, {}{})
205\tif err != nil {{
206\t\treturn nil, err
207\t}}
208\tdefer rows.Close()
209\tvar items []{}
210\tfor rows.Next() {{
211\t\tvar i {}
212\t\tif err := rows.Scan({}); err != nil {{
213\t\t\treturn nil, err
214\t\t}}
215\t\titems = append(items, i)
216\t}}
217\treturn items, rows.Err()\n}}",
218 func_name, params, row_type, const_name, args, row_type, row_type, scans,
219 ));
220 }
221 QueryCommand::Exec => {
222 parts.push(format!(
223 "func (q *Queries) {}({}) error {{\n\
224 \t_, err := q.db.ExecContext(ctx, {}{})\n\treturn err\n}}",
225 func_name, params, const_name, args,
226 ));
227 }
228 QueryCommand::ExecResult => {
229 let result_type = format!("{}Result", pascal_case(&query.name));
230 parts.push(generate_result_struct(query));
231 parts.push(format!(
232 "func (q *Queries) {}({}) (*{}, error) {{\n\
233 \tresult, err := q.db.ExecContext(ctx, {}{})
234\tif err != nil {{
235\t\treturn nil, err
236\t}}
237\taffected, err := result.RowsAffected()
238\tif err != nil {{
239\t\treturn nil, err
240\t}}
241\treturn &{}{{RowsAffected: affected}}, nil\n}}",
242 func_name, params, result_type, const_name, args, result_type,
243 ));
244 }
245 }
246
247 parts.join("\n\n")
248}
249
250fn collect_query_imports(queries: &[QueryDef]) -> BTreeSet<String> {
252 let mut imports = BTreeSet::new();
253 imports.insert("context".to_string());
254
255 let mut needs_database_sql = false;
256 for query in queries {
257 let ret_imports = go_imports_for_columns(&query.returns);
259 imports.extend(ret_imports);
260
261 match query.command {
263 QueryCommand::One | QueryCommand::ExecResult => {
264 needs_database_sql = true;
265 }
266 _ => {}
267 }
268
269 for param in &query.params {
271 let col = ColumnDef {
272 name: param.name.clone(),
273 alias: None,
274 source_table: None,
275 sql_type: param.sql_type.clone(),
276 nullable: false,
277 has_default: false,
278 };
279 let col_imports = go_imports_for_columns(&[col]);
280 imports.extend(col_imports);
281 }
282 }
283
284 if needs_database_sql {
285 imports.insert("database/sql".to_string());
286 }
287
288 imports
289}
290
291impl DatabaseSqlGenerator {
294 pub fn generate_client(&self) -> String {
295 generate_client()
296 }
297
298 pub fn generate_query_file(&self, queries: &[QueryDef]) -> String {
299 let imports = collect_query_imports(queries);
300 let imports_str = if imports.is_empty() {
301 String::new()
302 } else {
303 let lines: Vec<String> = imports.iter().map(|i| format!("\t\"{}\"", i)).collect();
304 format!("\nimport (\n{}\n)\n", lines.join("\n"))
305 };
306
307 let functions: Vec<String> = queries
308 .iter()
309 .map(|q| generate_query_function(self.backend, q))
310 .collect();
311
312 let mut content = String::new();
313 content.push_str("// Code generated by sqlcx. DO NOT EDIT.\npackage db\n");
314 content.push_str(&imports_str);
315 if !functions.is_empty() {
316 content.push('\n');
317 content.push_str(&functions.join("\n\n"));
318 content.push('\n');
319 }
320 content
321 }
322}
323
324impl DriverGenerator for DatabaseSqlGenerator {
325 fn generate(&self, ir: &SqlcxIR) -> Result<Vec<GeneratedFile>> {
326 let mut files = Vec::new();
327
328 files.push(GeneratedFile {
330 path: "client.go".to_string(),
331 content: self.generate_client(),
332 });
333
334 let mut grouped: BTreeMap<String, Vec<&QueryDef>> = BTreeMap::new();
336 for query in &ir.queries {
337 grouped
338 .entry(query.source_file.clone())
339 .or_default()
340 .push(query);
341 }
342 for (source_file, queries) in &grouped {
343 let basename = Path::new(source_file)
344 .file_stem()
345 .unwrap_or_default()
346 .to_string_lossy();
347 let owned: Vec<QueryDef> = queries.iter().map(|q| (*q).clone()).collect();
348 files.push(GeneratedFile {
349 path: format!("{}.queries.go", basename),
350 content: self.generate_query_file(&owned),
351 });
352 }
353
354 Ok(files)
355 }
356}
357
358#[cfg(test)]
361mod tests {
362 use super::*;
363 use crate::parser::DatabaseParser;
364 use crate::parser::postgres::PostgresParser;
365
366 fn parse_fixture_ir() -> SqlcxIR {
367 let schema_sql = include_str!("../../../../../tests/fixtures/schema.sql");
368 let queries_sql = include_str!("../../../../../tests/fixtures/queries/users.sql");
369 let parser = PostgresParser::new();
370 let (tables, enums) = parser.parse_schema(schema_sql).unwrap();
371 let queries = parser
372 .parse_queries(queries_sql, &tables, &enums, "queries/users.sql")
373 .unwrap();
374 SqlcxIR {
375 tables,
376 queries,
377 enums,
378 }
379 }
380
381 #[test]
382 fn generates_client_file() {
383 let gen_ = DatabaseSqlGenerator::postgres();
384 let content = gen_.generate_client();
385 assert!(content.contains("type DBTX interface"));
386 assert!(content.contains("type Queries struct"));
387 assert!(content.contains("func New(db DBTX) *Queries"));
388 insta::assert_snapshot!("go_database_sql_client", content);
389 }
390
391 #[test]
392 fn generates_query_file() {
393 let ir = parse_fixture_ir();
394 let gen_ = DatabaseSqlGenerator::postgres();
395 let content = gen_.generate_query_file(&ir.queries);
396 assert!(content.contains("func (q *Queries) GetUser"));
397 assert!(content.contains("func (q *Queries) ListUsers"));
398 assert!(content.contains("func (q *Queries) CreateUser"));
399 assert!(content.contains("func (q *Queries) DeleteUser"));
400 assert!(content.contains("WHERE id = $1"));
402 insta::assert_snapshot!("go_database_sql_queries", content);
403 }
404
405 #[test]
406 fn mysql_backend_rewrites_placeholders_to_question_marks() {
407 let ir = parse_fixture_ir();
408 let gen_ = DatabaseSqlGenerator::mysql();
409 let content = gen_.generate_query_file(&ir.queries);
410 assert!(content.contains("WHERE id = ?"));
412 assert!(!content.contains("WHERE id = $1"));
413 insta::assert_snapshot!("go_database_sql_mysql_queries", content);
414 }
415
416 #[test]
417 fn sqlite_backend_rewrites_placeholders_to_question_marks() {
418 let ir = parse_fixture_ir();
419 let gen_ = DatabaseSqlGenerator::sqlite();
420 let content = gen_.generate_query_file(&ir.queries);
421 assert!(content.contains("WHERE id = ?"));
422 assert!(!content.contains("WHERE id = $1"));
423 insta::assert_snapshot!("go_database_sql_sqlite_queries", content);
424 }
425
426 #[test]
427 fn placeholder_rewrite_preserves_dollar_in_string_literals() {
428 let (rewritten, idx) =
429 DatabaseSqlBackend::MySql.rewrite_placeholders("SELECT '$1' FROM x WHERE a = $1");
430 assert_eq!(rewritten, "SELECT '$1' FROM x WHERE a = ?");
431 assert_eq!(idx, vec![1]);
432 }
433
434 #[test]
435 fn reused_param_emits_repeated_args_in_mysql() {
436 use crate::ir::{ParamDef, SqlType, SqlTypeCategory};
438 let query = QueryDef {
439 name: "Search".to_string(),
440 command: QueryCommand::Many,
441 sql: "SELECT id FROM users WHERE name = $1 OR email = $1".to_string(),
442 params: vec![ParamDef {
443 index: 1,
444 name: "q".to_string(),
445 sql_type: SqlType {
446 raw: "text".to_string(),
447 normalized: "text".to_string(),
448 category: SqlTypeCategory::String,
449 element_type: None,
450 enum_name: None,
451 enum_values: None,
452 json_shape: None,
453 },
454 }],
455 returns: vec![],
456 source_file: "q.sql".to_string(),
457 };
458 let out = generate_query_function(DatabaseSqlBackend::MySql, &query);
459 assert_eq!(out.matches('?').count(), 2);
461 assert!(out.contains(", q, q)"), "expected `, q, q)` in: {out}");
463 }
464
465 #[test]
466 fn out_of_order_params_bind_in_document_order_in_sqlite() {
467 use crate::ir::{ParamDef, SqlType, SqlTypeCategory};
470 let int_type = SqlType {
471 raw: "integer".to_string(),
472 normalized: "integer".to_string(),
473 category: SqlTypeCategory::Number,
474 element_type: None,
475 enum_name: None,
476 enum_values: None,
477 json_shape: None,
478 };
479 let query = QueryDef {
480 name: "Range".to_string(),
481 command: QueryCommand::Many,
482 sql: "SELECT id FROM t WHERE b = $2 AND a = $1".to_string(),
483 params: vec![
484 ParamDef {
485 index: 1,
486 name: "a".to_string(),
487 sql_type: int_type.clone(),
488 },
489 ParamDef {
490 index: 2,
491 name: "b".to_string(),
492 sql_type: int_type,
493 },
494 ],
495 returns: vec![],
496 source_file: "q.sql".to_string(),
497 };
498 let out = generate_query_function(DatabaseSqlBackend::Sqlite, &query);
499 assert!(out.contains(", b, a)"), "expected `, b, a)` in: {out}");
501 assert!(!out.contains(", a, b)"));
502 }
503}