systemprompt_database/services/
executor.rs1use super::database::Database;
2use super::provider::DatabaseProvider;
3use crate::models::QueryResult;
4use anyhow::{Context, Result};
5
6#[derive(Debug, Copy, Clone)]
7pub struct SqlExecutor;
8
9impl SqlExecutor {
10 pub async fn execute_statements(db: &Database, sql: &str) -> Result<()> {
11 db.execute_batch(sql)
12 .await
13 .context("Failed to execute SQL statements")
14 }
15
16 pub async fn execute_statements_parsed(db: &dyn DatabaseProvider, sql: &str) -> Result<()> {
17 let statements = Self::parse_sql_statements(sql);
18
19 for statement in statements {
20 db.execute_raw(&statement)
21 .await
22 .with_context(|| format!("Failed to execute SQL statement: {statement}"))?;
23 }
24
25 Ok(())
26 }
27
28 pub fn parse_sql_statements(sql: &str) -> Vec<String> {
29 let mut statements = Vec::new();
30 let mut current_statement = String::new();
31 let mut in_trigger = false;
32 let mut in_dollar_quote = false;
33 let mut dollar_count = 0;
34
35 for line in sql.lines() {
36 let trimmed = line.trim();
37
38 if Self::should_skip_line(trimmed) {
39 continue;
40 }
41
42 current_statement.push_str(line);
43 current_statement.push('\n');
44
45 if trimmed.contains("$$") {
46 dollar_count += trimmed.matches("$$").count();
47 in_dollar_quote = dollar_count % 2 == 1;
48 }
49
50 if trimmed.starts_with("CREATE TRIGGER")
51 || trimmed.starts_with("CREATE OR REPLACE FUNCTION")
52 {
53 in_trigger = true;
54 }
55
56 if Self::is_statement_complete(trimmed, in_trigger, in_dollar_quote) {
57 let stmt = current_statement.trim().to_string();
58 if !stmt.is_empty() {
59 statements.push(stmt);
60 }
61 current_statement.clear();
62 in_trigger = false;
63 dollar_count = 0;
64 }
65 }
66
67 let stmt = current_statement.trim().to_string();
68 if !stmt.is_empty() {
69 statements.push(stmt);
70 }
71
72 statements
73 }
74
75 fn should_skip_line(line: &str) -> bool {
76 line.starts_with("--") || line.is_empty()
77 }
78
79 fn is_statement_complete(line: &str, in_trigger: bool, in_dollar_quote: bool) -> bool {
80 if in_dollar_quote {
81 return false;
82 }
83
84 if in_trigger {
85 return line == "END;" || line.ends_with("LANGUAGE plpgsql;");
86 }
87
88 line.ends_with(';')
89 }
90
91 pub async fn execute_query(db: &Database, query: &str) -> Result<QueryResult> {
92 db.query(&query).await.context("Failed to execute query")
93 }
94
95 pub async fn execute_file(db: &Database, file_path: &str) -> Result<()> {
96 let sql = std::fs::read_to_string(file_path)
97 .with_context(|| format!("Failed to read SQL file: {file_path}"))?;
98
99 Self::execute_statements(db, &sql).await
100 }
101
102 pub async fn execute_file_parsed(db: &dyn DatabaseProvider, file_path: &str) -> Result<()> {
103 let sql = std::fs::read_to_string(file_path)
104 .with_context(|| format!("Failed to read SQL file: {file_path}"))?;
105
106 Self::execute_statements_parsed(db, &sql).await
107 }
108
109 pub async fn table_exists(db: &Database, table_name: &str) -> Result<bool> {
110 let result = db
111 .query_with(
112 &"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema = \
113 'public' AND table_name = $1) as exists",
114 vec![serde_json::Value::String(table_name.to_string())],
115 )
116 .await?;
117
118 result
119 .first()
120 .and_then(|row| row.get("exists"))
121 .and_then(serde_json::Value::as_bool)
122 .ok_or_else(|| anyhow::anyhow!("Failed to check table existence"))
123 }
124
125 pub async fn column_exists(db: &Database, table_name: &str, column_name: &str) -> Result<bool> {
126 let result = db
127 .query_with(
128 &"SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_schema = \
129 'public' AND table_name = $1 AND column_name = $2) as exists",
130 vec![
131 serde_json::Value::String(table_name.to_string()),
132 serde_json::Value::String(column_name.to_string()),
133 ],
134 )
135 .await?;
136
137 result
138 .first()
139 .and_then(|row| row.get("exists"))
140 .and_then(serde_json::Value::as_bool)
141 .ok_or_else(|| anyhow::anyhow!("Failed to check column existence"))
142 }
143}