up_rust/communication/
udiscovery_client.rs

1/********************************************************************************
2 * Copyright (c) 2024 Contributors to the Eclipse Foundation
3 *
4 * See the NOTICE file(s) distributed with this work for additional
5 * information regarding copyright ownership.
6 *
7 * This program and the accompanying materials are made available under the
8 * terms of the Apache License Version 2.0 which is available at
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * SPDX-License-Identifier: Apache-2.0
12 ********************************************************************************/
13use std::sync::Arc;
14
15use async_trait::async_trait;
16
17use crate::{
18    core::udiscovery::{
19        udiscovery_uri, FindServicesRequest, FindServicesResponse, GetServiceTopicsRequest,
20        GetServiceTopicsResponse, ServiceTopicInfo, UDiscovery, RESOURCE_ID_FIND_SERVICES,
21        RESOURCE_ID_GET_SERVICE_TOPICS,
22    },
23    UStatus, UUri,
24};
25
26use super::{CallOptions, RpcClient};
27
28/// A [`UDiscovery`] client implementation for invoking operations of a local uDiscovery service.
29///
30/// The client requires an [`RpcClient`] for performing the remote procedure calls.
31pub struct RpcClientUDiscovery {
32    rpc_client: Arc<dyn RpcClient>,
33}
34
35impl RpcClientUDiscovery {
36    /// Creates a new uDiscovery client for a given transport.
37    ///
38    /// # Arguments
39    ///
40    /// * `rpc_client` - The client to use for performing the remote procedure calls on the service.
41    pub fn new(rpc_client: Arc<dyn RpcClient>) -> Self {
42        RpcClientUDiscovery { rpc_client }
43    }
44
45    fn default_call_options() -> CallOptions {
46        CallOptions::for_rpc_request(5_000, None, None, None)
47    }
48}
49
50#[async_trait]
51impl UDiscovery for RpcClientUDiscovery {
52    async fn find_services(
53        &self,
54        uri_pattern: UUri,
55        recursive: bool,
56    ) -> Result<Vec<UUri>, UStatus> {
57        let request_message = FindServicesRequest {
58            uri: Some(uri_pattern).into(),
59            recursive,
60            ..Default::default()
61        };
62        self.rpc_client
63            .invoke_proto_method::<_, FindServicesResponse>(
64                udiscovery_uri(RESOURCE_ID_FIND_SERVICES),
65                Self::default_call_options(),
66                request_message,
67            )
68            .await
69            .map(|response_message| {
70                response_message
71                    .uris
72                    .as_ref()
73                    .map_or(vec![], |batch| batch.uris.to_owned())
74            })
75            .map_err(UStatus::from)
76    }
77
78    async fn get_service_topics(
79        &self,
80        topic_pattern: UUri,
81        recursive: bool,
82    ) -> Result<Vec<ServiceTopicInfo>, UStatus> {
83        let request_message = GetServiceTopicsRequest {
84            topic: Some(topic_pattern).into(),
85            recursive,
86            ..Default::default()
87        };
88        self.rpc_client
89            .invoke_proto_method::<_, GetServiceTopicsResponse>(
90                udiscovery_uri(RESOURCE_ID_GET_SERVICE_TOPICS),
91                Self::default_call_options(),
92                request_message,
93            )
94            .await
95            .map(|response_message| response_message.topics.to_owned())
96            .map_err(UStatus::from)
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use mockall::Sequence;
103
104    use super::*;
105    use crate::{
106        communication::{rpc::MockRpcClient, UPayload},
107        up_core_api::uri::UUriBatch,
108        UCode, UUri,
109    };
110    use std::sync::Arc;
111
112    #[tokio::test]
113    async fn test_find_services_invokes_rpc_client() {
114        let service_pattern_uri = UUri::try_from_parts("other", 0xFFFF_D5A3, 0x01, 0xFFFF).unwrap();
115        let request = FindServicesRequest {
116            uri: Some(service_pattern_uri.clone()).into(),
117            ..Default::default()
118        };
119        let expected_request = request.clone();
120        let mut rpc_client = MockRpcClient::new();
121        let mut seq = Sequence::new();
122        rpc_client
123            .expect_invoke_method()
124            .once()
125            .in_sequence(&mut seq)
126            .withf(|method, _options, payload| {
127                method == &udiscovery_uri(RESOURCE_ID_FIND_SERVICES) && payload.is_some()
128            })
129            .return_const(Err(crate::communication::ServiceInvocationError::Internal(
130                "internal error".to_string(),
131            )));
132        rpc_client
133            .expect_invoke_method()
134            .once()
135            .in_sequence(&mut seq)
136            .withf(move |method, _options, payload| {
137                let request = payload
138                    .to_owned()
139                    .unwrap()
140                    .extract_protobuf::<FindServicesRequest>()
141                    .unwrap();
142                request == expected_request && method == &udiscovery_uri(RESOURCE_ID_FIND_SERVICES)
143            })
144            .returning(move |_method, _options, _payload| {
145                let response = FindServicesResponse {
146                    uris: Some(UUriBatch {
147                        uris: vec![UUri::try_from_parts("other", 0x0004_D5A3, 0x01, 0xD3FE)
148                            .expect("failed to create query result")],
149                        ..Default::default()
150                    })
151                    .into(),
152                    ..Default::default()
153                };
154                Ok(Some(UPayload::try_from_protobuf(response).unwrap()))
155            });
156
157        let udiscovery_client = RpcClientUDiscovery::new(Arc::new(rpc_client));
158
159        assert!(udiscovery_client
160            .find_services(service_pattern_uri.clone(), false)
161            .await
162            .is_err_and(|e| e.get_code() == UCode::INTERNAL));
163        assert!(udiscovery_client
164            .find_services(service_pattern_uri.clone(), false)
165            .await
166            .is_ok_and(|result| result.len() == 1 && service_pattern_uri.matches(&result[0])));
167    }
168
169    #[tokio::test]
170    async fn test_get_service_topics_invokes_rpc_client() {
171        let topic_pattern_uri = UUri::try_from_parts("*", 0xFFFF_D5A3, 0x01, 0xFFFF).unwrap();
172        let request = GetServiceTopicsRequest {
173            topic: Some(topic_pattern_uri.clone()).into(),
174            ..Default::default()
175        };
176        let expected_request = request.clone();
177        let mut rpc_client = MockRpcClient::new();
178        let mut seq = Sequence::new();
179        rpc_client
180            .expect_invoke_method()
181            .once()
182            .in_sequence(&mut seq)
183            .withf(|method, _options, payload| {
184                method == &udiscovery_uri(RESOURCE_ID_GET_SERVICE_TOPICS) && payload.is_some()
185            })
186            .return_const(Err(crate::communication::ServiceInvocationError::Internal(
187                "internal error".to_string(),
188            )));
189        rpc_client
190            .expect_invoke_method()
191            .once()
192            .in_sequence(&mut seq)
193            .withf(move |method, _options, payload| {
194                let request = payload
195                    .to_owned()
196                    .unwrap()
197                    .extract_protobuf::<GetServiceTopicsRequest>()
198                    .unwrap();
199                request == expected_request
200                    && method == &udiscovery_uri(RESOURCE_ID_GET_SERVICE_TOPICS)
201            })
202            .returning(move |_method, _options, _payload| {
203                let topic_info = ServiceTopicInfo {
204                    topic: Some(
205                        UUri::try_from_parts("other", 0x0004_D5A3, 0x01, 0xD3FE)
206                            .expect("failed to create query result"),
207                    )
208                    .into(),
209                    ttl: 600,
210                    info: None.into(),
211                    ..Default::default()
212                };
213                let response = GetServiceTopicsResponse {
214                    topics: vec![topic_info],
215                    ..Default::default()
216                };
217                Ok(Some(UPayload::try_from_protobuf(response).unwrap()))
218            });
219
220        let udiscovery_client = RpcClientUDiscovery::new(Arc::new(rpc_client));
221
222        assert!(udiscovery_client
223            .get_service_topics(topic_pattern_uri.clone(), false)
224            .await
225            .is_err_and(|e| e.get_code() == UCode::INTERNAL));
226        assert!(udiscovery_client
227            .get_service_topics(topic_pattern_uri.clone(), false)
228            .await
229            .is_ok_and(|result| result.len() == 1
230                && topic_pattern_uri.matches(result[0].topic.as_ref().unwrap())));
231    }
232}