Skip to main content

virtual_frame/
nlp.rs

1//! NLP primitives — string distance, n-grams, tokenization.
2//!
3//! Lightweight text analysis utilities for data cleaning and LLM training
4//! pipelines. All functions are deterministic and allocation-conscious.
5
6use std::collections::BTreeMap;
7
8// ── Levenshtein Distance ──────────────────────────────────────────────────
9
10/// Compute the Levenshtein edit distance between two strings.
11///
12/// Uses the classic two-row dynamic programming approach for O(min(m,n))
13/// memory. Each cell considers insertion, deletion, and substitution.
14pub fn levenshtein(a: &str, b: &str) -> usize {
15    let a_bytes = a.as_bytes();
16    let b_bytes = b.as_bytes();
17    let m = a_bytes.len();
18    let n = b_bytes.len();
19
20    if m == 0 {
21        return n;
22    }
23    if n == 0 {
24        return m;
25    }
26
27    // Use the shorter string as the column to minimize memory
28    let (short, long, s_len, l_len) = if m <= n {
29        (a_bytes, b_bytes, m, n)
30    } else {
31        (b_bytes, a_bytes, n, m)
32    };
33
34    let mut prev_row: Vec<usize> = (0..=s_len).collect();
35    let mut curr_row: Vec<usize> = vec![0; s_len + 1];
36
37    for i in 1..=l_len {
38        curr_row[0] = i;
39        for j in 1..=s_len {
40            let cost = if long[i - 1] == short[j - 1] { 0 } else { 1 };
41            curr_row[j] = (prev_row[j] + 1) // deletion
42                .min(curr_row[j - 1] + 1) // insertion
43                .min(prev_row[j - 1] + cost); // substitution
44        }
45        std::mem::swap(&mut prev_row, &mut curr_row);
46    }
47
48    prev_row[s_len]
49}
50
51/// Normalized Levenshtein similarity in [0.0, 1.0].
52///
53/// Returns `1.0 - distance / max_len`. Two identical strings yield `1.0`;
54/// completely different strings of equal length yield `0.0`.
55pub fn levenshtein_similarity(a: &str, b: &str) -> f64 {
56    let max_len = a.len().max(b.len());
57    if max_len == 0 {
58        return 1.0;
59    }
60    let dist = levenshtein(a, b);
61    1.0 - (dist as f64) / (max_len as f64)
62}
63
64// ── Jaccard Similarity ────────────────────────────────────────────────────
65
66/// Jaccard similarity between the character-level n-gram sets of two strings.
67///
68/// `n` is the n-gram size. Returns `|A ∩ B| / |A ∪ B|` where A and B are
69/// the sets of n-grams. Uses `BTreeSet` for deterministic iteration.
70pub fn jaccard_ngram_similarity(a: &str, b: &str, n: usize) -> f64 {
71    if n == 0 || a.is_empty() || b.is_empty() {
72        return 0.0;
73    }
74
75    let set_a = char_ngram_set(a, n);
76    let set_b = char_ngram_set(b, n);
77
78    let intersection = set_a.iter().filter(|g| set_b.contains(*g)).count();
79    let union = {
80        let mut all = set_a.clone();
81        all.extend(set_b.iter().cloned());
82        all.len()
83    };
84
85    if union == 0 {
86        0.0
87    } else {
88        intersection as f64 / union as f64
89    }
90}
91
92fn char_ngram_set(s: &str, n: usize) -> std::collections::BTreeSet<String> {
93    let chars: Vec<char> = s.chars().collect();
94    let mut set = std::collections::BTreeSet::new();
95    if chars.len() >= n {
96        for window in chars.windows(n) {
97            set.insert(window.iter().collect());
98        }
99    }
100    set
101}
102
103// ── N-Gram Extraction ─────────────────────────────────────────────────────
104
105/// Extract character-level n-grams with frequency counts.
106///
107/// Returns a `BTreeMap` for deterministic ordering.
108pub fn char_ngrams(s: &str, n: usize) -> BTreeMap<String, usize> {
109    let mut counts = BTreeMap::new();
110    let chars: Vec<char> = s.chars().collect();
111    if chars.len() >= n {
112        for window in chars.windows(n) {
113            let gram: String = window.iter().collect();
114            *counts.entry(gram).or_insert(0) += 1;
115        }
116    }
117    counts
118}
119
120/// Extract word-level n-grams with frequency counts.
121///
122/// Splits on whitespace, then collects contiguous word windows.
123pub fn word_ngrams(s: &str, n: usize) -> BTreeMap<String, usize> {
124    let mut counts = BTreeMap::new();
125    let words: Vec<&str> = s.split_whitespace().collect();
126    if words.len() >= n {
127        for window in words.windows(n) {
128            let gram = window.join(" ");
129            *counts.entry(gram).or_insert(0) += 1;
130        }
131    }
132    counts
133}
134
135// ── Tokenization ──────────────────────────────────────────────────────────
136
137/// Simple whitespace tokenizer. Returns token spans as `(start, end)` byte offsets.
138pub fn tokenize_whitespace(s: &str) -> Vec<(usize, usize)> {
139    let bytes = s.as_bytes();
140    let mut spans = Vec::new();
141    let mut i = 0;
142    while i < bytes.len() {
143        // Skip whitespace
144        while i < bytes.len() && bytes[i].is_ascii_whitespace() {
145            i += 1;
146        }
147        if i >= bytes.len() {
148            break;
149        }
150        let start = i;
151        // Scan token
152        while i < bytes.len() && !bytes[i].is_ascii_whitespace() {
153            i += 1;
154        }
155        spans.push((start, i));
156    }
157    spans
158}
159
160/// Word-and-punctuation tokenizer. Splits on whitespace, then separates
161/// leading/trailing punctuation into their own tokens.
162pub fn tokenize_words(s: &str) -> Vec<String> {
163    let mut tokens = Vec::new();
164    for chunk in s.split_whitespace() {
165        let chars: Vec<char> = chunk.chars().collect();
166        let len = chars.len();
167
168        // Count leading punctuation
169        let mut lead = 0;
170        while lead < len && chars[lead].is_ascii_punctuation() {
171            lead += 1;
172        }
173
174        // Count trailing punctuation
175        let mut trail = 0;
176        while trail < len - lead && chars[len - 1 - trail].is_ascii_punctuation() {
177            trail += 1;
178        }
179
180        // Emit leading punctuation as individual tokens
181        for c in &chars[..lead] {
182            tokens.push(c.to_string());
183        }
184
185        // Emit the word body (if any)
186        let body_end = len - trail;
187        if body_end > lead {
188            let body: String = chars[lead..body_end].iter().collect();
189            tokens.push(body);
190        }
191
192        // Emit trailing punctuation as individual tokens
193        for c in &chars[body_end..] {
194            tokens.push(c.to_string());
195        }
196    }
197    tokens
198}
199
200/// Convert a string to lowercase (ASCII-only, allocation-free for ASCII input).
201pub fn ascii_lowercase(s: &str) -> String {
202    s.chars().map(|c| {
203        if c.is_ascii_uppercase() {
204            (c as u8 + 32) as char
205        } else {
206            c
207        }
208    }).collect()
209}
210
211/// Remove ASCII punctuation from a string.
212pub fn strip_punctuation(s: &str) -> String {
213    s.chars().filter(|c| !c.is_ascii_punctuation()).collect()
214}
215
216/// Compute term frequency (TF) for each word in a string.
217///
218/// Splits on whitespace, converts to lowercase, counts occurrences,
219/// and normalizes by total word count.
220pub fn term_frequency(s: &str) -> BTreeMap<String, f64> {
221    let words: Vec<String> = s.split_whitespace()
222        .map(|w| ascii_lowercase(w))
223        .collect();
224    let total = words.len() as f64;
225    if total == 0.0 {
226        return BTreeMap::new();
227    }
228    let mut counts: BTreeMap<String, usize> = BTreeMap::new();
229    for w in &words {
230        *counts.entry(w.clone()).or_insert(0) += 1;
231    }
232    counts
233        .into_iter()
234        .map(|(word, count)| (word, count as f64 / total))
235        .collect()
236}
237
238// ── Cosine Similarity ─────────────────────────────────────────────────────
239
240/// Cosine similarity between two term-frequency vectors.
241///
242/// Both vectors are represented as `BTreeMap<String, f64>`. Returns a value
243/// in [0.0, 1.0] for non-negative vectors.
244pub fn cosine_similarity(a: &BTreeMap<String, f64>, b: &BTreeMap<String, f64>) -> f64 {
245    let mut dot = 0.0;
246    let mut norm_a = 0.0;
247    let mut norm_b = 0.0;
248
249    for (key, va) in a {
250        norm_a += va * va;
251        if let Some(vb) = b.get(key) {
252            dot += va * vb;
253        }
254    }
255    for (_, vb) in b {
256        norm_b += vb * vb;
257    }
258
259    let denom = norm_a.sqrt() * norm_b.sqrt();
260    if denom == 0.0 {
261        0.0
262    } else {
263        dot / denom
264    }
265}
266
267// ── Tests ─────────────────────────────────────────────────────────────────
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn test_levenshtein_identical() {
275        assert_eq!(levenshtein("hello", "hello"), 0);
276    }
277
278    #[test]
279    fn test_levenshtein_insert() {
280        assert_eq!(levenshtein("abc", "abcd"), 1);
281    }
282
283    #[test]
284    fn test_levenshtein_delete() {
285        assert_eq!(levenshtein("abcd", "abc"), 1);
286    }
287
288    #[test]
289    fn test_levenshtein_substitute() {
290        assert_eq!(levenshtein("abc", "axc"), 1);
291    }
292
293    #[test]
294    fn test_levenshtein_empty() {
295        assert_eq!(levenshtein("", "hello"), 5);
296        assert_eq!(levenshtein("hello", ""), 5);
297        assert_eq!(levenshtein("", ""), 0);
298    }
299
300    #[test]
301    fn test_levenshtein_kitten_sitting() {
302        assert_eq!(levenshtein("kitten", "sitting"), 3);
303    }
304
305    #[test]
306    fn test_levenshtein_similarity() {
307        let sim = levenshtein_similarity("hello", "hello");
308        assert!((sim - 1.0).abs() < 1e-10);
309        let sim2 = levenshtein_similarity("abc", "xyz");
310        assert!((sim2 - 0.0).abs() < 1e-10);
311    }
312
313    #[test]
314    fn test_jaccard_identical() {
315        let sim = jaccard_ngram_similarity("hello", "hello", 2);
316        assert!((sim - 1.0).abs() < 1e-10);
317    }
318
319    #[test]
320    fn test_jaccard_disjoint() {
321        let sim = jaccard_ngram_similarity("abc", "xyz", 2);
322        assert!((sim - 0.0).abs() < 1e-10);
323    }
324
325    #[test]
326    fn test_char_ngrams() {
327        let grams = char_ngrams("hello", 2);
328        assert_eq!(grams["he"], 1);
329        assert_eq!(grams["el"], 1);
330        assert_eq!(grams["ll"], 1);
331        assert_eq!(grams["lo"], 1);
332        assert_eq!(grams.len(), 4);
333    }
334
335    #[test]
336    fn test_word_ngrams() {
337        let grams = word_ngrams("the quick brown fox", 2);
338        assert_eq!(grams["the quick"], 1);
339        assert_eq!(grams["quick brown"], 1);
340        assert_eq!(grams["brown fox"], 1);
341        assert_eq!(grams.len(), 3);
342    }
343
344    #[test]
345    fn test_tokenize_whitespace() {
346        let spans = tokenize_whitespace("  hello   world  ");
347        assert_eq!(spans, vec![(2, 7), (10, 15)]);
348    }
349
350    #[test]
351    fn test_tokenize_words() {
352        let tokens = tokenize_words("Hello, world! (test)");
353        assert_eq!(tokens, vec!["Hello", ",", "world", "!", "(", "test", ")"]);
354    }
355
356    #[test]
357    fn test_ascii_lowercase() {
358        assert_eq!(ascii_lowercase("Hello WORLD"), "hello world");
359    }
360
361    #[test]
362    fn test_strip_punctuation() {
363        assert_eq!(strip_punctuation("hello, world!"), "hello world");
364    }
365
366    #[test]
367    fn test_term_frequency() {
368        let tf = term_frequency("the cat sat on the mat");
369        assert!((tf["the"] - 2.0 / 6.0).abs() < 1e-10);
370        assert!((tf["cat"] - 1.0 / 6.0).abs() < 1e-10);
371    }
372
373    #[test]
374    fn test_cosine_similarity_identical() {
375        let tf = term_frequency("hello world");
376        let sim = cosine_similarity(&tf, &tf);
377        assert!((sim - 1.0).abs() < 1e-10);
378    }
379
380    #[test]
381    fn test_cosine_similarity_orthogonal() {
382        let a = term_frequency("cat dog");
383        let b = term_frequency("fish bird");
384        let sim = cosine_similarity(&a, &b);
385        assert!((sim - 0.0).abs() < 1e-10);
386    }
387
388    #[test]
389    fn test_determinism() {
390        for _ in 0..10 {
391            assert_eq!(levenshtein("kitten", "sitting"), 3);
392            let grams = char_ngrams("deterministic", 3);
393            assert_eq!(grams.len(), 11);
394        }
395    }
396}