Skip to main content

triplets_core/
splits.rs

1use serde::{Deserialize, Serialize};
2use simd_r_drive::storage_engine::DataStore;
3use simd_r_drive::storage_engine::traits::{DataStoreReader, DataStoreWriter};
4use std::collections::HashMap;
5use std::fmt;
6use std::fs;
7use std::hash::{Hash, Hasher};
8use std::io;
9use std::path::{Path, PathBuf};
10use std::sync::RwLock;
11use tempfile::TempDir;
12
13use crate::constants::splits::{
14    ALL_SPLITS, BITCODE_PREFIX, EPOCH_HASH_RECORD_VERSION, EPOCH_HASHES_PREFIX, EPOCH_META_PREFIX,
15    EPOCH_META_RECORD_VERSION, EPOCH_RECORD_TOMBSTONE, META_KEY, SAMPLER_STATE_KEY,
16    SAMPLER_STATE_RECORD_VERSION, SPLIT_PREFIX, STORE_VERSION,
17};
18use crate::data::RecordId;
19use crate::errors::SamplerError;
20use crate::types::SourceId;
21
22/// Logical dataset partitions used during sampling.
23#[derive(
24    Clone,
25    Copy,
26    Debug,
27    PartialEq,
28    Eq,
29    Hash,
30    Serialize,
31    Deserialize,
32    bitcode::Encode,
33    bitcode::Decode,
34)]
35pub enum SplitLabel {
36    /// Training split.
37    Train,
38    /// Validation split.
39    Validation,
40    /// Test split.
41    Test,
42}
43
44/// Ratio configuration for train/validation/test assignment.
45#[derive(Clone, Copy, Debug, Serialize, Deserialize, bitcode::Encode, bitcode::Decode)]
46pub struct SplitRatios {
47    /// Fraction assigned to train.
48    pub train: f32,
49    /// Fraction assigned to validation.
50    pub validation: f32,
51    /// Fraction assigned to test.
52    pub test: f32,
53}
54
55impl Default for SplitRatios {
56    fn default() -> Self {
57        Self {
58            train: 0.8,
59            validation: 0.1,
60            test: 0.1,
61        }
62    }
63}
64
65impl SplitRatios {
66    /// Validate that ratios sum to `1.0` (within epsilon).
67    pub fn normalized(self) -> Result<Self, SamplerError> {
68        let sum = self.train + self.validation + self.test;
69        if (sum - 1.0).abs() > 1e-6 {
70            return Err(SamplerError::Configuration(
71                "split ratios must sum to 1.0".to_string(),
72            ));
73        }
74        Ok(self)
75    }
76}
77
78pub use crate::constants::splits::EPOCH_STATE_VERSION;
79
80/// Persisted epoch cursor metadata for one split.
81#[derive(Clone, Debug, bitcode::Encode, bitcode::Decode)]
82pub struct PersistedSplitMeta {
83    /// Current epoch for this split.
84    pub epoch: u64,
85    /// Cursor offset within the epoch hash list.
86    pub offset: u64,
87    /// Checksum of the persisted hash list.
88    pub hashes_checksum: u64,
89}
90
91/// Persisted deterministic epoch hash ordering for one split.
92#[derive(Clone, Debug, bitcode::Encode, bitcode::Decode)]
93pub struct PersistedSplitHashes {
94    /// Checksum of `hashes`.
95    pub checksum: u64,
96    /// Deterministic per-epoch hash ordering.
97    pub hashes: Vec<u64>,
98}
99
100/// Persisted sampler runtime state (cursors, recipe indices, RNG).
101#[derive(Clone, Debug, bitcode::Encode, bitcode::Decode)]
102pub struct PersistedSamplerState {
103    /// Source-cycle round-robin index.
104    pub source_cycle_idx: u64,
105    /// Per-source record cursors.
106    pub source_record_cursors: Vec<(SourceId, u64)>,
107    /// Current source epoch used for deterministic reshuffle.
108    pub source_epoch: u64,
109    /// Deterministic RNG internal state.
110    pub rng_state: u64,
111    /// Round-robin index for triplet recipes.
112    pub triplet_recipe_rr_idx: u64,
113    /// Round-robin index for text recipes.
114    pub text_recipe_rr_idx: u64,
115    /// Persisted source stream refresh cursors.
116    pub source_stream_cursors: Vec<(SourceId, u64)>,
117}
118
119/// Split assignment backend.
120///
121/// Implementations map `RecordId` values to split labels deterministically.
122pub trait SplitStore: Send + Sync {
123    /// Return split label for `id` if known/derivable.
124    fn label_for(&self, id: &RecordId) -> Option<SplitLabel>;
125    /// Persist an explicit split assignment for `id`.
126    fn upsert(&self, id: RecordId, label: SplitLabel) -> Result<(), SamplerError>;
127    /// Return configured split ratios.
128    fn ratios(&self) -> SplitRatios;
129    /// Return the split label for `id`, creating/deriving one when needed.
130    fn ensure(&self, id: RecordId) -> Result<SplitLabel, SamplerError>;
131}
132
133/// Persistence backend for epoch metadata and epoch hash orderings.
134pub trait EpochStateStore: Send + Sync {
135    /// Load split→epoch metadata map.
136    fn load_epoch_meta(&self) -> Result<HashMap<SplitLabel, PersistedSplitMeta>, SamplerError>;
137    /// Load persisted epoch hashes for one split, if available.
138    fn load_epoch_hashes(
139        &self,
140        label: SplitLabel,
141    ) -> Result<Option<PersistedSplitHashes>, SamplerError>;
142    /// Persist split→epoch metadata map.
143    fn save_epoch_meta(
144        &self,
145        meta: &HashMap<SplitLabel, PersistedSplitMeta>,
146    ) -> Result<(), SamplerError>;
147    /// Persist epoch hash list for one split.
148    fn save_epoch_hashes(
149        &self,
150        label: SplitLabel,
151        hashes: &PersistedSplitHashes,
152    ) -> Result<(), SamplerError>;
153}
154
155/// Persistence backend for sampler runtime state.
156pub trait SamplerStateStore: Send + Sync {
157    /// Load persisted sampler runtime state, if present.
158    fn load_sampler_state(&self) -> Result<Option<PersistedSamplerState>, SamplerError>;
159    /// Save sampler runtime state, optionally mirroring to `save_path`.
160    fn save_sampler_state(
161        &self,
162        state: &PersistedSamplerState,
163        save_path: Option<&Path>,
164    ) -> Result<(), SamplerError>;
165}
166
167/// In-memory split store with deterministic assignment derivation.
168pub struct DeterministicSplitStore {
169    ratios: SplitRatios,
170    assignments: RwLock<HashMap<RecordId, SplitLabel>>,
171    seed: u64,
172    epoch_meta: RwLock<HashMap<SplitLabel, PersistedSplitMeta>>,
173    epoch_hashes: RwLock<HashMap<SplitLabel, PersistedSplitHashes>>,
174    sampler_state: RwLock<Option<PersistedSamplerState>>,
175}
176
177impl DeterministicSplitStore {
178    /// Create an in-memory split store configured with `ratios` and `seed`.
179    pub fn new(ratios: SplitRatios, seed: u64) -> Result<Self, SamplerError> {
180        ratios.normalized()?;
181        Ok(Self {
182            ratios,
183            assignments: RwLock::new(HashMap::new()),
184            seed,
185            epoch_meta: RwLock::new(HashMap::new()),
186            epoch_hashes: RwLock::new(HashMap::new()),
187            sampler_state: RwLock::new(None),
188        })
189    }
190
191    fn derive_label(&self, id: &RecordId) -> SplitLabel {
192        derive_label_for_id(id, self.seed, self.ratios)
193    }
194}
195
196impl SplitStore for DeterministicSplitStore {
197    fn label_for(&self, id: &RecordId) -> Option<SplitLabel> {
198        if let Some(label) = self.assignments.read().ok()?.get(id).copied() {
199            return Some(label);
200        }
201        Some(self.derive_label(id))
202    }
203
204    fn upsert(&self, id: RecordId, label: SplitLabel) -> Result<(), SamplerError> {
205        let mut guard = self
206            .assignments
207            .write()
208            .map_err(|_| SamplerError::SplitStore("lock poisoned".into()))?;
209        guard.insert(id, label);
210        Ok(())
211    }
212
213    fn ratios(&self) -> SplitRatios {
214        self.ratios
215    }
216
217    fn ensure(&self, id: RecordId) -> Result<SplitLabel, SamplerError> {
218        Ok(self.derive_label(&id))
219    }
220}
221
222impl EpochStateStore for DeterministicSplitStore {
223    fn load_epoch_meta(&self) -> Result<HashMap<SplitLabel, PersistedSplitMeta>, SamplerError> {
224        self.epoch_meta
225            .read()
226            .map_err(|_| SamplerError::SplitStore("epoch meta lock poisoned".into()))
227            .map(|guard| guard.clone())
228    }
229
230    fn load_epoch_hashes(
231        &self,
232        label: SplitLabel,
233    ) -> Result<Option<PersistedSplitHashes>, SamplerError> {
234        Ok(self
235            .epoch_hashes
236            .read()
237            .map_err(|_| SamplerError::SplitStore("epoch hashes lock poisoned".into()))?
238            .get(&label)
239            .cloned())
240    }
241
242    fn save_epoch_meta(
243        &self,
244        meta: &HashMap<SplitLabel, PersistedSplitMeta>,
245    ) -> Result<(), SamplerError> {
246        *self
247            .epoch_meta
248            .write()
249            .map_err(|_| SamplerError::SplitStore("epoch meta lock poisoned".into()))? =
250            meta.clone();
251        Ok(())
252    }
253
254    fn save_epoch_hashes(
255        &self,
256        label: SplitLabel,
257        hashes: &PersistedSplitHashes,
258    ) -> Result<(), SamplerError> {
259        self.epoch_hashes
260            .write()
261            .map_err(|_| SamplerError::SplitStore("epoch hashes lock poisoned".into()))?
262            .insert(label, hashes.clone());
263        Ok(())
264    }
265}
266
267impl SamplerStateStore for DeterministicSplitStore {
268    fn load_sampler_state(&self) -> Result<Option<PersistedSamplerState>, SamplerError> {
269        self.sampler_state
270            .read()
271            .map_err(|_| SamplerError::SplitStore("sampler state lock poisoned".into()))
272            .map(|guard| guard.clone())
273    }
274
275    fn save_sampler_state(
276        &self,
277        state: &PersistedSamplerState,
278        _save_path: Option<&Path>,
279    ) -> Result<(), SamplerError> {
280        *self
281            .sampler_state
282            .write()
283            .map_err(|_| SamplerError::SplitStore("sampler state lock poisoned".into()))? =
284            Some(state.clone());
285        Ok(())
286    }
287}
288
289#[derive(Clone, Copy, Debug, bitcode::Encode, bitcode::Decode)]
290/// Versioned metadata header stored in file-backed split stores.
291struct StoreMeta {
292    version: u8,
293    seed: u64,
294    ratios: SplitRatios,
295}
296
297fn encode_store_meta(meta: &StoreMeta) -> Vec<u8> {
298    encode_bitcode_payload(&bitcode::encode(meta))
299}
300
301fn decode_store_meta(bytes: &[u8]) -> Result<StoreMeta, SamplerError> {
302    let raw = decode_bitcode_payload(bytes)?;
303    bitcode::decode(&raw).map_err(|err| {
304        SamplerError::SplitStore(format!("failed to decode split store metadata: {err}"))
305    })
306}
307
308/// File-backed split store for persistent runs.
309///
310/// Persists assignment metadata, epoch state, and sampler runtime state.
311///
312/// The store **always** works against a private temporary copy of the source
313/// snapshot.  All reads and mutations accumulate in the temp file.
314/// State is published to permanent storage only when
315/// [`SamplerStateStore::save_sampler_state`] is called:
316///
317/// * `save_to == None`  → publish temp to `save_path` (may overwrite).
318/// * `save_to == Some(p)` → publish temp to `p` only; `save_path` is left
319///   untouched.
320///
321/// This guarantees that the original source file is never modified and that
322/// no partial state leaks to the target before an explicit save.
323pub struct FileSplitStore {
324    store: DataStore,
325    /// Working path: always a private temp file; all reads and writes go here.
326    path: PathBuf,
327    /// Declared save destination; published to on `save_sampler_state(None)`.
328    save_path: PathBuf,
329    ratios: SplitRatios,
330    seed: u64,
331    /// Keeps the temporary directory alive for the lifetime of this store.
332    _temp_dir: TempDir,
333}
334
335impl fmt::Debug for FileSplitStore {
336    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
337        f.debug_struct("FileSplitStore")
338            .field("path", &self.save_path)
339            .field("ratios", &self.ratios)
340            .field("seed", &self.seed)
341            .finish()
342    }
343}
344
345impl FileSplitStore {
346    /// Open (or create) a file-backed split store at `path`.
347    pub fn open<P: Into<PathBuf>>(
348        path: P,
349        ratios: SplitRatios,
350        seed: u64,
351    ) -> Result<Self, SamplerError> {
352        Self::open_with_load_path(None::<PathBuf>, path, ratios, seed)
353    }
354
355    /// Open (or create) a file-backed split store at `save_path`, optionally
356    /// bootstrapping initial state from `load_path`.
357    ///
358    /// Always stages through a **private temporary file**.  The source used to
359    /// seed the temp is chosen as follows:
360    ///
361    /// 1. `load_path` if supplied and it exists.
362    /// 2. `save_path` if it already exists.
363    /// 3. Nothing — a fresh empty store is created in the temp.
364    ///
365    /// All mutations accumulate in the temp.  State is published to permanent
366    /// storage only when [`SamplerStateStore::save_sampler_state`] is called.
367    /// The original source file is never modified.
368    pub fn open_with_load_path<LP: Into<PathBuf>, SP: Into<PathBuf>>(
369        load_path: Option<LP>,
370        save_path: SP,
371        ratios: SplitRatios,
372        seed: u64,
373    ) -> Result<Self, SamplerError> {
374        let ratios = ratios.normalized()?;
375        let save_path = coerce_store_path(save_path.into());
376        ensure_parent_dir(&save_path)?;
377
378        // Determine source to seed the working temp from.
379        // Priority: explicit load_path > existing save_path > fresh (nothing).
380        let source: Option<PathBuf> = if let Some(lp) = load_path {
381            let lp = coerce_store_path(lp.into());
382            if lp.exists() { Some(lp) } else { None }
383        } else if save_path.exists() {
384            Some(save_path.clone())
385        } else {
386            None
387        };
388
389        let temp_dir = tempfile::tempdir().map_err(|err| {
390            SamplerError::SplitStore(format!("failed to create temp dir for split store: {err}"))
391        })?;
392        let working_path = temp_dir.path().join("working_store.bin");
393        if let Some(src) = &source {
394            fs::copy(src, &working_path).map_err(|err| {
395                SamplerError::SplitStore(format!(
396                    "failed to copy split store from '{}' to temp: {err}",
397                    src.display()
398                ))
399            })?;
400        }
401
402        let raw_store = DataStore::open(&working_path).map_err(map_store_err)?;
403        let store = Self {
404            store: raw_store,
405            path: working_path,
406            save_path,
407            ratios,
408            seed,
409            _temp_dir: temp_dir,
410        };
411        store.verify_metadata()?;
412        Ok(store)
413    }
414
415    fn verify_metadata(&self) -> Result<(), SamplerError> {
416        match read_bytes(&self.store, META_KEY)? {
417            Some(bytes) => {
418                let meta = decode_store_meta(&bytes)?;
419                if meta.version != STORE_VERSION {
420                    return Err(SamplerError::SplitStore(format!(
421                        "split store version mismatch (expected {}, found {})",
422                        STORE_VERSION, meta.version
423                    )));
424                }
425                if meta.seed != self.seed {
426                    return Err(SamplerError::SplitStore(format!(
427                        "split store seed mismatch (expected {}, found {})",
428                        self.seed, meta.seed
429                    )));
430                }
431                if !ratios_close(meta.ratios, self.ratios) {
432                    return Err(SamplerError::SplitStore(
433                        "split store ratios mismatch".into(),
434                    ));
435                }
436            }
437            None => {
438                let blob = StoreMeta {
439                    version: STORE_VERSION,
440                    seed: self.seed,
441                    ratios: self.ratios,
442                };
443                let payload = encode_store_meta(&blob);
444                write_bytes(&self.store, META_KEY, &payload)?;
445            }
446        }
447        Ok(())
448    }
449
450    fn read_epoch_meta_entry(
451        &self,
452        label: SplitLabel,
453    ) -> Result<Option<PersistedSplitMeta>, SamplerError> {
454        let key = epoch_meta_key(label);
455        let entry = self.store.read(&key).map_err(map_store_err)?;
456        match entry {
457            None => Ok(None),
458            Some(bytes) => decode_epoch_meta(bytes.as_ref()),
459        }
460    }
461
462    fn write_epoch_meta_entry(
463        &self,
464        label: SplitLabel,
465        meta: Option<&PersistedSplitMeta>,
466    ) -> Result<(), SamplerError> {
467        let key = epoch_meta_key(label);
468        let payload = encode_epoch_meta(meta);
469        self.store
470            .write(&key, payload.as_slice())
471            .map_err(map_store_err)?;
472        Ok(())
473    }
474
475    fn read_epoch_hashes_entry(
476        &self,
477        label: SplitLabel,
478    ) -> Result<Option<PersistedSplitHashes>, SamplerError> {
479        let key = epoch_hashes_key(label);
480        let entry = self.store.read(&key).map_err(map_store_err)?;
481        match entry {
482            None => Ok(None),
483            Some(bytes) => decode_epoch_hashes(bytes.as_ref()),
484        }
485    }
486
487    fn write_epoch_hashes_entry(
488        &self,
489        label: SplitLabel,
490        hashes: &PersistedSplitHashes,
491    ) -> Result<(), SamplerError> {
492        let key = epoch_hashes_key(label);
493        let payload = encode_epoch_hashes(hashes);
494        self.store
495            .write(&key, payload.as_slice())
496            .map_err(map_store_err)?;
497        Ok(())
498    }
499}
500
501impl SplitStore for FileSplitStore {
502    fn label_for(&self, id: &RecordId) -> Option<SplitLabel> {
503        let key = split_key(id);
504        if let Ok(Some(value)) = self.store.read(&key)
505            && let Ok(label) = decode_label(value.as_ref())
506        {
507            return Some(label);
508        }
509        Some(derive_label_for_id(id, self.seed, self.ratios))
510    }
511
512    fn upsert(&self, id: RecordId, label: SplitLabel) -> Result<(), SamplerError> {
513        let _ = (id, label);
514        Ok(())
515    }
516
517    fn ratios(&self) -> SplitRatios {
518        self.ratios
519    }
520
521    fn ensure(&self, id: RecordId) -> Result<SplitLabel, SamplerError> {
522        Ok(derive_label_for_id(&id, self.seed, self.ratios))
523    }
524}
525
526impl EpochStateStore for FileSplitStore {
527    fn load_epoch_meta(&self) -> Result<HashMap<SplitLabel, PersistedSplitMeta>, SamplerError> {
528        let mut meta = HashMap::new();
529        for label in ALL_SPLITS {
530            if let Some(entry) = self.read_epoch_meta_entry(label)? {
531                meta.insert(label, entry);
532            }
533        }
534        Ok(meta)
535    }
536
537    fn load_epoch_hashes(
538        &self,
539        label: SplitLabel,
540    ) -> Result<Option<PersistedSplitHashes>, SamplerError> {
541        self.read_epoch_hashes_entry(label)
542    }
543
544    fn save_epoch_meta(
545        &self,
546        meta: &HashMap<SplitLabel, PersistedSplitMeta>,
547    ) -> Result<(), SamplerError> {
548        for label in ALL_SPLITS {
549            self.write_epoch_meta_entry(label, meta.get(&label))?;
550        }
551        Ok(())
552    }
553
554    fn save_epoch_hashes(
555        &self,
556        label: SplitLabel,
557        hashes: &PersistedSplitHashes,
558    ) -> Result<(), SamplerError> {
559        self.write_epoch_hashes_entry(label, hashes)
560    }
561}
562
563impl SamplerStateStore for FileSplitStore {
564    fn load_sampler_state(&self) -> Result<Option<PersistedSamplerState>, SamplerError> {
565        match read_bytes(&self.store, SAMPLER_STATE_KEY)? {
566            Some(bytes) => decode_sampler_state(bytes.as_ref()),
567            None => Ok(None),
568        }
569    }
570
571    /// Persist `state` to the working temp store and publish to the destination.
572    ///
573    /// * `save_to == None`  → publish temp to `save_path` (may overwrite).
574    /// * `save_to == Some(p)` → publish temp to `p` only; `save_path` is left
575    ///   untouched.  `p` must not already exist.
576    fn save_sampler_state(
577        &self,
578        state: &PersistedSamplerState,
579        save_to: Option<&Path>,
580    ) -> Result<(), SamplerError> {
581        // Determine publish destination before writing anything.
582        let dest = if let Some(p) = save_to {
583            coerce_store_path(p.to_path_buf())
584        } else {
585            self.save_path.clone()
586        };
587
588        // Refuse to overwrite an explicitly-named destination that already exists.
589        // Saving back to the canonical save_path (None) is always allowed.
590        if save_to.is_some() && dest.exists() {
591            return Err(SamplerError::SplitStore(format!(
592                "refusing to overwrite existing split store '{}'; choose a new path",
593                dest.display()
594            )));
595        }
596
597        // Write state into the working temp store.
598        let payload = encode_sampler_state(state);
599        write_bytes(&self.store, SAMPLER_STATE_KEY, &payload)?;
600
601        // Publish: copy the temp store to the destination.
602        ensure_parent_dir(&dest)?;
603        fs::copy(&self.path, &dest).map_err(|err| {
604            SamplerError::SplitStore(format!(
605                "failed to publish split store to '{}': {err}",
606                dest.display()
607            ))
608        })?;
609
610        Ok(())
611    }
612}
613
614fn decode_label(bytes: &[u8]) -> Result<SplitLabel, SamplerError> {
615    match bytes.first() {
616        Some(b'0') => Ok(SplitLabel::Train),
617        Some(b'1') => Ok(SplitLabel::Validation),
618        Some(b'2') => Ok(SplitLabel::Test),
619        _ => Err(SamplerError::SplitStore("invalid split label".into())),
620    }
621}
622
623fn derive_label_for_id(id: &RecordId, seed: u64, ratios: SplitRatios) -> SplitLabel {
624    let mut hasher = std::collections::hash_map::DefaultHasher::new();
625    id.hash(&mut hasher);
626    seed.hash(&mut hasher);
627    let value = hasher.finish() as f64 / u64::MAX as f64;
628    let train_cut = ratios.train as f64;
629    let val_cut = train_cut + ratios.validation as f64;
630    if value < train_cut {
631        SplitLabel::Train
632    } else if value < val_cut {
633        SplitLabel::Validation
634    } else {
635        SplitLabel::Test
636    }
637}
638
639fn ratios_close(a: SplitRatios, b: SplitRatios) -> bool {
640    ((a.train - b.train).abs() + (a.validation - b.validation).abs() + (a.test - b.test).abs())
641        < 1e-5
642}
643
644fn split_key(id: &RecordId) -> Vec<u8> {
645    let mut key = Vec::with_capacity(SPLIT_PREFIX.len() + id.len());
646    key.extend_from_slice(SPLIT_PREFIX);
647    key.extend_from_slice(id.as_bytes());
648    key
649}
650
651fn read_bytes(store: &DataStore, key: &[u8]) -> Result<Option<Vec<u8>>, SamplerError> {
652    store
653        .read(key)
654        .map_err(map_store_err)?
655        .map(|entry| Ok(entry.as_ref().to_vec()))
656        .transpose()
657}
658
659fn write_bytes(store: &DataStore, key: &[u8], payload: &[u8]) -> Result<(), SamplerError> {
660    store.write(key, payload).map_err(map_store_err)?;
661    Ok(())
662}
663
664fn epoch_meta_key(label: SplitLabel) -> Vec<u8> {
665    let mut key = Vec::with_capacity(EPOCH_META_PREFIX.len() + 12);
666    key.extend_from_slice(EPOCH_META_PREFIX);
667    let suffix = match label {
668        SplitLabel::Train => b"train".as_ref(),
669        SplitLabel::Validation => b"validation".as_ref(),
670        SplitLabel::Test => b"test".as_ref(),
671    };
672    key.extend_from_slice(suffix);
673    key
674}
675
676fn epoch_hashes_key(label: SplitLabel) -> Vec<u8> {
677    let mut key = Vec::with_capacity(EPOCH_HASHES_PREFIX.len() + 12);
678    key.extend_from_slice(EPOCH_HASHES_PREFIX);
679    let suffix = match label {
680        SplitLabel::Train => b"train".as_ref(),
681        SplitLabel::Validation => b"validation".as_ref(),
682        SplitLabel::Test => b"test".as_ref(),
683    };
684    key.extend_from_slice(suffix);
685    key
686}
687
688fn encode_epoch_meta(meta: Option<&PersistedSplitMeta>) -> Vec<u8> {
689    match meta {
690        None => vec![EPOCH_RECORD_TOMBSTONE],
691        Some(meta) => {
692            let payload = encode_bitcode_payload(&bitcode::encode(meta));
693            let mut buf = Vec::with_capacity(1 + payload.len());
694            buf.push(EPOCH_META_RECORD_VERSION);
695            buf.extend_from_slice(&payload);
696            buf
697        }
698    }
699}
700
701fn decode_epoch_meta(bytes: &[u8]) -> Result<Option<PersistedSplitMeta>, SamplerError> {
702    if bytes.is_empty() || bytes[0] == EPOCH_RECORD_TOMBSTONE {
703        return Ok(None);
704    }
705    if bytes[0] != EPOCH_META_RECORD_VERSION {
706        return Err(SamplerError::SplitStore(
707            "epoch meta record version mismatch".into(),
708        ));
709    }
710    let raw = decode_bitcode_payload(&bytes[1..])?;
711    bitcode::decode(&raw)
712        .map(Some)
713        .map_err(|err| SamplerError::SplitStore(format!("corrupt epoch meta record: {err}")))
714}
715
716fn encode_epoch_hashes(hashes: &PersistedSplitHashes) -> Vec<u8> {
717    let payload = encode_bitcode_payload(&bitcode::encode(hashes));
718    let mut buf = Vec::with_capacity(1 + payload.len());
719    buf.push(EPOCH_HASH_RECORD_VERSION);
720    buf.extend_from_slice(&payload);
721    buf
722}
723
724fn decode_epoch_hashes(bytes: &[u8]) -> Result<Option<PersistedSplitHashes>, SamplerError> {
725    if bytes.is_empty() || bytes[0] == EPOCH_RECORD_TOMBSTONE {
726        return Ok(None);
727    }
728    if bytes[0] != EPOCH_HASH_RECORD_VERSION {
729        return Err(SamplerError::SplitStore(
730            "epoch hashes record version mismatch".into(),
731        ));
732    }
733    let raw = decode_bitcode_payload(&bytes[1..])?;
734    bitcode::decode(&raw)
735        .map(Some)
736        .map_err(|err| SamplerError::SplitStore(format!("corrupt epoch hashes record: {err}")))
737}
738
739fn encode_sampler_state(state: &PersistedSamplerState) -> Vec<u8> {
740    let payload = encode_bitcode_payload(&bitcode::encode(state));
741    let mut buf = Vec::with_capacity(1 + payload.len());
742    buf.push(SAMPLER_STATE_RECORD_VERSION);
743    buf.extend_from_slice(&payload);
744    buf
745}
746
747fn decode_sampler_state(bytes: &[u8]) -> Result<Option<PersistedSamplerState>, SamplerError> {
748    if bytes.is_empty() {
749        return Ok(None);
750    }
751    if bytes[0] != SAMPLER_STATE_RECORD_VERSION {
752        return Err(SamplerError::SplitStore(
753            "sampler state record version mismatch".into(),
754        ));
755    }
756    let raw = decode_bitcode_payload(&bytes[1..])?;
757    bitcode::decode(&raw)
758        .map(Some)
759        .map_err(|err| SamplerError::SplitStore(format!("corrupt sampler state record: {err}")))
760}
761
762fn encode_bitcode_payload(bytes: &[u8]) -> Vec<u8> {
763    let mut out = Vec::with_capacity(1 + bytes.len());
764    out.push(BITCODE_PREFIX);
765    out.extend_from_slice(bytes);
766    out
767}
768
769fn decode_bitcode_payload(bytes: &[u8]) -> Result<Vec<u8>, SamplerError> {
770    if bytes.first().copied() != Some(BITCODE_PREFIX) {
771        return Err(SamplerError::SplitStore(
772            "bitcode payload missing expected prefix".into(),
773        ));
774    }
775    Ok(bytes[1..].to_vec())
776}
777
778fn coerce_store_path(path: PathBuf) -> PathBuf {
779    path
780}
781
782fn ensure_parent_dir(path: &Path) -> Result<(), SamplerError> {
783    if let Some(parent) = path.parent()
784        && !parent.as_os_str().is_empty()
785    {
786        fs::create_dir_all(parent)?;
787    }
788    Ok(())
789}
790
791fn map_store_err(err: io::Error) -> SamplerError {
792    SamplerError::SplitStore(err.to_string())
793}
794
795#[cfg(test)]
796mod tests {
797    use super::*;
798    use std::collections::HashMap;
799    use tempfile::tempdir;
800
801    #[test]
802    fn split_ratios_reject_non_unit_sum() {
803        let invalid = SplitRatios {
804            train: 0.6,
805            validation: 0.3,
806            test: 0.3,
807        };
808
809        let err = DeterministicSplitStore::new(invalid, 1)
810            .err()
811            .expect("expected non-unit split ratios to fail");
812        assert!(matches!(
813            err,
814            SamplerError::Configuration(ref msg) if msg.contains("split ratios must sum to 1.0")
815        ));
816
817        let dir = tempdir().unwrap();
818        let path = dir.path().join("split_store.bin");
819        let err = FileSplitStore::open(&path, invalid, 1).unwrap_err();
820        assert!(matches!(
821            err,
822            SamplerError::Configuration(ref msg) if msg.contains("split ratios must sum to 1.0")
823        ));
824    }
825
826    #[test]
827    fn zero_test_ratio_never_assigns_test_labels() {
828        let ratios = SplitRatios {
829            train: 0.5,
830            validation: 0.5,
831            test: 0.0,
832        };
833        let store = DeterministicSplitStore::new(ratios, 7).unwrap();
834
835        let mut saw_train = false;
836        let mut saw_validation = false;
837        for idx in 0..20_000 {
838            let id = format!("record_{idx}");
839            let label = store.ensure(id).unwrap();
840            assert_ne!(label, SplitLabel::Test);
841            saw_train |= label == SplitLabel::Train;
842            saw_validation |= label == SplitLabel::Validation;
843            if saw_train && saw_validation {
844                break;
845            }
846        }
847
848        assert!(saw_train);
849        assert!(saw_validation);
850    }
851
852    /// `save(None)` must publish the temp to `save_path` on disk.  A fresh
853    /// `open` of that path must see everything that was written before the save.
854    #[test]
855    fn save_none_publishes_to_save_path_and_reloads_cleanly() {
856        let dir = tempdir().unwrap();
857        let path = dir.path().join("persist.bin");
858        let ratios = SplitRatios::default();
859
860        {
861            let store = FileSplitStore::open(&path, ratios, 123).unwrap();
862            let mut meta = HashMap::new();
863            meta.insert(
864                SplitLabel::Train,
865                PersistedSplitMeta {
866                    epoch: 3,
867                    offset: 7,
868                    hashes_checksum: 42,
869                },
870            );
871            store.save_epoch_meta(&meta).unwrap();
872            let state = PersistedSamplerState {
873                source_cycle_idx: 11,
874                source_record_cursors: vec![("s".to_string(), 1)],
875                source_epoch: 5,
876                rng_state: 99,
877                triplet_recipe_rr_idx: 2,
878                text_recipe_rr_idx: 3,
879                source_stream_cursors: vec![],
880            };
881            assert!(!path.exists(), "save_path must not exist before save(None)");
882            store.save_sampler_state(&state, None).unwrap();
883            assert!(
884                path.exists(),
885                "save(None) must publish to save_path on disk"
886            );
887        }
888
889        // Fresh open must read everything back from disk.
890        let reopened = FileSplitStore::open(&path, ratios, 123).unwrap();
891        let loaded_state = reopened.load_sampler_state().unwrap().unwrap();
892        assert_eq!(loaded_state.source_cycle_idx, 11);
893        assert_eq!(loaded_state.source_epoch, 5);
894        assert_eq!(loaded_state.rng_state, 99);
895        let loaded_meta = reopened.load_epoch_meta().unwrap();
896        assert_eq!(loaded_meta.get(&SplitLabel::Train).unwrap().epoch, 3);
897    }
898
899    #[test]
900    fn file_store_rejects_seed_mismatch() {
901        let dir = tempdir().unwrap();
902        let path = dir.path().join("splits.json");
903        let ratios = SplitRatios::default();
904        let store = FileSplitStore::open(&path, ratios, 123).unwrap();
905        store.ensure("abc".to_string()).unwrap();
906        // Publish to disk so the seed is committed before the next open.
907        store
908            .save_sampler_state(
909                &PersistedSamplerState {
910                    source_cycle_idx: 0,
911                    source_record_cursors: vec![],
912                    source_epoch: 0,
913                    rng_state: 0,
914                    triplet_recipe_rr_idx: 0,
915                    text_recipe_rr_idx: 0,
916                    source_stream_cursors: vec![],
917                },
918                None,
919            )
920            .unwrap();
921        drop(store);
922
923        let err = FileSplitStore::open(&path, ratios, 999).unwrap_err();
924        assert!(matches!(
925            err,
926            SamplerError::SplitStore(msg) if msg.contains("seed")
927        ));
928    }
929
930    #[test]
931    fn file_store_accepts_directory_path() {
932        let dir = tempdir().unwrap();
933        let ratios = SplitRatios::default();
934        let err = FileSplitStore::open(dir.path(), ratios, 777).unwrap_err();
935        assert!(matches!(err, SamplerError::SplitStore(_)));
936    }
937
938    #[test]
939    fn bitcode_payload_requires_prefix() {
940        let err = decode_bitcode_payload(&[0x00, 0x01]).unwrap_err();
941        assert!(
942            matches!(err, SamplerError::SplitStore(msg) if msg.contains("missing expected prefix"))
943        );
944    }
945
946    #[test]
947    fn file_store_round_trips_epoch_and_sampler_state() {
948        let dir = tempdir().unwrap();
949        let path = dir.path().join("epoch_sampler_state.bin");
950        let store = FileSplitStore::open(&path, SplitRatios::default(), 222).unwrap();
951
952        assert!(store.load_epoch_hashes(SplitLabel::Test).unwrap().is_none());
953
954        let mut epoch_meta = HashMap::new();
955        epoch_meta.insert(
956            SplitLabel::Train,
957            PersistedSplitMeta {
958                epoch: 3,
959                offset: 7,
960                hashes_checksum: 42,
961            },
962        );
963        store.save_epoch_meta(&epoch_meta).unwrap();
964
965        let loaded_meta = store.load_epoch_meta().unwrap();
966        let loaded_train = loaded_meta.get(&SplitLabel::Train).unwrap();
967        assert_eq!(loaded_train.epoch, 3);
968        assert_eq!(loaded_train.offset, 7);
969        assert_eq!(loaded_train.hashes_checksum, 42);
970
971        let hashes = PersistedSplitHashes {
972            checksum: 99,
973            hashes: vec![10, 20, 30],
974        };
975        store
976            .save_epoch_hashes(SplitLabel::Validation, &hashes)
977            .unwrap();
978        let loaded_hashes = store
979            .load_epoch_hashes(SplitLabel::Validation)
980            .unwrap()
981            .unwrap();
982        assert_eq!(loaded_hashes.checksum, 99);
983        assert_eq!(loaded_hashes.hashes, vec![10, 20, 30]);
984
985        let state = PersistedSamplerState {
986            source_cycle_idx: 11,
987            source_record_cursors: vec![("source_a".to_string(), 1)],
988            source_epoch: 8,
989            rng_state: 1234,
990            triplet_recipe_rr_idx: 2,
991            text_recipe_rr_idx: 5,
992            source_stream_cursors: vec![("source_a".to_string(), 9)],
993        };
994        store.save_sampler_state(&state, None).unwrap();
995        let loaded_state = store.load_sampler_state().unwrap().unwrap();
996        assert_eq!(loaded_state.source_cycle_idx, 11);
997        assert_eq!(loaded_state.source_epoch, 8);
998        assert_eq!(loaded_state.rng_state, 1234);
999        assert_eq!(loaded_state.triplet_recipe_rr_idx, 2);
1000        assert_eq!(loaded_state.text_recipe_rr_idx, 5);
1001        assert_eq!(
1002            loaded_state.source_record_cursors,
1003            vec![("source_a".to_string(), 1)]
1004        );
1005        assert_eq!(
1006            loaded_state.source_stream_cursors,
1007            vec![("source_a".to_string(), 9)]
1008        );
1009
1010        // Verify the same data survives a drop + fresh open from disk.
1011        drop(store);
1012        let reopened = FileSplitStore::open(&path, SplitRatios::default(), 222).unwrap();
1013        let disk_state = reopened.load_sampler_state().unwrap().unwrap();
1014        assert_eq!(disk_state.source_cycle_idx, 11);
1015        assert_eq!(disk_state.source_epoch, 8);
1016        assert_eq!(disk_state.rng_state, 1234);
1017        let disk_meta = reopened.load_epoch_meta().unwrap();
1018        assert_eq!(disk_meta.get(&SplitLabel::Train).unwrap().epoch, 3);
1019        let disk_hashes = reopened
1020            .load_epoch_hashes(SplitLabel::Validation)
1021            .unwrap()
1022            .unwrap();
1023        assert_eq!(disk_hashes.checksum, 99);
1024    }
1025
1026    #[test]
1027    fn split_keys_and_labels_cover_helper_paths() {
1028        let key = split_key(&"abc".to_string());
1029        assert!(key.starts_with(SPLIT_PREFIX));
1030
1031        assert!(matches!(decode_label(b"0"), Ok(SplitLabel::Train)));
1032        assert!(matches!(decode_label(b"1"), Ok(SplitLabel::Validation)));
1033        assert!(matches!(decode_label(b"2"), Ok(SplitLabel::Test)));
1034        assert!(decode_label(b"x").is_err());
1035
1036        let epoch_meta_train = epoch_meta_key(SplitLabel::Train);
1037        let epoch_hashes_train = epoch_hashes_key(SplitLabel::Train);
1038        let epoch_hashes_test = epoch_hashes_key(SplitLabel::Test);
1039        assert!(epoch_meta_train.starts_with(EPOCH_META_PREFIX));
1040        assert!(epoch_hashes_train.starts_with(EPOCH_HASHES_PREFIX));
1041        assert!(epoch_hashes_test.starts_with(EPOCH_HASHES_PREFIX));
1042    }
1043
1044    #[test]
1045    fn encode_decode_store_meta_roundtrip_and_corrupt_prefix_error() {
1046        let meta = StoreMeta {
1047            version: STORE_VERSION,
1048            seed: 55,
1049            ratios: SplitRatios::default(),
1050        };
1051        let encoded = encode_store_meta(&meta);
1052        let decoded = decode_store_meta(&encoded).unwrap();
1053        assert_eq!(decoded.version, STORE_VERSION);
1054        assert_eq!(decoded.seed, 55);
1055
1056        let err = decode_store_meta(&[0x00, 0x01]).unwrap_err();
1057        assert!(matches!(
1058            err,
1059            SamplerError::SplitStore(msg) if msg.contains("missing expected prefix")
1060        ));
1061    }
1062
1063    #[test]
1064    fn epoch_and_sampler_decoders_cover_tombstone_and_version_mismatch() {
1065        assert!(decode_epoch_meta(&[]).unwrap().is_none());
1066        assert!(
1067            decode_epoch_meta(&[EPOCH_RECORD_TOMBSTONE])
1068                .unwrap()
1069                .is_none()
1070        );
1071        assert!(decode_epoch_hashes(&[]).unwrap().is_none());
1072        assert!(
1073            decode_epoch_hashes(&[EPOCH_RECORD_TOMBSTONE])
1074                .unwrap()
1075                .is_none()
1076        );
1077        assert!(decode_sampler_state(&[]).unwrap().is_none());
1078
1079        let meta_mismatch = decode_epoch_meta(&[EPOCH_META_RECORD_VERSION.wrapping_add(1), 1]);
1080        assert!(matches!(
1081            meta_mismatch,
1082            Err(SamplerError::SplitStore(msg)) if msg.contains("version mismatch")
1083        ));
1084        let hashes_mismatch = decode_epoch_hashes(&[EPOCH_HASH_RECORD_VERSION.wrapping_add(1), 1]);
1085        assert!(matches!(
1086            hashes_mismatch,
1087            Err(SamplerError::SplitStore(msg)) if msg.contains("version mismatch")
1088        ));
1089        let state_mismatch =
1090            decode_sampler_state(&[SAMPLER_STATE_RECORD_VERSION.wrapping_add(1), 1]);
1091        assert!(matches!(
1092            state_mismatch,
1093            Err(SamplerError::SplitStore(msg)) if msg.contains("version mismatch")
1094        ));
1095    }
1096
1097    #[test]
1098    fn split_store_trait_methods_and_path_helpers_are_exercised() {
1099        let dir = tempdir().unwrap();
1100        let file_path = dir.path().join("nested").join("store.bin");
1101        ensure_parent_dir(&file_path).unwrap();
1102        assert!(file_path.parent().unwrap().exists());
1103
1104        let existing_dir_path = coerce_store_path(dir.path().to_path_buf());
1105        assert_eq!(existing_dir_path, dir.path().to_path_buf());
1106
1107        let ratios = SplitRatios::default();
1108        let store = FileSplitStore::open(&file_path, ratios, 444).unwrap();
1109        assert!((store.ratios().train - ratios.train).abs() < 1e-6);
1110        store
1111            .upsert("record_1".to_string(), SplitLabel::Validation)
1112            .unwrap();
1113        let ensured = store.ensure("record_1".to_string()).unwrap();
1114        assert!(matches!(
1115            ensured,
1116            SplitLabel::Train | SplitLabel::Validation | SplitLabel::Test
1117        ));
1118
1119        let mapped = map_store_err(io::Error::other("boom"));
1120        assert!(matches!(mapped, SamplerError::SplitStore(msg) if msg.contains("boom")));
1121    }
1122
1123    #[test]
1124    fn epoch_and_sampler_encode_decode_roundtrips() {
1125        let meta = PersistedSplitMeta {
1126            epoch: 4,
1127            offset: 9,
1128            hashes_checksum: 21,
1129        };
1130        let encoded_meta = encode_epoch_meta(Some(&meta));
1131        let decoded_meta = decode_epoch_meta(&encoded_meta).unwrap().unwrap();
1132        assert_eq!(decoded_meta.epoch, 4);
1133        assert_eq!(decoded_meta.offset, 9);
1134
1135        let hashes = PersistedSplitHashes {
1136            checksum: 7,
1137            hashes: vec![1, 2, 3],
1138        };
1139        let encoded_hashes = encode_epoch_hashes(&hashes);
1140        let decoded_hashes = decode_epoch_hashes(&encoded_hashes).unwrap().unwrap();
1141        assert_eq!(decoded_hashes.checksum, 7);
1142        assert_eq!(decoded_hashes.hashes, vec![1, 2, 3]);
1143
1144        let state = PersistedSamplerState {
1145            source_cycle_idx: 1,
1146            source_record_cursors: vec![("s".to_string(), 2)],
1147            source_epoch: 3,
1148            rng_state: 4,
1149            triplet_recipe_rr_idx: 5,
1150            text_recipe_rr_idx: 6,
1151            source_stream_cursors: vec![("s".to_string(), 7)],
1152        };
1153        let encoded_state = encode_sampler_state(&state);
1154        let decoded_state = decode_sampler_state(&encoded_state).unwrap().unwrap();
1155        assert_eq!(decoded_state.source_cycle_idx, 1);
1156        assert_eq!(decoded_state.source_epoch, 3);
1157        assert_eq!(decoded_state.rng_state, 4);
1158    }
1159
1160    #[test]
1161    fn deterministic_store_trait_methods_work() {
1162        let ratios = SplitRatios::default();
1163        let store = DeterministicSplitStore::new(ratios, 999).unwrap();
1164
1165        assert_eq!(store.ratios().train, ratios.train);
1166
1167        let id = "source::record".to_string();
1168        let derived = store.label_for(&id).unwrap();
1169        store.upsert(id.clone(), SplitLabel::Validation).unwrap();
1170        assert_eq!(store.label_for(&id), Some(SplitLabel::Validation));
1171        assert!(matches!(
1172            derived,
1173            SplitLabel::Train | SplitLabel::Validation | SplitLabel::Test
1174        ));
1175
1176        let mut meta = HashMap::new();
1177        meta.insert(
1178            SplitLabel::Test,
1179            PersistedSplitMeta {
1180                epoch: 1,
1181                offset: 2,
1182                hashes_checksum: 3,
1183            },
1184        );
1185        store.save_epoch_meta(&meta).unwrap();
1186        let loaded_meta = store.load_epoch_meta().unwrap();
1187        assert_eq!(loaded_meta.get(&SplitLabel::Test).unwrap().offset, 2);
1188
1189        assert!(
1190            store
1191                .load_epoch_hashes(SplitLabel::Train)
1192                .unwrap()
1193                .is_none()
1194        );
1195        store
1196            .save_epoch_hashes(
1197                SplitLabel::Train,
1198                &PersistedSplitHashes {
1199                    checksum: 11,
1200                    hashes: vec![4, 5],
1201                },
1202            )
1203            .unwrap();
1204        assert_eq!(
1205            store
1206                .load_epoch_hashes(SplitLabel::Train)
1207                .unwrap()
1208                .unwrap()
1209                .checksum,
1210            11
1211        );
1212
1213        assert!(store.load_sampler_state().unwrap().is_none());
1214        let sampler_state = PersistedSamplerState {
1215            source_cycle_idx: 1,
1216            source_record_cursors: vec![("s1".to_string(), 2)],
1217            source_epoch: 3,
1218            rng_state: 4,
1219            triplet_recipe_rr_idx: 5,
1220            text_recipe_rr_idx: 6,
1221            source_stream_cursors: vec![("s1".to_string(), 7)],
1222        };
1223        store.save_sampler_state(&sampler_state, None).unwrap();
1224        assert_eq!(
1225            store.load_sampler_state().unwrap().unwrap().source_epoch,
1226            sampler_state.source_epoch
1227        );
1228    }
1229
1230    #[test]
1231    fn open_with_load_path_bootstraps_state_explicitly() {
1232        let dir = tempdir().unwrap();
1233        let path_a = dir.path().join("snapshot_a.bin");
1234        let path_b = dir.path().join("snapshot_b.bin");
1235        let ratios = SplitRatios::default();
1236
1237        let store_a = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1238
1239        let mut meta = HashMap::new();
1240        meta.insert(
1241            SplitLabel::Train,
1242            PersistedSplitMeta {
1243                epoch: 5,
1244                offset: 3,
1245                hashes_checksum: 999,
1246            },
1247        );
1248        store_a.save_epoch_meta(&meta).unwrap();
1249
1250        let sampler_state = PersistedSamplerState {
1251            source_cycle_idx: 1,
1252            source_record_cursors: vec![("s1".to_string(), 2)],
1253            source_epoch: 7,
1254            rng_state: 123,
1255            triplet_recipe_rr_idx: 4,
1256            text_recipe_rr_idx: 6,
1257            source_stream_cursors: vec![("s1".to_string(), 8)],
1258        };
1259        store_a.save_sampler_state(&sampler_state, None).unwrap();
1260        drop(store_a);
1261
1262        let store_b =
1263            FileSplitStore::open_with_load_path(Some(path_a.clone()), &path_b, ratios, 42).unwrap();
1264        assert_eq!(
1265            store_b
1266                .load_epoch_meta()
1267                .unwrap()
1268                .get(&SplitLabel::Train)
1269                .unwrap()
1270                .epoch,
1271            5
1272        );
1273        assert_eq!(
1274            store_b.load_sampler_state().unwrap().unwrap().source_epoch,
1275            7
1276        );
1277
1278        let store_a_again = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1279        assert_eq!(
1280            store_a_again
1281                .load_epoch_meta()
1282                .unwrap()
1283                .get(&SplitLabel::Train)
1284                .unwrap()
1285                .epoch,
1286            5
1287        );
1288        assert_eq!(
1289            store_a_again
1290                .load_sampler_state()
1291                .unwrap()
1292                .unwrap()
1293                .source_epoch,
1294            7
1295        );
1296    }
1297
1298    #[test]
1299    fn save_sampler_state_to_new_path_copies_existing_store_first() {
1300        let dir = tempdir().unwrap();
1301        let path_a = dir.path().join("source_store.bin");
1302        let path_b = dir.path().join("mirror_store.bin");
1303        let ratios = SplitRatios::default();
1304
1305        let store_a = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1306
1307        let assigned_id = "record_with_assignment".to_string();
1308        let assigned_key = split_key(&assigned_id);
1309        store_a.store.write(&assigned_key, b"1").unwrap();
1310
1311        let sampler_state = PersistedSamplerState {
1312            source_cycle_idx: 1,
1313            source_record_cursors: vec![("s1".to_string(), 2)],
1314            source_epoch: 9,
1315            rng_state: 123,
1316            triplet_recipe_rr_idx: 4,
1317            text_recipe_rr_idx: 6,
1318            source_stream_cursors: vec![("s1".to_string(), 8)],
1319        };
1320
1321        store_a
1322            .save_sampler_state(&sampler_state, Some(path_b.as_path()))
1323            .unwrap();
1324
1325        // Destination gets the existing store data AND the new sampler state.
1326        let store_b = FileSplitStore::open(&path_b, ratios, 42).unwrap();
1327        assert_eq!(
1328            store_b.label_for(&assigned_id),
1329            Some(SplitLabel::Validation)
1330        );
1331        assert_eq!(
1332            store_b.load_sampler_state().unwrap().unwrap().source_epoch,
1333            9
1334        );
1335
1336        // save_to=Some(path_b) must not publish to path_a (the canonical save_path).
1337        assert!(
1338            !path_a.exists(),
1339            "save_to=Some(...) must not publish to the canonical save_path"
1340        );
1341    }
1342
1343    /// Saving to a custom path on a plain `open` store must not mutate the
1344    /// working store file -- i.e. a previously-saved state is preserved.
1345    #[test]
1346    fn save_some_on_regular_open_does_not_modify_working_store() {
1347        let dir = tempdir().unwrap();
1348        let path_a = dir.path().join("working_store.bin");
1349        let path_b = dir.path().join("checkpoint_store.bin");
1350        let ratios = SplitRatios::default();
1351
1352        let store_a = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1353
1354        // Establish a baseline state in the working store.
1355        let initial_state = PersistedSamplerState {
1356            source_cycle_idx: 1,
1357            source_record_cursors: vec![],
1358            source_epoch: 1,
1359            rng_state: 0,
1360            triplet_recipe_rr_idx: 0,
1361            text_recipe_rr_idx: 0,
1362            source_stream_cursors: vec![],
1363        };
1364        store_a.save_sampler_state(&initial_state, None).unwrap();
1365
1366        // Snapshot a newer state to a separate checkpoint path.
1367        let checkpoint_state = PersistedSamplerState {
1368            source_cycle_idx: 99,
1369            source_record_cursors: vec![],
1370            source_epoch: 99,
1371            rng_state: 42,
1372            triplet_recipe_rr_idx: 0,
1373            text_recipe_rr_idx: 0,
1374            source_stream_cursors: vec![],
1375        };
1376        store_a
1377            .save_sampler_state(&checkpoint_state, Some(path_b.as_path()))
1378            .unwrap();
1379
1380        // The on-disk save_path (path_a) must not have been overwritten by save_to=Some(...).
1381        // Re-open from disk to verify; path_a was last published by save_to=None above.
1382        drop(store_a);
1383        let store_a_disk = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1384        assert_eq!(
1385            store_a_disk
1386                .load_sampler_state()
1387                .unwrap()
1388                .unwrap()
1389                .source_epoch,
1390            1,
1391            "save_to=Some(...) must not overwrite the on-disk save_path"
1392        );
1393
1394        // The checkpoint must hold the new state.
1395        let store_b = FileSplitStore::open(&path_b, ratios, 42).unwrap();
1396        let state_from_b = store_b.load_sampler_state().unwrap().unwrap();
1397        assert_eq!(
1398            state_from_b.source_epoch, 99,
1399            "checkpoint store must hold the snapshotted state"
1400        );
1401    }
1402
1403    #[test]
1404    fn file_store_metadata_mismatch_and_debug_paths_are_covered() {
1405        let dir = tempdir().unwrap();
1406        let path = dir.path().join("meta_mismatch.bin");
1407        let ratios = SplitRatios::default();
1408        let store = FileSplitStore::open(&path, ratios, 123).unwrap();
1409
1410        let debug_repr = format!("{store:?}");
1411        assert!(debug_repr.contains("FileSplitStore"));
1412
1413        let wrong_version = StoreMeta {
1414            version: STORE_VERSION.wrapping_add(1),
1415            seed: 123,
1416            ratios,
1417        };
1418        let payload = encode_store_meta(&wrong_version);
1419        store.store.write(META_KEY, &payload).unwrap();
1420        // Publish the corrupted temp to disk so the next open reads the bad version.
1421        store
1422            .save_sampler_state(
1423                &PersistedSamplerState {
1424                    source_cycle_idx: 0,
1425                    source_record_cursors: vec![],
1426                    source_epoch: 0,
1427                    rng_state: 0,
1428                    triplet_recipe_rr_idx: 0,
1429                    text_recipe_rr_idx: 0,
1430                    source_stream_cursors: vec![],
1431                },
1432                None,
1433            )
1434            .unwrap();
1435        drop(store);
1436
1437        let err = FileSplitStore::open(&path, ratios, 123).unwrap_err();
1438        assert!(matches!(err, SamplerError::SplitStore(msg) if msg.contains("version mismatch")));
1439
1440        let ratio_path = dir.path().join("ratio_mismatch.bin");
1441        let baseline = FileSplitStore::open(&ratio_path, ratios, 777).unwrap();
1442        // Publish so ratio_path exists on disk before reopening with different ratios.
1443        baseline
1444            .save_sampler_state(
1445                &PersistedSamplerState {
1446                    source_cycle_idx: 0,
1447                    source_record_cursors: vec![],
1448                    source_epoch: 0,
1449                    rng_state: 0,
1450                    triplet_recipe_rr_idx: 0,
1451                    text_recipe_rr_idx: 0,
1452                    source_stream_cursors: vec![],
1453                },
1454                None,
1455            )
1456            .unwrap();
1457        drop(baseline);
1458
1459        let different_ratios = SplitRatios {
1460            train: 0.7,
1461            validation: 0.2,
1462            test: 0.1,
1463        };
1464        let err = FileSplitStore::open(&ratio_path, different_ratios, 777).unwrap_err();
1465        assert!(matches!(err, SamplerError::SplitStore(msg) if msg.contains("ratios mismatch")));
1466    }
1467
1468    #[test]
1469    fn split_decode_helpers_reject_corrupt_bitcode_payloads() {
1470        let store_meta_err = decode_store_meta(&[BITCODE_PREFIX, 0xFF, 0xEE]).unwrap_err();
1471        assert!(matches!(
1472            store_meta_err,
1473            SamplerError::SplitStore(msg) if msg.contains("failed to decode split store metadata")
1474        ));
1475
1476        let epoch_meta_err =
1477            decode_epoch_meta(&[EPOCH_META_RECORD_VERSION, BITCODE_PREFIX, 0xFF]).unwrap_err();
1478        assert!(
1479            matches!(epoch_meta_err, SamplerError::SplitStore(msg) if msg.contains("corrupt epoch meta record"))
1480        );
1481
1482        let epoch_hashes_err =
1483            decode_epoch_hashes(&[EPOCH_HASH_RECORD_VERSION, BITCODE_PREFIX, 0xFF]).unwrap_err();
1484        assert!(matches!(
1485            epoch_hashes_err,
1486            SamplerError::SplitStore(msg) if msg.contains("corrupt epoch hashes record")
1487        ));
1488
1489        let sampler_state_err =
1490            decode_sampler_state(&[SAMPLER_STATE_RECORD_VERSION, BITCODE_PREFIX, 0xFF])
1491                .unwrap_err();
1492        assert!(matches!(
1493            sampler_state_err,
1494            SamplerError::SplitStore(msg) if msg.contains("corrupt sampler state record")
1495        ));
1496    }
1497
1498    #[test]
1499    fn file_store_label_fallback_and_validation_keys_are_covered() {
1500        let dir = tempdir().unwrap();
1501        let path = dir.path().join("labels.bin");
1502        let store = FileSplitStore::open(&path, SplitRatios::default(), 42).unwrap();
1503
1504        let id = "bad_label_record".to_string();
1505        let expected = derive_label_for_id(&id, 42, SplitRatios::default());
1506        let key = split_key(&id);
1507
1508        store.store.write(&key, b"x").unwrap();
1509        assert_eq!(store.label_for(&id), Some(expected));
1510
1511        store.store.write(&key, b"1").unwrap();
1512        assert_eq!(store.label_for(&id), Some(SplitLabel::Validation));
1513
1514        let meta_validation = epoch_meta_key(SplitLabel::Validation);
1515        let hashes_validation = epoch_hashes_key(SplitLabel::Validation);
1516        assert!(meta_validation.starts_with(EPOCH_META_PREFIX));
1517        assert!(hashes_validation.starts_with(EPOCH_HASHES_PREFIX));
1518        assert!(meta_validation.ends_with(b"validation"));
1519        assert!(hashes_validation.ends_with(b"validation"));
1520    }
1521
1522    #[test]
1523    fn ensure_parent_dir_allows_plain_file_names() {
1524        ensure_parent_dir(Path::new("split_store_local.bin")).unwrap();
1525        let coerced = coerce_store_path(PathBuf::from("explicit_store.bin"));
1526        assert_eq!(coerced, PathBuf::from("explicit_store.bin"));
1527    }
1528
1529    // -----------------------------------------------------------------------
1530    // Temp-dir bootstrap contract
1531    // -----------------------------------------------------------------------
1532
1533    /// `open_with_load_path` must NOT modify the source file while the store is open.
1534    #[test]
1535    fn load_path_source_is_never_modified_while_open() {
1536        let dir = tempdir().unwrap();
1537        let source = dir.path().join("source.bin");
1538        let dest = dir.path().join("dest.bin");
1539        let ratios = SplitRatios::default();
1540
1541        // Seed the source with known state.
1542        let seeded = FileSplitStore::open(&source, ratios, 77).unwrap();
1543        let state = PersistedSamplerState {
1544            source_cycle_idx: 5,
1545            source_record_cursors: vec![("s".to_string(), 3)],
1546            source_epoch: 9,
1547            rng_state: 42,
1548            triplet_recipe_rr_idx: 1,
1549            text_recipe_rr_idx: 2,
1550            source_stream_cursors: vec![("s".to_string(), 4)],
1551        };
1552        seeded.save_sampler_state(&state, None).unwrap();
1553        drop(seeded);
1554
1555        let source_size_before = std::fs::metadata(&source).unwrap().len();
1556
1557        // Open bootstrapped store and write mutations to it.
1558        let bootstrapped =
1559            FileSplitStore::open_with_load_path(Some(&source), &dest, ratios, 77).unwrap();
1560        let new_state = PersistedSamplerState {
1561            source_cycle_idx: 99,
1562            source_record_cursors: vec![("s".to_string(), 77)],
1563            source_epoch: 100,
1564            rng_state: 0,
1565            triplet_recipe_rr_idx: 0,
1566            text_recipe_rr_idx: 0,
1567            source_stream_cursors: vec![],
1568        };
1569        bootstrapped.save_sampler_state(&new_state, None).unwrap();
1570        drop(bootstrapped);
1571
1572        // Source must be byte-identical to before bootstrap.
1573        let source_size_after = std::fs::metadata(&source).unwrap().len();
1574        assert_eq!(
1575            source_size_before, source_size_after,
1576            "source file was modified during bootstrapped open"
1577        );
1578
1579        // Verify source still holds the original state.
1580        let verify_source = FileSplitStore::open(&source, ratios, 77).unwrap();
1581        let loaded = verify_source.load_sampler_state().unwrap().unwrap();
1582        assert_eq!(loaded.source_cycle_idx, 5);
1583        assert_eq!(loaded.source_epoch, 9);
1584    }
1585
1586    /// `save_sampler_state(None)` on a bootstrapped store must publish to the
1587    /// declared `save_path` only, not to the source and not to a temp path.
1588    #[test]
1589    fn save_none_on_bootstrapped_store_publishes_to_save_path() {
1590        let dir = tempdir().unwrap();
1591        let source = dir.path().join("load.bin");
1592        let dest = dir.path().join("save.bin");
1593        let ratios = SplitRatios::default();
1594
1595        let _ = FileSplitStore::open(&source, ratios, 11).unwrap();
1596
1597        assert!(!dest.exists(), "dest must not exist before first save");
1598
1599        let store = FileSplitStore::open_with_load_path(Some(&source), &dest, ratios, 11).unwrap();
1600        let state = PersistedSamplerState {
1601            source_cycle_idx: 7,
1602            source_record_cursors: vec![],
1603            source_epoch: 2,
1604            rng_state: 1,
1605            triplet_recipe_rr_idx: 0,
1606            text_recipe_rr_idx: 0,
1607            source_stream_cursors: vec![],
1608        };
1609        store.save_sampler_state(&state, None).unwrap();
1610        drop(store);
1611
1612        assert!(dest.exists(), "dest must exist after save(None)");
1613
1614        let loaded_dest = FileSplitStore::open(&dest, ratios, 11).unwrap();
1615        assert_eq!(
1616            loaded_dest
1617                .load_sampler_state()
1618                .unwrap()
1619                .unwrap()
1620                .source_cycle_idx,
1621            7
1622        );
1623    }
1624
1625    /// `save_sampler_state(Some(other))` on a bootstrapped store must publish to
1626    /// `other` only — the declared `save_path` must remain absent.
1627    #[test]
1628    fn save_some_on_bootstrapped_store_publishes_to_explicit_path_only() {
1629        let dir = tempdir().unwrap();
1630        let source = dir.path().join("load.bin");
1631        let save = dir.path().join("save.bin"); // canonical — should stay empty
1632        let other = dir.path().join("other.bin"); // explicit target
1633        let ratios = SplitRatios::default();
1634
1635        let _ = FileSplitStore::open(&source, ratios, 22).unwrap();
1636
1637        let store = FileSplitStore::open_with_load_path(Some(&source), &save, ratios, 22).unwrap();
1638        let state = PersistedSamplerState {
1639            source_cycle_idx: 3,
1640            source_record_cursors: vec![],
1641            source_epoch: 1,
1642            rng_state: 0,
1643            triplet_recipe_rr_idx: 0,
1644            text_recipe_rr_idx: 0,
1645            source_stream_cursors: vec![],
1646        };
1647        store
1648            .save_sampler_state(&state, Some(other.as_path()))
1649            .unwrap();
1650        drop(store);
1651
1652        assert!(
1653            !save.exists(),
1654            "canonical save_path must not be created when saving to explicit path"
1655        );
1656        assert!(other.exists(), "explicit target must be created");
1657
1658        let loaded_other = FileSplitStore::open(&other, ratios, 22).unwrap();
1659        assert_eq!(
1660            loaded_other
1661                .load_sampler_state()
1662                .unwrap()
1663                .unwrap()
1664                .source_cycle_idx,
1665            3
1666        );
1667    }
1668
1669    /// Repeated `save(None)` calls on a bootstrapped store are idempotent: the
1670    /// canonical save_path is overwritten cleanly each time.
1671    #[test]
1672    fn repeated_save_none_on_bootstrapped_store_is_idempotent() {
1673        let dir = tempdir().unwrap();
1674        let source = dir.path().join("load.bin");
1675        let dest = dir.path().join("save.bin");
1676        let ratios = SplitRatios::default();
1677
1678        let _ = FileSplitStore::open(&source, ratios, 33).unwrap();
1679
1680        let store = FileSplitStore::open_with_load_path(Some(&source), &dest, ratios, 33).unwrap();
1681
1682        for cycle_idx in [1_u64, 2, 3] {
1683            let state = PersistedSamplerState {
1684                source_cycle_idx: cycle_idx,
1685                source_record_cursors: vec![],
1686                source_epoch: 0,
1687                rng_state: 0,
1688                triplet_recipe_rr_idx: 0,
1689                text_recipe_rr_idx: 0,
1690                source_stream_cursors: vec![],
1691            };
1692            store.save_sampler_state(&state, None).unwrap();
1693        }
1694        drop(store);
1695
1696        let reloaded = FileSplitStore::open(&dest, ratios, 33).unwrap();
1697        assert_eq!(
1698            reloaded
1699                .load_sampler_state()
1700                .unwrap()
1701                .unwrap()
1702                .source_cycle_idx,
1703            3,
1704            "dest should hold the last-saved state"
1705        );
1706    }
1707}