systemprompt_database/services/
executor.rs1use super::database::Database;
4use super::provider::DatabaseProvider;
5use crate::error::{DatabaseResult, RepositoryError};
6use crate::models::QueryResult;
7
8#[derive(Debug, Copy, Clone)]
9pub struct SqlExecutor;
10
11impl SqlExecutor {
12 pub async fn execute_statements(db: &Database, sql: &str) -> DatabaseResult<()> {
13 db.execute_batch(sql).await.map_err(|e| {
14 RepositoryError::Internal(format!("Failed to execute SQL statements: {e}"))
15 })
16 }
17
18 pub async fn execute_statements_parsed(
19 db: &dyn DatabaseProvider,
20 sql: &str,
21 ) -> DatabaseResult<()> {
22 let statements = Self::parse_sql_statements(sql);
23
24 for statement in statements {
25 db.execute_raw(&statement).await.map_err(|e| {
26 RepositoryError::Internal(format!(
27 "Failed to execute SQL statement: {statement}: {e}"
28 ))
29 })?;
30 }
31
32 Ok(())
33 }
34
35 pub fn parse_sql_statements(sql: &str) -> Vec<String> {
36 let mut statements = Vec::new();
37 let mut current_statement = String::new();
38 let mut in_trigger = false;
39 let mut in_dollar_quote = false;
40 let mut dollar_count = 0;
41
42 for line in sql.lines() {
43 let trimmed = line.trim();
44
45 if Self::should_skip_line(trimmed) {
46 continue;
47 }
48
49 current_statement.push_str(line);
50 current_statement.push('\n');
51
52 if trimmed.contains("$$") {
53 dollar_count += trimmed.matches("$$").count();
54 in_dollar_quote = dollar_count % 2 == 1;
55 }
56
57 if trimmed.starts_with("CREATE TRIGGER")
58 || trimmed.starts_with("CREATE OR REPLACE FUNCTION")
59 {
60 in_trigger = true;
61 }
62
63 if Self::is_statement_complete(trimmed, in_trigger, in_dollar_quote) {
64 let stmt = current_statement.trim().to_string();
65 if !stmt.is_empty() {
66 statements.push(stmt);
67 }
68 current_statement.clear();
69 in_trigger = false;
70 dollar_count = 0;
71 }
72 }
73
74 let stmt = current_statement.trim().to_string();
75 if !stmt.is_empty() {
76 statements.push(stmt);
77 }
78
79 statements
80 }
81
82 fn should_skip_line(line: &str) -> bool {
83 line.starts_with("--") || line.is_empty()
84 }
85
86 fn is_statement_complete(line: &str, in_trigger: bool, in_dollar_quote: bool) -> bool {
87 if in_dollar_quote {
88 return false;
89 }
90
91 if in_trigger {
92 return line == "END;" || line.ends_with("LANGUAGE plpgsql;");
93 }
94
95 line.ends_with(';')
96 }
97
98 pub async fn execute_query(db: &Database, query: &str) -> DatabaseResult<QueryResult> {
99 db.query(&query)
100 .await
101 .map_err(|e| RepositoryError::Internal(format!("Failed to execute query: {e}")))
102 }
103
104 pub async fn execute_file(db: &Database, file_path: &str) -> DatabaseResult<()> {
105 let sql = std::fs::read_to_string(file_path).map_err(|e| {
106 RepositoryError::Internal(format!("Failed to read SQL file: {file_path}: {e}"))
107 })?;
108 Self::execute_statements(db, &sql).await
109 }
110
111 pub async fn execute_file_parsed(
112 db: &dyn DatabaseProvider,
113 file_path: &str,
114 ) -> DatabaseResult<()> {
115 let sql = std::fs::read_to_string(file_path).map_err(|e| {
116 RepositoryError::Internal(format!("Failed to read SQL file: {file_path}: {e}"))
117 })?;
118 Self::execute_statements_parsed(db, &sql).await
119 }
120
121 pub async fn table_exists(db: &Database, table_name: &str) -> DatabaseResult<bool> {
122 let result = db
123 .query_with(
124 &"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema = \
125 'public' AND table_name = $1) as exists",
126 vec![serde_json::Value::String(table_name.to_string())],
127 )
128 .await?;
129
130 result
131 .first()
132 .and_then(|row| row.get("exists"))
133 .and_then(serde_json::Value::as_bool)
134 .ok_or_else(|| RepositoryError::Internal("Failed to check table existence".to_string()))
135 }
136
137 pub async fn column_exists(
138 db: &Database,
139 table_name: &str,
140 column_name: &str,
141 ) -> DatabaseResult<bool> {
142 let result = db
143 .query_with(
144 &"SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_schema = \
145 'public' AND table_name = $1 AND column_name = $2) as exists",
146 vec![
147 serde_json::Value::String(table_name.to_string()),
148 serde_json::Value::String(column_name.to_string()),
149 ],
150 )
151 .await?;
152
153 result
154 .first()
155 .and_then(|row| row.get("exists"))
156 .and_then(serde_json::Value::as_bool)
157 .ok_or_else(|| {
158 RepositoryError::Internal("Failed to check column existence".to_string())
159 })
160 }
161}