1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16const DEFAULT_MANIFEST_PG: &str = include_str!("../../manifests/kotlin-jdbc.toml");
17const DEFAULT_MANIFEST_MYSQL: &str = include_str!("../../manifests/kotlin-jdbc.mysql.toml");
18const DEFAULT_MANIFEST_SQLITE: &str = include_str!("../../manifests/kotlin-jdbc.sqlite.toml");
19const DEFAULT_MANIFEST_DUCKDB: &str = include_str!("../../manifests/kotlin-jdbc.duckdb.toml");
20
21pub struct KotlinJdbcBackend {
22 manifest: BackendManifest,
23}
24
25impl KotlinJdbcBackend {
26 pub fn new(engine: &str) -> Result<Self, ScytheError> {
27 let default_toml = match engine {
28 "postgresql" | "postgres" | "pg" => DEFAULT_MANIFEST_PG,
29 "mysql" | "mariadb" => DEFAULT_MANIFEST_MYSQL,
30 "sqlite" | "sqlite3" => DEFAULT_MANIFEST_SQLITE,
31 "duckdb" => DEFAULT_MANIFEST_DUCKDB,
32 _ => {
33 return Err(ScytheError::new(
34 ErrorCode::InternalError,
35 format!("unsupported engine '{}' for kotlin-jdbc backend", engine),
36 ));
37 }
38 };
39 let manifest_path = Path::new("backends/kotlin-jdbc/manifest.toml");
40 let manifest = if manifest_path.exists() {
41 load_manifest(manifest_path)
42 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
43 } else {
44 toml::from_str(default_toml)
45 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
46 };
47 Ok(Self { manifest })
48 }
49}
50
51fn pg_to_jdbc_params(sql: &str) -> String {
53 let mut result = String::with_capacity(sql.len());
54 let mut chars = sql.chars().peekable();
55 while let Some(ch) = chars.next() {
56 if ch == '$' {
57 if chars.peek().is_some_and(|c| c.is_ascii_digit()) {
58 while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
59 chars.next();
60 }
61 result.push('?');
62 } else {
63 result.push(ch);
64 }
65 } else {
66 result.push(ch);
67 }
68 }
69 result
70}
71
72fn rs_getter(kotlin_type: &str) -> &str {
74 match kotlin_type {
75 "Boolean" => "getBoolean",
76 "Byte" => "getByte",
77 "Short" => "getShort",
78 "Int" => "getInt",
79 "Long" => "getLong",
80 "Float" => "getFloat",
81 "Double" => "getDouble",
82 "String" => "getString",
83 "ByteArray" => "getBytes",
84 _ if kotlin_type.contains("BigDecimal") => "getBigDecimal",
85 _ if kotlin_type.contains("LocalDate") => "getObject",
86 _ if kotlin_type.contains("LocalTime") => "getObject",
87 _ if kotlin_type.contains("OffsetTime") => "getObject",
88 _ if kotlin_type.contains("LocalDateTime") => "getObject",
89 _ if kotlin_type.contains("OffsetDateTime") => "getObject",
90 _ if kotlin_type.contains("UUID") => "getObject",
91 _ => "getObject",
92 }
93}
94
95fn ps_setter(kotlin_type: &str) -> &str {
97 match kotlin_type {
98 "Boolean" => "setBoolean",
99 "Byte" => "setByte",
100 "Short" => "setShort",
101 "Int" => "setInt",
102 "Long" => "setLong",
103 "Float" => "setFloat",
104 "Double" => "setDouble",
105 "String" => "setString",
106 "ByteArray" => "setBytes",
107 _ if kotlin_type.contains("BigDecimal") => "setBigDecimal",
108 _ => "setObject",
109 }
110}
111
112impl CodegenBackend for KotlinJdbcBackend {
113 fn name(&self) -> &str {
114 "kotlin-jdbc"
115 }
116
117 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
118 &self.manifest
119 }
120
121 fn supported_engines(&self) -> &[&str] {
122 &["postgresql", "mysql", "sqlite", "duckdb"]
123 }
124
125 fn file_header(&self) -> String {
126 "import java.sql.Connection\n".to_string()
127 }
128
129 fn generate_row_struct(
130 &self,
131 query_name: &str,
132 columns: &[ResolvedColumn],
133 ) -> Result<String, ScytheError> {
134 let struct_name = row_struct_name(query_name, &self.manifest.naming);
135 let mut out = String::new();
136 let _ = writeln!(out, "data class {}(", struct_name);
137 for col in columns.iter() {
138 let _ = writeln!(out, " val {}: {},", col.field_name, col.full_type);
139 }
140 let _ = writeln!(out, ")");
141 Ok(out)
142 }
143
144 fn generate_model_struct(
145 &self,
146 table_name: &str,
147 columns: &[ResolvedColumn],
148 ) -> Result<String, ScytheError> {
149 let name = to_pascal_case(table_name);
150 self.generate_row_struct(&name, columns)
151 }
152
153 fn generate_query_fn(
154 &self,
155 analyzed: &AnalyzedQuery,
156 struct_name: &str,
157 columns: &[ResolvedColumn],
158 params: &[ResolvedParam],
159 ) -> Result<String, ScytheError> {
160 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
161 let sql = pg_to_jdbc_params(&super::clean_sql_oneline_with_optional(
162 &analyzed.sql,
163 &analyzed.optional_params,
164 &analyzed.params,
165 ));
166
167 let use_multiline_params = !params.is_empty();
169
170 let mut out = String::new();
171
172 let write_setters = |out: &mut String, params: &[ResolvedParam]| {
174 for (i, param) in params.iter().enumerate() {
175 let setter = ps_setter(¶m.lang_type);
176 let _ = writeln!(
177 out,
178 " ps.{}({}, {})",
179 setter,
180 i + 1,
181 param.field_name
182 );
183 }
184 };
185
186 let write_fn_sig =
188 |out: &mut String, name: &str, ret: &str, multiline: bool, params: &[ResolvedParam]| {
189 if multiline {
190 let _ = writeln!(out, "fun {}(", name);
191 let _ = writeln!(out, " conn: Connection,");
192 for p in params {
193 let _ = writeln!(out, " {}: {},", p.field_name, p.full_type);
194 }
195 let _ = writeln!(out, "){} {{", ret);
196 } else {
197 let _ = writeln!(out, "fun {}(conn: Connection){} {{", name, ret);
198 }
199 };
200
201 match &analyzed.command {
202 QueryCommand::Exec => {
203 write_fn_sig(&mut out, &func_name, "", use_multiline_params, params);
204 let _ = writeln!(out, " conn.prepareStatement(\"{}\").use {{ ps ->", sql);
205 write_setters(&mut out, params);
206 let _ = writeln!(out, " ps.executeUpdate()");
207 let _ = writeln!(out, " }}");
208 let _ = writeln!(out, "}}");
209 }
210 QueryCommand::ExecResult | QueryCommand::ExecRows => {
211 write_fn_sig(&mut out, &func_name, ": Int", use_multiline_params, params);
212 let _ = writeln!(
213 out,
214 " return conn.prepareStatement(\"{}\").use {{ ps ->",
215 sql
216 );
217 write_setters(&mut out, params);
218 let _ = writeln!(out, " ps.executeUpdate()");
219 let _ = writeln!(out, " }}");
220 let _ = writeln!(out, "}}");
221 }
222 QueryCommand::One => {
223 let ret = format!(": {}?", struct_name);
224 write_fn_sig(&mut out, &func_name, &ret, use_multiline_params, params);
225 let _ = writeln!(out, " conn.prepareStatement(\"{}\").use {{ ps ->", sql);
226 write_setters(&mut out, params);
227 let _ = writeln!(out, " ps.executeQuery().use {{ rs ->");
228 let _ = writeln!(out, " return if (rs.next()) {{");
229 let _ = writeln!(out, " {}(", struct_name);
230 for col in columns.iter() {
231 let getter = rs_getter(&col.lang_type);
232 let _ = writeln!(
233 out,
234 " {} = rs.{}(\"{}\"),",
235 col.field_name, getter, col.name
236 );
237 }
238 let _ = writeln!(out, " )");
239 let _ = writeln!(out, " }} else {{");
240 let _ = writeln!(out, " null");
241 let _ = writeln!(out, " }}");
242 let _ = writeln!(out, " }}");
243 let _ = writeln!(out, " }}");
244 let _ = writeln!(out, "}}");
245 }
246 QueryCommand::Batch => {
247 let batch_fn_name = format!("{}Batch", func_name);
248 if params.len() > 1 {
249 let params_class_name =
250 format!("{}BatchParams", to_pascal_case(&analyzed.name));
251 let _ = writeln!(out, "data class {}(", params_class_name);
252 for p in params {
253 let _ = writeln!(out, " val {}: {},", p.field_name, p.full_type);
254 }
255 let _ = writeln!(out, ")");
256 let _ = writeln!(out);
257 let _ = writeln!(out, "fun {}(", batch_fn_name);
258 let _ = writeln!(out, " conn: Connection,");
259 let _ = writeln!(out, " items: List<{}>,", params_class_name);
260 let _ = writeln!(out, ") {{");
261 let _ = writeln!(out, " conn.prepareStatement(\"{}\").use {{ ps ->", sql);
262 let _ = writeln!(out, " for (item in items) {{");
263 for (i, param) in params.iter().enumerate() {
264 let setter = ps_setter(¶m.lang_type);
265 let _ = writeln!(
266 out,
267 " ps.{}({}, item.{})",
268 setter,
269 i + 1,
270 param.field_name
271 );
272 }
273 let _ = writeln!(out, " ps.addBatch()");
274 let _ = writeln!(out, " }}");
275 let _ = writeln!(out, " ps.executeBatch()");
276 let _ = writeln!(out, " }}");
277 let _ = writeln!(out, "}}");
278 } else if params.len() == 1 {
279 let _ = writeln!(out, "fun {}(", batch_fn_name);
280 let _ = writeln!(out, " conn: Connection,");
281 let _ = writeln!(out, " items: List<{}>,", params[0].full_type);
282 let _ = writeln!(out, ") {{");
283 let _ = writeln!(out, " conn.prepareStatement(\"{}\").use {{ ps ->", sql);
284 let _ = writeln!(out, " for (item in items) {{");
285 let setter = ps_setter(¶ms[0].lang_type);
286 let _ = writeln!(out, " ps.{}(1, item)", setter);
287 let _ = writeln!(out, " ps.addBatch()");
288 let _ = writeln!(out, " }}");
289 let _ = writeln!(out, " ps.executeBatch()");
290 let _ = writeln!(out, " }}");
291 let _ = writeln!(out, "}}");
292 } else {
293 let _ = writeln!(
294 out,
295 "fun {}(conn: Connection, count: Int) {{",
296 batch_fn_name
297 );
298 let _ = writeln!(out, " conn.prepareStatement(\"{}\").use {{ ps ->", sql);
299 let _ = writeln!(out, " repeat(count) {{");
300 let _ = writeln!(out, " ps.addBatch()");
301 let _ = writeln!(out, " }}");
302 let _ = writeln!(out, " ps.executeBatch()");
303 let _ = writeln!(out, " }}");
304 let _ = writeln!(out, "}}");
305 }
306 }
307 QueryCommand::Many => {
308 let ret = format!(": List<{}>", struct_name);
309 write_fn_sig(&mut out, &func_name, &ret, use_multiline_params, params);
310 let _ = writeln!(out, " conn.prepareStatement(\"{}\").use {{ ps ->", sql);
311 write_setters(&mut out, params);
312 let _ = writeln!(out, " ps.executeQuery().use {{ rs ->");
313 let _ = writeln!(
314 out,
315 " val result = mutableListOf<{}>()",
316 struct_name
317 );
318 let _ = writeln!(out, " while (rs.next()) {{");
319 let _ = writeln!(out, " result.add(");
320 let _ = writeln!(out, " {}(", struct_name);
321 for col in columns.iter() {
322 let getter = rs_getter(&col.lang_type);
323 let _ = writeln!(
324 out,
325 " {} = rs.{}(\"{}\"),",
326 col.field_name, getter, col.name
327 );
328 }
329 let _ = writeln!(out, " ),");
330 let _ = writeln!(out, " )");
331 let _ = writeln!(out, " }}");
332 let _ = writeln!(out, " return result");
333 let _ = writeln!(out, " }}");
334 let _ = writeln!(out, " }}");
335 let _ = writeln!(out, "}}");
336 }
337 QueryCommand::Grouped => {
338 return Err(ScytheError::new(
339 ErrorCode::InternalError,
340 "grouped queries are not yet supported for kotlin-jdbc".to_string(),
341 ));
342 }
343 }
344
345 Ok(out)
346 }
347
348 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
349 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
350 let mut out = String::new();
351 let _ = writeln!(out, "enum class {}(val value: String) {{", type_name);
352 for (i, value) in enum_info.values.iter().enumerate() {
353 let variant = enum_variant_name(value, &self.manifest.naming);
354 let sep = if i + 1 < enum_info.values.len() {
355 ","
356 } else {
357 ";"
358 };
359 let _ = writeln!(out, " {}(\"{}\"){}", variant, value, sep);
360 }
361 let _ = writeln!(out, "}}");
362 Ok(out)
363 }
364
365 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
366 let name = to_pascal_case(&composite.sql_name);
367 let mut out = String::new();
368 let _ = writeln!(out, "data class {}(", name);
369 for field in composite.fields.iter() {
370 let field_name = to_camel_case(&field.name);
371 let field_type = resolve_type(&field.neutral_type, &self.manifest, false)
372 .map(|t| t.into_owned())
373 .unwrap_or_else(|_| "Any".to_string());
374 let _ = writeln!(out, " val {}: {},", field_name, field_type);
375 }
376 let _ = writeln!(out, ")");
377 Ok(out)
378 }
379}