sql_check_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3
4#[cfg(any(feature = "postgres", feature = "mysql", feature = "sqlite"))]
5mod database;
6#[cfg(any(feature = "postgres", feature = "mysql", feature = "sqlite"))]
7mod runtime;
8
9/// Internal procedural macro for SQL validation.
10/// Do not use directly - use the `check!` macro from the `sql-check` crate instead.
11#[proc_macro]
12pub fn check_impl(input: TokenStream) -> TokenStream {
13    let input_sql = syn::parse_macro_input!(input as syn::LitStr);
14    let sql_value = input_sql.value();
15
16    // Get DATABASE_URL
17    let database_url = match env("DATABASE_URL") {
18        Ok(url) => url,
19        Err(e) => {
20            // If DATABASE_URL is not set, emit a compile error
21            return quote! {
22                compile_error!(#e)
23            }
24            .into();
25        }
26    };
27
28    // Perform SQL validation
29    match validate_sql(&sql_value, &database_url) {
30        Ok(()) => {
31            // Validation succeeded, return the SQL string
32            quote! { #input_sql }.into()
33        }
34        Err(e) => {
35            // Validation failed, emit a compile error
36            let error_msg = format!("SQL validation failed: {}", e);
37            quote! {
38                compile_error!(#error_msg)
39            }
40            .into()
41        }
42    }
43}
44
45fn validate_sql(sql: &str, database_url: &str) -> Result<(), String> {
46    // Parse URL to determine database type
47    // Handle both uppercase and lowercase schemes
48    let url_lower = database_url.to_lowercase();
49
50    if url_lower.starts_with("postgres://") || url_lower.starts_with("postgresql://") {
51        validate_postgres(sql, database_url)
52    } else if url_lower.starts_with("mysql://") {
53        validate_mysql(sql, database_url)
54    } else if url_lower.starts_with("sqlite://") || url_lower.starts_with("sqlite:") {
55        validate_sqlite(sql, database_url)
56    } else {
57        Err(format!("Unsupported database URL scheme. Expected postgres://, mysql://, or sqlite://. Got: {}", 
58            database_url.split("://").next().unwrap_or("unknown")))
59    }
60}
61
62#[cfg(feature = "postgres")]
63fn validate_postgres(sql: &str, database_url: &str) -> Result<(), String> {
64    use sqlx_core::config;
65
66    let driver_config = config::drivers::Config::default();
67
68    <sqlx_postgres::Postgres as database::DatabaseExt>::describe_blocking(
69        sql,
70        database_url,
71        &driver_config,
72    )
73    .map(|_| ())
74    .map_err(|e| e.to_string())
75}
76
77#[cfg(not(feature = "postgres"))]
78fn validate_postgres(_sql: &str, _database_url: &str) -> Result<(), String> {
79    Err("PostgreSQL support not enabled. Enable the 'postgres' feature.".to_string())
80}
81
82#[cfg(feature = "mysql")]
83fn validate_mysql(sql: &str, database_url: &str) -> Result<(), String> {
84    use sqlx_core::config;
85
86    let driver_config = config::drivers::Config::default();
87
88    <sqlx_mysql::MySql as database::DatabaseExt>::describe_blocking(
89        sql,
90        database_url,
91        &driver_config,
92    )
93    .map(|_| ())
94    .map_err(|e| e.to_string())
95}
96
97#[cfg(not(feature = "mysql"))]
98fn validate_mysql(_sql: &str, _database_url: &str) -> Result<(), String> {
99    Err("MySQL support not enabled. Enable the 'mysql' feature.".to_string())
100}
101
102#[cfg(feature = "sqlite")]
103fn validate_sqlite(sql: &str, database_url: &str) -> Result<(), String> {
104    use sqlx_core::config;
105
106    let driver_config = config::drivers::Config::default();
107
108    <sqlx_sqlite::Sqlite as database::DatabaseExt>::describe_blocking(
109        sql,
110        database_url,
111        &driver_config,
112    )
113    .map(|_| ())
114    .map_err(|e| e.to_string())
115}
116
117#[cfg(not(feature = "sqlite"))]
118fn validate_sqlite(_sql: &str, _database_url: &str) -> Result<(), String> {
119    Err("SQLite support not enabled. Enable the 'sqlite' feature.".to_string())
120}
121
122fn env(var: &str) -> Result<String, String> {
123    std::env::var(var).map_err(|_| {
124        format!(
125            "Environment variable {} must be set to use SQL validation",
126            var
127        )
128    })
129}