1#![forbid(unsafe_code)]
2
3use std::sync::Arc;
4
5use rustls_platform_verifier::Verifier;
6use tokio::net::TcpStream;
7use tokio_rustls::{
8 TlsConnector,
9 rustls::{self, ClientConfig, crypto::CryptoProvider, version::TLS13},
10};
11use watermelon_net::Connection;
12use watermelon_proto::{ServerAddr, ServerInfo};
13
14#[cfg(feature = "non-standard-zstd")]
15pub use self::non_standard_zstd::ZstdStream;
16use self::proto::connect;
17pub use self::proto::{
18 AuthenticationMethod, ConnectError, ConnectionCompression, ConnectionSecurity,
19};
20
21#[cfg(feature = "non-standard-zstd")]
22pub(crate) mod non_standard_zstd;
23mod proto;
24mod util;
25
26#[derive(Debug, Clone)]
27#[non_exhaustive]
28pub struct ConnectFlags {
29 pub echo: bool,
30 #[cfg(feature = "non-standard-zstd")]
31 pub zstd_compression_level: Option<u8>,
32}
33
34#[cfg_attr(not(feature = "non-standard-zstd"), expect(clippy::derivable_impls))]
35impl Default for ConnectFlags {
36 fn default() -> Self {
37 Self {
38 echo: false,
39 #[cfg(feature = "non-standard-zstd")]
40 zstd_compression_level: Some(3),
41 }
42 }
43}
44
45#[expect(
54 clippy::missing_panics_doc,
55 reason = "the crypto_provider function always returns a provider that supports TLS 1.3"
56)]
57pub async fn easy_connect(
58 addr: &ServerAddr,
59 auth: Option<&AuthenticationMethod>,
60 flags: ConnectFlags,
61) -> Result<
62 (
63 Connection<
64 ConnectionCompression<ConnectionSecurity<TcpStream>>,
65 ConnectionSecurity<TcpStream>,
66 >,
67 Box<ServerInfo>,
68 ),
69 ConnectError,
70> {
71 let provider = Arc::new(crypto_provider());
72 let connector = TlsConnector::from(Arc::new(
73 ClientConfig::builder_with_provider(Arc::clone(&provider))
74 .with_protocol_versions(&[&TLS13])
75 .unwrap()
76 .dangerous()
77 .with_custom_certificate_verifier(Arc::new(
78 Verifier::new(provider).map_err(ConnectError::Tls)?,
79 ))
80 .with_no_client_auth(),
81 ));
82
83 let (conn, info) = connect(&connector, addr, "watermelon".to_owned(), auth, flags).await?;
84 Ok((conn, info))
85}
86
87fn crypto_provider() -> CryptoProvider {
88 #[cfg(feature = "aws-lc-rs")]
89 return rustls::crypto::aws_lc_rs::default_provider();
90 #[cfg(all(not(feature = "aws-lc-rs"), feature = "ring"))]
91 return rustls::crypto::ring::default_provider();
92 #[cfg(not(any(feature = "aws-lc-rs", feature = "ring")))]
93 compile_error!("Please enable the `aws-lc-rs` or the `ring` feature")
94}