sequoia_git/
persistent_set.rs1use std::{
23 collections::BTreeSet,
24 io::{
25 Seek,
26 SeekFrom,
27 Write,
28 },
29 path::Path,
30};
31use buffered_reader::{BufferedReader};
32
33const VALUE_BYTES: usize = 32;
34pub type Value = [u8; VALUE_BYTES];
35
36type File = buffered_reader::File<'static, ()>;
37
38pub struct Set {
39 header: Header,
40 store: File,
41 scratch: BTreeSet<Value>,
42}
43
44impl Set {
50 #[allow(dead_code)]
52 fn len(&self) -> usize {
53 usize::try_from(self.header.entries).expect("representable")
54 + self.scratch.len()
55 }
57
58 pub fn contains(&mut self, value: &Value) -> Result<bool> {
60 Ok(self.stored_values()?.binary_search(value).is_ok()
61 || self.scratch.contains(value))
62 }
63
64 pub fn insert(&mut self, value: Value) {
66 self.scratch.insert(value);
70 }
71
72 fn stored_values(&mut self) -> Result<&[Value]> {
73 let entries = self.header.entries as usize;
74 let bytes = self.store.data_hard(entries * VALUE_BYTES)?;
75 unsafe {
76 Ok(std::slice::from_raw_parts(bytes.as_ptr() as *const Value,
77 entries))
78 }
79 }
80
81 pub fn read<P: AsRef<Path>>(path: P, context: &str) -> Result<Self> {
82 assert_eq!(VALUE_BYTES, std::mem::size_of::<Value>());
86 assert_eq!(std::mem::size_of::<[Value; 2]>(),
87 2 * VALUE_BYTES,
88 "values are unpadded");
89
90 let context: [u8; CONTEXT_BYTES] = context.as_bytes()
91 .try_into()
92 .map_err(|_| Error::BadContext)?;
93
94 let (header, reader) = match File::open(path) {
95 Ok(mut f) => {
96 let header = Header::read(&mut f, context)?;
97 (header, f)
98 },
99 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
100 let t = tempfile::NamedTempFile::new()?;
101 let f = File::open(t.path())?;
104 (Header::new(context), f)
105 },
106 Err(e) => return Err(e.into()),
107 };
108
109 Ok(Set {
113 header,
114 store: reader,
115 scratch: Default::default(),
116 })
117 }
118
119 pub fn write<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
120 if self.scratch.is_empty() {
122 return Ok(());
123 }
124
125 let mut sink = tempfile::NamedTempFile::new_in(
126 path.as_ref().parent().ok_or(Error::BadPath)?)?;
127
128 let mut h = self.header.clone();
130 h.entries = 0; h.write(&mut sink)?;
132
133 let mut entries = 0;
135 let scratch = std::mem::replace(&mut self.scratch, Default::default());
136 let mut stored = self.stored_values()?;
137
138 for new in scratch.iter() {
139 let p = stored.partition_point(|v| v < new);
140
141 let before = &stored[..p];
142 let before_bytes = unsafe {
143 std::slice::from_raw_parts(before.as_ptr() as *const u8,
144 before.len() * VALUE_BYTES)
145 };
146 sink.write_all(before_bytes)?;
147 entries += p;
148
149 if before.is_empty() || &before[p - 1] != new {
151 sink.write_all(new)?;
152 entries += 1;
153 }
154
155 stored = &stored[p..];
157 }
158
159 {
161 let stored_bytes = unsafe {
162 std::slice::from_raw_parts(stored.as_ptr() as *const u8,
163 stored.len() * VALUE_BYTES)
164 };
165 sink.write_all(stored_bytes)?;
166 entries += stored.len();
167 }
168
169 self.scratch = scratch;
171
172 sink.as_file_mut().seek(SeekFrom::Start(0))?;
175 h.entries = entries.try_into().map_err(|_| Error::TooManyEntries)?;
176 h.write(&mut sink)?;
177 sink.flush()?;
178
179 sink.persist(path).map_err(|pe| pe.error)?;
180 Ok(())
181 }
182}
183
184const CONTEXT_BYTES: usize = 12;
185
186#[derive(Debug, Clone)]
187struct Header {
188 version: u8,
189 context: [u8; CONTEXT_BYTES],
190 entries: u32,
191}
192
193impl Header {
194 const MAGIC: &'static [u8; 15] = b"StoredSortedSet";
195
196 fn new(context: [u8; CONTEXT_BYTES]) -> Self {
197 Header {
198 version: 1,
199 context,
200 entries: 0,
201 }
202 }
203
204 fn read(reader: &mut File, context: [u8; CONTEXT_BYTES]) -> Result<Self> {
205 let m = reader.data_consume_hard(Self::MAGIC.len())?;
206 if &m[..Self::MAGIC.len()] != &Self::MAGIC[..] {
207 return Err(Error::BadMagic);
208 }
209 let v = reader.data_consume_hard(1)?;
210 let version = v[0];
211 if version != 1 {
212 return Err(Error::UnsupportedVersion(version));
213 }
214
215 let c = &reader.data_consume_hard(context.len())?[..context.len()];
216 if &c[..] != &context[..] {
217 return Err(Error::BadContext);
218 }
219
220 let e = &reader.data_consume_hard(4)?[..4];
221 let entries =
222 u32::from_be_bytes(e.try_into().expect("we read 4 bytes"));
223
224 Ok(Header {
225 version,
226 context,
227 entries,
228 })
229 }
230
231 fn write(&self, sink: &mut dyn Write) -> Result<()> {
232 sink.write_all(Self::MAGIC)?;
233 sink.write_all(&[self.version])?;
234 sink.write_all(&self.context)?;
235 sink.write_all(&self.entries.to_be_bytes())?;
236 Ok(())
237 }
238}
239
240#[derive(thiserror::Error, Debug)]
242pub enum Error {
243 #[error("Bad magic read from file")]
244 BadMagic,
245 #[error("Unsupported version: {0}")]
246 UnsupportedVersion(u8),
247 #[error("Bad context read from file")]
248 BadContext,
249 #[error("Too many entries")]
250 TooManyEntries,
251 #[error("Bad path")]
252 BadPath,
253 #[error("Io error")]
254 Io(#[from] std::io::Error),
255}
256
257pub type Result<T> = ::std::result::Result<T, Error>;