use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use async_trait::async_trait;
use tokio::sync::oneshot::{Receiver, Sender};
use tokio::time::timeout;
use tracing::{debug, info};
use crate::{
    LocalUriProvider, UCode, UListener, UMessage, UMessageBuilder, UMessageType, UStatus,
    UTransport, UUri, UUID,
};
use super::{
    build_message, CallOptions, RegistrationError, RpcClient, ServiceInvocationError, UPayload,
};
fn handle_response_message(response: UMessage) -> Result<Option<UPayload>, ServiceInvocationError> {
    let Some(attribs) = response.attributes.as_ref() else {
        return Err(ServiceInvocationError::RpcError(UStatus::fail_with_code(
            UCode::INTERNAL,
            "response message does not contain attributes",
        )));
    };
    match attribs.commstatus.map(|v| v.enum_value_or_default()) {
        Some(UCode::OK) | None => {
            response.payload.map_or(Ok(None), |payload| {
                Ok(Some(UPayload::new(
                    payload,
                    attribs.payload_format.enum_value_or_default(),
                )))
            })
        }
        Some(code) => {
            let status = response.extract_protobuf().unwrap_or_else(|_e| {
                UStatus::fail_with_code(code, "failed to invoke service operation")
            });
            Err(ServiceInvocationError::from(status))
        }
    }
}
struct ResponseListener {
    pending_requests: Mutex<HashMap<UUID, Sender<UMessage>>>,
}
impl ResponseListener {
    fn try_add_pending_request(
        &self,
        reqid: UUID,
    ) -> Result<Receiver<UMessage>, ServiceInvocationError> {
        let Ok(mut pending_requests) = self.pending_requests.lock() else {
            return Err(ServiceInvocationError::Internal(
                "failed to add response handler".to_string(),
            ));
        };
        if let Entry::Vacant(entry) = pending_requests.entry(reqid) {
            let (tx, rx) = tokio::sync::oneshot::channel();
            entry.insert(tx);
            Ok(rx)
        } else {
            Err(ServiceInvocationError::AlreadyExists(
                "RPC request with given ID already pending".to_string(),
            ))
        }
    }
    fn handle_response(&self, reqid: &UUID, response_message: UMessage) {
        let Ok(mut pending_requests) = self.pending_requests.lock() else {
            info!(
                request_id = reqid.to_hyphenated_string(),
                "failed to process response message, cannot acquire lock for pending requests map"
            );
            return;
        };
        if let Some(sender) = pending_requests.remove(reqid) {
            if let Err(_e) = sender.send(response_message) {
                debug!(
                    request_id = reqid.to_hyphenated_string(),
                    "failed to deliver response message, channel already closed"
                );
            }
        } else {
            debug!(
                request_id = reqid.to_hyphenated_string(),
                "ignoring response message for unknown request"
            );
        }
    }
    fn remove_pending_request(&self, reqid: &UUID) -> Option<Sender<UMessage>> {
        self.pending_requests
            .lock()
            .map_or(None, |mut pending_requests| pending_requests.remove(reqid))
    }
    #[cfg(test)]
    fn contains(&self, reqid: &UUID) -> bool {
        self.pending_requests
            .lock()
            .map_or(false, |pending_requests| {
                pending_requests.contains_key(reqid)
            })
    }
}
#[async_trait]
impl UListener for ResponseListener {
    async fn on_receive(&self, msg: UMessage) {
        let message_type = msg
            .attributes
            .get_or_default()
            .type_
            .enum_value_or_default();
        if message_type != UMessageType::UMESSAGE_TYPE_RESPONSE {
            debug!(
                message_type = message_type.to_cloudevent_type(),
                "service provider replied with message that is not an RPC Response"
            );
            return;
        }
        if let Some(reqid) = msg
            .attributes
            .as_ref()
            .and_then(|attribs| attribs.reqid.clone().into_option())
        {
            self.handle_response(&reqid, msg);
        } else {
            debug!("ignoring malformed response message not containing request ID");
        }
    }
}
pub struct InMemoryRpcClient {
    transport: Arc<dyn UTransport>,
    uri_provider: Arc<dyn LocalUriProvider>,
    response_listener: Arc<ResponseListener>,
}
impl InMemoryRpcClient {
    pub async fn new(
        transport: Arc<dyn UTransport>,
        uri_provider: Arc<dyn LocalUriProvider>,
    ) -> Result<Self, RegistrationError> {
        let response_listener = Arc::new(ResponseListener {
            pending_requests: Mutex::new(HashMap::new()),
        });
        transport
            .register_listener(
                &UUri::any(),
                Some(&uri_provider.get_source_uri()),
                response_listener.clone(),
            )
            .await
            .map_err(RegistrationError::from)?;
        Ok(InMemoryRpcClient {
            transport,
            uri_provider,
            response_listener,
        })
    }
    #[cfg(test)]
    fn contains_pending_request(&self, reqid: &UUID) -> bool {
        self.response_listener.contains(reqid)
    }
}
#[async_trait]
impl RpcClient for InMemoryRpcClient {
    async fn invoke_method(
        &self,
        method: UUri,
        call_options: CallOptions,
        payload: Option<UPayload>,
    ) -> Result<Option<UPayload>, ServiceInvocationError> {
        let message_id = call_options.message_id().unwrap_or_else(UUID::build);
        let mut builder = UMessageBuilder::request(
            method.clone(),
            self.uri_provider.get_source_uri(),
            call_options.ttl(),
        );
        builder.with_message_id(message_id.clone());
        if let Some(token) = call_options.token() {
            builder.with_token(token.to_owned());
        }
        if let Some(priority) = call_options.priority() {
            builder.with_priority(priority);
        }
        let rpc_request_message = build_message(&mut builder, payload)
            .map_err(|e| ServiceInvocationError::InvalidArgument(e.to_string()))?;
        let receiver = self
            .response_listener
            .try_add_pending_request(message_id.clone())?;
        self.transport
            .send(rpc_request_message)
            .await
            .map_err(|e| {
                self.response_listener.remove_pending_request(&message_id);
                e
            })?;
        if let Ok(Ok(response_message)) =
            timeout(Duration::from_millis(call_options.ttl() as u64), receiver).await
        {
            handle_response_message(response_message)
        } else {
            self.response_listener.remove_pending_request(&message_id);
            Err(ServiceInvocationError::DeadlineExceeded)
        }
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    use protobuf::{well_known_types::wrappers::StringValue, Enum};
    use crate::{
        utransport::{MockLocalUriProvider, MockTransport},
        UMessageBuilder, UPriority, UUri,
    };
    fn new_uri_provider() -> Arc<dyn LocalUriProvider> {
        let mut mock_uri_locator = MockLocalUriProvider::new();
        mock_uri_locator.expect_get_source_uri().returning(|| UUri {
            ue_id: 0x0005,
            ue_version_major: 0x02,
            resource_id: 0x0000,
            ..Default::default()
        });
        Arc::new(mock_uri_locator)
    }
    fn service_method_uri() -> UUri {
        UUri {
            ue_id: 0x0001,
            ue_version_major: 0x01,
            resource_id: 0x1000,
            ..Default::default()
        }
    }
    #[tokio::test]
    async fn test_registration_of_response_listener_fails() {
        let mut mock_transport = MockTransport::default();
        mock_transport
            .expect_do_register_listener()
            .once()
            .returning(|_source_filter, _sink_filter, _listener| {
                Err(UStatus::fail_with_code(
                    UCode::RESOURCE_EXHAUSTED,
                    "max number of listeners exceeded",
                ))
            });
        let creation_attempt =
            InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider()).await;
        assert!(
            creation_attempt.is_err_and(|e| matches!(e, RegistrationError::MaxListenersExceeded))
        );
    }
    #[tokio::test]
    async fn test_invoke_method_fails_with_transport_error() {
        let mut mock_transport = MockTransport::default();
        mock_transport
            .expect_do_register_listener()
            .once()
            .returning(|_source_filter, _sink_filter, _listener| Ok(()));
        mock_transport
            .expect_do_send()
            .returning(|_request_message| {
                Err(UStatus::fail_with_code(
                    UCode::UNAVAILABLE,
                    "transport not available",
                ))
            });
        let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
            .await
            .unwrap();
        let message_id = UUID::build();
        let call_options =
            CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
        let response = client
            .invoke_method(service_method_uri(), call_options, None)
            .await;
        assert!(response.is_err_and(|e| matches!(e, ServiceInvocationError::Unavailable(_msg))));
        assert!(!client.contains_pending_request(&message_id));
    }
    #[tokio::test]
    async fn test_invoke_method_succeeds() {
        let message_id = UUID::build();
        let call_options = CallOptions::for_rpc_request(
            5_000,
            Some(message_id.clone()),
            Some("my_token".to_string()),
            Some(crate::UPriority::UPRIORITY_CS6),
        );
        let (captured_listener_tx, captured_listener_rx) = std::sync::mpsc::channel();
        let mut mock_transport = MockTransport::default();
        mock_transport.expect_do_register_listener().returning(
            move |_source_filter, _sink_filter, listener| {
                captured_listener_tx
                    .send(listener)
                    .map_err(|_e| UStatus::fail("cannot capture listener"))
            },
        );
        let expected_message_id = message_id.clone();
        mock_transport
            .expect_do_send()
            .withf(move |request_message| {
                request_message
                    .attributes
                    .as_ref()
                    .map_or(false, |attribs| {
                        attribs.id.as_ref() == Some(&expected_message_id)
                            && attribs.priority.value() == UPriority::UPRIORITY_CS6.value()
                            && attribs.ttl == Some(5_000)
                            && attribs.token == Some("my_token".to_string())
                    })
            })
            .returning(move |request_message| {
                let request_payload: StringValue = request_message.extract_protobuf().unwrap();
                let response_payload = StringValue {
                    value: format!("Hello {}", request_payload.value),
                    ..Default::default()
                };
                let response_message = UMessageBuilder::response_for_request(
                    request_message.attributes.as_ref().unwrap(),
                )
                .build_with_protobuf_payload(&response_payload)
                .unwrap();
                let captured_listener = captured_listener_rx.recv().unwrap().to_owned();
                tokio::spawn(async move { captured_listener.on_receive(response_message).await });
                Ok(())
            });
        let rpc_client = Arc::new(
            InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
                .await
                .unwrap(),
        );
        let client: Arc<dyn RpcClient> = rpc_client.clone();
        let request_payload = StringValue {
            value: "World".to_string(),
            ..Default::default()
        };
        let response: StringValue = client
            .invoke_proto_method(service_method_uri(), call_options, request_payload)
            .await
            .expect("invoking method should have succeeded");
        assert_eq!(response.value, "Hello World");
        assert!(!rpc_client.contains_pending_request(&message_id));
    }
    #[tokio::test]
    async fn test_invoke_method_fails_with_remote_error() {
        let (captured_listener_tx, captured_listener_rx) = std::sync::mpsc::channel();
        let mut mock_transport = MockTransport::default();
        mock_transport.expect_do_register_listener().returning(
            move |_source_filter, _sink_filter, listener| {
                captured_listener_tx
                    .send(listener)
                    .map_err(|_e| UStatus::fail("cannot capture listener"))
            },
        );
        mock_transport
            .expect_do_send()
            .returning(move |request_message| {
                let error = UStatus::fail_with_code(UCode::NOT_FOUND, "no such object");
                let response_message = UMessageBuilder::response_for_request(
                    request_message.attributes.as_ref().unwrap(),
                )
                .with_comm_status(UCode::NOT_FOUND)
                .build_with_protobuf_payload(&error)
                .unwrap();
                let captured_listener = captured_listener_rx.recv().unwrap().to_owned();
                tokio::spawn(async move { captured_listener.on_receive(response_message).await });
                Ok(())
            });
        let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
            .await
            .unwrap();
        let message_id = UUID::build();
        let call_options =
            CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
        let response = client
            .invoke_method(service_method_uri(), call_options, None)
            .await;
        assert!(response.is_err_and(|e| { matches!(e, ServiceInvocationError::NotFound(_msg)) }));
        assert!(!client.contains_pending_request(&message_id));
    }
    #[tokio::test]
    async fn test_invoke_method_times_out() {
        let mut mock_transport = MockTransport::default();
        mock_transport
            .expect_do_register_listener()
            .returning(|_source_filter, _sink_filter, _listener| Ok(()));
        mock_transport
            .expect_do_send()
            .returning(|_request_message| Ok(()));
        let client = InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
            .await
            .unwrap();
        let message_id = UUID::build();
        let call_options = CallOptions::for_rpc_request(20, Some(message_id.clone()), None, None);
        let response = client
            .invoke_method(service_method_uri(), call_options, None)
            .await;
        assert!(response.is_err_and(|e| { matches!(e, ServiceInvocationError::DeadlineExceeded) }));
        assert!(!client.contains_pending_request(&message_id));
    }
}