1#[cfg(not(feature = "std"))]
2extern crate alloc;
3#[cfg(not(feature = "std"))]
4use alloc::{sync::Arc, vec::Vec};
5
6#[cfg(feature = "std")]
7use std::sync::Arc;
8
9use crate::{Frame, Message};
10
11#[cfg(feature = "derive")]
12use crate::Errorizable;
13
14pub type Result<T> = core::result::Result<T, RouterError>;
15
16#[cfg_attr(feature = "derive", derive(Errorizable))]
17#[derive(Debug)]
18pub enum RouterError {
19 #[cfg_attr(feature = "derive", error("No route configured for provided message"))]
20 UnknownRoute,
21}
22
23crate::impl_error_display!(RouterError {
24 UnknownRoute => "No route configured for provided message",
25});
26
27pub trait RouterPolicy: Send + Sync {
28 fn dispatch<T: Message + Send + 'static>(&self, message: Arc<Frame>) -> Result<()>;
29}
30
31#[macro_export]
32macro_rules! routes {
33 (@dispatch $self:ident, $message:ident, [ $( ($MsgTy:ty, $this:ident, $arg:pat_param, $handler:block) ),* ]) => {
35 $(
36 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<$MsgTy>() {
37 let $arg = $message;
38 let $this = $self;
39 { $handler }
40 return Ok(());
41 }
42 )*
43 Err($crate::router::RouterError::UnknownRoute)
44 };
45
46 (@dispatch $self:ident, $message:ident, [ $( ($MsgTy:ty, $arg:pat_param, $handler:block) ),* ]) => {
48 $(
49 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<$MsgTy>() {
50 let $arg = $message;
51 { $handler }
52 return Ok(());
53 }
54 )*
55 Err($crate::router::RouterError::UnknownRoute)
56 };
57
58 (
59 $RouterName:ident { $( $field:ident : $fty:ty ),* $(,)? } :
60 $(
61 $MsgTy:ty | $($arg:ident),* | $handler:block
62 )+
63 ) => {
64 struct $RouterName { $( $field : $fty ),* }
65 impl $crate::router::RouterPolicy for $RouterName {
66 #[cfg(not(feature = "std"))]
67 fn dispatch<T: $crate::Message + Send + 'static>(&self, message: alloc::sync::Arc<$crate::Frame>) -> $crate::router::Result<()> {
68 $(
69 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<$MsgTy>() {
70 let ($($arg),*) = (self, message);
71 $handler;
72 return Ok(());
73 }
74 )*
75 Err($crate::router::RouterError::UnknownRoute)
76 }
77
78 #[cfg(feature = "std")]
79 fn dispatch<T: $crate::Message + Send + 'static>(&self, message: std::sync::Arc<$crate::Frame>) -> $crate::router::Result<()> {
80 $(
81 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<$MsgTy>() {
82 let ($($arg),*) = (self, message);
83 $handler;
84 return Ok(());
85 }
86 )*
87 Err($crate::router::RouterError::UnknownRoute)
88 }
89 }
90 };
91}
92
93#[cfg(test)]
94mod tests {
95 use std::sync::{mpsc, Arc};
96 use std::time::Duration;
97
98 use crate::compose;
99 use crate::der::Sequence;
100 use crate::router::RouterPolicy;
101 use crate::Beamable;
102 use crate::Frame;
103
104 #[cfg(not(feature = "derive"))]
105 use crate::router::RouterPolicy;
106
107 #[cfg_attr(feature = "derive", derive(Beamable))]
108 #[derive(Sequence, Clone, Debug, PartialEq)]
109 pub struct HealthCheck {
110 pub uptime: u64,
111 }
112
113 #[cfg(not(feature = "derive"))]
114 impl crate::Message for HealthCheck {
115 const MUST_BE_CONFIDENTIAL: bool = false;
116 const MUST_BE_NON_REPUDIABLE: bool = false;
117 const MUST_BE_COMPRESSED: bool = false;
118 const MUST_BE_PRIORITIZED: bool = false;
119 const MIN_VERSION: crate::Version = crate::Version::V0;
120 }
121
122 #[cfg_attr(feature = "derive", derive(Beamable))]
123 #[derive(Sequence, Clone, Debug, PartialEq)]
124 pub struct Payment {
125 pub from: String,
126 pub amount: u64,
127 }
128
129 #[cfg(not(feature = "derive"))]
130 impl crate::Message for Payment {
131 const MUST_BE_CONFIDENTIAL: bool = false;
132 const MUST_BE_NON_REPUDIABLE: bool = false;
133 const MUST_BE_COMPRESSED: bool = false;
134 const MUST_BE_PRIORITIZED: bool = false;
135 const MIN_VERSION: crate::Version = crate::Version::V0;
136 }
137
138 #[test]
139 fn test_mpsc_channel_routing() -> Result<(), Box<dyn std::error::Error>> {
140 #[cfg(feature = "derive")]
141 routes! {
142 ChannelRouter {
143 payment_tx: mpsc::Sender<Arc<Frame>>,
144 health_tx: mpsc::Sender<Arc<Frame>>,
145 }:
146 Payment |router, msg| {
147 let _ = router.payment_tx.send(msg);
148 }
149 HealthCheck |router, msg| {
150 let _ = router.health_tx.send(msg);
151 }
152 }
153
154 #[cfg(not(feature = "derive"))]
155 struct ChannelRouter {
156 payment_tx: mpsc::Sender<Arc<Frame>>,
157 health_tx: mpsc::Sender<Arc<Frame>>,
158 }
159
160 #[cfg(not(feature = "derive"))]
161 impl super::RouterPolicy for ChannelRouter {
162 fn dispatch<M: Message>(&self, message: Arc<Frame>) -> crate::router::Result<()> {
163 if std::any::TypeId::of::<M>() == std::any::TypeId::of::<Payment>() {
164 let _ = self.payment_tx.send(message);
165 return Ok(());
166 }
167
168 if std::any::TypeId::of::<M>() == std::any::TypeId::of::<HealthCheck>() {
169 let _ = self.health_tx.send(message);
170 return Ok(());
171 }
172
173 Err(super::RouterError::UnknownRoute)
174 }
175 }
176
177 let (payment_tx, payment_rx) = mpsc::channel::<Arc<Frame>>();
178 let (health_tx, health_rx) = mpsc::channel::<Arc<Frame>>();
179 let router = ChannelRouter { payment_tx, health_tx };
180
181 let n = 5usize;
182 for i in 0..n {
183 let payment = compose! {
185 V0: id: format!("p-{i}"),
186 order: 1u64,
187 message: Payment {
188 from: "alice".into(),
189 amount: i as u64
190 }
191 }?;
192 router.dispatch::<Payment>(Arc::new(payment))?;
194
195 let health = compose! {
197 V0: id: format!("h-{i}"),
198 order: 1u64,
199 message: HealthCheck {
200 uptime: i as u64
201 }
202 }?;
203 router.dispatch::<HealthCheck>(Arc::new(health))?;
205 }
206
207 let timeout = Duration::from_millis(200);
209 for i in 0..n {
210 let received_payment = payment_rx.recv_timeout(timeout)?;
211 let message: Payment = crate::decode(&received_payment.message)?;
212 assert_eq!(&received_payment.metadata.id, &format!("p-{i}").as_bytes());
213 assert_eq!(message, Payment { from: "alice".into(), amount: i as u64 });
214
215 let received_health = health_rx.recv_timeout(timeout)?;
216 let message: HealthCheck = crate::decode(&received_health.message)?;
217 assert_eq!(received_health.metadata.id, format!("h-{i}").as_bytes());
218 assert_eq!(message, HealthCheck { uptime: i as u64 });
219 }
220
221 Ok(())
222 }
223}