1#![doc(
2 issue_tracker_base_url = "https://github.com/mycelial/snowflake-rs/issues",
3 test(no_crate_inject)
4)]
5#![doc = include_str!("../README.md")]
6#![warn(clippy::all, clippy::pedantic)]
7#![allow(
8clippy::must_use_candidate,
9clippy::missing_errors_doc,
10clippy::module_name_repetitions,
11clippy::struct_field_names,
12clippy::future_not_send, clippy::missing_panics_doc
14)]
15
16use std::fmt::{Display, Formatter};
17use std::io;
18use std::sync::Arc;
19
20use arrow::error::ArrowError;
21use arrow::ipc::reader::StreamReader;
22use arrow::record_batch::RecordBatch;
23use base64::Engine;
24use bytes::{Buf, Bytes};
25use futures::future::try_join_all;
26use regex::Regex;
27use reqwest_middleware::ClientWithMiddleware;
28use thiserror::Error;
29
30use responses::ExecResponse;
31use session::{AuthError, Session};
32
33use crate::connection::QueryType;
34use crate::connection::{Connection, ConnectionError};
35use crate::requests::ExecRequest;
36use crate::responses::{ExecResponseRowType, SnowflakeType};
37use crate::session::AuthError::MissingEnvArgument;
38
39pub mod connection;
40#[cfg(feature = "polars")]
41mod polars;
42mod put;
43mod requests;
44mod responses;
45mod session;
46
47#[derive(Error, Debug)]
48pub enum SnowflakeApiError {
49 #[error(transparent)]
50 RequestError(#[from] ConnectionError),
51
52 #[error(transparent)]
53 AuthError(#[from] AuthError),
54
55 #[error(transparent)]
56 ResponseDeserializationError(#[from] base64::DecodeError),
57
58 #[error(transparent)]
59 ArrowError(#[from] arrow::error::ArrowError),
60
61 #[error("S3 bucket path in PUT request is invalid: `{0}`")]
62 InvalidBucketPath(String),
63
64 #[error("Couldn't extract filename from the local path: `{0}`")]
65 InvalidLocalPath(String),
66
67 #[error(transparent)]
68 LocalIoError(#[from] io::Error),
69
70 #[error(transparent)]
71 ObjectStoreError(#[from] object_store::Error),
72
73 #[error(transparent)]
74 ObjectStorePathError(#[from] object_store::path::Error),
75
76 #[error(transparent)]
77 TokioTaskJoinError(#[from] tokio::task::JoinError),
78
79 #[error("Snowflake API error. Code: `{0}`. Message: `{1}`")]
80 ApiError(String, String),
81
82 #[error("Snowflake API empty response could mean that query wasn't executed correctly or API call was faulty")]
83 EmptyResponse,
84
85 #[error("No usable rowsets were included in the response")]
86 BrokenResponse,
87
88 #[error("Following feature is not implemented yet: {0}")]
89 Unimplemented(String),
90
91 #[error("Unexpected API response")]
92 UnexpectedResponse,
93
94 #[error(transparent)]
95 GlobPatternError(#[from] glob::PatternError),
96
97 #[error(transparent)]
98 GlobError(#[from] glob::GlobError),
99}
100
101pub struct JsonResult {
104 pub value: serde_json::Value,
106 pub schema: Vec<FieldSchema>,
108}
109
110impl Display for JsonResult {
111 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
112 write!(f, "{}", self.value)
113 }
114}
115
116pub struct FieldSchema {
118 pub name: String,
119 pub type_: SnowflakeType,
121 pub scale: Option<i64>,
122 pub precision: Option<i64>,
123 pub nullable: bool,
124}
125
126impl From<ExecResponseRowType> for FieldSchema {
127 fn from(value: ExecResponseRowType) -> Self {
128 FieldSchema {
129 name: value.name,
130 type_: value.type_,
131 scale: value.scale,
132 precision: value.precision,
133 nullable: value.nullable,
134 }
135 }
136}
137
138pub enum QueryResult {
142 Arrow(Vec<RecordBatch>),
143 Json(JsonResult),
144 Empty,
145}
146
147pub enum RawQueryResult {
150 Bytes(Vec<Bytes>),
153 Json(JsonResult),
156 Empty,
157}
158
159impl RawQueryResult {
160 pub fn deserialize_arrow(self) -> Result<QueryResult, ArrowError> {
161 match self {
162 RawQueryResult::Bytes(bytes) => {
163 Self::flat_bytes_to_batches(bytes).map(QueryResult::Arrow)
164 }
165 RawQueryResult::Json(j) => Ok(QueryResult::Json(j)),
166 RawQueryResult::Empty => Ok(QueryResult::Empty),
167 }
168 }
169
170 fn flat_bytes_to_batches(bytes: Vec<Bytes>) -> Result<Vec<RecordBatch>, ArrowError> {
171 let mut res = vec![];
172 for b in bytes {
173 let mut batches = Self::bytes_to_batches(b)?;
174 res.append(&mut batches);
175 }
176 Ok(res)
177 }
178
179 fn bytes_to_batches(bytes: Bytes) -> Result<Vec<RecordBatch>, ArrowError> {
180 let record_batches = StreamReader::try_new(bytes.reader(), None)?;
181 record_batches.into_iter().collect()
182 }
183}
184
185pub struct AuthArgs {
186 pub account_identifier: String,
187 pub warehouse: Option<String>,
188 pub database: Option<String>,
189 pub schema: Option<String>,
190 pub username: String,
191 pub role: Option<String>,
192 pub auth_type: AuthType,
193}
194
195impl AuthArgs {
196 pub fn from_env() -> Result<AuthArgs, SnowflakeApiError> {
197 let auth_type = if let Ok(password) = std::env::var("SNOWFLAKE_PASSWORD") {
198 Ok(AuthType::Password(PasswordArgs { password }))
199 } else if let Ok(private_key_pem) = std::env::var("SNOWFLAKE_PRIVATE_KEY") {
200 Ok(AuthType::Certificate(CertificateArgs { private_key_pem }))
201 } else {
202 Err(MissingEnvArgument(
203 "SNOWFLAKE_PASSWORD or SNOWFLAKE_PRIVATE_KEY".to_owned(),
204 ))
205 };
206
207 Ok(AuthArgs {
208 account_identifier: std::env::var("SNOWFLAKE_ACCOUNT")
209 .map_err(|_| MissingEnvArgument("SNOWFLAKE_ACCOUNT".to_owned()))?,
210 warehouse: std::env::var("SNOWLFLAKE_WAREHOUSE").ok(),
211 database: std::env::var("SNOWFLAKE_DATABASE").ok(),
212 schema: std::env::var("SNOWFLAKE_SCHEMA").ok(),
213 username: std::env::var("SNOWFLAKE_USER")
214 .map_err(|_| MissingEnvArgument("SNOWFLAKE_USER".to_owned()))?,
215 role: std::env::var("SNOWFLAKE_ROLE").ok(),
216 auth_type: auth_type?,
217 })
218 }
219}
220
221pub enum AuthType {
222 Password(PasswordArgs),
223 Certificate(CertificateArgs),
224}
225
226pub struct PasswordArgs {
227 pub password: String,
228}
229
230pub struct CertificateArgs {
231 pub private_key_pem: String,
232}
233
234#[must_use]
235pub struct SnowflakeApiBuilder {
236 pub auth: AuthArgs,
237 client: Option<ClientWithMiddleware>,
238}
239
240impl SnowflakeApiBuilder {
241 pub fn new(auth: AuthArgs) -> Self {
242 Self { auth, client: None }
243 }
244
245 pub fn with_client(mut self, client: ClientWithMiddleware) -> Self {
246 self.client = Some(client);
247 self
248 }
249
250 pub fn build(self) -> Result<SnowflakeApi, SnowflakeApiError> {
251 let connection = match self.client {
252 Some(client) => Arc::new(Connection::new_with_middware(client)),
253 None => Arc::new(Connection::new()?),
254 };
255
256 let session = match self.auth.auth_type {
257 AuthType::Password(args) => Session::password_auth(
258 Arc::clone(&connection),
259 &self.auth.account_identifier,
260 self.auth.warehouse.as_deref(),
261 self.auth.database.as_deref(),
262 self.auth.schema.as_deref(),
263 &self.auth.username,
264 self.auth.role.as_deref(),
265 &args.password,
266 ),
267 AuthType::Certificate(args) => Session::cert_auth(
268 Arc::clone(&connection),
269 &self.auth.account_identifier,
270 self.auth.warehouse.as_deref(),
271 self.auth.database.as_deref(),
272 self.auth.schema.as_deref(),
273 &self.auth.username,
274 self.auth.role.as_deref(),
275 &args.private_key_pem,
276 ),
277 };
278
279 let account_identifier = self.auth.account_identifier.to_uppercase();
280
281 Ok(SnowflakeApi::new(
282 Arc::clone(&connection),
283 session,
284 account_identifier,
285 ))
286 }
287}
288
289pub struct SnowflakeApi {
291 connection: Arc<Connection>,
292 session: Session,
293 account_identifier: String,
294}
295
296impl SnowflakeApi {
297 pub fn new(connection: Arc<Connection>, session: Session, account_identifier: String) -> Self {
299 Self {
300 connection,
301 session,
302 account_identifier,
303 }
304 }
305 pub fn with_password_auth(
307 account_identifier: &str,
308 warehouse: Option<&str>,
309 database: Option<&str>,
310 schema: Option<&str>,
311 username: &str,
312 role: Option<&str>,
313 password: &str,
314 ) -> Result<Self, SnowflakeApiError> {
315 let connection = Arc::new(Connection::new()?);
316
317 let session = Session::password_auth(
318 Arc::clone(&connection),
319 account_identifier,
320 warehouse,
321 database,
322 schema,
323 username,
324 role,
325 password,
326 );
327
328 let account_identifier = account_identifier.to_uppercase();
329 Ok(Self::new(
330 Arc::clone(&connection),
331 session,
332 account_identifier,
333 ))
334 }
335
336 pub fn with_certificate_auth(
338 account_identifier: &str,
339 warehouse: Option<&str>,
340 database: Option<&str>,
341 schema: Option<&str>,
342 username: &str,
343 role: Option<&str>,
344 private_key_pem: &str,
345 ) -> Result<Self, SnowflakeApiError> {
346 let connection = Arc::new(Connection::new()?);
347
348 let session = Session::cert_auth(
349 Arc::clone(&connection),
350 account_identifier,
351 warehouse,
352 database,
353 schema,
354 username,
355 role,
356 private_key_pem,
357 );
358
359 let account_identifier = account_identifier.to_uppercase();
360 Ok(Self::new(
361 Arc::clone(&connection),
362 session,
363 account_identifier,
364 ))
365 }
366
367 pub fn from_env() -> Result<Self, SnowflakeApiError> {
368 SnowflakeApiBuilder::new(AuthArgs::from_env()?).build()
369 }
370
371 pub async fn close_session(&mut self) -> Result<(), SnowflakeApiError> {
375 self.session.close().await?;
376 Ok(())
377 }
378
379 pub async fn exec(&self, sql: &str) -> Result<QueryResult, SnowflakeApiError> {
382 let raw = self.exec_raw(sql).await?;
383 let res = raw.deserialize_arrow()?;
384 Ok(res)
385 }
386
387 pub async fn exec_raw(&self, sql: &str) -> Result<RawQueryResult, SnowflakeApiError> {
391 let put_re = Regex::new(r"(?i)^(?:/\*.*\*/\s*)*put\s+").unwrap();
392
393 if put_re.is_match(sql) {
395 log::info!("Detected PUT query");
396 self.exec_put(sql).await.map(|()| RawQueryResult::Empty)
397 } else {
398 self.exec_arrow_raw(sql).await
399 }
400 }
401
402 async fn exec_put(&self, sql: &str) -> Result<(), SnowflakeApiError> {
403 let resp = self
404 .run_sql::<ExecResponse>(sql, QueryType::JsonQuery)
405 .await?;
406 log::debug!("Got PUT response: {resp:?}");
407
408 match resp {
409 ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse),
410 ExecResponse::PutGet(pg) => put::put(pg).await,
411 ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
412 e.data.error_code,
413 e.message.unwrap_or_default(),
414 )),
415 }
416 }
417
418 #[cfg(debug_assertions)]
420 pub async fn exec_response(&mut self, sql: &str) -> Result<ExecResponse, SnowflakeApiError> {
421 self.run_sql::<ExecResponse>(sql, QueryType::ArrowQuery)
422 .await
423 }
424
425 #[cfg(debug_assertions)]
427 pub async fn exec_json(&mut self, sql: &str) -> Result<serde_json::Value, SnowflakeApiError> {
428 self.run_sql::<serde_json::Value>(sql, QueryType::JsonQuery)
429 .await
430 }
431
432 async fn exec_arrow_raw(&self, sql: &str) -> Result<RawQueryResult, SnowflakeApiError> {
433 let resp = self
434 .run_sql::<ExecResponse>(sql, QueryType::ArrowQuery)
435 .await?;
436 log::debug!("Got query response: {resp:?}");
437
438 let resp = match resp {
439 ExecResponse::Query(qr) => Ok(qr),
441 ExecResponse::PutGet(_) => Err(SnowflakeApiError::UnexpectedResponse),
442 ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
443 e.data.error_code,
444 e.message.unwrap_or_default(),
445 )),
446 }?;
447
448 if resp.data.returned == 0 {
451 log::debug!("Got response with 0 rows");
452 Ok(RawQueryResult::Empty)
453 } else if let Some(value) = resp.data.rowset {
454 log::debug!("Got JSON response");
455 Ok(RawQueryResult::Json(JsonResult {
459 value,
460 schema: resp.data.rowtype.into_iter().map(Into::into).collect(),
461 }))
462 } else if let Some(base64) = resp.data.rowset_base64 {
463 let mut chunks = try_join_all(resp.data.chunks.iter().map(|chunk| {
465 self.connection
466 .get_chunk(&chunk.url, &resp.data.chunk_headers)
467 }))
468 .await?;
469
470 if !base64.is_empty() {
473 log::debug!("Got base64 encoded response");
474 let bytes = Bytes::from(base64::engine::general_purpose::STANDARD.decode(base64)?);
475 chunks.push(bytes);
476 }
477
478 Ok(RawQueryResult::Bytes(chunks))
479 } else {
480 Err(SnowflakeApiError::BrokenResponse)
481 }
482 }
483
484 async fn run_sql<R: serde::de::DeserializeOwned>(
485 &self,
486 sql_text: &str,
487 query_type: QueryType,
488 ) -> Result<R, SnowflakeApiError> {
489 log::debug!("Executing: {sql_text}");
490
491 let parts = self.session.get_token().await?;
492
493 let body = ExecRequest {
494 sql_text: sql_text.to_string(),
495 async_exec: false,
496 sequence_id: parts.sequence_id,
497 is_internal: false,
498 };
499
500 let resp = self
501 .connection
502 .request::<R>(
503 query_type,
504 &self.account_identifier,
505 &[],
506 Some(&parts.session_token_auth_header),
507 body,
508 )
509 .await?;
510
511 Ok(resp)
512 }
513}