sqlx_exasol/options/
builder.rs

1use std::{net::ToSocketAddrs, num::NonZeroUsize};
2
3use sqlx_core::{connection::LogSettings, net::tls::CertificateInput};
4
5use super::{
6    error::ExaConfigError, ssl_mode::ExaSslMode, ExaConnectOptions, Login, ProtocolVersion,
7    DEFAULT_CACHE_CAPACITY, DEFAULT_FETCH_SIZE, DEFAULT_PORT,
8};
9use crate::SqlxResult;
10
11/// Builder for [`ExaConnectOptions`].
12#[derive(Clone, Debug)]
13pub struct ExaConnectOptionsBuilder {
14    host: Option<String>,
15    port: u16,
16    ssl_mode: ExaSslMode,
17    ssl_ca: Option<CertificateInput>,
18    ssl_client_cert: Option<CertificateInput>,
19    ssl_client_key: Option<CertificateInput>,
20    statement_cache_capacity: NonZeroUsize,
21    username: Option<String>,
22    password: Option<String>,
23    access_token: Option<String>,
24    refresh_token: Option<String>,
25    schema: Option<String>,
26    protocol_version: ProtocolVersion,
27    fetch_size: usize,
28    query_timeout: u64,
29    compression: bool,
30    feedback_interval: u8,
31}
32
33impl Default for ExaConnectOptionsBuilder {
34    fn default() -> Self {
35        Self {
36            host: None,
37            port: DEFAULT_PORT,
38            ssl_mode: ExaSslMode::default(),
39            ssl_ca: None,
40            ssl_client_cert: None,
41            ssl_client_key: None,
42            statement_cache_capacity: DEFAULT_CACHE_CAPACITY,
43            username: None,
44            password: None,
45            access_token: None,
46            refresh_token: None,
47            schema: None,
48            protocol_version: ProtocolVersion::V3,
49            fetch_size: DEFAULT_FETCH_SIZE,
50            query_timeout: 0,
51            compression: false,
52            feedback_interval: 1,
53        }
54    }
55}
56
57impl ExaConnectOptionsBuilder {
58    /// Consumes this builder and returns an instance of [`ExaConnectOptions`].
59    ///
60    /// # Errors
61    ///
62    /// Will return an error if resolving the hostname to [`std::net::SocketAddr`] fails.
63    pub fn build(self) -> SqlxResult<ExaConnectOptions> {
64        let hostname = self.host.ok_or(ExaConfigError::MissingHost)?;
65        let password = self.password.unwrap_or_default();
66
67        // Only one authentication method can be used at once
68        let login = match (self.username, self.access_token, self.refresh_token) {
69            (Some(username), None, None) => Login::Credentials { username, password },
70            (None, Some(access_token), None) => Login::AccessToken { access_token },
71            (None, None, Some(refresh_token)) => Login::RefreshToken { refresh_token },
72            _ => return Err(ExaConfigError::MultipleAuthMethods.into()),
73        };
74
75        let hosts = Self::parse_hostname(hostname);
76        let mut hosts_details = Vec::with_capacity(hosts.len());
77
78        for host in hosts {
79            let addrs = (host.as_str(), self.port).to_socket_addrs()?.collect();
80            hosts_details.push((host, addrs));
81        }
82
83        let opts = ExaConnectOptions {
84            hosts_details,
85            port: self.port,
86            ssl_mode: self.ssl_mode,
87            ssl_ca: self.ssl_ca,
88            ssl_client_cert: self.ssl_client_cert,
89            ssl_client_key: self.ssl_client_key,
90            statement_cache_capacity: self.statement_cache_capacity,
91            login,
92            schema: self.schema,
93            protocol_version: self.protocol_version,
94            fetch_size: self.fetch_size,
95            query_timeout: self.query_timeout,
96            compression: self.compression,
97            feedback_interval: self.feedback_interval,
98            log_settings: LogSettings::default(),
99        };
100
101        Ok(opts)
102    }
103
104    #[must_use = "call build() to get connection options"]
105    pub fn host(mut self, host: String) -> Self {
106        self.host = Some(host);
107        self
108    }
109
110    #[must_use = "call build() to get connection options"]
111    pub fn port(mut self, port: u16) -> Self {
112        self.port = port;
113        self
114    }
115
116    #[must_use = "call build() to get connection options"]
117    pub fn ssl_mode(mut self, ssl_mode: ExaSslMode) -> Self {
118        self.ssl_mode = ssl_mode;
119        self
120    }
121
122    #[must_use = "call build() to get connection options"]
123    pub fn ssl_ca(mut self, ssl_ca: CertificateInput) -> Self {
124        self.ssl_ca = Some(ssl_ca);
125        self
126    }
127
128    #[must_use = "call build() to get connection options"]
129    pub fn ssl_client_cert(mut self, ssl_client_cert: CertificateInput) -> Self {
130        self.ssl_client_cert = Some(ssl_client_cert);
131        self
132    }
133
134    #[must_use = "call build() to get connection options"]
135    pub fn ssl_client_key(mut self, ssl_client_key: CertificateInput) -> Self {
136        self.ssl_client_key = Some(ssl_client_key);
137        self
138    }
139
140    #[must_use = "call build() to get connection options"]
141    pub fn statement_cache_capacity(mut self, capacity: NonZeroUsize) -> Self {
142        self.statement_cache_capacity = capacity;
143        self
144    }
145
146    #[must_use = "call build() to get connection options"]
147    pub fn username(mut self, username: String) -> Self {
148        self.username = Some(username);
149        self
150    }
151
152    #[must_use = "call build() to get connection options"]
153    pub fn password(mut self, password: String) -> Self {
154        self.password = Some(password);
155        self
156    }
157
158    #[must_use = "call build() to get connection options"]
159    pub fn access_token(mut self, access_token: String) -> Self {
160        self.access_token = Some(access_token);
161        self
162    }
163
164    #[must_use = "call build() to get connection options"]
165    pub fn refresh_token(mut self, refresh_token: String) -> Self {
166        self.refresh_token = Some(refresh_token);
167        self
168    }
169
170    #[must_use = "call build() to get connection options"]
171    pub fn schema(mut self, schema: String) -> Self {
172        self.schema = Some(schema);
173        self
174    }
175
176    #[must_use = "call build() to get connection options"]
177    pub fn protocol_version(mut self, protocol_version: ProtocolVersion) -> Self {
178        self.protocol_version = protocol_version;
179        self
180    }
181
182    #[must_use = "call build() to get connection options"]
183    pub fn fetch_size(mut self, fetch_size: usize) -> Self {
184        self.fetch_size = fetch_size;
185        self
186    }
187
188    #[must_use = "call build() to get connection options"]
189    pub fn query_timeout(mut self, query_timeout: u64) -> Self {
190        self.query_timeout = query_timeout;
191        self
192    }
193
194    #[must_use = "call build() to get connection options"]
195    pub fn compression(mut self, compression: bool) -> Self {
196        let feature_flag = cfg!(feature = "compression");
197
198        if feature_flag && !compression {
199            tracing::warn!("compression cannot be enabled without the 'compression' feature");
200        }
201
202        self.compression = compression && feature_flag;
203        self
204    }
205
206    #[must_use = "call build() to get connection options"]
207    pub fn feedback_interval(mut self, feedback_interval: u8) -> Self {
208        self.feedback_interval = feedback_interval;
209        self
210    }
211
212    /// Exasol supports host ranges, e.g: hostname4..1.com.
213    /// This method parses the provided host in the connection string and generates one for each
214    /// possible entry in the range.
215    ///
216    /// We do expect the range to be in the ascending order though, so `hostname4..1.com` won't
217    /// work.
218    fn parse_hostname(hostname: String) -> Vec<String> {
219        // If multiple hosts could not be generated, then the given hostname is the only one.
220        Self::_parse_hostname(&hostname).unwrap_or_else(|| vec![hostname])
221    }
222
223    /// This method is used to attempt to generate multiple hosts out of the given hostname.
224    ///
225    /// If that fails, we'll bail early and unwrap the option in a wrapper.
226    #[inline]
227    fn _parse_hostname(hostname: &str) -> Option<Vec<String>> {
228        let mut index_accum = 0;
229
230        // We loop through occurences of ranges (..) and try to find one surrounded by digits.
231        // If that happens, then we break out of the loop with the index of the range occurance.
232        let range_idx = loop {
233            let search_str = &hostname[index_accum..];
234
235            // No range? No problem! Return early.
236            let idx = search_str.find("..")?;
237
238            // While someone actually using something like "..thisismyhostname" in the connection
239            // string would be absolutely insane, it's still somewhat nicer not have this overflow.
240            //
241            // But really, if you read this and your host looks like that, you really should
242            // re-evaluate your taste in domain names.
243            //
244            // In any case, the index points to the range dots.
245            // We want to look before that, hence the substraction.
246            let before_opt = idx
247                .checked_sub(1)
248                .and_then(|i| search_str.as_bytes().get(i));
249
250            // Get the byte after the range dots.
251            let after_opt = search_str.as_bytes().get(idx + 2);
252
253            // Check if the range is surrounded by digits and if so, return its index.
254            // Continue to the next range if not.
255            break match (before_opt, after_opt) {
256                (Some(b), Some(a)) if b.is_ascii_digit() || a.is_ascii_digit() => idx + index_accum,
257                _ => {
258                    index_accum += idx + 2;
259                    continue;
260                }
261            };
262        };
263
264        let before_range = &hostname[..range_idx];
265        let after_range = &hostname[range_idx + 2..];
266
267        // We wanna find the last non-digit character before the range index in the first part of
268        // the hostname and the first non-digit character right after the range dots, in the
269        // second part of the hostname.
270        //
271        // The start is incremented as the index is for the last non-numeric character.
272        //
273        // If no indexes are found, then we consider the beginning/end of string, respectively.
274        let start_idx = before_range
275            .rfind(|c: char| !c.is_ascii_digit())
276            .map(|i| i + 1)
277            .unwrap_or_default();
278        let end_idx = after_range
279            .find(|c: char| !c.is_ascii_digit())
280            .unwrap_or(after_range.len());
281
282        // We split the hostname parts to isolate components.
283        let (prefix, start_range) = before_range.split_at(start_idx);
284        let (end_range, suffix) = after_range.split_at(end_idx);
285
286        // Return the hostname as is if the range boundaries are not integers.
287        let start = start_range.parse::<usize>().ok()?;
288        let end = end_range.parse::<usize>().ok()?;
289
290        let hosts = (start..=end)
291            .map(|i| format!("{prefix}{i}{suffix}"))
292            .collect();
293
294        Some(hosts)
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::ExaConnectOptionsBuilder;
301
302    #[test]
303    fn test_simple_ip() {
304        let hostname = "10.10.10.10";
305
306        let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
307        assert_eq!(generated, vec!(hostname));
308    }
309
310    #[test]
311    fn test_ip_range_end() {
312        let hostname = "10.10.10.1..3";
313        let expected = vec!["10.10.10.1", "10.10.10.2", "10.10.10.3"];
314
315        let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
316        assert_eq!(generated, expected);
317    }
318
319    #[test]
320    fn test_ip_range_start() {
321        let hostname = "1..3.10.10.10";
322        let expected = vec!["1.10.10.10", "2.10.10.10", "3.10.10.10"];
323
324        let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
325        assert_eq!(generated, expected);
326    }
327
328    #[test]
329    fn test_simple_hostname() {
330        let hostname = "myhost.com";
331
332        let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
333        assert_eq!(generated, vec!(hostname));
334    }
335
336    #[test]
337    fn test_hostname_with_range() {
338        let hostname = "myhost1..4.com";
339        let expected = vec!["myhost1.com", "myhost2.com", "myhost3.com", "myhost4.com"];
340
341        let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
342        assert_eq!(generated, expected);
343    }
344
345    #[test]
346    fn test_hostname_with_big_range() {
347        let hostname = "myhost125..127.com";
348        let expected = vec!["myhost125.com", "myhost126.com", "myhost127.com"];
349
350        let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
351        assert_eq!(generated, expected);
352    }
353
354    #[test]
355    fn test_hostname_with_inverse_range() {
356        let hostname = "myhost127..125.com";
357
358        let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
359        assert!(generated.is_empty());
360    }
361
362    #[test]
363    fn test_hostname_with_numbers_no_range() {
364        let hostname = "myhost1.4.com";
365
366        let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
367        assert_eq!(generated, vec![hostname]);
368    }
369
370    #[test]
371    fn test_hostname_with_range_one_numbers() {
372        let hostname = "myhost1..b.com";
373
374        let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
375        assert_eq!(generated, vec![hostname]);
376    }
377
378    #[test]
379    fn test_hostname_with_range_no_numbers() {
380        let hostname = "myhosta..b.com";
381
382        let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
383        assert_eq!(generated, vec![hostname]);
384    }
385
386    #[test]
387    fn test_hostname_starts_with_range() {
388        let hostname = "..myhost.com";
389
390        let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
391        assert_eq!(generated, vec![hostname]);
392    }
393
394    #[test]
395    fn test_hostname_ends_with_range() {
396        let hostname = "myhost.com..";
397
398        let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
399        assert_eq!(generated, vec![hostname]);
400    }
401
402    #[test]
403    fn test_hostname_real_and_fake_range() {
404        let hostname = "myhosta..bcdef1..3.com";
405        let expected = vec![
406            "myhosta..bcdef1.com",
407            "myhosta..bcdef2.com",
408            "myhosta..bcdef3.com",
409        ];
410
411        let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
412        assert_eq!(generated, expected);
413    }
414
415    #[test]
416    fn test_hostname_two_valid_ranges() {
417        let hostname = "myhost1..3cdef4..7.com";
418        let expected = vec![
419            "myhost1cdef4..7.com",
420            "myhost2cdef4..7.com",
421            "myhost3cdef4..7.com",
422        ];
423
424        let generated = ExaConnectOptionsBuilder::parse_hostname(hostname.to_owned());
425        assert_eq!(generated, expected);
426    }
427}