1use std::collections::hash_map::Entry;
17use std::collections::HashMap;
18use std::sync::Arc;
19use std::time::Duration;
20
21use async_trait::async_trait;
22use tracing::{debug, info};
23
24use crate::{
25 communication::build_message, LocalUriProvider, UAttributes, UAttributesError,
26 UAttributesValidators, UCode, UListener, UMessage, UMessageBuilder, UStatus, UTransport, UUri,
27};
28
29use super::{RegistrationError, RequestHandler, RpcServer, ServiceInvocationError, UPayload};
30
31struct RequestListener {
32 request_handler: Arc<dyn RequestHandler>,
33 transport: Arc<dyn UTransport>,
34}
35
36impl RequestListener {
37 async fn process_valid_request(&self, resource_id: u16, request_message: UMessage) {
38 let transport_clone = self.transport.clone();
39 let request_handler_clone = self.request_handler.clone();
40
41 let request_id = request_message
42 .attributes
43 .get_or_default()
44 .id
45 .get_or_default();
46 let request_timeout = request_message
47 .attributes
48 .get_or_default()
49 .ttl
50 .unwrap_or(10_000);
51 let payload = request_message.payload;
52 let payload_format = request_message
53 .attributes
54 .get_or_default()
55 .payload_format
56 .enum_value_or_default();
57 let request_payload = payload.map(|data| UPayload::new(data, payload_format));
58
59 debug!(ttl = request_timeout, id = %request_id, "processing RPC request");
60
61 let invocation_result_future = request_handler_clone.handle_request(
62 resource_id,
63 &request_message.attributes,
64 request_payload,
65 );
66 let outcome = tokio::time::timeout(
67 Duration::from_millis(request_timeout as u64),
68 invocation_result_future,
69 )
70 .await
71 .map_err(|_e| {
72 info!(ttl = request_timeout, "request handler timed out");
73 ServiceInvocationError::DeadlineExceeded
74 })
75 .and_then(|v| v);
76
77 let response = match outcome {
78 Ok(response_payload) => {
79 let mut builder = UMessageBuilder::response_for_request(
80 request_message.attributes.get_or_default(),
81 );
82 build_message(&mut builder, response_payload)
83 }
84 Err(e) => {
85 let error = UStatus::from(e);
86 UMessageBuilder::response_for_request(request_message.attributes.get_or_default())
87 .with_comm_status(error.get_code())
88 .build_with_protobuf_payload(&error)
89 }
90 };
91
92 match response {
93 Ok(response_message) => {
94 if let Err(e) = transport_clone.send(response_message).await {
95 info!(ucode = e.code.value(), "failed to send response message");
96 }
97 }
98 Err(e) => {
99 info!("failed to create response message: {}", e);
100 }
101 }
102 }
103
104 async fn process_invalid_request(
105 &self,
106 validation_error: UAttributesError,
107 request_attributes: &UAttributes,
108 ) {
109 let (Some(id), Some(source_address)) = (
111 request_attributes.id.to_owned().into_option(),
112 request_attributes
113 .source
114 .to_owned()
115 .into_option()
116 .filter(|uri| uri.is_rpc_response()),
117 ) else {
118 debug!("invalid request message does not contain enough data to create response");
119 return;
120 };
121
122 debug!(id = %id, "processing invalid request message");
123
124 let response_payload =
125 UStatus::fail_with_code(UCode::INVALID_ARGUMENT, validation_error.to_string());
126 let Ok(response_message) = UMessageBuilder::response(
127 source_address,
128 id,
129 request_attributes.sink.get_or_default().to_owned(),
130 )
131 .with_comm_status(response_payload.get_code())
132 .build_with_protobuf_payload(&response_payload) else {
133 info!("failed to create error message");
134 return;
135 };
136
137 if let Err(e) = self.transport.send(response_message).await {
138 info!(ucode = e.code.value(), "failed to send error response");
139 }
140 }
141}
142
143#[async_trait]
144impl UListener for RequestListener {
145 async fn on_receive(&self, msg: UMessage) {
146 let Some(attributes) = msg.attributes.as_ref() else {
147 debug!("ignoring invalid message having no attributes");
148 return;
149 };
150
151 let validator = UAttributesValidators::Request.validator();
152 if let Err(e) = validator.validate(attributes) {
153 self.process_invalid_request(e, attributes).await;
154 } else if let Some(resource_id) = attributes
155 .sink
156 .as_ref()
157 .and_then(|uri| u16::try_from(uri.resource_id).ok())
158 {
159 self.process_valid_request(resource_id, msg).await;
161 }
162 }
163}
164
165pub struct InMemoryRpcServer {
175 transport: Arc<dyn UTransport>,
176 uri_provider: Arc<dyn LocalUriProvider>,
177 request_listeners: tokio::sync::Mutex<HashMap<u16, Arc<dyn UListener>>>,
178}
179
180impl InMemoryRpcServer {
181 pub fn new(transport: Arc<dyn UTransport>, uri_provider: Arc<dyn LocalUriProvider>) -> Self {
183 InMemoryRpcServer {
184 transport,
185 uri_provider,
186 request_listeners: tokio::sync::Mutex::new(HashMap::new()),
187 }
188 }
189
190 fn validate_sink_filter(filter: &UUri) -> Result<(), RegistrationError> {
191 if !filter.is_rpc_method() {
192 return Err(RegistrationError::InvalidFilter(
193 "RPC endpoint's resource ID must be in range [0x0001, 0x7FFF]".to_string(),
194 ));
195 }
196 Ok(())
197 }
198
199 fn validate_origin_filter(filter: Option<&UUri>) -> Result<(), RegistrationError> {
200 if let Some(uri) = filter {
201 if !uri.is_rpc_response() {
202 return Err(RegistrationError::InvalidFilter(
203 "origin filter's resource ID must be 0".to_string(),
204 ));
205 }
206 }
207 Ok(())
208 }
209
210 #[cfg(test)]
211 async fn contains_endpoint(&self, resource_id: u16) -> bool {
212 let listener_map = self.request_listeners.lock().await;
213 listener_map.contains_key(&resource_id)
214 }
215}
216
217#[async_trait]
218impl RpcServer for InMemoryRpcServer {
219 async fn register_endpoint(
220 &self,
221 origin_filter: Option<&UUri>,
222 resource_id: u16,
223 request_handler: Arc<dyn RequestHandler>,
224 ) -> Result<(), RegistrationError> {
225 Self::validate_origin_filter(origin_filter)?;
226 let sink_filter = self.uri_provider.get_resource_uri(resource_id);
227 Self::validate_sink_filter(&sink_filter)?;
228
229 let mut listener_map = self.request_listeners.lock().await;
230 if let Entry::Vacant(e) = listener_map.entry(resource_id) {
231 let listener = Arc::new(RequestListener {
232 request_handler,
233 transport: self.transport.clone(),
234 });
235 self.transport
236 .register_listener(
237 origin_filter.unwrap_or(&UUri::any_with_resource_id(
238 crate::uri::RESOURCE_ID_RESPONSE,
239 )),
240 Some(&sink_filter),
241 listener.clone(),
242 )
243 .await
244 .map(|_| {
245 e.insert(listener);
246 })
247 .map_err(RegistrationError::from)
248 } else {
249 Err(RegistrationError::MaxListenersExceeded)
250 }
251 }
252
253 async fn unregister_endpoint(
254 &self,
255 origin_filter: Option<&UUri>,
256 resource_id: u16,
257 _request_handler: Arc<dyn RequestHandler>,
258 ) -> Result<(), RegistrationError> {
259 Self::validate_origin_filter(origin_filter)?;
260 let sink_filter = self.uri_provider.get_resource_uri(resource_id);
261 Self::validate_sink_filter(&sink_filter)?;
262
263 let mut listener_map = self.request_listeners.lock().await;
264 if let Entry::Occupied(entry) = listener_map.entry(resource_id) {
265 let listener = entry.get().to_owned();
266 self.transport
267 .unregister_listener(
268 origin_filter.unwrap_or(&UUri::any_with_resource_id(
269 crate::uri::RESOURCE_ID_RESPONSE,
270 )),
271 Some(&sink_filter),
272 listener,
273 )
274 .await
275 .map(|_| {
276 entry.remove();
277 })
278 .map_err(RegistrationError::from)
279 } else {
280 Err(RegistrationError::NoSuchListener)
281 }
282 }
283}
284
285#[cfg(test)]
286mod tests {
287
288 use super::*;
291
292 use protobuf::well_known_types::wrappers::StringValue;
293 use test_case::test_case;
294 use tokio::sync::Notify;
295
296 use crate::{
297 communication::rpc::MockRequestHandler, utransport::MockTransport, StaticUriProvider,
298 UAttributes, UMessageType, UPriority, UUri, UUID,
299 };
300
301 fn new_uri_provider() -> Arc<dyn LocalUriProvider> {
302 Arc::new(StaticUriProvider::new("", 0x0005, 0x02))
303 }
304
305 #[test_case(None, 0x4A10; "for empty origin filter")]
306 #[test_case(Some(UUri::try_from_parts("authority", 0xBF1A, 0x01, 0x0000).unwrap()), 0x4A10; "for specific origin filter")]
307 #[test_case(Some(UUri::try_from_parts("*", 0xFFFF, 0x01, 0x0000).unwrap()), 0x7091; "for wildcard origin filter")]
308 #[tokio::test]
309 async fn test_register_endpoint_succeeds(origin_filter: Option<UUri>, resource_id: u16) {
310 let request_handler = Arc::new(MockRequestHandler::new());
312 let mut transport = MockTransport::new();
313 let uri_provider = new_uri_provider();
314 let expected_source_filter = origin_filter
315 .clone()
316 .unwrap_or(UUri::any_with_resource_id(0));
317 let param_check = move |source_filter: &UUri,
318 sink_filter: &Option<&UUri>,
319 _listener: &Arc<dyn UListener>| {
320 source_filter == &expected_source_filter
321 && sink_filter.map_or(false, |uri| uri.resource_id == resource_id as u32)
322 };
323 transport
324 .expect_do_register_listener()
325 .once()
326 .withf(param_check.clone())
327 .returning(|_source_filter, _sink_filter, _listener| Ok(()));
328 transport
329 .expect_do_unregister_listener()
330 .once()
331 .withf(param_check)
332 .returning(|_source_filter, _sink_filter, _listener| Ok(()));
333
334 let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
335
336 let register_result = rpc_server
338 .register_endpoint(origin_filter.as_ref(), resource_id, request_handler.clone())
339 .await;
340 assert!(register_result.is_ok());
342 assert!(rpc_server.contains_endpoint(resource_id).await);
343
344 let unregister_result = rpc_server
346 .unregister_endpoint(origin_filter.as_ref(), resource_id, request_handler)
347 .await;
348 assert!(unregister_result.is_ok());
349 assert!(!rpc_server.contains_endpoint(resource_id).await);
350 }
351
352 #[test_case(None, 0x0000; "for resource ID 0")]
353 #[test_case(None, 0x8000; "for resource ID out of range")]
354 #[test_case(Some(UUri::try_from_parts("*", 0xFFFF, 0xFF, 0x0001).unwrap()), 0x4A10; "for source filter with invalid resource ID")]
355 #[tokio::test]
356 async fn test_register_endpoint_fails(origin_filter: Option<UUri>, resource_id: u16) {
357 let request_handler = Arc::new(MockRequestHandler::new());
359 let mut transport = MockTransport::new();
360 let uri_provider = new_uri_provider();
361 transport.expect_do_register_listener().never();
362 transport.expect_do_unregister_listener().never();
363
364 let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
365
366 let register_result = rpc_server
368 .register_endpoint(origin_filter.as_ref(), resource_id, request_handler.clone())
369 .await;
370 assert!(register_result.is_err_and(|e| matches!(e, RegistrationError::InvalidFilter(_v))));
372 assert!(!rpc_server.contains_endpoint(resource_id).await);
373
374 let unregister_result = rpc_server
376 .unregister_endpoint(origin_filter.as_ref(), resource_id, request_handler)
377 .await;
378 assert!(unregister_result.is_err_and(|e| matches!(e, RegistrationError::InvalidFilter(_v))));
379 }
380
381 #[tokio::test]
382 async fn test_register_endpoint_fails_for_duplicate_endpoint() {
383 let request_handler = Arc::new(MockRequestHandler::new());
385 let mut transport = MockTransport::new();
386 let uri_provider = new_uri_provider();
387 transport
388 .expect_do_register_listener()
389 .once()
390 .return_const(Ok(()));
391
392 let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
393
394 assert!(rpc_server
396 .register_endpoint(None, 0x5000, request_handler.clone())
397 .await
398 .is_ok());
399 let result = rpc_server
400 .register_endpoint(None, 0x5000, request_handler)
401 .await;
402
403 assert!(result.is_err_and(|e| matches!(e, RegistrationError::MaxListenersExceeded)));
405 assert!(rpc_server.contains_endpoint(0x5000).await);
407 }
408
409 #[tokio::test]
410 async fn test_unregister_endpoint_fails_for_non_existing_endpoint() {
411 let request_handler = Arc::new(MockRequestHandler::new());
413 let mut transport = MockTransport::new();
414 let uri_provider = new_uri_provider();
415 transport.expect_do_unregister_listener().never();
416
417 let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
418
419 assert!(!rpc_server.contains_endpoint(0x5000).await);
421 let result = rpc_server
422 .unregister_endpoint(None, 0x5000, request_handler)
423 .await;
424
425 assert!(result.is_err_and(|e| matches!(e, RegistrationError::NoSuchListener)));
427 }
428
429 #[tokio::test]
430 async fn test_request_listener_returns_response_for_invalid_request() {
431 let mut request_handler = MockRequestHandler::new();
433 let mut transport = MockTransport::new();
434 let notify = Arc::new(Notify::new());
435 let notify_clone = notify.clone();
436 let message_id = UUID::build();
437 let request_id = message_id.clone();
438
439 request_handler.expect_handle_request().never();
440 transport
441 .expect_do_send()
442 .once()
443 .withf(move |response_message| {
444 if !response_message.is_response() {
445 return false;
446 }
447 if response_message
448 .attributes
449 .get_or_default()
450 .reqid
451 .get_or_default()
452 != &request_id
453 {
454 return false;
455 }
456 let error: UStatus = response_message.extract_protobuf().unwrap();
457 error.get_code() == UCode::INVALID_ARGUMENT
458 && response_message
459 .attributes
460 .get_or_default()
461 .commstatus
462 .is_some_and(|v| v.enum_value_or_default() == error.get_code())
463 })
464 .returning(move |_msg| {
465 notify_clone.notify_one();
466 Ok(())
467 });
468
469 let invalid_request_attributes = UAttributes {
473 type_: UMessageType::UMESSAGE_TYPE_REQUEST.into(),
474 sink: UUri::try_from("up://localhost/A200/1/7000").ok().into(),
475 source: UUri::try_from("up://localhost/A100/1/0").ok().into(),
476 id: Some(message_id.clone()).into(),
477 priority: UPriority::UPRIORITY_CS5.into(),
478 ..Default::default()
479 };
480 assert!(
481 UAttributesValidators::Request
482 .validator()
483 .validate(&invalid_request_attributes)
484 .is_err(),
485 "request message attributes are supposed to be invalid (no TTL)"
486 );
487 let invalid_request_message = UMessage {
488 attributes: Some(invalid_request_attributes).into(),
489 ..Default::default()
490 };
491
492 let request_listener = RequestListener {
493 request_handler: Arc::new(request_handler),
494 transport: Arc::new(transport),
495 };
496 request_listener.on_receive(invalid_request_message).await;
497
498 let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
500 assert!(result.is_ok());
501 }
502
503 #[tokio::test]
504 async fn test_request_listener_ignores_invalid_request() {
505 let mut request_handler = MockRequestHandler::new();
507 request_handler.expect_handle_request().never();
508 let mut transport = MockTransport::new();
509 transport.expect_do_send().never();
510
511 let invalid_request_attributes = UAttributes {
515 type_: UMessageType::UMESSAGE_TYPE_REQUEST.into(),
516 sink: UUri::try_from("up://localhost/A200/1/7000").ok().into(),
517 source: UUri::try_from("up://localhost/A100/1/0").ok().into(),
518 ttl: Some(5_000),
519 id: None.into(),
520 priority: UPriority::UPRIORITY_CS5.into(),
521 ..Default::default()
522 };
523 assert!(
524 UAttributesValidators::Request
525 .validator()
526 .validate(&invalid_request_attributes)
527 .is_err(),
528 "request message attributes are supposed to be invalid (no ID)"
529 );
530 let invalid_request_message = UMessage {
531 attributes: Some(invalid_request_attributes).into(),
532 ..Default::default()
533 };
534
535 let request_listener = RequestListener {
536 request_handler: Arc::new(request_handler),
537 transport: Arc::new(transport),
538 };
539 request_listener.on_receive(invalid_request_message).await;
540
541 }
545
546 #[tokio::test]
547 async fn test_request_listener_invokes_operation_successfully() {
548 let mut request_handler = MockRequestHandler::new();
549 let mut transport = MockTransport::new();
550 let notify = Arc::new(Notify::new());
551 let notify_clone = notify.clone();
552 let request_payload = StringValue {
553 value: "Hello".to_string(),
554 ..Default::default()
555 };
556 let message_id = UUID::build();
557 let message_id_clone = message_id.clone();
558 let message_source = UUri::try_from("up://localhost/A100/1/0").unwrap();
559 let message_source_clone = message_source.clone();
560
561 request_handler
562 .expect_handle_request()
563 .once()
564 .withf(move |resource_id, message_attributes, request_payload| {
565 if let Some(pl) = request_payload {
566 let message_source = message_attributes.source.as_ref().unwrap();
567 let msg: StringValue = pl.extract_protobuf().unwrap();
568 msg.value == *"Hello"
569 && *resource_id == 0x7000_u16
570 && *message_source == message_source_clone
571 } else {
572 false
573 }
574 })
575 .returning(|_resource_id, _message_attributes, _request_payload| {
576 let response_payload = UPayload::try_from_protobuf(StringValue {
577 value: "Hello World".to_string(),
578 ..Default::default()
579 })
580 .unwrap();
581 Ok(Some(response_payload))
582 });
583 transport
584 .expect_do_send()
585 .once()
586 .withf(move |response_message| {
587 let msg: StringValue = response_message.extract_protobuf().unwrap();
588 msg.value == *"Hello World"
589 && response_message.is_response()
590 && response_message
591 .attributes
592 .get_or_default()
593 .commstatus
594 .map_or(true, |v| v.enum_value_or_default() == UCode::OK)
595 && response_message
596 .attributes
597 .get_or_default()
598 .reqid
599 .get_or_default()
600 == &message_id_clone
601 })
602 .returning(move |_msg| {
603 notify_clone.notify_one();
604 Ok(())
605 });
606 let request_message = UMessageBuilder::request(
607 UUri::try_from("up://localhost/A200/1/7000").unwrap(),
608 message_source,
609 5_000,
610 )
611 .with_message_id(message_id)
612 .build_with_protobuf_payload(&request_payload)
613 .unwrap();
614
615 let request_listener = RequestListener {
616 request_handler: Arc::new(request_handler),
617 transport: Arc::new(transport),
618 };
619 request_listener.on_receive(request_message).await;
620 let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
621 assert!(result.is_ok());
622 }
623
624 #[tokio::test]
625 async fn test_request_listener_invokes_operation_erroneously() {
626 let mut request_handler = MockRequestHandler::new();
627 let mut transport = MockTransport::new();
628 let notify = Arc::new(Notify::new());
629 let notify_clone = notify.clone();
630 let message_id = UUID::build();
631 let message_id_clone = message_id.clone();
632
633 request_handler
634 .expect_handle_request()
635 .once()
636 .withf(|resource_id, _message_attributes, _request_payload| *resource_id == 0x7000_u16)
637 .returning(|_resource_id, _message_attributes, _request_payload| {
638 Err(ServiceInvocationError::NotFound(
639 "no such object".to_string(),
640 ))
641 });
642 transport
643 .expect_do_send()
644 .once()
645 .withf(move |response_message| {
646 let error: UStatus = response_message.extract_protobuf().unwrap();
647 error.get_code() == UCode::NOT_FOUND
648 && response_message.is_response()
649 && response_message
650 .attributes
651 .get_or_default()
652 .commstatus
653 .is_some_and(|v| v.enum_value_or_default() == error.get_code())
654 && response_message
655 .attributes
656 .get_or_default()
657 .reqid
658 .get_or_default()
659 == &message_id_clone
660 })
661 .returning(move |_msg| {
662 notify_clone.notify_one();
663 Ok(())
664 });
665 let request_message = UMessageBuilder::request(
666 UUri::try_from("up://localhost/A200/1/7000").unwrap(),
667 UUri::try_from("up://localhost/A100/1/0").unwrap(),
668 5_000,
669 )
670 .with_message_id(message_id)
671 .build()
672 .unwrap();
673
674 let request_listener = RequestListener {
675 request_handler: Arc::new(request_handler),
676 transport: Arc::new(transport),
677 };
678 request_listener.on_receive(request_message).await;
679 let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
680 assert!(result.is_ok());
681 }
682
683 #[tokio::test]
684 async fn test_request_listener_times_out() {
685 struct NonRespondingHandler;
690 #[async_trait]
691 impl RequestHandler for NonRespondingHandler {
692 async fn handle_request(
693 &self,
694 resource_id: u16,
695 _message_attributes: &UAttributes,
696 _request_payload: Option<UPayload>,
697 ) -> Result<Option<UPayload>, ServiceInvocationError> {
698 assert_eq!(resource_id, 0x7000);
699 tokio::time::sleep(Duration::from_millis(2000)).await;
702 Ok(None)
703 }
704 }
705
706 let request_handler = NonRespondingHandler {};
707 let mut transport = MockTransport::new();
708 let notify = Arc::new(Notify::new());
709 let notify_clone = notify.clone();
710 let message_id = UUID::build();
711 let message_id_clone = message_id.clone();
712
713 transport
714 .expect_do_send()
715 .once()
716 .withf(move |response_message| {
717 let error: UStatus = response_message.extract_protobuf().unwrap();
718 error.get_code() == UCode::DEADLINE_EXCEEDED
719 && response_message.is_response()
720 && response_message
721 .attributes
722 .get_or_default()
723 .commstatus
724 .is_some_and(|v| v.enum_value_or_default() == error.get_code())
725 && response_message
726 .attributes
727 .get_or_default()
728 .reqid
729 .get_or_default()
730 == &message_id_clone
731 })
732 .returning(move |_msg| {
733 notify_clone.notify_one();
734 Ok(())
735 });
736 let request_message = UMessageBuilder::request(
737 UUri::try_from("up://localhost/A200/1/7000").unwrap(),
738 UUri::try_from("up://localhost/A100/1/0").unwrap(),
739 100,
741 )
742 .with_message_id(message_id)
743 .build()
744 .expect("should have been able to create RPC Request message");
745
746 let request_listener = RequestListener {
747 request_handler: Arc::new(request_handler),
748 transport: Arc::new(transport),
749 };
750 request_listener.on_receive(request_message).await;
751 let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
752 assert!(result.is_ok());
753 }
754}