winterbaume_redshiftdata/
state.rs1use std::collections::HashMap;
2
3use chrono::Utc;
4use thiserror::Error;
5
6use crate::types::*;
7
8#[derive(Debug, Default)]
9pub struct RedshiftDataState {
10 pub statements: HashMap<String, Statement>,
11 pub databases: Vec<String>,
13 pub schemas: Vec<String>,
15 pub table_names: Vec<String>,
17 pub table_columns: HashMap<String, Vec<(String, String)>>,
19}
20
21#[derive(Debug, Error)]
22pub enum RedshiftDataError {
23 #[error("Sql is required")]
24 SqlRequired,
25
26 #[error("Sqls is required")]
27 SqlsRequired,
28
29 #[error(
30 "id must satisfy regex pattern: ^[a-z0-9]{{8}}(-[a-z0-9]{{4}}){{3}}-[a-z0-9]{{12}}(:\\d+)?$"
31 )]
32 InvalidStatementId,
33
34 #[error("Query does not exist.")]
35 StatementNotFound,
36}
37
38fn is_valid_statement_id(id: &str) -> bool {
41 let parts: Vec<&str> = id.split(':').collect();
42 let uuid_part = parts[0];
43 let segments: Vec<&str> = uuid_part.split('-').collect();
44 if segments.len() != 5 {
45 return false;
46 }
47 let expected_lengths = [8, 4, 4, 4, 12];
48 for (seg, &len) in segments.iter().zip(expected_lengths.iter()) {
49 if seg.len() != len {
50 return false;
51 }
52 if !seg
53 .chars()
54 .all(|c| c.is_ascii_hexdigit() && !c.is_uppercase())
55 {
56 return false;
57 }
58 }
59 true
60}
61
62impl RedshiftDataState {
63 pub fn execute_statement(
64 &mut self,
65 sql: &str,
66 database: &str,
67 cluster_identifier: Option<&str>,
68 workgroup_name: Option<&str>,
69 db_user: Option<&str>,
70 secret_arn: Option<&str>,
71 parameters: Vec<StatementParameter>,
72 result: crate::backend::StatementResult,
73 ) -> Result<String, RedshiftDataError> {
74 if sql.is_empty() {
75 return Err(RedshiftDataError::SqlRequired);
76 }
77
78 let id = uuid::Uuid::new_v4().to_string();
79 let now = Utc::now();
80 let status = if result.error.is_some() {
81 StatementStatus::Failed
82 } else {
83 StatementStatus::Finished
84 };
85 let result_rows = result.rows.len() as i64;
86 let has_result_set = !result.columns.is_empty();
87
88 let statement = Statement {
89 id: id.clone(),
90 sql: sql.to_string(),
91 cluster_identifier: cluster_identifier.map(String::from),
92 workgroup_name: workgroup_name.map(String::from),
93 database: database.to_string(),
94 db_user: db_user.map(String::from),
95 secret_arn: secret_arn.map(String::from),
96 status,
97 created_at: now,
98 updated_at: now,
99 result_rows,
100 result_size: 0,
101 has_result_set,
102 query_string: sql.to_string(),
103 parameters,
104 sqls: vec![],
105 statement_name: None,
106 is_batch: false,
107 result_columns: result.columns,
108 result_data: result.rows,
109 error_message: result.error,
110 };
111
112 self.statements.insert(id.clone(), statement);
113 Ok(id)
114 }
115
116 pub fn describe_statement(&self, id: &str) -> Result<&Statement, RedshiftDataError> {
117 if !is_valid_statement_id(id) {
118 return Err(RedshiftDataError::InvalidStatementId);
119 }
120 self.statements
121 .get(id)
122 .ok_or(RedshiftDataError::StatementNotFound)
123 }
124
125 pub fn cancel_statement(&mut self, id: &str) -> Result<bool, RedshiftDataError> {
126 if !is_valid_statement_id(id) {
127 return Err(RedshiftDataError::InvalidStatementId);
128 }
129 let statement = self
130 .statements
131 .get_mut(id)
132 .ok_or(RedshiftDataError::StatementNotFound)?;
133
134 match statement.status {
135 StatementStatus::Submitted | StatementStatus::Started => {
136 statement.status = StatementStatus::Aborted;
137 statement.updated_at = Utc::now();
138 Ok(true)
139 }
140 _ => {
141 Ok(false)
143 }
144 }
145 }
146
147 pub fn list_statements(&self) -> Vec<&Statement> {
148 let mut stmts: Vec<&Statement> = self.statements.values().collect();
149 stmts.sort_by_key(|s| std::cmp::Reverse(s.created_at));
150 stmts
151 }
152
153 pub fn batch_execute_statement(
154 &mut self,
155 sqls: Vec<String>,
156 database: &str,
157 cluster_identifier: Option<&str>,
158 workgroup_name: Option<&str>,
159 db_user: Option<&str>,
160 secret_arn: Option<&str>,
161 statement_name: Option<&str>,
162 result: crate::backend::StatementResult,
163 ) -> Result<String, RedshiftDataError> {
164 if sqls.is_empty() {
165 return Err(RedshiftDataError::SqlsRequired);
166 }
167
168 let id = uuid::Uuid::new_v4().to_string();
169 let now = Utc::now();
170 let query_string = sqls.first().cloned().unwrap_or_default();
171 let status = if result.error.is_some() {
172 StatementStatus::Failed
173 } else {
174 StatementStatus::Finished
175 };
176
177 let statement = Statement {
178 id: id.clone(),
179 sql: query_string.clone(),
180 cluster_identifier: cluster_identifier.map(String::from),
181 workgroup_name: workgroup_name.map(String::from),
182 database: database.to_string(),
183 db_user: db_user.map(String::from),
184 secret_arn: secret_arn.map(String::from),
185 status,
186 created_at: now,
187 updated_at: now,
188 result_rows: 0,
189 result_size: 0,
190 has_result_set: false,
191 query_string,
192 parameters: vec![],
193 sqls,
194 statement_name: statement_name.map(String::from),
195 is_batch: true,
196 result_columns: vec![],
197 result_data: vec![],
198 error_message: result.error,
199 };
200
201 self.statements.insert(id.clone(), statement);
202 Ok(id)
203 }
204
205 pub fn list_databases(&self) -> Vec<String> {
208 if self.databases.is_empty() {
209 vec!["dev".to_string(), "prod".to_string()]
210 } else {
211 self.databases.clone()
212 }
213 }
214
215 pub fn list_schemas(&self) -> Vec<String> {
216 if self.schemas.is_empty() {
217 vec!["public".to_string(), "information_schema".to_string()]
218 } else {
219 self.schemas.clone()
220 }
221 }
222
223 pub fn list_tables(&self) -> Vec<String> {
224 self.table_names.clone()
225 }
226
227 pub fn describe_table(&self, table: Option<&str>) -> Vec<(String, String)> {
228 if let Some(name) = table {
229 self.table_columns.get(name).cloned().unwrap_or_default()
230 } else {
231 vec![]
232 }
233 }
234}