#![doc(
issue_tracker_base_url = "https://github.com/mycelial/snowflake-rs/issues",
test(no_crate_inject)
)]
#![doc = include_str!("../README.md")]
#![warn(clippy::all, clippy::pedantic)]
#![allow(
clippy::must_use_candidate,
clippy::missing_errors_doc,
clippy::module_name_repetitions,
clippy::struct_field_names,
clippy::future_not_send, clippy::missing_panics_doc
)]
use std::fmt::{Display, Formatter};
use std::io;
use std::sync::Arc;
use arrow::error::ArrowError;
use arrow::ipc::reader::StreamReader;
use arrow::record_batch::RecordBatch;
use base64::Engine;
use bytes::{Buf, Bytes};
use futures::future::try_join_all;
use regex::Regex;
use reqwest_middleware::ClientWithMiddleware;
use thiserror::Error;
use responses::ExecResponse;
use session::{AuthError, Session};
use crate::connection::QueryType;
use crate::connection::{Connection, ConnectionError};
use crate::requests::ExecRequest;
use crate::responses::{ExecResponseRowType, SnowflakeType};
use crate::session::AuthError::MissingEnvArgument;
pub mod connection;
#[cfg(feature = "polars")]
mod polars;
mod put;
mod requests;
mod responses;
mod session;
#[derive(Error, Debug)]
pub enum SnowflakeApiError {
#[error(transparent)]
RequestError(#[from] ConnectionError),
#[error(transparent)]
AuthError(#[from] AuthError),
#[error(transparent)]
ResponseDeserializationError(#[from] base64::DecodeError),
#[error(transparent)]
ArrowError(#[from] arrow::error::ArrowError),
#[error("S3 bucket path in PUT request is invalid: `{0}`")]
InvalidBucketPath(String),
#[error("Couldn't extract filename from the local path: `{0}`")]
InvalidLocalPath(String),
#[error(transparent)]
LocalIoError(#[from] io::Error),
#[error(transparent)]
ObjectStoreError(#[from] object_store::Error),
#[error(transparent)]
ObjectStorePathError(#[from] object_store::path::Error),
#[error(transparent)]
TokioTaskJoinError(#[from] tokio::task::JoinError),
#[error("Snowflake API error. Code: `{0}`. Message: `{1}`")]
ApiError(String, String),
#[error("Snowflake API empty response could mean that query wasn't executed correctly or API call was faulty")]
EmptyResponse,
#[error("No usable rowsets were included in the response")]
BrokenResponse,
#[error("Following feature is not implemented yet: {0}")]
Unimplemented(String),
#[error("Unexpected API response")]
UnexpectedResponse,
#[error(transparent)]
GlobPatternError(#[from] glob::PatternError),
#[error(transparent)]
GlobError(#[from] glob::GlobError),
}
pub struct JsonResult {
pub value: serde_json::Value,
pub schema: Vec<FieldSchema>,
}
impl Display for JsonResult {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.value)
}
}
pub struct FieldSchema {
pub name: String,
pub type_: SnowflakeType,
pub scale: Option<i64>,
pub precision: Option<i64>,
pub nullable: bool,
}
impl From<ExecResponseRowType> for FieldSchema {
fn from(value: ExecResponseRowType) -> Self {
FieldSchema {
name: value.name,
type_: value.type_,
scale: value.scale,
precision: value.precision,
nullable: value.nullable,
}
}
}
pub enum QueryResult {
Arrow(Vec<RecordBatch>),
Json(JsonResult),
Empty,
}
pub enum RawQueryResult {
Bytes(Vec<Bytes>),
Json(JsonResult),
Empty,
}
impl RawQueryResult {
pub fn deserialize_arrow(self) -> Result<QueryResult, ArrowError> {
match self {
RawQueryResult::Bytes(bytes) => {
Self::flat_bytes_to_batches(bytes).map(QueryResult::Arrow)
}
RawQueryResult::Json(j) => Ok(QueryResult::Json(j)),
RawQueryResult::Empty => Ok(QueryResult::Empty),
}
}
fn flat_bytes_to_batches(bytes: Vec<Bytes>) -> Result<Vec<RecordBatch>, ArrowError> {
let mut res = vec![];
for b in bytes {
let mut batches = Self::bytes_to_batches(b)?;
res.append(&mut batches);
}
Ok(res)
}
fn bytes_to_batches(bytes: Bytes) -> Result<Vec<RecordBatch>, ArrowError> {
let record_batches = StreamReader::try_new_unbuffered(bytes.reader(), None)?;
record_batches.into_iter().collect()
}
}
pub struct AuthArgs {
pub account_identifier: String,
pub warehouse: Option<String>,
pub database: Option<String>,
pub schema: Option<String>,
pub username: String,
pub role: Option<String>,
pub auth_type: AuthType,
}
impl AuthArgs {
pub fn from_env() -> Result<AuthArgs, SnowflakeApiError> {
let auth_type = if let Ok(password) = std::env::var("SNOWFLAKE_PASSWORD") {
Ok(AuthType::Password(PasswordArgs { password }))
} else if let Ok(private_key_pem) = std::env::var("SNOWFLAKE_PRIVATE_KEY") {
Ok(AuthType::Certificate(CertificateArgs { private_key_pem }))
} else {
Err(MissingEnvArgument(
"SNOWFLAKE_PASSWORD or SNOWFLAKE_PRIVATE_KEY".to_owned(),
))
};
Ok(AuthArgs {
account_identifier: std::env::var("SNOWFLAKE_ACCOUNT")
.map_err(|_| MissingEnvArgument("SNOWFLAKE_ACCOUNT".to_owned()))?,
warehouse: std::env::var("SNOWLFLAKE_WAREHOUSE").ok(),
database: std::env::var("SNOWFLAKE_DATABASE").ok(),
schema: std::env::var("SNOWFLAKE_SCHEMA").ok(),
username: std::env::var("SNOWFLAKE_USER")
.map_err(|_| MissingEnvArgument("SNOWFLAKE_USER".to_owned()))?,
role: std::env::var("SNOWFLAKE_ROLE").ok(),
auth_type: auth_type?,
})
}
}
pub enum AuthType {
Password(PasswordArgs),
Certificate(CertificateArgs),
}
pub struct PasswordArgs {
pub password: String,
}
pub struct CertificateArgs {
pub private_key_pem: String,
}
#[must_use]
pub struct SnowflakeApiBuilder {
pub auth: AuthArgs,
client: Option<ClientWithMiddleware>,
}
impl SnowflakeApiBuilder {
pub fn new(auth: AuthArgs) -> Self {
Self { auth, client: None }
}
pub fn with_client(mut self, client: ClientWithMiddleware) -> Self {
self.client = Some(client);
self
}
pub fn build(self) -> Result<SnowflakeApi, SnowflakeApiError> {
let connection = match self.client {
Some(client) => Arc::new(Connection::new_with_middware(client)),
None => Arc::new(Connection::new()?),
};
let session = match self.auth.auth_type {
AuthType::Password(args) => Session::password_auth(
Arc::clone(&connection),
&self.auth.account_identifier,
self.auth.warehouse.as_deref(),
self.auth.database.as_deref(),
self.auth.schema.as_deref(),
&self.auth.username,
self.auth.role.as_deref(),
&args.password,
),
AuthType::Certificate(args) => Session::cert_auth(
Arc::clone(&connection),
&self.auth.account_identifier,
self.auth.warehouse.as_deref(),
self.auth.database.as_deref(),
self.auth.schema.as_deref(),
&self.auth.username,
self.auth.role.as_deref(),
&args.private_key_pem,
),
};
let account_identifier = self.auth.account_identifier.to_uppercase();
Ok(SnowflakeApi::new(
Arc::clone(&connection),
session,
account_identifier,
))
}
}
pub struct SnowflakeApi {
connection: Arc<Connection>,
session: Session,
account_identifier: String,
}
impl SnowflakeApi {
pub fn new(connection: Arc<Connection>, session: Session, account_identifier: String) -> Self {
Self {
connection,
session,
account_identifier,
}
}
pub fn with_password_auth(
account_identifier: &str,
warehouse: Option<&str>,
database: Option<&str>,
schema: Option<&str>,
username: &str,
role: Option<&str>,
password: &str,
) -> Result<Self, SnowflakeApiError> {
let connection = Arc::new(Connection::new()?);
let session = Session::password_auth(
Arc::clone(&connection),
account_identifier,
warehouse,
database,
schema,
username,
role,
password,
);
let account_identifier = account_identifier.to_uppercase();
Ok(Self::new(
Arc::clone(&connection),
session,
account_identifier,
))
}
pub fn with_certificate_auth(
account_identifier: &str,
warehouse: Option<&str>,
database: Option<&str>,
schema: Option<&str>,
username: &str,
role: Option<&str>,
private_key_pem: &str,
) -> Result<Self, SnowflakeApiError> {
let connection = Arc::new(Connection::new()?);
let session = Session::cert_auth(
Arc::clone(&connection),
account_identifier,
warehouse,
database,
schema,
username,
role,
private_key_pem,
);
let account_identifier = account_identifier.to_uppercase();
Ok(Self::new(
Arc::clone(&connection),
session,
account_identifier,
))
}
pub fn from_env() -> Result<Self, SnowflakeApiError> {
SnowflakeApiBuilder::new(AuthArgs::from_env()?).build()
}
pub async fn close_session(&mut self) -> Result<(), SnowflakeApiError> {
self.session.close().await?;
Ok(())
}
pub async fn exec(&self, sql: &str) -> Result<QueryResult, SnowflakeApiError> {
let raw = self.exec_raw(sql).await?;
let res = raw.deserialize_arrow()?;
Ok(res)
}
pub async fn exec_raw(&self, sql: &str) -> Result<RawQueryResult, SnowflakeApiError> {
let put_re = Regex::new(r"(?i)^(?:/\*.*\*/\s*)*put\s+").unwrap();
if put_re.is_match(sql) {
log::info!("Detected PUT query");
self.exec_put(sql).await.map(|()| RawQueryResult::Empty)
} else {
self.exec_arrow_raw(sql).await
}
}
async fn exec_put(&self, sql: &str) -> Result<(), SnowflakeApiError> {
let resp = self
.run_sql::<ExecResponse>(sql, QueryType::JsonQuery)
.await?;
log::debug!("Got PUT response: {:?}", resp);
match resp {
ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse),
ExecResponse::PutGet(pg) => put::put(pg).await,
ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
e.data.error_code,
e.message.unwrap_or_default(),
)),
}
}
#[cfg(debug_assertions)]
pub async fn exec_response(&mut self, sql: &str) -> Result<ExecResponse, SnowflakeApiError> {
self.run_sql::<ExecResponse>(sql, QueryType::ArrowQuery)
.await
}
#[cfg(debug_assertions)]
pub async fn exec_json(&mut self, sql: &str) -> Result<serde_json::Value, SnowflakeApiError> {
self.run_sql::<serde_json::Value>(sql, QueryType::JsonQuery)
.await
}
async fn exec_arrow_raw(&self, sql: &str) -> Result<RawQueryResult, SnowflakeApiError> {
let resp = self
.run_sql::<ExecResponse>(sql, QueryType::ArrowQuery)
.await?;
log::debug!("Got query response: {:?}", resp);
let resp = match resp {
ExecResponse::Query(qr) => Ok(qr),
ExecResponse::PutGet(_) => Err(SnowflakeApiError::UnexpectedResponse),
ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
e.data.error_code,
e.message.unwrap_or_default(),
)),
}?;
if resp.data.returned == 0 {
log::debug!("Got response with 0 rows");
Ok(RawQueryResult::Empty)
} else if let Some(value) = resp.data.rowset {
log::debug!("Got JSON response");
Ok(RawQueryResult::Json(JsonResult {
value,
schema: resp.data.rowtype.into_iter().map(Into::into).collect(),
}))
} else if let Some(base64) = resp.data.rowset_base64 {
let mut chunks = try_join_all(resp.data.chunks.iter().map(|chunk| {
self.connection
.get_chunk(&chunk.url, &resp.data.chunk_headers)
}))
.await?;
if !base64.is_empty() {
log::debug!("Got base64 encoded response");
let bytes = Bytes::from(base64::engine::general_purpose::STANDARD.decode(base64)?);
chunks.push(bytes);
}
Ok(RawQueryResult::Bytes(chunks))
} else {
Err(SnowflakeApiError::BrokenResponse)
}
}
async fn run_sql<R: serde::de::DeserializeOwned>(
&self,
sql_text: &str,
query_type: QueryType,
) -> Result<R, SnowflakeApiError> {
log::debug!("Executing: {}", sql_text);
let parts = self.session.get_token().await?;
let body = ExecRequest {
sql_text: sql_text.to_string(),
async_exec: false,
sequence_id: parts.sequence_id,
is_internal: false,
};
let resp = self
.connection
.request::<R>(
query_type,
&self.account_identifier,
&[],
Some(&parts.session_token_auth_header),
body,
)
.await?;
Ok(resp)
}
}