1use std::io;
2
3use tokio::net::TcpStream;
4use tokio_rustls::{
5 TlsConnector,
6 rustls::{
7 self,
8 pki_types::{InvalidDnsNameError, ServerName},
9 },
10};
11use watermelon_net::{
12 Connection, StreamingConnection, connect_tcp,
13 error::{ConnectionReadError, StreamingReadError},
14 proto_connect,
15};
16#[cfg(feature = "websocket")]
17use watermelon_net::{WebsocketConnection, error::WebsocketReadError};
18#[cfg(feature = "websocket")]
19use watermelon_proto::proto::error::FrameDecoderError;
20use watermelon_proto::{
21 Connect, Host, NonStandardConnect, Protocol, ServerAddr, ServerInfo, Transport,
22 proto::{ServerOp, error::DecoderError},
23};
24
25use crate::{ConnectFlags, ConnectionCompression, util::MaybeConnection};
26
27use super::{
28 authenticator::{AuthenticationError, AuthenticationMethod},
29 connection::ConnectionSecurity,
30};
31
32#[derive(Debug, thiserror::Error)]
33pub enum ConnectError {
34 #[error("io error")]
35 Io(#[source] io::Error),
36 #[error("TLS error")]
37 Tls(rustls::Error),
38 #[error("invalid DNS name")]
39 InvalidDnsName(#[source] InvalidDnsNameError),
40 #[error("websocket not supported")]
41 WebsocketUnsupported,
42 #[error("unexpected ServerOp")]
43 UnexpectedServerOp,
44 #[error("decoder error")]
45 Decoder(#[source] DecoderError),
46 #[error("authentication error")]
47 Authentication(#[source] AuthenticationError),
48 #[error("connect")]
49 Connect(#[source] watermelon_net::error::ConnectError),
50}
51
52#[expect(clippy::too_many_lines)]
53pub(crate) async fn connect(
54 connector: &TlsConnector,
55 addr: &ServerAddr,
56 client_name: String,
57 auth_method: Option<&AuthenticationMethod>,
58 flags: ConnectFlags,
59) -> Result<
60 (
61 Connection<
62 ConnectionCompression<ConnectionSecurity<TcpStream>>,
63 ConnectionSecurity<TcpStream>,
64 >,
65 Box<ServerInfo>,
66 ),
67 ConnectError,
68> {
69 let conn = connect_tcp(addr).await.map_err(ConnectError::Io)?;
70 conn.set_nodelay(flags.tcp_nodelay)
71 .map_err(ConnectError::Io)?;
72 let mut conn = ConnectionSecurity::Plain(conn);
73
74 if matches!(addr.protocol(), Protocol::TLS) {
75 let domain = rustls_server_name_from_addr(addr).map_err(ConnectError::InvalidDnsName)?;
76 conn = conn
77 .upgrade_tls(connector, domain.to_owned())
78 .await
79 .map_err(ConnectError::Io)?;
80 }
81
82 let mut conn = match addr.transport() {
83 Transport::TCP => Connection::Streaming(StreamingConnection::new(conn)),
84 #[cfg(feature = "websocket")]
85 Transport::Websocket => {
86 let uri = addr.to_string().parse().unwrap();
87 Connection::Websocket(
88 WebsocketConnection::new(uri, conn)
89 .await
90 .map_err(ConnectError::Io)?,
91 )
92 }
93 #[cfg(not(feature = "websocket"))]
94 Transport::Websocket => return Err(ConnectError::WebsocketUnsupported),
95 };
96 let info = match conn.read_next().await {
97 Ok(ServerOp::Info { info }) => info,
98 Ok(_) => return Err(ConnectError::UnexpectedServerOp),
99 Err(ConnectionReadError::Streaming(StreamingReadError::Io(err))) => {
100 return Err(ConnectError::Io(err));
101 }
102 Err(ConnectionReadError::Streaming(StreamingReadError::Decoder(err))) => {
103 return Err(ConnectError::Decoder(err));
104 }
105 #[cfg(feature = "websocket")]
106 Err(ConnectionReadError::Websocket(WebsocketReadError::Io(err))) => {
107 return Err(ConnectError::Io(err));
108 }
109 #[cfg(feature = "websocket")]
110 Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder(
111 FrameDecoderError::Decoder(err),
112 ))) => return Err(ConnectError::Decoder(err)),
113 #[cfg(feature = "websocket")]
114 Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder(
115 FrameDecoderError::IncompleteFrame,
116 ))) => todo!(),
117 #[cfg(feature = "websocket")]
118 Err(ConnectionReadError::Websocket(WebsocketReadError::Closed)) => todo!(),
119 };
120
121 let conn = match conn {
122 Connection::Streaming(streaming) => Connection::Streaming(
123 if matches!(
124 (addr.protocol(), info.tls_required),
125 (Protocol::PossiblyPlain, true)
126 ) {
127 let domain =
128 rustls_server_name_from_addr(addr).map_err(ConnectError::InvalidDnsName)?;
129 StreamingConnection::new(
130 streaming
131 .into_inner()
132 .upgrade_tls(connector, domain.to_owned())
133 .await
134 .map_err(ConnectError::Io)?,
135 )
136 } else {
137 streaming
138 },
139 ),
140 Connection::Websocket(websocket) => Connection::Websocket(websocket),
141 };
142
143 let auth;
144 let auth_method = if let Some(auth_method) = auth_method {
145 Some(auth_method)
146 } else if let Some(auth_method) = AuthenticationMethod::try_from_addr(addr) {
147 auth = auth_method;
148 Some(&auth)
149 } else {
150 None
151 };
152
153 #[allow(unused_mut)]
154 let mut non_standard = NonStandardConnect::default();
155 #[cfg(feature = "non-standard-zstd")]
156 if matches!(conn, Connection::Streaming(_)) {
157 non_standard.zstd = flags.zstd_compression_level.is_some() && info.non_standard.zstd;
158 }
159
160 let mut connect = Connect {
161 verbose: true,
162 pedantic: false,
163 require_tls: false,
164 auth_token: None,
165 username: None,
166 password: None,
167 client_name: Some(client_name),
168 client_lang: "rust-watermelon",
169 client_version: env!("CARGO_PKG_VERSION"),
170 protocol: 1,
171 echo: flags.echo,
172 signature: None,
173 jwt: None,
174 supports_no_responders: true,
175 supports_headers: true,
176 nkey: None,
177 non_standard,
178 };
179 if let Some(auth_method) = auth_method {
180 auth_method
181 .prepare_for_auth(&info, &mut connect)
182 .map_err(ConnectError::Authentication)?;
183 }
184
185 let mut conn = match conn {
186 Connection::Streaming(streaming) => {
187 Connection::Streaming(streaming.replace_socket(|stream| {
188 MaybeConnection(Some(ConnectionCompression::Plain(stream)))
189 }))
190 }
191 Connection::Websocket(websocket) => Connection::Websocket(websocket),
192 };
193
194 #[cfg(feature = "non-standard-zstd")]
195 let zstd = connect.non_standard.zstd;
196
197 proto_connect(&mut conn, connect, |conn| {
198 #[cfg(feature = "non-standard-zstd")]
199 match conn {
200 Connection::Streaming(streaming) => {
201 if zstd {
202 if let Some(zstd_compression_level) = flags.zstd_compression_level {
203 let stream = streaming.socket_mut().0.take().unwrap();
204 streaming.socket_mut().0 =
205 Some(stream.upgrade_zstd(zstd_compression_level));
206 }
207 }
208 }
209 Connection::Websocket(_websocket) => {}
210 }
211
212 let _ = conn;
213 })
214 .await
215 .map_err(ConnectError::Connect)?;
216
217 let conn = match conn {
218 Connection::Streaming(streaming) => {
219 Connection::Streaming(streaming.replace_socket(|stream| stream.0.unwrap()))
220 }
221 Connection::Websocket(websocket) => Connection::Websocket(websocket),
222 };
223
224 Ok((conn, info))
225}
226
227fn rustls_server_name_from_addr(addr: &ServerAddr) -> Result<ServerName<'_>, InvalidDnsNameError> {
228 match addr.host() {
229 Host::Ip(addr) => Ok(ServerName::IpAddress((*addr).into())),
230 Host::Dns(name) => <_ as AsRef<str>>::as_ref(name).try_into(),
231 }
232}