Skip to main content

trojan_client/
connector.rs

1//! TLS connection establishment to the remote trojan server.
2
3use std::net::SocketAddr;
4use std::sync::Arc;
5use std::time::Duration;
6
7use rustls::pki_types::ServerName;
8use tokio::net::TcpStream;
9use tokio_rustls::TlsConnector;
10use tokio_rustls::client::TlsStream;
11use tracing::debug;
12use trojan_config::TcpConfig;
13use trojan_dns::DnsResolver;
14
15use crate::error::ClientError;
16
17/// Shared client state for establishing outbound connections.
18#[allow(missing_debug_implementations)]
19pub struct ClientState {
20    /// SHA-224 hex hash of the password (56 bytes).
21    pub hash_hex: String,
22    /// Remote trojan server address string (host:port).
23    pub remote_addr: String,
24    /// TLS connector.
25    pub tls_connector: TlsConnector,
26    /// TLS SNI server name.
27    pub sni: ServerName<'static>,
28    /// TCP socket options.
29    pub tcp_config: TcpConfig,
30    /// TLS handshake timeout.
31    pub tls_handshake_timeout: Duration,
32    /// DNS resolver.
33    pub dns_resolver: DnsResolver,
34}
35
36impl ClientState {
37    /// Establish a TLS connection to the remote trojan server.
38    pub async fn connect(&self) -> Result<TlsStream<TcpStream>, ClientError> {
39        // DNS resolve
40        let addr: SocketAddr = self
41            .dns_resolver
42            .resolve(&self.remote_addr)
43            .await
44            .map_err(|_| ClientError::Resolve(self.remote_addr.clone()))?;
45
46        debug!(remote = %addr, "connecting to trojan server");
47
48        // TCP connect
49        let tcp = TcpStream::connect(addr).await?;
50        apply_tcp_options(&tcp, &self.tcp_config)?;
51
52        // TLS handshake with timeout
53        let tls = tokio::time::timeout(
54            self.tls_handshake_timeout,
55            self.tls_connector.connect(self.sni.clone(), tcp),
56        )
57        .await
58        .map_err(|_| {
59            std::io::Error::new(std::io::ErrorKind::TimedOut, "TLS handshake timed out")
60        })??;
61
62        Ok(tls)
63    }
64}
65
66/// Build TLS client config from client TLS settings.
67pub fn build_tls_config(
68    tls: &crate::config::ClientTlsConfig,
69) -> Result<rustls::ClientConfig, ClientError> {
70    let mut root_store = rustls::RootCertStore::empty();
71
72    if let Some(ca_path) = &tls.ca {
73        let ca_data = std::fs::read(ca_path)
74            .map_err(|e| ClientError::Config(format!("failed to read CA cert: {e}")))?;
75
76        let certs = rustls_pemfile::certs(&mut std::io::Cursor::new(&ca_data))
77            .collect::<Result<Vec<_>, _>>()
78            .map_err(|e| ClientError::Config(format!("failed to parse CA cert: {e}")))?;
79
80        for cert in certs {
81            root_store
82                .add(cert)
83                .map_err(|e| ClientError::Config(format!("failed to add CA cert: {e}")))?;
84        }
85    } else {
86        root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
87    }
88
89    let mut config = if tls.skip_verify {
90        rustls::ClientConfig::builder()
91            .dangerous()
92            .with_custom_certificate_verifier(Arc::new(NoVerifier))
93            .with_no_client_auth()
94    } else {
95        rustls::ClientConfig::builder()
96            .with_root_certificates(root_store)
97            .with_no_client_auth()
98    };
99
100    config.alpn_protocols = tls.alpn.iter().map(|s| s.as_bytes().to_vec()).collect();
101
102    Ok(config)
103}
104
105/// Extract the SNI hostname from config or the remote address.
106pub fn resolve_sni(
107    tls: &crate::config::ClientTlsConfig,
108    remote: &str,
109) -> Result<ServerName<'static>, ClientError> {
110    let host = if let Some(sni) = &tls.sni {
111        sni.clone()
112    } else {
113        extract_host(remote)
114    };
115
116    ServerName::try_from(host)
117        .map_err(|e| ClientError::Config(format!("invalid SNI hostname: {e}")))
118}
119
120fn extract_host(remote: &str) -> String {
121    if let Some(stripped) = remote.strip_prefix('[')
122        && let Some(end) = stripped.find(']')
123    {
124        return stripped[..end].to_string();
125    }
126
127    if remote.chars().filter(|&c| c == ':').count() == 1 {
128        return remote
129            .rsplit_once(':')
130            .map(|(h, _)| h.to_string())
131            .unwrap_or_else(|| remote.to_string());
132    }
133
134    remote.to_string()
135}
136
137#[cfg(test)]
138mod tests {
139    use super::{extract_host, resolve_sni};
140
141    #[test]
142    fn extract_host_parses_bracketed_ipv6() {
143        assert_eq!(extract_host("[::1]:443"), "::1");
144        assert_eq!(extract_host("[2001:db8::1]:8443"), "2001:db8::1");
145    }
146
147    #[test]
148    fn extract_host_parses_hostname_and_port() {
149        assert_eq!(extract_host("example.com:443"), "example.com");
150        assert_eq!(extract_host("example.com"), "example.com");
151    }
152
153    #[test]
154    fn resolve_sni_accepts_ipv6_literal() {
155        let tls = crate::config::ClientTlsConfig::default();
156        let sni = resolve_sni(&tls, "[::1]:443");
157        sni.unwrap();
158    }
159}
160
161/// Apply TCP socket options.
162fn apply_tcp_options(stream: &TcpStream, config: &TcpConfig) -> Result<(), ClientError> {
163    stream.set_nodelay(config.no_delay)?;
164
165    if config.keepalive_secs > 0 {
166        let sock = socket2::SockRef::from(stream);
167        let keepalive =
168            socket2::TcpKeepalive::new().with_time(Duration::from_secs(config.keepalive_secs));
169        sock.set_tcp_keepalive(&keepalive)?;
170    }
171
172    Ok(())
173}
174
175/// Certificate verifier that accepts any certificate (for skip_verify mode).
176#[derive(Debug)]
177struct NoVerifier;
178
179impl rustls::client::danger::ServerCertVerifier for NoVerifier {
180    fn verify_server_cert(
181        &self,
182        _end_entity: &rustls::pki_types::CertificateDer<'_>,
183        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
184        _server_name: &ServerName<'_>,
185        _ocsp_response: &[u8],
186        _now: rustls::pki_types::UnixTime,
187    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
188        Ok(rustls::client::danger::ServerCertVerified::assertion())
189    }
190
191    fn verify_tls12_signature(
192        &self,
193        _message: &[u8],
194        _cert: &rustls::pki_types::CertificateDer<'_>,
195        _dss: &rustls::DigitallySignedStruct,
196    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
197        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
198    }
199
200    fn verify_tls13_signature(
201        &self,
202        _message: &[u8],
203        _cert: &rustls::pki_types::CertificateDer<'_>,
204        _dss: &rustls::DigitallySignedStruct,
205    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
206        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
207    }
208
209    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
210        rustls::crypto::CryptoProvider::get_default()
211            .map(|provider| {
212                provider
213                    .signature_verification_algorithms
214                    .supported_schemes()
215            })
216            .unwrap_or_default()
217    }
218}