1use crate::{Error, Result};
2
3#[derive(Clone, Copy, Debug, Eq, PartialEq)]
4pub enum PieceType {
5 Normal,
6 Unknown,
7 Control,
8 UserDefined,
9 Unused,
10 Byte,
11}
12
13impl PieceType {
14 pub(crate) fn from_i32(value: i32) -> Self {
15 match value {
16 2 => Self::Unknown,
17 3 => Self::Control,
18 4 => Self::UserDefined,
19 5 => Self::Unused,
20 6 => Self::Byte,
21 _ => Self::Normal,
22 }
23 }
24}
25
26#[derive(Clone, Copy, Debug, Eq, PartialEq)]
27pub enum ModelType {
28 Unigram,
29 Bpe,
30 Word,
31 Char,
32}
33
34impl ModelType {
35 pub(crate) fn from_i32(value: i32) -> Self {
36 match value {
37 2 => Self::Bpe,
38 3 => Self::Word,
39 4 => Self::Char,
40 _ => Self::Unigram,
41 }
42 }
43}
44
45#[derive(Clone, Debug, Default)]
46pub(crate) struct ModelProto {
47 pub(crate) pieces: Vec<SentencePiece>,
48 pub(crate) trainer_spec: TrainerSpec,
49 pub(crate) normalizer_spec: NormalizerSpec,
50 pub(crate) denormalizer_spec: Option<NormalizerSpec>,
51 pub(crate) self_test_data: SelfTestData,
52}
53
54#[derive(Clone, Debug)]
55pub(crate) struct SentencePiece {
56 pub(crate) piece: String,
57 pub(crate) score: f32,
58 pub(crate) kind: PieceType,
59}
60
61#[derive(Clone, Debug)]
62pub(crate) struct TrainerSpec {
63 pub(crate) model_type: ModelType,
64 pub(crate) treat_whitespace_as_suffix: bool,
65 pub(crate) allow_whitespace_only_pieces: bool,
66 pub(crate) byte_fallback: bool,
67 pub(crate) unk_id: i32,
68 pub(crate) bos_id: i32,
69 pub(crate) eos_id: i32,
70 pub(crate) pad_id: i32,
71 pub(crate) unk_surface: String,
72 pub(crate) unk_piece: String,
73 pub(crate) bos_piece: String,
74 pub(crate) eos_piece: String,
75 pub(crate) pad_piece: String,
76}
77
78#[derive(Clone, Debug)]
79pub(crate) struct NormalizerSpec {
80 pub(crate) name: String,
81 pub(crate) precompiled_charsmap: Vec<u8>,
82 pub(crate) add_dummy_prefix: bool,
83 pub(crate) remove_extra_whitespaces: bool,
84 pub(crate) escape_whitespaces: bool,
85}
86
87#[derive(Clone, Debug, Default)]
88pub(crate) struct SelfTestData {
89 pub(crate) samples: Vec<SelfTestSample>,
90}
91
92#[derive(Clone, Debug)]
93pub(crate) struct SelfTestSample {
94 pub(crate) input: String,
95 pub(crate) expected: String,
96}
97
98impl Default for TrainerSpec {
99 fn default() -> Self {
100 Self {
101 model_type: ModelType::Unigram,
102 treat_whitespace_as_suffix: false,
103 allow_whitespace_only_pieces: false,
104 byte_fallback: false,
105 unk_id: 0,
106 bos_id: 1,
107 eos_id: 2,
108 pad_id: -1,
109 unk_surface: crate::util::DEFAULT_UNKNOWN_SURFACE.to_owned(),
110 unk_piece: "<unk>".to_owned(),
111 bos_piece: "<s>".to_owned(),
112 eos_piece: "</s>".to_owned(),
113 pad_piece: "<pad>".to_owned(),
114 }
115 }
116}
117
118impl Default for NormalizerSpec {
119 fn default() -> Self {
120 Self {
121 name: String::new(),
122 precompiled_charsmap: Vec::new(),
123 add_dummy_prefix: true,
124 remove_extra_whitespaces: true,
125 escape_whitespaces: true,
126 }
127 }
128}
129
130impl ModelProto {
131 pub(crate) fn decode(bytes: &[u8]) -> Result<Self> {
132 let mut proto = Self::default();
133 let mut reader = ProtoReader::new(bytes);
134 while let Some((field, wire)) = reader.read_key()? {
135 match field {
136 1 if wire == WireType::LengthDelimited => {
137 let bytes = reader.read_len()?;
138 proto.pieces.push(decode_sentence_piece(bytes)?);
139 }
140 2 if wire == WireType::LengthDelimited => {
141 proto.trainer_spec = decode_trainer_spec(reader.read_len()?)?;
142 }
143 3 if wire == WireType::LengthDelimited => {
144 proto.normalizer_spec = decode_normalizer_spec(reader.read_len()?)?;
145 }
146 4 if wire == WireType::LengthDelimited => {
147 proto.self_test_data = decode_self_test_data(reader.read_len()?)?;
148 }
149 5 if wire == WireType::LengthDelimited => {
150 proto.denormalizer_spec = Some(decode_normalizer_spec(reader.read_len()?)?);
151 }
152 _ => reader.skip(wire)?,
153 }
154 }
155 Ok(proto)
156 }
157}
158
159fn decode_sentence_piece(bytes: &[u8]) -> Result<SentencePiece> {
160 let mut piece = SentencePiece {
161 piece: String::new(),
162 score: 0.0,
163 kind: PieceType::Normal,
164 };
165 let mut reader = ProtoReader::new(bytes);
166 while let Some((field, wire)) = reader.read_key()? {
167 match field {
168 1 if wire == WireType::LengthDelimited => {
169 piece.piece = reader.read_string()?;
170 }
171 2 if wire == WireType::ThirtyTwoBit => {
172 piece.score = reader.read_f32()?;
173 }
174 3 if wire == WireType::Varint => {
175 piece.kind = PieceType::from_i32(reader.read_varint()? as i32);
176 }
177 _ => reader.skip(wire)?,
178 }
179 }
180 Ok(piece)
181}
182
183fn decode_trainer_spec(bytes: &[u8]) -> Result<TrainerSpec> {
184 let mut spec = TrainerSpec::default();
185 let mut reader = ProtoReader::new(bytes);
186 while let Some((field, wire)) = reader.read_key()? {
187 match field {
188 3 if wire == WireType::Varint => {
189 spec.model_type = ModelType::from_i32(reader.read_varint()? as i32);
190 }
191 24 if wire == WireType::Varint => {
192 spec.treat_whitespace_as_suffix = reader.read_bool()?;
193 }
194 26 if wire == WireType::Varint => {
195 spec.allow_whitespace_only_pieces = reader.read_bool()?;
196 }
197 35 if wire == WireType::Varint => {
198 spec.byte_fallback = reader.read_bool()?;
199 }
200 40 if wire == WireType::Varint => {
201 spec.unk_id = reader.read_varint()? as i32;
202 }
203 41 if wire == WireType::Varint => {
204 spec.bos_id = reader.read_varint()? as i32;
205 }
206 42 if wire == WireType::Varint => {
207 spec.eos_id = reader.read_varint()? as i32;
208 }
209 43 if wire == WireType::Varint => {
210 spec.pad_id = reader.read_varint()? as i32;
211 }
212 44 if wire == WireType::LengthDelimited => {
213 spec.unk_surface = reader.read_string()?;
214 }
215 45 if wire == WireType::LengthDelimited => {
216 spec.unk_piece = reader.read_string()?;
217 }
218 46 if wire == WireType::LengthDelimited => {
219 spec.bos_piece = reader.read_string()?;
220 }
221 47 if wire == WireType::LengthDelimited => {
222 spec.eos_piece = reader.read_string()?;
223 }
224 48 if wire == WireType::LengthDelimited => {
225 spec.pad_piece = reader.read_string()?;
226 }
227 _ => reader.skip(wire)?,
228 }
229 }
230 Ok(spec)
231}
232
233fn decode_normalizer_spec(bytes: &[u8]) -> Result<NormalizerSpec> {
234 let mut spec = NormalizerSpec::default();
235 let mut reader = ProtoReader::new(bytes);
236 while let Some((field, wire)) = reader.read_key()? {
237 match field {
238 1 if wire == WireType::LengthDelimited => {
239 spec.name = reader.read_string()?;
240 }
241 2 if wire == WireType::LengthDelimited => {
242 spec.precompiled_charsmap = reader.read_len()?.to_vec();
243 }
244 3 if wire == WireType::Varint => {
245 spec.add_dummy_prefix = reader.read_bool()?;
246 }
247 4 if wire == WireType::Varint => {
248 spec.remove_extra_whitespaces = reader.read_bool()?;
249 }
250 5 if wire == WireType::Varint => {
251 spec.escape_whitespaces = reader.read_bool()?;
252 }
253 _ => reader.skip(wire)?,
254 }
255 }
256 Ok(spec)
257}
258
259fn decode_self_test_data(bytes: &[u8]) -> Result<SelfTestData> {
260 let mut data = SelfTestData::default();
261 let mut reader = ProtoReader::new(bytes);
262 while let Some((field, wire)) = reader.read_key()? {
263 match field {
264 1 if wire == WireType::LengthDelimited => {
265 data.samples
266 .push(decode_self_test_sample(reader.read_len()?)?);
267 }
268 _ => reader.skip(wire)?,
269 }
270 }
271 Ok(data)
272}
273
274fn decode_self_test_sample(bytes: &[u8]) -> Result<SelfTestSample> {
275 let mut sample = SelfTestSample {
276 input: String::new(),
277 expected: String::new(),
278 };
279 let mut reader = ProtoReader::new(bytes);
280 while let Some((field, wire)) = reader.read_key()? {
281 match field {
282 1 if wire == WireType::LengthDelimited => {
283 sample.input = reader.read_string()?;
284 }
285 2 if wire == WireType::LengthDelimited => {
286 sample.expected = reader.read_string()?;
287 }
288 _ => reader.skip(wire)?,
289 }
290 }
291 Ok(sample)
292}
293
294#[derive(Clone, Copy, Debug, Eq, PartialEq)]
295enum WireType {
296 Varint,
297 SixtyFourBit,
298 LengthDelimited,
299 ThirtyTwoBit,
300}
301
302impl WireType {
303 fn from_key(value: u64) -> Result<Self> {
304 match value & 0b111 {
305 0 => Ok(Self::Varint),
306 1 => Ok(Self::SixtyFourBit),
307 2 => Ok(Self::LengthDelimited),
308 5 => Ok(Self::ThirtyTwoBit),
309 wire => Err(Error::model_parse(format!(
310 "unsupported protobuf wire type {wire}"
311 ))),
312 }
313 }
314}
315
316struct ProtoReader<'a> {
317 bytes: &'a [u8],
318 position: usize,
319}
320
321impl<'a> ProtoReader<'a> {
322 fn new(bytes: &'a [u8]) -> Self {
323 Self { bytes, position: 0 }
324 }
325
326 fn read_key(&mut self) -> Result<Option<(u32, WireType)>> {
327 if self.position == self.bytes.len() {
328 return Ok(None);
329 }
330
331 let key = self.read_varint()?;
332 let field = (key >> 3) as u32;
333 if field == 0 {
334 return Err(Error::model_parse("protobuf field number 0 is invalid"));
335 }
336 Ok(Some((field, WireType::from_key(key)?)))
337 }
338
339 fn read_varint(&mut self) -> Result<u64> {
340 let mut value = 0u64;
341 for shift in (0..64).step_by(7) {
342 let byte = *self
343 .bytes
344 .get(self.position)
345 .ok_or_else(|| Error::model_parse("unexpected end of protobuf varint"))?;
346 self.position += 1;
347 value |= u64::from(byte & 0x7f) << shift;
348 if byte & 0x80 == 0 {
349 return Ok(value);
350 }
351 }
352 Err(Error::model_parse("protobuf varint is too long"))
353 }
354
355 fn read_bool(&mut self) -> Result<bool> {
356 Ok(self.read_varint()? != 0)
357 }
358
359 fn read_len(&mut self) -> Result<&'a [u8]> {
360 let len = self.read_varint()? as usize;
361 let end = self
362 .position
363 .checked_add(len)
364 .ok_or_else(|| Error::model_parse("protobuf length overflow"))?;
365 if end > self.bytes.len() {
366 return Err(Error::model_parse(
367 "unexpected end of protobuf length field",
368 ));
369 }
370 let out = &self.bytes[self.position..end];
371 self.position = end;
372 Ok(out)
373 }
374
375 fn read_string(&mut self) -> Result<String> {
376 let bytes = self.read_len()?;
377 String::from_utf8(bytes.to_vec())
378 .map_err(|_| Error::model_parse("protobuf string is not valid UTF-8"))
379 }
380
381 fn read_f32(&mut self) -> Result<f32> {
382 let end = self
383 .position
384 .checked_add(4)
385 .ok_or_else(|| Error::model_parse("protobuf fixed32 overflow"))?;
386 if end > self.bytes.len() {
387 return Err(Error::model_parse("unexpected end of protobuf fixed32"));
388 }
389 let bytes = [
390 self.bytes[self.position],
391 self.bytes[self.position + 1],
392 self.bytes[self.position + 2],
393 self.bytes[self.position + 3],
394 ];
395 self.position = end;
396 Ok(f32::from_le_bytes(bytes))
397 }
398
399 fn skip(&mut self, wire: WireType) -> Result<()> {
400 match wire {
401 WireType::Varint => {
402 self.read_varint()?;
403 }
404 WireType::SixtyFourBit => {
405 self.skip_bytes(8)?;
406 }
407 WireType::LengthDelimited => {
408 let len = self.read_varint()? as usize;
409 self.skip_bytes(len)?;
410 }
411 WireType::ThirtyTwoBit => {
412 self.skip_bytes(4)?;
413 }
414 }
415 Ok(())
416 }
417
418 fn skip_bytes(&mut self, len: usize) -> Result<()> {
419 let end = self
420 .position
421 .checked_add(len)
422 .ok_or_else(|| Error::model_parse("protobuf skip overflow"))?;
423 if end > self.bytes.len() {
424 return Err(Error::model_parse(
425 "unexpected end while skipping protobuf field",
426 ));
427 }
428 self.position = end;
429 Ok(())
430 }
431}