1mod backends;
2
3use crate::chunking::{ChunkingAlgorithm, SlidingWindowChunker};
4use chrono::Duration;
5use indexmap::IndexMap;
6use rand::prelude::*;
7use rayon::prelude::*;
8use std::borrow::Cow;
9use std::collections::{HashMap, HashSet};
10use std::path::Path;
11use std::sync::atomic::{AtomicUsize, Ordering};
12use std::sync::mpsc;
13use std::sync::{Arc, Mutex};
14use std::thread;
15
16use crate::config::{
17 ChunkingStrategy, NegativeStrategy, SamplerConfig, Selector, TextRecipe, TripletRecipe,
18};
19use crate::constants::sampler::AUTO_INJECTED_LONG_SECTION_CHUNK_PAIR_RECIPE_NAME;
20use crate::constants::sampler::{
21 ANCHOR_POSITIVE_SWAP_MASK, EPOCH_SEED_OFFSET, EXHAUSTION_RETRY_LIMIT, NEG_REASON_WRONG_ARTICLE,
22 NEG_REASON_WRONG_DATE, NEG_REASON_WRONG_QA, PREFETCHER_SOURCE_ID, PREFETCHER_STOPPED_REASON,
23 RECIPE_LABEL_TEXT, RECIPE_LABEL_TRIPLETS, RECIPE_ORDER_MAX_WEIGHT_MULTIPLIER,
24 ROLE_LABEL_ANCHOR, ROLE_LABEL_CONTEXT, SAME_SELECTOR_PAIR_RETRY_LIMIT,
25};
26use crate::data::{
27 ChunkView, DataRecord, PairLabel, RecordChunk, RecordSection, SampleBatch, SamplePair,
28 SampleTriplet, SectionRole, TextBatch, TextSample, TripletBatch,
29};
30use crate::epoch::EpochTracker;
31use crate::errors::SamplerError;
32use crate::hash::{derive_epoch_seed, stable_hash_str};
33use crate::ingestion::IngestionManager;
34use crate::metadata::{META_FIELD_DATE, MetadataKey};
35use crate::metrics::{chunk_proximity_score, window_index_proximity};
36use crate::source::DataSource;
37use crate::splits::{
38 EpochStateStore, PersistedSamplerState, SamplerStateStore, SplitLabel, SplitStore,
39};
40use crate::tokenizer::{Tokenizer, WhitespaceTokenizer};
41use crate::types::{RecipeKey, RecordId, SourceId};
42use crate::utils::platform_newline;
43
44#[derive(Debug, Clone)]
57struct DeterministicRng {
59 state: u64,
60}
61
62impl DeterministicRng {
63 fn new(seed: u64) -> Self {
64 Self { state: seed }
65 }
66
67 fn from_state(state: u64) -> Self {
68 Self { state }
69 }
70
71 fn state(&self) -> u64 {
72 self.state
73 }
74
75 fn next_u64_internal(&mut self) -> u64 {
76 let mut z = self.state.wrapping_add(0x9E3779B97F4A7C15);
77 self.state = z;
78 z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
79 z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
80 z ^ (z >> 31)
81 }
82}
83
84impl rand::RngCore for DeterministicRng {
85 fn next_u32(&mut self) -> u32 {
86 self.next_u64_internal() as u32
87 }
88
89 fn next_u64(&mut self) -> u64 {
90 self.next_u64_internal()
91 }
92
93 fn fill_bytes(&mut self, dest: &mut [u8]) {
94 let mut offset = 0;
95 while offset < dest.len() {
96 let value = self.next_u64_internal();
97 let bytes = value.to_le_bytes();
98 let remaining = dest.len() - offset;
99 let copy_len = remaining.min(bytes.len());
100 dest[offset..offset + copy_len].copy_from_slice(&bytes[..copy_len]);
101 offset += copy_len;
102 }
103 }
104}
105
106pub fn chunk_weight(strategy: &ChunkingStrategy, chunk: &RecordChunk) -> f32 {
115 let floor = strategy.chunk_weight_floor;
116 let trust = chunk.quality.trust.clamp(0.0, 1.0);
117 let base = match &chunk.view {
118 ChunkView::Window { index, .. } => window_index_proximity(*index),
119 ChunkView::SummaryFallback { weight, .. } => *weight,
120 };
121 (base * trust).max(floor)
122}
123
124pub trait Sampler {
126 fn next_pair_batch(&self, split: SplitLabel) -> Result<SampleBatch, SamplerError> {
128 self.next_pair_batch_with_weights(split, &HashMap::new())
129 }
130 fn next_text_batch(&self, split: SplitLabel) -> Result<TextBatch, SamplerError> {
132 self.next_text_batch_with_weights(split, &HashMap::new())
133 }
134 fn next_triplet_batch(&self, split: SplitLabel) -> Result<TripletBatch, SamplerError> {
136 self.next_triplet_batch_with_weights(split, &HashMap::new())
137 }
138 fn next_pair_batch_with_weights(
140 &self,
141 split: SplitLabel,
142 weights: &HashMap<SourceId, f32>,
143 ) -> Result<SampleBatch, SamplerError>;
144 fn next_text_batch_with_weights(
146 &self,
147 split: SplitLabel,
148 weights: &HashMap<SourceId, f32>,
149 ) -> Result<TextBatch, SamplerError>;
150 fn next_triplet_batch_with_weights(
152 &self,
153 split: SplitLabel,
154 weights: &HashMap<SourceId, f32>,
155 ) -> Result<TripletBatch, SamplerError>;
156}
157
158pub struct BatchPrefetcher<T> {
160 receiver: Option<mpsc::Receiver<Result<T, SamplerError>>>,
161 handle: Option<thread::JoinHandle<()>>,
162 stats: Arc<PrefetcherStats>,
163}
164
165#[derive(Default)]
166struct PrefetcherStats {
168 queued: AtomicUsize,
169 produced: AtomicUsize,
170 errors: AtomicUsize,
171}
172
173impl<T: Send + 'static> BatchPrefetcher<T> {
174 fn new<F>(capacity: usize, mut producer: F) -> Self
175 where
176 F: FnMut() -> Result<T, SamplerError> + Send + 'static,
177 {
178 let (sender, receiver) = mpsc::sync_channel(capacity.max(1));
179 let stats = Arc::new(PrefetcherStats::default());
180 let stats_thread = Arc::clone(&stats);
181 let handle = thread::spawn(move || {
182 loop {
183 let result = producer();
184 if result.is_err() {
185 stats_thread.errors.fetch_add(1, Ordering::Relaxed);
186 }
187 if sender.send(result).is_err() {
188 return;
189 }
190 stats_thread.queued.fetch_add(1, Ordering::Relaxed);
191 stats_thread.produced.fetch_add(1, Ordering::Relaxed);
192 }
193 });
194 Self {
195 receiver: Some(receiver),
196 handle: Some(handle),
197 stats,
198 }
199 }
200
201 pub fn next(&self) -> Result<T, SamplerError> {
203 let receiver = self
204 .receiver
205 .as_ref()
206 .ok_or_else(|| SamplerError::SourceUnavailable {
207 source_id: PREFETCHER_SOURCE_ID.into(),
208 reason: PREFETCHER_STOPPED_REASON.into(),
209 })?;
210 let result = receiver.recv().unwrap_or_else(|_| {
211 Err(SamplerError::SourceUnavailable {
212 source_id: PREFETCHER_SOURCE_ID.into(),
213 reason: PREFETCHER_STOPPED_REASON.into(),
214 })
215 });
216 self.stats
217 .queued
218 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |value| {
219 Some(value.saturating_sub(1))
220 })
221 .ok();
222 result
223 }
224
225 pub fn queue_len(&self) -> usize {
227 self.stats.queued.load(Ordering::Relaxed)
228 }
229
230 pub fn produced_count(&self) -> usize {
232 self.stats.produced.load(Ordering::Relaxed)
233 }
234
235 pub fn error_count(&self) -> usize {
237 self.stats.errors.load(Ordering::Relaxed)
238 }
239}
240
241impl<T> Drop for BatchPrefetcher<T> {
242 fn drop(&mut self) {
243 self.receiver.take();
244 if let Some(handle) = self.handle.take() {
245 let _ = handle.join();
246 }
247 }
248}
249
250pub struct TripletSampler<S: SplitStore + EpochStateStore + SamplerStateStore + 'static> {
253 inner: Mutex<TripletSamplerInner<S>>,
254}
255
256struct TripletSamplerInner<S: SplitStore + EpochStateStore + SamplerStateStore + 'static> {
258 config: SamplerConfig,
260 chunker: Arc<dyn ChunkingAlgorithm>,
262 split_store: Arc<S>,
264 ingestion: IngestionManager,
266 records: IndexMap<RecordId, Arc<DataRecord>>,
268 rng: DeterministicRng,
270 triplet_recipes: Vec<TripletRecipe>,
272 text_recipes: Vec<TextRecipe>,
274 source_triplet_recipes: HashMap<SourceId, Vec<TripletRecipe>>,
276 sources_with_long_sections: HashSet<SourceId>,
278 source_text_recipes: HashMap<SourceId, Vec<TextRecipe>>,
280 using_config_triplet_recipes: bool,
282 using_config_text_recipes: bool,
284 last_observed_ingest: u64,
286 epoch_tracker: EpochTracker,
288 chunk_cursors: HashMap<(RecordId, usize), usize>,
290 role_cursors: HashMap<(RecordId, String), usize>,
292 negative_backend: Box<dyn backends::NegativeBackend>,
294 chunk_index: HashMap<RecordId, RecordId>,
296 source_order: Vec<SourceId>,
298 source_cycle_idx: usize,
300 source_state_loaded: bool,
302 ingestion_cursors_loaded: bool,
304 source_state_dirty: bool,
306 source_record_indices: HashMap<SourceId, Vec<usize>>,
308 source_record_cursors: HashMap<SourceId, usize>,
310 triplet_recipe_rr_idx: usize,
312 emitted_text_hashes: HashSet<u64>,
317 text_recipe_rr_idx: usize,
319 source_epoch: u64,
321
322 source_wrapped: HashMap<SourceId, bool>,
324}
325
326impl<S: SplitStore + EpochStateStore + SamplerStateStore + 'static> TripletSamplerInner<S> {
327 fn new(config: SamplerConfig, split_store: Arc<S>) -> Self {
328 Self::new_with_chunker(config, split_store, Arc::new(SlidingWindowChunker))
329 }
330
331 fn new_with_chunker(
332 config: SamplerConfig,
333 split_store: Arc<S>,
334 chunker: Arc<dyn ChunkingAlgorithm>,
335 ) -> Self {
336 let buffer_size = config.ingestion_max_records.max(config.batch_size).max(2);
337 let using_config_triplet_recipes = !config.recipes.is_empty();
338 let using_config_text_recipes = !config.text_recipes.is_empty();
339 let triplet_recipes = if using_config_triplet_recipes {
340 config.recipes.clone()
341 } else {
342 Vec::new()
343 };
344 let text_recipes = if using_config_text_recipes {
345 config.text_recipes.clone()
346 } else if !triplet_recipes.is_empty() {
347 Self::build_derived_text_recipes(&triplet_recipes)
348 } else {
349 Vec::new()
350 };
351 let ingestion = IngestionManager::new(buffer_size, config.clone());
352 let epoch_backend = Some(Arc::clone(&split_store) as Arc<dyn EpochStateStore>);
353 let epoch_tracker = EpochTracker::new(
354 true,
355 epoch_backend,
356 derive_epoch_seed(config.seed, EPOCH_SEED_OFFSET),
357 );
358 let mut sampler = Self {
359 rng: DeterministicRng::new(config.seed),
360 config,
361 chunker,
362 split_store,
363 ingestion,
364 records: IndexMap::new(),
365 triplet_recipes,
366 text_recipes,
367 source_triplet_recipes: HashMap::new(),
368 sources_with_long_sections: HashSet::new(),
369 source_text_recipes: HashMap::new(),
370 using_config_triplet_recipes,
371 using_config_text_recipes,
372 last_observed_ingest: 0,
373 epoch_tracker,
374 chunk_cursors: HashMap::new(),
375 role_cursors: HashMap::new(),
376 negative_backend: {
377 #[cfg(feature = "bm25-mining")]
378 {
379 Box::new(backends::Bm25Backend::new())
380 }
381 #[cfg(not(feature = "bm25-mining"))]
382 {
383 Box::new(backends::DefaultBackend)
384 }
385 },
386 chunk_index: HashMap::new(),
387 source_order: Vec::new(),
388 source_cycle_idx: 0,
389 source_state_loaded: false,
390 ingestion_cursors_loaded: false,
391 source_state_dirty: false,
392 source_record_indices: HashMap::new(),
393 source_record_cursors: HashMap::new(),
394 emitted_text_hashes: HashSet::new(),
395 triplet_recipe_rr_idx: 0,
396 text_recipe_rr_idx: 0,
397 source_epoch: 0,
398 source_wrapped: HashMap::new(),
399 };
400 if !sampler.using_config_text_recipes {
401 sampler.rebuild_derived_text_recipes();
402 }
403 sampler
404 }
405
406 fn text_recipes(&self) -> &[TextRecipe] {
407 &self.text_recipes
408 }
409
410 fn epoch_seed(&self) -> u64 {
413 derive_epoch_seed(self.config.seed, self.source_epoch)
414 }
415
416 fn register_source(
417 &mut self,
418 source: Box<dyn DataSource + 'static>,
419 ) -> Result<(), SamplerError> {
420 let source_id = source.id().to_string();
421 if !self.using_config_triplet_recipes {
422 let triplets = source.default_triplet_recipes();
423 if !triplets.is_empty() {
424 self.source_triplet_recipes
425 .insert(source_id.clone(), triplets.clone());
426 if !self.using_config_text_recipes {
427 let derived = Self::build_derived_text_recipes(&triplets);
428 self.source_text_recipes
429 .insert(source_id.clone(), derived.clone());
430 self.extend_text_recipes_unique(&derived);
431 }
432 }
433 }
434 self.ingestion.register_source(source)?;
435 Ok(())
436 }
437
438 fn set_epoch(&mut self, epoch: u64) -> Result<(), SamplerError> {
439 self.epoch_tracker.ensure_loaded()?;
440 self.epoch_tracker.force_epoch(epoch);
441 self.source_epoch = epoch;
442 self.ingestion.set_source_epoch(epoch);
443 self.ingestion.reset_stream_cursors();
444 self.ingestion.reset_step_counter();
447 self.source_record_cursors.clear();
448 self.source_cycle_idx = 0;
449 for source in &self.source_order {
450 self.source_wrapped.insert(source.clone(), false);
451 }
452 self.rebuild_source_index()?;
453 self.source_state_dirty = self.source_order.len() > 1;
454 Ok(())
455 }
456
457 fn next_chunk_from_pool(
458 &mut self,
459 record_id: &str,
460 section_idx: usize,
461 pool: Vec<RecordChunk>,
462 ) -> Option<RecordChunk> {
463 if pool.is_empty() {
464 return None;
465 }
466 let key = (record_id.to_string(), section_idx);
467 if !self.chunk_cursors.contains_key(&key) {
468 let cursor_key = format!("{}::{}", record_id, section_idx);
474 let start = (stable_hash_str(self.epoch_seed(), &cursor_key) as usize) % pool.len();
475 self.chunk_cursors.insert(key.clone(), start);
476 }
477 let cursor = self.chunk_cursors.entry(key).or_insert(0);
478 if *cursor >= pool.len() {
479 *cursor = 0;
480 }
481 let chunk = pool.get(*cursor).cloned();
482 *cursor = (*cursor + 1) % pool.len();
483 chunk
484 }
485
486 fn prune_cursor_state(&mut self) {
487 if self.chunk_cursors.is_empty()
488 && self.role_cursors.is_empty()
489 && self.negative_backend.cursors_empty()
490 {
491 return;
492 }
493 let valid_ids: HashSet<RecordId> = self.records.keys().cloned().collect();
494 self.chunk_cursors
495 .retain(|(record_id, _), _| valid_ids.contains(record_id));
496 self.role_cursors
497 .retain(|(record_id, _), _| valid_ids.contains(record_id));
498 self.negative_backend.prune_cursors(&valid_ids);
499 }
500
501 fn rebuild_chunk_index(&mut self) {
502 self.chunk_index.clear();
503 for record in self.records.values() {
504 self.chunk_index
505 .insert(record.id.clone(), record.id.clone());
506 }
507 }
508
509 fn rebuild_source_index(&mut self) -> Result<(), SamplerError> {
510 self.source_record_indices.clear();
511 let mut label_cache: HashMap<RecordId, SplitLabel> = HashMap::new();
512 let allowed = self.allowed_target_splits();
513 let allowed_set: HashSet<SplitLabel> = allowed.into_iter().collect();
514 for (idx, record) in self.records.values().enumerate() {
515 let label = if let Some(label) = label_cache.get(&record.id) {
516 *label
517 } else {
518 let label = match self.split_store.label_for(&record.id) {
519 Some(label) => label,
520 None => self.split_store.ensure(record.id.clone())?,
521 };
522 label_cache.insert(record.id.clone(), label);
523 label
524 };
525 if !allowed_set.contains(&label) {
526 continue;
527 }
528 self.source_record_indices
529 .entry(record.source.clone())
530 .or_default()
531 .push(idx);
532 }
533
534 let shuffle_seed = self.epoch_seed();
535 for indices in self.source_record_indices.values_mut() {
536 indices.sort_by_key(|idx| {
537 self.records
538 .get_index(*idx)
539 .map(|(_, record)| stable_hash_str(shuffle_seed, &record.id))
540 .unwrap_or(0)
541 });
542 }
543
544 self.source_order = self.source_record_indices.keys().cloned().collect();
545 self.source_order.sort();
546 self.refresh_source_wrapped();
547
548 self.source_record_cursors
549 .retain(|source, _| self.source_record_indices.contains_key(source));
550 if self.source_state_loaded {
551 if self.source_order.is_empty() {
552 self.source_cycle_idx = 0;
553 }
554 self.source_state_dirty = self.source_order.len() > 1;
555 }
556 Ok(())
557 }
558
559 fn refresh_source_wrapped(&mut self) {
560 self.source_wrapped.clear();
561 for source in &self.source_order {
562 let len = self
563 .source_record_indices
564 .get(source)
565 .map(|items| items.len())
566 .unwrap_or(0);
567 if len == 0 {
568 self.source_wrapped.insert(source.clone(), false);
569 continue;
570 }
571 let cursor = self.source_record_cursors.get(source).copied().unwrap_or(0);
572 let wrapped = cursor > 0 && cursor % len == 0;
573 self.source_wrapped.insert(source.clone(), wrapped);
574 }
575 }
576
577 fn shuffled_source_cycle(&self, cycle: u64) -> Vec<SourceId> {
578 let mut sources = self.source_order.clone();
579 let seed = self.epoch_seed() ^ cycle;
580 sources.sort_by_key(|source| stable_hash_str(seed, source));
581 sources
582 }
583
584 fn ensure_source_state(&mut self) -> Result<(), SamplerError> {
585 if self.source_state_loaded {
586 return Ok(());
587 }
588 let persisted = self.split_store.load_sampler_state()?;
589 self.source_cycle_idx = persisted
590 .as_ref()
591 .map(|state| state.source_cycle_idx as usize)
592 .unwrap_or(0);
593 if let Some(state) = persisted {
594 for (source, cursor) in state.source_record_cursors {
595 if self.source_record_indices.contains_key(&source) {
596 self.source_record_cursors.insert(source, cursor as usize);
597 }
598 }
599 self.source_epoch = state.source_epoch;
600 self.ingestion.set_source_epoch(state.source_epoch);
601 self.rng = DeterministicRng::from_state(state.rng_state);
602 self.triplet_recipe_rr_idx = state.triplet_recipe_rr_idx as usize;
603 self.text_recipe_rr_idx = state.text_recipe_rr_idx as usize;
604 }
605 self.refresh_source_wrapped();
606 self.source_state_loaded = true;
607 self.source_state_dirty = true;
608 Ok(())
609 }
610
611 fn persist_source_state(&mut self, save_to: Option<&Path>) -> Result<(), SamplerError> {
612 if !self.source_state_loaded {
613 return Ok(());
614 }
615 let state = PersistedSamplerState {
616 source_cycle_idx: self.source_cycle_idx as u64,
617 source_record_cursors: self
618 .source_record_cursors
619 .iter()
620 .map(|(source, cursor)| (source.clone(), *cursor as u64))
621 .collect(),
622 source_epoch: self.source_epoch,
623 rng_state: self.rng.state(),
624 triplet_recipe_rr_idx: self.triplet_recipe_rr_idx as u64,
625 text_recipe_rr_idx: self.text_recipe_rr_idx as u64,
626 source_stream_cursors: self.ingestion.snapshot_cursors(),
627 };
628 self.split_store.save_sampler_state(&state, save_to)?;
629 self.source_state_dirty = false;
630 Ok(())
631 }
632
633 fn rebuild_derived_text_recipes(&mut self) {
634 if self.using_config_text_recipes {
635 return;
636 }
637 if self.triplet_recipes.is_empty() {
638 self.text_recipes.clear();
639 } else {
640 self.text_recipes = Self::build_derived_text_recipes(&self.triplet_recipes);
641 }
642 }
643
644 fn extend_text_recipes_unique(&mut self, recipes: &[TextRecipe]) {
645 for recipe in recipes {
646 if self
647 .text_recipes
648 .iter()
649 .any(|existing| existing.name == recipe.name)
650 {
651 continue;
652 }
653 self.text_recipes.push(recipe.clone());
654 }
655 }
656
657 fn configured_triplet_recipes_for_source<'a>(&'a self, source: &str) -> &'a [TripletRecipe] {
658 if self.using_config_triplet_recipes {
659 return &self.triplet_recipes;
660 }
661 self.source_triplet_recipes
662 .get(source)
663 .map(|recipes| recipes.as_slice())
664 .unwrap_or(&[])
665 }
666
667 fn contains_auto_chunk_pair_recipe(recipes: &[TripletRecipe]) -> bool {
669 recipes
670 .iter()
671 .any(|recipe| recipe.name.as_ref() == AUTO_INJECTED_LONG_SECTION_CHUNK_PAIR_RECIPE_NAME)
672 }
673
674 fn source_supports_chunk_pair_recipe(&self, source: &str) -> bool {
675 if self.config.chunking.max_window_tokens == 0 {
676 return false;
677 }
678 self.sources_with_long_sections.contains(source)
679 }
680
681 fn should_auto_inject_chunk_pair_recipe(
687 &self,
688 source: &str,
689 recipes: &[TripletRecipe],
690 ) -> bool {
691 self.source_supports_chunk_pair_recipe(source)
692 && !Self::contains_auto_chunk_pair_recipe(recipes)
693 }
694
695 fn source_chunk_pair_recipe() -> TripletRecipe {
706 TripletRecipe {
707 name: Cow::Borrowed(AUTO_INJECTED_LONG_SECTION_CHUNK_PAIR_RECIPE_NAME),
708 anchor: Selector::Role(SectionRole::Context),
709 positive_selector: Selector::Role(SectionRole::Context),
710 negative_selector: Selector::Role(SectionRole::Context),
711 negative_strategy: NegativeStrategy::WrongArticle,
712 weight: 1.0,
713 instruction: None,
714 allow_same_anchor_positive: false,
715 }
716 }
717
718 fn resolve_source_triplet_plan(&self, source: &str) -> (Vec<TripletRecipe>, bool) {
730 let mut recipes = self.configured_triplet_recipes_for_source(source).to_vec();
731 let mut auto_injected = false;
732 if self.should_auto_inject_chunk_pair_recipe(source, &recipes) {
733 recipes.push(Self::source_chunk_pair_recipe());
734 auto_injected = true;
735 }
736 (recipes, auto_injected)
737 }
738
739 #[cfg(test)]
740 fn triplet_recipes_for_source(&self, source: &str) -> Vec<TripletRecipe> {
741 self.resolve_source_triplet_plan(source).0
742 }
743
744 fn triplet_recipe_count_for_source(&self, source: &str) -> usize {
745 let (recipes, _auto_injected) = self.resolve_source_triplet_plan(source);
746 recipes.len()
747 }
748
749 fn text_recipes_for_source<'a>(&'a self, source: &str) -> &'a [TextRecipe] {
750 if self.using_config_text_recipes || self.using_config_triplet_recipes {
751 return &self.text_recipes;
752 }
753 self.source_text_recipes
754 .get(source)
755 .map(|recipes| recipes.as_slice())
756 .unwrap_or(&[])
757 }
758
759 fn recipe_order_weighted_shuffled(
770 &mut self,
771 weights: &[f32],
772 rng: &mut DeterministicRng,
773 ) -> Vec<usize> {
774 weighted_recipe_order(weights, rng)
775 }
776
777 fn recipe_order_weighted_cycled(
783 &mut self,
784 weights: &[f32],
785 rr_idx: usize,
786 rng: &mut DeterministicRng,
787 ) -> Vec<usize> {
788 let base = self.recipe_order_weighted_shuffled(weights, rng);
789 if base.is_empty() {
790 return base;
791 }
792 let start = rr_idx % base.len();
793 let mut order = Vec::with_capacity(base.len());
794 order.extend_from_slice(&base[start..]);
795 order.extend_from_slice(&base[..start]);
796 order
797 }
798
799 fn text_recipe_order_weighted_shuffled(
803 &mut self,
804 weights: &[f32],
805 rng: &mut DeterministicRng,
806 ) -> Vec<usize> {
807 weighted_recipe_order(weights, rng)
808 }
809
810 fn text_recipe_order_weighted_cycled(
811 &mut self,
812 weights: &[f32],
813 rr_idx: usize,
814 rng: &mut DeterministicRng,
815 ) -> Vec<usize> {
816 let base = self.text_recipe_order_weighted_shuffled(weights, rng);
817 if base.is_empty() {
818 return base;
819 }
820 let start = rr_idx % base.len();
821 let mut order = Vec::with_capacity(base.len());
822 order.extend_from_slice(&base[start..]);
823 order.extend_from_slice(&base[..start]);
824 order
825 }
826
827 fn allowed_target_splits(&self) -> Vec<SplitLabel> {
828 self.config.allowed_splits.clone()
829 }
830
831 fn ensure_split_allowed(&self, split: SplitLabel) -> Result<(), SamplerError> {
832 let allowed = self.allowed_target_splits();
833 if allowed.contains(&split) {
834 return Ok(());
835 }
836 Err(SamplerError::Configuration(format!(
837 "requested split {:?} is not in allowed_splits {:?}",
838 split, allowed
839 )))
840 }
841
842 fn ensure_split_has_records(&mut self, target_split: SplitLabel) -> Result<(), SamplerError> {
843 let records_by_split = self.records_by_split()?;
844 if records_by_split
845 .get(&target_split)
846 .map(|records| !records.is_empty())
847 .unwrap_or(false)
848 {
849 return Ok(());
850 }
851 Err(SamplerError::Exhausted(
852 "no records available for target split".into(),
853 ))
854 }
855
856 fn records_by_split(
857 &self,
858 ) -> Result<HashMap<SplitLabel, Vec<(RecordId, SourceId)>>, SamplerError> {
859 let mut map: HashMap<SplitLabel, Vec<(RecordId, SourceId)>> = HashMap::new();
860 let mut label_cache: HashMap<RecordId, SplitLabel> = HashMap::new();
861 for (chunk_id, record_id) in &self.chunk_index {
862 let Some(record) = self.records.get(record_id) else {
863 continue;
864 };
865 let label = if let Some(label) = label_cache.get(record_id) {
866 *label
867 } else {
868 let label = match self.split_store.label_for(record_id) {
869 Some(label) => label,
870 None => self.split_store.ensure(record_id.clone())?,
871 };
872 label_cache.insert(record_id.clone(), label);
873 label
874 };
875 map.entry(label)
876 .or_default()
877 .push((chunk_id.clone(), record.source.clone()));
878 }
879 Ok(map)
880 }
881
882 fn choose_anchor_record(
883 &mut self,
884 source: Option<&str>,
885 split: SplitLabel,
886 ) -> Option<Arc<DataRecord>> {
887 if let Some(source) = source {
888 let indices = self.source_record_indices.get(source)?;
889 if indices.is_empty() {
890 return None;
891 }
892 let mut cursor = *self.source_record_cursors.get(source).unwrap_or(&0);
893 let cycle = cursor / indices.len();
894 let offset_seed = self.epoch_seed() ^ (cycle as u64);
895 let offset = (stable_hash_str(offset_seed, source) as usize) % indices.len();
896 let mut wrapped = false;
897 let mut selected: Option<Arc<DataRecord>> = None;
898 for _ in 0..indices.len() {
899 let pos = (cursor % indices.len()).saturating_add(offset) % indices.len();
900 let idx = indices[pos];
901 cursor = cursor.saturating_add(1);
902 if cursor.is_multiple_of(indices.len()) {
903 wrapped = true;
904 }
905 if let Some((_, record)) = self.records.get_index(idx) {
906 if self.split_store.label_for(&record.id) != Some(split) {
907 continue;
908 }
909 selected = Some(Arc::clone(record));
910 break;
911 }
912 }
913 self.source_record_cursors
914 .insert(source.to_string(), cursor);
915 if wrapped {
916 self.mark_source_wrapped(source);
917 }
918 return selected;
919 }
920 while let Some(chunk_id) = self.epoch_tracker.next_record(split) {
921 if let Some(record_id) = self.chunk_index.get(&chunk_id)
922 && let Some(record) = self.records.get(record_id)
923 {
924 return Some(Arc::clone(record));
925 }
926 }
927 None
928 }
929
930 fn save_sampler_state(&mut self, save_to: Option<&Path>) -> Result<(), SamplerError> {
931 if self.epoch_tracker.is_enabled() {
932 self.epoch_tracker.persist()?;
933 }
934 self.persist_source_state(save_to)?;
935 Ok(())
936 }
937
938 fn mark_source_wrapped(&mut self, source: &str) {
939 self.source_wrapped.insert(source.to_string(), true);
940 if self.source_order.is_empty() {
941 return;
942 }
943 let all_wrapped = self
944 .source_order
945 .iter()
946 .all(|name| self.source_wrapped.get(name).copied().unwrap_or(false));
947 if all_wrapped {
948 self.advance_source_epoch();
949 }
950 }
951
952 fn advance_source_epoch(&mut self) {
953 self.source_epoch = self.source_epoch.saturating_add(1);
960 self.ingestion.set_source_epoch(self.source_epoch);
961 self.ingestion.reset_step_counter();
963 self.source_record_cursors.clear();
967 self.source_cycle_idx = 0;
968 for source in &self.source_order {
971 self.source_wrapped.insert(source.clone(), false);
972 }
973 let _ = self.rebuild_source_index();
974 self.source_state_dirty = self.source_order.len() > 1;
975 }
976
977 fn select_temporal_neighbor(
978 &'_ self,
979 record: &DataRecord,
980 offset_days: i32,
981 ) -> Option<Arc<DataRecord>> {
982 let target = record.created_at + Duration::days(offset_days.into());
983 let key = record.taxonomy.first().cloned();
984 let record_split = self.split_store.label_for(&record.id)?;
985 self.records
986 .values()
987 .filter(|candidate| {
988 candidate.id != record.id
989 && self
990 .split_store
991 .label_for(&candidate.id)
992 .map(|label| label == record_split)
993 .unwrap_or(false)
994 && (candidate.source == record.source
995 || key
996 .as_ref()
997 .zip(candidate.taxonomy.first())
998 .map(|(a, b)| a == b)
999 .unwrap_or(false))
1000 })
1001 .min_by_key(|candidate| (candidate.created_at - target).num_seconds().abs())
1002 .cloned()
1003 }
1004
1005 fn select_negative_record(
1006 &self,
1007 anchor_record: &DataRecord,
1008 strategy: &NegativeStrategy,
1009 anchor_query_text: Option<&str>,
1010 rng: &mut dyn rand::RngCore,
1011 ) -> Option<(Arc<DataRecord>, bool)> {
1012 let anchor_split = self.split_store.label_for(&anchor_record.id)?;
1013
1014 let in_anchor_split = |candidate: &DataRecord| {
1015 self.split_store
1016 .label_for(&candidate.id)
1017 .map(|label| label == anchor_split)
1018 .unwrap_or(false)
1019 };
1020
1021 match strategy {
1022 NegativeStrategy::WrongArticle => {
1023 let anchor_date =
1024 taxonomy_value(anchor_record, META_FIELD_DATE).map(|d| d.to_string());
1025 let mut same_date: Vec<Arc<DataRecord>> = self
1026 .records
1027 .values()
1028 .filter(|candidate| {
1029 candidate.source == anchor_record.source
1030 && candidate.id != anchor_record.id
1031 && in_anchor_split(candidate)
1032 })
1033 .filter(|candidate| {
1034 anchor_date
1035 .as_deref()
1036 .zip(taxonomy_value(candidate, META_FIELD_DATE))
1037 .map(|(a, b)| a == b)
1038 .unwrap_or(false)
1039 })
1040 .cloned()
1041 .collect();
1042 if same_date.is_empty() {
1043 same_date = self
1044 .records
1045 .values()
1046 .filter(|candidate| {
1047 candidate.source == anchor_record.source
1048 && candidate.id != anchor_record.id
1049 && in_anchor_split(candidate)
1050 })
1051 .cloned()
1052 .collect();
1053 }
1054 if !same_date.is_empty() {
1055 return self.negative_backend.choose_negative(
1056 anchor_record,
1057 anchor_split,
1058 same_date,
1059 false,
1060 anchor_query_text,
1061 rng,
1062 );
1063 }
1064 let pool = self
1065 .records
1066 .values()
1067 .filter(|candidate| {
1068 candidate.id != anchor_record.id && in_anchor_split(candidate)
1069 })
1070 .cloned()
1071 .collect::<Vec<_>>();
1072 self.negative_backend.choose_negative(
1073 anchor_record,
1074 anchor_split,
1075 pool,
1076 true,
1077 anchor_query_text,
1078 rng,
1079 )
1080 }
1081 NegativeStrategy::WrongPublicationDate => {
1082 let anchor_date =
1083 taxonomy_value(anchor_record, META_FIELD_DATE).map(|d| d.to_string());
1084 let pool: Vec<Arc<DataRecord>> = self
1085 .records
1086 .values()
1087 .filter(|candidate| {
1088 candidate.source == anchor_record.source
1089 && candidate.id != anchor_record.id
1090 && in_anchor_split(candidate)
1091 })
1092 .filter(|candidate| {
1093 match (
1094 anchor_date.as_deref(),
1095 taxonomy_value(candidate, META_FIELD_DATE),
1096 ) {
1097 (Some(anchor), Some(candidate_date)) => anchor != candidate_date,
1098 (Some(_), None) => true,
1099 (None, Some(_)) => true,
1100 (None, None) => false,
1101 }
1102 })
1103 .cloned()
1104 .collect();
1105 if pool.is_empty() {
1106 let fallback_pool = self
1109 .records
1110 .values()
1111 .filter(|candidate| {
1112 candidate.id != anchor_record.id && in_anchor_split(candidate)
1113 })
1114 .cloned()
1115 .collect::<Vec<_>>();
1116
1117 return self.negative_backend.choose_negative(
1118 anchor_record,
1119 anchor_split,
1120 fallback_pool,
1121 true,
1122 anchor_query_text,
1123 rng,
1124 );
1125 }
1126
1127 self.negative_backend.choose_negative(
1128 anchor_record,
1129 anchor_split,
1130 pool,
1131 false,
1132 anchor_query_text,
1133 rng,
1134 )
1135 }
1136 NegativeStrategy::QuestionAnswerMismatch => {
1137 let pool: Vec<Arc<DataRecord>> = self
1138 .records
1139 .values()
1140 .filter(|candidate| {
1141 candidate.source == anchor_record.source
1142 && candidate.id != anchor_record.id
1143 && in_anchor_split(candidate)
1144 })
1145 .cloned()
1146 .collect();
1147 if pool.is_empty() {
1148 let fallback_pool = self
1151 .records
1152 .values()
1153 .filter(|candidate| {
1154 candidate.id != anchor_record.id && in_anchor_split(candidate)
1155 })
1156 .cloned()
1157 .collect::<Vec<_>>();
1158
1159 return self.negative_backend.choose_negative(
1160 anchor_record,
1161 anchor_split,
1162 fallback_pool,
1163 true,
1164 anchor_query_text,
1165 rng,
1166 );
1167 }
1168
1169 self.negative_backend.choose_negative(
1170 anchor_record,
1171 anchor_split,
1172 pool,
1173 false,
1174 anchor_query_text,
1175 rng,
1176 )
1177 }
1178 }
1179 }
1180
1181 fn is_auto_chunk_pair_recipe(recipe: &TripletRecipe) -> bool {
1183 recipe.name.as_ref() == AUTO_INJECTED_LONG_SECTION_CHUNK_PAIR_RECIPE_NAME
1184 }
1185
1186 fn select_anchor_positive_pair(
1190 &mut self,
1191 record: &DataRecord,
1192 anchor_selector: &Selector,
1193 positive_selector: &Selector,
1194 enforce_window_pair: bool,
1195 ) -> Option<(RecordChunk, RecordChunk)> {
1196 let mut anchor_chunk = self.select_chunk(record, anchor_selector)?;
1197 let mut positive_chunk = self.select_chunk(record, positive_selector)?;
1198 if anchor_selector == positive_selector {
1199 let mut retries = 0usize;
1200 while !same_selector_pair_is_valid(&anchor_chunk, &positive_chunk, enforce_window_pair)
1201 && retries < SAME_SELECTOR_PAIR_RETRY_LIMIT
1202 {
1203 let Some(redraw_anchor) = self.select_chunk(record, anchor_selector) else {
1204 break;
1205 };
1206 let Some(redraw_positive) = self.select_chunk(record, positive_selector) else {
1207 break;
1208 };
1209 anchor_chunk = redraw_anchor;
1210 positive_chunk = redraw_positive;
1211 retries += 1;
1212 }
1213 if !same_selector_pair_is_valid(&anchor_chunk, &positive_chunk, enforce_window_pair) {
1214 return None;
1215 }
1216 }
1217 Some((anchor_chunk, positive_chunk))
1218 }
1219
1220 fn select_distinct_window_pair_for_auto_recipe(
1230 &mut self,
1231 recipe: &TripletRecipe,
1232 record: &DataRecord,
1233 ) -> Option<(RecordChunk, RecordChunk)> {
1234 if recipe.anchor != recipe.positive_selector {
1235 return None;
1236 }
1237 self.select_anchor_positive_pair(record, &recipe.anchor, &recipe.positive_selector, true)
1238 }
1239
1240 fn record_has_at_least_two_window_chunks_for_selector(
1245 &self,
1246 record: &DataRecord,
1247 selector: &Selector,
1248 ) -> bool {
1249 let section_indices: Vec<usize> = match selector {
1250 Selector::Role(role) => record
1251 .sections
1252 .iter()
1253 .enumerate()
1254 .filter(|(_, section)| roles_match(role, §ion.role))
1255 .map(|(idx, _)| idx)
1256 .collect(),
1257 Selector::Paragraph(idx) => {
1258 if *idx < record.sections.len() {
1259 vec![*idx]
1260 } else {
1261 Vec::new()
1262 }
1263 }
1264 Selector::Random => (0..record.sections.len()).collect(),
1265 Selector::TemporalOffset(_) => return false,
1266 };
1267
1268 let mut window_count = 0usize;
1269 for section_idx in section_indices {
1270 let Some(section) = record.sections.get(section_idx) else {
1271 continue;
1272 };
1273 let chunks = self.materialize_chunks(record, section_idx, section);
1274 window_count += chunks
1275 .iter()
1276 .filter(|chunk| matches!(chunk.view, ChunkView::Window { .. }))
1277 .count();
1278 if window_count >= 2 {
1279 return true;
1280 }
1281 }
1282 false
1283 }
1284
1285 fn build_triplet_with_selector_pair_policy(
1287 &mut self,
1288 recipe: &TripletRecipe,
1289 record: &DataRecord,
1290 enforce_window_pair: bool,
1291 rng: &mut DeterministicRng,
1292 ) -> Option<SampleTriplet> {
1293 let (mut anchor_chunk, mut positive_chunk) = self.select_anchor_positive_pair(
1294 record,
1295 &recipe.anchor,
1296 &recipe.positive_selector,
1297 enforce_window_pair,
1298 )?;
1299 let anchor_raw_text = anchor_chunk.text.clone();
1300 self.decorate_chunk(record, &mut anchor_chunk, rng);
1301 self.decorate_chunk(record, &mut positive_chunk, rng);
1302 self.finalize_triplet_with_negative(
1305 recipe,
1306 record,
1307 anchor_chunk,
1308 positive_chunk,
1309 &anchor_raw_text,
1310 rng,
1311 )
1312 }
1313
1314 fn make_auto_chunk_pair_triplet_with_anchor(
1324 &mut self,
1325 recipe: &TripletRecipe,
1326 record: &DataRecord,
1327 rng: &mut DeterministicRng,
1328 ) -> Option<SampleTriplet> {
1329 if !self.record_has_at_least_two_window_chunks_for_selector(record, &recipe.anchor) {
1330 return None;
1331 }
1332 let (mut anchor_chunk, mut positive_chunk) =
1333 self.select_distinct_window_pair_for_auto_recipe(recipe, record)?;
1334 let anchor_raw_text = anchor_chunk.text.clone();
1335 self.decorate_chunk(record, &mut anchor_chunk, rng);
1336 self.decorate_chunk(record, &mut positive_chunk, rng);
1337 self.finalize_triplet_with_negative(
1338 recipe,
1339 record,
1340 anchor_chunk,
1341 positive_chunk,
1342 &anchor_raw_text,
1343 rng,
1344 )
1345 }
1346
1347 fn make_standard_triplet_with_anchor(
1348 &mut self,
1349 recipe: &TripletRecipe,
1350 record: &DataRecord,
1351 rng: &mut DeterministicRng,
1352 ) -> Option<SampleTriplet> {
1353 self.build_triplet_with_selector_pair_policy(recipe, record, false, rng)
1354 }
1355
1356 fn finalize_triplet_with_negative(
1375 &mut self,
1376 recipe: &TripletRecipe,
1377 record: &DataRecord,
1378 anchor_chunk: RecordChunk,
1379 positive_chunk: RecordChunk,
1380 anchor_raw_text: &str,
1381 rng: &mut DeterministicRng,
1382 ) -> Option<SampleTriplet> {
1383 let (negative_record, fallback_used) = self.select_negative_record(
1384 record,
1385 &recipe.negative_strategy,
1386 Some(anchor_raw_text),
1387 rng,
1388 )?;
1389 let mut negative_chunk = self.select_chunk(&negative_record, &recipe.negative_selector)?;
1390 self.decorate_chunk(&negative_record, &mut negative_chunk, rng);
1391
1392 let (anchor_chunk, positive_chunk) = if rng.next_u64() & ANCHOR_POSITIVE_SWAP_MASK == 0 {
1394 (positive_chunk, anchor_chunk)
1395 } else {
1396 (anchor_chunk, positive_chunk)
1397 };
1398
1399 if (!recipe.allow_same_anchor_positive && anchor_chunk.text == positive_chunk.text)
1410 || negative_chunk.text == positive_chunk.text
1411 || negative_chunk.text == anchor_chunk.text
1412 {
1413 return None;
1414 }
1415
1416 let chunk_weight =
1417 self.triplet_chunk_weight(recipe, &anchor_chunk, &positive_chunk, &negative_chunk);
1418 let weight = recipe.weight * chunk_weight;
1419 let recipe_name = if fallback_used {
1420 format!("{}_fallback_same_split", recipe.name)
1421 } else {
1422 recipe.name.to_string()
1423 };
1424 Some(SampleTriplet {
1425 recipe: recipe_name,
1426 anchor: anchor_chunk,
1427 positive: positive_chunk,
1428 negative: negative_chunk,
1429 weight,
1430 instruction: recipe.instruction.as_ref().map(|s| s.to_string()),
1431 })
1432 }
1433
1434 fn make_triplet_with_anchor(
1435 &mut self,
1436 recipe: &TripletRecipe,
1437 record: &DataRecord,
1438 rng: &mut DeterministicRng,
1439 ) -> Option<SampleTriplet> {
1440 if Self::is_auto_chunk_pair_recipe(recipe) {
1441 return self.make_auto_chunk_pair_triplet_with_anchor(recipe, record, rng);
1442 }
1443 self.make_standard_triplet_with_anchor(recipe, record, rng)
1444 }
1445
1446 fn make_text_sample_for_split(
1447 &mut self,
1448 recipe: &TextRecipe,
1449 source: Option<&str>,
1450 split: SplitLabel,
1451 rng: &mut DeterministicRng,
1452 ) -> Option<TextSample> {
1453 let record = self.choose_anchor_record(source, split)?;
1454 let mut chunk = self.select_chunk(&record, &recipe.selector)?;
1455 self.decorate_chunk(&record, &mut chunk, rng);
1456 let weight = recipe.weight * self.chunk_weight(&chunk);
1457 Some(TextSample {
1458 recipe: recipe.name.to_string(),
1459 chunk,
1460 weight,
1461 instruction: recipe.instruction.as_ref().map(|s| s.to_string()),
1462 })
1463 }
1464
1465 fn chunk_weight(&self, chunk: &RecordChunk) -> f32 {
1466 chunk_weight(&self.config.chunking, chunk)
1467 }
1468
1469 fn triplet_chunk_weight(
1470 &self,
1471 recipe: &TripletRecipe,
1472 anchor: &RecordChunk,
1473 positive: &RecordChunk,
1474 negative: &RecordChunk,
1475 ) -> f32 {
1476 let floor = self.config.chunking.chunk_weight_floor;
1477 let negative_weight = negative.quality.trust.clamp(0.0, 1.0).max(floor);
1478 if Self::is_auto_chunk_pair_recipe(recipe) {
1479 let pair_trust = ((anchor.quality.trust.clamp(0.0, 1.0)
1482 + positive.quality.trust.clamp(0.0, 1.0))
1483 / 2.0)
1484 .clamp(0.0, 1.0);
1485 let pair_weight = (chunk_proximity_score(anchor, positive) * pair_trust).max(floor);
1486 return (pair_weight + pair_weight + negative_weight) / 3.0;
1488 }
1489 let pair_proximity = chunk_proximity_score(anchor, positive);
1492 let anchor_weight = (self.chunk_weight(anchor) * pair_proximity).max(floor);
1493 let positive_weight = (self.chunk_weight(positive) * pair_proximity).max(floor);
1494 (anchor_weight + positive_weight + negative_weight) / 3.0
1495 }
1496
1497 fn decorate_chunk(
1498 &mut self,
1499 record: &DataRecord,
1500 chunk: &mut RecordChunk,
1501 rng: &mut DeterministicRng,
1502 ) {
1503 chunk.kvp_meta = record
1504 .meta_prefix
1505 .as_ref()
1506 .map(|s| s.all_metadata())
1507 .unwrap_or_default();
1508 if let Some(spec) = record.meta_prefix.as_ref()
1509 && let Some(prefix) = spec.sample(rng)
1510 {
1511 let body_tokens: Vec<&str> = WhitespaceTokenizer.tokenize(&chunk.text);
1512 let prefix_tokens: Vec<&str> = WhitespaceTokenizer.tokenize(&prefix);
1513 let total_tokens = prefix_tokens.len() + body_tokens.len();
1514 let max_window = self.config.chunking.max_window_tokens;
1515 if max_window > 0 && total_tokens > max_window {
1516 if prefix_tokens.len() >= max_window {
1517 chunk.text = prefix_tokens
1518 .into_iter()
1519 .take(max_window)
1520 .collect::<Vec<_>>()
1521 .join(" ");
1522 chunk.tokens_estimate = max_window;
1523 } else {
1524 let remaining = max_window - prefix_tokens.len();
1525 let trimmed_body: Vec<&str> = body_tokens.into_iter().take(remaining).collect();
1526 chunk.text =
1527 format!("{}{}{}", prefix, platform_newline(), trimmed_body.join(" "));
1528 chunk.tokens_estimate = max_window;
1529 }
1530 } else {
1531 chunk.text = format!("{}{}{}", prefix, platform_newline(), chunk.text);
1532 chunk.tokens_estimate = total_tokens;
1533 }
1534 }
1535 }
1536
1537 fn select_chunk_parallel(
1541 &self,
1542 record: &DataRecord,
1543 selector: &Selector,
1544 rng: &mut DeterministicRng,
1545 ) -> Option<RecordChunk> {
1546 match selector {
1547 Selector::Role(role) => self.select_role_parallel(record, role, rng),
1548 Selector::Paragraph(idx) => record.sections.get(*idx).and_then(|section| {
1549 let pool = self.materialize_chunks(record, *idx, section);
1550 if pool.is_empty() {
1551 return None;
1552 }
1553 let i = rng.random_range(0..pool.len());
1554 pool.into_iter().nth(i)
1555 }),
1556 Selector::TemporalOffset(offset) => self
1557 .select_temporal_neighbor(record, *offset)
1558 .and_then(|neighbor| {
1559 self.select_role_parallel(&neighbor, &SectionRole::Context, rng)
1560 }),
1561 Selector::Random => {
1562 if record.sections.is_empty() {
1563 return None;
1564 }
1565 let idx = rng.random_range(0..record.sections.len());
1566 record.sections.get(idx).and_then(|section| {
1567 let pool = self.materialize_chunks(record, idx, section);
1568 if pool.is_empty() {
1569 return None;
1570 }
1571 let i = rng.random_range(0..pool.len());
1572 pool.into_iter().nth(i)
1573 })
1574 }
1575 }
1576 }
1577
1578 fn select_role_parallel(
1580 &self,
1581 record: &DataRecord,
1582 role: &SectionRole,
1583 rng: &mut DeterministicRng,
1584 ) -> Option<RecordChunk> {
1585 let indices: Vec<usize> = record
1586 .sections
1587 .iter()
1588 .enumerate()
1589 .filter(|(_, s)| roles_match(role, &s.role))
1590 .map(|(i, _)| i)
1591 .collect();
1592 if indices.is_empty() {
1593 return None;
1594 }
1595 let start = rng.random_range(0..indices.len());
1596 for offset in 0..indices.len() {
1597 let section_idx = indices[(start + offset) % indices.len()];
1598 let section = &record.sections[section_idx];
1599 let pool = self.materialize_chunks(record, section_idx, section);
1600 if !pool.is_empty() {
1601 let i = rng.random_range(0..pool.len());
1602 return pool.into_iter().nth(i);
1603 }
1604 }
1605 None
1606 }
1607
1608 fn decorate_chunk_parallel(
1610 &self,
1611 record: &DataRecord,
1612 chunk: &mut RecordChunk,
1613 rng: &mut DeterministicRng,
1614 ) {
1615 chunk.kvp_meta = record
1616 .meta_prefix
1617 .as_ref()
1618 .map(|s| s.all_metadata())
1619 .unwrap_or_default();
1620 if let Some(spec) = record.meta_prefix.as_ref()
1621 && let Some(prefix) = spec.sample(rng)
1622 {
1623 let body_tokens: Vec<&str> = WhitespaceTokenizer.tokenize(&chunk.text);
1624 let prefix_tokens: Vec<&str> = WhitespaceTokenizer.tokenize(&prefix);
1625 let total_tokens = prefix_tokens.len() + body_tokens.len();
1626 let max_window = self.config.chunking.max_window_tokens;
1627 if max_window > 0 && total_tokens > max_window {
1628 if prefix_tokens.len() >= max_window {
1629 chunk.text = prefix_tokens
1630 .into_iter()
1631 .take(max_window)
1632 .collect::<Vec<_>>()
1633 .join(" ");
1634 chunk.tokens_estimate = max_window;
1635 } else {
1636 let remaining = max_window - prefix_tokens.len();
1637 let trimmed_body: Vec<&str> = body_tokens.into_iter().take(remaining).collect();
1638 chunk.text =
1639 format!("{}{}{}", prefix, platform_newline(), trimmed_body.join(" "));
1640 chunk.tokens_estimate = max_window;
1641 }
1642 } else {
1643 chunk.text = format!("{}{}{}", prefix, platform_newline(), chunk.text);
1644 chunk.tokens_estimate = total_tokens;
1645 }
1646 }
1647 }
1648
1649 fn select_anchor_positive_parallel(
1651 &self,
1652 record: &DataRecord,
1653 anchor_selector: &Selector,
1654 positive_selector: &Selector,
1655 enforce_window_pair: bool,
1656 rng: &mut DeterministicRng,
1657 ) -> Option<(RecordChunk, RecordChunk)> {
1658 let anchor_chunk = self.select_chunk_parallel(record, anchor_selector, rng)?;
1659 let mut positive_chunk = self.select_chunk_parallel(record, positive_selector, rng)?;
1660 if anchor_selector == positive_selector {
1661 let mut retries = 0usize;
1662 while !same_selector_pair_is_valid(&anchor_chunk, &positive_chunk, enforce_window_pair)
1663 && retries < SAME_SELECTOR_PAIR_RETRY_LIMIT
1664 {
1665 positive_chunk = self.select_chunk_parallel(record, positive_selector, rng)?;
1666 retries += 1;
1667 }
1668 if !same_selector_pair_is_valid(&anchor_chunk, &positive_chunk, enforce_window_pair) {
1669 return None;
1670 }
1671 }
1672 Some((anchor_chunk, positive_chunk))
1674 }
1675
1676 fn select_anchor_positive_for_recipe(
1679 &self,
1680 recipe: &TripletRecipe,
1681 anchor_record: &DataRecord,
1682 rng: &mut DeterministicRng,
1683 ) -> Option<(RecordChunk, RecordChunk, String)> {
1684 if Self::is_auto_chunk_pair_recipe(recipe) {
1685 if !self
1686 .record_has_at_least_two_window_chunks_for_selector(anchor_record, &recipe.anchor)
1687 {
1688 return None;
1689 }
1690 let mut anchor_chunk =
1691 self.select_chunk_parallel(anchor_record, &recipe.anchor, rng)?;
1692 let mut positive_chunk =
1693 self.select_chunk_parallel(anchor_record, &recipe.anchor, rng)?;
1694 let mut tries = 0usize;
1695 while !same_selector_pair_is_valid(&anchor_chunk, &positive_chunk, true) {
1696 tries += 1;
1697 if tries >= SAME_SELECTOR_PAIR_RETRY_LIMIT {
1698 return None;
1699 }
1700 anchor_chunk = self.select_chunk_parallel(anchor_record, &recipe.anchor, rng)?;
1701 positive_chunk = self.select_chunk_parallel(anchor_record, &recipe.anchor, rng)?;
1702 }
1703 let anchor_raw_text = anchor_chunk.text.clone();
1704 self.decorate_chunk_parallel(anchor_record, &mut anchor_chunk, rng);
1705 self.decorate_chunk_parallel(anchor_record, &mut positive_chunk, rng);
1706 return Some((anchor_chunk, positive_chunk, anchor_raw_text));
1707 }
1708 let (mut anchor_chunk, mut positive_chunk) = self.select_anchor_positive_parallel(
1709 anchor_record,
1710 &recipe.anchor,
1711 &recipe.positive_selector,
1712 false,
1713 rng,
1714 )?;
1715 let anchor_raw_text = anchor_chunk.text.clone();
1716 self.decorate_chunk_parallel(anchor_record, &mut anchor_chunk, rng);
1717 self.decorate_chunk_parallel(anchor_record, &mut positive_chunk, rng);
1718 Some((anchor_chunk, positive_chunk, anchor_raw_text))
1719 }
1720
1721 fn select_chunk(&mut self, record: &DataRecord, selector: &Selector) -> Option<RecordChunk> {
1722 match selector {
1723 Selector::Role(role) => self.select_by_role(record, role),
1724 Selector::Paragraph(idx) => record.sections.get(*idx).and_then(|section| {
1725 let pool = self.materialize_chunks(record, *idx, section);
1726 self.next_chunk_from_pool(&record.id, *idx, pool)
1727 }),
1728 Selector::TemporalOffset(offset) => self
1729 .select_temporal_neighbor(record, *offset)
1730 .and_then(|neighbor| self.select_by_role(&neighbor, &SectionRole::Context)),
1731 Selector::Random => {
1732 if record.sections.is_empty() {
1733 return None;
1734 }
1735 let idx = self.rng.random_range(0..record.sections.len());
1736 record.sections.get(idx).and_then(|section| {
1737 let pool = self.materialize_chunks(record, idx, section);
1738 self.next_chunk_from_pool(&record.id, idx, pool)
1739 })
1740 }
1741 }
1742 }
1743
1744 fn select_by_role(&mut self, record: &DataRecord, role: &SectionRole) -> Option<RecordChunk> {
1745 let indices: Vec<usize> = record
1746 .sections
1747 .iter()
1748 .enumerate()
1749 .filter(|(_, section)| roles_match(role, §ion.role))
1750 .map(|(idx, _)| idx)
1751 .collect();
1752 if indices.is_empty() {
1753 return None;
1754 }
1755 let key = role_cursor_key(&record.id, role);
1756 let start_offset = self
1757 .role_cursors
1758 .get(&key)
1759 .and_then(|last_idx| indices.iter().position(|idx| idx == last_idx))
1760 .map(|pos| (pos + 1) % indices.len())
1761 .unwrap_or_else(|| {
1762 let seed_key = format!("{}::{}", key.0, key.1);
1766 (stable_hash_str(self.epoch_seed(), &seed_key) as usize) % indices.len()
1767 });
1768 for offset in 0..indices.len() {
1769 let section_idx = indices[(start_offset + offset) % indices.len()];
1770 let section = &record.sections[section_idx];
1771 let pool = self.materialize_chunks(record, section_idx, section);
1772 if let Some(chunk) = self.next_chunk_from_pool(&record.id, section_idx, pool) {
1773 self.role_cursors.insert(key.clone(), section_idx);
1774 return Some(chunk);
1775 }
1776 }
1777 None
1778 }
1779
1780 fn materialize_chunks(
1793 &self,
1794 record: &DataRecord,
1795 section_idx: usize,
1796 section: &RecordSection,
1797 ) -> Vec<RecordChunk> {
1798 self.chunker
1799 .materialize(&self.config.chunking, record, section_idx, section)
1800 }
1801
1802 fn build_derived_text_recipes(recipes: &[TripletRecipe]) -> Vec<TextRecipe> {
1803 let mut derived = Vec::new();
1804 for recipe in recipes {
1805 let base = recipe.name.as_ref();
1806 derived.push(TextRecipe {
1807 name: Cow::Owned(format!("{base}_anchor")),
1808 selector: recipe.anchor.clone(),
1809 weight: recipe.weight,
1810 instruction: None,
1811 });
1812 derived.push(TextRecipe {
1813 name: Cow::Owned(format!("{base}_positive")),
1814 selector: recipe.positive_selector.clone(),
1815 weight: recipe.weight,
1816 instruction: None,
1817 });
1818 derived.push(TextRecipe {
1819 name: Cow::Owned(format!("{base}_negative")),
1820 selector: recipe.negative_selector.clone(),
1821 weight: recipe.weight,
1822 instruction: None,
1823 });
1824 }
1825 derived
1826 }
1827
1828 fn record_has_long_anchor_or_context_section(&self, record: &DataRecord) -> bool {
1831 let window = self.config.chunking.max_window_tokens;
1832 if window == 0 {
1833 return false;
1834 }
1835 record.sections.iter().any(|section| {
1836 matches!(section.role, SectionRole::Anchor | SectionRole::Context)
1837 && WhitespaceTokenizer.token_count(§ion.text) > window
1838 })
1839 }
1840
1841 fn sync_records_from_cache(&mut self) -> Result<(), SamplerError> {
1842 let mut snapshot = self.ingestion.all_records_snapshot();
1843 snapshot.sort_by(|a, b| a.id.cmp(&b.id));
1844 let old_ids: HashSet<&str> = self.records.keys().map(|s| s.as_str()).collect();
1847 let incoming_ids: HashSet<&str> = snapshot.iter().map(|r| r.id.as_str()).collect();
1848 let pool_changed = old_ids != incoming_ids;
1849
1850 self.records.clear();
1855 self.sources_with_long_sections.clear();
1859 if pool_changed {
1864 self.emitted_text_hashes.clear();
1865 }
1866 self.negative_backend.on_sync_start();
1871 for record in snapshot {
1872 if self.split_store.label_for(&record.id).is_none() {
1873 self.split_store.ensure(record.id.clone())?;
1874 }
1875 if self.record_has_long_anchor_or_context_section(&record) {
1876 self.sources_with_long_sections
1878 .insert(record.source.clone());
1879 }
1880 self.records.insert(record.id.clone(), Arc::new(record));
1881 }
1882 self.prune_cursor_state();
1883 self.rebuild_chunk_index();
1884 self.rebuild_source_index()?;
1885 Ok(())
1886 }
1887
1888 fn ingest_internal_for_split(&mut self, target_split: SplitLabel) -> Result<(), SamplerError> {
1889 if !self.ingestion.has_sources() {
1890 return Ok(());
1891 }
1892 if !self.ingestion_cursors_loaded {
1893 if let Some(state) = self.split_store.load_sampler_state()? {
1894 self.ingestion.load_cursors(&state.source_stream_cursors);
1895 self.ingestion.set_source_epoch(state.source_epoch);
1896 }
1897 self.ingestion_cursors_loaded = true;
1898 }
1899 if self.ingestion.all_caches_empty() {
1900 self.ingestion.refresh_all();
1901 } else {
1902 self.ingestion.advance(self.config.batch_size);
1903 }
1904 let observed = self.ingestion.total_ingest_count();
1905 if observed == self.last_observed_ingest && !self.records.is_empty() {
1906 return Ok(());
1907 }
1908 self.last_observed_ingest = observed;
1909 self.sync_records_from_cache()?;
1910 let max_window_tokens = self.config.chunking.max_window_tokens;
1911 self.negative_backend.on_records_refreshed(
1912 &self.records,
1913 max_window_tokens,
1914 &|id| self.split_store.label_for(id),
1915 self.ingestion.last_refreshed_sources(),
1916 );
1917 self.epoch_tracker.ensure_loaded()?;
1920 let records_by_split = self.records_by_split()?;
1921 self.epoch_tracker
1922 .reconcile(target_split, &records_by_split);
1923 self.ensure_source_state()?;
1924 Ok(())
1925 }
1926
1927 #[cfg(test)]
1928 fn ingest_internal(&mut self, split: SplitLabel) -> Result<(), SamplerError> {
1929 self.ingest_internal_for_split(split)
1930 }
1931
1932 fn ingest_internal_with_weights_for_split(
1933 &mut self,
1934 target_split: SplitLabel,
1935 weights: &HashMap<SourceId, f32>,
1936 ) -> Result<(), SamplerError> {
1937 if !self.ingestion.has_sources() {
1938 return Ok(());
1939 }
1940 if !self.ingestion_cursors_loaded {
1941 if let Some(state) = self.split_store.load_sampler_state()? {
1942 self.ingestion.load_cursors(&state.source_stream_cursors);
1943 self.ingestion.set_source_epoch(state.source_epoch);
1944 }
1945 self.ingestion_cursors_loaded = true;
1946 }
1947 if self.ingestion.all_caches_empty() {
1948 self.ingestion.refresh_all_with_weights(weights)?;
1949 } else {
1950 self.ingestion
1951 .advance_with_weights(self.config.batch_size, weights)?;
1952 }
1953 let observed = self.ingestion.total_ingest_count();
1954 if observed == self.last_observed_ingest && !self.records.is_empty() {
1955 return Ok(());
1956 }
1957 self.last_observed_ingest = observed;
1958 self.sync_records_from_cache()?;
1959 let max_window_tokens = self.config.chunking.max_window_tokens;
1960 self.negative_backend.on_records_refreshed(
1961 &self.records,
1962 max_window_tokens,
1963 &|id| self.split_store.label_for(id),
1964 self.ingestion.last_refreshed_sources(),
1965 );
1966 self.epoch_tracker.ensure_loaded()?;
1967 let records_by_split = self.records_by_split()?;
1968 self.epoch_tracker
1969 .reconcile(target_split, &records_by_split);
1970 self.ensure_source_state()?;
1971 Ok(())
1972 }
1973
1974 fn ingest_with_weights_fallback(
1979 &mut self,
1980 target_split: SplitLabel,
1981 weights: Option<&HashMap<SourceId, f32>>,
1982 ) -> Result<(), SamplerError> {
1983 match weights {
1984 Some(weights)
1985 if !weights.is_empty()
1986 && !weights
1987 .values()
1988 .all(|&w| w == *weights.values().next().unwrap()) =>
1989 {
1990 self.ingest_internal_with_weights_for_split(target_split, weights)?
1991 }
1992 _ => self.ingest_internal_for_split(target_split)?,
1993 }
1994 Ok(())
1995 }
1996
1997 fn force_ingest_refresh_with_weights_for_split(
1998 &mut self,
1999 target_split: SplitLabel,
2000 weights: &HashMap<SourceId, f32>,
2001 ) -> Result<(), SamplerError> {
2002 if !self.ingestion.has_sources() {
2003 return Ok(());
2004 }
2005 if !self.ingestion_cursors_loaded {
2006 if let Some(state) = self.split_store.load_sampler_state()? {
2007 self.ingestion.load_cursors(&state.source_stream_cursors);
2008 self.ingestion.set_source_epoch(state.source_epoch);
2009 }
2010 self.ingestion_cursors_loaded = true;
2011 }
2012 self.ingestion.force_refresh_all_with_weights(weights)?;
2013 self.last_observed_ingest = self.ingestion.total_ingest_count();
2014 self.sync_records_from_cache()?;
2015 let max_window_tokens = self.config.chunking.max_window_tokens;
2016 self.negative_backend.on_records_refreshed(
2017 &self.records,
2018 max_window_tokens,
2019 &|id| self.split_store.label_for(id),
2020 self.ingestion.last_refreshed_sources(),
2021 );
2022 self.epoch_tracker.ensure_loaded()?;
2023 let records_by_split = self.records_by_split()?;
2024 self.epoch_tracker
2025 .reconcile(target_split, &records_by_split);
2026 self.ensure_source_state()?;
2027 Ok(())
2028 }
2029
2030 fn sample_source_triplet_candidate(
2040 &mut self,
2041 source: &str,
2042 target_split: SplitLabel,
2043 recipe_orders: &mut HashMap<RecipeKey, Vec<usize>>,
2044 recipe_positions: &mut HashMap<RecipeKey, usize>,
2045 rng: &mut DeterministicRng,
2046 ) -> (Option<(TripletRecipe, SampleTriplet)>, usize) {
2047 let (recipes, _auto_injected) = self.resolve_source_triplet_plan(source);
2050 if recipes.is_empty() {
2051 return (None, 0);
2052 }
2053 if !recipe_orders.contains_key(source) {
2054 let recipe_weights: Vec<f32> = recipes.iter().map(|r| r.weight).collect();
2055 let order =
2056 self.recipe_order_weighted_cycled(&recipe_weights, self.triplet_recipe_rr_idx, rng);
2057 recipe_orders.insert(source.to_string(), order);
2058 }
2059 let order = recipe_orders
2060 .get(source)
2061 .expect("recipe order missing for source");
2062 let pos = recipe_positions.entry(source.to_string()).or_insert(0);
2063 let Some(anchor) = self.choose_anchor_record(Some(source), target_split) else {
2064 return (None, 0);
2065 };
2066
2067 let mut attempts = 0usize;
2068 for offset in 0..order.len() {
2069 let idx = order[(*pos + offset) % order.len()];
2070 attempts = attempts.saturating_add(1);
2071 let recipe = recipes[idx].clone();
2072 if let Some(sample) = self.make_triplet_with_anchor(&recipe, &anchor, rng) {
2076 *pos = (*pos + offset + 1) % order.len();
2077 return (Some((recipe, sample)), attempts);
2078 }
2079 }
2080
2081 (None, attempts)
2082 }
2083
2084 fn next_pair_batch_inner_with_weights(
2085 &mut self,
2086 target_split: SplitLabel,
2087 weights: Option<&HashMap<SourceId, f32>>,
2088 ) -> Result<SampleBatch, SamplerError> {
2089 if let Some(weights) = weights {
2090 if weights.is_empty()
2091 || weights
2092 .values()
2093 .all(|&w| w == *weights.values().next().unwrap())
2094 {
2095 self.ingest_internal_for_split(target_split)?;
2096 } else {
2097 self.ingest_internal_with_weights_for_split(target_split, weights)?;
2098 }
2099 } else {
2100 self.ingest_internal_for_split(target_split)?;
2101 }
2102 self.ensure_split_has_records(target_split)?;
2103 let sources = self.source_order.clone();
2104 if sources.is_empty() {
2105 if self.triplet_recipes.is_empty() {
2106 return Err(SamplerError::Configuration(
2107 "no triplet recipes available".into(),
2108 ));
2109 }
2110 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2111 let recipe_weights: Vec<f32> = self.triplet_recipes.iter().map(|r| r.weight).collect();
2112 let recipe_order = self.recipe_order_weighted_cycled(
2113 &recipe_weights,
2114 self.triplet_recipe_rr_idx,
2115 &mut rng,
2116 );
2117 let mut pairs = Vec::new();
2118 let mut seen = HashSet::new();
2119 let mut last_recipe_name = None;
2120 let mut recipe_pos = 0usize;
2121 let mut recipe_steps = 0usize;
2122 let attempts = self.config.batch_size * 4 * recipe_order.len().max(1);
2123 for _ in 0..attempts {
2124 if pairs.len() >= self.config.batch_size {
2125 break;
2126 }
2127 let Some(anchor) = self.choose_anchor_record(None, target_split) else {
2128 break;
2129 };
2130 let mut triplet = None;
2131 for offset in 0..recipe_order.len() {
2132 let idx = recipe_order[(recipe_pos + offset) % recipe_order.len()];
2133 recipe_steps = recipe_steps.saturating_add(1);
2134 let recipe = self.triplet_recipes[idx].clone();
2135 last_recipe_name = Some(recipe.name.clone());
2136 if let Some(sample) = self.make_triplet_with_anchor(&recipe, &anchor, &mut rng)
2137 {
2138 triplet = Some((recipe, sample));
2139 recipe_pos = (recipe_pos + offset + 1) % recipe_order.len();
2140 break;
2141 }
2142 }
2143 if let Some((recipe, triplet)) = triplet {
2144 let key = (
2145 triplet.anchor.record_id.clone(),
2146 triplet.positive.record_id.clone(),
2147 triplet.negative.record_id.clone(),
2148 );
2149 if seen.insert(key) {
2150 let SampleTriplet {
2151 recipe: triplet_recipe_name,
2152 anchor,
2153 positive,
2154 negative,
2155 weight,
2156 instruction,
2157 } = triplet;
2158 if pairs.len() < self.config.batch_size {
2159 pairs.push(SamplePair {
2160 recipe: triplet_recipe_name.clone(),
2161 anchor: anchor.clone(),
2162 positive: positive.clone(),
2163 weight,
2164 instruction: instruction.clone(),
2165 label: PairLabel::Positive,
2166 reason: None,
2167 });
2168 }
2169 if pairs.len() < self.config.batch_size {
2170 pairs.push(SamplePair {
2171 recipe: triplet_recipe_name,
2172 anchor,
2173 positive: negative,
2174 weight,
2175 instruction,
2176 label: PairLabel::Negative,
2177 reason: Some(
2178 strategy_reason(&recipe.negative_strategy).to_string(),
2179 ),
2180 });
2181 }
2182 }
2183 }
2184 }
2185 if recipe_steps > 0 {
2186 self.triplet_recipe_rr_idx =
2187 self.triplet_recipe_rr_idx.saturating_add(recipe_steps);
2188 }
2189 self.rng = rng;
2190 pad_with_reuse(&mut pairs, self.config.batch_size);
2191 if pairs.len() == self.config.batch_size {
2192 return Ok(SampleBatch { pairs });
2193 }
2194 return Err(SamplerError::Exhausted(
2195 last_recipe_name
2196 .unwrap_or(Cow::Borrowed(RECIPE_LABEL_TRIPLETS))
2197 .to_string(),
2198 ));
2199 }
2200
2201 let mut pairs = Vec::new();
2202 let mut seen = HashSet::new();
2203 let mut source_steps = 0usize;
2204 let mut cycle = (self.source_cycle_idx / sources.len()) as u64;
2205 let mut source_idx = self.source_cycle_idx % sources.len();
2206 let mut cycle_sources = self.shuffled_source_cycle(cycle);
2207 let mut recipe_orders: HashMap<RecipeKey, Vec<usize>> = HashMap::new();
2208 let mut recipe_positions: HashMap<RecipeKey, usize> = HashMap::new();
2209 let mut recipe_steps = 0usize;
2210 let max_recipe_len = sources
2211 .iter()
2212 .map(|source| self.triplet_recipe_count_for_source(source))
2213 .max()
2214 .unwrap_or(1)
2215 .max(1);
2216 let attempts = self.config.batch_size * 4 * sources.len() * max_recipe_len;
2217 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2218 for _ in 0..attempts {
2219 if pairs.len() >= self.config.batch_size {
2220 break;
2221 }
2222 let source = cycle_sources[source_idx].as_str();
2223 let (triplet, attempts_used) = self.sample_source_triplet_candidate(
2224 source,
2225 target_split,
2226 &mut recipe_orders,
2227 &mut recipe_positions,
2228 &mut rng,
2229 );
2230 recipe_steps = recipe_steps.saturating_add(attempts_used);
2231 if let Some((recipe, triplet)) = triplet {
2232 let key = (
2233 triplet.anchor.record_id.clone(),
2234 triplet.positive.record_id.clone(),
2235 triplet.negative.record_id.clone(),
2236 );
2237 if seen.insert(key) {
2238 let SampleTriplet {
2239 recipe: triplet_recipe_name,
2240 anchor,
2241 positive,
2242 negative,
2243 weight,
2244 instruction,
2245 } = triplet;
2246 if pairs.len() < self.config.batch_size {
2247 pairs.push(SamplePair {
2248 recipe: triplet_recipe_name.clone(),
2249 anchor: anchor.clone(),
2250 positive: positive.clone(),
2251 weight,
2252 instruction: instruction.clone(),
2253 label: PairLabel::Positive,
2254 reason: None,
2255 });
2256 }
2257 if pairs.len() < self.config.batch_size {
2258 pairs.push(SamplePair {
2259 recipe: triplet_recipe_name,
2260 anchor,
2261 positive: negative,
2262 weight,
2263 instruction,
2264 label: PairLabel::Negative,
2265 reason: Some(strategy_reason(&recipe.negative_strategy).to_string()),
2266 });
2267 }
2268 }
2269 }
2270 source_idx += 1;
2271 source_steps += 1;
2272 if source_idx >= cycle_sources.len() {
2273 source_idx = 0;
2274 cycle = cycle.saturating_add(1);
2275 cycle_sources = self.shuffled_source_cycle(cycle);
2276 }
2277 }
2278 if recipe_steps > 0 {
2279 self.triplet_recipe_rr_idx = self.triplet_recipe_rr_idx.saturating_add(recipe_steps);
2280 }
2281 self.rng = rng;
2282 pad_with_reuse(&mut pairs, self.config.batch_size);
2283 if pairs.len() == self.config.batch_size {
2284 self.source_cycle_idx = self.source_cycle_idx.saturating_add(source_steps);
2285 self.source_state_dirty = sources.len() > 1;
2286 return Ok(SampleBatch { pairs });
2287 }
2288 Err(SamplerError::Exhausted(RECIPE_LABEL_TRIPLETS.into()))
2289 }
2290
2291 fn try_emit_text_sample(&mut self, sample: &TextSample) -> bool {
2296 self.emitted_text_hashes
2297 .insert(stable_hash_str(0, &sample.chunk.text))
2298 }
2299
2300 fn next_text_batch_inner_with_weights(
2301 &mut self,
2302 target_split: SplitLabel,
2303 weights: Option<&HashMap<SourceId, f32>>,
2304 ) -> Result<TextBatch, SamplerError> {
2305 self.ingest_with_weights_fallback(target_split, weights)?;
2306 self.ensure_split_has_records(target_split)?;
2307 let sources = self.source_order.clone();
2308 if sources.is_empty() {
2309 if self.text_recipes.is_empty() {
2310 return Err(SamplerError::Configuration(
2311 "no text recipes configured".into(),
2312 ));
2313 }
2314 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2315 let recipe_weights: Vec<f32> = self.text_recipes.iter().map(|r| r.weight).collect();
2316 let recipe_order = self.text_recipe_order_weighted_cycled(
2317 &recipe_weights,
2318 self.text_recipe_rr_idx,
2319 &mut rng,
2320 );
2321 let mut samples = Vec::new();
2322 let mut last_recipe_name = None;
2323 let mut recipe_pos = 0usize;
2324 let mut recipe_steps = 0usize;
2325 let attempts = self.config.batch_size * 4 * recipe_order.len().max(1);
2326 for _ in 0..attempts {
2327 if samples.len() >= self.config.batch_size {
2328 break;
2329 }
2330 let recipe_idx = recipe_order[recipe_pos];
2331 recipe_pos = (recipe_pos + 1) % recipe_order.len();
2332 recipe_steps = recipe_steps.saturating_add(1);
2333 let recipe = self.text_recipes[recipe_idx].clone();
2334 last_recipe_name = Some(recipe.name.clone());
2335 if let Some(sample) =
2336 self.make_text_sample_for_split(&recipe, None, target_split, &mut rng)
2337 && self.try_emit_text_sample(&sample)
2338 {
2339 samples.push(sample);
2340 }
2341 }
2342 if recipe_steps > 0 {
2343 self.text_recipe_rr_idx = self.text_recipe_rr_idx.saturating_add(recipe_steps);
2344 }
2345 self.rng = rng;
2346 pad_with_reuse(&mut samples, self.config.batch_size);
2347 if samples.len() == self.config.batch_size {
2348 return Ok(TextBatch { samples });
2349 }
2350 return Err(SamplerError::Exhausted(
2351 last_recipe_name
2352 .unwrap_or(Cow::Borrowed(RECIPE_LABEL_TEXT))
2353 .to_string(),
2354 ));
2355 }
2356
2357 let mut samples = Vec::new();
2358 let mut source_steps = 0usize;
2359 let mut cycle = (self.source_cycle_idx / sources.len()) as u64;
2360 let mut idx = self.source_cycle_idx % sources.len();
2361 let mut cycle_sources = self.shuffled_source_cycle(cycle);
2362 let mut recipe_orders: HashMap<RecipeKey, Vec<usize>> = HashMap::new();
2363 let mut recipe_positions: HashMap<RecipeKey, usize> = HashMap::new();
2364 let mut recipe_steps = 0usize;
2365 let max_recipe_len = sources
2366 .iter()
2367 .map(|source| self.text_recipes_for_source(source).len())
2368 .max()
2369 .unwrap_or(1)
2370 .max(1);
2371 let attempts = self.config.batch_size * 4 * sources.len() * max_recipe_len;
2372 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2373 for _ in 0..attempts {
2374 if samples.len() >= self.config.batch_size {
2375 break;
2376 }
2377 let source = cycle_sources[idx].as_str();
2378 let recipes = self.text_recipes_for_source(source).to_vec();
2379 if recipes.is_empty() {
2380 idx += 1;
2381 source_steps += 1;
2382 if idx >= cycle_sources.len() {
2383 idx = 0;
2384 cycle = cycle.saturating_add(1);
2385 cycle_sources = self.shuffled_source_cycle(cycle);
2386 }
2387 continue;
2388 }
2389 if !recipe_orders.contains_key(source) {
2390 let recipe_weights: Vec<f32> = recipes.iter().map(|r| r.weight).collect();
2391 let order = self.text_recipe_order_weighted_cycled(
2392 &recipe_weights,
2393 self.text_recipe_rr_idx,
2394 &mut rng,
2395 );
2396 recipe_orders.insert(source.to_string(), order);
2397 }
2398 let order = recipe_orders
2399 .get(source)
2400 .expect("recipe order missing for source");
2401 let pos = recipe_positions.entry(source.to_string()).or_insert(0);
2402 let mut sample: Option<(TextRecipe, TextSample)> = None;
2403 for offset in 0..order.len() {
2404 let recipe_idx = order[(*pos + offset) % order.len()];
2405 let recipe = recipes[recipe_idx].clone();
2406 if let Some(item) =
2407 self.make_text_sample_for_split(&recipe, Some(source), target_split, &mut rng)
2408 {
2409 recipe_steps = recipe_steps.saturating_add(offset + 1);
2410 *pos = (*pos + offset + 1) % order.len();
2411 sample = Some((recipe, item));
2412 break;
2413 }
2414 }
2415 if sample.is_none() {
2416 recipe_steps = recipe_steps.saturating_add(order.len());
2417 }
2418 if let Some((_recipe, sample)) = sample
2419 && self.try_emit_text_sample(&sample)
2420 {
2421 samples.push(sample);
2422 }
2423 idx += 1;
2424 source_steps += 1;
2425 if idx >= cycle_sources.len() {
2426 idx = 0;
2427 cycle = cycle.saturating_add(1);
2428 cycle_sources = self.shuffled_source_cycle(cycle);
2429 }
2430 }
2431 if samples.len() != self.config.batch_size {
2432 pad_with_reuse(&mut samples, self.config.batch_size);
2433 }
2434 if samples.len() != self.config.batch_size {
2435 self.rng = rng;
2436 return Err(SamplerError::Exhausted(RECIPE_LABEL_TEXT.into()));
2437 }
2438 self.rng = rng;
2439 self.source_cycle_idx = self.source_cycle_idx.saturating_add(source_steps);
2440 self.source_state_dirty = sources.len() > 1;
2441 if recipe_steps > 0 {
2442 self.text_recipe_rr_idx = self.text_recipe_rr_idx.saturating_add(recipe_steps);
2443 }
2444 Ok(TextBatch { samples })
2445 }
2446
2447 fn next_triplet_batch_inner_with_weights(
2448 &mut self,
2449 target_split: SplitLabel,
2450 weights: Option<&HashMap<SourceId, f32>>,
2451 ) -> Result<TripletBatch, SamplerError> {
2452 self.ingest_with_weights_fallback(target_split, weights)?;
2453 self.ensure_split_has_records(target_split)?;
2454 let sources = self.source_order.clone();
2455 if sources.is_empty() {
2456 if self.triplet_recipes.is_empty() {
2457 return Err(SamplerError::Configuration(
2458 "no triplet recipes configured".into(),
2459 ));
2460 }
2461 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2462 let recipe_weights: Vec<f32> = self.triplet_recipes.iter().map(|r| r.weight).collect();
2463 let recipe_order = self.recipe_order_weighted_cycled(
2464 &recipe_weights,
2465 self.triplet_recipe_rr_idx,
2466 &mut rng,
2467 );
2468 let mut triplets = Vec::new();
2469 let mut seen = HashSet::new();
2470 let mut last_recipe_name = None;
2471 let mut recipe_pos = 0usize;
2472 let mut recipe_steps = 0usize;
2473 let attempts = self.config.batch_size * 4 * recipe_order.len().max(1);
2474 for _ in 0..attempts {
2475 if triplets.len() >= self.config.batch_size {
2476 break;
2477 }
2478 let Some(anchor) = self.choose_anchor_record(None, target_split) else {
2479 break;
2480 };
2481 let mut triplet = None;
2482 for offset in 0..recipe_order.len() {
2483 let idx = recipe_order[(recipe_pos + offset) % recipe_order.len()];
2484 recipe_steps = recipe_steps.saturating_add(1);
2485 let recipe = self.triplet_recipes[idx].clone();
2486 last_recipe_name = Some(recipe.name.clone());
2487 if let Some(sample) = self.make_triplet_with_anchor(&recipe, &anchor, &mut rng)
2488 {
2489 triplet = Some(sample);
2490 recipe_pos = (recipe_pos + offset + 1) % recipe_order.len();
2491 break;
2492 }
2493 }
2494 if let Some(triplet) = triplet {
2495 let key = (
2496 triplet.anchor.record_id.clone(),
2497 triplet.positive.record_id.clone(),
2498 triplet.negative.record_id.clone(),
2499 );
2500 if seen.insert(key) {
2501 triplets.push(triplet);
2502 }
2503 }
2504 }
2505 if recipe_steps > 0 {
2506 self.triplet_recipe_rr_idx =
2507 self.triplet_recipe_rr_idx.saturating_add(recipe_steps);
2508 }
2509 self.rng = rng;
2510 pad_with_reuse(&mut triplets, self.config.batch_size);
2511 if triplets.len() == self.config.batch_size {
2512 return Ok(TripletBatch { triplets });
2513 }
2514 return Err(SamplerError::Exhausted(
2515 last_recipe_name
2516 .unwrap_or(Cow::Borrowed(RECIPE_LABEL_TRIPLETS))
2517 .to_string(),
2518 ));
2519 }
2520
2521 let mut triplets: Vec<SampleTriplet> = Vec::new();
2523 let mut seen: HashSet<(RecordId, RecordId, RecordId)> = HashSet::new();
2524 let mut source_steps = 0usize;
2525 let mut cycle = (self.source_cycle_idx / sources.len()) as u64;
2526 let mut source_idx = self.source_cycle_idx % sources.len();
2527 let mut cycle_sources = self.shuffled_source_cycle(cycle);
2528 let mut recipe_steps = 0usize;
2529 let max_recipe_len = sources
2530 .iter()
2531 .map(|source| self.triplet_recipe_count_for_source(source.as_str()))
2532 .max()
2533 .unwrap_or(1)
2534 .max(1);
2535 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2536
2537 struct SlotPlan {
2542 anchor: Arc<DataRecord>,
2543 recipes: Vec<TripletRecipe>,
2544 fork_seed: u64,
2545 }
2546 let target_slots = self.config.batch_size * 4 * max_recipe_len;
2547 let mut slot_plans: Vec<SlotPlan> = Vec::with_capacity(target_slots);
2548
2549 for _ in 0..target_slots {
2550 if slot_plans.len() >= target_slots {
2551 break;
2552 }
2553 let source = cycle_sources[source_idx].as_str();
2554 let (recipes, _) = self.resolve_source_triplet_plan(source);
2555 if !recipes.is_empty() {
2556 let fork_seed = rng.next_u64();
2557 if let Some(anchor) = self.choose_anchor_record(Some(source), target_split) {
2558 slot_plans.push(SlotPlan {
2559 anchor,
2560 recipes,
2561 fork_seed,
2562 });
2563 recipe_steps = recipe_steps.saturating_add(1);
2564 }
2565 }
2566 source_idx += 1;
2567 source_steps += 1;
2568 if source_idx >= cycle_sources.len() {
2569 source_idx = 0;
2570 cycle = cycle.saturating_add(1);
2571 cycle_sources = self.shuffled_source_cycle(cycle);
2572 }
2573 }
2574 self.rng = rng;
2575
2576 struct SlotCandidate {
2581 recipe: TripletRecipe,
2582 anchor: Arc<DataRecord>,
2583 anchor_chunk: RecordChunk,
2584 positive_chunk: RecordChunk,
2585 anchor_raw_text: String,
2586 }
2587 let mut raw_candidates: Vec<(usize, Option<SlotCandidate>)> = slot_plans
2588 .par_iter()
2589 .enumerate()
2590 .map(|(slot_idx, plan)| {
2591 let mut fork_rng = DeterministicRng::new(plan.fork_seed);
2592 let weights: Vec<f32> = plan.recipes.iter().map(|r| r.weight).collect();
2595 let order = weighted_recipe_order(&weights, &mut fork_rng);
2596 let mut candidate = None;
2597 for &idx in &order {
2598 let recipe = &plan.recipes[idx];
2599 if let Some((ac, pc, raw)) =
2600 self.select_anchor_positive_for_recipe(recipe, &plan.anchor, &mut fork_rng)
2601 {
2602 candidate = Some(SlotCandidate {
2603 recipe: recipe.clone(),
2604 anchor: Arc::clone(&plan.anchor),
2605 anchor_chunk: ac,
2606 positive_chunk: pc,
2607 anchor_raw_text: raw,
2608 });
2609 break;
2610 }
2611 }
2612 (slot_idx, candidate)
2613 })
2614 .collect();
2615
2616 raw_candidates.sort_unstable_by_key(|(i, _)| *i);
2621 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2622 for (_, candidate) in raw_candidates {
2623 if triplets.len() >= self.config.batch_size {
2624 break;
2625 }
2626 let Some(sc) = candidate else { continue };
2627 let (negative_record, fallback_used) = match self.select_negative_record(
2628 &sc.anchor,
2629 &sc.recipe.negative_strategy,
2630 Some(&sc.anchor_raw_text),
2631 &mut rng,
2632 ) {
2633 Some(pair) => pair,
2634 None => continue,
2635 };
2636 let mut negative_chunk = match self.select_chunk_parallel(
2637 &negative_record,
2638 &sc.recipe.negative_selector,
2639 &mut rng,
2640 ) {
2641 Some(c) => c,
2642 None => continue,
2643 };
2644 self.decorate_chunk_parallel(&negative_record, &mut negative_chunk, &mut rng);
2645
2646 let (anchor_chunk, positive_chunk) = if rng.next_u64() & ANCHOR_POSITIVE_SWAP_MASK == 0
2648 {
2649 (sc.positive_chunk, sc.anchor_chunk)
2650 } else {
2651 (sc.anchor_chunk, sc.positive_chunk)
2652 };
2653
2654 if (!sc.recipe.allow_same_anchor_positive && anchor_chunk.text == positive_chunk.text)
2655 || negative_chunk.text == positive_chunk.text
2656 || negative_chunk.text == anchor_chunk.text
2657 {
2658 continue;
2659 }
2660
2661 let chunk_weight = self.triplet_chunk_weight(
2662 &sc.recipe,
2663 &anchor_chunk,
2664 &positive_chunk,
2665 &negative_chunk,
2666 );
2667 let weight = sc.recipe.weight * chunk_weight;
2668 let recipe_name = if fallback_used {
2669 format!("{}_fallback_same_split", sc.recipe.name)
2670 } else {
2671 sc.recipe.name.to_string()
2672 };
2673 let triplet = SampleTriplet {
2674 recipe: recipe_name,
2675 anchor: anchor_chunk,
2676 positive: positive_chunk,
2677 negative: negative_chunk,
2678 weight,
2679 instruction: sc.recipe.instruction.as_ref().map(|s| s.to_string()),
2680 };
2681 let key = (
2682 triplet.anchor.record_id.clone(),
2683 triplet.positive.record_id.clone(),
2684 triplet.negative.record_id.clone(),
2685 );
2686 if seen.insert(key) && triplets.len() < self.config.batch_size {
2687 triplets.push(triplet);
2688 }
2689 }
2690 self.rng = rng;
2691
2692 if recipe_steps > 0 {
2693 self.triplet_recipe_rr_idx = self.triplet_recipe_rr_idx.saturating_add(recipe_steps);
2694 }
2695 pad_with_reuse(&mut triplets, self.config.batch_size);
2696 if triplets.len() == self.config.batch_size {
2697 self.source_cycle_idx = self.source_cycle_idx.saturating_add(source_steps);
2698 self.source_state_dirty = sources.len() > 1;
2699 let batch = TripletBatch { triplets };
2700 return Ok(batch);
2701 }
2702 Err(SamplerError::Exhausted(RECIPE_LABEL_TRIPLETS.into()))
2703 }
2704
2705 #[cfg(test)]
2711 #[cfg(feature = "bm25-mining")]
2712 fn bm25_backend_mut(&mut self) -> &mut backends::Bm25Backend {
2713 self.negative_backend
2714 .as_any_mut()
2715 .downcast_mut::<backends::Bm25Backend>()
2716 .expect("bm25_backend_mut: negative_backend is Bm25Backend when bm25-mining feature is active")
2717 }
2718
2719 #[cfg(test)]
2721 fn recipe_order_weighted_shuffled_seeded(&mut self, weights: &[f32]) -> Vec<usize> {
2722 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2723 let result = self.recipe_order_weighted_shuffled(weights, &mut rng);
2724 self.rng = rng;
2725 result
2726 }
2727
2728 #[cfg(test)]
2730 fn recipe_order_weighted_cycled_seeded(
2731 &mut self,
2732 weights: &[f32],
2733 rr_idx: usize,
2734 ) -> Vec<usize> {
2735 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2736 let result = self.recipe_order_weighted_cycled(weights, rr_idx, &mut rng);
2737 self.rng = rng;
2738 result
2739 }
2740
2741 #[cfg(test)]
2743 fn text_recipe_order_weighted_shuffled_seeded(&mut self, weights: &[f32]) -> Vec<usize> {
2744 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2745 let result = self.text_recipe_order_weighted_shuffled(weights, &mut rng);
2746 self.rng = rng;
2747 result
2748 }
2749
2750 #[cfg(test)]
2752 fn text_recipe_order_weighted_cycled_seeded(
2753 &mut self,
2754 weights: &[f32],
2755 rr_idx: usize,
2756 ) -> Vec<usize> {
2757 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2758 let result = self.text_recipe_order_weighted_cycled(weights, rr_idx, &mut rng);
2759 self.rng = rng;
2760 result
2761 }
2762
2763 #[cfg(test)]
2765 fn select_negative_record_seeded(
2766 &mut self,
2767 anchor_record: &DataRecord,
2768 strategy: &NegativeStrategy,
2769 anchor_query_text: Option<&str>,
2770 ) -> Option<(Arc<DataRecord>, bool)> {
2771 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2772 let result =
2773 self.select_negative_record(anchor_record, strategy, anchor_query_text, &mut rng);
2774 self.rng = rng;
2775 result
2776 }
2777
2778 #[cfg(test)]
2780 fn make_triplet_with_anchor_seeded(
2781 &mut self,
2782 recipe: &TripletRecipe,
2783 anchor: &DataRecord,
2784 ) -> Option<SampleTriplet> {
2785 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2786 let result = self.make_triplet_with_anchor(recipe, anchor, &mut rng);
2787 self.rng = rng;
2788 result
2789 }
2790
2791 #[cfg(test)]
2793 fn decorate_chunk_seeded(&mut self, record: &DataRecord, chunk: &mut RecordChunk) {
2794 let mut rng = std::mem::replace(&mut self.rng, DeterministicRng::new(0));
2795 self.decorate_chunk(record, chunk, &mut rng);
2796 self.rng = rng;
2797 }
2798
2799 #[cfg(all(feature = "bm25-mining", feature = "extended-metrics"))]
2805 fn bm25_fallback_stats(&self) -> (u64, u64) {
2806 self.negative_backend.bm25_fallback_stats()
2807 }
2808
2809 #[cfg(test)]
2814 #[cfg(feature = "bm25-mining")]
2815 fn bm25_ranked_candidates(&mut self, anchor: &crate::data::DataRecord) -> Vec<RecordId> {
2816 let split = self
2817 .split_store
2818 .label_for(&anchor.id)
2819 .unwrap_or(SplitLabel::Train);
2820 self.negative_backend
2821 .as_any_mut()
2822 .downcast_mut::<backends::Bm25Backend>()
2823 .expect("bm25_ranked_candidates: Bm25Backend")
2824 .ranked_candidates_pub(anchor, split)
2825 }
2826}
2827
2828fn weighted_recipe_order(weights: &[f32], rng: &mut DeterministicRng) -> Vec<usize> {
2839 let nonzero: Vec<(usize, f32)> = weights
2840 .iter()
2841 .enumerate()
2842 .filter(|(_, w)| **w > 0.0)
2843 .map(|(i, &w)| (i, w))
2844 .collect();
2845 if nonzero.is_empty() {
2846 return Vec::new();
2847 }
2848 let w_min = nonzero
2849 .iter()
2850 .map(|(_, w)| *w)
2851 .fold(f32::INFINITY, f32::min);
2852 let mut order: Vec<usize> = Vec::new();
2853 for (recipe_idx, w) in &nonzero {
2854 let tickets = ((w / w_min).round() as usize).clamp(1, RECIPE_ORDER_MAX_WEIGHT_MULTIPLIER);
2855 for _ in 0..tickets {
2856 order.push(*recipe_idx);
2857 }
2858 }
2859 order.shuffle(rng);
2860 order
2861}
2862
2863fn same_selector_pair_is_valid(
2864 anchor_chunk: &RecordChunk,
2865 positive_chunk: &RecordChunk,
2866 enforce_window_pair: bool,
2867) -> bool {
2868 if triplet_chunk_key(anchor_chunk) == triplet_chunk_key(positive_chunk) {
2869 return false;
2870 }
2871 if !enforce_window_pair {
2872 return true;
2873 }
2874 matches!(
2875 (&anchor_chunk.view, &positive_chunk.view),
2876 (ChunkView::Window { .. }, ChunkView::Window { .. })
2877 )
2878}
2879
2880impl<S: SplitStore + EpochStateStore + SamplerStateStore + 'static> TripletSampler<S> {
2881 pub fn new(config: SamplerConfig, split_store: Arc<S>) -> Self {
2883 let inner = TripletSamplerInner::new(config, split_store);
2884 Self {
2885 inner: Mutex::new(inner),
2886 }
2887 }
2888
2889 pub fn new_with_chunker(
2891 config: SamplerConfig,
2892 split_store: Arc<S>,
2893 chunker: Arc<dyn ChunkingAlgorithm>,
2894 ) -> Self {
2895 let inner = TripletSamplerInner::new_with_chunker(config, split_store, chunker);
2896 Self {
2897 inner: Mutex::new(inner),
2898 }
2899 }
2900
2901 pub fn next_pair_batch_for_split(
2903 &self,
2904 split: SplitLabel,
2905 ) -> Result<SampleBatch, SamplerError> {
2906 self.next_pair_batch_with_weights_for_split(split, &HashMap::new())
2907 }
2908
2909 pub fn next_text_batch_for_split(&self, split: SplitLabel) -> Result<TextBatch, SamplerError> {
2911 self.next_text_batch_with_weights_for_split(split, &HashMap::new())
2912 }
2913
2914 pub fn next_triplet_batch_for_split(
2916 &self,
2917 split: SplitLabel,
2918 ) -> Result<TripletBatch, SamplerError> {
2919 self.next_triplet_batch_with_weights_for_split(split, &HashMap::new())
2920 }
2921
2922 pub fn next_pair_batch_with_weights_for_split(
2924 &self,
2925 split: SplitLabel,
2926 weights: &HashMap<SourceId, f32>,
2927 ) -> Result<SampleBatch, SamplerError> {
2928 let mut inner = self.inner.lock().unwrap();
2929 inner.ensure_split_allowed(split)?;
2930 for attempt in 0..=EXHAUSTION_RETRY_LIMIT {
2931 match inner.next_pair_batch_inner_with_weights(split, Some(weights)) {
2932 Ok(batch) => return Ok(batch),
2933 Err(SamplerError::Exhausted(_)) => {
2934 if attempt < EXHAUSTION_RETRY_LIMIT {
2935 inner.force_ingest_refresh_with_weights_for_split(split, weights)?;
2936 }
2937 }
2938 Err(err) => return Err(err),
2939 }
2940 }
2941 Err(SamplerError::Exhausted(RECIPE_LABEL_TRIPLETS.into()))
2942 }
2943
2944 pub fn next_text_batch_with_weights_for_split(
2946 &self,
2947 split: SplitLabel,
2948 weights: &HashMap<SourceId, f32>,
2949 ) -> Result<TextBatch, SamplerError> {
2950 let mut inner = self.inner.lock().unwrap();
2951 inner.ensure_split_allowed(split)?;
2952 for attempt in 0..=EXHAUSTION_RETRY_LIMIT {
2953 match inner.next_text_batch_inner_with_weights(split, Some(weights)) {
2954 Ok(batch) => return Ok(batch),
2955 Err(SamplerError::Exhausted(_)) => {
2956 if attempt < EXHAUSTION_RETRY_LIMIT {
2957 inner.force_ingest_refresh_with_weights_for_split(split, weights)?;
2958 }
2959 }
2960 Err(err) => return Err(err),
2961 }
2962 }
2963 Err(SamplerError::Exhausted(RECIPE_LABEL_TEXT.into()))
2964 }
2965
2966 pub fn next_triplet_batch_with_weights_for_split(
2968 &self,
2969 split: SplitLabel,
2970 weights: &HashMap<SourceId, f32>,
2971 ) -> Result<TripletBatch, SamplerError> {
2972 let mut inner = self.inner.lock().unwrap();
2973 inner.ensure_split_allowed(split)?;
2974 for attempt in 0..=EXHAUSTION_RETRY_LIMIT {
2975 match inner.next_triplet_batch_inner_with_weights(split, Some(weights)) {
2976 Ok(batch) => return Ok(batch),
2977 Err(SamplerError::Exhausted(_)) => {
2978 if attempt < EXHAUSTION_RETRY_LIMIT {
2979 inner.force_ingest_refresh_with_weights_for_split(split, weights)?;
2980 }
2981 }
2982 Err(err) => return Err(err),
2983 }
2984 }
2985 Err(SamplerError::Exhausted(RECIPE_LABEL_TRIPLETS.into()))
2986 }
2987
2988 pub fn prefetch_triplet_batches(
2990 self: Arc<Self>,
2991 split: SplitLabel,
2992 capacity: usize,
2993 ) -> BatchPrefetcher<TripletBatch> {
2994 BatchPrefetcher::new(capacity, move || self.next_triplet_batch_for_split(split))
2995 }
2996
2997 pub fn prefetch_triplet_batches_with_weights(
2999 self: Arc<Self>,
3000 split: SplitLabel,
3001 capacity: usize,
3002 weights: HashMap<SourceId, f32>,
3003 ) -> BatchPrefetcher<TripletBatch> {
3004 BatchPrefetcher::new(capacity, move || {
3005 self.next_triplet_batch_with_weights_for_split(split, &weights)
3006 })
3007 }
3008
3009 pub fn prefetch_pair_batches(
3011 self: Arc<Self>,
3012 split: SplitLabel,
3013 capacity: usize,
3014 ) -> BatchPrefetcher<SampleBatch> {
3015 BatchPrefetcher::new(capacity, move || self.next_pair_batch_for_split(split))
3016 }
3017
3018 pub fn prefetch_pair_batches_with_weights(
3020 self: Arc<Self>,
3021 split: SplitLabel,
3022 capacity: usize,
3023 weights: HashMap<SourceId, f32>,
3024 ) -> BatchPrefetcher<SampleBatch> {
3025 BatchPrefetcher::new(capacity, move || {
3026 self.next_pair_batch_with_weights_for_split(split, &weights)
3027 })
3028 }
3029
3030 pub fn prefetch_text_batches(
3032 self: Arc<Self>,
3033 split: SplitLabel,
3034 capacity: usize,
3035 ) -> BatchPrefetcher<TextBatch> {
3036 BatchPrefetcher::new(capacity, move || self.next_text_batch_for_split(split))
3037 }
3038
3039 pub fn prefetch_text_batches_with_weights(
3041 self: Arc<Self>,
3042 split: SplitLabel,
3043 capacity: usize,
3044 weights: HashMap<SourceId, f32>,
3045 ) -> BatchPrefetcher<TextBatch> {
3046 BatchPrefetcher::new(capacity, move || {
3047 self.next_text_batch_with_weights_for_split(split, &weights)
3048 })
3049 }
3050
3051 pub fn text_recipes(&self) -> Vec<TextRecipe> {
3053 let inner = self.inner.lock().unwrap();
3054 inner.text_recipes().to_vec()
3055 }
3056
3057 pub fn register_source(
3062 &self,
3063 source: Box<dyn DataSource + 'static>,
3064 ) -> Result<(), SamplerError> {
3065 let mut inner = self.inner.lock().unwrap();
3066 inner.register_source(source)
3067 }
3068
3069 pub fn set_epoch(&self, epoch: u64) -> Result<(), SamplerError> {
3071 let mut inner = self.inner.lock().unwrap();
3072 inner.set_epoch(epoch)
3073 }
3074
3075 pub fn save_sampler_state(&self, save_to: Option<&Path>) -> Result<(), SamplerError> {
3080 let mut inner = self.inner.lock().unwrap();
3081 inner.save_sampler_state(save_to)
3082 }
3083
3084 #[cfg(all(feature = "bm25-mining", feature = "extended-metrics"))]
3092 pub fn bm25_fallback_stats(&self) -> (u64, u64) {
3093 let inner = self.inner.lock().unwrap();
3094 inner.bm25_fallback_stats()
3095 }
3096}
3097
3098impl<S: SplitStore + EpochStateStore + SamplerStateStore + 'static> Sampler for TripletSampler<S> {
3099 fn next_pair_batch(&self, split: SplitLabel) -> Result<SampleBatch, SamplerError> {
3100 self.next_pair_batch_for_split(split)
3101 }
3102
3103 fn next_pair_batch_with_weights(
3104 &self,
3105 split: SplitLabel,
3106 weights: &HashMap<SourceId, f32>,
3107 ) -> Result<SampleBatch, SamplerError> {
3108 self.next_pair_batch_with_weights_for_split(split, weights)
3109 }
3110
3111 fn next_text_batch(&self, split: SplitLabel) -> Result<TextBatch, SamplerError> {
3112 self.next_text_batch_for_split(split)
3113 }
3114
3115 fn next_text_batch_with_weights(
3116 &self,
3117 split: SplitLabel,
3118 weights: &HashMap<SourceId, f32>,
3119 ) -> Result<TextBatch, SamplerError> {
3120 self.next_text_batch_with_weights_for_split(split, weights)
3121 }
3122
3123 fn next_triplet_batch(&self, split: SplitLabel) -> Result<TripletBatch, SamplerError> {
3124 self.next_triplet_batch_for_split(split)
3125 }
3126
3127 fn next_triplet_batch_with_weights(
3128 &self,
3129 split: SplitLabel,
3130 weights: &HashMap<SourceId, f32>,
3131 ) -> Result<TripletBatch, SamplerError> {
3132 self.next_triplet_batch_with_weights_for_split(split, weights)
3133 }
3134}
3135
3136fn roles_match(target: &SectionRole, candidate: &SectionRole) -> bool {
3137 target == candidate
3138}
3139
3140fn role_cursor_key(record_id: &RecordId, role: &SectionRole) -> (RecordId, String) {
3141 (record_id.clone(), role_label(role))
3142}
3143
3144fn role_label(role: &SectionRole) -> String {
3145 match role {
3146 SectionRole::Anchor => ROLE_LABEL_ANCHOR.into(),
3147 SectionRole::Context => ROLE_LABEL_CONTEXT.into(),
3148 }
3149}
3150
3151fn taxonomy_value(record: &DataRecord, field: MetadataKey) -> Option<&str> {
3152 record.taxonomy.iter().find_map(|entry| field.strip(entry))
3153}
3154
3155fn strategy_reason(strategy: &NegativeStrategy) -> &'static str {
3156 match strategy {
3157 NegativeStrategy::WrongPublicationDate => NEG_REASON_WRONG_DATE,
3158 NegativeStrategy::WrongArticle => NEG_REASON_WRONG_ARTICLE,
3159 NegativeStrategy::QuestionAnswerMismatch => NEG_REASON_WRONG_QA,
3160 }
3161}
3162
3163#[cfg(test)]
3171fn text_dedup_key(chunk: &RecordChunk) -> (String, String) {
3172 (chunk.record_id.clone(), chunk.text.clone())
3173}
3174
3175fn triplet_chunk_key(chunk: &RecordChunk) -> String {
3176 match &chunk.view {
3177 ChunkView::Window { index, .. } => {
3178 format!("{}|{}|w|{}", chunk.record_id, chunk.section_idx, index)
3179 }
3180 ChunkView::SummaryFallback { strategy, .. } => {
3181 format!("{}|{}|s|{}", chunk.record_id, chunk.section_idx, strategy)
3182 }
3183 }
3184}
3185
3186fn pad_with_reuse<T: Clone>(items: &mut Vec<T>, target: usize) {
3187 if items.is_empty() || items.len() >= target {
3188 return;
3189 }
3190 let seed = items.clone();
3191 let base_len = seed.len();
3192 for idx in 0..(target - items.len()) {
3193 items.push(seed[idx % base_len].clone());
3194 }
3195}
3196
3197#[cfg(test)]
3198mod tests;