systemprompt_database/services/
executor.rs1use super::database::Database;
4use super::provider::DatabaseProvider;
5use crate::error::{DatabaseResult, RepositoryError};
6use crate::models::QueryResult;
7
8#[derive(Debug, Copy, Clone)]
9pub struct SqlExecutor;
10
11enum SplitState {
12 Normal,
13 SingleQuote,
14 DollarQuote(String),
15 LineComment,
16 BlockComment(u32),
17}
18
19fn dollar_tag_end(bytes: &[u8], start: usize) -> Option<usize> {
20 debug_assert_eq!(bytes[start], b'$');
21 let mut j = start + 1;
22 while j < bytes.len() {
23 let c = bytes[j];
24 if c == b'$' {
25 return Some(j);
26 }
27 if !(c.is_ascii_alphanumeric() || c == b'_') {
28 return None;
29 }
30 j += 1;
31 }
32 None
33}
34
35impl SqlExecutor {
36 pub async fn execute_statements(db: &Database, sql: &str) -> DatabaseResult<()> {
37 db.execute_batch(sql).await.map_err(|e| {
38 RepositoryError::Internal(format!("Failed to execute SQL statements: {e}"))
39 })
40 }
41
42 pub async fn execute_statements_parsed(
43 db: &dyn DatabaseProvider,
44 sql: &str,
45 ) -> DatabaseResult<()> {
46 let statements = Self::parse_sql_statements(sql)?;
47
48 for statement in statements {
49 db.execute_raw(&statement).await.map_err(|e| {
50 RepositoryError::Internal(format!(
51 "Failed to execute SQL statement: {statement}: {e}"
52 ))
53 })?;
54 }
55
56 Ok(())
57 }
58
59 pub fn parse_sql_statements(sql: &str) -> DatabaseResult<Vec<String>> {
71 let bytes = sql.as_bytes();
72 let mut statements = Vec::new();
73 let mut start = 0usize;
74 let mut i = 0usize;
75 let mut state = SplitState::Normal;
76 let mut has_content = false;
77 let mut emit = |sql: &str, start: usize, end: usize, has_content: &mut bool| {
78 if *has_content {
79 let stmt = sql[start..end].trim();
80 if !stmt.is_empty() {
81 statements.push(stmt.to_string());
82 }
83 }
84 *has_content = false;
85 };
86
87 while i < bytes.len() {
88 match &mut state {
89 SplitState::Normal => match bytes[i] {
90 b'\'' => {
91 has_content = true;
92 state = SplitState::SingleQuote;
93 i += 1;
94 },
95 b'-' if bytes.get(i + 1) == Some(&b'-') => {
96 state = SplitState::LineComment;
97 i += 2;
98 },
99 b'/' if bytes.get(i + 1) == Some(&b'*') => {
100 state = SplitState::BlockComment(1);
101 i += 2;
102 },
103 b'$' => {
104 has_content = true;
105 if let Some(tag_end) = dollar_tag_end(bytes, i) {
106 let tag = sql[i..=tag_end].to_string();
107 state = SplitState::DollarQuote(tag);
108 i = tag_end + 1;
109 } else {
110 i += 1;
111 }
112 },
113 b';' => {
114 emit(sql, start, i, &mut has_content);
115 i += 1;
116 start = i;
117 },
118 b => {
119 if !b.is_ascii_whitespace() {
120 has_content = true;
121 }
122 i += 1;
123 },
124 },
125 SplitState::SingleQuote => {
126 if bytes[i] == b'\'' {
127 if bytes.get(i + 1) == Some(&b'\'') {
128 i += 2;
129 } else {
130 state = SplitState::Normal;
131 i += 1;
132 }
133 } else {
134 i += 1;
135 }
136 },
137 SplitState::DollarQuote(tag) => {
138 let tag_bytes = tag.as_bytes();
139 if i + tag_bytes.len() <= bytes.len()
140 && &bytes[i..i + tag_bytes.len()] == tag_bytes
141 {
142 i += tag_bytes.len();
143 state = SplitState::Normal;
144 } else {
145 i += 1;
146 }
147 },
148 SplitState::LineComment => {
149 if bytes[i] == b'\n' {
150 state = SplitState::Normal;
151 }
152 i += 1;
153 },
154 SplitState::BlockComment(depth) => {
155 if bytes[i] == b'/' && bytes.get(i + 1) == Some(&b'*') {
156 *depth += 1;
157 i += 2;
158 } else if bytes[i] == b'*' && bytes.get(i + 1) == Some(&b'/') {
159 *depth -= 1;
160 i += 2;
161 if *depth == 0 {
162 state = SplitState::Normal;
163 }
164 } else {
165 i += 1;
166 }
167 },
168 }
169 }
170
171 match state {
172 SplitState::Normal | SplitState::LineComment => {
173 emit(sql, start, sql.len(), &mut has_content);
174 Ok(statements)
175 },
176 SplitState::SingleQuote => Err(RepositoryError::Internal(
177 "Unterminated string literal in SQL".into(),
178 )),
179 SplitState::DollarQuote(tag) => Err(RepositoryError::Internal(format!(
180 "Unterminated dollar-quoted string: {tag}"
181 ))),
182 SplitState::BlockComment(_) => Err(RepositoryError::Internal(
183 "Unterminated block comment in SQL".into(),
184 )),
185 }
186 }
187
188 pub async fn execute_query(db: &Database, query: &str) -> DatabaseResult<QueryResult> {
189 db.query(&query)
190 .await
191 .map_err(|e| RepositoryError::Internal(format!("Failed to execute query: {e}")))
192 }
193
194 pub async fn execute_file(db: &Database, file_path: &str) -> DatabaseResult<()> {
195 let sql = std::fs::read_to_string(file_path).map_err(|e| {
196 RepositoryError::Internal(format!("Failed to read SQL file: {file_path}: {e}"))
197 })?;
198 Self::execute_statements(db, &sql).await
199 }
200
201 pub async fn execute_file_parsed(
202 db: &dyn DatabaseProvider,
203 file_path: &str,
204 ) -> DatabaseResult<()> {
205 let sql = std::fs::read_to_string(file_path).map_err(|e| {
206 RepositoryError::Internal(format!("Failed to read SQL file: {file_path}: {e}"))
207 })?;
208 Self::execute_statements_parsed(db, &sql).await
209 }
210
211 pub async fn table_exists(db: &Database, table_name: &str) -> DatabaseResult<bool> {
212 let result = db
213 .query_with(
214 &"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema = \
215 'public' AND table_name = $1) as exists",
216 vec![serde_json::Value::String(table_name.to_string())],
217 )
218 .await?;
219
220 result
221 .first()
222 .and_then(|row| row.get("exists"))
223 .and_then(serde_json::Value::as_bool)
224 .ok_or_else(|| RepositoryError::Internal("Failed to check table existence".to_string()))
225 }
226
227 pub async fn column_exists(
228 db: &Database,
229 table_name: &str,
230 column_name: &str,
231 ) -> DatabaseResult<bool> {
232 let result = db
233 .query_with(
234 &"SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_schema = \
235 'public' AND table_name = $1 AND column_name = $2) as exists",
236 vec![
237 serde_json::Value::String(table_name.to_string()),
238 serde_json::Value::String(column_name.to_string()),
239 ],
240 )
241 .await?;
242
243 result
244 .first()
245 .and_then(|row| row.get("exists"))
246 .and_then(serde_json::Value::as_bool)
247 .ok_or_else(|| {
248 RepositoryError::Internal("Failed to check column existence".to_string())
249 })
250 }
251}