1use std::collections::HashMap;
8use std::net::TcpStream;
9use std::sync::{Arc, Mutex};
10
11use crate::header_map::HeaderMap;
12use crate::request::{FetchError, Request};
13use crate::response::Response;
14use crate::KEEP_ALIVE_TIMEOUT;
15
16#[derive(Default, Clone)]
19pub struct Client {
20 connection_pool: Arc<Mutex<ConnectionPool>>,
21 headers: HeaderMap,
22}
23
24impl Client {
25 pub fn new() -> Self {
27 Self::default()
28 }
29
30 pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
32 self.headers.insert(name.into(), value.into());
33 self
34 }
35
36 pub fn fetch(&mut self, mut request: Request) -> Result<Response, FetchError> {
38 for (name, value) in self.headers.iter() {
40 request = request.header(name, value);
41 }
42
43 let addr = format!(
45 "{}:{}",
46 request.url.host().expect("No host in URL"),
47 request.url.port().unwrap_or(80)
48 );
49 let mut stream = self
50 .connection_pool
51 .lock()
52 .expect("Can't lock connection pool")
53 .take_connection(&addr)
54 .ok_or(FetchError)?;
55 stream
56 .set_read_timeout(Some(KEEP_ALIVE_TIMEOUT))
57 .map_err(|_| FetchError)?;
58
59 request.write_to_stream(&mut stream, true);
61 let res = Response::read_from_stream(&mut stream).map_err(|_| FetchError)?;
62
63 self.connection_pool
65 .lock()
66 .expect("Can't lock connection pool")
67 .return_connection(&addr, stream);
68 Ok(res)
69 }
70}
71
72#[derive(Default)]
74struct ConnectionPool {
75 connections: HashMap<String, Vec<TcpStream>>,
76}
77
78impl ConnectionPool {
79 fn take_connection(&mut self, addr: &str) -> Option<TcpStream> {
80 if !self.connections.contains_key(addr) {
82 self.connections.insert(addr.to_string(), Vec::new());
83 }
84
85 if let Some(connections) = self.connections.get_mut(addr) {
87 if let Some(conn) = connections.pop() {
89 return Some(conn);
90 }
91
92 if let Ok(conn) = TcpStream::connect(addr) {
94 return Some(conn);
95 }
96 }
97
98 None
100 }
101
102 fn return_connection(&mut self, addr: &str, conn: TcpStream) {
103 if let Some(connections) = self.connections.get_mut(addr) {
105 connections.push(conn);
106 }
107 }
108}
109
110#[cfg(test)]
112mod test {
113 use std::io::{Read, Write};
114 use std::net::{Ipv4Addr, TcpListener};
115 use std::thread;
116
117 use super::*;
118
119 #[test]
120 fn test_client_multiple_requests() {
121 let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).unwrap();
123 let server_addr = listener.local_addr().unwrap();
124 thread::spawn(move || {
125 let (mut stream, _) = listener.accept().unwrap();
126 loop {
127 let mut buf = [0; 512];
128 _ = stream.read(&mut buf);
129 stream
130 .write_all(
131 b"HTTP/1.1 200 OK\r\nContent-Length: 4\r\nConnection: closed\r\n\r\ntest",
132 )
133 .unwrap();
134 }
135 });
136
137 let mut client = Client::new();
139 for _ in 0..10 {
140 client
141 .fetch(Request::get(format!("http://{}/", server_addr)))
142 .unwrap();
143 }
144 }
145}