1mod backends;
2
3use crate::chunking::{ChunkingAlgorithm, SlidingWindowChunker};
4use chrono::Duration;
5use indexmap::IndexMap;
6use rand::prelude::*;
7use rayon::prelude::*;
8use std::borrow::Cow;
9use std::collections::{HashMap, HashSet};
10use std::path::Path;
11use std::sync::atomic::{AtomicUsize, Ordering};
12use std::sync::mpsc;
13use std::sync::{Arc, Mutex};
14use std::thread;
15
16use crate::config::{
17 ChunkingStrategy, NegativeStrategy, SamplerConfig, Selector, TextRecipe, TripletRecipe,
18};
19use crate::constants::sampler::AUTO_INJECTED_LONG_SECTION_CHUNK_PAIR_RECIPE_NAME;
20use crate::constants::sampler::{
21 ANCHOR_POSITIVE_SWAP_MASK, EPOCH_SEED_OFFSET, EXHAUSTION_RETRY_LIMIT, NEG_REASON_WRONG_ARTICLE,
22 NEG_REASON_WRONG_DATE, NEG_REASON_WRONG_QA, PREFETCHER_SOURCE_ID, PREFETCHER_STOPPED_REASON,
23 RECIPE_LABEL_TEXT, RECIPE_LABEL_TRIPLETS, RECIPE_ORDER_MAX_WEIGHT_MULTIPLIER,
24 ROLE_LABEL_ANCHOR, ROLE_LABEL_CONTEXT, SAME_SELECTOR_PAIR_RETRY_LIMIT,
25};
26use crate::data::{
27 ChunkView, DataRecord, PairLabel, RecordChunk, RecordSection, SampleBatch, SamplePair,
28 SampleTriplet, SectionRole, TextBatch, TextSample, TripletBatch,
29};
30use crate::epoch::EpochTracker;
31use crate::errors::SamplerError;
32use crate::hash::{derive_epoch_seed, stable_hash_str};
33use crate::ingestion::IngestionManager;
34use crate::metadata::{META_FIELD_DATE, MetadataKey};
35use crate::metrics::{chunk_proximity_score, window_index_proximity};
36use crate::source::DataSource;
37use crate::splits::{
38 EpochStateStore, PersistedSamplerState, SamplerStateStore, SplitLabel, SplitStore,
39};
40use crate::tokenizer::{Tokenizer, WhitespaceTokenizer};
41use crate::types::{RecipeKey, RecordId, SourceId};
42use crate::utils::platform_newline;
43
44#[derive(Debug, Clone)]
57struct DeterministicRng {
59 state: u64,
60}
61
62impl DeterministicRng {
63 fn new(seed: u64) -> Self {
64 Self { state: seed }
65 }
66
67 fn from_state(state: u64) -> Self {
68 Self { state }
69 }
70
71 fn state(&self) -> u64 {
72 self.state
73 }
74
75 fn next_u64_internal(&mut self) -> u64 {
76 let mut z = self.state.wrapping_add(0x9E3779B97F4A7C15);
77 self.state = z;
78 z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
79 z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
80 z ^ (z >> 31)
81 }
82}
83
84impl rand::RngCore for DeterministicRng {
85 fn next_u32(&mut self) -> u32 {
86 self.next_u64_internal() as u32
87 }
88
89 fn next_u64(&mut self) -> u64 {
90 self.next_u64_internal()
91 }
92
93 fn fill_bytes(&mut self, dest: &mut [u8]) {
94 let mut offset = 0;
95 while offset < dest.len() {
96 let value = self.next_u64_internal();
97 let bytes = value.to_le_bytes();
98 let remaining = dest.len() - offset;
99 let copy_len = remaining.min(bytes.len());
100 dest[offset..offset + copy_len].copy_from_slice(&bytes[..copy_len]);
101 offset += copy_len;
102 }
103 }
104}
105
106pub fn chunk_weight(strategy: &ChunkingStrategy, chunk: &RecordChunk) -> f32 {
115 let floor = strategy.chunk_weight_floor;
116 let trust = chunk.quality.trust.clamp(0.0, 1.0);
117 let base = match &chunk.view {
118 ChunkView::Window { index, .. } => window_index_proximity(*index),
119 ChunkView::SummaryFallback { weight, .. } => *weight,
120 };
121 (base * trust).max(floor)
122}
123
124pub trait Sampler {
126 fn next_pair_batch(&self, split: SplitLabel) -> Result<SampleBatch, SamplerError> {
128 self.next_pair_batch_with_weights(split, &HashMap::new())
129 }
130 fn next_text_batch(&self, split: SplitLabel) -> Result<TextBatch, SamplerError> {
132 self.next_text_batch_with_weights(split, &HashMap::new())
133 }
134 fn next_triplet_batch(&self, split: SplitLabel) -> Result<TripletBatch, SamplerError> {
136 self.next_triplet_batch_with_weights(split, &HashMap::new())
137 }
138 fn next_pair_batch_with_weights(
140 &self,
141 split: SplitLabel,
142 weights: &HashMap<SourceId, f32>,
143 ) -> Result<SampleBatch, SamplerError>;
144 fn next_text_batch_with_weights(
146 &self,
147 split: SplitLabel,
148 weights: &HashMap<SourceId, f32>,
149 ) -> Result<TextBatch, SamplerError>;
150 fn next_triplet_batch_with_weights(
152 &self,
153 split: SplitLabel,
154 weights: &HashMap<SourceId, f32>,
155 ) -> Result<TripletBatch, SamplerError>;
156}
157
158pub struct BatchPrefetcher<T> {
160 receiver: Option<mpsc::Receiver<Result<T, SamplerError>>>,
161 handle: Option<thread::JoinHandle<()>>,
162 stats: Arc<PrefetcherStats>,
163}
164
165#[derive(Default)]
166struct PrefetcherStats {
168 queued: AtomicUsize,
169 produced: AtomicUsize,
170 errors: AtomicUsize,
171}
172
173impl<T: Send + 'static> BatchPrefetcher<T> {
174 fn new<F>(capacity: usize, mut producer: F) -> Self
175 where
176 F: FnMut() -> Result<T, SamplerError> + Send + 'static,
177 {
178 let (sender, receiver) = mpsc::sync_channel(capacity.max(1));
179 let stats = Arc::new(PrefetcherStats::default());
180 let stats_thread = Arc::clone(&stats);
181 let handle = thread::spawn(move || {
182 loop {
183 let result = producer();
184 if result.is_err() {
185 stats_thread.errors.fetch_add(1, Ordering::Relaxed);
186 }
187 if sender.send(result).is_err() {
188 return;
189 }
190 stats_thread.queued.fetch_add(1, Ordering::Relaxed);
191 stats_thread.produced.fetch_add(1, Ordering::Relaxed);
192 }
193 });
194 Self {
195 receiver: Some(receiver),
196 handle: Some(handle),
197 stats,
198 }
199 }
200
201 pub fn next(&self) -> Result<T, SamplerError> {
203 let receiver = self
204 .receiver
205 .as_ref()
206 .ok_or_else(|| SamplerError::SourceUnavailable {
207 source_id: PREFETCHER_SOURCE_ID.into(),
208 reason: PREFETCHER_STOPPED_REASON.into(),
209 })?;
210 let result = receiver.recv().unwrap_or_else(|_| {
211 Err(SamplerError::SourceUnavailable {
212 source_id: PREFETCHER_SOURCE_ID.into(),
213 reason: PREFETCHER_STOPPED_REASON.into(),
214 })
215 });
216 self.stats
217 .queued
218 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |value| {
219 Some(value.saturating_sub(1))
220 })
221 .ok();
222 result
223 }
224
225 pub fn queue_len(&self) -> usize {
227 self.stats.queued.load(Ordering::Relaxed)
228 }
229
230 pub fn produced_count(&self) -> usize {
232 self.stats.produced.load(Ordering::Relaxed)
233 }
234
235 pub fn error_count(&self) -> usize {
237 self.stats.errors.load(Ordering::Relaxed)
238 }
239}
240
241impl<T> Drop for BatchPrefetcher<T> {
242 fn drop(&mut self) {
243 self.receiver.take();
244 if let Some(handle) = self.handle.take() {
245 let _ = handle.join();
246 }
247 }
248}
249
250pub struct TripletSampler<S: SplitStore + EpochStateStore + SamplerStateStore + 'static> {
253 inner: Mutex<TripletSamplerInner<S>>,
254}
255
256struct TripletSamplerInner<S: SplitStore + EpochStateStore + SamplerStateStore + 'static> {
258 config: SamplerConfig,
260 chunker: Arc<dyn ChunkingAlgorithm>,
262 split_store: Arc<S>,
264 ingestion: IngestionManager,
266 records: IndexMap<RecordId, Arc<DataRecord>>,
268 rng: DeterministicRng,
270 triplet_recipes: Vec<TripletRecipe>,
272 text_recipes: Vec<TextRecipe>,
274 source_triplet_recipes: HashMap<SourceId, Vec<TripletRecipe>>,
276 sources_with_long_sections: HashSet<SourceId>,
278 source_text_recipes: HashMap<SourceId, Vec<TextRecipe>>,
280 using_config_triplet_recipes: bool,
282 using_config_text_recipes: bool,
284 last_observed_ingest: u64,
286 epoch_tracker: EpochTracker,
288 chunk_cursors: HashMap<(RecordId, usize), usize>,
290 role_cursors: HashMap<(RecordId, String), usize>,
292 negative_backend: Box<dyn backends::NegativeBackend>,
294 chunk_index: HashMap<RecordId, RecordId>,
296 source_order: Vec<SourceId>,
298 source_cycle_idx: usize,
300 source_state_loaded: bool,
302 ingestion_cursors_loaded: bool,
304 source_state_dirty: bool,
306 source_record_indices: HashMap<SourceId, Vec<usize>>,
308 source_record_cursors: HashMap<SourceId, usize>,
310 triplet_recipe_rr_idx: usize,
312 emitted_text_hashes: HashSet<u64>,
316 text_recipe_rr_idx: usize,
318 source_epoch: u64,
320 source_wrapped: HashMap<SourceId, bool>,
322}
323
324impl<S: SplitStore + EpochStateStore + SamplerStateStore + 'static> TripletSamplerInner<S> {
325 fn new(config: SamplerConfig, split_store: Arc<S>) -> Self {
326 Self::new_with_chunker(config, split_store, Arc::new(SlidingWindowChunker))
327 }
328
329 fn new_with_chunker(
330 config: SamplerConfig,
331 split_store: Arc<S>,
332 chunker: Arc<dyn ChunkingAlgorithm>,
333 ) -> Self {
334 let buffer_size = config.ingestion_max_records.max(config.batch_size).max(2);
335 let using_config_triplet_recipes = !config.recipes.is_empty();
336 let using_config_text_recipes = !config.text_recipes.is_empty();
337 let triplet_recipes = if using_config_triplet_recipes {
338 config.recipes.clone()
339 } else {
340 Vec::new()
341 };
342 let text_recipes = if using_config_text_recipes {
343 config.text_recipes.clone()
344 } else if !triplet_recipes.is_empty() {
345 Self::build_derived_text_recipes(&triplet_recipes)
346 } else {
347 Vec::new()
348 };
349 let ingestion = IngestionManager::new(buffer_size, config.clone());
350 let epoch_backend = Some(Arc::clone(&split_store) as Arc<dyn EpochStateStore>);
351 let epoch_tracker = EpochTracker::new(
352 true,
353 epoch_backend,
354 derive_epoch_seed(config.seed, EPOCH_SEED_OFFSET),
355 );
356 let mut sampler = Self {
357 rng: DeterministicRng::new(config.seed),
358 config,
359 chunker,
360 split_store,
361 ingestion,
362 records: IndexMap::new(),
363 triplet_recipes,
364 text_recipes,
365 source_triplet_recipes: HashMap::new(),
366 sources_with_long_sections: HashSet::new(),
367 source_text_recipes: HashMap::new(),
368 using_config_triplet_recipes,
369 using_config_text_recipes,
370 last_observed_ingest: 0,
371 epoch_tracker,
372 chunk_cursors: HashMap::new(),
373 role_cursors: HashMap::new(),
374 negative_backend: {
375 #[cfg(feature = "bm25-mining")]
376 {
377 Box::new(backends::Bm25Backend::new())
378 }
379 #[cfg(not(feature = "bm25-mining"))]
380 {
381 Box::new(backends::DefaultBackend)
382 }
383 },
384 chunk_index: HashMap::new(),
385 source_order: Vec::new(),
386 source_cycle_idx: 0,
387 source_state_loaded: false,
388 ingestion_cursors_loaded: false,
389 source_state_dirty: false,
390 source_record_indices: HashMap::new(),
391 source_record_cursors: HashMap::new(),
392 emitted_text_hashes: HashSet::new(),
393 triplet_recipe_rr_idx: 0,
394 text_recipe_rr_idx: 0,
395 source_epoch: 0,
396 source_wrapped: HashMap::new(),
397 };
398 if !sampler.using_config_text_recipes {
399 sampler.rebuild_derived_text_recipes();
400 }
401 sampler
402 }
403
404 fn text_recipes(&self) -> &[TextRecipe] {
405 &self.text_recipes
406 }
407
408 fn epoch_seed(&self) -> u64 {
411 derive_epoch_seed(self.config.seed, self.source_epoch)
412 }
413
414 fn register_source(&mut self, source: Box<dyn DataSource + 'static>) {
415 let source_id = source.id().to_string();
416 if !self.using_config_triplet_recipes {
417 let triplets = source.default_triplet_recipes();
418 if !triplets.is_empty() {
419 self.source_triplet_recipes
420 .insert(source_id.clone(), triplets.clone());
421 if !self.using_config_text_recipes {
422 let derived = Self::build_derived_text_recipes(&triplets);
423 self.source_text_recipes
424 .insert(source_id.clone(), derived.clone());
425 self.extend_text_recipes_unique(&derived);
426 }
427 }
428 }
429 self.ingestion.register_source(source);
430 }
431
432 fn set_epoch(&mut self, epoch: u64) -> Result<(), SamplerError> {
433 self.epoch_tracker.ensure_loaded()?;
434 self.epoch_tracker.force_epoch(epoch);
435 self.source_epoch = epoch;
436 self.ingestion.set_source_epoch(epoch);
437 self.ingestion.reset_stream_cursors();
438 self.source_record_cursors.clear();
439 self.source_cycle_idx = 0;
440 for source in &self.source_order {
441 self.source_wrapped.insert(source.clone(), false);
442 }
443 self.rebuild_source_index()?;
444 self.source_state_dirty = self.source_order.len() > 1;
445 Ok(())
446 }
447
448 fn next_chunk_from_pool(
449 &mut self,
450 record_id: &str,
451 section_idx: usize,
452 pool: Vec<RecordChunk>,
453 ) -> Option<RecordChunk> {
454 if pool.is_empty() {
455 return None;
456 }
457 let key = (record_id.to_string(), section_idx);
458 if !self.chunk_cursors.contains_key(&key) {
459 let cursor_key = format!("{}::{}", record_id, section_idx);
465 let start = (stable_hash_str(self.epoch_seed(), &cursor_key) as usize) % pool.len();
466 self.chunk_cursors.insert(key.clone(), start);
467 }
468 let cursor = self.chunk_cursors.entry(key).or_insert(0);
469 if *cursor >= pool.len() {
470 *cursor = 0;
471 }
472 let chunk = pool.get(*cursor).cloned();
473 *cursor = (*cursor + 1) % pool.len();
474 chunk
475 }
476
477 fn prune_cursor_state(&mut self) {
478 if self.chunk_cursors.is_empty()
479 && self.role_cursors.is_empty()
480 && self.negative_backend.cursors_empty()
481 {
482 return;
483 }
484 let valid_ids: HashSet<RecordId> = self.records.keys().cloned().collect();
485 self.chunk_cursors
486 .retain(|(record_id, _), _| valid_ids.contains(record_id));
487 self.role_cursors
488 .retain(|(record_id, _), _| valid_ids.contains(record_id));
489 self.negative_backend.prune_cursors(&valid_ids);
490 }
491
492 fn rebuild_chunk_index(&mut self) {
493 self.chunk_index.clear();
494 for record in self.records.values() {
495 self.chunk_index
496 .insert(record.id.clone(), record.id.clone());
497 }
498 }
499
500 fn rebuild_source_index(&mut self) -> Result<(), SamplerError> {
501 self.source_record_indices.clear();
502 let mut label_cache: HashMap<RecordId, SplitLabel> = HashMap::new();
503 let allowed = self.allowed_target_splits();
504 let allowed_set: HashSet<SplitLabel> = allowed.into_iter().collect();
505 for (idx, record) in self.records.values().enumerate() {
506 let label = if let Some(label) = label_cache.get(&record.id) {
507 *label
508 } else {
509 let label = match self.split_store.label_for(&record.id) {
510 Some(label) => label,
511 None => self.split_store.ensure(record.id.clone())?,
512 };
513 label_cache.insert(record.id.clone(), label);
514 label
515 };
516 if !allowed_set.contains(&label) {
517 continue;
518 }
519 self.source_record_indices
520 .entry(record.source.clone())
521 .or_default()
522 .push(idx);
523 }
524
525 let shuffle_seed = self.epoch_seed();
526 for indices in self.source_record_indices.values_mut() {
527 indices.sort_by_key(|idx| {
528 self.records
529 .get_index(*idx)
530 .map(|(_, record)| stable_hash_str(shuffle_seed, &record.id))
531 .unwrap_or(0)
532 });
533 }
534
535 self.source_order = self.source_record_indices.keys().cloned().collect();
536 self.source_order.sort();
537 self.refresh_source_wrapped();
538
539 self.source_record_cursors
540 .retain(|source, _| self.source_record_indices.contains_key(source));
541 if self.source_state_loaded {
542 if self.source_order.is_empty() {
543 self.source_cycle_idx = 0;
544 }
545 self.source_state_dirty = self.source_order.len() > 1;
546 }
547 Ok(())
548 }
549
550 fn refresh_source_wrapped(&mut self) {
551 self.source_wrapped.clear();
552 for source in &self.source_order {
553 let len = self
554 .source_record_indices
555 .get(source)
556 .map(|items| items.len())
557 .unwrap_or(0);
558 if len == 0 {
559 self.source_wrapped.insert(source.clone(), false);
560 continue;
561 }
562 let cursor = self.source_record_cursors.get(source).copied().unwrap_or(0);
563 let wrapped = cursor > 0 && cursor % len == 0;
564 self.source_wrapped.insert(source.clone(), wrapped);
565 }
566 }
567
568 fn shuffled_source_cycle(&self, cycle: u64) -> Vec<SourceId> {
569 let mut sources = self.source_order.clone();
570 let seed = self.epoch_seed() ^ cycle;
571 sources.sort_by_key(|source| stable_hash_str(seed, source));
572 sources
573 }
574
575 fn ensure_source_state(&mut self) -> Result<(), SamplerError> {
576 if self.source_state_loaded {
577 return Ok(());
578 }
579 let persisted = self.split_store.load_sampler_state()?;
580 self.source_cycle_idx = persisted
581 .as_ref()
582 .map(|state| state.source_cycle_idx as usize)
583 .unwrap_or(0);
584 if let Some(state) = persisted {
585 for (source, cursor) in state.source_record_cursors {
586 if self.source_record_indices.contains_key(&source) {
587 self.source_record_cursors.insert(source, cursor as usize);
588 }
589 }
590 self.source_epoch = state.source_epoch;
591 self.ingestion.set_source_epoch(state.source_epoch);
592 self.rng = DeterministicRng::from_state(state.rng_state);
593 self.triplet_recipe_rr_idx = state.triplet_recipe_rr_idx as usize;
594 self.text_recipe_rr_idx = state.text_recipe_rr_idx as usize;
595 }
596 self.refresh_source_wrapped();
597 self.source_state_loaded = true;
598 self.source_state_dirty = true;
599 Ok(())
600 }
601
602 fn persist_source_state(&mut self, save_to: Option<&Path>) -> Result<(), SamplerError> {
603 if !self.source_state_loaded {
604 return Ok(());
605 }
606 let state = PersistedSamplerState {
607 source_cycle_idx: self.source_cycle_idx as u64,
608 source_record_cursors: self
609 .source_record_cursors
610 .iter()
611 .map(|(source, cursor)| (source.clone(), *cursor as u64))
612 .collect(),
613 source_epoch: self.source_epoch,
614 rng_state: self.rng.state(),
615 triplet_recipe_rr_idx: self.triplet_recipe_rr_idx as u64,
616 text_recipe_rr_idx: self.text_recipe_rr_idx as u64,
617 source_stream_cursors: self.ingestion.snapshot_cursors(),
618 };
619 self.split_store.save_sampler_state(&state, save_to)?;
620 self.source_state_dirty = false;
621 Ok(())
622 }
623
624 fn rebuild_derived_text_recipes(&mut self) {
625 if self.using_config_text_recipes {
626 return;
627 }
628 if self.triplet_recipes.is_empty() {
629 self.text_recipes.clear();
630 } else {
631 self.text_recipes = Self::build_derived_text_recipes(&self.triplet_recipes);
632 }
633 }
634
635 fn extend_text_recipes_unique(&mut self, recipes: &[TextRecipe]) {
636 for recipe in recipes {
637 if self
638 .text_recipes
639 .iter()
640 .any(|existing| existing.name == recipe.name)
641 {
642 continue;
643 }
644 self.text_recipes.push(recipe.clone());
645 }
646 }
647
648 fn configured_triplet_recipes_for_source<'a>(&'a self, source: &str) -> &'a [TripletRecipe] {
649 if self.using_config_triplet_recipes {
650 return &self.triplet_recipes;
651 }
652 self.source_triplet_recipes
653 .get(source)
654 .map(|recipes| recipes.as_slice())
655 .unwrap_or(&[])
656 }
657
658 fn contains_auto_chunk_pair_recipe(recipes: &[TripletRecipe]) -> bool {
660 recipes
661 .iter()
662 .any(|recipe| recipe.name.as_ref() == AUTO_INJECTED_LONG_SECTION_CHUNK_PAIR_RECIPE_NAME)
663 }
664
665 fn source_supports_chunk_pair_recipe(&self, source: &str) -> bool {
666 if self.config.chunking.max_window_tokens == 0 {
667 return false;
668 }
669 self.sources_with_long_sections.contains(source)
670 }
671
672 fn should_auto_inject_chunk_pair_recipe(
678 &self,
679 source: &str,
680 recipes: &[TripletRecipe],
681 ) -> bool {
682 self.source_supports_chunk_pair_recipe(source)
683 && !Self::contains_auto_chunk_pair_recipe(recipes)
684 }
685
686 fn source_chunk_pair_recipe() -> TripletRecipe {
697 TripletRecipe {
698 name: Cow::Borrowed(AUTO_INJECTED_LONG_SECTION_CHUNK_PAIR_RECIPE_NAME),
699 anchor: Selector::Role(SectionRole::Context),
700 positive_selector: Selector::Role(SectionRole::Context),
701 negative_selector: Selector::Role(SectionRole::Context),
702 negative_strategy: NegativeStrategy::WrongArticle,
703 weight: 1.0,
704 instruction: None,
705 allow_same_anchor_positive: false,
706 }
707 }
708
709 fn resolve_source_triplet_plan(&self, source: &str) -> (Vec<TripletRecipe>, bool) {
721 let mut recipes = self.configured_triplet_recipes_for_source(source).to_vec();
722 let mut auto_injected = false;
723 if self.should_auto_inject_chunk_pair_recipe(source, &recipes) {
724 recipes.push(Self::source_chunk_pair_recipe());
725 auto_injected = true;
726 }
727 (recipes, auto_injected)
728 }
729
730 #[cfg(test)]
731 fn triplet_recipes_for_source(&self, source: &str) -> Vec<TripletRecipe> {
732 self.resolve_source_triplet_plan(source).0
733 }
734
735 fn triplet_recipe_count_for_source(&self, source: &str) -> usize {
736 let (recipes, _auto_injected) = self.resolve_source_triplet_plan(source);
737 recipes.len()
738 }
739
740 fn text_recipes_for_source<'a>(&'a self, source: &str) -> &'a [TextRecipe] {
741 if self.using_config_text_recipes || self.using_config_triplet_recipes {
742 return &self.text_recipes;
743 }
744 self.source_text_recipes
745 .get(source)
746 .map(|recipes| recipes.as_slice())
747 .unwrap_or(&[])
748 }
749
750 fn recipe_order_weighted_shuffled(
761 &mut self,
762 weights: &[f32],
763 rng: &mut DeterministicRng,
764 ) -> Vec<usize> {
765 weighted_recipe_order(weights, rng)
766 }
767
768 fn recipe_order_weighted_cycled(
774 &mut self,
775 weights: &[f32],
776 rr_idx: usize,
777 rng: &mut DeterministicRng,
778 ) -> Vec<usize> {
779 let base = self.recipe_order_weighted_shuffled(weights, rng);
780 if base.is_empty() {
781 return base;
782 }
783 let start = rr_idx % base.len();
784 let mut order = Vec::with_capacity(base.len());
785 order.extend_from_slice(&base[start..]);
786 order.extend_from_slice(&base[..start]);
787 order
788 }
789
790 fn text_recipe_order_weighted_shuffled(
794 &mut self,
795 weights: &[f32],
796 rng: &mut DeterministicRng,
797 ) -> Vec<usize> {
798 weighted_recipe_order(weights, rng)
799 }
800
801 fn text_recipe_order_weighted_cycled(
802 &mut self,
803 weights: &[f32],
804 rr_idx: usize,
805 rng: &mut DeterministicRng,
806 ) -> Vec<usize> {
807 let base = self.text_recipe_order_weighted_shuffled(weights, rng);
808 if base.is_empty() {
809 return base;
810 }
811 let start = rr_idx % base.len();
812 let mut order = Vec::with_capacity(base.len());
813 order.extend_from_slice(&base[start..]);
814 order.extend_from_slice(&base[..start]);
815 order
816 }
817
818 fn allowed_target_splits(&self) -> Vec<SplitLabel> {
819 self.config.allowed_splits.clone()
820 }
821
822 fn ensure_split_allowed(&self, split: SplitLabel) -> Result<(), SamplerError> {
823 let allowed = self.allowed_target_splits();
824 if allowed.contains(&split) {
825 return Ok(());
826 }
827 Err(SamplerError::Configuration(format!(
828 "requested split {:?} is not in allowed_splits {:?}",
829 split, allowed
830 )))
831 }
832
833 fn ensure_split_has_records(&mut self, target_split: SplitLabel) -> Result<(), SamplerError> {
834 let records_by_split = self.records_by_split()?;
835 if records_by_split
836 .get(&target_split)
837 .map(|records| !records.is_empty())
838 .unwrap_or(false)
839 {
840 return Ok(());
841 }
842 Err(SamplerError::Exhausted(
843 "no records available for target split".into(),
844 ))
845 }
846
847 fn records_by_split(
848 &self,
849 ) -> Result<HashMap<SplitLabel, Vec<(RecordId, SourceId)>>, SamplerError> {
850 let mut map: HashMap<SplitLabel, Vec<(RecordId, SourceId)>> = HashMap::new();
851 let mut label_cache: HashMap<RecordId, SplitLabel> = HashMap::new();
852 for (chunk_id, record_id) in &self.chunk_index {
853 let Some(record) = self.records.get(record_id) else {
854 continue;
855 };
856 let label = if let Some(label) = label_cache.get(record_id) {
857 *label
858 } else {
859 let label = match self.split_store.label_for(record_id) {
860 Some(label) => label,
861 None => self.split_store.ensure(record_id.clone())?,
862 };
863 label_cache.insert(record_id.clone(), label);
864 label
865 };
866 map.entry(label)
867 .or_default()
868 .push((chunk_id.clone(), record.source.clone()));
869 }
870 Ok(map)
871 }
872
873 fn choose_anchor_record(
874 &mut self,
875 source: Option<&str>,
876 split: SplitLabel,
877 ) -> Option<Arc<DataRecord>> {
878 if let Some(source) = source {
879 let indices = self.source_record_indices.get(source)?;
880 if indices.is_empty() {
881 return None;
882 }
883 let mut cursor = *self.source_record_cursors.get(source).unwrap_or(&0);
884 let cycle = cursor / indices.len();
885 let offset_seed = self.epoch_seed() ^ (cycle as u64);
886 let offset = (stable_hash_str(offset_seed, source) as usize) % indices.len();
887 let mut wrapped = false;
888 let mut selected: Option<Arc<DataRecord>> = None;
889 for _ in 0..indices.len() {
890 let pos = (cursor % indices.len()).saturating_add(offset) % indices.len();
891 let idx = indices[pos];
892 cursor = cursor.saturating_add(1);
893 if cursor.is_multiple_of(indices.len()) {
894 wrapped = true;
895 }
896 if let Some((_, record)) = self.records.get_index(idx) {
897 if self.split_store.label_for(&record.id) != Some(split) {
898 continue;
899 }
900 selected = Some(Arc::clone(record));
901 break;
902 }
903 }
904 self.source_record_cursors
905 .insert(source.to_string(), cursor);
906 if wrapped {
907 self.mark_source_wrapped(source);
908 }
909 return selected;
910 }
911 while let Some(chunk_id) = self.epoch_tracker.next_record(split) {
912 if let Some(record_id) = self.chunk_index.get(&chunk_id)
913 && let Some(record) = self.records.get(record_id)
914 {
915 return Some(Arc::clone(record));
916 }
917 }
918 None
919 }
920
921 fn save_sampler_state(&mut self, save_to: Option<&Path>) -> Result<(), SamplerError> {
922 if self.epoch_tracker.is_enabled() {
923 self.epoch_tracker.persist()?;
924 }
925 self.persist_source_state(save_to)?;
926 Ok(())
927 }
928
929 fn mark_source_wrapped(&mut self, source: &str) {
930 self.source_wrapped.insert(source.to_string(), true);
931 if self.source_order.is_empty() {
932 return;
933 }
934 let all_wrapped = self
935 .source_order
936 .iter()
937 .all(|name| self.source_wrapped.get(name).copied().unwrap_or(false));
938 if all_wrapped {
939 self.advance_source_epoch();
940 }
941 }
942
943 fn advance_source_epoch(&mut self) {
944 self.emitted_text_hashes.clear();
950 self.source_epoch = self.source_epoch.saturating_add(1);
953 self.ingestion.set_source_epoch(self.source_epoch);
954 self.source_record_cursors.clear();
958 self.source_cycle_idx = 0;
959 for source in &self.source_order {
962 self.source_wrapped.insert(source.clone(), false);
963 }
964 let _ = self.rebuild_source_index();
965 self.source_state_dirty = self.source_order.len() > 1;
966 }
967
968 fn select_temporal_neighbor(
969 &'_ self,
970 record: &DataRecord,
971 offset_days: i32,
972 ) -> Option<Arc<DataRecord>> {
973 let target = record.created_at + Duration::days(offset_days.into());
974 let key = record.taxonomy.first().cloned();
975 let record_split = self.split_store.label_for(&record.id)?;
976 self.records
977 .values()
978 .filter(|candidate| {
979 candidate.id != record.id
980 && self
981 .split_store
982 .label_for(&candidate.id)
983 .map(|label| label == record_split)
984 .unwrap_or(false)
985 && (candidate.source == record.source
986 || key
987 .as_ref()
988 .zip(candidate.taxonomy.first())
989 .map(|(a, b)| a == b)
990 .unwrap_or(false))
991 })
992 .min_by_key(|candidate| (candidate.created_at - target).num_seconds().abs())
993 .cloned()
994 }
995
996 fn select_negative_record(
997 &self,
998 anchor_record: &DataRecord,
999 strategy: &NegativeStrategy,
1000 anchor_query_text: Option<&str>,
1001 rng: &mut dyn rand::RngCore,
1002 ) -> Option<(Arc<DataRecord>, bool)> {
1003 let anchor_split = self.split_store.label_for(&anchor_record.id)?;
1004
1005 let in_anchor_split = |candidate: &DataRecord| {
1006 self.split_store
1007 .label_for(&candidate.id)
1008 .map(|label| label == anchor_split)
1009 .unwrap_or(false)
1010 };
1011
1012 match strategy {
1013 NegativeStrategy::WrongArticle => {
1014 let anchor_date =
1015 taxonomy_value(anchor_record, META_FIELD_DATE).map(|d| d.to_string());
1016 let mut same_date: Vec<Arc<DataRecord>> = self
1017 .records
1018 .values()
1019 .filter(|candidate| {
1020 candidate.source == anchor_record.source
1021 && candidate.id != anchor_record.id
1022 && in_anchor_split(candidate)
1023 })
1024 .filter(|candidate| {
1025 anchor_date
1026 .as_deref()
1027 .zip(taxonomy_value(candidate, META_FIELD_DATE))
1028 .map(|(a, b)| a == b)
1029 .unwrap_or(false)
1030 })
1031 .cloned()
1032 .collect();
1033 if same_date.is_empty() {
1034 same_date = self
1035 .records
1036 .values()
1037 .filter(|candidate| {
1038 candidate.source == anchor_record.source
1039 && candidate.id != anchor_record.id
1040 && in_anchor_split(candidate)
1041 })
1042 .cloned()
1043 .collect();
1044 }
1045 if !same_date.is_empty() {
1046 return self.negative_backend.choose_negative(
1047 anchor_record,
1048 anchor_split,
1049 same_date,
1050 false,
1051 anchor_query_text,
1052 rng,
1053 );
1054 }
1055 let pool = self
1056 .records
1057 .values()
1058 .filter(|candidate| {
1059 candidate.id != anchor_record.id && in_anchor_split(candidate)
1060 })
1061 .cloned()
1062 .collect::<Vec<_>>();
1063 self.negative_backend.choose_negative(
1064 anchor_record,
1065 anchor_split,
1066 pool,
1067 true,
1068 anchor_query_text,
1069 rng,
1070 )
1071 }
1072 NegativeStrategy::WrongPublicationDate => {
1073 let anchor_date =
1074 taxonomy_value(anchor_record, META_FIELD_DATE).map(|d| d.to_string());
1075 let pool: Vec<Arc<DataRecord>> = self
1076 .records
1077 .values()
1078 .filter(|candidate| {
1079 candidate.source == anchor_record.source
1080 && candidate.id != anchor_record.id
1081 && in_anchor_split(candidate)
1082 })
1083 .filter(|candidate| {
1084 match (
1085 anchor_date.as_deref(),
1086 taxonomy_value(candidate, META_FIELD_DATE),
1087 ) {
1088 (Some(anchor), Some(candidate_date)) => anchor != candidate_date,
1089 (Some(_), None) => true,
1090 (None, Some(_)) => true,
1091 (None, None) => false,
1092 }
1093 })
1094 .cloned()
1095 .collect();
1096 if pool.is_empty() {
1097 let fallback_pool = self
1100 .records
1101 .values()
1102 .filter(|candidate| {
1103 candidate.id != anchor_record.id && in_anchor_split(candidate)
1104 })
1105 .cloned()
1106 .collect::<Vec<_>>();
1107
1108 return self.negative_backend.choose_negative(
1109 anchor_record,
1110 anchor_split,
1111 fallback_pool,
1112 true,
1113 anchor_query_text,
1114 rng,
1115 );
1116 }
1117
1118 self.negative_backend.choose_negative(
1119 anchor_record,
1120 anchor_split,
1121 pool,
1122 false,
1123 anchor_query_text,
1124 rng,
1125 )
1126 }
1127 NegativeStrategy::QuestionAnswerMismatch => {
1128 let pool: Vec<Arc<DataRecord>> = self
1129 .records
1130 .values()
1131 .filter(|candidate| {
1132 candidate.source == anchor_record.source
1133 && candidate.id != anchor_record.id
1134 && in_anchor_split(candidate)
1135 })
1136 .cloned()
1137 .collect();
1138 if pool.is_empty() {
1139 let fallback_pool = self
1142 .records
1143 .values()
1144 .filter(|candidate| {
1145 candidate.id != anchor_record.id && in_anchor_split(candidate)
1146 })
1147 .cloned()
1148 .collect::<Vec<_>>();
1149
1150 return self.negative_backend.choose_negative(
1151 anchor_record,
1152 anchor_split,
1153 fallback_pool,
1154 true,
1155 anchor_query_text,
1156 rng,
1157 );
1158 }
1159
1160 self.negative_backend.choose_negative(
1161 anchor_record,
1162 anchor_split,
1163 pool,
1164 false,
1165 anchor_query_text,
1166 rng,
1167 )
1168 }
1169 }
1170 }
1171
1172 fn is_auto_chunk_pair_recipe(recipe: &TripletRecipe) -> bool {
1174 recipe.name.as_ref() == AUTO_INJECTED_LONG_SECTION_CHUNK_PAIR_RECIPE_NAME
1175 }
1176
1177 fn select_anchor_positive_pair(
1181 &mut self,
1182 record: &DataRecord,
1183 anchor_selector: &Selector,
1184 positive_selector: &Selector,
1185 enforce_window_pair: bool,
1186 ) -> Option<(RecordChunk, RecordChunk)> {
1187 let mut anchor_chunk = self.select_chunk(record, anchor_selector)?;
1188 let mut positive_chunk = self.select_chunk(record, positive_selector)?;
1189 if anchor_selector == positive_selector {
1190 let mut retries = 0usize;
1191 while !same_selector_pair_is_valid(&anchor_chunk, &positive_chunk, enforce_window_pair)
1192 && retries < SAME_SELECTOR_PAIR_RETRY_LIMIT
1193 {
1194 let Some(redraw_anchor) = self.select_chunk(record, anchor_selector) else {
1195 break;
1196 };
1197 let Some(redraw_positive) = self.select_chunk(record, positive_selector) else {
1198 break;
1199 };
1200 anchor_chunk = redraw_anchor;
1201 positive_chunk = redraw_positive;
1202 retries += 1;
1203 }
1204 if !same_selector_pair_is_valid(&anchor_chunk, &positive_chunk, enforce_window_pair) {
1205 return None;
1206 }
1207 }
1208 Some((anchor_chunk, positive_chunk))
1209 }
1210
1211 fn select_distinct_window_pair_for_auto_recipe(
1221 &mut self,
1222 recipe: &TripletRecipe,
1223 record: &DataRecord,
1224 ) -> Option<(RecordChunk, RecordChunk)> {
1225 if recipe.anchor != recipe.positive_selector {
1226 return None;
1227 }
1228 self.select_anchor_positive_pair(record, &recipe.anchor, &recipe.positive_selector, true)
1229 }
1230
1231 fn record_has_at_least_two_window_chunks_for_selector(
1236 &self,
1237 record: &DataRecord,
1238 selector: &Selector,
1239 ) -> bool {
1240 let section_indices: Vec<usize> = match selector {
1241 Selector::Role(role) => record
1242 .sections
1243 .iter()
1244 .enumerate()
1245 .filter(|(_, section)| roles_match(role, §ion.role))
1246 .map(|(idx, _)| idx)
1247 .collect(),
1248 Selector::Paragraph(idx) => {
1249 if *idx < record.sections.len() {
1250 vec![*idx]
1251 } else {
1252 Vec::new()
1253 }
1254 }
1255 Selector::Random => (0..record.sections.len()).collect(),
1256 Selector::TemporalOffset(_) => return false,
1257 };
1258
1259 let mut window_count = 0usize;
1260 for section_idx in section_indices {
1261 let Some(section) = record.sections.get(section_idx) else {
1262 continue;
1263 };
1264 let chunks = self.materialize_chunks(record, section_idx, section);
1265 window_count += chunks
1266 .iter()
1267 .filter(|chunk| matches!(chunk.view, ChunkView::Window { .. }))
1268 .count();
1269 if window_count >= 2 {
1270 return true;
1271 }
1272 }
1273 false
1274 }
1275
1276 fn build_triplet_with_selector_pair_policy(
1278 &mut self,
1279 recipe: &TripletRecipe,
1280 record: &DataRecord,
1281 enforce_window_pair: bool,
1282 rng: &mut DeterministicRng,
1283 ) -> Option<SampleTriplet> {
1284 let (mut anchor_chunk, mut positive_chunk) = self.select_anchor_positive_pair(
1285 record,
1286 &recipe.anchor,
1287 &recipe.positive_selector,
1288 enforce_window_pair,
1289 )?;
1290 let anchor_raw_text = anchor_chunk.text.clone();
1291 self.decorate_chunk(record, &mut anchor_chunk, rng);
1292 self.decorate_chunk(record, &mut positive_chunk, rng);
1293 self.finalize_triplet_with_negative(
1296 recipe,
1297 record,
1298 anchor_chunk,
1299 positive_chunk,
1300 &anchor_raw_text,
1301 rng,
1302 )
1303 }
1304
1305 fn make_auto_chunk_pair_triplet_with_anchor(
1315 &mut self,
1316 recipe: &TripletRecipe,
1317 record: &DataRecord,
1318 rng: &mut DeterministicRng,
1319 ) -> Option<SampleTriplet> {
1320 if !self.record_has_at_least_two_window_chunks_for_selector(record, &recipe.anchor) {
1321 return None;
1322 }
1323 let (mut anchor_chunk, mut positive_chunk) =
1324 self.select_distinct_window_pair_for_auto_recipe(recipe, record)?;
1325 let anchor_raw_text = anchor_chunk.text.clone();
1326 self.decorate_chunk(record, &mut anchor_chunk, rng);
1327 self.decorate_chunk(record, &mut positive_chunk, rng);
1328 self.finalize_triplet_with_negative(
1329 recipe,
1330 record,
1331 anchor_chunk,
1332 positive_chunk,
1333 &anchor_raw_text,
1334 rng,
1335 )
1336 }
1337
1338 fn make_standard_triplet_with_anchor(
1339 &mut self,
1340 recipe: &TripletRecipe,
1341 record: &DataRecord,
1342 rng: &mut DeterministicRng,
1343 ) -> Option<SampleTriplet> {
1344 self.build_triplet_with_selector_pair_policy(recipe, record, false, rng)
1345 }
1346
1347 fn finalize_triplet_with_negative(
1366 &mut self,
1367 recipe: &TripletRecipe,
1368 record: &DataRecord,
1369 anchor_chunk: RecordChunk,
1370 positive_chunk: RecordChunk,
1371 anchor_raw_text: &str,
1372 rng: &mut DeterministicRng,
1373 ) -> Option<SampleTriplet> {
1374 let (negative_record, fallback_used) = self.select_negative_record(
1375 record,
1376 &recipe.negative_strategy,
1377 Some(anchor_raw_text),
1378 rng,
1379 )?;
1380 let mut negative_chunk = self.select_chunk(&negative_record, &recipe.negative_selector)?;
1381 self.decorate_chunk(&negative_record, &mut negative_chunk, rng);
1382
1383 let (anchor_chunk, positive_chunk) = if rng.next_u64() & ANCHOR_POSITIVE_SWAP_MASK == 0 {
1385 (positive_chunk, anchor_chunk)
1386 } else {
1387 (anchor_chunk, positive_chunk)
1388 };
1389
1390 if (!recipe.allow_same_anchor_positive && anchor_chunk.text == positive_chunk.text)
1401 || negative_chunk.text == positive_chunk.text
1402 || negative_chunk.text == anchor_chunk.text
1403 {
1404 return None;
1405 }
1406
1407 let chunk_weight =
1408 self.triplet_chunk_weight(recipe, &anchor_chunk, &positive_chunk, &negative_chunk);
1409 let weight = recipe.weight * chunk_weight;
1410 let recipe_name = if fallback_used {
1411 format!("{}_fallback_same_split", recipe.name)
1412 } else {
1413 recipe.name.to_string()
1414 };
1415 Some(SampleTriplet {
1416 recipe: recipe_name,
1417 anchor: anchor_chunk,
1418 positive: positive_chunk,
1419 negative: negative_chunk,
1420 weight,
1421 instruction: recipe.instruction.as_ref().map(|s| s.to_string()),
1422 })
1423 }
1424
1425 fn make_triplet_with_anchor(
1426 &mut self,
1427 recipe: &TripletRecipe,
1428 record: &DataRecord,
1429 rng: &mut DeterministicRng,
1430 ) -> Option<SampleTriplet> {
1431 if Self::is_auto_chunk_pair_recipe(recipe) {
1432 return self.make_auto_chunk_pair_triplet_with_anchor(recipe, record, rng);
1433 }
1434 self.make_standard_triplet_with_anchor(recipe, record, rng)
1435 }
1436
1437 fn make_text_sample_for_split(
1438 &mut self,
1439 recipe: &TextRecipe,
1440 source: Option<&str>,
1441 split: SplitLabel,
1442 rng: &mut DeterministicRng,
1443 ) -> Option<TextSample> {
1444 let record = self.choose_anchor_record(source, split)?;
1445 let mut chunk = self.select_chunk(&record, &recipe.selector)?;
1446 self.decorate_chunk(&record, &mut chunk, rng);
1447 let weight = recipe.weight * self.chunk_weight(&chunk);
1448 Some(TextSample {
1449 recipe: recipe.name.to_string(),
1450 chunk,
1451 weight,
1452 instruction: recipe.instruction.as_ref().map(|s| s.to_string()),
1453 })
1454 }
1455
1456 fn chunk_weight(&self, chunk: &RecordChunk) -> f32 {
1457 chunk_weight(&self.config.chunking, chunk)
1458 }
1459
1460 fn triplet_chunk_weight(
1461 &self,
1462 recipe: &TripletRecipe,
1463 anchor: &RecordChunk,
1464 positive: &RecordChunk,
1465 negative: &RecordChunk,
1466 ) -> f32 {
1467 let floor = self.config.chunking.chunk_weight_floor;
1468 let negative_weight = negative.quality.trust.clamp(0.0, 1.0).max(floor);
1469 if Self::is_auto_chunk_pair_recipe(recipe) {
1470 let pair_trust = ((anchor.quality.trust.clamp(0.0, 1.0)
1473 + positive.quality.trust.clamp(0.0, 1.0))
1474 / 2.0)
1475 .clamp(0.0, 1.0);
1476 let pair_weight = (chunk_proximity_score(anchor, positive) * pair_trust).max(floor);
1477 return (pair_weight + pair_weight + negative_weight) / 3.0;
1479 }
1480 let pair_proximity = chunk_proximity_score(anchor, positive);
1483 let anchor_weight = (self.chunk_weight(anchor) * pair_proximity).max(floor);
1484 let positive_weight = (self.chunk_weight(positive) * pair_proximity).max(floor);
1485 (anchor_weight + positive_weight + negative_weight) / 3.0
1486 }
1487
1488 fn decorate_chunk(
1489 &mut self,
1490 record: &DataRecord,
1491 chunk: &mut RecordChunk,
1492 rng: &mut DeterministicRng,
1493 ) {
1494 chunk.kvp_meta = record
1495 .meta_prefix
1496 .as_ref()
1497 .map(|s| s.all_metadata())
1498 .unwrap_or_default();
1499 if let Some(spec) = record.meta_prefix.as_ref()
1500 && let Some(prefix) = spec.sample(rng)
1501 {
1502 let body_tokens: Vec<&str> = WhitespaceTokenizer.tokenize(&chunk.text);
1503 let prefix_tokens: Vec<&str> = WhitespaceTokenizer.tokenize(&prefix);
1504 let total_tokens = prefix_tokens.len() + body_tokens.len();
1505 let max_window = self.config.chunking.max_window_tokens;
1506 if max_window > 0 && total_tokens > max_window {
1507 if prefix_tokens.len() >= max_window {
1508 chunk.text = prefix_tokens
1509 .into_iter()
1510 .take(max_window)
1511 .collect::<Vec<_>>()
1512 .join(" ");
1513 chunk.tokens_estimate = max_window;
1514 } else {
1515 let remaining = max_window - prefix_tokens.len();
1516 let trimmed_body: Vec<&str> = body_tokens.into_iter().take(remaining).collect();
1517 chunk.text =
1518 format!("{}{}{}", prefix, platform_newline(), trimmed_body.join(" "));
1519 chunk.tokens_estimate = max_window;
1520 }
1521 } else {
1522 chunk.text = format!("{}{}{}", prefix, platform_newline(), chunk.text);
1523 chunk.tokens_estimate = total_tokens;
1524 }
1525 }
1526 }
1527
1528 fn select_chunk_parallel(
1532 &self,
1533 record: &DataRecord,
1534 selector: &Selector,
1535 rng: &mut DeterministicRng,
1536 ) -> Option<RecordChunk> {
1537 match selector {
1538 Selector::Role(role) => self.select_role_parallel(record, role, rng),
1539 Selector::Paragraph(idx) => record.sections.get(*idx).and_then(|section| {
1540 let pool = self.materialize_chunks(record, *idx, section);
1541 if pool.is_empty() {
1542 return None;
1543 }
1544 let i = rng.random_range(0..pool.len());
1545 pool.into_iter().nth(i)
1546 }),
1547 Selector::TemporalOffset(offset) => self
1548 .select_temporal_neighbor(record, *offset)
1549 .and_then(|neighbor| {
1550 self.select_role_parallel(&neighbor, &SectionRole::Context, rng)
1551 }),
1552 Selector::Random => {
1553 if record.sections.is_empty() {
1554 return None;
1555 }
1556 let idx = rng.random_range(0..record.sections.len());
1557 record.sections.get(idx).and_then(|section| {
1558 let pool = self.materialize_chunks(record, idx, section);
1559 if pool.is_empty() {
1560 return None;
1561 }
1562 let i = rng.random_range(0..pool.len());
1563 pool.into_iter().nth(i)
1564 })
1565 }
1566 }
1567 }
1568
1569 fn select_role_parallel(
1571 &self,
1572 record: &DataRecord,
1573 role: &SectionRole,
1574 rng: &mut DeterministicRng,
1575 ) -> Option<RecordChunk> {
1576 let indices: Vec<usize> = record
1577 .sections
1578 .iter()
1579 .enumerate()
1580 .filter(|(_, s)| roles_match(role, &s.role))
1581 .map(|(i, _)| i)
1582 .collect();
1583 if indices.is_empty() {
1584 return None;
1585 }
1586 let start = rng.random_range(0..indices.len());
1587 for offset in 0..indices.len() {
1588 let section_idx = indices[(start + offset) % indices.len()];
1589 let section = &record.sections[section_idx];
1590 let pool = self.materialize_chunks(record, section_idx, section);
1591 if !pool.is_empty() {
1592 let i = rng.random_range(0..pool.len());
1593 return pool.into_iter().nth(i);
1594 }
1595 }
1596 None
1597 }
1598
1599 fn decorate_chunk_parallel(
1601 &self,
1602 record: &DataRecord,
1603 chunk: &mut RecordChunk,
1604 rng: &mut DeterministicRng,
1605 ) {
1606 chunk.kvp_meta = record
1607 .meta_prefix
1608 .as_ref()
1609 .map(|s| s.all_metadata())
1610 .unwrap_or_default();
1611 if let Some(spec) = record.meta_prefix.as_ref()
1612 && let Some(prefix) = spec.sample(rng)
1613 {
1614 let body_tokens: Vec<&str> = WhitespaceTokenizer.tokenize(&chunk.text);
1615 let prefix_tokens: Vec<&str> = WhitespaceTokenizer.tokenize(&prefix);
1616 let total_tokens = prefix_tokens.len() + body_tokens.len();
1617 let max_window = self.config.chunking.max_window_tokens;
1618 if max_window > 0 && total_tokens > max_window {
1619 if prefix_tokens.len() >= max_window {
1620 chunk.text = prefix_tokens
1621 .into_iter()
1622 .take(max_window)
1623 .collect::<Vec<_>>()
1624 .join(" ");
1625 chunk.tokens_estimate = max_window;
1626 } else {
1627 let remaining = max_window - prefix_tokens.len();
1628 let trimmed_body: Vec<&str> = body_tokens.into_iter().take(remaining).collect();
1629 chunk.text =
1630 format!("{}{}{}", prefix, platform_newline(), trimmed_body.join(" "));
1631 chunk.tokens_estimate = max_window;
1632 }
1633 } else {
1634 chunk.text = format!("{}{}{}", prefix, platform_newline(), chunk.text);
1635 chunk.tokens_estimate = total_tokens;
1636 }
1637 }
1638 }
1639
1640 fn select_anchor_positive_parallel(
1642 &self,
1643 record: &DataRecord,
1644 anchor_selector: &Selector,
1645 positive_selector: &Selector,
1646 enforce_window_pair: bool,
1647 rng: &mut DeterministicRng,
1648 ) -> Option<(RecordChunk, RecordChunk)> {
1649 let anchor_chunk = self.select_chunk_parallel(record, anchor_selector, rng)?;
1650 let mut positive_chunk = self.select_chunk_parallel(record, positive_selector, rng)?;
1651 if anchor_selector == positive_selector {
1652 let mut retries = 0usize;
1653 while !same_selector_pair_is_valid(&anchor_chunk, &positive_chunk, enforce_window_pair)
1654 && retries < SAME_SELECTOR_PAIR_RETRY_LIMIT
1655 {
1656 positive_chunk = self.select_chunk_parallel(record, positive_selector, rng)?;
1657 retries += 1;
1658 }
1659 if !same_selector_pair_is_valid(&anchor_chunk, &positive_chunk, enforce_window_pair) {
1660 return None;
1661 }
1662 }
1663 Some((anchor_chunk, positive_chunk))
1665 }
1666
1667 fn select_anchor_positive_for_recipe(
1670 &self,
1671 recipe: &TripletRecipe,
1672 anchor_record: &DataRecord,
1673 rng: &mut DeterministicRng,
1674 ) -> Option<(RecordChunk, RecordChunk, String)> {
1675 if Self::is_auto_chunk_pair_recipe(recipe) {
1676 if !self
1677 .record_has_at_least_two_window_chunks_for_selector(anchor_record, &recipe.anchor)
1678 {
1679 return None;
1680 }
1681 let mut anchor_chunk =
1682 self.select_chunk_parallel(anchor_record, &recipe.anchor, rng)?;
1683 let mut positive_chunk =
1684 self.select_chunk_parallel(anchor_record, &recipe.anchor, rng)?;
1685 let mut tries = 0usize;
1686 while !same_selector_pair_is_valid(&anchor_chunk, &positive_chunk, true) {
1687 tries += 1;
1688 if tries >= SAME_SELECTOR_PAIR_RETRY_LIMIT {
1689 return None;
1690 }
1691 anchor_chunk = self.select_chunk_parallel(anchor_record, &recipe.anchor, rng)?;
1692 positive_chunk = self.select_chunk_parallel(anchor_record, &recipe.anchor, rng)?;
1693 }
1694 let anchor_raw_text = anchor_chunk.text.clone();
1695 self.decorate_chunk_parallel(anchor_record, &mut anchor_chunk, rng);
1696 self.decorate_chunk_parallel(anchor_record, &mut positive_chunk, rng);
1697 return Some((anchor_chunk, positive_chunk, anchor_raw_text));
1698 }
1699 let (mut anchor_chunk, mut positive_chunk) = self.select_anchor_positive_parallel(
1700 anchor_record,
1701 &recipe.anchor,
1702 &recipe.positive_selector,
1703 false,
1704 rng,
1705 )?;
1706 let anchor_raw_text = anchor_chunk.text.clone();
1707 self.decorate_chunk_parallel(anchor_record, &mut anchor_chunk, rng);
1708 self.decorate_chunk_parallel(anchor_record, &mut positive_chunk, rng);
1709 Some((anchor_chunk, positive_chunk, anchor_raw_text))
1710 }
1711
1712 fn select_chunk(&mut self, record: &DataRecord, selector: &Selector) -> Option<RecordChunk> {
1713 match selector {
1714 Selector::Role(role) => self.select_by_role(record, role),
1715 Selector::Paragraph(idx) => record.sections.get(*idx).and_then(|section| {
1716 let pool = self.materialize_chunks(record, *idx, section);
1717 self.next_chunk_from_pool(&record.id, *idx, pool)
1718 }),
1719 Selector::TemporalOffset(offset) => self
1720 .select_temporal_neighbor(record, *offset)
1721 .and_then(|neighbor| self.select_by_role(&neighbor, &SectionRole::Context)),
1722 Selector::Random => {
1723 if record.sections.is_empty() {
1724 return None;
1725 }
1726 let idx = self.rng.random_range(0..record.sections.len());
1727 record.sections.get(idx).and_then(|section| {
1728 let pool = self.materialize_chunks(record, idx, section);
1729 self.next_chunk_from_pool(&record.id, idx, pool)
1730 })
1731 }
1732 }
1733 }
1734
1735 fn select_by_role(&mut self, record: &DataRecord, role: &SectionRole) -> Option<RecordChunk> {
1736 let indices: Vec<usize> = record
1737 .sections
1738 .iter()
1739 .enumerate()
1740 .filter(|(_, section)| roles_match(role, §ion.role))
1741 .map(|(idx, _)| idx)
1742 .collect();
1743 if indices.is_empty() {
1744 return None;
1745 }
1746 let key = role_cursor_key(&record.id, role);
1747 let start_offset = self
1748 .role_cursors
1749 .get(&key)
1750 .and_then(|last_idx| indices.iter().position(|idx| idx == last_idx))
1751 .map(|pos| (pos + 1) % indices.len())
1752 .unwrap_or_else(|| {
1753 let seed_key = format!("{}::{}", key.0, key.1);
1757 (stable_hash_str(self.epoch_seed(), &seed_key) as usize) % indices.len()
1758 });
1759 for offset in 0..indices.len() {
1760 let section_idx = indices[(start_offset + offset) % indices.len()];
1761 let section = &record.sections[section_idx];
1762 let pool = self.materialize_chunks(record, section_idx, section);
1763 if let Some(chunk) = self.next_chunk_from_pool(&record.id, section_idx, pool) {
1764 self.role_cursors.insert(key.clone(), section_idx);
1765 return Some(chunk);
1766 }
1767 }
1768 None
1769 }
1770
1771 fn materialize_chunks(
1784 &self,
1785 record: &DataRecord,
1786 section_idx: usize,
1787 section: &RecordSection,
1788 ) -> Vec<RecordChunk> {
1789 self.chunker
1790 .materialize(&self.config.chunking, record, section_idx, section)
1791 }
1792
1793 fn build_derived_text_recipes(recipes: &[TripletRecipe]) -> Vec<TextRecipe> {
1794 let mut derived = Vec::new();
1795 for recipe in recipes {
1796 let base = recipe.name.as_ref();
1797 derived.push(TextRecipe {
1798 name: Cow::Owned(format!("{base}_anchor")),
1799 selector: recipe.anchor.clone(),
1800 weight: recipe.weight,
1801 instruction: None,
1802 });
1803 derived.push(TextRecipe {
1804 name: Cow::Owned(format!("{base}_positive")),
1805 selector: recipe.positive_selector.clone(),
1806 weight: recipe.weight,
1807 instruction: None,
1808 });
1809 derived.push(TextRecipe {
1810 name: Cow::Owned(format!("{base}_negative")),
1811 selector: recipe.negative_selector.clone(),
1812 weight: recipe.weight,
1813 instruction: None,
1814 });
1815 }
1816 derived
1817 }
1818
1819 fn record_has_long_anchor_or_context_section(&self, record: &DataRecord) -> bool {
1822 let window = self.config.chunking.max_window_tokens;
1823 if window == 0 {
1824 return false;
1825 }
1826 record.sections.iter().any(|section| {
1827 matches!(section.role, SectionRole::Anchor | SectionRole::Context)
1828 && WhitespaceTokenizer.token_count(§ion.text) > window
1829 })
1830 }
1831
1832 fn sync_records_from_cache(&mut self) -> Result<(), SamplerError> {
1833 let mut snapshot = self.ingestion.all_records_snapshot();
1834 snapshot.sort_by(|a, b| a.id.cmp(&b.id));
1835 self.records.clear();
1840 self.sources_with_long_sections.clear();
1844 self.emitted_text_hashes.clear();
1850 self.negative_backend.on_sync_start();
1855 for record in snapshot {
1856 if self.split_store.label_for(&record.id).is_none() {
1857 self.split_store.ensure(record.id.clone())?;
1858 }
1859 if self.record_has_long_anchor_or_context_section(&record) {
1860 self.sources_with_long_sections
1862 .insert(record.source.clone());
1863 }
1864 self.records.insert(record.id.clone(), Arc::new(record));
1865 }
1866 self.prune_cursor_state();
1867 self.rebuild_chunk_index();
1868 self.rebuild_source_index()?;
1869 Ok(())
1870 }
1871
1872 fn ingest_internal_for_split(&mut self, target_split: SplitLabel) -> Result<(), SamplerError> {
1873 if !self.ingestion.has_sources() {
1874 return Ok(());
1875 }
1876 if !self.ingestion_cursors_loaded {
1877 if let Some(state) = self.split_store.load_sampler_state()? {
1878 self.ingestion.load_cursors(&state.source_stream_cursors);
1879 self.ingestion.set_source_epoch(state.source_epoch);
1880 }
1881 self.ingestion_cursors_loaded = true;
1882 }
1883 if self.ingestion.all_caches_empty() {
1884 self.ingestion.refresh_all();
1885 } else {
1886 self.ingestion.advance(self.config.batch_size);
1887 }
1888 let observed = self.ingestion.total_ingest_count();
1889 if observed == self.last_observed_ingest && !self.records.is_empty() {
1890 return Ok(());
1891 }
1892 self.last_observed_ingest = observed;
1893 self.sync_records_from_cache()?;
1894 let max_window_tokens = self.config.chunking.max_window_tokens;
1895 self.negative_backend.on_records_refreshed(
1896 &self.records,
1897 max_window_tokens,
1898 &|id| self.split_store.label_for(id),
1899 self.ingestion.last_refreshed_sources(),
1900 );
1901 self.epoch_tracker.ensure_loaded()?;
1904 let records_by_split = self.records_by_split()?;
1905 self.epoch_tracker
1906 .reconcile(target_split, &records_by_split);
1907 self.ensure_source_state()?;
1908 Ok(())
1909 }
1910
1911 #[cfg(test)]
1912 fn ingest_internal(&mut self, split: SplitLabel) -> Result<(), SamplerError> {
1913 self.ingest_internal_for_split(split)
1914 }
1915
1916 fn ingest_internal_with_weights_for_split(
1917 &mut self,
1918 target_split: SplitLabel,
1919 weights: &HashMap<SourceId, f32>,
1920 ) -> Result<(), SamplerError> {
1921 if !self.ingestion.has_sources() {
1922 return Ok(());
1923 }
1924 if !self.ingestion_cursors_loaded {
1925 if let Some(state) = self.split_store.load_sampler_state()? {
1926 self.ingestion.load_cursors(&state.source_stream_cursors);
1927 self.ingestion.set_source_epoch(state.source_epoch);
1928 }
1929 self.ingestion_cursors_loaded = true;
1930 }
1931 if self.ingestion.all_caches_empty() {
1932 self.ingestion.refresh_all_with_weights(weights)?;
1933 } else {
1934 self.ingestion
1935 .advance_with_weights(self.config.batch_size, weights)?;
1936 }
1937 let observed = self.ingestion.total_ingest_count();
1938 if observed == self.last_observed_ingest && !self.records.is_empty() {
1939 return Ok(());
1940 }
1941 self.last_observed_ingest = observed;
1942 self.sync_records_from_cache()?;
1943 let max_window_tokens = self.config.chunking.max_window_tokens;
1944 self.negative_backend.on_records_refreshed(
1945 &self.records,
1946 max_window_tokens,
1947 &|id| self.split_store.label_for(id),
1948 self.ingestion.last_refreshed_sources(),
1949 );
1950 self.epoch_tracker.ensure_loaded()?;
1951 let records_by_split = self.records_by_split()?;
1952 self.epoch_tracker
1953 .reconcile(target_split, &records_by_split);
1954 self.ensure_source_state()?;
1955 Ok(())
1956 }
1957
1958 fn ingest_with_weights_fallback(
1963 &mut self,
1964 target_split: SplitLabel,
1965 weights: Option<&HashMap<SourceId, f32>>,
1966 ) -> Result<(), SamplerError> {
1967 match weights {
1968 Some(weights)
1969 if !weights.is_empty()
1970 && !weights
1971 .values()
1972 .all(|&w| w == *weights.values().next().unwrap()) =>
1973 {
1974 self.ingest_internal_with_weights_for_split(target_split, weights)?
1975 }
1976 _ => self.ingest_internal_for_split(target_split)?,
1977 }
1978 Ok(())
1979 }
1980
1981 fn force_ingest_refresh_with_weights_for_split(
1982 &mut self,
1983 target_split: SplitLabel,
1984 weights: &HashMap<SourceId, f32>,
1985 ) -> Result<(), SamplerError> {
1986 if !self.ingestion.has_sources() {
1987 return Ok(());
1988 }
1989 if !self.ingestion_cursors_loaded {
1990 if let Some(state) = self.split_store.load_sampler_state()? {
1991 self.ingestion.load_cursors(&state.source_stream_cursors);
1992 self.ingestion.set_source_epoch(state.source_epoch);
1993 }
1994 self.ingestion_cursors_loaded = true;
1995 }
1996 self.ingestion.force_refresh_all_with_weights(weights)?;
1997 self.last_observed_ingest = self.ingestion.total_ingest_count();
1998 self.sync_records_from_cache()?;
1999 let max_window_tokens = self.config.chunking.max_window_tokens;
2000 self.negative_backend.on_records_refreshed(
2001 &self.records,
2002 max_window_tokens,
2003 &|id| self.split_store.label_for(id),
2004 self.ingestion.last_refreshed_sources(),
2005 );
2006 self.epoch_tracker.ensure_loaded()?;
2007 let records_by_split = self.records_by_split()?;
2008 self.epoch_tracker
2009 .reconcile(target_split, &records_by_split);
2010 self.ensure_source_state()?;
2011 Ok(())
2012 }
2013
2014 fn sample_source_triplet_candidate(
2024 &mut self,
2025 source: &str,
2026 target_split: SplitLabel,
2027 recipe_orders: &mut HashMap<RecipeKey, Vec<usize>>,
2028 recipe_positions: &mut HashMap<RecipeKey, usize>,
2029 rng: &mut DeterministicRng,
2030 ) -> (Option<(TripletRecipe, SampleTriplet)>, usize) {
2031 let (recipes, _auto_injected) = self.resolve_source_triplet_plan(source);
2034 if recipes.is_empty() {
2035 return (None, 0);
2036 }
2037 if !recipe_orders.contains_key(source) {
2038 let recipe_weights: Vec<f32> = recipes.iter().map(|r| r.weight).collect();
2039 let order =
2040 self.recipe_order_weighted_cycled(&recipe_weights, self.triplet_recipe_rr_idx, rng);
2041 recipe_orders.insert(source.to_string(), order);
2042 }
2043 let order = recipe_orders
2044 .get(source)
2045 .expect("recipe order missing for source");
2046 let pos = recipe_positions.entry(source.to_string()).or_insert(0);
2047 let Some(anchor) = self.choose_anchor_record(Some(source), target_split) else {
2048 return (None, 0);
2049 };
2050
2051 let mut attempts = 0usize;
2052 for offset in 0..order.len() {
2053 let idx = order[(*pos + offset) % order.len()];
2054 attempts = attempts.saturating_add(1);
2055 let recipe = recipes[idx].clone();
2056 if let Some(sample) = self.make_triplet_with_anchor(&recipe, &anchor, rng) {
2060 *pos = (*pos + offset + 1) % order.len();
2061 return (Some((recipe, sample)), attempts);
2062 }
2063 }
2064
2065 (None, attempts)
2066 }
2067
2068 fn next_pair_batch_inner_with_weights(
2069 &mut self,
2070 target_split: SplitLabel,
2071 weights: Option<&HashMap<SourceId, f32>>,
2072 ) -> Result<SampleBatch, SamplerError> {
2073 if let Some(weights) = weights {
2074 if weights.is_empty()
2075 || weights
2076 .values()
2077 .all(|&w| w == *weights.values().next().unwrap())
2078 {
2079 self.ingest_internal_for_split(target_split)?;
2080 } else {
2081 self.ingest_internal_with_weights_for_split(target_split, weights)?;
2082 }
2083 } else {
2084 self.ingest_internal_for_split(target_split)?;
2085 }
2086 self.ensure_split_has_records(target_split)?;
2087 let sources = self.source_order.clone();
2088 if sources.is_empty() {
2089 if self.triplet_recipes.is_empty() {
2090 return Err(SamplerError::Configuration(
2091 "no triplet recipes available".into(),
2092 ));
2093 }
2094 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2095 let recipe_weights: Vec<f32> = self.triplet_recipes.iter().map(|r| r.weight).collect();
2096 let recipe_order = self.recipe_order_weighted_cycled(
2097 &recipe_weights,
2098 self.triplet_recipe_rr_idx,
2099 &mut rng,
2100 );
2101 let mut pairs = Vec::new();
2102 let mut seen = HashSet::new();
2103 let mut last_recipe_name = None;
2104 let mut recipe_pos = 0usize;
2105 let mut recipe_steps = 0usize;
2106 let attempts = self.config.batch_size * 4 * recipe_order.len().max(1);
2107 for _ in 0..attempts {
2108 if pairs.len() >= self.config.batch_size {
2109 break;
2110 }
2111 let Some(anchor) = self.choose_anchor_record(None, target_split) else {
2112 break;
2113 };
2114 let mut triplet = None;
2115 for offset in 0..recipe_order.len() {
2116 let idx = recipe_order[(recipe_pos + offset) % recipe_order.len()];
2117 recipe_steps = recipe_steps.saturating_add(1);
2118 let recipe = self.triplet_recipes[idx].clone();
2119 last_recipe_name = Some(recipe.name.clone());
2120 if let Some(sample) = self.make_triplet_with_anchor(&recipe, &anchor, &mut rng)
2121 {
2122 triplet = Some((recipe, sample));
2123 recipe_pos = (recipe_pos + offset + 1) % recipe_order.len();
2124 break;
2125 }
2126 }
2127 if let Some((recipe, triplet)) = triplet {
2128 let key = (
2129 triplet.anchor.record_id.clone(),
2130 triplet.positive.record_id.clone(),
2131 triplet.negative.record_id.clone(),
2132 );
2133 if seen.insert(key) {
2134 let SampleTriplet {
2135 recipe: triplet_recipe_name,
2136 anchor,
2137 positive,
2138 negative,
2139 weight,
2140 instruction,
2141 } = triplet;
2142 if pairs.len() < self.config.batch_size {
2143 pairs.push(SamplePair {
2144 recipe: triplet_recipe_name.clone(),
2145 anchor: anchor.clone(),
2146 positive: positive.clone(),
2147 weight,
2148 instruction: instruction.clone(),
2149 label: PairLabel::Positive,
2150 reason: None,
2151 });
2152 }
2153 if pairs.len() < self.config.batch_size {
2154 pairs.push(SamplePair {
2155 recipe: triplet_recipe_name,
2156 anchor,
2157 positive: negative,
2158 weight,
2159 instruction,
2160 label: PairLabel::Negative,
2161 reason: Some(
2162 strategy_reason(&recipe.negative_strategy).to_string(),
2163 ),
2164 });
2165 }
2166 }
2167 }
2168 }
2169 if recipe_steps > 0 {
2170 self.triplet_recipe_rr_idx =
2171 self.triplet_recipe_rr_idx.saturating_add(recipe_steps);
2172 }
2173 self.rng = rng;
2174 pad_with_reuse(&mut pairs, self.config.batch_size);
2175 if pairs.len() == self.config.batch_size {
2176 return Ok(SampleBatch { pairs });
2177 }
2178 return Err(SamplerError::Exhausted(
2179 last_recipe_name
2180 .unwrap_or(Cow::Borrowed(RECIPE_LABEL_TRIPLETS))
2181 .to_string(),
2182 ));
2183 }
2184
2185 let mut pairs = Vec::new();
2186 let mut seen = HashSet::new();
2187 let mut source_steps = 0usize;
2188 let mut cycle = (self.source_cycle_idx / sources.len()) as u64;
2189 let mut source_idx = self.source_cycle_idx % sources.len();
2190 let mut cycle_sources = self.shuffled_source_cycle(cycle);
2191 let mut recipe_orders: HashMap<RecipeKey, Vec<usize>> = HashMap::new();
2192 let mut recipe_positions: HashMap<RecipeKey, usize> = HashMap::new();
2193 let mut recipe_steps = 0usize;
2194 let max_recipe_len = sources
2195 .iter()
2196 .map(|source| self.triplet_recipe_count_for_source(source))
2197 .max()
2198 .unwrap_or(1)
2199 .max(1);
2200 let attempts = self.config.batch_size * 4 * sources.len() * max_recipe_len;
2201 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2202 for _ in 0..attempts {
2203 if pairs.len() >= self.config.batch_size {
2204 break;
2205 }
2206 let source = cycle_sources[source_idx].as_str();
2207 let (triplet, attempts_used) = self.sample_source_triplet_candidate(
2208 source,
2209 target_split,
2210 &mut recipe_orders,
2211 &mut recipe_positions,
2212 &mut rng,
2213 );
2214 recipe_steps = recipe_steps.saturating_add(attempts_used);
2215 if let Some((recipe, triplet)) = triplet {
2216 let key = (
2217 triplet.anchor.record_id.clone(),
2218 triplet.positive.record_id.clone(),
2219 triplet.negative.record_id.clone(),
2220 );
2221 if seen.insert(key) {
2222 let SampleTriplet {
2223 recipe: triplet_recipe_name,
2224 anchor,
2225 positive,
2226 negative,
2227 weight,
2228 instruction,
2229 } = triplet;
2230 if pairs.len() < self.config.batch_size {
2231 pairs.push(SamplePair {
2232 recipe: triplet_recipe_name.clone(),
2233 anchor: anchor.clone(),
2234 positive: positive.clone(),
2235 weight,
2236 instruction: instruction.clone(),
2237 label: PairLabel::Positive,
2238 reason: None,
2239 });
2240 }
2241 if pairs.len() < self.config.batch_size {
2242 pairs.push(SamplePair {
2243 recipe: triplet_recipe_name,
2244 anchor,
2245 positive: negative,
2246 weight,
2247 instruction,
2248 label: PairLabel::Negative,
2249 reason: Some(strategy_reason(&recipe.negative_strategy).to_string()),
2250 });
2251 }
2252 }
2253 }
2254 source_idx += 1;
2255 source_steps += 1;
2256 if source_idx >= cycle_sources.len() {
2257 source_idx = 0;
2258 cycle = cycle.saturating_add(1);
2259 cycle_sources = self.shuffled_source_cycle(cycle);
2260 }
2261 }
2262 if recipe_steps > 0 {
2263 self.triplet_recipe_rr_idx = self.triplet_recipe_rr_idx.saturating_add(recipe_steps);
2264 }
2265 self.rng = rng;
2266 pad_with_reuse(&mut pairs, self.config.batch_size);
2267 if pairs.len() == self.config.batch_size {
2268 self.source_cycle_idx = self.source_cycle_idx.saturating_add(source_steps);
2269 self.source_state_dirty = sources.len() > 1;
2270 return Ok(SampleBatch { pairs });
2271 }
2272 Err(SamplerError::Exhausted(RECIPE_LABEL_TRIPLETS.into()))
2273 }
2274
2275 fn try_emit_text_sample(
2280 &mut self,
2281 sample: &TextSample,
2282 seen: &mut HashSet<(String, String)>,
2283 ) -> bool {
2284 let key = text_dedup_key(&sample.chunk);
2285 seen.insert(key)
2286 && self.emitted_text_hashes.insert({
2287 let h = stable_hash_str(0, &sample.chunk.record_id);
2288 stable_hash_str(h, &sample.chunk.text)
2289 })
2290 }
2291
2292 fn next_text_batch_inner_with_weights(
2293 &mut self,
2294 target_split: SplitLabel,
2295 weights: Option<&HashMap<SourceId, f32>>,
2296 ) -> Result<TextBatch, SamplerError> {
2297 self.ingest_with_weights_fallback(target_split, weights)?;
2298 self.ensure_split_has_records(target_split)?;
2299 let sources = self.source_order.clone();
2300 if sources.is_empty() {
2301 if self.text_recipes.is_empty() {
2302 return Err(SamplerError::Configuration(
2303 "no text recipes configured".into(),
2304 ));
2305 }
2306 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2307 let recipe_weights: Vec<f32> = self.text_recipes.iter().map(|r| r.weight).collect();
2308 let recipe_order = self.text_recipe_order_weighted_cycled(
2309 &recipe_weights,
2310 self.text_recipe_rr_idx,
2311 &mut rng,
2312 );
2313 let mut samples = Vec::new();
2314 let mut seen = HashSet::new();
2315 let mut last_recipe_name = None;
2316 let mut recipe_pos = 0usize;
2317 let mut recipe_steps = 0usize;
2318 let attempts = self.config.batch_size * 4 * recipe_order.len().max(1);
2319 for _ in 0..attempts {
2320 if samples.len() >= self.config.batch_size {
2321 break;
2322 }
2323 let recipe_idx = recipe_order[recipe_pos];
2324 recipe_pos = (recipe_pos + 1) % recipe_order.len();
2325 recipe_steps = recipe_steps.saturating_add(1);
2326 let recipe = self.text_recipes[recipe_idx].clone();
2327 last_recipe_name = Some(recipe.name.clone());
2328 if let Some(sample) =
2329 self.make_text_sample_for_split(&recipe, None, target_split, &mut rng)
2330 && self.try_emit_text_sample(&sample, &mut seen)
2331 {
2332 samples.push(sample);
2333 }
2334 }
2335 if recipe_steps > 0 {
2336 self.text_recipe_rr_idx = self.text_recipe_rr_idx.saturating_add(recipe_steps);
2337 }
2338 self.rng = rng;
2339 pad_with_reuse(&mut samples, self.config.batch_size);
2340 if samples.len() == self.config.batch_size {
2341 return Ok(TextBatch { samples });
2342 }
2343 return Err(SamplerError::Exhausted(
2344 last_recipe_name
2345 .unwrap_or(Cow::Borrowed(RECIPE_LABEL_TEXT))
2346 .to_string(),
2347 ));
2348 }
2349
2350 let mut samples = Vec::new();
2351 let mut seen = HashSet::new();
2352 let mut source_steps = 0usize;
2353 let mut cycle = (self.source_cycle_idx / sources.len()) as u64;
2354 let mut idx = self.source_cycle_idx % sources.len();
2355 let mut cycle_sources = self.shuffled_source_cycle(cycle);
2356 let mut recipe_orders: HashMap<RecipeKey, Vec<usize>> = HashMap::new();
2357 let mut recipe_positions: HashMap<RecipeKey, usize> = HashMap::new();
2358 let mut recipe_steps = 0usize;
2359 let max_recipe_len = sources
2360 .iter()
2361 .map(|source| self.text_recipes_for_source(source).len())
2362 .max()
2363 .unwrap_or(1)
2364 .max(1);
2365 let attempts = self.config.batch_size * 4 * sources.len() * max_recipe_len;
2366 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2367 for _ in 0..attempts {
2368 if samples.len() >= self.config.batch_size {
2369 break;
2370 }
2371 let source = cycle_sources[idx].as_str();
2372 let recipes = self.text_recipes_for_source(source).to_vec();
2373 if recipes.is_empty() {
2374 idx += 1;
2375 source_steps += 1;
2376 if idx >= cycle_sources.len() {
2377 idx = 0;
2378 cycle = cycle.saturating_add(1);
2379 cycle_sources = self.shuffled_source_cycle(cycle);
2380 }
2381 continue;
2382 }
2383 if !recipe_orders.contains_key(source) {
2384 let recipe_weights: Vec<f32> = recipes.iter().map(|r| r.weight).collect();
2385 let order = self.text_recipe_order_weighted_cycled(
2386 &recipe_weights,
2387 self.text_recipe_rr_idx,
2388 &mut rng,
2389 );
2390 recipe_orders.insert(source.to_string(), order);
2391 }
2392 let order = recipe_orders
2393 .get(source)
2394 .expect("recipe order missing for source");
2395 let pos = recipe_positions.entry(source.to_string()).or_insert(0);
2396 let mut sample: Option<(TextRecipe, TextSample)> = None;
2397 for offset in 0..order.len() {
2398 let recipe_idx = order[(*pos + offset) % order.len()];
2399 let recipe = recipes[recipe_idx].clone();
2400 if let Some(item) =
2401 self.make_text_sample_for_split(&recipe, Some(source), target_split, &mut rng)
2402 {
2403 recipe_steps = recipe_steps.saturating_add(offset + 1);
2404 *pos = (*pos + offset + 1) % order.len();
2405 sample = Some((recipe, item));
2406 break;
2407 }
2408 }
2409 if sample.is_none() {
2410 recipe_steps = recipe_steps.saturating_add(order.len());
2411 }
2412 if let Some((_recipe, sample)) = sample
2413 && self.try_emit_text_sample(&sample, &mut seen)
2414 {
2415 samples.push(sample);
2416 }
2417 idx += 1;
2418 source_steps += 1;
2419 if idx >= cycle_sources.len() {
2420 idx = 0;
2421 cycle = cycle.saturating_add(1);
2422 cycle_sources = self.shuffled_source_cycle(cycle);
2423 }
2424 }
2425 if samples.len() != self.config.batch_size {
2426 pad_with_reuse(&mut samples, self.config.batch_size);
2427 }
2428 if samples.len() != self.config.batch_size {
2429 self.rng = rng;
2430 return Err(SamplerError::Exhausted(RECIPE_LABEL_TEXT.into()));
2431 }
2432 self.rng = rng;
2433 self.source_cycle_idx = self.source_cycle_idx.saturating_add(source_steps);
2434 self.source_state_dirty = sources.len() > 1;
2435 if recipe_steps > 0 {
2436 self.text_recipe_rr_idx = self.text_recipe_rr_idx.saturating_add(recipe_steps);
2437 }
2438 Ok(TextBatch { samples })
2439 }
2440
2441 fn next_triplet_batch_inner_with_weights(
2442 &mut self,
2443 target_split: SplitLabel,
2444 weights: Option<&HashMap<SourceId, f32>>,
2445 ) -> Result<TripletBatch, SamplerError> {
2446 self.ingest_with_weights_fallback(target_split, weights)?;
2447 self.ensure_split_has_records(target_split)?;
2448 let sources = self.source_order.clone();
2449 if sources.is_empty() {
2450 if self.triplet_recipes.is_empty() {
2451 return Err(SamplerError::Configuration(
2452 "no triplet recipes configured".into(),
2453 ));
2454 }
2455 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2456 let recipe_weights: Vec<f32> = self.triplet_recipes.iter().map(|r| r.weight).collect();
2457 let recipe_order = self.recipe_order_weighted_cycled(
2458 &recipe_weights,
2459 self.triplet_recipe_rr_idx,
2460 &mut rng,
2461 );
2462 let mut triplets = Vec::new();
2463 let mut seen = HashSet::new();
2464 let mut last_recipe_name = None;
2465 let mut recipe_pos = 0usize;
2466 let mut recipe_steps = 0usize;
2467 let attempts = self.config.batch_size * 4 * recipe_order.len().max(1);
2468 for _ in 0..attempts {
2469 if triplets.len() >= self.config.batch_size {
2470 break;
2471 }
2472 let Some(anchor) = self.choose_anchor_record(None, target_split) else {
2473 break;
2474 };
2475 let mut triplet = None;
2476 for offset in 0..recipe_order.len() {
2477 let idx = recipe_order[(recipe_pos + offset) % recipe_order.len()];
2478 recipe_steps = recipe_steps.saturating_add(1);
2479 let recipe = self.triplet_recipes[idx].clone();
2480 last_recipe_name = Some(recipe.name.clone());
2481 if let Some(sample) = self.make_triplet_with_anchor(&recipe, &anchor, &mut rng)
2482 {
2483 triplet = Some(sample);
2484 recipe_pos = (recipe_pos + offset + 1) % recipe_order.len();
2485 break;
2486 }
2487 }
2488 if let Some(triplet) = triplet {
2489 let key = (
2490 triplet.anchor.record_id.clone(),
2491 triplet.positive.record_id.clone(),
2492 triplet.negative.record_id.clone(),
2493 );
2494 if seen.insert(key) {
2495 triplets.push(triplet);
2496 }
2497 }
2498 }
2499 if recipe_steps > 0 {
2500 self.triplet_recipe_rr_idx =
2501 self.triplet_recipe_rr_idx.saturating_add(recipe_steps);
2502 }
2503 self.rng = rng;
2504 pad_with_reuse(&mut triplets, self.config.batch_size);
2505 if triplets.len() == self.config.batch_size {
2506 return Ok(TripletBatch { triplets });
2507 }
2508 return Err(SamplerError::Exhausted(
2509 last_recipe_name
2510 .unwrap_or(Cow::Borrowed(RECIPE_LABEL_TRIPLETS))
2511 .to_string(),
2512 ));
2513 }
2514
2515 let mut triplets: Vec<SampleTriplet> = Vec::new();
2517 let mut seen: HashSet<(RecordId, RecordId, RecordId)> = HashSet::new();
2518 let mut source_steps = 0usize;
2519 let mut cycle = (self.source_cycle_idx / sources.len()) as u64;
2520 let mut source_idx = self.source_cycle_idx % sources.len();
2521 let mut cycle_sources = self.shuffled_source_cycle(cycle);
2522 let mut recipe_steps = 0usize;
2523 let max_recipe_len = sources
2524 .iter()
2525 .map(|source| self.triplet_recipe_count_for_source(source.as_str()))
2526 .max()
2527 .unwrap_or(1)
2528 .max(1);
2529 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2530
2531 struct SlotPlan {
2536 anchor: Arc<DataRecord>,
2537 recipes: Vec<TripletRecipe>,
2538 fork_seed: u64,
2539 }
2540 let target_slots = self.config.batch_size * 4 * max_recipe_len;
2541 let mut slot_plans: Vec<SlotPlan> = Vec::with_capacity(target_slots);
2542
2543 for _ in 0..target_slots {
2544 if slot_plans.len() >= target_slots {
2545 break;
2546 }
2547 let source = cycle_sources[source_idx].as_str();
2548 let (recipes, _) = self.resolve_source_triplet_plan(source);
2549 if !recipes.is_empty() {
2550 let fork_seed = rng.next_u64();
2551 if let Some(anchor) = self.choose_anchor_record(Some(source), target_split) {
2552 slot_plans.push(SlotPlan {
2553 anchor,
2554 recipes,
2555 fork_seed,
2556 });
2557 recipe_steps = recipe_steps.saturating_add(1);
2558 }
2559 }
2560 source_idx += 1;
2561 source_steps += 1;
2562 if source_idx >= cycle_sources.len() {
2563 source_idx = 0;
2564 cycle = cycle.saturating_add(1);
2565 cycle_sources = self.shuffled_source_cycle(cycle);
2566 }
2567 }
2568 self.rng = rng;
2569
2570 struct SlotCandidate {
2575 recipe: TripletRecipe,
2576 anchor: Arc<DataRecord>,
2577 anchor_chunk: RecordChunk,
2578 positive_chunk: RecordChunk,
2579 anchor_raw_text: String,
2580 }
2581 let mut raw_candidates: Vec<(usize, Option<SlotCandidate>)> = slot_plans
2582 .par_iter()
2583 .enumerate()
2584 .map(|(slot_idx, plan)| {
2585 let mut fork_rng = DeterministicRng::new(plan.fork_seed);
2586 let weights: Vec<f32> = plan.recipes.iter().map(|r| r.weight).collect();
2589 let order = weighted_recipe_order(&weights, &mut fork_rng);
2590 let mut candidate = None;
2591 for &idx in &order {
2592 let recipe = &plan.recipes[idx];
2593 if let Some((ac, pc, raw)) =
2594 self.select_anchor_positive_for_recipe(recipe, &plan.anchor, &mut fork_rng)
2595 {
2596 candidate = Some(SlotCandidate {
2597 recipe: recipe.clone(),
2598 anchor: Arc::clone(&plan.anchor),
2599 anchor_chunk: ac,
2600 positive_chunk: pc,
2601 anchor_raw_text: raw,
2602 });
2603 break;
2604 }
2605 }
2606 (slot_idx, candidate)
2607 })
2608 .collect();
2609
2610 raw_candidates.sort_unstable_by_key(|(i, _)| *i);
2615 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2616 for (_, candidate) in raw_candidates {
2617 if triplets.len() >= self.config.batch_size {
2618 break;
2619 }
2620 let Some(sc) = candidate else { continue };
2621 let (negative_record, fallback_used) = match self.select_negative_record(
2622 &sc.anchor,
2623 &sc.recipe.negative_strategy,
2624 Some(&sc.anchor_raw_text),
2625 &mut rng,
2626 ) {
2627 Some(pair) => pair,
2628 None => continue,
2629 };
2630 let mut negative_chunk = match self.select_chunk_parallel(
2631 &negative_record,
2632 &sc.recipe.negative_selector,
2633 &mut rng,
2634 ) {
2635 Some(c) => c,
2636 None => continue,
2637 };
2638 self.decorate_chunk_parallel(&negative_record, &mut negative_chunk, &mut rng);
2639
2640 let (anchor_chunk, positive_chunk) = if rng.next_u64() & ANCHOR_POSITIVE_SWAP_MASK == 0
2642 {
2643 (sc.positive_chunk, sc.anchor_chunk)
2644 } else {
2645 (sc.anchor_chunk, sc.positive_chunk)
2646 };
2647
2648 if (!sc.recipe.allow_same_anchor_positive && anchor_chunk.text == positive_chunk.text)
2649 || negative_chunk.text == positive_chunk.text
2650 || negative_chunk.text == anchor_chunk.text
2651 {
2652 continue;
2653 }
2654
2655 let chunk_weight = self.triplet_chunk_weight(
2656 &sc.recipe,
2657 &anchor_chunk,
2658 &positive_chunk,
2659 &negative_chunk,
2660 );
2661 let weight = sc.recipe.weight * chunk_weight;
2662 let recipe_name = if fallback_used {
2663 format!("{}_fallback_same_split", sc.recipe.name)
2664 } else {
2665 sc.recipe.name.to_string()
2666 };
2667 let triplet = SampleTriplet {
2668 recipe: recipe_name,
2669 anchor: anchor_chunk,
2670 positive: positive_chunk,
2671 negative: negative_chunk,
2672 weight,
2673 instruction: sc.recipe.instruction.as_ref().map(|s| s.to_string()),
2674 };
2675 let key = (
2676 triplet.anchor.record_id.clone(),
2677 triplet.positive.record_id.clone(),
2678 triplet.negative.record_id.clone(),
2679 );
2680 if seen.insert(key) && triplets.len() < self.config.batch_size {
2681 triplets.push(triplet);
2682 }
2683 }
2684 self.rng = rng;
2685
2686 if recipe_steps > 0 {
2687 self.triplet_recipe_rr_idx = self.triplet_recipe_rr_idx.saturating_add(recipe_steps);
2688 }
2689 pad_with_reuse(&mut triplets, self.config.batch_size);
2690 if triplets.len() == self.config.batch_size {
2691 self.source_cycle_idx = self.source_cycle_idx.saturating_add(source_steps);
2692 self.source_state_dirty = sources.len() > 1;
2693 let batch = TripletBatch { triplets };
2694 return Ok(batch);
2695 }
2696 Err(SamplerError::Exhausted(RECIPE_LABEL_TRIPLETS.into()))
2697 }
2698
2699 #[cfg(test)]
2705 #[cfg(feature = "bm25-mining")]
2706 fn bm25_backend_mut(&mut self) -> &mut backends::Bm25Backend {
2707 self.negative_backend
2708 .as_any_mut()
2709 .downcast_mut::<backends::Bm25Backend>()
2710 .expect("bm25_backend_mut: negative_backend is Bm25Backend when bm25-mining feature is active")
2711 }
2712
2713 #[cfg(test)]
2715 fn recipe_order_weighted_shuffled_seeded(&mut self, weights: &[f32]) -> Vec<usize> {
2716 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2717 let result = self.recipe_order_weighted_shuffled(weights, &mut rng);
2718 self.rng = rng;
2719 result
2720 }
2721
2722 #[cfg(test)]
2724 fn recipe_order_weighted_cycled_seeded(
2725 &mut self,
2726 weights: &[f32],
2727 rr_idx: usize,
2728 ) -> Vec<usize> {
2729 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2730 let result = self.recipe_order_weighted_cycled(weights, rr_idx, &mut rng);
2731 self.rng = rng;
2732 result
2733 }
2734
2735 #[cfg(test)]
2737 fn text_recipe_order_weighted_shuffled_seeded(&mut self, weights: &[f32]) -> Vec<usize> {
2738 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2739 let result = self.text_recipe_order_weighted_shuffled(weights, &mut rng);
2740 self.rng = rng;
2741 result
2742 }
2743
2744 #[cfg(test)]
2746 fn text_recipe_order_weighted_cycled_seeded(
2747 &mut self,
2748 weights: &[f32],
2749 rr_idx: usize,
2750 ) -> Vec<usize> {
2751 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2752 let result = self.text_recipe_order_weighted_cycled(weights, rr_idx, &mut rng);
2753 self.rng = rng;
2754 result
2755 }
2756
2757 #[cfg(test)]
2759 fn select_negative_record_seeded(
2760 &mut self,
2761 anchor_record: &DataRecord,
2762 strategy: &NegativeStrategy,
2763 anchor_query_text: Option<&str>,
2764 ) -> Option<(Arc<DataRecord>, bool)> {
2765 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2766 let result =
2767 self.select_negative_record(anchor_record, strategy, anchor_query_text, &mut rng);
2768 self.rng = rng;
2769 result
2770 }
2771
2772 #[cfg(test)]
2774 fn make_triplet_with_anchor_seeded(
2775 &mut self,
2776 recipe: &TripletRecipe,
2777 anchor: &DataRecord,
2778 ) -> Option<SampleTriplet> {
2779 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2780 let result = self.make_triplet_with_anchor(recipe, anchor, &mut rng);
2781 self.rng = rng;
2782 result
2783 }
2784
2785 #[cfg(test)]
2787 fn decorate_chunk_seeded(&mut self, record: &DataRecord, chunk: &mut RecordChunk) {
2788 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2789 self.decorate_chunk(record, chunk, &mut rng);
2790 self.rng = rng;
2791 }
2792
2793 #[cfg(all(feature = "bm25-mining", feature = "extended-metrics"))]
2799 fn bm25_fallback_stats(&self) -> (u64, u64) {
2800 self.negative_backend.bm25_fallback_stats()
2801 }
2802
2803 #[cfg(test)]
2808 #[cfg(feature = "bm25-mining")]
2809 fn bm25_ranked_candidates(&mut self, anchor: &crate::data::DataRecord) -> Vec<RecordId> {
2810 let split = self
2811 .split_store
2812 .label_for(&anchor.id)
2813 .unwrap_or(SplitLabel::Train);
2814 self.negative_backend
2815 .as_any_mut()
2816 .downcast_mut::<backends::Bm25Backend>()
2817 .expect("bm25_ranked_candidates: Bm25Backend")
2818 .ranked_candidates_pub(anchor, split)
2819 }
2820}
2821
2822fn weighted_recipe_order(weights: &[f32], rng: &mut DeterministicRng) -> Vec<usize> {
2833 let nonzero: Vec<(usize, f32)> = weights
2834 .iter()
2835 .enumerate()
2836 .filter(|(_, w)| **w > 0.0)
2837 .map(|(i, &w)| (i, w))
2838 .collect();
2839 if nonzero.is_empty() {
2840 return Vec::new();
2841 }
2842 let w_min = nonzero
2843 .iter()
2844 .map(|(_, w)| *w)
2845 .fold(f32::INFINITY, f32::min);
2846 let mut order: Vec<usize> = Vec::new();
2847 for (recipe_idx, w) in &nonzero {
2848 let tickets = ((w / w_min).round() as usize).clamp(1, RECIPE_ORDER_MAX_WEIGHT_MULTIPLIER);
2849 for _ in 0..tickets {
2850 order.push(*recipe_idx);
2851 }
2852 }
2853 order.shuffle(rng);
2854 order
2855}
2856
2857fn same_selector_pair_is_valid(
2858 anchor_chunk: &RecordChunk,
2859 positive_chunk: &RecordChunk,
2860 enforce_window_pair: bool,
2861) -> bool {
2862 if triplet_chunk_key(anchor_chunk) == triplet_chunk_key(positive_chunk) {
2863 return false;
2864 }
2865 if !enforce_window_pair {
2866 return true;
2867 }
2868 matches!(
2869 (&anchor_chunk.view, &positive_chunk.view),
2870 (ChunkView::Window { .. }, ChunkView::Window { .. })
2871 )
2872}
2873
2874impl<S: SplitStore + EpochStateStore + SamplerStateStore + 'static> TripletSampler<S> {
2875 pub fn new(config: SamplerConfig, split_store: Arc<S>) -> Self {
2877 let inner = TripletSamplerInner::new(config, split_store);
2878 Self {
2879 inner: Mutex::new(inner),
2880 }
2881 }
2882
2883 pub fn new_with_chunker(
2885 config: SamplerConfig,
2886 split_store: Arc<S>,
2887 chunker: Arc<dyn ChunkingAlgorithm>,
2888 ) -> Self {
2889 let inner = TripletSamplerInner::new_with_chunker(config, split_store, chunker);
2890 Self {
2891 inner: Mutex::new(inner),
2892 }
2893 }
2894
2895 pub fn next_pair_batch_for_split(
2897 &self,
2898 split: SplitLabel,
2899 ) -> Result<SampleBatch, SamplerError> {
2900 self.next_pair_batch_with_weights_for_split(split, &HashMap::new())
2901 }
2902
2903 pub fn next_text_batch_for_split(&self, split: SplitLabel) -> Result<TextBatch, SamplerError> {
2905 self.next_text_batch_with_weights_for_split(split, &HashMap::new())
2906 }
2907
2908 pub fn next_triplet_batch_for_split(
2910 &self,
2911 split: SplitLabel,
2912 ) -> Result<TripletBatch, SamplerError> {
2913 self.next_triplet_batch_with_weights_for_split(split, &HashMap::new())
2914 }
2915
2916 pub fn next_pair_batch_with_weights_for_split(
2918 &self,
2919 split: SplitLabel,
2920 weights: &HashMap<SourceId, f32>,
2921 ) -> Result<SampleBatch, SamplerError> {
2922 let mut inner = self.inner.lock().unwrap();
2923 inner.ensure_split_allowed(split)?;
2924 for attempt in 0..=EXHAUSTION_RETRY_LIMIT {
2925 match inner.next_pair_batch_inner_with_weights(split, Some(weights)) {
2926 Ok(batch) => return Ok(batch),
2927 Err(SamplerError::Exhausted(_)) => {
2928 if attempt < EXHAUSTION_RETRY_LIMIT {
2929 inner.force_ingest_refresh_with_weights_for_split(split, weights)?;
2930 }
2931 }
2932 Err(err) => return Err(err),
2933 }
2934 }
2935 Err(SamplerError::Exhausted(RECIPE_LABEL_TRIPLETS.into()))
2936 }
2937
2938 pub fn next_text_batch_with_weights_for_split(
2940 &self,
2941 split: SplitLabel,
2942 weights: &HashMap<SourceId, f32>,
2943 ) -> Result<TextBatch, SamplerError> {
2944 let mut inner = self.inner.lock().unwrap();
2945 inner.ensure_split_allowed(split)?;
2946 for attempt in 0..=EXHAUSTION_RETRY_LIMIT {
2947 match inner.next_text_batch_inner_with_weights(split, Some(weights)) {
2948 Ok(batch) => return Ok(batch),
2949 Err(SamplerError::Exhausted(_)) => {
2950 if attempt < EXHAUSTION_RETRY_LIMIT {
2951 inner.force_ingest_refresh_with_weights_for_split(split, weights)?;
2952 }
2953 }
2954 Err(err) => return Err(err),
2955 }
2956 }
2957 Err(SamplerError::Exhausted(RECIPE_LABEL_TEXT.into()))
2958 }
2959
2960 pub fn next_triplet_batch_with_weights_for_split(
2962 &self,
2963 split: SplitLabel,
2964 weights: &HashMap<SourceId, f32>,
2965 ) -> Result<TripletBatch, SamplerError> {
2966 let mut inner = self.inner.lock().unwrap();
2967 inner.ensure_split_allowed(split)?;
2968 for attempt in 0..=EXHAUSTION_RETRY_LIMIT {
2969 match inner.next_triplet_batch_inner_with_weights(split, Some(weights)) {
2970 Ok(batch) => return Ok(batch),
2971 Err(SamplerError::Exhausted(_)) => {
2972 if attempt < EXHAUSTION_RETRY_LIMIT {
2973 inner.force_ingest_refresh_with_weights_for_split(split, weights)?;
2974 }
2975 }
2976 Err(err) => return Err(err),
2977 }
2978 }
2979 Err(SamplerError::Exhausted(RECIPE_LABEL_TRIPLETS.into()))
2980 }
2981
2982 pub fn prefetch_triplet_batches(
2984 self: Arc<Self>,
2985 split: SplitLabel,
2986 capacity: usize,
2987 ) -> BatchPrefetcher<TripletBatch> {
2988 BatchPrefetcher::new(capacity, move || self.next_triplet_batch_for_split(split))
2989 }
2990
2991 pub fn prefetch_triplet_batches_with_weights(
2993 self: Arc<Self>,
2994 split: SplitLabel,
2995 capacity: usize,
2996 weights: HashMap<SourceId, f32>,
2997 ) -> BatchPrefetcher<TripletBatch> {
2998 BatchPrefetcher::new(capacity, move || {
2999 self.next_triplet_batch_with_weights_for_split(split, &weights)
3000 })
3001 }
3002
3003 pub fn prefetch_pair_batches(
3005 self: Arc<Self>,
3006 split: SplitLabel,
3007 capacity: usize,
3008 ) -> BatchPrefetcher<SampleBatch> {
3009 BatchPrefetcher::new(capacity, move || self.next_pair_batch_for_split(split))
3010 }
3011
3012 pub fn prefetch_pair_batches_with_weights(
3014 self: Arc<Self>,
3015 split: SplitLabel,
3016 capacity: usize,
3017 weights: HashMap<SourceId, f32>,
3018 ) -> BatchPrefetcher<SampleBatch> {
3019 BatchPrefetcher::new(capacity, move || {
3020 self.next_pair_batch_with_weights_for_split(split, &weights)
3021 })
3022 }
3023
3024 pub fn prefetch_text_batches(
3026 self: Arc<Self>,
3027 split: SplitLabel,
3028 capacity: usize,
3029 ) -> BatchPrefetcher<TextBatch> {
3030 BatchPrefetcher::new(capacity, move || self.next_text_batch_for_split(split))
3031 }
3032
3033 pub fn prefetch_text_batches_with_weights(
3035 self: Arc<Self>,
3036 split: SplitLabel,
3037 capacity: usize,
3038 weights: HashMap<SourceId, f32>,
3039 ) -> BatchPrefetcher<TextBatch> {
3040 BatchPrefetcher::new(capacity, move || {
3041 self.next_text_batch_with_weights_for_split(split, &weights)
3042 })
3043 }
3044
3045 pub fn text_recipes(&self) -> Vec<TextRecipe> {
3047 let inner = self.inner.lock().unwrap();
3048 inner.text_recipes().to_vec()
3049 }
3050
3051 pub fn register_source(&self, source: Box<dyn DataSource + 'static>) {
3053 let mut inner = self.inner.lock().unwrap();
3054 inner.register_source(source);
3055 }
3056
3057 pub fn set_epoch(&self, epoch: u64) -> Result<(), SamplerError> {
3059 let mut inner = self.inner.lock().unwrap();
3060 inner.set_epoch(epoch)
3061 }
3062
3063 pub fn save_sampler_state(&self, save_to: Option<&Path>) -> Result<(), SamplerError> {
3068 let mut inner = self.inner.lock().unwrap();
3069 inner.save_sampler_state(save_to)
3070 }
3071
3072 #[cfg(all(feature = "bm25-mining", feature = "extended-metrics"))]
3080 pub fn bm25_fallback_stats(&self) -> (u64, u64) {
3081 let inner = self.inner.lock().unwrap();
3082 inner.bm25_fallback_stats()
3083 }
3084}
3085
3086impl<S: SplitStore + EpochStateStore + SamplerStateStore + 'static> Sampler for TripletSampler<S> {
3087 fn next_pair_batch(&self, split: SplitLabel) -> Result<SampleBatch, SamplerError> {
3088 self.next_pair_batch_for_split(split)
3089 }
3090
3091 fn next_pair_batch_with_weights(
3092 &self,
3093 split: SplitLabel,
3094 weights: &HashMap<SourceId, f32>,
3095 ) -> Result<SampleBatch, SamplerError> {
3096 self.next_pair_batch_with_weights_for_split(split, weights)
3097 }
3098
3099 fn next_text_batch(&self, split: SplitLabel) -> Result<TextBatch, SamplerError> {
3100 self.next_text_batch_for_split(split)
3101 }
3102
3103 fn next_text_batch_with_weights(
3104 &self,
3105 split: SplitLabel,
3106 weights: &HashMap<SourceId, f32>,
3107 ) -> Result<TextBatch, SamplerError> {
3108 self.next_text_batch_with_weights_for_split(split, weights)
3109 }
3110
3111 fn next_triplet_batch(&self, split: SplitLabel) -> Result<TripletBatch, SamplerError> {
3112 self.next_triplet_batch_for_split(split)
3113 }
3114
3115 fn next_triplet_batch_with_weights(
3116 &self,
3117 split: SplitLabel,
3118 weights: &HashMap<SourceId, f32>,
3119 ) -> Result<TripletBatch, SamplerError> {
3120 self.next_triplet_batch_with_weights_for_split(split, weights)
3121 }
3122}
3123
3124fn roles_match(target: &SectionRole, candidate: &SectionRole) -> bool {
3125 target == candidate
3126}
3127
3128fn role_cursor_key(record_id: &RecordId, role: &SectionRole) -> (RecordId, String) {
3129 (record_id.clone(), role_label(role))
3130}
3131
3132fn role_label(role: &SectionRole) -> String {
3133 match role {
3134 SectionRole::Anchor => ROLE_LABEL_ANCHOR.into(),
3135 SectionRole::Context => ROLE_LABEL_CONTEXT.into(),
3136 }
3137}
3138
3139fn taxonomy_value(record: &DataRecord, field: MetadataKey) -> Option<&str> {
3140 record.taxonomy.iter().find_map(|entry| field.strip(entry))
3141}
3142
3143fn strategy_reason(strategy: &NegativeStrategy) -> &'static str {
3144 match strategy {
3145 NegativeStrategy::WrongPublicationDate => NEG_REASON_WRONG_DATE,
3146 NegativeStrategy::WrongArticle => NEG_REASON_WRONG_ARTICLE,
3147 NegativeStrategy::QuestionAnswerMismatch => NEG_REASON_WRONG_QA,
3148 }
3149}
3150
3151fn text_dedup_key(chunk: &RecordChunk) -> (String, String) {
3159 (chunk.record_id.clone(), chunk.text.clone())
3160}
3161
3162fn triplet_chunk_key(chunk: &RecordChunk) -> String {
3163 match &chunk.view {
3164 ChunkView::Window { index, .. } => {
3165 format!("{}|{}|w|{}", chunk.record_id, chunk.section_idx, index)
3166 }
3167 ChunkView::SummaryFallback { strategy, .. } => {
3168 format!("{}|{}|s|{}", chunk.record_id, chunk.section_idx, strategy)
3169 }
3170 }
3171}
3172
3173fn pad_with_reuse<T: Clone>(items: &mut Vec<T>, target: usize) {
3174 if items.is_empty() || items.len() >= target {
3175 return;
3176 }
3177 let seed = items.clone();
3178 let base_len = seed.len();
3179 for idx in 0..(target - items.len()) {
3180 items.push(seed[idx % base_len].clone());
3181 }
3182}
3183
3184#[cfg(test)]
3185mod tests;