1use crate::ConsoleWriterPipeline;
2use crate::checkpoint::Checkpoint;
3use crate::downloaders::reqwest_client::ReqwestClientDownloader;
4use crate::downloader::Downloader;
5use crate::error::SpiderError;
6use crate::middleware::Middleware;
7use crate::middlewares::user_agent::{UserAgentMiddleware, UserAgentSource};
8use crate::pipeline::Pipeline;
9use crate::scheduler::Scheduler;
10use crate::spider::Spider;
11use num_cpus;
12use std::fs;
13use std::marker::PhantomData;
14use std::path::{Path, PathBuf};
15use std::sync::Arc;
16use std::time::Duration;
17use tracing::{debug, warn};
18
19use super::Crawler;
20
21pub struct CrawlerConfig {
23 pub max_concurrent_downloads: usize,
25 pub parser_workers: usize,
27 pub max_concurrent_pipelines: usize,
29}
30
31impl Default for CrawlerConfig {
32 fn default() -> Self {
33 CrawlerConfig {
34 max_concurrent_downloads: 5,
35 parser_workers: num_cpus::get(),
36 max_concurrent_pipelines: 5,
37 }
38 }
39}
40
41pub struct CrawlerBuilder<S: Spider, D = ReqwestClientDownloader>
42where
43 D: Downloader,
44{
45 crawler_config: CrawlerConfig,
46 downloader: D,
47 spider: Option<S>,
48 middlewares: Vec<Box<dyn Middleware<D::Client> + Send + Sync>>,
49 item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
50 checkpoint_path: Option<PathBuf>,
51 checkpoint_interval: Option<Duration>,
52 _phantom: PhantomData<S>,
53}
54
55impl<S: Spider, D: Default + Downloader> Default for CrawlerBuilder<S, D> {
56 fn default() -> Self {
57 Self {
58 crawler_config: CrawlerConfig::default(),
59 downloader: D::default(),
60 spider: None,
61 middlewares: Vec::new(),
62 item_pipelines: Vec::new(),
63 checkpoint_path: None,
64 checkpoint_interval: None,
65 _phantom: PhantomData,
66 }
67 }
68}
69
70impl<S: Spider, D: Downloader> CrawlerBuilder<S, D> {
71 pub fn new(spider: S) -> Self
73 where
74 D: Default,
75 {
76 Self {
77 spider: Some(spider),
78 ..Default::default()
79 }
80 }
81
82 pub fn max_concurrent_downloads(mut self, limit: usize) -> Self {
84 self.crawler_config.max_concurrent_downloads = limit;
85 self
86 }
87
88 pub fn max_parser_workers(mut self, limit: usize) -> Self {
90 self.crawler_config.parser_workers = limit;
91 self
92 }
93
94 pub fn max_concurrent_pipelines(mut self, limit: usize) -> Self {
96 self.crawler_config.max_concurrent_pipelines = limit;
97 self
98 }
99
100 pub fn downloader(mut self, downloader: D) -> Self {
102 self.downloader = downloader;
103 self
104 }
105
106 pub fn add_middleware<M>(mut self, middleware: M) -> Self
108 where
109 D: Downloader,
110 M: Middleware<D::Client> + Send + Sync + 'static,
111 {
112 self.middlewares.push(Box::new(middleware));
113 self
114 }
115
116 pub fn add_pipeline<P>(mut self, pipeline: P) -> Self
118 where
119 P: Pipeline<S::Item> + 'static,
120 {
121 self.item_pipelines.push(Box::new(pipeline));
122 self
123 }
124
125 pub fn with_checkpoint_path<P: AsRef<Path>>(mut self, path: P) -> Self {
127 self.checkpoint_path = Some(path.as_ref().to_path_buf());
128 self
129 }
130
131 pub fn with_checkpoint_interval(mut self, interval: Duration) -> Self {
133 self.checkpoint_interval = Some(interval);
134 self
135 }
136
137 pub async fn build(mut self) -> Result<Crawler<S, D::Client>, SpiderError>
139 where
140 D: Downloader + Send + Sync + 'static,
141 D::Client: Send + Sync,
142 {
143 if self.item_pipelines.is_empty() {
144 self = self.add_pipeline(ConsoleWriterPipeline::new());
145 }
146
147 let spider = self.spider.take().ok_or_else(|| {
148 SpiderError::ConfigurationError("Crawler must have a spider.".to_string())
149 })?;
150
151 if self.crawler_config.max_concurrent_downloads == 0 {
152 return Err(SpiderError::ConfigurationError(
153 "max_concurrent_downloads must be greater than 0.".to_string(),
154 ));
155 }
156 if self.crawler_config.parser_workers == 0 {
157 return Err(SpiderError::ConfigurationError(
158 "parser_workers must be greater than 0.".to_string(),
159 ));
160 }
161
162 let mut initial_scheduler_state = None;
163 let mut loaded_pipelines_state = None;
164
165 if let Some(path) = &self.checkpoint_path {
166 debug!("Attempting to load checkpoint from {:?}", path);
167 match fs::read(path) {
168 Ok(bytes) => match rmp_serde::from_slice::<Checkpoint>(&bytes) {
169 Ok(checkpoint) => {
170 initial_scheduler_state = Some(checkpoint.scheduler);
171 loaded_pipelines_state = Some(checkpoint.pipelines);
172 }
173 Err(e) => warn!("Failed to deserialize checkpoint from {:?}: {}", path, e),
174 },
175 Err(e) => warn!("Failed to read checkpoint file {:?}: {}", path, e),
176 }
177 }
178
179 let (scheduler_arc, req_rx) = Scheduler::new(initial_scheduler_state.clone());
180
181 if let Some(pipeline_states) = loaded_pipelines_state {
183 for (name, state) in pipeline_states {
184 if let Some(pipeline) = self.item_pipelines.iter().find(|p| p.name() == name) {
185 pipeline.restore_state(state).await?;
186 } else {
187 warn!("Checkpoint contains state for unknown pipeline: {}", name);
188 }
189 }
190 }
191
192 let has_user_agent_middleware = self
193 .middlewares
194 .iter()
195 .any(|m| m.name() == "UserAgentMiddleware");
196
197 if !has_user_agent_middleware {
198 let pkg_name = env!("CARGO_PKG_NAME");
199 let pkg_version = env!("CARGO_PKG_VERSION");
200 let default_user_agent = format!("{}/{}", pkg_name, pkg_version);
201
202 let default_user_agent_mw = UserAgentMiddleware::builder()
203 .source(UserAgentSource::List(vec![default_user_agent.clone()]))
204 .fallback_user_agent(default_user_agent)
205 .build()?;
206 self.middlewares.insert(0, Box::new(default_user_agent_mw));
207 }
208
209 let downloader_arc = Arc::new(self.downloader);
210
211 let crawler = Crawler::new(
212 scheduler_arc,
213 req_rx,
214 downloader_arc,
215 self.middlewares,
216 spider,
217 self.item_pipelines,
218 self.crawler_config.max_concurrent_downloads,
219 self.crawler_config.parser_workers,
220 self.crawler_config.max_concurrent_pipelines,
221 self.checkpoint_path.take(),
222 self.checkpoint_interval,
223 );
224
225 Ok(crawler)
226 }
227}