1use serde::{Deserialize, Serialize};
2use simd_r_drive::storage_engine::DataStore;
3use simd_r_drive::storage_engine::traits::{DataStoreReader, DataStoreWriter};
4use std::collections::HashMap;
5use std::fmt;
6use std::fs;
7use std::hash::{Hash, Hasher};
8use std::io;
9use std::path::{Path, PathBuf};
10use std::sync::RwLock;
11use tempfile::TempDir;
12
13use crate::constants::splits::{
14 ALL_SPLITS, BITCODE_PREFIX, EPOCH_HASH_RECORD_VERSION, EPOCH_HASHES_PREFIX, EPOCH_META_PREFIX,
15 EPOCH_META_RECORD_VERSION, EPOCH_RECORD_TOMBSTONE, META_KEY, SAMPLER_STATE_KEY,
16 SAMPLER_STATE_RECORD_VERSION, SPLIT_PREFIX, STORE_VERSION,
17};
18use crate::data::RecordId;
19use crate::errors::SamplerError;
20use crate::types::SourceId;
21
22#[derive(
24 Clone,
25 Copy,
26 Debug,
27 PartialEq,
28 Eq,
29 Hash,
30 Serialize,
31 Deserialize,
32 bitcode::Encode,
33 bitcode::Decode,
34)]
35pub enum SplitLabel {
36 Train,
38 Validation,
40 Test,
42}
43
44#[derive(Clone, Copy, Debug, Serialize, Deserialize, bitcode::Encode, bitcode::Decode)]
46pub struct SplitRatios {
47 pub train: f32,
49 pub validation: f32,
51 pub test: f32,
53}
54
55impl Default for SplitRatios {
56 fn default() -> Self {
57 Self {
58 train: 0.8,
59 validation: 0.1,
60 test: 0.1,
61 }
62 }
63}
64
65impl SplitRatios {
66 pub fn normalized(self) -> Result<Self, SamplerError> {
68 let sum = self.train + self.validation + self.test;
69 if (sum - 1.0).abs() > 1e-6 {
70 return Err(SamplerError::Configuration(
71 "split ratios must sum to 1.0".to_string(),
72 ));
73 }
74 Ok(self)
75 }
76}
77
78pub use crate::constants::splits::EPOCH_STATE_VERSION;
79
80#[derive(Clone, Debug, bitcode::Encode, bitcode::Decode)]
82pub struct PersistedSplitMeta {
83 pub epoch: u64,
85 pub offset: u64,
87 pub hashes_checksum: u64,
89}
90
91#[derive(Clone, Debug, bitcode::Encode, bitcode::Decode)]
93pub struct PersistedSplitHashes {
94 pub checksum: u64,
96 pub hashes: Vec<u64>,
98}
99
100#[derive(Clone, Debug, bitcode::Encode, bitcode::Decode)]
102pub struct PersistedSamplerState {
103 pub source_cycle_idx: u64,
105 pub source_record_cursors: Vec<(SourceId, u64)>,
107 pub source_epoch: u64,
109 pub rng_state: u64,
111 pub triplet_recipe_rr_idx: u64,
113 pub text_recipe_rr_idx: u64,
115 pub source_stream_cursors: Vec<(SourceId, u64)>,
119}
120
121pub trait SplitStore: Send + Sync {
125 fn label_for(&self, id: &RecordId) -> Option<SplitLabel>;
127 fn upsert(&self, id: RecordId, label: SplitLabel) -> Result<(), SamplerError>;
129 fn ratios(&self) -> SplitRatios;
131 fn ensure(&self, id: RecordId) -> Result<SplitLabel, SamplerError>;
133}
134
135pub trait EpochStateStore: Send + Sync {
137 fn load_epoch_meta(&self) -> Result<HashMap<SplitLabel, PersistedSplitMeta>, SamplerError>;
139 fn load_epoch_hashes(
141 &self,
142 label: SplitLabel,
143 ) -> Result<Option<PersistedSplitHashes>, SamplerError>;
144 fn save_epoch_meta(
146 &self,
147 meta: &HashMap<SplitLabel, PersistedSplitMeta>,
148 ) -> Result<(), SamplerError>;
149 fn save_epoch_hashes(
151 &self,
152 label: SplitLabel,
153 hashes: &PersistedSplitHashes,
154 ) -> Result<(), SamplerError>;
155}
156
157pub trait SamplerStateStore: Send + Sync {
159 fn load_sampler_state(&self) -> Result<Option<PersistedSamplerState>, SamplerError>;
161 fn save_sampler_state(
163 &self,
164 state: &PersistedSamplerState,
165 save_path: Option<&Path>,
166 ) -> Result<(), SamplerError>;
167}
168
169pub struct DeterministicSplitStore {
171 ratios: SplitRatios,
172 assignments: RwLock<HashMap<RecordId, SplitLabel>>,
173 seed: u64,
174 epoch_meta: RwLock<HashMap<SplitLabel, PersistedSplitMeta>>,
175 epoch_hashes: RwLock<HashMap<SplitLabel, PersistedSplitHashes>>,
176 sampler_state: RwLock<Option<PersistedSamplerState>>,
177}
178
179impl DeterministicSplitStore {
180 pub fn new(ratios: SplitRatios, seed: u64) -> Result<Self, SamplerError> {
182 ratios.normalized()?;
183 Ok(Self {
184 ratios,
185 assignments: RwLock::new(HashMap::new()),
186 seed,
187 epoch_meta: RwLock::new(HashMap::new()),
188 epoch_hashes: RwLock::new(HashMap::new()),
189 sampler_state: RwLock::new(None),
190 })
191 }
192
193 fn derive_label(&self, id: &RecordId) -> SplitLabel {
194 derive_label_for_id(id, self.seed, self.ratios)
195 }
196}
197
198impl SplitStore for DeterministicSplitStore {
199 fn label_for(&self, id: &RecordId) -> Option<SplitLabel> {
200 if let Some(label) = self.assignments.read().ok()?.get(id).copied() {
201 return Some(label);
202 }
203 Some(self.derive_label(id))
204 }
205
206 fn upsert(&self, id: RecordId, label: SplitLabel) -> Result<(), SamplerError> {
207 let mut guard = self
208 .assignments
209 .write()
210 .map_err(|_| SamplerError::SplitStore("lock poisoned".into()))?;
211 guard.insert(id, label);
212 Ok(())
213 }
214
215 fn ratios(&self) -> SplitRatios {
216 self.ratios
217 }
218
219 fn ensure(&self, id: RecordId) -> Result<SplitLabel, SamplerError> {
220 Ok(self.derive_label(&id))
221 }
222}
223
224impl EpochStateStore for DeterministicSplitStore {
225 fn load_epoch_meta(&self) -> Result<HashMap<SplitLabel, PersistedSplitMeta>, SamplerError> {
226 self.epoch_meta
227 .read()
228 .map_err(|_| SamplerError::SplitStore("epoch meta lock poisoned".into()))
229 .map(|guard| guard.clone())
230 }
231
232 fn load_epoch_hashes(
233 &self,
234 label: SplitLabel,
235 ) -> Result<Option<PersistedSplitHashes>, SamplerError> {
236 Ok(self
237 .epoch_hashes
238 .read()
239 .map_err(|_| SamplerError::SplitStore("epoch hashes lock poisoned".into()))?
240 .get(&label)
241 .cloned())
242 }
243
244 fn save_epoch_meta(
245 &self,
246 meta: &HashMap<SplitLabel, PersistedSplitMeta>,
247 ) -> Result<(), SamplerError> {
248 *self
249 .epoch_meta
250 .write()
251 .map_err(|_| SamplerError::SplitStore("epoch meta lock poisoned".into()))? =
252 meta.clone();
253 Ok(())
254 }
255
256 fn save_epoch_hashes(
257 &self,
258 label: SplitLabel,
259 hashes: &PersistedSplitHashes,
260 ) -> Result<(), SamplerError> {
261 self.epoch_hashes
262 .write()
263 .map_err(|_| SamplerError::SplitStore("epoch hashes lock poisoned".into()))?
264 .insert(label, hashes.clone());
265 Ok(())
266 }
267}
268
269impl SamplerStateStore for DeterministicSplitStore {
270 fn load_sampler_state(&self) -> Result<Option<PersistedSamplerState>, SamplerError> {
271 self.sampler_state
272 .read()
273 .map_err(|_| SamplerError::SplitStore("sampler state lock poisoned".into()))
274 .map(|guard| guard.clone())
275 }
276
277 fn save_sampler_state(
278 &self,
279 state: &PersistedSamplerState,
280 _save_path: Option<&Path>,
281 ) -> Result<(), SamplerError> {
282 *self
283 .sampler_state
284 .write()
285 .map_err(|_| SamplerError::SplitStore("sampler state lock poisoned".into()))? =
286 Some(state.clone());
287 Ok(())
288 }
289}
290
291#[derive(Clone, Copy, Debug, bitcode::Encode, bitcode::Decode)]
292struct StoreMeta {
294 version: u8,
295 seed: u64,
296 ratios: SplitRatios,
297}
298
299fn encode_store_meta(meta: &StoreMeta) -> Vec<u8> {
300 encode_bitcode_payload(&bitcode::encode(meta))
301}
302
303fn decode_store_meta(bytes: &[u8]) -> Result<StoreMeta, SamplerError> {
304 let raw = decode_bitcode_payload(bytes)?;
305 bitcode::decode(&raw).map_err(|err| {
306 SamplerError::SplitStore(format!("failed to decode split store metadata: {err}"))
307 })
308}
309
310pub struct FileSplitStore {
326 store: DataStore,
327 path: PathBuf,
329 save_path: PathBuf,
331 ratios: SplitRatios,
332 seed: u64,
333 _temp_dir: TempDir,
335}
336
337impl fmt::Debug for FileSplitStore {
338 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
339 f.debug_struct("FileSplitStore")
340 .field("path", &self.save_path)
341 .field("ratios", &self.ratios)
342 .field("seed", &self.seed)
343 .finish()
344 }
345}
346
347impl FileSplitStore {
348 pub fn open<P: Into<PathBuf>>(
350 path: P,
351 ratios: SplitRatios,
352 seed: u64,
353 ) -> Result<Self, SamplerError> {
354 Self::open_with_load_path(None::<PathBuf>, path, ratios, seed)
355 }
356
357 pub fn open_with_load_path<LP: Into<PathBuf>, SP: Into<PathBuf>>(
371 load_path: Option<LP>,
372 save_path: SP,
373 ratios: SplitRatios,
374 seed: u64,
375 ) -> Result<Self, SamplerError> {
376 let ratios = ratios.normalized()?;
377 let save_path = coerce_store_path(save_path.into());
378 ensure_parent_dir(&save_path)?;
379
380 let source: Option<PathBuf> = if let Some(lp) = load_path {
383 let lp = coerce_store_path(lp.into());
384 if lp.exists() { Some(lp) } else { None }
385 } else if save_path.exists() {
386 Some(save_path.clone())
387 } else {
388 None
389 };
390
391 let temp_dir = tempfile::tempdir().map_err(|err| {
392 SamplerError::SplitStore(format!("failed to create temp dir for split store: {err}"))
393 })?;
394 let working_path = temp_dir.path().join("working_store.bin");
395 if let Some(src) = &source {
396 fs::copy(src, &working_path).map_err(|err| {
397 SamplerError::SplitStore(format!(
398 "failed to copy split store from '{}' to temp: {err}",
399 src.display()
400 ))
401 })?;
402 }
403
404 let raw_store = DataStore::open(&working_path).map_err(map_store_err)?;
405 let store = Self {
406 store: raw_store,
407 path: working_path,
408 save_path,
409 ratios,
410 seed,
411 _temp_dir: temp_dir,
412 };
413 store.verify_metadata()?;
414 Ok(store)
415 }
416
417 fn verify_metadata(&self) -> Result<(), SamplerError> {
418 match read_bytes(&self.store, META_KEY)? {
419 Some(bytes) => {
420 let meta = decode_store_meta(&bytes)?;
421 if meta.version != STORE_VERSION {
422 return Err(SamplerError::SplitStore(format!(
423 "split store version mismatch (expected {}, found {})",
424 STORE_VERSION, meta.version
425 )));
426 }
427 if meta.seed != self.seed {
428 return Err(SamplerError::SplitStore(format!(
429 "split store seed mismatch (expected {}, found {})",
430 self.seed, meta.seed
431 )));
432 }
433 if !ratios_close(meta.ratios, self.ratios) {
434 return Err(SamplerError::SplitStore(
435 "split store ratios mismatch".into(),
436 ));
437 }
438 }
439 None => {
440 let blob = StoreMeta {
441 version: STORE_VERSION,
442 seed: self.seed,
443 ratios: self.ratios,
444 };
445 let payload = encode_store_meta(&blob);
446 write_bytes(&self.store, META_KEY, &payload)?;
447 }
448 }
449 Ok(())
450 }
451
452 fn read_epoch_meta_entry(
453 &self,
454 label: SplitLabel,
455 ) -> Result<Option<PersistedSplitMeta>, SamplerError> {
456 let key = epoch_meta_key(label);
457 let entry = self.store.read(&key).map_err(map_store_err)?;
458 match entry {
459 None => Ok(None),
460 Some(bytes) => decode_epoch_meta(bytes.as_ref()),
461 }
462 }
463
464 fn write_epoch_meta_entry(
465 &self,
466 label: SplitLabel,
467 meta: Option<&PersistedSplitMeta>,
468 ) -> Result<(), SamplerError> {
469 let key = epoch_meta_key(label);
470 let payload = encode_epoch_meta(meta);
471 self.store
472 .write(&key, payload.as_slice())
473 .map_err(map_store_err)?;
474 Ok(())
475 }
476
477 fn read_epoch_hashes_entry(
478 &self,
479 label: SplitLabel,
480 ) -> Result<Option<PersistedSplitHashes>, SamplerError> {
481 let key = epoch_hashes_key(label);
482 let entry = self.store.read(&key).map_err(map_store_err)?;
483 match entry {
484 None => Ok(None),
485 Some(bytes) => decode_epoch_hashes(bytes.as_ref()),
486 }
487 }
488
489 fn write_epoch_hashes_entry(
490 &self,
491 label: SplitLabel,
492 hashes: &PersistedSplitHashes,
493 ) -> Result<(), SamplerError> {
494 let key = epoch_hashes_key(label);
495 let payload = encode_epoch_hashes(hashes);
496 self.store
497 .write(&key, payload.as_slice())
498 .map_err(map_store_err)?;
499 Ok(())
500 }
501}
502
503impl SplitStore for FileSplitStore {
504 fn label_for(&self, id: &RecordId) -> Option<SplitLabel> {
505 let key = split_key(id);
506 if let Ok(Some(value)) = self.store.read(&key)
507 && let Ok(label) = decode_label(value.as_ref())
508 {
509 return Some(label);
510 }
511 Some(derive_label_for_id(id, self.seed, self.ratios))
512 }
513
514 fn upsert(&self, id: RecordId, label: SplitLabel) -> Result<(), SamplerError> {
515 let _ = (id, label);
516 Ok(())
517 }
518
519 fn ratios(&self) -> SplitRatios {
520 self.ratios
521 }
522
523 fn ensure(&self, id: RecordId) -> Result<SplitLabel, SamplerError> {
524 Ok(derive_label_for_id(&id, self.seed, self.ratios))
525 }
526}
527
528impl EpochStateStore for FileSplitStore {
529 fn load_epoch_meta(&self) -> Result<HashMap<SplitLabel, PersistedSplitMeta>, SamplerError> {
530 let mut meta = HashMap::new();
531 for label in ALL_SPLITS {
532 if let Some(entry) = self.read_epoch_meta_entry(label)? {
533 meta.insert(label, entry);
534 }
535 }
536 Ok(meta)
537 }
538
539 fn load_epoch_hashes(
540 &self,
541 label: SplitLabel,
542 ) -> Result<Option<PersistedSplitHashes>, SamplerError> {
543 self.read_epoch_hashes_entry(label)
544 }
545
546 fn save_epoch_meta(
547 &self,
548 meta: &HashMap<SplitLabel, PersistedSplitMeta>,
549 ) -> Result<(), SamplerError> {
550 for label in ALL_SPLITS {
551 self.write_epoch_meta_entry(label, meta.get(&label))?;
552 }
553 Ok(())
554 }
555
556 fn save_epoch_hashes(
557 &self,
558 label: SplitLabel,
559 hashes: &PersistedSplitHashes,
560 ) -> Result<(), SamplerError> {
561 self.write_epoch_hashes_entry(label, hashes)
562 }
563}
564
565impl SamplerStateStore for FileSplitStore {
566 fn load_sampler_state(&self) -> Result<Option<PersistedSamplerState>, SamplerError> {
567 match read_bytes(&self.store, SAMPLER_STATE_KEY)? {
568 Some(bytes) => decode_sampler_state(bytes.as_ref()),
569 None => Ok(None),
570 }
571 }
572
573 fn save_sampler_state(
579 &self,
580 state: &PersistedSamplerState,
581 save_to: Option<&Path>,
582 ) -> Result<(), SamplerError> {
583 let dest = if let Some(p) = save_to {
585 coerce_store_path(p.to_path_buf())
586 } else {
587 self.save_path.clone()
588 };
589
590 if save_to.is_some() && dest.exists() {
593 return Err(SamplerError::SplitStore(format!(
594 "refusing to overwrite existing split store '{}'; choose a new path",
595 dest.display()
596 )));
597 }
598
599 let payload = encode_sampler_state(state);
601 write_bytes(&self.store, SAMPLER_STATE_KEY, &payload)?;
602
603 ensure_parent_dir(&dest)?;
605 fs::copy(&self.path, &dest).map_err(|err| {
606 SamplerError::SplitStore(format!(
607 "failed to publish split store to '{}': {err}",
608 dest.display()
609 ))
610 })?;
611
612 Ok(())
613 }
614}
615
616fn decode_label(bytes: &[u8]) -> Result<SplitLabel, SamplerError> {
617 match bytes.first() {
618 Some(b'0') => Ok(SplitLabel::Train),
619 Some(b'1') => Ok(SplitLabel::Validation),
620 Some(b'2') => Ok(SplitLabel::Test),
621 _ => Err(SamplerError::SplitStore("invalid split label".into())),
622 }
623}
624
625fn derive_label_for_id(id: &RecordId, seed: u64, ratios: SplitRatios) -> SplitLabel {
626 let mut hasher = std::collections::hash_map::DefaultHasher::new();
627 id.hash(&mut hasher);
628 seed.hash(&mut hasher);
629 let value = hasher.finish() as f64 / u64::MAX as f64;
630 let train_cut = ratios.train as f64;
631 let val_cut = train_cut + ratios.validation as f64;
632 if value < train_cut {
633 SplitLabel::Train
634 } else if value < val_cut {
635 SplitLabel::Validation
636 } else {
637 SplitLabel::Test
638 }
639}
640
641fn ratios_close(a: SplitRatios, b: SplitRatios) -> bool {
642 ((a.train - b.train).abs() + (a.validation - b.validation).abs() + (a.test - b.test).abs())
643 < 1e-5
644}
645
646fn split_key(id: &RecordId) -> Vec<u8> {
647 let mut key = Vec::with_capacity(SPLIT_PREFIX.len() + id.len());
648 key.extend_from_slice(SPLIT_PREFIX);
649 key.extend_from_slice(id.as_bytes());
650 key
651}
652
653fn read_bytes(store: &DataStore, key: &[u8]) -> Result<Option<Vec<u8>>, SamplerError> {
654 store
655 .read(key)
656 .map_err(map_store_err)?
657 .map(|entry| Ok(entry.as_ref().to_vec()))
658 .transpose()
659}
660
661fn write_bytes(store: &DataStore, key: &[u8], payload: &[u8]) -> Result<(), SamplerError> {
662 store.write(key, payload).map_err(map_store_err)?;
663 Ok(())
664}
665
666fn epoch_meta_key(label: SplitLabel) -> Vec<u8> {
667 let mut key = Vec::with_capacity(EPOCH_META_PREFIX.len() + 12);
668 key.extend_from_slice(EPOCH_META_PREFIX);
669 let suffix = match label {
670 SplitLabel::Train => b"train".as_ref(),
671 SplitLabel::Validation => b"validation".as_ref(),
672 SplitLabel::Test => b"test".as_ref(),
673 };
674 key.extend_from_slice(suffix);
675 key
676}
677
678fn epoch_hashes_key(label: SplitLabel) -> Vec<u8> {
679 let mut key = Vec::with_capacity(EPOCH_HASHES_PREFIX.len() + 12);
680 key.extend_from_slice(EPOCH_HASHES_PREFIX);
681 let suffix = match label {
682 SplitLabel::Train => b"train".as_ref(),
683 SplitLabel::Validation => b"validation".as_ref(),
684 SplitLabel::Test => b"test".as_ref(),
685 };
686 key.extend_from_slice(suffix);
687 key
688}
689
690fn encode_epoch_meta(meta: Option<&PersistedSplitMeta>) -> Vec<u8> {
691 match meta {
692 None => vec![EPOCH_RECORD_TOMBSTONE],
693 Some(meta) => {
694 let payload = encode_bitcode_payload(&bitcode::encode(meta));
695 let mut buf = Vec::with_capacity(1 + payload.len());
696 buf.push(EPOCH_META_RECORD_VERSION);
697 buf.extend_from_slice(&payload);
698 buf
699 }
700 }
701}
702
703fn decode_epoch_meta(bytes: &[u8]) -> Result<Option<PersistedSplitMeta>, SamplerError> {
704 if bytes.is_empty() || bytes[0] == EPOCH_RECORD_TOMBSTONE {
705 return Ok(None);
706 }
707 if bytes[0] != EPOCH_META_RECORD_VERSION {
708 return Err(SamplerError::SplitStore(
709 "epoch meta record version mismatch".into(),
710 ));
711 }
712 let raw = decode_bitcode_payload(&bytes[1..])?;
713 bitcode::decode(&raw)
714 .map(Some)
715 .map_err(|err| SamplerError::SplitStore(format!("corrupt epoch meta record: {err}")))
716}
717
718fn encode_epoch_hashes(hashes: &PersistedSplitHashes) -> Vec<u8> {
719 let payload = encode_bitcode_payload(&bitcode::encode(hashes));
720 let mut buf = Vec::with_capacity(1 + payload.len());
721 buf.push(EPOCH_HASH_RECORD_VERSION);
722 buf.extend_from_slice(&payload);
723 buf
724}
725
726fn decode_epoch_hashes(bytes: &[u8]) -> Result<Option<PersistedSplitHashes>, SamplerError> {
727 if bytes.is_empty() || bytes[0] == EPOCH_RECORD_TOMBSTONE {
728 return Ok(None);
729 }
730 if bytes[0] != EPOCH_HASH_RECORD_VERSION {
731 return Err(SamplerError::SplitStore(
732 "epoch hashes record version mismatch".into(),
733 ));
734 }
735 let raw = decode_bitcode_payload(&bytes[1..])?;
736 bitcode::decode(&raw)
737 .map(Some)
738 .map_err(|err| SamplerError::SplitStore(format!("corrupt epoch hashes record: {err}")))
739}
740
741fn encode_sampler_state(state: &PersistedSamplerState) -> Vec<u8> {
742 let payload = encode_bitcode_payload(&bitcode::encode(state));
743 let mut buf = Vec::with_capacity(1 + payload.len());
744 buf.push(SAMPLER_STATE_RECORD_VERSION);
745 buf.extend_from_slice(&payload);
746 buf
747}
748
749fn decode_sampler_state(bytes: &[u8]) -> Result<Option<PersistedSamplerState>, SamplerError> {
750 if bytes.is_empty() {
751 return Ok(None);
752 }
753 if bytes[0] != SAMPLER_STATE_RECORD_VERSION {
754 return Err(SamplerError::SplitStore(
755 "sampler state record version mismatch".into(),
756 ));
757 }
758 let raw = decode_bitcode_payload(&bytes[1..])?;
759 bitcode::decode(&raw)
760 .map(Some)
761 .map_err(|err| SamplerError::SplitStore(format!("corrupt sampler state record: {err}")))
762}
763
764fn encode_bitcode_payload(bytes: &[u8]) -> Vec<u8> {
765 let mut out = Vec::with_capacity(1 + bytes.len());
766 out.push(BITCODE_PREFIX);
767 out.extend_from_slice(bytes);
768 out
769}
770
771fn decode_bitcode_payload(bytes: &[u8]) -> Result<Vec<u8>, SamplerError> {
772 if bytes.first().copied() != Some(BITCODE_PREFIX) {
773 return Err(SamplerError::SplitStore(
774 "bitcode payload missing expected prefix".into(),
775 ));
776 }
777 Ok(bytes[1..].to_vec())
778}
779
780fn coerce_store_path(path: PathBuf) -> PathBuf {
781 path
782}
783
784fn ensure_parent_dir(path: &Path) -> Result<(), SamplerError> {
785 if let Some(parent) = path.parent()
786 && !parent.as_os_str().is_empty()
787 {
788 fs::create_dir_all(parent)?;
789 }
790 Ok(())
791}
792
793fn map_store_err(err: io::Error) -> SamplerError {
794 SamplerError::SplitStore(err.to_string())
795}
796
797#[cfg(test)]
798mod tests {
799 use super::*;
800 use std::collections::HashMap;
801 use tempfile::tempdir;
802
803 #[test]
804 fn split_ratios_reject_non_unit_sum() {
805 let invalid = SplitRatios {
806 train: 0.6,
807 validation: 0.3,
808 test: 0.3,
809 };
810
811 let err = DeterministicSplitStore::new(invalid, 1)
812 .err()
813 .expect("expected non-unit split ratios to fail");
814 assert!(matches!(
815 err,
816 SamplerError::Configuration(ref msg) if msg.contains("split ratios must sum to 1.0")
817 ));
818
819 let dir = tempdir().unwrap();
820 let path = dir.path().join("split_store.bin");
821 let err = FileSplitStore::open(&path, invalid, 1).unwrap_err();
822 assert!(matches!(
823 err,
824 SamplerError::Configuration(ref msg) if msg.contains("split ratios must sum to 1.0")
825 ));
826 }
827
828 #[test]
829 fn zero_test_ratio_never_assigns_test_labels() {
830 let ratios = SplitRatios {
831 train: 0.5,
832 validation: 0.5,
833 test: 0.0,
834 };
835 let store = DeterministicSplitStore::new(ratios, 7).unwrap();
836
837 let mut saw_train = false;
838 let mut saw_validation = false;
839 for idx in 0..20_000 {
840 let id = format!("record_{idx}");
841 let label = store.ensure(id).unwrap();
842 assert_ne!(label, SplitLabel::Test);
843 saw_train |= label == SplitLabel::Train;
844 saw_validation |= label == SplitLabel::Validation;
845 if saw_train && saw_validation {
846 break;
847 }
848 }
849
850 assert!(saw_train);
851 assert!(saw_validation);
852 }
853
854 #[test]
857 fn save_none_publishes_to_save_path_and_reloads_cleanly() {
858 let dir = tempdir().unwrap();
859 let path = dir.path().join("persist.bin");
860 let ratios = SplitRatios::default();
861
862 {
863 let store = FileSplitStore::open(&path, ratios, 123).unwrap();
864 let mut meta = HashMap::new();
865 meta.insert(
866 SplitLabel::Train,
867 PersistedSplitMeta {
868 epoch: 3,
869 offset: 7,
870 hashes_checksum: 42,
871 },
872 );
873 store.save_epoch_meta(&meta).unwrap();
874 let state = PersistedSamplerState {
875 source_cycle_idx: 11,
876 source_record_cursors: vec![("s".to_string(), 1)],
877 source_epoch: 5,
878 rng_state: 99,
879 triplet_recipe_rr_idx: 2,
880 text_recipe_rr_idx: 3,
881 source_stream_cursors: vec![],
882 };
883 assert!(!path.exists(), "save_path must not exist before save(None)");
884 store.save_sampler_state(&state, None).unwrap();
885 assert!(
886 path.exists(),
887 "save(None) must publish to save_path on disk"
888 );
889 }
890
891 let reopened = FileSplitStore::open(&path, ratios, 123).unwrap();
893 let loaded_state = reopened.load_sampler_state().unwrap().unwrap();
894 assert_eq!(loaded_state.source_cycle_idx, 11);
895 assert_eq!(loaded_state.source_epoch, 5);
896 assert_eq!(loaded_state.rng_state, 99);
897 let loaded_meta = reopened.load_epoch_meta().unwrap();
898 assert_eq!(loaded_meta.get(&SplitLabel::Train).unwrap().epoch, 3);
899 }
900
901 #[test]
902 fn file_store_rejects_seed_mismatch() {
903 let dir = tempdir().unwrap();
904 let path = dir.path().join("splits.json");
905 let ratios = SplitRatios::default();
906 let store = FileSplitStore::open(&path, ratios, 123).unwrap();
907 store.ensure("abc".to_string()).unwrap();
908 store
910 .save_sampler_state(
911 &PersistedSamplerState {
912 source_cycle_idx: 0,
913 source_record_cursors: vec![],
914 source_epoch: 0,
915 rng_state: 0,
916 triplet_recipe_rr_idx: 0,
917 text_recipe_rr_idx: 0,
918 source_stream_cursors: vec![],
919 },
920 None,
921 )
922 .unwrap();
923 drop(store);
924
925 let err = FileSplitStore::open(&path, ratios, 999).unwrap_err();
926 assert!(matches!(
927 err,
928 SamplerError::SplitStore(msg) if msg.contains("seed")
929 ));
930 }
931
932 #[test]
933 fn file_store_accepts_directory_path() {
934 let dir = tempdir().unwrap();
935 let ratios = SplitRatios::default();
936 let err = FileSplitStore::open(dir.path(), ratios, 777).unwrap_err();
937 assert!(matches!(err, SamplerError::SplitStore(_)));
938 }
939
940 #[test]
941 fn bitcode_payload_requires_prefix() {
942 let err = decode_bitcode_payload(&[0x00, 0x01]).unwrap_err();
943 assert!(
944 matches!(err, SamplerError::SplitStore(msg) if msg.contains("missing expected prefix"))
945 );
946 }
947
948 #[test]
949 fn file_store_round_trips_epoch_and_sampler_state() {
950 let dir = tempdir().unwrap();
951 let path = dir.path().join("epoch_sampler_state.bin");
952 let store = FileSplitStore::open(&path, SplitRatios::default(), 222).unwrap();
953
954 assert!(store.load_epoch_hashes(SplitLabel::Test).unwrap().is_none());
955
956 let mut epoch_meta = HashMap::new();
957 epoch_meta.insert(
958 SplitLabel::Train,
959 PersistedSplitMeta {
960 epoch: 3,
961 offset: 7,
962 hashes_checksum: 42,
963 },
964 );
965 store.save_epoch_meta(&epoch_meta).unwrap();
966
967 let loaded_meta = store.load_epoch_meta().unwrap();
968 let loaded_train = loaded_meta.get(&SplitLabel::Train).unwrap();
969 assert_eq!(loaded_train.epoch, 3);
970 assert_eq!(loaded_train.offset, 7);
971 assert_eq!(loaded_train.hashes_checksum, 42);
972
973 let hashes = PersistedSplitHashes {
974 checksum: 99,
975 hashes: vec![10, 20, 30],
976 };
977 store
978 .save_epoch_hashes(SplitLabel::Validation, &hashes)
979 .unwrap();
980 let loaded_hashes = store
981 .load_epoch_hashes(SplitLabel::Validation)
982 .unwrap()
983 .unwrap();
984 assert_eq!(loaded_hashes.checksum, 99);
985 assert_eq!(loaded_hashes.hashes, vec![10, 20, 30]);
986
987 let state = PersistedSamplerState {
988 source_cycle_idx: 11,
989 source_record_cursors: vec![("source_a".to_string(), 1)],
990 source_epoch: 8,
991 rng_state: 1234,
992 triplet_recipe_rr_idx: 2,
993 text_recipe_rr_idx: 5,
994 source_stream_cursors: vec![("source_a".to_string(), 9)],
995 };
996 store.save_sampler_state(&state, None).unwrap();
997 let loaded_state = store.load_sampler_state().unwrap().unwrap();
998 assert_eq!(loaded_state.source_cycle_idx, 11);
999 assert_eq!(loaded_state.source_epoch, 8);
1000 assert_eq!(loaded_state.rng_state, 1234);
1001 assert_eq!(loaded_state.triplet_recipe_rr_idx, 2);
1002 assert_eq!(loaded_state.text_recipe_rr_idx, 5);
1003 assert_eq!(
1004 loaded_state.source_record_cursors,
1005 vec![("source_a".to_string(), 1)]
1006 );
1007 assert_eq!(
1008 loaded_state.source_stream_cursors,
1009 vec![("source_a".to_string(), 9)]
1010 );
1011
1012 drop(store);
1014 let reopened = FileSplitStore::open(&path, SplitRatios::default(), 222).unwrap();
1015 let disk_state = reopened.load_sampler_state().unwrap().unwrap();
1016 assert_eq!(disk_state.source_cycle_idx, 11);
1017 assert_eq!(disk_state.source_epoch, 8);
1018 assert_eq!(disk_state.rng_state, 1234);
1019 let disk_meta = reopened.load_epoch_meta().unwrap();
1020 assert_eq!(disk_meta.get(&SplitLabel::Train).unwrap().epoch, 3);
1021 let disk_hashes = reopened
1022 .load_epoch_hashes(SplitLabel::Validation)
1023 .unwrap()
1024 .unwrap();
1025 assert_eq!(disk_hashes.checksum, 99);
1026 }
1027
1028 #[test]
1029 fn split_keys_and_labels_cover_helper_paths() {
1030 let key = split_key(&"abc".to_string());
1031 assert!(key.starts_with(SPLIT_PREFIX));
1032
1033 assert!(matches!(decode_label(b"0"), Ok(SplitLabel::Train)));
1034 assert!(matches!(decode_label(b"1"), Ok(SplitLabel::Validation)));
1035 assert!(matches!(decode_label(b"2"), Ok(SplitLabel::Test)));
1036 assert!(decode_label(b"x").is_err());
1037
1038 let epoch_meta_train = epoch_meta_key(SplitLabel::Train);
1039 let epoch_hashes_train = epoch_hashes_key(SplitLabel::Train);
1040 let epoch_hashes_test = epoch_hashes_key(SplitLabel::Test);
1041 assert!(epoch_meta_train.starts_with(EPOCH_META_PREFIX));
1042 assert!(epoch_hashes_train.starts_with(EPOCH_HASHES_PREFIX));
1043 assert!(epoch_hashes_test.starts_with(EPOCH_HASHES_PREFIX));
1044 }
1045
1046 #[test]
1047 fn encode_decode_store_meta_roundtrip_and_corrupt_prefix_error() {
1048 let meta = StoreMeta {
1049 version: STORE_VERSION,
1050 seed: 55,
1051 ratios: SplitRatios::default(),
1052 };
1053 let encoded = encode_store_meta(&meta);
1054 let decoded = decode_store_meta(&encoded).unwrap();
1055 assert_eq!(decoded.version, STORE_VERSION);
1056 assert_eq!(decoded.seed, 55);
1057
1058 let err = decode_store_meta(&[0x00, 0x01]).unwrap_err();
1059 assert!(matches!(
1060 err,
1061 SamplerError::SplitStore(msg) if msg.contains("missing expected prefix")
1062 ));
1063 }
1064
1065 #[test]
1066 fn epoch_and_sampler_decoders_cover_tombstone_and_version_mismatch() {
1067 assert!(decode_epoch_meta(&[]).unwrap().is_none());
1068 assert!(
1069 decode_epoch_meta(&[EPOCH_RECORD_TOMBSTONE])
1070 .unwrap()
1071 .is_none()
1072 );
1073 assert!(decode_epoch_hashes(&[]).unwrap().is_none());
1074 assert!(
1075 decode_epoch_hashes(&[EPOCH_RECORD_TOMBSTONE])
1076 .unwrap()
1077 .is_none()
1078 );
1079 assert!(decode_sampler_state(&[]).unwrap().is_none());
1080
1081 let meta_mismatch = decode_epoch_meta(&[EPOCH_META_RECORD_VERSION.wrapping_add(1), 1]);
1082 assert!(matches!(
1083 meta_mismatch,
1084 Err(SamplerError::SplitStore(msg)) if msg.contains("version mismatch")
1085 ));
1086 let hashes_mismatch = decode_epoch_hashes(&[EPOCH_HASH_RECORD_VERSION.wrapping_add(1), 1]);
1087 assert!(matches!(
1088 hashes_mismatch,
1089 Err(SamplerError::SplitStore(msg)) if msg.contains("version mismatch")
1090 ));
1091 let state_mismatch =
1092 decode_sampler_state(&[SAMPLER_STATE_RECORD_VERSION.wrapping_add(1), 1]);
1093 assert!(matches!(
1094 state_mismatch,
1095 Err(SamplerError::SplitStore(msg)) if msg.contains("version mismatch")
1096 ));
1097 }
1098
1099 #[test]
1100 fn split_store_trait_methods_and_path_helpers_are_exercised() {
1101 let dir = tempdir().unwrap();
1102 let file_path = dir.path().join("nested").join("store.bin");
1103 ensure_parent_dir(&file_path).unwrap();
1104 assert!(file_path.parent().unwrap().exists());
1105
1106 let existing_dir_path = coerce_store_path(dir.path().to_path_buf());
1107 assert_eq!(existing_dir_path, dir.path().to_path_buf());
1108
1109 let ratios = SplitRatios::default();
1110 let store = FileSplitStore::open(&file_path, ratios, 444).unwrap();
1111 assert!((store.ratios().train - ratios.train).abs() < 1e-6);
1112 store
1113 .upsert("record_1".to_string(), SplitLabel::Validation)
1114 .unwrap();
1115 let ensured = store.ensure("record_1".to_string()).unwrap();
1116 assert!(matches!(
1117 ensured,
1118 SplitLabel::Train | SplitLabel::Validation | SplitLabel::Test
1119 ));
1120
1121 let mapped = map_store_err(io::Error::other("boom"));
1122 assert!(matches!(mapped, SamplerError::SplitStore(msg) if msg.contains("boom")));
1123 }
1124
1125 #[test]
1126 fn epoch_and_sampler_encode_decode_roundtrips() {
1127 let meta = PersistedSplitMeta {
1128 epoch: 4,
1129 offset: 9,
1130 hashes_checksum: 21,
1131 };
1132 let encoded_meta = encode_epoch_meta(Some(&meta));
1133 let decoded_meta = decode_epoch_meta(&encoded_meta).unwrap().unwrap();
1134 assert_eq!(decoded_meta.epoch, 4);
1135 assert_eq!(decoded_meta.offset, 9);
1136
1137 let hashes = PersistedSplitHashes {
1138 checksum: 7,
1139 hashes: vec![1, 2, 3],
1140 };
1141 let encoded_hashes = encode_epoch_hashes(&hashes);
1142 let decoded_hashes = decode_epoch_hashes(&encoded_hashes).unwrap().unwrap();
1143 assert_eq!(decoded_hashes.checksum, 7);
1144 assert_eq!(decoded_hashes.hashes, vec![1, 2, 3]);
1145
1146 let state = PersistedSamplerState {
1147 source_cycle_idx: 1,
1148 source_record_cursors: vec![("s".to_string(), 2)],
1149 source_epoch: 3,
1150 rng_state: 4,
1151 triplet_recipe_rr_idx: 5,
1152 text_recipe_rr_idx: 6,
1153 source_stream_cursors: vec![("s".to_string(), 7)],
1154 };
1155 let encoded_state = encode_sampler_state(&state);
1156 let decoded_state = decode_sampler_state(&encoded_state).unwrap().unwrap();
1157 assert_eq!(decoded_state.source_cycle_idx, 1);
1158 assert_eq!(decoded_state.source_epoch, 3);
1159 assert_eq!(decoded_state.rng_state, 4);
1160 }
1161
1162 #[test]
1163 fn deterministic_store_trait_methods_work() {
1164 let ratios = SplitRatios::default();
1165 let store = DeterministicSplitStore::new(ratios, 999).unwrap();
1166
1167 assert_eq!(store.ratios().train, ratios.train);
1168
1169 let id = "source::record".to_string();
1170 let derived = store.label_for(&id).unwrap();
1171 store.upsert(id.clone(), SplitLabel::Validation).unwrap();
1172 assert_eq!(store.label_for(&id), Some(SplitLabel::Validation));
1173 assert!(matches!(
1174 derived,
1175 SplitLabel::Train | SplitLabel::Validation | SplitLabel::Test
1176 ));
1177
1178 let mut meta = HashMap::new();
1179 meta.insert(
1180 SplitLabel::Test,
1181 PersistedSplitMeta {
1182 epoch: 1,
1183 offset: 2,
1184 hashes_checksum: 3,
1185 },
1186 );
1187 store.save_epoch_meta(&meta).unwrap();
1188 let loaded_meta = store.load_epoch_meta().unwrap();
1189 assert_eq!(loaded_meta.get(&SplitLabel::Test).unwrap().offset, 2);
1190
1191 assert!(
1192 store
1193 .load_epoch_hashes(SplitLabel::Train)
1194 .unwrap()
1195 .is_none()
1196 );
1197 store
1198 .save_epoch_hashes(
1199 SplitLabel::Train,
1200 &PersistedSplitHashes {
1201 checksum: 11,
1202 hashes: vec![4, 5],
1203 },
1204 )
1205 .unwrap();
1206 assert_eq!(
1207 store
1208 .load_epoch_hashes(SplitLabel::Train)
1209 .unwrap()
1210 .unwrap()
1211 .checksum,
1212 11
1213 );
1214
1215 assert!(store.load_sampler_state().unwrap().is_none());
1216 let sampler_state = PersistedSamplerState {
1217 source_cycle_idx: 1,
1218 source_record_cursors: vec![("s1".to_string(), 2)],
1219 source_epoch: 3,
1220 rng_state: 4,
1221 triplet_recipe_rr_idx: 5,
1222 text_recipe_rr_idx: 6,
1223 source_stream_cursors: vec![("s1".to_string(), 7)],
1224 };
1225 store.save_sampler_state(&sampler_state, None).unwrap();
1226 assert_eq!(
1227 store.load_sampler_state().unwrap().unwrap().source_epoch,
1228 sampler_state.source_epoch
1229 );
1230 }
1231
1232 #[test]
1233 fn open_with_load_path_bootstraps_state_explicitly() {
1234 let dir = tempdir().unwrap();
1235 let path_a = dir.path().join("snapshot_a.bin");
1236 let path_b = dir.path().join("snapshot_b.bin");
1237 let ratios = SplitRatios::default();
1238
1239 let store_a = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1240
1241 let mut meta = HashMap::new();
1242 meta.insert(
1243 SplitLabel::Train,
1244 PersistedSplitMeta {
1245 epoch: 5,
1246 offset: 3,
1247 hashes_checksum: 999,
1248 },
1249 );
1250 store_a.save_epoch_meta(&meta).unwrap();
1251
1252 let sampler_state = PersistedSamplerState {
1253 source_cycle_idx: 1,
1254 source_record_cursors: vec![("s1".to_string(), 2)],
1255 source_epoch: 7,
1256 rng_state: 123,
1257 triplet_recipe_rr_idx: 4,
1258 text_recipe_rr_idx: 6,
1259 source_stream_cursors: vec![("s1".to_string(), 8)],
1260 };
1261 store_a.save_sampler_state(&sampler_state, None).unwrap();
1262 drop(store_a);
1263
1264 let store_b =
1265 FileSplitStore::open_with_load_path(Some(path_a.clone()), &path_b, ratios, 42).unwrap();
1266 assert_eq!(
1267 store_b
1268 .load_epoch_meta()
1269 .unwrap()
1270 .get(&SplitLabel::Train)
1271 .unwrap()
1272 .epoch,
1273 5
1274 );
1275 assert_eq!(
1276 store_b.load_sampler_state().unwrap().unwrap().source_epoch,
1277 7
1278 );
1279
1280 let store_a_again = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1281 assert_eq!(
1282 store_a_again
1283 .load_epoch_meta()
1284 .unwrap()
1285 .get(&SplitLabel::Train)
1286 .unwrap()
1287 .epoch,
1288 5
1289 );
1290 assert_eq!(
1291 store_a_again
1292 .load_sampler_state()
1293 .unwrap()
1294 .unwrap()
1295 .source_epoch,
1296 7
1297 );
1298 }
1299
1300 #[test]
1301 fn save_sampler_state_to_new_path_copies_existing_store_first() {
1302 let dir = tempdir().unwrap();
1303 let path_a = dir.path().join("source_store.bin");
1304 let path_b = dir.path().join("mirror_store.bin");
1305 let ratios = SplitRatios::default();
1306
1307 let store_a = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1308
1309 let assigned_id = "record_with_assignment".to_string();
1310 let assigned_key = split_key(&assigned_id);
1311 store_a.store.write(&assigned_key, b"1").unwrap();
1312
1313 let sampler_state = PersistedSamplerState {
1314 source_cycle_idx: 1,
1315 source_record_cursors: vec![("s1".to_string(), 2)],
1316 source_epoch: 9,
1317 rng_state: 123,
1318 triplet_recipe_rr_idx: 4,
1319 text_recipe_rr_idx: 6,
1320 source_stream_cursors: vec![("s1".to_string(), 8)],
1321 };
1322
1323 store_a
1324 .save_sampler_state(&sampler_state, Some(path_b.as_path()))
1325 .unwrap();
1326
1327 let store_b = FileSplitStore::open(&path_b, ratios, 42).unwrap();
1329 assert_eq!(
1330 store_b.label_for(&assigned_id),
1331 Some(SplitLabel::Validation)
1332 );
1333 assert_eq!(
1334 store_b.load_sampler_state().unwrap().unwrap().source_epoch,
1335 9
1336 );
1337
1338 assert!(
1340 !path_a.exists(),
1341 "save_to=Some(...) must not publish to the canonical save_path"
1342 );
1343 }
1344
1345 #[test]
1348 fn save_some_on_regular_open_does_not_modify_working_store() {
1349 let dir = tempdir().unwrap();
1350 let path_a = dir.path().join("working_store.bin");
1351 let path_b = dir.path().join("checkpoint_store.bin");
1352 let ratios = SplitRatios::default();
1353
1354 let store_a = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1355
1356 let initial_state = PersistedSamplerState {
1358 source_cycle_idx: 1,
1359 source_record_cursors: vec![],
1360 source_epoch: 1,
1361 rng_state: 0,
1362 triplet_recipe_rr_idx: 0,
1363 text_recipe_rr_idx: 0,
1364 source_stream_cursors: vec![],
1365 };
1366 store_a.save_sampler_state(&initial_state, None).unwrap();
1367
1368 let checkpoint_state = PersistedSamplerState {
1370 source_cycle_idx: 99,
1371 source_record_cursors: vec![],
1372 source_epoch: 99,
1373 rng_state: 42,
1374 triplet_recipe_rr_idx: 0,
1375 text_recipe_rr_idx: 0,
1376 source_stream_cursors: vec![],
1377 };
1378 store_a
1379 .save_sampler_state(&checkpoint_state, Some(path_b.as_path()))
1380 .unwrap();
1381
1382 drop(store_a);
1385 let store_a_disk = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1386 assert_eq!(
1387 store_a_disk
1388 .load_sampler_state()
1389 .unwrap()
1390 .unwrap()
1391 .source_epoch,
1392 1,
1393 "save_to=Some(...) must not overwrite the on-disk save_path"
1394 );
1395
1396 let store_b = FileSplitStore::open(&path_b, ratios, 42).unwrap();
1398 let state_from_b = store_b.load_sampler_state().unwrap().unwrap();
1399 assert_eq!(
1400 state_from_b.source_epoch, 99,
1401 "checkpoint store must hold the snapshotted state"
1402 );
1403 }
1404
1405 #[test]
1406 fn file_store_metadata_mismatch_and_debug_paths_are_covered() {
1407 let dir = tempdir().unwrap();
1408 let path = dir.path().join("meta_mismatch.bin");
1409 let ratios = SplitRatios::default();
1410 let store = FileSplitStore::open(&path, ratios, 123).unwrap();
1411
1412 let debug_repr = format!("{store:?}");
1413 assert!(debug_repr.contains("FileSplitStore"));
1414
1415 let wrong_version = StoreMeta {
1416 version: STORE_VERSION.wrapping_add(1),
1417 seed: 123,
1418 ratios,
1419 };
1420 let payload = encode_store_meta(&wrong_version);
1421 store.store.write(META_KEY, &payload).unwrap();
1422 store
1424 .save_sampler_state(
1425 &PersistedSamplerState {
1426 source_cycle_idx: 0,
1427 source_record_cursors: vec![],
1428 source_epoch: 0,
1429 rng_state: 0,
1430 triplet_recipe_rr_idx: 0,
1431 text_recipe_rr_idx: 0,
1432 source_stream_cursors: vec![],
1433 },
1434 None,
1435 )
1436 .unwrap();
1437 drop(store);
1438
1439 let err = FileSplitStore::open(&path, ratios, 123).unwrap_err();
1440 assert!(matches!(err, SamplerError::SplitStore(msg) if msg.contains("version mismatch")));
1441
1442 let ratio_path = dir.path().join("ratio_mismatch.bin");
1443 let baseline = FileSplitStore::open(&ratio_path, ratios, 777).unwrap();
1444 baseline
1446 .save_sampler_state(
1447 &PersistedSamplerState {
1448 source_cycle_idx: 0,
1449 source_record_cursors: vec![],
1450 source_epoch: 0,
1451 rng_state: 0,
1452 triplet_recipe_rr_idx: 0,
1453 text_recipe_rr_idx: 0,
1454 source_stream_cursors: vec![],
1455 },
1456 None,
1457 )
1458 .unwrap();
1459 drop(baseline);
1460
1461 let different_ratios = SplitRatios {
1462 train: 0.7,
1463 validation: 0.2,
1464 test: 0.1,
1465 };
1466 let err = FileSplitStore::open(&ratio_path, different_ratios, 777).unwrap_err();
1467 assert!(matches!(err, SamplerError::SplitStore(msg) if msg.contains("ratios mismatch")));
1468 }
1469
1470 #[test]
1471 fn split_decode_helpers_reject_corrupt_bitcode_payloads() {
1472 let store_meta_err = decode_store_meta(&[BITCODE_PREFIX, 0xFF, 0xEE]).unwrap_err();
1473 assert!(matches!(
1474 store_meta_err,
1475 SamplerError::SplitStore(msg) if msg.contains("failed to decode split store metadata")
1476 ));
1477
1478 let epoch_meta_err =
1479 decode_epoch_meta(&[EPOCH_META_RECORD_VERSION, BITCODE_PREFIX, 0xFF]).unwrap_err();
1480 assert!(
1481 matches!(epoch_meta_err, SamplerError::SplitStore(msg) if msg.contains("corrupt epoch meta record"))
1482 );
1483
1484 let epoch_hashes_err =
1485 decode_epoch_hashes(&[EPOCH_HASH_RECORD_VERSION, BITCODE_PREFIX, 0xFF]).unwrap_err();
1486 assert!(matches!(
1487 epoch_hashes_err,
1488 SamplerError::SplitStore(msg) if msg.contains("corrupt epoch hashes record")
1489 ));
1490
1491 let sampler_state_err =
1492 decode_sampler_state(&[SAMPLER_STATE_RECORD_VERSION, BITCODE_PREFIX, 0xFF])
1493 .unwrap_err();
1494 assert!(matches!(
1495 sampler_state_err,
1496 SamplerError::SplitStore(msg) if msg.contains("corrupt sampler state record")
1497 ));
1498 }
1499
1500 #[test]
1501 fn file_store_label_fallback_and_validation_keys_are_covered() {
1502 let dir = tempdir().unwrap();
1503 let path = dir.path().join("labels.bin");
1504 let store = FileSplitStore::open(&path, SplitRatios::default(), 42).unwrap();
1505
1506 let id = "bad_label_record".to_string();
1507 let expected = derive_label_for_id(&id, 42, SplitRatios::default());
1508 let key = split_key(&id);
1509
1510 store.store.write(&key, b"x").unwrap();
1511 assert_eq!(store.label_for(&id), Some(expected));
1512
1513 store.store.write(&key, b"1").unwrap();
1514 assert_eq!(store.label_for(&id), Some(SplitLabel::Validation));
1515
1516 let meta_validation = epoch_meta_key(SplitLabel::Validation);
1517 let hashes_validation = epoch_hashes_key(SplitLabel::Validation);
1518 assert!(meta_validation.starts_with(EPOCH_META_PREFIX));
1519 assert!(hashes_validation.starts_with(EPOCH_HASHES_PREFIX));
1520 assert!(meta_validation.ends_with(b"validation"));
1521 assert!(hashes_validation.ends_with(b"validation"));
1522 }
1523
1524 #[test]
1525 fn ensure_parent_dir_allows_plain_file_names() {
1526 ensure_parent_dir(Path::new("split_store_local.bin")).unwrap();
1527 let coerced = coerce_store_path(PathBuf::from("explicit_store.bin"));
1528 assert_eq!(coerced, PathBuf::from("explicit_store.bin"));
1529 }
1530
1531 #[test]
1537 fn load_path_source_is_never_modified_while_open() {
1538 let dir = tempdir().unwrap();
1539 let source = dir.path().join("source.bin");
1540 let dest = dir.path().join("dest.bin");
1541 let ratios = SplitRatios::default();
1542
1543 let seeded = FileSplitStore::open(&source, ratios, 77).unwrap();
1545 let state = PersistedSamplerState {
1546 source_cycle_idx: 5,
1547 source_record_cursors: vec![("s".to_string(), 3)],
1548 source_epoch: 9,
1549 rng_state: 42,
1550 triplet_recipe_rr_idx: 1,
1551 text_recipe_rr_idx: 2,
1552 source_stream_cursors: vec![("s".to_string(), 4)],
1553 };
1554 seeded.save_sampler_state(&state, None).unwrap();
1555 drop(seeded);
1556
1557 let source_size_before = std::fs::metadata(&source).unwrap().len();
1558
1559 let bootstrapped =
1561 FileSplitStore::open_with_load_path(Some(&source), &dest, ratios, 77).unwrap();
1562 let new_state = PersistedSamplerState {
1563 source_cycle_idx: 99,
1564 source_record_cursors: vec![("s".to_string(), 77)],
1565 source_epoch: 100,
1566 rng_state: 0,
1567 triplet_recipe_rr_idx: 0,
1568 text_recipe_rr_idx: 0,
1569 source_stream_cursors: vec![],
1570 };
1571 bootstrapped.save_sampler_state(&new_state, None).unwrap();
1572 drop(bootstrapped);
1573
1574 let source_size_after = std::fs::metadata(&source).unwrap().len();
1576 assert_eq!(
1577 source_size_before, source_size_after,
1578 "source file was modified during bootstrapped open"
1579 );
1580
1581 let verify_source = FileSplitStore::open(&source, ratios, 77).unwrap();
1583 let loaded = verify_source.load_sampler_state().unwrap().unwrap();
1584 assert_eq!(loaded.source_cycle_idx, 5);
1585 assert_eq!(loaded.source_epoch, 9);
1586 }
1587
1588 #[test]
1591 fn save_none_on_bootstrapped_store_publishes_to_save_path() {
1592 let dir = tempdir().unwrap();
1593 let source = dir.path().join("load.bin");
1594 let dest = dir.path().join("save.bin");
1595 let ratios = SplitRatios::default();
1596
1597 let _ = FileSplitStore::open(&source, ratios, 11).unwrap();
1598
1599 assert!(!dest.exists(), "dest must not exist before first save");
1600
1601 let store = FileSplitStore::open_with_load_path(Some(&source), &dest, ratios, 11).unwrap();
1602 let state = PersistedSamplerState {
1603 source_cycle_idx: 7,
1604 source_record_cursors: vec![],
1605 source_epoch: 2,
1606 rng_state: 1,
1607 triplet_recipe_rr_idx: 0,
1608 text_recipe_rr_idx: 0,
1609 source_stream_cursors: vec![],
1610 };
1611 store.save_sampler_state(&state, None).unwrap();
1612 drop(store);
1613
1614 assert!(dest.exists(), "dest must exist after save(None)");
1615
1616 let loaded_dest = FileSplitStore::open(&dest, ratios, 11).unwrap();
1617 assert_eq!(
1618 loaded_dest
1619 .load_sampler_state()
1620 .unwrap()
1621 .unwrap()
1622 .source_cycle_idx,
1623 7
1624 );
1625 }
1626
1627 #[test]
1630 fn save_some_on_bootstrapped_store_publishes_to_explicit_path_only() {
1631 let dir = tempdir().unwrap();
1632 let source = dir.path().join("load.bin");
1633 let save = dir.path().join("save.bin"); let other = dir.path().join("other.bin"); let ratios = SplitRatios::default();
1636
1637 let _ = FileSplitStore::open(&source, ratios, 22).unwrap();
1638
1639 let store = FileSplitStore::open_with_load_path(Some(&source), &save, ratios, 22).unwrap();
1640 let state = PersistedSamplerState {
1641 source_cycle_idx: 3,
1642 source_record_cursors: vec![],
1643 source_epoch: 1,
1644 rng_state: 0,
1645 triplet_recipe_rr_idx: 0,
1646 text_recipe_rr_idx: 0,
1647 source_stream_cursors: vec![],
1648 };
1649 store
1650 .save_sampler_state(&state, Some(other.as_path()))
1651 .unwrap();
1652 drop(store);
1653
1654 assert!(
1655 !save.exists(),
1656 "canonical save_path must not be created when saving to explicit path"
1657 );
1658 assert!(other.exists(), "explicit target must be created");
1659
1660 let loaded_other = FileSplitStore::open(&other, ratios, 22).unwrap();
1661 assert_eq!(
1662 loaded_other
1663 .load_sampler_state()
1664 .unwrap()
1665 .unwrap()
1666 .source_cycle_idx,
1667 3
1668 );
1669 }
1670
1671 #[test]
1674 fn repeated_save_none_on_bootstrapped_store_is_idempotent() {
1675 let dir = tempdir().unwrap();
1676 let source = dir.path().join("load.bin");
1677 let dest = dir.path().join("save.bin");
1678 let ratios = SplitRatios::default();
1679
1680 let _ = FileSplitStore::open(&source, ratios, 33).unwrap();
1681
1682 let store = FileSplitStore::open_with_load_path(Some(&source), &dest, ratios, 33).unwrap();
1683
1684 for cycle_idx in [1_u64, 2, 3] {
1685 let state = PersistedSamplerState {
1686 source_cycle_idx: cycle_idx,
1687 source_record_cursors: vec![],
1688 source_epoch: 0,
1689 rng_state: 0,
1690 triplet_recipe_rr_idx: 0,
1691 text_recipe_rr_idx: 0,
1692 source_stream_cursors: vec![],
1693 };
1694 store.save_sampler_state(&state, None).unwrap();
1695 }
1696 drop(store);
1697
1698 let reloaded = FileSplitStore::open(&dest, ratios, 33).unwrap();
1699 assert_eq!(
1700 reloaded
1701 .load_sampler_state()
1702 .unwrap()
1703 .unwrap()
1704 .source_cycle_idx,
1705 3,
1706 "dest should hold the last-saved state"
1707 );
1708 }
1709}