Skip to main content

rustack_sts_http/
service.rs

1//! STS HTTP service implementing the hyper `Service` trait.
2//!
3//! STS uses the `awsQuery` protocol where the request body is
4//! `application/x-www-form-urlencoded` and the response is `text/xml`.
5//! The `Action=` form parameter routes to the appropriate operation.
6
7use std::{convert::Infallible, future::Future, pin::Pin, sync::Arc};
8
9use bytes::Bytes;
10use http_body_util::BodyExt;
11use hyper::body::Incoming;
12use rustack_sts_model::error::StsError;
13
14use crate::{
15    body::StsResponseBody,
16    dispatch::{StsHandler, dispatch_operation},
17    request::{extract_access_key_from_auth, parse_form_params},
18    response::{CONTENT_TYPE, error_to_response},
19    router::resolve_operation,
20};
21
22/// Configuration for the STS HTTP service.
23#[derive(Clone)]
24pub struct StsHttpConfig {
25    /// Whether to skip AWS signature validation.
26    pub skip_signature_validation: bool,
27    /// The AWS region this service is running in.
28    pub region: String,
29    /// Credential provider for signature validation.
30    pub credential_provider: Option<Arc<dyn rustack_auth::CredentialProvider>>,
31}
32
33impl std::fmt::Debug for StsHttpConfig {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("StsHttpConfig")
36            .field("skip_signature_validation", &self.skip_signature_validation)
37            .field("region", &self.region)
38            .field(
39                "credential_provider",
40                &self.credential_provider.as_ref().map(|_| "..."),
41            )
42            .finish()
43    }
44}
45
46impl Default for StsHttpConfig {
47    fn default() -> Self {
48        Self {
49            skip_signature_validation: true,
50            region: "us-east-1".to_owned(),
51            credential_provider: None,
52        }
53    }
54}
55
56/// Hyper `Service` implementation for STS.
57///
58/// Wraps an [`StsHandler`] implementation and routes incoming HTTP
59/// requests to the appropriate STS operation handler.
60#[derive(Debug)]
61pub struct StsHttpService<H: StsHandler> {
62    handler: Arc<H>,
63    config: Arc<StsHttpConfig>,
64}
65
66impl<H: StsHandler> StsHttpService<H> {
67    /// Create a new `StsHttpService`.
68    pub fn new(handler: Arc<H>, config: StsHttpConfig) -> Self {
69        Self {
70            handler,
71            config: Arc::new(config),
72        }
73    }
74}
75
76impl<H: StsHandler> Clone for StsHttpService<H> {
77    fn clone(&self) -> Self {
78        Self {
79            handler: Arc::clone(&self.handler),
80            config: Arc::clone(&self.config),
81        }
82    }
83}
84
85impl<H: StsHandler> hyper::service::Service<http::Request<Incoming>> for StsHttpService<H> {
86    type Response = http::Response<StsResponseBody>;
87    type Error = Infallible;
88    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
89
90    fn call(&self, req: http::Request<Incoming>) -> Self::Future {
91        let handler = Arc::clone(&self.handler);
92        let config = Arc::clone(&self.config);
93        let request_id = uuid::Uuid::new_v4().to_string();
94
95        Box::pin(async move {
96            let response = process_request(req, handler.as_ref(), &config, &request_id).await;
97            let response = add_common_headers(response, &request_id);
98            Ok(response)
99        })
100    }
101}
102
103/// Process a single STS HTTP request through the full pipeline.
104///
105/// Pipeline:
106/// 1. Verify POST method
107/// 2. Collect body
108/// 3. Parse form params from body
109/// 4. Resolve operation from `Action=` param
110/// 5. Extract caller access key from Authorization header
111/// 6. Authenticate (if enabled)
112/// 7. Dispatch to handler
113async fn process_request<H: StsHandler>(
114    req: http::Request<Incoming>,
115    handler: &H,
116    config: &StsHttpConfig,
117    request_id: &str,
118) -> http::Response<StsResponseBody> {
119    let (parts, incoming) = req.into_parts();
120
121    // 1. Verify POST method.
122    if parts.method != http::Method::POST {
123        let err =
124            StsError::invalid_action(format!("STS requires POST method, got {}", parts.method));
125        return error_to_response(&err, request_id);
126    }
127
128    // 2. Collect body.
129    let body = match collect_body(incoming).await {
130        Ok(body) => body,
131        Err(err) => return error_to_response(&err, request_id),
132    };
133
134    // 3. Parse form params.
135    let params = parse_form_params(&body);
136
137    // 4. Resolve operation.
138    let op = match resolve_operation(&params) {
139        Ok(op) => op,
140        Err(err) => return error_to_response(&err, request_id),
141    };
142
143    // 5. Extract caller access key from Authorization header.
144    let caller_access_key = parts
145        .headers
146        .get("authorization")
147        .and_then(|v| v.to_str().ok())
148        .and_then(extract_access_key_from_auth);
149
150    // 6. Authenticate (if enabled).
151    if !config.skip_signature_validation {
152        if let Some(ref cred_provider) = config.credential_provider {
153            let body_hash = rustack_auth::hash_payload(&body);
154            if let Err(auth_err) =
155                rustack_auth::verify_sigv4(&parts, &body_hash, cred_provider.as_ref())
156            {
157                let err = StsError::invalid_client_token_id(auth_err.to_string());
158                return error_to_response(&err, request_id);
159            }
160        }
161    }
162
163    // 7. Dispatch to handler.
164    match dispatch_operation(handler, op, body, caller_access_key, request_id).await {
165        Ok(response) => response,
166        Err(err) => error_to_response(&err, request_id),
167    }
168}
169
170/// Collect the incoming body into a single `Bytes` buffer.
171async fn collect_body(incoming: Incoming) -> Result<Bytes, StsError> {
172    incoming
173        .collect()
174        .await
175        .map(http_body_util::Collected::to_bytes)
176        .map_err(|e| StsError::internal_error(format!("Failed to read request body: {e}")))
177}
178
179/// Add common response headers to every STS response.
180fn add_common_headers(
181    mut response: http::Response<StsResponseBody>,
182    request_id: &str,
183) -> http::Response<StsResponseBody> {
184    let headers = response.headers_mut();
185
186    if let Ok(hv) = http::HeaderValue::from_str(request_id) {
187        headers.entry("x-amzn-requestid").or_insert(hv);
188    }
189
190    headers
191        .entry("content-type")
192        .or_insert(http::HeaderValue::from_static(CONTENT_TYPE));
193
194    headers.insert("server", http::HeaderValue::from_static("Rustack"));
195
196    // CORS headers.
197    headers.insert(
198        "access-control-allow-origin",
199        http::HeaderValue::from_static("*"),
200    );
201
202    response
203}