1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use http::header::HeaderName;
7use serde_json::json;
8use winterbaume_core::{
9 BackendState, MockRequest, MockResponse, MockService, StateChangeNotifier, StatefulService,
10 default_account_id,
11};
12
13use crate::state::{SdbError, SdbState};
14use crate::views::SdbStateView;
15use crate::wire;
16
17const X_AMZN_ERRORTYPE: HeaderName = HeaderName::from_static("x-amzn-errortype");
18
19pub struct SimpleDbV2Service {
20 pub(crate) state: Arc<BackendState<SdbState>>,
21 pub(crate) notifier: StateChangeNotifier<SdbStateView>,
22}
23
24impl SimpleDbV2Service {
25 pub fn new() -> Self {
26 Self {
27 state: Arc::new(BackendState::new()),
28 notifier: StateChangeNotifier::new(),
29 }
30 }
31
32 pub async fn with_domain(self, region: &str, domain_name: &str) -> Self {
35 let state = self.state.get(default_account_id(), region);
36 state.write().await.add_domain(domain_name);
37 self
38 }
39}
40
41impl Default for SimpleDbV2Service {
42 fn default() -> Self {
43 Self::new()
44 }
45}
46
47impl MockService for SimpleDbV2Service {
48 fn service_name(&self) -> &str {
49 "sdb"
50 }
51
52 fn url_patterns(&self) -> Vec<&str> {
53 vec![
54 r"https?://sdb\.(.+)\.amazonaws\.com",
55 r"https?://sdb\.amazonaws\.com",
56 ]
57 }
58
59 fn handle(
60 &self,
61 request: MockRequest,
62 ) -> Pin<Box<dyn Future<Output = MockResponse> + Send + '_>> {
63 Box::pin(async move { self.dispatch(request).await })
64 }
65}
66
67impl SimpleDbV2Service {
68 async fn dispatch(&self, request: MockRequest) -> MockResponse {
69 let region = winterbaume_core::auth::extract_region_from_uri(&request.uri);
70 let account_id = default_account_id();
71 let state = self.state.get(account_id, ®ion);
72
73 let path = extract_path(&request.uri);
75 let raw_query = extract_query_string(&request.uri);
76 let query_map: HashMap<String, String> = winterbaume_core::parse_query_string(&raw_query);
77
78 let response = match path.as_str() {
79 "/v2/StartDomainExport" => {
80 self.handle_start_domain_export(
81 &state,
82 &request,
83 &[],
84 &query_map,
85 account_id,
86 ®ion,
87 )
88 .await
89 }
90 "/v2/GetExport" => {
91 self.handle_get_export(&state, &request, &[], &query_map)
92 .await
93 }
94 "/v2/ListExports" => {
95 self.handle_list_exports(&state, &request, &[], &query_map)
96 .await
97 }
98 _ => rest_json_error(404, "UnknownOperationException", "Not found"),
99 };
100 if response.status / 100 == 2 {
101 self.notify_state_changed(account_id, ®ion).await;
102 }
103 response
104 }
105
106 #[allow(clippy::too_many_arguments)]
107 async fn handle_start_domain_export(
108 &self,
109 state: &Arc<tokio::sync::RwLock<SdbState>>,
110 request: &MockRequest,
111 labels: &[(&str, &str)],
112 query: &HashMap<String, String>,
113 account_id: &str,
114 region: &str,
115 ) -> MockResponse {
116 let input = match wire::deserialize_start_domain_export_request(request, labels, query) {
117 Ok(v) => v,
118 Err(_) => return rest_json_error(400, "SerializationException", "Invalid JSON body"),
119 };
120 if input.domain_name.is_empty() {
121 return rest_json_error(
122 400,
123 "InvalidParameterValueException",
124 "Missing 'domainName'",
125 );
126 }
127 if input.s3_bucket.is_empty() {
128 return rest_json_error(400, "InvalidParameterValueException", "Missing 's3Bucket'");
129 }
130 let client_token = match input.client_token.as_deref() {
131 Some(t) if !t.is_empty() => t,
132 _ => {
133 return rest_json_error(
134 400,
135 "InvalidParameterValueException",
136 "clientToken is required",
137 );
138 }
139 };
140
141 let mut state = state.write().await;
142 match state.start_domain_export(
143 &input.domain_name,
144 &input.s3_bucket,
145 input.s3_key_prefix.as_deref(),
146 input.s3_sse_algorithm.as_deref(),
147 input.s3_sse_kms_key_id.as_deref(),
148 input.s3_bucket_owner.as_deref(),
149 Some(client_token),
150 account_id,
151 region,
152 ) {
153 Ok(export) => {
154 wire::serialize_start_domain_export_response(&wire::StartDomainExportResponse {
155 client_token: Some(export.client_token.clone()),
156 export_arn: Some(export.export_arn.clone()),
157 requested_at: Some(export.requested_at.timestamp() as f64),
158 })
159 }
160 Err(e) => sdb_error_response(&e),
161 }
162 }
163
164 async fn handle_get_export(
165 &self,
166 state: &Arc<tokio::sync::RwLock<SdbState>>,
167 request: &MockRequest,
168 labels: &[(&str, &str)],
169 query: &HashMap<String, String>,
170 ) -> MockResponse {
171 let input = match wire::deserialize_get_export_request(request, labels, query) {
172 Ok(v) => v,
173 Err(_) => return rest_json_error(400, "SerializationException", "Invalid JSON body"),
174 };
175 if input.export_arn.is_empty() {
176 return rest_json_error(400, "InvalidParameterValueException", "Missing 'exportArn'");
177 }
178
179 let state = state.read().await;
180 match state.get_export(&input.export_arn) {
181 Ok(export) => wire::serialize_get_export_response(&wire::GetExportResponse {
182 export_arn: Some(export.export_arn.clone()),
183 client_token: Some(export.client_token.clone()),
184 export_status: Some(export.export_status.clone()),
185 domain_name: Some(export.domain_name.clone()),
186 requested_at: Some(export.requested_at.timestamp() as f64),
187 s3_bucket: Some(export.s3_bucket.clone()),
188 s3_key_prefix: export.s3_key_prefix.clone(),
189 s3_sse_algorithm: export.s3_sse_algorithm.clone(),
190 s3_sse_kms_key_id: export.s3_sse_kms_key_id.clone(),
191 s3_bucket_owner: export.s3_bucket_owner.clone(),
192 failure_code: export.failure_code.clone(),
193 failure_message: export.failure_message.clone(),
194 export_manifest: export.export_manifest.clone(),
195 items_count: export.items_count,
196 export_data_cutoff_time: export
197 .export_data_cutoff_time
198 .map(|dt| dt.timestamp() as f64),
199 }),
200 Err(e) => sdb_error_response(&e),
201 }
202 }
203
204 async fn handle_list_exports(
205 &self,
206 state: &Arc<tokio::sync::RwLock<SdbState>>,
207 request: &MockRequest,
208 labels: &[(&str, &str)],
209 query: &HashMap<String, String>,
210 ) -> MockResponse {
211 let input = match wire::deserialize_list_exports_request(request, labels, query) {
212 Ok(v) => v,
213 Err(_) => return rest_json_error(400, "SerializationException", "Invalid JSON body"),
214 };
215
216 let state = state.read().await;
217 match state.list_exports(
218 input.domain_name.as_deref(),
219 input.max_results,
220 input.next_token.as_deref(),
221 ) {
222 Ok((summaries, next_token)) => {
223 let entries: Vec<wire::ExportSummary> = summaries
224 .iter()
225 .map(|s| wire::ExportSummary {
226 export_arn: Some(s.export_arn.clone()),
227 export_status: Some(s.export_status.clone()),
228 requested_at: Some(s.requested_at.timestamp() as f64),
229 domain_name: Some(s.domain_name.clone()),
230 })
231 .collect();
232
233 wire::serialize_list_exports_response(&wire::ListExportsResponse {
234 export_summaries: Some(entries),
235 next_token,
236 })
237 }
238 Err(e) => sdb_error_response(&e),
239 }
240 }
241}
242
243fn sdb_error_response(err: &SdbError) -> MockResponse {
244 let (status, error_type) = match err {
245 SdbError::NoSuchDomain { .. } => (400, "NoSuchDomainException"),
246 SdbError::NoSuchExport { .. } => (400, "NoSuchExportException"),
247 SdbError::Conflict => (400, "ConflictException"),
248 };
249 let body = json!({
250 "Type": "User",
251 "Message": err.to_string(),
252 });
253 let mut resp = MockResponse::rest_json(status, body.to_string());
254 resp.headers
255 .insert(X_AMZN_ERRORTYPE, error_type.parse().unwrap());
256 resp
257}
258
259fn rest_json_error(status: u16, code: &str, message: &str) -> MockResponse {
260 let body = json!({
261 "Type": "User",
262 "Message": message,
263 });
264 let mut resp = MockResponse::rest_json(status, body.to_string());
265 resp.headers.insert(X_AMZN_ERRORTYPE, code.parse().unwrap());
266 resp
267}
268
269fn extract_path(uri: &str) -> String {
270 if let Ok(parsed) = uri.parse::<http::Uri>() {
272 parsed.path().to_string()
273 } else {
274 if let Some(pos) = uri.find("amazonaws.com") {
276 let rest = &uri[pos + "amazonaws.com".len()..];
277 rest.split('?').next().unwrap_or("/").to_string()
278 } else {
279 "/".to_string()
280 }
281 }
282}
283
284fn extract_query_string(uri: &str) -> String {
285 if let Ok(parsed) = uri.parse::<http::Uri>() {
286 parsed.query().unwrap_or("").to_string()
287 } else if let Some(idx) = uri.find('?') {
288 uri[idx + 1..].to_string()
289 } else {
290 String::new()
291 }
292}