1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/go-pgx.toml");
17
18pub struct GoPgxBackend {
19 manifest: BackendManifest,
20}
21
22impl GoPgxBackend {
23 pub fn new() -> Result<Self, ScytheError> {
24 let manifest_path = Path::new("backends/go-pgx/manifest.toml");
25 let manifest = if manifest_path.exists() {
26 load_manifest(manifest_path)
27 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
28 } else {
29 toml::from_str(DEFAULT_MANIFEST_TOML)
30 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
31 };
32 Ok(Self { manifest })
33 }
34
35 pub fn manifest(&self) -> &BackendManifest {
36 &self.manifest
37 }
38}
39
40impl CodegenBackend for GoPgxBackend {
41 fn name(&self) -> &str {
42 "go-pgx"
43 }
44
45 fn file_header(&self) -> String {
46 "package queries\n\nimport (\n\t\"context\"\n\n\t\"github.com/jackc/pgx/v5/pgxpool\"\n)\n"
47 .to_string()
48 }
49
50 fn generate_row_struct(
51 &self,
52 query_name: &str,
53 columns: &[ResolvedColumn],
54 ) -> Result<String, ScytheError> {
55 let struct_name = row_struct_name(query_name, &self.manifest.naming);
56 let mut out = String::new();
57 let _ = writeln!(out, "type {} struct {{", struct_name);
58 for col in columns {
59 let field = to_pascal_case(&col.field_name);
60 let json_tag = &col.field_name;
61 let _ = writeln!(out, "\t{} {} `json:\"{}\"`", field, col.full_type, json_tag);
62 }
63 let _ = write!(out, "}}");
64 Ok(out)
65 }
66
67 fn generate_model_struct(
68 &self,
69 table_name: &str,
70 columns: &[ResolvedColumn],
71 ) -> Result<String, ScytheError> {
72 let name = to_pascal_case(table_name);
73 self.generate_row_struct(&name, columns)
74 }
75
76 fn generate_query_fn(
77 &self,
78 analyzed: &AnalyzedQuery,
79 struct_name: &str,
80 columns: &[ResolvedColumn],
81 params: &[ResolvedParam],
82 ) -> Result<String, ScytheError> {
83 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
84 let sql = super::clean_sql_oneline(&analyzed.sql);
85
86 let param_list = params
87 .iter()
88 .map(|p| {
89 let field = to_pascal_case(&p.field_name);
90 format!("{} {}", field, p.full_type)
91 })
92 .collect::<Vec<_>>()
93 .join(", ");
94 let sep = if param_list.is_empty() { "" } else { ", " };
95
96 let args = params
97 .iter()
98 .map(|p| to_pascal_case(&p.field_name).into_owned())
99 .collect::<Vec<_>>();
100
101 let mut out = String::new();
102
103 match &analyzed.command {
104 QueryCommand::Exec => {
105 let _ = writeln!(
107 out,
108 "func {}(ctx context.Context, db *pgxpool.Pool{}{}) error {{",
109 func_name, sep, param_list
110 );
111 let args_str = if args.is_empty() {
112 String::new()
113 } else {
114 format!(", {}", args.join(", "))
115 };
116 let _ = writeln!(out, "\t_, err := db.Exec(ctx, \"{}\"{})", sql, args_str);
117 let _ = writeln!(out, "\treturn err");
118 let _ = write!(out, "}}");
119 }
120 QueryCommand::ExecResult | QueryCommand::ExecRows => {
121 let _ = writeln!(
123 out,
124 "func {}(ctx context.Context, db *pgxpool.Pool{}{}) (int64, error) {{",
125 func_name, sep, param_list
126 );
127 let args_str = if args.is_empty() {
128 String::new()
129 } else {
130 format!(", {}", args.join(", "))
131 };
132 let _ = writeln!(
133 out,
134 "\tresult, err := db.Exec(ctx, \"{}\"{})",
135 sql, args_str
136 );
137 let _ = writeln!(out, "\tif err != nil {{");
138 let _ = writeln!(out, "\t\treturn 0, err");
139 let _ = writeln!(out, "\t}}");
140 let _ = writeln!(out, "\treturn result.RowsAffected(), nil");
141 let _ = write!(out, "}}");
142 }
143 QueryCommand::One => {
144 let _ = writeln!(
146 out,
147 "func {}(ctx context.Context, db *pgxpool.Pool{}{}) ({}, error) {{",
148 func_name, sep, param_list, struct_name
149 );
150 let args_str = if args.is_empty() {
151 String::new()
152 } else {
153 format!(", {}", args.join(", "))
154 };
155 let _ = writeln!(out, "\trow := db.QueryRow(ctx, \"{}\"{})", sql, args_str);
156 let _ = writeln!(out, "\tvar r {}", struct_name);
157 let scan_fields: Vec<String> = columns
158 .iter()
159 .map(|c| format!("&r.{}", to_pascal_case(&c.field_name)))
160 .collect();
161 let _ = writeln!(out, "\terr := row.Scan({})", scan_fields.join(", "));
162 let _ = writeln!(out, "\treturn r, err");
163 let _ = write!(out, "}}");
164 }
165 QueryCommand::Many | QueryCommand::Batch => {
166 let _ = writeln!(
168 out,
169 "func {}(ctx context.Context, db *pgxpool.Pool{}{}) ([]{}, error) {{",
170 func_name, sep, param_list, struct_name
171 );
172 let args_str = if args.is_empty() {
173 String::new()
174 } else {
175 format!(", {}", args.join(", "))
176 };
177 let _ = writeln!(out, "\trows, err := db.Query(ctx, \"{}\"{})", sql, args_str);
178 let _ = writeln!(out, "\tif err != nil {{");
179 let _ = writeln!(out, "\t\treturn nil, err");
180 let _ = writeln!(out, "\t}}");
181 let _ = writeln!(out, "\tdefer rows.Close()");
182 let _ = writeln!(out, "\tvar result []{}", struct_name);
183 let _ = writeln!(out, "\tfor rows.Next() {{");
184 let _ = writeln!(out, "\t\tvar r {}", struct_name);
185 let scan_fields: Vec<String> = columns
186 .iter()
187 .map(|c| format!("&r.{}", to_pascal_case(&c.field_name)))
188 .collect();
189 let _ = writeln!(
190 out,
191 "\t\tif err := rows.Scan({}); err != nil {{",
192 scan_fields.join(", ")
193 );
194 let _ = writeln!(out, "\t\t\treturn nil, err");
195 let _ = writeln!(out, "\t\t}}");
196 let _ = writeln!(out, "\t\tresult = append(result, r)");
197 let _ = writeln!(out, "\t}}");
198 let _ = writeln!(out, "\treturn result, rows.Err()");
199 let _ = write!(out, "}}");
200 }
201 }
202
203 Ok(out)
204 }
205
206 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
207 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
208 let mut out = String::new();
209 let _ = writeln!(out, "type {} string", type_name);
210 let _ = writeln!(out);
211 let _ = writeln!(out, "const (");
212 for value in &enum_info.values {
213 let variant = enum_variant_name(value, &self.manifest.naming);
214 let _ = writeln!(
215 out,
216 "\t{}{} {} = \"{}\"",
217 type_name, variant, type_name, value
218 );
219 }
220 let _ = write!(out, ")");
221 Ok(out)
222 }
223
224 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
225 let name = to_pascal_case(&composite.sql_name);
226 let mut out = String::new();
227 let _ = writeln!(out, "type {} struct {{", name);
228 if composite.fields.is_empty() {
229 } else {
231 for field in &composite.fields {
232 let field_name = to_pascal_case(&field.name);
233 let go_type = resolve_type(&field.neutral_type, &self.manifest, false)
234 .map(|t| t.into_owned())
235 .unwrap_or_else(|_| "interface{}".to_string());
236 let json_tag = &field.name;
237 let _ = writeln!(out, "\t{} {} `json:\"{}\"`", field_name, go_type, json_tag);
238 }
239 }
240 let _ = write!(out, "}}");
241 Ok(out)
242 }
243}