1use std::convert::TryFrom;
4use std::fmt;
5
6pub use cssparser::ToCss;
7use html5ever::{LocalName, Namespace};
8use precomputed_hash::PrecomputedHash;
9use selectors::{
10 matching,
11 parser::{self, ParseRelative, SelectorList, SelectorParseErrorKind},
12};
13
14#[cfg(feature = "serde")]
15use serde::{de::Visitor, Deserialize, Serialize};
16
17use crate::error::SelectorErrorKind;
18use crate::ElementRef;
19
20#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct Selector {
25 selectors: SelectorList<Simple>,
27}
28
29impl Selector {
30 pub fn parse(selectors: &str) -> Result<Self, SelectorErrorKind<'_>> {
32 let mut parser_input = cssparser::ParserInput::new(selectors);
33 let mut parser = cssparser::Parser::new(&mut parser_input);
34
35 SelectorList::parse(&Parser, &mut parser, ParseRelative::No)
36 .map(|selectors| Self { selectors })
37 .map_err(SelectorErrorKind::from)
38 }
39
40 pub fn matches(&self, element: &ElementRef) -> bool {
42 self.matches_with_scope(element, None)
43 }
44
45 pub fn matches_with_scope(&self, element: &ElementRef, scope: Option<ElementRef>) -> bool {
49 self.matches_with_scope_and_cache(element, scope, &mut Default::default())
50 }
51
52 pub(crate) fn matches_with_scope_and_cache(
56 &self,
57 element: &ElementRef,
58 scope: Option<ElementRef>,
59 caches: &mut matching::SelectorCaches,
60 ) -> bool {
61 let mut context = matching::MatchingContext::new(
62 matching::MatchingMode::Normal,
63 None,
64 caches,
65 matching::QuirksMode::NoQuirks,
66 matching::NeedsSelectorFlags::No,
67 matching::MatchingForInvalidation::No,
68 );
69 context.scope_element = scope.map(|x| selectors::Element::opaque(&x));
70 self.selectors
71 .slice()
72 .iter()
73 .any(|s| matching::matches_selector(s, 0, None, element, &mut context))
74 }
75}
76
77impl ToCss for Selector {
78 fn to_css<W>(&self, dest: &mut W) -> fmt::Result
79 where
80 W: fmt::Write,
81 {
82 self.selectors.to_css(dest)
83 }
84}
85
86#[cfg(feature = "serde")]
87impl Serialize for Selector {
88 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
89 serializer.serialize_str(&self.to_css_string())
90 }
91}
92
93#[cfg(feature = "serde")]
94impl<'de> Deserialize<'de> for Selector {
95 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
96 deserializer.deserialize_str(SelectorVisitor)
97 }
98}
99
100#[cfg(feature = "serde")]
101struct SelectorVisitor;
102
103#[cfg(feature = "serde")]
104impl Visitor<'_> for SelectorVisitor {
105 type Value = Selector;
106
107 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
108 write!(formatter, "a css selector string")
109 }
110
111 fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
112 Selector::parse(v).map_err(serde::de::Error::custom)
113 }
114}
115
116#[derive(Clone, Copy, Debug)]
118pub struct Parser;
119impl<'i> parser::Parser<'i> for Parser {
120 type Impl = Simple;
121 type Error = SelectorParseErrorKind<'i>;
122
123 fn parse_is_and_where(&self) -> bool {
124 true
125 }
126
127 fn parse_has(&self) -> bool {
128 true
129 }
130}
131
132#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134pub struct Simple;
135
136impl parser::SelectorImpl for Simple {
137 type AttrValue = CssString;
138 type Identifier = CssLocalName;
139 type LocalName = CssLocalName;
140 type NamespacePrefix = CssLocalName;
141 type NamespaceUrl = Namespace;
142 type BorrowedNamespaceUrl = Namespace;
143 type BorrowedLocalName = CssLocalName;
144
145 type NonTSPseudoClass = NonTSPseudoClass;
146 type PseudoElement = PseudoElement;
147
148 type ExtraMatchingData<'a> = ();
150}
151
152#[derive(Debug, Clone, PartialEq, Eq)]
154pub struct CssString(pub String);
155
156impl<'a> From<&'a str> for CssString {
157 fn from(val: &'a str) -> Self {
158 Self(val.to_owned())
159 }
160}
161
162impl AsRef<str> for CssString {
163 fn as_ref(&self) -> &str {
164 &self.0
165 }
166}
167
168impl ToCss for CssString {
169 fn to_css<W>(&self, dest: &mut W) -> fmt::Result
170 where
171 W: fmt::Write,
172 {
173 cssparser::serialize_string(&self.0, dest)
174 }
175}
176
177#[derive(Debug, Default, Clone, PartialEq, Eq)]
179pub struct CssLocalName(pub LocalName);
180
181impl<'a> From<&'a str> for CssLocalName {
182 fn from(val: &'a str) -> Self {
183 Self(val.into())
184 }
185}
186
187impl ToCss for CssLocalName {
188 fn to_css<W>(&self, dest: &mut W) -> fmt::Result
189 where
190 W: fmt::Write,
191 {
192 dest.write_str(&self.0)
193 }
194}
195
196impl PrecomputedHash for CssLocalName {
197 fn precomputed_hash(&self) -> u32 {
198 self.0.precomputed_hash()
199 }
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Eq)]
204pub enum NonTSPseudoClass {}
205
206impl parser::NonTSPseudoClass for NonTSPseudoClass {
207 type Impl = Simple;
208
209 fn is_active_or_hover(&self) -> bool {
210 false
211 }
212
213 fn is_user_action_state(&self) -> bool {
214 false
215 }
216}
217
218impl ToCss for NonTSPseudoClass {
219 fn to_css<W>(&self, dest: &mut W) -> fmt::Result
220 where
221 W: fmt::Write,
222 {
223 dest.write_str("")
224 }
225}
226
227#[derive(Debug, Clone, Copy, PartialEq, Eq)]
229pub enum PseudoElement {}
230
231impl parser::PseudoElement for PseudoElement {
232 type Impl = Simple;
233}
234
235impl ToCss for PseudoElement {
236 fn to_css<W>(&self, dest: &mut W) -> fmt::Result
237 where
238 W: fmt::Write,
239 {
240 dest.write_str("")
241 }
242}
243
244impl<'i> TryFrom<&'i str> for Selector {
245 type Error = SelectorErrorKind<'i>;
246
247 fn try_from(s: &'i str) -> Result<Self, Self::Error> {
248 Selector::parse(s)
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255 use std::convert::TryInto;
256
257 #[test]
258 fn selector_conversions() {
259 let s = "#testid.testclass";
260 let _sel: Selector = s.try_into().unwrap();
261
262 let s = s.to_owned();
263 let _sel: Selector = (*s).try_into().unwrap();
264 }
265
266 #[test]
267 #[should_panic]
268 fn invalid_selector_conversions() {
269 let s = "<failing selector>";
270 let _sel: Selector = s.try_into().unwrap();
271 }
272
273 #[test]
274 fn has_selector() {
275 let s = ":has(a)";
276 let _sel: Selector = s.try_into().unwrap();
277 }
278
279 #[test]
280 fn is_selector() {
281 let s = ":is(a)";
282 let _sel: Selector = s.try_into().unwrap();
283 }
284
285 #[test]
286 fn where_selector() {
287 let s = ":where(a)";
288 let _sel: Selector = s.try_into().unwrap();
289 }
290}