use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use async_trait::async_trait;
use tokio::sync::oneshot::{Receiver, Sender};
use tokio::time::timeout;
use tracing::{debug, info};
use crate::{
LocalUriProvider, UCode, UListener, UMessage, UMessageBuilder, UMessageType, UStatus,
UTransport, UUri, UUID,
};
use super::{
build_message, CallOptions, RegistrationError, RpcClient, ServiceInvocationError, UPayload,
};
fn handle_response_message(response: UMessage) -> Result<Option<UPayload>, ServiceInvocationError> {
let Some(attribs) = response.attributes.as_ref() else {
return Err(ServiceInvocationError::RpcError(UStatus::fail_with_code(
UCode::INTERNAL,
"response message does not contain attributes",
)));
};
match attribs.commstatus.map(|v| v.enum_value_or_default()) {
Some(UCode::OK) | None => {
response.payload.map_or(Ok(None), |payload| {
Ok(Some(UPayload::new(
payload,
attribs.payload_format.enum_value_or_default(),
)))
})
}
Some(code) => {
let status = response.extract_protobuf().unwrap_or_else(|_e| {
UStatus::fail_with_code(code, "failed to invoke service operation")
});
Err(ServiceInvocationError::from(status))
}
}
}
struct ResponseListener {
pending_requests: Mutex<HashMap<UUID, Sender<UMessage>>>,
}
impl ResponseListener {
fn try_add_pending_request(
&self,
reqid: UUID,
) -> Result<Receiver<UMessage>, ServiceInvocationError> {
let Ok(mut pending_requests) = self.pending_requests.lock() else {
return Err(ServiceInvocationError::Internal(
"failed to add response handler".to_string(),
));
};
if let Entry::Vacant(entry) = pending_requests.entry(reqid) {
let (tx, rx) = tokio::sync::oneshot::channel();
entry.insert(tx);
Ok(rx)
} else {
Err(ServiceInvocationError::AlreadyExists(
"RPC request with given ID already pending".to_string(),
))
}
}
fn handle_response(&self, reqid: &UUID, response_message: UMessage) {
let Ok(mut pending_requests) = self.pending_requests.lock() else {
info!(
request_id = reqid.to_hyphenated_string(),
"failed to process response message, cannot acquire lock for pending requests map"
);
return;
};
if let Some(sender) = pending_requests.remove(reqid) {
if let Err(_e) = sender.send(response_message) {
debug!(
request_id = reqid.to_hyphenated_string(),
"failed to deliver response message, channel already closed"
);
}
} else {
debug!(
request_id = reqid.to_hyphenated_string(),
"ignoring response message for unknown request"
);
}
}
fn remove_pending_request(&self, reqid: &UUID) -> Option<Sender<UMessage>> {
self.pending_requests
.lock()
.map_or(None, |mut pending_requests| pending_requests.remove(reqid))
}
#[cfg(test)]
fn contains(&self, reqid: &UUID) -> bool {
self.pending_requests
.lock()
.map_or(false, |pending_requests| {
pending_requests.contains_key(reqid)
})
}
}
#[async_trait]
impl UListener for ResponseListener {
async fn on_receive(&self, msg: UMessage) {
let message_type = msg
.attributes
.get_or_default()
.type_
.enum_value_or_default();
if message_type != UMessageType::UMESSAGE_TYPE_RESPONSE {
debug!(
message_type = message_type.to_cloudevent_type(),
"service provider replied with message that is not an RPC Response"
);
return;
}
if let Some(reqid) = msg
.attributes
.as_ref()
.and_then(|attribs| attribs.reqid.clone().into_option())
{
self.handle_response(&reqid, msg);
} else {
debug!("ignoring malformed response message not containing request ID");
}
}
}
pub struct InMemoryRpcClient {
transport: Arc<dyn UTransport>,
uri_provider: Arc<dyn LocalUriProvider>,
response_listener: Arc<ResponseListener>,
}
impl InMemoryRpcClient {
pub async fn new(
transport: Arc<dyn UTransport>,
uri_provider: Arc<dyn LocalUriProvider>,
) -> Result<Self, RegistrationError> {
let response_listener = Arc::new(ResponseListener {
pending_requests: Mutex::new(HashMap::new()),
});
transport
.register_listener(
&UUri::any(),
Some(&uri_provider.get_source_uri()),
response_listener.clone(),
)
.await
.map_err(RegistrationError::from)?;
Ok(InMemoryRpcClient {
transport,
uri_provider,
response_listener,
})
}
#[cfg(test)]
fn contains_pending_request(&self, reqid: &UUID) -> bool {
self.response_listener.contains(reqid)
}
}
#[async_trait]
impl RpcClient for InMemoryRpcClient {
async fn invoke_method(
&self,
method: UUri,
call_options: CallOptions,
payload: Option<UPayload>,
) -> Result<Option<UPayload>, ServiceInvocationError> {
let message_id = call_options.message_id().unwrap_or_else(UUID::build);
let mut builder = UMessageBuilder::request(
method.clone(),
self.uri_provider.get_source_uri(),
call_options.ttl(),
);
builder.with_message_id(message_id.clone());
if let Some(token) = call_options.token() {
builder.with_token(token.to_owned());
}
if let Some(priority) = call_options.priority() {
builder.with_priority(priority);
}
let rpc_request_message = build_message(&mut builder, payload)
.map_err(|e| ServiceInvocationError::InvalidArgument(e.to_string()))?;
let receiver = self
.response_listener
.try_add_pending_request(message_id.clone())?;
self.transport
.send(rpc_request_message)
.await
.map_err(|e| {
self.response_listener.remove_pending_request(&message_id);
e
})?;
if let Ok(Ok(response_message)) =
timeout(Duration::from_millis(call_options.ttl() as u64), receiver).await
{
handle_response_message(response_message)
} else {
self.response_listener.remove_pending_request(&message_id);
Err(ServiceInvocationError::DeadlineExceeded)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use protobuf::{well_known_types::wrappers::StringValue, Enum};
use crate::{
utransport::{MockLocalUriProvider, MockTransport},
UMessageBuilder, UPriority, UUri,
};
fn new_uri_provider() -> Arc<dyn LocalUriProvider> {
let mut mock_uri_locator = MockLocalUriProvider::new();
mock_uri_locator.expect_get_source_uri().returning(|| UUri {
ue_id: 0x0005,
ue_version_major: 0x02,
resource_id: 0x0000,
..Default::default()
});
Arc::new(mock_uri_locator)
}
fn service_method_uri() -> UUri {
UUri {
ue_id: 0x0001,
ue_version_major: 0x01,
resource_id: 0x1000,
..Default::default()
}
}
#[tokio::test]
async fn test_registration_of_response_listener_fails() {
let mut mock_transport = MockTransport::default();
mock_transport
.expect_do_register_listener()
.once()
.returning(|_source_filter, _sink_filter, _listener| {
Err(UStatus::fail_with_code(
UCode::RESOURCE_EXHAUSTED,
"max number of listeners exceeded",
))
});
let creation_attempt =
InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider()).await;
assert!(
creation_attempt.is_err_and(|e| matches!(e, RegistrationError::MaxListenersExceeded))
);
}
#[tokio::test]
async fn test_invoke_method_fails_with_transport_error() {
let mut mock_transport = MockTransport::default();
mock_transport
.expect_do_register_listener()
.once()
.returning(|_source_filter, _sink_filter, _listener| Ok(()));
mock_transport
.expect_do_send()
.returning(|_request_message| {
Err(UStatus::fail_with_code(
UCode::UNAVAILABLE,
"transport not available",
))
});
let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
.await
.unwrap();
let message_id = UUID::build();
let call_options =
CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
let response = client
.invoke_method(service_method_uri(), call_options, None)
.await;
assert!(response.is_err_and(|e| matches!(e, ServiceInvocationError::Unavailable(_msg))));
assert!(!client.contains_pending_request(&message_id));
}
#[tokio::test]
async fn test_invoke_method_succeeds() {
let message_id = UUID::build();
let call_options = CallOptions::for_rpc_request(
5_000,
Some(message_id.clone()),
Some("my_token".to_string()),
Some(crate::UPriority::UPRIORITY_CS6),
);
let (captured_listener_tx, captured_listener_rx) = std::sync::mpsc::channel();
let mut mock_transport = MockTransport::default();
mock_transport.expect_do_register_listener().returning(
move |_source_filter, _sink_filter, listener| {
captured_listener_tx
.send(listener)
.map_err(|_e| UStatus::fail("cannot capture listener"))
},
);
let expected_message_id = message_id.clone();
mock_transport
.expect_do_send()
.withf(move |request_message| {
request_message
.attributes
.as_ref()
.map_or(false, |attribs| {
attribs.id.as_ref() == Some(&expected_message_id)
&& attribs.priority.value() == UPriority::UPRIORITY_CS6.value()
&& attribs.ttl == Some(5_000)
&& attribs.token == Some("my_token".to_string())
})
})
.returning(move |request_message| {
let request_payload: StringValue = request_message.extract_protobuf().unwrap();
let response_payload = StringValue {
value: format!("Hello {}", request_payload.value),
..Default::default()
};
let response_message = UMessageBuilder::response_for_request(
request_message.attributes.as_ref().unwrap(),
)
.build_with_protobuf_payload(&response_payload)
.unwrap();
let captured_listener = captured_listener_rx.recv().unwrap().to_owned();
tokio::spawn(async move { captured_listener.on_receive(response_message).await });
Ok(())
});
let rpc_client = Arc::new(
InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
.await
.unwrap(),
);
let client: Arc<dyn RpcClient> = rpc_client.clone();
let request_payload = StringValue {
value: "World".to_string(),
..Default::default()
};
let response: StringValue = client
.invoke_proto_method(service_method_uri(), call_options, request_payload)
.await
.expect("invoking method should have succeeded");
assert_eq!(response.value, "Hello World");
assert!(!rpc_client.contains_pending_request(&message_id));
}
#[tokio::test]
async fn test_invoke_method_fails_with_remote_error() {
let (captured_listener_tx, captured_listener_rx) = std::sync::mpsc::channel();
let mut mock_transport = MockTransport::default();
mock_transport.expect_do_register_listener().returning(
move |_source_filter, _sink_filter, listener| {
captured_listener_tx
.send(listener)
.map_err(|_e| UStatus::fail("cannot capture listener"))
},
);
mock_transport
.expect_do_send()
.returning(move |request_message| {
let error = UStatus::fail_with_code(UCode::NOT_FOUND, "no such object");
let response_message = UMessageBuilder::response_for_request(
request_message.attributes.as_ref().unwrap(),
)
.with_comm_status(UCode::NOT_FOUND)
.build_with_protobuf_payload(&error)
.unwrap();
let captured_listener = captured_listener_rx.recv().unwrap().to_owned();
tokio::spawn(async move { captured_listener.on_receive(response_message).await });
Ok(())
});
let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
.await
.unwrap();
let message_id = UUID::build();
let call_options =
CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
let response = client
.invoke_method(service_method_uri(), call_options, None)
.await;
assert!(response.is_err_and(|e| { matches!(e, ServiceInvocationError::NotFound(_msg)) }));
assert!(!client.contains_pending_request(&message_id));
}
#[tokio::test]
async fn test_invoke_method_times_out() {
let mut mock_transport = MockTransport::default();
mock_transport
.expect_do_register_listener()
.returning(|_source_filter, _sink_filter, _listener| Ok(()));
mock_transport
.expect_do_send()
.returning(|_request_message| Ok(()));
let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
.await
.unwrap();
let message_id = UUID::build();
let call_options = CallOptions::for_rpc_request(20, Some(message_id.clone()), None, None);
let response = client
.invoke_method(service_method_uri(), call_options, None)
.await;
assert!(response.is_err_and(|e| { matches!(e, ServiceInvocationError::DeadlineExceeded) }));
assert!(!client.contains_pending_request(&message_id));
}
}