sqlx_exasol_impl/options/
builder.rs

1use std::ops::RangeInclusive;
2
3use sqlx_core::{connection::LogSettings, net::tls::CertificateInput};
4
5use super::{
6    error::ExaConfigError, ssl_mode::ExaSslMode, ExaConnectOptions, Login, ProtocolVersion,
7    DEFAULT_CACHE_CAPACITY, DEFAULT_FETCH_SIZE, DEFAULT_PORT,
8};
9use crate::{options::compression::ExaCompressionMode, SqlxResult};
10
11/// Builder for [`ExaConnectOptions`].
12#[derive(Clone, Debug)]
13pub struct ExaConnectOptionsBuilder {
14    url_host: Option<String>,
15    url_port: u16,
16    extra_hosts: Vec<(String, u16)>,
17    ssl_mode: ExaSslMode,
18    ssl_ca: Option<CertificateInput>,
19    ssl_client_cert: Option<CertificateInput>,
20    ssl_client_key: Option<CertificateInput>,
21    statement_cache_capacity: usize,
22    username: Option<String>,
23    password: Option<String>,
24    access_token: Option<String>,
25    refresh_token: Option<String>,
26    schema: Option<String>,
27    protocol_version: ProtocolVersion,
28    fetch_size: usize,
29    query_timeout: u64,
30    compression_mode: ExaCompressionMode,
31    feedback_interval: u64,
32}
33
34impl Default for ExaConnectOptionsBuilder {
35    fn default() -> Self {
36        Self {
37            url_host: None,
38            url_port: DEFAULT_PORT,
39            extra_hosts: Vec::new(),
40            ssl_mode: ExaSslMode::default(),
41            ssl_ca: None,
42            ssl_client_cert: None,
43            ssl_client_key: None,
44            statement_cache_capacity: DEFAULT_CACHE_CAPACITY,
45            username: None,
46            password: None,
47            access_token: None,
48            refresh_token: None,
49            schema: None,
50            protocol_version: ProtocolVersion::default(),
51            fetch_size: DEFAULT_FETCH_SIZE,
52            query_timeout: 0,
53            compression_mode: ExaCompressionMode::default(),
54            feedback_interval: 1,
55        }
56    }
57}
58
59impl ExaConnectOptionsBuilder {
60    /// Consumes this builder and returns an instance of [`ExaConnectOptions`].
61    ///
62    /// # Errors
63    ///
64    /// Will return an error if no host or other than exactly one login method were provided.
65    pub fn build(self) -> SqlxResult<ExaConnectOptions> {
66        let url_host = self.url_host.ok_or(ExaConfigError::MissingHost)?;
67        let password = self.password.unwrap_or_default();
68
69        // Only one authentication method can be used at once
70        let login = match (self.username, self.access_token, self.refresh_token) {
71            (Some(username), None, None) => Login::Credentials { username, password },
72            (None, Some(access_token), None) => Login::AccessToken { access_token },
73            (None, None, Some(refresh_token)) => Login::RefreshToken { refresh_token },
74            (None, None, None) => return Err(ExaConfigError::MissingAuthMethod.into()),
75            _ => return Err(ExaConfigError::MultipleAuthMethods.into()),
76        };
77
78        let hosts = Some((url_host.clone(), self.url_port))
79            .into_iter()
80            .chain(self.extra_hosts)
81            .flat_map(|(host, port)| Self::parse_host(host).map(move |host| (host.into(), port)))
82            .collect();
83
84        let opts = ExaConnectOptions {
85            hosts,
86            ssl_mode: self.ssl_mode,
87            ssl_ca: self.ssl_ca,
88            ssl_client_cert: self.ssl_client_cert,
89            ssl_client_key: self.ssl_client_key,
90            statement_cache_capacity: self.statement_cache_capacity,
91            url_host,
92            url_port: self.url_port,
93            login,
94            schema: self.schema,
95            protocol_version: self.protocol_version,
96            fetch_size: self.fetch_size,
97            query_timeout: self.query_timeout,
98            compression_mode: self.compression_mode,
99            feedback_interval: self.feedback_interval,
100            log_settings: LogSettings::default(),
101        };
102
103        Ok(opts)
104    }
105
106    #[must_use = "call build() to get connection options"]
107    pub fn host(mut self, host: String) -> Self {
108        self.url_host = Some(host);
109        self
110    }
111
112    #[must_use = "call build() to get connection options"]
113    pub fn port(mut self, port: u16) -> Self {
114        self.url_port = port;
115        self
116    }
117
118    /// Appends an additional host to be used for randomly
119    /// connecting to Exasol nodes.
120    ///
121    /// Can be called multiple times.
122    #[must_use = "call build() to get connection options"]
123    pub fn extra_host(mut self, host: String, port: Option<u16>) -> Self {
124        self.extra_hosts.push((host, port.unwrap_or(DEFAULT_PORT)));
125        self
126    }
127
128    #[must_use = "call build() to get connection options"]
129    pub fn ssl_mode(mut self, ssl_mode: ExaSslMode) -> Self {
130        self.ssl_mode = ssl_mode;
131        self
132    }
133
134    #[must_use = "call build() to get connection options"]
135    pub fn ssl_ca(mut self, ssl_ca: CertificateInput) -> Self {
136        self.ssl_ca = Some(ssl_ca);
137        self
138    }
139
140    #[must_use = "call build() to get connection options"]
141    pub fn ssl_client_cert(mut self, ssl_client_cert: CertificateInput) -> Self {
142        self.ssl_client_cert = Some(ssl_client_cert);
143        self
144    }
145
146    #[must_use = "call build() to get connection options"]
147    pub fn ssl_client_key(mut self, ssl_client_key: CertificateInput) -> Self {
148        self.ssl_client_key = Some(ssl_client_key);
149        self
150    }
151
152    /// Sets the capacity of the statement cache.
153    ///
154    /// The cache is enabled by default. Setting the capacity to `0` disables the cache.
155    #[must_use = "call build() to get connection options"]
156    pub fn statement_cache_capacity(mut self, capacity: usize) -> Self {
157        self.statement_cache_capacity = capacity;
158        self
159    }
160
161    #[must_use = "call build() to get connection options"]
162    pub fn username(mut self, username: String) -> Self {
163        self.username = Some(username);
164        self
165    }
166
167    #[must_use = "call build() to get connection options"]
168    pub fn password(mut self, password: String) -> Self {
169        self.password = Some(password);
170        self
171    }
172
173    #[must_use = "call build() to get connection options"]
174    pub fn access_token(mut self, access_token: String) -> Self {
175        self.access_token = Some(access_token);
176        self
177    }
178
179    #[must_use = "call build() to get connection options"]
180    pub fn refresh_token(mut self, refresh_token: String) -> Self {
181        self.refresh_token = Some(refresh_token);
182        self
183    }
184
185    #[must_use = "call build() to get connection options"]
186    pub fn schema(mut self, schema: String) -> Self {
187        self.schema = Some(schema);
188        self
189    }
190
191    #[must_use = "call build() to get connection options"]
192    pub fn fetch_size(mut self, fetch_size: usize) -> Self {
193        self.fetch_size = fetch_size;
194        self
195    }
196
197    #[must_use = "call build() to get connection options"]
198    pub fn query_timeout(mut self, query_timeout: u64) -> Self {
199        self.query_timeout = query_timeout;
200        self
201    }
202
203    #[must_use = "call build() to get connection options"]
204    pub fn compression_mode(mut self, compression_mode: ExaCompressionMode) -> Self {
205        self.compression_mode = compression_mode;
206        self
207    }
208
209    #[must_use = "call build() to get connection options"]
210    pub fn feedback_interval(mut self, feedback_interval: u64) -> Self {
211        self.feedback_interval = feedback_interval;
212        self
213    }
214
215    /// Exasol supports host ranges, e.g: hostname1..4.com.
216    /// This method parses the provided host in the connection string and generates one for each
217    /// possible entry in the range.
218    ///
219    /// We do expect the range to be in the ascending order though, so `hostname4..1.com` will be
220    /// returned as is.
221    fn parse_host(host: String) -> HostKind {
222        // Loop through occurences of ranges (..) in reverse looking for one surrounded by digits.
223        for (idx, _) in host.rmatch_indices("..") {
224            let has_digit_before_range = idx
225                .checked_sub(1)
226                .and_then(|i| host.as_bytes().get(i))
227                .is_some_and(u8::is_ascii_digit);
228
229            let has_digit_after_range =
230                host.as_bytes().get(idx + 2).is_some_and(u8::is_ascii_digit);
231
232            // Move on if the range is not surrounded by digits.
233            if !has_digit_before_range || !has_digit_after_range {
234                continue;
235            }
236
237            let before_range = &host[..idx];
238            let after_range = &host[idx + 2..];
239
240            // Find the last non-digit character before the range index in the first part of
241            // the hostname and the first non-digit character right after the range dots, in the
242            // second part of the hostname.
243            //
244            // The start is incremented as the index is for the last non-numeric character.
245            //
246            // If no indexes are found, then we consider the beginning/end of string, respectively.
247            let prefix_idx = before_range
248                .rfind(|c: char| !c.is_ascii_digit())
249                .map(|i| i + 1)
250                .unwrap_or_default();
251            let suffix_idx = after_range
252                .find(|c: char| !c.is_ascii_digit())
253                .unwrap_or(after_range.len());
254
255            // Split the hostname parts to isolate components.
256            let (_, start_range) = before_range.split_at(prefix_idx);
257            let (end_range, _) = after_range.split_at(suffix_idx);
258
259            return match (start_range.parse::<usize>(), end_range.parse::<usize>()) {
260                (Ok(start), Ok(end)) if start < end => HostKind::Range {
261                    buffer: host,
262                    prefix_idx,
263                    suffix_idx: idx + 2 + suffix_idx,
264                    range: start..=end,
265                },
266                // Return the hostname as is if the range boundaries are not integers
267                // or if start >= end.
268                _ => HostKind::Single(host),
269            };
270        }
271
272        // No numeric range present, return singular hostname.
273        HostKind::Single(host)
274    }
275}
276
277#[derive(Clone, Debug, PartialEq)]
278enum HostKind {
279    Single(String),
280    Range {
281        buffer: String,
282        prefix_idx: usize,
283        suffix_idx: usize,
284        range: RangeInclusive<usize>,
285    },
286}
287
288impl Iterator for HostKind {
289    type Item = String;
290
291    fn next(&mut self) -> Option<Self::Item> {
292        match self {
293            HostKind::Single(s) => (!s.is_empty()).then(|| std::mem::take(s)),
294            HostKind::Range {
295                buffer,
296                prefix_idx,
297                suffix_idx,
298                range,
299            } => range.next().map(|i| {
300                let (prefix, _) = buffer.split_at(*prefix_idx);
301                let (_, suffix) = buffer.split_at(*suffix_idx);
302                format!("{prefix}{i}{suffix}")
303            }),
304        }
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn test_simple_ip() {
314        let host = "10.10.10.10";
315        let generated = ExaConnectOptionsBuilder::parse_host(host.to_owned());
316        assert!(generated.eq(Some(host.to_owned())));
317    }
318
319    #[test]
320    fn test_ip_range_end() {
321        let host = "10.10.10.1..3";
322        let expected = vec![
323            "10.10.10.1".to_owned(),
324            "10.10.10.2".to_owned(),
325            "10.10.10.3".to_owned(),
326        ];
327
328        let generated = ExaConnectOptionsBuilder::parse_host(host.to_owned());
329        assert!(generated.eq(expected));
330    }
331
332    #[test]
333    fn test_ip_range_start() {
334        let host = "1..3.10.10.10";
335        let expected = vec![
336            "1.10.10.10".to_owned(),
337            "2.10.10.10".to_owned(),
338            "3.10.10.10".to_owned(),
339        ];
340
341        let generated = ExaConnectOptionsBuilder::parse_host(host.to_owned());
342        assert!(generated.eq(expected));
343    }
344
345    #[test]
346    fn test_simple_hostname() {
347        let host = "myhost.com";
348        let generated = ExaConnectOptionsBuilder::parse_host(host.to_owned());
349        assert!(generated.eq(Some(host.to_owned())));
350    }
351
352    #[test]
353    fn test_hostname_with_range() {
354        let host = "myhost1..4.com";
355        let expected = vec![
356            "myhost1.com".to_owned(),
357            "myhost2.com".to_owned(),
358            "myhost3.com".to_owned(),
359            "myhost4.com".to_owned(),
360        ];
361
362        let generated = ExaConnectOptionsBuilder::parse_host(host.to_owned());
363        assert!(generated.eq(expected));
364    }
365
366    #[test]
367    fn test_hostname_with_big_range() {
368        let host = "myhost125..127.com";
369        let expected = vec![
370            "myhost125.com".to_owned(),
371            "myhost126.com".to_owned(),
372            "myhost127.com".to_owned(),
373        ];
374
375        let generated = ExaConnectOptionsBuilder::parse_host(host.to_owned());
376        assert!(generated.eq(expected));
377    }
378
379    #[test]
380    fn test_hostname_with_inverse_range() {
381        let host = "myhost127..125.com";
382        let generated = ExaConnectOptionsBuilder::parse_host(host.to_owned());
383        assert!(generated.eq(Some(host.to_owned())));
384    }
385
386    #[test]
387    fn test_hostname_with_numbers_no_range() {
388        let host = "myhost1.4.com";
389        let generated = ExaConnectOptionsBuilder::parse_host(host.to_owned());
390        assert!(generated.eq(Some(host.to_owned())));
391    }
392
393    #[test]
394    fn test_hostname_with_range_one_numbers() {
395        let host = "myhost1..b.com";
396        let generated = ExaConnectOptionsBuilder::parse_host(host.to_owned());
397        assert!(generated.eq(Some(host.to_owned())));
398    }
399
400    #[test]
401    fn test_hostname_with_range_no_numbers() {
402        let host = "myhosta..b.com";
403        let generated = ExaConnectOptionsBuilder::parse_host(host.to_owned());
404        assert!(generated.eq(Some(host.to_owned())));
405    }
406
407    #[test]
408    fn test_hostname_starts_with_range() {
409        let host = "..myhost.com";
410        let generated = ExaConnectOptionsBuilder::parse_host(host.to_owned());
411        assert!(generated.eq(Some(host.to_owned())));
412    }
413
414    #[test]
415    fn test_hostname_ends_with_range() {
416        let host = "myhost.com..";
417        let generated = ExaConnectOptionsBuilder::parse_host(host.to_owned());
418        assert!(generated.eq(Some(host.to_owned())));
419    }
420
421    #[test]
422    fn test_hostname_real_and_fake_range() {
423        let host = "myhosta..bcdef1..3.com";
424        let expected = vec![
425            "myhosta..bcdef1.com".to_owned(),
426            "myhosta..bcdef2.com".to_owned(),
427            "myhosta..bcdef3.com".to_owned(),
428        ];
429
430        let generated = ExaConnectOptionsBuilder::parse_host(host.to_owned());
431        assert!(generated.eq(expected));
432    }
433
434    #[test]
435    fn test_hostname_two_valid_ranges() {
436        let host = "myhost1..3cdef4..7.com";
437        let expected = vec![
438            "myhost1..3cdef4.com".to_owned(),
439            "myhost1..3cdef5.com".to_owned(),
440            "myhost1..3cdef6.com".to_owned(),
441            "myhost1..3cdef7.com".to_owned(),
442        ];
443
444        let generated = ExaConnectOptionsBuilder::parse_host(host.to_owned());
445        assert!(generated.eq(expected));
446    }
447}