scraper/element_ref/
mod.rs

1//! Element references.
2
3use std::ops::Deref;
4
5use ego_tree::iter::{Edge, Traverse};
6use ego_tree::NodeRef;
7use fast_html5ever::serialize::{serialize, SerializeOpts, TraversalScope};
8
9use crate::node::Element;
10use crate::node::Node;
11use crate::selector::Selector;
12
13/// Wrapper around a reference to an element node.
14///
15/// This wrapper implements the `Element` trait from the `selectors` crate, which allows it to be
16/// matched against CSS selectors.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub struct ElementRef<'a> {
19    node: NodeRef<'a, Node>,
20    /// The language of the element. Not used atm.
21    pub lang: &'a str,
22}
23
24impl<'a> ElementRef<'a> {
25    fn new(node: NodeRef<'a, Node>) -> Self {
26        ElementRef { node, lang: "" }
27    }
28
29    /// Wraps a `NodeRef` only if it references a `Node::Element`.
30    pub fn wrap(node: NodeRef<'a, Node>) -> Option<Self> {
31        if node.value().is_element() {
32            Some(ElementRef::new(node))
33        } else {
34            None
35        }
36    }
37
38    /// Returns the `Element` referenced by `self`.
39    pub fn value(&self) -> &'a Element {
40        self.node.value().as_element().unwrap()
41    }
42
43    /// Returns an iterator over descendent elements matching a selector.
44    pub fn select<'b>(&self, selector: &'b Selector) -> Select<'a, 'b> {
45        let mut inner = self.traverse();
46        inner.next(); // Skip Edge::Open(self).
47
48        Select {
49            scope: *self,
50            inner,
51            selector,
52        }
53    }
54
55    fn serialize(&self, traversal_scope: TraversalScope) -> String {
56        let opts = SerializeOpts {
57            scripting_enabled: false, // It's not clear what this does.
58            traversal_scope,
59            create_missing_parent: false,
60        };
61        let mut buf = Vec::new();
62        let _ = serialize(&mut buf, self, opts);
63        // we need to get the initial encoding of the html lang if used.
64        auto_encoder::auto_encode_bytes(&buf)
65    }
66
67    /// Returns the HTML of this element.
68    pub fn html(&self) -> String {
69        self.serialize(TraversalScope::IncludeNode)
70    }
71
72    /// Returns the inner HTML of this element.
73    pub fn inner_html(&self) -> String {
74        self.serialize(TraversalScope::ChildrenOnly(None))
75    }
76
77    /// Returns the value of an attribute.
78    pub fn attr(&self, attr: &str) -> Option<&str> {
79        self.value().attr(attr)
80    }
81
82    /// Returns an iterator over descendent text nodes.
83    pub fn text(&self) -> Text<'a> {
84        Text {
85            inner: self.traverse(),
86        }
87    }
88}
89
90impl<'a> Deref for ElementRef<'a> {
91    type Target = NodeRef<'a, Node>;
92    fn deref(&self) -> &NodeRef<'a, Node> {
93        &self.node
94    }
95}
96
97/// Iterator over descendent elements matching a selector.
98#[derive(Debug, Clone)]
99pub struct Select<'a, 'b> {
100    scope: ElementRef<'a>,
101    inner: Traverse<'a, Node>,
102    selector: &'b Selector,
103}
104
105impl<'a, 'b> Iterator for Select<'a, 'b> {
106    type Item = ElementRef<'a>;
107
108    fn next(&mut self) -> Option<ElementRef<'a>> {
109        for edge in &mut self.inner {
110            if let Edge::Open(node) = edge {
111                if let Some(element) = ElementRef::wrap(node) {
112                    if self.selector.matches_with_scope(&element, Some(self.scope)) {
113                        return Some(element);
114                    }
115                }
116            }
117        }
118        None
119    }
120}
121
122/// Iterator over descendent text nodes.
123#[derive(Debug, Clone)]
124pub struct Text<'a> {
125    inner: Traverse<'a, Node>,
126}
127
128impl<'a> Iterator for Text<'a> {
129    type Item = &'a str;
130
131    fn next(&mut self) -> Option<&'a str> {
132        for edge in &mut self.inner {
133            if let Edge::Open(ref node) = edge {
134                // check if the element is not a script or link.
135                let processable = match node.parent() {
136                    Some(e) => {
137                        match e.value().as_element() {
138                            Some(n) => {
139                                let name = n.name();
140                                // prevent all script and style elements
141                                !(name == "script" || name == "style")
142                            }
143                            _ => true,
144                        }
145                    }
146                    _ => true,
147                };
148
149                if !processable {
150                    continue;
151                }
152
153                if let Node::Text(text) = node.value() {
154                    return Some(&**text);
155                }
156            }
157        }
158
159        None
160    }
161}
162
163mod element;
164mod serializable;
165
166#[cfg(test)]
167mod tests {
168    use crate::html::Html;
169    use crate::selector::Selector;
170
171    #[test]
172    fn test_scope() {
173        let html = r"
174            <div>
175                <b>1</b>
176                <span>
177                    <span><b>2</b></span>
178                    <b>3</b>
179                </span>
180            </div>
181        ";
182        let fragment = Html::parse_fragment(html);
183        let sel1 = Selector::parse("div > span").unwrap();
184        let sel2 = Selector::parse(":scope > b").unwrap();
185
186        let element1 = fragment.select(&sel1).next().unwrap();
187        let element2 = element1.select(&sel2).next().unwrap();
188        assert_eq!(element2.inner_html(), "3");
189    }
190
191    #[test]
192    fn test_text() {
193        let fragment = Html::parse_fragment("<h1>Hello, <i>world!</i></h1><script>window.var = true</script><style>.main { background: white };</style>");
194        let selector = Selector::parse("h1").unwrap();
195
196        let h1 = fragment.select(&selector).next().unwrap();
197        let text = h1.text().collect::<Vec<_>>();
198
199        assert_eq!(vec!["Hello, ", "world!"], text);
200    }
201}