Skip to main content

winterbaume_redshiftdata/
handlers.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use serde_json::{Value, json};
6use winterbaume_core::{
7    BackendState, MockRequest, MockResponse, MockService, StateChangeNotifier, default_account_id,
8};
9
10use crate::backend::{InMemoryRedshiftQueryBackend, RedshiftQueryBackend};
11use crate::model::{
12    CancelStatementResponse, ColumnMetadata, DescribeStatementResponse, ExecuteStatementOutput,
13    Field, GetStatementResultResponse, ListStatementsResponse, SqlParameter, StatementData,
14};
15use crate::state::{RedshiftDataError, RedshiftDataState};
16use crate::types::{Statement, StatementParameter};
17use crate::views::RedshiftDataStateView;
18use crate::wire;
19
20pub struct RedshiftDataService {
21    pub(crate) query_backend: Arc<dyn RedshiftQueryBackend>,
22    pub(crate) state: Arc<BackendState<RedshiftDataState>>,
23    pub(crate) notifier: StateChangeNotifier<RedshiftDataStateView>,
24}
25
26impl RedshiftDataService {
27    pub fn new() -> Self {
28        Self {
29            query_backend: Arc::new(InMemoryRedshiftQueryBackend),
30            state: Arc::new(BackendState::new()),
31            notifier: StateChangeNotifier::new(),
32        }
33    }
34
35    /// Construct with a custom query execution backend.
36    pub fn with_query_backend(query_backend: Arc<dyn RedshiftQueryBackend>) -> Self {
37        Self {
38            query_backend,
39            state: Arc::new(BackendState::new()),
40            notifier: StateChangeNotifier::new(),
41        }
42    }
43}
44
45impl Default for RedshiftDataService {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51impl MockService for RedshiftDataService {
52    fn service_name(&self) -> &str {
53        "redshift-data"
54    }
55
56    fn url_patterns(&self) -> Vec<&str> {
57        vec![
58            r"https?://redshift-data\.(.+)\.amazonaws\.com",
59            r"https?://redshift-data\.amazonaws\.com",
60        ]
61    }
62
63    fn handle(
64        &self,
65        request: MockRequest,
66    ) -> Pin<Box<dyn Future<Output = MockResponse> + Send + '_>> {
67        Box::pin(async move { self.dispatch(request).await })
68    }
69}
70
71impl RedshiftDataService {
72    async fn dispatch(&self, request: MockRequest) -> MockResponse {
73        let region = winterbaume_core::auth::extract_region_from_uri(&request.uri);
74        let account_id = default_account_id();
75
76        // Extract action from X-Amz-Target header
77        // Format: "RedshiftData.ExecuteStatement"
78        let action = request
79            .headers
80            .get("x-amz-target")
81            .and_then(|v| v.to_str().ok())
82            .and_then(|v| v.rsplit('.').next())
83            .map(|s| s.to_string());
84
85        let action = match action {
86            Some(a) => a,
87            None => {
88                return json_error_response(400, "MissingAction", "Missing X-Amz-Target header");
89            }
90        };
91
92        // Validate the body is well-formed JSON up-front; the typed deserialisers in
93        // `wire` re-parse the bytes per operation.
94        if serde_json::from_slice::<Value>(&request.body).is_err() {
95            return json_error_response(400, "SerializationException", "Invalid JSON body");
96        }
97        let body_bytes: &[u8] = &request.body;
98
99        let state = self.state.get(account_id, &region);
100
101        match action.as_str() {
102            "ExecuteStatement" => self.handle_execute_statement(&state, body_bytes).await,
103            "BatchExecuteStatement" => {
104                self.handle_batch_execute_statement(&state, body_bytes)
105                    .await
106            }
107            "DescribeStatement" => self.handle_describe_statement(&state, body_bytes).await,
108            "DescribeTable" => self.handle_describe_table(&state, body_bytes).await,
109            "CancelStatement" => self.handle_cancel_statement(&state, body_bytes).await,
110            "ListStatements" => self.handle_list_statements(&state).await,
111            "GetStatementResult" => self.handle_get_statement_result(&state, body_bytes).await,
112            "GetStatementResultV2" => {
113                self.handle_get_statement_result_v2(&state, body_bytes)
114                    .await
115            }
116            "ListDatabases" => self.handle_list_databases(&state, body_bytes).await,
117            "ListSchemas" => self.handle_list_schemas(&state, body_bytes).await,
118            "ListTables" => self.handle_list_tables(&state, body_bytes).await,
119            _ => json_error_response(
120                400,
121                "InvalidAction",
122                &format!("Could not find operation {action} for RedshiftData"),
123            ),
124        }
125    }
126
127    async fn handle_execute_statement(
128        &self,
129        state: &Arc<tokio::sync::RwLock<RedshiftDataState>>,
130        body: &[u8],
131    ) -> MockResponse {
132        let input = match wire::deserialize_execute_statement_request(body) {
133            Ok(v) => v,
134            Err(e) => return json_error_response(400, "ValidationException", &e),
135        };
136        if input.sql.is_empty() {
137            return json_error_response(400, "ValidationException", "Sql is required");
138        }
139        let database = match input.database.as_deref() {
140            Some(d) => d,
141            None => {
142                return json_error_response(400, "ValidationException", "Database is required");
143            }
144        };
145
146        let sql = input.sql.as_str();
147        let cluster_identifier = input.cluster_identifier.as_deref();
148        let workgroup_name = input.workgroup_name.as_deref();
149        let db_user = input.db_user.as_deref();
150        let secret_arn = input.secret_arn.as_deref();
151
152        let parameters: Vec<StatementParameter> = input
153            .parameters
154            .unwrap_or_default()
155            .into_iter()
156            .map(|p| StatementParameter {
157                name: p.name,
158                value: p.value,
159            })
160            .collect();
161
162        let result = self.query_backend.execute_statement(sql.to_string()).await;
163
164        let mut state = state.write().await;
165        match state.execute_statement(
166            sql,
167            database,
168            cluster_identifier,
169            workgroup_name,
170            db_user,
171            secret_arn,
172            parameters,
173            result,
174        ) {
175            Ok(id) => {
176                let output = ExecuteStatementOutput {
177                    id: Some(id),
178                    created_at: Some(chrono::Utc::now().timestamp() as f64),
179                    database: Some(database.to_string()),
180                    cluster_identifier: cluster_identifier.map(String::from),
181                    workgroup_name: workgroup_name.map(String::from),
182                    db_user: db_user.map(String::from),
183                    secret_arn: secret_arn.map(String::from),
184                    ..Default::default()
185                };
186                wire::serialize_execute_statement_response(&output)
187            }
188            Err(e) => redshiftdata_error_response(&e),
189        }
190    }
191
192    async fn handle_describe_statement(
193        &self,
194        state: &Arc<tokio::sync::RwLock<RedshiftDataState>>,
195        body: &[u8],
196    ) -> MockResponse {
197        let input = match wire::deserialize_describe_statement_request(body) {
198            Ok(v) => v,
199            Err(e) => return json_error_response(400, "ValidationException", &e),
200        };
201        if input.id.is_empty() {
202            return json_error_response(400, "ValidationException", "Id is required");
203        }
204        let id = input.id.as_str();
205
206        let state = state.read().await;
207        match state.describe_statement(id) {
208            Ok(stmt) => {
209                let resp = statement_to_describe_response(stmt);
210                wire::serialize_describe_statement_response(&resp)
211            }
212            Err(e) => redshiftdata_error_response(&e),
213        }
214    }
215
216    async fn handle_cancel_statement(
217        &self,
218        state: &Arc<tokio::sync::RwLock<RedshiftDataState>>,
219        body: &[u8],
220    ) -> MockResponse {
221        let input = match wire::deserialize_cancel_statement_request(body) {
222            Ok(v) => v,
223            Err(e) => return json_error_response(400, "ValidationException", &e),
224        };
225        if input.id.is_empty() {
226            return json_error_response(400, "ValidationException", "Id is required");
227        }
228        let id = input.id.as_str();
229
230        let mut state = state.write().await;
231        match state.cancel_statement(id) {
232            Ok(status) => {
233                let resp = CancelStatementResponse {
234                    status: Some(status),
235                };
236                wire::serialize_cancel_statement_response(&resp)
237            }
238            Err(e) => redshiftdata_error_response(&e),
239        }
240    }
241
242    async fn handle_list_statements(
243        &self,
244        state: &Arc<tokio::sync::RwLock<RedshiftDataState>>,
245    ) -> MockResponse {
246        let state = state.read().await;
247        let stmts = state.list_statements();
248        let entries: Vec<StatementData> = stmts
249            .iter()
250            .map(|s| statement_to_statement_data(s))
251            .collect();
252
253        let resp = ListStatementsResponse {
254            statements: Some(entries),
255            next_token: None,
256        };
257        wire::serialize_list_statements_response(&resp)
258    }
259
260    async fn handle_get_statement_result(
261        &self,
262        state: &Arc<tokio::sync::RwLock<RedshiftDataState>>,
263        body: &[u8],
264    ) -> MockResponse {
265        let input = match wire::deserialize_get_statement_result_request(body) {
266            Ok(v) => v,
267            Err(e) => return json_error_response(400, "ValidationException", &e),
268        };
269        if input.id.is_empty() {
270            return json_error_response(400, "ValidationException", "Id is required");
271        }
272        let id = input.id.as_str();
273
274        let state = state.read().await;
275        match state.describe_statement(id) {
276            Ok(stmt) => {
277                let column_metadata: Vec<ColumnMetadata> = stmt
278                    .result_columns
279                    .iter()
280                    .map(|(name, type_str)| ColumnMetadata {
281                        name: Some(name.clone()),
282                        type_name: Some(type_str.clone()),
283                        ..Default::default()
284                    })
285                    .collect();
286
287                let records: Vec<Vec<Field>> = stmt
288                    .result_data
289                    .iter()
290                    .map(|row| {
291                        row.iter()
292                            .zip(stmt.result_columns.iter())
293                            .map(|(cell, (_, type_str))| string_to_field(cell, type_str))
294                            .collect()
295                    })
296                    .collect();
297
298                let total = records.len() as i64;
299                let resp = GetStatementResultResponse {
300                    records: Some(records),
301                    column_metadata: Some(column_metadata),
302                    total_num_rows: Some(total),
303                    next_token: None,
304                };
305                wire::serialize_get_statement_result_response(&resp)
306            }
307            Err(e) => redshiftdata_error_response(&e),
308        }
309    }
310
311    async fn handle_batch_execute_statement(
312        &self,
313        state: &Arc<tokio::sync::RwLock<RedshiftDataState>>,
314        body: &[u8],
315    ) -> MockResponse {
316        let input = match wire::deserialize_batch_execute_statement_request(body) {
317            Ok(v) => v,
318            Err(e) => return json_error_response(400, "ValidationException", &e),
319        };
320        if input.sqls.is_empty() {
321            return json_error_response(400, "ValidationException", "Sqls is required");
322        }
323        let sqls: Vec<String> = input.sqls;
324
325        let database = match input.database.as_deref() {
326            Some(d) => d,
327            None => {
328                return json_error_response(400, "ValidationException", "Database is required");
329            }
330        };
331
332        let cluster_identifier = input.cluster_identifier.as_deref();
333        let workgroup_name = input.workgroup_name.as_deref();
334        let db_user = input.db_user.as_deref();
335        let secret_arn = input.secret_arn.as_deref();
336        let statement_name = input.statement_name.as_deref();
337
338        let result = self.query_backend.batch_execute(sqls.clone()).await;
339
340        let mut state = state.write().await;
341        match state.batch_execute_statement(
342            sqls.clone(),
343            database,
344            cluster_identifier,
345            workgroup_name,
346            db_user,
347            secret_arn,
348            statement_name,
349            result,
350        ) {
351            Ok(id) => {
352                let output = crate::model::BatchExecuteStatementOutput {
353                    id: Some(id),
354                    created_at: Some(chrono::Utc::now().timestamp() as f64),
355                    database: Some(database.to_string()),
356                    cluster_identifier: cluster_identifier.map(String::from),
357                    workgroup_name: workgroup_name.map(String::from),
358                    db_user: db_user.map(String::from),
359                    secret_arn: secret_arn.map(String::from),
360                    ..Default::default()
361                };
362                wire::serialize_batch_execute_statement_response(&output)
363            }
364            Err(e) => redshiftdata_error_response(&e),
365        }
366    }
367
368    async fn handle_describe_table(
369        &self,
370        state: &Arc<tokio::sync::RwLock<RedshiftDataState>>,
371        body: &[u8],
372    ) -> MockResponse {
373        let input = match wire::deserialize_describe_table_request(body) {
374            Ok(v) => v,
375            Err(e) => return json_error_response(400, "ValidationException", &e),
376        };
377        let table_name = input.table.as_deref();
378
379        let state = state.read().await;
380        let columns = state.describe_table(table_name);
381        let column_list: Vec<ColumnMetadata> = columns
382            .iter()
383            .map(|(name, type_str)| ColumnMetadata {
384                name: Some(name.clone()),
385                type_name: Some(type_str.clone()),
386                ..Default::default()
387            })
388            .collect();
389
390        let resp = crate::model::DescribeTableResponse {
391            column_list: Some(column_list),
392            next_token: None,
393            table_name: table_name.map(String::from),
394        };
395        wire::serialize_describe_table_response(&resp)
396    }
397
398    async fn handle_get_statement_result_v2(
399        &self,
400        state: &Arc<tokio::sync::RwLock<RedshiftDataState>>,
401        body: &[u8],
402    ) -> MockResponse {
403        let input = match wire::deserialize_get_statement_result_v2_request(body) {
404            Ok(v) => v,
405            Err(e) => return json_error_response(400, "ValidationException", &e),
406        };
407        if input.id.is_empty() {
408            return json_error_response(400, "ValidationException", "Id is required");
409        }
410        let id = input.id.as_str();
411
412        let state = state.read().await;
413        match state.describe_statement(id) {
414            Ok(stmt) => {
415                let column_metadata: Vec<ColumnMetadata> = stmt
416                    .result_columns
417                    .iter()
418                    .map(|(name, type_str)| ColumnMetadata {
419                        name: Some(name.clone()),
420                        type_name: Some(type_str.clone()),
421                        ..Default::default()
422                    })
423                    .collect();
424
425                // V2 format: each record row is serialised as CSV
426                let records: Vec<crate::model::QueryRecords> = stmt
427                    .result_data
428                    .iter()
429                    .map(|row| {
430                        let csv_line: String = row
431                            .iter()
432                            .map(|cell| cell.as_deref().unwrap_or(""))
433                            .collect::<Vec<&str>>()
434                            .join(",");
435                        crate::model::QueryRecords {
436                            c_s_v_records: Some(csv_line),
437                        }
438                    })
439                    .collect();
440
441                let total = records.len() as i64;
442                let resp = crate::model::GetStatementResultV2Response {
443                    column_metadata: Some(column_metadata),
444                    next_token: None,
445                    records: Some(records),
446                    result_format: Some("CSV".to_string()),
447                    total_num_rows: Some(total),
448                };
449                wire::serialize_get_statement_result_v2_response(&resp)
450            }
451            Err(e) => redshiftdata_error_response(&e),
452        }
453    }
454
455    async fn handle_list_databases(
456        &self,
457        state: &Arc<tokio::sync::RwLock<RedshiftDataState>>,
458        body: &[u8],
459    ) -> MockResponse {
460        if let Err(e) = wire::deserialize_list_databases_request(body) {
461            return json_error_response(400, "ValidationException", &e);
462        }
463        let state = state.read().await;
464        let databases = state.list_databases();
465        let resp = crate::model::ListDatabasesResponse {
466            databases: Some(databases),
467            next_token: None,
468        };
469        wire::serialize_list_databases_response(&resp)
470    }
471
472    async fn handle_list_schemas(
473        &self,
474        state: &Arc<tokio::sync::RwLock<RedshiftDataState>>,
475        body: &[u8],
476    ) -> MockResponse {
477        if let Err(e) = wire::deserialize_list_schemas_request(body) {
478            return json_error_response(400, "ValidationException", &e);
479        }
480        let state = state.read().await;
481        let schemas = state.list_schemas();
482        let resp = crate::model::ListSchemasResponse {
483            schemas: Some(schemas),
484            next_token: None,
485        };
486        wire::serialize_list_schemas_response(&resp)
487    }
488
489    async fn handle_list_tables(
490        &self,
491        state: &Arc<tokio::sync::RwLock<RedshiftDataState>>,
492        body: &[u8],
493    ) -> MockResponse {
494        if let Err(e) = wire::deserialize_list_tables_request(body) {
495            return json_error_response(400, "ValidationException", &e);
496        }
497        let state = state.read().await;
498        let table_names = state.list_tables();
499        let tables: Vec<crate::model::TableMember> = table_names
500            .iter()
501            .map(|name| crate::model::TableMember {
502                name: Some(name.clone()),
503                ..Default::default()
504            })
505            .collect();
506        let resp = crate::model::ListTablesResponse {
507            tables: Some(tables),
508            next_token: None,
509        };
510        wire::serialize_list_tables_response(&resp)
511    }
512}
513
514/// Convert a state Statement to a wire DescribeStatementResponse.
515fn statement_to_describe_response(stmt: &Statement) -> DescribeStatementResponse {
516    let query_parameters: Option<Vec<SqlParameter>> = if stmt.parameters.is_empty() {
517        None
518    } else {
519        Some(
520            stmt.parameters
521                .iter()
522                .map(|p| SqlParameter {
523                    name: p.name.clone(),
524                    value: p.value.clone(),
525                })
526                .collect(),
527        )
528    };
529
530    DescribeStatementResponse {
531        id: Some(stmt.id.clone()),
532        status: Some(stmt.status.as_str().to_string()),
533        created_at: Some(stmt.created_at.timestamp() as f64),
534        updated_at: Some(stmt.updated_at.timestamp() as f64),
535        query_string: Some(stmt.query_string.clone()),
536        database: Some(stmt.database.clone()),
537        has_result_set: Some(stmt.has_result_set),
538        result_rows: Some(stmt.result_rows),
539        result_size: Some(stmt.result_size),
540        duration: Some(0),
541        cluster_identifier: stmt.cluster_identifier.clone(),
542        workgroup_name: stmt.workgroup_name.clone(),
543        db_user: stmt.db_user.clone(),
544        secret_arn: stmt.secret_arn.clone(),
545        query_parameters,
546        ..Default::default()
547    }
548}
549
550/// Convert a state Statement to a wire StatementData for list responses.
551fn statement_to_statement_data(stmt: &Statement) -> StatementData {
552    StatementData {
553        id: Some(stmt.id.clone()),
554        status: Some(stmt.status.as_str().to_string()),
555        created_at: Some(stmt.created_at.timestamp() as f64),
556        updated_at: Some(stmt.updated_at.timestamp() as f64),
557        query_string: Some(stmt.query_string.clone()),
558        statement_name: Some(String::new()),
559        is_batch_statement: Some(false),
560        secret_arn: stmt.secret_arn.clone(),
561        ..Default::default()
562    }
563}
564
565/// Convert a stored string cell to a typed [`Field`] using the column's
566/// declared type string (as returned by the query backend).
567fn string_to_field(value: &Option<String>, type_str: &str) -> Field {
568    match value {
569        None => Field {
570            is_null: Some(true),
571            ..Default::default()
572        },
573        Some(s) => {
574            let t = type_str.to_ascii_lowercase();
575            if t.contains("int") || t.contains("bigint") {
576                if let Ok(n) = s.parse::<i64>() {
577                    return Field {
578                        long_value: Some(n),
579                        ..Default::default()
580                    };
581                }
582            }
583            if t.contains("float") || t.contains("double") || t.contains("real") {
584                if let Ok(f) = s.parse::<f64>() {
585                    return Field {
586                        double_value: Some(f),
587                        ..Default::default()
588                    };
589                }
590            }
591            if t.contains("bool") {
592                let b = s == "true" || s == "1";
593                return Field {
594                    boolean_value: Some(b),
595                    ..Default::default()
596                };
597            }
598            Field {
599                string_value: Some(s.clone()),
600                ..Default::default()
601            }
602        }
603    }
604}
605
606fn redshiftdata_error_response(err: &RedshiftDataError) -> MockResponse {
607    let (status, error_type) = match err {
608        RedshiftDataError::SqlRequired
609        | RedshiftDataError::SqlsRequired
610        | RedshiftDataError::InvalidStatementId => (400, "ValidationException"),
611        RedshiftDataError::StatementNotFound => (404, "ResourceNotFoundException"),
612    };
613    let body = json!({
614        "__type": error_type,
615        "message": err.to_string(),
616    });
617    MockResponse::json(status, body.to_string())
618}
619
620fn json_error_response(status: u16, code: &str, message: &str) -> MockResponse {
621    let body = json!({
622        "__type": code,
623        "message": message,
624    });
625    MockResponse::json(status, body.to_string())
626}