1use std::collections::hash_map::Entry;
17use std::collections::HashMap;
18use std::sync::{Arc, Mutex};
19use std::time::Duration;
20
21use async_trait::async_trait;
22use tokio::sync::oneshot::{Receiver, Sender};
23use tokio::time::timeout;
24use tracing::{debug, info};
25
26use crate::{
27 LocalUriProvider, UCode, UListener, UMessage, UMessageBuilder, UStatus, UTransport, UUri, UUID,
28};
29
30use super::{
31 build_message, CallOptions, RegistrationError, RpcClient, ServiceInvocationError, UPayload,
32};
33
34fn handle_response_message(response: UMessage) -> Result<Option<UPayload>, ServiceInvocationError> {
36 match response.commstatus() {
37 Some(UCode::OK) | None => {
38 let payload_format = response.payload_format().unwrap_or_default();
40 Ok(response
41 .payload
42 .map(|payload| UPayload::new(payload, payload_format)))
43 }
44 Some(code) => {
45 let status = response.extract_protobuf().unwrap_or_else(|_e| {
47 UStatus::fail_with_code(code, "failed to invoke service operation")
48 });
49 Err(ServiceInvocationError::from(status))
50 }
51 }
52}
53
54struct ResponseListener {
55 pending_requests: Mutex<HashMap<UUID, Sender<UMessage>>>,
57}
58
59impl ResponseListener {
60 fn try_add_pending_request(
61 &self,
62 reqid: UUID,
63 ) -> Result<Receiver<UMessage>, ServiceInvocationError> {
64 let Ok(mut pending_requests) = self.pending_requests.lock() else {
65 return Err(ServiceInvocationError::Internal(
66 "failed to add response handler".to_string(),
67 ));
68 };
69
70 if let Entry::Vacant(entry) = pending_requests.entry(reqid) {
71 let (tx, rx) = tokio::sync::oneshot::channel();
72 entry.insert(tx);
73 Ok(rx)
74 } else {
75 Err(ServiceInvocationError::AlreadyExists(
76 "RPC request with given ID already pending".to_string(),
77 ))
78 }
79 }
80
81 fn handle_response(&self, response_message: UMessage) {
82 let reqid = response_message.request_id_unchecked().clone();
83 let response_sender = {
84 let Ok(mut pending_requests) = self.pending_requests.lock() else {
86 info!(
87 request_id = reqid.to_hyphenated_string(),
88 "failed to process response message, cannot acquire lock for pending requests map"
89 );
90 return;
91 };
92 pending_requests.remove(&reqid)
93 };
94 if let Some(sender) = response_sender {
95 if let Err(_e) = sender.send(response_message) {
96 debug!(
98 request_id = reqid.to_hyphenated_string(),
99 "failed to deliver RPC Response message, channel already closed"
100 );
101 } else {
102 debug!(
103 request_id = reqid.to_hyphenated_string(),
104 "successfully delivered RPC Response message"
105 )
106 }
107 } else {
108 debug!(
110 request_id = reqid.to_hyphenated_string(),
111 "ignoring (duplicate?) RPC Response message with unknown request ID"
112 );
113 }
114 }
115
116 fn remove_pending_request(&self, reqid: &UUID) -> Option<Sender<UMessage>> {
117 self.pending_requests
118 .lock()
119 .map_or(None, |mut pending_requests| pending_requests.remove(reqid))
120 }
121
122 #[cfg(test)]
123 fn contains(&self, reqid: &UUID) -> bool {
124 self.pending_requests
125 .lock()
126 .is_ok_and(|pending_requests| pending_requests.contains_key(reqid))
127 }
128}
129
130#[async_trait]
131impl UListener for ResponseListener {
132 async fn on_receive(&self, msg: UMessage) {
133 if msg.is_response() {
136 self.handle_response(msg);
137 } else {
138 debug!(
139 message_type = msg.type_unchecked().to_cloudevent_type(),
140 "ignoring non-response message received by RPC client"
141 );
142 }
143 }
144}
145
146pub struct InMemoryRpcClient {
160 transport: Arc<dyn UTransport>,
161 uri_provider: Arc<dyn LocalUriProvider>,
162 response_listener: Arc<ResponseListener>,
163}
164
165impl InMemoryRpcClient {
166 pub async fn new(
178 transport: Arc<dyn UTransport>,
179 uri_provider: Arc<dyn LocalUriProvider>,
180 ) -> Result<Self, RegistrationError> {
181 let response_listener = Arc::new(ResponseListener {
182 pending_requests: Mutex::new(HashMap::new()),
183 });
184 transport
185 .register_listener(
186 &UUri::any(),
187 Some(&uri_provider.get_source_uri()),
188 response_listener.clone(),
189 )
190 .await
191 .map_err(RegistrationError::from)?;
192
193 Ok(InMemoryRpcClient {
194 transport,
195 uri_provider,
196 response_listener,
197 })
198 }
199
200 #[cfg(test)]
201 fn contains_pending_request(&self, reqid: &UUID) -> bool {
202 self.response_listener.contains(reqid)
203 }
204}
205
206#[async_trait]
207impl RpcClient for InMemoryRpcClient {
208 async fn invoke_method(
209 &self,
210 method: UUri,
211 call_options: CallOptions,
212 payload: Option<UPayload>,
213 ) -> Result<Option<UPayload>, ServiceInvocationError> {
214 let message_id = call_options.message_id().unwrap_or_else(UUID::build);
215
216 let mut builder = UMessageBuilder::request(
217 method.clone(),
218 self.uri_provider.get_source_uri(),
219 call_options.ttl(),
220 );
221 builder.with_message_id(message_id.clone());
222 if let Some(token) = call_options.token() {
223 builder.with_token(token.to_owned());
224 }
225 if let Some(priority) = call_options.priority() {
226 builder.with_priority(priority);
227 }
228 let rpc_request_message = build_message(&mut builder, payload)
229 .map_err(|e| ServiceInvocationError::InvalidArgument(e.to_string()))?;
230
231 let receiver = self
232 .response_listener
233 .try_add_pending_request(message_id.clone())?;
234 self.transport
235 .send(rpc_request_message)
236 .await
237 .inspect_err(|_e| {
238 self.response_listener.remove_pending_request(&message_id);
239 })?;
240 debug!(
241 request_id = message_id.to_hyphenated_string(),
242 ttl = call_options.ttl(),
243 "successfully sent RPC Request message"
244 );
245
246 match timeout(Duration::from_millis(call_options.ttl() as u64), receiver).await {
247 Err(_) => {
248 debug!(
249 request_id = message_id.to_hyphenated_string(),
250 ttl = call_options.ttl(),
251 "invocation of service operation has timed out"
252 );
253 self.response_listener.remove_pending_request(&message_id);
254 Err(ServiceInvocationError::DeadlineExceeded)
255 }
256 Ok(result) => match result {
257 Ok(response_message) => handle_response_message(response_message),
258 Err(_e) => {
259 debug!(
260 request_id = message_id.to_hyphenated_string(),
261 "response listener failed to forward response message"
262 );
263 self.response_listener.remove_pending_request(&message_id);
264 Err(ServiceInvocationError::Internal(
265 "error receiving response message".to_string(),
266 ))
267 }
268 },
269 }
270 }
271}
272
273#[cfg(test)]
274mod tests {
275
276 use super::*;
279
280 use protobuf::well_known_types::wrappers::StringValue;
281 use tokio::{join, sync::Notify};
282
283 use crate::{utransport::MockTransport, StaticUriProvider, UMessageBuilder, UPriority, UUri};
284
285 fn new_uri_provider() -> Arc<dyn LocalUriProvider> {
286 Arc::new(StaticUriProvider::new("", 0x0005, 0x02))
287 }
288
289 fn service_method_uri() -> UUri {
290 UUri {
291 ue_id: 0x0001,
292 ue_version_major: 0x01,
293 resource_id: 0x1000,
294 ..Default::default()
295 }
296 }
297
298 #[tokio::test]
299 async fn test_registration_of_response_listener_fails() {
300 let mut mock_transport = MockTransport::default();
302 mock_transport
304 .expect_do_register_listener()
305 .once()
306 .returning(|_source_filter, _sink_filter, _listener| {
307 Err(UStatus::fail_with_code(
308 UCode::RESOURCE_EXHAUSTED,
309 "max number of listeners exceeded",
310 ))
311 });
312
313 let creation_attempt =
315 InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider()).await;
316
317 assert!(
319 creation_attempt.is_err_and(|e| matches!(e, RegistrationError::MaxListenersExceeded))
320 );
321 }
322
323 #[tokio::test]
324 async fn test_invoke_method_fails_with_transport_error() {
325 let mut mock_transport = MockTransport::default();
327 mock_transport
328 .expect_do_register_listener()
329 .once()
330 .returning(|_source_filter, _sink_filter, _listener| Ok(()));
331 mock_transport
333 .expect_do_send()
334 .returning(|_request_message| {
335 Err(UStatus::fail_with_code(
336 UCode::UNAVAILABLE,
337 "transport not available",
338 ))
339 });
340 let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
341 .await
342 .unwrap();
343
344 let message_id = UUID::build();
346 let call_options =
347 CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
348 let response = client
349 .invoke_method(service_method_uri(), call_options, None)
350 .await;
351
352 assert!(response.is_err_and(|e| matches!(e, ServiceInvocationError::Unavailable(_msg))));
354 assert!(!client.contains_pending_request(&message_id));
355 }
356
357 #[tokio::test]
358 async fn test_invoke_method_succeeds() {
359 let message_id = UUID::build();
360 let call_options = CallOptions::for_rpc_request(
361 5_000,
362 Some(message_id.clone()),
363 Some("my_token".to_string()),
364 Some(crate::UPriority::UPRIORITY_CS6),
365 );
366
367 let (captured_listener_tx, captured_listener_rx) = tokio::sync::oneshot::channel();
368 let request_sent = Arc::new(Notify::new());
369 let request_sent_clone = request_sent.clone();
370
371 let mut mock_transport = MockTransport::default();
373 mock_transport
374 .expect_do_register_listener()
375 .once()
376 .return_once(move |_source_filter, _sink_filter, listener| {
377 captured_listener_tx
378 .send(listener)
379 .map_err(|_e| UStatus::fail("cannot capture listener"))
380 });
381 let expected_message_id = message_id.clone();
382 mock_transport
383 .expect_do_send()
384 .once()
385 .withf(move |request_message| {
386 request_message.id_unchecked() == &expected_message_id
387 && request_message.priority_unchecked() == UPriority::UPRIORITY_CS6
388 && request_message.ttl_unchecked() == 5_000
389 && request_message.token() == Some(&String::from("my_token"))
390 })
391 .returning(move |_request_message| {
392 request_sent_clone.notify_one();
393 Ok(())
394 });
395
396 let uri_provider = new_uri_provider();
397 let rpc_client = Arc::new(
398 InMemoryRpcClient::new(Arc::new(mock_transport), uri_provider.clone())
399 .await
400 .unwrap(),
401 );
402 let client: Arc<dyn RpcClient> = rpc_client.clone();
403
404 let response_handle = tokio::spawn(async move {
406 let request_payload = StringValue {
407 value: "World".to_string(),
408 ..Default::default()
409 };
410 client
411 .invoke_proto_method::<_, StringValue>(
412 service_method_uri(),
413 call_options,
414 request_payload,
415 )
416 .await
417 });
418
419 let response_payload = StringValue {
421 value: "Hello World".to_string(),
422 ..Default::default()
423 };
424 let response_message = UMessageBuilder::response(
425 uri_provider.get_source_uri(),
426 message_id.clone(),
427 service_method_uri(),
428 )
429 .build_with_protobuf_payload(&response_payload)
430 .unwrap();
431
432 let (response_listener_result, _) = join!(captured_listener_rx, request_sent.notified());
434 let response_listener = response_listener_result.unwrap();
435
436 let cloned_response_message = response_message.clone();
438 let cloned_response_listener = response_listener.clone();
439 tokio::spawn(async move {
440 cloned_response_listener
441 .on_receive(cloned_response_message)
442 .await
443 });
444
445 let response = response_handle.await.unwrap();
447 assert!(response.is_ok_and(|payload| payload.value == *"Hello World"));
448 assert!(!rpc_client.contains_pending_request(&message_id));
449
450 response_listener.on_receive(response_message).await;
452 assert!(!rpc_client.contains_pending_request(&message_id));
454 }
455
456 #[tokio::test]
457 async fn test_invoke_method_fails_on_repeated_invocation() {
458 let message_id = UUID::build();
459 let first_request_sent = Arc::new(Notify::new());
460 let first_request_sent_clone = first_request_sent.clone();
461
462 let mut mock_transport = MockTransport::default();
464 mock_transport
465 .expect_do_register_listener()
466 .once()
467 .return_const(Ok(()));
468 let expected_message_id = message_id.clone();
469 mock_transport
470 .expect_do_send()
471 .once()
472 .withf(move |request_message| request_message.id_unchecked() == &expected_message_id)
473 .returning(move |_request_message| {
474 first_request_sent_clone.notify_one();
475 Ok(())
476 });
477
478 let in_memory_rpc_client = Arc::new(
479 InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
480 .await
481 .unwrap(),
482 );
483 let rpc_client: Arc<dyn RpcClient> = in_memory_rpc_client.clone();
484
485 let call_options =
487 CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
488 let cloned_call_options = call_options.clone();
489 let cloned_rpc_client = rpc_client.clone();
490
491 tokio::spawn(async move {
492 let request_payload = StringValue {
493 value: "World".to_string(),
494 ..Default::default()
495 };
496 cloned_rpc_client
497 .invoke_proto_method::<_, StringValue>(
498 service_method_uri(),
499 cloned_call_options,
500 request_payload,
501 )
502 .await
503 });
504
505 first_request_sent.notified().await;
509
510 let request_payload = StringValue {
512 value: "World".to_string(),
513 ..Default::default()
514 };
515 let second_request_handle = tokio::spawn(async move {
516 rpc_client
517 .invoke_proto_method::<_, StringValue>(
518 service_method_uri(),
519 call_options,
520 request_payload,
521 )
522 .await
523 });
524
525 let response = second_request_handle.await.unwrap();
527 assert!(response.is_err_and(|e| matches!(e, ServiceInvocationError::AlreadyExists(_))));
528 assert!(in_memory_rpc_client.contains_pending_request(&message_id));
530 }
531
532 #[tokio::test]
533 async fn test_invoke_method_fails_with_remote_error() {
534 let (captured_listener_tx, captured_listener_rx) = std::sync::mpsc::channel();
535
536 let mut mock_transport = MockTransport::default();
538 mock_transport.expect_do_register_listener().returning(
539 move |_source_filter, _sink_filter, listener| {
540 captured_listener_tx
541 .send(listener)
542 .map_err(|_e| UStatus::fail("cannot capture listener"))
543 },
544 );
545 mock_transport
547 .expect_do_send()
548 .returning(move |request_message| {
549 let error = UStatus::fail_with_code(UCode::NOT_FOUND, "no such object");
550 let response_message = UMessageBuilder::response_for_request(
551 request_message.attributes.as_ref().unwrap(),
552 )
553 .with_comm_status(UCode::NOT_FOUND)
554 .build_with_protobuf_payload(&error)
555 .unwrap();
556 let captured_listener = captured_listener_rx.recv().unwrap().to_owned();
557 tokio::spawn(async move { captured_listener.on_receive(response_message).await });
558 Ok(())
559 });
560
561 let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
562 .await
563 .unwrap();
564
565 let message_id = UUID::build();
567 let call_options =
568 CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
569 let response = client
570 .invoke_method(service_method_uri(), call_options, None)
571 .await;
572
573 assert!(response.is_err_and(|e| { matches!(e, ServiceInvocationError::NotFound(_msg)) }));
575 assert!(!client.contains_pending_request(&message_id));
576 }
577
578 #[tokio::test]
579 async fn test_invoke_method_times_out() {
580 let mut mock_transport = MockTransport::default();
582 mock_transport
583 .expect_do_register_listener()
584 .returning(|_source_filter, _sink_filter, _listener| Ok(()));
585 mock_transport
587 .expect_do_send()
588 .returning(|_request_message| Ok(()));
589
590 let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
591 .await
592 .unwrap();
593
594 let message_id = UUID::build();
596 let call_options = CallOptions::for_rpc_request(20, Some(message_id.clone()), None, None);
597 let response = client
598 .invoke_method(service_method_uri(), call_options, None)
599 .await;
600
601 assert!(response.is_err_and(|e| { matches!(e, ServiceInvocationError::DeadlineExceeded) }));
603 assert!(!client.contains_pending_request(&message_id));
604 }
605}