Skip to main content

systemprompt_database/services/
executor.rs

1//! SQL batch and statement-by-statement execution helpers.
2
3use super::database::Database;
4use super::provider::DatabaseProvider;
5use crate::error::{DatabaseResult, RepositoryError};
6use crate::models::QueryResult;
7
8#[derive(Debug, Copy, Clone)]
9pub struct SqlExecutor;
10
11enum SplitState {
12    Normal,
13    SingleQuote,
14    DollarQuote(String),
15    LineComment,
16    BlockComment(u32),
17}
18
19fn dollar_tag_end(bytes: &[u8], start: usize) -> Option<usize> {
20    debug_assert_eq!(bytes[start], b'$');
21    let mut j = start + 1;
22    while j < bytes.len() {
23        let c = bytes[j];
24        if c == b'$' {
25            return Some(j);
26        }
27        if !(c.is_ascii_alphanumeric() || c == b'_') {
28            return None;
29        }
30        j += 1;
31    }
32    None
33}
34
35impl SqlExecutor {
36    pub async fn execute_statements(db: &Database, sql: &str) -> DatabaseResult<()> {
37        db.execute_batch(sql).await.map_err(|e| {
38            RepositoryError::Internal(format!("Failed to execute SQL statements: {e}"))
39        })
40    }
41
42    pub async fn execute_statements_parsed(
43        db: &dyn DatabaseProvider,
44        sql: &str,
45    ) -> DatabaseResult<()> {
46        let statements = Self::parse_sql_statements(sql)?;
47
48        for statement in statements {
49            db.execute_raw(&statement).await.map_err(|e| {
50                RepositoryError::Internal(format!(
51                    "Failed to execute SQL statement: {statement}: {e}"
52                ))
53            })?;
54        }
55
56        Ok(())
57    }
58
59    /// Split a Postgres SQL script into individual statements while preserving
60    /// the original source text. Splits on top-level `;`; ignores
61    /// semicolons inside single quotes, dollar-quoted bodies (`$$ … $$` and
62    /// `$tag$ … $tag$`), `--` line comments, and `/* … */` block comments
63    /// (nested). Unterminated quotes or comments return
64    /// `RepositoryError::Internal`; grammar errors are left for Postgres to
65    /// surface at execute time. Preserving the original text is the
66    /// reason this is hand-rolled rather than `sqlparser`: round-tripping
67    /// through `Statement::Display` drops syntactic detail such as the
68    /// empty parameter list on `CREATE FUNCTION foo()`, which Postgres then
69    /// rejects.
70    pub fn parse_sql_statements(sql: &str) -> DatabaseResult<Vec<String>> {
71        let bytes = sql.as_bytes();
72        let mut statements = Vec::new();
73        let mut start = 0usize;
74        let mut i = 0usize;
75        let mut state = SplitState::Normal;
76        let mut has_content = false;
77        let mut emit = |sql: &str, start: usize, end: usize, has_content: &mut bool| {
78            if *has_content {
79                let stmt = sql[start..end].trim();
80                if !stmt.is_empty() {
81                    statements.push(stmt.to_string());
82                }
83            }
84            *has_content = false;
85        };
86
87        while i < bytes.len() {
88            match &mut state {
89                SplitState::Normal => match bytes[i] {
90                    b'\'' => {
91                        has_content = true;
92                        state = SplitState::SingleQuote;
93                        i += 1;
94                    },
95                    b'-' if bytes.get(i + 1) == Some(&b'-') => {
96                        state = SplitState::LineComment;
97                        i += 2;
98                    },
99                    b'/' if bytes.get(i + 1) == Some(&b'*') => {
100                        state = SplitState::BlockComment(1);
101                        i += 2;
102                    },
103                    b'$' => {
104                        has_content = true;
105                        if let Some(tag_end) = dollar_tag_end(bytes, i) {
106                            let tag = sql[i..=tag_end].to_string();
107                            state = SplitState::DollarQuote(tag);
108                            i = tag_end + 1;
109                        } else {
110                            i += 1;
111                        }
112                    },
113                    b';' => {
114                        emit(sql, start, i, &mut has_content);
115                        i += 1;
116                        start = i;
117                    },
118                    b => {
119                        if !b.is_ascii_whitespace() {
120                            has_content = true;
121                        }
122                        i += 1;
123                    },
124                },
125                SplitState::SingleQuote => {
126                    if bytes[i] == b'\'' {
127                        if bytes.get(i + 1) == Some(&b'\'') {
128                            i += 2;
129                        } else {
130                            state = SplitState::Normal;
131                            i += 1;
132                        }
133                    } else {
134                        i += 1;
135                    }
136                },
137                SplitState::DollarQuote(tag) => {
138                    let tag_bytes = tag.as_bytes();
139                    if i + tag_bytes.len() <= bytes.len()
140                        && &bytes[i..i + tag_bytes.len()] == tag_bytes
141                    {
142                        i += tag_bytes.len();
143                        state = SplitState::Normal;
144                    } else {
145                        i += 1;
146                    }
147                },
148                SplitState::LineComment => {
149                    if bytes[i] == b'\n' {
150                        state = SplitState::Normal;
151                    }
152                    i += 1;
153                },
154                SplitState::BlockComment(depth) => {
155                    if bytes[i] == b'/' && bytes.get(i + 1) == Some(&b'*') {
156                        *depth += 1;
157                        i += 2;
158                    } else if bytes[i] == b'*' && bytes.get(i + 1) == Some(&b'/') {
159                        *depth -= 1;
160                        i += 2;
161                        if *depth == 0 {
162                            state = SplitState::Normal;
163                        }
164                    } else {
165                        i += 1;
166                    }
167                },
168            }
169        }
170
171        match state {
172            SplitState::Normal | SplitState::LineComment => {
173                emit(sql, start, sql.len(), &mut has_content);
174                Ok(statements)
175            },
176            SplitState::SingleQuote => Err(RepositoryError::Internal(
177                "Unterminated string literal in SQL".into(),
178            )),
179            SplitState::DollarQuote(tag) => Err(RepositoryError::Internal(format!(
180                "Unterminated dollar-quoted string: {tag}"
181            ))),
182            SplitState::BlockComment(_) => Err(RepositoryError::Internal(
183                "Unterminated block comment in SQL".into(),
184            )),
185        }
186    }
187
188    pub async fn execute_query(db: &Database, query: &str) -> DatabaseResult<QueryResult> {
189        db.query(&query)
190            .await
191            .map_err(|e| RepositoryError::Internal(format!("Failed to execute query: {e}")))
192    }
193
194    pub async fn execute_file(db: &Database, file_path: &str) -> DatabaseResult<()> {
195        let sql = std::fs::read_to_string(file_path).map_err(|e| {
196            RepositoryError::Internal(format!("Failed to read SQL file: {file_path}: {e}"))
197        })?;
198        Self::execute_statements(db, &sql).await
199    }
200
201    pub async fn execute_file_parsed(
202        db: &dyn DatabaseProvider,
203        file_path: &str,
204    ) -> DatabaseResult<()> {
205        let sql = std::fs::read_to_string(file_path).map_err(|e| {
206            RepositoryError::Internal(format!("Failed to read SQL file: {file_path}: {e}"))
207        })?;
208        Self::execute_statements_parsed(db, &sql).await
209    }
210
211    pub async fn table_exists(db: &Database, table_name: &str) -> DatabaseResult<bool> {
212        let result = db
213            .query_with(
214                &"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema = \
215                  'public' AND table_name = $1) as exists",
216                vec![serde_json::Value::String(table_name.to_string())],
217            )
218            .await?;
219
220        result
221            .first()
222            .and_then(|row| row.get("exists"))
223            .and_then(serde_json::Value::as_bool)
224            .ok_or_else(|| RepositoryError::Internal("Failed to check table existence".to_string()))
225    }
226
227    pub async fn column_exists(
228        db: &Database,
229        table_name: &str,
230        column_name: &str,
231    ) -> DatabaseResult<bool> {
232        let result = db
233            .query_with(
234                &"SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_schema = \
235                  'public' AND table_name = $1 AND column_name = $2) as exists",
236                vec![
237                    serde_json::Value::String(table_name.to_string()),
238                    serde_json::Value::String(column_name.to_string()),
239                ],
240            )
241            .await?;
242
243        result
244            .first()
245            .and_then(|row| row.get("exists"))
246            .and_then(serde_json::Value::as_bool)
247            .ok_or_else(|| {
248                RepositoryError::Internal("Failed to check column existence".to_string())
249            })
250    }
251}