whichlang/
lib.rs

1pub use crate::weights::{Lang, LANGUAGES};
2
3#[allow(clippy::all)]
4mod weights;
5
6const NUM_LANGUAGES: usize = LANGUAGES.len();
7
8#[doc(hidden)]
9pub const DIMENSION: usize = 1 << 12;
10const BIGRAM_MASK: u32 = (1 << 16) - 1;
11const TRIGRAM_MASK: u32 = (1 << 24) - 1;
12
13#[derive(Copy, Clone, Debug, Eq, PartialEq)]
14#[doc(hidden)]
15pub enum Feature {
16    AsciiNGram(u32),
17    Unicode(char),
18    UnicodeClass(char),
19}
20
21const SEED: u32 = 3_242_157_231u32;
22
23#[inline(always)]
24fn murmurhash2(mut k: u32, seed: u32) -> u32 {
25    const M: u32 = 0x5bd1_e995;
26    let mut h: u32 = seed;
27    k = k.wrapping_mul(M);
28    k ^= k >> 24;
29    k = k.wrapping_mul(M);
30    h = h.wrapping_mul(M);
31    h ^= k;
32    h ^= h >> 13;
33    h = h.wrapping_mul(M);
34    h ^ (h >> 15)
35}
36
37impl Feature {
38    #[inline(always)]
39    pub fn to_hash(&self) -> u32 {
40        match self {
41            Feature::AsciiNGram(ngram) => murmurhash2(*ngram, SEED),
42            Feature::Unicode(chr) => murmurhash2(*chr as u32 / 128, SEED ^ 2),
43            Feature::UnicodeClass(chr) => murmurhash2(classify_codepoint(*chr), SEED ^ 4),
44        }
45    }
46}
47
48pub fn detect_language(text: &str) -> Lang {
49    let mut scores: [f32; NUM_LANGUAGES] = Default::default();
50    let mut num_features: u32 = 0;
51    emit_tokens(
52        text,
53        #[inline(always)]
54        |token| {
55            num_features += 1u32;
56            let bucket = token.to_hash() % DIMENSION as u32;
57            let idx = bucket as usize * NUM_LANGUAGES;
58            let per_language_scores = &weights::WEIGHTS[idx..idx + NUM_LANGUAGES];
59            for i in 0..NUM_LANGUAGES {
60                scores[i] += per_language_scores[i];
61            }
62        },
63    );
64    if num_features == 0 {
65        // By default, we return English
66        return Lang::Eng;
67    }
68
69    let sqrt_inv_num_features = 1.0f32 / (num_features as f32).sqrt();
70    #[allow(clippy::needless_range_loop)]
71    for i in 0..NUM_LANGUAGES {
72        // Ok so the sqrt(num_features) is not really the norm, but whatever.
73        scores[i] = scores[i] * sqrt_inv_num_features + weights::INTERCEPTS[i];
74    }
75
76    let lang_id = scores
77        .iter()
78        .enumerate()
79        .max_by(|(_, &score_left), (_, &score_right)| score_left.partial_cmp(&score_right).unwrap())
80        .map(|(pos, _val)| pos)
81        .unwrap();
82    weights::LANGUAGES[lang_id]
83}
84
85#[doc(hidden)]
86pub fn emit_tokens(text: &str, mut listener: impl FnMut(Feature)) {
87    let mut prev = ' ' as u32;
88    let mut num_previous_ascii_chr = 1;
89    for chr in text.chars() {
90        let code = chr.to_ascii_lowercase() as u32;
91        if !chr.is_ascii() {
92            listener(Feature::Unicode(chr));
93            listener(Feature::UnicodeClass(chr));
94            num_previous_ascii_chr = 0;
95            continue;
96        }
97        prev = prev << 8 | code;
98        match num_previous_ascii_chr {
99            0 => {
100                num_previous_ascii_chr = 1;
101            }
102            1 => {
103                listener(Feature::AsciiNGram(prev & BIGRAM_MASK));
104                num_previous_ascii_chr = 2;
105            }
106            2 => {
107                listener(Feature::AsciiNGram(prev & BIGRAM_MASK));
108                listener(Feature::AsciiNGram(prev & TRIGRAM_MASK));
109                num_previous_ascii_chr = 3;
110            }
111            3 => {
112                listener(Feature::AsciiNGram(prev & BIGRAM_MASK));
113                listener(Feature::AsciiNGram(prev & TRIGRAM_MASK));
114                listener(Feature::AsciiNGram(prev));
115            }
116            _ => {
117                unreachable!();
118            }
119        }
120        if !chr.is_alphanumeric() {
121            prev = ' ' as u32;
122        }
123    }
124}
125
126const JP_PUNCT_START: u32 = 0x3000;
127const JP_PUNCT_END: u32 = 0x303f;
128const JP_HIRAGANA_START: u32 = 0x3040;
129const JP_HIRAGANA_END: u32 = 0x309f;
130const JP_KATAKANA_START: u32 = 0x30a0;
131const JP_KATAKANA_END: u32 = 0x30ff;
132const CJK_KANJI_START: u32 = 0x4e00;
133const CJK_KANJI_END: u32 = 0x9faf;
134const JP_HALFWIDTH_KATAKANA_START: u32 = 0xff61;
135const JP_HALFWIDTH_KATAKANA_END: u32 = 0xff90;
136
137fn classify_codepoint(chr: char) -> u32 {
138    [
139        160,
140        161,
141        171,
142        172,
143        173,
144        174,
145        187,
146        192,
147        196,
148        199,
149        200,
150        201,
151        202,
152        205,
153        214,
154        220,
155        223,
156        224,
157        225,
158        226,
159        227,
160        228,
161        231,
162        232,
163        233,
164        234,
165        235,
166        236,
167        237,
168        238,
169        239,
170        242,
171        243,
172        244,
173        245,
174        246,
175        249,
176        250,
177        251,
178        252,
179        333,
180        339,
181        JP_PUNCT_START,
182        JP_PUNCT_END,
183        JP_HIRAGANA_START,
184        JP_HIRAGANA_END,
185        JP_KATAKANA_START,
186        JP_KATAKANA_END,
187        CJK_KANJI_START,
188        CJK_KANJI_END,
189        JP_HALFWIDTH_KATAKANA_START,
190        JP_HALFWIDTH_KATAKANA_END,
191    ]
192    .binary_search(&(chr as u32))
193    .unwrap_or_else(|pos| pos) as u32
194}
195
196#[cfg(test)]
197mod tests {
198    use crate::detect_language;
199    use crate::emit_tokens;
200    use crate::Feature;
201    use crate::Lang;
202
203    fn ascii_ngram_feature(text: &str) -> Feature {
204        assert!(text.is_ascii());
205        let mut bytes: [u8; 4] = [0u8; 4];
206        assert!(text.len() <= 4);
207        bytes[4 - text.len()..].copy_from_slice(text.as_bytes());
208        Feature::AsciiNGram(u32::from_be_bytes(bytes))
209    }
210
211    #[test]
212    fn test_emit_tokens() {
213        let mut tokens = Vec::new();
214        emit_tokens("hello こん!", |token| tokens.push(token));
215        assert_eq!(
216            &tokens,
217            &[
218                ascii_ngram_feature(" h"),
219                ascii_ngram_feature("he"),
220                ascii_ngram_feature(" he"),
221                ascii_ngram_feature("el"),
222                ascii_ngram_feature("hel"),
223                ascii_ngram_feature(" hel"),
224                ascii_ngram_feature("ll"),
225                ascii_ngram_feature("ell"),
226                ascii_ngram_feature("hell"),
227                ascii_ngram_feature("lo"),
228                ascii_ngram_feature("llo"),
229                ascii_ngram_feature("ello"),
230                Feature::Unicode(' '),
231                Feature::UnicodeClass(' '),
232                Feature::Unicode('こ'),
233                Feature::UnicodeClass('こ'),
234                Feature::Unicode('ん'),
235                Feature::UnicodeClass('ん'),
236                Feature::Unicode('!'),
237                Feature::UnicodeClass('!'),
238            ]
239        );
240    }
241
242    #[test]
243    fn test_empty_str() {
244        assert_eq!(detect_language(""), Lang::Eng);
245    }
246
247    #[test]
248    fn test_detect_language() {
249        // English
250        assert_eq!(detect_language("Hello, happy tax payer"), Lang::Eng);
251        // French
252        assert_eq!(detect_language("Bonjour joyeux contribuable"), Lang::Fra);
253        // German
254        assert_eq!(detect_language("Hallo glücklicher Steuerzahler"), Lang::Deu);
255        // Japanese
256        assert_eq!(detect_language("こんにちは幸せな税金納め"), Lang::Jpn);
257        // Mandarin chinese
258        assert_eq!(detect_language("你好幸福的纳税人"), Lang::Cmn);
259        // Turkish
260        assert_eq!(detect_language("Merhaba, mutlu vergi mükellefi"), Lang::Tur);
261        // Dutch
262        assert_eq!(detect_language("Hallo, blije belastingbetaler"), Lang::Nld);
263        // Korean
264        assert_eq!(detect_language("안녕하세요 행복한 납세자입니다"), Lang::Kor);
265        // Italian
266        assert_eq!(detect_language("Ciao, felice contribuente!"), Lang::Ita);
267        // Spanish
268        assert_eq!(detect_language("Hola feliz contribuyente"), Lang::Spa);
269        assert_eq!(detect_language("¡Hola!"), Lang::Spa);
270        // Portuguese
271        assert_eq!(detect_language("Olá feliz contribuinte"), Lang::Por);
272    }
273}