1use serde::{Deserialize, Serialize};
2use simd_r_drive::storage_engine::DataStore;
3use simd_r_drive::storage_engine::traits::{DataStoreReader, DataStoreWriter};
4use siphasher::sip::SipHasher;
5use std::collections::HashMap;
6use std::fmt;
7use std::fs;
8use std::hash::{Hash, Hasher};
9use std::io;
10use std::path::{Path, PathBuf};
11use std::sync::RwLock;
12use tempfile::TempDir;
13
14use crate::constants::splits::{
15 ALL_SPLITS, BITCODE_PREFIX, EPOCH_HASH_RECORD_VERSION, EPOCH_HASHES_PREFIX, EPOCH_META_PREFIX,
16 EPOCH_META_RECORD_VERSION, EPOCH_RECORD_TOMBSTONE, META_KEY, SAMPLER_STATE_KEY,
17 SAMPLER_STATE_RECORD_VERSION, SPLIT_PREFIX, STORE_VERSION,
18};
19use crate::data::RecordId;
20use crate::errors::SamplerError;
21use crate::types::SourceId;
22
23#[derive(
25 Clone,
26 Copy,
27 Debug,
28 PartialEq,
29 Eq,
30 Hash,
31 Serialize,
32 Deserialize,
33 bitcode::Encode,
34 bitcode::Decode,
35)]
36pub enum SplitLabel {
37 Train,
39 Validation,
41 Test,
43}
44
45#[derive(Clone, Copy, Debug, Serialize, Deserialize, bitcode::Encode, bitcode::Decode)]
47pub struct SplitRatios {
48 pub train: f32,
50 pub validation: f32,
52 pub test: f32,
54}
55
56impl Default for SplitRatios {
57 fn default() -> Self {
58 Self {
59 train: 0.8,
60 validation: 0.1,
61 test: 0.1,
62 }
63 }
64}
65
66impl SplitRatios {
67 pub fn normalized(self) -> Result<Self, SamplerError> {
69 let sum = self.train + self.validation + self.test;
70 if (sum - 1.0).abs() > 1e-6 {
71 return Err(SamplerError::Configuration(
72 "split ratios must sum to 1.0".to_string(),
73 ));
74 }
75 Ok(self)
76 }
77}
78
79pub use crate::constants::splits::EPOCH_STATE_VERSION;
80
81#[derive(Clone, Debug, bitcode::Encode, bitcode::Decode)]
83pub struct PersistedSplitMeta {
84 pub epoch: u64,
86 pub offset: u64,
88 pub hashes_checksum: u64,
90}
91
92#[derive(Clone, Debug, bitcode::Encode, bitcode::Decode)]
94pub struct PersistedSplitHashes {
95 pub checksum: u64,
97 pub hashes: Vec<u64>,
99}
100
101#[derive(Clone, Debug, bitcode::Encode, bitcode::Decode)]
110pub struct PersistedSamplerState {
111 pub source_cycle_idx: u64,
113 pub source_record_cursors: Vec<(SourceId, u64)>,
115 pub epoch: u64,
117 pub epoch_step: u64,
119 pub rng_state: u64,
121 pub triplet_recipe_rr_idx: u64,
123 pub text_recipe_rr_idx: u64,
125 pub source_stream_cursors: Vec<(SourceId, u64)>,
127}
128
129pub trait SplitStore: Send + Sync {
133 fn label_for(&self, id: &RecordId) -> Option<SplitLabel>;
135 fn upsert(&self, id: RecordId, label: SplitLabel) -> Result<(), SamplerError>;
137 fn ratios(&self) -> SplitRatios;
139 fn ensure(&self, id: RecordId) -> Result<SplitLabel, SamplerError>;
141}
142
143pub trait EpochStateStore: Send + Sync {
145 fn load_epoch_meta(&self) -> Result<HashMap<SplitLabel, PersistedSplitMeta>, SamplerError>;
147 fn load_epoch_hashes(
149 &self,
150 label: SplitLabel,
151 ) -> Result<Option<PersistedSplitHashes>, SamplerError>;
152 fn save_epoch_meta(
154 &self,
155 meta: &HashMap<SplitLabel, PersistedSplitMeta>,
156 ) -> Result<(), SamplerError>;
157 fn save_epoch_hashes(
159 &self,
160 label: SplitLabel,
161 hashes: &PersistedSplitHashes,
162 ) -> Result<(), SamplerError>;
163}
164
165pub trait SamplerStateStore: Send + Sync {
167 fn load_sampler_state(&self) -> Result<Option<PersistedSamplerState>, SamplerError>;
169 fn save_sampler_state(
171 &self,
172 state: &PersistedSamplerState,
173 save_path: Option<&Path>,
174 ) -> Result<(), SamplerError>;
175}
176
177pub struct DeterministicSplitStore {
179 ratios: SplitRatios,
180 assignments: RwLock<HashMap<RecordId, SplitLabel>>,
181 seed: u64,
182 epoch_meta: RwLock<HashMap<SplitLabel, PersistedSplitMeta>>,
183 epoch_hashes: RwLock<HashMap<SplitLabel, PersistedSplitHashes>>,
184 sampler_state: RwLock<Option<PersistedSamplerState>>,
185}
186
187impl DeterministicSplitStore {
188 pub fn new(ratios: SplitRatios, seed: u64) -> Result<Self, SamplerError> {
190 ratios.normalized()?;
191 Ok(Self {
192 ratios,
193 assignments: RwLock::new(HashMap::new()),
194 seed,
195 epoch_meta: RwLock::new(HashMap::new()),
196 epoch_hashes: RwLock::new(HashMap::new()),
197 sampler_state: RwLock::new(None),
198 })
199 }
200
201 fn derive_label(&self, id: &RecordId) -> SplitLabel {
202 derive_label_for_id(id, self.seed, self.ratios)
203 }
204}
205
206impl SplitStore for DeterministicSplitStore {
207 fn label_for(&self, id: &RecordId) -> Option<SplitLabel> {
208 if let Some(label) = self.assignments.read().ok()?.get(id).copied() {
209 return Some(label);
210 }
211 Some(self.derive_label(id))
212 }
213
214 fn upsert(&self, id: RecordId, label: SplitLabel) -> Result<(), SamplerError> {
215 let mut guard = self
216 .assignments
217 .write()
218 .map_err(|_| SamplerError::SplitStore("lock poisoned".into()))?;
219 guard.insert(id, label);
220 Ok(())
221 }
222
223 fn ratios(&self) -> SplitRatios {
224 self.ratios
225 }
226
227 fn ensure(&self, id: RecordId) -> Result<SplitLabel, SamplerError> {
228 Ok(self.derive_label(&id))
229 }
230}
231
232impl EpochStateStore for DeterministicSplitStore {
233 fn load_epoch_meta(&self) -> Result<HashMap<SplitLabel, PersistedSplitMeta>, SamplerError> {
234 self.epoch_meta
235 .read()
236 .map_err(|_| SamplerError::SplitStore("epoch meta lock poisoned".into()))
237 .map(|guard| guard.clone())
238 }
239
240 fn load_epoch_hashes(
241 &self,
242 label: SplitLabel,
243 ) -> Result<Option<PersistedSplitHashes>, SamplerError> {
244 Ok(self
245 .epoch_hashes
246 .read()
247 .map_err(|_| SamplerError::SplitStore("epoch hashes lock poisoned".into()))?
248 .get(&label)
249 .cloned())
250 }
251
252 fn save_epoch_meta(
253 &self,
254 meta: &HashMap<SplitLabel, PersistedSplitMeta>,
255 ) -> Result<(), SamplerError> {
256 *self
257 .epoch_meta
258 .write()
259 .map_err(|_| SamplerError::SplitStore("epoch meta lock poisoned".into()))? =
260 meta.clone();
261 Ok(())
262 }
263
264 fn save_epoch_hashes(
265 &self,
266 label: SplitLabel,
267 hashes: &PersistedSplitHashes,
268 ) -> Result<(), SamplerError> {
269 self.epoch_hashes
270 .write()
271 .map_err(|_| SamplerError::SplitStore("epoch hashes lock poisoned".into()))?
272 .insert(label, hashes.clone());
273 Ok(())
274 }
275}
276
277impl SamplerStateStore for DeterministicSplitStore {
278 fn load_sampler_state(&self) -> Result<Option<PersistedSamplerState>, SamplerError> {
279 self.sampler_state
280 .read()
281 .map_err(|_| SamplerError::SplitStore("sampler state lock poisoned".into()))
282 .map(|guard| guard.clone())
283 }
284
285 fn save_sampler_state(
286 &self,
287 state: &PersistedSamplerState,
288 _save_path: Option<&Path>,
289 ) -> Result<(), SamplerError> {
290 *self
291 .sampler_state
292 .write()
293 .map_err(|_| SamplerError::SplitStore("sampler state lock poisoned".into()))? =
294 Some(state.clone());
295 Ok(())
296 }
297}
298
299#[derive(Clone, Copy, Debug, bitcode::Encode, bitcode::Decode)]
300struct StoreMeta {
302 version: u8,
303 seed: u64,
304 ratios: SplitRatios,
305}
306
307fn encode_store_meta(meta: &StoreMeta) -> Vec<u8> {
308 encode_bitcode_payload(&bitcode::encode(meta))
309}
310
311fn decode_store_meta(bytes: &[u8]) -> Result<StoreMeta, SamplerError> {
312 let raw = decode_bitcode_payload(bytes)?;
313 bitcode::decode(&raw).map_err(|err| {
314 SamplerError::SplitStore(format!("failed to decode split store metadata: {err}"))
315 })
316}
317
318pub struct FileSplitStore {
334 store: DataStore,
335 path: PathBuf,
337 save_path: PathBuf,
339 ratios: SplitRatios,
340 seed: u64,
341 _temp_dir: TempDir,
343}
344
345impl fmt::Debug for FileSplitStore {
346 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
347 f.debug_struct("FileSplitStore")
348 .field("path", &self.save_path)
349 .field("ratios", &self.ratios)
350 .field("seed", &self.seed)
351 .finish()
352 }
353}
354
355impl FileSplitStore {
356 pub fn open<P: Into<PathBuf>>(
358 path: P,
359 ratios: SplitRatios,
360 seed: u64,
361 ) -> Result<Self, SamplerError> {
362 Self::open_with_load_path(None::<PathBuf>, path, ratios, seed)
363 }
364
365 pub fn open_with_load_path<LP: Into<PathBuf>, SP: Into<PathBuf>>(
379 load_path: Option<LP>,
380 save_path: SP,
381 ratios: SplitRatios,
382 seed: u64,
383 ) -> Result<Self, SamplerError> {
384 let ratios = ratios.normalized()?;
385 let save_path = coerce_store_path(save_path.into());
386 ensure_parent_dir(&save_path)?;
387
388 let source: Option<PathBuf> = if let Some(lp) = load_path {
391 let lp = coerce_store_path(lp.into());
392 if lp.exists() { Some(lp) } else { None }
393 } else if save_path.exists() {
394 Some(save_path.clone())
395 } else {
396 None
397 };
398
399 let temp_dir = tempfile::tempdir().map_err(|err| {
400 SamplerError::SplitStore(format!("failed to create temp dir for split store: {err}"))
401 })?;
402 let working_path = temp_dir.path().join("working_store.bin");
403 if let Some(src) = &source {
404 fs::copy(src, &working_path).map_err(|err| {
405 SamplerError::SplitStore(format!(
406 "failed to copy split store from '{}' to temp: {err}",
407 src.display()
408 ))
409 })?;
410 }
411
412 let raw_store = DataStore::open(&working_path).map_err(map_store_err)?;
413 let store = Self {
414 store: raw_store,
415 path: working_path,
416 save_path,
417 ratios,
418 seed,
419 _temp_dir: temp_dir,
420 };
421 store.verify_metadata()?;
422 Ok(store)
423 }
424
425 fn verify_metadata(&self) -> Result<(), SamplerError> {
426 match read_bytes(&self.store, META_KEY)? {
427 Some(bytes) => {
428 let meta = decode_store_meta(&bytes)?;
429 if meta.version != STORE_VERSION {
430 return Err(SamplerError::SplitStore(format!(
431 "split store version mismatch (expected {}, found {})",
432 STORE_VERSION, meta.version
433 )));
434 }
435 if meta.seed != self.seed {
436 return Err(SamplerError::SplitStore(format!(
437 "split store seed mismatch (expected {}, found {})",
438 self.seed, meta.seed
439 )));
440 }
441 if !ratios_close(meta.ratios, self.ratios) {
442 return Err(SamplerError::SplitStore(
443 "split store ratios mismatch".into(),
444 ));
445 }
446 }
447 None => {
448 let blob = StoreMeta {
449 version: STORE_VERSION,
450 seed: self.seed,
451 ratios: self.ratios,
452 };
453 let payload = encode_store_meta(&blob);
454 write_bytes(&self.store, META_KEY, &payload)?;
455 }
456 }
457 Ok(())
458 }
459
460 fn read_epoch_meta_entry(
461 &self,
462 label: SplitLabel,
463 ) -> Result<Option<PersistedSplitMeta>, SamplerError> {
464 let key = epoch_meta_key(label);
465 let entry = self.store.read(&key).map_err(map_store_err)?;
466 match entry {
467 None => Ok(None),
468 Some(bytes) => decode_epoch_meta(bytes.as_ref()),
469 }
470 }
471
472 fn write_epoch_meta_entry(
473 &self,
474 label: SplitLabel,
475 meta: Option<&PersistedSplitMeta>,
476 ) -> Result<(), SamplerError> {
477 let key = epoch_meta_key(label);
478 let payload = encode_epoch_meta(meta);
479 self.store
480 .write(&key, payload.as_slice())
481 .map_err(map_store_err)?;
482 Ok(())
483 }
484
485 fn read_epoch_hashes_entry(
486 &self,
487 label: SplitLabel,
488 ) -> Result<Option<PersistedSplitHashes>, SamplerError> {
489 let key = epoch_hashes_key(label);
490 let entry = self.store.read(&key).map_err(map_store_err)?;
491 match entry {
492 None => Ok(None),
493 Some(bytes) => decode_epoch_hashes(bytes.as_ref()),
494 }
495 }
496
497 fn write_epoch_hashes_entry(
498 &self,
499 label: SplitLabel,
500 hashes: &PersistedSplitHashes,
501 ) -> Result<(), SamplerError> {
502 let key = epoch_hashes_key(label);
503 let payload = encode_epoch_hashes(hashes);
504 self.store
505 .write(&key, payload.as_slice())
506 .map_err(map_store_err)?;
507 Ok(())
508 }
509}
510
511impl SplitStore for FileSplitStore {
512 fn label_for(&self, id: &RecordId) -> Option<SplitLabel> {
513 let key = split_key(id);
514 if let Ok(Some(value)) = self.store.read(&key)
515 && let Ok(label) = decode_label(value.as_ref())
516 {
517 return Some(label);
518 }
519 Some(derive_label_for_id(id, self.seed, self.ratios))
520 }
521
522 fn upsert(&self, id: RecordId, label: SplitLabel) -> Result<(), SamplerError> {
523 let _ = (id, label);
524 Ok(())
525 }
526
527 fn ratios(&self) -> SplitRatios {
528 self.ratios
529 }
530
531 fn ensure(&self, id: RecordId) -> Result<SplitLabel, SamplerError> {
532 Ok(derive_label_for_id(&id, self.seed, self.ratios))
533 }
534}
535
536impl EpochStateStore for FileSplitStore {
537 fn load_epoch_meta(&self) -> Result<HashMap<SplitLabel, PersistedSplitMeta>, SamplerError> {
538 let mut meta = HashMap::new();
539 for label in ALL_SPLITS {
540 if let Some(entry) = self.read_epoch_meta_entry(label)? {
541 meta.insert(label, entry);
542 }
543 }
544 Ok(meta)
545 }
546
547 fn load_epoch_hashes(
548 &self,
549 label: SplitLabel,
550 ) -> Result<Option<PersistedSplitHashes>, SamplerError> {
551 self.read_epoch_hashes_entry(label)
552 }
553
554 fn save_epoch_meta(
555 &self,
556 meta: &HashMap<SplitLabel, PersistedSplitMeta>,
557 ) -> Result<(), SamplerError> {
558 for label in ALL_SPLITS {
559 self.write_epoch_meta_entry(label, meta.get(&label))?;
560 }
561 Ok(())
562 }
563
564 fn save_epoch_hashes(
565 &self,
566 label: SplitLabel,
567 hashes: &PersistedSplitHashes,
568 ) -> Result<(), SamplerError> {
569 self.write_epoch_hashes_entry(label, hashes)
570 }
571}
572
573impl SamplerStateStore for FileSplitStore {
574 fn load_sampler_state(&self) -> Result<Option<PersistedSamplerState>, SamplerError> {
575 match read_bytes(&self.store, SAMPLER_STATE_KEY)? {
576 Some(bytes) => decode_sampler_state(bytes.as_ref()),
577 None => Ok(None),
578 }
579 }
580
581 fn save_sampler_state(
587 &self,
588 state: &PersistedSamplerState,
589 save_to: Option<&Path>,
590 ) -> Result<(), SamplerError> {
591 let dest = if let Some(p) = save_to {
593 coerce_store_path(p.to_path_buf())
594 } else {
595 self.save_path.clone()
596 };
597
598 if save_to.is_some() && dest.exists() {
601 return Err(SamplerError::SplitStore(format!(
602 "refusing to overwrite existing split store '{}'; choose a new path",
603 dest.display()
604 )));
605 }
606
607 let payload = encode_sampler_state(state);
609 write_bytes(&self.store, SAMPLER_STATE_KEY, &payload)?;
610
611 ensure_parent_dir(&dest)?;
613 fs::copy(&self.path, &dest).map_err(|err| {
614 SamplerError::SplitStore(format!(
615 "failed to publish split store to '{}': {err}",
616 dest.display()
617 ))
618 })?;
619
620 Ok(())
621 }
622}
623
624fn decode_label(bytes: &[u8]) -> Result<SplitLabel, SamplerError> {
625 match bytes.first() {
626 Some(b'0') => Ok(SplitLabel::Train),
627 Some(b'1') => Ok(SplitLabel::Validation),
628 Some(b'2') => Ok(SplitLabel::Test),
629 _ => Err(SamplerError::SplitStore("invalid split label".into())),
630 }
631}
632
633fn derive_label_for_id(id: &RecordId, seed: u64, ratios: SplitRatios) -> SplitLabel {
634 let mut hasher = SipHasher::new();
635 id.hash(&mut hasher);
636 seed.hash(&mut hasher);
637 let value = hasher.finish() as f64 / u64::MAX as f64;
638 let train_cut = ratios.train as f64;
639 let val_cut = train_cut + ratios.validation as f64;
640 if value < train_cut {
641 SplitLabel::Train
642 } else if value < val_cut {
643 SplitLabel::Validation
644 } else {
645 SplitLabel::Test
646 }
647}
648
649fn ratios_close(a: SplitRatios, b: SplitRatios) -> bool {
650 ((a.train - b.train).abs() + (a.validation - b.validation).abs() + (a.test - b.test).abs())
651 < 1e-5
652}
653
654fn split_key(id: &RecordId) -> Vec<u8> {
655 let mut key = Vec::with_capacity(SPLIT_PREFIX.len() + id.len());
656 key.extend_from_slice(SPLIT_PREFIX);
657 key.extend_from_slice(id.as_bytes());
658 key
659}
660
661fn read_bytes(store: &DataStore, key: &[u8]) -> Result<Option<Vec<u8>>, SamplerError> {
662 store
663 .read(key)
664 .map_err(map_store_err)?
665 .map(|entry| Ok(entry.as_ref().to_vec()))
666 .transpose()
667}
668
669fn write_bytes(store: &DataStore, key: &[u8], payload: &[u8]) -> Result<(), SamplerError> {
670 store.write(key, payload).map_err(map_store_err)?;
671 Ok(())
672}
673
674fn epoch_meta_key(label: SplitLabel) -> Vec<u8> {
675 let mut key = Vec::with_capacity(EPOCH_META_PREFIX.len() + 12);
676 key.extend_from_slice(EPOCH_META_PREFIX);
677 let suffix = match label {
678 SplitLabel::Train => b"train".as_ref(),
679 SplitLabel::Validation => b"validation".as_ref(),
680 SplitLabel::Test => b"test".as_ref(),
681 };
682 key.extend_from_slice(suffix);
683 key
684}
685
686fn epoch_hashes_key(label: SplitLabel) -> Vec<u8> {
687 let mut key = Vec::with_capacity(EPOCH_HASHES_PREFIX.len() + 12);
688 key.extend_from_slice(EPOCH_HASHES_PREFIX);
689 let suffix = match label {
690 SplitLabel::Train => b"train".as_ref(),
691 SplitLabel::Validation => b"validation".as_ref(),
692 SplitLabel::Test => b"test".as_ref(),
693 };
694 key.extend_from_slice(suffix);
695 key
696}
697
698fn encode_epoch_meta(meta: Option<&PersistedSplitMeta>) -> Vec<u8> {
699 match meta {
700 None => vec![EPOCH_RECORD_TOMBSTONE],
701 Some(meta) => {
702 let payload = encode_bitcode_payload(&bitcode::encode(meta));
703 let mut buf = Vec::with_capacity(1 + payload.len());
704 buf.push(EPOCH_META_RECORD_VERSION);
705 buf.extend_from_slice(&payload);
706 buf
707 }
708 }
709}
710
711fn decode_epoch_meta(bytes: &[u8]) -> Result<Option<PersistedSplitMeta>, SamplerError> {
712 if bytes.is_empty() || bytes[0] == EPOCH_RECORD_TOMBSTONE {
713 return Ok(None);
714 }
715 if bytes[0] != EPOCH_META_RECORD_VERSION {
716 return Err(SamplerError::SplitStore(
717 "epoch meta record version mismatch".into(),
718 ));
719 }
720 let raw = decode_bitcode_payload(&bytes[1..])?;
721 bitcode::decode(&raw)
722 .map(Some)
723 .map_err(|err| SamplerError::SplitStore(format!("corrupt epoch meta record: {err}")))
724}
725
726fn encode_epoch_hashes(hashes: &PersistedSplitHashes) -> Vec<u8> {
727 let payload = encode_bitcode_payload(&bitcode::encode(hashes));
728 let mut buf = Vec::with_capacity(1 + payload.len());
729 buf.push(EPOCH_HASH_RECORD_VERSION);
730 buf.extend_from_slice(&payload);
731 buf
732}
733
734fn decode_epoch_hashes(bytes: &[u8]) -> Result<Option<PersistedSplitHashes>, SamplerError> {
735 if bytes.is_empty() || bytes[0] == EPOCH_RECORD_TOMBSTONE {
736 return Ok(None);
737 }
738 if bytes[0] != EPOCH_HASH_RECORD_VERSION {
739 return Err(SamplerError::SplitStore(
740 "epoch hashes record version mismatch".into(),
741 ));
742 }
743 let raw = decode_bitcode_payload(&bytes[1..])?;
744 bitcode::decode(&raw)
745 .map(Some)
746 .map_err(|err| SamplerError::SplitStore(format!("corrupt epoch hashes record: {err}")))
747}
748
749fn encode_sampler_state(state: &PersistedSamplerState) -> Vec<u8> {
750 let payload = encode_bitcode_payload(&bitcode::encode(state));
751 let mut buf = Vec::with_capacity(1 + payload.len());
752 buf.push(SAMPLER_STATE_RECORD_VERSION);
753 buf.extend_from_slice(&payload);
754 buf
755}
756
757fn decode_sampler_state(bytes: &[u8]) -> Result<Option<PersistedSamplerState>, SamplerError> {
758 if bytes.is_empty() {
759 return Ok(None);
760 }
761 if bytes[0] != SAMPLER_STATE_RECORD_VERSION {
762 return Err(SamplerError::SplitStore(
763 "sampler state record version mismatch".into(),
764 ));
765 }
766 let raw = decode_bitcode_payload(&bytes[1..])?;
767 bitcode::decode(&raw)
768 .map(Some)
769 .map_err(|err| SamplerError::SplitStore(format!("corrupt sampler state record: {err}")))
770}
771
772fn encode_bitcode_payload(bytes: &[u8]) -> Vec<u8> {
773 let mut out = Vec::with_capacity(1 + bytes.len());
774 out.push(BITCODE_PREFIX);
775 out.extend_from_slice(bytes);
776 out
777}
778
779fn decode_bitcode_payload(bytes: &[u8]) -> Result<Vec<u8>, SamplerError> {
780 if bytes.first().copied() != Some(BITCODE_PREFIX) {
781 return Err(SamplerError::SplitStore(
782 "bitcode payload missing expected prefix".into(),
783 ));
784 }
785 Ok(bytes[1..].to_vec())
786}
787
788fn coerce_store_path(path: PathBuf) -> PathBuf {
789 path
790}
791
792fn ensure_parent_dir(path: &Path) -> Result<(), SamplerError> {
793 if let Some(parent) = path.parent()
794 && !parent.as_os_str().is_empty()
795 {
796 fs::create_dir_all(parent)?;
797 }
798 Ok(())
799}
800
801fn map_store_err(err: io::Error) -> SamplerError {
802 SamplerError::SplitStore(err.to_string())
803}
804
805#[cfg(test)]
806mod tests {
807 use super::*;
808 use std::collections::HashMap;
809 use tempfile::tempdir;
810
811 #[test]
812 fn split_ratios_reject_non_unit_sum() {
813 let invalid = SplitRatios {
814 train: 0.6,
815 validation: 0.3,
816 test: 0.3,
817 };
818
819 let err = DeterministicSplitStore::new(invalid, 1)
820 .err()
821 .expect("expected non-unit split ratios to fail");
822 assert!(matches!(
823 err,
824 SamplerError::Configuration(ref msg) if msg.contains("split ratios must sum to 1.0")
825 ));
826
827 let dir = tempdir().unwrap();
828 let path = dir.path().join("split_store.bin");
829 let err = FileSplitStore::open(&path, invalid, 1).unwrap_err();
830 assert!(matches!(
831 err,
832 SamplerError::Configuration(ref msg) if msg.contains("split ratios must sum to 1.0")
833 ));
834 }
835
836 #[test]
837 fn zero_test_ratio_never_assigns_test_labels() {
838 let ratios = SplitRatios {
839 train: 0.5,
840 validation: 0.5,
841 test: 0.0,
842 };
843 let store = DeterministicSplitStore::new(ratios, 7).unwrap();
844
845 let mut saw_train = false;
846 let mut saw_validation = false;
847 for idx in 0..20_000 {
848 let id = format!("record_{idx}");
849 let label = store.ensure(id).unwrap();
850 assert_ne!(label, SplitLabel::Test);
851 saw_train |= label == SplitLabel::Train;
852 saw_validation |= label == SplitLabel::Validation;
853 if saw_train && saw_validation {
854 break;
855 }
856 }
857
858 assert!(saw_train);
859 assert!(saw_validation);
860 }
861
862 #[test]
865 fn save_none_publishes_to_save_path_and_reloads_cleanly() {
866 let dir = tempdir().unwrap();
867 let path = dir.path().join("persist.bin");
868 let ratios = SplitRatios::default();
869
870 {
871 let store = FileSplitStore::open(&path, ratios, 123).unwrap();
872 let mut meta = HashMap::new();
873 meta.insert(
874 SplitLabel::Train,
875 PersistedSplitMeta {
876 epoch: 3,
877 offset: 7,
878 hashes_checksum: 42,
879 },
880 );
881 store.save_epoch_meta(&meta).unwrap();
882 let state = PersistedSamplerState {
883 source_cycle_idx: 11,
884 source_record_cursors: vec![("s".to_string(), 1)],
885 epoch: 5,
886 epoch_step: 0,
887 rng_state: 99,
888 triplet_recipe_rr_idx: 2,
889 text_recipe_rr_idx: 3,
890 source_stream_cursors: vec![],
891 };
892 assert!(!path.exists(), "save_path must not exist before save(None)");
893 store.save_sampler_state(&state, None).unwrap();
894 assert!(
895 path.exists(),
896 "save(None) must publish to save_path on disk"
897 );
898 }
899
900 let reopened = FileSplitStore::open(&path, ratios, 123).unwrap();
902 let loaded_state = reopened.load_sampler_state().unwrap().unwrap();
903 assert_eq!(loaded_state.source_cycle_idx, 11);
904 assert_eq!(loaded_state.epoch, 5);
905 assert_eq!(loaded_state.rng_state, 99);
906 let loaded_meta = reopened.load_epoch_meta().unwrap();
907 assert_eq!(loaded_meta.get(&SplitLabel::Train).unwrap().epoch, 3);
908 }
909
910 #[test]
911 fn file_store_rejects_seed_mismatch() {
912 let dir = tempdir().unwrap();
913 let path = dir.path().join("splits.json");
914 let ratios = SplitRatios::default();
915 let store = FileSplitStore::open(&path, ratios, 123).unwrap();
916 store.ensure("abc".to_string()).unwrap();
917 store
919 .save_sampler_state(
920 &PersistedSamplerState {
921 source_cycle_idx: 0,
922 source_record_cursors: vec![],
923 epoch: 0,
924 epoch_step: 0,
925 rng_state: 0,
926 triplet_recipe_rr_idx: 0,
927 text_recipe_rr_idx: 0,
928 source_stream_cursors: vec![],
929 },
930 None,
931 )
932 .unwrap();
933 drop(store);
934
935 let err = FileSplitStore::open(&path, ratios, 999).unwrap_err();
936 assert!(matches!(
937 err,
938 SamplerError::SplitStore(msg) if msg.contains("seed")
939 ));
940 }
941
942 #[test]
943 fn file_store_accepts_directory_path() {
944 let dir = tempdir().unwrap();
945 let ratios = SplitRatios::default();
946 let err = FileSplitStore::open(dir.path(), ratios, 777).unwrap_err();
947 assert!(matches!(err, SamplerError::SplitStore(_)));
948 }
949
950 #[test]
951 fn bitcode_payload_requires_prefix() {
952 let err = decode_bitcode_payload(&[0x00, 0x01]).unwrap_err();
953 assert!(
954 matches!(err, SamplerError::SplitStore(msg) if msg.contains("missing expected prefix"))
955 );
956 }
957
958 #[test]
959 fn file_store_round_trips_epoch_and_sampler_state() {
960 let dir = tempdir().unwrap();
961 let path = dir.path().join("epoch_sampler_state.bin");
962 let store = FileSplitStore::open(&path, SplitRatios::default(), 222).unwrap();
963
964 assert!(store.load_epoch_hashes(SplitLabel::Test).unwrap().is_none());
965
966 let mut epoch_meta = HashMap::new();
967 epoch_meta.insert(
968 SplitLabel::Train,
969 PersistedSplitMeta {
970 epoch: 3,
971 offset: 7,
972 hashes_checksum: 42,
973 },
974 );
975 store.save_epoch_meta(&epoch_meta).unwrap();
976
977 let loaded_meta = store.load_epoch_meta().unwrap();
978 let loaded_train = loaded_meta.get(&SplitLabel::Train).unwrap();
979 assert_eq!(loaded_train.epoch, 3);
980 assert_eq!(loaded_train.offset, 7);
981 assert_eq!(loaded_train.hashes_checksum, 42);
982
983 let hashes = PersistedSplitHashes {
984 checksum: 99,
985 hashes: vec![10, 20, 30],
986 };
987 store
988 .save_epoch_hashes(SplitLabel::Validation, &hashes)
989 .unwrap();
990 let loaded_hashes = store
991 .load_epoch_hashes(SplitLabel::Validation)
992 .unwrap()
993 .unwrap();
994 assert_eq!(loaded_hashes.checksum, 99);
995 assert_eq!(loaded_hashes.hashes, vec![10, 20, 30]);
996
997 let state = PersistedSamplerState {
998 source_cycle_idx: 11,
999 source_record_cursors: vec![("source_a".to_string(), 1)],
1000 epoch: 8,
1001 epoch_step: 0,
1002 rng_state: 1234,
1003 triplet_recipe_rr_idx: 2,
1004 text_recipe_rr_idx: 5,
1005 source_stream_cursors: vec![("source_a".to_string(), 9)],
1006 };
1007 store.save_sampler_state(&state, None).unwrap();
1008 let loaded_state = store.load_sampler_state().unwrap().unwrap();
1009 assert_eq!(loaded_state.source_cycle_idx, 11);
1010 assert_eq!(loaded_state.epoch, 8);
1011 assert_eq!(loaded_state.rng_state, 1234);
1012 assert_eq!(loaded_state.triplet_recipe_rr_idx, 2);
1013 assert_eq!(loaded_state.text_recipe_rr_idx, 5);
1014 assert_eq!(
1015 loaded_state.source_record_cursors,
1016 vec![("source_a".to_string(), 1)]
1017 );
1018 assert_eq!(
1019 loaded_state.source_stream_cursors,
1020 vec![("source_a".to_string(), 9)]
1021 );
1022
1023 drop(store);
1025 let reopened = FileSplitStore::open(&path, SplitRatios::default(), 222).unwrap();
1026 let disk_state = reopened.load_sampler_state().unwrap().unwrap();
1027 assert_eq!(disk_state.source_cycle_idx, 11);
1028 assert_eq!(disk_state.epoch, 8);
1029 assert_eq!(disk_state.rng_state, 1234);
1030 let disk_meta = reopened.load_epoch_meta().unwrap();
1031 assert_eq!(disk_meta.get(&SplitLabel::Train).unwrap().epoch, 3);
1032 let disk_hashes = reopened
1033 .load_epoch_hashes(SplitLabel::Validation)
1034 .unwrap()
1035 .unwrap();
1036 assert_eq!(disk_hashes.checksum, 99);
1037 }
1038
1039 #[test]
1040 fn split_keys_and_labels_cover_helper_paths() {
1041 let key = split_key(&"abc".to_string());
1042 assert!(key.starts_with(SPLIT_PREFIX));
1043
1044 assert!(matches!(decode_label(b"0"), Ok(SplitLabel::Train)));
1045 assert!(matches!(decode_label(b"1"), Ok(SplitLabel::Validation)));
1046 assert!(matches!(decode_label(b"2"), Ok(SplitLabel::Test)));
1047 assert!(decode_label(b"x").is_err());
1048
1049 let epoch_meta_train = epoch_meta_key(SplitLabel::Train);
1050 let epoch_hashes_train = epoch_hashes_key(SplitLabel::Train);
1051 let epoch_hashes_test = epoch_hashes_key(SplitLabel::Test);
1052 assert!(epoch_meta_train.starts_with(EPOCH_META_PREFIX));
1053 assert!(epoch_hashes_train.starts_with(EPOCH_HASHES_PREFIX));
1054 assert!(epoch_hashes_test.starts_with(EPOCH_HASHES_PREFIX));
1055 }
1056
1057 #[test]
1058 fn encode_decode_store_meta_roundtrip_and_corrupt_prefix_error() {
1059 let meta = StoreMeta {
1060 version: STORE_VERSION,
1061 seed: 55,
1062 ratios: SplitRatios::default(),
1063 };
1064 let encoded = encode_store_meta(&meta);
1065 let decoded = decode_store_meta(&encoded).unwrap();
1066 assert_eq!(decoded.version, STORE_VERSION);
1067 assert_eq!(decoded.seed, 55);
1068
1069 let err = decode_store_meta(&[0x00, 0x01]).unwrap_err();
1070 assert!(matches!(
1071 err,
1072 SamplerError::SplitStore(msg) if msg.contains("missing expected prefix")
1073 ));
1074 }
1075
1076 #[test]
1077 fn epoch_and_sampler_decoders_cover_tombstone_and_version_mismatch() {
1078 assert!(decode_epoch_meta(&[]).unwrap().is_none());
1079 assert!(
1080 decode_epoch_meta(&[EPOCH_RECORD_TOMBSTONE])
1081 .unwrap()
1082 .is_none()
1083 );
1084 assert!(decode_epoch_hashes(&[]).unwrap().is_none());
1085 assert!(
1086 decode_epoch_hashes(&[EPOCH_RECORD_TOMBSTONE])
1087 .unwrap()
1088 .is_none()
1089 );
1090 assert!(decode_sampler_state(&[]).unwrap().is_none());
1091
1092 let meta_mismatch = decode_epoch_meta(&[EPOCH_META_RECORD_VERSION.wrapping_add(1), 1]);
1093 assert!(matches!(
1094 meta_mismatch,
1095 Err(SamplerError::SplitStore(msg)) if msg.contains("version mismatch")
1096 ));
1097 let hashes_mismatch = decode_epoch_hashes(&[EPOCH_HASH_RECORD_VERSION.wrapping_add(1), 1]);
1098 assert!(matches!(
1099 hashes_mismatch,
1100 Err(SamplerError::SplitStore(msg)) if msg.contains("version mismatch")
1101 ));
1102 let state_mismatch =
1103 decode_sampler_state(&[SAMPLER_STATE_RECORD_VERSION.wrapping_add(1), 1]);
1104 assert!(matches!(
1105 state_mismatch,
1106 Err(SamplerError::SplitStore(msg)) if msg.contains("version mismatch")
1107 ));
1108 }
1109
1110 #[test]
1111 fn split_store_trait_methods_and_path_helpers_are_exercised() {
1112 let dir = tempdir().unwrap();
1113 let file_path = dir.path().join("nested").join("store.bin");
1114 ensure_parent_dir(&file_path).unwrap();
1115 assert!(file_path.parent().unwrap().exists());
1116
1117 let existing_dir_path = coerce_store_path(dir.path().to_path_buf());
1118 assert_eq!(existing_dir_path, dir.path().to_path_buf());
1119
1120 let ratios = SplitRatios::default();
1121 let store = FileSplitStore::open(&file_path, ratios, 444).unwrap();
1122 assert!((store.ratios().train - ratios.train).abs() < 1e-6);
1123 store
1124 .upsert("record_1".to_string(), SplitLabel::Validation)
1125 .unwrap();
1126 let ensured = store.ensure("record_1".to_string()).unwrap();
1127 assert!(matches!(
1128 ensured,
1129 SplitLabel::Train | SplitLabel::Validation | SplitLabel::Test
1130 ));
1131
1132 let mapped = map_store_err(io::Error::other("boom"));
1133 assert!(matches!(mapped, SamplerError::SplitStore(msg) if msg.contains("boom")));
1134 }
1135
1136 #[test]
1137 fn epoch_and_sampler_encode_decode_roundtrips() {
1138 let meta = PersistedSplitMeta {
1139 epoch: 4,
1140 offset: 9,
1141 hashes_checksum: 21,
1142 };
1143 let encoded_meta = encode_epoch_meta(Some(&meta));
1144 let decoded_meta = decode_epoch_meta(&encoded_meta).unwrap().unwrap();
1145 assert_eq!(decoded_meta.epoch, 4);
1146 assert_eq!(decoded_meta.offset, 9);
1147
1148 let hashes = PersistedSplitHashes {
1149 checksum: 7,
1150 hashes: vec![1, 2, 3],
1151 };
1152 let encoded_hashes = encode_epoch_hashes(&hashes);
1153 let decoded_hashes = decode_epoch_hashes(&encoded_hashes).unwrap().unwrap();
1154 assert_eq!(decoded_hashes.checksum, 7);
1155 assert_eq!(decoded_hashes.hashes, vec![1, 2, 3]);
1156
1157 let state = PersistedSamplerState {
1158 source_cycle_idx: 1,
1159 source_record_cursors: vec![("s".to_string(), 2)],
1160 epoch: 3,
1161 epoch_step: 0,
1162 rng_state: 4,
1163 triplet_recipe_rr_idx: 5,
1164 text_recipe_rr_idx: 6,
1165 source_stream_cursors: vec![("s".to_string(), 7)],
1166 };
1167 let encoded_state = encode_sampler_state(&state);
1168 let decoded_state = decode_sampler_state(&encoded_state).unwrap().unwrap();
1169 assert_eq!(decoded_state.source_cycle_idx, 1);
1170 assert_eq!(decoded_state.epoch, 3);
1171 assert_eq!(decoded_state.rng_state, 4);
1172 }
1173
1174 #[test]
1175 fn deterministic_store_trait_methods_work() {
1176 let ratios = SplitRatios::default();
1177 let store = DeterministicSplitStore::new(ratios, 999).unwrap();
1178
1179 assert_eq!(store.ratios().train, ratios.train);
1180
1181 let id = "source::record".to_string();
1182 let derived = store.label_for(&id).unwrap();
1183 store.upsert(id.clone(), SplitLabel::Validation).unwrap();
1184 assert_eq!(store.label_for(&id), Some(SplitLabel::Validation));
1185 assert!(matches!(
1186 derived,
1187 SplitLabel::Train | SplitLabel::Validation | SplitLabel::Test
1188 ));
1189
1190 let mut meta = HashMap::new();
1191 meta.insert(
1192 SplitLabel::Test,
1193 PersistedSplitMeta {
1194 epoch: 1,
1195 offset: 2,
1196 hashes_checksum: 3,
1197 },
1198 );
1199 store.save_epoch_meta(&meta).unwrap();
1200 let loaded_meta = store.load_epoch_meta().unwrap();
1201 assert_eq!(loaded_meta.get(&SplitLabel::Test).unwrap().offset, 2);
1202
1203 assert!(
1204 store
1205 .load_epoch_hashes(SplitLabel::Train)
1206 .unwrap()
1207 .is_none()
1208 );
1209 store
1210 .save_epoch_hashes(
1211 SplitLabel::Train,
1212 &PersistedSplitHashes {
1213 checksum: 11,
1214 hashes: vec![4, 5],
1215 },
1216 )
1217 .unwrap();
1218 assert_eq!(
1219 store
1220 .load_epoch_hashes(SplitLabel::Train)
1221 .unwrap()
1222 .unwrap()
1223 .checksum,
1224 11
1225 );
1226
1227 assert!(store.load_sampler_state().unwrap().is_none());
1228 let sampler_state = PersistedSamplerState {
1229 source_cycle_idx: 1,
1230 source_record_cursors: vec![("s1".to_string(), 2)],
1231 epoch: 3,
1232 epoch_step: 0,
1233 rng_state: 4,
1234 triplet_recipe_rr_idx: 5,
1235 text_recipe_rr_idx: 6,
1236 source_stream_cursors: vec![("s1".to_string(), 7)],
1237 };
1238 store.save_sampler_state(&sampler_state, None).unwrap();
1239 assert_eq!(
1240 store.load_sampler_state().unwrap().unwrap().epoch,
1241 sampler_state.epoch
1242 );
1243 }
1244
1245 #[test]
1246 fn open_with_load_path_bootstraps_state_explicitly() {
1247 let dir = tempdir().unwrap();
1248 let path_a = dir.path().join("snapshot_a.bin");
1249 let path_b = dir.path().join("snapshot_b.bin");
1250 let ratios = SplitRatios::default();
1251
1252 let store_a = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1253
1254 let mut meta = HashMap::new();
1255 meta.insert(
1256 SplitLabel::Train,
1257 PersistedSplitMeta {
1258 epoch: 5,
1259 offset: 3,
1260 hashes_checksum: 999,
1261 },
1262 );
1263 store_a.save_epoch_meta(&meta).unwrap();
1264
1265 let sampler_state = PersistedSamplerState {
1266 source_cycle_idx: 1,
1267 source_record_cursors: vec![("s1".to_string(), 2)],
1268 epoch: 7,
1269 epoch_step: 0,
1270 rng_state: 123,
1271 triplet_recipe_rr_idx: 4,
1272 text_recipe_rr_idx: 6,
1273 source_stream_cursors: vec![("s1".to_string(), 8)],
1274 };
1275 store_a.save_sampler_state(&sampler_state, None).unwrap();
1276 drop(store_a);
1277
1278 let store_b =
1279 FileSplitStore::open_with_load_path(Some(path_a.clone()), &path_b, ratios, 42).unwrap();
1280 assert_eq!(
1281 store_b
1282 .load_epoch_meta()
1283 .unwrap()
1284 .get(&SplitLabel::Train)
1285 .unwrap()
1286 .epoch,
1287 5
1288 );
1289 assert_eq!(store_b.load_sampler_state().unwrap().unwrap().epoch, 7);
1290
1291 let store_a_again = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1292 assert_eq!(
1293 store_a_again
1294 .load_epoch_meta()
1295 .unwrap()
1296 .get(&SplitLabel::Train)
1297 .unwrap()
1298 .epoch,
1299 5
1300 );
1301 assert_eq!(
1302 store_a_again.load_sampler_state().unwrap().unwrap().epoch,
1303 7
1304 );
1305 }
1306
1307 #[test]
1308 fn save_sampler_state_to_new_path_copies_existing_store_first() {
1309 let dir = tempdir().unwrap();
1310 let path_a = dir.path().join("source_store.bin");
1311 let path_b = dir.path().join("mirror_store.bin");
1312 let ratios = SplitRatios::default();
1313
1314 let store_a = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1315
1316 let assigned_id = "record_with_assignment".to_string();
1317 let assigned_key = split_key(&assigned_id);
1318 store_a.store.write(&assigned_key, b"1").unwrap();
1319
1320 let sampler_state = PersistedSamplerState {
1321 source_cycle_idx: 1,
1322 source_record_cursors: vec![("s1".to_string(), 2)],
1323 epoch: 9,
1324 epoch_step: 0,
1325 rng_state: 123,
1326 triplet_recipe_rr_idx: 4,
1327 text_recipe_rr_idx: 6,
1328 source_stream_cursors: vec![("s1".to_string(), 8)],
1329 };
1330
1331 store_a
1332 .save_sampler_state(&sampler_state, Some(path_b.as_path()))
1333 .unwrap();
1334
1335 let store_b = FileSplitStore::open(&path_b, ratios, 42).unwrap();
1337 assert_eq!(
1338 store_b.label_for(&assigned_id),
1339 Some(SplitLabel::Validation)
1340 );
1341 assert_eq!(store_b.load_sampler_state().unwrap().unwrap().epoch, 9);
1342
1343 assert!(
1345 !path_a.exists(),
1346 "save_to=Some(...) must not publish to the canonical save_path"
1347 );
1348 }
1349
1350 #[test]
1353 fn save_some_on_regular_open_does_not_modify_working_store() {
1354 let dir = tempdir().unwrap();
1355 let path_a = dir.path().join("working_store.bin");
1356 let path_b = dir.path().join("checkpoint_store.bin");
1357 let ratios = SplitRatios::default();
1358
1359 let store_a = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1360
1361 let initial_state = PersistedSamplerState {
1363 source_cycle_idx: 1,
1364 source_record_cursors: vec![],
1365 epoch: 1,
1366 epoch_step: 0,
1367 rng_state: 0,
1368 triplet_recipe_rr_idx: 0,
1369 text_recipe_rr_idx: 0,
1370 source_stream_cursors: vec![],
1371 };
1372 store_a.save_sampler_state(&initial_state, None).unwrap();
1373
1374 let checkpoint_state = PersistedSamplerState {
1376 source_cycle_idx: 99,
1377 source_record_cursors: vec![],
1378 epoch: 99,
1379 epoch_step: 0,
1380 rng_state: 42,
1381 triplet_recipe_rr_idx: 0,
1382 text_recipe_rr_idx: 0,
1383 source_stream_cursors: vec![],
1384 };
1385 store_a
1386 .save_sampler_state(&checkpoint_state, Some(path_b.as_path()))
1387 .unwrap();
1388
1389 drop(store_a);
1392 let store_a_disk = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1393 assert_eq!(
1394 store_a_disk.load_sampler_state().unwrap().unwrap().epoch,
1395 1,
1396 "save_to=Some(...) must not overwrite the on-disk save_path"
1397 );
1398
1399 let store_b = FileSplitStore::open(&path_b, ratios, 42).unwrap();
1401 let state_from_b = store_b.load_sampler_state().unwrap().unwrap();
1402 assert_eq!(
1403 state_from_b.epoch, 99,
1404 "checkpoint store must hold the snapshotted state"
1405 );
1406 }
1407
1408 #[test]
1409 fn file_store_metadata_mismatch_and_debug_paths_are_covered() {
1410 let dir = tempdir().unwrap();
1411 let path = dir.path().join("meta_mismatch.bin");
1412 let ratios = SplitRatios::default();
1413 let store = FileSplitStore::open(&path, ratios, 123).unwrap();
1414
1415 let debug_repr = format!("{store:?}");
1416 assert!(debug_repr.contains("FileSplitStore"));
1417
1418 let wrong_version = StoreMeta {
1419 version: STORE_VERSION.wrapping_add(1),
1420 seed: 123,
1421 ratios,
1422 };
1423 let payload = encode_store_meta(&wrong_version);
1424 store.store.write(META_KEY, &payload).unwrap();
1425 store
1427 .save_sampler_state(
1428 &PersistedSamplerState {
1429 source_cycle_idx: 0,
1430 source_record_cursors: vec![],
1431 epoch: 0,
1432 epoch_step: 0,
1433 rng_state: 0,
1434 triplet_recipe_rr_idx: 0,
1435 text_recipe_rr_idx: 0,
1436 source_stream_cursors: vec![],
1437 },
1438 None,
1439 )
1440 .unwrap();
1441 drop(store);
1442
1443 let err = FileSplitStore::open(&path, ratios, 123).unwrap_err();
1444 assert!(matches!(err, SamplerError::SplitStore(msg) if msg.contains("version mismatch")));
1445
1446 let ratio_path = dir.path().join("ratio_mismatch.bin");
1447 let baseline = FileSplitStore::open(&ratio_path, ratios, 777).unwrap();
1448 baseline
1450 .save_sampler_state(
1451 &PersistedSamplerState {
1452 source_cycle_idx: 0,
1453 source_record_cursors: vec![],
1454 epoch: 0,
1455 epoch_step: 0,
1456 rng_state: 0,
1457 triplet_recipe_rr_idx: 0,
1458 text_recipe_rr_idx: 0,
1459 source_stream_cursors: vec![],
1460 },
1461 None,
1462 )
1463 .unwrap();
1464 drop(baseline);
1465
1466 let different_ratios = SplitRatios {
1467 train: 0.7,
1468 validation: 0.2,
1469 test: 0.1,
1470 };
1471 let err = FileSplitStore::open(&ratio_path, different_ratios, 777).unwrap_err();
1472 assert!(matches!(err, SamplerError::SplitStore(msg) if msg.contains("ratios mismatch")));
1473 }
1474
1475 #[test]
1476 fn split_decode_helpers_reject_corrupt_bitcode_payloads() {
1477 let store_meta_err = decode_store_meta(&[BITCODE_PREFIX, 0xFF, 0xEE]).unwrap_err();
1478 assert!(matches!(
1479 store_meta_err,
1480 SamplerError::SplitStore(msg) if msg.contains("failed to decode split store metadata")
1481 ));
1482
1483 let epoch_meta_err =
1484 decode_epoch_meta(&[EPOCH_META_RECORD_VERSION, BITCODE_PREFIX, 0xFF]).unwrap_err();
1485 assert!(
1486 matches!(epoch_meta_err, SamplerError::SplitStore(msg) if msg.contains("corrupt epoch meta record"))
1487 );
1488
1489 let epoch_hashes_err =
1490 decode_epoch_hashes(&[EPOCH_HASH_RECORD_VERSION, BITCODE_PREFIX, 0xFF]).unwrap_err();
1491 assert!(matches!(
1492 epoch_hashes_err,
1493 SamplerError::SplitStore(msg) if msg.contains("corrupt epoch hashes record")
1494 ));
1495
1496 let sampler_state_err =
1497 decode_sampler_state(&[SAMPLER_STATE_RECORD_VERSION, BITCODE_PREFIX, 0xFF])
1498 .unwrap_err();
1499 assert!(matches!(
1500 sampler_state_err,
1501 SamplerError::SplitStore(msg) if msg.contains("corrupt sampler state record")
1502 ));
1503 }
1504
1505 #[test]
1506 fn file_store_label_fallback_and_validation_keys_are_covered() {
1507 let dir = tempdir().unwrap();
1508 let path = dir.path().join("labels.bin");
1509 let store = FileSplitStore::open(&path, SplitRatios::default(), 42).unwrap();
1510
1511 let id = "bad_label_record".to_string();
1512 let expected = derive_label_for_id(&id, 42, SplitRatios::default());
1513 let key = split_key(&id);
1514
1515 store.store.write(&key, b"x").unwrap();
1516 assert_eq!(store.label_for(&id), Some(expected));
1517
1518 store.store.write(&key, b"1").unwrap();
1519 assert_eq!(store.label_for(&id), Some(SplitLabel::Validation));
1520
1521 let meta_validation = epoch_meta_key(SplitLabel::Validation);
1522 let hashes_validation = epoch_hashes_key(SplitLabel::Validation);
1523 assert!(meta_validation.starts_with(EPOCH_META_PREFIX));
1524 assert!(hashes_validation.starts_with(EPOCH_HASHES_PREFIX));
1525 assert!(meta_validation.ends_with(b"validation"));
1526 assert!(hashes_validation.ends_with(b"validation"));
1527 }
1528
1529 #[test]
1530 fn ensure_parent_dir_allows_plain_file_names() {
1531 ensure_parent_dir(Path::new("split_store_local.bin")).unwrap();
1532 let coerced = coerce_store_path(PathBuf::from("explicit_store.bin"));
1533 assert_eq!(coerced, PathBuf::from("explicit_store.bin"));
1534 }
1535
1536 #[test]
1542 fn load_path_source_is_never_modified_while_open() {
1543 let dir = tempdir().unwrap();
1544 let source = dir.path().join("source.bin");
1545 let dest = dir.path().join("dest.bin");
1546 let ratios = SplitRatios::default();
1547
1548 let seeded = FileSplitStore::open(&source, ratios, 77).unwrap();
1550 let state = PersistedSamplerState {
1551 source_cycle_idx: 5,
1552 source_record_cursors: vec![("s".to_string(), 3)],
1553 epoch: 9,
1554 epoch_step: 0,
1555 rng_state: 42,
1556 triplet_recipe_rr_idx: 1,
1557 text_recipe_rr_idx: 2,
1558 source_stream_cursors: vec![("s".to_string(), 4)],
1559 };
1560 seeded.save_sampler_state(&state, None).unwrap();
1561 drop(seeded);
1562
1563 let source_size_before = std::fs::metadata(&source).unwrap().len();
1564
1565 let bootstrapped =
1567 FileSplitStore::open_with_load_path(Some(&source), &dest, ratios, 77).unwrap();
1568 let new_state = PersistedSamplerState {
1569 source_cycle_idx: 99,
1570 source_record_cursors: vec![("s".to_string(), 77)],
1571 epoch: 100,
1572 epoch_step: 0,
1573 rng_state: 0,
1574 triplet_recipe_rr_idx: 0,
1575 text_recipe_rr_idx: 0,
1576 source_stream_cursors: vec![],
1577 };
1578 bootstrapped.save_sampler_state(&new_state, None).unwrap();
1579 drop(bootstrapped);
1580
1581 let source_size_after = std::fs::metadata(&source).unwrap().len();
1583 assert_eq!(
1584 source_size_before, source_size_after,
1585 "source file was modified during bootstrapped open"
1586 );
1587
1588 let verify_source = FileSplitStore::open(&source, ratios, 77).unwrap();
1590 let loaded = verify_source.load_sampler_state().unwrap().unwrap();
1591 assert_eq!(loaded.source_cycle_idx, 5);
1592 assert_eq!(loaded.epoch, 9);
1593 }
1594
1595 #[test]
1598 fn save_none_on_bootstrapped_store_publishes_to_save_path() {
1599 let dir = tempdir().unwrap();
1600 let source = dir.path().join("load.bin");
1601 let dest = dir.path().join("save.bin");
1602 let ratios = SplitRatios::default();
1603
1604 let _ = FileSplitStore::open(&source, ratios, 11).unwrap();
1605
1606 assert!(!dest.exists(), "dest must not exist before first save");
1607
1608 let store = FileSplitStore::open_with_load_path(Some(&source), &dest, ratios, 11).unwrap();
1609 let state = PersistedSamplerState {
1610 source_cycle_idx: 7,
1611 source_record_cursors: vec![],
1612 epoch: 2,
1613 epoch_step: 0,
1614 rng_state: 1,
1615 triplet_recipe_rr_idx: 0,
1616 text_recipe_rr_idx: 0,
1617 source_stream_cursors: vec![],
1618 };
1619 store.save_sampler_state(&state, None).unwrap();
1620 drop(store);
1621
1622 assert!(dest.exists(), "dest must exist after save(None)");
1623
1624 let loaded_dest = FileSplitStore::open(&dest, ratios, 11).unwrap();
1625 assert_eq!(
1626 loaded_dest
1627 .load_sampler_state()
1628 .unwrap()
1629 .unwrap()
1630 .source_cycle_idx,
1631 7
1632 );
1633 }
1634
1635 #[test]
1638 fn save_some_on_bootstrapped_store_publishes_to_explicit_path_only() {
1639 let dir = tempdir().unwrap();
1640 let source = dir.path().join("load.bin");
1641 let save = dir.path().join("save.bin"); let other = dir.path().join("other.bin"); let ratios = SplitRatios::default();
1644
1645 let _ = FileSplitStore::open(&source, ratios, 22).unwrap();
1646
1647 let store = FileSplitStore::open_with_load_path(Some(&source), &save, ratios, 22).unwrap();
1648 let state = PersistedSamplerState {
1649 source_cycle_idx: 3,
1650 source_record_cursors: vec![],
1651 epoch: 1,
1652 epoch_step: 0,
1653 rng_state: 0,
1654 triplet_recipe_rr_idx: 0,
1655 text_recipe_rr_idx: 0,
1656 source_stream_cursors: vec![],
1657 };
1658 store
1659 .save_sampler_state(&state, Some(other.as_path()))
1660 .unwrap();
1661 drop(store);
1662
1663 assert!(
1664 !save.exists(),
1665 "canonical save_path must not be created when saving to explicit path"
1666 );
1667 assert!(other.exists(), "explicit target must be created");
1668
1669 let loaded_other = FileSplitStore::open(&other, ratios, 22).unwrap();
1670 assert_eq!(
1671 loaded_other
1672 .load_sampler_state()
1673 .unwrap()
1674 .unwrap()
1675 .source_cycle_idx,
1676 3
1677 );
1678 }
1679
1680 #[test]
1683 fn repeated_save_none_on_bootstrapped_store_is_idempotent() {
1684 let dir = tempdir().unwrap();
1685 let source = dir.path().join("load.bin");
1686 let dest = dir.path().join("save.bin");
1687 let ratios = SplitRatios::default();
1688
1689 let _ = FileSplitStore::open(&source, ratios, 33).unwrap();
1690
1691 let store = FileSplitStore::open_with_load_path(Some(&source), &dest, ratios, 33).unwrap();
1692
1693 for cycle_idx in [1_u64, 2, 3] {
1694 let state = PersistedSamplerState {
1695 source_cycle_idx: cycle_idx,
1696 source_record_cursors: vec![],
1697 epoch: 0,
1698 epoch_step: 0,
1699 rng_state: 0,
1700 triplet_recipe_rr_idx: 0,
1701 text_recipe_rr_idx: 0,
1702 source_stream_cursors: vec![],
1703 };
1704 store.save_sampler_state(&state, None).unwrap();
1705 }
1706 drop(store);
1707
1708 let reloaded = FileSplitStore::open(&dest, ratios, 33).unwrap();
1709 assert_eq!(
1710 reloaded
1711 .load_sampler_state()
1712 .unwrap()
1713 .unwrap()
1714 .source_cycle_idx,
1715 3,
1716 "dest should hold the last-saved state"
1717 );
1718 }
1719}