Skip to main content

scah_query_ir/query/compiler/
query.rs

1use std::ops::Range;
2
3use super::builder::{QueryBuilder, Save, SelectionKind};
4use super::error::SelectorParseError;
5use super::transition::Transition;
6use crate::query::selector::Combinator;
7
8pub trait QuerySpec<'query> {
9    fn states(&self) -> &[Transition<'query>];
10    fn queries(&self) -> &[QuerySection<'query>];
11    fn exit_at_section_end(&self) -> Option<usize>;
12
13    fn get_transition(&self, state: usize) -> &Transition<'query> {
14        &self.states()[state]
15    }
16
17    fn get_section_selection_kind(&self, section_index: usize) -> SelectionKind {
18        self.queries()[section_index].kind
19    }
20
21    fn get_selection(&self, section_index: usize) -> &QuerySection<'query> {
22        &self.queries()[section_index]
23    }
24
25    fn is_descendant(&self, state: usize) -> bool {
26        self.get_transition(state).guard == Combinator::Descendant
27    }
28
29    fn is_save_point(&self, position: &Position) -> bool {
30        debug_assert!(
31            self.queries()[position.selection]
32                .range
33                .contains(&position.state)
34        );
35        self.queries()[position.selection].range.end - 1 == position.state
36    }
37
38    fn is_last_save_point(&self, position: &Position) -> bool {
39        debug_assert!(position.selection < self.queries().len());
40        let is_last_query = self.queries().len() - 1 == position.selection;
41        let is_last_state = self.queries()[position.selection].range.end - 1 == position.state;
42        is_last_query && is_last_state
43    }
44}
45
46#[derive(PartialEq, Debug, Clone, Copy)]
47pub struct Position {
48    pub selection: usize,
49    pub state: usize,
50}
51
52impl Position {
53    pub fn next_transition<'query, Q: QuerySpec<'query> + ?Sized>(
54        &self,
55        query: &Q,
56    ) -> Option<usize> {
57        debug_assert!(self.selection < query.queries().len());
58        debug_assert!(query.queries()[self.selection].range.contains(&self.state));
59
60        let selection_range = &query.queries()[self.selection].range;
61        if self.state + 1 < selection_range.end {
62            Some(self.state + 1)
63        } else {
64            None
65        }
66    }
67
68    pub fn next_child<'query, Q: QuerySpec<'query> + ?Sized>(&self, query: &Q) -> Option<Self> {
69        debug_assert!(self.selection < query.queries().len());
70        debug_assert!(query.queries()[self.selection].range.contains(&self.state));
71
72        if self.selection == query.queries().len() - 1 {
73            return None;
74        }
75
76        let next_selection_index = self.selection + 1;
77        let next_selection = &query.queries()[next_selection_index];
78        if next_selection.parent.is_some_and(|p| p == self.selection) {
79            return Some(Self {
80                selection: next_selection_index,
81                state: query.queries()[next_selection_index].range.start,
82            });
83        }
84
85        None
86    }
87
88    pub fn next_sibling<'query, Q: QuerySpec<'query> + ?Sized>(&self, query: &Q) -> Option<Self> {
89        debug_assert!(self.selection < query.queries().len());
90        debug_assert!(query.queries()[self.selection].range.contains(&self.state));
91
92        query.queries()[self.selection]
93            .next_sibling
94            .map(|sibling| Self {
95                selection: sibling,
96                state: query.queries()[sibling].range.start,
97            })
98    }
99
100    pub fn back<'query, Q: QuerySpec<'query> + ?Sized>(&mut self, query: &Q) {
101        debug_assert!(self.selection < query.queries().len());
102        debug_assert!(self.state < query.queries()[self.selection].range.end);
103
104        let selection = &query.queries()[self.selection];
105        if self.state > selection.range.start {
106            self.state -= 1;
107        } else if let Some(parent) = selection.parent {
108            self.selection = parent;
109            self.state = query.queries()[self.selection].range.end - 1;
110        }
111    }
112}
113
114#[derive(Debug, Clone, PartialEq)]
115pub struct QuerySection<'query> {
116    pub source: &'query str,
117    pub range: Range<usize>,
118    pub parent: Option<usize>,
119    pub next_sibling: Option<usize>,
120    pub save: Save,
121    pub kind: SelectionKind,
122}
123
124impl<'query> QuerySection<'query> {
125    pub fn new(
126        source: &'query str,
127        save: Save,
128        kind: SelectionKind,
129        range: Range<usize>,
130        parent: Option<usize>,
131    ) -> Self {
132        Self {
133            source,
134            save,
135            kind,
136            range,
137            parent,
138            next_sibling: None,
139        }
140    }
141
142    pub const fn new_const(
143        source: &'query str,
144        save: Save,
145        kind: SelectionKind,
146        range: Range<usize>,
147        parent: Option<usize>,
148        next_sibling: Option<usize>,
149    ) -> Self {
150        Self {
151            source,
152            save,
153            kind,
154            range,
155            parent,
156            next_sibling,
157        }
158    }
159}
160
161#[derive(Debug, PartialEq, Clone)]
162pub struct Query<'query> {
163    pub states: Box<[Transition<'query>]>,
164    pub queries: Box<[QuerySection<'query>]>,
165    pub exit_at_section_end: Option<usize>,
166}
167
168impl<'query> QuerySpec<'query> for Query<'query> {
169    fn states(&self) -> &[Transition<'query>] {
170        &self.states
171    }
172
173    fn queries(&self) -> &[QuerySection<'query>] {
174        &self.queries
175    }
176
177    fn exit_at_section_end(&self) -> Option<usize> {
178        self.exit_at_section_end
179    }
180}
181
182#[derive(Debug, PartialEq, Clone)]
183pub struct StaticQuery<'query, const N_STATES: usize, const N_SECTIONS: usize> {
184    pub states: [Transition<'query>; N_STATES],
185    pub queries: [QuerySection<'query>; N_SECTIONS],
186    pub exit_at_section_end: Option<usize>,
187}
188
189impl<'query, const N_STATES: usize, const N_SECTIONS: usize>
190    StaticQuery<'query, N_STATES, N_SECTIONS>
191{
192    pub const fn new(
193        states: [Transition<'query>; N_STATES],
194        queries: [QuerySection<'query>; N_SECTIONS],
195        exit_at_section_end: Option<usize>,
196    ) -> Self {
197        Self {
198            states,
199            queries,
200            exit_at_section_end,
201        }
202    }
203}
204
205impl<'query, const N_STATES: usize, const N_SECTIONS: usize> QuerySpec<'query>
206    for StaticQuery<'query, N_STATES, N_SECTIONS>
207{
208    fn states(&self) -> &[Transition<'query>] {
209        &self.states
210    }
211
212    fn queries(&self) -> &[QuerySection<'query>] {
213        &self.queries
214    }
215
216    fn exit_at_section_end(&self) -> Option<usize> {
217        self.exit_at_section_end
218    }
219}
220
221impl<'query> Query<'query> {
222    pub fn first(
223        query: &'query str,
224        save: Save,
225    ) -> Result<QueryBuilder<'query>, SelectorParseError> {
226        let states = Transition::generate_transitions_from_string(query)?;
227        let queries = vec![QuerySection::new(
228            query,
229            save,
230            SelectionKind::First,
231            0..states.len(),
232            None,
233        )];
234
235        Ok(QueryBuilder {
236            states,
237            selection: queries,
238        })
239    }
240
241    pub fn all(query: &'query str, save: Save) -> Result<QueryBuilder<'query>, SelectorParseError> {
242        let states = Transition::generate_transitions_from_string(query)?;
243        let queries = vec![QuerySection::new(
244            query,
245            save,
246            SelectionKind::All,
247            0..states.len(),
248            None,
249        )];
250
251        Ok(QueryBuilder {
252            states,
253            selection: queries,
254        })
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use crate::query::compiler::transition::Transition;
261    use crate::query::selector::AttributeSelection;
262    use crate::query::selector::AttributeSelectionKind;
263    use crate::query::selector::AttributeSelections;
264    use crate::query::selector::ClassSelections;
265    use crate::query::selector::Combinator;
266    use crate::query::selector::ElementPredicate;
267    use crate::{Query, QuerySection, Save, SelectionKind};
268
269    #[test]
270    fn test_query_builder_one_selection() {
271        let query = Query::all("a", Save::all()).unwrap().build();
272
273        assert_eq!(
274            query.states.iter().as_slice(),
275            [Transition {
276                predicate: ElementPredicate {
277                    name: Some("a"),
278                    id: None,
279                    classes: ClassSelections::from_static(&[]),
280                    attributes: AttributeSelections::from_static(&[])
281                },
282                guard: Combinator::Descendant,
283            }]
284        );
285
286        assert_eq!(
287            query.queries.iter().as_slice(),
288            [QuerySection {
289                source: "a",
290                save: Save::all(),
291                kind: SelectionKind::All,
292                parent: None,
293                range: 0..1,
294                next_sibling: None,
295            }]
296        );
297    }
298
299    #[test]
300    fn test_query_builder_chainned_selection() {
301        let query = Query::first("span", Save::all())
302            .unwrap()
303            .all("a", Save::all())
304            .unwrap()
305            .build();
306
307        assert_eq!(
308            query.states.iter().as_slice(),
309            [
310                Transition {
311                    predicate: ElementPredicate {
312                        name: Some("span"),
313                        id: None,
314                        classes: ClassSelections::from_static(&[]),
315                        attributes: AttributeSelections::from_static(&[])
316                    },
317                    guard: Combinator::Descendant,
318                },
319                Transition {
320                    predicate: ElementPredicate {
321                        name: Some("a"),
322                        id: None,
323                        classes: ClassSelections::from_static(&[]),
324                        attributes: AttributeSelections::from_static(&[])
325                    },
326                    guard: Combinator::Descendant,
327                }
328            ]
329        );
330    }
331
332    #[test]
333    fn test_query_builder_chainned_multi_element_selection() {
334        let query = Query::first("span#top.inner", Save::all())
335            .unwrap()
336            .all("a#link1.foo[href^=\"https\"]", Save::all())
337            .unwrap()
338            .build();
339
340        assert_eq!(query.states.len(), 2);
341        assert_eq!(query.queries.len(), 2);
342        assert_eq!(
343            query.states[1].predicate,
344            ElementPredicate {
345                name: Some("a"),
346                id: Some("link1"),
347                classes: ClassSelections::from_static(&["foo"]),
348                attributes: AttributeSelections::from(vec![AttributeSelection {
349                    name: "href",
350                    value: Some("https"),
351                    kind: AttributeSelectionKind::Prefix,
352                }]),
353            }
354        );
355    }
356
357    #[test]
358    fn test_query_builder_chainned_multi_element_selection_with_branching() {
359        let query = Query::first("div", Save::all())
360            .unwrap()
361            .then(|ctx| {
362                Ok([
363                    ctx.all("a", Save::all())?,
364                    ctx.first("p.note", Save::none())?,
365                ])
366            })
367            .unwrap()
368            .build();
369
370        assert_eq!(query.queries.len(), 3);
371        assert_eq!(query.queries[1].next_sibling, Some(2));
372        assert_eq!(query.queries[2].next_sibling, None);
373    }
374}