1use memmap2::Mmap;
2
3use std::fs::File;
4use std::io::prelude::*;
5use std::io::Result;
6use std::path::Path;
7
8const MAX_BUF_SIZE: usize = 4 * 1024 * 1024; #[cfg_attr(
11 target_family = "unix",
12 allow(unreachable_code),
13 allow(unused_mut),
14 allow(unused_variables)
15)]
16pub fn reverse_file<W: Write, P: AsRef<Path>>(writer: &mut W, path: Option<P>, separator: u8) -> Result<()> {
48 fn inner(writer: &mut dyn Write, path: Option<&Path>, separator: u8) -> Result<()> {
49 let mut temp_path = None;
50 {
51 let mmap;
52 let mut buf;
53 let bytes = match path {
54 #[cfg_attr(not(target_family = "unix"), allow(unused_labels))]
55 None => 'stdin: {
56 #[cfg(target_family = "unix")]
59 {
60 let stdin = std::io::stdin();
61 if let Ok(stdin) = unsafe { Mmap::map(&stdin) } {
62 mmap = stdin;
63 break 'stdin &mmap[..];
64 }
65 }
66
67 buf = vec![0; MAX_BUF_SIZE];
71 let mut reader = std::io::stdin();
72 let mut total_read = 0;
73
74 loop {
76 let bytes_read = reader.read(&mut buf[total_read..])?;
77 if bytes_read == 0 {
78 break &buf[0..total_read];
79 }
80 total_read += bytes_read;
81
82 if total_read == MAX_BUF_SIZE {
83 temp_path = Some(std::env::temp_dir().join(format!(".tac-{}", std::process::id())));
84 let mut temp_file = File::create(temp_path.as_ref().unwrap())?;
85 temp_file.write_all(&buf)?;
87 std::io::copy(&mut reader, &mut temp_file)?;
89 mmap = unsafe { Mmap::map(&temp_file)? };
90 break &mmap[..];
91 }
92 }
93 }
94 Some(path) => {
95 let file = File::open(path)?;
96 mmap = unsafe { Mmap::map(&file)? };
97 &mmap[..]
98 }
99 };
100
101 search_auto(bytes, separator, writer)?;
102 }
103
104 if let Some(ref path) = temp_path.as_ref() {
105 if let Err(e) = std::fs::remove_file(path) {
107 eprintln!("Error: failed to remove temporary file {}\n{}", path.display(), e)
108 };
109 }
110
111 writer.flush()?;
112 Ok(())
113 }
114 inner(writer, path.as_ref().map(AsRef::as_ref), separator)
115}
116
117fn search_auto(bytes: &[u8], separator: u8, mut output: &mut dyn Write) -> Result<()> {
118 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
119 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("lzcnt") && is_x86_feature_detected!("bmi2") {
120 return unsafe { search256(bytes, separator, &mut output) };
121 }
122
123 #[cfg(target_arch = "aarch64")]
124 if std::arch::is_aarch64_feature_detected!("neon") {
125 return unsafe { search128(bytes, separator, &mut output) };
126 }
127
128 search(bytes, separator, &mut output)
129}
130
131#[inline(always)]
133fn search(bytes: &[u8], separator: u8, output: &mut dyn Write) -> Result<()> {
134 let mut last_printed = bytes.len();
135 slow_search_and_print(bytes, 0, last_printed, &mut last_printed, separator, output)?;
136 output.write_all(&bytes[..last_printed])?;
137 Ok(())
138}
139
140#[inline(always)]
141fn slow_search_and_print(
144 bytes: &[u8],
145 start: usize,
146 end: usize,
147 stop: &mut usize,
148 separator: u8,
149 output: &mut dyn Write,
150) -> Result<()> {
151 for index in (start..end).rev() {
152 if bytes[index] == separator {
153 output.write_all(&bytes[index + 1..*stop])?;
154 *stop = index + 1;
155 }
156 }
157
158 Ok(())
159}
160
161#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
162#[target_feature(enable = "avx2")]
163#[target_feature(enable = "lzcnt")]
164#[target_feature(enable = "bmi2")]
165unsafe fn search256(bytes: &[u8], separator: u8, mut output: &mut dyn Write) -> Result<()> {
175 #[cfg(target_arch = "x86")]
176 use core::arch::x86::*;
177 #[cfg(target_arch = "x86_64")]
178 use core::arch::x86_64::*;
179
180 #[cfg(target_arch = "x86")]
181 const SIZE: u32 = 32;
182 #[cfg(target_arch = "x86_64")]
183 const SIZE: u32 = 64;
184
185 const ALIGNMENT: usize = std::mem::align_of::<__m256i>();
186
187 let ptr = bytes.as_ptr();
188 let len = bytes.len();
189 let mut last_printed = len;
190 let mut remaining = len;
191
192 if len >= ALIGNMENT * 3 - 1 {
196 let align_offset = unsafe { ptr.add(len) }.align_offset(ALIGNMENT);
200 if align_offset != 0 {
201 let aligned_index = len + align_offset - ALIGNMENT;
202 debug_assert!(aligned_index < len && aligned_index > 0);
203 debug_assert!((ptr as usize + aligned_index) % ALIGNMENT == 0);
204
205 slow_search_and_print(bytes, aligned_index, len, &mut last_printed, separator, &mut output)?;
207 remaining = aligned_index;
208 } else {
209 debug_assert!((ptr as usize + len) % ALIGNMENT == 0);
211 }
212
213 let pattern256 = unsafe { _mm256_set1_epi8(separator as i8) };
214 while remaining >= SIZE as usize {
215 let window_end_offset = remaining;
216 unsafe {
217 remaining -= 32;
218 let search256 = _mm256_load_si256(ptr.add(remaining) as *const __m256i);
219 let result256 = _mm256_cmpeq_epi8(search256, pattern256);
220 let part = _mm256_movemask_epi8(result256) as u32;
221 let mut matches;
222
223 #[cfg(target_arch = "x86")]
226 {
227 matches = part;
228 }
229 #[cfg(target_arch = "x86_64")]
230 {
231 remaining -= 32;
232 let search256 = _mm256_load_si256(ptr.add(remaining) as *const __m256i);
233 let result256 = _mm256_cmpeq_epi8(search256, pattern256);
234 matches = ((part as u64) << 32) | _mm256_movemask_epi8(result256) as u32 as u64;
235 }
236
237 while matches != 0 {
238 let leading = matches.leading_zeros();
244 let offset = window_end_offset - leading as usize;
245
246 output.write_all(&bytes[offset..last_printed])?;
247 last_printed = offset;
248
249 #[cfg(target_arch = "x86")]
251 {
252 matches = _bzhi_u32(matches, SIZE - 1 - leading);
253 }
254 #[cfg(target_arch = "x86_64")]
255 {
256 matches = _bzhi_u64(matches, SIZE - 1 - leading);
257 }
258 }
259 }
260 }
261 }
262
263 if remaining != 0 {
264 slow_search_and_print(bytes, 0, remaining, &mut last_printed, separator, &mut output)?;
266 }
267
268 output.write_all(&bytes[..last_printed])?;
270
271 Ok(())
272}
273
274#[cfg(target_arch = "aarch64")]
275#[target_feature(enable = "neon")]
276unsafe fn search128(bytes: &[u8], separator: u8, mut output: &mut dyn Write) -> Result<()> {
279 use core::arch::aarch64::*;
280
281 let ptr = bytes.as_ptr();
282 let mut last_printed = bytes.len();
283 let mut index = last_printed - 1;
284
285 if index >= 64 {
286 let align_offset = unsafe { ptr.add(index).align_offset(16) };
290 let aligned_index = index + align_offset - 16;
291
292 slow_search_and_print(
294 bytes,
295 aligned_index,
296 last_printed,
297 &mut last_printed,
298 separator,
299 &mut output,
300 )?;
301 index = aligned_index;
302
303 let pattern128 = unsafe { vdupq_n_u8(separator) };
304 while index >= 64 {
305 let window_end_offset = index;
306 unsafe {
307 index -= 16;
308 let window = ptr.add(index);
309 let search128 = vld1q_u8(window);
310 let result128_0 = vceqq_u8(search128, pattern128);
311
312 index -= 16;
313 let window = ptr.add(index);
314 let search128 = vld1q_u8(window);
315 let result128_1 = vceqq_u8(search128, pattern128);
316
317 index -= 16;
318 let window = ptr.add(index);
319 let search128 = vld1q_u8(window);
320 let result128_2 = vceqq_u8(search128, pattern128);
321
322 index -= 16;
323 let window = ptr.add(index);
324 let search128 = vld1q_u8(window);
325 let result128_3 = vceqq_u8(search128, pattern128);
326
327 let mut matches = {
330 let bit_mask: uint8x16_t = std::mem::transmute([
331 0x01u8, 0x02, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80,
332 ]);
333 let t0 = vandq_u8(result128_3, bit_mask);
334 let t1 = vandq_u8(result128_2, bit_mask);
335 let t2 = vandq_u8(result128_1, bit_mask);
336 let t3 = vandq_u8(result128_0, bit_mask);
337 let sum0 = vpaddq_u8(t0, t1);
338 let sum1 = vpaddq_u8(t2, t3);
339 let sum0 = vpaddq_u8(sum0, sum1);
340 let sum0 = vpaddq_u8(sum0, sum0);
341 vgetq_lane_u64(vreinterpretq_u64_u8(sum0), 0)
342 };
343
344 while matches != 0 {
345 let leading = matches.leading_zeros();
349 let offset = window_end_offset - leading as usize;
350
351 output.write_all(&bytes[offset..last_printed])?;
352 last_printed = offset;
353
354 matches &= !(1 << (64 - leading - 1));
356 }
357 }
358 }
359 }
360
361 if index != 0 {
362 slow_search_and_print(bytes, 0, index, &mut last_printed, separator, &mut output)?;
364 }
365
366 output.write_all(&bytes[0..last_printed])?;
368
369 Ok(())
370}
371
372#[cfg(test)]
373mod tests {
374 #[allow(unused_imports)]
375 use super::*;
376
377 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
378 #[cfg(target_os = "linux")]
379 #[test]
380 fn test_x86_simd() {
381 let mut file = File::open("/dev/urandom").unwrap();
382 let mut buffer = [0; 1023];
383 for _ in 0..100_000 {
384 test(&buffer);
385 file.read_exact(&mut buffer).unwrap();
386 }
387
388 fn test(buf: &[u8]) {
389 let mut slow_result = Vec::new();
390 let mut simd_result = Vec::new();
391 search(buf, b'.', &mut slow_result).unwrap();
392 unsafe { search256(buf, b'.', &mut simd_result).unwrap() };
393 assert_eq!(slow_result, simd_result);
394 }
395 }
396}