Skip to main content

rustack_ses_http/
service.rs

1//! SES HTTP service implementing the hyper `Service` trait.
2//!
3//! SES v1 uses the `awsQuery` protocol (form-urlencoded request, XML response).
4//! SES v2 uses `restJson1` (JSON request/response, path-based routing).
5
6use std::{convert::Infallible, future::Future, pin::Pin, sync::Arc};
7
8use bytes::Bytes;
9use http_body_util::BodyExt;
10use hyper::body::Incoming;
11use rustack_ses_model::error::SesError;
12
13use crate::{
14    body::SesResponseBody,
15    dispatch::{SesHandler, dispatch_operation},
16    request::parse_form_params,
17    response::{XML_CONTENT_TYPE, error_to_response},
18    router::resolve_operation,
19};
20
21/// Configuration for the SES HTTP service.
22#[derive(Clone)]
23pub struct SesHttpConfig {
24    /// Whether to skip AWS signature validation.
25    pub skip_signature_validation: bool,
26    /// The AWS region this service is running in.
27    pub region: String,
28    /// Credential provider for signature validation.
29    pub credential_provider: Option<Arc<dyn rustack_auth::CredentialProvider>>,
30}
31
32impl std::fmt::Debug for SesHttpConfig {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("SesHttpConfig")
35            .field("skip_signature_validation", &self.skip_signature_validation)
36            .field("region", &self.region)
37            .field(
38                "credential_provider",
39                &self.credential_provider.as_ref().map(|_| "..."),
40            )
41            .finish()
42    }
43}
44
45impl Default for SesHttpConfig {
46    fn default() -> Self {
47        Self {
48            skip_signature_validation: true,
49            region: "us-east-1".to_owned(),
50            credential_provider: None,
51        }
52    }
53}
54
55/// Hyper `Service` implementation for SES v1 (awsQuery).
56#[derive(Debug)]
57pub struct SesHttpService<H: SesHandler> {
58    handler: Arc<H>,
59    config: Arc<SesHttpConfig>,
60}
61
62impl<H: SesHandler> SesHttpService<H> {
63    /// Create a new `SesHttpService`.
64    pub fn new(handler: Arc<H>, config: SesHttpConfig) -> Self {
65        Self {
66            handler,
67            config: Arc::new(config),
68        }
69    }
70}
71
72impl<H: SesHandler> Clone for SesHttpService<H> {
73    fn clone(&self) -> Self {
74        Self {
75            handler: Arc::clone(&self.handler),
76            config: Arc::clone(&self.config),
77        }
78    }
79}
80
81impl<H: SesHandler> hyper::service::Service<http::Request<Incoming>> for SesHttpService<H> {
82    type Response = http::Response<SesResponseBody>;
83    type Error = Infallible;
84    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
85
86    fn call(&self, req: http::Request<Incoming>) -> Self::Future {
87        let handler = Arc::clone(&self.handler);
88        let config = Arc::clone(&self.config);
89        let request_id = uuid::Uuid::new_v4().to_string();
90
91        Box::pin(async move {
92            let response = process_request(req, handler.as_ref(), &config, &request_id).await;
93            let response = add_common_headers(response, &request_id);
94            Ok(response)
95        })
96    }
97}
98
99/// Process a single SES v1 HTTP request through the full pipeline.
100///
101/// Pipeline:
102/// 1. Verify POST method (SES v1 only accepts POST)
103/// 2. Collect body
104/// 3. Parse form params from body
105/// 4. Resolve operation from `Action=` param
106/// 5. Authenticate (if enabled)
107/// 6. Dispatch to handler (pass raw body bytes)
108async fn process_request<H: SesHandler>(
109    req: http::Request<Incoming>,
110    handler: &H,
111    config: &SesHttpConfig,
112    request_id: &str,
113) -> http::Response<SesResponseBody> {
114    let (parts, incoming) = req.into_parts();
115
116    // 1. Verify POST method (SES v1 only accepts POST).
117    if parts.method != http::Method::POST {
118        let err = SesError::invalid_parameter_value(format!(
119            "SES requires POST method, got {}",
120            parts.method
121        ));
122        return error_to_response(&err, request_id);
123    }
124
125    // 2. Collect body.
126    let body = match collect_body(incoming).await {
127        Ok(body) => body,
128        Err(err) => return error_to_response(&err, request_id),
129    };
130
131    // 3. Parse form params to extract Action for routing.
132    let params = parse_form_params(&body);
133
134    // 4. Resolve operation from Action= param.
135    let op = match resolve_operation(&params) {
136        Ok(op) => op,
137        Err(err) => return error_to_response(&err, request_id),
138    };
139
140    // 5. Authenticate (if enabled).
141    if !config.skip_signature_validation {
142        if let Some(ref cred_provider) = config.credential_provider {
143            let body_hash = rustack_auth::hash_payload(&body);
144            if let Err(auth_err) =
145                rustack_auth::verify_sigv4(&parts, &body_hash, cred_provider.as_ref())
146            {
147                let err = SesError::internal_error(auth_err.to_string());
148                return error_to_response(&err, request_id);
149            }
150        }
151    }
152
153    // 6. Dispatch to handler (pass raw body so handler can re-parse as needed).
154    match dispatch_operation(handler, op, body).await {
155        Ok(response) => response,
156        Err(err) => error_to_response(&err, request_id),
157    }
158}
159
160/// Collect the incoming body into a single `Bytes` buffer.
161async fn collect_body(incoming: Incoming) -> Result<Bytes, SesError> {
162    incoming
163        .collect()
164        .await
165        .map(http_body_util::Collected::to_bytes)
166        .map_err(|e| SesError::internal_error(format!("Failed to read request body: {e}")))
167}
168
169/// Add common response headers to every SES response.
170fn add_common_headers(
171    mut response: http::Response<SesResponseBody>,
172    request_id: &str,
173) -> http::Response<SesResponseBody> {
174    let headers = response.headers_mut();
175
176    if let Ok(hv) = http::HeaderValue::from_str(request_id) {
177        headers.entry("x-amzn-requestid").or_insert(hv);
178    }
179
180    headers
181        .entry("content-type")
182        .or_insert(http::HeaderValue::from_static(XML_CONTENT_TYPE));
183
184    headers.insert("server", http::HeaderValue::from_static("Rustack"));
185
186    // CORS headers.
187    headers.insert(
188        "access-control-allow-origin",
189        http::HeaderValue::from_static("*"),
190    );
191
192    response
193}