tdb_succinct/storage/
memory.rs

1use std::{
2    pin::Pin,
3    sync::{Arc, RwLock},
4    task::{Context, Poll},
5};
6
7use async_trait::async_trait;
8use bytes::{Bytes, BytesMut};
9use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf};
10
11use super::types::{FileLoad, FileStore, SyncableFile};
12
13enum MemoryBackedStoreContents {
14    Nonexistent,
15    Existent(Bytes),
16}
17
18#[derive(Clone)]
19pub struct MemoryBackedStore {
20    contents: Arc<RwLock<MemoryBackedStoreContents>>,
21}
22
23impl MemoryBackedStore {
24    pub fn new() -> Self {
25        Self {
26            contents: Arc::new(RwLock::new(MemoryBackedStoreContents::Nonexistent)),
27        }
28    }
29}
30
31pub struct MemoryBackedStoreWriter {
32    file: MemoryBackedStore,
33    bytes: BytesMut,
34}
35
36#[async_trait]
37impl SyncableFile for MemoryBackedStoreWriter {
38    async fn sync_all(self) -> io::Result<()> {
39        let mut contents = self.file.contents.write().unwrap();
40        *contents = MemoryBackedStoreContents::Existent(self.bytes.freeze());
41
42        Ok(())
43    }
44}
45
46impl std::io::Write for MemoryBackedStoreWriter {
47    fn write(&mut self, buf: &[u8]) -> Result<usize, io::Error> {
48        self.bytes.extend_from_slice(buf);
49
50        Ok(buf.len())
51    }
52
53    fn flush(&mut self) -> Result<(), std::io::Error> {
54        Ok(())
55    }
56}
57
58impl AsyncWrite for MemoryBackedStoreWriter {
59    fn poll_write(
60        self: Pin<&mut Self>,
61        _cx: &mut Context,
62        buf: &[u8],
63    ) -> Poll<Result<usize, io::Error>> {
64        Poll::Ready(std::io::Write::write(self.get_mut(), buf))
65    }
66
67    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), io::Error>> {
68        Poll::Ready(std::io::Write::flush(self.get_mut()))
69    }
70
71    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
72        self.poll_flush(cx)
73    }
74}
75
76#[async_trait]
77impl FileStore for MemoryBackedStore {
78    type Write = MemoryBackedStoreWriter;
79
80    async fn open_write(&self) -> io::Result<Self::Write> {
81        Ok(MemoryBackedStoreWriter {
82            file: self.clone(),
83            bytes: BytesMut::new(),
84        })
85    }
86}
87
88pub struct MemoryBackedStoreReader {
89    bytes: Bytes,
90    pos: usize,
91}
92
93impl std::io::Read for MemoryBackedStoreReader {
94    fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
95        if self.bytes.len() == self.pos {
96            // end of file
97            Ok(0)
98        } else if self.bytes.len() < self.pos + buf.len() {
99            // read up to end
100            let len = self.bytes.len() - self.pos;
101            buf[..len].copy_from_slice(&self.bytes[self.pos..]);
102
103            self.pos += len;
104
105            Ok(len)
106        } else {
107            // read full buf
108            buf.copy_from_slice(&self.bytes[self.pos..self.pos + buf.len()]);
109
110            self.pos += buf.len();
111
112            Ok(buf.len())
113        }
114    }
115}
116
117impl AsyncRead for MemoryBackedStoreReader {
118    fn poll_read(
119        self: Pin<&mut Self>,
120        _cx: &mut Context,
121        buf: &mut ReadBuf,
122    ) -> Poll<Result<(), io::Error>> {
123        let slice = buf.initialize_unfilled();
124        let count = std::io::Read::read(self.get_mut(), slice);
125        if count.is_ok() {
126            buf.advance(*count.as_ref().unwrap());
127        }
128
129        Poll::Ready(count.map(|_| ()))
130    }
131}
132
133#[async_trait]
134impl FileLoad for MemoryBackedStore {
135    type Read = MemoryBackedStoreReader;
136
137    async fn exists(&self) -> io::Result<bool> {
138        match &*self.contents.read().unwrap() {
139            MemoryBackedStoreContents::Nonexistent => Ok(false),
140            _ => Ok(true),
141        }
142    }
143
144    async fn size(&self) -> io::Result<usize> {
145        match &*self.contents.read().unwrap() {
146            MemoryBackedStoreContents::Nonexistent => {
147                panic!("tried to retrieve size of nonexistent memory file")
148            }
149            MemoryBackedStoreContents::Existent(bytes) => Ok(bytes.len()),
150        }
151    }
152
153    async fn open_read_from(&self, offset: usize) -> io::Result<MemoryBackedStoreReader> {
154        match &*self.contents.read().unwrap() {
155            MemoryBackedStoreContents::Nonexistent => {
156                panic!("tried to open nonexistent memory file for reading")
157            }
158            MemoryBackedStoreContents::Existent(bytes) => Ok(MemoryBackedStoreReader {
159                bytes: bytes.clone(),
160                pos: offset,
161            }),
162        }
163    }
164
165    async fn map(&self) -> io::Result<Bytes> {
166        match &*self.contents.read().unwrap() {
167            MemoryBackedStoreContents::Nonexistent => Err(io::Error::new(
168                io::ErrorKind::NotFound,
169                "tried to open a nonexistent memory file for reading",
170            )),
171            MemoryBackedStoreContents::Existent(bytes) => Ok(bytes.clone()),
172        }
173    }
174}