1use std::collections::HashMap;
11use std::fs::File;
12use std::io::{Seek, SeekFrom, Write};
13
14use rvf_types::cow_map::CowMapEntry;
15use rvf_types::{ErrorCode, RvfError};
16
17use crate::cow_map::CowMap;
18use crate::store::simple_shake256_256;
19
20pub struct WitnessEvent {
22 pub event_type: u8,
24 pub cluster_id: u32,
26 pub parent_cluster_hash: [u8; 32],
28 pub new_cluster_hash: [u8; 32],
30}
31
32struct PendingWrite {
34 vector_offset_in_cluster: u32,
36 data: Vec<u8>,
38}
39
40pub struct CowEngine {
42 cow_map: CowMap,
44 cluster_size: u32,
46 vectors_per_cluster: u32,
48 bytes_per_vector: u32,
50 l0_cache: HashMap<u32, u64>,
52 write_buffer: HashMap<u32, Vec<PendingWrite>>,
54 frozen: bool,
56 snapshot_epoch: u32,
58}
59
60impl CowEngine {
61 pub fn new(cluster_size: u32, vectors_per_cluster: u32, bytes_per_vector: u32) -> Self {
66 assert!(vectors_per_cluster > 0, "vectors_per_cluster must be > 0");
67 Self {
68 cow_map: CowMap::new_flat(0),
69 cluster_size,
70 vectors_per_cluster,
71 bytes_per_vector,
72 l0_cache: HashMap::new(),
73 write_buffer: HashMap::new(),
74 frozen: false,
75 snapshot_epoch: 0,
76 }
77 }
78
79 pub fn from_parent(
84 cluster_count: u32,
85 cluster_size: u32,
86 vectors_per_cluster: u32,
87 bytes_per_vector: u32,
88 ) -> Self {
89 assert!(vectors_per_cluster > 0, "vectors_per_cluster must be > 0");
90 Self {
91 cow_map: CowMap::new_parent_ref(cluster_count),
92 cluster_size,
93 vectors_per_cluster,
94 bytes_per_vector,
95 l0_cache: HashMap::new(),
96 write_buffer: HashMap::new(),
97 frozen: false,
98 snapshot_epoch: 0,
99 }
100 }
101
102 pub fn cow_map(&self) -> &CowMap {
104 &self.cow_map
105 }
106
107 pub fn read_vector(
109 &self,
110 vector_id: u64,
111 file: &File,
112 parent: Option<&File>,
113 ) -> Result<Vec<u8>, RvfError> {
114 let cluster_id = (vector_id / self.vectors_per_cluster as u64) as u32;
115 let vector_index_in_cluster = (vector_id % self.vectors_per_cluster as u64) as u32;
116 let vector_offset = vector_index_in_cluster * self.bytes_per_vector;
117
118 let cluster_data = self.read_cluster(cluster_id, file, parent)?;
119
120 let start = vector_offset as usize;
121 let end = start + self.bytes_per_vector as usize;
122 if end > cluster_data.len() {
123 return Err(RvfError::Code(ErrorCode::ClusterNotFound));
124 }
125
126 Ok(cluster_data[start..end].to_vec())
127 }
128
129 pub fn read_cluster(
131 &self,
132 cluster_id: u32,
133 file: &File,
134 parent: Option<&File>,
135 ) -> Result<Vec<u8>, RvfError> {
136 if let Some(&cached_offset) = self.l0_cache.get(&cluster_id) {
138 return read_bytes_at(file, cached_offset, self.cluster_size as usize);
139 }
140
141 match self.cow_map.lookup(cluster_id) {
142 CowMapEntry::LocalOffset(offset) => {
143 read_bytes_at(file, offset, self.cluster_size as usize)
144 }
145 CowMapEntry::ParentRef => {
146 let parent_file = parent.ok_or(RvfError::Code(ErrorCode::ParentChainBroken))?;
147 let parent_offset = cluster_id as u64 * self.cluster_size as u64;
148 read_bytes_at(parent_file, parent_offset, self.cluster_size as usize)
149 }
150 CowMapEntry::Unallocated => {
151 Ok(vec![0u8; self.cluster_size as usize])
153 }
154 }
155 }
156
157 pub fn write_vector(
161 &mut self,
162 vector_id: u64,
163 data: &[u8],
164 ) -> Result<(), RvfError> {
165 if self.frozen {
166 return Err(RvfError::Code(ErrorCode::SnapshotFrozen));
167 }
168 if data.len() != self.bytes_per_vector as usize {
169 return Err(RvfError::Code(ErrorCode::DimensionMismatch));
170 }
171
172 let cluster_id = (vector_id / self.vectors_per_cluster as u64) as u32;
173 let vector_index_in_cluster = (vector_id % self.vectors_per_cluster as u64) as u32;
174 let vector_offset = vector_index_in_cluster * self.bytes_per_vector;
175
176 self.write_buffer
177 .entry(cluster_id)
178 .or_default()
179 .push(PendingWrite {
180 vector_offset_in_cluster: vector_offset,
181 data: data.to_vec(),
182 });
183
184 Ok(())
185 }
186
187 pub fn flush_writes(
190 &mut self,
191 file: &mut File,
192 parent: Option<&File>,
193 ) -> Result<Vec<WitnessEvent>, RvfError> {
194 if self.frozen {
195 return Err(RvfError::Code(ErrorCode::SnapshotFrozen));
196 }
197
198 let pending: Vec<(u32, Vec<PendingWrite>)> =
199 self.write_buffer.drain().collect();
200
201 let mut witness_events = Vec::new();
202
203 for (cluster_id, writes) in pending {
204 let entry = self.cow_map.lookup(cluster_id);
205
206 let mut cluster_data = match entry {
208 CowMapEntry::LocalOffset(offset) => {
209 read_bytes_at(file, offset, self.cluster_size as usize)?
211 }
212 CowMapEntry::ParentRef => {
213 let parent_file =
215 parent.ok_or(RvfError::Code(ErrorCode::ParentChainBroken))?;
216 let parent_offset = cluster_id as u64 * self.cluster_size as u64;
217 let parent_data =
218 read_bytes_at(parent_file, parent_offset, self.cluster_size as usize)?;
219 let parent_hash = simple_shake256_256(&parent_data);
220
221 let new_offset = file
223 .seek(SeekFrom::End(0))
224 .map_err(|_| RvfError::Code(ErrorCode::FsyncFailed))?;
225
226 file.write_all(&parent_data)
228 .map_err(|_| RvfError::Code(ErrorCode::FsyncFailed))?;
229
230 self.cow_map
232 .update(cluster_id, CowMapEntry::LocalOffset(new_offset));
233 self.l0_cache.insert(cluster_id, new_offset);
234
235 witness_events.push(WitnessEvent {
237 event_type: 0x0E, cluster_id,
239 parent_cluster_hash: parent_hash,
240 new_cluster_hash: [0u8; 32], });
242
243 parent_data
244 }
245 CowMapEntry::Unallocated => {
246 let zeroed = vec![0u8; self.cluster_size as usize];
248 let new_offset = file
249 .seek(SeekFrom::End(0))
250 .map_err(|_| RvfError::Code(ErrorCode::FsyncFailed))?;
251 file.write_all(&zeroed)
252 .map_err(|_| RvfError::Code(ErrorCode::FsyncFailed))?;
253 self.cow_map
254 .update(cluster_id, CowMapEntry::LocalOffset(new_offset));
255 self.l0_cache.insert(cluster_id, new_offset);
256 zeroed
257 }
258 };
259
260 for pw in &writes {
262 let start = pw.vector_offset_in_cluster as usize;
263 let end = start + pw.data.len();
264 if end > cluster_data.len() {
265 return Err(RvfError::Code(ErrorCode::ClusterNotFound));
266 }
267 cluster_data[start..end].copy_from_slice(&pw.data);
268 }
269
270 if let CowMapEntry::LocalOffset(offset) = self.cow_map.lookup(cluster_id) {
272 file.seek(SeekFrom::Start(offset))
273 .map_err(|_| RvfError::Code(ErrorCode::FsyncFailed))?;
274 file.write_all(&cluster_data)
275 .map_err(|_| RvfError::Code(ErrorCode::FsyncFailed))?;
276
277 let new_hash = simple_shake256_256(&cluster_data);
279 for event in witness_events.iter_mut().rev() {
280 if event.cluster_id == cluster_id {
281 event.new_cluster_hash = new_hash;
282 break;
283 }
284 }
285 }
286 }
287
288 file.sync_all()
289 .map_err(|_| RvfError::Code(ErrorCode::FsyncFailed))?;
290
291 Ok(witness_events)
292 }
293
294 pub fn freeze(&mut self, epoch: u32) -> Result<(), RvfError> {
296 if self.frozen {
297 return Err(RvfError::Code(ErrorCode::SnapshotFrozen));
298 }
299 if !self.write_buffer.is_empty() {
300 return Err(RvfError::Code(ErrorCode::FsyncFailed));
301 }
302 self.frozen = true;
303 self.snapshot_epoch = epoch;
304 Ok(())
305 }
306
307 pub fn is_frozen(&self) -> bool {
309 self.frozen
310 }
311
312 pub fn snapshot_epoch(&self) -> u32 {
314 self.snapshot_epoch
315 }
316
317 pub fn stats(&self) -> CowStats {
319 CowStats {
320 cluster_count: self.cow_map.cluster_count(),
321 local_cluster_count: self.cow_map.local_cluster_count(),
322 cluster_size: self.cluster_size,
323 vectors_per_cluster: self.vectors_per_cluster,
324 frozen: self.frozen,
325 snapshot_epoch: self.snapshot_epoch,
326 pending_writes: self.write_buffer.values().map(|v| v.len()).sum(),
327 }
328 }
329}
330
331pub struct CowStats {
333 pub cluster_count: u32,
335 pub local_cluster_count: u32,
337 pub cluster_size: u32,
339 pub vectors_per_cluster: u32,
341 pub frozen: bool,
343 pub snapshot_epoch: u32,
345 pub pending_writes: usize,
347}
348
349#[cfg(unix)]
353fn read_bytes_at(file: &File, offset: u64, len: usize) -> Result<Vec<u8>, RvfError> {
354 use std::os::unix::fs::FileExt;
355 let mut buf = vec![0u8; len];
356 file.read_exact_at(&mut buf, offset)
357 .map_err(|_| RvfError::Code(ErrorCode::ClusterNotFound))?;
358 Ok(buf)
359}
360
361#[cfg(not(unix))]
363fn read_bytes_at(file: &File, offset: u64, len: usize) -> Result<Vec<u8>, RvfError> {
364 use std::io::Read;
365 let mut reader = std::io::BufReader::new(file);
366 reader
367 .seek(SeekFrom::Start(offset))
368 .map_err(|_| RvfError::Code(ErrorCode::FsyncFailed))?;
369 let mut buf = vec![0u8; len];
370 reader
371 .read_exact(&mut buf)
372 .map_err(|_| RvfError::Code(ErrorCode::ClusterNotFound))?;
373 Ok(buf)
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use std::io::Write;
380 use tempfile::NamedTempFile;
381
382 fn create_parent_file(cluster_size: u32, cluster_count: u32) -> NamedTempFile {
383 let mut f = NamedTempFile::new().unwrap();
384 for cluster_id in 0..cluster_count {
385 let mut data = vec![0u8; cluster_size as usize];
386 for b in data.iter_mut() {
388 *b = (cluster_id & 0xFF) as u8;
389 }
390 f.write_all(&data).unwrap();
391 }
392 f.flush().unwrap();
393 f
394 }
395
396 #[test]
397 fn cow_read_from_parent() {
398 let cluster_size = 256u32;
399 let vecs_per_cluster = 4u32;
400 let bytes_per_vec = 64u32; let parent_file = create_parent_file(cluster_size, 4);
403 let child_file = NamedTempFile::new().unwrap();
404
405 let engine = CowEngine::from_parent(4, cluster_size, vecs_per_cluster, bytes_per_vec);
406
407 let data = engine
409 .read_cluster(
410 2,
411 child_file.as_file(),
412 Some(parent_file.as_file()),
413 )
414 .unwrap();
415 assert_eq!(data.len(), cluster_size as usize);
416 assert!(data.iter().all(|&b| b == 2));
417 }
418
419 #[test]
420 fn cow_write_triggers_copy() {
421 let cluster_size = 128u32;
422 let vecs_per_cluster = 2u32;
423 let bytes_per_vec = 64u32;
424
425 let parent_file = create_parent_file(cluster_size, 2);
426 let child_file = NamedTempFile::new().unwrap();
427
428 let mut engine =
429 CowEngine::from_parent(2, cluster_size, vecs_per_cluster, bytes_per_vec);
430
431 let new_data = vec![0xAA; bytes_per_vec as usize];
433 engine.write_vector(0, &new_data).unwrap();
434
435 let events = engine
436 .flush_writes(
437 &mut child_file.as_file().try_clone().unwrap(),
438 Some(parent_file.as_file()),
439 )
440 .unwrap();
441
442 assert_eq!(events.len(), 1);
444 assert_eq!(events[0].event_type, 0x0E);
445 assert_eq!(events[0].cluster_id, 0);
446
447 assert_eq!(engine.cow_map().local_cluster_count(), 1);
449 }
450
451 #[test]
452 fn cow_write_coalescing() {
453 let cluster_size = 128u32;
454 let vecs_per_cluster = 2u32;
455 let bytes_per_vec = 64u32;
456
457 let parent_file = create_parent_file(cluster_size, 2);
458 let child_file = NamedTempFile::new().unwrap();
459
460 let mut engine =
461 CowEngine::from_parent(2, cluster_size, vecs_per_cluster, bytes_per_vec);
462
463 let data_a = vec![0xAA; bytes_per_vec as usize];
465 let data_b = vec![0xBB; bytes_per_vec as usize];
466 engine.write_vector(0, &data_a).unwrap();
467 engine.write_vector(1, &data_b).unwrap();
468
469 let events = engine
470 .flush_writes(
471 &mut child_file.as_file().try_clone().unwrap(),
472 Some(parent_file.as_file()),
473 )
474 .unwrap();
475
476 assert_eq!(events.len(), 1);
478 assert_eq!(events[0].cluster_id, 0);
479 }
480
481 #[test]
482 fn cow_frozen_rejects_writes() {
483 let mut engine = CowEngine::new(128, 2, 64);
484 engine.freeze(1).unwrap();
485 assert!(engine.is_frozen());
486
487 let result = engine.write_vector(0, &vec![0u8; 64]);
488 assert!(result.is_err());
489 }
490
491 #[test]
492 fn cow_read_unallocated_returns_zeros() {
493 let engine = CowEngine::new(128, 2, 64);
494 let child_file = NamedTempFile::new().unwrap();
495
496 let data = engine
497 .read_cluster(0, child_file.as_file(), None)
498 .unwrap();
499 assert_eq!(data.len(), 128);
500 assert!(data.iter().all(|&b| b == 0));
501 }
502
503 #[test]
504 fn cow_stats() {
505 let mut engine = CowEngine::from_parent(4, 256, 4, 64);
506 let stats = engine.stats();
507 assert_eq!(stats.cluster_count, 4);
508 assert_eq!(stats.local_cluster_count, 0);
509 assert!(!stats.frozen);
510
511 engine.write_vector(0, &vec![0u8; 64]).unwrap();
513 let stats = engine.stats();
514 assert_eq!(stats.pending_writes, 1);
515 }
516}