small_http/
client.rs

1/*
2 * Copyright (c) 2025 Bastiaan van der Plaat
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7use std::collections::HashMap;
8use std::net::TcpStream;
9use std::sync::{Arc, Mutex};
10
11use crate::header_map::HeaderMap;
12use crate::request::{FetchError, Request};
13use crate::response::Response;
14use crate::KEEP_ALIVE_TIMEOUT;
15
16// MARK: HTTP Client
17/// HTTP client
18#[derive(Default, Clone)]
19pub struct Client {
20    connection_pool: Arc<Mutex<ConnectionPool>>,
21    headers: HeaderMap,
22}
23
24impl Client {
25    /// Create a new HTTP client
26    pub fn new() -> Self {
27        Self::default()
28    }
29
30    /// Set header
31    pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
32        self.headers.insert(name.into(), value.into());
33        self
34    }
35
36    /// Fetch a request
37    pub fn fetch(&mut self, mut request: Request) -> Result<Response, FetchError> {
38        // Add client headers to request
39        for (name, value) in self.headers.iter() {
40            request = request.header(name, value);
41        }
42
43        // Get or create connection
44        let addr = format!(
45            "{}:{}",
46            request.url.host().expect("No host in URL"),
47            request.url.port().unwrap_or(80)
48        );
49        let mut stream = self
50            .connection_pool
51            .lock()
52            .expect("Can't lock connection pool")
53            .take_connection(&addr)
54            .ok_or(FetchError)?;
55        stream
56            .set_read_timeout(Some(KEEP_ALIVE_TIMEOUT))
57            .map_err(|_| FetchError)?;
58
59        // Send request and read response
60        request.write_to_stream(&mut stream, true);
61        let res = Response::read_from_stream(&mut stream).map_err(|_| FetchError)?;
62
63        // Return connection
64        self.connection_pool
65            .lock()
66            .expect("Can't lock connection pool")
67            .return_connection(&addr, stream);
68        Ok(res)
69    }
70}
71
72// MARK: ConnectionPool
73#[derive(Default)]
74struct ConnectionPool {
75    connections: HashMap<String, Vec<TcpStream>>,
76}
77
78impl ConnectionPool {
79    fn take_connection(&mut self, addr: &str) -> Option<TcpStream> {
80        // Insert addr into connection pool if it doesn't exist
81        if !self.connections.contains_key(addr) {
82            self.connections.insert(addr.to_string(), Vec::new());
83        }
84
85        // Check if we have a connections for the addr
86        if let Some(connections) = self.connections.get_mut(addr) {
87            // Check if we have a connection available
88            if let Some(conn) = connections.pop() {
89                return Some(conn);
90            }
91
92            // Open connection and return it
93            if let Ok(conn) = TcpStream::connect(addr) {
94                return Some(conn);
95            }
96        }
97
98        // No connection available
99        None
100    }
101
102    fn return_connection(&mut self, addr: &str, conn: TcpStream) {
103        // Insert connection back into pool
104        if let Some(connections) = self.connections.get_mut(addr) {
105            connections.push(conn);
106        }
107    }
108}
109
110// MARK: Tests
111#[cfg(test)]
112mod test {
113    use std::io::{Read, Write};
114    use std::net::{Ipv4Addr, TcpListener};
115    use std::thread;
116
117    use super::*;
118
119    #[test]
120    fn test_client_multiple_requests() {
121        // Start test server
122        let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).unwrap();
123        let server_addr = listener.local_addr().unwrap();
124        thread::spawn(move || {
125            let (mut stream, _) = listener.accept().unwrap();
126            loop {
127                let mut buf = [0; 512];
128                _ = stream.read(&mut buf);
129                stream
130                    .write_all(
131                        b"HTTP/1.1 200 OK\r\nContent-Length: 4\r\nConnection: closed\r\n\r\ntest",
132                    )
133                    .unwrap();
134            }
135        });
136
137        // Create client and fetch multiple requests
138        let mut client = Client::new();
139        for _ in 0..10 {
140            client
141                .fetch(Request::get(format!("http://{}/", server_addr)))
142                .unwrap();
143        }
144    }
145}