1use std::io;
2
3use tokio::net::TcpStream;
4use tokio_rustls::{
5 TlsConnector,
6 rustls::{
7 self,
8 pki_types::{InvalidDnsNameError, ServerName},
9 },
10};
11use watermelon_net::{
12 Connection, StreamingConnection, connect_tcp,
13 error::{ConnectionReadError, StreamingReadError},
14 proto_connect,
15};
16#[cfg(feature = "websocket")]
17use watermelon_net::{WebsocketConnection, error::WebsocketReadError};
18#[cfg(feature = "websocket")]
19use watermelon_proto::proto::error::FrameDecoderError;
20use watermelon_proto::{
21 Connect, Host, NonStandardConnect, Protocol, ServerAddr, ServerInfo, Transport,
22 proto::{ServerOp, error::DecoderError},
23};
24
25use crate::{ConnectFlags, ConnectionCompression, util::MaybeConnection};
26
27use super::{
28 authenticator::{AuthenticationError, AuthenticationMethod},
29 connection::ConnectionSecurity,
30};
31
32#[derive(Debug, thiserror::Error)]
33pub enum ConnectError {
34 #[error("io error")]
35 Io(#[source] io::Error),
36 #[error("TLS error")]
37 Tls(rustls::Error),
38 #[error("invalid DNS name")]
39 InvalidDnsName(#[source] InvalidDnsNameError),
40 #[error("websocket not supported")]
41 WebsocketUnsupported,
42 #[error("unexpected ServerOp")]
43 UnexpectedServerOp,
44 #[error("decoder error")]
45 Decoder(#[source] DecoderError),
46 #[error("authentication error")]
47 Authentication(#[source] AuthenticationError),
48 #[error("connect")]
49 Connect(#[source] watermelon_net::error::ConnectError),
50}
51
52#[expect(clippy::too_many_lines)]
53pub(crate) async fn connect(
54 connector: &TlsConnector,
55 addr: &ServerAddr,
56 client_name: String,
57 auth_method: Option<&AuthenticationMethod>,
58 flags: ConnectFlags,
59) -> Result<
60 (
61 Connection<
62 ConnectionCompression<ConnectionSecurity<TcpStream>>,
63 ConnectionSecurity<TcpStream>,
64 >,
65 Box<ServerInfo>,
66 ),
67 ConnectError,
68> {
69 let conn = connect_tcp(addr).await.map_err(ConnectError::Io)?;
70 conn.set_nodelay(true).map_err(ConnectError::Io)?;
71 let mut conn = ConnectionSecurity::Plain(conn);
72
73 if matches!(addr.protocol(), Protocol::TLS) {
74 let domain = rustls_server_name_from_addr(addr).map_err(ConnectError::InvalidDnsName)?;
75 conn = conn
76 .upgrade_tls(connector, domain.to_owned())
77 .await
78 .map_err(ConnectError::Io)?;
79 }
80
81 let mut conn = match addr.transport() {
82 Transport::TCP => Connection::Streaming(StreamingConnection::new(conn)),
83 #[cfg(feature = "websocket")]
84 Transport::Websocket => {
85 let uri = addr.to_string().parse().unwrap();
86 Connection::Websocket(
87 WebsocketConnection::new(uri, conn)
88 .await
89 .map_err(ConnectError::Io)?,
90 )
91 }
92 #[cfg(not(feature = "websocket"))]
93 Transport::Websocket => return Err(ConnectError::WebsocketUnsupported),
94 };
95 let info = match conn.read_next().await {
96 Ok(ServerOp::Info { info }) => info,
97 Ok(_) => return Err(ConnectError::UnexpectedServerOp),
98 Err(ConnectionReadError::Streaming(StreamingReadError::Io(err))) => {
99 return Err(ConnectError::Io(err));
100 }
101 Err(ConnectionReadError::Streaming(StreamingReadError::Decoder(err))) => {
102 return Err(ConnectError::Decoder(err));
103 }
104 #[cfg(feature = "websocket")]
105 Err(ConnectionReadError::Websocket(WebsocketReadError::Io(err))) => {
106 return Err(ConnectError::Io(err));
107 }
108 #[cfg(feature = "websocket")]
109 Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder(
110 FrameDecoderError::Decoder(err),
111 ))) => return Err(ConnectError::Decoder(err)),
112 #[cfg(feature = "websocket")]
113 Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder(
114 FrameDecoderError::IncompleteFrame,
115 ))) => todo!(),
116 #[cfg(feature = "websocket")]
117 Err(ConnectionReadError::Websocket(WebsocketReadError::Closed)) => todo!(),
118 };
119
120 let conn = match conn {
121 Connection::Streaming(streaming) => Connection::Streaming(
122 if matches!(
123 (addr.protocol(), info.tls_required),
124 (Protocol::PossiblyPlain, true)
125 ) {
126 let domain =
127 rustls_server_name_from_addr(addr).map_err(ConnectError::InvalidDnsName)?;
128 StreamingConnection::new(
129 streaming
130 .into_inner()
131 .upgrade_tls(connector, domain.to_owned())
132 .await
133 .map_err(ConnectError::Io)?,
134 )
135 } else {
136 streaming
137 },
138 ),
139 Connection::Websocket(websocket) => Connection::Websocket(websocket),
140 };
141
142 let auth;
143 let auth_method = if let Some(auth_method) = auth_method {
144 Some(auth_method)
145 } else if let Some(auth_method) = AuthenticationMethod::try_from_addr(addr) {
146 auth = auth_method;
147 Some(&auth)
148 } else {
149 None
150 };
151
152 #[allow(unused_mut)]
153 let mut non_standard = NonStandardConnect::default();
154 #[cfg(feature = "non-standard-zstd")]
155 if matches!(conn, Connection::Streaming(_)) {
156 non_standard.zstd = flags.zstd_compression_level.is_some() && info.non_standard.zstd;
157 }
158
159 let mut connect = Connect {
160 verbose: true,
161 pedantic: false,
162 require_tls: false,
163 auth_token: None,
164 username: None,
165 password: None,
166 client_name: Some(client_name),
167 client_lang: "rust-watermelon",
168 client_version: env!("CARGO_PKG_VERSION"),
169 protocol: 1,
170 echo: flags.echo,
171 signature: None,
172 jwt: None,
173 supports_no_responders: true,
174 supports_headers: true,
175 nkey: None,
176 non_standard,
177 };
178 if let Some(auth_method) = auth_method {
179 auth_method
180 .prepare_for_auth(&info, &mut connect)
181 .map_err(ConnectError::Authentication)?;
182 }
183
184 let mut conn = match conn {
185 Connection::Streaming(streaming) => {
186 Connection::Streaming(streaming.replace_socket(|stream| {
187 MaybeConnection(Some(ConnectionCompression::Plain(stream)))
188 }))
189 }
190 Connection::Websocket(websocket) => Connection::Websocket(websocket),
191 };
192
193 #[cfg(feature = "non-standard-zstd")]
194 let zstd = connect.non_standard.zstd;
195
196 proto_connect(&mut conn, connect, |conn| {
197 #[cfg(feature = "non-standard-zstd")]
198 match conn {
199 Connection::Streaming(streaming) => {
200 if zstd {
201 if let Some(zstd_compression_level) = flags.zstd_compression_level {
202 let stream = streaming.socket_mut().0.take().unwrap();
203 streaming.socket_mut().0 =
204 Some(stream.upgrade_zstd(zstd_compression_level));
205 }
206 }
207 }
208 Connection::Websocket(_websocket) => {}
209 }
210
211 let _ = conn;
212 })
213 .await
214 .map_err(ConnectError::Connect)?;
215
216 let conn = match conn {
217 Connection::Streaming(streaming) => {
218 Connection::Streaming(streaming.replace_socket(|stream| stream.0.unwrap()))
219 }
220 Connection::Websocket(websocket) => Connection::Websocket(websocket),
221 };
222
223 Ok((conn, info))
224}
225
226fn rustls_server_name_from_addr(addr: &ServerAddr) -> Result<ServerName<'_>, InvalidDnsNameError> {
227 match addr.host() {
228 Host::Ip(addr) => Ok(ServerName::IpAddress((*addr).into())),
229 Host::Dns(name) => <_ as AsRef<str>>::as_ref(name).try_into(),
230 }
231}