rust_tokenizers/vocab/
fnet_vocab.rs

1// Copyright 2018-2020 The HuggingFace Inc. team.
2// Copyright 2020 Marian Team Authors
3// Copyright 2019 Google LLC. All Rights Reserved.
4// Copyright 2019-2020 Guillaume Becquin
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//     http://www.apache.org/licenses/LICENSE-2.0
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::error::TokenizerError;
16use crate::vocab::base_vocab::{
17    read_protobuf_file, read_special_token_mapping_file, swap_key_values, SpecialTokenMap,
18};
19use crate::vocab::Vocab;
20use std::collections::HashMap;
21use std::path::Path;
22
23/// # FNetVocab
24/// Vocabulary for FNet tokenizer. Contains the following special values:
25/// - CLS token
26/// - SEP token
27/// - PAD token
28/// - MASK token
29///
30/// Expects a SentencePiece BPE protobuf file when created from file.
31#[derive(Debug, Clone)]
32pub struct FNetVocab {
33    /// A mapping of tokens as string to indices (i.e. the encoder base)
34    pub values: HashMap<String, i64>,
35
36    /// A mapping of token ids to strings (i.e. the decoder base)
37    pub indices: HashMap<i64, String>,
38
39    /// Special tokens used by the vocabulary
40    pub special_token_map: SpecialTokenMap,
41
42    /// A mapping of special value tokens as strings to IDs (i.e. the encoder base for special
43    /// values), special values typically include things like BOS/EOS markers, class markers, mask
44    /// markers and padding markers
45    pub special_values: HashMap<String, i64>,
46
47    /// A mapping of special value tokens as IDs to strings (i.e. the decoder base for special values)
48    pub special_indices: HashMap<i64, String>,
49}
50
51const DEFAULT_UNK_TOKEN: &str = "<unk>";
52const DEFAULT_PAD_TOKEN: &str = "<pad>";
53const DEFAULT_SEP_TOKEN: &str = "[SEP]";
54const DEFAULT_CLS_TOKEN: &str = "[CLS]";
55const DEFAULT_MASK_TOKEN: &str = "[MASK]";
56
57impl FNetVocab {
58    pub fn get_pad_value(&self) -> &str {
59        self.special_token_map
60            .pad_token
61            .as_deref()
62            .unwrap_or(DEFAULT_PAD_TOKEN)
63    }
64
65    pub fn get_sep_value(&self) -> &str {
66        self.special_token_map
67            .sep_token
68            .as_deref()
69            .unwrap_or(DEFAULT_SEP_TOKEN)
70    }
71
72    pub fn get_cls_value(&self) -> &str {
73        self.special_token_map
74            .cls_token
75            .as_deref()
76            .unwrap_or(DEFAULT_CLS_TOKEN)
77    }
78
79    pub fn get_mask_value(&self) -> &str {
80        self.special_token_map
81            .mask_token
82            .as_deref()
83            .unwrap_or(DEFAULT_MASK_TOKEN)
84    }
85}
86
87impl Vocab for FNetVocab {
88    fn get_unknown_value(&self) -> &str {
89        &self.special_token_map.unk_token
90    }
91
92    fn values(&self) -> &HashMap<String, i64> {
93        &self.values
94    }
95
96    fn indices(&self) -> &HashMap<i64, String> {
97        &self.indices
98    }
99
100    fn special_values(&self) -> &HashMap<String, i64> {
101        &self.special_values
102    }
103
104    fn special_indices(&self) -> &HashMap<i64, String> {
105        &self.special_indices
106    }
107
108    fn values_mut(&mut self) -> &mut HashMap<String, i64> {
109        &mut self.values
110    }
111
112    fn indices_mut(&mut self) -> &mut HashMap<i64, String> {
113        &mut self.indices
114    }
115
116    fn special_values_mut(&mut self) -> &mut HashMap<String, i64> {
117        &mut self.special_values
118    }
119
120    fn special_indices_mut(&mut self) -> &mut HashMap<i64, String> {
121        &mut self.special_indices
122    }
123
124    fn from_file<P: AsRef<Path>>(path: P) -> Result<FNetVocab, TokenizerError> {
125        let values = read_protobuf_file(path)?;
126
127        let special_token_map = SpecialTokenMap {
128            unk_token: DEFAULT_UNK_TOKEN.to_string(),
129            pad_token: Some(DEFAULT_PAD_TOKEN.to_string()),
130            bos_token: None,
131            sep_token: Some(DEFAULT_SEP_TOKEN.to_string()),
132            cls_token: Some(DEFAULT_CLS_TOKEN.to_string()),
133            eos_token: None,
134            mask_token: Some(DEFAULT_MASK_TOKEN.to_string()),
135            additional_special_tokens: None,
136        };
137        Self::from_values_and_special_token_map(values, special_token_map)
138    }
139
140    fn from_file_with_special_token_mapping<P: AsRef<Path>, S: AsRef<Path>>(
141        path: P,
142        special_token_mapping_path: S,
143    ) -> Result<Self, TokenizerError> {
144        let values = read_protobuf_file(path)?;
145        let special_token_map = read_special_token_mapping_file(special_token_mapping_path)?;
146        Self::from_values_and_special_token_map(values, special_token_map)
147    }
148    fn from_values_and_special_token_map(
149        values: HashMap<String, i64>,
150        special_token_map: SpecialTokenMap,
151    ) -> Result<Self, TokenizerError>
152    where
153        Self: Sized,
154    {
155        let mut special_values = HashMap::new();
156        special_token_map.register_special_values(&values, &mut special_values)?;
157
158        let indices = swap_key_values(&values);
159        let special_indices = swap_key_values(&special_values);
160        Ok(Self {
161            values,
162            indices,
163            special_token_map,
164            special_values,
165            special_indices,
166        })
167    }
168    fn token_to_id(&self, token: &str) -> i64 {
169        self._token_to_id(
170            token,
171            &self.values,
172            &self.special_values,
173            self.get_unknown_value(),
174        )
175    }
176
177    fn id_to_token(&self, id: &i64) -> String {
178        self._id_to_token(
179            id,
180            &self.indices,
181            &self.special_indices,
182            self.get_unknown_value(),
183        )
184    }
185}