1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15
16const DEFAULT_MANIFEST_PG: &str = include_str!("../../manifests/kotlin-r2dbc.toml");
17const DEFAULT_MANIFEST_MYSQL: &str = include_str!("../../manifests/kotlin-r2dbc.mysql.toml");
18const DEFAULT_MANIFEST_SQLITE: &str = include_str!("../../manifests/kotlin-r2dbc.sqlite.toml");
19
20pub struct KotlinR2dbcBackend {
21 manifest: BackendManifest,
22}
23
24impl KotlinR2dbcBackend {
25 pub fn new(engine: &str) -> Result<Self, ScytheError> {
26 let default_toml = match engine {
27 "postgresql" | "postgres" | "pg" => DEFAULT_MANIFEST_PG,
28 "mysql" | "mariadb" => DEFAULT_MANIFEST_MYSQL,
29 "sqlite" | "sqlite3" => DEFAULT_MANIFEST_SQLITE,
30 _ => {
31 return Err(ScytheError::new(
32 ErrorCode::InternalError,
33 format!("unsupported engine '{}' for kotlin-r2dbc backend", engine),
34 ));
35 }
36 };
37 let manifest_path = Path::new("backends/kotlin-r2dbc/manifest.toml");
38 let manifest = if manifest_path.exists() {
39 load_manifest(manifest_path)
40 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
41 } else {
42 toml::from_str(default_toml)
43 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
44 };
45 Ok(Self { manifest })
46 }
47}
48
49fn r2dbc_row_class(kotlin_type: &str) -> &str {
51 match kotlin_type {
52 "Boolean" => "Boolean::class.java",
53 "Byte" => "Byte::class.java",
54 "Short" => "Short::class.java",
55 "Int" => "Int::class.javaObjectType",
56 "Long" => "Long::class.javaObjectType",
57 "Float" => "Float::class.javaObjectType",
58 "Double" => "Double::class.javaObjectType",
59 "String" => "String::class.java",
60 "ByteArray" => "ByteArray::class.java",
61 _ if kotlin_type.contains("BigDecimal") => "java.math.BigDecimal::class.java",
62 _ if kotlin_type.contains("LocalDate") => "java.time.LocalDate::class.java",
63 _ if kotlin_type.contains("LocalTime") => "java.time.LocalTime::class.java",
64 _ if kotlin_type.contains("OffsetTime") => "java.time.OffsetTime::class.java",
65 _ if kotlin_type.contains("LocalDateTime") => "java.time.LocalDateTime::class.java",
66 _ if kotlin_type.contains("OffsetDateTime") => "java.time.OffsetDateTime::class.java",
67 _ if kotlin_type.contains("UUID") => "java.util.UUID::class.java",
68 _ => "Any::class.java",
69 }
70}
71
72impl CodegenBackend for KotlinR2dbcBackend {
73 fn name(&self) -> &str {
74 "kotlin-r2dbc"
75 }
76
77 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
78 &self.manifest
79 }
80
81 fn supported_engines(&self) -> &[&str] {
82 &["postgresql", "mysql", "sqlite"]
83 }
84
85 fn file_header(&self) -> String {
86 "import io.r2dbc.spi.ConnectionFactory\n\
87 import java.math.BigDecimal\n\
88 import java.time.*\n\
89 import java.util.UUID\n\
90 import kotlinx.coroutines.flow.Flow\n\
91 import kotlinx.coroutines.reactive.asFlow\n\
92 import kotlinx.coroutines.reactive.awaitFirst\n\
93 import kotlinx.coroutines.reactive.awaitFirstOrNull\n\
94 import reactor.core.publisher.Flux\n\
95 import reactor.core.publisher.Mono\n"
96 .to_string()
97 }
98
99 fn generate_row_struct(
100 &self,
101 query_name: &str,
102 columns: &[ResolvedColumn],
103 ) -> Result<String, ScytheError> {
104 let struct_name = row_struct_name(query_name, &self.manifest.naming);
105 let mut out = String::new();
106 let _ = writeln!(out, "data class {}(", struct_name);
107 for col in columns.iter() {
108 let _ = writeln!(out, " val {}: {},", col.field_name, col.full_type);
109 }
110 let _ = writeln!(out, ")");
111 Ok(out)
112 }
113
114 fn generate_model_struct(
115 &self,
116 table_name: &str,
117 columns: &[ResolvedColumn],
118 ) -> Result<String, ScytheError> {
119 let name = to_pascal_case(table_name);
120 self.generate_row_struct(&name, columns)
121 }
122
123 fn generate_query_fn(
124 &self,
125 analyzed: &AnalyzedQuery,
126 struct_name: &str,
127 columns: &[ResolvedColumn],
128 params: &[ResolvedParam],
129 ) -> Result<String, ScytheError> {
130 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
131 let sql = super::clean_sql_oneline_with_optional(
132 &analyzed.sql,
133 &analyzed.optional_params,
134 &analyzed.params,
135 );
136
137 let use_multiline_params = !params.is_empty();
138
139 let mut out = String::new();
140
141 let write_binds = |out: &mut String, indent: &str| {
143 for (i, param) in params.iter().enumerate() {
144 let _ = writeln!(out, "{}.bind({}, {})", indent, i, param.field_name);
145 }
146 };
147
148 let write_row_map = |out: &mut String, indent: &str| {
150 let _ = writeln!(out, "{}{}(", indent, struct_name);
151 for col in columns.iter() {
152 let class = r2dbc_row_class(&col.lang_type);
153 let _ = writeln!(
154 out,
155 "{} {} = row.get(\"{}\", {}),",
156 indent, col.field_name, col.name, class
157 );
158 }
159 let _ = write!(out, "{})", indent);
160 };
161
162 let write_suspend_fn_sig =
164 |out: &mut String, name: &str, ret: &str, multiline: bool, params: &[ResolvedParam]| {
165 if multiline {
166 let _ = writeln!(out, "suspend fun {}(", name);
167 let _ = writeln!(out, " cf: ConnectionFactory,");
168 for p in params {
169 let _ = writeln!(out, " {}: {},", p.field_name, p.full_type);
170 }
171 let _ = writeln!(out, "){} {{", ret);
172 } else {
173 let _ = writeln!(out, "suspend fun {}(cf: ConnectionFactory){} {{", name, ret);
174 }
175 };
176
177 match &analyzed.command {
178 QueryCommand::Exec => {
179 write_suspend_fn_sig(&mut out, &func_name, "", use_multiline_params, params);
180 let _ = writeln!(out, " val conn = Mono.from(cf.create()).awaitFirst()");
181 let _ = writeln!(out, " try {{");
182 let _ = writeln!(out, " val stmt = conn.createStatement(\"{}\")", sql);
183 write_binds(&mut out, " stmt");
184 let _ = writeln!(
185 out,
186 " Mono.from(stmt.execute()).flatMap {{ result -> Mono.from(result.rowsUpdated) }}.awaitFirstOrNull()"
187 );
188 let _ = writeln!(out, " }} finally {{");
189 let _ = writeln!(out, " Mono.from(conn.close()).awaitFirstOrNull()");
190 let _ = writeln!(out, " }}");
191 let _ = writeln!(out, "}}");
192 }
193 QueryCommand::ExecResult | QueryCommand::ExecRows => {
194 write_suspend_fn_sig(&mut out, &func_name, ": Long", use_multiline_params, params);
195 let _ = writeln!(out, " val conn = Mono.from(cf.create()).awaitFirst()");
196 let _ = writeln!(out, " try {{");
197 let _ = writeln!(out, " val stmt = conn.createStatement(\"{}\")", sql);
198 write_binds(&mut out, " stmt");
199 let _ = writeln!(out, " return Mono");
200 let _ = writeln!(out, " .from(stmt.execute())");
201 let _ = writeln!(
202 out,
203 " .flatMap {{ result -> Mono.from(result.rowsUpdated) }}"
204 );
205 let _ = writeln!(out, " .awaitFirst()");
206 let _ = writeln!(out, " }} finally {{");
207 let _ = writeln!(out, " Mono.from(conn.close()).awaitFirstOrNull()");
208 let _ = writeln!(out, " }}");
209 let _ = writeln!(out, "}}");
210 }
211 QueryCommand::One => {
212 let ret = format!(": {}?", struct_name);
213 write_suspend_fn_sig(&mut out, &func_name, &ret, use_multiline_params, params);
214 let _ = writeln!(out, " val conn = Mono.from(cf.create()).awaitFirst()");
215 let _ = writeln!(out, " try {{");
216 let _ = writeln!(out, " val stmt = conn.createStatement(\"{}\")", sql);
217 write_binds(&mut out, " stmt");
218 let _ = writeln!(out, " return Mono");
219 let _ = writeln!(out, " .from(stmt.execute())");
220 let _ = writeln!(out, " .flatMap {{ result ->");
221 let _ = writeln!(out, " Mono.from(");
222 let _ = writeln!(out, " result.map {{ row, _ ->");
223 write_row_map(&mut out, " ");
224 let _ = writeln!(out);
225 let _ = writeln!(out, " }},");
226 let _ = writeln!(out, " )");
227 let _ = writeln!(out, " }}.awaitFirstOrNull()");
228 let _ = writeln!(out, " }} finally {{");
229 let _ = writeln!(out, " Mono.from(conn.close()).awaitFirstOrNull()");
230 let _ = writeln!(out, " }}");
231 let _ = writeln!(out, "}}");
232 }
233 QueryCommand::Many => {
234 let ret = format!(": Flow<{}>", struct_name);
236 if use_multiline_params {
237 let _ = writeln!(out, "fun {}(", func_name);
238 let _ = writeln!(out, " cf: ConnectionFactory,");
239 for p in params {
240 let _ = writeln!(out, " {}: {},", p.field_name, p.full_type);
241 }
242 let _ = writeln!(out, "){} =", ret);
243 } else {
244 let _ = writeln!(out, "fun {}(cf: ConnectionFactory){} =", func_name, ret);
245 }
246 let _ = writeln!(out, " Flux");
247 let _ = writeln!(out, " .usingWhen(");
248 let _ = writeln!(out, " cf.create(),");
249 let _ = writeln!(out, " {{ conn ->");
250 let _ = writeln!(
251 out,
252 " val stmt = conn.createStatement(\"{}\")",
253 sql
254 );
255 write_binds(&mut out, " stmt");
256 let _ = writeln!(out, " Flux");
257 let _ = writeln!(out, " .from(stmt.execute())");
258 let _ = writeln!(out, " .flatMap {{ result ->");
259 let _ = writeln!(out, " result.map {{ row, _ ->");
260 write_row_map(&mut out, " ");
261 let _ = writeln!(out);
262 let _ = writeln!(out, " }}");
263 let _ = writeln!(out, " }}");
264 let _ = writeln!(out, " }},");
265 let _ = writeln!(out, " {{ conn -> Mono.from(conn.close()) }},");
266 let _ = writeln!(out, " ).asFlow()");
267 }
268 QueryCommand::Batch => {
269 let batch_fn_name = format!("{}Batch", func_name);
270 if params.len() > 1 {
271 let params_class_name =
272 format!("{}BatchParams", to_pascal_case(&analyzed.name));
273 let _ = writeln!(out, "data class {}(", params_class_name);
274 for p in params {
275 let _ = writeln!(out, " val {}: {},", p.field_name, p.full_type);
276 }
277 let _ = writeln!(out, ")");
278 let _ = writeln!(out);
279 let _ = writeln!(out, "suspend fun {}(", batch_fn_name);
280 let _ = writeln!(out, " cf: ConnectionFactory,");
281 let _ = writeln!(out, " items: List<{}>,", params_class_name);
282 let _ = writeln!(out, ") {{");
283 let _ = writeln!(out, " val conn = Mono.from(cf.create()).awaitFirst()");
284 let _ = writeln!(out, " try {{");
285 let _ = writeln!(
286 out,
287 " Mono.from(conn.beginTransaction()).awaitFirstOrNull()"
288 );
289 let _ = writeln!(out, " val stmt = conn.createStatement(\"{}\")", sql);
290 let _ = writeln!(out, " var first = true");
291 let _ = writeln!(out, " for (item in items) {{");
292 let _ = writeln!(out, " if (!first) stmt.add()");
293 for (i, param) in params.iter().enumerate() {
294 let _ = writeln!(
295 out,
296 " stmt.bind({}, item.{})",
297 i, param.field_name
298 );
299 }
300 let _ = writeln!(out, " first = false");
301 let _ = writeln!(out, " }}");
302 let _ = writeln!(
303 out,
304 " Flux.from(stmt.execute()).then().awaitFirstOrNull()"
305 );
306 let _ = writeln!(
307 out,
308 " Mono.from(conn.commitTransaction()).awaitFirstOrNull()"
309 );
310 let _ = writeln!(out, " }} catch (e: Exception) {{");
311 let _ = writeln!(
312 out,
313 " Mono.from(conn.rollbackTransaction()).awaitFirstOrNull()"
314 );
315 let _ = writeln!(out, " throw e");
316 let _ = writeln!(out, " }} finally {{");
317 let _ = writeln!(out, " Mono.from(conn.close()).awaitFirstOrNull()");
318 let _ = writeln!(out, " }}");
319 let _ = writeln!(out, "}}");
320 } else if params.len() == 1 {
321 let _ = writeln!(out, "suspend fun {}(", batch_fn_name);
322 let _ = writeln!(out, " cf: ConnectionFactory,");
323 let _ = writeln!(out, " items: List<{}>,", params[0].full_type);
324 let _ = writeln!(out, ") {{");
325 let _ = writeln!(out, " val conn = Mono.from(cf.create()).awaitFirst()");
326 let _ = writeln!(out, " try {{");
327 let _ = writeln!(
328 out,
329 " Mono.from(conn.beginTransaction()).awaitFirstOrNull()"
330 );
331 let _ = writeln!(out, " val stmt = conn.createStatement(\"{}\")", sql);
332 let _ = writeln!(out, " var first = true");
333 let _ = writeln!(out, " for (item in items) {{");
334 let _ = writeln!(out, " if (!first) stmt.add()");
335 let _ = writeln!(out, " stmt.bind(0, item)");
336 let _ = writeln!(out, " first = false");
337 let _ = writeln!(out, " }}");
338 let _ = writeln!(
339 out,
340 " Flux.from(stmt.execute()).then().awaitFirstOrNull()"
341 );
342 let _ = writeln!(
343 out,
344 " Mono.from(conn.commitTransaction()).awaitFirstOrNull()"
345 );
346 let _ = writeln!(out, " }} catch (e: Exception) {{");
347 let _ = writeln!(
348 out,
349 " Mono.from(conn.rollbackTransaction()).awaitFirstOrNull()"
350 );
351 let _ = writeln!(out, " throw e");
352 let _ = writeln!(out, " }} finally {{");
353 let _ = writeln!(out, " Mono.from(conn.close()).awaitFirstOrNull()");
354 let _ = writeln!(out, " }}");
355 let _ = writeln!(out, "}}");
356 } else {
357 let _ = writeln!(
358 out,
359 "suspend fun {}(cf: ConnectionFactory, count: Int) {{",
360 batch_fn_name
361 );
362 let _ = writeln!(out, " val conn = Mono.from(cf.create()).awaitFirst()");
363 let _ = writeln!(out, " try {{");
364 let _ = writeln!(
365 out,
366 " Mono.from(conn.beginTransaction()).awaitFirstOrNull()"
367 );
368 let _ = writeln!(out, " val stmt = conn.createStatement(\"{}\")", sql);
369 let _ = writeln!(out, " repeat(count - 1) {{");
370 let _ = writeln!(out, " stmt.add()");
371 let _ = writeln!(out, " }}");
372 let _ = writeln!(
373 out,
374 " Flux.from(stmt.execute()).then().awaitFirstOrNull()"
375 );
376 let _ = writeln!(
377 out,
378 " Mono.from(conn.commitTransaction()).awaitFirstOrNull()"
379 );
380 let _ = writeln!(out, " }} catch (e: Exception) {{");
381 let _ = writeln!(
382 out,
383 " Mono.from(conn.rollbackTransaction()).awaitFirstOrNull()"
384 );
385 let _ = writeln!(out, " throw e");
386 let _ = writeln!(out, " }} finally {{");
387 let _ = writeln!(out, " Mono.from(conn.close()).awaitFirstOrNull()");
388 let _ = writeln!(out, " }}");
389 let _ = writeln!(out, "}}");
390 }
391 }
392 QueryCommand::Grouped => {
393 return Err(ScytheError::new(
395 ErrorCode::InternalError,
396 "grouped queries are not yet supported for kotlin-r2dbc".to_string(),
397 ));
398 }
399 }
400
401 Ok(out)
402 }
403
404 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
405 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
406 let mut out = String::new();
407 let _ = writeln!(out, "enum class {}(val value: String) {{", type_name);
408 for (i, value) in enum_info.values.iter().enumerate() {
409 let variant = enum_variant_name(value, &self.manifest.naming);
410 let sep = if i + 1 < enum_info.values.len() {
411 ","
412 } else {
413 ";"
414 };
415 let _ = writeln!(out, " {}(\"{}\"){}", variant, value, sep);
416 }
417 let _ = writeln!(out, "}}");
418 Ok(out)
419 }
420
421 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
422 let name = to_pascal_case(&composite.sql_name);
423 let mut out = String::new();
424 let _ = writeln!(out, "data class {}(", name);
425 for field in composite.fields.iter() {
426 let field_name = to_camel_case(&field.name);
427 let field_type = resolve_type(&field.neutral_type, &self.manifest, false)
428 .map(|t| t.into_owned())
429 .unwrap_or_else(|_| "Any".to_string());
430 let _ = writeln!(out, " val {}: {},", field_name, field_type);
431 }
432 let _ = writeln!(out, ")");
433 Ok(out)
434 }
435}