Skip to main content

spikard_cli/codegen/
sql.rs

1//! Glue between the CLI and `spikard_codegen::sql`.
2//!
3//! Reads schema DDL + annotated query files from disk, runs scythe's parser
4//! and analyzer, builds the handler set via `spikard_codegen::sql`, and writes
5//! `handlers.json` (route list), `openapi.json` (the spec), and
6//! `spikard-sql.json` (sidecar) to the output directory.
7
8use std::fs;
9use std::path::{Path, PathBuf};
10
11use anyhow::{Context, Result, anyhow, bail};
12use scythe_core::analyzer::AnalyzedQuery;
13use scythe_core::catalog::Catalog;
14use scythe_core::dialect::SqlDialect;
15use scythe_core::parser::parse_query_with_dialect;
16use spikard_codegen::sql::{BuildOptions, DecimalMode, LanguageBackend, OpenApiInfo, build_handler_set};
17
18use super::TargetLanguage;
19use super::engine::GeneratedAsset;
20
21/// Output of [`generate_from_sql_dir`] — the three artifacts always written
22/// to the output directory.
23#[derive(Debug)]
24pub struct SqlCodegenOutput {
25    pub assets: Vec<GeneratedAsset>,
26}
27
28#[derive(Debug, Clone)]
29pub struct SqlCodegenConfig {
30    pub schema_paths: Vec<PathBuf>,
31    pub queries_dir: PathBuf,
32    pub output_dir: PathBuf,
33    pub dialect: SqlDialect,
34    pub languages: Vec<TargetLanguage>,
35    pub decimal_mode: DecimalMode,
36    pub strict: bool,
37    pub emit_openapi: bool,
38    pub api_title: String,
39    pub api_version: String,
40}
41
42pub fn generate_from_sql_dir(config: SqlCodegenConfig) -> Result<SqlCodegenOutput> {
43    let catalog = load_catalog(&config.schema_paths, &config.dialect)?;
44    let queries = load_queries(&config.queries_dir, &config.dialect, &catalog)?;
45    if queries.is_empty() {
46        bail!(
47            "No queries found in {}. Add at least one .sql file with `-- @name`, `-- @returns`, and `-- @http` annotations.",
48            config.queries_dir.display()
49        );
50    }
51
52    let info = OpenApiInfo::new(config.api_title.clone(), config.api_version.clone());
53    let opts = BuildOptions {
54        decimal_mode: config.decimal_mode,
55        strict: config.strict,
56    };
57
58    let backends: Vec<LanguageBackend<'_>> = config.languages.iter().map(|lang| language_backend(*lang)).collect();
59
60    let set = build_handler_set(&catalog, &queries, &info, &opts, &backends)
61        .context("Failed to build handler set from SQL annotations")?;
62
63    fs::create_dir_all(&config.output_dir)
64        .with_context(|| format!("Failed to create output directory {}", config.output_dir.display()))?;
65
66    let mut assets = Vec::new();
67
68    let routes_path = config.output_dir.join("handlers.json");
69    let routes_json = serde_json::to_string_pretty(&set.routes).context("Failed to serialize routes")?;
70    fs::write(&routes_path, &routes_json).with_context(|| format!("Failed to write {}", routes_path.display()))?;
71    assets.push(GeneratedAsset {
72        path: routes_path,
73        description: "SQL-derived route metadata".to_string(),
74    });
75
76    let sidecar_path = config.output_dir.join("spikard-sql.json");
77    let sidecar_json = serde_json::to_string_pretty(&set.sidecar).context("Failed to serialize sidecar")?;
78    fs::write(&sidecar_path, &sidecar_json).with_context(|| format!("Failed to write {}", sidecar_path.display()))?;
79    assets.push(GeneratedAsset {
80        path: sidecar_path,
81        description: "Per-language SQL→handler sidecar".to_string(),
82    });
83
84    if config.emit_openapi {
85        let openapi_path = config.output_dir.join("openapi.json");
86        let openapi_json = serde_json::to_string_pretty(&set.openapi).context("Failed to serialize OpenAPI spec")?;
87        fs::write(&openapi_path, &openapi_json)
88            .with_context(|| format!("Failed to write {}", openapi_path.display()))?;
89        assets.push(GeneratedAsset {
90            path: openapi_path,
91            description: "OpenAPI 3.1 spec derived from SQL annotations".to_string(),
92        });
93    }
94
95    Ok(SqlCodegenOutput { assets })
96}
97
98fn load_catalog(schema_paths: &[PathBuf], dialect: &SqlDialect) -> Result<Catalog> {
99    let mut ddl_strings: Vec<String> = Vec::new();
100    for path in schema_paths {
101        if path.is_dir() {
102            for entry in fs::read_dir(path).with_context(|| format!("Failed to read schema dir {}", path.display()))? {
103                let entry = entry?;
104                if entry.file_type()?.is_file() && has_sql_extension(&entry.path()) {
105                    ddl_strings.push(fs::read_to_string(entry.path())?);
106                }
107            }
108        } else {
109            ddl_strings.push(
110                fs::read_to_string(path).with_context(|| format!("Failed to read schema file {}", path.display()))?,
111            );
112        }
113    }
114    if ddl_strings.is_empty() {
115        bail!("No schema DDL found at the configured paths");
116    }
117    let refs: Vec<&str> = ddl_strings.iter().map(String::as_str).collect();
118    Catalog::from_ddl_with_dialect(&refs, dialect).map_err(|e| anyhow!("Failed to build catalog: {}", e))
119}
120
121fn load_queries(queries_dir: &Path, dialect: &SqlDialect, catalog: &Catalog) -> Result<Vec<AnalyzedQuery>> {
122    let mut entries: Vec<PathBuf> = if queries_dir.is_file() {
123        vec![queries_dir.to_path_buf()]
124    } else {
125        fs::read_dir(queries_dir)
126            .with_context(|| format!("Failed to read queries dir {}", queries_dir.display()))?
127            .filter_map(|e| e.ok())
128            .filter(|e| e.path().is_file() && has_sql_extension(&e.path()))
129            .map(|e| e.path())
130            .collect()
131    };
132    entries.sort();
133
134    let mut out = Vec::new();
135    for path in entries {
136        let body = fs::read_to_string(&path).with_context(|| format!("Failed to read {}", path.display()))?;
137        for chunk in split_queries(&body) {
138            if chunk.trim().is_empty() {
139                continue;
140            }
141            let query = parse_query_with_dialect(chunk, dialect)
142                .map_err(|e| anyhow!("Failed to parse query in {}: {}", path.display(), e))?;
143            let analyzed = scythe_core::analyzer::analyze(catalog, &query)
144                .map_err(|e| anyhow!("Failed to analyze query in {}: {}", path.display(), e))?;
145            out.push(analyzed);
146        }
147    }
148    Ok(out)
149}
150
151/// Split a `.sql` file into individual query blocks. Each block starts at the
152/// first `-- @name` it contains; everything between two such markers is one
153/// query (including its other annotations + SQL body).
154fn split_queries(body: &str) -> Vec<&str> {
155    let mut chunks = Vec::new();
156    let mut start: Option<usize> = None;
157    let mut last_pos = 0usize;
158    for (idx, line) in body.match_indices('\n').chain(std::iter::once((body.len(), ""))) {
159        let line_start = last_pos;
160        let line_end = idx;
161        let line = &body[line_start..line_end];
162        if line.trim_start().to_ascii_lowercase().starts_with("-- @name")
163            || line.trim_start().to_ascii_lowercase().starts_with("--@name")
164        {
165            if let Some(s) = start {
166                chunks.push(body[s..line_start].trim_end_matches('\n'));
167            }
168            start = Some(line_start);
169        }
170        last_pos = line_end + 1;
171    }
172    if let Some(s) = start {
173        chunks.push(body[s..].trim_end_matches('\n'));
174    }
175    chunks
176}
177
178fn has_sql_extension(p: &Path) -> bool {
179    p.extension()
180        .and_then(|e| e.to_str())
181        .map(|e| e.eq_ignore_ascii_case("sql"))
182        .unwrap_or(false)
183}
184
185fn language_backend(lang: TargetLanguage) -> LanguageBackend<'static> {
186    match lang {
187        TargetLanguage::Python => LanguageBackend {
188            name: "python",
189            scythe_module: "queries",
190            is_async: true,
191            scythe_fn_for: &python_fn_name,
192            lang_type_for: &python_lang_type,
193        },
194        TargetLanguage::TypeScript => LanguageBackend {
195            name: "typescript",
196            scythe_module: "./queries",
197            is_async: true,
198            scythe_fn_for: &camel_fn_name,
199            lang_type_for: &typescript_lang_type,
200        },
201        TargetLanguage::Rust => LanguageBackend {
202            name: "rust",
203            scythe_module: "crate::queries",
204            is_async: true,
205            scythe_fn_for: &snake_fn_name,
206            lang_type_for: &rust_lang_type,
207        },
208        TargetLanguage::Ruby => LanguageBackend {
209            name: "ruby",
210            scythe_module: "Queries",
211            is_async: false,
212            scythe_fn_for: &snake_fn_name,
213            lang_type_for: &ruby_lang_type,
214        },
215        TargetLanguage::Php => LanguageBackend {
216            name: "php",
217            scythe_module: "Queries",
218            is_async: false,
219            scythe_fn_for: &camel_fn_name,
220            lang_type_for: &php_lang_type,
221        },
222        TargetLanguage::Elixir => LanguageBackend {
223            name: "elixir",
224            scythe_module: "Queries",
225            is_async: false,
226            scythe_fn_for: &snake_fn_name,
227            lang_type_for: &elixir_lang_type,
228        },
229    }
230}
231
232fn snake_fn_name(name: &str) -> String {
233    let mut out = String::with_capacity(name.len() + 4);
234    let mut prev_lower = false;
235    for c in name.chars() {
236        if c.is_ascii_uppercase() {
237            if prev_lower {
238                out.push('_');
239            }
240            out.push(c.to_ascii_lowercase());
241            prev_lower = false;
242        } else {
243            out.push(c);
244            prev_lower = c.is_ascii_lowercase() || c.is_ascii_digit();
245        }
246    }
247    out
248}
249
250fn camel_fn_name(name: &str) -> String {
251    let mut chars = name.chars();
252    match chars.next() {
253        Some(c) => c.to_ascii_lowercase().to_string() + chars.as_str(),
254        None => String::new(),
255    }
256}
257
258fn python_fn_name(name: &str) -> String {
259    snake_fn_name(name)
260}
261
262fn python_lang_type(neutral: &str, nullable: bool) -> String {
263    let base = match neutral {
264        n if n.starts_with("array<") => {
265            return format!(
266                "list[{}]{}",
267                python_lang_type(&n[6..n.len() - 1], false),
268                if nullable { " | None" } else { "" }
269            );
270        }
271        "int16" | "int32" | "int64" => "int",
272        "float32" | "float64" => "float",
273        "string" => "str",
274        "bool" => "bool",
275        "bytes" => "bytes",
276        "uuid" => "UUID",
277        "date" => "date",
278        "datetime" | "datetime_tz" => "datetime",
279        "time" | "time_tz" => "time",
280        "decimal" => "Decimal",
281        "json" => "Any",
282        _ => "Any",
283    };
284    if nullable {
285        format!("{} | None", base)
286    } else {
287        base.to_string()
288    }
289}
290
291fn typescript_lang_type(neutral: &str, nullable: bool) -> String {
292    let base = match neutral {
293        n if n.starts_with("array<") => {
294            return format!(
295                "{}[]{}",
296                typescript_lang_type(&n[6..n.len() - 1], false),
297                if nullable { " | null" } else { "" }
298            );
299        }
300        "int16" | "int32" | "float32" | "float64" => "number",
301        "int64" => "bigint",
302        "string" | "uuid" | "date" | "datetime" | "datetime_tz" | "time" | "time_tz" | "decimal" => "string",
303        "bool" => "boolean",
304        "bytes" => "Uint8Array",
305        "json" => "unknown",
306        _ => "unknown",
307    };
308    if nullable {
309        format!("{} | null", base)
310    } else {
311        base.to_string()
312    }
313}
314
315fn rust_lang_type(neutral: &str, nullable: bool) -> String {
316    let base = match neutral {
317        n if n.starts_with("array<") => {
318            return wrap_nullable_rust(format!("Vec<{}>", rust_lang_type(&n[6..n.len() - 1], false)), nullable);
319        }
320        "int16" => "i16",
321        "int32" => "i32",
322        "int64" => "i64",
323        "float32" => "f32",
324        "float64" => "f64",
325        "string" => "String",
326        "bool" => "bool",
327        "bytes" => "Vec<u8>",
328        "uuid" => "uuid::Uuid",
329        "date" => "chrono::NaiveDate",
330        "datetime" => "chrono::NaiveDateTime",
331        "datetime_tz" => "chrono::DateTime<chrono::Utc>",
332        "time" => "chrono::NaiveTime",
333        "time_tz" => "chrono::NaiveTime",
334        "decimal" => "rust_decimal::Decimal",
335        "json" => "serde_json::Value",
336        _ => "serde_json::Value",
337    };
338    wrap_nullable_rust(base.to_string(), nullable)
339}
340
341fn wrap_nullable_rust(t: String, nullable: bool) -> String {
342    if nullable { format!("Option<{}>", t) } else { t }
343}
344
345fn ruby_lang_type(neutral: &str, nullable: bool) -> String {
346    let base = match neutral {
347        n if n.starts_with("array<") => {
348            return format!(
349                "Array<{}>{}",
350                ruby_lang_type(&n[6..n.len() - 1], false),
351                if nullable { "?" } else { "" }
352            );
353        }
354        "int16" | "int32" | "int64" => "Integer",
355        "float32" | "float64" => "Float",
356        "string" | "uuid" => "String",
357        "bool" => "Bool",
358        "bytes" => "String",
359        "date" => "Date",
360        "datetime" | "datetime_tz" => "DateTime",
361        "time" | "time_tz" => "Time",
362        "decimal" => "BigDecimal",
363        "json" => "Hash",
364        _ => "Object",
365    };
366    if nullable {
367        format!("{}?", base)
368    } else {
369        base.to_string()
370    }
371}
372
373fn php_lang_type(neutral: &str, nullable: bool) -> String {
374    let base = match neutral {
375        n if n.starts_with("array<") => "array",
376        "int16" | "int32" | "int64" => "int",
377        "float32" | "float64" => "float",
378        "string" | "uuid" | "date" | "datetime" | "datetime_tz" | "time" | "time_tz" | "decimal" | "bytes" => "string",
379        "bool" => "bool",
380        "json" => "mixed",
381        _ => "mixed",
382    };
383    if nullable {
384        format!("?{}", base)
385    } else {
386        base.to_string()
387    }
388}
389
390fn elixir_lang_type(neutral: &str, nullable: bool) -> String {
391    let base = match neutral {
392        n if n.starts_with("array<") => {
393            return format!(
394                "[{}]{}",
395                elixir_lang_type(&n[6..n.len() - 1], false),
396                if nullable { " | nil" } else { "" }
397            );
398        }
399        "int16" | "int32" | "int64" => "integer()",
400        "float32" | "float64" => "float()",
401        "string" | "uuid" | "date" | "datetime" | "datetime_tz" | "time" | "time_tz" | "decimal" => "String.t()",
402        "bool" => "boolean()",
403        "bytes" => "binary()",
404        "json" => "map()",
405        _ => "any()",
406    };
407    if nullable {
408        format!("{} | nil", base)
409    } else {
410        base.to_string()
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417    use tempfile::tempdir;
418
419    fn write(path: &Path, body: &str) {
420        std::fs::write(path, body).unwrap();
421    }
422
423    #[test]
424    fn split_queries_separates_at_at_name() {
425        let body = "-- @name First\n-- @returns :one\nSELECT 1;\n\n-- @name Second\n-- @returns :many\nSELECT 2;\n";
426        let chunks = split_queries(body);
427        assert_eq!(chunks.len(), 2);
428        assert!(chunks[0].contains("First"));
429        assert!(chunks[1].contains("Second"));
430    }
431
432    #[test]
433    fn split_queries_handles_single_query() {
434        let body = "-- @name Only\n-- @returns :one\nSELECT 1;";
435        let chunks = split_queries(body);
436        assert_eq!(chunks.len(), 1);
437    }
438
439    #[test]
440    fn end_to_end_smoke_writes_three_files() {
441        let dir = tempdir().unwrap();
442        let schema_path = dir.path().join("schema.sql");
443        write(
444            &schema_path,
445            "CREATE TABLE users (id BIGSERIAL PRIMARY KEY, email TEXT NOT NULL);",
446        );
447        let queries_dir = dir.path().join("queries");
448        std::fs::create_dir_all(&queries_dir).unwrap();
449        write(
450            &queries_dir.join("users.sql"),
451            "-- @name GetUser\n-- @returns :one\n-- @http GET /users/{id}\nSELECT id, email FROM users WHERE id = $1;",
452        );
453        let output_dir = dir.path().join("out");
454        let output = generate_from_sql_dir(SqlCodegenConfig {
455            schema_paths: vec![schema_path],
456            queries_dir,
457            output_dir: output_dir.clone(),
458            dialect: SqlDialect::PostgreSQL,
459            languages: vec![TargetLanguage::Python],
460            decimal_mode: DecimalMode::StringPattern,
461            strict: false,
462            emit_openapi: true,
463            api_title: "Demo".into(),
464            api_version: "0.1.0".into(),
465        })
466        .unwrap();
467        assert_eq!(output.assets.len(), 3);
468        assert!(output_dir.join("handlers.json").exists());
469        assert!(output_dir.join("openapi.json").exists());
470        assert!(output_dir.join("spikard-sql.json").exists());
471
472        let openapi: serde_json::Value =
473            serde_json::from_str(&std::fs::read_to_string(output_dir.join("openapi.json")).unwrap()).unwrap();
474        assert_eq!(openapi["openapi"], "3.1.0");
475        assert!(openapi["paths"]["/users/{id}"]["get"].is_object());
476
477        let sidecar: serde_json::Value =
478            serde_json::from_str(&std::fs::read_to_string(output_dir.join("spikard-sql.json")).unwrap()).unwrap();
479        let entry = &sidecar["by_language"]["python"]["GetUser"];
480        assert_eq!(entry["scythe_fn"], "get_user");
481        assert_eq!(entry["scythe_module"], "queries");
482    }
483
484    #[test]
485    fn snake_and_camel_helpers() {
486        assert_eq!(snake_fn_name("GetUser"), "get_user");
487        assert_eq!(snake_fn_name("ListActiveUsers"), "list_active_users");
488        assert_eq!(camel_fn_name("GetUser"), "getUser");
489    }
490
491    #[test]
492    fn python_lang_type_optional_wraps_with_none() {
493        assert_eq!(python_lang_type("string", true), "str | None");
494        assert_eq!(python_lang_type("int64", false), "int");
495    }
496
497    #[test]
498    fn rust_lang_type_wraps_option() {
499        assert_eq!(rust_lang_type("string", true), "Option<String>");
500        assert_eq!(rust_lang_type("int32", false), "i32");
501    }
502}