1use std::collections::hash_map::Entry;
17use std::collections::HashMap;
18use std::sync::Arc;
19use std::time::Duration;
20
21use async_trait::async_trait;
22use tracing::{debug, info};
23
24use crate::{
25 communication::build_message, LocalUriProvider, UAttributes, UAttributesError,
26 UAttributesValidators, UCode, UListener, UMessage, UMessageBuilder, UStatus, UTransport, UUri,
27};
28
29use super::{RegistrationError, RequestHandler, RpcServer, ServiceInvocationError, UPayload};
30
31struct RequestListener {
32 request_handler: Arc<dyn RequestHandler>,
33 transport: Arc<dyn UTransport>,
34}
35
36impl RequestListener {
37 async fn process_valid_request(&self, resource_id: u16, request_message: UMessage) {
38 let transport_clone = self.transport.clone();
39 let request_handler_clone = self.request_handler.clone();
40
41 let request_id = request_message
42 .attributes
43 .get_or_default()
44 .id
45 .get_or_default();
46 let request_timeout = request_message
47 .attributes
48 .get_or_default()
49 .ttl
50 .unwrap_or(10_000);
51 let payload = request_message.payload;
52 let payload_format = request_message
53 .attributes
54 .get_or_default()
55 .payload_format
56 .enum_value_or_default();
57 let request_payload = payload.map(|data| UPayload::new(data, payload_format));
58
59 debug!(ttl = request_timeout, id = %request_id, "processing RPC request");
60
61 let invocation_result_future = request_handler_clone.handle_request(
62 resource_id,
63 &request_message.attributes,
64 request_payload,
65 );
66 let outcome = tokio::time::timeout(
67 Duration::from_millis(request_timeout as u64),
68 invocation_result_future,
69 )
70 .await
71 .map_err(|_e| {
72 info!(ttl = request_timeout, "request handler timed out");
73 ServiceInvocationError::DeadlineExceeded
74 })
75 .and_then(|v| v);
76
77 let response = match outcome {
78 Ok(response_payload) => {
79 let mut builder = UMessageBuilder::response_for_request(
80 request_message.attributes.get_or_default(),
81 );
82 build_message(&mut builder, response_payload)
83 }
84 Err(e) => {
85 let error = UStatus::from(e);
86 UMessageBuilder::response_for_request(request_message.attributes.get_or_default())
87 .with_comm_status(error.get_code())
88 .build_with_protobuf_payload(&error)
89 }
90 };
91
92 match response {
93 Ok(response_message) => {
94 if let Err(e) = transport_clone.send(response_message).await {
95 info!(ucode = e.code.value(), "failed to send response message");
96 }
97 }
98 Err(e) => {
99 info!("failed to create response message: {}", e);
100 }
101 }
102 }
103
104 async fn process_invalid_request(
105 &self,
106 validation_error: UAttributesError,
107 request_attributes: &UAttributes,
108 ) {
109 let (Some(id), Some(source_address)) = (
111 request_attributes.id.to_owned().into_option(),
112 request_attributes
113 .source
114 .to_owned()
115 .into_option()
116 .filter(|uri| uri.is_rpc_response()),
117 ) else {
118 debug!("invalid request message does not contain enough data to create response");
119 return;
120 };
121
122 debug!(id = %id, "processing invalid request message");
123
124 let response_payload =
125 UStatus::fail_with_code(UCode::INVALID_ARGUMENT, validation_error.to_string());
126 let Ok(response_message) = UMessageBuilder::response(
127 source_address,
128 id,
129 request_attributes.sink.get_or_default().to_owned(),
130 )
131 .with_comm_status(response_payload.get_code())
132 .build_with_protobuf_payload(&response_payload) else {
133 info!("failed to create error message");
134 return;
135 };
136
137 if let Err(e) = self.transport.send(response_message).await {
138 info!(ucode = e.code.value(), "failed to send error response");
139 }
140 }
141}
142
143#[async_trait]
144impl UListener for RequestListener {
145 async fn on_receive(&self, msg: UMessage) {
146 let Some(attributes) = msg.attributes.as_ref() else {
147 debug!("ignoring invalid message having no attributes");
148 return;
149 };
150
151 let validator = UAttributesValidators::Request.validator();
152 if let Err(e) = validator.validate(attributes) {
153 self.process_invalid_request(e, attributes).await;
154 } else if let Some(resource_id) = attributes
155 .sink
156 .as_ref()
157 .and_then(|uri| u16::try_from(uri.resource_id).ok())
158 {
159 self.process_valid_request(resource_id, msg).await;
161 }
162 }
163}
164
165pub struct InMemoryRpcServer {
175 transport: Arc<dyn UTransport>,
176 uri_provider: Arc<dyn LocalUriProvider>,
177 request_listeners: tokio::sync::Mutex<HashMap<u16, Arc<dyn UListener>>>,
178}
179
180impl InMemoryRpcServer {
181 pub fn new(transport: Arc<dyn UTransport>, uri_provider: Arc<dyn LocalUriProvider>) -> Self {
183 InMemoryRpcServer {
184 transport,
185 uri_provider,
186 request_listeners: tokio::sync::Mutex::new(HashMap::new()),
187 }
188 }
189
190 fn validate_sink_filter(filter: &UUri) -> Result<(), RegistrationError> {
191 if !filter.is_rpc_method() {
192 return Err(RegistrationError::InvalidFilter(
193 "RPC endpoint's resource ID must be in range [0x0001, 0x7FFF]".to_string(),
194 ));
195 }
196 Ok(())
197 }
198
199 fn validate_origin_filter(filter: Option<&UUri>) -> Result<(), RegistrationError> {
200 if let Some(uri) = filter {
201 if !uri.is_rpc_response() {
202 return Err(RegistrationError::InvalidFilter(
203 "origin filter's resource ID must be 0".to_string(),
204 ));
205 }
206 }
207 Ok(())
208 }
209
210 #[cfg(test)]
211 async fn contains_endpoint(&self, resource_id: u16) -> bool {
212 let listener_map = self.request_listeners.lock().await;
213 listener_map.contains_key(&resource_id)
214 }
215}
216
217#[async_trait]
218impl RpcServer for InMemoryRpcServer {
219 async fn register_endpoint(
220 &self,
221 origin_filter: Option<&UUri>,
222 resource_id: u16,
223 request_handler: Arc<dyn RequestHandler>,
224 ) -> Result<(), RegistrationError> {
225 Self::validate_origin_filter(origin_filter)?;
226 let sink_filter = self.uri_provider.get_resource_uri(resource_id);
227 Self::validate_sink_filter(&sink_filter)?;
228
229 let mut listener_map = self.request_listeners.lock().await;
230 if let Entry::Vacant(e) = listener_map.entry(resource_id) {
231 let listener = Arc::new(RequestListener {
232 request_handler,
233 transport: self.transport.clone(),
234 });
235 self.transport
236 .register_listener(
237 origin_filter.unwrap_or(&UUri::any_with_resource_id(
238 crate::uri::RESOURCE_ID_RESPONSE,
239 )),
240 Some(&sink_filter),
241 listener.clone(),
242 )
243 .await
244 .map(|_| {
245 e.insert(listener);
246 })
247 .map_err(RegistrationError::from)
248 } else {
249 Err(RegistrationError::MaxListenersExceeded)
250 }
251 }
252
253 async fn unregister_endpoint(
254 &self,
255 origin_filter: Option<&UUri>,
256 resource_id: u16,
257 _request_handler: Arc<dyn RequestHandler>,
258 ) -> Result<(), RegistrationError> {
259 Self::validate_origin_filter(origin_filter)?;
260 let sink_filter = self.uri_provider.get_resource_uri(resource_id);
261 Self::validate_sink_filter(&sink_filter)?;
262
263 let mut listener_map = self.request_listeners.lock().await;
264 if let Entry::Occupied(entry) = listener_map.entry(resource_id) {
265 let listener = entry.get().to_owned();
266 self.transport
267 .unregister_listener(
268 origin_filter.unwrap_or(&UUri::any_with_resource_id(
269 crate::uri::RESOURCE_ID_RESPONSE,
270 )),
271 Some(&sink_filter),
272 listener,
273 )
274 .await
275 .map(|_| {
276 entry.remove();
277 })
278 .map_err(RegistrationError::from)
279 } else {
280 Err(RegistrationError::NoSuchListener)
281 }
282 }
283}
284
285#[cfg(test)]
286mod tests {
287
288 use super::*;
291
292 use protobuf::well_known_types::wrappers::StringValue;
293 use test_case::test_case;
294 use tokio::sync::Notify;
295
296 use crate::{
297 communication::rpc::MockRequestHandler, utransport::MockTransport, StaticUriProvider,
298 UAttributes, UMessageType, UPriority, UUri, UUID,
299 };
300
301 fn new_uri_provider() -> Arc<dyn LocalUriProvider> {
302 Arc::new(StaticUriProvider::new("", 0x0005, 0x02))
303 }
304
305 #[test_case(None, 0x4A10; "for empty origin filter")]
306 #[test_case(Some(UUri::try_from_parts("authority", 0xBF1A, 0x01, 0x0000).unwrap()), 0x4A10; "for specific origin filter")]
307 #[test_case(Some(UUri::try_from_parts("*", 0xFFFF, 0x01, 0x0000).unwrap()), 0x7091; "for wildcard origin filter")]
308 #[tokio::test]
309 async fn test_register_endpoint_succeeds(origin_filter: Option<UUri>, resource_id: u16) {
310 let request_handler = Arc::new(MockRequestHandler::new());
312 let mut transport = MockTransport::new();
313 let uri_provider = new_uri_provider();
314 let expected_source_filter = origin_filter
315 .clone()
316 .unwrap_or(UUri::any_with_resource_id(0));
317 let param_check = move |source_filter: &UUri,
318 sink_filter: &Option<&UUri>,
319 _listener: &Arc<dyn UListener>| {
320 source_filter == &expected_source_filter
321 && sink_filter.is_some_and(|uri| uri.resource_id == resource_id as u32)
322 };
323 transport
324 .expect_do_register_listener()
325 .once()
326 .withf(param_check.clone())
327 .returning(|_source_filter, _sink_filter, _listener| Ok(()));
328 transport
329 .expect_do_unregister_listener()
330 .once()
331 .withf(param_check)
332 .returning(|_source_filter, _sink_filter, _listener| Ok(()));
333
334 let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
335
336 let register_result = rpc_server
338 .register_endpoint(origin_filter.as_ref(), resource_id, request_handler.clone())
339 .await;
340 assert!(register_result.is_ok());
342 assert!(rpc_server.contains_endpoint(resource_id).await);
343
344 let unregister_result = rpc_server
346 .unregister_endpoint(origin_filter.as_ref(), resource_id, request_handler)
347 .await;
348 assert!(unregister_result.is_ok());
349 assert!(!rpc_server.contains_endpoint(resource_id).await);
350 }
351
352 #[test_case(None, 0x0000; "for resource ID 0")]
353 #[test_case(None, 0x8000; "for resource ID out of range")]
354 #[test_case(Some(UUri::try_from_parts("*", 0xFFFF, 0xFF, 0x0001).unwrap()), 0x4A10; "for source filter with invalid resource ID")]
355 #[tokio::test]
356 async fn test_register_endpoint_fails(origin_filter: Option<UUri>, resource_id: u16) {
357 let request_handler = Arc::new(MockRequestHandler::new());
359 let mut transport = MockTransport::new();
360 let uri_provider = new_uri_provider();
361 transport.expect_do_register_listener().never();
362 transport.expect_do_unregister_listener().never();
363
364 let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
365
366 let register_result = rpc_server
368 .register_endpoint(origin_filter.as_ref(), resource_id, request_handler.clone())
369 .await;
370 assert!(register_result.is_err_and(|e| matches!(e, RegistrationError::InvalidFilter(_v))));
372 assert!(!rpc_server.contains_endpoint(resource_id).await);
373
374 let unregister_result = rpc_server
376 .unregister_endpoint(origin_filter.as_ref(), resource_id, request_handler)
377 .await;
378 assert!(unregister_result.is_err_and(|e| matches!(e, RegistrationError::InvalidFilter(_v))));
379 }
380
381 #[tokio::test]
382 async fn test_register_endpoint_fails_for_duplicate_endpoint() {
383 let request_handler = Arc::new(MockRequestHandler::new());
385 let mut transport = MockTransport::new();
386 let uri_provider = new_uri_provider();
387 transport
388 .expect_do_register_listener()
389 .once()
390 .return_const(Ok(()));
391
392 let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
393
394 assert!(rpc_server
396 .register_endpoint(None, 0x5000, request_handler.clone())
397 .await
398 .is_ok());
399 let result = rpc_server
400 .register_endpoint(None, 0x5000, request_handler)
401 .await;
402
403 assert!(result.is_err_and(|e| matches!(e, RegistrationError::MaxListenersExceeded)));
405 assert!(rpc_server.contains_endpoint(0x5000).await);
407 }
408
409 #[tokio::test]
410 async fn test_unregister_endpoint_fails_for_non_existing_endpoint() {
411 let request_handler = Arc::new(MockRequestHandler::new());
413 let mut transport = MockTransport::new();
414 let uri_provider = new_uri_provider();
415 transport.expect_do_unregister_listener().never();
416
417 let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
418
419 assert!(!rpc_server.contains_endpoint(0x5000).await);
421 let result = rpc_server
422 .unregister_endpoint(None, 0x5000, request_handler)
423 .await;
424
425 assert!(result.is_err_and(|e| matches!(e, RegistrationError::NoSuchListener)));
427 }
428
429 #[tokio::test]
430 async fn test_request_listener_returns_response_for_invalid_request() {
431 let mut request_handler = MockRequestHandler::new();
433 let mut transport = MockTransport::new();
434 let notify = Arc::new(Notify::new());
435 let notify_clone = notify.clone();
436 let message_id = UUID::build();
437 let request_id = message_id.clone();
438
439 request_handler.expect_handle_request().never();
440 transport
441 .expect_do_send()
442 .once()
443 .withf(move |response_message| {
444 if !response_message.is_response() {
445 return false;
446 }
447 if response_message.request_id_unchecked() != &request_id {
448 return false;
449 }
450 let error: UStatus = response_message.extract_protobuf().unwrap();
451 error.get_code() == UCode::INVALID_ARGUMENT
452 && response_message.commstatus_unchecked() == error.get_code()
453 })
454 .returning(move |_msg| {
455 notify_clone.notify_one();
456 Ok(())
457 });
458
459 let invalid_request_attributes = UAttributes {
463 type_: UMessageType::UMESSAGE_TYPE_REQUEST.into(),
464 sink: UUri::try_from("up://localhost/A200/1/7000").ok().into(),
465 source: UUri::try_from("up://localhost/A100/1/0").ok().into(),
466 id: Some(message_id.clone()).into(),
467 priority: UPriority::UPRIORITY_CS5.into(),
468 ..Default::default()
469 };
470 assert!(
471 UAttributesValidators::Request
472 .validator()
473 .validate(&invalid_request_attributes)
474 .is_err(),
475 "request message attributes are supposed to be invalid (no TTL)"
476 );
477 let invalid_request_message = UMessage {
478 attributes: Some(invalid_request_attributes).into(),
479 ..Default::default()
480 };
481
482 let request_listener = RequestListener {
483 request_handler: Arc::new(request_handler),
484 transport: Arc::new(transport),
485 };
486 request_listener.on_receive(invalid_request_message).await;
487
488 let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
490 assert!(result.is_ok());
491 }
492
493 #[tokio::test]
494 async fn test_request_listener_ignores_invalid_request() {
495 let mut request_handler = MockRequestHandler::new();
497 request_handler.expect_handle_request().never();
498 let mut transport = MockTransport::new();
499 transport.expect_do_send().never();
500
501 let invalid_request_attributes = UAttributes {
505 type_: UMessageType::UMESSAGE_TYPE_REQUEST.into(),
506 sink: UUri::try_from("up://localhost/A200/1/7000").ok().into(),
507 source: UUri::try_from("up://localhost/A100/1/0").ok().into(),
508 ttl: Some(5_000),
509 id: None.into(),
510 priority: UPriority::UPRIORITY_CS5.into(),
511 ..Default::default()
512 };
513 assert!(
514 UAttributesValidators::Request
515 .validator()
516 .validate(&invalid_request_attributes)
517 .is_err(),
518 "request message attributes are supposed to be invalid (no ID)"
519 );
520 let invalid_request_message = UMessage {
521 attributes: Some(invalid_request_attributes).into(),
522 ..Default::default()
523 };
524
525 let request_listener = RequestListener {
526 request_handler: Arc::new(request_handler),
527 transport: Arc::new(transport),
528 };
529 request_listener.on_receive(invalid_request_message).await;
530
531 }
535
536 #[tokio::test]
537 async fn test_request_listener_invokes_operation_successfully() {
538 let mut request_handler = MockRequestHandler::new();
539 let mut transport = MockTransport::new();
540 let notify = Arc::new(Notify::new());
541 let notify_clone = notify.clone();
542 let request_payload = StringValue {
543 value: "Hello".to_string(),
544 ..Default::default()
545 };
546 let message_id = UUID::build();
547 let message_id_clone = message_id.clone();
548 let message_source = UUri::try_from("up://localhost/A100/1/0").unwrap();
549 let message_source_clone = message_source.clone();
550
551 request_handler
552 .expect_handle_request()
553 .once()
554 .withf(move |resource_id, message_attributes, request_payload| {
555 if let Some(pl) = request_payload {
556 let message_source = message_attributes.source.as_ref().unwrap();
557 let msg: StringValue = pl.extract_protobuf().unwrap();
558 msg.value == *"Hello"
559 && *resource_id == 0x7000_u16
560 && *message_source == message_source_clone
561 } else {
562 false
563 }
564 })
565 .returning(|_resource_id, _message_attributes, _request_payload| {
566 let response_payload = UPayload::try_from_protobuf(StringValue {
567 value: "Hello World".to_string(),
568 ..Default::default()
569 })
570 .unwrap();
571 Ok(Some(response_payload))
572 });
573 transport
574 .expect_do_send()
575 .once()
576 .withf(move |response_message| {
577 let msg: StringValue = response_message.extract_protobuf().unwrap();
578 msg.value == *"Hello World"
579 && response_message.is_response()
580 && response_message
581 .commstatus()
582 .is_none_or(|code| code == UCode::OK)
583 && response_message.request_id_unchecked() == &message_id_clone
584 })
585 .returning(move |_msg| {
586 notify_clone.notify_one();
587 Ok(())
588 });
589 let request_message = UMessageBuilder::request(
590 UUri::try_from("up://localhost/A200/1/7000").unwrap(),
591 message_source,
592 5_000,
593 )
594 .with_message_id(message_id)
595 .build_with_protobuf_payload(&request_payload)
596 .unwrap();
597
598 let request_listener = RequestListener {
599 request_handler: Arc::new(request_handler),
600 transport: Arc::new(transport),
601 };
602 request_listener.on_receive(request_message).await;
603 let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
604 assert!(result.is_ok());
605 }
606
607 #[tokio::test]
608 async fn test_request_listener_invokes_operation_erroneously() {
609 let mut request_handler = MockRequestHandler::new();
610 let mut transport = MockTransport::new();
611 let notify = Arc::new(Notify::new());
612 let notify_clone = notify.clone();
613 let message_id = UUID::build();
614 let message_id_clone = message_id.clone();
615
616 request_handler
617 .expect_handle_request()
618 .once()
619 .withf(|resource_id, _message_attributes, _request_payload| *resource_id == 0x7000_u16)
620 .returning(|_resource_id, _message_attributes, _request_payload| {
621 Err(ServiceInvocationError::NotFound(
622 "no such object".to_string(),
623 ))
624 });
625 transport
626 .expect_do_send()
627 .once()
628 .withf(move |response_message| {
629 let error: UStatus = response_message.extract_protobuf().unwrap();
630 error.get_code() == UCode::NOT_FOUND
631 && response_message.is_response()
632 && response_message.commstatus_unchecked() == error.get_code()
633 && response_message.request_id_unchecked() == &message_id_clone
634 })
635 .returning(move |_msg| {
636 notify_clone.notify_one();
637 Ok(())
638 });
639 let request_message = UMessageBuilder::request(
640 UUri::try_from("up://localhost/A200/1/7000").unwrap(),
641 UUri::try_from("up://localhost/A100/1/0").unwrap(),
642 5_000,
643 )
644 .with_message_id(message_id)
645 .build()
646 .unwrap();
647
648 let request_listener = RequestListener {
649 request_handler: Arc::new(request_handler),
650 transport: Arc::new(transport),
651 };
652 request_listener.on_receive(request_message).await;
653 let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
654 assert!(result.is_ok());
655 }
656
657 #[tokio::test]
658 async fn test_request_listener_times_out() {
659 struct NonRespondingHandler;
664 #[async_trait]
665 impl RequestHandler for NonRespondingHandler {
666 async fn handle_request(
667 &self,
668 resource_id: u16,
669 _message_attributes: &UAttributes,
670 _request_payload: Option<UPayload>,
671 ) -> Result<Option<UPayload>, ServiceInvocationError> {
672 assert_eq!(resource_id, 0x7000);
673 tokio::time::sleep(Duration::from_millis(2000)).await;
676 Ok(None)
677 }
678 }
679
680 let request_handler = NonRespondingHandler {};
681 let mut transport = MockTransport::new();
682 let notify = Arc::new(Notify::new());
683 let notify_clone = notify.clone();
684 let message_id = UUID::build();
685 let message_id_clone = message_id.clone();
686
687 transport
688 .expect_do_send()
689 .once()
690 .withf(move |response_message| {
691 let error: UStatus = response_message.extract_protobuf().unwrap();
692 error.get_code() == UCode::DEADLINE_EXCEEDED
693 && response_message.is_response()
694 && response_message.commstatus_unchecked() == error.get_code()
695 && response_message.request_id_unchecked() == &message_id_clone
696 })
697 .returning(move |_msg| {
698 notify_clone.notify_one();
699 Ok(())
700 });
701 let request_message = UMessageBuilder::request(
702 UUri::try_from("up://localhost/A200/1/7000").unwrap(),
703 UUri::try_from("up://localhost/A100/1/0").unwrap(),
704 100,
706 )
707 .with_message_id(message_id)
708 .build()
709 .expect("should have been able to create RPC Request message");
710
711 let request_listener = RequestListener {
712 request_handler: Arc::new(request_handler),
713 transport: Arc::new(transport),
714 };
715 request_listener.on_receive(request_message).await;
716 let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
717 assert!(result.is_ok());
718 }
719}