use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use async_trait::async_trait;
use tokio::sync::oneshot::{Receiver, Sender};
use tokio::time::timeout;
use tracing::{debug, info};
use crate::{
LocalUriProvider, UCode, UListener, UMessage, UMessageBuilder, UMessageType, UStatus,
UTransport, UUri, UUID,
};
use super::{
build_message, CallOptions, RegistrationError, RpcClient, ServiceInvocationError, UPayload,
};
fn handle_response_message(response: UMessage) -> Result<Option<UPayload>, ServiceInvocationError> {
let Some(attribs) = response.attributes.as_ref() else {
return Err(ServiceInvocationError::InvalidArgument(
"response message does not contain attributes".to_string(),
));
};
match attribs.commstatus.map(|v| v.enum_value_or_default()) {
Some(UCode::OK) | None => {
response.payload.map_or(Ok(None), |payload| {
Ok(Some(UPayload::new(
payload,
attribs.payload_format.enum_value_or_default(),
)))
})
}
Some(code) => {
let status = response.extract_protobuf().unwrap_or_else(|_e| {
UStatus::fail_with_code(code, "failed to invoke service operation")
});
Err(ServiceInvocationError::from(status))
}
}
}
struct ResponseListener {
pending_requests: Mutex<HashMap<UUID, Sender<UMessage>>>,
}
impl ResponseListener {
fn try_add_pending_request(
&self,
reqid: UUID,
) -> Result<Receiver<UMessage>, ServiceInvocationError> {
let Ok(mut pending_requests) = self.pending_requests.lock() else {
return Err(ServiceInvocationError::Internal(
"failed to add response handler".to_string(),
));
};
if let Entry::Vacant(entry) = pending_requests.entry(reqid) {
let (tx, rx) = tokio::sync::oneshot::channel();
entry.insert(tx);
Ok(rx)
} else {
Err(ServiceInvocationError::AlreadyExists(
"RPC request with given ID already pending".to_string(),
))
}
}
fn handle_response(&self, reqid: &UUID, response_message: UMessage) {
let Ok(mut pending_requests) = self.pending_requests.lock() else {
info!(
request_id = reqid.to_hyphenated_string(),
"failed to process response message, cannot acquire lock for pending requests map"
);
return;
};
if let Some(sender) = pending_requests.remove(reqid) {
if let Err(_e) = sender.send(response_message) {
debug!(
request_id = reqid.to_hyphenated_string(),
"failed to deliver RPC Response message, channel already closed"
);
} else {
debug!(
request_id = reqid.to_hyphenated_string(),
"successfully delivered RPC Response message"
)
}
} else {
debug!(
request_id = reqid.to_hyphenated_string(),
"ignoring (duplicate?) RPC Response message with unknown request ID"
);
}
}
fn remove_pending_request(&self, reqid: &UUID) -> Option<Sender<UMessage>> {
self.pending_requests
.lock()
.map_or(None, |mut pending_requests| pending_requests.remove(reqid))
}
#[cfg(test)]
fn contains(&self, reqid: &UUID) -> bool {
self.pending_requests
.lock()
.map_or(false, |pending_requests| {
pending_requests.contains_key(reqid)
})
}
}
#[async_trait]
impl UListener for ResponseListener {
async fn on_receive(&self, msg: UMessage) {
let message_type = msg
.attributes
.get_or_default()
.type_
.enum_value_or_default();
if message_type != UMessageType::UMESSAGE_TYPE_RESPONSE {
debug!(
message_type = message_type.to_cloudevent_type(),
"service provider replied with message that is not an RPC Response"
);
return;
}
if let Some(reqid) = msg
.attributes
.as_ref()
.and_then(|attribs| attribs.reqid.clone().into_option())
{
self.handle_response(&reqid, msg);
} else {
debug!("ignoring malformed response message not containing request ID");
}
}
}
pub struct InMemoryRpcClient {
transport: Arc<dyn UTransport>,
uri_provider: Arc<dyn LocalUriProvider>,
response_listener: Arc<ResponseListener>,
}
impl InMemoryRpcClient {
pub async fn new(
transport: Arc<dyn UTransport>,
uri_provider: Arc<dyn LocalUriProvider>,
) -> Result<Self, RegistrationError> {
let response_listener = Arc::new(ResponseListener {
pending_requests: Mutex::new(HashMap::new()),
});
transport
.register_listener(
&UUri::any(),
Some(&uri_provider.get_source_uri()),
response_listener.clone(),
)
.await
.map_err(RegistrationError::from)?;
Ok(InMemoryRpcClient {
transport,
uri_provider,
response_listener,
})
}
#[cfg(test)]
fn contains_pending_request(&self, reqid: &UUID) -> bool {
self.response_listener.contains(reqid)
}
}
#[async_trait]
impl RpcClient for InMemoryRpcClient {
async fn invoke_method(
&self,
method: UUri,
call_options: CallOptions,
payload: Option<UPayload>,
) -> Result<Option<UPayload>, ServiceInvocationError> {
let message_id = call_options.message_id().unwrap_or_else(UUID::build);
let mut builder = UMessageBuilder::request(
method.clone(),
self.uri_provider.get_source_uri(),
call_options.ttl(),
);
builder.with_message_id(message_id.clone());
if let Some(token) = call_options.token() {
builder.with_token(token.to_owned());
}
if let Some(priority) = call_options.priority() {
builder.with_priority(priority);
}
let rpc_request_message = build_message(&mut builder, payload)
.map_err(|e| ServiceInvocationError::InvalidArgument(e.to_string()))?;
let receiver = self
.response_listener
.try_add_pending_request(message_id.clone())?;
self.transport
.send(rpc_request_message)
.await
.map_err(|e| {
self.response_listener.remove_pending_request(&message_id);
e
})?;
debug!(
request_id = message_id.to_hyphenated_string(),
ttl = call_options.ttl(),
"successfully sent RPC Request message"
);
match timeout(Duration::from_millis(call_options.ttl() as u64), receiver).await {
Err(_) => {
debug!(
request_id = message_id.to_hyphenated_string(),
ttl = call_options.ttl(),
"invocation of service operation has timed out"
);
self.response_listener.remove_pending_request(&message_id);
Err(ServiceInvocationError::DeadlineExceeded)
}
Ok(result) => match result {
Ok(response_message) => handle_response_message(response_message),
Err(_e) => {
debug!(
request_id = message_id.to_hyphenated_string(),
"response listener failed to forward response message"
);
self.response_listener.remove_pending_request(&message_id);
Err(ServiceInvocationError::Internal(
"error receiving response message".to_string(),
))
}
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use protobuf::{well_known_types::wrappers::StringValue, Enum};
use tokio::{join, sync::Notify};
use crate::{
utransport::{MockLocalUriProvider, MockTransport},
UMessageBuilder, UPriority, UUri,
};
fn new_uri_provider() -> Arc<dyn LocalUriProvider> {
let mut mock_uri_locator = MockLocalUriProvider::new();
mock_uri_locator.expect_get_source_uri().returning(|| UUri {
ue_id: 0x0005,
ue_version_major: 0x02,
resource_id: 0x0000,
..Default::default()
});
Arc::new(mock_uri_locator)
}
fn service_method_uri() -> UUri {
UUri {
ue_id: 0x0001,
ue_version_major: 0x01,
resource_id: 0x1000,
..Default::default()
}
}
#[tokio::test]
async fn test_registration_of_response_listener_fails() {
let mut mock_transport = MockTransport::default();
mock_transport
.expect_do_register_listener()
.once()
.returning(|_source_filter, _sink_filter, _listener| {
Err(UStatus::fail_with_code(
UCode::RESOURCE_EXHAUSTED,
"max number of listeners exceeded",
))
});
let creation_attempt =
InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider()).await;
assert!(
creation_attempt.is_err_and(|e| matches!(e, RegistrationError::MaxListenersExceeded))
);
}
#[tokio::test]
async fn test_invoke_method_fails_with_transport_error() {
let mut mock_transport = MockTransport::default();
mock_transport
.expect_do_register_listener()
.once()
.returning(|_source_filter, _sink_filter, _listener| Ok(()));
mock_transport
.expect_do_send()
.returning(|_request_message| {
Err(UStatus::fail_with_code(
UCode::UNAVAILABLE,
"transport not available",
))
});
let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
.await
.unwrap();
let message_id = UUID::build();
let call_options =
CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
let response = client
.invoke_method(service_method_uri(), call_options, None)
.await;
assert!(response.is_err_and(|e| matches!(e, ServiceInvocationError::Unavailable(_msg))));
assert!(!client.contains_pending_request(&message_id));
}
#[tokio::test]
async fn test_invoke_method_succeeds() {
let message_id = UUID::build();
let call_options = CallOptions::for_rpc_request(
5_000,
Some(message_id.clone()),
Some("my_token".to_string()),
Some(crate::UPriority::UPRIORITY_CS6),
);
let (captured_listener_tx, captured_listener_rx) = tokio::sync::oneshot::channel();
let request_sent = Arc::new(Notify::new());
let request_sent_clone = request_sent.clone();
let mut mock_transport = MockTransport::default();
mock_transport
.expect_do_register_listener()
.once()
.return_once(move |_source_filter, _sink_filter, listener| {
captured_listener_tx
.send(listener)
.map_err(|_e| UStatus::fail("cannot capture listener"))
});
let expected_message_id = message_id.clone();
mock_transport
.expect_do_send()
.once()
.withf(move |request_message| {
request_message
.attributes
.as_ref()
.map_or(false, |attribs| {
attribs.id.as_ref() == Some(&expected_message_id)
&& attribs.priority.value() == UPriority::UPRIORITY_CS6.value()
&& attribs.ttl == Some(5_000)
&& attribs.token == Some("my_token".to_string())
})
})
.returning(move |_request_message| {
request_sent_clone.notify_one();
Ok(())
});
let uri_provider = new_uri_provider();
let rpc_client = Arc::new(
InMemoryRpcClient::new(Arc::new(mock_transport), uri_provider.clone())
.await
.unwrap(),
);
let client: Arc<dyn RpcClient> = rpc_client.clone();
let response_handle = tokio::spawn(async move {
let request_payload = StringValue {
value: "World".to_string(),
..Default::default()
};
client
.invoke_proto_method::<_, StringValue>(
service_method_uri(),
call_options,
request_payload,
)
.await
});
let response_payload = StringValue {
value: "Hello World".to_string(),
..Default::default()
};
let response_message = UMessageBuilder::response(
uri_provider.get_source_uri(),
message_id.clone(),
service_method_uri(),
)
.build_with_protobuf_payload(&response_payload)
.unwrap();
let (response_listener_result, _) = join!(captured_listener_rx, request_sent.notified());
let response_listener = response_listener_result.unwrap();
let cloned_response_message = response_message.clone();
let cloned_response_listener = response_listener.clone();
tokio::spawn(async move {
cloned_response_listener
.on_receive(cloned_response_message)
.await
});
let response = response_handle.await.unwrap();
assert!(response.is_ok_and(|payload| payload.value == *"Hello World"));
assert!(!rpc_client.contains_pending_request(&message_id));
response_listener.on_receive(response_message).await;
assert!(!rpc_client.contains_pending_request(&message_id));
}
#[tokio::test]
async fn test_invoke_method_fails_on_repeated_invocation() {
let message_id = UUID::build();
let first_request_sent = Arc::new(Notify::new());
let first_request_sent_clone = first_request_sent.clone();
let mut mock_transport = MockTransport::default();
mock_transport
.expect_do_register_listener()
.once()
.return_const(Ok(()));
let expected_message_id = message_id.clone();
mock_transport
.expect_do_send()
.once()
.withf(move |request_message| {
request_message
.attributes
.as_ref()
.map_or(false, |attribs| {
attribs.id.as_ref() == Some(&expected_message_id)
})
})
.returning(move |_request_message| {
first_request_sent_clone.notify_one();
Ok(())
});
let in_memory_rpc_client = Arc::new(
InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
.await
.unwrap(),
);
let rpc_client: Arc<dyn RpcClient> = in_memory_rpc_client.clone();
let call_options =
CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
let cloned_call_options = call_options.clone();
let cloned_rpc_client = rpc_client.clone();
tokio::spawn(async move {
let request_payload = StringValue {
value: "World".to_string(),
..Default::default()
};
cloned_rpc_client
.invoke_proto_method::<_, StringValue>(
service_method_uri(),
cloned_call_options,
request_payload,
)
.await
});
first_request_sent.notified().await;
let request_payload = StringValue {
value: "World".to_string(),
..Default::default()
};
let second_request_handle = tokio::spawn(async move {
rpc_client
.invoke_proto_method::<_, StringValue>(
service_method_uri(),
call_options,
request_payload,
)
.await
});
let response = second_request_handle.await.unwrap();
assert!(response.is_err_and(|e| matches!(e, ServiceInvocationError::AlreadyExists(_))));
assert!(in_memory_rpc_client.contains_pending_request(&message_id));
}
#[tokio::test]
async fn test_invoke_method_fails_with_remote_error() {
let (captured_listener_tx, captured_listener_rx) = std::sync::mpsc::channel();
let mut mock_transport = MockTransport::default();
mock_transport.expect_do_register_listener().returning(
move |_source_filter, _sink_filter, listener| {
captured_listener_tx
.send(listener)
.map_err(|_e| UStatus::fail("cannot capture listener"))
},
);
mock_transport
.expect_do_send()
.returning(move |request_message| {
let error = UStatus::fail_with_code(UCode::NOT_FOUND, "no such object");
let response_message = UMessageBuilder::response_for_request(
request_message.attributes.as_ref().unwrap(),
)
.with_comm_status(UCode::NOT_FOUND)
.build_with_protobuf_payload(&error)
.unwrap();
let captured_listener = captured_listener_rx.recv().unwrap().to_owned();
tokio::spawn(async move { captured_listener.on_receive(response_message).await });
Ok(())
});
let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
.await
.unwrap();
let message_id = UUID::build();
let call_options =
CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
let response = client
.invoke_method(service_method_uri(), call_options, None)
.await;
assert!(response.is_err_and(|e| { matches!(e, ServiceInvocationError::NotFound(_msg)) }));
assert!(!client.contains_pending_request(&message_id));
}
#[tokio::test]
async fn test_invoke_method_times_out() {
let mut mock_transport = MockTransport::default();
mock_transport
.expect_do_register_listener()
.returning(|_source_filter, _sink_filter, _listener| Ok(()));
mock_transport
.expect_do_send()
.returning(|_request_message| Ok(()));
let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
.await
.unwrap();
let message_id = UUID::build();
let call_options = CallOptions::for_rpc_request(20, Some(message_id.clone()), None, None);
let response = client
.invoke_method(service_method_uri(), call_options, None)
.await;
assert!(response.is_err_and(|e| { matches!(e, ServiceInvocationError::DeadlineExceeded) }));
assert!(!client.contains_pending_request(&message_id));
}
#[test]
fn test_handle_response_message_fails_for_missing_attributes() {
let response_msg = UMessage {
..Default::default()
};
let result = handle_response_message(response_msg);
assert!(result.is_err_and(|e| matches!(e, ServiceInvocationError::InvalidArgument(_))));
}
}