Skip to main content

simd_normalizer/
casefold.rs

1//! Unicode simple case folding (CaseFolding.txt, status C+S).
2//!
3//! Provides character-level and string-level case folding for case-insensitive
4//! matching. Supports both standard folding and Turkish/Azerbaijani locale mode.
5
6use alloc::borrow::Cow;
7use alloc::string::String;
8
9use crate::tables;
10
11/// Case folding mode.
12#[derive(Clone, Copy, Debug, PartialEq, Eq)]
13pub enum CaseFoldMode {
14    /// Standard Unicode case folding (CaseFolding.txt status C+S).
15    Standard,
16    /// Turkish/Azerbaijani locale folding.
17    ///
18    /// Overrides:
19    /// - U+0049 (I) → U+0131 (ı) instead of U+0069 (i)
20    /// - U+0130 (İ) → U+0069 (i) instead of standard mapping
21    Turkish,
22}
23
24/// Fold a single character using simple case folding.
25///
26/// Returns the folded character, or the input character unchanged if no
27/// folding applies.
28#[inline]
29pub fn casefold_char(c: char, mode: CaseFoldMode) -> char {
30    // Turkish exceptions override the standard mapping.
31    if mode == CaseFoldMode::Turkish
32        && let Some(folded) = tables::turkish_casefold(c)
33    {
34        return folded;
35    }
36    tables::lookup_casefold(c).unwrap_or(c)
37}
38
39/// Fold a string using simple case folding.
40///
41/// Returns `Cow::Borrowed` if the string is already fully case-folded
42/// (no characters changed).
43///
44/// In `CaseFoldMode::Standard`, the implementation runs an ASCII fast path:
45/// 64-byte chunks are scanned via the dispatched SIMD scanner with bound
46/// `0x80`. Chunks with zero non-ASCII bytes get a scalar `mask | 0x20`-style
47/// lowercase pass (no per-byte trie lookup), which dominates throughput on
48/// ASCII / Latin-1 inputs. Non-ASCII chunks fall back to the per-codepoint
49/// trie-driven path. Other modes (Turkish, etc.) skip the fast path because
50/// their override rules apply within the ASCII range.
51pub fn casefold<'a>(input: &'a str, mode: CaseFoldMode) -> Cow<'a, str> {
52    if input.is_empty() {
53        return Cow::Borrowed(input);
54    }
55
56    if mode == CaseFoldMode::Standard {
57        casefold_ascii_fastpath(input)
58    } else {
59        casefold_scalar(input, mode)
60    }
61}
62
63/// Scalar fallback used by both the non-Standard modes and the ASCII fast
64/// path's tail / non-ASCII region. Walks codepoints through the casefold
65/// trie; returns `Cow::Borrowed` if nothing changed.
66fn casefold_scalar<'a>(input: &'a str, mode: CaseFoldMode) -> Cow<'a, str> {
67    // Quick scan: find first character that would change.
68    let mut scan_iter = input.char_indices();
69    let first_change = loop {
70        match scan_iter.next() {
71            None => return Cow::Borrowed(input),
72            Some((idx, ch)) => {
73                let folded = casefold_char(ch, mode);
74                if folded != ch {
75                    break idx;
76                }
77            },
78        }
79    };
80
81    // Build the output: copy unchanged prefix, then fold the rest.
82    let mut out = String::with_capacity(input.len());
83    out.push_str(&input[..first_change]);
84
85    for ch in input[first_change..].chars() {
86        out.push(casefold_char(ch, mode));
87    }
88
89    Cow::Owned(out)
90}
91
92/// Standard-mode casefold with a 64-byte SIMD-driven ASCII fast path.
93///
94/// We walk the input in 64-byte chunks, scanning with `bound = 0x80` to
95/// detect any non-ASCII byte. ASCII-only chunks are lowercased via the
96/// scalar `0x41..=0x5A → +0x20` rule (no trie lookup, single byte per
97/// position). The first chunk containing a non-ASCII byte switches the
98/// remainder of the input over to the per-codepoint trie-driven path.
99fn casefold_ascii_fastpath<'a>(input: &'a str) -> Cow<'a, str> {
100    let bytes = input.as_bytes();
101    let len = bytes.len();
102    let ptr = bytes.as_ptr();
103
104    // First scan: locate the first non-ASCII byte (if any) and the first
105    // ASCII uppercase byte (if any), to decide whether allocation is needed.
106    let mut pos = 0usize;
107    let mut first_change: Option<usize> = None;
108
109    // SIMD-driven ASCII probe: 64-byte chunks scanning for any byte >= 0x80.
110    while pos + 64 <= len {
111        // SAFETY: `pos + 64 <= len`, so the pointer is valid for 64 bytes.
112        let nonascii = unsafe { crate::simd::scan_chunk(ptr.add(pos), 0x80) };
113        if nonascii != 0 {
114            // Non-ASCII somewhere in this chunk — break out and delegate to
115            // the scalar path for the entire input. Trying to splice the
116            // ASCII prefix here would not save work (the scalar path's own
117            // pre-scan does the same prefix detection in tight scalar code).
118            return casefold_scalar(input, CaseFoldMode::Standard);
119        }
120        // Pure-ASCII chunk: probe for an uppercase byte to decide whether
121        // we even need to allocate. We use a second SIMD scan with bound
122        // `0x41` (`'A'`) and refine in scalar if any byte >= 'A' exists,
123        // since 'A'..='Z' is a tiny window inside [0x41, 0x80).
124        let upper_or_more = unsafe { crate::simd::scan_chunk(ptr.add(pos), b'A') };
125        if upper_or_more != 0 {
126            // Some byte is >= 'A'. Find the first byte that is uppercase ASCII.
127            let mut mask = upper_or_more;
128            while mask != 0 {
129                let bit = mask.trailing_zeros() as usize;
130                mask &= mask.wrapping_sub(1);
131                let b = bytes[pos + bit];
132                if b.is_ascii_uppercase() {
133                    first_change = Some(pos + bit);
134                    break;
135                }
136            }
137            if first_change.is_some() {
138                break;
139            }
140        }
141        pos += 64;
142    }
143
144    // Tail (or whole input if it's < 64 bytes): scan byte-by-byte for the
145    // first uppercase ASCII or any non-ASCII byte.
146    if first_change.is_none() {
147        let mut tail = pos;
148        while tail < len {
149            let b = bytes[tail];
150            if b >= 0x80 {
151                // Hit a non-ASCII byte before finding any uppercase: defer to
152                // scalar (it will re-scan, but the input is by definition
153                // not pure ASCII so the SIMD fast path is exhausted anyway).
154                return casefold_scalar(input, CaseFoldMode::Standard);
155            }
156            if b.is_ascii_uppercase() {
157                first_change = Some(tail);
158                break;
159            }
160            tail += 1;
161        }
162    }
163
164    let Some(start) = first_change else {
165        // Pure ASCII, no uppercase: borrowed.
166        return Cow::Borrowed(input);
167    };
168
169    // We have a definite change at `start`. Build the output:
170    //   - copy bytes [0, start) verbatim
171    //   - lowercase bytes [start, ?) in scalar: `b | 0x20` for 'A'..='Z',
172    //     copy others, until we hit either end-of-input or a non-ASCII byte.
173    //   - if we hit a non-ASCII byte, append the per-codepoint folded tail.
174    let mut out = String::with_capacity(len);
175    // SAFETY: bytes [0, start) are pure ASCII (we only walked past them
176    // when no byte was >= 0x80), so they are valid UTF-8.
177    out.push_str(unsafe { core::str::from_utf8_unchecked(&bytes[..start]) });
178
179    let mut i = start;
180    while i < len {
181        let b = bytes[i];
182        if b >= 0x80 {
183            // Switch to per-codepoint fallback for the rest of the input.
184            // SAFETY: `i` is on a UTF-8 boundary because we only advanced
185            // through ASCII bytes (each 1 byte wide) up to this point.
186            let rest = unsafe { core::str::from_utf8_unchecked(&bytes[i..]) };
187            for ch in rest.chars() {
188                out.push(casefold_char(ch, CaseFoldMode::Standard));
189            }
190            return Cow::Owned(out);
191        }
192        if b.is_ascii_uppercase() {
193            // Lowercase via OR with 0x20.
194            out.push((b | 0x20) as char);
195        } else {
196            out.push(b as char);
197        }
198        i += 1;
199    }
200    Cow::Owned(out)
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    // ---- Character-level tests ----
208
209    #[test]
210    fn fold_ascii_uppercase() {
211        assert_eq!(casefold_char('A', CaseFoldMode::Standard), 'a');
212        assert_eq!(casefold_char('Z', CaseFoldMode::Standard), 'z');
213    }
214
215    #[test]
216    fn fold_ascii_lowercase_unchanged() {
217        assert_eq!(casefold_char('a', CaseFoldMode::Standard), 'a');
218        assert_eq!(casefold_char('z', CaseFoldMode::Standard), 'z');
219    }
220
221    #[test]
222    fn fold_digit_unchanged() {
223        assert_eq!(casefold_char('0', CaseFoldMode::Standard), '0');
224        assert_eq!(casefold_char('9', CaseFoldMode::Standard), '9');
225    }
226
227    #[test]
228    fn fold_latin_extended() {
229        // U+00C0 À → U+00E0 à
230        assert_eq!(
231            casefold_char('\u{00C0}', CaseFoldMode::Standard),
232            '\u{00E0}'
233        );
234        // U+00D6 Ö → U+00F6 ö
235        assert_eq!(
236            casefold_char('\u{00D6}', CaseFoldMode::Standard),
237            '\u{00F6}'
238        );
239    }
240
241    #[test]
242    fn fold_greek() {
243        // U+0391 Α → U+03B1 α
244        assert_eq!(
245            casefold_char('\u{0391}', CaseFoldMode::Standard),
246            '\u{03B1}'
247        );
248        // U+03A3 Σ → U+03C3 σ
249        assert_eq!(
250            casefold_char('\u{03A3}', CaseFoldMode::Standard),
251            '\u{03C3}'
252        );
253    }
254
255    #[test]
256    fn fold_cyrillic() {
257        // U+0410 А → U+0430 а
258        assert_eq!(
259            casefold_char('\u{0410}', CaseFoldMode::Standard),
260            '\u{0430}'
261        );
262    }
263
264    #[test]
265    fn fold_micro_sign() {
266        // U+00B5 µ (MICRO SIGN) → U+03BC μ (GREEK SMALL LETTER MU)
267        assert_eq!(
268            casefold_char('\u{00B5}', CaseFoldMode::Standard),
269            '\u{03BC}'
270        );
271    }
272
273    #[test]
274    fn fold_sharp_s() {
275        // U+1E9E ẞ (LATIN CAPITAL LETTER SHARP S) → U+00DF ß
276        assert_eq!(
277            casefold_char('\u{1E9E}', CaseFoldMode::Standard),
278            '\u{00DF}'
279        );
280    }
281
282    // ---- Turkish mode ----
283
284    #[test]
285    fn fold_turkish_dotless_i() {
286        // Standard: I → i
287        assert_eq!(casefold_char('I', CaseFoldMode::Standard), 'i');
288        // Turkish: I → ı (U+0131)
289        assert_eq!(casefold_char('I', CaseFoldMode::Turkish), '\u{0131}');
290    }
291
292    #[test]
293    fn fold_turkish_dotted_capital_i() {
294        // Turkish: İ (U+0130) → i
295        assert_eq!(casefold_char('\u{0130}', CaseFoldMode::Turkish), 'i');
296    }
297
298    #[test]
299    fn fold_turkish_other_chars_unchanged() {
300        // Non-I characters should fold the same in Turkish mode.
301        assert_eq!(casefold_char('A', CaseFoldMode::Turkish), 'a');
302        assert_eq!(casefold_char('a', CaseFoldMode::Turkish), 'a');
303    }
304
305    // ---- String-level tests ----
306
307    #[test]
308    fn fold_string_ascii() {
309        let result = casefold("Hello World", CaseFoldMode::Standard);
310        assert_eq!(&*result, "hello world");
311    }
312
313    #[test]
314    fn fold_string_already_folded() {
315        let result = casefold("hello world", CaseFoldMode::Standard);
316        assert!(matches!(result, Cow::Borrowed(_)));
317        assert_eq!(&*result, "hello world");
318    }
319
320    #[test]
321    fn fold_string_empty() {
322        let result = casefold("", CaseFoldMode::Standard);
323        assert!(matches!(result, Cow::Borrowed(_)));
324    }
325
326    #[test]
327    fn fold_string_mixed() {
328        let result = casefold("Ströme", CaseFoldMode::Standard);
329        assert_eq!(&*result, "ströme");
330    }
331
332    #[test]
333    fn fold_string_turkish() {
334        let result = casefold("Istanbul", CaseFoldMode::Turkish);
335        // I → ı in Turkish mode
336        assert_eq!(&*result, "\u{0131}stanbul");
337    }
338
339    #[test]
340    fn fold_string_all_ascii_lowercase() {
341        // Should return borrowed.
342        let result = casefold(
343            "abcdefghijklmnopqrstuvwxyz0123456789",
344            CaseFoldMode::Standard,
345        );
346        assert!(matches!(result, Cow::Borrowed(_)));
347    }
348}