1use {
2 chrono::Duration,
3 futures::stream::StreamExt,
4 http::request::Parts,
5 hyper::{
6 body::{Body, Bytes},
7 Error as HyperError, Request, Response,
8 },
9 log::{debug, warn},
10 scratchstack_aws_principal::PrincipalActor,
11 scratchstack_aws_signature::{
12 sigv4_verify, GetSigningKeyRequest, Request as AwsSigVerifyRequest, SigningKey, SigningKeyKind,
13 },
14 std::{
15 any::type_name,
16 fmt::{Debug, Display, Formatter, Result as FmtResult},
17 future::Future,
18 pin::Pin,
19 task::{Context, Poll},
20 },
21 tower::{buffer::Buffer, BoxError, Service, ServiceExt},
22};
23
24#[derive(Clone)]
26pub struct AwsSigV4VerifierService<G, S>
27where
28 G: Service<GetSigningKeyRequest, Response = (PrincipalActor, SigningKey)> + Clone + Send + 'static,
29 G::Future: Send,
30 G::Error: Into<BoxError> + Send + Sync,
31 S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
32 S::Future: Send,
33 S::Error: Into<BoxError> + Send + Sync,
34{
35 pub signing_key_kind: SigningKeyKind,
36 pub allowed_mismatch: Option<Duration>,
37 pub region: String,
38 pub service: String,
39 pub get_signing_key: Buffer<G, GetSigningKeyRequest>,
40 pub implementation: Buffer<S, Request<Body>>,
41}
42
43impl<G, S> AwsSigV4VerifierService<G, S>
44where
45 G: Service<GetSigningKeyRequest, Response = (PrincipalActor, SigningKey)> + Clone + Send + 'static,
46 G::Future: Send,
47 G::Error: Into<BoxError> + Send + Sync,
48 S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
49 S::Future: Send,
50 S::Error: Into<BoxError> + Send + Sync,
51{
52 pub fn new<S1, S2>(region: S1, service: S2, get_signing_key: G, implementation: S) -> Self
53 where
54 S1: Into<String>,
55 S2: Into<String>,
56 {
57 AwsSigV4VerifierService {
58 signing_key_kind: SigningKeyKind::KSigning,
59 allowed_mismatch: Some(Duration::minutes(5)),
60 region: region.into(),
61 service: service.into(),
62 get_signing_key: Buffer::new(get_signing_key, 10),
63 implementation: Buffer::new(implementation, 10),
64 }
65 }
66}
67
68impl<G, S> Debug for AwsSigV4VerifierService<G, S>
69where
70 G: Service<GetSigningKeyRequest, Response = (PrincipalActor, SigningKey)> + Clone + Send + 'static,
71 G::Future: Send,
72 G::Error: Into<BoxError> + Send + Sync,
73 S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
74 S::Future: Send,
75 S::Error: Into<BoxError> + Send + Sync,
76{
77 fn fmt(&self, f: &mut Formatter) -> FmtResult {
78 f.debug_struct("AwsSigV4VerifierService")
79 .field("region", &self.region)
80 .field("service", &self.service)
81 .field("get_signing_key", &type_name::<G>())
82 .field("implementation", &type_name::<S>())
83 .finish()
84 }
85}
86
87impl<G, S> Display for AwsSigV4VerifierService<G, S>
88where
89 G: Service<GetSigningKeyRequest, Response = (PrincipalActor, SigningKey)> + Clone + Send + 'static,
90 G::Future: Send,
91 G::Error: Into<BoxError> + Send + Sync,
92 S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
93 S::Future: Send,
94 S::Error: Into<BoxError> + Send + Sync,
95{
96 fn fmt(&self, f: &mut Formatter) -> FmtResult {
97 Debug::fmt(self, f)
98 }
99}
100
101impl<G, S> Service<Request<Body>> for AwsSigV4VerifierService<G, S>
113where
114 G: Service<GetSigningKeyRequest, Response = (PrincipalActor, SigningKey)> + Clone + Send + 'static,
115 G::Future: Send,
116 G::Error: Into<BoxError> + Send + Sync,
117 S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
118 S::Future: Send,
119 S::Error: Into<BoxError> + Send + Sync,
120{
121 type Response = S::Response;
122 type Error = BoxError;
123 type Future = Pin<Box<dyn Future<Output = Result<Response<Body>, BoxError>> + Send>>;
124
125 fn poll_ready(&mut self, c: &mut Context) -> Poll<Result<(), Self::Error>> {
126 match self.get_signing_key.poll_ready(c) {
127 Poll::Ready(r) => match r {
128 Ok(()) => match self.implementation.poll_ready(c) {
129 Poll::Ready(r) => match r {
130 Ok(()) => Poll::Ready(Ok(())),
131 Err(e) => Poll::Ready(Err(e)),
132 },
133 Poll::Pending => Poll::Pending,
134 },
135 Err(e) => Poll::Ready(Err(e)),
136 },
137 Poll::Pending => Poll::Pending,
138 }
139 }
140
141 fn call(&mut self, req: Request<Body>) -> Self::Future {
142 let (parts, body) = req.into_parts();
143 let allowed_mismatch = self.allowed_mismatch;
144 let region = self.region.clone();
145 let service = self.service.clone();
146 let signing_key_kind = self.signing_key_kind;
147 let get_signing_key = self.get_signing_key.clone();
148 let implementation = self.implementation.clone();
149
150 Box::pin(handle_call(
151 parts,
152 body,
153 allowed_mismatch,
154 region,
155 service,
156 get_signing_key,
157 signing_key_kind,
158 implementation,
159 ))
160 }
161}
162
163#[allow(clippy::too_many_arguments)]
164async fn handle_call<G, S>(
165 mut parts: Parts,
166 body: Body,
167 allowed_mismatch: Option<Duration>,
168 region: String,
169 service: String,
170 get_signing_key: Buffer<G, GetSigningKeyRequest>,
171 signing_key_kind: SigningKeyKind,
172 implementation: Buffer<S, Request<Body>>,
173) -> Result<Response<Body>, BoxError>
174where
175 G: Service<GetSigningKeyRequest, Response = (PrincipalActor, SigningKey)> + Clone + Send + 'static,
176 G::Future: Send,
177 G::Error: Into<BoxError> + Send + Sync,
178 S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
179 S::Future: Send,
180 S::Error: Into<BoxError> + Send + Sync,
181{
182 debug!("Request: {} {}", parts.method, parts.uri);
183 debug!("Request headers:");
184 for (key, value) in &parts.headers {
185 let value_disp = value.to_str().unwrap_or("<INVALID>");
186 debug!("{}: {}", key, value_disp);
187 }
188
189 match body_to_bytes(body).await {
191 Err(e) => Err(e.into()),
192 Ok(body) => {
193 let aws_req = AwsSigVerifyRequest::from_http_request_parts(&parts, Some(body.clone()));
194 let sig_req = match aws_req.to_get_signing_key_request(signing_key_kind, ®ion, &service) {
195 Ok(sig_req) => Some(sig_req),
196 Err(e) => {
197 warn!("Failed to generate a GetSigningKeyRequest request from Request: {:?}", e);
198 None
199 }
200 };
201 if let Some(sig_req) = sig_req {
202 match get_signing_key.oneshot(sig_req).await {
203 Ok((principal, signing_key)) => {
204 debug!("Get signing key returned principal {:?}", principal);
205 match sigv4_verify(&aws_req, &signing_key, allowed_mismatch, ®ion, &service) {
206 Ok(()) => {
207 debug!("Signature verified; adding principal to request: {:?}", principal);
208 parts.extensions.insert(principal);
209 }
210 Err(e) => warn!("Signature mismatch: {:?}", e),
211 }
212 }
213 Err(e) => warn!("Get signing key failed: {:?}", e),
214 }
215 }
216
217 let new_body = Bytes::copy_from_slice(&body);
218 let new_req = Request::from_parts(parts, Body::from(new_body));
219 match implementation.oneshot(new_req).await {
220 Ok(r) => Ok(r),
221 Err(e) => Err(e),
222 }
223 }
224 }
225}
226
227async fn body_to_bytes(mut body: Body) -> Result<Vec<u8>, HyperError> {
228 let mut result = Vec::<u8>::new();
229
230 loop {
231 match body.next().await {
232 None => break,
233 Some(chunk_result) => match chunk_result {
234 Ok(chunk) => result.append(&mut chunk.to_vec()),
235 Err(e) => return Err(e),
236 },
237 }
238 }
239
240 Ok(result)
241}