sqlx_exasol_impl/options/
mod.rs

1mod builder;
2mod compression;
3mod error;
4mod protocol_version;
5mod ssl_mode;
6
7use std::{borrow::Cow, path::PathBuf, str::FromStr, sync::Arc};
8
9pub use builder::ExaConnectOptionsBuilder;
10pub use compression::ExaCompressionMode;
11use error::ExaConfigError;
12pub use protocol_version::ProtocolVersion;
13use sqlx_core::{
14    connection::{ConnectOptions, LogSettings},
15    net::tls::CertificateInput,
16    percent_encoding::{percent_decode_str, utf8_percent_encode, NON_ALPHANUMERIC},
17};
18pub use ssl_mode::ExaSslMode;
19use tracing::log;
20use url::Url;
21
22use crate::{
23    connection::{
24        websocket::request::{ExaLoginRequest, LoginRef},
25        ExaConnection,
26    },
27    error::ExaProtocolError,
28    responses::ExaRwAttributes,
29    SqlxError, SqlxResult,
30};
31
32const URL_SCHEME: &str = "exa";
33
34const DEFAULT_FETCH_SIZE: usize = 5 * 1024 * 1024;
35const DEFAULT_PORT: u16 = 8563;
36const DEFAULT_CACHE_CAPACITY: usize = 100;
37
38const ACCESS_TOKEN: &str = "access-token";
39const REFRESH_TOKEN: &str = "refresh-token";
40const SSL_MODE: &str = "ssl-mode";
41const SSL_CA: &str = "ssl-ca";
42const SSL_CERT: &str = "ssl-cert";
43const SSL_KEY: &str = "ssl-key";
44const STATEMENT_CACHE_CAPACITY: &str = "statement-cache-capacity";
45const FETCH_SIZE: &str = "fetch-size";
46const QUERY_TIMEOUT: &str = "query-timeout";
47const COMPRESSION: &str = "compression";
48const FEEDBACK_INTERVAL: &str = "feedback-interval";
49
50/// Options for connecting to the Exasol database. Implementor of [`ConnectOptions`].
51///
52/// While generally automatically created through a connection string,
53/// [`ExaConnectOptions::builder()`] can be used to get a [`ExaConnectOptionsBuilder`].
54///
55/// Connection options:
56/// - `access-token`: Use an access token for login instead of credentials
57/// - `refresh-token`: Use a refresh token for login instead of credentials
58/// - `ssl-mode`: Select a specifc SSL behavior.
59/// - `ssl-ca`: Use a certain certificate authority
60/// - `ssl-cert`: Use a certain certificate
61/// - `ssl-key`: Use a specific SSL key
62/// - `statement-cache-capacity`: Set the capacity of the LRU prepared statements cache
63/// - `fetch-size`: Sets the size of data chunks when retrieving result sets
64/// - `query-timeout`: The query timeout amount, in seconds. 0 means no timeout
65/// - `compression`: Set the desired compression mode.
66/// - `feedback-interval`: Interval at which Exasol sends keep-alive Pong frames
67#[derive(Debug, Clone)]
68pub struct ExaConnectOptions {
69    pub(crate) hosts: Vec<(Arc<str>, u16)>,
70    pub(crate) ssl_mode: ExaSslMode,
71    pub(crate) ssl_ca: Option<CertificateInput>,
72    pub(crate) ssl_client_cert: Option<CertificateInput>,
73    pub(crate) ssl_client_key: Option<CertificateInput>,
74    pub(crate) statement_cache_capacity: usize,
75    pub(crate) schema: Option<String>,
76    pub(crate) compression_mode: ExaCompressionMode,
77    pub(crate) log_settings: LogSettings,
78    url_host: String,
79    url_port: u16,
80    login: Login,
81    protocol_version: ProtocolVersion,
82    fetch_size: usize,
83    query_timeout: u64,
84    feedback_interval: u64,
85}
86
87impl ExaConnectOptions {
88    #[must_use = "call build() to get connection options"]
89    pub fn builder() -> ExaConnectOptionsBuilder {
90        ExaConnectOptionsBuilder::default()
91    }
92
93    /// Create an [`ExaConnectOptionsBuilder`] by starting from an [`Url`].
94    ///
95    /// # Errors
96    ///
97    /// Returns an error if parsing the [`Url`] fails.
98    #[must_use = "call build() to get connection options"]
99    pub fn builder_from_url(url: &Url) -> SqlxResult<ExaConnectOptionsBuilder> {
100        let scheme = url.scheme();
101
102        if URL_SCHEME != scheme {
103            return Err(ExaConfigError::InvalidUrlScheme(scheme.to_owned()).into());
104        }
105
106        let mut builder = Self::builder();
107
108        if let Some(host) = url.host_str() {
109            builder = builder.host(host.to_owned());
110        }
111
112        let username = url.username();
113        if !username.is_empty() {
114            let username = percent_decode_str(username)
115                .decode_utf8()
116                .map_err(SqlxError::config)?;
117            builder = builder.username(username.to_string());
118        }
119
120        if let Some(password) = url.password() {
121            let password = percent_decode_str(password)
122                .decode_utf8()
123                .map_err(SqlxError::config)?;
124            builder = builder.password(password.to_string());
125        }
126
127        if let Some(port) = url.port() {
128            builder = builder.port(port);
129        }
130
131        let path = url.path().trim_start_matches('/');
132
133        if !path.is_empty() {
134            let db_schema = percent_decode_str(path)
135                .decode_utf8()
136                .map_err(SqlxError::config)?;
137            builder = builder.schema(db_schema.to_string());
138        }
139
140        for (name, value) in url.query_pairs() {
141            match name.as_ref() {
142                ACCESS_TOKEN => builder = builder.access_token(value.to_string()),
143
144                REFRESH_TOKEN => builder = builder.refresh_token(value.to_string()),
145
146                SSL_MODE => {
147                    let ssl_mode = value.parse::<ExaSslMode>()?;
148                    builder = builder.ssl_mode(ssl_mode);
149                }
150
151                SSL_CA => {
152                    let ssl_ca = CertificateInput::File(PathBuf::from(value.to_string()));
153                    builder = builder.ssl_ca(ssl_ca);
154                }
155
156                SSL_CERT => {
157                    let ssl_cert = CertificateInput::File(PathBuf::from(value.to_string()));
158                    builder = builder.ssl_client_cert(ssl_cert);
159                }
160
161                SSL_KEY => {
162                    let ssl_key = CertificateInput::File(PathBuf::from(value.to_string()));
163                    builder = builder.ssl_client_key(ssl_key);
164                }
165
166                STATEMENT_CACHE_CAPACITY => {
167                    let capacity = value
168                        .parse::<usize>()
169                        .map_err(|_| ExaConfigError::InvalidParameter(STATEMENT_CACHE_CAPACITY))?;
170                    builder = builder.statement_cache_capacity(capacity);
171                }
172
173                FETCH_SIZE => {
174                    let fetch_size = value
175                        .parse::<usize>()
176                        .map_err(|_| ExaConfigError::InvalidParameter(FETCH_SIZE))?;
177                    builder = builder.fetch_size(fetch_size);
178                }
179
180                QUERY_TIMEOUT => {
181                    let query_timeout = value
182                        .parse::<u64>()
183                        .map_err(|_| ExaConfigError::InvalidParameter(QUERY_TIMEOUT))?;
184                    builder = builder.query_timeout(query_timeout);
185                }
186
187                COMPRESSION => {
188                    let compression_mode = value
189                        .parse::<ExaCompressionMode>()
190                        .map_err(|_| ExaConfigError::InvalidParameter(COMPRESSION))?;
191                    builder = builder.compression_mode(compression_mode);
192                }
193
194                FEEDBACK_INTERVAL => {
195                    let feedback_interval = value
196                        .parse::<u64>()
197                        .map_err(|_| ExaConfigError::InvalidParameter(FEEDBACK_INTERVAL))?;
198                    builder = builder.feedback_interval(feedback_interval);
199                }
200
201                _ => {
202                    return Err(SqlxError::Protocol(format!(
203                        "Unknown connection string parameter: {value}"
204                    )))
205                }
206            }
207        }
208
209        Ok(builder)
210    }
211
212    /// Create an [`ExaConnectOptionsBuilder`] by starting from a connection string.
213    ///
214    /// # Errors
215    ///
216    /// Returns an error if parsing the connection string fails.
217    #[must_use = "call build() to get connection options"]
218    pub fn builder_from_str(s: &str) -> SqlxResult<ExaConnectOptionsBuilder> {
219        let url = Url::parse(s)
220            .map_err(From::from)
221            .map_err(SqlxError::Configuration)?;
222        Self::builder_from_url(&url)
223    }
224}
225
226impl FromStr for ExaConnectOptions {
227    type Err = SqlxError;
228
229    fn from_str(s: &str) -> Result<Self, Self::Err> {
230        Self::builder_from_str(s)?.build()
231    }
232}
233
234impl ConnectOptions for ExaConnectOptions {
235    type Connection = ExaConnection;
236
237    fn from_url(url: &Url) -> SqlxResult<Self> {
238        Self::builder_from_url(url)?.build()
239    }
240
241    fn to_url_lossy(&self) -> Url {
242        let mut url = Url::parse(&format!(
243            "{URL_SCHEME}://{}:{}",
244            self.url_host, self.url_port
245        ))
246        .expect("generated URL must be correct");
247
248        if let Some(schema) = &self.schema {
249            url.set_path(schema);
250        }
251
252        match &self.login {
253            Login::Credentials { username, password } => {
254                url.set_username(username).ok();
255                let password = utf8_percent_encode(password, NON_ALPHANUMERIC).to_string();
256                url.set_password(Some(&password)).ok();
257            }
258            Login::AccessToken { access_token } => {
259                url.query_pairs_mut()
260                    .append_pair(ACCESS_TOKEN, access_token);
261            }
262            Login::RefreshToken { refresh_token } => {
263                url.query_pairs_mut()
264                    .append_pair(REFRESH_TOKEN, refresh_token);
265            }
266        }
267
268        url.query_pairs_mut()
269            .append_pair(SSL_MODE, self.ssl_mode.as_ref());
270
271        if let Some(ssl_ca) = &self.ssl_ca {
272            url.query_pairs_mut()
273                .append_pair(SSL_CA, &ssl_ca.to_string());
274        }
275
276        if let Some(ssl_cert) = &self.ssl_client_cert {
277            url.query_pairs_mut()
278                .append_pair(SSL_CERT, &ssl_cert.to_string());
279        }
280
281        if let Some(ssl_key) = &self.ssl_client_key {
282            url.query_pairs_mut()
283                .append_pair(SSL_KEY, &ssl_key.to_string());
284        }
285
286        url.query_pairs_mut().append_pair(
287            STATEMENT_CACHE_CAPACITY,
288            &self.statement_cache_capacity.to_string(),
289        );
290
291        url.query_pairs_mut()
292            .append_pair(FETCH_SIZE, &self.fetch_size.to_string());
293
294        url.query_pairs_mut()
295            .append_pair(QUERY_TIMEOUT, &self.query_timeout.to_string());
296
297        url.query_pairs_mut()
298            .append_pair(COMPRESSION, self.compression_mode.as_ref());
299
300        url.query_pairs_mut()
301            .append_pair(FEEDBACK_INTERVAL, &self.feedback_interval.to_string());
302
303        url
304    }
305
306    async fn connect(&self) -> SqlxResult<Self::Connection>
307    where
308        Self::Connection: Sized,
309    {
310        ExaConnection::establish(self).await
311    }
312
313    fn log_statements(mut self, level: log::LevelFilter) -> Self {
314        self.log_settings.log_statements(level);
315        self
316    }
317
318    fn log_slow_statements(
319        mut self,
320        level: log::LevelFilter,
321        duration: std::time::Duration,
322    ) -> Self {
323        self.log_settings.log_slow_statements(level, duration);
324        self
325    }
326}
327
328impl<'a> TryFrom<&'a ExaConnectOptions> for ExaLoginRequest<'a> {
329    type Error = ExaProtocolError;
330
331    fn try_from(value: &'a ExaConnectOptions) -> Result<Self, Self::Error> {
332        let crate_version = option_env!("CARGO_PKG_VERSION").unwrap_or("UNKNOWN");
333
334        let attributes = ExaRwAttributes::new(
335            value.schema.as_deref().map(Cow::Borrowed),
336            value.feedback_interval,
337            value.query_timeout,
338        );
339
340        let compression_supported = cfg!(feature = "compression");
341
342        let use_compression = match value.compression_mode {
343            ExaCompressionMode::Disabled => false,
344            ExaCompressionMode::Preferred if !compression_supported => {
345                tracing::debug!("not using compression: compression support not compiled in");
346                false
347            }
348            ExaCompressionMode::Preferred => true,
349            ExaCompressionMode::Required if compression_supported => true,
350            ExaCompressionMode::Required => return Err(ExaProtocolError::CompressionDisabled),
351        };
352
353        let output = Self {
354            protocol_version: value.protocol_version,
355            fetch_size: value.fetch_size,
356            statement_cache_capacity: value.statement_cache_capacity,
357            login: (&value.login).into(),
358            use_compression,
359            client_name: "sqlx-exasol",
360            client_version: crate_version,
361            client_os: std::env::consts::OS,
362            client_runtime: "RUST",
363            attributes,
364        };
365
366        Ok(output)
367    }
368}
369
370/// Enum representing the possible ways of authenticating a connection.
371/// The variant chosen dictates which login process is called.
372#[derive(Clone, Debug)]
373pub enum Login {
374    Credentials { username: String, password: String },
375    AccessToken { access_token: String },
376    RefreshToken { refresh_token: String },
377}
378
379impl<'a> From<&'a Login> for LoginRef<'a> {
380    fn from(value: &'a Login) -> Self {
381        match value {
382            Login::Credentials { username, password } => LoginRef::Credentials {
383                username,
384                password: Cow::Borrowed(password),
385            },
386            Login::AccessToken { access_token } => LoginRef::AccessToken { access_token },
387            Login::RefreshToken { refresh_token } => LoginRef::RefreshToken { refresh_token },
388        }
389    }
390}
391
392/// Helper containing TLS related options.
393#[derive(Debug, Clone, Copy)]
394#[allow(clippy::struct_field_names)]
395pub struct ExaTlsOptionsRef<'a> {
396    pub ssl_mode: ExaSslMode,
397    pub ssl_ca: Option<&'a CertificateInput>,
398    pub ssl_client_cert: Option<&'a CertificateInput>,
399    pub ssl_client_key: Option<&'a CertificateInput>,
400}
401
402impl<'a> From<&'a ExaConnectOptions> for ExaTlsOptionsRef<'a> {
403    fn from(value: &'a ExaConnectOptions) -> Self {
404        ExaTlsOptionsRef {
405            ssl_mode: value.ssl_mode,
406            ssl_ca: value.ssl_ca.as_ref(),
407            ssl_client_cert: value.ssl_client_cert.as_ref(),
408            ssl_client_key: value.ssl_client_key.as_ref(),
409        }
410    }
411}
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    #[test]
417    fn test_from_url_basic() {
418        let url = "exa://user:pass@localhost:8563/schema";
419        let options = ExaConnectOptions::from_str(url).unwrap();
420
421        assert_eq!(options.url_host, "localhost");
422        assert_eq!(options.url_port, 8563);
423        assert_eq!(options.schema.as_deref(), Some("schema"));
424
425        match &options.login {
426            Login::Credentials { username, password } => {
427                assert_eq!(username, "user");
428                assert_eq!(password, "pass");
429            }
430            _ => panic!("Expected credentials login"),
431        }
432    }
433
434    #[test]
435    fn test_from_url_with_query_params() {
436        let url = "exa://localhost:8563?access-token=token123&compression=disabled&fetch-size=1024";
437        let options = ExaConnectOptions::from_str(url).unwrap();
438
439        match &options.login {
440            Login::AccessToken { access_token } => {
441                assert_eq!(access_token, "token123");
442            }
443            _ => panic!("Expected access token login"),
444        }
445
446        assert_eq!(options.compression_mode, ExaCompressionMode::Disabled);
447        assert_eq!(options.fetch_size, 1024);
448    }
449
450    #[test]
451    fn test_from_url_refresh_token() {
452        let url = "exa://localhost:8563?refresh-token=refresh123";
453        let options = ExaConnectOptions::from_str(url).unwrap();
454
455        match &options.login {
456            Login::RefreshToken { refresh_token } => {
457                assert_eq!(refresh_token, "refresh123");
458            }
459            _ => panic!("Expected refresh token login"),
460        }
461    }
462
463    #[test]
464    fn test_from_url_ssl_params() {
465        let url = "exa://user:p@ssw0rd@localhost:8563?ssl-mode=required&ssl-ca=/path/to/ca.crt";
466        let options = ExaConnectOptions::from_str(url).unwrap();
467
468        assert_eq!(options.ssl_mode, ExaSslMode::Required);
469        assert!(options.ssl_ca.is_some());
470    }
471
472    #[test]
473    fn test_from_url_numeric_params() {
474        let url = "exa://user:p@ssw0rd@localhost:8563?statement-cache-capacity=50&\
475                   query-timeout=30&feedback-interval=10";
476        let options = ExaConnectOptions::from_str(url).unwrap();
477
478        assert_eq!(options.statement_cache_capacity, 50);
479        assert_eq!(options.query_timeout, 30);
480        assert_eq!(options.feedback_interval, 10);
481    }
482
483    #[test]
484    fn test_from_url_invalid_scheme() {
485        let url = "mysql://localhost:8563";
486        let result = ExaConnectOptions::from_str(url);
487        assert!(result.is_err());
488    }
489
490    #[test]
491    fn test_from_url_unknown_parameter() {
492        let url = "exa://localhost:8563?unknown-param=value";
493        let result = ExaConnectOptions::from_str(url);
494        assert!(result.is_err());
495    }
496
497    #[test]
498    fn test_to_url_lossy_credentials() {
499        let options = ExaConnectOptions::builder()
500            .host("localhost".to_string())
501            .port(8563)
502            .username("user".to_string())
503            .password("pass".to_string())
504            .schema("schema".to_string())
505            .build()
506            .unwrap();
507
508        let url = options.to_url_lossy();
509
510        assert_eq!(url.scheme(), "exa");
511        assert_eq!(url.host_str(), Some("localhost"));
512        assert_eq!(url.port(), Some(8563));
513        assert_eq!(url.path(), "/schema");
514        assert_eq!(url.username(), "user");
515        assert_eq!(url.password(), Some("pass"));
516    }
517
518    #[test]
519    fn test_to_url_lossy_access_token() {
520        let options = ExaConnectOptions::builder()
521            .host("localhost".to_string())
522            .access_token("token123".to_string())
523            .build()
524            .unwrap();
525
526        let url = options.to_url_lossy();
527
528        let query_pairs: std::collections::HashMap<String, String> =
529            url.query_pairs().into_owned().collect();
530
531        assert_eq!(query_pairs.get(ACCESS_TOKEN), Some(&"token123".to_string()));
532    }
533
534    #[test]
535    fn test_to_url_lossy_refresh_token() {
536        let options = ExaConnectOptions::builder()
537            .host("localhost".to_string())
538            .refresh_token("refresh123".to_string())
539            .build()
540            .unwrap();
541
542        let url = options.to_url_lossy();
543
544        let query_pairs: std::collections::HashMap<String, String> =
545            url.query_pairs().into_owned().collect();
546
547        assert_eq!(
548            query_pairs.get(REFRESH_TOKEN),
549            Some(&"refresh123".to_string())
550        );
551    }
552
553    #[test]
554    fn test_to_url_lossy_all_params() {
555        let options = ExaConnectOptions::builder()
556            .host("localhost".to_string())
557            .port(8563)
558            .username("user".to_string())
559            .password("pass".to_string())
560            .schema("schema".to_string())
561            .compression_mode(ExaCompressionMode::Disabled)
562            .fetch_size(2048)
563            .query_timeout(60)
564            .feedback_interval(5)
565            .statement_cache_capacity(200)
566            .build()
567            .unwrap();
568
569        let url = options.to_url_lossy();
570
571        let query_pairs: std::collections::HashMap<String, String> =
572            url.query_pairs().into_owned().collect();
573
574        assert_eq!(query_pairs.get(COMPRESSION), Some(&"disabled".to_string()));
575        assert_eq!(query_pairs.get(FETCH_SIZE), Some(&"2048".to_string()));
576        assert_eq!(query_pairs.get(QUERY_TIMEOUT), Some(&"60".to_string()));
577        assert_eq!(query_pairs.get(FEEDBACK_INTERVAL), Some(&"5".to_string()));
578        assert_eq!(
579            query_pairs.get(STATEMENT_CACHE_CAPACITY),
580            Some(&"200".to_string())
581        );
582    }
583
584    #[test]
585    fn test_roundtrip_conversion() {
586        let original_url =
587            "exa://user:pass@localhost:8563/schema?compression=preferred&fetch-size=1024";
588        let options = ExaConnectOptions::from_str(original_url).unwrap();
589        let reconstructed_url = options.to_url_lossy();
590        let options2 = ExaConnectOptions::from_url(&reconstructed_url).unwrap();
591
592        assert_eq!(options.url_host, options2.url_host);
593        assert_eq!(options.url_port, options2.url_port);
594        assert_eq!(options.schema, options2.schema);
595        assert_eq!(options.compression_mode, options2.compression_mode);
596        assert_eq!(options.fetch_size, options2.fetch_size);
597    }
598    #[test]
599    fn test_compression_modes() {
600        // Test ExaCompressionMode::Disabled
601        let url = "exa://user:pass@localhost:8563?compression=disabled";
602        let options = ExaConnectOptions::from_str(url).unwrap();
603        assert_eq!(options.compression_mode, ExaCompressionMode::Disabled);
604
605        // Test ExaCompressionMode::Preferred
606        let url = "exa://user:pass@localhost:8563?compression=preferred";
607        let options = ExaConnectOptions::from_str(url).unwrap();
608        assert_eq!(options.compression_mode, ExaCompressionMode::Preferred);
609
610        // Test ExaCompressionMode::Required
611        let url = "exa://user:pass@localhost:8563?compression=required";
612        let options = ExaConnectOptions::from_str(url).unwrap();
613        assert_eq!(options.compression_mode, ExaCompressionMode::Required);
614    }
615
616    #[test]
617    fn test_ssl_modes() {
618        // Test ExaSslMode::Disable
619        let url = "exa://user:pass@localhost:8563?ssl-mode=disabled";
620        let options = ExaConnectOptions::from_str(url).unwrap();
621        assert_eq!(options.ssl_mode, ExaSslMode::Disabled);
622
623        // Test ExaSslMode::Preferred
624        let url = "exa://user:pass@localhost:8563?ssl-mode=preferred";
625        let options = ExaConnectOptions::from_str(url).unwrap();
626        assert_eq!(options.ssl_mode, ExaSslMode::Preferred);
627
628        // Test ExaSslMode::Required
629        let url = "exa://user:pass@localhost:8563?ssl-mode=required";
630        let options = ExaConnectOptions::from_str(url).unwrap();
631        assert_eq!(options.ssl_mode, ExaSslMode::Required);
632    }
633
634    #[test]
635    fn test_compression_and_ssl_modes_together() {
636        let url = "exa://user:pass@localhost:8563?compression=required&ssl-mode=required";
637        let options = ExaConnectOptions::from_str(url).unwrap();
638        assert_eq!(options.compression_mode, ExaCompressionMode::Required);
639        assert_eq!(options.ssl_mode, ExaSslMode::Required);
640    }
641
642    #[test]
643    fn test_compression_mode_to_url_lossy() {
644        // Test that compression modes are correctly serialized back to URL
645        let options = ExaConnectOptions::builder()
646            .host("localhost".to_string())
647            .username("user".to_string())
648            .password("pass".to_string())
649            .compression_mode(ExaCompressionMode::Required)
650            .build()
651            .unwrap();
652
653        let url = options.to_url_lossy();
654        let query_pairs: std::collections::HashMap<String, String> =
655            url.query_pairs().into_owned().collect();
656
657        assert_eq!(query_pairs.get(COMPRESSION), Some(&"required".to_string()));
658    }
659
660    #[test]
661    fn test_ssl_mode_to_url_lossy() {
662        // Test that SSL modes are correctly serialized back to URL
663        let options = ExaConnectOptions::builder()
664            .host("localhost".to_string())
665            .username("user".to_string())
666            .password("pass".to_string())
667            .ssl_mode(ExaSslMode::Required)
668            .build()
669            .unwrap();
670
671        let url = options.to_url_lossy();
672        let query_pairs: std::collections::HashMap<String, String> =
673            url.query_pairs().into_owned().collect();
674
675        assert_eq!(query_pairs.get(SSL_MODE), Some(&"required".to_string()));
676    }
677}