1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/csharp-npgsql.toml");
17
18pub struct CsharpNpgsqlBackend {
19 manifest: BackendManifest,
20}
21
22impl CsharpNpgsqlBackend {
23 pub fn new(engine: &str) -> Result<Self, ScytheError> {
24 match engine {
25 "postgresql" | "postgres" | "pg" => {}
26 _ => {
27 return Err(ScytheError::new(
28 ErrorCode::InternalError,
29 format!(
30 "csharp-npgsql only supports PostgreSQL, got engine '{}'",
31 engine
32 ),
33 ));
34 }
35 }
36 let manifest_path = Path::new("backends/csharp-npgsql/manifest.toml");
37 let manifest = if manifest_path.exists() {
38 load_manifest(manifest_path)
39 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
40 } else {
41 toml::from_str(DEFAULT_MANIFEST_TOML)
42 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
43 };
44 Ok(Self { manifest })
45 }
46}
47
48fn reader_method(neutral_type: &str) -> &'static str {
50 match neutral_type {
51 "bool" => "GetBoolean",
52 "int16" => "GetInt16",
53 "int32" => "GetInt32",
54 "int64" => "GetInt64",
55 "float32" => "GetFloat",
56 "float64" => "GetDouble",
57 "string" | "json" | "inet" | "interval" => "GetString",
58 "uuid" => "GetGuid",
59 "decimal" => "GetDecimal",
60 "date" => "GetFieldValue<DateOnly>",
61 "time" | "time_tz" => "GetFieldValue<TimeOnly>",
62 "datetime" => "GetDateTime",
63 "datetime_tz" => "GetFieldValue<DateTimeOffset>",
64 _ => "GetValue",
65 }
66}
67
68fn rewrite_params(sql: &str) -> String {
70 let mut result = sql.to_string();
71 for i in (1..=99).rev() {
73 let from = format!("${}", i);
74 let to = format!("@p{}", i);
75 result = result.replace(&from, &to);
76 }
77 result
78}
79
80fn column_read_expr(col: &ResolvedColumn, ordinal: usize) -> String {
82 if col.neutral_type.starts_with("enum::") {
83 format!(
84 "(Enum.TryParse<{typ}>(reader.GetString({ord}), true, out var enumVal{ord}) ? enumVal{ord} : throw new InvalidOperationException($\"Invalid enum value '{{reader.GetString({ord})}}' for {typ}\"))",
85 typ = col.lang_type,
86 ord = ordinal
87 )
88 } else {
89 let method = reader_method(&col.neutral_type);
90 format!("reader.{}({})", method, ordinal)
91 }
92}
93
94impl CodegenBackend for CsharpNpgsqlBackend {
95 fn name(&self) -> &str {
96 "csharp-npgsql"
97 }
98
99 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
100 &self.manifest
101 }
102
103 fn file_header(&self) -> String {
104 "// Auto-generated by scythe. Do not edit.\n#nullable enable\n\nusing Npgsql;\n\npublic static class Queries {"
105 .to_string()
106 }
107
108 fn file_footer(&self) -> String {
109 "}".to_string()
110 }
111
112 fn generate_row_struct(
113 &self,
114 query_name: &str,
115 columns: &[ResolvedColumn],
116 ) -> Result<String, ScytheError> {
117 let struct_name = row_struct_name(query_name, &self.manifest.naming);
118 let mut out = String::new();
119 let _ = writeln!(out, "public record {}(", struct_name);
120 for (i, c) in columns.iter().enumerate() {
121 let field = to_pascal_case(&c.field_name);
122 let sep = if i + 1 < columns.len() { "," } else { "" };
123 let _ = writeln!(out, " {} {}{}", c.full_type, field, sep);
124 }
125 let _ = write!(out, ");");
126 Ok(out)
127 }
128
129 fn generate_model_struct(
130 &self,
131 table_name: &str,
132 columns: &[ResolvedColumn],
133 ) -> Result<String, ScytheError> {
134 let name = to_pascal_case(table_name);
135 self.generate_row_struct(&name, columns)
136 }
137
138 fn generate_query_fn(
139 &self,
140 analyzed: &AnalyzedQuery,
141 struct_name: &str,
142 columns: &[ResolvedColumn],
143 params: &[ResolvedParam],
144 ) -> Result<String, ScytheError> {
145 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
146 let mut sql = rewrite_params(&super::clean_sql_oneline_with_optional(
147 &analyzed.sql,
148 &analyzed.optional_params,
149 &analyzed.params,
150 ));
151 for (i, p) in params.iter().enumerate() {
153 if let Some(enum_name) = p.neutral_type.strip_prefix("enum::") {
154 let placeholder = format!("@p{}", i + 1);
155 let casted = format!("@p{}::{}", i + 1, enum_name);
156 sql = sql.replace(&placeholder, &casted);
157 }
158 }
159 let mut out = String::new();
160
161 let param_list = params
163 .iter()
164 .map(|p| format!("{} {}", p.full_type, p.field_name))
165 .collect::<Vec<_>>()
166 .join(", ");
167 let sep = if param_list.is_empty() { "" } else { ", " };
168
169 if matches!(analyzed.command, QueryCommand::Batch) {
171 let batch_fn_name = format!("{}Batch", func_name);
172 if params.len() > 1 {
173 let params_record_name = format!("{}BatchParams", to_pascal_case(&analyzed.name));
174 let _ = writeln!(out, "public record {}(", params_record_name);
175 for (i, p) in params.iter().enumerate() {
176 let field = to_pascal_case(&p.field_name);
177 let sep = if i + 1 < params.len() { "," } else { "" };
178 let _ = writeln!(out, " {} {}{}", p.full_type, field, sep);
179 }
180 let _ = writeln!(out, ");");
181 let _ = writeln!(out);
182 let _ = writeln!(
183 out,
184 "public static async Task {}(NpgsqlConnection conn, List<{}> items) {{",
185 batch_fn_name, params_record_name
186 );
187 } else if params.len() == 1 {
188 let _ = writeln!(
189 out,
190 "public static async Task {}(NpgsqlConnection conn, List<{}> items) {{",
191 batch_fn_name, params[0].full_type
192 );
193 } else {
194 let _ = writeln!(
195 out,
196 "public static async Task {}(NpgsqlConnection conn, int count) {{",
197 batch_fn_name
198 );
199 }
200 let _ = writeln!(
201 out,
202 " await using var tx = await conn.BeginTransactionAsync();"
203 );
204 let _ = writeln!(out, " try {{");
205 if params.is_empty() {
206 let _ = writeln!(out, " for (int i = 0; i < count; i++) {{");
207 } else {
208 let _ = writeln!(out, " foreach (var item in items) {{");
209 }
210 let _ = writeln!(
211 out,
212 " await using var cmd = new NpgsqlCommand(\"{}\", conn, tx);",
213 sql
214 );
215 for (i, p) in params.iter().enumerate() {
216 let value_expr = if params.len() > 1 {
217 let field = to_pascal_case(&p.field_name);
218 if p.neutral_type.starts_with("enum::") {
219 format!("item.{}.ToString().ToLower()", field)
220 } else {
221 format!("item.{}", field)
222 }
223 } else if p.neutral_type.starts_with("enum::") {
224 "item.ToString().ToLower()".to_string()
225 } else {
226 "item".to_string()
227 };
228 let _ = writeln!(
229 out,
230 " cmd.Parameters.AddWithValue(\"p{}\", {});",
231 i + 1,
232 value_expr
233 );
234 }
235 let _ = writeln!(out, " await cmd.ExecuteNonQueryAsync();");
236 let _ = writeln!(out, " }}");
237 let _ = writeln!(out, " await tx.CommitAsync();");
238 let _ = writeln!(out, " }} catch {{");
239 let _ = writeln!(out, " await tx.RollbackAsync();");
240 let _ = writeln!(out, " throw;");
241 let _ = writeln!(out, " }}");
242 let _ = write!(out, "}}");
243 return Ok(out);
244 }
245
246 let return_type = match &analyzed.command {
248 QueryCommand::One => format!("{}?", struct_name),
249 QueryCommand::Many => format!("List<{}>", struct_name),
250 QueryCommand::Exec => "void".to_string(),
251 QueryCommand::ExecResult | QueryCommand::ExecRows => "int".to_string(),
252 QueryCommand::Batch | QueryCommand::Grouped => unreachable!(),
253 };
254
255 let is_async_void = return_type == "void";
256 let task_type = if is_async_void {
257 "Task".to_string()
258 } else {
259 format!("Task<{}>", return_type)
260 };
261
262 let _ = writeln!(
263 out,
264 "public static async {} {}(NpgsqlConnection conn{}{}) {{",
265 task_type, func_name, sep, param_list
266 );
267
268 let _ = writeln!(
270 out,
271 " await using var cmd = new NpgsqlCommand(\"{}\", conn);",
272 sql
273 );
274 for (i, p) in params.iter().enumerate() {
275 let value_expr = if p.neutral_type.starts_with("enum::") {
276 format!("{}.ToString().ToLower()", p.field_name)
277 } else {
278 p.field_name.clone()
279 };
280 let _ = writeln!(
281 out,
282 " cmd.Parameters.AddWithValue(\"p{}\", {});",
283 i + 1,
284 value_expr
285 );
286 }
287
288 match &analyzed.command {
289 QueryCommand::One => {
290 let _ = writeln!(
291 out,
292 " await using var reader = await cmd.ExecuteReaderAsync();"
293 );
294 let _ = writeln!(out, " if (!await reader.ReadAsync()) return null;");
295 let _ = writeln!(out, " return new {}(", struct_name);
296 for (i, col) in columns.iter().enumerate() {
297 let expr = column_read_expr(col, i);
298 let sep = if i + 1 < columns.len() { "," } else { "" };
299 if col.nullable {
300 let _ = writeln!(out, " reader.IsDBNull({i}) ? null : {expr}{sep}");
301 } else {
302 let _ = writeln!(out, " {expr}{sep}");
303 }
304 }
305 let _ = writeln!(out, " );");
306 }
307 QueryCommand::Many => {
308 let _ = writeln!(
309 out,
310 " await using var reader = await cmd.ExecuteReaderAsync();"
311 );
312 let _ = writeln!(out, " var results = new List<{}>();", struct_name);
313 let _ = writeln!(out, " while (await reader.ReadAsync()) {{");
314 let _ = writeln!(out, " results.Add(new {}(", struct_name);
315 for (i, col) in columns.iter().enumerate() {
316 let expr = column_read_expr(col, i);
317 let sep = if i + 1 < columns.len() { "," } else { "" };
318 if col.nullable {
319 let _ =
320 writeln!(out, " reader.IsDBNull({i}) ? null : {expr}{sep}");
321 } else {
322 let _ = writeln!(out, " {expr}{sep}");
323 }
324 }
325 let _ = writeln!(out, " ));");
326 let _ = writeln!(out, " }}");
327 let _ = writeln!(out, " return results;");
328 }
329 QueryCommand::Exec => {
330 let _ = writeln!(out, " await cmd.ExecuteNonQueryAsync();");
331 }
332 QueryCommand::ExecResult | QueryCommand::ExecRows => {
333 let _ = writeln!(out, " return await cmd.ExecuteNonQueryAsync();");
334 }
335 QueryCommand::Batch | QueryCommand::Grouped => unreachable!(),
336 }
337
338 let _ = write!(out, "}}");
339 Ok(out)
340 }
341
342 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
343 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
344 let mut out = String::new();
345 let _ = writeln!(out, "public enum {} {{", type_name);
346 for value in &enum_info.values {
347 let variant = enum_variant_name(value, &self.manifest.naming);
348 let _ = writeln!(out, " {},", variant);
349 }
350 let _ = write!(out, "}}");
351 Ok(out)
352 }
353
354 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
355 let name = to_pascal_case(&composite.sql_name);
356 let mut out = String::new();
357 if composite.fields.is_empty() {
358 let _ = writeln!(out, "public record {}();", name);
359 } else {
360 let _ = writeln!(out, "public record {}(", name);
361 for (i, field) in composite.fields.iter().enumerate() {
362 let cs_type = resolve_type(&field.neutral_type, &self.manifest, false)
363 .map(|t| t.into_owned())
364 .unwrap_or_else(|_| "object".to_string());
365 let field_name = to_pascal_case(&field.name);
366 let sep = if i + 1 < composite.fields.len() {
367 ","
368 } else {
369 ""
370 };
371 let _ = writeln!(out, " {} {}{}", cs_type, field_name, sep);
372 }
373 let _ = write!(out, ");");
374 }
375 Ok(out)
376 }
377}