1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16const DEFAULT_MANIFEST_PG: &str = include_str!("../../manifests/kotlin-jdbc.toml");
17const DEFAULT_MANIFEST_MYSQL: &str = include_str!("../../manifests/kotlin-jdbc.mysql.toml");
18const DEFAULT_MANIFEST_SQLITE: &str = include_str!("../../manifests/kotlin-jdbc.sqlite.toml");
19
20pub struct KotlinJdbcBackend {
21 manifest: BackendManifest,
22}
23
24impl KotlinJdbcBackend {
25 pub fn new(engine: &str) -> Result<Self, ScytheError> {
26 let default_toml = match engine {
27 "postgresql" | "postgres" | "pg" => DEFAULT_MANIFEST_PG,
28 "mysql" | "mariadb" => DEFAULT_MANIFEST_MYSQL,
29 "sqlite" | "sqlite3" => DEFAULT_MANIFEST_SQLITE,
30 _ => {
31 return Err(ScytheError::new(
32 ErrorCode::InternalError,
33 format!("unsupported engine '{}' for kotlin-jdbc backend", engine),
34 ));
35 }
36 };
37 let manifest_path = Path::new("backends/kotlin-jdbc/manifest.toml");
38 let manifest = if manifest_path.exists() {
39 load_manifest(manifest_path)
40 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
41 } else {
42 toml::from_str(default_toml)
43 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
44 };
45 Ok(Self { manifest })
46 }
47}
48
49fn pg_to_jdbc_params(sql: &str) -> String {
51 let mut result = String::with_capacity(sql.len());
52 let mut chars = sql.chars().peekable();
53 while let Some(ch) = chars.next() {
54 if ch == '$' {
55 if chars.peek().is_some_and(|c| c.is_ascii_digit()) {
56 while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
57 chars.next();
58 }
59 result.push('?');
60 } else {
61 result.push(ch);
62 }
63 } else {
64 result.push(ch);
65 }
66 }
67 result
68}
69
70fn rs_getter(kotlin_type: &str) -> &str {
72 match kotlin_type {
73 "Boolean" => "getBoolean",
74 "Byte" => "getByte",
75 "Short" => "getShort",
76 "Int" => "getInt",
77 "Long" => "getLong",
78 "Float" => "getFloat",
79 "Double" => "getDouble",
80 "String" => "getString",
81 "ByteArray" => "getBytes",
82 _ if kotlin_type.contains("BigDecimal") => "getBigDecimal",
83 _ if kotlin_type.contains("LocalDate") => "getObject",
84 _ if kotlin_type.contains("LocalTime") => "getObject",
85 _ if kotlin_type.contains("OffsetTime") => "getObject",
86 _ if kotlin_type.contains("LocalDateTime") => "getObject",
87 _ if kotlin_type.contains("OffsetDateTime") => "getObject",
88 _ if kotlin_type.contains("UUID") => "getObject",
89 _ => "getObject",
90 }
91}
92
93fn ps_setter(kotlin_type: &str) -> &str {
95 match kotlin_type {
96 "Boolean" => "setBoolean",
97 "Byte" => "setByte",
98 "Short" => "setShort",
99 "Int" => "setInt",
100 "Long" => "setLong",
101 "Float" => "setFloat",
102 "Double" => "setDouble",
103 "String" => "setString",
104 "ByteArray" => "setBytes",
105 _ if kotlin_type.contains("BigDecimal") => "setBigDecimal",
106 _ => "setObject",
107 }
108}
109
110impl CodegenBackend for KotlinJdbcBackend {
111 fn name(&self) -> &str {
112 "kotlin-jdbc"
113 }
114
115 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
116 &self.manifest
117 }
118
119 fn supported_engines(&self) -> &[&str] {
120 &["postgresql", "mysql", "sqlite"]
121 }
122
123 fn file_header(&self) -> String {
124 "import java.sql.Connection\n".to_string()
125 }
126
127 fn generate_row_struct(
128 &self,
129 query_name: &str,
130 columns: &[ResolvedColumn],
131 ) -> Result<String, ScytheError> {
132 let struct_name = row_struct_name(query_name, &self.manifest.naming);
133 let mut out = String::new();
134 let _ = writeln!(out, "data class {}(", struct_name);
135 for col in columns.iter() {
136 let _ = writeln!(out, " val {}: {},", col.field_name, col.full_type);
137 }
138 let _ = writeln!(out, ")");
139 Ok(out)
140 }
141
142 fn generate_model_struct(
143 &self,
144 table_name: &str,
145 columns: &[ResolvedColumn],
146 ) -> Result<String, ScytheError> {
147 let name = to_pascal_case(table_name);
148 self.generate_row_struct(&name, columns)
149 }
150
151 fn generate_query_fn(
152 &self,
153 analyzed: &AnalyzedQuery,
154 struct_name: &str,
155 columns: &[ResolvedColumn],
156 params: &[ResolvedParam],
157 ) -> Result<String, ScytheError> {
158 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
159 let sql = pg_to_jdbc_params(&super::clean_sql_oneline_with_optional(
160 &analyzed.sql,
161 &analyzed.optional_params,
162 &analyzed.params,
163 ));
164
165 let use_multiline_params = !params.is_empty();
167
168 let mut out = String::new();
169
170 let write_setters = |out: &mut String, params: &[ResolvedParam]| {
172 for (i, param) in params.iter().enumerate() {
173 let setter = ps_setter(¶m.lang_type);
174 let _ = writeln!(
175 out,
176 " ps.{}({}, {})",
177 setter,
178 i + 1,
179 param.field_name
180 );
181 }
182 };
183
184 let write_fn_sig =
186 |out: &mut String, name: &str, ret: &str, multiline: bool, params: &[ResolvedParam]| {
187 if multiline {
188 let _ = writeln!(out, "fun {}(", name);
189 let _ = writeln!(out, " conn: Connection,");
190 for p in params {
191 let _ = writeln!(out, " {}: {},", p.field_name, p.full_type);
192 }
193 let _ = writeln!(out, "){} {{", ret);
194 } else {
195 let _ = writeln!(out, "fun {}(conn: Connection){} {{", name, ret);
196 }
197 };
198
199 match &analyzed.command {
200 QueryCommand::Exec => {
201 write_fn_sig(&mut out, &func_name, "", use_multiline_params, params);
202 let _ = writeln!(out, " conn.prepareStatement(\"{}\").use {{ ps ->", sql);
203 write_setters(&mut out, params);
204 let _ = writeln!(out, " ps.executeUpdate()");
205 let _ = writeln!(out, " }}");
206 let _ = writeln!(out, "}}");
207 }
208 QueryCommand::ExecResult | QueryCommand::ExecRows => {
209 write_fn_sig(&mut out, &func_name, ": Int", use_multiline_params, params);
210 let _ = writeln!(
211 out,
212 " return conn.prepareStatement(\"{}\").use {{ ps ->",
213 sql
214 );
215 write_setters(&mut out, params);
216 let _ = writeln!(out, " ps.executeUpdate()");
217 let _ = writeln!(out, " }}");
218 let _ = writeln!(out, "}}");
219 }
220 QueryCommand::One => {
221 let ret = format!(": {}?", struct_name);
222 write_fn_sig(&mut out, &func_name, &ret, use_multiline_params, params);
223 let _ = writeln!(out, " conn.prepareStatement(\"{}\").use {{ ps ->", sql);
224 write_setters(&mut out, params);
225 let _ = writeln!(out, " ps.executeQuery().use {{ rs ->");
226 let _ = writeln!(out, " return if (rs.next()) {{");
227 let _ = writeln!(out, " {}(", struct_name);
228 for col in columns.iter() {
229 let getter = rs_getter(&col.lang_type);
230 let _ = writeln!(
231 out,
232 " {} = rs.{}(\"{}\"),",
233 col.field_name, getter, col.name
234 );
235 }
236 let _ = writeln!(out, " )");
237 let _ = writeln!(out, " }} else {{");
238 let _ = writeln!(out, " null");
239 let _ = writeln!(out, " }}");
240 let _ = writeln!(out, " }}");
241 let _ = writeln!(out, " }}");
242 let _ = writeln!(out, "}}");
243 }
244 QueryCommand::Batch => {
245 let batch_fn_name = format!("{}Batch", func_name);
246 if params.len() > 1 {
247 let params_class_name =
248 format!("{}BatchParams", to_pascal_case(&analyzed.name));
249 let _ = writeln!(out, "data class {}(", params_class_name);
250 for p in params {
251 let _ = writeln!(out, " val {}: {},", p.field_name, p.full_type);
252 }
253 let _ = writeln!(out, ")");
254 let _ = writeln!(out);
255 let _ = writeln!(out, "fun {}(", batch_fn_name);
256 let _ = writeln!(out, " conn: Connection,");
257 let _ = writeln!(out, " items: List<{}>,", params_class_name);
258 let _ = writeln!(out, ") {{");
259 let _ = writeln!(out, " conn.prepareStatement(\"{}\").use {{ ps ->", sql);
260 let _ = writeln!(out, " for (item in items) {{");
261 for (i, param) in params.iter().enumerate() {
262 let setter = ps_setter(¶m.lang_type);
263 let _ = writeln!(
264 out,
265 " ps.{}({}, item.{})",
266 setter,
267 i + 1,
268 param.field_name
269 );
270 }
271 let _ = writeln!(out, " ps.addBatch()");
272 let _ = writeln!(out, " }}");
273 let _ = writeln!(out, " ps.executeBatch()");
274 let _ = writeln!(out, " }}");
275 let _ = writeln!(out, "}}");
276 } else if params.len() == 1 {
277 let _ = writeln!(out, "fun {}(", batch_fn_name);
278 let _ = writeln!(out, " conn: Connection,");
279 let _ = writeln!(out, " items: List<{}>,", params[0].full_type);
280 let _ = writeln!(out, ") {{");
281 let _ = writeln!(out, " conn.prepareStatement(\"{}\").use {{ ps ->", sql);
282 let _ = writeln!(out, " for (item in items) {{");
283 let setter = ps_setter(¶ms[0].lang_type);
284 let _ = writeln!(out, " ps.{}(1, item)", setter);
285 let _ = writeln!(out, " ps.addBatch()");
286 let _ = writeln!(out, " }}");
287 let _ = writeln!(out, " ps.executeBatch()");
288 let _ = writeln!(out, " }}");
289 let _ = writeln!(out, "}}");
290 } else {
291 let _ = writeln!(
292 out,
293 "fun {}(conn: Connection, count: Int) {{",
294 batch_fn_name
295 );
296 let _ = writeln!(out, " conn.prepareStatement(\"{}\").use {{ ps ->", sql);
297 let _ = writeln!(out, " repeat(count) {{");
298 let _ = writeln!(out, " ps.addBatch()");
299 let _ = writeln!(out, " }}");
300 let _ = writeln!(out, " ps.executeBatch()");
301 let _ = writeln!(out, " }}");
302 let _ = writeln!(out, "}}");
303 }
304 }
305 QueryCommand::Many => {
306 let ret = format!(": List<{}>", struct_name);
307 write_fn_sig(&mut out, &func_name, &ret, use_multiline_params, params);
308 let _ = writeln!(out, " conn.prepareStatement(\"{}\").use {{ ps ->", sql);
309 write_setters(&mut out, params);
310 let _ = writeln!(out, " ps.executeQuery().use {{ rs ->");
311 let _ = writeln!(
312 out,
313 " val result = mutableListOf<{}>()",
314 struct_name
315 );
316 let _ = writeln!(out, " while (rs.next()) {{");
317 let _ = writeln!(out, " result.add(");
318 let _ = writeln!(out, " {}(", struct_name);
319 for col in columns.iter() {
320 let getter = rs_getter(&col.lang_type);
321 let _ = writeln!(
322 out,
323 " {} = rs.{}(\"{}\"),",
324 col.field_name, getter, col.name
325 );
326 }
327 let _ = writeln!(out, " ),");
328 let _ = writeln!(out, " )");
329 let _ = writeln!(out, " }}");
330 let _ = writeln!(out, " return result");
331 let _ = writeln!(out, " }}");
332 let _ = writeln!(out, " }}");
333 let _ = writeln!(out, "}}");
334 }
335 }
336
337 Ok(out)
338 }
339
340 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
341 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
342 let mut out = String::new();
343 let _ = writeln!(out, "enum class {}(val value: String) {{", type_name);
344 for (i, value) in enum_info.values.iter().enumerate() {
345 let variant = enum_variant_name(value, &self.manifest.naming);
346 let sep = if i + 1 < enum_info.values.len() {
347 ","
348 } else {
349 ";"
350 };
351 let _ = writeln!(out, " {}(\"{}\"){}", variant, value, sep);
352 }
353 let _ = writeln!(out, "}}");
354 Ok(out)
355 }
356
357 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
358 let name = to_pascal_case(&composite.sql_name);
359 let mut out = String::new();
360 let _ = writeln!(out, "data class {}(", name);
361 for field in composite.fields.iter() {
362 let field_name = to_camel_case(&field.name);
363 let field_type = resolve_type(&field.neutral_type, &self.manifest, false)
364 .map(|t| t.into_owned())
365 .unwrap_or_else(|_| "Any".to_string());
366 let _ = writeln!(out, " val {}: {},", field_name, field_type);
367 }
368 let _ = writeln!(out, ")");
369 Ok(out)
370 }
371}