1use serde::{Deserialize, Serialize};
2use simd_r_drive::storage_engine::DataStore;
3use simd_r_drive::storage_engine::traits::{DataStoreReader, DataStoreWriter};
4use std::collections::HashMap;
5use std::fmt;
6use std::fs;
7use std::hash::{Hash, Hasher};
8use std::io;
9use std::path::{Path, PathBuf};
10use std::sync::RwLock;
11use tempfile::TempDir;
12
13use crate::constants::splits::{
14 ALL_SPLITS, BITCODE_PREFIX, EPOCH_HASH_RECORD_VERSION, EPOCH_HASHES_PREFIX, EPOCH_META_PREFIX,
15 EPOCH_META_RECORD_VERSION, EPOCH_RECORD_TOMBSTONE, META_KEY, SAMPLER_STATE_KEY,
16 SAMPLER_STATE_RECORD_VERSION, SPLIT_PREFIX, STORE_VERSION,
17};
18use crate::data::RecordId;
19use crate::errors::SamplerError;
20use crate::types::SourceId;
21
22#[derive(
24 Clone,
25 Copy,
26 Debug,
27 PartialEq,
28 Eq,
29 Hash,
30 Serialize,
31 Deserialize,
32 bitcode::Encode,
33 bitcode::Decode,
34)]
35pub enum SplitLabel {
36 Train,
38 Validation,
40 Test,
42}
43
44#[derive(Clone, Copy, Debug, Serialize, Deserialize, bitcode::Encode, bitcode::Decode)]
46pub struct SplitRatios {
47 pub train: f32,
49 pub validation: f32,
51 pub test: f32,
53}
54
55impl Default for SplitRatios {
56 fn default() -> Self {
57 Self {
58 train: 0.8,
59 validation: 0.1,
60 test: 0.1,
61 }
62 }
63}
64
65impl SplitRatios {
66 pub fn normalized(self) -> Result<Self, SamplerError> {
68 let sum = self.train + self.validation + self.test;
69 if (sum - 1.0).abs() > 1e-6 {
70 return Err(SamplerError::Configuration(
71 "split ratios must sum to 1.0".to_string(),
72 ));
73 }
74 Ok(self)
75 }
76}
77
78pub use crate::constants::splits::EPOCH_STATE_VERSION;
79
80#[derive(Clone, Debug, bitcode::Encode, bitcode::Decode)]
82pub struct PersistedSplitMeta {
83 pub epoch: u64,
85 pub offset: u64,
87 pub hashes_checksum: u64,
89}
90
91#[derive(Clone, Debug, bitcode::Encode, bitcode::Decode)]
93pub struct PersistedSplitHashes {
94 pub checksum: u64,
96 pub hashes: Vec<u64>,
98}
99
100#[derive(Clone, Debug, bitcode::Encode, bitcode::Decode)]
102pub struct PersistedSamplerState {
103 pub source_cycle_idx: u64,
105 pub source_record_cursors: Vec<(SourceId, u64)>,
107 pub source_epoch: u64,
109 pub rng_state: u64,
111 pub triplet_recipe_rr_idx: u64,
113 pub text_recipe_rr_idx: u64,
115 pub source_stream_cursors: Vec<(SourceId, u64)>,
117}
118
119pub trait SplitStore: Send + Sync {
123 fn label_for(&self, id: &RecordId) -> Option<SplitLabel>;
125 fn upsert(&self, id: RecordId, label: SplitLabel) -> Result<(), SamplerError>;
127 fn ratios(&self) -> SplitRatios;
129 fn ensure(&self, id: RecordId) -> Result<SplitLabel, SamplerError>;
131}
132
133pub trait EpochStateStore: Send + Sync {
135 fn load_epoch_meta(&self) -> Result<HashMap<SplitLabel, PersistedSplitMeta>, SamplerError>;
137 fn load_epoch_hashes(
139 &self,
140 label: SplitLabel,
141 ) -> Result<Option<PersistedSplitHashes>, SamplerError>;
142 fn save_epoch_meta(
144 &self,
145 meta: &HashMap<SplitLabel, PersistedSplitMeta>,
146 ) -> Result<(), SamplerError>;
147 fn save_epoch_hashes(
149 &self,
150 label: SplitLabel,
151 hashes: &PersistedSplitHashes,
152 ) -> Result<(), SamplerError>;
153}
154
155pub trait SamplerStateStore: Send + Sync {
157 fn load_sampler_state(&self) -> Result<Option<PersistedSamplerState>, SamplerError>;
159 fn save_sampler_state(
161 &self,
162 state: &PersistedSamplerState,
163 save_path: Option<&Path>,
164 ) -> Result<(), SamplerError>;
165}
166
167pub struct DeterministicSplitStore {
169 ratios: SplitRatios,
170 assignments: RwLock<HashMap<RecordId, SplitLabel>>,
171 seed: u64,
172 epoch_meta: RwLock<HashMap<SplitLabel, PersistedSplitMeta>>,
173 epoch_hashes: RwLock<HashMap<SplitLabel, PersistedSplitHashes>>,
174 sampler_state: RwLock<Option<PersistedSamplerState>>,
175}
176
177impl DeterministicSplitStore {
178 pub fn new(ratios: SplitRatios, seed: u64) -> Result<Self, SamplerError> {
180 ratios.normalized()?;
181 Ok(Self {
182 ratios,
183 assignments: RwLock::new(HashMap::new()),
184 seed,
185 epoch_meta: RwLock::new(HashMap::new()),
186 epoch_hashes: RwLock::new(HashMap::new()),
187 sampler_state: RwLock::new(None),
188 })
189 }
190
191 fn derive_label(&self, id: &RecordId) -> SplitLabel {
192 derive_label_for_id(id, self.seed, self.ratios)
193 }
194}
195
196impl SplitStore for DeterministicSplitStore {
197 fn label_for(&self, id: &RecordId) -> Option<SplitLabel> {
198 if let Some(label) = self.assignments.read().ok()?.get(id).copied() {
199 return Some(label);
200 }
201 Some(self.derive_label(id))
202 }
203
204 fn upsert(&self, id: RecordId, label: SplitLabel) -> Result<(), SamplerError> {
205 let mut guard = self
206 .assignments
207 .write()
208 .map_err(|_| SamplerError::SplitStore("lock poisoned".into()))?;
209 guard.insert(id, label);
210 Ok(())
211 }
212
213 fn ratios(&self) -> SplitRatios {
214 self.ratios
215 }
216
217 fn ensure(&self, id: RecordId) -> Result<SplitLabel, SamplerError> {
218 Ok(self.derive_label(&id))
219 }
220}
221
222impl EpochStateStore for DeterministicSplitStore {
223 fn load_epoch_meta(&self) -> Result<HashMap<SplitLabel, PersistedSplitMeta>, SamplerError> {
224 self.epoch_meta
225 .read()
226 .map_err(|_| SamplerError::SplitStore("epoch meta lock poisoned".into()))
227 .map(|guard| guard.clone())
228 }
229
230 fn load_epoch_hashes(
231 &self,
232 label: SplitLabel,
233 ) -> Result<Option<PersistedSplitHashes>, SamplerError> {
234 Ok(self
235 .epoch_hashes
236 .read()
237 .map_err(|_| SamplerError::SplitStore("epoch hashes lock poisoned".into()))?
238 .get(&label)
239 .cloned())
240 }
241
242 fn save_epoch_meta(
243 &self,
244 meta: &HashMap<SplitLabel, PersistedSplitMeta>,
245 ) -> Result<(), SamplerError> {
246 *self
247 .epoch_meta
248 .write()
249 .map_err(|_| SamplerError::SplitStore("epoch meta lock poisoned".into()))? =
250 meta.clone();
251 Ok(())
252 }
253
254 fn save_epoch_hashes(
255 &self,
256 label: SplitLabel,
257 hashes: &PersistedSplitHashes,
258 ) -> Result<(), SamplerError> {
259 self.epoch_hashes
260 .write()
261 .map_err(|_| SamplerError::SplitStore("epoch hashes lock poisoned".into()))?
262 .insert(label, hashes.clone());
263 Ok(())
264 }
265}
266
267impl SamplerStateStore for DeterministicSplitStore {
268 fn load_sampler_state(&self) -> Result<Option<PersistedSamplerState>, SamplerError> {
269 self.sampler_state
270 .read()
271 .map_err(|_| SamplerError::SplitStore("sampler state lock poisoned".into()))
272 .map(|guard| guard.clone())
273 }
274
275 fn save_sampler_state(
276 &self,
277 state: &PersistedSamplerState,
278 _save_path: Option<&Path>,
279 ) -> Result<(), SamplerError> {
280 *self
281 .sampler_state
282 .write()
283 .map_err(|_| SamplerError::SplitStore("sampler state lock poisoned".into()))? =
284 Some(state.clone());
285 Ok(())
286 }
287}
288
289#[derive(Clone, Copy, Debug, bitcode::Encode, bitcode::Decode)]
290struct StoreMeta {
292 version: u8,
293 seed: u64,
294 ratios: SplitRatios,
295}
296
297fn encode_store_meta(meta: &StoreMeta) -> Vec<u8> {
298 encode_bitcode_payload(&bitcode::encode(meta))
299}
300
301fn decode_store_meta(bytes: &[u8]) -> Result<StoreMeta, SamplerError> {
302 let raw = decode_bitcode_payload(bytes)?;
303 bitcode::decode(&raw).map_err(|err| {
304 SamplerError::SplitStore(format!("failed to decode split store metadata: {err}"))
305 })
306}
307
308pub struct FileSplitStore {
324 store: DataStore,
325 path: PathBuf,
327 save_path: PathBuf,
329 ratios: SplitRatios,
330 seed: u64,
331 _temp_dir: TempDir,
333}
334
335impl fmt::Debug for FileSplitStore {
336 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
337 f.debug_struct("FileSplitStore")
338 .field("path", &self.save_path)
339 .field("ratios", &self.ratios)
340 .field("seed", &self.seed)
341 .finish()
342 }
343}
344
345impl FileSplitStore {
346 pub fn open<P: Into<PathBuf>>(
348 path: P,
349 ratios: SplitRatios,
350 seed: u64,
351 ) -> Result<Self, SamplerError> {
352 Self::open_with_load_path(None::<PathBuf>, path, ratios, seed)
353 }
354
355 pub fn open_with_load_path<LP: Into<PathBuf>, SP: Into<PathBuf>>(
369 load_path: Option<LP>,
370 save_path: SP,
371 ratios: SplitRatios,
372 seed: u64,
373 ) -> Result<Self, SamplerError> {
374 let ratios = ratios.normalized()?;
375 let save_path = coerce_store_path(save_path.into());
376 ensure_parent_dir(&save_path)?;
377
378 let source: Option<PathBuf> = if let Some(lp) = load_path {
381 let lp = coerce_store_path(lp.into());
382 if lp.exists() { Some(lp) } else { None }
383 } else if save_path.exists() {
384 Some(save_path.clone())
385 } else {
386 None
387 };
388
389 let temp_dir = tempfile::tempdir().map_err(|err| {
390 SamplerError::SplitStore(format!("failed to create temp dir for split store: {err}"))
391 })?;
392 let working_path = temp_dir.path().join("working_store.bin");
393 if let Some(src) = &source {
394 fs::copy(src, &working_path).map_err(|err| {
395 SamplerError::SplitStore(format!(
396 "failed to copy split store from '{}' to temp: {err}",
397 src.display()
398 ))
399 })?;
400 }
401
402 let raw_store = DataStore::open(&working_path).map_err(map_store_err)?;
403 let store = Self {
404 store: raw_store,
405 path: working_path,
406 save_path,
407 ratios,
408 seed,
409 _temp_dir: temp_dir,
410 };
411 store.verify_metadata()?;
412 Ok(store)
413 }
414
415 fn verify_metadata(&self) -> Result<(), SamplerError> {
416 match read_bytes(&self.store, META_KEY)? {
417 Some(bytes) => {
418 let meta = decode_store_meta(&bytes)?;
419 if meta.version != STORE_VERSION {
420 return Err(SamplerError::SplitStore(format!(
421 "split store version mismatch (expected {}, found {})",
422 STORE_VERSION, meta.version
423 )));
424 }
425 if meta.seed != self.seed {
426 return Err(SamplerError::SplitStore(format!(
427 "split store seed mismatch (expected {}, found {})",
428 self.seed, meta.seed
429 )));
430 }
431 if !ratios_close(meta.ratios, self.ratios) {
432 return Err(SamplerError::SplitStore(
433 "split store ratios mismatch".into(),
434 ));
435 }
436 }
437 None => {
438 let blob = StoreMeta {
439 version: STORE_VERSION,
440 seed: self.seed,
441 ratios: self.ratios,
442 };
443 let payload = encode_store_meta(&blob);
444 write_bytes(&self.store, META_KEY, &payload)?;
445 }
446 }
447 Ok(())
448 }
449
450 fn read_epoch_meta_entry(
451 &self,
452 label: SplitLabel,
453 ) -> Result<Option<PersistedSplitMeta>, SamplerError> {
454 let key = epoch_meta_key(label);
455 let entry = self.store.read(&key).map_err(map_store_err)?;
456 match entry {
457 None => Ok(None),
458 Some(bytes) => decode_epoch_meta(bytes.as_ref()),
459 }
460 }
461
462 fn write_epoch_meta_entry(
463 &self,
464 label: SplitLabel,
465 meta: Option<&PersistedSplitMeta>,
466 ) -> Result<(), SamplerError> {
467 let key = epoch_meta_key(label);
468 let payload = encode_epoch_meta(meta);
469 self.store
470 .write(&key, payload.as_slice())
471 .map_err(map_store_err)?;
472 Ok(())
473 }
474
475 fn read_epoch_hashes_entry(
476 &self,
477 label: SplitLabel,
478 ) -> Result<Option<PersistedSplitHashes>, SamplerError> {
479 let key = epoch_hashes_key(label);
480 let entry = self.store.read(&key).map_err(map_store_err)?;
481 match entry {
482 None => Ok(None),
483 Some(bytes) => decode_epoch_hashes(bytes.as_ref()),
484 }
485 }
486
487 fn write_epoch_hashes_entry(
488 &self,
489 label: SplitLabel,
490 hashes: &PersistedSplitHashes,
491 ) -> Result<(), SamplerError> {
492 let key = epoch_hashes_key(label);
493 let payload = encode_epoch_hashes(hashes);
494 self.store
495 .write(&key, payload.as_slice())
496 .map_err(map_store_err)?;
497 Ok(())
498 }
499}
500
501impl SplitStore for FileSplitStore {
502 fn label_for(&self, id: &RecordId) -> Option<SplitLabel> {
503 let key = split_key(id);
504 if let Ok(Some(value)) = self.store.read(&key)
505 && let Ok(label) = decode_label(value.as_ref())
506 {
507 return Some(label);
508 }
509 Some(derive_label_for_id(id, self.seed, self.ratios))
510 }
511
512 fn upsert(&self, id: RecordId, label: SplitLabel) -> Result<(), SamplerError> {
513 let _ = (id, label);
514 Ok(())
515 }
516
517 fn ratios(&self) -> SplitRatios {
518 self.ratios
519 }
520
521 fn ensure(&self, id: RecordId) -> Result<SplitLabel, SamplerError> {
522 Ok(derive_label_for_id(&id, self.seed, self.ratios))
523 }
524}
525
526impl EpochStateStore for FileSplitStore {
527 fn load_epoch_meta(&self) -> Result<HashMap<SplitLabel, PersistedSplitMeta>, SamplerError> {
528 let mut meta = HashMap::new();
529 for label in ALL_SPLITS {
530 if let Some(entry) = self.read_epoch_meta_entry(label)? {
531 meta.insert(label, entry);
532 }
533 }
534 Ok(meta)
535 }
536
537 fn load_epoch_hashes(
538 &self,
539 label: SplitLabel,
540 ) -> Result<Option<PersistedSplitHashes>, SamplerError> {
541 self.read_epoch_hashes_entry(label)
542 }
543
544 fn save_epoch_meta(
545 &self,
546 meta: &HashMap<SplitLabel, PersistedSplitMeta>,
547 ) -> Result<(), SamplerError> {
548 for label in ALL_SPLITS {
549 self.write_epoch_meta_entry(label, meta.get(&label))?;
550 }
551 Ok(())
552 }
553
554 fn save_epoch_hashes(
555 &self,
556 label: SplitLabel,
557 hashes: &PersistedSplitHashes,
558 ) -> Result<(), SamplerError> {
559 self.write_epoch_hashes_entry(label, hashes)
560 }
561}
562
563impl SamplerStateStore for FileSplitStore {
564 fn load_sampler_state(&self) -> Result<Option<PersistedSamplerState>, SamplerError> {
565 match read_bytes(&self.store, SAMPLER_STATE_KEY)? {
566 Some(bytes) => decode_sampler_state(bytes.as_ref()),
567 None => Ok(None),
568 }
569 }
570
571 fn save_sampler_state(
577 &self,
578 state: &PersistedSamplerState,
579 save_to: Option<&Path>,
580 ) -> Result<(), SamplerError> {
581 let dest = if let Some(p) = save_to {
583 coerce_store_path(p.to_path_buf())
584 } else {
585 self.save_path.clone()
586 };
587
588 if save_to.is_some() && dest.exists() {
591 return Err(SamplerError::SplitStore(format!(
592 "refusing to overwrite existing split store '{}'; choose a new path",
593 dest.display()
594 )));
595 }
596
597 let payload = encode_sampler_state(state);
599 write_bytes(&self.store, SAMPLER_STATE_KEY, &payload)?;
600
601 ensure_parent_dir(&dest)?;
603 fs::copy(&self.path, &dest).map_err(|err| {
604 SamplerError::SplitStore(format!(
605 "failed to publish split store to '{}': {err}",
606 dest.display()
607 ))
608 })?;
609
610 Ok(())
611 }
612}
613
614fn decode_label(bytes: &[u8]) -> Result<SplitLabel, SamplerError> {
615 match bytes.first() {
616 Some(b'0') => Ok(SplitLabel::Train),
617 Some(b'1') => Ok(SplitLabel::Validation),
618 Some(b'2') => Ok(SplitLabel::Test),
619 _ => Err(SamplerError::SplitStore("invalid split label".into())),
620 }
621}
622
623fn derive_label_for_id(id: &RecordId, seed: u64, ratios: SplitRatios) -> SplitLabel {
624 let mut hasher = std::collections::hash_map::DefaultHasher::new();
625 id.hash(&mut hasher);
626 seed.hash(&mut hasher);
627 let value = hasher.finish() as f64 / u64::MAX as f64;
628 let train_cut = ratios.train as f64;
629 let val_cut = train_cut + ratios.validation as f64;
630 if value < train_cut {
631 SplitLabel::Train
632 } else if value < val_cut {
633 SplitLabel::Validation
634 } else {
635 SplitLabel::Test
636 }
637}
638
639fn ratios_close(a: SplitRatios, b: SplitRatios) -> bool {
640 ((a.train - b.train).abs() + (a.validation - b.validation).abs() + (a.test - b.test).abs())
641 < 1e-5
642}
643
644fn split_key(id: &RecordId) -> Vec<u8> {
645 let mut key = Vec::with_capacity(SPLIT_PREFIX.len() + id.len());
646 key.extend_from_slice(SPLIT_PREFIX);
647 key.extend_from_slice(id.as_bytes());
648 key
649}
650
651fn read_bytes(store: &DataStore, key: &[u8]) -> Result<Option<Vec<u8>>, SamplerError> {
652 store
653 .read(key)
654 .map_err(map_store_err)?
655 .map(|entry| Ok(entry.as_ref().to_vec()))
656 .transpose()
657}
658
659fn write_bytes(store: &DataStore, key: &[u8], payload: &[u8]) -> Result<(), SamplerError> {
660 store.write(key, payload).map_err(map_store_err)?;
661 Ok(())
662}
663
664fn epoch_meta_key(label: SplitLabel) -> Vec<u8> {
665 let mut key = Vec::with_capacity(EPOCH_META_PREFIX.len() + 12);
666 key.extend_from_slice(EPOCH_META_PREFIX);
667 let suffix = match label {
668 SplitLabel::Train => b"train".as_ref(),
669 SplitLabel::Validation => b"validation".as_ref(),
670 SplitLabel::Test => b"test".as_ref(),
671 };
672 key.extend_from_slice(suffix);
673 key
674}
675
676fn epoch_hashes_key(label: SplitLabel) -> Vec<u8> {
677 let mut key = Vec::with_capacity(EPOCH_HASHES_PREFIX.len() + 12);
678 key.extend_from_slice(EPOCH_HASHES_PREFIX);
679 let suffix = match label {
680 SplitLabel::Train => b"train".as_ref(),
681 SplitLabel::Validation => b"validation".as_ref(),
682 SplitLabel::Test => b"test".as_ref(),
683 };
684 key.extend_from_slice(suffix);
685 key
686}
687
688fn encode_epoch_meta(meta: Option<&PersistedSplitMeta>) -> Vec<u8> {
689 match meta {
690 None => vec![EPOCH_RECORD_TOMBSTONE],
691 Some(meta) => {
692 let payload = encode_bitcode_payload(&bitcode::encode(meta));
693 let mut buf = Vec::with_capacity(1 + payload.len());
694 buf.push(EPOCH_META_RECORD_VERSION);
695 buf.extend_from_slice(&payload);
696 buf
697 }
698 }
699}
700
701fn decode_epoch_meta(bytes: &[u8]) -> Result<Option<PersistedSplitMeta>, SamplerError> {
702 if bytes.is_empty() || bytes[0] == EPOCH_RECORD_TOMBSTONE {
703 return Ok(None);
704 }
705 if bytes[0] != EPOCH_META_RECORD_VERSION {
706 return Err(SamplerError::SplitStore(
707 "epoch meta record version mismatch".into(),
708 ));
709 }
710 let raw = decode_bitcode_payload(&bytes[1..])?;
711 bitcode::decode(&raw)
712 .map(Some)
713 .map_err(|err| SamplerError::SplitStore(format!("corrupt epoch meta record: {err}")))
714}
715
716fn encode_epoch_hashes(hashes: &PersistedSplitHashes) -> Vec<u8> {
717 let payload = encode_bitcode_payload(&bitcode::encode(hashes));
718 let mut buf = Vec::with_capacity(1 + payload.len());
719 buf.push(EPOCH_HASH_RECORD_VERSION);
720 buf.extend_from_slice(&payload);
721 buf
722}
723
724fn decode_epoch_hashes(bytes: &[u8]) -> Result<Option<PersistedSplitHashes>, SamplerError> {
725 if bytes.is_empty() || bytes[0] == EPOCH_RECORD_TOMBSTONE {
726 return Ok(None);
727 }
728 if bytes[0] != EPOCH_HASH_RECORD_VERSION {
729 return Err(SamplerError::SplitStore(
730 "epoch hashes record version mismatch".into(),
731 ));
732 }
733 let raw = decode_bitcode_payload(&bytes[1..])?;
734 bitcode::decode(&raw)
735 .map(Some)
736 .map_err(|err| SamplerError::SplitStore(format!("corrupt epoch hashes record: {err}")))
737}
738
739fn encode_sampler_state(state: &PersistedSamplerState) -> Vec<u8> {
740 let payload = encode_bitcode_payload(&bitcode::encode(state));
741 let mut buf = Vec::with_capacity(1 + payload.len());
742 buf.push(SAMPLER_STATE_RECORD_VERSION);
743 buf.extend_from_slice(&payload);
744 buf
745}
746
747fn decode_sampler_state(bytes: &[u8]) -> Result<Option<PersistedSamplerState>, SamplerError> {
748 if bytes.is_empty() {
749 return Ok(None);
750 }
751 if bytes[0] != SAMPLER_STATE_RECORD_VERSION {
752 return Err(SamplerError::SplitStore(
753 "sampler state record version mismatch".into(),
754 ));
755 }
756 let raw = decode_bitcode_payload(&bytes[1..])?;
757 bitcode::decode(&raw)
758 .map(Some)
759 .map_err(|err| SamplerError::SplitStore(format!("corrupt sampler state record: {err}")))
760}
761
762fn encode_bitcode_payload(bytes: &[u8]) -> Vec<u8> {
763 let mut out = Vec::with_capacity(1 + bytes.len());
764 out.push(BITCODE_PREFIX);
765 out.extend_from_slice(bytes);
766 out
767}
768
769fn decode_bitcode_payload(bytes: &[u8]) -> Result<Vec<u8>, SamplerError> {
770 if bytes.first().copied() != Some(BITCODE_PREFIX) {
771 return Err(SamplerError::SplitStore(
772 "bitcode payload missing expected prefix".into(),
773 ));
774 }
775 Ok(bytes[1..].to_vec())
776}
777
778fn coerce_store_path(path: PathBuf) -> PathBuf {
779 path
780}
781
782fn ensure_parent_dir(path: &Path) -> Result<(), SamplerError> {
783 if let Some(parent) = path.parent()
784 && !parent.as_os_str().is_empty()
785 {
786 fs::create_dir_all(parent)?;
787 }
788 Ok(())
789}
790
791fn map_store_err(err: io::Error) -> SamplerError {
792 SamplerError::SplitStore(err.to_string())
793}
794
795#[cfg(test)]
796mod tests {
797 use super::*;
798 use std::collections::HashMap;
799 use tempfile::tempdir;
800
801 #[test]
802 fn split_ratios_reject_non_unit_sum() {
803 let invalid = SplitRatios {
804 train: 0.6,
805 validation: 0.3,
806 test: 0.3,
807 };
808
809 let err = DeterministicSplitStore::new(invalid, 1)
810 .err()
811 .expect("expected non-unit split ratios to fail");
812 assert!(matches!(
813 err,
814 SamplerError::Configuration(ref msg) if msg.contains("split ratios must sum to 1.0")
815 ));
816
817 let dir = tempdir().unwrap();
818 let path = dir.path().join("split_store.bin");
819 let err = FileSplitStore::open(&path, invalid, 1).unwrap_err();
820 assert!(matches!(
821 err,
822 SamplerError::Configuration(ref msg) if msg.contains("split ratios must sum to 1.0")
823 ));
824 }
825
826 #[test]
827 fn zero_test_ratio_never_assigns_test_labels() {
828 let ratios = SplitRatios {
829 train: 0.5,
830 validation: 0.5,
831 test: 0.0,
832 };
833 let store = DeterministicSplitStore::new(ratios, 7).unwrap();
834
835 let mut saw_train = false;
836 let mut saw_validation = false;
837 for idx in 0..20_000 {
838 let id = format!("record_{idx}");
839 let label = store.ensure(id).unwrap();
840 assert_ne!(label, SplitLabel::Test);
841 saw_train |= label == SplitLabel::Train;
842 saw_validation |= label == SplitLabel::Validation;
843 if saw_train && saw_validation {
844 break;
845 }
846 }
847
848 assert!(saw_train);
849 assert!(saw_validation);
850 }
851
852 #[test]
855 fn save_none_publishes_to_save_path_and_reloads_cleanly() {
856 let dir = tempdir().unwrap();
857 let path = dir.path().join("persist.bin");
858 let ratios = SplitRatios::default();
859
860 {
861 let store = FileSplitStore::open(&path, ratios, 123).unwrap();
862 let mut meta = HashMap::new();
863 meta.insert(
864 SplitLabel::Train,
865 PersistedSplitMeta {
866 epoch: 3,
867 offset: 7,
868 hashes_checksum: 42,
869 },
870 );
871 store.save_epoch_meta(&meta).unwrap();
872 let state = PersistedSamplerState {
873 source_cycle_idx: 11,
874 source_record_cursors: vec![("s".to_string(), 1)],
875 source_epoch: 5,
876 rng_state: 99,
877 triplet_recipe_rr_idx: 2,
878 text_recipe_rr_idx: 3,
879 source_stream_cursors: vec![],
880 };
881 assert!(!path.exists(), "save_path must not exist before save(None)");
882 store.save_sampler_state(&state, None).unwrap();
883 assert!(
884 path.exists(),
885 "save(None) must publish to save_path on disk"
886 );
887 }
888
889 let reopened = FileSplitStore::open(&path, ratios, 123).unwrap();
891 let loaded_state = reopened.load_sampler_state().unwrap().unwrap();
892 assert_eq!(loaded_state.source_cycle_idx, 11);
893 assert_eq!(loaded_state.source_epoch, 5);
894 assert_eq!(loaded_state.rng_state, 99);
895 let loaded_meta = reopened.load_epoch_meta().unwrap();
896 assert_eq!(loaded_meta.get(&SplitLabel::Train).unwrap().epoch, 3);
897 }
898
899 #[test]
900 fn file_store_rejects_seed_mismatch() {
901 let dir = tempdir().unwrap();
902 let path = dir.path().join("splits.json");
903 let ratios = SplitRatios::default();
904 let store = FileSplitStore::open(&path, ratios, 123).unwrap();
905 store.ensure("abc".to_string()).unwrap();
906 store
908 .save_sampler_state(
909 &PersistedSamplerState {
910 source_cycle_idx: 0,
911 source_record_cursors: vec![],
912 source_epoch: 0,
913 rng_state: 0,
914 triplet_recipe_rr_idx: 0,
915 text_recipe_rr_idx: 0,
916 source_stream_cursors: vec![],
917 },
918 None,
919 )
920 .unwrap();
921 drop(store);
922
923 let err = FileSplitStore::open(&path, ratios, 999).unwrap_err();
924 assert!(matches!(
925 err,
926 SamplerError::SplitStore(msg) if msg.contains("seed")
927 ));
928 }
929
930 #[test]
931 fn file_store_accepts_directory_path() {
932 let dir = tempdir().unwrap();
933 let ratios = SplitRatios::default();
934 let err = FileSplitStore::open(dir.path(), ratios, 777).unwrap_err();
935 assert!(matches!(err, SamplerError::SplitStore(_)));
936 }
937
938 #[test]
939 fn bitcode_payload_requires_prefix() {
940 let err = decode_bitcode_payload(&[0x00, 0x01]).unwrap_err();
941 assert!(
942 matches!(err, SamplerError::SplitStore(msg) if msg.contains("missing expected prefix"))
943 );
944 }
945
946 #[test]
947 fn file_store_round_trips_epoch_and_sampler_state() {
948 let dir = tempdir().unwrap();
949 let path = dir.path().join("epoch_sampler_state.bin");
950 let store = FileSplitStore::open(&path, SplitRatios::default(), 222).unwrap();
951
952 assert!(store.load_epoch_hashes(SplitLabel::Test).unwrap().is_none());
953
954 let mut epoch_meta = HashMap::new();
955 epoch_meta.insert(
956 SplitLabel::Train,
957 PersistedSplitMeta {
958 epoch: 3,
959 offset: 7,
960 hashes_checksum: 42,
961 },
962 );
963 store.save_epoch_meta(&epoch_meta).unwrap();
964
965 let loaded_meta = store.load_epoch_meta().unwrap();
966 let loaded_train = loaded_meta.get(&SplitLabel::Train).unwrap();
967 assert_eq!(loaded_train.epoch, 3);
968 assert_eq!(loaded_train.offset, 7);
969 assert_eq!(loaded_train.hashes_checksum, 42);
970
971 let hashes = PersistedSplitHashes {
972 checksum: 99,
973 hashes: vec![10, 20, 30],
974 };
975 store
976 .save_epoch_hashes(SplitLabel::Validation, &hashes)
977 .unwrap();
978 let loaded_hashes = store
979 .load_epoch_hashes(SplitLabel::Validation)
980 .unwrap()
981 .unwrap();
982 assert_eq!(loaded_hashes.checksum, 99);
983 assert_eq!(loaded_hashes.hashes, vec![10, 20, 30]);
984
985 let state = PersistedSamplerState {
986 source_cycle_idx: 11,
987 source_record_cursors: vec![("source_a".to_string(), 1)],
988 source_epoch: 8,
989 rng_state: 1234,
990 triplet_recipe_rr_idx: 2,
991 text_recipe_rr_idx: 5,
992 source_stream_cursors: vec![("source_a".to_string(), 9)],
993 };
994 store.save_sampler_state(&state, None).unwrap();
995 let loaded_state = store.load_sampler_state().unwrap().unwrap();
996 assert_eq!(loaded_state.source_cycle_idx, 11);
997 assert_eq!(loaded_state.source_epoch, 8);
998 assert_eq!(loaded_state.rng_state, 1234);
999 assert_eq!(loaded_state.triplet_recipe_rr_idx, 2);
1000 assert_eq!(loaded_state.text_recipe_rr_idx, 5);
1001 assert_eq!(
1002 loaded_state.source_record_cursors,
1003 vec![("source_a".to_string(), 1)]
1004 );
1005 assert_eq!(
1006 loaded_state.source_stream_cursors,
1007 vec![("source_a".to_string(), 9)]
1008 );
1009
1010 drop(store);
1012 let reopened = FileSplitStore::open(&path, SplitRatios::default(), 222).unwrap();
1013 let disk_state = reopened.load_sampler_state().unwrap().unwrap();
1014 assert_eq!(disk_state.source_cycle_idx, 11);
1015 assert_eq!(disk_state.source_epoch, 8);
1016 assert_eq!(disk_state.rng_state, 1234);
1017 let disk_meta = reopened.load_epoch_meta().unwrap();
1018 assert_eq!(disk_meta.get(&SplitLabel::Train).unwrap().epoch, 3);
1019 let disk_hashes = reopened
1020 .load_epoch_hashes(SplitLabel::Validation)
1021 .unwrap()
1022 .unwrap();
1023 assert_eq!(disk_hashes.checksum, 99);
1024 }
1025
1026 #[test]
1027 fn split_keys_and_labels_cover_helper_paths() {
1028 let key = split_key(&"abc".to_string());
1029 assert!(key.starts_with(SPLIT_PREFIX));
1030
1031 assert!(matches!(decode_label(b"0"), Ok(SplitLabel::Train)));
1032 assert!(matches!(decode_label(b"1"), Ok(SplitLabel::Validation)));
1033 assert!(matches!(decode_label(b"2"), Ok(SplitLabel::Test)));
1034 assert!(decode_label(b"x").is_err());
1035
1036 let epoch_meta_train = epoch_meta_key(SplitLabel::Train);
1037 let epoch_hashes_train = epoch_hashes_key(SplitLabel::Train);
1038 let epoch_hashes_test = epoch_hashes_key(SplitLabel::Test);
1039 assert!(epoch_meta_train.starts_with(EPOCH_META_PREFIX));
1040 assert!(epoch_hashes_train.starts_with(EPOCH_HASHES_PREFIX));
1041 assert!(epoch_hashes_test.starts_with(EPOCH_HASHES_PREFIX));
1042 }
1043
1044 #[test]
1045 fn encode_decode_store_meta_roundtrip_and_corrupt_prefix_error() {
1046 let meta = StoreMeta {
1047 version: STORE_VERSION,
1048 seed: 55,
1049 ratios: SplitRatios::default(),
1050 };
1051 let encoded = encode_store_meta(&meta);
1052 let decoded = decode_store_meta(&encoded).unwrap();
1053 assert_eq!(decoded.version, STORE_VERSION);
1054 assert_eq!(decoded.seed, 55);
1055
1056 let err = decode_store_meta(&[0x00, 0x01]).unwrap_err();
1057 assert!(matches!(
1058 err,
1059 SamplerError::SplitStore(msg) if msg.contains("missing expected prefix")
1060 ));
1061 }
1062
1063 #[test]
1064 fn epoch_and_sampler_decoders_cover_tombstone_and_version_mismatch() {
1065 assert!(decode_epoch_meta(&[]).unwrap().is_none());
1066 assert!(
1067 decode_epoch_meta(&[EPOCH_RECORD_TOMBSTONE])
1068 .unwrap()
1069 .is_none()
1070 );
1071 assert!(decode_epoch_hashes(&[]).unwrap().is_none());
1072 assert!(
1073 decode_epoch_hashes(&[EPOCH_RECORD_TOMBSTONE])
1074 .unwrap()
1075 .is_none()
1076 );
1077 assert!(decode_sampler_state(&[]).unwrap().is_none());
1078
1079 let meta_mismatch = decode_epoch_meta(&[EPOCH_META_RECORD_VERSION.wrapping_add(1), 1]);
1080 assert!(matches!(
1081 meta_mismatch,
1082 Err(SamplerError::SplitStore(msg)) if msg.contains("version mismatch")
1083 ));
1084 let hashes_mismatch = decode_epoch_hashes(&[EPOCH_HASH_RECORD_VERSION.wrapping_add(1), 1]);
1085 assert!(matches!(
1086 hashes_mismatch,
1087 Err(SamplerError::SplitStore(msg)) if msg.contains("version mismatch")
1088 ));
1089 let state_mismatch =
1090 decode_sampler_state(&[SAMPLER_STATE_RECORD_VERSION.wrapping_add(1), 1]);
1091 assert!(matches!(
1092 state_mismatch,
1093 Err(SamplerError::SplitStore(msg)) if msg.contains("version mismatch")
1094 ));
1095 }
1096
1097 #[test]
1098 fn split_store_trait_methods_and_path_helpers_are_exercised() {
1099 let dir = tempdir().unwrap();
1100 let file_path = dir.path().join("nested").join("store.bin");
1101 ensure_parent_dir(&file_path).unwrap();
1102 assert!(file_path.parent().unwrap().exists());
1103
1104 let existing_dir_path = coerce_store_path(dir.path().to_path_buf());
1105 assert_eq!(existing_dir_path, dir.path().to_path_buf());
1106
1107 let ratios = SplitRatios::default();
1108 let store = FileSplitStore::open(&file_path, ratios, 444).unwrap();
1109 assert!((store.ratios().train - ratios.train).abs() < 1e-6);
1110 store
1111 .upsert("record_1".to_string(), SplitLabel::Validation)
1112 .unwrap();
1113 let ensured = store.ensure("record_1".to_string()).unwrap();
1114 assert!(matches!(
1115 ensured,
1116 SplitLabel::Train | SplitLabel::Validation | SplitLabel::Test
1117 ));
1118
1119 let mapped = map_store_err(io::Error::other("boom"));
1120 assert!(matches!(mapped, SamplerError::SplitStore(msg) if msg.contains("boom")));
1121 }
1122
1123 #[test]
1124 fn epoch_and_sampler_encode_decode_roundtrips() {
1125 let meta = PersistedSplitMeta {
1126 epoch: 4,
1127 offset: 9,
1128 hashes_checksum: 21,
1129 };
1130 let encoded_meta = encode_epoch_meta(Some(&meta));
1131 let decoded_meta = decode_epoch_meta(&encoded_meta).unwrap().unwrap();
1132 assert_eq!(decoded_meta.epoch, 4);
1133 assert_eq!(decoded_meta.offset, 9);
1134
1135 let hashes = PersistedSplitHashes {
1136 checksum: 7,
1137 hashes: vec![1, 2, 3],
1138 };
1139 let encoded_hashes = encode_epoch_hashes(&hashes);
1140 let decoded_hashes = decode_epoch_hashes(&encoded_hashes).unwrap().unwrap();
1141 assert_eq!(decoded_hashes.checksum, 7);
1142 assert_eq!(decoded_hashes.hashes, vec![1, 2, 3]);
1143
1144 let state = PersistedSamplerState {
1145 source_cycle_idx: 1,
1146 source_record_cursors: vec![("s".to_string(), 2)],
1147 source_epoch: 3,
1148 rng_state: 4,
1149 triplet_recipe_rr_idx: 5,
1150 text_recipe_rr_idx: 6,
1151 source_stream_cursors: vec![("s".to_string(), 7)],
1152 };
1153 let encoded_state = encode_sampler_state(&state);
1154 let decoded_state = decode_sampler_state(&encoded_state).unwrap().unwrap();
1155 assert_eq!(decoded_state.source_cycle_idx, 1);
1156 assert_eq!(decoded_state.source_epoch, 3);
1157 assert_eq!(decoded_state.rng_state, 4);
1158 }
1159
1160 #[test]
1161 fn deterministic_store_trait_methods_work() {
1162 let ratios = SplitRatios::default();
1163 let store = DeterministicSplitStore::new(ratios, 999).unwrap();
1164
1165 assert_eq!(store.ratios().train, ratios.train);
1166
1167 let id = "source::record".to_string();
1168 let derived = store.label_for(&id).unwrap();
1169 store.upsert(id.clone(), SplitLabel::Validation).unwrap();
1170 assert_eq!(store.label_for(&id), Some(SplitLabel::Validation));
1171 assert!(matches!(
1172 derived,
1173 SplitLabel::Train | SplitLabel::Validation | SplitLabel::Test
1174 ));
1175
1176 let mut meta = HashMap::new();
1177 meta.insert(
1178 SplitLabel::Test,
1179 PersistedSplitMeta {
1180 epoch: 1,
1181 offset: 2,
1182 hashes_checksum: 3,
1183 },
1184 );
1185 store.save_epoch_meta(&meta).unwrap();
1186 let loaded_meta = store.load_epoch_meta().unwrap();
1187 assert_eq!(loaded_meta.get(&SplitLabel::Test).unwrap().offset, 2);
1188
1189 assert!(
1190 store
1191 .load_epoch_hashes(SplitLabel::Train)
1192 .unwrap()
1193 .is_none()
1194 );
1195 store
1196 .save_epoch_hashes(
1197 SplitLabel::Train,
1198 &PersistedSplitHashes {
1199 checksum: 11,
1200 hashes: vec![4, 5],
1201 },
1202 )
1203 .unwrap();
1204 assert_eq!(
1205 store
1206 .load_epoch_hashes(SplitLabel::Train)
1207 .unwrap()
1208 .unwrap()
1209 .checksum,
1210 11
1211 );
1212
1213 assert!(store.load_sampler_state().unwrap().is_none());
1214 let sampler_state = PersistedSamplerState {
1215 source_cycle_idx: 1,
1216 source_record_cursors: vec![("s1".to_string(), 2)],
1217 source_epoch: 3,
1218 rng_state: 4,
1219 triplet_recipe_rr_idx: 5,
1220 text_recipe_rr_idx: 6,
1221 source_stream_cursors: vec![("s1".to_string(), 7)],
1222 };
1223 store.save_sampler_state(&sampler_state, None).unwrap();
1224 assert_eq!(
1225 store.load_sampler_state().unwrap().unwrap().source_epoch,
1226 sampler_state.source_epoch
1227 );
1228 }
1229
1230 #[test]
1231 fn open_with_load_path_bootstraps_state_explicitly() {
1232 let dir = tempdir().unwrap();
1233 let path_a = dir.path().join("snapshot_a.bin");
1234 let path_b = dir.path().join("snapshot_b.bin");
1235 let ratios = SplitRatios::default();
1236
1237 let store_a = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1238
1239 let mut meta = HashMap::new();
1240 meta.insert(
1241 SplitLabel::Train,
1242 PersistedSplitMeta {
1243 epoch: 5,
1244 offset: 3,
1245 hashes_checksum: 999,
1246 },
1247 );
1248 store_a.save_epoch_meta(&meta).unwrap();
1249
1250 let sampler_state = PersistedSamplerState {
1251 source_cycle_idx: 1,
1252 source_record_cursors: vec![("s1".to_string(), 2)],
1253 source_epoch: 7,
1254 rng_state: 123,
1255 triplet_recipe_rr_idx: 4,
1256 text_recipe_rr_idx: 6,
1257 source_stream_cursors: vec![("s1".to_string(), 8)],
1258 };
1259 store_a.save_sampler_state(&sampler_state, None).unwrap();
1260 drop(store_a);
1261
1262 let store_b =
1263 FileSplitStore::open_with_load_path(Some(path_a.clone()), &path_b, ratios, 42).unwrap();
1264 assert_eq!(
1265 store_b
1266 .load_epoch_meta()
1267 .unwrap()
1268 .get(&SplitLabel::Train)
1269 .unwrap()
1270 .epoch,
1271 5
1272 );
1273 assert_eq!(
1274 store_b.load_sampler_state().unwrap().unwrap().source_epoch,
1275 7
1276 );
1277
1278 let store_a_again = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1279 assert_eq!(
1280 store_a_again
1281 .load_epoch_meta()
1282 .unwrap()
1283 .get(&SplitLabel::Train)
1284 .unwrap()
1285 .epoch,
1286 5
1287 );
1288 assert_eq!(
1289 store_a_again
1290 .load_sampler_state()
1291 .unwrap()
1292 .unwrap()
1293 .source_epoch,
1294 7
1295 );
1296 }
1297
1298 #[test]
1299 fn save_sampler_state_to_new_path_copies_existing_store_first() {
1300 let dir = tempdir().unwrap();
1301 let path_a = dir.path().join("source_store.bin");
1302 let path_b = dir.path().join("mirror_store.bin");
1303 let ratios = SplitRatios::default();
1304
1305 let store_a = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1306
1307 let assigned_id = "record_with_assignment".to_string();
1308 let assigned_key = split_key(&assigned_id);
1309 store_a.store.write(&assigned_key, b"1").unwrap();
1310
1311 let sampler_state = PersistedSamplerState {
1312 source_cycle_idx: 1,
1313 source_record_cursors: vec![("s1".to_string(), 2)],
1314 source_epoch: 9,
1315 rng_state: 123,
1316 triplet_recipe_rr_idx: 4,
1317 text_recipe_rr_idx: 6,
1318 source_stream_cursors: vec![("s1".to_string(), 8)],
1319 };
1320
1321 store_a
1322 .save_sampler_state(&sampler_state, Some(path_b.as_path()))
1323 .unwrap();
1324
1325 let store_b = FileSplitStore::open(&path_b, ratios, 42).unwrap();
1327 assert_eq!(
1328 store_b.label_for(&assigned_id),
1329 Some(SplitLabel::Validation)
1330 );
1331 assert_eq!(
1332 store_b.load_sampler_state().unwrap().unwrap().source_epoch,
1333 9
1334 );
1335
1336 assert!(
1338 !path_a.exists(),
1339 "save_to=Some(...) must not publish to the canonical save_path"
1340 );
1341 }
1342
1343 #[test]
1346 fn save_some_on_regular_open_does_not_modify_working_store() {
1347 let dir = tempdir().unwrap();
1348 let path_a = dir.path().join("working_store.bin");
1349 let path_b = dir.path().join("checkpoint_store.bin");
1350 let ratios = SplitRatios::default();
1351
1352 let store_a = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1353
1354 let initial_state = PersistedSamplerState {
1356 source_cycle_idx: 1,
1357 source_record_cursors: vec![],
1358 source_epoch: 1,
1359 rng_state: 0,
1360 triplet_recipe_rr_idx: 0,
1361 text_recipe_rr_idx: 0,
1362 source_stream_cursors: vec![],
1363 };
1364 store_a.save_sampler_state(&initial_state, None).unwrap();
1365
1366 let checkpoint_state = PersistedSamplerState {
1368 source_cycle_idx: 99,
1369 source_record_cursors: vec![],
1370 source_epoch: 99,
1371 rng_state: 42,
1372 triplet_recipe_rr_idx: 0,
1373 text_recipe_rr_idx: 0,
1374 source_stream_cursors: vec![],
1375 };
1376 store_a
1377 .save_sampler_state(&checkpoint_state, Some(path_b.as_path()))
1378 .unwrap();
1379
1380 drop(store_a);
1383 let store_a_disk = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1384 assert_eq!(
1385 store_a_disk
1386 .load_sampler_state()
1387 .unwrap()
1388 .unwrap()
1389 .source_epoch,
1390 1,
1391 "save_to=Some(...) must not overwrite the on-disk save_path"
1392 );
1393
1394 let store_b = FileSplitStore::open(&path_b, ratios, 42).unwrap();
1396 let state_from_b = store_b.load_sampler_state().unwrap().unwrap();
1397 assert_eq!(
1398 state_from_b.source_epoch, 99,
1399 "checkpoint store must hold the snapshotted state"
1400 );
1401 }
1402
1403 #[test]
1404 fn file_store_metadata_mismatch_and_debug_paths_are_covered() {
1405 let dir = tempdir().unwrap();
1406 let path = dir.path().join("meta_mismatch.bin");
1407 let ratios = SplitRatios::default();
1408 let store = FileSplitStore::open(&path, ratios, 123).unwrap();
1409
1410 let debug_repr = format!("{store:?}");
1411 assert!(debug_repr.contains("FileSplitStore"));
1412
1413 let wrong_version = StoreMeta {
1414 version: STORE_VERSION.wrapping_add(1),
1415 seed: 123,
1416 ratios,
1417 };
1418 let payload = encode_store_meta(&wrong_version);
1419 store.store.write(META_KEY, &payload).unwrap();
1420 store
1422 .save_sampler_state(
1423 &PersistedSamplerState {
1424 source_cycle_idx: 0,
1425 source_record_cursors: vec![],
1426 source_epoch: 0,
1427 rng_state: 0,
1428 triplet_recipe_rr_idx: 0,
1429 text_recipe_rr_idx: 0,
1430 source_stream_cursors: vec![],
1431 },
1432 None,
1433 )
1434 .unwrap();
1435 drop(store);
1436
1437 let err = FileSplitStore::open(&path, ratios, 123).unwrap_err();
1438 assert!(matches!(err, SamplerError::SplitStore(msg) if msg.contains("version mismatch")));
1439
1440 let ratio_path = dir.path().join("ratio_mismatch.bin");
1441 let baseline = FileSplitStore::open(&ratio_path, ratios, 777).unwrap();
1442 baseline
1444 .save_sampler_state(
1445 &PersistedSamplerState {
1446 source_cycle_idx: 0,
1447 source_record_cursors: vec![],
1448 source_epoch: 0,
1449 rng_state: 0,
1450 triplet_recipe_rr_idx: 0,
1451 text_recipe_rr_idx: 0,
1452 source_stream_cursors: vec![],
1453 },
1454 None,
1455 )
1456 .unwrap();
1457 drop(baseline);
1458
1459 let different_ratios = SplitRatios {
1460 train: 0.7,
1461 validation: 0.2,
1462 test: 0.1,
1463 };
1464 let err = FileSplitStore::open(&ratio_path, different_ratios, 777).unwrap_err();
1465 assert!(matches!(err, SamplerError::SplitStore(msg) if msg.contains("ratios mismatch")));
1466 }
1467
1468 #[test]
1469 fn split_decode_helpers_reject_corrupt_bitcode_payloads() {
1470 let store_meta_err = decode_store_meta(&[BITCODE_PREFIX, 0xFF, 0xEE]).unwrap_err();
1471 assert!(matches!(
1472 store_meta_err,
1473 SamplerError::SplitStore(msg) if msg.contains("failed to decode split store metadata")
1474 ));
1475
1476 let epoch_meta_err =
1477 decode_epoch_meta(&[EPOCH_META_RECORD_VERSION, BITCODE_PREFIX, 0xFF]).unwrap_err();
1478 assert!(
1479 matches!(epoch_meta_err, SamplerError::SplitStore(msg) if msg.contains("corrupt epoch meta record"))
1480 );
1481
1482 let epoch_hashes_err =
1483 decode_epoch_hashes(&[EPOCH_HASH_RECORD_VERSION, BITCODE_PREFIX, 0xFF]).unwrap_err();
1484 assert!(matches!(
1485 epoch_hashes_err,
1486 SamplerError::SplitStore(msg) if msg.contains("corrupt epoch hashes record")
1487 ));
1488
1489 let sampler_state_err =
1490 decode_sampler_state(&[SAMPLER_STATE_RECORD_VERSION, BITCODE_PREFIX, 0xFF])
1491 .unwrap_err();
1492 assert!(matches!(
1493 sampler_state_err,
1494 SamplerError::SplitStore(msg) if msg.contains("corrupt sampler state record")
1495 ));
1496 }
1497
1498 #[test]
1499 fn file_store_label_fallback_and_validation_keys_are_covered() {
1500 let dir = tempdir().unwrap();
1501 let path = dir.path().join("labels.bin");
1502 let store = FileSplitStore::open(&path, SplitRatios::default(), 42).unwrap();
1503
1504 let id = "bad_label_record".to_string();
1505 let expected = derive_label_for_id(&id, 42, SplitRatios::default());
1506 let key = split_key(&id);
1507
1508 store.store.write(&key, b"x").unwrap();
1509 assert_eq!(store.label_for(&id), Some(expected));
1510
1511 store.store.write(&key, b"1").unwrap();
1512 assert_eq!(store.label_for(&id), Some(SplitLabel::Validation));
1513
1514 let meta_validation = epoch_meta_key(SplitLabel::Validation);
1515 let hashes_validation = epoch_hashes_key(SplitLabel::Validation);
1516 assert!(meta_validation.starts_with(EPOCH_META_PREFIX));
1517 assert!(hashes_validation.starts_with(EPOCH_HASHES_PREFIX));
1518 assert!(meta_validation.ends_with(b"validation"));
1519 assert!(hashes_validation.ends_with(b"validation"));
1520 }
1521
1522 #[test]
1523 fn ensure_parent_dir_allows_plain_file_names() {
1524 ensure_parent_dir(Path::new("split_store_local.bin")).unwrap();
1525 let coerced = coerce_store_path(PathBuf::from("explicit_store.bin"));
1526 assert_eq!(coerced, PathBuf::from("explicit_store.bin"));
1527 }
1528
1529 #[test]
1535 fn load_path_source_is_never_modified_while_open() {
1536 let dir = tempdir().unwrap();
1537 let source = dir.path().join("source.bin");
1538 let dest = dir.path().join("dest.bin");
1539 let ratios = SplitRatios::default();
1540
1541 let seeded = FileSplitStore::open(&source, ratios, 77).unwrap();
1543 let state = PersistedSamplerState {
1544 source_cycle_idx: 5,
1545 source_record_cursors: vec![("s".to_string(), 3)],
1546 source_epoch: 9,
1547 rng_state: 42,
1548 triplet_recipe_rr_idx: 1,
1549 text_recipe_rr_idx: 2,
1550 source_stream_cursors: vec![("s".to_string(), 4)],
1551 };
1552 seeded.save_sampler_state(&state, None).unwrap();
1553 drop(seeded);
1554
1555 let source_size_before = std::fs::metadata(&source).unwrap().len();
1556
1557 let bootstrapped =
1559 FileSplitStore::open_with_load_path(Some(&source), &dest, ratios, 77).unwrap();
1560 let new_state = PersistedSamplerState {
1561 source_cycle_idx: 99,
1562 source_record_cursors: vec![("s".to_string(), 77)],
1563 source_epoch: 100,
1564 rng_state: 0,
1565 triplet_recipe_rr_idx: 0,
1566 text_recipe_rr_idx: 0,
1567 source_stream_cursors: vec![],
1568 };
1569 bootstrapped.save_sampler_state(&new_state, None).unwrap();
1570 drop(bootstrapped);
1571
1572 let source_size_after = std::fs::metadata(&source).unwrap().len();
1574 assert_eq!(
1575 source_size_before, source_size_after,
1576 "source file was modified during bootstrapped open"
1577 );
1578
1579 let verify_source = FileSplitStore::open(&source, ratios, 77).unwrap();
1581 let loaded = verify_source.load_sampler_state().unwrap().unwrap();
1582 assert_eq!(loaded.source_cycle_idx, 5);
1583 assert_eq!(loaded.source_epoch, 9);
1584 }
1585
1586 #[test]
1589 fn save_none_on_bootstrapped_store_publishes_to_save_path() {
1590 let dir = tempdir().unwrap();
1591 let source = dir.path().join("load.bin");
1592 let dest = dir.path().join("save.bin");
1593 let ratios = SplitRatios::default();
1594
1595 let _ = FileSplitStore::open(&source, ratios, 11).unwrap();
1596
1597 assert!(!dest.exists(), "dest must not exist before first save");
1598
1599 let store = FileSplitStore::open_with_load_path(Some(&source), &dest, ratios, 11).unwrap();
1600 let state = PersistedSamplerState {
1601 source_cycle_idx: 7,
1602 source_record_cursors: vec![],
1603 source_epoch: 2,
1604 rng_state: 1,
1605 triplet_recipe_rr_idx: 0,
1606 text_recipe_rr_idx: 0,
1607 source_stream_cursors: vec![],
1608 };
1609 store.save_sampler_state(&state, None).unwrap();
1610 drop(store);
1611
1612 assert!(dest.exists(), "dest must exist after save(None)");
1613
1614 let loaded_dest = FileSplitStore::open(&dest, ratios, 11).unwrap();
1615 assert_eq!(
1616 loaded_dest
1617 .load_sampler_state()
1618 .unwrap()
1619 .unwrap()
1620 .source_cycle_idx,
1621 7
1622 );
1623 }
1624
1625 #[test]
1628 fn save_some_on_bootstrapped_store_publishes_to_explicit_path_only() {
1629 let dir = tempdir().unwrap();
1630 let source = dir.path().join("load.bin");
1631 let save = dir.path().join("save.bin"); let other = dir.path().join("other.bin"); let ratios = SplitRatios::default();
1634
1635 let _ = FileSplitStore::open(&source, ratios, 22).unwrap();
1636
1637 let store = FileSplitStore::open_with_load_path(Some(&source), &save, ratios, 22).unwrap();
1638 let state = PersistedSamplerState {
1639 source_cycle_idx: 3,
1640 source_record_cursors: vec![],
1641 source_epoch: 1,
1642 rng_state: 0,
1643 triplet_recipe_rr_idx: 0,
1644 text_recipe_rr_idx: 0,
1645 source_stream_cursors: vec![],
1646 };
1647 store
1648 .save_sampler_state(&state, Some(other.as_path()))
1649 .unwrap();
1650 drop(store);
1651
1652 assert!(
1653 !save.exists(),
1654 "canonical save_path must not be created when saving to explicit path"
1655 );
1656 assert!(other.exists(), "explicit target must be created");
1657
1658 let loaded_other = FileSplitStore::open(&other, ratios, 22).unwrap();
1659 assert_eq!(
1660 loaded_other
1661 .load_sampler_state()
1662 .unwrap()
1663 .unwrap()
1664 .source_cycle_idx,
1665 3
1666 );
1667 }
1668
1669 #[test]
1672 fn repeated_save_none_on_bootstrapped_store_is_idempotent() {
1673 let dir = tempdir().unwrap();
1674 let source = dir.path().join("load.bin");
1675 let dest = dir.path().join("save.bin");
1676 let ratios = SplitRatios::default();
1677
1678 let _ = FileSplitStore::open(&source, ratios, 33).unwrap();
1679
1680 let store = FileSplitStore::open_with_load_path(Some(&source), &dest, ratios, 33).unwrap();
1681
1682 for cycle_idx in [1_u64, 2, 3] {
1683 let state = PersistedSamplerState {
1684 source_cycle_idx: cycle_idx,
1685 source_record_cursors: vec![],
1686 source_epoch: 0,
1687 rng_state: 0,
1688 triplet_recipe_rr_idx: 0,
1689 text_recipe_rr_idx: 0,
1690 source_stream_cursors: vec![],
1691 };
1692 store.save_sampler_state(&state, None).unwrap();
1693 }
1694 drop(store);
1695
1696 let reloaded = FileSplitStore::open(&dest, ratios, 33).unwrap();
1697 assert_eq!(
1698 reloaded
1699 .load_sampler_state()
1700 .unwrap()
1701 .unwrap()
1702 .source_cycle_idx,
1703 3,
1704 "dest should hold the last-saved state"
1705 );
1706 }
1707}