up_rust/communication/
in_memory_rpc_client.rs

1/********************************************************************************
2 * Copyright (c) 2024 Contributors to the Eclipse Foundation
3 *
4 * See the NOTICE file(s) distributed with this work for additional
5 * information regarding copyright ownership.
6 *
7 * This program and the accompanying materials are made available under the
8 * terms of the Apache License Version 2.0 which is available at
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * SPDX-License-Identifier: Apache-2.0
12 ********************************************************************************/
13
14// [impl->dsn~communication-layer-impl-default~1]
15
16use std::collections::hash_map::Entry;
17use std::collections::HashMap;
18use std::sync::{Arc, Mutex};
19use std::time::Duration;
20
21use async_trait::async_trait;
22use tokio::sync::oneshot::{Receiver, Sender};
23use tokio::time::timeout;
24use tracing::{debug, info};
25
26use crate::{
27    LocalUriProvider, UCode, UListener, UMessage, UMessageBuilder, UMessageType, UStatus,
28    UTransport, UUri, UUID,
29};
30
31use super::{
32    build_message, CallOptions, RegistrationError, RpcClient, ServiceInvocationError, UPayload,
33};
34
35fn handle_response_message(response: UMessage) -> Result<Option<UPayload>, ServiceInvocationError> {
36    let Some(attribs) = response.attributes.as_ref() else {
37        return Err(ServiceInvocationError::InvalidArgument(
38            "response message does not contain attributes".to_string(),
39        ));
40    };
41
42    match attribs.commstatus.map(|v| v.enum_value_or_default()) {
43        Some(UCode::OK) | None => {
44            // successful invocation
45            response.payload.map_or(Ok(None), |payload| {
46                Ok(Some(UPayload::new(
47                    payload,
48                    attribs.payload_format.enum_value_or_default(),
49                )))
50            })
51        }
52        Some(code) => {
53            // try to extract UStatus from response payload
54            let status = response.extract_protobuf().unwrap_or_else(|_e| {
55                UStatus::fail_with_code(code, "failed to invoke service operation")
56            });
57            Err(ServiceInvocationError::from(status))
58        }
59    }
60}
61
62struct ResponseListener {
63    // request ID -> sender for response message
64    pending_requests: Mutex<HashMap<UUID, Sender<UMessage>>>,
65}
66
67impl ResponseListener {
68    fn try_add_pending_request(
69        &self,
70        reqid: UUID,
71    ) -> Result<Receiver<UMessage>, ServiceInvocationError> {
72        let Ok(mut pending_requests) = self.pending_requests.lock() else {
73            return Err(ServiceInvocationError::Internal(
74                "failed to add response handler".to_string(),
75            ));
76        };
77
78        if let Entry::Vacant(entry) = pending_requests.entry(reqid) {
79            let (tx, rx) = tokio::sync::oneshot::channel();
80            entry.insert(tx);
81            Ok(rx)
82        } else {
83            Err(ServiceInvocationError::AlreadyExists(
84                "RPC request with given ID already pending".to_string(),
85            ))
86        }
87    }
88
89    fn handle_response(&self, reqid: &UUID, response_message: UMessage) {
90        let Ok(mut pending_requests) = self.pending_requests.lock() else {
91            info!(
92                request_id = reqid.to_hyphenated_string(),
93                "failed to process response message, cannot acquire lock for pending requests map"
94            );
95            return;
96        };
97        if let Some(sender) = pending_requests.remove(reqid) {
98            if let Err(_e) = sender.send(response_message) {
99                // channel seems to be closed already
100                debug!(
101                    request_id = reqid.to_hyphenated_string(),
102                    "failed to deliver RPC Response message, channel already closed"
103                );
104            } else {
105                debug!(
106                    request_id = reqid.to_hyphenated_string(),
107                    "successfully delivered RPC Response message"
108                )
109            }
110        } else {
111            // we seem to have received a duplicate of the response message, ignoring it ...
112            debug!(
113                request_id = reqid.to_hyphenated_string(),
114                "ignoring (duplicate?) RPC Response message with unknown request ID"
115            );
116        }
117    }
118
119    fn remove_pending_request(&self, reqid: &UUID) -> Option<Sender<UMessage>> {
120        self.pending_requests
121            .lock()
122            .map_or(None, |mut pending_requests| pending_requests.remove(reqid))
123    }
124
125    #[cfg(test)]
126    fn contains(&self, reqid: &UUID) -> bool {
127        self.pending_requests
128            .lock()
129            .is_ok_and(|pending_requests| pending_requests.contains_key(reqid))
130    }
131}
132
133#[async_trait]
134impl UListener for ResponseListener {
135    async fn on_receive(&self, msg: UMessage) {
136        let message_type = msg
137            .attributes
138            .get_or_default()
139            .type_
140            .enum_value_or_default();
141        if message_type != UMessageType::UMESSAGE_TYPE_RESPONSE {
142            debug!(
143                message_type = message_type.to_cloudevent_type(),
144                "service provider replied with message that is not an RPC Response"
145            );
146            return;
147        }
148
149        if let Some(reqid) = msg
150            .attributes
151            .as_ref()
152            .and_then(|attribs| attribs.reqid.clone().into_option())
153        {
154            self.handle_response(&reqid, msg);
155        } else {
156            debug!("ignoring malformed response message not containing request ID");
157        }
158    }
159}
160
161/// An [`RpcClient`] which keeps all information about pending requests in memory.
162///
163/// The client requires an implementations of [`UTransport`] for sending RPC Request messages
164/// to the service implementation and receiving its RPC Response messages.
165///
166/// During [startup](`Self::new`) the client registers a generic [`UListener`] with the transport
167/// for receiving all kinds of messages with a _sink_ address matching the client. The listener
168/// maintains an in-memory mapping of (pending) request IDs to response message handlers.
169///
170/// When an [`RPC call`](Self::invoke_method) is made, an RPC Request message is sent to the service
171/// implementation and a response handler is created and registered with the listener.
172/// When an RPC Response message arrives from the service, the corresponding handler is being looked
173/// up and invoked.
174pub struct InMemoryRpcClient {
175    transport: Arc<dyn UTransport>,
176    uri_provider: Arc<dyn LocalUriProvider>,
177    response_listener: Arc<ResponseListener>,
178}
179
180impl InMemoryRpcClient {
181    /// Creates a new RPC client for a given transport.
182    ///
183    /// # Arguments
184    ///
185    /// * `transport` - The uProtocol Transport Layer implementation to use for invoking service operations.
186    /// * `uri_provider` - The helper for creating URIs that represent local resources.
187    ///
188    /// # Errors
189    ///
190    /// Returns an error if the generic RPC Response listener could not be
191    /// registered with the given transport.
192    pub async fn new(
193        transport: Arc<dyn UTransport>,
194        uri_provider: Arc<dyn LocalUriProvider>,
195    ) -> Result<Self, RegistrationError> {
196        let response_listener = Arc::new(ResponseListener {
197            pending_requests: Mutex::new(HashMap::new()),
198        });
199        transport
200            .register_listener(
201                &UUri::any(),
202                Some(&uri_provider.get_source_uri()),
203                response_listener.clone(),
204            )
205            .await
206            .map_err(RegistrationError::from)?;
207
208        Ok(InMemoryRpcClient {
209            transport,
210            uri_provider,
211            response_listener,
212        })
213    }
214
215    #[cfg(test)]
216    fn contains_pending_request(&self, reqid: &UUID) -> bool {
217        self.response_listener.contains(reqid)
218    }
219}
220
221#[async_trait]
222impl RpcClient for InMemoryRpcClient {
223    async fn invoke_method(
224        &self,
225        method: UUri,
226        call_options: CallOptions,
227        payload: Option<UPayload>,
228    ) -> Result<Option<UPayload>, ServiceInvocationError> {
229        let message_id = call_options.message_id().unwrap_or_else(UUID::build);
230
231        let mut builder = UMessageBuilder::request(
232            method.clone(),
233            self.uri_provider.get_source_uri(),
234            call_options.ttl(),
235        );
236        builder.with_message_id(message_id.clone());
237        if let Some(token) = call_options.token() {
238            builder.with_token(token.to_owned());
239        }
240        if let Some(priority) = call_options.priority() {
241            builder.with_priority(priority);
242        }
243        let rpc_request_message = build_message(&mut builder, payload)
244            .map_err(|e| ServiceInvocationError::InvalidArgument(e.to_string()))?;
245
246        let receiver = self
247            .response_listener
248            .try_add_pending_request(message_id.clone())?;
249        self.transport
250            .send(rpc_request_message)
251            .await
252            .inspect_err(|_e| {
253                self.response_listener.remove_pending_request(&message_id);
254            })?;
255        debug!(
256            request_id = message_id.to_hyphenated_string(),
257            ttl = call_options.ttl(),
258            "successfully sent RPC Request message"
259        );
260
261        match timeout(Duration::from_millis(call_options.ttl() as u64), receiver).await {
262            Err(_) => {
263                debug!(
264                    request_id = message_id.to_hyphenated_string(),
265                    ttl = call_options.ttl(),
266                    "invocation of service operation has timed out"
267                );
268                self.response_listener.remove_pending_request(&message_id);
269                Err(ServiceInvocationError::DeadlineExceeded)
270            }
271            Ok(result) => match result {
272                Ok(response_message) => handle_response_message(response_message),
273                Err(_e) => {
274                    debug!(
275                        request_id = message_id.to_hyphenated_string(),
276                        "response listener failed to forward response message"
277                    );
278                    self.response_listener.remove_pending_request(&message_id);
279                    Err(ServiceInvocationError::Internal(
280                        "error receiving response message".to_string(),
281                    ))
282                }
283            },
284        }
285    }
286}
287
288#[cfg(test)]
289mod tests {
290
291    // [utest->dsn~communication-layer-impl-default~1]
292
293    use super::*;
294
295    use protobuf::well_known_types::wrappers::StringValue;
296    use tokio::{join, sync::Notify};
297
298    use crate::{utransport::MockTransport, StaticUriProvider, UMessageBuilder, UPriority, UUri};
299
300    fn new_uri_provider() -> Arc<dyn LocalUriProvider> {
301        Arc::new(StaticUriProvider::new("", 0x0005, 0x02))
302    }
303
304    fn service_method_uri() -> UUri {
305        UUri {
306            ue_id: 0x0001,
307            ue_version_major: 0x01,
308            resource_id: 0x1000,
309            ..Default::default()
310        }
311    }
312
313    #[tokio::test]
314    async fn test_registration_of_response_listener_fails() {
315        // GIVEN a transport
316        let mut mock_transport = MockTransport::default();
317        // with the maximum number of listeners already registered
318        mock_transport
319            .expect_do_register_listener()
320            .once()
321            .returning(|_source_filter, _sink_filter, _listener| {
322                Err(UStatus::fail_with_code(
323                    UCode::RESOURCE_EXHAUSTED,
324                    "max number of listeners exceeded",
325                ))
326            });
327
328        // WHEN trying to create an RpcClient for the transport
329        let creation_attempt =
330            InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider()).await;
331
332        // THEN the attempt fails with a MaxListenersExceeded error
333        assert!(
334            creation_attempt.is_err_and(|e| matches!(e, RegistrationError::MaxListenersExceeded))
335        );
336    }
337
338    #[tokio::test]
339    async fn test_invoke_method_fails_with_transport_error() {
340        // GIVEN an RPC client
341        let mut mock_transport = MockTransport::default();
342        mock_transport
343            .expect_do_register_listener()
344            .once()
345            .returning(|_source_filter, _sink_filter, _listener| Ok(()));
346        // with a transport that fails with an error when invoking a method
347        mock_transport
348            .expect_do_send()
349            .returning(|_request_message| {
350                Err(UStatus::fail_with_code(
351                    UCode::UNAVAILABLE,
352                    "transport not available",
353                ))
354            });
355        let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
356            .await
357            .unwrap();
358
359        // WHEN invoking a remote service operation
360        let message_id = UUID::build();
361        let call_options =
362            CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
363        let response = client
364            .invoke_method(service_method_uri(), call_options, None)
365            .await;
366
367        // THEN the invocation fails with the error caused at the Transport Layer
368        assert!(response.is_err_and(|e| matches!(e, ServiceInvocationError::Unavailable(_msg))));
369        assert!(!client.contains_pending_request(&message_id));
370    }
371
372    #[tokio::test]
373    async fn test_invoke_method_succeeds() {
374        let message_id = UUID::build();
375        let call_options = CallOptions::for_rpc_request(
376            5_000,
377            Some(message_id.clone()),
378            Some("my_token".to_string()),
379            Some(crate::UPriority::UPRIORITY_CS6),
380        );
381
382        let (captured_listener_tx, captured_listener_rx) = tokio::sync::oneshot::channel();
383        let request_sent = Arc::new(Notify::new());
384        let request_sent_clone = request_sent.clone();
385
386        // GIVEN an RPC client
387        let mut mock_transport = MockTransport::default();
388        mock_transport
389            .expect_do_register_listener()
390            .once()
391            .return_once(move |_source_filter, _sink_filter, listener| {
392                captured_listener_tx
393                    .send(listener)
394                    .map_err(|_e| UStatus::fail("cannot capture listener"))
395            });
396        let expected_message_id = message_id.clone();
397        mock_transport
398            .expect_do_send()
399            .once()
400            .withf(move |request_message| {
401                request_message.id_unchecked() == &expected_message_id
402                    && request_message.priority_unchecked() == UPriority::UPRIORITY_CS6
403                    && request_message.ttl_unchecked() == 5_000
404                    && request_message.token() == Some(&String::from("my_token"))
405            })
406            .returning(move |_request_message| {
407                request_sent_clone.notify_one();
408                Ok(())
409            });
410
411        let uri_provider = new_uri_provider();
412        let rpc_client = Arc::new(
413            InMemoryRpcClient::new(Arc::new(mock_transport), uri_provider.clone())
414                .await
415                .unwrap(),
416        );
417        let client: Arc<dyn RpcClient> = rpc_client.clone();
418
419        // WHEN invoking a remote service operation
420        let response_handle = tokio::spawn(async move {
421            let request_payload = StringValue {
422                value: "World".to_string(),
423                ..Default::default()
424            };
425            client
426                .invoke_proto_method::<_, StringValue>(
427                    service_method_uri(),
428                    call_options,
429                    request_payload,
430                )
431                .await
432        });
433
434        // AND the remote service sends the corresponding RPC Response message
435        let response_payload = StringValue {
436            value: "Hello World".to_string(),
437            ..Default::default()
438        };
439        let response_message = UMessageBuilder::response(
440            uri_provider.get_source_uri(),
441            message_id.clone(),
442            service_method_uri(),
443        )
444        .build_with_protobuf_payload(&response_payload)
445        .unwrap();
446
447        // wait for the RPC Request message having been sent
448        let (response_listener_result, _) = join!(captured_listener_rx, request_sent.notified());
449        let response_listener = response_listener_result.unwrap();
450
451        // send the RPC Response message which completes the request
452        let cloned_response_message = response_message.clone();
453        let cloned_response_listener = response_listener.clone();
454        tokio::spawn(async move {
455            cloned_response_listener
456                .on_receive(cloned_response_message)
457                .await
458        });
459
460        // THEN the response contains the expected payload
461        let response = response_handle.await.unwrap();
462        assert!(response.is_ok_and(|payload| payload.value == *"Hello World"));
463        assert!(!rpc_client.contains_pending_request(&message_id));
464
465        // AND if the remote service sends its response message again
466        response_listener.on_receive(response_message).await;
467        // the duplicate response is silently ignored
468        assert!(!rpc_client.contains_pending_request(&message_id));
469    }
470
471    #[tokio::test]
472    async fn test_invoke_method_fails_on_repeated_invocation() {
473        let message_id = UUID::build();
474        let first_request_sent = Arc::new(Notify::new());
475        let first_request_sent_clone = first_request_sent.clone();
476
477        // GIVEN an RPC client
478        let mut mock_transport = MockTransport::default();
479        mock_transport
480            .expect_do_register_listener()
481            .once()
482            .return_const(Ok(()));
483        let expected_message_id = message_id.clone();
484        mock_transport
485            .expect_do_send()
486            .once()
487            .withf(move |request_message| request_message.id_unchecked() == &expected_message_id)
488            .returning(move |_request_message| {
489                first_request_sent_clone.notify_one();
490                Ok(())
491            });
492
493        let in_memory_rpc_client = Arc::new(
494            InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
495                .await
496                .unwrap(),
497        );
498        let rpc_client: Arc<dyn RpcClient> = in_memory_rpc_client.clone();
499
500        // WHEN invoking a remote service operation
501        let call_options =
502            CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
503        let cloned_call_options = call_options.clone();
504        let cloned_rpc_client = rpc_client.clone();
505
506        tokio::spawn(async move {
507            let request_payload = StringValue {
508                value: "World".to_string(),
509                ..Default::default()
510            };
511            cloned_rpc_client
512                .invoke_proto_method::<_, StringValue>(
513                    service_method_uri(),
514                    cloned_call_options,
515                    request_payload,
516                )
517                .await
518        });
519
520        // we wait for the first request message having been sent via the transport
521        // in order to be sure that the pending request has been added to the client's
522        // internal state
523        first_request_sent.notified().await;
524
525        // AND invoking the same operation before the response to the first request has arrived
526        let request_payload = StringValue {
527            value: "World".to_string(),
528            ..Default::default()
529        };
530        let second_request_handle = tokio::spawn(async move {
531            rpc_client
532                .invoke_proto_method::<_, StringValue>(
533                    service_method_uri(),
534                    call_options,
535                    request_payload,
536                )
537                .await
538        });
539
540        // THEN the second invocation fails
541        let response = second_request_handle.await.unwrap();
542        assert!(response.is_err_and(|e| matches!(e, ServiceInvocationError::AlreadyExists(_))));
543        // because there is a pending request for the message ID used in both requests
544        assert!(in_memory_rpc_client.contains_pending_request(&message_id));
545    }
546
547    #[tokio::test]
548    async fn test_invoke_method_fails_with_remote_error() {
549        let (captured_listener_tx, captured_listener_rx) = std::sync::mpsc::channel();
550
551        // GIVEN an RPC client
552        let mut mock_transport = MockTransport::default();
553        mock_transport.expect_do_register_listener().returning(
554            move |_source_filter, _sink_filter, listener| {
555                captured_listener_tx
556                    .send(listener)
557                    .map_err(|_e| UStatus::fail("cannot capture listener"))
558            },
559        );
560        // and a remote service operation that returns an error
561        mock_transport
562            .expect_do_send()
563            .returning(move |request_message| {
564                let error = UStatus::fail_with_code(UCode::NOT_FOUND, "no such object");
565                let response_message = UMessageBuilder::response_for_request(
566                    request_message.attributes.as_ref().unwrap(),
567                )
568                .with_comm_status(UCode::NOT_FOUND)
569                .build_with_protobuf_payload(&error)
570                .unwrap();
571                let captured_listener = captured_listener_rx.recv().unwrap().to_owned();
572                tokio::spawn(async move { captured_listener.on_receive(response_message).await });
573                Ok(())
574            });
575
576        let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
577            .await
578            .unwrap();
579
580        // WHEN invoking the remote service operation
581        let message_id = UUID::build();
582        let call_options =
583            CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
584        let response = client
585            .invoke_method(service_method_uri(), call_options, None)
586            .await;
587
588        // THEN the invocation has failed with the error returned from the service
589        assert!(response.is_err_and(|e| { matches!(e, ServiceInvocationError::NotFound(_msg)) }));
590        assert!(!client.contains_pending_request(&message_id));
591    }
592
593    #[tokio::test]
594    async fn test_invoke_method_times_out() {
595        // GIVEN an RPC client
596        let mut mock_transport = MockTransport::default();
597        mock_transport
598            .expect_do_register_listener()
599            .returning(|_source_filter, _sink_filter, _listener| Ok(()));
600        // and a remote service operation that does not return a response
601        mock_transport
602            .expect_do_send()
603            .returning(|_request_message| Ok(()));
604
605        let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
606            .await
607            .unwrap();
608
609        // WHEN invoking the remote service operation
610        let message_id = UUID::build();
611        let call_options = CallOptions::for_rpc_request(20, Some(message_id.clone()), None, None);
612        let response = client
613            .invoke_method(service_method_uri(), call_options, None)
614            .await;
615
616        // THEN the invocation times out
617        assert!(response.is_err_and(|e| { matches!(e, ServiceInvocationError::DeadlineExceeded) }));
618        assert!(!client.contains_pending_request(&message_id));
619    }
620
621    #[test]
622    fn test_handle_response_message_fails_for_missing_attributes() {
623        let response_msg = UMessage {
624            ..Default::default()
625        };
626        let result = handle_response_message(response_msg);
627        assert!(result.is_err_and(|e| matches!(e, ServiceInvocationError::InvalidArgument(_))));
628    }
629}