Skip to main content

vortex_io/object_store/
write.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::io;
5use std::sync::Arc;
6
7use bytes::BytesMut;
8use futures::TryStreamExt;
9use futures::stream::FuturesUnordered;
10use object_store::MultipartUpload;
11use object_store::ObjectStore;
12use object_store::ObjectStoreExt;
13use object_store::PutPayload;
14use object_store::PutResult;
15use object_store::path::Path;
16use vortex_error::VortexResult;
17
18use crate::IoBuf;
19use crate::VortexWrite;
20
21/// Adapter type to write data through a [`ObjectStore`] instance.
22///
23/// After writing, the caller must make sure to call `shutdown`, in order to ensure the data is actually persisted.
24pub struct ObjectStoreWrite {
25    upload: Box<dyn MultipartUpload>,
26    buffer: BytesMut,
27    put_result: Option<PutResult>,
28}
29
30const CHUNK_SIZE: usize = 16 * 1024 * 1024;
31const BUFFER_SIZE: usize = 128 * 1024 * 1024;
32
33impl ObjectStoreWrite {
34    pub async fn new(object_store: Arc<dyn ObjectStore>, location: &Path) -> VortexResult<Self> {
35        let upload = object_store.put_multipart(location).await?;
36        Ok(Self {
37            upload,
38            buffer: BytesMut::with_capacity(CHUNK_SIZE),
39            put_result: None,
40        })
41    }
42
43    pub fn put_result(&self) -> Option<&PutResult> {
44        self.put_result.as_ref()
45    }
46}
47
48impl VortexWrite for ObjectStoreWrite {
49    async fn write_all<B: IoBuf>(&mut self, buffer: B) -> io::Result<B> {
50        self.buffer.extend_from_slice(buffer.as_slice());
51        let parts = FuturesUnordered::new();
52
53        // If the buffer is full
54        if self.buffer.len() > BUFFER_SIZE {
55            // Split off chunks while buffer is larger than CHUNKS_SIZE
56            while self.buffer.len() > CHUNK_SIZE {
57                let payload = self.buffer.split_to(CHUNK_SIZE).freeze();
58                let part_fut = self.upload.put_part(PutPayload::from_bytes(payload));
59
60                parts.push(part_fut);
61            }
62        }
63
64        parts.try_collect::<Vec<_>>().await?;
65
66        Ok(buffer)
67    }
68
69    async fn flush(&mut self) -> io::Result<()> {
70        let parts = FuturesUnordered::new();
71
72        while self.buffer.len() > CHUNK_SIZE {
73            let payload = self.buffer.split_to(CHUNK_SIZE).freeze();
74            let part_fut = self.upload.put_part(PutPayload::from_bytes(payload));
75
76            parts.push(part_fut);
77        }
78
79        parts.try_collect::<Vec<_>>().await?;
80
81        Ok(())
82    }
83
84    async fn shutdown(&mut self) -> io::Result<()> {
85        self.flush().await?;
86
87        if !self.buffer.is_empty() {
88            let payload = std::mem::take(&mut self.buffer).freeze();
89            self.upload
90                .put_part(PutPayload::from_bytes(payload))
91                .await?;
92        }
93
94        self.put_result = Some(self.upload.complete().await?);
95        Ok(())
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use std::sync::Arc;
102
103    use object_store::ObjectStore;
104    use object_store::local::LocalFileSystem;
105    use object_store::memory::InMemory;
106    use object_store::path::Path;
107    use rstest::rstest;
108    use tempfile::tempdir;
109
110    use super::*;
111
112    // Note: Concurrent writes test removed because &mut self in write_all already ensures
113    // exclusive access. Multiple writers would need to be created with separate buffers,
114    // which is not the intended use case.
115
116    #[tokio::test]
117    #[rstest]
118    #[case(100)]
119    #[case(8 * 1024 * 1024)]
120    #[case(25 * 1024 * 1024)]
121    #[case(26 * 1024 * 1024)]
122    async fn test_object_store_writer_multiple_flushes(
123        #[case] chunk_size: usize,
124    ) -> anyhow::Result<()> {
125        let temp_dir = tempdir()?;
126        let local_store =
127            Arc::new(LocalFileSystem::new_with_prefix(temp_dir.path())?) as Arc<dyn ObjectStore>;
128        let memory_store = Arc::new(InMemory::new()) as Arc<dyn ObjectStore>;
129        let location = Path::from("test.bin");
130
131        for test_store in [memory_store, local_store] {
132            let mut writer = ObjectStoreWrite::new(Arc::clone(&test_store), &location).await?;
133
134            #[expect(clippy::cast_possible_truncation)]
135            let data = (0..3)
136                .map(|i| vec![i as u8; chunk_size])
137                .collect::<Vec<_>>();
138
139            // Write and flush multiple times
140            for i in 0..3 {
141                let data = data[i].clone();
142                writer.write_all(data).await?;
143                writer.flush().await?;
144            }
145
146            // Shutdown the writer to make sure data actually gets persisted.
147            writer.shutdown().await?;
148
149            // Verify all data was written
150            let result = test_store.get(&location).await?;
151            let bytes = result.bytes().await?;
152
153            let expected_data = itertools::concat(data.into_iter());
154            assert_eq!(bytes, expected_data);
155        }
156
157        Ok(())
158    }
159}