up_rust/communication/
in_memory_rpc_client.rs

1/********************************************************************************
2 * Copyright (c) 2024 Contributors to the Eclipse Foundation
3 *
4 * See the NOTICE file(s) distributed with this work for additional
5 * information regarding copyright ownership.
6 *
7 * This program and the accompanying materials are made available under the
8 * terms of the Apache License Version 2.0 which is available at
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * SPDX-License-Identifier: Apache-2.0
12 ********************************************************************************/
13
14// [impl->req~up-language-comm-api-default-impl~1]
15
16use std::collections::hash_map::Entry;
17use std::collections::HashMap;
18use std::sync::{Arc, Mutex};
19use std::time::Duration;
20
21use async_trait::async_trait;
22use tokio::sync::oneshot::{Receiver, Sender};
23use tokio::time::timeout;
24use tracing::{debug, info};
25
26use crate::{
27    LocalUriProvider, UCode, UListener, UMessage, UMessageBuilder, UMessageType, UStatus,
28    UTransport, UUri, UUID,
29};
30
31use super::{
32    build_message, CallOptions, RegistrationError, RpcClient, ServiceInvocationError, UPayload,
33};
34
35fn handle_response_message(response: UMessage) -> Result<Option<UPayload>, ServiceInvocationError> {
36    let Some(attribs) = response.attributes.as_ref() else {
37        return Err(ServiceInvocationError::InvalidArgument(
38            "response message does not contain attributes".to_string(),
39        ));
40    };
41
42    match attribs.commstatus.map(|v| v.enum_value_or_default()) {
43        Some(UCode::OK) | None => {
44            // successful invocation
45            response.payload.map_or(Ok(None), |payload| {
46                Ok(Some(UPayload::new(
47                    payload,
48                    attribs.payload_format.enum_value_or_default(),
49                )))
50            })
51        }
52        Some(code) => {
53            // try to extract UStatus from response payload
54            let status = response.extract_protobuf().unwrap_or_else(|_e| {
55                UStatus::fail_with_code(code, "failed to invoke service operation")
56            });
57            Err(ServiceInvocationError::from(status))
58        }
59    }
60}
61
62struct ResponseListener {
63    // request ID -> sender for response message
64    pending_requests: Mutex<HashMap<UUID, Sender<UMessage>>>,
65}
66
67impl ResponseListener {
68    fn try_add_pending_request(
69        &self,
70        reqid: UUID,
71    ) -> Result<Receiver<UMessage>, ServiceInvocationError> {
72        let Ok(mut pending_requests) = self.pending_requests.lock() else {
73            return Err(ServiceInvocationError::Internal(
74                "failed to add response handler".to_string(),
75            ));
76        };
77
78        if let Entry::Vacant(entry) = pending_requests.entry(reqid) {
79            let (tx, rx) = tokio::sync::oneshot::channel();
80            entry.insert(tx);
81            Ok(rx)
82        } else {
83            Err(ServiceInvocationError::AlreadyExists(
84                "RPC request with given ID already pending".to_string(),
85            ))
86        }
87    }
88
89    fn handle_response(&self, reqid: &UUID, response_message: UMessage) {
90        let Ok(mut pending_requests) = self.pending_requests.lock() else {
91            info!(
92                request_id = reqid.to_hyphenated_string(),
93                "failed to process response message, cannot acquire lock for pending requests map"
94            );
95            return;
96        };
97        if let Some(sender) = pending_requests.remove(reqid) {
98            if let Err(_e) = sender.send(response_message) {
99                // channel seems to be closed already
100                debug!(
101                    request_id = reqid.to_hyphenated_string(),
102                    "failed to deliver RPC Response message, channel already closed"
103                );
104            } else {
105                debug!(
106                    request_id = reqid.to_hyphenated_string(),
107                    "successfully delivered RPC Response message"
108                )
109            }
110        } else {
111            // we seem to have received a duplicate of the response message, ignoring it ...
112            debug!(
113                request_id = reqid.to_hyphenated_string(),
114                "ignoring (duplicate?) RPC Response message with unknown request ID"
115            );
116        }
117    }
118
119    fn remove_pending_request(&self, reqid: &UUID) -> Option<Sender<UMessage>> {
120        self.pending_requests
121            .lock()
122            .map_or(None, |mut pending_requests| pending_requests.remove(reqid))
123    }
124
125    #[cfg(test)]
126    fn contains(&self, reqid: &UUID) -> bool {
127        self.pending_requests
128            .lock()
129            .is_ok_and(|pending_requests| pending_requests.contains_key(reqid))
130    }
131}
132
133#[async_trait]
134impl UListener for ResponseListener {
135    async fn on_receive(&self, msg: UMessage) {
136        let message_type = msg
137            .attributes
138            .get_or_default()
139            .type_
140            .enum_value_or_default();
141        if message_type != UMessageType::UMESSAGE_TYPE_RESPONSE {
142            debug!(
143                message_type = message_type.to_cloudevent_type(),
144                "service provider replied with message that is not an RPC Response"
145            );
146            return;
147        }
148
149        if let Some(reqid) = msg
150            .attributes
151            .as_ref()
152            .and_then(|attribs| attribs.reqid.clone().into_option())
153        {
154            self.handle_response(&reqid, msg);
155        } else {
156            debug!("ignoring malformed response message not containing request ID");
157        }
158    }
159}
160
161/// An [`RpcClient`] which keeps all information about pending requests in memory.
162///
163/// The client requires an implementations of [`UTransport`] for sending RPC Request messages
164/// to the service implementation and receiving its RPC Response messages.
165///
166/// During [startup](`Self::new`) the client registers a generic [`UListener`] with the transport
167/// for receiving all kinds of messages with a _sink_ address matching the client. The listener
168/// maintains an in-memory mapping of (pending) request IDs to response message handlers.
169///
170/// When an [`RPC call`](Self::invoke_method) is made, an RPC Request message is sent to the service
171/// implementation and a response handler is created and registered with the listener.
172/// When an RPC Response message arrives from the service, the corresponding handler is being looked
173/// up and invoked.
174pub struct InMemoryRpcClient {
175    transport: Arc<dyn UTransport>,
176    uri_provider: Arc<dyn LocalUriProvider>,
177    response_listener: Arc<ResponseListener>,
178}
179
180impl InMemoryRpcClient {
181    /// Creates a new RPC client for a given transport.
182    ///
183    /// # Arguments
184    ///
185    /// * `transport` - The uProtocol Transport Layer implementation to use for invoking service operations.
186    /// * `uri_provider` - The helper for creating URIs that represent local resources.
187    ///
188    /// # Errors
189    ///
190    /// Returns an error if the generic RPC Response listener could not be
191    /// registered with the given transport.
192    pub async fn new(
193        transport: Arc<dyn UTransport>,
194        uri_provider: Arc<dyn LocalUriProvider>,
195    ) -> Result<Self, RegistrationError> {
196        let response_listener = Arc::new(ResponseListener {
197            pending_requests: Mutex::new(HashMap::new()),
198        });
199        transport
200            .register_listener(
201                &UUri::any(),
202                Some(&uri_provider.get_source_uri()),
203                response_listener.clone(),
204            )
205            .await
206            .map_err(RegistrationError::from)?;
207
208        Ok(InMemoryRpcClient {
209            transport,
210            uri_provider,
211            response_listener,
212        })
213    }
214
215    #[cfg(test)]
216    fn contains_pending_request(&self, reqid: &UUID) -> bool {
217        self.response_listener.contains(reqid)
218    }
219}
220
221#[async_trait]
222impl RpcClient for InMemoryRpcClient {
223    async fn invoke_method(
224        &self,
225        method: UUri,
226        call_options: CallOptions,
227        payload: Option<UPayload>,
228    ) -> Result<Option<UPayload>, ServiceInvocationError> {
229        let message_id = call_options.message_id().unwrap_or_else(UUID::build);
230
231        let mut builder = UMessageBuilder::request(
232            method.clone(),
233            self.uri_provider.get_source_uri(),
234            call_options.ttl(),
235        );
236        builder.with_message_id(message_id.clone());
237        if let Some(token) = call_options.token() {
238            builder.with_token(token.to_owned());
239        }
240        if let Some(priority) = call_options.priority() {
241            builder.with_priority(priority);
242        }
243        let rpc_request_message = build_message(&mut builder, payload)
244            .map_err(|e| ServiceInvocationError::InvalidArgument(e.to_string()))?;
245
246        let receiver = self
247            .response_listener
248            .try_add_pending_request(message_id.clone())?;
249        self.transport
250            .send(rpc_request_message)
251            .await
252            .map_err(|e| {
253                self.response_listener.remove_pending_request(&message_id);
254                e
255            })?;
256        debug!(
257            request_id = message_id.to_hyphenated_string(),
258            ttl = call_options.ttl(),
259            "successfully sent RPC Request message"
260        );
261
262        match timeout(Duration::from_millis(call_options.ttl() as u64), receiver).await {
263            Err(_) => {
264                debug!(
265                    request_id = message_id.to_hyphenated_string(),
266                    ttl = call_options.ttl(),
267                    "invocation of service operation has timed out"
268                );
269                self.response_listener.remove_pending_request(&message_id);
270                Err(ServiceInvocationError::DeadlineExceeded)
271            }
272            Ok(result) => match result {
273                Ok(response_message) => handle_response_message(response_message),
274                Err(_e) => {
275                    debug!(
276                        request_id = message_id.to_hyphenated_string(),
277                        "response listener failed to forward response message"
278                    );
279                    self.response_listener.remove_pending_request(&message_id);
280                    Err(ServiceInvocationError::Internal(
281                        "error receiving response message".to_string(),
282                    ))
283                }
284            },
285        }
286    }
287}
288
289#[cfg(test)]
290mod tests {
291
292    // [utest->req~up-language-comm-api-default-impl~1]
293
294    use super::*;
295
296    use protobuf::{well_known_types::wrappers::StringValue, Enum};
297    use tokio::{join, sync::Notify};
298
299    use crate::{utransport::MockTransport, StaticUriProvider, UMessageBuilder, UPriority, UUri};
300
301    fn new_uri_provider() -> Arc<dyn LocalUriProvider> {
302        Arc::new(StaticUriProvider::new("", 0x0005, 0x02))
303    }
304
305    fn service_method_uri() -> UUri {
306        UUri {
307            ue_id: 0x0001,
308            ue_version_major: 0x01,
309            resource_id: 0x1000,
310            ..Default::default()
311        }
312    }
313
314    #[tokio::test]
315    async fn test_registration_of_response_listener_fails() {
316        // GIVEN a transport
317        let mut mock_transport = MockTransport::default();
318        // with the maximum number of listeners already registered
319        mock_transport
320            .expect_do_register_listener()
321            .once()
322            .returning(|_source_filter, _sink_filter, _listener| {
323                Err(UStatus::fail_with_code(
324                    UCode::RESOURCE_EXHAUSTED,
325                    "max number of listeners exceeded",
326                ))
327            });
328
329        // WHEN trying to create an RpcClient for the transport
330        let creation_attempt =
331            InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider()).await;
332
333        // THEN the attempt fails with a MaxListenersExceeded error
334        assert!(
335            creation_attempt.is_err_and(|e| matches!(e, RegistrationError::MaxListenersExceeded))
336        );
337    }
338
339    #[tokio::test]
340    async fn test_invoke_method_fails_with_transport_error() {
341        // GIVEN an RPC client
342        let mut mock_transport = MockTransport::default();
343        mock_transport
344            .expect_do_register_listener()
345            .once()
346            .returning(|_source_filter, _sink_filter, _listener| Ok(()));
347        // with a transport that fails with an error when invoking a method
348        mock_transport
349            .expect_do_send()
350            .returning(|_request_message| {
351                Err(UStatus::fail_with_code(
352                    UCode::UNAVAILABLE,
353                    "transport not available",
354                ))
355            });
356        let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
357            .await
358            .unwrap();
359
360        // WHEN invoking a remote service operation
361        let message_id = UUID::build();
362        let call_options =
363            CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
364        let response = client
365            .invoke_method(service_method_uri(), call_options, None)
366            .await;
367
368        // THEN the invocation fails with the error caused at the Transport Layer
369        assert!(response.is_err_and(|e| matches!(e, ServiceInvocationError::Unavailable(_msg))));
370        assert!(!client.contains_pending_request(&message_id));
371    }
372
373    #[tokio::test]
374    async fn test_invoke_method_succeeds() {
375        let message_id = UUID::build();
376        let call_options = CallOptions::for_rpc_request(
377            5_000,
378            Some(message_id.clone()),
379            Some("my_token".to_string()),
380            Some(crate::UPriority::UPRIORITY_CS6),
381        );
382
383        let (captured_listener_tx, captured_listener_rx) = tokio::sync::oneshot::channel();
384        let request_sent = Arc::new(Notify::new());
385        let request_sent_clone = request_sent.clone();
386
387        // GIVEN an RPC client
388        let mut mock_transport = MockTransport::default();
389        mock_transport
390            .expect_do_register_listener()
391            .once()
392            .return_once(move |_source_filter, _sink_filter, listener| {
393                captured_listener_tx
394                    .send(listener)
395                    .map_err(|_e| UStatus::fail("cannot capture listener"))
396            });
397        let expected_message_id = message_id.clone();
398        mock_transport
399            .expect_do_send()
400            .once()
401            .withf(move |request_message| {
402                request_message.attributes.as_ref().is_some_and(|attribs| {
403                    attribs.id.as_ref() == Some(&expected_message_id)
404                        && attribs.priority.value() == UPriority::UPRIORITY_CS6.value()
405                        && attribs.ttl == Some(5_000)
406                        && attribs.token == Some("my_token".to_string())
407                })
408            })
409            .returning(move |_request_message| {
410                request_sent_clone.notify_one();
411                Ok(())
412            });
413
414        let uri_provider = new_uri_provider();
415        let rpc_client = Arc::new(
416            InMemoryRpcClient::new(Arc::new(mock_transport), uri_provider.clone())
417                .await
418                .unwrap(),
419        );
420        let client: Arc<dyn RpcClient> = rpc_client.clone();
421
422        // WHEN invoking a remote service operation
423        let response_handle = tokio::spawn(async move {
424            let request_payload = StringValue {
425                value: "World".to_string(),
426                ..Default::default()
427            };
428            client
429                .invoke_proto_method::<_, StringValue>(
430                    service_method_uri(),
431                    call_options,
432                    request_payload,
433                )
434                .await
435        });
436
437        // AND the remote service sends the corresponding RPC Response message
438        let response_payload = StringValue {
439            value: "Hello World".to_string(),
440            ..Default::default()
441        };
442        let response_message = UMessageBuilder::response(
443            uri_provider.get_source_uri(),
444            message_id.clone(),
445            service_method_uri(),
446        )
447        .build_with_protobuf_payload(&response_payload)
448        .unwrap();
449
450        // wait for the RPC Request message having been sent
451        let (response_listener_result, _) = join!(captured_listener_rx, request_sent.notified());
452        let response_listener = response_listener_result.unwrap();
453
454        // send the RPC Response message which completes the request
455        let cloned_response_message = response_message.clone();
456        let cloned_response_listener = response_listener.clone();
457        tokio::spawn(async move {
458            cloned_response_listener
459                .on_receive(cloned_response_message)
460                .await
461        });
462
463        // THEN the response contains the expected payload
464        let response = response_handle.await.unwrap();
465        assert!(response.is_ok_and(|payload| payload.value == *"Hello World"));
466        assert!(!rpc_client.contains_pending_request(&message_id));
467
468        // AND if the remote service sends its response message again
469        response_listener.on_receive(response_message).await;
470        // the duplicate response is silently ignored
471        assert!(!rpc_client.contains_pending_request(&message_id));
472    }
473
474    #[tokio::test]
475    async fn test_invoke_method_fails_on_repeated_invocation() {
476        let message_id = UUID::build();
477        let first_request_sent = Arc::new(Notify::new());
478        let first_request_sent_clone = first_request_sent.clone();
479
480        // GIVEN an RPC client
481        let mut mock_transport = MockTransport::default();
482        mock_transport
483            .expect_do_register_listener()
484            .once()
485            .return_const(Ok(()));
486        let expected_message_id = message_id.clone();
487        mock_transport
488            .expect_do_send()
489            .once()
490            .withf(move |request_message| {
491                request_message
492                    .attributes
493                    .as_ref()
494                    .is_some_and(|attribs| attribs.id.as_ref() == Some(&expected_message_id))
495            })
496            .returning(move |_request_message| {
497                first_request_sent_clone.notify_one();
498                Ok(())
499            });
500
501        let in_memory_rpc_client = Arc::new(
502            InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
503                .await
504                .unwrap(),
505        );
506        let rpc_client: Arc<dyn RpcClient> = in_memory_rpc_client.clone();
507
508        // WHEN invoking a remote service operation
509        let call_options =
510            CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
511        let cloned_call_options = call_options.clone();
512        let cloned_rpc_client = rpc_client.clone();
513
514        tokio::spawn(async move {
515            let request_payload = StringValue {
516                value: "World".to_string(),
517                ..Default::default()
518            };
519            cloned_rpc_client
520                .invoke_proto_method::<_, StringValue>(
521                    service_method_uri(),
522                    cloned_call_options,
523                    request_payload,
524                )
525                .await
526        });
527
528        // we wait for the first request message having been sent via the transport
529        // in order to be sure that the pending request has been added to the client's
530        // internal state
531        first_request_sent.notified().await;
532
533        // AND invoking the same operation before the response to the first request has arrived
534        let request_payload = StringValue {
535            value: "World".to_string(),
536            ..Default::default()
537        };
538        let second_request_handle = tokio::spawn(async move {
539            rpc_client
540                .invoke_proto_method::<_, StringValue>(
541                    service_method_uri(),
542                    call_options,
543                    request_payload,
544                )
545                .await
546        });
547
548        // THEN the second invocation fails
549        let response = second_request_handle.await.unwrap();
550        assert!(response.is_err_and(|e| matches!(e, ServiceInvocationError::AlreadyExists(_))));
551        // because there is a pending request for the message ID used in both requests
552        assert!(in_memory_rpc_client.contains_pending_request(&message_id));
553    }
554
555    #[tokio::test]
556    async fn test_invoke_method_fails_with_remote_error() {
557        let (captured_listener_tx, captured_listener_rx) = std::sync::mpsc::channel();
558
559        // GIVEN an RPC client
560        let mut mock_transport = MockTransport::default();
561        mock_transport.expect_do_register_listener().returning(
562            move |_source_filter, _sink_filter, listener| {
563                captured_listener_tx
564                    .send(listener)
565                    .map_err(|_e| UStatus::fail("cannot capture listener"))
566            },
567        );
568        // and a remote service operation that returns an error
569        mock_transport
570            .expect_do_send()
571            .returning(move |request_message| {
572                let error = UStatus::fail_with_code(UCode::NOT_FOUND, "no such object");
573                let response_message = UMessageBuilder::response_for_request(
574                    request_message.attributes.as_ref().unwrap(),
575                )
576                .with_comm_status(UCode::NOT_FOUND)
577                .build_with_protobuf_payload(&error)
578                .unwrap();
579                let captured_listener = captured_listener_rx.recv().unwrap().to_owned();
580                tokio::spawn(async move { captured_listener.on_receive(response_message).await });
581                Ok(())
582            });
583
584        let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
585            .await
586            .unwrap();
587
588        // WHEN invoking the remote service operation
589        let message_id = UUID::build();
590        let call_options =
591            CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
592        let response = client
593            .invoke_method(service_method_uri(), call_options, None)
594            .await;
595
596        // THEN the invocation has failed with the error returned from the service
597        assert!(response.is_err_and(|e| { matches!(e, ServiceInvocationError::NotFound(_msg)) }));
598        assert!(!client.contains_pending_request(&message_id));
599    }
600
601    #[tokio::test]
602    async fn test_invoke_method_times_out() {
603        // GIVEN an RPC client
604        let mut mock_transport = MockTransport::default();
605        mock_transport
606            .expect_do_register_listener()
607            .returning(|_source_filter, _sink_filter, _listener| Ok(()));
608        // and a remote service operation that does not return a response
609        mock_transport
610            .expect_do_send()
611            .returning(|_request_message| Ok(()));
612
613        let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
614            .await
615            .unwrap();
616
617        // WHEN invoking the remote service operation
618        let message_id = UUID::build();
619        let call_options = CallOptions::for_rpc_request(20, Some(message_id.clone()), None, None);
620        let response = client
621            .invoke_method(service_method_uri(), call_options, None)
622            .await;
623
624        // THEN the invocation times out
625        assert!(response.is_err_and(|e| { matches!(e, ServiceInvocationError::DeadlineExceeded) }));
626        assert!(!client.contains_pending_request(&message_id));
627    }
628
629    #[test]
630    fn test_handle_response_message_fails_for_missing_attributes() {
631        let response_msg = UMessage {
632            ..Default::default()
633        };
634        let result = handle_response_message(response_msg);
635        assert!(result.is_err_and(|e| matches!(e, ServiceInvocationError::InvalidArgument(_))));
636    }
637}