up_rust/communication/
in_memory_rpc_server.rs

1/********************************************************************************
2 * Copyright (c) 2024 Contributors to the Eclipse Foundation
3 *
4 * See the NOTICE file(s) distributed with this work for additional
5 * information regarding copyright ownership.
6 *
7 * This program and the accompanying materials are made available under the
8 * terms of the Apache License Version 2.0 which is available at
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * SPDX-License-Identifier: Apache-2.0
12 ********************************************************************************/
13
14// [impl->dsn~communication-layer-impl-default~1]
15
16use std::collections::hash_map::Entry;
17use std::collections::HashMap;
18use std::sync::Arc;
19use std::time::Duration;
20
21use async_trait::async_trait;
22use tracing::{debug, info};
23
24use crate::{
25    communication::build_message, LocalUriProvider, UListener, UMessage, UMessageBuilder, UStatus,
26    UTransport, UUri,
27};
28
29use super::{RegistrationError, RequestHandler, RpcServer, ServiceInvocationError, UPayload};
30
31struct RequestListener {
32    request_handler: Arc<dyn RequestHandler>,
33    transport: Arc<dyn UTransport>,
34}
35
36impl RequestListener {
37    async fn process_valid_request(&self, resource_id: u16, request_message: UMessage) {
38        let transport_clone = self.transport.clone();
39        let request_handler_clone = self.request_handler.clone();
40        let mut response_builder =
41            UMessageBuilder::response_for_request(request_message.attributes_unchecked());
42
43        let request_message_id = request_message.id_unchecked().to_hyphenated_string();
44        let request_timeout = request_message.ttl_unchecked();
45        let payload_format = request_message.payload_format().unwrap_or_default();
46        let payload = request_message.payload;
47        let request_payload = payload.map(|data| UPayload::new(data, payload_format));
48
49        debug!(
50            ttl = request_timeout,
51            id = request_message_id,
52            resource_id = resource_id,
53            "processing RPC request"
54        );
55
56        let invocation_result_future = request_handler_clone.handle_request(
57            resource_id,
58            &request_message.attributes,
59            request_payload,
60        );
61        let outcome = tokio::time::timeout(
62            Duration::from_millis(request_timeout as u64),
63            invocation_result_future,
64        )
65        .await
66        .map_err(|_e| {
67            info!(ttl = request_timeout, "request handler timed out");
68            ServiceInvocationError::DeadlineExceeded
69        })
70        .and_then(|v| v);
71
72        let response = match outcome {
73            Ok(response_payload) => build_message(&mut response_builder, response_payload),
74            Err(e) => {
75                let error = UStatus::from(e);
76                response_builder
77                    .with_comm_status(error.get_code())
78                    .build_with_protobuf_payload(&error)
79            }
80        };
81
82        match response {
83            Ok(response_message) => {
84                if let Err(e) = transport_clone.send(response_message).await {
85                    info!(ucode = e.code.value(), "failed to send response message");
86                }
87            }
88            Err(e) => {
89                info!("failed to create response message: {}", e);
90            }
91        }
92    }
93}
94
95#[async_trait]
96impl UListener for RequestListener {
97    async fn on_receive(&self, msg: UMessage) {
98        if msg.is_request() {
99            // cannot fail because inbound messages are validated at the transport layer already
100            let method_id = msg.sink_unchecked().resource_id();
101            self.process_valid_request(method_id, msg).await;
102        } else {
103            debug!(
104                message_type = msg.type_unchecked().to_cloudevent_type(),
105                "ignoring non-request message received by RPC server"
106            );
107        }
108    }
109}
110
111/// An [`RpcServer`] which keeps all information about registered endpoints in memory.
112///
113/// The server requires an implementations of [`UTransport`] for receiving RPC Request messages
114/// from clients and sending back RPC Response messages.
115///
116/// For each [endpoint being registered](`Self::register_endpoint`), a [`UListener`] is created for
117/// the given request handler and registered with the underlying transport. The listener is also
118/// mapped to the endpoint's method resource ID in order to prevent registration of multiple
119/// request handlers for the same method.
120pub struct InMemoryRpcServer {
121    transport: Arc<dyn UTransport>,
122    uri_provider: Arc<dyn LocalUriProvider>,
123    request_listeners: tokio::sync::Mutex<HashMap<u16, Arc<dyn UListener>>>,
124}
125
126impl InMemoryRpcServer {
127    /// Creates a new RPC server for a given transport.
128    pub fn new(transport: Arc<dyn UTransport>, uri_provider: Arc<dyn LocalUriProvider>) -> Self {
129        InMemoryRpcServer {
130            transport,
131            uri_provider,
132            request_listeners: tokio::sync::Mutex::new(HashMap::new()),
133        }
134    }
135
136    fn validate_sink_filter(filter: &UUri) -> Result<(), RegistrationError> {
137        if !filter.is_rpc_method() {
138            return Err(RegistrationError::InvalidFilter(
139                "RPC endpoint's resource ID must be in range [0x0001, 0x7FFF]".to_string(),
140            ));
141        }
142        Ok(())
143    }
144
145    fn validate_origin_filter(filter: Option<&UUri>) -> Result<(), RegistrationError> {
146        if let Some(uri) = filter {
147            if !uri.is_rpc_response() {
148                return Err(RegistrationError::InvalidFilter(
149                    "origin filter's resource ID must be 0".to_string(),
150                ));
151            }
152        }
153        Ok(())
154    }
155
156    #[cfg(test)]
157    async fn contains_endpoint(&self, resource_id: u16) -> bool {
158        let listener_map = self.request_listeners.lock().await;
159        listener_map.contains_key(&resource_id)
160    }
161}
162
163#[async_trait]
164impl RpcServer for InMemoryRpcServer {
165    async fn register_endpoint(
166        &self,
167        origin_filter: Option<&UUri>,
168        resource_id: u16,
169        request_handler: Arc<dyn RequestHandler>,
170    ) -> Result<(), RegistrationError> {
171        Self::validate_origin_filter(origin_filter)?;
172        let sink_filter = self.uri_provider.get_resource_uri(resource_id);
173        Self::validate_sink_filter(&sink_filter)?;
174
175        let mut listener_map = self.request_listeners.lock().await;
176        if let Entry::Vacant(e) = listener_map.entry(resource_id) {
177            let listener = Arc::new(RequestListener {
178                request_handler,
179                transport: self.transport.clone(),
180            });
181            self.transport
182                .register_listener(
183                    origin_filter.unwrap_or(&UUri::any_with_resource_id(
184                        crate::uri::RESOURCE_ID_RESPONSE,
185                    )),
186                    Some(&sink_filter),
187                    listener.clone(),
188                )
189                .await
190                .map(|_| {
191                    e.insert(listener);
192                })
193                .map_err(RegistrationError::from)
194        } else {
195            Err(RegistrationError::MaxListenersExceeded)
196        }
197    }
198
199    async fn unregister_endpoint(
200        &self,
201        origin_filter: Option<&UUri>,
202        resource_id: u16,
203        _request_handler: Arc<dyn RequestHandler>,
204    ) -> Result<(), RegistrationError> {
205        Self::validate_origin_filter(origin_filter)?;
206        let sink_filter = self.uri_provider.get_resource_uri(resource_id);
207        Self::validate_sink_filter(&sink_filter)?;
208
209        let mut listener_map = self.request_listeners.lock().await;
210        if let Entry::Occupied(entry) = listener_map.entry(resource_id) {
211            let listener = entry.get().to_owned();
212            self.transport
213                .unregister_listener(
214                    origin_filter.unwrap_or(&UUri::any_with_resource_id(
215                        crate::uri::RESOURCE_ID_RESPONSE,
216                    )),
217                    Some(&sink_filter),
218                    listener,
219                )
220                .await
221                .map(|_| {
222                    entry.remove();
223                })
224                .map_err(RegistrationError::from)
225        } else {
226            Err(RegistrationError::NoSuchListener)
227        }
228    }
229}
230
231#[cfg(test)]
232mod tests {
233
234    // [utest->dsn~communication-layer-impl-default~1]
235
236    use super::*;
237
238    use protobuf::well_known_types::wrappers::StringValue;
239    use test_case::test_case;
240    use tokio::sync::Notify;
241
242    use crate::{
243        communication::rpc::MockRequestHandler, utransport::MockTransport, StaticUriProvider,
244        UAttributes, UCode, UUri, UUID,
245    };
246
247    fn new_uri_provider() -> Arc<dyn LocalUriProvider> {
248        Arc::new(StaticUriProvider::new("", 0x0005, 0x02))
249    }
250
251    #[test_case(None, 0x4A10; "for empty origin filter")]
252    #[test_case(Some(UUri::try_from_parts("authority", 0xBF1A, 0x01, 0x0000).unwrap()), 0x4A10; "for specific origin filter")]
253    #[test_case(Some(UUri::try_from_parts("*", 0xFFFF, 0x01, 0x0000).unwrap()), 0x7091; "for wildcard origin filter")]
254    #[tokio::test]
255    async fn test_register_endpoint_succeeds(origin_filter: Option<UUri>, resource_id: u16) {
256        // GIVEN an RpcServer for a transport
257        let request_handler = Arc::new(MockRequestHandler::new());
258        let mut transport = MockTransport::new();
259        let uri_provider = new_uri_provider();
260        let expected_source_filter = origin_filter
261            .clone()
262            .unwrap_or(UUri::any_with_resource_id(0));
263        let param_check = move |source_filter: &UUri,
264                                sink_filter: &Option<&UUri>,
265                                _listener: &Arc<dyn UListener>| {
266            source_filter == &expected_source_filter
267                && sink_filter.is_some_and(|uri| uri.resource_id == resource_id as u32)
268        };
269        transport
270            .expect_do_register_listener()
271            .once()
272            .withf(param_check.clone())
273            .returning(|_source_filter, _sink_filter, _listener| Ok(()));
274        transport
275            .expect_do_unregister_listener()
276            .once()
277            .withf(param_check)
278            .returning(|_source_filter, _sink_filter, _listener| Ok(()));
279
280        let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
281
282        // WHEN registering a request handler
283        let register_result = rpc_server
284            .register_endpoint(origin_filter.as_ref(), resource_id, request_handler.clone())
285            .await;
286        // THEN registration succeeds
287        assert!(register_result.is_ok());
288        assert!(rpc_server.contains_endpoint(resource_id).await);
289
290        // and the handler can be unregistered again
291        let unregister_result = rpc_server
292            .unregister_endpoint(origin_filter.as_ref(), resource_id, request_handler)
293            .await;
294        assert!(unregister_result.is_ok());
295        assert!(!rpc_server.contains_endpoint(resource_id).await);
296    }
297
298    #[test_case(None, 0x0000; "for resource ID 0")]
299    #[test_case(None, 0x8000; "for resource ID out of range")]
300    #[test_case(Some(UUri::try_from_parts("*", 0xFFFF, 0xFF, 0x0001).unwrap()), 0x4A10; "for source filter with invalid resource ID")]
301    #[tokio::test]
302    async fn test_register_endpoint_fails(origin_filter: Option<UUri>, resource_id: u16) {
303        // GIVEN an RpcServer for a transport
304        let request_handler = Arc::new(MockRequestHandler::new());
305        let mut transport = MockTransport::new();
306        let uri_provider = new_uri_provider();
307        transport.expect_do_register_listener().never();
308        transport.expect_do_unregister_listener().never();
309
310        let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
311
312        // WHEN registering a request handler using invalid parameters
313        let register_result = rpc_server
314            .register_endpoint(origin_filter.as_ref(), resource_id, request_handler.clone())
315            .await;
316        // THEN registration fails
317        assert!(register_result.is_err_and(|e| matches!(e, RegistrationError::InvalidFilter(_v))));
318        assert!(!rpc_server.contains_endpoint(resource_id).await);
319
320        // and an attempt to unregister the handler using the same invalid parameters also fails with the same error
321        let unregister_result = rpc_server
322            .unregister_endpoint(origin_filter.as_ref(), resource_id, request_handler)
323            .await;
324        assert!(unregister_result.is_err_and(|e| matches!(e, RegistrationError::InvalidFilter(_v))));
325    }
326
327    #[tokio::test]
328    async fn test_register_endpoint_fails_for_duplicate_endpoint() {
329        // GIVEN an RpcServer for a transport
330        let request_handler = Arc::new(MockRequestHandler::new());
331        let mut transport = MockTransport::new();
332        let uri_provider = new_uri_provider();
333        transport
334            .expect_do_register_listener()
335            .once()
336            .return_const(Ok(()));
337
338        let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
339
340        // WHEN registering a request handler for an already existing endpoint
341        assert!(rpc_server
342            .register_endpoint(None, 0x5000, request_handler.clone())
343            .await
344            .is_ok());
345        let result = rpc_server
346            .register_endpoint(None, 0x5000, request_handler)
347            .await;
348
349        // THEN registration of the additional handler fails
350        assert!(result.is_err_and(|e| matches!(e, RegistrationError::MaxListenersExceeded)));
351        // but the original endpoint is still registered
352        assert!(rpc_server.contains_endpoint(0x5000).await);
353    }
354
355    #[tokio::test]
356    async fn test_unregister_endpoint_fails_for_non_existing_endpoint() {
357        // GIVEN an RpcServer for a transport
358        let request_handler = Arc::new(MockRequestHandler::new());
359        let mut transport = MockTransport::new();
360        let uri_provider = new_uri_provider();
361        transport.expect_do_unregister_listener().never();
362
363        let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
364
365        // WHEN trying to unregister a non existing endpoint
366        assert!(!rpc_server.contains_endpoint(0x5000).await);
367        let result = rpc_server
368            .unregister_endpoint(None, 0x5000, request_handler)
369            .await;
370
371        // THEN registration fails
372        assert!(result.is_err_and(|e| matches!(e, RegistrationError::NoSuchListener)));
373    }
374
375    #[tokio::test]
376    async fn test_request_listener_invokes_operation_successfully() {
377        let mut request_handler = MockRequestHandler::new();
378        let mut transport = MockTransport::new();
379        let notify = Arc::new(Notify::new());
380        let notify_clone = notify.clone();
381        let request_payload = StringValue {
382            value: "Hello".to_string(),
383            ..Default::default()
384        };
385        let message_id = UUID::build();
386        let message_id_clone = message_id.clone();
387        let message_source = UUri::try_from("up://localhost/A100/1/0").unwrap();
388        let message_source_clone = message_source.clone();
389
390        request_handler
391            .expect_handle_request()
392            .once()
393            .withf(move |resource_id, message_attributes, request_payload| {
394                if let Some(pl) = request_payload {
395                    let message_source = message_attributes.source.as_ref().unwrap();
396                    let msg: StringValue = pl.extract_protobuf().unwrap();
397                    msg.value == *"Hello"
398                        && *resource_id == 0x7000_u16
399                        && *message_source == message_source_clone
400                } else {
401                    false
402                }
403            })
404            .returning(|_resource_id, _message_attributes, _request_payload| {
405                let response_payload = UPayload::try_from_protobuf(StringValue {
406                    value: "Hello World".to_string(),
407                    ..Default::default()
408                })
409                .unwrap();
410                Ok(Some(response_payload))
411            });
412        transport
413            .expect_do_send()
414            .once()
415            .withf(move |response_message| {
416                let msg: StringValue = response_message.extract_protobuf().unwrap();
417                msg.value == *"Hello World"
418                    && response_message.is_response()
419                    && response_message
420                        .commstatus()
421                        .is_none_or(|code| code == UCode::OK)
422                    && response_message.request_id_unchecked() == &message_id_clone
423            })
424            .returning(move |_msg| {
425                notify_clone.notify_one();
426                Ok(())
427            });
428        let request_message = UMessageBuilder::request(
429            UUri::try_from("up://localhost/A200/1/7000").unwrap(),
430            message_source,
431            5_000,
432        )
433        .with_message_id(message_id)
434        .build_with_protobuf_payload(&request_payload)
435        .unwrap();
436
437        let request_listener = RequestListener {
438            request_handler: Arc::new(request_handler),
439            transport: Arc::new(transport),
440        };
441        request_listener.on_receive(request_message).await;
442        let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
443        assert!(result.is_ok());
444    }
445
446    #[tokio::test]
447    async fn test_request_listener_invokes_operation_erroneously() {
448        let mut request_handler = MockRequestHandler::new();
449        let mut transport = MockTransport::new();
450        let notify = Arc::new(Notify::new());
451        let notify_clone = notify.clone();
452        let message_id = UUID::build();
453        let message_id_clone = message_id.clone();
454
455        request_handler
456            .expect_handle_request()
457            .once()
458            .withf(|resource_id, _message_attributes, _request_payload| *resource_id == 0x7000_u16)
459            .returning(|_resource_id, _message_attributes, _request_payload| {
460                Err(ServiceInvocationError::NotFound(
461                    "no such object".to_string(),
462                ))
463            });
464        transport
465            .expect_do_send()
466            .once()
467            .withf(move |response_message| {
468                let error: UStatus = response_message.extract_protobuf().unwrap();
469                error.get_code() == UCode::NOT_FOUND
470                    && response_message.is_response()
471                    && response_message.commstatus_unchecked() == error.get_code()
472                    && response_message.request_id_unchecked() == &message_id_clone
473            })
474            .returning(move |_msg| {
475                notify_clone.notify_one();
476                Ok(())
477            });
478        let request_message = UMessageBuilder::request(
479            UUri::try_from("up://localhost/A200/1/7000").unwrap(),
480            UUri::try_from("up://localhost/A100/1/0").unwrap(),
481            5_000,
482        )
483        .with_message_id(message_id)
484        .build()
485        .unwrap();
486
487        let request_listener = RequestListener {
488            request_handler: Arc::new(request_handler),
489            transport: Arc::new(transport),
490        };
491        request_listener.on_receive(request_message).await;
492        let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
493        assert!(result.is_ok());
494    }
495
496    #[tokio::test]
497    async fn test_request_listener_times_out() {
498        // we need to manually implement the RequestHandler
499        // because from within the MockRequestHandler's expectation
500        // we cannot yield the current task (we can only use the blocking
501        // thread::sleep function)
502        struct NonRespondingHandler;
503        #[async_trait]
504        impl RequestHandler for NonRespondingHandler {
505            async fn handle_request(
506                &self,
507                resource_id: u16,
508                _message_attributes: &UAttributes,
509                _request_payload: Option<UPayload>,
510            ) -> Result<Option<UPayload>, ServiceInvocationError> {
511                assert_eq!(resource_id, 0x7000);
512                // this will yield the current task and allow the
513                // RequestListener to run into the timeout
514                tokio::time::sleep(Duration::from_millis(2000)).await;
515                Ok(None)
516            }
517        }
518
519        let request_handler = NonRespondingHandler {};
520        let mut transport = MockTransport::new();
521        let notify = Arc::new(Notify::new());
522        let notify_clone = notify.clone();
523        let message_id = UUID::build();
524        let message_id_clone = message_id.clone();
525
526        transport
527            .expect_do_send()
528            .once()
529            .withf(move |response_message| {
530                let error: UStatus = response_message.extract_protobuf().unwrap();
531                error.get_code() == UCode::DEADLINE_EXCEEDED
532                    && response_message.is_response()
533                    && response_message.commstatus_unchecked() == error.get_code()
534                    && response_message.request_id_unchecked() == &message_id_clone
535            })
536            .returning(move |_msg| {
537                notify_clone.notify_one();
538                Ok(())
539            });
540        let request_message = UMessageBuilder::request(
541            UUri::try_from("up://localhost/A200/1/7000").unwrap(),
542            UUri::try_from("up://localhost/A100/1/0").unwrap(),
543            // make sure this request times out very quickly
544            100,
545        )
546        .with_message_id(message_id)
547        .build()
548        .expect("should have been able to create RPC Request message");
549
550        let request_listener = RequestListener {
551            request_handler: Arc::new(request_handler),
552            transport: Arc::new(transport),
553        };
554        request_listener.on_receive(request_message).await;
555        let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
556        assert!(result.is_ok());
557    }
558}