Skip to main content

scythe_codegen/backends/
kotlin_r2dbc.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16const DEFAULT_MANIFEST_PG: &str = include_str!("../../manifests/kotlin-r2dbc.toml");
17const DEFAULT_MANIFEST_MYSQL: &str = include_str!("../../manifests/kotlin-r2dbc.mysql.toml");
18const DEFAULT_MANIFEST_SQLITE: &str = include_str!("../../manifests/kotlin-r2dbc.sqlite.toml");
19
20pub struct KotlinR2dbcBackend {
21    manifest: BackendManifest,
22}
23
24impl KotlinR2dbcBackend {
25    pub fn new(engine: &str) -> Result<Self, ScytheError> {
26        let default_toml = match engine {
27            "postgresql" | "postgres" | "pg" => DEFAULT_MANIFEST_PG,
28            "mysql" | "mariadb" => DEFAULT_MANIFEST_MYSQL,
29            "sqlite" | "sqlite3" => DEFAULT_MANIFEST_SQLITE,
30            _ => {
31                return Err(ScytheError::new(
32                    ErrorCode::InternalError,
33                    format!("unsupported engine '{}' for kotlin-r2dbc backend", engine),
34                ));
35            }
36        };
37        let manifest_path = Path::new("backends/kotlin-r2dbc/manifest.toml");
38        let manifest = if manifest_path.exists() {
39            load_manifest(manifest_path)
40                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
41        } else {
42            toml::from_str(default_toml)
43                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
44        };
45        Ok(Self { manifest })
46    }
47}
48
49/// Get the R2DBC Row getter class for a given Kotlin type.
50fn r2dbc_row_class(kotlin_type: &str) -> &str {
51    match kotlin_type {
52        "Boolean" => "Boolean::class.java",
53        "Byte" => "Byte::class.java",
54        "Short" => "Short::class.java",
55        "Int" => "Int::class.javaObjectType",
56        "Long" => "Long::class.javaObjectType",
57        "Float" => "Float::class.javaObjectType",
58        "Double" => "Double::class.javaObjectType",
59        "String" => "String::class.java",
60        "ByteArray" => "ByteArray::class.java",
61        _ if kotlin_type.contains("BigDecimal") => "java.math.BigDecimal::class.java",
62        _ if kotlin_type.contains("LocalDate") => "java.time.LocalDate::class.java",
63        _ if kotlin_type.contains("LocalTime") => "java.time.LocalTime::class.java",
64        _ if kotlin_type.contains("OffsetTime") => "java.time.OffsetTime::class.java",
65        _ if kotlin_type.contains("LocalDateTime") => "java.time.LocalDateTime::class.java",
66        _ if kotlin_type.contains("OffsetDateTime") => "java.time.OffsetDateTime::class.java",
67        _ if kotlin_type.contains("UUID") => "java.util.UUID::class.java",
68        _ => "Any::class.java",
69    }
70}
71
72impl CodegenBackend for KotlinR2dbcBackend {
73    fn name(&self) -> &str {
74        "kotlin-r2dbc"
75    }
76
77    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
78        &self.manifest
79    }
80
81    fn supported_engines(&self) -> &[&str] {
82        &["postgresql", "mysql", "sqlite"]
83    }
84
85    fn file_header(&self) -> String {
86        "import io.r2dbc.spi.ConnectionFactory\n\
87         import java.math.BigDecimal\n\
88         import java.time.*\n\
89         import java.util.UUID\n\
90         import kotlinx.coroutines.flow.Flow\n\
91         import kotlinx.coroutines.reactive.asFlow\n\
92         import kotlinx.coroutines.reactive.awaitFirst\n\
93         import kotlinx.coroutines.reactive.awaitFirstOrNull\n\
94         import reactor.core.publisher.Flux\n\
95         import reactor.core.publisher.Mono\n"
96            .to_string()
97    }
98
99    fn generate_row_struct(
100        &self,
101        query_name: &str,
102        columns: &[ResolvedColumn],
103    ) -> Result<String, ScytheError> {
104        let struct_name = row_struct_name(query_name, &self.manifest.naming);
105        let mut out = String::new();
106        let _ = writeln!(out, "data class {}(", struct_name);
107        for col in columns.iter() {
108            let _ = writeln!(out, "    val {}: {},", col.field_name, col.full_type);
109        }
110        let _ = writeln!(out, ")");
111        Ok(out)
112    }
113
114    fn generate_model_struct(
115        &self,
116        table_name: &str,
117        columns: &[ResolvedColumn],
118    ) -> Result<String, ScytheError> {
119        let name = to_pascal_case(table_name);
120        self.generate_row_struct(&name, columns)
121    }
122
123    fn generate_query_fn(
124        &self,
125        analyzed: &AnalyzedQuery,
126        struct_name: &str,
127        columns: &[ResolvedColumn],
128        params: &[ResolvedParam],
129    ) -> Result<String, ScytheError> {
130        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
131        let sql = super::clean_sql_oneline_with_optional(
132            &analyzed.sql,
133            &analyzed.optional_params,
134            &analyzed.params,
135        );
136
137        let use_multiline_params = !params.is_empty();
138
139        let mut out = String::new();
140
141        // Helper: write .bind() calls for R2DBC (0-based indexing)
142        let write_binds = |out: &mut String, indent: &str| {
143            for (i, param) in params.iter().enumerate() {
144                let _ = writeln!(out, "{}.bind({}, {})", indent, i, param.field_name);
145            }
146        };
147
148        // Helper: write row mapping expression for Kotlin
149        let write_row_map = |out: &mut String, indent: &str| {
150            let _ = writeln!(out, "{}{}(", indent, struct_name);
151            for col in columns.iter() {
152                let class = r2dbc_row_class(&col.lang_type);
153                let _ = writeln!(
154                    out,
155                    "{}    {} = row.get(\"{}\", {}),",
156                    indent, col.field_name, col.name, class
157                );
158            }
159            let _ = write!(out, "{})", indent);
160        };
161
162        // Helper: write suspend function signature
163        let write_suspend_fn_sig =
164            |out: &mut String, name: &str, ret: &str, multiline: bool, params: &[ResolvedParam]| {
165                if multiline {
166                    let _ = writeln!(out, "suspend fun {}(", name);
167                    let _ = writeln!(out, "    cf: ConnectionFactory,");
168                    for p in params {
169                        let _ = writeln!(out, "    {}: {},", p.field_name, p.full_type);
170                    }
171                    let _ = writeln!(out, "){} {{", ret);
172                } else {
173                    let _ = writeln!(out, "suspend fun {}(cf: ConnectionFactory){} {{", name, ret);
174                }
175            };
176
177        match &analyzed.command {
178            QueryCommand::Exec => {
179                write_suspend_fn_sig(&mut out, &func_name, "", use_multiline_params, params);
180                let _ = writeln!(out, "    val conn = Mono.from(cf.create()).awaitFirst()");
181                let _ = writeln!(out, "    try {{");
182                let _ = writeln!(out, "        val stmt = conn.createStatement(\"{}\")", sql);
183                write_binds(&mut out, "        stmt");
184                let _ = writeln!(
185                    out,
186                    "        Mono.from(stmt.execute()).flatMap {{ result -> Mono.from(result.rowsUpdated) }}.awaitFirstOrNull()"
187                );
188                let _ = writeln!(out, "    }} finally {{");
189                let _ = writeln!(out, "        Mono.from(conn.close()).awaitFirstOrNull()");
190                let _ = writeln!(out, "    }}");
191                let _ = writeln!(out, "}}");
192            }
193            QueryCommand::ExecResult | QueryCommand::ExecRows => {
194                write_suspend_fn_sig(&mut out, &func_name, ": Long", use_multiline_params, params);
195                let _ = writeln!(out, "    val conn = Mono.from(cf.create()).awaitFirst()");
196                let _ = writeln!(out, "    try {{");
197                let _ = writeln!(out, "        val stmt = conn.createStatement(\"{}\")", sql);
198                write_binds(&mut out, "        stmt");
199                let _ = writeln!(out, "        return Mono");
200                let _ = writeln!(out, "            .from(stmt.execute())");
201                let _ = writeln!(
202                    out,
203                    "            .flatMap {{ result -> Mono.from(result.rowsUpdated) }}"
204                );
205                let _ = writeln!(out, "            .awaitFirst()");
206                let _ = writeln!(out, "    }} finally {{");
207                let _ = writeln!(out, "        Mono.from(conn.close()).awaitFirstOrNull()");
208                let _ = writeln!(out, "    }}");
209                let _ = writeln!(out, "}}");
210            }
211            QueryCommand::One => {
212                let ret = format!(": {}?", struct_name);
213                write_suspend_fn_sig(&mut out, &func_name, &ret, use_multiline_params, params);
214                let _ = writeln!(out, "    val conn = Mono.from(cf.create()).awaitFirst()");
215                let _ = writeln!(out, "    try {{");
216                let _ = writeln!(out, "        val stmt = conn.createStatement(\"{}\")", sql);
217                write_binds(&mut out, "        stmt");
218                let _ = writeln!(out, "        return Mono");
219                let _ = writeln!(out, "            .from(stmt.execute())");
220                let _ = writeln!(out, "            .flatMap {{ result ->");
221                let _ = writeln!(out, "                Mono.from(");
222                let _ = writeln!(out, "                    result.map {{ row, _ ->");
223                write_row_map(&mut out, "                        ");
224                let _ = writeln!(out);
225                let _ = writeln!(out, "                    }},");
226                let _ = writeln!(out, "                )");
227                let _ = writeln!(out, "            }}.awaitFirstOrNull()");
228                let _ = writeln!(out, "    }} finally {{");
229                let _ = writeln!(out, "        Mono.from(conn.close()).awaitFirstOrNull()");
230                let _ = writeln!(out, "    }}");
231                let _ = writeln!(out, "}}");
232            }
233            QueryCommand::Many => {
234                // :many returns Flow<T> (non-suspend function, expression body)
235                let ret = format!(": Flow<{}>", struct_name);
236                if use_multiline_params {
237                    let _ = writeln!(out, "fun {}(", func_name);
238                    let _ = writeln!(out, "    cf: ConnectionFactory,");
239                    for p in params {
240                        let _ = writeln!(out, "    {}: {},", p.field_name, p.full_type);
241                    }
242                    let _ = writeln!(out, "){} =", ret);
243                } else {
244                    let _ = writeln!(out, "fun {}(cf: ConnectionFactory){} =", func_name, ret);
245                }
246                let _ = writeln!(out, "    Flux");
247                let _ = writeln!(out, "        .usingWhen(");
248                let _ = writeln!(out, "            cf.create(),");
249                let _ = writeln!(out, "            {{ conn ->");
250                let _ = writeln!(
251                    out,
252                    "                val stmt = conn.createStatement(\"{}\")",
253                    sql
254                );
255                write_binds(&mut out, "                stmt");
256                let _ = writeln!(out, "                Flux");
257                let _ = writeln!(out, "                    .from(stmt.execute())");
258                let _ = writeln!(out, "                    .flatMap {{ result ->");
259                let _ = writeln!(out, "                        result.map {{ row, _ ->");
260                write_row_map(&mut out, "                            ");
261                let _ = writeln!(out);
262                let _ = writeln!(out, "                        }}");
263                let _ = writeln!(out, "                    }}");
264                let _ = writeln!(out, "            }},");
265                let _ = writeln!(out, "            {{ conn -> Mono.from(conn.close()) }},");
266                let _ = writeln!(out, "        ).asFlow()");
267            }
268            QueryCommand::Batch => {
269                let batch_fn_name = format!("{}Batch", func_name);
270                if params.len() > 1 {
271                    let params_class_name =
272                        format!("{}BatchParams", to_pascal_case(&analyzed.name));
273                    let _ = writeln!(out, "data class {}(", params_class_name);
274                    for p in params {
275                        let _ = writeln!(out, "    val {}: {},", p.field_name, p.full_type);
276                    }
277                    let _ = writeln!(out, ")");
278                    let _ = writeln!(out);
279                    let _ = writeln!(out, "suspend fun {}(", batch_fn_name);
280                    let _ = writeln!(out, "    cf: ConnectionFactory,");
281                    let _ = writeln!(out, "    items: List<{}>,", params_class_name);
282                    let _ = writeln!(out, ") {{");
283                    let _ = writeln!(out, "    val conn = Mono.from(cf.create()).awaitFirst()");
284                    let _ = writeln!(out, "    try {{");
285                    let _ = writeln!(
286                        out,
287                        "        Mono.from(conn.beginTransaction()).awaitFirstOrNull()"
288                    );
289                    let _ = writeln!(out, "        val stmt = conn.createStatement(\"{}\")", sql);
290                    let _ = writeln!(out, "        var first = true");
291                    let _ = writeln!(out, "        for (item in items) {{");
292                    let _ = writeln!(out, "            if (!first) stmt.add()");
293                    for (i, param) in params.iter().enumerate() {
294                        let _ = writeln!(
295                            out,
296                            "            stmt.bind({}, item.{})",
297                            i, param.field_name
298                        );
299                    }
300                    let _ = writeln!(out, "            first = false");
301                    let _ = writeln!(out, "        }}");
302                    let _ = writeln!(
303                        out,
304                        "        Flux.from(stmt.execute()).then().awaitFirstOrNull()"
305                    );
306                    let _ = writeln!(
307                        out,
308                        "        Mono.from(conn.commitTransaction()).awaitFirstOrNull()"
309                    );
310                    let _ = writeln!(out, "    }} catch (e: Exception) {{");
311                    let _ = writeln!(
312                        out,
313                        "        Mono.from(conn.rollbackTransaction()).awaitFirstOrNull()"
314                    );
315                    let _ = writeln!(out, "        throw e");
316                    let _ = writeln!(out, "    }} finally {{");
317                    let _ = writeln!(out, "        Mono.from(conn.close()).awaitFirstOrNull()");
318                    let _ = writeln!(out, "    }}");
319                    let _ = writeln!(out, "}}");
320                } else if params.len() == 1 {
321                    let _ = writeln!(out, "suspend fun {}(", batch_fn_name);
322                    let _ = writeln!(out, "    cf: ConnectionFactory,");
323                    let _ = writeln!(out, "    items: List<{}>,", params[0].full_type);
324                    let _ = writeln!(out, ") {{");
325                    let _ = writeln!(out, "    val conn = Mono.from(cf.create()).awaitFirst()");
326                    let _ = writeln!(out, "    try {{");
327                    let _ = writeln!(
328                        out,
329                        "        Mono.from(conn.beginTransaction()).awaitFirstOrNull()"
330                    );
331                    let _ = writeln!(out, "        val stmt = conn.createStatement(\"{}\")", sql);
332                    let _ = writeln!(out, "        var first = true");
333                    let _ = writeln!(out, "        for (item in items) {{");
334                    let _ = writeln!(out, "            if (!first) stmt.add()");
335                    let _ = writeln!(out, "            stmt.bind(0, item)");
336                    let _ = writeln!(out, "            first = false");
337                    let _ = writeln!(out, "        }}");
338                    let _ = writeln!(
339                        out,
340                        "        Flux.from(stmt.execute()).then().awaitFirstOrNull()"
341                    );
342                    let _ = writeln!(
343                        out,
344                        "        Mono.from(conn.commitTransaction()).awaitFirstOrNull()"
345                    );
346                    let _ = writeln!(out, "    }} catch (e: Exception) {{");
347                    let _ = writeln!(
348                        out,
349                        "        Mono.from(conn.rollbackTransaction()).awaitFirstOrNull()"
350                    );
351                    let _ = writeln!(out, "        throw e");
352                    let _ = writeln!(out, "    }} finally {{");
353                    let _ = writeln!(out, "        Mono.from(conn.close()).awaitFirstOrNull()");
354                    let _ = writeln!(out, "    }}");
355                    let _ = writeln!(out, "}}");
356                } else {
357                    let _ = writeln!(
358                        out,
359                        "suspend fun {}(cf: ConnectionFactory, count: Int) {{",
360                        batch_fn_name
361                    );
362                    let _ = writeln!(out, "    val conn = Mono.from(cf.create()).awaitFirst()");
363                    let _ = writeln!(out, "    try {{");
364                    let _ = writeln!(
365                        out,
366                        "        Mono.from(conn.beginTransaction()).awaitFirstOrNull()"
367                    );
368                    let _ = writeln!(out, "        val stmt = conn.createStatement(\"{}\")", sql);
369                    let _ = writeln!(out, "        repeat(count - 1) {{");
370                    let _ = writeln!(out, "            stmt.add()");
371                    let _ = writeln!(out, "        }}");
372                    let _ = writeln!(
373                        out,
374                        "        Flux.from(stmt.execute()).then().awaitFirstOrNull()"
375                    );
376                    let _ = writeln!(
377                        out,
378                        "        Mono.from(conn.commitTransaction()).awaitFirstOrNull()"
379                    );
380                    let _ = writeln!(out, "    }} catch (e: Exception) {{");
381                    let _ = writeln!(
382                        out,
383                        "        Mono.from(conn.rollbackTransaction()).awaitFirstOrNull()"
384                    );
385                    let _ = writeln!(out, "        throw e");
386                    let _ = writeln!(out, "    }} finally {{");
387                    let _ = writeln!(out, "        Mono.from(conn.close()).awaitFirstOrNull()");
388                    let _ = writeln!(out, "    }}");
389                    let _ = writeln!(out, "}}");
390                }
391            }
392            QueryCommand::Grouped => {
393                // Grouped queries are not yet supported for kotlin-r2dbc
394                return Err(ScytheError::new(
395                    ErrorCode::InternalError,
396                    "grouped queries are not yet supported for kotlin-r2dbc".to_string(),
397                ));
398            }
399        }
400
401        Ok(out)
402    }
403
404    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
405        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
406        let mut out = String::new();
407        let _ = writeln!(out, "enum class {}(val value: String) {{", type_name);
408        for (i, value) in enum_info.values.iter().enumerate() {
409            let variant = enum_variant_name(value, &self.manifest.naming);
410            let sep = if i + 1 < enum_info.values.len() {
411                ","
412            } else {
413                ";"
414            };
415            let _ = writeln!(out, "    {}(\"{}\"){}", variant, value, sep);
416        }
417        let _ = writeln!(out, "}}");
418        Ok(out)
419    }
420
421    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
422        let name = to_pascal_case(&composite.sql_name);
423        let mut out = String::new();
424        let _ = writeln!(out, "data class {}(", name);
425        for field in composite.fields.iter() {
426            let field_name = to_camel_case(&field.name);
427            let field_type = resolve_type(&field.neutral_type, &self.manifest, false)
428                .map(|t| t.into_owned())
429                .unwrap_or_else(|_| "Any".to_string());
430            let _ = writeln!(out, "    val {}: {},", field_name, field_type);
431        }
432        let _ = writeln!(out, ")");
433        Ok(out)
434    }
435}