Skip to main content

ruststream/runtime/
typed.rs

1//! Typed handler adapter: turns a [`Handler<T>`](Handler) over a decoded value into a
2//! [`Handler<M>`](Handler) by decoding the message payload via a [`Codec`].
3//!
4//! This is the decode boundary between the two middleware levels: raw (pre-decode) middleware
5//! wrap the produced `Handler<M>`; typed (post-decode) middleware wrap the `inner: Handler<T>`
6//! passed in here. Both use the same [`Layer`](super::Layer) / [`HandlerExt`](super::HandlerExt)
7//! machinery, just at different inputs.
8
9use std::{fmt, marker::PhantomData};
10
11use crate::IncomingMessage;
12use crate::codec::Codec;
13use serde::de::DeserializeOwned;
14use tracing::warn;
15
16use super::context::Context;
17use super::failure::FailurePolicy;
18use super::handler::{Handler, HandlerResult, Settle};
19
20/// Build a `Handler<M>` that decodes the payload with `codec` into `T` and forwards `&T` to
21/// `inner`.
22///
23/// `inner` is any [`Handler<T>`](Handler) - a closure `Fn(&T) -> _` or a typed middleware stack
24/// built with [`HandlerExt::with`](super::HandlerExt::with).
25pub fn typed<M, T, C, H>(codec: C, inner: H) -> Typed<M, T, C, H>
26where
27    M: IncomingMessage,
28    T: DeserializeOwned + Send + Sync,
29    C: Codec,
30    H: Handler<T>,
31{
32    Typed {
33        codec,
34        inner,
35        decode: FailurePolicy::Drop,
36        _phantom: PhantomData,
37    }
38}
39
40/// Handler produced by [`typed`]. Override the decode-failure policy with
41/// [`Typed::on_decode_failure`].
42pub struct Typed<M, T, C, H> {
43    codec: C,
44    inner: H,
45    decode: FailurePolicy,
46    _phantom: PhantomData<fn(M, T)>,
47}
48
49impl<M, T, C, H> Typed<M, T, C, H> {
50    /// Sets the [`FailurePolicy`] applied when the codec fails to decode an incoming payload. The
51    /// default is [`FailurePolicy::Drop`].
52    #[must_use]
53    pub fn on_decode_failure(mut self, decode: FailurePolicy) -> Self {
54        self.decode = decode;
55        self
56    }
57}
58
59impl<M, T, C, H> fmt::Debug for Typed<M, T, C, H> {
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        f.debug_struct("Typed")
62            .field("decode", &self.decode)
63            .finish_non_exhaustive()
64    }
65}
66
67impl<M, T, C, H> Handler<M> for Typed<M, T, C, H>
68where
69    M: IncomingMessage,
70    T: DeserializeOwned + Send + Sync,
71    C: Codec,
72    H: Handler<T>,
73{
74    async fn handle(&self, msg: &M, ctx: &mut Context<'_>) -> Settle {
75        match self.codec.decode::<T>(msg.payload()) {
76            Ok(value) => self.inner.handle(&value, ctx).await,
77            Err(err) => {
78                warn!(
79                    target: "ruststream::dispatch",
80                    subscription = %ctx.name(),
81                    message_type = std::any::type_name::<T>(),
82                    error = %err,
83                    "codec decode failed",
84                );
85                match self.decode {
86                    FailurePolicy::FailFast => {
87                        ctx.fail_fast(&format!("decode failed: {err}"));
88                        HandlerResult::drop()
89                    }
90                    other => other.settlement().unwrap_or_else(HandlerResult::drop),
91                }
92                .into()
93            }
94        }
95    }
96}
97
98#[cfg(all(test, feature = "json"))]
99mod tests {
100    use std::sync::{
101        Arc,
102        atomic::{AtomicU32, Ordering},
103    };
104
105    use super::typed;
106    use crate::codec::JsonCodec;
107    use crate::runtime::context::{Context, State};
108    use crate::runtime::dispatch::Delivery;
109    use crate::runtime::failure::FailurePolicy;
110    use crate::runtime::handler::{Handler, HandlerResult};
111    use crate::{AckError, Headers, IncomingMessage};
112
113    struct StubMsg(Vec<u8>, Headers);
114
115    impl IncomingMessage for StubMsg {
116        fn payload(&self) -> &[u8] {
117            &self.0
118        }
119
120        fn headers(&self) -> &Headers {
121            &self.1
122        }
123
124        async fn ack(self) -> Result<(), AckError> {
125            Ok(())
126        }
127
128        async fn nack(self, _requeue: bool) -> Result<(), AckError> {
129            Ok(())
130        }
131    }
132
133    fn counting_inner(seen: &Arc<AtomicU32>) -> impl Handler<u32> {
134        let seen = Arc::clone(seen);
135        move |value: &u32, _ctx: &mut Context| {
136            let seen = Arc::clone(&seen);
137            let value = *value;
138            async move {
139                seen.store(value, Ordering::SeqCst);
140                HandlerResult::Ack
141            }
142        }
143    }
144
145    // Plain #[tokio::test]: nothing is spawned, the handler future is awaited inline.
146    #[tokio::test]
147    async fn decoded_value_reaches_inner() {
148        let seen = Arc::new(AtomicU32::new(0));
149        let handler = typed(JsonCodec, counting_inner(&seen));
150        let state = State::default();
151        let delivery = Delivery::empty();
152        let headers = Headers::new();
153        let mut ctx = Context::new("typed", &headers, &state, &delivery);
154
155        let msg = StubMsg(b"7".to_vec(), Headers::new());
156        assert_eq!(
157            handler.handle(&msg, &mut ctx).await.outcome(),
158            HandlerResult::Ack
159        );
160        assert_eq!(seen.load(Ordering::SeqCst), 7);
161    }
162
163    #[tokio::test]
164    async fn decode_failure_drops_by_default() {
165        let seen = Arc::new(AtomicU32::new(0));
166        let handler = typed(JsonCodec, counting_inner(&seen));
167        let state = State::default();
168        let delivery = Delivery::empty();
169        let headers = Headers::new();
170        let mut ctx = Context::new("typed", &headers, &state, &delivery);
171
172        let msg = StubMsg(b"not json".to_vec(), Headers::new());
173        assert_eq!(
174            handler.handle(&msg, &mut ctx).await.outcome(),
175            HandlerResult::drop()
176        );
177        assert_eq!(seen.load(Ordering::SeqCst), 0, "inner must not run");
178    }
179
180    #[tokio::test]
181    async fn decode_failure_requeues_when_overridden() {
182        let seen = Arc::new(AtomicU32::new(0));
183        let handler =
184            typed(JsonCodec, counting_inner(&seen)).on_decode_failure(FailurePolicy::Retry);
185        let state = State::default();
186        let delivery = Delivery::empty();
187        let headers = Headers::new();
188        let mut ctx = Context::new("typed", &headers, &state, &delivery);
189
190        let msg = StubMsg(b"not json".to_vec(), Headers::new());
191        assert_eq!(
192            handler.handle(&msg, &mut ctx).await.outcome(),
193            HandlerResult::retry()
194        );
195        assert_eq!(seen.load(Ordering::SeqCst), 0, "inner must not run");
196    }
197
198    #[tokio::test]
199    async fn typed_handler_is_debug_and_stub_acks() {
200        let seen = Arc::new(AtomicU32::new(0));
201        let handler = typed(JsonCodec, counting_inner(&seen));
202        let state = State::default();
203        let delivery = Delivery::empty();
204        let headers = Headers::new();
205        let mut ctx = Context::new("typed", &headers, &state, &delivery);
206        // Drive one delivery to pin the message type, then check the Debug rendering.
207        let msg = StubMsg(b"5".to_vec(), Headers::new());
208        let _ = handler.handle(&msg, &mut ctx).await;
209        assert!(format!("{handler:?}").contains("Typed"));
210
211        // Exercise the StubMsg fixture's own IncomingMessage surface.
212        let other = StubMsg(b"x".to_vec(), Headers::new());
213        assert!(other.headers().is_empty());
214        other.ack().await.unwrap();
215        StubMsg(Vec::new(), Headers::new())
216            .nack(true)
217            .await
218            .unwrap();
219    }
220
221    // Captures the fields of the one event emitted on a decode failure, so the test can assert the
222    // diagnostic carries the subscription name and target type (needs a tracing subscriber, hence
223    // the `logging` feature gate).
224    #[cfg(feature = "logging")]
225    #[tokio::test]
226    async fn decode_failure_log_names_subscription_and_type() {
227        use std::collections::HashMap;
228        use std::sync::Mutex;
229
230        use tracing::field::{Field, Visit};
231        use tracing_subscriber::Layer;
232        use tracing_subscriber::layer::{Context as LayerContext, SubscriberExt as _};
233
234        #[derive(Default)]
235        struct FieldGrab(HashMap<String, String>);
236
237        impl Visit for FieldGrab {
238            fn record_str(&mut self, field: &Field, value: &str) {
239                self.0.insert(field.name().to_owned(), value.to_owned());
240            }
241
242            fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) {
243                self.0
244                    .entry(field.name().to_owned())
245                    .or_insert_with(|| format!("{value:?}"));
246            }
247        }
248
249        struct Capture(Arc<Mutex<Vec<HashMap<String, String>>>>);
250
251        impl<S: tracing::Subscriber> Layer<S> for Capture {
252            fn on_event(&self, event: &tracing::Event<'_>, _ctx: LayerContext<'_, S>) {
253                let mut grab = FieldGrab::default();
254                event.record(&mut grab);
255                self.0.lock().unwrap().push(grab.0);
256            }
257        }
258
259        let events = Arc::new(Mutex::new(Vec::new()));
260        let guard = tracing::subscriber::set_default(
261            tracing_subscriber::registry().with(Capture(Arc::clone(&events))),
262        );
263
264        let seen = Arc::new(AtomicU32::new(0));
265        let handler = typed(JsonCodec, counting_inner(&seen));
266        let state = State::default();
267        let delivery = Delivery::empty();
268        let headers = Headers::new();
269        let mut ctx = Context::new("orders.inbound", &headers, &state, &delivery);
270        let msg = StubMsg(b"not json".to_vec(), Headers::new());
271        assert_eq!(
272            handler.handle(&msg, &mut ctx).await.outcome(),
273            HandlerResult::drop()
274        );
275        drop(guard);
276
277        let decode_event = {
278            let captured = events.lock().unwrap();
279            captured
280                .iter()
281                .find(|f| f.get("message").is_some_and(|m| m == "codec decode failed"))
282                .cloned()
283                .expect("a codec-decode-failed event must be emitted")
284        };
285        assert_eq!(
286            decode_event.get("subscription").map(String::as_str),
287            Some("orders.inbound")
288        );
289        assert_eq!(
290            decode_event.get("message_type").map(String::as_str),
291            Some("u32")
292        );
293    }
294}