snowflake_api/
lib.rs

1#![doc(
2    issue_tracker_base_url = "https://github.com/mycelial/snowflake-rs/issues",
3    test(no_crate_inject)
4)]
5#![doc = include_str!("../README.md")]
6#![warn(clippy::all, clippy::pedantic)]
7#![allow(
8clippy::must_use_candidate,
9clippy::missing_errors_doc,
10clippy::module_name_repetitions,
11clippy::struct_field_names,
12clippy::future_not_send, // This one seems like something we should eventually fix
13clippy::missing_panics_doc
14)]
15
16use std::fmt::{Display, Formatter};
17use std::io;
18use std::sync::Arc;
19
20use arrow::error::ArrowError;
21use arrow::ipc::reader::StreamReader;
22use arrow::record_batch::RecordBatch;
23use base64::Engine;
24use bytes::{Buf, Bytes};
25use futures::future::try_join_all;
26use regex::Regex;
27use reqwest_middleware::ClientWithMiddleware;
28use thiserror::Error;
29
30use responses::ExecResponse;
31use session::{AuthError, Session};
32
33use crate::connection::QueryType;
34use crate::connection::{Connection, ConnectionError};
35use crate::requests::ExecRequest;
36use crate::responses::{ExecResponseRowType, SnowflakeType};
37use crate::session::AuthError::MissingEnvArgument;
38
39pub mod connection;
40#[cfg(feature = "polars")]
41mod polars;
42mod put;
43mod requests;
44mod responses;
45mod session;
46
47#[derive(Error, Debug)]
48pub enum SnowflakeApiError {
49    #[error(transparent)]
50    RequestError(#[from] ConnectionError),
51
52    #[error(transparent)]
53    AuthError(#[from] AuthError),
54
55    #[error(transparent)]
56    ResponseDeserializationError(#[from] base64::DecodeError),
57
58    #[error(transparent)]
59    ArrowError(#[from] arrow::error::ArrowError),
60
61    #[error("S3 bucket path in PUT request is invalid: `{0}`")]
62    InvalidBucketPath(String),
63
64    #[error("Couldn't extract filename from the local path: `{0}`")]
65    InvalidLocalPath(String),
66
67    #[error(transparent)]
68    LocalIoError(#[from] io::Error),
69
70    #[error(transparent)]
71    ObjectStoreError(#[from] object_store::Error),
72
73    #[error(transparent)]
74    ObjectStorePathError(#[from] object_store::path::Error),
75
76    #[error(transparent)]
77    TokioTaskJoinError(#[from] tokio::task::JoinError),
78
79    #[error("Snowflake API error. Code: `{0}`. Message: `{1}`")]
80    ApiError(String, String),
81
82    #[error("Snowflake API empty response could mean that query wasn't executed correctly or API call was faulty")]
83    EmptyResponse,
84
85    #[error("No usable rowsets were included in the response")]
86    BrokenResponse,
87
88    #[error("Following feature is not implemented yet: {0}")]
89    Unimplemented(String),
90
91    #[error("Unexpected API response")]
92    UnexpectedResponse,
93
94    #[error(transparent)]
95    GlobPatternError(#[from] glob::PatternError),
96
97    #[error(transparent)]
98    GlobError(#[from] glob::GlobError),
99}
100
101/// Even if Arrow is specified as a return type non-select queries
102/// will return Json array of arrays: `[[42, "answer"], [43, "non-answer"]]`.
103pub struct JsonResult {
104    // todo: can it _only_ be a json array of arrays or something else too?
105    pub value: serde_json::Value,
106    /// Field ordering matches the array ordering
107    pub schema: Vec<FieldSchema>,
108}
109
110impl Display for JsonResult {
111    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
112        write!(f, "{}", self.value)
113    }
114}
115
116/// Based on the [`ExecResponseRowType`]
117pub struct FieldSchema {
118    pub name: String,
119    // todo: is it a good idea to expose internal response struct to the user?
120    pub type_: SnowflakeType,
121    pub scale: Option<i64>,
122    pub precision: Option<i64>,
123    pub nullable: bool,
124}
125
126impl From<ExecResponseRowType> for FieldSchema {
127    fn from(value: ExecResponseRowType) -> Self {
128        FieldSchema {
129            name: value.name,
130            type_: value.type_,
131            scale: value.scale,
132            precision: value.precision,
133            nullable: value.nullable,
134        }
135    }
136}
137
138/// Container for query result.
139/// Arrow is returned by-default for all SELECT statements,
140/// unless there is session configuration issue or it's a different statement type.
141pub enum QueryResult {
142    Arrow(Vec<RecordBatch>),
143    Json(JsonResult),
144    Empty,
145}
146
147/// Raw query result
148/// Can be transformed into [`QueryResult`]
149pub enum RawQueryResult {
150    /// Arrow IPC chunks
151    /// see: <https://arrow.apache.org/docs/format/Columnar.html#serialization-and-interprocess-communication-ipc>
152    Bytes(Vec<Bytes>),
153    /// Json payload is deserialized,
154    /// as it's already a part of REST response
155    Json(JsonResult),
156    Empty,
157}
158
159impl RawQueryResult {
160    pub fn deserialize_arrow(self) -> Result<QueryResult, ArrowError> {
161        match self {
162            RawQueryResult::Bytes(bytes) => {
163                Self::flat_bytes_to_batches(bytes).map(QueryResult::Arrow)
164            }
165            RawQueryResult::Json(j) => Ok(QueryResult::Json(j)),
166            RawQueryResult::Empty => Ok(QueryResult::Empty),
167        }
168    }
169
170    fn flat_bytes_to_batches(bytes: Vec<Bytes>) -> Result<Vec<RecordBatch>, ArrowError> {
171        let mut res = vec![];
172        for b in bytes {
173            let mut batches = Self::bytes_to_batches(b)?;
174            res.append(&mut batches);
175        }
176        Ok(res)
177    }
178
179    fn bytes_to_batches(bytes: Bytes) -> Result<Vec<RecordBatch>, ArrowError> {
180        let record_batches = StreamReader::try_new(bytes.reader(), None)?;
181        record_batches.into_iter().collect()
182    }
183}
184
185pub struct AuthArgs {
186    pub account_identifier: String,
187    pub warehouse: Option<String>,
188    pub database: Option<String>,
189    pub schema: Option<String>,
190    pub username: String,
191    pub role: Option<String>,
192    pub auth_type: AuthType,
193}
194
195impl AuthArgs {
196    pub fn from_env() -> Result<AuthArgs, SnowflakeApiError> {
197        let auth_type = if let Ok(password) = std::env::var("SNOWFLAKE_PASSWORD") {
198            Ok(AuthType::Password(PasswordArgs { password }))
199        } else if let Ok(private_key_pem) = std::env::var("SNOWFLAKE_PRIVATE_KEY") {
200            Ok(AuthType::Certificate(CertificateArgs { private_key_pem }))
201        } else {
202            Err(MissingEnvArgument(
203                "SNOWFLAKE_PASSWORD or SNOWFLAKE_PRIVATE_KEY".to_owned(),
204            ))
205        };
206
207        Ok(AuthArgs {
208            account_identifier: std::env::var("SNOWFLAKE_ACCOUNT")
209                .map_err(|_| MissingEnvArgument("SNOWFLAKE_ACCOUNT".to_owned()))?,
210            warehouse: std::env::var("SNOWLFLAKE_WAREHOUSE").ok(),
211            database: std::env::var("SNOWFLAKE_DATABASE").ok(),
212            schema: std::env::var("SNOWFLAKE_SCHEMA").ok(),
213            username: std::env::var("SNOWFLAKE_USER")
214                .map_err(|_| MissingEnvArgument("SNOWFLAKE_USER".to_owned()))?,
215            role: std::env::var("SNOWFLAKE_ROLE").ok(),
216            auth_type: auth_type?,
217        })
218    }
219}
220
221pub enum AuthType {
222    Password(PasswordArgs),
223    Certificate(CertificateArgs),
224}
225
226pub struct PasswordArgs {
227    pub password: String,
228}
229
230pub struct CertificateArgs {
231    pub private_key_pem: String,
232}
233
234#[must_use]
235pub struct SnowflakeApiBuilder {
236    pub auth: AuthArgs,
237    client: Option<ClientWithMiddleware>,
238}
239
240impl SnowflakeApiBuilder {
241    pub fn new(auth: AuthArgs) -> Self {
242        Self { auth, client: None }
243    }
244
245    pub fn with_client(mut self, client: ClientWithMiddleware) -> Self {
246        self.client = Some(client);
247        self
248    }
249
250    pub fn build(self) -> Result<SnowflakeApi, SnowflakeApiError> {
251        let connection = match self.client {
252            Some(client) => Arc::new(Connection::new_with_middware(client)),
253            None => Arc::new(Connection::new()?),
254        };
255
256        let session = match self.auth.auth_type {
257            AuthType::Password(args) => Session::password_auth(
258                Arc::clone(&connection),
259                &self.auth.account_identifier,
260                self.auth.warehouse.as_deref(),
261                self.auth.database.as_deref(),
262                self.auth.schema.as_deref(),
263                &self.auth.username,
264                self.auth.role.as_deref(),
265                &args.password,
266            ),
267            AuthType::Certificate(args) => Session::cert_auth(
268                Arc::clone(&connection),
269                &self.auth.account_identifier,
270                self.auth.warehouse.as_deref(),
271                self.auth.database.as_deref(),
272                self.auth.schema.as_deref(),
273                &self.auth.username,
274                self.auth.role.as_deref(),
275                &args.private_key_pem,
276            ),
277        };
278
279        let account_identifier = self.auth.account_identifier.to_uppercase();
280
281        Ok(SnowflakeApi::new(
282            Arc::clone(&connection),
283            session,
284            account_identifier,
285        ))
286    }
287}
288
289/// Snowflake API, keeps connection pool and manages session for you
290pub struct SnowflakeApi {
291    connection: Arc<Connection>,
292    session: Session,
293    account_identifier: String,
294}
295
296impl SnowflakeApi {
297    /// Create a new `SnowflakeApi` object with an existing connection and session.
298    pub fn new(connection: Arc<Connection>, session: Session, account_identifier: String) -> Self {
299        Self {
300            connection,
301            session,
302            account_identifier,
303        }
304    }
305    /// Initialize object with password auth. Authentication happens on the first request.
306    pub fn with_password_auth(
307        account_identifier: &str,
308        warehouse: Option<&str>,
309        database: Option<&str>,
310        schema: Option<&str>,
311        username: &str,
312        role: Option<&str>,
313        password: &str,
314    ) -> Result<Self, SnowflakeApiError> {
315        let connection = Arc::new(Connection::new()?);
316
317        let session = Session::password_auth(
318            Arc::clone(&connection),
319            account_identifier,
320            warehouse,
321            database,
322            schema,
323            username,
324            role,
325            password,
326        );
327
328        let account_identifier = account_identifier.to_uppercase();
329        Ok(Self::new(
330            Arc::clone(&connection),
331            session,
332            account_identifier,
333        ))
334    }
335
336    /// Initialize object with private certificate auth. Authentication happens on the first request.
337    pub fn with_certificate_auth(
338        account_identifier: &str,
339        warehouse: Option<&str>,
340        database: Option<&str>,
341        schema: Option<&str>,
342        username: &str,
343        role: Option<&str>,
344        private_key_pem: &str,
345    ) -> Result<Self, SnowflakeApiError> {
346        let connection = Arc::new(Connection::new()?);
347
348        let session = Session::cert_auth(
349            Arc::clone(&connection),
350            account_identifier,
351            warehouse,
352            database,
353            schema,
354            username,
355            role,
356            private_key_pem,
357        );
358
359        let account_identifier = account_identifier.to_uppercase();
360        Ok(Self::new(
361            Arc::clone(&connection),
362            session,
363            account_identifier,
364        ))
365    }
366
367    pub fn from_env() -> Result<Self, SnowflakeApiError> {
368        SnowflakeApiBuilder::new(AuthArgs::from_env()?).build()
369    }
370
371    /// Closes the current session, this is necessary to clean up temporary objects (tables, functions, etc)
372    /// which are Snowflake session dependent.
373    /// If another request is made the new session will be initiated.
374    pub async fn close_session(&mut self) -> Result<(), SnowflakeApiError> {
375        self.session.close().await?;
376        Ok(())
377    }
378
379    /// Execute a single query against API.
380    /// If statement is PUT, then file will be uploaded to the Snowflake-managed storage
381    pub async fn exec(&self, sql: &str) -> Result<QueryResult, SnowflakeApiError> {
382        let raw = self.exec_raw(sql).await?;
383        let res = raw.deserialize_arrow()?;
384        Ok(res)
385    }
386
387    /// Executes a single query against API.
388    /// If statement is PUT, then file will be uploaded to the Snowflake-managed storage
389    /// Returns raw bytes in the Arrow response
390    pub async fn exec_raw(&self, sql: &str) -> Result<RawQueryResult, SnowflakeApiError> {
391        let put_re = Regex::new(r"(?i)^(?:/\*.*\*/\s*)*put\s+").unwrap();
392
393        // put commands go through a different flow and result is side-effect
394        if put_re.is_match(sql) {
395            log::info!("Detected PUT query");
396            self.exec_put(sql).await.map(|()| RawQueryResult::Empty)
397        } else {
398            self.exec_arrow_raw(sql).await
399        }
400    }
401
402    async fn exec_put(&self, sql: &str) -> Result<(), SnowflakeApiError> {
403        let resp = self
404            .run_sql::<ExecResponse>(sql, QueryType::JsonQuery)
405            .await?;
406        log::debug!("Got PUT response: {resp:?}");
407
408        match resp {
409            ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse),
410            ExecResponse::PutGet(pg) => put::put(pg).await,
411            ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
412                e.data.error_code,
413                e.message.unwrap_or_default(),
414            )),
415        }
416    }
417
418    /// Useful for debugging to get the straight query response
419    #[cfg(debug_assertions)]
420    pub async fn exec_response(&mut self, sql: &str) -> Result<ExecResponse, SnowflakeApiError> {
421        self.run_sql::<ExecResponse>(sql, QueryType::ArrowQuery)
422            .await
423    }
424
425    /// Useful for debugging to get raw JSON response
426    #[cfg(debug_assertions)]
427    pub async fn exec_json(&mut self, sql: &str) -> Result<serde_json::Value, SnowflakeApiError> {
428        self.run_sql::<serde_json::Value>(sql, QueryType::JsonQuery)
429            .await
430    }
431
432    async fn exec_arrow_raw(&self, sql: &str) -> Result<RawQueryResult, SnowflakeApiError> {
433        let resp = self
434            .run_sql::<ExecResponse>(sql, QueryType::ArrowQuery)
435            .await?;
436        log::debug!("Got query response: {resp:?}");
437
438        let resp = match resp {
439            // processable response
440            ExecResponse::Query(qr) => Ok(qr),
441            ExecResponse::PutGet(_) => Err(SnowflakeApiError::UnexpectedResponse),
442            ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
443                e.data.error_code,
444                e.message.unwrap_or_default(),
445            )),
446        }?;
447
448        // if response was empty, base64 data is empty string
449        // todo: still return empty arrow batch with proper schema? (schema always included)
450        if resp.data.returned == 0 {
451            log::debug!("Got response with 0 rows");
452            Ok(RawQueryResult::Empty)
453        } else if let Some(value) = resp.data.rowset {
454            log::debug!("Got JSON response");
455            // NOTE: json response could be chunked too. however, go clients should receive arrow by-default,
456            // unless user sets session variable to return json. This case was added for debugging and status
457            // information being passed through that fields.
458            Ok(RawQueryResult::Json(JsonResult {
459                value,
460                schema: resp.data.rowtype.into_iter().map(Into::into).collect(),
461            }))
462        } else if let Some(base64) = resp.data.rowset_base64 {
463            // fixme: is it possible to give streaming interface?
464            let mut chunks = try_join_all(resp.data.chunks.iter().map(|chunk| {
465                self.connection
466                    .get_chunk(&chunk.url, &resp.data.chunk_headers)
467            }))
468            .await?;
469
470            // fixme: should base64 chunk go first?
471            // fixme: if response is chunked is it both base64 + chunks or just chunks?
472            if !base64.is_empty() {
473                log::debug!("Got base64 encoded response");
474                let bytes = Bytes::from(base64::engine::general_purpose::STANDARD.decode(base64)?);
475                chunks.push(bytes);
476            }
477
478            Ok(RawQueryResult::Bytes(chunks))
479        } else {
480            Err(SnowflakeApiError::BrokenResponse)
481        }
482    }
483
484    async fn run_sql<R: serde::de::DeserializeOwned>(
485        &self,
486        sql_text: &str,
487        query_type: QueryType,
488    ) -> Result<R, SnowflakeApiError> {
489        log::debug!("Executing: {sql_text}");
490
491        let parts = self.session.get_token().await?;
492
493        let body = ExecRequest {
494            sql_text: sql_text.to_string(),
495            async_exec: false,
496            sequence_id: parts.sequence_id,
497            is_internal: false,
498        };
499
500        let resp = self
501            .connection
502            .request::<R>(
503                query_type,
504                &self.account_identifier,
505                &[],
506                Some(&parts.session_token_auth_header),
507                body,
508            )
509            .await?;
510
511        Ok(resp)
512    }
513}