trojan_client/
connector.rs1use std::net::SocketAddr;
4use std::sync::Arc;
5use std::time::Duration;
6
7use rustls::pki_types::ServerName;
8use tokio::net::TcpStream;
9use tokio_rustls::TlsConnector;
10use tokio_rustls::client::TlsStream;
11use tracing::debug;
12use trojan_config::TcpConfig;
13
14use crate::error::ClientError;
15
16pub struct ClientState {
18 pub hash_hex: String,
20 pub remote_addr: String,
22 pub tls_connector: TlsConnector,
24 pub sni: ServerName<'static>,
26 pub tcp_config: TcpConfig,
28 pub tls_handshake_timeout: Duration,
30}
31
32impl ClientState {
33 pub async fn connect(&self) -> Result<TlsStream<TcpStream>, ClientError> {
35 let addr: SocketAddr = tokio::net::lookup_host(&self.remote_addr)
37 .await?
38 .next()
39 .ok_or_else(|| ClientError::Resolve(self.remote_addr.clone()))?;
40
41 debug!(remote = %addr, "connecting to trojan server");
42
43 let tcp = TcpStream::connect(addr).await?;
45 apply_tcp_options(&tcp, &self.tcp_config)?;
46
47 let tls = tokio::time::timeout(
49 self.tls_handshake_timeout,
50 self.tls_connector.connect(self.sni.clone(), tcp),
51 )
52 .await
53 .map_err(|_| {
54 std::io::Error::new(std::io::ErrorKind::TimedOut, "TLS handshake timed out")
55 })??;
56
57 Ok(tls)
58 }
59}
60
61pub fn build_tls_config(
63 tls: &crate::config::ClientTlsConfig,
64) -> Result<rustls::ClientConfig, ClientError> {
65 let mut root_store = rustls::RootCertStore::empty();
66
67 if let Some(ca_path) = &tls.ca {
68 let ca_data = std::fs::read(ca_path)
69 .map_err(|e| ClientError::Config(format!("failed to read CA cert: {e}")))?;
70
71 let certs = rustls_pemfile::certs(&mut std::io::Cursor::new(&ca_data))
72 .collect::<Result<Vec<_>, _>>()
73 .map_err(|e| ClientError::Config(format!("failed to parse CA cert: {e}")))?;
74
75 for cert in certs {
76 root_store
77 .add(cert)
78 .map_err(|e| ClientError::Config(format!("failed to add CA cert: {e}")))?;
79 }
80 } else {
81 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
82 }
83
84 let mut config = if tls.skip_verify {
85 rustls::ClientConfig::builder()
86 .dangerous()
87 .with_custom_certificate_verifier(Arc::new(NoVerifier))
88 .with_no_client_auth()
89 } else {
90 rustls::ClientConfig::builder()
91 .with_root_certificates(root_store)
92 .with_no_client_auth()
93 };
94
95 config.alpn_protocols = tls.alpn.iter().map(|s| s.as_bytes().to_vec()).collect();
96
97 Ok(config)
98}
99
100pub fn resolve_sni(
102 tls: &crate::config::ClientTlsConfig,
103 remote: &str,
104) -> Result<ServerName<'static>, ClientError> {
105 let host = if let Some(sni) = &tls.sni {
106 sni.clone()
107 } else {
108 extract_host(remote)
109 };
110
111 ServerName::try_from(host)
112 .map_err(|e| ClientError::Config(format!("invalid SNI hostname: {e}")))
113}
114
115fn extract_host(remote: &str) -> String {
116 if let Some(stripped) = remote.strip_prefix('[') {
117 if let Some(end) = stripped.find(']') {
118 return stripped[..end].to_string();
119 }
120 }
121
122 if remote.chars().filter(|&c| c == ':').count() == 1 {
123 return remote
124 .rsplit_once(':')
125 .map(|(h, _)| h.to_string())
126 .unwrap_or_else(|| remote.to_string());
127 }
128
129 remote.to_string()
130}
131
132#[cfg(test)]
133mod tests {
134 use super::{extract_host, resolve_sni};
135
136 #[test]
137 fn extract_host_parses_bracketed_ipv6() {
138 assert_eq!(extract_host("[::1]:443"), "::1");
139 assert_eq!(extract_host("[2001:db8::1]:8443"), "2001:db8::1");
140 }
141
142 #[test]
143 fn extract_host_parses_hostname_and_port() {
144 assert_eq!(extract_host("example.com:443"), "example.com");
145 assert_eq!(extract_host("example.com"), "example.com");
146 }
147
148 #[test]
149 fn resolve_sni_accepts_ipv6_literal() {
150 let tls = crate::config::ClientTlsConfig::default();
151 let sni = resolve_sni(&tls, "[::1]:443");
152 assert!(sni.is_ok());
153 }
154}
155
156fn apply_tcp_options(stream: &TcpStream, config: &TcpConfig) -> Result<(), ClientError> {
158 stream.set_nodelay(config.no_delay)?;
159
160 if config.keepalive_secs > 0 {
161 let sock = socket2::SockRef::from(stream);
162 let keepalive =
163 socket2::TcpKeepalive::new().with_time(Duration::from_secs(config.keepalive_secs));
164 sock.set_tcp_keepalive(&keepalive)?;
165 }
166
167 Ok(())
168}
169
170#[derive(Debug)]
172struct NoVerifier;
173
174impl rustls::client::danger::ServerCertVerifier for NoVerifier {
175 fn verify_server_cert(
176 &self,
177 _end_entity: &rustls::pki_types::CertificateDer<'_>,
178 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
179 _server_name: &ServerName<'_>,
180 _ocsp_response: &[u8],
181 _now: rustls::pki_types::UnixTime,
182 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
183 Ok(rustls::client::danger::ServerCertVerified::assertion())
184 }
185
186 fn verify_tls12_signature(
187 &self,
188 _message: &[u8],
189 _cert: &rustls::pki_types::CertificateDer<'_>,
190 _dss: &rustls::DigitallySignedStruct,
191 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
192 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
193 }
194
195 fn verify_tls13_signature(
196 &self,
197 _message: &[u8],
198 _cert: &rustls::pki_types::CertificateDer<'_>,
199 _dss: &rustls::DigitallySignedStruct,
200 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
201 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
202 }
203
204 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
205 rustls::crypto::CryptoProvider::get_default()
206 .map(|provider| provider.signature_verification_algorithms.supported_schemes())
207 .unwrap_or_default()
208 }
209}