trojan_client/
connector.rs1use std::net::SocketAddr;
4use std::sync::Arc;
5use std::time::Duration;
6
7use rustls::pki_types::ServerName;
8use tokio::net::TcpStream;
9use tokio_rustls::TlsConnector;
10use tokio_rustls::client::TlsStream;
11use tracing::debug;
12use trojan_config::TcpConfig;
13use trojan_dns::DnsResolver;
14
15use crate::error::ClientError;
16
17#[allow(missing_debug_implementations)]
19pub struct ClientState {
20 pub hash_hex: String,
22 pub remote_addr: String,
24 pub tls_connector: TlsConnector,
26 pub sni: ServerName<'static>,
28 pub tcp_config: TcpConfig,
30 pub tls_handshake_timeout: Duration,
32 pub dns_resolver: DnsResolver,
34}
35
36impl ClientState {
37 pub async fn connect(&self) -> Result<TlsStream<TcpStream>, ClientError> {
39 let addr: SocketAddr = self
41 .dns_resolver
42 .resolve(&self.remote_addr)
43 .await
44 .map_err(|_| ClientError::Resolve(self.remote_addr.clone()))?;
45
46 debug!(remote = %addr, "connecting to trojan server");
47
48 let tcp = TcpStream::connect(addr).await?;
50 apply_tcp_options(&tcp, &self.tcp_config)?;
51
52 let tls = tokio::time::timeout(
54 self.tls_handshake_timeout,
55 self.tls_connector.connect(self.sni.clone(), tcp),
56 )
57 .await
58 .map_err(|_| {
59 std::io::Error::new(std::io::ErrorKind::TimedOut, "TLS handshake timed out")
60 })??;
61
62 Ok(tls)
63 }
64}
65
66pub fn build_tls_config(
68 tls: &crate::config::ClientTlsConfig,
69) -> Result<rustls::ClientConfig, ClientError> {
70 let mut root_store = rustls::RootCertStore::empty();
71
72 if let Some(ca_path) = &tls.ca {
73 let ca_data = std::fs::read(ca_path)
74 .map_err(|e| ClientError::Config(format!("failed to read CA cert: {e}")))?;
75
76 let certs = rustls_pemfile::certs(&mut std::io::Cursor::new(&ca_data))
77 .collect::<Result<Vec<_>, _>>()
78 .map_err(|e| ClientError::Config(format!("failed to parse CA cert: {e}")))?;
79
80 for cert in certs {
81 root_store
82 .add(cert)
83 .map_err(|e| ClientError::Config(format!("failed to add CA cert: {e}")))?;
84 }
85 } else {
86 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
87 }
88
89 let mut config = if tls.skip_verify {
90 rustls::ClientConfig::builder()
91 .dangerous()
92 .with_custom_certificate_verifier(Arc::new(NoVerifier))
93 .with_no_client_auth()
94 } else {
95 rustls::ClientConfig::builder()
96 .with_root_certificates(root_store)
97 .with_no_client_auth()
98 };
99
100 config.alpn_protocols = tls.alpn.iter().map(|s| s.as_bytes().to_vec()).collect();
101
102 Ok(config)
103}
104
105pub fn resolve_sni(
107 tls: &crate::config::ClientTlsConfig,
108 remote: &str,
109) -> Result<ServerName<'static>, ClientError> {
110 let host = if let Some(sni) = &tls.sni {
111 sni.clone()
112 } else {
113 extract_host(remote)
114 };
115
116 ServerName::try_from(host)
117 .map_err(|e| ClientError::Config(format!("invalid SNI hostname: {e}")))
118}
119
120fn extract_host(remote: &str) -> String {
121 if let Some(stripped) = remote.strip_prefix('[')
122 && let Some(end) = stripped.find(']')
123 {
124 return stripped[..end].to_string();
125 }
126
127 if remote.chars().filter(|&c| c == ':').count() == 1 {
128 return remote
129 .rsplit_once(':')
130 .map(|(h, _)| h.to_string())
131 .unwrap_or_else(|| remote.to_string());
132 }
133
134 remote.to_string()
135}
136
137#[cfg(test)]
138mod tests {
139 use super::{extract_host, resolve_sni};
140
141 #[test]
142 fn extract_host_parses_bracketed_ipv6() {
143 assert_eq!(extract_host("[::1]:443"), "::1");
144 assert_eq!(extract_host("[2001:db8::1]:8443"), "2001:db8::1");
145 }
146
147 #[test]
148 fn extract_host_parses_hostname_and_port() {
149 assert_eq!(extract_host("example.com:443"), "example.com");
150 assert_eq!(extract_host("example.com"), "example.com");
151 }
152
153 #[test]
154 fn resolve_sni_accepts_ipv6_literal() {
155 let tls = crate::config::ClientTlsConfig::default();
156 let sni = resolve_sni(&tls, "[::1]:443");
157 sni.unwrap();
158 }
159}
160
161fn apply_tcp_options(stream: &TcpStream, config: &TcpConfig) -> Result<(), ClientError> {
163 stream.set_nodelay(config.no_delay)?;
164
165 if config.keepalive_secs > 0 {
166 let sock = socket2::SockRef::from(stream);
167 let keepalive =
168 socket2::TcpKeepalive::new().with_time(Duration::from_secs(config.keepalive_secs));
169 sock.set_tcp_keepalive(&keepalive)?;
170 }
171
172 Ok(())
173}
174
175#[derive(Debug)]
177struct NoVerifier;
178
179impl rustls::client::danger::ServerCertVerifier for NoVerifier {
180 fn verify_server_cert(
181 &self,
182 _end_entity: &rustls::pki_types::CertificateDer<'_>,
183 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
184 _server_name: &ServerName<'_>,
185 _ocsp_response: &[u8],
186 _now: rustls::pki_types::UnixTime,
187 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
188 Ok(rustls::client::danger::ServerCertVerified::assertion())
189 }
190
191 fn verify_tls12_signature(
192 &self,
193 _message: &[u8],
194 _cert: &rustls::pki_types::CertificateDer<'_>,
195 _dss: &rustls::DigitallySignedStruct,
196 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
197 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
198 }
199
200 fn verify_tls13_signature(
201 &self,
202 _message: &[u8],
203 _cert: &rustls::pki_types::CertificateDer<'_>,
204 _dss: &rustls::DigitallySignedStruct,
205 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
206 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
207 }
208
209 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
210 rustls::crypto::CryptoProvider::get_default()
211 .map(|provider| {
212 provider
213 .signature_verification_algorithms
214 .supported_schemes()
215 })
216 .unwrap_or_default()
217 }
218}