spider_lib/
engine.rs

1use crate::downloader::Downloader;
2use crate::error::SpiderError;
3use crate::item::{CrawlOutput, ScrapedItem};
4use crate::pipeline::ItemPipeline;
5use crate::request::Request;
6use crate::response::Response;
7use crate::scheduler::Scheduler;
8use crate::spider::Spider;
9use crate::state::EngineState;
10use anyhow::Result;
11use futures_util::future::join_all;
12use kanal::{AsyncReceiver, AsyncSender, bounded_async};
13use std::sync::Arc;
14use std::sync::atomic::Ordering;
15use std::time::Duration;
16use tokio::sync::Semaphore;
17use tokio::task::JoinSet;
18use tracing::{debug, error, info, instrument, warn};
19
20pub struct Engine<S: Spider> {
21    scheduler: Arc<Scheduler>,
22    req_rx: AsyncReceiver<Request>,
23    downloader: Arc<Downloader>,
24    spider: Arc<S>,
25    item_pipelines: Vec<Box<dyn ItemPipeline<S::Item>>>,
26    max_concurrent_downloads: usize,
27    #[allow(dead_code)]
28    parser_workers: usize,
29}
30
31impl<S> Engine<S>
32where
33    S: Spider + 'static,
34    S::Item: ScrapedItem,
35{
36    pub(crate) fn new(
37        scheduler: Arc<Scheduler>,
38        req_rx: AsyncReceiver<Request>,
39        downloader: Arc<Downloader>,
40        spider: S,
41        item_pipelines: Vec<Box<dyn ItemPipeline<S::Item>>>,
42        max_concurrent_downloads: usize,
43        parser_workers: usize,
44    ) -> Self {
45        Engine {
46            scheduler,
47            req_rx,
48            downloader,
49            spider: Arc::new(spider),
50            item_pipelines,
51            max_concurrent_downloads,
52            parser_workers,
53        }
54    }
55
56    #[instrument(skip_all)]
57    pub async fn start_crawl(self) -> Result<(), SpiderError> {
58        info!("Engine starting crawl");
59
60        let Engine {
61            scheduler,
62            req_rx,
63            downloader,
64            spider,
65            item_pipelines,
66            max_concurrent_downloads,
67            parser_workers,
68        } = self;
69
70        let state = EngineState::new();
71        let pipelines = Arc::new(item_pipelines);
72        let channel_capacity = max_concurrent_downloads * 2;
73
74        let (res_tx, res_rx) = bounded_async(channel_capacity);
75        let (item_tx, item_rx) = bounded_async(channel_capacity);
76
77        let initial_requests_task =
78            spawn_initial_requests_task::<S>(scheduler.clone(), spider.clone());
79
80        let downloader_actor = spawn_downloader_actor::<S>(
81            req_rx,
82            downloader,
83            state.clone(),
84            res_tx,
85            max_concurrent_downloads,
86        );
87
88        let parser_actor = spawn_parser_actor::<S>(
89            scheduler.clone(),
90            spider.clone(),
91            state.clone(),
92            res_rx,
93            item_tx,
94            parser_workers,
95        );
96
97        let item_processor_actor =
98            spawn_item_processor_actor::<S>(state.clone(), item_rx, pipelines.clone());
99
100        monitor_crawl(scheduler.clone(), state.clone()).await;
101        info!("Crawl has become idle, initiating shutdown.");
102
103        scheduler.shutdown().await?;
104        initial_requests_task.await.map_err(|e| {
105            SpiderError::InternalError(format!("Initial requests task failed: {}", e))
106        })?;
107
108        downloader_actor.await.map_err(|e| {
109            SpiderError::InternalError(format!("Downloader actor task failed: {}", e))
110        })?;
111
112        parser_actor
113            .await
114            .map_err(|e| SpiderError::InternalError(format!("Parser actor task failed: {}", e)))?;
115
116        item_processor_actor.await.map_err(|e| {
117            SpiderError::InternalError(format!("Item processor actor task failed: {}", e))
118        })?;
119
120        // Close all pipelines
121        info!("Closing item pipelines...");
122        let closing_futures: Vec<_> = pipelines.iter().map(|p| p.close()).collect();
123        join_all(closing_futures).await;
124
125        info!("Crawl finished successfully.");
126        Ok(())
127    }
128}
129
130fn spawn_initial_requests_task<S>(
131    scheduler: Arc<Scheduler>,
132    spider: Arc<S>,
133) -> tokio::task::JoinHandle<()>
134where
135    S: Spider + 'static,
136    S::Item: ScrapedItem,
137{
138    tokio::spawn(async move {
139        for mut url in spider.start_urls() {
140            url.set_fragment(None);
141            let req = Request::new(url, spider.name(), "parse");
142            if let Err(e) = scheduler.enqueue_request(req).await {
143                error!("Failed to enqueue initial request: {}", e);
144            }
145        }
146    })
147}
148
149async fn monitor_crawl(scheduler: Arc<Scheduler>, state: Arc<EngineState>) {
150    loop {
151        if scheduler.is_idle() && state.is_idle() {
152            // Short delay to ensure no tasks are in the process of being created.
153            tokio::time::sleep(Duration::from_millis(50)).await;
154            if scheduler.is_idle() && state.is_idle() {
155                break;
156            }
157        }
158        tokio::time::sleep(Duration::from_millis(100)).await;
159    }
160}
161
162fn spawn_downloader_actor<S>(
163    req_rx: AsyncReceiver<Request>,
164    downloader: Arc<Downloader>,
165    state: Arc<EngineState>,
166    res_tx: AsyncSender<Response>,
167    max_concurrent_downloads: usize,
168) -> tokio::task::JoinHandle<()>
169where
170    S: Spider + 'static,
171    S::Item: ScrapedItem,
172{
173    let semaphore = Arc::new(Semaphore::new(max_concurrent_downloads));
174    let mut tasks = JoinSet::new();
175
176    tokio::spawn(async move {
177        while let Ok(request) = req_rx.recv().await {
178            let permit = match semaphore.clone().acquire_owned().await {
179                Ok(p) => p,
180                Err(_) => {
181                    warn!("Semaphore closed, shutting down downloader actor.");
182                    break;
183                }
184            };
185
186            state.in_flight_requests.fetch_add(1, Ordering::SeqCst);
187            let downloader_clone = downloader.clone();
188            let res_tx_clone = res_tx.clone();
189            let state_clone = state.clone();
190
191            tasks.spawn(async move {
192                debug!("Downloading {}", request.url);
193                match downloader_clone.download(request).await {
194                    Ok(response) => {
195                        if res_tx_clone.send(response).await.is_err() {
196                            error!("Response channel closed, cannot send parsed response.");
197                        }
198                    }
199                    Err(e) => error!("Download error: {:?}", e),
200                }
201                state_clone
202                    .in_flight_requests
203                    .fetch_sub(1, Ordering::SeqCst);
204                drop(permit);
205            });
206        }
207        // Wait for all download tasks to complete
208        while let Some(res) = tasks.join_next().await {
209            if let Err(e) = res {
210                error!("A download task failed: {:?}", e);
211            }
212        }
213    })
214}
215
216fn spawn_parser_actor<S>(
217    scheduler: Arc<Scheduler>,
218    spider: Arc<S>,
219    state: Arc<EngineState>,
220    res_rx: AsyncReceiver<Response>,
221    item_tx: AsyncSender<S::Item>,
222    parser_workers: usize,
223) -> tokio::task::JoinHandle<()>
224where
225    S: Spider + 'static,
226    S::Item: ScrapedItem,
227{
228    let mut tasks = JoinSet::new();
229    let internal_parse_tx: AsyncSender<Response>;
230    let internal_parse_rx: AsyncReceiver<Response>;
231    (internal_parse_tx, internal_parse_rx) = bounded_async(parser_workers * 2);
232
233    // Spawn N parsing worker tasks
234    for _ in 0..parser_workers {
235        let internal_parse_rx_clone = internal_parse_rx.clone();
236        let spider_clone = spider.clone();
237        let scheduler_clone = scheduler.clone();
238        let item_tx_clone = item_tx.clone();
239        let state_clone = state.clone();
240
241        tasks.spawn(async move {
242            while let Ok(response) = internal_parse_rx_clone.recv().await {
243                debug!("Parsing response from {}", response.url);
244                match spider_clone.parse(response).await {
245                    Ok(outputs) => {
246                        process_crawl_outputs::<S>(
247                            outputs,
248                            scheduler_clone.clone(),
249                            item_tx_clone.clone(),
250                            state_clone.clone(),
251                        )
252                        .await;
253                    }
254                    Err(e) => error!("Spider parsing error: {:?}", e),
255                }
256                state_clone.parsing_responses.fetch_sub(1, Ordering::SeqCst);
257            }
258        });
259    }
260
261    tokio::spawn(async move {
262        while let Ok(response) = res_rx.recv().await {
263            state.parsing_responses.fetch_add(1, Ordering::SeqCst);
264            if internal_parse_tx.send(response).await.is_err() {
265                error!("Internal parse channel closed, cannot send response to parser worker.");
266                state.parsing_responses.fetch_sub(1, Ordering::SeqCst); // Decrement if send fails
267            }
268        }
269        // Drop the sender to signal worker tasks to shut down
270        drop(internal_parse_tx);
271
272        // Wait for all parsing tasks to complete
273        while let Some(res) = tasks.join_next().await {
274            if let Err(e) = res {
275                error!("A parsing worker task failed: {:?}", e);
276            }
277        }
278    })
279}
280
281async fn process_crawl_outputs<S>(
282    outputs: CrawlOutput<S::Item>,
283    scheduler: Arc<Scheduler>,
284    item_tx: AsyncSender<S::Item>,
285    state: Arc<EngineState>,
286) where
287    S: Spider + 'static,
288    S::Item: ScrapedItem,
289{
290    for request in outputs.requests {
291        if let Err(e) = scheduler.enqueue_request(request).await {
292            error!("Failed to enqueue request from parser: {}", e);
293        }
294    }
295    for item in outputs.items {
296        state.processing_items.fetch_add(1, Ordering::SeqCst);
297        if item_tx.send(item).await.is_err() {
298            error!("Item channel closed, cannot send scraped item.");
299            state.processing_items.fetch_sub(1, Ordering::SeqCst); // Decrement if send fails
300        }
301    }
302}
303
304fn spawn_item_processor_actor<S>(
305    state: Arc<EngineState>,
306    item_rx: AsyncReceiver<S::Item>,
307    pipelines: Arc<Vec<Box<dyn ItemPipeline<S::Item>>>>,
308) -> tokio::task::JoinHandle<()>
309where
310    S: Spider + 'static,
311    S::Item: ScrapedItem,
312{
313    let mut tasks = JoinSet::new();
314    tokio::spawn(async move {
315        while let Ok(item) = item_rx.recv().await {
316            let state_clone = state.clone();
317            let pipelines_clone = pipelines.clone();
318            tasks.spawn(async move {
319                debug!("Processing item...");
320                let mut item_to_process = Some(item);
321                for pipeline in pipelines_clone.iter() {
322                    if let Some(it) = item_to_process.take() {
323                        match pipeline.process_item(it).await {
324                            Ok(Some(next_item)) => item_to_process = Some(next_item),
325                            Ok(None) => break, // Item was dropped
326                            Err(e) => {
327                                error!("Pipeline error for {}: {:?}", pipeline.name(), e);
328                                break;
329                            }
330                        }
331                    } else {
332                        break;
333                    }
334                }
335                state_clone.processing_items.fetch_sub(1, Ordering::SeqCst);
336            });
337        }
338        while let Some(res) = tasks.join_next().await {
339            if let Err(e) = res {
340                error!("An item processing task failed: {:?}", e);
341            }
342        }
343    })
344}