wg_utils/unsafe/
io.rs

1use std::fs::File;
2use std::io::{self, Read, Seek, SeekFrom, Write};
3use std::mem;
4use std::os::unix::io::{AsRawFd, RawFd};
5use std::ptr;
6use std::slice;
7
8pub struct RawIO {
9    fd: RawFd,
10    owned: bool,
11}
12
13impl RawIO {
14    pub unsafe fn from_raw_fd(fd: RawFd, owned: bool) -> Self {
15        Self { fd, owned }
16    }
17
18    pub unsafe fn from_file(file: File) -> Self {
19        let fd = file.as_raw_fd();
20        mem::forget(file);
21        Self { fd, owned: true }
22    }
23
24    pub fn raw_fd(&self) -> RawFd {
25        self.fd
26    }
27
28    pub unsafe fn read_direct(&self, buf: *mut u8, len: usize) -> io::Result<usize> {
29        let ret = libc::read(self.fd, buf as *mut libc::c_void, len);
30        if ret < 0 {
31            Err(io::Error::last_os_error())
32        } else {
33            Ok(ret as usize)
34        }
35    }
36
37    pub unsafe fn write_direct(&self, buf: *const u8, len: usize) -> io::Result<usize> {
38        let ret = libc::write(self.fd, buf as *const libc::c_void, len);
39        if ret < 0 {
40            Err(io::Error::last_os_error())
41        } else {
42            Ok(ret as usize)
43        }
44    }
45
46    pub fn seek(&self, pos: i64, whence: i32) -> io::Result<u64> {
47        let ret = unsafe { libc::lseek(self.fd, pos, whence) };
48        if ret < 0 {
49            Err(io::Error::last_os_error())
50        } else {
51            Ok(ret as u64)
52        }
53    }
54
55    pub fn fsync(&self) -> io::Result<()> {
56        let ret = unsafe { libc::fsync(self.fd) };
57        if ret < 0 {
58            Err(io::Error::last_os_error())
59        } else {
60            Ok(())
61        }
62    }
63
64    pub unsafe fn mmap(
65        &self,
66        len: usize,
67        prot: i32,
68        flags: i32,
69        offset: i64,
70    ) -> io::Result<*mut u8> {
71        let ptr = libc::mmap(ptr::null_mut(), len, prot, flags, self.fd, offset);
72
73        if ptr == libc::MAP_FAILED {
74            Err(io::Error::last_os_error())
75        } else {
76            Ok(ptr as *mut u8)
77        }
78    }
79
80    pub unsafe fn munmap(&self, addr: *mut u8, len: usize) -> io::Result<()> {
81        let ret = libc::munmap(addr as *mut libc::c_void, len);
82        if ret < 0 {
83            Err(io::Error::last_os_error())
84        } else {
85            Ok(())
86        }
87    }
88
89    pub unsafe fn madvise(&self, addr: *mut u8, len: usize, advice: i32) -> io::Result<()> {
90        let ret = libc::madvise(addr as *mut libc::c_void, len, advice);
91        if ret < 0 {
92            Err(io::Error::last_os_error())
93        } else {
94            Ok(())
95        }
96    }
97
98    pub fn pread(&self, buf: &mut [u8], offset: i64) -> io::Result<usize> {
99        let ret = unsafe {
100            libc::pread(
101                self.fd,
102                buf.as_mut_ptr() as *mut libc::c_void,
103                buf.len(),
104                offset,
105            )
106        };
107
108        if ret < 0 {
109            Err(io::Error::last_os_error())
110        } else {
111            Ok(ret as usize)
112        }
113    }
114
115    pub fn pwrite(&self, buf: &[u8], offset: i64) -> io::Result<usize> {
116        let ret = unsafe {
117            libc::pwrite(
118                self.fd,
119                buf.as_ptr() as *const libc::c_void,
120                buf.len(),
121                offset,
122            )
123        };
124
125        if ret < 0 {
126            Err(io::Error::last_os_error())
127        } else {
128            Ok(ret as usize)
129        }
130    }
131
132    pub fn readv(&self, iovecs: &mut [libc::iovec]) -> io::Result<usize> {
133        let ret = unsafe { libc::readv(self.fd, iovecs.as_mut_ptr(), iovecs.len() as i32) };
134
135        if ret < 0 {
136            Err(io::Error::last_os_error())
137        } else {
138            Ok(ret as usize)
139        }
140    }
141
142    pub fn writev(&self, iovecs: &[libc::iovec]) -> io::Result<usize> {
143        let ret = unsafe { libc::writev(self.fd, iovecs.as_ptr(), iovecs.len() as i32) };
144
145        if ret < 0 {
146            Err(io::Error::last_os_error())
147        } else {
148            Ok(ret as usize)
149        }
150    }
151
152    pub unsafe fn read_vectored_direct(&self, bufs: &mut [&mut [u8]]) -> io::Result<usize> {
153        let mut iovecs: Vec<libc::iovec> = Vec::with_capacity(bufs.len());
154
155        for buf in bufs.iter_mut() {
156            iovecs.push(libc::iovec {
157                iov_base: buf.as_mut_ptr() as *mut libc::c_void,
158                iov_len: buf.len(),
159            });
160        }
161
162        self.readv(&mut iovecs)
163    }
164
165    pub unsafe fn write_vectored_direct(&self, bufs: &[&[u8]]) -> io::Result<usize> {
166        let mut iovecs: Vec<libc::iovec> = Vec::with_capacity(bufs.len());
167
168        for buf in bufs.iter() {
169            iovecs.push(libc::iovec {
170                iov_base: buf.as_ptr() as *mut libc::c_void,
171                iov_len: buf.len(),
172            });
173        }
174
175        self.writev(&iovecs)
176    }
177
178    pub fn allocate(&self, offset: i64, len: i64) -> io::Result<()> {
179        let ret = unsafe { libc::fallocate(self.fd, 0, offset, len) };
180        if ret < 0 {
181            Err(io::Error::last_os_error())
182        } else {
183            Ok(())
184        }
185    }
186
187    pub fn truncate(&self, len: i64) -> io::Result<()> {
188        let ret = unsafe { libc::ftruncate(self.fd, len) };
189        if ret < 0 {
190            Err(io::Error::last_os_error())
191        } else {
192            Ok(())
193        }
194    }
195}
196
197impl Drop for RawIO {
198    fn drop(&mut self) {
199        if self.owned {
200            unsafe { libc::close(self.fd) };
201        }
202    }
203}
204
205pub struct MemoryMappedFile {
206    addr: *mut u8,
207    len: usize,
208    io: Option<RawIO>,
209}
210
211impl MemoryMappedFile {
212    pub unsafe fn new(file: File, len: usize, write: bool) -> io::Result<Self> {
213        let io = RawIO::from_file(file);
214
215        let prot = libc::PROT_READ | if write { libc::PROT_WRITE } else { 0 };
216        let flags = libc::MAP_SHARED;
217
218        let addr = io.mmap(len, prot, flags, 0)?;
219
220        Ok(Self {
221            addr,
222            len,
223            io: Some(io),
224        })
225    }
226
227    pub unsafe fn anonymous(len: usize) -> io::Result<Self> {
228        let prot = libc::PROT_READ | libc::PROT_WRITE;
229        let flags = libc::MAP_PRIVATE | libc::MAP_ANONYMOUS;
230
231        let addr = libc::mmap(ptr::null_mut(), len, prot, flags, -1, 0);
232
233        if addr == libc::MAP_FAILED {
234            Err(io::Error::last_os_error())
235        } else {
236            Ok(Self {
237                addr: addr as *mut u8,
238                len,
239                io: None,
240            })
241        }
242    }
243
244    pub fn as_slice(&self) -> &[u8] {
245        unsafe { slice::from_raw_parts(self.addr, self.len) }
246    }
247
248    pub fn as_mut_slice(&mut self) -> &mut [u8] {
249        unsafe { slice::from_raw_parts_mut(self.addr, self.len) }
250    }
251
252    pub fn as_ptr(&self) -> *const u8 {
253        self.addr
254    }
255
256    pub fn as_mut_ptr(&mut self) -> *mut u8 {
257        self.addr
258    }
259
260    pub fn advise(&self, advice: i32) -> io::Result<()> {
261        if let Some(ref io) = self.io {
262            unsafe { io.madvise(self.addr, self.len, advice) }
263        } else {
264            unsafe {
265                let ret = libc::madvise(self.addr as *mut libc::c_void, self.len, advice);
266                if ret < 0 {
267                    Err(io::Error::last_os_error())
268                } else {
269                    Ok(())
270                }
271            }
272        }
273    }
274
275    pub fn sync(&self, sync_flags: i32) -> io::Result<()> {
276        unsafe {
277            let ret = libc::msync(self.addr as *mut libc::c_void, self.len, sync_flags);
278            if ret < 0 {
279                Err(io::Error::last_os_error())
280            } else {
281                Ok(())
282            }
283        }
284    }
285
286    pub fn len(&self) -> usize {
287        self.len
288    }
289}
290
291impl Drop for MemoryMappedFile {
292    fn drop(&mut self) {
293        unsafe {
294            libc::munmap(self.addr as *mut libc::c_void, self.len);
295        }
296    }
297}
298
299pub fn direct_copy(src: &RawIO, dst: &RawIO, buffer_size: usize) -> io::Result<u64> {
300    let mut buffer = Vec::with_capacity(buffer_size);
301    unsafe {
302        buffer.set_len(buffer_size);
303    }
304
305    let mut total_copied = 0;
306
307    loop {
308        let read_bytes = unsafe { src.read_direct(buffer.as_mut_ptr(), buffer_size) }?;
309        if read_bytes == 0 {
310            break;
311        }
312
313        let mut written = 0;
314        while written < read_bytes {
315            let n =
316                unsafe { dst.write_direct(buffer.as_ptr().add(written), read_bytes - written) }?;
317
318            if n == 0 {
319                return Err(io::Error::new(io::ErrorKind::WriteZero, "failed to write"));
320            }
321
322            written += n;
323        }
324
325        total_copied += read_bytes as u64;
326    }
327
328    Ok(total_copied)
329}
330
331#[cfg(target_os = "linux")]
332pub fn splice_copy(src: &RawIO, dst: &RawIO, len: usize) -> io::Result<u64> {
333    let mut total = 0;
334    let mut remaining = len;
335
336    while remaining > 0 {
337        let chunk_size = remaining.min(0x7ffff000);
338        let ret = unsafe {
339            libc::splice(
340                src.raw_fd(),
341                ptr::null_mut(),
342                dst.raw_fd(),
343                ptr::null_mut(),
344                chunk_size,
345                libc::SPLICE_F_MOVE,
346            )
347        };
348
349        if ret < 0 {
350            return Err(io::Error::last_os_error());
351        }
352
353        if ret == 0 {
354            break;
355        }
356
357        total += ret as u64;
358        remaining -= ret as usize;
359    }
360
361    Ok(total)
362}