Skip to main content

svod_model/
sentencepiece.rs

1//! SentencePiece `.model` (protobuf) loader. Pieces retain their `▁`
2//! (U+2581) prefix on word-initial tokens; consumers replace `▁` with a
3//! space for natural detokenization.
4
5use std::path::{Path, PathBuf};
6
7use snafu::{ResultExt, Snafu};
8
9/// Partial decode of `ModelProto` — only the `pieces` array. prost skips
10/// unknown tags. Tag source of truth: upstream `google/sentencepiece`
11/// repo's `src/sentencepiece_model.proto`.
12#[derive(prost::Message)]
13struct SpModelProto {
14    #[prost(message, repeated, tag = "1")]
15    pieces: Vec<SpPiece>,
16}
17
18#[derive(prost::Message)]
19struct SpPiece {
20    /// The piece string, e.g. `"▁hello"` (`U+2581` = SP space marker) or
21    /// `"<unk>"` for control tokens.
22    #[prost(string, optional, tag = "1")]
23    piece: Option<String>,
24    /// `enum Type { NORMAL = 1; UNKNOWN = 2; CONTROL = 3; USER_DEFINED = 4; BYTE = 6; UNUSED = 5 }`.
25    #[prost(int32, optional, tag = "3")]
26    r#type: Option<i32>,
27}
28
29#[derive(Debug, Snafu)]
30#[snafu(visibility(pub))]
31pub enum Error {
32    #[snafu(display("reading SentencePiece model from {}: {source}", path.display()))]
33    Io { path: PathBuf, source: std::io::Error },
34    #[snafu(display("parsing SentencePiece model at {}: {source}", path.display()))]
35    Decode { path: PathBuf, source: prost::DecodeError },
36}
37
38pub type Result<T> = std::result::Result<T, Error>;
39
40/// Read a SentencePiece `.model` file and return per-id raw pieces.
41///
42/// Special tokens (UNKNOWN=2, CONTROL=3, BYTE=6, UNUSED=5) are mapped to the
43/// empty string so they elide from the transcript on the (rare) chance the
44/// model emits one.
45pub fn load_vocab(path: &Path) -> Result<Vec<String>> {
46    use prost::Message;
47    let bytes = std::fs::read(path).context(IoSnafu { path: path.to_path_buf() })?;
48    let proto = SpModelProto::decode(&*bytes).context(DecodeSnafu { path: path.to_path_buf() })?;
49    let mut pieces = Vec::with_capacity(proto.pieces.len());
50    for p in proto.pieces {
51        let kind = p.r#type.unwrap_or(1);
52        // Type 1 = NORMAL, 4 = USER_DEFINED. Everything else (UNKNOWN,
53        // CONTROL, BYTE, UNUSED) is non-emittable: store empty so the
54        // transcript stays clean if the predictor accidentally lands there.
55        let s = if kind == 1 || kind == 4 { p.piece.unwrap_or_default() } else { String::new() };
56        pieces.push(s);
57    }
58    Ok(pieces)
59}