rustack_sts_http/
service.rs1use std::{convert::Infallible, future::Future, pin::Pin, sync::Arc};
8
9use bytes::Bytes;
10use http_body_util::BodyExt;
11use hyper::body::Incoming;
12use rustack_sts_model::error::StsError;
13
14use crate::{
15 body::StsResponseBody,
16 dispatch::{StsHandler, dispatch_operation},
17 request::{extract_access_key_from_auth, parse_form_params},
18 response::{CONTENT_TYPE, error_to_response},
19 router::resolve_operation,
20};
21
22#[derive(Clone)]
24pub struct StsHttpConfig {
25 pub skip_signature_validation: bool,
27 pub region: String,
29 pub credential_provider: Option<Arc<dyn rustack_auth::CredentialProvider>>,
31}
32
33impl std::fmt::Debug for StsHttpConfig {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("StsHttpConfig")
36 .field("skip_signature_validation", &self.skip_signature_validation)
37 .field("region", &self.region)
38 .field(
39 "credential_provider",
40 &self.credential_provider.as_ref().map(|_| "..."),
41 )
42 .finish()
43 }
44}
45
46impl Default for StsHttpConfig {
47 fn default() -> Self {
48 Self {
49 skip_signature_validation: true,
50 region: "us-east-1".to_owned(),
51 credential_provider: None,
52 }
53 }
54}
55
56#[derive(Debug)]
61pub struct StsHttpService<H: StsHandler> {
62 handler: Arc<H>,
63 config: Arc<StsHttpConfig>,
64}
65
66impl<H: StsHandler> StsHttpService<H> {
67 pub fn new(handler: Arc<H>, config: StsHttpConfig) -> Self {
69 Self {
70 handler,
71 config: Arc::new(config),
72 }
73 }
74}
75
76impl<H: StsHandler> Clone for StsHttpService<H> {
77 fn clone(&self) -> Self {
78 Self {
79 handler: Arc::clone(&self.handler),
80 config: Arc::clone(&self.config),
81 }
82 }
83}
84
85impl<H: StsHandler> hyper::service::Service<http::Request<Incoming>> for StsHttpService<H> {
86 type Response = http::Response<StsResponseBody>;
87 type Error = Infallible;
88 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
89
90 fn call(&self, req: http::Request<Incoming>) -> Self::Future {
91 let handler = Arc::clone(&self.handler);
92 let config = Arc::clone(&self.config);
93 let request_id = uuid::Uuid::new_v4().to_string();
94
95 Box::pin(async move {
96 let response = process_request(req, handler.as_ref(), &config, &request_id).await;
97 let response = add_common_headers(response, &request_id);
98 Ok(response)
99 })
100 }
101}
102
103async fn process_request<H: StsHandler>(
114 req: http::Request<Incoming>,
115 handler: &H,
116 config: &StsHttpConfig,
117 request_id: &str,
118) -> http::Response<StsResponseBody> {
119 let (parts, incoming) = req.into_parts();
120
121 if parts.method != http::Method::POST {
123 let err =
124 StsError::invalid_action(format!("STS requires POST method, got {}", parts.method));
125 return error_to_response(&err, request_id);
126 }
127
128 let body = match collect_body(incoming).await {
130 Ok(body) => body,
131 Err(err) => return error_to_response(&err, request_id),
132 };
133
134 let params = parse_form_params(&body);
136
137 let op = match resolve_operation(¶ms) {
139 Ok(op) => op,
140 Err(err) => return error_to_response(&err, request_id),
141 };
142
143 let caller_access_key = parts
145 .headers
146 .get("authorization")
147 .and_then(|v| v.to_str().ok())
148 .and_then(extract_access_key_from_auth);
149
150 if !config.skip_signature_validation {
152 if let Some(ref cred_provider) = config.credential_provider {
153 let body_hash = rustack_auth::hash_payload(&body);
154 if let Err(auth_err) =
155 rustack_auth::verify_sigv4(&parts, &body_hash, cred_provider.as_ref())
156 {
157 let err = StsError::invalid_client_token_id(auth_err.to_string());
158 return error_to_response(&err, request_id);
159 }
160 }
161 }
162
163 match dispatch_operation(handler, op, body, caller_access_key, request_id).await {
165 Ok(response) => response,
166 Err(err) => error_to_response(&err, request_id),
167 }
168}
169
170async fn collect_body(incoming: Incoming) -> Result<Bytes, StsError> {
172 incoming
173 .collect()
174 .await
175 .map(http_body_util::Collected::to_bytes)
176 .map_err(|e| StsError::internal_error(format!("Failed to read request body: {e}")))
177}
178
179fn add_common_headers(
181 mut response: http::Response<StsResponseBody>,
182 request_id: &str,
183) -> http::Response<StsResponseBody> {
184 let headers = response.headers_mut();
185
186 if let Ok(hv) = http::HeaderValue::from_str(request_id) {
187 headers.entry("x-amzn-requestid").or_insert(hv);
188 }
189
190 headers
191 .entry("content-type")
192 .or_insert(http::HeaderValue::from_static(CONTENT_TYPE));
193
194 headers.insert("server", http::HeaderValue::from_static("Rustack"));
195
196 headers.insert(
198 "access-control-allow-origin",
199 http::HeaderValue::from_static("*"),
200 );
201
202 response
203}