1#[cfg(test)]
7use tokio_rustls::rustls::RootCertStore;
8use tonic::transport::{Channel, ClientTlsConfig, Endpoint};
9use zcash_client_backend::proto::service::compact_tx_streamer_client::CompactTxStreamerClient;
10
11#[derive(Debug, thiserror::Error)]
12pub enum GetClientError {
13 #[error("bad uri: invalid scheme")]
14 InvalidScheme,
15
16 #[error("bad uri: invalid authority")]
17 InvalidAuthority,
18
19 #[error("bad uri: invalid path and/or query")]
20 InvalidPathAndQuery,
21
22 #[error(transparent)]
23 Transport(#[from] tonic::transport::Error),
24}
25
26#[cfg(test)]
27fn load_test_cert_pem() -> Option<Vec<u8>> {
28 const TEST_PEMFILE_PATH: &str = "test-data/localhost.pem";
29 std::fs::read(TEST_PEMFILE_PATH).ok()
30}
31fn client_tls_config() -> Result<ClientTlsConfig, GetClientError> {
32 #[cfg(test)]
34 {
35 if let Some(pem) = load_test_cert_pem() {
36 return Ok(
37 ClientTlsConfig::new().ca_certificate(tonic::transport::Certificate::from_pem(pem))
38 );
39 }
40 }
41
42 Ok(ClientTlsConfig::new())
43}
44pub async fn get_client(
52 uri: http::Uri,
53) -> Result<CompactTxStreamerClient<Channel>, GetClientError> {
54 let scheme = uri.scheme_str().ok_or(GetClientError::InvalidScheme)?;
55 if scheme != "http" && scheme != "https" {
56 return Err(GetClientError::InvalidScheme);
57 }
58 let _authority = uri.authority().ok_or(GetClientError::InvalidAuthority)?;
59
60 let endpoint = Endpoint::from_shared(uri.to_string())?.tcp_nodelay(true);
61
62 let channel = if scheme == "https" {
63 let tls = client_tls_config()?;
64 endpoint.tls_config(tls)?.connect().await?
65 } else {
66 endpoint.connect().await?
67 };
68
69 Ok(CompactTxStreamerClient::new(channel))
70}
71
72#[cfg(test)]
73fn add_test_cert_to_roots(roots: &mut RootCertStore) {
74 use tonic::transport::CertificateDer;
75 eprintln!("Adding test cert to roots");
76
77 const TEST_PEMFILE_PATH: &str = "test-data/localhost.pem";
78
79 let Ok(fd) = std::fs::File::open(TEST_PEMFILE_PATH) else {
80 eprintln!("Test TLS cert not found at {TEST_PEMFILE_PATH}, skipping");
81 return;
82 };
83
84 let mut buf = std::io::BufReader::new(fd);
85 let certs_bytes: Vec<tonic::transport::CertificateDer> = rustls_pemfile::certs(&mut buf)
86 .filter_map(Result::ok)
87 .collect();
88
89 let certs: Vec<CertificateDer<'_>> = certs_bytes.into_iter().collect();
90 roots.add_parsable_certificates(certs);
91}
92
93#[cfg(test)]
94mod tests {
95 use std::time::Duration;
110
111 use http::{Request, Response};
112 use hyper::{
113 body::{Bytes, Incoming},
114 service::service_fn,
115 };
116 use hyper_util::rt::TokioIo;
117 use tokio::{net::TcpListener, sync::oneshot, time::timeout};
118 use tokio_rustls::{TlsAcceptor, rustls};
119
120 use super::*;
121
122 #[test]
129 fn localhost_cert_file_exists_and_is_parseable() {
130 const CERT_PATH: &str = "test-data/localhost.pem";
131
132 let pem = std::fs::read(CERT_PATH).expect("missing test-data/localhost.pem");
133
134 let mut cursor = std::io::BufReader::new(pem.as_slice());
135 let certs = rustls_pemfile::certs(&mut cursor)
136 .filter_map(Result::ok)
137 .collect::<Vec<_>>();
138
139 assert!(!certs.is_empty(), "no certs found in {CERT_PATH}");
140
141 for cert in certs {
142 let der = cert.as_ref();
143 let parsed = x509_parser::parse_x509_certificate(der);
144 assert!(
145 parsed.is_ok(),
146 "failed to parse a cert from {CERT_PATH} as X.509"
147 );
148 }
149 }
150
151 #[test]
157 fn localhost_cert_is_end_entity_not_ca() {
158 let pem =
159 std::fs::read("test-data/localhost.pem").expect("missing test-data/localhost.pem");
160 let mut cursor = std::io::BufReader::new(pem.as_slice());
161
162 let certs = rustls_pemfile::certs(&mut cursor)
163 .filter_map(Result::ok)
164 .collect::<Vec<_>>();
165
166 assert!(!certs.is_empty(), "no certs found in localhost.pem");
167
168 let der = certs[0].as_ref();
169 let parsed = x509_parser::parse_x509_certificate(der).expect("failed to parse X.509");
170 let x509 = parsed.1;
171
172 let constraints = x509
173 .basic_constraints()
174 .expect("missing basic constraints extension");
175
176 assert!(
177 !constraints.unwrap().value.ca,
178 "localhost.pem must be CA:FALSE"
179 );
180 }
181
182 fn load_test_server_config() -> std::sync::Arc<rustls::ServerConfig> {
189 let cert_pem =
190 std::fs::read("test-data/localhost.pem").expect("missing test-data/localhost.pem");
191 let key_pem =
192 std::fs::read("test-data/localhost.key").expect("missing test-data/localhost.key");
193
194 let mut cert_cursor = std::io::BufReader::new(cert_pem.as_slice());
195 let mut key_cursor = std::io::BufReader::new(key_pem.as_slice());
196
197 let certs = rustls_pemfile::certs(&mut cert_cursor)
198 .filter_map(Result::ok)
199 .map(rustls::pki_types::CertificateDer::from)
200 .collect::<Vec<_>>();
201
202 let key = rustls_pemfile::private_key(&mut key_cursor)
203 .expect("failed to read private key")
204 .expect("no private key found");
205
206 let config = rustls::ServerConfig::builder()
207 .with_no_client_auth()
208 .with_single_cert(certs, key)
209 .expect("bad cert or key");
210
211 std::sync::Arc::new(config)
212 }
213 #[tokio::test]
224 async fn add_test_cert_to_roots_enables_tls_handshake() {
225 use http_body_util::Full;
226 use hyper::service::service_fn;
227 use hyper_util::rt::TokioIo;
228 use tokio::net::TcpListener;
229 use tokio_rustls::TlsAcceptor;
230 use tokio_rustls::rustls;
231
232 let _ = rustls::crypto::ring::default_provider().install_default();
233
234 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind failed");
235 let addr = listener.local_addr().expect("local_addr failed");
236
237 let tls_config = load_test_server_config();
238 let acceptor = TlsAcceptor::from(tls_config);
239
240 let ready = oneshot::channel::<()>();
241 let ready_tx = ready.0;
242 let ready_rx = ready.1;
243
244 let server_task = tokio::spawn(async move {
245 let _ = ready_tx.send(());
246
247 let accept_res = timeout(Duration::from_secs(3), listener.accept()).await;
248 let (socket, _) = accept_res
249 .expect("server accept timed out")
250 .expect("accept failed");
251
252 let tls_stream = timeout(Duration::from_secs(3), acceptor.accept(socket))
253 .await
254 .expect("tls accept timed out")
255 .expect("tls accept failed");
256
257 let io = TokioIo::new(tls_stream);
258
259 let svc = service_fn(|mut req: http::Request<hyper::body::Incoming>| async move {
260 use http_body_util::BodyExt;
261
262 while let Some(frame) = req.body_mut().frame().await {
263 if frame.is_err() {
264 break;
265 }
266 }
267
268 let mut resp = http::Response::new(Full::new(Bytes::from_static(b"ok")));
269 resp.headers_mut().insert(
270 http::header::CONNECTION,
271 http::HeaderValue::from_static("close"),
272 );
273 Ok::<_, hyper::Error>(resp)
274 });
275
276 timeout(
277 Duration::from_secs(3),
278 hyper::server::conn::http1::Builder::new()
279 .keep_alive(false)
280 .serve_connection(io, svc),
281 )
282 .await
283 .expect("serve_connection timed out")
284 .expect("serve_connection failed");
285 });
286
287 let _ = timeout(Duration::from_secs(1), ready_rx)
288 .await
289 .expect("server ready signal timed out")
290 .expect("server dropped before ready");
291
292 let mut roots = rustls::RootCertStore::empty();
294 add_test_cert_to_roots(&mut roots);
295
296 let client_config = rustls::ClientConfig::builder()
297 .with_root_certificates(roots)
298 .with_no_client_auth();
299
300 let https = hyper_rustls::HttpsConnectorBuilder::new()
302 .with_tls_config(client_config)
303 .https_only()
304 .enable_http1()
305 .build();
306
307 let client =
308 hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
309 .build(https);
310
311 let uri: http::Uri = format!("https://127.0.0.1:{}/", addr.port())
312 .parse()
313 .expect("bad uri");
314
315 let req = http::Request::builder()
316 .method("GET")
317 .uri(uri)
318 .body(Full::<Bytes>::new(Bytes::new()))
319 .expect("request build failed");
320
321 let res = timeout(Duration::from_secs(3), client.request(req))
322 .await
323 .expect("client request timed out")
324 .expect("TLS handshake or request failed");
325
326 assert!(res.status().is_success());
327
328 timeout(Duration::from_secs(3), server_task)
329 .await
330 .expect("server task timed out")
331 .expect("server task failed");
332 }
333
334 #[tokio::test]
340 async fn rejects_non_http_schemes() {
341 let uri: http::Uri = "ftp://example.com:1234".parse().unwrap();
342 let res = get_client(uri).await;
343
344 assert!(
345 res.is_err(),
346 "expected get_client() to reject non-http(s) schemes, but got Ok"
347 );
348 }
349
350 #[tokio::test]
356 async fn https_connector_must_not_downgrade_to_http1() {
357 use http_body_util::Full;
358
359 let _ = rustls::crypto::ring::default_provider().install_default();
360
361 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind failed");
362 let addr = listener.local_addr().expect("local_addr failed");
363
364 let tls_config = load_test_server_config();
365 let acceptor = TlsAcceptor::from(tls_config);
366
367 let server_task = tokio::spawn(async move {
368 let accept_res = timeout(Duration::from_secs(3), listener.accept()).await;
369 let (socket, _) = accept_res
370 .expect("server accept timed out")
371 .expect("accept failed");
372
373 let tls_stream = acceptor.accept(socket).await.expect("tls accept failed");
374 let io = TokioIo::new(tls_stream);
375
376 let svc = service_fn(|_req: Request<Incoming>| async move {
377 Ok::<_, hyper::Error>(Response::new(Full::new(Bytes::from_static(b"ok"))))
378 });
379
380 let _ = hyper::server::conn::http1::Builder::new()
383 .serve_connection(io, svc)
384 .await;
385 });
386
387 let base = format!("https://127.0.0.1:{}", addr.port());
388 let uri = base.parse::<http::Uri>().expect("bad base uri");
389
390 let endpoint = tonic::transport::Endpoint::from_shared(uri.to_string())
391 .expect("endpoint")
392 .tcp_nodelay(true);
393
394 let tls = client_tls_config().expect("tls config");
395 let connect_res = endpoint
396 .tls_config(tls)
397 .expect("tls_config failed")
398 .connect()
399 .await;
400
401 assert!(
403 connect_res.is_err(),
404 "expected connect to fail (no downgrade to HTTP/1.1), but it succeeded"
405 );
406
407 server_task.abort();
408 }
409}