1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16const DEFAULT_MANIFEST_TOML: &str =
17 include_str!("../../manifests/csharp-npgsql.toml");
18
19pub struct CsharpNpgsqlBackend {
20 manifest: BackendManifest,
21}
22
23impl CsharpNpgsqlBackend {
24 pub fn new() -> Result<Self, ScytheError> {
25 let manifest_path = Path::new("backends/csharp-npgsql/manifest.toml");
26 let manifest = if manifest_path.exists() {
27 load_manifest(manifest_path)
28 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
29 } else {
30 toml::from_str(DEFAULT_MANIFEST_TOML)
31 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
32 };
33 Ok(Self { manifest })
34 }
35
36 pub fn manifest(&self) -> &BackendManifest {
37 &self.manifest
38 }
39}
40
41fn reader_method(neutral_type: &str) -> &'static str {
43 match neutral_type {
44 "bool" => "GetBoolean",
45 "int16" => "GetInt16",
46 "int32" => "GetInt32",
47 "int64" => "GetInt64",
48 "float32" => "GetFloat",
49 "float64" => "GetDouble",
50 "string" | "json" | "inet" | "interval" => "GetString",
51 "uuid" => "GetGuid",
52 "decimal" => "GetDecimal",
53 "date" => "GetFieldValue<DateOnly>",
54 "time" | "time_tz" => "GetFieldValue<TimeOnly>",
55 "datetime" => "GetDateTime",
56 "datetime_tz" => "GetFieldValue<DateTimeOffset>",
57 _ => "GetValue",
58 }
59}
60
61fn rewrite_params(sql: &str) -> String {
63 let mut result = sql.to_string();
64 for i in (1..=99).rev() {
66 let from = format!("${}", i);
67 let to = format!("@p{}", i);
68 result = result.replace(&from, &to);
69 }
70 result
71}
72
73impl CodegenBackend for CsharpNpgsqlBackend {
74 fn name(&self) -> &str {
75 "csharp-npgsql"
76 }
77
78 fn generate_row_struct(
79 &self,
80 query_name: &str,
81 columns: &[ResolvedColumn],
82 ) -> Result<String, ScytheError> {
83 let struct_name = row_struct_name(query_name, &self.manifest.naming);
84 let mut out = String::new();
85 let _ = writeln!(out, "public record {}(", struct_name);
86 for (i, c) in columns.iter().enumerate() {
87 let field = to_pascal_case(&c.field_name);
88 let sep = if i + 1 < columns.len() { "," } else { "" };
89 let _ = writeln!(out, " {} {}{}", c.full_type, field, sep);
90 }
91 let _ = write!(out, ");");
92 Ok(out)
93 }
94
95 fn generate_model_struct(
96 &self,
97 table_name: &str,
98 columns: &[ResolvedColumn],
99 ) -> Result<String, ScytheError> {
100 let name = to_pascal_case(table_name);
101 self.generate_row_struct(&name, columns)
102 }
103
104 fn generate_query_fn(
105 &self,
106 analyzed: &AnalyzedQuery,
107 struct_name: &str,
108 columns: &[ResolvedColumn],
109 params: &[ResolvedParam],
110 ) -> Result<String, ScytheError> {
111 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
112 let sql = rewrite_params(&super::clean_sql(&analyzed.sql));
113 let mut out = String::new();
114
115 let param_list = params
117 .iter()
118 .map(|p| format!("{} {}", p.full_type, p.field_name))
119 .collect::<Vec<_>>()
120 .join(", ");
121 let sep = if param_list.is_empty() { "" } else { ", " };
122
123 let return_type = match &analyzed.command {
125 QueryCommand::One => format!("{}?", struct_name),
126 QueryCommand::Many | QueryCommand::Batch => {
127 format!("List<{}>", struct_name)
128 }
129 QueryCommand::Exec => "void".to_string(),
130 QueryCommand::ExecResult | QueryCommand::ExecRows => "int".to_string(),
131 };
132
133 let is_async_void = return_type == "void";
134 let task_type = if is_async_void {
135 "Task".to_string()
136 } else {
137 format!("Task<{}>", return_type)
138 };
139
140 let _ = writeln!(
141 out,
142 "public static async {} {}(NpgsqlConnection conn{}{}) {{",
143 task_type, func_name, sep, param_list
144 );
145
146 let _ = writeln!(
148 out,
149 " await using var cmd = new NpgsqlCommand(\"{}\", conn);",
150 sql
151 );
152 for (i, p) in params.iter().enumerate() {
153 let _ = writeln!(
154 out,
155 " cmd.Parameters.AddWithValue(\"p{}\", {});",
156 i + 1,
157 p.field_name
158 );
159 }
160
161 match &analyzed.command {
162 QueryCommand::One => {
163 let _ = writeln!(
164 out,
165 " await using var reader = await cmd.ExecuteReaderAsync();"
166 );
167 let _ = writeln!(out, " if (!await reader.ReadAsync()) return null;");
168 let _ = writeln!(out, " return new {}(", struct_name);
169 for (i, col) in columns.iter().enumerate() {
170 let method = reader_method(&col.neutral_type);
171 let sep = if i + 1 < columns.len() { "," } else { "" };
172 if col.nullable {
173 let _ = writeln!(
174 out,
175 " reader.IsDBNull({i}) ? null : reader.{method}({i}){sep}"
176 );
177 } else {
178 let _ = writeln!(out, " reader.{method}({i}){sep}");
179 }
180 }
181 let _ = writeln!(out, " );");
182 }
183 QueryCommand::Many | QueryCommand::Batch => {
184 let _ = writeln!(
185 out,
186 " await using var reader = await cmd.ExecuteReaderAsync();"
187 );
188 let _ = writeln!(out, " var results = new List<{}>();", struct_name);
189 let _ = writeln!(out, " while (await reader.ReadAsync()) {{");
190 let _ = writeln!(out, " results.Add(new {}(", struct_name);
191 for (i, col) in columns.iter().enumerate() {
192 let method = reader_method(&col.neutral_type);
193 let sep = if i + 1 < columns.len() { "," } else { "" };
194 if col.nullable {
195 let _ = writeln!(
196 out,
197 " reader.IsDBNull({i}) ? null : reader.{method}({i}){sep}"
198 );
199 } else {
200 let _ = writeln!(out, " reader.{method}({i}){sep}");
201 }
202 }
203 let _ = writeln!(out, " ));");
204 let _ = writeln!(out, " }}");
205 let _ = writeln!(out, " return results;");
206 }
207 QueryCommand::Exec => {
208 let _ = writeln!(out, " await cmd.ExecuteNonQueryAsync();");
209 }
210 QueryCommand::ExecResult | QueryCommand::ExecRows => {
211 let _ = writeln!(out, " return await cmd.ExecuteNonQueryAsync();");
212 }
213 }
214
215 let _ = write!(out, "}}");
216 Ok(out)
217 }
218
219 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
220 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
221 let mut out = String::new();
222 let _ = writeln!(out, "public enum {} {{", type_name);
223 for value in &enum_info.values {
224 let variant = enum_variant_name(value, &self.manifest.naming);
225 let _ = writeln!(out, " {},", variant);
226 }
227 let _ = write!(out, "}}");
228 Ok(out)
229 }
230
231 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
232 let name = to_pascal_case(&composite.sql_name);
233 let mut out = String::new();
234 if composite.fields.is_empty() {
235 let _ = writeln!(out, "public record {}();", name);
236 } else {
237 let _ = writeln!(out, "public record {}(", name);
238 for (i, field) in composite.fields.iter().enumerate() {
239 let cs_type = resolve_type(&field.neutral_type, &self.manifest, false)
240 .map(|t| t.into_owned())
241 .unwrap_or_else(|_| "object".to_string());
242 let field_name = to_pascal_case(&field.name);
243 let sep = if i + 1 < composite.fields.len() {
244 ","
245 } else {
246 ""
247 };
248 let _ = writeln!(out, " {} {}{}", cs_type, field_name, sep);
249 }
250 let _ = write!(out, ");");
251 }
252 Ok(out)
253 }
254}