1use std::fs;
9use std::path::{Path, PathBuf};
10
11use anyhow::{Context, Result, anyhow, bail};
12use scythe_core::analyzer::AnalyzedQuery;
13use scythe_core::catalog::Catalog;
14use scythe_core::dialect::SqlDialect;
15use scythe_core::parser::parse_query_with_dialect;
16use spikard_codegen::sql::{BuildOptions, DecimalMode, LanguageBackend, OpenApiInfo, build_handler_set};
17
18use super::TargetLanguage;
19use super::engine::GeneratedAsset;
20
21#[derive(Debug)]
24pub struct SqlCodegenOutput {
25 pub assets: Vec<GeneratedAsset>,
26}
27
28#[derive(Debug, Clone)]
29pub struct SqlCodegenConfig {
30 pub schema_paths: Vec<PathBuf>,
31 pub queries_dir: PathBuf,
32 pub output_dir: PathBuf,
33 pub dialect: SqlDialect,
34 pub languages: Vec<TargetLanguage>,
35 pub decimal_mode: DecimalMode,
36 pub strict: bool,
37 pub emit_openapi: bool,
38 pub api_title: String,
39 pub api_version: String,
40}
41
42pub fn generate_from_sql_dir(config: SqlCodegenConfig) -> Result<SqlCodegenOutput> {
43 let catalog = load_catalog(&config.schema_paths, &config.dialect)?;
44 let queries = load_queries(&config.queries_dir, &config.dialect, &catalog)?;
45 if queries.is_empty() {
46 bail!(
47 "No queries found in {}. Add at least one .sql file with `-- @name`, `-- @returns`, and `-- @http` annotations.",
48 config.queries_dir.display()
49 );
50 }
51
52 let info = OpenApiInfo::new(config.api_title.clone(), config.api_version.clone());
53 let opts = BuildOptions {
54 decimal_mode: config.decimal_mode,
55 strict: config.strict,
56 };
57
58 let backends: Vec<LanguageBackend<'_>> = config.languages.iter().map(|lang| language_backend(*lang)).collect();
59
60 let set = build_handler_set(&catalog, &queries, &info, &opts, &backends)
61 .context("Failed to build handler set from SQL annotations")?;
62
63 fs::create_dir_all(&config.output_dir)
64 .with_context(|| format!("Failed to create output directory {}", config.output_dir.display()))?;
65
66 let mut assets = Vec::new();
67
68 let routes_path = config.output_dir.join("handlers.json");
69 let routes_json = serde_json::to_string_pretty(&set.routes).context("Failed to serialize routes")?;
70 fs::write(&routes_path, &routes_json).with_context(|| format!("Failed to write {}", routes_path.display()))?;
71 assets.push(GeneratedAsset {
72 path: routes_path,
73 description: "SQL-derived route metadata".to_string(),
74 });
75
76 let sidecar_path = config.output_dir.join("spikard-sql.json");
77 let sidecar_json = serde_json::to_string_pretty(&set.sidecar).context("Failed to serialize sidecar")?;
78 fs::write(&sidecar_path, &sidecar_json).with_context(|| format!("Failed to write {}", sidecar_path.display()))?;
79 assets.push(GeneratedAsset {
80 path: sidecar_path,
81 description: "Per-language SQL→handler sidecar".to_string(),
82 });
83
84 if config.emit_openapi {
85 let openapi_path = config.output_dir.join("openapi.json");
86 let openapi_json = serde_json::to_string_pretty(&set.openapi).context("Failed to serialize OpenAPI spec")?;
87 fs::write(&openapi_path, &openapi_json)
88 .with_context(|| format!("Failed to write {}", openapi_path.display()))?;
89 assets.push(GeneratedAsset {
90 path: openapi_path,
91 description: "OpenAPI 3.1 spec derived from SQL annotations".to_string(),
92 });
93 }
94
95 Ok(SqlCodegenOutput { assets })
96}
97
98fn load_catalog(schema_paths: &[PathBuf], dialect: &SqlDialect) -> Result<Catalog> {
99 let mut ddl_strings: Vec<String> = Vec::new();
100 for path in schema_paths {
101 if path.is_dir() {
102 for entry in fs::read_dir(path).with_context(|| format!("Failed to read schema dir {}", path.display()))? {
103 let entry = entry?;
104 if entry.file_type()?.is_file() && has_sql_extension(&entry.path()) {
105 ddl_strings.push(fs::read_to_string(entry.path())?);
106 }
107 }
108 } else {
109 ddl_strings.push(
110 fs::read_to_string(path).with_context(|| format!("Failed to read schema file {}", path.display()))?,
111 );
112 }
113 }
114 if ddl_strings.is_empty() {
115 bail!("No schema DDL found at the configured paths");
116 }
117 let refs: Vec<&str> = ddl_strings.iter().map(String::as_str).collect();
118 Catalog::from_ddl_with_dialect(&refs, dialect).map_err(|e| anyhow!("Failed to build catalog: {}", e))
119}
120
121fn load_queries(queries_dir: &Path, dialect: &SqlDialect, catalog: &Catalog) -> Result<Vec<AnalyzedQuery>> {
122 let mut entries: Vec<PathBuf> = if queries_dir.is_file() {
123 vec![queries_dir.to_path_buf()]
124 } else {
125 fs::read_dir(queries_dir)
126 .with_context(|| format!("Failed to read queries dir {}", queries_dir.display()))?
127 .filter_map(|e| e.ok())
128 .filter(|e| e.path().is_file() && has_sql_extension(&e.path()))
129 .map(|e| e.path())
130 .collect()
131 };
132 entries.sort();
133
134 let mut out = Vec::new();
135 for path in entries {
136 let body = fs::read_to_string(&path).with_context(|| format!("Failed to read {}", path.display()))?;
137 for chunk in split_queries(&body) {
138 if chunk.trim().is_empty() {
139 continue;
140 }
141 let query = parse_query_with_dialect(chunk, dialect)
142 .map_err(|e| anyhow!("Failed to parse query in {}: {}", path.display(), e))?;
143 let analyzed = scythe_core::analyzer::analyze(catalog, &query)
144 .map_err(|e| anyhow!("Failed to analyze query in {}: {}", path.display(), e))?;
145 out.push(analyzed);
146 }
147 }
148 Ok(out)
149}
150
151fn split_queries(body: &str) -> Vec<&str> {
155 let mut chunks = Vec::new();
156 let mut start: Option<usize> = None;
157 let mut last_pos = 0usize;
158 for (idx, line) in body.match_indices('\n').chain(std::iter::once((body.len(), ""))) {
159 let line_start = last_pos;
160 let line_end = idx;
161 let line = &body[line_start..line_end];
162 if line.trim_start().to_ascii_lowercase().starts_with("-- @name")
163 || line.trim_start().to_ascii_lowercase().starts_with("--@name")
164 {
165 if let Some(s) = start {
166 chunks.push(body[s..line_start].trim_end_matches('\n'));
167 }
168 start = Some(line_start);
169 }
170 last_pos = line_end + 1;
171 }
172 if let Some(s) = start {
173 chunks.push(body[s..].trim_end_matches('\n'));
174 }
175 chunks
176}
177
178fn has_sql_extension(p: &Path) -> bool {
179 p.extension()
180 .and_then(|e| e.to_str())
181 .map(|e| e.eq_ignore_ascii_case("sql"))
182 .unwrap_or(false)
183}
184
185fn language_backend(lang: TargetLanguage) -> LanguageBackend<'static> {
186 match lang {
187 TargetLanguage::Python => LanguageBackend {
188 name: "python",
189 scythe_module: "queries",
190 is_async: true,
191 scythe_fn_for: &python_fn_name,
192 lang_type_for: &python_lang_type,
193 },
194 TargetLanguage::TypeScript => LanguageBackend {
195 name: "typescript",
196 scythe_module: "./queries",
197 is_async: true,
198 scythe_fn_for: &camel_fn_name,
199 lang_type_for: &typescript_lang_type,
200 },
201 TargetLanguage::Rust => LanguageBackend {
202 name: "rust",
203 scythe_module: "crate::queries",
204 is_async: true,
205 scythe_fn_for: &snake_fn_name,
206 lang_type_for: &rust_lang_type,
207 },
208 TargetLanguage::Ruby => LanguageBackend {
209 name: "ruby",
210 scythe_module: "Queries",
211 is_async: false,
212 scythe_fn_for: &snake_fn_name,
213 lang_type_for: &ruby_lang_type,
214 },
215 TargetLanguage::Php => LanguageBackend {
216 name: "php",
217 scythe_module: "Queries",
218 is_async: false,
219 scythe_fn_for: &camel_fn_name,
220 lang_type_for: &php_lang_type,
221 },
222 TargetLanguage::Elixir => LanguageBackend {
223 name: "elixir",
224 scythe_module: "Queries",
225 is_async: false,
226 scythe_fn_for: &snake_fn_name,
227 lang_type_for: &elixir_lang_type,
228 },
229 }
230}
231
232fn snake_fn_name(name: &str) -> String {
233 let mut out = String::with_capacity(name.len() + 4);
234 let mut prev_lower = false;
235 for c in name.chars() {
236 if c.is_ascii_uppercase() {
237 if prev_lower {
238 out.push('_');
239 }
240 out.push(c.to_ascii_lowercase());
241 prev_lower = false;
242 } else {
243 out.push(c);
244 prev_lower = c.is_ascii_lowercase() || c.is_ascii_digit();
245 }
246 }
247 out
248}
249
250fn camel_fn_name(name: &str) -> String {
251 let mut chars = name.chars();
252 match chars.next() {
253 Some(c) => c.to_ascii_lowercase().to_string() + chars.as_str(),
254 None => String::new(),
255 }
256}
257
258fn python_fn_name(name: &str) -> String {
259 snake_fn_name(name)
260}
261
262fn python_lang_type(neutral: &str, nullable: bool) -> String {
263 let base = match neutral {
264 n if n.starts_with("array<") => {
265 return format!(
266 "list[{}]{}",
267 python_lang_type(&n[6..n.len() - 1], false),
268 if nullable { " | None" } else { "" }
269 );
270 }
271 "int16" | "int32" | "int64" => "int",
272 "float32" | "float64" => "float",
273 "string" => "str",
274 "bool" => "bool",
275 "bytes" => "bytes",
276 "uuid" => "UUID",
277 "date" => "date",
278 "datetime" | "datetime_tz" => "datetime",
279 "time" | "time_tz" => "time",
280 "decimal" => "Decimal",
281 "json" => "Any",
282 _ => "Any",
283 };
284 if nullable {
285 format!("{} | None", base)
286 } else {
287 base.to_string()
288 }
289}
290
291fn typescript_lang_type(neutral: &str, nullable: bool) -> String {
292 let base = match neutral {
293 n if n.starts_with("array<") => {
294 return format!(
295 "{}[]{}",
296 typescript_lang_type(&n[6..n.len() - 1], false),
297 if nullable { " | null" } else { "" }
298 );
299 }
300 "int16" | "int32" | "float32" | "float64" => "number",
301 "int64" => "bigint",
302 "string" | "uuid" | "date" | "datetime" | "datetime_tz" | "time" | "time_tz" | "decimal" => "string",
303 "bool" => "boolean",
304 "bytes" => "Uint8Array",
305 "json" => "unknown",
306 _ => "unknown",
307 };
308 if nullable {
309 format!("{} | null", base)
310 } else {
311 base.to_string()
312 }
313}
314
315fn rust_lang_type(neutral: &str, nullable: bool) -> String {
316 let base = match neutral {
317 n if n.starts_with("array<") => {
318 return wrap_nullable_rust(format!("Vec<{}>", rust_lang_type(&n[6..n.len() - 1], false)), nullable);
319 }
320 "int16" => "i16",
321 "int32" => "i32",
322 "int64" => "i64",
323 "float32" => "f32",
324 "float64" => "f64",
325 "string" => "String",
326 "bool" => "bool",
327 "bytes" => "Vec<u8>",
328 "uuid" => "uuid::Uuid",
329 "date" => "chrono::NaiveDate",
330 "datetime" => "chrono::NaiveDateTime",
331 "datetime_tz" => "chrono::DateTime<chrono::Utc>",
332 "time" => "chrono::NaiveTime",
333 "time_tz" => "chrono::NaiveTime",
334 "decimal" => "rust_decimal::Decimal",
335 "json" => "serde_json::Value",
336 _ => "serde_json::Value",
337 };
338 wrap_nullable_rust(base.to_string(), nullable)
339}
340
341fn wrap_nullable_rust(t: String, nullable: bool) -> String {
342 if nullable { format!("Option<{}>", t) } else { t }
343}
344
345fn ruby_lang_type(neutral: &str, nullable: bool) -> String {
346 let base = match neutral {
347 n if n.starts_with("array<") => {
348 return format!(
349 "Array<{}>{}",
350 ruby_lang_type(&n[6..n.len() - 1], false),
351 if nullable { "?" } else { "" }
352 );
353 }
354 "int16" | "int32" | "int64" => "Integer",
355 "float32" | "float64" => "Float",
356 "string" | "uuid" => "String",
357 "bool" => "Bool",
358 "bytes" => "String",
359 "date" => "Date",
360 "datetime" | "datetime_tz" => "DateTime",
361 "time" | "time_tz" => "Time",
362 "decimal" => "BigDecimal",
363 "json" => "Hash",
364 _ => "Object",
365 };
366 if nullable {
367 format!("{}?", base)
368 } else {
369 base.to_string()
370 }
371}
372
373fn php_lang_type(neutral: &str, nullable: bool) -> String {
374 let base = match neutral {
375 n if n.starts_with("array<") => "array",
376 "int16" | "int32" | "int64" => "int",
377 "float32" | "float64" => "float",
378 "string" | "uuid" | "date" | "datetime" | "datetime_tz" | "time" | "time_tz" | "decimal" | "bytes" => "string",
379 "bool" => "bool",
380 "json" => "mixed",
381 _ => "mixed",
382 };
383 if nullable {
384 format!("?{}", base)
385 } else {
386 base.to_string()
387 }
388}
389
390fn elixir_lang_type(neutral: &str, nullable: bool) -> String {
391 let base = match neutral {
392 n if n.starts_with("array<") => {
393 return format!(
394 "[{}]{}",
395 elixir_lang_type(&n[6..n.len() - 1], false),
396 if nullable { " | nil" } else { "" }
397 );
398 }
399 "int16" | "int32" | "int64" => "integer()",
400 "float32" | "float64" => "float()",
401 "string" | "uuid" | "date" | "datetime" | "datetime_tz" | "time" | "time_tz" | "decimal" => "String.t()",
402 "bool" => "boolean()",
403 "bytes" => "binary()",
404 "json" => "map()",
405 _ => "any()",
406 };
407 if nullable {
408 format!("{} | nil", base)
409 } else {
410 base.to_string()
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417 use tempfile::tempdir;
418
419 fn write(path: &Path, body: &str) {
420 std::fs::write(path, body).unwrap();
421 }
422
423 #[test]
424 fn split_queries_separates_at_at_name() {
425 let body = "-- @name First\n-- @returns :one\nSELECT 1;\n\n-- @name Second\n-- @returns :many\nSELECT 2;\n";
426 let chunks = split_queries(body);
427 assert_eq!(chunks.len(), 2);
428 assert!(chunks[0].contains("First"));
429 assert!(chunks[1].contains("Second"));
430 }
431
432 #[test]
433 fn split_queries_handles_single_query() {
434 let body = "-- @name Only\n-- @returns :one\nSELECT 1;";
435 let chunks = split_queries(body);
436 assert_eq!(chunks.len(), 1);
437 }
438
439 #[test]
440 fn end_to_end_smoke_writes_three_files() {
441 let dir = tempdir().unwrap();
442 let schema_path = dir.path().join("schema.sql");
443 write(
444 &schema_path,
445 "CREATE TABLE users (id BIGSERIAL PRIMARY KEY, email TEXT NOT NULL);",
446 );
447 let queries_dir = dir.path().join("queries");
448 std::fs::create_dir_all(&queries_dir).unwrap();
449 write(
450 &queries_dir.join("users.sql"),
451 "-- @name GetUser\n-- @returns :one\n-- @http GET /users/{id}\nSELECT id, email FROM users WHERE id = $1;",
452 );
453 let output_dir = dir.path().join("out");
454 let output = generate_from_sql_dir(SqlCodegenConfig {
455 schema_paths: vec![schema_path],
456 queries_dir,
457 output_dir: output_dir.clone(),
458 dialect: SqlDialect::PostgreSQL,
459 languages: vec![TargetLanguage::Python],
460 decimal_mode: DecimalMode::StringPattern,
461 strict: false,
462 emit_openapi: true,
463 api_title: "Demo".into(),
464 api_version: "0.1.0".into(),
465 })
466 .unwrap();
467 assert_eq!(output.assets.len(), 3);
468 assert!(output_dir.join("handlers.json").exists());
469 assert!(output_dir.join("openapi.json").exists());
470 assert!(output_dir.join("spikard-sql.json").exists());
471
472 let openapi: serde_json::Value =
473 serde_json::from_str(&std::fs::read_to_string(output_dir.join("openapi.json")).unwrap()).unwrap();
474 assert_eq!(openapi["openapi"], "3.1.0");
475 assert!(openapi["paths"]["/users/{id}"]["get"].is_object());
476
477 let sidecar: serde_json::Value =
478 serde_json::from_str(&std::fs::read_to_string(output_dir.join("spikard-sql.json")).unwrap()).unwrap();
479 let entry = &sidecar["by_language"]["python"]["GetUser"];
480 assert_eq!(entry["scythe_fn"], "get_user");
481 assert_eq!(entry["scythe_module"], "queries");
482 }
483
484 #[test]
485 fn snake_and_camel_helpers() {
486 assert_eq!(snake_fn_name("GetUser"), "get_user");
487 assert_eq!(snake_fn_name("ListActiveUsers"), "list_active_users");
488 assert_eq!(camel_fn_name("GetUser"), "getUser");
489 }
490
491 #[test]
492 fn python_lang_type_optional_wraps_with_none() {
493 assert_eq!(python_lang_type("string", true), "str | None");
494 assert_eq!(python_lang_type("int64", false), "int");
495 }
496
497 #[test]
498 fn rust_lang_type_wraps_option() {
499 assert_eq!(rust_lang_type("string", true), "Option<String>");
500 assert_eq!(rust_lang_type("int32", false), "i32");
501 }
502}