uwuifier/
lib.rs

1//! fastest text uwuifier in the west
2#![cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3
4#[cfg(target_arch = "x86")]
5use std::arch::x86::*;
6#[cfg(target_arch = "x86_64")]
7use std::arch::x86_64::*;
8
9use std::{ptr, str};
10
11pub mod rng;
12use rng::XorShift32;
13
14pub mod bitap;
15use bitap::Bitap8x16;
16
17#[repr(align(16))]
18struct A([u8; 16]);
19
20/// round up `n` to the next multiple of 16. useful for allocating buffers
21///
22/// # example:
23/// ```
24/// use uwuifier::round_up16;
25/// assert_eq!(round_up16(17), 32);
26/// ```
27#[inline(always)]
28pub fn round_up16(n: usize) -> usize { (n + 15) / 16 * 16 }
29
30#[inline(always)]
31fn pad_zeros(bytes: &mut [u8], len: usize) {
32    for i in len..round_up16(len) {
33        unsafe { *bytes.get_unchecked_mut(i) = 0u8; }
34    }
35}
36
37/// uwuify a string slice
38///
39/// requires the sse4.1 x86 feature
40///
41/// this is probably fine for one-off use, but not very efficient if called multiple times.
42/// use `uwuify_sse` to reduce memory allocations
43///
44/// # example:
45/// ```
46/// use uwuifier::uwuify_str_sse;
47/// assert_eq!(uwuify_str_sse("hello world"), "hewwo wowwd");
48/// ```
49pub fn uwuify_str_sse(s: &str) -> String {
50    let bytes = s.as_bytes();
51    let mut temp1 = vec![0u8; round_up16(bytes.len()) * 16];
52    let mut temp2 = vec![0u8; round_up16(bytes.len()) * 16];
53    unsafe { str::from_utf8_unchecked(uwuify_sse(bytes, &mut temp1, &mut temp2)).to_owned() }
54}
55
56/// uwuify some bytes
57///
58/// requires the sse4.1 x86 feature
59///
60/// `temp_bytes1` and `temp_bytes2` must be buffers of size `round_up16(bytes.len()) * 16`,
61/// because this is the worst-case size of the output. yes, it is annoying to allocate by
62/// hand, but simd :)
63///
64/// the returned slice is the uwu'd result. when working with utf-8 strings, just pass in
65/// the string as raw bytes and convert the output slice back to a string afterwards.
66/// there's also the `uwuify_str_sse` function that is suitable for one-off use with a string slice
67///
68/// # example:
69/// ```
70/// use uwuifier::{uwuify_sse, round_up16};
71/// let s = "hello world";
72/// let b = s.as_bytes();
73/// let mut temp1 = vec![0u8; round_up16(b.len()) * 16];
74/// let mut temp2 = vec![0u8; round_up16(b.len()) * 16];
75/// let res = uwuify_sse(b, &mut temp1, &mut temp2);
76/// assert_eq!(std::str::from_utf8(res).unwrap(), "hewwo wowwd");
77/// ```
78pub fn uwuify_sse<'a>(bytes: &[u8], temp_bytes1: &'a mut [u8], temp_bytes2: &'a mut [u8]) -> &'a [u8] {
79    if !is_x86_feature_detected!("sse4.1") {
80        panic!("sse4.1 feature not detected!");
81    }
82    assert!(temp_bytes1.len() >= round_up16(bytes.len()) * 16);
83    assert!(temp_bytes2.len() >= round_up16(bytes.len()) * 16);
84
85    // only the highest quality seed will do
86    let mut rng = XorShift32::new(b"uwu!");
87
88    let mut len = bytes.len();
89
90    unsafe {
91        // bitap_sse will not read past len, unlike the other passes
92        len = bitap_sse(bytes, len, temp_bytes1);
93        pad_zeros(temp_bytes1, len);
94        len = nya_ify_sse(temp_bytes1, len, temp_bytes2);
95        pad_zeros(temp_bytes2, len);
96        len = replace_and_stutter_sse(&mut rng, temp_bytes2, len, temp_bytes1);
97        pad_zeros(temp_bytes1, len);
98        len = emoji_sse(&mut rng, temp_bytes1, len, temp_bytes2);
99        &temp_bytes2[..len]
100    }
101}
102
103#[target_feature(enable = "sse4.1")]
104unsafe fn bitap_sse(in_bytes: &[u8], mut len: usize, out_bytes: &mut [u8]) -> usize {
105    let mut out_ptr = out_bytes.as_mut_ptr();
106    let mut bitap = Bitap8x16::new();
107    let iter_len = len;
108
109    for i in 0..iter_len {
110        let c = *in_bytes.get_unchecked(i);
111        ptr::write(out_ptr, c);
112        out_ptr = out_ptr.add(1);
113
114        if let Some(m) = bitap.next(c) {
115            let replace = _mm_load_si128(m.replace_ptr);
116            _mm_storeu_si128(out_ptr.sub(m.match_len) as *mut __m128i, replace);
117            out_ptr = out_ptr.add(m.replace_len).sub(m.match_len);
118            len = len + m.replace_len - m.match_len;
119            bitap.reset();
120        }
121    }
122
123    len
124}
125
126const fn str_to_bytes(s: &str) -> A {
127    let bytes = s.as_bytes();
128    let mut res = A([0u8; 16]);
129    let mut i = 0;
130
131    while i < bytes.len() {
132        res.0[i] = bytes[i];
133        i += 1;
134    }
135
136    res
137}
138
139// this lookup table needs to be power of two sized
140const LUT_SIZE: usize = 32;
141static LUT: [A; LUT_SIZE] = [
142    str_to_bytes(" rawr x3"),
143    str_to_bytes(" OwO"),
144    str_to_bytes(" UwU"),
145    str_to_bytes(" o.O"),
146    str_to_bytes(" -.-"),
147    str_to_bytes(" >w<"),
148    str_to_bytes(" (⑅˘꒳˘)"),
149    str_to_bytes(" (ꈍᴗꈍ)"),
150    str_to_bytes(" (˘ω˘)"),
151    str_to_bytes(" (U ᵕ U❁)"),
152    str_to_bytes(" σωσ"),
153    str_to_bytes(" òωó"),
154    str_to_bytes(" (///ˬ///✿)"),
155    str_to_bytes(" (U ﹏ U)"),
156    str_to_bytes(" ( ͡o ω ͡o )"),
157    str_to_bytes(" ʘwʘ"),
158    str_to_bytes(" :3"),
159    str_to_bytes(" :3"), // important enough to have twice
160    str_to_bytes(" XD"),
161    str_to_bytes(" nyaa~~"),
162    str_to_bytes(" mya"),
163    str_to_bytes(" >_<"),
164    str_to_bytes(" 😳"),
165    str_to_bytes(" 🥺"),
166    str_to_bytes(" 😳😳😳"),
167    str_to_bytes(" rawr"),
168    str_to_bytes(" ^^"),
169    str_to_bytes(" ^^;;"),
170    str_to_bytes(" (ˆ ﻌ ˆ)♡"),
171    str_to_bytes(" ^•ﻌ•^"),
172    str_to_bytes(" /(^•ω•^)"),
173    str_to_bytes(" (✿oωo)")
174];
175
176const fn bytes_len(b: &[u8]) -> usize {
177    let mut len = 0;
178
179    while len < b.len() && b[len] != 0 {
180        len += 1;
181    }
182
183    len
184}
185
186const fn get_len(a: &[A]) -> [usize; LUT_SIZE] {
187    let mut res = [0usize; LUT_SIZE];
188    let mut i = 0;
189
190    while i < a.len() {
191        res[i] = bytes_len(&a[i].0);
192        i += 1;
193    }
194
195    res
196}
197
198static LUT_LEN: [usize; LUT_SIZE] = get_len(&LUT);
199
200#[target_feature(enable = "sse4.1")]
201unsafe fn emoji_sse(rng: &mut XorShift32, in_bytes: &[u8], mut len: usize, out_bytes: &mut [u8]) -> usize {
202    let in_ptr = in_bytes.as_ptr();
203    let mut out_ptr = out_bytes.as_mut_ptr();
204
205    let splat_period = _mm_set1_epi8(b'.' as i8);
206    let splat_comma = _mm_set1_epi8(b',' as i8);
207    let splat_exclamation = _mm_set1_epi8(b'!' as i8);
208    let splat_space = _mm_set1_epi8(b' ' as i8);
209    let splat_tab = _mm_set1_epi8(b'\t' as i8);
210    let splat_newline = _mm_set1_epi8(b'\n' as i8);
211    let indexes = _mm_set_epi8(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
212
213    let lut_bits = LUT.len().trailing_zeros() as u32;
214
215    let iter_len = round_up16(len);
216
217    for i in (0..iter_len).step_by(16) {
218        let vec = _mm_loadu_si128(in_ptr.add(i) as *const __m128i);
219        let mut punctuation_mask = _mm_or_si128(
220            _mm_cmpeq_epi8(vec, splat_comma),
221            _mm_or_si128(_mm_cmpeq_epi8(vec, splat_period), _mm_cmpeq_epi8(vec, splat_exclamation))
222        );
223        // multiple punctuation in a row means no emoji
224        let mut multiple_mask = _mm_and_si128(punctuation_mask, _mm_slli_si128(punctuation_mask, 1));
225        multiple_mask = _mm_or_si128(multiple_mask, _mm_srli_si128(multiple_mask, 1));
226        // punctuation must be followed by a space or else no emoji
227        let space_mask = _mm_or_si128(
228            _mm_cmpeq_epi8(vec, splat_space),
229            _mm_or_si128(_mm_cmpeq_epi8(vec, splat_tab), _mm_cmpeq_epi8(vec, splat_newline))
230        );
231        punctuation_mask = _mm_andnot_si128(
232            multiple_mask,
233            _mm_and_si128(punctuation_mask, _mm_srli_si128(space_mask, 1))
234        );
235        let insert_mask = _mm_movemask_epi8(punctuation_mask) as u32;
236
237        _mm_storeu_si128(out_ptr as *mut __m128i, vec);
238
239        // be lazy and only allow one emoji per vector
240        if insert_mask != 0 {
241            let insert_idx = insert_mask.trailing_zeros() as usize + 1;
242            let rand_idx = rng.gen_bits(lut_bits) as usize;
243            let insert = LUT.get_unchecked(rand_idx);
244            let insert_len = *LUT_LEN.get_unchecked(rand_idx);
245            let insert_vec = _mm_load_si128(insert.0.as_ptr() as *const __m128i);
246            _mm_storeu_si128(out_ptr.add(insert_idx) as *mut __m128i, insert_vec);
247
248            // shuffle to shift right by amount only known at runtime
249            let rest_vec = _mm_shuffle_epi8(vec, _mm_add_epi8(indexes, _mm_set1_epi8(insert_idx as i8)));
250            _mm_storeu_si128(out_ptr.add(insert_idx + insert_len) as *mut __m128i, rest_vec);
251            out_ptr = out_ptr.add(insert_len);
252            len += insert_len;
253        }
254
255        out_ptr = out_ptr.add(16);
256    }
257
258    len
259}
260
261#[target_feature(enable = "sse4.1")]
262unsafe fn nya_ify_sse(in_bytes: &[u8], mut len: usize, out_bytes: &mut [u8]) -> usize {
263    let in_ptr = in_bytes.as_ptr();
264    let mut out_ptr = out_bytes.as_mut_ptr();
265
266    let bit5 = _mm_set1_epi8(0b0010_0000);
267    let splat_n = _mm_set1_epi8(b'n' as i8);
268    let splat_space = _mm_set1_epi8(b' ' as i8);
269    let splat_tab = _mm_set1_epi8(b'\t' as i8);
270    let splat_newline = _mm_set1_epi8(b'\n' as i8);
271    let indexes = _mm_set_epi8(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
272
273    let iter_len = round_up16(len);
274
275    for i in (0..iter_len).step_by(16) {
276        let vec = _mm_loadu_si128(in_ptr.add(i) as *const __m128i);
277        let n_mask = _mm_cmpeq_epi8(_mm_or_si128(vec, bit5), splat_n);
278        let space_mask = _mm_or_si128(
279            _mm_cmpeq_epi8(vec, splat_space),
280            _mm_or_si128(_mm_cmpeq_epi8(vec, splat_tab), _mm_cmpeq_epi8(vec, splat_newline))
281        );
282        // only nya-ify if its space followed by 'n'
283        let space_and_n_mask = _mm_and_si128(_mm_slli_si128(space_mask, 1), n_mask);
284        let mut nya_mask = _mm_movemask_epi8(space_and_n_mask) as u32;
285
286        _mm_storeu_si128(out_ptr as *mut __m128i, vec);
287
288        // try to nya-ify as many as possible in the current vector
289        while nya_mask != 0 {
290            let nya_idx = nya_mask.trailing_zeros() as usize;
291            ptr::write(out_ptr.add(nya_idx + 1), b'y');
292            // shuffle to shift by amount only known at runtime
293            let shifted = _mm_shuffle_epi8(vec, _mm_add_epi8(indexes, _mm_set1_epi8(nya_idx as i8 + 1)));
294            _mm_storeu_si128(out_ptr.add(nya_idx + 2) as *mut __m128i, shifted);
295            out_ptr = out_ptr.add(1);
296            len += 1;
297            nya_mask &= nya_mask - 1;
298        }
299
300        out_ptr = out_ptr.add(16);
301    }
302
303    len
304}
305
306#[target_feature(enable = "sse4.1")]
307unsafe fn replace_and_stutter_sse(rng: &mut XorShift32, in_bytes: &[u8], mut len: usize, out_bytes: &mut [u8]) -> usize {
308    let in_ptr = in_bytes.as_ptr();
309    let mut out_ptr = out_bytes.as_mut_ptr();
310
311    let bit5 = _mm_set1_epi8(0b0010_0000);
312    let splat_backtick = _mm_set1_epi8(b'`' as i8);
313    let splat_open_brace = _mm_set1_epi8(b'{' as i8);
314    let splat_l = _mm_set1_epi8(b'l' as i8);
315    let splat_r = _mm_set1_epi8(b'r' as i8);
316    let splat_w = _mm_set1_epi8(b'w' as i8);
317    let splat_space = _mm_set1_epi8(b' ' as i8);
318    let splat_tab = _mm_set1_epi8(b'\t' as i8);
319    let splat_newline = _mm_set1_epi8(b'\n' as i8);
320    let indexes = _mm_set_epi8(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
321
322    let iter_len = round_up16(len);
323
324    for i in (0..iter_len).step_by(16) {
325        // replace 'l' and 'r' with 'w'
326        let vec = _mm_loadu_si128(in_ptr.add(i) as *const __m128i);
327        let vec_but_lower = _mm_or_si128(vec, bit5);
328        let alpha_mask = _mm_and_si128(_mm_cmpgt_epi8(vec_but_lower, splat_backtick), _mm_cmpgt_epi8(splat_open_brace, vec_but_lower));
329        let replace_mask = _mm_or_si128(_mm_cmpeq_epi8(vec_but_lower, splat_l), _mm_cmpeq_epi8(vec_but_lower, splat_r));
330        let replaced = _mm_blendv_epi8(vec_but_lower, splat_w, replace_mask);
331        // make sure only alphabetical characters are lowercased and replaced
332        let mut res = _mm_blendv_epi8(vec, replaced, alpha_mask);
333
334        // sometimes, add a stutter if there is a space, tab, or newline followed by any letter
335        let space_mask = _mm_or_si128(
336            _mm_cmpeq_epi8(vec, splat_space),
337            _mm_or_si128(_mm_cmpeq_epi8(vec, splat_tab), _mm_cmpeq_epi8(vec, splat_newline))
338        );
339        let space_and_alpha_mask = _mm_and_si128(_mm_slli_si128(space_mask, 1), alpha_mask);
340        let stutter_mask = _mm_movemask_epi8(space_and_alpha_mask) as u32;
341
342        _mm_storeu_si128(out_ptr as *mut __m128i, res);
343
344        if stutter_mask != 0 {
345            let stutter_idx = stutter_mask.trailing_zeros() as usize;
346            // shuffle to shift by amount only known at runtime
347            res = _mm_shuffle_epi8(res, _mm_add_epi8(indexes, _mm_set1_epi8(stutter_idx as i8)));
348            _mm_storeu_si128(out_ptr.add(stutter_idx) as *mut __m128i, _mm_insert_epi8(res, b'-' as i32, 1));
349            // decide whether to stutter in a branchless way
350            // a branch would mispredict often since this is random
351            let increment = if rng.gen_bool() { 2 } else { 0 };
352            _mm_storeu_si128(out_ptr.add(stutter_idx + increment) as *mut __m128i, res);
353            out_ptr = out_ptr.add(increment);
354            len += increment;
355        }
356
357        out_ptr = out_ptr.add(16);
358    }
359
360    len
361}
362
363#[cfg(test)]
364mod tests {
365    use std::str;
366
367    use super::*;
368
369    #[test]
370    fn test_uwuify_sse() {
371        let mut temp_bytes1 = [0u8; 1024];
372        let mut temp_bytes2 = [0u8; 1024];
373
374        let s = "Hey, I think I really love you. Do you want a headpat?";
375        let res_bytes = uwuify_sse(s.as_bytes(), &mut temp_bytes1, &mut temp_bytes2);
376        let res = str::from_utf8(res_bytes).unwrap();
377        assert_eq!(res, "hey, (ꈍᴗꈍ) i think i weawwy wuv you. ^•ﻌ•^ do y-you want a headpat?");
378    }
379
380    #[test]
381    fn test_uwuify_str_sse() {
382        let s = "Hey, I think I really love you. Do you want a headpat?";
383        let res = uwuify_str_sse(s);
384        assert_eq!(res, "hey, (ꈍᴗꈍ) i think i weawwy wuv you. ^•ﻌ•^ do y-you want a headpat?");
385    }
386}