1use std::{fmt, marker::PhantomData};
10
11use crate::IncomingMessage;
12use crate::codec::Codec;
13use serde::de::DeserializeOwned;
14use tracing::warn;
15
16use super::context::Context;
17use super::failure::FailurePolicy;
18use super::handler::{Handler, HandlerResult, Settle};
19
20pub fn typed<M, T, C, H>(codec: C, inner: H) -> Typed<M, T, C, H>
26where
27 M: IncomingMessage,
28 T: DeserializeOwned + Send + Sync,
29 C: Codec,
30 H: Handler<T>,
31{
32 Typed {
33 codec,
34 inner,
35 decode: FailurePolicy::Drop,
36 _phantom: PhantomData,
37 }
38}
39
40pub struct Typed<M, T, C, H> {
43 codec: C,
44 inner: H,
45 decode: FailurePolicy,
46 _phantom: PhantomData<fn(M, T)>,
47}
48
49impl<M, T, C, H> Typed<M, T, C, H> {
50 #[must_use]
53 pub fn on_decode_failure(mut self, decode: FailurePolicy) -> Self {
54 self.decode = decode;
55 self
56 }
57}
58
59impl<M, T, C, H> fmt::Debug for Typed<M, T, C, H> {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 f.debug_struct("Typed")
62 .field("decode", &self.decode)
63 .finish_non_exhaustive()
64 }
65}
66
67impl<M, T, C, H> Handler<M> for Typed<M, T, C, H>
68where
69 M: IncomingMessage,
70 T: DeserializeOwned + Send + Sync,
71 C: Codec,
72 H: Handler<T>,
73{
74 async fn handle(&self, msg: &M, ctx: &mut Context<'_>) -> Settle {
75 match self.codec.decode::<T>(msg.payload()) {
76 Ok(value) => self.inner.handle(&value, ctx).await,
77 Err(err) => {
78 warn!(
79 target: "ruststream::dispatch",
80 subscription = %ctx.name(),
81 message_type = std::any::type_name::<T>(),
82 error = %err,
83 "codec decode failed",
84 );
85 match self.decode {
86 FailurePolicy::FailFast => {
87 ctx.fail_fast(&format!("decode failed: {err}"));
88 HandlerResult::drop()
89 }
90 other => other.settlement().unwrap_or_else(HandlerResult::drop),
91 }
92 .into()
93 }
94 }
95 }
96}
97
98#[cfg(all(test, feature = "json"))]
99mod tests {
100 use std::sync::{
101 Arc,
102 atomic::{AtomicU32, Ordering},
103 };
104
105 use super::typed;
106 use crate::codec::JsonCodec;
107 use crate::runtime::context::{Context, State};
108 use crate::runtime::dispatch::Delivery;
109 use crate::runtime::failure::FailurePolicy;
110 use crate::runtime::handler::{Handler, HandlerResult};
111 use crate::{AckError, Headers, IncomingMessage};
112
113 struct StubMsg(Vec<u8>, Headers);
114
115 impl IncomingMessage for StubMsg {
116 fn payload(&self) -> &[u8] {
117 &self.0
118 }
119
120 fn headers(&self) -> &Headers {
121 &self.1
122 }
123
124 async fn ack(self) -> Result<(), AckError> {
125 Ok(())
126 }
127
128 async fn nack(self, _requeue: bool) -> Result<(), AckError> {
129 Ok(())
130 }
131 }
132
133 fn counting_inner(seen: &Arc<AtomicU32>) -> impl Handler<u32> {
134 let seen = Arc::clone(seen);
135 move |value: &u32, _ctx: &mut Context| {
136 let seen = Arc::clone(&seen);
137 let value = *value;
138 async move {
139 seen.store(value, Ordering::SeqCst);
140 HandlerResult::Ack
141 }
142 }
143 }
144
145 #[tokio::test]
147 async fn decoded_value_reaches_inner() {
148 let seen = Arc::new(AtomicU32::new(0));
149 let handler = typed(JsonCodec, counting_inner(&seen));
150 let state = State::default();
151 let delivery = Delivery::empty();
152 let headers = Headers::new();
153 let mut ctx = Context::new("typed", &headers, &state, &delivery);
154
155 let msg = StubMsg(b"7".to_vec(), Headers::new());
156 assert_eq!(
157 handler.handle(&msg, &mut ctx).await.outcome(),
158 HandlerResult::Ack
159 );
160 assert_eq!(seen.load(Ordering::SeqCst), 7);
161 }
162
163 #[tokio::test]
164 async fn decode_failure_drops_by_default() {
165 let seen = Arc::new(AtomicU32::new(0));
166 let handler = typed(JsonCodec, counting_inner(&seen));
167 let state = State::default();
168 let delivery = Delivery::empty();
169 let headers = Headers::new();
170 let mut ctx = Context::new("typed", &headers, &state, &delivery);
171
172 let msg = StubMsg(b"not json".to_vec(), Headers::new());
173 assert_eq!(
174 handler.handle(&msg, &mut ctx).await.outcome(),
175 HandlerResult::drop()
176 );
177 assert_eq!(seen.load(Ordering::SeqCst), 0, "inner must not run");
178 }
179
180 #[tokio::test]
181 async fn decode_failure_requeues_when_overridden() {
182 let seen = Arc::new(AtomicU32::new(0));
183 let handler =
184 typed(JsonCodec, counting_inner(&seen)).on_decode_failure(FailurePolicy::Retry);
185 let state = State::default();
186 let delivery = Delivery::empty();
187 let headers = Headers::new();
188 let mut ctx = Context::new("typed", &headers, &state, &delivery);
189
190 let msg = StubMsg(b"not json".to_vec(), Headers::new());
191 assert_eq!(
192 handler.handle(&msg, &mut ctx).await.outcome(),
193 HandlerResult::retry()
194 );
195 assert_eq!(seen.load(Ordering::SeqCst), 0, "inner must not run");
196 }
197
198 #[tokio::test]
199 async fn typed_handler_is_debug_and_stub_acks() {
200 let seen = Arc::new(AtomicU32::new(0));
201 let handler = typed(JsonCodec, counting_inner(&seen));
202 let state = State::default();
203 let delivery = Delivery::empty();
204 let headers = Headers::new();
205 let mut ctx = Context::new("typed", &headers, &state, &delivery);
206 let msg = StubMsg(b"5".to_vec(), Headers::new());
208 let _ = handler.handle(&msg, &mut ctx).await;
209 assert!(format!("{handler:?}").contains("Typed"));
210
211 let other = StubMsg(b"x".to_vec(), Headers::new());
213 assert!(other.headers().is_empty());
214 other.ack().await.unwrap();
215 StubMsg(Vec::new(), Headers::new())
216 .nack(true)
217 .await
218 .unwrap();
219 }
220
221 #[cfg(feature = "logging")]
225 #[tokio::test]
226 async fn decode_failure_log_names_subscription_and_type() {
227 use std::collections::HashMap;
228 use std::sync::Mutex;
229
230 use tracing::field::{Field, Visit};
231 use tracing_subscriber::Layer;
232 use tracing_subscriber::layer::{Context as LayerContext, SubscriberExt as _};
233
234 #[derive(Default)]
235 struct FieldGrab(HashMap<String, String>);
236
237 impl Visit for FieldGrab {
238 fn record_str(&mut self, field: &Field, value: &str) {
239 self.0.insert(field.name().to_owned(), value.to_owned());
240 }
241
242 fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) {
243 self.0
244 .entry(field.name().to_owned())
245 .or_insert_with(|| format!("{value:?}"));
246 }
247 }
248
249 struct Capture(Arc<Mutex<Vec<HashMap<String, String>>>>);
250
251 impl<S: tracing::Subscriber> Layer<S> for Capture {
252 fn on_event(&self, event: &tracing::Event<'_>, _ctx: LayerContext<'_, S>) {
253 let mut grab = FieldGrab::default();
254 event.record(&mut grab);
255 self.0.lock().unwrap().push(grab.0);
256 }
257 }
258
259 let events = Arc::new(Mutex::new(Vec::new()));
260 let guard = tracing::subscriber::set_default(
261 tracing_subscriber::registry().with(Capture(Arc::clone(&events))),
262 );
263
264 let seen = Arc::new(AtomicU32::new(0));
265 let handler = typed(JsonCodec, counting_inner(&seen));
266 let state = State::default();
267 let delivery = Delivery::empty();
268 let headers = Headers::new();
269 let mut ctx = Context::new("orders.inbound", &headers, &state, &delivery);
270 let msg = StubMsg(b"not json".to_vec(), Headers::new());
271 assert_eq!(
272 handler.handle(&msg, &mut ctx).await.outcome(),
273 HandlerResult::drop()
274 );
275 drop(guard);
276
277 let decode_event = {
278 let captured = events.lock().unwrap();
279 captured
280 .iter()
281 .find(|f| f.get("message").is_some_and(|m| m == "codec decode failed"))
282 .cloned()
283 .expect("a codec-decode-failed event must be emitted")
284 };
285 assert_eq!(
286 decode_event.get("subscription").map(String::as_str),
287 Some("orders.inbound")
288 );
289 assert_eq!(
290 decode_event.get("message_type").map(String::as_str),
291 Some("u32")
292 );
293 }
294}