Skip to main content

scythe_codegen/backends/
csharp_npgsql.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/csharp-npgsql.toml");
17
18pub struct CsharpNpgsqlBackend {
19    manifest: BackendManifest,
20}
21
22impl CsharpNpgsqlBackend {
23    pub fn new() -> Result<Self, ScytheError> {
24        let manifest_path = Path::new("backends/csharp-npgsql/manifest.toml");
25        let manifest = if manifest_path.exists() {
26            load_manifest(manifest_path)
27                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
28        } else {
29            toml::from_str(DEFAULT_MANIFEST_TOML)
30                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
31        };
32        Ok(Self { manifest })
33    }
34
35    pub fn manifest(&self) -> &BackendManifest {
36        &self.manifest
37    }
38}
39
40/// Map a neutral type to an Npgsql reader method.
41fn reader_method(neutral_type: &str) -> &'static str {
42    match neutral_type {
43        "bool" => "GetBoolean",
44        "int16" => "GetInt16",
45        "int32" => "GetInt32",
46        "int64" => "GetInt64",
47        "float32" => "GetFloat",
48        "float64" => "GetDouble",
49        "string" | "json" | "inet" | "interval" => "GetString",
50        "uuid" => "GetGuid",
51        "decimal" => "GetDecimal",
52        "date" => "GetFieldValue<DateOnly>",
53        "time" | "time_tz" => "GetFieldValue<TimeOnly>",
54        "datetime" => "GetDateTime",
55        "datetime_tz" => "GetFieldValue<DateTimeOffset>",
56        _ => "GetValue",
57    }
58}
59
60/// Rewrite $1, $2, ... to @p1, @p2, ...
61fn rewrite_params(sql: &str) -> String {
62    let mut result = sql.to_string();
63    // Replace from highest number down to avoid $1 matching inside $10
64    for i in (1..=99).rev() {
65        let from = format!("${}", i);
66        let to = format!("@p{}", i);
67        result = result.replace(&from, &to);
68    }
69    result
70}
71
72impl CodegenBackend for CsharpNpgsqlBackend {
73    fn name(&self) -> &str {
74        "csharp-npgsql"
75    }
76
77    fn generate_row_struct(
78        &self,
79        query_name: &str,
80        columns: &[ResolvedColumn],
81    ) -> Result<String, ScytheError> {
82        let struct_name = row_struct_name(query_name, &self.manifest.naming);
83        let mut out = String::new();
84        let _ = writeln!(out, "public record {}(", struct_name);
85        for (i, c) in columns.iter().enumerate() {
86            let field = to_pascal_case(&c.field_name);
87            let sep = if i + 1 < columns.len() { "," } else { "" };
88            let _ = writeln!(out, "    {} {}{}", c.full_type, field, sep);
89        }
90        let _ = write!(out, ");");
91        Ok(out)
92    }
93
94    fn generate_model_struct(
95        &self,
96        table_name: &str,
97        columns: &[ResolvedColumn],
98    ) -> Result<String, ScytheError> {
99        let name = to_pascal_case(table_name);
100        self.generate_row_struct(&name, columns)
101    }
102
103    fn generate_query_fn(
104        &self,
105        analyzed: &AnalyzedQuery,
106        struct_name: &str,
107        columns: &[ResolvedColumn],
108        params: &[ResolvedParam],
109    ) -> Result<String, ScytheError> {
110        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
111        let sql = rewrite_params(&super::clean_sql(&analyzed.sql));
112        let mut out = String::new();
113
114        // Build C# parameter list
115        let param_list = params
116            .iter()
117            .map(|p| format!("{} {}", p.full_type, p.field_name))
118            .collect::<Vec<_>>()
119            .join(", ");
120        let sep = if param_list.is_empty() { "" } else { ", " };
121
122        // Return type depends on command
123        let return_type = match &analyzed.command {
124            QueryCommand::One => format!("{}?", struct_name),
125            QueryCommand::Many | QueryCommand::Batch => {
126                format!("List<{}>", struct_name)
127            }
128            QueryCommand::Exec => "void".to_string(),
129            QueryCommand::ExecResult | QueryCommand::ExecRows => "int".to_string(),
130        };
131
132        let is_async_void = return_type == "void";
133        let task_type = if is_async_void {
134            "Task".to_string()
135        } else {
136            format!("Task<{}>", return_type)
137        };
138
139        let _ = writeln!(
140            out,
141            "public static async {} {}(NpgsqlConnection conn{}{}) {{",
142            task_type, func_name, sep, param_list
143        );
144
145        // Command setup
146        let _ = writeln!(
147            out,
148            "    await using var cmd = new NpgsqlCommand(\"{}\", conn);",
149            sql
150        );
151        for (i, p) in params.iter().enumerate() {
152            let _ = writeln!(
153                out,
154                "    cmd.Parameters.AddWithValue(\"p{}\", {});",
155                i + 1,
156                p.field_name
157            );
158        }
159
160        match &analyzed.command {
161            QueryCommand::One => {
162                let _ = writeln!(
163                    out,
164                    "    await using var reader = await cmd.ExecuteReaderAsync();"
165                );
166                let _ = writeln!(out, "    if (!await reader.ReadAsync()) return null;");
167                let _ = writeln!(out, "    return new {}(", struct_name);
168                for (i, col) in columns.iter().enumerate() {
169                    let method = reader_method(&col.neutral_type);
170                    let sep = if i + 1 < columns.len() { "," } else { "" };
171                    if col.nullable {
172                        let _ = writeln!(
173                            out,
174                            "        reader.IsDBNull({i}) ? null : reader.{method}({i}){sep}"
175                        );
176                    } else {
177                        let _ = writeln!(out, "        reader.{method}({i}){sep}");
178                    }
179                }
180                let _ = writeln!(out, "    );");
181            }
182            QueryCommand::Many | QueryCommand::Batch => {
183                let _ = writeln!(
184                    out,
185                    "    await using var reader = await cmd.ExecuteReaderAsync();"
186                );
187                let _ = writeln!(out, "    var results = new List<{}>();", struct_name);
188                let _ = writeln!(out, "    while (await reader.ReadAsync()) {{");
189                let _ = writeln!(out, "        results.Add(new {}(", struct_name);
190                for (i, col) in columns.iter().enumerate() {
191                    let method = reader_method(&col.neutral_type);
192                    let sep = if i + 1 < columns.len() { "," } else { "" };
193                    if col.nullable {
194                        let _ = writeln!(
195                            out,
196                            "            reader.IsDBNull({i}) ? null : reader.{method}({i}){sep}"
197                        );
198                    } else {
199                        let _ = writeln!(out, "            reader.{method}({i}){sep}");
200                    }
201                }
202                let _ = writeln!(out, "        ));");
203                let _ = writeln!(out, "    }}");
204                let _ = writeln!(out, "    return results;");
205            }
206            QueryCommand::Exec => {
207                let _ = writeln!(out, "    await cmd.ExecuteNonQueryAsync();");
208            }
209            QueryCommand::ExecResult | QueryCommand::ExecRows => {
210                let _ = writeln!(out, "    return await cmd.ExecuteNonQueryAsync();");
211            }
212        }
213
214        let _ = write!(out, "}}");
215        Ok(out)
216    }
217
218    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
219        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
220        let mut out = String::new();
221        let _ = writeln!(out, "public enum {} {{", type_name);
222        for value in &enum_info.values {
223            let variant = enum_variant_name(value, &self.manifest.naming);
224            let _ = writeln!(out, "    {},", variant);
225        }
226        let _ = write!(out, "}}");
227        Ok(out)
228    }
229
230    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
231        let name = to_pascal_case(&composite.sql_name);
232        let mut out = String::new();
233        if composite.fields.is_empty() {
234            let _ = writeln!(out, "public record {}();", name);
235        } else {
236            let _ = writeln!(out, "public record {}(", name);
237            for (i, field) in composite.fields.iter().enumerate() {
238                let cs_type = resolve_type(&field.neutral_type, &self.manifest, false)
239                    .map(|t| t.into_owned())
240                    .unwrap_or_else(|_| "object".to_string());
241                let field_name = to_pascal_case(&field.name);
242                let sep = if i + 1 < composite.fields.len() {
243                    ","
244                } else {
245                    ""
246                };
247                let _ = writeln!(out, "    {} {}{}", cs_type, field_name, sep);
248            }
249            let _ = write!(out, ");");
250        }
251        Ok(out)
252    }
253}