Skip to main content

rustack_cloudfront_http/
service.rs

1//! CloudFront hyper `Service`.
2
3use std::{convert::Infallible, future::Future, pin::Pin, sync::Arc};
4
5use bytes::Bytes;
6use http_body_util::{BodyExt, Full};
7use hyper::body::{Body, Frame, Incoming};
8
9use crate::{
10    dispatch::{CloudFrontHandler, dispatch as dispatch_op},
11    response::error_response,
12    router::resolve,
13};
14
15/// Response body type for the CloudFront HTTP service.
16#[derive(Debug)]
17pub struct HttpBody {
18    inner: Full<Bytes>,
19}
20
21impl Default for HttpBody {
22    fn default() -> Self {
23        Self {
24            inner: Full::new(Bytes::new()),
25        }
26    }
27}
28
29impl From<String> for HttpBody {
30    fn from(s: String) -> Self {
31        Self {
32            inner: Full::new(Bytes::from(s)),
33        }
34    }
35}
36
37impl From<Bytes> for HttpBody {
38    fn from(b: Bytes) -> Self {
39        Self {
40            inner: Full::new(b),
41        }
42    }
43}
44
45impl Body for HttpBody {
46    type Data = Bytes;
47    type Error = Infallible;
48
49    fn poll_frame(
50        self: Pin<&mut Self>,
51        cx: &mut std::task::Context<'_>,
52    ) -> std::task::Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
53        // SAFETY: we never move `inner` out.
54        let inner = unsafe { self.map_unchecked_mut(|s| &mut s.inner) };
55        inner.poll_frame(cx)
56    }
57}
58
59/// Configuration for the CloudFront HTTP service.
60#[derive(Clone)]
61pub struct CloudFrontHttpConfig {
62    /// Whether to skip SigV4 validation.
63    pub skip_signature_validation: bool,
64    /// Region string to report to clients.
65    pub region: String,
66    /// Optional credential provider for SigV4 verification.
67    pub credential_provider: Option<Arc<dyn rustack_auth::CredentialProvider>>,
68}
69
70impl std::fmt::Debug for CloudFrontHttpConfig {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        f.debug_struct("CloudFrontHttpConfig")
73            .field("skip_signature_validation", &self.skip_signature_validation)
74            .field("region", &self.region)
75            .field(
76                "credential_provider",
77                &self.credential_provider.as_ref().map(|_| "..."),
78            )
79            .finish()
80    }
81}
82
83impl Default for CloudFrontHttpConfig {
84    fn default() -> Self {
85        Self {
86            skip_signature_validation: true,
87            region: "us-east-1".to_owned(),
88            credential_provider: None,
89        }
90    }
91}
92
93/// CloudFront HTTP service.
94#[derive(Debug)]
95pub struct CloudFrontHttpService<H: CloudFrontHandler> {
96    handler: Arc<H>,
97    config: Arc<CloudFrontHttpConfig>,
98}
99
100impl<H: CloudFrontHandler> CloudFrontHttpService<H> {
101    /// Create a new service.
102    pub fn new(handler: Arc<H>, config: CloudFrontHttpConfig) -> Self {
103        Self {
104            handler,
105            config: Arc::new(config),
106        }
107    }
108}
109
110impl<H: CloudFrontHandler> Clone for CloudFrontHttpService<H> {
111    fn clone(&self) -> Self {
112        Self {
113            handler: Arc::clone(&self.handler),
114            config: Arc::clone(&self.config),
115        }
116    }
117}
118
119impl<H: CloudFrontHandler> hyper::service::Service<http::Request<Incoming>>
120    for CloudFrontHttpService<H>
121{
122    type Response = http::Response<HttpBody>;
123    type Error = Infallible;
124    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
125
126    fn call(&self, req: http::Request<Incoming>) -> Self::Future {
127        let handler = Arc::clone(&self.handler);
128        let config = Arc::clone(&self.config);
129        let request_id = uuid::Uuid::new_v4().to_string();
130        Box::pin(async move { Ok(serve(req, handler.as_ref(), &config, request_id).await) })
131    }
132}
133
134async fn serve<H: CloudFrontHandler>(
135    req: http::Request<Incoming>,
136    handler: &H,
137    _config: &CloudFrontHttpConfig,
138    request_id: String,
139) -> http::Response<HttpBody> {
140    let (parts, body) = req.into_parts();
141    let body_bytes = match body.collect().await {
142        Ok(c) => c.to_bytes(),
143        Err(e) => {
144            let err = rustack_cloudfront_model::CloudFrontError::Internal(format!(
145                "failed to read body: {e}"
146            ));
147            return error_response(&err, &request_id);
148        }
149    };
150
151    let route = match resolve(&parts.method, &parts.uri) {
152        Ok(r) => r,
153        Err(e) => return error_response(&e, &request_id),
154    };
155
156    let if_match = parts
157        .headers
158        .get(http::header::IF_MATCH)
159        .and_then(|v| v.to_str().ok());
160
161    let mut resp = dispatch_op(
162        handler,
163        route,
164        &parts.uri,
165        &parts.headers,
166        if_match,
167        body_bytes,
168        &request_id,
169    )
170    .await;
171
172    if let Ok(hv) = http::HeaderValue::from_str(&request_id) {
173        resp.headers_mut().entry("x-amzn-requestid").or_insert(hv);
174    }
175    resp
176}