snowflake_api/
connection.rs1use reqwest::header::{self, HeaderMap, HeaderName, HeaderValue};
2use reqwest_middleware::ClientWithMiddleware;
3use reqwest_retry::policies::ExponentialBackoff;
4use reqwest_retry::RetryTransientMiddleware;
5use std::collections::HashMap;
6use std::time::{SystemTime, UNIX_EPOCH};
7use thiserror::Error;
8use url::Url;
9use uuid::Uuid;
10
11#[derive(Error, Debug)]
12pub enum ConnectionError {
13 #[error(transparent)]
14 RequestError(#[from] reqwest::Error),
15
16 #[error(transparent)]
17 RequestMiddlewareError(#[from] reqwest_middleware::Error),
18
19 #[error(transparent)]
20 UrlParsing(#[from] url::ParseError),
21
22 #[error(transparent)]
23 Deserialization(#[from] serde_json::Error),
24
25 #[error(transparent)]
26 InvalidHeader(#[from] header::InvalidHeaderValue),
27}
28
29struct QueryContext {
32 path: &'static str,
33 accept_mime: &'static str,
34}
35
36pub enum QueryType {
37 LoginRequest,
38 TokenRequest,
39 CloseSession,
40 JsonQuery,
41 ArrowQuery,
42}
43
44impl QueryType {
45 const fn query_context(&self) -> QueryContext {
46 match self {
47 Self::LoginRequest => QueryContext {
48 path: "session/v1/login-request",
49 accept_mime: "application/json",
50 },
51 Self::TokenRequest => QueryContext {
52 path: "/session/token-request",
53 accept_mime: "application/snowflake",
54 },
55 Self::CloseSession => QueryContext {
56 path: "session",
57 accept_mime: "application/snowflake",
58 },
59 Self::JsonQuery => QueryContext {
60 path: "queries/v1/query-request",
61 accept_mime: "application/json",
62 },
63 Self::ArrowQuery => QueryContext {
64 path: "queries/v1/query-request",
65 accept_mime: "application/snowflake",
66 },
67 }
68 }
69}
70
71pub struct Connection {
74 client: ClientWithMiddleware,
76}
77
78impl Connection {
79 pub fn new() -> Result<Self, ConnectionError> {
80 let client = Self::default_client_builder()?;
81
82 Ok(Self::new_with_middware(client.build()))
83 }
84
85 pub fn new_with_middware(client: ClientWithMiddleware) -> Self {
96 Self { client }
97 }
98
99 pub fn default_client_builder() -> Result<reqwest_middleware::ClientBuilder, ConnectionError> {
100 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
101
102 let client = reqwest::ClientBuilder::new()
103 .user_agent("Rust/0.0.1")
104 .gzip(true)
105 .referer(false);
106
107 #[cfg(debug_assertions)]
108 let client = client.connection_verbose(true);
109
110 let client = client.build()?;
111
112 Ok(reqwest_middleware::ClientBuilder::new(client)
113 .with(RetryTransientMiddleware::new_with_policy(retry_policy)))
114 }
115
116 pub async fn request<R: serde::de::DeserializeOwned>(
120 &self,
121 query_type: QueryType,
122 account_identifier: &str,
123 extra_get_params: &[(&str, &str)],
124 auth: Option<&str>,
125 body: impl serde::Serialize,
126 ) -> Result<R, ConnectionError> {
127 let context = query_type.query_context();
128
129 let request_id = Uuid::new_v4();
130 let request_guid = Uuid::new_v4();
131 let client_start_time = SystemTime::now()
132 .duration_since(UNIX_EPOCH)
133 .unwrap()
134 .as_secs()
135 .to_string();
136 let request_id = request_id.to_string();
138 let request_guid = request_guid.to_string();
139
140 let mut get_params = vec![
141 ("clientStartTime", client_start_time.as_str()),
142 ("requestId", request_id.as_str()),
143 ("request_guid", request_guid.as_str()),
144 ];
145 get_params.extend_from_slice(extra_get_params);
146
147 let url = format!(
148 "https://{}.snowflakecomputing.com/{}",
149 &account_identifier, context.path
150 );
151 let url = Url::parse_with_params(&url, get_params)?;
152
153 let mut headers = HeaderMap::new();
154
155 headers.append(
156 header::ACCEPT,
157 HeaderValue::from_static(context.accept_mime),
158 );
159 if let Some(auth) = auth {
160 let mut auth_val = HeaderValue::from_str(auth)?;
161 auth_val.set_sensitive(true);
162 headers.append(header::AUTHORIZATION, auth_val);
163 }
164
165 let resp = self
167 .client
168 .post(url)
169 .headers(headers)
170 .json(&body)
171 .send()
172 .await?;
173
174 Ok(resp.json::<R>().await?)
175 }
176
177 pub async fn get_chunk(
178 &self,
179 url: &str,
180 headers: &HashMap<String, String>,
181 ) -> Result<bytes::Bytes, ConnectionError> {
182 let mut header_map = HeaderMap::new();
183 for (k, v) in headers {
184 header_map.insert(
185 HeaderName::from_bytes(k.as_bytes()).unwrap(),
186 HeaderValue::from_bytes(v.as_bytes()).unwrap(),
187 );
188 }
189 let bytes = self
190 .client
191 .get(url)
192 .headers(header_map)
193 .send()
194 .await?
195 .bytes()
196 .await?;
197 Ok(bytes)
198 }
199}