Skip to main content

winterbaume_redshiftdata/
state.rs

1use std::collections::HashMap;
2
3use chrono::Utc;
4use thiserror::Error;
5
6use crate::types::*;
7
8#[derive(Debug, Default)]
9pub struct RedshiftDataState {
10    pub statements: HashMap<String, Statement>,
11    /// Databases available in the mock cluster catalogue.
12    pub databases: Vec<String>,
13    /// Schemas available in the mock cluster catalogue.
14    pub schemas: Vec<String>,
15    /// Table names available in the mock cluster catalogue.
16    pub table_names: Vec<String>,
17    /// Column metadata keyed by table name.
18    pub table_columns: HashMap<String, Vec<(String, String)>>,
19}
20
21#[derive(Debug, Error)]
22pub enum RedshiftDataError {
23    #[error("Sql is required")]
24    SqlRequired,
25
26    #[error("Sqls is required")]
27    SqlsRequired,
28
29    #[error(
30        "id must satisfy regex pattern: ^[a-z0-9]{{8}}(-[a-z0-9]{{4}}){{3}}-[a-z0-9]{{12}}(:\\d+)?$"
31    )]
32    InvalidStatementId,
33
34    #[error("Query does not exist.")]
35    StatementNotFound,
36}
37
38/// Check if an id looks like a valid UUID (lowercase hex with dashes).
39/// Pattern: [a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12}
40fn is_valid_statement_id(id: &str) -> bool {
41    let parts: Vec<&str> = id.split(':').collect();
42    let uuid_part = parts[0];
43    let segments: Vec<&str> = uuid_part.split('-').collect();
44    if segments.len() != 5 {
45        return false;
46    }
47    let expected_lengths = [8, 4, 4, 4, 12];
48    for (seg, &len) in segments.iter().zip(expected_lengths.iter()) {
49        if seg.len() != len {
50            return false;
51        }
52        if !seg
53            .chars()
54            .all(|c| c.is_ascii_hexdigit() && !c.is_uppercase())
55        {
56            return false;
57        }
58    }
59    true
60}
61
62impl RedshiftDataState {
63    pub fn execute_statement(
64        &mut self,
65        sql: &str,
66        database: &str,
67        cluster_identifier: Option<&str>,
68        workgroup_name: Option<&str>,
69        db_user: Option<&str>,
70        secret_arn: Option<&str>,
71        parameters: Vec<StatementParameter>,
72        result: crate::backend::StatementResult,
73    ) -> Result<String, RedshiftDataError> {
74        if sql.is_empty() {
75            return Err(RedshiftDataError::SqlRequired);
76        }
77
78        let id = uuid::Uuid::new_v4().to_string();
79        let now = Utc::now();
80        let status = if result.error.is_some() {
81            StatementStatus::Failed
82        } else {
83            StatementStatus::Finished
84        };
85        let result_rows = result.rows.len() as i64;
86        let has_result_set = !result.columns.is_empty();
87
88        let statement = Statement {
89            id: id.clone(),
90            sql: sql.to_string(),
91            cluster_identifier: cluster_identifier.map(String::from),
92            workgroup_name: workgroup_name.map(String::from),
93            database: database.to_string(),
94            db_user: db_user.map(String::from),
95            secret_arn: secret_arn.map(String::from),
96            status,
97            created_at: now,
98            updated_at: now,
99            result_rows,
100            result_size: 0,
101            has_result_set,
102            query_string: sql.to_string(),
103            parameters,
104            sqls: vec![],
105            statement_name: None,
106            is_batch: false,
107            result_columns: result.columns,
108            result_data: result.rows,
109            error_message: result.error,
110        };
111
112        self.statements.insert(id.clone(), statement);
113        Ok(id)
114    }
115
116    pub fn describe_statement(&self, id: &str) -> Result<&Statement, RedshiftDataError> {
117        if !is_valid_statement_id(id) {
118            return Err(RedshiftDataError::InvalidStatementId);
119        }
120        self.statements
121            .get(id)
122            .ok_or(RedshiftDataError::StatementNotFound)
123    }
124
125    pub fn cancel_statement(&mut self, id: &str) -> Result<bool, RedshiftDataError> {
126        if !is_valid_statement_id(id) {
127            return Err(RedshiftDataError::InvalidStatementId);
128        }
129        let statement = self
130            .statements
131            .get_mut(id)
132            .ok_or(RedshiftDataError::StatementNotFound)?;
133
134        match statement.status {
135            StatementStatus::Submitted | StatementStatus::Started => {
136                statement.status = StatementStatus::Aborted;
137                statement.updated_at = Utc::now();
138                Ok(true)
139            }
140            _ => {
141                // Already finished, failed, or aborted - cancel returns false
142                Ok(false)
143            }
144        }
145    }
146
147    pub fn list_statements(&self) -> Vec<&Statement> {
148        let mut stmts: Vec<&Statement> = self.statements.values().collect();
149        stmts.sort_by_key(|s| std::cmp::Reverse(s.created_at));
150        stmts
151    }
152
153    pub fn batch_execute_statement(
154        &mut self,
155        sqls: Vec<String>,
156        database: &str,
157        cluster_identifier: Option<&str>,
158        workgroup_name: Option<&str>,
159        db_user: Option<&str>,
160        secret_arn: Option<&str>,
161        statement_name: Option<&str>,
162        result: crate::backend::StatementResult,
163    ) -> Result<String, RedshiftDataError> {
164        if sqls.is_empty() {
165            return Err(RedshiftDataError::SqlsRequired);
166        }
167
168        let id = uuid::Uuid::new_v4().to_string();
169        let now = Utc::now();
170        let query_string = sqls.first().cloned().unwrap_or_default();
171        let status = if result.error.is_some() {
172            StatementStatus::Failed
173        } else {
174            StatementStatus::Finished
175        };
176
177        let statement = Statement {
178            id: id.clone(),
179            sql: query_string.clone(),
180            cluster_identifier: cluster_identifier.map(String::from),
181            workgroup_name: workgroup_name.map(String::from),
182            database: database.to_string(),
183            db_user: db_user.map(String::from),
184            secret_arn: secret_arn.map(String::from),
185            status,
186            created_at: now,
187            updated_at: now,
188            result_rows: 0,
189            result_size: 0,
190            has_result_set: false,
191            query_string,
192            parameters: vec![],
193            sqls,
194            statement_name: statement_name.map(String::from),
195            is_batch: true,
196            result_columns: vec![],
197            result_data: vec![],
198            error_message: result.error,
199        };
200
201        self.statements.insert(id.clone(), statement);
202        Ok(id)
203    }
204
205    // --- Catalogue operations ---
206
207    pub fn list_databases(&self) -> Vec<String> {
208        if self.databases.is_empty() {
209            vec!["dev".to_string(), "prod".to_string()]
210        } else {
211            self.databases.clone()
212        }
213    }
214
215    pub fn list_schemas(&self) -> Vec<String> {
216        if self.schemas.is_empty() {
217            vec!["public".to_string(), "information_schema".to_string()]
218        } else {
219            self.schemas.clone()
220        }
221    }
222
223    pub fn list_tables(&self) -> Vec<String> {
224        self.table_names.clone()
225    }
226
227    pub fn describe_table(&self, table: Option<&str>) -> Vec<(String, String)> {
228        if let Some(name) = table {
229            self.table_columns.get(name).cloned().unwrap_or_default()
230        } else {
231            vec![]
232        }
233    }
234}