Skip to main content

spider_downloader/
client.rs

1//! Reqwest-based Downloader implementation for the `spider-lib` framework.
2//!
3//! This module provides `ReqwestClientDownloader`, a concrete implementation
4//! of the `Downloader` trait that leverages the `reqwest` HTTP client library.
5//! It is responsible for executing HTTP requests defined by `Request` objects
6//! and converting the received HTTP responses into `Response` objects suitable
7//! for further processing by the crawler.
8//!
9//! This downloader handles various HTTP methods, request bodies (JSON, form data, bytes),
10//! and integrates with the framework's error handling.
11
12use crate::Downloader;
13use async_trait::async_trait;
14use reqwest::{Client, Proxy};
15use spider_util::error::SpiderError;
16use spider_util::request::{Body, Request};
17use spider_util::response::Response;
18use std::time::Duration;
19use log::info;
20use tokio::sync::RwLock;
21use std::collections::HashMap;
22use std::sync::Arc;
23
24/// Concrete implementation of Downloader using reqwest client
25pub struct ReqwestClientDownloader {
26    client: Client,
27    timeout: Duration,
28    /// Per-host connection pools for better resource management
29    host_clients: Arc<RwLock<HashMap<String, Client>>>,
30}
31
32#[async_trait]
33impl Downloader for ReqwestClientDownloader {
34    type Client = Client;
35
36    /// Returns a reference to the underlying HTTP client.
37    fn client(&self) -> &Self::Client {
38        &self.client
39    }
40
41    async fn download(&self, request: Request) -> Result<Response, SpiderError> {
42        info!(
43            "Downloading {} (fingerprint: {})",
44            request.url,
45            request.fingerprint()
46        );
47
48        let url = request.url.clone();
49        let method = request.method.clone();
50        let headers = request.headers.clone();
51        let body = request.body.clone();
52        let meta = request.meta_inner().clone();
53
54        // Get host-specific client if available, otherwise use default
55        let host = url.host_str().unwrap_or("").to_string();
56        // Convert DashMap to HashMap for the host client creation
57        let meta_hashmap: std::collections::HashMap<String, serde_json::Value> = meta
58            .as_ref()
59            .map(|m| m.iter().map(|entry| (entry.key().clone(), entry.value().clone())).collect())
60            .unwrap_or_default();
61        let mut client_to_use = self.get_or_create_host_client(&host, &meta_hashmap).await;
62
63        // Check for proxy in metadata
64        if let Some(meta_map) = meta.as_ref()
65            && let Some(proxy_val) = meta_map.get("proxy")
66            && let Some(proxy_str) = proxy_val.as_str()
67        {
68            match Proxy::all(proxy_str) {
69                Ok(proxy) => {
70                    let new_client = Client::builder()
71                        .timeout(self.timeout)
72                        .proxy(proxy)
73                        .build()
74                        .map_err(|e| SpiderError::ReqwestError(e.into()))?;
75                    client_to_use = new_client;
76                }
77                Err(e) => {
78                    return Err(SpiderError::ReqwestError(e.into()));
79                }
80            }
81        }
82
83        let mut req_builder = client_to_use.request(method, url.clone());
84
85        if let Some(body_content) = body {
86            req_builder = match body_content {
87                Body::Json(json_val) => req_builder.json(&json_val),
88                Body::Form(form_val) => {
89                    let mut form_map = std::collections::HashMap::new();
90                    for entry in form_val.iter() {
91                        form_map.insert(entry.key().clone(), entry.value().clone());
92                    }
93                    req_builder.form(&form_map)
94                }
95                Body::Bytes(bytes_val) => req_builder.body(bytes_val),
96            };
97        }
98
99        let res = req_builder.headers(headers).send().await?;
100
101        let response_url = res.url().clone();
102        let status = res.status();
103        let response_headers = res.headers().clone();
104        let response_body = res.bytes().await?;
105
106        Ok(Response {
107            url: response_url,
108            status,
109            headers: response_headers,
110            body: response_body,
111            request_url: url,
112            meta,  // Pass meta directly (already Option<Arc<...>>)
113            cached: false,
114        })
115    }
116}
117
118impl ReqwestClientDownloader {
119    /// Creates a new `ReqwestClientDownloader` with a default timeout of 30 seconds.
120    pub fn new() -> Self {
121        Self::new_with_timeout(Duration::from_secs(30))
122    }
123
124    /// Creates a new `ReqwestClientDownloader` with a specified request timeout.
125    pub fn new_with_timeout(timeout: Duration) -> Self {
126        let base_client = Client::builder()
127            .timeout(timeout)
128            .pool_max_idle_per_host(200)
129            .pool_idle_timeout(Duration::from_secs(120))
130            .tcp_keepalive(Duration::from_secs(60))
131            .connect_timeout(Duration::from_secs(10))
132            .build()
133            .unwrap();
134            
135        ReqwestClientDownloader {
136            client: base_client.clone(),
137            timeout,
138            host_clients: Arc::new(RwLock::new(HashMap::new())),
139        }
140    }
141
142    /// Gets or creates a host-specific client with optimized settings for that host
143    async fn get_or_create_host_client(&self, host: &str, _meta: &std::collections::HashMap<String, serde_json::Value>) -> Client {
144        {
145            let clients = self.host_clients.read().await;
146            if let Some(client) = clients.get(host) {
147                return client.clone();
148            }
149        }
150
151        // Create a new client for this host with optimized settings
152        let host_specific_client = Client::builder()
153            .timeout(self.timeout)
154            .pool_max_idle_per_host(50) // Smaller pool per host to distribute connections
155            .pool_idle_timeout(Duration::from_secs(90))
156            .tcp_keepalive(Duration::from_secs(30))
157            .connect_timeout(Duration::from_secs(5))
158            .build()
159            .unwrap();
160
161        {
162            let mut clients = self.host_clients.write().await;
163            // Double-check pattern to avoid race condition
164            if let Some(client) = clients.get(host) {
165                return client.clone();
166            }
167            clients.insert(host.to_string(), host_specific_client.clone());
168        }
169
170        host_specific_client
171    }
172}
173
174impl Default for ReqwestClientDownloader {
175    fn default() -> Self {
176        Self::new()
177    }
178}