systemprompt_database/services/
executor.rs1use super::database::Database;
4use super::provider::DatabaseProvider;
5use crate::error::{DatabaseResult, RepositoryError};
6use crate::models::QueryResult;
7
8#[derive(Debug, Copy, Clone)]
9pub struct SqlExecutor;
10
11enum SplitState {
12 Normal,
13 SingleQuote,
14 DollarQuote(String),
15 LineComment,
16 BlockComment(u32),
17}
18
19fn dollar_tag_end(bytes: &[u8], start: usize) -> Option<usize> {
20 debug_assert_eq!(bytes[start], b'$');
21 let mut j = start + 1;
22 while j < bytes.len() {
23 let c = bytes[j];
24 if c == b'$' {
25 return Some(j);
26 }
27 if !(c.is_ascii_alphanumeric() || c == b'_') {
28 return None;
29 }
30 j += 1;
31 }
32 None
33}
34
35struct Splitter<'a> {
36 sql: &'a str,
37 bytes: &'a [u8],
38 i: usize,
39 start: usize,
40 has_content: bool,
41 statements: Vec<String>,
42}
43
44impl<'a> Splitter<'a> {
45 const fn new(sql: &'a str) -> Self {
46 Self {
47 sql,
48 bytes: sql.as_bytes(),
49 i: 0,
50 start: 0,
51 has_content: false,
52 statements: Vec::new(),
53 }
54 }
55
56 fn emit(&mut self, end: usize) {
57 if self.has_content {
58 let stmt = self.sql[self.start..end].trim();
59 if !stmt.is_empty() {
60 self.statements.push(stmt.to_string());
61 }
62 }
63 self.has_content = false;
64 }
65
66 fn step_normal(&mut self) -> SplitState {
67 match self.bytes[self.i] {
68 b'\'' => {
69 self.has_content = true;
70 self.i += 1;
71 SplitState::SingleQuote
72 },
73 b'-' if self.bytes.get(self.i + 1) == Some(&b'-') => {
74 self.i += 2;
75 SplitState::LineComment
76 },
77 b'/' if self.bytes.get(self.i + 1) == Some(&b'*') => {
78 self.i += 2;
79 SplitState::BlockComment(1)
80 },
81 b'$' => {
82 self.has_content = true;
83 if let Some(tag_end) = dollar_tag_end(self.bytes, self.i) {
84 let tag = self.sql[self.i..=tag_end].to_string();
85 self.i = tag_end + 1;
86 SplitState::DollarQuote(tag)
87 } else {
88 self.i += 1;
89 SplitState::Normal
90 }
91 },
92 b';' => {
93 self.emit(self.i);
94 self.i += 1;
95 self.start = self.i;
96 SplitState::Normal
97 },
98 b => {
99 if !b.is_ascii_whitespace() {
100 self.has_content = true;
101 }
102 self.i += 1;
103 SplitState::Normal
104 },
105 }
106 }
107
108 fn step_single_quote(&mut self) -> SplitState {
109 if self.bytes[self.i] == b'\'' {
110 if self.bytes.get(self.i + 1) == Some(&b'\'') {
111 self.i += 2;
112 SplitState::SingleQuote
113 } else {
114 self.i += 1;
115 SplitState::Normal
116 }
117 } else {
118 self.i += 1;
119 SplitState::SingleQuote
120 }
121 }
122
123 fn step_dollar_quote(&mut self, tag: String) -> SplitState {
124 let tag_bytes = tag.as_bytes();
125 if self.i + tag_bytes.len() <= self.bytes.len()
126 && self.bytes[self.i..self.i + tag_bytes.len()] == *tag_bytes
127 {
128 self.i += tag_bytes.len();
129 SplitState::Normal
130 } else {
131 self.i += 1;
132 SplitState::DollarQuote(tag)
133 }
134 }
135
136 fn step_line_comment(&mut self) -> SplitState {
137 let next = if self.bytes[self.i] == b'\n' {
138 SplitState::Normal
139 } else {
140 SplitState::LineComment
141 };
142 self.i += 1;
143 next
144 }
145
146 fn step_block_comment(&mut self, depth: u32) -> SplitState {
147 if self.bytes[self.i] == b'/' && self.bytes.get(self.i + 1) == Some(&b'*') {
148 self.i += 2;
149 SplitState::BlockComment(depth + 1)
150 } else if self.bytes[self.i] == b'*' && self.bytes.get(self.i + 1) == Some(&b'/') {
151 self.i += 2;
152 if depth == 1 {
153 SplitState::Normal
154 } else {
155 SplitState::BlockComment(depth - 1)
156 }
157 } else {
158 self.i += 1;
159 SplitState::BlockComment(depth)
160 }
161 }
162
163 fn run(mut self) -> DatabaseResult<Vec<String>> {
164 let mut state = SplitState::Normal;
165 while self.i < self.bytes.len() {
166 state = match state {
167 SplitState::Normal => self.step_normal(),
168 SplitState::SingleQuote => self.step_single_quote(),
169 SplitState::DollarQuote(tag) => self.step_dollar_quote(tag),
170 SplitState::LineComment => self.step_line_comment(),
171 SplitState::BlockComment(depth) => self.step_block_comment(depth),
172 };
173 }
174
175 match state {
176 SplitState::Normal | SplitState::LineComment => {
177 let end = self.sql.len();
178 self.emit(end);
179 Ok(self.statements)
180 },
181 SplitState::SingleQuote => Err(RepositoryError::Internal(
182 "Unterminated string literal in SQL".into(),
183 )),
184 SplitState::DollarQuote(tag) => Err(RepositoryError::Internal(format!(
185 "Unterminated dollar-quoted string: {tag}"
186 ))),
187 SplitState::BlockComment(_) => Err(RepositoryError::Internal(
188 "Unterminated block comment in SQL".into(),
189 )),
190 }
191 }
192}
193
194impl SqlExecutor {
195 pub async fn execute_statements(db: &Database, sql: &str) -> DatabaseResult<()> {
196 db.execute_batch(sql).await.map_err(|e| {
197 RepositoryError::Internal(format!("Failed to execute SQL statements: {e}"))
198 })
199 }
200
201 pub async fn execute_statements_parsed(
202 db: &dyn DatabaseProvider,
203 sql: &str,
204 ) -> DatabaseResult<()> {
205 let statements = Self::parse_sql_statements(sql)?;
206
207 for statement in statements {
208 db.execute_raw(&statement).await.map_err(|e| {
209 RepositoryError::Internal(format!(
210 "Failed to execute SQL statement: {statement}: {e}"
211 ))
212 })?;
213 }
214
215 Ok(())
216 }
217
218 pub fn parse_sql_statements(sql: &str) -> DatabaseResult<Vec<String>> {
230 Splitter::new(sql).run()
231 }
232
233 pub async fn execute_query(db: &Database, query: &str) -> DatabaseResult<QueryResult> {
234 db.query(&query)
235 .await
236 .map_err(|e| RepositoryError::Internal(format!("Failed to execute query: {e}")))
237 }
238
239 pub async fn execute_file(db: &Database, file_path: &str) -> DatabaseResult<()> {
240 let sql = std::fs::read_to_string(file_path).map_err(|e| {
241 RepositoryError::Internal(format!("Failed to read SQL file: {file_path}: {e}"))
242 })?;
243 Self::execute_statements(db, &sql).await
244 }
245
246 pub async fn execute_file_parsed(
247 db: &dyn DatabaseProvider,
248 file_path: &str,
249 ) -> DatabaseResult<()> {
250 let sql = std::fs::read_to_string(file_path).map_err(|e| {
251 RepositoryError::Internal(format!("Failed to read SQL file: {file_path}: {e}"))
252 })?;
253 Self::execute_statements_parsed(db, &sql).await
254 }
255
256 pub async fn table_exists(db: &Database, table_name: &str) -> DatabaseResult<bool> {
257 let result = db
258 .query_with(
259 &"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema = \
260 'public' AND table_name = $1) as exists",
261 &[&table_name],
262 )
263 .await?;
264
265 result
266 .first()
267 .and_then(|row| row.get("exists"))
268 .and_then(serde_json::Value::as_bool)
269 .ok_or_else(|| RepositoryError::Internal("Failed to check table existence".to_string()))
270 }
271
272 pub async fn column_exists(
273 db: &Database,
274 table_name: &str,
275 column_name: &str,
276 ) -> DatabaseResult<bool> {
277 let result = db
278 .query_with(
279 &"SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_schema = \
280 'public' AND table_name = $1 AND column_name = $2) as exists",
281 &[&table_name, &column_name],
282 )
283 .await?;
284
285 result
286 .first()
287 .and_then(|row| row.get("exists"))
288 .and_then(serde_json::Value::as_bool)
289 .ok_or_else(|| {
290 RepositoryError::Internal("Failed to check column existence".to_string())
291 })
292 }
293}