summa_core/components/tokenizers/
html_tokenizer.rs

1use std::cell::RefCell;
2use std::collections::HashSet;
3
4use super::tokenizer::TokenStream;
5
6/// Tokenize the text by splitting on whitespaces and punctuation.
7#[derive(Clone)]
8pub struct HtmlTokenizer {
9    ignored_tags: HashSet<String>,
10    inlined_tags: HashSet<String>,
11}
12
13impl HtmlTokenizer {
14    pub fn new(ignored_tags: HashSet<String>, inlined_tags: HashSet<String>) -> HtmlTokenizer {
15        HtmlTokenizer { ignored_tags, inlined_tags }
16    }
17}
18
19pub struct HtmlTokenStream<'a> {
20    text: &'a str,
21    html_tokenizer: xmlparser::Tokenizer<'a>,
22    current_nested_token_stream: TokenStream<'a>,
23    ignored_tags: &'a HashSet<String>,
24    inlined_tags: &'a HashSet<String>,
25    position: usize,
26    skip_list: RefCell<Option<Vec<(usize, usize)>>>,
27    current_state: HtmlTokenizerState,
28    next_token: Option<Result<xmlparser::Token<'a>, xmlparser::Error>>,
29}
30
31impl HtmlTokenStream<'_> {
32    pub fn add_new_skip(&self, start: usize, end: usize) {
33        let mut skip_list = self.skip_list.borrow_mut();
34        match skip_list.as_mut() {
35            None => *skip_list = Some(vec![(start, end)]),
36            Some(skip_list) => skip_list.push((start, end)),
37        }
38    }
39
40    pub fn emit(&mut self, start: usize, end: usize) {
41        self.current_nested_token_stream =
42            TokenStream::new_with_offset_and_position(&self.text[start..end], start, self.position, self.skip_list.borrow_mut().take());
43        self.current_state = HtmlTokenizerState::Emit;
44    }
45
46    pub fn skip_tag(&mut self) {
47        let mut depth = 1;
48        while let Some(Ok(next_token)) = self.html_tokenizer.next() {
49            match next_token {
50                xmlparser::Token::ElementStart { .. } => {
51                    depth += 1;
52                }
53                xmlparser::Token::ElementEnd {
54                    end: xmlparser::ElementEnd::Close(..),
55                    ..
56                } => {
57                    depth -= 1;
58                }
59                _ => {}
60            }
61            if depth == 0 {
62                self.next_token = self.html_tokenizer.next();
63            }
64        }
65    }
66}
67
68impl tantivy::tokenizer::Tokenizer for HtmlTokenizer {
69    type TokenStream<'a> = HtmlTokenStream<'a>;
70
71    fn token_stream<'a>(&'a mut self, text: &'a str) -> HtmlTokenStream<'a> {
72        let html_tokenizer = xmlparser::Tokenizer::from_fragment(text, 0..text.len());
73        HtmlTokenStream {
74            text,
75            html_tokenizer,
76            current_nested_token_stream: TokenStream::new(""),
77            ignored_tags: &self.ignored_tags,
78            inlined_tags: &self.inlined_tags,
79            position: usize::MAX,
80            skip_list: RefCell::default(),
81            current_state: HtmlTokenizerState::BeginReading,
82            next_token: None,
83        }
84    }
85}
86
87#[derive(Debug)]
88enum CollectedToken {
89    None,
90    Ref { start: usize, end: usize },
91}
92
93#[derive(Debug)]
94enum HtmlTokenizerState {
95    BeginReading,
96    CollectToken { collected_token: CollectedToken },
97    Emit,
98}
99
100impl<'a> tantivy::tokenizer::TokenStream for HtmlTokenStream<'a> {
101    fn advance(&mut self) -> bool {
102        loop {
103            match &self.current_state {
104                HtmlTokenizerState::BeginReading => {
105                    *self.skip_list.borrow_mut() = None;
106                    self.next_token = self.html_tokenizer.next();
107                    self.current_state = HtmlTokenizerState::CollectToken {
108                        collected_token: CollectedToken::None,
109                    };
110                }
111                HtmlTokenizerState::CollectToken { collected_token } => match self.next_token {
112                    Some(next_token) => match next_token {
113                        Ok(xmlparser::Token::Declaration { .. })
114                        | Ok(xmlparser::Token::ProcessingInstruction { .. })
115                        | Ok(xmlparser::Token::Comment { .. })
116                        | Ok(xmlparser::Token::DtdStart { .. })
117                        | Ok(xmlparser::Token::EmptyDtd { .. })
118                        | Ok(xmlparser::Token::DtdEnd { .. })
119                        | Ok(xmlparser::Token::Attribute { .. })
120                        | Ok(xmlparser::Token::Cdata { .. })
121                        | Ok(xmlparser::Token::EntityDeclaration { .. })
122                        | Ok(xmlparser::Token::ElementEnd {
123                            end: xmlparser::ElementEnd::Open,
124                            ..
125                        })
126                        | Ok(xmlparser::Token::ElementEnd {
127                            end: xmlparser::ElementEnd::Empty,
128                            ..
129                        }) => {
130                            self.next_token = self.html_tokenizer.next();
131                        }
132                        Ok(xmlparser::Token::ElementStart { local: start, .. }) => {
133                            if self.ignored_tags.contains(start.as_str()) {
134                                let mut depth = 1;
135                                while let Some(Ok(next_token)) = self.html_tokenizer.next() {
136                                    match next_token {
137                                        xmlparser::Token::ElementStart { .. } => {
138                                            depth += 1;
139                                        }
140                                        xmlparser::Token::ElementEnd {
141                                            end: xmlparser::ElementEnd::Close(..),
142                                            ..
143                                        } => {
144                                            depth -= 1;
145                                        }
146                                        _ => {}
147                                    }
148                                    if depth == 0 {
149                                        break;
150                                    }
151                                }
152                            } else if self.inlined_tags.contains(start.as_str()) {
153                                while let Some(Ok(next_token)) = self.html_tokenizer.next() {
154                                    if let xmlparser::Token::ElementEnd {
155                                        end: xmlparser::ElementEnd::Open,
156                                        ..
157                                    } = next_token
158                                    {
159                                        break;
160                                    }
161                                }
162                                self.next_token = self.html_tokenizer.next();
163                                continue;
164                            }
165                            match collected_token {
166                                CollectedToken::None => self.current_state = HtmlTokenizerState::BeginReading,
167                                CollectedToken::Ref { start, end } => self.emit(*start, *end),
168                            }
169                        }
170                        Ok(xmlparser::Token::ElementEnd {
171                            end: xmlparser::ElementEnd::Close(_, local),
172                            ..
173                        }) => {
174                            if self.inlined_tags.contains(local.as_str()) {
175                                self.next_token = self.html_tokenizer.next();
176                                continue;
177                            }
178                            match collected_token {
179                                CollectedToken::None => self.current_state = HtmlTokenizerState::BeginReading,
180                                CollectedToken::Ref { start, end } => self.emit(*start, *end),
181                            }
182                        }
183                        Ok(xmlparser::Token::Text { text }) => {
184                            let new_collected_token = match collected_token {
185                                CollectedToken::None => CollectedToken::Ref {
186                                    start: text.start(),
187                                    end: text.end(),
188                                },
189                                CollectedToken::Ref { start, end } => {
190                                    if *end < text.start() {
191                                        self.add_new_skip(*end, text.start());
192                                    }
193                                    CollectedToken::Ref {
194                                        start: *start,
195                                        end: text.end(),
196                                    }
197                                }
198                            };
199                            self.current_state = HtmlTokenizerState::CollectToken {
200                                collected_token: new_collected_token,
201                            };
202                            self.next_token = self.html_tokenizer.next();
203                        }
204                        Err(_) => match collected_token {
205                            CollectedToken::None => self.current_state = HtmlTokenizerState::BeginReading,
206                            CollectedToken::Ref { start, end } => self.emit(*start, *end),
207                        },
208                    },
209                    None => match collected_token {
210                        CollectedToken::None => return false,
211                        CollectedToken::Ref { start, end } => self.emit(*start, *end),
212                    },
213                },
214                HtmlTokenizerState::Emit => {
215                    if self.current_nested_token_stream.advance() {
216                        self.position = self.current_nested_token_stream.token().position;
217                        return true;
218                    }
219                    self.current_state = HtmlTokenizerState::BeginReading;
220                }
221            }
222        }
223    }
224
225    fn token(&self) -> &tantivy::tokenizer::Token {
226        self.current_nested_token_stream.token()
227    }
228
229    fn token_mut(&mut self) -> &mut tantivy::tokenizer::Token {
230        self.current_nested_token_stream.token_mut()
231    }
232}
233
234#[cfg(test)]
235pub mod tests {
236    use std::collections::HashSet;
237
238    use tantivy::tokenizer::{LowerCaser, RemoveLongFilter, StopWordFilter, TextAnalyzer, Token, TokenizerManager};
239
240    use super::HtmlTokenizer;
241    use crate::components::tokenizers::tokenizer::tests::assert_tokenization;
242    use crate::components::STOP_WORDS;
243
244    #[test]
245    fn test_html_tokenization() {
246        let tokenizer_manager = TokenizerManager::default();
247        tokenizer_manager.register(
248            "tokenizer",
249            TextAnalyzer::builder(HtmlTokenizer::new(
250                HashSet::from_iter(vec!["formula".to_string()].into_iter()),
251                HashSet::from_iter(vec!["sup".to_string()].into_iter()),
252            ))
253            .filter(RemoveLongFilter::limit(40))
254            .filter(LowerCaser)
255            .filter(StopWordFilter::remove(STOP_WORDS.map(String::from).to_vec()))
256            .build(),
257        );
258        let mut tokenizer = tokenizer_manager.get("tokenizer").unwrap();
259        let t_ref = &mut tokenizer;
260        assert_tokenization(
261            t_ref,
262            "Hello, world!",
263            &[
264                Token {
265                    offset_from: 0,
266                    offset_to: 5,
267                    position: 0,
268                    text: "hello".to_string(),
269                    position_length: 1,
270                },
271                Token {
272                    offset_from: 7,
273                    offset_to: 12,
274                    position: 1,
275                    text: "world".to_string(),
276                    position_length: 1,
277                },
278            ],
279        );
280        assert_tokenization(
281            t_ref,
282            "<article>test1 <t2>test2 TEST3</t2></article>",
283            &[
284                Token {
285                    offset_from: 9,
286                    offset_to: 14,
287                    position: 0,
288                    text: "test1".to_string(),
289                    position_length: 1,
290                },
291                Token {
292                    offset_from: 19,
293                    offset_to: 24,
294                    position: 1,
295                    text: "test2".to_string(),
296                    position_length: 1,
297                },
298                Token {
299                    offset_from: 25,
300                    offset_to: 30,
301                    position: 2,
302                    text: "test3".to_string(),
303                    position_length: 1,
304                },
305            ],
306        );
307        assert_tokenization(
308            t_ref,
309            "<article>test1 test2<p>link link2</p><formula>1 + 2</formula><p>link3 link4</p></article>",
310            &[
311                Token {
312                    offset_from: 9,
313                    offset_to: 14,
314                    position: 0,
315                    text: "test1".to_string(),
316                    position_length: 1,
317                },
318                Token {
319                    offset_from: 15,
320                    offset_to: 20,
321                    position: 1,
322                    text: "test2".to_string(),
323                    position_length: 1,
324                },
325                Token {
326                    offset_from: 23,
327                    offset_to: 27,
328                    position: 2,
329                    text: "link".to_string(),
330                    position_length: 1,
331                },
332                Token {
333                    offset_from: 28,
334                    offset_to: 33,
335                    position: 3,
336                    text: "link2".to_string(),
337                    position_length: 1,
338                },
339                Token {
340                    offset_from: 64,
341                    offset_to: 69,
342                    position: 4,
343                    text: "link3".to_string(),
344                    position_length: 1,
345                },
346                Token {
347                    offset_from: 70,
348                    offset_to: 75,
349                    position: 5,
350                    text: "link4".to_string(),
351                    position_length: 1,
352                },
353            ],
354        );
355        assert_tokenization(
356            t_ref,
357            "test1 test2<p>link link2<formula>1 + 2</formula><p>link3 link4",
358            &[
359                Token {
360                    offset_from: 0,
361                    offset_to: 5,
362                    position: 0,
363                    text: "test1".to_string(),
364                    position_length: 1,
365                },
366                Token {
367                    offset_from: 6,
368                    offset_to: 11,
369                    position: 1,
370                    text: "test2".to_string(),
371                    position_length: 1,
372                },
373                Token {
374                    offset_from: 14,
375                    offset_to: 18,
376                    position: 2,
377                    text: "link".to_string(),
378                    position_length: 1,
379                },
380                Token {
381                    offset_from: 19,
382                    offset_to: 24,
383                    position: 3,
384                    text: "link2".to_string(),
385                    position_length: 1,
386                },
387                Token {
388                    offset_from: 51,
389                    offset_to: 56,
390                    position: 4,
391                    text: "link3".to_string(),
392                    position_length: 1,
393                },
394                Token {
395                    offset_from: 57,
396                    offset_to: 62,
397                    position: 5,
398                    text: "link4".to_string(),
399                    position_length: 1,
400                },
401            ],
402        );
403        assert_tokenization(
404            t_ref,
405            "link link2<formula>1 + 2</formula>link3 link4",
406            &[
407                Token {
408                    offset_from: 0,
409                    offset_to: 4,
410                    position: 0,
411                    text: "link".to_string(),
412                    position_length: 1,
413                },
414                Token {
415                    offset_from: 5,
416                    offset_to: 10,
417                    position: 1,
418                    text: "link2".to_string(),
419                    position_length: 1,
420                },
421                Token {
422                    offset_from: 34,
423                    offset_to: 39,
424                    position: 2,
425                    text: "link3".to_string(),
426                    position_length: 1,
427                },
428                Token {
429                    offset_from: 40,
430                    offset_to: 45,
431                    position: 3,
432                    text: "link4".to_string(),
433                    position_length: 1,
434                },
435            ],
436        );
437        assert_tokenization(
438            t_ref,
439            "link link2<i>link</i>link3 link4",
440            &[
441                Token {
442                    offset_from: 0,
443                    offset_to: 4,
444                    position: 0,
445                    text: "link".to_string(),
446                    position_length: 1,
447                },
448                Token {
449                    offset_from: 5,
450                    offset_to: 10,
451                    position: 1,
452                    text: "link2".to_string(),
453                    position_length: 1,
454                },
455                Token {
456                    offset_from: 13,
457                    offset_to: 17,
458                    position: 2,
459                    text: "link".to_string(),
460                    position_length: 1,
461                },
462                Token {
463                    offset_from: 21,
464                    offset_to: 26,
465                    position: 3,
466                    text: "link3".to_string(),
467                    position_length: 1,
468                },
469                Token {
470                    offset_from: 27,
471                    offset_to: 32,
472                    position: 4,
473                    text: "link4".to_string(),
474                    position_length: 1,
475                },
476            ],
477        );
478        assert_tokenization(
479            t_ref,
480            "link link2 <i>link</i>link3 link4",
481            &[
482                Token {
483                    offset_from: 0,
484                    offset_to: 4,
485                    position: 0,
486                    text: "link".to_string(),
487                    position_length: 1,
488                },
489                Token {
490                    offset_from: 5,
491                    offset_to: 10,
492                    position: 1,
493                    text: "link2".to_string(),
494                    position_length: 1,
495                },
496                Token {
497                    offset_from: 14,
498                    offset_to: 18,
499                    position: 2,
500                    text: "link".to_string(),
501                    position_length: 1,
502                },
503                Token {
504                    offset_from: 22,
505                    offset_to: 27,
506                    position: 3,
507                    text: "link3".to_string(),
508                    position_length: 1,
509                },
510                Token {
511                    offset_from: 28,
512                    offset_to: 33,
513                    position: 4,
514                    text: "link4".to_string(),
515                    position_length: 1,
516                },
517            ],
518        );
519        assert_tokenization(
520            t_ref,
521            "link link2 <i>link</i> link3 link4",
522            &[
523                Token {
524                    offset_from: 0,
525                    offset_to: 4,
526                    position: 0,
527                    text: "link".to_string(),
528                    position_length: 1,
529                },
530                Token {
531                    offset_from: 5,
532                    offset_to: 10,
533                    position: 1,
534                    text: "link2".to_string(),
535                    position_length: 1,
536                },
537                Token {
538                    offset_from: 14,
539                    offset_to: 18,
540                    position: 2,
541                    text: "link".to_string(),
542                    position_length: 1,
543                },
544                Token {
545                    offset_from: 23,
546                    offset_to: 28,
547                    position: 3,
548                    text: "link3".to_string(),
549                    position_length: 1,
550                },
551                Token {
552                    offset_from: 29,
553                    offset_to: 34,
554                    position: 4,
555                    text: "link4".to_string(),
556                    position_length: 1,
557                },
558            ],
559        );
560        assert_tokenization(
561            t_ref,
562            "link link2<i>link</i> link3 link4",
563            &[
564                Token {
565                    offset_from: 0,
566                    offset_to: 4,
567                    position: 0,
568                    text: "link".to_string(),
569                    position_length: 1,
570                },
571                Token {
572                    offset_from: 5,
573                    offset_to: 10,
574                    position: 1,
575                    text: "link2".to_string(),
576                    position_length: 1,
577                },
578                Token {
579                    offset_from: 13,
580                    offset_to: 17,
581                    position: 2,
582                    text: "link".to_string(),
583                    position_length: 1,
584                },
585                Token {
586                    offset_from: 22,
587                    offset_to: 27,
588                    position: 3,
589                    text: "link3".to_string(),
590                    position_length: 1,
591                },
592                Token {
593                    offset_from: 28,
594                    offset_to: 33,
595                    position: 4,
596                    text: "link4".to_string(),
597                    position_length: 1,
598                },
599            ],
600        );
601        assert_tokenization(
602            t_ref,
603            "link<sup>1</sup>2 link<sup>3</sup>",
604            &[
605                Token {
606                    offset_from: 0,
607                    offset_to: 17,
608                    position: 0,
609                    text: "link12".to_string(),
610                    position_length: 1,
611                },
612                Token {
613                    offset_from: 18,
614                    offset_to: 28,
615                    position: 1,
616                    text: "link3".to_string(),
617                    position_length: 1,
618                },
619            ],
620        );
621        assert_tokenization(
622            t_ref,
623            "link<sup attr=\"1\">1</sup>",
624            &[Token {
625                offset_from: 0,
626                offset_to: 19,
627                position: 0,
628                text: "link1".to_string(),
629                position_length: 1,
630            }],
631        );
632        assert_tokenization(
633            t_ref,
634            "link<mll:p attr=\"1\">1</mll:p>",
635            &[
636                Token {
637                    offset_from: 0,
638                    offset_to: 4,
639                    position: 0,
640                    text: "link".to_string(),
641                    position_length: 1,
642                },
643                Token {
644                    offset_from: 20,
645                    offset_to: 21,
646                    position: 1,
647                    text: "1".to_string(),
648                    position_length: 1,
649                },
650            ],
651        );
652        assert_tokenization(
653            t_ref,
654            "<p>test1 <sup>test2",
655            &[
656                Token {
657                    offset_from: 3,
658                    offset_to: 8,
659                    position: 0,
660                    text: "test1".to_string(),
661                    position_length: 1,
662                },
663                Token {
664                    offset_from: 14,
665                    offset_to: 19,
666                    position: 1,
667                    text: "test2".to_string(),
668                    position_length: 1,
669                },
670            ],
671        );
672        assert_tokenization(
673            t_ref,
674            "<p>test1<sup>test2",
675            &[Token {
676                offset_from: 3,
677                offset_to: 18,
678                position: 0,
679                text: "test1test2".to_string(),
680                position_length: 1,
681            }],
682        );
683        assert_tokenization(
684            t_ref,
685            "test1<p <b>>test2</b>",
686            &[Token {
687                offset_from: 0,
688                offset_to: 5,
689                position: 0,
690                text: "test1".to_string(),
691                position_length: 1,
692            }],
693        );
694    }
695}