Skip to main content

scythe_codegen/backends/
kotlin_jdbc.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16const DEFAULT_MANIFEST_PG: &str = include_str!("../../manifests/kotlin-jdbc.toml");
17const DEFAULT_MANIFEST_MYSQL: &str = include_str!("../../manifests/kotlin-jdbc.mysql.toml");
18const DEFAULT_MANIFEST_SQLITE: &str = include_str!("../../manifests/kotlin-jdbc.sqlite.toml");
19
20pub struct KotlinJdbcBackend {
21    manifest: BackendManifest,
22}
23
24impl KotlinJdbcBackend {
25    pub fn new(engine: &str) -> Result<Self, ScytheError> {
26        let default_toml = match engine {
27            "postgresql" | "postgres" | "pg" => DEFAULT_MANIFEST_PG,
28            "mysql" | "mariadb" => DEFAULT_MANIFEST_MYSQL,
29            "sqlite" | "sqlite3" => DEFAULT_MANIFEST_SQLITE,
30            _ => {
31                return Err(ScytheError::new(
32                    ErrorCode::InternalError,
33                    format!("unsupported engine '{}' for kotlin-jdbc backend", engine),
34                ));
35            }
36        };
37        let manifest_path = Path::new("backends/kotlin-jdbc/manifest.toml");
38        let manifest = if manifest_path.exists() {
39            load_manifest(manifest_path)
40                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
41        } else {
42            toml::from_str(default_toml)
43                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
44        };
45        Ok(Self { manifest })
46    }
47}
48
49/// Convert PostgreSQL $1, $2, ... placeholders to JDBC ? placeholders.
50fn pg_to_jdbc_params(sql: &str) -> String {
51    let mut result = String::with_capacity(sql.len());
52    let mut chars = sql.chars().peekable();
53    while let Some(ch) = chars.next() {
54        if ch == '$' {
55            if chars.peek().is_some_and(|c| c.is_ascii_digit()) {
56                while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
57                    chars.next();
58                }
59                result.push('?');
60            } else {
61                result.push(ch);
62            }
63        } else {
64            result.push(ch);
65        }
66    }
67    result
68}
69
70/// Get the ResultSet getter method name for a given Kotlin type.
71fn rs_getter(kotlin_type: &str) -> &str {
72    match kotlin_type {
73        "Boolean" => "getBoolean",
74        "Byte" => "getByte",
75        "Short" => "getShort",
76        "Int" => "getInt",
77        "Long" => "getLong",
78        "Float" => "getFloat",
79        "Double" => "getDouble",
80        "String" => "getString",
81        "ByteArray" => "getBytes",
82        _ if kotlin_type.contains("BigDecimal") => "getBigDecimal",
83        _ if kotlin_type.contains("LocalDate") => "getObject",
84        _ if kotlin_type.contains("LocalTime") => "getObject",
85        _ if kotlin_type.contains("OffsetTime") => "getObject",
86        _ if kotlin_type.contains("LocalDateTime") => "getObject",
87        _ if kotlin_type.contains("OffsetDateTime") => "getObject",
88        _ if kotlin_type.contains("UUID") => "getObject",
89        _ => "getObject",
90    }
91}
92
93/// Get the PreparedStatement setter method name for a given Kotlin type.
94fn ps_setter(kotlin_type: &str) -> &str {
95    match kotlin_type {
96        "Boolean" => "setBoolean",
97        "Byte" => "setByte",
98        "Short" => "setShort",
99        "Int" => "setInt",
100        "Long" => "setLong",
101        "Float" => "setFloat",
102        "Double" => "setDouble",
103        "String" => "setString",
104        "ByteArray" => "setBytes",
105        _ if kotlin_type.contains("BigDecimal") => "setBigDecimal",
106        _ => "setObject",
107    }
108}
109
110impl CodegenBackend for KotlinJdbcBackend {
111    fn name(&self) -> &str {
112        "kotlin-jdbc"
113    }
114
115    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
116        &self.manifest
117    }
118
119    fn supported_engines(&self) -> &[&str] {
120        &["postgresql", "mysql", "sqlite"]
121    }
122
123    fn file_header(&self) -> String {
124        "import java.sql.Connection\n".to_string()
125    }
126
127    fn generate_row_struct(
128        &self,
129        query_name: &str,
130        columns: &[ResolvedColumn],
131    ) -> Result<String, ScytheError> {
132        let struct_name = row_struct_name(query_name, &self.manifest.naming);
133        let mut out = String::new();
134        let _ = writeln!(out, "data class {}(", struct_name);
135        for col in columns.iter() {
136            let _ = writeln!(out, "    val {}: {},", col.field_name, col.full_type);
137        }
138        let _ = writeln!(out, ")");
139        Ok(out)
140    }
141
142    fn generate_model_struct(
143        &self,
144        table_name: &str,
145        columns: &[ResolvedColumn],
146    ) -> Result<String, ScytheError> {
147        let name = to_pascal_case(table_name);
148        self.generate_row_struct(&name, columns)
149    }
150
151    fn generate_query_fn(
152        &self,
153        analyzed: &AnalyzedQuery,
154        struct_name: &str,
155        columns: &[ResolvedColumn],
156        params: &[ResolvedParam],
157    ) -> Result<String, ScytheError> {
158        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
159        let sql = pg_to_jdbc_params(&super::clean_sql_oneline(&analyzed.sql));
160
161        // Build function params: inline for single param (conn only), multi-line for 2+
162        let use_multiline_params = !params.is_empty();
163
164        let mut out = String::new();
165
166        // Helper: write param setters
167        let write_setters = |out: &mut String, params: &[ResolvedParam]| {
168            for (i, param) in params.iter().enumerate() {
169                let setter = ps_setter(&param.lang_type);
170                let _ = writeln!(
171                    out,
172                    "        ps.{}({}, {})",
173                    setter,
174                    i + 1,
175                    param.field_name
176                );
177            }
178        };
179
180        // Helper: write function signature
181        let write_fn_sig =
182            |out: &mut String, name: &str, ret: &str, multiline: bool, params: &[ResolvedParam]| {
183                if multiline {
184                    let _ = writeln!(out, "fun {}(", name);
185                    let _ = writeln!(out, "    conn: Connection,");
186                    for p in params {
187                        let _ = writeln!(out, "    {}: {},", p.field_name, p.full_type);
188                    }
189                    let _ = writeln!(out, "){} {{", ret);
190                } else {
191                    let _ = writeln!(out, "fun {}(conn: Connection){} {{", name, ret);
192                }
193            };
194
195        match &analyzed.command {
196            QueryCommand::Exec => {
197                write_fn_sig(&mut out, &func_name, "", use_multiline_params, params);
198                let _ = writeln!(out, "    conn.prepareStatement(\"{}\").use {{ ps ->", sql);
199                write_setters(&mut out, params);
200                let _ = writeln!(out, "        ps.executeUpdate()");
201                let _ = writeln!(out, "    }}");
202                let _ = writeln!(out, "}}");
203            }
204            QueryCommand::ExecResult | QueryCommand::ExecRows => {
205                write_fn_sig(&mut out, &func_name, ": Int", use_multiline_params, params);
206                let _ = writeln!(
207                    out,
208                    "    return conn.prepareStatement(\"{}\").use {{ ps ->",
209                    sql
210                );
211                write_setters(&mut out, params);
212                let _ = writeln!(out, "        ps.executeUpdate()");
213                let _ = writeln!(out, "    }}");
214                let _ = writeln!(out, "}}");
215            }
216            QueryCommand::One => {
217                let ret = format!(": {}?", struct_name);
218                write_fn_sig(&mut out, &func_name, &ret, use_multiline_params, params);
219                let _ = writeln!(out, "    conn.prepareStatement(\"{}\").use {{ ps ->", sql);
220                write_setters(&mut out, params);
221                let _ = writeln!(out, "        ps.executeQuery().use {{ rs ->");
222                let _ = writeln!(out, "            return if (rs.next()) {{");
223                let _ = writeln!(out, "                {}(", struct_name);
224                for col in columns.iter() {
225                    let getter = rs_getter(&col.lang_type);
226                    let _ = writeln!(
227                        out,
228                        "                    {} = rs.{}(\"{}\"),",
229                        col.field_name, getter, col.name
230                    );
231                }
232                let _ = writeln!(out, "                )");
233                let _ = writeln!(out, "            }} else {{");
234                let _ = writeln!(out, "                null");
235                let _ = writeln!(out, "            }}");
236                let _ = writeln!(out, "        }}");
237                let _ = writeln!(out, "    }}");
238                let _ = writeln!(out, "}}");
239            }
240            QueryCommand::Many | QueryCommand::Batch => {
241                let ret = format!(": List<{}>", struct_name);
242                write_fn_sig(&mut out, &func_name, &ret, use_multiline_params, params);
243                let _ = writeln!(out, "    conn.prepareStatement(\"{}\").use {{ ps ->", sql);
244                write_setters(&mut out, params);
245                let _ = writeln!(out, "        ps.executeQuery().use {{ rs ->");
246                let _ = writeln!(
247                    out,
248                    "            val result = mutableListOf<{}>()",
249                    struct_name
250                );
251                let _ = writeln!(out, "            while (rs.next()) {{");
252                let _ = writeln!(out, "                result.add(");
253                let _ = writeln!(out, "                    {}(", struct_name);
254                for col in columns.iter() {
255                    let getter = rs_getter(&col.lang_type);
256                    let _ = writeln!(
257                        out,
258                        "                        {} = rs.{}(\"{}\"),",
259                        col.field_name, getter, col.name
260                    );
261                }
262                let _ = writeln!(out, "                    ),");
263                let _ = writeln!(out, "                )");
264                let _ = writeln!(out, "            }}");
265                let _ = writeln!(out, "            return result");
266                let _ = writeln!(out, "        }}");
267                let _ = writeln!(out, "    }}");
268                let _ = writeln!(out, "}}");
269            }
270        }
271
272        Ok(out)
273    }
274
275    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
276        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
277        let mut out = String::new();
278        let _ = writeln!(out, "enum class {}(val value: String) {{", type_name);
279        for (i, value) in enum_info.values.iter().enumerate() {
280            let variant = enum_variant_name(value, &self.manifest.naming);
281            let sep = if i + 1 < enum_info.values.len() {
282                ","
283            } else {
284                ";"
285            };
286            let _ = writeln!(out, "    {}(\"{}\"){}", variant, value, sep);
287        }
288        let _ = writeln!(out, "}}");
289        Ok(out)
290    }
291
292    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
293        let name = to_pascal_case(&composite.sql_name);
294        let mut out = String::new();
295        let _ = writeln!(out, "data class {}(", name);
296        for field in composite.fields.iter() {
297            let field_name = to_camel_case(&field.name);
298            let field_type = resolve_type(&field.neutral_type, &self.manifest, false)
299                .map(|t| t.into_owned())
300                .unwrap_or_else(|_| "Any".to_string());
301            let _ = writeln!(out, "    val {}: {},", field_name, field_type);
302        }
303        let _ = writeln!(out, ")");
304        Ok(out)
305    }
306}