up_rust/communication/
in_memory_rpc_client.rs

1/********************************************************************************
2 * Copyright (c) 2024 Contributors to the Eclipse Foundation
3 *
4 * See the NOTICE file(s) distributed with this work for additional
5 * information regarding copyright ownership.
6 *
7 * This program and the accompanying materials are made available under the
8 * terms of the Apache License Version 2.0 which is available at
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * SPDX-License-Identifier: Apache-2.0
12 ********************************************************************************/
13
14// [impl->dsn~communication-layer-impl-default~1]
15
16use std::collections::hash_map::Entry;
17use std::collections::HashMap;
18use std::sync::{Arc, Mutex};
19use std::time::Duration;
20
21use async_trait::async_trait;
22use tokio::sync::oneshot::{Receiver, Sender};
23use tokio::time::timeout;
24use tracing::{debug, info};
25
26use crate::{
27    LocalUriProvider, UCode, UListener, UMessage, UMessageBuilder, UStatus, UTransport, UUri, UUID,
28};
29
30use super::{
31    build_message, CallOptions, RegistrationError, RpcClient, ServiceInvocationError, UPayload,
32};
33
34/// Handles an RPC Response message received from the transport layer.
35fn handle_response_message(response: UMessage) -> Result<Option<UPayload>, ServiceInvocationError> {
36    match response.commstatus() {
37        Some(UCode::OK) | None => {
38            // successful invocation
39            let payload_format = response.payload_format().unwrap_or_default();
40            Ok(response
41                .payload
42                .map(|payload| UPayload::new(payload, payload_format)))
43        }
44        Some(code) => {
45            // try to extract UStatus from response payload
46            let status = response.extract_protobuf().unwrap_or_else(|_e| {
47                UStatus::fail_with_code(code, "failed to invoke service operation")
48            });
49            Err(ServiceInvocationError::from(status))
50        }
51    }
52}
53
54struct ResponseListener {
55    // request ID -> sender for response message
56    pending_requests: Mutex<HashMap<UUID, Sender<UMessage>>>,
57}
58
59impl ResponseListener {
60    fn try_add_pending_request(
61        &self,
62        reqid: UUID,
63    ) -> Result<Receiver<UMessage>, ServiceInvocationError> {
64        let Ok(mut pending_requests) = self.pending_requests.lock() else {
65            return Err(ServiceInvocationError::Internal(
66                "failed to add response handler".to_string(),
67            ));
68        };
69
70        if let Entry::Vacant(entry) = pending_requests.entry(reqid) {
71            let (tx, rx) = tokio::sync::oneshot::channel();
72            entry.insert(tx);
73            Ok(rx)
74        } else {
75            Err(ServiceInvocationError::AlreadyExists(
76                "RPC request with given ID already pending".to_string(),
77            ))
78        }
79    }
80
81    fn handle_response(&self, response_message: UMessage) {
82        let reqid = response_message.request_id_unchecked().clone();
83        let response_sender = {
84            // drop lock as soon as possible
85            let Ok(mut pending_requests) = self.pending_requests.lock() else {
86                info!(
87                    request_id = reqid.to_hyphenated_string(),
88                    "failed to process response message, cannot acquire lock for pending requests map"
89                );
90                return;
91            };
92            pending_requests.remove(&reqid)
93        };
94        if let Some(sender) = response_sender {
95            if let Err(_e) = sender.send(response_message) {
96                // channel seems to be closed already
97                debug!(
98                    request_id = reqid.to_hyphenated_string(),
99                    "failed to deliver RPC Response message, channel already closed"
100                );
101            } else {
102                debug!(
103                    request_id = reqid.to_hyphenated_string(),
104                    "successfully delivered RPC Response message"
105                )
106            }
107        } else {
108            // we seem to have received a duplicate of the response message, ignoring it ...
109            debug!(
110                request_id = reqid.to_hyphenated_string(),
111                "ignoring (duplicate?) RPC Response message with unknown request ID"
112            );
113        }
114    }
115
116    fn remove_pending_request(&self, reqid: &UUID) -> Option<Sender<UMessage>> {
117        self.pending_requests
118            .lock()
119            .map_or(None, |mut pending_requests| pending_requests.remove(reqid))
120    }
121
122    #[cfg(test)]
123    fn contains(&self, reqid: &UUID) -> bool {
124        self.pending_requests
125            .lock()
126            .is_ok_and(|pending_requests| pending_requests.contains_key(reqid))
127    }
128}
129
130#[async_trait]
131impl UListener for ResponseListener {
132    async fn on_receive(&self, msg: UMessage) {
133        // it is sufficient to check if the message is a response
134        // because the transport implementation forwards valid UMessages only
135        if msg.is_response() {
136            self.handle_response(msg);
137        } else {
138            debug!(
139                message_type = msg.type_unchecked().to_cloudevent_type(),
140                "ignoring non-response message received by RPC client"
141            );
142        }
143    }
144}
145
146/// An [`RpcClient`] which keeps all information about pending requests in memory.
147///
148/// The client requires an implementations of [`UTransport`] for sending RPC Request messages
149/// to the service implementation and receiving its RPC Response messages.
150///
151/// During [startup](`Self::new`) the client registers a generic [`UListener`] with the transport
152/// for receiving all kinds of messages with a _sink_ address matching the client. The listener
153/// maintains an in-memory mapping of (pending) request IDs to response message handlers.
154///
155/// When an [`RPC call`](Self::invoke_method) is made, an RPC Request message is sent to the service
156/// implementation and a response handler is created and registered with the listener.
157/// When an RPC Response message arrives from the service, the corresponding handler is being looked
158/// up and invoked.
159pub struct InMemoryRpcClient {
160    transport: Arc<dyn UTransport>,
161    uri_provider: Arc<dyn LocalUriProvider>,
162    response_listener: Arc<ResponseListener>,
163}
164
165impl InMemoryRpcClient {
166    /// Creates a new RPC client for a given transport.
167    ///
168    /// # Arguments
169    ///
170    /// * `transport` - The uProtocol Transport Layer implementation to use for invoking service operations.
171    /// * `uri_provider` - The helper for creating URIs that represent local resources.
172    ///
173    /// # Errors
174    ///
175    /// Returns an error if the generic RPC Response listener could not be
176    /// registered with the given transport.
177    pub async fn new(
178        transport: Arc<dyn UTransport>,
179        uri_provider: Arc<dyn LocalUriProvider>,
180    ) -> Result<Self, RegistrationError> {
181        let response_listener = Arc::new(ResponseListener {
182            pending_requests: Mutex::new(HashMap::new()),
183        });
184        transport
185            .register_listener(
186                &UUri::any(),
187                Some(&uri_provider.get_source_uri()),
188                response_listener.clone(),
189            )
190            .await
191            .map_err(RegistrationError::from)?;
192
193        Ok(InMemoryRpcClient {
194            transport,
195            uri_provider,
196            response_listener,
197        })
198    }
199
200    #[cfg(test)]
201    fn contains_pending_request(&self, reqid: &UUID) -> bool {
202        self.response_listener.contains(reqid)
203    }
204}
205
206#[async_trait]
207impl RpcClient for InMemoryRpcClient {
208    async fn invoke_method(
209        &self,
210        method: UUri,
211        call_options: CallOptions,
212        payload: Option<UPayload>,
213    ) -> Result<Option<UPayload>, ServiceInvocationError> {
214        let message_id = call_options.message_id().unwrap_or_else(UUID::build);
215
216        let mut builder = UMessageBuilder::request(
217            method.clone(),
218            self.uri_provider.get_source_uri(),
219            call_options.ttl(),
220        );
221        builder.with_message_id(message_id.clone());
222        if let Some(token) = call_options.token() {
223            builder.with_token(token.to_owned());
224        }
225        if let Some(priority) = call_options.priority() {
226            builder.with_priority(priority);
227        }
228        let rpc_request_message = build_message(&mut builder, payload)
229            .map_err(|e| ServiceInvocationError::InvalidArgument(e.to_string()))?;
230
231        let receiver = self
232            .response_listener
233            .try_add_pending_request(message_id.clone())?;
234        self.transport
235            .send(rpc_request_message)
236            .await
237            .inspect_err(|_e| {
238                self.response_listener.remove_pending_request(&message_id);
239            })?;
240        debug!(
241            request_id = message_id.to_hyphenated_string(),
242            ttl = call_options.ttl(),
243            "successfully sent RPC Request message"
244        );
245
246        match timeout(Duration::from_millis(call_options.ttl() as u64), receiver).await {
247            Err(_) => {
248                debug!(
249                    request_id = message_id.to_hyphenated_string(),
250                    ttl = call_options.ttl(),
251                    "invocation of service operation has timed out"
252                );
253                self.response_listener.remove_pending_request(&message_id);
254                Err(ServiceInvocationError::DeadlineExceeded)
255            }
256            Ok(result) => match result {
257                Ok(response_message) => handle_response_message(response_message),
258                Err(_e) => {
259                    debug!(
260                        request_id = message_id.to_hyphenated_string(),
261                        "response listener failed to forward response message"
262                    );
263                    self.response_listener.remove_pending_request(&message_id);
264                    Err(ServiceInvocationError::Internal(
265                        "error receiving response message".to_string(),
266                    ))
267                }
268            },
269        }
270    }
271}
272
273#[cfg(test)]
274mod tests {
275
276    // [utest->dsn~communication-layer-impl-default~1]
277
278    use super::*;
279
280    use protobuf::well_known_types::wrappers::StringValue;
281    use tokio::{join, sync::Notify};
282
283    use crate::{utransport::MockTransport, StaticUriProvider, UMessageBuilder, UPriority, UUri};
284
285    fn new_uri_provider() -> Arc<dyn LocalUriProvider> {
286        Arc::new(StaticUriProvider::new("", 0x0005, 0x02))
287    }
288
289    fn service_method_uri() -> UUri {
290        UUri {
291            ue_id: 0x0001,
292            ue_version_major: 0x01,
293            resource_id: 0x1000,
294            ..Default::default()
295        }
296    }
297
298    #[tokio::test]
299    async fn test_registration_of_response_listener_fails() {
300        // GIVEN a transport
301        let mut mock_transport = MockTransport::default();
302        // with the maximum number of listeners already registered
303        mock_transport
304            .expect_do_register_listener()
305            .once()
306            .returning(|_source_filter, _sink_filter, _listener| {
307                Err(UStatus::fail_with_code(
308                    UCode::RESOURCE_EXHAUSTED,
309                    "max number of listeners exceeded",
310                ))
311            });
312
313        // WHEN trying to create an RpcClient for the transport
314        let creation_attempt =
315            InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider()).await;
316
317        // THEN the attempt fails with a MaxListenersExceeded error
318        assert!(
319            creation_attempt.is_err_and(|e| matches!(e, RegistrationError::MaxListenersExceeded))
320        );
321    }
322
323    #[tokio::test]
324    async fn test_invoke_method_fails_with_transport_error() {
325        // GIVEN an RPC client
326        let mut mock_transport = MockTransport::default();
327        mock_transport
328            .expect_do_register_listener()
329            .once()
330            .returning(|_source_filter, _sink_filter, _listener| Ok(()));
331        // with a transport that fails with an error when invoking a method
332        mock_transport
333            .expect_do_send()
334            .returning(|_request_message| {
335                Err(UStatus::fail_with_code(
336                    UCode::UNAVAILABLE,
337                    "transport not available",
338                ))
339            });
340        let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
341            .await
342            .unwrap();
343
344        // WHEN invoking a remote service operation
345        let message_id = UUID::build();
346        let call_options =
347            CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
348        let response = client
349            .invoke_method(service_method_uri(), call_options, None)
350            .await;
351
352        // THEN the invocation fails with the error caused at the Transport Layer
353        assert!(response.is_err_and(|e| matches!(e, ServiceInvocationError::Unavailable(_msg))));
354        assert!(!client.contains_pending_request(&message_id));
355    }
356
357    #[tokio::test]
358    async fn test_invoke_method_succeeds() {
359        let message_id = UUID::build();
360        let call_options = CallOptions::for_rpc_request(
361            5_000,
362            Some(message_id.clone()),
363            Some("my_token".to_string()),
364            Some(crate::UPriority::UPRIORITY_CS6),
365        );
366
367        let (captured_listener_tx, captured_listener_rx) = tokio::sync::oneshot::channel();
368        let request_sent = Arc::new(Notify::new());
369        let request_sent_clone = request_sent.clone();
370
371        // GIVEN an RPC client
372        let mut mock_transport = MockTransport::default();
373        mock_transport
374            .expect_do_register_listener()
375            .once()
376            .return_once(move |_source_filter, _sink_filter, listener| {
377                captured_listener_tx
378                    .send(listener)
379                    .map_err(|_e| UStatus::fail("cannot capture listener"))
380            });
381        let expected_message_id = message_id.clone();
382        mock_transport
383            .expect_do_send()
384            .once()
385            .withf(move |request_message| {
386                request_message.id_unchecked() == &expected_message_id
387                    && request_message.priority_unchecked() == UPriority::UPRIORITY_CS6
388                    && request_message.ttl_unchecked() == 5_000
389                    && request_message.token() == Some(&String::from("my_token"))
390            })
391            .returning(move |_request_message| {
392                request_sent_clone.notify_one();
393                Ok(())
394            });
395
396        let uri_provider = new_uri_provider();
397        let rpc_client = Arc::new(
398            InMemoryRpcClient::new(Arc::new(mock_transport), uri_provider.clone())
399                .await
400                .unwrap(),
401        );
402        let client: Arc<dyn RpcClient> = rpc_client.clone();
403
404        // WHEN invoking a remote service operation
405        let response_handle = tokio::spawn(async move {
406            let request_payload = StringValue {
407                value: "World".to_string(),
408                ..Default::default()
409            };
410            client
411                .invoke_proto_method::<_, StringValue>(
412                    service_method_uri(),
413                    call_options,
414                    request_payload,
415                )
416                .await
417        });
418
419        // AND the remote service sends the corresponding RPC Response message
420        let response_payload = StringValue {
421            value: "Hello World".to_string(),
422            ..Default::default()
423        };
424        let response_message = UMessageBuilder::response(
425            uri_provider.get_source_uri(),
426            message_id.clone(),
427            service_method_uri(),
428        )
429        .build_with_protobuf_payload(&response_payload)
430        .unwrap();
431
432        // wait for the RPC Request message having been sent
433        let (response_listener_result, _) = join!(captured_listener_rx, request_sent.notified());
434        let response_listener = response_listener_result.unwrap();
435
436        // send the RPC Response message which completes the request
437        let cloned_response_message = response_message.clone();
438        let cloned_response_listener = response_listener.clone();
439        tokio::spawn(async move {
440            cloned_response_listener
441                .on_receive(cloned_response_message)
442                .await
443        });
444
445        // THEN the response contains the expected payload
446        let response = response_handle.await.unwrap();
447        assert!(response.is_ok_and(|payload| payload.value == *"Hello World"));
448        assert!(!rpc_client.contains_pending_request(&message_id));
449
450        // AND if the remote service sends its response message again
451        response_listener.on_receive(response_message).await;
452        // the duplicate response is silently ignored
453        assert!(!rpc_client.contains_pending_request(&message_id));
454    }
455
456    #[tokio::test]
457    async fn test_invoke_method_fails_on_repeated_invocation() {
458        let message_id = UUID::build();
459        let first_request_sent = Arc::new(Notify::new());
460        let first_request_sent_clone = first_request_sent.clone();
461
462        // GIVEN an RPC client
463        let mut mock_transport = MockTransport::default();
464        mock_transport
465            .expect_do_register_listener()
466            .once()
467            .return_const(Ok(()));
468        let expected_message_id = message_id.clone();
469        mock_transport
470            .expect_do_send()
471            .once()
472            .withf(move |request_message| request_message.id_unchecked() == &expected_message_id)
473            .returning(move |_request_message| {
474                first_request_sent_clone.notify_one();
475                Ok(())
476            });
477
478        let in_memory_rpc_client = Arc::new(
479            InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
480                .await
481                .unwrap(),
482        );
483        let rpc_client: Arc<dyn RpcClient> = in_memory_rpc_client.clone();
484
485        // WHEN invoking a remote service operation
486        let call_options =
487            CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
488        let cloned_call_options = call_options.clone();
489        let cloned_rpc_client = rpc_client.clone();
490
491        tokio::spawn(async move {
492            let request_payload = StringValue {
493                value: "World".to_string(),
494                ..Default::default()
495            };
496            cloned_rpc_client
497                .invoke_proto_method::<_, StringValue>(
498                    service_method_uri(),
499                    cloned_call_options,
500                    request_payload,
501                )
502                .await
503        });
504
505        // we wait for the first request message having been sent via the transport
506        // in order to be sure that the pending request has been added to the client's
507        // internal state
508        first_request_sent.notified().await;
509
510        // AND invoking the same operation before the response to the first request has arrived
511        let request_payload = StringValue {
512            value: "World".to_string(),
513            ..Default::default()
514        };
515        let second_request_handle = tokio::spawn(async move {
516            rpc_client
517                .invoke_proto_method::<_, StringValue>(
518                    service_method_uri(),
519                    call_options,
520                    request_payload,
521                )
522                .await
523        });
524
525        // THEN the second invocation fails
526        let response = second_request_handle.await.unwrap();
527        assert!(response.is_err_and(|e| matches!(e, ServiceInvocationError::AlreadyExists(_))));
528        // because there is a pending request for the message ID used in both requests
529        assert!(in_memory_rpc_client.contains_pending_request(&message_id));
530    }
531
532    #[tokio::test]
533    async fn test_invoke_method_fails_with_remote_error() {
534        let (captured_listener_tx, captured_listener_rx) = std::sync::mpsc::channel();
535
536        // GIVEN an RPC client
537        let mut mock_transport = MockTransport::default();
538        mock_transport.expect_do_register_listener().returning(
539            move |_source_filter, _sink_filter, listener| {
540                captured_listener_tx
541                    .send(listener)
542                    .map_err(|_e| UStatus::fail("cannot capture listener"))
543            },
544        );
545        // and a remote service operation that returns an error
546        mock_transport
547            .expect_do_send()
548            .returning(move |request_message| {
549                let error = UStatus::fail_with_code(UCode::NOT_FOUND, "no such object");
550                let response_message = UMessageBuilder::response_for_request(
551                    request_message.attributes.as_ref().unwrap(),
552                )
553                .with_comm_status(UCode::NOT_FOUND)
554                .build_with_protobuf_payload(&error)
555                .unwrap();
556                let captured_listener = captured_listener_rx.recv().unwrap().to_owned();
557                tokio::spawn(async move { captured_listener.on_receive(response_message).await });
558                Ok(())
559            });
560
561        let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
562            .await
563            .unwrap();
564
565        // WHEN invoking the remote service operation
566        let message_id = UUID::build();
567        let call_options =
568            CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
569        let response = client
570            .invoke_method(service_method_uri(), call_options, None)
571            .await;
572
573        // THEN the invocation has failed with the error returned from the service
574        assert!(response.is_err_and(|e| { matches!(e, ServiceInvocationError::NotFound(_msg)) }));
575        assert!(!client.contains_pending_request(&message_id));
576    }
577
578    #[tokio::test]
579    async fn test_invoke_method_times_out() {
580        // GIVEN an RPC client
581        let mut mock_transport = MockTransport::default();
582        mock_transport
583            .expect_do_register_listener()
584            .returning(|_source_filter, _sink_filter, _listener| Ok(()));
585        // and a remote service operation that does not return a response
586        mock_transport
587            .expect_do_send()
588            .returning(|_request_message| Ok(()));
589
590        let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
591            .await
592            .unwrap();
593
594        // WHEN invoking the remote service operation
595        let message_id = UUID::build();
596        let call_options = CallOptions::for_rpc_request(20, Some(message_id.clone()), None, None);
597        let response = client
598            .invoke_method(service_method_uri(), call_options, None)
599            .await;
600
601        // THEN the invocation times out
602        assert!(response.is_err_and(|e| { matches!(e, ServiceInvocationError::DeadlineExceeded) }));
603        assert!(!client.contains_pending_request(&message_id));
604    }
605}