1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/csharp-npgsql.toml");
17
18pub struct CsharpNpgsqlBackend {
19 manifest: BackendManifest,
20}
21
22impl CsharpNpgsqlBackend {
23 pub fn new(engine: &str) -> Result<Self, ScytheError> {
24 match engine {
25 "postgresql" | "postgres" | "pg" => {}
26 _ => {
27 return Err(ScytheError::new(
28 ErrorCode::InternalError,
29 format!(
30 "csharp-npgsql only supports PostgreSQL, got engine '{}'",
31 engine
32 ),
33 ));
34 }
35 }
36 let manifest_path = Path::new("backends/csharp-npgsql/manifest.toml");
37 let manifest = if manifest_path.exists() {
38 load_manifest(manifest_path)
39 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
40 } else {
41 toml::from_str(DEFAULT_MANIFEST_TOML)
42 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
43 };
44 Ok(Self { manifest })
45 }
46}
47
48fn reader_method(neutral_type: &str) -> &'static str {
50 match neutral_type {
51 "bool" => "GetBoolean",
52 "int16" => "GetInt16",
53 "int32" => "GetInt32",
54 "int64" => "GetInt64",
55 "float32" => "GetFloat",
56 "float64" => "GetDouble",
57 "string" | "json" | "inet" | "interval" => "GetString",
58 "uuid" => "GetGuid",
59 "decimal" => "GetDecimal",
60 "date" => "GetFieldValue<DateOnly>",
61 "time" | "time_tz" => "GetFieldValue<TimeOnly>",
62 "datetime" => "GetDateTime",
63 "datetime_tz" => "GetFieldValue<DateTimeOffset>",
64 _ => "GetValue",
65 }
66}
67
68fn rewrite_params(sql: &str) -> String {
70 let mut result = sql.to_string();
71 for i in (1..=99).rev() {
73 let from = format!("${}", i);
74 let to = format!("@p{}", i);
75 result = result.replace(&from, &to);
76 }
77 result
78}
79
80fn column_read_expr(col: &ResolvedColumn, ordinal: usize) -> String {
82 if col.neutral_type.starts_with("enum::") {
83 format!(
84 "(Enum.TryParse<{typ}>(reader.GetString({ord}), true, out var enumVal{ord}) ? enumVal{ord} : throw new InvalidOperationException($\"Invalid enum value '{{reader.GetString({ord})}}' for {typ}\"))",
85 typ = col.lang_type,
86 ord = ordinal
87 )
88 } else {
89 let method = reader_method(&col.neutral_type);
90 format!("reader.{}({})", method, ordinal)
91 }
92}
93
94impl CodegenBackend for CsharpNpgsqlBackend {
95 fn name(&self) -> &str {
96 "csharp-npgsql"
97 }
98
99 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
100 &self.manifest
101 }
102
103 fn file_header(&self) -> String {
104 "// Auto-generated by scythe. Do not edit.\nusing Npgsql;\n\npublic static class Queries {"
105 .to_string()
106 }
107
108 fn file_footer(&self) -> String {
109 "}".to_string()
110 }
111
112 fn generate_row_struct(
113 &self,
114 query_name: &str,
115 columns: &[ResolvedColumn],
116 ) -> Result<String, ScytheError> {
117 let struct_name = row_struct_name(query_name, &self.manifest.naming);
118 let mut out = String::new();
119 let _ = writeln!(out, "public record {}(", struct_name);
120 for (i, c) in columns.iter().enumerate() {
121 let field = to_pascal_case(&c.field_name);
122 let sep = if i + 1 < columns.len() { "," } else { "" };
123 let _ = writeln!(out, " {} {}{}", c.full_type, field, sep);
124 }
125 let _ = write!(out, ");");
126 Ok(out)
127 }
128
129 fn generate_model_struct(
130 &self,
131 table_name: &str,
132 columns: &[ResolvedColumn],
133 ) -> Result<String, ScytheError> {
134 let name = to_pascal_case(table_name);
135 self.generate_row_struct(&name, columns)
136 }
137
138 fn generate_query_fn(
139 &self,
140 analyzed: &AnalyzedQuery,
141 struct_name: &str,
142 columns: &[ResolvedColumn],
143 params: &[ResolvedParam],
144 ) -> Result<String, ScytheError> {
145 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
146 let mut sql = rewrite_params(&super::clean_sql_oneline(&analyzed.sql));
147 for (i, p) in params.iter().enumerate() {
149 if let Some(enum_name) = p.neutral_type.strip_prefix("enum::") {
150 let placeholder = format!("@p{}", i + 1);
151 let casted = format!("@p{}::{}", i + 1, enum_name);
152 sql = sql.replace(&placeholder, &casted);
153 }
154 }
155 let mut out = String::new();
156
157 let param_list = params
159 .iter()
160 .map(|p| format!("{} {}", p.full_type, p.field_name))
161 .collect::<Vec<_>>()
162 .join(", ");
163 let sep = if param_list.is_empty() { "" } else { ", " };
164
165 let return_type = match &analyzed.command {
167 QueryCommand::One => format!("{}?", struct_name),
168 QueryCommand::Many | QueryCommand::Batch => {
169 format!("List<{}>", struct_name)
170 }
171 QueryCommand::Exec => "void".to_string(),
172 QueryCommand::ExecResult | QueryCommand::ExecRows => "int".to_string(),
173 };
174
175 let is_async_void = return_type == "void";
176 let task_type = if is_async_void {
177 "Task".to_string()
178 } else {
179 format!("Task<{}>", return_type)
180 };
181
182 let _ = writeln!(
183 out,
184 "public static async {} {}(NpgsqlConnection conn{}{}) {{",
185 task_type, func_name, sep, param_list
186 );
187
188 let _ = writeln!(
190 out,
191 " await using var cmd = new NpgsqlCommand(\"{}\", conn);",
192 sql
193 );
194 for (i, p) in params.iter().enumerate() {
195 let value_expr = if p.neutral_type.starts_with("enum::") {
196 format!("{}.ToString().ToLower()", p.field_name)
197 } else {
198 p.field_name.clone()
199 };
200 let _ = writeln!(
201 out,
202 " cmd.Parameters.AddWithValue(\"p{}\", {});",
203 i + 1,
204 value_expr
205 );
206 }
207
208 match &analyzed.command {
209 QueryCommand::One => {
210 let _ = writeln!(
211 out,
212 " await using var reader = await cmd.ExecuteReaderAsync();"
213 );
214 let _ = writeln!(out, " if (!await reader.ReadAsync()) return null;");
215 let _ = writeln!(out, " return new {}(", struct_name);
216 for (i, col) in columns.iter().enumerate() {
217 let expr = column_read_expr(col, i);
218 let sep = if i + 1 < columns.len() { "," } else { "" };
219 if col.nullable {
220 let _ = writeln!(out, " reader.IsDBNull({i}) ? null : {expr}{sep}");
221 } else {
222 let _ = writeln!(out, " {expr}{sep}");
223 }
224 }
225 let _ = writeln!(out, " );");
226 }
227 QueryCommand::Many | QueryCommand::Batch => {
228 let _ = writeln!(
229 out,
230 " await using var reader = await cmd.ExecuteReaderAsync();"
231 );
232 let _ = writeln!(out, " var results = new List<{}>();", struct_name);
233 let _ = writeln!(out, " while (await reader.ReadAsync()) {{");
234 let _ = writeln!(out, " results.Add(new {}(", struct_name);
235 for (i, col) in columns.iter().enumerate() {
236 let expr = column_read_expr(col, i);
237 let sep = if i + 1 < columns.len() { "," } else { "" };
238 if col.nullable {
239 let _ =
240 writeln!(out, " reader.IsDBNull({i}) ? null : {expr}{sep}");
241 } else {
242 let _ = writeln!(out, " {expr}{sep}");
243 }
244 }
245 let _ = writeln!(out, " ));");
246 let _ = writeln!(out, " }}");
247 let _ = writeln!(out, " return results;");
248 }
249 QueryCommand::Exec => {
250 let _ = writeln!(out, " await cmd.ExecuteNonQueryAsync();");
251 }
252 QueryCommand::ExecResult | QueryCommand::ExecRows => {
253 let _ = writeln!(out, " return await cmd.ExecuteNonQueryAsync();");
254 }
255 }
256
257 let _ = write!(out, "}}");
258 Ok(out)
259 }
260
261 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
262 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
263 let mut out = String::new();
264 let _ = writeln!(out, "public enum {} {{", type_name);
265 for value in &enum_info.values {
266 let variant = enum_variant_name(value, &self.manifest.naming);
267 let _ = writeln!(out, " {},", variant);
268 }
269 let _ = write!(out, "}}");
270 Ok(out)
271 }
272
273 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
274 let name = to_pascal_case(&composite.sql_name);
275 let mut out = String::new();
276 if composite.fields.is_empty() {
277 let _ = writeln!(out, "public record {}();", name);
278 } else {
279 let _ = writeln!(out, "public record {}(", name);
280 for (i, field) in composite.fields.iter().enumerate() {
281 let cs_type = resolve_type(&field.neutral_type, &self.manifest, false)
282 .map(|t| t.into_owned())
283 .unwrap_or_else(|_| "object".to_string());
284 let field_name = to_pascal_case(&field.name);
285 let sep = if i + 1 < composite.fields.len() {
286 ","
287 } else {
288 ""
289 };
290 let _ = writeln!(out, " {} {}{}", cs_type, field_name, sep);
291 }
292 let _ = write!(out, ");");
293 }
294 Ok(out)
295 }
296}