tac_k_lib/
lib.rs

1use memmap2::Mmap;
2
3use std::fs::File;
4use std::io::prelude::*;
5use std::io::Result;
6use std::path::Path;
7
8const MAX_BUF_SIZE: usize = 4 * 1024 * 1024; // 4 MiB
9
10#[cfg_attr(
11    target_family = "unix",
12    allow(unreachable_code),
13    allow(unused_mut),
14    allow(unused_variables)
15)]
16/// Write the reversed content from `path` into `writer`, last line first.
17///
18/// If `path` is `Some(_)`, read from the file at the specified path.
19/// If `path` is `None`, read from `stdin` instead.
20///
21/// `separator` is used to partition the content into lines.
22/// This is normally the newline character, `b'\n'`.
23///
24/// Internally it uses the following instruction set extensions
25/// to enable SIMD acceleration if available at runtime:
26/// - AVX2/LZCNT(ABM)/BMI2 on x64/x64_84
27/// - NEON on AArch64
28///
29/// ## Example
30///
31/// ```
32/// use tac_k::reverse_file;
33/// use std::path::Path;
34///
35/// // Read from `README.md` file, separated by '.'.
36/// let mut result = vec![];
37/// reverse_file(&mut result, Some("README.md"), b'.').unwrap();
38///
39/// assert!(std::str::from_utf8(&result).is_ok());
40///
41/// // Read from stdin.
42/// let mut result = vec![];
43/// reverse_file(&mut result, None::<&str>, b'.').unwrap();
44///
45/// assert!(result.is_empty());
46/// ```
47pub fn reverse_file<W: Write, P: AsRef<Path>>(writer: &mut W, path: Option<P>, separator: u8) -> Result<()> {
48    fn inner(writer: &mut dyn Write, path: Option<&Path>, separator: u8) -> Result<()> {
49        let mut temp_path = None;
50        {
51            let mmap;
52            let mut buf;
53            let bytes = match path {
54                #[cfg_attr(not(target_family = "unix"), allow(unused_labels))]
55                None => 'stdin: {
56                    // Depending on what the STDIN fd actually points to, it may still be possible to
57                    // mmap the input (e.g. in case of `tac - < foo.txt`).
58                    #[cfg(target_family = "unix")]
59                    {
60                        let stdin = std::io::stdin();
61                        if let Ok(stdin) = unsafe { Mmap::map(&stdin) } {
62                            mmap = stdin;
63                            break 'stdin &mmap[..];
64                        }
65                    }
66
67                    // We unfortunately need to buffer the entirety of the stdin input first;
68                    // we try to do so purely in memory but will switch to a backing file if
69                    // the input exceeds MAX_BUF_SIZE.
70                    buf = vec![0; MAX_BUF_SIZE];
71                    let mut reader = std::io::stdin();
72                    let mut total_read = 0;
73
74                    // Once/if we switch to a file-backed buffer, this will contain the handle.
75                    loop {
76                        let bytes_read = reader.read(&mut buf[total_read..])?;
77                        if bytes_read == 0 {
78                            break &buf[0..total_read];
79                        }
80                        total_read += bytes_read;
81
82                        if total_read == MAX_BUF_SIZE {
83                            temp_path = Some(std::env::temp_dir().join(format!(".tac-{}", std::process::id())));
84                            let mut temp_file = File::create(temp_path.as_ref().unwrap())?;
85                            // Write everything we've read so far
86                            temp_file.write_all(&buf)?;
87                            // Copy remaining bytes directly from stdin
88                            std::io::copy(&mut reader, &mut temp_file)?;
89                            mmap = unsafe { Mmap::map(&temp_file)? };
90                            break &mmap[..];
91                        }
92                    }
93                }
94                Some(path) => {
95                    let file = File::open(path)?;
96                    mmap = unsafe { Mmap::map(&file)? };
97                    &mmap[..]
98                }
99            };
100
101            search_auto(bytes, separator, writer)?;
102        }
103
104        if let Some(ref path) = temp_path.as_ref() {
105            // This should never fail unless we've somehow kept a handle open to it
106            if let Err(e) = std::fs::remove_file(path) {
107                eprintln!("Error: failed to remove temporary file {}\n{}", path.display(), e)
108            };
109        }
110
111        writer.flush()?;
112        Ok(())
113    }
114    inner(writer, path.as_ref().map(AsRef::as_ref), separator)
115}
116
117fn search_auto(bytes: &[u8], separator: u8, mut output: &mut dyn Write) -> Result<()> {
118    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
119    if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("lzcnt") && is_x86_feature_detected!("bmi2") {
120        return unsafe { search256(bytes, separator, &mut output) };
121    }
122
123    #[cfg(target_arch = "aarch64")]
124    if std::arch::is_aarch64_feature_detected!("neon") {
125        return unsafe { search128(bytes, separator, &mut output) };
126    }
127
128    search(bytes, separator, &mut output)
129}
130
131/// This is the default, naïve byte search
132#[inline(always)]
133fn search(bytes: &[u8], separator: u8, output: &mut dyn Write) -> Result<()> {
134    let mut last_printed = bytes.len();
135    slow_search_and_print(bytes, 0, last_printed, &mut last_printed, separator, output)?;
136    output.write_all(&bytes[..last_printed])?;
137    Ok(())
138}
139
140#[inline(always)]
141/// Search a range index-by-index and write to `output` when a match is found. Primarily used to
142/// search before/after the aligned portion of a range.
143fn slow_search_and_print(
144    bytes: &[u8],
145    start: usize,
146    end: usize,
147    stop: &mut usize,
148    separator: u8,
149    output: &mut dyn Write,
150) -> Result<()> {
151    for index in (start..end).rev() {
152        if bytes[index] == separator {
153            output.write_all(&bytes[index + 1..*stop])?;
154            *stop = index + 1;
155        }
156    }
157
158    Ok(())
159}
160
161#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
162#[target_feature(enable = "avx2")]
163#[target_feature(enable = "lzcnt")]
164#[target_feature(enable = "bmi2")]
165/// This is an AVX2-optimized newline search function that searches a 32-byte (256-bit) window
166/// instead of scanning character-by-character (once aligned). This is a *safe* function, but must
167/// be adorned with `unsafe` to guarantee it's not called without first checking for AVX2 support.
168///
169/// We need to explicitly enable lzcnt support for u32::leading_zeros() to use the `lzcnt`
170/// instruction instead of an extremely slow combination of branching + BSR.
171///
172/// BMI2 is explicitly opted into to inline the BZHI instruction; otherwise a call to the intrinsic
173/// function is added and not inlined.
174unsafe fn search256(bytes: &[u8], separator: u8, mut output: &mut dyn Write) -> Result<()> {
175    #[cfg(target_arch = "x86")]
176    use core::arch::x86::*;
177    #[cfg(target_arch = "x86_64")]
178    use core::arch::x86_64::*;
179
180    #[cfg(target_arch = "x86")]
181    const SIZE: u32 = 32;
182    #[cfg(target_arch = "x86_64")]
183    const SIZE: u32 = 64;
184
185    const ALIGNMENT: usize = std::mem::align_of::<__m256i>();
186
187    let ptr = bytes.as_ptr();
188    let len = bytes.len();
189    let mut last_printed = len;
190    let mut remaining = len;
191
192    // We should only use 32-byte (256-bit) aligned reads w/ AVX2 intrinsics.
193    // Search unaligned bytes via slow method so subsequent haystack reads are always aligned.
194    // Guaranteed to have at least two aligned blocks
195    if len >= ALIGNMENT * 3 - 1 {
196        // Regardless of whether or not the base pointer is aligned to a 32-byte address, we are
197        // reading from an arbitrary offset (determined by the length of the lines) and so we must
198        // first calculate a safe place to begin using SIMD operations from.
199        let align_offset = unsafe { ptr.add(len) }.align_offset(ALIGNMENT);
200        if align_offset != 0 {
201            let aligned_index = len + align_offset - ALIGNMENT;
202            debug_assert!(aligned_index < len && aligned_index > 0);
203            debug_assert!((ptr as usize + aligned_index) % ALIGNMENT == 0);
204
205            // eprintln!("Unoptimized search from {} to {}", aligned_index, last_printed);
206            slow_search_and_print(bytes, aligned_index, len, &mut last_printed, separator, &mut output)?;
207            remaining = aligned_index;
208        } else {
209            // `bytes` end in an aligned block, no need to offset
210            debug_assert!((ptr as usize + len) % ALIGNMENT == 0);
211        }
212
213        let pattern256 = unsafe { _mm256_set1_epi8(separator as i8) };
214        while remaining >= SIZE as usize {
215            let window_end_offset = remaining;
216            unsafe {
217                remaining -= 32;
218                let search256 = _mm256_load_si256(ptr.add(remaining) as *const __m256i);
219                let result256 = _mm256_cmpeq_epi8(search256, pattern256);
220                let part = _mm256_movemask_epi8(result256) as u32;
221                let mut matches;
222
223                // For 32-bit x86 architecture only one part can be loaded. 64-bit x86_64 can load another part
224                // to find the matches.
225                #[cfg(target_arch = "x86")]
226                {
227                    matches = part;
228                }
229                #[cfg(target_arch = "x86_64")]
230                {
231                    remaining -= 32;
232                    let search256 = _mm256_load_si256(ptr.add(remaining) as *const __m256i);
233                    let result256 = _mm256_cmpeq_epi8(search256, pattern256);
234                    matches = ((part as u64) << 32) | _mm256_movemask_epi8(result256) as u32 as u64;
235                }
236
237                while matches != 0 {
238                    // We would count *trailing* zeroes to find new lines in reverse order, but the
239                    // result mask is in little endian (reversed) order, so we do the very
240                    // opposite.
241                    // core::intrinsics::ctlz() is not stabilized, but `u64::leading_zeros()` will
242                    // use it directly if the lzcnt or bmi1 features are enabled.
243                    let leading = matches.leading_zeros();
244                    let offset = window_end_offset - leading as usize;
245
246                    output.write_all(&bytes[offset..last_printed])?;
247                    last_printed = offset;
248
249                    // Clear this match from the matches bitset.
250                    #[cfg(target_arch = "x86")]
251                    {
252                        matches = _bzhi_u32(matches, SIZE - 1 - leading);
253                    }
254                    #[cfg(target_arch = "x86_64")]
255                    {
256                        matches = _bzhi_u64(matches, SIZE - 1 - leading);
257                    }
258                }
259            }
260        }
261    }
262
263    if remaining != 0 {
264        // eprintln!("Unoptimized end search from {} to {}", 0, index);
265        slow_search_and_print(bytes, 0, remaining, &mut last_printed, separator, &mut output)?;
266    }
267
268    // Regardless of whether or not `index` is zero, as this is predicated on `last_printed`
269    output.write_all(&bytes[..last_printed])?;
270
271    Ok(())
272}
273
274#[cfg(target_arch = "aarch64")]
275#[target_feature(enable = "neon")]
276/// This is a NEON/AdvSIMD-optimized newline search function that searches a 16-byte (128-bit) window
277/// instead of scanning character-by-character (once aligned).
278unsafe fn search128(bytes: &[u8], separator: u8, mut output: &mut dyn Write) -> Result<()> {
279    use core::arch::aarch64::*;
280
281    let ptr = bytes.as_ptr();
282    let mut last_printed = bytes.len();
283    let mut index = last_printed - 1;
284
285    if index >= 64 {
286        // ARMv8 loads do not have alignment *requirements*, but there can be performance penalties
287        // (e.g. seems to be about 2% slowdown on Cortex-A72 with a 500MB file) so let's align.
288        // Search unaligned bytes via slow method so subsequent haystack reads are always aligned.
289        let align_offset = unsafe { ptr.add(index).align_offset(16) };
290        let aligned_index = index + align_offset - 16;
291
292        // eprintln!("Unoptimized search from {} to {}", aligned_index, last_printed);
293        slow_search_and_print(
294            bytes,
295            aligned_index,
296            last_printed,
297            &mut last_printed,
298            separator,
299            &mut output,
300        )?;
301        index = aligned_index;
302
303        let pattern128 = unsafe { vdupq_n_u8(separator) };
304        while index >= 64 {
305            let window_end_offset = index;
306            unsafe {
307                index -= 16;
308                let window = ptr.add(index);
309                let search128 = vld1q_u8(window);
310                let result128_0 = vceqq_u8(search128, pattern128);
311
312                index -= 16;
313                let window = ptr.add(index);
314                let search128 = vld1q_u8(window);
315                let result128_1 = vceqq_u8(search128, pattern128);
316
317                index -= 16;
318                let window = ptr.add(index);
319                let search128 = vld1q_u8(window);
320                let result128_2 = vceqq_u8(search128, pattern128);
321
322                index -= 16;
323                let window = ptr.add(index);
324                let search128 = vld1q_u8(window);
325                let result128_3 = vceqq_u8(search128, pattern128);
326
327                // Bulk movemask as described in
328                // https://branchfree.org/2019/04/01/fitting-my-head-through-the-arm-holes/
329                let mut matches = {
330                    let bit_mask: uint8x16_t = std::mem::transmute([
331                        0x01u8, 0x02, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80,
332                    ]);
333                    let t0 = vandq_u8(result128_3, bit_mask);
334                    let t1 = vandq_u8(result128_2, bit_mask);
335                    let t2 = vandq_u8(result128_1, bit_mask);
336                    let t3 = vandq_u8(result128_0, bit_mask);
337                    let sum0 = vpaddq_u8(t0, t1);
338                    let sum1 = vpaddq_u8(t2, t3);
339                    let sum0 = vpaddq_u8(sum0, sum1);
340                    let sum0 = vpaddq_u8(sum0, sum0);
341                    vgetq_lane_u64(vreinterpretq_u64_u8(sum0), 0)
342                };
343
344                while matches != 0 {
345                    // We would count *trailing* zeroes to find new lines in reverse order, but the
346                    // result mask is in little endian (reversed) order, so we do the very
347                    // opposite.
348                    let leading = matches.leading_zeros();
349                    let offset = window_end_offset - leading as usize;
350
351                    output.write_all(&bytes[offset..last_printed])?;
352                    last_printed = offset;
353
354                    // Clear this match from the matches bitset.
355                    matches &= !(1 << (64 - leading - 1));
356                }
357            }
358        }
359    }
360
361    if index != 0 {
362        // eprintln!("Unoptimized end search from {} to {}", 0, index);
363        slow_search_and_print(bytes, 0, index, &mut last_printed, separator, &mut output)?;
364    }
365
366    // Regardless of whether or not `index` is zero, as this is predicated on `last_printed`
367    output.write_all(&bytes[0..last_printed])?;
368
369    Ok(())
370}
371
372#[cfg(test)]
373mod tests {
374    #[allow(unused_imports)]
375    use super::*;
376
377    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
378    #[cfg(target_os = "linux")]
379    #[test]
380    fn test_x86_simd() {
381        let mut file = File::open("/dev/urandom").unwrap();
382        let mut buffer = [0; 1023];
383        for _ in 0..100_000 {
384            test(&buffer);
385            file.read_exact(&mut buffer).unwrap();
386        }
387
388        fn test(buf: &[u8]) {
389            let mut slow_result = Vec::new();
390            let mut simd_result = Vec::new();
391            search(buf, b'.', &mut slow_result).unwrap();
392            unsafe { search256(buf, b'.', &mut simd_result).unwrap() };
393            assert_eq!(slow_result, simd_result);
394        }
395    }
396}