Skip to main content

sentencepiece_rs/
proto.rs

1use crate::{Error, Result};
2
3#[derive(Clone, Copy, Debug, Eq, PartialEq)]
4pub enum PieceType {
5    Normal,
6    Unknown,
7    Control,
8    UserDefined,
9    Unused,
10    Byte,
11}
12
13impl PieceType {
14    pub(crate) fn from_i32(value: i32) -> Self {
15        match value {
16            2 => Self::Unknown,
17            3 => Self::Control,
18            4 => Self::UserDefined,
19            5 => Self::Unused,
20            6 => Self::Byte,
21            _ => Self::Normal,
22        }
23    }
24}
25
26#[derive(Clone, Copy, Debug, Eq, PartialEq)]
27pub enum ModelType {
28    Unigram,
29    Bpe,
30    Word,
31    Char,
32}
33
34impl ModelType {
35    pub(crate) fn from_i32(value: i32) -> Self {
36        match value {
37            2 => Self::Bpe,
38            3 => Self::Word,
39            4 => Self::Char,
40            _ => Self::Unigram,
41        }
42    }
43}
44
45#[derive(Clone, Debug, Default)]
46pub(crate) struct ModelProto {
47    pub(crate) pieces: Vec<SentencePiece>,
48    pub(crate) trainer_spec: TrainerSpec,
49    pub(crate) normalizer_spec: NormalizerSpec,
50    pub(crate) denormalizer_spec: Option<NormalizerSpec>,
51    pub(crate) self_test_data: SelfTestData,
52}
53
54#[derive(Clone, Debug)]
55pub(crate) struct SentencePiece {
56    pub(crate) piece: String,
57    pub(crate) score: f32,
58    pub(crate) kind: PieceType,
59}
60
61#[derive(Clone, Debug)]
62pub(crate) struct TrainerSpec {
63    pub(crate) model_type: ModelType,
64    pub(crate) treat_whitespace_as_suffix: bool,
65    pub(crate) allow_whitespace_only_pieces: bool,
66    pub(crate) byte_fallback: bool,
67    pub(crate) unk_id: i32,
68    pub(crate) bos_id: i32,
69    pub(crate) eos_id: i32,
70    pub(crate) pad_id: i32,
71    pub(crate) unk_surface: String,
72    pub(crate) unk_piece: String,
73    pub(crate) bos_piece: String,
74    pub(crate) eos_piece: String,
75    pub(crate) pad_piece: String,
76}
77
78#[derive(Clone, Debug)]
79pub(crate) struct NormalizerSpec {
80    pub(crate) name: String,
81    pub(crate) precompiled_charsmap: Vec<u8>,
82    pub(crate) add_dummy_prefix: bool,
83    pub(crate) remove_extra_whitespaces: bool,
84    pub(crate) escape_whitespaces: bool,
85}
86
87#[derive(Clone, Debug, Default)]
88pub(crate) struct SelfTestData {
89    pub(crate) samples: Vec<SelfTestSample>,
90}
91
92#[derive(Clone, Debug)]
93pub(crate) struct SelfTestSample {
94    pub(crate) input: String,
95    pub(crate) expected: String,
96}
97
98impl Default for TrainerSpec {
99    fn default() -> Self {
100        Self {
101            model_type: ModelType::Unigram,
102            treat_whitespace_as_suffix: false,
103            allow_whitespace_only_pieces: false,
104            byte_fallback: false,
105            unk_id: 0,
106            bos_id: 1,
107            eos_id: 2,
108            pad_id: -1,
109            unk_surface: crate::util::DEFAULT_UNKNOWN_SURFACE.to_owned(),
110            unk_piece: "<unk>".to_owned(),
111            bos_piece: "<s>".to_owned(),
112            eos_piece: "</s>".to_owned(),
113            pad_piece: "<pad>".to_owned(),
114        }
115    }
116}
117
118impl Default for NormalizerSpec {
119    fn default() -> Self {
120        Self {
121            name: String::new(),
122            precompiled_charsmap: Vec::new(),
123            add_dummy_prefix: true,
124            remove_extra_whitespaces: true,
125            escape_whitespaces: true,
126        }
127    }
128}
129
130impl ModelProto {
131    pub(crate) fn decode(bytes: &[u8]) -> Result<Self> {
132        let mut proto = Self::default();
133        let mut reader = ProtoReader::new(bytes);
134        while let Some((field, wire)) = reader.read_key()? {
135            match field {
136                1 if wire == WireType::LengthDelimited => {
137                    let bytes = reader.read_len()?;
138                    proto.pieces.push(decode_sentence_piece(bytes)?);
139                }
140                2 if wire == WireType::LengthDelimited => {
141                    proto.trainer_spec = decode_trainer_spec(reader.read_len()?)?;
142                }
143                3 if wire == WireType::LengthDelimited => {
144                    proto.normalizer_spec = decode_normalizer_spec(reader.read_len()?)?;
145                }
146                4 if wire == WireType::LengthDelimited => {
147                    proto.self_test_data = decode_self_test_data(reader.read_len()?)?;
148                }
149                5 if wire == WireType::LengthDelimited => {
150                    proto.denormalizer_spec = Some(decode_normalizer_spec(reader.read_len()?)?);
151                }
152                _ => reader.skip(wire)?,
153            }
154        }
155        Ok(proto)
156    }
157}
158
159fn decode_sentence_piece(bytes: &[u8]) -> Result<SentencePiece> {
160    let mut piece = SentencePiece {
161        piece: String::new(),
162        score: 0.0,
163        kind: PieceType::Normal,
164    };
165    let mut reader = ProtoReader::new(bytes);
166    while let Some((field, wire)) = reader.read_key()? {
167        match field {
168            1 if wire == WireType::LengthDelimited => {
169                piece.piece = reader.read_string()?;
170            }
171            2 if wire == WireType::ThirtyTwoBit => {
172                piece.score = reader.read_f32()?;
173            }
174            3 if wire == WireType::Varint => {
175                piece.kind = PieceType::from_i32(reader.read_varint()? as i32);
176            }
177            _ => reader.skip(wire)?,
178        }
179    }
180    Ok(piece)
181}
182
183fn decode_trainer_spec(bytes: &[u8]) -> Result<TrainerSpec> {
184    let mut spec = TrainerSpec::default();
185    let mut reader = ProtoReader::new(bytes);
186    while let Some((field, wire)) = reader.read_key()? {
187        match field {
188            3 if wire == WireType::Varint => {
189                spec.model_type = ModelType::from_i32(reader.read_varint()? as i32);
190            }
191            24 if wire == WireType::Varint => {
192                spec.treat_whitespace_as_suffix = reader.read_bool()?;
193            }
194            26 if wire == WireType::Varint => {
195                spec.allow_whitespace_only_pieces = reader.read_bool()?;
196            }
197            35 if wire == WireType::Varint => {
198                spec.byte_fallback = reader.read_bool()?;
199            }
200            40 if wire == WireType::Varint => {
201                spec.unk_id = reader.read_varint()? as i32;
202            }
203            41 if wire == WireType::Varint => {
204                spec.bos_id = reader.read_varint()? as i32;
205            }
206            42 if wire == WireType::Varint => {
207                spec.eos_id = reader.read_varint()? as i32;
208            }
209            43 if wire == WireType::Varint => {
210                spec.pad_id = reader.read_varint()? as i32;
211            }
212            44 if wire == WireType::LengthDelimited => {
213                spec.unk_surface = reader.read_string()?;
214            }
215            45 if wire == WireType::LengthDelimited => {
216                spec.unk_piece = reader.read_string()?;
217            }
218            46 if wire == WireType::LengthDelimited => {
219                spec.bos_piece = reader.read_string()?;
220            }
221            47 if wire == WireType::LengthDelimited => {
222                spec.eos_piece = reader.read_string()?;
223            }
224            48 if wire == WireType::LengthDelimited => {
225                spec.pad_piece = reader.read_string()?;
226            }
227            _ => reader.skip(wire)?,
228        }
229    }
230    Ok(spec)
231}
232
233fn decode_normalizer_spec(bytes: &[u8]) -> Result<NormalizerSpec> {
234    let mut spec = NormalizerSpec::default();
235    let mut reader = ProtoReader::new(bytes);
236    while let Some((field, wire)) = reader.read_key()? {
237        match field {
238            1 if wire == WireType::LengthDelimited => {
239                spec.name = reader.read_string()?;
240            }
241            2 if wire == WireType::LengthDelimited => {
242                spec.precompiled_charsmap = reader.read_len()?.to_vec();
243            }
244            3 if wire == WireType::Varint => {
245                spec.add_dummy_prefix = reader.read_bool()?;
246            }
247            4 if wire == WireType::Varint => {
248                spec.remove_extra_whitespaces = reader.read_bool()?;
249            }
250            5 if wire == WireType::Varint => {
251                spec.escape_whitespaces = reader.read_bool()?;
252            }
253            _ => reader.skip(wire)?,
254        }
255    }
256    Ok(spec)
257}
258
259fn decode_self_test_data(bytes: &[u8]) -> Result<SelfTestData> {
260    let mut data = SelfTestData::default();
261    let mut reader = ProtoReader::new(bytes);
262    while let Some((field, wire)) = reader.read_key()? {
263        match field {
264            1 if wire == WireType::LengthDelimited => {
265                data.samples
266                    .push(decode_self_test_sample(reader.read_len()?)?);
267            }
268            _ => reader.skip(wire)?,
269        }
270    }
271    Ok(data)
272}
273
274fn decode_self_test_sample(bytes: &[u8]) -> Result<SelfTestSample> {
275    let mut sample = SelfTestSample {
276        input: String::new(),
277        expected: String::new(),
278    };
279    let mut reader = ProtoReader::new(bytes);
280    while let Some((field, wire)) = reader.read_key()? {
281        match field {
282            1 if wire == WireType::LengthDelimited => {
283                sample.input = reader.read_string()?;
284            }
285            2 if wire == WireType::LengthDelimited => {
286                sample.expected = reader.read_string()?;
287            }
288            _ => reader.skip(wire)?,
289        }
290    }
291    Ok(sample)
292}
293
294#[derive(Clone, Copy, Debug, Eq, PartialEq)]
295enum WireType {
296    Varint,
297    SixtyFourBit,
298    LengthDelimited,
299    ThirtyTwoBit,
300}
301
302impl WireType {
303    fn from_key(value: u64) -> Result<Self> {
304        match value & 0b111 {
305            0 => Ok(Self::Varint),
306            1 => Ok(Self::SixtyFourBit),
307            2 => Ok(Self::LengthDelimited),
308            5 => Ok(Self::ThirtyTwoBit),
309            wire => Err(Error::model_parse(format!(
310                "unsupported protobuf wire type {wire}"
311            ))),
312        }
313    }
314}
315
316struct ProtoReader<'a> {
317    bytes: &'a [u8],
318    position: usize,
319}
320
321impl<'a> ProtoReader<'a> {
322    fn new(bytes: &'a [u8]) -> Self {
323        Self { bytes, position: 0 }
324    }
325
326    fn read_key(&mut self) -> Result<Option<(u32, WireType)>> {
327        if self.position == self.bytes.len() {
328            return Ok(None);
329        }
330
331        let key = self.read_varint()?;
332        let field = (key >> 3) as u32;
333        if field == 0 {
334            return Err(Error::model_parse("protobuf field number 0 is invalid"));
335        }
336        Ok(Some((field, WireType::from_key(key)?)))
337    }
338
339    fn read_varint(&mut self) -> Result<u64> {
340        let mut value = 0u64;
341        for shift in (0..64).step_by(7) {
342            let byte = *self
343                .bytes
344                .get(self.position)
345                .ok_or_else(|| Error::model_parse("unexpected end of protobuf varint"))?;
346            self.position += 1;
347            value |= u64::from(byte & 0x7f) << shift;
348            if byte & 0x80 == 0 {
349                return Ok(value);
350            }
351        }
352        Err(Error::model_parse("protobuf varint is too long"))
353    }
354
355    fn read_bool(&mut self) -> Result<bool> {
356        Ok(self.read_varint()? != 0)
357    }
358
359    fn read_len(&mut self) -> Result<&'a [u8]> {
360        let len = self.read_varint()? as usize;
361        let end = self
362            .position
363            .checked_add(len)
364            .ok_or_else(|| Error::model_parse("protobuf length overflow"))?;
365        if end > self.bytes.len() {
366            return Err(Error::model_parse(
367                "unexpected end of protobuf length field",
368            ));
369        }
370        let out = &self.bytes[self.position..end];
371        self.position = end;
372        Ok(out)
373    }
374
375    fn read_string(&mut self) -> Result<String> {
376        let bytes = self.read_len()?;
377        String::from_utf8(bytes.to_vec())
378            .map_err(|_| Error::model_parse("protobuf string is not valid UTF-8"))
379    }
380
381    fn read_f32(&mut self) -> Result<f32> {
382        let end = self
383            .position
384            .checked_add(4)
385            .ok_or_else(|| Error::model_parse("protobuf fixed32 overflow"))?;
386        if end > self.bytes.len() {
387            return Err(Error::model_parse("unexpected end of protobuf fixed32"));
388        }
389        let bytes = [
390            self.bytes[self.position],
391            self.bytes[self.position + 1],
392            self.bytes[self.position + 2],
393            self.bytes[self.position + 3],
394        ];
395        self.position = end;
396        Ok(f32::from_le_bytes(bytes))
397    }
398
399    fn skip(&mut self, wire: WireType) -> Result<()> {
400        match wire {
401            WireType::Varint => {
402                self.read_varint()?;
403            }
404            WireType::SixtyFourBit => {
405                self.skip_bytes(8)?;
406            }
407            WireType::LengthDelimited => {
408                let len = self.read_varint()? as usize;
409                self.skip_bytes(len)?;
410            }
411            WireType::ThirtyTwoBit => {
412                self.skip_bytes(4)?;
413            }
414        }
415        Ok(())
416    }
417
418    fn skip_bytes(&mut self, len: usize) -> Result<()> {
419        let end = self
420            .position
421            .checked_add(len)
422            .ok_or_else(|| Error::model_parse("protobuf skip overflow"))?;
423        if end > self.bytes.len() {
424            return Err(Error::model_parse(
425                "unexpected end while skipping protobuf field",
426            ));
427        }
428        self.position = end;
429        Ok(())
430    }
431}