Skip to main content

systemprompt_database/services/
executor.rs

1//! SQL batch and statement-by-statement execution helpers.
2
3use super::database::Database;
4use super::provider::DatabaseProvider;
5use crate::error::{DatabaseResult, RepositoryError};
6use crate::models::QueryResult;
7
8#[derive(Debug, Copy, Clone)]
9pub struct SqlExecutor;
10
11enum SplitState {
12    Normal,
13    SingleQuote,
14    DollarQuote(String),
15    LineComment,
16    BlockComment(u32),
17}
18
19fn dollar_tag_end(bytes: &[u8], start: usize) -> Option<usize> {
20    debug_assert_eq!(bytes[start], b'$');
21    let mut j = start + 1;
22    while j < bytes.len() {
23        let c = bytes[j];
24        if c == b'$' {
25            return Some(j);
26        }
27        if !(c.is_ascii_alphanumeric() || c == b'_') {
28            return None;
29        }
30        j += 1;
31    }
32    None
33}
34
35struct Splitter<'a> {
36    sql: &'a str,
37    bytes: &'a [u8],
38    i: usize,
39    start: usize,
40    has_content: bool,
41    statements: Vec<String>,
42}
43
44impl<'a> Splitter<'a> {
45    const fn new(sql: &'a str) -> Self {
46        Self {
47            sql,
48            bytes: sql.as_bytes(),
49            i: 0,
50            start: 0,
51            has_content: false,
52            statements: Vec::new(),
53        }
54    }
55
56    fn emit(&mut self, end: usize) {
57        if self.has_content {
58            let stmt = self.sql[self.start..end].trim();
59            if !stmt.is_empty() {
60                self.statements.push(stmt.to_string());
61            }
62        }
63        self.has_content = false;
64    }
65
66    fn step_normal(&mut self) -> SplitState {
67        match self.bytes[self.i] {
68            b'\'' => {
69                self.has_content = true;
70                self.i += 1;
71                SplitState::SingleQuote
72            },
73            b'-' if self.bytes.get(self.i + 1) == Some(&b'-') => {
74                self.i += 2;
75                SplitState::LineComment
76            },
77            b'/' if self.bytes.get(self.i + 1) == Some(&b'*') => {
78                self.i += 2;
79                SplitState::BlockComment(1)
80            },
81            b'$' => {
82                self.has_content = true;
83                if let Some(tag_end) = dollar_tag_end(self.bytes, self.i) {
84                    let tag = self.sql[self.i..=tag_end].to_string();
85                    self.i = tag_end + 1;
86                    SplitState::DollarQuote(tag)
87                } else {
88                    self.i += 1;
89                    SplitState::Normal
90                }
91            },
92            b';' => {
93                self.emit(self.i);
94                self.i += 1;
95                self.start = self.i;
96                SplitState::Normal
97            },
98            b => {
99                if !b.is_ascii_whitespace() {
100                    self.has_content = true;
101                }
102                self.i += 1;
103                SplitState::Normal
104            },
105        }
106    }
107
108    fn step_single_quote(&mut self) -> SplitState {
109        if self.bytes[self.i] == b'\'' {
110            if self.bytes.get(self.i + 1) == Some(&b'\'') {
111                self.i += 2;
112                SplitState::SingleQuote
113            } else {
114                self.i += 1;
115                SplitState::Normal
116            }
117        } else {
118            self.i += 1;
119            SplitState::SingleQuote
120        }
121    }
122
123    fn step_dollar_quote(&mut self, tag: String) -> SplitState {
124        let tag_bytes = tag.as_bytes();
125        if self.i + tag_bytes.len() <= self.bytes.len()
126            && self.bytes[self.i..self.i + tag_bytes.len()] == *tag_bytes
127        {
128            self.i += tag_bytes.len();
129            SplitState::Normal
130        } else {
131            self.i += 1;
132            SplitState::DollarQuote(tag)
133        }
134    }
135
136    fn step_line_comment(&mut self) -> SplitState {
137        let next = if self.bytes[self.i] == b'\n' {
138            SplitState::Normal
139        } else {
140            SplitState::LineComment
141        };
142        self.i += 1;
143        next
144    }
145
146    fn step_block_comment(&mut self, depth: u32) -> SplitState {
147        if self.bytes[self.i] == b'/' && self.bytes.get(self.i + 1) == Some(&b'*') {
148            self.i += 2;
149            SplitState::BlockComment(depth + 1)
150        } else if self.bytes[self.i] == b'*' && self.bytes.get(self.i + 1) == Some(&b'/') {
151            self.i += 2;
152            if depth == 1 {
153                SplitState::Normal
154            } else {
155                SplitState::BlockComment(depth - 1)
156            }
157        } else {
158            self.i += 1;
159            SplitState::BlockComment(depth)
160        }
161    }
162
163    fn run(mut self) -> DatabaseResult<Vec<String>> {
164        let mut state = SplitState::Normal;
165        while self.i < self.bytes.len() {
166            state = match state {
167                SplitState::Normal => self.step_normal(),
168                SplitState::SingleQuote => self.step_single_quote(),
169                SplitState::DollarQuote(tag) => self.step_dollar_quote(tag),
170                SplitState::LineComment => self.step_line_comment(),
171                SplitState::BlockComment(depth) => self.step_block_comment(depth),
172            };
173        }
174
175        match state {
176            SplitState::Normal | SplitState::LineComment => {
177                let end = self.sql.len();
178                self.emit(end);
179                Ok(self.statements)
180            },
181            SplitState::SingleQuote => Err(RepositoryError::Internal(
182                "Unterminated string literal in SQL".into(),
183            )),
184            SplitState::DollarQuote(tag) => Err(RepositoryError::Internal(format!(
185                "Unterminated dollar-quoted string: {tag}"
186            ))),
187            SplitState::BlockComment(_) => Err(RepositoryError::Internal(
188                "Unterminated block comment in SQL".into(),
189            )),
190        }
191    }
192}
193
194impl SqlExecutor {
195    pub async fn execute_statements(db: &Database, sql: &str) -> DatabaseResult<()> {
196        db.execute_batch(sql).await.map_err(|e| {
197            RepositoryError::Internal(format!("Failed to execute SQL statements: {e}"))
198        })
199    }
200
201    pub async fn execute_statements_parsed(
202        db: &dyn DatabaseProvider,
203        sql: &str,
204    ) -> DatabaseResult<()> {
205        let statements = Self::parse_sql_statements(sql)?;
206
207        for statement in statements {
208            db.execute_raw(&statement).await.map_err(|e| {
209                RepositoryError::Internal(format!(
210                    "Failed to execute SQL statement: {statement}: {e}"
211                ))
212            })?;
213        }
214
215        Ok(())
216    }
217
218    /// Split a Postgres SQL script into individual statements while preserving
219    /// the original source text. Splits on top-level `;`; ignores
220    /// semicolons inside single quotes, dollar-quoted bodies (`$$ … $$` and
221    /// `$tag$ … $tag$`), `--` line comments, and `/* … */` block comments
222    /// (nested). Unterminated quotes or comments return
223    /// `RepositoryError::Internal`; grammar errors are left for Postgres to
224    /// surface at execute time. Preserving the original text is the
225    /// reason this is hand-rolled rather than `sqlparser`: round-tripping
226    /// through `Statement::Display` drops syntactic detail such as the
227    /// empty parameter list on `CREATE FUNCTION foo()`, which Postgres then
228    /// rejects.
229    pub fn parse_sql_statements(sql: &str) -> DatabaseResult<Vec<String>> {
230        Splitter::new(sql).run()
231    }
232
233    pub async fn execute_query(db: &Database, query: &str) -> DatabaseResult<QueryResult> {
234        db.query(&query)
235            .await
236            .map_err(|e| RepositoryError::Internal(format!("Failed to execute query: {e}")))
237    }
238
239    pub async fn execute_file(db: &Database, file_path: &str) -> DatabaseResult<()> {
240        let sql = std::fs::read_to_string(file_path).map_err(|e| {
241            RepositoryError::Internal(format!("Failed to read SQL file: {file_path}: {e}"))
242        })?;
243        Self::execute_statements(db, &sql).await
244    }
245
246    pub async fn execute_file_parsed(
247        db: &dyn DatabaseProvider,
248        file_path: &str,
249    ) -> DatabaseResult<()> {
250        let sql = std::fs::read_to_string(file_path).map_err(|e| {
251            RepositoryError::Internal(format!("Failed to read SQL file: {file_path}: {e}"))
252        })?;
253        Self::execute_statements_parsed(db, &sql).await
254    }
255
256    pub async fn table_exists(db: &Database, table_name: &str) -> DatabaseResult<bool> {
257        let result = db
258            .query_with(
259                &"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema = \
260                  'public' AND table_name = $1) as exists",
261                &[&table_name],
262            )
263            .await?;
264
265        result
266            .first()
267            .and_then(|row| row.get("exists"))
268            .and_then(serde_json::Value::as_bool)
269            .ok_or_else(|| RepositoryError::Internal("Failed to check table existence".to_string()))
270    }
271
272    pub async fn column_exists(
273        db: &Database,
274        table_name: &str,
275        column_name: &str,
276    ) -> DatabaseResult<bool> {
277        let result = db
278            .query_with(
279                &"SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_schema = \
280                  'public' AND table_name = $1 AND column_name = $2) as exists",
281                &[&table_name, &column_name],
282            )
283            .await?;
284
285        result
286            .first()
287            .and_then(|row| row.get("exists"))
288            .and_then(serde_json::Value::as_bool)
289            .ok_or_else(|| {
290                RepositoryError::Internal("Failed to check column existence".to_string())
291            })
292    }
293}