pickaxe/
document.rs

1use std::{
2    fmt::{Debug, Formatter},
3    sync::Arc,
4    collections::HashMap,
5};
6use ego_tree::{NodeId, NodeRef};
7use scraper::{ElementRef, Html, Node, Selector};
8
9use crate::xpath::{parse_xpath, NodeAccessor};
10use crate::errors::{Result, PackageError};
11
12
13/// The result of an XPath query.
14#[derive(Debug)]
15pub enum XPathResult {
16    Node(HtmlNode),
17    String(String),
18}
19
20impl XPathResult {
21    /// Get the result as a string.
22    pub fn as_string(&self) -> Option<&str> {
23        match self {
24            XPathResult::String(s) => Some(s),
25            _ => None,
26        }
27    }
28
29    /// Get the result as a node.
30    pub fn as_node(&self) -> Option<&HtmlNode> {
31        match self {
32            XPathResult::Node(node) => Some(node),
33            _ => None,
34        }
35    }
36
37    /// Convert the result into a string.
38    pub fn into_string(self) -> Option<String> {
39        match self {
40            XPathResult::String(s) => Some(s),
41            _ => None,
42        }
43    }
44
45    /// Convert the result into a node.
46    pub fn into_node(self) -> Option<HtmlNode> {
47        match self {
48            XPathResult::Node(node) => Some(node),
49            _ => None,
50        }
51    }
52}
53
54/// An HTML document that has been parsed into a virtual DOM.
55#[derive(Clone)]
56pub struct HtmlDocument {
57    raw: String,
58    dom: Arc<Html>,
59    is_fragment: bool,
60}
61
62// NOTE: This is safe because `HtmlDocument` is immutable, so no UB can occur. Any modifications
63// to the DOM must be wrapped in some kind of synchronization primitive if this is ever
64// changed.
65// Don't try this at home kids.
66unsafe impl Send for HtmlDocument {}
67unsafe impl Sync for HtmlDocument {}
68
69impl HtmlDocument {
70    /// Create a new `HtmlDocument` from a string of HTML and a [`Html`] DOM.
71    /// 
72    /// * `raw` - The raw HTML string.
73    /// * `dom` - The parsed HTML DOM.
74    pub fn new(raw: String, dom: Html, is_fragment: bool) -> Self {
75        Self {
76            raw,
77            dom: Arc::new(dom),
78            is_fragment,
79        }
80    }
81
82    /// Parse an HTML document from a string.
83    /// 
84    /// * `html` - The HTML string to parse.
85    pub fn from_str(html: String) -> Self {
86        let is_fragment = !html.contains("html");
87        let dom = if is_fragment {
88            Html::parse_fragment(&html)
89        } else {
90            Html::parse_document(&html)
91        };
92
93        Self::new(html, dom, is_fragment)
94    }
95
96    /// The raw HTML string that was parsed.
97    pub fn raw(&self) -> &str {
98        &self.raw
99    }
100
101    /// Get the root node of the document.
102    pub fn root(&self) -> HtmlNode {
103        HtmlNode::new(
104            self.dom.clone(),
105            match self.is_fragment {
106                false => self.dom
107                            .root_element()
108                            .id(),
109                true => self.dom
110                            .root_element()
111                            .children()
112                            .next()
113                            .expect("no root element")
114                            .id(),
115            }
116        )
117    }
118
119    /// Query the document for matching elements using a CSS selector..
120    ///
121    /// * `selector` - The CSS selector to query for.
122    pub fn find_all(&self, selector: &str) -> Result<Vec<HtmlNode>> {
123        self.root()
124            .find_all(selector)
125    }
126
127    /// Query the document for matching elements using an XPath expression.
128    /// 
129    /// * `xpath` - The XPath expression to query for.
130    pub fn find_all_xpath(&self, xpath: &str) -> Result<Vec<XPathResult>> {
131        self.root()
132            .find_all_xpath(xpath)
133    }
134
135    /// Query the document for the first matching element using a CSS selector.
136    /// 
137    /// * `selector` - The CSS selector to query for.
138    pub fn find(&self, selector: &str) -> Result<Option<HtmlNode>> {
139        self.root()
140            .find(selector)
141    }
142
143    /// Query the document for the first matching element using an XPath expression.
144    /// 
145    /// * `xpath` - The XPath expression to query for.
146    pub fn find_xpath(&self, xpath: &str) -> Result<Option<XPathResult>> {
147        self.root()
148            .find_xpath(xpath)
149    }
150
151    /// Query the document for the nth matching element using a CSS selector.
152    /// 
153    /// * `selector` - The CSS selector to query for.
154    /// * `n` - The index of the element to get.
155    pub fn find_nth(&self, selector: &str, n: usize) -> Result<Option<HtmlNode>> {
156        self.root()
157            .find_nth(selector, n)
158    }
159
160    /// Query the document for the nth matching element using an XPath expression.
161    /// 
162    /// * `xpath` - The XPath expression to query for.
163    /// * `n` - The index of the element to get.
164    pub fn find_nth_xpath(&self, xpath: &str, n: usize) -> Result<Option<XPathResult>> {
165        self.root()
166            .find_nth_xpath(xpath, n)
167    }
168
169    /// Get the children of the document. This is the same as getting the children
170    /// of the root node.
171    pub fn children(&self) -> Vec<HtmlNode> {
172        self.root()
173            .children()
174    }
175}
176
177impl Debug for HtmlDocument {
178    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
179        write!(f, "<HtmlDocument is_fragment={:?}>", self.is_fragment)
180    }
181}
182
183/// An HTML element in a document.
184#[derive(Clone)]
185pub struct HtmlNode {
186    dom: Arc<Html>,
187    node: NodeId,
188}
189
190// NOTE: This is safe because `HtmlNode` is immutable, so no UB can occur. Any modifications
191// to the DOM must be wrapped in some kind of synchronization primitive if this is ever
192// changed.
193// Don't try this at home kids.
194unsafe impl Send for HtmlNode {}
195unsafe impl Sync for HtmlNode {}
196
197impl HtmlNode {
198    /// Create a new [`HtmlNode`] from a [`Html`] DOM and a [`NodeId`].
199    /// 
200    /// * `dom` - The parsed HTML DOM.
201    /// * `node` - The HTML node.
202    pub fn new(dom: Arc<Html>, node: NodeId) -> Self {
203        Self { dom, node }
204    }
205
206    fn node(&self) -> Option<NodeRef<'_, Node>> {
207        self.dom
208            .tree
209            .get(self.node)
210    }
211
212    fn element(&self) -> Option<ElementRef<'_>> {
213        ElementRef::wrap(self.node()?)
214    }
215
216    /// Get the tag name of the node.
217    pub fn tag_name(&self) -> &str {
218        self.element()
219            .expect("element not found")
220            .value()
221            .name()
222    }
223
224    /// Find all elements matching a CSS selector.
225    /// 
226    /// * `selector` - The CSS selector to query for.
227    pub fn find_all(&self, selector: &str) -> Result<Vec<HtmlNode>> {
228        Ok(
229            self.element()
230                .expect("element not found")
231                .select(
232                    &Selector::parse(selector)
233                        .map_err(|e| PackageError::SelectorParseError(e.to_string()))?
234                )
235                .map(|element| HtmlNode::new(self.dom.clone(), element.id()))
236                .collect()
237        )
238    }
239
240    /// Find all elements matching an XPath expression.
241    /// 
242    /// * `xpath` - The XPath expression to query for.
243    pub fn find_all_xpath(&self, xpath: &str) -> Result<Vec<XPathResult>> {
244        fn resolve_accessor(node: &HtmlNode, accessor: &NodeAccessor) -> Option<XPathResult> {
245            match accessor {
246                NodeAccessor::Text { recursive } => {
247                    Some(XPathResult::String(if *recursive {
248                        node.inner_text()
249                    } else {
250                        node.text()
251                    }))
252                },
253                NodeAccessor::Attribute(name) => {
254                    Some(XPathResult::String(node
255                        .get_attribute(name.as_str())?
256                        .to_string()
257                    ))
258                },
259                _ => Some(XPathResult::Node(node.clone())),
260            }
261        }
262
263        match parse_xpath(xpath) {
264            Some((selector, accessor)) => {
265                match selector.as_str() {
266                    "" => Ok(
267                        resolve_accessor(self, &accessor)
268                            .into_iter()
269                            .collect()
270                    ),
271                    _ => Ok(
272                        self.find_all(&selector)?
273                            .into_iter()
274                            .filter_map(|node| resolve_accessor(&node, &accessor))
275                            .collect()
276                    ),
277                }
278            },
279            None => Ok(Vec::new()),
280        }
281    }
282
283    /// Find the first element matching a CSS selector.
284    /// 
285    /// * `selector` - The CSS selector to query for.
286    pub fn find(&self, selector: &str) -> Result<Option<HtmlNode>> {
287        Ok(
288            self.find_all(selector)?
289                .into_iter()
290                .next()
291        )
292    }
293
294    /// Find the first element matching an XPath expression.
295    /// 
296    /// * `xpath` - The XPath expression to query for.
297    pub fn find_xpath(&self, xpath: &str) -> Result<Option<XPathResult>> {
298        Ok(
299            self.find_all_xpath(xpath)?
300                .into_iter()
301                .next()
302        )
303    }
304
305    /// Find the nth element matching a CSS selector.
306    ///
307    /// * `selector` - The CSS selector to query for.
308    /// * `n` - The index of the element to get.
309    pub fn find_nth(&self, selector: &str, n: usize) -> Result<Option<HtmlNode>> {
310        Ok(
311            self.find_all(selector)?
312                .into_iter()
313                .nth(n)
314        )
315    }
316
317    /// Find the nth element matching an XPath expression.
318    /// 
319    /// * `xpath` - The XPath expression to query for.
320    /// * `n` - The index of the element to get.
321    pub fn find_nth_xpath(&self, xpath: &str, n: usize) -> Result<Option<XPathResult>> {
322        Ok(
323            self.find_all_xpath(xpath)?
324                .into_iter()
325                .nth(n)
326        )
327    }
328
329    /// Get the attributes of the node.
330    pub fn attributes(&self) -> HashMap<&str, Option<&str>> {
331        self.element()
332            .expect("element not found")
333            .value()
334            .attrs()
335            .map(|(k, v)| (k, Some(v)))
336            .collect()
337    }
338
339    /// Get the attribute of the node with the specified name.
340    ///
341    /// * `name` - The name of the attribute to get.
342    pub fn get_attribute(&self, name: &str) -> Option<&str> {
343        self.element()
344            .expect("element not found")
345            .value()
346            .attr(name)
347    }
348
349    /// Get the text content of the node.
350    pub fn text(&self) -> String {
351        self.element()
352            .expect("element not found")
353            .text()
354            .next()
355            .unwrap_or("")
356            .to_string()
357    }
358
359    /// Get the inner text of the node.
360    pub fn inner_text(&self) -> String {
361        self.element()
362            .expect("element not found")
363            .text()
364            .collect::<Vec<_>>()
365            .join("")
366    }
367
368    /// Get the inner HTML of the node.
369    pub fn inner_html(&self) -> String {
370        self.element()
371            .expect("element not found")
372            .inner_html()
373    }
374
375    /// Get the outer HTML of the node.
376    pub fn outer_html(&self) -> String {
377        self.element()
378            .expect("element not found")
379            .html()
380    }
381
382    /// Get the children of the node.
383    pub fn children(&self) -> Vec<HtmlNode> {
384        self.element()
385            .expect("element not found")
386            .children()
387            .filter_map(|child| child
388                .value()
389                .is_element()
390                .then(|| HtmlNode::new(self.dom.clone(), child.id()))
391            )
392            .collect()
393    }
394}
395
396impl Debug for HtmlNode {
397    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
398        write!(f, "<{}", self.tag_name())?;
399
400        for (key, value) in self.attributes() {
401            if let Some(val) = value {
402                write!(f, " {}=\"{}\"", key, val)?;
403            }
404        }
405
406        write!(f, ">")
407    }
408}
409
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    #[test]
416    fn test_html_document_from_str() {
417        let html = "<html><head></head><body><h1>Hello, world!</h1></body></html>";
418        let doc = HtmlDocument::from_str(html.to_string());
419
420        assert_eq!(doc.raw, html);
421        assert_eq!(doc.root().tag_name(), "html");
422    }
423
424    #[test]
425    fn test_html_document_fragment_from_str() {
426        let html = "<h1>Hello, world!</h1>";
427        let doc = HtmlDocument::from_str(html.to_string());
428
429        assert_eq!(doc.root().tag_name(), "h1");
430        assert_eq!(doc.root().children().len(), 0);
431        assert_eq!(doc.root().inner_text(), "Hello, world!".to_string());
432    }
433
434    #[test]
435    fn test_html_document_query_selector() {
436        let html = r#"<html><head></head><body><h1 id="test">Hello, world!</h1><h1>Not hello world</h1></body></html>"#;
437        let doc = HtmlDocument::from_str(html.to_string());
438
439        let h1 = doc.find("h1#test").unwrap().unwrap();
440        assert_eq!(h1.inner_text(), "Hello, world!".to_string());
441    }
442
443    #[test]
444    fn test_html_node_text() {
445        let html = "<html><head></head><body><h1>Hello, <span>world!</span></h1></body></html>";
446        let doc = HtmlDocument::from_str(html.to_string());
447
448        let h1 = doc.find("h1").unwrap().unwrap();
449        assert_eq!(h1.text(), "Hello, ".to_string());
450    }
451
452    #[test]
453    fn test_html_node_inner_text() {
454        let html = "<html><head></head><body><h1>Hello, <span>world!</span></h1></body></html>";
455        let doc = HtmlDocument::from_str(html.to_string());
456
457        let root = doc.root();
458        assert_eq!(root.inner_text(), "Hello, world!".to_string());
459    }
460
461    #[test]
462    fn test_html_node_inner_html() {
463        let html = "<html><head></head><body><h1>Hello,<span>world!</span></h1></body></html>";
464        let doc = HtmlDocument::from_str(html.to_string());
465
466        let h1 = doc.find("h1").unwrap().unwrap();
467        assert_eq!(h1.inner_html(), "Hello,<span>world!</span>".to_string());
468    }
469
470    #[test]
471    fn test_html_node_outer_html() {
472        let html = "<html><head></head><body><h1>Hello, world!</h1></body></html>";
473        let doc = HtmlDocument::from_str(html.to_string());
474
475        let h1 = doc.find("h1").unwrap().unwrap();
476        assert_eq!(h1.outer_html(), "<h1>Hello, world!</h1>".to_string());
477    }
478
479    #[test]
480    fn test_html_node_get_attribute() {
481        let html = r#"<html><head></head><body><h1 id="title">Hello, world!</h1></body></html>"#;
482        let doc = HtmlDocument::from_str(html.to_string());
483
484        let h1 = doc.find("h1").unwrap().unwrap();
485        assert_eq!(h1.get_attribute("id"), Some("title"));
486    }
487
488    #[test]
489    fn test_html_node_children() {
490        let html = "<html><head></head><body><h1>Hello, world!</h1><p>Paragraph</p></body></html>";
491        let doc = HtmlDocument::from_str(html.to_string());
492
493        let body = doc.find("body").unwrap().unwrap();
494        let children = body.children();
495
496        assert_eq!(children.len(), 2);
497        assert_eq!(children[0].inner_text(), "Hello, world!".to_string());
498        assert_eq!(children[1].inner_text(), "Paragraph".to_string());
499    }
500
501    #[test]
502    fn test_html_document_find_xpath() {
503        let html = r#"<html><head></head><body><h1 id="title">Hello, world!</h1></body></html>"#;
504        let doc = HtmlDocument::from_str(html.to_string());
505
506        let h1 = doc.find_xpath("//h1[@id='title']")
507            .unwrap()
508            .unwrap()
509            .into_node()
510            .unwrap();
511        assert_eq!(h1.inner_text(), "Hello, world!".to_string());
512    }
513
514    #[test]
515    fn test_html_document_find_xpath_attribute() {
516        let html = r#"<html><head></head><body><h1 id="title">Hello, world!</h1></body></html>"#;
517        let doc = HtmlDocument::from_str(html.to_string());
518
519        let h1 = doc.find_xpath("//h1/@id")
520            .unwrap()
521            .unwrap()
522            .into_string()
523            .unwrap();
524        assert_eq!(h1, "title".to_string());
525    }
526
527    #[test]
528    fn test_html_document_find_xpath_text() {
529        let html = r#"<html><head></head><body><h1 id="title">Hello, <span>world!</span></h1></body></html>"#;
530        let doc = HtmlDocument::from_str(html.to_string());
531
532        let h1 = doc.find_xpath("//h1/text()")
533            .unwrap()
534            .unwrap()
535            .into_string()
536            .unwrap();
537        assert_eq!(h1, "Hello, ".to_string());
538    }
539
540    #[test]
541    fn test_html_document_find_xpath_inner_text() {
542        let html = r#"<html><head></head><body><h1 id="title">Hello, <span>world!</span></h1></body></html>"#;
543        let doc = HtmlDocument::from_str(html.to_string());
544
545        let h1 = doc.find_xpath("//h1//text()")
546            .unwrap()
547            .unwrap()
548            .into_string()
549            .unwrap();
550        assert_eq!(h1, "Hello, world!".to_string());
551    }
552
553    #[test]
554    fn test_html_document_find_xpath_bad() {
555        let html = r#"<html><head></head><body><h1 id="title">Hello, world!</h1></body></html>"#;
556        let doc = HtmlDocument::from_str(html.to_string());
557
558        let h1 = doc.find_xpath("//h1[@id='title']/@src").unwrap();
559        assert!(h1.is_none());
560    }
561
562    #[test]
563    fn test_html_document_find_nth() {
564        let html = "<html><head></head><body><h1>Hello, world!</h1><h1>Not hello world</h1></body></html>";
565        let doc = HtmlDocument::from_str(html.to_string());
566
567        let h1 = doc.find_nth("h1", 1).unwrap().unwrap();
568        assert_eq!(h1.inner_text(), "Not hello world".to_string());
569    }
570
571    #[test]
572    fn test_html_document_find_nth_selector() {
573        let html = "<html><head></head><body><h1>Hello, world!</h1><h1>Not hello world</h1></body></html>";
574        let doc = HtmlDocument::from_str(html.to_string());
575
576        let h1 = doc.find("h1:nth-of-type(2)").unwrap().unwrap();
577        assert_eq!(h1.inner_text(), "Not hello world".to_string());
578    }
579
580    #[test]
581    fn test_html_document_find_xpath_nth() {
582        let html = r#"<html><head></head><body><h1>Hello, world!</h1><h1>Not hello world</h1></body></html>"#;
583        let doc = HtmlDocument::from_str(html.to_string());
584
585        let h1 = doc.find_xpath("//h1[2]")
586            .unwrap()
587            .unwrap()
588            .into_node()
589            .unwrap();
590        assert_eq!(h1.inner_text(), "Not hello world".to_string());
591    }
592
593    #[test]
594    fn test_html_node_find_all() {
595        let html = "<html><head></head><body><h1>Hello, world!</h1><h1>Not hello world</h1></body></html>";
596        let doc = HtmlDocument::from_str(html.to_string());
597
598        let body = doc.find("body").unwrap().unwrap();
599        let h1s = body.find_all("h1").unwrap();
600
601        assert_eq!(h1s.len(), 2);
602        assert_eq!(h1s[0].inner_text(), "Hello, world!".to_string());
603        assert_eq!(h1s[1].inner_text(), "Not hello world".to_string());
604    }
605
606    #[test]
607    fn test_html_node_find_all_xpath() {
608        let html = r#"<html><head></head><body><h1>Hello, world!</h1><h1>Not hello world</h1></body></html>"#;
609        let doc = HtmlDocument::from_str(html.to_string());
610
611        let body = doc.find("body").unwrap().unwrap();
612        let h1s = body.find_all_xpath("//h1").unwrap();
613
614        assert_eq!(h1s.len(), 2);
615        assert_eq!(h1s[0].as_node().unwrap().inner_text(), "Hello, world!".to_string());
616        assert_eq!(h1s[1].as_node().unwrap().inner_text(), "Not hello world".to_string());
617    }
618
619    #[test]
620    fn test_html_node_find_nth() {
621        let html = "<html><head></head><body><h1>Hello, world!</h1><h1>Not hello world</h1></body></html>";
622        let doc = HtmlDocument::from_str(html.to_string());
623
624        let body = doc.find("body").unwrap().unwrap();
625        let h1 = body.find_nth("h1", 1).unwrap().unwrap();
626
627        assert_eq!(h1.inner_text(), "Not hello world".to_string());
628    }
629
630    #[test]
631    fn test_html_node_find_nth_xpath() {
632        let html = r#"<html><head></head><body><h1>Hello, world!</h1><h1>Not hello world</h1></body></html>"#;
633        let doc = HtmlDocument::from_str(html.to_string());
634
635        let body = doc.find("body").unwrap().unwrap();
636        let h1 = body.find_nth_xpath("//h1", 1).unwrap().unwrap().into_node().unwrap();
637
638        assert_eq!(h1.inner_text(), "Not hello world".to_string());
639    }
640
641    #[test]
642    fn test_html_node_relative_find_xpath() {
643        let html = r#"<html><head></head><body><h1>Hello, world!</h1><h1>Not hello world</h1></body></html>"#;
644        let doc = HtmlDocument::from_str(html.to_string());
645
646        let body = doc.find("body").unwrap().unwrap();
647        let h1 = body.find_xpath(".//h1[2]")
648            .unwrap()
649            .unwrap()
650            .into_node()
651            .unwrap();
652
653        assert_eq!(h1.inner_text(), "Not hello world".to_string());
654    }
655
656    #[test]
657    fn test_html_node_relative_find_xpath_text_accessor() {
658        let html = r#"<html><head></head><body><h1>Hello, world!</h1><h1>Not hello world</h1></body></html>"#;
659        let doc = HtmlDocument::from_str(html.to_string());
660
661        let h1 = doc.find_xpath("//h1[1]").unwrap().unwrap().into_node().unwrap();
662        let text = h1.find_xpath("./text()")
663            .unwrap()
664            .unwrap()
665            .into_string()
666            .unwrap();
667
668        assert_eq!(text, "Hello, world!".to_string());
669    }
670}