scirs2_text/huggingface_compat/
tokenizer.rs

1//! Hugging Face compatible tokenizer implementations
2//!
3//! This module provides tokenizer wrappers that are compatible with
4//! Hugging Face tokenizer formats and APIs.
5
6use super::config::HfTokenizerConfig;
7use crate::error::Result;
8use crate::tokenize::Tokenizer;
9use std::collections::HashMap;
10
11/// Hugging Face compatible tokenizer wrapper
12pub struct HfTokenizer {
13    /// Underlying tokenizer
14    tokenizer: Box<dyn Tokenizer>,
15    /// Tokenizer configuration
16    config: HfTokenizerConfig,
17    /// Vocabulary mapping
18    vocab: HashMap<String, usize>,
19    /// Reverse vocabulary mapping
20    reverse_vocab: HashMap<usize, String>,
21}
22
23impl HfTokenizer {
24    /// Create new HF-compatible tokenizer
25    pub fn new(tokenizer: Box<dyn Tokenizer>, config: HfTokenizerConfig) -> Self {
26        // Create basic vocabulary (in practice, this would be loaded from files)
27        let mut vocab = HashMap::new();
28        let mut reverse_vocab = HashMap::new();
29
30        // Add special tokens
31        for (token_id, token) in config.special_tokens.keys().enumerate() {
32            vocab.insert(token.clone(), token_id);
33            reverse_vocab.insert(token_id, token.clone());
34        }
35
36        Self {
37            tokenizer,
38            config,
39            vocab,
40            reverse_vocab,
41        }
42    }
43
44    /// Tokenize text with HF-compatible output
45    pub fn encode(&self, text: &str, add_specialtokens: bool) -> Result<HfEncodedInput> {
46        let mut tokens = self.tokenizer.tokenize(text)?;
47
48        // Add special tokens if requested
49        if add_specialtokens {
50            if let Some(bos_token) = &self.config.bos_token {
51                tokens.insert(0, bos_token.clone());
52            }
53            if let Some(eos_token) = &self.config.eos_token {
54                tokens.push(eos_token.clone());
55            }
56        }
57
58        // Convert tokens to IDs
59        let input_ids: Vec<usize> = tokens
60            .iter()
61            .map(|token| {
62                self.vocab
63                    .get(token)
64                    .copied()
65                    .unwrap_or(self.vocab.get(&self.config.unk_token).copied().unwrap_or(0))
66            })
67            .collect();
68
69        // Create attention mask (1 for real tokens, 0 for padding)
70        let attention_mask = vec![1; input_ids.len()];
71
72        // Token type IDs (all 0 for single sentence)
73        let token_type_ids = vec![0; input_ids.len()];
74
75        Ok(HfEncodedInput {
76            input_ids,
77            attention_mask,
78            token_type_ids: Some(token_type_ids),
79            tokens,
80        })
81    }
82
83    /// Batch encode multiple texts
84    pub fn encode_batch(
85        &self,
86        texts: &[&str],
87        add_special_tokens: bool,
88    ) -> Result<Vec<HfEncodedInput>> {
89        texts
90            .iter()
91            .map(|text| self.encode(text, add_special_tokens))
92            .collect()
93    }
94
95    /// Decode token IDs back to text
96    pub fn decode(&self, token_ids: &[usize], skip_specialtokens: bool) -> Result<String> {
97        let tokens: Vec<String> = token_ids
98            .iter()
99            .filter_map(|&id| self.reverse_vocab.get(&id))
100            .filter(|token| {
101                if skip_specialtokens {
102                    !self.config.special_tokens.contains_key(*token)
103                } else {
104                    true
105                }
106            })
107            .cloned()
108            .collect();
109
110        Ok(tokens.join(" "))
111    }
112
113    /// Get vocabulary size
114    pub fn vocab_size(&self) -> usize {
115        self.vocab.len()
116    }
117}
118
119/// HF-compatible encoded input format
120#[derive(Debug, Clone)]
121pub struct HfEncodedInput {
122    /// Token IDs
123    pub input_ids: Vec<usize>,
124    /// Attention mask
125    pub attention_mask: Vec<i32>,
126    /// Token type IDs (for multi-sentence tasks)
127    pub token_type_ids: Option<Vec<usize>>,
128    /// Original tokens
129    pub tokens: Vec<String>,
130}