syntaxdot_tokenizers/
bert.rs

1use std::convert::TryFrom;
2use std::fs::File;
3use std::io::{BufRead, BufReader};
4
5use udgraph::graph::{Node, Sentence};
6use wordpieces::WordPieces;
7
8use super::{SentenceWithPieces, Tokenize};
9use crate::TokenizerError;
10use std::path::Path;
11
12/// BERT word piece tokenizer.
13///
14/// This tokenizer splits CoNLL-X tokens into word pieces. For
15/// example, a sentence such as:
16///
17/// > Veruntreute die AWO Spendengeld ?
18///
19/// Could be split (depending on the vocabulary) into the following
20/// word pieces:
21///
22/// > Ver ##unt ##reute die A ##W ##O Spenden ##geld [UNK]
23///
24/// Then vocabulary index of each such piece is returned.
25///
26/// The unknown token (here `[UNK]`) can be specified while
27/// constructing a tokenizer.
28pub struct BertTokenizer {
29    word_pieces: WordPieces,
30    unknown_piece: String,
31}
32
33impl BertTokenizer {
34    /// Construct a tokenizer from wordpieces and the unknown piece.
35    pub fn new(word_pieces: WordPieces, unknown_piece: impl Into<String>) -> Self {
36        BertTokenizer {
37            word_pieces,
38            unknown_piece: unknown_piece.into(),
39        }
40    }
41
42    pub fn open<P>(model_path: P, unknown_piece: impl Into<String>) -> Result<Self, TokenizerError>
43    where
44        P: AsRef<Path>,
45    {
46        let model_path = model_path.as_ref();
47        let f = File::open(model_path)
48            .map_err(|err| TokenizerError::open_error(model_path.to_string_lossy(), err))?;
49        Self::read(BufReader::new(f), unknown_piece)
50    }
51
52    pub fn read<R>(
53        buf_read: R,
54        unknown_piece: impl Into<String>,
55    ) -> Result<BertTokenizer, TokenizerError>
56    where
57        R: BufRead,
58    {
59        let word_pieces = WordPieces::try_from(buf_read.lines())?;
60        Ok(Self::new(word_pieces, unknown_piece))
61    }
62}
63
64impl Tokenize for BertTokenizer {
65    fn tokenize(&self, sentence: Sentence) -> SentenceWithPieces {
66        // An average of three pieces per token ought to enough for
67        // everyone ;).
68        let mut pieces = Vec::with_capacity((sentence.len() - 1) * 3);
69        let mut token_offsets = Vec::with_capacity(sentence.len());
70
71        pieces.push(
72            self.word_pieces
73                .get_initial("[CLS]")
74                .expect("BERT model does not have a [CLS] token") as i64,
75        );
76
77        for token in sentence.iter().filter_map(Node::token) {
78            token_offsets.push(pieces.len());
79
80            match self
81                .word_pieces
82                .split(token.form())
83                .map(|piece| piece.idx().map(|piece| piece as i64))
84                .collect::<Option<Vec<_>>>()
85            {
86                Some(word_pieces) => pieces.extend(word_pieces),
87                None => pieces.push(
88                    self.word_pieces
89                        .get_initial(&self.unknown_piece)
90                        .expect("Cannot get unknown piece") as i64,
91                ),
92            }
93        }
94
95        SentenceWithPieces {
96            pieces: pieces.into(),
97            sentence,
98            token_offsets,
99        }
100    }
101}
102
103#[cfg(feature = "model-tests")]
104#[cfg(test)]
105mod tests {
106    use std::convert::TryFrom;
107    use std::fs::File;
108    use std::io::{BufRead, BufReader};
109    use std::iter::FromIterator;
110
111    use ndarray::array;
112    use udgraph::graph::Sentence;
113    use udgraph::token::Token;
114    use wordpieces::WordPieces;
115
116    use super::BertTokenizer;
117    use crate::Tokenize;
118
119    fn read_pieces() -> WordPieces {
120        let f = File::open(env!("BERT_BASE_GERMAN_CASED_VOCAB")).unwrap();
121        WordPieces::try_from(BufReader::new(f).lines()).unwrap()
122    }
123
124    fn sentence_from_forms(forms: &[&str]) -> Sentence {
125        Sentence::from_iter(forms.iter().map(|&f| Token::new(f)))
126    }
127
128    #[test]
129    fn test_pieces() {
130        let tokenizer = BertTokenizer::new(read_pieces(), "[UNK]");
131
132        let sentence = sentence_from_forms(&["Veruntreute", "die", "AWO", "Spendengeld", "?"]);
133
134        let sentence_pieces = tokenizer.tokenize(sentence);
135        assert_eq!(
136            sentence_pieces.pieces,
137            array![3i64, 133, 1937, 14010, 30, 32, 26939, 26962, 12558, 2739, 26972]
138        );
139        assert_eq!(sentence_pieces.token_offsets, &[1, 4, 5, 8, 10]);
140    }
141}