Skip to main content

triplets_core/
splits.rs

1use serde::{Deserialize, Serialize};
2use simd_r_drive::storage_engine::DataStore;
3use simd_r_drive::storage_engine::traits::{DataStoreReader, DataStoreWriter};
4use std::collections::HashMap;
5use std::fmt;
6use std::fs;
7use std::hash::{Hash, Hasher};
8use std::io;
9use std::path::{Path, PathBuf};
10use std::sync::RwLock;
11use tempfile::TempDir;
12
13use crate::constants::splits::{
14    ALL_SPLITS, BITCODE_PREFIX, EPOCH_HASH_RECORD_VERSION, EPOCH_HASHES_PREFIX, EPOCH_META_PREFIX,
15    EPOCH_META_RECORD_VERSION, EPOCH_RECORD_TOMBSTONE, META_KEY, SAMPLER_STATE_KEY,
16    SAMPLER_STATE_RECORD_VERSION, SPLIT_PREFIX, STORE_VERSION,
17};
18use crate::data::RecordId;
19use crate::errors::SamplerError;
20use crate::types::SourceId;
21
22/// Logical dataset partitions used during sampling.
23#[derive(
24    Clone,
25    Copy,
26    Debug,
27    PartialEq,
28    Eq,
29    Hash,
30    Serialize,
31    Deserialize,
32    bitcode::Encode,
33    bitcode::Decode,
34)]
35pub enum SplitLabel {
36    /// Training split.
37    Train,
38    /// Validation split.
39    Validation,
40    /// Test split.
41    Test,
42}
43
44/// Ratio configuration for train/validation/test assignment.
45#[derive(Clone, Copy, Debug, Serialize, Deserialize, bitcode::Encode, bitcode::Decode)]
46pub struct SplitRatios {
47    /// Fraction assigned to train.
48    pub train: f32,
49    /// Fraction assigned to validation.
50    pub validation: f32,
51    /// Fraction assigned to test.
52    pub test: f32,
53}
54
55impl Default for SplitRatios {
56    fn default() -> Self {
57        Self {
58            train: 0.8,
59            validation: 0.1,
60            test: 0.1,
61        }
62    }
63}
64
65impl SplitRatios {
66    /// Validate that ratios sum to `1.0` (within epsilon).
67    pub fn normalized(self) -> Result<Self, SamplerError> {
68        let sum = self.train + self.validation + self.test;
69        if (sum - 1.0).abs() > 1e-6 {
70            return Err(SamplerError::Configuration(
71                "split ratios must sum to 1.0".to_string(),
72            ));
73        }
74        Ok(self)
75    }
76}
77
78pub use crate::constants::splits::EPOCH_STATE_VERSION;
79
80/// Persisted epoch cursor metadata for one split.
81#[derive(Clone, Debug, bitcode::Encode, bitcode::Decode)]
82pub struct PersistedSplitMeta {
83    /// Current epoch for this split.
84    pub epoch: u64,
85    /// Cursor offset within the epoch hash list.
86    pub offset: u64,
87    /// Checksum of the persisted hash list.
88    pub hashes_checksum: u64,
89}
90
91/// Persisted deterministic epoch hash ordering for one split.
92#[derive(Clone, Debug, bitcode::Encode, bitcode::Decode)]
93pub struct PersistedSplitHashes {
94    /// Checksum of `hashes`.
95    pub checksum: u64,
96    /// Deterministic per-epoch hash ordering.
97    pub hashes: Vec<u64>,
98}
99
100/// Persisted sampler runtime state (cursors, recipe indices, RNG).
101#[derive(Clone, Debug, bitcode::Encode, bitcode::Decode)]
102pub struct PersistedSamplerState {
103    /// Source-cycle round-robin index.
104    pub source_cycle_idx: u64,
105    /// Per-source record cursors.
106    pub source_record_cursors: Vec<(SourceId, u64)>,
107    /// Current source epoch used for deterministic reshuffle.
108    pub source_epoch: u64,
109    /// Deterministic RNG internal state.
110    pub rng_state: u64,
111    /// Round-robin index for triplet recipes.
112    pub triplet_recipe_rr_idx: u64,
113    /// Round-robin index for text recipes.
114    pub text_recipe_rr_idx: u64,
115    /// Persisted source stream refresh cursors.
116    /// Stores per-source I/O cursors and, for backward compatibility, the
117    /// step counter as a synthetic `(STEP_CURSOR_KEY, revision)` entry.
118    pub source_stream_cursors: Vec<(SourceId, u64)>,
119}
120
121/// Split assignment backend.
122///
123/// Implementations map `RecordId` values to split labels deterministically.
124pub trait SplitStore: Send + Sync {
125    /// Return split label for `id` if known/derivable.
126    fn label_for(&self, id: &RecordId) -> Option<SplitLabel>;
127    /// Persist an explicit split assignment for `id`.
128    fn upsert(&self, id: RecordId, label: SplitLabel) -> Result<(), SamplerError>;
129    /// Return configured split ratios.
130    fn ratios(&self) -> SplitRatios;
131    /// Return the split label for `id`, creating/deriving one when needed.
132    fn ensure(&self, id: RecordId) -> Result<SplitLabel, SamplerError>;
133}
134
135/// Persistence backend for epoch metadata and epoch hash orderings.
136pub trait EpochStateStore: Send + Sync {
137    /// Load split→epoch metadata map.
138    fn load_epoch_meta(&self) -> Result<HashMap<SplitLabel, PersistedSplitMeta>, SamplerError>;
139    /// Load persisted epoch hashes for one split, if available.
140    fn load_epoch_hashes(
141        &self,
142        label: SplitLabel,
143    ) -> Result<Option<PersistedSplitHashes>, SamplerError>;
144    /// Persist split→epoch metadata map.
145    fn save_epoch_meta(
146        &self,
147        meta: &HashMap<SplitLabel, PersistedSplitMeta>,
148    ) -> Result<(), SamplerError>;
149    /// Persist epoch hash list for one split.
150    fn save_epoch_hashes(
151        &self,
152        label: SplitLabel,
153        hashes: &PersistedSplitHashes,
154    ) -> Result<(), SamplerError>;
155}
156
157/// Persistence backend for sampler runtime state.
158pub trait SamplerStateStore: Send + Sync {
159    /// Load persisted sampler runtime state, if present.
160    fn load_sampler_state(&self) -> Result<Option<PersistedSamplerState>, SamplerError>;
161    /// Save sampler runtime state, optionally mirroring to `save_path`.
162    fn save_sampler_state(
163        &self,
164        state: &PersistedSamplerState,
165        save_path: Option<&Path>,
166    ) -> Result<(), SamplerError>;
167}
168
169/// In-memory split store with deterministic assignment derivation.
170pub struct DeterministicSplitStore {
171    ratios: SplitRatios,
172    assignments: RwLock<HashMap<RecordId, SplitLabel>>,
173    seed: u64,
174    epoch_meta: RwLock<HashMap<SplitLabel, PersistedSplitMeta>>,
175    epoch_hashes: RwLock<HashMap<SplitLabel, PersistedSplitHashes>>,
176    sampler_state: RwLock<Option<PersistedSamplerState>>,
177}
178
179impl DeterministicSplitStore {
180    /// Create an in-memory split store configured with `ratios` and `seed`.
181    pub fn new(ratios: SplitRatios, seed: u64) -> Result<Self, SamplerError> {
182        ratios.normalized()?;
183        Ok(Self {
184            ratios,
185            assignments: RwLock::new(HashMap::new()),
186            seed,
187            epoch_meta: RwLock::new(HashMap::new()),
188            epoch_hashes: RwLock::new(HashMap::new()),
189            sampler_state: RwLock::new(None),
190        })
191    }
192
193    fn derive_label(&self, id: &RecordId) -> SplitLabel {
194        derive_label_for_id(id, self.seed, self.ratios)
195    }
196}
197
198impl SplitStore for DeterministicSplitStore {
199    fn label_for(&self, id: &RecordId) -> Option<SplitLabel> {
200        if let Some(label) = self.assignments.read().ok()?.get(id).copied() {
201            return Some(label);
202        }
203        Some(self.derive_label(id))
204    }
205
206    fn upsert(&self, id: RecordId, label: SplitLabel) -> Result<(), SamplerError> {
207        let mut guard = self
208            .assignments
209            .write()
210            .map_err(|_| SamplerError::SplitStore("lock poisoned".into()))?;
211        guard.insert(id, label);
212        Ok(())
213    }
214
215    fn ratios(&self) -> SplitRatios {
216        self.ratios
217    }
218
219    fn ensure(&self, id: RecordId) -> Result<SplitLabel, SamplerError> {
220        Ok(self.derive_label(&id))
221    }
222}
223
224impl EpochStateStore for DeterministicSplitStore {
225    fn load_epoch_meta(&self) -> Result<HashMap<SplitLabel, PersistedSplitMeta>, SamplerError> {
226        self.epoch_meta
227            .read()
228            .map_err(|_| SamplerError::SplitStore("epoch meta lock poisoned".into()))
229            .map(|guard| guard.clone())
230    }
231
232    fn load_epoch_hashes(
233        &self,
234        label: SplitLabel,
235    ) -> Result<Option<PersistedSplitHashes>, SamplerError> {
236        Ok(self
237            .epoch_hashes
238            .read()
239            .map_err(|_| SamplerError::SplitStore("epoch hashes lock poisoned".into()))?
240            .get(&label)
241            .cloned())
242    }
243
244    fn save_epoch_meta(
245        &self,
246        meta: &HashMap<SplitLabel, PersistedSplitMeta>,
247    ) -> Result<(), SamplerError> {
248        *self
249            .epoch_meta
250            .write()
251            .map_err(|_| SamplerError::SplitStore("epoch meta lock poisoned".into()))? =
252            meta.clone();
253        Ok(())
254    }
255
256    fn save_epoch_hashes(
257        &self,
258        label: SplitLabel,
259        hashes: &PersistedSplitHashes,
260    ) -> Result<(), SamplerError> {
261        self.epoch_hashes
262            .write()
263            .map_err(|_| SamplerError::SplitStore("epoch hashes lock poisoned".into()))?
264            .insert(label, hashes.clone());
265        Ok(())
266    }
267}
268
269impl SamplerStateStore for DeterministicSplitStore {
270    fn load_sampler_state(&self) -> Result<Option<PersistedSamplerState>, SamplerError> {
271        self.sampler_state
272            .read()
273            .map_err(|_| SamplerError::SplitStore("sampler state lock poisoned".into()))
274            .map(|guard| guard.clone())
275    }
276
277    fn save_sampler_state(
278        &self,
279        state: &PersistedSamplerState,
280        _save_path: Option<&Path>,
281    ) -> Result<(), SamplerError> {
282        *self
283            .sampler_state
284            .write()
285            .map_err(|_| SamplerError::SplitStore("sampler state lock poisoned".into()))? =
286            Some(state.clone());
287        Ok(())
288    }
289}
290
291#[derive(Clone, Copy, Debug, bitcode::Encode, bitcode::Decode)]
292/// Versioned metadata header stored in file-backed split stores.
293struct StoreMeta {
294    version: u8,
295    seed: u64,
296    ratios: SplitRatios,
297}
298
299fn encode_store_meta(meta: &StoreMeta) -> Vec<u8> {
300    encode_bitcode_payload(&bitcode::encode(meta))
301}
302
303fn decode_store_meta(bytes: &[u8]) -> Result<StoreMeta, SamplerError> {
304    let raw = decode_bitcode_payload(bytes)?;
305    bitcode::decode(&raw).map_err(|err| {
306        SamplerError::SplitStore(format!("failed to decode split store metadata: {err}"))
307    })
308}
309
310/// File-backed split store for persistent runs.
311///
312/// Persists assignment metadata, epoch state, and sampler runtime state.
313///
314/// The store **always** works against a private temporary copy of the source
315/// snapshot.  All reads and mutations accumulate in the temp file.
316/// State is published to permanent storage only when
317/// [`SamplerStateStore::save_sampler_state`] is called:
318///
319/// * `save_to == None`  → publish temp to `save_path` (may overwrite).
320/// * `save_to == Some(p)` → publish temp to `p` only; `save_path` is left
321///   untouched.
322///
323/// This guarantees that the original source file is never modified and that
324/// no partial state leaks to the target before an explicit save.
325pub struct FileSplitStore {
326    store: DataStore,
327    /// Working path: always a private temp file; all reads and writes go here.
328    path: PathBuf,
329    /// Declared save destination; published to on `save_sampler_state(None)`.
330    save_path: PathBuf,
331    ratios: SplitRatios,
332    seed: u64,
333    /// Keeps the temporary directory alive for the lifetime of this store.
334    _temp_dir: TempDir,
335}
336
337impl fmt::Debug for FileSplitStore {
338    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
339        f.debug_struct("FileSplitStore")
340            .field("path", &self.save_path)
341            .field("ratios", &self.ratios)
342            .field("seed", &self.seed)
343            .finish()
344    }
345}
346
347impl FileSplitStore {
348    /// Open (or create) a file-backed split store at `path`.
349    pub fn open<P: Into<PathBuf>>(
350        path: P,
351        ratios: SplitRatios,
352        seed: u64,
353    ) -> Result<Self, SamplerError> {
354        Self::open_with_load_path(None::<PathBuf>, path, ratios, seed)
355    }
356
357    /// Open (or create) a file-backed split store at `save_path`, optionally
358    /// bootstrapping initial state from `load_path`.
359    ///
360    /// Always stages through a **private temporary file**.  The source used to
361    /// seed the temp is chosen as follows:
362    ///
363    /// 1. `load_path` if supplied and it exists.
364    /// 2. `save_path` if it already exists.
365    /// 3. Nothing — a fresh empty store is created in the temp.
366    ///
367    /// All mutations accumulate in the temp.  State is published to permanent
368    /// storage only when [`SamplerStateStore::save_sampler_state`] is called.
369    /// The original source file is never modified.
370    pub fn open_with_load_path<LP: Into<PathBuf>, SP: Into<PathBuf>>(
371        load_path: Option<LP>,
372        save_path: SP,
373        ratios: SplitRatios,
374        seed: u64,
375    ) -> Result<Self, SamplerError> {
376        let ratios = ratios.normalized()?;
377        let save_path = coerce_store_path(save_path.into());
378        ensure_parent_dir(&save_path)?;
379
380        // Determine source to seed the working temp from.
381        // Priority: explicit load_path > existing save_path > fresh (nothing).
382        let source: Option<PathBuf> = if let Some(lp) = load_path {
383            let lp = coerce_store_path(lp.into());
384            if lp.exists() { Some(lp) } else { None }
385        } else if save_path.exists() {
386            Some(save_path.clone())
387        } else {
388            None
389        };
390
391        let temp_dir = tempfile::tempdir().map_err(|err| {
392            SamplerError::SplitStore(format!("failed to create temp dir for split store: {err}"))
393        })?;
394        let working_path = temp_dir.path().join("working_store.bin");
395        if let Some(src) = &source {
396            fs::copy(src, &working_path).map_err(|err| {
397                SamplerError::SplitStore(format!(
398                    "failed to copy split store from '{}' to temp: {err}",
399                    src.display()
400                ))
401            })?;
402        }
403
404        let raw_store = DataStore::open(&working_path).map_err(map_store_err)?;
405        let store = Self {
406            store: raw_store,
407            path: working_path,
408            save_path,
409            ratios,
410            seed,
411            _temp_dir: temp_dir,
412        };
413        store.verify_metadata()?;
414        Ok(store)
415    }
416
417    fn verify_metadata(&self) -> Result<(), SamplerError> {
418        match read_bytes(&self.store, META_KEY)? {
419            Some(bytes) => {
420                let meta = decode_store_meta(&bytes)?;
421                if meta.version != STORE_VERSION {
422                    return Err(SamplerError::SplitStore(format!(
423                        "split store version mismatch (expected {}, found {})",
424                        STORE_VERSION, meta.version
425                    )));
426                }
427                if meta.seed != self.seed {
428                    return Err(SamplerError::SplitStore(format!(
429                        "split store seed mismatch (expected {}, found {})",
430                        self.seed, meta.seed
431                    )));
432                }
433                if !ratios_close(meta.ratios, self.ratios) {
434                    return Err(SamplerError::SplitStore(
435                        "split store ratios mismatch".into(),
436                    ));
437                }
438            }
439            None => {
440                let blob = StoreMeta {
441                    version: STORE_VERSION,
442                    seed: self.seed,
443                    ratios: self.ratios,
444                };
445                let payload = encode_store_meta(&blob);
446                write_bytes(&self.store, META_KEY, &payload)?;
447            }
448        }
449        Ok(())
450    }
451
452    fn read_epoch_meta_entry(
453        &self,
454        label: SplitLabel,
455    ) -> Result<Option<PersistedSplitMeta>, SamplerError> {
456        let key = epoch_meta_key(label);
457        let entry = self.store.read(&key).map_err(map_store_err)?;
458        match entry {
459            None => Ok(None),
460            Some(bytes) => decode_epoch_meta(bytes.as_ref()),
461        }
462    }
463
464    fn write_epoch_meta_entry(
465        &self,
466        label: SplitLabel,
467        meta: Option<&PersistedSplitMeta>,
468    ) -> Result<(), SamplerError> {
469        let key = epoch_meta_key(label);
470        let payload = encode_epoch_meta(meta);
471        self.store
472            .write(&key, payload.as_slice())
473            .map_err(map_store_err)?;
474        Ok(())
475    }
476
477    fn read_epoch_hashes_entry(
478        &self,
479        label: SplitLabel,
480    ) -> Result<Option<PersistedSplitHashes>, SamplerError> {
481        let key = epoch_hashes_key(label);
482        let entry = self.store.read(&key).map_err(map_store_err)?;
483        match entry {
484            None => Ok(None),
485            Some(bytes) => decode_epoch_hashes(bytes.as_ref()),
486        }
487    }
488
489    fn write_epoch_hashes_entry(
490        &self,
491        label: SplitLabel,
492        hashes: &PersistedSplitHashes,
493    ) -> Result<(), SamplerError> {
494        let key = epoch_hashes_key(label);
495        let payload = encode_epoch_hashes(hashes);
496        self.store
497            .write(&key, payload.as_slice())
498            .map_err(map_store_err)?;
499        Ok(())
500    }
501}
502
503impl SplitStore for FileSplitStore {
504    fn label_for(&self, id: &RecordId) -> Option<SplitLabel> {
505        let key = split_key(id);
506        if let Ok(Some(value)) = self.store.read(&key)
507            && let Ok(label) = decode_label(value.as_ref())
508        {
509            return Some(label);
510        }
511        Some(derive_label_for_id(id, self.seed, self.ratios))
512    }
513
514    fn upsert(&self, id: RecordId, label: SplitLabel) -> Result<(), SamplerError> {
515        let _ = (id, label);
516        Ok(())
517    }
518
519    fn ratios(&self) -> SplitRatios {
520        self.ratios
521    }
522
523    fn ensure(&self, id: RecordId) -> Result<SplitLabel, SamplerError> {
524        Ok(derive_label_for_id(&id, self.seed, self.ratios))
525    }
526}
527
528impl EpochStateStore for FileSplitStore {
529    fn load_epoch_meta(&self) -> Result<HashMap<SplitLabel, PersistedSplitMeta>, SamplerError> {
530        let mut meta = HashMap::new();
531        for label in ALL_SPLITS {
532            if let Some(entry) = self.read_epoch_meta_entry(label)? {
533                meta.insert(label, entry);
534            }
535        }
536        Ok(meta)
537    }
538
539    fn load_epoch_hashes(
540        &self,
541        label: SplitLabel,
542    ) -> Result<Option<PersistedSplitHashes>, SamplerError> {
543        self.read_epoch_hashes_entry(label)
544    }
545
546    fn save_epoch_meta(
547        &self,
548        meta: &HashMap<SplitLabel, PersistedSplitMeta>,
549    ) -> Result<(), SamplerError> {
550        for label in ALL_SPLITS {
551            self.write_epoch_meta_entry(label, meta.get(&label))?;
552        }
553        Ok(())
554    }
555
556    fn save_epoch_hashes(
557        &self,
558        label: SplitLabel,
559        hashes: &PersistedSplitHashes,
560    ) -> Result<(), SamplerError> {
561        self.write_epoch_hashes_entry(label, hashes)
562    }
563}
564
565impl SamplerStateStore for FileSplitStore {
566    fn load_sampler_state(&self) -> Result<Option<PersistedSamplerState>, SamplerError> {
567        match read_bytes(&self.store, SAMPLER_STATE_KEY)? {
568            Some(bytes) => decode_sampler_state(bytes.as_ref()),
569            None => Ok(None),
570        }
571    }
572
573    /// Persist `state` to the working temp store and publish to the destination.
574    ///
575    /// * `save_to == None`  → publish temp to `save_path` (may overwrite).
576    /// * `save_to == Some(p)` → publish temp to `p` only; `save_path` is left
577    ///   untouched.  `p` must not already exist.
578    fn save_sampler_state(
579        &self,
580        state: &PersistedSamplerState,
581        save_to: Option<&Path>,
582    ) -> Result<(), SamplerError> {
583        // Determine publish destination before writing anything.
584        let dest = if let Some(p) = save_to {
585            coerce_store_path(p.to_path_buf())
586        } else {
587            self.save_path.clone()
588        };
589
590        // Refuse to overwrite an explicitly-named destination that already exists.
591        // Saving back to the canonical save_path (None) is always allowed.
592        if save_to.is_some() && dest.exists() {
593            return Err(SamplerError::SplitStore(format!(
594                "refusing to overwrite existing split store '{}'; choose a new path",
595                dest.display()
596            )));
597        }
598
599        // Write state into the working temp store.
600        let payload = encode_sampler_state(state);
601        write_bytes(&self.store, SAMPLER_STATE_KEY, &payload)?;
602
603        // Publish: copy the temp store to the destination.
604        ensure_parent_dir(&dest)?;
605        fs::copy(&self.path, &dest).map_err(|err| {
606            SamplerError::SplitStore(format!(
607                "failed to publish split store to '{}': {err}",
608                dest.display()
609            ))
610        })?;
611
612        Ok(())
613    }
614}
615
616fn decode_label(bytes: &[u8]) -> Result<SplitLabel, SamplerError> {
617    match bytes.first() {
618        Some(b'0') => Ok(SplitLabel::Train),
619        Some(b'1') => Ok(SplitLabel::Validation),
620        Some(b'2') => Ok(SplitLabel::Test),
621        _ => Err(SamplerError::SplitStore("invalid split label".into())),
622    }
623}
624
625fn derive_label_for_id(id: &RecordId, seed: u64, ratios: SplitRatios) -> SplitLabel {
626    let mut hasher = std::collections::hash_map::DefaultHasher::new();
627    id.hash(&mut hasher);
628    seed.hash(&mut hasher);
629    let value = hasher.finish() as f64 / u64::MAX as f64;
630    let train_cut = ratios.train as f64;
631    let val_cut = train_cut + ratios.validation as f64;
632    if value < train_cut {
633        SplitLabel::Train
634    } else if value < val_cut {
635        SplitLabel::Validation
636    } else {
637        SplitLabel::Test
638    }
639}
640
641fn ratios_close(a: SplitRatios, b: SplitRatios) -> bool {
642    ((a.train - b.train).abs() + (a.validation - b.validation).abs() + (a.test - b.test).abs())
643        < 1e-5
644}
645
646fn split_key(id: &RecordId) -> Vec<u8> {
647    let mut key = Vec::with_capacity(SPLIT_PREFIX.len() + id.len());
648    key.extend_from_slice(SPLIT_PREFIX);
649    key.extend_from_slice(id.as_bytes());
650    key
651}
652
653fn read_bytes(store: &DataStore, key: &[u8]) -> Result<Option<Vec<u8>>, SamplerError> {
654    store
655        .read(key)
656        .map_err(map_store_err)?
657        .map(|entry| Ok(entry.as_ref().to_vec()))
658        .transpose()
659}
660
661fn write_bytes(store: &DataStore, key: &[u8], payload: &[u8]) -> Result<(), SamplerError> {
662    store.write(key, payload).map_err(map_store_err)?;
663    Ok(())
664}
665
666fn epoch_meta_key(label: SplitLabel) -> Vec<u8> {
667    let mut key = Vec::with_capacity(EPOCH_META_PREFIX.len() + 12);
668    key.extend_from_slice(EPOCH_META_PREFIX);
669    let suffix = match label {
670        SplitLabel::Train => b"train".as_ref(),
671        SplitLabel::Validation => b"validation".as_ref(),
672        SplitLabel::Test => b"test".as_ref(),
673    };
674    key.extend_from_slice(suffix);
675    key
676}
677
678fn epoch_hashes_key(label: SplitLabel) -> Vec<u8> {
679    let mut key = Vec::with_capacity(EPOCH_HASHES_PREFIX.len() + 12);
680    key.extend_from_slice(EPOCH_HASHES_PREFIX);
681    let suffix = match label {
682        SplitLabel::Train => b"train".as_ref(),
683        SplitLabel::Validation => b"validation".as_ref(),
684        SplitLabel::Test => b"test".as_ref(),
685    };
686    key.extend_from_slice(suffix);
687    key
688}
689
690fn encode_epoch_meta(meta: Option<&PersistedSplitMeta>) -> Vec<u8> {
691    match meta {
692        None => vec![EPOCH_RECORD_TOMBSTONE],
693        Some(meta) => {
694            let payload = encode_bitcode_payload(&bitcode::encode(meta));
695            let mut buf = Vec::with_capacity(1 + payload.len());
696            buf.push(EPOCH_META_RECORD_VERSION);
697            buf.extend_from_slice(&payload);
698            buf
699        }
700    }
701}
702
703fn decode_epoch_meta(bytes: &[u8]) -> Result<Option<PersistedSplitMeta>, SamplerError> {
704    if bytes.is_empty() || bytes[0] == EPOCH_RECORD_TOMBSTONE {
705        return Ok(None);
706    }
707    if bytes[0] != EPOCH_META_RECORD_VERSION {
708        return Err(SamplerError::SplitStore(
709            "epoch meta record version mismatch".into(),
710        ));
711    }
712    let raw = decode_bitcode_payload(&bytes[1..])?;
713    bitcode::decode(&raw)
714        .map(Some)
715        .map_err(|err| SamplerError::SplitStore(format!("corrupt epoch meta record: {err}")))
716}
717
718fn encode_epoch_hashes(hashes: &PersistedSplitHashes) -> Vec<u8> {
719    let payload = encode_bitcode_payload(&bitcode::encode(hashes));
720    let mut buf = Vec::with_capacity(1 + payload.len());
721    buf.push(EPOCH_HASH_RECORD_VERSION);
722    buf.extend_from_slice(&payload);
723    buf
724}
725
726fn decode_epoch_hashes(bytes: &[u8]) -> Result<Option<PersistedSplitHashes>, SamplerError> {
727    if bytes.is_empty() || bytes[0] == EPOCH_RECORD_TOMBSTONE {
728        return Ok(None);
729    }
730    if bytes[0] != EPOCH_HASH_RECORD_VERSION {
731        return Err(SamplerError::SplitStore(
732            "epoch hashes record version mismatch".into(),
733        ));
734    }
735    let raw = decode_bitcode_payload(&bytes[1..])?;
736    bitcode::decode(&raw)
737        .map(Some)
738        .map_err(|err| SamplerError::SplitStore(format!("corrupt epoch hashes record: {err}")))
739}
740
741fn encode_sampler_state(state: &PersistedSamplerState) -> Vec<u8> {
742    let payload = encode_bitcode_payload(&bitcode::encode(state));
743    let mut buf = Vec::with_capacity(1 + payload.len());
744    buf.push(SAMPLER_STATE_RECORD_VERSION);
745    buf.extend_from_slice(&payload);
746    buf
747}
748
749fn decode_sampler_state(bytes: &[u8]) -> Result<Option<PersistedSamplerState>, SamplerError> {
750    if bytes.is_empty() {
751        return Ok(None);
752    }
753    if bytes[0] != SAMPLER_STATE_RECORD_VERSION {
754        return Err(SamplerError::SplitStore(
755            "sampler state record version mismatch".into(),
756        ));
757    }
758    let raw = decode_bitcode_payload(&bytes[1..])?;
759    bitcode::decode(&raw)
760        .map(Some)
761        .map_err(|err| SamplerError::SplitStore(format!("corrupt sampler state record: {err}")))
762}
763
764fn encode_bitcode_payload(bytes: &[u8]) -> Vec<u8> {
765    let mut out = Vec::with_capacity(1 + bytes.len());
766    out.push(BITCODE_PREFIX);
767    out.extend_from_slice(bytes);
768    out
769}
770
771fn decode_bitcode_payload(bytes: &[u8]) -> Result<Vec<u8>, SamplerError> {
772    if bytes.first().copied() != Some(BITCODE_PREFIX) {
773        return Err(SamplerError::SplitStore(
774            "bitcode payload missing expected prefix".into(),
775        ));
776    }
777    Ok(bytes[1..].to_vec())
778}
779
780fn coerce_store_path(path: PathBuf) -> PathBuf {
781    path
782}
783
784fn ensure_parent_dir(path: &Path) -> Result<(), SamplerError> {
785    if let Some(parent) = path.parent()
786        && !parent.as_os_str().is_empty()
787    {
788        fs::create_dir_all(parent)?;
789    }
790    Ok(())
791}
792
793fn map_store_err(err: io::Error) -> SamplerError {
794    SamplerError::SplitStore(err.to_string())
795}
796
797#[cfg(test)]
798mod tests {
799    use super::*;
800    use std::collections::HashMap;
801    use tempfile::tempdir;
802
803    #[test]
804    fn split_ratios_reject_non_unit_sum() {
805        let invalid = SplitRatios {
806            train: 0.6,
807            validation: 0.3,
808            test: 0.3,
809        };
810
811        let err = DeterministicSplitStore::new(invalid, 1)
812            .err()
813            .expect("expected non-unit split ratios to fail");
814        assert!(matches!(
815            err,
816            SamplerError::Configuration(ref msg) if msg.contains("split ratios must sum to 1.0")
817        ));
818
819        let dir = tempdir().unwrap();
820        let path = dir.path().join("split_store.bin");
821        let err = FileSplitStore::open(&path, invalid, 1).unwrap_err();
822        assert!(matches!(
823            err,
824            SamplerError::Configuration(ref msg) if msg.contains("split ratios must sum to 1.0")
825        ));
826    }
827
828    #[test]
829    fn zero_test_ratio_never_assigns_test_labels() {
830        let ratios = SplitRatios {
831            train: 0.5,
832            validation: 0.5,
833            test: 0.0,
834        };
835        let store = DeterministicSplitStore::new(ratios, 7).unwrap();
836
837        let mut saw_train = false;
838        let mut saw_validation = false;
839        for idx in 0..20_000 {
840            let id = format!("record_{idx}");
841            let label = store.ensure(id).unwrap();
842            assert_ne!(label, SplitLabel::Test);
843            saw_train |= label == SplitLabel::Train;
844            saw_validation |= label == SplitLabel::Validation;
845            if saw_train && saw_validation {
846                break;
847            }
848        }
849
850        assert!(saw_train);
851        assert!(saw_validation);
852    }
853
854    /// `save(None)` must publish the temp to `save_path` on disk.  A fresh
855    /// `open` of that path must see everything that was written before the save.
856    #[test]
857    fn save_none_publishes_to_save_path_and_reloads_cleanly() {
858        let dir = tempdir().unwrap();
859        let path = dir.path().join("persist.bin");
860        let ratios = SplitRatios::default();
861
862        {
863            let store = FileSplitStore::open(&path, ratios, 123).unwrap();
864            let mut meta = HashMap::new();
865            meta.insert(
866                SplitLabel::Train,
867                PersistedSplitMeta {
868                    epoch: 3,
869                    offset: 7,
870                    hashes_checksum: 42,
871                },
872            );
873            store.save_epoch_meta(&meta).unwrap();
874            let state = PersistedSamplerState {
875                source_cycle_idx: 11,
876                source_record_cursors: vec![("s".to_string(), 1)],
877                source_epoch: 5,
878                rng_state: 99,
879                triplet_recipe_rr_idx: 2,
880                text_recipe_rr_idx: 3,
881                source_stream_cursors: vec![],
882            };
883            assert!(!path.exists(), "save_path must not exist before save(None)");
884            store.save_sampler_state(&state, None).unwrap();
885            assert!(
886                path.exists(),
887                "save(None) must publish to save_path on disk"
888            );
889        }
890
891        // Fresh open must read everything back from disk.
892        let reopened = FileSplitStore::open(&path, ratios, 123).unwrap();
893        let loaded_state = reopened.load_sampler_state().unwrap().unwrap();
894        assert_eq!(loaded_state.source_cycle_idx, 11);
895        assert_eq!(loaded_state.source_epoch, 5);
896        assert_eq!(loaded_state.rng_state, 99);
897        let loaded_meta = reopened.load_epoch_meta().unwrap();
898        assert_eq!(loaded_meta.get(&SplitLabel::Train).unwrap().epoch, 3);
899    }
900
901    #[test]
902    fn file_store_rejects_seed_mismatch() {
903        let dir = tempdir().unwrap();
904        let path = dir.path().join("splits.json");
905        let ratios = SplitRatios::default();
906        let store = FileSplitStore::open(&path, ratios, 123).unwrap();
907        store.ensure("abc".to_string()).unwrap();
908        // Publish to disk so the seed is committed before the next open.
909        store
910            .save_sampler_state(
911                &PersistedSamplerState {
912                    source_cycle_idx: 0,
913                    source_record_cursors: vec![],
914                    source_epoch: 0,
915                    rng_state: 0,
916                    triplet_recipe_rr_idx: 0,
917                    text_recipe_rr_idx: 0,
918                    source_stream_cursors: vec![],
919                },
920                None,
921            )
922            .unwrap();
923        drop(store);
924
925        let err = FileSplitStore::open(&path, ratios, 999).unwrap_err();
926        assert!(matches!(
927            err,
928            SamplerError::SplitStore(msg) if msg.contains("seed")
929        ));
930    }
931
932    #[test]
933    fn file_store_accepts_directory_path() {
934        let dir = tempdir().unwrap();
935        let ratios = SplitRatios::default();
936        let err = FileSplitStore::open(dir.path(), ratios, 777).unwrap_err();
937        assert!(matches!(err, SamplerError::SplitStore(_)));
938    }
939
940    #[test]
941    fn bitcode_payload_requires_prefix() {
942        let err = decode_bitcode_payload(&[0x00, 0x01]).unwrap_err();
943        assert!(
944            matches!(err, SamplerError::SplitStore(msg) if msg.contains("missing expected prefix"))
945        );
946    }
947
948    #[test]
949    fn file_store_round_trips_epoch_and_sampler_state() {
950        let dir = tempdir().unwrap();
951        let path = dir.path().join("epoch_sampler_state.bin");
952        let store = FileSplitStore::open(&path, SplitRatios::default(), 222).unwrap();
953
954        assert!(store.load_epoch_hashes(SplitLabel::Test).unwrap().is_none());
955
956        let mut epoch_meta = HashMap::new();
957        epoch_meta.insert(
958            SplitLabel::Train,
959            PersistedSplitMeta {
960                epoch: 3,
961                offset: 7,
962                hashes_checksum: 42,
963            },
964        );
965        store.save_epoch_meta(&epoch_meta).unwrap();
966
967        let loaded_meta = store.load_epoch_meta().unwrap();
968        let loaded_train = loaded_meta.get(&SplitLabel::Train).unwrap();
969        assert_eq!(loaded_train.epoch, 3);
970        assert_eq!(loaded_train.offset, 7);
971        assert_eq!(loaded_train.hashes_checksum, 42);
972
973        let hashes = PersistedSplitHashes {
974            checksum: 99,
975            hashes: vec![10, 20, 30],
976        };
977        store
978            .save_epoch_hashes(SplitLabel::Validation, &hashes)
979            .unwrap();
980        let loaded_hashes = store
981            .load_epoch_hashes(SplitLabel::Validation)
982            .unwrap()
983            .unwrap();
984        assert_eq!(loaded_hashes.checksum, 99);
985        assert_eq!(loaded_hashes.hashes, vec![10, 20, 30]);
986
987        let state = PersistedSamplerState {
988            source_cycle_idx: 11,
989            source_record_cursors: vec![("source_a".to_string(), 1)],
990            source_epoch: 8,
991            rng_state: 1234,
992            triplet_recipe_rr_idx: 2,
993            text_recipe_rr_idx: 5,
994            source_stream_cursors: vec![("source_a".to_string(), 9)],
995        };
996        store.save_sampler_state(&state, None).unwrap();
997        let loaded_state = store.load_sampler_state().unwrap().unwrap();
998        assert_eq!(loaded_state.source_cycle_idx, 11);
999        assert_eq!(loaded_state.source_epoch, 8);
1000        assert_eq!(loaded_state.rng_state, 1234);
1001        assert_eq!(loaded_state.triplet_recipe_rr_idx, 2);
1002        assert_eq!(loaded_state.text_recipe_rr_idx, 5);
1003        assert_eq!(
1004            loaded_state.source_record_cursors,
1005            vec![("source_a".to_string(), 1)]
1006        );
1007        assert_eq!(
1008            loaded_state.source_stream_cursors,
1009            vec![("source_a".to_string(), 9)]
1010        );
1011
1012        // Verify the same data survives a drop + fresh open from disk.
1013        drop(store);
1014        let reopened = FileSplitStore::open(&path, SplitRatios::default(), 222).unwrap();
1015        let disk_state = reopened.load_sampler_state().unwrap().unwrap();
1016        assert_eq!(disk_state.source_cycle_idx, 11);
1017        assert_eq!(disk_state.source_epoch, 8);
1018        assert_eq!(disk_state.rng_state, 1234);
1019        let disk_meta = reopened.load_epoch_meta().unwrap();
1020        assert_eq!(disk_meta.get(&SplitLabel::Train).unwrap().epoch, 3);
1021        let disk_hashes = reopened
1022            .load_epoch_hashes(SplitLabel::Validation)
1023            .unwrap()
1024            .unwrap();
1025        assert_eq!(disk_hashes.checksum, 99);
1026    }
1027
1028    #[test]
1029    fn split_keys_and_labels_cover_helper_paths() {
1030        let key = split_key(&"abc".to_string());
1031        assert!(key.starts_with(SPLIT_PREFIX));
1032
1033        assert!(matches!(decode_label(b"0"), Ok(SplitLabel::Train)));
1034        assert!(matches!(decode_label(b"1"), Ok(SplitLabel::Validation)));
1035        assert!(matches!(decode_label(b"2"), Ok(SplitLabel::Test)));
1036        assert!(decode_label(b"x").is_err());
1037
1038        let epoch_meta_train = epoch_meta_key(SplitLabel::Train);
1039        let epoch_hashes_train = epoch_hashes_key(SplitLabel::Train);
1040        let epoch_hashes_test = epoch_hashes_key(SplitLabel::Test);
1041        assert!(epoch_meta_train.starts_with(EPOCH_META_PREFIX));
1042        assert!(epoch_hashes_train.starts_with(EPOCH_HASHES_PREFIX));
1043        assert!(epoch_hashes_test.starts_with(EPOCH_HASHES_PREFIX));
1044    }
1045
1046    #[test]
1047    fn encode_decode_store_meta_roundtrip_and_corrupt_prefix_error() {
1048        let meta = StoreMeta {
1049            version: STORE_VERSION,
1050            seed: 55,
1051            ratios: SplitRatios::default(),
1052        };
1053        let encoded = encode_store_meta(&meta);
1054        let decoded = decode_store_meta(&encoded).unwrap();
1055        assert_eq!(decoded.version, STORE_VERSION);
1056        assert_eq!(decoded.seed, 55);
1057
1058        let err = decode_store_meta(&[0x00, 0x01]).unwrap_err();
1059        assert!(matches!(
1060            err,
1061            SamplerError::SplitStore(msg) if msg.contains("missing expected prefix")
1062        ));
1063    }
1064
1065    #[test]
1066    fn epoch_and_sampler_decoders_cover_tombstone_and_version_mismatch() {
1067        assert!(decode_epoch_meta(&[]).unwrap().is_none());
1068        assert!(
1069            decode_epoch_meta(&[EPOCH_RECORD_TOMBSTONE])
1070                .unwrap()
1071                .is_none()
1072        );
1073        assert!(decode_epoch_hashes(&[]).unwrap().is_none());
1074        assert!(
1075            decode_epoch_hashes(&[EPOCH_RECORD_TOMBSTONE])
1076                .unwrap()
1077                .is_none()
1078        );
1079        assert!(decode_sampler_state(&[]).unwrap().is_none());
1080
1081        let meta_mismatch = decode_epoch_meta(&[EPOCH_META_RECORD_VERSION.wrapping_add(1), 1]);
1082        assert!(matches!(
1083            meta_mismatch,
1084            Err(SamplerError::SplitStore(msg)) if msg.contains("version mismatch")
1085        ));
1086        let hashes_mismatch = decode_epoch_hashes(&[EPOCH_HASH_RECORD_VERSION.wrapping_add(1), 1]);
1087        assert!(matches!(
1088            hashes_mismatch,
1089            Err(SamplerError::SplitStore(msg)) if msg.contains("version mismatch")
1090        ));
1091        let state_mismatch =
1092            decode_sampler_state(&[SAMPLER_STATE_RECORD_VERSION.wrapping_add(1), 1]);
1093        assert!(matches!(
1094            state_mismatch,
1095            Err(SamplerError::SplitStore(msg)) if msg.contains("version mismatch")
1096        ));
1097    }
1098
1099    #[test]
1100    fn split_store_trait_methods_and_path_helpers_are_exercised() {
1101        let dir = tempdir().unwrap();
1102        let file_path = dir.path().join("nested").join("store.bin");
1103        ensure_parent_dir(&file_path).unwrap();
1104        assert!(file_path.parent().unwrap().exists());
1105
1106        let existing_dir_path = coerce_store_path(dir.path().to_path_buf());
1107        assert_eq!(existing_dir_path, dir.path().to_path_buf());
1108
1109        let ratios = SplitRatios::default();
1110        let store = FileSplitStore::open(&file_path, ratios, 444).unwrap();
1111        assert!((store.ratios().train - ratios.train).abs() < 1e-6);
1112        store
1113            .upsert("record_1".to_string(), SplitLabel::Validation)
1114            .unwrap();
1115        let ensured = store.ensure("record_1".to_string()).unwrap();
1116        assert!(matches!(
1117            ensured,
1118            SplitLabel::Train | SplitLabel::Validation | SplitLabel::Test
1119        ));
1120
1121        let mapped = map_store_err(io::Error::other("boom"));
1122        assert!(matches!(mapped, SamplerError::SplitStore(msg) if msg.contains("boom")));
1123    }
1124
1125    #[test]
1126    fn epoch_and_sampler_encode_decode_roundtrips() {
1127        let meta = PersistedSplitMeta {
1128            epoch: 4,
1129            offset: 9,
1130            hashes_checksum: 21,
1131        };
1132        let encoded_meta = encode_epoch_meta(Some(&meta));
1133        let decoded_meta = decode_epoch_meta(&encoded_meta).unwrap().unwrap();
1134        assert_eq!(decoded_meta.epoch, 4);
1135        assert_eq!(decoded_meta.offset, 9);
1136
1137        let hashes = PersistedSplitHashes {
1138            checksum: 7,
1139            hashes: vec![1, 2, 3],
1140        };
1141        let encoded_hashes = encode_epoch_hashes(&hashes);
1142        let decoded_hashes = decode_epoch_hashes(&encoded_hashes).unwrap().unwrap();
1143        assert_eq!(decoded_hashes.checksum, 7);
1144        assert_eq!(decoded_hashes.hashes, vec![1, 2, 3]);
1145
1146        let state = PersistedSamplerState {
1147            source_cycle_idx: 1,
1148            source_record_cursors: vec![("s".to_string(), 2)],
1149            source_epoch: 3,
1150            rng_state: 4,
1151            triplet_recipe_rr_idx: 5,
1152            text_recipe_rr_idx: 6,
1153            source_stream_cursors: vec![("s".to_string(), 7)],
1154        };
1155        let encoded_state = encode_sampler_state(&state);
1156        let decoded_state = decode_sampler_state(&encoded_state).unwrap().unwrap();
1157        assert_eq!(decoded_state.source_cycle_idx, 1);
1158        assert_eq!(decoded_state.source_epoch, 3);
1159        assert_eq!(decoded_state.rng_state, 4);
1160    }
1161
1162    #[test]
1163    fn deterministic_store_trait_methods_work() {
1164        let ratios = SplitRatios::default();
1165        let store = DeterministicSplitStore::new(ratios, 999).unwrap();
1166
1167        assert_eq!(store.ratios().train, ratios.train);
1168
1169        let id = "source::record".to_string();
1170        let derived = store.label_for(&id).unwrap();
1171        store.upsert(id.clone(), SplitLabel::Validation).unwrap();
1172        assert_eq!(store.label_for(&id), Some(SplitLabel::Validation));
1173        assert!(matches!(
1174            derived,
1175            SplitLabel::Train | SplitLabel::Validation | SplitLabel::Test
1176        ));
1177
1178        let mut meta = HashMap::new();
1179        meta.insert(
1180            SplitLabel::Test,
1181            PersistedSplitMeta {
1182                epoch: 1,
1183                offset: 2,
1184                hashes_checksum: 3,
1185            },
1186        );
1187        store.save_epoch_meta(&meta).unwrap();
1188        let loaded_meta = store.load_epoch_meta().unwrap();
1189        assert_eq!(loaded_meta.get(&SplitLabel::Test).unwrap().offset, 2);
1190
1191        assert!(
1192            store
1193                .load_epoch_hashes(SplitLabel::Train)
1194                .unwrap()
1195                .is_none()
1196        );
1197        store
1198            .save_epoch_hashes(
1199                SplitLabel::Train,
1200                &PersistedSplitHashes {
1201                    checksum: 11,
1202                    hashes: vec![4, 5],
1203                },
1204            )
1205            .unwrap();
1206        assert_eq!(
1207            store
1208                .load_epoch_hashes(SplitLabel::Train)
1209                .unwrap()
1210                .unwrap()
1211                .checksum,
1212            11
1213        );
1214
1215        assert!(store.load_sampler_state().unwrap().is_none());
1216        let sampler_state = PersistedSamplerState {
1217            source_cycle_idx: 1,
1218            source_record_cursors: vec![("s1".to_string(), 2)],
1219            source_epoch: 3,
1220            rng_state: 4,
1221            triplet_recipe_rr_idx: 5,
1222            text_recipe_rr_idx: 6,
1223            source_stream_cursors: vec![("s1".to_string(), 7)],
1224        };
1225        store.save_sampler_state(&sampler_state, None).unwrap();
1226        assert_eq!(
1227            store.load_sampler_state().unwrap().unwrap().source_epoch,
1228            sampler_state.source_epoch
1229        );
1230    }
1231
1232    #[test]
1233    fn open_with_load_path_bootstraps_state_explicitly() {
1234        let dir = tempdir().unwrap();
1235        let path_a = dir.path().join("snapshot_a.bin");
1236        let path_b = dir.path().join("snapshot_b.bin");
1237        let ratios = SplitRatios::default();
1238
1239        let store_a = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1240
1241        let mut meta = HashMap::new();
1242        meta.insert(
1243            SplitLabel::Train,
1244            PersistedSplitMeta {
1245                epoch: 5,
1246                offset: 3,
1247                hashes_checksum: 999,
1248            },
1249        );
1250        store_a.save_epoch_meta(&meta).unwrap();
1251
1252        let sampler_state = PersistedSamplerState {
1253            source_cycle_idx: 1,
1254            source_record_cursors: vec![("s1".to_string(), 2)],
1255            source_epoch: 7,
1256            rng_state: 123,
1257            triplet_recipe_rr_idx: 4,
1258            text_recipe_rr_idx: 6,
1259            source_stream_cursors: vec![("s1".to_string(), 8)],
1260        };
1261        store_a.save_sampler_state(&sampler_state, None).unwrap();
1262        drop(store_a);
1263
1264        let store_b =
1265            FileSplitStore::open_with_load_path(Some(path_a.clone()), &path_b, ratios, 42).unwrap();
1266        assert_eq!(
1267            store_b
1268                .load_epoch_meta()
1269                .unwrap()
1270                .get(&SplitLabel::Train)
1271                .unwrap()
1272                .epoch,
1273            5
1274        );
1275        assert_eq!(
1276            store_b.load_sampler_state().unwrap().unwrap().source_epoch,
1277            7
1278        );
1279
1280        let store_a_again = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1281        assert_eq!(
1282            store_a_again
1283                .load_epoch_meta()
1284                .unwrap()
1285                .get(&SplitLabel::Train)
1286                .unwrap()
1287                .epoch,
1288            5
1289        );
1290        assert_eq!(
1291            store_a_again
1292                .load_sampler_state()
1293                .unwrap()
1294                .unwrap()
1295                .source_epoch,
1296            7
1297        );
1298    }
1299
1300    #[test]
1301    fn save_sampler_state_to_new_path_copies_existing_store_first() {
1302        let dir = tempdir().unwrap();
1303        let path_a = dir.path().join("source_store.bin");
1304        let path_b = dir.path().join("mirror_store.bin");
1305        let ratios = SplitRatios::default();
1306
1307        let store_a = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1308
1309        let assigned_id = "record_with_assignment".to_string();
1310        let assigned_key = split_key(&assigned_id);
1311        store_a.store.write(&assigned_key, b"1").unwrap();
1312
1313        let sampler_state = PersistedSamplerState {
1314            source_cycle_idx: 1,
1315            source_record_cursors: vec![("s1".to_string(), 2)],
1316            source_epoch: 9,
1317            rng_state: 123,
1318            triplet_recipe_rr_idx: 4,
1319            text_recipe_rr_idx: 6,
1320            source_stream_cursors: vec![("s1".to_string(), 8)],
1321        };
1322
1323        store_a
1324            .save_sampler_state(&sampler_state, Some(path_b.as_path()))
1325            .unwrap();
1326
1327        // Destination gets the existing store data AND the new sampler state.
1328        let store_b = FileSplitStore::open(&path_b, ratios, 42).unwrap();
1329        assert_eq!(
1330            store_b.label_for(&assigned_id),
1331            Some(SplitLabel::Validation)
1332        );
1333        assert_eq!(
1334            store_b.load_sampler_state().unwrap().unwrap().source_epoch,
1335            9
1336        );
1337
1338        // save_to=Some(path_b) must not publish to path_a (the canonical save_path).
1339        assert!(
1340            !path_a.exists(),
1341            "save_to=Some(...) must not publish to the canonical save_path"
1342        );
1343    }
1344
1345    /// Saving to a custom path on a plain `open` store must not mutate the
1346    /// working store file -- i.e. a previously-saved state is preserved.
1347    #[test]
1348    fn save_some_on_regular_open_does_not_modify_working_store() {
1349        let dir = tempdir().unwrap();
1350        let path_a = dir.path().join("working_store.bin");
1351        let path_b = dir.path().join("checkpoint_store.bin");
1352        let ratios = SplitRatios::default();
1353
1354        let store_a = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1355
1356        // Establish a baseline state in the working store.
1357        let initial_state = PersistedSamplerState {
1358            source_cycle_idx: 1,
1359            source_record_cursors: vec![],
1360            source_epoch: 1,
1361            rng_state: 0,
1362            triplet_recipe_rr_idx: 0,
1363            text_recipe_rr_idx: 0,
1364            source_stream_cursors: vec![],
1365        };
1366        store_a.save_sampler_state(&initial_state, None).unwrap();
1367
1368        // Snapshot a newer state to a separate checkpoint path.
1369        let checkpoint_state = PersistedSamplerState {
1370            source_cycle_idx: 99,
1371            source_record_cursors: vec![],
1372            source_epoch: 99,
1373            rng_state: 42,
1374            triplet_recipe_rr_idx: 0,
1375            text_recipe_rr_idx: 0,
1376            source_stream_cursors: vec![],
1377        };
1378        store_a
1379            .save_sampler_state(&checkpoint_state, Some(path_b.as_path()))
1380            .unwrap();
1381
1382        // The on-disk save_path (path_a) must not have been overwritten by save_to=Some(...).
1383        // Re-open from disk to verify; path_a was last published by save_to=None above.
1384        drop(store_a);
1385        let store_a_disk = FileSplitStore::open(&path_a, ratios, 42).unwrap();
1386        assert_eq!(
1387            store_a_disk
1388                .load_sampler_state()
1389                .unwrap()
1390                .unwrap()
1391                .source_epoch,
1392            1,
1393            "save_to=Some(...) must not overwrite the on-disk save_path"
1394        );
1395
1396        // The checkpoint must hold the new state.
1397        let store_b = FileSplitStore::open(&path_b, ratios, 42).unwrap();
1398        let state_from_b = store_b.load_sampler_state().unwrap().unwrap();
1399        assert_eq!(
1400            state_from_b.source_epoch, 99,
1401            "checkpoint store must hold the snapshotted state"
1402        );
1403    }
1404
1405    #[test]
1406    fn file_store_metadata_mismatch_and_debug_paths_are_covered() {
1407        let dir = tempdir().unwrap();
1408        let path = dir.path().join("meta_mismatch.bin");
1409        let ratios = SplitRatios::default();
1410        let store = FileSplitStore::open(&path, ratios, 123).unwrap();
1411
1412        let debug_repr = format!("{store:?}");
1413        assert!(debug_repr.contains("FileSplitStore"));
1414
1415        let wrong_version = StoreMeta {
1416            version: STORE_VERSION.wrapping_add(1),
1417            seed: 123,
1418            ratios,
1419        };
1420        let payload = encode_store_meta(&wrong_version);
1421        store.store.write(META_KEY, &payload).unwrap();
1422        // Publish the corrupted temp to disk so the next open reads the bad version.
1423        store
1424            .save_sampler_state(
1425                &PersistedSamplerState {
1426                    source_cycle_idx: 0,
1427                    source_record_cursors: vec![],
1428                    source_epoch: 0,
1429                    rng_state: 0,
1430                    triplet_recipe_rr_idx: 0,
1431                    text_recipe_rr_idx: 0,
1432                    source_stream_cursors: vec![],
1433                },
1434                None,
1435            )
1436            .unwrap();
1437        drop(store);
1438
1439        let err = FileSplitStore::open(&path, ratios, 123).unwrap_err();
1440        assert!(matches!(err, SamplerError::SplitStore(msg) if msg.contains("version mismatch")));
1441
1442        let ratio_path = dir.path().join("ratio_mismatch.bin");
1443        let baseline = FileSplitStore::open(&ratio_path, ratios, 777).unwrap();
1444        // Publish so ratio_path exists on disk before reopening with different ratios.
1445        baseline
1446            .save_sampler_state(
1447                &PersistedSamplerState {
1448                    source_cycle_idx: 0,
1449                    source_record_cursors: vec![],
1450                    source_epoch: 0,
1451                    rng_state: 0,
1452                    triplet_recipe_rr_idx: 0,
1453                    text_recipe_rr_idx: 0,
1454                    source_stream_cursors: vec![],
1455                },
1456                None,
1457            )
1458            .unwrap();
1459        drop(baseline);
1460
1461        let different_ratios = SplitRatios {
1462            train: 0.7,
1463            validation: 0.2,
1464            test: 0.1,
1465        };
1466        let err = FileSplitStore::open(&ratio_path, different_ratios, 777).unwrap_err();
1467        assert!(matches!(err, SamplerError::SplitStore(msg) if msg.contains("ratios mismatch")));
1468    }
1469
1470    #[test]
1471    fn split_decode_helpers_reject_corrupt_bitcode_payloads() {
1472        let store_meta_err = decode_store_meta(&[BITCODE_PREFIX, 0xFF, 0xEE]).unwrap_err();
1473        assert!(matches!(
1474            store_meta_err,
1475            SamplerError::SplitStore(msg) if msg.contains("failed to decode split store metadata")
1476        ));
1477
1478        let epoch_meta_err =
1479            decode_epoch_meta(&[EPOCH_META_RECORD_VERSION, BITCODE_PREFIX, 0xFF]).unwrap_err();
1480        assert!(
1481            matches!(epoch_meta_err, SamplerError::SplitStore(msg) if msg.contains("corrupt epoch meta record"))
1482        );
1483
1484        let epoch_hashes_err =
1485            decode_epoch_hashes(&[EPOCH_HASH_RECORD_VERSION, BITCODE_PREFIX, 0xFF]).unwrap_err();
1486        assert!(matches!(
1487            epoch_hashes_err,
1488            SamplerError::SplitStore(msg) if msg.contains("corrupt epoch hashes record")
1489        ));
1490
1491        let sampler_state_err =
1492            decode_sampler_state(&[SAMPLER_STATE_RECORD_VERSION, BITCODE_PREFIX, 0xFF])
1493                .unwrap_err();
1494        assert!(matches!(
1495            sampler_state_err,
1496            SamplerError::SplitStore(msg) if msg.contains("corrupt sampler state record")
1497        ));
1498    }
1499
1500    #[test]
1501    fn file_store_label_fallback_and_validation_keys_are_covered() {
1502        let dir = tempdir().unwrap();
1503        let path = dir.path().join("labels.bin");
1504        let store = FileSplitStore::open(&path, SplitRatios::default(), 42).unwrap();
1505
1506        let id = "bad_label_record".to_string();
1507        let expected = derive_label_for_id(&id, 42, SplitRatios::default());
1508        let key = split_key(&id);
1509
1510        store.store.write(&key, b"x").unwrap();
1511        assert_eq!(store.label_for(&id), Some(expected));
1512
1513        store.store.write(&key, b"1").unwrap();
1514        assert_eq!(store.label_for(&id), Some(SplitLabel::Validation));
1515
1516        let meta_validation = epoch_meta_key(SplitLabel::Validation);
1517        let hashes_validation = epoch_hashes_key(SplitLabel::Validation);
1518        assert!(meta_validation.starts_with(EPOCH_META_PREFIX));
1519        assert!(hashes_validation.starts_with(EPOCH_HASHES_PREFIX));
1520        assert!(meta_validation.ends_with(b"validation"));
1521        assert!(hashes_validation.ends_with(b"validation"));
1522    }
1523
1524    #[test]
1525    fn ensure_parent_dir_allows_plain_file_names() {
1526        ensure_parent_dir(Path::new("split_store_local.bin")).unwrap();
1527        let coerced = coerce_store_path(PathBuf::from("explicit_store.bin"));
1528        assert_eq!(coerced, PathBuf::from("explicit_store.bin"));
1529    }
1530
1531    // -----------------------------------------------------------------------
1532    // Temp-dir bootstrap contract
1533    // -----------------------------------------------------------------------
1534
1535    /// `open_with_load_path` must NOT modify the source file while the store is open.
1536    #[test]
1537    fn load_path_source_is_never_modified_while_open() {
1538        let dir = tempdir().unwrap();
1539        let source = dir.path().join("source.bin");
1540        let dest = dir.path().join("dest.bin");
1541        let ratios = SplitRatios::default();
1542
1543        // Seed the source with known state.
1544        let seeded = FileSplitStore::open(&source, ratios, 77).unwrap();
1545        let state = PersistedSamplerState {
1546            source_cycle_idx: 5,
1547            source_record_cursors: vec![("s".to_string(), 3)],
1548            source_epoch: 9,
1549            rng_state: 42,
1550            triplet_recipe_rr_idx: 1,
1551            text_recipe_rr_idx: 2,
1552            source_stream_cursors: vec![("s".to_string(), 4)],
1553        };
1554        seeded.save_sampler_state(&state, None).unwrap();
1555        drop(seeded);
1556
1557        let source_size_before = std::fs::metadata(&source).unwrap().len();
1558
1559        // Open bootstrapped store and write mutations to it.
1560        let bootstrapped =
1561            FileSplitStore::open_with_load_path(Some(&source), &dest, ratios, 77).unwrap();
1562        let new_state = PersistedSamplerState {
1563            source_cycle_idx: 99,
1564            source_record_cursors: vec![("s".to_string(), 77)],
1565            source_epoch: 100,
1566            rng_state: 0,
1567            triplet_recipe_rr_idx: 0,
1568            text_recipe_rr_idx: 0,
1569            source_stream_cursors: vec![],
1570        };
1571        bootstrapped.save_sampler_state(&new_state, None).unwrap();
1572        drop(bootstrapped);
1573
1574        // Source must be byte-identical to before bootstrap.
1575        let source_size_after = std::fs::metadata(&source).unwrap().len();
1576        assert_eq!(
1577            source_size_before, source_size_after,
1578            "source file was modified during bootstrapped open"
1579        );
1580
1581        // Verify source still holds the original state.
1582        let verify_source = FileSplitStore::open(&source, ratios, 77).unwrap();
1583        let loaded = verify_source.load_sampler_state().unwrap().unwrap();
1584        assert_eq!(loaded.source_cycle_idx, 5);
1585        assert_eq!(loaded.source_epoch, 9);
1586    }
1587
1588    /// `save_sampler_state(None)` on a bootstrapped store must publish to the
1589    /// declared `save_path` only, not to the source and not to a temp path.
1590    #[test]
1591    fn save_none_on_bootstrapped_store_publishes_to_save_path() {
1592        let dir = tempdir().unwrap();
1593        let source = dir.path().join("load.bin");
1594        let dest = dir.path().join("save.bin");
1595        let ratios = SplitRatios::default();
1596
1597        let _ = FileSplitStore::open(&source, ratios, 11).unwrap();
1598
1599        assert!(!dest.exists(), "dest must not exist before first save");
1600
1601        let store = FileSplitStore::open_with_load_path(Some(&source), &dest, ratios, 11).unwrap();
1602        let state = PersistedSamplerState {
1603            source_cycle_idx: 7,
1604            source_record_cursors: vec![],
1605            source_epoch: 2,
1606            rng_state: 1,
1607            triplet_recipe_rr_idx: 0,
1608            text_recipe_rr_idx: 0,
1609            source_stream_cursors: vec![],
1610        };
1611        store.save_sampler_state(&state, None).unwrap();
1612        drop(store);
1613
1614        assert!(dest.exists(), "dest must exist after save(None)");
1615
1616        let loaded_dest = FileSplitStore::open(&dest, ratios, 11).unwrap();
1617        assert_eq!(
1618            loaded_dest
1619                .load_sampler_state()
1620                .unwrap()
1621                .unwrap()
1622                .source_cycle_idx,
1623            7
1624        );
1625    }
1626
1627    /// `save_sampler_state(Some(other))` on a bootstrapped store must publish to
1628    /// `other` only — the declared `save_path` must remain absent.
1629    #[test]
1630    fn save_some_on_bootstrapped_store_publishes_to_explicit_path_only() {
1631        let dir = tempdir().unwrap();
1632        let source = dir.path().join("load.bin");
1633        let save = dir.path().join("save.bin"); // canonical — should stay empty
1634        let other = dir.path().join("other.bin"); // explicit target
1635        let ratios = SplitRatios::default();
1636
1637        let _ = FileSplitStore::open(&source, ratios, 22).unwrap();
1638
1639        let store = FileSplitStore::open_with_load_path(Some(&source), &save, ratios, 22).unwrap();
1640        let state = PersistedSamplerState {
1641            source_cycle_idx: 3,
1642            source_record_cursors: vec![],
1643            source_epoch: 1,
1644            rng_state: 0,
1645            triplet_recipe_rr_idx: 0,
1646            text_recipe_rr_idx: 0,
1647            source_stream_cursors: vec![],
1648        };
1649        store
1650            .save_sampler_state(&state, Some(other.as_path()))
1651            .unwrap();
1652        drop(store);
1653
1654        assert!(
1655            !save.exists(),
1656            "canonical save_path must not be created when saving to explicit path"
1657        );
1658        assert!(other.exists(), "explicit target must be created");
1659
1660        let loaded_other = FileSplitStore::open(&other, ratios, 22).unwrap();
1661        assert_eq!(
1662            loaded_other
1663                .load_sampler_state()
1664                .unwrap()
1665                .unwrap()
1666                .source_cycle_idx,
1667            3
1668        );
1669    }
1670
1671    /// Repeated `save(None)` calls on a bootstrapped store are idempotent: the
1672    /// canonical save_path is overwritten cleanly each time.
1673    #[test]
1674    fn repeated_save_none_on_bootstrapped_store_is_idempotent() {
1675        let dir = tempdir().unwrap();
1676        let source = dir.path().join("load.bin");
1677        let dest = dir.path().join("save.bin");
1678        let ratios = SplitRatios::default();
1679
1680        let _ = FileSplitStore::open(&source, ratios, 33).unwrap();
1681
1682        let store = FileSplitStore::open_with_load_path(Some(&source), &dest, ratios, 33).unwrap();
1683
1684        for cycle_idx in [1_u64, 2, 3] {
1685            let state = PersistedSamplerState {
1686                source_cycle_idx: cycle_idx,
1687                source_record_cursors: vec![],
1688                source_epoch: 0,
1689                rng_state: 0,
1690                triplet_recipe_rr_idx: 0,
1691                text_recipe_rr_idx: 0,
1692                source_stream_cursors: vec![],
1693            };
1694            store.save_sampler_state(&state, None).unwrap();
1695        }
1696        drop(store);
1697
1698        let reloaded = FileSplitStore::open(&dest, ratios, 33).unwrap();
1699        assert_eq!(
1700            reloaded
1701                .load_sampler_state()
1702                .unwrap()
1703                .unwrap()
1704                .source_cycle_idx,
1705            3,
1706            "dest should hold the last-saved state"
1707        );
1708    }
1709}