Skip to main content

symproj/
lib.rs

1//! # symproj
2//!
3//! Symbolic projection and embeddings.
4//!
5//! Maps discrete symbols to continuous vectors using a Codebook.
6//!
7//! **Naming note**: this crate was previously named `proj`, but `proj` is already taken on crates.io
8//! by GeoRust's PROJ bindings (geospatial). We publish this crate as `symproj`.
9//!
10//! ## Intuition First
11//!
12//! Imagine a library where every book has a call number. The call number
13//! isn't just a label; it tells you where the book sits in a 3D space.
14//! `symproj` is the system that maps "book names" (tokens) to "library coordinates" (vectors).
15//!
16//! ## Provenance (minimal citations)
17//!
18//! What this crate implements is the long-lived primitive:
19//! \[
20//! (t_1,\dots,t_n)\mapsto \mathbb{R}^d
21//! \]
22//! via (1) embedding lookup (a codebook) and (2) pooling (mean).
23//!
24//! - **Word embeddings / lookup tables**: Mikolov et al. (word2vec), 2013. [`arXiv:1301.3781`](https://arxiv.org/abs/1301.3781)
25//! - **Subword tokenization**:
26//!   - BPE for NMT: Sennrich et al., 2016. [`P16-1162`](https://aclanthology.org/P16-1162/)
27//!   - SentencePiece / Unigram LM: Kudo, 2018. [`arXiv:1808.06226`](https://arxiv.org/abs/1808.06226)
28//! - **Sentence embeddings baseline**: Arora et al. (SIF), 2017. [`ICLR OpenReview`](https://openreview.net/forum?id=SyK00v5xx)
29//! - **Modern sentence embedding fine-tuning**:
30//!   - SBERT: Reimers & Gurevych, 2019. [`D19-1410`](https://aclanthology.org/D19-1410/)
31//!   - SimCSE: Gao et al., 2021. [`EMNLP 2021`](https://aclanthology.org/2021.emnlp-main.552/)
32//! - **Retrieval context (token vectors + pooling/compression)**:
33//!   - ColBERT (late interaction): Khattab & Zaharia, 2020. [`arXiv:2004.12832`](https://arxiv.org/abs/2004.12832)
34//!
35//! ## Nearby Rust ecosystem crates (context, not dependencies)
36//!
37//! - `tokenizers` (Hugging Face tokenization): <https://docs.rs/tokenizers/>
38//! - `sentencepiece` (SentencePiece model loading): <https://crates.io/crates/sentencepiece>
39//! - `finalfusion` / `rust2vec` (word embedding formats): <https://docs.rs/finalfusion/> / <https://docs.rs/rust2vec/>
40//! - `fastembed` (embedding generation via ONNX): <https://docs.rs/fastembed/>
41//! - `candle` (Rust ML runtime): <https://github.com/huggingface/candle>
42
43use textprep::SubwordTokenizer;
44
45#[derive(Debug, thiserror::Error)]
46pub enum Error {
47    #[error("Dimension mismatch: expected {expected}, got {got}")]
48    DimensionMismatch { expected: usize, got: usize },
49    #[error("Token not found in codebook: {0}")]
50    TokenNotFound(u32),
51    #[error("Weight length mismatch: expected {expected}, got {got}")]
52    WeightLenMismatch { expected: usize, got: usize },
53    #[error("dimension cannot be zero")]
54    ZeroDimension,
55    #[error("matrix length {len} is not a multiple of dimension {dim}")]
56    InvalidMatrixShape { len: usize, dim: usize },
57}
58
59pub type Result<T> = std::result::Result<T, Error>;
60
61/// A Codebook maps token IDs to dense vectors.
62#[derive(Debug, Clone)]
63pub struct Codebook {
64    /// Flattened embedding matrix [vocab_size * dim]
65    matrix: Vec<f32>,
66    /// Dimension of each vector
67    dim: usize,
68}
69
70impl Codebook {
71    /// Create a new Codebook from a flattened matrix and dimension.
72    pub fn new(matrix: Vec<f32>, dim: usize) -> Result<Self> {
73        if dim == 0 {
74            return Err(Error::ZeroDimension);
75        }
76        if !matrix.len().is_multiple_of(dim) {
77            return Err(Error::InvalidMatrixShape {
78                len: matrix.len(),
79                dim,
80            });
81        }
82        Ok(Self { matrix, dim })
83    }
84
85    /// Get the vector for a token ID.
86    pub fn get(&self, id: u32) -> Option<&[f32]> {
87        let start = (id as usize) * self.dim;
88        let end = start + self.dim;
89        if end <= self.matrix.len() {
90            Some(&self.matrix[start..end])
91        } else {
92            None
93        }
94    }
95
96    /// Get the embedding dimension.
97    pub fn dim(&self) -> usize {
98        self.dim
99    }
100
101    /// Number of token IDs representable by this codebook.
102    pub fn vocab_size(&self) -> usize {
103        self.matrix.len() / self.dim
104    }
105}
106
107impl Codebook {
108    /// Encode a token-id sequence into a single vector using mean pooling.
109    ///
110    /// This is **lenient**: token IDs not present in the codebook are skipped.
111    pub fn encode_ids(&self, ids: &[u32]) -> Vec<f32> {
112        if ids.is_empty() {
113            return vec![0.0; self.dim];
114        }
115
116        let embeddings: Vec<&[f32]> = ids.iter().filter_map(|&id| self.get(id)).collect();
117        if embeddings.is_empty() {
118            return vec![0.0; self.dim];
119        }
120
121        let mut out = vec![0.0; self.dim];
122        let count = embeddings.len() as f32;
123        for emb in &embeddings {
124            for (o, &e) in out.iter_mut().zip(emb.iter()) {
125                *o += e;
126            }
127        }
128        for o in out.iter_mut() {
129            *o /= count;
130        }
131        out
132    }
133
134    /// Encode token IDs into a single vector using mean pooling (strict).
135    ///
136    /// Unlike [`Self::encode_ids`], this returns an error if any token ID is not present in the
137    /// codebook. This is useful when you need a “closed vocabulary” contract.
138    pub fn encode_ids_strict(&self, ids: &[u32]) -> Result<Vec<f32>> {
139        if ids.is_empty() {
140            return Ok(vec![0.0; self.dim]);
141        }
142
143        let mut embeddings: Vec<&[f32]> = Vec::with_capacity(ids.len());
144        for &id in ids {
145            let emb = self.get(id).ok_or(Error::TokenNotFound(id))?;
146            embeddings.push(emb);
147        }
148
149        let mut out = vec![0.0; self.dim];
150        let count = embeddings.len() as f32;
151        for emb in &embeddings {
152            for (o, &e) in out.iter_mut().zip(emb.iter()) {
153                *o += e;
154            }
155        }
156        for o in out.iter_mut() {
157            *o /= count;
158        }
159        Ok(out)
160    }
161
162    /// Encode token IDs into a single vector using a weighted mean (strict).
163    ///
164    /// We compute:
165    /// \[
166    /// v = \frac{\sum_i w_i \, E\[t_i\]}{\sum_i w_i}
167    /// \]
168    /// with the convention that if \(\sum_i w_i \le 0\), we return the zero vector.
169    ///
170    /// Weighting is one route toward SIF-style baselines (Arora et al., 2017).
171    pub fn encode_ids_weighted_strict(&self, ids: &[u32], weights: &[f32]) -> Result<Vec<f32>> {
172        if ids.len() != weights.len() {
173            return Err(Error::WeightLenMismatch {
174                expected: ids.len(),
175                got: weights.len(),
176            });
177        }
178        if ids.is_empty() {
179            return Ok(vec![0.0; self.dim]);
180        }
181
182        let dim = self.dim;
183        let mut out = vec![0.0f32; dim];
184        let mut sum_w = 0.0f32;
185
186        for (&id, &w) in ids.iter().zip(weights.iter()) {
187            let emb = self.get(id).ok_or(Error::TokenNotFound(id))?;
188            if w == 0.0 {
189                continue;
190            }
191            sum_w += w;
192            for (o, &e) in out.iter_mut().zip(emb.iter()) {
193                *o += w * e;
194            }
195        }
196
197        if sum_w <= 0.0 {
198            return Ok(vec![0.0; dim]);
199        }
200
201        for o in out.iter_mut() {
202            *o /= sum_w;
203        }
204        Ok(out)
205    }
206
207    /// Encode token ids into a sequence of vectors (no pooling).
208    pub fn encode_sequence_ids(&self, ids: &[u32]) -> Vec<Vec<f32>> {
209        let mut result = Vec::with_capacity(ids.len());
210        for &id in ids {
211            if let Some(emb) = self.get(id) {
212                result.push(emb.to_vec());
213            }
214        }
215        result
216    }
217}
218
219/// SIF (Smooth Inverse Frequency) weight from Arora et al. (2017):
220/// \[
221/// w(p) = \frac{a}{a + p}
222/// \]
223/// where \(p\) is token probability and \(a\) is a small smoothing constant (often \(10^{-3}\)).
224#[inline]
225pub fn sif_weight(p: f32, a: f32) -> f32 {
226    if a <= 0.0 {
227        return 0.0;
228    }
229    if p < 0.0 {
230        return 0.0;
231    }
232    a / (a + p)
233}
234
235/// L2-normalize a vector in place.
236///
237/// If the input has norm 0, this is a no-op.
238pub fn l2_normalize_in_place(v: &mut [f32]) {
239    let mut ss = 0.0f32;
240    for &x in v.iter() {
241        ss += x * x;
242    }
243    if ss <= 0.0 {
244        return;
245    }
246    let inv = 1.0f32 / ss.sqrt();
247    for x in v.iter_mut() {
248        *x *= inv;
249    }
250}
251
252/// Remove a (unit) component direction \(u\) from a vector \(v\):
253/// \[
254/// v \leftarrow v - u \,(u \cdot v)
255/// \]
256///
257/// This is the “remove top PC” post-step used in SIF-style baselines (when \(u\) is the top PC).
258pub fn remove_component_in_place(v: &mut [f32], u_unit: &[f32]) -> Result<()> {
259    if v.len() != u_unit.len() {
260        return Err(Error::DimensionMismatch {
261            expected: v.len(),
262            got: u_unit.len(),
263        });
264    }
265    let mut dot = 0.0f32;
266    for i in 0..v.len() {
267        dot += u_unit[i] * v[i];
268    }
269    for i in 0..v.len() {
270        v[i] -= u_unit[i] * dot;
271    }
272    Ok(())
273}
274
275/// A Projection combines a Tokenizer and a Codebook.
276pub struct Projection<T: SubwordTokenizer> {
277    tokenizer: T,
278    codebook: Codebook,
279}
280
281impl<T: SubwordTokenizer> Projection<T> {
282    /// Create a new Projection.
283    pub fn new(tokenizer: T, codebook: Codebook) -> Self {
284        Self {
285            tokenizer,
286            codebook,
287        }
288    }
289
290    /// Encode text into a single vector using mean pooling.
291    pub fn encode(&self, text: &str) -> Vec<f32> {
292        let tokens = self.tokenizer.tokenize(text);
293        self.codebook.encode_ids(&tokens)
294    }
295
296    /// Encode text into a sequence of vectors (no pooling).
297    pub fn encode_sequence(&self, text: &str) -> Vec<Vec<f32>> {
298        let tokens = self.tokenizer.tokenize(text);
299        self.codebook.encode_sequence_ids(&tokens)
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306    use std::collections::HashMap;
307    use textprep::BpeTokenizer;
308
309    #[test]
310    fn test_projection_basic() {
311        let mut vocab = HashMap::new();
312        vocab.insert("apple".to_string(), 0);
313        vocab.insert("pie".to_string(), 1);
314        let tokenizer = BpeTokenizer::from_vocab(vocab);
315
316        let matrix = vec![
317            1.0, 0.0, 0.0, // apple
318            0.0, 1.0, 0.0, // pie
319        ];
320        let codebook = Codebook::new(matrix, 3).unwrap();
321        let proj = Projection::new(tokenizer, codebook);
322
323        let vec = proj.encode("apple pie");
324        // Mean pooling: ( [1,0,0] + [0,1,0] ) / 2 = [0.5, 0.5, 0]
325        assert!((vec[0] - 0.5).abs() < 1e-6);
326        assert!((vec[1] - 0.5).abs() < 1e-6);
327        assert!((vec[2] - 0.0).abs() < 1e-6);
328    }
329
330    #[test]
331    fn test_codebook_rejects_zero_dim() {
332        let err = Codebook::new(vec![1.0, 2.0, 3.0], 0).unwrap_err();
333        let msg = err.to_string();
334        assert!(msg.contains("dimension cannot be zero"), "got: {msg}");
335    }
336
337    #[test]
338    fn test_codebook_rejects_non_multiple() {
339        let err = Codebook::new(vec![1.0, 2.0, 3.0], 2).unwrap_err();
340        let msg = err.to_string();
341        assert!(msg.contains("not a multiple of dimension"), "got: {msg}");
342    }
343
344    #[test]
345    fn codebook_strict_errors_on_missing_token() {
346        let codebook = Codebook::new(vec![1.0, 2.0], 2).unwrap(); // vocab_size=1
347        let err = codebook.encode_ids_strict(&[0, 9]).unwrap_err();
348        let msg = err.to_string();
349        assert!(msg.contains("Token not found"), "got: {msg}");
350    }
351
352    #[test]
353    fn weighted_mean_matches_unweighted_mean_when_all_weights_equal() {
354        let matrix = vec![
355            1.0, 0.0, // id=0
356            0.0, 1.0, // id=1
357        ];
358        let codebook = Codebook::new(matrix, 2).unwrap();
359        let ids = [0u32, 1u32];
360        let w = [1.0f32, 1.0f32];
361        let v = codebook.encode_ids_weighted_strict(&ids, &w).unwrap();
362        assert!((v[0] - 0.5).abs() < 1e-6);
363        assert!((v[1] - 0.5).abs() < 1e-6);
364    }
365
366    #[test]
367    fn l2_normalize_has_unit_norm_when_nonzero() {
368        let mut v = vec![3.0f32, 4.0];
369        l2_normalize_in_place(&mut v);
370        let norm = (v[0] * v[0] + v[1] * v[1]).sqrt();
371        assert!((norm - 1.0).abs() < 1e-6, "norm={norm}");
372    }
373
374    #[test]
375    fn multilingual_vocab_smoke() {
376        // Minimal multilingual coverage (scripts + diacritics).
377        let mut vocab = HashMap::new();
378        vocab.insert("東京".to_string(), 0);
379        vocab.insert("Москва".to_string(), 1);
380        vocab.insert("التقى".to_string(), 2);
381        vocab.insert("राम".to_string(), 3);
382        vocab.insert("François".to_string(), 4);
383        let tokenizer = BpeTokenizer::from_vocab(vocab);
384
385        // 5 tokens, 1-D vectors for simplicity.
386        let matrix = vec![1.0, 2.0, 3.0, 4.0, 5.0];
387        let codebook = Codebook::new(matrix, 1).unwrap();
388        let proj = Projection::new(tokenizer, codebook);
389
390        let v = proj.encode("東京 Москва التقى राम François");
391        assert!((v[0] - 3.0).abs() < 1e-6, "got={:?}", v);
392    }
393}