Skip to main content

rustack_apigatewayv2_http/
service.rs

1//! API Gateway v2 HTTP service implementing the hyper `Service` trait.
2
3use std::{convert::Infallible, future::Future, pin::Pin, sync::Arc};
4
5use bytes::Bytes;
6use http_body_util::BodyExt;
7use hyper::body::Incoming;
8use rustack_apigatewayv2_model::error::ApiGatewayV2Error;
9
10use crate::{
11    body::ApiGatewayV2ResponseBody,
12    dispatch::{ApiGatewayV2Handler, dispatch_operation},
13    response::{CONTENT_TYPE, error_to_response},
14    router::resolve_operation,
15};
16
17/// Configuration for the API Gateway v2 HTTP service.
18#[derive(Clone)]
19pub struct ApiGatewayV2HttpConfig {
20    /// Whether to skip AWS signature validation.
21    pub skip_signature_validation: bool,
22    /// The AWS region this service is running in.
23    pub region: String,
24    /// Credential provider for signature validation.
25    pub credential_provider: Option<Arc<dyn rustack_auth::CredentialProvider>>,
26}
27
28impl std::fmt::Debug for ApiGatewayV2HttpConfig {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("ApiGatewayV2HttpConfig")
31            .field("skip_signature_validation", &self.skip_signature_validation)
32            .field("region", &self.region)
33            .field(
34                "credential_provider",
35                &self.credential_provider.as_ref().map(|_| "..."),
36            )
37            .finish()
38    }
39}
40
41impl Default for ApiGatewayV2HttpConfig {
42    fn default() -> Self {
43        Self {
44            skip_signature_validation: true,
45            region: "us-east-1".to_owned(),
46            credential_provider: None,
47        }
48    }
49}
50
51/// Hyper `Service` implementation for API Gateway v2.
52///
53/// Wraps an [`ApiGatewayV2Handler`] implementation and routes incoming HTTP
54/// requests to the appropriate operation handler using restJson1 URL-based
55/// routing.
56#[derive(Debug)]
57pub struct ApiGatewayV2HttpService<H: ApiGatewayV2Handler> {
58    handler: Arc<H>,
59    config: Arc<ApiGatewayV2HttpConfig>,
60}
61
62impl<H: ApiGatewayV2Handler> ApiGatewayV2HttpService<H> {
63    /// Create a new `ApiGatewayV2HttpService`.
64    pub fn new(handler: Arc<H>, config: ApiGatewayV2HttpConfig) -> Self {
65        Self {
66            handler,
67            config: Arc::new(config),
68        }
69    }
70}
71
72impl<H: ApiGatewayV2Handler> Clone for ApiGatewayV2HttpService<H> {
73    fn clone(&self) -> Self {
74        Self {
75            handler: Arc::clone(&self.handler),
76            config: Arc::clone(&self.config),
77        }
78    }
79}
80
81impl<H: ApiGatewayV2Handler> hyper::service::Service<http::Request<Incoming>>
82    for ApiGatewayV2HttpService<H>
83{
84    type Response = http::Response<ApiGatewayV2ResponseBody>;
85    type Error = Infallible;
86    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
87
88    fn call(&self, req: http::Request<Incoming>) -> Self::Future {
89        let handler = Arc::clone(&self.handler);
90        let config = Arc::clone(&self.config);
91        let request_id = uuid::Uuid::new_v4().to_string();
92
93        Box::pin(async move {
94            let response = process_request(req, handler.as_ref(), &config, &request_id).await;
95            let response = add_common_headers(response, &request_id);
96            Ok(response)
97        })
98    }
99}
100
101/// Process a single API Gateway v2 HTTP request through the full pipeline.
102async fn process_request<H: ApiGatewayV2Handler>(
103    req: http::Request<Incoming>,
104    handler: &H,
105    config: &ApiGatewayV2HttpConfig,
106    _request_id: &str,
107) -> http::Response<ApiGatewayV2ResponseBody> {
108    let (parts, incoming) = req.into_parts();
109
110    // 1. Route: extract operation from method + path.
111    let path = parts.uri.path();
112    let (op, path_params, success_status) = match resolve_operation(&parts.method, path) {
113        Ok(result) => result,
114        Err(err) => return wrap_error_response(&err),
115    };
116
117    // 2. Extract query string.
118    let query = parts.uri.query().unwrap_or("").to_owned();
119
120    // 3. Collect body.
121    let body = match collect_body(incoming).await {
122        Ok(body) => body,
123        Err(err) => return wrap_error_response(&err),
124    };
125
126    // 4. Authenticate (if enabled).
127    if !config.skip_signature_validation {
128        if let Some(ref cred_provider) = config.credential_provider {
129            let body_hash = rustack_auth::hash_payload(&body);
130            if let Err(auth_err) =
131                rustack_auth::verify_sigv4(&parts, &body_hash, cred_provider.as_ref())
132            {
133                let err = ApiGatewayV2Error::with_message(
134                    rustack_apigatewayv2_model::error::ApiGatewayV2ErrorCode::AccessDeniedException,
135                    auth_err.to_string(),
136                );
137                return wrap_error_response(&err);
138            }
139        }
140    }
141
142    // 5. Dispatch to handler.
143    match dispatch_operation(handler, op, path_params, query, parts.headers, body).await {
144        Ok(mut response) => {
145            // Override status if the handler returned 200 but the route specifies differently.
146            if response.status() == http::StatusCode::OK && success_status != 200 {
147                *response.status_mut() =
148                    http::StatusCode::from_u16(success_status).unwrap_or(http::StatusCode::OK);
149            }
150            response
151        }
152        Err(err) => wrap_error_response(&err),
153    }
154}
155
156/// Convert an `ApiGatewayV2Error` into an `ApiGatewayV2ResponseBody`-typed response.
157///
158/// Falls back to a plain-text 500 response if the error response itself
159/// cannot be constructed (extremely unlikely).
160fn wrap_error_response(error: &ApiGatewayV2Error) -> http::Response<ApiGatewayV2ResponseBody> {
161    if let Ok(bytes_response) = error_to_response(error) {
162        let (parts, body) = bytes_response.into_parts();
163        http::Response::from_parts(parts, ApiGatewayV2ResponseBody::from_bytes(body))
164    } else {
165        // Fallback: if we cannot even serialize the error, return a minimal 500.
166        let (parts, body) = http::Response::builder()
167            .status(http::StatusCode::INTERNAL_SERVER_ERROR)
168            .body(Bytes::from(r#"{"message":"Internal error"}"#))
169            .unwrap_or_default()
170            .into_parts();
171        http::Response::from_parts(parts, ApiGatewayV2ResponseBody::from_bytes(body))
172    }
173}
174
175/// Collect the incoming body into a single `Bytes` buffer.
176async fn collect_body(incoming: Incoming) -> Result<Bytes, ApiGatewayV2Error> {
177    incoming
178        .collect()
179        .await
180        .map(http_body_util::Collected::to_bytes)
181        .map_err(|e| ApiGatewayV2Error::internal_error(format!("Failed to read request body: {e}")))
182}
183
184/// Add common response headers to every API Gateway v2 response.
185fn add_common_headers(
186    mut response: http::Response<ApiGatewayV2ResponseBody>,
187    request_id: &str,
188) -> http::Response<ApiGatewayV2ResponseBody> {
189    let is_no_content = response.status() == http::StatusCode::NO_CONTENT;
190    let headers = response.headers_mut();
191
192    if let Ok(hv) = http::HeaderValue::from_str(request_id) {
193        headers.entry("x-amzn-requestid").or_insert(hv);
194    }
195
196    // Only set content-type for responses with a body (not 204 No Content).
197    if !is_no_content {
198        headers
199            .entry("content-type")
200            .or_insert(http::HeaderValue::from_static(CONTENT_TYPE));
201    }
202
203    headers.insert("server", http::HeaderValue::from_static("Rustack"));
204
205    // CORS headers.
206    headers.insert(
207        "access-control-allow-origin",
208        http::HeaderValue::from_static("*"),
209    );
210
211    response
212}