tdb_succinct/storage/
memory.rs1use std::{
2 pin::Pin,
3 sync::{Arc, RwLock},
4 task::{Context, Poll},
5};
6
7use async_trait::async_trait;
8use bytes::{Bytes, BytesMut};
9use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf};
10
11use super::types::{FileLoad, FileStore, SyncableFile};
12
13enum MemoryBackedStoreContents {
14 Nonexistent,
15 Existent(Bytes),
16}
17
18#[derive(Clone)]
19pub struct MemoryBackedStore {
20 contents: Arc<RwLock<MemoryBackedStoreContents>>,
21}
22
23impl MemoryBackedStore {
24 pub fn new() -> Self {
25 Self {
26 contents: Arc::new(RwLock::new(MemoryBackedStoreContents::Nonexistent)),
27 }
28 }
29}
30
31pub struct MemoryBackedStoreWriter {
32 file: MemoryBackedStore,
33 bytes: BytesMut,
34}
35
36#[async_trait]
37impl SyncableFile for MemoryBackedStoreWriter {
38 async fn sync_all(self) -> io::Result<()> {
39 let mut contents = self.file.contents.write().unwrap();
40 *contents = MemoryBackedStoreContents::Existent(self.bytes.freeze());
41
42 Ok(())
43 }
44}
45
46impl std::io::Write for MemoryBackedStoreWriter {
47 fn write(&mut self, buf: &[u8]) -> Result<usize, io::Error> {
48 self.bytes.extend_from_slice(buf);
49
50 Ok(buf.len())
51 }
52
53 fn flush(&mut self) -> Result<(), std::io::Error> {
54 Ok(())
55 }
56}
57
58impl AsyncWrite for MemoryBackedStoreWriter {
59 fn poll_write(
60 self: Pin<&mut Self>,
61 _cx: &mut Context,
62 buf: &[u8],
63 ) -> Poll<Result<usize, io::Error>> {
64 Poll::Ready(std::io::Write::write(self.get_mut(), buf))
65 }
66
67 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), io::Error>> {
68 Poll::Ready(std::io::Write::flush(self.get_mut()))
69 }
70
71 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
72 self.poll_flush(cx)
73 }
74}
75
76#[async_trait]
77impl FileStore for MemoryBackedStore {
78 type Write = MemoryBackedStoreWriter;
79
80 async fn open_write(&self) -> io::Result<Self::Write> {
81 Ok(MemoryBackedStoreWriter {
82 file: self.clone(),
83 bytes: BytesMut::new(),
84 })
85 }
86}
87
88pub struct MemoryBackedStoreReader {
89 bytes: Bytes,
90 pos: usize,
91}
92
93impl std::io::Read for MemoryBackedStoreReader {
94 fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
95 if self.bytes.len() == self.pos {
96 Ok(0)
98 } else if self.bytes.len() < self.pos + buf.len() {
99 let len = self.bytes.len() - self.pos;
101 buf[..len].copy_from_slice(&self.bytes[self.pos..]);
102
103 self.pos += len;
104
105 Ok(len)
106 } else {
107 buf.copy_from_slice(&self.bytes[self.pos..self.pos + buf.len()]);
109
110 self.pos += buf.len();
111
112 Ok(buf.len())
113 }
114 }
115}
116
117impl AsyncRead for MemoryBackedStoreReader {
118 fn poll_read(
119 self: Pin<&mut Self>,
120 _cx: &mut Context,
121 buf: &mut ReadBuf,
122 ) -> Poll<Result<(), io::Error>> {
123 let slice = buf.initialize_unfilled();
124 let count = std::io::Read::read(self.get_mut(), slice);
125 if count.is_ok() {
126 buf.advance(*count.as_ref().unwrap());
127 }
128
129 Poll::Ready(count.map(|_| ()))
130 }
131}
132
133#[async_trait]
134impl FileLoad for MemoryBackedStore {
135 type Read = MemoryBackedStoreReader;
136
137 async fn exists(&self) -> io::Result<bool> {
138 match &*self.contents.read().unwrap() {
139 MemoryBackedStoreContents::Nonexistent => Ok(false),
140 _ => Ok(true),
141 }
142 }
143
144 async fn size(&self) -> io::Result<usize> {
145 match &*self.contents.read().unwrap() {
146 MemoryBackedStoreContents::Nonexistent => {
147 panic!("tried to retrieve size of nonexistent memory file")
148 }
149 MemoryBackedStoreContents::Existent(bytes) => Ok(bytes.len()),
150 }
151 }
152
153 async fn open_read_from(&self, offset: usize) -> io::Result<MemoryBackedStoreReader> {
154 match &*self.contents.read().unwrap() {
155 MemoryBackedStoreContents::Nonexistent => {
156 panic!("tried to open nonexistent memory file for reading")
157 }
158 MemoryBackedStoreContents::Existent(bytes) => Ok(MemoryBackedStoreReader {
159 bytes: bytes.clone(),
160 pos: offset,
161 }),
162 }
163 }
164
165 async fn map(&self) -> io::Result<Bytes> {
166 match &*self.contents.read().unwrap() {
167 MemoryBackedStoreContents::Nonexistent => Err(io::Error::new(
168 io::ErrorKind::NotFound,
169 "tried to open a nonexistent memory file for reading",
170 )),
171 MemoryBackedStoreContents::Existent(bytes) => Ok(bytes.clone()),
172 }
173 }
174}