Skip to main content

scythe_codegen/backends/
kotlin_jdbc.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16const DEFAULT_MANIFEST_PG: &str = include_str!("../../manifests/kotlin-jdbc.toml");
17const DEFAULT_MANIFEST_MYSQL: &str = include_str!("../../manifests/kotlin-jdbc.mysql.toml");
18const DEFAULT_MANIFEST_SQLITE: &str = include_str!("../../manifests/kotlin-jdbc.sqlite.toml");
19const DEFAULT_MANIFEST_DUCKDB: &str = include_str!("../../manifests/kotlin-jdbc.duckdb.toml");
20
21pub struct KotlinJdbcBackend {
22    manifest: BackendManifest,
23}
24
25impl KotlinJdbcBackend {
26    pub fn new(engine: &str) -> Result<Self, ScytheError> {
27        let default_toml = match engine {
28            "postgresql" | "postgres" | "pg" => DEFAULT_MANIFEST_PG,
29            "mysql" | "mariadb" => DEFAULT_MANIFEST_MYSQL,
30            "sqlite" | "sqlite3" => DEFAULT_MANIFEST_SQLITE,
31            "duckdb" => DEFAULT_MANIFEST_DUCKDB,
32            _ => {
33                return Err(ScytheError::new(
34                    ErrorCode::InternalError,
35                    format!("unsupported engine '{}' for kotlin-jdbc backend", engine),
36                ));
37            }
38        };
39        let manifest_path = Path::new("backends/kotlin-jdbc/manifest.toml");
40        let manifest = if manifest_path.exists() {
41            load_manifest(manifest_path)
42                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
43        } else {
44            toml::from_str(default_toml)
45                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
46        };
47        Ok(Self { manifest })
48    }
49}
50
51/// Convert PostgreSQL $1, $2, ... placeholders to JDBC ? placeholders.
52fn pg_to_jdbc_params(sql: &str) -> String {
53    let mut result = String::with_capacity(sql.len());
54    let mut chars = sql.chars().peekable();
55    while let Some(ch) = chars.next() {
56        if ch == '$' {
57            if chars.peek().is_some_and(|c| c.is_ascii_digit()) {
58                while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
59                    chars.next();
60                }
61                result.push('?');
62            } else {
63                result.push(ch);
64            }
65        } else {
66            result.push(ch);
67        }
68    }
69    result
70}
71
72/// Get the ResultSet getter method name for a given Kotlin type.
73fn rs_getter(kotlin_type: &str) -> &str {
74    match kotlin_type {
75        "Boolean" => "getBoolean",
76        "Byte" => "getByte",
77        "Short" => "getShort",
78        "Int" => "getInt",
79        "Long" => "getLong",
80        "Float" => "getFloat",
81        "Double" => "getDouble",
82        "String" => "getString",
83        "ByteArray" => "getBytes",
84        _ if kotlin_type.contains("BigDecimal") => "getBigDecimal",
85        _ if kotlin_type.contains("LocalDate") => "getObject",
86        _ if kotlin_type.contains("LocalTime") => "getObject",
87        _ if kotlin_type.contains("OffsetTime") => "getObject",
88        _ if kotlin_type.contains("LocalDateTime") => "getObject",
89        _ if kotlin_type.contains("OffsetDateTime") => "getObject",
90        _ if kotlin_type.contains("UUID") => "getObject",
91        _ => "getObject",
92    }
93}
94
95/// Get the PreparedStatement setter method name for a given Kotlin type.
96fn ps_setter(kotlin_type: &str) -> &str {
97    match kotlin_type {
98        "Boolean" => "setBoolean",
99        "Byte" => "setByte",
100        "Short" => "setShort",
101        "Int" => "setInt",
102        "Long" => "setLong",
103        "Float" => "setFloat",
104        "Double" => "setDouble",
105        "String" => "setString",
106        "ByteArray" => "setBytes",
107        _ if kotlin_type.contains("BigDecimal") => "setBigDecimal",
108        _ => "setObject",
109    }
110}
111
112impl CodegenBackend for KotlinJdbcBackend {
113    fn name(&self) -> &str {
114        "kotlin-jdbc"
115    }
116
117    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
118        &self.manifest
119    }
120
121    fn supported_engines(&self) -> &[&str] {
122        &["postgresql", "mysql", "sqlite", "duckdb"]
123    }
124
125    fn file_header(&self) -> String {
126        "import java.sql.Connection\n".to_string()
127    }
128
129    fn generate_row_struct(
130        &self,
131        query_name: &str,
132        columns: &[ResolvedColumn],
133    ) -> Result<String, ScytheError> {
134        let struct_name = row_struct_name(query_name, &self.manifest.naming);
135        let mut out = String::new();
136        let _ = writeln!(out, "data class {}(", struct_name);
137        for col in columns.iter() {
138            let _ = writeln!(out, "    val {}: {},", col.field_name, col.full_type);
139        }
140        let _ = writeln!(out, ")");
141        Ok(out)
142    }
143
144    fn generate_model_struct(
145        &self,
146        table_name: &str,
147        columns: &[ResolvedColumn],
148    ) -> Result<String, ScytheError> {
149        let name = to_pascal_case(table_name);
150        self.generate_row_struct(&name, columns)
151    }
152
153    fn generate_query_fn(
154        &self,
155        analyzed: &AnalyzedQuery,
156        struct_name: &str,
157        columns: &[ResolvedColumn],
158        params: &[ResolvedParam],
159    ) -> Result<String, ScytheError> {
160        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
161        let sql = pg_to_jdbc_params(&super::clean_sql_oneline_with_optional(
162            &analyzed.sql,
163            &analyzed.optional_params,
164            &analyzed.params,
165        ));
166
167        // Build function params: inline for single param (conn only), multi-line for 2+
168        let use_multiline_params = !params.is_empty();
169
170        let mut out = String::new();
171
172        // Helper: write param setters
173        let write_setters = |out: &mut String, params: &[ResolvedParam]| {
174            for (i, param) in params.iter().enumerate() {
175                let setter = ps_setter(&param.lang_type);
176                let _ = writeln!(
177                    out,
178                    "        ps.{}({}, {})",
179                    setter,
180                    i + 1,
181                    param.field_name
182                );
183            }
184        };
185
186        // Helper: write function signature
187        let write_fn_sig =
188            |out: &mut String, name: &str, ret: &str, multiline: bool, params: &[ResolvedParam]| {
189                if multiline {
190                    let _ = writeln!(out, "fun {}(", name);
191                    let _ = writeln!(out, "    conn: Connection,");
192                    for p in params {
193                        let _ = writeln!(out, "    {}: {},", p.field_name, p.full_type);
194                    }
195                    let _ = writeln!(out, "){} {{", ret);
196                } else {
197                    let _ = writeln!(out, "fun {}(conn: Connection){} {{", name, ret);
198                }
199            };
200
201        match &analyzed.command {
202            QueryCommand::Exec => {
203                write_fn_sig(&mut out, &func_name, "", use_multiline_params, params);
204                let _ = writeln!(out, "    conn.prepareStatement(\"{}\").use {{ ps ->", sql);
205                write_setters(&mut out, params);
206                let _ = writeln!(out, "        ps.executeUpdate()");
207                let _ = writeln!(out, "    }}");
208                let _ = writeln!(out, "}}");
209            }
210            QueryCommand::ExecResult | QueryCommand::ExecRows => {
211                write_fn_sig(&mut out, &func_name, ": Int", use_multiline_params, params);
212                let _ = writeln!(
213                    out,
214                    "    return conn.prepareStatement(\"{}\").use {{ ps ->",
215                    sql
216                );
217                write_setters(&mut out, params);
218                let _ = writeln!(out, "        ps.executeUpdate()");
219                let _ = writeln!(out, "    }}");
220                let _ = writeln!(out, "}}");
221            }
222            QueryCommand::One => {
223                let ret = format!(": {}?", struct_name);
224                write_fn_sig(&mut out, &func_name, &ret, use_multiline_params, params);
225                let _ = writeln!(out, "    conn.prepareStatement(\"{}\").use {{ ps ->", sql);
226                write_setters(&mut out, params);
227                let _ = writeln!(out, "        ps.executeQuery().use {{ rs ->");
228                let _ = writeln!(out, "            return if (rs.next()) {{");
229                let _ = writeln!(out, "                {}(", struct_name);
230                for col in columns.iter() {
231                    let getter = rs_getter(&col.lang_type);
232                    let _ = writeln!(
233                        out,
234                        "                    {} = rs.{}(\"{}\"),",
235                        col.field_name, getter, col.name
236                    );
237                }
238                let _ = writeln!(out, "                )");
239                let _ = writeln!(out, "            }} else {{");
240                let _ = writeln!(out, "                null");
241                let _ = writeln!(out, "            }}");
242                let _ = writeln!(out, "        }}");
243                let _ = writeln!(out, "    }}");
244                let _ = writeln!(out, "}}");
245            }
246            QueryCommand::Batch => {
247                let batch_fn_name = format!("{}Batch", func_name);
248                if params.len() > 1 {
249                    let params_class_name =
250                        format!("{}BatchParams", to_pascal_case(&analyzed.name));
251                    let _ = writeln!(out, "data class {}(", params_class_name);
252                    for p in params {
253                        let _ = writeln!(out, "    val {}: {},", p.field_name, p.full_type);
254                    }
255                    let _ = writeln!(out, ")");
256                    let _ = writeln!(out);
257                    let _ = writeln!(out, "fun {}(", batch_fn_name);
258                    let _ = writeln!(out, "    conn: Connection,");
259                    let _ = writeln!(out, "    items: List<{}>,", params_class_name);
260                    let _ = writeln!(out, ") {{");
261                    let _ = writeln!(out, "    conn.prepareStatement(\"{}\").use {{ ps ->", sql);
262                    let _ = writeln!(out, "        for (item in items) {{");
263                    for (i, param) in params.iter().enumerate() {
264                        let setter = ps_setter(&param.lang_type);
265                        let _ = writeln!(
266                            out,
267                            "            ps.{}({}, item.{})",
268                            setter,
269                            i + 1,
270                            param.field_name
271                        );
272                    }
273                    let _ = writeln!(out, "            ps.addBatch()");
274                    let _ = writeln!(out, "        }}");
275                    let _ = writeln!(out, "        ps.executeBatch()");
276                    let _ = writeln!(out, "    }}");
277                    let _ = writeln!(out, "}}");
278                } else if params.len() == 1 {
279                    let _ = writeln!(out, "fun {}(", batch_fn_name);
280                    let _ = writeln!(out, "    conn: Connection,");
281                    let _ = writeln!(out, "    items: List<{}>,", params[0].full_type);
282                    let _ = writeln!(out, ") {{");
283                    let _ = writeln!(out, "    conn.prepareStatement(\"{}\").use {{ ps ->", sql);
284                    let _ = writeln!(out, "        for (item in items) {{");
285                    let setter = ps_setter(&params[0].lang_type);
286                    let _ = writeln!(out, "            ps.{}(1, item)", setter);
287                    let _ = writeln!(out, "            ps.addBatch()");
288                    let _ = writeln!(out, "        }}");
289                    let _ = writeln!(out, "        ps.executeBatch()");
290                    let _ = writeln!(out, "    }}");
291                    let _ = writeln!(out, "}}");
292                } else {
293                    let _ = writeln!(
294                        out,
295                        "fun {}(conn: Connection, count: Int) {{",
296                        batch_fn_name
297                    );
298                    let _ = writeln!(out, "    conn.prepareStatement(\"{}\").use {{ ps ->", sql);
299                    let _ = writeln!(out, "        repeat(count) {{");
300                    let _ = writeln!(out, "            ps.addBatch()");
301                    let _ = writeln!(out, "        }}");
302                    let _ = writeln!(out, "        ps.executeBatch()");
303                    let _ = writeln!(out, "    }}");
304                    let _ = writeln!(out, "}}");
305                }
306            }
307            QueryCommand::Many => {
308                let ret = format!(": List<{}>", struct_name);
309                write_fn_sig(&mut out, &func_name, &ret, use_multiline_params, params);
310                let _ = writeln!(out, "    conn.prepareStatement(\"{}\").use {{ ps ->", sql);
311                write_setters(&mut out, params);
312                let _ = writeln!(out, "        ps.executeQuery().use {{ rs ->");
313                let _ = writeln!(
314                    out,
315                    "            val result = mutableListOf<{}>()",
316                    struct_name
317                );
318                let _ = writeln!(out, "            while (rs.next()) {{");
319                let _ = writeln!(out, "                result.add(");
320                let _ = writeln!(out, "                    {}(", struct_name);
321                for col in columns.iter() {
322                    let getter = rs_getter(&col.lang_type);
323                    let _ = writeln!(
324                        out,
325                        "                        {} = rs.{}(\"{}\"),",
326                        col.field_name, getter, col.name
327                    );
328                }
329                let _ = writeln!(out, "                    ),");
330                let _ = writeln!(out, "                )");
331                let _ = writeln!(out, "            }}");
332                let _ = writeln!(out, "            return result");
333                let _ = writeln!(out, "        }}");
334                let _ = writeln!(out, "    }}");
335                let _ = writeln!(out, "}}");
336            }
337            QueryCommand::Grouped => {
338                return Err(ScytheError::new(
339                    ErrorCode::InternalError,
340                    "grouped queries are not yet supported for kotlin-jdbc".to_string(),
341                ));
342            }
343        }
344
345        Ok(out)
346    }
347
348    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
349        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
350        let mut out = String::new();
351        let _ = writeln!(out, "enum class {}(val value: String) {{", type_name);
352        for (i, value) in enum_info.values.iter().enumerate() {
353            let variant = enum_variant_name(value, &self.manifest.naming);
354            let sep = if i + 1 < enum_info.values.len() {
355                ","
356            } else {
357                ";"
358            };
359            let _ = writeln!(out, "    {}(\"{}\"){}", variant, value, sep);
360        }
361        let _ = writeln!(out, "}}");
362        Ok(out)
363    }
364
365    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
366        let name = to_pascal_case(&composite.sql_name);
367        let mut out = String::new();
368        let _ = writeln!(out, "data class {}(", name);
369        for field in composite.fields.iter() {
370            let field_name = to_camel_case(&field.name);
371            let field_type = resolve_type(&field.neutral_type, &self.manifest, false)
372                .map(|t| t.into_owned())
373                .unwrap_or_else(|_| "Any".to_string());
374            let _ = writeln!(out, "    val {}: {},", field_name, field_type);
375        }
376        let _ = writeln!(out, ")");
377        Ok(out)
378    }
379}