up_rust/communication/
usubscription_client.rs

1/********************************************************************************
2 * Copyright (c) 2024 Contributors to the Eclipse Foundation
3 *
4 * See the NOTICE file(s) distributed with this work for additional
5 * information regarding copyright ownership.
6 *
7 * This program and the accompanying materials are made available under the
8 * terms of the Apache License Version 2.0 which is available at
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * SPDX-License-Identifier: Apache-2.0
12 ********************************************************************************/
13
14use std::sync::Arc;
15
16use async_trait::async_trait;
17
18use crate::{
19    core::usubscription::{
20        usubscription_uri, FetchSubscribersRequest, FetchSubscribersResponse,
21        FetchSubscriptionsRequest, FetchSubscriptionsResponse, NotificationsRequest,
22        NotificationsResponse, SubscriptionRequest, SubscriptionResponse, USubscription,
23        UnsubscribeRequest, UnsubscribeResponse, RESOURCE_ID_FETCH_SUBSCRIBERS,
24        RESOURCE_ID_FETCH_SUBSCRIPTIONS, RESOURCE_ID_REGISTER_FOR_NOTIFICATIONS,
25        RESOURCE_ID_SUBSCRIBE, RESOURCE_ID_UNREGISTER_FOR_NOTIFICATIONS, RESOURCE_ID_UNSUBSCRIBE,
26    },
27    UStatus,
28};
29
30use super::{CallOptions, RpcClient};
31
32/// A [`USubscription`] client implementation for invoking operations of a local USubscription service.
33///
34/// The client requires an [`RpcClient`] for performing the remote procedure calls.
35pub struct RpcClientUSubscription {
36    rpc_client: Arc<dyn RpcClient>,
37}
38
39impl RpcClientUSubscription {
40    /// Creates a new Notifier for a given transport.
41    ///
42    /// # Arguments
43    ///
44    /// * `rpc_client` - The client to use for performing the remote procedure calls on the USubscription service.
45    pub fn new(rpc_client: Arc<dyn RpcClient>) -> Self {
46        RpcClientUSubscription { rpc_client }
47    }
48
49    fn default_call_options() -> CallOptions {
50        CallOptions::for_rpc_request(5_000, None, None, None)
51    }
52}
53
54#[async_trait]
55impl USubscription for RpcClientUSubscription {
56    async fn subscribe(
57        &self,
58        subscription_request: SubscriptionRequest,
59    ) -> Result<SubscriptionResponse, UStatus> {
60        self.rpc_client
61            .invoke_proto_method::<_, SubscriptionResponse>(
62                usubscription_uri(RESOURCE_ID_SUBSCRIBE),
63                Self::default_call_options(),
64                subscription_request,
65            )
66            .await
67            .map_err(UStatus::from)
68    }
69
70    async fn unsubscribe(&self, unsubscribe_request: UnsubscribeRequest) -> Result<(), UStatus> {
71        self.rpc_client
72            .invoke_proto_method::<_, UnsubscribeResponse>(
73                usubscription_uri(RESOURCE_ID_UNSUBSCRIBE),
74                Self::default_call_options(),
75                unsubscribe_request,
76            )
77            .await
78            .map(|_response| ())
79            .map_err(UStatus::from)
80    }
81
82    async fn fetch_subscriptions(
83        &self,
84        fetch_subscriptions_request: FetchSubscriptionsRequest,
85    ) -> Result<FetchSubscriptionsResponse, UStatus> {
86        self.rpc_client
87            .invoke_proto_method::<_, FetchSubscriptionsResponse>(
88                usubscription_uri(RESOURCE_ID_FETCH_SUBSCRIPTIONS),
89                Self::default_call_options(),
90                fetch_subscriptions_request,
91            )
92            .await
93            .map_err(UStatus::from)
94    }
95
96    async fn register_for_notifications(
97        &self,
98        notifications_register_request: NotificationsRequest,
99    ) -> Result<(), UStatus> {
100        self.rpc_client
101            .invoke_proto_method::<_, NotificationsResponse>(
102                usubscription_uri(RESOURCE_ID_REGISTER_FOR_NOTIFICATIONS),
103                Self::default_call_options(),
104                notifications_register_request,
105            )
106            .await
107            .map(|_response| ())
108            .map_err(UStatus::from)
109    }
110
111    async fn unregister_for_notifications(
112        &self,
113        notifications_unregister_request: NotificationsRequest,
114    ) -> Result<(), UStatus> {
115        self.rpc_client
116            .invoke_proto_method::<_, NotificationsResponse>(
117                usubscription_uri(RESOURCE_ID_UNREGISTER_FOR_NOTIFICATIONS),
118                Self::default_call_options(),
119                notifications_unregister_request,
120            )
121            .await
122            .map(|_response| ())
123            .map_err(UStatus::from)
124    }
125
126    async fn fetch_subscribers(
127        &self,
128        fetch_subscribers_request: FetchSubscribersRequest,
129    ) -> Result<FetchSubscribersResponse, UStatus> {
130        self.rpc_client
131            .invoke_proto_method::<_, FetchSubscribersResponse>(
132                usubscription_uri(RESOURCE_ID_FETCH_SUBSCRIBERS),
133                Self::default_call_options(),
134                fetch_subscribers_request,
135            )
136            .await
137            .map_err(UStatus::from)
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use mockall::Sequence;
144
145    use super::*;
146    use crate::{
147        communication::{rpc::MockRpcClient, UPayload},
148        core::usubscription::{Request, SubscriptionResponse},
149        UCode, UUri,
150    };
151    use std::sync::Arc;
152
153    #[tokio::test]
154    async fn test_subscribe_invokes_rpc_client() {
155        let topic = UUri::try_from_parts("other", 0xd5a3, 0x01, 0xd3fe).unwrap();
156        let request = SubscriptionRequest {
157            topic: Some(topic).into(),
158            ..Default::default()
159        };
160        let expected_request = request.clone();
161        let mut rpc_client = MockRpcClient::new();
162        let mut seq = Sequence::new();
163        rpc_client
164            .expect_invoke_method()
165            .once()
166            .in_sequence(&mut seq)
167            .withf(|method, _options, payload| {
168                method == &usubscription_uri(RESOURCE_ID_SUBSCRIBE) && payload.is_some()
169            })
170            .return_const(Err(crate::communication::ServiceInvocationError::Internal(
171                "internal error".to_string(),
172            )));
173        rpc_client
174            .expect_invoke_method()
175            .once()
176            .in_sequence(&mut seq)
177            .withf(move |method, _options, payload| {
178                let request = payload
179                    .to_owned()
180                    .unwrap()
181                    .extract_protobuf::<SubscriptionRequest>()
182                    .unwrap();
183                request == expected_request && method == &usubscription_uri(RESOURCE_ID_SUBSCRIBE)
184            })
185            .returning(move |_method, _options, _payload| {
186                let response = SubscriptionResponse {
187                    ..Default::default()
188                };
189                Ok(Some(UPayload::try_from_protobuf(response).unwrap()))
190            });
191
192        let usubscription_client = RpcClientUSubscription::new(Arc::new(rpc_client));
193
194        assert!(usubscription_client
195            .subscribe(request.clone())
196            .await
197            .is_err_and(|e| e.get_code() == UCode::INTERNAL));
198        assert!(usubscription_client.subscribe(request).await.is_ok());
199    }
200
201    #[tokio::test]
202    async fn test_unsubscribe_invokes_rpc_client() {
203        let topic = UUri::try_from_parts("other", 0xd5a3, 0x01, 0xd3fe).unwrap();
204        let request = UnsubscribeRequest {
205            topic: Some(topic).into(),
206            ..Default::default()
207        };
208        let expected_request = request.clone();
209        let mut rpc_client = MockRpcClient::new();
210        let mut seq = Sequence::new();
211        rpc_client
212            .expect_invoke_method()
213            .once()
214            .in_sequence(&mut seq)
215            .withf(|method, _options, payload| {
216                method == &usubscription_uri(RESOURCE_ID_UNSUBSCRIBE) && payload.is_some()
217            })
218            .return_const(Err(crate::communication::ServiceInvocationError::Internal(
219                "internal error".to_string(),
220            )));
221        rpc_client
222            .expect_invoke_method()
223            .once()
224            .in_sequence(&mut seq)
225            .withf(move |method, _options, payload| {
226                let request = payload
227                    .to_owned()
228                    .unwrap()
229                    .extract_protobuf::<UnsubscribeRequest>()
230                    .unwrap();
231                request == expected_request && method == &usubscription_uri(RESOURCE_ID_UNSUBSCRIBE)
232            })
233            .returning(move |_method, _options, _payload| {
234                let response = UnsubscribeResponse {
235                    ..Default::default()
236                };
237                Ok(Some(UPayload::try_from_protobuf(response).unwrap()))
238            });
239
240        let usubscription_client = RpcClientUSubscription::new(Arc::new(rpc_client));
241
242        assert!(usubscription_client
243            .unsubscribe(request.clone())
244            .await
245            .is_err_and(|e| e.get_code() == UCode::INTERNAL));
246        assert!(usubscription_client.unsubscribe(request).await.is_ok());
247    }
248
249    #[tokio::test]
250    async fn test_fetch_subscriptions_invokes_rpc_client() {
251        let topic = UUri::try_from_parts("other", 0xd5a3, 0x01, 0xd3fe).unwrap();
252        let request = FetchSubscriptionsRequest {
253            request: Some(Request::Topic(topic)),
254            ..Default::default()
255        };
256        let expected_request = request.clone();
257        let mut rpc_client = MockRpcClient::new();
258        let mut seq = Sequence::new();
259        rpc_client
260            .expect_invoke_method()
261            .once()
262            .in_sequence(&mut seq)
263            .withf(|method, _options, payload| {
264                method == &usubscription_uri(RESOURCE_ID_FETCH_SUBSCRIPTIONS) && payload.is_some()
265            })
266            .return_const(Err(crate::communication::ServiceInvocationError::Internal(
267                "internal error".to_string(),
268            )));
269        rpc_client
270            .expect_invoke_method()
271            .once()
272            .in_sequence(&mut seq)
273            .withf(move |method, _options, payload| {
274                let request = payload
275                    .to_owned()
276                    .unwrap()
277                    .extract_protobuf::<FetchSubscriptionsRequest>()
278                    .unwrap();
279
280                request == expected_request
281                    && method == &usubscription_uri(RESOURCE_ID_FETCH_SUBSCRIPTIONS)
282            })
283            .returning(move |_method, _options, _payload| {
284                let response = FetchSubscriptionsResponse {
285                    ..Default::default()
286                };
287                Ok(Some(UPayload::try_from_protobuf(response).unwrap()))
288            });
289
290        let usubscription_client = RpcClientUSubscription::new(Arc::new(rpc_client));
291
292        assert!(usubscription_client
293            .fetch_subscriptions(request.clone())
294            .await
295            .is_err_and(|e| e.get_code() == UCode::INTERNAL));
296        assert!(usubscription_client
297            .fetch_subscriptions(request)
298            .await
299            .is_ok());
300    }
301
302    #[tokio::test]
303    async fn test_fetch_subscribers_invokes_rpc_client() {
304        let topic = UUri::try_from_parts("other", 0xd5a3, 0x01, 0xd3fe).unwrap();
305        let request = FetchSubscribersRequest {
306            topic: Some(topic).into(),
307            ..Default::default()
308        };
309        let expected_request = request.clone();
310        let mut rpc_client = MockRpcClient::new();
311        let mut seq = Sequence::new();
312        rpc_client
313            .expect_invoke_method()
314            .once()
315            .in_sequence(&mut seq)
316            .withf(|method, _options, payload| {
317                method == &usubscription_uri(RESOURCE_ID_FETCH_SUBSCRIBERS) && payload.is_some()
318            })
319            .return_const(Err(crate::communication::ServiceInvocationError::Internal(
320                "internal error".to_string(),
321            )));
322        rpc_client
323            .expect_invoke_method()
324            .once()
325            .in_sequence(&mut seq)
326            .withf(move |method, _options, payload| {
327                let request = payload
328                    .to_owned()
329                    .unwrap()
330                    .extract_protobuf::<FetchSubscribersRequest>()
331                    .unwrap();
332
333                request == expected_request
334                    && method == &usubscription_uri(RESOURCE_ID_FETCH_SUBSCRIBERS)
335            })
336            .returning(move |_method, _options, _payload| {
337                let response = FetchSubscribersResponse {
338                    ..Default::default()
339                };
340                Ok(Some(UPayload::try_from_protobuf(response).unwrap()))
341            });
342
343        let usubscription_client = RpcClientUSubscription::new(Arc::new(rpc_client));
344
345        assert!(usubscription_client
346            .fetch_subscribers(request.clone())
347            .await
348            .is_err_and(|e| e.get_code() == UCode::INTERNAL));
349        assert!(usubscription_client
350            .fetch_subscribers(request)
351            .await
352            .is_ok());
353    }
354
355    #[tokio::test]
356    async fn test_register_for_notifications_invokes_rpc_client() {
357        let topic = UUri::try_from_parts("other", 0xd5a3, 0x01, 0xd3fe).unwrap();
358        let request = NotificationsRequest {
359            topic: Some(topic).into(),
360            ..Default::default()
361        };
362        let expected_request = request.clone();
363        let mut rpc_client = MockRpcClient::new();
364        let mut seq = Sequence::new();
365        rpc_client
366            .expect_invoke_method()
367            .once()
368            .in_sequence(&mut seq)
369            .withf(|method, _options, payload| {
370                method == &usubscription_uri(RESOURCE_ID_REGISTER_FOR_NOTIFICATIONS)
371                    && payload.is_some()
372            })
373            .return_const(Err(crate::communication::ServiceInvocationError::Internal(
374                "internal error".to_string(),
375            )));
376        rpc_client
377            .expect_invoke_method()
378            .once()
379            .in_sequence(&mut seq)
380            .withf(move |method, _options, payload| {
381                let request = payload
382                    .to_owned()
383                    .unwrap()
384                    .extract_protobuf::<NotificationsRequest>()
385                    .unwrap();
386
387                request == expected_request
388                    && method == &usubscription_uri(RESOURCE_ID_REGISTER_FOR_NOTIFICATIONS)
389            })
390            .returning(move |_method, _options, _payload| {
391                let response = NotificationsResponse {
392                    ..Default::default()
393                };
394                Ok(Some(UPayload::try_from_protobuf(response).unwrap()))
395            });
396
397        let usubscription_client = RpcClientUSubscription::new(Arc::new(rpc_client));
398
399        assert!(usubscription_client
400            .register_for_notifications(request.clone())
401            .await
402            .is_err_and(|e| e.get_code() == UCode::INTERNAL));
403        assert!(usubscription_client
404            .register_for_notifications(request)
405            .await
406            .is_ok());
407    }
408
409    #[tokio::test]
410    async fn test_unregister_for_notifications_invokes_rpc_client() {
411        let topic = UUri::try_from_parts("other", 0xd5a3, 0x01, 0xd3fe).unwrap();
412        let request = NotificationsRequest {
413            topic: Some(topic).into(),
414            ..Default::default()
415        };
416        let expected_request = request.clone();
417        let mut rpc_client = MockRpcClient::new();
418        let mut seq = Sequence::new();
419        rpc_client
420            .expect_invoke_method()
421            .once()
422            .in_sequence(&mut seq)
423            .withf(|method, _options, payload| {
424                method == &usubscription_uri(RESOURCE_ID_UNREGISTER_FOR_NOTIFICATIONS)
425                    && payload.is_some()
426            })
427            .return_const(Err(crate::communication::ServiceInvocationError::Internal(
428                "internal error".to_string(),
429            )));
430        rpc_client
431            .expect_invoke_method()
432            .once()
433            .in_sequence(&mut seq)
434            .withf(move |method, _options, payload| {
435                let request = payload
436                    .to_owned()
437                    .unwrap()
438                    .extract_protobuf::<NotificationsRequest>()
439                    .unwrap();
440
441                request == expected_request
442                    && method == &usubscription_uri(RESOURCE_ID_UNREGISTER_FOR_NOTIFICATIONS)
443            })
444            .returning(move |_method, _options, _payload| {
445                let response = NotificationsResponse {
446                    ..Default::default()
447                };
448                Ok(Some(UPayload::try_from_protobuf(response).unwrap()))
449            });
450
451        let usubscription_client = RpcClientUSubscription::new(Arc::new(rpc_client));
452
453        assert!(usubscription_client
454            .unregister_for_notifications(request.clone())
455            .await
456            .is_err_and(|e| e.get_code() == UCode::INTERNAL));
457        assert!(usubscription_client
458            .unregister_for_notifications(request)
459            .await
460            .is_ok());
461    }
462}