spark_rust/rpc/connections/
connection.rs1use crate::common_types::types::certificates;
10use crate::error::{NetworkError, SparkSdkError};
11use crate::rpc::traits::SparkRpcConnection;
12use crate::rpc::SparkRpcClient;
13use tonic::async_trait;
14use tonic::transport::Uri;
15use tonic::transport::{Certificate, Channel, ClientTlsConfig};
16
17#[derive(Debug, Clone)]
21pub(crate) struct SparkConnection {
22 pub(crate) channel: Channel,
24}
25
26#[async_trait]
27impl SparkRpcConnection for SparkConnection {
28 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
37 async fn establish_connection(uri: Uri) -> Result<SparkRpcClient, SparkSdkError> {
38 if uri.scheme_str().is_some_and(|scheme| scheme == "http") {
39 let channel = Channel::from_shared(uri.to_string())
40 .map_err(|err| {
41 SparkSdkError::from(NetworkError::RpcConnection {
42 uri: uri.to_string(),
43 details: Some(err.to_string()),
44 })
45 })?
46 .connect()
47 .await?;
48
49 return Ok(SparkRpcClient::DefaultConnection(SparkConnection {
50 channel,
51 }));
52 }
53
54 let server_root_ca_cert = Certificate::from_pem(certificates::amazon_root_ca::CA_PEM);
55
56 let mut tls = ClientTlsConfig::new().ca_certificate(server_root_ca_cert);
58 if let Some(host) = uri.host() {
59 tls = tls.domain_name(host);
60 }
61
62 let channel = Channel::from_shared(uri.to_string())
64 .unwrap()
65 .tls_config(tls)?
66 .connect()
67 .await?;
68
69 let connection = SparkConnection { channel };
70
71 Ok(SparkRpcClient::DefaultConnection(connection))
72 }
73}