Skip to main content

systemprompt_database/services/
executor.rs

1use super::database::Database;
2use super::provider::DatabaseProvider;
3use crate::models::QueryResult;
4use anyhow::{Context, Result};
5
6#[derive(Debug, Copy, Clone)]
7pub struct SqlExecutor;
8
9impl SqlExecutor {
10    pub async fn execute_statements(db: &Database, sql: &str) -> Result<()> {
11        db.execute_batch(sql)
12            .await
13            .context("Failed to execute SQL statements")
14    }
15
16    pub async fn execute_statements_parsed(db: &dyn DatabaseProvider, sql: &str) -> Result<()> {
17        let statements = Self::parse_sql_statements(sql);
18
19        for statement in statements {
20            db.execute_raw(&statement)
21                .await
22                .with_context(|| format!("Failed to execute SQL statement: {statement}"))?;
23        }
24
25        Ok(())
26    }
27
28    pub fn parse_sql_statements(sql: &str) -> Vec<String> {
29        let mut statements = Vec::new();
30        let mut current_statement = String::new();
31        let mut in_trigger = false;
32        let mut in_dollar_quote = false;
33        let mut dollar_count = 0;
34
35        for line in sql.lines() {
36            let trimmed = line.trim();
37
38            if Self::should_skip_line(trimmed) {
39                continue;
40            }
41
42            current_statement.push_str(line);
43            current_statement.push('\n');
44
45            if trimmed.contains("$$") {
46                dollar_count += trimmed.matches("$$").count();
47                in_dollar_quote = dollar_count % 2 == 1;
48            }
49
50            if trimmed.starts_with("CREATE TRIGGER")
51                || trimmed.starts_with("CREATE OR REPLACE FUNCTION")
52            {
53                in_trigger = true;
54            }
55
56            if Self::is_statement_complete(trimmed, in_trigger, in_dollar_quote) {
57                let stmt = current_statement.trim().to_string();
58                if !stmt.is_empty() {
59                    statements.push(stmt);
60                }
61                current_statement.clear();
62                in_trigger = false;
63                dollar_count = 0;
64            }
65        }
66
67        let stmt = current_statement.trim().to_string();
68        if !stmt.is_empty() {
69            statements.push(stmt);
70        }
71
72        statements
73    }
74
75    fn should_skip_line(line: &str) -> bool {
76        line.starts_with("--") || line.is_empty()
77    }
78
79    fn is_statement_complete(line: &str, in_trigger: bool, in_dollar_quote: bool) -> bool {
80        if in_dollar_quote {
81            return false;
82        }
83
84        if in_trigger {
85            return line == "END;" || line.ends_with("LANGUAGE plpgsql;");
86        }
87
88        line.ends_with(';')
89    }
90
91    pub async fn execute_query(db: &Database, query: &str) -> Result<QueryResult> {
92        db.query(&query).await.context("Failed to execute query")
93    }
94
95    pub async fn execute_file(db: &Database, file_path: &str) -> Result<()> {
96        let sql = std::fs::read_to_string(file_path)
97            .with_context(|| format!("Failed to read SQL file: {file_path}"))?;
98
99        Self::execute_statements(db, &sql).await
100    }
101
102    pub async fn execute_file_parsed(db: &dyn DatabaseProvider, file_path: &str) -> Result<()> {
103        let sql = std::fs::read_to_string(file_path)
104            .with_context(|| format!("Failed to read SQL file: {file_path}"))?;
105
106        Self::execute_statements_parsed(db, &sql).await
107    }
108
109    pub async fn table_exists(db: &Database, table_name: &str) -> Result<bool> {
110        let result = db
111            .query_with(
112                &"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema = \
113                  'public' AND table_name = $1) as exists",
114                vec![serde_json::Value::String(table_name.to_string())],
115            )
116            .await?;
117
118        result
119            .first()
120            .and_then(|row| row.get("exists"))
121            .and_then(serde_json::Value::as_bool)
122            .ok_or_else(|| anyhow::anyhow!("Failed to check table existence"))
123    }
124
125    pub async fn column_exists(db: &Database, table_name: &str, column_name: &str) -> Result<bool> {
126        let result = db
127            .query_with(
128                &"SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_schema = \
129                  'public' AND table_name = $1 AND column_name = $2) as exists",
130                vec![
131                    serde_json::Value::String(table_name.to_string()),
132                    serde_json::Value::String(column_name.to_string()),
133                ],
134            )
135            .await?;
136
137        result
138            .first()
139            .and_then(|row| row.get("exists"))
140            .and_then(serde_json::Value::as_bool)
141            .ok_or_else(|| anyhow::anyhow!("Failed to check column existence"))
142    }
143}