1use crate::config::SamplerConfig;
2use crate::constants::splits::is_reserved_source_id;
3use crate::data::DataRecord;
4use crate::errors::SamplerError;
5use crate::hash::derive_epoch_seed;
6use crate::source::{DataSource, SourceCursor, SourceSnapshot};
7use crate::types::{RecordId, SourceId};
8use chrono::Utc;
9use indexmap::IndexMap;
10use std::collections::HashMap;
11use std::collections::VecDeque;
12use std::sync::{Arc, Condvar, Mutex, RwLock};
13use std::thread;
14use std::time::Duration;
15use tracing::debug;
16
17#[derive(Clone)]
19pub struct RecordCache {
20 inner: Arc<RwLock<RecordCacheInner>>,
21 notifier: Arc<(Mutex<CacheStats>, Condvar)>,
22}
23
24struct RecordCacheInner {
26 records: IndexMap<RecordId, CachedRecord>,
27 order: VecDeque<RecordId>,
28 max_records: usize,
29 next_version: u64,
30}
31
32struct CachedRecord {
34 record: DataRecord,
35 version: u64,
36}
37
38#[derive(Default)]
40struct CacheStats {
41 ingests: u64,
42}
43
44impl RecordCache {
45 pub fn new(max_records: usize) -> Self {
47 Self {
48 inner: Arc::new(RwLock::new(RecordCacheInner {
49 records: IndexMap::new(),
50 order: VecDeque::new(),
51 max_records,
52 next_version: 0,
53 })),
54 notifier: Arc::new((Mutex::new(CacheStats::default()), Condvar::new())),
55 }
56 }
57
58 pub fn ingest<I>(&self, records: I)
60 where
61 I: IntoIterator<Item = DataRecord>,
62 {
63 let mut batch: Vec<DataRecord> = records.into_iter().collect();
64 if batch.is_empty() {
65 return;
66 }
67 let mut inner = self.inner.write().expect("record cache poisoned");
68 inner.ingest_batch(&mut batch);
69 drop(inner);
70 let (lock, cvar) = &*self.notifier;
71 let mut stats = lock.lock().expect("record cache stats poisoned");
72 stats.ingests = stats.ingests.saturating_add(1);
73 cvar.notify_all();
74 }
75
76 pub fn clear(&self) {
78 let mut inner = self.inner.write().expect("record cache poisoned");
79 inner.records.clear();
80 inner.order.clear();
81 }
82
83 pub fn snapshot(&self) -> Vec<DataRecord> {
85 let inner = self.inner.read().expect("record cache poisoned");
86 inner
87 .records
88 .values()
89 .map(|entry| entry.record.clone())
90 .collect()
91 }
92
93 pub fn ingest_count(&self) -> u64 {
95 let (lock, _) = &*self.notifier;
96 lock.lock().expect("record cache stats poisoned").ingests
97 }
98
99 pub fn wait_for_ingest(&self, last_seen: u64, timeout: Duration) -> u64 {
101 let (lock, cvar) = &*self.notifier;
102 let mut stats = lock.lock().expect("record cache stats poisoned");
103 while stats.ingests <= last_seen {
104 let result = cvar
105 .wait_timeout(stats, timeout)
106 .expect("record cache stats poisoned");
107 stats = result.0;
108 if result.1.timed_out() {
109 break;
110 }
111 }
112 stats.ingests
113 }
114
115 pub fn wait_for_ingest_blocking(&self, last_seen: u64) -> u64 {
117 let (lock, cvar) = &*self.notifier;
118 let mut stats = lock.lock().expect("record cache stats poisoned");
119 while stats.ingests <= last_seen {
120 stats = cvar.wait(stats).expect("record cache stats poisoned");
121 }
122 stats.ingests
123 }
124
125 pub fn is_empty(&self) -> bool {
127 let inner = self.inner.read().expect("record cache poisoned");
128 inner.records.is_empty()
129 }
130
131 pub fn len(&self) -> usize {
133 let inner = self.inner.read().expect("record cache poisoned");
134 inner.records.len()
135 }
136}
137
138impl RecordCacheInner {
139 fn ingest_batch(&mut self, records: &mut Vec<DataRecord>) {
140 for record in records.drain(..) {
141 self.next_version = self.next_version.saturating_add(1);
142 let record_id = record.id.clone();
143 if self.records.contains_key(&record_id) {
144 if let Some(entry) = self.records.get_mut(&record_id) {
145 entry.record = record;
146 entry.version = self.next_version;
147 }
148 Self::refresh_order(&mut self.order, &record_id);
149 self.order.push_back(record_id);
150 } else {
151 self.order.push_back(record_id.clone());
152 self.records.insert(
153 record_id,
154 CachedRecord {
155 record,
156 version: self.next_version,
157 },
158 );
159 }
160 self.enforce_limit();
161 }
162 }
163
164 fn enforce_limit(&mut self) {
165 if self.max_records == 0 {
166 self.records.clear();
167 self.order.clear();
168 return;
169 }
170 while self.records.len() > self.max_records {
171 if let Some(oldest) = self.order.pop_front() {
172 self.records.swap_remove(&oldest);
173 } else {
174 break;
175 }
176 }
177 }
178
179 fn refresh_order(order: &mut VecDeque<RecordId>, id: &RecordId) {
180 if order.is_empty() {
181 return;
182 }
183 if let Some(pos) = order.iter().position(|existing| existing == id) {
184 order.remove(pos);
185 }
186 }
187}
188
189pub struct IngestionManager {
191 sources: Vec<SourceState>,
192 max_records: usize,
193 sampler_config: SamplerConfig,
194 epoch: u64,
196 source_refresh_generation: u64,
198 last_refreshed_sources: Vec<SourceId>,
203 drain_start: usize,
208 epoch_step: u64,
211}
212
213#[derive(Clone, Debug, Default)]
214pub struct SourceRefreshStats {
216 pub last_refresh_ms: u128,
218 pub last_record_count: usize,
220 pub last_records_per_sec: f64,
222 pub last_error: Option<String>,
224 pub error_count: u64,
226}
227
228impl IngestionManager {
229 pub fn new(max_records: usize, sampler_config: SamplerConfig) -> Self {
231 Self {
232 sources: Vec::new(),
233 max_records,
234 sampler_config,
235 epoch: 0,
236 source_refresh_generation: 0,
237 last_refreshed_sources: Vec::new(),
238 drain_start: 0,
239 epoch_step: 0,
240 }
241 }
242
243 pub fn source_refresh_generation(&self) -> u64 {
245 self.source_refresh_generation
246 }
247
248 pub fn last_refreshed_sources(&self) -> &[SourceId] {
250 &self.last_refreshed_sources
251 }
252
253 pub(crate) fn set_epoch(&mut self, epoch: u64) {
260 self.epoch = epoch;
261 }
262
263 pub(crate) fn reset_epoch_step(&mut self) {
267 self.epoch_step = 0;
268 }
269
270 pub(crate) fn increment_epoch_step(&mut self) {
273 self.epoch_step = self.epoch_step.saturating_add(1);
274 }
275
276 pub fn epoch_step(&self) -> u64 {
278 self.epoch_step
279 }
280
281 pub(crate) fn set_epoch_step(&mut self, step: u64) {
283 self.epoch_step = step;
284 }
285
286 #[cfg(test)]
288 pub fn epoch(&self) -> u64 {
289 self.epoch
290 }
291
292 pub(crate) fn reset_stream_cursors(&mut self) {
298 for state in &mut self.sources {
299 state.cursor = None;
300 state.buffer.clear();
301 state.cache.clear();
302 }
303 }
304
305 pub fn register_source(
310 &mut self,
311 source: Box<dyn DataSource + 'static>,
312 ) -> Result<(), SamplerError> {
313 let source_id = source.id().to_string();
314 if is_reserved_source_id(&source_id) {
315 return Err(SamplerError::ReservedSourceId(source_id));
316 }
317 let cache = RecordCache::new(self.max_records);
318 self.sources.push(SourceState {
319 source,
320 cursor: None,
321 buffer: VecDeque::new(),
322 cache,
323 stats: SourceRefreshStats::default(),
324 });
325 Ok(())
326 }
327
328 pub fn load_cursors(&mut self, cursors: &[(SourceId, u64)]) {
330 if cursors.is_empty() {
331 return;
332 }
333 let mut map = std::collections::HashMap::with_capacity(cursors.len());
334 for (id, revision) in cursors {
335 map.insert(id.as_str(), *revision);
336 }
337 for state in &mut self.sources {
338 if let Some(revision) = map.get(state.source.id()) {
339 state.cursor = Some(SourceCursor {
340 last_seen: Utc::now(),
341 revision: *revision,
342 });
343 }
344 }
345 }
346
347 pub fn snapshot_cursors(&self) -> Vec<(SourceId, u64)> {
349 let mut out = Vec::new();
350 for state in &self.sources {
351 if let Some(cursor) = state.cursor.as_ref() {
352 out.push((state.source.id().to_string(), cursor.revision));
353 }
354 }
355 out
356 }
357
358 pub fn source_refresh_stats(&self) -> Vec<(SourceId, SourceRefreshStats)> {
360 self.sources
361 .iter()
362 .map(|state| (state.source.id().to_string(), state.stats.clone()))
363 .collect()
364 }
365
366 pub fn all_records_snapshot(&self) -> Vec<DataRecord> {
371 self.sources
372 .iter()
373 .flat_map(|s| s.cache.snapshot())
374 .collect()
375 }
376
377 pub fn all_caches_empty(&self) -> bool {
379 self.sources.iter().all(|s| s.cache.is_empty())
380 }
381
382 pub fn all_records_len(&self) -> usize {
384 self.sources.iter().map(|s| s.cache.len()).sum()
385 }
386
387 pub fn total_ingest_count(&self) -> u64 {
392 self.sources.iter().map(|s| s.cache.ingest_count()).sum()
393 }
394
395 pub fn refresh_all(&mut self) {
397 self.refresh_all_internal(false, None, None);
398 }
399
400 pub fn advance(&mut self, step: usize) {
402 self.refresh_all_internal(false, Some(step), None);
403 }
404
405 pub fn advance_with_weights(
410 &mut self,
411 step: usize,
412 weights: &HashMap<SourceId, f32>,
413 ) -> Result<(), SamplerError> {
414 self.validate_weights(weights)?;
415 self.refresh_all_internal(false, Some(step), Some(weights));
416 Ok(())
417 }
418
419 pub fn force_refresh_all(&mut self) {
421 self.refresh_all_internal(true, None, None);
422 }
423
424 pub fn refresh_all_with_weights(
429 &mut self,
430 weights: &HashMap<SourceId, f32>,
431 ) -> Result<(), SamplerError> {
432 self.validate_weights(weights)?;
433 self.refresh_all_internal(false, None, Some(weights));
434 Ok(())
435 }
436
437 pub fn force_refresh_all_with_weights(
442 &mut self,
443 weights: &HashMap<SourceId, f32>,
444 ) -> Result<(), SamplerError> {
445 self.validate_weights(weights)?;
446 self.refresh_all_internal(true, None, Some(weights));
447 Ok(())
448 }
449
450 fn validate_weights(&self, weights: &HashMap<SourceId, f32>) -> Result<(), SamplerError> {
451 let known_ids: std::collections::HashSet<&str> =
452 self.sources.iter().map(|s| s.source.id()).collect();
453 for (id, &w) in weights {
454 if !known_ids.contains(id.as_str()) {
455 return Err(SamplerError::InvalidWeight {
456 source_id: id.clone(),
457 reason: "source is not registered".to_string(),
458 });
459 }
460 if w < 0.0 {
461 return Err(SamplerError::InvalidWeight {
462 source_id: id.clone(),
463 reason: format!("weight {w} is negative"),
464 });
465 }
466 }
467 Ok(())
468 }
469
470 fn refresh_all_internal(
477 &mut self,
478 force_refresh: bool,
479 step: Option<usize>,
480 weights: Option<&HashMap<SourceId, f32>>,
481 ) {
482 self.last_refreshed_sources.clear();
483 let mut refresh_plan = Vec::new();
484 for (idx, state) in self.sources.iter_mut().enumerate() {
485 if force_refresh {
486 state.buffer.clear();
487 }
488 if force_refresh || state.buffer.is_empty() {
489 refresh_plan.push((idx, state.cursor.clone()));
490 }
491 }
492
493 if !refresh_plan.is_empty() {
494 self.source_refresh_generation = self.source_refresh_generation.saturating_add(1);
495 self.last_refreshed_sources = refresh_plan
496 .iter()
497 .map(|(idx, _)| self.sources[*idx].source.id().to_string())
498 .collect();
499 let mut results: Vec<
500 Option<(Result<SourceSnapshot, SamplerError>, std::time::Duration)>,
501 > = Vec::with_capacity(self.sources.len());
502 results.resize_with(self.sources.len(), || None);
503 let fetch_limit = self.max_records;
504 let sampler_config = self.sampler_config.clone();
505 let step = self.epoch_step;
506 thread::scope(|scope| {
507 let mut handles = Vec::with_capacity(refresh_plan.len());
508 for (idx, cursor) in &refresh_plan {
509 let source = &self.sources[*idx].source;
510 let cursor = cursor.clone();
511 let sampler_config = sampler_config.clone();
512 let epoch = self.epoch;
513 handles.push((
514 *idx,
515 scope.spawn(move || {
516 let start = std::time::Instant::now();
517 let epoch_config = SamplerConfig {
523 seed: derive_epoch_seed(sampler_config.seed, epoch) ^ step,
524 ..sampler_config
525 };
526 let result =
527 source.refresh(&epoch_config, cursor.as_ref(), Some(fetch_limit));
528 let elapsed = start.elapsed();
529 (result, elapsed)
530 }),
531 ));
532 }
533 for (idx, handle) in handles {
534 let result = match handle.join() {
535 Ok((result, elapsed)) => {
536 debug!(
537 source_id = %self.sources[idx].source.id(),
538 refresh_ms = elapsed.as_millis(),
539 "source refresh completed"
540 );
541 (result, elapsed)
542 }
543 Err(_) => (
544 Err(SamplerError::SourceUnavailable {
545 source_id: self.sources[idx].source.id().to_string(),
546 reason: "source refresh thread panicked".into(),
547 }),
548 std::time::Duration::from_secs(0),
549 ),
550 };
551 results[idx] = Some(result);
552 }
553 });
554
555 for (idx, _) in refresh_plan {
556 let Some((result, elapsed)) = results[idx].take() else {
557 continue;
558 };
559 match result {
560 Ok(snapshot) => {
561 let SourceSnapshot {
562 records,
563 cursor: next_cursor,
564 } = snapshot;
565 let record_count = records.len();
566 let seconds = elapsed.as_secs_f64();
567 let per_sec = if seconds > 0.0 {
568 (record_count as f64) / seconds
569 } else {
570 0.0
571 };
572 let stats = &mut self.sources[idx].stats;
573 stats.last_refresh_ms = elapsed.as_millis();
574 stats.last_record_count = record_count;
575 stats.last_records_per_sec = per_sec;
576 stats.last_error = None;
577 debug!(
578 source_id = %self.sources[idx].source.id(),
579 record_count,
580 refresh_ms = elapsed.as_millis(),
581 records_per_sec = per_sec,
582 "source refresh ingested records"
583 );
584 let source_id = self.sources[idx].source.id().to_string();
585 let normalized = records
586 .into_iter()
587 .map(|mut record| {
588 record.source = source_id.clone();
589 record
590 })
591 .collect::<Vec<_>>();
592 self.sources[idx].buffer.extend(normalized);
593 self.sources[idx].cursor = Some(next_cursor);
594 }
595 Err(err) => {
596 let stats = &mut self.sources[idx].stats;
597 stats.last_refresh_ms = elapsed.as_millis();
598 stats.last_record_count = 0;
599 stats.last_records_per_sec = 0.0;
600 stats.last_error = Some(err.to_string());
601 stats.error_count = stats.error_count.saturating_add(1);
602 eprintln!(
603 "[data_sampler] source '{}' refresh failed: {}",
604 self.sources[idx].source.id(),
605 err
606 );
607 }
608 }
609 }
610 }
611
612 if step.is_none() {
616 for state in self.sources.iter_mut() {
617 state.cache.clear();
618 }
619 }
620 if self.max_records == 0 {
621 return;
622 }
623 let target_limit = step.unwrap_or(self.max_records);
624 if let Some(weights) = weights {
625 self.weighted_drain_into_caches(target_limit, weights);
626 } else {
627 let n = self.sources.len();
632 if n > 0 {
633 let mut per_source: Vec<Vec<DataRecord>> = vec![Vec::new(); n];
634 let mut total_drained = 0;
635 let mut any_remaining = true;
636 while total_drained < target_limit && any_remaining {
637 any_remaining = false;
638 for offset in 0..n {
639 if total_drained >= target_limit {
640 break;
641 }
642 let idx = (self.drain_start + offset) % n;
643 if let Some(record) = self.sources[idx].buffer.pop_front() {
644 per_source[idx].push(record);
645 total_drained += 1;
646 any_remaining = true;
647 }
648 }
649 }
650 if total_drained > 0 {
654 self.drain_start = (self.drain_start + 1) % n;
655 }
656 for (idx, batch) in per_source.into_iter().enumerate() {
657 if !batch.is_empty() {
658 self.sources[idx].cache.ingest(batch);
659 }
660 }
661 }
662 }
663 }
664
665 fn weighted_drain_into_caches(&mut self, limit: usize, weights: &HashMap<SourceId, f32>) {
666 let len = self.sources.len();
667 if len == 0 {
668 return;
669 }
670 let mut weight_values = Vec::with_capacity(len);
671 let mut any_positive = false;
672 for state in &self.sources {
673 let weight = weights.get(state.source.id()).copied().unwrap_or(1.0);
674 if weight > 0.0 {
675 any_positive = true;
676 }
677 weight_values.push(weight);
678 }
679 if !any_positive {
680 weight_values.fill(1.0);
681 }
682
683 let mut current = vec![0.0f32; len];
684 let mut per_source: Vec<Vec<DataRecord>> = vec![Vec::new(); len];
685 let mut total = 0;
686 while total < limit {
687 let mut total_weight = 0.0f32;
688 for (idx, weight) in weight_values.iter().copied().enumerate().take(len) {
689 if weight <= 0.0 {
690 continue;
691 }
692 if self.sources[idx].buffer.is_empty() {
693 continue;
694 }
695 total_weight += weight;
696 }
697 if total_weight == 0.0 {
698 break;
699 }
700
701 let mut best_idx = None;
702 let mut best_score = f32::MIN;
703 let closer_to_start = |a: usize, b: usize| -> bool {
707 let da = (a + len - self.drain_start) % len;
708 let db = (b + len - self.drain_start) % len;
709 da < db
710 };
711 for idx in 0..len {
712 if weight_values[idx] <= 0.0 {
713 continue;
714 }
715 if self.sources[idx].buffer.is_empty() {
716 continue;
717 }
718 current[idx] += weight_values[idx];
719 let is_better = if current[idx] > best_score {
720 true
721 } else if current[idx] == best_score {
722 closer_to_start(idx, best_idx.unwrap_or(0))
723 } else {
724 false
725 };
726 if is_better {
727 best_score = current[idx];
728 best_idx = Some(idx);
729 }
730 }
731
732 let idx = match best_idx {
733 Some(idx) => idx,
734 None => break,
735 };
736 current[idx] -= total_weight;
737 if let Some(record) = self.sources[idx].buffer.pop_front() {
738 per_source[idx].push(record);
739 total += 1;
740 }
741 }
742
743 if total > 0 && len > 0 {
744 self.drain_start = (self.drain_start + 1) % len;
745 }
746
747 for (idx, batch) in per_source.into_iter().enumerate() {
748 if !batch.is_empty() {
749 self.sources[idx].cache.ingest(batch);
750 }
751 }
752 }
753
754 pub fn has_sources(&self) -> bool {
756 !self.sources.is_empty()
757 }
758}
759
760struct SourceState {
762 source: Box<dyn DataSource + 'static>,
763 cursor: Option<SourceCursor>,
764 buffer: VecDeque<DataRecord>,
765 cache: RecordCache,
767 stats: SourceRefreshStats,
768}
769
770#[cfg(test)]
771mod tests {
772 use super::*;
773 use crate::TripletSampler;
774 use crate::config::{Selector, TextRecipe, TripletRecipe};
775 use crate::data::{QualityScore, RecordSection, SectionRole};
776 use crate::sampler::Sampler;
777 use crate::splits::{DeterministicSplitStore, SamplerStateStore, SplitLabel, SplitRatios};
778 use chrono::Utc;
779 use std::collections::HashMap;
780 use std::collections::VecDeque;
781 use std::sync::atomic::{AtomicUsize, Ordering};
782 use std::sync::{Arc, Mutex};
783
784 fn make_record(id: &str, source: &str) -> DataRecord {
785 let now = Utc::now();
786 DataRecord {
787 id: id.to_string(),
788 source: source.to_string(),
789 created_at: now,
790 updated_at: now,
791 quality: QualityScore { trust: 1.0 },
792 taxonomy: Vec::new(),
793 sections: vec![RecordSection {
794 role: SectionRole::Anchor,
795 heading: None,
796 text: id.to_string(),
797 sentences: vec![id.to_string()],
798 }],
799 meta_prefix: None,
800 }
801 }
802
803 struct ScriptedSource {
804 id: String,
805 refreshes: Arc<AtomicUsize>,
806 script: Arc<Mutex<VecDeque<Result<SourceSnapshot, SamplerError>>>>,
807 }
808
809 impl ScriptedSource {
810 fn new(
811 id: &str,
812 refreshes: Arc<AtomicUsize>,
813 script: Vec<Result<SourceSnapshot, SamplerError>>,
814 ) -> Self {
815 Self {
816 id: id.to_string(),
817 refreshes,
818 script: Arc::new(Mutex::new(script.into_iter().collect())),
819 }
820 }
821 }
822
823 impl DataSource for ScriptedSource {
824 fn id(&self) -> &str {
825 &self.id
826 }
827
828 fn refresh(
829 &self,
830 _config: &SamplerConfig,
831 _cursor: Option<&SourceCursor>,
832 _limit: Option<usize>,
833 ) -> Result<SourceSnapshot, SamplerError> {
834 self.refreshes.fetch_add(1, Ordering::SeqCst);
835 let mut guard = self.script.lock().expect("script lock poisoned");
836 guard.pop_front().unwrap_or_else(|| {
837 Ok(SourceSnapshot {
838 records: Vec::new(),
839 cursor: SourceCursor {
840 last_seen: Utc::now(),
841 revision: 0,
842 },
843 })
844 })
845 }
846
847 fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
848 Ok(0)
849 }
850
851 fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
852 Vec::new()
853 }
854 }
855
856 struct PanicSource {
857 id: String,
858 }
859
860 impl DataSource for PanicSource {
861 fn id(&self) -> &str {
862 &self.id
863 }
864
865 fn refresh(
866 &self,
867 _config: &SamplerConfig,
868 _cursor: Option<&SourceCursor>,
869 _limit: Option<usize>,
870 ) -> Result<SourceSnapshot, SamplerError> {
871 panic!("panic source refresh")
872 }
873
874 fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
875 Ok(0)
876 }
877 }
878
879 #[test]
880 fn record_cache_waits_len_and_clear_paths_are_covered() {
881 let cache = RecordCache::new(2);
882 assert!(cache.is_empty());
883 assert_eq!(cache.len(), 0);
884 assert_eq!(cache.ingest_count(), 0);
885
886 cache.ingest(Vec::<DataRecord>::new());
887 assert_eq!(cache.wait_for_ingest(0, Duration::from_millis(1)), 0);
888
889 let cache_for_waiter = cache.clone();
890 let handle = std::thread::spawn(move || cache_for_waiter.wait_for_ingest_blocking(0));
891 std::thread::sleep(Duration::from_millis(5));
892 cache.ingest(vec![make_record("r1", "s")]);
893 assert_eq!(handle.join().unwrap(), 1);
894 assert_eq!(cache.ingest_count(), 1);
895
896 cache.ingest(vec![make_record("r2", "s"), make_record("r3", "s")]);
897 assert_eq!(cache.len(), 2);
898 let ids: Vec<String> = cache
899 .snapshot()
900 .into_iter()
901 .map(|record| record.id)
902 .collect();
903 assert!(ids.contains(&"r2".to_string()));
904 assert!(ids.contains(&"r3".to_string()));
905
906 cache.clear();
907 assert!(cache.is_empty());
908 }
909
910 #[test]
911 fn record_cache_zero_limit_discards_everything() {
912 let cache = RecordCache::new(0);
913 cache.ingest(vec![make_record("r1", "s")]);
914 assert!(cache.is_empty());
915 assert_eq!(cache.len(), 0);
916 }
917
918 #[test]
919 fn manager_loads_and_snapshots_cursors_and_reports_has_sources() {
920 let mut manager = IngestionManager::new(4, SamplerConfig::default());
921 assert!(!manager.has_sources());
922 manager.load_cursors(&[]);
923
924 let refreshes = Arc::new(AtomicUsize::new(0));
925 manager
926 .register_source(Box::new(ScriptedSource::new(
927 "cursor_source",
928 refreshes,
929 vec![Ok(SourceSnapshot {
930 records: vec![make_record("id_1", "original_source")],
931 cursor: SourceCursor {
932 last_seen: Utc::now(),
933 revision: 33,
934 },
935 })],
936 )))
937 .unwrap();
938 assert!(manager.has_sources());
939
940 manager.load_cursors(&[("cursor_source".to_string(), 7)]);
941 let cursors = manager.snapshot_cursors();
942 assert_eq!(cursors.len(), 1);
943 assert_eq!(cursors[0], ("cursor_source".to_string(), 7));
944
945 manager.refresh_all();
946 let updated = manager.snapshot_cursors();
947 assert_eq!(updated.len(), 1);
948 assert_eq!(updated[0], ("cursor_source".to_string(), 33));
949 let records = manager.all_records_snapshot();
950 assert_eq!(records.len(), 1);
951 assert_eq!(records[0].source, "cursor_source");
952 }
953
954 #[test]
955 fn advance_uses_buffer_before_refreshing_again() {
956 let refreshes = Arc::new(AtomicUsize::new(0));
957 let mut manager = IngestionManager::new(5, SamplerConfig::default());
958 manager
959 .register_source(Box::new(ScriptedSource::new(
960 "buffered",
961 refreshes.clone(),
962 vec![Ok(SourceSnapshot {
963 records: vec![
964 make_record("a", "x"),
965 make_record("b", "x"),
966 make_record("c", "x"),
967 ],
968 cursor: SourceCursor {
969 last_seen: Utc::now(),
970 revision: 1,
971 },
972 })],
973 )))
974 .unwrap();
975
976 manager.advance(1);
977 assert_eq!(refreshes.load(Ordering::SeqCst), 1);
978 assert_eq!(manager.all_records_len(), 1);
979
980 manager.advance(1);
981 assert_eq!(refreshes.load(Ordering::SeqCst), 1);
982 assert_eq!(manager.all_records_len(), 2);
983 }
984
985 #[test]
986 fn force_refresh_clears_buffer_and_fetches_again() {
987 let refreshes = Arc::new(AtomicUsize::new(0));
988 let mut manager = IngestionManager::new(4, SamplerConfig::default());
989 manager
990 .register_source(Box::new(ScriptedSource::new(
991 "force",
992 refreshes.clone(),
993 vec![
994 Ok(SourceSnapshot {
995 records: vec![
996 make_record("r1", "x"),
997 make_record("r2", "x"),
998 make_record("r3", "x"),
999 ],
1000 cursor: SourceCursor {
1001 last_seen: Utc::now(),
1002 revision: 10,
1003 },
1004 }),
1005 Ok(SourceSnapshot {
1006 records: vec![make_record("r4", "x")],
1007 cursor: SourceCursor {
1008 last_seen: Utc::now(),
1009 revision: 11,
1010 },
1011 }),
1012 ],
1013 )))
1014 .unwrap();
1015
1016 manager.advance(1);
1017 assert_eq!(manager.all_records_len(), 1);
1018 assert_eq!(refreshes.load(Ordering::SeqCst), 1);
1019
1020 manager.force_refresh_all();
1021 assert_eq!(refreshes.load(Ordering::SeqCst), 2);
1022 let records = manager.all_records_snapshot();
1023 assert_eq!(records.len(), 1);
1024 assert_eq!(records[0].id, "r4");
1025 }
1026
1027 #[test]
1028 fn weighted_drain_respects_zero_and_fallback_weights() {
1029 let mut manager = IngestionManager::new(6, SamplerConfig::default());
1030 manager
1031 .register_source(Box::new(ScriptedSource::new(
1032 "a",
1033 Arc::new(AtomicUsize::new(0)),
1034 vec![Ok(SourceSnapshot {
1035 records: vec![make_record("a1", "a"), make_record("a2", "a")],
1036 cursor: SourceCursor {
1037 last_seen: Utc::now(),
1038 revision: 1,
1039 },
1040 })],
1041 )))
1042 .unwrap();
1043 manager
1044 .register_source(Box::new(ScriptedSource::new(
1045 "b",
1046 Arc::new(AtomicUsize::new(0)),
1047 vec![Ok(SourceSnapshot {
1048 records: vec![make_record("b1", "b"), make_record("b2", "b")],
1049 cursor: SourceCursor {
1050 last_seen: Utc::now(),
1051 revision: 1,
1052 },
1053 })],
1054 )))
1055 .unwrap();
1056
1057 let mut only_b = HashMap::new();
1058 only_b.insert("a".to_string(), 0.0);
1059 only_b.insert("b".to_string(), 1.0);
1060 manager.refresh_all_with_weights(&only_b).unwrap();
1061 let ids: Vec<String> = manager
1062 .all_records_snapshot()
1063 .into_iter()
1064 .map(|record| record.id)
1065 .collect();
1066 assert!(ids.iter().all(|id| id.starts_with('b')));
1067
1068 let mut manager_fallback = IngestionManager::new(6, SamplerConfig::default());
1069 manager_fallback
1070 .register_source(Box::new(ScriptedSource::new(
1071 "a",
1072 Arc::new(AtomicUsize::new(0)),
1073 vec![Ok(SourceSnapshot {
1074 records: vec![make_record("a1", "a")],
1075 cursor: SourceCursor {
1076 last_seen: Utc::now(),
1077 revision: 2,
1078 },
1079 })],
1080 )))
1081 .unwrap();
1082 manager_fallback
1083 .register_source(Box::new(ScriptedSource::new(
1084 "b",
1085 Arc::new(AtomicUsize::new(0)),
1086 vec![Ok(SourceSnapshot {
1087 records: vec![make_record("b1", "b")],
1088 cursor: SourceCursor {
1089 last_seen: Utc::now(),
1090 revision: 2,
1091 },
1092 })],
1093 )))
1094 .unwrap();
1095
1096 let mut all_zero = HashMap::new();
1097 all_zero.insert("a".to_string(), 0.0);
1098 all_zero.insert("b".to_string(), 0.0);
1099 manager_fallback
1100 .refresh_all_with_weights(&all_zero)
1101 .unwrap();
1102 let ids: Vec<String> = manager_fallback
1103 .all_records_snapshot()
1104 .into_iter()
1105 .map(|record| record.id)
1106 .collect();
1107 assert!(ids.contains(&"a1".to_string()));
1108 assert!(ids.contains(&"b1".to_string()));
1109 }
1110
1111 #[test]
1112 fn refresh_errors_and_panics_update_source_stats() {
1113 let mut manager = IngestionManager::new(4, SamplerConfig::default());
1114 manager
1115 .register_source(Box::new(ScriptedSource::new(
1116 "err_source",
1117 Arc::new(AtomicUsize::new(0)),
1118 vec![Err(SamplerError::SourceUnavailable {
1119 source_id: "err_source".to_string(),
1120 reason: "boom".to_string(),
1121 })],
1122 )))
1123 .unwrap();
1124 manager
1125 .register_source(Box::new(PanicSource {
1126 id: "panic_source".to_string(),
1127 }))
1128 .unwrap();
1129
1130 manager.refresh_all();
1131 let stats = manager.source_refresh_stats();
1132 let err_stats = stats
1133 .iter()
1134 .find(|(source, _)| source == "err_source")
1135 .map(|(_, stats)| stats)
1136 .unwrap();
1137 assert_eq!(err_stats.error_count, 1);
1138 assert!(
1139 err_stats
1140 .last_error
1141 .as_ref()
1142 .is_some_and(|msg| msg.contains("boom"))
1143 );
1144
1145 let panic_stats = stats
1146 .iter()
1147 .find(|(source, _)| source == "panic_source")
1148 .map(|(_, stats)| stats)
1149 .unwrap();
1150 assert_eq!(panic_stats.error_count, 1);
1151 assert!(
1152 panic_stats
1153 .last_error
1154 .as_ref()
1155 .is_some_and(|msg| msg.contains("panicked"))
1156 );
1157 }
1158
1159 #[test]
1160 fn force_refresh_with_weights_path_is_exercised() {
1161 let mut manager = IngestionManager::new(3, SamplerConfig::default());
1162 manager
1163 .register_source(Box::new(ScriptedSource::new(
1164 "w",
1165 Arc::new(AtomicUsize::new(0)),
1166 vec![Ok(SourceSnapshot {
1167 records: vec![make_record("w1", "w")],
1168 cursor: SourceCursor {
1169 last_seen: Utc::now(),
1170 revision: 3,
1171 },
1172 })],
1173 )))
1174 .unwrap();
1175
1176 let mut weights = HashMap::new();
1177 weights.insert("w".to_string(), 1.0);
1178 manager.force_refresh_all_with_weights(&weights).unwrap();
1179 assert_eq!(manager.all_records_len(), 1);
1180 }
1181
1182 #[test]
1183 fn advance_with_weights_rejects_unknown_source() {
1184 let mut manager = IngestionManager::new(4, SamplerConfig::default());
1185 manager
1186 .register_source(Box::new(ScriptedSource::new(
1187 "known",
1188 Arc::new(AtomicUsize::new(0)),
1189 vec![],
1190 )))
1191 .unwrap();
1192
1193 let mut weights = HashMap::new();
1194 weights.insert("known".to_string(), 1.0);
1195 weights.insert("unknown".to_string(), 0.5);
1196
1197 let err = manager.advance_with_weights(1, &weights).unwrap_err();
1198 assert!(
1199 matches!(err, SamplerError::InvalidWeight { ref source_id, .. } if source_id == "unknown"),
1200 "expected InvalidWeight for unknown source, got {err:?}"
1201 );
1202 }
1203
1204 #[test]
1205 fn refresh_all_with_weights_rejects_negative_weight() {
1206 let mut manager = IngestionManager::new(4, SamplerConfig::default());
1207 manager
1208 .register_source(Box::new(ScriptedSource::new(
1209 "src",
1210 Arc::new(AtomicUsize::new(0)),
1211 vec![],
1212 )))
1213 .unwrap();
1214
1215 let mut weights = HashMap::new();
1216 weights.insert("src".to_string(), -1.0);
1217
1218 let err = manager.refresh_all_with_weights(&weights).unwrap_err();
1219 assert!(
1220 matches!(err, SamplerError::InvalidWeight { ref source_id, .. } if source_id == "src"),
1221 "expected InvalidWeight for negative weight, got {err:?}"
1222 );
1223 }
1224
1225 #[test]
1226 fn force_refresh_all_with_weights_rejects_unknown_source() {
1227 let mut manager = IngestionManager::new(4, SamplerConfig::default());
1228 manager
1229 .register_source(Box::new(ScriptedSource::new(
1230 "real",
1231 Arc::new(AtomicUsize::new(0)),
1232 vec![],
1233 )))
1234 .unwrap();
1235
1236 let mut weights = HashMap::new();
1237 weights.insert("ghost".to_string(), 1.0);
1238
1239 let err = manager
1240 .force_refresh_all_with_weights(&weights)
1241 .unwrap_err();
1242 assert!(
1243 matches!(err, SamplerError::InvalidWeight { ref source_id, .. } if source_id == "ghost"),
1244 "expected InvalidWeight for unknown source, got {err:?}"
1245 );
1246 }
1247
1248 struct SeedCapturingSource {
1250 id: String,
1251 received_seeds: Arc<Mutex<Vec<u64>>>,
1252 }
1253
1254 impl SeedCapturingSource {
1255 fn new(id: &str, received_seeds: Arc<Mutex<Vec<u64>>>) -> Self {
1256 Self {
1257 id: id.to_string(),
1258 received_seeds,
1259 }
1260 }
1261 }
1262
1263 impl DataSource for SeedCapturingSource {
1264 fn id(&self) -> &str {
1265 &self.id
1266 }
1267
1268 fn refresh(
1269 &self,
1270 config: &SamplerConfig,
1271 _cursor: Option<&SourceCursor>,
1272 _limit: Option<usize>,
1273 ) -> Result<SourceSnapshot, SamplerError> {
1274 self.received_seeds
1275 .lock()
1276 .expect("seed lock poisoned")
1277 .push(config.seed);
1278 Ok(SourceSnapshot {
1279 records: Vec::new(),
1280 cursor: SourceCursor {
1281 last_seen: Utc::now(),
1282 revision: 0,
1283 },
1284 })
1285 }
1286
1287 fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
1288 Ok(0)
1289 }
1290
1291 fn default_triplet_recipes(&self) -> Vec<crate::config::TripletRecipe> {
1292 Vec::new()
1293 }
1294 }
1295
1296 #[test]
1297 fn epoch_xor_changes_seed_received_by_source() {
1298 let base_seed = 0xDEAD_BEEF_u64;
1301 let config = SamplerConfig {
1302 seed: base_seed,
1303 ..SamplerConfig::default()
1304 };
1305
1306 let seeds_epoch0 = Arc::new(Mutex::new(Vec::<u64>::new()));
1307 let seeds_epoch1 = Arc::new(Mutex::new(Vec::<u64>::new()));
1308
1309 let mut manager = IngestionManager::new(4, config.clone());
1311 manager
1312 .register_source(Box::new(SeedCapturingSource::new(
1313 "src",
1314 Arc::clone(&seeds_epoch0),
1315 )))
1316 .unwrap();
1317 manager.refresh_all();
1319
1320 let mut manager2 = IngestionManager::new(4, config.clone());
1322 manager2
1323 .register_source(Box::new(SeedCapturingSource::new(
1324 "src",
1325 Arc::clone(&seeds_epoch1),
1326 )))
1327 .unwrap();
1328 manager2.set_epoch(1);
1329 manager2.refresh_all();
1330
1331 let received0 = seeds_epoch0.lock().unwrap();
1332 let received1 = seeds_epoch1.lock().unwrap();
1333
1334 assert!(!received0.is_empty(), "epoch-0 source was never refreshed");
1335 assert!(!received1.is_empty(), "epoch-1 source was never refreshed");
1336
1337 let seed_at_epoch0 = received0[0];
1338 let seed_at_epoch1 = received1[0];
1339
1340 assert_ne!(
1342 seed_at_epoch0, seed_at_epoch1,
1343 "epoch 0 and epoch 1 both produced seed {seed_at_epoch0:#x}; \
1344 derive_epoch_seed is not reaching the source"
1345 );
1346
1347 assert_eq!(
1351 seed_at_epoch0,
1352 derive_epoch_seed(base_seed, 0),
1353 "epoch-0 seed mismatch (epoch_step=0)"
1354 );
1355 assert_eq!(
1356 seed_at_epoch1,
1357 derive_epoch_seed(base_seed, 1),
1358 "epoch-1 seed mismatch (epoch_step=0)"
1359 );
1360 }
1361
1362 #[test]
1363 fn epoch_step_resets_on_epoch_change() {
1364 let config = SamplerConfig::default();
1367 let seeds = Arc::new(Mutex::new(Vec::new()));
1368
1369 let mut manager = IngestionManager::new(4, config.clone());
1370 manager
1371 .register_source(Box::new(SeedCapturingSource::new(
1372 "src",
1373 Arc::clone(&seeds),
1374 )))
1375 .unwrap();
1376
1377 manager.refresh_all();
1380 let step1_seed = seeds.lock().unwrap()[0];
1381 assert_eq!(
1382 step1_seed,
1383 derive_epoch_seed(config.seed, 0),
1384 "epoch 0 step 0 seed"
1385 );
1386
1387 manager.set_epoch(1);
1389 assert_eq!(manager.epoch_step(), 0);
1390 seeds.lock().unwrap().clear();
1391
1392 manager.refresh_all();
1394 let step1_epoch1 = seeds.lock().unwrap()[0];
1395 assert_eq!(
1396 step1_epoch1,
1397 derive_epoch_seed(config.seed, 1),
1398 "epoch 1 step 0 seed (must be ^0 since refresh_all no longer bumps step)"
1399 );
1400 }
1401
1402 #[test]
1403 fn epoch_step_survives_sampler_save_and_load_state() {
1404 let store = Arc::new(DeterministicSplitStore::new(SplitRatios::default(), 1).unwrap());
1410
1411 let sampler = TripletSampler::new(
1413 SamplerConfig {
1414 seed: 42,
1415 ..SamplerConfig::default()
1416 },
1417 Arc::clone(&store),
1418 );
1419 let refreshes = Arc::new(AtomicUsize::new(0));
1420 sampler
1421 .register_source(Box::new(ScriptedSource::new(
1422 "src",
1423 refreshes,
1424 vec![Ok(SourceSnapshot {
1425 records: vec![make_record("r1", "src")],
1426 cursor: SourceCursor {
1427 last_seen: Utc::now(),
1428 revision: 1,
1429 },
1430 })],
1431 )))
1432 .unwrap();
1433
1434 let _ = sampler.next_text_batch(SplitLabel::Train);
1438
1439 sampler.save_sampler_state(None).unwrap();
1441
1442 let loaded = store.load_sampler_state().unwrap().unwrap();
1444 let step_saved = loaded.epoch_step;
1445 assert!(
1446 step_saved > 0,
1447 "epoch_step must be >0 after batch call through TripletSampler"
1448 );
1449
1450 let mut manager2 = IngestionManager::new(4, SamplerConfig::default());
1452 manager2
1453 .register_source(Box::new(ScriptedSource::new(
1454 "src",
1455 Arc::new(AtomicUsize::new(0)),
1456 vec![Ok(SourceSnapshot {
1457 records: vec![make_record("r2", "src")],
1458 cursor: SourceCursor {
1459 last_seen: Utc::now(),
1460 revision: 2,
1461 },
1462 })],
1463 )))
1464 .unwrap();
1465 manager2.load_cursors(&loaded.source_stream_cursors);
1466 manager2.set_epoch_step(loaded.epoch_step);
1467
1468 let step_before = manager2.epoch_step();
1471 assert_eq!(
1472 step_before, step_saved,
1473 "load_cursors must restore epoch_step to saved value"
1474 );
1475
1476 manager2.refresh_all();
1478 let step_after = manager2.epoch_step();
1479 assert_eq!(
1480 step_after, step_saved,
1481 "epoch_step must survive refresh_all without increment (step is per-batch,
1482 not per-refresh): loaded {step_saved}, got {step_after}"
1483 );
1484 }
1485
1486 #[test]
1487 fn scripted_and_panic_sources_cover_default_trait_paths() {
1488 let refreshes = Arc::new(AtomicUsize::new(0));
1489 let scripted = ScriptedSource::new("scripted", refreshes, vec![]);
1490
1491 let snapshot = scripted
1493 .refresh(&SamplerConfig::default(), None, None)
1494 .expect("fallback snapshot");
1495 assert!(snapshot.records.is_empty());
1496 assert_eq!(snapshot.cursor.revision, 0);
1497
1498 assert_eq!(
1499 scripted
1500 .reported_record_count(&SamplerConfig::default())
1501 .expect("record count"),
1502 0
1503 );
1504 assert!(scripted.default_triplet_recipes().is_empty());
1505
1506 let panic_source = PanicSource {
1507 id: "panic_count".to_string(),
1508 };
1509 assert_eq!(
1510 panic_source
1511 .reported_record_count(&SamplerConfig::default())
1512 .expect("record count"),
1513 0
1514 );
1515 }
1516
1517 #[test]
1518 fn seed_capturing_source_trait_defaults_are_exercised() {
1519 let source = SeedCapturingSource::new("seed_defaults", Arc::new(Mutex::new(Vec::new())));
1520 assert_eq!(
1521 source
1522 .reported_record_count(&SamplerConfig::default())
1523 .expect("record count"),
1524 0
1525 );
1526 assert!(source.default_triplet_recipes().is_empty());
1527 }
1528
1529 #[test]
1530 fn refresh_paths_handle_zero_capacity_and_no_sources() {
1531 let mut manager = IngestionManager::new(0, SamplerConfig::default());
1532 manager
1533 .register_source(Box::new(ScriptedSource::new(
1534 "zero_capacity",
1535 Arc::new(AtomicUsize::new(0)),
1536 vec![Ok(SourceSnapshot {
1537 records: vec![make_record("r1", "zero_capacity")],
1538 cursor: SourceCursor {
1539 last_seen: Utc::now(),
1540 revision: 1,
1541 },
1542 })],
1543 )))
1544 .unwrap();
1545 manager.refresh_all();
1546 assert!(manager.all_caches_empty());
1547
1548 let mut empty_manager = IngestionManager::new(4, SamplerConfig::default());
1550 let empty_weights = HashMap::new();
1551 empty_manager
1552 .refresh_all_with_weights(&empty_weights)
1553 .expect("no sources should not error");
1554 assert!(empty_manager.all_caches_empty());
1555 }
1556
1557 #[test]
1558 fn drain_start_rotates_fairly_across_sources() {
1559 struct FairSource {
1563 id: String,
1564 refresh_count: Arc<AtomicUsize>,
1565 }
1566
1567 impl DataSource for FairSource {
1568 fn id(&self) -> &str {
1569 &self.id
1570 }
1571 fn refresh(
1572 &self,
1573 _config: &SamplerConfig,
1574 _cursor: Option<&SourceCursor>,
1575 _limit: Option<usize>,
1576 ) -> Result<SourceSnapshot, SamplerError> {
1577 self.refresh_count.fetch_add(1, Ordering::SeqCst);
1578 Ok(SourceSnapshot {
1579 records: (0..10)
1580 .map(|i| make_record(&format!("r{i}"), &self.id))
1581 .collect(),
1582 cursor: SourceCursor {
1583 last_seen: Utc::now(),
1584 revision: 1,
1585 },
1586 })
1587 }
1588 fn reported_record_count(&self, _: &SamplerConfig) -> Result<u128, SamplerError> {
1589 Ok(10)
1590 }
1591 }
1592
1593 let counts = (
1594 Arc::new(AtomicUsize::new(0)),
1595 Arc::new(AtomicUsize::new(0)),
1596 Arc::new(AtomicUsize::new(0)),
1597 );
1598
1599 let mut manager = IngestionManager::new(30, SamplerConfig::default());
1600 manager
1601 .register_source(Box::new(FairSource {
1602 id: "src_0".to_string(),
1603 refresh_count: Arc::clone(&counts.0),
1604 }))
1605 .unwrap();
1606 manager
1607 .register_source(Box::new(FairSource {
1608 id: "src_1".to_string(),
1609 refresh_count: Arc::clone(&counts.1),
1610 }))
1611 .unwrap();
1612 manager
1613 .register_source(Box::new(FairSource {
1614 id: "src_2".to_string(),
1615 refresh_count: Arc::clone(&counts.2),
1616 }))
1617 .unwrap();
1618
1619 manager.refresh_all();
1621 assert_eq!(counts.0.load(Ordering::SeqCst), 1);
1623 assert_eq!(counts.1.load(Ordering::SeqCst), 1);
1624 assert_eq!(counts.2.load(Ordering::SeqCst), 1);
1625
1626 for _ in 0..33 {
1632 manager.advance(1);
1633 }
1634
1635 let r0 = counts.0.load(Ordering::SeqCst);
1636 let r1 = counts.1.load(Ordering::SeqCst);
1637 let r2 = counts.2.load(Ordering::SeqCst);
1638
1639 let min = r0.min(r1).min(r2);
1646 let max = r0.max(r1).max(r2);
1647 assert!(
1648 max <= min + 1,
1649 "sources should refresh at roughly the same rate: got {r0}/{r1}/{r2}"
1650 );
1651 }
1652
1653 #[test]
1654 fn direct_drain_start_rotates_fairly_with_batch_2_of_5() {
1655 struct SimpleSource {
1658 id: String,
1659 refresh_count: Arc<AtomicUsize>,
1660 }
1661
1662 impl DataSource for SimpleSource {
1663 fn id(&self) -> &str {
1664 &self.id
1665 }
1666 fn refresh(
1667 &self,
1668 _: &SamplerConfig,
1669 _: Option<&SourceCursor>,
1670 _: Option<usize>,
1671 ) -> Result<SourceSnapshot, SamplerError> {
1672 self.refresh_count.fetch_add(1, Ordering::SeqCst);
1673 let now = Utc::now();
1674 let records: Vec<DataRecord> = (0..8)
1675 .map(|i| DataRecord {
1676 id: format!("{}_r{i}", self.id),
1677 source: self.id.clone(),
1678 created_at: now,
1679 updated_at: now,
1680 quality: QualityScore { trust: 1.0 },
1681 taxonomy: Vec::new(),
1682 sections: Vec::new(),
1683 meta_prefix: None,
1684 })
1685 .collect();
1686 Ok(SourceSnapshot {
1687 records,
1688 cursor: SourceCursor {
1689 last_seen: now,
1690 revision: 1,
1691 },
1692 })
1693 }
1694 fn reported_record_count(&self, _: &SamplerConfig) -> Result<u128, SamplerError> {
1695 Ok(8)
1696 }
1697 }
1698
1699 let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1700 let mut manager = IngestionManager::new(40, SamplerConfig::default());
1701 for (i, count) in counts.iter().enumerate() {
1702 manager
1703 .register_source(Box::new(SimpleSource {
1704 id: format!("src_{i}"),
1705 refresh_count: Arc::clone(count),
1706 }))
1707 .unwrap();
1708 }
1709
1710 manager.refresh_all();
1711 for c in &counts {
1712 assert_eq!(c.load(Ordering::SeqCst), 1);
1713 }
1714
1715 for _ in 0..80 {
1716 manager.advance(2);
1717 }
1718
1719 let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1720 assert_eq!(
1724 totals,
1725 vec![5, 5, 6, 6, 6],
1726 "direct manager: unexpected refresh distribution"
1727 );
1728 }
1729
1730 fn make_five_source_sampler(
1733 counts: &[Arc<AtomicUsize>],
1734 ) -> TripletSampler<DeterministicSplitStore> {
1735 struct Tracked {
1736 id: String,
1737 refresh_count: Arc<AtomicUsize>,
1738 }
1739 impl DataSource for Tracked {
1740 fn id(&self) -> &str {
1741 &self.id
1742 }
1743 fn refresh(
1744 &self,
1745 _: &SamplerConfig,
1746 _: Option<&SourceCursor>,
1747 _: Option<usize>,
1748 ) -> Result<SourceSnapshot, SamplerError> {
1749 self.refresh_count.fetch_add(1, Ordering::SeqCst);
1750 let now = Utc::now();
1751 let records: Vec<DataRecord> = (0..8)
1752 .map(|i| DataRecord {
1753 id: format!("{}_r{i}", self.id),
1754 source: self.id.clone(),
1755 created_at: now,
1756 updated_at: now,
1757 quality: QualityScore { trust: 1.0 },
1758 taxonomy: Vec::new(),
1759 sections: vec![RecordSection {
1760 role: SectionRole::Anchor,
1761 heading: None,
1762 text: format!("x{i}"),
1763 sentences: vec![format!("x{i}")],
1764 }],
1765 meta_prefix: None,
1766 })
1767 .collect();
1768 Ok(SourceSnapshot {
1769 records,
1770 cursor: SourceCursor {
1771 last_seen: now,
1772 revision: 1,
1773 },
1774 })
1775 }
1776 fn reported_record_count(&self, _: &SamplerConfig) -> Result<u128, SamplerError> {
1777 Ok(8)
1778 }
1779 }
1780
1781 let config = SamplerConfig {
1782 batch_size: 2,
1783 text_recipes: vec![TextRecipe {
1784 name: "anchor".into(),
1785 selector: Selector::Role(SectionRole::Anchor),
1786 weight: 1.0,
1787 instruction: None,
1788 }],
1789 split: SplitRatios {
1790 train: 1.0,
1791 validation: 0.0,
1792 test: 0.0,
1793 },
1794 allowed_splits: vec![SplitLabel::Train],
1795 ingestion_max_records: 4,
1801 ..SamplerConfig::default()
1802 };
1803 let store = Arc::new(DeterministicSplitStore::new(config.split, 99).unwrap());
1804 let sampler = TripletSampler::new(config, store);
1805
1806 for (i, count) in counts.iter().enumerate() {
1807 sampler
1808 .register_source(Box::new(Tracked {
1809 id: format!("src_{i}"),
1810 refresh_count: Arc::clone(count),
1811 }))
1812 .unwrap();
1813 }
1814 sampler
1815 }
1816
1817 #[test]
1818 fn sampler_unweighted_drain_distributes_evenly() {
1819 let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1820 let sampler = make_five_source_sampler(&counts);
1821
1822 sampler.next_text_batch(SplitLabel::Train).unwrap();
1823 for _ in 0..80 {
1824 sampler.next_text_batch(SplitLabel::Train).unwrap();
1825 }
1826
1827 let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1828 let min = *totals.iter().min().unwrap();
1829 let max = *totals.iter().max().unwrap();
1830 assert!(
1831 max <= min + 1,
1832 "unweighted: all sources must refresh at roughly the same rate: {totals:?}"
1833 );
1834 assert!(
1835 min >= 4,
1836 "unweighted: each source should have refreshed at least 4 times: {totals:?}"
1837 );
1838 }
1839
1840 #[test]
1841 fn sampler_weighted_drain_with_equal_weights_distributes_evenly() {
1842 let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1846 let sampler = make_five_source_sampler(&counts);
1847
1848 let weights: HashMap<String, f32> = (0..5).map(|i| (format!("src_{i}"), 1.0)).collect();
1849
1850 sampler
1851 .next_text_batch_with_weights(SplitLabel::Train, &weights)
1852 .unwrap();
1853 for _ in 0..80 {
1854 sampler
1855 .next_text_batch_with_weights(SplitLabel::Train, &weights)
1856 .unwrap();
1857 }
1858
1859 let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1860 let min = *totals.iter().min().unwrap();
1861 let max = *totals.iter().max().unwrap();
1862 assert!(
1863 max <= min + 1,
1864 "weighted (equal): all sources must refresh at roughly the same rate: {totals:?}"
1865 );
1866 assert!(
1867 min >= 4,
1868 "weighted (equal): each source should have refreshed at least 4 times: {totals:?}"
1869 );
1870 }
1871
1872 #[test]
1873 fn sampler_unweighted_and_weighted_match_distribution() {
1874 let uc: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1876 let wc: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1877 let usampler = make_five_source_sampler(&uc);
1878 let wsampler = make_five_source_sampler(&wc);
1879
1880 let weights: HashMap<String, f32> = (0..5).map(|i| (format!("src_{i}"), 1.0)).collect();
1881
1882 usampler.next_text_batch(SplitLabel::Train).unwrap();
1883 wsampler
1884 .next_text_batch_with_weights(SplitLabel::Train, &weights)
1885 .unwrap();
1886 for _ in 0..80 {
1887 usampler.next_text_batch(SplitLabel::Train).unwrap();
1888 wsampler
1889 .next_text_batch_with_weights(SplitLabel::Train, &weights)
1890 .unwrap();
1891 }
1892
1893 let ut: Vec<usize> = uc.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1894 let wt: Vec<usize> = wc.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1895 let umax = *ut.iter().max().unwrap();
1896 let umin = *ut.iter().min().unwrap();
1897 let wmax = *wt.iter().max().unwrap();
1898 let wmin = *wt.iter().min().unwrap();
1899 assert!(umax <= umin + 1, "unweighted: {ut:?}");
1900 assert!(wmax <= wmin + 1, "weighted equal: {wt:?}");
1901 }
1902
1903 #[test]
1904 fn sampler_weighted_drain_with_unequal_weights_respects_ratios() {
1905 let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1909 let sampler = make_five_source_sampler(&counts);
1910
1911 let mut weights = HashMap::new();
1912 for i in 0..5 {
1913 weights.insert(format!("src_{i}"), if i == 3 { 2.0f32 } else { 1.0 });
1914 }
1915
1916 sampler
1917 .next_text_batch_with_weights(SplitLabel::Train, &weights)
1918 .unwrap();
1919 for _ in 0..200 {
1920 sampler
1921 .next_text_batch_with_weights(SplitLabel::Train, &weights)
1922 .unwrap();
1923 }
1924
1925 let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1926 assert!(
1934 totals
1935 .iter()
1936 .enumerate()
1937 .all(|(i, &t)| i == 3 || t < totals[3]),
1938 "src_3 (w=2.0) must outpace all w=1.0 sources (totals: {totals:?})"
1939 );
1940 assert_eq!(
1947 totals,
1948 vec![11, 11, 11, 31, 17],
1949 "unequal-weights: unexpected refresh distribution"
1950 );
1951 }
1952
1953 #[test]
1954 fn register_source_rejects_reserved_id_pattern() {
1955 use crate::constants::splits::is_reserved_source_id;
1956
1957 let mut manager = IngestionManager::new(4, SamplerConfig::default());
1958
1959 assert!(is_reserved_source_id("__meta__"));
1961 assert!(is_reserved_source_id("__anything__"));
1962 assert!(is_reserved_source_id("__x__"));
1963 assert!(!is_reserved_source_id(""));
1964 assert!(!is_reserved_source_id("__"));
1965 assert!(!is_reserved_source_id("___"));
1966 assert!(!is_reserved_source_id("normal_source"));
1967 assert!(!is_reserved_source_id("_prefix_suffix_"));
1968 assert!(!is_reserved_source_id("__unclosed"));
1969 assert!(!is_reserved_source_id("unopened__"));
1970
1971 let result = manager.register_source(Box::new(ScriptedSource::new(
1973 "__reserved__",
1974 Arc::new(AtomicUsize::new(0)),
1975 vec![],
1976 )));
1977 assert!(
1978 result.is_err(),
1979 "register_source should reject reserved source id"
1980 );
1981 let err = result.unwrap_err();
1982 assert!(
1983 matches!(&err, SamplerError::ReservedSourceId(id) if id == "__reserved__"),
1984 "expected ReservedSourceId error, got: {err}"
1985 );
1986
1987 assert!(!manager.has_sources());
1989
1990 manager
1992 .register_source(Box::new(ScriptedSource::new(
1993 "valid_source",
1994 Arc::new(AtomicUsize::new(0)),
1995 vec![],
1996 )))
1997 .unwrap();
1998 assert!(manager.has_sources());
1999 }
2000}