1use textprep::SubwordTokenizer;
44
45#[derive(Debug, thiserror::Error)]
46pub enum Error {
47 #[error("Dimension mismatch: expected {expected}, got {got}")]
48 DimensionMismatch { expected: usize, got: usize },
49 #[error("Token not found in codebook: {0}")]
50 TokenNotFound(u32),
51 #[error("Weight length mismatch: expected {expected}, got {got}")]
52 WeightLenMismatch { expected: usize, got: usize },
53 #[error("dimension cannot be zero")]
54 ZeroDimension,
55 #[error("matrix length {len} is not a multiple of dimension {dim}")]
56 InvalidMatrixShape { len: usize, dim: usize },
57}
58
59pub type Result<T> = std::result::Result<T, Error>;
60
61#[derive(Debug, Clone)]
63pub struct Codebook {
64 matrix: Vec<f32>,
66 dim: usize,
68}
69
70impl Codebook {
71 pub fn new(matrix: Vec<f32>, dim: usize) -> Result<Self> {
73 if dim == 0 {
74 return Err(Error::ZeroDimension);
75 }
76 if !matrix.len().is_multiple_of(dim) {
77 return Err(Error::InvalidMatrixShape {
78 len: matrix.len(),
79 dim,
80 });
81 }
82 Ok(Self { matrix, dim })
83 }
84
85 pub fn get(&self, id: u32) -> Option<&[f32]> {
87 let start = (id as usize) * self.dim;
88 let end = start + self.dim;
89 if end <= self.matrix.len() {
90 Some(&self.matrix[start..end])
91 } else {
92 None
93 }
94 }
95
96 pub fn dim(&self) -> usize {
98 self.dim
99 }
100
101 pub fn vocab_size(&self) -> usize {
103 self.matrix.len() / self.dim
104 }
105}
106
107impl Codebook {
108 pub fn encode_ids(&self, ids: &[u32]) -> Vec<f32> {
112 if ids.is_empty() {
113 return vec![0.0; self.dim];
114 }
115
116 let embeddings: Vec<&[f32]> = ids.iter().filter_map(|&id| self.get(id)).collect();
117 if embeddings.is_empty() {
118 return vec![0.0; self.dim];
119 }
120
121 let mut out = vec![0.0; self.dim];
122 let count = embeddings.len() as f32;
123 for emb in &embeddings {
124 for (o, &e) in out.iter_mut().zip(emb.iter()) {
125 *o += e;
126 }
127 }
128 for o in out.iter_mut() {
129 *o /= count;
130 }
131 out
132 }
133
134 pub fn encode_ids_strict(&self, ids: &[u32]) -> Result<Vec<f32>> {
139 if ids.is_empty() {
140 return Ok(vec![0.0; self.dim]);
141 }
142
143 let mut embeddings: Vec<&[f32]> = Vec::with_capacity(ids.len());
144 for &id in ids {
145 let emb = self.get(id).ok_or(Error::TokenNotFound(id))?;
146 embeddings.push(emb);
147 }
148
149 let mut out = vec![0.0; self.dim];
150 let count = embeddings.len() as f32;
151 for emb in &embeddings {
152 for (o, &e) in out.iter_mut().zip(emb.iter()) {
153 *o += e;
154 }
155 }
156 for o in out.iter_mut() {
157 *o /= count;
158 }
159 Ok(out)
160 }
161
162 pub fn encode_ids_weighted_strict(&self, ids: &[u32], weights: &[f32]) -> Result<Vec<f32>> {
172 if ids.len() != weights.len() {
173 return Err(Error::WeightLenMismatch {
174 expected: ids.len(),
175 got: weights.len(),
176 });
177 }
178 if ids.is_empty() {
179 return Ok(vec![0.0; self.dim]);
180 }
181
182 let dim = self.dim;
183 let mut out = vec![0.0f32; dim];
184 let mut sum_w = 0.0f32;
185
186 for (&id, &w) in ids.iter().zip(weights.iter()) {
187 let emb = self.get(id).ok_or(Error::TokenNotFound(id))?;
188 if w == 0.0 {
189 continue;
190 }
191 sum_w += w;
192 for (o, &e) in out.iter_mut().zip(emb.iter()) {
193 *o += w * e;
194 }
195 }
196
197 if sum_w <= 0.0 {
198 return Ok(vec![0.0; dim]);
199 }
200
201 for o in out.iter_mut() {
202 *o /= sum_w;
203 }
204 Ok(out)
205 }
206
207 pub fn encode_sequence_ids(&self, ids: &[u32]) -> Vec<Vec<f32>> {
209 let mut result = Vec::with_capacity(ids.len());
210 for &id in ids {
211 if let Some(emb) = self.get(id) {
212 result.push(emb.to_vec());
213 }
214 }
215 result
216 }
217}
218
219#[inline]
225pub fn sif_weight(p: f32, a: f32) -> f32 {
226 if a <= 0.0 {
227 return 0.0;
228 }
229 if p < 0.0 {
230 return 0.0;
231 }
232 a / (a + p)
233}
234
235pub fn l2_normalize_in_place(v: &mut [f32]) {
239 let mut ss = 0.0f32;
240 for &x in v.iter() {
241 ss += x * x;
242 }
243 if ss <= 0.0 {
244 return;
245 }
246 let inv = 1.0f32 / ss.sqrt();
247 for x in v.iter_mut() {
248 *x *= inv;
249 }
250}
251
252pub fn remove_component_in_place(v: &mut [f32], u_unit: &[f32]) -> Result<()> {
259 if v.len() != u_unit.len() {
260 return Err(Error::DimensionMismatch {
261 expected: v.len(),
262 got: u_unit.len(),
263 });
264 }
265 let mut dot = 0.0f32;
266 for i in 0..v.len() {
267 dot += u_unit[i] * v[i];
268 }
269 for i in 0..v.len() {
270 v[i] -= u_unit[i] * dot;
271 }
272 Ok(())
273}
274
275pub struct Projection<T: SubwordTokenizer> {
277 tokenizer: T,
278 codebook: Codebook,
279}
280
281impl<T: SubwordTokenizer> Projection<T> {
282 pub fn new(tokenizer: T, codebook: Codebook) -> Self {
284 Self {
285 tokenizer,
286 codebook,
287 }
288 }
289
290 pub fn encode(&self, text: &str) -> Vec<f32> {
292 let tokens = self.tokenizer.tokenize(text);
293 self.codebook.encode_ids(&tokens)
294 }
295
296 pub fn encode_sequence(&self, text: &str) -> Vec<Vec<f32>> {
298 let tokens = self.tokenizer.tokenize(text);
299 self.codebook.encode_sequence_ids(&tokens)
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306 use std::collections::HashMap;
307 use textprep::BpeTokenizer;
308
309 #[test]
310 fn test_projection_basic() {
311 let mut vocab = HashMap::new();
312 vocab.insert("apple".to_string(), 0);
313 vocab.insert("pie".to_string(), 1);
314 let tokenizer = BpeTokenizer::from_vocab(vocab);
315
316 let matrix = vec![
317 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, ];
320 let codebook = Codebook::new(matrix, 3).unwrap();
321 let proj = Projection::new(tokenizer, codebook);
322
323 let vec = proj.encode("apple pie");
324 assert!((vec[0] - 0.5).abs() < 1e-6);
326 assert!((vec[1] - 0.5).abs() < 1e-6);
327 assert!((vec[2] - 0.0).abs() < 1e-6);
328 }
329
330 #[test]
331 fn test_codebook_rejects_zero_dim() {
332 let err = Codebook::new(vec![1.0, 2.0, 3.0], 0).unwrap_err();
333 let msg = err.to_string();
334 assert!(msg.contains("dimension cannot be zero"), "got: {msg}");
335 }
336
337 #[test]
338 fn test_codebook_rejects_non_multiple() {
339 let err = Codebook::new(vec![1.0, 2.0, 3.0], 2).unwrap_err();
340 let msg = err.to_string();
341 assert!(msg.contains("not a multiple of dimension"), "got: {msg}");
342 }
343
344 #[test]
345 fn codebook_strict_errors_on_missing_token() {
346 let codebook = Codebook::new(vec![1.0, 2.0], 2).unwrap(); let err = codebook.encode_ids_strict(&[0, 9]).unwrap_err();
348 let msg = err.to_string();
349 assert!(msg.contains("Token not found"), "got: {msg}");
350 }
351
352 #[test]
353 fn weighted_mean_matches_unweighted_mean_when_all_weights_equal() {
354 let matrix = vec![
355 1.0, 0.0, 0.0, 1.0, ];
358 let codebook = Codebook::new(matrix, 2).unwrap();
359 let ids = [0u32, 1u32];
360 let w = [1.0f32, 1.0f32];
361 let v = codebook.encode_ids_weighted_strict(&ids, &w).unwrap();
362 assert!((v[0] - 0.5).abs() < 1e-6);
363 assert!((v[1] - 0.5).abs() < 1e-6);
364 }
365
366 #[test]
367 fn l2_normalize_has_unit_norm_when_nonzero() {
368 let mut v = vec![3.0f32, 4.0];
369 l2_normalize_in_place(&mut v);
370 let norm = (v[0] * v[0] + v[1] * v[1]).sqrt();
371 assert!((norm - 1.0).abs() < 1e-6, "norm={norm}");
372 }
373
374 #[test]
375 fn multilingual_vocab_smoke() {
376 let mut vocab = HashMap::new();
378 vocab.insert("東京".to_string(), 0);
379 vocab.insert("Москва".to_string(), 1);
380 vocab.insert("التقى".to_string(), 2);
381 vocab.insert("राम".to_string(), 3);
382 vocab.insert("François".to_string(), 4);
383 let tokenizer = BpeTokenizer::from_vocab(vocab);
384
385 let matrix = vec![1.0, 2.0, 3.0, 4.0, 5.0];
387 let codebook = Codebook::new(matrix, 1).unwrap();
388 let proj = Projection::new(tokenizer, codebook);
389
390 let v = proj.encode("東京 Москва التقى राम François");
391 assert!((v[0] - 3.0).abs() < 1e-6, "got={:?}", v);
392 }
393}