Skip to main content

scythe_codegen/backends/
kotlin_jdbc.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/kotlin-jdbc.toml");
17
18pub struct KotlinJdbcBackend {
19    manifest: BackendManifest,
20}
21
22impl KotlinJdbcBackend {
23    pub fn new() -> Result<Self, ScytheError> {
24        let manifest_path = Path::new("backends/kotlin-jdbc/manifest.toml");
25        let manifest = if manifest_path.exists() {
26            load_manifest(manifest_path)
27                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
28        } else {
29            toml::from_str(DEFAULT_MANIFEST_TOML)
30                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
31        };
32        Ok(Self { manifest })
33    }
34
35    pub fn manifest(&self) -> &BackendManifest {
36        &self.manifest
37    }
38}
39
40/// Convert PostgreSQL $1, $2, ... placeholders to JDBC ? placeholders.
41fn pg_to_jdbc_params(sql: &str) -> String {
42    let mut result = String::with_capacity(sql.len());
43    let mut chars = sql.chars().peekable();
44    while let Some(ch) = chars.next() {
45        if ch == '$' {
46            if chars.peek().is_some_and(|c| c.is_ascii_digit()) {
47                while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
48                    chars.next();
49                }
50                result.push('?');
51            } else {
52                result.push(ch);
53            }
54        } else {
55            result.push(ch);
56        }
57    }
58    result
59}
60
61/// Get the ResultSet getter method name for a given Kotlin type.
62fn rs_getter(kotlin_type: &str) -> &str {
63    match kotlin_type {
64        "Boolean" => "getBoolean",
65        "Byte" => "getByte",
66        "Short" => "getShort",
67        "Int" => "getInt",
68        "Long" => "getLong",
69        "Float" => "getFloat",
70        "Double" => "getDouble",
71        "String" => "getString",
72        "ByteArray" => "getBytes",
73        _ if kotlin_type.contains("BigDecimal") => "getBigDecimal",
74        _ if kotlin_type.contains("LocalDate") => "getObject",
75        _ if kotlin_type.contains("LocalTime") => "getObject",
76        _ if kotlin_type.contains("OffsetTime") => "getObject",
77        _ if kotlin_type.contains("LocalDateTime") => "getObject",
78        _ if kotlin_type.contains("OffsetDateTime") => "getObject",
79        _ if kotlin_type.contains("UUID") => "getObject",
80        _ => "getObject",
81    }
82}
83
84/// Get the PreparedStatement setter method name for a given Kotlin type.
85fn ps_setter(kotlin_type: &str) -> &str {
86    match kotlin_type {
87        "Boolean" => "setBoolean",
88        "Byte" => "setByte",
89        "Short" => "setShort",
90        "Int" => "setInt",
91        "Long" => "setLong",
92        "Float" => "setFloat",
93        "Double" => "setDouble",
94        "String" => "setString",
95        "ByteArray" => "setBytes",
96        _ if kotlin_type.contains("BigDecimal") => "setBigDecimal",
97        _ => "setObject",
98    }
99}
100
101impl CodegenBackend for KotlinJdbcBackend {
102    fn name(&self) -> &str {
103        "kotlin-jdbc"
104    }
105
106    fn file_header(&self) -> String {
107        "import java.sql.Connection\n".to_string()
108    }
109
110    fn generate_row_struct(
111        &self,
112        query_name: &str,
113        columns: &[ResolvedColumn],
114    ) -> Result<String, ScytheError> {
115        let struct_name = row_struct_name(query_name, &self.manifest.naming);
116        let mut out = String::new();
117        let _ = writeln!(out, "data class {}(", struct_name);
118        for col in columns.iter() {
119            let _ = writeln!(out, "    val {}: {},", col.field_name, col.full_type);
120        }
121        let _ = writeln!(out, ")");
122        Ok(out)
123    }
124
125    fn generate_model_struct(
126        &self,
127        table_name: &str,
128        columns: &[ResolvedColumn],
129    ) -> Result<String, ScytheError> {
130        let name = to_pascal_case(table_name);
131        self.generate_row_struct(&name, columns)
132    }
133
134    fn generate_query_fn(
135        &self,
136        analyzed: &AnalyzedQuery,
137        struct_name: &str,
138        columns: &[ResolvedColumn],
139        params: &[ResolvedParam],
140    ) -> Result<String, ScytheError> {
141        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
142        let sql = pg_to_jdbc_params(&super::clean_sql_oneline(&analyzed.sql));
143
144        // Build function params: inline for single param (conn only), multi-line for 2+
145        let use_multiline_params = !params.is_empty();
146
147        let mut out = String::new();
148
149        // Helper: write param setters
150        let write_setters = |out: &mut String, params: &[ResolvedParam]| {
151            for (i, param) in params.iter().enumerate() {
152                let setter = ps_setter(&param.lang_type);
153                let _ = writeln!(
154                    out,
155                    "        ps.{}({}, {})",
156                    setter,
157                    i + 1,
158                    param.field_name
159                );
160            }
161        };
162
163        // Helper: write function signature
164        let write_fn_sig =
165            |out: &mut String, name: &str, ret: &str, multiline: bool, params: &[ResolvedParam]| {
166                if multiline {
167                    let _ = writeln!(out, "fun {}(", name);
168                    let _ = writeln!(out, "    conn: Connection,");
169                    for p in params {
170                        let _ = writeln!(out, "    {}: {},", p.field_name, p.full_type);
171                    }
172                    let _ = writeln!(out, "){} {{", ret);
173                } else {
174                    let _ = writeln!(out, "fun {}(conn: Connection){} {{", name, ret);
175                }
176            };
177
178        match &analyzed.command {
179            QueryCommand::Exec => {
180                write_fn_sig(&mut out, &func_name, "", use_multiline_params, params);
181                let _ = writeln!(out, "    conn.prepareStatement(\"{}\").use {{ ps ->", sql);
182                write_setters(&mut out, params);
183                let _ = writeln!(out, "        ps.executeUpdate()");
184                let _ = writeln!(out, "    }}");
185                let _ = writeln!(out, "}}");
186            }
187            QueryCommand::ExecResult | QueryCommand::ExecRows => {
188                write_fn_sig(&mut out, &func_name, ": Int", use_multiline_params, params);
189                let _ = writeln!(
190                    out,
191                    "    return conn.prepareStatement(\"{}\").use {{ ps ->",
192                    sql
193                );
194                write_setters(&mut out, params);
195                let _ = writeln!(out, "        ps.executeUpdate()");
196                let _ = writeln!(out, "    }}");
197                let _ = writeln!(out, "}}");
198            }
199            QueryCommand::One => {
200                let ret = format!(": {}?", struct_name);
201                write_fn_sig(&mut out, &func_name, &ret, use_multiline_params, params);
202                let _ = writeln!(out, "    conn.prepareStatement(\"{}\").use {{ ps ->", sql);
203                write_setters(&mut out, params);
204                let _ = writeln!(out, "        ps.executeQuery().use {{ rs ->");
205                let _ = writeln!(out, "            return if (rs.next()) {{");
206                let _ = writeln!(out, "                {}(", struct_name);
207                for col in columns.iter() {
208                    let getter = rs_getter(&col.lang_type);
209                    let _ = writeln!(
210                        out,
211                        "                    {} = rs.{}(\"{}\"),",
212                        col.field_name, getter, col.name
213                    );
214                }
215                let _ = writeln!(out, "                )");
216                let _ = writeln!(out, "            }} else {{");
217                let _ = writeln!(out, "                null");
218                let _ = writeln!(out, "            }}");
219                let _ = writeln!(out, "        }}");
220                let _ = writeln!(out, "    }}");
221                let _ = writeln!(out, "}}");
222            }
223            QueryCommand::Many | QueryCommand::Batch => {
224                let ret = format!(": List<{}>", struct_name);
225                write_fn_sig(&mut out, &func_name, &ret, use_multiline_params, params);
226                let _ = writeln!(out, "    conn.prepareStatement(\"{}\").use {{ ps ->", sql);
227                write_setters(&mut out, params);
228                let _ = writeln!(out, "        ps.executeQuery().use {{ rs ->");
229                let _ = writeln!(
230                    out,
231                    "            val result = mutableListOf<{}>()",
232                    struct_name
233                );
234                let _ = writeln!(out, "            while (rs.next()) {{");
235                let _ = writeln!(out, "                result.add(");
236                let _ = writeln!(out, "                    {}(", struct_name);
237                for col in columns.iter() {
238                    let getter = rs_getter(&col.lang_type);
239                    let _ = writeln!(
240                        out,
241                        "                        {} = rs.{}(\"{}\"),",
242                        col.field_name, getter, col.name
243                    );
244                }
245                let _ = writeln!(out, "                    ),");
246                let _ = writeln!(out, "                )");
247                let _ = writeln!(out, "            }}");
248                let _ = writeln!(out, "            return result");
249                let _ = writeln!(out, "        }}");
250                let _ = writeln!(out, "    }}");
251                let _ = writeln!(out, "}}");
252            }
253        }
254
255        Ok(out)
256    }
257
258    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
259        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
260        let mut out = String::new();
261        let _ = writeln!(out, "enum class {}(val value: String) {{", type_name);
262        for (i, value) in enum_info.values.iter().enumerate() {
263            let variant = enum_variant_name(value, &self.manifest.naming);
264            let sep = if i + 1 < enum_info.values.len() {
265                ","
266            } else {
267                ";"
268            };
269            let _ = writeln!(out, "    {}(\"{}\"){}", variant, value, sep);
270        }
271        let _ = writeln!(out, "}}");
272        Ok(out)
273    }
274
275    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
276        let name = to_pascal_case(&composite.sql_name);
277        let mut out = String::new();
278        let _ = writeln!(out, "data class {}(", name);
279        for field in composite.fields.iter() {
280            let field_name = to_camel_case(&field.name);
281            let field_type = resolve_type(&field.neutral_type, &self.manifest, false)
282                .map(|t| t.into_owned())
283                .unwrap_or_else(|_| "Any".to_string());
284            let _ = writeln!(out, "    val {}: {},", field_name, field_type);
285        }
286        let _ = writeln!(out, ")");
287        Ok(out)
288    }
289}