watermelon_mini/proto/
connector.rs

1use std::io;
2
3use tokio::net::TcpStream;
4use tokio_rustls::{
5    rustls::pki_types::{InvalidDnsNameError, ServerName},
6    TlsConnector,
7};
8use watermelon_net::{
9    connect_tcp,
10    error::{ConnectionReadError, StreamingReadError},
11    proto_connect, Connection, StreamingConnection,
12};
13#[cfg(feature = "websocket")]
14use watermelon_net::{error::WebsocketReadError, WebsocketConnection};
15#[cfg(feature = "websocket")]
16use watermelon_proto::proto::error::FrameDecoderError;
17use watermelon_proto::{
18    proto::{error::DecoderError, ServerOp},
19    Connect, Host, NonStandardConnect, Protocol, ServerAddr, ServerInfo, Transport,
20};
21
22use crate::{util::MaybeConnection, ConnectFlags, ConnectionCompression};
23
24use super::{
25    authenticator::{AuthenticationError, AuthenticationMethod},
26    connection::ConnectionSecurity,
27};
28
29#[derive(Debug, thiserror::Error)]
30pub enum ConnectError {
31    #[error("io error")]
32    Io(#[source] io::Error),
33    #[error("invalid DNS name")]
34    InvalidDnsName(#[source] InvalidDnsNameError),
35    #[error("websocket not supported")]
36    WebsocketUnsupported,
37    #[error("unexpected ServerOp")]
38    UnexpectedServerOp,
39    #[error("decoder error")]
40    Decoder(#[source] DecoderError),
41    #[error("authentication error")]
42    Authentication(#[source] AuthenticationError),
43    #[error("connect")]
44    Connect(#[source] watermelon_net::error::ConnectError),
45}
46
47#[expect(clippy::too_many_lines)]
48pub(crate) async fn connect(
49    connector: &TlsConnector,
50    addr: &ServerAddr,
51    client_name: String,
52    auth_method: Option<&AuthenticationMethod>,
53    flags: ConnectFlags,
54) -> Result<
55    (
56        Connection<
57            ConnectionCompression<ConnectionSecurity<TcpStream>>,
58            ConnectionSecurity<TcpStream>,
59        >,
60        Box<ServerInfo>,
61    ),
62    ConnectError,
63> {
64    let conn = connect_tcp(addr).await.map_err(ConnectError::Io)?;
65    conn.set_nodelay(true).map_err(ConnectError::Io)?;
66    let mut conn = ConnectionSecurity::Plain(conn);
67
68    if matches!(addr.protocol(), Protocol::TLS) {
69        let domain = rustls_server_name_from_addr(addr).map_err(ConnectError::InvalidDnsName)?;
70        conn = conn
71            .upgrade_tls(connector, domain.to_owned())
72            .await
73            .map_err(ConnectError::Io)?;
74    }
75
76    let mut conn = match addr.transport() {
77        Transport::TCP => Connection::Streaming(StreamingConnection::new(conn)),
78        #[cfg(feature = "websocket")]
79        Transport::Websocket => {
80            let uri = addr.to_string().parse().unwrap();
81            Connection::Websocket(
82                WebsocketConnection::new(uri, conn)
83                    .await
84                    .map_err(ConnectError::Io)?,
85            )
86        }
87        #[cfg(not(feature = "websocket"))]
88        Transport::Websocket => return Err(ConnectError::WebsocketUnsupported),
89    };
90    let info = match conn.read_next().await {
91        Ok(ServerOp::Info { info }) => info,
92        Ok(_) => return Err(ConnectError::UnexpectedServerOp),
93        Err(ConnectionReadError::Streaming(StreamingReadError::Io(err))) => {
94            return Err(ConnectError::Io(err))
95        }
96        Err(ConnectionReadError::Streaming(StreamingReadError::Decoder(err))) => {
97            return Err(ConnectError::Decoder(err))
98        }
99        #[cfg(feature = "websocket")]
100        Err(ConnectionReadError::Websocket(WebsocketReadError::Io(err))) => {
101            return Err(ConnectError::Io(err))
102        }
103        #[cfg(feature = "websocket")]
104        Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder(
105            FrameDecoderError::Decoder(err),
106        ))) => return Err(ConnectError::Decoder(err)),
107        #[cfg(feature = "websocket")]
108        Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder(
109            FrameDecoderError::IncompleteFrame,
110        ))) => todo!(),
111        #[cfg(feature = "websocket")]
112        Err(ConnectionReadError::Websocket(WebsocketReadError::Closed)) => todo!(),
113    };
114
115    let conn = match conn {
116        Connection::Streaming(streaming) => Connection::Streaming(
117            if matches!(
118                (addr.protocol(), info.tls_required),
119                (Protocol::PossiblyPlain, true)
120            ) {
121                let domain =
122                    rustls_server_name_from_addr(addr).map_err(ConnectError::InvalidDnsName)?;
123                StreamingConnection::new(
124                    streaming
125                        .into_inner()
126                        .upgrade_tls(connector, domain.to_owned())
127                        .await
128                        .map_err(ConnectError::Io)?,
129                )
130            } else {
131                streaming
132            },
133        ),
134        Connection::Websocket(websocket) => Connection::Websocket(websocket),
135    };
136
137    let auth;
138    let auth_method = if let Some(auth_method) = auth_method {
139        Some(auth_method)
140    } else if let Some(auth_method) = AuthenticationMethod::try_from_addr(addr) {
141        auth = auth_method;
142        Some(&auth)
143    } else {
144        None
145    };
146
147    #[allow(unused_mut)]
148    let mut non_standard = NonStandardConnect::default();
149    #[cfg(feature = "non-standard-zstd")]
150    if matches!(conn, Connection::Streaming(_)) {
151        non_standard.zstd = flags.zstd_compression_level.is_some() && info.non_standard.zstd;
152    }
153
154    let mut connect = Connect {
155        verbose: true,
156        pedantic: false,
157        require_tls: false,
158        auth_token: None,
159        username: None,
160        password: None,
161        client_name: Some(client_name),
162        client_lang: "rust-watermelon",
163        client_version: env!("CARGO_PKG_VERSION"),
164        protocol: 1,
165        echo: flags.echo,
166        signature: None,
167        jwt: None,
168        supports_no_responders: true,
169        supports_headers: true,
170        nkey: None,
171        non_standard,
172    };
173    if let Some(auth_method) = auth_method {
174        auth_method
175            .prepare_for_auth(&info, &mut connect)
176            .map_err(ConnectError::Authentication)?;
177    }
178
179    let mut conn = match conn {
180        Connection::Streaming(streaming) => {
181            Connection::Streaming(streaming.replace_socket(|stream| {
182                MaybeConnection(Some(ConnectionCompression::Plain(stream)))
183            }))
184        }
185        Connection::Websocket(websocket) => Connection::Websocket(websocket),
186    };
187
188    #[cfg(feature = "non-standard-zstd")]
189    let zstd = connect.non_standard.zstd;
190
191    proto_connect(&mut conn, connect, |conn| {
192        #[cfg(feature = "non-standard-zstd")]
193        match conn {
194            Connection::Streaming(streaming) => {
195                if zstd {
196                    if let Some(zstd_compression_level) = flags.zstd_compression_level {
197                        let stream = streaming.socket_mut().0.take().unwrap();
198                        streaming.socket_mut().0 =
199                            Some(stream.upgrade_zstd(zstd_compression_level));
200                    }
201                }
202            }
203            Connection::Websocket(_websocket) => {}
204        }
205
206        let _ = conn;
207    })
208    .await
209    .map_err(ConnectError::Connect)?;
210
211    let conn = match conn {
212        Connection::Streaming(streaming) => {
213            Connection::Streaming(streaming.replace_socket(|stream| stream.0.unwrap()))
214        }
215        Connection::Websocket(websocket) => Connection::Websocket(websocket),
216    };
217
218    Ok((conn, info))
219}
220
221fn rustls_server_name_from_addr(addr: &ServerAddr) -> Result<ServerName<'_>, InvalidDnsNameError> {
222    match addr.host() {
223        Host::Ip(addr) => Ok(ServerName::IpAddress((*addr).into())),
224        Host::Dns(name) => <_ as AsRef<str>>::as_ref(name).try_into(),
225    }
226}