Skip to main content

scythe_codegen/backends/
kotlin_jdbc.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16const DEFAULT_MANIFEST_PG: &str = include_str!("../../manifests/kotlin-jdbc.toml");
17const DEFAULT_MANIFEST_MYSQL: &str = include_str!("../../manifests/kotlin-jdbc.mysql.toml");
18const DEFAULT_MANIFEST_SQLITE: &str = include_str!("../../manifests/kotlin-jdbc.sqlite.toml");
19
20pub struct KotlinJdbcBackend {
21    manifest: BackendManifest,
22}
23
24impl KotlinJdbcBackend {
25    pub fn new(engine: &str) -> Result<Self, ScytheError> {
26        let default_toml = match engine {
27            "postgresql" | "postgres" | "pg" => DEFAULT_MANIFEST_PG,
28            "mysql" | "mariadb" => DEFAULT_MANIFEST_MYSQL,
29            "sqlite" | "sqlite3" => DEFAULT_MANIFEST_SQLITE,
30            _ => {
31                return Err(ScytheError::new(
32                    ErrorCode::InternalError,
33                    format!("unsupported engine '{}' for kotlin-jdbc backend", engine),
34                ));
35            }
36        };
37        let manifest_path = Path::new("backends/kotlin-jdbc/manifest.toml");
38        let manifest = if manifest_path.exists() {
39            load_manifest(manifest_path)
40                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
41        } else {
42            toml::from_str(default_toml)
43                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
44        };
45        Ok(Self { manifest })
46    }
47}
48
49/// Convert PostgreSQL $1, $2, ... placeholders to JDBC ? placeholders.
50fn pg_to_jdbc_params(sql: &str) -> String {
51    let mut result = String::with_capacity(sql.len());
52    let mut chars = sql.chars().peekable();
53    while let Some(ch) = chars.next() {
54        if ch == '$' {
55            if chars.peek().is_some_and(|c| c.is_ascii_digit()) {
56                while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
57                    chars.next();
58                }
59                result.push('?');
60            } else {
61                result.push(ch);
62            }
63        } else {
64            result.push(ch);
65        }
66    }
67    result
68}
69
70/// Get the ResultSet getter method name for a given Kotlin type.
71fn rs_getter(kotlin_type: &str) -> &str {
72    match kotlin_type {
73        "Boolean" => "getBoolean",
74        "Byte" => "getByte",
75        "Short" => "getShort",
76        "Int" => "getInt",
77        "Long" => "getLong",
78        "Float" => "getFloat",
79        "Double" => "getDouble",
80        "String" => "getString",
81        "ByteArray" => "getBytes",
82        _ if kotlin_type.contains("BigDecimal") => "getBigDecimal",
83        _ if kotlin_type.contains("LocalDate") => "getObject",
84        _ if kotlin_type.contains("LocalTime") => "getObject",
85        _ if kotlin_type.contains("OffsetTime") => "getObject",
86        _ if kotlin_type.contains("LocalDateTime") => "getObject",
87        _ if kotlin_type.contains("OffsetDateTime") => "getObject",
88        _ if kotlin_type.contains("UUID") => "getObject",
89        _ => "getObject",
90    }
91}
92
93/// Get the PreparedStatement setter method name for a given Kotlin type.
94fn ps_setter(kotlin_type: &str) -> &str {
95    match kotlin_type {
96        "Boolean" => "setBoolean",
97        "Byte" => "setByte",
98        "Short" => "setShort",
99        "Int" => "setInt",
100        "Long" => "setLong",
101        "Float" => "setFloat",
102        "Double" => "setDouble",
103        "String" => "setString",
104        "ByteArray" => "setBytes",
105        _ if kotlin_type.contains("BigDecimal") => "setBigDecimal",
106        _ => "setObject",
107    }
108}
109
110impl CodegenBackend for KotlinJdbcBackend {
111    fn name(&self) -> &str {
112        "kotlin-jdbc"
113    }
114
115    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
116        &self.manifest
117    }
118
119    fn supported_engines(&self) -> &[&str] {
120        &["postgresql", "mysql", "sqlite"]
121    }
122
123    fn file_header(&self) -> String {
124        "import java.sql.Connection\n".to_string()
125    }
126
127    fn generate_row_struct(
128        &self,
129        query_name: &str,
130        columns: &[ResolvedColumn],
131    ) -> Result<String, ScytheError> {
132        let struct_name = row_struct_name(query_name, &self.manifest.naming);
133        let mut out = String::new();
134        let _ = writeln!(out, "data class {}(", struct_name);
135        for col in columns.iter() {
136            let _ = writeln!(out, "    val {}: {},", col.field_name, col.full_type);
137        }
138        let _ = writeln!(out, ")");
139        Ok(out)
140    }
141
142    fn generate_model_struct(
143        &self,
144        table_name: &str,
145        columns: &[ResolvedColumn],
146    ) -> Result<String, ScytheError> {
147        let name = to_pascal_case(table_name);
148        self.generate_row_struct(&name, columns)
149    }
150
151    fn generate_query_fn(
152        &self,
153        analyzed: &AnalyzedQuery,
154        struct_name: &str,
155        columns: &[ResolvedColumn],
156        params: &[ResolvedParam],
157    ) -> Result<String, ScytheError> {
158        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
159        let sql = pg_to_jdbc_params(&super::clean_sql_oneline_with_optional(
160            &analyzed.sql,
161            &analyzed.optional_params,
162            &analyzed.params,
163        ));
164
165        // Build function params: inline for single param (conn only), multi-line for 2+
166        let use_multiline_params = !params.is_empty();
167
168        let mut out = String::new();
169
170        // Helper: write param setters
171        let write_setters = |out: &mut String, params: &[ResolvedParam]| {
172            for (i, param) in params.iter().enumerate() {
173                let setter = ps_setter(&param.lang_type);
174                let _ = writeln!(
175                    out,
176                    "        ps.{}({}, {})",
177                    setter,
178                    i + 1,
179                    param.field_name
180                );
181            }
182        };
183
184        // Helper: write function signature
185        let write_fn_sig =
186            |out: &mut String, name: &str, ret: &str, multiline: bool, params: &[ResolvedParam]| {
187                if multiline {
188                    let _ = writeln!(out, "fun {}(", name);
189                    let _ = writeln!(out, "    conn: Connection,");
190                    for p in params {
191                        let _ = writeln!(out, "    {}: {},", p.field_name, p.full_type);
192                    }
193                    let _ = writeln!(out, "){} {{", ret);
194                } else {
195                    let _ = writeln!(out, "fun {}(conn: Connection){} {{", name, ret);
196                }
197            };
198
199        match &analyzed.command {
200            QueryCommand::Exec => {
201                write_fn_sig(&mut out, &func_name, "", use_multiline_params, params);
202                let _ = writeln!(out, "    conn.prepareStatement(\"{}\").use {{ ps ->", sql);
203                write_setters(&mut out, params);
204                let _ = writeln!(out, "        ps.executeUpdate()");
205                let _ = writeln!(out, "    }}");
206                let _ = writeln!(out, "}}");
207            }
208            QueryCommand::ExecResult | QueryCommand::ExecRows => {
209                write_fn_sig(&mut out, &func_name, ": Int", use_multiline_params, params);
210                let _ = writeln!(
211                    out,
212                    "    return conn.prepareStatement(\"{}\").use {{ ps ->",
213                    sql
214                );
215                write_setters(&mut out, params);
216                let _ = writeln!(out, "        ps.executeUpdate()");
217                let _ = writeln!(out, "    }}");
218                let _ = writeln!(out, "}}");
219            }
220            QueryCommand::One => {
221                let ret = format!(": {}?", struct_name);
222                write_fn_sig(&mut out, &func_name, &ret, use_multiline_params, params);
223                let _ = writeln!(out, "    conn.prepareStatement(\"{}\").use {{ ps ->", sql);
224                write_setters(&mut out, params);
225                let _ = writeln!(out, "        ps.executeQuery().use {{ rs ->");
226                let _ = writeln!(out, "            return if (rs.next()) {{");
227                let _ = writeln!(out, "                {}(", struct_name);
228                for col in columns.iter() {
229                    let getter = rs_getter(&col.lang_type);
230                    let _ = writeln!(
231                        out,
232                        "                    {} = rs.{}(\"{}\"),",
233                        col.field_name, getter, col.name
234                    );
235                }
236                let _ = writeln!(out, "                )");
237                let _ = writeln!(out, "            }} else {{");
238                let _ = writeln!(out, "                null");
239                let _ = writeln!(out, "            }}");
240                let _ = writeln!(out, "        }}");
241                let _ = writeln!(out, "    }}");
242                let _ = writeln!(out, "}}");
243            }
244            QueryCommand::Batch => {
245                let batch_fn_name = format!("{}Batch", func_name);
246                if params.len() > 1 {
247                    let params_class_name =
248                        format!("{}BatchParams", to_pascal_case(&analyzed.name));
249                    let _ = writeln!(out, "data class {}(", params_class_name);
250                    for p in params {
251                        let _ = writeln!(out, "    val {}: {},", p.field_name, p.full_type);
252                    }
253                    let _ = writeln!(out, ")");
254                    let _ = writeln!(out);
255                    let _ = writeln!(out, "fun {}(", batch_fn_name);
256                    let _ = writeln!(out, "    conn: Connection,");
257                    let _ = writeln!(out, "    items: List<{}>,", params_class_name);
258                    let _ = writeln!(out, ") {{");
259                    let _ = writeln!(out, "    conn.prepareStatement(\"{}\").use {{ ps ->", sql);
260                    let _ = writeln!(out, "        for (item in items) {{");
261                    for (i, param) in params.iter().enumerate() {
262                        let setter = ps_setter(&param.lang_type);
263                        let _ = writeln!(
264                            out,
265                            "            ps.{}({}, item.{})",
266                            setter,
267                            i + 1,
268                            param.field_name
269                        );
270                    }
271                    let _ = writeln!(out, "            ps.addBatch()");
272                    let _ = writeln!(out, "        }}");
273                    let _ = writeln!(out, "        ps.executeBatch()");
274                    let _ = writeln!(out, "    }}");
275                    let _ = writeln!(out, "}}");
276                } else if params.len() == 1 {
277                    let _ = writeln!(out, "fun {}(", batch_fn_name);
278                    let _ = writeln!(out, "    conn: Connection,");
279                    let _ = writeln!(out, "    items: List<{}>,", params[0].full_type);
280                    let _ = writeln!(out, ") {{");
281                    let _ = writeln!(out, "    conn.prepareStatement(\"{}\").use {{ ps ->", sql);
282                    let _ = writeln!(out, "        for (item in items) {{");
283                    let setter = ps_setter(&params[0].lang_type);
284                    let _ = writeln!(out, "            ps.{}(1, item)", setter);
285                    let _ = writeln!(out, "            ps.addBatch()");
286                    let _ = writeln!(out, "        }}");
287                    let _ = writeln!(out, "        ps.executeBatch()");
288                    let _ = writeln!(out, "    }}");
289                    let _ = writeln!(out, "}}");
290                } else {
291                    let _ = writeln!(
292                        out,
293                        "fun {}(conn: Connection, count: Int) {{",
294                        batch_fn_name
295                    );
296                    let _ = writeln!(out, "    conn.prepareStatement(\"{}\").use {{ ps ->", sql);
297                    let _ = writeln!(out, "        repeat(count) {{");
298                    let _ = writeln!(out, "            ps.addBatch()");
299                    let _ = writeln!(out, "        }}");
300                    let _ = writeln!(out, "        ps.executeBatch()");
301                    let _ = writeln!(out, "    }}");
302                    let _ = writeln!(out, "}}");
303                }
304            }
305            QueryCommand::Many => {
306                let ret = format!(": List<{}>", struct_name);
307                write_fn_sig(&mut out, &func_name, &ret, use_multiline_params, params);
308                let _ = writeln!(out, "    conn.prepareStatement(\"{}\").use {{ ps ->", sql);
309                write_setters(&mut out, params);
310                let _ = writeln!(out, "        ps.executeQuery().use {{ rs ->");
311                let _ = writeln!(
312                    out,
313                    "            val result = mutableListOf<{}>()",
314                    struct_name
315                );
316                let _ = writeln!(out, "            while (rs.next()) {{");
317                let _ = writeln!(out, "                result.add(");
318                let _ = writeln!(out, "                    {}(", struct_name);
319                for col in columns.iter() {
320                    let getter = rs_getter(&col.lang_type);
321                    let _ = writeln!(
322                        out,
323                        "                        {} = rs.{}(\"{}\"),",
324                        col.field_name, getter, col.name
325                    );
326                }
327                let _ = writeln!(out, "                    ),");
328                let _ = writeln!(out, "                )");
329                let _ = writeln!(out, "            }}");
330                let _ = writeln!(out, "            return result");
331                let _ = writeln!(out, "        }}");
332                let _ = writeln!(out, "    }}");
333                let _ = writeln!(out, "}}");
334            }
335        }
336
337        Ok(out)
338    }
339
340    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
341        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
342        let mut out = String::new();
343        let _ = writeln!(out, "enum class {}(val value: String) {{", type_name);
344        for (i, value) in enum_info.values.iter().enumerate() {
345            let variant = enum_variant_name(value, &self.manifest.naming);
346            let sep = if i + 1 < enum_info.values.len() {
347                ","
348            } else {
349                ";"
350            };
351            let _ = writeln!(out, "    {}(\"{}\"){}", variant, value, sep);
352        }
353        let _ = writeln!(out, "}}");
354        Ok(out)
355    }
356
357    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
358        let name = to_pascal_case(&composite.sql_name);
359        let mut out = String::new();
360        let _ = writeln!(out, "data class {}(", name);
361        for field in composite.fields.iter() {
362            let field_name = to_camel_case(&field.name);
363            let field_type = resolve_type(&field.neutral_type, &self.manifest, false)
364                .map(|t| t.into_owned())
365                .unwrap_or_else(|_| "Any".to_string());
366            let _ = writeln!(out, "    val {}: {},", field_name, field_type);
367        }
368        let _ = writeln!(out, ")");
369        Ok(out)
370    }
371}