1use std::sync::Arc;
15use thiserror::Error;
16
17use async_trait::async_trait;
18use protobuf::MessageFull;
19
20use crate::communication::RegistrationError;
21use crate::{UAttributes, UCode, UStatus, UUri};
22
23use super::{CallOptions, UPayload};
24
25#[derive(Clone, Error, Debug)]
28pub enum ServiceInvocationError {
29 #[error("entity already exists: {0}")]
31 AlreadyExists(String),
32 #[error("request timed out")]
37 DeadlineExceeded,
38 #[error("failed precondition: {0}")]
40 FailedPrecondition(String),
41 #[error("internal error: {0}")]
43 Internal(String),
44 #[error("invalid argument: {0}")]
46 InvalidArgument(String),
47 #[error("no such entity: {0}")]
49 NotFound(String),
50 #[error("permission denied: {0}")]
52 PermissionDenied(String),
53 #[error("resource exhausted: {0}")]
55 ResourceExhausted(String),
56 #[error("unknown error: {0}")]
58 RpcError(UStatus),
59 #[error("unauthenticated")]
61 Unauthenticated,
62 #[error("resource unavailable: {0}")]
64 Unavailable(String),
65 #[error("unimplemented: {0}")]
67 Unimplemented(String),
68}
69
70impl From<UStatus> for ServiceInvocationError {
71 fn from(value: UStatus) -> Self {
72 match value.code.enum_value() {
73 Ok(UCode::ALREADY_EXISTS) => ServiceInvocationError::AlreadyExists(value.get_message()),
74 Ok(UCode::DEADLINE_EXCEEDED) => ServiceInvocationError::DeadlineExceeded,
75 Ok(UCode::FAILED_PRECONDITION) => {
76 ServiceInvocationError::FailedPrecondition(value.get_message())
77 }
78 Ok(UCode::INTERNAL) => ServiceInvocationError::Internal(value.get_message()),
79 Ok(UCode::INVALID_ARGUMENT) => {
80 ServiceInvocationError::InvalidArgument(value.get_message())
81 }
82 Ok(UCode::NOT_FOUND) => ServiceInvocationError::NotFound(value.get_message()),
83 Ok(UCode::PERMISSION_DENIED) => {
84 ServiceInvocationError::PermissionDenied(value.get_message())
85 }
86 Ok(UCode::RESOURCE_EXHAUSTED) => {
87 ServiceInvocationError::ResourceExhausted(value.get_message())
88 }
89 Ok(UCode::UNAUTHENTICATED) => ServiceInvocationError::Unauthenticated,
90 Ok(UCode::UNAVAILABLE) => ServiceInvocationError::Unavailable(value.get_message()),
91 Ok(UCode::UNIMPLEMENTED) => ServiceInvocationError::Unimplemented(value.get_message()),
92 _ => ServiceInvocationError::RpcError(value),
93 }
94 }
95}
96
97impl From<ServiceInvocationError> for UStatus {
98 fn from(value: ServiceInvocationError) -> Self {
99 match value {
100 ServiceInvocationError::AlreadyExists(msg) => {
101 UStatus::fail_with_code(UCode::ALREADY_EXISTS, msg)
102 }
103 ServiceInvocationError::DeadlineExceeded => {
104 UStatus::fail_with_code(UCode::DEADLINE_EXCEEDED, "request timed out")
105 }
106 ServiceInvocationError::FailedPrecondition(msg) => {
107 UStatus::fail_with_code(UCode::FAILED_PRECONDITION, msg)
108 }
109 ServiceInvocationError::Internal(msg) => UStatus::fail_with_code(UCode::INTERNAL, msg),
110 ServiceInvocationError::InvalidArgument(msg) => {
111 UStatus::fail_with_code(UCode::INVALID_ARGUMENT, msg)
112 }
113 ServiceInvocationError::NotFound(msg) => UStatus::fail_with_code(UCode::NOT_FOUND, msg),
114 ServiceInvocationError::PermissionDenied(msg) => {
115 UStatus::fail_with_code(UCode::PERMISSION_DENIED, msg)
116 }
117 ServiceInvocationError::ResourceExhausted(msg) => {
118 UStatus::fail_with_code(UCode::RESOURCE_EXHAUSTED, msg)
119 }
120 ServiceInvocationError::Unauthenticated => {
121 UStatus::fail_with_code(UCode::UNAUTHENTICATED, "client must authenticate")
122 }
123 ServiceInvocationError::Unavailable(msg) => {
124 UStatus::fail_with_code(UCode::UNAVAILABLE, msg)
125 }
126 ServiceInvocationError::Unimplemented(msg) => {
127 UStatus::fail_with_code(UCode::UNIMPLEMENTED, msg)
128 }
129 _ => UStatus::fail_with_code(UCode::UNKNOWN, "unknown"),
130 }
131 }
132}
133
134#[cfg_attr(any(test, feature = "test-util"), mockall::automock)]
141#[async_trait]
142pub trait RpcClient: Send + Sync {
143 async fn invoke_method(
159 &self,
160 method: UUri,
161 call_options: CallOptions,
162 payload: Option<UPayload>,
163 ) -> Result<Option<UPayload>, ServiceInvocationError>;
164}
165
166impl dyn RpcClient {
167 pub async fn invoke_proto_method<T, R>(
184 &self,
185 method: UUri,
186 call_options: CallOptions,
187 request_message: T,
188 ) -> Result<R, ServiceInvocationError>
189 where
190 T: MessageFull,
191 R: MessageFull,
192 {
193 let payload = UPayload::try_from_protobuf(request_message)
194 .map_err(|e| ServiceInvocationError::InvalidArgument(e.to_string()))?;
195
196 let result = self
197 .invoke_method(method, call_options, Some(payload))
198 .await?;
199
200 if let Some(result) = result {
201 UPayload::extract_protobuf::<R>(&result)
202 .map_err(|e| ServiceInvocationError::InvalidArgument(e.to_string()))
203 } else {
204 Err(ServiceInvocationError::InvalidArgument(
205 "No payload".to_string(),
206 ))
207 }
208 }
209}
210
211#[cfg_attr(any(test, feature = "test-util"), mockall::automock)]
215#[async_trait]
216pub trait RequestHandler: Send + Sync {
217 async fn handle_request(
237 &self,
238 resource_id: u16,
239 message_attributes: &UAttributes,
240 request_payload: Option<UPayload>,
241 ) -> Result<Option<UPayload>, ServiceInvocationError>;
242}
243
244#[async_trait]
251pub trait RpcServer {
252 async fn register_endpoint(
269 &self,
270 origin_filter: Option<&UUri>,
271 resource_id: u16,
272 request_handler: Arc<dyn RequestHandler>,
273 ) -> Result<(), RegistrationError>;
274
275 async fn unregister_endpoint(
287 &self,
288 origin_filter: Option<&UUri>,
289 resource_id: u16,
290 request_handler: Arc<dyn RequestHandler>,
291 ) -> Result<(), RegistrationError>;
292}
293
294#[cfg(not(tarpaulin_include))]
295#[cfg(any(test, feature = "test-util"))]
296mockall::mock! {
297 pub RpcServerImpl {
300 pub async fn do_register_endpoint<'a>(&'a self, origin_filter: Option<&'a UUri>, resource_id: u16, request_handler: Arc<dyn RequestHandler>) -> Result<(), RegistrationError>;
301 pub async fn do_unregister_endpoint<'a>(&'a self, origin_filter: Option<&'a UUri>, resource_id: u16, request_handler: Arc<dyn RequestHandler>) -> Result<(), RegistrationError>;
302 }
303}
304
305#[cfg(not(tarpaulin_include))]
306#[cfg(any(test, feature = "test-util"))]
307#[async_trait]
308impl RpcServer for MockRpcServerImpl {
311 async fn register_endpoint(
312 &self,
313 origin_filter: Option<&UUri>,
314 resource_id: u16,
315 request_handler: Arc<dyn RequestHandler>,
316 ) -> Result<(), RegistrationError> {
317 self.do_register_endpoint(origin_filter, resource_id, request_handler)
318 .await
319 }
320 async fn unregister_endpoint(
321 &self,
322 origin_filter: Option<&UUri>,
323 resource_id: u16,
324 request_handler: Arc<dyn RequestHandler>,
325 ) -> Result<(), RegistrationError> {
326 self.do_unregister_endpoint(origin_filter, resource_id, request_handler)
327 .await
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use std::sync::Arc;
334
335 use protobuf::well_known_types::wrappers::StringValue;
336
337 use crate::{communication::CallOptions, UUri};
338
339 use super::*;
340
341 #[tokio::test]
342 async fn test_invoke_proto_method_fails_for_unexpected_return_type() {
343 let mut rpc_client = MockRpcClient::new();
344 rpc_client
345 .expect_invoke_method()
346 .once()
347 .returning(|_method, _options, _payload| {
348 let error = UStatus::fail_with_code(UCode::INTERNAL, "internal error");
349 let response_payload = UPayload::try_from_protobuf(error).unwrap();
350 Ok(Some(response_payload))
351 });
352 let client: Arc<dyn RpcClient> = Arc::new(rpc_client);
353 let mut request = StringValue::new();
354 request.value = "hello".to_string();
355 let result = client
356 .invoke_proto_method::<StringValue, StringValue>(
357 UUri::try_from_parts("", 0x1000, 0x01, 0x0001).unwrap(),
358 CallOptions::for_rpc_request(5_000, None, None, None),
359 request,
360 )
361 .await;
362 assert!(result.is_err_and(|e| matches!(e, ServiceInvocationError::InvalidArgument(_))));
363 }
364
365 #[tokio::test]
366 async fn test_invoke_proto_method_fails_for_missing_response_payload() {
367 let mut rpc_client = MockRpcClient::new();
368 rpc_client
369 .expect_invoke_method()
370 .once()
371 .return_const(Ok(None));
372 let client: Arc<dyn RpcClient> = Arc::new(rpc_client);
373 let mut request = StringValue::new();
374 request.value = "hello".to_string();
375 let result = client
376 .invoke_proto_method::<StringValue, StringValue>(
377 UUri::try_from_parts("", 0x1000, 0x01, 0x0001).unwrap(),
378 CallOptions::for_rpc_request(5_000, None, None, None),
379 request,
380 )
381 .await;
382 assert!(result.is_err_and(|e| matches!(e, ServiceInvocationError::InvalidArgument(_))));
383 }
384}