1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16const DEFAULT_MANIFEST_PG: &str = include_str!("../../manifests/kotlin-exposed.toml");
17
18pub struct KotlinExposedBackend {
19 manifest: BackendManifest,
20}
21
22impl KotlinExposedBackend {
23 pub fn new(engine: &str) -> Result<Self, ScytheError> {
24 let default_toml = match engine {
25 "postgresql" | "postgres" | "pg" => DEFAULT_MANIFEST_PG,
26 _ => {
27 return Err(ScytheError::new(
28 ErrorCode::InternalError,
29 format!("unsupported engine '{}' for kotlin-exposed backend", engine),
30 ));
31 }
32 };
33 let manifest_path = Path::new("backends/kotlin-exposed/manifest.toml");
34 let manifest = if manifest_path.exists() {
35 load_manifest(manifest_path)
36 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
37 } else {
38 toml::from_str(default_toml)
39 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
40 };
41 Ok(Self { manifest })
42 }
43}
44
45fn pg_to_jdbc_params(sql: &str) -> String {
47 let mut result = String::with_capacity(sql.len());
48 let mut chars = sql.chars().peekable();
49 while let Some(ch) = chars.next() {
50 if ch == '$' {
51 if chars.peek().is_some_and(|c| c.is_ascii_digit()) {
52 while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
53 chars.next();
54 }
55 result.push('?');
56 } else {
57 result.push(ch);
58 }
59 } else {
60 result.push(ch);
61 }
62 }
63 result
64}
65
66fn exposed_column_fn(kotlin_type: &str) -> &str {
68 match kotlin_type {
69 "Boolean" => "bool",
70 "Byte" => "byte",
71 "Short" => "short",
72 "Int" => "integer",
73 "Long" => "long",
74 "Float" => "float",
75 "Double" => "double",
76 "String" => "varchar",
77 "ByteArray" => "binary",
78 _ if kotlin_type.contains("BigDecimal") => "decimal",
79 _ if kotlin_type.contains("LocalDate") => "date",
80 _ if kotlin_type.contains("LocalTime") => "time",
81 _ if kotlin_type.contains("OffsetTime") => "time",
82 _ if kotlin_type.contains("LocalDateTime") => "datetime",
83 _ if kotlin_type.contains("OffsetDateTime") => "timestampWithTimeZone",
84 _ if kotlin_type.contains("UUID") => "uuid",
85 _ => "text",
86 }
87}
88
89fn rs_getter(kotlin_type: &str) -> &str {
91 match kotlin_type {
92 "Boolean" => "getBoolean",
93 "Byte" => "getByte",
94 "Short" => "getShort",
95 "Int" => "getInt",
96 "Long" => "getLong",
97 "Float" => "getFloat",
98 "Double" => "getDouble",
99 "String" => "getString",
100 "ByteArray" => "getBytes",
101 _ if kotlin_type.contains("BigDecimal") => "getBigDecimal",
102 _ if kotlin_type.contains("LocalDate") => "getObject",
103 _ if kotlin_type.contains("LocalTime") => "getObject",
104 _ if kotlin_type.contains("OffsetTime") => "getObject",
105 _ if kotlin_type.contains("LocalDateTime") => "getObject",
106 _ if kotlin_type.contains("OffsetDateTime") => "getObject",
107 _ if kotlin_type.contains("UUID") => "getObject",
108 _ => "getObject",
109 }
110}
111
112fn exposed_column_type_class(kotlin_type: &str) -> &str {
114 match kotlin_type {
115 "Boolean" => "BooleanColumnType()",
116 "Byte" => "ByteColumnType()",
117 "Short" => "ShortColumnType()",
118 "Int" => "IntegerColumnType()",
119 "Long" => "LongColumnType()",
120 "Float" => "FloatColumnType()",
121 "Double" => "DoubleColumnType()",
122 "String" => "VarCharColumnType(255)",
124 "ByteArray" => "BinaryColumnType()",
125 _ if kotlin_type.contains("BigDecimal") => "DecimalColumnType(10, 2)",
126 _ if kotlin_type.contains("LocalDate") => "JavaLocalDateColumnType()",
127 _ if kotlin_type.contains("LocalTime") => "JavaLocalTimeColumnType()",
128 _ if kotlin_type.contains("OffsetTime") => "JavaLocalTimeColumnType()",
129 _ if kotlin_type.contains("LocalDateTime") => "JavaLocalDateTimeColumnType()",
130 _ if kotlin_type.contains("OffsetDateTime") => "JavaOffsetDateTimeColumnType()",
131 _ if kotlin_type.contains("UUID") => "UUIDColumnType()",
132 _ => "TextColumnType()",
133 }
134}
135
136impl CodegenBackend for KotlinExposedBackend {
137 fn name(&self) -> &str {
138 "kotlin-exposed"
139 }
140
141 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
142 &self.manifest
143 }
144
145 fn supported_engines(&self) -> &[&str] {
146 &["postgresql"]
147 }
148
149 fn file_header(&self) -> String {
150 let mut out = String::new();
151 out.push_str("import org.jetbrains.exposed.sql.*\n");
152 out.push_str("import org.jetbrains.exposed.sql.transactions.transaction\n");
153 out.push_str("import org.jetbrains.exposed.dao.*\n");
154 out.push_str("import org.jetbrains.exposed.dao.id.IntIdTable\n");
155 out
156 }
157
158 fn generate_row_struct(
159 &self,
160 query_name: &str,
161 columns: &[ResolvedColumn],
162 ) -> Result<String, ScytheError> {
163 let struct_name = row_struct_name(query_name, &self.manifest.naming);
164 let mut out = String::new();
165 let _ = writeln!(out, "data class {}(", struct_name);
166 for col in columns.iter() {
167 let _ = writeln!(out, " val {}: {},", col.field_name, col.full_type);
168 }
169 let _ = writeln!(out, ")");
170 Ok(out)
171 }
172
173 fn generate_model_struct(
174 &self,
175 table_name: &str,
176 columns: &[ResolvedColumn],
177 ) -> Result<String, ScytheError> {
178 let name = to_pascal_case(table_name);
179 let table_obj_name = format!("{}Table", name);
180 let mut out = String::new();
181 let _ = writeln!(
185 out,
186 "object {} : IntIdTable(\"{}\") {{",
187 table_obj_name, table_name
188 );
189 for col in columns.iter() {
190 let col_fn = exposed_column_fn(&col.lang_type);
191 let nullable_suffix = if col.nullable { ".nullable()" } else { "" };
192 if col_fn == "varchar" {
196 let _ = writeln!(
197 out,
198 " val {} = varchar(\"{}\", 255){}",
199 col.field_name, col.name, nullable_suffix
200 );
201 } else {
202 let _ = writeln!(
203 out,
204 " val {} = {}(\"{}\"){}",
205 col.field_name, col_fn, col.name, nullable_suffix
206 );
207 }
208 }
209 let _ = writeln!(out, "}}");
210 Ok(out)
211 }
212
213 fn generate_query_fn(
214 &self,
215 analyzed: &AnalyzedQuery,
216 struct_name: &str,
217 columns: &[ResolvedColumn],
218 params: &[ResolvedParam],
219 ) -> Result<String, ScytheError> {
220 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
221 let sql = pg_to_jdbc_params(&super::clean_sql_oneline_with_optional(
222 &analyzed.sql,
223 &analyzed.optional_params,
224 &analyzed.params,
225 ));
226
227 let use_multiline_params = !params.is_empty();
228 let mut out = String::new();
229
230 let write_fn_sig =
232 |out: &mut String, name: &str, ret: &str, multiline: bool, params: &[ResolvedParam]| {
233 if multiline {
234 let _ = writeln!(out, "fun {}(", name);
235 for p in params {
236 let _ = writeln!(out, " {}: {},", p.field_name, p.full_type);
237 }
238 let _ = writeln!(out, "){} = transaction {{", ret);
239 } else {
240 let _ = writeln!(out, "fun {}(){} = transaction {{", name, ret);
241 }
242 };
243
244 let build_args = |params: &[ResolvedParam]| -> String {
246 if params.is_empty() {
247 return String::new();
248 }
249 let pairs: Vec<String> = params
250 .iter()
251 .map(|p| {
252 format!(
253 "{} to {}",
254 exposed_column_type_class(&p.lang_type),
255 p.field_name
256 )
257 })
258 .collect();
259 format!(", listOf({})", pairs.join(", "))
260 };
261
262 match &analyzed.command {
263 QueryCommand::Exec => {
264 write_fn_sig(&mut out, &func_name, "", use_multiline_params, params);
265 let args = build_args(params);
266 let _ = writeln!(out, " exec(\"{}\"{})", sql, args);
267 let _ = writeln!(out, "}}");
268 }
269 QueryCommand::ExecResult | QueryCommand::ExecRows => {
270 write_fn_sig(&mut out, &func_name, ": Int", use_multiline_params, params);
271 let args = build_args(params);
272 let _ = writeln!(out, " exec(\"{}\"{}) ?: 0", sql, args);
273 let _ = writeln!(out, "}}");
274 }
275 QueryCommand::One => {
276 let ret = format!(": {}?", struct_name);
277 write_fn_sig(&mut out, &func_name, &ret, use_multiline_params, params);
278 let args = build_args(params);
279 let _ = writeln!(out, " exec(\"{}\"{}) {{ rs ->", sql, args);
280 let _ = writeln!(out, " if (rs.next()) {}(", struct_name);
281 for col in columns.iter() {
282 let getter = rs_getter(&col.lang_type);
283 let _ = writeln!(
284 out,
285 " {} = rs.{}(\"{}\"),",
286 col.field_name, getter, col.name
287 );
288 }
289 let _ = writeln!(out, " )");
290 let _ = writeln!(out, " else null");
291 let _ = writeln!(out, " }}");
292 let _ = writeln!(out, "}}");
293 }
294 QueryCommand::Batch => {
295 let batch_fn_name = format!("{}Batch", func_name);
296 if params.len() > 1 {
297 let params_class_name =
298 format!("{}BatchParams", to_pascal_case(&analyzed.name));
299 let _ = writeln!(out, "data class {}(", params_class_name);
300 for p in params {
301 let _ = writeln!(out, " val {}: {},", p.field_name, p.full_type);
302 }
303 let _ = writeln!(out, ")");
304 let _ = writeln!(out);
305 let _ = writeln!(out, "fun {}(", batch_fn_name);
306 let _ = writeln!(out, " items: List<{}>,", params_class_name);
307 let _ = writeln!(out, ") = transaction {{");
308 let _ = writeln!(out, " for (item in items) {{");
309 let args: Vec<String> = params
310 .iter()
311 .map(|p| {
312 format!(
313 "{} to item.{}",
314 exposed_column_type_class(&p.lang_type),
315 p.field_name
316 )
317 })
318 .collect();
319 let _ = writeln!(
320 out,
321 " exec(\"{}\", listOf({}))",
322 sql,
323 args.join(", ")
324 );
325 let _ = writeln!(out, " }}");
326 let _ = writeln!(out, "}}");
327 } else if params.len() == 1 {
328 let _ = writeln!(out, "fun {}(", batch_fn_name);
329 let _ = writeln!(out, " items: List<{}>,", params[0].full_type);
330 let _ = writeln!(out, ") = transaction {{");
331 let _ = writeln!(out, " for (item in items) {{");
332 let _ = writeln!(
333 out,
334 " exec(\"{}\", listOf({} to item))",
335 sql,
336 exposed_column_type_class(¶ms[0].lang_type)
337 );
338 let _ = writeln!(out, " }}");
339 let _ = writeln!(out, "}}");
340 } else {
341 let _ = writeln!(out, "fun {}(count: Int) = transaction {{", batch_fn_name);
342 let _ = writeln!(out, " repeat(count) {{");
343 let _ = writeln!(out, " exec(\"{}\")", sql);
344 let _ = writeln!(out, " }}");
345 let _ = writeln!(out, "}}");
346 }
347 }
348 QueryCommand::Grouped => {
349 return Err(ScytheError::new(
351 ErrorCode::InternalError,
352 "kotlin-exposed backend does not yet support :grouped queries".to_string(),
353 ));
354 }
355 QueryCommand::Many => {
356 let ret = format!(": List<{}>", struct_name);
357 write_fn_sig(&mut out, &func_name, &ret, use_multiline_params, params);
358 let args = build_args(params);
359 let _ = writeln!(out, " val result = mutableListOf<{}>()", struct_name);
360 let _ = writeln!(out, " exec(\"{}\"{}) {{ rs ->", sql, args);
361 let _ = writeln!(out, " while (rs.next()) {{");
362 let _ = writeln!(out, " result.add(");
363 let _ = writeln!(out, " {}(", struct_name);
364 for col in columns.iter() {
365 let getter = rs_getter(&col.lang_type);
366 let _ = writeln!(
367 out,
368 " {} = rs.{}(\"{}\"),",
369 col.field_name, getter, col.name
370 );
371 }
372 let _ = writeln!(out, " ),");
373 let _ = writeln!(out, " )");
374 let _ = writeln!(out, " }}");
375 let _ = writeln!(out, " }}");
376 let _ = writeln!(out, " result");
377 let _ = writeln!(out, "}}");
378 }
379 }
380
381 Ok(out)
382 }
383
384 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
385 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
386 let mut out = String::new();
387 let _ = writeln!(out, "enum class {}(val value: String) {{", type_name);
388 for (i, value) in enum_info.values.iter().enumerate() {
389 let variant = enum_variant_name(value, &self.manifest.naming);
390 let sep = if i + 1 < enum_info.values.len() {
391 ","
392 } else {
393 ";"
394 };
395 let _ = writeln!(out, " {}(\"{}\"){}", variant, value, sep);
396 }
397 let _ = writeln!(out, "}}");
398 Ok(out)
399 }
400
401 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
402 let name = to_pascal_case(&composite.sql_name);
403 let mut out = String::new();
404 let _ = writeln!(out, "data class {}(", name);
405 for field in composite.fields.iter() {
406 let field_name = to_camel_case(&field.name);
407 let field_type = resolve_type(&field.neutral_type, &self.manifest, false)
408 .map(|t| t.into_owned())
409 .unwrap_or_else(|_| "Any".to_string());
410 let _ = writeln!(out, " val {}: {},", field_name, field_type);
411 }
412 let _ = writeln!(out, ")");
413 Ok(out)
414 }
415}