1use std::{
2 fmt::{self, Display, Formatter},
3 io::{Error, ErrorKind, Result},
4};
5
6use base64::engine::{Engine, general_purpose::STANDARD};
7use futures_util::StreamExt;
8use httparse::{EMPTY_HEADER, Request, Response, Status};
9use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
10use tokio_util::codec::{BytesCodec, FramedRead};
11
12use crate::{Address, AuthMethod};
13
14const MAX_HEADER_SIZE: usize = 8192;
15
16#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
27pub enum HttpReply {
28 #[allow(missing_docs)]
29 Ok = 200,
30 #[allow(missing_docs)]
31 Forbidden = 403,
32 #[allow(missing_docs)]
33 ProxyAuthenticationRequired = 407,
34 #[allow(missing_docs)]
35 BadGateway = 502,
36 #[allow(missing_docs)]
37 GatewayTimeout = 504,
38}
39
40pub async fn http_accept<T>(stream: &mut T, auth_method: &AuthMethod) -> Result<Address>
54where
55 T: AsyncRead + Unpin,
56{
57 read_client_request(stream, auth_method).await
58}
59
60pub async fn http_finalize_accept<T>(stream: &mut T, reply: &HttpReply) -> Result<()>
73where
74 T: AsyncWrite + Unpin,
75{
76 write_server_response(stream, reply).await
77}
78
79pub async fn http_connect<T>(
94 stream: &mut T,
95 address: &Address,
96 auth_method: &AuthMethod,
97) -> Result<()>
98where
99 T: AsyncRead + AsyncWrite + Unpin,
100{
101 write_client_request(stream, address, auth_method).await?;
102 read_server_response(stream).await?;
103 Ok(())
104}
105
106async fn read_client_request<T>(reader: &mut T, auth_method: &AuthMethod) -> Result<Address>
115where
116 T: AsyncRead + Unpin,
117{
118 let mut framed = FramedRead::new(reader, BytesCodec::new());
119 let mut buffer = Vec::with_capacity(MAX_HEADER_SIZE);
120
121 let (path, auth_header_value) = loop {
122 let mut headers = [EMPTY_HEADER; 32];
123 let mut req = Request::new(&mut headers);
124 match req.parse(&buffer).map_err(HttpError::ParseRequestFailed)? {
125 Status::Complete(_) => {
126 let method = req.method.ok_or(HttpError::MissingMethod)?;
127 let path = req.path.ok_or(HttpError::MissingTargetPath)?;
128 let auth_header_value = headers
129 .iter()
130 .find(|h| h.name.eq_ignore_ascii_case("proxy-authorization"))
131 .and_then(|h| std::str::from_utf8(h.value).ok())
132 .map(String::from);
133
134 if method != "CONNECT" {
136 return Err(HttpError::OnlyConnectSupported.into());
137 }
138
139 break (path, auth_header_value);
140 }
141 Status::Partial if buffer.len() >= MAX_HEADER_SIZE => {
142 return Err(HttpError::HeaderTooLong.into());
143 }
144 Status::Partial => match framed.next().await {
145 Some(Ok(bytes)) => buffer.extend_from_slice(&bytes),
146 Some(Err(e)) => return Err(e),
147 None => return Err(HttpError::ConnectionClosedHeaderIncomplete.into()),
148 },
149 }
150 };
151
152 if let AuthMethod::UserPass { username, password } = auth_method {
154 let Some(auth) = auth_header_value else {
155 return Err(HttpError::AuthenticationRequired.into());
156 };
157
158 if !auth.starts_with("Basic ") {
159 return Err(HttpError::OnlyBasicAuthSupported.into());
160 }
161
162 let base64_value = auth.trim_start_matches("Basic ").trim();
163 let decoded = STANDARD
164 .decode(base64_value)
165 .map_err(|_| HttpError::InvalidBase64Encoding)?;
166
167 let decoded_str = String::from_utf8_lossy(&decoded);
168 let creds: Vec<&str> = decoded_str.split(':').collect();
169
170 if creds.len() < 2 || creds[0] != username || creds[1] != password {
171 return Err(HttpError::InvalidCredentials.into());
172 }
173 }
174
175 Ok(Address::try_from(path)?)
177}
178
179async fn write_client_request<T>(
180 writer: &mut T,
181 address: &Address,
182 auth_method: &AuthMethod,
183) -> Result<()>
184where
185 T: AsyncWrite + Unpin,
186{
187 let target: String = address.into();
188
189 let mut request = format!("CONNECT {} HTTP/1.1\r\n", target);
191
192 request.push_str(&format!("Host: {}\r\n", target));
194
195 match auth_method {
197 AuthMethod::UserPass { username, password } => {
198 let credentials = format!("{}:{}", username, password);
199 let encoded = STANDARD.encode(credentials);
200 request.push_str(&format!("Proxy-Authorization: Basic {}\r\n", encoded));
201 }
202 AuthMethod::NoAuth => {} }
204
205 request.push_str("Connection: keep-alive\r\n\r\n");
207
208 writer.write_all(request.as_bytes()).await
210}
211
212async fn read_server_response<T>(reader: &mut T) -> Result<()>
229where
230 T: AsyncRead + Unpin,
231{
232 let mut framed = FramedRead::new(reader, BytesCodec::new());
233 let mut buffer = Vec::with_capacity(MAX_HEADER_SIZE);
234
235 loop {
236 let mut headers = [EMPTY_HEADER; 32];
237 let mut resp = Response::new(&mut headers);
238
239 match resp
240 .parse(&buffer)
241 .map_err(HttpError::ParseResponseFailed)?
242 {
243 Status::Complete(_) => {
244 let status_code = resp.code.ok_or(HttpError::MissingStatusCode)?;
245 let reason = String::from(resp.reason.unwrap_or("Unknown error"));
246
247 if status_code != 200 {
249 return Err(HttpError::HttpError(status_code, reason).into());
250 }
251 return Ok(());
252 }
253 Status::Partial if buffer.len() >= MAX_HEADER_SIZE => {
254 return Err(HttpError::HeaderTooLong.into());
255 }
256 Status::Partial => match framed.next().await {
257 Some(Ok(bytes)) => buffer.extend_from_slice(&bytes),
258 Some(Err(e)) => return Err(e),
259 None => return Err(HttpError::ConnectionClosedHeaderIncomplete.into()),
260 },
261 }
262 }
263}
264
265async fn write_server_response<T>(writer: &mut T, reply: &HttpReply) -> Result<()>
266where
267 T: AsyncWrite + Unpin,
268{
269 let status_code = *reply as u16;
271
272 let status_text = match reply {
274 HttpReply::Ok => "OK",
275 HttpReply::Forbidden => "Forbidden",
276 HttpReply::ProxyAuthenticationRequired => "Proxy Authentication Required",
277 HttpReply::BadGateway => "Bad Gateway",
278 HttpReply::GatewayTimeout => "Gateway Timeout",
279 };
280
281 let mut response = format!("HTTP/1.1 {} {}\r\n", status_code, status_text);
283
284 if *reply == HttpReply::ProxyAuthenticationRequired {
286 response.push_str("Proxy-Authenticate: Basic realm=\"Proxy\"\r\n");
287 }
288
289 response.push_str("Connection: keep-alive\r\n");
291 response.push_str("Content-Length: 0\r\n\r\n");
292
293 writer.write_all(response.as_bytes()).await
295}
296
297#[derive(Clone, Debug, Eq, PartialEq)]
302#[non_exhaustive]
303pub enum HttpError {
304 ParseRequestFailed(httparse::Error),
306 ParseResponseFailed(httparse::Error),
308 HeaderTooLong,
310 ConnectionClosedHeaderIncomplete,
312 MissingMethod,
314 OnlyConnectSupported,
316 OnlyBasicAuthSupported,
318 InvalidBase64Encoding,
320 InvalidCredentials,
322 AuthenticationRequired,
324 MissingTargetPath,
326 MissingStatusCode,
328 HttpError(u16, String),
330}
331
332impl Display for HttpError {
333 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
334 match self {
335 Self::ParseRequestFailed(reason) => {
336 write!(f, "Failed to parse HTTP request: {}", reason)
337 }
338 Self::ParseResponseFailed(reason) => {
339 write!(f, "Failed to parse HTTP response: {}", reason)
340 }
341 Self::HeaderTooLong => write!(f, "HTTP header exceeds maximum length"),
342 Self::ConnectionClosedHeaderIncomplete => {
343 write!(f, "Connection closed, HTTP header incomplete")
344 }
345 Self::MissingMethod => write!(f, "Missing HTTP method"),
346 Self::OnlyConnectSupported => write!(f, "Only CONNECT method is supported"),
347 Self::OnlyBasicAuthSupported => write!(f, "Only Basic authentication is supported"),
348 Self::InvalidBase64Encoding => write!(f, "Invalid Base64 encoding"),
349 Self::InvalidCredentials => write!(f, "Invalid credentials"),
350 Self::AuthenticationRequired => write!(f, "Authentication required"),
351 Self::MissingTargetPath => write!(f, "Missing target path"),
352 Self::MissingStatusCode => write!(f, "Missing status code"),
353 Self::HttpError(code, reason) => write!(f, "HTTP error: {} {}", code, reason),
354 }
355 }
356}
357
358impl std::error::Error for HttpError {}
359
360impl From<HttpError> for Error {
361 fn from(e: HttpError) -> Self {
362 match e {
363 HttpError::ParseRequestFailed(_) => Error::new(ErrorKind::InvalidData, e),
364 HttpError::ParseResponseFailed(_) => Error::new(ErrorKind::InvalidData, e),
365 HttpError::HeaderTooLong => Error::new(ErrorKind::InvalidData, e),
366 HttpError::ConnectionClosedHeaderIncomplete => Error::new(ErrorKind::UnexpectedEof, e),
367 HttpError::MissingMethod => Error::new(ErrorKind::InvalidData, e),
368 HttpError::OnlyConnectSupported => Error::new(ErrorKind::InvalidData, e),
369 HttpError::OnlyBasicAuthSupported => Error::new(ErrorKind::PermissionDenied, e),
370 HttpError::InvalidBase64Encoding => Error::new(ErrorKind::InvalidData, e),
371 HttpError::InvalidCredentials => Error::new(ErrorKind::PermissionDenied, e),
372 HttpError::AuthenticationRequired => Error::new(ErrorKind::PermissionDenied, e),
373 HttpError::MissingTargetPath => Error::new(ErrorKind::InvalidData, e),
374 HttpError::MissingStatusCode => Error::new(ErrorKind::InvalidData, e),
375 HttpError::HttpError(..) => Error::new(ErrorKind::Other, e),
376 }
377 }
378}
379
380#[cfg(test)]
381mod test {
382 use std::net::{Ipv4Addr, Ipv6Addr};
383
384 use super::*;
385 use crate::test_utils::create_mock_stream;
386
387 #[tokio::test]
388 async fn test_client_request_write_read() {
389 let (mut stream1, mut stream2) = create_mock_stream();
390 let all_addresses = [
391 Address::IPv4((Ipv4Addr::new(192, 168, 1, 1), 8080)),
392 Address::DomainName(("example.com".to_string(), 443)),
393 Address::IPv6((
394 Ipv6Addr::new(0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x01),
395 8080,
396 )),
397 ];
398 let all_auth_methods = [
399 AuthMethod::UserPass {
400 username: "user".to_string(),
401 password: "pass".to_string(),
402 },
403 AuthMethod::NoAuth,
404 ];
405 for address in all_addresses.iter() {
406 for auth_method in all_auth_methods.iter() {
407 write_client_request(&mut stream1, address, auth_method)
408 .await
409 .unwrap();
410 let received_addr = read_client_request(&mut stream2, auth_method)
411 .await
412 .unwrap();
413 assert_eq!(address, &received_addr);
414 }
415 }
416 }
417
418 #[tokio::test]
419 async fn test_server_response_write_read() {
420 let (mut stream1, mut stream2) = create_mock_stream();
421 write_server_response(&mut stream1, &HttpReply::Ok)
422 .await
423 .unwrap();
424 read_server_response(&mut stream2).await.unwrap();
425 }
426
427 #[tokio::test]
428 async fn test_read_client_request_missing_method() {
429 let test_input = b"HTTP/1.1\r\nHost: example.com:80\r\n\r\n";
431
432 let (client, mut server) = create_mock_stream();
433 client.write_immediate(test_input).unwrap();
434
435 let result = read_client_request(&mut server, &AuthMethod::NoAuth).await;
436
437 let err = result.unwrap_err();
438 assert_eq!(err.kind(), ErrorKind::InvalidData);
439 assert!(matches!(
440 err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
441 &HttpError::ParseRequestFailed(_)
442 ));
443 }
444
445 #[tokio::test]
446 async fn test_read_client_request_missing_path() {
447 let test_input = b"CONNECT HTTP/1.1\r\nHost: example.com:80\r\n\r\n";
449
450 let (client, mut server) = create_mock_stream();
451 client.write_immediate(test_input).unwrap();
452
453 let result = read_client_request(&mut server, &AuthMethod::NoAuth).await;
454
455 let err = result.unwrap_err();
456 assert_eq!(err.kind(), ErrorKind::InvalidData);
457 assert!(matches!(
458 err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
459 &HttpError::ParseRequestFailed(_)
460 ));
461 }
462
463 #[tokio::test]
464 async fn test_read_client_request_non_connect_method() {
465 let test_input = b"GET example.com:80 HTTP/1.1\r\nHost: example.com:80\r\n\r\n";
467
468 let (client, mut server) = create_mock_stream();
469 client.write_immediate(test_input).unwrap();
470
471 let result = read_client_request(&mut server, &AuthMethod::NoAuth).await;
472
473 let err = result.unwrap_err();
474 assert_eq!(err.kind(), ErrorKind::InvalidData);
475 assert_eq!(
476 err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
477 &HttpError::OnlyConnectSupported
478 );
479 }
480
481 #[tokio::test]
482 async fn test_read_client_request_very_large_header() {
483 let mut large_header = Vec::with_capacity(10000);
485 large_header
486 .extend_from_slice(b"CONNECT example.com:80 HTTP/1.1\r\nHost: example.com:80\r\n");
487 large_header.extend_from_slice(b"X-Custom-Header: ");
488 large_header.extend_from_slice(&vec![b'A'; MAX_HEADER_SIZE]);
489 large_header.extend_from_slice(b"\r\n\r\n");
490
491 let (client, mut server) = create_mock_stream();
492 client.write_immediate(&large_header).unwrap();
493
494 let result = read_client_request(&mut server, &AuthMethod::NoAuth).await;
495
496 let err = result.unwrap_err();
497 assert_eq!(err.kind(), ErrorKind::InvalidData);
498 assert_eq!(
499 err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
500 &HttpError::HeaderTooLong
501 );
502 }
503
504 #[tokio::test]
505 async fn test_read_client_request_incomplete_header() {
506 let test_input = b"CONNECT example.com:80 HTTP/1.1\r\nHost: example.com:80\r\n";
508
509 let (mut client, mut server) = create_mock_stream();
510 client.write_immediate(test_input).unwrap();
511
512 client.shutdown().await.unwrap();
514
515 let result = read_client_request(&mut server, &AuthMethod::NoAuth).await;
516
517 let err = result.unwrap_err();
518 assert_eq!(err.kind(), ErrorKind::UnexpectedEof);
519 assert_eq!(
520 err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
521 &HttpError::ConnectionClosedHeaderIncomplete
522 );
523 }
524
525 #[tokio::test]
526 async fn test_read_client_request_no_auth_but_required() {
527 let test_input = b"CONNECT example.com:80 HTTP/1.1\r\nHost: example.com:80\r\n\r\n";
529
530 let (client, mut server) = create_mock_stream();
531 client.write_immediate(test_input).unwrap();
532
533 let auth_method = AuthMethod::UserPass {
535 username: "user".to_string(),
536 password: "pass".to_string(),
537 };
538
539 let result = read_client_request(&mut server, &auth_method).await;
540
541 let err = result.unwrap_err();
542 assert_eq!(err.kind(), ErrorKind::PermissionDenied);
543 assert_eq!(
544 err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
545 &HttpError::AuthenticationRequired
546 );
547 }
548
549 #[tokio::test]
550 async fn test_read_client_request_non_basic_auth() {
551 let test_input = b"CONNECT example.com:80 HTTP/1.1\r\n\
553 Host: example.com:80\r\n\
554 Proxy-Authorization: Digest username=\"user\", realm=\"proxy\"\r\n\r\n";
555
556 let (client, mut server) = create_mock_stream();
557 client.write_immediate(test_input).unwrap();
558
559 let auth_method = AuthMethod::UserPass {
561 username: "user".to_string(),
562 password: "pass".to_string(),
563 };
564
565 let result = read_client_request(&mut server, &auth_method).await;
566
567 let err = result.unwrap_err();
568 assert_eq!(err.kind(), ErrorKind::PermissionDenied);
569 assert_eq!(
570 err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
571 &HttpError::OnlyBasicAuthSupported
572 );
573 }
574
575 #[tokio::test]
576 async fn test_read_client_request_invalid_base64() {
577 let test_input = b"CONNECT example.com:80 HTTP/1.1\r\n\
579 Host: example.com:80\r\n\
580 Proxy-Authorization: Basic !@#$%^&*\r\n\r\n";
581
582 let (client, mut server) = create_mock_stream();
583 client.write_immediate(test_input).unwrap();
584
585 let auth_method = AuthMethod::UserPass {
587 username: "user".to_string(),
588 password: "pass".to_string(),
589 };
590
591 let result = read_client_request(&mut server, &auth_method).await;
592
593 let err = result.unwrap_err();
594 assert_eq!(err.kind(), ErrorKind::InvalidData);
595 assert_eq!(
596 err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
597 &HttpError::InvalidBase64Encoding
598 );
599 }
600
601 #[tokio::test]
602 async fn test_read_client_request_invalid_credentials() {
603 let encoded = STANDARD.encode("wrong:credentials");
605
606 let request = format!(
608 "CONNECT example.com:80 HTTP/1.1\r\n\
609 Host: example.com:80\r\n\
610 Proxy-Authorization: Basic {}\r\n\r\n",
611 encoded
612 );
613
614 let (client, mut server) = create_mock_stream();
615 client.write_immediate(request.as_bytes()).unwrap();
616
617 let auth_method = AuthMethod::UserPass {
619 username: "user".to_string(),
620 password: "pass".to_string(),
621 };
622
623 let result = read_client_request(&mut server, &auth_method).await;
624
625 let err = result.unwrap_err();
626 assert_eq!(err.kind(), ErrorKind::PermissionDenied);
627 assert_eq!(
628 err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
629 &HttpError::InvalidCredentials
630 );
631 }
632
633 #[tokio::test]
634 async fn test_read_server_response_missing_status_code() {
635 let test_input = b"HTTP/1.1 OK\r\nContent-Length: 0\r\n\r\n";
637
638 let (client, mut server) = create_mock_stream();
639 client.write_immediate(test_input).unwrap();
640
641 let result = read_server_response(&mut server).await;
642
643 let err = result.unwrap_err();
644 assert_eq!(err.kind(), ErrorKind::InvalidData);
645 assert!(matches!(
646 err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
647 &HttpError::ParseResponseFailed(_)
648 ));
649 }
650
651 #[tokio::test]
652 async fn test_read_server_response_http_error() {
653 let test_input = b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n";
655
656 let (client, mut server) = create_mock_stream();
657 client.write_immediate(test_input).unwrap();
658
659 let result = read_server_response(&mut server).await;
660
661 let err = result.unwrap_err();
662 assert_eq!(err.kind(), ErrorKind::Other);
663
664 if let Some(http_err) = err.get_ref().unwrap().downcast_ref::<HttpError>() {
666 match http_err {
667 HttpError::HttpError(code, reason) => {
668 assert_eq!(*code, 403);
669 assert_eq!(reason, "Forbidden");
670 }
671 _ => panic!("Expected HttpError::HttpError variant"),
672 }
673 } else {
674 panic!("Expected HttpError");
675 }
676 }
677
678 #[tokio::test]
679 async fn test_read_server_header_too_large() {
680 let mut large_header = Vec::with_capacity(10000);
682 large_header.extend_from_slice(b"HTTP/1.1 200 OK\r\n");
683 large_header.extend_from_slice(b"X-Custom-Header: ");
684 large_header.extend_from_slice(&vec![b'A'; MAX_HEADER_SIZE]);
685 large_header.extend_from_slice(b"\r\n\r\n");
686
687 let (client, mut server) = create_mock_stream();
688 client.write_immediate(&large_header).unwrap();
689
690 let result = read_server_response(&mut server).await;
691
692 let err = result.unwrap_err();
693 assert_eq!(err.kind(), ErrorKind::InvalidData);
694 assert_eq!(
695 err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
696 &HttpError::HeaderTooLong
697 );
698 }
699
700 #[tokio::test]
701 async fn test_read_server_response_incomplete_header() {
702 let test_input = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n";
704
705 let (mut client, mut server) = create_mock_stream();
706 client.write_immediate(test_input).unwrap();
707
708 client.shutdown().await.unwrap();
710
711 let result = read_server_response(&mut server).await;
712
713 let err = result.unwrap_err();
714 assert_eq!(err.kind(), ErrorKind::UnexpectedEof);
715 assert_eq!(
716 err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
717 &HttpError::ConnectionClosedHeaderIncomplete
718 );
719 }
720}