watermelon_mini/proto/
connector.rs

1use std::io;
2
3use tokio::net::TcpStream;
4use tokio_rustls::{
5    TlsConnector,
6    rustls::{
7        self,
8        pki_types::{InvalidDnsNameError, ServerName},
9    },
10};
11use watermelon_net::{
12    Connection, StreamingConnection, connect_tcp,
13    error::{ConnectionReadError, StreamingReadError},
14    proto_connect,
15};
16#[cfg(feature = "websocket")]
17use watermelon_net::{WebsocketConnection, error::WebsocketReadError};
18#[cfg(feature = "websocket")]
19use watermelon_proto::proto::error::FrameDecoderError;
20use watermelon_proto::{
21    Connect, Host, NonStandardConnect, Protocol, ServerAddr, ServerInfo, Transport,
22    proto::{ServerOp, error::DecoderError},
23};
24
25use crate::{ConnectFlags, ConnectionCompression, util::MaybeConnection};
26
27use super::{
28    authenticator::{AuthenticationError, AuthenticationMethod},
29    connection::ConnectionSecurity,
30};
31
32#[derive(Debug, thiserror::Error)]
33pub enum ConnectError {
34    #[error("io error")]
35    Io(#[source] io::Error),
36    #[error("TLS error")]
37    Tls(rustls::Error),
38    #[error("invalid DNS name")]
39    InvalidDnsName(#[source] InvalidDnsNameError),
40    #[error("websocket not supported")]
41    WebsocketUnsupported,
42    #[error("unexpected ServerOp")]
43    UnexpectedServerOp,
44    #[error("decoder error")]
45    Decoder(#[source] DecoderError),
46    #[error("authentication error")]
47    Authentication(#[source] AuthenticationError),
48    #[error("connect")]
49    Connect(#[source] watermelon_net::error::ConnectError),
50}
51
52#[expect(clippy::too_many_lines)]
53pub(crate) async fn connect(
54    connector: &TlsConnector,
55    addr: &ServerAddr,
56    client_name: String,
57    auth_method: Option<&AuthenticationMethod>,
58    flags: ConnectFlags,
59) -> Result<
60    (
61        Connection<
62            ConnectionCompression<ConnectionSecurity<TcpStream>>,
63            ConnectionSecurity<TcpStream>,
64        >,
65        Box<ServerInfo>,
66    ),
67    ConnectError,
68> {
69    let conn = connect_tcp(addr).await.map_err(ConnectError::Io)?;
70    conn.set_nodelay(flags.tcp_nodelay)
71        .map_err(ConnectError::Io)?;
72    let mut conn = ConnectionSecurity::Plain(conn);
73
74    if matches!(addr.protocol(), Protocol::TLS) {
75        let domain = rustls_server_name_from_addr(addr).map_err(ConnectError::InvalidDnsName)?;
76        conn = conn
77            .upgrade_tls(connector, domain.to_owned())
78            .await
79            .map_err(ConnectError::Io)?;
80    }
81
82    let mut conn = match addr.transport() {
83        Transport::TCP => Connection::Streaming(StreamingConnection::new(conn)),
84        #[cfg(feature = "websocket")]
85        Transport::Websocket => {
86            let uri = addr.to_string().parse().unwrap();
87            Connection::Websocket(
88                WebsocketConnection::new(uri, conn)
89                    .await
90                    .map_err(ConnectError::Io)?,
91            )
92        }
93        #[cfg(not(feature = "websocket"))]
94        Transport::Websocket => return Err(ConnectError::WebsocketUnsupported),
95    };
96    let info = match conn.read_next().await {
97        Ok(ServerOp::Info { info }) => info,
98        Ok(_) => return Err(ConnectError::UnexpectedServerOp),
99        Err(ConnectionReadError::Streaming(StreamingReadError::Io(err))) => {
100            return Err(ConnectError::Io(err));
101        }
102        Err(ConnectionReadError::Streaming(StreamingReadError::Decoder(err))) => {
103            return Err(ConnectError::Decoder(err));
104        }
105        #[cfg(feature = "websocket")]
106        Err(ConnectionReadError::Websocket(WebsocketReadError::Io(err))) => {
107            return Err(ConnectError::Io(err));
108        }
109        #[cfg(feature = "websocket")]
110        Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder(
111            FrameDecoderError::Decoder(err),
112        ))) => return Err(ConnectError::Decoder(err)),
113        #[cfg(feature = "websocket")]
114        Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder(
115            FrameDecoderError::IncompleteFrame,
116        ))) => todo!(),
117        #[cfg(feature = "websocket")]
118        Err(ConnectionReadError::Websocket(WebsocketReadError::Closed)) => todo!(),
119    };
120
121    let conn = match conn {
122        Connection::Streaming(streaming) => Connection::Streaming(
123            if matches!(
124                (addr.protocol(), info.tls_required),
125                (Protocol::PossiblyPlain, true)
126            ) {
127                let domain =
128                    rustls_server_name_from_addr(addr).map_err(ConnectError::InvalidDnsName)?;
129                StreamingConnection::new(
130                    streaming
131                        .into_inner()
132                        .upgrade_tls(connector, domain.to_owned())
133                        .await
134                        .map_err(ConnectError::Io)?,
135                )
136            } else {
137                streaming
138            },
139        ),
140        Connection::Websocket(websocket) => Connection::Websocket(websocket),
141    };
142
143    let auth;
144    let auth_method = if let Some(auth_method) = auth_method {
145        Some(auth_method)
146    } else if let Some(auth_method) = AuthenticationMethod::try_from_addr(addr) {
147        auth = auth_method;
148        Some(&auth)
149    } else {
150        None
151    };
152
153    #[allow(unused_mut)]
154    let mut non_standard = NonStandardConnect::default();
155    #[cfg(feature = "non-standard-zstd")]
156    if matches!(conn, Connection::Streaming(_)) {
157        non_standard.zstd = flags.zstd_compression_level.is_some() && info.non_standard.zstd;
158    }
159
160    let mut connect = Connect {
161        verbose: true,
162        pedantic: false,
163        require_tls: false,
164        auth_token: None,
165        username: None,
166        password: None,
167        client_name: Some(client_name),
168        client_lang: "rust-watermelon",
169        client_version: env!("CARGO_PKG_VERSION"),
170        protocol: 1,
171        echo: flags.echo,
172        signature: None,
173        jwt: None,
174        supports_no_responders: true,
175        supports_headers: true,
176        nkey: None,
177        non_standard,
178    };
179    if let Some(auth_method) = auth_method {
180        auth_method
181            .prepare_for_auth(&info, &mut connect)
182            .map_err(ConnectError::Authentication)?;
183    }
184
185    let mut conn = match conn {
186        Connection::Streaming(streaming) => {
187            Connection::Streaming(streaming.replace_socket(|stream| {
188                MaybeConnection(Some(ConnectionCompression::Plain(stream)))
189            }))
190        }
191        Connection::Websocket(websocket) => Connection::Websocket(websocket),
192    };
193
194    #[cfg(feature = "non-standard-zstd")]
195    let zstd = connect.non_standard.zstd;
196
197    proto_connect(&mut conn, connect, |conn| {
198        #[cfg(feature = "non-standard-zstd")]
199        match conn {
200            Connection::Streaming(streaming) => {
201                if zstd {
202                    if let Some(zstd_compression_level) = flags.zstd_compression_level {
203                        let stream = streaming.socket_mut().0.take().unwrap();
204                        streaming.socket_mut().0 =
205                            Some(stream.upgrade_zstd(zstd_compression_level));
206                    }
207                }
208            }
209            Connection::Websocket(_websocket) => {}
210        }
211
212        let _ = conn;
213    })
214    .await
215    .map_err(ConnectError::Connect)?;
216
217    let conn = match conn {
218        Connection::Streaming(streaming) => {
219            Connection::Streaming(streaming.replace_socket(|stream| stream.0.unwrap()))
220        }
221        Connection::Websocket(websocket) => Connection::Websocket(websocket),
222    };
223
224    Ok((conn, info))
225}
226
227fn rustls_server_name_from_addr(addr: &ServerAddr) -> Result<ServerName<'_>, InvalidDnsNameError> {
228    match addr.host() {
229        Host::Ip(addr) => Ok(ServerName::IpAddress((*addr).into())),
230        Host::Dns(name) => <_ as AsRef<str>>::as_ref(name).try_into(),
231    }
232}