sui_http/middleware/callback/
mod.rs1use http::HeaderMap;
88use http::request;
89use http::response;
90
91mod body;
92mod future;
93mod layer;
94mod service;
95
96pub use self::body::RequestBody;
97pub use self::body::ResponseBody;
98pub use self::future::ResponseFuture;
99pub use self::layer::CallbackLayer;
100pub use self::service::Callback;
101
102pub trait MakeCallbackHandler {
108 type RequestHandler: RequestHandler;
111 type ResponseHandler: ResponseHandler;
114
115 fn make_handler(
117 &self,
118 request: &request::Parts,
119 ) -> (Self::RequestHandler, Self::ResponseHandler);
120}
121
122pub trait RequestHandler {
129 fn on_body_chunk<B>(&mut self, _chunk: &B)
131 where
132 B: bytes::Buf,
133 {
134 }
136
137 fn on_end_of_stream(&mut self, _trailers: Option<&HeaderMap>) {
142 }
144
145 fn on_body_error<E>(&mut self, _error: &E)
147 where
148 E: std::fmt::Display + 'static,
149 {
150 }
152}
153
154impl RequestHandler for () {}
155
156pub trait ResponseHandler {
162 fn on_response(&mut self, response: &response::Parts);
164
165 fn on_service_error<E>(&mut self, error: &E)
169 where
170 E: std::fmt::Display + 'static;
171
172 fn on_body_chunk<B>(&mut self, _chunk: &B)
174 where
175 B: bytes::Buf,
176 {
177 }
179
180 fn on_end_of_stream(&mut self, _trailers: Option<&HeaderMap>) {
182 }
184
185 fn on_body_error<E>(&mut self, _error: &E)
187 where
188 E: std::fmt::Display + 'static,
189 {
190 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use bytes::Buf;
198 use bytes::Bytes;
199 use futures::stream;
200 use http::Request;
201 use http::Response;
202 use http_body::Body;
203 use http_body_util::BodyExt;
204 use http_body_util::Full;
205 use http_body_util::StreamBody;
206 use std::convert::Infallible;
207 use std::sync::Arc;
208 use std::sync::Mutex;
209 use tower::ServiceBuilder;
210 use tower::ServiceExt;
211
212 #[derive(Debug, Default, PartialEq, Eq)]
216 struct Events {
217 request_chunks: Vec<Vec<u8>>,
218 request_end_trailers: Vec<Option<HeaderMap>>,
219 request_body_errors: Vec<String>,
220 response_seen: u32,
221 response_chunks: Vec<Vec<u8>>,
222 response_end_trailers: Vec<Option<HeaderMap>>,
223 response_body_errors: Vec<String>,
224 response_service_errors: Vec<String>,
225 }
226
227 #[derive(Clone, Default)]
228 struct Recorder(Arc<Mutex<Events>>);
229
230 struct ReqH(Arc<Mutex<Events>>);
231 struct RespH(Arc<Mutex<Events>>);
232
233 impl RequestHandler for ReqH {
234 fn on_body_chunk<B: Buf>(&mut self, chunk: &B) {
235 self.0
236 .lock()
237 .unwrap()
238 .request_chunks
239 .push(chunk.chunk().to_vec());
240 }
241 fn on_end_of_stream(&mut self, trailers: Option<&HeaderMap>) {
242 self.0
243 .lock()
244 .unwrap()
245 .request_end_trailers
246 .push(trailers.cloned());
247 }
248 fn on_body_error<E: std::fmt::Display + 'static>(&mut self, error: &E) {
249 self.0
250 .lock()
251 .unwrap()
252 .request_body_errors
253 .push(error.to_string());
254 }
255 }
256
257 impl ResponseHandler for RespH {
258 fn on_response(&mut self, _parts: &response::Parts) {
259 self.0.lock().unwrap().response_seen += 1;
260 }
261 fn on_service_error<E: std::fmt::Display + 'static>(&mut self, error: &E) {
262 self.0
263 .lock()
264 .unwrap()
265 .response_service_errors
266 .push(error.to_string());
267 }
268 fn on_body_chunk<B: Buf>(&mut self, chunk: &B) {
269 self.0
270 .lock()
271 .unwrap()
272 .response_chunks
273 .push(chunk.chunk().to_vec());
274 }
275 fn on_end_of_stream(&mut self, trailers: Option<&HeaderMap>) {
276 self.0
277 .lock()
278 .unwrap()
279 .response_end_trailers
280 .push(trailers.cloned());
281 }
282 fn on_body_error<E: std::fmt::Display + 'static>(&mut self, error: &E) {
283 self.0
284 .lock()
285 .unwrap()
286 .response_body_errors
287 .push(error.to_string());
288 }
289 }
290
291 impl MakeCallbackHandler for Recorder {
292 type RequestHandler = ReqH;
293 type ResponseHandler = RespH;
294
295 fn make_handler(
296 &self,
297 _request: &request::Parts,
298 ) -> (Self::RequestHandler, Self::ResponseHandler) {
299 (ReqH(self.0.clone()), RespH(self.0.clone()))
300 }
301 }
302
303 async fn drain<B: Body + Unpin>(body: B) -> Result<(), B::Error> {
307 let collected = body.collect().await?;
308 let _ = collected.to_bytes();
309 Ok(())
310 }
311
312 #[tokio::test]
313 async fn observes_request_chunks_and_clean_end() {
314 let recorder = Recorder::default();
315 let events = recorder.0.clone();
316
317 let inner = tower::service_fn(
318 |req: Request<RequestBody<Full<Bytes>, ReqH>>| async move {
319 drain(req.into_body()).await.unwrap();
320 Ok::<_, Infallible>(Response::new(Full::new(Bytes::from_static(b"ok"))))
321 },
322 );
323 let svc = ServiceBuilder::new()
324 .layer(CallbackLayer::new(recorder))
325 .service(inner);
326
327 let request = Request::new(Full::new(Bytes::from_static(b"hello world")));
328 let response = svc.oneshot(request).await.unwrap();
329 drain(response.into_body()).await.unwrap();
330
331 let events = events.lock().unwrap();
332 assert_eq!(events.request_chunks, vec![b"hello world".to_vec()]);
333 assert_eq!(events.request_end_trailers, vec![None]);
334 assert!(events.request_body_errors.is_empty());
335 assert_eq!(events.response_seen, 1);
337 assert_eq!(events.response_chunks, vec![b"ok".to_vec()]);
338 assert_eq!(events.response_end_trailers, vec![None]);
339 assert!(events.response_body_errors.is_empty());
340 assert!(events.response_service_errors.is_empty());
341 }
342
343 #[tokio::test]
344 async fn observes_request_trailers_on_end() {
345 let recorder = Recorder::default();
346 let events = recorder.0.clone();
347
348 let mut trailers = HeaderMap::new();
349 trailers.insert("x-req-trailer", "abc".parse().unwrap());
350 let frames: Vec<Result<http_body::Frame<Bytes>, Infallible>> = vec![
351 Ok(http_body::Frame::data(Bytes::from_static(b"chunk-1"))),
352 Ok(http_body::Frame::data(Bytes::from_static(b"chunk-2"))),
353 Ok(http_body::Frame::trailers(trailers.clone())),
354 ];
355 let body = StreamBody::new(stream::iter(frames));
356
357 let inner = tower::service_fn(
358 |req: Request<RequestBody<StreamBody<_>, ReqH>>| async move {
359 drain(req.into_body()).await.unwrap();
360 Ok::<_, Infallible>(Response::new(Full::new(Bytes::new())))
361 },
362 );
363 let svc = ServiceBuilder::new()
364 .layer(CallbackLayer::new(recorder))
365 .service(inner);
366
367 let response = svc.oneshot(Request::new(body)).await.unwrap();
368 drain(response.into_body()).await.unwrap();
369
370 let events = events.lock().unwrap();
371 assert_eq!(
372 events.request_chunks,
373 vec![b"chunk-1".to_vec(), b"chunk-2".to_vec()]
374 );
375 assert_eq!(events.request_end_trailers.len(), 1);
376 assert_eq!(events.request_end_trailers[0].as_ref(), Some(&trailers));
377 assert!(events.request_body_errors.is_empty());
378 }
379
380 #[tokio::test]
381 async fn observes_request_body_error() {
382 #[derive(Debug)]
383 struct BodyErr;
384 impl std::fmt::Display for BodyErr {
385 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386 f.write_str("boom")
387 }
388 }
389 impl std::error::Error for BodyErr {}
390
391 let recorder = Recorder::default();
392 let events = recorder.0.clone();
393
394 let frames: Vec<Result<http_body::Frame<Bytes>, BodyErr>> = vec![
395 Ok(http_body::Frame::data(Bytes::from_static(b"partial"))),
396 Err(BodyErr),
397 ];
398 let body = StreamBody::new(stream::iter(frames));
399
400 let inner = tower::service_fn(
401 |req: Request<RequestBody<StreamBody<_>, ReqH>>| async move {
402 let _ = drain(req.into_body()).await;
404 Ok::<_, Infallible>(Response::new(Full::new(Bytes::new())))
405 },
406 );
407 let svc = ServiceBuilder::new()
408 .layer(CallbackLayer::new(recorder))
409 .service(inner);
410
411 let response = svc.oneshot(Request::new(body)).await.unwrap();
412 drain(response.into_body()).await.unwrap();
413
414 let events = events.lock().unwrap();
415 assert_eq!(events.request_chunks, vec![b"partial".to_vec()]);
416 assert_eq!(events.request_body_errors, vec!["boom".to_string()]);
417 assert!(events.request_end_trailers.is_empty());
419 }
420
421 #[tokio::test]
425 async fn unit_request_handler_is_noop() {
426 #[derive(Clone)]
427 struct MakeResponseOnly(Arc<Mutex<u32>>);
428
429 struct CountResp(Arc<Mutex<u32>>);
430 impl ResponseHandler for CountResp {
431 fn on_response(&mut self, _parts: &response::Parts) {
432 *self.0.lock().unwrap() += 1;
433 }
434 fn on_service_error<E: std::fmt::Display + 'static>(&mut self, _error: &E) {}
435 }
436
437 impl MakeCallbackHandler for MakeResponseOnly {
438 type RequestHandler = ();
439 type ResponseHandler = CountResp;
440
441 fn make_handler(
442 &self,
443 _request: &request::Parts,
444 ) -> (Self::RequestHandler, Self::ResponseHandler) {
445 ((), CountResp(self.0.clone()))
446 }
447 }
448
449 let counter = Arc::new(Mutex::new(0));
450 let make = MakeResponseOnly(counter.clone());
451
452 let inner = tower::service_fn(
453 |req: Request<RequestBody<Full<Bytes>, ()>>| async move {
454 drain(req.into_body()).await.unwrap();
455 Ok::<_, Infallible>(Response::new(Full::new(Bytes::from_static(b"hi"))))
456 },
457 );
458 let svc = ServiceBuilder::new()
459 .layer(CallbackLayer::new(make))
460 .service(inner);
461
462 let response = svc
463 .oneshot(Request::new(Full::new(Bytes::from_static(b"ping"))))
464 .await
465 .unwrap();
466 drain(response.into_body()).await.unwrap();
467
468 assert_eq!(*counter.lock().unwrap(), 1);
469 }
470
471 #[tokio::test]
472 async fn observes_response_trailers_on_end() {
473 let recorder = Recorder::default();
474 let events = recorder.0.clone();
475
476 let mut trailers = HeaderMap::new();
477 trailers.insert("x-resp-trailer", "xyz".parse().unwrap());
478 let frames: Vec<Result<http_body::Frame<Bytes>, Infallible>> = vec![
479 Ok(http_body::Frame::data(Bytes::from_static(b"part-1"))),
480 Ok(http_body::Frame::data(Bytes::from_static(b"part-2"))),
481 Ok(http_body::Frame::trailers(trailers.clone())),
482 ];
483 let body_slot = Arc::new(Mutex::new(Some(StreamBody::new(stream::iter(frames)))));
486
487 let inner = tower::service_fn({
488 let body_slot = body_slot.clone();
489 move |req: Request<RequestBody<Full<Bytes>, ReqH>>| {
490 let body = body_slot.lock().unwrap().take().expect("called once");
491 async move {
492 drain(req.into_body()).await.unwrap();
493 Ok::<_, Infallible>(Response::new(body))
494 }
495 }
496 });
497 let svc = ServiceBuilder::new()
498 .layer(CallbackLayer::new(recorder))
499 .service(inner);
500
501 let response = svc
502 .oneshot(Request::new(Full::new(Bytes::from_static(b"ping"))))
503 .await
504 .unwrap();
505 drain(response.into_body()).await.unwrap();
506
507 let events = events.lock().unwrap();
508 assert_eq!(events.response_seen, 1);
509 assert_eq!(
510 events.response_chunks,
511 vec![b"part-1".to_vec(), b"part-2".to_vec()]
512 );
513 assert_eq!(events.response_end_trailers.len(), 1);
514 assert_eq!(events.response_end_trailers[0].as_ref(), Some(&trailers));
515 assert!(events.response_body_errors.is_empty());
516 assert!(events.response_service_errors.is_empty());
517 }
518
519 #[tokio::test]
520 async fn observes_response_body_error() {
521 #[derive(Debug)]
522 struct BodyErr;
523 impl std::fmt::Display for BodyErr {
524 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
525 f.write_str("body-boom")
526 }
527 }
528 impl std::error::Error for BodyErr {}
529
530 let recorder = Recorder::default();
531 let events = recorder.0.clone();
532
533 let inner = tower::service_fn(
534 |req: Request<RequestBody<Full<Bytes>, ReqH>>| async move {
535 drain(req.into_body()).await.unwrap();
536 let frames: Vec<Result<http_body::Frame<Bytes>, BodyErr>> = vec![
537 Ok(http_body::Frame::data(Bytes::from_static(b"partial"))),
538 Err(BodyErr),
539 ];
540 Ok::<_, Infallible>(Response::new(StreamBody::new(stream::iter(frames))))
541 },
542 );
543 let svc = ServiceBuilder::new()
544 .layer(CallbackLayer::new(recorder))
545 .service(inner);
546
547 let response = svc
548 .oneshot(Request::new(Full::new(Bytes::new())))
549 .await
550 .unwrap();
551 let _ = drain(response.into_body()).await;
553
554 let events = events.lock().unwrap();
555 assert_eq!(events.response_seen, 1);
556 assert_eq!(events.response_chunks, vec![b"partial".to_vec()]);
557 assert_eq!(events.response_body_errors, vec!["body-boom".to_string()]);
558 assert!(events.response_service_errors.is_empty());
559 assert!(events.response_end_trailers.is_empty());
561 }
562
563 #[tokio::test]
564 async fn observes_service_error() {
565 #[derive(Debug)]
566 struct SvcErr;
567 impl std::fmt::Display for SvcErr {
568 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
569 f.write_str("svc-boom")
570 }
571 }
572 impl std::error::Error for SvcErr {}
573
574 let recorder = Recorder::default();
575 let events = recorder.0.clone();
576
577 let inner = tower::service_fn(
578 |_req: Request<RequestBody<Full<Bytes>, ReqH>>| async move {
579 Err::<Response<Full<Bytes>>, _>(SvcErr)
580 },
581 );
582 let svc = ServiceBuilder::new()
583 .layer(CallbackLayer::new(recorder))
584 .service(inner);
585
586 let result = svc
587 .oneshot(Request::new(Full::new(Bytes::from_static(b"ping"))))
588 .await;
589 let err = match result {
590 Ok(_) => panic!("expected service error"),
591 Err(err) => err,
592 };
593 assert_eq!(err.to_string(), "svc-boom");
594
595 let events = events.lock().unwrap();
596 assert_eq!(events.response_seen, 0);
598 assert!(events.response_chunks.is_empty());
599 assert!(events.response_end_trailers.is_empty());
600 assert!(events.response_body_errors.is_empty());
601 assert_eq!(events.response_service_errors, vec!["svc-boom".to_string()]);
603 }
604}