1use std::{net::ToSocketAddrs, num::NonZeroUsize};
2
3use sqlx_core::{connection::LogSettings, net::tls::CertificateInput};
4
5use super::{
6 error::ExaConfigError, ssl_mode::ExaSslMode, ExaConnectOptions, Login, ProtocolVersion,
7 DEFAULT_CACHE_CAPACITY, DEFAULT_FETCH_SIZE, DEFAULT_PORT,
8};
9use crate::SqlxResult;
10
11#[derive(Clone, Debug)]
13pub struct ExaConnectOptionsBuilder {
14 host: Option<String>,
15 port: u16,
16 ssl_mode: ExaSslMode,
17 ssl_ca: Option<CertificateInput>,
18 ssl_client_cert: Option<CertificateInput>,
19 ssl_client_key: Option<CertificateInput>,
20 statement_cache_capacity: NonZeroUsize,
21 username: Option<String>,
22 password: Option<String>,
23 access_token: Option<String>,
24 refresh_token: Option<String>,
25 schema: Option<String>,
26 protocol_version: ProtocolVersion,
27 fetch_size: usize,
28 query_timeout: u64,
29 compression: bool,
30 feedback_interval: u8,
31}
32
33impl Default for ExaConnectOptionsBuilder {
34 fn default() -> Self {
35 Self {
36 host: None,
37 port: DEFAULT_PORT,
38 ssl_mode: ExaSslMode::default(),
39 ssl_ca: None,
40 ssl_client_cert: None,
41 ssl_client_key: None,
42 statement_cache_capacity: DEFAULT_CACHE_CAPACITY,
43 username: None,
44 password: None,
45 access_token: None,
46 refresh_token: None,
47 schema: None,
48 protocol_version: ProtocolVersion::V3,
49 fetch_size: DEFAULT_FETCH_SIZE,
50 query_timeout: 0,
51 compression: false,
52 feedback_interval: 1,
53 }
54 }
55}
56
57impl ExaConnectOptionsBuilder {
58 pub fn build(self) -> SqlxResult<ExaConnectOptions> {
64 let hostname = self.host.ok_or(ExaConfigError::MissingHost)?;
65 let password = self.password.unwrap_or_default();
66
67 let login = match (self.username, self.access_token, self.refresh_token) {
69 (Some(username), None, None) => Login::Credentials { username, password },
70 (None, Some(access_token), None) => Login::AccessToken { access_token },
71 (None, None, Some(refresh_token)) => Login::RefreshToken { refresh_token },
72 _ => return Err(ExaConfigError::MultipleAuthMethods.into()),
73 };
74
75 let hosts = Self::parse_hostname(hostname);
76 let mut hosts_details = Vec::with_capacity(hosts.len());
77
78 for host in hosts {
79 let addrs = (host.as_str(), self.port).to_socket_addrs()?.collect();
80 hosts_details.push((host, addrs));
81 }
82
83 let opts = ExaConnectOptions {
84 hosts_details,
85 port: self.port,
86 ssl_mode: self.ssl_mode,
87 ssl_ca: self.ssl_ca,
88 ssl_client_cert: self.ssl_client_cert,
89 ssl_client_key: self.ssl_client_key,
90 statement_cache_capacity: self.statement_cache_capacity,
91 login,
92 schema: self.schema,
93 protocol_version: self.protocol_version,
94 fetch_size: self.fetch_size,
95 query_timeout: self.query_timeout,
96 compression: self.compression,
97 feedback_interval: self.feedback_interval,
98 log_settings: LogSettings::default(),
99 };
100
101 Ok(opts)
102 }
103
104 #[must_use = "call build() to get connection options"]
105 pub fn host(mut self, host: String) -> Self {
106 self.host = Some(host);
107 self
108 }
109
110 #[must_use = "call build() to get connection options"]
111 pub fn port(mut self, port: u16) -> Self {
112 self.port = port;
113 self
114 }
115
116 #[must_use = "call build() to get connection options"]
117 pub fn ssl_mode(mut self, ssl_mode: ExaSslMode) -> Self {
118 self.ssl_mode = ssl_mode;
119 self
120 }
121
122 #[must_use = "call build() to get connection options"]
123 pub fn ssl_ca(mut self, ssl_ca: CertificateInput) -> Self {
124 self.ssl_ca = Some(ssl_ca);
125 self
126 }
127
128 #[must_use = "call build() to get connection options"]
129 pub fn ssl_client_cert(mut self, ssl_client_cert: CertificateInput) -> Self {
130 self.ssl_client_cert = Some(ssl_client_cert);
131 self
132 }
133
134 #[must_use = "call build() to get connection options"]
135 pub fn ssl_client_key(mut self, ssl_client_key: CertificateInput) -> Self {
136 self.ssl_client_key = Some(ssl_client_key);
137 self
138 }
139
140 #[must_use = "call build() to get connection options"]
141 pub fn statement_cache_capacity(mut self, capacity: NonZeroUsize) -> Self {
142 self.statement_cache_capacity = capacity;
143 self
144 }
145
146 #[must_use = "call build() to get connection options"]
147 pub fn username(mut self, username: String) -> Self {
148 self.username = Some(username);
149 self
150 }
151
152 #[must_use = "call build() to get connection options"]
153 pub fn password(mut self, password: String) -> Self {
154 self.password = Some(password);
155 self
156 }
157
158 #[must_use = "call build() to get connection options"]
159 pub fn access_token(mut self, access_token: String) -> Self {
160 self.access_token = Some(access_token);
161 self
162 }
163
164 #[must_use = "call build() to get connection options"]
165 pub fn refresh_token(mut self, refresh_token: String) -> Self {
166 self.refresh_token = Some(refresh_token);
167 self
168 }
169
170 #[must_use = "call build() to get connection options"]
171 pub fn schema(mut self, schema: String) -> Self {
172 self.schema = Some(schema);
173 self
174 }
175
176 #[must_use = "call build() to get connection options"]
177 pub fn protocol_version(mut self, protocol_version: ProtocolVersion) -> Self {
178 self.protocol_version = protocol_version;
179 self
180 }
181
182 #[must_use = "call build() to get connection options"]
183 pub fn fetch_size(mut self, fetch_size: usize) -> Self {
184 self.fetch_size = fetch_size;
185 self
186 }
187
188 #[must_use = "call build() to get connection options"]
189 pub fn query_timeout(mut self, query_timeout: u64) -> Self {
190 self.query_timeout = query_timeout;
191 self
192 }
193
194 #[must_use = "call build() to get connection options"]
195 pub fn compression(mut self, compression: bool) -> Self {
196 let feature_flag = cfg!(feature = "compression");
197
198 if feature_flag && !compression {
199 tracing::warn!("compression cannot be enabled without the 'compression' feature");
200 }
201
202 self.compression = compression && feature_flag;
203 self
204 }
205
206 #[must_use = "call build() to get connection options"]
207 pub fn feedback_interval(mut self, feedback_interval: u8) -> Self {
208 self.feedback_interval = feedback_interval;
209 self
210 }
211
212 fn parse_hostname(hostname: String) -> Vec<String> {
219 Self::_parse_hostname(&hostname).unwrap_or_else(|| vec![hostname])
221 }
222
223 #[inline]
227 fn _parse_hostname(hostname: &str) -> Option<Vec<String>> {
228 let mut index_accum = 0;
229
230 let range_idx = loop {
233 let search_str = &hostname[index_accum..];
234
235 let idx = search_str.find("..")?;
237
238 let before_opt = idx
247 .checked_sub(1)
248 .and_then(|i| search_str.as_bytes().get(i));
249
250 let after_opt = search_str.as_bytes().get(idx + 2);
252
253 break match (before_opt, after_opt) {
256 (Some(b), Some(a)) if b.is_ascii_digit() || a.is_ascii_digit() => idx + index_accum,
257 _ => {
258 index_accum += idx + 2;
259 continue;
260 }
261 };
262 };
263
264 let before_range = &hostname[..range_idx];
265 let after_range = &hostname[range_idx + 2..];
266
267 let start_idx = before_range
275 .rfind(|c: char| !c.is_ascii_digit())
276 .map(|i| i + 1)
277 .unwrap_or_default();
278 let end_idx = after_range
279 .find(|c: char| !c.is_ascii_digit())
280 .unwrap_or(after_range.len());
281
282 let (prefix, start_range) = before_range.split_at(start_idx);
284 let (end_range, suffix) = after_range.split_at(end_idx);
285
286 let start = start_range.parse::<usize>().ok()?;
288 let end = end_range.parse::<usize>().ok()?;
289
290 let hosts = (start..=end)
291 .map(|i| format!("{prefix}{i}{suffix}"))
292 .collect();
293
294 Some(hosts)
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::ExaConnectOptionsBuilder;
301
302 #[test]
303 fn test_simple_ip() {
304 let hostname = "10.10.10.10";
305
306 let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
307 assert_eq!(generated, vec!(hostname));
308 }
309
310 #[test]
311 fn test_ip_range_end() {
312 let hostname = "10.10.10.1..3";
313 let expected = vec!["10.10.10.1", "10.10.10.2", "10.10.10.3"];
314
315 let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
316 assert_eq!(generated, expected);
317 }
318
319 #[test]
320 fn test_ip_range_start() {
321 let hostname = "1..3.10.10.10";
322 let expected = vec!["1.10.10.10", "2.10.10.10", "3.10.10.10"];
323
324 let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
325 assert_eq!(generated, expected);
326 }
327
328 #[test]
329 fn test_simple_hostname() {
330 let hostname = "myhost.com";
331
332 let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
333 assert_eq!(generated, vec!(hostname));
334 }
335
336 #[test]
337 fn test_hostname_with_range() {
338 let hostname = "myhost1..4.com";
339 let expected = vec!["myhost1.com", "myhost2.com", "myhost3.com", "myhost4.com"];
340
341 let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
342 assert_eq!(generated, expected);
343 }
344
345 #[test]
346 fn test_hostname_with_big_range() {
347 let hostname = "myhost125..127.com";
348 let expected = vec!["myhost125.com", "myhost126.com", "myhost127.com"];
349
350 let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
351 assert_eq!(generated, expected);
352 }
353
354 #[test]
355 fn test_hostname_with_inverse_range() {
356 let hostname = "myhost127..125.com";
357
358 let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
359 assert!(generated.is_empty());
360 }
361
362 #[test]
363 fn test_hostname_with_numbers_no_range() {
364 let hostname = "myhost1.4.com";
365
366 let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
367 assert_eq!(generated, vec![hostname]);
368 }
369
370 #[test]
371 fn test_hostname_with_range_one_numbers() {
372 let hostname = "myhost1..b.com";
373
374 let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
375 assert_eq!(generated, vec![hostname]);
376 }
377
378 #[test]
379 fn test_hostname_with_range_no_numbers() {
380 let hostname = "myhosta..b.com";
381
382 let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
383 assert_eq!(generated, vec![hostname]);
384 }
385
386 #[test]
387 fn test_hostname_starts_with_range() {
388 let hostname = "..myhost.com";
389
390 let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
391 assert_eq!(generated, vec![hostname]);
392 }
393
394 #[test]
395 fn test_hostname_ends_with_range() {
396 let hostname = "myhost.com..";
397
398 let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
399 assert_eq!(generated, vec![hostname]);
400 }
401
402 #[test]
403 fn test_hostname_real_and_fake_range() {
404 let hostname = "myhosta..bcdef1..3.com";
405 let expected = vec![
406 "myhosta..bcdef1.com",
407 "myhosta..bcdef2.com",
408 "myhosta..bcdef3.com",
409 ];
410
411 let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
412 assert_eq!(generated, expected);
413 }
414
415 #[test]
416 fn test_hostname_two_valid_ranges() {
417 let hostname = "myhost1..3cdef4..7.com";
418 let expected = vec![
419 "myhost1cdef4..7.com",
420 "myhost2cdef4..7.com",
421 "myhost3cdef4..7.com",
422 ];
423
424 let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
425 assert_eq!(generated, expected);
426 }
427}