Skip to main content

scythe_codegen/backends/
go_pgx.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/go-pgx.toml");
17
18pub struct GoPgxBackend {
19    manifest: BackendManifest,
20}
21
22impl GoPgxBackend {
23    pub fn new() -> Result<Self, ScytheError> {
24        let manifest_path = Path::new("backends/go-pgx/manifest.toml");
25        let manifest = if manifest_path.exists() {
26            load_manifest(manifest_path)
27                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
28        } else {
29            toml::from_str(DEFAULT_MANIFEST_TOML)
30                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
31        };
32        Ok(Self { manifest })
33    }
34
35    pub fn manifest(&self) -> &BackendManifest {
36        &self.manifest
37    }
38}
39
40impl CodegenBackend for GoPgxBackend {
41    fn name(&self) -> &str {
42        "go-pgx"
43    }
44
45    fn file_header(&self) -> String {
46        "package queries\n\nimport (\n\t\"context\"\n\n\t\"github.com/jackc/pgx/v5/pgxpool\"\n)\n"
47            .to_string()
48    }
49
50    fn generate_row_struct(
51        &self,
52        query_name: &str,
53        columns: &[ResolvedColumn],
54    ) -> Result<String, ScytheError> {
55        let struct_name = row_struct_name(query_name, &self.manifest.naming);
56        let mut out = String::new();
57        let _ = writeln!(out, "type {} struct {{", struct_name);
58        for col in columns {
59            let field = to_pascal_case(&col.field_name);
60            let json_tag = &col.field_name;
61            let _ = writeln!(out, "\t{} {} `json:\"{}\"`", field, col.full_type, json_tag);
62        }
63        let _ = write!(out, "}}");
64        Ok(out)
65    }
66
67    fn generate_model_struct(
68        &self,
69        table_name: &str,
70        columns: &[ResolvedColumn],
71    ) -> Result<String, ScytheError> {
72        let name = to_pascal_case(table_name);
73        self.generate_row_struct(&name, columns)
74    }
75
76    fn generate_query_fn(
77        &self,
78        analyzed: &AnalyzedQuery,
79        struct_name: &str,
80        columns: &[ResolvedColumn],
81        params: &[ResolvedParam],
82    ) -> Result<String, ScytheError> {
83        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
84        let sql = super::clean_sql_oneline(&analyzed.sql);
85
86        let param_list = params
87            .iter()
88            .map(|p| {
89                let field = to_pascal_case(&p.field_name);
90                format!("{} {}", field, p.full_type)
91            })
92            .collect::<Vec<_>>()
93            .join(", ");
94        let sep = if param_list.is_empty() { "" } else { ", " };
95
96        let args = params
97            .iter()
98            .map(|p| to_pascal_case(&p.field_name).into_owned())
99            .collect::<Vec<_>>();
100
101        let mut out = String::new();
102
103        match &analyzed.command {
104            QueryCommand::Exec => {
105                // :exec - returns error only
106                let _ = writeln!(
107                    out,
108                    "func {}(ctx context.Context, db *pgxpool.Pool{}{}) error {{",
109                    func_name, sep, param_list
110                );
111                let args_str = if args.is_empty() {
112                    String::new()
113                } else {
114                    format!(", {}", args.join(", "))
115                };
116                let _ = writeln!(out, "\t_, err := db.Exec(ctx, \"{}\"{})", sql, args_str);
117                let _ = writeln!(out, "\treturn err");
118                let _ = write!(out, "}}");
119            }
120            QueryCommand::ExecResult | QueryCommand::ExecRows => {
121                // :exec_rows - returns affected row count
122                let _ = writeln!(
123                    out,
124                    "func {}(ctx context.Context, db *pgxpool.Pool{}{}) (int64, error) {{",
125                    func_name, sep, param_list
126                );
127                let args_str = if args.is_empty() {
128                    String::new()
129                } else {
130                    format!(", {}", args.join(", "))
131                };
132                let _ = writeln!(
133                    out,
134                    "\tresult, err := db.Exec(ctx, \"{}\"{})",
135                    sql, args_str
136                );
137                let _ = writeln!(out, "\tif err != nil {{");
138                let _ = writeln!(out, "\t\treturn 0, err");
139                let _ = writeln!(out, "\t}}");
140                let _ = writeln!(out, "\treturn result.RowsAffected(), nil");
141                let _ = write!(out, "}}");
142            }
143            QueryCommand::One => {
144                // :one - returns single struct
145                let _ = writeln!(
146                    out,
147                    "func {}(ctx context.Context, db *pgxpool.Pool{}{}) ({}, error) {{",
148                    func_name, sep, param_list, struct_name
149                );
150                let args_str = if args.is_empty() {
151                    String::new()
152                } else {
153                    format!(", {}", args.join(", "))
154                };
155                let _ = writeln!(out, "\trow := db.QueryRow(ctx, \"{}\"{})", sql, args_str);
156                let _ = writeln!(out, "\tvar r {}", struct_name);
157                let scan_fields: Vec<String> = columns
158                    .iter()
159                    .map(|c| format!("&r.{}", to_pascal_case(&c.field_name)))
160                    .collect();
161                let _ = writeln!(out, "\terr := row.Scan({})", scan_fields.join(", "));
162                let _ = writeln!(out, "\treturn r, err");
163                let _ = write!(out, "}}");
164            }
165            QueryCommand::Many | QueryCommand::Batch => {
166                // :many - returns slice
167                let _ = writeln!(
168                    out,
169                    "func {}(ctx context.Context, db *pgxpool.Pool{}{}) ([]{}, error) {{",
170                    func_name, sep, param_list, struct_name
171                );
172                let args_str = if args.is_empty() {
173                    String::new()
174                } else {
175                    format!(", {}", args.join(", "))
176                };
177                let _ = writeln!(out, "\trows, err := db.Query(ctx, \"{}\"{})", sql, args_str);
178                let _ = writeln!(out, "\tif err != nil {{");
179                let _ = writeln!(out, "\t\treturn nil, err");
180                let _ = writeln!(out, "\t}}");
181                let _ = writeln!(out, "\tdefer rows.Close()");
182                let _ = writeln!(out, "\tvar result []{}", struct_name);
183                let _ = writeln!(out, "\tfor rows.Next() {{");
184                let _ = writeln!(out, "\t\tvar r {}", struct_name);
185                let scan_fields: Vec<String> = columns
186                    .iter()
187                    .map(|c| format!("&r.{}", to_pascal_case(&c.field_name)))
188                    .collect();
189                let _ = writeln!(
190                    out,
191                    "\t\tif err := rows.Scan({}); err != nil {{",
192                    scan_fields.join(", ")
193                );
194                let _ = writeln!(out, "\t\t\treturn nil, err");
195                let _ = writeln!(out, "\t\t}}");
196                let _ = writeln!(out, "\t\tresult = append(result, r)");
197                let _ = writeln!(out, "\t}}");
198                let _ = writeln!(out, "\treturn result, rows.Err()");
199                let _ = write!(out, "}}");
200            }
201        }
202
203        Ok(out)
204    }
205
206    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
207        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
208        let mut out = String::new();
209        let _ = writeln!(out, "type {} string", type_name);
210        let _ = writeln!(out);
211        let _ = writeln!(out, "const (");
212        for value in &enum_info.values {
213            let variant = enum_variant_name(value, &self.manifest.naming);
214            let _ = writeln!(
215                out,
216                "\t{}{} {} = \"{}\"",
217                type_name, variant, type_name, value
218            );
219        }
220        let _ = write!(out, ")");
221        Ok(out)
222    }
223
224    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
225        let name = to_pascal_case(&composite.sql_name);
226        let mut out = String::new();
227        let _ = writeln!(out, "type {} struct {{", name);
228        if composite.fields.is_empty() {
229            // empty struct
230        } else {
231            for field in &composite.fields {
232                let field_name = to_pascal_case(&field.name);
233                let go_type = resolve_type(&field.neutral_type, &self.manifest, false)
234                    .map(|t| t.into_owned())
235                    .unwrap_or_else(|_| "interface{}".to_string());
236                let json_tag = &field.name;
237                let _ = writeln!(out, "\t{} {} `json:\"{}\"`", field_name, go_type, json_tag);
238            }
239        }
240        let _ = write!(out, "}}");
241        Ok(out)
242    }
243}