swiftide_integrations/scraping/
loader.rs

1use derive_builder::Builder;
2use spider::website::Website;
3
4use swiftide_core::{
5    Loader,
6    indexing::{IndexingStream, TextNode},
7};
8
9#[derive(Debug, Builder, Clone)]
10#[builder(pattern = "owned")]
11/// Scrapes a given website
12///
13/// Under the hood uses the `spider` crate to scrape the website.
14/// For more configuration options see their documentation.
15pub struct ScrapingLoader {
16    spider_website: Website,
17}
18
19impl ScrapingLoader {
20    pub fn builder() -> ScrapingLoaderBuilder {
21        ScrapingLoaderBuilder::default()
22    }
23
24    // Constructs a scrapingloader from a `spider::Website` configuration
25    #[allow(dead_code)]
26    pub fn from_spider(spider_website: Website) -> Self {
27        Self { spider_website }
28    }
29
30    /// Constructs a scrapingloader from a given url
31    pub fn from_url(url: impl AsRef<str>) -> Self {
32        Self::from_spider(Website::new(url.as_ref()))
33    }
34}
35
36impl Loader for ScrapingLoader {
37    type Output = String;
38
39    fn into_stream(mut self) -> IndexingStream<String> {
40        let (tx, rx) = tokio::sync::mpsc::channel(1000);
41        let mut spider_rx = self
42            .spider_website
43            .subscribe(0)
44            .expect("Failed to subscribe to spider");
45        tracing::info!("Subscribed to spider");
46
47        let _recv_thread = tokio::spawn(async move {
48            while let Ok(res) = spider_rx.recv().await {
49                let html = res.get_html();
50                let original_size = html.len();
51
52                let node = TextNode::builder()
53                    .chunk(html)
54                    .original_size(original_size)
55                    .path(res.get_url())
56                    .build();
57
58                tracing::debug!(?node, "[Spider] Received node from spider");
59
60                if let Err(error) = tx.send(node).await {
61                    tracing::error!(?error, "[Spider] Failed to send node to stream");
62                    break;
63                }
64            }
65        });
66
67        let mut spider_website = self.spider_website;
68
69        let _scrape_thread = tokio::spawn(async move {
70            tracing::info!("[Spider] Starting scrape loop");
71            // TODO: It would be much nicer if this used `scrape` instead, as it is supposedly
72            // more concurrent
73            spider_website.crawl().await;
74            tracing::info!("[Spider] Scrape loop finished");
75        });
76
77        // NOTE: Handles should stay alive because of rx, but feels a bit fishy
78        rx.into()
79    }
80
81    fn into_stream_boxed(self: Box<Self>) -> IndexingStream<String> {
82        self.into_stream()
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use anyhow::Result;
90    use futures_util::StreamExt;
91    use swiftide_core::indexing::Loader;
92    use wiremock::matchers::{method, path};
93    use wiremock::{Mock, MockServer, Request, ResponseTemplate};
94
95    #[test_log::test(tokio::test(flavor = "multi_thread"))]
96    async fn test_scraping_loader_with_wiremock() {
97        // Set up the wiremock server to simulate the remote web server
98        let mock_server = MockServer::start().await;
99
100        // Mocked response for the page we will scrape
101        let body = "<html><body><h1>Test Page</h1></body></html>";
102        Mock::given(method("GET"))
103            .and(path("/"))
104            .respond_with(ResponseTemplate::new(200).set_body_string(body))
105            .mount(&mock_server)
106            .await;
107
108        // Create an instance of ScrapingLoader using the mock server's URL
109        let loader = ScrapingLoader::from_url(mock_server.uri());
110
111        // Execute the into_stream method
112        let stream = loader.into_stream();
113
114        // Process the stream to check if we get the expected result
115        let nodes = stream.collect::<Vec<Result<TextNode>>>().await;
116
117        assert_eq!(nodes.len(), 1);
118
119        let first_node = nodes.first().unwrap().as_ref().unwrap();
120
121        assert_eq!(first_node.chunk, body);
122    }
123
124    #[test_log::test(tokio::test(flavor = "multi_thread"))]
125    async fn test_scraping_loader_multiple_pages() {
126        // Set up the wiremock server to simulate the remote web server
127        let mock_server = MockServer::start().await;
128
129        // Mocked response for the page we will scrape
130        let body = "<html><body><h1>Test Page</h1><a href=\"/other\">link</a></body></html>";
131        Mock::given(method("GET"))
132            .and(path("/"))
133            .respond_with(ResponseTemplate::new(200).set_body_string(body))
134            .mount(&mock_server)
135            .await;
136
137        let body2 = "<html><body><h1>Test Page 2</h1></body></html>";
138        Mock::given(method("GET"))
139            .and(path("/other"))
140            .respond_with(move |_req: &Request| {
141                std::thread::sleep(std::time::Duration::from_secs(1));
142                ResponseTemplate::new(200).set_body_string(body2)
143            })
144            .mount(&mock_server)
145            .await;
146
147        // Create an instance of ScrapingLoader using the mock server's URL
148        let loader = ScrapingLoader::from_url(mock_server.uri());
149
150        // Execute the into_stream method
151        let stream = loader.into_stream();
152
153        // Process the stream to check if we get the expected result
154        let mut nodes = stream.collect::<Vec<Result<TextNode>>>().await;
155
156        assert_eq!(nodes.len(), 2);
157
158        let first_node = nodes.pop().unwrap().unwrap();
159
160        assert_eq!(first_node.chunk, body2);
161
162        let second_node = nodes.pop().unwrap().unwrap();
163
164        assert_eq!(second_node.chunk, body);
165    }
166}