1use std::collections::hash_map::Entry;
17use std::collections::HashMap;
18use std::sync::Arc;
19use std::time::Duration;
20
21use async_trait::async_trait;
22use tracing::{debug, info};
23
24use crate::{
25 communication::build_message, LocalUriProvider, UListener, UMessage, UMessageBuilder, UStatus,
26 UTransport, UUri,
27};
28
29use super::{RegistrationError, RequestHandler, RpcServer, ServiceInvocationError, UPayload};
30
31struct RequestListener {
32 request_handler: Arc<dyn RequestHandler>,
33 transport: Arc<dyn UTransport>,
34}
35
36impl RequestListener {
37 async fn process_valid_request(&self, resource_id: u16, request_message: UMessage) {
38 let transport_clone = self.transport.clone();
39 let request_handler_clone = self.request_handler.clone();
40 let mut response_builder =
41 UMessageBuilder::response_for_request(request_message.attributes_unchecked());
42
43 let request_message_id = request_message.id_unchecked().to_hyphenated_string();
44 let request_timeout = request_message.ttl_unchecked();
45 let payload_format = request_message.payload_format().unwrap_or_default();
46 let payload = request_message.payload;
47 let request_payload = payload.map(|data| UPayload::new(data, payload_format));
48
49 debug!(
50 ttl = request_timeout,
51 id = request_message_id,
52 resource_id = resource_id,
53 "processing RPC request"
54 );
55
56 let invocation_result_future = request_handler_clone.handle_request(
57 resource_id,
58 &request_message.attributes,
59 request_payload,
60 );
61 let outcome = tokio::time::timeout(
62 Duration::from_millis(request_timeout as u64),
63 invocation_result_future,
64 )
65 .await
66 .map_err(|_e| {
67 info!(ttl = request_timeout, "request handler timed out");
68 ServiceInvocationError::DeadlineExceeded
69 })
70 .and_then(|v| v);
71
72 let response = match outcome {
73 Ok(response_payload) => build_message(&mut response_builder, response_payload),
74 Err(e) => {
75 let error = UStatus::from(e);
76 response_builder
77 .with_comm_status(error.get_code())
78 .build_with_protobuf_payload(&error)
79 }
80 };
81
82 match response {
83 Ok(response_message) => {
84 if let Err(e) = transport_clone.send(response_message).await {
85 info!(ucode = e.code.value(), "failed to send response message");
86 }
87 }
88 Err(e) => {
89 info!("failed to create response message: {}", e);
90 }
91 }
92 }
93}
94
95#[async_trait]
96impl UListener for RequestListener {
97 async fn on_receive(&self, msg: UMessage) {
98 if msg.is_request() {
99 let method_id = msg.sink_unchecked().resource_id();
101 self.process_valid_request(method_id, msg).await;
102 } else {
103 debug!(
104 message_type = msg.type_unchecked().to_cloudevent_type(),
105 "ignoring non-request message received by RPC server"
106 );
107 }
108 }
109}
110
111pub struct InMemoryRpcServer {
121 transport: Arc<dyn UTransport>,
122 uri_provider: Arc<dyn LocalUriProvider>,
123 request_listeners: tokio::sync::Mutex<HashMap<u16, Arc<dyn UListener>>>,
124}
125
126impl InMemoryRpcServer {
127 pub fn new(transport: Arc<dyn UTransport>, uri_provider: Arc<dyn LocalUriProvider>) -> Self {
129 InMemoryRpcServer {
130 transport,
131 uri_provider,
132 request_listeners: tokio::sync::Mutex::new(HashMap::new()),
133 }
134 }
135
136 fn validate_sink_filter(filter: &UUri) -> Result<(), RegistrationError> {
137 if !filter.is_rpc_method() {
138 return Err(RegistrationError::InvalidFilter(
139 "RPC endpoint's resource ID must be in range [0x0001, 0x7FFF]".to_string(),
140 ));
141 }
142 Ok(())
143 }
144
145 fn validate_origin_filter(filter: Option<&UUri>) -> Result<(), RegistrationError> {
146 if let Some(uri) = filter {
147 if !uri.is_rpc_response() {
148 return Err(RegistrationError::InvalidFilter(
149 "origin filter's resource ID must be 0".to_string(),
150 ));
151 }
152 }
153 Ok(())
154 }
155
156 #[cfg(test)]
157 async fn contains_endpoint(&self, resource_id: u16) -> bool {
158 let listener_map = self.request_listeners.lock().await;
159 listener_map.contains_key(&resource_id)
160 }
161}
162
163#[async_trait]
164impl RpcServer for InMemoryRpcServer {
165 async fn register_endpoint(
166 &self,
167 origin_filter: Option<&UUri>,
168 resource_id: u16,
169 request_handler: Arc<dyn RequestHandler>,
170 ) -> Result<(), RegistrationError> {
171 Self::validate_origin_filter(origin_filter)?;
172 let sink_filter = self.uri_provider.get_resource_uri(resource_id);
173 Self::validate_sink_filter(&sink_filter)?;
174
175 let mut listener_map = self.request_listeners.lock().await;
176 if let Entry::Vacant(e) = listener_map.entry(resource_id) {
177 let listener = Arc::new(RequestListener {
178 request_handler,
179 transport: self.transport.clone(),
180 });
181 self.transport
182 .register_listener(
183 origin_filter.unwrap_or(&UUri::any_with_resource_id(
184 crate::uri::RESOURCE_ID_RESPONSE,
185 )),
186 Some(&sink_filter),
187 listener.clone(),
188 )
189 .await
190 .map(|_| {
191 e.insert(listener);
192 })
193 .map_err(RegistrationError::from)
194 } else {
195 Err(RegistrationError::MaxListenersExceeded)
196 }
197 }
198
199 async fn unregister_endpoint(
200 &self,
201 origin_filter: Option<&UUri>,
202 resource_id: u16,
203 _request_handler: Arc<dyn RequestHandler>,
204 ) -> Result<(), RegistrationError> {
205 Self::validate_origin_filter(origin_filter)?;
206 let sink_filter = self.uri_provider.get_resource_uri(resource_id);
207 Self::validate_sink_filter(&sink_filter)?;
208
209 let mut listener_map = self.request_listeners.lock().await;
210 if let Entry::Occupied(entry) = listener_map.entry(resource_id) {
211 let listener = entry.get().to_owned();
212 self.transport
213 .unregister_listener(
214 origin_filter.unwrap_or(&UUri::any_with_resource_id(
215 crate::uri::RESOURCE_ID_RESPONSE,
216 )),
217 Some(&sink_filter),
218 listener,
219 )
220 .await
221 .map(|_| {
222 entry.remove();
223 })
224 .map_err(RegistrationError::from)
225 } else {
226 Err(RegistrationError::NoSuchListener)
227 }
228 }
229}
230
231#[cfg(test)]
232mod tests {
233
234 use super::*;
237
238 use protobuf::well_known_types::wrappers::StringValue;
239 use test_case::test_case;
240 use tokio::sync::Notify;
241
242 use crate::{
243 communication::rpc::MockRequestHandler, utransport::MockTransport, StaticUriProvider,
244 UAttributes, UCode, UUri, UUID,
245 };
246
247 fn new_uri_provider() -> Arc<dyn LocalUriProvider> {
248 Arc::new(StaticUriProvider::new("", 0x0005, 0x02))
249 }
250
251 #[test_case(None, 0x4A10; "for empty origin filter")]
252 #[test_case(Some(UUri::try_from_parts("authority", 0xBF1A, 0x01, 0x0000).unwrap()), 0x4A10; "for specific origin filter")]
253 #[test_case(Some(UUri::try_from_parts("*", 0xFFFF, 0x01, 0x0000).unwrap()), 0x7091; "for wildcard origin filter")]
254 #[tokio::test]
255 async fn test_register_endpoint_succeeds(origin_filter: Option<UUri>, resource_id: u16) {
256 let request_handler = Arc::new(MockRequestHandler::new());
258 let mut transport = MockTransport::new();
259 let uri_provider = new_uri_provider();
260 let expected_source_filter = origin_filter
261 .clone()
262 .unwrap_or(UUri::any_with_resource_id(0));
263 let param_check = move |source_filter: &UUri,
264 sink_filter: &Option<&UUri>,
265 _listener: &Arc<dyn UListener>| {
266 source_filter == &expected_source_filter
267 && sink_filter.is_some_and(|uri| uri.resource_id == resource_id as u32)
268 };
269 transport
270 .expect_do_register_listener()
271 .once()
272 .withf(param_check.clone())
273 .returning(|_source_filter, _sink_filter, _listener| Ok(()));
274 transport
275 .expect_do_unregister_listener()
276 .once()
277 .withf(param_check)
278 .returning(|_source_filter, _sink_filter, _listener| Ok(()));
279
280 let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
281
282 let register_result = rpc_server
284 .register_endpoint(origin_filter.as_ref(), resource_id, request_handler.clone())
285 .await;
286 assert!(register_result.is_ok());
288 assert!(rpc_server.contains_endpoint(resource_id).await);
289
290 let unregister_result = rpc_server
292 .unregister_endpoint(origin_filter.as_ref(), resource_id, request_handler)
293 .await;
294 assert!(unregister_result.is_ok());
295 assert!(!rpc_server.contains_endpoint(resource_id).await);
296 }
297
298 #[test_case(None, 0x0000; "for resource ID 0")]
299 #[test_case(None, 0x8000; "for resource ID out of range")]
300 #[test_case(Some(UUri::try_from_parts("*", 0xFFFF, 0xFF, 0x0001).unwrap()), 0x4A10; "for source filter with invalid resource ID")]
301 #[tokio::test]
302 async fn test_register_endpoint_fails(origin_filter: Option<UUri>, resource_id: u16) {
303 let request_handler = Arc::new(MockRequestHandler::new());
305 let mut transport = MockTransport::new();
306 let uri_provider = new_uri_provider();
307 transport.expect_do_register_listener().never();
308 transport.expect_do_unregister_listener().never();
309
310 let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
311
312 let register_result = rpc_server
314 .register_endpoint(origin_filter.as_ref(), resource_id, request_handler.clone())
315 .await;
316 assert!(register_result.is_err_and(|e| matches!(e, RegistrationError::InvalidFilter(_v))));
318 assert!(!rpc_server.contains_endpoint(resource_id).await);
319
320 let unregister_result = rpc_server
322 .unregister_endpoint(origin_filter.as_ref(), resource_id, request_handler)
323 .await;
324 assert!(unregister_result.is_err_and(|e| matches!(e, RegistrationError::InvalidFilter(_v))));
325 }
326
327 #[tokio::test]
328 async fn test_register_endpoint_fails_for_duplicate_endpoint() {
329 let request_handler = Arc::new(MockRequestHandler::new());
331 let mut transport = MockTransport::new();
332 let uri_provider = new_uri_provider();
333 transport
334 .expect_do_register_listener()
335 .once()
336 .return_const(Ok(()));
337
338 let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
339
340 assert!(rpc_server
342 .register_endpoint(None, 0x5000, request_handler.clone())
343 .await
344 .is_ok());
345 let result = rpc_server
346 .register_endpoint(None, 0x5000, request_handler)
347 .await;
348
349 assert!(result.is_err_and(|e| matches!(e, RegistrationError::MaxListenersExceeded)));
351 assert!(rpc_server.contains_endpoint(0x5000).await);
353 }
354
355 #[tokio::test]
356 async fn test_unregister_endpoint_fails_for_non_existing_endpoint() {
357 let request_handler = Arc::new(MockRequestHandler::new());
359 let mut transport = MockTransport::new();
360 let uri_provider = new_uri_provider();
361 transport.expect_do_unregister_listener().never();
362
363 let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
364
365 assert!(!rpc_server.contains_endpoint(0x5000).await);
367 let result = rpc_server
368 .unregister_endpoint(None, 0x5000, request_handler)
369 .await;
370
371 assert!(result.is_err_and(|e| matches!(e, RegistrationError::NoSuchListener)));
373 }
374
375 #[tokio::test]
376 async fn test_request_listener_invokes_operation_successfully() {
377 let mut request_handler = MockRequestHandler::new();
378 let mut transport = MockTransport::new();
379 let notify = Arc::new(Notify::new());
380 let notify_clone = notify.clone();
381 let request_payload = StringValue {
382 value: "Hello".to_string(),
383 ..Default::default()
384 };
385 let message_id = UUID::build();
386 let message_id_clone = message_id.clone();
387 let message_source = UUri::try_from("up://localhost/A100/1/0").unwrap();
388 let message_source_clone = message_source.clone();
389
390 request_handler
391 .expect_handle_request()
392 .once()
393 .withf(move |resource_id, message_attributes, request_payload| {
394 if let Some(pl) = request_payload {
395 let message_source = message_attributes.source.as_ref().unwrap();
396 let msg: StringValue = pl.extract_protobuf().unwrap();
397 msg.value == *"Hello"
398 && *resource_id == 0x7000_u16
399 && *message_source == message_source_clone
400 } else {
401 false
402 }
403 })
404 .returning(|_resource_id, _message_attributes, _request_payload| {
405 let response_payload = UPayload::try_from_protobuf(StringValue {
406 value: "Hello World".to_string(),
407 ..Default::default()
408 })
409 .unwrap();
410 Ok(Some(response_payload))
411 });
412 transport
413 .expect_do_send()
414 .once()
415 .withf(move |response_message| {
416 let msg: StringValue = response_message.extract_protobuf().unwrap();
417 msg.value == *"Hello World"
418 && response_message.is_response()
419 && response_message
420 .commstatus()
421 .is_none_or(|code| code == UCode::OK)
422 && response_message.request_id_unchecked() == &message_id_clone
423 })
424 .returning(move |_msg| {
425 notify_clone.notify_one();
426 Ok(())
427 });
428 let request_message = UMessageBuilder::request(
429 UUri::try_from("up://localhost/A200/1/7000").unwrap(),
430 message_source,
431 5_000,
432 )
433 .with_message_id(message_id)
434 .build_with_protobuf_payload(&request_payload)
435 .unwrap();
436
437 let request_listener = RequestListener {
438 request_handler: Arc::new(request_handler),
439 transport: Arc::new(transport),
440 };
441 request_listener.on_receive(request_message).await;
442 let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
443 assert!(result.is_ok());
444 }
445
446 #[tokio::test]
447 async fn test_request_listener_invokes_operation_erroneously() {
448 let mut request_handler = MockRequestHandler::new();
449 let mut transport = MockTransport::new();
450 let notify = Arc::new(Notify::new());
451 let notify_clone = notify.clone();
452 let message_id = UUID::build();
453 let message_id_clone = message_id.clone();
454
455 request_handler
456 .expect_handle_request()
457 .once()
458 .withf(|resource_id, _message_attributes, _request_payload| *resource_id == 0x7000_u16)
459 .returning(|_resource_id, _message_attributes, _request_payload| {
460 Err(ServiceInvocationError::NotFound(
461 "no such object".to_string(),
462 ))
463 });
464 transport
465 .expect_do_send()
466 .once()
467 .withf(move |response_message| {
468 let error: UStatus = response_message.extract_protobuf().unwrap();
469 error.get_code() == UCode::NOT_FOUND
470 && response_message.is_response()
471 && response_message.commstatus_unchecked() == error.get_code()
472 && response_message.request_id_unchecked() == &message_id_clone
473 })
474 .returning(move |_msg| {
475 notify_clone.notify_one();
476 Ok(())
477 });
478 let request_message = UMessageBuilder::request(
479 UUri::try_from("up://localhost/A200/1/7000").unwrap(),
480 UUri::try_from("up://localhost/A100/1/0").unwrap(),
481 5_000,
482 )
483 .with_message_id(message_id)
484 .build()
485 .unwrap();
486
487 let request_listener = RequestListener {
488 request_handler: Arc::new(request_handler),
489 transport: Arc::new(transport),
490 };
491 request_listener.on_receive(request_message).await;
492 let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
493 assert!(result.is_ok());
494 }
495
496 #[tokio::test]
497 async fn test_request_listener_times_out() {
498 struct NonRespondingHandler;
503 #[async_trait]
504 impl RequestHandler for NonRespondingHandler {
505 async fn handle_request(
506 &self,
507 resource_id: u16,
508 _message_attributes: &UAttributes,
509 _request_payload: Option<UPayload>,
510 ) -> Result<Option<UPayload>, ServiceInvocationError> {
511 assert_eq!(resource_id, 0x7000);
512 tokio::time::sleep(Duration::from_millis(2000)).await;
515 Ok(None)
516 }
517 }
518
519 let request_handler = NonRespondingHandler {};
520 let mut transport = MockTransport::new();
521 let notify = Arc::new(Notify::new());
522 let notify_clone = notify.clone();
523 let message_id = UUID::build();
524 let message_id_clone = message_id.clone();
525
526 transport
527 .expect_do_send()
528 .once()
529 .withf(move |response_message| {
530 let error: UStatus = response_message.extract_protobuf().unwrap();
531 error.get_code() == UCode::DEADLINE_EXCEEDED
532 && response_message.is_response()
533 && response_message.commstatus_unchecked() == error.get_code()
534 && response_message.request_id_unchecked() == &message_id_clone
535 })
536 .returning(move |_msg| {
537 notify_clone.notify_one();
538 Ok(())
539 });
540 let request_message = UMessageBuilder::request(
541 UUri::try_from("up://localhost/A200/1/7000").unwrap(),
542 UUri::try_from("up://localhost/A100/1/0").unwrap(),
543 100,
545 )
546 .with_message_id(message_id)
547 .build()
548 .expect("should have been able to create RPC Request message");
549
550 let request_listener = RequestListener {
551 request_handler: Arc::new(request_handler),
552 transport: Arc::new(transport),
553 };
554 request_listener.on_receive(request_message).await;
555 let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
556 assert!(result.is_ok());
557 }
558}