1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/csharp-npgsql.toml");
17
18pub struct CsharpNpgsqlBackend {
19 manifest: BackendManifest,
20}
21
22impl CsharpNpgsqlBackend {
23 pub fn new() -> Result<Self, ScytheError> {
24 let manifest_path = Path::new("backends/csharp-npgsql/manifest.toml");
25 let manifest = if manifest_path.exists() {
26 load_manifest(manifest_path)
27 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
28 } else {
29 toml::from_str(DEFAULT_MANIFEST_TOML)
30 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
31 };
32 Ok(Self { manifest })
33 }
34
35 pub fn manifest(&self) -> &BackendManifest {
36 &self.manifest
37 }
38}
39
40fn reader_method(neutral_type: &str) -> &'static str {
42 match neutral_type {
43 "bool" => "GetBoolean",
44 "int16" => "GetInt16",
45 "int32" => "GetInt32",
46 "int64" => "GetInt64",
47 "float32" => "GetFloat",
48 "float64" => "GetDouble",
49 "string" | "json" | "inet" | "interval" => "GetString",
50 "uuid" => "GetGuid",
51 "decimal" => "GetDecimal",
52 "date" => "GetFieldValue<DateOnly>",
53 "time" | "time_tz" => "GetFieldValue<TimeOnly>",
54 "datetime" => "GetDateTime",
55 "datetime_tz" => "GetFieldValue<DateTimeOffset>",
56 _ => "GetValue",
57 }
58}
59
60fn rewrite_params(sql: &str) -> String {
62 let mut result = sql.to_string();
63 for i in (1..=99).rev() {
65 let from = format!("${}", i);
66 let to = format!("@p{}", i);
67 result = result.replace(&from, &to);
68 }
69 result
70}
71
72impl CodegenBackend for CsharpNpgsqlBackend {
73 fn name(&self) -> &str {
74 "csharp-npgsql"
75 }
76
77 fn generate_row_struct(
78 &self,
79 query_name: &str,
80 columns: &[ResolvedColumn],
81 ) -> Result<String, ScytheError> {
82 let struct_name = row_struct_name(query_name, &self.manifest.naming);
83 let mut out = String::new();
84 let _ = writeln!(out, "public record {}(", struct_name);
85 for (i, c) in columns.iter().enumerate() {
86 let field = to_pascal_case(&c.field_name);
87 let sep = if i + 1 < columns.len() { "," } else { "" };
88 let _ = writeln!(out, " {} {}{}", c.full_type, field, sep);
89 }
90 let _ = write!(out, ");");
91 Ok(out)
92 }
93
94 fn generate_model_struct(
95 &self,
96 table_name: &str,
97 columns: &[ResolvedColumn],
98 ) -> Result<String, ScytheError> {
99 let name = to_pascal_case(table_name);
100 self.generate_row_struct(&name, columns)
101 }
102
103 fn generate_query_fn(
104 &self,
105 analyzed: &AnalyzedQuery,
106 struct_name: &str,
107 columns: &[ResolvedColumn],
108 params: &[ResolvedParam],
109 ) -> Result<String, ScytheError> {
110 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
111 let sql = rewrite_params(&super::clean_sql(&analyzed.sql));
112 let mut out = String::new();
113
114 let param_list = params
116 .iter()
117 .map(|p| format!("{} {}", p.full_type, p.field_name))
118 .collect::<Vec<_>>()
119 .join(", ");
120 let sep = if param_list.is_empty() { "" } else { ", " };
121
122 let return_type = match &analyzed.command {
124 QueryCommand::One => format!("{}?", struct_name),
125 QueryCommand::Many | QueryCommand::Batch => {
126 format!("List<{}>", struct_name)
127 }
128 QueryCommand::Exec => "void".to_string(),
129 QueryCommand::ExecResult | QueryCommand::ExecRows => "int".to_string(),
130 };
131
132 let is_async_void = return_type == "void";
133 let task_type = if is_async_void {
134 "Task".to_string()
135 } else {
136 format!("Task<{}>", return_type)
137 };
138
139 let _ = writeln!(
140 out,
141 "public static async {} {}(NpgsqlConnection conn{}{}) {{",
142 task_type, func_name, sep, param_list
143 );
144
145 let _ = writeln!(
147 out,
148 " await using var cmd = new NpgsqlCommand(\"{}\", conn);",
149 sql
150 );
151 for (i, p) in params.iter().enumerate() {
152 let _ = writeln!(
153 out,
154 " cmd.Parameters.AddWithValue(\"p{}\", {});",
155 i + 1,
156 p.field_name
157 );
158 }
159
160 match &analyzed.command {
161 QueryCommand::One => {
162 let _ = writeln!(
163 out,
164 " await using var reader = await cmd.ExecuteReaderAsync();"
165 );
166 let _ = writeln!(out, " if (!await reader.ReadAsync()) return null;");
167 let _ = writeln!(out, " return new {}(", struct_name);
168 for (i, col) in columns.iter().enumerate() {
169 let method = reader_method(&col.neutral_type);
170 let sep = if i + 1 < columns.len() { "," } else { "" };
171 if col.nullable {
172 let _ = writeln!(
173 out,
174 " reader.IsDBNull({i}) ? null : reader.{method}({i}){sep}"
175 );
176 } else {
177 let _ = writeln!(out, " reader.{method}({i}){sep}");
178 }
179 }
180 let _ = writeln!(out, " );");
181 }
182 QueryCommand::Many | QueryCommand::Batch => {
183 let _ = writeln!(
184 out,
185 " await using var reader = await cmd.ExecuteReaderAsync();"
186 );
187 let _ = writeln!(out, " var results = new List<{}>();", struct_name);
188 let _ = writeln!(out, " while (await reader.ReadAsync()) {{");
189 let _ = writeln!(out, " results.Add(new {}(", struct_name);
190 for (i, col) in columns.iter().enumerate() {
191 let method = reader_method(&col.neutral_type);
192 let sep = if i + 1 < columns.len() { "," } else { "" };
193 if col.nullable {
194 let _ = writeln!(
195 out,
196 " reader.IsDBNull({i}) ? null : reader.{method}({i}){sep}"
197 );
198 } else {
199 let _ = writeln!(out, " reader.{method}({i}){sep}");
200 }
201 }
202 let _ = writeln!(out, " ));");
203 let _ = writeln!(out, " }}");
204 let _ = writeln!(out, " return results;");
205 }
206 QueryCommand::Exec => {
207 let _ = writeln!(out, " await cmd.ExecuteNonQueryAsync();");
208 }
209 QueryCommand::ExecResult | QueryCommand::ExecRows => {
210 let _ = writeln!(out, " return await cmd.ExecuteNonQueryAsync();");
211 }
212 }
213
214 let _ = write!(out, "}}");
215 Ok(out)
216 }
217
218 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
219 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
220 let mut out = String::new();
221 let _ = writeln!(out, "public enum {} {{", type_name);
222 for value in &enum_info.values {
223 let variant = enum_variant_name(value, &self.manifest.naming);
224 let _ = writeln!(out, " {},", variant);
225 }
226 let _ = write!(out, "}}");
227 Ok(out)
228 }
229
230 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
231 let name = to_pascal_case(&composite.sql_name);
232 let mut out = String::new();
233 if composite.fields.is_empty() {
234 let _ = writeln!(out, "public record {}();", name);
235 } else {
236 let _ = writeln!(out, "public record {}(", name);
237 for (i, field) in composite.fields.iter().enumerate() {
238 let cs_type = resolve_type(&field.neutral_type, &self.manifest, false)
239 .map(|t| t.into_owned())
240 .unwrap_or_else(|_| "object".to_string());
241 let field_name = to_pascal_case(&field.name);
242 let sep = if i + 1 < composite.fields.len() {
243 ","
244 } else {
245 ""
246 };
247 let _ = writeln!(out, " {} {}{}", cs_type, field_name, sep);
248 }
249 let _ = write!(out, ");");
250 }
251 Ok(out)
252 }
253}