1use std::collections::hash_map::Entry;
17use std::collections::HashMap;
18use std::sync::Arc;
19use std::time::Duration;
20
21use async_trait::async_trait;
22use tracing::{debug, info};
23
24use crate::{
25 communication::build_message, LocalUriProvider, UListener, UMessage, UMessageBuilder, UStatus,
26 UTransport, UUri,
27};
28
29use super::{RegistrationError, RequestHandler, RpcServer, ServiceInvocationError, UPayload};
30
31struct RequestListener {
32 request_handler: Arc<dyn RequestHandler>,
33 transport: Arc<dyn UTransport>,
34}
35
36impl RequestListener {
37 async fn process_valid_request(&self, resource_id: u16, request_message: UMessage) {
38 let transport_clone = self.transport.clone();
39 let request_handler_clone = self.request_handler.clone();
40 let mut response_builder =
41 UMessageBuilder::response_for_request(request_message.attributes_unchecked());
42
43 let request_message_id = request_message.id_unchecked().clone();
44 let request_timeout = request_message.ttl_unchecked();
45 let payload_format = request_message.payload_format().unwrap_or_default();
46 let payload = request_message.payload;
47 let request_payload = payload.map(|data| UPayload::new(data, payload_format));
48
49 debug!(ttl = request_timeout, id = %request_message_id, "processing RPC request");
50
51 let invocation_result_future = request_handler_clone.handle_request(
52 resource_id,
53 &request_message.attributes,
54 request_payload,
55 );
56 let outcome = tokio::time::timeout(
57 Duration::from_millis(request_timeout as u64),
58 invocation_result_future,
59 )
60 .await
61 .map_err(|_e| {
62 info!(ttl = request_timeout, "request handler timed out");
63 ServiceInvocationError::DeadlineExceeded
64 })
65 .and_then(|v| v);
66
67 let response = match outcome {
68 Ok(response_payload) => build_message(&mut response_builder, response_payload),
69 Err(e) => {
70 let error = UStatus::from(e);
71 response_builder
72 .with_comm_status(error.get_code())
73 .build_with_protobuf_payload(&error)
74 }
75 };
76
77 match response {
78 Ok(response_message) => {
79 if let Err(e) = transport_clone.send(response_message).await {
80 info!(ucode = e.code.value(), "failed to send response message");
81 }
82 }
83 Err(e) => {
84 info!("failed to create response message: {}", e);
85 }
86 }
87 }
88}
89
90#[async_trait]
91impl UListener for RequestListener {
92 async fn on_receive(&self, msg: UMessage) {
93 if msg.is_request() {
94 let method_id = msg.sink_unchecked().resource_id();
96 self.process_valid_request(method_id, msg).await;
97 } else {
98 debug!(
99 message_type = msg.type_unchecked().to_cloudevent_type(),
100 "ignoring non-request message received by RPC server"
101 );
102 }
103 }
104}
105
106pub struct InMemoryRpcServer {
116 transport: Arc<dyn UTransport>,
117 uri_provider: Arc<dyn LocalUriProvider>,
118 request_listeners: tokio::sync::Mutex<HashMap<u16, Arc<dyn UListener>>>,
119}
120
121impl InMemoryRpcServer {
122 pub fn new(transport: Arc<dyn UTransport>, uri_provider: Arc<dyn LocalUriProvider>) -> Self {
124 InMemoryRpcServer {
125 transport,
126 uri_provider,
127 request_listeners: tokio::sync::Mutex::new(HashMap::new()),
128 }
129 }
130
131 fn validate_sink_filter(filter: &UUri) -> Result<(), RegistrationError> {
132 if !filter.is_rpc_method() {
133 return Err(RegistrationError::InvalidFilter(
134 "RPC endpoint's resource ID must be in range [0x0001, 0x7FFF]".to_string(),
135 ));
136 }
137 Ok(())
138 }
139
140 fn validate_origin_filter(filter: Option<&UUri>) -> Result<(), RegistrationError> {
141 if let Some(uri) = filter {
142 if !uri.is_rpc_response() {
143 return Err(RegistrationError::InvalidFilter(
144 "origin filter's resource ID must be 0".to_string(),
145 ));
146 }
147 }
148 Ok(())
149 }
150
151 #[cfg(test)]
152 async fn contains_endpoint(&self, resource_id: u16) -> bool {
153 let listener_map = self.request_listeners.lock().await;
154 listener_map.contains_key(&resource_id)
155 }
156}
157
158#[async_trait]
159impl RpcServer for InMemoryRpcServer {
160 async fn register_endpoint(
161 &self,
162 origin_filter: Option<&UUri>,
163 resource_id: u16,
164 request_handler: Arc<dyn RequestHandler>,
165 ) -> Result<(), RegistrationError> {
166 Self::validate_origin_filter(origin_filter)?;
167 let sink_filter = self.uri_provider.get_resource_uri(resource_id);
168 Self::validate_sink_filter(&sink_filter)?;
169
170 let mut listener_map = self.request_listeners.lock().await;
171 if let Entry::Vacant(e) = listener_map.entry(resource_id) {
172 let listener = Arc::new(RequestListener {
173 request_handler,
174 transport: self.transport.clone(),
175 });
176 self.transport
177 .register_listener(
178 origin_filter.unwrap_or(&UUri::any_with_resource_id(
179 crate::uri::RESOURCE_ID_RESPONSE,
180 )),
181 Some(&sink_filter),
182 listener.clone(),
183 )
184 .await
185 .map(|_| {
186 e.insert(listener);
187 })
188 .map_err(RegistrationError::from)
189 } else {
190 Err(RegistrationError::MaxListenersExceeded)
191 }
192 }
193
194 async fn unregister_endpoint(
195 &self,
196 origin_filter: Option<&UUri>,
197 resource_id: u16,
198 _request_handler: Arc<dyn RequestHandler>,
199 ) -> Result<(), RegistrationError> {
200 Self::validate_origin_filter(origin_filter)?;
201 let sink_filter = self.uri_provider.get_resource_uri(resource_id);
202 Self::validate_sink_filter(&sink_filter)?;
203
204 let mut listener_map = self.request_listeners.lock().await;
205 if let Entry::Occupied(entry) = listener_map.entry(resource_id) {
206 let listener = entry.get().to_owned();
207 self.transport
208 .unregister_listener(
209 origin_filter.unwrap_or(&UUri::any_with_resource_id(
210 crate::uri::RESOURCE_ID_RESPONSE,
211 )),
212 Some(&sink_filter),
213 listener,
214 )
215 .await
216 .map(|_| {
217 entry.remove();
218 })
219 .map_err(RegistrationError::from)
220 } else {
221 Err(RegistrationError::NoSuchListener)
222 }
223 }
224}
225
226#[cfg(test)]
227mod tests {
228
229 use super::*;
232
233 use protobuf::well_known_types::wrappers::StringValue;
234 use test_case::test_case;
235 use tokio::sync::Notify;
236
237 use crate::{
238 communication::rpc::MockRequestHandler, utransport::MockTransport, StaticUriProvider,
239 UAttributes, UCode, UUri, UUID,
240 };
241
242 fn new_uri_provider() -> Arc<dyn LocalUriProvider> {
243 Arc::new(StaticUriProvider::new("", 0x0005, 0x02))
244 }
245
246 #[test_case(None, 0x4A10; "for empty origin filter")]
247 #[test_case(Some(UUri::try_from_parts("authority", 0xBF1A, 0x01, 0x0000).unwrap()), 0x4A10; "for specific origin filter")]
248 #[test_case(Some(UUri::try_from_parts("*", 0xFFFF, 0x01, 0x0000).unwrap()), 0x7091; "for wildcard origin filter")]
249 #[tokio::test]
250 async fn test_register_endpoint_succeeds(origin_filter: Option<UUri>, resource_id: u16) {
251 let request_handler = Arc::new(MockRequestHandler::new());
253 let mut transport = MockTransport::new();
254 let uri_provider = new_uri_provider();
255 let expected_source_filter = origin_filter
256 .clone()
257 .unwrap_or(UUri::any_with_resource_id(0));
258 let param_check = move |source_filter: &UUri,
259 sink_filter: &Option<&UUri>,
260 _listener: &Arc<dyn UListener>| {
261 source_filter == &expected_source_filter
262 && sink_filter.is_some_and(|uri| uri.resource_id == resource_id as u32)
263 };
264 transport
265 .expect_do_register_listener()
266 .once()
267 .withf(param_check.clone())
268 .returning(|_source_filter, _sink_filter, _listener| Ok(()));
269 transport
270 .expect_do_unregister_listener()
271 .once()
272 .withf(param_check)
273 .returning(|_source_filter, _sink_filter, _listener| Ok(()));
274
275 let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
276
277 let register_result = rpc_server
279 .register_endpoint(origin_filter.as_ref(), resource_id, request_handler.clone())
280 .await;
281 assert!(register_result.is_ok());
283 assert!(rpc_server.contains_endpoint(resource_id).await);
284
285 let unregister_result = rpc_server
287 .unregister_endpoint(origin_filter.as_ref(), resource_id, request_handler)
288 .await;
289 assert!(unregister_result.is_ok());
290 assert!(!rpc_server.contains_endpoint(resource_id).await);
291 }
292
293 #[test_case(None, 0x0000; "for resource ID 0")]
294 #[test_case(None, 0x8000; "for resource ID out of range")]
295 #[test_case(Some(UUri::try_from_parts("*", 0xFFFF, 0xFF, 0x0001).unwrap()), 0x4A10; "for source filter with invalid resource ID")]
296 #[tokio::test]
297 async fn test_register_endpoint_fails(origin_filter: Option<UUri>, resource_id: u16) {
298 let request_handler = Arc::new(MockRequestHandler::new());
300 let mut transport = MockTransport::new();
301 let uri_provider = new_uri_provider();
302 transport.expect_do_register_listener().never();
303 transport.expect_do_unregister_listener().never();
304
305 let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
306
307 let register_result = rpc_server
309 .register_endpoint(origin_filter.as_ref(), resource_id, request_handler.clone())
310 .await;
311 assert!(register_result.is_err_and(|e| matches!(e, RegistrationError::InvalidFilter(_v))));
313 assert!(!rpc_server.contains_endpoint(resource_id).await);
314
315 let unregister_result = rpc_server
317 .unregister_endpoint(origin_filter.as_ref(), resource_id, request_handler)
318 .await;
319 assert!(unregister_result.is_err_and(|e| matches!(e, RegistrationError::InvalidFilter(_v))));
320 }
321
322 #[tokio::test]
323 async fn test_register_endpoint_fails_for_duplicate_endpoint() {
324 let request_handler = Arc::new(MockRequestHandler::new());
326 let mut transport = MockTransport::new();
327 let uri_provider = new_uri_provider();
328 transport
329 .expect_do_register_listener()
330 .once()
331 .return_const(Ok(()));
332
333 let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
334
335 assert!(rpc_server
337 .register_endpoint(None, 0x5000, request_handler.clone())
338 .await
339 .is_ok());
340 let result = rpc_server
341 .register_endpoint(None, 0x5000, request_handler)
342 .await;
343
344 assert!(result.is_err_and(|e| matches!(e, RegistrationError::MaxListenersExceeded)));
346 assert!(rpc_server.contains_endpoint(0x5000).await);
348 }
349
350 #[tokio::test]
351 async fn test_unregister_endpoint_fails_for_non_existing_endpoint() {
352 let request_handler = Arc::new(MockRequestHandler::new());
354 let mut transport = MockTransport::new();
355 let uri_provider = new_uri_provider();
356 transport.expect_do_unregister_listener().never();
357
358 let rpc_server = InMemoryRpcServer::new(Arc::new(transport), uri_provider);
359
360 assert!(!rpc_server.contains_endpoint(0x5000).await);
362 let result = rpc_server
363 .unregister_endpoint(None, 0x5000, request_handler)
364 .await;
365
366 assert!(result.is_err_and(|e| matches!(e, RegistrationError::NoSuchListener)));
368 }
369
370 #[tokio::test]
371 async fn test_request_listener_invokes_operation_successfully() {
372 let mut request_handler = MockRequestHandler::new();
373 let mut transport = MockTransport::new();
374 let notify = Arc::new(Notify::new());
375 let notify_clone = notify.clone();
376 let request_payload = StringValue {
377 value: "Hello".to_string(),
378 ..Default::default()
379 };
380 let message_id = UUID::build();
381 let message_id_clone = message_id.clone();
382 let message_source = UUri::try_from("up://localhost/A100/1/0").unwrap();
383 let message_source_clone = message_source.clone();
384
385 request_handler
386 .expect_handle_request()
387 .once()
388 .withf(move |resource_id, message_attributes, request_payload| {
389 if let Some(pl) = request_payload {
390 let message_source = message_attributes.source.as_ref().unwrap();
391 let msg: StringValue = pl.extract_protobuf().unwrap();
392 msg.value == *"Hello"
393 && *resource_id == 0x7000_u16
394 && *message_source == message_source_clone
395 } else {
396 false
397 }
398 })
399 .returning(|_resource_id, _message_attributes, _request_payload| {
400 let response_payload = UPayload::try_from_protobuf(StringValue {
401 value: "Hello World".to_string(),
402 ..Default::default()
403 })
404 .unwrap();
405 Ok(Some(response_payload))
406 });
407 transport
408 .expect_do_send()
409 .once()
410 .withf(move |response_message| {
411 let msg: StringValue = response_message.extract_protobuf().unwrap();
412 msg.value == *"Hello World"
413 && response_message.is_response()
414 && response_message
415 .commstatus()
416 .is_none_or(|code| code == UCode::OK)
417 && response_message.request_id_unchecked() == &message_id_clone
418 })
419 .returning(move |_msg| {
420 notify_clone.notify_one();
421 Ok(())
422 });
423 let request_message = UMessageBuilder::request(
424 UUri::try_from("up://localhost/A200/1/7000").unwrap(),
425 message_source,
426 5_000,
427 )
428 .with_message_id(message_id)
429 .build_with_protobuf_payload(&request_payload)
430 .unwrap();
431
432 let request_listener = RequestListener {
433 request_handler: Arc::new(request_handler),
434 transport: Arc::new(transport),
435 };
436 request_listener.on_receive(request_message).await;
437 let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
438 assert!(result.is_ok());
439 }
440
441 #[tokio::test]
442 async fn test_request_listener_invokes_operation_erroneously() {
443 let mut request_handler = MockRequestHandler::new();
444 let mut transport = MockTransport::new();
445 let notify = Arc::new(Notify::new());
446 let notify_clone = notify.clone();
447 let message_id = UUID::build();
448 let message_id_clone = message_id.clone();
449
450 request_handler
451 .expect_handle_request()
452 .once()
453 .withf(|resource_id, _message_attributes, _request_payload| *resource_id == 0x7000_u16)
454 .returning(|_resource_id, _message_attributes, _request_payload| {
455 Err(ServiceInvocationError::NotFound(
456 "no such object".to_string(),
457 ))
458 });
459 transport
460 .expect_do_send()
461 .once()
462 .withf(move |response_message| {
463 let error: UStatus = response_message.extract_protobuf().unwrap();
464 error.get_code() == UCode::NOT_FOUND
465 && response_message.is_response()
466 && response_message.commstatus_unchecked() == error.get_code()
467 && response_message.request_id_unchecked() == &message_id_clone
468 })
469 .returning(move |_msg| {
470 notify_clone.notify_one();
471 Ok(())
472 });
473 let request_message = UMessageBuilder::request(
474 UUri::try_from("up://localhost/A200/1/7000").unwrap(),
475 UUri::try_from("up://localhost/A100/1/0").unwrap(),
476 5_000,
477 )
478 .with_message_id(message_id)
479 .build()
480 .unwrap();
481
482 let request_listener = RequestListener {
483 request_handler: Arc::new(request_handler),
484 transport: Arc::new(transport),
485 };
486 request_listener.on_receive(request_message).await;
487 let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
488 assert!(result.is_ok());
489 }
490
491 #[tokio::test]
492 async fn test_request_listener_times_out() {
493 struct NonRespondingHandler;
498 #[async_trait]
499 impl RequestHandler for NonRespondingHandler {
500 async fn handle_request(
501 &self,
502 resource_id: u16,
503 _message_attributes: &UAttributes,
504 _request_payload: Option<UPayload>,
505 ) -> Result<Option<UPayload>, ServiceInvocationError> {
506 assert_eq!(resource_id, 0x7000);
507 tokio::time::sleep(Duration::from_millis(2000)).await;
510 Ok(None)
511 }
512 }
513
514 let request_handler = NonRespondingHandler {};
515 let mut transport = MockTransport::new();
516 let notify = Arc::new(Notify::new());
517 let notify_clone = notify.clone();
518 let message_id = UUID::build();
519 let message_id_clone = message_id.clone();
520
521 transport
522 .expect_do_send()
523 .once()
524 .withf(move |response_message| {
525 let error: UStatus = response_message.extract_protobuf().unwrap();
526 error.get_code() == UCode::DEADLINE_EXCEEDED
527 && response_message.is_response()
528 && response_message.commstatus_unchecked() == error.get_code()
529 && response_message.request_id_unchecked() == &message_id_clone
530 })
531 .returning(move |_msg| {
532 notify_clone.notify_one();
533 Ok(())
534 });
535 let request_message = UMessageBuilder::request(
536 UUri::try_from("up://localhost/A200/1/7000").unwrap(),
537 UUri::try_from("up://localhost/A100/1/0").unwrap(),
538 100,
540 )
541 .with_message_id(message_id)
542 .build()
543 .expect("should have been able to create RPC Request message");
544
545 let request_listener = RequestListener {
546 request_handler: Arc::new(request_handler),
547 transport: Arc::new(transport),
548 };
549 request_listener.on_receive(request_message).await;
550 let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
551 assert!(result.is_ok());
552 }
553}