up_rust/
local_transport.rs1use std::{collections::HashSet, sync::Arc};
20
21use tokio::sync::RwLock;
22
23use crate::{ComparableListener, UListener, UMessage, UStatus, UTransport, UUri};
24
25#[derive(Eq, PartialEq, Hash)]
26struct RegisteredListener {
27 source_filter: UUri,
28 sink_filter: Option<UUri>,
29 listener: ComparableListener,
30}
31
32impl RegisteredListener {
33 fn matches(&self, source: &UUri, sink: Option<&UUri>) -> bool {
34 if !self.source_filter.matches(source) {
35 return false;
36 }
37
38 if let Some(pattern) = &self.sink_filter {
39 sink.is_some_and(|candidate_sink| pattern.matches(candidate_sink))
40 } else {
41 sink.is_none()
42 }
43 }
44 fn matches_msg(&self, msg: &UMessage) -> bool {
45 if let Some(source) = msg
46 .attributes
47 .as_ref()
48 .and_then(|attribs| attribs.source.as_ref())
49 {
50 self.matches(
51 source,
52 msg.attributes
53 .as_ref()
54 .and_then(|attribs| attribs.sink.as_ref()),
55 )
56 } else {
57 false
58 }
59 }
60 async fn on_receive(&self, msg: UMessage) {
61 self.listener.on_receive(msg).await
62 }
63}
64
65#[derive(Default)]
70pub struct LocalTransport {
71 listeners: RwLock<HashSet<RegisteredListener>>,
72}
73
74impl LocalTransport {
75 async fn dispatch(&self, message: UMessage) {
76 let listeners = self.listeners.read().await;
77 for listener in listeners.iter() {
78 if listener.matches_msg(&message) {
79 listener.on_receive(message.clone()).await;
80 }
81 }
82 }
83}
84
85#[async_trait::async_trait]
86impl UTransport for LocalTransport {
87 async fn send(&self, message: UMessage) -> Result<(), UStatus> {
88 self.dispatch(message).await;
89 Ok(())
90 }
91
92 async fn register_listener(
93 &self,
94 source_filter: &UUri,
95 sink_filter: Option<&UUri>,
96 listener: Arc<dyn UListener>,
97 ) -> Result<(), UStatus> {
98 let registered_listener = RegisteredListener {
99 source_filter: source_filter.to_owned(),
100 sink_filter: sink_filter.map(|u| u.to_owned()),
101 listener: ComparableListener::new(listener),
102 };
103 let mut listeners = self.listeners.write().await;
104 if listeners.contains(®istered_listener) {
105 Err(UStatus::fail_with_code(
106 crate::UCode::ALREADY_EXISTS,
107 "listener already registered for filters",
108 ))
109 } else {
110 listeners.insert(registered_listener);
111 Ok(())
112 }
113 }
114
115 async fn unregister_listener(
116 &self,
117 source_filter: &UUri,
118 sink_filter: Option<&UUri>,
119 listener: Arc<dyn UListener>,
120 ) -> Result<(), UStatus> {
121 let registered_listener = RegisteredListener {
122 source_filter: source_filter.to_owned(),
123 sink_filter: sink_filter.map(|u| u.to_owned()),
124 listener: ComparableListener::new(listener),
125 };
126 let mut listeners = self.listeners.write().await;
127 if listeners.remove(®istered_listener) {
128 Ok(())
129 } else {
130 Err(UStatus::fail_with_code(
131 crate::UCode::NOT_FOUND,
132 "no such listener registered for filters",
133 ))
134 }
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use crate::{utransport::MockUListener, LocalUriProvider, StaticUriProvider, UMessageBuilder};
142
143 #[tokio::test]
144 async fn test_send_dispatches_to_matching_listener() {
145 const RESOURCE_ID: u16 = 0xa1b3;
146 let mut listener = MockUListener::new();
147 listener.expect_on_receive().once().return_const(());
148 let listener_ref = Arc::new(listener);
149 let uri_provider = StaticUriProvider::new("my-vehicle", 0x100d, 0x02);
150 let transport = LocalTransport::default();
151
152 transport
153 .register_listener(
154 &uri_provider.get_resource_uri(RESOURCE_ID),
155 None,
156 listener_ref.clone(),
157 )
158 .await
159 .unwrap();
160 let _ = transport
161 .send(
162 UMessageBuilder::publish(uri_provider.get_resource_uri(RESOURCE_ID))
163 .build()
164 .unwrap(),
165 )
166 .await;
167
168 transport
169 .unregister_listener(
170 &uri_provider.get_resource_uri(RESOURCE_ID),
171 None,
172 listener_ref,
173 )
174 .await
175 .unwrap();
176 let _ = transport
177 .send(
178 UMessageBuilder::publish(uri_provider.get_resource_uri(RESOURCE_ID))
179 .build()
180 .unwrap(),
181 )
182 .await;
183 }
184
185 #[tokio::test]
186 async fn test_send_does_not_dispatch_to_non_matching_listener() {
187 const RESOURCE_ID: u16 = 0xa1b3;
188 let mut listener = MockUListener::new();
189 listener.expect_on_receive().never().return_const(());
190 let listener_ref = Arc::new(listener);
191 let uri_provider = StaticUriProvider::new("my-vehicle", 0x100d, 0x02);
192 let transport = LocalTransport::default();
193
194 transport
195 .register_listener(
196 &uri_provider.get_resource_uri(RESOURCE_ID + 10),
197 None,
198 listener_ref.clone(),
199 )
200 .await
201 .unwrap();
202 let _ = transport
203 .send(
204 UMessageBuilder::publish(uri_provider.get_resource_uri(RESOURCE_ID))
205 .build()
206 .unwrap(),
207 )
208 .await;
209
210 transport
211 .unregister_listener(
212 &uri_provider.get_resource_uri(RESOURCE_ID + 10),
213 None,
214 listener_ref,
215 )
216 .await
217 .unwrap();
218 let _ = transport
219 .send(
220 UMessageBuilder::publish(uri_provider.get_resource_uri(RESOURCE_ID))
221 .build()
222 .unwrap(),
223 )
224 .await;
225 }
226}