1use std::collections::hash_map::Entry;
17use std::collections::HashMap;
18use std::sync::{Arc, Mutex};
19use std::time::Duration;
20
21use async_trait::async_trait;
22use tokio::sync::oneshot::{Receiver, Sender};
23use tokio::time::timeout;
24use tracing::{debug, info};
25
26use crate::{
27 LocalUriProvider, UCode, UListener, UMessage, UMessageBuilder, UMessageType, UStatus,
28 UTransport, UUri, UUID,
29};
30
31use super::{
32 build_message, CallOptions, RegistrationError, RpcClient, ServiceInvocationError, UPayload,
33};
34
35fn handle_response_message(response: UMessage) -> Result<Option<UPayload>, ServiceInvocationError> {
36 let Some(attribs) = response.attributes.as_ref() else {
37 return Err(ServiceInvocationError::InvalidArgument(
38 "response message does not contain attributes".to_string(),
39 ));
40 };
41
42 match attribs.commstatus.map(|v| v.enum_value_or_default()) {
43 Some(UCode::OK) | None => {
44 response.payload.map_or(Ok(None), |payload| {
46 Ok(Some(UPayload::new(
47 payload,
48 attribs.payload_format.enum_value_or_default(),
49 )))
50 })
51 }
52 Some(code) => {
53 let status = response.extract_protobuf().unwrap_or_else(|_e| {
55 UStatus::fail_with_code(code, "failed to invoke service operation")
56 });
57 Err(ServiceInvocationError::from(status))
58 }
59 }
60}
61
62struct ResponseListener {
63 pending_requests: Mutex<HashMap<UUID, Sender<UMessage>>>,
65}
66
67impl ResponseListener {
68 fn try_add_pending_request(
69 &self,
70 reqid: UUID,
71 ) -> Result<Receiver<UMessage>, ServiceInvocationError> {
72 let Ok(mut pending_requests) = self.pending_requests.lock() else {
73 return Err(ServiceInvocationError::Internal(
74 "failed to add response handler".to_string(),
75 ));
76 };
77
78 if let Entry::Vacant(entry) = pending_requests.entry(reqid) {
79 let (tx, rx) = tokio::sync::oneshot::channel();
80 entry.insert(tx);
81 Ok(rx)
82 } else {
83 Err(ServiceInvocationError::AlreadyExists(
84 "RPC request with given ID already pending".to_string(),
85 ))
86 }
87 }
88
89 fn handle_response(&self, reqid: &UUID, response_message: UMessage) {
90 let Ok(mut pending_requests) = self.pending_requests.lock() else {
91 info!(
92 request_id = reqid.to_hyphenated_string(),
93 "failed to process response message, cannot acquire lock for pending requests map"
94 );
95 return;
96 };
97 if let Some(sender) = pending_requests.remove(reqid) {
98 if let Err(_e) = sender.send(response_message) {
99 debug!(
101 request_id = reqid.to_hyphenated_string(),
102 "failed to deliver RPC Response message, channel already closed"
103 );
104 } else {
105 debug!(
106 request_id = reqid.to_hyphenated_string(),
107 "successfully delivered RPC Response message"
108 )
109 }
110 } else {
111 debug!(
113 request_id = reqid.to_hyphenated_string(),
114 "ignoring (duplicate?) RPC Response message with unknown request ID"
115 );
116 }
117 }
118
119 fn remove_pending_request(&self, reqid: &UUID) -> Option<Sender<UMessage>> {
120 self.pending_requests
121 .lock()
122 .map_or(None, |mut pending_requests| pending_requests.remove(reqid))
123 }
124
125 #[cfg(test)]
126 fn contains(&self, reqid: &UUID) -> bool {
127 self.pending_requests
128 .lock()
129 .is_ok_and(|pending_requests| pending_requests.contains_key(reqid))
130 }
131}
132
133#[async_trait]
134impl UListener for ResponseListener {
135 async fn on_receive(&self, msg: UMessage) {
136 let message_type = msg
137 .attributes
138 .get_or_default()
139 .type_
140 .enum_value_or_default();
141 if message_type != UMessageType::UMESSAGE_TYPE_RESPONSE {
142 debug!(
143 message_type = message_type.to_cloudevent_type(),
144 "service provider replied with message that is not an RPC Response"
145 );
146 return;
147 }
148
149 if let Some(reqid) = msg
150 .attributes
151 .as_ref()
152 .and_then(|attribs| attribs.reqid.clone().into_option())
153 {
154 self.handle_response(&reqid, msg);
155 } else {
156 debug!("ignoring malformed response message not containing request ID");
157 }
158 }
159}
160
161pub struct InMemoryRpcClient {
175 transport: Arc<dyn UTransport>,
176 uri_provider: Arc<dyn LocalUriProvider>,
177 response_listener: Arc<ResponseListener>,
178}
179
180impl InMemoryRpcClient {
181 pub async fn new(
193 transport: Arc<dyn UTransport>,
194 uri_provider: Arc<dyn LocalUriProvider>,
195 ) -> Result<Self, RegistrationError> {
196 let response_listener = Arc::new(ResponseListener {
197 pending_requests: Mutex::new(HashMap::new()),
198 });
199 transport
200 .register_listener(
201 &UUri::any(),
202 Some(&uri_provider.get_source_uri()),
203 response_listener.clone(),
204 )
205 .await
206 .map_err(RegistrationError::from)?;
207
208 Ok(InMemoryRpcClient {
209 transport,
210 uri_provider,
211 response_listener,
212 })
213 }
214
215 #[cfg(test)]
216 fn contains_pending_request(&self, reqid: &UUID) -> bool {
217 self.response_listener.contains(reqid)
218 }
219}
220
221#[async_trait]
222impl RpcClient for InMemoryRpcClient {
223 async fn invoke_method(
224 &self,
225 method: UUri,
226 call_options: CallOptions,
227 payload: Option<UPayload>,
228 ) -> Result<Option<UPayload>, ServiceInvocationError> {
229 let message_id = call_options.message_id().unwrap_or_else(UUID::build);
230
231 let mut builder = UMessageBuilder::request(
232 method.clone(),
233 self.uri_provider.get_source_uri(),
234 call_options.ttl(),
235 );
236 builder.with_message_id(message_id.clone());
237 if let Some(token) = call_options.token() {
238 builder.with_token(token.to_owned());
239 }
240 if let Some(priority) = call_options.priority() {
241 builder.with_priority(priority);
242 }
243 let rpc_request_message = build_message(&mut builder, payload)
244 .map_err(|e| ServiceInvocationError::InvalidArgument(e.to_string()))?;
245
246 let receiver = self
247 .response_listener
248 .try_add_pending_request(message_id.clone())?;
249 self.transport
250 .send(rpc_request_message)
251 .await
252 .inspect_err(|_e| {
253 self.response_listener.remove_pending_request(&message_id);
254 })?;
255 debug!(
256 request_id = message_id.to_hyphenated_string(),
257 ttl = call_options.ttl(),
258 "successfully sent RPC Request message"
259 );
260
261 match timeout(Duration::from_millis(call_options.ttl() as u64), receiver).await {
262 Err(_) => {
263 debug!(
264 request_id = message_id.to_hyphenated_string(),
265 ttl = call_options.ttl(),
266 "invocation of service operation has timed out"
267 );
268 self.response_listener.remove_pending_request(&message_id);
269 Err(ServiceInvocationError::DeadlineExceeded)
270 }
271 Ok(result) => match result {
272 Ok(response_message) => handle_response_message(response_message),
273 Err(_e) => {
274 debug!(
275 request_id = message_id.to_hyphenated_string(),
276 "response listener failed to forward response message"
277 );
278 self.response_listener.remove_pending_request(&message_id);
279 Err(ServiceInvocationError::Internal(
280 "error receiving response message".to_string(),
281 ))
282 }
283 },
284 }
285 }
286}
287
288#[cfg(test)]
289mod tests {
290
291 use super::*;
294
295 use protobuf::well_known_types::wrappers::StringValue;
296 use tokio::{join, sync::Notify};
297
298 use crate::{utransport::MockTransport, StaticUriProvider, UMessageBuilder, UPriority, UUri};
299
300 fn new_uri_provider() -> Arc<dyn LocalUriProvider> {
301 Arc::new(StaticUriProvider::new("", 0x0005, 0x02))
302 }
303
304 fn service_method_uri() -> UUri {
305 UUri {
306 ue_id: 0x0001,
307 ue_version_major: 0x01,
308 resource_id: 0x1000,
309 ..Default::default()
310 }
311 }
312
313 #[tokio::test]
314 async fn test_registration_of_response_listener_fails() {
315 let mut mock_transport = MockTransport::default();
317 mock_transport
319 .expect_do_register_listener()
320 .once()
321 .returning(|_source_filter, _sink_filter, _listener| {
322 Err(UStatus::fail_with_code(
323 UCode::RESOURCE_EXHAUSTED,
324 "max number of listeners exceeded",
325 ))
326 });
327
328 let creation_attempt =
330 InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider()).await;
331
332 assert!(
334 creation_attempt.is_err_and(|e| matches!(e, RegistrationError::MaxListenersExceeded))
335 );
336 }
337
338 #[tokio::test]
339 async fn test_invoke_method_fails_with_transport_error() {
340 let mut mock_transport = MockTransport::default();
342 mock_transport
343 .expect_do_register_listener()
344 .once()
345 .returning(|_source_filter, _sink_filter, _listener| Ok(()));
346 mock_transport
348 .expect_do_send()
349 .returning(|_request_message| {
350 Err(UStatus::fail_with_code(
351 UCode::UNAVAILABLE,
352 "transport not available",
353 ))
354 });
355 let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
356 .await
357 .unwrap();
358
359 let message_id = UUID::build();
361 let call_options =
362 CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
363 let response = client
364 .invoke_method(service_method_uri(), call_options, None)
365 .await;
366
367 assert!(response.is_err_and(|e| matches!(e, ServiceInvocationError::Unavailable(_msg))));
369 assert!(!client.contains_pending_request(&message_id));
370 }
371
372 #[tokio::test]
373 async fn test_invoke_method_succeeds() {
374 let message_id = UUID::build();
375 let call_options = CallOptions::for_rpc_request(
376 5_000,
377 Some(message_id.clone()),
378 Some("my_token".to_string()),
379 Some(crate::UPriority::UPRIORITY_CS6),
380 );
381
382 let (captured_listener_tx, captured_listener_rx) = tokio::sync::oneshot::channel();
383 let request_sent = Arc::new(Notify::new());
384 let request_sent_clone = request_sent.clone();
385
386 let mut mock_transport = MockTransport::default();
388 mock_transport
389 .expect_do_register_listener()
390 .once()
391 .return_once(move |_source_filter, _sink_filter, listener| {
392 captured_listener_tx
393 .send(listener)
394 .map_err(|_e| UStatus::fail("cannot capture listener"))
395 });
396 let expected_message_id = message_id.clone();
397 mock_transport
398 .expect_do_send()
399 .once()
400 .withf(move |request_message| {
401 request_message.id_unchecked() == &expected_message_id
402 && request_message.priority_unchecked() == UPriority::UPRIORITY_CS6
403 && request_message.ttl_unchecked() == 5_000
404 && request_message.token() == Some(&String::from("my_token"))
405 })
406 .returning(move |_request_message| {
407 request_sent_clone.notify_one();
408 Ok(())
409 });
410
411 let uri_provider = new_uri_provider();
412 let rpc_client = Arc::new(
413 InMemoryRpcClient::new(Arc::new(mock_transport), uri_provider.clone())
414 .await
415 .unwrap(),
416 );
417 let client: Arc<dyn RpcClient> = rpc_client.clone();
418
419 let response_handle = tokio::spawn(async move {
421 let request_payload = StringValue {
422 value: "World".to_string(),
423 ..Default::default()
424 };
425 client
426 .invoke_proto_method::<_, StringValue>(
427 service_method_uri(),
428 call_options,
429 request_payload,
430 )
431 .await
432 });
433
434 let response_payload = StringValue {
436 value: "Hello World".to_string(),
437 ..Default::default()
438 };
439 let response_message = UMessageBuilder::response(
440 uri_provider.get_source_uri(),
441 message_id.clone(),
442 service_method_uri(),
443 )
444 .build_with_protobuf_payload(&response_payload)
445 .unwrap();
446
447 let (response_listener_result, _) = join!(captured_listener_rx, request_sent.notified());
449 let response_listener = response_listener_result.unwrap();
450
451 let cloned_response_message = response_message.clone();
453 let cloned_response_listener = response_listener.clone();
454 tokio::spawn(async move {
455 cloned_response_listener
456 .on_receive(cloned_response_message)
457 .await
458 });
459
460 let response = response_handle.await.unwrap();
462 assert!(response.is_ok_and(|payload| payload.value == *"Hello World"));
463 assert!(!rpc_client.contains_pending_request(&message_id));
464
465 response_listener.on_receive(response_message).await;
467 assert!(!rpc_client.contains_pending_request(&message_id));
469 }
470
471 #[tokio::test]
472 async fn test_invoke_method_fails_on_repeated_invocation() {
473 let message_id = UUID::build();
474 let first_request_sent = Arc::new(Notify::new());
475 let first_request_sent_clone = first_request_sent.clone();
476
477 let mut mock_transport = MockTransport::default();
479 mock_transport
480 .expect_do_register_listener()
481 .once()
482 .return_const(Ok(()));
483 let expected_message_id = message_id.clone();
484 mock_transport
485 .expect_do_send()
486 .once()
487 .withf(move |request_message| request_message.id_unchecked() == &expected_message_id)
488 .returning(move |_request_message| {
489 first_request_sent_clone.notify_one();
490 Ok(())
491 });
492
493 let in_memory_rpc_client = Arc::new(
494 InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
495 .await
496 .unwrap(),
497 );
498 let rpc_client: Arc<dyn RpcClient> = in_memory_rpc_client.clone();
499
500 let call_options =
502 CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
503 let cloned_call_options = call_options.clone();
504 let cloned_rpc_client = rpc_client.clone();
505
506 tokio::spawn(async move {
507 let request_payload = StringValue {
508 value: "World".to_string(),
509 ..Default::default()
510 };
511 cloned_rpc_client
512 .invoke_proto_method::<_, StringValue>(
513 service_method_uri(),
514 cloned_call_options,
515 request_payload,
516 )
517 .await
518 });
519
520 first_request_sent.notified().await;
524
525 let request_payload = StringValue {
527 value: "World".to_string(),
528 ..Default::default()
529 };
530 let second_request_handle = tokio::spawn(async move {
531 rpc_client
532 .invoke_proto_method::<_, StringValue>(
533 service_method_uri(),
534 call_options,
535 request_payload,
536 )
537 .await
538 });
539
540 let response = second_request_handle.await.unwrap();
542 assert!(response.is_err_and(|e| matches!(e, ServiceInvocationError::AlreadyExists(_))));
543 assert!(in_memory_rpc_client.contains_pending_request(&message_id));
545 }
546
547 #[tokio::test]
548 async fn test_invoke_method_fails_with_remote_error() {
549 let (captured_listener_tx, captured_listener_rx) = std::sync::mpsc::channel();
550
551 let mut mock_transport = MockTransport::default();
553 mock_transport.expect_do_register_listener().returning(
554 move |_source_filter, _sink_filter, listener| {
555 captured_listener_tx
556 .send(listener)
557 .map_err(|_e| UStatus::fail("cannot capture listener"))
558 },
559 );
560 mock_transport
562 .expect_do_send()
563 .returning(move |request_message| {
564 let error = UStatus::fail_with_code(UCode::NOT_FOUND, "no such object");
565 let response_message = UMessageBuilder::response_for_request(
566 request_message.attributes.as_ref().unwrap(),
567 )
568 .with_comm_status(UCode::NOT_FOUND)
569 .build_with_protobuf_payload(&error)
570 .unwrap();
571 let captured_listener = captured_listener_rx.recv().unwrap().to_owned();
572 tokio::spawn(async move { captured_listener.on_receive(response_message).await });
573 Ok(())
574 });
575
576 let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
577 .await
578 .unwrap();
579
580 let message_id = UUID::build();
582 let call_options =
583 CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
584 let response = client
585 .invoke_method(service_method_uri(), call_options, None)
586 .await;
587
588 assert!(response.is_err_and(|e| { matches!(e, ServiceInvocationError::NotFound(_msg)) }));
590 assert!(!client.contains_pending_request(&message_id));
591 }
592
593 #[tokio::test]
594 async fn test_invoke_method_times_out() {
595 let mut mock_transport = MockTransport::default();
597 mock_transport
598 .expect_do_register_listener()
599 .returning(|_source_filter, _sink_filter, _listener| Ok(()));
600 mock_transport
602 .expect_do_send()
603 .returning(|_request_message| Ok(()));
604
605 let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
606 .await
607 .unwrap();
608
609 let message_id = UUID::build();
611 let call_options = CallOptions::for_rpc_request(20, Some(message_id.clone()), None, None);
612 let response = client
613 .invoke_method(service_method_uri(), call_options, None)
614 .await;
615
616 assert!(response.is_err_and(|e| { matches!(e, ServiceInvocationError::DeadlineExceeded) }));
618 assert!(!client.contains_pending_request(&message_id));
619 }
620
621 #[test]
622 fn test_handle_response_message_fails_for_missing_attributes() {
623 let response_msg = UMessage {
624 ..Default::default()
625 };
626 let result = handle_response_message(response_msg);
627 assert!(result.is_err_and(|e| matches!(e, ServiceInvocationError::InvalidArgument(_))));
628 }
629}