socks_http_kit/
http.rs

1use std::{
2    fmt::{self, Display, Formatter},
3    io::{Error, ErrorKind, Result},
4};
5
6use base64::engine::{Engine, general_purpose::STANDARD};
7use futures_util::StreamExt;
8use httparse::{EMPTY_HEADER, Request, Response, Status};
9use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
10use tokio_util::codec::{BytesCodec, FramedRead};
11
12use crate::{Address, AuthMethod};
13
14const MAX_HEADER_SIZE: usize = 8192;
15
16/// HTTP proxy response status codes as defined in RFC 7231.
17///
18/// These status codes represent the standard HTTP responses for proxy operations:
19/// - 200 OK: Connection established successfully
20/// - 403 Forbidden: Access denied
21/// - 407 Proxy Authentication Required: Authentication needed
22/// - 502 Bad Gateway: Connection to target failed
23/// - 504 Gateway Timeout: Connection to target timed out
24///
25/// Reference: <https://datatracker.ietf.org/doc/html/rfc7231#section-6>
26#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
27pub enum HttpReply {
28    #[allow(missing_docs)]
29    Ok = 200,
30    #[allow(missing_docs)]
31    Forbidden = 403,
32    #[allow(missing_docs)]
33    ProxyAuthenticationRequired = 407,
34    #[allow(missing_docs)]
35    BadGateway = 502,
36    #[allow(missing_docs)]
37    GatewayTimeout = 504,
38}
39
40/// Accepts an HTTP proxy connection request from a client.
41///
42/// This function reads and processes an HTTP CONNECT request from the client,
43/// validates authentication if required, and extracts the target address.
44///
45/// # Arguments
46/// * `stream` - A mutable reference to an asynchronous stream implementing `AsyncRead` + `Unpin`.
47/// * `auth_method` - The authentication method required for this connection.
48///
49/// # Returns
50/// * `Result<Address>` - The parsed target address on success, or an error if the request
51///   is invalid, authentication fails, or the connection cannot be established.
52///
53pub async fn http_accept<T>(stream: &mut T, auth_method: &AuthMethod) -> Result<Address>
54where
55    T: AsyncRead + Unpin,
56{
57    read_client_request(stream, auth_method).await
58}
59
60/// Completes an HTTP proxy connection by sending a response to the client.
61///
62/// After processing a client's HTTP CONNECT request with `http_accept`, this function
63/// sends the appropriate HTTP response to indicate success or failure.
64///
65/// # Arguments
66/// * `stream` - A mutable reference to an asynchronous stream.
67/// * `reply` - The HTTP reply status to send to the client.
68///
69/// # Returns
70/// * `Result<()>` - Success if the response is sent, or an IO error if writing fails.
71///
72pub async fn http_finalize_accept<T>(stream: &mut T, reply: &HttpReply) -> Result<()>
73where
74    T: AsyncWrite + Unpin,
75{
76    write_server_response(stream, reply).await
77}
78
79/// Establishes an HTTP proxy connection to a target server.
80///
81/// This function sends an HTTP CONNECT request to a proxy server with the specified
82/// target address and authentication credentials, then verifies the response.
83///
84/// # Arguments
85/// * `stream` - A mutable reference to an asynchronous stream implementing `AsyncRead` + `AsyncWrite` + `Unpin`.
86/// * `address` - The target address to connect to.
87/// * `auth_method` - The authentication method to use for this connection.
88///
89/// # Returns
90/// * `Result<()>` - Success if the connection is established, or an error if the request
91///   fails, authentication is rejected, or the server returns a non-200 status code.
92///
93pub async fn http_connect<T>(
94    stream: &mut T,
95    address: &Address,
96    auth_method: &AuthMethod,
97) -> Result<()>
98where
99    T: AsyncRead + AsyncWrite + Unpin,
100{
101    write_client_request(stream, address, auth_method).await?;
102    read_server_response(stream).await?;
103    Ok(())
104}
105
106/// The HTTP CONNECT request format complies with RFC 7231:
107///
108/// ```text
109/// CONNECT example.com:80 HTTP/1.1
110/// Host: example.com:80
111/// [Proxy-Authorization: Basic base64(username:password)]
112/// Connection: keep-alive
113/// ```
114async fn read_client_request<T>(reader: &mut T, auth_method: &AuthMethod) -> Result<Address>
115where
116    T: AsyncRead + Unpin,
117{
118    let mut framed = FramedRead::new(reader, BytesCodec::new());
119    let mut buffer = Vec::with_capacity(MAX_HEADER_SIZE);
120
121    let (path, auth_header_value) = loop {
122        let mut headers = [EMPTY_HEADER; 32];
123        let mut req = Request::new(&mut headers);
124        match req.parse(&buffer).map_err(HttpError::ParseRequestFailed)? {
125            Status::Complete(_) => {
126                let method = req.method.ok_or(HttpError::MissingMethod)?;
127                let path = req.path.ok_or(HttpError::MissingTargetPath)?;
128                let auth_header_value = headers
129                    .iter()
130                    .find(|h| h.name.eq_ignore_ascii_case("proxy-authorization"))
131                    .and_then(|h| std::str::from_utf8(h.value).ok())
132                    .map(String::from);
133
134                // Check and process extracted information
135                if method != "CONNECT" {
136                    return Err(HttpError::OnlyConnectSupported.into());
137                }
138
139                break (path, auth_header_value);
140            }
141            Status::Partial if buffer.len() >= MAX_HEADER_SIZE => {
142                return Err(HttpError::HeaderTooLong.into());
143            }
144            Status::Partial => match framed.next().await {
145                Some(Ok(bytes)) => buffer.extend_from_slice(&bytes),
146                Some(Err(e)) => return Err(e),
147                None => return Err(HttpError::ConnectionClosedHeaderIncomplete.into()),
148            },
149        }
150    };
151
152    // Verify authentication
153    if let AuthMethod::UserPass { username, password } = auth_method {
154        let Some(auth) = auth_header_value else {
155            return Err(HttpError::AuthenticationRequired.into());
156        };
157
158        if !auth.starts_with("Basic ") {
159            return Err(HttpError::OnlyBasicAuthSupported.into());
160        }
161
162        let base64_value = auth.trim_start_matches("Basic ").trim();
163        let decoded = STANDARD
164            .decode(base64_value)
165            .map_err(|_| HttpError::InvalidBase64Encoding)?;
166
167        let decoded_str = String::from_utf8_lossy(&decoded);
168        let creds: Vec<&str> = decoded_str.split(':').collect();
169
170        if creds.len() < 2 || creds[0] != username || creds[1] != password {
171            return Err(HttpError::InvalidCredentials.into());
172        }
173    }
174
175    // Parse the target address
176    Ok(Address::try_from(path)?)
177}
178
179async fn write_client_request<T>(
180    writer: &mut T,
181    address: &Address,
182    auth_method: &AuthMethod,
183) -> Result<()>
184where
185    T: AsyncWrite + Unpin,
186{
187    let target: String = address.into();
188
189    // Construct CONNECT request
190    let mut request = format!("CONNECT {} HTTP/1.1\r\n", target);
191
192    // Add Host header
193    request.push_str(&format!("Host: {}\r\n", target));
194
195    // Add authentication header (if required)
196    match auth_method {
197        AuthMethod::UserPass { username, password } => {
198            let credentials = format!("{}:{}", username, password);
199            let encoded = STANDARD.encode(credentials);
200            request.push_str(&format!("Proxy-Authorization: Basic {}\r\n", encoded));
201        }
202        AuthMethod::NoAuth => {} // No authentication required
203    }
204
205    // Add Connection header and end headers
206    request.push_str("Connection: keep-alive\r\n\r\n");
207
208    // Write the request
209    writer.write_all(request.as_bytes()).await
210}
211
212/// The HTTP response format complies with RFC 7231:
213///
214/// Successful response:
215/// ```text
216/// HTTP/1.1 200 OK
217/// Connection: keep-alive
218/// Content-Length: 0
219/// ```
220///
221/// Authentication required response:
222/// ```text
223/// HTTP/1.1 407 Proxy Authentication Required
224/// Proxy-Authenticate: Basic realm="Proxy"
225/// Connection: keep-alive
226/// Content-Length: 0
227/// ```
228async fn read_server_response<T>(reader: &mut T) -> Result<()>
229where
230    T: AsyncRead + Unpin,
231{
232    let mut framed = FramedRead::new(reader, BytesCodec::new());
233    let mut buffer = Vec::with_capacity(MAX_HEADER_SIZE);
234
235    loop {
236        let mut headers = [EMPTY_HEADER; 32];
237        let mut resp = Response::new(&mut headers);
238
239        match resp
240            .parse(&buffer)
241            .map_err(HttpError::ParseResponseFailed)?
242        {
243            Status::Complete(_) => {
244                let status_code = resp.code.ok_or(HttpError::MissingStatusCode)?;
245                let reason = String::from(resp.reason.unwrap_or("Unknown error"));
246
247                // Determine if status code indicates success
248                if status_code != 200 {
249                    return Err(HttpError::HttpError(status_code, reason).into());
250                }
251                return Ok(());
252            }
253            Status::Partial if buffer.len() >= MAX_HEADER_SIZE => {
254                return Err(HttpError::HeaderTooLong.into());
255            }
256            Status::Partial => match framed.next().await {
257                Some(Ok(bytes)) => buffer.extend_from_slice(&bytes),
258                Some(Err(e)) => return Err(e),
259                None => return Err(HttpError::ConnectionClosedHeaderIncomplete.into()),
260            },
261        }
262    }
263}
264
265async fn write_server_response<T>(writer: &mut T, reply: &HttpReply) -> Result<()>
266where
267    T: AsyncWrite + Unpin,
268{
269    // Get status code
270    let status_code = *reply as u16;
271
272    // Get status text
273    let status_text = match reply {
274        HttpReply::Ok => "OK",
275        HttpReply::Forbidden => "Forbidden",
276        HttpReply::ProxyAuthenticationRequired => "Proxy Authentication Required",
277        HttpReply::BadGateway => "Bad Gateway",
278        HttpReply::GatewayTimeout => "Gateway Timeout",
279    };
280
281    // Construct response
282    let mut response = format!("HTTP/1.1 {} {}\r\n", status_code, status_text);
283
284    // Add appropriate headers based on status code
285    if *reply == HttpReply::ProxyAuthenticationRequired {
286        response.push_str("Proxy-Authenticate: Basic realm=\"Proxy\"\r\n");
287    }
288
289    // Add standard headers
290    response.push_str("Connection: keep-alive\r\n");
291    response.push_str("Content-Length: 0\r\n\r\n");
292
293    // Write the response
294    writer.write_all(response.as_bytes()).await
295}
296
297/// Errors that can occur during HTTP proxy protocol operations.
298///
299/// Each variant represents a specific error condition that may arise when implementing
300/// or using the HTTP proxy protocol, particularly with the CONNECT method.
301#[derive(Clone, Debug, Eq, PartialEq)]
302#[non_exhaustive]
303pub enum HttpError {
304    /// Failed to parse the HTTP request due to a specific httparse error.
305    ParseRequestFailed(httparse::Error),
306    /// Failed to parse the HTTP response due to a specific httparse error.
307    ParseResponseFailed(httparse::Error),
308    /// HTTP header section exceeds the maximum buffer size.
309    HeaderTooLong,
310    /// Connection was closed before the complete HTTP header was received.
311    ConnectionClosedHeaderIncomplete,
312    /// HTTP request is missing the method field.
313    MissingMethod,
314    /// HTTP proxy implementation only supports the CONNECT method.
315    OnlyConnectSupported,
316    /// HTTP proxy authentication only supports the Basic scheme.
317    OnlyBasicAuthSupported,
318    /// Provided authorization header contains invalid Base64 encoding.
319    InvalidBase64Encoding,
320    /// Provided username/password combination is incorrect.
321    InvalidCredentials,
322    /// Proxy requires authentication but no credentials were provided.
323    AuthenticationRequired,
324    /// HTTP CONNECT request is missing the target host:port path.
325    MissingTargetPath,
326    /// HTTP response is missing the status code.
327    MissingStatusCode,
328    /// Server returned an HTTP error with specific status code and reason.
329    HttpError(u16, String),
330}
331
332impl Display for HttpError {
333    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
334        match self {
335            Self::ParseRequestFailed(reason) => {
336                write!(f, "Failed to parse HTTP request: {}", reason)
337            }
338            Self::ParseResponseFailed(reason) => {
339                write!(f, "Failed to parse HTTP response: {}", reason)
340            }
341            Self::HeaderTooLong => write!(f, "HTTP header exceeds maximum length"),
342            Self::ConnectionClosedHeaderIncomplete => {
343                write!(f, "Connection closed, HTTP header incomplete")
344            }
345            Self::MissingMethod => write!(f, "Missing HTTP method"),
346            Self::OnlyConnectSupported => write!(f, "Only CONNECT method is supported"),
347            Self::OnlyBasicAuthSupported => write!(f, "Only Basic authentication is supported"),
348            Self::InvalidBase64Encoding => write!(f, "Invalid Base64 encoding"),
349            Self::InvalidCredentials => write!(f, "Invalid credentials"),
350            Self::AuthenticationRequired => write!(f, "Authentication required"),
351            Self::MissingTargetPath => write!(f, "Missing target path"),
352            Self::MissingStatusCode => write!(f, "Missing status code"),
353            Self::HttpError(code, reason) => write!(f, "HTTP error: {} {}", code, reason),
354        }
355    }
356}
357
358impl std::error::Error for HttpError {}
359
360impl From<HttpError> for Error {
361    fn from(e: HttpError) -> Self {
362        match e {
363            HttpError::ParseRequestFailed(_) => Error::new(ErrorKind::InvalidData, e),
364            HttpError::ParseResponseFailed(_) => Error::new(ErrorKind::InvalidData, e),
365            HttpError::HeaderTooLong => Error::new(ErrorKind::InvalidData, e),
366            HttpError::ConnectionClosedHeaderIncomplete => Error::new(ErrorKind::UnexpectedEof, e),
367            HttpError::MissingMethod => Error::new(ErrorKind::InvalidData, e),
368            HttpError::OnlyConnectSupported => Error::new(ErrorKind::InvalidData, e),
369            HttpError::OnlyBasicAuthSupported => Error::new(ErrorKind::PermissionDenied, e),
370            HttpError::InvalidBase64Encoding => Error::new(ErrorKind::InvalidData, e),
371            HttpError::InvalidCredentials => Error::new(ErrorKind::PermissionDenied, e),
372            HttpError::AuthenticationRequired => Error::new(ErrorKind::PermissionDenied, e),
373            HttpError::MissingTargetPath => Error::new(ErrorKind::InvalidData, e),
374            HttpError::MissingStatusCode => Error::new(ErrorKind::InvalidData, e),
375            HttpError::HttpError(..) => Error::new(ErrorKind::Other, e),
376        }
377    }
378}
379
380#[cfg(test)]
381mod test {
382    use std::net::{Ipv4Addr, Ipv6Addr};
383
384    use super::*;
385    use crate::test_utils::create_mock_stream;
386
387    #[tokio::test]
388    async fn test_client_request_write_read() {
389        let (mut stream1, mut stream2) = create_mock_stream();
390        let all_addresses = [
391            Address::IPv4((Ipv4Addr::new(192, 168, 1, 1), 8080)),
392            Address::DomainName(("example.com".to_string(), 443)),
393            Address::IPv6((
394                Ipv6Addr::new(0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x01),
395                8080,
396            )),
397        ];
398        let all_auth_methods = [
399            AuthMethod::UserPass {
400                username: "user".to_string(),
401                password: "pass".to_string(),
402            },
403            AuthMethod::NoAuth,
404        ];
405        for address in all_addresses.iter() {
406            for auth_method in all_auth_methods.iter() {
407                write_client_request(&mut stream1, address, auth_method)
408                    .await
409                    .unwrap();
410                let received_addr = read_client_request(&mut stream2, auth_method)
411                    .await
412                    .unwrap();
413                assert_eq!(address, &received_addr);
414            }
415        }
416    }
417
418    #[tokio::test]
419    async fn test_server_response_write_read() {
420        let (mut stream1, mut stream2) = create_mock_stream();
421        write_server_response(&mut stream1, &HttpReply::Ok)
422            .await
423            .unwrap();
424        read_server_response(&mut stream2).await.unwrap();
425    }
426
427    #[tokio::test]
428    async fn test_read_client_request_missing_method() {
429        // Construct a request with missing method
430        let test_input = b"HTTP/1.1\r\nHost: example.com:80\r\n\r\n";
431
432        let (client, mut server) = create_mock_stream();
433        client.write_immediate(test_input).unwrap();
434
435        let result = read_client_request(&mut server, &AuthMethod::NoAuth).await;
436
437        let err = result.unwrap_err();
438        assert_eq!(err.kind(), ErrorKind::InvalidData);
439        assert!(matches!(
440            err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
441            &HttpError::ParseRequestFailed(_)
442        ));
443    }
444
445    #[tokio::test]
446    async fn test_read_client_request_missing_path() {
447        // Construct a request with missing target path
448        let test_input = b"CONNECT HTTP/1.1\r\nHost: example.com:80\r\n\r\n";
449
450        let (client, mut server) = create_mock_stream();
451        client.write_immediate(test_input).unwrap();
452
453        let result = read_client_request(&mut server, &AuthMethod::NoAuth).await;
454
455        let err = result.unwrap_err();
456        assert_eq!(err.kind(), ErrorKind::InvalidData);
457        assert!(matches!(
458            err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
459            &HttpError::ParseRequestFailed(_)
460        ));
461    }
462
463    #[tokio::test]
464    async fn test_read_client_request_non_connect_method() {
465        // Construct a request with non-CONNECT method (using GET instead)
466        let test_input = b"GET example.com:80 HTTP/1.1\r\nHost: example.com:80\r\n\r\n";
467
468        let (client, mut server) = create_mock_stream();
469        client.write_immediate(test_input).unwrap();
470
471        let result = read_client_request(&mut server, &AuthMethod::NoAuth).await;
472
473        let err = result.unwrap_err();
474        assert_eq!(err.kind(), ErrorKind::InvalidData);
475        assert_eq!(
476            err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
477            &HttpError::OnlyConnectSupported
478        );
479    }
480
481    #[tokio::test]
482    async fn test_read_client_request_very_large_header() {
483        // Construct a request with extremely large header (exceeding MAX_HEADER_SIZE bytes)
484        let mut large_header = Vec::with_capacity(10000);
485        large_header
486            .extend_from_slice(b"CONNECT example.com:80 HTTP/1.1\r\nHost: example.com:80\r\n");
487        large_header.extend_from_slice(b"X-Custom-Header: ");
488        large_header.extend_from_slice(&vec![b'A'; MAX_HEADER_SIZE]);
489        large_header.extend_from_slice(b"\r\n\r\n");
490
491        let (client, mut server) = create_mock_stream();
492        client.write_immediate(&large_header).unwrap();
493
494        let result = read_client_request(&mut server, &AuthMethod::NoAuth).await;
495
496        let err = result.unwrap_err();
497        assert_eq!(err.kind(), ErrorKind::InvalidData);
498        assert_eq!(
499            err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
500            &HttpError::HeaderTooLong
501        );
502    }
503
504    #[tokio::test]
505    async fn test_read_client_request_incomplete_header() {
506        // Write incomplete header without the final \r\n
507        let test_input = b"CONNECT example.com:80 HTTP/1.1\r\nHost: example.com:80\r\n";
508
509        let (mut client, mut server) = create_mock_stream();
510        client.write_immediate(test_input).unwrap();
511
512        // Close the connection to simulate incomplete header scenario
513        client.shutdown().await.unwrap();
514
515        let result = read_client_request(&mut server, &AuthMethod::NoAuth).await;
516
517        let err = result.unwrap_err();
518        assert_eq!(err.kind(), ErrorKind::UnexpectedEof);
519        assert_eq!(
520            err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
521            &HttpError::ConnectionClosedHeaderIncomplete
522        );
523    }
524
525    #[tokio::test]
526    async fn test_read_client_request_no_auth_but_required() {
527        // Construct a normal CONNECT request without auth header
528        let test_input = b"CONNECT example.com:80 HTTP/1.1\r\nHost: example.com:80\r\n\r\n";
529
530        let (client, mut server) = create_mock_stream();
531        client.write_immediate(test_input).unwrap();
532
533        // Request should be rejected because auth is required
534        let auth_method = AuthMethod::UserPass {
535            username: "user".to_string(),
536            password: "pass".to_string(),
537        };
538
539        let result = read_client_request(&mut server, &auth_method).await;
540
541        let err = result.unwrap_err();
542        assert_eq!(err.kind(), ErrorKind::PermissionDenied);
543        assert_eq!(
544            err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
545            &HttpError::AuthenticationRequired
546        );
547    }
548
549    #[tokio::test]
550    async fn test_read_client_request_non_basic_auth() {
551        // Construct a request with non-Basic authentication
552        let test_input = b"CONNECT example.com:80 HTTP/1.1\r\n\
553        Host: example.com:80\r\n\
554        Proxy-Authorization: Digest username=\"user\", realm=\"proxy\"\r\n\r\n";
555
556        let (client, mut server) = create_mock_stream();
557        client.write_immediate(test_input).unwrap();
558
559        // Auth method requires Basic auth
560        let auth_method = AuthMethod::UserPass {
561            username: "user".to_string(),
562            password: "pass".to_string(),
563        };
564
565        let result = read_client_request(&mut server, &auth_method).await;
566
567        let err = result.unwrap_err();
568        assert_eq!(err.kind(), ErrorKind::PermissionDenied);
569        assert_eq!(
570            err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
571            &HttpError::OnlyBasicAuthSupported
572        );
573    }
574
575    #[tokio::test]
576    async fn test_read_client_request_invalid_base64() {
577        // Construct a request with invalid Base64 in auth header
578        let test_input = b"CONNECT example.com:80 HTTP/1.1\r\n\
579        Host: example.com:80\r\n\
580        Proxy-Authorization: Basic !@#$%^&*\r\n\r\n";
581
582        let (client, mut server) = create_mock_stream();
583        client.write_immediate(test_input).unwrap();
584
585        // Auth method requires Basic auth
586        let auth_method = AuthMethod::UserPass {
587            username: "user".to_string(),
588            password: "pass".to_string(),
589        };
590
591        let result = read_client_request(&mut server, &auth_method).await;
592
593        let err = result.unwrap_err();
594        assert_eq!(err.kind(), ErrorKind::InvalidData);
595        assert_eq!(
596            err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
597            &HttpError::InvalidBase64Encoding
598        );
599    }
600
601    #[tokio::test]
602    async fn test_read_client_request_invalid_credentials() {
603        // Encode invalid credentials in Base64
604        let encoded = STANDARD.encode("wrong:credentials");
605
606        // Construct a request with valid Base64 but wrong credentials
607        let request = format!(
608            "CONNECT example.com:80 HTTP/1.1\r\n\
609        Host: example.com:80\r\n\
610        Proxy-Authorization: Basic {}\r\n\r\n",
611            encoded
612        );
613
614        let (client, mut server) = create_mock_stream();
615        client.write_immediate(request.as_bytes()).unwrap();
616
617        // Auth method requires specific credentials
618        let auth_method = AuthMethod::UserPass {
619            username: "user".to_string(),
620            password: "pass".to_string(),
621        };
622
623        let result = read_client_request(&mut server, &auth_method).await;
624
625        let err = result.unwrap_err();
626        assert_eq!(err.kind(), ErrorKind::PermissionDenied);
627        assert_eq!(
628            err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
629            &HttpError::InvalidCredentials
630        );
631    }
632
633    #[tokio::test]
634    async fn test_read_server_response_missing_status_code() {
635        // Construct a response with missing status code
636        let test_input = b"HTTP/1.1 OK\r\nContent-Length: 0\r\n\r\n";
637
638        let (client, mut server) = create_mock_stream();
639        client.write_immediate(test_input).unwrap();
640
641        let result = read_server_response(&mut server).await;
642
643        let err = result.unwrap_err();
644        assert_eq!(err.kind(), ErrorKind::InvalidData);
645        assert!(matches!(
646            err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
647            &HttpError::ParseResponseFailed(_)
648        ));
649    }
650
651    #[tokio::test]
652    async fn test_read_server_response_http_error() {
653        // Construct a response with error status code
654        let test_input = b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n";
655
656        let (client, mut server) = create_mock_stream();
657        client.write_immediate(test_input).unwrap();
658
659        let result = read_server_response(&mut server).await;
660
661        let err = result.unwrap_err();
662        assert_eq!(err.kind(), ErrorKind::Other);
663
664        // Check the error contains the status code and reason
665        if let Some(http_err) = err.get_ref().unwrap().downcast_ref::<HttpError>() {
666            match http_err {
667                HttpError::HttpError(code, reason) => {
668                    assert_eq!(*code, 403);
669                    assert_eq!(reason, "Forbidden");
670                }
671                _ => panic!("Expected HttpError::HttpError variant"),
672            }
673        } else {
674            panic!("Expected HttpError");
675        }
676    }
677
678    #[tokio::test]
679    async fn test_read_server_header_too_large() {
680        // Construct a response with extremely large header
681        let mut large_header = Vec::with_capacity(10000);
682        large_header.extend_from_slice(b"HTTP/1.1 200 OK\r\n");
683        large_header.extend_from_slice(b"X-Custom-Header: ");
684        large_header.extend_from_slice(&vec![b'A'; MAX_HEADER_SIZE]);
685        large_header.extend_from_slice(b"\r\n\r\n");
686
687        let (client, mut server) = create_mock_stream();
688        client.write_immediate(&large_header).unwrap();
689
690        let result = read_server_response(&mut server).await;
691
692        let err = result.unwrap_err();
693        assert_eq!(err.kind(), ErrorKind::InvalidData);
694        assert_eq!(
695            err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
696            &HttpError::HeaderTooLong
697        );
698    }
699
700    #[tokio::test]
701    async fn test_read_server_response_incomplete_header() {
702        // Write incomplete header without the final \r\n
703        let test_input = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n";
704
705        let (mut client, mut server) = create_mock_stream();
706        client.write_immediate(test_input).unwrap();
707
708        // Close the connection to simulate incomplete header scenario
709        client.shutdown().await.unwrap();
710
711        let result = read_server_response(&mut server).await;
712
713        let err = result.unwrap_err();
714        assert_eq!(err.kind(), ErrorKind::UnexpectedEof);
715        assert_eq!(
716            err.get_ref().unwrap().downcast_ref::<HttpError>().unwrap(),
717            &HttpError::ConnectionClosedHeaderIncomplete
718        );
719    }
720}