1use crate::ConsoleWriterPipeline;
16use crate::checkpoint::Checkpoint;
17use crate::downloader::Downloader;
18use crate::downloaders::reqwest_client::ReqwestClientDownloader;
19use crate::error::SpiderError;
20use crate::middleware::Middleware;
21use crate::middlewares::user_agent::{UserAgentMiddleware, UserAgentSource};
22use crate::pipeline::Pipeline;
23use crate::scheduler::Scheduler;
24use crate::spider::Spider;
25use num_cpus;
26use std::fs;
27use std::marker::PhantomData;
28use std::path::{Path, PathBuf};
29use std::sync::Arc;
30use std::time::Duration;
31use tracing::{debug, warn};
32
33use super::Crawler;
34
35pub struct CrawlerConfig {
37 pub max_concurrent_downloads: usize,
39 pub parser_workers: usize,
41 pub max_concurrent_pipelines: usize,
43}
44
45impl Default for CrawlerConfig {
46 fn default() -> Self {
47 CrawlerConfig {
48 max_concurrent_downloads: 5,
49 parser_workers: num_cpus::get(),
50 max_concurrent_pipelines: 5,
51 }
52 }
53}
54
55pub struct CrawlerBuilder<S: Spider, D = ReqwestClientDownloader>
56where
57 D: Downloader,
58{
59 crawler_config: CrawlerConfig,
60 downloader: D,
61 spider: Option<S>,
62 middlewares: Vec<Box<dyn Middleware<D::Client> + Send + Sync>>,
63 item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
64 checkpoint_path: Option<PathBuf>,
65 checkpoint_interval: Option<Duration>,
66 _phantom: PhantomData<S>,
67}
68
69impl<S: Spider, D: Default + Downloader> Default for CrawlerBuilder<S, D> {
70 fn default() -> Self {
71 Self {
72 crawler_config: CrawlerConfig::default(),
73 downloader: D::default(),
74 spider: None,
75 middlewares: Vec::new(),
76 item_pipelines: Vec::new(),
77 checkpoint_path: None,
78 checkpoint_interval: None,
79 _phantom: PhantomData,
80 }
81 }
82}
83
84impl<S: Spider, D: Downloader> CrawlerBuilder<S, D> {
85 pub fn new(spider: S) -> Self
87 where
88 D: Default,
89 {
90 Self {
91 spider: Some(spider),
92 ..Default::default()
93 }
94 }
95
96 pub fn max_concurrent_downloads(mut self, limit: usize) -> Self {
98 self.crawler_config.max_concurrent_downloads = limit;
99 self
100 }
101
102 pub fn max_parser_workers(mut self, limit: usize) -> Self {
104 self.crawler_config.parser_workers = limit;
105 self
106 }
107
108 pub fn max_concurrent_pipelines(mut self, limit: usize) -> Self {
110 self.crawler_config.max_concurrent_pipelines = limit;
111 self
112 }
113
114 pub fn downloader(mut self, downloader: D) -> Self {
116 self.downloader = downloader;
117 self
118 }
119
120 pub fn add_middleware<M>(mut self, middleware: M) -> Self
122 where
123 D: Downloader,
124 M: Middleware<D::Client> + Send + Sync + 'static,
125 {
126 self.middlewares.push(Box::new(middleware));
127 self
128 }
129
130 pub fn add_pipeline<P>(mut self, pipeline: P) -> Self
132 where
133 P: Pipeline<S::Item> + 'static,
134 {
135 self.item_pipelines.push(Box::new(pipeline));
136 self
137 }
138
139 pub fn with_checkpoint_path<P: AsRef<Path>>(mut self, path: P) -> Self {
141 self.checkpoint_path = Some(path.as_ref().to_path_buf());
142 self
143 }
144
145 pub fn with_checkpoint_interval(mut self, interval: Duration) -> Self {
147 self.checkpoint_interval = Some(interval);
148 self
149 }
150
151 pub async fn build(mut self) -> Result<Crawler<S, D::Client>, SpiderError>
153 where
154 D: Downloader + Send + Sync + 'static,
155 D::Client: Send + Sync,
156 {
157 if self.item_pipelines.is_empty() {
158 self = self.add_pipeline(ConsoleWriterPipeline::new());
159 }
160
161 let spider = self.spider.take().ok_or_else(|| {
162 SpiderError::ConfigurationError("Crawler must have a spider.".to_string())
163 })?;
164
165 if self.crawler_config.max_concurrent_downloads == 0 {
166 return Err(SpiderError::ConfigurationError(
167 "max_concurrent_downloads must be greater than 0.".to_string(),
168 ));
169 }
170 if self.crawler_config.parser_workers == 0 {
171 return Err(SpiderError::ConfigurationError(
172 "parser_workers must be greater than 0.".to_string(),
173 ));
174 }
175
176 let mut initial_scheduler_state = None;
177 let mut loaded_pipelines_state = None;
178
179 if let Some(path) = &self.checkpoint_path {
180 debug!("Attempting to load checkpoint from {:?}", path);
181 match fs::read(path) {
182 Ok(bytes) => match rmp_serde::from_slice::<Checkpoint>(&bytes) {
183 Ok(checkpoint) => {
184 initial_scheduler_state = Some(checkpoint.scheduler);
185 loaded_pipelines_state = Some(checkpoint.pipelines);
186 }
187 Err(e) => warn!("Failed to deserialize checkpoint from {:?}: {}", path, e),
188 },
189 Err(e) => warn!("Failed to read checkpoint file {:?}: {}", path, e),
190 }
191 }
192
193 let (scheduler_arc, req_rx) = Scheduler::new(initial_scheduler_state.clone());
194
195 if let Some(pipeline_states) = loaded_pipelines_state {
197 for (name, state) in pipeline_states {
198 if let Some(pipeline) = self.item_pipelines.iter().find(|p| p.name() == name) {
199 pipeline.restore_state(state).await?;
200 } else {
201 warn!("Checkpoint contains state for unknown pipeline: {}", name);
202 }
203 }
204 }
205
206 let has_user_agent_middleware = self
207 .middlewares
208 .iter()
209 .any(|m| m.name() == "UserAgentMiddleware");
210
211 if !has_user_agent_middleware {
212 let pkg_name = env!("CARGO_PKG_NAME");
213 let pkg_version = env!("CARGO_PKG_VERSION");
214 let default_user_agent = format!("{}/{}", pkg_name, pkg_version);
215
216 let default_user_agent_mw = UserAgentMiddleware::builder()
217 .source(UserAgentSource::List(vec![default_user_agent.clone()]))
218 .fallback_user_agent(default_user_agent)
219 .build()?;
220 self.middlewares.insert(0, Box::new(default_user_agent_mw));
221 }
222
223 let downloader_arc = Arc::new(self.downloader);
224
225 let crawler = Crawler::new(
226 scheduler_arc,
227 req_rx,
228 downloader_arc,
229 self.middlewares,
230 spider,
231 self.item_pipelines,
232 self.crawler_config.max_concurrent_downloads,
233 self.crawler_config.parser_workers,
234 self.crawler_config.max_concurrent_pipelines,
235 self.checkpoint_path.take(),
236 self.checkpoint_interval,
237 );
238
239 Ok(crawler)
240 }
241}