1use std::collections::hash_map::Entry;
17use std::collections::HashMap;
18use std::sync::{Arc, Mutex};
19use std::time::Duration;
20
21use async_trait::async_trait;
22use tokio::sync::oneshot::{Receiver, Sender};
23use tokio::time::timeout;
24use tracing::{debug, info};
25
26use crate::{
27 LocalUriProvider, UCode, UListener, UMessage, UMessageBuilder, UMessageType, UStatus,
28 UTransport, UUri, UUID,
29};
30
31use super::{
32 build_message, CallOptions, RegistrationError, RpcClient, ServiceInvocationError, UPayload,
33};
34
35fn handle_response_message(response: UMessage) -> Result<Option<UPayload>, ServiceInvocationError> {
36 let Some(attribs) = response.attributes.as_ref() else {
37 return Err(ServiceInvocationError::InvalidArgument(
38 "response message does not contain attributes".to_string(),
39 ));
40 };
41
42 match attribs.commstatus.map(|v| v.enum_value_or_default()) {
43 Some(UCode::OK) | None => {
44 response.payload.map_or(Ok(None), |payload| {
46 Ok(Some(UPayload::new(
47 payload,
48 attribs.payload_format.enum_value_or_default(),
49 )))
50 })
51 }
52 Some(code) => {
53 let status = response.extract_protobuf().unwrap_or_else(|_e| {
55 UStatus::fail_with_code(code, "failed to invoke service operation")
56 });
57 Err(ServiceInvocationError::from(status))
58 }
59 }
60}
61
62struct ResponseListener {
63 pending_requests: Mutex<HashMap<UUID, Sender<UMessage>>>,
65}
66
67impl ResponseListener {
68 fn try_add_pending_request(
69 &self,
70 reqid: UUID,
71 ) -> Result<Receiver<UMessage>, ServiceInvocationError> {
72 let Ok(mut pending_requests) = self.pending_requests.lock() else {
73 return Err(ServiceInvocationError::Internal(
74 "failed to add response handler".to_string(),
75 ));
76 };
77
78 if let Entry::Vacant(entry) = pending_requests.entry(reqid) {
79 let (tx, rx) = tokio::sync::oneshot::channel();
80 entry.insert(tx);
81 Ok(rx)
82 } else {
83 Err(ServiceInvocationError::AlreadyExists(
84 "RPC request with given ID already pending".to_string(),
85 ))
86 }
87 }
88
89 fn handle_response(&self, reqid: &UUID, response_message: UMessage) {
90 let Ok(mut pending_requests) = self.pending_requests.lock() else {
91 info!(
92 request_id = reqid.to_hyphenated_string(),
93 "failed to process response message, cannot acquire lock for pending requests map"
94 );
95 return;
96 };
97 if let Some(sender) = pending_requests.remove(reqid) {
98 if let Err(_e) = sender.send(response_message) {
99 debug!(
101 request_id = reqid.to_hyphenated_string(),
102 "failed to deliver RPC Response message, channel already closed"
103 );
104 } else {
105 debug!(
106 request_id = reqid.to_hyphenated_string(),
107 "successfully delivered RPC Response message"
108 )
109 }
110 } else {
111 debug!(
113 request_id = reqid.to_hyphenated_string(),
114 "ignoring (duplicate?) RPC Response message with unknown request ID"
115 );
116 }
117 }
118
119 fn remove_pending_request(&self, reqid: &UUID) -> Option<Sender<UMessage>> {
120 self.pending_requests
121 .lock()
122 .map_or(None, |mut pending_requests| pending_requests.remove(reqid))
123 }
124
125 #[cfg(test)]
126 fn contains(&self, reqid: &UUID) -> bool {
127 self.pending_requests
128 .lock()
129 .is_ok_and(|pending_requests| pending_requests.contains_key(reqid))
130 }
131}
132
133#[async_trait]
134impl UListener for ResponseListener {
135 async fn on_receive(&self, msg: UMessage) {
136 let message_type = msg
137 .attributes
138 .get_or_default()
139 .type_
140 .enum_value_or_default();
141 if message_type != UMessageType::UMESSAGE_TYPE_RESPONSE {
142 debug!(
143 message_type = message_type.to_cloudevent_type(),
144 "service provider replied with message that is not an RPC Response"
145 );
146 return;
147 }
148
149 if let Some(reqid) = msg
150 .attributes
151 .as_ref()
152 .and_then(|attribs| attribs.reqid.clone().into_option())
153 {
154 self.handle_response(&reqid, msg);
155 } else {
156 debug!("ignoring malformed response message not containing request ID");
157 }
158 }
159}
160
161pub struct InMemoryRpcClient {
175 transport: Arc<dyn UTransport>,
176 uri_provider: Arc<dyn LocalUriProvider>,
177 response_listener: Arc<ResponseListener>,
178}
179
180impl InMemoryRpcClient {
181 pub async fn new(
193 transport: Arc<dyn UTransport>,
194 uri_provider: Arc<dyn LocalUriProvider>,
195 ) -> Result<Self, RegistrationError> {
196 let response_listener = Arc::new(ResponseListener {
197 pending_requests: Mutex::new(HashMap::new()),
198 });
199 transport
200 .register_listener(
201 &UUri::any(),
202 Some(&uri_provider.get_source_uri()),
203 response_listener.clone(),
204 )
205 .await
206 .map_err(RegistrationError::from)?;
207
208 Ok(InMemoryRpcClient {
209 transport,
210 uri_provider,
211 response_listener,
212 })
213 }
214
215 #[cfg(test)]
216 fn contains_pending_request(&self, reqid: &UUID) -> bool {
217 self.response_listener.contains(reqid)
218 }
219}
220
221#[async_trait]
222impl RpcClient for InMemoryRpcClient {
223 async fn invoke_method(
224 &self,
225 method: UUri,
226 call_options: CallOptions,
227 payload: Option<UPayload>,
228 ) -> Result<Option<UPayload>, ServiceInvocationError> {
229 let message_id = call_options.message_id().unwrap_or_else(UUID::build);
230
231 let mut builder = UMessageBuilder::request(
232 method.clone(),
233 self.uri_provider.get_source_uri(),
234 call_options.ttl(),
235 );
236 builder.with_message_id(message_id.clone());
237 if let Some(token) = call_options.token() {
238 builder.with_token(token.to_owned());
239 }
240 if let Some(priority) = call_options.priority() {
241 builder.with_priority(priority);
242 }
243 let rpc_request_message = build_message(&mut builder, payload)
244 .map_err(|e| ServiceInvocationError::InvalidArgument(e.to_string()))?;
245
246 let receiver = self
247 .response_listener
248 .try_add_pending_request(message_id.clone())?;
249 self.transport
250 .send(rpc_request_message)
251 .await
252 .map_err(|e| {
253 self.response_listener.remove_pending_request(&message_id);
254 e
255 })?;
256 debug!(
257 request_id = message_id.to_hyphenated_string(),
258 ttl = call_options.ttl(),
259 "successfully sent RPC Request message"
260 );
261
262 match timeout(Duration::from_millis(call_options.ttl() as u64), receiver).await {
263 Err(_) => {
264 debug!(
265 request_id = message_id.to_hyphenated_string(),
266 ttl = call_options.ttl(),
267 "invocation of service operation has timed out"
268 );
269 self.response_listener.remove_pending_request(&message_id);
270 Err(ServiceInvocationError::DeadlineExceeded)
271 }
272 Ok(result) => match result {
273 Ok(response_message) => handle_response_message(response_message),
274 Err(_e) => {
275 debug!(
276 request_id = message_id.to_hyphenated_string(),
277 "response listener failed to forward response message"
278 );
279 self.response_listener.remove_pending_request(&message_id);
280 Err(ServiceInvocationError::Internal(
281 "error receiving response message".to_string(),
282 ))
283 }
284 },
285 }
286 }
287}
288
289#[cfg(test)]
290mod tests {
291
292 use super::*;
295
296 use protobuf::{well_known_types::wrappers::StringValue, Enum};
297 use tokio::{join, sync::Notify};
298
299 use crate::{utransport::MockTransport, StaticUriProvider, UMessageBuilder, UPriority, UUri};
300
301 fn new_uri_provider() -> Arc<dyn LocalUriProvider> {
302 Arc::new(StaticUriProvider::new("", 0x0005, 0x02))
303 }
304
305 fn service_method_uri() -> UUri {
306 UUri {
307 ue_id: 0x0001,
308 ue_version_major: 0x01,
309 resource_id: 0x1000,
310 ..Default::default()
311 }
312 }
313
314 #[tokio::test]
315 async fn test_registration_of_response_listener_fails() {
316 let mut mock_transport = MockTransport::default();
318 mock_transport
320 .expect_do_register_listener()
321 .once()
322 .returning(|_source_filter, _sink_filter, _listener| {
323 Err(UStatus::fail_with_code(
324 UCode::RESOURCE_EXHAUSTED,
325 "max number of listeners exceeded",
326 ))
327 });
328
329 let creation_attempt =
331 InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider()).await;
332
333 assert!(
335 creation_attempt.is_err_and(|e| matches!(e, RegistrationError::MaxListenersExceeded))
336 );
337 }
338
339 #[tokio::test]
340 async fn test_invoke_method_fails_with_transport_error() {
341 let mut mock_transport = MockTransport::default();
343 mock_transport
344 .expect_do_register_listener()
345 .once()
346 .returning(|_source_filter, _sink_filter, _listener| Ok(()));
347 mock_transport
349 .expect_do_send()
350 .returning(|_request_message| {
351 Err(UStatus::fail_with_code(
352 UCode::UNAVAILABLE,
353 "transport not available",
354 ))
355 });
356 let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
357 .await
358 .unwrap();
359
360 let message_id = UUID::build();
362 let call_options =
363 CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
364 let response = client
365 .invoke_method(service_method_uri(), call_options, None)
366 .await;
367
368 assert!(response.is_err_and(|e| matches!(e, ServiceInvocationError::Unavailable(_msg))));
370 assert!(!client.contains_pending_request(&message_id));
371 }
372
373 #[tokio::test]
374 async fn test_invoke_method_succeeds() {
375 let message_id = UUID::build();
376 let call_options = CallOptions::for_rpc_request(
377 5_000,
378 Some(message_id.clone()),
379 Some("my_token".to_string()),
380 Some(crate::UPriority::UPRIORITY_CS6),
381 );
382
383 let (captured_listener_tx, captured_listener_rx) = tokio::sync::oneshot::channel();
384 let request_sent = Arc::new(Notify::new());
385 let request_sent_clone = request_sent.clone();
386
387 let mut mock_transport = MockTransport::default();
389 mock_transport
390 .expect_do_register_listener()
391 .once()
392 .return_once(move |_source_filter, _sink_filter, listener| {
393 captured_listener_tx
394 .send(listener)
395 .map_err(|_e| UStatus::fail("cannot capture listener"))
396 });
397 let expected_message_id = message_id.clone();
398 mock_transport
399 .expect_do_send()
400 .once()
401 .withf(move |request_message| {
402 request_message.attributes.as_ref().is_some_and(|attribs| {
403 attribs.id.as_ref() == Some(&expected_message_id)
404 && attribs.priority.value() == UPriority::UPRIORITY_CS6.value()
405 && attribs.ttl == Some(5_000)
406 && attribs.token == Some("my_token".to_string())
407 })
408 })
409 .returning(move |_request_message| {
410 request_sent_clone.notify_one();
411 Ok(())
412 });
413
414 let uri_provider = new_uri_provider();
415 let rpc_client = Arc::new(
416 InMemoryRpcClient::new(Arc::new(mock_transport), uri_provider.clone())
417 .await
418 .unwrap(),
419 );
420 let client: Arc<dyn RpcClient> = rpc_client.clone();
421
422 let response_handle = tokio::spawn(async move {
424 let request_payload = StringValue {
425 value: "World".to_string(),
426 ..Default::default()
427 };
428 client
429 .invoke_proto_method::<_, StringValue>(
430 service_method_uri(),
431 call_options,
432 request_payload,
433 )
434 .await
435 });
436
437 let response_payload = StringValue {
439 value: "Hello World".to_string(),
440 ..Default::default()
441 };
442 let response_message = UMessageBuilder::response(
443 uri_provider.get_source_uri(),
444 message_id.clone(),
445 service_method_uri(),
446 )
447 .build_with_protobuf_payload(&response_payload)
448 .unwrap();
449
450 let (response_listener_result, _) = join!(captured_listener_rx, request_sent.notified());
452 let response_listener = response_listener_result.unwrap();
453
454 let cloned_response_message = response_message.clone();
456 let cloned_response_listener = response_listener.clone();
457 tokio::spawn(async move {
458 cloned_response_listener
459 .on_receive(cloned_response_message)
460 .await
461 });
462
463 let response = response_handle.await.unwrap();
465 assert!(response.is_ok_and(|payload| payload.value == *"Hello World"));
466 assert!(!rpc_client.contains_pending_request(&message_id));
467
468 response_listener.on_receive(response_message).await;
470 assert!(!rpc_client.contains_pending_request(&message_id));
472 }
473
474 #[tokio::test]
475 async fn test_invoke_method_fails_on_repeated_invocation() {
476 let message_id = UUID::build();
477 let first_request_sent = Arc::new(Notify::new());
478 let first_request_sent_clone = first_request_sent.clone();
479
480 let mut mock_transport = MockTransport::default();
482 mock_transport
483 .expect_do_register_listener()
484 .once()
485 .return_const(Ok(()));
486 let expected_message_id = message_id.clone();
487 mock_transport
488 .expect_do_send()
489 .once()
490 .withf(move |request_message| {
491 request_message
492 .attributes
493 .as_ref()
494 .is_some_and(|attribs| attribs.id.as_ref() == Some(&expected_message_id))
495 })
496 .returning(move |_request_message| {
497 first_request_sent_clone.notify_one();
498 Ok(())
499 });
500
501 let in_memory_rpc_client = Arc::new(
502 InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
503 .await
504 .unwrap(),
505 );
506 let rpc_client: Arc<dyn RpcClient> = in_memory_rpc_client.clone();
507
508 let call_options =
510 CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
511 let cloned_call_options = call_options.clone();
512 let cloned_rpc_client = rpc_client.clone();
513
514 tokio::spawn(async move {
515 let request_payload = StringValue {
516 value: "World".to_string(),
517 ..Default::default()
518 };
519 cloned_rpc_client
520 .invoke_proto_method::<_, StringValue>(
521 service_method_uri(),
522 cloned_call_options,
523 request_payload,
524 )
525 .await
526 });
527
528 first_request_sent.notified().await;
532
533 let request_payload = StringValue {
535 value: "World".to_string(),
536 ..Default::default()
537 };
538 let second_request_handle = tokio::spawn(async move {
539 rpc_client
540 .invoke_proto_method::<_, StringValue>(
541 service_method_uri(),
542 call_options,
543 request_payload,
544 )
545 .await
546 });
547
548 let response = second_request_handle.await.unwrap();
550 assert!(response.is_err_and(|e| matches!(e, ServiceInvocationError::AlreadyExists(_))));
551 assert!(in_memory_rpc_client.contains_pending_request(&message_id));
553 }
554
555 #[tokio::test]
556 async fn test_invoke_method_fails_with_remote_error() {
557 let (captured_listener_tx, captured_listener_rx) = std::sync::mpsc::channel();
558
559 let mut mock_transport = MockTransport::default();
561 mock_transport.expect_do_register_listener().returning(
562 move |_source_filter, _sink_filter, listener| {
563 captured_listener_tx
564 .send(listener)
565 .map_err(|_e| UStatus::fail("cannot capture listener"))
566 },
567 );
568 mock_transport
570 .expect_do_send()
571 .returning(move |request_message| {
572 let error = UStatus::fail_with_code(UCode::NOT_FOUND, "no such object");
573 let response_message = UMessageBuilder::response_for_request(
574 request_message.attributes.as_ref().unwrap(),
575 )
576 .with_comm_status(UCode::NOT_FOUND)
577 .build_with_protobuf_payload(&error)
578 .unwrap();
579 let captured_listener = captured_listener_rx.recv().unwrap().to_owned();
580 tokio::spawn(async move { captured_listener.on_receive(response_message).await });
581 Ok(())
582 });
583
584 let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
585 .await
586 .unwrap();
587
588 let message_id = UUID::build();
590 let call_options =
591 CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
592 let response = client
593 .invoke_method(service_method_uri(), call_options, None)
594 .await;
595
596 assert!(response.is_err_and(|e| { matches!(e, ServiceInvocationError::NotFound(_msg)) }));
598 assert!(!client.contains_pending_request(&message_id));
599 }
600
601 #[tokio::test]
602 async fn test_invoke_method_times_out() {
603 let mut mock_transport = MockTransport::default();
605 mock_transport
606 .expect_do_register_listener()
607 .returning(|_source_filter, _sink_filter, _listener| Ok(()));
608 mock_transport
610 .expect_do_send()
611 .returning(|_request_message| Ok(()));
612
613 let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
614 .await
615 .unwrap();
616
617 let message_id = UUID::build();
619 let call_options = CallOptions::for_rpc_request(20, Some(message_id.clone()), None, None);
620 let response = client
621 .invoke_method(service_method_uri(), call_options, None)
622 .await;
623
624 assert!(response.is_err_and(|e| { matches!(e, ServiceInvocationError::DeadlineExceeded) }));
626 assert!(!client.contains_pending_request(&message_id));
627 }
628
629 #[test]
630 fn test_handle_response_message_fails_for_missing_attributes() {
631 let response_msg = UMessage {
632 ..Default::default()
633 };
634 let result = handle_response_message(response_msg);
635 assert!(result.is_err_and(|e| matches!(e, ServiceInvocationError::InvalidArgument(_))));
636 }
637}