1use crate::client::ChannelBuilder;
26use crate::client::HeaderInterceptor;
27use crate::client::SparkClient;
28use crate::spark;
29use crate::spark::spark_connect_service_client::SparkConnectServiceClient;
30use crate::spark::expression::Literal;
31use crate::query::SqlQueryBuilder;
32use crate::{SparkError, error::SparkErrorKind};
33
34use arrow::record_batch::RecordBatch;
35use std::sync::Arc;
36use tokio::sync::RwLock;
37use tonic::transport::Channel;
38#[cfg(feature = "tls")]
39use tonic::transport::ClientTlsConfig;
40use tower::ServiceBuilder;
41
42#[derive(Clone, Debug)]
63pub struct SparkSessionBuilder {
64 channel_builder: ChannelBuilder,
65}
66
67impl SparkSessionBuilder {
68 pub fn new(connection: &str) -> Self {
73 let channel_builder =
74 ChannelBuilder::new(connection).expect("Invalid Spark connection string");
75 Self { channel_builder }
76 }
77
78 pub async fn build(&self) -> Result<SparkSession, SparkError> {
85 let mut endpoint = Channel::from_shared(self.channel_builder.endpoint())
87 .map_err(|source| {
88 SparkError::new(SparkErrorKind::InvalidConnectionUri {
89 source, uri: self.channel_builder.endpoint()
90 })
91 })?;
92
93 #[cfg(feature = "tls")]
96 if self.channel_builder.use_ssl {
97 let tls_config = ClientTlsConfig::new()
98 .domain_name(&self.channel_builder.host)
99 .with_native_roots();
101
102 endpoint = endpoint.tls_config(tls_config).map_err(|source| {
103 SparkError::new(SparkErrorKind::Transport(source))
104 })?;
105 }
106
107 let channel = ServiceBuilder::new().service(
109 endpoint.connect().await.map_err(|source| {
110 SparkError::new(SparkErrorKind::Transport(source))
111 })?
112 );
113
114 let grpc_client = SparkConnectServiceClient::with_interceptor(
115 channel, HeaderInterceptor::new(
116 self.channel_builder.headers().unwrap_or_default()
117 )
118 );
119 let spark_client = SparkClient::new(
120 Arc::new(RwLock::new(grpc_client)),
121 self.channel_builder.clone(),
122 );
123
124 Ok(SparkSession::new(spark_client))
125 }
126}
127
128#[derive(Clone, Debug)]
151pub struct SparkSession {
152 client: SparkClient,
153 session_id: String,
154}
155
156impl SparkSession {
157 pub(crate) fn new(client: SparkClient) -> Self {
161 let session_id = client.session_id().to_string();
162 Self { client, session_id }
163 }
164
165 pub fn session_id(&self) -> String {
167 self.session_id.to_string()
168 }
169
170 pub(crate) fn client(&self) -> SparkClient {
176 self.client.clone()
177 }
178
179 pub async fn sql(
181 &self,
182 query: &str,
183 params: Vec<Literal>
184 ) -> Result<spark::Plan, SparkError> {
185 let sql_cmd = spark::command::CommandType::SqlCommand(
186 spark::SqlCommand {
187 sql: query.to_string(),
188 args: Default::default(),
189 pos_args: params,
190 },
191 );
192
193 let plan = spark::Plan {
195 op_type: Some(spark::plan::OpType::Command(spark::Command {
196 command_type: Some(sql_cmd),
197 })),
198 };
199 let mut client = self.client();
200 let result = client.execute_plan(plan).await?;
201
202 Ok(spark::Plan {
203 op_type: Some(spark::plan::OpType::Root(result.relation()?)),
204 })
205 }
206
207 pub fn query(
210 &self,
211 query: &str,
212 ) -> SqlQueryBuilder<'_> {
213 SqlQueryBuilder::new(&self, query)
214 }
215
216 pub async fn collect(&self, plan: spark::Plan) -> Result<Vec<RecordBatch>, SparkError> {
218 let mut client = self.client();
219
220 Ok(client.execute_plan(plan).await?.batches())
221 }
222
223 pub async fn interrupt_all(&self) -> Result<Vec<String>, SparkError> {
225 Ok(
226 self.client().interrupt(
227 spark::interrupt_request::InterruptType::All,
228 None
229 ).await?.interrupted_ids()
230 )
231 }
232
233 pub async fn interrupt_operation(&self, op_id: &str) -> Result<Vec<String>, SparkError> {
235 Ok(
236 self.client().interrupt(
237 spark::interrupt_request::InterruptType::OperationId,
238 Some(op_id.to_string()),
239 ).await?.interrupted_ids()
240 )
241 }
242
243 pub async fn version(&self) -> Result<String, SparkError> {
245 let version = spark::analyze_plan_request::Analyze::SparkVersion(
246 spark::analyze_plan_request::SparkVersion {},
247 );
248
249 let mut client = self.client.clone();
250
251 Ok(client.analyze(version).await?.spark_version()?)
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use crate::test_utils::test_utils::setup_session;
258 use crate::SparkError;
259
260 use arrow::array::{Int32Array, StringArray};
261 use regex::Regex;
262
263 #[tokio::test]
264 async fn test_session_create() {
265 let spark = setup_session().await;
266 assert!(spark.is_ok());
267 }
268
269 #[tokio::test]
273 async fn test_session_version() -> Result<(), SparkError> {
274 let spark = setup_session().await?;
276
277 let version = spark.version().await?;
280
281 let re = Regex::new(r"^\d+\.\d+\.\d+$").unwrap();
283 assert!(re.is_match(&version), "Version {} invalid", version);
284 Ok(())
285 }
286
287 #[tokio::test]
291 async fn test_sql() {
292 let session = setup_session().await.expect("Failed to create Spark session");
294
295 let lazy_plan = session
297 .sql("SELECT 1 AS id, 'hello' AS text", vec![])
298 .await
299 .expect("SQL query failed");
300 let batches = session
301 .collect(lazy_plan)
302 .await
303 .expect("Failed to collect batches");
304
305 assert_eq!(batches.len(), 1, "Expected exactly one RecordBatch");
307 let batch = &batches[0];
308 assert_eq!(batch.num_rows(), 1, "Expected one row");
309 assert_eq!(batch.num_columns(), 2, "Expected two columns");
310
311 let id_col = batch
313 .column(0)
314 .as_any()
315 .downcast_ref::<Int32Array>()
316 .expect("Column 0 should be an Int32Array");
317 assert_eq!(id_col.value(0), 1);
318 }
319
320 #[tokio::test]
321 async fn test_sql_query_builder_bind() -> Result<(), SparkError> {
322 let session = setup_session().await?;
323
324 let batches = session
326 .query("SELECT ? AS id, ? AS text")
327 .bind(42_i32)
328 .bind("world")
329 .execute()
330 .await?;
331
332 assert_eq!(batches.len(), 1);
333 let batch = &batches[0];
334 assert_eq!(batch.num_rows(), 1);
335 assert_eq!(batch.num_columns(), 2);
336
337 let id_col = batch.column(0)
338 .as_any()
339 .downcast_ref::<Int32Array>()
340 .unwrap();
341 assert_eq!(id_col.value(0), 42);
342
343 let text_col = batch.column(1)
344 .as_any()
345 .downcast_ref::<StringArray>()
346 .unwrap();
347 assert_eq!(text_col.value(0), "world");
348
349 Ok(())
350 }
351}