up_rust/communication/
in_memory_rpc_server.rs

1/********************************************************************************
2 * Copyright (c) 2024 Contributors to the Eclipse Foundation
3 *
4 * See the NOTICE file(s) distributed with this work for additional
5 * information regarding copyright ownership.
6 *
7 * This program and the accompanying materials are made available under the
8 * terms of the Apache License Version 2.0 which is available at
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * SPDX-License-Identifier: Apache-2.0
12 ********************************************************************************/
13
14// [impl->dsn~communication-layer-impl-default~1]
15
16use std::collections::hash_map::Entry;
17use std::collections::HashMap;
18use std::sync::Arc;
19use std::time::Duration;
20
21use async_trait::async_trait;
22use tracing::{debug, info};
23
24use crate::{
25    communication::build_message, LocalUriProvider, UListener, UMessage, UMessageBuilder, UStatus,
26    UTransport, UUri,
27};
28
29use super::{RegistrationError, RequestHandler, RpcServer, ServiceInvocationError, UPayload};
30
31struct RequestListener {
32    request_handler: Arc<dyn RequestHandler>,
33    transport: Arc<dyn UTransport>,
34}
35
36impl RequestListener {
37    async fn process_valid_request(&self, resource_id: u16, request_message: UMessage) {
38        let transport_clone = self.transport.clone();
39        let request_handler_clone = self.request_handler.clone();
40        let mut response_builder =
41            UMessageBuilder::response_for_request(request_message.attributes_unchecked());
42
43        let request_message_id = request_message.id_unchecked().clone();
44        let request_timeout = request_message.ttl_unchecked();
45        let payload_format = request_message.payload_format().unwrap_or_default();
46        let payload = request_message.payload;
47        let request_payload = payload.map(|data| UPayload::new(data, payload_format));
48
49        debug!(ttl = request_timeout, id = %request_message_id, "processing RPC request");
50
51        let invocation_result_future = request_handler_clone.handle_request(
52            resource_id,
53            &request_message.attributes,
54            request_payload,
55        );
56        let outcome = tokio::time::timeout(
57            Duration::from_millis(request_timeout as u64),
58            invocation_result_future,
59        )
60        .await
61        .map_err(|_e| {
62            info!(ttl = request_timeout, "request handler timed out");
63            ServiceInvocationError::DeadlineExceeded
64        })
65        .and_then(|v| v);
66
67        let response = match outcome {
68            Ok(response_payload) => build_message(&mut response_builder, response_payload),
69            Err(e) => {
70                let error = UStatus::from(e);
71                response_builder
72                    .with_comm_status(error.get_code())
73                    .build_with_protobuf_payload(&error)
74            }
75        };
76
77        match response {
78            Ok(response_message) => {
79                if let Err(e) = transport_clone.send(response_message).await {
80                    info!(ucode = e.code.value(), "failed to send response message");
81                }
82            }
83            Err(e) => {
84                info!("failed to create response message: {}", e);
85            }
86        }
87    }
88}
89
90#[async_trait]
91impl UListener for RequestListener {
92    async fn on_receive(&self, msg: UMessage) {
93        if msg.is_request() {
94            // cannot fail because inbound messages are validated at the transport layer already
95            let method_id = msg.sink_unchecked().resource_id();
96            self.process_valid_request(method_id, msg).await;
97        } else {
98            debug!(
99                message_type = msg.type_unchecked().to_cloudevent_type(),
100                "ignoring non-request message received by RPC server"
101            );
102        }
103    }
104}
105
106/// An [`RpcServer`] which keeps all information about registered endpoints in memory.
107///
108/// The server requires an implementations of [`UTransport`] for receiving RPC Request messages
109/// from clients and sending back RPC Response messages.
110///
111/// For each [endpoint being registered](`Self::register_endpoint`), a [`UListener`] is created for
112/// the given request handler and registered with the underlying transport. The listener is also
113/// mapped to the endpoint's method resource ID in order to prevent registration of multiple
114/// request handlers for the same method.
115pub struct InMemoryRpcServer {
116    transport: Arc<dyn UTransport>,
117    uri_provider: Arc<dyn LocalUriProvider>,
118    request_listeners: tokio::sync::Mutex<HashMap<u16, Arc<dyn UListener>>>,
119}
120
121impl InMemoryRpcServer {
122    /// Creates a new RPC server for a given transport.
123    pub fn new(transport: Arc<dyn UTransport>, uri_provider: Arc<dyn LocalUriProvider>) -> Self {
124        InMemoryRpcServer {
125            transport,
126            uri_provider,
127            request_listeners: tokio::sync::Mutex::new(HashMap::new()),
128        }
129    }
130
131    fn validate_sink_filter(filter: &UUri) -> Result<(), RegistrationError> {
132        if !filter.is_rpc_method() {
133            return Err(RegistrationError::InvalidFilter(
134                "RPC endpoint's resource ID must be in range [0x0001, 0x7FFF]".to_string(),
135            ));
136        }
137        Ok(())
138    }
139
140    fn validate_origin_filter(filter: Option<&UUri>) -> Result<(), RegistrationError> {
141        if let Some(uri) = filter {
142            if !uri.is_rpc_response() {
143                return Err(RegistrationError::InvalidFilter(
144                    "origin filter's resource ID must be 0".to_string(),
145                ));
146            }
147        }
148        Ok(())
149    }
150
151    #[cfg(test)]
152    async fn contains_endpoint(&self, resource_id: u16) -> bool {
153        let listener_map = self.request_listeners.lock().await;
154        listener_map.contains_key(&resource_id)
155    }
156}
157
158#[async_trait]
159impl RpcServer for InMemoryRpcServer {
160    async fn register_endpoint(
161        &self,
162        origin_filter: Option<&UUri>,
163        resource_id: u16,
164        request_handler: Arc<dyn RequestHandler>,
165    ) -> Result<(), RegistrationError> {
166        Self::validate_origin_filter(origin_filter)?;
167        let sink_filter = self.uri_provider.get_resource_uri(resource_id);
168        Self::validate_sink_filter(&sink_filter)?;
169
170        let mut listener_map = self.request_listeners.lock().await;
171        if let Entry::Vacant(e) = listener_map.entry(resource_id) {
172            let listener = Arc::new(RequestListener {
173                request_handler,
174                transport: self.transport.clone(),
175            });
176            self.transport
177                .register_listener(
178                    origin_filter.unwrap_or(&UUri::any_with_resource_id(
179                        crate::uri::RESOURCE_ID_RESPONSE,
180                    )),
181                    Some(&sink_filter),
182                    listener.clone(),
183                )
184                .await
185                .map(|_| {
186                    e.insert(listener);
187                })
188                .map_err(RegistrationError::from)
189        } else {
190            Err(RegistrationError::MaxListenersExceeded)
191        }
192    }
193
194    async fn unregister_endpoint(
195        &self,
196        origin_filter: Option<&UUri>,
197        resource_id: u16,
198        _request_handler: Arc<dyn RequestHandler>,
199    ) -> Result<(), RegistrationError> {
200        Self::validate_origin_filter(origin_filter)?;
201        let sink_filter = self.uri_provider.get_resource_uri(resource_id);
202        Self::validate_sink_filter(&sink_filter)?;
203
204        let mut listener_map = self.request_listeners.lock().await;
205        if let Entry::Occupied(entry) = listener_map.entry(resource_id) {
206            let listener = entry.get().to_owned();
207            self.transport
208                .unregister_listener(
209                    origin_filter.unwrap_or(&UUri::any_with_resource_id(
210                        crate::uri::RESOURCE_ID_RESPONSE,
211                    )),
212                    Some(&sink_filter),
213                    listener,
214                )
215                .await
216                .map(|_| {
217                    entry.remove();
218                })
219                .map_err(RegistrationError::from)
220        } else {
221            Err(RegistrationError::NoSuchListener)
222        }
223    }
224}
225
226#[cfg(test)]
227mod tests {
228
229    // [utest->dsn~communication-layer-impl-default~1]
230
231    use super::*;
232
233    use protobuf::well_known_types::wrappers::StringValue;
234    use test_case::test_case;
235    use tokio::sync::Notify;
236
237    use crate::{
238        communication::rpc::MockRequestHandler, utransport::MockTransport, StaticUriProvider,
239        UAttributes, UCode, UUri, UUID,
240    };
241
242    fn new_uri_provider() -> Arc<dyn LocalUriProvider> {
243        Arc::new(StaticUriProvider::new("", 0x0005, 0x02))
244    }
245
246    #[test_case(None, 0x4A10; "for empty origin filter")]
247    #[test_case(Some(UUri::try_from_parts("authority", 0xBF1A, 0x01, 0x0000).unwrap()), 0x4A10; "for specific origin filter")]
248    #[test_case(Some(UUri::try_from_parts("*", 0xFFFF, 0x01, 0x0000).unwrap()), 0x7091; "for wildcard origin filter")]
249    #[tokio::test]
250    async fn test_register_endpoint_succeeds(origin_filter: Option<UUri>, resource_id: u16) {
251        // GIVEN an RpcServer for a transport
252        let request_handler = Arc::new(MockRequestHandler::new());
253        let mut transport = MockTransport::new();
254        let uri_provider = new_uri_provider();
255        let expected_source_filter = origin_filter
256            .clone()
257            .unwrap_or(UUri::any_with_resource_id(0));
258        let param_check = move |source_filter: &UUri,
259                                sink_filter: &Option<&UUri>,
260                                _listener: &Arc<dyn UListener>| {
261            source_filter == &expected_source_filter
262                && sink_filter.is_some_and(|uri| uri.resource_id == resource_id as u32)
263        };
264        transport
265            .expect_do_register_listener()
266            .once()
267            .withf(param_check.clone())
268            .returning(|_source_filter, _sink_filter, _listener| Ok(()));
269        transport
270            .expect_do_unregister_listener()
271            .once()
272            .withf(param_check)
273            .returning(|_source_filter, _sink_filter, _listener| Ok(()));
274
275        let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
276
277        // WHEN registering a request handler
278        let register_result = rpc_server
279            .register_endpoint(origin_filter.as_ref(), resource_id, request_handler.clone())
280            .await;
281        // THEN registration succeeds
282        assert!(register_result.is_ok());
283        assert!(rpc_server.contains_endpoint(resource_id).await);
284
285        // and the handler can be unregistered again
286        let unregister_result = rpc_server
287            .unregister_endpoint(origin_filter.as_ref(), resource_id, request_handler)
288            .await;
289        assert!(unregister_result.is_ok());
290        assert!(!rpc_server.contains_endpoint(resource_id).await);
291    }
292
293    #[test_case(None, 0x0000; "for resource ID 0")]
294    #[test_case(None, 0x8000; "for resource ID out of range")]
295    #[test_case(Some(UUri::try_from_parts("*", 0xFFFF, 0xFF, 0x0001).unwrap()), 0x4A10; "for source filter with invalid resource ID")]
296    #[tokio::test]
297    async fn test_register_endpoint_fails(origin_filter: Option<UUri>, resource_id: u16) {
298        // GIVEN an RpcServer for a transport
299        let request_handler = Arc::new(MockRequestHandler::new());
300        let mut transport = MockTransport::new();
301        let uri_provider = new_uri_provider();
302        transport.expect_do_register_listener().never();
303        transport.expect_do_unregister_listener().never();
304
305        let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
306
307        // WHEN registering a request handler using invalid parameters
308        let register_result = rpc_server
309            .register_endpoint(origin_filter.as_ref(), resource_id, request_handler.clone())
310            .await;
311        // THEN registration fails
312        assert!(register_result.is_err_and(|e| matches!(e, RegistrationError::InvalidFilter(_v))));
313        assert!(!rpc_server.contains_endpoint(resource_id).await);
314
315        // and an attempt to unregister the handler using the same invalid parameters also fails with the same error
316        let unregister_result = rpc_server
317            .unregister_endpoint(origin_filter.as_ref(), resource_id, request_handler)
318            .await;
319        assert!(unregister_result.is_err_and(|e| matches!(e, RegistrationError::InvalidFilter(_v))));
320    }
321
322    #[tokio::test]
323    async fn test_register_endpoint_fails_for_duplicate_endpoint() {
324        // GIVEN an RpcServer for a transport
325        let request_handler = Arc::new(MockRequestHandler::new());
326        let mut transport = MockTransport::new();
327        let uri_provider = new_uri_provider();
328        transport
329            .expect_do_register_listener()
330            .once()
331            .return_const(Ok(()));
332
333        let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
334
335        // WHEN registering a request handler for an already existing endpoint
336        assert!(rpc_server
337            .register_endpoint(None, 0x5000, request_handler.clone())
338            .await
339            .is_ok());
340        let result = rpc_server
341            .register_endpoint(None, 0x5000, request_handler)
342            .await;
343
344        // THEN registration of the additional handler fails
345        assert!(result.is_err_and(|e| matches!(e, RegistrationError::MaxListenersExceeded)));
346        // but the original endpoint is still registered
347        assert!(rpc_server.contains_endpoint(0x5000).await);
348    }
349
350    #[tokio::test]
351    async fn test_unregister_endpoint_fails_for_non_existing_endpoint() {
352        // GIVEN an RpcServer for a transport
353        let request_handler = Arc::new(MockRequestHandler::new());
354        let mut transport = MockTransport::new();
355        let uri_provider = new_uri_provider();
356        transport.expect_do_unregister_listener().never();
357
358        let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
359
360        // WHEN trying to unregister a non existing endpoint
361        assert!(!rpc_server.contains_endpoint(0x5000).await);
362        let result = rpc_server
363            .unregister_endpoint(None, 0x5000, request_handler)
364            .await;
365
366        // THEN registration fails
367        assert!(result.is_err_and(|e| matches!(e, RegistrationError::NoSuchListener)));
368    }
369
370    #[tokio::test]
371    async fn test_request_listener_invokes_operation_successfully() {
372        let mut request_handler = MockRequestHandler::new();
373        let mut transport = MockTransport::new();
374        let notify = Arc::new(Notify::new());
375        let notify_clone = notify.clone();
376        let request_payload = StringValue {
377            value: "Hello".to_string(),
378            ..Default::default()
379        };
380        let message_id = UUID::build();
381        let message_id_clone = message_id.clone();
382        let message_source = UUri::try_from("up://localhost/A100/1/0").unwrap();
383        let message_source_clone = message_source.clone();
384
385        request_handler
386            .expect_handle_request()
387            .once()
388            .withf(move |resource_id, message_attributes, request_payload| {
389                if let Some(pl) = request_payload {
390                    let message_source = message_attributes.source.as_ref().unwrap();
391                    let msg: StringValue = pl.extract_protobuf().unwrap();
392                    msg.value == *"Hello"
393                        && *resource_id == 0x7000_u16
394                        && *message_source == message_source_clone
395                } else {
396                    false
397                }
398            })
399            .returning(|_resource_id, _message_attributes, _request_payload| {
400                let response_payload = UPayload::try_from_protobuf(StringValue {
401                    value: "Hello World".to_string(),
402                    ..Default::default()
403                })
404                .unwrap();
405                Ok(Some(response_payload))
406            });
407        transport
408            .expect_do_send()
409            .once()
410            .withf(move |response_message| {
411                let msg: StringValue = response_message.extract_protobuf().unwrap();
412                msg.value == *"Hello World"
413                    && response_message.is_response()
414                    && response_message
415                        .commstatus()
416                        .is_none_or(|code| code == UCode::OK)
417                    && response_message.request_id_unchecked() == &message_id_clone
418            })
419            .returning(move |_msg| {
420                notify_clone.notify_one();
421                Ok(())
422            });
423        let request_message = UMessageBuilder::request(
424            UUri::try_from("up://localhost/A200/1/7000").unwrap(),
425            message_source,
426            5_000,
427        )
428        .with_message_id(message_id)
429        .build_with_protobuf_payload(&request_payload)
430        .unwrap();
431
432        let request_listener = RequestListener {
433            request_handler: Arc::new(request_handler),
434            transport: Arc::new(transport),
435        };
436        request_listener.on_receive(request_message).await;
437        let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
438        assert!(result.is_ok());
439    }
440
441    #[tokio::test]
442    async fn test_request_listener_invokes_operation_erroneously() {
443        let mut request_handler = MockRequestHandler::new();
444        let mut transport = MockTransport::new();
445        let notify = Arc::new(Notify::new());
446        let notify_clone = notify.clone();
447        let message_id = UUID::build();
448        let message_id_clone = message_id.clone();
449
450        request_handler
451            .expect_handle_request()
452            .once()
453            .withf(|resource_id, _message_attributes, _request_payload| *resource_id == 0x7000_u16)
454            .returning(|_resource_id, _message_attributes, _request_payload| {
455                Err(ServiceInvocationError::NotFound(
456                    "no such object".to_string(),
457                ))
458            });
459        transport
460            .expect_do_send()
461            .once()
462            .withf(move |response_message| {
463                let error: UStatus = response_message.extract_protobuf().unwrap();
464                error.get_code() == UCode::NOT_FOUND
465                    && response_message.is_response()
466                    && response_message.commstatus_unchecked() == error.get_code()
467                    && response_message.request_id_unchecked() == &message_id_clone
468            })
469            .returning(move |_msg| {
470                notify_clone.notify_one();
471                Ok(())
472            });
473        let request_message = UMessageBuilder::request(
474            UUri::try_from("up://localhost/A200/1/7000").unwrap(),
475            UUri::try_from("up://localhost/A100/1/0").unwrap(),
476            5_000,
477        )
478        .with_message_id(message_id)
479        .build()
480        .unwrap();
481
482        let request_listener = RequestListener {
483            request_handler: Arc::new(request_handler),
484            transport: Arc::new(transport),
485        };
486        request_listener.on_receive(request_message).await;
487        let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
488        assert!(result.is_ok());
489    }
490
491    #[tokio::test]
492    async fn test_request_listener_times_out() {
493        // we need to manually implement the RequestHandler
494        // because from within the MockRequestHandler's expectation
495        // we cannot yield the current task (we can only use the blocking
496        // thread::sleep function)
497        struct NonRespondingHandler;
498        #[async_trait]
499        impl RequestHandler for NonRespondingHandler {
500            async fn handle_request(
501                &self,
502                resource_id: u16,
503                _message_attributes: &UAttributes,
504                _request_payload: Option<UPayload>,
505            ) -> Result<Option<UPayload>, ServiceInvocationError> {
506                assert_eq!(resource_id, 0x7000);
507                // this will yield the current task and allow the
508                // RequestListener to run into the timeout
509                tokio::time::sleep(Duration::from_millis(2000)).await;
510                Ok(None)
511            }
512        }
513
514        let request_handler = NonRespondingHandler {};
515        let mut transport = MockTransport::new();
516        let notify = Arc::new(Notify::new());
517        let notify_clone = notify.clone();
518        let message_id = UUID::build();
519        let message_id_clone = message_id.clone();
520
521        transport
522            .expect_do_send()
523            .once()
524            .withf(move |response_message| {
525                let error: UStatus = response_message.extract_protobuf().unwrap();
526                error.get_code() == UCode::DEADLINE_EXCEEDED
527                    && response_message.is_response()
528                    && response_message.commstatus_unchecked() == error.get_code()
529                    && response_message.request_id_unchecked() == &message_id_clone
530            })
531            .returning(move |_msg| {
532                notify_clone.notify_one();
533                Ok(())
534            });
535        let request_message = UMessageBuilder::request(
536            UUri::try_from("up://localhost/A200/1/7000").unwrap(),
537            UUri::try_from("up://localhost/A100/1/0").unwrap(),
538            // make sure this request times out very quickly
539            100,
540        )
541        .with_message_id(message_id)
542        .build()
543        .expect("should have been able to create RPC Request message");
544
545        let request_listener = RequestListener {
546            request_handler: Arc::new(request_handler),
547            transport: Arc::new(transport),
548        };
549        request_listener.on_receive(request_message).await;
550        let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
551        assert!(result.is_ok());
552    }
553}