spark_connect/
session.rs

1//! High-level user-facing interface for Spark Connect.
2//!
3//! This module provides [`SparkSession`] — the main entry point for interacting
4//! with a Spark Connect server. It exposes a familiar API surface inspired by
5//! PySpark and Scala's `SparkSession`, while delegating low-level gRPC work to
6//! [`SparkClient`](crate::SparkClient).
7//!
8//! # Typical usage
9//!
10//! ```
11//! use spark_connect::SparkSessionBuilder;
12//!
13//! # tokio_test::block_on(async {
14//! let session = SparkSessionBuilder::new("sc://localhost:15002")
15//!     .build()
16//!     .await
17//!     .expect("failed to connect");
18//!
19//! println!("Connected to Spark session: {}", session.session_id());
20//! # });
21//! ```
22//!
23//! The `SparkSession` provides an ergonomic API for executing SQL, analyzing
24//! plans, and inspecting results — without exposing internal client plumbing.
25use crate::client::ChannelBuilder;
26use crate::client::HeaderInterceptor;
27use crate::client::SparkClient;
28use crate::spark;
29use crate::spark::spark_connect_service_client::SparkConnectServiceClient;
30use crate::spark::expression::Literal;
31use crate::query::SqlQueryBuilder;
32use crate::{SparkError, error::SparkErrorKind};
33
34use arrow::record_batch::RecordBatch;
35use std::sync::Arc;
36use tokio::sync::RwLock;
37use tonic::transport::Channel;
38#[cfg(feature = "tls")]
39use tonic::transport::ClientTlsConfig;
40use tower::ServiceBuilder;
41
42/// Builder for creating [`SparkSession`] instances.
43///
44/// Configures a connection to a Spark Connect endpoint
45/// following the URL format defined by
46/// [Apache Spark's client connection spec](https://github.com/apache/spark/blob/master/connector/connect/docs/client-connection-string.md).
47///
48/// # Example
49///
50/// ```
51/// use spark_connect::SparkSessionBuilder;
52///
53/// # tokio_test::block_on(async {
54/// let session = SparkSessionBuilder::new("sc://localhost:15002")
55///     .build()
56///     .await
57///     .unwrap();
58///
59/// println!("Session ID: {}", session.session_id());
60/// # });
61/// ```
62#[derive(Clone, Debug)]
63pub struct SparkSessionBuilder {
64    channel_builder: ChannelBuilder,
65}
66
67impl SparkSessionBuilder {
68    /// Creates a new builder from a Spark Connect connection string.
69    ///
70    /// The connection string must follow the format:
71    /// `sc://<host>:<port>/;key1=value1;key2=value2;...`
72    pub fn new(connection: &str) -> Self {
73        let channel_builder =
74            ChannelBuilder::new(connection).expect("Invalid Spark connection string");
75        Self { channel_builder }
76    }
77
78    /// Establishes a connection and returns a ready-to-use [`SparkSession`].
79    ///
80    /// This method performs:
81    /// - gRPC channel setup;
82    /// - Metadata interceptor attachment;
83    /// - [`SparkClient`](crate::SparkClient) initialization.
84    pub async fn build(&self) -> Result<SparkSession, SparkError> {
85        // Create gRPC endpoint
86        let mut endpoint = Channel::from_shared(self.channel_builder.endpoint())
87            .map_err(|source| {
88                SparkError::new(SparkErrorKind::InvalidConnectionUri {
89                    source, uri: self.channel_builder.endpoint()
90                })
91            })?;
92
93        // Configure TLS if enabled to send
94        // the correct Domain Name (SNI) during handshake.
95        #[cfg(feature = "tls")]
96        if self.channel_builder.use_ssl {
97            let tls_config = ClientTlsConfig::new()
98                .domain_name(&self.channel_builder.host)
99                // Use system root certificates.
100                .with_native_roots();
101            
102            endpoint = endpoint.tls_config(tls_config).map_err(|source| {
103                SparkError::new(SparkErrorKind::Transport(source))
104            })?;
105        }
106
107        // Connect to the endpoint and build the channel.
108        let channel = ServiceBuilder::new().service(
109            endpoint.connect().await.map_err(|source| {
110                SparkError::new(SparkErrorKind::Transport(source))
111            })?
112        );
113
114        let grpc_client = SparkConnectServiceClient::with_interceptor(
115            channel, HeaderInterceptor::new(
116                self.channel_builder.headers().unwrap_or_default()
117            )
118        );
119        let spark_client = SparkClient::new(
120            Arc::new(RwLock::new(grpc_client)),
121            self.channel_builder.clone(),
122        );
123
124        Ok(SparkSession::new(spark_client))
125    }
126}
127
128/// Represents a logical connection to a Spark Connect backend.
129///
130/// `SparkSession` is the main entry point for executing commands, analyzing
131/// queries, and retrieving results from Spark Connect.
132///
133/// It wraps an internal [`SparkClient`](crate::SparkClient) and tracks session
134/// state (such as the `session_id`).
135///
136/// # Examples
137///
138/// ```
139/// use spark_connect::SparkSessionBuilder;
140///
141/// # tokio_test::block_on(async {
142/// let session = SparkSessionBuilder::new("sc://localhost:15002")
143///     .build()
144///     .await
145///     .unwrap();
146///
147/// println!("Session ID: {}", session.session_id());
148/// # });
149/// ```
150#[derive(Clone, Debug)]
151pub struct SparkSession {
152    client: SparkClient,
153    session_id: String,
154}
155
156impl SparkSession {
157    /// Creates a new session from a [`SparkClient`].
158    ///
159    /// Usually invoked internally by [`SparkSessionBuilder::build`].
160    pub(crate) fn new(client: SparkClient) -> Self {
161        let session_id = client.session_id().to_string();
162        Self { client, session_id }
163    }
164
165     /// Returns the unique session identifier for this connection.
166    pub fn session_id(&self) -> String {
167        self.session_id.to_string()
168    }
169
170    /// Returns a mutable reference to the underlying [`SparkClient`].
171    ///
172    /// While exposed for advanced use cases, typical consumers are advised to rely on
173    /// higher-level abstractions in `SparkSession` instead of manipulating the
174    /// client directly.
175    pub(crate) fn client(&self) -> SparkClient {
176        self.client.clone()
177    }
178
179    /// Execute a SQL query and return a lazy [`plan`](crate::spark::Plan).
180    pub async fn sql(
181        &self,
182        query: &str,
183        params: Vec<Literal>
184    ) -> Result<spark::Plan, SparkError> {
185        let sql_cmd = spark::command::CommandType::SqlCommand(
186            spark::SqlCommand {
187                sql: query.to_string(),
188                args: Default::default(),
189                pos_args: params,
190            },
191        );
192
193        // Execute plan
194        let plan = spark::Plan {
195            op_type: Some(spark::plan::OpType::Command(spark::Command {
196                command_type: Some(sql_cmd),
197            })),
198        };
199        let mut client = self.client();
200        let result = client.execute_plan(plan).await?;
201
202        Ok(spark::Plan {
203            op_type: Some(spark::plan::OpType::Root(result.relation()?)),
204        })
205    }
206
207    /// Alternative ["sqlx-like"](https://docs.rs/sqlx/latest/sqlx/) query interface.
208    /// Returns a [`SqlQueryBuilder`] to `bind()` parameters and `execute()`.
209    pub fn query(
210        &self,
211        query: &str,
212    ) -> SqlQueryBuilder<'_> {
213        SqlQueryBuilder::new(&self, query)
214    }
215
216    /// Collect the results from a lazy [`plan`](crate::spark::Plan).
217    pub async fn collect(&self, plan: spark::Plan) -> Result<Vec<RecordBatch>, SparkError> {
218        let mut client = self.client();
219
220        Ok(client.execute_plan(plan).await?.batches())
221    }
222
223    /// Interrupt all running operations.
224    pub async fn interrupt_all(&self) -> Result<Vec<String>, SparkError> {
225        Ok(
226            self.client().interrupt(
227                spark::interrupt_request::InterruptType::All,
228                None
229            ).await?.interrupted_ids()
230        )
231    }
232
233    /// Interrupt a specific operation by ID.
234    pub async fn interrupt_operation(&self, op_id: &str) -> Result<Vec<String>, SparkError> {
235        Ok(
236            self.client().interrupt(
237                spark::interrupt_request::InterruptType::OperationId,
238                Some(op_id.to_string()),
239            ).await?.interrupted_ids()
240        )
241    }
242
243    /// Request the version of the Spark Connect server.
244    pub async fn version(&self) -> Result<String, SparkError> {
245        let version = spark::analyze_plan_request::Analyze::SparkVersion(
246            spark::analyze_plan_request::SparkVersion {},
247        );
248
249        let mut client = self.client.clone();
250        
251        Ok(client.analyze(version).await?.spark_version()?)
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use crate::test_utils::test_utils::setup_session;
258    use crate::SparkError;
259    
260    use arrow::array::{Int32Array, StringArray};
261    use regex::Regex;
262
263    #[tokio::test]
264    async fn test_session_create() {
265        let spark = setup_session().await;
266        assert!(spark.is_ok());
267    }
268
269    /// Verifies that the client can connect, establish a session, and perform
270    /// a basic analysis operation (fetching the Spark version).
271    /// This tests `SparkClient::new` and `SparkClient::analyze`.
272    #[tokio::test]
273    async fn test_session_version() -> Result<(), SparkError> {
274        // Arrange: Start server and create a session
275        let spark = setup_session().await?;
276        
277        // Act: The version() method on SparkSession will trigger the
278        // underlying SparkClient::analyze call.
279        let version = spark.version().await?;
280
281        // Assert: Check for a valid version string
282        let re = Regex::new(r"^\d+\.\d+\.\d+$").unwrap();
283        assert!(re.is_match(&version), "Version {} invalid", version);
284        Ok(())
285    }
286
287    /// Verifies that the client can execute a SQL query
288    /// and correctly retrieve the resulting Arrow RecordBatches.
289    /// This tests `SparkClient::execute_command_and_fetch`.
290    #[tokio::test]
291    async fn test_sql() {
292        // Arrange: Start server and create a session
293        let session = setup_session().await.expect("Failed to create Spark session");
294
295        // Act: Execute a simple SQL query.
296        let lazy_plan = session
297            .sql("SELECT 1 AS id, 'hello' AS text", vec![])
298            .await
299            .expect("SQL query failed");
300        let batches = session
301            .collect(lazy_plan)
302            .await
303            .expect("Failed to collect batches");
304
305        // Assert: Validate the structure and content of the returned data
306        assert_eq!(batches.len(), 1, "Expected exactly one RecordBatch");
307        let batch = &batches[0];
308        assert_eq!(batch.num_rows(), 1, "Expected one row");
309        assert_eq!(batch.num_columns(), 2, "Expected two columns");
310
311        // Verify the data in the first column (id)
312        let id_col = batch
313            .column(0)
314            .as_any()
315            .downcast_ref::<Int32Array>()
316            .expect("Column 0 should be an Int32Array");
317        assert_eq!(id_col.value(0), 1);
318    }
319    
320    #[tokio::test]
321    async fn test_sql_query_builder_bind() -> Result<(), SparkError> {
322        let session = setup_session().await?;
323
324        // Use SqlQueryBuilder and bind parameters
325        let batches = session
326            .query("SELECT ? AS id, ? AS text")
327            .bind(42_i32)
328            .bind("world")
329            .execute()
330            .await?;
331
332        assert_eq!(batches.len(), 1);
333        let batch = &batches[0];
334        assert_eq!(batch.num_rows(), 1);
335        assert_eq!(batch.num_columns(), 2);
336
337        let id_col = batch.column(0)
338            .as_any()
339            .downcast_ref::<Int32Array>()
340            .unwrap();
341        assert_eq!(id_col.value(0), 42);
342
343        let text_col = batch.column(1)
344            .as_any()
345            .downcast_ref::<StringArray>()
346            .unwrap();
347        assert_eq!(text_col.value(0), "world");
348
349        Ok(())
350    }
351}