scratchstack_aws_signature_hyper/
service.rs

1use {
2    chrono::Duration,
3    futures::stream::StreamExt,
4    http::request::Parts,
5    hyper::{
6        body::{Body, Bytes},
7        Error as HyperError, Request, Response,
8    },
9    log::{debug, warn},
10    scratchstack_aws_principal::PrincipalActor,
11    scratchstack_aws_signature::{
12        sigv4_verify, GetSigningKeyRequest, Request as AwsSigVerifyRequest, SigningKey, SigningKeyKind,
13    },
14    std::{
15        any::type_name,
16        fmt::{Debug, Display, Formatter, Result as FmtResult},
17        future::Future,
18        pin::Pin,
19        task::{Context, Poll},
20    },
21    tower::{buffer::Buffer, BoxError, Service, ServiceExt},
22};
23
24/// AWSSigV4VerifierService implements a Hyper service that authenticates a request against AWS SigV4 signing protocol.
25#[derive(Clone)]
26pub struct AwsSigV4VerifierService<G, S>
27where
28    G: Service<GetSigningKeyRequest, Response = (PrincipalActor, SigningKey)> + Clone + Send + 'static,
29    G::Future: Send,
30    G::Error: Into<BoxError> + Send + Sync,
31    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
32    S::Future: Send,
33    S::Error: Into<BoxError> + Send + Sync,
34{
35    pub signing_key_kind: SigningKeyKind,
36    pub allowed_mismatch: Option<Duration>,
37    pub region: String,
38    pub service: String,
39    pub get_signing_key: Buffer<G, GetSigningKeyRequest>,
40    pub implementation: Buffer<S, Request<Body>>,
41}
42
43impl<G, S> AwsSigV4VerifierService<G, S>
44where
45    G: Service<GetSigningKeyRequest, Response = (PrincipalActor, SigningKey)> + Clone + Send + 'static,
46    G::Future: Send,
47    G::Error: Into<BoxError> + Send + Sync,
48    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
49    S::Future: Send,
50    S::Error: Into<BoxError> + Send + Sync,
51{
52    pub fn new<S1, S2>(region: S1, service: S2, get_signing_key: G, implementation: S) -> Self
53    where
54        S1: Into<String>,
55        S2: Into<String>,
56    {
57        AwsSigV4VerifierService {
58            signing_key_kind: SigningKeyKind::KSigning,
59            allowed_mismatch: Some(Duration::minutes(5)),
60            region: region.into(),
61            service: service.into(),
62            get_signing_key: Buffer::new(get_signing_key, 10),
63            implementation: Buffer::new(implementation, 10),
64        }
65    }
66}
67
68impl<G, S> Debug for AwsSigV4VerifierService<G, S>
69where
70    G: Service<GetSigningKeyRequest, Response = (PrincipalActor, SigningKey)> + Clone + Send + 'static,
71    G::Future: Send,
72    G::Error: Into<BoxError> + Send + Sync,
73    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
74    S::Future: Send,
75    S::Error: Into<BoxError> + Send + Sync,
76{
77    fn fmt(&self, f: &mut Formatter) -> FmtResult {
78        f.debug_struct("AwsSigV4VerifierService")
79            .field("region", &self.region)
80            .field("service", &self.service)
81            .field("get_signing_key", &type_name::<G>())
82            .field("implementation", &type_name::<S>())
83            .finish()
84    }
85}
86
87impl<G, S> Display for AwsSigV4VerifierService<G, S>
88where
89    G: Service<GetSigningKeyRequest, Response = (PrincipalActor, SigningKey)> + Clone + Send + 'static,
90    G::Future: Send,
91    G::Error: Into<BoxError> + Send + Sync,
92    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
93    S::Future: Send,
94    S::Error: Into<BoxError> + Send + Sync,
95{
96    fn fmt(&self, f: &mut Formatter) -> FmtResult {
97        Debug::fmt(self, f)
98    }
99}
100
101// impl<S, GSK> AwsSigV4VerifierService<S, GSK>
102// where
103//     S: Service<Request<Body>, Response=Response<Body>> + Send + Sync + 'static,
104//     S::Error: From<HyperError>,
105//     GSK: GetSigningKey + Clone + Send + Sync + 'static,
106//     GSK::Future: Send + Sync,
107// {
108//     async fn handle_call(&mut self, req: Request<Body>) -> Result<Response<Body>, <Self as Service<Request<Body>>>::Error> {
109//     }
110// }
111
112impl<G, S> Service<Request<Body>> for AwsSigV4VerifierService<G, S>
113where
114    G: Service<GetSigningKeyRequest, Response = (PrincipalActor, SigningKey)> + Clone + Send + 'static,
115    G::Future: Send,
116    G::Error: Into<BoxError> + Send + Sync,
117    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
118    S::Future: Send,
119    S::Error: Into<BoxError> + Send + Sync,
120{
121    type Response = S::Response;
122    type Error = BoxError;
123    type Future = Pin<Box<dyn Future<Output = Result<Response<Body>, BoxError>> + Send>>;
124
125    fn poll_ready(&mut self, c: &mut Context) -> Poll<Result<(), Self::Error>> {
126        match self.get_signing_key.poll_ready(c) {
127            Poll::Ready(r) => match r {
128                Ok(()) => match self.implementation.poll_ready(c) {
129                    Poll::Ready(r) => match r {
130                        Ok(()) => Poll::Ready(Ok(())),
131                        Err(e) => Poll::Ready(Err(e)),
132                    },
133                    Poll::Pending => Poll::Pending,
134                },
135                Err(e) => Poll::Ready(Err(e)),
136            },
137            Poll::Pending => Poll::Pending,
138        }
139    }
140
141    fn call(&mut self, req: Request<Body>) -> Self::Future {
142        let (parts, body) = req.into_parts();
143        let allowed_mismatch = self.allowed_mismatch;
144        let region = self.region.clone();
145        let service = self.service.clone();
146        let signing_key_kind = self.signing_key_kind;
147        let get_signing_key = self.get_signing_key.clone();
148        let implementation = self.implementation.clone();
149
150        Box::pin(handle_call(
151            parts,
152            body,
153            allowed_mismatch,
154            region,
155            service,
156            get_signing_key,
157            signing_key_kind,
158            implementation,
159        ))
160    }
161}
162
163#[allow(clippy::too_many_arguments)]
164async fn handle_call<G, S>(
165    mut parts: Parts,
166    body: Body,
167    allowed_mismatch: Option<Duration>,
168    region: String,
169    service: String,
170    get_signing_key: Buffer<G, GetSigningKeyRequest>,
171    signing_key_kind: SigningKeyKind,
172    implementation: Buffer<S, Request<Body>>,
173) -> Result<Response<Body>, BoxError>
174where
175    G: Service<GetSigningKeyRequest, Response = (PrincipalActor, SigningKey)> + Clone + Send + 'static,
176    G::Future: Send,
177    G::Error: Into<BoxError> + Send + Sync,
178    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
179    S::Future: Send,
180    S::Error: Into<BoxError> + Send + Sync,
181{
182    debug!("Request: {} {}", parts.method, parts.uri);
183    debug!("Request headers:");
184    for (key, value) in &parts.headers {
185        let value_disp = value.to_str().unwrap_or("<INVALID>");
186        debug!("{}: {}", key, value_disp);
187    }
188
189    // We need the actual body in order to compute the signature.
190    match body_to_bytes(body).await {
191        Err(e) => Err(e.into()),
192        Ok(body) => {
193            let aws_req = AwsSigVerifyRequest::from_http_request_parts(&parts, Some(body.clone()));
194            let sig_req = match aws_req.to_get_signing_key_request(signing_key_kind, &region, &service) {
195                Ok(sig_req) => Some(sig_req),
196                Err(e) => {
197                    warn!("Failed to generate a GetSigningKeyRequest request from Request: {:?}", e);
198                    None
199                }
200            };
201            if let Some(sig_req) = sig_req {
202                match get_signing_key.oneshot(sig_req).await {
203                    Ok((principal, signing_key)) => {
204                        debug!("Get signing key returned principal {:?}", principal);
205                        match sigv4_verify(&aws_req, &signing_key, allowed_mismatch, &region, &service) {
206                            Ok(()) => {
207                                debug!("Signature verified; adding principal to request: {:?}", principal);
208                                parts.extensions.insert(principal);
209                            }
210                            Err(e) => warn!("Signature mismatch: {:?}", e),
211                        }
212                    }
213                    Err(e) => warn!("Get signing key failed: {:?}", e),
214                }
215            }
216
217            let new_body = Bytes::copy_from_slice(&body);
218            let new_req = Request::from_parts(parts, Body::from(new_body));
219            match implementation.oneshot(new_req).await {
220                Ok(r) => Ok(r),
221                Err(e) => Err(e),
222            }
223        }
224    }
225}
226
227async fn body_to_bytes(mut body: Body) -> Result<Vec<u8>, HyperError> {
228    let mut result = Vec::<u8>::new();
229
230    loop {
231        match body.next().await {
232            None => break,
233            Some(chunk_result) => match chunk_result {
234                Ok(chunk) => result.append(&mut chunk.to_vec()),
235                Err(e) => return Err(e),
236            },
237        }
238    }
239
240    Ok(result)
241}