1mod backends;
2
3use crate::chunking::{ChunkingAlgorithm, SlidingWindowChunker};
4use chrono::Duration;
5use indexmap::IndexMap;
6use rand::prelude::*;
7use rayon::prelude::*;
8use std::borrow::Cow;
9use std::collections::{HashMap, HashSet};
10use std::path::Path;
11use std::sync::atomic::{AtomicUsize, Ordering};
12use std::sync::mpsc;
13use std::sync::{Arc, Mutex};
14use std::thread;
15
16use crate::config::{
17 ChunkingStrategy, NegativeStrategy, SamplerConfig, Selector, TextRecipe, TripletRecipe,
18};
19use crate::constants::sampler::AUTO_INJECTED_LONG_SECTION_CHUNK_PAIR_RECIPE_NAME;
20use crate::constants::sampler::{
21 ANCHOR_POSITIVE_SWAP_MASK, EPOCH_SEED_OFFSET, EXHAUSTION_RETRY_LIMIT, NEG_REASON_WRONG_ARTICLE,
22 NEG_REASON_WRONG_DATE, NEG_REASON_WRONG_QA, PREFETCHER_SOURCE_ID, PREFETCHER_STOPPED_REASON,
23 RECIPE_LABEL_TEXT, RECIPE_LABEL_TRIPLETS, RECIPE_ORDER_MAX_WEIGHT_MULTIPLIER,
24 ROLE_LABEL_ANCHOR, ROLE_LABEL_CONTEXT, SAME_SELECTOR_PAIR_RETRY_LIMIT,
25};
26use crate::data::{
27 ChunkView, DataRecord, PairLabel, RecordChunk, RecordSection, SampleBatch, SamplePair,
28 SampleTriplet, SectionRole, TextBatch, TextSample, TripletBatch,
29};
30use crate::epoch::EpochTracker;
31use crate::errors::SamplerError;
32use crate::hash::{derive_epoch_seed, stable_hash_str};
33use crate::ingestion::IngestionManager;
34use crate::metadata::{META_FIELD_DATE, MetadataKey};
35use crate::metrics::{chunk_proximity_score, window_index_proximity};
36use crate::source::DataSource;
37use crate::splits::{
38 EpochStateStore, PersistedSamplerState, SamplerStateStore, SplitLabel, SplitStore,
39};
40use crate::tokenizer::{Tokenizer, WhitespaceTokenizer};
41use crate::types::{RecipeKey, RecordId, SourceId};
42use crate::utils::platform_newline;
43
44#[derive(Debug, Clone)]
57struct DeterministicRng {
59 state: u64,
60}
61
62impl DeterministicRng {
63 fn new(seed: u64) -> Self {
64 Self { state: seed }
65 }
66
67 fn from_state(state: u64) -> Self {
68 Self { state }
69 }
70
71 fn state(&self) -> u64 {
72 self.state
73 }
74
75 fn next_u64_internal(&mut self) -> u64 {
76 let mut z = self.state.wrapping_add(0x9E3779B97F4A7C15);
77 self.state = z;
78 z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
79 z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
80 z ^ (z >> 31)
81 }
82}
83
84impl rand::RngCore for DeterministicRng {
85 fn next_u32(&mut self) -> u32 {
86 self.next_u64_internal() as u32
87 }
88
89 fn next_u64(&mut self) -> u64 {
90 self.next_u64_internal()
91 }
92
93 fn fill_bytes(&mut self, dest: &mut [u8]) {
94 let mut offset = 0;
95 while offset < dest.len() {
96 let value = self.next_u64_internal();
97 let bytes = value.to_le_bytes();
98 let remaining = dest.len() - offset;
99 let copy_len = remaining.min(bytes.len());
100 dest[offset..offset + copy_len].copy_from_slice(&bytes[..copy_len]);
101 offset += copy_len;
102 }
103 }
104}
105
106pub fn chunk_weight(strategy: &ChunkingStrategy, chunk: &RecordChunk) -> f32 {
115 let floor = strategy.chunk_weight_floor;
116 let trust = chunk.quality.trust.clamp(0.0, 1.0);
117 let base = match &chunk.view {
118 ChunkView::Window { index, .. } => window_index_proximity(*index),
119 ChunkView::SummaryFallback { weight, .. } => *weight,
120 };
121 (base * trust).max(floor)
122}
123
124pub trait Sampler {
126 fn next_pair_batch(&self, split: SplitLabel) -> Result<SampleBatch, SamplerError> {
128 self.next_pair_batch_with_weights(split, &HashMap::new())
129 }
130 fn next_text_batch(&self, split: SplitLabel) -> Result<TextBatch, SamplerError> {
132 self.next_text_batch_with_weights(split, &HashMap::new())
133 }
134 fn next_triplet_batch(&self, split: SplitLabel) -> Result<TripletBatch, SamplerError> {
136 self.next_triplet_batch_with_weights(split, &HashMap::new())
137 }
138 fn next_pair_batch_with_weights(
140 &self,
141 split: SplitLabel,
142 weights: &HashMap<SourceId, f32>,
143 ) -> Result<SampleBatch, SamplerError>;
144 fn next_text_batch_with_weights(
146 &self,
147 split: SplitLabel,
148 weights: &HashMap<SourceId, f32>,
149 ) -> Result<TextBatch, SamplerError>;
150 fn next_triplet_batch_with_weights(
152 &self,
153 split: SplitLabel,
154 weights: &HashMap<SourceId, f32>,
155 ) -> Result<TripletBatch, SamplerError>;
156}
157
158pub struct BatchPrefetcher<T> {
160 receiver: Option<mpsc::Receiver<Result<T, SamplerError>>>,
161 handle: Option<thread::JoinHandle<()>>,
162 stats: Arc<PrefetcherStats>,
163}
164
165#[derive(Default)]
166struct PrefetcherStats {
168 queued: AtomicUsize,
169 produced: AtomicUsize,
170 errors: AtomicUsize,
171}
172
173impl<T: Send + 'static> BatchPrefetcher<T> {
174 fn new<F>(capacity: usize, mut producer: F) -> Self
175 where
176 F: FnMut() -> Result<T, SamplerError> + Send + 'static,
177 {
178 let (sender, receiver) = mpsc::sync_channel(capacity.max(1));
179 let stats = Arc::new(PrefetcherStats::default());
180 let stats_thread = Arc::clone(&stats);
181 let handle = thread::spawn(move || {
182 loop {
183 let result = producer();
184 if result.is_err() {
185 stats_thread.errors.fetch_add(1, Ordering::Relaxed);
186 }
187 if sender.send(result).is_err() {
188 return;
189 }
190 stats_thread.queued.fetch_add(1, Ordering::Relaxed);
191 stats_thread.produced.fetch_add(1, Ordering::Relaxed);
192 }
193 });
194 Self {
195 receiver: Some(receiver),
196 handle: Some(handle),
197 stats,
198 }
199 }
200
201 pub fn next(&self) -> Result<T, SamplerError> {
203 let receiver = self
204 .receiver
205 .as_ref()
206 .ok_or_else(|| SamplerError::SourceUnavailable {
207 source_id: PREFETCHER_SOURCE_ID.into(),
208 reason: PREFETCHER_STOPPED_REASON.into(),
209 })?;
210 let result = receiver.recv().unwrap_or_else(|_| {
211 Err(SamplerError::SourceUnavailable {
212 source_id: PREFETCHER_SOURCE_ID.into(),
213 reason: PREFETCHER_STOPPED_REASON.into(),
214 })
215 });
216 self.stats
217 .queued
218 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |value| {
219 Some(value.saturating_sub(1))
220 })
221 .ok();
222 result
223 }
224
225 pub fn queue_len(&self) -> usize {
227 self.stats.queued.load(Ordering::Relaxed)
228 }
229
230 pub fn produced_count(&self) -> usize {
232 self.stats.produced.load(Ordering::Relaxed)
233 }
234
235 pub fn error_count(&self) -> usize {
237 self.stats.errors.load(Ordering::Relaxed)
238 }
239}
240
241impl<T> Drop for BatchPrefetcher<T> {
242 fn drop(&mut self) {
243 self.receiver.take();
244 if let Some(handle) = self.handle.take() {
245 let _ = handle.join();
246 }
247 }
248}
249
250pub struct TripletSampler<S: SplitStore + EpochStateStore + SamplerStateStore + 'static> {
253 inner: Mutex<TripletSamplerInner<S>>,
254}
255
256struct TripletSamplerInner<S: SplitStore + EpochStateStore + SamplerStateStore + 'static> {
258 config: SamplerConfig,
260 chunker: Arc<dyn ChunkingAlgorithm>,
262 split_store: Arc<S>,
264 ingestion: IngestionManager,
266 records: IndexMap<RecordId, Arc<DataRecord>>,
268 rng: DeterministicRng,
270 triplet_recipes: Vec<TripletRecipe>,
272 text_recipes: Vec<TextRecipe>,
274 source_triplet_recipes: HashMap<SourceId, Vec<TripletRecipe>>,
276 sources_with_long_sections: HashSet<SourceId>,
278 source_text_recipes: HashMap<SourceId, Vec<TextRecipe>>,
280 using_config_triplet_recipes: bool,
282 using_config_text_recipes: bool,
284 last_observed_ingest: u64,
286 epoch_tracker: EpochTracker,
288 chunk_cursors: HashMap<(RecordId, usize), usize>,
290 role_cursors: HashMap<(RecordId, String), usize>,
292 negative_backend: Box<dyn backends::NegativeBackend>,
294 chunk_index: HashMap<RecordId, RecordId>,
296 source_order: Vec<SourceId>,
298 source_cycle_idx: usize,
300 source_state_loaded: bool,
302 ingestion_cursors_loaded: bool,
304 source_state_dirty: bool,
306 source_record_indices: HashMap<SourceId, Vec<usize>>,
308 source_record_cursors: HashMap<SourceId, usize>,
310 triplet_recipe_rr_idx: usize,
312 text_recipe_rr_idx: usize,
314 source_epoch: u64,
316 source_wrapped: HashMap<SourceId, bool>,
318}
319
320impl<S: SplitStore + EpochStateStore + SamplerStateStore + 'static> TripletSamplerInner<S> {
321 fn new(config: SamplerConfig, split_store: Arc<S>) -> Self {
322 Self::new_with_chunker(config, split_store, Arc::new(SlidingWindowChunker))
323 }
324
325 fn new_with_chunker(
326 config: SamplerConfig,
327 split_store: Arc<S>,
328 chunker: Arc<dyn ChunkingAlgorithm>,
329 ) -> Self {
330 let buffer_size = config.ingestion_max_records.max(config.batch_size).max(2);
331 let using_config_triplet_recipes = !config.recipes.is_empty();
332 let using_config_text_recipes = !config.text_recipes.is_empty();
333 let triplet_recipes = if using_config_triplet_recipes {
334 config.recipes.clone()
335 } else {
336 Vec::new()
337 };
338 let text_recipes = if using_config_text_recipes {
339 config.text_recipes.clone()
340 } else if !triplet_recipes.is_empty() {
341 Self::build_derived_text_recipes(&triplet_recipes)
342 } else {
343 Vec::new()
344 };
345 let ingestion = IngestionManager::new(buffer_size, config.clone());
346 let epoch_backend = Some(Arc::clone(&split_store) as Arc<dyn EpochStateStore>);
347 let epoch_tracker = EpochTracker::new(
348 true,
349 epoch_backend,
350 derive_epoch_seed(config.seed, EPOCH_SEED_OFFSET),
351 );
352 let mut sampler = Self {
353 rng: DeterministicRng::new(config.seed),
354 config,
355 chunker,
356 split_store,
357 ingestion,
358 records: IndexMap::new(),
359 triplet_recipes,
360 text_recipes,
361 source_triplet_recipes: HashMap::new(),
362 sources_with_long_sections: HashSet::new(),
363 source_text_recipes: HashMap::new(),
364 using_config_triplet_recipes,
365 using_config_text_recipes,
366 last_observed_ingest: 0,
367 epoch_tracker,
368 chunk_cursors: HashMap::new(),
369 role_cursors: HashMap::new(),
370 negative_backend: {
371 #[cfg(feature = "bm25-mining")]
372 {
373 Box::new(backends::Bm25Backend::new())
374 }
375 #[cfg(not(feature = "bm25-mining"))]
376 {
377 Box::new(backends::DefaultBackend)
378 }
379 },
380 chunk_index: HashMap::new(),
381 source_order: Vec::new(),
382 source_cycle_idx: 0,
383 source_state_loaded: false,
384 ingestion_cursors_loaded: false,
385 source_state_dirty: false,
386 source_record_indices: HashMap::new(),
387 source_record_cursors: HashMap::new(),
388 triplet_recipe_rr_idx: 0,
389 text_recipe_rr_idx: 0,
390 source_epoch: 0,
391 source_wrapped: HashMap::new(),
392 };
393 if !sampler.using_config_text_recipes {
394 sampler.rebuild_derived_text_recipes();
395 }
396 sampler
397 }
398
399 fn text_recipes(&self) -> &[TextRecipe] {
400 &self.text_recipes
401 }
402
403 fn epoch_seed(&self) -> u64 {
406 derive_epoch_seed(self.config.seed, self.source_epoch)
407 }
408
409 fn register_source(&mut self, source: Box<dyn DataSource + 'static>) {
410 let source_id = source.id().to_string();
411 if !self.using_config_triplet_recipes {
412 let triplets = source.default_triplet_recipes();
413 if !triplets.is_empty() {
414 self.source_triplet_recipes
415 .insert(source_id.clone(), triplets.clone());
416 if !self.using_config_text_recipes {
417 let derived = Self::build_derived_text_recipes(&triplets);
418 self.source_text_recipes
419 .insert(source_id.clone(), derived.clone());
420 self.extend_text_recipes_unique(&derived);
421 }
422 }
423 }
424 self.ingestion.register_source(source);
425 }
426
427 fn set_epoch(&mut self, epoch: u64) -> Result<(), SamplerError> {
428 self.epoch_tracker.ensure_loaded()?;
429 self.epoch_tracker.force_epoch(epoch);
430 self.source_epoch = epoch;
431 self.ingestion.set_source_epoch(epoch);
432 self.ingestion.reset_stream_cursors();
433 self.source_record_cursors.clear();
434 self.source_cycle_idx = 0;
435 for source in &self.source_order {
436 self.source_wrapped.insert(source.clone(), false);
437 }
438 self.rebuild_source_index()?;
439 self.source_state_dirty = self.source_order.len() > 1;
440 Ok(())
441 }
442
443 fn next_chunk_from_pool(
444 &mut self,
445 record_id: &str,
446 section_idx: usize,
447 pool: Vec<RecordChunk>,
448 ) -> Option<RecordChunk> {
449 if pool.is_empty() {
450 return None;
451 }
452 let key = (record_id.to_string(), section_idx);
453 if !self.chunk_cursors.contains_key(&key) {
454 let cursor_key = format!("{}::{}", record_id, section_idx);
460 let start = (stable_hash_str(self.epoch_seed(), &cursor_key) as usize) % pool.len();
461 self.chunk_cursors.insert(key.clone(), start);
462 }
463 let cursor = self.chunk_cursors.entry(key).or_insert(0);
464 if *cursor >= pool.len() {
465 *cursor = 0;
466 }
467 let chunk = pool.get(*cursor).cloned();
468 *cursor = (*cursor + 1) % pool.len();
469 chunk
470 }
471
472 fn prune_cursor_state(&mut self) {
473 if self.chunk_cursors.is_empty()
474 && self.role_cursors.is_empty()
475 && self.negative_backend.cursors_empty()
476 {
477 return;
478 }
479 let valid_ids: HashSet<RecordId> = self.records.keys().cloned().collect();
480 self.chunk_cursors
481 .retain(|(record_id, _), _| valid_ids.contains(record_id));
482 self.role_cursors
483 .retain(|(record_id, _), _| valid_ids.contains(record_id));
484 self.negative_backend.prune_cursors(&valid_ids);
485 }
486
487 fn rebuild_chunk_index(&mut self) {
488 self.chunk_index.clear();
489 for record in self.records.values() {
490 self.chunk_index
491 .insert(record.id.clone(), record.id.clone());
492 }
493 }
494
495 fn rebuild_source_index(&mut self) -> Result<(), SamplerError> {
496 self.source_record_indices.clear();
497 let mut label_cache: HashMap<RecordId, SplitLabel> = HashMap::new();
498 let allowed = self.allowed_target_splits();
499 let allowed_set: HashSet<SplitLabel> = allowed.into_iter().collect();
500 for (idx, record) in self.records.values().enumerate() {
501 let label = if let Some(label) = label_cache.get(&record.id) {
502 *label
503 } else {
504 let label = match self.split_store.label_for(&record.id) {
505 Some(label) => label,
506 None => self.split_store.ensure(record.id.clone())?,
507 };
508 label_cache.insert(record.id.clone(), label);
509 label
510 };
511 if !allowed_set.contains(&label) {
512 continue;
513 }
514 self.source_record_indices
515 .entry(record.source.clone())
516 .or_default()
517 .push(idx);
518 }
519
520 let shuffle_seed = self.epoch_seed();
521 for indices in self.source_record_indices.values_mut() {
522 indices.sort_by_key(|idx| {
523 self.records
524 .get_index(*idx)
525 .map(|(_, record)| stable_hash_str(shuffle_seed, &record.id))
526 .unwrap_or(0)
527 });
528 }
529
530 self.source_order = self.source_record_indices.keys().cloned().collect();
531 self.source_order.sort();
532 self.refresh_source_wrapped();
533
534 self.source_record_cursors
535 .retain(|source, _| self.source_record_indices.contains_key(source));
536 if self.source_state_loaded {
537 if self.source_order.is_empty() {
538 self.source_cycle_idx = 0;
539 }
540 self.source_state_dirty = self.source_order.len() > 1;
541 }
542 Ok(())
543 }
544
545 fn refresh_source_wrapped(&mut self) {
546 self.source_wrapped.clear();
547 for source in &self.source_order {
548 let len = self
549 .source_record_indices
550 .get(source)
551 .map(|items| items.len())
552 .unwrap_or(0);
553 if len == 0 {
554 self.source_wrapped.insert(source.clone(), false);
555 continue;
556 }
557 let cursor = self.source_record_cursors.get(source).copied().unwrap_or(0);
558 let wrapped = cursor > 0 && cursor % len == 0;
559 self.source_wrapped.insert(source.clone(), wrapped);
560 }
561 }
562
563 fn shuffled_source_cycle(&self, cycle: u64) -> Vec<SourceId> {
564 let mut sources = self.source_order.clone();
565 let seed = self.epoch_seed() ^ cycle;
566 sources.sort_by_key(|source| stable_hash_str(seed, source));
567 sources
568 }
569
570 fn ensure_source_state(&mut self) -> Result<(), SamplerError> {
571 if self.source_state_loaded {
572 return Ok(());
573 }
574 let persisted = self.split_store.load_sampler_state()?;
575 self.source_cycle_idx = persisted
576 .as_ref()
577 .map(|state| state.source_cycle_idx as usize)
578 .unwrap_or(0);
579 if let Some(state) = persisted {
580 for (source, cursor) in state.source_record_cursors {
581 if self.source_record_indices.contains_key(&source) {
582 self.source_record_cursors.insert(source, cursor as usize);
583 }
584 }
585 self.source_epoch = state.source_epoch;
586 self.ingestion.set_source_epoch(state.source_epoch);
587 self.rng = DeterministicRng::from_state(state.rng_state);
588 self.triplet_recipe_rr_idx = state.triplet_recipe_rr_idx as usize;
589 self.text_recipe_rr_idx = state.text_recipe_rr_idx as usize;
590 }
591 self.refresh_source_wrapped();
592 self.source_state_loaded = true;
593 self.source_state_dirty = true;
594 Ok(())
595 }
596
597 fn persist_source_state(&mut self, save_to: Option<&Path>) -> Result<(), SamplerError> {
598 if !self.source_state_loaded {
599 return Ok(());
600 }
601 let state = PersistedSamplerState {
602 source_cycle_idx: self.source_cycle_idx as u64,
603 source_record_cursors: self
604 .source_record_cursors
605 .iter()
606 .map(|(source, cursor)| (source.clone(), *cursor as u64))
607 .collect(),
608 source_epoch: self.source_epoch,
609 rng_state: self.rng.state(),
610 triplet_recipe_rr_idx: self.triplet_recipe_rr_idx as u64,
611 text_recipe_rr_idx: self.text_recipe_rr_idx as u64,
612 source_stream_cursors: self.ingestion.snapshot_cursors(),
613 };
614 self.split_store.save_sampler_state(&state, save_to)?;
615 self.source_state_dirty = false;
616 Ok(())
617 }
618
619 fn rebuild_derived_text_recipes(&mut self) {
620 if self.using_config_text_recipes {
621 return;
622 }
623 if self.triplet_recipes.is_empty() {
624 self.text_recipes.clear();
625 } else {
626 self.text_recipes = Self::build_derived_text_recipes(&self.triplet_recipes);
627 }
628 }
629
630 fn extend_text_recipes_unique(&mut self, recipes: &[TextRecipe]) {
631 for recipe in recipes {
632 if self
633 .text_recipes
634 .iter()
635 .any(|existing| existing.name == recipe.name)
636 {
637 continue;
638 }
639 self.text_recipes.push(recipe.clone());
640 }
641 }
642
643 fn configured_triplet_recipes_for_source<'a>(&'a self, source: &str) -> &'a [TripletRecipe] {
644 if self.using_config_triplet_recipes {
645 return &self.triplet_recipes;
646 }
647 self.source_triplet_recipes
648 .get(source)
649 .map(|recipes| recipes.as_slice())
650 .unwrap_or(&[])
651 }
652
653 fn contains_auto_chunk_pair_recipe(recipes: &[TripletRecipe]) -> bool {
655 recipes
656 .iter()
657 .any(|recipe| recipe.name.as_ref() == AUTO_INJECTED_LONG_SECTION_CHUNK_PAIR_RECIPE_NAME)
658 }
659
660 fn source_supports_chunk_pair_recipe(&self, source: &str) -> bool {
661 if self.config.chunking.max_window_tokens == 0 {
662 return false;
663 }
664 self.sources_with_long_sections.contains(source)
665 }
666
667 fn should_auto_inject_chunk_pair_recipe(
673 &self,
674 source: &str,
675 recipes: &[TripletRecipe],
676 ) -> bool {
677 self.source_supports_chunk_pair_recipe(source)
678 && !Self::contains_auto_chunk_pair_recipe(recipes)
679 }
680
681 fn source_chunk_pair_recipe() -> TripletRecipe {
692 TripletRecipe {
693 name: Cow::Borrowed(AUTO_INJECTED_LONG_SECTION_CHUNK_PAIR_RECIPE_NAME),
694 anchor: Selector::Role(SectionRole::Context),
695 positive_selector: Selector::Role(SectionRole::Context),
696 negative_selector: Selector::Role(SectionRole::Context),
697 negative_strategy: NegativeStrategy::WrongArticle,
698 weight: 1.0,
699 instruction: None,
700 allow_same_anchor_positive: false,
701 }
702 }
703
704 fn resolve_source_triplet_plan(&self, source: &str) -> (Vec<TripletRecipe>, bool) {
716 let mut recipes = self.configured_triplet_recipes_for_source(source).to_vec();
717 let mut auto_injected = false;
718 if self.should_auto_inject_chunk_pair_recipe(source, &recipes) {
719 recipes.push(Self::source_chunk_pair_recipe());
720 auto_injected = true;
721 }
722 (recipes, auto_injected)
723 }
724
725 #[cfg(test)]
726 fn triplet_recipes_for_source(&self, source: &str) -> Vec<TripletRecipe> {
727 self.resolve_source_triplet_plan(source).0
728 }
729
730 fn triplet_recipe_count_for_source(&self, source: &str) -> usize {
731 let (recipes, _auto_injected) = self.resolve_source_triplet_plan(source);
732 recipes.len()
733 }
734
735 fn text_recipes_for_source<'a>(&'a self, source: &str) -> &'a [TextRecipe] {
736 if self.using_config_text_recipes || self.using_config_triplet_recipes {
737 return &self.text_recipes;
738 }
739 self.source_text_recipes
740 .get(source)
741 .map(|recipes| recipes.as_slice())
742 .unwrap_or(&[])
743 }
744
745 fn recipe_order_weighted_shuffled(
756 &mut self,
757 weights: &[f32],
758 rng: &mut DeterministicRng,
759 ) -> Vec<usize> {
760 weighted_recipe_order(weights, rng)
761 }
762
763 fn recipe_order_weighted_cycled(
769 &mut self,
770 weights: &[f32],
771 rr_idx: usize,
772 rng: &mut DeterministicRng,
773 ) -> Vec<usize> {
774 let base = self.recipe_order_weighted_shuffled(weights, rng);
775 if base.is_empty() {
776 return base;
777 }
778 let start = rr_idx % base.len();
779 let mut order = Vec::with_capacity(base.len());
780 order.extend_from_slice(&base[start..]);
781 order.extend_from_slice(&base[..start]);
782 order
783 }
784
785 fn text_recipe_order_weighted_shuffled(
789 &mut self,
790 weights: &[f32],
791 rng: &mut DeterministicRng,
792 ) -> Vec<usize> {
793 weighted_recipe_order(weights, rng)
794 }
795
796 fn text_recipe_order_weighted_cycled(
797 &mut self,
798 weights: &[f32],
799 rr_idx: usize,
800 rng: &mut DeterministicRng,
801 ) -> Vec<usize> {
802 let base = self.text_recipe_order_weighted_shuffled(weights, rng);
803 if base.is_empty() {
804 return base;
805 }
806 let start = rr_idx % base.len();
807 let mut order = Vec::with_capacity(base.len());
808 order.extend_from_slice(&base[start..]);
809 order.extend_from_slice(&base[..start]);
810 order
811 }
812
813 fn allowed_target_splits(&self) -> Vec<SplitLabel> {
814 self.config.allowed_splits.clone()
815 }
816
817 fn ensure_split_allowed(&self, split: SplitLabel) -> Result<(), SamplerError> {
818 let allowed = self.allowed_target_splits();
819 if allowed.contains(&split) {
820 return Ok(());
821 }
822 Err(SamplerError::Configuration(format!(
823 "requested split {:?} is not in allowed_splits {:?}",
824 split, allowed
825 )))
826 }
827
828 fn ensure_split_has_records(&mut self, target_split: SplitLabel) -> Result<(), SamplerError> {
829 let records_by_split = self.records_by_split()?;
830 if records_by_split
831 .get(&target_split)
832 .map(|records| !records.is_empty())
833 .unwrap_or(false)
834 {
835 return Ok(());
836 }
837 Err(SamplerError::Exhausted(
838 "no records available for target split".into(),
839 ))
840 }
841
842 fn records_by_split(
843 &self,
844 ) -> Result<HashMap<SplitLabel, Vec<(RecordId, SourceId)>>, SamplerError> {
845 let mut map: HashMap<SplitLabel, Vec<(RecordId, SourceId)>> = HashMap::new();
846 let mut label_cache: HashMap<RecordId, SplitLabel> = HashMap::new();
847 for (chunk_id, record_id) in &self.chunk_index {
848 let Some(record) = self.records.get(record_id) else {
849 continue;
850 };
851 let label = if let Some(label) = label_cache.get(record_id) {
852 *label
853 } else {
854 let label = match self.split_store.label_for(record_id) {
855 Some(label) => label,
856 None => self.split_store.ensure(record_id.clone())?,
857 };
858 label_cache.insert(record_id.clone(), label);
859 label
860 };
861 map.entry(label)
862 .or_default()
863 .push((chunk_id.clone(), record.source.clone()));
864 }
865 Ok(map)
866 }
867
868 fn choose_anchor_record(
869 &mut self,
870 source: Option<&str>,
871 split: SplitLabel,
872 ) -> Option<Arc<DataRecord>> {
873 if let Some(source) = source {
874 let indices = self.source_record_indices.get(source)?;
875 if indices.is_empty() {
876 return None;
877 }
878 let mut cursor = *self.source_record_cursors.get(source).unwrap_or(&0);
879 let cycle = cursor / indices.len();
880 let offset_seed = self.epoch_seed() ^ (cycle as u64);
881 let offset = (stable_hash_str(offset_seed, source) as usize) % indices.len();
882 let mut wrapped = false;
883 let mut selected: Option<Arc<DataRecord>> = None;
884 for _ in 0..indices.len() {
885 let pos = (cursor % indices.len()).saturating_add(offset) % indices.len();
886 let idx = indices[pos];
887 cursor = cursor.saturating_add(1);
888 if cursor.is_multiple_of(indices.len()) {
889 wrapped = true;
890 }
891 if let Some((_, record)) = self.records.get_index(idx) {
892 if self.split_store.label_for(&record.id) != Some(split) {
893 continue;
894 }
895 selected = Some(Arc::clone(record));
896 break;
897 }
898 }
899 self.source_record_cursors
900 .insert(source.to_string(), cursor);
901 if wrapped {
902 self.mark_source_wrapped(source);
903 }
904 return selected;
905 }
906 while let Some(chunk_id) = self.epoch_tracker.next_record(split) {
907 if let Some(record_id) = self.chunk_index.get(&chunk_id)
908 && let Some(record) = self.records.get(record_id)
909 {
910 return Some(Arc::clone(record));
911 }
912 }
913 None
914 }
915
916 fn save_sampler_state(&mut self, save_to: Option<&Path>) -> Result<(), SamplerError> {
917 if self.epoch_tracker.is_enabled() {
918 self.epoch_tracker.persist()?;
919 }
920 self.persist_source_state(save_to)?;
921 Ok(())
922 }
923
924 fn mark_source_wrapped(&mut self, source: &str) {
925 self.source_wrapped.insert(source.to_string(), true);
926 if self.source_order.is_empty() {
927 return;
928 }
929 let all_wrapped = self
930 .source_order
931 .iter()
932 .all(|name| self.source_wrapped.get(name).copied().unwrap_or(false));
933 if all_wrapped {
934 self.advance_source_epoch();
935 }
936 }
937
938 fn advance_source_epoch(&mut self) {
939 self.source_epoch = self.source_epoch.saturating_add(1);
940 self.ingestion.set_source_epoch(self.source_epoch);
941 self.source_record_cursors.clear();
942 self.source_cycle_idx = 0;
943 for source in &self.source_order {
944 self.source_wrapped.insert(source.clone(), false);
945 }
946 let _ = self.rebuild_source_index();
947 self.source_state_dirty = self.source_order.len() > 1;
948 }
949
950 fn select_temporal_neighbor(
951 &'_ self,
952 record: &DataRecord,
953 offset_days: i32,
954 ) -> Option<Arc<DataRecord>> {
955 let target = record.created_at + Duration::days(offset_days.into());
956 let key = record.taxonomy.first().cloned();
957 let record_split = self.split_store.label_for(&record.id)?;
958 self.records
959 .values()
960 .filter(|candidate| {
961 candidate.id != record.id
962 && self
963 .split_store
964 .label_for(&candidate.id)
965 .map(|label| label == record_split)
966 .unwrap_or(false)
967 && (candidate.source == record.source
968 || key
969 .as_ref()
970 .zip(candidate.taxonomy.first())
971 .map(|(a, b)| a == b)
972 .unwrap_or(false))
973 })
974 .min_by_key(|candidate| (candidate.created_at - target).num_seconds().abs())
975 .cloned()
976 }
977
978 fn select_negative_record(
979 &self,
980 anchor_record: &DataRecord,
981 strategy: &NegativeStrategy,
982 anchor_query_text: Option<&str>,
983 rng: &mut dyn rand::RngCore,
984 ) -> Option<(Arc<DataRecord>, bool)> {
985 let anchor_split = self.split_store.label_for(&anchor_record.id)?;
986
987 let in_anchor_split = |candidate: &DataRecord| {
988 self.split_store
989 .label_for(&candidate.id)
990 .map(|label| label == anchor_split)
991 .unwrap_or(false)
992 };
993
994 match strategy {
995 NegativeStrategy::WrongArticle => {
996 let anchor_date =
997 taxonomy_value(anchor_record, META_FIELD_DATE).map(|d| d.to_string());
998 let mut same_date: Vec<Arc<DataRecord>> = self
999 .records
1000 .values()
1001 .filter(|candidate| {
1002 candidate.source == anchor_record.source
1003 && candidate.id != anchor_record.id
1004 && in_anchor_split(candidate)
1005 })
1006 .filter(|candidate| {
1007 anchor_date
1008 .as_deref()
1009 .zip(taxonomy_value(candidate, META_FIELD_DATE))
1010 .map(|(a, b)| a == b)
1011 .unwrap_or(false)
1012 })
1013 .cloned()
1014 .collect();
1015 if same_date.is_empty() {
1016 same_date = self
1017 .records
1018 .values()
1019 .filter(|candidate| {
1020 candidate.source == anchor_record.source
1021 && candidate.id != anchor_record.id
1022 && in_anchor_split(candidate)
1023 })
1024 .cloned()
1025 .collect();
1026 }
1027 if !same_date.is_empty() {
1028 return self.negative_backend.choose_negative(
1029 anchor_record,
1030 anchor_split,
1031 same_date,
1032 false,
1033 anchor_query_text,
1034 rng,
1035 );
1036 }
1037 let pool = self
1038 .records
1039 .values()
1040 .filter(|candidate| {
1041 candidate.id != anchor_record.id && in_anchor_split(candidate)
1042 })
1043 .cloned()
1044 .collect::<Vec<_>>();
1045 self.negative_backend.choose_negative(
1046 anchor_record,
1047 anchor_split,
1048 pool,
1049 true,
1050 anchor_query_text,
1051 rng,
1052 )
1053 }
1054 NegativeStrategy::WrongPublicationDate => {
1055 let anchor_date =
1056 taxonomy_value(anchor_record, META_FIELD_DATE).map(|d| d.to_string());
1057 let pool: Vec<Arc<DataRecord>> = self
1058 .records
1059 .values()
1060 .filter(|candidate| {
1061 candidate.source == anchor_record.source
1062 && candidate.id != anchor_record.id
1063 && in_anchor_split(candidate)
1064 })
1065 .filter(|candidate| {
1066 match (
1067 anchor_date.as_deref(),
1068 taxonomy_value(candidate, META_FIELD_DATE),
1069 ) {
1070 (Some(anchor), Some(candidate_date)) => anchor != candidate_date,
1071 (Some(_), None) => true,
1072 (None, Some(_)) => true,
1073 (None, None) => false,
1074 }
1075 })
1076 .cloned()
1077 .collect();
1078 if pool.is_empty() {
1079 let fallback_pool = self
1082 .records
1083 .values()
1084 .filter(|candidate| {
1085 candidate.id != anchor_record.id && in_anchor_split(candidate)
1086 })
1087 .cloned()
1088 .collect::<Vec<_>>();
1089
1090 return self.negative_backend.choose_negative(
1091 anchor_record,
1092 anchor_split,
1093 fallback_pool,
1094 true,
1095 anchor_query_text,
1096 rng,
1097 );
1098 }
1099
1100 self.negative_backend.choose_negative(
1101 anchor_record,
1102 anchor_split,
1103 pool,
1104 false,
1105 anchor_query_text,
1106 rng,
1107 )
1108 }
1109 NegativeStrategy::QuestionAnswerMismatch => {
1110 let pool: Vec<Arc<DataRecord>> = self
1111 .records
1112 .values()
1113 .filter(|candidate| {
1114 candidate.source == anchor_record.source
1115 && candidate.id != anchor_record.id
1116 && in_anchor_split(candidate)
1117 })
1118 .cloned()
1119 .collect();
1120 if pool.is_empty() {
1121 let fallback_pool = self
1124 .records
1125 .values()
1126 .filter(|candidate| {
1127 candidate.id != anchor_record.id && in_anchor_split(candidate)
1128 })
1129 .cloned()
1130 .collect::<Vec<_>>();
1131
1132 return self.negative_backend.choose_negative(
1133 anchor_record,
1134 anchor_split,
1135 fallback_pool,
1136 true,
1137 anchor_query_text,
1138 rng,
1139 );
1140 }
1141
1142 self.negative_backend.choose_negative(
1143 anchor_record,
1144 anchor_split,
1145 pool,
1146 false,
1147 anchor_query_text,
1148 rng,
1149 )
1150 }
1151 }
1152 }
1153
1154 fn is_auto_chunk_pair_recipe(recipe: &TripletRecipe) -> bool {
1156 recipe.name.as_ref() == AUTO_INJECTED_LONG_SECTION_CHUNK_PAIR_RECIPE_NAME
1157 }
1158
1159 fn select_anchor_positive_pair(
1163 &mut self,
1164 record: &DataRecord,
1165 anchor_selector: &Selector,
1166 positive_selector: &Selector,
1167 enforce_window_pair: bool,
1168 ) -> Option<(RecordChunk, RecordChunk)> {
1169 let mut anchor_chunk = self.select_chunk(record, anchor_selector)?;
1170 let mut positive_chunk = self.select_chunk(record, positive_selector)?;
1171 if anchor_selector == positive_selector {
1172 let mut retries = 0usize;
1173 while !same_selector_pair_is_valid(&anchor_chunk, &positive_chunk, enforce_window_pair)
1174 && retries < SAME_SELECTOR_PAIR_RETRY_LIMIT
1175 {
1176 let Some(redraw_anchor) = self.select_chunk(record, anchor_selector) else {
1177 break;
1178 };
1179 let Some(redraw_positive) = self.select_chunk(record, positive_selector) else {
1180 break;
1181 };
1182 anchor_chunk = redraw_anchor;
1183 positive_chunk = redraw_positive;
1184 retries += 1;
1185 }
1186 if !same_selector_pair_is_valid(&anchor_chunk, &positive_chunk, enforce_window_pair) {
1187 return None;
1188 }
1189 }
1190 Some((anchor_chunk, positive_chunk))
1191 }
1192
1193 fn select_distinct_window_pair_for_auto_recipe(
1203 &mut self,
1204 recipe: &TripletRecipe,
1205 record: &DataRecord,
1206 ) -> Option<(RecordChunk, RecordChunk)> {
1207 if recipe.anchor != recipe.positive_selector {
1208 return None;
1209 }
1210 self.select_anchor_positive_pair(record, &recipe.anchor, &recipe.positive_selector, true)
1211 }
1212
1213 fn record_has_at_least_two_window_chunks_for_selector(
1218 &self,
1219 record: &DataRecord,
1220 selector: &Selector,
1221 ) -> bool {
1222 let section_indices: Vec<usize> = match selector {
1223 Selector::Role(role) => record
1224 .sections
1225 .iter()
1226 .enumerate()
1227 .filter(|(_, section)| roles_match(role, §ion.role))
1228 .map(|(idx, _)| idx)
1229 .collect(),
1230 Selector::Paragraph(idx) => {
1231 if *idx < record.sections.len() {
1232 vec![*idx]
1233 } else {
1234 Vec::new()
1235 }
1236 }
1237 Selector::Random => (0..record.sections.len()).collect(),
1238 Selector::TemporalOffset(_) => return false,
1239 };
1240
1241 let mut window_count = 0usize;
1242 for section_idx in section_indices {
1243 let Some(section) = record.sections.get(section_idx) else {
1244 continue;
1245 };
1246 let chunks = self.materialize_chunks(record, section_idx, section);
1247 window_count += chunks
1248 .iter()
1249 .filter(|chunk| matches!(chunk.view, ChunkView::Window { .. }))
1250 .count();
1251 if window_count >= 2 {
1252 return true;
1253 }
1254 }
1255 false
1256 }
1257
1258 fn build_triplet_with_selector_pair_policy(
1260 &mut self,
1261 recipe: &TripletRecipe,
1262 record: &DataRecord,
1263 enforce_window_pair: bool,
1264 rng: &mut DeterministicRng,
1265 ) -> Option<SampleTriplet> {
1266 let (mut anchor_chunk, mut positive_chunk) = self.select_anchor_positive_pair(
1267 record,
1268 &recipe.anchor,
1269 &recipe.positive_selector,
1270 enforce_window_pair,
1271 )?;
1272 let anchor_raw_text = anchor_chunk.text.clone();
1273 self.decorate_chunk(record, &mut anchor_chunk, rng);
1274 self.decorate_chunk(record, &mut positive_chunk, rng);
1275 self.finalize_triplet_with_negative(
1278 recipe,
1279 record,
1280 anchor_chunk,
1281 positive_chunk,
1282 &anchor_raw_text,
1283 rng,
1284 )
1285 }
1286
1287 fn make_auto_chunk_pair_triplet_with_anchor(
1297 &mut self,
1298 recipe: &TripletRecipe,
1299 record: &DataRecord,
1300 rng: &mut DeterministicRng,
1301 ) -> Option<SampleTriplet> {
1302 if !self.record_has_at_least_two_window_chunks_for_selector(record, &recipe.anchor) {
1303 return None;
1304 }
1305 let (mut anchor_chunk, mut positive_chunk) =
1306 self.select_distinct_window_pair_for_auto_recipe(recipe, record)?;
1307 let anchor_raw_text = anchor_chunk.text.clone();
1308 self.decorate_chunk(record, &mut anchor_chunk, rng);
1309 self.decorate_chunk(record, &mut positive_chunk, rng);
1310 self.finalize_triplet_with_negative(
1311 recipe,
1312 record,
1313 anchor_chunk,
1314 positive_chunk,
1315 &anchor_raw_text,
1316 rng,
1317 )
1318 }
1319
1320 fn make_standard_triplet_with_anchor(
1321 &mut self,
1322 recipe: &TripletRecipe,
1323 record: &DataRecord,
1324 rng: &mut DeterministicRng,
1325 ) -> Option<SampleTriplet> {
1326 self.build_triplet_with_selector_pair_policy(recipe, record, false, rng)
1327 }
1328
1329 fn finalize_triplet_with_negative(
1348 &mut self,
1349 recipe: &TripletRecipe,
1350 record: &DataRecord,
1351 anchor_chunk: RecordChunk,
1352 positive_chunk: RecordChunk,
1353 anchor_raw_text: &str,
1354 rng: &mut DeterministicRng,
1355 ) -> Option<SampleTriplet> {
1356 let (negative_record, fallback_used) = self.select_negative_record(
1357 record,
1358 &recipe.negative_strategy,
1359 Some(anchor_raw_text),
1360 rng,
1361 )?;
1362 let mut negative_chunk = self.select_chunk(&negative_record, &recipe.negative_selector)?;
1363 self.decorate_chunk(&negative_record, &mut negative_chunk, rng);
1364
1365 let (anchor_chunk, positive_chunk) = if rng.next_u64() & ANCHOR_POSITIVE_SWAP_MASK == 0 {
1367 (positive_chunk, anchor_chunk)
1368 } else {
1369 (anchor_chunk, positive_chunk)
1370 };
1371
1372 if (!recipe.allow_same_anchor_positive && anchor_chunk.text == positive_chunk.text)
1383 || negative_chunk.text == positive_chunk.text
1384 || negative_chunk.text == anchor_chunk.text
1385 {
1386 return None;
1387 }
1388
1389 let chunk_weight =
1390 self.triplet_chunk_weight(recipe, &anchor_chunk, &positive_chunk, &negative_chunk);
1391 let weight = recipe.weight * chunk_weight;
1392 let recipe_name = if fallback_used {
1393 format!("{}_fallback_same_split", recipe.name)
1394 } else {
1395 recipe.name.to_string()
1396 };
1397 Some(SampleTriplet {
1398 recipe: recipe_name,
1399 anchor: anchor_chunk,
1400 positive: positive_chunk,
1401 negative: negative_chunk,
1402 weight,
1403 instruction: recipe.instruction.as_ref().map(|s| s.to_string()),
1404 })
1405 }
1406
1407 fn make_triplet_with_anchor(
1408 &mut self,
1409 recipe: &TripletRecipe,
1410 record: &DataRecord,
1411 rng: &mut DeterministicRng,
1412 ) -> Option<SampleTriplet> {
1413 if Self::is_auto_chunk_pair_recipe(recipe) {
1414 return self.make_auto_chunk_pair_triplet_with_anchor(recipe, record, rng);
1415 }
1416 self.make_standard_triplet_with_anchor(recipe, record, rng)
1417 }
1418
1419 fn make_text_sample_for_split(
1420 &mut self,
1421 recipe: &TextRecipe,
1422 source: Option<&str>,
1423 split: SplitLabel,
1424 rng: &mut DeterministicRng,
1425 ) -> Option<TextSample> {
1426 let record = self.choose_anchor_record(source, split)?;
1427 let mut chunk = self.select_chunk(&record, &recipe.selector)?;
1428 self.decorate_chunk(&record, &mut chunk, rng);
1429 let weight = recipe.weight * self.chunk_weight(&chunk);
1430 Some(TextSample {
1431 recipe: recipe.name.to_string(),
1432 chunk,
1433 weight,
1434 instruction: recipe.instruction.as_ref().map(|s| s.to_string()),
1435 })
1436 }
1437
1438 fn chunk_weight(&self, chunk: &RecordChunk) -> f32 {
1439 chunk_weight(&self.config.chunking, chunk)
1440 }
1441
1442 fn triplet_chunk_weight(
1443 &self,
1444 recipe: &TripletRecipe,
1445 anchor: &RecordChunk,
1446 positive: &RecordChunk,
1447 negative: &RecordChunk,
1448 ) -> f32 {
1449 let floor = self.config.chunking.chunk_weight_floor;
1450 let negative_weight = negative.quality.trust.clamp(0.0, 1.0).max(floor);
1451 if Self::is_auto_chunk_pair_recipe(recipe) {
1452 let pair_trust = ((anchor.quality.trust.clamp(0.0, 1.0)
1455 + positive.quality.trust.clamp(0.0, 1.0))
1456 / 2.0)
1457 .clamp(0.0, 1.0);
1458 let pair_weight = (chunk_proximity_score(anchor, positive) * pair_trust).max(floor);
1459 return (pair_weight + pair_weight + negative_weight) / 3.0;
1461 }
1462 let pair_proximity = chunk_proximity_score(anchor, positive);
1465 let anchor_weight = (self.chunk_weight(anchor) * pair_proximity).max(floor);
1466 let positive_weight = (self.chunk_weight(positive) * pair_proximity).max(floor);
1467 (anchor_weight + positive_weight + negative_weight) / 3.0
1468 }
1469
1470 fn decorate_chunk(
1471 &mut self,
1472 record: &DataRecord,
1473 chunk: &mut RecordChunk,
1474 rng: &mut DeterministicRng,
1475 ) {
1476 chunk.kvp_meta = record
1477 .meta_prefix
1478 .as_ref()
1479 .map(|s| s.all_metadata())
1480 .unwrap_or_default();
1481 if let Some(spec) = record.meta_prefix.as_ref()
1482 && let Some(prefix) = spec.sample(rng)
1483 {
1484 let body_tokens: Vec<&str> = WhitespaceTokenizer.tokenize(&chunk.text);
1485 let prefix_tokens: Vec<&str> = WhitespaceTokenizer.tokenize(&prefix);
1486 let total_tokens = prefix_tokens.len() + body_tokens.len();
1487 let max_window = self.config.chunking.max_window_tokens;
1488 if max_window > 0 && total_tokens > max_window {
1489 if prefix_tokens.len() >= max_window {
1490 chunk.text = prefix_tokens
1491 .into_iter()
1492 .take(max_window)
1493 .collect::<Vec<_>>()
1494 .join(" ");
1495 chunk.tokens_estimate = max_window;
1496 } else {
1497 let remaining = max_window - prefix_tokens.len();
1498 let trimmed_body: Vec<&str> = body_tokens.into_iter().take(remaining).collect();
1499 chunk.text =
1500 format!("{}{}{}", prefix, platform_newline(), trimmed_body.join(" "));
1501 chunk.tokens_estimate = max_window;
1502 }
1503 } else {
1504 chunk.text = format!("{}{}{}", prefix, platform_newline(), chunk.text);
1505 chunk.tokens_estimate = total_tokens;
1506 }
1507 }
1508 }
1509
1510 fn select_chunk_parallel(
1514 &self,
1515 record: &DataRecord,
1516 selector: &Selector,
1517 rng: &mut DeterministicRng,
1518 ) -> Option<RecordChunk> {
1519 match selector {
1520 Selector::Role(role) => self.select_role_parallel(record, role, rng),
1521 Selector::Paragraph(idx) => record.sections.get(*idx).and_then(|section| {
1522 let pool = self.materialize_chunks(record, *idx, section);
1523 if pool.is_empty() {
1524 return None;
1525 }
1526 let i = rng.random_range(0..pool.len());
1527 pool.into_iter().nth(i)
1528 }),
1529 Selector::TemporalOffset(offset) => self
1530 .select_temporal_neighbor(record, *offset)
1531 .and_then(|neighbor| {
1532 self.select_role_parallel(&neighbor, &SectionRole::Context, rng)
1533 }),
1534 Selector::Random => {
1535 if record.sections.is_empty() {
1536 return None;
1537 }
1538 let idx = rng.random_range(0..record.sections.len());
1539 record.sections.get(idx).and_then(|section| {
1540 let pool = self.materialize_chunks(record, idx, section);
1541 if pool.is_empty() {
1542 return None;
1543 }
1544 let i = rng.random_range(0..pool.len());
1545 pool.into_iter().nth(i)
1546 })
1547 }
1548 }
1549 }
1550
1551 fn select_role_parallel(
1553 &self,
1554 record: &DataRecord,
1555 role: &SectionRole,
1556 rng: &mut DeterministicRng,
1557 ) -> Option<RecordChunk> {
1558 let indices: Vec<usize> = record
1559 .sections
1560 .iter()
1561 .enumerate()
1562 .filter(|(_, s)| roles_match(role, &s.role))
1563 .map(|(i, _)| i)
1564 .collect();
1565 if indices.is_empty() {
1566 return None;
1567 }
1568 let start = rng.random_range(0..indices.len());
1569 for offset in 0..indices.len() {
1570 let section_idx = indices[(start + offset) % indices.len()];
1571 let section = &record.sections[section_idx];
1572 let pool = self.materialize_chunks(record, section_idx, section);
1573 if !pool.is_empty() {
1574 let i = rng.random_range(0..pool.len());
1575 return pool.into_iter().nth(i);
1576 }
1577 }
1578 None
1579 }
1580
1581 fn decorate_chunk_parallel(
1583 &self,
1584 record: &DataRecord,
1585 chunk: &mut RecordChunk,
1586 rng: &mut DeterministicRng,
1587 ) {
1588 chunk.kvp_meta = record
1589 .meta_prefix
1590 .as_ref()
1591 .map(|s| s.all_metadata())
1592 .unwrap_or_default();
1593 if let Some(spec) = record.meta_prefix.as_ref()
1594 && let Some(prefix) = spec.sample(rng)
1595 {
1596 let body_tokens: Vec<&str> = WhitespaceTokenizer.tokenize(&chunk.text);
1597 let prefix_tokens: Vec<&str> = WhitespaceTokenizer.tokenize(&prefix);
1598 let total_tokens = prefix_tokens.len() + body_tokens.len();
1599 let max_window = self.config.chunking.max_window_tokens;
1600 if max_window > 0 && total_tokens > max_window {
1601 if prefix_tokens.len() >= max_window {
1602 chunk.text = prefix_tokens
1603 .into_iter()
1604 .take(max_window)
1605 .collect::<Vec<_>>()
1606 .join(" ");
1607 chunk.tokens_estimate = max_window;
1608 } else {
1609 let remaining = max_window - prefix_tokens.len();
1610 let trimmed_body: Vec<&str> = body_tokens.into_iter().take(remaining).collect();
1611 chunk.text =
1612 format!("{}{}{}", prefix, platform_newline(), trimmed_body.join(" "));
1613 chunk.tokens_estimate = max_window;
1614 }
1615 } else {
1616 chunk.text = format!("{}{}{}", prefix, platform_newline(), chunk.text);
1617 chunk.tokens_estimate = total_tokens;
1618 }
1619 }
1620 }
1621
1622 fn select_anchor_positive_parallel(
1624 &self,
1625 record: &DataRecord,
1626 anchor_selector: &Selector,
1627 positive_selector: &Selector,
1628 enforce_window_pair: bool,
1629 rng: &mut DeterministicRng,
1630 ) -> Option<(RecordChunk, RecordChunk)> {
1631 let anchor_chunk = self.select_chunk_parallel(record, anchor_selector, rng)?;
1632 let mut positive_chunk = self.select_chunk_parallel(record, positive_selector, rng)?;
1633 if anchor_selector == positive_selector {
1634 let mut retries = 0usize;
1635 while !same_selector_pair_is_valid(&anchor_chunk, &positive_chunk, enforce_window_pair)
1636 && retries < SAME_SELECTOR_PAIR_RETRY_LIMIT
1637 {
1638 positive_chunk = self.select_chunk_parallel(record, positive_selector, rng)?;
1639 retries += 1;
1640 }
1641 if !same_selector_pair_is_valid(&anchor_chunk, &positive_chunk, enforce_window_pair) {
1642 return None;
1643 }
1644 }
1645 Some((anchor_chunk, positive_chunk))
1647 }
1648
1649 fn select_anchor_positive_for_recipe(
1652 &self,
1653 recipe: &TripletRecipe,
1654 anchor_record: &DataRecord,
1655 rng: &mut DeterministicRng,
1656 ) -> Option<(RecordChunk, RecordChunk, String)> {
1657 if Self::is_auto_chunk_pair_recipe(recipe) {
1658 if !self
1659 .record_has_at_least_two_window_chunks_for_selector(anchor_record, &recipe.anchor)
1660 {
1661 return None;
1662 }
1663 let mut anchor_chunk =
1664 self.select_chunk_parallel(anchor_record, &recipe.anchor, rng)?;
1665 let mut positive_chunk =
1666 self.select_chunk_parallel(anchor_record, &recipe.anchor, rng)?;
1667 let mut tries = 0usize;
1668 while !same_selector_pair_is_valid(&anchor_chunk, &positive_chunk, true) {
1669 tries += 1;
1670 if tries >= SAME_SELECTOR_PAIR_RETRY_LIMIT {
1671 return None;
1672 }
1673 anchor_chunk = self.select_chunk_parallel(anchor_record, &recipe.anchor, rng)?;
1674 positive_chunk = self.select_chunk_parallel(anchor_record, &recipe.anchor, rng)?;
1675 }
1676 let anchor_raw_text = anchor_chunk.text.clone();
1677 self.decorate_chunk_parallel(anchor_record, &mut anchor_chunk, rng);
1678 self.decorate_chunk_parallel(anchor_record, &mut positive_chunk, rng);
1679 return Some((anchor_chunk, positive_chunk, anchor_raw_text));
1680 }
1681 let (mut anchor_chunk, mut positive_chunk) = self.select_anchor_positive_parallel(
1682 anchor_record,
1683 &recipe.anchor,
1684 &recipe.positive_selector,
1685 false,
1686 rng,
1687 )?;
1688 let anchor_raw_text = anchor_chunk.text.clone();
1689 self.decorate_chunk_parallel(anchor_record, &mut anchor_chunk, rng);
1690 self.decorate_chunk_parallel(anchor_record, &mut positive_chunk, rng);
1691 Some((anchor_chunk, positive_chunk, anchor_raw_text))
1692 }
1693
1694 fn select_chunk(&mut self, record: &DataRecord, selector: &Selector) -> Option<RecordChunk> {
1695 match selector {
1696 Selector::Role(role) => self.select_by_role(record, role),
1697 Selector::Paragraph(idx) => record.sections.get(*idx).and_then(|section| {
1698 let pool = self.materialize_chunks(record, *idx, section);
1699 self.next_chunk_from_pool(&record.id, *idx, pool)
1700 }),
1701 Selector::TemporalOffset(offset) => self
1702 .select_temporal_neighbor(record, *offset)
1703 .and_then(|neighbor| self.select_by_role(&neighbor, &SectionRole::Context)),
1704 Selector::Random => {
1705 if record.sections.is_empty() {
1706 return None;
1707 }
1708 let idx = self.rng.random_range(0..record.sections.len());
1709 record.sections.get(idx).and_then(|section| {
1710 let pool = self.materialize_chunks(record, idx, section);
1711 self.next_chunk_from_pool(&record.id, idx, pool)
1712 })
1713 }
1714 }
1715 }
1716
1717 fn select_by_role(&mut self, record: &DataRecord, role: &SectionRole) -> Option<RecordChunk> {
1718 let indices: Vec<usize> = record
1719 .sections
1720 .iter()
1721 .enumerate()
1722 .filter(|(_, section)| roles_match(role, §ion.role))
1723 .map(|(idx, _)| idx)
1724 .collect();
1725 if indices.is_empty() {
1726 return None;
1727 }
1728 let key = role_cursor_key(&record.id, role);
1729 let start_offset = self
1730 .role_cursors
1731 .get(&key)
1732 .and_then(|last_idx| indices.iter().position(|idx| idx == last_idx))
1733 .map(|pos| (pos + 1) % indices.len())
1734 .unwrap_or_else(|| {
1735 let seed_key = format!("{}::{}", key.0, key.1);
1739 (stable_hash_str(self.epoch_seed(), &seed_key) as usize) % indices.len()
1740 });
1741 for offset in 0..indices.len() {
1742 let section_idx = indices[(start_offset + offset) % indices.len()];
1743 let section = &record.sections[section_idx];
1744 let pool = self.materialize_chunks(record, section_idx, section);
1745 if let Some(chunk) = self.next_chunk_from_pool(&record.id, section_idx, pool) {
1746 self.role_cursors.insert(key.clone(), section_idx);
1747 return Some(chunk);
1748 }
1749 }
1750 None
1751 }
1752
1753 fn materialize_chunks(
1766 &self,
1767 record: &DataRecord,
1768 section_idx: usize,
1769 section: &RecordSection,
1770 ) -> Vec<RecordChunk> {
1771 self.chunker
1772 .materialize(&self.config.chunking, record, section_idx, section)
1773 }
1774
1775 fn build_derived_text_recipes(recipes: &[TripletRecipe]) -> Vec<TextRecipe> {
1776 let mut derived = Vec::new();
1777 for recipe in recipes {
1778 let base = recipe.name.as_ref();
1779 derived.push(TextRecipe {
1780 name: Cow::Owned(format!("{base}_anchor")),
1781 selector: recipe.anchor.clone(),
1782 weight: recipe.weight,
1783 instruction: None,
1784 });
1785 derived.push(TextRecipe {
1786 name: Cow::Owned(format!("{base}_positive")),
1787 selector: recipe.positive_selector.clone(),
1788 weight: recipe.weight,
1789 instruction: None,
1790 });
1791 derived.push(TextRecipe {
1792 name: Cow::Owned(format!("{base}_negative")),
1793 selector: recipe.negative_selector.clone(),
1794 weight: recipe.weight,
1795 instruction: None,
1796 });
1797 }
1798 derived
1799 }
1800
1801 fn record_has_long_anchor_or_context_section(&self, record: &DataRecord) -> bool {
1804 let window = self.config.chunking.max_window_tokens;
1805 if window == 0 {
1806 return false;
1807 }
1808 record.sections.iter().any(|section| {
1809 matches!(section.role, SectionRole::Anchor | SectionRole::Context)
1810 && WhitespaceTokenizer.token_count(§ion.text) > window
1811 })
1812 }
1813
1814 fn sync_records_from_cache(&mut self) -> Result<(), SamplerError> {
1815 let mut snapshot = self.ingestion.all_records_snapshot();
1816 snapshot.sort_by(|a, b| a.id.cmp(&b.id));
1817 self.records.clear();
1818 self.sources_with_long_sections.clear();
1819 self.negative_backend.on_sync_start();
1821 for record in snapshot {
1822 if self.split_store.label_for(&record.id).is_none() {
1823 self.split_store.ensure(record.id.clone())?;
1824 }
1825 if self.record_has_long_anchor_or_context_section(&record) {
1826 self.sources_with_long_sections
1828 .insert(record.source.clone());
1829 }
1830 self.records.insert(record.id.clone(), Arc::new(record));
1831 }
1832 self.prune_cursor_state();
1833 self.rebuild_chunk_index();
1834 self.rebuild_source_index()?;
1835 Ok(())
1836 }
1837
1838 fn ingest_internal_for_split(&mut self, target_split: SplitLabel) -> Result<(), SamplerError> {
1839 if !self.ingestion.has_sources() {
1840 return Ok(());
1841 }
1842 if !self.ingestion_cursors_loaded {
1843 if let Some(state) = self.split_store.load_sampler_state()? {
1844 self.ingestion.load_cursors(&state.source_stream_cursors);
1845 self.ingestion.set_source_epoch(state.source_epoch);
1846 }
1847 self.ingestion_cursors_loaded = true;
1848 }
1849 if self.ingestion.all_caches_empty() {
1850 self.ingestion.refresh_all();
1851 } else {
1852 self.ingestion.advance(self.config.batch_size);
1853 }
1854 let observed = self.ingestion.total_ingest_count();
1855 if observed == self.last_observed_ingest && !self.records.is_empty() {
1856 return Ok(());
1857 }
1858 self.last_observed_ingest = observed;
1859 self.sync_records_from_cache()?;
1860 let max_window_tokens = self.config.chunking.max_window_tokens;
1861 self.negative_backend.on_records_refreshed(
1862 &self.records,
1863 max_window_tokens,
1864 &|id| self.split_store.label_for(id),
1865 self.ingestion.last_refreshed_sources(),
1866 );
1867 self.epoch_tracker.ensure_loaded()?;
1870 let records_by_split = self.records_by_split()?;
1871 self.epoch_tracker
1872 .reconcile(target_split, &records_by_split);
1873 self.ensure_source_state()?;
1874 Ok(())
1875 }
1876
1877 #[cfg(test)]
1878 fn ingest_internal(&mut self, split: SplitLabel) -> Result<(), SamplerError> {
1879 self.ingest_internal_for_split(split)
1880 }
1881
1882 fn ingest_internal_with_weights_for_split(
1883 &mut self,
1884 target_split: SplitLabel,
1885 weights: &HashMap<SourceId, f32>,
1886 ) -> Result<(), SamplerError> {
1887 if !self.ingestion.has_sources() {
1888 return Ok(());
1889 }
1890 if !self.ingestion_cursors_loaded {
1891 if let Some(state) = self.split_store.load_sampler_state()? {
1892 self.ingestion.load_cursors(&state.source_stream_cursors);
1893 self.ingestion.set_source_epoch(state.source_epoch);
1894 }
1895 self.ingestion_cursors_loaded = true;
1896 }
1897 if self.ingestion.all_caches_empty() {
1898 self.ingestion.refresh_all_with_weights(weights)?;
1899 } else {
1900 self.ingestion
1901 .advance_with_weights(self.config.batch_size, weights)?;
1902 }
1903 let observed = self.ingestion.total_ingest_count();
1904 if observed == self.last_observed_ingest && !self.records.is_empty() {
1905 return Ok(());
1906 }
1907 self.last_observed_ingest = observed;
1908 self.sync_records_from_cache()?;
1909 let max_window_tokens = self.config.chunking.max_window_tokens;
1910 self.negative_backend.on_records_refreshed(
1911 &self.records,
1912 max_window_tokens,
1913 &|id| self.split_store.label_for(id),
1914 self.ingestion.last_refreshed_sources(),
1915 );
1916 self.epoch_tracker.ensure_loaded()?;
1917 let records_by_split = self.records_by_split()?;
1918 self.epoch_tracker
1919 .reconcile(target_split, &records_by_split);
1920 self.ensure_source_state()?;
1921 Ok(())
1922 }
1923
1924 fn ingest_with_weights_fallback(
1929 &mut self,
1930 target_split: SplitLabel,
1931 weights: Option<&HashMap<SourceId, f32>>,
1932 ) -> Result<(), SamplerError> {
1933 match weights {
1934 Some(weights)
1935 if !weights.is_empty()
1936 && !weights
1937 .values()
1938 .all(|&w| w == *weights.values().next().unwrap()) =>
1939 {
1940 self.ingest_internal_with_weights_for_split(target_split, weights)?
1941 }
1942 _ => self.ingest_internal_for_split(target_split)?,
1943 }
1944 Ok(())
1945 }
1946
1947 fn force_ingest_refresh_with_weights_for_split(
1948 &mut self,
1949 target_split: SplitLabel,
1950 weights: &HashMap<SourceId, f32>,
1951 ) -> Result<(), SamplerError> {
1952 if !self.ingestion.has_sources() {
1953 return Ok(());
1954 }
1955 if !self.ingestion_cursors_loaded {
1956 if let Some(state) = self.split_store.load_sampler_state()? {
1957 self.ingestion.load_cursors(&state.source_stream_cursors);
1958 self.ingestion.set_source_epoch(state.source_epoch);
1959 }
1960 self.ingestion_cursors_loaded = true;
1961 }
1962 self.ingestion.force_refresh_all_with_weights(weights)?;
1963 self.last_observed_ingest = self.ingestion.total_ingest_count();
1964 self.sync_records_from_cache()?;
1965 let max_window_tokens = self.config.chunking.max_window_tokens;
1966 self.negative_backend.on_records_refreshed(
1967 &self.records,
1968 max_window_tokens,
1969 &|id| self.split_store.label_for(id),
1970 self.ingestion.last_refreshed_sources(),
1971 );
1972 self.epoch_tracker.ensure_loaded()?;
1973 let records_by_split = self.records_by_split()?;
1974 self.epoch_tracker
1975 .reconcile(target_split, &records_by_split);
1976 self.ensure_source_state()?;
1977 Ok(())
1978 }
1979
1980 fn sample_source_triplet_candidate(
1990 &mut self,
1991 source: &str,
1992 target_split: SplitLabel,
1993 recipe_orders: &mut HashMap<RecipeKey, Vec<usize>>,
1994 recipe_positions: &mut HashMap<RecipeKey, usize>,
1995 rng: &mut DeterministicRng,
1996 ) -> (Option<(TripletRecipe, SampleTriplet)>, usize) {
1997 let (recipes, _auto_injected) = self.resolve_source_triplet_plan(source);
2000 if recipes.is_empty() {
2001 return (None, 0);
2002 }
2003 if !recipe_orders.contains_key(source) {
2004 let recipe_weights: Vec<f32> = recipes.iter().map(|r| r.weight).collect();
2005 let order =
2006 self.recipe_order_weighted_cycled(&recipe_weights, self.triplet_recipe_rr_idx, rng);
2007 recipe_orders.insert(source.to_string(), order);
2008 }
2009 let order = recipe_orders
2010 .get(source)
2011 .expect("recipe order missing for source");
2012 let pos = recipe_positions.entry(source.to_string()).or_insert(0);
2013 let Some(anchor) = self.choose_anchor_record(Some(source), target_split) else {
2014 return (None, 0);
2015 };
2016
2017 let mut attempts = 0usize;
2018 for offset in 0..order.len() {
2019 let idx = order[(*pos + offset) % order.len()];
2020 attempts = attempts.saturating_add(1);
2021 let recipe = recipes[idx].clone();
2022 if let Some(sample) = self.make_triplet_with_anchor(&recipe, &anchor, rng) {
2026 *pos = (*pos + offset + 1) % order.len();
2027 return (Some((recipe, sample)), attempts);
2028 }
2029 }
2030
2031 (None, attempts)
2032 }
2033
2034 fn next_pair_batch_inner_with_weights(
2035 &mut self,
2036 target_split: SplitLabel,
2037 weights: Option<&HashMap<SourceId, f32>>,
2038 ) -> Result<SampleBatch, SamplerError> {
2039 if let Some(weights) = weights {
2040 if weights.is_empty()
2041 || weights
2042 .values()
2043 .all(|&w| w == *weights.values().next().unwrap())
2044 {
2045 self.ingest_internal_for_split(target_split)?;
2046 } else {
2047 self.ingest_internal_with_weights_for_split(target_split, weights)?;
2048 }
2049 } else {
2050 self.ingest_internal_for_split(target_split)?;
2051 }
2052 self.ensure_split_has_records(target_split)?;
2053 let sources = self.source_order.clone();
2054 if sources.is_empty() {
2055 if self.triplet_recipes.is_empty() {
2056 return Err(SamplerError::Configuration(
2057 "no triplet recipes available".into(),
2058 ));
2059 }
2060 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2061 let recipe_weights: Vec<f32> = self.triplet_recipes.iter().map(|r| r.weight).collect();
2062 let recipe_order = self.recipe_order_weighted_cycled(
2063 &recipe_weights,
2064 self.triplet_recipe_rr_idx,
2065 &mut rng,
2066 );
2067 let mut pairs = Vec::new();
2068 let mut seen = HashSet::new();
2069 let mut last_recipe_name = None;
2070 let mut recipe_pos = 0usize;
2071 let mut recipe_steps = 0usize;
2072 let attempts = self.config.batch_size * 4 * recipe_order.len().max(1);
2073 for _ in 0..attempts {
2074 if pairs.len() >= self.config.batch_size {
2075 break;
2076 }
2077 let Some(anchor) = self.choose_anchor_record(None, target_split) else {
2078 break;
2079 };
2080 let mut triplet = None;
2081 for offset in 0..recipe_order.len() {
2082 let idx = recipe_order[(recipe_pos + offset) % recipe_order.len()];
2083 recipe_steps = recipe_steps.saturating_add(1);
2084 let recipe = self.triplet_recipes[idx].clone();
2085 last_recipe_name = Some(recipe.name.clone());
2086 if let Some(sample) = self.make_triplet_with_anchor(&recipe, &anchor, &mut rng)
2087 {
2088 triplet = Some((recipe, sample));
2089 recipe_pos = (recipe_pos + offset + 1) % recipe_order.len();
2090 break;
2091 }
2092 }
2093 if let Some((recipe, triplet)) = triplet {
2094 let key = (
2095 triplet.anchor.record_id.clone(),
2096 triplet.positive.record_id.clone(),
2097 triplet.negative.record_id.clone(),
2098 );
2099 if seen.insert(key) {
2100 let SampleTriplet {
2101 recipe: triplet_recipe_name,
2102 anchor,
2103 positive,
2104 negative,
2105 weight,
2106 instruction,
2107 } = triplet;
2108 if pairs.len() < self.config.batch_size {
2109 pairs.push(SamplePair {
2110 recipe: triplet_recipe_name.clone(),
2111 anchor: anchor.clone(),
2112 positive: positive.clone(),
2113 weight,
2114 instruction: instruction.clone(),
2115 label: PairLabel::Positive,
2116 reason: None,
2117 });
2118 }
2119 if pairs.len() < self.config.batch_size {
2120 pairs.push(SamplePair {
2121 recipe: triplet_recipe_name,
2122 anchor,
2123 positive: negative,
2124 weight,
2125 instruction,
2126 label: PairLabel::Negative,
2127 reason: Some(
2128 strategy_reason(&recipe.negative_strategy).to_string(),
2129 ),
2130 });
2131 }
2132 }
2133 }
2134 }
2135 if recipe_steps > 0 {
2136 self.triplet_recipe_rr_idx =
2137 self.triplet_recipe_rr_idx.saturating_add(recipe_steps);
2138 }
2139 self.rng = rng;
2140 pad_with_reuse(&mut pairs, self.config.batch_size);
2141 if pairs.len() == self.config.batch_size {
2142 return Ok(SampleBatch { pairs });
2143 }
2144 return Err(SamplerError::Exhausted(
2145 last_recipe_name
2146 .unwrap_or(Cow::Borrowed(RECIPE_LABEL_TRIPLETS))
2147 .to_string(),
2148 ));
2149 }
2150
2151 let mut pairs = Vec::new();
2152 let mut seen = HashSet::new();
2153 let mut source_steps = 0usize;
2154 let mut cycle = (self.source_cycle_idx / sources.len()) as u64;
2155 let mut source_idx = self.source_cycle_idx % sources.len();
2156 let mut cycle_sources = self.shuffled_source_cycle(cycle);
2157 let mut recipe_orders: HashMap<RecipeKey, Vec<usize>> = HashMap::new();
2158 let mut recipe_positions: HashMap<RecipeKey, usize> = HashMap::new();
2159 let mut recipe_steps = 0usize;
2160 let max_recipe_len = sources
2161 .iter()
2162 .map(|source| self.triplet_recipe_count_for_source(source))
2163 .max()
2164 .unwrap_or(1)
2165 .max(1);
2166 let attempts = self.config.batch_size * 4 * sources.len() * max_recipe_len;
2167 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2168 for _ in 0..attempts {
2169 if pairs.len() >= self.config.batch_size {
2170 break;
2171 }
2172 let source = cycle_sources[source_idx].as_str();
2173 let (triplet, attempts_used) = self.sample_source_triplet_candidate(
2174 source,
2175 target_split,
2176 &mut recipe_orders,
2177 &mut recipe_positions,
2178 &mut rng,
2179 );
2180 recipe_steps = recipe_steps.saturating_add(attempts_used);
2181 if let Some((recipe, triplet)) = triplet {
2182 let key = (
2183 triplet.anchor.record_id.clone(),
2184 triplet.positive.record_id.clone(),
2185 triplet.negative.record_id.clone(),
2186 );
2187 if seen.insert(key) {
2188 let SampleTriplet {
2189 recipe: triplet_recipe_name,
2190 anchor,
2191 positive,
2192 negative,
2193 weight,
2194 instruction,
2195 } = triplet;
2196 if pairs.len() < self.config.batch_size {
2197 pairs.push(SamplePair {
2198 recipe: triplet_recipe_name.clone(),
2199 anchor: anchor.clone(),
2200 positive: positive.clone(),
2201 weight,
2202 instruction: instruction.clone(),
2203 label: PairLabel::Positive,
2204 reason: None,
2205 });
2206 }
2207 if pairs.len() < self.config.batch_size {
2208 pairs.push(SamplePair {
2209 recipe: triplet_recipe_name,
2210 anchor,
2211 positive: negative,
2212 weight,
2213 instruction,
2214 label: PairLabel::Negative,
2215 reason: Some(strategy_reason(&recipe.negative_strategy).to_string()),
2216 });
2217 }
2218 }
2219 }
2220 source_idx += 1;
2221 source_steps += 1;
2222 if source_idx >= cycle_sources.len() {
2223 source_idx = 0;
2224 cycle = cycle.saturating_add(1);
2225 cycle_sources = self.shuffled_source_cycle(cycle);
2226 }
2227 }
2228 if recipe_steps > 0 {
2229 self.triplet_recipe_rr_idx = self.triplet_recipe_rr_idx.saturating_add(recipe_steps);
2230 }
2231 self.rng = rng;
2232 pad_with_reuse(&mut pairs, self.config.batch_size);
2233 if pairs.len() == self.config.batch_size {
2234 self.source_cycle_idx = self.source_cycle_idx.saturating_add(source_steps);
2235 self.source_state_dirty = sources.len() > 1;
2236 return Ok(SampleBatch { pairs });
2237 }
2238 Err(SamplerError::Exhausted(RECIPE_LABEL_TRIPLETS.into()))
2239 }
2240
2241 fn next_text_batch_inner_with_weights(
2242 &mut self,
2243 target_split: SplitLabel,
2244 weights: Option<&HashMap<SourceId, f32>>,
2245 ) -> Result<TextBatch, SamplerError> {
2246 self.ingest_with_weights_fallback(target_split, weights)?;
2247 self.ensure_split_has_records(target_split)?;
2248 let sources = self.source_order.clone();
2249 if sources.is_empty() {
2250 if self.text_recipes.is_empty() {
2251 return Err(SamplerError::Configuration(
2252 "no text recipes configured".into(),
2253 ));
2254 }
2255 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2256 let recipe_weights: Vec<f32> = self.text_recipes.iter().map(|r| r.weight).collect();
2257 let recipe_order = self.text_recipe_order_weighted_cycled(
2258 &recipe_weights,
2259 self.text_recipe_rr_idx,
2260 &mut rng,
2261 );
2262 let mut samples = Vec::new();
2263 let mut seen = HashSet::new();
2264 let mut last_recipe_name = None;
2265 let mut recipe_pos = 0usize;
2266 let mut recipe_steps = 0usize;
2267 let attempts = self.config.batch_size * 4 * recipe_order.len().max(1);
2268 for _ in 0..attempts {
2269 if samples.len() >= self.config.batch_size {
2270 break;
2271 }
2272 let recipe_idx = recipe_order[recipe_pos];
2273 recipe_pos = (recipe_pos + 1) % recipe_order.len();
2274 recipe_steps = recipe_steps.saturating_add(1);
2275 let recipe = self.text_recipes[recipe_idx].clone();
2276 last_recipe_name = Some(recipe.name.clone());
2277 if let Some(sample) =
2278 self.make_text_sample_for_split(&recipe, None, target_split, &mut rng)
2279 {
2280 let key = chunk_key(&sample.chunk);
2281 if seen.insert(key) {
2282 samples.push(sample);
2283 }
2284 }
2285 }
2286 if recipe_steps > 0 {
2287 self.text_recipe_rr_idx = self.text_recipe_rr_idx.saturating_add(recipe_steps);
2288 }
2289 self.rng = rng;
2290 pad_with_reuse(&mut samples, self.config.batch_size);
2291 if samples.len() == self.config.batch_size {
2292 return Ok(TextBatch { samples });
2293 }
2294 return Err(SamplerError::Exhausted(
2295 last_recipe_name
2296 .unwrap_or(Cow::Borrowed(RECIPE_LABEL_TEXT))
2297 .to_string(),
2298 ));
2299 }
2300
2301 let mut samples = Vec::new();
2302 let mut seen = HashSet::new();
2303 let mut source_steps = 0usize;
2304 let mut cycle = (self.source_cycle_idx / sources.len()) as u64;
2305 let mut idx = self.source_cycle_idx % sources.len();
2306 let mut cycle_sources = self.shuffled_source_cycle(cycle);
2307 let mut recipe_orders: HashMap<RecipeKey, Vec<usize>> = HashMap::new();
2308 let mut recipe_positions: HashMap<RecipeKey, usize> = HashMap::new();
2309 let mut recipe_steps = 0usize;
2310 let max_recipe_len = sources
2311 .iter()
2312 .map(|source| self.text_recipes_for_source(source).len())
2313 .max()
2314 .unwrap_or(1)
2315 .max(1);
2316 let attempts = self.config.batch_size * 4 * sources.len() * max_recipe_len;
2317 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2318 for _ in 0..attempts {
2319 if samples.len() >= self.config.batch_size {
2320 break;
2321 }
2322 let source = cycle_sources[idx].as_str();
2323 let recipes = self.text_recipes_for_source(source).to_vec();
2324 if recipes.is_empty() {
2325 idx += 1;
2326 source_steps += 1;
2327 if idx >= cycle_sources.len() {
2328 idx = 0;
2329 cycle = cycle.saturating_add(1);
2330 cycle_sources = self.shuffled_source_cycle(cycle);
2331 }
2332 continue;
2333 }
2334 if !recipe_orders.contains_key(source) {
2335 let recipe_weights: Vec<f32> = recipes.iter().map(|r| r.weight).collect();
2336 let order = self.text_recipe_order_weighted_cycled(
2337 &recipe_weights,
2338 self.text_recipe_rr_idx,
2339 &mut rng,
2340 );
2341 recipe_orders.insert(source.to_string(), order);
2342 }
2343 let order = recipe_orders
2344 .get(source)
2345 .expect("recipe order missing for source");
2346 let pos = recipe_positions.entry(source.to_string()).or_insert(0);
2347 let mut sample: Option<(TextRecipe, TextSample)> = None;
2348 for offset in 0..order.len() {
2349 let recipe_idx = order[(*pos + offset) % order.len()];
2350 let recipe = recipes[recipe_idx].clone();
2351 if let Some(item) =
2352 self.make_text_sample_for_split(&recipe, Some(source), target_split, &mut rng)
2353 {
2354 recipe_steps = recipe_steps.saturating_add(offset + 1);
2355 *pos = (*pos + offset + 1) % order.len();
2356 sample = Some((recipe, item));
2357 break;
2358 }
2359 }
2360 if sample.is_none() {
2361 recipe_steps = recipe_steps.saturating_add(order.len());
2362 }
2363 if let Some((_recipe, sample)) = sample {
2364 let key = chunk_key(&sample.chunk);
2365 if seen.insert(key) {
2366 samples.push(sample);
2367 }
2368 }
2369 idx += 1;
2370 source_steps += 1;
2371 if idx >= cycle_sources.len() {
2372 idx = 0;
2373 cycle = cycle.saturating_add(1);
2374 cycle_sources = self.shuffled_source_cycle(cycle);
2375 }
2376 }
2377 if samples.len() != self.config.batch_size {
2378 pad_with_reuse(&mut samples, self.config.batch_size);
2379 }
2380 if samples.len() != self.config.batch_size {
2381 self.rng = rng;
2382 return Err(SamplerError::Exhausted(RECIPE_LABEL_TEXT.into()));
2383 }
2384 self.rng = rng;
2385 self.source_cycle_idx = self.source_cycle_idx.saturating_add(source_steps);
2386 self.source_state_dirty = sources.len() > 1;
2387 if recipe_steps > 0 {
2388 self.text_recipe_rr_idx = self.text_recipe_rr_idx.saturating_add(recipe_steps);
2389 }
2390 Ok(TextBatch { samples })
2391 }
2392
2393 fn next_triplet_batch_inner_with_weights(
2394 &mut self,
2395 target_split: SplitLabel,
2396 weights: Option<&HashMap<SourceId, f32>>,
2397 ) -> Result<TripletBatch, SamplerError> {
2398 self.ingest_with_weights_fallback(target_split, weights)?;
2399 self.ensure_split_has_records(target_split)?;
2400 let sources = self.source_order.clone();
2401 if sources.is_empty() {
2402 if self.triplet_recipes.is_empty() {
2403 return Err(SamplerError::Configuration(
2404 "no triplet recipes configured".into(),
2405 ));
2406 }
2407 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2408 let recipe_weights: Vec<f32> = self.triplet_recipes.iter().map(|r| r.weight).collect();
2409 let recipe_order = self.recipe_order_weighted_cycled(
2410 &recipe_weights,
2411 self.triplet_recipe_rr_idx,
2412 &mut rng,
2413 );
2414 let mut triplets = Vec::new();
2415 let mut seen = HashSet::new();
2416 let mut last_recipe_name = None;
2417 let mut recipe_pos = 0usize;
2418 let mut recipe_steps = 0usize;
2419 let attempts = self.config.batch_size * 4 * recipe_order.len().max(1);
2420 for _ in 0..attempts {
2421 if triplets.len() >= self.config.batch_size {
2422 break;
2423 }
2424 let Some(anchor) = self.choose_anchor_record(None, target_split) else {
2425 break;
2426 };
2427 let mut triplet = None;
2428 for offset in 0..recipe_order.len() {
2429 let idx = recipe_order[(recipe_pos + offset) % recipe_order.len()];
2430 recipe_steps = recipe_steps.saturating_add(1);
2431 let recipe = self.triplet_recipes[idx].clone();
2432 last_recipe_name = Some(recipe.name.clone());
2433 if let Some(sample) = self.make_triplet_with_anchor(&recipe, &anchor, &mut rng)
2434 {
2435 triplet = Some(sample);
2436 recipe_pos = (recipe_pos + offset + 1) % recipe_order.len();
2437 break;
2438 }
2439 }
2440 if let Some(triplet) = triplet {
2441 let key = (
2442 triplet.anchor.record_id.clone(),
2443 triplet.positive.record_id.clone(),
2444 triplet.negative.record_id.clone(),
2445 );
2446 if seen.insert(key) {
2447 triplets.push(triplet);
2448 }
2449 }
2450 }
2451 if recipe_steps > 0 {
2452 self.triplet_recipe_rr_idx =
2453 self.triplet_recipe_rr_idx.saturating_add(recipe_steps);
2454 }
2455 self.rng = rng;
2456 pad_with_reuse(&mut triplets, self.config.batch_size);
2457 if triplets.len() == self.config.batch_size {
2458 return Ok(TripletBatch { triplets });
2459 }
2460 return Err(SamplerError::Exhausted(
2461 last_recipe_name
2462 .unwrap_or(Cow::Borrowed(RECIPE_LABEL_TRIPLETS))
2463 .to_string(),
2464 ));
2465 }
2466
2467 let mut triplets: Vec<SampleTriplet> = Vec::new();
2469 let mut seen: HashSet<(RecordId, RecordId, RecordId)> = HashSet::new();
2470 let mut source_steps = 0usize;
2471 let mut cycle = (self.source_cycle_idx / sources.len()) as u64;
2472 let mut source_idx = self.source_cycle_idx % sources.len();
2473 let mut cycle_sources = self.shuffled_source_cycle(cycle);
2474 let mut recipe_steps = 0usize;
2475 let max_recipe_len = sources
2476 .iter()
2477 .map(|source| self.triplet_recipe_count_for_source(source.as_str()))
2478 .max()
2479 .unwrap_or(1)
2480 .max(1);
2481 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2482
2483 struct SlotPlan {
2488 anchor: Arc<DataRecord>,
2489 recipes: Vec<TripletRecipe>,
2490 fork_seed: u64,
2491 }
2492 let target_slots = self.config.batch_size * 4 * max_recipe_len;
2493 let mut slot_plans: Vec<SlotPlan> = Vec::with_capacity(target_slots);
2494
2495 for _ in 0..target_slots {
2496 if slot_plans.len() >= target_slots {
2497 break;
2498 }
2499 let source = cycle_sources[source_idx].as_str();
2500 let (recipes, _) = self.resolve_source_triplet_plan(source);
2501 if !recipes.is_empty() {
2502 let fork_seed = rng.next_u64();
2503 if let Some(anchor) = self.choose_anchor_record(Some(source), target_split) {
2504 slot_plans.push(SlotPlan {
2505 anchor,
2506 recipes,
2507 fork_seed,
2508 });
2509 recipe_steps = recipe_steps.saturating_add(1);
2510 }
2511 }
2512 source_idx += 1;
2513 source_steps += 1;
2514 if source_idx >= cycle_sources.len() {
2515 source_idx = 0;
2516 cycle = cycle.saturating_add(1);
2517 cycle_sources = self.shuffled_source_cycle(cycle);
2518 }
2519 }
2520 self.rng = rng;
2521
2522 struct SlotCandidate {
2527 recipe: TripletRecipe,
2528 anchor: Arc<DataRecord>,
2529 anchor_chunk: RecordChunk,
2530 positive_chunk: RecordChunk,
2531 anchor_raw_text: String,
2532 }
2533 let mut raw_candidates: Vec<(usize, Option<SlotCandidate>)> = slot_plans
2534 .par_iter()
2535 .enumerate()
2536 .map(|(slot_idx, plan)| {
2537 let mut fork_rng = DeterministicRng::new(plan.fork_seed);
2538 let weights: Vec<f32> = plan.recipes.iter().map(|r| r.weight).collect();
2541 let order = weighted_recipe_order(&weights, &mut fork_rng);
2542 let mut candidate = None;
2543 for &idx in &order {
2544 let recipe = &plan.recipes[idx];
2545 if let Some((ac, pc, raw)) =
2546 self.select_anchor_positive_for_recipe(recipe, &plan.anchor, &mut fork_rng)
2547 {
2548 candidate = Some(SlotCandidate {
2549 recipe: recipe.clone(),
2550 anchor: Arc::clone(&plan.anchor),
2551 anchor_chunk: ac,
2552 positive_chunk: pc,
2553 anchor_raw_text: raw,
2554 });
2555 break;
2556 }
2557 }
2558 (slot_idx, candidate)
2559 })
2560 .collect();
2561
2562 raw_candidates.sort_unstable_by_key(|(i, _)| *i);
2567 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2568 for (_, candidate) in raw_candidates {
2569 if triplets.len() >= self.config.batch_size {
2570 break;
2571 }
2572 let Some(sc) = candidate else { continue };
2573 let (negative_record, fallback_used) = match self.select_negative_record(
2574 &sc.anchor,
2575 &sc.recipe.negative_strategy,
2576 Some(&sc.anchor_raw_text),
2577 &mut rng,
2578 ) {
2579 Some(pair) => pair,
2580 None => continue,
2581 };
2582 let mut negative_chunk = match self.select_chunk_parallel(
2583 &negative_record,
2584 &sc.recipe.negative_selector,
2585 &mut rng,
2586 ) {
2587 Some(c) => c,
2588 None => continue,
2589 };
2590 self.decorate_chunk_parallel(&negative_record, &mut negative_chunk, &mut rng);
2591
2592 let (anchor_chunk, positive_chunk) = if rng.next_u64() & ANCHOR_POSITIVE_SWAP_MASK == 0
2594 {
2595 (sc.positive_chunk, sc.anchor_chunk)
2596 } else {
2597 (sc.anchor_chunk, sc.positive_chunk)
2598 };
2599
2600 if (!sc.recipe.allow_same_anchor_positive && anchor_chunk.text == positive_chunk.text)
2601 || negative_chunk.text == positive_chunk.text
2602 || negative_chunk.text == anchor_chunk.text
2603 {
2604 continue;
2605 }
2606
2607 let chunk_weight = self.triplet_chunk_weight(
2608 &sc.recipe,
2609 &anchor_chunk,
2610 &positive_chunk,
2611 &negative_chunk,
2612 );
2613 let weight = sc.recipe.weight * chunk_weight;
2614 let recipe_name = if fallback_used {
2615 format!("{}_fallback_same_split", sc.recipe.name)
2616 } else {
2617 sc.recipe.name.to_string()
2618 };
2619 let triplet = SampleTriplet {
2620 recipe: recipe_name,
2621 anchor: anchor_chunk,
2622 positive: positive_chunk,
2623 negative: negative_chunk,
2624 weight,
2625 instruction: sc.recipe.instruction.as_ref().map(|s| s.to_string()),
2626 };
2627 let key = (
2628 triplet.anchor.record_id.clone(),
2629 triplet.positive.record_id.clone(),
2630 triplet.negative.record_id.clone(),
2631 );
2632 if seen.insert(key) && triplets.len() < self.config.batch_size {
2633 triplets.push(triplet);
2634 }
2635 }
2636 self.rng = rng;
2637
2638 if recipe_steps > 0 {
2639 self.triplet_recipe_rr_idx = self.triplet_recipe_rr_idx.saturating_add(recipe_steps);
2640 }
2641 pad_with_reuse(&mut triplets, self.config.batch_size);
2642 if triplets.len() == self.config.batch_size {
2643 self.source_cycle_idx = self.source_cycle_idx.saturating_add(source_steps);
2644 self.source_state_dirty = sources.len() > 1;
2645 let batch = TripletBatch { triplets };
2646 return Ok(batch);
2647 }
2648 Err(SamplerError::Exhausted(RECIPE_LABEL_TRIPLETS.into()))
2649 }
2650
2651 #[cfg(test)]
2657 #[cfg(feature = "bm25-mining")]
2658 fn bm25_backend_mut(&mut self) -> &mut backends::Bm25Backend {
2659 self.negative_backend
2660 .as_any_mut()
2661 .downcast_mut::<backends::Bm25Backend>()
2662 .expect("bm25_backend_mut: negative_backend is Bm25Backend when bm25-mining feature is active")
2663 }
2664
2665 #[cfg(test)]
2667 fn recipe_order_weighted_shuffled_seeded(&mut self, weights: &[f32]) -> Vec<usize> {
2668 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2669 let result = self.recipe_order_weighted_shuffled(weights, &mut rng);
2670 self.rng = rng;
2671 result
2672 }
2673
2674 #[cfg(test)]
2676 fn recipe_order_weighted_cycled_seeded(
2677 &mut self,
2678 weights: &[f32],
2679 rr_idx: usize,
2680 ) -> Vec<usize> {
2681 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2682 let result = self.recipe_order_weighted_cycled(weights, rr_idx, &mut rng);
2683 self.rng = rng;
2684 result
2685 }
2686
2687 #[cfg(test)]
2689 fn text_recipe_order_weighted_shuffled_seeded(&mut self, weights: &[f32]) -> Vec<usize> {
2690 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2691 let result = self.text_recipe_order_weighted_shuffled(weights, &mut rng);
2692 self.rng = rng;
2693 result
2694 }
2695
2696 #[cfg(test)]
2698 fn text_recipe_order_weighted_cycled_seeded(
2699 &mut self,
2700 weights: &[f32],
2701 rr_idx: usize,
2702 ) -> Vec<usize> {
2703 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2704 let result = self.text_recipe_order_weighted_cycled(weights, rr_idx, &mut rng);
2705 self.rng = rng;
2706 result
2707 }
2708
2709 #[cfg(test)]
2711 fn select_negative_record_seeded(
2712 &mut self,
2713 anchor_record: &DataRecord,
2714 strategy: &NegativeStrategy,
2715 anchor_query_text: Option<&str>,
2716 ) -> Option<(Arc<DataRecord>, bool)> {
2717 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2718 let result =
2719 self.select_negative_record(anchor_record, strategy, anchor_query_text, &mut rng);
2720 self.rng = rng;
2721 result
2722 }
2723
2724 #[cfg(test)]
2726 fn make_triplet_with_anchor_seeded(
2727 &mut self,
2728 recipe: &TripletRecipe,
2729 anchor: &DataRecord,
2730 ) -> Option<SampleTriplet> {
2731 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2732 let result = self.make_triplet_with_anchor(recipe, anchor, &mut rng);
2733 self.rng = rng;
2734 result
2735 }
2736
2737 #[cfg(test)]
2739 fn decorate_chunk_seeded(&mut self, record: &DataRecord, chunk: &mut RecordChunk) {
2740 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2741 self.decorate_chunk(record, chunk, &mut rng);
2742 self.rng = rng;
2743 }
2744
2745 #[cfg(all(feature = "bm25-mining", feature = "extended-metrics"))]
2751 fn bm25_fallback_stats(&self) -> (u64, u64) {
2752 self.negative_backend.bm25_fallback_stats()
2753 }
2754
2755 #[cfg(test)]
2760 #[cfg(feature = "bm25-mining")]
2761 fn bm25_ranked_candidates(&mut self, anchor: &crate::data::DataRecord) -> Vec<RecordId> {
2762 let split = self
2763 .split_store
2764 .label_for(&anchor.id)
2765 .unwrap_or(SplitLabel::Train);
2766 self.negative_backend
2767 .as_any_mut()
2768 .downcast_mut::<backends::Bm25Backend>()
2769 .expect("bm25_ranked_candidates: Bm25Backend")
2770 .ranked_candidates_pub(anchor, split)
2771 }
2772}
2773
2774fn weighted_recipe_order(weights: &[f32], rng: &mut DeterministicRng) -> Vec<usize> {
2785 let nonzero: Vec<(usize, f32)> = weights
2786 .iter()
2787 .enumerate()
2788 .filter(|(_, w)| **w > 0.0)
2789 .map(|(i, &w)| (i, w))
2790 .collect();
2791 if nonzero.is_empty() {
2792 return Vec::new();
2793 }
2794 let w_min = nonzero
2795 .iter()
2796 .map(|(_, w)| *w)
2797 .fold(f32::INFINITY, f32::min);
2798 let mut order: Vec<usize> = Vec::new();
2799 for (recipe_idx, w) in &nonzero {
2800 let tickets = ((w / w_min).round() as usize).clamp(1, RECIPE_ORDER_MAX_WEIGHT_MULTIPLIER);
2801 for _ in 0..tickets {
2802 order.push(*recipe_idx);
2803 }
2804 }
2805 order.shuffle(rng);
2806 order
2807}
2808
2809fn same_selector_pair_is_valid(
2810 anchor_chunk: &RecordChunk,
2811 positive_chunk: &RecordChunk,
2812 enforce_window_pair: bool,
2813) -> bool {
2814 if chunk_key(anchor_chunk) == chunk_key(positive_chunk) {
2815 return false;
2816 }
2817 if !enforce_window_pair {
2818 return true;
2819 }
2820 matches!(
2821 (&anchor_chunk.view, &positive_chunk.view),
2822 (ChunkView::Window { .. }, ChunkView::Window { .. })
2823 )
2824}
2825
2826impl<S: SplitStore + EpochStateStore + SamplerStateStore + 'static> TripletSampler<S> {
2827 pub fn new(config: SamplerConfig, split_store: Arc<S>) -> Self {
2829 let inner = TripletSamplerInner::new(config, split_store);
2830 Self {
2831 inner: Mutex::new(inner),
2832 }
2833 }
2834
2835 pub fn new_with_chunker(
2837 config: SamplerConfig,
2838 split_store: Arc<S>,
2839 chunker: Arc<dyn ChunkingAlgorithm>,
2840 ) -> Self {
2841 let inner = TripletSamplerInner::new_with_chunker(config, split_store, chunker);
2842 Self {
2843 inner: Mutex::new(inner),
2844 }
2845 }
2846
2847 pub fn next_pair_batch_for_split(
2849 &self,
2850 split: SplitLabel,
2851 ) -> Result<SampleBatch, SamplerError> {
2852 self.next_pair_batch_with_weights_for_split(split, &HashMap::new())
2853 }
2854
2855 pub fn next_text_batch_for_split(&self, split: SplitLabel) -> Result<TextBatch, SamplerError> {
2857 self.next_text_batch_with_weights_for_split(split, &HashMap::new())
2858 }
2859
2860 pub fn next_triplet_batch_for_split(
2862 &self,
2863 split: SplitLabel,
2864 ) -> Result<TripletBatch, SamplerError> {
2865 self.next_triplet_batch_with_weights_for_split(split, &HashMap::new())
2866 }
2867
2868 pub fn next_pair_batch_with_weights_for_split(
2870 &self,
2871 split: SplitLabel,
2872 weights: &HashMap<SourceId, f32>,
2873 ) -> Result<SampleBatch, SamplerError> {
2874 let mut inner = self.inner.lock().unwrap();
2875 inner.ensure_split_allowed(split)?;
2876 for attempt in 0..=EXHAUSTION_RETRY_LIMIT {
2877 match inner.next_pair_batch_inner_with_weights(split, Some(weights)) {
2878 Ok(batch) => return Ok(batch),
2879 Err(SamplerError::Exhausted(_)) => {
2880 if attempt < EXHAUSTION_RETRY_LIMIT {
2881 inner.force_ingest_refresh_with_weights_for_split(split, weights)?;
2882 }
2883 }
2884 Err(err) => return Err(err),
2885 }
2886 }
2887 Err(SamplerError::Exhausted(RECIPE_LABEL_TRIPLETS.into()))
2888 }
2889
2890 pub fn next_text_batch_with_weights_for_split(
2892 &self,
2893 split: SplitLabel,
2894 weights: &HashMap<SourceId, f32>,
2895 ) -> Result<TextBatch, SamplerError> {
2896 let mut inner = self.inner.lock().unwrap();
2897 inner.ensure_split_allowed(split)?;
2898 for attempt in 0..=EXHAUSTION_RETRY_LIMIT {
2899 match inner.next_text_batch_inner_with_weights(split, Some(weights)) {
2900 Ok(batch) => return Ok(batch),
2901 Err(SamplerError::Exhausted(_)) => {
2902 if attempt < EXHAUSTION_RETRY_LIMIT {
2903 inner.force_ingest_refresh_with_weights_for_split(split, weights)?;
2904 }
2905 }
2906 Err(err) => return Err(err),
2907 }
2908 }
2909 Err(SamplerError::Exhausted(RECIPE_LABEL_TEXT.into()))
2910 }
2911
2912 pub fn next_triplet_batch_with_weights_for_split(
2914 &self,
2915 split: SplitLabel,
2916 weights: &HashMap<SourceId, f32>,
2917 ) -> Result<TripletBatch, SamplerError> {
2918 let mut inner = self.inner.lock().unwrap();
2919 inner.ensure_split_allowed(split)?;
2920 for attempt in 0..=EXHAUSTION_RETRY_LIMIT {
2921 match inner.next_triplet_batch_inner_with_weights(split, Some(weights)) {
2922 Ok(batch) => return Ok(batch),
2923 Err(SamplerError::Exhausted(_)) => {
2924 if attempt < EXHAUSTION_RETRY_LIMIT {
2925 inner.force_ingest_refresh_with_weights_for_split(split, weights)?;
2926 }
2927 }
2928 Err(err) => return Err(err),
2929 }
2930 }
2931 Err(SamplerError::Exhausted(RECIPE_LABEL_TRIPLETS.into()))
2932 }
2933
2934 pub fn prefetch_triplet_batches(
2936 self: Arc<Self>,
2937 split: SplitLabel,
2938 capacity: usize,
2939 ) -> BatchPrefetcher<TripletBatch> {
2940 BatchPrefetcher::new(capacity, move || self.next_triplet_batch_for_split(split))
2941 }
2942
2943 pub fn prefetch_triplet_batches_with_weights(
2945 self: Arc<Self>,
2946 split: SplitLabel,
2947 capacity: usize,
2948 weights: HashMap<SourceId, f32>,
2949 ) -> BatchPrefetcher<TripletBatch> {
2950 BatchPrefetcher::new(capacity, move || {
2951 self.next_triplet_batch_with_weights_for_split(split, &weights)
2952 })
2953 }
2954
2955 pub fn prefetch_pair_batches(
2957 self: Arc<Self>,
2958 split: SplitLabel,
2959 capacity: usize,
2960 ) -> BatchPrefetcher<SampleBatch> {
2961 BatchPrefetcher::new(capacity, move || self.next_pair_batch_for_split(split))
2962 }
2963
2964 pub fn prefetch_pair_batches_with_weights(
2966 self: Arc<Self>,
2967 split: SplitLabel,
2968 capacity: usize,
2969 weights: HashMap<SourceId, f32>,
2970 ) -> BatchPrefetcher<SampleBatch> {
2971 BatchPrefetcher::new(capacity, move || {
2972 self.next_pair_batch_with_weights_for_split(split, &weights)
2973 })
2974 }
2975
2976 pub fn prefetch_text_batches(
2978 self: Arc<Self>,
2979 split: SplitLabel,
2980 capacity: usize,
2981 ) -> BatchPrefetcher<TextBatch> {
2982 BatchPrefetcher::new(capacity, move || self.next_text_batch_for_split(split))
2983 }
2984
2985 pub fn prefetch_text_batches_with_weights(
2987 self: Arc<Self>,
2988 split: SplitLabel,
2989 capacity: usize,
2990 weights: HashMap<SourceId, f32>,
2991 ) -> BatchPrefetcher<TextBatch> {
2992 BatchPrefetcher::new(capacity, move || {
2993 self.next_text_batch_with_weights_for_split(split, &weights)
2994 })
2995 }
2996
2997 pub fn text_recipes(&self) -> Vec<TextRecipe> {
2999 let inner = self.inner.lock().unwrap();
3000 inner.text_recipes().to_vec()
3001 }
3002
3003 pub fn register_source(&self, source: Box<dyn DataSource + 'static>) {
3005 let mut inner = self.inner.lock().unwrap();
3006 inner.register_source(source);
3007 }
3008
3009 pub fn set_epoch(&self, epoch: u64) -> Result<(), SamplerError> {
3011 let mut inner = self.inner.lock().unwrap();
3012 inner.set_epoch(epoch)
3013 }
3014
3015 pub fn save_sampler_state(&self, save_to: Option<&Path>) -> Result<(), SamplerError> {
3020 let mut inner = self.inner.lock().unwrap();
3021 inner.save_sampler_state(save_to)
3022 }
3023
3024 #[cfg(all(feature = "bm25-mining", feature = "extended-metrics"))]
3032 pub fn bm25_fallback_stats(&self) -> (u64, u64) {
3033 let inner = self.inner.lock().unwrap();
3034 inner.bm25_fallback_stats()
3035 }
3036}
3037
3038impl<S: SplitStore + EpochStateStore + SamplerStateStore + 'static> Sampler for TripletSampler<S> {
3039 fn next_pair_batch(&self, split: SplitLabel) -> Result<SampleBatch, SamplerError> {
3040 self.next_pair_batch_for_split(split)
3041 }
3042
3043 fn next_pair_batch_with_weights(
3044 &self,
3045 split: SplitLabel,
3046 weights: &HashMap<SourceId, f32>,
3047 ) -> Result<SampleBatch, SamplerError> {
3048 self.next_pair_batch_with_weights_for_split(split, weights)
3049 }
3050
3051 fn next_text_batch(&self, split: SplitLabel) -> Result<TextBatch, SamplerError> {
3052 self.next_text_batch_for_split(split)
3053 }
3054
3055 fn next_text_batch_with_weights(
3056 &self,
3057 split: SplitLabel,
3058 weights: &HashMap<SourceId, f32>,
3059 ) -> Result<TextBatch, SamplerError> {
3060 self.next_text_batch_with_weights_for_split(split, weights)
3061 }
3062
3063 fn next_triplet_batch(&self, split: SplitLabel) -> Result<TripletBatch, SamplerError> {
3064 self.next_triplet_batch_for_split(split)
3065 }
3066
3067 fn next_triplet_batch_with_weights(
3068 &self,
3069 split: SplitLabel,
3070 weights: &HashMap<SourceId, f32>,
3071 ) -> Result<TripletBatch, SamplerError> {
3072 self.next_triplet_batch_with_weights_for_split(split, weights)
3073 }
3074}
3075
3076fn roles_match(target: &SectionRole, candidate: &SectionRole) -> bool {
3077 target == candidate
3078}
3079
3080fn role_cursor_key(record_id: &RecordId, role: &SectionRole) -> (RecordId, String) {
3081 (record_id.clone(), role_label(role))
3082}
3083
3084fn role_label(role: &SectionRole) -> String {
3085 match role {
3086 SectionRole::Anchor => ROLE_LABEL_ANCHOR.into(),
3087 SectionRole::Context => ROLE_LABEL_CONTEXT.into(),
3088 }
3089}
3090
3091fn taxonomy_value(record: &DataRecord, field: MetadataKey) -> Option<&str> {
3092 record.taxonomy.iter().find_map(|entry| field.strip(entry))
3093}
3094
3095fn strategy_reason(strategy: &NegativeStrategy) -> &'static str {
3096 match strategy {
3097 NegativeStrategy::WrongPublicationDate => NEG_REASON_WRONG_DATE,
3098 NegativeStrategy::WrongArticle => NEG_REASON_WRONG_ARTICLE,
3099 NegativeStrategy::QuestionAnswerMismatch => NEG_REASON_WRONG_QA,
3100 }
3101}
3102
3103fn chunk_key(chunk: &RecordChunk) -> String {
3104 match &chunk.view {
3105 ChunkView::Window { index, .. } => {
3106 format!("{}|{}|w|{}", chunk.record_id, chunk.section_idx, index)
3107 }
3108 ChunkView::SummaryFallback { strategy, .. } => {
3109 format!("{}|{}|s|{}", chunk.record_id, chunk.section_idx, strategy)
3110 }
3111 }
3112}
3113
3114fn pad_with_reuse<T: Clone>(items: &mut Vec<T>, target: usize) {
3115 if items.is_empty() || items.len() >= target {
3116 return;
3117 }
3118 let seed = items.clone();
3119 let base_len = seed.len();
3120 for idx in 0..(target - items.len()) {
3121 items.push(seed[idx % base_len].clone());
3122 }
3123}
3124
3125#[cfg(test)]
3126mod tests;