spark_rust/rpc/connections/
connection.rs

1//! Connection implementation for establishing secure RPC connections to Spark nodes.
2//!
3//! This module provides the default connection implementation used by the SDK to establish
4//! TLS-secured connections to Spark RPC endpoints. It handles:
5//! - TLS configuration with Amazon root CA certificate verification
6//! - Domain name validation
7//! - Channel establishment and management
8
9use crate::common_types::types::certificates;
10use crate::error::{NetworkError, SparkSdkError};
11use crate::rpc::traits::SparkRpcConnection;
12use crate::rpc::SparkRpcClient;
13use tonic::async_trait;
14use tonic::transport::Uri;
15use tonic::transport::{Certificate, Channel, ClientTlsConfig};
16
17/// Default connection type for establishing secure RPC connections to Spark nodes.
18///
19/// This connection type uses TLS with Amazon root CA certificate verification.
20#[derive(Debug, Clone)]
21pub(crate) struct SparkConnection {
22    /// The client to use for the Spark RPC service
23    pub(crate) channel: Channel,
24}
25
26#[async_trait]
27impl SparkRpcConnection for SparkConnection {
28    /// Establishes a new secure RPC connection to a Spark node.
29    ///
30    /// # Arguments
31    /// * `uri` - The URI of the Spark node to connect to
32    ///
33    /// # Returns
34    /// - `Ok(SparkRpcClient)` if the connection was established successfully
35    /// - `Err(SparkSdkError)` if there was an error establishing the connection
36    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
37    async fn establish_connection(uri: Uri) -> Result<SparkRpcClient, SparkSdkError> {
38        if uri.scheme_str().is_some_and(|scheme| scheme == "http") {
39            let channel = Channel::from_shared(uri.to_string())
40                .map_err(|err| {
41                    SparkSdkError::from(NetworkError::RpcConnection {
42                        uri: uri.to_string(),
43                        details: Some(err.to_string()),
44                    })
45                })?
46                .connect()
47                .await?;
48
49            return Ok(SparkRpcClient::DefaultConnection(SparkConnection {
50                channel,
51            }));
52        }
53
54        let server_root_ca_cert = Certificate::from_pem(certificates::amazon_root_ca::CA_PEM);
55
56        // create tls config
57        let mut tls = ClientTlsConfig::new().ca_certificate(server_root_ca_cert);
58        if let Some(host) = uri.host() {
59            tls = tls.domain_name(host);
60        }
61
62        // create channel with tls
63        let channel = Channel::from_shared(uri.to_string())
64            .unwrap()
65            .tls_config(tls)?
66            .connect()
67            .await?;
68
69        let connection = SparkConnection { channel };
70
71        Ok(SparkRpcClient::DefaultConnection(connection))
72    }
73}