1use crate::ConsoleWriterPipeline;
17#[cfg(feature = "checkpoint")]
18use crate::checkpoint::Checkpoint;
19use crate::{CookieStore, SchedulerCheckpoint};
20
21use crate::downloader::Downloader;
22use crate::downloaders::reqwest_client::ReqwestClientDownloader;
23use crate::error::SpiderError;
24use crate::middleware::Middleware;
25use crate::middlewares::user_agent::{UserAgentMiddleware, UserAgentSource};
26use crate::pipeline::Pipeline;
27use crate::scheduler::Scheduler;
28use crate::spider::Spider;
29use num_cpus;
30#[cfg(feature = "middleware-cookies")]
31use std::any::Any;
32#[cfg(feature = "checkpoint")]
33use std::fs;
34use std::marker::PhantomData;
35#[cfg(feature = "checkpoint")]
36use std::path::{Path, PathBuf};
37use std::sync::Arc;
38#[cfg(feature = "checkpoint")]
39use std::time::Duration;
40#[cfg(feature = "middleware-cookies")]
41use tracing::info;
42#[cfg(feature = "checkpoint")]
43use tracing::{debug, warn};
44
45#[cfg(feature = "middleware-cookies")]
46use crate::middlewares::cookies::CookieMiddleware;
47
48#[cfg(feature = "middleware-cookies")]
49use tokio::sync::Mutex;
50
51use super::Crawler;
52use crate::stats::StatCollector;
53
54pub struct CrawlerConfig {
56 pub max_concurrent_downloads: usize,
58 pub parser_workers: usize,
60 pub max_concurrent_pipelines: usize,
62}
63
64impl Default for CrawlerConfig {
65 fn default() -> Self {
66 CrawlerConfig {
67 max_concurrent_downloads: 5,
68 parser_workers: num_cpus::get(),
69 max_concurrent_pipelines: 5,
70 }
71 }
72}
73
74pub struct CrawlerBuilder<S: Spider, D = ReqwestClientDownloader>
75where
76 D: Downloader,
77{
78 crawler_config: CrawlerConfig,
79 downloader: D,
80 spider: Option<S>,
81 middlewares: Vec<Box<dyn Middleware<D::Client> + Send + Sync>>,
82 item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
83 #[cfg(feature = "checkpoint")]
84 checkpoint_path: Option<PathBuf>,
85 #[cfg(feature = "checkpoint")]
86 checkpoint_interval: Option<Duration>,
87 _phantom: PhantomData<S>,
88}
89
90impl<S: Spider, D: Default + Downloader> Default for CrawlerBuilder<S, D> {
91 fn default() -> Self {
92 Self {
93 crawler_config: CrawlerConfig::default(),
94 downloader: D::default(),
95 spider: None,
96 middlewares: Vec::new(),
97 item_pipelines: Vec::new(),
98 #[cfg(feature = "checkpoint")]
99 checkpoint_path: None,
100 #[cfg(feature = "checkpoint")]
101 checkpoint_interval: None,
102 _phantom: PhantomData,
103 }
104 }
105}
106
107impl<S: Spider> CrawlerBuilder<S> {
108 pub fn new(spider: S) -> Self {
110 Self {
111 spider: Some(spider),
112 ..Default::default()
113 }
114 }
115}
116
117impl<S: Spider, D: Downloader> CrawlerBuilder<S, D> {
118 pub fn max_concurrent_downloads(mut self, limit: usize) -> Self {
120 self.crawler_config.max_concurrent_downloads = limit;
121 self
122 }
123
124 pub fn max_parser_workers(mut self, limit: usize) -> Self {
126 self.crawler_config.parser_workers = limit;
127 self
128 }
129
130 pub fn max_concurrent_pipelines(mut self, limit: usize) -> Self {
132 self.crawler_config.max_concurrent_pipelines = limit;
133 self
134 }
135
136 pub fn downloader(mut self, downloader: D) -> Self {
138 self.downloader = downloader;
139 self
140 }
141
142 pub fn add_middleware<M>(mut self, middleware: M) -> Self
144 where
145 D: Downloader,
146 M: Middleware<D::Client> + Send + Sync + 'static,
147 {
148 self.middlewares.push(Box::new(middleware));
149 self
150 }
151
152 pub fn add_pipeline<P>(mut self, pipeline: P) -> Self
154 where
155 P: Pipeline<S::Item> + 'static,
156 {
157 self.item_pipelines.push(Box::new(pipeline));
158 self
159 }
160
161 #[cfg(feature = "checkpoint")]
163 pub fn with_checkpoint_path<P: AsRef<Path>>(mut self, path: P) -> Self {
164 self.checkpoint_path = Some(path.as_ref().to_path_buf());
165 self
166 }
167
168 #[cfg(feature = "checkpoint")]
170 pub fn with_checkpoint_interval(mut self, interval: Duration) -> Self {
171 self.checkpoint_interval = Some(interval);
172 self
173 }
174
175 #[allow(unused_variables)]
177 pub async fn build(mut self) -> Result<Crawler<S, D::Client>, SpiderError>
178 where
179 D: Downloader + Send + Sync + 'static,
180 D::Client: Send + Sync + Clone,
181 {
182 let spider = self.validate_and_get_spider()?;
183 self.ensure_default_components().await?;
184
185 #[cfg(feature = "checkpoint")]
186 let (initial_scheduler_state, loaded_cookie_store) =
187 self.load_and_restore_checkpoint_state().await?;
188
189 #[cfg(not(feature = "checkpoint"))]
190 let (initial_scheduler_state, loaded_cookie_store): (
191 Option<SchedulerCheckpoint>,
192 Option<CookieStore>,
193 ) = (None, None);
194
195 let (scheduler_arc, req_rx) = Scheduler::new(initial_scheduler_state);
196
197 #[cfg(feature = "middleware-cookies")]
198 let final_cookie_store = self.resolve_final_cookie_store(loaded_cookie_store);
199
200 let downloader_arc = Arc::new(self.downloader);
201 let stats = Arc::new(StatCollector::new());
202
203 let crawler = Crawler::new(
204 scheduler_arc,
205 req_rx,
206 downloader_arc,
207 self.middlewares,
208 spider,
209 self.item_pipelines,
210 self.crawler_config.max_concurrent_downloads,
211 self.crawler_config.parser_workers,
212 self.crawler_config.max_concurrent_pipelines,
213 #[cfg(feature = "checkpoint")]
214 self.checkpoint_path.take(),
215 #[cfg(feature = "checkpoint")]
216 self.checkpoint_interval,
217 stats,
218 #[cfg(feature = "middleware-cookies")]
219 final_cookie_store,
220 );
221
222 Ok(crawler)
223 }
224
225 #[cfg(feature = "middleware-cookies")]
226 fn resolve_final_cookie_store(
227 &self,
228 loaded_cookie_store: Option<CookieStore>,
229 ) -> Arc<Mutex<CookieStore>> {
230 let mut final_cookie_store = Arc::new(Mutex::new(loaded_cookie_store.unwrap_or_default()));
231
232 for mw_box in &self.middlewares {
234 if let Some(cookie_mw) =
235 (mw_box.as_ref() as &dyn Any).downcast_ref::<CookieMiddleware>()
236 {
237 info!(
238 "Found CookieMiddleware, using its cookie store for Crawler. This overrides any checkpointed store."
239 );
240 final_cookie_store = cookie_mw.store.clone();
241 break;
242 }
243 }
244 final_cookie_store
245 }
246
247 #[cfg(feature = "checkpoint")]
248 async fn load_and_restore_checkpoint_state(
249 &mut self,
250 ) -> Result<(Option<SchedulerCheckpoint>, Option<CookieStore>), SpiderError> {
251 let mut initial_scheduler_state = None;
252 let mut loaded_pipelines_state = None;
253 #[cfg(feature = "middleware-cookies")]
254 let mut loaded_cookie_store = None;
255
256 if let Some(path) = &self.checkpoint_path {
257 debug!("Attempting to load checkpoint from {:?}", path);
258 match fs::read(path) {
259 Ok(bytes) => match rmp_serde::from_slice::<Checkpoint>(&bytes) {
260 Ok(checkpoint) => {
261 initial_scheduler_state = Some(checkpoint.scheduler);
262 loaded_pipelines_state = Some(checkpoint.pipelines);
263
264 #[cfg(feature = "middleware-cookies")]
265 {
266 info!("Checkpoint loaded, restoring cookie store data.");
267 loaded_cookie_store = Some(checkpoint.cookie_store);
268 }
269 }
270 Err(e) => warn!("Failed to deserialize checkpoint from {:?}: {}", path, e),
271 },
272 Err(e) => warn!("Failed to read checkpoint file {:?}: {}", path, e),
273 }
274 }
275
276 if let Some(pipeline_states) = loaded_pipelines_state {
277 for (name, state) in pipeline_states {
278 if let Some(pipeline) = self.item_pipelines.iter().find(|p| p.name() == name) {
279 pipeline.restore_state(state).await?;
280 } else {
281 warn!("Checkpoint contains state for unknown pipeline: {}", name);
282 }
283 }
284 }
285
286 #[cfg(not(feature = "middleware-cookies"))]
287 return Ok((initial_scheduler_state, None));
288 #[cfg(feature = "middleware-cookies")]
289 return Ok((initial_scheduler_state, loaded_cookie_store));
290 }
291
292 fn validate_and_get_spider(&mut self) -> Result<S, SpiderError> {
293 if self.crawler_config.max_concurrent_downloads == 0 {
294 return Err(SpiderError::ConfigurationError(
295 "max_concurrent_downloads must be greater than 0.".to_string(),
296 ));
297 }
298 if self.crawler_config.parser_workers == 0 {
299 return Err(SpiderError::ConfigurationError(
300 "parser_workers must be greater than 0.".to_string(),
301 ));
302 }
303 self.spider.take().ok_or_else(|| {
304 SpiderError::ConfigurationError("Crawler must have a spider.".to_string())
305 })
306 }
307
308 async fn ensure_default_components(&mut self) -> Result<(), SpiderError> {
309 if self.item_pipelines.is_empty() {
310 self.item_pipelines
311 .push(Box::new(ConsoleWriterPipeline::new()));
312 }
313
314 let has_user_agent_middleware = self
315 .middlewares
316 .iter()
317 .any(|m| m.name() == "UserAgentMiddleware");
318
319 if !has_user_agent_middleware {
320 let pkg_name = env!("CARGO_PKG_NAME");
321 let pkg_version = env!("CARGO_PKG_VERSION");
322 let default_user_agent = format!("{}/{}", pkg_name, pkg_version);
323
324 let default_user_agent_mw = UserAgentMiddleware::builder()
325 .source(UserAgentSource::List(vec![default_user_agent.clone()]))
326 .fallback_user_agent(default_user_agent)
327 .build()?;
328 self.middlewares.insert(0, Box::new(default_user_agent_mw));
329 }
330
331 Ok(())
332 }
333}