1mod backends;
2
3use crate::chunking::{ChunkingAlgorithm, SlidingWindowChunker};
4use chrono::Duration;
5use indexmap::IndexMap;
6use rand::prelude::*;
7use rayon::prelude::*;
8use std::borrow::Cow;
9use std::collections::{HashMap, HashSet};
10use std::path::Path;
11use std::sync::atomic::{AtomicUsize, Ordering};
12use std::sync::mpsc;
13use std::sync::{Arc, Mutex};
14use std::thread;
15
16use crate::config::{
17 ChunkingStrategy, NegativeStrategy, SamplerConfig, Selector, TextRecipe, TripletRecipe,
18};
19use crate::constants::sampler::AUTO_INJECTED_LONG_SECTION_CHUNK_PAIR_RECIPE_NAME;
20use crate::constants::sampler::{
21 ANCHOR_POSITIVE_SWAP_MASK, EPOCH_SEED_OFFSET, EXHAUSTION_RETRY_LIMIT, NEG_REASON_WRONG_ARTICLE,
22 NEG_REASON_WRONG_DATE, NEG_REASON_WRONG_QA, PREFETCHER_SOURCE_ID, PREFETCHER_STOPPED_REASON,
23 RECIPE_LABEL_TEXT, RECIPE_LABEL_TRIPLETS, RECIPE_ORDER_MAX_WEIGHT_MULTIPLIER,
24 ROLE_LABEL_ANCHOR, ROLE_LABEL_CONTEXT, SAME_SELECTOR_PAIR_RETRY_LIMIT,
25};
26use crate::data::{
27 ChunkView, DataRecord, PairLabel, RecordChunk, RecordSection, SampleBatch, SamplePair,
28 SampleTriplet, SectionRole, TextBatch, TextSample, TripletBatch,
29};
30use crate::epoch::EpochTracker;
31use crate::errors::SamplerError;
32use crate::hash::{derive_epoch_seed, stable_hash_str};
33use crate::ingestion::IngestionManager;
34use crate::metadata::{META_FIELD_DATE, MetadataKey};
35use crate::metrics::{chunk_proximity_score, window_index_proximity};
36use crate::source::DataSource;
37use crate::splits::{
38 EpochStateStore, PersistedSamplerState, SamplerStateStore, SplitLabel, SplitStore,
39};
40use crate::tokenizer::{Tokenizer, WhitespaceTokenizer};
41use crate::types::{RecipeKey, RecordId, SourceId};
42use crate::utils::platform_newline;
43
44#[derive(Debug, Clone)]
57struct DeterministicRng {
59 state: u64,
60}
61
62impl DeterministicRng {
63 fn new(seed: u64) -> Self {
64 Self { state: seed }
65 }
66
67 fn from_state(state: u64) -> Self {
68 Self { state }
69 }
70
71 fn state(&self) -> u64 {
72 self.state
73 }
74
75 fn next_u64_internal(&mut self) -> u64 {
76 let mut z = self.state.wrapping_add(0x9E3779B97F4A7C15);
77 self.state = z;
78 z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
79 z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
80 z ^ (z >> 31)
81 }
82}
83
84impl rand::RngCore for DeterministicRng {
85 fn next_u32(&mut self) -> u32 {
86 self.next_u64_internal() as u32
87 }
88
89 fn next_u64(&mut self) -> u64 {
90 self.next_u64_internal()
91 }
92
93 fn fill_bytes(&mut self, dest: &mut [u8]) {
94 let mut offset = 0;
95 while offset < dest.len() {
96 let value = self.next_u64_internal();
97 let bytes = value.to_le_bytes();
98 let remaining = dest.len() - offset;
99 let copy_len = remaining.min(bytes.len());
100 dest[offset..offset + copy_len].copy_from_slice(&bytes[..copy_len]);
101 offset += copy_len;
102 }
103 }
104}
105
106pub fn chunk_weight(strategy: &ChunkingStrategy, chunk: &RecordChunk) -> f32 {
115 let floor = strategy.chunk_weight_floor;
116 let trust = chunk.quality.trust.clamp(0.0, 1.0);
117 let base = match &chunk.view {
118 ChunkView::Window { index, .. } => window_index_proximity(*index),
119 ChunkView::SummaryFallback { weight, .. } => *weight,
120 };
121 (base * trust).max(floor)
122}
123
124pub trait Sampler {
126 fn next_pair_batch(&self, split: SplitLabel) -> Result<SampleBatch, SamplerError> {
128 self.next_pair_batch_with_weights(split, &HashMap::new())
129 }
130 fn next_text_batch(&self, split: SplitLabel) -> Result<TextBatch, SamplerError> {
132 self.next_text_batch_with_weights(split, &HashMap::new())
133 }
134 fn next_triplet_batch(&self, split: SplitLabel) -> Result<TripletBatch, SamplerError> {
136 self.next_triplet_batch_with_weights(split, &HashMap::new())
137 }
138 fn next_pair_batch_with_weights(
140 &self,
141 split: SplitLabel,
142 weights: &HashMap<SourceId, f32>,
143 ) -> Result<SampleBatch, SamplerError>;
144 fn next_text_batch_with_weights(
146 &self,
147 split: SplitLabel,
148 weights: &HashMap<SourceId, f32>,
149 ) -> Result<TextBatch, SamplerError>;
150 fn next_triplet_batch_with_weights(
152 &self,
153 split: SplitLabel,
154 weights: &HashMap<SourceId, f32>,
155 ) -> Result<TripletBatch, SamplerError>;
156}
157
158pub struct BatchPrefetcher<T> {
160 receiver: Option<mpsc::Receiver<Result<T, SamplerError>>>,
161 handle: Option<thread::JoinHandle<()>>,
162 stats: Arc<PrefetcherStats>,
163}
164
165#[derive(Default)]
166struct PrefetcherStats {
168 queued: AtomicUsize,
169 produced: AtomicUsize,
170 errors: AtomicUsize,
171}
172
173impl<T: Send + 'static> BatchPrefetcher<T> {
174 fn new<F>(capacity: usize, mut producer: F) -> Self
175 where
176 F: FnMut() -> Result<T, SamplerError> + Send + 'static,
177 {
178 let (sender, receiver) = mpsc::sync_channel(capacity.max(1));
179 let stats = Arc::new(PrefetcherStats::default());
180 let stats_thread = Arc::clone(&stats);
181 let handle = thread::spawn(move || {
182 loop {
183 let result = producer();
184 if result.is_err() {
185 stats_thread.errors.fetch_add(1, Ordering::Relaxed);
186 }
187 if sender.send(result).is_err() {
188 return;
189 }
190 stats_thread.queued.fetch_add(1, Ordering::Relaxed);
191 stats_thread.produced.fetch_add(1, Ordering::Relaxed);
192 }
193 });
194 Self {
195 receiver: Some(receiver),
196 handle: Some(handle),
197 stats,
198 }
199 }
200
201 pub fn next(&self) -> Result<T, SamplerError> {
203 let receiver = self
204 .receiver
205 .as_ref()
206 .ok_or_else(|| SamplerError::SourceUnavailable {
207 source_id: PREFETCHER_SOURCE_ID.into(),
208 reason: PREFETCHER_STOPPED_REASON.into(),
209 })?;
210 let result = receiver.recv().unwrap_or_else(|_| {
211 Err(SamplerError::SourceUnavailable {
212 source_id: PREFETCHER_SOURCE_ID.into(),
213 reason: PREFETCHER_STOPPED_REASON.into(),
214 })
215 });
216 self.stats
217 .queued
218 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |value| {
219 Some(value.saturating_sub(1))
220 })
221 .ok();
222 result
223 }
224
225 pub fn queue_len(&self) -> usize {
227 self.stats.queued.load(Ordering::Relaxed)
228 }
229
230 pub fn produced_count(&self) -> usize {
232 self.stats.produced.load(Ordering::Relaxed)
233 }
234
235 pub fn error_count(&self) -> usize {
237 self.stats.errors.load(Ordering::Relaxed)
238 }
239}
240
241impl<T> Drop for BatchPrefetcher<T> {
242 fn drop(&mut self) {
243 self.receiver.take();
244 if let Some(handle) = self.handle.take() {
245 let _ = handle.join();
246 }
247 }
248}
249
250pub struct TripletSampler<S: SplitStore + EpochStateStore + SamplerStateStore + 'static> {
253 inner: Mutex<TripletSamplerInner<S>>,
254}
255
256struct TripletSamplerInner<S: SplitStore + EpochStateStore + SamplerStateStore + 'static> {
258 config: SamplerConfig,
260 chunker: Arc<dyn ChunkingAlgorithm>,
262 split_store: Arc<S>,
264 ingestion: IngestionManager,
266 records: IndexMap<RecordId, Arc<DataRecord>>,
268 rng: DeterministicRng,
270 triplet_recipes: Vec<TripletRecipe>,
272 text_recipes: Vec<TextRecipe>,
274 source_triplet_recipes: HashMap<SourceId, Vec<TripletRecipe>>,
276 sources_with_long_sections: HashSet<SourceId>,
278 source_text_recipes: HashMap<SourceId, Vec<TextRecipe>>,
280 using_config_triplet_recipes: bool,
282 using_config_text_recipes: bool,
284 last_observed_ingest: u64,
286 epoch_tracker: EpochTracker,
288 chunk_cursors: HashMap<(RecordId, usize), usize>,
290 role_cursors: HashMap<(RecordId, String), usize>,
292 negative_backend: Box<dyn backends::NegativeBackend>,
294 chunk_index: HashMap<RecordId, RecordId>,
296 source_order: Vec<SourceId>,
298 source_cycle_idx: usize,
300 source_state_loaded: bool,
302 ingestion_cursors_loaded: bool,
304 source_state_dirty: bool,
306 source_record_indices: HashMap<SourceId, Vec<usize>>,
308 source_record_cursors: HashMap<SourceId, usize>,
310 triplet_recipe_rr_idx: usize,
312 emitted_text_hashes: HashMap<RecordId, HashSet<u64>>,
321 text_recipe_rr_idx: usize,
323 epoch: u64,
325
326 source_wrapped: HashMap<SourceId, bool>,
328}
329
330impl<S: SplitStore + EpochStateStore + SamplerStateStore + 'static> TripletSamplerInner<S> {
331 fn new(config: SamplerConfig, split_store: Arc<S>) -> Self {
332 Self::new_with_chunker(config, split_store, Arc::new(SlidingWindowChunker))
333 }
334
335 fn new_with_chunker(
336 config: SamplerConfig,
337 split_store: Arc<S>,
338 chunker: Arc<dyn ChunkingAlgorithm>,
339 ) -> Self {
340 let buffer_size = config.ingestion_max_records.max(config.batch_size).max(2);
341 let using_config_triplet_recipes = !config.recipes.is_empty();
342 let using_config_text_recipes = !config.text_recipes.is_empty();
343 let triplet_recipes = if using_config_triplet_recipes {
344 config.recipes.clone()
345 } else {
346 Vec::new()
347 };
348 let text_recipes = if using_config_text_recipes {
349 config.text_recipes.clone()
350 } else if !triplet_recipes.is_empty() {
351 Self::build_derived_text_recipes(&triplet_recipes)
352 } else {
353 Vec::new()
354 };
355 let ingestion = IngestionManager::new(buffer_size, config.clone());
356 let epoch_backend = Some(Arc::clone(&split_store) as Arc<dyn EpochStateStore>);
357 let epoch_tracker = EpochTracker::new(
358 true,
359 epoch_backend,
360 derive_epoch_seed(config.seed, EPOCH_SEED_OFFSET),
361 );
362 let mut sampler = Self {
363 rng: DeterministicRng::new(config.seed),
364 config,
365 chunker,
366 split_store,
367 ingestion,
368 records: IndexMap::new(),
369 triplet_recipes,
370 text_recipes,
371 source_triplet_recipes: HashMap::new(),
372 sources_with_long_sections: HashSet::new(),
373 source_text_recipes: HashMap::new(),
374 using_config_triplet_recipes,
375 using_config_text_recipes,
376 last_observed_ingest: 0,
377 epoch_tracker,
378 chunk_cursors: HashMap::new(),
379 role_cursors: HashMap::new(),
380 negative_backend: {
381 #[cfg(feature = "bm25-mining")]
382 {
383 Box::new(backends::Bm25Backend::new())
384 }
385 #[cfg(not(feature = "bm25-mining"))]
386 {
387 Box::new(backends::DefaultBackend)
388 }
389 },
390 chunk_index: HashMap::new(),
391 source_order: Vec::new(),
392 source_cycle_idx: 0,
393 source_state_loaded: false,
394 ingestion_cursors_loaded: false,
395 source_state_dirty: false,
396 source_record_indices: HashMap::new(),
397 source_record_cursors: HashMap::new(),
398 emitted_text_hashes: HashMap::new(),
399 triplet_recipe_rr_idx: 0,
400 text_recipe_rr_idx: 0,
401 epoch: 0,
402 source_wrapped: HashMap::new(),
403 };
404 if !sampler.using_config_text_recipes {
405 sampler.rebuild_derived_text_recipes();
406 }
407 sampler
408 }
409
410 fn text_recipes(&self) -> &[TextRecipe] {
411 &self.text_recipes
412 }
413
414 fn epoch_seed(&self) -> u64 {
417 derive_epoch_seed(self.config.seed, self.epoch)
418 }
419
420 fn register_source(
421 &mut self,
422 source: Box<dyn DataSource + 'static>,
423 ) -> Result<(), SamplerError> {
424 let source_id = source.id().to_string();
425 if !self.using_config_triplet_recipes {
426 let triplets = source.default_triplet_recipes();
427 if !triplets.is_empty() {
428 self.source_triplet_recipes
429 .insert(source_id.clone(), triplets.clone());
430 if !self.using_config_text_recipes {
431 let derived = Self::build_derived_text_recipes(&triplets);
432 self.source_text_recipes
433 .insert(source_id.clone(), derived.clone());
434 self.extend_text_recipes_unique(&derived);
435 }
436 }
437 }
438 self.ingestion.register_source(source)?;
439 Ok(())
440 }
441
442 fn set_epoch(&mut self, epoch: u64) -> Result<(), SamplerError> {
443 self.epoch_tracker.ensure_loaded()?;
444 self.epoch_tracker.force_epoch(epoch);
445 self.epoch = epoch;
446 self.ingestion.set_epoch(epoch);
447 self.ingestion.reset_stream_cursors();
448 self.ingestion.reset_epoch_step();
451 self.source_record_cursors.clear();
452 self.source_cycle_idx = 0;
453 for source in &self.source_order {
454 self.source_wrapped.insert(source.clone(), false);
455 }
456 self.rebuild_source_index()?;
457 self.source_state_dirty = self.source_order.len() > 1;
458 Ok(())
459 }
460
461 fn next_chunk_from_pool(
462 &mut self,
463 record_id: &str,
464 section_idx: usize,
465 pool: Vec<RecordChunk>,
466 ) -> Option<RecordChunk> {
467 if pool.is_empty() {
468 return None;
469 }
470 let key = (record_id.to_string(), section_idx);
471 if !self.chunk_cursors.contains_key(&key) {
472 let cursor_key = format!("{}::{}", record_id, section_idx);
478 let start = (stable_hash_str(self.epoch_seed(), &cursor_key) as usize) % pool.len();
479 self.chunk_cursors.insert(key.clone(), start);
480 }
481 let cursor = self.chunk_cursors.entry(key).or_insert(0);
482 if *cursor >= pool.len() {
483 *cursor = 0;
484 }
485 let chunk = pool.get(*cursor).cloned();
486 *cursor = (*cursor + 1) % pool.len();
487 chunk
488 }
489
490 fn prune_cursor_state(&mut self) {
491 if self.chunk_cursors.is_empty()
492 && self.role_cursors.is_empty()
493 && self.negative_backend.cursors_empty()
494 {
495 return;
496 }
497 let valid_ids: HashSet<RecordId> = self.records.keys().cloned().collect();
498 self.chunk_cursors
499 .retain(|(record_id, _), _| valid_ids.contains(record_id));
500 self.role_cursors
501 .retain(|(record_id, _), _| valid_ids.contains(record_id));
502 self.negative_backend.prune_cursors(&valid_ids);
503 }
504
505 fn rebuild_chunk_index(&mut self) {
506 self.chunk_index.clear();
507 for record in self.records.values() {
508 self.chunk_index
509 .insert(record.id.clone(), record.id.clone());
510 }
511 }
512
513 fn rebuild_source_index(&mut self) -> Result<(), SamplerError> {
514 self.source_record_indices.clear();
515 let mut label_cache: HashMap<RecordId, SplitLabel> = HashMap::new();
516 let allowed = self.allowed_target_splits();
517 let allowed_set: HashSet<SplitLabel> = allowed.into_iter().collect();
518 for (idx, record) in self.records.values().enumerate() {
519 let label = if let Some(label) = label_cache.get(&record.id) {
520 *label
521 } else {
522 let label = match self.split_store.label_for(&record.id) {
523 Some(label) => label,
524 None => self.split_store.ensure(record.id.clone())?,
525 };
526 label_cache.insert(record.id.clone(), label);
527 label
528 };
529 if !allowed_set.contains(&label) {
530 continue;
531 }
532 self.source_record_indices
533 .entry(record.source.clone())
534 .or_default()
535 .push(idx);
536 }
537
538 let shuffle_seed = self.epoch_seed();
539 for indices in self.source_record_indices.values_mut() {
540 indices.sort_by_key(|idx| {
541 self.records
542 .get_index(*idx)
543 .map(|(_, record)| stable_hash_str(shuffle_seed, &record.id))
544 .unwrap_or(0)
545 });
546 }
547
548 self.source_order = self.source_record_indices.keys().cloned().collect();
549 self.source_order.sort();
550 self.refresh_source_wrapped();
551
552 self.source_record_cursors
553 .retain(|source, _| self.source_record_indices.contains_key(source));
554 if self.source_state_loaded {
555 if self.source_order.is_empty() {
556 self.source_cycle_idx = 0;
557 }
558 self.source_state_dirty = self.source_order.len() > 1;
559 }
560 Ok(())
561 }
562
563 fn refresh_source_wrapped(&mut self) {
564 self.source_wrapped.clear();
565 for source in &self.source_order {
566 let len = self
567 .source_record_indices
568 .get(source)
569 .map(|items| items.len())
570 .unwrap_or(0);
571 if len == 0 {
572 self.source_wrapped.insert(source.clone(), false);
573 continue;
574 }
575 let cursor = self.source_record_cursors.get(source).copied().unwrap_or(0);
576 let wrapped = cursor > 0 && cursor % len == 0;
577 self.source_wrapped.insert(source.clone(), wrapped);
578 }
579 }
580
581 fn shuffled_source_cycle(&self, cycle: u64) -> Vec<SourceId> {
582 let mut sources = self.source_order.clone();
583 let seed = self.epoch_seed() ^ cycle;
584 sources.sort_by_key(|source| stable_hash_str(seed, source));
585 sources
586 }
587
588 fn ensure_source_state(&mut self) -> Result<(), SamplerError> {
589 if self.source_state_loaded {
590 return Ok(());
591 }
592 let persisted = self.split_store.load_sampler_state()?;
593 self.source_cycle_idx = persisted
594 .as_ref()
595 .map(|state| state.source_cycle_idx as usize)
596 .unwrap_or(0);
597 if let Some(state) = persisted {
598 for (source, cursor) in state.source_record_cursors {
599 if self.source_record_indices.contains_key(&source) {
600 self.source_record_cursors.insert(source, cursor as usize);
601 }
602 }
603 self.epoch = state.epoch;
604 self.ingestion.set_epoch(state.epoch);
605 self.rng = DeterministicRng::from_state(state.rng_state);
606 self.triplet_recipe_rr_idx = state.triplet_recipe_rr_idx as usize;
607 self.text_recipe_rr_idx = state.text_recipe_rr_idx as usize;
608 }
609 self.refresh_source_wrapped();
610 self.source_state_loaded = true;
611 self.source_state_dirty = true;
612 Ok(())
613 }
614
615 fn persist_source_state(&mut self, save_to: Option<&Path>) -> Result<(), SamplerError> {
616 if !self.source_state_loaded {
617 return Ok(());
618 }
619 let state = PersistedSamplerState {
620 source_cycle_idx: self.source_cycle_idx as u64,
621 source_record_cursors: self
622 .source_record_cursors
623 .iter()
624 .map(|(source, cursor)| (source.clone(), *cursor as u64))
625 .collect(),
626 epoch: self.epoch,
627 epoch_step: self.ingestion.epoch_step(),
628 rng_state: self.rng.state(),
629 triplet_recipe_rr_idx: self.triplet_recipe_rr_idx as u64,
630 text_recipe_rr_idx: self.text_recipe_rr_idx as u64,
631 source_stream_cursors: self.ingestion.snapshot_cursors(),
632 };
633 self.split_store.save_sampler_state(&state, save_to)?;
634 self.source_state_dirty = false;
635 Ok(())
636 }
637
638 fn rebuild_derived_text_recipes(&mut self) {
639 if self.using_config_text_recipes {
640 return;
641 }
642 if self.triplet_recipes.is_empty() {
643 self.text_recipes.clear();
644 } else {
645 self.text_recipes = Self::build_derived_text_recipes(&self.triplet_recipes);
646 }
647 }
648
649 fn extend_text_recipes_unique(&mut self, recipes: &[TextRecipe]) {
650 for recipe in recipes {
651 if self
652 .text_recipes
653 .iter()
654 .any(|existing| existing.name == recipe.name)
655 {
656 continue;
657 }
658 self.text_recipes.push(recipe.clone());
659 }
660 }
661
662 fn configured_triplet_recipes_for_source<'a>(&'a self, source: &str) -> &'a [TripletRecipe] {
663 if self.using_config_triplet_recipes {
664 return &self.triplet_recipes;
665 }
666 self.source_triplet_recipes
667 .get(source)
668 .map(|recipes| recipes.as_slice())
669 .unwrap_or(&[])
670 }
671
672 fn contains_auto_chunk_pair_recipe(recipes: &[TripletRecipe]) -> bool {
674 recipes
675 .iter()
676 .any(|recipe| recipe.name.as_ref() == AUTO_INJECTED_LONG_SECTION_CHUNK_PAIR_RECIPE_NAME)
677 }
678
679 fn source_supports_chunk_pair_recipe(&self, source: &str) -> bool {
680 if self.config.chunking.max_window_tokens == 0 {
681 return false;
682 }
683 self.sources_with_long_sections.contains(source)
684 }
685
686 fn should_auto_inject_chunk_pair_recipe(
692 &self,
693 source: &str,
694 recipes: &[TripletRecipe],
695 ) -> bool {
696 self.source_supports_chunk_pair_recipe(source)
697 && !Self::contains_auto_chunk_pair_recipe(recipes)
698 }
699
700 fn source_chunk_pair_recipe() -> TripletRecipe {
711 TripletRecipe {
712 name: Cow::Borrowed(AUTO_INJECTED_LONG_SECTION_CHUNK_PAIR_RECIPE_NAME),
713 anchor: Selector::Role(SectionRole::Context),
714 positive_selector: Selector::Role(SectionRole::Context),
715 negative_selector: Selector::Role(SectionRole::Context),
716 negative_strategy: NegativeStrategy::WrongArticle,
717 weight: 1.0,
718 instruction: None,
719 allow_same_anchor_positive: false,
720 }
721 }
722
723 fn resolve_source_triplet_plan(&self, source: &str) -> (Vec<TripletRecipe>, bool) {
735 let mut recipes = self.configured_triplet_recipes_for_source(source).to_vec();
736 let mut auto_injected = false;
737 if self.should_auto_inject_chunk_pair_recipe(source, &recipes) {
738 recipes.push(Self::source_chunk_pair_recipe());
739 auto_injected = true;
740 }
741 (recipes, auto_injected)
742 }
743
744 #[cfg(test)]
745 fn triplet_recipes_for_source(&self, source: &str) -> Vec<TripletRecipe> {
746 self.resolve_source_triplet_plan(source).0
747 }
748
749 fn triplet_recipe_count_for_source(&self, source: &str) -> usize {
750 let (recipes, _auto_injected) = self.resolve_source_triplet_plan(source);
751 recipes.len()
752 }
753
754 fn text_recipes_for_source<'a>(&'a self, source: &str) -> &'a [TextRecipe] {
755 if self.using_config_text_recipes || self.using_config_triplet_recipes {
756 return &self.text_recipes;
757 }
758 self.source_text_recipes
759 .get(source)
760 .map(|recipes| recipes.as_slice())
761 .unwrap_or(&[])
762 }
763
764 fn recipe_order_weighted_shuffled(
775 &mut self,
776 weights: &[f32],
777 rng: &mut DeterministicRng,
778 ) -> Vec<usize> {
779 weighted_recipe_order(weights, rng)
780 }
781
782 fn recipe_order_weighted_cycled(
788 &mut self,
789 weights: &[f32],
790 rr_idx: usize,
791 rng: &mut DeterministicRng,
792 ) -> Vec<usize> {
793 let base = self.recipe_order_weighted_shuffled(weights, rng);
794 if base.is_empty() {
795 return base;
796 }
797 let start = rr_idx % base.len();
798 let mut order = Vec::with_capacity(base.len());
799 order.extend_from_slice(&base[start..]);
800 order.extend_from_slice(&base[..start]);
801 order
802 }
803
804 fn text_recipe_order_weighted_shuffled(
808 &mut self,
809 weights: &[f32],
810 rng: &mut DeterministicRng,
811 ) -> Vec<usize> {
812 weighted_recipe_order(weights, rng)
813 }
814
815 fn text_recipe_order_weighted_cycled(
816 &mut self,
817 weights: &[f32],
818 rr_idx: usize,
819 rng: &mut DeterministicRng,
820 ) -> Vec<usize> {
821 let base = self.text_recipe_order_weighted_shuffled(weights, rng);
822 if base.is_empty() {
823 return base;
824 }
825 let start = rr_idx % base.len();
826 let mut order = Vec::with_capacity(base.len());
827 order.extend_from_slice(&base[start..]);
828 order.extend_from_slice(&base[..start]);
829 order
830 }
831
832 fn allowed_target_splits(&self) -> Vec<SplitLabel> {
833 self.config.allowed_splits.clone()
834 }
835
836 fn ensure_split_allowed(&self, split: SplitLabel) -> Result<(), SamplerError> {
837 let allowed = self.allowed_target_splits();
838 if allowed.contains(&split) {
839 return Ok(());
840 }
841 Err(SamplerError::Configuration(format!(
842 "requested split {:?} is not in allowed_splits {:?}",
843 split, allowed
844 )))
845 }
846
847 fn ensure_split_has_records(&mut self, target_split: SplitLabel) -> Result<(), SamplerError> {
848 let records_by_split = self.records_by_split()?;
849 if records_by_split
850 .get(&target_split)
851 .map(|records| !records.is_empty())
852 .unwrap_or(false)
853 {
854 return Ok(());
855 }
856 Err(SamplerError::Exhausted(
857 "no records available for target split".into(),
858 ))
859 }
860
861 fn records_by_split(
862 &self,
863 ) -> Result<HashMap<SplitLabel, Vec<(RecordId, SourceId)>>, SamplerError> {
864 let mut map: HashMap<SplitLabel, Vec<(RecordId, SourceId)>> = HashMap::new();
865 let mut label_cache: HashMap<RecordId, SplitLabel> = HashMap::new();
866 for (chunk_id, record_id) in &self.chunk_index {
867 let Some(record) = self.records.get(record_id) else {
868 continue;
869 };
870 let label = if let Some(label) = label_cache.get(record_id) {
871 *label
872 } else {
873 let label = match self.split_store.label_for(record_id) {
874 Some(label) => label,
875 None => self.split_store.ensure(record_id.clone())?,
876 };
877 label_cache.insert(record_id.clone(), label);
878 label
879 };
880 map.entry(label)
881 .or_default()
882 .push((chunk_id.clone(), record.source.clone()));
883 }
884 Ok(map)
885 }
886
887 fn choose_anchor_record(
888 &mut self,
889 source: Option<&str>,
890 split: SplitLabel,
891 ) -> Option<Arc<DataRecord>> {
892 if let Some(source) = source {
893 let indices = self.source_record_indices.get(source)?;
894 if indices.is_empty() {
895 return None;
896 }
897 let mut cursor = *self.source_record_cursors.get(source).unwrap_or(&0);
898 let cycle = cursor / indices.len();
899 let offset_seed = self.epoch_seed() ^ (cycle as u64);
900 let offset = (stable_hash_str(offset_seed, source) as usize) % indices.len();
901 let mut wrapped = false;
902 let mut selected: Option<Arc<DataRecord>> = None;
903 for _ in 0..indices.len() {
904 let pos = (cursor % indices.len()).saturating_add(offset) % indices.len();
905 let idx = indices[pos];
906 cursor = cursor.saturating_add(1);
907 if cursor.is_multiple_of(indices.len()) {
908 wrapped = true;
909 }
910 if let Some((_, record)) = self.records.get_index(idx) {
911 if self.split_store.label_for(&record.id) != Some(split) {
912 continue;
913 }
914 selected = Some(Arc::clone(record));
915 break;
916 }
917 }
918 self.source_record_cursors
919 .insert(source.to_string(), cursor);
920 if wrapped {
921 self.mark_source_wrapped(source);
922 }
923 return selected;
924 }
925 while let Some(chunk_id) = self.epoch_tracker.next_record(split) {
926 if let Some(record_id) = self.chunk_index.get(&chunk_id)
927 && let Some(record) = self.records.get(record_id)
928 {
929 return Some(Arc::clone(record));
930 }
931 }
932 None
933 }
934
935 fn save_sampler_state(&mut self, save_to: Option<&Path>) -> Result<(), SamplerError> {
936 if self.epoch_tracker.is_enabled() {
937 self.epoch_tracker.persist()?;
938 }
939 self.persist_source_state(save_to)?;
940 Ok(())
941 }
942
943 fn mark_source_wrapped(&mut self, source: &str) {
944 self.source_wrapped.insert(source.to_string(), true);
945 if self.source_order.is_empty() {
946 return;
947 }
948 let all_wrapped = self
949 .source_order
950 .iter()
951 .all(|name| self.source_wrapped.get(name).copied().unwrap_or(false));
952 if all_wrapped {
953 self.advance_epoch();
954 }
955 }
956
957 fn advance_epoch(&mut self) {
958 self.epoch = self.epoch.saturating_add(1);
965 self.ingestion.set_epoch(self.epoch);
966 self.source_record_cursors.clear();
973 self.source_cycle_idx = 0;
974 for source in &self.source_order {
977 self.source_wrapped.insert(source.clone(), false);
978 }
979 let _ = self.rebuild_source_index();
980 self.source_state_dirty = self.source_order.len() > 1;
981 }
982
983 fn select_temporal_neighbor(
984 &'_ self,
985 record: &DataRecord,
986 offset_days: i32,
987 ) -> Option<Arc<DataRecord>> {
988 let target = record.created_at + Duration::days(offset_days.into());
989 let key = record.taxonomy.first().cloned();
990 let record_split = self.split_store.label_for(&record.id)?;
991 self.records
992 .values()
993 .filter(|candidate| {
994 candidate.id != record.id
995 && self
996 .split_store
997 .label_for(&candidate.id)
998 .map(|label| label == record_split)
999 .unwrap_or(false)
1000 && (candidate.source == record.source
1001 || key
1002 .as_ref()
1003 .zip(candidate.taxonomy.first())
1004 .map(|(a, b)| a == b)
1005 .unwrap_or(false))
1006 })
1007 .min_by_key(|candidate| (candidate.created_at - target).num_seconds().abs())
1008 .cloned()
1009 }
1010
1011 fn select_negative_record(
1012 &self,
1013 anchor_record: &DataRecord,
1014 strategy: &NegativeStrategy,
1015 anchor_query_text: Option<&str>,
1016 rng: &mut dyn rand::RngCore,
1017 ) -> Option<(Arc<DataRecord>, bool)> {
1018 let anchor_split = self.split_store.label_for(&anchor_record.id)?;
1019
1020 let in_anchor_split = |candidate: &DataRecord| {
1021 self.split_store
1022 .label_for(&candidate.id)
1023 .map(|label| label == anchor_split)
1024 .unwrap_or(false)
1025 };
1026
1027 match strategy {
1028 NegativeStrategy::WrongArticle => {
1029 let anchor_date =
1030 taxonomy_value(anchor_record, META_FIELD_DATE).map(|d| d.to_string());
1031 let mut same_date: Vec<Arc<DataRecord>> = self
1032 .records
1033 .values()
1034 .filter(|candidate| {
1035 candidate.source == anchor_record.source
1036 && candidate.id != anchor_record.id
1037 && in_anchor_split(candidate)
1038 })
1039 .filter(|candidate| {
1040 anchor_date
1041 .as_deref()
1042 .zip(taxonomy_value(candidate, META_FIELD_DATE))
1043 .map(|(a, b)| a == b)
1044 .unwrap_or(false)
1045 })
1046 .cloned()
1047 .collect();
1048 if same_date.is_empty() {
1049 same_date = self
1050 .records
1051 .values()
1052 .filter(|candidate| {
1053 candidate.source == anchor_record.source
1054 && candidate.id != anchor_record.id
1055 && in_anchor_split(candidate)
1056 })
1057 .cloned()
1058 .collect();
1059 }
1060 if !same_date.is_empty() {
1061 return self.negative_backend.choose_negative(
1062 anchor_record,
1063 anchor_split,
1064 same_date,
1065 false,
1066 anchor_query_text,
1067 rng,
1068 );
1069 }
1070 let pool = self
1071 .records
1072 .values()
1073 .filter(|candidate| {
1074 candidate.id != anchor_record.id && in_anchor_split(candidate)
1075 })
1076 .cloned()
1077 .collect::<Vec<_>>();
1078 self.negative_backend.choose_negative(
1079 anchor_record,
1080 anchor_split,
1081 pool,
1082 true,
1083 anchor_query_text,
1084 rng,
1085 )
1086 }
1087 NegativeStrategy::WrongPublicationDate => {
1088 let anchor_date =
1089 taxonomy_value(anchor_record, META_FIELD_DATE).map(|d| d.to_string());
1090 let pool: Vec<Arc<DataRecord>> = self
1091 .records
1092 .values()
1093 .filter(|candidate| {
1094 candidate.source == anchor_record.source
1095 && candidate.id != anchor_record.id
1096 && in_anchor_split(candidate)
1097 })
1098 .filter(|candidate| {
1099 match (
1100 anchor_date.as_deref(),
1101 taxonomy_value(candidate, META_FIELD_DATE),
1102 ) {
1103 (Some(anchor), Some(candidate_date)) => anchor != candidate_date,
1104 (Some(_), None) => true,
1105 (None, Some(_)) => true,
1106 (None, None) => false,
1107 }
1108 })
1109 .cloned()
1110 .collect();
1111 if pool.is_empty() {
1112 let fallback_pool = self
1115 .records
1116 .values()
1117 .filter(|candidate| {
1118 candidate.id != anchor_record.id && in_anchor_split(candidate)
1119 })
1120 .cloned()
1121 .collect::<Vec<_>>();
1122
1123 return self.negative_backend.choose_negative(
1124 anchor_record,
1125 anchor_split,
1126 fallback_pool,
1127 true,
1128 anchor_query_text,
1129 rng,
1130 );
1131 }
1132
1133 self.negative_backend.choose_negative(
1134 anchor_record,
1135 anchor_split,
1136 pool,
1137 false,
1138 anchor_query_text,
1139 rng,
1140 )
1141 }
1142 NegativeStrategy::QuestionAnswerMismatch => {
1143 let pool: Vec<Arc<DataRecord>> = self
1144 .records
1145 .values()
1146 .filter(|candidate| {
1147 candidate.source == anchor_record.source
1148 && candidate.id != anchor_record.id
1149 && in_anchor_split(candidate)
1150 })
1151 .cloned()
1152 .collect();
1153 if pool.is_empty() {
1154 let fallback_pool = self
1157 .records
1158 .values()
1159 .filter(|candidate| {
1160 candidate.id != anchor_record.id && in_anchor_split(candidate)
1161 })
1162 .cloned()
1163 .collect::<Vec<_>>();
1164
1165 return self.negative_backend.choose_negative(
1166 anchor_record,
1167 anchor_split,
1168 fallback_pool,
1169 true,
1170 anchor_query_text,
1171 rng,
1172 );
1173 }
1174
1175 self.negative_backend.choose_negative(
1176 anchor_record,
1177 anchor_split,
1178 pool,
1179 false,
1180 anchor_query_text,
1181 rng,
1182 )
1183 }
1184 }
1185 }
1186
1187 fn is_auto_chunk_pair_recipe(recipe: &TripletRecipe) -> bool {
1189 recipe.name.as_ref() == AUTO_INJECTED_LONG_SECTION_CHUNK_PAIR_RECIPE_NAME
1190 }
1191
1192 fn select_anchor_positive_pair(
1196 &mut self,
1197 record: &DataRecord,
1198 anchor_selector: &Selector,
1199 positive_selector: &Selector,
1200 enforce_window_pair: bool,
1201 ) -> Option<(RecordChunk, RecordChunk)> {
1202 let mut anchor_chunk = self.select_chunk(record, anchor_selector)?;
1203 let mut positive_chunk = self.select_chunk(record, positive_selector)?;
1204 if anchor_selector == positive_selector {
1205 let mut retries = 0usize;
1206 while !same_selector_pair_is_valid(&anchor_chunk, &positive_chunk, enforce_window_pair)
1207 && retries < SAME_SELECTOR_PAIR_RETRY_LIMIT
1208 {
1209 let Some(redraw_anchor) = self.select_chunk(record, anchor_selector) else {
1210 break;
1211 };
1212 let Some(redraw_positive) = self.select_chunk(record, positive_selector) else {
1213 break;
1214 };
1215 anchor_chunk = redraw_anchor;
1216 positive_chunk = redraw_positive;
1217 retries += 1;
1218 }
1219 if !same_selector_pair_is_valid(&anchor_chunk, &positive_chunk, enforce_window_pair) {
1220 return None;
1221 }
1222 }
1223 Some((anchor_chunk, positive_chunk))
1224 }
1225
1226 fn select_distinct_window_pair_for_auto_recipe(
1236 &mut self,
1237 recipe: &TripletRecipe,
1238 record: &DataRecord,
1239 ) -> Option<(RecordChunk, RecordChunk)> {
1240 if recipe.anchor != recipe.positive_selector {
1241 return None;
1242 }
1243 self.select_anchor_positive_pair(record, &recipe.anchor, &recipe.positive_selector, true)
1244 }
1245
1246 fn record_has_at_least_two_window_chunks_for_selector(
1251 &self,
1252 record: &DataRecord,
1253 selector: &Selector,
1254 ) -> bool {
1255 let section_indices: Vec<usize> = match selector {
1256 Selector::Role(role) => record
1257 .sections
1258 .iter()
1259 .enumerate()
1260 .filter(|(_, section)| roles_match(role, §ion.role))
1261 .map(|(idx, _)| idx)
1262 .collect(),
1263 Selector::Paragraph(idx) => {
1264 if *idx < record.sections.len() {
1265 vec![*idx]
1266 } else {
1267 Vec::new()
1268 }
1269 }
1270 Selector::Random => (0..record.sections.len()).collect(),
1271 Selector::TemporalOffset(_) => return false,
1272 };
1273
1274 let mut window_count = 0usize;
1275 for section_idx in section_indices {
1276 let Some(section) = record.sections.get(section_idx) else {
1277 continue;
1278 };
1279 let chunks = self.materialize_chunks(record, section_idx, section);
1280 window_count += chunks
1281 .iter()
1282 .filter(|chunk| matches!(chunk.view, ChunkView::Window { .. }))
1283 .count();
1284 if window_count >= 2 {
1285 return true;
1286 }
1287 }
1288 false
1289 }
1290
1291 fn build_triplet_with_selector_pair_policy(
1293 &mut self,
1294 recipe: &TripletRecipe,
1295 record: &DataRecord,
1296 enforce_window_pair: bool,
1297 rng: &mut DeterministicRng,
1298 ) -> Option<SampleTriplet> {
1299 let (mut anchor_chunk, mut positive_chunk) = self.select_anchor_positive_pair(
1300 record,
1301 &recipe.anchor,
1302 &recipe.positive_selector,
1303 enforce_window_pair,
1304 )?;
1305 let anchor_raw_text = anchor_chunk.text.clone();
1306 self.decorate_chunk(record, &mut anchor_chunk, rng);
1307 self.decorate_chunk(record, &mut positive_chunk, rng);
1308 self.finalize_triplet_with_negative(
1311 recipe,
1312 record,
1313 anchor_chunk,
1314 positive_chunk,
1315 &anchor_raw_text,
1316 rng,
1317 )
1318 }
1319
1320 fn make_auto_chunk_pair_triplet_with_anchor(
1330 &mut self,
1331 recipe: &TripletRecipe,
1332 record: &DataRecord,
1333 rng: &mut DeterministicRng,
1334 ) -> Option<SampleTriplet> {
1335 if !self.record_has_at_least_two_window_chunks_for_selector(record, &recipe.anchor) {
1336 return None;
1337 }
1338 let (mut anchor_chunk, mut positive_chunk) =
1339 self.select_distinct_window_pair_for_auto_recipe(recipe, record)?;
1340 let anchor_raw_text = anchor_chunk.text.clone();
1341 self.decorate_chunk(record, &mut anchor_chunk, rng);
1342 self.decorate_chunk(record, &mut positive_chunk, rng);
1343 self.finalize_triplet_with_negative(
1344 recipe,
1345 record,
1346 anchor_chunk,
1347 positive_chunk,
1348 &anchor_raw_text,
1349 rng,
1350 )
1351 }
1352
1353 fn make_standard_triplet_with_anchor(
1354 &mut self,
1355 recipe: &TripletRecipe,
1356 record: &DataRecord,
1357 rng: &mut DeterministicRng,
1358 ) -> Option<SampleTriplet> {
1359 self.build_triplet_with_selector_pair_policy(recipe, record, false, rng)
1360 }
1361
1362 fn finalize_triplet_with_negative(
1381 &mut self,
1382 recipe: &TripletRecipe,
1383 record: &DataRecord,
1384 anchor_chunk: RecordChunk,
1385 positive_chunk: RecordChunk,
1386 anchor_raw_text: &str,
1387 rng: &mut DeterministicRng,
1388 ) -> Option<SampleTriplet> {
1389 let (negative_record, fallback_used) = self.select_negative_record(
1390 record,
1391 &recipe.negative_strategy,
1392 Some(anchor_raw_text),
1393 rng,
1394 )?;
1395 let mut negative_chunk = self.select_chunk(&negative_record, &recipe.negative_selector)?;
1396 self.decorate_chunk(&negative_record, &mut negative_chunk, rng);
1397
1398 let (anchor_chunk, positive_chunk) = if rng.next_u64() & ANCHOR_POSITIVE_SWAP_MASK == 0 {
1400 (positive_chunk, anchor_chunk)
1401 } else {
1402 (anchor_chunk, positive_chunk)
1403 };
1404
1405 if (!recipe.allow_same_anchor_positive && anchor_chunk.text == positive_chunk.text)
1416 || negative_chunk.text == positive_chunk.text
1417 || negative_chunk.text == anchor_chunk.text
1418 {
1419 return None;
1420 }
1421
1422 let chunk_weight =
1423 self.triplet_chunk_weight(recipe, &anchor_chunk, &positive_chunk, &negative_chunk);
1424 let weight = recipe.weight * chunk_weight;
1425 let recipe_name = if fallback_used {
1426 format!("{}_fallback_same_split", recipe.name)
1427 } else {
1428 recipe.name.to_string()
1429 };
1430 Some(SampleTriplet {
1431 recipe: recipe_name,
1432 anchor: anchor_chunk,
1433 positive: positive_chunk,
1434 negative: negative_chunk,
1435 weight,
1436 instruction: recipe.instruction.as_ref().map(|s| s.to_string()),
1437 })
1438 }
1439
1440 fn make_triplet_with_anchor(
1441 &mut self,
1442 recipe: &TripletRecipe,
1443 record: &DataRecord,
1444 rng: &mut DeterministicRng,
1445 ) -> Option<SampleTriplet> {
1446 if Self::is_auto_chunk_pair_recipe(recipe) {
1447 return self.make_auto_chunk_pair_triplet_with_anchor(recipe, record, rng);
1448 }
1449 self.make_standard_triplet_with_anchor(recipe, record, rng)
1450 }
1451
1452 fn make_text_sample_for_split(
1453 &mut self,
1454 recipe: &TextRecipe,
1455 source: Option<&str>,
1456 split: SplitLabel,
1457 rng: &mut DeterministicRng,
1458 ) -> Option<TextSample> {
1459 let record = self.choose_anchor_record(source, split)?;
1460 let mut chunk = self.select_chunk(&record, &recipe.selector)?;
1461 self.decorate_chunk(&record, &mut chunk, rng);
1462 let weight = recipe.weight * self.chunk_weight(&chunk);
1463 Some(TextSample {
1464 recipe: recipe.name.to_string(),
1465 chunk,
1466 weight,
1467 instruction: recipe.instruction.as_ref().map(|s| s.to_string()),
1468 })
1469 }
1470
1471 fn chunk_weight(&self, chunk: &RecordChunk) -> f32 {
1472 chunk_weight(&self.config.chunking, chunk)
1473 }
1474
1475 fn triplet_chunk_weight(
1476 &self,
1477 recipe: &TripletRecipe,
1478 anchor: &RecordChunk,
1479 positive: &RecordChunk,
1480 negative: &RecordChunk,
1481 ) -> f32 {
1482 let floor = self.config.chunking.chunk_weight_floor;
1483 let negative_weight = negative.quality.trust.clamp(0.0, 1.0).max(floor);
1484 if Self::is_auto_chunk_pair_recipe(recipe) {
1485 let pair_trust = ((anchor.quality.trust.clamp(0.0, 1.0)
1488 + positive.quality.trust.clamp(0.0, 1.0))
1489 / 2.0)
1490 .clamp(0.0, 1.0);
1491 let pair_weight = (chunk_proximity_score(anchor, positive) * pair_trust).max(floor);
1492 return (pair_weight + pair_weight + negative_weight) / 3.0;
1494 }
1495 let pair_proximity = chunk_proximity_score(anchor, positive);
1498 let anchor_weight = (self.chunk_weight(anchor) * pair_proximity).max(floor);
1499 let positive_weight = (self.chunk_weight(positive) * pair_proximity).max(floor);
1500 (anchor_weight + positive_weight + negative_weight) / 3.0
1501 }
1502
1503 fn decorate_chunk(
1504 &mut self,
1505 record: &DataRecord,
1506 chunk: &mut RecordChunk,
1507 rng: &mut DeterministicRng,
1508 ) {
1509 chunk.kvp_meta = record
1510 .meta_prefix
1511 .as_ref()
1512 .map(|s| s.all_metadata())
1513 .unwrap_or_default();
1514 if let Some(spec) = record.meta_prefix.as_ref()
1515 && let Some(prefix) = spec.sample(rng)
1516 {
1517 let body_tokens: Vec<&str> = WhitespaceTokenizer.tokenize(&chunk.text);
1518 let prefix_tokens: Vec<&str> = WhitespaceTokenizer.tokenize(&prefix);
1519 let total_tokens = prefix_tokens.len() + body_tokens.len();
1520 let max_window = self.config.chunking.max_window_tokens;
1521 if max_window > 0 && total_tokens > max_window {
1522 if prefix_tokens.len() >= max_window {
1523 chunk.text = prefix_tokens
1524 .into_iter()
1525 .take(max_window)
1526 .collect::<Vec<_>>()
1527 .join(" ");
1528 chunk.tokens_estimate = max_window;
1529 } else {
1530 let remaining = max_window - prefix_tokens.len();
1531 let trimmed_body: Vec<&str> = body_tokens.into_iter().take(remaining).collect();
1532 chunk.text =
1533 format!("{}{}{}", prefix, platform_newline(), trimmed_body.join(" "));
1534 chunk.tokens_estimate = max_window;
1535 }
1536 } else {
1537 chunk.text = format!("{}{}{}", prefix, platform_newline(), chunk.text);
1538 chunk.tokens_estimate = total_tokens;
1539 }
1540 }
1541 }
1542
1543 fn select_chunk_parallel(
1547 &self,
1548 record: &DataRecord,
1549 selector: &Selector,
1550 rng: &mut DeterministicRng,
1551 ) -> Option<RecordChunk> {
1552 match selector {
1553 Selector::Role(role) => self.select_role_parallel(record, role, rng),
1554 Selector::Paragraph(idx) => record.sections.get(*idx).and_then(|section| {
1555 let pool = self.materialize_chunks(record, *idx, section);
1556 if pool.is_empty() {
1557 return None;
1558 }
1559 let i = rng.random_range(0..pool.len());
1560 pool.into_iter().nth(i)
1561 }),
1562 Selector::TemporalOffset(offset) => self
1563 .select_temporal_neighbor(record, *offset)
1564 .and_then(|neighbor| {
1565 self.select_role_parallel(&neighbor, &SectionRole::Context, rng)
1566 }),
1567 Selector::Random => {
1568 if record.sections.is_empty() {
1569 return None;
1570 }
1571 let idx = rng.random_range(0..record.sections.len());
1572 record.sections.get(idx).and_then(|section| {
1573 let pool = self.materialize_chunks(record, idx, section);
1574 if pool.is_empty() {
1575 return None;
1576 }
1577 let i = rng.random_range(0..pool.len());
1578 pool.into_iter().nth(i)
1579 })
1580 }
1581 }
1582 }
1583
1584 fn select_role_parallel(
1586 &self,
1587 record: &DataRecord,
1588 role: &SectionRole,
1589 rng: &mut DeterministicRng,
1590 ) -> Option<RecordChunk> {
1591 let indices: Vec<usize> = record
1592 .sections
1593 .iter()
1594 .enumerate()
1595 .filter(|(_, s)| roles_match(role, &s.role))
1596 .map(|(i, _)| i)
1597 .collect();
1598 if indices.is_empty() {
1599 return None;
1600 }
1601 let start = rng.random_range(0..indices.len());
1602 for offset in 0..indices.len() {
1603 let section_idx = indices[(start + offset) % indices.len()];
1604 let section = &record.sections[section_idx];
1605 let pool = self.materialize_chunks(record, section_idx, section);
1606 if !pool.is_empty() {
1607 let i = rng.random_range(0..pool.len());
1608 return pool.into_iter().nth(i);
1609 }
1610 }
1611 None
1612 }
1613
1614 fn decorate_chunk_parallel(
1616 &self,
1617 record: &DataRecord,
1618 chunk: &mut RecordChunk,
1619 rng: &mut DeterministicRng,
1620 ) {
1621 chunk.kvp_meta = record
1622 .meta_prefix
1623 .as_ref()
1624 .map(|s| s.all_metadata())
1625 .unwrap_or_default();
1626 if let Some(spec) = record.meta_prefix.as_ref()
1627 && let Some(prefix) = spec.sample(rng)
1628 {
1629 let body_tokens: Vec<&str> = WhitespaceTokenizer.tokenize(&chunk.text);
1630 let prefix_tokens: Vec<&str> = WhitespaceTokenizer.tokenize(&prefix);
1631 let total_tokens = prefix_tokens.len() + body_tokens.len();
1632 let max_window = self.config.chunking.max_window_tokens;
1633 if max_window > 0 && total_tokens > max_window {
1634 if prefix_tokens.len() >= max_window {
1635 chunk.text = prefix_tokens
1636 .into_iter()
1637 .take(max_window)
1638 .collect::<Vec<_>>()
1639 .join(" ");
1640 chunk.tokens_estimate = max_window;
1641 } else {
1642 let remaining = max_window - prefix_tokens.len();
1643 let trimmed_body: Vec<&str> = body_tokens.into_iter().take(remaining).collect();
1644 chunk.text =
1645 format!("{}{}{}", prefix, platform_newline(), trimmed_body.join(" "));
1646 chunk.tokens_estimate = max_window;
1647 }
1648 } else {
1649 chunk.text = format!("{}{}{}", prefix, platform_newline(), chunk.text);
1650 chunk.tokens_estimate = total_tokens;
1651 }
1652 }
1653 }
1654
1655 fn select_anchor_positive_parallel(
1657 &self,
1658 record: &DataRecord,
1659 anchor_selector: &Selector,
1660 positive_selector: &Selector,
1661 enforce_window_pair: bool,
1662 rng: &mut DeterministicRng,
1663 ) -> Option<(RecordChunk, RecordChunk)> {
1664 let anchor_chunk = self.select_chunk_parallel(record, anchor_selector, rng)?;
1665 let mut positive_chunk = self.select_chunk_parallel(record, positive_selector, rng)?;
1666 if anchor_selector == positive_selector {
1667 let mut retries = 0usize;
1668 while !same_selector_pair_is_valid(&anchor_chunk, &positive_chunk, enforce_window_pair)
1669 && retries < SAME_SELECTOR_PAIR_RETRY_LIMIT
1670 {
1671 positive_chunk = self.select_chunk_parallel(record, positive_selector, rng)?;
1672 retries += 1;
1673 }
1674 if !same_selector_pair_is_valid(&anchor_chunk, &positive_chunk, enforce_window_pair) {
1675 return None;
1676 }
1677 }
1678 Some((anchor_chunk, positive_chunk))
1680 }
1681
1682 fn select_anchor_positive_for_recipe(
1685 &self,
1686 recipe: &TripletRecipe,
1687 anchor_record: &DataRecord,
1688 rng: &mut DeterministicRng,
1689 ) -> Option<(RecordChunk, RecordChunk, String)> {
1690 if Self::is_auto_chunk_pair_recipe(recipe) {
1691 if !self
1692 .record_has_at_least_two_window_chunks_for_selector(anchor_record, &recipe.anchor)
1693 {
1694 return None;
1695 }
1696 let mut anchor_chunk =
1697 self.select_chunk_parallel(anchor_record, &recipe.anchor, rng)?;
1698 let mut positive_chunk =
1699 self.select_chunk_parallel(anchor_record, &recipe.anchor, rng)?;
1700 let mut tries = 0usize;
1701 while !same_selector_pair_is_valid(&anchor_chunk, &positive_chunk, true) {
1702 tries += 1;
1703 if tries >= SAME_SELECTOR_PAIR_RETRY_LIMIT {
1704 return None;
1705 }
1706 anchor_chunk = self.select_chunk_parallel(anchor_record, &recipe.anchor, rng)?;
1707 positive_chunk = self.select_chunk_parallel(anchor_record, &recipe.anchor, rng)?;
1708 }
1709 let anchor_raw_text = anchor_chunk.text.clone();
1710 self.decorate_chunk_parallel(anchor_record, &mut anchor_chunk, rng);
1711 self.decorate_chunk_parallel(anchor_record, &mut positive_chunk, rng);
1712 return Some((anchor_chunk, positive_chunk, anchor_raw_text));
1713 }
1714 let (mut anchor_chunk, mut positive_chunk) = self.select_anchor_positive_parallel(
1715 anchor_record,
1716 &recipe.anchor,
1717 &recipe.positive_selector,
1718 false,
1719 rng,
1720 )?;
1721 let anchor_raw_text = anchor_chunk.text.clone();
1722 self.decorate_chunk_parallel(anchor_record, &mut anchor_chunk, rng);
1723 self.decorate_chunk_parallel(anchor_record, &mut positive_chunk, rng);
1724 Some((anchor_chunk, positive_chunk, anchor_raw_text))
1725 }
1726
1727 fn select_chunk(&mut self, record: &DataRecord, selector: &Selector) -> Option<RecordChunk> {
1728 match selector {
1729 Selector::Role(role) => self.select_by_role(record, role),
1730 Selector::Paragraph(idx) => record.sections.get(*idx).and_then(|section| {
1731 let pool = self.materialize_chunks(record, *idx, section);
1732 self.next_chunk_from_pool(&record.id, *idx, pool)
1733 }),
1734 Selector::TemporalOffset(offset) => self
1735 .select_temporal_neighbor(record, *offset)
1736 .and_then(|neighbor| self.select_by_role(&neighbor, &SectionRole::Context)),
1737 Selector::Random => {
1738 if record.sections.is_empty() {
1739 return None;
1740 }
1741 let idx = self.rng.random_range(0..record.sections.len());
1742 record.sections.get(idx).and_then(|section| {
1743 let pool = self.materialize_chunks(record, idx, section);
1744 self.next_chunk_from_pool(&record.id, idx, pool)
1745 })
1746 }
1747 }
1748 }
1749
1750 fn select_by_role(&mut self, record: &DataRecord, role: &SectionRole) -> Option<RecordChunk> {
1751 let indices: Vec<usize> = record
1752 .sections
1753 .iter()
1754 .enumerate()
1755 .filter(|(_, section)| roles_match(role, §ion.role))
1756 .map(|(idx, _)| idx)
1757 .collect();
1758 if indices.is_empty() {
1759 return None;
1760 }
1761 let key = role_cursor_key(&record.id, role);
1762 let start_offset = self
1763 .role_cursors
1764 .get(&key)
1765 .and_then(|last_idx| indices.iter().position(|idx| idx == last_idx))
1766 .map(|pos| (pos + 1) % indices.len())
1767 .unwrap_or_else(|| {
1768 let seed_key = format!("{}::{}", key.0, key.1);
1772 (stable_hash_str(self.epoch_seed(), &seed_key) as usize) % indices.len()
1773 });
1774 for offset in 0..indices.len() {
1775 let section_idx = indices[(start_offset + offset) % indices.len()];
1776 let section = &record.sections[section_idx];
1777 let pool = self.materialize_chunks(record, section_idx, section);
1778 if let Some(chunk) = self.next_chunk_from_pool(&record.id, section_idx, pool) {
1779 self.role_cursors.insert(key.clone(), section_idx);
1780 return Some(chunk);
1781 }
1782 }
1783 None
1784 }
1785
1786 fn materialize_chunks(
1799 &self,
1800 record: &DataRecord,
1801 section_idx: usize,
1802 section: &RecordSection,
1803 ) -> Vec<RecordChunk> {
1804 self.chunker
1805 .materialize(&self.config.chunking, record, section_idx, section)
1806 }
1807
1808 fn build_derived_text_recipes(recipes: &[TripletRecipe]) -> Vec<TextRecipe> {
1809 let mut derived = Vec::new();
1810 for recipe in recipes {
1811 let base = recipe.name.as_ref();
1812 derived.push(TextRecipe {
1813 name: Cow::Owned(format!("{base}_anchor")),
1814 selector: recipe.anchor.clone(),
1815 weight: recipe.weight,
1816 instruction: None,
1817 });
1818 derived.push(TextRecipe {
1819 name: Cow::Owned(format!("{base}_positive")),
1820 selector: recipe.positive_selector.clone(),
1821 weight: recipe.weight,
1822 instruction: None,
1823 });
1824 derived.push(TextRecipe {
1825 name: Cow::Owned(format!("{base}_negative")),
1826 selector: recipe.negative_selector.clone(),
1827 weight: recipe.weight,
1828 instruction: None,
1829 });
1830 }
1831 derived
1832 }
1833
1834 fn record_has_long_anchor_or_context_section(&self, record: &DataRecord) -> bool {
1837 let window = self.config.chunking.max_window_tokens;
1838 if window == 0 {
1839 return false;
1840 }
1841 record.sections.iter().any(|section| {
1842 matches!(section.role, SectionRole::Anchor | SectionRole::Context)
1843 && WhitespaceTokenizer.token_count(§ion.text) > window
1844 })
1845 }
1846
1847 fn sync_records_from_cache(&mut self) -> Result<(), SamplerError> {
1848 let mut snapshot = self.ingestion.all_records_snapshot();
1849 snapshot.sort_by(|a, b| a.id.cmp(&b.id));
1850
1851 let old_ids: HashSet<String> = self.records.keys().cloned().collect();
1854 let incoming_ids: HashSet<String> = snapshot.iter().map(|r| r.id.clone()).collect();
1855 let evicted: HashSet<&String> = old_ids.difference(&incoming_ids).collect();
1856
1857 self.records.clear();
1862 self.sources_with_long_sections.clear();
1866 for evicted_id in &evicted {
1875 self.emitted_text_hashes.remove(*evicted_id);
1876 }
1877 self.negative_backend.on_sync_start();
1882 for record in snapshot {
1883 if self.split_store.label_for(&record.id).is_none() {
1884 self.split_store.ensure(record.id.clone())?;
1885 }
1886 if self.record_has_long_anchor_or_context_section(&record) {
1887 self.sources_with_long_sections
1889 .insert(record.source.clone());
1890 }
1891 self.records.insert(record.id.clone(), Arc::new(record));
1892 }
1893 self.prune_cursor_state();
1894 self.rebuild_chunk_index();
1895 self.rebuild_source_index()?;
1896 Ok(())
1897 }
1898
1899 fn ingest_internal_for_split(&mut self, target_split: SplitLabel) -> Result<(), SamplerError> {
1900 if !self.ingestion.has_sources() {
1901 return Ok(());
1902 }
1903 if !self.ingestion_cursors_loaded {
1904 if let Some(state) = self.split_store.load_sampler_state()? {
1905 self.ingestion.load_cursors(&state.source_stream_cursors);
1906 self.ingestion.set_epoch(state.epoch);
1907 self.ingestion.set_epoch_step(state.epoch_step);
1908 }
1909 self.ingestion_cursors_loaded = true;
1910 }
1911 if self.ingestion.all_caches_empty() {
1912 self.ingestion.refresh_all();
1913 } else {
1914 self.ingestion.advance(self.config.batch_size);
1915 }
1916 let observed = self.ingestion.total_ingest_count();
1917 if observed == self.last_observed_ingest && !self.records.is_empty() {
1918 return Ok(());
1919 }
1920 self.last_observed_ingest = observed;
1921 self.sync_records_from_cache()?;
1922 let max_window_tokens = self.config.chunking.max_window_tokens;
1923 self.negative_backend.on_records_refreshed(
1924 &self.records,
1925 max_window_tokens,
1926 &|id| self.split_store.label_for(id),
1927 self.ingestion.last_refreshed_sources(),
1928 );
1929 self.epoch_tracker.ensure_loaded()?;
1932 let records_by_split = self.records_by_split()?;
1933 self.epoch_tracker
1934 .reconcile(target_split, &records_by_split);
1935 self.ensure_source_state()?;
1936 Ok(())
1937 }
1938
1939 fn ensure_ingestion_cursors_loaded(&mut self) -> Result<(), SamplerError> {
1948 if !self.ingestion_cursors_loaded {
1949 if let Some(state) = self.split_store.load_sampler_state()? {
1950 self.ingestion.load_cursors(&state.source_stream_cursors);
1951 self.ingestion.set_epoch(state.epoch);
1952 self.ingestion.set_epoch_step(state.epoch_step);
1953 }
1954 self.ingestion_cursors_loaded = true;
1955 }
1956 Ok(())
1957 }
1958
1959 #[cfg(test)]
1960 fn ingest_internal(&mut self, split: SplitLabel) -> Result<(), SamplerError> {
1961 self.ingest_internal_for_split(split)
1962 }
1963
1964 fn ingest_internal_with_weights_for_split(
1965 &mut self,
1966 target_split: SplitLabel,
1967 weights: &HashMap<SourceId, f32>,
1968 ) -> Result<(), SamplerError> {
1969 if !self.ingestion.has_sources() {
1970 return Ok(());
1971 }
1972 if !self.ingestion_cursors_loaded {
1973 if let Some(state) = self.split_store.load_sampler_state()? {
1974 self.ingestion.load_cursors(&state.source_stream_cursors);
1975 self.ingestion.set_epoch(state.epoch);
1976 self.ingestion.set_epoch_step(state.epoch_step);
1977 }
1978 self.ingestion_cursors_loaded = true;
1979 }
1980 if self.ingestion.all_caches_empty() {
1981 self.ingestion.refresh_all_with_weights(weights)?;
1982 } else {
1983 self.ingestion
1984 .advance_with_weights(self.config.batch_size, weights)?;
1985 }
1986 let observed = self.ingestion.total_ingest_count();
1987 if observed == self.last_observed_ingest && !self.records.is_empty() {
1988 return Ok(());
1989 }
1990 self.last_observed_ingest = observed;
1991 self.sync_records_from_cache()?;
1992 let max_window_tokens = self.config.chunking.max_window_tokens;
1993 self.negative_backend.on_records_refreshed(
1994 &self.records,
1995 max_window_tokens,
1996 &|id| self.split_store.label_for(id),
1997 self.ingestion.last_refreshed_sources(),
1998 );
1999 self.epoch_tracker.ensure_loaded()?;
2000 let records_by_split = self.records_by_split()?;
2001 self.epoch_tracker
2002 .reconcile(target_split, &records_by_split);
2003 self.ensure_source_state()?;
2004 Ok(())
2005 }
2006
2007 fn ingest_with_weights_fallback(
2012 &mut self,
2013 target_split: SplitLabel,
2014 weights: Option<&HashMap<SourceId, f32>>,
2015 ) -> Result<(), SamplerError> {
2016 match weights {
2017 Some(weights)
2018 if !weights.is_empty()
2019 && !weights
2020 .values()
2021 .all(|&w| w == *weights.values().next().unwrap()) =>
2022 {
2023 self.ingest_internal_with_weights_for_split(target_split, weights)?
2024 }
2025 _ => self.ingest_internal_for_split(target_split)?,
2026 }
2027 Ok(())
2028 }
2029
2030 fn force_ingest_refresh_with_weights_for_split(
2031 &mut self,
2032 target_split: SplitLabel,
2033 weights: &HashMap<SourceId, f32>,
2034 ) -> Result<(), SamplerError> {
2035 if !self.ingestion.has_sources() {
2036 return Ok(());
2037 }
2038 if !self.ingestion_cursors_loaded {
2039 if let Some(state) = self.split_store.load_sampler_state()? {
2040 self.ingestion.load_cursors(&state.source_stream_cursors);
2041 self.ingestion.set_epoch(state.epoch);
2042 self.ingestion.set_epoch_step(state.epoch_step);
2043 }
2044 self.ingestion_cursors_loaded = true;
2045 }
2046 self.ingestion.force_refresh_all_with_weights(weights)?;
2047 self.last_observed_ingest = self.ingestion.total_ingest_count();
2048 self.sync_records_from_cache()?;
2049 let max_window_tokens = self.config.chunking.max_window_tokens;
2050 self.negative_backend.on_records_refreshed(
2051 &self.records,
2052 max_window_tokens,
2053 &|id| self.split_store.label_for(id),
2054 self.ingestion.last_refreshed_sources(),
2055 );
2056 self.epoch_tracker.ensure_loaded()?;
2057 let records_by_split = self.records_by_split()?;
2058 self.epoch_tracker
2059 .reconcile(target_split, &records_by_split);
2060 self.ensure_source_state()?;
2061 Ok(())
2062 }
2063
2064 fn sample_source_triplet_candidate(
2074 &mut self,
2075 source: &str,
2076 target_split: SplitLabel,
2077 recipe_orders: &mut HashMap<RecipeKey, Vec<usize>>,
2078 recipe_positions: &mut HashMap<RecipeKey, usize>,
2079 rng: &mut DeterministicRng,
2080 ) -> (Option<(TripletRecipe, SampleTriplet)>, usize) {
2081 let (recipes, _auto_injected) = self.resolve_source_triplet_plan(source);
2084 if recipes.is_empty() {
2085 return (None, 0);
2086 }
2087 if !recipe_orders.contains_key(source) {
2088 let recipe_weights: Vec<f32> = recipes.iter().map(|r| r.weight).collect();
2089 let order =
2090 self.recipe_order_weighted_cycled(&recipe_weights, self.triplet_recipe_rr_idx, rng);
2091 recipe_orders.insert(source.to_string(), order);
2092 }
2093 let order = recipe_orders
2094 .get(source)
2095 .expect("recipe order missing for source");
2096 let pos = recipe_positions.entry(source.to_string()).or_insert(0);
2097 let Some(anchor) = self.choose_anchor_record(Some(source), target_split) else {
2098 return (None, 0);
2099 };
2100
2101 let mut attempts = 0usize;
2102 for offset in 0..order.len() {
2103 let idx = order[(*pos + offset) % order.len()];
2104 attempts = attempts.saturating_add(1);
2105 let recipe = recipes[idx].clone();
2106 if let Some(sample) = self.make_triplet_with_anchor(&recipe, &anchor, rng) {
2110 *pos = (*pos + offset + 1) % order.len();
2111 return (Some((recipe, sample)), attempts);
2112 }
2113 }
2114
2115 (None, attempts)
2116 }
2117
2118 fn next_pair_batch_inner_with_weights(
2119 &mut self,
2120 target_split: SplitLabel,
2121 weights: Option<&HashMap<SourceId, f32>>,
2122 ) -> Result<SampleBatch, SamplerError> {
2123 self.ingest_with_weights_fallback(target_split, weights)?;
2124 self.ensure_split_has_records(target_split)?;
2125 let sources = self.source_order.clone();
2126 if sources.is_empty() {
2127 if self.triplet_recipes.is_empty() {
2128 return Err(SamplerError::Configuration(
2129 "no triplet recipes available".into(),
2130 ));
2131 }
2132 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2133 let recipe_weights: Vec<f32> = self.triplet_recipes.iter().map(|r| r.weight).collect();
2134 let recipe_order = self.recipe_order_weighted_cycled(
2135 &recipe_weights,
2136 self.triplet_recipe_rr_idx,
2137 &mut rng,
2138 );
2139 let mut pairs = Vec::new();
2140 let mut seen = HashSet::new();
2141 let mut last_recipe_name = None;
2142 let mut recipe_pos = 0usize;
2143 let mut recipe_steps = 0usize;
2144 let attempts = self.config.batch_size * 4 * recipe_order.len().max(1);
2145 for _ in 0..attempts {
2146 if pairs.len() >= self.config.batch_size {
2147 break;
2148 }
2149 let Some(anchor) = self.choose_anchor_record(None, target_split) else {
2150 break;
2151 };
2152 let mut triplet = None;
2153 for offset in 0..recipe_order.len() {
2154 let idx = recipe_order[(recipe_pos + offset) % recipe_order.len()];
2155 recipe_steps = recipe_steps.saturating_add(1);
2156 let recipe = self.triplet_recipes[idx].clone();
2157 last_recipe_name = Some(recipe.name.clone());
2158 if let Some(sample) = self.make_triplet_with_anchor(&recipe, &anchor, &mut rng)
2159 {
2160 triplet = Some((recipe, sample));
2161 recipe_pos = (recipe_pos + offset + 1) % recipe_order.len();
2162 break;
2163 }
2164 }
2165 if let Some((recipe, triplet)) = triplet {
2166 let key = (
2167 triplet.anchor.record_id.clone(),
2168 triplet.positive.record_id.clone(),
2169 triplet.negative.record_id.clone(),
2170 );
2171 if seen.insert(key) {
2172 let SampleTriplet {
2173 recipe: triplet_recipe_name,
2174 anchor,
2175 positive,
2176 negative,
2177 weight,
2178 instruction,
2179 } = triplet;
2180 if pairs.len() < self.config.batch_size {
2181 pairs.push(SamplePair {
2182 recipe: triplet_recipe_name.clone(),
2183 anchor: anchor.clone(),
2184 positive: positive.clone(),
2185 weight,
2186 instruction: instruction.clone(),
2187 label: PairLabel::Positive,
2188 reason: None,
2189 });
2190 }
2191 if pairs.len() < self.config.batch_size {
2192 pairs.push(SamplePair {
2193 recipe: triplet_recipe_name,
2194 anchor,
2195 positive: negative,
2196 weight,
2197 instruction,
2198 label: PairLabel::Negative,
2199 reason: Some(
2200 strategy_reason(&recipe.negative_strategy).to_string(),
2201 ),
2202 });
2203 }
2204 }
2205 }
2206 }
2207 if recipe_steps > 0 {
2208 self.triplet_recipe_rr_idx =
2209 self.triplet_recipe_rr_idx.saturating_add(recipe_steps);
2210 }
2211 self.rng = rng;
2212 pad_with_reuse(&mut pairs, self.config.batch_size);
2213 if pairs.len() == self.config.batch_size {
2214 return Ok(SampleBatch { pairs });
2215 }
2216 return Err(SamplerError::Exhausted(
2217 last_recipe_name
2218 .unwrap_or(Cow::Borrowed(RECIPE_LABEL_TRIPLETS))
2219 .to_string(),
2220 ));
2221 }
2222
2223 let mut pairs = Vec::new();
2224 let mut seen = HashSet::new();
2225 let mut source_steps = 0usize;
2226 let mut cycle = (self.source_cycle_idx / sources.len()) as u64;
2227 let mut source_idx = self.source_cycle_idx % sources.len();
2228 let mut cycle_sources = self.shuffled_source_cycle(cycle);
2229 let mut recipe_orders: HashMap<RecipeKey, Vec<usize>> = HashMap::new();
2230 let mut recipe_positions: HashMap<RecipeKey, usize> = HashMap::new();
2231 let mut recipe_steps = 0usize;
2232 let max_recipe_len = sources
2233 .iter()
2234 .map(|source| self.triplet_recipe_count_for_source(source))
2235 .max()
2236 .unwrap_or(1)
2237 .max(1);
2238 let attempts = self.config.batch_size * 4 * sources.len() * max_recipe_len;
2239 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2240 for _ in 0..attempts {
2241 if pairs.len() >= self.config.batch_size {
2242 break;
2243 }
2244 let source = cycle_sources[source_idx].as_str();
2245 let (triplet, attempts_used) = self.sample_source_triplet_candidate(
2246 source,
2247 target_split,
2248 &mut recipe_orders,
2249 &mut recipe_positions,
2250 &mut rng,
2251 );
2252 recipe_steps = recipe_steps.saturating_add(attempts_used);
2253 if let Some((recipe, triplet)) = triplet {
2254 let key = (
2255 triplet.anchor.record_id.clone(),
2256 triplet.positive.record_id.clone(),
2257 triplet.negative.record_id.clone(),
2258 );
2259 if seen.insert(key) {
2260 let SampleTriplet {
2261 recipe: triplet_recipe_name,
2262 anchor,
2263 positive,
2264 negative,
2265 weight,
2266 instruction,
2267 } = triplet;
2268 if pairs.len() < self.config.batch_size {
2269 pairs.push(SamplePair {
2270 recipe: triplet_recipe_name.clone(),
2271 anchor: anchor.clone(),
2272 positive: positive.clone(),
2273 weight,
2274 instruction: instruction.clone(),
2275 label: PairLabel::Positive,
2276 reason: None,
2277 });
2278 }
2279 if pairs.len() < self.config.batch_size {
2280 pairs.push(SamplePair {
2281 recipe: triplet_recipe_name,
2282 anchor,
2283 positive: negative,
2284 weight,
2285 instruction,
2286 label: PairLabel::Negative,
2287 reason: Some(strategy_reason(&recipe.negative_strategy).to_string()),
2288 });
2289 }
2290 }
2291 }
2292 source_idx += 1;
2293 source_steps += 1;
2294 if source_idx >= cycle_sources.len() {
2295 source_idx = 0;
2296 cycle = cycle.saturating_add(1);
2297 cycle_sources = self.shuffled_source_cycle(cycle);
2298 }
2299 }
2300 if recipe_steps > 0 {
2301 self.triplet_recipe_rr_idx = self.triplet_recipe_rr_idx.saturating_add(recipe_steps);
2302 }
2303 self.rng = rng;
2304 pad_with_reuse(&mut pairs, self.config.batch_size);
2305 if pairs.len() == self.config.batch_size {
2306 self.source_cycle_idx = self.source_cycle_idx.saturating_add(source_steps);
2307 self.source_state_dirty = sources.len() > 1;
2308 return Ok(SampleBatch { pairs });
2309 }
2310 Err(SamplerError::Exhausted(RECIPE_LABEL_TRIPLETS.into()))
2311 }
2312
2313 fn try_emit_text_sample(
2318 &mut self,
2319 sample: &TextSample,
2320 batch_texts: &mut HashSet<String>,
2321 ) -> bool {
2322 batch_texts.insert(sample.chunk.text.clone())
2323 && self
2324 .emitted_text_hashes
2325 .entry(sample.chunk.record_id.clone())
2326 .or_default()
2327 .insert(stable_hash_str(0, &sample.chunk.text))
2328 }
2329
2330 fn next_text_batch_inner_with_weights(
2331 &mut self,
2332 target_split: SplitLabel,
2333 weights: Option<&HashMap<SourceId, f32>>,
2334 ) -> Result<TextBatch, SamplerError> {
2335 self.ingest_with_weights_fallback(target_split, weights)?;
2336 self.ensure_split_has_records(target_split)?;
2337 let sources = self.source_order.clone();
2338 if sources.is_empty() {
2339 if self.text_recipes.is_empty() {
2340 return Err(SamplerError::Configuration(
2341 "no text recipes configured".into(),
2342 ));
2343 }
2344 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2345 let recipe_weights: Vec<f32> = self.text_recipes.iter().map(|r| r.weight).collect();
2346 let recipe_order = self.text_recipe_order_weighted_cycled(
2347 &recipe_weights,
2348 self.text_recipe_rr_idx,
2349 &mut rng,
2350 );
2351 let mut samples = Vec::new();
2352 let mut batch_texts = HashSet::new();
2353 let mut last_recipe_name = None;
2354 let mut recipe_pos = 0usize;
2355 let mut recipe_steps = 0usize;
2356 let attempts = self.config.batch_size * 4 * recipe_order.len().max(1);
2357 for _ in 0..attempts {
2358 if samples.len() >= self.config.batch_size {
2359 break;
2360 }
2361 let recipe_idx = recipe_order[recipe_pos];
2362 recipe_pos = (recipe_pos + 1) % recipe_order.len();
2363 recipe_steps = recipe_steps.saturating_add(1);
2364 let recipe = self.text_recipes[recipe_idx].clone();
2365 last_recipe_name = Some(recipe.name.clone());
2366 if let Some(sample) =
2367 self.make_text_sample_for_split(&recipe, None, target_split, &mut rng)
2368 && self.try_emit_text_sample(&sample, &mut batch_texts)
2369 {
2370 samples.push(sample);
2371 }
2372 }
2373 if recipe_steps > 0 {
2374 self.text_recipe_rr_idx = self.text_recipe_rr_idx.saturating_add(recipe_steps);
2375 }
2376 self.rng = rng;
2377 pad_with_reuse(&mut samples, self.config.batch_size);
2378 if samples.len() == self.config.batch_size {
2379 return Ok(TextBatch { samples });
2380 }
2381 return Err(SamplerError::Exhausted(
2382 last_recipe_name
2383 .unwrap_or(Cow::Borrowed(RECIPE_LABEL_TEXT))
2384 .to_string(),
2385 ));
2386 }
2387
2388 let mut samples = Vec::new();
2389 let mut batch_texts = HashSet::new();
2390 let mut source_steps = 0usize;
2391 let mut cycle = (self.source_cycle_idx / sources.len()) as u64;
2392 let mut idx = self.source_cycle_idx % sources.len();
2393 let mut cycle_sources = self.shuffled_source_cycle(cycle);
2394 let mut recipe_orders: HashMap<RecipeKey, Vec<usize>> = HashMap::new();
2395 let mut recipe_positions: HashMap<RecipeKey, usize> = HashMap::new();
2396 let mut recipe_steps = 0usize;
2397 let max_recipe_len = sources
2398 .iter()
2399 .map(|source| self.text_recipes_for_source(source).len())
2400 .max()
2401 .unwrap_or(1)
2402 .max(1);
2403 let attempts = self.config.batch_size * 4 * sources.len() * max_recipe_len;
2404 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2405 for _ in 0..attempts {
2406 if samples.len() >= self.config.batch_size {
2407 break;
2408 }
2409 let source = cycle_sources[idx].as_str();
2410 let recipes = self.text_recipes_for_source(source).to_vec();
2411 if recipes.is_empty() {
2412 idx += 1;
2413 source_steps += 1;
2414 if idx >= cycle_sources.len() {
2415 idx = 0;
2416 cycle = cycle.saturating_add(1);
2417 cycle_sources = self.shuffled_source_cycle(cycle);
2418 }
2419 continue;
2420 }
2421 if !recipe_orders.contains_key(source) {
2422 let recipe_weights: Vec<f32> = recipes.iter().map(|r| r.weight).collect();
2423 let order = self.text_recipe_order_weighted_cycled(
2424 &recipe_weights,
2425 self.text_recipe_rr_idx,
2426 &mut rng,
2427 );
2428 recipe_orders.insert(source.to_string(), order);
2429 }
2430 let order = recipe_orders
2431 .get(source)
2432 .expect("recipe order missing for source");
2433 let pos = recipe_positions.entry(source.to_string()).or_insert(0);
2434 let mut sample: Option<(TextRecipe, TextSample)> = None;
2435 for offset in 0..order.len() {
2436 let recipe_idx = order[(*pos + offset) % order.len()];
2437 let recipe = recipes[recipe_idx].clone();
2438 if let Some(item) =
2439 self.make_text_sample_for_split(&recipe, Some(source), target_split, &mut rng)
2440 && self.try_emit_text_sample(&item, &mut batch_texts)
2441 {
2442 recipe_steps = recipe_steps.saturating_add(offset + 1);
2443 *pos = (*pos + offset + 1) % order.len();
2444 sample = Some((recipe, item));
2445 break;
2446 }
2447 }
2448 if sample.is_none() {
2449 recipe_steps = recipe_steps.saturating_add(order.len());
2450 }
2451 if let Some((_recipe, sample)) = sample {
2452 samples.push(sample);
2453 }
2454 idx += 1;
2455 source_steps += 1;
2456 if idx >= cycle_sources.len() {
2457 idx = 0;
2458 cycle = cycle.saturating_add(1);
2459 cycle_sources = self.shuffled_source_cycle(cycle);
2460 }
2461 }
2462 if samples.len() != self.config.batch_size {
2463 pad_with_reuse(&mut samples, self.config.batch_size);
2464 }
2465 if samples.len() != self.config.batch_size {
2466 self.rng = rng;
2467 return Err(SamplerError::Exhausted(RECIPE_LABEL_TEXT.into()));
2468 }
2469 self.rng = rng;
2470 self.source_cycle_idx = self.source_cycle_idx.saturating_add(source_steps);
2471 self.source_state_dirty = sources.len() > 1;
2472 if recipe_steps > 0 {
2473 self.text_recipe_rr_idx = self.text_recipe_rr_idx.saturating_add(recipe_steps);
2474 }
2475 Ok(TextBatch { samples })
2476 }
2477
2478 fn next_triplet_batch_inner_with_weights(
2479 &mut self,
2480 target_split: SplitLabel,
2481 weights: Option<&HashMap<SourceId, f32>>,
2482 ) -> Result<TripletBatch, SamplerError> {
2483 self.ingest_with_weights_fallback(target_split, weights)?;
2484 self.ensure_split_has_records(target_split)?;
2485 let sources = self.source_order.clone();
2486 if sources.is_empty() {
2487 if self.triplet_recipes.is_empty() {
2488 return Err(SamplerError::Configuration(
2489 "no triplet recipes configured".into(),
2490 ));
2491 }
2492 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2493 let recipe_weights: Vec<f32> = self.triplet_recipes.iter().map(|r| r.weight).collect();
2494 let recipe_order = self.recipe_order_weighted_cycled(
2495 &recipe_weights,
2496 self.triplet_recipe_rr_idx,
2497 &mut rng,
2498 );
2499 let mut triplets = Vec::new();
2500 let mut seen = HashSet::new();
2501 let mut last_recipe_name = None;
2502 let mut recipe_pos = 0usize;
2503 let mut recipe_steps = 0usize;
2504 let attempts = self.config.batch_size * 4 * recipe_order.len().max(1);
2505 for _ in 0..attempts {
2506 if triplets.len() >= self.config.batch_size {
2507 break;
2508 }
2509 let Some(anchor) = self.choose_anchor_record(None, target_split) else {
2510 break;
2511 };
2512 let mut triplet = None;
2513 for offset in 0..recipe_order.len() {
2514 let idx = recipe_order[(recipe_pos + offset) % recipe_order.len()];
2515 recipe_steps = recipe_steps.saturating_add(1);
2516 let recipe = self.triplet_recipes[idx].clone();
2517 last_recipe_name = Some(recipe.name.clone());
2518 if let Some(sample) = self.make_triplet_with_anchor(&recipe, &anchor, &mut rng)
2519 {
2520 triplet = Some(sample);
2521 recipe_pos = (recipe_pos + offset + 1) % recipe_order.len();
2522 break;
2523 }
2524 }
2525 if let Some(triplet) = triplet {
2526 let key = (
2527 triplet.anchor.record_id.clone(),
2528 triplet.positive.record_id.clone(),
2529 triplet.negative.record_id.clone(),
2530 );
2531 if seen.insert(key) {
2532 triplets.push(triplet);
2533 }
2534 }
2535 }
2536 if recipe_steps > 0 {
2537 self.triplet_recipe_rr_idx =
2538 self.triplet_recipe_rr_idx.saturating_add(recipe_steps);
2539 }
2540 self.rng = rng;
2541 pad_with_reuse(&mut triplets, self.config.batch_size);
2542 if triplets.len() == self.config.batch_size {
2543 return Ok(TripletBatch { triplets });
2544 }
2545 return Err(SamplerError::Exhausted(
2546 last_recipe_name
2547 .unwrap_or(Cow::Borrowed(RECIPE_LABEL_TRIPLETS))
2548 .to_string(),
2549 ));
2550 }
2551
2552 let mut triplets: Vec<SampleTriplet> = Vec::new();
2554 let mut seen: HashSet<(RecordId, RecordId, RecordId)> = HashSet::new();
2555 let mut source_steps = 0usize;
2556 let mut cycle = (self.source_cycle_idx / sources.len()) as u64;
2557 let mut source_idx = self.source_cycle_idx % sources.len();
2558 let mut cycle_sources = self.shuffled_source_cycle(cycle);
2559 let mut recipe_steps = 0usize;
2560 let max_recipe_len = sources
2561 .iter()
2562 .map(|source| self.triplet_recipe_count_for_source(source.as_str()))
2563 .max()
2564 .unwrap_or(1)
2565 .max(1);
2566 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2567
2568 struct SlotPlan {
2573 anchor: Arc<DataRecord>,
2574 recipes: Vec<TripletRecipe>,
2575 fork_seed: u64,
2576 }
2577 let target_slots = self.config.batch_size * 4 * max_recipe_len;
2578 let mut slot_plans: Vec<SlotPlan> = Vec::with_capacity(target_slots);
2579
2580 for _ in 0..target_slots {
2581 if slot_plans.len() >= target_slots {
2582 break;
2583 }
2584 let source = cycle_sources[source_idx].as_str();
2585 let (recipes, _) = self.resolve_source_triplet_plan(source);
2586 if !recipes.is_empty() {
2587 let fork_seed = rng.next_u64();
2588 if let Some(anchor) = self.choose_anchor_record(Some(source), target_split) {
2589 slot_plans.push(SlotPlan {
2590 anchor,
2591 recipes,
2592 fork_seed,
2593 });
2594 recipe_steps = recipe_steps.saturating_add(1);
2595 }
2596 }
2597 source_idx += 1;
2598 source_steps += 1;
2599 if source_idx >= cycle_sources.len() {
2600 source_idx = 0;
2601 cycle = cycle.saturating_add(1);
2602 cycle_sources = self.shuffled_source_cycle(cycle);
2603 }
2604 }
2605 self.rng = rng;
2606
2607 struct SlotCandidate {
2612 recipe: TripletRecipe,
2613 anchor: Arc<DataRecord>,
2614 anchor_chunk: RecordChunk,
2615 positive_chunk: RecordChunk,
2616 anchor_raw_text: String,
2617 }
2618 let mut raw_candidates: Vec<(usize, Option<SlotCandidate>)> = slot_plans
2619 .par_iter()
2620 .enumerate()
2621 .map(|(slot_idx, plan)| {
2622 let mut fork_rng = DeterministicRng::new(plan.fork_seed);
2623 let weights: Vec<f32> = plan.recipes.iter().map(|r| r.weight).collect();
2626 let order = weighted_recipe_order(&weights, &mut fork_rng);
2627 let mut candidate = None;
2628 for &idx in &order {
2629 let recipe = &plan.recipes[idx];
2630 if let Some((ac, pc, raw)) =
2631 self.select_anchor_positive_for_recipe(recipe, &plan.anchor, &mut fork_rng)
2632 {
2633 candidate = Some(SlotCandidate {
2634 recipe: recipe.clone(),
2635 anchor: Arc::clone(&plan.anchor),
2636 anchor_chunk: ac,
2637 positive_chunk: pc,
2638 anchor_raw_text: raw,
2639 });
2640 break;
2641 }
2642 }
2643 (slot_idx, candidate)
2644 })
2645 .collect();
2646
2647 raw_candidates.sort_unstable_by_key(|(i, _)| *i);
2652 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2653 for (_, candidate) in raw_candidates {
2654 if triplets.len() >= self.config.batch_size {
2655 break;
2656 }
2657 let Some(sc) = candidate else { continue };
2658 let (negative_record, fallback_used) = match self.select_negative_record(
2659 &sc.anchor,
2660 &sc.recipe.negative_strategy,
2661 Some(&sc.anchor_raw_text),
2662 &mut rng,
2663 ) {
2664 Some(pair) => pair,
2665 None => continue,
2666 };
2667 let mut negative_chunk = match self.select_chunk_parallel(
2668 &negative_record,
2669 &sc.recipe.negative_selector,
2670 &mut rng,
2671 ) {
2672 Some(c) => c,
2673 None => continue,
2674 };
2675 self.decorate_chunk_parallel(&negative_record, &mut negative_chunk, &mut rng);
2676
2677 let (anchor_chunk, positive_chunk) = if rng.next_u64() & ANCHOR_POSITIVE_SWAP_MASK == 0
2679 {
2680 (sc.positive_chunk, sc.anchor_chunk)
2681 } else {
2682 (sc.anchor_chunk, sc.positive_chunk)
2683 };
2684
2685 if (!sc.recipe.allow_same_anchor_positive && anchor_chunk.text == positive_chunk.text)
2686 || negative_chunk.text == positive_chunk.text
2687 || negative_chunk.text == anchor_chunk.text
2688 {
2689 continue;
2690 }
2691
2692 let chunk_weight = self.triplet_chunk_weight(
2693 &sc.recipe,
2694 &anchor_chunk,
2695 &positive_chunk,
2696 &negative_chunk,
2697 );
2698 let weight = sc.recipe.weight * chunk_weight;
2699 let recipe_name = if fallback_used {
2700 format!("{}_fallback_same_split", sc.recipe.name)
2701 } else {
2702 sc.recipe.name.to_string()
2703 };
2704 let triplet = SampleTriplet {
2705 recipe: recipe_name,
2706 anchor: anchor_chunk,
2707 positive: positive_chunk,
2708 negative: negative_chunk,
2709 weight,
2710 instruction: sc.recipe.instruction.as_ref().map(|s| s.to_string()),
2711 };
2712 let key = (
2713 triplet.anchor.record_id.clone(),
2714 triplet.positive.record_id.clone(),
2715 triplet.negative.record_id.clone(),
2716 );
2717 if seen.insert(key) && triplets.len() < self.config.batch_size {
2718 triplets.push(triplet);
2719 }
2720 }
2721 self.rng = rng;
2722
2723 if recipe_steps > 0 {
2724 self.triplet_recipe_rr_idx = self.triplet_recipe_rr_idx.saturating_add(recipe_steps);
2725 }
2726 pad_with_reuse(&mut triplets, self.config.batch_size);
2727 if triplets.len() == self.config.batch_size {
2728 self.source_cycle_idx = self.source_cycle_idx.saturating_add(source_steps);
2729 self.source_state_dirty = sources.len() > 1;
2730 let batch = TripletBatch { triplets };
2731 return Ok(batch);
2732 }
2733 Err(SamplerError::Exhausted(RECIPE_LABEL_TRIPLETS.into()))
2734 }
2735
2736 #[cfg(test)]
2742 #[cfg(feature = "bm25-mining")]
2743 fn bm25_backend_mut(&mut self) -> &mut backends::Bm25Backend {
2744 self.negative_backend
2745 .as_any_mut()
2746 .downcast_mut::<backends::Bm25Backend>()
2747 .expect("bm25_backend_mut: negative_backend is Bm25Backend when bm25-mining feature is active")
2748 }
2749
2750 #[cfg(test)]
2752 fn recipe_order_weighted_shuffled_seeded(&mut self, weights: &[f32]) -> Vec<usize> {
2753 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2754 let result = self.recipe_order_weighted_shuffled(weights, &mut rng);
2755 self.rng = rng;
2756 result
2757 }
2758
2759 #[cfg(test)]
2761 fn recipe_order_weighted_cycled_seeded(
2762 &mut self,
2763 weights: &[f32],
2764 rr_idx: usize,
2765 ) -> Vec<usize> {
2766 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2767 let result = self.recipe_order_weighted_cycled(weights, rr_idx, &mut rng);
2768 self.rng = rng;
2769 result
2770 }
2771
2772 #[cfg(test)]
2774 fn text_recipe_order_weighted_shuffled_seeded(&mut self, weights: &[f32]) -> Vec<usize> {
2775 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2776 let result = self.text_recipe_order_weighted_shuffled(weights, &mut rng);
2777 self.rng = rng;
2778 result
2779 }
2780
2781 #[cfg(test)]
2783 fn text_recipe_order_weighted_cycled_seeded(
2784 &mut self,
2785 weights: &[f32],
2786 rr_idx: usize,
2787 ) -> Vec<usize> {
2788 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2789 let result = self.text_recipe_order_weighted_cycled(weights, rr_idx, &mut rng);
2790 self.rng = rng;
2791 result
2792 }
2793
2794 #[cfg(test)]
2796 fn select_negative_record_seeded(
2797 &mut self,
2798 anchor_record: &DataRecord,
2799 strategy: &NegativeStrategy,
2800 anchor_query_text: Option<&str>,
2801 ) -> Option<(Arc<DataRecord>, bool)> {
2802 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2803 let result =
2804 self.select_negative_record(anchor_record, strategy, anchor_query_text, &mut rng);
2805 self.rng = rng;
2806 result
2807 }
2808
2809 #[cfg(test)]
2811 fn make_triplet_with_anchor_seeded(
2812 &mut self,
2813 recipe: &TripletRecipe,
2814 anchor: &DataRecord,
2815 ) -> Option<SampleTriplet> {
2816 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2817 let result = self.make_triplet_with_anchor(recipe, anchor, &mut rng);
2818 self.rng = rng;
2819 result
2820 }
2821
2822 #[cfg(test)]
2824 fn decorate_chunk_seeded(&mut self, record: &DataRecord, chunk: &mut RecordChunk) {
2825 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2826 self.decorate_chunk(record, chunk, &mut rng);
2827 self.rng = rng;
2828 }
2829
2830 #[cfg(all(feature = "bm25-mining", feature = "extended-metrics"))]
2836 fn bm25_fallback_stats(&self) -> (u64, u64) {
2837 self.negative_backend.bm25_fallback_stats()
2838 }
2839
2840 #[cfg(test)]
2845 #[cfg(feature = "bm25-mining")]
2846 fn bm25_ranked_candidates(&mut self, anchor: &crate::data::DataRecord) -> Vec<RecordId> {
2847 let split = self
2848 .split_store
2849 .label_for(&anchor.id)
2850 .unwrap_or(SplitLabel::Train);
2851 self.negative_backend
2852 .as_any_mut()
2853 .downcast_mut::<backends::Bm25Backend>()
2854 .expect("bm25_ranked_candidates: Bm25Backend")
2855 .ranked_candidates_pub(anchor, split)
2856 }
2857}
2858
2859fn weighted_recipe_order(weights: &[f32], rng: &mut DeterministicRng) -> Vec<usize> {
2870 let nonzero: Vec<(usize, f32)> = weights
2871 .iter()
2872 .enumerate()
2873 .filter(|(_, w)| **w > 0.0)
2874 .map(|(i, &w)| (i, w))
2875 .collect();
2876 if nonzero.is_empty() {
2877 return Vec::new();
2878 }
2879 let w_min = nonzero
2880 .iter()
2881 .map(|(_, w)| *w)
2882 .fold(f32::INFINITY, f32::min);
2883 let mut order: Vec<usize> = Vec::new();
2884 for (recipe_idx, w) in &nonzero {
2885 let tickets = ((w / w_min).round() as usize).clamp(1, RECIPE_ORDER_MAX_WEIGHT_MULTIPLIER);
2886 for _ in 0..tickets {
2887 order.push(*recipe_idx);
2888 }
2889 }
2890 order.shuffle(rng);
2891 order
2892}
2893
2894fn same_selector_pair_is_valid(
2895 anchor_chunk: &RecordChunk,
2896 positive_chunk: &RecordChunk,
2897 enforce_window_pair: bool,
2898) -> bool {
2899 if triplet_chunk_key(anchor_chunk) == triplet_chunk_key(positive_chunk) {
2900 return false;
2901 }
2902 if !enforce_window_pair {
2903 return true;
2904 }
2905 matches!(
2906 (&anchor_chunk.view, &positive_chunk.view),
2907 (ChunkView::Window { .. }, ChunkView::Window { .. })
2908 )
2909}
2910
2911impl<S: SplitStore + EpochStateStore + SamplerStateStore + 'static> TripletSampler<S> {
2912 pub fn new(config: SamplerConfig, split_store: Arc<S>) -> Self {
2914 let inner = TripletSamplerInner::new(config, split_store);
2915 Self {
2916 inner: Mutex::new(inner),
2917 }
2918 }
2919
2920 pub fn new_with_chunker(
2922 config: SamplerConfig,
2923 split_store: Arc<S>,
2924 chunker: Arc<dyn ChunkingAlgorithm>,
2925 ) -> Self {
2926 let inner = TripletSamplerInner::new_with_chunker(config, split_store, chunker);
2927 Self {
2928 inner: Mutex::new(inner),
2929 }
2930 }
2931
2932 pub fn next_pair_batch_for_split(
2934 &self,
2935 split: SplitLabel,
2936 ) -> Result<SampleBatch, SamplerError> {
2937 self.next_pair_batch_with_weights_for_split(split, &HashMap::new())
2938 }
2939
2940 pub fn next_text_batch_for_split(&self, split: SplitLabel) -> Result<TextBatch, SamplerError> {
2942 self.next_text_batch_with_weights_for_split(split, &HashMap::new())
2943 }
2944
2945 pub fn next_triplet_batch_for_split(
2947 &self,
2948 split: SplitLabel,
2949 ) -> Result<TripletBatch, SamplerError> {
2950 self.next_triplet_batch_with_weights_for_split(split, &HashMap::new())
2951 }
2952
2953 fn next_batch_with_retry<B>(
2957 &self,
2958 split: SplitLabel,
2959 weights: &HashMap<SourceId, f32>,
2960 inner_fn: impl Fn(
2961 &mut TripletSamplerInner<S>,
2962 SplitLabel,
2963 Option<&HashMap<SourceId, f32>>,
2964 ) -> Result<B, SamplerError>,
2965 ) -> Result<B, SamplerError> {
2966 let mut inner = self.inner.lock().unwrap();
2967 inner.ensure_split_allowed(split)?;
2968 inner.ensure_ingestion_cursors_loaded()?;
2971 inner.ingestion.increment_epoch_step();
2973 for attempt in 0..=EXHAUSTION_RETRY_LIMIT {
2974 match inner_fn(&mut inner, split, Some(weights)) {
2975 Ok(batch) => return Ok(batch),
2976 Err(SamplerError::Exhausted(_)) if attempt < EXHAUSTION_RETRY_LIMIT => {
2977 inner.force_ingest_refresh_with_weights_for_split(split, weights)?;
2978 }
2979 Err(err) => return Err(err),
2980 }
2981 }
2982 unreachable!()
2984 }
2985
2986 pub fn next_pair_batch_with_weights_for_split(
2988 &self,
2989 split: SplitLabel,
2990 weights: &HashMap<SourceId, f32>,
2991 ) -> Result<SampleBatch, SamplerError> {
2992 self.next_batch_with_retry(split, weights, |inner, split, weights| {
2993 inner.next_pair_batch_inner_with_weights(split, weights)
2994 })
2995 }
2996
2997 pub fn next_text_batch_with_weights_for_split(
2999 &self,
3000 split: SplitLabel,
3001 weights: &HashMap<SourceId, f32>,
3002 ) -> Result<TextBatch, SamplerError> {
3003 self.next_batch_with_retry(split, weights, |inner, split, weights| {
3004 inner.next_text_batch_inner_with_weights(split, weights)
3005 })
3006 }
3007
3008 pub fn next_triplet_batch_with_weights_for_split(
3010 &self,
3011 split: SplitLabel,
3012 weights: &HashMap<SourceId, f32>,
3013 ) -> Result<TripletBatch, SamplerError> {
3014 self.next_batch_with_retry(split, weights, |inner, split, weights| {
3015 inner.next_triplet_batch_inner_with_weights(split, weights)
3016 })
3017 }
3018
3019 pub fn prefetch_triplet_batches(
3021 self: Arc<Self>,
3022 split: SplitLabel,
3023 capacity: usize,
3024 ) -> BatchPrefetcher<TripletBatch> {
3025 BatchPrefetcher::new(capacity, move || self.next_triplet_batch_for_split(split))
3026 }
3027
3028 pub fn prefetch_triplet_batches_with_weights(
3030 self: Arc<Self>,
3031 split: SplitLabel,
3032 capacity: usize,
3033 weights: HashMap<SourceId, f32>,
3034 ) -> BatchPrefetcher<TripletBatch> {
3035 BatchPrefetcher::new(capacity, move || {
3036 self.next_triplet_batch_with_weights_for_split(split, &weights)
3037 })
3038 }
3039
3040 pub fn prefetch_pair_batches(
3042 self: Arc<Self>,
3043 split: SplitLabel,
3044 capacity: usize,
3045 ) -> BatchPrefetcher<SampleBatch> {
3046 BatchPrefetcher::new(capacity, move || self.next_pair_batch_for_split(split))
3047 }
3048
3049 pub fn prefetch_pair_batches_with_weights(
3051 self: Arc<Self>,
3052 split: SplitLabel,
3053 capacity: usize,
3054 weights: HashMap<SourceId, f32>,
3055 ) -> BatchPrefetcher<SampleBatch> {
3056 BatchPrefetcher::new(capacity, move || {
3057 self.next_pair_batch_with_weights_for_split(split, &weights)
3058 })
3059 }
3060
3061 pub fn prefetch_text_batches(
3063 self: Arc<Self>,
3064 split: SplitLabel,
3065 capacity: usize,
3066 ) -> BatchPrefetcher<TextBatch> {
3067 BatchPrefetcher::new(capacity, move || self.next_text_batch_for_split(split))
3068 }
3069
3070 pub fn prefetch_text_batches_with_weights(
3072 self: Arc<Self>,
3073 split: SplitLabel,
3074 capacity: usize,
3075 weights: HashMap<SourceId, f32>,
3076 ) -> BatchPrefetcher<TextBatch> {
3077 BatchPrefetcher::new(capacity, move || {
3078 self.next_text_batch_with_weights_for_split(split, &weights)
3079 })
3080 }
3081
3082 pub fn text_recipes(&self) -> Vec<TextRecipe> {
3084 let inner = self.inner.lock().unwrap();
3085 inner.text_recipes().to_vec()
3086 }
3087
3088 pub fn register_source(
3093 &self,
3094 source: Box<dyn DataSource + 'static>,
3095 ) -> Result<(), SamplerError> {
3096 let mut inner = self.inner.lock().unwrap();
3097 inner.register_source(source)
3098 }
3099
3100 pub fn set_epoch(&self, epoch: u64) -> Result<(), SamplerError> {
3102 let mut inner = self.inner.lock().unwrap();
3103 inner.set_epoch(epoch)
3104 }
3105
3106 pub fn save_sampler_state(&self, save_to: Option<&Path>) -> Result<(), SamplerError> {
3111 let mut inner = self.inner.lock().unwrap();
3112 inner.save_sampler_state(save_to)
3113 }
3114
3115 #[cfg(all(feature = "bm25-mining", feature = "extended-metrics"))]
3123 pub fn bm25_fallback_stats(&self) -> (u64, u64) {
3124 let inner = self.inner.lock().unwrap();
3125 inner.bm25_fallback_stats()
3126 }
3127}
3128
3129impl<S: SplitStore + EpochStateStore + SamplerStateStore + 'static> Sampler for TripletSampler<S> {
3130 fn next_pair_batch(&self, split: SplitLabel) -> Result<SampleBatch, SamplerError> {
3131 self.next_pair_batch_for_split(split)
3132 }
3133
3134 fn next_pair_batch_with_weights(
3135 &self,
3136 split: SplitLabel,
3137 weights: &HashMap<SourceId, f32>,
3138 ) -> Result<SampleBatch, SamplerError> {
3139 self.next_pair_batch_with_weights_for_split(split, weights)
3140 }
3141
3142 fn next_text_batch(&self, split: SplitLabel) -> Result<TextBatch, SamplerError> {
3143 self.next_text_batch_for_split(split)
3144 }
3145
3146 fn next_text_batch_with_weights(
3147 &self,
3148 split: SplitLabel,
3149 weights: &HashMap<SourceId, f32>,
3150 ) -> Result<TextBatch, SamplerError> {
3151 self.next_text_batch_with_weights_for_split(split, weights)
3152 }
3153
3154 fn next_triplet_batch(&self, split: SplitLabel) -> Result<TripletBatch, SamplerError> {
3155 self.next_triplet_batch_for_split(split)
3156 }
3157
3158 fn next_triplet_batch_with_weights(
3159 &self,
3160 split: SplitLabel,
3161 weights: &HashMap<SourceId, f32>,
3162 ) -> Result<TripletBatch, SamplerError> {
3163 self.next_triplet_batch_with_weights_for_split(split, weights)
3164 }
3165}
3166
3167fn roles_match(target: &SectionRole, candidate: &SectionRole) -> bool {
3168 target == candidate
3169}
3170
3171fn role_cursor_key(record_id: &RecordId, role: &SectionRole) -> (RecordId, String) {
3172 (record_id.clone(), role_label(role))
3173}
3174
3175fn role_label(role: &SectionRole) -> String {
3176 match role {
3177 SectionRole::Anchor => ROLE_LABEL_ANCHOR.into(),
3178 SectionRole::Context => ROLE_LABEL_CONTEXT.into(),
3179 }
3180}
3181
3182fn taxonomy_value(record: &DataRecord, field: MetadataKey) -> Option<&str> {
3183 record.taxonomy.iter().find_map(|entry| field.strip(entry))
3184}
3185
3186fn strategy_reason(strategy: &NegativeStrategy) -> &'static str {
3187 match strategy {
3188 NegativeStrategy::WrongPublicationDate => NEG_REASON_WRONG_DATE,
3189 NegativeStrategy::WrongArticle => NEG_REASON_WRONG_ARTICLE,
3190 NegativeStrategy::QuestionAnswerMismatch => NEG_REASON_WRONG_QA,
3191 }
3192}
3193
3194#[cfg(test)]
3202fn text_dedup_key(chunk: &RecordChunk) -> (String, String) {
3203 (chunk.record_id.clone(), chunk.text.clone())
3204}
3205
3206fn triplet_chunk_key(chunk: &RecordChunk) -> String {
3207 match &chunk.view {
3208 ChunkView::Window { index, .. } => {
3209 format!("{}|{}|w|{}", chunk.record_id, chunk.section_idx, index)
3210 }
3211 ChunkView::SummaryFallback { strategy, .. } => {
3212 format!("{}|{}|s|{}", chunk.record_id, chunk.section_idx, strategy)
3213 }
3214 }
3215}
3216
3217fn pad_with_reuse<T: Clone>(items: &mut Vec<T>, target: usize) {
3218 if items.is_empty() || items.len() >= target {
3219 return;
3220 }
3221 let seed = items.clone();
3222 let base_len = seed.len();
3223 for idx in 0..(target - items.len()) {
3224 items.push(seed[idx % base_len].clone());
3225 }
3226}
3227
3228#[cfg(test)]
3229mod tests;