watermelon_mini/proto/
connector.rs

1use std::io;
2
3use tokio::net::TcpStream;
4use tokio_rustls::{
5    TlsConnector,
6    rustls::{
7        self,
8        pki_types::{InvalidDnsNameError, ServerName},
9    },
10};
11use watermelon_net::{
12    Connection, StreamingConnection, connect_tcp,
13    error::{ConnectionReadError, StreamingReadError},
14    proto_connect,
15};
16#[cfg(feature = "websocket")]
17use watermelon_net::{WebsocketConnection, error::WebsocketReadError};
18#[cfg(feature = "websocket")]
19use watermelon_proto::proto::error::FrameDecoderError;
20use watermelon_proto::{
21    Connect, Host, NonStandardConnect, Protocol, ServerAddr, ServerInfo, Transport,
22    proto::{ServerOp, error::DecoderError},
23};
24
25use crate::{ConnectFlags, ConnectionCompression, util::MaybeConnection};
26
27use super::{
28    authenticator::{AuthenticationError, AuthenticationMethod},
29    connection::ConnectionSecurity,
30};
31
32#[derive(Debug, thiserror::Error)]
33pub enum ConnectError {
34    #[error("io error")]
35    Io(#[source] io::Error),
36    #[error("TLS error")]
37    Tls(rustls::Error),
38    #[error("invalid DNS name")]
39    InvalidDnsName(#[source] InvalidDnsNameError),
40    #[error("websocket not supported")]
41    WebsocketUnsupported,
42    #[error("unexpected ServerOp")]
43    UnexpectedServerOp,
44    #[error("decoder error")]
45    Decoder(#[source] DecoderError),
46    #[error("authentication error")]
47    Authentication(#[source] AuthenticationError),
48    #[error("connect")]
49    Connect(#[source] watermelon_net::error::ConnectError),
50}
51
52#[expect(clippy::too_many_lines)]
53pub(crate) async fn connect(
54    connector: &TlsConnector,
55    addr: &ServerAddr,
56    client_name: String,
57    auth_method: Option<&AuthenticationMethod>,
58    flags: ConnectFlags,
59) -> Result<
60    (
61        Connection<
62            ConnectionCompression<ConnectionSecurity<TcpStream>>,
63            ConnectionSecurity<TcpStream>,
64        >,
65        Box<ServerInfo>,
66    ),
67    ConnectError,
68> {
69    let conn = connect_tcp(addr).await.map_err(ConnectError::Io)?;
70    conn.set_nodelay(true).map_err(ConnectError::Io)?;
71    let mut conn = ConnectionSecurity::Plain(conn);
72
73    if matches!(addr.protocol(), Protocol::TLS) {
74        let domain = rustls_server_name_from_addr(addr).map_err(ConnectError::InvalidDnsName)?;
75        conn = conn
76            .upgrade_tls(connector, domain.to_owned())
77            .await
78            .map_err(ConnectError::Io)?;
79    }
80
81    let mut conn = match addr.transport() {
82        Transport::TCP => Connection::Streaming(StreamingConnection::new(conn)),
83        #[cfg(feature = "websocket")]
84        Transport::Websocket => {
85            let uri = addr.to_string().parse().unwrap();
86            Connection::Websocket(
87                WebsocketConnection::new(uri, conn)
88                    .await
89                    .map_err(ConnectError::Io)?,
90            )
91        }
92        #[cfg(not(feature = "websocket"))]
93        Transport::Websocket => return Err(ConnectError::WebsocketUnsupported),
94    };
95    let info = match conn.read_next().await {
96        Ok(ServerOp::Info { info }) => info,
97        Ok(_) => return Err(ConnectError::UnexpectedServerOp),
98        Err(ConnectionReadError::Streaming(StreamingReadError::Io(err))) => {
99            return Err(ConnectError::Io(err));
100        }
101        Err(ConnectionReadError::Streaming(StreamingReadError::Decoder(err))) => {
102            return Err(ConnectError::Decoder(err));
103        }
104        #[cfg(feature = "websocket")]
105        Err(ConnectionReadError::Websocket(WebsocketReadError::Io(err))) => {
106            return Err(ConnectError::Io(err));
107        }
108        #[cfg(feature = "websocket")]
109        Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder(
110            FrameDecoderError::Decoder(err),
111        ))) => return Err(ConnectError::Decoder(err)),
112        #[cfg(feature = "websocket")]
113        Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder(
114            FrameDecoderError::IncompleteFrame,
115        ))) => todo!(),
116        #[cfg(feature = "websocket")]
117        Err(ConnectionReadError::Websocket(WebsocketReadError::Closed)) => todo!(),
118    };
119
120    let conn = match conn {
121        Connection::Streaming(streaming) => Connection::Streaming(
122            if matches!(
123                (addr.protocol(), info.tls_required),
124                (Protocol::PossiblyPlain, true)
125            ) {
126                let domain =
127                    rustls_server_name_from_addr(addr).map_err(ConnectError::InvalidDnsName)?;
128                StreamingConnection::new(
129                    streaming
130                        .into_inner()
131                        .upgrade_tls(connector, domain.to_owned())
132                        .await
133                        .map_err(ConnectError::Io)?,
134                )
135            } else {
136                streaming
137            },
138        ),
139        Connection::Websocket(websocket) => Connection::Websocket(websocket),
140    };
141
142    let auth;
143    let auth_method = if let Some(auth_method) = auth_method {
144        Some(auth_method)
145    } else if let Some(auth_method) = AuthenticationMethod::try_from_addr(addr) {
146        auth = auth_method;
147        Some(&auth)
148    } else {
149        None
150    };
151
152    #[allow(unused_mut)]
153    let mut non_standard = NonStandardConnect::default();
154    #[cfg(feature = "non-standard-zstd")]
155    if matches!(conn, Connection::Streaming(_)) {
156        non_standard.zstd = flags.zstd_compression_level.is_some() && info.non_standard.zstd;
157    }
158
159    let mut connect = Connect {
160        verbose: true,
161        pedantic: false,
162        require_tls: false,
163        auth_token: None,
164        username: None,
165        password: None,
166        client_name: Some(client_name),
167        client_lang: "rust-watermelon",
168        client_version: env!("CARGO_PKG_VERSION"),
169        protocol: 1,
170        echo: flags.echo,
171        signature: None,
172        jwt: None,
173        supports_no_responders: true,
174        supports_headers: true,
175        nkey: None,
176        non_standard,
177    };
178    if let Some(auth_method) = auth_method {
179        auth_method
180            .prepare_for_auth(&info, &mut connect)
181            .map_err(ConnectError::Authentication)?;
182    }
183
184    let mut conn = match conn {
185        Connection::Streaming(streaming) => {
186            Connection::Streaming(streaming.replace_socket(|stream| {
187                MaybeConnection(Some(ConnectionCompression::Plain(stream)))
188            }))
189        }
190        Connection::Websocket(websocket) => Connection::Websocket(websocket),
191    };
192
193    #[cfg(feature = "non-standard-zstd")]
194    let zstd = connect.non_standard.zstd;
195
196    proto_connect(&mut conn, connect, |conn| {
197        #[cfg(feature = "non-standard-zstd")]
198        match conn {
199            Connection::Streaming(streaming) => {
200                if zstd {
201                    if let Some(zstd_compression_level) = flags.zstd_compression_level {
202                        let stream = streaming.socket_mut().0.take().unwrap();
203                        streaming.socket_mut().0 =
204                            Some(stream.upgrade_zstd(zstd_compression_level));
205                    }
206                }
207            }
208            Connection::Websocket(_websocket) => {}
209        }
210
211        let _ = conn;
212    })
213    .await
214    .map_err(ConnectError::Connect)?;
215
216    let conn = match conn {
217        Connection::Streaming(streaming) => {
218            Connection::Streaming(streaming.replace_socket(|stream| stream.0.unwrap()))
219        }
220        Connection::Websocket(websocket) => Connection::Websocket(websocket),
221    };
222
223    Ok((conn, info))
224}
225
226fn rustls_server_name_from_addr(addr: &ServerAddr) -> Result<ServerName<'_>, InvalidDnsNameError> {
227    match addr.host() {
228        Host::Ip(addr) => Ok(ServerName::IpAddress((*addr).into())),
229        Host::Dns(name) => <_ as AsRef<str>>::as_ref(name).try_into(),
230    }
231}