Skip to main content

systemprompt_database/services/
executor.rs

1//! SQL batch and statement-by-statement execution helpers.
2
3use super::database::Database;
4use super::provider::DatabaseProvider;
5use crate::error::{DatabaseResult, RepositoryError};
6use crate::models::QueryResult;
7
8#[derive(Debug, Copy, Clone)]
9pub struct SqlExecutor;
10
11impl SqlExecutor {
12    pub async fn execute_statements(db: &Database, sql: &str) -> DatabaseResult<()> {
13        db.execute_batch(sql).await.map_err(|e| {
14            RepositoryError::Internal(format!("Failed to execute SQL statements: {e}"))
15        })
16    }
17
18    pub async fn execute_statements_parsed(
19        db: &dyn DatabaseProvider,
20        sql: &str,
21    ) -> DatabaseResult<()> {
22        let statements = Self::parse_sql_statements(sql);
23
24        for statement in statements {
25            db.execute_raw(&statement).await.map_err(|e| {
26                RepositoryError::Internal(format!(
27                    "Failed to execute SQL statement: {statement}: {e}"
28                ))
29            })?;
30        }
31
32        Ok(())
33    }
34
35    pub fn parse_sql_statements(sql: &str) -> Vec<String> {
36        let mut statements = Vec::new();
37        let mut current_statement = String::new();
38        let mut in_trigger = false;
39        let mut in_dollar_quote = false;
40        let mut dollar_count = 0;
41
42        for line in sql.lines() {
43            let trimmed = line.trim();
44
45            if Self::should_skip_line(trimmed) {
46                continue;
47            }
48
49            current_statement.push_str(line);
50            current_statement.push('\n');
51
52            if trimmed.contains("$$") {
53                dollar_count += trimmed.matches("$$").count();
54                in_dollar_quote = dollar_count % 2 == 1;
55            }
56
57            if trimmed.starts_with("CREATE TRIGGER")
58                || trimmed.starts_with("CREATE OR REPLACE FUNCTION")
59            {
60                in_trigger = true;
61            }
62
63            if Self::is_statement_complete(trimmed, in_trigger, in_dollar_quote) {
64                let stmt = current_statement.trim().to_string();
65                if !stmt.is_empty() {
66                    statements.push(stmt);
67                }
68                current_statement.clear();
69                in_trigger = false;
70                dollar_count = 0;
71            }
72        }
73
74        let stmt = current_statement.trim().to_string();
75        if !stmt.is_empty() {
76            statements.push(stmt);
77        }
78
79        statements
80    }
81
82    fn should_skip_line(line: &str) -> bool {
83        line.starts_with("--") || line.is_empty()
84    }
85
86    fn is_statement_complete(line: &str, in_trigger: bool, in_dollar_quote: bool) -> bool {
87        if in_dollar_quote {
88            return false;
89        }
90
91        if in_trigger {
92            return line == "END;" || line.ends_with("LANGUAGE plpgsql;");
93        }
94
95        line.ends_with(';')
96    }
97
98    pub async fn execute_query(db: &Database, query: &str) -> DatabaseResult<QueryResult> {
99        db.query(&query)
100            .await
101            .map_err(|e| RepositoryError::Internal(format!("Failed to execute query: {e}")))
102    }
103
104    pub async fn execute_file(db: &Database, file_path: &str) -> DatabaseResult<()> {
105        let sql = std::fs::read_to_string(file_path).map_err(|e| {
106            RepositoryError::Internal(format!("Failed to read SQL file: {file_path}: {e}"))
107        })?;
108        Self::execute_statements(db, &sql).await
109    }
110
111    pub async fn execute_file_parsed(
112        db: &dyn DatabaseProvider,
113        file_path: &str,
114    ) -> DatabaseResult<()> {
115        let sql = std::fs::read_to_string(file_path).map_err(|e| {
116            RepositoryError::Internal(format!("Failed to read SQL file: {file_path}: {e}"))
117        })?;
118        Self::execute_statements_parsed(db, &sql).await
119    }
120
121    pub async fn table_exists(db: &Database, table_name: &str) -> DatabaseResult<bool> {
122        let result = db
123            .query_with(
124                &"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema = \
125                  'public' AND table_name = $1) as exists",
126                vec![serde_json::Value::String(table_name.to_string())],
127            )
128            .await?;
129
130        result
131            .first()
132            .and_then(|row| row.get("exists"))
133            .and_then(serde_json::Value::as_bool)
134            .ok_or_else(|| RepositoryError::Internal("Failed to check table existence".to_string()))
135    }
136
137    pub async fn column_exists(
138        db: &Database,
139        table_name: &str,
140        column_name: &str,
141    ) -> DatabaseResult<bool> {
142        let result = db
143            .query_with(
144                &"SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_schema = \
145                  'public' AND table_name = $1 AND column_name = $2) as exists",
146                vec![
147                    serde_json::Value::String(table_name.to_string()),
148                    serde_json::Value::String(column_name.to_string()),
149                ],
150            )
151            .await?;
152
153        result
154            .first()
155            .and_then(|row| row.get("exists"))
156            .and_then(serde_json::Value::as_bool)
157            .ok_or_else(|| {
158                RepositoryError::Internal("Failed to check column existence".to_string())
159            })
160    }
161}