1use std::{fmt::Display, num::ParseIntError, str::FromStr, time::Duration};
2
3use futures_core::future::BoxFuture;
4use log::LevelFilter;
5use scylla::{client::session::TlsContext, frame::Compression};
6use sqlx::{ConnectOptions, Error};
7use sqlx_core::connection::LogSettings;
8use url::Url;
9
10use crate::{ScyllaDBError, connection::ScyllaDBConnection};
11
12const DEFAULT_PORT: u16 = 9042;
13const DEFAULT_PAGE_SIZE: i32 = 5000;
14const DEFAULT_STATEMENT_CACHE_CAPACITY: usize = 128;
15
16#[derive(Debug, Clone)]
18pub struct ScyllaDBConnectOptions {
19 pub(crate) nodes: Vec<String>,
20 pub(crate) keyspace: Option<String>,
21 pub(crate) statement_cache_capacity: usize,
22 pub(crate) log_settings: LogSettings,
23 pub(crate) tcp_nodelay: bool,
24 pub(crate) authentication_options: Option<ScyllaDBAuthenticationOptions>,
25 pub(crate) replication_options: Option<ScyllaDBReplicationOptions>,
26 pub(crate) compression_options: Option<ScyllaDBCompressionOptions>,
27 pub(crate) tls_options: Option<ScyllaDBTLSOptions>,
28 pub(crate) tcp_keepalive: Option<Duration>,
29 pub(crate) page_size: i32,
30}
31
32impl ScyllaDBConnectOptions {
33 pub(crate) fn parse_from_url(url: &Url) -> Result<Self, Error> {
34 let mut options = Self::new();
35
36 let host = url.host_str();
37 if let Some(host) = host {
38 let port = url.port().unwrap_or(DEFAULT_PORT);
39 let node = format!("{}:{}", host, port);
40 options = options.add_node(node);
41 }
42
43 let path = url.path().trim_start_matches('/');
44 if !path.is_empty() {
45 options = options.keyspace(path);
46 }
47
48 let username = url.username();
49 if !username.is_empty() {
50 let password = url.password().unwrap_or_default();
51 options = options.user_authentication(username, password);
52 }
53
54 let query_pairs = url.query_pairs();
55 for (key, value) in query_pairs {
56 match key.as_ref() {
57 "nodes" => {
58 let nodes = value.split(",");
59 for node in nodes {
60 options = options.add_node(node);
61 }
62 }
63 "replication_strategy" => {
64 let strategy = ScyllaDBReplicationStrategy::from_str(&value)?;
65 options = options.replication_strategy(strategy);
66 }
67 "replication_factor" => {
68 let replication_factor = value.parse().map_err(|err: ParseIntError| {
69 let message = format!("Invalid replication_factor. {err}");
70 Error::Configuration(message.into())
71 })?;
72 options = options.replication_factor(replication_factor);
73 }
74 "compression" => {
75 let compressor = ScyllaDBCompressor::from_str(&value)?;
76 options = options.compressor(compressor);
77 }
78 "tcp_nodelay" => {
79 options = options.tcp_nodelay();
80 }
81 "tcp_keepalive" => {
82 let secs = value.parse().map_err(|err: ParseIntError| {
83 let message = format!("Invalid tcp_keepalive. {err}");
84 Error::Configuration(message.into())
85 })?;
86 options = options.tcp_keepalive(secs);
87 }
88 "page_size" => {
89 let page_size = value.parse().map_err(|err: ParseIntError| {
90 let message = format!("Invalid page_size. {err}");
91 Error::Configuration(message.into())
92 })?;
93 options = options.page_size(page_size);
94 }
95 "tls_rootcert" => {
96 options = options.tls_rootcert(value.to_string());
97 }
98 "tls_cert" => {
99 options = options.tls_cert(value.to_string());
100 }
101 "tls_key" => {
102 options = options.tls_key(value.to_string());
103 }
104 _ => eprintln!("Not supported options. {key}"),
105 }
106 }
107
108 Ok(options)
109 }
110}
111
112impl ScyllaDBConnectOptions {
113 pub fn new() -> Self {
115 Self {
116 nodes: vec![],
117 keyspace: None,
118 statement_cache_capacity: DEFAULT_STATEMENT_CACHE_CAPACITY,
119 log_settings: Default::default(),
120 tcp_nodelay: false,
121 authentication_options: None,
122 replication_options: None,
123 compression_options: None,
124 tls_options: None,
125 tcp_keepalive: None,
126 page_size: DEFAULT_PAGE_SIZE,
127 }
128 }
129
130 pub fn nodes(mut self, nodes: Vec<String>) -> Self {
132 self.nodes = nodes;
133 self
134 }
135
136 pub fn add_node(mut self, node: impl Into<String>) -> Self {
138 self.nodes.push(node.into());
139 self
140 }
141
142 pub fn keyspace(mut self, keyspace: impl Into<String>) -> Self {
144 self.keyspace = Some(keyspace.into());
145 self
146 }
147
148 pub fn user_authentication(
150 mut self,
151 username: impl Into<String>,
152 password: impl Into<String>,
153 ) -> Self {
154 let authentication_options = ScyllaDBAuthenticationOptions {
155 username: username.into(),
156 password: password.into(),
157 };
158 self.authentication_options = Some(authentication_options);
159 self
160 }
161
162 pub fn replication_strategy(mut self, strategy: ScyllaDBReplicationStrategy) -> Self {
164 let mut replication_options = self.replication_options_or_default();
165 replication_options.strategy = strategy;
166 self.replication_options = Some(replication_options);
167 self
168 }
169
170 pub fn replication_factor(mut self, factor: usize) -> Self {
172 let mut replication_options = self.replication_options_or_default();
173 replication_options.replication_factor = factor;
174 self.replication_options = Some(replication_options);
175 self
176 }
177
178 pub fn compressor(mut self, compressor: ScyllaDBCompressor) -> Self {
180 self.compression_options = Some(ScyllaDBCompressionOptions { compressor });
181 self
182 }
183
184 pub fn tls_rootcert(mut self, root_cert: impl Into<String>) -> Self {
186 let root_cert = root_cert.into();
187 if let Some(mut tls_options) = self.tls_options {
188 tls_options.root_cert = root_cert;
189 self.tls_options = Some(tls_options);
190 } else {
191 self.tls_options = Some(ScyllaDBTLSOptions {
192 root_cert,
193 ..Default::default()
194 });
195 }
196 self
197 }
198
199 pub fn tls_cert(mut self, cert: impl Into<String>) -> Self {
201 let cert = cert.into();
202 if let Some(mut tls_options) = self.tls_options {
203 tls_options.cert = Some(cert);
204 self.tls_options = Some(tls_options);
205 } else {
206 self.tls_options = Some(ScyllaDBTLSOptions {
207 cert: Some(cert),
208 ..Default::default()
209 });
210 }
211 self
212 }
213
214 pub fn tls_key(mut self, key: impl Into<String>) -> Self {
216 let key = key.into();
217 if let Some(mut tls_options) = self.tls_options {
218 tls_options.key = Some(key);
219 self.tls_options = Some(tls_options);
220 } else {
221 self.tls_options = Some(ScyllaDBTLSOptions {
222 key: Some(key),
223 ..Default::default()
224 });
225 }
226 self
227 }
228
229 pub fn tcp_nodelay(mut self) -> Self {
231 self.tcp_nodelay = true;
232 self
233 }
234
235 pub fn tcp_keepalive(mut self, secs: u64) -> Self {
237 self.tcp_keepalive = Some(Duration::from_secs(secs));
238 self
239 }
240
241 pub fn page_size(mut self, page_size: i32) -> Self {
243 self.page_size = page_size;
244 self
245 }
246
247 fn replication_options_or_default(&self) -> ScyllaDBReplicationOptions {
248 if let Some(replication_options) = self.replication_options {
249 replication_options
250 } else {
251 ScyllaDBReplicationOptions::default()
252 }
253 }
254}
255
256impl ConnectOptions for ScyllaDBConnectOptions {
257 type Connection = ScyllaDBConnection;
258
259 fn from_url(url: &Url) -> Result<Self, Error> {
260 Self::parse_from_url(url)
261 }
262
263 fn connect(&self) -> BoxFuture<'_, Result<Self::Connection, Error>>
264 where
265 Self::Connection: Sized,
266 {
267 Box::pin(async { ScyllaDBConnection::establish(self).await })
268 }
269
270 fn log_statements(mut self, level: LevelFilter) -> Self {
271 self.log_settings.log_statements(level);
272 self
273 }
274
275 fn log_slow_statements(mut self, level: LevelFilter, duration: Duration) -> Self {
276 self.log_settings.log_slow_statements(level, duration);
277 self
278 }
279}
280
281impl FromStr for ScyllaDBConnectOptions {
282 type Err = Error;
283
284 fn from_str(s: &str) -> Result<Self, Self::Err> {
285 let url: Url = s.parse().map_err(Error::config)?;
286 Self::from_url(&url)
287 }
288}
289
290#[derive(Debug, Clone)]
291pub(crate) struct ScyllaDBAuthenticationOptions {
292 pub(crate) username: String,
293 pub(crate) password: String,
294}
295
296#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
298pub enum ScyllaDBReplicationStrategy {
299 #[default]
301 SimpleStrategy,
302 NetworkTopologyStrategy,
304}
305
306impl FromStr for ScyllaDBReplicationStrategy {
307 type Err = ScyllaDBError;
308
309 fn from_str(s: &str) -> Result<Self, Self::Err> {
310 let class = match s {
311 "simple" => Self::SimpleStrategy,
312 "network_topology" => Self::NetworkTopologyStrategy,
313 "SimpleStrategy" => Self::SimpleStrategy,
314 "NetworkTopologyStrategy" => Self::NetworkTopologyStrategy,
315 _ => {
316 return Err(ScyllaDBError::ConfigurationError(format!(
317 "replication_strategy '{s}' is invalid."
318 )));
319 }
320 };
321
322 Ok(class)
323 }
324}
325
326impl Display for ScyllaDBReplicationStrategy {
327 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
328 match self {
329 ScyllaDBReplicationStrategy::SimpleStrategy => write!(f, "SimpleStrategy"),
330 ScyllaDBReplicationStrategy::NetworkTopologyStrategy => {
331 write!(f, "NetworkTopologyStrategy")
332 }
333 }
334 }
335}
336
337#[derive(Debug, Clone, Copy)]
338pub(crate) struct ScyllaDBReplicationOptions {
339 pub(crate) strategy: ScyllaDBReplicationStrategy,
340 pub(crate) replication_factor: usize,
341}
342
343impl Default for ScyllaDBReplicationOptions {
344 fn default() -> Self {
345 Self {
346 strategy: Default::default(),
347 replication_factor: 1,
348 }
349 }
350}
351
352#[derive(Debug, Clone, Copy, PartialEq, Eq)]
354pub enum ScyllaDBCompressor {
355 LZ4Compressor,
357 SnappyCompressor,
359}
360
361impl Into<Compression> for ScyllaDBCompressor {
362 fn into(self) -> Compression {
363 match self {
364 ScyllaDBCompressor::LZ4Compressor => Compression::Lz4,
365 ScyllaDBCompressor::SnappyCompressor => Compression::Snappy,
366 }
367 }
368}
369
370impl FromStr for ScyllaDBCompressor {
371 type Err = ScyllaDBError;
372
373 fn from_str(s: &str) -> Result<Self, Self::Err> {
374 let compressor = match s.to_ascii_lowercase().as_str() {
375 "lz4" => Self::LZ4Compressor,
376 "snappy" => Self::SnappyCompressor,
377 _ => {
378 return Err(ScyllaDBError::ConfigurationError(format!(
379 "compressor '{s}' is invalid."
380 )));
381 }
382 };
383
384 Ok(compressor)
385 }
386}
387
388#[derive(Debug, Clone, Copy)]
389pub(crate) struct ScyllaDBCompressionOptions {
390 pub(crate) compressor: ScyllaDBCompressor,
391}
392
393#[derive(Debug, Clone, Default)]
394pub(crate) struct ScyllaDBTLSOptions {
395 root_cert: String,
396 cert: Option<String>,
397 key: Option<String>,
398}
399
400impl TryInto<TlsContext> for ScyllaDBTLSOptions {
401 type Error = sqlx::Error;
402
403 fn try_into(self) -> Result<TlsContext, Self::Error> {
404 #[cfg(feature = "openssl-010")]
405 {
406 let ssl_context: openssl_010::ssl::SslContext = self.try_into()?;
407 return Ok(TlsContext::OpenSsl010(ssl_context));
408 }
409
410 #[allow(unreachable_code)]
411 #[cfg(feature = "rustls-023")]
412 {
413 let client_config: rustls_023::ClientConfig = self.try_into()?;
414 return Ok(TlsContext::Rustls023(std::sync::Arc::new(client_config)));
415 }
416
417 #[allow(unreachable_code)]
418 Err(Error::Configuration(
419 "To enable TLS, specify the ‘openssl-010’ or ‘rustls-023’ feature.".into(),
420 ))
421 }
422}
423
424#[cfg(feature = "openssl-010")]
425impl TryInto<openssl_010::ssl::SslContext> for ScyllaDBTLSOptions {
426 type Error = sqlx::Error;
427
428 fn try_into(self) -> Result<openssl_010::ssl::SslContext, Self::Error> {
429 use std::{fs, fs::File, io::Read, path::PathBuf};
430
431 use openssl_010::{
432 pkey::PKey,
433 ssl::{SslContextBuilder, SslMethod, SslVerifyMode},
434 x509::{X509, store::X509StoreBuilder},
435 };
436
437 let mut context_builder = SslContextBuilder::new(SslMethod::tls())
438 .map_err(|e| Error::Configuration(Box::new(e)))?;
439
440 let ca_path = fs::canonicalize(PathBuf::from(&self.root_cert))
441 .map_err(|e| Error::Configuration(Box::new(e)))?;
442 let mut ca_file = File::open(ca_path).map_err(|e| Error::Configuration(Box::new(e)))?;
443 let mut ca_buf = Vec::new();
444 ca_file
445 .read_to_end(&mut ca_buf)
446 .map_err(|e| Error::Configuration(Box::new(e)))?;
447 let ca_x509 = X509::from_pem(&ca_buf).map_err(|e| Error::Configuration(Box::new(e)))?;
448
449 let mut builder = X509StoreBuilder::new().map_err(|e| Error::Configuration(Box::new(e)))?;
450 builder
451 .add_cert(ca_x509)
452 .map_err(|e| Error::Configuration(Box::new(e)))?;
453 let cert_store = builder.build();
454
455 context_builder.set_cert_store(cert_store);
456
457 if let Some(cert) = &self.cert {
458 if let Some(key) = &self.key {
459 let cert_path = fs::canonicalize(PathBuf::from(cert))
460 .map_err(|e| Error::Configuration(Box::new(e)))?;
461 let mut cert_file =
462 File::open(cert_path).map_err(|e| Error::Configuration(Box::new(e)))?;
463 let mut cert_buf = Vec::new();
464 cert_file
465 .read_to_end(&mut cert_buf)
466 .map_err(|e| Error::Configuration(Box::new(e)))?;
467 let cert_x509 =
468 X509::from_pem(&cert_buf).map_err(|e| Error::Configuration(Box::new(e)))?;
469 context_builder
470 .set_certificate(&cert_x509)
471 .map_err(|e| Error::Configuration(Box::new(e)))?;
472
473 let key_path = fs::canonicalize(PathBuf::from(key))
474 .map_err(|e| Error::Configuration(Box::new(e)))?;
475 let mut key_file =
476 File::open(key_path).map_err(|e| Error::Configuration(Box::new(e)))?;
477 let mut key_buf = Vec::new();
478 key_file
479 .read_to_end(&mut key_buf)
480 .map_err(|e| Error::Configuration(Box::new(e)))?;
481 let pkey = PKey::private_key_from_pem(&key_buf)
482 .map_err(|e| Error::Configuration(Box::new(e)))?;
483 context_builder
484 .set_private_key(&pkey)
485 .map_err(|e| Error::Configuration(Box::new(e)))?;
486
487 context_builder.set_verify(SslVerifyMode::PEER);
488 } else {
489 return Err(Error::Configuration(
490 "Client private key is required.".into(),
491 ));
492 }
493 } else {
494 context_builder.set_verify(SslVerifyMode::NONE);
495 }
496
497 let context = context_builder.build();
498
499 Ok(context)
500 }
501}
502
503#[cfg(feature = "rustls-023")]
504impl TryInto<rustls_023::ClientConfig> for ScyllaDBTLSOptions {
505 type Error = sqlx::Error;
506
507 fn try_into(self) -> Result<rustls_023::ClientConfig, Self::Error> {
508 use rustls_023::{
509 ClientConfig, RootCertStore,
510 pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
511 };
512
513 let rustls_ca = CertificateDer::from_pem_file(&self.root_cert)
514 .map_err(|e| Error::Configuration(Box::new(e)))?;
515 let mut root_store = RootCertStore::empty();
516 root_store
517 .add(rustls_ca)
518 .map_err(|e| Error::Configuration(Box::new(e)))?;
519
520 let builder = ClientConfig::builder().with_root_certificates(root_store);
521
522 let client_config = if let Some(cert) = &self.cert {
523 if let Some(key) = &self.key {
524 let client_cert = CertificateDer::from_pem_file(cert)
525 .map_err(|e| Error::Configuration(Box::new(e)))?;
526 let priv_key = PrivateKeyDer::from_pem_file(key)
527 .map_err(|e| Error::Configuration(Box::new(e)))?;
528
529 builder
530 .with_client_auth_cert(vec![client_cert], priv_key)
531 .map_err(|e| Error::Configuration(Box::new(e)))?
532 } else {
533 return Err(Error::Configuration(
534 "Client private key is required.".into(),
535 ));
536 }
537 } else {
538 builder.with_no_client_auth()
539 };
540
541 Ok(client_config)
542 }
543}
544
545#[cfg(test)]
546mod tests {
547 use std::{str::FromStr, time::Duration};
548
549 use crate::{
550 ScyllaDBConnectOptions,
551 options::{ScyllaDBCompressor, ScyllaDBReplicationStrategy},
552 };
553
554 #[test]
555 fn test_parse_url() -> anyhow::Result<()> {
556 const URL: &'static str = "scylladb://my_name:my_passwd@localhost/my_keyspace?nodes=example.test,example2.test:9043&tcp_nodelay&tcp_keepalive=40&compression=lz4&replication_strategy=simple&replication_factor=2&page_size=10&tls_rootcert=/etc/tls/root.pem&tls_cert=/etc/tls/client.pem&tls_key=/etc/tls/client.key";
557 let options: ScyllaDBConnectOptions = URL.parse()?;
558
559 assert_eq!("my_keyspace", options.keyspace.unwrap());
560 assert!(options.tcp_nodelay);
561 assert_eq!(40, options.tcp_keepalive.unwrap().as_secs());
562
563 let authentication_options = options.authentication_options.clone().unwrap();
564 assert_eq!("my_name", &authentication_options.username);
565 assert_eq!("my_passwd", &authentication_options.password);
566
567 assert_eq!(
568 vec!["localhost:9042", "example.test", "example2.test:9043"],
569 options.nodes
570 );
571
572 let compression_options = options.compression_options.unwrap();
573 assert_eq!(
574 ScyllaDBCompressor::LZ4Compressor,
575 compression_options.compressor
576 );
577
578 let replication_options = options.replication_options.unwrap();
579 assert_eq!(
580 ScyllaDBReplicationStrategy::SimpleStrategy,
581 replication_options.strategy
582 );
583 assert_eq!(2, replication_options.replication_factor);
584
585 let page_size = options.page_size;
586 assert_eq!(10, page_size);
587
588 let tls_options = options.tls_options.unwrap();
589
590 assert_eq!("/etc/tls/root.pem", tls_options.root_cert);
591 assert_eq!("/etc/tls/client.pem", tls_options.cert.unwrap());
592 assert_eq!("/etc/tls/client.key", tls_options.key.unwrap());
593
594 Ok(())
595 }
596
597 #[test]
598 fn test_add_nodes() -> anyhow::Result<()> {
599 let options = ScyllaDBConnectOptions::new();
600
601 assert_eq!(0, options.nodes.len());
602
603 let options = options.add_node("example1.test:9043");
604
605 assert_eq!(vec!["example1.test:9043"], options.nodes);
606
607 Ok(())
608 }
609
610 #[test]
611 fn test_keyspace() -> anyhow::Result<()> {
612 let options = ScyllaDBConnectOptions::new();
613
614 assert!(options.keyspace.is_none());
615
616 let options = options.keyspace("test");
617
618 assert_eq!("test", options.keyspace.unwrap());
619
620 Ok(())
621 }
622
623 #[test]
624 fn test_user_authentication() -> anyhow::Result<()> {
625 let options = ScyllaDBConnectOptions::new();
626
627 assert!(options.authentication_options.is_none());
628
629 let options = options.user_authentication("my_name", "my_password");
630
631 let authentication_options = options.authentication_options.unwrap();
632 assert_eq!("my_name", &authentication_options.username);
633 assert_eq!("my_password", &authentication_options.password);
634
635 Ok(())
636 }
637
638 #[test]
639 fn test_replication_strategy() -> anyhow::Result<()> {
640 let options = ScyllaDBConnectOptions::new();
641
642 assert!(options.replication_options.is_none());
643
644 let options =
645 options.replication_strategy(ScyllaDBReplicationStrategy::NetworkTopologyStrategy);
646
647 let replication_options = options.replication_options.unwrap();
648 assert_eq!(
649 ScyllaDBReplicationStrategy::NetworkTopologyStrategy,
650 replication_options.strategy
651 );
652 assert_eq!(1, replication_options.replication_factor);
653
654 Ok(())
655 }
656
657 #[test]
658 fn test_replication_strategy_from_str() -> anyhow::Result<()> {
659 assert_eq!(
660 ScyllaDBReplicationStrategy::SimpleStrategy,
661 ScyllaDBReplicationStrategy::from_str("simple")?
662 );
663
664 assert_eq!(
665 ScyllaDBReplicationStrategy::SimpleStrategy,
666 ScyllaDBReplicationStrategy::from_str("SimpleStrategy")?
667 );
668
669 assert_eq!(
670 ScyllaDBReplicationStrategy::NetworkTopologyStrategy,
671 ScyllaDBReplicationStrategy::from_str("network_topology")?
672 );
673
674 assert_eq!(
675 ScyllaDBReplicationStrategy::NetworkTopologyStrategy,
676 ScyllaDBReplicationStrategy::from_str("NetworkTopologyStrategy")?
677 );
678
679 Ok(())
680 }
681
682 #[test]
683 fn test_replication_factor() -> anyhow::Result<()> {
684 let options = ScyllaDBConnectOptions::new();
685
686 assert!(options.replication_options.is_none());
687
688 let options = options.replication_factor(2);
689
690 let replication_options = options.replication_options.unwrap();
691 assert_eq!(
692 ScyllaDBReplicationStrategy::SimpleStrategy,
693 replication_options.strategy
694 );
695 assert_eq!(2, replication_options.replication_factor);
696
697 Ok(())
698 }
699
700 #[test]
701 fn test_compressor() -> anyhow::Result<()> {
702 let options = ScyllaDBConnectOptions::new();
703
704 assert!(options.compression_options.is_none());
705
706 let options = options.compressor(ScyllaDBCompressor::SnappyCompressor);
707
708 let compression_options = options.compression_options.unwrap();
709 assert_eq!(
710 ScyllaDBCompressor::SnappyCompressor,
711 compression_options.compressor
712 );
713
714 Ok(())
715 }
716
717 #[test]
718 fn test_tcp_nodelay() -> anyhow::Result<()> {
719 let options = ScyllaDBConnectOptions::new();
720
721 assert!(!options.tcp_nodelay);
722
723 let options = options.tcp_nodelay();
724
725 assert!(options.tcp_nodelay);
726
727 Ok(())
728 }
729
730 #[test]
731 fn test_tcp_keepalive() -> anyhow::Result<()> {
732 let options = ScyllaDBConnectOptions::new();
733
734 assert!(options.tcp_keepalive.is_none());
735
736 let options = options.tcp_keepalive(20);
737
738 assert_eq!(Duration::from_secs(20), options.tcp_keepalive.unwrap());
739
740 Ok(())
741 }
742
743 #[test]
744 fn test_page_size() -> anyhow::Result<()> {
745 let options = ScyllaDBConnectOptions::new();
746
747 assert_eq!(5000, options.page_size);
748
749 let options = options.page_size(200);
750
751 assert_eq!(200, options.page_size);
752
753 Ok(())
754 }
755}