1use crate::ConsoleWriterPipeline;
17#[cfg(feature = "checkpoint")]
18use crate::checkpoint::Checkpoint;
19use crate::downloader::Downloader;
20use crate::downloaders::reqwest_client::ReqwestClientDownloader;
21use crate::error::SpiderError;
22use crate::middleware::Middleware;
23use crate::middlewares::user_agent::{UserAgentMiddleware, UserAgentSource};
24use crate::pipeline::Pipeline;
25use crate::scheduler::Scheduler;
26use crate::spider::Spider;
27use num_cpus;
28#[cfg(feature = "middleware-cookies")]
29use std::any::Any;
30#[cfg(feature = "checkpoint")]
31use std::fs;
32use std::marker::PhantomData;
33#[cfg(feature = "checkpoint")]
34use std::path::{Path, PathBuf};
35use std::sync::Arc;
36#[cfg(feature = "checkpoint")]
37use std::time::Duration;
38#[cfg(feature = "checkpoint")]
39use tracing::{debug, info, warn};
40
41#[cfg(feature = "middleware-cookies")]
42use crate::middlewares::cookies::CookieMiddleware;
43#[cfg(feature = "middleware-cookies")]
44use cookie_store::CookieStore;
45#[cfg(feature = "middleware-cookies")]
46use tokio::sync::Mutex;
47
48use super::Crawler;
49use crate::stats::StatCollector;
50
51pub struct CrawlerConfig {
53 pub max_concurrent_downloads: usize,
55 pub parser_workers: usize,
57 pub max_concurrent_pipelines: usize,
59}
60
61impl Default for CrawlerConfig {
62 fn default() -> Self {
63 CrawlerConfig {
64 max_concurrent_downloads: 5,
65 parser_workers: num_cpus::get(),
66 max_concurrent_pipelines: 5,
67 }
68 }
69}
70
71pub struct CrawlerBuilder<S: Spider, D = ReqwestClientDownloader>
72where
73 D: Downloader,
74{
75 crawler_config: CrawlerConfig,
76 downloader: D,
77 spider: Option<S>,
78 middlewares: Vec<Box<dyn Middleware<D::Client> + Send + Sync>>,
79 item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
80 #[cfg(feature = "checkpoint")]
81 checkpoint_path: Option<PathBuf>,
82 #[cfg(feature = "checkpoint")]
83 checkpoint_interval: Option<Duration>,
84 _phantom: PhantomData<S>,
85}
86
87impl<S: Spider, D: Default + Downloader> Default for CrawlerBuilder<S, D> {
88 fn default() -> Self {
89 Self {
90 crawler_config: CrawlerConfig::default(),
91 downloader: D::default(),
92 spider: None,
93 middlewares: Vec::new(),
94 item_pipelines: Vec::new(),
95 #[cfg(feature = "checkpoint")]
96 checkpoint_path: None,
97 #[cfg(feature = "checkpoint")]
98 checkpoint_interval: None,
99 _phantom: PhantomData,
100 }
101 }
102}
103
104impl<S: Spider> CrawlerBuilder<S> {
105 pub fn new(spider: S) -> Self {
107 Self {
108 spider: Some(spider),
109 ..Default::default()
110 }
111 }
112}
113
114impl<S: Spider, D: Downloader> CrawlerBuilder<S, D> {
115 pub fn max_concurrent_downloads(mut self, limit: usize) -> Self {
117 self.crawler_config.max_concurrent_downloads = limit;
118 self
119 }
120
121 pub fn max_parser_workers(mut self, limit: usize) -> Self {
123 self.crawler_config.parser_workers = limit;
124 self
125 }
126
127 pub fn max_concurrent_pipelines(mut self, limit: usize) -> Self {
129 self.crawler_config.max_concurrent_pipelines = limit;
130 self
131 }
132
133 pub fn downloader(mut self, downloader: D) -> Self {
135 self.downloader = downloader;
136 self
137 }
138
139 pub fn add_middleware<M>(mut self, middleware: M) -> Self
141 where
142 D: Downloader,
143 M: Middleware<D::Client> + Send + Sync + 'static,
144 {
145 self.middlewares.push(Box::new(middleware));
146 self
147 }
148
149 pub fn add_pipeline<P>(mut self, pipeline: P) -> Self
151 where
152 P: Pipeline<S::Item> + 'static,
153 {
154 self.item_pipelines.push(Box::new(pipeline));
155 self
156 }
157
158 #[cfg(feature = "checkpoint")]
160 pub fn with_checkpoint_path<P: AsRef<Path>>(mut self, path: P) -> Self {
161 self.checkpoint_path = Some(path.as_ref().to_path_buf());
162 self
163 }
164
165 #[cfg(feature = "checkpoint")]
167 pub fn with_checkpoint_interval(mut self, interval: Duration) -> Self {
168 self.checkpoint_interval = Some(interval);
169 self
170 }
171
172 pub async fn build(mut self) -> Result<Crawler<S, D::Client>, SpiderError>
174 where
175 D: Downloader + Send + Sync + 'static,
176 D::Client: Send + Sync,
177 {
178 if self.item_pipelines.is_empty() {
179 self = self.add_pipeline(ConsoleWriterPipeline::new());
180 }
181
182 let spider = self.spider.take().ok_or_else(|| {
183 SpiderError::ConfigurationError("Crawler must have a spider.".to_string())
184 })?;
185
186 if self.crawler_config.max_concurrent_downloads == 0 {
187 return Err(SpiderError::ConfigurationError(
188 "max_concurrent_downloads must be greater than 0.".to_string(),
189 ));
190 }
191 if self.crawler_config.parser_workers == 0 {
192 return Err(SpiderError::ConfigurationError(
193 "parser_workers must be greater than 0.".to_string(),
194 ));
195 }
196
197 #[cfg(feature = "checkpoint")]
198 let mut initial_scheduler_state = None;
199 #[cfg(not(feature = "checkpoint"))]
200 let initial_scheduler_state = None;
201 #[cfg(feature = "checkpoint")]
202 let mut loaded_pipelines_state = None;
203 #[cfg(all(feature = "checkpoint", feature = "middleware-cookies"))]
204 let mut loaded_cookie_store: Option<CookieStore> = None;
205
206 #[cfg(feature = "checkpoint")]
207 if let Some(path) = &self.checkpoint_path {
208 debug!("Attempting to load checkpoint from {:?}", path);
209 match fs::read(path) {
210 Ok(bytes) => match rmp_serde::from_slice::<Checkpoint>(&bytes) {
211 Ok(checkpoint) => {
212 initial_scheduler_state = Some(checkpoint.scheduler);
213 loaded_pipelines_state = Some(checkpoint.pipelines);
214
215 #[cfg(feature = "middleware-cookies")]
216 {
217 info!("Checkpoint loaded, restoring cookie store data.");
218 loaded_cookie_store = Some(checkpoint.cookie_store);
219 }
220 }
221 Err(e) => warn!("Failed to deserialize checkpoint from {:?}: {}", path, e),
222 },
223 Err(e) => warn!("Failed to read checkpoint file {:?}: {}", path, e),
224 }
225 }
226
227 #[cfg(feature = "checkpoint")]
228 if let Some(pipeline_states) = loaded_pipelines_state {
230 for (name, state) in pipeline_states {
231 if let Some(pipeline) = self.item_pipelines.iter().find(|p| p.name() == name) {
232 pipeline.restore_state(state).await?;
233 } else {
234 warn!("Checkpoint contains state for unknown pipeline: {}", name);
235 }
236 }
237 }
238
239 let (scheduler_arc, req_rx) = Scheduler::new(initial_scheduler_state);
240
241 let has_user_agent_middleware = self
242 .middlewares
243 .iter()
244 .any(|m| m.name() == "UserAgentMiddleware");
245
246 if !has_user_agent_middleware {
247 let pkg_name = env!("CARGO_PKG_NAME");
248 let pkg_version = env!("CARGO_PKG_VERSION");
249 let default_user_agent = format!("{}/{}", pkg_name, pkg_version);
250
251 let default_user_agent_mw = UserAgentMiddleware::builder()
252 .source(UserAgentSource::List(vec![default_user_agent.clone()]))
253 .fallback_user_agent(default_user_agent)
254 .build()?;
255 self.middlewares.insert(0, Box::new(default_user_agent_mw));
256 }
257
258 let downloader_arc = Arc::new(self.downloader);
259 let stats = Arc::new(StatCollector::new());
260
261 let crawler = {
262 #[cfg(not(feature = "middleware-cookies"))]
263 {
264 Crawler::new(
265 scheduler_arc,
266 req_rx,
267 downloader_arc,
268 self.middlewares,
269 spider,
270 self.item_pipelines,
271 self.crawler_config.max_concurrent_downloads,
272 self.crawler_config.parser_workers,
273 self.crawler_config.max_concurrent_pipelines,
274 #[cfg(feature = "checkpoint")]
275 self.checkpoint_path.take(),
276 #[cfg(feature = "checkpoint")]
277 self.checkpoint_interval,
278 stats,
279 )
280 }
281
282 #[cfg(feature = "middleware-cookies")]
283 {
284 let mut final_cookie_store =
285 Arc::new(Mutex::new(loaded_cookie_store.unwrap_or_default()));
286
287 for mw_box in &self.middlewares {
289 if let Some(cookie_mw) =
290 (mw_box.as_ref() as &dyn Any).downcast_ref::<CookieMiddleware>()
291 {
292 info!(
293 "Found CookieMiddleware, using its cookie store for Crawler. This overrides any checkpointed store."
294 );
295 final_cookie_store = cookie_mw.store.clone();
296 break;
297 }
298 }
299
300 Crawler::new(
301 scheduler_arc,
302 req_rx,
303 downloader_arc,
304 self.middlewares,
305 spider,
306 self.item_pipelines,
307 self.crawler_config.max_concurrent_downloads,
308 self.crawler_config.parser_workers,
309 self.crawler_config.max_concurrent_pipelines,
310 #[cfg(feature = "checkpoint")]
311 self.checkpoint_path.take(),
312 #[cfg(feature = "checkpoint")]
313 self.checkpoint_interval,
314 stats,
315 final_cookie_store,
316 )
317 }
318 };
319
320 Ok(crawler)
321 }
322}