1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/csharp-npgsql.toml");
17
18pub struct CsharpNpgsqlBackend {
19 manifest: BackendManifest,
20}
21
22impl CsharpNpgsqlBackend {
23 pub fn new(engine: &str) -> Result<Self, ScytheError> {
24 match engine {
25 "postgresql" | "postgres" | "pg" => {}
26 _ => {
27 return Err(ScytheError::new(
28 ErrorCode::InternalError,
29 format!(
30 "csharp-npgsql only supports PostgreSQL, got engine '{}'",
31 engine
32 ),
33 ));
34 }
35 }
36 let manifest_path = Path::new("backends/csharp-npgsql/manifest.toml");
37 let manifest = if manifest_path.exists() {
38 load_manifest(manifest_path)
39 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
40 } else {
41 toml::from_str(DEFAULT_MANIFEST_TOML)
42 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
43 };
44 Ok(Self { manifest })
45 }
46}
47
48fn reader_method(neutral_type: &str) -> &'static str {
50 match neutral_type {
51 "bool" => "GetBoolean",
52 "int16" => "GetInt16",
53 "int32" => "GetInt32",
54 "int64" => "GetInt64",
55 "float32" => "GetFloat",
56 "float64" => "GetDouble",
57 "string" | "json" | "inet" | "interval" => "GetString",
58 "uuid" => "GetGuid",
59 "decimal" => "GetDecimal",
60 "date" => "GetFieldValue<DateOnly>",
61 "time" | "time_tz" => "GetFieldValue<TimeOnly>",
62 "datetime" => "GetDateTime",
63 "datetime_tz" => "GetFieldValue<DateTimeOffset>",
64 _ => "GetValue",
65 }
66}
67
68fn rewrite_params(sql: &str) -> String {
70 let mut result = sql.to_string();
71 for i in (1..=99).rev() {
73 let from = format!("${}", i);
74 let to = format!("@p{}", i);
75 result = result.replace(&from, &to);
76 }
77 result
78}
79
80fn column_read_expr(col: &ResolvedColumn, ordinal: usize) -> String {
82 if col.neutral_type.starts_with("enum::") {
83 format!(
84 "Enum.Parse<{}>(reader.GetString({}), true)",
85 col.lang_type, ordinal
86 )
87 } else {
88 let method = reader_method(&col.neutral_type);
89 format!("reader.{}({})", method, ordinal)
90 }
91}
92
93impl CodegenBackend for CsharpNpgsqlBackend {
94 fn name(&self) -> &str {
95 "csharp-npgsql"
96 }
97
98 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
99 &self.manifest
100 }
101
102 fn file_header(&self) -> String {
103 "// Auto-generated by scythe. Do not edit.\nusing Npgsql;\n\npublic static class Queries {"
104 .to_string()
105 }
106
107 fn file_footer(&self) -> String {
108 "}".to_string()
109 }
110
111 fn generate_row_struct(
112 &self,
113 query_name: &str,
114 columns: &[ResolvedColumn],
115 ) -> Result<String, ScytheError> {
116 let struct_name = row_struct_name(query_name, &self.manifest.naming);
117 let mut out = String::new();
118 let _ = writeln!(out, "public record {}(", struct_name);
119 for (i, c) in columns.iter().enumerate() {
120 let field = to_pascal_case(&c.field_name);
121 let sep = if i + 1 < columns.len() { "," } else { "" };
122 let _ = writeln!(out, " {} {}{}", c.full_type, field, sep);
123 }
124 let _ = write!(out, ");");
125 Ok(out)
126 }
127
128 fn generate_model_struct(
129 &self,
130 table_name: &str,
131 columns: &[ResolvedColumn],
132 ) -> Result<String, ScytheError> {
133 let name = to_pascal_case(table_name);
134 self.generate_row_struct(&name, columns)
135 }
136
137 fn generate_query_fn(
138 &self,
139 analyzed: &AnalyzedQuery,
140 struct_name: &str,
141 columns: &[ResolvedColumn],
142 params: &[ResolvedParam],
143 ) -> Result<String, ScytheError> {
144 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
145 let mut sql = rewrite_params(&super::clean_sql_oneline(&analyzed.sql));
146 for (i, p) in params.iter().enumerate() {
148 if let Some(enum_name) = p.neutral_type.strip_prefix("enum::") {
149 let placeholder = format!("@p{}", i + 1);
150 let casted = format!("@p{}::{}", i + 1, enum_name);
151 sql = sql.replace(&placeholder, &casted);
152 }
153 }
154 let mut out = String::new();
155
156 let param_list = params
158 .iter()
159 .map(|p| format!("{} {}", p.full_type, p.field_name))
160 .collect::<Vec<_>>()
161 .join(", ");
162 let sep = if param_list.is_empty() { "" } else { ", " };
163
164 let return_type = match &analyzed.command {
166 QueryCommand::One => format!("{}?", struct_name),
167 QueryCommand::Many | QueryCommand::Batch => {
168 format!("List<{}>", struct_name)
169 }
170 QueryCommand::Exec => "void".to_string(),
171 QueryCommand::ExecResult | QueryCommand::ExecRows => "int".to_string(),
172 };
173
174 let is_async_void = return_type == "void";
175 let task_type = if is_async_void {
176 "Task".to_string()
177 } else {
178 format!("Task<{}>", return_type)
179 };
180
181 let _ = writeln!(
182 out,
183 "public static async {} {}(NpgsqlConnection conn{}{}) {{",
184 task_type, func_name, sep, param_list
185 );
186
187 let _ = writeln!(
189 out,
190 " await using var cmd = new NpgsqlCommand(\"{}\", conn);",
191 sql
192 );
193 for (i, p) in params.iter().enumerate() {
194 let value_expr = if p.neutral_type.starts_with("enum::") {
195 format!("{}.ToString().ToLower()", p.field_name)
196 } else {
197 p.field_name.clone()
198 };
199 let _ = writeln!(
200 out,
201 " cmd.Parameters.AddWithValue(\"p{}\", {});",
202 i + 1,
203 value_expr
204 );
205 }
206
207 match &analyzed.command {
208 QueryCommand::One => {
209 let _ = writeln!(
210 out,
211 " await using var reader = await cmd.ExecuteReaderAsync();"
212 );
213 let _ = writeln!(out, " if (!await reader.ReadAsync()) return null;");
214 let _ = writeln!(out, " return new {}(", struct_name);
215 for (i, col) in columns.iter().enumerate() {
216 let expr = column_read_expr(col, i);
217 let sep = if i + 1 < columns.len() { "," } else { "" };
218 if col.nullable {
219 let _ = writeln!(out, " reader.IsDBNull({i}) ? null : {expr}{sep}");
220 } else {
221 let _ = writeln!(out, " {expr}{sep}");
222 }
223 }
224 let _ = writeln!(out, " );");
225 }
226 QueryCommand::Many | QueryCommand::Batch => {
227 let _ = writeln!(
228 out,
229 " await using var reader = await cmd.ExecuteReaderAsync();"
230 );
231 let _ = writeln!(out, " var results = new List<{}>();", struct_name);
232 let _ = writeln!(out, " while (await reader.ReadAsync()) {{");
233 let _ = writeln!(out, " results.Add(new {}(", struct_name);
234 for (i, col) in columns.iter().enumerate() {
235 let expr = column_read_expr(col, i);
236 let sep = if i + 1 < columns.len() { "," } else { "" };
237 if col.nullable {
238 let _ =
239 writeln!(out, " reader.IsDBNull({i}) ? null : {expr}{sep}");
240 } else {
241 let _ = writeln!(out, " {expr}{sep}");
242 }
243 }
244 let _ = writeln!(out, " ));");
245 let _ = writeln!(out, " }}");
246 let _ = writeln!(out, " return results;");
247 }
248 QueryCommand::Exec => {
249 let _ = writeln!(out, " await cmd.ExecuteNonQueryAsync();");
250 }
251 QueryCommand::ExecResult | QueryCommand::ExecRows => {
252 let _ = writeln!(out, " return await cmd.ExecuteNonQueryAsync();");
253 }
254 }
255
256 let _ = write!(out, "}}");
257 Ok(out)
258 }
259
260 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
261 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
262 let mut out = String::new();
263 let _ = writeln!(out, "public enum {} {{", type_name);
264 for value in &enum_info.values {
265 let variant = enum_variant_name(value, &self.manifest.naming);
266 let _ = writeln!(out, " {},", variant);
267 }
268 let _ = write!(out, "}}");
269 Ok(out)
270 }
271
272 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
273 let name = to_pascal_case(&composite.sql_name);
274 let mut out = String::new();
275 if composite.fields.is_empty() {
276 let _ = writeln!(out, "public record {}();", name);
277 } else {
278 let _ = writeln!(out, "public record {}(", name);
279 for (i, field) in composite.fields.iter().enumerate() {
280 let cs_type = resolve_type(&field.neutral_type, &self.manifest, false)
281 .map(|t| t.into_owned())
282 .unwrap_or_else(|_| "object".to_string());
283 let field_name = to_pascal_case(&field.name);
284 let sep = if i + 1 < composite.fields.len() {
285 ","
286 } else {
287 ""
288 };
289 let _ = writeln!(out, " {} {}{}", cs_type, field_name, sep);
290 }
291 let _ = write!(out, ");");
292 }
293 Ok(out)
294 }
295}