1use std::io;
2
3use tokio::net::TcpStream;
4use tokio_rustls::{
5 rustls::pki_types::{InvalidDnsNameError, ServerName},
6 TlsConnector,
7};
8use watermelon_net::{
9 connect_tcp,
10 error::{ConnectionReadError, StreamingReadError},
11 proto_connect, Connection, StreamingConnection,
12};
13#[cfg(feature = "websocket")]
14use watermelon_net::{error::WebsocketReadError, WebsocketConnection};
15#[cfg(feature = "websocket")]
16use watermelon_proto::proto::error::FrameDecoderError;
17use watermelon_proto::{
18 proto::{error::DecoderError, ServerOp},
19 Connect, Host, NonStandardConnect, Protocol, ServerAddr, ServerInfo, Transport,
20};
21
22use crate::{util::MaybeConnection, ConnectFlags, ConnectionCompression};
23
24use super::{
25 authenticator::{AuthenticationError, AuthenticationMethod},
26 connection::ConnectionSecurity,
27};
28
29#[derive(Debug, thiserror::Error)]
30pub enum ConnectError {
31 #[error("io error")]
32 Io(#[source] io::Error),
33 #[error("invalid DNS name")]
34 InvalidDnsName(#[source] InvalidDnsNameError),
35 #[error("websocket not supported")]
36 WebsocketUnsupported,
37 #[error("unexpected ServerOp")]
38 UnexpectedServerOp,
39 #[error("decoder error")]
40 Decoder(#[source] DecoderError),
41 #[error("authentication error")]
42 Authentication(#[source] AuthenticationError),
43 #[error("connect")]
44 Connect(#[source] watermelon_net::error::ConnectError),
45}
46
47#[expect(clippy::too_many_lines)]
48pub(crate) async fn connect(
49 connector: &TlsConnector,
50 addr: &ServerAddr,
51 client_name: String,
52 auth_method: Option<&AuthenticationMethod>,
53 flags: ConnectFlags,
54) -> Result<
55 (
56 Connection<
57 ConnectionCompression<ConnectionSecurity<TcpStream>>,
58 ConnectionSecurity<TcpStream>,
59 >,
60 Box<ServerInfo>,
61 ),
62 ConnectError,
63> {
64 let conn = connect_tcp(addr).await.map_err(ConnectError::Io)?;
65 conn.set_nodelay(true).map_err(ConnectError::Io)?;
66 let mut conn = ConnectionSecurity::Plain(conn);
67
68 if matches!(addr.protocol(), Protocol::TLS) {
69 let domain = rustls_server_name_from_addr(addr).map_err(ConnectError::InvalidDnsName)?;
70 conn = conn
71 .upgrade_tls(connector, domain.to_owned())
72 .await
73 .map_err(ConnectError::Io)?;
74 }
75
76 let mut conn = match addr.transport() {
77 Transport::TCP => Connection::Streaming(StreamingConnection::new(conn)),
78 #[cfg(feature = "websocket")]
79 Transport::Websocket => {
80 let uri = addr.to_string().parse().unwrap();
81 Connection::Websocket(
82 WebsocketConnection::new(uri, conn)
83 .await
84 .map_err(ConnectError::Io)?,
85 )
86 }
87 #[cfg(not(feature = "websocket"))]
88 Transport::Websocket => return Err(ConnectError::WebsocketUnsupported),
89 };
90 let info = match conn.read_next().await {
91 Ok(ServerOp::Info { info }) => info,
92 Ok(_) => return Err(ConnectError::UnexpectedServerOp),
93 Err(ConnectionReadError::Streaming(StreamingReadError::Io(err))) => {
94 return Err(ConnectError::Io(err))
95 }
96 Err(ConnectionReadError::Streaming(StreamingReadError::Decoder(err))) => {
97 return Err(ConnectError::Decoder(err))
98 }
99 #[cfg(feature = "websocket")]
100 Err(ConnectionReadError::Websocket(WebsocketReadError::Io(err))) => {
101 return Err(ConnectError::Io(err))
102 }
103 #[cfg(feature = "websocket")]
104 Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder(
105 FrameDecoderError::Decoder(err),
106 ))) => return Err(ConnectError::Decoder(err)),
107 #[cfg(feature = "websocket")]
108 Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder(
109 FrameDecoderError::IncompleteFrame,
110 ))) => todo!(),
111 #[cfg(feature = "websocket")]
112 Err(ConnectionReadError::Websocket(WebsocketReadError::Closed)) => todo!(),
113 };
114
115 let conn = match conn {
116 Connection::Streaming(streaming) => Connection::Streaming(
117 if matches!(
118 (addr.protocol(), info.tls_required),
119 (Protocol::PossiblyPlain, true)
120 ) {
121 let domain =
122 rustls_server_name_from_addr(addr).map_err(ConnectError::InvalidDnsName)?;
123 StreamingConnection::new(
124 streaming
125 .into_inner()
126 .upgrade_tls(connector, domain.to_owned())
127 .await
128 .map_err(ConnectError::Io)?,
129 )
130 } else {
131 streaming
132 },
133 ),
134 Connection::Websocket(websocket) => Connection::Websocket(websocket),
135 };
136
137 let auth;
138 let auth_method = if let Some(auth_method) = auth_method {
139 Some(auth_method)
140 } else if let Some(auth_method) = AuthenticationMethod::try_from_addr(addr) {
141 auth = auth_method;
142 Some(&auth)
143 } else {
144 None
145 };
146
147 #[allow(unused_mut)]
148 let mut non_standard = NonStandardConnect::default();
149 #[cfg(feature = "non-standard-zstd")]
150 if matches!(conn, Connection::Streaming(_)) {
151 non_standard.zstd = flags.zstd_compression_level.is_some() && info.non_standard.zstd;
152 }
153
154 let mut connect = Connect {
155 verbose: true,
156 pedantic: false,
157 require_tls: false,
158 auth_token: None,
159 username: None,
160 password: None,
161 client_name: Some(client_name),
162 client_lang: "rust-watermelon",
163 client_version: env!("CARGO_PKG_VERSION"),
164 protocol: 1,
165 echo: flags.echo,
166 signature: None,
167 jwt: None,
168 supports_no_responders: true,
169 supports_headers: true,
170 nkey: None,
171 non_standard,
172 };
173 if let Some(auth_method) = auth_method {
174 auth_method
175 .prepare_for_auth(&info, &mut connect)
176 .map_err(ConnectError::Authentication)?;
177 }
178
179 let mut conn = match conn {
180 Connection::Streaming(streaming) => {
181 Connection::Streaming(streaming.replace_socket(|stream| {
182 MaybeConnection(Some(ConnectionCompression::Plain(stream)))
183 }))
184 }
185 Connection::Websocket(websocket) => Connection::Websocket(websocket),
186 };
187
188 #[cfg(feature = "non-standard-zstd")]
189 let zstd = connect.non_standard.zstd;
190
191 proto_connect(&mut conn, connect, |conn| {
192 #[cfg(feature = "non-standard-zstd")]
193 match conn {
194 Connection::Streaming(streaming) => {
195 if zstd {
196 if let Some(zstd_compression_level) = flags.zstd_compression_level {
197 let stream = streaming.socket_mut().0.take().unwrap();
198 streaming.socket_mut().0 =
199 Some(stream.upgrade_zstd(zstd_compression_level));
200 }
201 }
202 }
203 Connection::Websocket(_websocket) => {}
204 }
205
206 let _ = conn;
207 })
208 .await
209 .map_err(ConnectError::Connect)?;
210
211 let conn = match conn {
212 Connection::Streaming(streaming) => {
213 Connection::Streaming(streaming.replace_socket(|stream| stream.0.unwrap()))
214 }
215 Connection::Websocket(websocket) => Connection::Websocket(websocket),
216 };
217
218 Ok((conn, info))
219}
220
221fn rustls_server_name_from_addr(addr: &ServerAddr) -> Result<ServerName<'_>, InvalidDnsNameError> {
222 match addr.host() {
223 Host::Ip(addr) => Ok(ServerName::IpAddress((*addr).into())),
224 Host::Dns(name) => <_ as AsRef<str>>::as_ref(name).try_into(),
225 }
226}