1use crate::ConsoleWriterPipeline;
17#[cfg(feature = "checkpoint")]
18use crate::checkpoint::Checkpoint;
19use crate::downloader::Downloader;
20use crate::downloaders::reqwest_client::ReqwestClientDownloader;
21use crate::error::SpiderError;
22use crate::middleware::Middleware;
23use crate::middlewares::user_agent::{UserAgentMiddleware, UserAgentSource};
24use crate::pipeline::Pipeline;
25use crate::scheduler::Scheduler;
26use crate::spider::Spider;
27use num_cpus;
28#[cfg(feature = "checkpoint")]
29use std::fs;
30use std::marker::PhantomData;
31#[cfg(feature = "checkpoint")]
32use std::path::{Path, PathBuf};
33use std::sync::Arc;
34#[cfg(feature = "checkpoint")]
35use std::time::Duration;
36#[cfg(feature = "checkpoint")]
37use tracing::{debug, warn};
38
39use super::Crawler;
40use crate::stats::StatCollector;
41
42pub struct CrawlerConfig {
44 pub max_concurrent_downloads: usize,
46 pub parser_workers: usize,
48 pub max_concurrent_pipelines: usize,
50}
51
52impl Default for CrawlerConfig {
53 fn default() -> Self {
54 CrawlerConfig {
55 max_concurrent_downloads: 5,
56 parser_workers: num_cpus::get(),
57 max_concurrent_pipelines: 5,
58 }
59 }
60}
61
62pub struct CrawlerBuilder<S: Spider, D = ReqwestClientDownloader>
63where
64 D: Downloader,
65{
66 crawler_config: CrawlerConfig,
67 downloader: D,
68 spider: Option<S>,
69 middlewares: Vec<Box<dyn Middleware<D::Client> + Send + Sync>>,
70 item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
71 #[cfg(feature = "checkpoint")]
72 checkpoint_path: Option<PathBuf>,
73 #[cfg(feature = "checkpoint")]
74 checkpoint_interval: Option<Duration>,
75 _phantom: PhantomData<S>,
76}
77
78impl<S: Spider, D: Default + Downloader> Default for CrawlerBuilder<S, D> {
79 fn default() -> Self {
80 Self {
81 crawler_config: CrawlerConfig::default(),
82 downloader: D::default(),
83 spider: None,
84 middlewares: Vec::new(),
85 item_pipelines: Vec::new(),
86 #[cfg(feature = "checkpoint")]
87 checkpoint_path: None,
88 #[cfg(feature = "checkpoint")]
89 checkpoint_interval: None,
90 _phantom: PhantomData,
91 }
92 }
93}
94
95impl<S: Spider> CrawlerBuilder<S> {
96 pub fn new(spider: S) -> Self
98 {
99 Self {
100 spider: Some(spider),
101 ..Default::default()
102 }
103 }
104}
105
106impl<S: Spider, D: Downloader> CrawlerBuilder<S, D> {
107 pub fn max_concurrent_downloads(mut self, limit: usize) -> Self {
109 self.crawler_config.max_concurrent_downloads = limit;
110 self
111 }
112
113 pub fn max_parser_workers(mut self, limit: usize) -> Self {
115 self.crawler_config.parser_workers = limit;
116 self
117 }
118
119 pub fn max_concurrent_pipelines(mut self, limit: usize) -> Self {
121 self.crawler_config.max_concurrent_pipelines = limit;
122 self
123 }
124
125 pub fn downloader(mut self, downloader: D) -> Self {
127 self.downloader = downloader;
128 self
129 }
130
131 pub fn add_middleware<M>(mut self, middleware: M) -> Self
133 where
134 D: Downloader,
135 M: Middleware<D::Client> + Send + Sync + 'static,
136 {
137 self.middlewares.push(Box::new(middleware));
138 self
139 }
140
141 pub fn add_pipeline<P>(mut self, pipeline: P) -> Self
143 where
144 P: Pipeline<S::Item> + 'static,
145 {
146 self.item_pipelines.push(Box::new(pipeline));
147 self
148 }
149
150 #[cfg(feature = "checkpoint")]
152 pub fn with_checkpoint_path<P: AsRef<Path>>(mut self, path: P) -> Self {
153 self.checkpoint_path = Some(path.as_ref().to_path_buf());
154 self
155 }
156
157 #[cfg(feature = "checkpoint")]
159 pub fn with_checkpoint_interval(mut self, interval: Duration) -> Self {
160 self.checkpoint_interval = Some(interval);
161 self
162 }
163
164 pub async fn build(mut self) -> Result<Crawler<S, D::Client>, SpiderError>
166 where
167 D: Downloader + Send + Sync + 'static,
168 D::Client: Send + Sync,
169 {
170 if self.item_pipelines.is_empty() {
171 self = self.add_pipeline(ConsoleWriterPipeline::new());
172 }
173
174 let spider = self.spider.take().ok_or_else(|| {
175 SpiderError::ConfigurationError("Crawler must have a spider.".to_string())
176 })?;
177
178 if self.crawler_config.max_concurrent_downloads == 0 {
179 return Err(SpiderError::ConfigurationError(
180 "max_concurrent_downloads must be greater than 0.".to_string(),
181 ));
182 }
183 if self.crawler_config.parser_workers == 0 {
184 return Err(SpiderError::ConfigurationError(
185 "parser_workers must be greater than 0.".to_string(),
186 ));
187 }
188
189 #[cfg(feature = "checkpoint")]
190 let mut initial_scheduler_state = None;
191 #[cfg(not(feature = "checkpoint"))]
192 let initial_scheduler_state = None;
193 #[cfg(feature = "checkpoint")]
194 let mut loaded_pipelines_state = None;
195
196 #[cfg(feature = "checkpoint")]
197 if let Some(path) = &self.checkpoint_path {
198 debug!("Attempting to load checkpoint from {:?}", path);
199 match fs::read(path) {
200 Ok(bytes) => match rmp_serde::from_slice::<Checkpoint>(&bytes) {
201 Ok(checkpoint) => {
202 initial_scheduler_state = Some(checkpoint.scheduler);
203 loaded_pipelines_state = Some(checkpoint.pipelines);
204 }
205 Err(e) => warn!("Failed to deserialize checkpoint from {:?}: {}", path, e),
206 },
207 Err(e) => warn!("Failed to read checkpoint file {:?}: {}", path, e),
208 }
209 }
210
211 #[cfg(feature = "checkpoint")]
212 if let Some(pipeline_states) = loaded_pipelines_state {
214 for (name, state) in pipeline_states {
215 if let Some(pipeline) = self.item_pipelines.iter().find(|p| p.name() == name) {
216 pipeline.restore_state(state).await?;
217 } else {
218 warn!("Checkpoint contains state for unknown pipeline: {}", name);
219 }
220 }
221 }
222
223 let (scheduler_arc, req_rx) = Scheduler::new(initial_scheduler_state);
224
225 let has_user_agent_middleware = self
226 .middlewares
227 .iter()
228 .any(|m| m.name() == "UserAgentMiddleware");
229
230 if !has_user_agent_middleware {
231 let pkg_name = env!("CARGO_PKG_NAME");
232 let pkg_version = env!("CARGO_PKG_VERSION");
233 let default_user_agent = format!("{}/{}", pkg_name, pkg_version);
234
235 let default_user_agent_mw = UserAgentMiddleware::builder()
236 .source(UserAgentSource::List(vec![default_user_agent.clone()]))
237 .fallback_user_agent(default_user_agent)
238 .build()?;
239 self.middlewares.insert(0, Box::new(default_user_agent_mw));
240 }
241
242 let downloader_arc = Arc::new(self.downloader);
243
244 let stats = Arc::new(StatCollector::new());
245 let crawler = Crawler::new(
246 scheduler_arc,
247 req_rx,
248 downloader_arc,
249 self.middlewares,
250 spider,
251 self.item_pipelines,
252 self.crawler_config.max_concurrent_downloads,
253 self.crawler_config.parser_workers,
254 self.crawler_config.max_concurrent_pipelines,
255 #[cfg(feature = "checkpoint")]
256 self.checkpoint_path.take(),
257 #[cfg(feature = "checkpoint")]
258 self.checkpoint_interval,
259 stats,
260 );
261
262 Ok(crawler)
263 }
264}