Skip to main content

tur_rs/
storage.rs

1use std::path::Path;
2
3use anyhow::Result;
4#[cfg(target_os = "linux")]
5use anyhow::anyhow;
6use tokio::fs::File;
7#[cfg(not(any(target_os = "linux", target_os = "macos")))]
8use tokio::fs::OpenOptions;
9use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom};
10#[cfg(target_os = "linux")]
11use tokio::sync::{mpsc, oneshot};
12
13mod aligned_buffer;
14#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
15mod common;
16#[cfg(target_os = "linux")]
17mod linux;
18#[cfg(target_os = "macos")]
19mod macos;
20#[cfg(target_os = "windows")]
21mod windows;
22
23#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
24use common as platform;
25#[cfg(target_os = "linux")]
26use linux as platform;
27#[cfg(target_os = "macos")]
28use macos as platform;
29#[cfg(target_os = "windows")]
30use windows as platform;
31
32use aligned_buffer::AlignedBuffer;
33#[cfg(target_os = "linux")]
34use aligned_buffer::LinuxIoUringCommand;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub struct StorageConfig {
38    pub use_pwrite: bool,
39    pub use_splice: bool,
40    pub no_io_uring: bool,
41    pub no_direct_io: bool,
42}
43
44impl Default for StorageConfig {
45    fn default() -> Self {
46        Self {
47            use_pwrite: true,
48            use_splice: true,
49            no_io_uring: false,
50            no_direct_io: false,
51        }
52    }
53}
54
55pub fn prepare_download_file(path: &Path, total_size: u64) -> Result<()> {
56    if let Some(parent) = path.parent() {
57        std::fs::create_dir_all(parent)?;
58    }
59    let file = std::fs::File::create(path)?;
60    file.set_len(total_size)?;
61    platform::prepare_download_file(&file, total_size)?;
62    Ok(())
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum StorageBackendKind {
67    Standard,
68    LinuxTokio,
69    LinuxPwrite,
70    LinuxSplice,
71    LinuxIoUring,
72    MacosPwrite,
73    MacosNoCache,
74    WindowsPwrite,
75    WindowsDirectIo,
76    WindowsSequential,
77}
78
79enum DownloadFileInner {
80    #[cfg_attr(target_os = "linux", allow(dead_code))]
81    Tokio(File),
82    #[cfg(target_os = "linux")]
83    LinuxPwrite(std::fs::File),
84    #[cfg(target_os = "linux")]
85    LinuxSplice {
86        file: std::fs::File,
87        pipe_read: std::fs::File,
88        pipe_write: std::fs::File,
89    },
90    #[cfg(target_os = "linux")]
91    LinuxIoUring {
92        tx: mpsc::Sender<LinuxIoUringCommand>,
93        fallback: File,
94    },
95    #[cfg(target_os = "macos")]
96    MacosPwrite(std::fs::File),
97    #[cfg(target_os = "windows")]
98    WindowsPwrite(std::fs::File),
99    #[cfg(target_os = "windows")]
100    WindowsDirectIo(std::fs::File),
101}
102
103pub struct DownloadFile {
104    inner: DownloadFileInner,
105    backend: StorageBackendKind,
106}
107
108impl DownloadFile {
109    pub fn backend(&self) -> StorageBackendKind {
110        self.backend
111    }
112
113    pub fn direct_io_alignment(&self) -> Option<usize> {
114        match self.backend {
115            StorageBackendKind::LinuxIoUring => Some(platform::DIRECT_IO_ALIGNMENT),
116            StorageBackendKind::WindowsDirectIo => Some(platform::DIRECT_IO_ALIGNMENT),
117            _ => None,
118        }
119    }
120
121    pub async fn write_all_at(&mut self, offset: u64, data: &[u8]) -> Result<()> {
122        match &mut self.inner {
123            DownloadFileInner::Tokio(file) => {
124                platform::write_all_at_tokio(file, offset, data).await
125            }
126            #[cfg(target_os = "linux")]
127            DownloadFileInner::LinuxPwrite(file) => {
128                platform::write_all_at_pwrite(file, offset, data).await
129            }
130            #[cfg(target_os = "linux")]
131            DownloadFileInner::LinuxSplice {
132                pipe_write,
133                pipe_read,
134                file,
135            } => platform::write_all_at_splice(file, pipe_read, pipe_write, offset, data).await,
136            #[cfg(target_os = "linux")]
137            DownloadFileInner::LinuxIoUring { tx, fallback } => {
138                let alignment = platform::DIRECT_IO_ALIGNMENT as u64;
139                let start = offset;
140                let end = offset + data.len() as u64;
141
142                let aligned_start = if start % alignment == 0 {
143                    start
144                } else {
145                    start + (alignment - (start % alignment))
146                };
147                let aligned_end = end - (end % alignment);
148
149                if aligned_start >= aligned_end {
150                    return platform::write_all_at_tokio(fallback, offset, data).await;
151                }
152
153                let prefix_len = aligned_start.saturating_sub(start) as usize;
154                if prefix_len > 0 {
155                    platform::write_all_at_tokio(fallback, offset, &data[..prefix_len]).await?;
156                }
157
158                let middle_start = prefix_len;
159                let middle_len = (aligned_end - aligned_start) as usize;
160                let middle_end = middle_start + middle_len;
161                if middle_len > 0 {
162                    let mut aligned = AlignedBuffer::new(middle_len, platform::DIRECT_IO_ALIGNMENT);
163                    aligned.as_mut_slice()[..middle_len]
164                        .copy_from_slice(&data[middle_start..middle_end]);
165
166                    let (resp_tx, resp_rx) = oneshot::channel();
167                    tx.send(LinuxIoUringCommand::WriteAllAt {
168                        offset: aligned_start,
169                        data: aligned,
170                        resp: resp_tx,
171                    })
172                    .await
173                    .map_err(|_| anyhow!("io_uring backend thread is not available"))?;
174                    resp_rx
175                        .await
176                        .map_err(|_| anyhow!("io_uring backend response channel closed"))??;
177                }
178
179                if middle_end < data.len() {
180                    platform::write_all_at_tokio(fallback, aligned_end, &data[middle_end..])
181                        .await?;
182                }
183
184                Ok(())
185            }
186            #[cfg(target_os = "macos")]
187            DownloadFileInner::MacosPwrite(file) => {
188                platform::write_all_at_pwrite(file, offset, data).await
189            }
190            #[cfg(target_os = "windows")]
191            DownloadFileInner::WindowsPwrite(file) => {
192                platform::write_all_at_windows_pwrite(file, offset, data).await
193            }
194            #[cfg(target_os = "windows")]
195            DownloadFileInner::WindowsDirectIo(file) => {
196                platform::write_all_at_windows_direct_io(file, offset, data).await
197            }
198        }
199    }
200}
201
202impl Drop for DownloadFile {
203    fn drop(&mut self) {
204        #[cfg(target_os = "linux")]
205        if let DownloadFileInner::LinuxIoUring { tx, .. } = &self.inner {
206            let _ = tx.try_send(LinuxIoUringCommand::Shutdown);
207        }
208    }
209}
210
211pub async fn open_download_file_for_write(path: &Path) -> Result<DownloadFile> {
212    open_download_file_for_write_with_config(path, &StorageConfig::default()).await
213}
214
215pub async fn open_download_file_for_write_with_config(
216    path: &Path,
217    config: &StorageConfig,
218) -> Result<DownloadFile> {
219    platform::open_download_file_for_write(path, config).await
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[tokio::test]
227    async fn selects_a_supported_backend_for_current_platform() {
228        let dir = std::env::temp_dir().join(format!("tur-storage-{}", uuid::Uuid::new_v4()));
229        std::fs::create_dir_all(&dir).unwrap();
230        let path = dir.join("file.bin");
231        prepare_download_file(&path, 8192).unwrap();
232
233        let file = open_download_file_for_write_with_config(&path, &StorageConfig::default())
234            .await
235            .unwrap();
236        let backend = file.backend();
237
238        #[cfg(target_os = "linux")]
239        assert!(matches!(
240            backend,
241            StorageBackendKind::LinuxIoUring
242                | StorageBackendKind::LinuxSplice
243                | StorageBackendKind::LinuxPwrite
244                | StorageBackendKind::LinuxTokio
245        ));
246        #[cfg(target_os = "macos")]
247        assert!(matches!(backend, StorageBackendKind::MacosPwrite));
248        #[cfg(target_os = "windows")]
249        assert!(matches!(
250            backend,
251            StorageBackendKind::WindowsDirectIo | StorageBackendKind::WindowsPwrite
252        ));
253        #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
254        assert!(matches!(backend, StorageBackendKind::Standard));
255
256        let _ = std::fs::remove_file(path);
257        let _ = std::fs::remove_dir_all(dir);
258    }
259}