1use crate::config::SamplerConfig;
2use crate::data::DataRecord;
3use crate::errors::SamplerError;
4use crate::hash::derive_epoch_seed;
5use crate::source::{DataSource, SourceCursor, SourceSnapshot};
6use crate::types::{RecordId, SourceId};
7use chrono::Utc;
8use indexmap::IndexMap;
9use std::collections::HashMap;
10use std::collections::VecDeque;
11use std::sync::{Arc, Condvar, Mutex, RwLock};
12use std::thread;
13use std::time::Duration;
14use tracing::debug;
15
16#[derive(Clone)]
18pub struct RecordCache {
19 inner: Arc<RwLock<RecordCacheInner>>,
20 notifier: Arc<(Mutex<CacheStats>, Condvar)>,
21}
22
23struct RecordCacheInner {
25 records: IndexMap<RecordId, CachedRecord>,
26 order: VecDeque<RecordId>,
27 max_records: usize,
28 next_version: u64,
29}
30
31struct CachedRecord {
33 record: DataRecord,
34 version: u64,
35}
36
37#[derive(Default)]
39struct CacheStats {
40 ingests: u64,
41}
42
43impl RecordCache {
44 pub fn new(max_records: usize) -> Self {
46 Self {
47 inner: Arc::new(RwLock::new(RecordCacheInner {
48 records: IndexMap::new(),
49 order: VecDeque::new(),
50 max_records,
51 next_version: 0,
52 })),
53 notifier: Arc::new((Mutex::new(CacheStats::default()), Condvar::new())),
54 }
55 }
56
57 pub fn ingest<I>(&self, records: I)
59 where
60 I: IntoIterator<Item = DataRecord>,
61 {
62 let mut batch: Vec<DataRecord> = records.into_iter().collect();
63 if batch.is_empty() {
64 return;
65 }
66 let mut inner = self.inner.write().expect("record cache poisoned");
67 inner.ingest_batch(&mut batch);
68 drop(inner);
69 let (lock, cvar) = &*self.notifier;
70 let mut stats = lock.lock().expect("record cache stats poisoned");
71 stats.ingests = stats.ingests.saturating_add(1);
72 cvar.notify_all();
73 }
74
75 pub fn clear(&self) {
77 let mut inner = self.inner.write().expect("record cache poisoned");
78 inner.records.clear();
79 inner.order.clear();
80 }
81
82 pub fn snapshot(&self) -> Vec<DataRecord> {
84 let inner = self.inner.read().expect("record cache poisoned");
85 inner
86 .records
87 .values()
88 .map(|entry| entry.record.clone())
89 .collect()
90 }
91
92 pub fn ingest_count(&self) -> u64 {
94 let (lock, _) = &*self.notifier;
95 lock.lock().expect("record cache stats poisoned").ingests
96 }
97
98 pub fn wait_for_ingest(&self, last_seen: u64, timeout: Duration) -> u64 {
100 let (lock, cvar) = &*self.notifier;
101 let mut stats = lock.lock().expect("record cache stats poisoned");
102 while stats.ingests <= last_seen {
103 let result = cvar
104 .wait_timeout(stats, timeout)
105 .expect("record cache stats poisoned");
106 stats = result.0;
107 if result.1.timed_out() {
108 break;
109 }
110 }
111 stats.ingests
112 }
113
114 pub fn wait_for_ingest_blocking(&self, last_seen: u64) -> u64 {
116 let (lock, cvar) = &*self.notifier;
117 let mut stats = lock.lock().expect("record cache stats poisoned");
118 while stats.ingests <= last_seen {
119 stats = cvar.wait(stats).expect("record cache stats poisoned");
120 }
121 stats.ingests
122 }
123
124 pub fn is_empty(&self) -> bool {
126 let inner = self.inner.read().expect("record cache poisoned");
127 inner.records.is_empty()
128 }
129
130 pub fn len(&self) -> usize {
132 let inner = self.inner.read().expect("record cache poisoned");
133 inner.records.len()
134 }
135}
136
137impl RecordCacheInner {
138 fn ingest_batch(&mut self, records: &mut Vec<DataRecord>) {
139 for record in records.drain(..) {
140 self.next_version = self.next_version.saturating_add(1);
141 let record_id = record.id.clone();
142 if self.records.contains_key(&record_id) {
143 if let Some(entry) = self.records.get_mut(&record_id) {
144 entry.record = record;
145 entry.version = self.next_version;
146 }
147 Self::refresh_order(&mut self.order, &record_id);
148 self.order.push_back(record_id);
149 } else {
150 self.order.push_back(record_id.clone());
151 self.records.insert(
152 record_id,
153 CachedRecord {
154 record,
155 version: self.next_version,
156 },
157 );
158 }
159 self.enforce_limit();
160 }
161 }
162
163 fn enforce_limit(&mut self) {
164 if self.max_records == 0 {
165 self.records.clear();
166 self.order.clear();
167 return;
168 }
169 while self.records.len() > self.max_records {
170 if let Some(oldest) = self.order.pop_front() {
171 self.records.swap_remove(&oldest);
172 } else {
173 break;
174 }
175 }
176 }
177
178 fn refresh_order(order: &mut VecDeque<RecordId>, id: &RecordId) {
179 if order.is_empty() {
180 return;
181 }
182 if let Some(pos) = order.iter().position(|existing| existing == id) {
183 order.remove(pos);
184 }
185 }
186}
187
188pub struct IngestionManager {
190 sources: Vec<SourceState>,
191 max_records: usize,
192 sampler_config: SamplerConfig,
193 source_epoch: u64,
195 source_refresh_generation: u64,
197 last_refreshed_sources: Vec<SourceId>,
202 drain_start: usize,
207}
208
209#[derive(Clone, Debug, Default)]
210pub struct SourceRefreshStats {
212 pub last_refresh_ms: u128,
214 pub last_record_count: usize,
216 pub last_records_per_sec: f64,
218 pub last_error: Option<String>,
220 pub error_count: u64,
222}
223
224impl IngestionManager {
225 pub fn new(max_records: usize, sampler_config: SamplerConfig) -> Self {
227 Self {
228 sources: Vec::new(),
229 max_records,
230 sampler_config,
231 source_epoch: 0,
232 source_refresh_generation: 0,
233 last_refreshed_sources: Vec::new(),
234 drain_start: 0,
235 }
236 }
237
238 pub fn source_refresh_generation(&self) -> u64 {
240 self.source_refresh_generation
241 }
242
243 pub fn last_refreshed_sources(&self) -> &[SourceId] {
245 &self.last_refreshed_sources
246 }
247
248 pub(crate) fn set_source_epoch(&mut self, epoch: u64) {
257 self.source_epoch = epoch;
258 }
259
260 #[cfg(test)]
262 pub fn source_epoch(&self) -> u64 {
263 self.source_epoch
264 }
265
266 pub(crate) fn reset_stream_cursors(&mut self) {
272 for state in &mut self.sources {
273 state.cursor = None;
274 state.buffer.clear();
275 state.cache.clear();
276 }
277 }
278
279 pub fn register_source(&mut self, source: Box<dyn DataSource + 'static>) {
281 let cache = RecordCache::new(self.max_records);
282 self.sources.push(SourceState {
283 source,
284 cursor: None,
285 buffer: VecDeque::new(),
286 cache,
287 stats: SourceRefreshStats::default(),
288 });
289 }
290
291 pub fn load_cursors(&mut self, cursors: &[(SourceId, u64)]) {
293 if cursors.is_empty() {
294 return;
295 }
296 let mut map = std::collections::HashMap::with_capacity(cursors.len());
297 for (id, revision) in cursors {
298 map.insert(id.as_str(), *revision);
299 }
300 for state in &mut self.sources {
301 if let Some(revision) = map.get(state.source.id()) {
302 state.cursor = Some(SourceCursor {
303 last_seen: Utc::now(),
304 revision: *revision,
305 });
306 }
307 }
308 }
309
310 pub fn snapshot_cursors(&self) -> Vec<(SourceId, u64)> {
312 let mut out = Vec::new();
313 for state in &self.sources {
314 if let Some(cursor) = state.cursor.as_ref() {
315 out.push((state.source.id().to_string(), cursor.revision));
316 }
317 }
318 out
319 }
320
321 pub fn source_refresh_stats(&self) -> Vec<(SourceId, SourceRefreshStats)> {
323 self.sources
324 .iter()
325 .map(|state| (state.source.id().to_string(), state.stats.clone()))
326 .collect()
327 }
328
329 pub fn all_records_snapshot(&self) -> Vec<DataRecord> {
334 self.sources
335 .iter()
336 .flat_map(|s| s.cache.snapshot())
337 .collect()
338 }
339
340 pub fn all_caches_empty(&self) -> bool {
342 self.sources.iter().all(|s| s.cache.is_empty())
343 }
344
345 pub fn all_records_len(&self) -> usize {
347 self.sources.iter().map(|s| s.cache.len()).sum()
348 }
349
350 pub fn total_ingest_count(&self) -> u64 {
355 self.sources.iter().map(|s| s.cache.ingest_count()).sum()
356 }
357
358 pub fn refresh_all(&mut self) {
360 self.refresh_all_internal(false, None, None);
361 }
362
363 pub fn advance(&mut self, step: usize) {
365 self.refresh_all_internal(false, Some(step), None);
366 }
367
368 pub fn advance_with_weights(
373 &mut self,
374 step: usize,
375 weights: &HashMap<SourceId, f32>,
376 ) -> Result<(), SamplerError> {
377 self.validate_weights(weights)?;
378 self.refresh_all_internal(false, Some(step), Some(weights));
379 Ok(())
380 }
381
382 pub fn force_refresh_all(&mut self) {
384 self.refresh_all_internal(true, None, None);
385 }
386
387 pub fn refresh_all_with_weights(
392 &mut self,
393 weights: &HashMap<SourceId, f32>,
394 ) -> Result<(), SamplerError> {
395 self.validate_weights(weights)?;
396 self.refresh_all_internal(false, None, Some(weights));
397 Ok(())
398 }
399
400 pub fn force_refresh_all_with_weights(
405 &mut self,
406 weights: &HashMap<SourceId, f32>,
407 ) -> Result<(), SamplerError> {
408 self.validate_weights(weights)?;
409 self.refresh_all_internal(true, None, Some(weights));
410 Ok(())
411 }
412
413 fn validate_weights(&self, weights: &HashMap<SourceId, f32>) -> Result<(), SamplerError> {
414 let known_ids: std::collections::HashSet<&str> =
415 self.sources.iter().map(|s| s.source.id()).collect();
416 for (id, &w) in weights {
417 if !known_ids.contains(id.as_str()) {
418 return Err(SamplerError::InvalidWeight {
419 source_id: id.clone(),
420 reason: "source is not registered".to_string(),
421 });
422 }
423 if w < 0.0 {
424 return Err(SamplerError::InvalidWeight {
425 source_id: id.clone(),
426 reason: format!("weight {w} is negative"),
427 });
428 }
429 }
430 Ok(())
431 }
432
433 fn refresh_all_internal(
440 &mut self,
441 force_refresh: bool,
442 step: Option<usize>,
443 weights: Option<&HashMap<SourceId, f32>>,
444 ) {
445 self.last_refreshed_sources.clear();
446 let mut refresh_plan = Vec::new();
447 for (idx, state) in self.sources.iter_mut().enumerate() {
448 if force_refresh {
449 state.buffer.clear();
450 }
451 if force_refresh || state.buffer.is_empty() {
452 refresh_plan.push((idx, state.cursor.clone()));
453 }
454 }
455
456 if !refresh_plan.is_empty() {
457 self.source_refresh_generation = self.source_refresh_generation.saturating_add(1);
458 self.last_refreshed_sources = refresh_plan
459 .iter()
460 .map(|(idx, _)| self.sources[*idx].source.id().to_string())
461 .collect();
462 let mut results: Vec<
463 Option<(Result<SourceSnapshot, SamplerError>, std::time::Duration)>,
464 > = Vec::with_capacity(self.sources.len());
465 results.resize_with(self.sources.len(), || None);
466 let fetch_limit = self.max_records;
467 let sampler_config = self.sampler_config.clone();
468 thread::scope(|scope| {
469 let mut handles = Vec::with_capacity(refresh_plan.len());
470 for (idx, cursor) in &refresh_plan {
471 let source = &self.sources[*idx].source;
472 let cursor = cursor.clone();
473 let sampler_config = sampler_config.clone();
474 let source_epoch = self.source_epoch;
475 handles.push((
476 *idx,
477 scope.spawn(move || {
478 let start = std::time::Instant::now();
479 let epoch_config = SamplerConfig {
482 seed: derive_epoch_seed(sampler_config.seed, source_epoch),
483 ..sampler_config
484 };
485 let result =
486 source.refresh(&epoch_config, cursor.as_ref(), Some(fetch_limit));
487 let elapsed = start.elapsed();
488 (result, elapsed)
489 }),
490 ));
491 }
492 for (idx, handle) in handles {
493 let result = match handle.join() {
494 Ok((result, elapsed)) => {
495 debug!(
496 source_id = %self.sources[idx].source.id(),
497 refresh_ms = elapsed.as_millis(),
498 "source refresh completed"
499 );
500 (result, elapsed)
501 }
502 Err(_) => (
503 Err(SamplerError::SourceUnavailable {
504 source_id: self.sources[idx].source.id().to_string(),
505 reason: "source refresh thread panicked".into(),
506 }),
507 std::time::Duration::from_secs(0),
508 ),
509 };
510 results[idx] = Some(result);
511 }
512 });
513
514 for (idx, _) in refresh_plan {
515 let Some((result, elapsed)) = results[idx].take() else {
516 continue;
517 };
518 match result {
519 Ok(snapshot) => {
520 let SourceSnapshot {
521 records,
522 cursor: next_cursor,
523 } = snapshot;
524 let record_count = records.len();
525 let seconds = elapsed.as_secs_f64();
526 let per_sec = if seconds > 0.0 {
527 (record_count as f64) / seconds
528 } else {
529 0.0
530 };
531 let stats = &mut self.sources[idx].stats;
532 stats.last_refresh_ms = elapsed.as_millis();
533 stats.last_record_count = record_count;
534 stats.last_records_per_sec = per_sec;
535 stats.last_error = None;
536 debug!(
537 source_id = %self.sources[idx].source.id(),
538 record_count,
539 refresh_ms = elapsed.as_millis(),
540 records_per_sec = per_sec,
541 "source refresh ingested records"
542 );
543 let source_id = self.sources[idx].source.id().to_string();
544 let normalized = records
545 .into_iter()
546 .map(|mut record| {
547 record.source = source_id.clone();
548 record
549 })
550 .collect::<Vec<_>>();
551 self.sources[idx].buffer.extend(normalized);
552 self.sources[idx].cursor = Some(next_cursor);
553 }
554 Err(err) => {
555 let stats = &mut self.sources[idx].stats;
556 stats.last_refresh_ms = elapsed.as_millis();
557 stats.last_record_count = 0;
558 stats.last_records_per_sec = 0.0;
559 stats.last_error = Some(err.to_string());
560 stats.error_count = stats.error_count.saturating_add(1);
561 eprintln!(
562 "[data_sampler] source '{}' refresh failed: {}",
563 self.sources[idx].source.id(),
564 err
565 );
566 }
567 }
568 }
569 }
570
571 if step.is_none() {
575 for state in self.sources.iter_mut() {
576 state.cache.clear();
577 }
578 }
579 if self.max_records == 0 {
580 return;
581 }
582 let target_limit = step.unwrap_or(self.max_records);
583 if let Some(weights) = weights {
584 self.weighted_drain_into_caches(target_limit, weights);
585 } else {
586 let n = self.sources.len();
591 if n > 0 {
592 let mut per_source: Vec<Vec<DataRecord>> = vec![Vec::new(); n];
593 let mut total_drained = 0;
594 let mut any_remaining = true;
595 while total_drained < target_limit && any_remaining {
596 any_remaining = false;
597 for offset in 0..n {
598 if total_drained >= target_limit {
599 break;
600 }
601 let idx = (self.drain_start + offset) % n;
602 if let Some(record) = self.sources[idx].buffer.pop_front() {
603 per_source[idx].push(record);
604 total_drained += 1;
605 any_remaining = true;
606 }
607 }
608 }
609 if total_drained > 0 {
613 self.drain_start = (self.drain_start + 1) % n;
614 }
615 for (idx, batch) in per_source.into_iter().enumerate() {
616 if !batch.is_empty() {
617 self.sources[idx].cache.ingest(batch);
618 }
619 }
620 }
621 }
622 }
623
624 fn weighted_drain_into_caches(&mut self, limit: usize, weights: &HashMap<SourceId, f32>) {
625 let len = self.sources.len();
626 if len == 0 {
627 return;
628 }
629 let mut weight_values = Vec::with_capacity(len);
630 let mut any_positive = false;
631 for state in &self.sources {
632 let weight = weights.get(state.source.id()).copied().unwrap_or(1.0);
633 if weight > 0.0 {
634 any_positive = true;
635 }
636 weight_values.push(weight);
637 }
638 if !any_positive {
639 weight_values.fill(1.0);
640 }
641
642 let mut current = vec![0.0f32; len];
643 let mut per_source: Vec<Vec<DataRecord>> = vec![Vec::new(); len];
644 let mut total = 0;
645 while total < limit {
646 let mut total_weight = 0.0f32;
647 for (idx, weight) in weight_values.iter().copied().enumerate().take(len) {
648 if weight <= 0.0 {
649 continue;
650 }
651 if self.sources[idx].buffer.is_empty() {
652 continue;
653 }
654 total_weight += weight;
655 }
656 if total_weight == 0.0 {
657 break;
658 }
659
660 let mut best_idx = None;
661 let mut best_score = f32::MIN;
662 let closer_to_start = |a: usize, b: usize| -> bool {
666 let da = (a + len - self.drain_start) % len;
667 let db = (b + len - self.drain_start) % len;
668 da < db
669 };
670 for idx in 0..len {
671 if weight_values[idx] <= 0.0 {
672 continue;
673 }
674 if self.sources[idx].buffer.is_empty() {
675 continue;
676 }
677 current[idx] += weight_values[idx];
678 let is_better = if current[idx] > best_score {
679 true
680 } else if current[idx] == best_score {
681 closer_to_start(idx, best_idx.unwrap_or(0))
682 } else {
683 false
684 };
685 if is_better {
686 best_score = current[idx];
687 best_idx = Some(idx);
688 }
689 }
690
691 let idx = match best_idx {
692 Some(idx) => idx,
693 None => break,
694 };
695 current[idx] -= total_weight;
696 if let Some(record) = self.sources[idx].buffer.pop_front() {
697 per_source[idx].push(record);
698 total += 1;
699 }
700 }
701
702 if total > 0 && len > 0 {
703 self.drain_start = (self.drain_start + 1) % len;
704 }
705
706 for (idx, batch) in per_source.into_iter().enumerate() {
707 if !batch.is_empty() {
708 self.sources[idx].cache.ingest(batch);
709 }
710 }
711 }
712
713 pub fn has_sources(&self) -> bool {
715 !self.sources.is_empty()
716 }
717}
718
719struct SourceState {
721 source: Box<dyn DataSource + 'static>,
722 cursor: Option<SourceCursor>,
723 buffer: VecDeque<DataRecord>,
724 cache: RecordCache,
726 stats: SourceRefreshStats,
727}
728
729#[cfg(test)]
730mod tests {
731 use super::*;
732 use crate::TripletSampler;
733 use crate::config::{Selector, TextRecipe, TripletRecipe};
734 use crate::data::{QualityScore, RecordSection, SectionRole};
735 use crate::sampler::Sampler;
736 use crate::splits::{DeterministicSplitStore, SplitLabel, SplitRatios};
737 use chrono::Utc;
738 use std::collections::HashMap;
739 use std::collections::VecDeque;
740 use std::sync::atomic::{AtomicUsize, Ordering};
741 use std::sync::{Arc, Mutex};
742
743 fn make_record(id: &str, source: &str) -> DataRecord {
744 let now = Utc::now();
745 DataRecord {
746 id: id.to_string(),
747 source: source.to_string(),
748 created_at: now,
749 updated_at: now,
750 quality: QualityScore { trust: 1.0 },
751 taxonomy: Vec::new(),
752 sections: vec![RecordSection {
753 role: SectionRole::Anchor,
754 heading: None,
755 text: id.to_string(),
756 sentences: vec![id.to_string()],
757 }],
758 meta_prefix: None,
759 }
760 }
761
762 struct ScriptedSource {
763 id: String,
764 refreshes: Arc<AtomicUsize>,
765 script: Arc<Mutex<VecDeque<Result<SourceSnapshot, SamplerError>>>>,
766 }
767
768 impl ScriptedSource {
769 fn new(
770 id: &str,
771 refreshes: Arc<AtomicUsize>,
772 script: Vec<Result<SourceSnapshot, SamplerError>>,
773 ) -> Self {
774 Self {
775 id: id.to_string(),
776 refreshes,
777 script: Arc::new(Mutex::new(script.into_iter().collect())),
778 }
779 }
780 }
781
782 impl DataSource for ScriptedSource {
783 fn id(&self) -> &str {
784 &self.id
785 }
786
787 fn refresh(
788 &self,
789 _config: &SamplerConfig,
790 _cursor: Option<&SourceCursor>,
791 _limit: Option<usize>,
792 ) -> Result<SourceSnapshot, SamplerError> {
793 self.refreshes.fetch_add(1, Ordering::SeqCst);
794 let mut guard = self.script.lock().expect("script lock poisoned");
795 guard.pop_front().unwrap_or_else(|| {
796 Ok(SourceSnapshot {
797 records: Vec::new(),
798 cursor: SourceCursor {
799 last_seen: Utc::now(),
800 revision: 0,
801 },
802 })
803 })
804 }
805
806 fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
807 Ok(0)
808 }
809
810 fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
811 Vec::new()
812 }
813 }
814
815 struct PanicSource {
816 id: String,
817 }
818
819 impl DataSource for PanicSource {
820 fn id(&self) -> &str {
821 &self.id
822 }
823
824 fn refresh(
825 &self,
826 _config: &SamplerConfig,
827 _cursor: Option<&SourceCursor>,
828 _limit: Option<usize>,
829 ) -> Result<SourceSnapshot, SamplerError> {
830 panic!("panic source refresh")
831 }
832
833 fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
834 Ok(0)
835 }
836 }
837
838 #[test]
839 fn record_cache_waits_len_and_clear_paths_are_covered() {
840 let cache = RecordCache::new(2);
841 assert!(cache.is_empty());
842 assert_eq!(cache.len(), 0);
843 assert_eq!(cache.ingest_count(), 0);
844
845 cache.ingest(Vec::<DataRecord>::new());
846 assert_eq!(cache.wait_for_ingest(0, Duration::from_millis(1)), 0);
847
848 let cache_for_waiter = cache.clone();
849 let handle = std::thread::spawn(move || cache_for_waiter.wait_for_ingest_blocking(0));
850 std::thread::sleep(Duration::from_millis(5));
851 cache.ingest(vec![make_record("r1", "s")]);
852 assert_eq!(handle.join().unwrap(), 1);
853 assert_eq!(cache.ingest_count(), 1);
854
855 cache.ingest(vec![make_record("r2", "s"), make_record("r3", "s")]);
856 assert_eq!(cache.len(), 2);
857 let ids: Vec<String> = cache
858 .snapshot()
859 .into_iter()
860 .map(|record| record.id)
861 .collect();
862 assert!(ids.contains(&"r2".to_string()));
863 assert!(ids.contains(&"r3".to_string()));
864
865 cache.clear();
866 assert!(cache.is_empty());
867 }
868
869 #[test]
870 fn record_cache_zero_limit_discards_everything() {
871 let cache = RecordCache::new(0);
872 cache.ingest(vec![make_record("r1", "s")]);
873 assert!(cache.is_empty());
874 assert_eq!(cache.len(), 0);
875 }
876
877 #[test]
878 fn manager_loads_and_snapshots_cursors_and_reports_has_sources() {
879 let mut manager = IngestionManager::new(4, SamplerConfig::default());
880 assert!(!manager.has_sources());
881 manager.load_cursors(&[]);
882
883 let refreshes = Arc::new(AtomicUsize::new(0));
884 manager.register_source(Box::new(ScriptedSource::new(
885 "cursor_source",
886 refreshes,
887 vec![Ok(SourceSnapshot {
888 records: vec![make_record("id_1", "original_source")],
889 cursor: SourceCursor {
890 last_seen: Utc::now(),
891 revision: 33,
892 },
893 })],
894 )));
895 assert!(manager.has_sources());
896
897 manager.load_cursors(&[("cursor_source".to_string(), 7)]);
898 let cursors = manager.snapshot_cursors();
899 assert_eq!(cursors, vec![("cursor_source".to_string(), 7)]);
900
901 manager.refresh_all();
902 let updated = manager.snapshot_cursors();
903 assert_eq!(updated, vec![("cursor_source".to_string(), 33)]);
904 let records = manager.all_records_snapshot();
905 assert_eq!(records.len(), 1);
906 assert_eq!(records[0].source, "cursor_source");
907 }
908
909 #[test]
910 fn advance_uses_buffer_before_refreshing_again() {
911 let refreshes = Arc::new(AtomicUsize::new(0));
912 let mut manager = IngestionManager::new(5, SamplerConfig::default());
913 manager.register_source(Box::new(ScriptedSource::new(
914 "buffered",
915 refreshes.clone(),
916 vec![Ok(SourceSnapshot {
917 records: vec![
918 make_record("a", "x"),
919 make_record("b", "x"),
920 make_record("c", "x"),
921 ],
922 cursor: SourceCursor {
923 last_seen: Utc::now(),
924 revision: 1,
925 },
926 })],
927 )));
928
929 manager.advance(1);
930 assert_eq!(refreshes.load(Ordering::SeqCst), 1);
931 assert_eq!(manager.all_records_len(), 1);
932
933 manager.advance(1);
934 assert_eq!(refreshes.load(Ordering::SeqCst), 1);
935 assert_eq!(manager.all_records_len(), 2);
936 }
937
938 #[test]
939 fn force_refresh_clears_buffer_and_fetches_again() {
940 let refreshes = Arc::new(AtomicUsize::new(0));
941 let mut manager = IngestionManager::new(4, SamplerConfig::default());
942 manager.register_source(Box::new(ScriptedSource::new(
943 "force",
944 refreshes.clone(),
945 vec![
946 Ok(SourceSnapshot {
947 records: vec![
948 make_record("r1", "x"),
949 make_record("r2", "x"),
950 make_record("r3", "x"),
951 ],
952 cursor: SourceCursor {
953 last_seen: Utc::now(),
954 revision: 10,
955 },
956 }),
957 Ok(SourceSnapshot {
958 records: vec![make_record("r4", "x")],
959 cursor: SourceCursor {
960 last_seen: Utc::now(),
961 revision: 11,
962 },
963 }),
964 ],
965 )));
966
967 manager.advance(1);
968 assert_eq!(manager.all_records_len(), 1);
969 assert_eq!(refreshes.load(Ordering::SeqCst), 1);
970
971 manager.force_refresh_all();
972 assert_eq!(refreshes.load(Ordering::SeqCst), 2);
973 let records = manager.all_records_snapshot();
974 assert_eq!(records.len(), 1);
975 assert_eq!(records[0].id, "r4");
976 }
977
978 #[test]
979 fn weighted_drain_respects_zero_and_fallback_weights() {
980 let mut manager = IngestionManager::new(6, SamplerConfig::default());
981 manager.register_source(Box::new(ScriptedSource::new(
982 "a",
983 Arc::new(AtomicUsize::new(0)),
984 vec![Ok(SourceSnapshot {
985 records: vec![make_record("a1", "a"), make_record("a2", "a")],
986 cursor: SourceCursor {
987 last_seen: Utc::now(),
988 revision: 1,
989 },
990 })],
991 )));
992 manager.register_source(Box::new(ScriptedSource::new(
993 "b",
994 Arc::new(AtomicUsize::new(0)),
995 vec![Ok(SourceSnapshot {
996 records: vec![make_record("b1", "b"), make_record("b2", "b")],
997 cursor: SourceCursor {
998 last_seen: Utc::now(),
999 revision: 1,
1000 },
1001 })],
1002 )));
1003
1004 let mut only_b = HashMap::new();
1005 only_b.insert("a".to_string(), 0.0);
1006 only_b.insert("b".to_string(), 1.0);
1007 manager.refresh_all_with_weights(&only_b).unwrap();
1008 let ids: Vec<String> = manager
1009 .all_records_snapshot()
1010 .into_iter()
1011 .map(|record| record.id)
1012 .collect();
1013 assert!(ids.iter().all(|id| id.starts_with('b')));
1014
1015 let mut manager_fallback = IngestionManager::new(6, SamplerConfig::default());
1016 manager_fallback.register_source(Box::new(ScriptedSource::new(
1017 "a",
1018 Arc::new(AtomicUsize::new(0)),
1019 vec![Ok(SourceSnapshot {
1020 records: vec![make_record("a1", "a")],
1021 cursor: SourceCursor {
1022 last_seen: Utc::now(),
1023 revision: 2,
1024 },
1025 })],
1026 )));
1027 manager_fallback.register_source(Box::new(ScriptedSource::new(
1028 "b",
1029 Arc::new(AtomicUsize::new(0)),
1030 vec![Ok(SourceSnapshot {
1031 records: vec![make_record("b1", "b")],
1032 cursor: SourceCursor {
1033 last_seen: Utc::now(),
1034 revision: 2,
1035 },
1036 })],
1037 )));
1038
1039 let mut all_zero = HashMap::new();
1040 all_zero.insert("a".to_string(), 0.0);
1041 all_zero.insert("b".to_string(), 0.0);
1042 manager_fallback
1043 .refresh_all_with_weights(&all_zero)
1044 .unwrap();
1045 let ids: Vec<String> = manager_fallback
1046 .all_records_snapshot()
1047 .into_iter()
1048 .map(|record| record.id)
1049 .collect();
1050 assert!(ids.contains(&"a1".to_string()));
1051 assert!(ids.contains(&"b1".to_string()));
1052 }
1053
1054 #[test]
1055 fn refresh_errors_and_panics_update_source_stats() {
1056 let mut manager = IngestionManager::new(4, SamplerConfig::default());
1057 manager.register_source(Box::new(ScriptedSource::new(
1058 "err_source",
1059 Arc::new(AtomicUsize::new(0)),
1060 vec![Err(SamplerError::SourceUnavailable {
1061 source_id: "err_source".to_string(),
1062 reason: "boom".to_string(),
1063 })],
1064 )));
1065 manager.register_source(Box::new(PanicSource {
1066 id: "panic_source".to_string(),
1067 }));
1068
1069 manager.refresh_all();
1070 let stats = manager.source_refresh_stats();
1071 let err_stats = stats
1072 .iter()
1073 .find(|(source, _)| source == "err_source")
1074 .map(|(_, stats)| stats)
1075 .unwrap();
1076 assert_eq!(err_stats.error_count, 1);
1077 assert!(
1078 err_stats
1079 .last_error
1080 .as_ref()
1081 .is_some_and(|msg| msg.contains("boom"))
1082 );
1083
1084 let panic_stats = stats
1085 .iter()
1086 .find(|(source, _)| source == "panic_source")
1087 .map(|(_, stats)| stats)
1088 .unwrap();
1089 assert_eq!(panic_stats.error_count, 1);
1090 assert!(
1091 panic_stats
1092 .last_error
1093 .as_ref()
1094 .is_some_and(|msg| msg.contains("panicked"))
1095 );
1096 }
1097
1098 #[test]
1099 fn force_refresh_with_weights_path_is_exercised() {
1100 let mut manager = IngestionManager::new(3, SamplerConfig::default());
1101 manager.register_source(Box::new(ScriptedSource::new(
1102 "w",
1103 Arc::new(AtomicUsize::new(0)),
1104 vec![Ok(SourceSnapshot {
1105 records: vec![make_record("w1", "w")],
1106 cursor: SourceCursor {
1107 last_seen: Utc::now(),
1108 revision: 3,
1109 },
1110 })],
1111 )));
1112
1113 let mut weights = HashMap::new();
1114 weights.insert("w".to_string(), 1.0);
1115 manager.force_refresh_all_with_weights(&weights).unwrap();
1116 assert_eq!(manager.all_records_len(), 1);
1117 }
1118
1119 #[test]
1120 fn advance_with_weights_rejects_unknown_source() {
1121 let mut manager = IngestionManager::new(4, SamplerConfig::default());
1122 manager.register_source(Box::new(ScriptedSource::new(
1123 "known",
1124 Arc::new(AtomicUsize::new(0)),
1125 vec![],
1126 )));
1127
1128 let mut weights = HashMap::new();
1129 weights.insert("known".to_string(), 1.0);
1130 weights.insert("unknown".to_string(), 0.5);
1131
1132 let err = manager.advance_with_weights(1, &weights).unwrap_err();
1133 assert!(
1134 matches!(err, SamplerError::InvalidWeight { ref source_id, .. } if source_id == "unknown"),
1135 "expected InvalidWeight for unknown source, got {err:?}"
1136 );
1137 }
1138
1139 #[test]
1140 fn refresh_all_with_weights_rejects_negative_weight() {
1141 let mut manager = IngestionManager::new(4, SamplerConfig::default());
1142 manager.register_source(Box::new(ScriptedSource::new(
1143 "src",
1144 Arc::new(AtomicUsize::new(0)),
1145 vec![],
1146 )));
1147
1148 let mut weights = HashMap::new();
1149 weights.insert("src".to_string(), -1.0);
1150
1151 let err = manager.refresh_all_with_weights(&weights).unwrap_err();
1152 assert!(
1153 matches!(err, SamplerError::InvalidWeight { ref source_id, .. } if source_id == "src"),
1154 "expected InvalidWeight for negative weight, got {err:?}"
1155 );
1156 }
1157
1158 #[test]
1159 fn force_refresh_all_with_weights_rejects_unknown_source() {
1160 let mut manager = IngestionManager::new(4, SamplerConfig::default());
1161 manager.register_source(Box::new(ScriptedSource::new(
1162 "real",
1163 Arc::new(AtomicUsize::new(0)),
1164 vec![],
1165 )));
1166
1167 let mut weights = HashMap::new();
1168 weights.insert("ghost".to_string(), 1.0);
1169
1170 let err = manager
1171 .force_refresh_all_with_weights(&weights)
1172 .unwrap_err();
1173 assert!(
1174 matches!(err, SamplerError::InvalidWeight { ref source_id, .. } if source_id == "ghost"),
1175 "expected InvalidWeight for unknown source, got {err:?}"
1176 );
1177 }
1178
1179 struct SeedCapturingSource {
1181 id: String,
1182 received_seeds: Arc<Mutex<Vec<u64>>>,
1183 }
1184
1185 impl SeedCapturingSource {
1186 fn new(id: &str, received_seeds: Arc<Mutex<Vec<u64>>>) -> Self {
1187 Self {
1188 id: id.to_string(),
1189 received_seeds,
1190 }
1191 }
1192 }
1193
1194 impl DataSource for SeedCapturingSource {
1195 fn id(&self) -> &str {
1196 &self.id
1197 }
1198
1199 fn refresh(
1200 &self,
1201 config: &SamplerConfig,
1202 _cursor: Option<&SourceCursor>,
1203 _limit: Option<usize>,
1204 ) -> Result<SourceSnapshot, SamplerError> {
1205 self.received_seeds
1206 .lock()
1207 .expect("seed lock poisoned")
1208 .push(config.seed);
1209 Ok(SourceSnapshot {
1210 records: Vec::new(),
1211 cursor: SourceCursor {
1212 last_seen: Utc::now(),
1213 revision: 0,
1214 },
1215 })
1216 }
1217
1218 fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
1219 Ok(0)
1220 }
1221
1222 fn default_triplet_recipes(&self) -> Vec<crate::config::TripletRecipe> {
1223 Vec::new()
1224 }
1225 }
1226
1227 #[test]
1228 fn source_epoch_xor_changes_seed_received_by_source() {
1229 let base_seed = 0xDEAD_BEEF_u64;
1232 let config = SamplerConfig {
1233 seed: base_seed,
1234 ..SamplerConfig::default()
1235 };
1236
1237 let seeds_epoch0 = Arc::new(Mutex::new(Vec::<u64>::new()));
1238 let seeds_epoch1 = Arc::new(Mutex::new(Vec::<u64>::new()));
1239
1240 let mut manager = IngestionManager::new(4, config.clone());
1242 manager.register_source(Box::new(SeedCapturingSource::new(
1243 "src",
1244 Arc::clone(&seeds_epoch0),
1245 )));
1246 manager.refresh_all();
1248
1249 let mut manager2 = IngestionManager::new(4, config.clone());
1251 manager2.register_source(Box::new(SeedCapturingSource::new(
1252 "src",
1253 Arc::clone(&seeds_epoch1),
1254 )));
1255 manager2.set_source_epoch(1);
1256 manager2.refresh_all();
1257
1258 let received0 = seeds_epoch0.lock().unwrap();
1259 let received1 = seeds_epoch1.lock().unwrap();
1260
1261 assert!(!received0.is_empty(), "epoch-0 source was never refreshed");
1262 assert!(!received1.is_empty(), "epoch-1 source was never refreshed");
1263
1264 let seed_at_epoch0 = received0[0];
1265 let seed_at_epoch1 = received1[0];
1266
1267 assert_ne!(
1269 seed_at_epoch0, seed_at_epoch1,
1270 "epoch 0 and epoch 1 both produced seed {seed_at_epoch0:#x}; \
1271 derive_epoch_seed is not reaching the source"
1272 );
1273
1274 assert_eq!(
1276 seed_at_epoch0,
1277 derive_epoch_seed(base_seed, 0),
1278 "epoch-0 seed mismatch"
1279 );
1280 assert_eq!(
1281 seed_at_epoch1,
1282 derive_epoch_seed(base_seed, 1),
1283 "epoch-1 seed mismatch"
1284 );
1285 }
1286
1287 #[test]
1288 fn scripted_and_panic_sources_cover_default_trait_paths() {
1289 let refreshes = Arc::new(AtomicUsize::new(0));
1290 let scripted = ScriptedSource::new("scripted", refreshes, vec![]);
1291
1292 let snapshot = scripted
1294 .refresh(&SamplerConfig::default(), None, None)
1295 .expect("fallback snapshot");
1296 assert!(snapshot.records.is_empty());
1297 assert_eq!(snapshot.cursor.revision, 0);
1298
1299 assert_eq!(
1300 scripted
1301 .reported_record_count(&SamplerConfig::default())
1302 .expect("record count"),
1303 0
1304 );
1305 assert!(scripted.default_triplet_recipes().is_empty());
1306
1307 let panic_source = PanicSource {
1308 id: "panic_count".to_string(),
1309 };
1310 assert_eq!(
1311 panic_source
1312 .reported_record_count(&SamplerConfig::default())
1313 .expect("record count"),
1314 0
1315 );
1316 }
1317
1318 #[test]
1319 fn seed_capturing_source_trait_defaults_are_exercised() {
1320 let source = SeedCapturingSource::new("seed_defaults", Arc::new(Mutex::new(Vec::new())));
1321 assert_eq!(
1322 source
1323 .reported_record_count(&SamplerConfig::default())
1324 .expect("record count"),
1325 0
1326 );
1327 assert!(source.default_triplet_recipes().is_empty());
1328 }
1329
1330 #[test]
1331 fn refresh_paths_handle_zero_capacity_and_no_sources() {
1332 let mut manager = IngestionManager::new(0, SamplerConfig::default());
1333 manager.register_source(Box::new(ScriptedSource::new(
1334 "zero_capacity",
1335 Arc::new(AtomicUsize::new(0)),
1336 vec![Ok(SourceSnapshot {
1337 records: vec![make_record("r1", "zero_capacity")],
1338 cursor: SourceCursor {
1339 last_seen: Utc::now(),
1340 revision: 1,
1341 },
1342 })],
1343 )));
1344 manager.refresh_all();
1345 assert!(manager.all_caches_empty());
1346
1347 let mut empty_manager = IngestionManager::new(4, SamplerConfig::default());
1349 let empty_weights = HashMap::new();
1350 empty_manager
1351 .refresh_all_with_weights(&empty_weights)
1352 .expect("no sources should not error");
1353 assert!(empty_manager.all_caches_empty());
1354 }
1355
1356 #[test]
1357 fn drain_start_rotates_fairly_across_sources() {
1358 struct FairSource {
1362 id: String,
1363 refresh_count: Arc<AtomicUsize>,
1364 }
1365
1366 impl DataSource for FairSource {
1367 fn id(&self) -> &str {
1368 &self.id
1369 }
1370 fn refresh(
1371 &self,
1372 _config: &SamplerConfig,
1373 _cursor: Option<&SourceCursor>,
1374 _limit: Option<usize>,
1375 ) -> Result<SourceSnapshot, SamplerError> {
1376 self.refresh_count.fetch_add(1, Ordering::SeqCst);
1377 Ok(SourceSnapshot {
1378 records: (0..10)
1379 .map(|i| make_record(&format!("r{i}"), &self.id))
1380 .collect(),
1381 cursor: SourceCursor {
1382 last_seen: Utc::now(),
1383 revision: 1,
1384 },
1385 })
1386 }
1387 fn reported_record_count(&self, _: &SamplerConfig) -> Result<u128, SamplerError> {
1388 Ok(10)
1389 }
1390 }
1391
1392 let counts = (
1393 Arc::new(AtomicUsize::new(0)),
1394 Arc::new(AtomicUsize::new(0)),
1395 Arc::new(AtomicUsize::new(0)),
1396 );
1397
1398 let mut manager = IngestionManager::new(30, SamplerConfig::default());
1399 manager.register_source(Box::new(FairSource {
1400 id: "src_0".to_string(),
1401 refresh_count: Arc::clone(&counts.0),
1402 }));
1403 manager.register_source(Box::new(FairSource {
1404 id: "src_1".to_string(),
1405 refresh_count: Arc::clone(&counts.1),
1406 }));
1407 manager.register_source(Box::new(FairSource {
1408 id: "src_2".to_string(),
1409 refresh_count: Arc::clone(&counts.2),
1410 }));
1411
1412 manager.refresh_all();
1414 assert_eq!(counts.0.load(Ordering::SeqCst), 1);
1416 assert_eq!(counts.1.load(Ordering::SeqCst), 1);
1417 assert_eq!(counts.2.load(Ordering::SeqCst), 1);
1418
1419 for _ in 0..33 {
1425 manager.advance(1);
1426 }
1427
1428 let r0 = counts.0.load(Ordering::SeqCst);
1429 let r1 = counts.1.load(Ordering::SeqCst);
1430 let r2 = counts.2.load(Ordering::SeqCst);
1431
1432 let min = r0.min(r1).min(r2);
1439 let max = r0.max(r1).max(r2);
1440 assert!(
1441 max <= min + 1,
1442 "sources should refresh at roughly the same rate: got {r0}/{r1}/{r2}"
1443 );
1444 }
1445
1446 #[test]
1447 fn direct_drain_start_rotates_fairly_with_batch_2_of_5() {
1448 struct SimpleSource {
1451 id: String,
1452 refresh_count: Arc<AtomicUsize>,
1453 }
1454
1455 impl DataSource for SimpleSource {
1456 fn id(&self) -> &str {
1457 &self.id
1458 }
1459 fn refresh(
1460 &self,
1461 _: &SamplerConfig,
1462 _: Option<&SourceCursor>,
1463 _: Option<usize>,
1464 ) -> Result<SourceSnapshot, SamplerError> {
1465 self.refresh_count.fetch_add(1, Ordering::SeqCst);
1466 let now = Utc::now();
1467 let records: Vec<DataRecord> = (0..8)
1468 .map(|i| DataRecord {
1469 id: format!("{}_r{i}", self.id),
1470 source: self.id.clone(),
1471 created_at: now,
1472 updated_at: now,
1473 quality: QualityScore { trust: 1.0 },
1474 taxonomy: Vec::new(),
1475 sections: Vec::new(),
1476 meta_prefix: None,
1477 })
1478 .collect();
1479 Ok(SourceSnapshot {
1480 records,
1481 cursor: SourceCursor {
1482 last_seen: now,
1483 revision: 1,
1484 },
1485 })
1486 }
1487 fn reported_record_count(&self, _: &SamplerConfig) -> Result<u128, SamplerError> {
1488 Ok(8)
1489 }
1490 }
1491
1492 let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1493 let mut manager = IngestionManager::new(40, SamplerConfig::default());
1494 for (i, count) in counts.iter().enumerate() {
1495 manager.register_source(Box::new(SimpleSource {
1496 id: format!("src_{i}"),
1497 refresh_count: Arc::clone(count),
1498 }));
1499 }
1500
1501 manager.refresh_all();
1502 for c in &counts {
1503 assert_eq!(c.load(Ordering::SeqCst), 1);
1504 }
1505
1506 for _ in 0..80 {
1507 manager.advance(2);
1508 }
1509
1510 let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1511 assert_eq!(
1515 totals,
1516 vec![5, 5, 6, 6, 6],
1517 "direct manager: unexpected refresh distribution"
1518 );
1519 }
1520
1521 fn make_five_source_sampler(
1524 counts: &[Arc<AtomicUsize>],
1525 ) -> TripletSampler<DeterministicSplitStore> {
1526 struct Tracked {
1527 id: String,
1528 refresh_count: Arc<AtomicUsize>,
1529 }
1530 impl DataSource for Tracked {
1531 fn id(&self) -> &str {
1532 &self.id
1533 }
1534 fn refresh(
1535 &self,
1536 _: &SamplerConfig,
1537 _: Option<&SourceCursor>,
1538 _: Option<usize>,
1539 ) -> Result<SourceSnapshot, SamplerError> {
1540 self.refresh_count.fetch_add(1, Ordering::SeqCst);
1541 let now = Utc::now();
1542 let records: Vec<DataRecord> = (0..8)
1543 .map(|i| DataRecord {
1544 id: format!("{}_r{i}", self.id),
1545 source: self.id.clone(),
1546 created_at: now,
1547 updated_at: now,
1548 quality: QualityScore { trust: 1.0 },
1549 taxonomy: Vec::new(),
1550 sections: vec![RecordSection {
1551 role: SectionRole::Anchor,
1552 heading: None,
1553 text: format!("x{i}"),
1554 sentences: vec![format!("x{i}")],
1555 }],
1556 meta_prefix: None,
1557 })
1558 .collect();
1559 Ok(SourceSnapshot {
1560 records,
1561 cursor: SourceCursor {
1562 last_seen: now,
1563 revision: 1,
1564 },
1565 })
1566 }
1567 fn reported_record_count(&self, _: &SamplerConfig) -> Result<u128, SamplerError> {
1568 Ok(8)
1569 }
1570 }
1571
1572 let config = SamplerConfig {
1573 batch_size: 2,
1574 text_recipes: vec![TextRecipe {
1575 name: "anchor".into(),
1576 selector: Selector::Role(SectionRole::Anchor),
1577 weight: 1.0,
1578 instruction: None,
1579 }],
1580 split: SplitRatios {
1581 train: 1.0,
1582 validation: 0.0,
1583 test: 0.0,
1584 },
1585 allowed_splits: vec![SplitLabel::Train],
1586 ingestion_max_records: 40,
1587 ..SamplerConfig::default()
1588 };
1589 let store = Arc::new(DeterministicSplitStore::new(config.split, 99).unwrap());
1590 let sampler = TripletSampler::new(config, store);
1591
1592 for (i, count) in counts.iter().enumerate() {
1593 sampler.register_source(Box::new(Tracked {
1594 id: format!("src_{i}"),
1595 refresh_count: Arc::clone(count),
1596 }));
1597 }
1598 sampler
1599 }
1600
1601 #[test]
1602 fn sampler_unweighted_drain_distributes_evenly() {
1603 let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1604 let sampler = make_five_source_sampler(&counts);
1605
1606 sampler.next_text_batch(SplitLabel::Train).unwrap();
1607 for _ in 0..80 {
1608 sampler.next_text_batch(SplitLabel::Train).unwrap();
1609 }
1610
1611 let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1612 let min = *totals.iter().min().unwrap();
1613 let max = *totals.iter().max().unwrap();
1614 assert!(
1615 max <= min + 1,
1616 "unweighted: all sources must refresh at roughly the same rate: {totals:?}"
1617 );
1618 assert!(
1619 min >= 4,
1620 "unweighted: each source should have refreshed at least 4 times: {totals:?}"
1621 );
1622 }
1623
1624 #[test]
1625 fn sampler_weighted_drain_with_equal_weights_distributes_evenly() {
1626 let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1630 let sampler = make_five_source_sampler(&counts);
1631
1632 let weights: HashMap<String, f32> = (0..5).map(|i| (format!("src_{i}"), 1.0)).collect();
1633
1634 sampler
1635 .next_text_batch_with_weights(SplitLabel::Train, &weights)
1636 .unwrap();
1637 for _ in 0..80 {
1638 sampler
1639 .next_text_batch_with_weights(SplitLabel::Train, &weights)
1640 .unwrap();
1641 }
1642
1643 let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1644 let min = *totals.iter().min().unwrap();
1645 let max = *totals.iter().max().unwrap();
1646 assert!(
1647 max <= min + 1,
1648 "weighted (equal): all sources must refresh at roughly the same rate: {totals:?}"
1649 );
1650 assert!(
1651 min >= 4,
1652 "weighted (equal): each source should have refreshed at least 4 times: {totals:?}"
1653 );
1654 }
1655
1656 #[test]
1657 fn sampler_unweighted_and_weighted_match_distribution() {
1658 let uc: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1660 let wc: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1661 let usampler = make_five_source_sampler(&uc);
1662 let wsampler = make_five_source_sampler(&wc);
1663
1664 let weights: HashMap<String, f32> = (0..5).map(|i| (format!("src_{i}"), 1.0)).collect();
1665
1666 usampler.next_text_batch(SplitLabel::Train).unwrap();
1667 wsampler
1668 .next_text_batch_with_weights(SplitLabel::Train, &weights)
1669 .unwrap();
1670 for _ in 0..80 {
1671 usampler.next_text_batch(SplitLabel::Train).unwrap();
1672 wsampler
1673 .next_text_batch_with_weights(SplitLabel::Train, &weights)
1674 .unwrap();
1675 }
1676
1677 let ut: Vec<usize> = uc.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1678 let wt: Vec<usize> = wc.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1679 let umax = *ut.iter().max().unwrap();
1680 let umin = *ut.iter().min().unwrap();
1681 let wmax = *wt.iter().max().unwrap();
1682 let wmin = *wt.iter().min().unwrap();
1683 assert!(umax <= umin + 1, "unweighted: {ut:?}");
1684 assert!(wmax <= wmin + 1, "weighted equal: {wt:?}");
1685 }
1686
1687 #[test]
1688 fn sampler_weighted_drain_with_unequal_weights_respects_ratios() {
1689 let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1693 let sampler = make_five_source_sampler(&counts);
1694
1695 let mut weights = HashMap::new();
1696 for i in 0..5 {
1697 weights.insert(format!("src_{i}"), if i == 3 { 2.0f32 } else { 1.0 });
1698 }
1699
1700 sampler
1701 .next_text_batch_with_weights(SplitLabel::Train, &weights)
1702 .unwrap();
1703 for _ in 0..200 {
1704 sampler
1705 .next_text_batch_with_weights(SplitLabel::Train, &weights)
1706 .unwrap();
1707 }
1708
1709 let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1710 assert!(
1718 totals
1719 .iter()
1720 .enumerate()
1721 .all(|(i, &t)| i == 3 || t < totals[3]),
1722 "src_3 (w=2.0) must outpace all w=1.0 sources (totals: {totals:?})"
1723 );
1724 assert_eq!(
1726 totals,
1727 vec![6, 7, 7, 26, 12],
1728 "unequal-weights: unexpected refresh distribution"
1729 );
1730 }
1731}