snowflake_api/
connection.rs

1use reqwest::header::{self, HeaderMap, HeaderName, HeaderValue};
2use reqwest_middleware::ClientWithMiddleware;
3use reqwest_retry::policies::ExponentialBackoff;
4use reqwest_retry::RetryTransientMiddleware;
5use std::collections::HashMap;
6use std::time::{SystemTime, UNIX_EPOCH};
7use thiserror::Error;
8use url::Url;
9use uuid::Uuid;
10
11#[derive(Error, Debug)]
12pub enum ConnectionError {
13    #[error(transparent)]
14    RequestError(#[from] reqwest::Error),
15
16    #[error(transparent)]
17    RequestMiddlewareError(#[from] reqwest_middleware::Error),
18
19    #[error(transparent)]
20    UrlParsing(#[from] url::ParseError),
21
22    #[error(transparent)]
23    Deserialization(#[from] serde_json::Error),
24
25    #[error(transparent)]
26    InvalidHeader(#[from] header::InvalidHeaderValue),
27}
28
29/// Container for query parameters
30/// This API has different endpoints and MIME types for different requests
31struct QueryContext {
32    path: &'static str,
33    accept_mime: &'static str,
34}
35
36pub enum QueryType {
37    LoginRequest,
38    TokenRequest,
39    CloseSession,
40    JsonQuery,
41    ArrowQuery,
42}
43
44impl QueryType {
45    const fn query_context(&self) -> QueryContext {
46        match self {
47            Self::LoginRequest => QueryContext {
48                path: "session/v1/login-request",
49                accept_mime: "application/json",
50            },
51            Self::TokenRequest => QueryContext {
52                path: "/session/token-request",
53                accept_mime: "application/snowflake",
54            },
55            Self::CloseSession => QueryContext {
56                path: "session",
57                accept_mime: "application/snowflake",
58            },
59            Self::JsonQuery => QueryContext {
60                path: "queries/v1/query-request",
61                accept_mime: "application/json",
62            },
63            Self::ArrowQuery => QueryContext {
64                path: "queries/v1/query-request",
65                accept_mime: "application/snowflake",
66            },
67        }
68    }
69}
70
71/// Connection pool
72/// Minimal session will have at least 2 requests - login and query
73pub struct Connection {
74    // no need for Arc as it's already inside the reqwest client
75    client: ClientWithMiddleware,
76}
77
78impl Connection {
79    pub fn new() -> Result<Self, ConnectionError> {
80        let client = Self::default_client_builder()?;
81
82        Ok(Self::new_with_middware(client.build()))
83    }
84
85    /// Allow a user to provide their own middleware
86    ///
87    /// Users can provide their own middleware to the connection like this:
88    /// ```rust
89    /// use snowflake_api::connection::Connection;
90    /// let mut client = Connection::default_client_builder();
91    ///  // modify the client builder here
92    /// let connection = Connection::new_with_middware(client.unwrap().build());
93    /// ```
94    /// This is not intended to be called directly, but is used by `SnowflakeApiBuilder::with_client`
95    pub fn new_with_middware(client: ClientWithMiddleware) -> Self {
96        Self { client }
97    }
98
99    pub fn default_client_builder() -> Result<reqwest_middleware::ClientBuilder, ConnectionError> {
100        let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
101
102        let client = reqwest::ClientBuilder::new()
103            .user_agent("Rust/0.0.1")
104            .gzip(true)
105            .referer(false);
106
107        #[cfg(debug_assertions)]
108        let client = client.connection_verbose(true);
109
110        let client = client.build()?;
111
112        Ok(reqwest_middleware::ClientBuilder::new(client)
113            .with(RetryTransientMiddleware::new_with_policy(retry_policy)))
114    }
115
116    /// Perform request of given query type with extra body or parameters
117    // todo: implement soft error handling
118    // todo: is there better way to not repeat myself?
119    pub async fn request<R: serde::de::DeserializeOwned>(
120        &self,
121        query_type: QueryType,
122        account_identifier: &str,
123        extra_get_params: &[(&str, &str)],
124        auth: Option<&str>,
125        body: impl serde::Serialize,
126    ) -> Result<R, ConnectionError> {
127        let context = query_type.query_context();
128
129        let request_id = Uuid::new_v4();
130        let request_guid = Uuid::new_v4();
131        let client_start_time = SystemTime::now()
132            .duration_since(UNIX_EPOCH)
133            .unwrap()
134            .as_secs()
135            .to_string();
136        // fixme: update uuid's on the retry
137        let request_id = request_id.to_string();
138        let request_guid = request_guid.to_string();
139
140        let mut get_params = vec![
141            ("clientStartTime", client_start_time.as_str()),
142            ("requestId", request_id.as_str()),
143            ("request_guid", request_guid.as_str()),
144        ];
145        get_params.extend_from_slice(extra_get_params);
146
147        let url = format!(
148            "https://{}.snowflakecomputing.com/{}",
149            &account_identifier, context.path
150        );
151        let url = Url::parse_with_params(&url, get_params)?;
152
153        let mut headers = HeaderMap::new();
154
155        headers.append(
156            header::ACCEPT,
157            HeaderValue::from_static(context.accept_mime),
158        );
159        if let Some(auth) = auth {
160            let mut auth_val = HeaderValue::from_str(auth)?;
161            auth_val.set_sensitive(true);
162            headers.append(header::AUTHORIZATION, auth_val);
163        }
164
165        // todo: persist client to use connection polling
166        let resp = self
167            .client
168            .post(url)
169            .headers(headers)
170            .json(&body)
171            .send()
172            .await?;
173
174        Ok(resp.json::<R>().await?)
175    }
176
177    pub async fn get_chunk(
178        &self,
179        url: &str,
180        headers: &HashMap<String, String>,
181    ) -> Result<bytes::Bytes, ConnectionError> {
182        let mut header_map = HeaderMap::new();
183        for (k, v) in headers {
184            header_map.insert(
185                HeaderName::from_bytes(k.as_bytes()).unwrap(),
186                HeaderValue::from_bytes(v.as_bytes()).unwrap(),
187            );
188        }
189        let bytes = self
190            .client
191            .get(url)
192            .headers(header_map)
193            .send()
194            .await?
195            .bytes()
196            .await?;
197        Ok(bytes)
198    }
199}