scratchstack_aws_signature_hyper/
lib.rs

1mod service;
2pub use crate::service::AwsSigV4VerifierService;
3
4#[cfg(test)]
5mod tests {
6    use {
7        crate::AwsSigV4VerifierService,
8        chrono::{DateTime, Utc},
9        futures::stream::StreamExt,
10        http::StatusCode,
11        hyper::{
12            client::{connect::dns::GaiResolver, HttpConnector},
13            server::conn::AddrStream,
14            service::{make_service_fn, service_fn},
15            Body, Request, Response, Server,
16        },
17        log::debug,
18        rusoto_core::{DispatchSignedRequest, HttpClient, Region},
19        rusoto_credential::AwsCredentials,
20        rusoto_signature::SignedRequest,
21        scratchstack_aws_principal::PrincipalActor,
22        scratchstack_aws_signature::{
23            get_signing_key_fn, GetSigningKeyRequest, SignatureError, SigningKey, SigningKeyKind,
24        },
25        std::{
26            convert::Infallible,
27            future::Future,
28            net::{Ipv6Addr, SocketAddr, SocketAddrV6},
29            pin::Pin,
30            task::{Context, Poll},
31            time::Duration,
32        },
33        tower::{BoxError, Service},
34    };
35
36    #[test_log::test(tokio::test)]
37    async fn test_fn_wrapper() {
38        let sigfn = get_signing_key_fn(get_creds_fn);
39        let wrapped = service_fn(hello_response);
40        let make_svc = make_service_fn(|_socket: &AddrStream| async move {
41            Ok::<_, Infallible>(AwsSigV4VerifierService::new("local", "service", sigfn, wrapped))
42        });
43
44        let server = Server::bind(&SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 5937, 0, 0))).serve(make_svc);
45        let mut connector = HttpConnector::new_with_resolver(GaiResolver::new());
46        connector.set_connect_timeout(Some(Duration::from_millis(10)));
47        let client = HttpClient::<HttpConnector<GaiResolver>>::from_connector(connector);
48        match server
49            .with_graceful_shutdown(async {
50                let region = Region::Custom {
51                    name: "local".to_owned(),
52                    endpoint: "http://[::1]:5937".to_owned(),
53                };
54                let mut sr = SignedRequest::new("GET", "service", &region, "/");
55                sr.sign(&AwsCredentials::new(
56                    "AKIDEXAMPLE",
57                    "AWS4wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
58                    None,
59                    None,
60                ));
61                match client.dispatch(sr, Some(Duration::from_millis(100))).await {
62                    Ok(r) => {
63                        eprintln!("Response from server: {:?}", r.status);
64
65                        let mut body = r.body;
66                        while let Some(b_result) = body.next().await {
67                            match b_result {
68                                Ok(bytes) => eprint!("{:?}", bytes),
69                                Err(e) => {
70                                    eprintln!("Error while ready body: {:?}", e);
71                                    break;
72                                }
73                            }
74                        }
75                        eprintln!();
76                        assert_eq!(r.status, StatusCode::OK);
77                    }
78                    Err(e) => panic!("Error from server: {:?}", e),
79                };
80            })
81            .await
82        {
83            Ok(()) => println!("Server shutdown normally"),
84            Err(e) => panic!("Server shutdown with error {:?}", e),
85        }
86    }
87
88    #[test_log::test(tokio::test)]
89    async fn test_svc_wrapper() {
90        let make_svc = SpawnDummyHelloService {};
91        let server = Server::bind(&SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 5938, 0, 0))).serve(make_svc);
92        let mut connector = HttpConnector::new_with_resolver(GaiResolver::new());
93        connector.set_connect_timeout(Some(Duration::from_millis(10)));
94        let client = HttpClient::<HttpConnector<GaiResolver>>::from_connector(connector);
95        let mut status = StatusCode::OK;
96        match server
97            .with_graceful_shutdown(async {
98                let region = Region::Custom {
99                    name: "local".to_owned(),
100                    endpoint: "http://[::1]:5938".to_owned(),
101                };
102                let mut sr = SignedRequest::new("GET", "service", &region, "/");
103                sr.sign(&AwsCredentials::new(
104                    "AKIDEXAMPLE",
105                    "AWS4wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
106                    None,
107                    None,
108                ));
109                match client.dispatch(sr, Some(Duration::from_millis(100))).await {
110                    Ok(r) => {
111                        eprintln!("Response from server: {:?}", r.status);
112
113                        let mut body = r.body;
114                        while let Some(b_result) = body.next().await {
115                            match b_result {
116                                Ok(bytes) => eprint!("{:?}", bytes),
117                                Err(e) => {
118                                    eprintln!("Error while ready body: {:?}", e);
119                                    break;
120                                }
121                            }
122                        }
123                        eprintln!();
124                        status = r.status;
125                    }
126                    Err(e) => panic!("Error from server: {:?}", e),
127                };
128            })
129            .await
130        {
131            Ok(()) => println!("Server shutdown normally"),
132            Err(e) => panic!("Server shutdown with error {:?}", e),
133        }
134
135        assert_eq!(status, StatusCode::OK);
136    }
137
138    #[test_log::test(tokio::test)]
139    async fn test_svc_wrapper_bad_creds() {
140        let make_svc = SpawnDummyHelloService {};
141        let server = Server::bind(&SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 5939, 0, 0))).serve(make_svc);
142        let mut connector = HttpConnector::new_with_resolver(GaiResolver::new());
143        connector.set_connect_timeout(Some(Duration::from_millis(10)));
144        let client = HttpClient::<HttpConnector<GaiResolver>>::from_connector(connector);
145        match server
146            .with_graceful_shutdown(async {
147                let region = Region::Custom {
148                    name: "local".to_owned(),
149                    endpoint: "http://[::1]:5939".to_owned(),
150                };
151                let mut sr = SignedRequest::new("GET", "service", &region, "/");
152                sr.sign(&AwsCredentials::new("AKIDEXAMPLE", "WRONGKEY", None, None));
153                match client.dispatch(sr, Some(Duration::from_millis(100))).await {
154                    Ok(r) => {
155                        eprintln!("Response from server: {:?}", r.status);
156
157                        let mut body = r.body;
158                        while let Some(b_result) = body.next().await {
159                            match b_result {
160                                Ok(bytes) => eprint!("{:?}", bytes),
161                                Err(e) => {
162                                    eprintln!("Error while ready body: {:?}", e);
163                                    break;
164                                }
165                            }
166                        }
167                        eprintln!();
168                        assert_eq!(r.status, StatusCode::UNAUTHORIZED);
169                    }
170                    Err(e) => panic!("Error from server: {:?}", e),
171                };
172            })
173            .await
174        {
175            Ok(()) => println!("Server shutdown normally"),
176            Err(e) => panic!("Server shutdown with error {:?}", e),
177        }
178    }
179
180    async fn get_creds_fn(
181        signing_key_kind: SigningKeyKind,
182        access_key: String,
183        _session_token: Option<String>,
184        request_date: DateTime<Utc>,
185        region: String,
186        service: String,
187    ) -> Result<(PrincipalActor, SigningKey), SignatureError> {
188        if access_key == "AKIDEXAMPLE" {
189            let k_secret = SigningKey {
190                kind: SigningKeyKind::KSecret,
191                key: b"AWS4wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY".to_vec(),
192            };
193
194            let principal = PrincipalActor::user("aws", "123456789012", "/", "test", "AIDAAAAAAAAAAAAAAAAA").unwrap();
195            Ok((principal, k_secret.derive(signing_key_kind, &request_date, region, service)))
196        } else {
197            Err(SignatureError::UnknownAccessKey {
198                access_key,
199            })
200        }
201    }
202
203    async fn hello_response(_req: Request<Body>) -> Result<Response<Body>, BoxError> {
204        Ok(Response::new(Body::from("Hello world")))
205    }
206
207    #[derive(Clone)]
208    struct SpawnDummyHelloService {}
209    impl Service<&AddrStream> for SpawnDummyHelloService {
210        type Response = AwsSigV4VerifierService<GetDummyCreds, HelloService>;
211        type Error = BoxError;
212        type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
213
214        fn poll_ready(&mut self, _c: &mut Context) -> Poll<Result<(), Self::Error>> {
215            Poll::Ready(Ok(()))
216        }
217
218        fn call(&mut self, _addr: &AddrStream) -> Self::Future {
219            Box::pin(
220                async move { Ok(AwsSigV4VerifierService::new("local", "service", GetDummyCreds {}, HelloService {})) },
221            )
222        }
223    }
224
225    #[derive(Clone)]
226    struct GetDummyCreds {}
227    impl Service<GetSigningKeyRequest> for GetDummyCreds {
228        type Response = (PrincipalActor, SigningKey);
229        type Error = BoxError;
230        type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
231
232        fn poll_ready(&mut self, _c: &mut Context) -> Poll<Result<(), Self::Error>> {
233            Poll::Ready(Ok(()))
234        }
235
236        fn call(&mut self, req: GetSigningKeyRequest) -> Self::Future {
237            Box::pin(async move {
238                if req.access_key == "AKIDEXAMPLE" {
239                    let k_secret = SigningKey {
240                        kind: SigningKeyKind::KSecret,
241                        key: b"AWS4wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY".to_vec(),
242                    };
243                    debug!("secret key: {:?} {:02x?}", k_secret, &k_secret.key);
244
245                    let principal =
246                        PrincipalActor::user("aws", "123456789012", "/", "test", "AIDAAAAAAAAAAAAAAAAA").unwrap();
247                    let derived = k_secret.derive(req.signing_key_kind, &req.request_date, req.region, req.service);
248                    debug!("derived key: {:?} {:02x?}", derived, &derived.key);
249                    Ok((principal, derived))
250                } else {
251                    Err(SignatureError::UnknownAccessKey {
252                        access_key: req.access_key,
253                    }
254                    .into())
255                }
256            })
257        }
258    }
259
260    #[derive(Clone)]
261    struct HelloService {}
262    impl Service<Request<Body>> for HelloService {
263        type Response = Response<Body>;
264        type Error = BoxError;
265        type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
266
267        fn poll_ready(&mut self, _c: &mut Context) -> Poll<Result<(), Self::Error>> {
268            Poll::Ready(Ok(()))
269        }
270
271        fn call(&mut self, req: Request<Body>) -> Self::Future {
272            Box::pin(async move {
273                let (parts, _body) = req.into_parts();
274                let principal = parts.extensions.get::<PrincipalActor>();
275
276                let (status, body) = match principal {
277                    Some(principal) => (StatusCode::OK, format!("Hello {:?}", principal)),
278                    None => (StatusCode::UNAUTHORIZED, "Unauthorized".to_string()),
279                };
280
281                match Response::builder().status(status).header("Content-Type", "text/plain").body(Body::from(body)) {
282                    Ok(r) => Ok(r),
283                    Err(e) => {
284                        eprintln!("Response builder: error: {:?}", e);
285                        Err(e.into())
286                    }
287                }
288            })
289        }
290    }
291}