sos_filesystem/
vault_writer.rs

1//! Write vault changes to a file on disc.
2use async_fd_lock::{LockRead, LockWrite};
3use async_trait::async_trait;
4use binary_stream::futures::{BinaryReader, BinaryWriter};
5use sos_core::{
6    commit::CommitHash,
7    crypto::AeadPack,
8    encode,
9    encoding::encoding_options,
10    events::{ReadEvent, WriteEvent},
11    SecretId, VaultCommit, VaultEntry, VaultFlags,
12};
13use sos_vault::{Contents, EncryptedEntry, Header, Summary, Vault};
14use sos_vfs::{self as vfs, OpenOptions};
15use std::io::Cursor;
16use std::{borrow::Cow, io::SeekFrom, ops::Range, path::Path, path::PathBuf};
17use tokio::io::BufWriter;
18use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt};
19
20/// Write changes to a vault file on disc.
21pub struct VaultFileWriter<E>
22where
23    E: std::error::Error
24        + std::fmt::Debug
25        + From<sos_core::Error>
26        + From<sos_vault::Error>
27        + From<std::io::Error>
28        + Send
29        + Sync
30        + 'static,
31{
32    pub(crate) file_path: PathBuf,
33    marker: std::marker::PhantomData<E>,
34}
35
36impl<E> VaultFileWriter<E>
37where
38    E: std::error::Error
39        + std::fmt::Debug
40        + From<sos_core::Error>
41        + From<sos_vault::Error>
42        + From<std::io::Error>
43        + Send
44        + Sync
45        + 'static,
46{
47    /// Create a new vault file writer.
48    pub fn new<P: AsRef<Path>>(path: P) -> Self {
49        let file_path = path.as_ref().to_path_buf();
50        Self {
51            file_path,
52            marker: std::marker::PhantomData,
53        }
54    }
55
56    /// Check the identity bytes and return the byte offset of the
57    /// beginning of the vault content area.
58    async fn check_identity(&self) -> Result<u64, E> {
59        Ok(Header::read_content_offset(&self.file_path).await?)
60    }
61
62    /// Write out the header preserving the existing content bytes.
63    async fn write_header(
64        &self,
65        content_offset: u64,
66        header: &Header,
67    ) -> Result<(), E> {
68        let head = encode(header).await?;
69        let mut file = OpenOptions::new()
70            .read(true)
71            .write(true)
72            .open(&self.file_path)
73            .await?;
74
75        // Read the content into memory
76        file.seek(SeekFrom::Start(content_offset)).await?;
77        let mut content = Vec::new();
78        file.read_to_end(&mut content).await?;
79
80        // Rewind and truncate the file
81        file.rewind().await?;
82        file.set_len(0).await?;
83
84        let mut guard = file.lock_write().await.map_err(|e| e.error)?;
85
86        // Write out the header
87        guard.write_all(&head).await?;
88
89        // Write out the content
90        guard.write_all(&content).await?;
91        guard.flush().await?;
92
93        Ok(())
94    }
95
96    /// Splice a file preserving the head and tail and
97    /// optionally inserting content in between.
98    async fn splice(
99        &self,
100        head: Range<u64>,
101        tail: Range<u64>,
102        content: Option<&[u8]>,
103    ) -> Result<(), E> {
104        let end = {
105            let file =
106                OpenOptions::new().read(true).open(&self.file_path).await?;
107            let mut guard = file.lock_read().await.map_err(|e| e.error)?;
108
109            // Read the tail into memory
110            guard.seek(SeekFrom::Start(tail.start)).await?;
111            let mut end = Vec::new();
112            guard.read_to_end(&mut end).await?;
113
114            end
115        };
116
117        let file =
118            OpenOptions::new().write(true).open(&self.file_path).await?;
119
120        let mut guard = file.lock_write().await.map_err(|e| e.error)?;
121
122        if head.start == 0 {
123            // Rewind and truncate the file to the head
124            guard.rewind().await?;
125            guard.inner_mut().set_len(head.end).await?;
126        } else {
127            unreachable!("file splice head range always starts at zero");
128        }
129
130        // Must seek to the end before writing out the content or tail
131        guard.seek(SeekFrom::End(0)).await?;
132
133        // Inject the content if necessary
134        if let Some(content) = content {
135            guard.write_all(content).await?;
136        }
137
138        // Write out the end portion
139        guard.write_all(&end).await?;
140        guard.flush().await?;
141
142        Ok(())
143    }
144
145    /// Find the byte offset of a row.
146    ///
147    /// Returns the content offset and the byte
148    /// offset and row length of the row if it exists.
149    async fn find_row(
150        &self,
151        id: &SecretId,
152    ) -> Result<(u64, Option<(u64, u32)>), E> {
153        let content_offset = self.check_identity().await?;
154
155        let file =
156            OpenOptions::new().read(true).open(&self.file_path).await?;
157        let mut guard = file.lock_read().await.map_err(|e| e.error)?;
158
159        let mut reader = BinaryReader::new(&mut guard, encoding_options());
160        reader.seek(SeekFrom::Start(content_offset)).await?;
161
162        // Scan all the rows
163        let mut current_pos = reader.stream_position().await?;
164        while let Ok(row_len) = reader.read_u32().await {
165            let row_id: [u8; 16] = reader
166                .read_bytes(16)
167                .await?
168                .as_slice()
169                .try_into()
170                .map_err(sos_core::Error::from)?;
171            let row_id = SecretId::from_bytes(row_id);
172            if id == &row_id {
173                // Need to backtrack as we just read the row length and UUID;
174                // calling decode_row() will try to read the length and UUID.
175                reader.seek(SeekFrom::Start(current_pos)).await?;
176                return Ok((content_offset, Some((current_pos, row_len))));
177            }
178
179            // Move on to the next row
180            reader
181                .seek(SeekFrom::Start(current_pos + 8 + row_len as u64))
182                .await?;
183            current_pos = reader.stream_position().await?;
184        }
185
186        Ok((content_offset, None))
187    }
188}
189
190#[async_trait]
191impl<E> EncryptedEntry for VaultFileWriter<E>
192where
193    E: std::error::Error
194        + std::fmt::Debug
195        + From<sos_vault::Error>
196        + From<sos_core::Error>
197        + From<std::io::Error>
198        + Send
199        + Sync
200        + 'static,
201{
202    type Error = E;
203
204    async fn summary(&self) -> Result<Summary, Self::Error> {
205        Ok(Header::read_summary_file(&self.file_path).await?)
206    }
207
208    async fn vault_name(&self) -> Result<Cow<'_, str>, Self::Error> {
209        let header = Header::read_header_file(&self.file_path).await?;
210        let name = header.name().to_string();
211        Ok(Cow::Owned(name))
212    }
213
214    async fn set_vault_name(
215        &mut self,
216        name: String,
217    ) -> Result<WriteEvent, Self::Error> {
218        let content_offset = self.check_identity().await?;
219        let mut header = Header::read_header_file(&self.file_path).await?;
220        header.set_name(name.clone());
221        self.write_header(content_offset, &header).await?;
222        Ok(WriteEvent::SetVaultName(name))
223    }
224
225    async fn set_vault_flags(
226        &mut self,
227        flags: VaultFlags,
228    ) -> Result<WriteEvent, Self::Error> {
229        let content_offset = self.check_identity().await?;
230        let mut header = Header::read_header_file(&self.file_path).await?;
231        *header.flags_mut() = flags.clone();
232        self.write_header(content_offset, &header).await?;
233        Ok(WriteEvent::SetVaultFlags(flags))
234    }
235
236    async fn set_vault_meta(
237        &mut self,
238        meta_data: AeadPack,
239    ) -> Result<WriteEvent, Self::Error> {
240        let content_offset = self.check_identity().await?;
241        let mut header = Header::read_header_file(&self.file_path).await?;
242        header.set_meta(Some(meta_data.clone()));
243        self.write_header(content_offset, &header).await?;
244        Ok(WriteEvent::SetVaultMeta(meta_data))
245    }
246
247    async fn create_secret(
248        &mut self,
249        commit: CommitHash,
250        secret: VaultEntry,
251    ) -> Result<WriteEvent, Self::Error> {
252        let id = SecretId::new_v4();
253        self.insert_secret(id, commit, secret).await
254    }
255
256    async fn insert_secret(
257        &mut self,
258        id: SecretId,
259        commit: CommitHash,
260        secret: VaultEntry,
261    ) -> Result<WriteEvent, Self::Error> {
262        let _summary = self.summary().await?;
263
264        // Encode the row into a buffer
265        let mut buffer = Vec::new();
266        let mut writer =
267            BinaryWriter::new(Cursor::new(&mut buffer), encoding_options());
268        let row = VaultCommit(commit, secret);
269        Contents::encode_row(&mut writer, &id, &row).await?;
270        writer.flush().await?;
271
272        // Append to the file
273        let file = OpenOptions::new()
274            .read(true)
275            .write(true)
276            .append(true)
277            .open(&self.file_path)
278            .await?;
279        let mut guard = file.lock_write().await.map_err(|e| e.error)?;
280        guard.write_all(&buffer).await?;
281        guard.flush().await?;
282
283        Ok(WriteEvent::CreateSecret(id, row))
284    }
285
286    async fn read_secret<'a>(
287        &'a self,
288        id: &SecretId,
289    ) -> Result<Option<(Cow<'a, VaultCommit>, ReadEvent)>, Self::Error> {
290        let _summary = self.summary().await?;
291        let event = ReadEvent::ReadSecret(*id);
292        let (_, row) = self.find_row(id).await?;
293        if let Some((row_offset, _)) = row {
294            let file =
295                OpenOptions::new().read(true).open(&self.file_path).await?;
296            let mut guard = file.lock_read().await.map_err(|e| e.error)?;
297
298            let mut reader =
299                BinaryReader::new(&mut guard, encoding_options());
300            reader.seek(SeekFrom::Start(row_offset)).await?;
301            let (_, value) = Contents::decode_row(&mut reader).await?;
302            Ok(Some((Cow::Owned(value), event)))
303        } else {
304            Ok(None)
305        }
306    }
307
308    async fn update_secret(
309        &mut self,
310        id: &SecretId,
311        commit: CommitHash,
312        secret: VaultEntry,
313    ) -> Result<Option<WriteEvent>, Self::Error> {
314        let _summary = self.summary().await?;
315        let (_content_offset, row) = self.find_row(id).await?;
316        if let Some((row_offset, row_len)) = row {
317            // Prepare the row
318            let mut buffer = Vec::new();
319            let mut stream = BufWriter::new(Cursor::new(&mut buffer));
320            let mut writer =
321                BinaryWriter::new(&mut stream, encoding_options());
322
323            let row = VaultCommit(commit, secret);
324            Contents::encode_row(&mut writer, id, &row).await?;
325            writer.flush().await?;
326
327            // Splice the row into the file
328            let length = writer.len().await?;
329
330            let head = 0..row_offset;
331            // Row offset is before the row length u32 so we
332            // need to account for that too
333            let tail = (row_offset + 8 + row_len as u64)..length;
334
335            self.splice(head, tail, Some(&buffer)).await?;
336
337            Ok(Some(WriteEvent::UpdateSecret(*id, row)))
338        } else {
339            Ok(None)
340        }
341    }
342
343    async fn delete_secret(
344        &mut self,
345        id: &SecretId,
346    ) -> Result<Option<WriteEvent>, Self::Error> {
347        let _summary = self.summary().await?;
348        let (_content_offset, row) = self.find_row(id).await?;
349        if let Some((row_offset, row_len)) = row {
350            let length = vfs::metadata(&self.file_path).await?.len();
351
352            let head = 0..row_offset;
353            // Row offset is before the row length u32 so we
354            // need to account for that too
355            let tail = (row_offset + 8 + row_len as u64)..length;
356
357            self.splice(head, tail, None).await?;
358
359            Ok(Some(WriteEvent::DeleteSecret(*id)))
360        } else {
361            Ok(None)
362        }
363    }
364
365    async fn replace_vault(
366        &mut self,
367        vault: &Vault,
368    ) -> Result<(), Self::Error> {
369        let buffer = encode(vault).await?;
370
371        let file =
372            OpenOptions::new().write(true).open(&self.file_path).await?;
373        let mut guard = file.lock_write().await.map_err(|e| e.error)?;
374        guard.write_all(&buffer).await?;
375        guard.flush().await?;
376
377        Ok(())
378    }
379}