Skip to main content

scythe_codegen/backends/
kotlin_exposed.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16const DEFAULT_MANIFEST_PG: &str = include_str!("../../manifests/kotlin-exposed.toml");
17
18pub struct KotlinExposedBackend {
19    manifest: BackendManifest,
20}
21
22impl KotlinExposedBackend {
23    pub fn new(engine: &str) -> Result<Self, ScytheError> {
24        let default_toml = match engine {
25            "postgresql" | "postgres" | "pg" => DEFAULT_MANIFEST_PG,
26            _ => {
27                return Err(ScytheError::new(
28                    ErrorCode::InternalError,
29                    format!("unsupported engine '{}' for kotlin-exposed backend", engine),
30                ));
31            }
32        };
33        let manifest_path = Path::new("backends/kotlin-exposed/manifest.toml");
34        let manifest = if manifest_path.exists() {
35            load_manifest(manifest_path)
36                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
37        } else {
38            toml::from_str(default_toml)
39                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
40        };
41        Ok(Self { manifest })
42    }
43}
44
45/// Convert PostgreSQL $1, $2, ... placeholders to JDBC ? placeholders.
46fn pg_to_jdbc_params(sql: &str) -> String {
47    let mut result = String::with_capacity(sql.len());
48    let mut chars = sql.chars().peekable();
49    while let Some(ch) = chars.next() {
50        if ch == '$' {
51            if chars.peek().is_some_and(|c| c.is_ascii_digit()) {
52                while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
53                    chars.next();
54                }
55                result.push('?');
56            } else {
57                result.push(ch);
58            }
59        } else {
60            result.push(ch);
61        }
62    }
63    result
64}
65
66/// Get the Exposed column type function for a given Kotlin type.
67fn exposed_column_fn(kotlin_type: &str) -> &str {
68    match kotlin_type {
69        "Boolean" => "bool",
70        "Byte" => "byte",
71        "Short" => "short",
72        "Int" => "integer",
73        "Long" => "long",
74        "Float" => "float",
75        "Double" => "double",
76        "String" => "varchar",
77        "ByteArray" => "binary",
78        _ if kotlin_type.contains("BigDecimal") => "decimal",
79        _ if kotlin_type.contains("LocalDate") => "date",
80        _ if kotlin_type.contains("LocalTime") => "time",
81        _ if kotlin_type.contains("OffsetTime") => "time",
82        _ if kotlin_type.contains("LocalDateTime") => "datetime",
83        _ if kotlin_type.contains("OffsetDateTime") => "timestampWithTimeZone",
84        _ if kotlin_type.contains("UUID") => "uuid",
85        _ => "text",
86    }
87}
88
89/// Get the ResultSet getter method name for a given Kotlin type.
90fn rs_getter(kotlin_type: &str) -> &str {
91    match kotlin_type {
92        "Boolean" => "getBoolean",
93        "Byte" => "getByte",
94        "Short" => "getShort",
95        "Int" => "getInt",
96        "Long" => "getLong",
97        "Float" => "getFloat",
98        "Double" => "getDouble",
99        "String" => "getString",
100        "ByteArray" => "getBytes",
101        _ if kotlin_type.contains("BigDecimal") => "getBigDecimal",
102        _ if kotlin_type.contains("LocalDate") => "getObject",
103        _ if kotlin_type.contains("LocalTime") => "getObject",
104        _ if kotlin_type.contains("OffsetTime") => "getObject",
105        _ if kotlin_type.contains("LocalDateTime") => "getObject",
106        _ if kotlin_type.contains("OffsetDateTime") => "getObject",
107        _ if kotlin_type.contains("UUID") => "getObject",
108        _ => "getObject",
109    }
110}
111
112/// Get the Exposed column type class for use in `exec()` parameter binding.
113fn exposed_column_type_class(kotlin_type: &str) -> &str {
114    match kotlin_type {
115        "Boolean" => "BooleanColumnType()",
116        "Byte" => "ByteColumnType()",
117        "Short" => "ShortColumnType()",
118        "Int" => "IntegerColumnType()",
119        "Long" => "LongColumnType()",
120        "Float" => "FloatColumnType()",
121        "Double" => "DoubleColumnType()",
122        // TODO: varchar length 255 is hardcoded; see generate_model_struct TODO.
123        "String" => "VarCharColumnType(255)",
124        "ByteArray" => "BinaryColumnType()",
125        _ if kotlin_type.contains("BigDecimal") => "DecimalColumnType(10, 2)",
126        _ if kotlin_type.contains("LocalDate") => "JavaLocalDateColumnType()",
127        _ if kotlin_type.contains("LocalTime") => "JavaLocalTimeColumnType()",
128        _ if kotlin_type.contains("OffsetTime") => "JavaLocalTimeColumnType()",
129        _ if kotlin_type.contains("LocalDateTime") => "JavaLocalDateTimeColumnType()",
130        _ if kotlin_type.contains("OffsetDateTime") => "JavaOffsetDateTimeColumnType()",
131        _ if kotlin_type.contains("UUID") => "UUIDColumnType()",
132        _ => "TextColumnType()",
133    }
134}
135
136impl CodegenBackend for KotlinExposedBackend {
137    fn name(&self) -> &str {
138        "kotlin-exposed"
139    }
140
141    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
142        &self.manifest
143    }
144
145    fn supported_engines(&self) -> &[&str] {
146        &["postgresql"]
147    }
148
149    fn file_header(&self) -> String {
150        let mut out = String::new();
151        out.push_str("import org.jetbrains.exposed.sql.*\n");
152        out.push_str("import org.jetbrains.exposed.sql.transactions.transaction\n");
153        out.push_str("import org.jetbrains.exposed.dao.*\n");
154        out.push_str("import org.jetbrains.exposed.dao.id.IntIdTable\n");
155        out
156    }
157
158    fn generate_row_struct(
159        &self,
160        query_name: &str,
161        columns: &[ResolvedColumn],
162    ) -> Result<String, ScytheError> {
163        let struct_name = row_struct_name(query_name, &self.manifest.naming);
164        let mut out = String::new();
165        let _ = writeln!(out, "data class {}(", struct_name);
166        for col in columns.iter() {
167            let _ = writeln!(out, "    val {}: {},", col.field_name, col.full_type);
168        }
169        let _ = writeln!(out, ")");
170        Ok(out)
171    }
172
173    fn generate_model_struct(
174        &self,
175        table_name: &str,
176        columns: &[ResolvedColumn],
177    ) -> Result<String, ScytheError> {
178        let name = to_pascal_case(table_name);
179        let table_obj_name = format!("{}Table", name);
180        let mut out = String::new();
181        // TODO: IntIdTable is hardcoded — detecting the actual PK type (LongIdTable,
182        // UUIDTable, etc.) from schema DDL requires propagating PK column info through
183        // the analyzer. Follow-up: https://github.com/scythe-sql/scythe/issues/XXX
184        let _ = writeln!(
185            out,
186            "object {} : IntIdTable(\"{}\") {{",
187            table_obj_name, table_name
188        );
189        for col in columns.iter() {
190            let col_fn = exposed_column_fn(&col.lang_type);
191            let nullable_suffix = if col.nullable { ".nullable()" } else { "" };
192            // TODO: varchar length is hardcoded to 255 — column lengths from schema DDL
193            // are not propagated through the analyzer yet. Follow-up needed to thread
194            // length/precision metadata from DDL columns to codegen.
195            if col_fn == "varchar" {
196                let _ = writeln!(
197                    out,
198                    "    val {} = varchar(\"{}\", 255){}",
199                    col.field_name, col.name, nullable_suffix
200                );
201            } else {
202                let _ = writeln!(
203                    out,
204                    "    val {} = {}(\"{}\"){}",
205                    col.field_name, col_fn, col.name, nullable_suffix
206                );
207            }
208        }
209        let _ = writeln!(out, "}}");
210        Ok(out)
211    }
212
213    fn generate_query_fn(
214        &self,
215        analyzed: &AnalyzedQuery,
216        struct_name: &str,
217        columns: &[ResolvedColumn],
218        params: &[ResolvedParam],
219    ) -> Result<String, ScytheError> {
220        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
221        let sql = pg_to_jdbc_params(&super::clean_sql_oneline_with_optional(
222            &analyzed.sql,
223            &analyzed.optional_params,
224            &analyzed.params,
225        ));
226
227        let use_multiline_params = !params.is_empty();
228        let mut out = String::new();
229
230        // Helper: write function signature
231        let write_fn_sig =
232            |out: &mut String, name: &str, ret: &str, multiline: bool, params: &[ResolvedParam]| {
233                if multiline {
234                    let _ = writeln!(out, "fun {}(", name);
235                    for p in params {
236                        let _ = writeln!(out, "    {}: {},", p.field_name, p.full_type);
237                    }
238                    let _ = writeln!(out, "){} = transaction {{", ret);
239                } else {
240                    let _ = writeln!(out, "fun {}(){} = transaction {{", name, ret);
241                }
242            };
243
244        // Helper: build args list for exec()
245        let build_args = |params: &[ResolvedParam]| -> String {
246            if params.is_empty() {
247                return String::new();
248            }
249            let pairs: Vec<String> = params
250                .iter()
251                .map(|p| {
252                    format!(
253                        "{} to {}",
254                        exposed_column_type_class(&p.lang_type),
255                        p.field_name
256                    )
257                })
258                .collect();
259            format!(", listOf({})", pairs.join(", "))
260        };
261
262        match &analyzed.command {
263            QueryCommand::Exec => {
264                write_fn_sig(&mut out, &func_name, "", use_multiline_params, params);
265                let args = build_args(params);
266                let _ = writeln!(out, "    exec(\"{}\"{})", sql, args);
267                let _ = writeln!(out, "}}");
268            }
269            QueryCommand::ExecResult | QueryCommand::ExecRows => {
270                write_fn_sig(&mut out, &func_name, ": Int", use_multiline_params, params);
271                let args = build_args(params);
272                let _ = writeln!(out, "    exec(\"{}\"{}) ?: 0", sql, args);
273                let _ = writeln!(out, "}}");
274            }
275            QueryCommand::One => {
276                let ret = format!(": {}?", struct_name);
277                write_fn_sig(&mut out, &func_name, &ret, use_multiline_params, params);
278                let args = build_args(params);
279                let _ = writeln!(out, "    exec(\"{}\"{}) {{ rs ->", sql, args);
280                let _ = writeln!(out, "        if (rs.next()) {}(", struct_name);
281                for col in columns.iter() {
282                    let getter = rs_getter(&col.lang_type);
283                    let _ = writeln!(
284                        out,
285                        "            {} = rs.{}(\"{}\"),",
286                        col.field_name, getter, col.name
287                    );
288                }
289                let _ = writeln!(out, "        )");
290                let _ = writeln!(out, "        else null");
291                let _ = writeln!(out, "    }}");
292                let _ = writeln!(out, "}}");
293            }
294            QueryCommand::Batch => {
295                let batch_fn_name = format!("{}Batch", func_name);
296                if params.len() > 1 {
297                    let params_class_name =
298                        format!("{}BatchParams", to_pascal_case(&analyzed.name));
299                    let _ = writeln!(out, "data class {}(", params_class_name);
300                    for p in params {
301                        let _ = writeln!(out, "    val {}: {},", p.field_name, p.full_type);
302                    }
303                    let _ = writeln!(out, ")");
304                    let _ = writeln!(out);
305                    let _ = writeln!(out, "fun {}(", batch_fn_name);
306                    let _ = writeln!(out, "    items: List<{}>,", params_class_name);
307                    let _ = writeln!(out, ") = transaction {{");
308                    let _ = writeln!(out, "    for (item in items) {{");
309                    let args: Vec<String> = params
310                        .iter()
311                        .map(|p| {
312                            format!(
313                                "{} to item.{}",
314                                exposed_column_type_class(&p.lang_type),
315                                p.field_name
316                            )
317                        })
318                        .collect();
319                    let _ = writeln!(
320                        out,
321                        "        exec(\"{}\", listOf({}))",
322                        sql,
323                        args.join(", ")
324                    );
325                    let _ = writeln!(out, "    }}");
326                    let _ = writeln!(out, "}}");
327                } else if params.len() == 1 {
328                    let _ = writeln!(out, "fun {}(", batch_fn_name);
329                    let _ = writeln!(out, "    items: List<{}>,", params[0].full_type);
330                    let _ = writeln!(out, ") = transaction {{");
331                    let _ = writeln!(out, "    for (item in items) {{");
332                    let _ = writeln!(
333                        out,
334                        "        exec(\"{}\", listOf({} to item))",
335                        sql,
336                        exposed_column_type_class(&params[0].lang_type)
337                    );
338                    let _ = writeln!(out, "    }}");
339                    let _ = writeln!(out, "}}");
340                } else {
341                    let _ = writeln!(out, "fun {}(count: Int) = transaction {{", batch_fn_name);
342                    let _ = writeln!(out, "    repeat(count) {{");
343                    let _ = writeln!(out, "        exec(\"{}\")", sql);
344                    let _ = writeln!(out, "    }}");
345                    let _ = writeln!(out, "}}");
346                }
347            }
348            QueryCommand::Grouped => {
349                // Grouped queries are not yet supported by this backend.
350                return Err(ScytheError::new(
351                    ErrorCode::InternalError,
352                    "kotlin-exposed backend does not yet support :grouped queries".to_string(),
353                ));
354            }
355            QueryCommand::Many => {
356                let ret = format!(": List<{}>", struct_name);
357                write_fn_sig(&mut out, &func_name, &ret, use_multiline_params, params);
358                let args = build_args(params);
359                let _ = writeln!(out, "    val result = mutableListOf<{}>()", struct_name);
360                let _ = writeln!(out, "    exec(\"{}\"{}) {{ rs ->", sql, args);
361                let _ = writeln!(out, "        while (rs.next()) {{");
362                let _ = writeln!(out, "            result.add(");
363                let _ = writeln!(out, "                {}(", struct_name);
364                for col in columns.iter() {
365                    let getter = rs_getter(&col.lang_type);
366                    let _ = writeln!(
367                        out,
368                        "                    {} = rs.{}(\"{}\"),",
369                        col.field_name, getter, col.name
370                    );
371                }
372                let _ = writeln!(out, "                ),");
373                let _ = writeln!(out, "            )");
374                let _ = writeln!(out, "        }}");
375                let _ = writeln!(out, "    }}");
376                let _ = writeln!(out, "    result");
377                let _ = writeln!(out, "}}");
378            }
379        }
380
381        Ok(out)
382    }
383
384    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
385        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
386        let mut out = String::new();
387        let _ = writeln!(out, "enum class {}(val value: String) {{", type_name);
388        for (i, value) in enum_info.values.iter().enumerate() {
389            let variant = enum_variant_name(value, &self.manifest.naming);
390            let sep = if i + 1 < enum_info.values.len() {
391                ","
392            } else {
393                ";"
394            };
395            let _ = writeln!(out, "    {}(\"{}\"){}", variant, value, sep);
396        }
397        let _ = writeln!(out, "}}");
398        Ok(out)
399    }
400
401    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
402        let name = to_pascal_case(&composite.sql_name);
403        let mut out = String::new();
404        let _ = writeln!(out, "data class {}(", name);
405        for field in composite.fields.iter() {
406            let field_name = to_camel_case(&field.name);
407            let field_type = resolve_type(&field.neutral_type, &self.manifest, false)
408                .map(|t| t.into_owned())
409                .unwrap_or_else(|_| "Any".to_string());
410            let _ = writeln!(out, "    val {}: {},", field_name, field_type);
411        }
412        let _ = writeln!(out, ")");
413        Ok(out)
414    }
415}