syntaxdot_tokenizers/
bert.rs1use std::convert::TryFrom;
2use std::fs::File;
3use std::io::{BufRead, BufReader};
4
5use udgraph::graph::{Node, Sentence};
6use wordpieces::WordPieces;
7
8use super::{SentenceWithPieces, Tokenize};
9use crate::TokenizerError;
10use std::path::Path;
11
12pub struct BertTokenizer {
29 word_pieces: WordPieces,
30 unknown_piece: String,
31}
32
33impl BertTokenizer {
34 pub fn new(word_pieces: WordPieces, unknown_piece: impl Into<String>) -> Self {
36 BertTokenizer {
37 word_pieces,
38 unknown_piece: unknown_piece.into(),
39 }
40 }
41
42 pub fn open<P>(model_path: P, unknown_piece: impl Into<String>) -> Result<Self, TokenizerError>
43 where
44 P: AsRef<Path>,
45 {
46 let model_path = model_path.as_ref();
47 let f = File::open(model_path)
48 .map_err(|err| TokenizerError::open_error(model_path.to_string_lossy(), err))?;
49 Self::read(BufReader::new(f), unknown_piece)
50 }
51
52 pub fn read<R>(
53 buf_read: R,
54 unknown_piece: impl Into<String>,
55 ) -> Result<BertTokenizer, TokenizerError>
56 where
57 R: BufRead,
58 {
59 let word_pieces = WordPieces::try_from(buf_read.lines())?;
60 Ok(Self::new(word_pieces, unknown_piece))
61 }
62}
63
64impl Tokenize for BertTokenizer {
65 fn tokenize(&self, sentence: Sentence) -> SentenceWithPieces {
66 let mut pieces = Vec::with_capacity((sentence.len() - 1) * 3);
69 let mut token_offsets = Vec::with_capacity(sentence.len());
70
71 pieces.push(
72 self.word_pieces
73 .get_initial("[CLS]")
74 .expect("BERT model does not have a [CLS] token") as i64,
75 );
76
77 for token in sentence.iter().filter_map(Node::token) {
78 token_offsets.push(pieces.len());
79
80 match self
81 .word_pieces
82 .split(token.form())
83 .map(|piece| piece.idx().map(|piece| piece as i64))
84 .collect::<Option<Vec<_>>>()
85 {
86 Some(word_pieces) => pieces.extend(word_pieces),
87 None => pieces.push(
88 self.word_pieces
89 .get_initial(&self.unknown_piece)
90 .expect("Cannot get unknown piece") as i64,
91 ),
92 }
93 }
94
95 SentenceWithPieces {
96 pieces: pieces.into(),
97 sentence,
98 token_offsets,
99 }
100 }
101}
102
103#[cfg(feature = "model-tests")]
104#[cfg(test)]
105mod tests {
106 use std::convert::TryFrom;
107 use std::fs::File;
108 use std::io::{BufRead, BufReader};
109 use std::iter::FromIterator;
110
111 use ndarray::array;
112 use udgraph::graph::Sentence;
113 use udgraph::token::Token;
114 use wordpieces::WordPieces;
115
116 use super::BertTokenizer;
117 use crate::Tokenize;
118
119 fn read_pieces() -> WordPieces {
120 let f = File::open(env!("BERT_BASE_GERMAN_CASED_VOCAB")).unwrap();
121 WordPieces::try_from(BufReader::new(f).lines()).unwrap()
122 }
123
124 fn sentence_from_forms(forms: &[&str]) -> Sentence {
125 Sentence::from_iter(forms.iter().map(|&f| Token::new(f)))
126 }
127
128 #[test]
129 fn test_pieces() {
130 let tokenizer = BertTokenizer::new(read_pieces(), "[UNK]");
131
132 let sentence = sentence_from_forms(&["Veruntreute", "die", "AWO", "Spendengeld", "?"]);
133
134 let sentence_pieces = tokenizer.tokenize(sentence);
135 assert_eq!(
136 sentence_pieces.pieces,
137 array![3i64, 133, 1937, 14010, 30, 32, 26939, 26962, 12558, 2739, 26972]
138 );
139 assert_eq!(sentence_pieces.token_offsets, &[1, 4, 5, 8, 10]);
140 }
141}