Skip to main content

sql_composer_sqlx/
lib.rs

1//! sqlx integration for sql-composer.
2//!
3//! Provides verification of composed SQL against a live database connection
4//! and optional syntax validation via sqlparser.
5
6pub use sql_composer;
7
8use sql_composer::composer::ComposedSql;
9
10/// Errors specific to the sqlx integration.
11#[derive(Debug, thiserror::Error)]
12pub enum Error {
13    /// An error from sql-composer core.
14    #[error("composer error: {0}")]
15    Composer(#[from] sql_composer::Error),
16
17    /// An error from sqlx during verification.
18    #[error("sqlx error: {0}")]
19    Sqlx(#[from] sqlx::Error),
20
21    /// SQL syntax validation failed (requires `validate` feature).
22    #[error("SQL syntax error: {0}")]
23    Syntax(String),
24}
25
26/// A specialized `Result` type for sqlx integration operations.
27pub type Result<T> = std::result::Result<T, Error>;
28
29/// Verify composed SQL statements against a PostgreSQL database.
30///
31/// Connects to the database and attempts to `PREPARE` each statement.
32/// This validates that the SQL syntax is correct and that referenced
33/// tables/columns exist.
34#[cfg(feature = "postgres")]
35pub async fn verify_postgres(database_url: &str, statements: &[&ComposedSql]) -> Result<()> {
36    use sqlx::postgres::PgPoolOptions;
37    use sqlx::Executor;
38
39    let pool = PgPoolOptions::new()
40        .max_connections(1)
41        .connect(database_url)
42        .await?;
43
44    for (i, stmt) in statements.iter().enumerate() {
45        pool.execute(sqlx::query(&format!(
46            "PREPARE _sqlc_verify_{i} AS {}",
47            stmt.sql
48        )))
49        .await?;
50
51        pool.execute(sqlx::query(&format!("DEALLOCATE _sqlc_verify_{i}")))
52            .await?;
53    }
54
55    pool.close().await;
56    Ok(())
57}
58
59/// Validate SQL syntax without a database connection.
60///
61/// Uses sqlparser to check that the composed SQL is syntactically valid.
62/// This does not check table/column existence.
63#[cfg(feature = "validate")]
64pub fn validate_syntax(sql: &str, dialect: sql_composer::Dialect) -> Result<()> {
65    use sqlparser::dialect::{MySqlDialect, PostgreSqlDialect, SQLiteDialect};
66    use sqlparser::parser::Parser;
67
68    let dialect: Box<dyn sqlparser::dialect::Dialect> = match dialect {
69        sql_composer::Dialect::Postgres => Box::new(PostgreSqlDialect {}),
70        sql_composer::Dialect::Mysql => Box::new(MySqlDialect {}),
71        sql_composer::Dialect::Sqlite => Box::new(SQLiteDialect {}),
72    };
73
74    // Replace placeholders with literal values for parsing
75    let normalized = normalize_placeholders(sql);
76    Parser::parse_sql(dialect.as_ref(), &normalized).map_err(|e| Error::Syntax(e.to_string()))?;
77
78    Ok(())
79}
80
81/// Replace dialect-specific placeholders with literal `1` for syntax validation.
82#[cfg(feature = "validate")]
83fn normalize_placeholders(sql: &str) -> String {
84    let mut result = String::with_capacity(sql.len());
85    let mut chars = sql.chars().peekable();
86
87    while let Some(ch) = chars.next() {
88        if ch == '$' || ch == '?' {
89            // Skip the placeholder number
90            let mut has_digits = false;
91            while let Some(&next) = chars.peek() {
92                if next.is_ascii_digit() {
93                    chars.next();
94                    has_digits = true;
95                } else {
96                    break;
97                }
98            }
99            if has_digits || ch == '?' {
100                result.push('1');
101            } else {
102                result.push(ch);
103            }
104        } else {
105            result.push(ch);
106        }
107    }
108
109    result
110}
111
112#[cfg(test)]
113mod tests {
114    #[cfg(feature = "validate")]
115    mod validate_tests {
116        use crate::{normalize_placeholders, validate_syntax};
117        use sql_composer::Dialect;
118
119        #[test]
120        fn test_validate_syntax_postgres() {
121            validate_syntax("SELECT 1", Dialect::Postgres).unwrap();
122        }
123
124        #[test]
125        fn test_validate_syntax_mysql() {
126            validate_syntax("SELECT 1", Dialect::Mysql).unwrap();
127        }
128
129        #[test]
130        fn test_validate_syntax_sqlite() {
131            validate_syntax("SELECT 1", Dialect::Sqlite).unwrap();
132        }
133
134        #[test]
135        fn test_validate_syntax_invalid() {
136            let result = validate_syntax("SELECTT 1 FROMM", Dialect::Postgres);
137            assert!(result.is_err());
138        }
139
140        #[test]
141        fn test_validate_syntax_with_placeholders() {
142            // Placeholders get normalized to `1` before parsing
143            validate_syntax("SELECT * FROM users WHERE id = $1", Dialect::Postgres).unwrap();
144        }
145
146        #[test]
147        fn test_normalize_placeholders_postgres() {
148            assert_eq!(normalize_placeholders("$1"), "1");
149            assert_eq!(normalize_placeholders("$10"), "1");
150            assert_eq!(
151                normalize_placeholders("WHERE a = $1 AND b = $2"),
152                "WHERE a = 1 AND b = 1"
153            );
154        }
155
156        #[test]
157        fn test_normalize_placeholders_mysql() {
158            assert_eq!(normalize_placeholders("?"), "1");
159            assert_eq!(
160                normalize_placeholders("WHERE a = ? AND b = ?"),
161                "WHERE a = 1 AND b = 1"
162            );
163        }
164
165        #[test]
166        fn test_normalize_placeholders_sqlite() {
167            assert_eq!(normalize_placeholders("?1"), "1");
168            assert_eq!(
169                normalize_placeholders("WHERE a = ?1 AND b = ?2"),
170                "WHERE a = 1 AND b = 1"
171            );
172        }
173
174        #[test]
175        fn test_normalize_preserves_dollar_without_digits() {
176            // A bare $ not followed by digits should be preserved
177            assert_eq!(normalize_placeholders("$"), "$");
178        }
179    }
180}