Skip to main content

scrape_core/query/
selector.rs

1//! CSS selector parsing and matching via the `selectors` crate.
2//!
3//! This module provides integration with Mozilla's `selectors` crate for CSS selector
4//! parsing and element matching. The key types are:
5//!
6//! - [`ScrapeSelector`] - Marker type implementing [`selectors::SelectorImpl`]
7//! - [`ElementWrapper`] - Adapter implementing [`selectors::Element`] for our DOM
8
9use std::{
10    borrow::Borrow,
11    fmt,
12    hash::{Hash, Hasher},
13};
14
15use cssparser::ToCss;
16use selectors::{
17    Element, OpaqueElement, SelectorList,
18    attr::{AttrSelectorOperation, CaseSensitivity, NamespaceConstraint},
19    context::{MatchingForInvalidation, NeedsSelectorFlags, QuirksMode, SelectorCaches},
20    matching::{ElementSelectorFlags, MatchingContext, MatchingMode},
21    parser::{ParseRelative, Parser, SelectorImpl, SelectorParseErrorKind},
22};
23
24use super::error::{QueryError, QueryResult};
25use crate::dom::{Document, NodeId};
26
27/// A CSS value string that implements the traits required by `selectors`.
28#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)]
29pub struct CssString(String);
30
31impl CssString {
32    /// Creates a new CSS string.
33    pub fn new(s: impl Into<String>) -> Self {
34        Self(s.into())
35    }
36
37    /// Returns the underlying string.
38    pub fn as_str(&self) -> &str {
39        &self.0
40    }
41}
42
43impl From<&str> for CssString {
44    fn from(s: &str) -> Self {
45        Self(s.to_owned())
46    }
47}
48
49impl AsRef<str> for CssString {
50    fn as_ref(&self) -> &str {
51        &self.0
52    }
53}
54
55impl ToCss for CssString {
56    fn to_css<W>(&self, dest: &mut W) -> fmt::Result
57    where
58        W: fmt::Write,
59    {
60        cssparser::serialize_identifier(&self.0, dest)
61    }
62}
63
64impl Borrow<str> for CssString {
65    fn borrow(&self) -> &str {
66        &self.0
67    }
68}
69
70impl precomputed_hash::PrecomputedHash for CssString {
71    #[allow(clippy::cast_possible_truncation)]
72    fn precomputed_hash(&self) -> u32 {
73        use std::collections::hash_map::DefaultHasher;
74
75        let mut hasher = DefaultHasher::new();
76        self.0.hash(&mut hasher);
77        // Intentional truncation for hash value
78        hasher.finish() as u32
79    }
80}
81
82/// A local name (tag name) that implements the traits required by `selectors`.
83#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)]
84pub struct CssLocalName(String);
85
86impl CssLocalName {
87    /// Creates a new local name.
88    pub fn new(s: impl Into<String>) -> Self {
89        Self(s.into().to_ascii_lowercase())
90    }
91
92    /// Returns the underlying string.
93    pub fn as_str(&self) -> &str {
94        &self.0
95    }
96}
97
98impl From<&str> for CssLocalName {
99    fn from(s: &str) -> Self {
100        Self(s.to_ascii_lowercase())
101    }
102}
103
104impl AsRef<str> for CssLocalName {
105    fn as_ref(&self) -> &str {
106        &self.0
107    }
108}
109
110impl ToCss for CssLocalName {
111    fn to_css<W>(&self, dest: &mut W) -> fmt::Result
112    where
113        W: fmt::Write,
114    {
115        dest.write_str(&self.0)
116    }
117}
118
119impl Borrow<str> for CssLocalName {
120    fn borrow(&self) -> &str {
121        &self.0
122    }
123}
124
125impl precomputed_hash::PrecomputedHash for CssLocalName {
126    #[allow(clippy::cast_possible_truncation)]
127    fn precomputed_hash(&self) -> u32 {
128        use std::collections::hash_map::DefaultHasher;
129
130        let mut hasher = DefaultHasher::new();
131        self.0.hash(&mut hasher);
132        // Intentional truncation for hash value
133        hasher.finish() as u32
134    }
135}
136
137/// Marker type for our selector implementation.
138///
139/// This type implements [`SelectorImpl`] to configure the selectors crate
140/// for our DOM representation.
141#[derive(Debug, Clone, PartialEq, Eq)]
142pub struct ScrapeSelector;
143
144/// Pseudo-class variants (non-tree-structural).
145///
146/// We only support a minimal set of pseudo-classes that can be evaluated
147/// statically without browser state.
148#[derive(Debug, Clone, PartialEq, Eq)]
149pub enum NonTSPseudoClass {
150    /// The `:link` pseudo-class (matches `<a>` with href).
151    Link,
152    /// The `:any-link` pseudo-class.
153    AnyLink,
154}
155
156impl selectors::parser::NonTSPseudoClass for NonTSPseudoClass {
157    type Impl = ScrapeSelector;
158
159    fn is_active_or_hover(&self) -> bool {
160        false
161    }
162
163    fn is_user_action_state(&self) -> bool {
164        false
165    }
166}
167
168impl ToCss for NonTSPseudoClass {
169    fn to_css<W>(&self, dest: &mut W) -> fmt::Result
170    where
171        W: fmt::Write,
172    {
173        match self {
174            Self::Link => dest.write_str(":link"),
175            Self::AnyLink => dest.write_str(":any-link"),
176        }
177    }
178}
179
180/// Pseudo-element variants (not supported for matching).
181#[derive(Debug, Clone, PartialEq, Eq)]
182pub enum PseudoElement {}
183
184impl selectors::parser::PseudoElement for PseudoElement {
185    type Impl = ScrapeSelector;
186}
187
188impl ToCss for PseudoElement {
189    fn to_css<W>(&self, _dest: &mut W) -> fmt::Result
190    where
191        W: fmt::Write,
192    {
193        // PseudoElement is an uninhabited type (no variants), so this is unreachable
194        unreachable!("PseudoElement has no variants")
195    }
196}
197
198impl SelectorImpl for ScrapeSelector {
199    type ExtraMatchingData<'a> = ();
200    type AttrValue = CssString;
201    type Identifier = CssLocalName;
202    type LocalName = CssLocalName;
203    type NamespaceUrl = CssString;
204    type NamespacePrefix = CssLocalName;
205    type BorrowedLocalName = CssLocalName;
206    type BorrowedNamespaceUrl = CssString;
207    type NonTSPseudoClass = NonTSPseudoClass;
208    type PseudoElement = PseudoElement;
209}
210
211/// Custom selector parser for our implementation.
212struct SelectorParser;
213
214impl<'i> Parser<'i> for SelectorParser {
215    type Impl = ScrapeSelector;
216    type Error = SelectorParseErrorKind<'i>;
217
218    fn parse_non_ts_pseudo_class(
219        &self,
220        location: cssparser::SourceLocation,
221        name: cssparser::CowRcStr<'i>,
222    ) -> Result<NonTSPseudoClass, cssparser::ParseError<'i, Self::Error>> {
223        match name.as_ref() {
224            "link" => Ok(NonTSPseudoClass::Link),
225            "any-link" => Ok(NonTSPseudoClass::AnyLink),
226            _ => Err(cssparser::ParseError {
227                kind: cssparser::ParseErrorKind::Custom(
228                    SelectorParseErrorKind::UnsupportedPseudoClassOrElement(name),
229                ),
230                location,
231            }),
232        }
233    }
234}
235
236/// Parses a CSS selector string into a compiled selector list.
237///
238/// # Errors
239///
240/// Returns [`QueryError::InvalidSelector`] if the selector syntax is invalid.
241///
242/// # Examples
243///
244/// ```rust
245/// use scrape_core::query::parse_selector;
246///
247/// let selectors = parse_selector("div.container > span").unwrap();
248/// ```
249pub fn parse_selector(selector: &str) -> QueryResult<SelectorList<ScrapeSelector>> {
250    let mut parser_input = cssparser::ParserInput::new(selector);
251    let mut parser = cssparser::Parser::new(&mut parser_input);
252
253    SelectorList::parse(&SelectorParser, &mut parser, ParseRelative::No).map_err(|e| {
254        // Sanitize error messages to expose only position info, avoiding potential
255        // information disclosure from internal parser state in public error messages.
256        QueryError::invalid_selector(format!(
257            "invalid selector at line {}, column {}",
258            e.location.line, e.location.column
259        ))
260    })
261}
262
263/// Adapter wrapping a DOM node for selector matching.
264///
265/// This type implements the [`selectors::Element`] trait, allowing our
266/// arena-based DOM to be matched against CSS selectors.
267#[derive(Debug, Clone, Copy)]
268pub struct ElementWrapper<'a> {
269    doc: &'a Document,
270    id: NodeId,
271}
272
273impl<'a> ElementWrapper<'a> {
274    /// Creates a new element wrapper.
275    #[must_use]
276    pub fn new(doc: &'a Document, id: NodeId) -> Self {
277        Self { doc, id }
278    }
279
280    /// Returns the node ID.
281    #[must_use]
282    pub fn node_id(&self) -> NodeId {
283        self.id
284    }
285
286    /// Returns a reference to the document.
287    #[must_use]
288    pub fn document(&self) -> &'a Document {
289        self.doc
290    }
291}
292
293impl PartialEq for ElementWrapper<'_> {
294    fn eq(&self, other: &Self) -> bool {
295        // Document equality via pointer comparison ensures elements from different documents
296        // are never considered equal, maintaining correctness for cross-document operations.
297        // NodeId equality alone is insufficient since different documents may have nodes
298        // with the same ID but different content.
299        std::ptr::eq(self.doc, other.doc) && self.id == other.id
300    }
301}
302
303impl Eq for ElementWrapper<'_> {}
304
305impl Element for ElementWrapper<'_> {
306    type Impl = ScrapeSelector;
307
308    fn opaque(&self) -> OpaqueElement {
309        OpaqueElement::new(self)
310    }
311
312    fn parent_element(&self) -> Option<Self> {
313        let parent_id = self.doc.parent(self.id)?;
314        let parent_node = self.doc.get(parent_id)?;
315        if parent_node.kind.is_element() { Some(Self::new(self.doc, parent_id)) } else { None }
316    }
317
318    fn parent_node_is_shadow_root(&self) -> bool {
319        false
320    }
321
322    fn containing_shadow_host(&self) -> Option<Self> {
323        None
324    }
325
326    fn is_pseudo_element(&self) -> bool {
327        false
328    }
329
330    fn prev_sibling_element(&self) -> Option<Self> {
331        let mut current = self.doc.prev_sibling(self.id);
332        while let Some(sibling_id) = current {
333            if let Some(node) = self.doc.get(sibling_id)
334                && node.kind.is_element()
335            {
336                return Some(Self::new(self.doc, sibling_id));
337            }
338            current = self.doc.prev_sibling(sibling_id);
339        }
340        None
341    }
342
343    fn next_sibling_element(&self) -> Option<Self> {
344        let mut current = self.doc.next_sibling(self.id);
345        while let Some(sibling_id) = current {
346            if let Some(node) = self.doc.get(sibling_id)
347                && node.kind.is_element()
348            {
349                return Some(Self::new(self.doc, sibling_id));
350            }
351            current = self.doc.next_sibling(sibling_id);
352        }
353        None
354    }
355
356    fn first_element_child(&self) -> Option<Self> {
357        for child_id in self.doc.children(self.id) {
358            if let Some(node) = self.doc.get(child_id)
359                && node.kind.is_element()
360            {
361                return Some(Self::new(self.doc, child_id));
362            }
363        }
364        None
365    }
366
367    fn is_html_element_in_html_document(&self) -> bool {
368        true
369    }
370
371    fn has_local_name(&self, local_name: &<Self::Impl as SelectorImpl>::BorrowedLocalName) -> bool {
372        self.doc
373            .get(self.id)
374            .and_then(|n| n.kind.tag_name())
375            .is_some_and(|name| name.eq_ignore_ascii_case(local_name.as_str()))
376    }
377
378    fn has_namespace(&self, _ns: &<Self::Impl as SelectorImpl>::BorrowedNamespaceUrl) -> bool {
379        // We don't track namespaces, so match everything
380        true
381    }
382
383    fn is_same_type(&self, other: &Self) -> bool {
384        self.doc
385            .get(self.id)
386            .and_then(|n| n.kind.tag_name())
387            .zip(other.doc.get(other.id).and_then(|n| n.kind.tag_name()))
388            .is_some_and(|(a, b)| a.eq_ignore_ascii_case(b))
389    }
390
391    fn attr_matches(
392        &self,
393        ns: &NamespaceConstraint<&<Self::Impl as SelectorImpl>::NamespaceUrl>,
394        local_name: &<Self::Impl as SelectorImpl>::BorrowedLocalName,
395        operation: &AttrSelectorOperation<&<Self::Impl as SelectorImpl>::AttrValue>,
396    ) -> bool {
397        // In HTML, we don't track namespaces, so we accept all namespace constraints
398        // - NamespaceConstraint::Any: matches any namespace (e.g., [*|href])
399        // - NamespaceConstraint::Specific: matches a specific namespace (we ignore since HTML has
400        //   no namespaces)
401        let _ = ns;
402
403        let Some(node) = self.doc.get(self.id) else { return false };
404        let Some(attrs) = node.kind.attributes() else { return false };
405
406        // HTML attribute names are case-insensitive
407        let attr_name = local_name.as_str();
408        let value = attrs.iter().find(|(k, _)| k.eq_ignore_ascii_case(attr_name)).map(|(_, v)| v);
409
410        let Some(value) = value else { return false };
411
412        operation.eval_str(value)
413    }
414
415    fn match_non_ts_pseudo_class(
416        &self,
417        pc: &NonTSPseudoClass,
418        _context: &mut MatchingContext<Self::Impl>,
419    ) -> bool {
420        match pc {
421            NonTSPseudoClass::Link | NonTSPseudoClass::AnyLink => {
422                // Match <a>, <area>, or <link> elements with href
423                let Some(node) = self.doc.get(self.id) else { return false };
424                let Some(tag_name) = node.kind.tag_name() else { return false };
425                let Some(attrs) = node.kind.attributes() else { return false };
426
427                matches!(tag_name, "a" | "area" | "link") && attrs.contains_key("href")
428            }
429        }
430    }
431
432    fn match_pseudo_element(
433        &self,
434        _pe: &PseudoElement,
435        _context: &mut MatchingContext<Self::Impl>,
436    ) -> bool {
437        // No pseudo-elements supported
438        false
439    }
440
441    fn is_link(&self) -> bool {
442        let Some(node) = self.doc.get(self.id) else { return false };
443        let Some(tag_name) = node.kind.tag_name() else { return false };
444        let Some(attrs) = node.kind.attributes() else { return false };
445
446        matches!(tag_name, "a" | "area" | "link") && attrs.contains_key("href")
447    }
448
449    fn is_html_slot_element(&self) -> bool {
450        false
451    }
452
453    fn has_id(
454        &self,
455        id: &<Self::Impl as SelectorImpl>::Identifier,
456        case_sensitivity: CaseSensitivity,
457    ) -> bool {
458        let Some(node) = self.doc.get(self.id) else { return false };
459        let Some(attrs) = node.kind.attributes() else { return false };
460        let Some(element_id) = attrs.get("id") else { return false };
461
462        case_sensitivity.eq(element_id.as_bytes(), id.as_str().as_bytes())
463    }
464
465    fn has_class(
466        &self,
467        name: &<Self::Impl as SelectorImpl>::Identifier,
468        case_sensitivity: CaseSensitivity,
469    ) -> bool {
470        let Some(node) = self.doc.get(self.id) else { return false };
471        let Some(attrs) = node.kind.attributes() else { return false };
472        let Some(class_attr) = attrs.get("class") else { return false };
473
474        class_attr
475            .split_whitespace()
476            .any(|class| case_sensitivity.eq(class.as_bytes(), name.as_str().as_bytes()))
477    }
478
479    fn imported_part(
480        &self,
481        _name: &<Self::Impl as SelectorImpl>::Identifier,
482    ) -> Option<<Self::Impl as SelectorImpl>::Identifier> {
483        None
484    }
485
486    fn is_part(&self, _name: &<Self::Impl as SelectorImpl>::Identifier) -> bool {
487        false
488    }
489
490    fn is_empty(&self) -> bool {
491        // Element is empty if it has no element or text children
492        for child_id in self.doc.children(self.id) {
493            if let Some(node) = self.doc.get(child_id) {
494                match &node.kind {
495                    crate::dom::NodeKind::Element { .. } => return false,
496                    crate::dom::NodeKind::Text { content } => {
497                        if !content.trim().is_empty() {
498                            return false;
499                        }
500                    }
501                    crate::dom::NodeKind::Comment { .. } => {}
502                }
503            }
504        }
505        true
506    }
507
508    fn is_root(&self) -> bool {
509        self.doc.root().is_some_and(|_root_id| {
510            // Walk up to find the html element
511            self.doc
512                .get(self.id)
513                .is_some_and(|node| node.kind.tag_name().is_some_and(|name| name == "html"))
514                && self.parent_element().is_none()
515        })
516    }
517
518    fn apply_selector_flags(&self, _flags: ElementSelectorFlags) {
519        // No-op: we don't need to track selector flags
520    }
521
522    fn add_element_unique_hashes(&self, _filter: &mut selectors::bloom::BloomFilter) -> bool {
523        false
524    }
525
526    fn has_custom_state(&self, _name: &<Self::Impl as SelectorImpl>::Identifier) -> bool {
527        false
528    }
529}
530
531/// Checks if an element matches a selector list.
532///
533/// This creates new [`SelectorCaches`] for each call. For batch operations
534/// (e.g., iterating over many elements), use [`matches_selector_with_caches`]
535/// to reuse caches and avoid allocation overhead.
536///
537/// # Examples
538///
539/// ```rust
540/// use scrape_core::{
541///     Html5everParser, Parser,
542///     query::{matches_selector, parse_selector},
543/// };
544///
545/// let parser = Html5everParser;
546/// let doc = parser.parse("<div class=\"foo\"><span id=\"bar\">text</span></div>").unwrap();
547/// let selectors = parse_selector("span#bar").unwrap();
548///
549/// // Find span element and check if it matches
550/// for (id, node) in doc.nodes() {
551///     if node.kind.tag_name() == Some("span") {
552///         assert!(matches_selector(&doc, id, &selectors));
553///     }
554/// }
555/// ```
556#[must_use]
557pub fn matches_selector(
558    doc: &Document,
559    id: NodeId,
560    selectors: &SelectorList<ScrapeSelector>,
561) -> bool {
562    let mut caches = SelectorCaches::default();
563    matches_selector_with_caches(doc, id, selectors, &mut caches)
564}
565
566/// Checks if an element matches a selector list, reusing provided caches.
567///
568/// This is more efficient than [`matches_selector`] when matching many elements
569/// against the same selector, as it avoids creating new [`SelectorCaches`]
570/// for each element.
571///
572/// # Examples
573///
574/// ```rust
575/// use scrape_core::{
576///     Html5everParser, Parser,
577///     query::{matches_selector_with_caches, parse_selector},
578/// };
579/// use selectors::context::SelectorCaches;
580///
581/// let parser = Html5everParser;
582/// let doc = parser.parse("<ul><li>A</li><li>B</li><li>C</li></ul>").unwrap();
583/// let selectors = parse_selector("li").unwrap();
584///
585/// // Reuse caches for efficiency when matching many elements
586/// let mut caches = SelectorCaches::default();
587/// let count = doc
588///     .nodes()
589///     .filter(|(id, n)| {
590///         n.kind.is_element() && matches_selector_with_caches(&doc, *id, &selectors, &mut caches)
591///     })
592///     .count();
593/// assert_eq!(count, 3);
594/// ```
595#[must_use]
596pub fn matches_selector_with_caches(
597    doc: &Document,
598    id: NodeId,
599    selectors: &SelectorList<ScrapeSelector>,
600    caches: &mut SelectorCaches,
601) -> bool {
602    let element = ElementWrapper::new(doc, id);
603    let mut context = MatchingContext::new(
604        MatchingMode::Normal,
605        None,
606        caches,
607        QuirksMode::NoQuirks,
608        NeedsSelectorFlags::No,
609        MatchingForInvalidation::No,
610    );
611
612    selectors.slice().iter().any(|selector| {
613        selectors::matching::matches_selector(selector, 0, None, &element, &mut context)
614    })
615}
616
617/// Checks if an element matches a selector list.
618///
619/// This is a convenience wrapper around [`matches_selector`] for use with `Tag::closest()`.
620///
621/// # Examples
622///
623/// ```rust
624/// use scrape_core::{
625///     Html5everParser, Parser,
626///     query::{matches_selector_list, parse_selector},
627/// };
628///
629/// let parser = Html5everParser;
630/// let doc = parser.parse("<div class='foo'><span id='bar'>text</span></div>").unwrap();
631/// let selectors = parse_selector("span#bar").unwrap();
632///
633/// // Find span element and check if it matches
634/// for (id, node) in doc.nodes() {
635///     if node.kind.tag_name() == Some("span") {
636///         assert!(matches_selector_list(&doc, id, &selectors));
637///     }
638/// }
639/// ```
640#[must_use]
641pub fn matches_selector_list(
642    doc: &Document,
643    id: NodeId,
644    selector_list: &SelectorList<ScrapeSelector>,
645) -> bool {
646    matches_selector(doc, id, selector_list)
647}
648
649#[cfg(test)]
650mod tests {
651    use super::*;
652    use crate::parser::{Html5everParser, Parser};
653
654    fn parse_doc(html: &str) -> Document {
655        Html5everParser.parse(html).unwrap()
656    }
657
658    fn find_element_by_tag(doc: &Document, tag: &str) -> Option<NodeId> {
659        doc.nodes().find(|(_, n)| n.kind.tag_name() == Some(tag)).map(|(id, _)| id)
660    }
661
662    #[test]
663    fn test_parse_simple_selector() {
664        let selectors = parse_selector("div").unwrap();
665        assert_eq!(selectors.slice().len(), 1);
666    }
667
668    #[test]
669    fn test_parse_class_selector() {
670        let selectors = parse_selector(".foo").unwrap();
671        assert_eq!(selectors.slice().len(), 1);
672    }
673
674    #[test]
675    fn test_parse_id_selector() {
676        let selectors = parse_selector("#bar").unwrap();
677        assert_eq!(selectors.slice().len(), 1);
678    }
679
680    #[test]
681    fn test_parse_compound_selector() {
682        let selectors = parse_selector("div.foo#bar").unwrap();
683        assert_eq!(selectors.slice().len(), 1);
684    }
685
686    #[test]
687    fn test_parse_descendant_combinator() {
688        let selectors = parse_selector("div span").unwrap();
689        assert_eq!(selectors.slice().len(), 1);
690    }
691
692    #[test]
693    fn test_parse_child_combinator() {
694        let selectors = parse_selector("div > span").unwrap();
695        assert_eq!(selectors.slice().len(), 1);
696    }
697
698    #[test]
699    fn test_parse_adjacent_sibling() {
700        let selectors = parse_selector("h1 + p").unwrap();
701        assert_eq!(selectors.slice().len(), 1);
702    }
703
704    #[test]
705    fn test_parse_general_sibling() {
706        let selectors = parse_selector("h1 ~ p").unwrap();
707        assert_eq!(selectors.slice().len(), 1);
708    }
709
710    #[test]
711    fn test_parse_attribute_exists() {
712        let selectors = parse_selector("[href]").unwrap();
713        assert_eq!(selectors.slice().len(), 1);
714    }
715
716    #[test]
717    fn test_parse_attribute_equals() {
718        let selectors = parse_selector("[type=\"text\"]").unwrap();
719        assert_eq!(selectors.slice().len(), 1);
720    }
721
722    #[test]
723    fn test_parse_multiple_selectors() {
724        let selectors = parse_selector("div, span, p").unwrap();
725        assert_eq!(selectors.slice().len(), 3);
726    }
727
728    #[test]
729    fn test_parse_invalid_selector() {
730        let result = parse_selector("[");
731        assert!(result.is_err());
732    }
733
734    #[test]
735    fn test_match_tag_selector() {
736        let doc = parse_doc("<div><span>text</span></div>");
737        let span_id = find_element_by_tag(&doc, "span").unwrap();
738        let selectors = parse_selector("span").unwrap();
739        assert!(matches_selector(&doc, span_id, &selectors));
740    }
741
742    #[test]
743    fn test_match_class_selector() {
744        let doc = parse_doc("<div class=\"foo bar\">text</div>");
745        let div_id = find_element_by_tag(&doc, "div").unwrap();
746
747        let selectors = parse_selector(".foo").unwrap();
748        assert!(matches_selector(&doc, div_id, &selectors));
749
750        let selectors = parse_selector(".bar").unwrap();
751        assert!(matches_selector(&doc, div_id, &selectors));
752
753        let selectors = parse_selector(".baz").unwrap();
754        assert!(!matches_selector(&doc, div_id, &selectors));
755    }
756
757    #[test]
758    fn test_match_id_selector() {
759        let doc = parse_doc("<div id=\"main\">text</div>");
760        let div_id = find_element_by_tag(&doc, "div").unwrap();
761
762        let selectors = parse_selector("#main").unwrap();
763        assert!(matches_selector(&doc, div_id, &selectors));
764
765        let selectors = parse_selector("#other").unwrap();
766        assert!(!matches_selector(&doc, div_id, &selectors));
767    }
768
769    #[test]
770    fn test_match_compound_selector() {
771        let doc = parse_doc("<div class=\"foo\" id=\"bar\">text</div>");
772        let div_id = find_element_by_tag(&doc, "div").unwrap();
773
774        let selectors = parse_selector("div.foo#bar").unwrap();
775        assert!(matches_selector(&doc, div_id, &selectors));
776
777        let selectors = parse_selector("div.foo#baz").unwrap();
778        assert!(!matches_selector(&doc, div_id, &selectors));
779    }
780
781    #[test]
782    fn test_match_attribute_exists() {
783        let doc = parse_doc("<a href=\"/page\">link</a>");
784        let a_id = find_element_by_tag(&doc, "a").unwrap();
785
786        // Verify we have the right element
787        let node = doc.get(a_id).unwrap();
788        let attrs = node.kind.attributes().unwrap();
789        assert!(attrs.contains_key("href"), "Element should have href attribute: {attrs:?}");
790
791        let selectors = parse_selector("[href]").unwrap();
792        assert_eq!(selectors.slice().len(), 1, "Should have one selector");
793        assert!(matches_selector(&doc, a_id, &selectors), "Element with href should match [href]");
794
795        let selectors = parse_selector("[title]").unwrap();
796        assert!(!matches_selector(&doc, a_id, &selectors));
797    }
798
799    #[test]
800    fn test_match_attribute_equals() {
801        let doc = parse_doc("<input type=\"text\">");
802        let input_id = find_element_by_tag(&doc, "input").unwrap();
803
804        let selectors = parse_selector("[type=\"text\"]").unwrap();
805        assert!(matches_selector(&doc, input_id, &selectors));
806
807        let selectors = parse_selector("[type=\"password\"]").unwrap();
808        assert!(!matches_selector(&doc, input_id, &selectors));
809    }
810
811    #[test]
812    fn test_element_is_empty() {
813        let doc = parse_doc("<div></div><span>text</span>");
814        let div_id = find_element_by_tag(&doc, "div").unwrap();
815        let span_id = find_element_by_tag(&doc, "span").unwrap();
816
817        let selectors = parse_selector(":empty").unwrap();
818        assert!(matches_selector(&doc, div_id, &selectors));
819        assert!(!matches_selector(&doc, span_id, &selectors));
820    }
821
822    #[test]
823    fn test_element_first_child() {
824        let doc = parse_doc("<ul><li>first</li><li>second</li></ul>");
825
826        // Find first li
827        let first_li =
828            doc.nodes().find(|(_, n)| n.kind.tag_name() == Some("li")).map(|(id, _)| id).unwrap();
829
830        let selectors = parse_selector("li:first-child").unwrap();
831        assert!(matches_selector(&doc, first_li, &selectors));
832    }
833
834    #[test]
835    fn test_match_not_selector() {
836        let doc = parse_doc("<div class=\"foo\">a</div><div class=\"bar\">b</div>");
837
838        let divs: Vec<_> = doc
839            .nodes()
840            .filter(|(_, n)| n.kind.tag_name() == Some("div"))
841            .map(|(id, _)| id)
842            .collect();
843
844        let selectors = parse_selector("div:not(.foo)").unwrap();
845
846        // Only the second div (with class="bar") should match
847        let match_count = divs.iter().filter(|id| matches_selector(&doc, **id, &selectors)).count();
848        assert_eq!(match_count, 1);
849    }
850
851    // ==================== Attribute Substring Selectors ====================
852
853    #[test]
854    fn test_match_attribute_prefix() {
855        let doc = parse_doc(
856            r#"<a href="https://example.com">secure</a><a href="http://example.com">insecure</a>"#,
857        );
858
859        let links: Vec<_> =
860            doc.nodes().filter(|(_, n)| n.kind.tag_name() == Some("a")).map(|(id, _)| id).collect();
861        assert_eq!(links.len(), 2);
862
863        let selectors = parse_selector("[href^=\"https\"]").unwrap();
864        let match_count =
865            links.iter().filter(|id| matches_selector(&doc, **id, &selectors)).count();
866        assert_eq!(match_count, 1, "[attr^=prefix] should match elements starting with prefix");
867    }
868
869    #[test]
870    fn test_match_attribute_suffix() {
871        let doc = parse_doc(r#"<a href="/page.html">html</a><a href="/page.pdf">pdf</a>"#);
872
873        let links: Vec<_> =
874            doc.nodes().filter(|(_, n)| n.kind.tag_name() == Some("a")).map(|(id, _)| id).collect();
875        assert_eq!(links.len(), 2);
876
877        let selectors = parse_selector("[href$=\".html\"]").unwrap();
878        let match_count =
879            links.iter().filter(|id| matches_selector(&doc, **id, &selectors)).count();
880        assert_eq!(match_count, 1, "[attr$=suffix] should match elements ending with suffix");
881    }
882
883    #[test]
884    fn test_match_attribute_contains() {
885        let doc = parse_doc(r#"<a href="/foo/bar/baz">yes</a><a href="/qux">no</a>"#);
886
887        let links: Vec<_> =
888            doc.nodes().filter(|(_, n)| n.kind.tag_name() == Some("a")).map(|(id, _)| id).collect();
889        assert_eq!(links.len(), 2);
890
891        let selectors = parse_selector("[href*=\"bar\"]").unwrap();
892        let match_count =
893            links.iter().filter(|id| matches_selector(&doc, **id, &selectors)).count();
894        assert_eq!(match_count, 1, "[attr*=substring] should match elements containing substring");
895    }
896
897    #[test]
898    fn test_match_attribute_word() {
899        let doc = parse_doc(r#"<div class="foo bar baz">yes</div><div class="foobar">no</div>"#);
900
901        let divs: Vec<_> = doc
902            .nodes()
903            .filter(|(_, n)| n.kind.tag_name() == Some("div"))
904            .map(|(id, _)| id)
905            .collect();
906        assert_eq!(divs.len(), 2);
907
908        let selectors = parse_selector("[class~=\"bar\"]").unwrap();
909        let match_count = divs.iter().filter(|id| matches_selector(&doc, **id, &selectors)).count();
910        assert_eq!(
911            match_count, 1,
912            "[attr~=word] should match elements with word in space-separated list"
913        );
914    }
915
916    #[test]
917    fn test_match_attribute_lang() {
918        let doc = parse_doc(
919            r#"<div lang="en-US">US</div><div lang="en-GB">GB</div><div lang="fr">FR</div>"#,
920        );
921
922        let divs: Vec<_> = doc
923            .nodes()
924            .filter(|(_, n)| n.kind.tag_name() == Some("div"))
925            .map(|(id, _)| id)
926            .collect();
927        assert_eq!(divs.len(), 3);
928
929        let selectors = parse_selector("[lang|=\"en\"]").unwrap();
930        let match_count = divs.iter().filter(|id| matches_selector(&doc, **id, &selectors)).count();
931        assert_eq!(match_count, 2, "[attr|=lang] should match 'en' and 'en-*' values");
932    }
933
934    // ==================== Pseudo-class Selectors ====================
935
936    #[test]
937    fn test_match_nth_child_even() {
938        let doc = parse_doc("<ul><li>1</li><li>2</li><li>3</li><li>4</li></ul>");
939
940        let lis: Vec<_> = doc
941            .nodes()
942            .filter(|(_, n)| n.kind.tag_name() == Some("li"))
943            .map(|(id, _)| id)
944            .collect();
945        assert_eq!(lis.len(), 4);
946
947        let selectors = parse_selector("li:nth-child(even)").unwrap();
948        let match_count = lis.iter().filter(|id| matches_selector(&doc, **id, &selectors)).count();
949        assert_eq!(match_count, 2, ":nth-child(even) should match 2nd and 4th elements");
950    }
951
952    #[test]
953    fn test_match_nth_child_2n_plus_1() {
954        let doc = parse_doc("<ul><li>1</li><li>2</li><li>3</li><li>4</li></ul>");
955
956        let lis: Vec<_> = doc
957            .nodes()
958            .filter(|(_, n)| n.kind.tag_name() == Some("li"))
959            .map(|(id, _)| id)
960            .collect();
961        assert_eq!(lis.len(), 4);
962
963        let selectors = parse_selector("li:nth-child(2n+1)").unwrap();
964        let match_count = lis.iter().filter(|id| matches_selector(&doc, **id, &selectors)).count();
965        assert_eq!(match_count, 2, ":nth-child(2n+1) should match odd elements (1st and 3rd)");
966    }
967
968    #[test]
969    fn test_match_last_child() {
970        let doc = parse_doc("<ul><li id=\"first\">1</li><li id=\"last\">2</li></ul>");
971
972        let lis: Vec<_> = doc
973            .nodes()
974            .filter(|(_, n)| n.kind.tag_name() == Some("li"))
975            .map(|(id, _)| id)
976            .collect();
977        assert_eq!(lis.len(), 2);
978
979        let selectors = parse_selector("li:last-child").unwrap();
980        let matches: Vec<_> =
981            lis.iter().filter(|id| matches_selector(&doc, **id, &selectors)).collect();
982        assert_eq!(matches.len(), 1, ":last-child should match exactly one element");
983
984        // Verify it's the last one
985        let last_id = matches[0];
986        let node = doc.get(*last_id).unwrap();
987        let attrs = node.kind.attributes().unwrap();
988        assert_eq!(attrs.get("id"), Some(&"last".to_string()));
989    }
990
991    // ==================== Sibling Combinator Selectors ====================
992
993    #[test]
994    fn test_match_adjacent_sibling() {
995        let doc = parse_doc("<h1>Title</h1><p>First paragraph</p><p>Second paragraph</p>");
996
997        let ps: Vec<_> =
998            doc.nodes().filter(|(_, n)| n.kind.tag_name() == Some("p")).map(|(id, _)| id).collect();
999        assert_eq!(ps.len(), 2);
1000
1001        let selectors = parse_selector("h1 + p").unwrap();
1002        let match_count = ps.iter().filter(|id| matches_selector(&doc, **id, &selectors)).count();
1003        assert_eq!(match_count, 1, "h1 + p should match only the immediately adjacent paragraph");
1004    }
1005
1006    #[test]
1007    fn test_match_general_sibling() {
1008        let doc = parse_doc("<h1>Title</h1><p>First</p><p>Second</p>");
1009
1010        let ps: Vec<_> =
1011            doc.nodes().filter(|(_, n)| n.kind.tag_name() == Some("p")).map(|(id, _)| id).collect();
1012        assert_eq!(ps.len(), 2);
1013
1014        let selectors = parse_selector("h1 ~ p").unwrap();
1015        let match_count = ps.iter().filter(|id| matches_selector(&doc, **id, &selectors)).count();
1016        assert_eq!(match_count, 2, "h1 ~ p should match all following sibling paragraphs");
1017    }
1018
1019    #[test]
1020    fn test_match_general_sibling_not_preceding() {
1021        let doc = parse_doc("<p>Before</p><h1>Title</h1><p>After</p>");
1022
1023        let ps: Vec<_> =
1024            doc.nodes().filter(|(_, n)| n.kind.tag_name() == Some("p")).map(|(id, _)| id).collect();
1025        assert_eq!(ps.len(), 2);
1026
1027        let selectors = parse_selector("h1 ~ p").unwrap();
1028        let match_count = ps.iter().filter(|id| matches_selector(&doc, **id, &selectors)).count();
1029        assert_eq!(match_count, 1, "h1 ~ p should not match paragraphs preceding h1");
1030    }
1031
1032    #[test]
1033    fn test_match_adjacent_sibling_requires_immediate() {
1034        let doc = parse_doc("<h1>Title</h1><div>Separator</div><p>Paragraph</p>");
1035
1036        let p_id = find_element_by_tag(&doc, "p").unwrap();
1037
1038        let selectors = parse_selector("h1 + p").unwrap();
1039        assert!(
1040            !matches_selector(&doc, p_id, &selectors),
1041            "h1 + p should not match when div is between them"
1042        );
1043    }
1044
1045    // ==================== matches_selector_with_caches ====================
1046
1047    #[test]
1048    fn test_matches_selector_with_caches() {
1049        let doc = parse_doc("<ul><li>A</li><li>B</li><li>C</li></ul>");
1050        let selectors = parse_selector("li").unwrap();
1051
1052        let mut caches = SelectorCaches::default();
1053        let count = doc
1054            .nodes()
1055            .filter(|(id, n)| {
1056                n.kind.is_element()
1057                    && matches_selector_with_caches(&doc, *id, &selectors, &mut caches)
1058            })
1059            .count();
1060        assert_eq!(count, 3);
1061    }
1062}