Skip to main content

triplets_core/
splits.rs

1use serde::{Deserialize, Serialize};
2use simd_r_drive::storage_engine::DataStore;
3use simd_r_drive::storage_engine::traits::{DataStoreReader, DataStoreWriter};
4use siphasher::sip::SipHasher;
5use std::collections::HashMap;
6use std::fmt;
7use std::fs;
8use std::hash::{Hash, Hasher};
9use std::io;
10use std::path::{Path, PathBuf};
11use std::sync::RwLock;
12use tempfile::TempDir;
13
14use crate::constants::splits::{
15    ALL_SPLITS, BITCODE_PREFIX, EPOCH_HASH_RECORD_VERSION, EPOCH_HASHES_PREFIX, EPOCH_META_PREFIX,
16    EPOCH_META_RECORD_VERSION, EPOCH_RECORD_TOMBSTONE, META_KEY, SAMPLER_STATE_KEY,
17    SAMPLER_STATE_RECORD_VERSION, SPLIT_PREFIX, STORE_VERSION,
18};
19use crate::data::RecordId;
20use crate::errors::SamplerError;
21use crate::types::SourceId;
22
23/// Logical dataset partitions used during sampling.
24#[derive(
25    Clone,
26    Copy,
27    Debug,
28    PartialEq,
29    Eq,
30    Hash,
31    Serialize,
32    Deserialize,
33    bitcode::Encode,
34    bitcode::Decode,
35)]
36pub enum SplitLabel {
37    /// Training split.
38    Train,
39    /// Validation split.
40    Validation,
41    /// Test split.
42    Test,
43}
44
45/// Ratio configuration for train/validation/test assignment.
46#[derive(Clone, Copy, Debug, Serialize, Deserialize, bitcode::Encode, bitcode::Decode)]
47pub struct SplitRatios {
48    /// Fraction assigned to train.
49    pub train: f32,
50    /// Fraction assigned to validation.
51    pub validation: f32,
52    /// Fraction assigned to test.
53    pub test: f32,
54}
55
56impl Default for SplitRatios {
57    fn default() -> Self {
58        Self {
59            train: 0.8,
60            validation: 0.1,
61            test: 0.1,
62        }
63    }
64}
65
66impl SplitRatios {
67    /// Validate that ratios sum to `1.0` (within epsilon).
68    pub fn normalized(self) -> Result<Self, SamplerError> {
69        let sum = self.train + self.validation + self.test;
70        if (sum - 1.0).abs() > 1e-6 {
71            return Err(SamplerError::Configuration(
72                "split ratios must sum to 1.0".to_string(),
73            ));
74        }
75        Ok(self)
76    }
77}
78
79pub use crate::constants::splits::EPOCH_STATE_VERSION;
80
81/// Persisted epoch cursor metadata for one split.
82#[derive(Clone, Debug, bitcode::Encode, bitcode::Decode)]
83pub struct PersistedSplitMeta {
84    /// Current epoch for this split.
85    pub epoch: u64,
86    /// Cursor offset within the epoch hash list.
87    pub offset: u64,
88    /// Checksum of the persisted hash list.
89    pub hashes_checksum: u64,
90}
91
92/// Persisted deterministic epoch hash ordering for one split.
93#[derive(Clone, Debug, bitcode::Encode, bitcode::Decode)]
94pub struct PersistedSplitHashes {
95    /// Checksum of `hashes`.
96    pub checksum: u64,
97    /// Deterministic per-epoch hash ordering.
98    pub hashes: Vec<u64>,
99}
100
101/// Persisted sampler runtime state (cursors, recipe indices, RNG).
102///
103/// Every field here is restored on resume.  Notable things NOT persisted:
104///
105/// - **`emitted_text_hashes`** — cross-batch text dedup restarts fresh after
106///   a save/load boundary.  The same text may appear on both sides.
107/// - **`source_wrapped`** — reconstructed from the restored cursors.
108/// - **`chunk_index` / role cursors** — rebuilt from the ingested pool.
109#[derive(Clone, Debug, bitcode::Encode, bitcode::Decode)]
110pub struct PersistedSamplerState {
111    /// Source-cycle round-robin index.
112    pub source_cycle_idx: u64,
113    /// Per-source record cursors.
114    pub source_record_cursors: Vec<(SourceId, u64)>,
115    /// Current epoch used for deterministic reshuffle.
116    pub epoch: u64,
117    /// Epoch step counter — incremented once per `next_*_batch` call.
118    pub epoch_step: u64,
119    /// Deterministic RNG internal state.
120    pub rng_state: u64,
121    /// Round-robin index for triplet recipes.
122    pub triplet_recipe_rr_idx: u64,
123    /// Round-robin index for text recipes.
124    pub text_recipe_rr_idx: u64,
125    /// Per-source I/O stream cursors for restart-resume support.
126    pub source_stream_cursors: Vec<(SourceId, u64)>,
127}
128
129/// Split assignment backend.
130///
131/// Implementations map `RecordId` values to split labels deterministically.
132pub trait SplitStore: Send + Sync {
133    /// Return split label for `id` if known/derivable.
134    fn label_for(&self, id: &RecordId) -> Option<SplitLabel>;
135    /// Persist an explicit split assignment for `id`.
136    fn upsert(&self, id: RecordId, label: SplitLabel) -> Result<(), SamplerError>;
137    /// Return configured split ratios.
138    fn ratios(&self) -> SplitRatios;
139    /// Return the split label for `id`, creating/deriving one when needed.
140    fn ensure(&self, id: RecordId) -> Result<SplitLabel, SamplerError>;
141}
142
143/// Persistence backend for epoch metadata and epoch hash orderings.
144pub trait EpochStateStore: Send + Sync {
145    /// Load split→epoch metadata map.
146    fn load_epoch_meta(&self) -> Result<HashMap<SplitLabel, PersistedSplitMeta>, SamplerError>;
147    /// Load persisted epoch hashes for one split, if available.
148    fn load_epoch_hashes(
149        &self,
150        label: SplitLabel,
151    ) -> Result<Option<PersistedSplitHashes>, SamplerError>;
152    /// Persist split→epoch metadata map.
153    fn save_epoch_meta(
154        &self,
155        meta: &HashMap<SplitLabel, PersistedSplitMeta>,
156    ) -> Result<(), SamplerError>;
157    /// Persist epoch hash list for one split.
158    fn save_epoch_hashes(
159        &self,
160        label: SplitLabel,
161        hashes: &PersistedSplitHashes,
162    ) -> Result<(), SamplerError>;
163}
164
165/// Persistence backend for sampler runtime state.
166pub trait SamplerStateStore: Send + Sync {
167    /// Load persisted sampler runtime state, if present.
168    fn load_sampler_state(&self) -> Result<Option<PersistedSamplerState>, SamplerError>;
169    /// Save sampler runtime state, optionally mirroring to `save_path`.
170    fn save_sampler_state(
171        &self,
172        state: &PersistedSamplerState,
173        save_path: Option<&Path>,
174    ) -> Result<(), SamplerError>;
175}
176
177/// In-memory split store with deterministic assignment derivation.
178pub struct DeterministicSplitStore {
179    ratios: SplitRatios,
180    assignments: RwLock<HashMap<RecordId, SplitLabel>>,
181    seed: u64,
182    epoch_meta: RwLock<HashMap<SplitLabel, PersistedSplitMeta>>,
183    epoch_hashes: RwLock<HashMap<SplitLabel, PersistedSplitHashes>>,
184    sampler_state: RwLock<Option<PersistedSamplerState>>,
185}
186
187impl DeterministicSplitStore {
188    /// Create an in-memory split store configured with `ratios` and `seed`.
189    pub fn new(ratios: SplitRatios, seed: u64) -> Result<Self, SamplerError> {
190        ratios.normalized()?;
191        Ok(Self {
192            ratios,
193            assignments: RwLock::new(HashMap::new()),
194            seed,
195            epoch_meta: RwLock::new(HashMap::new()),
196            epoch_hashes: RwLock::new(HashMap::new()),
197            sampler_state: RwLock::new(None),
198        })
199    }
200
201    fn derive_label(&self, id: &RecordId) -> SplitLabel {
202        derive_label_for_id(id, self.seed, self.ratios)
203    }
204}
205
206impl SplitStore for DeterministicSplitStore {
207    fn label_for(&self, id: &RecordId) -> Option<SplitLabel> {
208        if let Some(label) = self.assignments.read().ok()?.get(id).copied() {
209            return Some(label);
210        }
211        Some(self.derive_label(id))
212    }
213
214    fn upsert(&self, id: RecordId, label: SplitLabel) -> Result<(), SamplerError> {
215        let mut guard = self
216            .assignments
217            .write()
218            .map_err(|_| SamplerError::SplitStore("lock poisoned".into()))?;
219        guard.insert(id, label);
220        Ok(())
221    }
222
223    fn ratios(&self) -> SplitRatios {
224        self.ratios
225    }
226
227    fn ensure(&self, id: RecordId) -> Result<SplitLabel, SamplerError> {
228        Ok(self.derive_label(&id))
229    }
230}
231
232impl EpochStateStore for DeterministicSplitStore {
233    fn load_epoch_meta(&self) -> Result<HashMap<SplitLabel, PersistedSplitMeta>, SamplerError> {
234        self.epoch_meta
235            .read()
236            .map_err(|_| SamplerError::SplitStore("epoch meta lock poisoned".into()))
237            .map(|guard| guard.clone())
238    }
239
240    fn load_epoch_hashes(
241        &self,
242        label: SplitLabel,
243    ) -> Result<Option<PersistedSplitHashes>, SamplerError> {
244        Ok(self
245            .epoch_hashes
246            .read()
247            .map_err(|_| SamplerError::SplitStore("epoch hashes lock poisoned".into()))?
248            .get(&label)
249            .cloned())
250    }
251
252    fn save_epoch_meta(
253        &self,
254        meta: &HashMap<SplitLabel, PersistedSplitMeta>,
255    ) -> Result<(), SamplerError> {
256        *self
257            .epoch_meta
258            .write()
259            .map_err(|_| SamplerError::SplitStore("epoch meta lock poisoned".into()))? =
260            meta.clone();
261        Ok(())
262    }
263
264    fn save_epoch_hashes(
265        &self,
266        label: SplitLabel,
267        hashes: &PersistedSplitHashes,
268    ) -> Result<(), SamplerError> {
269        self.epoch_hashes
270            .write()
271            .map_err(|_| SamplerError::SplitStore("epoch hashes lock poisoned".into()))?
272            .insert(label, hashes.clone());
273        Ok(())
274    }
275}
276
277impl SamplerStateStore for DeterministicSplitStore {
278    fn load_sampler_state(&self) -> Result<Option<PersistedSamplerState>, SamplerError> {
279        self.sampler_state
280            .read()
281            .map_err(|_| SamplerError::SplitStore("sampler state lock poisoned".into()))
282            .map(|guard| guard.clone())
283    }
284
285    fn save_sampler_state(
286        &self,
287        state: &PersistedSamplerState,
288        _save_path: Option<&Path>,
289    ) -> Result<(), SamplerError> {
290        *self
291            .sampler_state
292            .write()
293            .map_err(|_| SamplerError::SplitStore("sampler state lock poisoned".into()))? =
294            Some(state.clone());
295        Ok(())
296    }
297}
298
299#[derive(Clone, Copy, Debug, bitcode::Encode, bitcode::Decode)]
300/// Versioned metadata header stored in file-backed split stores.
301struct StoreMeta {
302    version: u8,
303    seed: u64,
304    ratios: SplitRatios,
305}
306
307fn encode_store_meta(meta: &StoreMeta) -> Vec<u8> {
308    encode_bitcode_payload(&bitcode::encode(meta))
309}
310
311fn decode_store_meta(bytes: &[u8]) -> Result<StoreMeta, SamplerError> {
312    let raw = decode_bitcode_payload(bytes)?;
313    bitcode::decode(&raw).map_err(|err| {
314        SamplerError::SplitStore(format!("failed to decode split store metadata: {err}"))
315    })
316}
317
318/// File-backed split store for persistent runs.
319///
320/// Persists assignment metadata, epoch state, and sampler runtime state.
321///
322/// The store **always** works against a private temporary copy of the source
323/// snapshot.  All reads and mutations accumulate in the temp file.
324/// State is published to permanent storage only when
325/// [`SamplerStateStore::save_sampler_state`] is called:
326///
327/// * `save_to == None`  → publish temp to `save_path` (may overwrite).
328/// * `save_to == Some(p)` → publish temp to `p` only; `save_path` is left
329///   untouched.
330///
331/// This guarantees that the original source file is never modified and that
332/// no partial state leaks to the target before an explicit save.
333pub struct FileSplitStore {
334    store: DataStore,
335    /// Working path: always a private temp file; all reads and writes go here.
336    path: PathBuf,
337    /// Declared save destination; published to on `save_sampler_state(None)`.
338    save_path: PathBuf,
339    ratios: SplitRatios,
340    seed: u64,
341    /// Keeps the temporary directory alive for the lifetime of this store.
342    _temp_dir: TempDir,
343}
344
345impl fmt::Debug for FileSplitStore {
346    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
347        f.debug_struct("FileSplitStore")
348            .field("path", &self.save_path)
349            .field("ratios", &self.ratios)
350            .field("seed", &self.seed)
351            .finish()
352    }
353}
354
355impl FileSplitStore {
356    /// Open (or create) a file-backed split store at `path`.
357    pub fn open<P: Into<PathBuf>>(
358        path: P,
359        ratios: SplitRatios,
360        seed: u64,
361    ) -> Result<Self, SamplerError> {
362        Self::open_with_load_path(None::<PathBuf>, path, ratios, seed)
363    }
364
365    /// Open (or create) a file-backed split store at `save_path`, optionally
366    /// bootstrapping initial state from `load_path`.
367    ///
368    /// Always stages through a **private temporary file**.  The source used to
369    /// seed the temp is chosen as follows:
370    ///
371    /// 1. `load_path` if supplied and it exists.
372    /// 2. `save_path` if it already exists.
373    /// 3. Nothing — a fresh empty store is created in the temp.
374    ///
375    /// All mutations accumulate in the temp.  State is published to permanent
376    /// storage only when [`SamplerStateStore::save_sampler_state`] is called.
377    /// The original source file is never modified.
378    pub fn open_with_load_path<LP: Into<PathBuf>, SP: Into<PathBuf>>(
379        load_path: Option<LP>,
380        save_path: SP,
381        ratios: SplitRatios,
382        seed: u64,
383    ) -> Result<Self, SamplerError> {
384        let ratios = ratios.normalized()?;
385        let save_path = coerce_store_path(save_path.into());
386        ensure_parent_dir(&save_path)?;
387
388        // Determine source to seed the working temp from.
389        // Priority: explicit load_path > existing save_path > fresh (nothing).
390        let source: Option<PathBuf> = if let Some(lp) = load_path {
391            let lp = coerce_store_path(lp.into());
392            if lp.exists() { Some(lp) } else { None }
393        } else if save_path.exists() {
394            Some(save_path.clone())
395        } else {
396            None
397        };
398
399        let temp_dir = tempfile::tempdir().map_err(|err| {
400            SamplerError::SplitStore(format!("failed to create temp dir for split store: {err}"))
401        })?;
402        let working_path = temp_dir.path().join("working_store.bin");
403        if let Some(src) = &source {
404            fs::copy(src, &working_path).map_err(|err| {
405                SamplerError::SplitStore(format!(
406                    "failed to copy split store from '{}' to temp: {err}",
407                    src.display()
408                ))
409            })?;
410        }
411
412        let raw_store = DataStore::open(&working_path).map_err(map_store_err)?;
413        let store = Self {
414            store: raw_store,
415            path: working_path,
416            save_path,
417            ratios,
418            seed,
419            _temp_dir: temp_dir,
420        };
421        store.verify_metadata()?;
422        Ok(store)
423    }
424
425    fn verify_metadata(&self) -> Result<(), SamplerError> {
426        match read_bytes(&self.store, META_KEY)? {
427            Some(bytes) => {
428                let meta = decode_store_meta(&bytes)?;
429                if meta.version != STORE_VERSION {
430                    return Err(SamplerError::SplitStore(format!(
431                        "split store version mismatch (expected {}, found {})",
432                        STORE_VERSION, meta.version
433                    )));
434                }
435                if meta.seed != self.seed {
436                    return Err(SamplerError::SplitStore(format!(
437                        "split store seed mismatch (expected {}, found {})",
438                        self.seed, meta.seed
439                    )));
440                }
441                if !ratios_close(meta.ratios, self.ratios) {
442                    return Err(SamplerError::SplitStore(
443                        "split store ratios mismatch".into(),
444                    ));
445                }
446            }
447            None => {
448                let blob = StoreMeta {
449                    version: STORE_VERSION,
450                    seed: self.seed,
451                    ratios: self.ratios,
452                };
453                let payload = encode_store_meta(&blob);
454                write_bytes(&self.store, META_KEY, &payload)?;
455            }
456        }
457        Ok(())
458    }
459
460    fn read_epoch_meta_entry(
461        &self,
462        label: SplitLabel,
463    ) -> Result<Option<PersistedSplitMeta>, SamplerError> {
464        let key = epoch_meta_key(label);
465        let entry = self.store.read(&key).map_err(map_store_err)?;
466        match entry {
467            None => Ok(None),
468            Some(bytes) => decode_epoch_meta(bytes.as_ref()),
469        }
470    }
471
472    fn write_epoch_meta_entry(
473        &self,
474        label: SplitLabel,
475        meta: Option<&PersistedSplitMeta>,
476    ) -> Result<(), SamplerError> {
477        let key = epoch_meta_key(label);
478        let payload = encode_epoch_meta(meta);
479        self.store
480            .write(&key, payload.as_slice())
481            .map_err(map_store_err)?;
482        Ok(())
483    }
484
485    fn read_epoch_hashes_entry(
486        &self,
487        label: SplitLabel,
488    ) -> Result<Option<PersistedSplitHashes>, SamplerError> {
489        let key = epoch_hashes_key(label);
490        let entry = self.store.read(&key).map_err(map_store_err)?;
491        match entry {
492            None => Ok(None),
493            Some(bytes) => decode_epoch_hashes(bytes.as_ref()),
494        }
495    }
496
497    fn write_epoch_hashes_entry(
498        &self,
499        label: SplitLabel,
500        hashes: &PersistedSplitHashes,
501    ) -> Result<(), SamplerError> {
502        let key = epoch_hashes_key(label);
503        let payload = encode_epoch_hashes(hashes);
504        self.store
505            .write(&key, payload.as_slice())
506            .map_err(map_store_err)?;
507        Ok(())
508    }
509}
510
511impl SplitStore for FileSplitStore {
512    fn label_for(&self, id: &RecordId) -> Option<SplitLabel> {
513        let key = split_key(id);
514        if let Ok(Some(value)) = self.store.read(&key)
515            && let Ok(label) = decode_label(value.as_ref())
516        {
517            return Some(label);
518        }
519        Some(derive_label_for_id(id, self.seed, self.ratios))
520    }
521
522    fn upsert(&self, id: RecordId, label: SplitLabel) -> Result<(), SamplerError> {
523        let _ = (id, label);
524        Ok(())
525    }
526
527    fn ratios(&self) -> SplitRatios {
528        self.ratios
529    }
530
531    fn ensure(&self, id: RecordId) -> Result<SplitLabel, SamplerError> {
532        Ok(derive_label_for_id(&id, self.seed, self.ratios))
533    }
534}
535
536impl EpochStateStore for FileSplitStore {
537    fn load_epoch_meta(&self) -> Result<HashMap<SplitLabel, PersistedSplitMeta>, SamplerError> {
538        let mut meta = HashMap::new();
539        for label in ALL_SPLITS {
540            if let Some(entry) = self.read_epoch_meta_entry(label)? {
541                meta.insert(label, entry);
542            }
543        }
544        Ok(meta)
545    }
546
547    fn load_epoch_hashes(
548        &self,
549        label: SplitLabel,
550    ) -> Result<Option<PersistedSplitHashes>, SamplerError> {
551        self.read_epoch_hashes_entry(label)
552    }
553
554    fn save_epoch_meta(
555        &self,
556        meta: &HashMap<SplitLabel, PersistedSplitMeta>,
557    ) -> Result<(), SamplerError> {
558        for label in ALL_SPLITS {
559            self.write_epoch_meta_entry(label, meta.get(&label))?;
560        }
561        Ok(())
562    }
563
564    fn save_epoch_hashes(
565        &self,
566        label: SplitLabel,
567        hashes: &PersistedSplitHashes,
568    ) -> Result<(), SamplerError> {
569        self.write_epoch_hashes_entry(label, hashes)
570    }
571}
572
573impl SamplerStateStore for FileSplitStore {
574    fn load_sampler_state(&self) -> Result<Option<PersistedSamplerState>, SamplerError> {
575        match read_bytes(&self.store, SAMPLER_STATE_KEY)? {
576            Some(bytes) => decode_sampler_state(bytes.as_ref()),
577            None => Ok(None),
578        }
579    }
580
581    /// Persist `state` to the working temp store and publish to the destination.
582    ///
583    /// * `save_to == None`  → publish temp to `save_path` (may overwrite).
584    /// * `save_to == Some(p)` → publish temp to `p` only; `save_path` is left
585    ///   untouched.  `p` must not already exist.
586    fn save_sampler_state(
587        &self,
588        state: &PersistedSamplerState,
589        save_to: Option<&Path>,
590    ) -> Result<(), SamplerError> {
591        // Determine publish destination before writing anything.
592        let dest = if let Some(p) = save_to {
593            coerce_store_path(p.to_path_buf())
594        } else {
595            self.save_path.clone()
596        };
597
598        // Refuse to overwrite an explicitly-named destination that already exists.
599        // Saving back to the canonical save_path (None) is always allowed.
600        if save_to.is_some() && dest.exists() {
601            return Err(SamplerError::SplitStore(format!(
602                "refusing to overwrite existing split store '{}'; choose a new path",
603                dest.display()
604            )));
605        }
606
607        // Write state into the working temp store.
608        let payload = encode_sampler_state(state);
609        write_bytes(&self.store, SAMPLER_STATE_KEY, &payload)?;
610
611        // Publish: copy the temp store to the destination.
612        ensure_parent_dir(&dest)?;
613        fs::copy(&self.path, &dest).map_err(|err| {
614            SamplerError::SplitStore(format!(
615                "failed to publish split store to '{}': {err}",
616                dest.display()
617            ))
618        })?;
619
620        Ok(())
621    }
622}
623
624fn decode_label(bytes: &[u8]) -> Result<SplitLabel, SamplerError> {
625    match bytes.first() {
626        Some(b'0') => Ok(SplitLabel::Train),
627        Some(b'1') => Ok(SplitLabel::Validation),
628        Some(b'2') => Ok(SplitLabel::Test),
629        _ => Err(SamplerError::SplitStore("invalid split label".into())),
630    }
631}
632
633fn derive_label_for_id(id: &RecordId, seed: u64, ratios: SplitRatios) -> SplitLabel {
634    let mut hasher = SipHasher::new();
635    id.hash(&mut hasher);
636    seed.hash(&mut hasher);
637    let value = hasher.finish() as f64 / u64::MAX as f64;
638    let train_cut = ratios.train as f64;
639    let val_cut = train_cut + ratios.validation as f64;
640    if value < train_cut {
641        SplitLabel::Train
642    } else if value < val_cut {
643        SplitLabel::Validation
644    } else {
645        SplitLabel::Test
646    }
647}
648
649fn ratios_close(a: SplitRatios, b: SplitRatios) -> bool {
650    ((a.train - b.train).abs() + (a.validation - b.validation).abs() + (a.test - b.test).abs())
651        < 1e-5
652}
653
654fn split_key(id: &RecordId) -> Vec<u8> {
655    let mut key = Vec::with_capacity(SPLIT_PREFIX.len() + id.len());
656    key.extend_from_slice(SPLIT_PREFIX);
657    key.extend_from_slice(id.as_bytes());
658    key
659}
660
661fn read_bytes(store: &DataStore, key: &[u8]) -> Result<Option<Vec<u8>>, SamplerError> {
662    store
663        .read(key)
664        .map_err(map_store_err)?
665        .map(|entry| Ok(entry.as_ref().to_vec()))
666        .transpose()
667}
668
669fn write_bytes(store: &DataStore, key: &[u8], payload: &[u8]) -> Result<(), SamplerError> {
670    store.write(key, payload).map_err(map_store_err)?;
671    Ok(())
672}
673
674fn epoch_meta_key(label: SplitLabel) -> Vec<u8> {
675    let mut key = Vec::with_capacity(EPOCH_META_PREFIX.len() + 12);
676    key.extend_from_slice(EPOCH_META_PREFIX);
677    let suffix = match label {
678        SplitLabel::Train => b"train".as_ref(),
679        SplitLabel::Validation => b"validation".as_ref(),
680        SplitLabel::Test => b"test".as_ref(),
681    };
682    key.extend_from_slice(suffix);
683    key
684}
685
686fn epoch_hashes_key(label: SplitLabel) -> Vec<u8> {
687    let mut key = Vec::with_capacity(EPOCH_HASHES_PREFIX.len() + 12);
688    key.extend_from_slice(EPOCH_HASHES_PREFIX);
689    let suffix = match label {
690        SplitLabel::Train => b"train".as_ref(),
691        SplitLabel::Validation => b"validation".as_ref(),
692        SplitLabel::Test => b"test".as_ref(),
693    };
694    key.extend_from_slice(suffix);
695    key
696}
697
698fn encode_epoch_meta(meta: Option<&PersistedSplitMeta>) -> Vec<u8> {
699    match meta {
700        None => vec![EPOCH_RECORD_TOMBSTONE],
701        Some(meta) => {
702            let payload = encode_bitcode_payload(&bitcode::encode(meta));
703            let mut buf = Vec::with_capacity(1 + payload.len());
704            buf.push(EPOCH_META_RECORD_VERSION);
705            buf.extend_from_slice(&payload);
706            buf
707        }
708    }
709}
710
711fn decode_epoch_meta(bytes: &[u8]) -> Result<Option<PersistedSplitMeta>, SamplerError> {
712    if bytes.is_empty() || bytes[0] == EPOCH_RECORD_TOMBSTONE {
713        return Ok(None);
714    }
715    if bytes[0] != EPOCH_META_RECORD_VERSION {
716        return Err(SamplerError::SplitStore(
717            "epoch meta record version mismatch".into(),
718        ));
719    }
720    let raw = decode_bitcode_payload(&bytes[1..])?;
721    bitcode::decode(&raw)
722        .map(Some)
723        .map_err(|err| SamplerError::SplitStore(format!("corrupt epoch meta record: {err}")))
724}
725
726fn encode_epoch_hashes(hashes: &PersistedSplitHashes) -> Vec<u8> {
727    let payload = encode_bitcode_payload(&bitcode::encode(hashes));
728    let mut buf = Vec::with_capacity(1 + payload.len());
729    buf.push(EPOCH_HASH_RECORD_VERSION);
730    buf.extend_from_slice(&payload);
731    buf
732}
733
734fn decode_epoch_hashes(bytes: &[u8]) -> Result<Option<PersistedSplitHashes>, SamplerError> {
735    if bytes.is_empty() || bytes[0] == EPOCH_RECORD_TOMBSTONE {
736        return Ok(None);
737    }
738    if bytes[0] != EPOCH_HASH_RECORD_VERSION {
739        return Err(SamplerError::SplitStore(
740            "epoch hashes record version mismatch".into(),
741        ));
742    }
743    let raw = decode_bitcode_payload(&bytes[1..])?;
744    bitcode::decode(&raw)
745        .map(Some)
746        .map_err(|err| SamplerError::SplitStore(format!("corrupt epoch hashes record: {err}")))
747}
748
749fn encode_sampler_state(state: &PersistedSamplerState) -> Vec<u8> {
750    let payload = encode_bitcode_payload(&bitcode::encode(state));
751    let mut buf = Vec::with_capacity(1 + payload.len());
752    buf.push(SAMPLER_STATE_RECORD_VERSION);
753    buf.extend_from_slice(&payload);
754    buf
755}
756
757fn decode_sampler_state(bytes: &[u8]) -> Result<Option<PersistedSamplerState>, SamplerError> {
758    if bytes.is_empty() {
759        return Ok(None);
760    }
761    if bytes[0] != SAMPLER_STATE_RECORD_VERSION {
762        return Err(SamplerError::SplitStore(
763            "sampler state record version mismatch".into(),
764        ));
765    }
766    let raw = decode_bitcode_payload(&bytes[1..])?;
767    bitcode::decode(&raw)
768        .map(Some)
769        .map_err(|err| SamplerError::SplitStore(format!("corrupt sampler state record: {err}")))
770}
771
772fn encode_bitcode_payload(bytes: &[u8]) -> Vec<u8> {
773    let mut out = Vec::with_capacity(1 + bytes.len());
774    out.push(BITCODE_PREFIX);
775    out.extend_from_slice(bytes);
776    out
777}
778
779fn decode_bitcode_payload(bytes: &[u8]) -> Result<Vec<u8>, SamplerError> {
780    if bytes.first().copied() != Some(BITCODE_PREFIX) {
781        return Err(SamplerError::SplitStore(
782            "bitcode payload missing expected prefix".into(),
783        ));
784    }
785    Ok(bytes[1..].to_vec())
786}
787
788fn coerce_store_path(path: PathBuf) -> PathBuf {
789    path
790}
791
792fn ensure_parent_dir(path: &Path) -> Result<(), SamplerError> {
793    if let Some(parent) = path.parent()
794        && !parent.as_os_str().is_empty()
795    {
796        fs::create_dir_all(parent)?;
797    }
798    Ok(())
799}
800
801fn map_store_err(err: io::Error) -> SamplerError {
802    SamplerError::SplitStore(err.to_string())
803}
804
805#[cfg(test)]
806mod tests {
807    use super::*;
808    use std::collections::HashMap;
809    use tempfile::tempdir;
810
811    #[test]
812    fn split_ratios_reject_non_unit_sum() {
813        let invalid = SplitRatios {
814            train: 0.6,
815            validation: 0.3,
816            test: 0.3,
817        };
818
819        let err = DeterministicSplitStore::new(invalid, 1)
820            .err()
821            .expect("expected non-unit split ratios to fail");
822        assert!(matches!(
823            err,
824            SamplerError::Configuration(ref msg) if msg.contains("split ratios must sum to 1.0")
825        ));
826
827        let dir = tempdir().unwrap();
828        let path = dir.path().join("split_store.bin");
829        let err = FileSplitStore::open(&path, invalid, 1).unwrap_err();
830        assert!(matches!(
831            err,
832            SamplerError::Configuration(ref msg) if msg.contains("split ratios must sum to 1.0")
833        ));
834    }
835
836    #[test]
837    fn zero_test_ratio_never_assigns_test_labels() {
838        let ratios = SplitRatios {
839            train: 0.5,
840            validation: 0.5,
841            test: 0.0,
842        };
843        let store = DeterministicSplitStore::new(ratios, 7).unwrap();
844
845        let mut saw_train = false;
846        let mut saw_validation = false;
847        for idx in 0..20_000 {
848            let id = format!("record_{idx}");
849            let label = store.ensure(id).unwrap();
850            assert_ne!(label, SplitLabel::Test);
851            saw_train |= label == SplitLabel::Train;
852            saw_validation |= label == SplitLabel::Validation;
853            if saw_train && saw_validation {
854                break;
855            }
856        }
857
858        assert!(saw_train);
859        assert!(saw_validation);
860    }
861
862    /// `save(None)` must publish the temp to `save_path` on disk.  A fresh
863    /// `open` of that path must see everything that was written before the save.
864    #[test]
865    fn save_none_publishes_to_save_path_and_reloads_cleanly() {
866        let dir = tempdir().unwrap();
867        let path = dir.path().join("persist.bin");
868        let ratios = SplitRatios::default();
869
870        {
871            let store = FileSplitStore::open(&path, ratios, 123).unwrap();
872            let mut meta = HashMap::new();
873            meta.insert(
874                SplitLabel::Train,
875                PersistedSplitMeta {
876                    epoch: 3,
877                    offset: 7,
878                    hashes_checksum: 42,
879                },
880            );
881            store.save_epoch_meta(&meta).unwrap();
882            let state = PersistedSamplerState {
883                source_cycle_idx: 11,
884                source_record_cursors: vec![("s".to_string(), 1)],
885                epoch: 5,
886                epoch_step: 0,
887                rng_state: 99,
888                triplet_recipe_rr_idx: 2,
889                text_recipe_rr_idx: 3,
890                source_stream_cursors: vec![],
891            };
892            assert!(!path.exists(), "save_path must not exist before save(None)");
893            store.save_sampler_state(&state, None).unwrap();
894            assert!(
895                path.exists(),
896                "save(None) must publish to save_path on disk"
897            );
898        }
899
900        // Fresh open must read everything back from disk.
901        let reopened = FileSplitStore::open(&path, ratios, 123).unwrap();
902        let loaded_state = reopened.load_sampler_state().unwrap().unwrap();
903        assert_eq!(loaded_state.source_cycle_idx, 11);
904        assert_eq!(loaded_state.epoch, 5);
905        assert_eq!(loaded_state.rng_state, 99);
906        let loaded_meta = reopened.load_epoch_meta().unwrap();
907        assert_eq!(loaded_meta.get(&SplitLabel::Train).unwrap().epoch, 3);
908    }
909
910    #[test]
911    fn file_store_rejects_seed_mismatch() {
912        let dir = tempdir().unwrap();
913        let path = dir.path().join("splits.json");
914        let ratios = SplitRatios::default();
915        let store = FileSplitStore::open(&path, ratios, 123).unwrap();
916        store.ensure("abc".to_string()).unwrap();
917        // Publish to disk so the seed is committed before the next open.
918        store
919            .save_sampler_state(
920                &PersistedSamplerState {
921                    source_cycle_idx: 0,
922                    source_record_cursors: vec![],
923                    epoch: 0,
924                    epoch_step: 0,
925                    rng_state: 0,
926                    triplet_recipe_rr_idx: 0,
927                    text_recipe_rr_idx: 0,
928                    source_stream_cursors: vec![],
929                },
930                None,
931            )
932            .unwrap();
933        drop(store);
934
935        let err = FileSplitStore::open(&path, ratios, 999).unwrap_err();
936        assert!(matches!(
937            err,
938            SamplerError::SplitStore(msg) if msg.contains("seed")
939        ));
940    }
941
942    #[test]
943    fn file_store_accepts_directory_path() {
944        let dir = tempdir().unwrap();
945        let ratios = SplitRatios::default();
946        let err = FileSplitStore::open(dir.path(), ratios, 777).unwrap_err();
947        assert!(matches!(err, SamplerError::SplitStore(_)));
948    }
949
950    #[test]
951    fn bitcode_payload_requires_prefix() {
952        let err = decode_bitcode_payload(&[0x00, 0x01]).unwrap_err();
953        assert!(
954            matches!(err, SamplerError::SplitStore(msg) if msg.contains("missing expected prefix"))
955        );
956    }
957
958    #[test]
959    fn file_store_round_trips_epoch_and_sampler_state() {
960        let dir = tempdir().unwrap();
961        let path = dir.path().join("epoch_sampler_state.bin");
962        let store = FileSplitStore::open(&path, SplitRatios::default(), 222).unwrap();
963
964        assert!(store.load_epoch_hashes(SplitLabel::Test).unwrap().is_none());
965
966        let mut epoch_meta = HashMap::new();
967        epoch_meta.insert(
968            SplitLabel::Train,
969            PersistedSplitMeta {
970                epoch: 3,
971                offset: 7,
972                hashes_checksum: 42,
973            },
974        );
975        store.save_epoch_meta(&epoch_meta).unwrap();
976
977        let loaded_meta = store.load_epoch_meta().unwrap();
978        let loaded_train = loaded_meta.get(&SplitLabel::Train).unwrap();
979        assert_eq!(loaded_train.epoch, 3);
980        assert_eq!(loaded_train.offset, 7);
981        assert_eq!(loaded_train.hashes_checksum, 42);
982
983        let hashes = PersistedSplitHashes {
984            checksum: 99,
985            hashes: vec![10, 20, 30],
986        };
987        store
988            .save_epoch_hashes(SplitLabel::Validation, &hashes)
989            .unwrap();
990        let loaded_hashes = store
991            .load_epoch_hashes(SplitLabel::Validation)
992            .unwrap()
993            .unwrap();
994        assert_eq!(loaded_hashes.checksum, 99);
995        assert_eq!(loaded_hashes.hashes, vec![10, 20, 30]);
996
997        let state = PersistedSamplerState {
998            source_cycle_idx: 11,
999            source_record_cursors: vec![("source_a".to_string(), 1)],
1000            epoch: 8,
1001            epoch_step: 0,
1002            rng_state: 1234,
1003            triplet_recipe_rr_idx: 2,
1004            text_recipe_rr_idx: 5,
1005            source_stream_cursors: vec![("source_a".to_string(), 9)],
1006        };
1007        store.save_sampler_state(&state, None).unwrap();
1008        let loaded_state = store.load_sampler_state().unwrap().unwrap();
1009        assert_eq!(loaded_state.source_cycle_idx, 11);
1010        assert_eq!(loaded_state.epoch, 8);
1011        assert_eq!(loaded_state.rng_state, 1234);
1012        assert_eq!(loaded_state.triplet_recipe_rr_idx, 2);
1013        assert_eq!(loaded_state.text_recipe_rr_idx, 5);
1014        assert_eq!(
1015            loaded_state.source_record_cursors,
1016            vec![("source_a".to_string(), 1)]
1017        );
1018        assert_eq!(
1019            loaded_state.source_stream_cursors,
1020            vec![("source_a".to_string(), 9)]
1021        );
1022
1023        // Verify the same data survives a drop + fresh open from disk.
1024        drop(store);
1025        let reopened = FileSplitStore::open(&path, SplitRatios::default(), 222).unwrap();
1026        let disk_state = reopened.load_sampler_state().unwrap().unwrap();
1027        assert_eq!(disk_state.source_cycle_idx, 11);
1028        assert_eq!(disk_state.epoch, 8);
1029        assert_eq!(disk_state.rng_state, 1234);
1030        let disk_meta = reopened.load_epoch_meta().unwrap();
1031        assert_eq!(disk_meta.get(&SplitLabel::Train).unwrap().epoch, 3);
1032        let disk_hashes = reopened
1033            .load_epoch_hashes(SplitLabel::Validation)
1034            .unwrap()
1035            .unwrap();
1036        assert_eq!(disk_hashes.checksum, 99);
1037    }
1038
1039    #[test]
1040    fn split_keys_and_labels_cover_helper_paths() {
1041        let key = split_key(&"abc".to_string());
1042        assert!(key.starts_with(SPLIT_PREFIX));
1043
1044        assert!(matches!(decode_label(b"0"), Ok(SplitLabel::Train)));
1045        assert!(matches!(decode_label(b"1"), Ok(SplitLabel::Validation)));
1046        assert!(matches!(decode_label(b"2"), Ok(SplitLabel::Test)));
1047        assert!(decode_label(b"x").is_err());
1048
1049        let epoch_meta_train = epoch_meta_key(SplitLabel::Train);
1050        let epoch_hashes_train = epoch_hashes_key(SplitLabel::Train);
1051        let epoch_hashes_test = epoch_hashes_key(SplitLabel::Test);
1052        assert!(epoch_meta_train.starts_with(EPOCH_META_PREFIX));
1053        assert!(epoch_hashes_train.starts_with(EPOCH_HASHES_PREFIX));
1054        assert!(epoch_hashes_test.starts_with(EPOCH_HASHES_PREFIX));
1055    }
1056
1057    #[test]
1058    fn encode_decode_store_meta_roundtrip_and_corrupt_prefix_error() {
1059        let meta = StoreMeta {
1060            version: STORE_VERSION,
1061            seed: 55,
1062            ratios: SplitRatios::default(),
1063        };
1064        let encoded = encode_store_meta(&meta);
1065        let decoded = decode_store_meta(&encoded).unwrap();
1066        assert_eq!(decoded.version, STORE_VERSION);
1067        assert_eq!(decoded.seed, 55);
1068
1069        let err = decode_store_meta(&[0x00, 0x01]).unwrap_err();
1070        assert!(matches!(
1071            err,
1072            SamplerError::SplitStore(msg) if msg.contains("missing expected prefix")
1073        ));
1074    }
1075
1076    #[test]
1077    fn epoch_and_sampler_decoders_cover_tombstone_and_version_mismatch() {
1078        assert!(decode_epoch_meta(&[]).unwrap().is_none());
1079        assert!(
1080            decode_epoch_meta(&[EPOCH_RECORD_TOMBSTONE])
1081                .unwrap()
1082                .is_none()
1083        );
1084        assert!(decode_epoch_hashes(&[]).unwrap().is_none());
1085        assert!(
1086            decode_epoch_hashes(&[EPOCH_RECORD_TOMBSTONE])
1087                .unwrap()
1088                .is_none()
1089        );
1090        assert!(decode_sampler_state(&[]).unwrap().is_none());
1091
1092        let meta_mismatch = decode_epoch_meta(&[EPOCH_META_RECORD_VERSION.wrapping_add(1), 1]);
1093        assert!(matches!(
1094            meta_mismatch,
1095            Err(SamplerError::SplitStore(msg)) if msg.contains("version mismatch")
1096        ));
1097        let hashes_mismatch = decode_epoch_hashes(&[EPOCH_HASH_RECORD_VERSION.wrapping_add(1), 1]);
1098        assert!(matches!(
1099            hashes_mismatch,
1100            Err(SamplerError::SplitStore(msg)) if msg.contains("version mismatch")
1101        ));
1102        let state_mismatch =
1103            decode_sampler_state(&[SAMPLER_STATE_RECORD_VERSION.wrapping_add(1), 1]);
1104        assert!(matches!(
1105            state_mismatch,
1106            Err(SamplerError::SplitStore(msg)) if msg.contains("version mismatch")
1107        ));
1108    }
1109
1110    #[test]
1111    fn split_store_trait_methods_and_path_helpers_are_exercised() {
1112        let dir = tempdir().unwrap();
1113        let file_path = dir.path().join("nested").join("store.bin");
1114        ensure_parent_dir(&file_path).unwrap();
1115        assert!(file_path.parent().unwrap().exists());
1116
1117        let existing_dir_path = coerce_store_path(dir.path().to_path_buf());
1118        assert_eq!(existing_dir_path, dir.path().to_path_buf());
1119
1120        let ratios = SplitRatios::default();
1121        let store = FileSplitStore::open(&file_path, ratios, 444).unwrap();
1122        assert!((store.ratios().train - ratios.train).abs() < 1e-6);
1123        store
1124            .upsert("record_1".to_string(), SplitLabel::Validation)
1125            .unwrap();
1126        let ensured = store.ensure("record_1".to_string()).unwrap();
1127        assert!(matches!(
1128            ensured,
1129            SplitLabel::Train | SplitLabel::Validation | SplitLabel::Test
1130        ));
1131
1132        let mapped = map_store_err(io::Error::other("boom"));
1133        assert!(matches!(mapped, SamplerError::SplitStore(msg) if msg.contains("boom")));
1134    }
1135
1136    #[test]
1137    fn epoch_and_sampler_encode_decode_roundtrips() {
1138        let meta = PersistedSplitMeta {
1139            epoch: 4,
1140            offset: 9,
1141            hashes_checksum: 21,
1142        };
1143        let encoded_meta = encode_epoch_meta(Some(&meta));
1144        let decoded_meta = decode_epoch_meta(&encoded_meta).unwrap().unwrap();
1145        assert_eq!(decoded_meta.epoch, 4);
1146        assert_eq!(decoded_meta.offset, 9);
1147
1148        let hashes = PersistedSplitHashes {
1149            checksum: 7,
1150            hashes: vec![1, 2, 3],
1151        };
1152        let encoded_hashes = encode_epoch_hashes(&hashes);
1153        let decoded_hashes = decode_epoch_hashes(&encoded_hashes).unwrap().unwrap();
1154        assert_eq!(decoded_hashes.checksum, 7);
1155        assert_eq!(decoded_hashes.hashes, vec![1, 2, 3]);
1156
1157        let state = PersistedSamplerState {
1158            source_cycle_idx: 1,
1159            source_record_cursors: vec![("s".to_string(), 2)],
1160            epoch: 3,
1161            epoch_step: 0,
1162            rng_state: 4,
1163            triplet_recipe_rr_idx: 5,
1164            text_recipe_rr_idx: 6,
1165            source_stream_cursors: vec![("s".to_string(), 7)],
1166        };
1167        let encoded_state = encode_sampler_state(&state);
1168        let decoded_state = decode_sampler_state(&encoded_state).unwrap().unwrap();
1169        assert_eq!(decoded_state.source_cycle_idx, 1);
1170        assert_eq!(decoded_state.epoch, 3);
1171        assert_eq!(decoded_state.rng_state, 4);
1172    }
1173
1174    #[test]
1175    fn deterministic_store_trait_methods_work() {
1176        let ratios = SplitRatios::default();
1177        let store = DeterministicSplitStore::new(ratios, 999).unwrap();
1178
1179        assert_eq!(store.ratios().train, ratios.train);
1180
1181        let id = "source::record".to_string();
1182        let derived = store.label_for(&id).unwrap();
1183        store.upsert(id.clone(), SplitLabel::Validation).unwrap();
1184        assert_eq!(store.label_for(&id), Some(SplitLabel::Validation));
1185        assert!(matches!(
1186            derived,
1187            SplitLabel::Train | SplitLabel::Validation | SplitLabel::Test
1188        ));
1189
1190        let mut meta = HashMap::new();
1191        meta.insert(
1192            SplitLabel::Test,
1193            PersistedSplitMeta {
1194                epoch: 1,
1195                offset: 2,
1196                hashes_checksum: 3,
1197            },
1198        );
1199        store.save_epoch_meta(&meta).unwrap();
1200        let loaded_meta = store.load_epoch_meta().unwrap();
1201        assert_eq!(loaded_meta.get(&SplitLabel::Test).unwrap().offset, 2);
1202
1203        assert!(
1204            store
1205                .load_epoch_hashes(SplitLabel::Train)
1206                .unwrap()
1207                .is_none()
1208        );
1209        store
1210            .save_epoch_hashes(
1211                SplitLabel::Train,
1212                &PersistedSplitHashes {
1213                    checksum: 11,
1214                    hashes: vec![4, 5],
1215                },
1216            )
1217            .unwrap();
1218        assert_eq!(
1219            store
1220                .load_epoch_hashes(SplitLabel::Train)
1221                .unwrap()
1222                .unwrap()
1223                .checksum,
1224            11
1225        );
1226
1227        assert!(store.load_sampler_state().unwrap().is_none());
1228        let sampler_state = PersistedSamplerState {
1229            source_cycle_idx: 1,
1230            source_record_cursors: vec![("s1".to_string(), 2)],
1231            epoch: 3,
1232            epoch_step: 0,
1233            rng_state: 4,
1234            triplet_recipe_rr_idx: 5,
1235            text_recipe_rr_idx: 6,
1236            source_stream_cursors: vec![("s1".to_string(), 7)],
1237        };
1238        store.save_sampler_state(&sampler_state, None).unwrap();
1239        assert_eq!(
1240            store.load_sampler_state().unwrap().unwrap().epoch,
1241            sampler_state.epoch
1242        );
1243    }
1244
1245    #[test]
1246    fn open_with_load_path_bootstraps_state_explicitly() {
1247        let dir = tempdir().unwrap();
1248        let path_a = dir.path().join("snapshot_a.bin");
1249        let path_b = dir.path().join("snapshot_b.bin");
1250        let ratios = SplitRatios::default();
1251
1252        let store_a = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1253
1254        let mut meta = HashMap::new();
1255        meta.insert(
1256            SplitLabel::Train,
1257            PersistedSplitMeta {
1258                epoch: 5,
1259                offset: 3,
1260                hashes_checksum: 999,
1261            },
1262        );
1263        store_a.save_epoch_meta(&meta).unwrap();
1264
1265        let sampler_state = PersistedSamplerState {
1266            source_cycle_idx: 1,
1267            source_record_cursors: vec![("s1".to_string(), 2)],
1268            epoch: 7,
1269            epoch_step: 0,
1270            rng_state: 123,
1271            triplet_recipe_rr_idx: 4,
1272            text_recipe_rr_idx: 6,
1273            source_stream_cursors: vec![("s1".to_string(), 8)],
1274        };
1275        store_a.save_sampler_state(&sampler_state, None).unwrap();
1276        drop(store_a);
1277
1278        let store_b =
1279            FileSplitStore::open_with_load_path(Some(path_a.clone()), &path_b, ratios, 42).unwrap();
1280        assert_eq!(
1281            store_b
1282                .load_epoch_meta()
1283                .unwrap()
1284                .get(&SplitLabel::Train)
1285                .unwrap()
1286                .epoch,
1287            5
1288        );
1289        assert_eq!(store_b.load_sampler_state().unwrap().unwrap().epoch, 7);
1290
1291        let store_a_again = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1292        assert_eq!(
1293            store_a_again
1294                .load_epoch_meta()
1295                .unwrap()
1296                .get(&SplitLabel::Train)
1297                .unwrap()
1298                .epoch,
1299            5
1300        );
1301        assert_eq!(
1302            store_a_again.load_sampler_state().unwrap().unwrap().epoch,
1303            7
1304        );
1305    }
1306
1307    #[test]
1308    fn save_sampler_state_to_new_path_copies_existing_store_first() {
1309        let dir = tempdir().unwrap();
1310        let path_a = dir.path().join("source_store.bin");
1311        let path_b = dir.path().join("mirror_store.bin");
1312        let ratios = SplitRatios::default();
1313
1314        let store_a = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1315
1316        let assigned_id = "record_with_assignment".to_string();
1317        let assigned_key = split_key(&assigned_id);
1318        store_a.store.write(&assigned_key, b"1").unwrap();
1319
1320        let sampler_state = PersistedSamplerState {
1321            source_cycle_idx: 1,
1322            source_record_cursors: vec![("s1".to_string(), 2)],
1323            epoch: 9,
1324            epoch_step: 0,
1325            rng_state: 123,
1326            triplet_recipe_rr_idx: 4,
1327            text_recipe_rr_idx: 6,
1328            source_stream_cursors: vec![("s1".to_string(), 8)],
1329        };
1330
1331        store_a
1332            .save_sampler_state(&sampler_state, Some(path_b.as_path()))
1333            .unwrap();
1334
1335        // Destination gets the existing store data AND the new sampler state.
1336        let store_b = FileSplitStore::open(&path_b, ratios, 42).unwrap();
1337        assert_eq!(
1338            store_b.label_for(&assigned_id),
1339            Some(SplitLabel::Validation)
1340        );
1341        assert_eq!(store_b.load_sampler_state().unwrap().unwrap().epoch, 9);
1342
1343        // save_to=Some(path_b) must not publish to path_a (the canonical save_path).
1344        assert!(
1345            !path_a.exists(),
1346            "save_to=Some(...) must not publish to the canonical save_path"
1347        );
1348    }
1349
1350    /// Saving to a custom path on a plain `open` store must not mutate the
1351    /// working store file -- i.e. a previously-saved state is preserved.
1352    #[test]
1353    fn save_some_on_regular_open_does_not_modify_working_store() {
1354        let dir = tempdir().unwrap();
1355        let path_a = dir.path().join("working_store.bin");
1356        let path_b = dir.path().join("checkpoint_store.bin");
1357        let ratios = SplitRatios::default();
1358
1359        let store_a = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1360
1361        // Establish a baseline state in the working store.
1362        let initial_state = PersistedSamplerState {
1363            source_cycle_idx: 1,
1364            source_record_cursors: vec![],
1365            epoch: 1,
1366            epoch_step: 0,
1367            rng_state: 0,
1368            triplet_recipe_rr_idx: 0,
1369            text_recipe_rr_idx: 0,
1370            source_stream_cursors: vec![],
1371        };
1372        store_a.save_sampler_state(&initial_state, None).unwrap();
1373
1374        // Snapshot a newer state to a separate checkpoint path.
1375        let checkpoint_state = PersistedSamplerState {
1376            source_cycle_idx: 99,
1377            source_record_cursors: vec![],
1378            epoch: 99,
1379            epoch_step: 0,
1380            rng_state: 42,
1381            triplet_recipe_rr_idx: 0,
1382            text_recipe_rr_idx: 0,
1383            source_stream_cursors: vec![],
1384        };
1385        store_a
1386            .save_sampler_state(&checkpoint_state, Some(path_b.as_path()))
1387            .unwrap();
1388
1389        // The on-disk save_path (path_a) must not have been overwritten by save_to=Some(...).
1390        // Re-open from disk to verify; path_a was last published by save_to=None above.
1391        drop(store_a);
1392        let store_a_disk = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1393        assert_eq!(
1394            store_a_disk.load_sampler_state().unwrap().unwrap().epoch,
1395            1,
1396            "save_to=Some(...) must not overwrite the on-disk save_path"
1397        );
1398
1399        // The checkpoint must hold the new state.
1400        let store_b = FileSplitStore::open(&path_b, ratios, 42).unwrap();
1401        let state_from_b = store_b.load_sampler_state().unwrap().unwrap();
1402        assert_eq!(
1403            state_from_b.epoch, 99,
1404            "checkpoint store must hold the snapshotted state"
1405        );
1406    }
1407
1408    #[test]
1409    fn file_store_metadata_mismatch_and_debug_paths_are_covered() {
1410        let dir = tempdir().unwrap();
1411        let path = dir.path().join("meta_mismatch.bin");
1412        let ratios = SplitRatios::default();
1413        let store = FileSplitStore::open(&path, ratios, 123).unwrap();
1414
1415        let debug_repr = format!("{store:?}");
1416        assert!(debug_repr.contains("FileSplitStore"));
1417
1418        let wrong_version = StoreMeta {
1419            version: STORE_VERSION.wrapping_add(1),
1420            seed: 123,
1421            ratios,
1422        };
1423        let payload = encode_store_meta(&wrong_version);
1424        store.store.write(META_KEY, &payload).unwrap();
1425        // Publish the corrupted temp to disk so the next open reads the bad version.
1426        store
1427            .save_sampler_state(
1428                &PersistedSamplerState {
1429                    source_cycle_idx: 0,
1430                    source_record_cursors: vec![],
1431                    epoch: 0,
1432                    epoch_step: 0,
1433                    rng_state: 0,
1434                    triplet_recipe_rr_idx: 0,
1435                    text_recipe_rr_idx: 0,
1436                    source_stream_cursors: vec![],
1437                },
1438                None,
1439            )
1440            .unwrap();
1441        drop(store);
1442
1443        let err = FileSplitStore::open(&path, ratios, 123).unwrap_err();
1444        assert!(matches!(err, SamplerError::SplitStore(msg) if msg.contains("version mismatch")));
1445
1446        let ratio_path = dir.path().join("ratio_mismatch.bin");
1447        let baseline = FileSplitStore::open(&ratio_path, ratios, 777).unwrap();
1448        // Publish so ratio_path exists on disk before reopening with different ratios.
1449        baseline
1450            .save_sampler_state(
1451                &PersistedSamplerState {
1452                    source_cycle_idx: 0,
1453                    source_record_cursors: vec![],
1454                    epoch: 0,
1455                    epoch_step: 0,
1456                    rng_state: 0,
1457                    triplet_recipe_rr_idx: 0,
1458                    text_recipe_rr_idx: 0,
1459                    source_stream_cursors: vec![],
1460                },
1461                None,
1462            )
1463            .unwrap();
1464        drop(baseline);
1465
1466        let different_ratios = SplitRatios {
1467            train: 0.7,
1468            validation: 0.2,
1469            test: 0.1,
1470        };
1471        let err = FileSplitStore::open(&ratio_path, different_ratios, 777).unwrap_err();
1472        assert!(matches!(err, SamplerError::SplitStore(msg) if msg.contains("ratios mismatch")));
1473    }
1474
1475    #[test]
1476    fn split_decode_helpers_reject_corrupt_bitcode_payloads() {
1477        let store_meta_err = decode_store_meta(&[BITCODE_PREFIX, 0xFF, 0xEE]).unwrap_err();
1478        assert!(matches!(
1479            store_meta_err,
1480            SamplerError::SplitStore(msg) if msg.contains("failed to decode split store metadata")
1481        ));
1482
1483        let epoch_meta_err =
1484            decode_epoch_meta(&[EPOCH_META_RECORD_VERSION, BITCODE_PREFIX, 0xFF]).unwrap_err();
1485        assert!(
1486            matches!(epoch_meta_err, SamplerError::SplitStore(msg) if msg.contains("corrupt epoch meta record"))
1487        );
1488
1489        let epoch_hashes_err =
1490            decode_epoch_hashes(&[EPOCH_HASH_RECORD_VERSION, BITCODE_PREFIX, 0xFF]).unwrap_err();
1491        assert!(matches!(
1492            epoch_hashes_err,
1493            SamplerError::SplitStore(msg) if msg.contains("corrupt epoch hashes record")
1494        ));
1495
1496        let sampler_state_err =
1497            decode_sampler_state(&[SAMPLER_STATE_RECORD_VERSION, BITCODE_PREFIX, 0xFF])
1498                .unwrap_err();
1499        assert!(matches!(
1500            sampler_state_err,
1501            SamplerError::SplitStore(msg) if msg.contains("corrupt sampler state record")
1502        ));
1503    }
1504
1505    #[test]
1506    fn file_store_label_fallback_and_validation_keys_are_covered() {
1507        let dir = tempdir().unwrap();
1508        let path = dir.path().join("labels.bin");
1509        let store = FileSplitStore::open(&path, SplitRatios::default(), 42).unwrap();
1510
1511        let id = "bad_label_record".to_string();
1512        let expected = derive_label_for_id(&id, 42, SplitRatios::default());
1513        let key = split_key(&id);
1514
1515        store.store.write(&key, b"x").unwrap();
1516        assert_eq!(store.label_for(&id), Some(expected));
1517
1518        store.store.write(&key, b"1").unwrap();
1519        assert_eq!(store.label_for(&id), Some(SplitLabel::Validation));
1520
1521        let meta_validation = epoch_meta_key(SplitLabel::Validation);
1522        let hashes_validation = epoch_hashes_key(SplitLabel::Validation);
1523        assert!(meta_validation.starts_with(EPOCH_META_PREFIX));
1524        assert!(hashes_validation.starts_with(EPOCH_HASHES_PREFIX));
1525        assert!(meta_validation.ends_with(b"validation"));
1526        assert!(hashes_validation.ends_with(b"validation"));
1527    }
1528
1529    #[test]
1530    fn ensure_parent_dir_allows_plain_file_names() {
1531        ensure_parent_dir(Path::new("split_store_local.bin")).unwrap();
1532        let coerced = coerce_store_path(PathBuf::from("explicit_store.bin"));
1533        assert_eq!(coerced, PathBuf::from("explicit_store.bin"));
1534    }
1535
1536    // -----------------------------------------------------------------------
1537    // Temp-dir bootstrap contract
1538    // -----------------------------------------------------------------------
1539
1540    /// `open_with_load_path` must NOT modify the source file while the store is open.
1541    #[test]
1542    fn load_path_source_is_never_modified_while_open() {
1543        let dir = tempdir().unwrap();
1544        let source = dir.path().join("source.bin");
1545        let dest = dir.path().join("dest.bin");
1546        let ratios = SplitRatios::default();
1547
1548        // Seed the source with known state.
1549        let seeded = FileSplitStore::open(&source, ratios, 77).unwrap();
1550        let state = PersistedSamplerState {
1551            source_cycle_idx: 5,
1552            source_record_cursors: vec![("s".to_string(), 3)],
1553            epoch: 9,
1554            epoch_step: 0,
1555            rng_state: 42,
1556            triplet_recipe_rr_idx: 1,
1557            text_recipe_rr_idx: 2,
1558            source_stream_cursors: vec![("s".to_string(), 4)],
1559        };
1560        seeded.save_sampler_state(&state, None).unwrap();
1561        drop(seeded);
1562
1563        let source_size_before = std::fs::metadata(&source).unwrap().len();
1564
1565        // Open bootstrapped store and write mutations to it.
1566        let bootstrapped =
1567            FileSplitStore::open_with_load_path(Some(&source), &dest, ratios, 77).unwrap();
1568        let new_state = PersistedSamplerState {
1569            source_cycle_idx: 99,
1570            source_record_cursors: vec![("s".to_string(), 77)],
1571            epoch: 100,
1572            epoch_step: 0,
1573            rng_state: 0,
1574            triplet_recipe_rr_idx: 0,
1575            text_recipe_rr_idx: 0,
1576            source_stream_cursors: vec![],
1577        };
1578        bootstrapped.save_sampler_state(&new_state, None).unwrap();
1579        drop(bootstrapped);
1580
1581        // Source must be byte-identical to before bootstrap.
1582        let source_size_after = std::fs::metadata(&source).unwrap().len();
1583        assert_eq!(
1584            source_size_before, source_size_after,
1585            "source file was modified during bootstrapped open"
1586        );
1587
1588        // Verify source still holds the original state.
1589        let verify_source = FileSplitStore::open(&source, ratios, 77).unwrap();
1590        let loaded = verify_source.load_sampler_state().unwrap().unwrap();
1591        assert_eq!(loaded.source_cycle_idx, 5);
1592        assert_eq!(loaded.epoch, 9);
1593    }
1594
1595    /// `save_sampler_state(None)` on a bootstrapped store must publish to the
1596    /// declared `save_path` only, not to the source and not to a temp path.
1597    #[test]
1598    fn save_none_on_bootstrapped_store_publishes_to_save_path() {
1599        let dir = tempdir().unwrap();
1600        let source = dir.path().join("load.bin");
1601        let dest = dir.path().join("save.bin");
1602        let ratios = SplitRatios::default();
1603
1604        let _ = FileSplitStore::open(&source, ratios, 11).unwrap();
1605
1606        assert!(!dest.exists(), "dest must not exist before first save");
1607
1608        let store = FileSplitStore::open_with_load_path(Some(&source), &dest, ratios, 11).unwrap();
1609        let state = PersistedSamplerState {
1610            source_cycle_idx: 7,
1611            source_record_cursors: vec![],
1612            epoch: 2,
1613            epoch_step: 0,
1614            rng_state: 1,
1615            triplet_recipe_rr_idx: 0,
1616            text_recipe_rr_idx: 0,
1617            source_stream_cursors: vec![],
1618        };
1619        store.save_sampler_state(&state, None).unwrap();
1620        drop(store);
1621
1622        assert!(dest.exists(), "dest must exist after save(None)");
1623
1624        let loaded_dest = FileSplitStore::open(&dest, ratios, 11).unwrap();
1625        assert_eq!(
1626            loaded_dest
1627                .load_sampler_state()
1628                .unwrap()
1629                .unwrap()
1630                .source_cycle_idx,
1631            7
1632        );
1633    }
1634
1635    /// `save_sampler_state(Some(other))` on a bootstrapped store must publish to
1636    /// `other` only — the declared `save_path` must remain absent.
1637    #[test]
1638    fn save_some_on_bootstrapped_store_publishes_to_explicit_path_only() {
1639        let dir = tempdir().unwrap();
1640        let source = dir.path().join("load.bin");
1641        let save = dir.path().join("save.bin"); // canonical — should stay empty
1642        let other = dir.path().join("other.bin"); // explicit target
1643        let ratios = SplitRatios::default();
1644
1645        let _ = FileSplitStore::open(&source, ratios, 22).unwrap();
1646
1647        let store = FileSplitStore::open_with_load_path(Some(&source), &save, ratios, 22).unwrap();
1648        let state = PersistedSamplerState {
1649            source_cycle_idx: 3,
1650            source_record_cursors: vec![],
1651            epoch: 1,
1652            epoch_step: 0,
1653            rng_state: 0,
1654            triplet_recipe_rr_idx: 0,
1655            text_recipe_rr_idx: 0,
1656            source_stream_cursors: vec![],
1657        };
1658        store
1659            .save_sampler_state(&state, Some(other.as_path()))
1660            .unwrap();
1661        drop(store);
1662
1663        assert!(
1664            !save.exists(),
1665            "canonical save_path must not be created when saving to explicit path"
1666        );
1667        assert!(other.exists(), "explicit target must be created");
1668
1669        let loaded_other = FileSplitStore::open(&other, ratios, 22).unwrap();
1670        assert_eq!(
1671            loaded_other
1672                .load_sampler_state()
1673                .unwrap()
1674                .unwrap()
1675                .source_cycle_idx,
1676            3
1677        );
1678    }
1679
1680    /// Repeated `save(None)` calls on a bootstrapped store are idempotent: the
1681    /// canonical save_path is overwritten cleanly each time.
1682    #[test]
1683    fn repeated_save_none_on_bootstrapped_store_is_idempotent() {
1684        let dir = tempdir().unwrap();
1685        let source = dir.path().join("load.bin");
1686        let dest = dir.path().join("save.bin");
1687        let ratios = SplitRatios::default();
1688
1689        let _ = FileSplitStore::open(&source, ratios, 33).unwrap();
1690
1691        let store = FileSplitStore::open_with_load_path(Some(&source), &dest, ratios, 33).unwrap();
1692
1693        for cycle_idx in [1_u64, 2, 3] {
1694            let state = PersistedSamplerState {
1695                source_cycle_idx: cycle_idx,
1696                source_record_cursors: vec![],
1697                epoch: 0,
1698                epoch_step: 0,
1699                rng_state: 0,
1700                triplet_recipe_rr_idx: 0,
1701                text_recipe_rr_idx: 0,
1702                source_stream_cursors: vec![],
1703            };
1704            store.save_sampler_state(&state, None).unwrap();
1705        }
1706        drop(store);
1707
1708        let reloaded = FileSplitStore::open(&dest, ratios, 33).unwrap();
1709        assert_eq!(
1710            reloaded
1711                .load_sampler_state()
1712                .unwrap()
1713                .unwrap()
1714                .source_cycle_idx,
1715            3,
1716            "dest should hold the last-saved state"
1717        );
1718    }
1719}