Skip to main content

sui_http/middleware/callback/
mod.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Middleware for observing both the request and response streams of a
5//! [`tower::Service`] via user-provided callback handlers.
6//!
7//! A [`MakeCallbackHandler`] produces a pair of handlers per request:
8//! a [`RequestHandler`] (invoked as the request body is polled by the
9//! inner service) and a [`ResponseHandler`] (invoked when the response
10//! materializes and as its body is polled by the caller). The inner
11//! service's request body is wrapped as a [`RequestBody`], and the
12//! response body handed back to the caller is wrapped as a
13//! [`ResponseBody`]; both carry their respective handler along with the
14//! data.
15//!
16//! Either side can be a no-op by using the unit type `()`, which has a
17//! blanket [`RequestHandler`] impl provided by this crate.
18//!
19//! # Example
20//!
21//! ```
22//! use http::request;
23//! use http::response;
24//! use sui_http::middleware::callback::CallbackLayer;
25//! use sui_http::middleware::callback::MakeCallbackHandler;
26//! use sui_http::middleware::callback::RequestHandler;
27//! use sui_http::middleware::callback::ResponseHandler;
28//!
29//! /// A handler that counts bytes observed on one side of the exchange.
30//! #[derive(Default)]
31//! struct ByteCounter {
32//!     bytes: usize,
33//! }
34//!
35//! impl RequestHandler for ByteCounter {
36//!     fn on_body_chunk<B: bytes::Buf>(&mut self, chunk: &B) {
37//!         self.bytes += chunk.remaining();
38//!     }
39//! }
40//!
41//! impl ResponseHandler for ByteCounter {
42//!     fn on_response(&mut self, _parts: &response::Parts) {}
43//!     fn on_service_error<E: std::fmt::Display + 'static>(&mut self, _error: &E) {}
44//!     fn on_body_chunk<B: bytes::Buf>(&mut self, chunk: &B) {
45//!         self.bytes += chunk.remaining();
46//!     }
47//! }
48//!
49//! #[derive(Clone)]
50//! struct MakeByteCounter;
51//!
52//! impl MakeCallbackHandler for MakeByteCounter {
53//!     type RequestHandler = ByteCounter;
54//!     type ResponseHandler = ByteCounter;
55//!
56//!     fn make_handler(
57//!         &self,
58//!         _request: &request::Parts,
59//!     ) -> (Self::RequestHandler, Self::ResponseHandler) {
60//!         (ByteCounter::default(), ByteCounter::default())
61//!     }
62//! }
63//!
64//! let _layer = CallbackLayer::new(MakeByteCounter);
65//! ```
66//!
67//! # Body type change
68//!
69//! The wrapped [`Callback`] service hands the inner service a
70//! `Request<RequestBody<B, M::RequestHandler>>` rather than the original
71//! `Request<B>`. For body-polymorphic inner services (e.g. `axum::Router`
72//! or generic `tower` services), this is transparent.
73//!
74//! Monomorphic inner services that require a specific body type — for
75//! example `tonic::transport::Channel`, which expects `tonic::body::Body` —
76//! must rebox the wrapped body at the call site:
77//!
78//! ```ignore
79//! let service = tower::ServiceBuilder::new()
80//!     .layer(CallbackLayer::new(MakeByteCounter))
81//!     .map_request(|req: tonic::Request<_>| req.map(tonic::body::Body::new))
82//!     .service(tonic_service);
83//! ```
84//!
85//! [`Callback`]: self::Callback
86
87use http::HeaderMap;
88use http::request;
89use http::response;
90
91mod body;
92mod future;
93mod layer;
94mod service;
95
96pub use self::body::RequestBody;
97pub use self::body::ResponseBody;
98pub use self::future::ResponseFuture;
99pub use self::layer::CallbackLayer;
100pub use self::service::Callback;
101
102/// Factory for per-request callback handler pairs.
103///
104/// A single [`MakeCallbackHandler`] implementation produces, for each
105/// inbound request, one [`RequestHandler`] (observes the request body)
106/// and one [`ResponseHandler`] (observes the response and its body).
107pub trait MakeCallbackHandler {
108    /// Handler invoked while the request body is polled by the inner
109    /// service.
110    type RequestHandler: RequestHandler;
111    /// Handler invoked when the response materializes and while its body
112    /// is polled.
113    type ResponseHandler: ResponseHandler;
114
115    /// Build the handler pair for a single request.
116    fn make_handler(
117        &self,
118        request: &request::Parts,
119    ) -> (Self::RequestHandler, Self::ResponseHandler);
120}
121
122/// Observes the request body as it is polled by the inner service.
123///
124/// All methods default to no-ops, so implementors only override the
125/// events they care about. The unit type `()` has a blanket impl with
126/// every method a no-op; use `type RequestHandler = ();` when only the
127/// response side is interesting.
128pub trait RequestHandler {
129    /// Called once per data frame yielded by the request body.
130    fn on_body_chunk<B>(&mut self, _chunk: &B)
131    where
132        B: bytes::Buf,
133    {
134        // do nothing
135    }
136
137    /// Called at most once when the request body stream ends.
138    ///
139    /// `trailers` is `Some` if the final frame was a trailers frame,
140    /// otherwise `None`.
141    fn on_end_of_stream(&mut self, _trailers: Option<&HeaderMap>) {
142        // do nothing
143    }
144
145    /// Called when polling the request body yields an error.
146    fn on_body_error<E>(&mut self, _error: &E)
147    where
148        E: std::fmt::Display + 'static,
149    {
150        // do nothing
151    }
152}
153
154impl RequestHandler for () {}
155
156/// Observes the response as seen by the caller: the response parts, the
157/// response body, and the service-level error that occurs if the inner
158/// service's future resolves to `Err` before any response is produced.
159///
160/// Body-level methods default to no-ops.
161pub trait ResponseHandler {
162    /// Called exactly once when the inner service produces a response.
163    fn on_response(&mut self, response: &response::Parts);
164
165    /// Called when the inner service's future resolves to `Err` (no
166    /// response is produced). Response body errors are reported
167    /// separately through [`Self::on_body_error`].
168    fn on_service_error<E>(&mut self, error: &E)
169    where
170        E: std::fmt::Display + 'static;
171
172    /// Called once per data frame yielded by the response body.
173    fn on_body_chunk<B>(&mut self, _chunk: &B)
174    where
175        B: bytes::Buf,
176    {
177        // do nothing
178    }
179
180    /// Called at most once when the response body stream ends.
181    fn on_end_of_stream(&mut self, _trailers: Option<&HeaderMap>) {
182        // do nothing
183    }
184
185    /// Called when polling the response body yields an error.
186    fn on_body_error<E>(&mut self, _error: &E)
187    where
188        E: std::fmt::Display + 'static,
189    {
190        // do nothing
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use bytes::Buf;
198    use bytes::Bytes;
199    use futures::stream;
200    use http::Request;
201    use http::Response;
202    use http_body::Body;
203    use http_body_util::BodyExt;
204    use http_body_util::Full;
205    use http_body_util::StreamBody;
206    use std::convert::Infallible;
207    use std::sync::Arc;
208    use std::sync::Mutex;
209    use tower::ServiceBuilder;
210    use tower::ServiceExt;
211
212    /// Events recorded by a test handler pair. We share one `Arc<Mutex<_>>`
213    /// between the request and response handlers so the test can assert on
214    /// the complete, ordered event log.
215    #[derive(Debug, Default, PartialEq, Eq)]
216    struct Events {
217        request_chunks: Vec<Vec<u8>>,
218        request_end_trailers: Vec<Option<HeaderMap>>,
219        request_body_errors: Vec<String>,
220        response_seen: u32,
221        response_chunks: Vec<Vec<u8>>,
222        response_end_trailers: Vec<Option<HeaderMap>>,
223        response_body_errors: Vec<String>,
224        response_service_errors: Vec<String>,
225    }
226
227    #[derive(Clone, Default)]
228    struct Recorder(Arc<Mutex<Events>>);
229
230    struct ReqH(Arc<Mutex<Events>>);
231    struct RespH(Arc<Mutex<Events>>);
232
233    impl RequestHandler for ReqH {
234        fn on_body_chunk<B: Buf>(&mut self, chunk: &B) {
235            self.0
236                .lock()
237                .unwrap()
238                .request_chunks
239                .push(chunk.chunk().to_vec());
240        }
241        fn on_end_of_stream(&mut self, trailers: Option<&HeaderMap>) {
242            self.0
243                .lock()
244                .unwrap()
245                .request_end_trailers
246                .push(trailers.cloned());
247        }
248        fn on_body_error<E: std::fmt::Display + 'static>(&mut self, error: &E) {
249            self.0
250                .lock()
251                .unwrap()
252                .request_body_errors
253                .push(error.to_string());
254        }
255    }
256
257    impl ResponseHandler for RespH {
258        fn on_response(&mut self, _parts: &response::Parts) {
259            self.0.lock().unwrap().response_seen += 1;
260        }
261        fn on_service_error<E: std::fmt::Display + 'static>(&mut self, error: &E) {
262            self.0
263                .lock()
264                .unwrap()
265                .response_service_errors
266                .push(error.to_string());
267        }
268        fn on_body_chunk<B: Buf>(&mut self, chunk: &B) {
269            self.0
270                .lock()
271                .unwrap()
272                .response_chunks
273                .push(chunk.chunk().to_vec());
274        }
275        fn on_end_of_stream(&mut self, trailers: Option<&HeaderMap>) {
276            self.0
277                .lock()
278                .unwrap()
279                .response_end_trailers
280                .push(trailers.cloned());
281        }
282        fn on_body_error<E: std::fmt::Display + 'static>(&mut self, error: &E) {
283            self.0
284                .lock()
285                .unwrap()
286                .response_body_errors
287                .push(error.to_string());
288        }
289    }
290
291    impl MakeCallbackHandler for Recorder {
292        type RequestHandler = ReqH;
293        type ResponseHandler = RespH;
294
295        fn make_handler(
296            &self,
297            _request: &request::Parts,
298        ) -> (Self::RequestHandler, Self::ResponseHandler) {
299            (ReqH(self.0.clone()), RespH(self.0.clone()))
300        }
301    }
302
303    /// Drives the request body to completion so the request handler's
304    /// events fire. In a real server, hyper does this implicitly; in
305    /// tests we have to poll the body ourselves.
306    async fn drain<B: Body + Unpin>(body: B) -> Result<(), B::Error> {
307        let collected = body.collect().await?;
308        let _ = collected.to_bytes();
309        Ok(())
310    }
311
312    #[tokio::test]
313    async fn observes_request_chunks_and_clean_end() {
314        let recorder = Recorder::default();
315        let events = recorder.0.clone();
316
317        let inner = tower::service_fn(
318            |req: Request<RequestBody<Full<Bytes>, ReqH>>| async move {
319                drain(req.into_body()).await.unwrap();
320                Ok::<_, Infallible>(Response::new(Full::new(Bytes::from_static(b"ok"))))
321            },
322        );
323        let svc = ServiceBuilder::new()
324            .layer(CallbackLayer::new(recorder))
325            .service(inner);
326
327        let request = Request::new(Full::new(Bytes::from_static(b"hello world")));
328        let response = svc.oneshot(request).await.unwrap();
329        drain(response.into_body()).await.unwrap();
330
331        let events = events.lock().unwrap();
332        assert_eq!(events.request_chunks, vec![b"hello world".to_vec()]);
333        assert_eq!(events.request_end_trailers, vec![None]);
334        assert!(events.request_body_errors.is_empty());
335        // Regression guard on the response side.
336        assert_eq!(events.response_seen, 1);
337        assert_eq!(events.response_chunks, vec![b"ok".to_vec()]);
338        assert_eq!(events.response_end_trailers, vec![None]);
339        assert!(events.response_body_errors.is_empty());
340        assert!(events.response_service_errors.is_empty());
341    }
342
343    #[tokio::test]
344    async fn observes_request_trailers_on_end() {
345        let recorder = Recorder::default();
346        let events = recorder.0.clone();
347
348        let mut trailers = HeaderMap::new();
349        trailers.insert("x-req-trailer", "abc".parse().unwrap());
350        let frames: Vec<Result<http_body::Frame<Bytes>, Infallible>> = vec![
351            Ok(http_body::Frame::data(Bytes::from_static(b"chunk-1"))),
352            Ok(http_body::Frame::data(Bytes::from_static(b"chunk-2"))),
353            Ok(http_body::Frame::trailers(trailers.clone())),
354        ];
355        let body = StreamBody::new(stream::iter(frames));
356
357        let inner = tower::service_fn(
358            |req: Request<RequestBody<StreamBody<_>, ReqH>>| async move {
359                drain(req.into_body()).await.unwrap();
360                Ok::<_, Infallible>(Response::new(Full::new(Bytes::new())))
361            },
362        );
363        let svc = ServiceBuilder::new()
364            .layer(CallbackLayer::new(recorder))
365            .service(inner);
366
367        let response = svc.oneshot(Request::new(body)).await.unwrap();
368        drain(response.into_body()).await.unwrap();
369
370        let events = events.lock().unwrap();
371        assert_eq!(
372            events.request_chunks,
373            vec![b"chunk-1".to_vec(), b"chunk-2".to_vec()]
374        );
375        assert_eq!(events.request_end_trailers.len(), 1);
376        assert_eq!(events.request_end_trailers[0].as_ref(), Some(&trailers));
377        assert!(events.request_body_errors.is_empty());
378    }
379
380    #[tokio::test]
381    async fn observes_request_body_error() {
382        #[derive(Debug)]
383        struct BodyErr;
384        impl std::fmt::Display for BodyErr {
385            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386                f.write_str("boom")
387            }
388        }
389        impl std::error::Error for BodyErr {}
390
391        let recorder = Recorder::default();
392        let events = recorder.0.clone();
393
394        let frames: Vec<Result<http_body::Frame<Bytes>, BodyErr>> = vec![
395            Ok(http_body::Frame::data(Bytes::from_static(b"partial"))),
396            Err(BodyErr),
397        ];
398        let body = StreamBody::new(stream::iter(frames));
399
400        let inner = tower::service_fn(
401            |req: Request<RequestBody<StreamBody<_>, ReqH>>| async move {
402                // Ignore the error; we just want to trigger it.
403                let _ = drain(req.into_body()).await;
404                Ok::<_, Infallible>(Response::new(Full::new(Bytes::new())))
405            },
406        );
407        let svc = ServiceBuilder::new()
408            .layer(CallbackLayer::new(recorder))
409            .service(inner);
410
411        let response = svc.oneshot(Request::new(body)).await.unwrap();
412        drain(response.into_body()).await.unwrap();
413
414        let events = events.lock().unwrap();
415        assert_eq!(events.request_chunks, vec![b"partial".to_vec()]);
416        assert_eq!(events.request_body_errors, vec!["boom".to_string()]);
417        // An error terminates the stream; no clean end-of-stream fires.
418        assert!(events.request_end_trailers.is_empty());
419    }
420
421    /// Compile-time and runtime check that `type RequestHandler = ();`
422    /// works, is zero-cost in the ordinary sense (no observable side
423    /// effects), and leaves the response side fully functional.
424    #[tokio::test]
425    async fn unit_request_handler_is_noop() {
426        #[derive(Clone)]
427        struct MakeResponseOnly(Arc<Mutex<u32>>);
428
429        struct CountResp(Arc<Mutex<u32>>);
430        impl ResponseHandler for CountResp {
431            fn on_response(&mut self, _parts: &response::Parts) {
432                *self.0.lock().unwrap() += 1;
433            }
434            fn on_service_error<E: std::fmt::Display + 'static>(&mut self, _error: &E) {}
435        }
436
437        impl MakeCallbackHandler for MakeResponseOnly {
438            type RequestHandler = ();
439            type ResponseHandler = CountResp;
440
441            fn make_handler(
442                &self,
443                _request: &request::Parts,
444            ) -> (Self::RequestHandler, Self::ResponseHandler) {
445                ((), CountResp(self.0.clone()))
446            }
447        }
448
449        let counter = Arc::new(Mutex::new(0));
450        let make = MakeResponseOnly(counter.clone());
451
452        let inner = tower::service_fn(
453            |req: Request<RequestBody<Full<Bytes>, ()>>| async move {
454                drain(req.into_body()).await.unwrap();
455                Ok::<_, Infallible>(Response::new(Full::new(Bytes::from_static(b"hi"))))
456            },
457        );
458        let svc = ServiceBuilder::new()
459            .layer(CallbackLayer::new(make))
460            .service(inner);
461
462        let response = svc
463            .oneshot(Request::new(Full::new(Bytes::from_static(b"ping"))))
464            .await
465            .unwrap();
466        drain(response.into_body()).await.unwrap();
467
468        assert_eq!(*counter.lock().unwrap(), 1);
469    }
470
471    #[tokio::test]
472    async fn observes_response_trailers_on_end() {
473        let recorder = Recorder::default();
474        let events = recorder.0.clone();
475
476        let mut trailers = HeaderMap::new();
477        trailers.insert("x-resp-trailer", "xyz".parse().unwrap());
478        let frames: Vec<Result<http_body::Frame<Bytes>, Infallible>> = vec![
479            Ok(http_body::Frame::data(Bytes::from_static(b"part-1"))),
480            Ok(http_body::Frame::data(Bytes::from_static(b"part-2"))),
481            Ok(http_body::Frame::trailers(trailers.clone())),
482        ];
483        // `StreamBody` isn't `Clone` and `service_fn` takes an `Fn`, so we
484        // smuggle the single-use body through a `Mutex<Option<_>>`.
485        let body_slot = Arc::new(Mutex::new(Some(StreamBody::new(stream::iter(frames)))));
486
487        let inner = tower::service_fn({
488            let body_slot = body_slot.clone();
489            move |req: Request<RequestBody<Full<Bytes>, ReqH>>| {
490                let body = body_slot.lock().unwrap().take().expect("called once");
491                async move {
492                    drain(req.into_body()).await.unwrap();
493                    Ok::<_, Infallible>(Response::new(body))
494                }
495            }
496        });
497        let svc = ServiceBuilder::new()
498            .layer(CallbackLayer::new(recorder))
499            .service(inner);
500
501        let response = svc
502            .oneshot(Request::new(Full::new(Bytes::from_static(b"ping"))))
503            .await
504            .unwrap();
505        drain(response.into_body()).await.unwrap();
506
507        let events = events.lock().unwrap();
508        assert_eq!(events.response_seen, 1);
509        assert_eq!(
510            events.response_chunks,
511            vec![b"part-1".to_vec(), b"part-2".to_vec()]
512        );
513        assert_eq!(events.response_end_trailers.len(), 1);
514        assert_eq!(events.response_end_trailers[0].as_ref(), Some(&trailers));
515        assert!(events.response_body_errors.is_empty());
516        assert!(events.response_service_errors.is_empty());
517    }
518
519    #[tokio::test]
520    async fn observes_response_body_error() {
521        #[derive(Debug)]
522        struct BodyErr;
523        impl std::fmt::Display for BodyErr {
524            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
525                f.write_str("body-boom")
526            }
527        }
528        impl std::error::Error for BodyErr {}
529
530        let recorder = Recorder::default();
531        let events = recorder.0.clone();
532
533        let inner = tower::service_fn(
534            |req: Request<RequestBody<Full<Bytes>, ReqH>>| async move {
535                drain(req.into_body()).await.unwrap();
536                let frames: Vec<Result<http_body::Frame<Bytes>, BodyErr>> = vec![
537                    Ok(http_body::Frame::data(Bytes::from_static(b"partial"))),
538                    Err(BodyErr),
539                ];
540                Ok::<_, Infallible>(Response::new(StreamBody::new(stream::iter(frames))))
541            },
542        );
543        let svc = ServiceBuilder::new()
544            .layer(CallbackLayer::new(recorder))
545            .service(inner);
546
547        let response = svc
548            .oneshot(Request::new(Full::new(Bytes::new())))
549            .await
550            .unwrap();
551        // Drain but ignore the body error; we only care about the callback.
552        let _ = drain(response.into_body()).await;
553
554        let events = events.lock().unwrap();
555        assert_eq!(events.response_seen, 1);
556        assert_eq!(events.response_chunks, vec![b"partial".to_vec()]);
557        assert_eq!(events.response_body_errors, vec!["body-boom".to_string()]);
558        assert!(events.response_service_errors.is_empty());
559        // An error terminates the stream; no clean end-of-stream fires.
560        assert!(events.response_end_trailers.is_empty());
561    }
562
563    #[tokio::test]
564    async fn observes_service_error() {
565        #[derive(Debug)]
566        struct SvcErr;
567        impl std::fmt::Display for SvcErr {
568            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
569                f.write_str("svc-boom")
570            }
571        }
572        impl std::error::Error for SvcErr {}
573
574        let recorder = Recorder::default();
575        let events = recorder.0.clone();
576
577        let inner = tower::service_fn(
578            |_req: Request<RequestBody<Full<Bytes>, ReqH>>| async move {
579                Err::<Response<Full<Bytes>>, _>(SvcErr)
580            },
581        );
582        let svc = ServiceBuilder::new()
583            .layer(CallbackLayer::new(recorder))
584            .service(inner);
585
586        let result = svc
587            .oneshot(Request::new(Full::new(Bytes::from_static(b"ping"))))
588            .await;
589        let err = match result {
590            Ok(_) => panic!("expected service error"),
591            Err(err) => err,
592        };
593        assert_eq!(err.to_string(), "svc-boom");
594
595        let events = events.lock().unwrap();
596        // The response itself never materialized.
597        assert_eq!(events.response_seen, 0);
598        assert!(events.response_chunks.is_empty());
599        assert!(events.response_end_trailers.is_empty());
600        assert!(events.response_body_errors.is_empty());
601        // Service error routed to the response handler.
602        assert_eq!(events.response_service_errors, vec!["svc-boom".to_string()]);
603    }
604}